diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 0000000..abf74a0 --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,22 @@ +version: 2 +updates: + - package-ecosystem: "github-actions" + directory: "/" + schedule: + interval: "weekly" + + - package-ecosystem: "gomod" + directory: "/" + schedule: + interval: "weekly" + groups: + golang-x-dependencies: + patterns: + - "golang.org/x/*" + zx2c4-dependencies: + patterns: + - "golang.zx2c4.com/*" + protobuf-dependencies: + patterns: + - "github.com/golang/protobuf" + - "google.golang.org/protobuf" diff --git a/.github/workflows/gofmt.yml b/.github/workflows/gofmt.yml index e1a49b6..1552cc6 100644 --- a/.github/workflows/gofmt.yml +++ b/.github/workflows/gofmt.yml @@ -14,21 +14,12 @@ jobs: runs-on: ubuntu-latest steps: - - name: Set up Go 1.20 - uses: actions/setup-go@v2 - with: - go-version: "1.20" - id: go + - uses: actions/checkout@v4 - - name: Check out code into the Go module directory - uses: actions/checkout@v2 - - - uses: actions/cache@v2 + - uses: actions/setup-go@v4 with: - path: ~/go/pkg/mod - key: ${{ runner.os }}-gofmt1.20-${{ hashFiles('**/go.sum') }} - restore-keys: | - ${{ runner.os }}-gofmt1.20- + go-version-file: 'go.mod' + check-latest: true - name: Install goimports run: | diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 81203ad..ef4e507 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -7,25 +7,24 @@ name: Create release and upload binaries jobs: build-linux: - name: Build Linux All + name: Build Linux/BSD All runs-on: ubuntu-latest steps: - - name: Set up Go 1.20 - uses: actions/setup-go@v2 - with: - go-version: "1.20" + - uses: actions/checkout@v4 - - name: Checkout code - uses: actions/checkout@v2 + - uses: actions/setup-go@v4 + with: + go-version-file: 'go.mod' + check-latest: true - name: Build run: | - make BUILD_NUMBER="${GITHUB_REF#refs/tags/v}" release-linux release-freebsd + make BUILD_NUMBER="${GITHUB_REF#refs/tags/v}" release-linux release-freebsd release-openbsd release-netbsd mkdir release mv build/*.tar.gz release - name: Upload artifacts - uses: actions/upload-artifact@v2 + uses: actions/upload-artifact@v3 with: name: linux-latest path: release @@ -34,13 +33,12 @@ jobs: name: Build Windows runs-on: windows-latest steps: - - name: Set up Go 1.20 - uses: actions/setup-go@v2 - with: - go-version: "1.20" + - uses: actions/checkout@v4 - - name: Checkout code - uses: actions/checkout@v2 + - uses: actions/setup-go@v4 + with: + go-version-file: 'go.mod' + check-latest: true - name: Build run: | @@ -57,7 +55,7 @@ jobs: mv dist\windows\wintun build\dist\windows\ - name: Upload artifacts - uses: actions/upload-artifact@v2 + uses: actions/upload-artifact@v3 with: name: windows-latest path: build @@ -68,17 +66,16 @@ jobs: HAS_SIGNING_CREDS: ${{ secrets.AC_USERNAME != '' }} runs-on: macos-11 steps: - - name: Set up Go 1.20 - uses: actions/setup-go@v2 - with: - go-version: "1.20" + - uses: actions/checkout@v4 - - name: Checkout code - uses: actions/checkout@v2 + - uses: actions/setup-go@v4 + with: + go-version-file: 'go.mod' + check-latest: true - name: Import certificates if: env.HAS_SIGNING_CREDS == 'true' - uses: Apple-Actions/import-codesign-certs@v1 + uses: Apple-Actions/import-codesign-certs@v2 with: p12-file-base64: ${{ secrets.APPLE_DEVELOPER_CERTIFICATE_P12_BASE64 }} p12-password: ${{ secrets.APPLE_DEVELOPER_CERTIFICATE_PASSWORD }} @@ -107,7 +104,7 @@ jobs: fi - name: Upload artifacts - uses: actions/upload-artifact@v2 + uses: actions/upload-artifact@v3 with: name: darwin-latest path: ./release/* @@ -117,12 +114,16 @@ jobs: needs: [build-linux, build-darwin, build-windows] runs-on: ubuntu-latest steps: + - uses: actions/checkout@v4 + - name: Download artifacts - uses: actions/download-artifact@v2 + uses: actions/download-artifact@v3 + with: + path: artifacts - name: Zip Windows run: | - cd windows-latest + cd artifacts/windows-latest cp windows-amd64/* . zip -r nebula-windows-amd64.zip nebula.exe nebula-cert.exe dist cp windows-arm64/* . @@ -130,6 +131,7 @@ jobs: - name: Create sha256sum run: | + cd artifacts for dir in linux-latest darwin-latest windows-latest do ( @@ -159,195 +161,12 @@ jobs: - name: Create Release id: create_release - uses: actions/create-release@v1 env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - with: - tag_name: ${{ github.ref }} - release_name: Release ${{ github.ref }} - draft: false - prerelease: false - - ## - ## Upload assets (I wish we could just upload the whole folder at once... - ## - - - name: Upload SHASUM256.txt - uses: actions/upload-release-asset@v1.0.1 - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - with: - upload_url: ${{ steps.create_release.outputs.upload_url }} - asset_path: ./SHASUM256.txt - asset_name: SHASUM256.txt - asset_content_type: text/plain - - - name: Upload darwin zip - uses: actions/upload-release-asset@v1.0.1 - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - with: - upload_url: ${{ steps.create_release.outputs.upload_url }} - asset_path: ./darwin-latest/nebula-darwin.zip - asset_name: nebula-darwin.zip - asset_content_type: application/zip - - - name: Upload windows-amd64 - uses: actions/upload-release-asset@v1.0.1 - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - with: - upload_url: ${{ steps.create_release.outputs.upload_url }} - asset_path: ./windows-latest/nebula-windows-amd64.zip - asset_name: nebula-windows-amd64.zip - asset_content_type: application/zip - - - name: Upload windows-arm64 - uses: actions/upload-release-asset@v1.0.1 - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - with: - upload_url: ${{ steps.create_release.outputs.upload_url }} - asset_path: ./windows-latest/nebula-windows-arm64.zip - asset_name: nebula-windows-arm64.zip - asset_content_type: application/zip - - - name: Upload linux-amd64 - uses: actions/upload-release-asset@v1.0.1 - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - with: - upload_url: ${{ steps.create_release.outputs.upload_url }} - asset_path: ./linux-latest/nebula-linux-amd64.tar.gz - asset_name: nebula-linux-amd64.tar.gz - asset_content_type: application/gzip - - - name: Upload linux-386 - uses: actions/upload-release-asset@v1.0.1 - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - with: - upload_url: ${{ steps.create_release.outputs.upload_url }} - asset_path: ./linux-latest/nebula-linux-386.tar.gz - asset_name: nebula-linux-386.tar.gz - asset_content_type: application/gzip - - - name: Upload linux-ppc64le - uses: actions/upload-release-asset@v1.0.1 - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - with: - upload_url: ${{ steps.create_release.outputs.upload_url }} - asset_path: ./linux-latest/nebula-linux-ppc64le.tar.gz - asset_name: nebula-linux-ppc64le.tar.gz - asset_content_type: application/gzip - - - name: Upload linux-arm-5 - uses: actions/upload-release-asset@v1.0.1 - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - with: - upload_url: ${{ steps.create_release.outputs.upload_url }} - asset_path: ./linux-latest/nebula-linux-arm-5.tar.gz - asset_name: nebula-linux-arm-5.tar.gz - asset_content_type: application/gzip - - - name: Upload linux-arm-6 - uses: actions/upload-release-asset@v1.0.1 - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - with: - upload_url: ${{ steps.create_release.outputs.upload_url }} - asset_path: ./linux-latest/nebula-linux-arm-6.tar.gz - asset_name: nebula-linux-arm-6.tar.gz - asset_content_type: application/gzip - - - name: Upload linux-arm-7 - uses: actions/upload-release-asset@v1.0.1 - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - with: - upload_url: ${{ steps.create_release.outputs.upload_url }} - asset_path: ./linux-latest/nebula-linux-arm-7.tar.gz - asset_name: nebula-linux-arm-7.tar.gz - asset_content_type: application/gzip - - - name: Upload linux-arm64 - uses: actions/upload-release-asset@v1.0.1 - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - with: - upload_url: ${{ steps.create_release.outputs.upload_url }} - asset_path: ./linux-latest/nebula-linux-arm64.tar.gz - asset_name: nebula-linux-arm64.tar.gz - asset_content_type: application/gzip - - - name: Upload linux-mips - uses: actions/upload-release-asset@v1.0.1 - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - with: - upload_url: ${{ steps.create_release.outputs.upload_url }} - asset_path: ./linux-latest/nebula-linux-mips.tar.gz - asset_name: nebula-linux-mips.tar.gz - asset_content_type: application/gzip - - - name: Upload linux-mipsle - uses: actions/upload-release-asset@v1.0.1 - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - with: - upload_url: ${{ steps.create_release.outputs.upload_url }} - asset_path: ./linux-latest/nebula-linux-mipsle.tar.gz - asset_name: nebula-linux-mipsle.tar.gz - asset_content_type: application/gzip - - - name: Upload linux-mips64 - uses: actions/upload-release-asset@v1.0.1 - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - with: - upload_url: ${{ steps.create_release.outputs.upload_url }} - asset_path: ./linux-latest/nebula-linux-mips64.tar.gz - asset_name: nebula-linux-mips64.tar.gz - asset_content_type: application/gzip - - - name: Upload linux-mips64le - uses: actions/upload-release-asset@v1.0.1 - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - with: - upload_url: ${{ steps.create_release.outputs.upload_url }} - asset_path: ./linux-latest/nebula-linux-mips64le.tar.gz - asset_name: nebula-linux-mips64le.tar.gz - asset_content_type: application/gzip - - - name: Upload linux-mips-softfloat - uses: actions/upload-release-asset@v1.0.1 - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - with: - upload_url: ${{ steps.create_release.outputs.upload_url }} - asset_path: ./linux-latest/nebula-linux-mips-softfloat.tar.gz - asset_name: nebula-linux-mips-softfloat.tar.gz - asset_content_type: application/gzip - - - name: Upload linux-riscv64 - uses: actions/upload-release-asset@v1.0.1 - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - with: - upload_url: ${{ steps.create_release.outputs.upload_url }} - asset_path: ./linux-latest/nebula-linux-riscv64.tar.gz - asset_name: nebula-linux-riscv64.tar.gz - asset_content_type: application/gzip - - - name: Upload freebsd-amd64 - uses: actions/upload-release-asset@v1.0.1 - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - with: - upload_url: ${{ steps.create_release.outputs.upload_url }} - asset_path: ./linux-latest/nebula-freebsd-amd64.tar.gz - asset_name: nebula-freebsd-amd64.tar.gz - asset_content_type: application/gzip + run: | + cd artifacts + gh release create \ + --verify-tag \ + --title "Release ${{ github.ref_name }}" \ + "${{ github.ref_name }}" \ + SHASUM256.txt *-latest/*.zip *-latest/*.tar.gz diff --git a/.github/workflows/smoke.yml b/.github/workflows/smoke.yml index 9334ffc..99c7e82 100644 --- a/.github/workflows/smoke.yml +++ b/.github/workflows/smoke.yml @@ -18,21 +18,12 @@ jobs: runs-on: ubuntu-latest steps: - - name: Set up Go 1.20 - uses: actions/setup-go@v2 - with: - go-version: "1.20" - id: go + - uses: actions/checkout@v4 - - name: Check out code into the Go module directory - uses: actions/checkout@v2 - - - uses: actions/cache@v2 + - uses: actions/setup-go@v4 with: - path: ~/go/pkg/mod - key: ${{ runner.os }}-go1.20-${{ hashFiles('**/go.sum') }} - restore-keys: | - ${{ runner.os }}-go1.20- + go-version-file: 'go.mod' + check-latest: true - name: build run: make bin-docker diff --git a/.github/workflows/smoke/build-relay.sh b/.github/workflows/smoke/build-relay.sh index 1ec23c7..70b07f4 100755 --- a/.github/workflows/smoke/build-relay.sh +++ b/.github/workflows/smoke/build-relay.sh @@ -41,4 +41,4 @@ EOF ../../../../nebula-cert sign -name "host4" -groups "host,host4" -ip "192.168.100.4/24" ) -sudo docker build -t nebula:smoke-relay . +docker build -t nebula:smoke-relay . diff --git a/.github/workflows/smoke/build.sh b/.github/workflows/smoke/build.sh index 00b2346..9cbb200 100755 --- a/.github/workflows/smoke/build.sh +++ b/.github/workflows/smoke/build.sh @@ -36,4 +36,4 @@ mkdir ./build ../../../../nebula-cert sign -name "host4" -groups "host,host4" -ip "192.168.100.4/24" ) -sudo docker build -t "nebula:${NAME:-smoke}" . +docker build -t "nebula:${NAME:-smoke}" . diff --git a/.github/workflows/smoke/smoke-relay.sh b/.github/workflows/smoke/smoke-relay.sh index 91954d6..8926091 100755 --- a/.github/workflows/smoke/smoke-relay.sh +++ b/.github/workflows/smoke/smoke-relay.sh @@ -14,24 +14,24 @@ cleanup() { set +e if [ "$(jobs -r)" ] then - sudo docker kill lighthouse1 host2 host3 host4 + docker kill lighthouse1 host2 host3 host4 fi } trap cleanup EXIT -sudo docker run --name lighthouse1 --rm nebula:smoke-relay -config lighthouse1.yml -test -sudo docker run --name host2 --rm nebula:smoke-relay -config host2.yml -test -sudo docker run --name host3 --rm nebula:smoke-relay -config host3.yml -test -sudo docker run --name host4 --rm nebula:smoke-relay -config host4.yml -test +docker run --name lighthouse1 --rm nebula:smoke-relay -config lighthouse1.yml -test +docker run --name host2 --rm nebula:smoke-relay -config host2.yml -test +docker run --name host3 --rm nebula:smoke-relay -config host3.yml -test +docker run --name host4 --rm nebula:smoke-relay -config host4.yml -test -sudo docker run --name lighthouse1 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm nebula:smoke-relay -config lighthouse1.yml 2>&1 | tee logs/lighthouse1 | sed -u 's/^/ [lighthouse1] /' & +docker run --name lighthouse1 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm nebula:smoke-relay -config lighthouse1.yml 2>&1 | tee logs/lighthouse1 | sed -u 's/^/ [lighthouse1] /' & sleep 1 -sudo docker run --name host2 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm nebula:smoke-relay -config host2.yml 2>&1 | tee logs/host2 | sed -u 's/^/ [host2] /' & +docker run --name host2 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm nebula:smoke-relay -config host2.yml 2>&1 | tee logs/host2 | sed -u 's/^/ [host2] /' & sleep 1 -sudo docker run --name host3 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm nebula:smoke-relay -config host3.yml 2>&1 | tee logs/host3 | sed -u 's/^/ [host3] /' & +docker run --name host3 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm nebula:smoke-relay -config host3.yml 2>&1 | tee logs/host3 | sed -u 's/^/ [host3] /' & sleep 1 -sudo docker run --name host4 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm nebula:smoke-relay -config host4.yml 2>&1 | tee logs/host4 | sed -u 's/^/ [host4] /' & +docker run --name host4 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm nebula:smoke-relay -config host4.yml 2>&1 | tee logs/host4 | sed -u 's/^/ [host4] /' & sleep 1 set +x @@ -39,43 +39,43 @@ echo echo " *** Testing ping from lighthouse1" echo set -x -sudo docker exec lighthouse1 ping -c1 192.168.100.2 -sudo docker exec lighthouse1 ping -c1 192.168.100.3 -sudo docker exec lighthouse1 ping -c1 192.168.100.4 +docker exec lighthouse1 ping -c1 192.168.100.2 +docker exec lighthouse1 ping -c1 192.168.100.3 +docker exec lighthouse1 ping -c1 192.168.100.4 set +x echo echo " *** Testing ping from host2" echo set -x -sudo docker exec host2 ping -c1 192.168.100.1 +docker exec host2 ping -c1 192.168.100.1 # Should fail because no relay configured in this direction -! sudo docker exec host2 ping -c1 192.168.100.3 -w5 || exit 1 -! sudo docker exec host2 ping -c1 192.168.100.4 -w5 || exit 1 +! docker exec host2 ping -c1 192.168.100.3 -w5 || exit 1 +! docker exec host2 ping -c1 192.168.100.4 -w5 || exit 1 set +x echo echo " *** Testing ping from host3" echo set -x -sudo docker exec host3 ping -c1 192.168.100.1 -sudo docker exec host3 ping -c1 192.168.100.2 -sudo docker exec host3 ping -c1 192.168.100.4 +docker exec host3 ping -c1 192.168.100.1 +docker exec host3 ping -c1 192.168.100.2 +docker exec host3 ping -c1 192.168.100.4 set +x echo echo " *** Testing ping from host4" echo set -x -sudo docker exec host4 ping -c1 192.168.100.1 +docker exec host4 ping -c1 192.168.100.1 # Should fail because relays not allowed -! sudo docker exec host4 ping -c1 192.168.100.2 -w5 || exit 1 -sudo docker exec host4 ping -c1 192.168.100.3 +! docker exec host4 ping -c1 192.168.100.2 -w5 || exit 1 +docker exec host4 ping -c1 192.168.100.3 -sudo docker exec host4 sh -c 'kill 1' -sudo docker exec host3 sh -c 'kill 1' -sudo docker exec host2 sh -c 'kill 1' -sudo docker exec lighthouse1 sh -c 'kill 1' +docker exec host4 sh -c 'kill 1' +docker exec host3 sh -c 'kill 1' +docker exec host2 sh -c 'kill 1' +docker exec lighthouse1 sh -c 'kill 1' sleep 1 if [ "$(jobs -r)" ] diff --git a/.github/workflows/smoke/smoke.sh b/.github/workflows/smoke/smoke.sh index 4aa8029..3177255 100755 --- a/.github/workflows/smoke/smoke.sh +++ b/.github/workflows/smoke/smoke.sh @@ -14,7 +14,7 @@ cleanup() { set +e if [ "$(jobs -r)" ] then - sudo docker kill lighthouse1 host2 host3 host4 + docker kill lighthouse1 host2 host3 host4 fi } @@ -22,51 +22,51 @@ trap cleanup EXIT CONTAINER="nebula:${NAME:-smoke}" -sudo docker run --name lighthouse1 --rm "$CONTAINER" -config lighthouse1.yml -test -sudo docker run --name host2 --rm "$CONTAINER" -config host2.yml -test -sudo docker run --name host3 --rm "$CONTAINER" -config host3.yml -test -sudo docker run --name host4 --rm "$CONTAINER" -config host4.yml -test +docker run --name lighthouse1 --rm "$CONTAINER" -config lighthouse1.yml -test +docker run --name host2 --rm "$CONTAINER" -config host2.yml -test +docker run --name host3 --rm "$CONTAINER" -config host3.yml -test +docker run --name host4 --rm "$CONTAINER" -config host4.yml -test -sudo docker run --name lighthouse1 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config lighthouse1.yml 2>&1 | tee logs/lighthouse1 | sed -u 's/^/ [lighthouse1] /' & +docker run --name lighthouse1 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config lighthouse1.yml 2>&1 | tee logs/lighthouse1 | sed -u 's/^/ [lighthouse1] /' & sleep 1 -sudo docker run --name host2 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config host2.yml 2>&1 | tee logs/host2 | sed -u 's/^/ [host2] /' & +docker run --name host2 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config host2.yml 2>&1 | tee logs/host2 | sed -u 's/^/ [host2] /' & sleep 1 -sudo docker run --name host3 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config host3.yml 2>&1 | tee logs/host3 | sed -u 's/^/ [host3] /' & +docker run --name host3 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config host3.yml 2>&1 | tee logs/host3 | sed -u 's/^/ [host3] /' & sleep 1 -sudo docker run --name host4 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config host4.yml 2>&1 | tee logs/host4 | sed -u 's/^/ [host4] /' & +docker run --name host4 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config host4.yml 2>&1 | tee logs/host4 | sed -u 's/^/ [host4] /' & sleep 1 # grab tcpdump pcaps for debugging -sudo docker exec lighthouse1 tcpdump -i nebula1 -q -w - -U 2>logs/lighthouse1.inside.log >logs/lighthouse1.inside.pcap & -sudo docker exec lighthouse1 tcpdump -i eth0 -q -w - -U 2>logs/lighthouse1.outside.log >logs/lighthouse1.outside.pcap & -sudo docker exec host2 tcpdump -i nebula1 -q -w - -U 2>logs/host2.inside.log >logs/host2.inside.pcap & -sudo docker exec host2 tcpdump -i eth0 -q -w - -U 2>logs/host2.outside.log >logs/host2.outside.pcap & -sudo docker exec host3 tcpdump -i nebula1 -q -w - -U 2>logs/host3.inside.log >logs/host3.inside.pcap & -sudo docker exec host3 tcpdump -i eth0 -q -w - -U 2>logs/host3.outside.log >logs/host3.outside.pcap & -sudo docker exec host4 tcpdump -i nebula1 -q -w - -U 2>logs/host4.inside.log >logs/host4.inside.pcap & -sudo docker exec host4 tcpdump -i eth0 -q -w - -U 2>logs/host4.outside.log >logs/host4.outside.pcap & +docker exec lighthouse1 tcpdump -i nebula1 -q -w - -U 2>logs/lighthouse1.inside.log >logs/lighthouse1.inside.pcap & +docker exec lighthouse1 tcpdump -i eth0 -q -w - -U 2>logs/lighthouse1.outside.log >logs/lighthouse1.outside.pcap & +docker exec host2 tcpdump -i nebula1 -q -w - -U 2>logs/host2.inside.log >logs/host2.inside.pcap & +docker exec host2 tcpdump -i eth0 -q -w - -U 2>logs/host2.outside.log >logs/host2.outside.pcap & +docker exec host3 tcpdump -i nebula1 -q -w - -U 2>logs/host3.inside.log >logs/host3.inside.pcap & +docker exec host3 tcpdump -i eth0 -q -w - -U 2>logs/host3.outside.log >logs/host3.outside.pcap & +docker exec host4 tcpdump -i nebula1 -q -w - -U 2>logs/host4.inside.log >logs/host4.inside.pcap & +docker exec host4 tcpdump -i eth0 -q -w - -U 2>logs/host4.outside.log >logs/host4.outside.pcap & -sudo docker exec host2 ncat -nklv 0.0.0.0 2000 & -sudo docker exec host3 ncat -nklv 0.0.0.0 2000 & -sudo docker exec host2 ncat -e '/usr/bin/echo host2' -nkluv 0.0.0.0 3000 & -sudo docker exec host3 ncat -e '/usr/bin/echo host3' -nkluv 0.0.0.0 3000 & +docker exec host2 ncat -nklv 0.0.0.0 2000 & +docker exec host3 ncat -nklv 0.0.0.0 2000 & +docker exec host2 ncat -e '/usr/bin/echo host2' -nkluv 0.0.0.0 3000 & +docker exec host3 ncat -e '/usr/bin/echo host3' -nkluv 0.0.0.0 3000 & set +x echo echo " *** Testing ping from lighthouse1" echo set -x -sudo docker exec lighthouse1 ping -c1 192.168.100.2 -sudo docker exec lighthouse1 ping -c1 192.168.100.3 +docker exec lighthouse1 ping -c1 192.168.100.2 +docker exec lighthouse1 ping -c1 192.168.100.3 set +x echo echo " *** Testing ping from host2" echo set -x -sudo docker exec host2 ping -c1 192.168.100.1 +docker exec host2 ping -c1 192.168.100.1 # Should fail because not allowed by host3 inbound firewall -! sudo docker exec host2 ping -c1 192.168.100.3 -w5 || exit 1 +! docker exec host2 ping -c1 192.168.100.3 -w5 || exit 1 set +x echo @@ -74,34 +74,34 @@ echo " *** Testing ncat from host2" echo set -x # Should fail because not allowed by host3 inbound firewall -! sudo docker exec host2 ncat -nzv -w5 192.168.100.3 2000 || exit 1 -! sudo docker exec host2 ncat -nzuv -w5 192.168.100.3 3000 | grep -q host3 || exit 1 +! docker exec host2 ncat -nzv -w5 192.168.100.3 2000 || exit 1 +! docker exec host2 ncat -nzuv -w5 192.168.100.3 3000 | grep -q host3 || exit 1 set +x echo echo " *** Testing ping from host3" echo set -x -sudo docker exec host3 ping -c1 192.168.100.1 -sudo docker exec host3 ping -c1 192.168.100.2 +docker exec host3 ping -c1 192.168.100.1 +docker exec host3 ping -c1 192.168.100.2 set +x echo echo " *** Testing ncat from host3" echo set -x -sudo docker exec host3 ncat -nzv -w5 192.168.100.2 2000 -sudo docker exec host3 ncat -nzuv -w5 192.168.100.2 3000 | grep -q host2 +docker exec host3 ncat -nzv -w5 192.168.100.2 2000 +docker exec host3 ncat -nzuv -w5 192.168.100.2 3000 | grep -q host2 set +x echo echo " *** Testing ping from host4" echo set -x -sudo docker exec host4 ping -c1 192.168.100.1 +docker exec host4 ping -c1 192.168.100.1 # Should fail because not allowed by host4 outbound firewall -! sudo docker exec host4 ping -c1 192.168.100.2 -w5 || exit 1 -! sudo docker exec host4 ping -c1 192.168.100.3 -w5 || exit 1 +! docker exec host4 ping -c1 192.168.100.2 -w5 || exit 1 +! docker exec host4 ping -c1 192.168.100.3 -w5 || exit 1 set +x echo @@ -109,10 +109,10 @@ echo " *** Testing ncat from host4" echo set -x # Should fail because not allowed by host4 outbound firewall -! sudo docker exec host4 ncat -nzv -w5 192.168.100.2 2000 || exit 1 -! sudo docker exec host4 ncat -nzv -w5 192.168.100.3 2000 || exit 1 -! sudo docker exec host4 ncat -nzuv -w5 192.168.100.2 3000 | grep -q host2 || exit 1 -! sudo docker exec host4 ncat -nzuv -w5 192.168.100.3 3000 | grep -q host3 || exit 1 +! docker exec host4 ncat -nzv -w5 192.168.100.2 2000 || exit 1 +! docker exec host4 ncat -nzv -w5 192.168.100.3 2000 || exit 1 +! docker exec host4 ncat -nzuv -w5 192.168.100.2 3000 | grep -q host2 || exit 1 +! docker exec host4 ncat -nzuv -w5 192.168.100.3 3000 | grep -q host3 || exit 1 set +x echo @@ -120,15 +120,15 @@ echo " *** Testing conntrack" echo set -x # host2 can ping host3 now that host3 pinged it first -sudo docker exec host2 ping -c1 192.168.100.3 +docker exec host2 ping -c1 192.168.100.3 # host4 can ping host2 once conntrack established -sudo docker exec host2 ping -c1 192.168.100.4 -sudo docker exec host4 ping -c1 192.168.100.2 +docker exec host2 ping -c1 192.168.100.4 +docker exec host4 ping -c1 192.168.100.2 -sudo docker exec host4 sh -c 'kill 1' -sudo docker exec host3 sh -c 'kill 1' -sudo docker exec host2 sh -c 'kill 1' -sudo docker exec lighthouse1 sh -c 'kill 1' +docker exec host4 sh -c 'kill 1' +docker exec host3 sh -c 'kill 1' +docker exec host2 sh -c 'kill 1' +docker exec lighthouse1 sh -c 'kill 1' sleep 1 if [ "$(jobs -r)" ] diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 05aff78..cc3725f 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -18,21 +18,12 @@ jobs: runs-on: ubuntu-latest steps: - - name: Set up Go 1.20 - uses: actions/setup-go@v2 - with: - go-version: "1.20" - id: go + - uses: actions/checkout@v4 - - name: Check out code into the Go module directory - uses: actions/checkout@v2 - - - uses: actions/cache@v2 + - uses: actions/setup-go@v4 with: - path: ~/go/pkg/mod - key: ${{ runner.os }}-go1.20-${{ hashFiles('**/go.sum') }} - restore-keys: | - ${{ runner.os }}-go1.20- + go-version-file: 'go.mod' + check-latest: true - name: Build run: make all @@ -57,21 +48,12 @@ jobs: runs-on: ubuntu-latest steps: - - name: Set up Go 1.20 - uses: actions/setup-go@v2 - with: - go-version: "1.20" - id: go + - uses: actions/checkout@v4 - - name: Check out code into the Go module directory - uses: actions/checkout@v2 - - - uses: actions/cache@v2 + - uses: actions/setup-go@v4 with: - path: ~/go/pkg/mod - key: ${{ runner.os }}-go1.20-${{ hashFiles('**/go.sum') }} - restore-keys: | - ${{ runner.os }}-go1.20- + go-version-file: 'go.mod' + check-latest: true - name: Build run: make bin-boringcrypto @@ -90,21 +72,12 @@ jobs: os: [windows-latest, macos-11] steps: - - name: Set up Go 1.20 - uses: actions/setup-go@v2 - with: - go-version: "1.20" - id: go + - uses: actions/checkout@v4 - - name: Check out code into the Go module directory - uses: actions/checkout@v2 - - - uses: actions/cache@v2 + - uses: actions/setup-go@v4 with: - path: ~/go/pkg/mod - key: ${{ runner.os }}-go1.20-${{ hashFiles('**/go.sum') }} - restore-keys: | - ${{ runner.os }}-go1.20- + go-version-file: 'go.mod' + check-latest: true - name: Build nebula run: go build ./cmd/nebula diff --git a/CHANGELOG.md b/CHANGELOG.md index 3febb32..6951a4a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,19 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +## [1.7.2] - 2023-06-01 + +### Fixed + +- Fix a freeze during config reload if the `static_host_map` config was changed. (#886) + +## [1.7.1] - 2023-05-18 + +### Fixed + +- Fix IPv4 addresses returned by `static_host_map` DNS lookup queries being + treated as IPv6 addresses. (#877) + ## [1.7.0] - 2023-05-17 ### Added @@ -475,7 +488,9 @@ created.) - Initial public release. -[Unreleased]: https://github.com/slackhq/nebula/compare/v1.7.0...HEAD +[Unreleased]: https://github.com/slackhq/nebula/compare/v1.7.2...HEAD +[1.7.2]: https://github.com/slackhq/nebula/releases/tag/v1.7.2 +[1.7.1]: https://github.com/slackhq/nebula/releases/tag/v1.7.1 [1.7.0]: https://github.com/slackhq/nebula/releases/tag/v1.7.0 [1.6.1]: https://github.com/slackhq/nebula/releases/tag/v1.6.1 [1.6.0]: https://github.com/slackhq/nebula/releases/tag/v1.6.0 diff --git a/Makefile b/Makefile index 1d4a7a2..89bd284 100644 --- a/Makefile +++ b/Makefile @@ -12,6 +12,8 @@ ifeq ($(OS),Windows_NT) GOISMIN := $(shell IF "$(GOVERSION)" GEQ "$(GOMINVERSION)" ECHO 1) NEBULA_CMD_SUFFIX = .exe NULL_FILE = nul + # RIO on windows does pointer stuff that makes go vet angry + VET_FLAGS = -unsafeptr=false else GOVERSION := $(shell go version | awk '{print substr($$3, 3)}') GOISMIN := $(shell expr "$(GOVERSION)" ">=" "$(GOMINVERSION)") @@ -44,10 +46,21 @@ ALL_LINUX = linux-amd64 \ linux-mips-softfloat \ linux-riscv64 +ALL_FREEBSD = freebsd-amd64 \ + freebsd-arm64 + +ALL_OPENBSD = openbsd-amd64 \ + openbsd-arm64 + +ALL_NETBSD = netbsd-amd64 \ + netbsd-arm64 + ALL = $(ALL_LINUX) \ + $(ALL_FREEBSD) \ + $(ALL_OPENBSD) \ + $(ALL_NETBSD) \ darwin-amd64 \ darwin-arm64 \ - freebsd-amd64 \ windows-amd64 \ windows-arm64 @@ -75,7 +88,11 @@ release: $(ALL:%=build/nebula-%.tar.gz) release-linux: $(ALL_LINUX:%=build/nebula-%.tar.gz) -release-freebsd: build/nebula-freebsd-amd64.tar.gz +release-freebsd: $(ALL_FREEBSD:%=build/nebula-%.tar.gz) + +release-openbsd: $(ALL_OPENBSD:%=build/nebula-%.tar.gz) + +release-netbsd: $(ALL_NETBSD:%=build/nebula-%.tar.gz) release-boringcrypto: build/nebula-linux-$(shell go env GOARCH)-boringcrypto.tar.gz @@ -93,6 +110,9 @@ bin-darwin: build/darwin-amd64/nebula build/darwin-amd64/nebula-cert bin-freebsd: build/freebsd-amd64/nebula build/freebsd-amd64/nebula-cert mv $? . +bin-freebsd-arm64: build/freebsd-arm64/nebula build/freebsd-arm64/nebula-cert + mv $? . + bin-boringcrypto: build/linux-$(shell go env GOARCH)-boringcrypto/nebula build/linux-$(shell go env GOARCH)-boringcrypto/nebula-cert mv $? . @@ -137,7 +157,7 @@ build/nebula-%.zip: build/%/nebula.exe build/%/nebula-cert.exe cd build/$* && zip ../nebula-$*.zip nebula.exe nebula-cert.exe vet: - go vet -v ./... + go vet $(VET_FLAGS) -v ./... test: go test -v ./... diff --git a/README.md b/README.md index 925aa61..6a7e5f2 100644 --- a/README.md +++ b/README.md @@ -108,7 +108,7 @@ For each host, copy the nebula binary to the host, along with `config.yml` from ## Building Nebula from source -Download go and clone this repo. Change to the nebula directory. +Make sure you have [go](https://go.dev/doc/install) installed and clone this repo. Change to the nebula directory. To build nebula for all platforms: `make all` diff --git a/cert.go b/cert.go deleted file mode 100644 index bbd29c6..0000000 --- a/cert.go +++ /dev/null @@ -1,163 +0,0 @@ -package nebula - -import ( - "errors" - "fmt" - "io/ioutil" - "strings" - "time" - - "github.com/sirupsen/logrus" - "github.com/slackhq/nebula/cert" - "github.com/slackhq/nebula/config" -) - -type CertState struct { - certificate *cert.NebulaCertificate - rawCertificate []byte - rawCertificateNoKey []byte - publicKey []byte - privateKey []byte -} - -func NewCertState(certificate *cert.NebulaCertificate, privateKey []byte) (*CertState, error) { - // Marshal the certificate to ensure it is valid - rawCertificate, err := certificate.Marshal() - if err != nil { - return nil, fmt.Errorf("invalid nebula certificate on interface: %s", err) - } - - publicKey := certificate.Details.PublicKey - cs := &CertState{ - rawCertificate: rawCertificate, - certificate: certificate, // PublicKey has been set to nil above - privateKey: privateKey, - publicKey: publicKey, - } - - cs.certificate.Details.PublicKey = nil - rawCertNoKey, err := cs.certificate.Marshal() - if err != nil { - return nil, fmt.Errorf("error marshalling certificate no key: %s", err) - } - cs.rawCertificateNoKey = rawCertNoKey - // put public key back - cs.certificate.Details.PublicKey = cs.publicKey - return cs, nil -} - -func NewCertStateFromConfig(c *config.C) (*CertState, error) { - var pemPrivateKey []byte - var err error - - privPathOrPEM := c.GetString("pki.key", "") - - if privPathOrPEM == "" { - return nil, errors.New("no pki.key path or PEM data provided") - } - - if strings.Contains(privPathOrPEM, "-----BEGIN") { - pemPrivateKey = []byte(privPathOrPEM) - privPathOrPEM = "" - } else { - pemPrivateKey, err = ioutil.ReadFile(privPathOrPEM) - if err != nil { - return nil, fmt.Errorf("unable to read pki.key file %s: %s", privPathOrPEM, err) - } - } - - rawKey, _, curve, err := cert.UnmarshalPrivateKey(pemPrivateKey) - if err != nil { - return nil, fmt.Errorf("error while unmarshaling pki.key %s: %s", privPathOrPEM, err) - } - - var rawCert []byte - - pubPathOrPEM := c.GetString("pki.cert", "") - - if pubPathOrPEM == "" { - return nil, errors.New("no pki.cert path or PEM data provided") - } - - if strings.Contains(pubPathOrPEM, "-----BEGIN") { - rawCert = []byte(pubPathOrPEM) - pubPathOrPEM = "" - } else { - rawCert, err = ioutil.ReadFile(pubPathOrPEM) - if err != nil { - return nil, fmt.Errorf("unable to read pki.cert file %s: %s", pubPathOrPEM, err) - } - } - - nebulaCert, _, err := cert.UnmarshalNebulaCertificateFromPEM(rawCert) - if err != nil { - return nil, fmt.Errorf("error while unmarshaling pki.cert %s: %s", pubPathOrPEM, err) - } - - if nebulaCert.Expired(time.Now()) { - return nil, fmt.Errorf("nebula certificate for this host is expired") - } - - if len(nebulaCert.Details.Ips) == 0 { - return nil, fmt.Errorf("no IPs encoded in certificate") - } - - if err = nebulaCert.VerifyPrivateKey(curve, rawKey); err != nil { - return nil, fmt.Errorf("private key is not a pair with public key in nebula cert") - } - - return NewCertState(nebulaCert, rawKey) -} - -func loadCAFromConfig(l *logrus.Logger, c *config.C) (*cert.NebulaCAPool, error) { - var rawCA []byte - var err error - - caPathOrPEM := c.GetString("pki.ca", "") - if caPathOrPEM == "" { - return nil, errors.New("no pki.ca path or PEM data provided") - } - - if strings.Contains(caPathOrPEM, "-----BEGIN") { - rawCA = []byte(caPathOrPEM) - - } else { - rawCA, err = ioutil.ReadFile(caPathOrPEM) - if err != nil { - return nil, fmt.Errorf("unable to read pki.ca file %s: %s", caPathOrPEM, err) - } - } - - CAs, err := cert.NewCAPoolFromBytes(rawCA) - if errors.Is(err, cert.ErrExpired) { - var expired int - for _, cert := range CAs.CAs { - if cert.Expired(time.Now()) { - expired++ - l.WithField("cert", cert).Warn("expired certificate present in CA pool") - } - } - - if expired >= len(CAs.CAs) { - return nil, errors.New("no valid CA certificates present") - } - - } else if err != nil { - return nil, fmt.Errorf("error while adding CA certificate to CA trust store: %s", err) - } - - for _, fp := range c.GetStringSlice("pki.blocklist", []string{}) { - l.WithField("fingerprint", fp).Info("Blocklisting cert") - CAs.BlocklistFingerprint(fp) - } - - // Support deprecated config for at least one minor release to allow for migrations - //TODO: remove in 2022 or later - for _, fp := range c.GetStringSlice("pki.blacklist", []string{}) { - l.WithField("fingerprint", fp).Info("Blocklisting cert") - l.Warn("pki.blacklist is deprecated and will not be supported in a future release. Please migrate your config to use pki.blocklist") - CAs.BlocklistFingerprint(fp) - } - - return CAs, nil -} diff --git a/cert/cert.go b/cert/cert.go index 24a75e3..4f1b776 100644 --- a/cert/cert.go +++ b/cert/cert.go @@ -272,6 +272,9 @@ func EncryptAndMarshalSigningPrivateKey(curve Curve, b []byte, passphrase []byte }, Ciphertext: ciphertext, }) + if err != nil { + return nil, err + } switch curve { case Curve_CURVE25519: diff --git a/cert/crypto.go b/cert/crypto.go index 94f4c48..3558e1a 100644 --- a/cert/crypto.go +++ b/cert/crypto.go @@ -77,6 +77,9 @@ func aes256Decrypt(passphrase []byte, kdfParams *Argon2Parameters, data []byte) } gcm, err := cipher.NewGCM(block) + if err != nil { + return nil, err + } nonce, ciphertext, err := splitNonceCiphertext(data, gcm.NonceSize()) if err != nil { diff --git a/cmd/nebula-service/main.go b/cmd/nebula-service/main.go index c1de267..8d0eaa1 100644 --- a/cmd/nebula-service/main.go +++ b/cmd/nebula-service/main.go @@ -59,13 +59,8 @@ func main() { } ctrl, err := nebula.Main(c, *configTest, Build, l, nil) - - switch v := err.(type) { - case util.ContextualError: - v.Log(l) - os.Exit(1) - case error: - l.WithError(err).Error("Failed to start") + if err != nil { + util.LogWithContextIfNeeded("Failed to start", err, l) os.Exit(1) } diff --git a/cmd/nebula/main.go b/cmd/nebula/main.go index e9b285e..5cf0a02 100644 --- a/cmd/nebula/main.go +++ b/cmd/nebula/main.go @@ -53,18 +53,14 @@ func main() { } ctrl, err := nebula.Main(c, *configTest, Build, l, nil) - - switch v := err.(type) { - case util.ContextualError: - v.Log(l) - os.Exit(1) - case error: - l.WithError(err).Error("Failed to start") + if err != nil { + util.LogWithContextIfNeeded("Failed to start", err, l) os.Exit(1) } if !*configTest { ctrl.Start() + notifyReady(l) ctrl.ShutdownBlock() } diff --git a/cmd/nebula/notify_linux.go b/cmd/nebula/notify_linux.go new file mode 100644 index 0000000..8c3dca5 --- /dev/null +++ b/cmd/nebula/notify_linux.go @@ -0,0 +1,42 @@ +package main + +import ( + "net" + "os" + "time" + + "github.com/sirupsen/logrus" +) + +// SdNotifyReady tells systemd the service is ready and dependent services can now be started +// https://www.freedesktop.org/software/systemd/man/sd_notify.html +// https://www.freedesktop.org/software/systemd/man/systemd.service.html +const SdNotifyReady = "READY=1" + +func notifyReady(l *logrus.Logger) { + sockName := os.Getenv("NOTIFY_SOCKET") + if sockName == "" { + l.Debugln("NOTIFY_SOCKET systemd env var not set, not sending ready signal") + return + } + + conn, err := net.DialTimeout("unixgram", sockName, time.Second) + if err != nil { + l.WithError(err).Error("failed to connect to systemd notification socket") + return + } + defer conn.Close() + + err = conn.SetWriteDeadline(time.Now().Add(time.Second)) + if err != nil { + l.WithError(err).Error("failed to set the write deadline for the systemd notification socket") + return + } + + if _, err = conn.Write([]byte(SdNotifyReady)); err != nil { + l.WithError(err).Error("failed to signal the systemd notification socket") + return + } + + l.Debugln("notified systemd the service is ready") +} diff --git a/cmd/nebula/notify_notlinux.go b/cmd/nebula/notify_notlinux.go new file mode 100644 index 0000000..e7758e0 --- /dev/null +++ b/cmd/nebula/notify_notlinux.go @@ -0,0 +1,10 @@ +//go:build !linux +// +build !linux + +package main + +import "github.com/sirupsen/logrus" + +func notifyReady(_ *logrus.Logger) { + // No init service to notify +} diff --git a/config/config.go b/config/config.go index 966e905..bc3818d 100644 --- a/config/config.go +++ b/config/config.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "io/ioutil" + "math" "os" "os/signal" "path/filepath" @@ -15,7 +16,7 @@ import ( "syscall" "time" - "github.com/imdario/mergo" + "dario.cat/mergo" "github.com/sirupsen/logrus" "gopkg.in/yaml.v2" ) @@ -236,6 +237,15 @@ func (c *C) GetInt(k string, d int) int { return v } +// GetUint32 will get the uint32 for k or return the default d if not found or invalid +func (c *C) GetUint32(k string, d uint32) uint32 { + r := c.GetInt(k, int(d)) + if uint64(r) > uint64(math.MaxUint32) { + return d + } + return uint32(r) +} + // GetBool will get the bool for k or return the default d if not found or invalid func (c *C) GetBool(k string, d bool) bool { r := strings.ToLower(c.GetString(k, fmt.Sprintf("%v", d))) diff --git a/config/config_test.go b/config/config_test.go index 52bf2e4..1001f8d 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -7,7 +7,7 @@ import ( "testing" "time" - "github.com/imdario/mergo" + "dario.cat/mergo" "github.com/slackhq/nebula/test" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" diff --git a/connection_manager.go b/connection_manager.go index 528cf1b..ce11f19 100644 --- a/connection_manager.go +++ b/connection_manager.go @@ -231,7 +231,7 @@ func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo) index = existing.LocalIndex switch r.Type { case TerminalType: - relayFrom = newhostinfo.vpnIp + relayFrom = n.intf.myVpnIp relayTo = existing.PeerIp case ForwardingType: relayFrom = existing.PeerIp @@ -256,7 +256,7 @@ func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo) } switch r.Type { case TerminalType: - relayFrom = newhostinfo.vpnIp + relayFrom = n.intf.myVpnIp relayTo = r.PeerIp case ForwardingType: relayFrom = r.PeerIp @@ -405,8 +405,8 @@ func (n *connectionManager) shouldSwapPrimary(current, primary *HostInfo) bool { return false } - certState := n.intf.certState.Load() - return bytes.Equal(current.ConnectionState.certState.certificate.Signature, certState.certificate.Signature) + certState := n.intf.pki.GetCertState() + return bytes.Equal(current.ConnectionState.myCert.Signature, certState.Certificate.Signature) } func (n *connectionManager) swapPrimary(current, primary *HostInfo) { @@ -427,7 +427,7 @@ func (n *connectionManager) isInvalidCertificate(now time.Time, hostinfo *HostIn return false } - valid, err := remoteCert.VerifyWithCache(now, n.intf.caPool) + valid, err := remoteCert.VerifyWithCache(now, n.intf.pki.GetCAPool()) if valid { return false } @@ -464,8 +464,8 @@ func (n *connectionManager) sendPunch(hostinfo *HostInfo) { } func (n *connectionManager) tryRehandshake(hostinfo *HostInfo) { - certState := n.intf.certState.Load() - if bytes.Equal(hostinfo.ConnectionState.certState.certificate.Signature, certState.certificate.Signature) { + certState := n.intf.pki.GetCertState() + if bytes.Equal(hostinfo.ConnectionState.myCert.Signature, certState.Certificate.Signature) { return } @@ -473,18 +473,5 @@ func (n *connectionManager) tryRehandshake(hostinfo *HostInfo) { WithField("reason", "local certificate is not current"). Info("Re-handshaking with remote") - //TODO: this is copied from getOrHandshake to keep the extra checks out of the hot path, figure it out - newHostinfo := n.intf.handshakeManager.AddVpnIp(hostinfo.vpnIp, n.intf.initHostInfo) - if !newHostinfo.HandshakeReady { - ixHandshakeStage0(n.intf, newHostinfo.vpnIp, newHostinfo) - } - - //If this is a static host, we don't need to wait for the HostQueryReply - //We can trigger the handshake right now - if _, ok := n.intf.lightHouse.GetStaticHostList()[hostinfo.vpnIp]; ok { - select { - case n.intf.handshakeManager.trigger <- hostinfo.vpnIp: - default: - } - } + n.intf.handshakeManager.StartHandshake(hostinfo.vpnIp, nil) } diff --git a/connection_manager_test.go b/connection_manager_test.go index 3a25611..e802904 100644 --- a/connection_manager_test.go +++ b/connection_manager_test.go @@ -42,25 +42,26 @@ func Test_NewConnectionManagerTest(t *testing.T) { preferredRanges := []*net.IPNet{localrange} // Very incomplete mock objects - hostMap := NewHostMap(l, "test", vpncidr, preferredRanges) + hostMap := NewHostMap(l, vpncidr, preferredRanges) cs := &CertState{ - rawCertificate: []byte{}, - privateKey: []byte{}, - certificate: &cert.NebulaCertificate{}, - rawCertificateNoKey: []byte{}, + RawCertificate: []byte{}, + PrivateKey: []byte{}, + Certificate: &cert.NebulaCertificate{}, + RawCertificateNoKey: []byte{}, } lh := newTestLighthouse() ifce := &Interface{ hostMap: hostMap, inside: &test.NoopTun{}, - outside: &udp.Conn{}, + outside: &udp.NoopConn{}, firewall: &Firewall{}, lightHouse: lh, - handshakeManager: NewHandshakeManager(l, vpncidr, preferredRanges, hostMap, lh, &udp.Conn{}, defaultHandshakeConfig), + pki: &PKI{}, + handshakeManager: NewHandshakeManager(l, hostMap, lh, &udp.NoopConn{}, defaultHandshakeConfig), l: l, } - ifce.certState.Store(cs) + ifce.pki.cs.Store(cs) // Create manager ctx, cancel := context.WithCancel(context.Background()) @@ -78,8 +79,8 @@ func Test_NewConnectionManagerTest(t *testing.T) { remoteIndexId: 9901, } hostinfo.ConnectionState = &ConnectionState{ - certState: cs, - H: &noise.HandshakeState{}, + myCert: &cert.NebulaCertificate{}, + H: &noise.HandshakeState{}, } nc.hostMap.unlockedAddHostInfo(hostinfo, ifce) @@ -121,25 +122,26 @@ func Test_NewConnectionManagerTest2(t *testing.T) { preferredRanges := []*net.IPNet{localrange} // Very incomplete mock objects - hostMap := NewHostMap(l, "test", vpncidr, preferredRanges) + hostMap := NewHostMap(l, vpncidr, preferredRanges) cs := &CertState{ - rawCertificate: []byte{}, - privateKey: []byte{}, - certificate: &cert.NebulaCertificate{}, - rawCertificateNoKey: []byte{}, + RawCertificate: []byte{}, + PrivateKey: []byte{}, + Certificate: &cert.NebulaCertificate{}, + RawCertificateNoKey: []byte{}, } lh := newTestLighthouse() ifce := &Interface{ hostMap: hostMap, inside: &test.NoopTun{}, - outside: &udp.Conn{}, + outside: &udp.NoopConn{}, firewall: &Firewall{}, lightHouse: lh, - handshakeManager: NewHandshakeManager(l, vpncidr, preferredRanges, hostMap, lh, &udp.Conn{}, defaultHandshakeConfig), + pki: &PKI{}, + handshakeManager: NewHandshakeManager(l, hostMap, lh, &udp.NoopConn{}, defaultHandshakeConfig), l: l, } - ifce.certState.Store(cs) + ifce.pki.cs.Store(cs) // Create manager ctx, cancel := context.WithCancel(context.Background()) @@ -157,8 +159,8 @@ func Test_NewConnectionManagerTest2(t *testing.T) { remoteIndexId: 9901, } hostinfo.ConnectionState = &ConnectionState{ - certState: cs, - H: &noise.HandshakeState{}, + myCert: &cert.NebulaCertificate{}, + H: &noise.HandshakeState{}, } nc.hostMap.unlockedAddHostInfo(hostinfo, ifce) @@ -207,7 +209,7 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) { _, vpncidr, _ := net.ParseCIDR("172.1.1.1/24") _, localrange, _ := net.ParseCIDR("10.1.1.1/24") preferredRanges := []*net.IPNet{localrange} - hostMap := NewHostMap(l, "test", vpncidr, preferredRanges) + hostMap := NewHostMap(l, vpncidr, preferredRanges) // Generate keys for CA and peer's cert. pubCA, privCA, _ := ed25519.GenerateKey(rand.Reader) @@ -220,7 +222,8 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) { PublicKey: pubCA, }, } - caCert.Sign(cert.Curve_CURVE25519, privCA) + + assert.NoError(t, caCert.Sign(cert.Curve_CURVE25519, privCA)) ncp := &cert.NebulaCAPool{ CAs: cert.NewCAPool().CAs, } @@ -239,28 +242,29 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) { Issuer: "ca", }, } - peerCert.Sign(cert.Curve_CURVE25519, privCA) + assert.NoError(t, peerCert.Sign(cert.Curve_CURVE25519, privCA)) cs := &CertState{ - rawCertificate: []byte{}, - privateKey: []byte{}, - certificate: &cert.NebulaCertificate{}, - rawCertificateNoKey: []byte{}, + RawCertificate: []byte{}, + PrivateKey: []byte{}, + Certificate: &cert.NebulaCertificate{}, + RawCertificateNoKey: []byte{}, } lh := newTestLighthouse() ifce := &Interface{ hostMap: hostMap, inside: &test.NoopTun{}, - outside: &udp.Conn{}, + outside: &udp.NoopConn{}, firewall: &Firewall{}, lightHouse: lh, - handshakeManager: NewHandshakeManager(l, vpncidr, preferredRanges, hostMap, lh, &udp.Conn{}, defaultHandshakeConfig), + handshakeManager: NewHandshakeManager(l, hostMap, lh, &udp.NoopConn{}, defaultHandshakeConfig), l: l, disconnectInvalid: true, - caPool: ncp, + pki: &PKI{}, } - ifce.certState.Store(cs) + ifce.pki.cs.Store(cs) + ifce.pki.caPool.Store(ncp) // Create manager ctx, cancel := context.WithCancel(context.Background()) @@ -268,12 +272,16 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) { punchy := NewPunchyFromConfig(l, config.NewC(l)) nc := newConnectionManager(ctx, l, ifce, 5, 10, punchy) ifce.connectionManager = nc - hostinfo, _ := nc.hostMap.AddVpnIp(vpnIp, nil) - hostinfo.ConnectionState = &ConnectionState{ - certState: cs, - peerCert: &peerCert, - H: &noise.HandshakeState{}, + + hostinfo := &HostInfo{ + vpnIp: vpnIp, + ConnectionState: &ConnectionState{ + myCert: &cert.NebulaCertificate{}, + peerCert: &peerCert, + H: &noise.HandshakeState{}, + }, } + nc.hostMap.unlockedAddHostInfo(hostinfo, ifce) // Move ahead 45s. // Check if to disconnect with invalid certificate. diff --git a/connection_state.go b/connection_state.go index ab818c9..f8c31f6 100644 --- a/connection_state.go +++ b/connection_state.go @@ -18,35 +18,35 @@ type ConnectionState struct { eKey *NebulaCipherState dKey *NebulaCipherState H *noise.HandshakeState - certState *CertState + myCert *cert.NebulaCertificate peerCert *cert.NebulaCertificate initiator bool messageCounter atomic.Uint64 window *Bits - queueLock sync.Mutex writeLock sync.Mutex ready bool } -func (f *Interface) newConnectionState(l *logrus.Logger, initiator bool, pattern noise.HandshakePattern, psk []byte, pskStage int) *ConnectionState { +func NewConnectionState(l *logrus.Logger, cipher string, certState *CertState, initiator bool, pattern noise.HandshakePattern, psk []byte, pskStage int) *ConnectionState { var dhFunc noise.DHFunc - curCertState := f.certState.Load() - - switch curCertState.certificate.Details.Curve { + switch certState.Certificate.Details.Curve { case cert.Curve_CURVE25519: dhFunc = noise.DH25519 case cert.Curve_P256: dhFunc = noiseutil.DHP256 default: - l.Errorf("invalid curve: %s", curCertState.certificate.Details.Curve) + l.Errorf("invalid curve: %s", certState.Certificate.Details.Curve) return nil } - cs := noise.NewCipherSuite(dhFunc, noiseutil.CipherAESGCM, noise.HashSHA256) - if f.cipher == "chachapoly" { + + var cs noise.CipherSuite + if cipher == "chachapoly" { cs = noise.NewCipherSuite(dhFunc, noise.CipherChaChaPoly, noise.HashSHA256) + } else { + cs = noise.NewCipherSuite(dhFunc, noiseutil.CipherAESGCM, noise.HashSHA256) } - static := noise.DHKey{Private: curCertState.privateKey, Public: curCertState.publicKey} + static := noise.DHKey{Private: certState.PrivateKey, Public: certState.PublicKey} b := NewBits(ReplayWindow) // Clear out bit 0, we never transmit it and we don't want it showing as packet loss @@ -72,7 +72,7 @@ func (f *Interface) newConnectionState(l *logrus.Logger, initiator bool, pattern initiator: initiator, window: b, ready: false, - certState: curCertState, + myCert: certState.Certificate, } return ci diff --git a/control.go b/control.go index 203278d..4af115c 100644 --- a/control.go +++ b/control.go @@ -17,13 +17,23 @@ import ( // Every interaction here needs to take extra care to copy memory and not return or use arguments "as is" when touching // core. This means copying IP objects, slices, de-referencing pointers and taking the actual value, etc +type controlEach func(h *HostInfo) + +type controlHostLister interface { + QueryVpnIp(vpnIp iputil.VpnIp) *HostInfo + ForEachIndex(each controlEach) + ForEachVpnIp(each controlEach) + GetPreferredRanges() []*net.IPNet +} + type Control struct { - f *Interface - l *logrus.Logger - cancel context.CancelFunc - sshStart func() - statsStart func() - dnsStart func() + f *Interface + l *logrus.Logger + cancel context.CancelFunc + sshStart func() + statsStart func() + dnsStart func() + lighthouseStart func() } type ControlHostInfo struct { @@ -54,12 +64,15 @@ func (c *Control) Start() { if c.dnsStart != nil { go c.dnsStart() } + if c.lighthouseStart != nil { + c.lighthouseStart() + } // Start reading packets. c.f.run() } -// Stop signals nebula to shutdown, returns after the shutdown is complete +// Stop signals nebula to shutdown and close all tunnels, returns after the shutdown is complete func (c *Control) Stop() { // Stop the handshakeManager (and other services), to prevent new tunnels from // being created while we're shutting them all down. @@ -89,7 +102,7 @@ func (c *Control) RebindUDPServer() { _ = c.f.outside.Rebind() // Trigger a lighthouse update, useful for mobile clients that should have an update interval of 0 - c.f.lightHouse.SendUpdate(c.f) + c.f.lightHouse.SendUpdate() // Let the main interface know that we rebound so that underlying tunnels know to trigger punches from their remotes c.f.rebindCount++ @@ -98,7 +111,7 @@ func (c *Control) RebindUDPServer() { // ListHostmapHosts returns details about the actual or pending (handshaking) hostmap by vpn ip func (c *Control) ListHostmapHosts(pendingMap bool) []ControlHostInfo { if pendingMap { - return listHostMapHosts(c.f.handshakeManager.pendingHostMap) + return listHostMapHosts(c.f.handshakeManager) } else { return listHostMapHosts(c.f.hostMap) } @@ -107,7 +120,7 @@ func (c *Control) ListHostmapHosts(pendingMap bool) []ControlHostInfo { // ListHostmapIndexes returns details about the actual or pending (handshaking) hostmap by local index id func (c *Control) ListHostmapIndexes(pendingMap bool) []ControlHostInfo { if pendingMap { - return listHostMapIndexes(c.f.handshakeManager.pendingHostMap) + return listHostMapIndexes(c.f.handshakeManager) } else { return listHostMapIndexes(c.f.hostMap) } @@ -115,15 +128,15 @@ func (c *Control) ListHostmapIndexes(pendingMap bool) []ControlHostInfo { // GetHostInfoByVpnIp returns a single tunnels hostInfo, or nil if not found func (c *Control) GetHostInfoByVpnIp(vpnIp iputil.VpnIp, pending bool) *ControlHostInfo { - var hm *HostMap + var hl controlHostLister if pending { - hm = c.f.handshakeManager.pendingHostMap + hl = c.f.handshakeManager } else { - hm = c.f.hostMap + hl = c.f.hostMap } - h, err := hm.QueryVpnIp(vpnIp) - if err != nil { + h := hl.QueryVpnIp(vpnIp) + if h == nil { return nil } @@ -133,8 +146,8 @@ func (c *Control) GetHostInfoByVpnIp(vpnIp iputil.VpnIp, pending bool) *ControlH // SetRemoteForTunnel forces a tunnel to use a specific remote func (c *Control) SetRemoteForTunnel(vpnIp iputil.VpnIp, addr udp.Addr) *ControlHostInfo { - hostInfo, err := c.f.hostMap.QueryVpnIp(vpnIp) - if err != nil { + hostInfo := c.f.hostMap.QueryVpnIp(vpnIp) + if hostInfo == nil { return nil } @@ -145,8 +158,8 @@ func (c *Control) SetRemoteForTunnel(vpnIp iputil.VpnIp, addr udp.Addr) *Control // CloseTunnel closes a fully established tunnel. If localOnly is false it will notify the remote end as well. func (c *Control) CloseTunnel(vpnIp iputil.VpnIp, localOnly bool) bool { - hostInfo, err := c.f.hostMap.QueryVpnIp(vpnIp) - if err != nil { + hostInfo := c.f.hostMap.QueryVpnIp(vpnIp) + if hostInfo == nil { return false } @@ -241,28 +254,20 @@ func copyHostInfo(h *HostInfo, preferredRanges []*net.IPNet) ControlHostInfo { return chi } -func listHostMapHosts(hm *HostMap) []ControlHostInfo { - hm.RLock() - hosts := make([]ControlHostInfo, len(hm.Hosts)) - i := 0 - for _, v := range hm.Hosts { - hosts[i] = copyHostInfo(v, hm.preferredRanges) - i++ - } - hm.RUnlock() - +func listHostMapHosts(hl controlHostLister) []ControlHostInfo { + hosts := make([]ControlHostInfo, 0) + pr := hl.GetPreferredRanges() + hl.ForEachVpnIp(func(hostinfo *HostInfo) { + hosts = append(hosts, copyHostInfo(hostinfo, pr)) + }) return hosts } -func listHostMapIndexes(hm *HostMap) []ControlHostInfo { - hm.RLock() - hosts := make([]ControlHostInfo, len(hm.Indexes)) - i := 0 - for _, v := range hm.Indexes { - hosts[i] = copyHostInfo(v, hm.preferredRanges) - i++ - } - hm.RUnlock() - +func listHostMapIndexes(hl controlHostLister) []ControlHostInfo { + hosts := make([]ControlHostInfo, 0) + pr := hl.GetPreferredRanges() + hl.ForEachIndex(func(hostinfo *HostInfo) { + hosts = append(hosts, copyHostInfo(hostinfo, pr)) + }) return hosts } diff --git a/control_test.go b/control_test.go index de46991..56a2b2f 100644 --- a/control_test.go +++ b/control_test.go @@ -18,7 +18,7 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) { l := test.NewLogger() // Special care must be taken to re-use all objects provided to the hostmap and certificate in the expectedInfo object // To properly ensure we are not exposing core memory to the caller - hm := NewHostMap(l, "test", &net.IPNet{}, make([]*net.IPNet, 0)) + hm := NewHostMap(l, &net.IPNet{}, make([]*net.IPNet, 0)) remote1 := udp.NewAddr(net.ParseIP("0.0.0.100"), 4444) remote2 := udp.NewAddr(net.ParseIP("1:2:3:4:5:6:7:8"), 4444) ipNet := net.IPNet{ @@ -50,7 +50,7 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) { remotes := NewRemoteList(nil) remotes.unlockedPrependV4(0, NewIp4AndPort(remote1.IP, uint32(remote1.Port))) remotes.unlockedPrependV6(0, NewIp6AndPort(remote2.IP, uint32(remote2.Port))) - hm.Add(iputil.Ip2VpnIp(ipNet.IP), &HostInfo{ + hm.unlockedAddHostInfo(&HostInfo{ remote: remote1, remotes: remotes, ConnectionState: &ConnectionState{ @@ -64,9 +64,9 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) { relayForByIp: map[iputil.VpnIp]*Relay{}, relayForByIdx: map[uint32]*Relay{}, }, - }) + }, &Interface{}) - hm.Add(iputil.Ip2VpnIp(ipNet2.IP), &HostInfo{ + hm.unlockedAddHostInfo(&HostInfo{ remote: remote1, remotes: remotes, ConnectionState: &ConnectionState{ @@ -80,7 +80,7 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) { relayForByIp: map[iputil.VpnIp]*Relay{}, relayForByIdx: map[uint32]*Relay{}, }, - }) + }, &Interface{}) c := Control{ f: &Interface{ diff --git a/control_tester.go b/control_tester.go index 48deb13..b786ba3 100644 --- a/control_tester.go +++ b/control_tester.go @@ -21,7 +21,7 @@ import ( func (c *Control) WaitForType(msgType header.MessageType, subType header.MessageSubType, pipeTo *Control) { h := &header.H{} for { - p := c.f.outside.Get(true) + p := c.f.outside.(*udp.TesterConn).Get(true) if err := h.Parse(p.Data); err != nil { panic(err) } @@ -37,7 +37,7 @@ func (c *Control) WaitForType(msgType header.MessageType, subType header.Message func (c *Control) WaitForTypeByIndex(toIndex uint32, msgType header.MessageType, subType header.MessageSubType, pipeTo *Control) { h := &header.H{} for { - p := c.f.outside.Get(true) + p := c.f.outside.(*udp.TesterConn).Get(true) if err := h.Parse(p.Data); err != nil { panic(err) } @@ -90,11 +90,11 @@ func (c *Control) GetFromTun(block bool) []byte { // GetFromUDP will pull a udp packet off the udp side of nebula func (c *Control) GetFromUDP(block bool) *udp.Packet { - return c.f.outside.Get(block) + return c.f.outside.(*udp.TesterConn).Get(block) } func (c *Control) GetUDPTxChan() <-chan *udp.Packet { - return c.f.outside.TxPackets + return c.f.outside.(*udp.TesterConn).TxPackets } func (c *Control) GetTunTxChan() <-chan []byte { @@ -103,7 +103,7 @@ func (c *Control) GetTunTxChan() <-chan []byte { // InjectUDPPacket will inject a packet into the udp side of nebula func (c *Control) InjectUDPPacket(p *udp.Packet) { - c.f.outside.Send(p) + c.f.outside.(*udp.TesterConn).Send(p) } // InjectTunUDPPacket puts a udp packet on the tun interface. Using UDP here because it's a simpler protocol @@ -143,16 +143,16 @@ func (c *Control) GetVpnIp() iputil.VpnIp { } func (c *Control) GetUDPAddr() string { - return c.f.outside.Addr.String() + return c.f.outside.(*udp.TesterConn).Addr.String() } func (c *Control) KillPendingTunnel(vpnIp net.IP) bool { - hostinfo, ok := c.f.handshakeManager.pendingHostMap.Hosts[iputil.Ip2VpnIp(vpnIp)] - if !ok { + hostinfo := c.f.handshakeManager.QueryVpnIp(iputil.Ip2VpnIp(vpnIp)) + if hostinfo == nil { return false } - c.f.handshakeManager.pendingHostMap.DeleteHostInfo(hostinfo) + c.f.handshakeManager.DeleteHostInfo(hostinfo) return true } @@ -161,19 +161,9 @@ func (c *Control) GetHostmap() *HostMap { } func (c *Control) GetCert() *cert.NebulaCertificate { - return c.f.certState.Load().certificate + return c.f.pki.GetCertState().Certificate } func (c *Control) ReHandshake(vpnIp iputil.VpnIp) { - hostinfo := c.f.handshakeManager.AddVpnIp(vpnIp, c.f.initHostInfo) - ixHandshakeStage0(c.f, vpnIp, hostinfo) - - // If this is a static host, we don't need to wait for the HostQueryReply - // We can trigger the handshake right now - if _, ok := c.f.lightHouse.GetStaticHostList()[hostinfo.vpnIp]; ok { - select { - case c.f.handshakeManager.trigger <- hostinfo.vpnIp: - default: - } - } + c.f.handshakeManager.StartHandshake(vpnIp, nil) } diff --git a/dist/arch/nebula.service b/dist/arch/nebula.service index 7e5335a..831c71a 100644 --- a/dist/arch/nebula.service +++ b/dist/arch/nebula.service @@ -4,6 +4,8 @@ Wants=basic.target network-online.target nss-lookup.target time-sync.target After=basic.target network.target network-online.target [Service] +Type=notify +NotifyAccess=main SyslogIdentifier=nebula ExecReload=/bin/kill -HUP $MAINPID ExecStart=/usr/bin/nebula -config /etc/nebula/config.yml diff --git a/dist/fedora/nebula.service b/dist/fedora/nebula.service index 21a99c5..0f947ea 100644 --- a/dist/fedora/nebula.service +++ b/dist/fedora/nebula.service @@ -5,6 +5,8 @@ After=basic.target network.target network-online.target Before=sshd.service [Service] +Type=notify +NotifyAccess=main SyslogIdentifier=nebula ExecReload=/bin/kill -HUP $MAINPID ExecStart=/usr/bin/nebula -config /etc/nebula/config.yml diff --git a/dns_server.go b/dns_server.go index 19bc5ce..3109b4c 100644 --- a/dns_server.go +++ b/dns_server.go @@ -47,8 +47,8 @@ func (d *dnsRecords) QueryCert(data string) string { return "" } iip := iputil.Ip2VpnIp(ip) - hostinfo, err := d.hostMap.QueryVpnIp(iip) - if err != nil { + hostinfo := d.hostMap.QueryVpnIp(iip) + if hostinfo == nil { return "" } q := hostinfo.GetCert() diff --git a/e2e/handshakes_test.go b/e2e/handshakes_test.go index aa62603..022b5a3 100644 --- a/e2e/handshakes_test.go +++ b/e2e/handshakes_test.go @@ -410,6 +410,8 @@ func TestStage1RaceRelays(t *testing.T) { p := r.RouteForAllUntilTxTun(myControl) _ = p + r.FlushAll() + myControl.Stop() theirControl.Stop() relayControl.Stop() @@ -608,6 +610,110 @@ func TestRehandshakingRelays(t *testing.T) { t.Logf("relayControl hostinfos got cleaned up!") } +func TestRehandshakingRelaysPrimary(t *testing.T) { + // This test is the same as TestRehandshakingRelays but one of the terminal types is a primary swap winner + ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) + myControl, myVpnIpNet, _, _ := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 128}, m{"relay": m{"use_relays": true}}) + relayControl, relayVpnIpNet, relayUdpAddr, relayConfig := newSimpleServer(ca, caKey, "relay ", net.IP{10, 0, 0, 1}, m{"relay": m{"am_relay": true}}) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", net.IP{10, 0, 0, 2}, m{"relay": m{"use_relays": true}}) + + // Teach my how to get to the relay and that their can be reached via the relay + myControl.InjectLightHouseAddr(relayVpnIpNet.IP, relayUdpAddr) + myControl.InjectRelays(theirVpnIpNet.IP, []net.IP{relayVpnIpNet.IP}) + relayControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr) + + // Build a router so we don't have to reason who gets which packet + r := router.NewR(t, myControl, relayControl, theirControl) + defer r.RenderFlow() + + // Start the servers + myControl.Start() + relayControl.Start() + theirControl.Start() + + t.Log("Trigger a handshake from me to them via the relay") + myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me")) + + p := r.RouteForAllUntilTxTun(theirControl) + r.Log("Assert the tunnel works") + assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.IP, theirVpnIpNet.IP, 80, 80) + r.RenderHostmaps("working hostmaps", myControl, relayControl, theirControl) + + // When I update the certificate for the relay, both me and them will have 2 host infos for the relay, + // and the main host infos will not have any relay state to handle the me<->relay<->them tunnel. + r.Log("Renew relay certificate and spin until me and them sees it") + _, _, myNextPrivKey, myNextPEM := newTestCert(ca, caKey, "relay", time.Now(), time.Now().Add(5*time.Minute), relayVpnIpNet, nil, []string{"new group"}) + + caB, err := ca.MarshalToPEM() + if err != nil { + panic(err) + } + + relayConfig.Settings["pki"] = m{ + "ca": string(caB), + "cert": string(myNextPEM), + "key": string(myNextPrivKey), + } + rc, err := yaml.Marshal(relayConfig.Settings) + assert.NoError(t, err) + relayConfig.ReloadConfigString(string(rc)) + + for { + r.Log("Assert the tunnel works between myVpnIpNet and relayVpnIpNet") + assertTunnel(t, myVpnIpNet.IP, relayVpnIpNet.IP, myControl, relayControl, r) + c := myControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(relayVpnIpNet.IP), false) + if len(c.Cert.Details.Groups) != 0 { + // We have a new certificate now + r.Log("Certificate between my and relay is updated!") + break + } + + time.Sleep(time.Second) + } + + for { + r.Log("Assert the tunnel works between theirVpnIpNet and relayVpnIpNet") + assertTunnel(t, theirVpnIpNet.IP, relayVpnIpNet.IP, theirControl, relayControl, r) + c := theirControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(relayVpnIpNet.IP), false) + if len(c.Cert.Details.Groups) != 0 { + // We have a new certificate now + r.Log("Certificate between their and relay is updated!") + break + } + + time.Sleep(time.Second) + } + + r.Log("Assert the relay tunnel still works") + assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r) + r.RenderHostmaps("working hostmaps", myControl, relayControl, theirControl) + // We should have two hostinfos on all sides + for len(myControl.GetHostmap().Indexes) != 2 { + t.Logf("Waiting for myControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(myControl.GetHostmap().Indexes)) + r.Log("Assert the relay tunnel still works") + assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r) + r.Log("yupitdoes") + time.Sleep(time.Second) + } + t.Logf("myControl hostinfos got cleaned up!") + for len(theirControl.GetHostmap().Indexes) != 2 { + t.Logf("Waiting for theirControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(theirControl.GetHostmap().Indexes)) + r.Log("Assert the relay tunnel still works") + assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r) + r.Log("yupitdoes") + time.Sleep(time.Second) + } + t.Logf("theirControl hostinfos got cleaned up!") + for len(relayControl.GetHostmap().Indexes) != 2 { + t.Logf("Waiting for relayControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(relayControl.GetHostmap().Indexes)) + r.Log("Assert the relay tunnel still works") + assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r) + r.Log("yupitdoes") + time.Sleep(time.Second) + } + t.Logf("relayControl hostinfos got cleaned up!") +} + func TestRehandshaking(t *testing.T) { ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) myControl, myVpnIpNet, myUdpAddr, myConfig := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 2}, nil) diff --git a/e2e/helpers_test.go b/e2e/helpers_test.go index f2e3128..8440a72 100644 --- a/e2e/helpers_test.go +++ b/e2e/helpers_test.go @@ -12,9 +12,9 @@ import ( "testing" "time" + "dario.cat/mergo" "github.com/google/gopacket" "github.com/google/gopacket/layers" - "github.com/imdario/mergo" "github.com/sirupsen/logrus" "github.com/slackhq/nebula" "github.com/slackhq/nebula/cert" diff --git a/examples/config.yml b/examples/config.yml index e1a556e..96ae8de 100644 --- a/examples/config.yml +++ b/examples/config.yml @@ -21,6 +21,19 @@ pki: static_host_map: "192.168.100.1": ["100.64.22.11:4242"] +# The static_map config stanza can be used to configure how the static_host_map behaves. +#static_map: + # cadence determines how frequently DNS is re-queried for updated IP addresses when a static_host_map entry contains + # a DNS name. + #cadence: 30s + + # network determines the type of IP addresses to ask the DNS server for. The default is "ip4" because nodes typically + # do not know their public IPv4 address. Connecting to the Lighthouse via IPv4 allows the Lighthouse to detect the + # public address. Other valid options are "ip6" and "ip" (returns both.) + #network: ip4 + + # lookup_timeout is the DNS query timeout. + #lookup_timeout: 250ms lighthouse: # am_lighthouse is used to enable lighthouse functionality for a node. This should ONLY be true on nodes @@ -158,7 +171,8 @@ punchy: # and has been deprecated for "preferred_ranges" #preferred_ranges: ["172.16.0.0/24"] -# sshd can expose informational and administrative functions via ssh this is a +# sshd can expose informational and administrative functions via ssh. This can expose informational and administrative +# functions, and allows manual tweaking of various network settings when debugging or testing. #sshd: # Toggles the feature #enabled: true @@ -194,7 +208,7 @@ tun: disabled: false # Name of the device. If not set, a default will be chosen by the OS. # For macOS: if set, must be in the form `utun[0-9]+`. - # For FreeBSD: Required to be set, must be in the form `tun[0-9]+`. + # For NetBSD: Required to be set, must be in the form `tun[0-9]+` dev: nebula1 # Toggles forwarding of local broadcast packets, the address of which depends on the ip/mask encoded in pki.cert drop_local_broadcast: false diff --git a/examples/service_scripts/nebula.service b/examples/service_scripts/nebula.service index fd7a067..ab5218f 100644 --- a/examples/service_scripts/nebula.service +++ b/examples/service_scripts/nebula.service @@ -5,6 +5,8 @@ After=basic.target network.target network-online.target Before=sshd.service [Service] +Type=notify +NotifyAccess=main SyslogIdentifier=nebula ExecReload=/bin/kill -HUP $MAINPID ExecStart=/usr/local/bin/nebula -config /etc/nebula/config.yml diff --git a/go.mod b/go.mod index 52c2e92..ba57aa1 100644 --- a/go.mod +++ b/go.mod @@ -3,31 +3,32 @@ module github.com/slackhq/nebula go 1.20 require ( + dario.cat/mergo v1.0.0 github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be github.com/armon/go-radix v1.0.0 github.com/cyberdelia/go-metrics-graphite v0.0.0-20161219230853-39f87cc3b432 github.com/flynn/noise v1.0.0 github.com/gogo/protobuf v1.3.2 github.com/google/gopacket v1.1.19 - github.com/imdario/mergo v0.3.15 github.com/kardianos/service v1.2.2 - github.com/miekg/dns v1.1.54 + github.com/miekg/dns v1.1.56 github.com/nbrownus/go-metrics-prometheus v0.0.0-20210712211119-974a6260965f - github.com/prometheus/client_golang v1.15.1 + github.com/prometheus/client_golang v1.16.0 github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475 - github.com/sirupsen/logrus v1.9.0 + github.com/sirupsen/logrus v1.9.3 github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8 - github.com/stretchr/testify v1.8.2 + github.com/stretchr/testify v1.8.4 github.com/vishvananda/netlink v1.1.0 - golang.org/x/crypto v0.8.0 + golang.org/x/crypto v0.14.0 golang.org/x/exp v0.0.0-20230425010034-47ecfdc1ba53 - golang.org/x/net v0.9.0 - golang.org/x/sys v0.8.0 - golang.org/x/term v0.8.0 + golang.org/x/net v0.17.0 + golang.org/x/sys v0.13.0 + golang.org/x/term v0.13.0 golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 + golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b golang.zx2c4.com/wireguard/windows v0.5.3 - google.golang.org/protobuf v1.30.0 + google.golang.org/protobuf v1.31.0 gopkg.in/yaml.v2 v2.4.0 ) @@ -40,10 +41,10 @@ require ( github.com/pmezard/go-difflib v1.0.0 // indirect github.com/prometheus/client_model v0.4.0 // indirect github.com/prometheus/common v0.42.0 // indirect - github.com/prometheus/procfs v0.9.0 // indirect + github.com/prometheus/procfs v0.10.1 // indirect github.com/rogpeppe/go-internal v1.10.0 // indirect github.com/vishvananda/netns v0.0.4 // indirect - golang.org/x/mod v0.10.0 // indirect - golang.org/x/tools v0.8.0 // indirect + golang.org/x/mod v0.12.0 // indirect + golang.org/x/tools v0.13.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 452a1d2..445f18a 100644 --- a/go.sum +++ b/go.sum @@ -1,4 +1,6 @@ cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= +dario.cat/mergo v1.0.0 h1:AGCNq9Evsj31mOgNPcLyXc+4PNABt905YmuqPYYpBWk= +dario.cat/mergo v1.0.0/go.mod h1:uNxQE+84aUszobStD9th8a29P2fMDhsBdgRYvZOxGmk= github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= @@ -54,8 +56,6 @@ github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/gopacket v1.1.19 h1:ves8RnFZPGiFnTS0uPQStjwru6uO6h+nlr9j6fL7kF8= github.com/google/gopacket v1.1.19/go.mod h1:iJ8V8n6KS+z2U1A8pUwu8bW5SyEMkXJB8Yo/Vo+TKTo= -github.com/imdario/mergo v0.3.15 h1:M8XP7IuFNsqUx6VPK2P9OSmsYsI/YFaGil0uD21V3dM= -github.com/imdario/mergo v0.3.15/go.mod h1:WBLT9ZmE3lPoWsEzCh9LPo3TiwVN+ZKEjmz+hD27ysY= github.com/jpillora/backoff v1.0.0/go.mod h1:J/6gKK9jxlEcS3zixgDgUAsiuZ7yrSoa/FX5e0EB2j4= github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU= github.com/json-iterator/go v1.1.10/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= @@ -78,8 +78,8 @@ github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= github.com/matttproud/golang_protobuf_extensions v1.0.4 h1:mmDVorXM7PCGKw94cs5zkfA9PSy5pEvNWRP0ET0TIVo= github.com/matttproud/golang_protobuf_extensions v1.0.4/go.mod h1:BSXmuO+STAnVfrANrmjBb36TMTDstsz7MSK+HVaYKv4= -github.com/miekg/dns v1.1.54 h1:5jon9mWcb0sFJGpnI99tOMhCPyJ+RPVz5b63MQG0VWI= -github.com/miekg/dns v1.1.54/go.mod h1:uInx36IzPl7FYnDcMeVWxj9byh7DutNykX4G9Sj60FY= +github.com/miekg/dns v1.1.56 h1:5imZaSeoRNvpM9SzWNhEcP9QliKiz20/dA2QabIGVnE= +github.com/miekg/dns v1.1.56/go.mod h1:cRm6Oo2C8TY9ZS/TqsSrseAcncm74lfK5G+ikN2SWWY= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= @@ -97,8 +97,8 @@ github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXP github.com/prometheus/client_golang v1.0.0/go.mod h1:db9x61etRT2tGnBNRi70OPL5FsnadC4Ky3P0J6CfImo= github.com/prometheus/client_golang v1.7.1/go.mod h1:PY5Wy2awLA44sXw4AOSfFBetzPP4j5+D6mVACh+pe2M= github.com/prometheus/client_golang v1.11.0/go.mod h1:Z6t4BnS23TR94PD6BsDNk8yVqroYurpAkEiz0P2BEV0= -github.com/prometheus/client_golang v1.15.1 h1:8tXpTmJbyH5lydzFPoxSIJ0J46jdh3tylbvM1xCv0LI= -github.com/prometheus/client_golang v1.15.1/go.mod h1:e9yaBhRPU2pPNsZwE+JdQl0KEt1N9XgF6zxWmaC0xOk= +github.com/prometheus/client_golang v1.16.0 h1:yk/hx9hDbrGHovbci4BY+pRMfSuuat626eFsHb7tmT8= +github.com/prometheus/client_golang v1.16.0/go.mod h1:Zsulrv/L9oM40tJ7T815tM89lFEugiJ9HzIqaAx4LKc= github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo= github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= github.com/prometheus/client_model v0.2.0/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= @@ -113,8 +113,8 @@ github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d/go.mod h1:c3At6R github.com/prometheus/procfs v0.0.2/go.mod h1:TjEm7ze935MbeOT/UhFTIMYKhuLP4wbCsTZCD3I8kEA= github.com/prometheus/procfs v0.1.3/go.mod h1:lV6e/gmhEcM9IjHGsFOCxxuZ+z1YqCvr4OA4YeYWdaU= github.com/prometheus/procfs v0.6.0/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1xBZuNvfVA= -github.com/prometheus/procfs v0.9.0 h1:wzCHvIvM5SxWqYvwgVL7yJY8Lz3PKn49KQtpgMYJfhI= -github.com/prometheus/procfs v0.9.0/go.mod h1:+pB4zwohETzFnmlpe6yd2lSc+0/46IYZRB/chUwxUZY= +github.com/prometheus/procfs v0.10.1 h1:kYK1Va/YMlutzCGazswoHKo//tZVlFpKYh+PymziUAg= +github.com/prometheus/procfs v0.10.1/go.mod h1:nwNm2aOCAYw8uTR/9bWRREkZFxAUcWzPHWJq+XBB/FM= github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475 h1:N/ElC8H3+5XpJzTSTfLsJV/mx9Q9g7kxmchpfZyxgzM= github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4= github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ= @@ -122,24 +122,20 @@ github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncj github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= github.com/sirupsen/logrus v1.6.0/go.mod h1:7uNnSEd1DgxDLC74fIahvMZmmYsHGZGEOFrfsX/uA88= -github.com/sirupsen/logrus v1.9.0 h1:trlNQbNUG3OdDrDil03MCb1H2o9nJ1x4/5LYw7byDE0= -github.com/sirupsen/logrus v1.9.0/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= +github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= +github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e h1:MRM5ITcdelLK2j1vwZ3Je0FKVCfqOLp5zO6trqMLYs0= github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e/go.mod h1:XV66xRDqSt+GTGFMVlhk3ULuV0y9ZmzeVGR4mloJI3M= github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8 h1:TG/diQgUe0pntT/2D9tmUCz4VNwm9MfrtPr0SU2qSX8= github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8/go.mod h1:P5HUIBuIWKbyjl083/loAegFkfbFNx5i2qEP4CNbm7E= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= -github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= -github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8= -github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/vishvananda/netlink v1.1.0 h1:1iyaYNBLmP6L0220aDnYQpo1QEV4t4hJ+xEEhhJH8j0= github.com/vishvananda/netlink v1.1.0/go.mod h1:cTgwzPIzzgDAYoQrMm0EdrjRUBkTqKYppBueQtXaqoE= github.com/vishvananda/netns v0.0.0-20191106174202-0a2b9b5464df/go.mod h1:JP3t17pCcGlemwknint6hfoeCVQrEMVwxRLRjXpq+BU= @@ -152,16 +148,16 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= -golang.org/x/crypto v0.8.0 h1:pd9TJtTueMTVQXzk8E2XESSMQDj/U7OUu0PqJqPXQjQ= -golang.org/x/crypto v0.8.0/go.mod h1:mRqEX+O9/h5TFCrQhkgjo2yKi0yYA+9ecGkdQoHrywE= +golang.org/x/crypto v0.14.0 h1:wBqGXzWJW6m1XrIKlAH0Hs1JJ7+9KBwnIO8v66Q9cHc= +golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4= golang.org/x/exp v0.0.0-20230425010034-47ecfdc1ba53 h1:5llv2sWeaMSnA3w2kS57ouQQ4pudlXrR0dCgw51QK9o= golang.org/x/exp v0.0.0-20230425010034-47ecfdc1ba53/go.mod h1:V1LtkGg67GoY2N1AnLN78QLrzxkLyJw7RJb1gzOOz9w= golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY= golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= -golang.org/x/mod v0.10.0 h1:lFO9qtOdlre5W1jxS3r/4szv2/6iXxScdzjoBMXNhYk= -golang.org/x/mod v0.10.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= +golang.org/x/mod v0.12.0 h1:rmsUpXtvNzj340zd98LZ4KntptpfRHwpFOHG188oHXc= +golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= @@ -172,8 +168,8 @@ golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLL golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= -golang.org/x/net v0.9.0 h1:aWJ/m6xSmxWBx+V0XRHTlrYrPG56jKsLdTFmsSsCzOM= -golang.org/x/net v0.9.0/go.mod h1:d48xBJpPfHeWQsugry2m+kC02ZBRGRgulfHnEXEuWns= +golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM= +golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -181,7 +177,7 @@ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o= +golang.org/x/sync v0.3.0 h1:ftCYgMx6zT/asHUrPw8BLLscYtGznsLAnjq5RH9P66E= golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -198,11 +194,11 @@ golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210603081109-ebe580a85c40/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.8.0 h1:EBmGv8NaZBZTWvrbjNoL6HVt+IVy3QDQpJs7VRIw3tU= -golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE= +golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= -golang.org/x/term v0.8.0 h1:n5xxQn2i3PC0yLAbjTpNT85q/Kgzcr2gIoX9OrJUols= -golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= +golang.org/x/term v0.13.0 h1:bb+I9cTfFazGW51MZqBVmZy7+JEJMouUHTUSKVQLBek= +golang.org/x/term v0.13.0/go.mod h1:LTmsnFJwVN6bCy1rVCoS+qHT1HhALEFxKncY3WNNh4U= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= @@ -211,14 +207,16 @@ golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtn golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= -golang.org/x/tools v0.8.0 h1:vSDcovVPld282ceKgDimkRSC8kpaH1dgyc9UMzlt84Y= -golang.org/x/tools v0.8.0/go.mod h1:JxBZ99ISMI5ViVkT1tr6tdNmXeTrcpVSD3vZ1RsRdN4= +golang.org/x/tools v0.13.0 h1:Iey4qkscZuv0VvIt8E0neZjtPVQFSc870HQ448QgEmQ= +golang.org/x/tools v0.13.0/go.mod h1:HvlwmtVNQAhOuCjW7xxvovg8wbNq7LwfXh/k7wXUl58= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg= golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI= +golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b h1:J1CaxgLerRR5lgx3wnr6L04cJFbWoceSK9JWBdglINo= +golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b/go.mod h1:tqur9LnfstdR9ep2LaJT4lFUl0EjlHtge+gAjmsHUG4= golang.zx2c4.com/wireguard/windows v0.5.3 h1:On6j2Rpn3OEMXqBq00QEDC7bWSZrPIHKIus8eIuExIE= golang.zx2c4.com/wireguard/windows v0.5.3/go.mod h1:9TEe8TJmtwyQebdFwAkEWOPr3prrtqm+REGFifP60hI= google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= @@ -230,8 +228,8 @@ google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzi google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= -google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng= -google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= +google.golang.org/protobuf v1.31.0 h1:g0LDEJHgrBl9N9r17Ru3sqWhkIx2NB67okBHPwC7hs8= +google.golang.org/protobuf v1.31.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/handshake.go b/handshake.go index 1f2f03a..8cfba21 100644 --- a/handshake.go +++ b/handshake.go @@ -20,7 +20,7 @@ func HandleIncomingHandshake(f *Interface, addr *udp.Addr, via *ViaSender, packe case 1: ixHandshakeStage1(f, addr, via, packet, h) case 2: - newHostinfo, _ := f.handshakeManager.QueryIndex(h.RemoteIndex) + newHostinfo := f.handshakeManager.QueryIndex(h.RemoteIndex) tearDown := ixHandshakeStage2(f, addr, via, newHostinfo, packet, h) if tearDown && newHostinfo != nil { f.handshakeManager.DeleteHostInfo(newHostinfo) diff --git a/handshake_ix.go b/handshake_ix.go index 39615b1..26cc983 100644 --- a/handshake_ix.go +++ b/handshake_ix.go @@ -13,27 +13,22 @@ import ( // This function constructs a handshake packet, but does not actually send it // Sending is done by the handshake manager -func ixHandshakeStage0(f *Interface, vpnIp iputil.VpnIp, hostinfo *HostInfo) { - // This queries the lighthouse if we don't know a remote for the host - // We do it here to provoke the lighthouse to preempt our timer wheel and trigger the stage 1 packet to send - // more quickly, effect is a quicker handshake. - if hostinfo.remote == nil { - f.lightHouse.QueryServer(vpnIp, f) - } - - err := f.handshakeManager.AddIndexHostInfo(hostinfo) +func ixHandshakeStage0(f *Interface, hostinfo *HostInfo) bool { + err := f.handshakeManager.allocateIndex(hostinfo) if err != nil { - f.l.WithError(err).WithField("vpnIp", vpnIp). + f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp). WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to generate index") - return + return false } - ci := hostinfo.ConnectionState + certState := f.pki.GetCertState() + ci := NewConnectionState(f.l, f.cipher, certState, true, noise.HandshakeIX, []byte{}, 0) + hostinfo.ConnectionState = ci hsProto := &NebulaHandshakeDetails{ InitiatorIndex: hostinfo.localIndexId, Time: uint64(time.Now().UnixNano()), - Cert: ci.certState.rawCertificateNoKey, + Cert: certState.RawCertificateNoKey, } if f.multiPort.Tx || f.multiPort.Rx { @@ -53,9 +48,9 @@ func ixHandshakeStage0(f *Interface, vpnIp iputil.VpnIp, hostinfo *HostInfo) { hsBytes, err = hs.Marshal() if err != nil { - f.l.WithError(err).WithField("vpnIp", vpnIp). + f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp). WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to marshal handshake message") - return + return false } h := header.Encode(make([]byte, header.Len), header.Version, header.Handshake, header.HandshakeIXPSK0, 0, 1) @@ -63,9 +58,9 @@ func ixHandshakeStage0(f *Interface, vpnIp iputil.VpnIp, hostinfo *HostInfo) { msg, _, _, err := ci.H.WriteMessage(h, hsBytes) if err != nil { - f.l.WithError(err).WithField("vpnIp", vpnIp). + f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp). WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage") - return + return false } // We are sending handshake packet 1, so we don't expect to receive @@ -75,10 +70,12 @@ func ixHandshakeStage0(f *Interface, vpnIp iputil.VpnIp, hostinfo *HostInfo) { hostinfo.HandshakePacket[0] = msg hostinfo.HandshakeReady = true hostinfo.handshakeStart = time.Now() + return true } func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []byte, h *header.H) { - ci := f.newConnectionState(f.l, false, noise.HandshakeIX, []byte{}, 0) + certState := f.pki.GetCertState() + ci := NewConnectionState(f.l, f.cipher, certState, false, noise.HandshakeIX, []byte{}, 0) // Mark packet 1 as seen so it doesn't show up as missed ci.window.Update(f.l, 1) @@ -100,7 +97,7 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []by return } - remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert, f.caPool) + remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert, f.pki.GetCAPool()) if err != nil { f.l.WithError(err).WithField("udpAddr", addr). WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).WithField("cert", remoteCert). @@ -190,7 +187,7 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []by Info("Handshake message received") hs.Details.ResponderIndex = myIndex - hs.Details.Cert = ci.certState.rawCertificateNoKey + hs.Details.Cert = certState.RawCertificateNoKey // Update the time in case their clock is way off from ours hs.Details.Time = uint64(time.Now().UnixNano()) @@ -467,7 +464,7 @@ func ixHandshakeStage2(f *Interface, addr *udp.Addr, via *ViaSender, hostinfo *H } } - remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert, f.caPool) + remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert, f.pki.GetCAPool()) if err != nil { f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr). WithField("cert", remoteCert).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). @@ -490,34 +487,30 @@ func ixHandshakeStage2(f *Interface, addr *udp.Addr, via *ViaSender, hostinfo *H Info("Incorrect host responded to handshake") // Release our old handshake from pending, it should not continue - f.handshakeManager.pendingHostMap.DeleteHostInfo(hostinfo) + f.handshakeManager.DeleteHostInfo(hostinfo) // Create a new hostinfo/handshake for the intended vpn ip - //TODO: this adds it to the timer wheel in a way that aggressively retries - newHostInfo := f.getOrHandshake(hostinfo.vpnIp) - newHostInfo.Lock() + f.handshakeManager.StartHandshake(hostinfo.vpnIp, func(newHostInfo *HostInfo) { + //TODO: this doesnt know if its being added or is being used for caching a packet + // Block the current used address + newHostInfo.remotes = hostinfo.remotes + newHostInfo.remotes.BlockRemote(addr) - // Block the current used address - newHostInfo.remotes = hostinfo.remotes - newHostInfo.remotes.BlockRemote(addr) + // Get the correct remote list for the host we did handshake with + hostinfo.remotes = f.lightHouse.QueryCache(vpnIp) - // Get the correct remote list for the host we did handshake with - hostinfo.remotes = f.lightHouse.QueryCache(vpnIp) + f.l.WithField("blockedUdpAddrs", newHostInfo.remotes.CopyBlockedRemotes()).WithField("vpnIp", vpnIp). + WithField("remotes", newHostInfo.remotes.CopyAddrs(f.hostMap.preferredRanges)). + Info("Blocked addresses for handshakes") - f.l.WithField("blockedUdpAddrs", newHostInfo.remotes.CopyBlockedRemotes()).WithField("vpnIp", vpnIp). - WithField("remotes", newHostInfo.remotes.CopyAddrs(f.hostMap.preferredRanges)). - Info("Blocked addresses for handshakes") + // Swap the packet store to benefit the original intended recipient + newHostInfo.packetStore = hostinfo.packetStore + hostinfo.packetStore = []*cachedPacket{} - // Swap the packet store to benefit the original intended recipient - hostinfo.ConnectionState.queueLock.Lock() - newHostInfo.packetStore = hostinfo.packetStore - hostinfo.packetStore = []*cachedPacket{} - hostinfo.ConnectionState.queueLock.Unlock() - - // Finally, put the correct vpn ip in the host info, tell them to close the tunnel, and return true to tear down - hostinfo.vpnIp = vpnIp - f.sendCloseTunnel(hostinfo) - newHostInfo.Unlock() + // Finally, put the correct vpn ip in the host info, tell them to close the tunnel, and return true to tear down + hostinfo.vpnIp = vpnIp + f.sendCloseTunnel(hostinfo) + }) return true } diff --git a/handshake_manager.go b/handshake_manager.go index 0918056..107b1f3 100644 --- a/handshake_manager.go +++ b/handshake_manager.go @@ -7,6 +7,7 @@ import ( "encoding/binary" "errors" "net" + "sync" "time" "github.com/rcrowley/go-metrics" @@ -42,15 +43,21 @@ type HandshakeConfig struct { } type HandshakeManager struct { - pendingHostMap *HostMap + // Mutex for interacting with the vpnIps and indexes maps + sync.RWMutex + + vpnIps map[iputil.VpnIp]*HostInfo + indexes map[uint32]*HostInfo + mainHostMap *HostMap lightHouse *LightHouse - outside *udp.Conn + outside udp.Conn config HandshakeConfig OutboundHandshakeTimer *LockingTimerWheel[iputil.VpnIp] messageMetrics *MessageMetrics metricInitiated metrics.Counter metricTimedOut metrics.Counter + f *Interface l *logrus.Logger multiPort MultiPortConfig @@ -60,9 +67,10 @@ type HandshakeManager struct { trigger chan iputil.VpnIp } -func NewHandshakeManager(l *logrus.Logger, tunCidr *net.IPNet, preferredRanges []*net.IPNet, mainHostMap *HostMap, lightHouse *LightHouse, outside *udp.Conn, config HandshakeConfig) *HandshakeManager { +func NewHandshakeManager(l *logrus.Logger, mainHostMap *HostMap, lightHouse *LightHouse, outside udp.Conn, config HandshakeConfig) *HandshakeManager { return &HandshakeManager{ - pendingHostMap: NewHostMap(l, "pending", tunCidr, preferredRanges), + vpnIps: map[iputil.VpnIp]*HostInfo{}, + indexes: map[uint32]*HostInfo{}, mainHostMap: mainHostMap, lightHouse: lightHouse, outside: outside, @@ -76,7 +84,7 @@ func NewHandshakeManager(l *logrus.Logger, tunCidr *net.IPNet, preferredRanges [ } } -func (c *HandshakeManager) Run(ctx context.Context, f EncWriter) { +func (c *HandshakeManager) Run(ctx context.Context) { clockSource := time.NewTicker(c.config.tryInterval) defer clockSource.Stop() @@ -85,27 +93,27 @@ func (c *HandshakeManager) Run(ctx context.Context, f EncWriter) { case <-ctx.Done(): return case vpnIP := <-c.trigger: - c.handleOutbound(vpnIP, f, true) + c.handleOutbound(vpnIP, true) case now := <-clockSource.C: - c.NextOutboundHandshakeTimerTick(now, f) + c.NextOutboundHandshakeTimerTick(now) } } } -func (c *HandshakeManager) NextOutboundHandshakeTimerTick(now time.Time, f EncWriter) { +func (c *HandshakeManager) NextOutboundHandshakeTimerTick(now time.Time) { c.OutboundHandshakeTimer.Advance(now) for { vpnIp, has := c.OutboundHandshakeTimer.Purge() if !has { break } - c.handleOutbound(vpnIp, f, false) + c.handleOutbound(vpnIp, false) } } -func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, f EncWriter, lighthouseTriggered bool) { - hostinfo, err := c.pendingHostMap.QueryVpnIp(vpnIp) - if err != nil { +func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTriggered bool) { + hostinfo := c.QueryVpnIp(vpnIp) + if hostinfo == nil { return } hostinfo.Lock() @@ -114,31 +122,34 @@ func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, f EncWriter, light // We may have raced to completion but now that we have a lock we should ensure we have not yet completed. if hostinfo.HandshakeComplete { // Ensure we don't exist in the pending hostmap anymore since we have completed - c.pendingHostMap.DeleteHostInfo(hostinfo) - return - } - - // Check if we have a handshake packet to transmit yet - if !hostinfo.HandshakeReady { - // There is currently a slight race in getOrHandshake due to ConnectionState not being part of the HostInfo directly - // Our hostinfo here was added to the pending map and the wheel may have ticked to us before we created ConnectionState - c.OutboundHandshakeTimer.Add(vpnIp, c.config.tryInterval*time.Duration(hostinfo.HandshakeCounter)) + c.DeleteHostInfo(hostinfo) return } // If we are out of time, clean up if hostinfo.HandshakeCounter >= c.config.retries { - hostinfo.logger(c.l).WithField("udpAddrs", hostinfo.remotes.CopyAddrs(c.pendingHostMap.preferredRanges)). + hostinfo.logger(c.l).WithField("udpAddrs", hostinfo.remotes.CopyAddrs(c.mainHostMap.preferredRanges)). WithField("initiatorIndex", hostinfo.localIndexId). WithField("remoteIndex", hostinfo.remoteIndexId). WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). WithField("durationNs", time.Since(hostinfo.handshakeStart).Nanoseconds()). Info("Handshake timed out") c.metricTimedOut.Inc(1) - c.pendingHostMap.DeleteHostInfo(hostinfo) + c.DeleteHostInfo(hostinfo) return } + // Increment the counter to increase our delay, linear backoff + hostinfo.HandshakeCounter++ + + // Check if we have a handshake packet to transmit yet + if !hostinfo.HandshakeReady { + if !ixHandshakeStage0(c.f, hostinfo) { + c.OutboundHandshakeTimer.Add(vpnIp, c.config.tryInterval*time.Duration(hostinfo.HandshakeCounter)) + return + } + } + // Get a remotes object if we don't already have one. // This is mainly to protect us as this should never be the case // NB ^ This comment doesn't jive. It's how the thing gets initialized. @@ -147,7 +158,7 @@ func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, f EncWriter, light hostinfo.remotes = c.lightHouse.QueryCache(vpnIp) } - remotes := hostinfo.remotes.CopyAddrs(c.pendingHostMap.preferredRanges) + remotes := hostinfo.remotes.CopyAddrs(c.mainHostMap.preferredRanges) remotesHaveChanged := !udp.AddrSlice(remotes).Equal(hostinfo.HandshakeLastRemotes) // We only care about a lighthouse trigger if we have new remotes to send to. @@ -166,15 +177,15 @@ func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, f EncWriter, light // If we only have 1 remote it is highly likely our query raced with the other host registered within the lighthouse // Our vpnIp here has a tunnel with a lighthouse but has yet to send a host update packet there so we only know about // the learned public ip for them. Query again to short circuit the promotion counter - c.lightHouse.QueryServer(vpnIp, f) + c.lightHouse.QueryServer(vpnIp, c.f) } // Send the handshake to all known ips, stage 2 takes care of assigning the hostinfo.remote based on the first to reply var sentTo []*udp.Addr var sentMultiport bool - hostinfo.remotes.ForEach(c.pendingHostMap.preferredRanges, func(addr *udp.Addr, _ bool) { + hostinfo.remotes.ForEach(c.mainHostMap.preferredRanges, func(addr *udp.Addr, _ bool) { c.messageMetrics.Tx(header.Handshake, header.MessageSubType(hostinfo.HandshakePacket[0][1]), 1) - err = c.outside.WriteTo(hostinfo.HandshakePacket[0], addr) + err := c.outside.WriteTo(hostinfo.HandshakePacket[0], addr) if err != nil { hostinfo.logger(c.l).WithField("udpAddr", addr). WithField("initiatorIndex", hostinfo.localIndexId). @@ -230,10 +241,10 @@ func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, f EncWriter, light if *relay == vpnIp || *relay == c.lightHouse.myVpnIp { continue } - relayHostInfo, err := c.mainHostMap.QueryVpnIp(*relay) - if err != nil || relayHostInfo.remote == nil { - hostinfo.logger(c.l).WithError(err).WithField("relay", relay.String()).Info("Establish tunnel to relay target") - f.Handshake(*relay) + relayHostInfo := c.mainHostMap.QueryVpnIp(*relay) + if relayHostInfo == nil || relayHostInfo.remote == nil { + hostinfo.logger(c.l).WithField("relay", relay.String()).Info("Establish tunnel to relay target") + c.f.Handshake(*relay) continue } // Check the relay HostInfo to see if we already established a relay through it @@ -241,7 +252,7 @@ func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, f EncWriter, light switch existingRelay.State { case Established: hostinfo.logger(c.l).WithField("relay", relay.String()).Info("Send handshake via relay") - f.SendVia(relayHostInfo, existingRelay, hostinfo.HandshakePacket[0], make([]byte, 12), make([]byte, mtu), false) + c.f.SendVia(relayHostInfo, existingRelay, hostinfo.HandshakePacket[0], make([]byte, 12), make([]byte, mtu), false) case Requested: hostinfo.logger(c.l).WithField("relay", relay.String()).Info("Re-send CreateRelay request") // Re-send the CreateRelay request, in case the previous one was lost. @@ -258,7 +269,7 @@ func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, f EncWriter, light Error("Failed to marshal Control message to create relay") } else { // This must send over the hostinfo, not over hm.Hosts[ip] - f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu)) + c.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu)) c.l.WithFields(logrus.Fields{ "relayFrom": c.lightHouse.myVpnIp, "relayTo": vpnIp, @@ -293,7 +304,7 @@ func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, f EncWriter, light WithError(err). Error("Failed to marshal Control message to create relay") } else { - f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu)) + c.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu)) c.l.WithFields(logrus.Fields{ "relayFrom": c.lightHouse.myVpnIp, "relayTo": vpnIp, @@ -306,23 +317,78 @@ func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, f EncWriter, light } } - // Increment the counter to increase our delay, linear backoff - hostinfo.HandshakeCounter++ - // If a lighthouse triggered this attempt then we are still in the timer wheel and do not need to re-add if !lighthouseTriggered { c.OutboundHandshakeTimer.Add(vpnIp, c.config.tryInterval*time.Duration(hostinfo.HandshakeCounter)) } } -func (c *HandshakeManager) AddVpnIp(vpnIp iputil.VpnIp, init func(*HostInfo)) *HostInfo { - hostinfo, created := c.pendingHostMap.AddVpnIp(vpnIp, init) - - if created { - c.OutboundHandshakeTimer.Add(vpnIp, c.config.tryInterval) - c.metricInitiated.Inc(1) +// GetOrHandshake will try to find a hostinfo with a fully formed tunnel or start a new handshake if one is not present +// The 2nd argument will be true if the hostinfo is ready to transmit traffic +func (hm *HandshakeManager) GetOrHandshake(vpnIp iputil.VpnIp, cacheCb func(*HostInfo)) (*HostInfo, bool) { + // Check the main hostmap and maintain a read lock if our host is not there + hm.mainHostMap.RLock() + if h, ok := hm.mainHostMap.Hosts[vpnIp]; ok { + hm.mainHostMap.RUnlock() + // Do not attempt promotion if you are a lighthouse + if !hm.lightHouse.amLighthouse { + h.TryPromoteBest(hm.mainHostMap.preferredRanges, hm.f) + } + return h, true } + defer hm.mainHostMap.RUnlock() + return hm.StartHandshake(vpnIp, cacheCb), false +} + +// StartHandshake will ensure a handshake is currently being attempted for the provided vpn ip +func (hm *HandshakeManager) StartHandshake(vpnIp iputil.VpnIp, cacheCb func(*HostInfo)) *HostInfo { + hm.Lock() + + if hostinfo, ok := hm.vpnIps[vpnIp]; ok { + // We are already trying to handshake with this vpn ip + if cacheCb != nil { + cacheCb(hostinfo) + } + hm.Unlock() + return hostinfo + } + + hostinfo := &HostInfo{ + vpnIp: vpnIp, + HandshakePacket: make(map[uint8][]byte, 0), + relayState: RelayState{ + relays: map[iputil.VpnIp]struct{}{}, + relayForByIp: map[iputil.VpnIp]*Relay{}, + relayForByIdx: map[uint32]*Relay{}, + }, + } + + hm.vpnIps[vpnIp] = hostinfo + hm.metricInitiated.Inc(1) + hm.OutboundHandshakeTimer.Add(vpnIp, hm.config.tryInterval) + + if cacheCb != nil { + cacheCb(hostinfo) + } + + // If this is a static host, we don't need to wait for the HostQueryReply + // We can trigger the handshake right now + _, doTrigger := hm.lightHouse.GetStaticHostList()[vpnIp] + if !doTrigger { + // Add any calculated remotes, and trigger early handshake if one found + doTrigger = hm.lightHouse.addCalculatedRemotes(vpnIp) + } + + if doTrigger { + select { + case hm.trigger <- vpnIp: + default: + } + } + + hm.Unlock() + hm.lightHouse.QueryServer(vpnIp, hm.f) return hostinfo } @@ -344,10 +410,10 @@ var ( // ErrLocalIndexCollision if we already have an entry in the main or pending // hostmap for the hostinfo.localIndexId. func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket uint8, f *Interface) (*HostInfo, error) { - c.pendingHostMap.Lock() - defer c.pendingHostMap.Unlock() c.mainHostMap.Lock() defer c.mainHostMap.Unlock() + c.Lock() + defer c.Unlock() // Check if we already have a tunnel with this vpn ip existingHostInfo, found := c.mainHostMap.Hosts[hostinfo.vpnIp] @@ -376,7 +442,7 @@ func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket return existingIndex, ErrLocalIndexCollision } - existingIndex, found = c.pendingHostMap.Indexes[hostinfo.localIndexId] + existingIndex, found = c.indexes[hostinfo.localIndexId] if found && existingIndex != hostinfo { // We have a collision, but for a different hostinfo return existingIndex, ErrLocalIndexCollision @@ -398,47 +464,47 @@ func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket // Complete is a simpler version of CheckAndComplete when we already know we // won't have a localIndexId collision because we already have an entry in the // pendingHostMap. An existing hostinfo is returned if there was one. -func (c *HandshakeManager) Complete(hostinfo *HostInfo, f *Interface) { - c.pendingHostMap.Lock() - defer c.pendingHostMap.Unlock() - c.mainHostMap.Lock() - defer c.mainHostMap.Unlock() +func (hm *HandshakeManager) Complete(hostinfo *HostInfo, f *Interface) { + hm.mainHostMap.Lock() + defer hm.mainHostMap.Unlock() + hm.Lock() + defer hm.Unlock() - existingRemoteIndex, found := c.mainHostMap.RemoteIndexes[hostinfo.remoteIndexId] + existingRemoteIndex, found := hm.mainHostMap.RemoteIndexes[hostinfo.remoteIndexId] if found && existingRemoteIndex != nil { // We have a collision, but this can happen since we can't control // the remote ID. Just log about the situation as a note. - hostinfo.logger(c.l). + hostinfo.logger(hm.l). WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", existingRemoteIndex.vpnIp). Info("New host shadows existing host remoteIndex") } // We need to remove from the pending hostmap first to avoid undoing work when after to the main hostmap. - c.pendingHostMap.unlockedDeleteHostInfo(hostinfo) - c.mainHostMap.unlockedAddHostInfo(hostinfo, f) + hm.unlockedDeleteHostInfo(hostinfo) + hm.mainHostMap.unlockedAddHostInfo(hostinfo, f) } -// AddIndexHostInfo generates a unique localIndexId for this HostInfo +// allocateIndex generates a unique localIndexId for this HostInfo // and adds it to the pendingHostMap. Will error if we are unable to generate // a unique localIndexId -func (c *HandshakeManager) AddIndexHostInfo(h *HostInfo) error { - c.pendingHostMap.Lock() - defer c.pendingHostMap.Unlock() - c.mainHostMap.RLock() - defer c.mainHostMap.RUnlock() +func (hm *HandshakeManager) allocateIndex(h *HostInfo) error { + hm.mainHostMap.RLock() + defer hm.mainHostMap.RUnlock() + hm.Lock() + defer hm.Unlock() for i := 0; i < 32; i++ { - index, err := generateIndex(c.l) + index, err := generateIndex(hm.l) if err != nil { return err } - _, inPending := c.pendingHostMap.Indexes[index] - _, inMain := c.mainHostMap.Indexes[index] + _, inPending := hm.indexes[index] + _, inMain := hm.mainHostMap.Indexes[index] if !inMain && !inPending { h.localIndexId = index - c.pendingHostMap.Indexes[index] = h + hm.indexes[index] = h return nil } } @@ -446,22 +512,73 @@ func (c *HandshakeManager) AddIndexHostInfo(h *HostInfo) error { return errors.New("failed to generate unique localIndexId") } -func (c *HandshakeManager) addRemoteIndexHostInfo(index uint32, h *HostInfo) { - c.pendingHostMap.addRemoteIndexHostInfo(index, h) -} - func (c *HandshakeManager) DeleteHostInfo(hostinfo *HostInfo) { - //l.Debugln("Deleting pending hostinfo :", hostinfo) - c.pendingHostMap.DeleteHostInfo(hostinfo) + c.Lock() + defer c.Unlock() + c.unlockedDeleteHostInfo(hostinfo) } -func (c *HandshakeManager) QueryIndex(index uint32) (*HostInfo, error) { - return c.pendingHostMap.QueryIndex(index) +func (c *HandshakeManager) unlockedDeleteHostInfo(hostinfo *HostInfo) { + delete(c.vpnIps, hostinfo.vpnIp) + if len(c.vpnIps) == 0 { + c.vpnIps = map[iputil.VpnIp]*HostInfo{} + } + + delete(c.indexes, hostinfo.localIndexId) + if len(c.vpnIps) == 0 { + c.indexes = map[uint32]*HostInfo{} + } + + if c.l.Level >= logrus.DebugLevel { + c.l.WithField("hostMap", m{"mapTotalSize": len(c.vpnIps), + "vpnIp": hostinfo.vpnIp, "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId}). + Debug("Pending hostmap hostInfo deleted") + } +} + +func (c *HandshakeManager) QueryVpnIp(vpnIp iputil.VpnIp) *HostInfo { + c.RLock() + defer c.RUnlock() + return c.vpnIps[vpnIp] +} + +func (c *HandshakeManager) QueryIndex(index uint32) *HostInfo { + c.RLock() + defer c.RUnlock() + return c.indexes[index] +} + +func (c *HandshakeManager) GetPreferredRanges() []*net.IPNet { + return c.mainHostMap.preferredRanges +} + +func (c *HandshakeManager) ForEachVpnIp(f controlEach) { + c.RLock() + defer c.RUnlock() + + for _, v := range c.vpnIps { + f(v) + } +} + +func (c *HandshakeManager) ForEachIndex(f controlEach) { + c.RLock() + defer c.RUnlock() + + for _, v := range c.indexes { + f(v) + } } func (c *HandshakeManager) EmitStats() { - c.pendingHostMap.EmitStats("pending") - c.mainHostMap.EmitStats("main") + c.RLock() + hostLen := len(c.vpnIps) + indexLen := len(c.indexes) + c.RUnlock() + + metrics.GetOrRegisterGauge("hostmap.pending.hosts", nil).Update(int64(hostLen)) + metrics.GetOrRegisterGauge("hostmap.pending.indexes", nil).Update(int64(indexLen)) + c.mainHostMap.EmitStats() } // Utility functions below diff --git a/handshake_manager_test.go b/handshake_manager_test.go index 3e39e48..d318a9d 100644 --- a/handshake_manager_test.go +++ b/handshake_manager_test.go @@ -14,31 +14,20 @@ import ( func Test_NewHandshakeManagerVpnIp(t *testing.T) { l := test.NewLogger() - _, tuncidr, _ := net.ParseCIDR("172.1.1.1/24") _, vpncidr, _ := net.ParseCIDR("172.1.1.1/24") _, localrange, _ := net.ParseCIDR("10.1.1.1/24") ip := iputil.Ip2VpnIp(net.ParseIP("172.1.1.2")) preferredRanges := []*net.IPNet{localrange} - mw := &mockEncWriter{} - mainHM := NewHostMap(l, "test", vpncidr, preferredRanges) + mainHM := NewHostMap(l, vpncidr, preferredRanges) lh := newTestLighthouse() - blah := NewHandshakeManager(l, tuncidr, preferredRanges, mainHM, lh, &udp.Conn{}, defaultHandshakeConfig) + blah := NewHandshakeManager(l, mainHM, lh, &udp.NoopConn{}, defaultHandshakeConfig) now := time.Now() - blah.NextOutboundHandshakeTimerTick(now, mw) + blah.NextOutboundHandshakeTimerTick(now) - var initCalled bool - initFunc := func(*HostInfo) { - initCalled = true - } - - i := blah.AddVpnIp(ip, initFunc) - assert.True(t, initCalled) - - initCalled = false - i2 := blah.AddVpnIp(ip, initFunc) - assert.False(t, initCalled) + i := blah.StartHandshake(ip, nil) + i2 := blah.StartHandshake(ip, nil) assert.Same(t, i, i2) i.remotes = NewRemoteList(nil) @@ -48,22 +37,22 @@ func Test_NewHandshakeManagerVpnIp(t *testing.T) { assert.Len(t, mainHM.Hosts, 0) // Confirm they are in the pending index list - assert.Contains(t, blah.pendingHostMap.Hosts, ip) + assert.Contains(t, blah.vpnIps, ip) // Jump ahead `HandshakeRetries` ticks, offset by one to get the sleep logic right for i := 1; i <= DefaultHandshakeRetries+1; i++ { now = now.Add(time.Duration(i) * DefaultHandshakeTryInterval) - blah.NextOutboundHandshakeTimerTick(now, mw) + blah.NextOutboundHandshakeTimerTick(now) } // Confirm they are still in the pending index list - assert.Contains(t, blah.pendingHostMap.Hosts, ip) + assert.Contains(t, blah.vpnIps, ip) // Tick 1 more time, a minute will certainly flush it out - blah.NextOutboundHandshakeTimerTick(now.Add(time.Minute), mw) + blah.NextOutboundHandshakeTimerTick(now.Add(time.Minute)) // Confirm they have been removed - assert.NotContains(t, blah.pendingHostMap.Hosts, ip) + assert.NotContains(t, blah.vpnIps, ip) } func testCountTimerWheelEntries(tw *LockingTimerWheel[iputil.VpnIp]) (c int) { diff --git a/hostmap.go b/hostmap.go index 29c0177..8f36f65 100644 --- a/hostmap.go +++ b/hostmap.go @@ -2,7 +2,6 @@ package nebula import ( "errors" - "fmt" "net" "sync" "sync/atomic" @@ -18,8 +17,9 @@ import ( ) // const ProbeLen = 100 -const PromoteEvery = 1000 -const ReQueryEvery = 5000 +const defaultPromoteEvery = 1000 // Count of packets sent before we try moving a tunnel to a preferred underlay ip address +const defaultReQueryEvery = 5000 // Count of packets sent before re-querying a hostinfo to the lighthouse +const defaultReQueryWait = time.Minute // Minimum amount of seconds to wait before re-querying a hostinfo the lighthouse. Evaluated every ReQueryEvery const MaxRemotes = 10 // MaxHostInfosPerVpnIp is the max number of hostinfos we will track for a given vpn ip @@ -52,7 +52,6 @@ type Relay struct { type HostMap struct { sync.RWMutex //Because we concurrently read and write to our maps - name string Indexes map[uint32]*HostInfo Relays map[uint32]*HostInfo // Maps a Relay IDX to a Relay HostInfo object RemoteIndexes map[uint32]*HostInfo @@ -205,13 +204,13 @@ type HostInfo struct { multiportTx bool multiportRx bool ConnectionState *ConnectionState - handshakeStart time.Time //todo: this an entry in the handshake manager - HandshakeReady bool //todo: being in the manager means you are ready - HandshakeCounter int //todo: another handshake manager entry - HandshakeLastRemotes []*udp.Addr //todo: another handshake manager entry, which remotes we sent to last time - HandshakeComplete bool //todo: this should go away in favor of ConnectionState.ready - HandshakePacket map[uint8][]byte //todo: this is other handshake manager entry - packetStore []*cachedPacket //todo: this is other handshake manager entry + handshakeStart time.Time //todo: this an entry in the handshake manager + HandshakeReady bool //todo: being in the manager means you are ready + HandshakeCounter int //todo: another handshake manager entry + HandshakeLastRemotes []*udp.Addr //todo: another handshake manager entry, which remotes we sent to last time + HandshakeComplete bool //todo: this should go away in favor of ConnectionState.ready + HandshakePacket map[uint8][]byte + packetStore []*cachedPacket //todo: this is other handshake manager entry remoteIndexId uint32 localIndexId uint32 vpnIp iputil.VpnIp @@ -219,6 +218,10 @@ type HostInfo struct { remoteCidr *cidr.Tree4 relayState RelayState + // nextLHQuery is the earliest we can ask the lighthouse for new information. + // This is used to limit lighthouse re-queries in chatty clients + nextLHQuery atomic.Int64 + // lastRebindCount is the other side of Interface.rebindCount, if these values don't match then we need to ask LH // for a punch from the remote end of this tunnel. The goal being to prime their conntrack for our traffic just like // with a handshake @@ -257,13 +260,12 @@ type cachedPacketMetrics struct { dropped metrics.Counter } -func NewHostMap(l *logrus.Logger, name string, vpnCIDR *net.IPNet, preferredRanges []*net.IPNet) *HostMap { +func NewHostMap(l *logrus.Logger, vpnCIDR *net.IPNet, preferredRanges []*net.IPNet) *HostMap { h := map[iputil.VpnIp]*HostInfo{} i := map[uint32]*HostInfo{} r := map[uint32]*HostInfo{} relays := map[uint32]*HostInfo{} m := HostMap{ - name: name, Indexes: i, Relays: relays, RemoteIndexes: r, @@ -275,8 +277,8 @@ func NewHostMap(l *logrus.Logger, name string, vpnCIDR *net.IPNet, preferredRang return &m } -// UpdateStats takes a name and reports host and index counts to the stats collection system -func (hm *HostMap) EmitStats(name string) { +// EmitStats reports host, index, and relay counts to the stats collection system +func (hm *HostMap) EmitStats() { hm.RLock() hostLen := len(hm.Hosts) indexLen := len(hm.Indexes) @@ -284,10 +286,10 @@ func (hm *HostMap) EmitStats(name string) { relaysLen := len(hm.Relays) hm.RUnlock() - metrics.GetOrRegisterGauge("hostmap."+name+".hosts", nil).Update(int64(hostLen)) - metrics.GetOrRegisterGauge("hostmap."+name+".indexes", nil).Update(int64(indexLen)) - metrics.GetOrRegisterGauge("hostmap."+name+".remoteIndexes", nil).Update(int64(remoteIndexLen)) - metrics.GetOrRegisterGauge("hostmap."+name+".relayIndexes", nil).Update(int64(relaysLen)) + metrics.GetOrRegisterGauge("hostmap.main.hosts", nil).Update(int64(hostLen)) + metrics.GetOrRegisterGauge("hostmap.main.indexes", nil).Update(int64(indexLen)) + metrics.GetOrRegisterGauge("hostmap.main.remoteIndexes", nil).Update(int64(remoteIndexLen)) + metrics.GetOrRegisterGauge("hostmap.main.relayIndexes", nil).Update(int64(relaysLen)) } func (hm *HostMap) RemoveRelay(localIdx uint32) { @@ -301,88 +303,6 @@ func (hm *HostMap) RemoveRelay(localIdx uint32) { hm.Unlock() } -func (hm *HostMap) GetIndexByVpnIp(vpnIp iputil.VpnIp) (uint32, error) { - hm.RLock() - if i, ok := hm.Hosts[vpnIp]; ok { - index := i.localIndexId - hm.RUnlock() - return index, nil - } - hm.RUnlock() - return 0, errors.New("vpn IP not found") -} - -func (hm *HostMap) Add(ip iputil.VpnIp, hostinfo *HostInfo) { - hm.Lock() - hm.Hosts[ip] = hostinfo - hm.Unlock() -} - -func (hm *HostMap) AddVpnIp(vpnIp iputil.VpnIp, init func(hostinfo *HostInfo)) (hostinfo *HostInfo, created bool) { - hm.RLock() - if h, ok := hm.Hosts[vpnIp]; !ok { - hm.RUnlock() - h = &HostInfo{ - vpnIp: vpnIp, - HandshakePacket: make(map[uint8][]byte, 0), - relayState: RelayState{ - relays: map[iputil.VpnIp]struct{}{}, - relayForByIp: map[iputil.VpnIp]*Relay{}, - relayForByIdx: map[uint32]*Relay{}, - }, - } - if init != nil { - init(h) - } - hm.Lock() - hm.Hosts[vpnIp] = h - hm.Unlock() - return h, true - } else { - hm.RUnlock() - return h, false - } -} - -// Only used by pendingHostMap when the remote index is not initially known -func (hm *HostMap) addRemoteIndexHostInfo(index uint32, h *HostInfo) { - hm.Lock() - h.remoteIndexId = index - hm.RemoteIndexes[index] = h - hm.Unlock() - - if hm.l.Level > logrus.DebugLevel { - hm.l.WithField("hostMap", m{"mapName": hm.name, "indexNumber": index, "mapTotalSize": len(hm.Indexes), - "hostinfo": m{"existing": true, "localIndexId": h.localIndexId, "hostId": h.vpnIp}}). - Debug("Hostmap remoteIndex added") - } -} - -// DeleteReverseIndex is used to clean up on recv_error -// This function should only ever be called on the pending hostmap -func (hm *HostMap) DeleteReverseIndex(index uint32) { - hm.Lock() - hostinfo, ok := hm.RemoteIndexes[index] - if ok { - delete(hm.Indexes, hostinfo.localIndexId) - delete(hm.RemoteIndexes, index) - - // Check if we have an entry under hostId that matches the same hostinfo - // instance. Clean it up as well if we do (they might not match in pendingHostmap) - var hostinfo2 *HostInfo - hostinfo2, ok = hm.Hosts[hostinfo.vpnIp] - if ok && hostinfo2 == hostinfo { - delete(hm.Hosts, hostinfo.vpnIp) - } - } - hm.Unlock() - - if hm.l.Level >= logrus.DebugLevel { - hm.l.WithField("hostMap", m{"mapName": hm.name, "indexNumber": index, "mapTotalSize": len(hm.Indexes)}). - Debug("Hostmap remote index deleted") - } -} - // DeleteHostInfo will fully unlink the hostinfo and return true if it was the final hostinfo for this vpn ip func (hm *HostMap) DeleteHostInfo(hostinfo *HostInfo) bool { // Delete the host itself, ensuring it's not modified anymore @@ -395,12 +315,6 @@ func (hm *HostMap) DeleteHostInfo(hostinfo *HostInfo) bool { return final } -func (hm *HostMap) DeleteRelayIdx(localIdx uint32) { - hm.Lock() - defer hm.Unlock() - delete(hm.RemoteIndexes, localIdx) -} - func (hm *HostMap) MakePrimary(hostinfo *HostInfo) { hm.Lock() defer hm.Unlock() @@ -478,7 +392,7 @@ func (hm *HostMap) unlockedDeleteHostInfo(hostinfo *HostInfo) { } if hm.l.Level >= logrus.DebugLevel { - hm.l.WithField("hostMap", m{"mapName": hm.name, "mapTotalSize": len(hm.Hosts), + hm.l.WithField("hostMap", m{"mapTotalSize": len(hm.Hosts), "vpnIp": hostinfo.vpnIp, "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId}). Debug("Hostmap hostInfo deleted") } @@ -488,55 +402,41 @@ func (hm *HostMap) unlockedDeleteHostInfo(hostinfo *HostInfo) { } } -func (hm *HostMap) QueryIndex(index uint32) (*HostInfo, error) { - //TODO: we probably just want to return bool instead of error, or at least a static error +func (hm *HostMap) QueryIndex(index uint32) *HostInfo { hm.RLock() if h, ok := hm.Indexes[index]; ok { hm.RUnlock() - return h, nil + return h } else { hm.RUnlock() - return nil, errors.New("unable to find index") + return nil } } -// Retrieves a HostInfo by Index. Returns whether the HostInfo is primary at time of query. -// This helper exists so that the hostinfo.prev pointer can be read while the hostmap lock is held. -func (hm *HostMap) QueryIndexIsPrimary(index uint32) (*HostInfo, bool, error) { - //TODO: we probably just want to return bool instead of error, or at least a static error - hm.RLock() - if h, ok := hm.Indexes[index]; ok { - hm.RUnlock() - return h, h.prev == nil, nil - } else { - hm.RUnlock() - return nil, false, errors.New("unable to find index") - } -} -func (hm *HostMap) QueryRelayIndex(index uint32) (*HostInfo, error) { +func (hm *HostMap) QueryRelayIndex(index uint32) *HostInfo { //TODO: we probably just want to return bool instead of error, or at least a static error hm.RLock() if h, ok := hm.Relays[index]; ok { hm.RUnlock() - return h, nil + return h } else { hm.RUnlock() - return nil, errors.New("unable to find index") + return nil } } -func (hm *HostMap) QueryReverseIndex(index uint32) (*HostInfo, error) { +func (hm *HostMap) QueryReverseIndex(index uint32) *HostInfo { hm.RLock() if h, ok := hm.RemoteIndexes[index]; ok { hm.RUnlock() - return h, nil + return h } else { hm.RUnlock() - return nil, fmt.Errorf("unable to find reverse index or connectionstate nil in %s hostmap", hm.name) + return nil } } -func (hm *HostMap) QueryVpnIp(vpnIp iputil.VpnIp) (*HostInfo, error) { +func (hm *HostMap) QueryVpnIp(vpnIp iputil.VpnIp) *HostInfo { return hm.queryVpnIp(vpnIp, nil) } @@ -558,13 +458,7 @@ func (hm *HostMap) QueryVpnIpRelayFor(targetIp, relayHostIp iputil.VpnIp) (*Host return nil, nil, errors.New("unable to find host with relay") } -// PromoteBestQueryVpnIp will attempt to lazily switch to the best remote every -// `PromoteEvery` calls to this function for a given host. -func (hm *HostMap) PromoteBestQueryVpnIp(vpnIp iputil.VpnIp, ifce *Interface) (*HostInfo, error) { - return hm.queryVpnIp(vpnIp, ifce) -} - -func (hm *HostMap) queryVpnIp(vpnIp iputil.VpnIp, promoteIfce *Interface) (*HostInfo, error) { +func (hm *HostMap) queryVpnIp(vpnIp iputil.VpnIp, promoteIfce *Interface) *HostInfo { hm.RLock() if h, ok := hm.Hosts[vpnIp]; ok { hm.RUnlock() @@ -572,12 +466,12 @@ func (hm *HostMap) queryVpnIp(vpnIp iputil.VpnIp, promoteIfce *Interface) (*Host if promoteIfce != nil && !promoteIfce.lightHouse.amLighthouse { h.TryPromoteBest(hm.preferredRanges, promoteIfce) } - return h, nil + return h } hm.RUnlock() - return nil, errors.New("unable to find host") + return nil } // unlockedAddHostInfo assumes you have a write-lock and will add a hostinfo object to the hostmap Indexes and RemoteIndexes maps. @@ -600,7 +494,7 @@ func (hm *HostMap) unlockedAddHostInfo(hostinfo *HostInfo, f *Interface) { hm.RemoteIndexes[hostinfo.remoteIndexId] = hostinfo if hm.l.Level >= logrus.DebugLevel { - hm.l.WithField("hostMap", m{"mapName": hm.name, "vpnIp": hostinfo.vpnIp, "mapTotalSize": len(hm.Hosts), + hm.l.WithField("hostMap", m{"vpnIp": hostinfo.vpnIp, "mapTotalSize": len(hm.Hosts), "hostinfo": m{"existing": true, "localIndexId": hostinfo.localIndexId, "hostId": hostinfo.vpnIp}}). Debug("Hostmap vpnIp added") } @@ -616,11 +510,33 @@ func (hm *HostMap) unlockedAddHostInfo(hostinfo *HostInfo, f *Interface) { } } +func (hm *HostMap) GetPreferredRanges() []*net.IPNet { + return hm.preferredRanges +} + +func (hm *HostMap) ForEachVpnIp(f controlEach) { + hm.RLock() + defer hm.RUnlock() + + for _, v := range hm.Hosts { + f(v) + } +} + +func (hm *HostMap) ForEachIndex(f controlEach) { + hm.RLock() + defer hm.RUnlock() + + for _, v := range hm.Indexes { + f(v) + } +} + // TryPromoteBest handles re-querying lighthouses and probing for better paths // NOTE: It is an error to call this if you are a lighthouse since they should not roam clients! func (i *HostInfo) TryPromoteBest(preferredRanges []*net.IPNet, ifce *Interface) { c := i.promoteCounter.Add(1) - if c%PromoteEvery == 0 { + if c%ifce.tryPromoteEvery.Load() == 0 { // The lock here is currently protecting i.remote access i.RLock() remote := i.remote @@ -648,12 +564,18 @@ func (i *HostInfo) TryPromoteBest(preferredRanges []*net.IPNet, ifce *Interface) } // Re query our lighthouses for new remotes occasionally - if c%ReQueryEvery == 0 && ifce.lightHouse != nil { + if c%ifce.reQueryEvery.Load() == 0 && ifce.lightHouse != nil { + now := time.Now().UnixNano() + if now < i.nextLHQuery.Load() { + return + } + + i.nextLHQuery.Store(now + ifce.reQueryWait.Load()) ifce.lightHouse.QueryServer(i.vpnIp, ifce) } } -func (i *HostInfo) cachePacket(l *logrus.Logger, t header.MessageType, st header.MessageSubType, packet []byte, f packetCallback, m *cachedPacketMetrics) { +func (i *HostInfo) unlockedCachePacket(l *logrus.Logger, t header.MessageType, st header.MessageSubType, packet []byte, f packetCallback, m *cachedPacketMetrics) { //TODO: return the error so we can log with more context if len(i.packetStore) < 100 { tempPacket := make([]byte, len(packet)) @@ -682,7 +604,6 @@ func (i *HostInfo) handshakeComplete(l *logrus.Logger, m *cachedPacketMetrics) { //TODO: HandshakeComplete means send stored packets and ConnectionState.ready means we are ready to send //TODO: if the transition from HandhsakeComplete to ConnectionState.ready happens all within this function they are identical - i.ConnectionState.queueLock.Lock() i.HandshakeComplete = true //TODO: this should be managed by the handshake state machine to set it based on how many handshake were seen. // Clamping it to 2 gets us out of the woods for now @@ -704,7 +625,6 @@ func (i *HostInfo) handshakeComplete(l *logrus.Logger, m *cachedPacketMetrics) { i.remotes.ResetBlockedRemotes() i.packetStore = make([]*cachedPacket, 0) i.ConnectionState.ready = true - i.ConnectionState.queueLock.Unlock() } func (i *HostInfo) GetCert() *cert.NebulaCertificate { diff --git a/hostmap_test.go b/hostmap_test.go index e523a21..c1c0dce 100644 --- a/hostmap_test.go +++ b/hostmap_test.go @@ -11,7 +11,7 @@ import ( func TestHostMap_MakePrimary(t *testing.T) { l := test.NewLogger() hm := NewHostMap( - l, "test", + l, &net.IPNet{ IP: net.IP{10, 0, 0, 1}, Mask: net.IPMask{255, 255, 255, 0}, @@ -32,7 +32,7 @@ func TestHostMap_MakePrimary(t *testing.T) { hm.unlockedAddHostInfo(h1, f) // Make sure we go h1 -> h2 -> h3 -> h4 - prim, _ := hm.QueryVpnIp(1) + prim := hm.QueryVpnIp(1) assert.Equal(t, h1.localIndexId, prim.localIndexId) assert.Equal(t, h2.localIndexId, prim.next.localIndexId) assert.Nil(t, prim.prev) @@ -47,7 +47,7 @@ func TestHostMap_MakePrimary(t *testing.T) { hm.MakePrimary(h3) // Make sure we go h3 -> h1 -> h2 -> h4 - prim, _ = hm.QueryVpnIp(1) + prim = hm.QueryVpnIp(1) assert.Equal(t, h3.localIndexId, prim.localIndexId) assert.Equal(t, h1.localIndexId, prim.next.localIndexId) assert.Nil(t, prim.prev) @@ -62,7 +62,7 @@ func TestHostMap_MakePrimary(t *testing.T) { hm.MakePrimary(h4) // Make sure we go h4 -> h3 -> h1 -> h2 - prim, _ = hm.QueryVpnIp(1) + prim = hm.QueryVpnIp(1) assert.Equal(t, h4.localIndexId, prim.localIndexId) assert.Equal(t, h3.localIndexId, prim.next.localIndexId) assert.Nil(t, prim.prev) @@ -77,7 +77,7 @@ func TestHostMap_MakePrimary(t *testing.T) { hm.MakePrimary(h4) // Make sure we go h4 -> h3 -> h1 -> h2 - prim, _ = hm.QueryVpnIp(1) + prim = hm.QueryVpnIp(1) assert.Equal(t, h4.localIndexId, prim.localIndexId) assert.Equal(t, h3.localIndexId, prim.next.localIndexId) assert.Nil(t, prim.prev) @@ -92,7 +92,7 @@ func TestHostMap_MakePrimary(t *testing.T) { func TestHostMap_DeleteHostInfo(t *testing.T) { l := test.NewLogger() hm := NewHostMap( - l, "test", + l, &net.IPNet{ IP: net.IP{10, 0, 0, 1}, Mask: net.IPMask{255, 255, 255, 0}, @@ -119,11 +119,11 @@ func TestHostMap_DeleteHostInfo(t *testing.T) { // h6 should be deleted assert.Nil(t, h6.next) assert.Nil(t, h6.prev) - _, err := hm.QueryIndex(h6.localIndexId) - assert.Error(t, err) + h := hm.QueryIndex(h6.localIndexId) + assert.Nil(t, h) // Make sure we go h1 -> h2 -> h3 -> h4 -> h5 - prim, _ := hm.QueryVpnIp(1) + prim := hm.QueryVpnIp(1) assert.Equal(t, h1.localIndexId, prim.localIndexId) assert.Equal(t, h2.localIndexId, prim.next.localIndexId) assert.Nil(t, prim.prev) @@ -142,7 +142,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) { assert.Nil(t, h1.next) // Make sure we go h2 -> h3 -> h4 -> h5 - prim, _ = hm.QueryVpnIp(1) + prim = hm.QueryVpnIp(1) assert.Equal(t, h2.localIndexId, prim.localIndexId) assert.Equal(t, h3.localIndexId, prim.next.localIndexId) assert.Nil(t, prim.prev) @@ -160,7 +160,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) { assert.Nil(t, h3.next) // Make sure we go h2 -> h4 -> h5 - prim, _ = hm.QueryVpnIp(1) + prim = hm.QueryVpnIp(1) assert.Equal(t, h2.localIndexId, prim.localIndexId) assert.Equal(t, h4.localIndexId, prim.next.localIndexId) assert.Nil(t, prim.prev) @@ -176,7 +176,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) { assert.Nil(t, h5.next) // Make sure we go h2 -> h4 - prim, _ = hm.QueryVpnIp(1) + prim = hm.QueryVpnIp(1) assert.Equal(t, h2.localIndexId, prim.localIndexId) assert.Equal(t, h4.localIndexId, prim.next.localIndexId) assert.Nil(t, prim.prev) @@ -190,7 +190,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) { assert.Nil(t, h2.next) // Make sure we only have h4 - prim, _ = hm.QueryVpnIp(1) + prim = hm.QueryVpnIp(1) assert.Equal(t, h4.localIndexId, prim.localIndexId) assert.Nil(t, prim.prev) assert.Nil(t, prim.next) @@ -202,6 +202,6 @@ func TestHostMap_DeleteHostInfo(t *testing.T) { assert.Nil(t, h4.next) // Make sure we have nil - prim, _ = hm.QueryVpnIp(1) + prim = hm.QueryVpnIp(1) assert.Nil(t, prim) } diff --git a/inside.go b/inside.go index 58cd7b2..728dddd 100644 --- a/inside.go +++ b/inside.go @@ -1,7 +1,6 @@ package nebula import ( - "github.com/flynn/noise" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/firewall" "github.com/slackhq/nebula/header" @@ -45,7 +44,10 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet return } - hostinfo := f.getOrHandshake(fwPacket.RemoteIP) + hostinfo, ready := f.getOrHandshake(fwPacket.RemoteIP, func(h *HostInfo) { + h.unlockedCachePacket(f.l, header.Message, 0, packet, f.sendMessageNow, f.cachedPacketMetrics) + }) + if hostinfo == nil { f.rejectInside(packet, out, q) if f.l.Level >= logrus.DebugLevel { @@ -55,23 +57,14 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet } return } - ci := hostinfo.ConnectionState - if !ci.ready { - // Because we might be sending stored packets, lock here to stop new things going to - // the packet queue. - ci.queueLock.Lock() - if !ci.ready { - hostinfo.cachePacket(f.l, header.Message, 0, packet, f.sendMessageNow, f.cachedPacketMetrics) - ci.queueLock.Unlock() - return - } - ci.queueLock.Unlock() + if !ready { + return } - dropReason := f.firewall.Drop(packet, *fwPacket, false, hostinfo, f.caPool, localCache) + dropReason := f.firewall.Drop(packet, *fwPacket, false, hostinfo, f.pki.GetCAPool(), localCache) if dropReason == nil { - f.sendNoMetrics(header.Message, 0, ci, hostinfo, nil, packet, nb, out, q, fwPacket) + f.sendNoMetrics(header.Message, 0, hostinfo.ConnectionState, hostinfo, nil, packet, nb, out, q, fwPacket) } else { f.rejectInside(packet, out, q) @@ -110,71 +103,20 @@ func (f *Interface) rejectOutside(packet []byte, ci *ConnectionState, hostinfo * } func (f *Interface) Handshake(vpnIp iputil.VpnIp) { - f.getOrHandshake(vpnIp) + f.getOrHandshake(vpnIp, nil) } -// getOrHandshake returns nil if the vpnIp is not routable -func (f *Interface) getOrHandshake(vpnIp iputil.VpnIp) *HostInfo { +// getOrHandshake returns nil if the vpnIp is not routable. +// If the 2nd return var is false then the hostinfo is not ready to be used in a tunnel +func (f *Interface) getOrHandshake(vpnIp iputil.VpnIp, cacheCallback func(info *HostInfo)) (*HostInfo, bool) { if !ipMaskContains(f.lightHouse.myVpnIp, f.lightHouse.myVpnZeros, vpnIp) { vpnIp = f.inside.RouteFor(vpnIp) if vpnIp == 0 { - return nil - } - } - hostinfo, err := f.hostMap.PromoteBestQueryVpnIp(vpnIp, f) - - //if err != nil || hostinfo.ConnectionState == nil { - if err != nil { - hostinfo, err = f.handshakeManager.pendingHostMap.QueryVpnIp(vpnIp) - if err != nil { - hostinfo = f.handshakeManager.AddVpnIp(vpnIp, f.initHostInfo) - } - } - ci := hostinfo.ConnectionState - - if ci != nil && ci.eKey != nil && ci.ready { - return hostinfo - } - - // Handshake is not ready, we need to grab the lock now before we start the handshake process - hostinfo.Lock() - defer hostinfo.Unlock() - - // Double check, now that we have the lock - ci = hostinfo.ConnectionState - if ci != nil && ci.eKey != nil && ci.ready { - return hostinfo - } - - // If we have already created the handshake packet, we don't want to call the function at all. - if !hostinfo.HandshakeReady { - ixHandshakeStage0(f, vpnIp, hostinfo) - // FIXME: Maybe make XX selectable, but probably not since psk makes it nearly pointless for us. - //xx_handshakeStage0(f, ip, hostinfo) - - // If this is a static host, we don't need to wait for the HostQueryReply - // We can trigger the handshake right now - _, doTrigger := f.lightHouse.GetStaticHostList()[vpnIp] - if !doTrigger { - // Add any calculated remotes, and trigger early handshake if one found - doTrigger = f.lightHouse.addCalculatedRemotes(vpnIp) - } - - if doTrigger { - select { - case f.handshakeManager.trigger <- vpnIp: - default: - } + return nil, false } } - return hostinfo -} - -// initHostInfo is the init function to pass to (*HandshakeManager).AddVpnIP that -// will create the initial Noise ConnectionState -func (f *Interface) initHostInfo(hostinfo *HostInfo) { - hostinfo.ConnectionState = f.newConnectionState(f.l, true, noise.HandshakeIX, []byte{}, 0) + return f.handshakeManager.GetOrHandshake(vpnIp, cacheCallback) } func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, nb, out []byte) { @@ -186,7 +128,7 @@ func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubTyp } // check if packet is in outbound fw rules - dropReason := f.firewall.Drop(p, *fp, false, hostinfo, f.caPool, nil) + dropReason := f.firewall.Drop(p, *fp, false, hostinfo, f.pki.GetCAPool(), nil) if dropReason != nil { if f.l.Level >= logrus.DebugLevel { f.l.WithField("fwPacket", fp). @@ -201,7 +143,10 @@ func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubTyp // SendMessageToVpnIp handles real ip:port lookup and sends to the current best known address for vpnIp func (f *Interface) SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp iputil.VpnIp, p, nb, out []byte) { - hostInfo := f.getOrHandshake(vpnIp) + hostInfo, ready := f.getOrHandshake(vpnIp, func(h *HostInfo) { + h.unlockedCachePacket(f.l, t, st, p, f.SendMessageToHostInfo, f.cachedPacketMetrics) + }) + if hostInfo == nil { if f.l.Level >= logrus.DebugLevel { f.l.WithField("vpnIp", vpnIp). @@ -210,16 +155,8 @@ func (f *Interface) SendMessageToVpnIp(t header.MessageType, st header.MessageSu return } - if !hostInfo.ConnectionState.ready { - // Because we might be sending stored packets, lock here to stop new things going to - // the packet queue. - hostInfo.ConnectionState.queueLock.Lock() - if !hostInfo.ConnectionState.ready { - hostInfo.cachePacket(f.l, t, st, p, f.SendMessageToHostInfo, f.cachedPacketMetrics) - hostInfo.ConnectionState.queueLock.Unlock() - return - } - hostInfo.ConnectionState.queueLock.Unlock() + if !ready { + return } f.SendMessageToHostInfo(t, st, hostInfo, p, nb, out) @@ -239,7 +176,7 @@ func (f *Interface) sendTo(t header.MessageType, st header.MessageSubType, ci *C f.sendNoMetrics(t, st, ci, hostinfo, remote, p, nb, out, 0, nil) } -// sendVia sends a payload through a Relay tunnel. No authentication or encryption is done +// SendVia sends a payload through a Relay tunnel. No authentication or encryption is done // to the payload for the ultimate target host, making this a useful method for sending // handshake messages to peers through relay tunnels. // via is the HostInfo through which the message is relayed. diff --git a/interface.go b/interface.go index 474cdf0..58fa2a2 100644 --- a/interface.go +++ b/interface.go @@ -13,7 +13,6 @@ import ( "github.com/rcrowley/go-metrics" "github.com/sirupsen/logrus" - "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/firewall" "github.com/slackhq/nebula/header" @@ -26,9 +25,9 @@ const mtu = 9001 type InterfaceConfig struct { HostMap *HostMap - Outside *udp.Conn + Outside udp.Conn Inside overlay.Device - certState *CertState + pki *PKI Cipher string Firewall *Firewall ServeDns bool @@ -41,20 +40,23 @@ type InterfaceConfig struct { routines int MessageMetrics *MessageMetrics version string - caPool *cert.NebulaCAPool disconnectInvalid bool relayManager *relayManager punchy *Punchy + tryPromoteEvery uint32 + reQueryEvery uint32 + reQueryWait time.Duration + ConntrackCacheTimeout time.Duration l *logrus.Logger } type Interface struct { hostMap *HostMap - outside *udp.Conn + outside udp.Conn inside overlay.Device - certState atomic.Pointer[CertState] + pki *PKI cipher string firewall *Firewall connectionManager *connectionManager @@ -67,11 +69,14 @@ type Interface struct { dropLocalBroadcast bool dropMulticast bool routines int - caPool *cert.NebulaCAPool disconnectInvalid bool closed atomic.Bool relayManager *relayManager + tryPromoteEvery atomic.Uint32 + reQueryEvery atomic.Uint32 + reQueryWait atomic.Int64 + sendRecvErrorConfig sendRecvErrorConfig // rebindCount is used to decide if an active tunnel should trigger a punch notification through a lighthouse @@ -80,7 +85,7 @@ type Interface struct { conntrackCacheTimeout time.Duration - writers []*udp.Conn + writers []udp.Conn readers []io.ReadWriteCloser udpRaw *udp.RawConn @@ -156,15 +161,17 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) { if c.Inside == nil { return nil, errors.New("no inside interface (tun)") } - if c.certState == nil { + if c.pki == nil { return nil, errors.New("no certificate state") } if c.Firewall == nil { return nil, errors.New("no firewall rules") } - myVpnIp := iputil.Ip2VpnIp(c.certState.certificate.Details.Ips[0].IP) + certificate := c.pki.GetCertState().Certificate + myVpnIp := iputil.Ip2VpnIp(certificate.Details.Ips[0].IP) ifce := &Interface{ + pki: c.pki, hostMap: c.HostMap, outside: c.Outside, inside: c.Inside, @@ -174,14 +181,13 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) { handshakeManager: c.HandshakeManager, createTime: time.Now(), lightHouse: c.lightHouse, - localBroadcast: myVpnIp | ^iputil.Ip2VpnIp(c.certState.certificate.Details.Ips[0].Mask), + localBroadcast: myVpnIp | ^iputil.Ip2VpnIp(certificate.Details.Ips[0].Mask), dropLocalBroadcast: c.DropLocalBroadcast, dropMulticast: c.DropMulticast, routines: c.routines, version: c.version, - writers: make([]*udp.Conn, c.routines), + writers: make([]udp.Conn, c.routines), readers: make([]io.ReadWriteCloser, c.routines), - caPool: c.caPool, disconnectInvalid: c.disconnectInvalid, myVpnIp: myVpnIp, relayManager: c.relayManager, @@ -198,7 +204,10 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) { l: c.l, } - ifce.certState.Store(c.certState) + ifce.tryPromoteEvery.Store(c.tryPromoteEvery) + ifce.reQueryEvery.Store(c.reQueryEvery) + ifce.reQueryWait.Store(int64(c.reQueryWait)) + ifce.connectionManager = newConnectionManager(ctx, c.l, ifce, c.checkInterval, c.pendingDeletionInterval, c.punchy) return ifce, nil @@ -257,7 +266,7 @@ func (f *Interface) run() { func (f *Interface) listenOut(i int) { runtime.LockOSThread() - var li *udp.Conn + var li udp.Conn // TODO clean this up with a coherent interface for each outside connection if i > 0 { li = f.writers[i] @@ -297,49 +306,14 @@ func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) { } func (f *Interface) RegisterConfigChangeCallbacks(c *config.C) { - c.RegisterReloadCallback(f.reloadCA) - c.RegisterReloadCallback(f.reloadCertKey) c.RegisterReloadCallback(f.reloadFirewall) c.RegisterReloadCallback(f.reloadSendRecvError) + c.RegisterReloadCallback(f.reloadMisc) for _, udpConn := range f.writers { c.RegisterReloadCallback(udpConn.ReloadConfig) } } -func (f *Interface) reloadCA(c *config.C) { - // reload and check regardless - // todo: need mutex? - newCAs, err := loadCAFromConfig(f.l, c) - if err != nil { - f.l.WithError(err).Error("Could not refresh trusted CA certificates") - return - } - - f.caPool = newCAs - f.l.WithField("fingerprints", f.caPool.GetFingerprints()).Info("Trusted CA certificates refreshed") -} - -func (f *Interface) reloadCertKey(c *config.C) { - // reload and check in all cases - cs, err := NewCertStateFromConfig(c) - if err != nil { - f.l.WithError(err).Error("Could not refresh client cert") - return - } - - // did IP in cert change? if so, don't set - currentCert := f.certState.Load().certificate - oldIPs := currentCert.Details.Ips - newIPs := cs.certificate.Details.Ips - if len(oldIPs) > 0 && len(newIPs) > 0 && oldIPs[0].String() != newIPs[0].String() { - f.l.WithField("new_ip", newIPs[0]).WithField("old_ip", oldIPs[0]).Error("IP in new cert was different from old") - return - } - - f.certState.Store(cs) - f.l.WithField("cert", cs.certificate).Info("Client cert refreshed from disk") -} - func (f *Interface) reloadFirewall(c *config.C) { //TODO: need to trigger/detect if the certificate changed too if c.HasChanged("firewall") == false { @@ -347,7 +321,7 @@ func (f *Interface) reloadFirewall(c *config.C) { return } - fw, err := NewFirewallFromConfig(f.l, f.certState.Load().certificate, c) + fw, err := NewFirewallFromConfig(f.l, f.pki.GetCertState().Certificate, c) if err != nil { f.l.WithError(err).Error("Error while creating firewall during reload") return @@ -403,6 +377,26 @@ func (f *Interface) reloadSendRecvError(c *config.C) { } } +func (f *Interface) reloadMisc(c *config.C) { + if c.HasChanged("counters.try_promote") { + n := c.GetUint32("counters.try_promote", defaultPromoteEvery) + f.tryPromoteEvery.Store(n) + f.l.Info("counters.try_promote has changed") + } + + if c.HasChanged("counters.requery_every_packets") { + n := c.GetUint32("counters.requery_every_packets", defaultReQueryEvery) + f.reQueryEvery.Store(n) + f.l.Info("counters.requery_every_packets has changed") + } + + if c.HasChanged("timers.requery_wait_duration") { + n := c.GetDuration("timers.requery_wait_duration", defaultReQueryWait) + f.reQueryWait.Store(int64(n)) + f.l.Info("timers.requery_wait_duration has changed") + } +} + func (f *Interface) emitStats(ctx context.Context, i time.Duration) { ticker := time.NewTicker(i) defer ticker.Stop() @@ -427,7 +421,7 @@ func (f *Interface) emitStats(ctx context.Context, i time.Duration) { } rawStats() } - certExpirationGauge.Update(int64(f.certState.Load().certificate.Details.NotAfter.Sub(time.Now()) / time.Second)) + certExpirationGauge.Update(int64(f.pki.GetCertState().Certificate.Details.NotAfter.Sub(time.Now()) / time.Second)) } } } @@ -435,6 +429,13 @@ func (f *Interface) emitStats(ctx context.Context, i time.Duration) { func (f *Interface) Close() error { f.closed.Store(true) + for _, u := range f.writers { + err := u.Close() + if err != nil { + f.l.WithError(err).Error("Error while closing udp socket") + } + } + // Release the tun device return f.inside.Close() } diff --git a/lighthouse.go b/lighthouse.go index 460a1cb..9b3b837 100644 --- a/lighthouse.go +++ b/lighthouse.go @@ -39,7 +39,7 @@ type LightHouse struct { myVpnIp iputil.VpnIp myVpnZeros iputil.VpnIp myVpnNet *net.IPNet - punchConn *udp.Conn + punchConn udp.Conn punchy *Punchy // Local cache of answers from light houses @@ -64,11 +64,10 @@ type LightHouse struct { staticList atomic.Pointer[map[iputil.VpnIp]struct{}] lighthouses atomic.Pointer[map[iputil.VpnIp]struct{}] - interval atomic.Int64 - updateCancel context.CancelFunc - updateParentCtx context.Context - updateUdp EncWriter - nebulaPort uint32 // 32 bits because protobuf does not have a uint16 + interval atomic.Int64 + updateCancel context.CancelFunc + ifce EncWriter + nebulaPort uint32 // 32 bits because protobuf does not have a uint16 advertiseAddrs atomic.Pointer[[]netIpAndPort] @@ -84,7 +83,7 @@ type LightHouse struct { // NewLightHouseFromConfig will build a Lighthouse struct from the values provided in the config object // addrMap should be nil unless this is during a config reload -func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C, myVpnNet *net.IPNet, pc *udp.Conn, p *Punchy) (*LightHouse, error) { +func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C, myVpnNet *net.IPNet, pc udp.Conn, p *Punchy) (*LightHouse, error) { amLighthouse := c.GetBool("lighthouse.am_lighthouse", false) nebulaPort := uint32(c.GetInt("listen.port", 0)) if amLighthouse && nebulaPort == 0 { @@ -133,7 +132,7 @@ func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C, c.RegisterReloadCallback(func(c *config.C) { err := h.reload(c, false) switch v := err.(type) { - case util.ContextualError: + case *util.ContextualError: v.Log(l) case error: l.WithError(err).Error("failed to reload lighthouse") @@ -217,7 +216,7 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error { lh.updateCancel() } - lh.LhUpdateWorker(lh.updateParentCtx, lh.updateUdp) + lh.StartUpdateWorker() } } @@ -262,6 +261,18 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error { //NOTE: many things will get much simpler when we combine static_host_map and lighthouse.hosts in config if initial || c.HasChanged("static_host_map") || c.HasChanged("static_map.cadence") || c.HasChanged("static_map.network") || c.HasChanged("static_map.lookup_timeout") { + // Clean up. Entries still in the static_host_map will be re-built. + // Entries no longer present must have their (possible) background DNS goroutines stopped. + if existingStaticList := lh.staticList.Load(); existingStaticList != nil { + lh.RLock() + for staticVpnIp := range *existingStaticList { + if am, ok := lh.addrMap[staticVpnIp]; ok && am != nil { + am.hr.Cancel() + } + } + lh.RUnlock() + } + // Build a new list based on current config. staticList := make(map[iputil.VpnIp]struct{}) err := lh.loadStaticMap(c, lh.myVpnNet, staticList) if err != nil { @@ -742,33 +753,33 @@ func NewUDPAddrFromLH6(ipp *Ip6AndPort) *udp.Addr { return udp.NewAddr(lhIp6ToIp(ipp), uint16(ipp.Port)) } -func (lh *LightHouse) LhUpdateWorker(ctx context.Context, f EncWriter) { - lh.updateParentCtx = ctx - lh.updateUdp = f - +func (lh *LightHouse) StartUpdateWorker() { interval := lh.GetUpdateInterval() if lh.amLighthouse || interval == 0 { return } clockSource := time.NewTicker(time.Second * time.Duration(interval)) - updateCtx, cancel := context.WithCancel(ctx) + updateCtx, cancel := context.WithCancel(lh.ctx) lh.updateCancel = cancel - defer clockSource.Stop() - for { - lh.SendUpdate(f) + go func() { + defer clockSource.Stop() - select { - case <-updateCtx.Done(): - return - case <-clockSource.C: - continue + for { + lh.SendUpdate() + + select { + case <-updateCtx.Done(): + return + case <-clockSource.C: + continue + } } - } + }() } -func (lh *LightHouse) SendUpdate(f EncWriter) { +func (lh *LightHouse) SendUpdate() { var v4 []*Ip4AndPort var v6 []*Ip6AndPort @@ -821,7 +832,7 @@ func (lh *LightHouse) SendUpdate(f EncWriter) { } for vpnIp := range lighthouses { - f.SendMessageToVpnIp(header.LightHouse, 0, vpnIp, mm, nb, out) + lh.ifce.SendMessageToVpnIp(header.LightHouse, 0, vpnIp, mm, nb, out) } } diff --git a/lighthouse_test.go b/lighthouse_test.go index aa4da4c..66427e3 100644 --- a/lighthouse_test.go +++ b/lighthouse_test.go @@ -12,6 +12,7 @@ import ( "github.com/slackhq/nebula/test" "github.com/slackhq/nebula/udp" "github.com/stretchr/testify/assert" + "gopkg.in/yaml.v2" ) //TODO: Add a test to ensure udpAddr is copied and not reused @@ -65,6 +66,35 @@ func Test_lhStaticMapping(t *testing.T) { assert.EqualError(t, err, "lighthouse 10.128.0.3 does not have a static_host_map entry") } +func TestReloadLighthouseInterval(t *testing.T) { + l := test.NewLogger() + _, myVpnNet, _ := net.ParseCIDR("10.128.0.1/16") + lh1 := "10.128.0.2" + + c := config.NewC(l) + c.Settings["lighthouse"] = map[interface{}]interface{}{ + "hosts": []interface{}{lh1}, + "interval": "1s", + } + + c.Settings["static_host_map"] = map[interface{}]interface{}{lh1: []interface{}{"1.1.1.1:4242"}} + lh, err := NewLightHouseFromConfig(context.Background(), l, c, myVpnNet, nil, nil) + assert.NoError(t, err) + lh.ifce = &mockEncWriter{} + + // The first one routine is kicked off by main.go currently, lets make sure that one dies + c.ReloadConfigString("lighthouse:\n interval: 5") + assert.Equal(t, int64(5), lh.interval.Load()) + + // Subsequent calls are killed off by the LightHouse.Reload function + c.ReloadConfigString("lighthouse:\n interval: 10") + assert.Equal(t, int64(10), lh.interval.Load()) + + // If this completes then nothing is stealing our reload routine + c.ReloadConfigString("lighthouse:\n interval: 11") + assert.Equal(t, int64(11), lh.interval.Load()) +} + func BenchmarkLighthouseHandleRequest(b *testing.B) { l := test.NewLogger() _, myVpnNet, _ := net.ParseCIDR("10.128.0.1/0") @@ -242,8 +272,17 @@ func TestLighthouse_reload(t *testing.T) { lh, err := NewLightHouseFromConfig(context.Background(), l, c, &net.IPNet{IP: net.IP{10, 128, 0, 1}, Mask: net.IPMask{255, 255, 255, 0}}, nil, nil) assert.NoError(t, err) - c.Settings["static_host_map"] = map[interface{}]interface{}{"10.128.0.2": []interface{}{"1.1.1.1:4242"}} - lh.reload(c, false) + nc := map[interface{}]interface{}{ + "static_host_map": map[interface{}]interface{}{ + "10.128.0.2": []interface{}{"1.1.1.1:4242"}, + }, + } + rc, err := yaml.Marshal(nc) + assert.NoError(t, err) + c.ReloadConfigString(string(rc)) + + err = lh.reload(c, false) + assert.NoError(t, err) } func newLHHostRequest(fromAddr *udp.Addr, myVpnIp, queryVpnIp iputil.VpnIp, lhh *LightHouseHandler) testLhReply { diff --git a/main.go b/main.go index 4fa9e71..4398328 100644 --- a/main.go +++ b/main.go @@ -3,7 +3,6 @@ package nebula import ( "context" "encoding/binary" - "errors" "fmt" "net" "time" @@ -46,7 +45,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg err := configLogger(l, c) if err != nil { - return nil, util.NewContextualError("Failed to configure the logger", nil, err) + return nil, util.ContextualizeIfNeeded("Failed to configure the logger", err) } c.RegisterReloadCallback(func(c *config.C) { @@ -56,28 +55,20 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg } }) - caPool, err := loadCAFromConfig(l, c) + pki, err := NewPKIFromConfig(l, c) if err != nil { - //The errors coming out of loadCA are already nicely formatted - return nil, util.NewContextualError("Failed to load ca from config", nil, err) + return nil, util.ContextualizeIfNeeded("Failed to load PKI from config", err) } - l.WithField("fingerprints", caPool.GetFingerprints()).Debug("Trusted CA fingerprints") - cs, err := NewCertStateFromConfig(c) + certificate := pki.GetCertState().Certificate + fw, err := NewFirewallFromConfig(l, certificate, c) if err != nil { - //The errors coming out of NewCertStateFromConfig are already nicely formatted - return nil, util.NewContextualError("Failed to load certificate from config", nil, err) - } - l.WithField("cert", cs.certificate).Debug("Client nebula certificate") - - fw, err := NewFirewallFromConfig(l, cs.certificate, c) - if err != nil { - return nil, util.NewContextualError("Error while loading firewall rules", nil, err) + return nil, util.ContextualizeIfNeeded("Error while loading firewall rules", err) } l.WithField("firewallHash", fw.GetRuleHash()).Info("Firewall started") // TODO: make sure mask is 4 bytes - tunCidr := cs.certificate.Details.Ips[0] + tunCidr := certificate.Details.Ips[0] ssh, err := sshd.NewSSHServer(l.WithField("subsystem", "sshd")) wireSSHReload(l, ssh, c) @@ -85,7 +76,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg if c.GetBool("sshd.enabled", false) { sshStart, err = configSSH(l, ssh, c) if err != nil { - return nil, util.NewContextualError("Error while configuring the sshd", nil, err) + return nil, util.ContextualizeIfNeeded("Error while configuring the sshd", err) } } @@ -136,7 +127,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg tun, err = overlay.NewDeviceFromConfig(c, l, tunCidr, tunFd, routines) if err != nil { - return nil, util.NewContextualError("Failed to get a tun/tap device", nil, err) + return nil, util.ContextualizeIfNeeded("Failed to get a tun/tap device", err) } defer func() { @@ -147,7 +138,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg } // set up our UDP listener - udpConns := make([]*udp.Conn, routines) + udpConns := make([]udp.Conn, routines) port := c.GetInt("listen.port", 0) if !configTest { @@ -160,7 +151,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg } else { listenHost, err = net.ResolveIPAddr("ip", rawListenHost) if err != nil { - return nil, util.NewContextualError("Failed to resolve listen.host", nil, err) + return nil, util.ContextualizeIfNeeded("Failed to resolve listen.host", err) } } @@ -182,7 +173,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg for _, rawPreferredRange := range rawPreferredRanges { _, preferredRange, err := net.ParseCIDR(rawPreferredRange) if err != nil { - return nil, util.NewContextualError("Failed to parse preferred ranges", nil, err) + return nil, util.ContextualizeIfNeeded("Failed to parse preferred ranges", err) } preferredRanges = append(preferredRanges, preferredRange) } @@ -195,7 +186,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg if rawLocalRange != "" { _, localRange, err := net.ParseCIDR(rawLocalRange) if err != nil { - return nil, util.NewContextualError("Failed to parse local_range", nil, err) + return nil, util.ContextualizeIfNeeded("Failed to parse local_range", err) } // Check if the entry for local_range was already specified in @@ -212,7 +203,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg } } - hostMap := NewHostMap(l, "main", tunCidr, preferredRanges) + hostMap := NewHostMap(l, tunCidr, preferredRanges) hostMap.metricsEnabled = c.GetBool("stats.message_metrics", false) l. @@ -220,18 +211,10 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg WithField("preferredRanges", hostMap.preferredRanges). Info("Main HostMap created") - /* - config.SetDefault("promoter.interval", 10) - go hostMap.Promoter(config.GetInt("promoter.interval")) - */ - punchy := NewPunchyFromConfig(l, c) lightHouse, err := NewLightHouseFromConfig(ctx, l, c, tunCidr, udpConns[0], punchy) - switch { - case errors.As(err, &util.ContextualError{}): - return nil, err - case err != nil: - return nil, util.NewContextualError("Failed to initialize lighthouse handler", nil, err) + if err != nil { + return nil, util.ContextualizeIfNeeded("Failed to initialize lighthouse handler", err) } var messageMetrics *MessageMetrics @@ -252,13 +235,9 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg messageMetrics: messageMetrics, } - handshakeManager := NewHandshakeManager(l, tunCidr, preferredRanges, hostMap, lightHouse, udpConns[0], handshakeConfig) + handshakeManager := NewHandshakeManager(l, hostMap, lightHouse, udpConns[0], handshakeConfig) lightHouse.handshakeTrigger = handshakeManager.trigger - //TODO: These will be reused for psk - //handshakeMACKey := config.GetString("handshake_mac.key", "") - //handshakeAcceptedMACKeys := config.GetStringSlice("handshake_mac.accepted_keys", []string{}) - serveDns := false if c.GetBool("lighthouse.serve_dns", false) { if c.GetBool("lighthouse.am_lighthouse", false) { @@ -270,11 +249,12 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg checkInterval := c.GetInt("timers.connection_alive_interval", 5) pendingDeletionInterval := c.GetInt("timers.pending_deletion_interval", 10) + ifConfig := &InterfaceConfig{ HostMap: hostMap, Inside: tun, Outside: udpConns[0], - certState: cs, + pki: pki, Cipher: c.GetString("cipher", "aes"), Firewall: fw, ServeDns: serveDns, @@ -282,12 +262,14 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg lightHouse: lightHouse, checkInterval: time.Second * time.Duration(checkInterval), pendingDeletionInterval: time.Second * time.Duration(pendingDeletionInterval), + tryPromoteEvery: c.GetUint32("counters.try_promote", defaultPromoteEvery), + reQueryEvery: c.GetUint32("counters.requery_every_packets", defaultReQueryEvery), + reQueryWait: c.GetDuration("timers.requery_wait_duration", defaultReQueryWait), DropLocalBroadcast: c.GetBool("tun.drop_local_broadcast", false), DropMulticast: c.GetBool("tun.drop_multicast", false), routines: routines, MessageMetrics: messageMetrics, version: buildVersion, - caPool: caPool, disconnectInvalid: c.GetBool("pki.disconnect_invalid", false), relayManager: NewRelayManager(ctx, l, hostMap, c), punchy: punchy, @@ -315,6 +297,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg // TODO: Better way to attach these, probably want a new interface in InterfaceConfig // I don't want to make this initial commit too far-reaching though ifce.writers = udpConns + lightHouse.ifce = ifce loadMultiPortConfig := func(c *config.C) { ifce.multiPort.Rx = c.GetBool("tun.multiport.rx_enabled", false) @@ -350,19 +333,17 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg c.RegisterReloadCallback(loadMultiPortConfig) ifce.RegisterConfigChangeCallbacks(c) - ifce.reloadSendRecvError(c) - go handshakeManager.Run(ctx, ifce) - go lightHouse.LhUpdateWorker(ctx, ifce) + handshakeManager.f = ifce + go handshakeManager.Run(ctx) } // TODO - stats third-party modules start uncancellable goroutines. Update those libs to accept // a context so that they can exit when the context is Done. statsStart, err := startStats(l, c, buildVersion, configTest) - if err != nil { - return nil, util.NewContextualError("Failed to start stats emitter", nil, err) + return nil, util.ContextualizeIfNeeded("Failed to start stats emitter", err) } if configTest { @@ -372,7 +353,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg //TODO: check if we _should_ be emitting stats go ifce.emitStats(ctx, c.GetDuration("stats.interval", time.Second*10)) - attachCommands(l, c, ssh, hostMap, handshakeManager.pendingHostMap, lightHouse, ifce) + attachCommands(l, c, ssh, ifce) // Start DNS server last to allow using the nebula IP as lighthouse.dns.host var dnsStart func() @@ -381,5 +362,13 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg dnsStart = dnsMain(l, hostMap, c) } - return &Control{ifce, l, cancel, sshStart, statsStart, dnsStart}, nil + return &Control{ + ifce, + l, + cancel, + sshStart, + statsStart, + dnsStart, + lightHouse.StartUpdateWorker, + }, nil } diff --git a/outside.go b/outside.go index 59820ae..970f299 100644 --- a/outside.go +++ b/outside.go @@ -64,9 +64,9 @@ func (f *Interface) readOutsidePackets(addr *udp.Addr, via *ViaSender, out []byt var hostinfo *HostInfo // verify if we've seen this index before, otherwise respond to the handshake initiation if h.Type == header.Message && h.Subtype == header.MessageRelay { - hostinfo, _ = f.hostMap.QueryRelayIndex(h.RemoteIndex) + hostinfo = f.hostMap.QueryRelayIndex(h.RemoteIndex) } else { - hostinfo, _ = f.hostMap.QueryIndex(h.RemoteIndex) + hostinfo = f.hostMap.QueryIndex(h.RemoteIndex) } var ci *ConnectionState @@ -417,7 +417,7 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out return false } - dropReason := f.firewall.Drop(out, *fwPacket, true, hostinfo, f.caPool, localCache) + dropReason := f.firewall.Drop(out, *fwPacket, true, hostinfo, f.pki.GetCAPool(), localCache) if dropReason != nil { f.rejectOutside(out, hostinfo.ConnectionState, hostinfo, nb, out, q) if f.l.Level >= logrus.DebugLevel { @@ -462,12 +462,9 @@ func (f *Interface) handleRecvError(addr *udp.Addr, h *header.H) { Debug("Recv error received") } - // First, clean up in the pending hostmap - f.handshakeManager.pendingHostMap.DeleteReverseIndex(h.RemoteIndex) - - hostinfo, err := f.hostMap.QueryReverseIndex(h.RemoteIndex) - if err != nil { - f.l.Debugln(err, ": ", h.RemoteIndex) + hostinfo := f.hostMap.QueryReverseIndex(h.RemoteIndex) + if hostinfo == nil { + f.l.WithField("remoteIndex", h.RemoteIndex).Debugln("Did not find remote index in main hostmap") return } @@ -477,14 +474,14 @@ func (f *Interface) handleRecvError(addr *udp.Addr, h *header.H) { if !hostinfo.RecvErrorExceeded() { return } + if hostinfo.remote != nil && !hostinfo.remote.Equals(addr) { f.l.Infoln("Someone spoofing recv_errors? ", addr, hostinfo.remote) return } f.closeTunnel(hostinfo) - // We also delete it from pending hostmap to allow for - // fast reconnect. + // We also delete it from pending hostmap to allow for fast reconnect. f.handshakeManager.DeleteHostInfo(hostinfo) } diff --git a/overlay/tun_darwin.go b/overlay/tun_darwin.go index fd3429d..428e38f 100644 --- a/overlay/tun_darwin.go +++ b/overlay/tun_darwin.go @@ -47,14 +47,6 @@ type ifReq struct { pad [8]byte } -func ioctl(a1, a2, a3 uintptr) error { - _, _, errno := unix.Syscall(unix.SYS_IOCTL, a1, a2, a3) - if errno != 0 { - return errno - } - return nil -} - var sockaddrCtlSize uintptr = 32 const ( @@ -194,10 +186,10 @@ func (t *tun) Activate() error { unix.SOCK_DGRAM, unix.IPPROTO_IP, ) - if err != nil { return err } + defer unix.Close(s) fd := uintptr(s) diff --git a/overlay/tun_freebsd.go b/overlay/tun_freebsd.go index 99cbdb0..8a52954 100644 --- a/overlay/tun_freebsd.go +++ b/overlay/tun_freebsd.go @@ -4,21 +4,44 @@ package overlay import ( + "bytes" + "errors" "fmt" "io" + "io/fs" "net" "os" "os/exec" - "regexp" "strconv" - "strings" + "syscall" + "unsafe" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/cidr" "github.com/slackhq/nebula/iputil" ) -var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`) +const ( + // FIODGNAME is defined in sys/sys/filio.h on FreeBSD + // For 32-bit systems, use FIODGNAME_32 (not defined in this file: 0x80086678) + FIODGNAME = 0x80106678 +) + +type fiodgnameArg struct { + length int32 + pad [4]byte + buf unsafe.Pointer +} + +type ifreqRename struct { + Name [16]byte + Data uintptr +} + +type ifreqDestroy struct { + Name [16]byte + pad [16]byte +} type tun struct { Device string @@ -33,8 +56,23 @@ type tun struct { func (t *tun) Close() error { if t.ReadWriteCloser != nil { - return t.ReadWriteCloser.Close() + if err := t.ReadWriteCloser.Close(); err != nil { + return err + } + + s, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_DGRAM, syscall.IPPROTO_IP) + if err != nil { + return err + } + defer syscall.Close(s) + + ifreq := ifreqDestroy{Name: t.deviceBytes()} + + // Destroy the interface + err = ioctl(uintptr(s), syscall.SIOCIFDESTROY, uintptr(unsafe.Pointer(&ifreq))) + return err } + return nil } @@ -43,34 +81,87 @@ func newTunFromFd(_ *logrus.Logger, _ int, _ *net.IPNet, _ int, _ []Route, _ int } func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, routes []Route, _ int, _ bool, _ bool) (*tun, error) { + // Try to open existing tun device + var file *os.File + var err error + if deviceName != "" { + file, err = os.OpenFile("/dev/"+deviceName, os.O_RDWR, 0) + } + if errors.Is(err, fs.ErrNotExist) || deviceName == "" { + // If the device doesn't already exist, request a new one and rename it + file, err = os.OpenFile("/dev/tun", os.O_RDWR, 0) + } + if err != nil { + return nil, err + } + + rawConn, err := file.SyscallConn() + if err != nil { + return nil, fmt.Errorf("SyscallConn: %v", err) + } + + var name [16]byte + var ctrlErr error + rawConn.Control(func(fd uintptr) { + // Read the name of the interface + arg := fiodgnameArg{length: 16, buf: unsafe.Pointer(&name)} + ctrlErr = ioctl(fd, FIODGNAME, uintptr(unsafe.Pointer(&arg))) + }) + if ctrlErr != nil { + return nil, err + } + + ifName := string(bytes.TrimRight(name[:], "\x00")) + if deviceName == "" { + deviceName = ifName + } + + // If the name doesn't match the desired interface name, rename it now + if ifName != deviceName { + s, err := syscall.Socket( + syscall.AF_INET, + syscall.SOCK_DGRAM, + syscall.IPPROTO_IP, + ) + if err != nil { + return nil, err + } + defer syscall.Close(s) + + fd := uintptr(s) + + var fromName [16]byte + var toName [16]byte + copy(fromName[:], ifName) + copy(toName[:], deviceName) + + ifrr := ifreqRename{ + Name: fromName, + Data: uintptr(unsafe.Pointer(&toName)), + } + + // Set the device name + ioctl(fd, syscall.SIOCSIFNAME, uintptr(unsafe.Pointer(&ifrr))) + } + routeTree, err := makeRouteTree(l, routes, false) if err != nil { return nil, err } - if strings.HasPrefix(deviceName, "/dev/") { - deviceName = strings.TrimPrefix(deviceName, "/dev/") - } - if !deviceNameRE.MatchString(deviceName) { - return nil, fmt.Errorf("tun.dev must match `tun[0-9]+`") - } return &tun{ - Device: deviceName, - cidr: cidr, - MTU: defaultMTU, - Routes: routes, - routeTree: routeTree, - l: l, + ReadWriteCloser: file, + Device: deviceName, + cidr: cidr, + MTU: defaultMTU, + Routes: routes, + routeTree: routeTree, + l: l, }, nil } func (t *tun) Activate() error { var err error - t.ReadWriteCloser, err = os.OpenFile("/dev/"+t.Device, os.O_RDWR, 0) - if err != nil { - return fmt.Errorf("activate failed: %v", err) - } - // TODO use syscalls instead of exec.Command t.l.Debug("command: ifconfig", t.Device, t.cidr.String(), t.cidr.IP.String()) if err = exec.Command("/sbin/ifconfig", t.Device, t.cidr.String(), t.cidr.IP.String()).Run(); err != nil { @@ -120,3 +211,10 @@ func (t *tun) Name() string { func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { return nil, fmt.Errorf("TODO: multiqueue not implemented for freebsd") } + +func (t *tun) deviceBytes() (o [16]byte) { + for i, c := range t.Device { + o[i] = byte(c) + } + return +} diff --git a/overlay/tun_linux.go b/overlay/tun_linux.go index 7833186..8751a3f 100644 --- a/overlay/tun_linux.go +++ b/overlay/tun_linux.go @@ -43,14 +43,6 @@ type ifReq struct { pad [8]byte } -func ioctl(a1, a2, a3 uintptr) error { - _, _, errno := unix.Syscall(unix.SYS_IOCTL, a1, a2, a3) - if errno != 0 { - return errno - } - return nil -} - type ifreqAddr struct { Name [16]byte Addr unix.RawSockaddrInet4 diff --git a/overlay/tun_netbsd.go b/overlay/tun_netbsd.go new file mode 100644 index 0000000..4d7f897 --- /dev/null +++ b/overlay/tun_netbsd.go @@ -0,0 +1,162 @@ +//go:build !e2e_testing +// +build !e2e_testing + +package overlay + +import ( + "fmt" + "io" + "net" + "os" + "os/exec" + "regexp" + "strconv" + "syscall" + "unsafe" + + "github.com/sirupsen/logrus" + "github.com/slackhq/nebula/cidr" + "github.com/slackhq/nebula/iputil" +) + +type ifreqDestroy struct { + Name [16]byte + pad [16]byte +} + +type tun struct { + Device string + cidr *net.IPNet + MTU int + Routes []Route + routeTree *cidr.Tree4 + l *logrus.Logger + + io.ReadWriteCloser +} + +func (t *tun) Close() error { + if t.ReadWriteCloser != nil { + if err := t.ReadWriteCloser.Close(); err != nil { + return err + } + + s, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_DGRAM, syscall.IPPROTO_IP) + if err != nil { + return err + } + defer syscall.Close(s) + + ifreq := ifreqDestroy{Name: t.deviceBytes()} + + err = ioctl(uintptr(s), syscall.SIOCIFDESTROY, uintptr(unsafe.Pointer(&ifreq))) + + return err + } + return nil +} + +func newTunFromFd(_ *logrus.Logger, _ int, _ *net.IPNet, _ int, _ []Route, _ int, _ bool) (*tun, error) { + return nil, fmt.Errorf("newTunFromFd not supported in NetBSD") +} + +var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`) + +func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, routes []Route, _ int, _ bool, _ bool) (*tun, error) { + // Try to open tun device + var file *os.File + var err error + if deviceName == "" { + return nil, fmt.Errorf("a device name in the format of /dev/tunN must be specified") + } + if !deviceNameRE.MatchString(deviceName) { + return nil, fmt.Errorf("a device name in the format of /dev/tunN must be specified") + } + file, err = os.OpenFile("/dev/"+deviceName, os.O_RDWR, 0) + + if err != nil { + return nil, err + } + + routeTree, err := makeRouteTree(l, routes, false) + + if err != nil { + return nil, err + } + + return &tun{ + ReadWriteCloser: file, + Device: deviceName, + cidr: cidr, + MTU: defaultMTU, + Routes: routes, + routeTree: routeTree, + l: l, + }, nil +} + +func (t *tun) Activate() error { + var err error + + // TODO use syscalls instead of exec.Command + cmd := exec.Command("/sbin/ifconfig", t.Device, t.cidr.String(), t.cidr.IP.String()) + t.l.Debug("command: ", cmd.String()) + if err = cmd.Run(); err != nil { + return fmt.Errorf("failed to run 'ifconfig': %s", err) + } + + cmd = exec.Command("/sbin/route", "-n", "add", "-net", t.cidr.String(), t.cidr.IP.String()) + t.l.Debug("command: ", cmd.String()) + if err = cmd.Run(); err != nil { + return fmt.Errorf("failed to run 'route add': %s", err) + } + + cmd = exec.Command("/sbin/ifconfig", t.Device, "mtu", strconv.Itoa(t.MTU)) + t.l.Debug("command: ", cmd.String()) + if err = cmd.Run(); err != nil { + return fmt.Errorf("failed to run 'ifconfig': %s", err) + } + // Unsafe path routes + for _, r := range t.Routes { + if r.Via == nil || !r.Install { + // We don't allow route MTUs so only install routes with a via + continue + } + + cmd = exec.Command("/sbin/route", "-n", "add", "-net", r.Cidr.String(), t.cidr.IP.String()) + t.l.Debug("command: ", cmd.String()) + if err = cmd.Run(); err != nil { + return fmt.Errorf("failed to run 'route add' for unsafe_route %s: %s", r.Cidr.String(), err) + } + } + + return nil +} + +func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp { + r := t.routeTree.MostSpecificContains(ip) + if r != nil { + return r.(iputil.VpnIp) + } + + return 0 +} + +func (t *tun) Cidr() *net.IPNet { + return t.cidr +} + +func (t *tun) Name() string { + return t.Device +} + +func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { + return nil, fmt.Errorf("TODO: multiqueue not implemented for netbsd") +} + +func (t *tun) deviceBytes() (o [16]byte) { + for i, c := range t.Device { + o[i] = byte(c) + } + return +} diff --git a/overlay/tun_notwin.go b/overlay/tun_notwin.go new file mode 100644 index 0000000..2fab927 --- /dev/null +++ b/overlay/tun_notwin.go @@ -0,0 +1,14 @@ +//go:build !windows +// +build !windows + +package overlay + +import "syscall" + +func ioctl(a1, a2, a3 uintptr) error { + _, _, errno := syscall.Syscall(syscall.SYS_IOCTL, a1, a2, a3) + if errno != 0 { + return errno + } + return nil +} diff --git a/overlay/tun_openbsd.go b/overlay/tun_openbsd.go new file mode 100644 index 0000000..709fb42 --- /dev/null +++ b/overlay/tun_openbsd.go @@ -0,0 +1,174 @@ +//go:build !e2e_testing +// +build !e2e_testing + +package overlay + +import ( + "fmt" + "io" + "net" + "os" + "os/exec" + "regexp" + "strconv" + "syscall" + + "github.com/sirupsen/logrus" + "github.com/slackhq/nebula/cidr" + "github.com/slackhq/nebula/iputil" +) + +type tun struct { + Device string + cidr *net.IPNet + MTU int + Routes []Route + routeTree *cidr.Tree4 + l *logrus.Logger + + io.ReadWriteCloser + + // cache out buffer since we need to prepend 4 bytes for tun metadata + out []byte +} + +func (t *tun) Close() error { + if t.ReadWriteCloser != nil { + return t.ReadWriteCloser.Close() + } + + return nil +} + +func newTunFromFd(_ *logrus.Logger, _ int, _ *net.IPNet, _ int, _ []Route, _ int, _ bool) (*tun, error) { + return nil, fmt.Errorf("newTunFromFd not supported in OpenBSD") +} + +var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`) + +func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, routes []Route, _ int, _ bool, _ bool) (*tun, error) { + if deviceName == "" { + return nil, fmt.Errorf("a device name in the format of tunN must be specified") + } + + if !deviceNameRE.MatchString(deviceName) { + return nil, fmt.Errorf("a device name in the format of tunN must be specified") + } + + file, err := os.OpenFile("/dev/"+deviceName, os.O_RDWR, 0) + if err != nil { + return nil, err + } + + routeTree, err := makeRouteTree(l, routes, false) + if err != nil { + return nil, err + } + + return &tun{ + ReadWriteCloser: file, + Device: deviceName, + cidr: cidr, + MTU: defaultMTU, + Routes: routes, + routeTree: routeTree, + l: l, + }, nil +} + +func (t *tun) Activate() error { + var err error + // TODO use syscalls instead of exec.Command + cmd := exec.Command("/sbin/ifconfig", t.Device, t.cidr.String(), t.cidr.IP.String()) + t.l.Debug("command: ", cmd.String()) + if err = cmd.Run(); err != nil { + return fmt.Errorf("failed to run 'ifconfig': %s", err) + } + + cmd = exec.Command("/sbin/ifconfig", t.Device, "mtu", strconv.Itoa(t.MTU)) + t.l.Debug("command: ", cmd.String()) + if err = cmd.Run(); err != nil { + return fmt.Errorf("failed to run 'ifconfig': %s", err) + } + + cmd = exec.Command("/sbin/route", "-n", "add", "-inet", t.cidr.String(), t.cidr.IP.String()) + t.l.Debug("command: ", cmd.String()) + if err = cmd.Run(); err != nil { + return fmt.Errorf("failed to run 'route add': %s", err) + } + + // Unsafe path routes + for _, r := range t.Routes { + if r.Via == nil || !r.Install { + // We don't allow route MTUs so only install routes with a via + continue + } + + cmd = exec.Command("/sbin/route", "-n", "add", "-inet", r.Cidr.String(), t.cidr.IP.String()) + t.l.Debug("command: ", cmd.String()) + if err = cmd.Run(); err != nil { + return fmt.Errorf("failed to run 'route add' for unsafe_route %s: %s", r.Cidr.String(), err) + } + } + + return nil +} + +func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp { + r := t.routeTree.MostSpecificContains(ip) + if r != nil { + return r.(iputil.VpnIp) + } + + return 0 +} + +func (t *tun) Cidr() *net.IPNet { + return t.cidr +} + +func (t *tun) Name() string { + return t.Device +} + +func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { + return nil, fmt.Errorf("TODO: multiqueue not implemented for freebsd") +} + +func (t *tun) Read(to []byte) (int, error) { + buf := make([]byte, len(to)+4) + + n, err := t.ReadWriteCloser.Read(buf) + + copy(to, buf[4:]) + return n - 4, err +} + +// Write is only valid for single threaded use +func (t *tun) Write(from []byte) (int, error) { + buf := t.out + if cap(buf) < len(from)+4 { + buf = make([]byte, len(from)+4) + t.out = buf + } + buf = buf[:len(from)+4] + + if len(from) == 0 { + return 0, syscall.EIO + } + + // Determine the IP Family for the NULL L2 Header + ipVer := from[0] >> 4 + if ipVer == 4 { + buf[3] = syscall.AF_INET + } else if ipVer == 6 { + buf[3] = syscall.AF_INET6 + } else { + return 0, fmt.Errorf("unable to determine IP version from packet") + } + + copy(buf[4:], from) + + n, err := t.ReadWriteCloser.Write(buf) + return n - 4, err +} diff --git a/overlay/tun_tester.go b/overlay/tun_tester.go index 3a49dcb..a2a57e1 100644 --- a/overlay/tun_tester.go +++ b/overlay/tun_tester.go @@ -8,6 +8,7 @@ import ( "io" "net" "os" + "sync/atomic" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/cidr" @@ -21,6 +22,7 @@ type TestTun struct { routeTree *cidr.Tree4 l *logrus.Logger + closed atomic.Bool rxPackets chan []byte // Packets to receive into nebula TxPackets chan []byte // Packets transmitted outside by nebula } @@ -50,6 +52,10 @@ func newTunFromFd(_ *logrus.Logger, _ int, _ *net.IPNet, _ int, _ []Route, _ int // These are unencrypted ip layer frames destined for another nebula node. // packets should exit the udp side, capture them with udpConn.Get func (t *TestTun) Send(packet []byte) { + if t.closed.Load() { + return + } + if t.l.Level >= logrus.DebugLevel { t.l.WithField("dataLen", len(packet)).Debug("Tun receiving injected packet") } @@ -98,6 +104,10 @@ func (t *TestTun) Name() string { } func (t *TestTun) Write(b []byte) (n int, err error) { + if t.closed.Load() { + return 0, io.ErrClosedPipe + } + packet := make([]byte, len(b), len(b)) copy(packet, b) t.TxPackets <- packet @@ -105,7 +115,10 @@ func (t *TestTun) Write(b []byte) (n int, err error) { } func (t *TestTun) Close() error { - close(t.rxPackets) + if t.closed.CompareAndSwap(false, true) { + close(t.rxPackets) + close(t.TxPackets) + } return nil } diff --git a/overlay/tun_wintun_windows.go b/overlay/tun_wintun_windows.go index 9146c88..a406123 100644 --- a/overlay/tun_wintun_windows.go +++ b/overlay/tun_wintun_windows.go @@ -54,9 +54,16 @@ func newWinTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU return nil, fmt.Errorf("generate GUID failed: %w", err) } - tunDevice, err := wintun.CreateTUNWithRequestedGUID(deviceName, guid, defaultMTU) + var tunDevice wintun.Device + tunDevice, err = wintun.CreateTUNWithRequestedGUID(deviceName, guid, defaultMTU) if err != nil { - return nil, fmt.Errorf("create TUN device failed: %w", err) + // Windows 10 has an issue with unclean shutdowns not fully cleaning up the wintun device. + // Trying a second time resolves the issue. + l.WithError(err).Debug("Failed to create wintun device, retrying") + tunDevice, err = wintun.CreateTUNWithRequestedGUID(deviceName, guid, defaultMTU) + if err != nil { + return nil, fmt.Errorf("create TUN device failed: %w", err) + } } routeTree, err := makeRouteTree(l, routes, false) diff --git a/pki.go b/pki.go new file mode 100644 index 0000000..91478ce --- /dev/null +++ b/pki.go @@ -0,0 +1,248 @@ +package nebula + +import ( + "errors" + "fmt" + "os" + "strings" + "sync/atomic" + "time" + + "github.com/sirupsen/logrus" + "github.com/slackhq/nebula/cert" + "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/util" +) + +type PKI struct { + cs atomic.Pointer[CertState] + caPool atomic.Pointer[cert.NebulaCAPool] + l *logrus.Logger +} + +type CertState struct { + Certificate *cert.NebulaCertificate + RawCertificate []byte + RawCertificateNoKey []byte + PublicKey []byte + PrivateKey []byte +} + +func NewPKIFromConfig(l *logrus.Logger, c *config.C) (*PKI, error) { + pki := &PKI{l: l} + err := pki.reload(c, true) + if err != nil { + return nil, err + } + + c.RegisterReloadCallback(func(c *config.C) { + rErr := pki.reload(c, false) + if rErr != nil { + util.LogWithContextIfNeeded("Failed to reload PKI from config", rErr, l) + } + }) + + return pki, nil +} + +func (p *PKI) GetCertState() *CertState { + return p.cs.Load() +} + +func (p *PKI) GetCAPool() *cert.NebulaCAPool { + return p.caPool.Load() +} + +func (p *PKI) reload(c *config.C, initial bool) error { + err := p.reloadCert(c, initial) + if err != nil { + if initial { + return err + } + err.Log(p.l) + } + + err = p.reloadCAPool(c) + if err != nil { + if initial { + return err + } + err.Log(p.l) + } + + return nil +} + +func (p *PKI) reloadCert(c *config.C, initial bool) *util.ContextualError { + cs, err := newCertStateFromConfig(c) + if err != nil { + return util.NewContextualError("Could not load client cert", nil, err) + } + + if !initial { + // did IP in cert change? if so, don't set + currentCert := p.cs.Load().Certificate + oldIPs := currentCert.Details.Ips + newIPs := cs.Certificate.Details.Ips + if len(oldIPs) > 0 && len(newIPs) > 0 && oldIPs[0].String() != newIPs[0].String() { + return util.NewContextualError( + "IP in new cert was different from old", + m{"new_ip": newIPs[0], "old_ip": oldIPs[0]}, + nil, + ) + } + } + + p.cs.Store(cs) + if initial { + p.l.WithField("cert", cs.Certificate).Debug("Client nebula certificate") + } else { + p.l.WithField("cert", cs.Certificate).Info("Client cert refreshed from disk") + } + return nil +} + +func (p *PKI) reloadCAPool(c *config.C) *util.ContextualError { + caPool, err := loadCAPoolFromConfig(p.l, c) + if err != nil { + return util.NewContextualError("Failed to load ca from config", nil, err) + } + + p.caPool.Store(caPool) + p.l.WithField("fingerprints", caPool.GetFingerprints()).Debug("Trusted CA fingerprints") + return nil +} + +func newCertState(certificate *cert.NebulaCertificate, privateKey []byte) (*CertState, error) { + // Marshal the certificate to ensure it is valid + rawCertificate, err := certificate.Marshal() + if err != nil { + return nil, fmt.Errorf("invalid nebula certificate on interface: %s", err) + } + + publicKey := certificate.Details.PublicKey + cs := &CertState{ + RawCertificate: rawCertificate, + Certificate: certificate, + PrivateKey: privateKey, + PublicKey: publicKey, + } + + cs.Certificate.Details.PublicKey = nil + rawCertNoKey, err := cs.Certificate.Marshal() + if err != nil { + return nil, fmt.Errorf("error marshalling certificate no key: %s", err) + } + cs.RawCertificateNoKey = rawCertNoKey + // put public key back + cs.Certificate.Details.PublicKey = cs.PublicKey + return cs, nil +} + +func newCertStateFromConfig(c *config.C) (*CertState, error) { + var pemPrivateKey []byte + var err error + + privPathOrPEM := c.GetString("pki.key", "") + if privPathOrPEM == "" { + return nil, errors.New("no pki.key path or PEM data provided") + } + + if strings.Contains(privPathOrPEM, "-----BEGIN") { + pemPrivateKey = []byte(privPathOrPEM) + privPathOrPEM = "" + + } else { + pemPrivateKey, err = os.ReadFile(privPathOrPEM) + if err != nil { + return nil, fmt.Errorf("unable to read pki.key file %s: %s", privPathOrPEM, err) + } + } + + rawKey, _, curve, err := cert.UnmarshalPrivateKey(pemPrivateKey) + if err != nil { + return nil, fmt.Errorf("error while unmarshaling pki.key %s: %s", privPathOrPEM, err) + } + + var rawCert []byte + + pubPathOrPEM := c.GetString("pki.cert", "") + if pubPathOrPEM == "" { + return nil, errors.New("no pki.cert path or PEM data provided") + } + + if strings.Contains(pubPathOrPEM, "-----BEGIN") { + rawCert = []byte(pubPathOrPEM) + pubPathOrPEM = "" + + } else { + rawCert, err = os.ReadFile(pubPathOrPEM) + if err != nil { + return nil, fmt.Errorf("unable to read pki.cert file %s: %s", pubPathOrPEM, err) + } + } + + nebulaCert, _, err := cert.UnmarshalNebulaCertificateFromPEM(rawCert) + if err != nil { + return nil, fmt.Errorf("error while unmarshaling pki.cert %s: %s", pubPathOrPEM, err) + } + + if nebulaCert.Expired(time.Now()) { + return nil, fmt.Errorf("nebula certificate for this host is expired") + } + + if len(nebulaCert.Details.Ips) == 0 { + return nil, fmt.Errorf("no IPs encoded in certificate") + } + + if err = nebulaCert.VerifyPrivateKey(curve, rawKey); err != nil { + return nil, fmt.Errorf("private key is not a pair with public key in nebula cert") + } + + return newCertState(nebulaCert, rawKey) +} + +func loadCAPoolFromConfig(l *logrus.Logger, c *config.C) (*cert.NebulaCAPool, error) { + var rawCA []byte + var err error + + caPathOrPEM := c.GetString("pki.ca", "") + if caPathOrPEM == "" { + return nil, errors.New("no pki.ca path or PEM data provided") + } + + if strings.Contains(caPathOrPEM, "-----BEGIN") { + rawCA = []byte(caPathOrPEM) + + } else { + rawCA, err = os.ReadFile(caPathOrPEM) + if err != nil { + return nil, fmt.Errorf("unable to read pki.ca file %s: %s", caPathOrPEM, err) + } + } + + caPool, err := cert.NewCAPoolFromBytes(rawCA) + if errors.Is(err, cert.ErrExpired) { + var expired int + for _, crt := range caPool.CAs { + if crt.Expired(time.Now()) { + expired++ + l.WithField("cert", crt).Warn("expired certificate present in CA pool") + } + } + + if expired >= len(caPool.CAs) { + return nil, errors.New("no valid CA certificates present") + } + + } else if err != nil { + return nil, fmt.Errorf("error while adding CA certificate to CA trust store: %s", err) + } + + for _, fp := range c.GetStringSlice("pki.blocklist", []string{}) { + l.WithField("fingerprint", fp).Info("Blocklisting cert") + caPool.BlocklistFingerprint(fp) + } + + return caPool, nil +} diff --git a/relay_manager.go b/relay_manager.go index fb90eec..7aa06cc 100644 --- a/relay_manager.go +++ b/relay_manager.go @@ -131,9 +131,9 @@ func (rm *relayManager) handleCreateRelayResponse(h *HostInfo, f *Interface, m * return } // I'm the middle man. Let the initiator know that the I've established the relay they requested. - peerHostInfo, err := rm.hostmap.QueryVpnIp(relay.PeerIp) - if err != nil { - rm.l.WithError(err).WithField("relayTo", relay.PeerIp).Error("Can't find a HostInfo for peer") + peerHostInfo := rm.hostmap.QueryVpnIp(relay.PeerIp) + if peerHostInfo == nil { + rm.l.WithField("relayTo", relay.PeerIp).Error("Can't find a HostInfo for peer") return } peerRelay, ok := peerHostInfo.relayState.QueryRelayForByIp(target) @@ -179,6 +179,12 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N "vpnIp": h.vpnIp}) logMsg.Info("handleCreateRelayRequest") + // Is the source of the relay me? This should never happen, but did happen due to + // an issue migrating relays over to newly re-handshaked host info objects. + if from == f.myVpnIp { + logMsg.WithField("myIP", f.myVpnIp).Error("Discarding relay request from myself") + return + } // Is the target of the relay me? if target == f.myVpnIp { existingRelay, ok := h.relayState.QueryRelayForByIp(from) @@ -240,11 +246,11 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N if !rm.GetAmRelay() { return } - peer, err := rm.hostmap.QueryVpnIp(target) - if err != nil { + peer := rm.hostmap.QueryVpnIp(target) + if peer == nil { // Try to establish a connection to this host. If we get a future relay request, // we'll be ready! - f.getOrHandshake(target) + f.Handshake(target) return } if peer.remote == nil { @@ -253,6 +259,7 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N } sendCreateRequest := false var index uint32 + var err error targetRelay, ok := peer.relayState.QueryRelayForByIp(from) if ok { index = targetRelay.LocalIndex diff --git a/remote_list.go b/remote_list.go index 4540714..60a1afd 100644 --- a/remote_list.go +++ b/remote_list.go @@ -70,7 +70,7 @@ type hostnamesResults struct { hostnames []hostnamePort network string lookupTimeout time.Duration - stop chan struct{} + cancelFn func() l *logrus.Logger ips atomic.Pointer[map[netip.AddrPort]struct{}] } @@ -80,7 +80,6 @@ func NewHostnameResults(ctx context.Context, l *logrus.Logger, d time.Duration, hostnames: make([]hostnamePort, len(hostPorts)), network: network, lookupTimeout: timeout, - stop: make(chan (struct{})), l: l, } @@ -115,6 +114,8 @@ func NewHostnameResults(ctx context.Context, l *logrus.Logger, d time.Duration, // Time for the DNS lookup goroutine if performBackgroundLookup { + newCtx, cancel := context.WithCancel(ctx) + r.cancelFn = cancel ticker := time.NewTicker(d) go func() { defer ticker.Stop() @@ -154,9 +155,7 @@ func NewHostnameResults(ctx context.Context, l *logrus.Logger, d time.Duration, onUpdate() } select { - case <-ctx.Done(): - return - case <-r.stop: + case <-newCtx.Done(): return case <-ticker.C: continue @@ -169,8 +168,8 @@ func NewHostnameResults(ctx context.Context, l *logrus.Logger, d time.Duration, } func (hr *hostnamesResults) Cancel() { - if hr != nil { - hr.stop <- struct{}{} + if hr != nil && hr.cancelFn != nil { + hr.cancelFn() } } @@ -582,20 +581,11 @@ func (r *RemoteList) unlockedCollect() { dnsAddrs := r.hr.GetIPs() for _, addr := range dnsAddrs { if r.shouldAdd == nil || r.shouldAdd(addr.Addr()) { - switch { - case addr.Addr().Is4(): - v4 := addr.Addr().As4() - addrs = append(addrs, &udp.Addr{ - IP: v4[:], - Port: addr.Port(), - }) - case addr.Addr().Is6(): - v6 := addr.Addr().As16() - addrs = append(addrs, &udp.Addr{ - IP: v6[:], - Port: addr.Port(), - }) - } + v6 := addr.Addr().As16() + addrs = append(addrs, &udp.Addr{ + IP: v6[:], + Port: addr.Port(), + }) } } diff --git a/ssh.go b/ssh.go index 6223314..30f9aea 100644 --- a/ssh.go +++ b/ssh.go @@ -3,6 +3,7 @@ package nebula import ( "bytes" "encoding/json" + "errors" "flag" "fmt" "io/ioutil" @@ -168,7 +169,7 @@ func configSSH(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) (func(), erro return runner, nil } -func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, hostMap *HostMap, pendingHostMap *HostMap, lightHouse *LightHouse, ifce *Interface) { +func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Interface) { ssh.RegisterCommand(&sshd.Command{ Name: "list-hostmap", ShortDescription: "List all known previously connected hosts", @@ -181,7 +182,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, hostMap return fl, &s }, Callback: func(fs interface{}, a []string, w sshd.StringWriter) error { - return sshListHostMap(hostMap, fs, w) + return sshListHostMap(f.hostMap, fs, w) }, }) @@ -197,7 +198,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, hostMap return fl, &s }, Callback: func(fs interface{}, a []string, w sshd.StringWriter) error { - return sshListHostMap(pendingHostMap, fs, w) + return sshListHostMap(f.handshakeManager, fs, w) }, }) @@ -212,7 +213,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, hostMap return fl, &s }, Callback: func(fs interface{}, a []string, w sshd.StringWriter) error { - return sshListLighthouseMap(lightHouse, fs, w) + return sshListLighthouseMap(f.lightHouse, fs, w) }, }) @@ -277,7 +278,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, hostMap Name: "version", ShortDescription: "Prints the currently running version of nebula", Callback: func(fs interface{}, a []string, w sshd.StringWriter) error { - return sshVersion(ifce, fs, a, w) + return sshVersion(f, fs, a, w) }, }) @@ -293,7 +294,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, hostMap return fl, &s }, Callback: func(fs interface{}, a []string, w sshd.StringWriter) error { - return sshPrintCert(ifce, fs, a, w) + return sshPrintCert(f, fs, a, w) }, }) @@ -307,7 +308,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, hostMap return fl, &s }, Callback: func(fs interface{}, a []string, w sshd.StringWriter) error { - return sshPrintTunnel(ifce, fs, a, w) + return sshPrintTunnel(f, fs, a, w) }, }) @@ -321,7 +322,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, hostMap return fl, &s }, Callback: func(fs interface{}, a []string, w sshd.StringWriter) error { - return sshPrintRelays(ifce, fs, a, w) + return sshPrintRelays(f, fs, a, w) }, }) @@ -335,7 +336,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, hostMap return fl, &s }, Callback: func(fs interface{}, a []string, w sshd.StringWriter) error { - return sshChangeRemote(ifce, fs, a, w) + return sshChangeRemote(f, fs, a, w) }, }) @@ -349,7 +350,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, hostMap return fl, &s }, Callback: func(fs interface{}, a []string, w sshd.StringWriter) error { - return sshCloseTunnel(ifce, fs, a, w) + return sshCloseTunnel(f, fs, a, w) }, }) @@ -364,7 +365,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, hostMap return fl, &s }, Callback: func(fs interface{}, a []string, w sshd.StringWriter) error { - return sshCreateTunnel(ifce, fs, a, w) + return sshCreateTunnel(f, fs, a, w) }, }) @@ -373,12 +374,12 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, hostMap ShortDescription: "Query the lighthouses for the provided vpn ip", Help: "This command is asynchronous. Only currently known udp ips will be printed.", Callback: func(fs interface{}, a []string, w sshd.StringWriter) error { - return sshQueryLighthouse(ifce, fs, a, w) + return sshQueryLighthouse(f, fs, a, w) }, }) } -func sshListHostMap(hostMap *HostMap, a interface{}, w sshd.StringWriter) error { +func sshListHostMap(hl controlHostLister, a interface{}, w sshd.StringWriter) error { fs, ok := a.(*sshListHostMapFlags) if !ok { //TODO: error @@ -387,9 +388,9 @@ func sshListHostMap(hostMap *HostMap, a interface{}, w sshd.StringWriter) error var hm []ControlHostInfo if fs.ByIndex { - hm = listHostMapIndexes(hostMap) + hm = listHostMapIndexes(hl) } else { - hm = listHostMapHosts(hostMap) + hm = listHostMapHosts(hl) } sort.Slice(hm, func(i, j int) bool { @@ -546,8 +547,8 @@ func sshCloseTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWr return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) } - hostInfo, err := ifce.hostMap.QueryVpnIp(vpnIp) - if err != nil { + hostInfo := ifce.hostMap.QueryVpnIp(vpnIp) + if hostInfo == nil { return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0])) } @@ -588,12 +589,12 @@ func sshCreateTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringW return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) } - hostInfo, _ := ifce.hostMap.QueryVpnIp(vpnIp) + hostInfo := ifce.hostMap.QueryVpnIp(vpnIp) if hostInfo != nil { return w.WriteLine(fmt.Sprintf("Tunnel already exists")) } - hostInfo, _ = ifce.handshakeManager.pendingHostMap.QueryVpnIp(vpnIp) + hostInfo = ifce.handshakeManager.QueryVpnIp(vpnIp) if hostInfo != nil { return w.WriteLine(fmt.Sprintf("Tunnel already handshaking")) } @@ -606,11 +607,10 @@ func sshCreateTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringW } } - hostInfo = ifce.handshakeManager.AddVpnIp(vpnIp, ifce.initHostInfo) + hostInfo = ifce.handshakeManager.StartHandshake(vpnIp, nil) if addr != nil { hostInfo.SetRemote(addr) } - ifce.getOrHandshake(vpnIp) return w.WriteLine("Created") } @@ -645,8 +645,8 @@ func sshChangeRemote(ifce *Interface, fs interface{}, a []string, w sshd.StringW return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) } - hostInfo, err := ifce.hostMap.QueryVpnIp(vpnIp) - if err != nil { + hostInfo := ifce.hostMap.QueryVpnIp(vpnIp) + if hostInfo == nil { return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0])) } @@ -753,7 +753,7 @@ func sshPrintCert(ifce *Interface, fs interface{}, a []string, w sshd.StringWrit return nil } - cert := ifce.certState.Load().certificate + cert := ifce.pki.GetCertState().Certificate if len(a) > 0 { parsedIp := net.ParseIP(a[0]) if parsedIp == nil { @@ -765,8 +765,8 @@ func sshPrintCert(ifce *Interface, fs interface{}, a []string, w sshd.StringWrit return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) } - hostInfo, err := ifce.hostMap.QueryVpnIp(vpnIp) - if err != nil { + hostInfo := ifce.hostMap.QueryVpnIp(vpnIp) + if hostInfo == nil { return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0])) } @@ -851,9 +851,9 @@ func sshPrintRelays(ifce *Interface, fs interface{}, a []string, w sshd.StringWr for k, v := range relays { ro := RelayOutput{NebulaIp: v.vpnIp} co.Relays = append(co.Relays, &ro) - relayHI, err := ifce.hostMap.QueryVpnIp(v.vpnIp) - if err != nil { - ro.RelayForIps = append(ro.RelayForIps, RelayFor{Error: err}) + relayHI := ifce.hostMap.QueryVpnIp(v.vpnIp) + if relayHI == nil { + ro.RelayForIps = append(ro.RelayForIps, RelayFor{Error: errors.New("could not find hostinfo")}) continue } for _, vpnIp := range relayHI.relayState.CopyRelayForIps() { @@ -889,8 +889,8 @@ func sshPrintRelays(ifce *Interface, fs interface{}, a []string, w sshd.StringWr rf.Error = fmt.Errorf("hostmap LocalIndex '%v' does not match RelayState LocalIndex", k) } } - relayedHI, err := ifce.hostMap.QueryVpnIp(vpnIp) - if err == nil { + relayedHI := ifce.hostMap.QueryVpnIp(vpnIp) + if relayedHI != nil { rf.RelayedThrough = append(rf.RelayedThrough, relayedHI.relayState.CopyRelayIps()...) } @@ -925,8 +925,8 @@ func sshPrintTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWr return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) } - hostInfo, err := ifce.hostMap.QueryVpnIp(vpnIp) - if err != nil { + hostInfo := ifce.hostMap.QueryVpnIp(vpnIp) + if hostInfo == nil { return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0])) } diff --git a/udp/conn.go b/udp/conn.go index f967a9a..a2c24a1 100644 --- a/udp/conn.go +++ b/udp/conn.go @@ -1,6 +1,7 @@ package udp import ( + "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/firewall" "github.com/slackhq/nebula/header" ) @@ -18,3 +19,33 @@ type EncReader func( q int, localCache firewall.ConntrackCache, ) + +type Conn interface { + Rebind() error + LocalAddr() (*Addr, error) + ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int) + WriteTo(b []byte, addr *Addr) error + ReloadConfig(c *config.C) + Close() error +} + +type NoopConn struct{} + +func (NoopConn) Rebind() error { + return nil +} +func (NoopConn) LocalAddr() (*Addr, error) { + return nil, nil +} +func (NoopConn) ListenOut(_ EncReader, _ LightHouseHandlerFunc, _ *firewall.ConntrackCacheTicker, _ int) { + return +} +func (NoopConn) WriteTo(_ []byte, _ *Addr) error { + return nil +} +func (NoopConn) ReloadConfig(_ *config.C) { + return +} +func (NoopConn) Close() error { + return nil +} diff --git a/udp/udp_android.go b/udp/udp_android.go index d2812a8..8d69074 100644 --- a/udp/udp_android.go +++ b/udp/udp_android.go @@ -8,9 +8,14 @@ import ( "net" "syscall" + "github.com/sirupsen/logrus" "golang.org/x/sys/unix" ) +func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (Conn, error) { + return NewGenericListener(l, ip, port, multi, batch) +} + func NewListenConfig(multi bool) net.ListenConfig { return net.ListenConfig{ Control: func(network, address string, c syscall.RawConn) error { @@ -34,6 +39,6 @@ func NewListenConfig(multi bool) net.ListenConfig { } } -func (u *Conn) Rebind() error { +func (u *GenericConn) Rebind() error { return nil } diff --git a/udp/udp_bsd.go b/udp/udp_bsd.go new file mode 100644 index 0000000..785aa6a --- /dev/null +++ b/udp/udp_bsd.go @@ -0,0 +1,47 @@ +//go:build (openbsd || freebsd) && !e2e_testing +// +build openbsd freebsd +// +build !e2e_testing + +package udp + +// FreeBSD support is primarily implemented in udp_generic, besides NewListenConfig + +import ( + "fmt" + "net" + "syscall" + + "github.com/sirupsen/logrus" + "golang.org/x/sys/unix" +) + +func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (Conn, error) { + return NewGenericListener(l, ip, port, multi, batch) +} + +func NewListenConfig(multi bool) net.ListenConfig { + return net.ListenConfig{ + Control: func(network, address string, c syscall.RawConn) error { + if multi { + var controlErr error + err := c.Control(func(fd uintptr) { + if err := syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, unix.SO_REUSEPORT, 1); err != nil { + controlErr = fmt.Errorf("SO_REUSEPORT failed: %v", err) + return + } + }) + if err != nil { + return err + } + if controlErr != nil { + return controlErr + } + } + return nil + }, + } +} + +func (u *GenericConn) Rebind() error { + return nil +} diff --git a/udp/udp_darwin.go b/udp/udp_darwin.go index 69d0c58..08e1b6a 100644 --- a/udp/udp_darwin.go +++ b/udp/udp_darwin.go @@ -10,9 +10,14 @@ import ( "net" "syscall" + "github.com/sirupsen/logrus" "golang.org/x/sys/unix" ) +func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (Conn, error) { + return NewGenericListener(l, ip, port, multi, batch) +} + func NewListenConfig(multi bool) net.ListenConfig { return net.ListenConfig{ Control: func(network, address string, c syscall.RawConn) error { @@ -37,11 +42,16 @@ func NewListenConfig(multi bool) net.ListenConfig { } } -func (u *Conn) Rebind() error { - file, err := u.File() +func (u *GenericConn) Rebind() error { + rc, err := u.UDPConn.SyscallConn() if err != nil { return err } - return syscall.SetsockoptInt(int(file.Fd()), unix.IPPROTO_IPV6, unix.IPV6_BOUND_IF, 0) + return rc.Control(func(fd uintptr) { + err := syscall.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_BOUND_IF, 0) + if err != nil { + u.l.WithError(err).Error("Failed to rebind udp socket") + } + }) } diff --git a/udp/udp_generic.go b/udp/udp_generic.go index ff254eb..1dd6d1d 100644 --- a/udp/udp_generic.go +++ b/udp/udp_generic.go @@ -18,30 +18,32 @@ import ( "github.com/slackhq/nebula/header" ) -type Conn struct { +type GenericConn struct { *net.UDPConn l *logrus.Logger } -func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (*Conn, error) { +var _ Conn = &GenericConn{} + +func NewGenericListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (Conn, error) { lc := NewListenConfig(multi) pc, err := lc.ListenPacket(context.TODO(), "udp", net.JoinHostPort(ip.String(), fmt.Sprintf("%v", port))) if err != nil { return nil, err } if uc, ok := pc.(*net.UDPConn); ok { - return &Conn{UDPConn: uc, l: l}, nil + return &GenericConn{UDPConn: uc, l: l}, nil } return nil, fmt.Errorf("Unexpected PacketConn: %T %#v", pc, pc) } -func (uc *Conn) WriteTo(b []byte, addr *Addr) error { - _, err := uc.UDPConn.WriteToUDP(b, &net.UDPAddr{IP: addr.IP, Port: int(addr.Port)}) +func (u *GenericConn) WriteTo(b []byte, addr *Addr) error { + _, err := u.UDPConn.WriteToUDP(b, &net.UDPAddr{IP: addr.IP, Port: int(addr.Port)}) return err } -func (uc *Conn) LocalAddr() (*Addr, error) { - a := uc.UDPConn.LocalAddr() +func (u *GenericConn) LocalAddr() (*Addr, error) { + a := u.UDPConn.LocalAddr() switch v := a.(type) { case *net.UDPAddr: @@ -55,11 +57,11 @@ func (uc *Conn) LocalAddr() (*Addr, error) { } } -func (u *Conn) ReloadConfig(c *config.C) { +func (u *GenericConn) ReloadConfig(c *config.C) { // TODO } -func NewUDPStatsEmitter(udpConns []*Conn) func() { +func NewUDPStatsEmitter(udpConns []Conn) func() { // No UDP stats for non-linux return func() {} } @@ -68,7 +70,7 @@ type rawMessage struct { Len uint32 } -func (u *Conn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int) { +func (u *GenericConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int) { plaintext := make([]byte, MTU) buffer := make([]byte, MTU) h := &header.H{} @@ -80,8 +82,8 @@ func (u *Conn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall // Just read one packet at a time n, rua, err := u.ReadFromUDP(buffer) if err != nil { - u.l.WithError(err).Error("Failed to read packets") - continue + u.l.WithError(err).Debug("udp socket is closed, exiting read loop") + return } udpAddr.IP = rua.IP diff --git a/udp/udp_linux.go b/udp/udp_linux.go index 26bbe36..ca050bb 100644 --- a/udp/udp_linux.go +++ b/udp/udp_linux.go @@ -20,7 +20,7 @@ import ( //TODO: make it support reload as best you can! -type Conn struct { +type StdConn struct { sysFd int l *logrus.Logger batch int @@ -45,7 +45,7 @@ const ( type _SK_MEMINFO [_SK_MEMINFO_VARS]uint32 -func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (*Conn, error) { +func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (Conn, error) { syscall.ForkLock.RLock() fd, err := unix.Socket(unix.AF_INET6, unix.SOCK_DGRAM, unix.IPPROTO_UDP) if err == nil { @@ -77,30 +77,30 @@ func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) ( //v, err := unix.GetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_INCOMING_CPU) //l.Println(v, err) - return &Conn{sysFd: fd, l: l, batch: batch}, err + return &StdConn{sysFd: fd, l: l, batch: batch}, err } -func (u *Conn) Rebind() error { +func (u *StdConn) Rebind() error { return nil } -func (u *Conn) SetRecvBuffer(n int) error { +func (u *StdConn) SetRecvBuffer(n int) error { return unix.SetsockoptInt(u.sysFd, unix.SOL_SOCKET, unix.SO_RCVBUFFORCE, n) } -func (u *Conn) SetSendBuffer(n int) error { +func (u *StdConn) SetSendBuffer(n int) error { return unix.SetsockoptInt(u.sysFd, unix.SOL_SOCKET, unix.SO_SNDBUFFORCE, n) } -func (u *Conn) GetRecvBuffer() (int, error) { +func (u *StdConn) GetRecvBuffer() (int, error) { return unix.GetsockoptInt(int(u.sysFd), unix.SOL_SOCKET, unix.SO_RCVBUF) } -func (u *Conn) GetSendBuffer() (int, error) { +func (u *StdConn) GetSendBuffer() (int, error) { return unix.GetsockoptInt(int(u.sysFd), unix.SOL_SOCKET, unix.SO_SNDBUF) } -func (u *Conn) LocalAddr() (*Addr, error) { +func (u *StdConn) LocalAddr() (*Addr, error) { sa, err := unix.Getsockname(u.sysFd) if err != nil { return nil, err @@ -119,7 +119,7 @@ func (u *Conn) LocalAddr() (*Addr, error) { return addr, nil } -func (u *Conn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int) { +func (u *StdConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int) { plaintext := make([]byte, MTU) h := &header.H{} fwPacket := &firewall.Packet{} @@ -137,8 +137,8 @@ func (u *Conn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall for { n, err := read(msgs) if err != nil { - u.l.WithError(err).Error("Failed to read packets") - continue + u.l.WithError(err).Debug("udp socket is closed, exiting read loop") + return } //metric.Update(int64(n)) @@ -150,7 +150,7 @@ func (u *Conn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall } } -func (u *Conn) ReadSingle(msgs []rawMessage) (int, error) { +func (u *StdConn) ReadSingle(msgs []rawMessage) (int, error) { for { n, _, err := unix.Syscall6( unix.SYS_RECVMSG, @@ -171,7 +171,7 @@ func (u *Conn) ReadSingle(msgs []rawMessage) (int, error) { } } -func (u *Conn) ReadMulti(msgs []rawMessage) (int, error) { +func (u *StdConn) ReadMulti(msgs []rawMessage) (int, error) { for { n, _, err := unix.Syscall6( unix.SYS_RECVMMSG, @@ -191,7 +191,7 @@ func (u *Conn) ReadMulti(msgs []rawMessage) (int, error) { } } -func (u *Conn) WriteTo(b []byte, addr *Addr) error { +func (u *StdConn) WriteTo(b []byte, addr *Addr) error { var rsa unix.RawSockaddrInet6 rsa.Family = unix.AF_INET6 @@ -221,7 +221,7 @@ func (u *Conn) WriteTo(b []byte, addr *Addr) error { } } -func (u *Conn) ReloadConfig(c *config.C) { +func (u *StdConn) ReloadConfig(c *config.C) { b := c.GetInt("listen.read_buffer", 0) if b > 0 { err := u.SetRecvBuffer(b) @@ -253,7 +253,7 @@ func (u *Conn) ReloadConfig(c *config.C) { } } -func (u *Conn) getMemInfo(meminfo *_SK_MEMINFO) error { +func (u *StdConn) getMemInfo(meminfo *_SK_MEMINFO) error { var vallen uint32 = 4 * _SK_MEMINFO_VARS _, _, err := unix.Syscall6(unix.SYS_GETSOCKOPT, uintptr(u.sysFd), uintptr(unix.SOL_SOCKET), uintptr(unix.SO_MEMINFO), uintptr(unsafe.Pointer(meminfo)), uintptr(unsafe.Pointer(&vallen)), 0) if err != 0 { @@ -262,11 +262,16 @@ func (u *Conn) getMemInfo(meminfo *_SK_MEMINFO) error { return nil } -func NewUDPStatsEmitter(udpConns []*Conn) func() { +func (u *StdConn) Close() error { + //TODO: this will not interrupt the read loop + return syscall.Close(u.sysFd) +} + +func NewUDPStatsEmitter(udpConns []Conn) func() { // Check if our kernel supports SO_MEMINFO before registering the gauges var udpGauges [][_SK_MEMINFO_VARS]metrics.Gauge var meminfo _SK_MEMINFO - if err := udpConns[0].getMemInfo(&meminfo); err == nil { + if err := udpConns[0].(*StdConn).getMemInfo(&meminfo); err == nil { udpGauges = make([][_SK_MEMINFO_VARS]metrics.Gauge, len(udpConns)) for i := range udpConns { udpGauges[i] = [_SK_MEMINFO_VARS]metrics.Gauge{ @@ -285,7 +290,7 @@ func NewUDPStatsEmitter(udpConns []*Conn) func() { return func() { for i, gauges := range udpGauges { - if err := udpConns[i].getMemInfo(&meminfo); err == nil { + if err := udpConns[i].(*StdConn).getMemInfo(&meminfo); err == nil { for j := 0; j < _SK_MEMINFO_VARS; j++ { gauges[j].Update(int64(meminfo[j])) } diff --git a/udp/udp_linux_32.go b/udp/udp_linux_32.go index 06cd382..523968c 100644 --- a/udp/udp_linux_32.go +++ b/udp/udp_linux_32.go @@ -30,7 +30,7 @@ type rawMessage struct { Len uint32 } -func (u *Conn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) { +func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) { msgs := make([]rawMessage, n) buffers := make([][]byte, n) names := make([][]byte, n) diff --git a/udp/udp_linux_64.go b/udp/udp_linux_64.go index c442405..a54f1df 100644 --- a/udp/udp_linux_64.go +++ b/udp/udp_linux_64.go @@ -33,7 +33,7 @@ type rawMessage struct { Pad0 [4]byte } -func (u *Conn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) { +func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) { msgs := make([]rawMessage, n) buffers := make([][]byte, n) names := make([][]byte, n) diff --git a/udp/udp_freebsd.go b/udp/udp_netbsd.go similarity index 77% rename from udp/udp_freebsd.go rename to udp/udp_netbsd.go index 10ff94b..3c14fac 100644 --- a/udp/udp_freebsd.go +++ b/udp/udp_netbsd.go @@ -10,9 +10,14 @@ import ( "net" "syscall" + "github.com/sirupsen/logrus" "golang.org/x/sys/unix" ) +func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (Conn, error) { + return NewGenericListener(l, ip, port, multi, batch) +} + func NewListenConfig(multi bool) net.ListenConfig { return net.ListenConfig{ Control: func(network, address string, c syscall.RawConn) error { @@ -36,6 +41,6 @@ func NewListenConfig(multi bool) net.ListenConfig { } } -func (u *Conn) Rebind() error { +func (u *GenericConn) Rebind() error { return nil } diff --git a/udp/udp_rio_windows.go b/udp/udp_rio_windows.go new file mode 100644 index 0000000..31c1a55 --- /dev/null +++ b/udp/udp_rio_windows.go @@ -0,0 +1,403 @@ +//go:build !e2e_testing +// +build !e2e_testing + +// Inspired by https://git.zx2c4.com/wireguard-go/tree/conn/bind_windows.go + +package udp + +import ( + "errors" + "fmt" + "io" + "net" + "sync" + "sync/atomic" + "syscall" + "unsafe" + + "github.com/sirupsen/logrus" + "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/firewall" + "github.com/slackhq/nebula/header" + + "golang.org/x/sys/windows" + "golang.zx2c4.com/wireguard/conn/winrio" +) + +// Assert we meet the standard conn interface +var _ Conn = &RIOConn{} + +//go:linkname procyield runtime.procyield +func procyield(cycles uint32) + +const ( + packetsPerRing = 1024 + bytesPerPacket = 2048 - 32 + receiveSpins = 15 +) + +type ringPacket struct { + addr windows.RawSockaddrInet6 + data [bytesPerPacket]byte +} + +type ringBuffer struct { + packets uintptr + head, tail uint32 + id winrio.BufferId + iocp windows.Handle + isFull bool + cq winrio.Cq + mu sync.Mutex + overlapped windows.Overlapped +} + +type RIOConn struct { + isOpen atomic.Bool + l *logrus.Logger + sock windows.Handle + rx, tx ringBuffer + rq winrio.Rq + results [packetsPerRing]winrio.Result +} + +func NewRIOListener(l *logrus.Logger, ip net.IP, port int) (*RIOConn, error) { + if !winrio.Initialize() { + return nil, errors.New("could not initialize winrio") + } + + u := &RIOConn{l: l} + + addr := [16]byte{} + copy(addr[:], ip.To16()) + err := u.bind(&windows.SockaddrInet6{Addr: addr, Port: port}) + if err != nil { + return nil, fmt.Errorf("bind: %w", err) + } + + for i := 0; i < packetsPerRing; i++ { + err = u.insertReceiveRequest() + if err != nil { + return nil, fmt.Errorf("init rx ring: %w", err) + } + } + + u.isOpen.Store(true) + return u, nil +} + +func (u *RIOConn) bind(sa windows.Sockaddr) error { + var err error + u.sock, err = winrio.Socket(windows.AF_INET6, windows.SOCK_DGRAM, windows.IPPROTO_UDP) + if err != nil { + return err + } + + // Enable v4 for this socket + syscall.SetsockoptInt(syscall.Handle(u.sock), syscall.IPPROTO_IPV6, syscall.IPV6_V6ONLY, 0) + + err = u.rx.Open() + if err != nil { + return err + } + + err = u.tx.Open() + if err != nil { + return err + } + + u.rq, err = winrio.CreateRequestQueue(u.sock, packetsPerRing, 1, packetsPerRing, 1, u.rx.cq, u.tx.cq, 0) + if err != nil { + return err + } + + err = windows.Bind(u.sock, sa) + if err != nil { + return err + } + + return nil +} + +func (u *RIOConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int) { + plaintext := make([]byte, MTU) + buffer := make([]byte, MTU) + h := &header.H{} + fwPacket := &firewall.Packet{} + udpAddr := &Addr{IP: make([]byte, 16)} + nb := make([]byte, 12, 12) + + for { + // Just read one packet at a time + n, rua, err := u.receive(buffer) + if err != nil { + u.l.WithError(err).Debug("udp socket is closed, exiting read loop") + return + } + + udpAddr.IP = rua.Addr[:] + p := (*[2]byte)(unsafe.Pointer(&udpAddr.Port)) + p[0] = byte(rua.Port >> 8) + p[1] = byte(rua.Port) + r(udpAddr, plaintext[:0], buffer[:n], h, fwPacket, lhf, nb, q, cache.Get(u.l)) + } +} + +func (u *RIOConn) insertReceiveRequest() error { + packet := u.rx.Push() + dataBuffer := &winrio.Buffer{ + Id: u.rx.id, + Offset: uint32(uintptr(unsafe.Pointer(&packet.data[0])) - u.rx.packets), + Length: uint32(len(packet.data)), + } + addressBuffer := &winrio.Buffer{ + Id: u.rx.id, + Offset: uint32(uintptr(unsafe.Pointer(&packet.addr)) - u.rx.packets), + Length: uint32(unsafe.Sizeof(packet.addr)), + } + + return winrio.ReceiveEx(u.rq, dataBuffer, 1, nil, addressBuffer, nil, nil, 0, uintptr(unsafe.Pointer(packet))) +} + +func (u *RIOConn) receive(buf []byte) (int, windows.RawSockaddrInet6, error) { + if !u.isOpen.Load() { + return 0, windows.RawSockaddrInet6{}, net.ErrClosed + } + + u.rx.mu.Lock() + defer u.rx.mu.Unlock() + + var err error + var count uint32 + var results [1]winrio.Result + +retry: + count = 0 + for tries := 0; count == 0 && tries < receiveSpins; tries++ { + if tries > 0 { + if !u.isOpen.Load() { + return 0, windows.RawSockaddrInet6{}, net.ErrClosed + } + procyield(1) + } + + count = winrio.DequeueCompletion(u.rx.cq, results[:]) + } + + if count == 0 { + err = winrio.Notify(u.rx.cq) + if err != nil { + return 0, windows.RawSockaddrInet6{}, err + } + var bytes uint32 + var key uintptr + var overlapped *windows.Overlapped + err = windows.GetQueuedCompletionStatus(u.rx.iocp, &bytes, &key, &overlapped, windows.INFINITE) + if err != nil { + return 0, windows.RawSockaddrInet6{}, err + } + + if !u.isOpen.Load() { + return 0, windows.RawSockaddrInet6{}, net.ErrClosed + } + + count = winrio.DequeueCompletion(u.rx.cq, results[:]) + if count == 0 { + return 0, windows.RawSockaddrInet6{}, io.ErrNoProgress + + } + } + + u.rx.Return(1) + err = u.insertReceiveRequest() + if err != nil { + return 0, windows.RawSockaddrInet6{}, err + } + + // We limit the MTU well below the 65k max for practicality, but this means a remote host can still send us + // huge packets. Just try again when this happens. The infinite loop this could cause is still limited to + // attacker bandwidth, just like the rest of the receive path. + if windows.Errno(results[0].Status) == windows.WSAEMSGSIZE { + goto retry + } + + if results[0].Status != 0 { + return 0, windows.RawSockaddrInet6{}, windows.Errno(results[0].Status) + } + + packet := (*ringPacket)(unsafe.Pointer(uintptr(results[0].RequestContext))) + ep := packet.addr + n := copy(buf, packet.data[:results[0].BytesTransferred]) + return n, ep, nil +} + +func (u *RIOConn) WriteTo(buf []byte, addr *Addr) error { + if !u.isOpen.Load() { + return net.ErrClosed + } + + if len(buf) > bytesPerPacket { + return io.ErrShortBuffer + } + + u.tx.mu.Lock() + defer u.tx.mu.Unlock() + + count := winrio.DequeueCompletion(u.tx.cq, u.results[:]) + if count == 0 && u.tx.isFull { + err := winrio.Notify(u.tx.cq) + if err != nil { + return err + } + + var bytes uint32 + var key uintptr + var overlapped *windows.Overlapped + err = windows.GetQueuedCompletionStatus(u.tx.iocp, &bytes, &key, &overlapped, windows.INFINITE) + if err != nil { + return err + } + + if !u.isOpen.Load() { + return net.ErrClosed + } + + count = winrio.DequeueCompletion(u.tx.cq, u.results[:]) + if count == 0 { + return io.ErrNoProgress + } + } + + if count > 0 { + u.tx.Return(count) + } + + packet := u.tx.Push() + packet.addr.Family = windows.AF_INET6 + p := (*[2]byte)(unsafe.Pointer(&packet.addr.Port)) + p[0] = byte(addr.Port >> 8) + p[1] = byte(addr.Port) + copy(packet.addr.Addr[:], addr.IP.To16()) + copy(packet.data[:], buf) + + dataBuffer := &winrio.Buffer{ + Id: u.tx.id, + Offset: uint32(uintptr(unsafe.Pointer(&packet.data[0])) - u.tx.packets), + Length: uint32(len(buf)), + } + + addressBuffer := &winrio.Buffer{ + Id: u.tx.id, + Offset: uint32(uintptr(unsafe.Pointer(&packet.addr)) - u.tx.packets), + Length: uint32(unsafe.Sizeof(packet.addr)), + } + + return winrio.SendEx(u.rq, dataBuffer, 1, nil, addressBuffer, nil, nil, 0, 0) +} + +func (u *RIOConn) LocalAddr() (*Addr, error) { + sa, err := windows.Getsockname(u.sock) + if err != nil { + return nil, err + } + + v6 := sa.(*windows.SockaddrInet6) + return &Addr{ + IP: v6.Addr[:], + Port: uint16(v6.Port), + }, nil +} + +func (u *RIOConn) Rebind() error { + return nil +} + +func (u *RIOConn) ReloadConfig(*config.C) {} + +func (u *RIOConn) Close() error { + if !u.isOpen.CompareAndSwap(true, false) { + return nil + } + + windows.PostQueuedCompletionStatus(u.rx.iocp, 0, 0, nil) + windows.PostQueuedCompletionStatus(u.tx.iocp, 0, 0, nil) + + u.rx.CloseAndZero() + u.tx.CloseAndZero() + if u.sock != 0 { + windows.CloseHandle(u.sock) + } + return nil +} + +func (ring *ringBuffer) Push() *ringPacket { + for ring.isFull { + panic("ring is full") + } + ret := (*ringPacket)(unsafe.Pointer(ring.packets + (uintptr(ring.tail%packetsPerRing) * unsafe.Sizeof(ringPacket{})))) + ring.tail += 1 + if ring.tail%packetsPerRing == ring.head%packetsPerRing { + ring.isFull = true + } + return ret +} + +func (ring *ringBuffer) Return(count uint32) { + if ring.head%packetsPerRing == ring.tail%packetsPerRing && !ring.isFull { + return + } + ring.head += count + ring.isFull = false +} + +func (ring *ringBuffer) CloseAndZero() { + if ring.cq != 0 { + winrio.CloseCompletionQueue(ring.cq) + ring.cq = 0 + } + + if ring.iocp != 0 { + windows.CloseHandle(ring.iocp) + ring.iocp = 0 + } + + if ring.id != 0 { + winrio.DeregisterBuffer(ring.id) + ring.id = 0 + } + + if ring.packets != 0 { + windows.VirtualFree(ring.packets, 0, windows.MEM_RELEASE) + ring.packets = 0 + } + + ring.head = 0 + ring.tail = 0 + ring.isFull = false +} + +func (ring *ringBuffer) Open() error { + var err error + packetsLen := unsafe.Sizeof(ringPacket{}) * packetsPerRing + ring.packets, err = windows.VirtualAlloc(0, packetsLen, windows.MEM_COMMIT|windows.MEM_RESERVE, windows.PAGE_READWRITE) + if err != nil { + return err + } + + ring.id, err = winrio.RegisterPointer(unsafe.Pointer(ring.packets), uint32(packetsLen)) + if err != nil { + return err + } + + ring.iocp, err = windows.CreateIoCompletionPort(windows.InvalidHandle, 0, 0, 0) + if err != nil { + return err + } + + ring.cq, err = winrio.CreateIOCPCompletionQueue(packetsPerRing, ring.iocp, 0, &ring.overlapped) + if err != nil { + return err + } + + return nil +} diff --git a/udp/udp_tester.go b/udp/udp_tester.go index 8b5e531..55985f4 100644 --- a/udp/udp_tester.go +++ b/udp/udp_tester.go @@ -5,7 +5,9 @@ package udp import ( "fmt" + "io" "net" + "sync/atomic" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" @@ -36,17 +38,18 @@ func (u *Packet) Copy() *Packet { return n } -type Conn struct { +type TesterConn struct { Addr *Addr RxPackets chan *Packet // Packets to receive into nebula TxPackets chan *Packet // Packets transmitted outside by nebula - l *logrus.Logger + closed atomic.Bool + l *logrus.Logger } -func NewListener(l *logrus.Logger, ip net.IP, port int, _ bool, _ int) (*Conn, error) { - return &Conn{ +func NewListener(l *logrus.Logger, ip net.IP, port int, _ bool, _ int) (Conn, error) { + return &TesterConn{ Addr: &Addr{ip, uint16(port)}, RxPackets: make(chan *Packet, 10), TxPackets: make(chan *Packet, 10), @@ -57,7 +60,11 @@ func NewListener(l *logrus.Logger, ip net.IP, port int, _ bool, _ int) (*Conn, e // Send will place a UdpPacket onto the receive queue for nebula to consume // this is an encrypted packet or a handshake message in most cases // packets were transmitted from another nebula node, you can send them with Tun.Send -func (u *Conn) Send(packet *Packet) { +func (u *TesterConn) Send(packet *Packet) { + if u.closed.Load() { + return + } + h := &header.H{} if err := h.Parse(packet.Data); err != nil { panic(err) @@ -74,7 +81,7 @@ func (u *Conn) Send(packet *Packet) { // Get will pull a UdpPacket from the transmit queue // nebula meant to send this message on the network, it will be encrypted // packets were ingested from the tun side (in most cases), you can send them with Tun.Send -func (u *Conn) Get(block bool) *Packet { +func (u *TesterConn) Get(block bool) *Packet { if block { return <-u.TxPackets } @@ -91,7 +98,11 @@ func (u *Conn) Get(block bool) *Packet { // Below this is boilerplate implementation to make nebula actually work //********************************************************************************************************************// -func (u *Conn) WriteTo(b []byte, addr *Addr) error { +func (u *TesterConn) WriteTo(b []byte, addr *Addr) error { + if u.closed.Load() { + return io.ErrClosedPipe + } + p := &Packet{ Data: make([]byte, len(b), len(b)), FromIp: make([]byte, 16), @@ -108,7 +119,7 @@ func (u *Conn) WriteTo(b []byte, addr *Addr) error { return nil } -func (u *Conn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int) { +func (u *TesterConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int) { plaintext := make([]byte, MTU) h := &header.H{} fwPacket := &firewall.Packet{} @@ -126,17 +137,25 @@ func (u *Conn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall } } -func (u *Conn) ReloadConfig(*config.C) {} +func (u *TesterConn) ReloadConfig(*config.C) {} -func NewUDPStatsEmitter(_ []*Conn) func() { +func NewUDPStatsEmitter(_ []Conn) func() { // No UDP stats for non-linux return func() {} } -func (u *Conn) LocalAddr() (*Addr, error) { +func (u *TesterConn) LocalAddr() (*Addr, error) { return u.Addr, nil } -func (u *Conn) Rebind() error { +func (u *TesterConn) Rebind() error { + return nil +} + +func (u *TesterConn) Close() error { + if u.closed.CompareAndSwap(false, true) { + close(u.RxPackets) + close(u.TxPackets) + } return nil } diff --git a/udp/udp_windows.go b/udp/udp_windows.go index 1f2ce64..ebcace6 100644 --- a/udp/udp_windows.go +++ b/udp/udp_windows.go @@ -3,14 +3,31 @@ package udp -// Windows support is primarily implemented in udp_generic, besides NewListenConfig - import ( "fmt" "net" "syscall" + + "github.com/sirupsen/logrus" ) +func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (Conn, error) { + if multi { + //NOTE: Technically we can support it with RIO but it wouldn't be at the socket level + // The udp stack would need to be reworked to hide away the implementation differences between + // Windows and Linux + return nil, fmt.Errorf("multiple udp listeners not supported on windows") + } + + rc, err := NewRIOListener(l, ip, port) + if err == nil { + return rc, nil + } + + l.WithError(err).Error("Falling back to standard udp sockets") + return NewGenericListener(l, ip, port, multi, batch) +} + func NewListenConfig(multi bool) net.ListenConfig { return net.ListenConfig{ Control: func(network, address string, c syscall.RawConn) error { @@ -24,6 +41,6 @@ func NewListenConfig(multi bool) net.ListenConfig { } } -func (u *Conn) Rebind() error { +func (u *GenericConn) Rebind() error { return nil } diff --git a/util/error.go b/util/error.go index 7f9bc47..a11c9c4 100644 --- a/util/error.go +++ b/util/error.go @@ -12,18 +12,38 @@ type ContextualError struct { Context string } -func NewContextualError(msg string, fields map[string]interface{}, realError error) ContextualError { - return ContextualError{Context: msg, Fields: fields, RealError: realError} +func NewContextualError(msg string, fields map[string]interface{}, realError error) *ContextualError { + return &ContextualError{Context: msg, Fields: fields, RealError: realError} } -func (ce ContextualError) Error() string { +// ContextualizeIfNeeded is a helper function to turn an error into a ContextualError if it is not already one +func ContextualizeIfNeeded(msg string, err error) error { + switch err.(type) { + case *ContextualError: + return err + default: + return NewContextualError(msg, nil, err) + } +} + +// LogWithContextIfNeeded is a helper function to log an error line for an error or ContextualError +func LogWithContextIfNeeded(msg string, err error, l *logrus.Logger) { + switch v := err.(type) { + case *ContextualError: + v.Log(l) + default: + l.WithError(err).Error(msg) + } +} + +func (ce *ContextualError) Error() string { if ce.RealError == nil { return ce.Context } return ce.RealError.Error() } -func (ce ContextualError) Unwrap() error { +func (ce *ContextualError) Unwrap() error { if ce.RealError == nil { return errors.New(ce.Context) } diff --git a/util/error_test.go b/util/error_test.go index 747d04e..5041f82 100644 --- a/util/error_test.go +++ b/util/error_test.go @@ -2,6 +2,7 @@ package util import ( "errors" + "fmt" "testing" "github.com/sirupsen/logrus" @@ -67,3 +68,44 @@ func TestContextualError_Log(t *testing.T) { e.Log(l) assert.Equal(t, []string{"level=error error=error\n"}, tl.Logs) } + +func TestLogWithContextIfNeeded(t *testing.T) { + l := logrus.New() + l.Formatter = &logrus.TextFormatter{ + DisableTimestamp: true, + DisableColors: true, + } + + tl := NewTestLogWriter() + l.Out = tl + + // Test ignoring fallback context + tl.Reset() + e := NewContextualError("test message", m{"field": "1"}, errors.New("error")) + LogWithContextIfNeeded("This should get thrown away", e, l) + assert.Equal(t, []string{"level=error msg=\"test message\" error=error field=1\n"}, tl.Logs) + + // Test using fallback context + tl.Reset() + err := fmt.Errorf("this is a normal error") + LogWithContextIfNeeded("Fallback context woo", err, l) + assert.Equal(t, []string{"level=error msg=\"Fallback context woo\" error=\"this is a normal error\"\n"}, tl.Logs) +} + +func TestContextualizeIfNeeded(t *testing.T) { + // Test ignoring fallback context + e := NewContextualError("test message", m{"field": "1"}, errors.New("error")) + assert.Same(t, e, ContextualizeIfNeeded("should be ignored", e)) + + // Test using fallback context + err := fmt.Errorf("this is a normal error") + cErr := ContextualizeIfNeeded("Fallback context woo", err) + + switch v := cErr.(type) { + case *ContextualError: + assert.Equal(t, err, v.RealError) + default: + t.Error("Error was not wrapped") + t.Fail() + } +}