Merge remote-tracking branch 'origin/master' into multiport

This commit is contained in:
Wade Simmons 2023-10-27 08:48:13 -04:00
commit f2aef0d6eb
74 changed files with 2540 additions and 1402 deletions

22
.github/dependabot.yml vendored Normal file
View File

@ -0,0 +1,22 @@
version: 2
updates:
- package-ecosystem: "github-actions"
directory: "/"
schedule:
interval: "weekly"
- package-ecosystem: "gomod"
directory: "/"
schedule:
interval: "weekly"
groups:
golang-x-dependencies:
patterns:
- "golang.org/x/*"
zx2c4-dependencies:
patterns:
- "golang.zx2c4.com/*"
protobuf-dependencies:
patterns:
- "github.com/golang/protobuf"
- "google.golang.org/protobuf"

View File

@ -14,21 +14,12 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Set up Go 1.20
uses: actions/setup-go@v2
with:
go-version: "1.20"
id: go
- uses: actions/checkout@v4
- name: Check out code into the Go module directory
uses: actions/checkout@v2
- uses: actions/cache@v2
- uses: actions/setup-go@v4
with:
path: ~/go/pkg/mod
key: ${{ runner.os }}-gofmt1.20-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-gofmt1.20-
go-version-file: 'go.mod'
check-latest: true
- name: Install goimports
run: |

View File

@ -7,25 +7,24 @@ name: Create release and upload binaries
jobs:
build-linux:
name: Build Linux All
name: Build Linux/BSD All
runs-on: ubuntu-latest
steps:
- name: Set up Go 1.20
uses: actions/setup-go@v2
with:
go-version: "1.20"
- uses: actions/checkout@v4
- name: Checkout code
uses: actions/checkout@v2
- uses: actions/setup-go@v4
with:
go-version-file: 'go.mod'
check-latest: true
- name: Build
run: |
make BUILD_NUMBER="${GITHUB_REF#refs/tags/v}" release-linux release-freebsd
make BUILD_NUMBER="${GITHUB_REF#refs/tags/v}" release-linux release-freebsd release-openbsd release-netbsd
mkdir release
mv build/*.tar.gz release
- name: Upload artifacts
uses: actions/upload-artifact@v2
uses: actions/upload-artifact@v3
with:
name: linux-latest
path: release
@ -34,13 +33,12 @@ jobs:
name: Build Windows
runs-on: windows-latest
steps:
- name: Set up Go 1.20
uses: actions/setup-go@v2
with:
go-version: "1.20"
- uses: actions/checkout@v4
- name: Checkout code
uses: actions/checkout@v2
- uses: actions/setup-go@v4
with:
go-version-file: 'go.mod'
check-latest: true
- name: Build
run: |
@ -57,7 +55,7 @@ jobs:
mv dist\windows\wintun build\dist\windows\
- name: Upload artifacts
uses: actions/upload-artifact@v2
uses: actions/upload-artifact@v3
with:
name: windows-latest
path: build
@ -68,17 +66,16 @@ jobs:
HAS_SIGNING_CREDS: ${{ secrets.AC_USERNAME != '' }}
runs-on: macos-11
steps:
- name: Set up Go 1.20
uses: actions/setup-go@v2
with:
go-version: "1.20"
- uses: actions/checkout@v4
- name: Checkout code
uses: actions/checkout@v2
- uses: actions/setup-go@v4
with:
go-version-file: 'go.mod'
check-latest: true
- name: Import certificates
if: env.HAS_SIGNING_CREDS == 'true'
uses: Apple-Actions/import-codesign-certs@v1
uses: Apple-Actions/import-codesign-certs@v2
with:
p12-file-base64: ${{ secrets.APPLE_DEVELOPER_CERTIFICATE_P12_BASE64 }}
p12-password: ${{ secrets.APPLE_DEVELOPER_CERTIFICATE_PASSWORD }}
@ -107,7 +104,7 @@ jobs:
fi
- name: Upload artifacts
uses: actions/upload-artifact@v2
uses: actions/upload-artifact@v3
with:
name: darwin-latest
path: ./release/*
@ -117,12 +114,16 @@ jobs:
needs: [build-linux, build-darwin, build-windows]
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Download artifacts
uses: actions/download-artifact@v2
uses: actions/download-artifact@v3
with:
path: artifacts
- name: Zip Windows
run: |
cd windows-latest
cd artifacts/windows-latest
cp windows-amd64/* .
zip -r nebula-windows-amd64.zip nebula.exe nebula-cert.exe dist
cp windows-arm64/* .
@ -130,6 +131,7 @@ jobs:
- name: Create sha256sum
run: |
cd artifacts
for dir in linux-latest darwin-latest windows-latest
do
(
@ -159,195 +161,12 @@ jobs:
- name: Create Release
id: create_release
uses: actions/create-release@v1
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
with:
tag_name: ${{ github.ref }}
release_name: Release ${{ github.ref }}
draft: false
prerelease: false
##
## Upload assets (I wish we could just upload the whole folder at once...
##
- name: Upload SHASUM256.txt
uses: actions/upload-release-asset@v1.0.1
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
with:
upload_url: ${{ steps.create_release.outputs.upload_url }}
asset_path: ./SHASUM256.txt
asset_name: SHASUM256.txt
asset_content_type: text/plain
- name: Upload darwin zip
uses: actions/upload-release-asset@v1.0.1
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
with:
upload_url: ${{ steps.create_release.outputs.upload_url }}
asset_path: ./darwin-latest/nebula-darwin.zip
asset_name: nebula-darwin.zip
asset_content_type: application/zip
- name: Upload windows-amd64
uses: actions/upload-release-asset@v1.0.1
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
with:
upload_url: ${{ steps.create_release.outputs.upload_url }}
asset_path: ./windows-latest/nebula-windows-amd64.zip
asset_name: nebula-windows-amd64.zip
asset_content_type: application/zip
- name: Upload windows-arm64
uses: actions/upload-release-asset@v1.0.1
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
with:
upload_url: ${{ steps.create_release.outputs.upload_url }}
asset_path: ./windows-latest/nebula-windows-arm64.zip
asset_name: nebula-windows-arm64.zip
asset_content_type: application/zip
- name: Upload linux-amd64
uses: actions/upload-release-asset@v1.0.1
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
with:
upload_url: ${{ steps.create_release.outputs.upload_url }}
asset_path: ./linux-latest/nebula-linux-amd64.tar.gz
asset_name: nebula-linux-amd64.tar.gz
asset_content_type: application/gzip
- name: Upload linux-386
uses: actions/upload-release-asset@v1.0.1
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
with:
upload_url: ${{ steps.create_release.outputs.upload_url }}
asset_path: ./linux-latest/nebula-linux-386.tar.gz
asset_name: nebula-linux-386.tar.gz
asset_content_type: application/gzip
- name: Upload linux-ppc64le
uses: actions/upload-release-asset@v1.0.1
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
with:
upload_url: ${{ steps.create_release.outputs.upload_url }}
asset_path: ./linux-latest/nebula-linux-ppc64le.tar.gz
asset_name: nebula-linux-ppc64le.tar.gz
asset_content_type: application/gzip
- name: Upload linux-arm-5
uses: actions/upload-release-asset@v1.0.1
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
with:
upload_url: ${{ steps.create_release.outputs.upload_url }}
asset_path: ./linux-latest/nebula-linux-arm-5.tar.gz
asset_name: nebula-linux-arm-5.tar.gz
asset_content_type: application/gzip
- name: Upload linux-arm-6
uses: actions/upload-release-asset@v1.0.1
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
with:
upload_url: ${{ steps.create_release.outputs.upload_url }}
asset_path: ./linux-latest/nebula-linux-arm-6.tar.gz
asset_name: nebula-linux-arm-6.tar.gz
asset_content_type: application/gzip
- name: Upload linux-arm-7
uses: actions/upload-release-asset@v1.0.1
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
with:
upload_url: ${{ steps.create_release.outputs.upload_url }}
asset_path: ./linux-latest/nebula-linux-arm-7.tar.gz
asset_name: nebula-linux-arm-7.tar.gz
asset_content_type: application/gzip
- name: Upload linux-arm64
uses: actions/upload-release-asset@v1.0.1
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
with:
upload_url: ${{ steps.create_release.outputs.upload_url }}
asset_path: ./linux-latest/nebula-linux-arm64.tar.gz
asset_name: nebula-linux-arm64.tar.gz
asset_content_type: application/gzip
- name: Upload linux-mips
uses: actions/upload-release-asset@v1.0.1
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
with:
upload_url: ${{ steps.create_release.outputs.upload_url }}
asset_path: ./linux-latest/nebula-linux-mips.tar.gz
asset_name: nebula-linux-mips.tar.gz
asset_content_type: application/gzip
- name: Upload linux-mipsle
uses: actions/upload-release-asset@v1.0.1
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
with:
upload_url: ${{ steps.create_release.outputs.upload_url }}
asset_path: ./linux-latest/nebula-linux-mipsle.tar.gz
asset_name: nebula-linux-mipsle.tar.gz
asset_content_type: application/gzip
- name: Upload linux-mips64
uses: actions/upload-release-asset@v1.0.1
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
with:
upload_url: ${{ steps.create_release.outputs.upload_url }}
asset_path: ./linux-latest/nebula-linux-mips64.tar.gz
asset_name: nebula-linux-mips64.tar.gz
asset_content_type: application/gzip
- name: Upload linux-mips64le
uses: actions/upload-release-asset@v1.0.1
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
with:
upload_url: ${{ steps.create_release.outputs.upload_url }}
asset_path: ./linux-latest/nebula-linux-mips64le.tar.gz
asset_name: nebula-linux-mips64le.tar.gz
asset_content_type: application/gzip
- name: Upload linux-mips-softfloat
uses: actions/upload-release-asset@v1.0.1
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
with:
upload_url: ${{ steps.create_release.outputs.upload_url }}
asset_path: ./linux-latest/nebula-linux-mips-softfloat.tar.gz
asset_name: nebula-linux-mips-softfloat.tar.gz
asset_content_type: application/gzip
- name: Upload linux-riscv64
uses: actions/upload-release-asset@v1.0.1
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
with:
upload_url: ${{ steps.create_release.outputs.upload_url }}
asset_path: ./linux-latest/nebula-linux-riscv64.tar.gz
asset_name: nebula-linux-riscv64.tar.gz
asset_content_type: application/gzip
- name: Upload freebsd-amd64
uses: actions/upload-release-asset@v1.0.1
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
with:
upload_url: ${{ steps.create_release.outputs.upload_url }}
asset_path: ./linux-latest/nebula-freebsd-amd64.tar.gz
asset_name: nebula-freebsd-amd64.tar.gz
asset_content_type: application/gzip
run: |
cd artifacts
gh release create \
--verify-tag \
--title "Release ${{ github.ref_name }}" \
"${{ github.ref_name }}" \
SHASUM256.txt *-latest/*.zip *-latest/*.tar.gz

View File

@ -18,21 +18,12 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Set up Go 1.20
uses: actions/setup-go@v2
with:
go-version: "1.20"
id: go
- uses: actions/checkout@v4
- name: Check out code into the Go module directory
uses: actions/checkout@v2
- uses: actions/cache@v2
- uses: actions/setup-go@v4
with:
path: ~/go/pkg/mod
key: ${{ runner.os }}-go1.20-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-go1.20-
go-version-file: 'go.mod'
check-latest: true
- name: build
run: make bin-docker

View File

@ -41,4 +41,4 @@ EOF
../../../../nebula-cert sign -name "host4" -groups "host,host4" -ip "192.168.100.4/24"
)
sudo docker build -t nebula:smoke-relay .
docker build -t nebula:smoke-relay .

View File

@ -36,4 +36,4 @@ mkdir ./build
../../../../nebula-cert sign -name "host4" -groups "host,host4" -ip "192.168.100.4/24"
)
sudo docker build -t "nebula:${NAME:-smoke}" .
docker build -t "nebula:${NAME:-smoke}" .

View File

@ -14,24 +14,24 @@ cleanup() {
set +e
if [ "$(jobs -r)" ]
then
sudo docker kill lighthouse1 host2 host3 host4
docker kill lighthouse1 host2 host3 host4
fi
}
trap cleanup EXIT
sudo docker run --name lighthouse1 --rm nebula:smoke-relay -config lighthouse1.yml -test
sudo docker run --name host2 --rm nebula:smoke-relay -config host2.yml -test
sudo docker run --name host3 --rm nebula:smoke-relay -config host3.yml -test
sudo docker run --name host4 --rm nebula:smoke-relay -config host4.yml -test
docker run --name lighthouse1 --rm nebula:smoke-relay -config lighthouse1.yml -test
docker run --name host2 --rm nebula:smoke-relay -config host2.yml -test
docker run --name host3 --rm nebula:smoke-relay -config host3.yml -test
docker run --name host4 --rm nebula:smoke-relay -config host4.yml -test
sudo docker run --name lighthouse1 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm nebula:smoke-relay -config lighthouse1.yml 2>&1 | tee logs/lighthouse1 | sed -u 's/^/ [lighthouse1] /' &
docker run --name lighthouse1 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm nebula:smoke-relay -config lighthouse1.yml 2>&1 | tee logs/lighthouse1 | sed -u 's/^/ [lighthouse1] /' &
sleep 1
sudo docker run --name host2 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm nebula:smoke-relay -config host2.yml 2>&1 | tee logs/host2 | sed -u 's/^/ [host2] /' &
docker run --name host2 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm nebula:smoke-relay -config host2.yml 2>&1 | tee logs/host2 | sed -u 's/^/ [host2] /' &
sleep 1
sudo docker run --name host3 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm nebula:smoke-relay -config host3.yml 2>&1 | tee logs/host3 | sed -u 's/^/ [host3] /' &
docker run --name host3 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm nebula:smoke-relay -config host3.yml 2>&1 | tee logs/host3 | sed -u 's/^/ [host3] /' &
sleep 1
sudo docker run --name host4 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm nebula:smoke-relay -config host4.yml 2>&1 | tee logs/host4 | sed -u 's/^/ [host4] /' &
docker run --name host4 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm nebula:smoke-relay -config host4.yml 2>&1 | tee logs/host4 | sed -u 's/^/ [host4] /' &
sleep 1
set +x
@ -39,43 +39,43 @@ echo
echo " *** Testing ping from lighthouse1"
echo
set -x
sudo docker exec lighthouse1 ping -c1 192.168.100.2
sudo docker exec lighthouse1 ping -c1 192.168.100.3
sudo docker exec lighthouse1 ping -c1 192.168.100.4
docker exec lighthouse1 ping -c1 192.168.100.2
docker exec lighthouse1 ping -c1 192.168.100.3
docker exec lighthouse1 ping -c1 192.168.100.4
set +x
echo
echo " *** Testing ping from host2"
echo
set -x
sudo docker exec host2 ping -c1 192.168.100.1
docker exec host2 ping -c1 192.168.100.1
# Should fail because no relay configured in this direction
! sudo docker exec host2 ping -c1 192.168.100.3 -w5 || exit 1
! sudo docker exec host2 ping -c1 192.168.100.4 -w5 || exit 1
! docker exec host2 ping -c1 192.168.100.3 -w5 || exit 1
! docker exec host2 ping -c1 192.168.100.4 -w5 || exit 1
set +x
echo
echo " *** Testing ping from host3"
echo
set -x
sudo docker exec host3 ping -c1 192.168.100.1
sudo docker exec host3 ping -c1 192.168.100.2
sudo docker exec host3 ping -c1 192.168.100.4
docker exec host3 ping -c1 192.168.100.1
docker exec host3 ping -c1 192.168.100.2
docker exec host3 ping -c1 192.168.100.4
set +x
echo
echo " *** Testing ping from host4"
echo
set -x
sudo docker exec host4 ping -c1 192.168.100.1
docker exec host4 ping -c1 192.168.100.1
# Should fail because relays not allowed
! sudo docker exec host4 ping -c1 192.168.100.2 -w5 || exit 1
sudo docker exec host4 ping -c1 192.168.100.3
! docker exec host4 ping -c1 192.168.100.2 -w5 || exit 1
docker exec host4 ping -c1 192.168.100.3
sudo docker exec host4 sh -c 'kill 1'
sudo docker exec host3 sh -c 'kill 1'
sudo docker exec host2 sh -c 'kill 1'
sudo docker exec lighthouse1 sh -c 'kill 1'
docker exec host4 sh -c 'kill 1'
docker exec host3 sh -c 'kill 1'
docker exec host2 sh -c 'kill 1'
docker exec lighthouse1 sh -c 'kill 1'
sleep 1
if [ "$(jobs -r)" ]

View File

@ -14,7 +14,7 @@ cleanup() {
set +e
if [ "$(jobs -r)" ]
then
sudo docker kill lighthouse1 host2 host3 host4
docker kill lighthouse1 host2 host3 host4
fi
}
@ -22,51 +22,51 @@ trap cleanup EXIT
CONTAINER="nebula:${NAME:-smoke}"
sudo docker run --name lighthouse1 --rm "$CONTAINER" -config lighthouse1.yml -test
sudo docker run --name host2 --rm "$CONTAINER" -config host2.yml -test
sudo docker run --name host3 --rm "$CONTAINER" -config host3.yml -test
sudo docker run --name host4 --rm "$CONTAINER" -config host4.yml -test
docker run --name lighthouse1 --rm "$CONTAINER" -config lighthouse1.yml -test
docker run --name host2 --rm "$CONTAINER" -config host2.yml -test
docker run --name host3 --rm "$CONTAINER" -config host3.yml -test
docker run --name host4 --rm "$CONTAINER" -config host4.yml -test
sudo docker run --name lighthouse1 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config lighthouse1.yml 2>&1 | tee logs/lighthouse1 | sed -u 's/^/ [lighthouse1] /' &
docker run --name lighthouse1 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config lighthouse1.yml 2>&1 | tee logs/lighthouse1 | sed -u 's/^/ [lighthouse1] /' &
sleep 1
sudo docker run --name host2 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config host2.yml 2>&1 | tee logs/host2 | sed -u 's/^/ [host2] /' &
docker run --name host2 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config host2.yml 2>&1 | tee logs/host2 | sed -u 's/^/ [host2] /' &
sleep 1
sudo docker run --name host3 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config host3.yml 2>&1 | tee logs/host3 | sed -u 's/^/ [host3] /' &
docker run --name host3 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config host3.yml 2>&1 | tee logs/host3 | sed -u 's/^/ [host3] /' &
sleep 1
sudo docker run --name host4 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config host4.yml 2>&1 | tee logs/host4 | sed -u 's/^/ [host4] /' &
docker run --name host4 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config host4.yml 2>&1 | tee logs/host4 | sed -u 's/^/ [host4] /' &
sleep 1
# grab tcpdump pcaps for debugging
sudo docker exec lighthouse1 tcpdump -i nebula1 -q -w - -U 2>logs/lighthouse1.inside.log >logs/lighthouse1.inside.pcap &
sudo docker exec lighthouse1 tcpdump -i eth0 -q -w - -U 2>logs/lighthouse1.outside.log >logs/lighthouse1.outside.pcap &
sudo docker exec host2 tcpdump -i nebula1 -q -w - -U 2>logs/host2.inside.log >logs/host2.inside.pcap &
sudo docker exec host2 tcpdump -i eth0 -q -w - -U 2>logs/host2.outside.log >logs/host2.outside.pcap &
sudo docker exec host3 tcpdump -i nebula1 -q -w - -U 2>logs/host3.inside.log >logs/host3.inside.pcap &
sudo docker exec host3 tcpdump -i eth0 -q -w - -U 2>logs/host3.outside.log >logs/host3.outside.pcap &
sudo docker exec host4 tcpdump -i nebula1 -q -w - -U 2>logs/host4.inside.log >logs/host4.inside.pcap &
sudo docker exec host4 tcpdump -i eth0 -q -w - -U 2>logs/host4.outside.log >logs/host4.outside.pcap &
docker exec lighthouse1 tcpdump -i nebula1 -q -w - -U 2>logs/lighthouse1.inside.log >logs/lighthouse1.inside.pcap &
docker exec lighthouse1 tcpdump -i eth0 -q -w - -U 2>logs/lighthouse1.outside.log >logs/lighthouse1.outside.pcap &
docker exec host2 tcpdump -i nebula1 -q -w - -U 2>logs/host2.inside.log >logs/host2.inside.pcap &
docker exec host2 tcpdump -i eth0 -q -w - -U 2>logs/host2.outside.log >logs/host2.outside.pcap &
docker exec host3 tcpdump -i nebula1 -q -w - -U 2>logs/host3.inside.log >logs/host3.inside.pcap &
docker exec host3 tcpdump -i eth0 -q -w - -U 2>logs/host3.outside.log >logs/host3.outside.pcap &
docker exec host4 tcpdump -i nebula1 -q -w - -U 2>logs/host4.inside.log >logs/host4.inside.pcap &
docker exec host4 tcpdump -i eth0 -q -w - -U 2>logs/host4.outside.log >logs/host4.outside.pcap &
sudo docker exec host2 ncat -nklv 0.0.0.0 2000 &
sudo docker exec host3 ncat -nklv 0.0.0.0 2000 &
sudo docker exec host2 ncat -e '/usr/bin/echo host2' -nkluv 0.0.0.0 3000 &
sudo docker exec host3 ncat -e '/usr/bin/echo host3' -nkluv 0.0.0.0 3000 &
docker exec host2 ncat -nklv 0.0.0.0 2000 &
docker exec host3 ncat -nklv 0.0.0.0 2000 &
docker exec host2 ncat -e '/usr/bin/echo host2' -nkluv 0.0.0.0 3000 &
docker exec host3 ncat -e '/usr/bin/echo host3' -nkluv 0.0.0.0 3000 &
set +x
echo
echo " *** Testing ping from lighthouse1"
echo
set -x
sudo docker exec lighthouse1 ping -c1 192.168.100.2
sudo docker exec lighthouse1 ping -c1 192.168.100.3
docker exec lighthouse1 ping -c1 192.168.100.2
docker exec lighthouse1 ping -c1 192.168.100.3
set +x
echo
echo " *** Testing ping from host2"
echo
set -x
sudo docker exec host2 ping -c1 192.168.100.1
docker exec host2 ping -c1 192.168.100.1
# Should fail because not allowed by host3 inbound firewall
! sudo docker exec host2 ping -c1 192.168.100.3 -w5 || exit 1
! docker exec host2 ping -c1 192.168.100.3 -w5 || exit 1
set +x
echo
@ -74,34 +74,34 @@ echo " *** Testing ncat from host2"
echo
set -x
# Should fail because not allowed by host3 inbound firewall
! sudo docker exec host2 ncat -nzv -w5 192.168.100.3 2000 || exit 1
! sudo docker exec host2 ncat -nzuv -w5 192.168.100.3 3000 | grep -q host3 || exit 1
! docker exec host2 ncat -nzv -w5 192.168.100.3 2000 || exit 1
! docker exec host2 ncat -nzuv -w5 192.168.100.3 3000 | grep -q host3 || exit 1
set +x
echo
echo " *** Testing ping from host3"
echo
set -x
sudo docker exec host3 ping -c1 192.168.100.1
sudo docker exec host3 ping -c1 192.168.100.2
docker exec host3 ping -c1 192.168.100.1
docker exec host3 ping -c1 192.168.100.2
set +x
echo
echo " *** Testing ncat from host3"
echo
set -x
sudo docker exec host3 ncat -nzv -w5 192.168.100.2 2000
sudo docker exec host3 ncat -nzuv -w5 192.168.100.2 3000 | grep -q host2
docker exec host3 ncat -nzv -w5 192.168.100.2 2000
docker exec host3 ncat -nzuv -w5 192.168.100.2 3000 | grep -q host2
set +x
echo
echo " *** Testing ping from host4"
echo
set -x
sudo docker exec host4 ping -c1 192.168.100.1
docker exec host4 ping -c1 192.168.100.1
# Should fail because not allowed by host4 outbound firewall
! sudo docker exec host4 ping -c1 192.168.100.2 -w5 || exit 1
! sudo docker exec host4 ping -c1 192.168.100.3 -w5 || exit 1
! docker exec host4 ping -c1 192.168.100.2 -w5 || exit 1
! docker exec host4 ping -c1 192.168.100.3 -w5 || exit 1
set +x
echo
@ -109,10 +109,10 @@ echo " *** Testing ncat from host4"
echo
set -x
# Should fail because not allowed by host4 outbound firewall
! sudo docker exec host4 ncat -nzv -w5 192.168.100.2 2000 || exit 1
! sudo docker exec host4 ncat -nzv -w5 192.168.100.3 2000 || exit 1
! sudo docker exec host4 ncat -nzuv -w5 192.168.100.2 3000 | grep -q host2 || exit 1
! sudo docker exec host4 ncat -nzuv -w5 192.168.100.3 3000 | grep -q host3 || exit 1
! docker exec host4 ncat -nzv -w5 192.168.100.2 2000 || exit 1
! docker exec host4 ncat -nzv -w5 192.168.100.3 2000 || exit 1
! docker exec host4 ncat -nzuv -w5 192.168.100.2 3000 | grep -q host2 || exit 1
! docker exec host4 ncat -nzuv -w5 192.168.100.3 3000 | grep -q host3 || exit 1
set +x
echo
@ -120,15 +120,15 @@ echo " *** Testing conntrack"
echo
set -x
# host2 can ping host3 now that host3 pinged it first
sudo docker exec host2 ping -c1 192.168.100.3
docker exec host2 ping -c1 192.168.100.3
# host4 can ping host2 once conntrack established
sudo docker exec host2 ping -c1 192.168.100.4
sudo docker exec host4 ping -c1 192.168.100.2
docker exec host2 ping -c1 192.168.100.4
docker exec host4 ping -c1 192.168.100.2
sudo docker exec host4 sh -c 'kill 1'
sudo docker exec host3 sh -c 'kill 1'
sudo docker exec host2 sh -c 'kill 1'
sudo docker exec lighthouse1 sh -c 'kill 1'
docker exec host4 sh -c 'kill 1'
docker exec host3 sh -c 'kill 1'
docker exec host2 sh -c 'kill 1'
docker exec lighthouse1 sh -c 'kill 1'
sleep 1
if [ "$(jobs -r)" ]

View File

@ -18,21 +18,12 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Set up Go 1.20
uses: actions/setup-go@v2
with:
go-version: "1.20"
id: go
- uses: actions/checkout@v4
- name: Check out code into the Go module directory
uses: actions/checkout@v2
- uses: actions/cache@v2
- uses: actions/setup-go@v4
with:
path: ~/go/pkg/mod
key: ${{ runner.os }}-go1.20-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-go1.20-
go-version-file: 'go.mod'
check-latest: true
- name: Build
run: make all
@ -57,21 +48,12 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Set up Go 1.20
uses: actions/setup-go@v2
with:
go-version: "1.20"
id: go
- uses: actions/checkout@v4
- name: Check out code into the Go module directory
uses: actions/checkout@v2
- uses: actions/cache@v2
- uses: actions/setup-go@v4
with:
path: ~/go/pkg/mod
key: ${{ runner.os }}-go1.20-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-go1.20-
go-version-file: 'go.mod'
check-latest: true
- name: Build
run: make bin-boringcrypto
@ -90,21 +72,12 @@ jobs:
os: [windows-latest, macos-11]
steps:
- name: Set up Go 1.20
uses: actions/setup-go@v2
with:
go-version: "1.20"
id: go
- uses: actions/checkout@v4
- name: Check out code into the Go module directory
uses: actions/checkout@v2
- uses: actions/cache@v2
- uses: actions/setup-go@v4
with:
path: ~/go/pkg/mod
key: ${{ runner.os }}-go1.20-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-go1.20-
go-version-file: 'go.mod'
check-latest: true
- name: Build nebula
run: go build ./cmd/nebula

View File

@ -7,6 +7,19 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [Unreleased]
## [1.7.2] - 2023-06-01
### Fixed
- Fix a freeze during config reload if the `static_host_map` config was changed. (#886)
## [1.7.1] - 2023-05-18
### Fixed
- Fix IPv4 addresses returned by `static_host_map` DNS lookup queries being
treated as IPv6 addresses. (#877)
## [1.7.0] - 2023-05-17
### Added
@ -475,7 +488,9 @@ created.)
- Initial public release.
[Unreleased]: https://github.com/slackhq/nebula/compare/v1.7.0...HEAD
[Unreleased]: https://github.com/slackhq/nebula/compare/v1.7.2...HEAD
[1.7.2]: https://github.com/slackhq/nebula/releases/tag/v1.7.2
[1.7.1]: https://github.com/slackhq/nebula/releases/tag/v1.7.1
[1.7.0]: https://github.com/slackhq/nebula/releases/tag/v1.7.0
[1.6.1]: https://github.com/slackhq/nebula/releases/tag/v1.6.1
[1.6.0]: https://github.com/slackhq/nebula/releases/tag/v1.6.0

View File

@ -12,6 +12,8 @@ ifeq ($(OS),Windows_NT)
GOISMIN := $(shell IF "$(GOVERSION)" GEQ "$(GOMINVERSION)" ECHO 1)
NEBULA_CMD_SUFFIX = .exe
NULL_FILE = nul
# RIO on windows does pointer stuff that makes go vet angry
VET_FLAGS = -unsafeptr=false
else
GOVERSION := $(shell go version | awk '{print substr($$3, 3)}')
GOISMIN := $(shell expr "$(GOVERSION)" ">=" "$(GOMINVERSION)")
@ -44,10 +46,21 @@ ALL_LINUX = linux-amd64 \
linux-mips-softfloat \
linux-riscv64
ALL_FREEBSD = freebsd-amd64 \
freebsd-arm64
ALL_OPENBSD = openbsd-amd64 \
openbsd-arm64
ALL_NETBSD = netbsd-amd64 \
netbsd-arm64
ALL = $(ALL_LINUX) \
$(ALL_FREEBSD) \
$(ALL_OPENBSD) \
$(ALL_NETBSD) \
darwin-amd64 \
darwin-arm64 \
freebsd-amd64 \
windows-amd64 \
windows-arm64
@ -75,7 +88,11 @@ release: $(ALL:%=build/nebula-%.tar.gz)
release-linux: $(ALL_LINUX:%=build/nebula-%.tar.gz)
release-freebsd: build/nebula-freebsd-amd64.tar.gz
release-freebsd: $(ALL_FREEBSD:%=build/nebula-%.tar.gz)
release-openbsd: $(ALL_OPENBSD:%=build/nebula-%.tar.gz)
release-netbsd: $(ALL_NETBSD:%=build/nebula-%.tar.gz)
release-boringcrypto: build/nebula-linux-$(shell go env GOARCH)-boringcrypto.tar.gz
@ -93,6 +110,9 @@ bin-darwin: build/darwin-amd64/nebula build/darwin-amd64/nebula-cert
bin-freebsd: build/freebsd-amd64/nebula build/freebsd-amd64/nebula-cert
mv $? .
bin-freebsd-arm64: build/freebsd-arm64/nebula build/freebsd-arm64/nebula-cert
mv $? .
bin-boringcrypto: build/linux-$(shell go env GOARCH)-boringcrypto/nebula build/linux-$(shell go env GOARCH)-boringcrypto/nebula-cert
mv $? .
@ -137,7 +157,7 @@ build/nebula-%.zip: build/%/nebula.exe build/%/nebula-cert.exe
cd build/$* && zip ../nebula-$*.zip nebula.exe nebula-cert.exe
vet:
go vet -v ./...
go vet $(VET_FLAGS) -v ./...
test:
go test -v ./...

View File

@ -108,7 +108,7 @@ For each host, copy the nebula binary to the host, along with `config.yml` from
## Building Nebula from source
Download go and clone this repo. Change to the nebula directory.
Make sure you have [go](https://go.dev/doc/install) installed and clone this repo. Change to the nebula directory.
To build nebula for all platforms:
`make all`

163
cert.go
View File

@ -1,163 +0,0 @@
package nebula
import (
"errors"
"fmt"
"io/ioutil"
"strings"
"time"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/config"
)
type CertState struct {
certificate *cert.NebulaCertificate
rawCertificate []byte
rawCertificateNoKey []byte
publicKey []byte
privateKey []byte
}
func NewCertState(certificate *cert.NebulaCertificate, privateKey []byte) (*CertState, error) {
// Marshal the certificate to ensure it is valid
rawCertificate, err := certificate.Marshal()
if err != nil {
return nil, fmt.Errorf("invalid nebula certificate on interface: %s", err)
}
publicKey := certificate.Details.PublicKey
cs := &CertState{
rawCertificate: rawCertificate,
certificate: certificate, // PublicKey has been set to nil above
privateKey: privateKey,
publicKey: publicKey,
}
cs.certificate.Details.PublicKey = nil
rawCertNoKey, err := cs.certificate.Marshal()
if err != nil {
return nil, fmt.Errorf("error marshalling certificate no key: %s", err)
}
cs.rawCertificateNoKey = rawCertNoKey
// put public key back
cs.certificate.Details.PublicKey = cs.publicKey
return cs, nil
}
func NewCertStateFromConfig(c *config.C) (*CertState, error) {
var pemPrivateKey []byte
var err error
privPathOrPEM := c.GetString("pki.key", "")
if privPathOrPEM == "" {
return nil, errors.New("no pki.key path or PEM data provided")
}
if strings.Contains(privPathOrPEM, "-----BEGIN") {
pemPrivateKey = []byte(privPathOrPEM)
privPathOrPEM = "<inline>"
} else {
pemPrivateKey, err = ioutil.ReadFile(privPathOrPEM)
if err != nil {
return nil, fmt.Errorf("unable to read pki.key file %s: %s", privPathOrPEM, err)
}
}
rawKey, _, curve, err := cert.UnmarshalPrivateKey(pemPrivateKey)
if err != nil {
return nil, fmt.Errorf("error while unmarshaling pki.key %s: %s", privPathOrPEM, err)
}
var rawCert []byte
pubPathOrPEM := c.GetString("pki.cert", "")
if pubPathOrPEM == "" {
return nil, errors.New("no pki.cert path or PEM data provided")
}
if strings.Contains(pubPathOrPEM, "-----BEGIN") {
rawCert = []byte(pubPathOrPEM)
pubPathOrPEM = "<inline>"
} else {
rawCert, err = ioutil.ReadFile(pubPathOrPEM)
if err != nil {
return nil, fmt.Errorf("unable to read pki.cert file %s: %s", pubPathOrPEM, err)
}
}
nebulaCert, _, err := cert.UnmarshalNebulaCertificateFromPEM(rawCert)
if err != nil {
return nil, fmt.Errorf("error while unmarshaling pki.cert %s: %s", pubPathOrPEM, err)
}
if nebulaCert.Expired(time.Now()) {
return nil, fmt.Errorf("nebula certificate for this host is expired")
}
if len(nebulaCert.Details.Ips) == 0 {
return nil, fmt.Errorf("no IPs encoded in certificate")
}
if err = nebulaCert.VerifyPrivateKey(curve, rawKey); err != nil {
return nil, fmt.Errorf("private key is not a pair with public key in nebula cert")
}
return NewCertState(nebulaCert, rawKey)
}
func loadCAFromConfig(l *logrus.Logger, c *config.C) (*cert.NebulaCAPool, error) {
var rawCA []byte
var err error
caPathOrPEM := c.GetString("pki.ca", "")
if caPathOrPEM == "" {
return nil, errors.New("no pki.ca path or PEM data provided")
}
if strings.Contains(caPathOrPEM, "-----BEGIN") {
rawCA = []byte(caPathOrPEM)
} else {
rawCA, err = ioutil.ReadFile(caPathOrPEM)
if err != nil {
return nil, fmt.Errorf("unable to read pki.ca file %s: %s", caPathOrPEM, err)
}
}
CAs, err := cert.NewCAPoolFromBytes(rawCA)
if errors.Is(err, cert.ErrExpired) {
var expired int
for _, cert := range CAs.CAs {
if cert.Expired(time.Now()) {
expired++
l.WithField("cert", cert).Warn("expired certificate present in CA pool")
}
}
if expired >= len(CAs.CAs) {
return nil, errors.New("no valid CA certificates present")
}
} else if err != nil {
return nil, fmt.Errorf("error while adding CA certificate to CA trust store: %s", err)
}
for _, fp := range c.GetStringSlice("pki.blocklist", []string{}) {
l.WithField("fingerprint", fp).Info("Blocklisting cert")
CAs.BlocklistFingerprint(fp)
}
// Support deprecated config for at least one minor release to allow for migrations
//TODO: remove in 2022 or later
for _, fp := range c.GetStringSlice("pki.blacklist", []string{}) {
l.WithField("fingerprint", fp).Info("Blocklisting cert")
l.Warn("pki.blacklist is deprecated and will not be supported in a future release. Please migrate your config to use pki.blocklist")
CAs.BlocklistFingerprint(fp)
}
return CAs, nil
}

View File

@ -272,6 +272,9 @@ func EncryptAndMarshalSigningPrivateKey(curve Curve, b []byte, passphrase []byte
},
Ciphertext: ciphertext,
})
if err != nil {
return nil, err
}
switch curve {
case Curve_CURVE25519:

View File

@ -77,6 +77,9 @@ func aes256Decrypt(passphrase []byte, kdfParams *Argon2Parameters, data []byte)
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return nil, err
}
nonce, ciphertext, err := splitNonceCiphertext(data, gcm.NonceSize())
if err != nil {

View File

@ -59,13 +59,8 @@ func main() {
}
ctrl, err := nebula.Main(c, *configTest, Build, l, nil)
switch v := err.(type) {
case util.ContextualError:
v.Log(l)
os.Exit(1)
case error:
l.WithError(err).Error("Failed to start")
if err != nil {
util.LogWithContextIfNeeded("Failed to start", err, l)
os.Exit(1)
}

View File

@ -53,18 +53,14 @@ func main() {
}
ctrl, err := nebula.Main(c, *configTest, Build, l, nil)
switch v := err.(type) {
case util.ContextualError:
v.Log(l)
os.Exit(1)
case error:
l.WithError(err).Error("Failed to start")
if err != nil {
util.LogWithContextIfNeeded("Failed to start", err, l)
os.Exit(1)
}
if !*configTest {
ctrl.Start()
notifyReady(l)
ctrl.ShutdownBlock()
}

View File

@ -0,0 +1,42 @@
package main
import (
"net"
"os"
"time"
"github.com/sirupsen/logrus"
)
// SdNotifyReady tells systemd the service is ready and dependent services can now be started
// https://www.freedesktop.org/software/systemd/man/sd_notify.html
// https://www.freedesktop.org/software/systemd/man/systemd.service.html
const SdNotifyReady = "READY=1"
func notifyReady(l *logrus.Logger) {
sockName := os.Getenv("NOTIFY_SOCKET")
if sockName == "" {
l.Debugln("NOTIFY_SOCKET systemd env var not set, not sending ready signal")
return
}
conn, err := net.DialTimeout("unixgram", sockName, time.Second)
if err != nil {
l.WithError(err).Error("failed to connect to systemd notification socket")
return
}
defer conn.Close()
err = conn.SetWriteDeadline(time.Now().Add(time.Second))
if err != nil {
l.WithError(err).Error("failed to set the write deadline for the systemd notification socket")
return
}
if _, err = conn.Write([]byte(SdNotifyReady)); err != nil {
l.WithError(err).Error("failed to signal the systemd notification socket")
return
}
l.Debugln("notified systemd the service is ready")
}

View File

@ -0,0 +1,10 @@
//go:build !linux
// +build !linux
package main
import "github.com/sirupsen/logrus"
func notifyReady(_ *logrus.Logger) {
// No init service to notify
}

View File

@ -5,6 +5,7 @@ import (
"errors"
"fmt"
"io/ioutil"
"math"
"os"
"os/signal"
"path/filepath"
@ -15,7 +16,7 @@ import (
"syscall"
"time"
"github.com/imdario/mergo"
"dario.cat/mergo"
"github.com/sirupsen/logrus"
"gopkg.in/yaml.v2"
)
@ -236,6 +237,15 @@ func (c *C) GetInt(k string, d int) int {
return v
}
// GetUint32 will get the uint32 for k or return the default d if not found or invalid
func (c *C) GetUint32(k string, d uint32) uint32 {
r := c.GetInt(k, int(d))
if uint64(r) > uint64(math.MaxUint32) {
return d
}
return uint32(r)
}
// GetBool will get the bool for k or return the default d if not found or invalid
func (c *C) GetBool(k string, d bool) bool {
r := strings.ToLower(c.GetString(k, fmt.Sprintf("%v", d)))

View File

@ -7,7 +7,7 @@ import (
"testing"
"time"
"github.com/imdario/mergo"
"dario.cat/mergo"
"github.com/slackhq/nebula/test"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

View File

@ -231,7 +231,7 @@ func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo)
index = existing.LocalIndex
switch r.Type {
case TerminalType:
relayFrom = newhostinfo.vpnIp
relayFrom = n.intf.myVpnIp
relayTo = existing.PeerIp
case ForwardingType:
relayFrom = existing.PeerIp
@ -256,7 +256,7 @@ func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo)
}
switch r.Type {
case TerminalType:
relayFrom = newhostinfo.vpnIp
relayFrom = n.intf.myVpnIp
relayTo = r.PeerIp
case ForwardingType:
relayFrom = r.PeerIp
@ -405,8 +405,8 @@ func (n *connectionManager) shouldSwapPrimary(current, primary *HostInfo) bool {
return false
}
certState := n.intf.certState.Load()
return bytes.Equal(current.ConnectionState.certState.certificate.Signature, certState.certificate.Signature)
certState := n.intf.pki.GetCertState()
return bytes.Equal(current.ConnectionState.myCert.Signature, certState.Certificate.Signature)
}
func (n *connectionManager) swapPrimary(current, primary *HostInfo) {
@ -427,7 +427,7 @@ func (n *connectionManager) isInvalidCertificate(now time.Time, hostinfo *HostIn
return false
}
valid, err := remoteCert.VerifyWithCache(now, n.intf.caPool)
valid, err := remoteCert.VerifyWithCache(now, n.intf.pki.GetCAPool())
if valid {
return false
}
@ -464,8 +464,8 @@ func (n *connectionManager) sendPunch(hostinfo *HostInfo) {
}
func (n *connectionManager) tryRehandshake(hostinfo *HostInfo) {
certState := n.intf.certState.Load()
if bytes.Equal(hostinfo.ConnectionState.certState.certificate.Signature, certState.certificate.Signature) {
certState := n.intf.pki.GetCertState()
if bytes.Equal(hostinfo.ConnectionState.myCert.Signature, certState.Certificate.Signature) {
return
}
@ -473,18 +473,5 @@ func (n *connectionManager) tryRehandshake(hostinfo *HostInfo) {
WithField("reason", "local certificate is not current").
Info("Re-handshaking with remote")
//TODO: this is copied from getOrHandshake to keep the extra checks out of the hot path, figure it out
newHostinfo := n.intf.handshakeManager.AddVpnIp(hostinfo.vpnIp, n.intf.initHostInfo)
if !newHostinfo.HandshakeReady {
ixHandshakeStage0(n.intf, newHostinfo.vpnIp, newHostinfo)
}
//If this is a static host, we don't need to wait for the HostQueryReply
//We can trigger the handshake right now
if _, ok := n.intf.lightHouse.GetStaticHostList()[hostinfo.vpnIp]; ok {
select {
case n.intf.handshakeManager.trigger <- hostinfo.vpnIp:
default:
}
}
n.intf.handshakeManager.StartHandshake(hostinfo.vpnIp, nil)
}

View File

@ -42,25 +42,26 @@ func Test_NewConnectionManagerTest(t *testing.T) {
preferredRanges := []*net.IPNet{localrange}
// Very incomplete mock objects
hostMap := NewHostMap(l, "test", vpncidr, preferredRanges)
hostMap := NewHostMap(l, vpncidr, preferredRanges)
cs := &CertState{
rawCertificate: []byte{},
privateKey: []byte{},
certificate: &cert.NebulaCertificate{},
rawCertificateNoKey: []byte{},
RawCertificate: []byte{},
PrivateKey: []byte{},
Certificate: &cert.NebulaCertificate{},
RawCertificateNoKey: []byte{},
}
lh := newTestLighthouse()
ifce := &Interface{
hostMap: hostMap,
inside: &test.NoopTun{},
outside: &udp.Conn{},
outside: &udp.NoopConn{},
firewall: &Firewall{},
lightHouse: lh,
handshakeManager: NewHandshakeManager(l, vpncidr, preferredRanges, hostMap, lh, &udp.Conn{}, defaultHandshakeConfig),
pki: &PKI{},
handshakeManager: NewHandshakeManager(l, hostMap, lh, &udp.NoopConn{}, defaultHandshakeConfig),
l: l,
}
ifce.certState.Store(cs)
ifce.pki.cs.Store(cs)
// Create manager
ctx, cancel := context.WithCancel(context.Background())
@ -78,8 +79,8 @@ func Test_NewConnectionManagerTest(t *testing.T) {
remoteIndexId: 9901,
}
hostinfo.ConnectionState = &ConnectionState{
certState: cs,
H: &noise.HandshakeState{},
myCert: &cert.NebulaCertificate{},
H: &noise.HandshakeState{},
}
nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)
@ -121,25 +122,26 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
preferredRanges := []*net.IPNet{localrange}
// Very incomplete mock objects
hostMap := NewHostMap(l, "test", vpncidr, preferredRanges)
hostMap := NewHostMap(l, vpncidr, preferredRanges)
cs := &CertState{
rawCertificate: []byte{},
privateKey: []byte{},
certificate: &cert.NebulaCertificate{},
rawCertificateNoKey: []byte{},
RawCertificate: []byte{},
PrivateKey: []byte{},
Certificate: &cert.NebulaCertificate{},
RawCertificateNoKey: []byte{},
}
lh := newTestLighthouse()
ifce := &Interface{
hostMap: hostMap,
inside: &test.NoopTun{},
outside: &udp.Conn{},
outside: &udp.NoopConn{},
firewall: &Firewall{},
lightHouse: lh,
handshakeManager: NewHandshakeManager(l, vpncidr, preferredRanges, hostMap, lh, &udp.Conn{}, defaultHandshakeConfig),
pki: &PKI{},
handshakeManager: NewHandshakeManager(l, hostMap, lh, &udp.NoopConn{}, defaultHandshakeConfig),
l: l,
}
ifce.certState.Store(cs)
ifce.pki.cs.Store(cs)
// Create manager
ctx, cancel := context.WithCancel(context.Background())
@ -157,8 +159,8 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
remoteIndexId: 9901,
}
hostinfo.ConnectionState = &ConnectionState{
certState: cs,
H: &noise.HandshakeState{},
myCert: &cert.NebulaCertificate{},
H: &noise.HandshakeState{},
}
nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)
@ -207,7 +209,7 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
preferredRanges := []*net.IPNet{localrange}
hostMap := NewHostMap(l, "test", vpncidr, preferredRanges)
hostMap := NewHostMap(l, vpncidr, preferredRanges)
// Generate keys for CA and peer's cert.
pubCA, privCA, _ := ed25519.GenerateKey(rand.Reader)
@ -220,7 +222,8 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
PublicKey: pubCA,
},
}
caCert.Sign(cert.Curve_CURVE25519, privCA)
assert.NoError(t, caCert.Sign(cert.Curve_CURVE25519, privCA))
ncp := &cert.NebulaCAPool{
CAs: cert.NewCAPool().CAs,
}
@ -239,28 +242,29 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
Issuer: "ca",
},
}
peerCert.Sign(cert.Curve_CURVE25519, privCA)
assert.NoError(t, peerCert.Sign(cert.Curve_CURVE25519, privCA))
cs := &CertState{
rawCertificate: []byte{},
privateKey: []byte{},
certificate: &cert.NebulaCertificate{},
rawCertificateNoKey: []byte{},
RawCertificate: []byte{},
PrivateKey: []byte{},
Certificate: &cert.NebulaCertificate{},
RawCertificateNoKey: []byte{},
}
lh := newTestLighthouse()
ifce := &Interface{
hostMap: hostMap,
inside: &test.NoopTun{},
outside: &udp.Conn{},
outside: &udp.NoopConn{},
firewall: &Firewall{},
lightHouse: lh,
handshakeManager: NewHandshakeManager(l, vpncidr, preferredRanges, hostMap, lh, &udp.Conn{}, defaultHandshakeConfig),
handshakeManager: NewHandshakeManager(l, hostMap, lh, &udp.NoopConn{}, defaultHandshakeConfig),
l: l,
disconnectInvalid: true,
caPool: ncp,
pki: &PKI{},
}
ifce.certState.Store(cs)
ifce.pki.cs.Store(cs)
ifce.pki.caPool.Store(ncp)
// Create manager
ctx, cancel := context.WithCancel(context.Background())
@ -268,12 +272,16 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
punchy := NewPunchyFromConfig(l, config.NewC(l))
nc := newConnectionManager(ctx, l, ifce, 5, 10, punchy)
ifce.connectionManager = nc
hostinfo, _ := nc.hostMap.AddVpnIp(vpnIp, nil)
hostinfo.ConnectionState = &ConnectionState{
certState: cs,
peerCert: &peerCert,
H: &noise.HandshakeState{},
hostinfo := &HostInfo{
vpnIp: vpnIp,
ConnectionState: &ConnectionState{
myCert: &cert.NebulaCertificate{},
peerCert: &peerCert,
H: &noise.HandshakeState{},
},
}
nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)
// Move ahead 45s.
// Check if to disconnect with invalid certificate.

View File

@ -18,35 +18,35 @@ type ConnectionState struct {
eKey *NebulaCipherState
dKey *NebulaCipherState
H *noise.HandshakeState
certState *CertState
myCert *cert.NebulaCertificate
peerCert *cert.NebulaCertificate
initiator bool
messageCounter atomic.Uint64
window *Bits
queueLock sync.Mutex
writeLock sync.Mutex
ready bool
}
func (f *Interface) newConnectionState(l *logrus.Logger, initiator bool, pattern noise.HandshakePattern, psk []byte, pskStage int) *ConnectionState {
func NewConnectionState(l *logrus.Logger, cipher string, certState *CertState, initiator bool, pattern noise.HandshakePattern, psk []byte, pskStage int) *ConnectionState {
var dhFunc noise.DHFunc
curCertState := f.certState.Load()
switch curCertState.certificate.Details.Curve {
switch certState.Certificate.Details.Curve {
case cert.Curve_CURVE25519:
dhFunc = noise.DH25519
case cert.Curve_P256:
dhFunc = noiseutil.DHP256
default:
l.Errorf("invalid curve: %s", curCertState.certificate.Details.Curve)
l.Errorf("invalid curve: %s", certState.Certificate.Details.Curve)
return nil
}
cs := noise.NewCipherSuite(dhFunc, noiseutil.CipherAESGCM, noise.HashSHA256)
if f.cipher == "chachapoly" {
var cs noise.CipherSuite
if cipher == "chachapoly" {
cs = noise.NewCipherSuite(dhFunc, noise.CipherChaChaPoly, noise.HashSHA256)
} else {
cs = noise.NewCipherSuite(dhFunc, noiseutil.CipherAESGCM, noise.HashSHA256)
}
static := noise.DHKey{Private: curCertState.privateKey, Public: curCertState.publicKey}
static := noise.DHKey{Private: certState.PrivateKey, Public: certState.PublicKey}
b := NewBits(ReplayWindow)
// Clear out bit 0, we never transmit it and we don't want it showing as packet loss
@ -72,7 +72,7 @@ func (f *Interface) newConnectionState(l *logrus.Logger, initiator bool, pattern
initiator: initiator,
window: b,
ready: false,
certState: curCertState,
myCert: certState.Certificate,
}
return ci

View File

@ -17,13 +17,23 @@ import (
// Every interaction here needs to take extra care to copy memory and not return or use arguments "as is" when touching
// core. This means copying IP objects, slices, de-referencing pointers and taking the actual value, etc
type controlEach func(h *HostInfo)
type controlHostLister interface {
QueryVpnIp(vpnIp iputil.VpnIp) *HostInfo
ForEachIndex(each controlEach)
ForEachVpnIp(each controlEach)
GetPreferredRanges() []*net.IPNet
}
type Control struct {
f *Interface
l *logrus.Logger
cancel context.CancelFunc
sshStart func()
statsStart func()
dnsStart func()
f *Interface
l *logrus.Logger
cancel context.CancelFunc
sshStart func()
statsStart func()
dnsStart func()
lighthouseStart func()
}
type ControlHostInfo struct {
@ -54,12 +64,15 @@ func (c *Control) Start() {
if c.dnsStart != nil {
go c.dnsStart()
}
if c.lighthouseStart != nil {
c.lighthouseStart()
}
// Start reading packets.
c.f.run()
}
// Stop signals nebula to shutdown, returns after the shutdown is complete
// Stop signals nebula to shutdown and close all tunnels, returns after the shutdown is complete
func (c *Control) Stop() {
// Stop the handshakeManager (and other services), to prevent new tunnels from
// being created while we're shutting them all down.
@ -89,7 +102,7 @@ func (c *Control) RebindUDPServer() {
_ = c.f.outside.Rebind()
// Trigger a lighthouse update, useful for mobile clients that should have an update interval of 0
c.f.lightHouse.SendUpdate(c.f)
c.f.lightHouse.SendUpdate()
// Let the main interface know that we rebound so that underlying tunnels know to trigger punches from their remotes
c.f.rebindCount++
@ -98,7 +111,7 @@ func (c *Control) RebindUDPServer() {
// ListHostmapHosts returns details about the actual or pending (handshaking) hostmap by vpn ip
func (c *Control) ListHostmapHosts(pendingMap bool) []ControlHostInfo {
if pendingMap {
return listHostMapHosts(c.f.handshakeManager.pendingHostMap)
return listHostMapHosts(c.f.handshakeManager)
} else {
return listHostMapHosts(c.f.hostMap)
}
@ -107,7 +120,7 @@ func (c *Control) ListHostmapHosts(pendingMap bool) []ControlHostInfo {
// ListHostmapIndexes returns details about the actual or pending (handshaking) hostmap by local index id
func (c *Control) ListHostmapIndexes(pendingMap bool) []ControlHostInfo {
if pendingMap {
return listHostMapIndexes(c.f.handshakeManager.pendingHostMap)
return listHostMapIndexes(c.f.handshakeManager)
} else {
return listHostMapIndexes(c.f.hostMap)
}
@ -115,15 +128,15 @@ func (c *Control) ListHostmapIndexes(pendingMap bool) []ControlHostInfo {
// GetHostInfoByVpnIp returns a single tunnels hostInfo, or nil if not found
func (c *Control) GetHostInfoByVpnIp(vpnIp iputil.VpnIp, pending bool) *ControlHostInfo {
var hm *HostMap
var hl controlHostLister
if pending {
hm = c.f.handshakeManager.pendingHostMap
hl = c.f.handshakeManager
} else {
hm = c.f.hostMap
hl = c.f.hostMap
}
h, err := hm.QueryVpnIp(vpnIp)
if err != nil {
h := hl.QueryVpnIp(vpnIp)
if h == nil {
return nil
}
@ -133,8 +146,8 @@ func (c *Control) GetHostInfoByVpnIp(vpnIp iputil.VpnIp, pending bool) *ControlH
// SetRemoteForTunnel forces a tunnel to use a specific remote
func (c *Control) SetRemoteForTunnel(vpnIp iputil.VpnIp, addr udp.Addr) *ControlHostInfo {
hostInfo, err := c.f.hostMap.QueryVpnIp(vpnIp)
if err != nil {
hostInfo := c.f.hostMap.QueryVpnIp(vpnIp)
if hostInfo == nil {
return nil
}
@ -145,8 +158,8 @@ func (c *Control) SetRemoteForTunnel(vpnIp iputil.VpnIp, addr udp.Addr) *Control
// CloseTunnel closes a fully established tunnel. If localOnly is false it will notify the remote end as well.
func (c *Control) CloseTunnel(vpnIp iputil.VpnIp, localOnly bool) bool {
hostInfo, err := c.f.hostMap.QueryVpnIp(vpnIp)
if err != nil {
hostInfo := c.f.hostMap.QueryVpnIp(vpnIp)
if hostInfo == nil {
return false
}
@ -241,28 +254,20 @@ func copyHostInfo(h *HostInfo, preferredRanges []*net.IPNet) ControlHostInfo {
return chi
}
func listHostMapHosts(hm *HostMap) []ControlHostInfo {
hm.RLock()
hosts := make([]ControlHostInfo, len(hm.Hosts))
i := 0
for _, v := range hm.Hosts {
hosts[i] = copyHostInfo(v, hm.preferredRanges)
i++
}
hm.RUnlock()
func listHostMapHosts(hl controlHostLister) []ControlHostInfo {
hosts := make([]ControlHostInfo, 0)
pr := hl.GetPreferredRanges()
hl.ForEachVpnIp(func(hostinfo *HostInfo) {
hosts = append(hosts, copyHostInfo(hostinfo, pr))
})
return hosts
}
func listHostMapIndexes(hm *HostMap) []ControlHostInfo {
hm.RLock()
hosts := make([]ControlHostInfo, len(hm.Indexes))
i := 0
for _, v := range hm.Indexes {
hosts[i] = copyHostInfo(v, hm.preferredRanges)
i++
}
hm.RUnlock()
func listHostMapIndexes(hl controlHostLister) []ControlHostInfo {
hosts := make([]ControlHostInfo, 0)
pr := hl.GetPreferredRanges()
hl.ForEachIndex(func(hostinfo *HostInfo) {
hosts = append(hosts, copyHostInfo(hostinfo, pr))
})
return hosts
}

View File

@ -18,7 +18,7 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
l := test.NewLogger()
// Special care must be taken to re-use all objects provided to the hostmap and certificate in the expectedInfo object
// To properly ensure we are not exposing core memory to the caller
hm := NewHostMap(l, "test", &net.IPNet{}, make([]*net.IPNet, 0))
hm := NewHostMap(l, &net.IPNet{}, make([]*net.IPNet, 0))
remote1 := udp.NewAddr(net.ParseIP("0.0.0.100"), 4444)
remote2 := udp.NewAddr(net.ParseIP("1:2:3:4:5:6:7:8"), 4444)
ipNet := net.IPNet{
@ -50,7 +50,7 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
remotes := NewRemoteList(nil)
remotes.unlockedPrependV4(0, NewIp4AndPort(remote1.IP, uint32(remote1.Port)))
remotes.unlockedPrependV6(0, NewIp6AndPort(remote2.IP, uint32(remote2.Port)))
hm.Add(iputil.Ip2VpnIp(ipNet.IP), &HostInfo{
hm.unlockedAddHostInfo(&HostInfo{
remote: remote1,
remotes: remotes,
ConnectionState: &ConnectionState{
@ -64,9 +64,9 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
relayForByIp: map[iputil.VpnIp]*Relay{},
relayForByIdx: map[uint32]*Relay{},
},
})
}, &Interface{})
hm.Add(iputil.Ip2VpnIp(ipNet2.IP), &HostInfo{
hm.unlockedAddHostInfo(&HostInfo{
remote: remote1,
remotes: remotes,
ConnectionState: &ConnectionState{
@ -80,7 +80,7 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
relayForByIp: map[iputil.VpnIp]*Relay{},
relayForByIdx: map[uint32]*Relay{},
},
})
}, &Interface{})
c := Control{
f: &Interface{

View File

@ -21,7 +21,7 @@ import (
func (c *Control) WaitForType(msgType header.MessageType, subType header.MessageSubType, pipeTo *Control) {
h := &header.H{}
for {
p := c.f.outside.Get(true)
p := c.f.outside.(*udp.TesterConn).Get(true)
if err := h.Parse(p.Data); err != nil {
panic(err)
}
@ -37,7 +37,7 @@ func (c *Control) WaitForType(msgType header.MessageType, subType header.Message
func (c *Control) WaitForTypeByIndex(toIndex uint32, msgType header.MessageType, subType header.MessageSubType, pipeTo *Control) {
h := &header.H{}
for {
p := c.f.outside.Get(true)
p := c.f.outside.(*udp.TesterConn).Get(true)
if err := h.Parse(p.Data); err != nil {
panic(err)
}
@ -90,11 +90,11 @@ func (c *Control) GetFromTun(block bool) []byte {
// GetFromUDP will pull a udp packet off the udp side of nebula
func (c *Control) GetFromUDP(block bool) *udp.Packet {
return c.f.outside.Get(block)
return c.f.outside.(*udp.TesterConn).Get(block)
}
func (c *Control) GetUDPTxChan() <-chan *udp.Packet {
return c.f.outside.TxPackets
return c.f.outside.(*udp.TesterConn).TxPackets
}
func (c *Control) GetTunTxChan() <-chan []byte {
@ -103,7 +103,7 @@ func (c *Control) GetTunTxChan() <-chan []byte {
// InjectUDPPacket will inject a packet into the udp side of nebula
func (c *Control) InjectUDPPacket(p *udp.Packet) {
c.f.outside.Send(p)
c.f.outside.(*udp.TesterConn).Send(p)
}
// InjectTunUDPPacket puts a udp packet on the tun interface. Using UDP here because it's a simpler protocol
@ -143,16 +143,16 @@ func (c *Control) GetVpnIp() iputil.VpnIp {
}
func (c *Control) GetUDPAddr() string {
return c.f.outside.Addr.String()
return c.f.outside.(*udp.TesterConn).Addr.String()
}
func (c *Control) KillPendingTunnel(vpnIp net.IP) bool {
hostinfo, ok := c.f.handshakeManager.pendingHostMap.Hosts[iputil.Ip2VpnIp(vpnIp)]
if !ok {
hostinfo := c.f.handshakeManager.QueryVpnIp(iputil.Ip2VpnIp(vpnIp))
if hostinfo == nil {
return false
}
c.f.handshakeManager.pendingHostMap.DeleteHostInfo(hostinfo)
c.f.handshakeManager.DeleteHostInfo(hostinfo)
return true
}
@ -161,19 +161,9 @@ func (c *Control) GetHostmap() *HostMap {
}
func (c *Control) GetCert() *cert.NebulaCertificate {
return c.f.certState.Load().certificate
return c.f.pki.GetCertState().Certificate
}
func (c *Control) ReHandshake(vpnIp iputil.VpnIp) {
hostinfo := c.f.handshakeManager.AddVpnIp(vpnIp, c.f.initHostInfo)
ixHandshakeStage0(c.f, vpnIp, hostinfo)
// If this is a static host, we don't need to wait for the HostQueryReply
// We can trigger the handshake right now
if _, ok := c.f.lightHouse.GetStaticHostList()[hostinfo.vpnIp]; ok {
select {
case c.f.handshakeManager.trigger <- hostinfo.vpnIp:
default:
}
}
c.f.handshakeManager.StartHandshake(vpnIp, nil)
}

View File

@ -4,6 +4,8 @@ Wants=basic.target network-online.target nss-lookup.target time-sync.target
After=basic.target network.target network-online.target
[Service]
Type=notify
NotifyAccess=main
SyslogIdentifier=nebula
ExecReload=/bin/kill -HUP $MAINPID
ExecStart=/usr/bin/nebula -config /etc/nebula/config.yml

View File

@ -5,6 +5,8 @@ After=basic.target network.target network-online.target
Before=sshd.service
[Service]
Type=notify
NotifyAccess=main
SyslogIdentifier=nebula
ExecReload=/bin/kill -HUP $MAINPID
ExecStart=/usr/bin/nebula -config /etc/nebula/config.yml

View File

@ -47,8 +47,8 @@ func (d *dnsRecords) QueryCert(data string) string {
return ""
}
iip := iputil.Ip2VpnIp(ip)
hostinfo, err := d.hostMap.QueryVpnIp(iip)
if err != nil {
hostinfo := d.hostMap.QueryVpnIp(iip)
if hostinfo == nil {
return ""
}
q := hostinfo.GetCert()

View File

@ -410,6 +410,8 @@ func TestStage1RaceRelays(t *testing.T) {
p := r.RouteForAllUntilTxTun(myControl)
_ = p
r.FlushAll()
myControl.Stop()
theirControl.Stop()
relayControl.Stop()
@ -608,6 +610,110 @@ func TestRehandshakingRelays(t *testing.T) {
t.Logf("relayControl hostinfos got cleaned up!")
}
func TestRehandshakingRelaysPrimary(t *testing.T) {
// This test is the same as TestRehandshakingRelays but one of the terminal types is a primary swap winner
ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
myControl, myVpnIpNet, _, _ := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 128}, m{"relay": m{"use_relays": true}})
relayControl, relayVpnIpNet, relayUdpAddr, relayConfig := newSimpleServer(ca, caKey, "relay ", net.IP{10, 0, 0, 1}, m{"relay": m{"am_relay": true}})
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", net.IP{10, 0, 0, 2}, m{"relay": m{"use_relays": true}})
// Teach my how to get to the relay and that their can be reached via the relay
myControl.InjectLightHouseAddr(relayVpnIpNet.IP, relayUdpAddr)
myControl.InjectRelays(theirVpnIpNet.IP, []net.IP{relayVpnIpNet.IP})
relayControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr)
// Build a router so we don't have to reason who gets which packet
r := router.NewR(t, myControl, relayControl, theirControl)
defer r.RenderFlow()
// Start the servers
myControl.Start()
relayControl.Start()
theirControl.Start()
t.Log("Trigger a handshake from me to them via the relay")
myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me"))
p := r.RouteForAllUntilTxTun(theirControl)
r.Log("Assert the tunnel works")
assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.IP, theirVpnIpNet.IP, 80, 80)
r.RenderHostmaps("working hostmaps", myControl, relayControl, theirControl)
// When I update the certificate for the relay, both me and them will have 2 host infos for the relay,
// and the main host infos will not have any relay state to handle the me<->relay<->them tunnel.
r.Log("Renew relay certificate and spin until me and them sees it")
_, _, myNextPrivKey, myNextPEM := newTestCert(ca, caKey, "relay", time.Now(), time.Now().Add(5*time.Minute), relayVpnIpNet, nil, []string{"new group"})
caB, err := ca.MarshalToPEM()
if err != nil {
panic(err)
}
relayConfig.Settings["pki"] = m{
"ca": string(caB),
"cert": string(myNextPEM),
"key": string(myNextPrivKey),
}
rc, err := yaml.Marshal(relayConfig.Settings)
assert.NoError(t, err)
relayConfig.ReloadConfigString(string(rc))
for {
r.Log("Assert the tunnel works between myVpnIpNet and relayVpnIpNet")
assertTunnel(t, myVpnIpNet.IP, relayVpnIpNet.IP, myControl, relayControl, r)
c := myControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(relayVpnIpNet.IP), false)
if len(c.Cert.Details.Groups) != 0 {
// We have a new certificate now
r.Log("Certificate between my and relay is updated!")
break
}
time.Sleep(time.Second)
}
for {
r.Log("Assert the tunnel works between theirVpnIpNet and relayVpnIpNet")
assertTunnel(t, theirVpnIpNet.IP, relayVpnIpNet.IP, theirControl, relayControl, r)
c := theirControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(relayVpnIpNet.IP), false)
if len(c.Cert.Details.Groups) != 0 {
// We have a new certificate now
r.Log("Certificate between their and relay is updated!")
break
}
time.Sleep(time.Second)
}
r.Log("Assert the relay tunnel still works")
assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r)
r.RenderHostmaps("working hostmaps", myControl, relayControl, theirControl)
// We should have two hostinfos on all sides
for len(myControl.GetHostmap().Indexes) != 2 {
t.Logf("Waiting for myControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(myControl.GetHostmap().Indexes))
r.Log("Assert the relay tunnel still works")
assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r)
r.Log("yupitdoes")
time.Sleep(time.Second)
}
t.Logf("myControl hostinfos got cleaned up!")
for len(theirControl.GetHostmap().Indexes) != 2 {
t.Logf("Waiting for theirControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(theirControl.GetHostmap().Indexes))
r.Log("Assert the relay tunnel still works")
assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r)
r.Log("yupitdoes")
time.Sleep(time.Second)
}
t.Logf("theirControl hostinfos got cleaned up!")
for len(relayControl.GetHostmap().Indexes) != 2 {
t.Logf("Waiting for relayControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(relayControl.GetHostmap().Indexes))
r.Log("Assert the relay tunnel still works")
assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r)
r.Log("yupitdoes")
time.Sleep(time.Second)
}
t.Logf("relayControl hostinfos got cleaned up!")
}
func TestRehandshaking(t *testing.T) {
ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
myControl, myVpnIpNet, myUdpAddr, myConfig := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 2}, nil)

View File

@ -12,9 +12,9 @@ import (
"testing"
"time"
"dario.cat/mergo"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
"github.com/imdario/mergo"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula"
"github.com/slackhq/nebula/cert"

View File

@ -21,6 +21,19 @@ pki:
static_host_map:
"192.168.100.1": ["100.64.22.11:4242"]
# The static_map config stanza can be used to configure how the static_host_map behaves.
#static_map:
# cadence determines how frequently DNS is re-queried for updated IP addresses when a static_host_map entry contains
# a DNS name.
#cadence: 30s
# network determines the type of IP addresses to ask the DNS server for. The default is "ip4" because nodes typically
# do not know their public IPv4 address. Connecting to the Lighthouse via IPv4 allows the Lighthouse to detect the
# public address. Other valid options are "ip6" and "ip" (returns both.)
#network: ip4
# lookup_timeout is the DNS query timeout.
#lookup_timeout: 250ms
lighthouse:
# am_lighthouse is used to enable lighthouse functionality for a node. This should ONLY be true on nodes
@ -158,7 +171,8 @@ punchy:
# and has been deprecated for "preferred_ranges"
#preferred_ranges: ["172.16.0.0/24"]
# sshd can expose informational and administrative functions via ssh this is a
# sshd can expose informational and administrative functions via ssh. This can expose informational and administrative
# functions, and allows manual tweaking of various network settings when debugging or testing.
#sshd:
# Toggles the feature
#enabled: true
@ -194,7 +208,7 @@ tun:
disabled: false
# Name of the device. If not set, a default will be chosen by the OS.
# For macOS: if set, must be in the form `utun[0-9]+`.
# For FreeBSD: Required to be set, must be in the form `tun[0-9]+`.
# For NetBSD: Required to be set, must be in the form `tun[0-9]+`
dev: nebula1
# Toggles forwarding of local broadcast packets, the address of which depends on the ip/mask encoded in pki.cert
drop_local_broadcast: false

View File

@ -5,6 +5,8 @@ After=basic.target network.target network-online.target
Before=sshd.service
[Service]
Type=notify
NotifyAccess=main
SyslogIdentifier=nebula
ExecReload=/bin/kill -HUP $MAINPID
ExecStart=/usr/local/bin/nebula -config /etc/nebula/config.yml

27
go.mod
View File

@ -3,31 +3,32 @@ module github.com/slackhq/nebula
go 1.20
require (
dario.cat/mergo v1.0.0
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be
github.com/armon/go-radix v1.0.0
github.com/cyberdelia/go-metrics-graphite v0.0.0-20161219230853-39f87cc3b432
github.com/flynn/noise v1.0.0
github.com/gogo/protobuf v1.3.2
github.com/google/gopacket v1.1.19
github.com/imdario/mergo v0.3.15
github.com/kardianos/service v1.2.2
github.com/miekg/dns v1.1.54
github.com/miekg/dns v1.1.56
github.com/nbrownus/go-metrics-prometheus v0.0.0-20210712211119-974a6260965f
github.com/prometheus/client_golang v1.15.1
github.com/prometheus/client_golang v1.16.0
github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475
github.com/sirupsen/logrus v1.9.0
github.com/sirupsen/logrus v1.9.3
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e
github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8
github.com/stretchr/testify v1.8.2
github.com/stretchr/testify v1.8.4
github.com/vishvananda/netlink v1.1.0
golang.org/x/crypto v0.8.0
golang.org/x/crypto v0.14.0
golang.org/x/exp v0.0.0-20230425010034-47ecfdc1ba53
golang.org/x/net v0.9.0
golang.org/x/sys v0.8.0
golang.org/x/term v0.8.0
golang.org/x/net v0.17.0
golang.org/x/sys v0.13.0
golang.org/x/term v0.13.0
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2
golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b
golang.zx2c4.com/wireguard/windows v0.5.3
google.golang.org/protobuf v1.30.0
google.golang.org/protobuf v1.31.0
gopkg.in/yaml.v2 v2.4.0
)
@ -40,10 +41,10 @@ require (
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/prometheus/client_model v0.4.0 // indirect
github.com/prometheus/common v0.42.0 // indirect
github.com/prometheus/procfs v0.9.0 // indirect
github.com/prometheus/procfs v0.10.1 // indirect
github.com/rogpeppe/go-internal v1.10.0 // indirect
github.com/vishvananda/netns v0.0.4 // indirect
golang.org/x/mod v0.10.0 // indirect
golang.org/x/tools v0.8.0 // indirect
golang.org/x/mod v0.12.0 // indirect
golang.org/x/tools v0.13.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

60
go.sum
View File

@ -1,4 +1,6 @@
cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
dario.cat/mergo v1.0.0 h1:AGCNq9Evsj31mOgNPcLyXc+4PNABt905YmuqPYYpBWk=
dario.cat/mergo v1.0.0/go.mod h1:uNxQE+84aUszobStD9th8a29P2fMDhsBdgRYvZOxGmk=
github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc=
github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc=
github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0=
@ -54,8 +56,6 @@ github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/gopacket v1.1.19 h1:ves8RnFZPGiFnTS0uPQStjwru6uO6h+nlr9j6fL7kF8=
github.com/google/gopacket v1.1.19/go.mod h1:iJ8V8n6KS+z2U1A8pUwu8bW5SyEMkXJB8Yo/Vo+TKTo=
github.com/imdario/mergo v0.3.15 h1:M8XP7IuFNsqUx6VPK2P9OSmsYsI/YFaGil0uD21V3dM=
github.com/imdario/mergo v0.3.15/go.mod h1:WBLT9ZmE3lPoWsEzCh9LPo3TiwVN+ZKEjmz+hD27ysY=
github.com/jpillora/backoff v1.0.0/go.mod h1:J/6gKK9jxlEcS3zixgDgUAsiuZ7yrSoa/FX5e0EB2j4=
github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU=
github.com/json-iterator/go v1.1.10/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4=
@ -78,8 +78,8 @@ github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0=
github.com/matttproud/golang_protobuf_extensions v1.0.4 h1:mmDVorXM7PCGKw94cs5zkfA9PSy5pEvNWRP0ET0TIVo=
github.com/matttproud/golang_protobuf_extensions v1.0.4/go.mod h1:BSXmuO+STAnVfrANrmjBb36TMTDstsz7MSK+HVaYKv4=
github.com/miekg/dns v1.1.54 h1:5jon9mWcb0sFJGpnI99tOMhCPyJ+RPVz5b63MQG0VWI=
github.com/miekg/dns v1.1.54/go.mod h1:uInx36IzPl7FYnDcMeVWxj9byh7DutNykX4G9Sj60FY=
github.com/miekg/dns v1.1.56 h1:5imZaSeoRNvpM9SzWNhEcP9QliKiz20/dA2QabIGVnE=
github.com/miekg/dns v1.1.56/go.mod h1:cRm6Oo2C8TY9ZS/TqsSrseAcncm74lfK5G+ikN2SWWY=
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0=
@ -97,8 +97,8 @@ github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXP
github.com/prometheus/client_golang v1.0.0/go.mod h1:db9x61etRT2tGnBNRi70OPL5FsnadC4Ky3P0J6CfImo=
github.com/prometheus/client_golang v1.7.1/go.mod h1:PY5Wy2awLA44sXw4AOSfFBetzPP4j5+D6mVACh+pe2M=
github.com/prometheus/client_golang v1.11.0/go.mod h1:Z6t4BnS23TR94PD6BsDNk8yVqroYurpAkEiz0P2BEV0=
github.com/prometheus/client_golang v1.15.1 h1:8tXpTmJbyH5lydzFPoxSIJ0J46jdh3tylbvM1xCv0LI=
github.com/prometheus/client_golang v1.15.1/go.mod h1:e9yaBhRPU2pPNsZwE+JdQl0KEt1N9XgF6zxWmaC0xOk=
github.com/prometheus/client_golang v1.16.0 h1:yk/hx9hDbrGHovbci4BY+pRMfSuuat626eFsHb7tmT8=
github.com/prometheus/client_golang v1.16.0/go.mod h1:Zsulrv/L9oM40tJ7T815tM89lFEugiJ9HzIqaAx4LKc=
github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo=
github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
github.com/prometheus/client_model v0.2.0/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
@ -113,8 +113,8 @@ github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d/go.mod h1:c3At6R
github.com/prometheus/procfs v0.0.2/go.mod h1:TjEm7ze935MbeOT/UhFTIMYKhuLP4wbCsTZCD3I8kEA=
github.com/prometheus/procfs v0.1.3/go.mod h1:lV6e/gmhEcM9IjHGsFOCxxuZ+z1YqCvr4OA4YeYWdaU=
github.com/prometheus/procfs v0.6.0/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1xBZuNvfVA=
github.com/prometheus/procfs v0.9.0 h1:wzCHvIvM5SxWqYvwgVL7yJY8Lz3PKn49KQtpgMYJfhI=
github.com/prometheus/procfs v0.9.0/go.mod h1:+pB4zwohETzFnmlpe6yd2lSc+0/46IYZRB/chUwxUZY=
github.com/prometheus/procfs v0.10.1 h1:kYK1Va/YMlutzCGazswoHKo//tZVlFpKYh+PymziUAg=
github.com/prometheus/procfs v0.10.1/go.mod h1:nwNm2aOCAYw8uTR/9bWRREkZFxAUcWzPHWJq+XBB/FM=
github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475 h1:N/ElC8H3+5XpJzTSTfLsJV/mx9Q9g7kxmchpfZyxgzM=
github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4=
github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ=
@ -122,24 +122,20 @@ github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncj
github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo=
github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE=
github.com/sirupsen/logrus v1.6.0/go.mod h1:7uNnSEd1DgxDLC74fIahvMZmmYsHGZGEOFrfsX/uA88=
github.com/sirupsen/logrus v1.9.0 h1:trlNQbNUG3OdDrDil03MCb1H2o9nJ1x4/5LYw7byDE0=
github.com/sirupsen/logrus v1.9.0/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e h1:MRM5ITcdelLK2j1vwZ3Je0FKVCfqOLp5zO6trqMLYs0=
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e/go.mod h1:XV66xRDqSt+GTGFMVlhk3ULuV0y9ZmzeVGR4mloJI3M=
github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8 h1:TG/diQgUe0pntT/2D9tmUCz4VNwm9MfrtPr0SU2qSX8=
github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8/go.mod h1:P5HUIBuIWKbyjl083/loAegFkfbFNx5i2qEP4CNbm7E=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8=
github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/vishvananda/netlink v1.1.0 h1:1iyaYNBLmP6L0220aDnYQpo1QEV4t4hJ+xEEhhJH8j0=
github.com/vishvananda/netlink v1.1.0/go.mod h1:cTgwzPIzzgDAYoQrMm0EdrjRUBkTqKYppBueQtXaqoE=
github.com/vishvananda/netns v0.0.0-20191106174202-0a2b9b5464df/go.mod h1:JP3t17pCcGlemwknint6hfoeCVQrEMVwxRLRjXpq+BU=
@ -152,16 +148,16 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
golang.org/x/crypto v0.8.0 h1:pd9TJtTueMTVQXzk8E2XESSMQDj/U7OUu0PqJqPXQjQ=
golang.org/x/crypto v0.8.0/go.mod h1:mRqEX+O9/h5TFCrQhkgjo2yKi0yYA+9ecGkdQoHrywE=
golang.org/x/crypto v0.14.0 h1:wBqGXzWJW6m1XrIKlAH0Hs1JJ7+9KBwnIO8v66Q9cHc=
golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4=
golang.org/x/exp v0.0.0-20230425010034-47ecfdc1ba53 h1:5llv2sWeaMSnA3w2kS57ouQQ4pudlXrR0dCgw51QK9o=
golang.org/x/exp v0.0.0-20230425010034-47ecfdc1ba53/go.mod h1:V1LtkGg67GoY2N1AnLN78QLrzxkLyJw7RJb1gzOOz9w=
golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY=
golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg=
golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.10.0 h1:lFO9qtOdlre5W1jxS3r/4szv2/6iXxScdzjoBMXNhYk=
golang.org/x/mod v0.10.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/mod v0.12.0 h1:rmsUpXtvNzj340zd98LZ4KntptpfRHwpFOHG188oHXc=
golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
@ -172,8 +168,8 @@ golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLL
golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.9.0 h1:aWJ/m6xSmxWBx+V0XRHTlrYrPG56jKsLdTFmsSsCzOM=
golang.org/x/net v0.9.0/go.mod h1:d48xBJpPfHeWQsugry2m+kC02ZBRGRgulfHnEXEuWns=
golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM=
golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE=
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
@ -181,7 +177,7 @@ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJ
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o=
golang.org/x/sync v0.3.0 h1:ftCYgMx6zT/asHUrPw8BLLscYtGznsLAnjq5RH9P66E=
golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
@ -198,11 +194,11 @@ golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210603081109-ebe580a85c40/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.8.0 h1:EBmGv8NaZBZTWvrbjNoL6HVt+IVy3QDQpJs7VRIw3tU=
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE=
golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.8.0 h1:n5xxQn2i3PC0yLAbjTpNT85q/Kgzcr2gIoX9OrJUols=
golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo=
golang.org/x/term v0.13.0 h1:bb+I9cTfFazGW51MZqBVmZy7+JEJMouUHTUSKVQLBek=
golang.org/x/term v0.13.0/go.mod h1:LTmsnFJwVN6bCy1rVCoS+qHT1HhALEFxKncY3WNNh4U=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
@ -211,14 +207,16 @@ golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtn
golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
golang.org/x/tools v0.8.0 h1:vSDcovVPld282ceKgDimkRSC8kpaH1dgyc9UMzlt84Y=
golang.org/x/tools v0.8.0/go.mod h1:JxBZ99ISMI5ViVkT1tr6tdNmXeTrcpVSD3vZ1RsRdN4=
golang.org/x/tools v0.13.0 h1:Iey4qkscZuv0VvIt8E0neZjtPVQFSc870HQ448QgEmQ=
golang.org/x/tools v0.13.0/go.mod h1:HvlwmtVNQAhOuCjW7xxvovg8wbNq7LwfXh/k7wXUl58=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg=
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI=
golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b h1:J1CaxgLerRR5lgx3wnr6L04cJFbWoceSK9JWBdglINo=
golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b/go.mod h1:tqur9LnfstdR9ep2LaJT4lFUl0EjlHtge+gAjmsHUG4=
golang.zx2c4.com/wireguard/windows v0.5.3 h1:On6j2Rpn3OEMXqBq00QEDC7bWSZrPIHKIus8eIuExIE=
golang.zx2c4.com/wireguard/windows v0.5.3/go.mod h1:9TEe8TJmtwyQebdFwAkEWOPr3prrtqm+REGFifP60hI=
google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4=
@ -230,8 +228,8 @@ google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzi
google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU=
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc=
google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng=
google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
google.golang.org/protobuf v1.31.0 h1:g0LDEJHgrBl9N9r17Ru3sqWhkIx2NB67okBHPwC7hs8=
google.golang.org/protobuf v1.31.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=

View File

@ -20,7 +20,7 @@ func HandleIncomingHandshake(f *Interface, addr *udp.Addr, via *ViaSender, packe
case 1:
ixHandshakeStage1(f, addr, via, packet, h)
case 2:
newHostinfo, _ := f.handshakeManager.QueryIndex(h.RemoteIndex)
newHostinfo := f.handshakeManager.QueryIndex(h.RemoteIndex)
tearDown := ixHandshakeStage2(f, addr, via, newHostinfo, packet, h)
if tearDown && newHostinfo != nil {
f.handshakeManager.DeleteHostInfo(newHostinfo)

View File

@ -13,27 +13,22 @@ import (
// This function constructs a handshake packet, but does not actually send it
// Sending is done by the handshake manager
func ixHandshakeStage0(f *Interface, vpnIp iputil.VpnIp, hostinfo *HostInfo) {
// This queries the lighthouse if we don't know a remote for the host
// We do it here to provoke the lighthouse to preempt our timer wheel and trigger the stage 1 packet to send
// more quickly, effect is a quicker handshake.
if hostinfo.remote == nil {
f.lightHouse.QueryServer(vpnIp, f)
}
err := f.handshakeManager.AddIndexHostInfo(hostinfo)
func ixHandshakeStage0(f *Interface, hostinfo *HostInfo) bool {
err := f.handshakeManager.allocateIndex(hostinfo)
if err != nil {
f.l.WithError(err).WithField("vpnIp", vpnIp).
f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to generate index")
return
return false
}
ci := hostinfo.ConnectionState
certState := f.pki.GetCertState()
ci := NewConnectionState(f.l, f.cipher, certState, true, noise.HandshakeIX, []byte{}, 0)
hostinfo.ConnectionState = ci
hsProto := &NebulaHandshakeDetails{
InitiatorIndex: hostinfo.localIndexId,
Time: uint64(time.Now().UnixNano()),
Cert: ci.certState.rawCertificateNoKey,
Cert: certState.RawCertificateNoKey,
}
if f.multiPort.Tx || f.multiPort.Rx {
@ -53,9 +48,9 @@ func ixHandshakeStage0(f *Interface, vpnIp iputil.VpnIp, hostinfo *HostInfo) {
hsBytes, err = hs.Marshal()
if err != nil {
f.l.WithError(err).WithField("vpnIp", vpnIp).
f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to marshal handshake message")
return
return false
}
h := header.Encode(make([]byte, header.Len), header.Version, header.Handshake, header.HandshakeIXPSK0, 0, 1)
@ -63,9 +58,9 @@ func ixHandshakeStage0(f *Interface, vpnIp iputil.VpnIp, hostinfo *HostInfo) {
msg, _, _, err := ci.H.WriteMessage(h, hsBytes)
if err != nil {
f.l.WithError(err).WithField("vpnIp", vpnIp).
f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage")
return
return false
}
// We are sending handshake packet 1, so we don't expect to receive
@ -75,10 +70,12 @@ func ixHandshakeStage0(f *Interface, vpnIp iputil.VpnIp, hostinfo *HostInfo) {
hostinfo.HandshakePacket[0] = msg
hostinfo.HandshakeReady = true
hostinfo.handshakeStart = time.Now()
return true
}
func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []byte, h *header.H) {
ci := f.newConnectionState(f.l, false, noise.HandshakeIX, []byte{}, 0)
certState := f.pki.GetCertState()
ci := NewConnectionState(f.l, f.cipher, certState, false, noise.HandshakeIX, []byte{}, 0)
// Mark packet 1 as seen so it doesn't show up as missed
ci.window.Update(f.l, 1)
@ -100,7 +97,7 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []by
return
}
remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert, f.caPool)
remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert, f.pki.GetCAPool())
if err != nil {
f.l.WithError(err).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).WithField("cert", remoteCert).
@ -190,7 +187,7 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []by
Info("Handshake message received")
hs.Details.ResponderIndex = myIndex
hs.Details.Cert = ci.certState.rawCertificateNoKey
hs.Details.Cert = certState.RawCertificateNoKey
// Update the time in case their clock is way off from ours
hs.Details.Time = uint64(time.Now().UnixNano())
@ -467,7 +464,7 @@ func ixHandshakeStage2(f *Interface, addr *udp.Addr, via *ViaSender, hostinfo *H
}
}
remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert, f.caPool)
remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert, f.pki.GetCAPool())
if err != nil {
f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).
WithField("cert", remoteCert).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
@ -490,34 +487,30 @@ func ixHandshakeStage2(f *Interface, addr *udp.Addr, via *ViaSender, hostinfo *H
Info("Incorrect host responded to handshake")
// Release our old handshake from pending, it should not continue
f.handshakeManager.pendingHostMap.DeleteHostInfo(hostinfo)
f.handshakeManager.DeleteHostInfo(hostinfo)
// Create a new hostinfo/handshake for the intended vpn ip
//TODO: this adds it to the timer wheel in a way that aggressively retries
newHostInfo := f.getOrHandshake(hostinfo.vpnIp)
newHostInfo.Lock()
f.handshakeManager.StartHandshake(hostinfo.vpnIp, func(newHostInfo *HostInfo) {
//TODO: this doesnt know if its being added or is being used for caching a packet
// Block the current used address
newHostInfo.remotes = hostinfo.remotes
newHostInfo.remotes.BlockRemote(addr)
// Block the current used address
newHostInfo.remotes = hostinfo.remotes
newHostInfo.remotes.BlockRemote(addr)
// Get the correct remote list for the host we did handshake with
hostinfo.remotes = f.lightHouse.QueryCache(vpnIp)
// Get the correct remote list for the host we did handshake with
hostinfo.remotes = f.lightHouse.QueryCache(vpnIp)
f.l.WithField("blockedUdpAddrs", newHostInfo.remotes.CopyBlockedRemotes()).WithField("vpnIp", vpnIp).
WithField("remotes", newHostInfo.remotes.CopyAddrs(f.hostMap.preferredRanges)).
Info("Blocked addresses for handshakes")
f.l.WithField("blockedUdpAddrs", newHostInfo.remotes.CopyBlockedRemotes()).WithField("vpnIp", vpnIp).
WithField("remotes", newHostInfo.remotes.CopyAddrs(f.hostMap.preferredRanges)).
Info("Blocked addresses for handshakes")
// Swap the packet store to benefit the original intended recipient
newHostInfo.packetStore = hostinfo.packetStore
hostinfo.packetStore = []*cachedPacket{}
// Swap the packet store to benefit the original intended recipient
hostinfo.ConnectionState.queueLock.Lock()
newHostInfo.packetStore = hostinfo.packetStore
hostinfo.packetStore = []*cachedPacket{}
hostinfo.ConnectionState.queueLock.Unlock()
// Finally, put the correct vpn ip in the host info, tell them to close the tunnel, and return true to tear down
hostinfo.vpnIp = vpnIp
f.sendCloseTunnel(hostinfo)
newHostInfo.Unlock()
// Finally, put the correct vpn ip in the host info, tell them to close the tunnel, and return true to tear down
hostinfo.vpnIp = vpnIp
f.sendCloseTunnel(hostinfo)
})
return true
}

View File

@ -7,6 +7,7 @@ import (
"encoding/binary"
"errors"
"net"
"sync"
"time"
"github.com/rcrowley/go-metrics"
@ -42,15 +43,21 @@ type HandshakeConfig struct {
}
type HandshakeManager struct {
pendingHostMap *HostMap
// Mutex for interacting with the vpnIps and indexes maps
sync.RWMutex
vpnIps map[iputil.VpnIp]*HostInfo
indexes map[uint32]*HostInfo
mainHostMap *HostMap
lightHouse *LightHouse
outside *udp.Conn
outside udp.Conn
config HandshakeConfig
OutboundHandshakeTimer *LockingTimerWheel[iputil.VpnIp]
messageMetrics *MessageMetrics
metricInitiated metrics.Counter
metricTimedOut metrics.Counter
f *Interface
l *logrus.Logger
multiPort MultiPortConfig
@ -60,9 +67,10 @@ type HandshakeManager struct {
trigger chan iputil.VpnIp
}
func NewHandshakeManager(l *logrus.Logger, tunCidr *net.IPNet, preferredRanges []*net.IPNet, mainHostMap *HostMap, lightHouse *LightHouse, outside *udp.Conn, config HandshakeConfig) *HandshakeManager {
func NewHandshakeManager(l *logrus.Logger, mainHostMap *HostMap, lightHouse *LightHouse, outside udp.Conn, config HandshakeConfig) *HandshakeManager {
return &HandshakeManager{
pendingHostMap: NewHostMap(l, "pending", tunCidr, preferredRanges),
vpnIps: map[iputil.VpnIp]*HostInfo{},
indexes: map[uint32]*HostInfo{},
mainHostMap: mainHostMap,
lightHouse: lightHouse,
outside: outside,
@ -76,7 +84,7 @@ func NewHandshakeManager(l *logrus.Logger, tunCidr *net.IPNet, preferredRanges [
}
}
func (c *HandshakeManager) Run(ctx context.Context, f EncWriter) {
func (c *HandshakeManager) Run(ctx context.Context) {
clockSource := time.NewTicker(c.config.tryInterval)
defer clockSource.Stop()
@ -85,27 +93,27 @@ func (c *HandshakeManager) Run(ctx context.Context, f EncWriter) {
case <-ctx.Done():
return
case vpnIP := <-c.trigger:
c.handleOutbound(vpnIP, f, true)
c.handleOutbound(vpnIP, true)
case now := <-clockSource.C:
c.NextOutboundHandshakeTimerTick(now, f)
c.NextOutboundHandshakeTimerTick(now)
}
}
}
func (c *HandshakeManager) NextOutboundHandshakeTimerTick(now time.Time, f EncWriter) {
func (c *HandshakeManager) NextOutboundHandshakeTimerTick(now time.Time) {
c.OutboundHandshakeTimer.Advance(now)
for {
vpnIp, has := c.OutboundHandshakeTimer.Purge()
if !has {
break
}
c.handleOutbound(vpnIp, f, false)
c.handleOutbound(vpnIp, false)
}
}
func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, f EncWriter, lighthouseTriggered bool) {
hostinfo, err := c.pendingHostMap.QueryVpnIp(vpnIp)
if err != nil {
func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTriggered bool) {
hostinfo := c.QueryVpnIp(vpnIp)
if hostinfo == nil {
return
}
hostinfo.Lock()
@ -114,31 +122,34 @@ func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, f EncWriter, light
// We may have raced to completion but now that we have a lock we should ensure we have not yet completed.
if hostinfo.HandshakeComplete {
// Ensure we don't exist in the pending hostmap anymore since we have completed
c.pendingHostMap.DeleteHostInfo(hostinfo)
return
}
// Check if we have a handshake packet to transmit yet
if !hostinfo.HandshakeReady {
// There is currently a slight race in getOrHandshake due to ConnectionState not being part of the HostInfo directly
// Our hostinfo here was added to the pending map and the wheel may have ticked to us before we created ConnectionState
c.OutboundHandshakeTimer.Add(vpnIp, c.config.tryInterval*time.Duration(hostinfo.HandshakeCounter))
c.DeleteHostInfo(hostinfo)
return
}
// If we are out of time, clean up
if hostinfo.HandshakeCounter >= c.config.retries {
hostinfo.logger(c.l).WithField("udpAddrs", hostinfo.remotes.CopyAddrs(c.pendingHostMap.preferredRanges)).
hostinfo.logger(c.l).WithField("udpAddrs", hostinfo.remotes.CopyAddrs(c.mainHostMap.preferredRanges)).
WithField("initiatorIndex", hostinfo.localIndexId).
WithField("remoteIndex", hostinfo.remoteIndexId).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
WithField("durationNs", time.Since(hostinfo.handshakeStart).Nanoseconds()).
Info("Handshake timed out")
c.metricTimedOut.Inc(1)
c.pendingHostMap.DeleteHostInfo(hostinfo)
c.DeleteHostInfo(hostinfo)
return
}
// Increment the counter to increase our delay, linear backoff
hostinfo.HandshakeCounter++
// Check if we have a handshake packet to transmit yet
if !hostinfo.HandshakeReady {
if !ixHandshakeStage0(c.f, hostinfo) {
c.OutboundHandshakeTimer.Add(vpnIp, c.config.tryInterval*time.Duration(hostinfo.HandshakeCounter))
return
}
}
// Get a remotes object if we don't already have one.
// This is mainly to protect us as this should never be the case
// NB ^ This comment doesn't jive. It's how the thing gets initialized.
@ -147,7 +158,7 @@ func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, f EncWriter, light
hostinfo.remotes = c.lightHouse.QueryCache(vpnIp)
}
remotes := hostinfo.remotes.CopyAddrs(c.pendingHostMap.preferredRanges)
remotes := hostinfo.remotes.CopyAddrs(c.mainHostMap.preferredRanges)
remotesHaveChanged := !udp.AddrSlice(remotes).Equal(hostinfo.HandshakeLastRemotes)
// We only care about a lighthouse trigger if we have new remotes to send to.
@ -166,15 +177,15 @@ func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, f EncWriter, light
// If we only have 1 remote it is highly likely our query raced with the other host registered within the lighthouse
// Our vpnIp here has a tunnel with a lighthouse but has yet to send a host update packet there so we only know about
// the learned public ip for them. Query again to short circuit the promotion counter
c.lightHouse.QueryServer(vpnIp, f)
c.lightHouse.QueryServer(vpnIp, c.f)
}
// Send the handshake to all known ips, stage 2 takes care of assigning the hostinfo.remote based on the first to reply
var sentTo []*udp.Addr
var sentMultiport bool
hostinfo.remotes.ForEach(c.pendingHostMap.preferredRanges, func(addr *udp.Addr, _ bool) {
hostinfo.remotes.ForEach(c.mainHostMap.preferredRanges, func(addr *udp.Addr, _ bool) {
c.messageMetrics.Tx(header.Handshake, header.MessageSubType(hostinfo.HandshakePacket[0][1]), 1)
err = c.outside.WriteTo(hostinfo.HandshakePacket[0], addr)
err := c.outside.WriteTo(hostinfo.HandshakePacket[0], addr)
if err != nil {
hostinfo.logger(c.l).WithField("udpAddr", addr).
WithField("initiatorIndex", hostinfo.localIndexId).
@ -230,10 +241,10 @@ func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, f EncWriter, light
if *relay == vpnIp || *relay == c.lightHouse.myVpnIp {
continue
}
relayHostInfo, err := c.mainHostMap.QueryVpnIp(*relay)
if err != nil || relayHostInfo.remote == nil {
hostinfo.logger(c.l).WithError(err).WithField("relay", relay.String()).Info("Establish tunnel to relay target")
f.Handshake(*relay)
relayHostInfo := c.mainHostMap.QueryVpnIp(*relay)
if relayHostInfo == nil || relayHostInfo.remote == nil {
hostinfo.logger(c.l).WithField("relay", relay.String()).Info("Establish tunnel to relay target")
c.f.Handshake(*relay)
continue
}
// Check the relay HostInfo to see if we already established a relay through it
@ -241,7 +252,7 @@ func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, f EncWriter, light
switch existingRelay.State {
case Established:
hostinfo.logger(c.l).WithField("relay", relay.String()).Info("Send handshake via relay")
f.SendVia(relayHostInfo, existingRelay, hostinfo.HandshakePacket[0], make([]byte, 12), make([]byte, mtu), false)
c.f.SendVia(relayHostInfo, existingRelay, hostinfo.HandshakePacket[0], make([]byte, 12), make([]byte, mtu), false)
case Requested:
hostinfo.logger(c.l).WithField("relay", relay.String()).Info("Re-send CreateRelay request")
// Re-send the CreateRelay request, in case the previous one was lost.
@ -258,7 +269,7 @@ func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, f EncWriter, light
Error("Failed to marshal Control message to create relay")
} else {
// This must send over the hostinfo, not over hm.Hosts[ip]
f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu))
c.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu))
c.l.WithFields(logrus.Fields{
"relayFrom": c.lightHouse.myVpnIp,
"relayTo": vpnIp,
@ -293,7 +304,7 @@ func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, f EncWriter, light
WithError(err).
Error("Failed to marshal Control message to create relay")
} else {
f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu))
c.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu))
c.l.WithFields(logrus.Fields{
"relayFrom": c.lightHouse.myVpnIp,
"relayTo": vpnIp,
@ -306,23 +317,78 @@ func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, f EncWriter, light
}
}
// Increment the counter to increase our delay, linear backoff
hostinfo.HandshakeCounter++
// If a lighthouse triggered this attempt then we are still in the timer wheel and do not need to re-add
if !lighthouseTriggered {
c.OutboundHandshakeTimer.Add(vpnIp, c.config.tryInterval*time.Duration(hostinfo.HandshakeCounter))
}
}
func (c *HandshakeManager) AddVpnIp(vpnIp iputil.VpnIp, init func(*HostInfo)) *HostInfo {
hostinfo, created := c.pendingHostMap.AddVpnIp(vpnIp, init)
if created {
c.OutboundHandshakeTimer.Add(vpnIp, c.config.tryInterval)
c.metricInitiated.Inc(1)
// GetOrHandshake will try to find a hostinfo with a fully formed tunnel or start a new handshake if one is not present
// The 2nd argument will be true if the hostinfo is ready to transmit traffic
func (hm *HandshakeManager) GetOrHandshake(vpnIp iputil.VpnIp, cacheCb func(*HostInfo)) (*HostInfo, bool) {
// Check the main hostmap and maintain a read lock if our host is not there
hm.mainHostMap.RLock()
if h, ok := hm.mainHostMap.Hosts[vpnIp]; ok {
hm.mainHostMap.RUnlock()
// Do not attempt promotion if you are a lighthouse
if !hm.lightHouse.amLighthouse {
h.TryPromoteBest(hm.mainHostMap.preferredRanges, hm.f)
}
return h, true
}
defer hm.mainHostMap.RUnlock()
return hm.StartHandshake(vpnIp, cacheCb), false
}
// StartHandshake will ensure a handshake is currently being attempted for the provided vpn ip
func (hm *HandshakeManager) StartHandshake(vpnIp iputil.VpnIp, cacheCb func(*HostInfo)) *HostInfo {
hm.Lock()
if hostinfo, ok := hm.vpnIps[vpnIp]; ok {
// We are already trying to handshake with this vpn ip
if cacheCb != nil {
cacheCb(hostinfo)
}
hm.Unlock()
return hostinfo
}
hostinfo := &HostInfo{
vpnIp: vpnIp,
HandshakePacket: make(map[uint8][]byte, 0),
relayState: RelayState{
relays: map[iputil.VpnIp]struct{}{},
relayForByIp: map[iputil.VpnIp]*Relay{},
relayForByIdx: map[uint32]*Relay{},
},
}
hm.vpnIps[vpnIp] = hostinfo
hm.metricInitiated.Inc(1)
hm.OutboundHandshakeTimer.Add(vpnIp, hm.config.tryInterval)
if cacheCb != nil {
cacheCb(hostinfo)
}
// If this is a static host, we don't need to wait for the HostQueryReply
// We can trigger the handshake right now
_, doTrigger := hm.lightHouse.GetStaticHostList()[vpnIp]
if !doTrigger {
// Add any calculated remotes, and trigger early handshake if one found
doTrigger = hm.lightHouse.addCalculatedRemotes(vpnIp)
}
if doTrigger {
select {
case hm.trigger <- vpnIp:
default:
}
}
hm.Unlock()
hm.lightHouse.QueryServer(vpnIp, hm.f)
return hostinfo
}
@ -344,10 +410,10 @@ var (
// ErrLocalIndexCollision if we already have an entry in the main or pending
// hostmap for the hostinfo.localIndexId.
func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket uint8, f *Interface) (*HostInfo, error) {
c.pendingHostMap.Lock()
defer c.pendingHostMap.Unlock()
c.mainHostMap.Lock()
defer c.mainHostMap.Unlock()
c.Lock()
defer c.Unlock()
// Check if we already have a tunnel with this vpn ip
existingHostInfo, found := c.mainHostMap.Hosts[hostinfo.vpnIp]
@ -376,7 +442,7 @@ func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket
return existingIndex, ErrLocalIndexCollision
}
existingIndex, found = c.pendingHostMap.Indexes[hostinfo.localIndexId]
existingIndex, found = c.indexes[hostinfo.localIndexId]
if found && existingIndex != hostinfo {
// We have a collision, but for a different hostinfo
return existingIndex, ErrLocalIndexCollision
@ -398,47 +464,47 @@ func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket
// Complete is a simpler version of CheckAndComplete when we already know we
// won't have a localIndexId collision because we already have an entry in the
// pendingHostMap. An existing hostinfo is returned if there was one.
func (c *HandshakeManager) Complete(hostinfo *HostInfo, f *Interface) {
c.pendingHostMap.Lock()
defer c.pendingHostMap.Unlock()
c.mainHostMap.Lock()
defer c.mainHostMap.Unlock()
func (hm *HandshakeManager) Complete(hostinfo *HostInfo, f *Interface) {
hm.mainHostMap.Lock()
defer hm.mainHostMap.Unlock()
hm.Lock()
defer hm.Unlock()
existingRemoteIndex, found := c.mainHostMap.RemoteIndexes[hostinfo.remoteIndexId]
existingRemoteIndex, found := hm.mainHostMap.RemoteIndexes[hostinfo.remoteIndexId]
if found && existingRemoteIndex != nil {
// We have a collision, but this can happen since we can't control
// the remote ID. Just log about the situation as a note.
hostinfo.logger(c.l).
hostinfo.logger(hm.l).
WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", existingRemoteIndex.vpnIp).
Info("New host shadows existing host remoteIndex")
}
// We need to remove from the pending hostmap first to avoid undoing work when after to the main hostmap.
c.pendingHostMap.unlockedDeleteHostInfo(hostinfo)
c.mainHostMap.unlockedAddHostInfo(hostinfo, f)
hm.unlockedDeleteHostInfo(hostinfo)
hm.mainHostMap.unlockedAddHostInfo(hostinfo, f)
}
// AddIndexHostInfo generates a unique localIndexId for this HostInfo
// allocateIndex generates a unique localIndexId for this HostInfo
// and adds it to the pendingHostMap. Will error if we are unable to generate
// a unique localIndexId
func (c *HandshakeManager) AddIndexHostInfo(h *HostInfo) error {
c.pendingHostMap.Lock()
defer c.pendingHostMap.Unlock()
c.mainHostMap.RLock()
defer c.mainHostMap.RUnlock()
func (hm *HandshakeManager) allocateIndex(h *HostInfo) error {
hm.mainHostMap.RLock()
defer hm.mainHostMap.RUnlock()
hm.Lock()
defer hm.Unlock()
for i := 0; i < 32; i++ {
index, err := generateIndex(c.l)
index, err := generateIndex(hm.l)
if err != nil {
return err
}
_, inPending := c.pendingHostMap.Indexes[index]
_, inMain := c.mainHostMap.Indexes[index]
_, inPending := hm.indexes[index]
_, inMain := hm.mainHostMap.Indexes[index]
if !inMain && !inPending {
h.localIndexId = index
c.pendingHostMap.Indexes[index] = h
hm.indexes[index] = h
return nil
}
}
@ -446,22 +512,73 @@ func (c *HandshakeManager) AddIndexHostInfo(h *HostInfo) error {
return errors.New("failed to generate unique localIndexId")
}
func (c *HandshakeManager) addRemoteIndexHostInfo(index uint32, h *HostInfo) {
c.pendingHostMap.addRemoteIndexHostInfo(index, h)
}
func (c *HandshakeManager) DeleteHostInfo(hostinfo *HostInfo) {
//l.Debugln("Deleting pending hostinfo :", hostinfo)
c.pendingHostMap.DeleteHostInfo(hostinfo)
c.Lock()
defer c.Unlock()
c.unlockedDeleteHostInfo(hostinfo)
}
func (c *HandshakeManager) QueryIndex(index uint32) (*HostInfo, error) {
return c.pendingHostMap.QueryIndex(index)
func (c *HandshakeManager) unlockedDeleteHostInfo(hostinfo *HostInfo) {
delete(c.vpnIps, hostinfo.vpnIp)
if len(c.vpnIps) == 0 {
c.vpnIps = map[iputil.VpnIp]*HostInfo{}
}
delete(c.indexes, hostinfo.localIndexId)
if len(c.vpnIps) == 0 {
c.indexes = map[uint32]*HostInfo{}
}
if c.l.Level >= logrus.DebugLevel {
c.l.WithField("hostMap", m{"mapTotalSize": len(c.vpnIps),
"vpnIp": hostinfo.vpnIp, "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId}).
Debug("Pending hostmap hostInfo deleted")
}
}
func (c *HandshakeManager) QueryVpnIp(vpnIp iputil.VpnIp) *HostInfo {
c.RLock()
defer c.RUnlock()
return c.vpnIps[vpnIp]
}
func (c *HandshakeManager) QueryIndex(index uint32) *HostInfo {
c.RLock()
defer c.RUnlock()
return c.indexes[index]
}
func (c *HandshakeManager) GetPreferredRanges() []*net.IPNet {
return c.mainHostMap.preferredRanges
}
func (c *HandshakeManager) ForEachVpnIp(f controlEach) {
c.RLock()
defer c.RUnlock()
for _, v := range c.vpnIps {
f(v)
}
}
func (c *HandshakeManager) ForEachIndex(f controlEach) {
c.RLock()
defer c.RUnlock()
for _, v := range c.indexes {
f(v)
}
}
func (c *HandshakeManager) EmitStats() {
c.pendingHostMap.EmitStats("pending")
c.mainHostMap.EmitStats("main")
c.RLock()
hostLen := len(c.vpnIps)
indexLen := len(c.indexes)
c.RUnlock()
metrics.GetOrRegisterGauge("hostmap.pending.hosts", nil).Update(int64(hostLen))
metrics.GetOrRegisterGauge("hostmap.pending.indexes", nil).Update(int64(indexLen))
c.mainHostMap.EmitStats()
}
// Utility functions below

View File

@ -14,31 +14,20 @@ import (
func Test_NewHandshakeManagerVpnIp(t *testing.T) {
l := test.NewLogger()
_, tuncidr, _ := net.ParseCIDR("172.1.1.1/24")
_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
ip := iputil.Ip2VpnIp(net.ParseIP("172.1.1.2"))
preferredRanges := []*net.IPNet{localrange}
mw := &mockEncWriter{}
mainHM := NewHostMap(l, "test", vpncidr, preferredRanges)
mainHM := NewHostMap(l, vpncidr, preferredRanges)
lh := newTestLighthouse()
blah := NewHandshakeManager(l, tuncidr, preferredRanges, mainHM, lh, &udp.Conn{}, defaultHandshakeConfig)
blah := NewHandshakeManager(l, mainHM, lh, &udp.NoopConn{}, defaultHandshakeConfig)
now := time.Now()
blah.NextOutboundHandshakeTimerTick(now, mw)
blah.NextOutboundHandshakeTimerTick(now)
var initCalled bool
initFunc := func(*HostInfo) {
initCalled = true
}
i := blah.AddVpnIp(ip, initFunc)
assert.True(t, initCalled)
initCalled = false
i2 := blah.AddVpnIp(ip, initFunc)
assert.False(t, initCalled)
i := blah.StartHandshake(ip, nil)
i2 := blah.StartHandshake(ip, nil)
assert.Same(t, i, i2)
i.remotes = NewRemoteList(nil)
@ -48,22 +37,22 @@ func Test_NewHandshakeManagerVpnIp(t *testing.T) {
assert.Len(t, mainHM.Hosts, 0)
// Confirm they are in the pending index list
assert.Contains(t, blah.pendingHostMap.Hosts, ip)
assert.Contains(t, blah.vpnIps, ip)
// Jump ahead `HandshakeRetries` ticks, offset by one to get the sleep logic right
for i := 1; i <= DefaultHandshakeRetries+1; i++ {
now = now.Add(time.Duration(i) * DefaultHandshakeTryInterval)
blah.NextOutboundHandshakeTimerTick(now, mw)
blah.NextOutboundHandshakeTimerTick(now)
}
// Confirm they are still in the pending index list
assert.Contains(t, blah.pendingHostMap.Hosts, ip)
assert.Contains(t, blah.vpnIps, ip)
// Tick 1 more time, a minute will certainly flush it out
blah.NextOutboundHandshakeTimerTick(now.Add(time.Minute), mw)
blah.NextOutboundHandshakeTimerTick(now.Add(time.Minute))
// Confirm they have been removed
assert.NotContains(t, blah.pendingHostMap.Hosts, ip)
assert.NotContains(t, blah.vpnIps, ip)
}
func testCountTimerWheelEntries(tw *LockingTimerWheel[iputil.VpnIp]) (c int) {

View File

@ -2,7 +2,6 @@ package nebula
import (
"errors"
"fmt"
"net"
"sync"
"sync/atomic"
@ -18,8 +17,9 @@ import (
)
// const ProbeLen = 100
const PromoteEvery = 1000
const ReQueryEvery = 5000
const defaultPromoteEvery = 1000 // Count of packets sent before we try moving a tunnel to a preferred underlay ip address
const defaultReQueryEvery = 5000 // Count of packets sent before re-querying a hostinfo to the lighthouse
const defaultReQueryWait = time.Minute // Minimum amount of seconds to wait before re-querying a hostinfo the lighthouse. Evaluated every ReQueryEvery
const MaxRemotes = 10
// MaxHostInfosPerVpnIp is the max number of hostinfos we will track for a given vpn ip
@ -52,7 +52,6 @@ type Relay struct {
type HostMap struct {
sync.RWMutex //Because we concurrently read and write to our maps
name string
Indexes map[uint32]*HostInfo
Relays map[uint32]*HostInfo // Maps a Relay IDX to a Relay HostInfo object
RemoteIndexes map[uint32]*HostInfo
@ -205,13 +204,13 @@ type HostInfo struct {
multiportTx bool
multiportRx bool
ConnectionState *ConnectionState
handshakeStart time.Time //todo: this an entry in the handshake manager
HandshakeReady bool //todo: being in the manager means you are ready
HandshakeCounter int //todo: another handshake manager entry
HandshakeLastRemotes []*udp.Addr //todo: another handshake manager entry, which remotes we sent to last time
HandshakeComplete bool //todo: this should go away in favor of ConnectionState.ready
HandshakePacket map[uint8][]byte //todo: this is other handshake manager entry
packetStore []*cachedPacket //todo: this is other handshake manager entry
handshakeStart time.Time //todo: this an entry in the handshake manager
HandshakeReady bool //todo: being in the manager means you are ready
HandshakeCounter int //todo: another handshake manager entry
HandshakeLastRemotes []*udp.Addr //todo: another handshake manager entry, which remotes we sent to last time
HandshakeComplete bool //todo: this should go away in favor of ConnectionState.ready
HandshakePacket map[uint8][]byte
packetStore []*cachedPacket //todo: this is other handshake manager entry
remoteIndexId uint32
localIndexId uint32
vpnIp iputil.VpnIp
@ -219,6 +218,10 @@ type HostInfo struct {
remoteCidr *cidr.Tree4
relayState RelayState
// nextLHQuery is the earliest we can ask the lighthouse for new information.
// This is used to limit lighthouse re-queries in chatty clients
nextLHQuery atomic.Int64
// lastRebindCount is the other side of Interface.rebindCount, if these values don't match then we need to ask LH
// for a punch from the remote end of this tunnel. The goal being to prime their conntrack for our traffic just like
// with a handshake
@ -257,13 +260,12 @@ type cachedPacketMetrics struct {
dropped metrics.Counter
}
func NewHostMap(l *logrus.Logger, name string, vpnCIDR *net.IPNet, preferredRanges []*net.IPNet) *HostMap {
func NewHostMap(l *logrus.Logger, vpnCIDR *net.IPNet, preferredRanges []*net.IPNet) *HostMap {
h := map[iputil.VpnIp]*HostInfo{}
i := map[uint32]*HostInfo{}
r := map[uint32]*HostInfo{}
relays := map[uint32]*HostInfo{}
m := HostMap{
name: name,
Indexes: i,
Relays: relays,
RemoteIndexes: r,
@ -275,8 +277,8 @@ func NewHostMap(l *logrus.Logger, name string, vpnCIDR *net.IPNet, preferredRang
return &m
}
// UpdateStats takes a name and reports host and index counts to the stats collection system
func (hm *HostMap) EmitStats(name string) {
// EmitStats reports host, index, and relay counts to the stats collection system
func (hm *HostMap) EmitStats() {
hm.RLock()
hostLen := len(hm.Hosts)
indexLen := len(hm.Indexes)
@ -284,10 +286,10 @@ func (hm *HostMap) EmitStats(name string) {
relaysLen := len(hm.Relays)
hm.RUnlock()
metrics.GetOrRegisterGauge("hostmap."+name+".hosts", nil).Update(int64(hostLen))
metrics.GetOrRegisterGauge("hostmap."+name+".indexes", nil).Update(int64(indexLen))
metrics.GetOrRegisterGauge("hostmap."+name+".remoteIndexes", nil).Update(int64(remoteIndexLen))
metrics.GetOrRegisterGauge("hostmap."+name+".relayIndexes", nil).Update(int64(relaysLen))
metrics.GetOrRegisterGauge("hostmap.main.hosts", nil).Update(int64(hostLen))
metrics.GetOrRegisterGauge("hostmap.main.indexes", nil).Update(int64(indexLen))
metrics.GetOrRegisterGauge("hostmap.main.remoteIndexes", nil).Update(int64(remoteIndexLen))
metrics.GetOrRegisterGauge("hostmap.main.relayIndexes", nil).Update(int64(relaysLen))
}
func (hm *HostMap) RemoveRelay(localIdx uint32) {
@ -301,88 +303,6 @@ func (hm *HostMap) RemoveRelay(localIdx uint32) {
hm.Unlock()
}
func (hm *HostMap) GetIndexByVpnIp(vpnIp iputil.VpnIp) (uint32, error) {
hm.RLock()
if i, ok := hm.Hosts[vpnIp]; ok {
index := i.localIndexId
hm.RUnlock()
return index, nil
}
hm.RUnlock()
return 0, errors.New("vpn IP not found")
}
func (hm *HostMap) Add(ip iputil.VpnIp, hostinfo *HostInfo) {
hm.Lock()
hm.Hosts[ip] = hostinfo
hm.Unlock()
}
func (hm *HostMap) AddVpnIp(vpnIp iputil.VpnIp, init func(hostinfo *HostInfo)) (hostinfo *HostInfo, created bool) {
hm.RLock()
if h, ok := hm.Hosts[vpnIp]; !ok {
hm.RUnlock()
h = &HostInfo{
vpnIp: vpnIp,
HandshakePacket: make(map[uint8][]byte, 0),
relayState: RelayState{
relays: map[iputil.VpnIp]struct{}{},
relayForByIp: map[iputil.VpnIp]*Relay{},
relayForByIdx: map[uint32]*Relay{},
},
}
if init != nil {
init(h)
}
hm.Lock()
hm.Hosts[vpnIp] = h
hm.Unlock()
return h, true
} else {
hm.RUnlock()
return h, false
}
}
// Only used by pendingHostMap when the remote index is not initially known
func (hm *HostMap) addRemoteIndexHostInfo(index uint32, h *HostInfo) {
hm.Lock()
h.remoteIndexId = index
hm.RemoteIndexes[index] = h
hm.Unlock()
if hm.l.Level > logrus.DebugLevel {
hm.l.WithField("hostMap", m{"mapName": hm.name, "indexNumber": index, "mapTotalSize": len(hm.Indexes),
"hostinfo": m{"existing": true, "localIndexId": h.localIndexId, "hostId": h.vpnIp}}).
Debug("Hostmap remoteIndex added")
}
}
// DeleteReverseIndex is used to clean up on recv_error
// This function should only ever be called on the pending hostmap
func (hm *HostMap) DeleteReverseIndex(index uint32) {
hm.Lock()
hostinfo, ok := hm.RemoteIndexes[index]
if ok {
delete(hm.Indexes, hostinfo.localIndexId)
delete(hm.RemoteIndexes, index)
// Check if we have an entry under hostId that matches the same hostinfo
// instance. Clean it up as well if we do (they might not match in pendingHostmap)
var hostinfo2 *HostInfo
hostinfo2, ok = hm.Hosts[hostinfo.vpnIp]
if ok && hostinfo2 == hostinfo {
delete(hm.Hosts, hostinfo.vpnIp)
}
}
hm.Unlock()
if hm.l.Level >= logrus.DebugLevel {
hm.l.WithField("hostMap", m{"mapName": hm.name, "indexNumber": index, "mapTotalSize": len(hm.Indexes)}).
Debug("Hostmap remote index deleted")
}
}
// DeleteHostInfo will fully unlink the hostinfo and return true if it was the final hostinfo for this vpn ip
func (hm *HostMap) DeleteHostInfo(hostinfo *HostInfo) bool {
// Delete the host itself, ensuring it's not modified anymore
@ -395,12 +315,6 @@ func (hm *HostMap) DeleteHostInfo(hostinfo *HostInfo) bool {
return final
}
func (hm *HostMap) DeleteRelayIdx(localIdx uint32) {
hm.Lock()
defer hm.Unlock()
delete(hm.RemoteIndexes, localIdx)
}
func (hm *HostMap) MakePrimary(hostinfo *HostInfo) {
hm.Lock()
defer hm.Unlock()
@ -478,7 +392,7 @@ func (hm *HostMap) unlockedDeleteHostInfo(hostinfo *HostInfo) {
}
if hm.l.Level >= logrus.DebugLevel {
hm.l.WithField("hostMap", m{"mapName": hm.name, "mapTotalSize": len(hm.Hosts),
hm.l.WithField("hostMap", m{"mapTotalSize": len(hm.Hosts),
"vpnIp": hostinfo.vpnIp, "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId}).
Debug("Hostmap hostInfo deleted")
}
@ -488,55 +402,41 @@ func (hm *HostMap) unlockedDeleteHostInfo(hostinfo *HostInfo) {
}
}
func (hm *HostMap) QueryIndex(index uint32) (*HostInfo, error) {
//TODO: we probably just want to return bool instead of error, or at least a static error
func (hm *HostMap) QueryIndex(index uint32) *HostInfo {
hm.RLock()
if h, ok := hm.Indexes[index]; ok {
hm.RUnlock()
return h, nil
return h
} else {
hm.RUnlock()
return nil, errors.New("unable to find index")
return nil
}
}
// Retrieves a HostInfo by Index. Returns whether the HostInfo is primary at time of query.
// This helper exists so that the hostinfo.prev pointer can be read while the hostmap lock is held.
func (hm *HostMap) QueryIndexIsPrimary(index uint32) (*HostInfo, bool, error) {
//TODO: we probably just want to return bool instead of error, or at least a static error
hm.RLock()
if h, ok := hm.Indexes[index]; ok {
hm.RUnlock()
return h, h.prev == nil, nil
} else {
hm.RUnlock()
return nil, false, errors.New("unable to find index")
}
}
func (hm *HostMap) QueryRelayIndex(index uint32) (*HostInfo, error) {
func (hm *HostMap) QueryRelayIndex(index uint32) *HostInfo {
//TODO: we probably just want to return bool instead of error, or at least a static error
hm.RLock()
if h, ok := hm.Relays[index]; ok {
hm.RUnlock()
return h, nil
return h
} else {
hm.RUnlock()
return nil, errors.New("unable to find index")
return nil
}
}
func (hm *HostMap) QueryReverseIndex(index uint32) (*HostInfo, error) {
func (hm *HostMap) QueryReverseIndex(index uint32) *HostInfo {
hm.RLock()
if h, ok := hm.RemoteIndexes[index]; ok {
hm.RUnlock()
return h, nil
return h
} else {
hm.RUnlock()
return nil, fmt.Errorf("unable to find reverse index or connectionstate nil in %s hostmap", hm.name)
return nil
}
}
func (hm *HostMap) QueryVpnIp(vpnIp iputil.VpnIp) (*HostInfo, error) {
func (hm *HostMap) QueryVpnIp(vpnIp iputil.VpnIp) *HostInfo {
return hm.queryVpnIp(vpnIp, nil)
}
@ -558,13 +458,7 @@ func (hm *HostMap) QueryVpnIpRelayFor(targetIp, relayHostIp iputil.VpnIp) (*Host
return nil, nil, errors.New("unable to find host with relay")
}
// PromoteBestQueryVpnIp will attempt to lazily switch to the best remote every
// `PromoteEvery` calls to this function for a given host.
func (hm *HostMap) PromoteBestQueryVpnIp(vpnIp iputil.VpnIp, ifce *Interface) (*HostInfo, error) {
return hm.queryVpnIp(vpnIp, ifce)
}
func (hm *HostMap) queryVpnIp(vpnIp iputil.VpnIp, promoteIfce *Interface) (*HostInfo, error) {
func (hm *HostMap) queryVpnIp(vpnIp iputil.VpnIp, promoteIfce *Interface) *HostInfo {
hm.RLock()
if h, ok := hm.Hosts[vpnIp]; ok {
hm.RUnlock()
@ -572,12 +466,12 @@ func (hm *HostMap) queryVpnIp(vpnIp iputil.VpnIp, promoteIfce *Interface) (*Host
if promoteIfce != nil && !promoteIfce.lightHouse.amLighthouse {
h.TryPromoteBest(hm.preferredRanges, promoteIfce)
}
return h, nil
return h
}
hm.RUnlock()
return nil, errors.New("unable to find host")
return nil
}
// unlockedAddHostInfo assumes you have a write-lock and will add a hostinfo object to the hostmap Indexes and RemoteIndexes maps.
@ -600,7 +494,7 @@ func (hm *HostMap) unlockedAddHostInfo(hostinfo *HostInfo, f *Interface) {
hm.RemoteIndexes[hostinfo.remoteIndexId] = hostinfo
if hm.l.Level >= logrus.DebugLevel {
hm.l.WithField("hostMap", m{"mapName": hm.name, "vpnIp": hostinfo.vpnIp, "mapTotalSize": len(hm.Hosts),
hm.l.WithField("hostMap", m{"vpnIp": hostinfo.vpnIp, "mapTotalSize": len(hm.Hosts),
"hostinfo": m{"existing": true, "localIndexId": hostinfo.localIndexId, "hostId": hostinfo.vpnIp}}).
Debug("Hostmap vpnIp added")
}
@ -616,11 +510,33 @@ func (hm *HostMap) unlockedAddHostInfo(hostinfo *HostInfo, f *Interface) {
}
}
func (hm *HostMap) GetPreferredRanges() []*net.IPNet {
return hm.preferredRanges
}
func (hm *HostMap) ForEachVpnIp(f controlEach) {
hm.RLock()
defer hm.RUnlock()
for _, v := range hm.Hosts {
f(v)
}
}
func (hm *HostMap) ForEachIndex(f controlEach) {
hm.RLock()
defer hm.RUnlock()
for _, v := range hm.Indexes {
f(v)
}
}
// TryPromoteBest handles re-querying lighthouses and probing for better paths
// NOTE: It is an error to call this if you are a lighthouse since they should not roam clients!
func (i *HostInfo) TryPromoteBest(preferredRanges []*net.IPNet, ifce *Interface) {
c := i.promoteCounter.Add(1)
if c%PromoteEvery == 0 {
if c%ifce.tryPromoteEvery.Load() == 0 {
// The lock here is currently protecting i.remote access
i.RLock()
remote := i.remote
@ -648,12 +564,18 @@ func (i *HostInfo) TryPromoteBest(preferredRanges []*net.IPNet, ifce *Interface)
}
// Re query our lighthouses for new remotes occasionally
if c%ReQueryEvery == 0 && ifce.lightHouse != nil {
if c%ifce.reQueryEvery.Load() == 0 && ifce.lightHouse != nil {
now := time.Now().UnixNano()
if now < i.nextLHQuery.Load() {
return
}
i.nextLHQuery.Store(now + ifce.reQueryWait.Load())
ifce.lightHouse.QueryServer(i.vpnIp, ifce)
}
}
func (i *HostInfo) cachePacket(l *logrus.Logger, t header.MessageType, st header.MessageSubType, packet []byte, f packetCallback, m *cachedPacketMetrics) {
func (i *HostInfo) unlockedCachePacket(l *logrus.Logger, t header.MessageType, st header.MessageSubType, packet []byte, f packetCallback, m *cachedPacketMetrics) {
//TODO: return the error so we can log with more context
if len(i.packetStore) < 100 {
tempPacket := make([]byte, len(packet))
@ -682,7 +604,6 @@ func (i *HostInfo) handshakeComplete(l *logrus.Logger, m *cachedPacketMetrics) {
//TODO: HandshakeComplete means send stored packets and ConnectionState.ready means we are ready to send
//TODO: if the transition from HandhsakeComplete to ConnectionState.ready happens all within this function they are identical
i.ConnectionState.queueLock.Lock()
i.HandshakeComplete = true
//TODO: this should be managed by the handshake state machine to set it based on how many handshake were seen.
// Clamping it to 2 gets us out of the woods for now
@ -704,7 +625,6 @@ func (i *HostInfo) handshakeComplete(l *logrus.Logger, m *cachedPacketMetrics) {
i.remotes.ResetBlockedRemotes()
i.packetStore = make([]*cachedPacket, 0)
i.ConnectionState.ready = true
i.ConnectionState.queueLock.Unlock()
}
func (i *HostInfo) GetCert() *cert.NebulaCertificate {

View File

@ -11,7 +11,7 @@ import (
func TestHostMap_MakePrimary(t *testing.T) {
l := test.NewLogger()
hm := NewHostMap(
l, "test",
l,
&net.IPNet{
IP: net.IP{10, 0, 0, 1},
Mask: net.IPMask{255, 255, 255, 0},
@ -32,7 +32,7 @@ func TestHostMap_MakePrimary(t *testing.T) {
hm.unlockedAddHostInfo(h1, f)
// Make sure we go h1 -> h2 -> h3 -> h4
prim, _ := hm.QueryVpnIp(1)
prim := hm.QueryVpnIp(1)
assert.Equal(t, h1.localIndexId, prim.localIndexId)
assert.Equal(t, h2.localIndexId, prim.next.localIndexId)
assert.Nil(t, prim.prev)
@ -47,7 +47,7 @@ func TestHostMap_MakePrimary(t *testing.T) {
hm.MakePrimary(h3)
// Make sure we go h3 -> h1 -> h2 -> h4
prim, _ = hm.QueryVpnIp(1)
prim = hm.QueryVpnIp(1)
assert.Equal(t, h3.localIndexId, prim.localIndexId)
assert.Equal(t, h1.localIndexId, prim.next.localIndexId)
assert.Nil(t, prim.prev)
@ -62,7 +62,7 @@ func TestHostMap_MakePrimary(t *testing.T) {
hm.MakePrimary(h4)
// Make sure we go h4 -> h3 -> h1 -> h2
prim, _ = hm.QueryVpnIp(1)
prim = hm.QueryVpnIp(1)
assert.Equal(t, h4.localIndexId, prim.localIndexId)
assert.Equal(t, h3.localIndexId, prim.next.localIndexId)
assert.Nil(t, prim.prev)
@ -77,7 +77,7 @@ func TestHostMap_MakePrimary(t *testing.T) {
hm.MakePrimary(h4)
// Make sure we go h4 -> h3 -> h1 -> h2
prim, _ = hm.QueryVpnIp(1)
prim = hm.QueryVpnIp(1)
assert.Equal(t, h4.localIndexId, prim.localIndexId)
assert.Equal(t, h3.localIndexId, prim.next.localIndexId)
assert.Nil(t, prim.prev)
@ -92,7 +92,7 @@ func TestHostMap_MakePrimary(t *testing.T) {
func TestHostMap_DeleteHostInfo(t *testing.T) {
l := test.NewLogger()
hm := NewHostMap(
l, "test",
l,
&net.IPNet{
IP: net.IP{10, 0, 0, 1},
Mask: net.IPMask{255, 255, 255, 0},
@ -119,11 +119,11 @@ func TestHostMap_DeleteHostInfo(t *testing.T) {
// h6 should be deleted
assert.Nil(t, h6.next)
assert.Nil(t, h6.prev)
_, err := hm.QueryIndex(h6.localIndexId)
assert.Error(t, err)
h := hm.QueryIndex(h6.localIndexId)
assert.Nil(t, h)
// Make sure we go h1 -> h2 -> h3 -> h4 -> h5
prim, _ := hm.QueryVpnIp(1)
prim := hm.QueryVpnIp(1)
assert.Equal(t, h1.localIndexId, prim.localIndexId)
assert.Equal(t, h2.localIndexId, prim.next.localIndexId)
assert.Nil(t, prim.prev)
@ -142,7 +142,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) {
assert.Nil(t, h1.next)
// Make sure we go h2 -> h3 -> h4 -> h5
prim, _ = hm.QueryVpnIp(1)
prim = hm.QueryVpnIp(1)
assert.Equal(t, h2.localIndexId, prim.localIndexId)
assert.Equal(t, h3.localIndexId, prim.next.localIndexId)
assert.Nil(t, prim.prev)
@ -160,7 +160,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) {
assert.Nil(t, h3.next)
// Make sure we go h2 -> h4 -> h5
prim, _ = hm.QueryVpnIp(1)
prim = hm.QueryVpnIp(1)
assert.Equal(t, h2.localIndexId, prim.localIndexId)
assert.Equal(t, h4.localIndexId, prim.next.localIndexId)
assert.Nil(t, prim.prev)
@ -176,7 +176,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) {
assert.Nil(t, h5.next)
// Make sure we go h2 -> h4
prim, _ = hm.QueryVpnIp(1)
prim = hm.QueryVpnIp(1)
assert.Equal(t, h2.localIndexId, prim.localIndexId)
assert.Equal(t, h4.localIndexId, prim.next.localIndexId)
assert.Nil(t, prim.prev)
@ -190,7 +190,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) {
assert.Nil(t, h2.next)
// Make sure we only have h4
prim, _ = hm.QueryVpnIp(1)
prim = hm.QueryVpnIp(1)
assert.Equal(t, h4.localIndexId, prim.localIndexId)
assert.Nil(t, prim.prev)
assert.Nil(t, prim.next)
@ -202,6 +202,6 @@ func TestHostMap_DeleteHostInfo(t *testing.T) {
assert.Nil(t, h4.next)
// Make sure we have nil
prim, _ = hm.QueryVpnIp(1)
prim = hm.QueryVpnIp(1)
assert.Nil(t, prim)
}

107
inside.go
View File

@ -1,7 +1,6 @@
package nebula
import (
"github.com/flynn/noise"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/firewall"
"github.com/slackhq/nebula/header"
@ -45,7 +44,10 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
return
}
hostinfo := f.getOrHandshake(fwPacket.RemoteIP)
hostinfo, ready := f.getOrHandshake(fwPacket.RemoteIP, func(h *HostInfo) {
h.unlockedCachePacket(f.l, header.Message, 0, packet, f.sendMessageNow, f.cachedPacketMetrics)
})
if hostinfo == nil {
f.rejectInside(packet, out, q)
if f.l.Level >= logrus.DebugLevel {
@ -55,23 +57,14 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
}
return
}
ci := hostinfo.ConnectionState
if !ci.ready {
// Because we might be sending stored packets, lock here to stop new things going to
// the packet queue.
ci.queueLock.Lock()
if !ci.ready {
hostinfo.cachePacket(f.l, header.Message, 0, packet, f.sendMessageNow, f.cachedPacketMetrics)
ci.queueLock.Unlock()
return
}
ci.queueLock.Unlock()
if !ready {
return
}
dropReason := f.firewall.Drop(packet, *fwPacket, false, hostinfo, f.caPool, localCache)
dropReason := f.firewall.Drop(packet, *fwPacket, false, hostinfo, f.pki.GetCAPool(), localCache)
if dropReason == nil {
f.sendNoMetrics(header.Message, 0, ci, hostinfo, nil, packet, nb, out, q, fwPacket)
f.sendNoMetrics(header.Message, 0, hostinfo.ConnectionState, hostinfo, nil, packet, nb, out, q, fwPacket)
} else {
f.rejectInside(packet, out, q)
@ -110,71 +103,20 @@ func (f *Interface) rejectOutside(packet []byte, ci *ConnectionState, hostinfo *
}
func (f *Interface) Handshake(vpnIp iputil.VpnIp) {
f.getOrHandshake(vpnIp)
f.getOrHandshake(vpnIp, nil)
}
// getOrHandshake returns nil if the vpnIp is not routable
func (f *Interface) getOrHandshake(vpnIp iputil.VpnIp) *HostInfo {
// getOrHandshake returns nil if the vpnIp is not routable.
// If the 2nd return var is false then the hostinfo is not ready to be used in a tunnel
func (f *Interface) getOrHandshake(vpnIp iputil.VpnIp, cacheCallback func(info *HostInfo)) (*HostInfo, bool) {
if !ipMaskContains(f.lightHouse.myVpnIp, f.lightHouse.myVpnZeros, vpnIp) {
vpnIp = f.inside.RouteFor(vpnIp)
if vpnIp == 0 {
return nil
}
}
hostinfo, err := f.hostMap.PromoteBestQueryVpnIp(vpnIp, f)
//if err != nil || hostinfo.ConnectionState == nil {
if err != nil {
hostinfo, err = f.handshakeManager.pendingHostMap.QueryVpnIp(vpnIp)
if err != nil {
hostinfo = f.handshakeManager.AddVpnIp(vpnIp, f.initHostInfo)
}
}
ci := hostinfo.ConnectionState
if ci != nil && ci.eKey != nil && ci.ready {
return hostinfo
}
// Handshake is not ready, we need to grab the lock now before we start the handshake process
hostinfo.Lock()
defer hostinfo.Unlock()
// Double check, now that we have the lock
ci = hostinfo.ConnectionState
if ci != nil && ci.eKey != nil && ci.ready {
return hostinfo
}
// If we have already created the handshake packet, we don't want to call the function at all.
if !hostinfo.HandshakeReady {
ixHandshakeStage0(f, vpnIp, hostinfo)
// FIXME: Maybe make XX selectable, but probably not since psk makes it nearly pointless for us.
//xx_handshakeStage0(f, ip, hostinfo)
// If this is a static host, we don't need to wait for the HostQueryReply
// We can trigger the handshake right now
_, doTrigger := f.lightHouse.GetStaticHostList()[vpnIp]
if !doTrigger {
// Add any calculated remotes, and trigger early handshake if one found
doTrigger = f.lightHouse.addCalculatedRemotes(vpnIp)
}
if doTrigger {
select {
case f.handshakeManager.trigger <- vpnIp:
default:
}
return nil, false
}
}
return hostinfo
}
// initHostInfo is the init function to pass to (*HandshakeManager).AddVpnIP that
// will create the initial Noise ConnectionState
func (f *Interface) initHostInfo(hostinfo *HostInfo) {
hostinfo.ConnectionState = f.newConnectionState(f.l, true, noise.HandshakeIX, []byte{}, 0)
return f.handshakeManager.GetOrHandshake(vpnIp, cacheCallback)
}
func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, nb, out []byte) {
@ -186,7 +128,7 @@ func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubTyp
}
// check if packet is in outbound fw rules
dropReason := f.firewall.Drop(p, *fp, false, hostinfo, f.caPool, nil)
dropReason := f.firewall.Drop(p, *fp, false, hostinfo, f.pki.GetCAPool(), nil)
if dropReason != nil {
if f.l.Level >= logrus.DebugLevel {
f.l.WithField("fwPacket", fp).
@ -201,7 +143,10 @@ func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubTyp
// SendMessageToVpnIp handles real ip:port lookup and sends to the current best known address for vpnIp
func (f *Interface) SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp iputil.VpnIp, p, nb, out []byte) {
hostInfo := f.getOrHandshake(vpnIp)
hostInfo, ready := f.getOrHandshake(vpnIp, func(h *HostInfo) {
h.unlockedCachePacket(f.l, t, st, p, f.SendMessageToHostInfo, f.cachedPacketMetrics)
})
if hostInfo == nil {
if f.l.Level >= logrus.DebugLevel {
f.l.WithField("vpnIp", vpnIp).
@ -210,16 +155,8 @@ func (f *Interface) SendMessageToVpnIp(t header.MessageType, st header.MessageSu
return
}
if !hostInfo.ConnectionState.ready {
// Because we might be sending stored packets, lock here to stop new things going to
// the packet queue.
hostInfo.ConnectionState.queueLock.Lock()
if !hostInfo.ConnectionState.ready {
hostInfo.cachePacket(f.l, t, st, p, f.SendMessageToHostInfo, f.cachedPacketMetrics)
hostInfo.ConnectionState.queueLock.Unlock()
return
}
hostInfo.ConnectionState.queueLock.Unlock()
if !ready {
return
}
f.SendMessageToHostInfo(t, st, hostInfo, p, nb, out)
@ -239,7 +176,7 @@ func (f *Interface) sendTo(t header.MessageType, st header.MessageSubType, ci *C
f.sendNoMetrics(t, st, ci, hostinfo, remote, p, nb, out, 0, nil)
}
// sendVia sends a payload through a Relay tunnel. No authentication or encryption is done
// SendVia sends a payload through a Relay tunnel. No authentication or encryption is done
// to the payload for the ultimate target host, making this a useful method for sending
// handshake messages to peers through relay tunnels.
// via is the HostInfo through which the message is relayed.

View File

@ -13,7 +13,6 @@ import (
"github.com/rcrowley/go-metrics"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/firewall"
"github.com/slackhq/nebula/header"
@ -26,9 +25,9 @@ const mtu = 9001
type InterfaceConfig struct {
HostMap *HostMap
Outside *udp.Conn
Outside udp.Conn
Inside overlay.Device
certState *CertState
pki *PKI
Cipher string
Firewall *Firewall
ServeDns bool
@ -41,20 +40,23 @@ type InterfaceConfig struct {
routines int
MessageMetrics *MessageMetrics
version string
caPool *cert.NebulaCAPool
disconnectInvalid bool
relayManager *relayManager
punchy *Punchy
tryPromoteEvery uint32
reQueryEvery uint32
reQueryWait time.Duration
ConntrackCacheTimeout time.Duration
l *logrus.Logger
}
type Interface struct {
hostMap *HostMap
outside *udp.Conn
outside udp.Conn
inside overlay.Device
certState atomic.Pointer[CertState]
pki *PKI
cipher string
firewall *Firewall
connectionManager *connectionManager
@ -67,11 +69,14 @@ type Interface struct {
dropLocalBroadcast bool
dropMulticast bool
routines int
caPool *cert.NebulaCAPool
disconnectInvalid bool
closed atomic.Bool
relayManager *relayManager
tryPromoteEvery atomic.Uint32
reQueryEvery atomic.Uint32
reQueryWait atomic.Int64
sendRecvErrorConfig sendRecvErrorConfig
// rebindCount is used to decide if an active tunnel should trigger a punch notification through a lighthouse
@ -80,7 +85,7 @@ type Interface struct {
conntrackCacheTimeout time.Duration
writers []*udp.Conn
writers []udp.Conn
readers []io.ReadWriteCloser
udpRaw *udp.RawConn
@ -156,15 +161,17 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
if c.Inside == nil {
return nil, errors.New("no inside interface (tun)")
}
if c.certState == nil {
if c.pki == nil {
return nil, errors.New("no certificate state")
}
if c.Firewall == nil {
return nil, errors.New("no firewall rules")
}
myVpnIp := iputil.Ip2VpnIp(c.certState.certificate.Details.Ips[0].IP)
certificate := c.pki.GetCertState().Certificate
myVpnIp := iputil.Ip2VpnIp(certificate.Details.Ips[0].IP)
ifce := &Interface{
pki: c.pki,
hostMap: c.HostMap,
outside: c.Outside,
inside: c.Inside,
@ -174,14 +181,13 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
handshakeManager: c.HandshakeManager,
createTime: time.Now(),
lightHouse: c.lightHouse,
localBroadcast: myVpnIp | ^iputil.Ip2VpnIp(c.certState.certificate.Details.Ips[0].Mask),
localBroadcast: myVpnIp | ^iputil.Ip2VpnIp(certificate.Details.Ips[0].Mask),
dropLocalBroadcast: c.DropLocalBroadcast,
dropMulticast: c.DropMulticast,
routines: c.routines,
version: c.version,
writers: make([]*udp.Conn, c.routines),
writers: make([]udp.Conn, c.routines),
readers: make([]io.ReadWriteCloser, c.routines),
caPool: c.caPool,
disconnectInvalid: c.disconnectInvalid,
myVpnIp: myVpnIp,
relayManager: c.relayManager,
@ -198,7 +204,10 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
l: c.l,
}
ifce.certState.Store(c.certState)
ifce.tryPromoteEvery.Store(c.tryPromoteEvery)
ifce.reQueryEvery.Store(c.reQueryEvery)
ifce.reQueryWait.Store(int64(c.reQueryWait))
ifce.connectionManager = newConnectionManager(ctx, c.l, ifce, c.checkInterval, c.pendingDeletionInterval, c.punchy)
return ifce, nil
@ -257,7 +266,7 @@ func (f *Interface) run() {
func (f *Interface) listenOut(i int) {
runtime.LockOSThread()
var li *udp.Conn
var li udp.Conn
// TODO clean this up with a coherent interface for each outside connection
if i > 0 {
li = f.writers[i]
@ -297,49 +306,14 @@ func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
}
func (f *Interface) RegisterConfigChangeCallbacks(c *config.C) {
c.RegisterReloadCallback(f.reloadCA)
c.RegisterReloadCallback(f.reloadCertKey)
c.RegisterReloadCallback(f.reloadFirewall)
c.RegisterReloadCallback(f.reloadSendRecvError)
c.RegisterReloadCallback(f.reloadMisc)
for _, udpConn := range f.writers {
c.RegisterReloadCallback(udpConn.ReloadConfig)
}
}
func (f *Interface) reloadCA(c *config.C) {
// reload and check regardless
// todo: need mutex?
newCAs, err := loadCAFromConfig(f.l, c)
if err != nil {
f.l.WithError(err).Error("Could not refresh trusted CA certificates")
return
}
f.caPool = newCAs
f.l.WithField("fingerprints", f.caPool.GetFingerprints()).Info("Trusted CA certificates refreshed")
}
func (f *Interface) reloadCertKey(c *config.C) {
// reload and check in all cases
cs, err := NewCertStateFromConfig(c)
if err != nil {
f.l.WithError(err).Error("Could not refresh client cert")
return
}
// did IP in cert change? if so, don't set
currentCert := f.certState.Load().certificate
oldIPs := currentCert.Details.Ips
newIPs := cs.certificate.Details.Ips
if len(oldIPs) > 0 && len(newIPs) > 0 && oldIPs[0].String() != newIPs[0].String() {
f.l.WithField("new_ip", newIPs[0]).WithField("old_ip", oldIPs[0]).Error("IP in new cert was different from old")
return
}
f.certState.Store(cs)
f.l.WithField("cert", cs.certificate).Info("Client cert refreshed from disk")
}
func (f *Interface) reloadFirewall(c *config.C) {
//TODO: need to trigger/detect if the certificate changed too
if c.HasChanged("firewall") == false {
@ -347,7 +321,7 @@ func (f *Interface) reloadFirewall(c *config.C) {
return
}
fw, err := NewFirewallFromConfig(f.l, f.certState.Load().certificate, c)
fw, err := NewFirewallFromConfig(f.l, f.pki.GetCertState().Certificate, c)
if err != nil {
f.l.WithError(err).Error("Error while creating firewall during reload")
return
@ -403,6 +377,26 @@ func (f *Interface) reloadSendRecvError(c *config.C) {
}
}
func (f *Interface) reloadMisc(c *config.C) {
if c.HasChanged("counters.try_promote") {
n := c.GetUint32("counters.try_promote", defaultPromoteEvery)
f.tryPromoteEvery.Store(n)
f.l.Info("counters.try_promote has changed")
}
if c.HasChanged("counters.requery_every_packets") {
n := c.GetUint32("counters.requery_every_packets", defaultReQueryEvery)
f.reQueryEvery.Store(n)
f.l.Info("counters.requery_every_packets has changed")
}
if c.HasChanged("timers.requery_wait_duration") {
n := c.GetDuration("timers.requery_wait_duration", defaultReQueryWait)
f.reQueryWait.Store(int64(n))
f.l.Info("timers.requery_wait_duration has changed")
}
}
func (f *Interface) emitStats(ctx context.Context, i time.Duration) {
ticker := time.NewTicker(i)
defer ticker.Stop()
@ -427,7 +421,7 @@ func (f *Interface) emitStats(ctx context.Context, i time.Duration) {
}
rawStats()
}
certExpirationGauge.Update(int64(f.certState.Load().certificate.Details.NotAfter.Sub(time.Now()) / time.Second))
certExpirationGauge.Update(int64(f.pki.GetCertState().Certificate.Details.NotAfter.Sub(time.Now()) / time.Second))
}
}
}
@ -435,6 +429,13 @@ func (f *Interface) emitStats(ctx context.Context, i time.Duration) {
func (f *Interface) Close() error {
f.closed.Store(true)
for _, u := range f.writers {
err := u.Close()
if err != nil {
f.l.WithError(err).Error("Error while closing udp socket")
}
}
// Release the tun device
return f.inside.Close()
}

View File

@ -39,7 +39,7 @@ type LightHouse struct {
myVpnIp iputil.VpnIp
myVpnZeros iputil.VpnIp
myVpnNet *net.IPNet
punchConn *udp.Conn
punchConn udp.Conn
punchy *Punchy
// Local cache of answers from light houses
@ -64,11 +64,10 @@ type LightHouse struct {
staticList atomic.Pointer[map[iputil.VpnIp]struct{}]
lighthouses atomic.Pointer[map[iputil.VpnIp]struct{}]
interval atomic.Int64
updateCancel context.CancelFunc
updateParentCtx context.Context
updateUdp EncWriter
nebulaPort uint32 // 32 bits because protobuf does not have a uint16
interval atomic.Int64
updateCancel context.CancelFunc
ifce EncWriter
nebulaPort uint32 // 32 bits because protobuf does not have a uint16
advertiseAddrs atomic.Pointer[[]netIpAndPort]
@ -84,7 +83,7 @@ type LightHouse struct {
// NewLightHouseFromConfig will build a Lighthouse struct from the values provided in the config object
// addrMap should be nil unless this is during a config reload
func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C, myVpnNet *net.IPNet, pc *udp.Conn, p *Punchy) (*LightHouse, error) {
func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C, myVpnNet *net.IPNet, pc udp.Conn, p *Punchy) (*LightHouse, error) {
amLighthouse := c.GetBool("lighthouse.am_lighthouse", false)
nebulaPort := uint32(c.GetInt("listen.port", 0))
if amLighthouse && nebulaPort == 0 {
@ -133,7 +132,7 @@ func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C,
c.RegisterReloadCallback(func(c *config.C) {
err := h.reload(c, false)
switch v := err.(type) {
case util.ContextualError:
case *util.ContextualError:
v.Log(l)
case error:
l.WithError(err).Error("failed to reload lighthouse")
@ -217,7 +216,7 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error {
lh.updateCancel()
}
lh.LhUpdateWorker(lh.updateParentCtx, lh.updateUdp)
lh.StartUpdateWorker()
}
}
@ -262,6 +261,18 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error {
//NOTE: many things will get much simpler when we combine static_host_map and lighthouse.hosts in config
if initial || c.HasChanged("static_host_map") || c.HasChanged("static_map.cadence") || c.HasChanged("static_map.network") || c.HasChanged("static_map.lookup_timeout") {
// Clean up. Entries still in the static_host_map will be re-built.
// Entries no longer present must have their (possible) background DNS goroutines stopped.
if existingStaticList := lh.staticList.Load(); existingStaticList != nil {
lh.RLock()
for staticVpnIp := range *existingStaticList {
if am, ok := lh.addrMap[staticVpnIp]; ok && am != nil {
am.hr.Cancel()
}
}
lh.RUnlock()
}
// Build a new list based on current config.
staticList := make(map[iputil.VpnIp]struct{})
err := lh.loadStaticMap(c, lh.myVpnNet, staticList)
if err != nil {
@ -742,33 +753,33 @@ func NewUDPAddrFromLH6(ipp *Ip6AndPort) *udp.Addr {
return udp.NewAddr(lhIp6ToIp(ipp), uint16(ipp.Port))
}
func (lh *LightHouse) LhUpdateWorker(ctx context.Context, f EncWriter) {
lh.updateParentCtx = ctx
lh.updateUdp = f
func (lh *LightHouse) StartUpdateWorker() {
interval := lh.GetUpdateInterval()
if lh.amLighthouse || interval == 0 {
return
}
clockSource := time.NewTicker(time.Second * time.Duration(interval))
updateCtx, cancel := context.WithCancel(ctx)
updateCtx, cancel := context.WithCancel(lh.ctx)
lh.updateCancel = cancel
defer clockSource.Stop()
for {
lh.SendUpdate(f)
go func() {
defer clockSource.Stop()
select {
case <-updateCtx.Done():
return
case <-clockSource.C:
continue
for {
lh.SendUpdate()
select {
case <-updateCtx.Done():
return
case <-clockSource.C:
continue
}
}
}
}()
}
func (lh *LightHouse) SendUpdate(f EncWriter) {
func (lh *LightHouse) SendUpdate() {
var v4 []*Ip4AndPort
var v6 []*Ip6AndPort
@ -821,7 +832,7 @@ func (lh *LightHouse) SendUpdate(f EncWriter) {
}
for vpnIp := range lighthouses {
f.SendMessageToVpnIp(header.LightHouse, 0, vpnIp, mm, nb, out)
lh.ifce.SendMessageToVpnIp(header.LightHouse, 0, vpnIp, mm, nb, out)
}
}

View File

@ -12,6 +12,7 @@ import (
"github.com/slackhq/nebula/test"
"github.com/slackhq/nebula/udp"
"github.com/stretchr/testify/assert"
"gopkg.in/yaml.v2"
)
//TODO: Add a test to ensure udpAddr is copied and not reused
@ -65,6 +66,35 @@ func Test_lhStaticMapping(t *testing.T) {
assert.EqualError(t, err, "lighthouse 10.128.0.3 does not have a static_host_map entry")
}
func TestReloadLighthouseInterval(t *testing.T) {
l := test.NewLogger()
_, myVpnNet, _ := net.ParseCIDR("10.128.0.1/16")
lh1 := "10.128.0.2"
c := config.NewC(l)
c.Settings["lighthouse"] = map[interface{}]interface{}{
"hosts": []interface{}{lh1},
"interval": "1s",
}
c.Settings["static_host_map"] = map[interface{}]interface{}{lh1: []interface{}{"1.1.1.1:4242"}}
lh, err := NewLightHouseFromConfig(context.Background(), l, c, myVpnNet, nil, nil)
assert.NoError(t, err)
lh.ifce = &mockEncWriter{}
// The first one routine is kicked off by main.go currently, lets make sure that one dies
c.ReloadConfigString("lighthouse:\n interval: 5")
assert.Equal(t, int64(5), lh.interval.Load())
// Subsequent calls are killed off by the LightHouse.Reload function
c.ReloadConfigString("lighthouse:\n interval: 10")
assert.Equal(t, int64(10), lh.interval.Load())
// If this completes then nothing is stealing our reload routine
c.ReloadConfigString("lighthouse:\n interval: 11")
assert.Equal(t, int64(11), lh.interval.Load())
}
func BenchmarkLighthouseHandleRequest(b *testing.B) {
l := test.NewLogger()
_, myVpnNet, _ := net.ParseCIDR("10.128.0.1/0")
@ -242,8 +272,17 @@ func TestLighthouse_reload(t *testing.T) {
lh, err := NewLightHouseFromConfig(context.Background(), l, c, &net.IPNet{IP: net.IP{10, 128, 0, 1}, Mask: net.IPMask{255, 255, 255, 0}}, nil, nil)
assert.NoError(t, err)
c.Settings["static_host_map"] = map[interface{}]interface{}{"10.128.0.2": []interface{}{"1.1.1.1:4242"}}
lh.reload(c, false)
nc := map[interface{}]interface{}{
"static_host_map": map[interface{}]interface{}{
"10.128.0.2": []interface{}{"1.1.1.1:4242"},
},
}
rc, err := yaml.Marshal(nc)
assert.NoError(t, err)
c.ReloadConfigString(string(rc))
err = lh.reload(c, false)
assert.NoError(t, err)
}
func newLHHostRequest(fromAddr *udp.Addr, myVpnIp, queryVpnIp iputil.VpnIp, lhh *LightHouseHandler) testLhReply {

83
main.go
View File

@ -3,7 +3,6 @@ package nebula
import (
"context"
"encoding/binary"
"errors"
"fmt"
"net"
"time"
@ -46,7 +45,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
err := configLogger(l, c)
if err != nil {
return nil, util.NewContextualError("Failed to configure the logger", nil, err)
return nil, util.ContextualizeIfNeeded("Failed to configure the logger", err)
}
c.RegisterReloadCallback(func(c *config.C) {
@ -56,28 +55,20 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
}
})
caPool, err := loadCAFromConfig(l, c)
pki, err := NewPKIFromConfig(l, c)
if err != nil {
//The errors coming out of loadCA are already nicely formatted
return nil, util.NewContextualError("Failed to load ca from config", nil, err)
return nil, util.ContextualizeIfNeeded("Failed to load PKI from config", err)
}
l.WithField("fingerprints", caPool.GetFingerprints()).Debug("Trusted CA fingerprints")
cs, err := NewCertStateFromConfig(c)
certificate := pki.GetCertState().Certificate
fw, err := NewFirewallFromConfig(l, certificate, c)
if err != nil {
//The errors coming out of NewCertStateFromConfig are already nicely formatted
return nil, util.NewContextualError("Failed to load certificate from config", nil, err)
}
l.WithField("cert", cs.certificate).Debug("Client nebula certificate")
fw, err := NewFirewallFromConfig(l, cs.certificate, c)
if err != nil {
return nil, util.NewContextualError("Error while loading firewall rules", nil, err)
return nil, util.ContextualizeIfNeeded("Error while loading firewall rules", err)
}
l.WithField("firewallHash", fw.GetRuleHash()).Info("Firewall started")
// TODO: make sure mask is 4 bytes
tunCidr := cs.certificate.Details.Ips[0]
tunCidr := certificate.Details.Ips[0]
ssh, err := sshd.NewSSHServer(l.WithField("subsystem", "sshd"))
wireSSHReload(l, ssh, c)
@ -85,7 +76,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
if c.GetBool("sshd.enabled", false) {
sshStart, err = configSSH(l, ssh, c)
if err != nil {
return nil, util.NewContextualError("Error while configuring the sshd", nil, err)
return nil, util.ContextualizeIfNeeded("Error while configuring the sshd", err)
}
}
@ -136,7 +127,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
tun, err = overlay.NewDeviceFromConfig(c, l, tunCidr, tunFd, routines)
if err != nil {
return nil, util.NewContextualError("Failed to get a tun/tap device", nil, err)
return nil, util.ContextualizeIfNeeded("Failed to get a tun/tap device", err)
}
defer func() {
@ -147,7 +138,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
}
// set up our UDP listener
udpConns := make([]*udp.Conn, routines)
udpConns := make([]udp.Conn, routines)
port := c.GetInt("listen.port", 0)
if !configTest {
@ -160,7 +151,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
} else {
listenHost, err = net.ResolveIPAddr("ip", rawListenHost)
if err != nil {
return nil, util.NewContextualError("Failed to resolve listen.host", nil, err)
return nil, util.ContextualizeIfNeeded("Failed to resolve listen.host", err)
}
}
@ -182,7 +173,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
for _, rawPreferredRange := range rawPreferredRanges {
_, preferredRange, err := net.ParseCIDR(rawPreferredRange)
if err != nil {
return nil, util.NewContextualError("Failed to parse preferred ranges", nil, err)
return nil, util.ContextualizeIfNeeded("Failed to parse preferred ranges", err)
}
preferredRanges = append(preferredRanges, preferredRange)
}
@ -195,7 +186,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
if rawLocalRange != "" {
_, localRange, err := net.ParseCIDR(rawLocalRange)
if err != nil {
return nil, util.NewContextualError("Failed to parse local_range", nil, err)
return nil, util.ContextualizeIfNeeded("Failed to parse local_range", err)
}
// Check if the entry for local_range was already specified in
@ -212,7 +203,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
}
}
hostMap := NewHostMap(l, "main", tunCidr, preferredRanges)
hostMap := NewHostMap(l, tunCidr, preferredRanges)
hostMap.metricsEnabled = c.GetBool("stats.message_metrics", false)
l.
@ -220,18 +211,10 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
WithField("preferredRanges", hostMap.preferredRanges).
Info("Main HostMap created")
/*
config.SetDefault("promoter.interval", 10)
go hostMap.Promoter(config.GetInt("promoter.interval"))
*/
punchy := NewPunchyFromConfig(l, c)
lightHouse, err := NewLightHouseFromConfig(ctx, l, c, tunCidr, udpConns[0], punchy)
switch {
case errors.As(err, &util.ContextualError{}):
return nil, err
case err != nil:
return nil, util.NewContextualError("Failed to initialize lighthouse handler", nil, err)
if err != nil {
return nil, util.ContextualizeIfNeeded("Failed to initialize lighthouse handler", err)
}
var messageMetrics *MessageMetrics
@ -252,13 +235,9 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
messageMetrics: messageMetrics,
}
handshakeManager := NewHandshakeManager(l, tunCidr, preferredRanges, hostMap, lightHouse, udpConns[0], handshakeConfig)
handshakeManager := NewHandshakeManager(l, hostMap, lightHouse, udpConns[0], handshakeConfig)
lightHouse.handshakeTrigger = handshakeManager.trigger
//TODO: These will be reused for psk
//handshakeMACKey := config.GetString("handshake_mac.key", "")
//handshakeAcceptedMACKeys := config.GetStringSlice("handshake_mac.accepted_keys", []string{})
serveDns := false
if c.GetBool("lighthouse.serve_dns", false) {
if c.GetBool("lighthouse.am_lighthouse", false) {
@ -270,11 +249,12 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
checkInterval := c.GetInt("timers.connection_alive_interval", 5)
pendingDeletionInterval := c.GetInt("timers.pending_deletion_interval", 10)
ifConfig := &InterfaceConfig{
HostMap: hostMap,
Inside: tun,
Outside: udpConns[0],
certState: cs,
pki: pki,
Cipher: c.GetString("cipher", "aes"),
Firewall: fw,
ServeDns: serveDns,
@ -282,12 +262,14 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
lightHouse: lightHouse,
checkInterval: time.Second * time.Duration(checkInterval),
pendingDeletionInterval: time.Second * time.Duration(pendingDeletionInterval),
tryPromoteEvery: c.GetUint32("counters.try_promote", defaultPromoteEvery),
reQueryEvery: c.GetUint32("counters.requery_every_packets", defaultReQueryEvery),
reQueryWait: c.GetDuration("timers.requery_wait_duration", defaultReQueryWait),
DropLocalBroadcast: c.GetBool("tun.drop_local_broadcast", false),
DropMulticast: c.GetBool("tun.drop_multicast", false),
routines: routines,
MessageMetrics: messageMetrics,
version: buildVersion,
caPool: caPool,
disconnectInvalid: c.GetBool("pki.disconnect_invalid", false),
relayManager: NewRelayManager(ctx, l, hostMap, c),
punchy: punchy,
@ -315,6 +297,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
// TODO: Better way to attach these, probably want a new interface in InterfaceConfig
// I don't want to make this initial commit too far-reaching though
ifce.writers = udpConns
lightHouse.ifce = ifce
loadMultiPortConfig := func(c *config.C) {
ifce.multiPort.Rx = c.GetBool("tun.multiport.rx_enabled", false)
@ -350,19 +333,17 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
c.RegisterReloadCallback(loadMultiPortConfig)
ifce.RegisterConfigChangeCallbacks(c)
ifce.reloadSendRecvError(c)
go handshakeManager.Run(ctx, ifce)
go lightHouse.LhUpdateWorker(ctx, ifce)
handshakeManager.f = ifce
go handshakeManager.Run(ctx)
}
// TODO - stats third-party modules start uncancellable goroutines. Update those libs to accept
// a context so that they can exit when the context is Done.
statsStart, err := startStats(l, c, buildVersion, configTest)
if err != nil {
return nil, util.NewContextualError("Failed to start stats emitter", nil, err)
return nil, util.ContextualizeIfNeeded("Failed to start stats emitter", err)
}
if configTest {
@ -372,7 +353,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
//TODO: check if we _should_ be emitting stats
go ifce.emitStats(ctx, c.GetDuration("stats.interval", time.Second*10))
attachCommands(l, c, ssh, hostMap, handshakeManager.pendingHostMap, lightHouse, ifce)
attachCommands(l, c, ssh, ifce)
// Start DNS server last to allow using the nebula IP as lighthouse.dns.host
var dnsStart func()
@ -381,5 +362,13 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
dnsStart = dnsMain(l, hostMap, c)
}
return &Control{ifce, l, cancel, sshStart, statsStart, dnsStart}, nil
return &Control{
ifce,
l,
cancel,
sshStart,
statsStart,
dnsStart,
lightHouse.StartUpdateWorker,
}, nil
}

View File

@ -64,9 +64,9 @@ func (f *Interface) readOutsidePackets(addr *udp.Addr, via *ViaSender, out []byt
var hostinfo *HostInfo
// verify if we've seen this index before, otherwise respond to the handshake initiation
if h.Type == header.Message && h.Subtype == header.MessageRelay {
hostinfo, _ = f.hostMap.QueryRelayIndex(h.RemoteIndex)
hostinfo = f.hostMap.QueryRelayIndex(h.RemoteIndex)
} else {
hostinfo, _ = f.hostMap.QueryIndex(h.RemoteIndex)
hostinfo = f.hostMap.QueryIndex(h.RemoteIndex)
}
var ci *ConnectionState
@ -417,7 +417,7 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
return false
}
dropReason := f.firewall.Drop(out, *fwPacket, true, hostinfo, f.caPool, localCache)
dropReason := f.firewall.Drop(out, *fwPacket, true, hostinfo, f.pki.GetCAPool(), localCache)
if dropReason != nil {
f.rejectOutside(out, hostinfo.ConnectionState, hostinfo, nb, out, q)
if f.l.Level >= logrus.DebugLevel {
@ -462,12 +462,9 @@ func (f *Interface) handleRecvError(addr *udp.Addr, h *header.H) {
Debug("Recv error received")
}
// First, clean up in the pending hostmap
f.handshakeManager.pendingHostMap.DeleteReverseIndex(h.RemoteIndex)
hostinfo, err := f.hostMap.QueryReverseIndex(h.RemoteIndex)
if err != nil {
f.l.Debugln(err, ": ", h.RemoteIndex)
hostinfo := f.hostMap.QueryReverseIndex(h.RemoteIndex)
if hostinfo == nil {
f.l.WithField("remoteIndex", h.RemoteIndex).Debugln("Did not find remote index in main hostmap")
return
}
@ -477,14 +474,14 @@ func (f *Interface) handleRecvError(addr *udp.Addr, h *header.H) {
if !hostinfo.RecvErrorExceeded() {
return
}
if hostinfo.remote != nil && !hostinfo.remote.Equals(addr) {
f.l.Infoln("Someone spoofing recv_errors? ", addr, hostinfo.remote)
return
}
f.closeTunnel(hostinfo)
// We also delete it from pending hostmap to allow for
// fast reconnect.
// We also delete it from pending hostmap to allow for fast reconnect.
f.handshakeManager.DeleteHostInfo(hostinfo)
}

View File

@ -47,14 +47,6 @@ type ifReq struct {
pad [8]byte
}
func ioctl(a1, a2, a3 uintptr) error {
_, _, errno := unix.Syscall(unix.SYS_IOCTL, a1, a2, a3)
if errno != 0 {
return errno
}
return nil
}
var sockaddrCtlSize uintptr = 32
const (
@ -194,10 +186,10 @@ func (t *tun) Activate() error {
unix.SOCK_DGRAM,
unix.IPPROTO_IP,
)
if err != nil {
return err
}
defer unix.Close(s)
fd := uintptr(s)

View File

@ -4,21 +4,44 @@
package overlay
import (
"bytes"
"errors"
"fmt"
"io"
"io/fs"
"net"
"os"
"os/exec"
"regexp"
"strconv"
"strings"
"syscall"
"unsafe"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cidr"
"github.com/slackhq/nebula/iputil"
)
var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
const (
// FIODGNAME is defined in sys/sys/filio.h on FreeBSD
// For 32-bit systems, use FIODGNAME_32 (not defined in this file: 0x80086678)
FIODGNAME = 0x80106678
)
type fiodgnameArg struct {
length int32
pad [4]byte
buf unsafe.Pointer
}
type ifreqRename struct {
Name [16]byte
Data uintptr
}
type ifreqDestroy struct {
Name [16]byte
pad [16]byte
}
type tun struct {
Device string
@ -33,8 +56,23 @@ type tun struct {
func (t *tun) Close() error {
if t.ReadWriteCloser != nil {
return t.ReadWriteCloser.Close()
if err := t.ReadWriteCloser.Close(); err != nil {
return err
}
s, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_DGRAM, syscall.IPPROTO_IP)
if err != nil {
return err
}
defer syscall.Close(s)
ifreq := ifreqDestroy{Name: t.deviceBytes()}
// Destroy the interface
err = ioctl(uintptr(s), syscall.SIOCIFDESTROY, uintptr(unsafe.Pointer(&ifreq)))
return err
}
return nil
}
@ -43,34 +81,87 @@ func newTunFromFd(_ *logrus.Logger, _ int, _ *net.IPNet, _ int, _ []Route, _ int
}
func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, routes []Route, _ int, _ bool, _ bool) (*tun, error) {
// Try to open existing tun device
var file *os.File
var err error
if deviceName != "" {
file, err = os.OpenFile("/dev/"+deviceName, os.O_RDWR, 0)
}
if errors.Is(err, fs.ErrNotExist) || deviceName == "" {
// If the device doesn't already exist, request a new one and rename it
file, err = os.OpenFile("/dev/tun", os.O_RDWR, 0)
}
if err != nil {
return nil, err
}
rawConn, err := file.SyscallConn()
if err != nil {
return nil, fmt.Errorf("SyscallConn: %v", err)
}
var name [16]byte
var ctrlErr error
rawConn.Control(func(fd uintptr) {
// Read the name of the interface
arg := fiodgnameArg{length: 16, buf: unsafe.Pointer(&name)}
ctrlErr = ioctl(fd, FIODGNAME, uintptr(unsafe.Pointer(&arg)))
})
if ctrlErr != nil {
return nil, err
}
ifName := string(bytes.TrimRight(name[:], "\x00"))
if deviceName == "" {
deviceName = ifName
}
// If the name doesn't match the desired interface name, rename it now
if ifName != deviceName {
s, err := syscall.Socket(
syscall.AF_INET,
syscall.SOCK_DGRAM,
syscall.IPPROTO_IP,
)
if err != nil {
return nil, err
}
defer syscall.Close(s)
fd := uintptr(s)
var fromName [16]byte
var toName [16]byte
copy(fromName[:], ifName)
copy(toName[:], deviceName)
ifrr := ifreqRename{
Name: fromName,
Data: uintptr(unsafe.Pointer(&toName)),
}
// Set the device name
ioctl(fd, syscall.SIOCSIFNAME, uintptr(unsafe.Pointer(&ifrr)))
}
routeTree, err := makeRouteTree(l, routes, false)
if err != nil {
return nil, err
}
if strings.HasPrefix(deviceName, "/dev/") {
deviceName = strings.TrimPrefix(deviceName, "/dev/")
}
if !deviceNameRE.MatchString(deviceName) {
return nil, fmt.Errorf("tun.dev must match `tun[0-9]+`")
}
return &tun{
Device: deviceName,
cidr: cidr,
MTU: defaultMTU,
Routes: routes,
routeTree: routeTree,
l: l,
ReadWriteCloser: file,
Device: deviceName,
cidr: cidr,
MTU: defaultMTU,
Routes: routes,
routeTree: routeTree,
l: l,
}, nil
}
func (t *tun) Activate() error {
var err error
t.ReadWriteCloser, err = os.OpenFile("/dev/"+t.Device, os.O_RDWR, 0)
if err != nil {
return fmt.Errorf("activate failed: %v", err)
}
// TODO use syscalls instead of exec.Command
t.l.Debug("command: ifconfig", t.Device, t.cidr.String(), t.cidr.IP.String())
if err = exec.Command("/sbin/ifconfig", t.Device, t.cidr.String(), t.cidr.IP.String()).Run(); err != nil {
@ -120,3 +211,10 @@ func (t *tun) Name() string {
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for freebsd")
}
func (t *tun) deviceBytes() (o [16]byte) {
for i, c := range t.Device {
o[i] = byte(c)
}
return
}

View File

@ -43,14 +43,6 @@ type ifReq struct {
pad [8]byte
}
func ioctl(a1, a2, a3 uintptr) error {
_, _, errno := unix.Syscall(unix.SYS_IOCTL, a1, a2, a3)
if errno != 0 {
return errno
}
return nil
}
type ifreqAddr struct {
Name [16]byte
Addr unix.RawSockaddrInet4

162
overlay/tun_netbsd.go Normal file
View File

@ -0,0 +1,162 @@
//go:build !e2e_testing
// +build !e2e_testing
package overlay
import (
"fmt"
"io"
"net"
"os"
"os/exec"
"regexp"
"strconv"
"syscall"
"unsafe"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cidr"
"github.com/slackhq/nebula/iputil"
)
type ifreqDestroy struct {
Name [16]byte
pad [16]byte
}
type tun struct {
Device string
cidr *net.IPNet
MTU int
Routes []Route
routeTree *cidr.Tree4
l *logrus.Logger
io.ReadWriteCloser
}
func (t *tun) Close() error {
if t.ReadWriteCloser != nil {
if err := t.ReadWriteCloser.Close(); err != nil {
return err
}
s, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_DGRAM, syscall.IPPROTO_IP)
if err != nil {
return err
}
defer syscall.Close(s)
ifreq := ifreqDestroy{Name: t.deviceBytes()}
err = ioctl(uintptr(s), syscall.SIOCIFDESTROY, uintptr(unsafe.Pointer(&ifreq)))
return err
}
return nil
}
func newTunFromFd(_ *logrus.Logger, _ int, _ *net.IPNet, _ int, _ []Route, _ int, _ bool) (*tun, error) {
return nil, fmt.Errorf("newTunFromFd not supported in NetBSD")
}
var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, routes []Route, _ int, _ bool, _ bool) (*tun, error) {
// Try to open tun device
var file *os.File
var err error
if deviceName == "" {
return nil, fmt.Errorf("a device name in the format of /dev/tunN must be specified")
}
if !deviceNameRE.MatchString(deviceName) {
return nil, fmt.Errorf("a device name in the format of /dev/tunN must be specified")
}
file, err = os.OpenFile("/dev/"+deviceName, os.O_RDWR, 0)
if err != nil {
return nil, err
}
routeTree, err := makeRouteTree(l, routes, false)
if err != nil {
return nil, err
}
return &tun{
ReadWriteCloser: file,
Device: deviceName,
cidr: cidr,
MTU: defaultMTU,
Routes: routes,
routeTree: routeTree,
l: l,
}, nil
}
func (t *tun) Activate() error {
var err error
// TODO use syscalls instead of exec.Command
cmd := exec.Command("/sbin/ifconfig", t.Device, t.cidr.String(), t.cidr.IP.String())
t.l.Debug("command: ", cmd.String())
if err = cmd.Run(); err != nil {
return fmt.Errorf("failed to run 'ifconfig': %s", err)
}
cmd = exec.Command("/sbin/route", "-n", "add", "-net", t.cidr.String(), t.cidr.IP.String())
t.l.Debug("command: ", cmd.String())
if err = cmd.Run(); err != nil {
return fmt.Errorf("failed to run 'route add': %s", err)
}
cmd = exec.Command("/sbin/ifconfig", t.Device, "mtu", strconv.Itoa(t.MTU))
t.l.Debug("command: ", cmd.String())
if err = cmd.Run(); err != nil {
return fmt.Errorf("failed to run 'ifconfig': %s", err)
}
// Unsafe path routes
for _, r := range t.Routes {
if r.Via == nil || !r.Install {
// We don't allow route MTUs so only install routes with a via
continue
}
cmd = exec.Command("/sbin/route", "-n", "add", "-net", r.Cidr.String(), t.cidr.IP.String())
t.l.Debug("command: ", cmd.String())
if err = cmd.Run(); err != nil {
return fmt.Errorf("failed to run 'route add' for unsafe_route %s: %s", r.Cidr.String(), err)
}
}
return nil
}
func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
r := t.routeTree.MostSpecificContains(ip)
if r != nil {
return r.(iputil.VpnIp)
}
return 0
}
func (t *tun) Cidr() *net.IPNet {
return t.cidr
}
func (t *tun) Name() string {
return t.Device
}
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for netbsd")
}
func (t *tun) deviceBytes() (o [16]byte) {
for i, c := range t.Device {
o[i] = byte(c)
}
return
}

14
overlay/tun_notwin.go Normal file
View File

@ -0,0 +1,14 @@
//go:build !windows
// +build !windows
package overlay
import "syscall"
func ioctl(a1, a2, a3 uintptr) error {
_, _, errno := syscall.Syscall(syscall.SYS_IOCTL, a1, a2, a3)
if errno != 0 {
return errno
}
return nil
}

174
overlay/tun_openbsd.go Normal file
View File

@ -0,0 +1,174 @@
//go:build !e2e_testing
// +build !e2e_testing
package overlay
import (
"fmt"
"io"
"net"
"os"
"os/exec"
"regexp"
"strconv"
"syscall"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cidr"
"github.com/slackhq/nebula/iputil"
)
type tun struct {
Device string
cidr *net.IPNet
MTU int
Routes []Route
routeTree *cidr.Tree4
l *logrus.Logger
io.ReadWriteCloser
// cache out buffer since we need to prepend 4 bytes for tun metadata
out []byte
}
func (t *tun) Close() error {
if t.ReadWriteCloser != nil {
return t.ReadWriteCloser.Close()
}
return nil
}
func newTunFromFd(_ *logrus.Logger, _ int, _ *net.IPNet, _ int, _ []Route, _ int, _ bool) (*tun, error) {
return nil, fmt.Errorf("newTunFromFd not supported in OpenBSD")
}
var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, routes []Route, _ int, _ bool, _ bool) (*tun, error) {
if deviceName == "" {
return nil, fmt.Errorf("a device name in the format of tunN must be specified")
}
if !deviceNameRE.MatchString(deviceName) {
return nil, fmt.Errorf("a device name in the format of tunN must be specified")
}
file, err := os.OpenFile("/dev/"+deviceName, os.O_RDWR, 0)
if err != nil {
return nil, err
}
routeTree, err := makeRouteTree(l, routes, false)
if err != nil {
return nil, err
}
return &tun{
ReadWriteCloser: file,
Device: deviceName,
cidr: cidr,
MTU: defaultMTU,
Routes: routes,
routeTree: routeTree,
l: l,
}, nil
}
func (t *tun) Activate() error {
var err error
// TODO use syscalls instead of exec.Command
cmd := exec.Command("/sbin/ifconfig", t.Device, t.cidr.String(), t.cidr.IP.String())
t.l.Debug("command: ", cmd.String())
if err = cmd.Run(); err != nil {
return fmt.Errorf("failed to run 'ifconfig': %s", err)
}
cmd = exec.Command("/sbin/ifconfig", t.Device, "mtu", strconv.Itoa(t.MTU))
t.l.Debug("command: ", cmd.String())
if err = cmd.Run(); err != nil {
return fmt.Errorf("failed to run 'ifconfig': %s", err)
}
cmd = exec.Command("/sbin/route", "-n", "add", "-inet", t.cidr.String(), t.cidr.IP.String())
t.l.Debug("command: ", cmd.String())
if err = cmd.Run(); err != nil {
return fmt.Errorf("failed to run 'route add': %s", err)
}
// Unsafe path routes
for _, r := range t.Routes {
if r.Via == nil || !r.Install {
// We don't allow route MTUs so only install routes with a via
continue
}
cmd = exec.Command("/sbin/route", "-n", "add", "-inet", r.Cidr.String(), t.cidr.IP.String())
t.l.Debug("command: ", cmd.String())
if err = cmd.Run(); err != nil {
return fmt.Errorf("failed to run 'route add' for unsafe_route %s: %s", r.Cidr.String(), err)
}
}
return nil
}
func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
r := t.routeTree.MostSpecificContains(ip)
if r != nil {
return r.(iputil.VpnIp)
}
return 0
}
func (t *tun) Cidr() *net.IPNet {
return t.cidr
}
func (t *tun) Name() string {
return t.Device
}
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for freebsd")
}
func (t *tun) Read(to []byte) (int, error) {
buf := make([]byte, len(to)+4)
n, err := t.ReadWriteCloser.Read(buf)
copy(to, buf[4:])
return n - 4, err
}
// Write is only valid for single threaded use
func (t *tun) Write(from []byte) (int, error) {
buf := t.out
if cap(buf) < len(from)+4 {
buf = make([]byte, len(from)+4)
t.out = buf
}
buf = buf[:len(from)+4]
if len(from) == 0 {
return 0, syscall.EIO
}
// Determine the IP Family for the NULL L2 Header
ipVer := from[0] >> 4
if ipVer == 4 {
buf[3] = syscall.AF_INET
} else if ipVer == 6 {
buf[3] = syscall.AF_INET6
} else {
return 0, fmt.Errorf("unable to determine IP version from packet")
}
copy(buf[4:], from)
n, err := t.ReadWriteCloser.Write(buf)
return n - 4, err
}

View File

@ -8,6 +8,7 @@ import (
"io"
"net"
"os"
"sync/atomic"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cidr"
@ -21,6 +22,7 @@ type TestTun struct {
routeTree *cidr.Tree4
l *logrus.Logger
closed atomic.Bool
rxPackets chan []byte // Packets to receive into nebula
TxPackets chan []byte // Packets transmitted outside by nebula
}
@ -50,6 +52,10 @@ func newTunFromFd(_ *logrus.Logger, _ int, _ *net.IPNet, _ int, _ []Route, _ int
// These are unencrypted ip layer frames destined for another nebula node.
// packets should exit the udp side, capture them with udpConn.Get
func (t *TestTun) Send(packet []byte) {
if t.closed.Load() {
return
}
if t.l.Level >= logrus.DebugLevel {
t.l.WithField("dataLen", len(packet)).Debug("Tun receiving injected packet")
}
@ -98,6 +104,10 @@ func (t *TestTun) Name() string {
}
func (t *TestTun) Write(b []byte) (n int, err error) {
if t.closed.Load() {
return 0, io.ErrClosedPipe
}
packet := make([]byte, len(b), len(b))
copy(packet, b)
t.TxPackets <- packet
@ -105,7 +115,10 @@ func (t *TestTun) Write(b []byte) (n int, err error) {
}
func (t *TestTun) Close() error {
close(t.rxPackets)
if t.closed.CompareAndSwap(false, true) {
close(t.rxPackets)
close(t.TxPackets)
}
return nil
}

View File

@ -54,9 +54,16 @@ func newWinTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU
return nil, fmt.Errorf("generate GUID failed: %w", err)
}
tunDevice, err := wintun.CreateTUNWithRequestedGUID(deviceName, guid, defaultMTU)
var tunDevice wintun.Device
tunDevice, err = wintun.CreateTUNWithRequestedGUID(deviceName, guid, defaultMTU)
if err != nil {
return nil, fmt.Errorf("create TUN device failed: %w", err)
// Windows 10 has an issue with unclean shutdowns not fully cleaning up the wintun device.
// Trying a second time resolves the issue.
l.WithError(err).Debug("Failed to create wintun device, retrying")
tunDevice, err = wintun.CreateTUNWithRequestedGUID(deviceName, guid, defaultMTU)
if err != nil {
return nil, fmt.Errorf("create TUN device failed: %w", err)
}
}
routeTree, err := makeRouteTree(l, routes, false)

248
pki.go Normal file
View File

@ -0,0 +1,248 @@
package nebula
import (
"errors"
"fmt"
"os"
"strings"
"sync/atomic"
"time"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/util"
)
type PKI struct {
cs atomic.Pointer[CertState]
caPool atomic.Pointer[cert.NebulaCAPool]
l *logrus.Logger
}
type CertState struct {
Certificate *cert.NebulaCertificate
RawCertificate []byte
RawCertificateNoKey []byte
PublicKey []byte
PrivateKey []byte
}
func NewPKIFromConfig(l *logrus.Logger, c *config.C) (*PKI, error) {
pki := &PKI{l: l}
err := pki.reload(c, true)
if err != nil {
return nil, err
}
c.RegisterReloadCallback(func(c *config.C) {
rErr := pki.reload(c, false)
if rErr != nil {
util.LogWithContextIfNeeded("Failed to reload PKI from config", rErr, l)
}
})
return pki, nil
}
func (p *PKI) GetCertState() *CertState {
return p.cs.Load()
}
func (p *PKI) GetCAPool() *cert.NebulaCAPool {
return p.caPool.Load()
}
func (p *PKI) reload(c *config.C, initial bool) error {
err := p.reloadCert(c, initial)
if err != nil {
if initial {
return err
}
err.Log(p.l)
}
err = p.reloadCAPool(c)
if err != nil {
if initial {
return err
}
err.Log(p.l)
}
return nil
}
func (p *PKI) reloadCert(c *config.C, initial bool) *util.ContextualError {
cs, err := newCertStateFromConfig(c)
if err != nil {
return util.NewContextualError("Could not load client cert", nil, err)
}
if !initial {
// did IP in cert change? if so, don't set
currentCert := p.cs.Load().Certificate
oldIPs := currentCert.Details.Ips
newIPs := cs.Certificate.Details.Ips
if len(oldIPs) > 0 && len(newIPs) > 0 && oldIPs[0].String() != newIPs[0].String() {
return util.NewContextualError(
"IP in new cert was different from old",
m{"new_ip": newIPs[0], "old_ip": oldIPs[0]},
nil,
)
}
}
p.cs.Store(cs)
if initial {
p.l.WithField("cert", cs.Certificate).Debug("Client nebula certificate")
} else {
p.l.WithField("cert", cs.Certificate).Info("Client cert refreshed from disk")
}
return nil
}
func (p *PKI) reloadCAPool(c *config.C) *util.ContextualError {
caPool, err := loadCAPoolFromConfig(p.l, c)
if err != nil {
return util.NewContextualError("Failed to load ca from config", nil, err)
}
p.caPool.Store(caPool)
p.l.WithField("fingerprints", caPool.GetFingerprints()).Debug("Trusted CA fingerprints")
return nil
}
func newCertState(certificate *cert.NebulaCertificate, privateKey []byte) (*CertState, error) {
// Marshal the certificate to ensure it is valid
rawCertificate, err := certificate.Marshal()
if err != nil {
return nil, fmt.Errorf("invalid nebula certificate on interface: %s", err)
}
publicKey := certificate.Details.PublicKey
cs := &CertState{
RawCertificate: rawCertificate,
Certificate: certificate,
PrivateKey: privateKey,
PublicKey: publicKey,
}
cs.Certificate.Details.PublicKey = nil
rawCertNoKey, err := cs.Certificate.Marshal()
if err != nil {
return nil, fmt.Errorf("error marshalling certificate no key: %s", err)
}
cs.RawCertificateNoKey = rawCertNoKey
// put public key back
cs.Certificate.Details.PublicKey = cs.PublicKey
return cs, nil
}
func newCertStateFromConfig(c *config.C) (*CertState, error) {
var pemPrivateKey []byte
var err error
privPathOrPEM := c.GetString("pki.key", "")
if privPathOrPEM == "" {
return nil, errors.New("no pki.key path or PEM data provided")
}
if strings.Contains(privPathOrPEM, "-----BEGIN") {
pemPrivateKey = []byte(privPathOrPEM)
privPathOrPEM = "<inline>"
} else {
pemPrivateKey, err = os.ReadFile(privPathOrPEM)
if err != nil {
return nil, fmt.Errorf("unable to read pki.key file %s: %s", privPathOrPEM, err)
}
}
rawKey, _, curve, err := cert.UnmarshalPrivateKey(pemPrivateKey)
if err != nil {
return nil, fmt.Errorf("error while unmarshaling pki.key %s: %s", privPathOrPEM, err)
}
var rawCert []byte
pubPathOrPEM := c.GetString("pki.cert", "")
if pubPathOrPEM == "" {
return nil, errors.New("no pki.cert path or PEM data provided")
}
if strings.Contains(pubPathOrPEM, "-----BEGIN") {
rawCert = []byte(pubPathOrPEM)
pubPathOrPEM = "<inline>"
} else {
rawCert, err = os.ReadFile(pubPathOrPEM)
if err != nil {
return nil, fmt.Errorf("unable to read pki.cert file %s: %s", pubPathOrPEM, err)
}
}
nebulaCert, _, err := cert.UnmarshalNebulaCertificateFromPEM(rawCert)
if err != nil {
return nil, fmt.Errorf("error while unmarshaling pki.cert %s: %s", pubPathOrPEM, err)
}
if nebulaCert.Expired(time.Now()) {
return nil, fmt.Errorf("nebula certificate for this host is expired")
}
if len(nebulaCert.Details.Ips) == 0 {
return nil, fmt.Errorf("no IPs encoded in certificate")
}
if err = nebulaCert.VerifyPrivateKey(curve, rawKey); err != nil {
return nil, fmt.Errorf("private key is not a pair with public key in nebula cert")
}
return newCertState(nebulaCert, rawKey)
}
func loadCAPoolFromConfig(l *logrus.Logger, c *config.C) (*cert.NebulaCAPool, error) {
var rawCA []byte
var err error
caPathOrPEM := c.GetString("pki.ca", "")
if caPathOrPEM == "" {
return nil, errors.New("no pki.ca path or PEM data provided")
}
if strings.Contains(caPathOrPEM, "-----BEGIN") {
rawCA = []byte(caPathOrPEM)
} else {
rawCA, err = os.ReadFile(caPathOrPEM)
if err != nil {
return nil, fmt.Errorf("unable to read pki.ca file %s: %s", caPathOrPEM, err)
}
}
caPool, err := cert.NewCAPoolFromBytes(rawCA)
if errors.Is(err, cert.ErrExpired) {
var expired int
for _, crt := range caPool.CAs {
if crt.Expired(time.Now()) {
expired++
l.WithField("cert", crt).Warn("expired certificate present in CA pool")
}
}
if expired >= len(caPool.CAs) {
return nil, errors.New("no valid CA certificates present")
}
} else if err != nil {
return nil, fmt.Errorf("error while adding CA certificate to CA trust store: %s", err)
}
for _, fp := range c.GetStringSlice("pki.blocklist", []string{}) {
l.WithField("fingerprint", fp).Info("Blocklisting cert")
caPool.BlocklistFingerprint(fp)
}
return caPool, nil
}

View File

@ -131,9 +131,9 @@ func (rm *relayManager) handleCreateRelayResponse(h *HostInfo, f *Interface, m *
return
}
// I'm the middle man. Let the initiator know that the I've established the relay they requested.
peerHostInfo, err := rm.hostmap.QueryVpnIp(relay.PeerIp)
if err != nil {
rm.l.WithError(err).WithField("relayTo", relay.PeerIp).Error("Can't find a HostInfo for peer")
peerHostInfo := rm.hostmap.QueryVpnIp(relay.PeerIp)
if peerHostInfo == nil {
rm.l.WithField("relayTo", relay.PeerIp).Error("Can't find a HostInfo for peer")
return
}
peerRelay, ok := peerHostInfo.relayState.QueryRelayForByIp(target)
@ -179,6 +179,12 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N
"vpnIp": h.vpnIp})
logMsg.Info("handleCreateRelayRequest")
// Is the source of the relay me? This should never happen, but did happen due to
// an issue migrating relays over to newly re-handshaked host info objects.
if from == f.myVpnIp {
logMsg.WithField("myIP", f.myVpnIp).Error("Discarding relay request from myself")
return
}
// Is the target of the relay me?
if target == f.myVpnIp {
existingRelay, ok := h.relayState.QueryRelayForByIp(from)
@ -240,11 +246,11 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N
if !rm.GetAmRelay() {
return
}
peer, err := rm.hostmap.QueryVpnIp(target)
if err != nil {
peer := rm.hostmap.QueryVpnIp(target)
if peer == nil {
// Try to establish a connection to this host. If we get a future relay request,
// we'll be ready!
f.getOrHandshake(target)
f.Handshake(target)
return
}
if peer.remote == nil {
@ -253,6 +259,7 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N
}
sendCreateRequest := false
var index uint32
var err error
targetRelay, ok := peer.relayState.QueryRelayForByIp(from)
if ok {
index = targetRelay.LocalIndex

View File

@ -70,7 +70,7 @@ type hostnamesResults struct {
hostnames []hostnamePort
network string
lookupTimeout time.Duration
stop chan struct{}
cancelFn func()
l *logrus.Logger
ips atomic.Pointer[map[netip.AddrPort]struct{}]
}
@ -80,7 +80,6 @@ func NewHostnameResults(ctx context.Context, l *logrus.Logger, d time.Duration,
hostnames: make([]hostnamePort, len(hostPorts)),
network: network,
lookupTimeout: timeout,
stop: make(chan (struct{})),
l: l,
}
@ -115,6 +114,8 @@ func NewHostnameResults(ctx context.Context, l *logrus.Logger, d time.Duration,
// Time for the DNS lookup goroutine
if performBackgroundLookup {
newCtx, cancel := context.WithCancel(ctx)
r.cancelFn = cancel
ticker := time.NewTicker(d)
go func() {
defer ticker.Stop()
@ -154,9 +155,7 @@ func NewHostnameResults(ctx context.Context, l *logrus.Logger, d time.Duration,
onUpdate()
}
select {
case <-ctx.Done():
return
case <-r.stop:
case <-newCtx.Done():
return
case <-ticker.C:
continue
@ -169,8 +168,8 @@ func NewHostnameResults(ctx context.Context, l *logrus.Logger, d time.Duration,
}
func (hr *hostnamesResults) Cancel() {
if hr != nil {
hr.stop <- struct{}{}
if hr != nil && hr.cancelFn != nil {
hr.cancelFn()
}
}
@ -582,20 +581,11 @@ func (r *RemoteList) unlockedCollect() {
dnsAddrs := r.hr.GetIPs()
for _, addr := range dnsAddrs {
if r.shouldAdd == nil || r.shouldAdd(addr.Addr()) {
switch {
case addr.Addr().Is4():
v4 := addr.Addr().As4()
addrs = append(addrs, &udp.Addr{
IP: v4[:],
Port: addr.Port(),
})
case addr.Addr().Is6():
v6 := addr.Addr().As16()
addrs = append(addrs, &udp.Addr{
IP: v6[:],
Port: addr.Port(),
})
}
v6 := addr.Addr().As16()
addrs = append(addrs, &udp.Addr{
IP: v6[:],
Port: addr.Port(),
})
}
}

66
ssh.go
View File

@ -3,6 +3,7 @@ package nebula
import (
"bytes"
"encoding/json"
"errors"
"flag"
"fmt"
"io/ioutil"
@ -168,7 +169,7 @@ func configSSH(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) (func(), erro
return runner, nil
}
func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, hostMap *HostMap, pendingHostMap *HostMap, lightHouse *LightHouse, ifce *Interface) {
func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Interface) {
ssh.RegisterCommand(&sshd.Command{
Name: "list-hostmap",
ShortDescription: "List all known previously connected hosts",
@ -181,7 +182,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, hostMap
return fl, &s
},
Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
return sshListHostMap(hostMap, fs, w)
return sshListHostMap(f.hostMap, fs, w)
},
})
@ -197,7 +198,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, hostMap
return fl, &s
},
Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
return sshListHostMap(pendingHostMap, fs, w)
return sshListHostMap(f.handshakeManager, fs, w)
},
})
@ -212,7 +213,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, hostMap
return fl, &s
},
Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
return sshListLighthouseMap(lightHouse, fs, w)
return sshListLighthouseMap(f.lightHouse, fs, w)
},
})
@ -277,7 +278,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, hostMap
Name: "version",
ShortDescription: "Prints the currently running version of nebula",
Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
return sshVersion(ifce, fs, a, w)
return sshVersion(f, fs, a, w)
},
})
@ -293,7 +294,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, hostMap
return fl, &s
},
Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
return sshPrintCert(ifce, fs, a, w)
return sshPrintCert(f, fs, a, w)
},
})
@ -307,7 +308,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, hostMap
return fl, &s
},
Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
return sshPrintTunnel(ifce, fs, a, w)
return sshPrintTunnel(f, fs, a, w)
},
})
@ -321,7 +322,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, hostMap
return fl, &s
},
Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
return sshPrintRelays(ifce, fs, a, w)
return sshPrintRelays(f, fs, a, w)
},
})
@ -335,7 +336,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, hostMap
return fl, &s
},
Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
return sshChangeRemote(ifce, fs, a, w)
return sshChangeRemote(f, fs, a, w)
},
})
@ -349,7 +350,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, hostMap
return fl, &s
},
Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
return sshCloseTunnel(ifce, fs, a, w)
return sshCloseTunnel(f, fs, a, w)
},
})
@ -364,7 +365,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, hostMap
return fl, &s
},
Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
return sshCreateTunnel(ifce, fs, a, w)
return sshCreateTunnel(f, fs, a, w)
},
})
@ -373,12 +374,12 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, hostMap
ShortDescription: "Query the lighthouses for the provided vpn ip",
Help: "This command is asynchronous. Only currently known udp ips will be printed.",
Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
return sshQueryLighthouse(ifce, fs, a, w)
return sshQueryLighthouse(f, fs, a, w)
},
})
}
func sshListHostMap(hostMap *HostMap, a interface{}, w sshd.StringWriter) error {
func sshListHostMap(hl controlHostLister, a interface{}, w sshd.StringWriter) error {
fs, ok := a.(*sshListHostMapFlags)
if !ok {
//TODO: error
@ -387,9 +388,9 @@ func sshListHostMap(hostMap *HostMap, a interface{}, w sshd.StringWriter) error
var hm []ControlHostInfo
if fs.ByIndex {
hm = listHostMapIndexes(hostMap)
hm = listHostMapIndexes(hl)
} else {
hm = listHostMapHosts(hostMap)
hm = listHostMapHosts(hl)
}
sort.Slice(hm, func(i, j int) bool {
@ -546,8 +547,8 @@ func sshCloseTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWr
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
}
hostInfo, err := ifce.hostMap.QueryVpnIp(vpnIp)
if err != nil {
hostInfo := ifce.hostMap.QueryVpnIp(vpnIp)
if hostInfo == nil {
return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0]))
}
@ -588,12 +589,12 @@ func sshCreateTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringW
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
}
hostInfo, _ := ifce.hostMap.QueryVpnIp(vpnIp)
hostInfo := ifce.hostMap.QueryVpnIp(vpnIp)
if hostInfo != nil {
return w.WriteLine(fmt.Sprintf("Tunnel already exists"))
}
hostInfo, _ = ifce.handshakeManager.pendingHostMap.QueryVpnIp(vpnIp)
hostInfo = ifce.handshakeManager.QueryVpnIp(vpnIp)
if hostInfo != nil {
return w.WriteLine(fmt.Sprintf("Tunnel already handshaking"))
}
@ -606,11 +607,10 @@ func sshCreateTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringW
}
}
hostInfo = ifce.handshakeManager.AddVpnIp(vpnIp, ifce.initHostInfo)
hostInfo = ifce.handshakeManager.StartHandshake(vpnIp, nil)
if addr != nil {
hostInfo.SetRemote(addr)
}
ifce.getOrHandshake(vpnIp)
return w.WriteLine("Created")
}
@ -645,8 +645,8 @@ func sshChangeRemote(ifce *Interface, fs interface{}, a []string, w sshd.StringW
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
}
hostInfo, err := ifce.hostMap.QueryVpnIp(vpnIp)
if err != nil {
hostInfo := ifce.hostMap.QueryVpnIp(vpnIp)
if hostInfo == nil {
return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0]))
}
@ -753,7 +753,7 @@ func sshPrintCert(ifce *Interface, fs interface{}, a []string, w sshd.StringWrit
return nil
}
cert := ifce.certState.Load().certificate
cert := ifce.pki.GetCertState().Certificate
if len(a) > 0 {
parsedIp := net.ParseIP(a[0])
if parsedIp == nil {
@ -765,8 +765,8 @@ func sshPrintCert(ifce *Interface, fs interface{}, a []string, w sshd.StringWrit
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
}
hostInfo, err := ifce.hostMap.QueryVpnIp(vpnIp)
if err != nil {
hostInfo := ifce.hostMap.QueryVpnIp(vpnIp)
if hostInfo == nil {
return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0]))
}
@ -851,9 +851,9 @@ func sshPrintRelays(ifce *Interface, fs interface{}, a []string, w sshd.StringWr
for k, v := range relays {
ro := RelayOutput{NebulaIp: v.vpnIp}
co.Relays = append(co.Relays, &ro)
relayHI, err := ifce.hostMap.QueryVpnIp(v.vpnIp)
if err != nil {
ro.RelayForIps = append(ro.RelayForIps, RelayFor{Error: err})
relayHI := ifce.hostMap.QueryVpnIp(v.vpnIp)
if relayHI == nil {
ro.RelayForIps = append(ro.RelayForIps, RelayFor{Error: errors.New("could not find hostinfo")})
continue
}
for _, vpnIp := range relayHI.relayState.CopyRelayForIps() {
@ -889,8 +889,8 @@ func sshPrintRelays(ifce *Interface, fs interface{}, a []string, w sshd.StringWr
rf.Error = fmt.Errorf("hostmap LocalIndex '%v' does not match RelayState LocalIndex", k)
}
}
relayedHI, err := ifce.hostMap.QueryVpnIp(vpnIp)
if err == nil {
relayedHI := ifce.hostMap.QueryVpnIp(vpnIp)
if relayedHI != nil {
rf.RelayedThrough = append(rf.RelayedThrough, relayedHI.relayState.CopyRelayIps()...)
}
@ -925,8 +925,8 @@ func sshPrintTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWr
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
}
hostInfo, err := ifce.hostMap.QueryVpnIp(vpnIp)
if err != nil {
hostInfo := ifce.hostMap.QueryVpnIp(vpnIp)
if hostInfo == nil {
return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0]))
}

View File

@ -1,6 +1,7 @@
package udp
import (
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/firewall"
"github.com/slackhq/nebula/header"
)
@ -18,3 +19,33 @@ type EncReader func(
q int,
localCache firewall.ConntrackCache,
)
type Conn interface {
Rebind() error
LocalAddr() (*Addr, error)
ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int)
WriteTo(b []byte, addr *Addr) error
ReloadConfig(c *config.C)
Close() error
}
type NoopConn struct{}
func (NoopConn) Rebind() error {
return nil
}
func (NoopConn) LocalAddr() (*Addr, error) {
return nil, nil
}
func (NoopConn) ListenOut(_ EncReader, _ LightHouseHandlerFunc, _ *firewall.ConntrackCacheTicker, _ int) {
return
}
func (NoopConn) WriteTo(_ []byte, _ *Addr) error {
return nil
}
func (NoopConn) ReloadConfig(_ *config.C) {
return
}
func (NoopConn) Close() error {
return nil
}

View File

@ -8,9 +8,14 @@ import (
"net"
"syscall"
"github.com/sirupsen/logrus"
"golang.org/x/sys/unix"
)
func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (Conn, error) {
return NewGenericListener(l, ip, port, multi, batch)
}
func NewListenConfig(multi bool) net.ListenConfig {
return net.ListenConfig{
Control: func(network, address string, c syscall.RawConn) error {
@ -34,6 +39,6 @@ func NewListenConfig(multi bool) net.ListenConfig {
}
}
func (u *Conn) Rebind() error {
func (u *GenericConn) Rebind() error {
return nil
}

47
udp/udp_bsd.go Normal file
View File

@ -0,0 +1,47 @@
//go:build (openbsd || freebsd) && !e2e_testing
// +build openbsd freebsd
// +build !e2e_testing
package udp
// FreeBSD support is primarily implemented in udp_generic, besides NewListenConfig
import (
"fmt"
"net"
"syscall"
"github.com/sirupsen/logrus"
"golang.org/x/sys/unix"
)
func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (Conn, error) {
return NewGenericListener(l, ip, port, multi, batch)
}
func NewListenConfig(multi bool) net.ListenConfig {
return net.ListenConfig{
Control: func(network, address string, c syscall.RawConn) error {
if multi {
var controlErr error
err := c.Control(func(fd uintptr) {
if err := syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, unix.SO_REUSEPORT, 1); err != nil {
controlErr = fmt.Errorf("SO_REUSEPORT failed: %v", err)
return
}
})
if err != nil {
return err
}
if controlErr != nil {
return controlErr
}
}
return nil
},
}
}
func (u *GenericConn) Rebind() error {
return nil
}

View File

@ -10,9 +10,14 @@ import (
"net"
"syscall"
"github.com/sirupsen/logrus"
"golang.org/x/sys/unix"
)
func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (Conn, error) {
return NewGenericListener(l, ip, port, multi, batch)
}
func NewListenConfig(multi bool) net.ListenConfig {
return net.ListenConfig{
Control: func(network, address string, c syscall.RawConn) error {
@ -37,11 +42,16 @@ func NewListenConfig(multi bool) net.ListenConfig {
}
}
func (u *Conn) Rebind() error {
file, err := u.File()
func (u *GenericConn) Rebind() error {
rc, err := u.UDPConn.SyscallConn()
if err != nil {
return err
}
return syscall.SetsockoptInt(int(file.Fd()), unix.IPPROTO_IPV6, unix.IPV6_BOUND_IF, 0)
return rc.Control(func(fd uintptr) {
err := syscall.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_BOUND_IF, 0)
if err != nil {
u.l.WithError(err).Error("Failed to rebind udp socket")
}
})
}

View File

@ -18,30 +18,32 @@ import (
"github.com/slackhq/nebula/header"
)
type Conn struct {
type GenericConn struct {
*net.UDPConn
l *logrus.Logger
}
func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (*Conn, error) {
var _ Conn = &GenericConn{}
func NewGenericListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (Conn, error) {
lc := NewListenConfig(multi)
pc, err := lc.ListenPacket(context.TODO(), "udp", net.JoinHostPort(ip.String(), fmt.Sprintf("%v", port)))
if err != nil {
return nil, err
}
if uc, ok := pc.(*net.UDPConn); ok {
return &Conn{UDPConn: uc, l: l}, nil
return &GenericConn{UDPConn: uc, l: l}, nil
}
return nil, fmt.Errorf("Unexpected PacketConn: %T %#v", pc, pc)
}
func (uc *Conn) WriteTo(b []byte, addr *Addr) error {
_, err := uc.UDPConn.WriteToUDP(b, &net.UDPAddr{IP: addr.IP, Port: int(addr.Port)})
func (u *GenericConn) WriteTo(b []byte, addr *Addr) error {
_, err := u.UDPConn.WriteToUDP(b, &net.UDPAddr{IP: addr.IP, Port: int(addr.Port)})
return err
}
func (uc *Conn) LocalAddr() (*Addr, error) {
a := uc.UDPConn.LocalAddr()
func (u *GenericConn) LocalAddr() (*Addr, error) {
a := u.UDPConn.LocalAddr()
switch v := a.(type) {
case *net.UDPAddr:
@ -55,11 +57,11 @@ func (uc *Conn) LocalAddr() (*Addr, error) {
}
}
func (u *Conn) ReloadConfig(c *config.C) {
func (u *GenericConn) ReloadConfig(c *config.C) {
// TODO
}
func NewUDPStatsEmitter(udpConns []*Conn) func() {
func NewUDPStatsEmitter(udpConns []Conn) func() {
// No UDP stats for non-linux
return func() {}
}
@ -68,7 +70,7 @@ type rawMessage struct {
Len uint32
}
func (u *Conn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int) {
func (u *GenericConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int) {
plaintext := make([]byte, MTU)
buffer := make([]byte, MTU)
h := &header.H{}
@ -80,8 +82,8 @@ func (u *Conn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall
// Just read one packet at a time
n, rua, err := u.ReadFromUDP(buffer)
if err != nil {
u.l.WithError(err).Error("Failed to read packets")
continue
u.l.WithError(err).Debug("udp socket is closed, exiting read loop")
return
}
udpAddr.IP = rua.IP

View File

@ -20,7 +20,7 @@ import (
//TODO: make it support reload as best you can!
type Conn struct {
type StdConn struct {
sysFd int
l *logrus.Logger
batch int
@ -45,7 +45,7 @@ const (
type _SK_MEMINFO [_SK_MEMINFO_VARS]uint32
func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (*Conn, error) {
func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (Conn, error) {
syscall.ForkLock.RLock()
fd, err := unix.Socket(unix.AF_INET6, unix.SOCK_DGRAM, unix.IPPROTO_UDP)
if err == nil {
@ -77,30 +77,30 @@ func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (
//v, err := unix.GetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_INCOMING_CPU)
//l.Println(v, err)
return &Conn{sysFd: fd, l: l, batch: batch}, err
return &StdConn{sysFd: fd, l: l, batch: batch}, err
}
func (u *Conn) Rebind() error {
func (u *StdConn) Rebind() error {
return nil
}
func (u *Conn) SetRecvBuffer(n int) error {
func (u *StdConn) SetRecvBuffer(n int) error {
return unix.SetsockoptInt(u.sysFd, unix.SOL_SOCKET, unix.SO_RCVBUFFORCE, n)
}
func (u *Conn) SetSendBuffer(n int) error {
func (u *StdConn) SetSendBuffer(n int) error {
return unix.SetsockoptInt(u.sysFd, unix.SOL_SOCKET, unix.SO_SNDBUFFORCE, n)
}
func (u *Conn) GetRecvBuffer() (int, error) {
func (u *StdConn) GetRecvBuffer() (int, error) {
return unix.GetsockoptInt(int(u.sysFd), unix.SOL_SOCKET, unix.SO_RCVBUF)
}
func (u *Conn) GetSendBuffer() (int, error) {
func (u *StdConn) GetSendBuffer() (int, error) {
return unix.GetsockoptInt(int(u.sysFd), unix.SOL_SOCKET, unix.SO_SNDBUF)
}
func (u *Conn) LocalAddr() (*Addr, error) {
func (u *StdConn) LocalAddr() (*Addr, error) {
sa, err := unix.Getsockname(u.sysFd)
if err != nil {
return nil, err
@ -119,7 +119,7 @@ func (u *Conn) LocalAddr() (*Addr, error) {
return addr, nil
}
func (u *Conn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int) {
func (u *StdConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int) {
plaintext := make([]byte, MTU)
h := &header.H{}
fwPacket := &firewall.Packet{}
@ -137,8 +137,8 @@ func (u *Conn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall
for {
n, err := read(msgs)
if err != nil {
u.l.WithError(err).Error("Failed to read packets")
continue
u.l.WithError(err).Debug("udp socket is closed, exiting read loop")
return
}
//metric.Update(int64(n))
@ -150,7 +150,7 @@ func (u *Conn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall
}
}
func (u *Conn) ReadSingle(msgs []rawMessage) (int, error) {
func (u *StdConn) ReadSingle(msgs []rawMessage) (int, error) {
for {
n, _, err := unix.Syscall6(
unix.SYS_RECVMSG,
@ -171,7 +171,7 @@ func (u *Conn) ReadSingle(msgs []rawMessage) (int, error) {
}
}
func (u *Conn) ReadMulti(msgs []rawMessage) (int, error) {
func (u *StdConn) ReadMulti(msgs []rawMessage) (int, error) {
for {
n, _, err := unix.Syscall6(
unix.SYS_RECVMMSG,
@ -191,7 +191,7 @@ func (u *Conn) ReadMulti(msgs []rawMessage) (int, error) {
}
}
func (u *Conn) WriteTo(b []byte, addr *Addr) error {
func (u *StdConn) WriteTo(b []byte, addr *Addr) error {
var rsa unix.RawSockaddrInet6
rsa.Family = unix.AF_INET6
@ -221,7 +221,7 @@ func (u *Conn) WriteTo(b []byte, addr *Addr) error {
}
}
func (u *Conn) ReloadConfig(c *config.C) {
func (u *StdConn) ReloadConfig(c *config.C) {
b := c.GetInt("listen.read_buffer", 0)
if b > 0 {
err := u.SetRecvBuffer(b)
@ -253,7 +253,7 @@ func (u *Conn) ReloadConfig(c *config.C) {
}
}
func (u *Conn) getMemInfo(meminfo *_SK_MEMINFO) error {
func (u *StdConn) getMemInfo(meminfo *_SK_MEMINFO) error {
var vallen uint32 = 4 * _SK_MEMINFO_VARS
_, _, err := unix.Syscall6(unix.SYS_GETSOCKOPT, uintptr(u.sysFd), uintptr(unix.SOL_SOCKET), uintptr(unix.SO_MEMINFO), uintptr(unsafe.Pointer(meminfo)), uintptr(unsafe.Pointer(&vallen)), 0)
if err != 0 {
@ -262,11 +262,16 @@ func (u *Conn) getMemInfo(meminfo *_SK_MEMINFO) error {
return nil
}
func NewUDPStatsEmitter(udpConns []*Conn) func() {
func (u *StdConn) Close() error {
//TODO: this will not interrupt the read loop
return syscall.Close(u.sysFd)
}
func NewUDPStatsEmitter(udpConns []Conn) func() {
// Check if our kernel supports SO_MEMINFO before registering the gauges
var udpGauges [][_SK_MEMINFO_VARS]metrics.Gauge
var meminfo _SK_MEMINFO
if err := udpConns[0].getMemInfo(&meminfo); err == nil {
if err := udpConns[0].(*StdConn).getMemInfo(&meminfo); err == nil {
udpGauges = make([][_SK_MEMINFO_VARS]metrics.Gauge, len(udpConns))
for i := range udpConns {
udpGauges[i] = [_SK_MEMINFO_VARS]metrics.Gauge{
@ -285,7 +290,7 @@ func NewUDPStatsEmitter(udpConns []*Conn) func() {
return func() {
for i, gauges := range udpGauges {
if err := udpConns[i].getMemInfo(&meminfo); err == nil {
if err := udpConns[i].(*StdConn).getMemInfo(&meminfo); err == nil {
for j := 0; j < _SK_MEMINFO_VARS; j++ {
gauges[j].Update(int64(meminfo[j]))
}

View File

@ -30,7 +30,7 @@ type rawMessage struct {
Len uint32
}
func (u *Conn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) {
func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) {
msgs := make([]rawMessage, n)
buffers := make([][]byte, n)
names := make([][]byte, n)

View File

@ -33,7 +33,7 @@ type rawMessage struct {
Pad0 [4]byte
}
func (u *Conn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) {
func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) {
msgs := make([]rawMessage, n)
buffers := make([][]byte, n)
names := make([][]byte, n)

View File

@ -10,9 +10,14 @@ import (
"net"
"syscall"
"github.com/sirupsen/logrus"
"golang.org/x/sys/unix"
)
func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (Conn, error) {
return NewGenericListener(l, ip, port, multi, batch)
}
func NewListenConfig(multi bool) net.ListenConfig {
return net.ListenConfig{
Control: func(network, address string, c syscall.RawConn) error {
@ -36,6 +41,6 @@ func NewListenConfig(multi bool) net.ListenConfig {
}
}
func (u *Conn) Rebind() error {
func (u *GenericConn) Rebind() error {
return nil
}

403
udp/udp_rio_windows.go Normal file
View File

@ -0,0 +1,403 @@
//go:build !e2e_testing
// +build !e2e_testing
// Inspired by https://git.zx2c4.com/wireguard-go/tree/conn/bind_windows.go
package udp
import (
"errors"
"fmt"
"io"
"net"
"sync"
"sync/atomic"
"syscall"
"unsafe"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/firewall"
"github.com/slackhq/nebula/header"
"golang.org/x/sys/windows"
"golang.zx2c4.com/wireguard/conn/winrio"
)
// Assert we meet the standard conn interface
var _ Conn = &RIOConn{}
//go:linkname procyield runtime.procyield
func procyield(cycles uint32)
const (
packetsPerRing = 1024
bytesPerPacket = 2048 - 32
receiveSpins = 15
)
type ringPacket struct {
addr windows.RawSockaddrInet6
data [bytesPerPacket]byte
}
type ringBuffer struct {
packets uintptr
head, tail uint32
id winrio.BufferId
iocp windows.Handle
isFull bool
cq winrio.Cq
mu sync.Mutex
overlapped windows.Overlapped
}
type RIOConn struct {
isOpen atomic.Bool
l *logrus.Logger
sock windows.Handle
rx, tx ringBuffer
rq winrio.Rq
results [packetsPerRing]winrio.Result
}
func NewRIOListener(l *logrus.Logger, ip net.IP, port int) (*RIOConn, error) {
if !winrio.Initialize() {
return nil, errors.New("could not initialize winrio")
}
u := &RIOConn{l: l}
addr := [16]byte{}
copy(addr[:], ip.To16())
err := u.bind(&windows.SockaddrInet6{Addr: addr, Port: port})
if err != nil {
return nil, fmt.Errorf("bind: %w", err)
}
for i := 0; i < packetsPerRing; i++ {
err = u.insertReceiveRequest()
if err != nil {
return nil, fmt.Errorf("init rx ring: %w", err)
}
}
u.isOpen.Store(true)
return u, nil
}
func (u *RIOConn) bind(sa windows.Sockaddr) error {
var err error
u.sock, err = winrio.Socket(windows.AF_INET6, windows.SOCK_DGRAM, windows.IPPROTO_UDP)
if err != nil {
return err
}
// Enable v4 for this socket
syscall.SetsockoptInt(syscall.Handle(u.sock), syscall.IPPROTO_IPV6, syscall.IPV6_V6ONLY, 0)
err = u.rx.Open()
if err != nil {
return err
}
err = u.tx.Open()
if err != nil {
return err
}
u.rq, err = winrio.CreateRequestQueue(u.sock, packetsPerRing, 1, packetsPerRing, 1, u.rx.cq, u.tx.cq, 0)
if err != nil {
return err
}
err = windows.Bind(u.sock, sa)
if err != nil {
return err
}
return nil
}
func (u *RIOConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int) {
plaintext := make([]byte, MTU)
buffer := make([]byte, MTU)
h := &header.H{}
fwPacket := &firewall.Packet{}
udpAddr := &Addr{IP: make([]byte, 16)}
nb := make([]byte, 12, 12)
for {
// Just read one packet at a time
n, rua, err := u.receive(buffer)
if err != nil {
u.l.WithError(err).Debug("udp socket is closed, exiting read loop")
return
}
udpAddr.IP = rua.Addr[:]
p := (*[2]byte)(unsafe.Pointer(&udpAddr.Port))
p[0] = byte(rua.Port >> 8)
p[1] = byte(rua.Port)
r(udpAddr, plaintext[:0], buffer[:n], h, fwPacket, lhf, nb, q, cache.Get(u.l))
}
}
func (u *RIOConn) insertReceiveRequest() error {
packet := u.rx.Push()
dataBuffer := &winrio.Buffer{
Id: u.rx.id,
Offset: uint32(uintptr(unsafe.Pointer(&packet.data[0])) - u.rx.packets),
Length: uint32(len(packet.data)),
}
addressBuffer := &winrio.Buffer{
Id: u.rx.id,
Offset: uint32(uintptr(unsafe.Pointer(&packet.addr)) - u.rx.packets),
Length: uint32(unsafe.Sizeof(packet.addr)),
}
return winrio.ReceiveEx(u.rq, dataBuffer, 1, nil, addressBuffer, nil, nil, 0, uintptr(unsafe.Pointer(packet)))
}
func (u *RIOConn) receive(buf []byte) (int, windows.RawSockaddrInet6, error) {
if !u.isOpen.Load() {
return 0, windows.RawSockaddrInet6{}, net.ErrClosed
}
u.rx.mu.Lock()
defer u.rx.mu.Unlock()
var err error
var count uint32
var results [1]winrio.Result
retry:
count = 0
for tries := 0; count == 0 && tries < receiveSpins; tries++ {
if tries > 0 {
if !u.isOpen.Load() {
return 0, windows.RawSockaddrInet6{}, net.ErrClosed
}
procyield(1)
}
count = winrio.DequeueCompletion(u.rx.cq, results[:])
}
if count == 0 {
err = winrio.Notify(u.rx.cq)
if err != nil {
return 0, windows.RawSockaddrInet6{}, err
}
var bytes uint32
var key uintptr
var overlapped *windows.Overlapped
err = windows.GetQueuedCompletionStatus(u.rx.iocp, &bytes, &key, &overlapped, windows.INFINITE)
if err != nil {
return 0, windows.RawSockaddrInet6{}, err
}
if !u.isOpen.Load() {
return 0, windows.RawSockaddrInet6{}, net.ErrClosed
}
count = winrio.DequeueCompletion(u.rx.cq, results[:])
if count == 0 {
return 0, windows.RawSockaddrInet6{}, io.ErrNoProgress
}
}
u.rx.Return(1)
err = u.insertReceiveRequest()
if err != nil {
return 0, windows.RawSockaddrInet6{}, err
}
// We limit the MTU well below the 65k max for practicality, but this means a remote host can still send us
// huge packets. Just try again when this happens. The infinite loop this could cause is still limited to
// attacker bandwidth, just like the rest of the receive path.
if windows.Errno(results[0].Status) == windows.WSAEMSGSIZE {
goto retry
}
if results[0].Status != 0 {
return 0, windows.RawSockaddrInet6{}, windows.Errno(results[0].Status)
}
packet := (*ringPacket)(unsafe.Pointer(uintptr(results[0].RequestContext)))
ep := packet.addr
n := copy(buf, packet.data[:results[0].BytesTransferred])
return n, ep, nil
}
func (u *RIOConn) WriteTo(buf []byte, addr *Addr) error {
if !u.isOpen.Load() {
return net.ErrClosed
}
if len(buf) > bytesPerPacket {
return io.ErrShortBuffer
}
u.tx.mu.Lock()
defer u.tx.mu.Unlock()
count := winrio.DequeueCompletion(u.tx.cq, u.results[:])
if count == 0 && u.tx.isFull {
err := winrio.Notify(u.tx.cq)
if err != nil {
return err
}
var bytes uint32
var key uintptr
var overlapped *windows.Overlapped
err = windows.GetQueuedCompletionStatus(u.tx.iocp, &bytes, &key, &overlapped, windows.INFINITE)
if err != nil {
return err
}
if !u.isOpen.Load() {
return net.ErrClosed
}
count = winrio.DequeueCompletion(u.tx.cq, u.results[:])
if count == 0 {
return io.ErrNoProgress
}
}
if count > 0 {
u.tx.Return(count)
}
packet := u.tx.Push()
packet.addr.Family = windows.AF_INET6
p := (*[2]byte)(unsafe.Pointer(&packet.addr.Port))
p[0] = byte(addr.Port >> 8)
p[1] = byte(addr.Port)
copy(packet.addr.Addr[:], addr.IP.To16())
copy(packet.data[:], buf)
dataBuffer := &winrio.Buffer{
Id: u.tx.id,
Offset: uint32(uintptr(unsafe.Pointer(&packet.data[0])) - u.tx.packets),
Length: uint32(len(buf)),
}
addressBuffer := &winrio.Buffer{
Id: u.tx.id,
Offset: uint32(uintptr(unsafe.Pointer(&packet.addr)) - u.tx.packets),
Length: uint32(unsafe.Sizeof(packet.addr)),
}
return winrio.SendEx(u.rq, dataBuffer, 1, nil, addressBuffer, nil, nil, 0, 0)
}
func (u *RIOConn) LocalAddr() (*Addr, error) {
sa, err := windows.Getsockname(u.sock)
if err != nil {
return nil, err
}
v6 := sa.(*windows.SockaddrInet6)
return &Addr{
IP: v6.Addr[:],
Port: uint16(v6.Port),
}, nil
}
func (u *RIOConn) Rebind() error {
return nil
}
func (u *RIOConn) ReloadConfig(*config.C) {}
func (u *RIOConn) Close() error {
if !u.isOpen.CompareAndSwap(true, false) {
return nil
}
windows.PostQueuedCompletionStatus(u.rx.iocp, 0, 0, nil)
windows.PostQueuedCompletionStatus(u.tx.iocp, 0, 0, nil)
u.rx.CloseAndZero()
u.tx.CloseAndZero()
if u.sock != 0 {
windows.CloseHandle(u.sock)
}
return nil
}
func (ring *ringBuffer) Push() *ringPacket {
for ring.isFull {
panic("ring is full")
}
ret := (*ringPacket)(unsafe.Pointer(ring.packets + (uintptr(ring.tail%packetsPerRing) * unsafe.Sizeof(ringPacket{}))))
ring.tail += 1
if ring.tail%packetsPerRing == ring.head%packetsPerRing {
ring.isFull = true
}
return ret
}
func (ring *ringBuffer) Return(count uint32) {
if ring.head%packetsPerRing == ring.tail%packetsPerRing && !ring.isFull {
return
}
ring.head += count
ring.isFull = false
}
func (ring *ringBuffer) CloseAndZero() {
if ring.cq != 0 {
winrio.CloseCompletionQueue(ring.cq)
ring.cq = 0
}
if ring.iocp != 0 {
windows.CloseHandle(ring.iocp)
ring.iocp = 0
}
if ring.id != 0 {
winrio.DeregisterBuffer(ring.id)
ring.id = 0
}
if ring.packets != 0 {
windows.VirtualFree(ring.packets, 0, windows.MEM_RELEASE)
ring.packets = 0
}
ring.head = 0
ring.tail = 0
ring.isFull = false
}
func (ring *ringBuffer) Open() error {
var err error
packetsLen := unsafe.Sizeof(ringPacket{}) * packetsPerRing
ring.packets, err = windows.VirtualAlloc(0, packetsLen, windows.MEM_COMMIT|windows.MEM_RESERVE, windows.PAGE_READWRITE)
if err != nil {
return err
}
ring.id, err = winrio.RegisterPointer(unsafe.Pointer(ring.packets), uint32(packetsLen))
if err != nil {
return err
}
ring.iocp, err = windows.CreateIoCompletionPort(windows.InvalidHandle, 0, 0, 0)
if err != nil {
return err
}
ring.cq, err = winrio.CreateIOCPCompletionQueue(packetsPerRing, ring.iocp, 0, &ring.overlapped)
if err != nil {
return err
}
return nil
}

View File

@ -5,7 +5,9 @@ package udp
import (
"fmt"
"io"
"net"
"sync/atomic"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
@ -36,17 +38,18 @@ func (u *Packet) Copy() *Packet {
return n
}
type Conn struct {
type TesterConn struct {
Addr *Addr
RxPackets chan *Packet // Packets to receive into nebula
TxPackets chan *Packet // Packets transmitted outside by nebula
l *logrus.Logger
closed atomic.Bool
l *logrus.Logger
}
func NewListener(l *logrus.Logger, ip net.IP, port int, _ bool, _ int) (*Conn, error) {
return &Conn{
func NewListener(l *logrus.Logger, ip net.IP, port int, _ bool, _ int) (Conn, error) {
return &TesterConn{
Addr: &Addr{ip, uint16(port)},
RxPackets: make(chan *Packet, 10),
TxPackets: make(chan *Packet, 10),
@ -57,7 +60,11 @@ func NewListener(l *logrus.Logger, ip net.IP, port int, _ bool, _ int) (*Conn, e
// Send will place a UdpPacket onto the receive queue for nebula to consume
// this is an encrypted packet or a handshake message in most cases
// packets were transmitted from another nebula node, you can send them with Tun.Send
func (u *Conn) Send(packet *Packet) {
func (u *TesterConn) Send(packet *Packet) {
if u.closed.Load() {
return
}
h := &header.H{}
if err := h.Parse(packet.Data); err != nil {
panic(err)
@ -74,7 +81,7 @@ func (u *Conn) Send(packet *Packet) {
// Get will pull a UdpPacket from the transmit queue
// nebula meant to send this message on the network, it will be encrypted
// packets were ingested from the tun side (in most cases), you can send them with Tun.Send
func (u *Conn) Get(block bool) *Packet {
func (u *TesterConn) Get(block bool) *Packet {
if block {
return <-u.TxPackets
}
@ -91,7 +98,11 @@ func (u *Conn) Get(block bool) *Packet {
// Below this is boilerplate implementation to make nebula actually work
//********************************************************************************************************************//
func (u *Conn) WriteTo(b []byte, addr *Addr) error {
func (u *TesterConn) WriteTo(b []byte, addr *Addr) error {
if u.closed.Load() {
return io.ErrClosedPipe
}
p := &Packet{
Data: make([]byte, len(b), len(b)),
FromIp: make([]byte, 16),
@ -108,7 +119,7 @@ func (u *Conn) WriteTo(b []byte, addr *Addr) error {
return nil
}
func (u *Conn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int) {
func (u *TesterConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int) {
plaintext := make([]byte, MTU)
h := &header.H{}
fwPacket := &firewall.Packet{}
@ -126,17 +137,25 @@ func (u *Conn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall
}
}
func (u *Conn) ReloadConfig(*config.C) {}
func (u *TesterConn) ReloadConfig(*config.C) {}
func NewUDPStatsEmitter(_ []*Conn) func() {
func NewUDPStatsEmitter(_ []Conn) func() {
// No UDP stats for non-linux
return func() {}
}
func (u *Conn) LocalAddr() (*Addr, error) {
func (u *TesterConn) LocalAddr() (*Addr, error) {
return u.Addr, nil
}
func (u *Conn) Rebind() error {
func (u *TesterConn) Rebind() error {
return nil
}
func (u *TesterConn) Close() error {
if u.closed.CompareAndSwap(false, true) {
close(u.RxPackets)
close(u.TxPackets)
}
return nil
}

View File

@ -3,14 +3,31 @@
package udp
// Windows support is primarily implemented in udp_generic, besides NewListenConfig
import (
"fmt"
"net"
"syscall"
"github.com/sirupsen/logrus"
)
func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (Conn, error) {
if multi {
//NOTE: Technically we can support it with RIO but it wouldn't be at the socket level
// The udp stack would need to be reworked to hide away the implementation differences between
// Windows and Linux
return nil, fmt.Errorf("multiple udp listeners not supported on windows")
}
rc, err := NewRIOListener(l, ip, port)
if err == nil {
return rc, nil
}
l.WithError(err).Error("Falling back to standard udp sockets")
return NewGenericListener(l, ip, port, multi, batch)
}
func NewListenConfig(multi bool) net.ListenConfig {
return net.ListenConfig{
Control: func(network, address string, c syscall.RawConn) error {
@ -24,6 +41,6 @@ func NewListenConfig(multi bool) net.ListenConfig {
}
}
func (u *Conn) Rebind() error {
func (u *GenericConn) Rebind() error {
return nil
}

View File

@ -12,18 +12,38 @@ type ContextualError struct {
Context string
}
func NewContextualError(msg string, fields map[string]interface{}, realError error) ContextualError {
return ContextualError{Context: msg, Fields: fields, RealError: realError}
func NewContextualError(msg string, fields map[string]interface{}, realError error) *ContextualError {
return &ContextualError{Context: msg, Fields: fields, RealError: realError}
}
func (ce ContextualError) Error() string {
// ContextualizeIfNeeded is a helper function to turn an error into a ContextualError if it is not already one
func ContextualizeIfNeeded(msg string, err error) error {
switch err.(type) {
case *ContextualError:
return err
default:
return NewContextualError(msg, nil, err)
}
}
// LogWithContextIfNeeded is a helper function to log an error line for an error or ContextualError
func LogWithContextIfNeeded(msg string, err error, l *logrus.Logger) {
switch v := err.(type) {
case *ContextualError:
v.Log(l)
default:
l.WithError(err).Error(msg)
}
}
func (ce *ContextualError) Error() string {
if ce.RealError == nil {
return ce.Context
}
return ce.RealError.Error()
}
func (ce ContextualError) Unwrap() error {
func (ce *ContextualError) Unwrap() error {
if ce.RealError == nil {
return errors.New(ce.Context)
}

View File

@ -2,6 +2,7 @@ package util
import (
"errors"
"fmt"
"testing"
"github.com/sirupsen/logrus"
@ -67,3 +68,44 @@ func TestContextualError_Log(t *testing.T) {
e.Log(l)
assert.Equal(t, []string{"level=error error=error\n"}, tl.Logs)
}
func TestLogWithContextIfNeeded(t *testing.T) {
l := logrus.New()
l.Formatter = &logrus.TextFormatter{
DisableTimestamp: true,
DisableColors: true,
}
tl := NewTestLogWriter()
l.Out = tl
// Test ignoring fallback context
tl.Reset()
e := NewContextualError("test message", m{"field": "1"}, errors.New("error"))
LogWithContextIfNeeded("This should get thrown away", e, l)
assert.Equal(t, []string{"level=error msg=\"test message\" error=error field=1\n"}, tl.Logs)
// Test using fallback context
tl.Reset()
err := fmt.Errorf("this is a normal error")
LogWithContextIfNeeded("Fallback context woo", err, l)
assert.Equal(t, []string{"level=error msg=\"Fallback context woo\" error=\"this is a normal error\"\n"}, tl.Logs)
}
func TestContextualizeIfNeeded(t *testing.T) {
// Test ignoring fallback context
e := NewContextualError("test message", m{"field": "1"}, errors.New("error"))
assert.Same(t, e, ContextualizeIfNeeded("should be ignored", e))
// Test using fallback context
err := fmt.Errorf("this is a normal error")
cErr := ContextualizeIfNeeded("Fallback context woo", err)
switch v := cErr.(type) {
case *ContextualError:
assert.Equal(t, err, v.RealError)
default:
t.Error("Error was not wrapped")
t.Fail()
}
}