mirror of
https://github.com/slackhq/nebula.git
synced 2025-11-22 16:34:25 +01:00
Compare commits
109 Commits
cert-v2-mo
...
e2e-bench-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
662a3358ee | ||
|
|
64f202fa17 | ||
|
|
6d7cf611c9 | ||
|
|
83ae8077f5 | ||
|
|
12cf348c80 | ||
|
|
a5ee928990 | ||
|
|
7aff313a17 | ||
|
|
297767b2e3 | ||
|
|
99faab505c | ||
|
|
584c2668b3 | ||
|
|
27ea667aee | ||
|
|
4df8bcb1f5 | ||
|
|
36c890eaad | ||
|
|
44001244f2 | ||
|
|
a89f95182c | ||
|
|
6a8a2992ff | ||
|
|
3d94dfe6a1 | ||
|
|
3670e24fa0 | ||
|
|
b348ee726e | ||
|
|
a941b65114 | ||
|
|
17101d425f | ||
|
|
52f1908126 | ||
|
|
48f1ae98ba | ||
|
|
97b3972c11 | ||
|
|
0f305d5397 | ||
|
|
01909f4715 | ||
|
|
770147264d | ||
|
|
fa8c013b97 | ||
|
|
2710f2af06 | ||
|
|
ad6d3e6bac | ||
|
|
2b0aa74e85 | ||
|
|
b126d88963 | ||
|
|
45c1d3eab3 | ||
|
|
634181ba66 | ||
|
|
eb89839d13 | ||
|
|
fb7f0c3657 | ||
|
|
b1f53d8d25 | ||
|
|
8824eeaea2 | ||
|
|
071589f7c7 | ||
|
|
f1e992f6dd | ||
|
|
1ea5f776d7 | ||
|
|
4cdeb284ef | ||
|
|
5cccd39465 | ||
|
|
8196c22b5a | ||
|
|
65cc253c19 | ||
|
|
73cfa7b5b1 | ||
|
|
768325c9b4 | ||
|
|
932e329164 | ||
|
|
4bea299265 | ||
|
|
5cff83b282 | ||
|
|
7da79685ff | ||
|
|
91eff03418 | ||
|
|
52623820c2 | ||
|
|
c2420642a0 | ||
|
|
b3a1f7b0a3 | ||
|
|
94142aded5 | ||
|
|
b158eb0c4c | ||
|
|
e4b7dbcfb0 | ||
|
|
882edf11d7 | ||
|
|
d34c2b8e06 | ||
|
|
442a52879b | ||
|
|
061e733007 | ||
|
|
92a9248083 | ||
|
|
83ff2461e2 | ||
|
|
8536c57645 | ||
|
|
15b5a43300 | ||
|
|
e5ce8966d6 | ||
|
|
2dc30fc300 | ||
|
|
b8ea55eb90 | ||
|
|
4eb056af9d | ||
|
|
e49f279004 | ||
|
|
459cb38a6d | ||
|
|
18279ed17b | ||
|
|
c7fb3ad9cf | ||
|
|
d4a7df3083 | ||
|
|
e83a1c6c84 | ||
|
|
f5d096dd2b | ||
|
|
e2d6f4e444 | ||
|
|
d99fd60e06 | ||
|
|
e4bae15825 | ||
|
|
58ead4116f | ||
|
|
e136d1d47a | ||
|
|
d2adebf26d | ||
|
|
36bc9dd261 | ||
|
|
879852c32a | ||
|
|
75faa5f2e5 | ||
|
|
4444ed166a | ||
|
|
f86953ca56 | ||
|
|
3de36c99b6 | ||
|
|
50473bd2a8 | ||
|
|
1d3c85338c | ||
|
|
2fb018ced8 | ||
|
|
088af8edb2 | ||
|
|
612637f529 | ||
|
|
94e89a1045 | ||
|
|
f7540ad355 | ||
|
|
096179a8c9 | ||
|
|
f8734ffa43 | ||
|
|
c58e223b3d | ||
|
|
c46ef43590 | ||
|
|
775c6bc83d | ||
|
|
13799f425d | ||
|
|
8a090e59d7 | ||
|
|
9feda811a6 | ||
|
|
750e4a81bf | ||
|
|
32d3a6e091 | ||
|
|
351dbd6059 | ||
|
|
d97ed57a19 | ||
|
|
2b427a7e89 |
22
.github/ISSUE_TEMPLATE/config.yml
vendored
22
.github/ISSUE_TEMPLATE/config.yml
vendored
@@ -1,13 +1,21 @@
|
|||||||
blank_issues_enabled: true
|
blank_issues_enabled: true
|
||||||
contact_links:
|
contact_links:
|
||||||
|
- name: 💨 Performance Issues
|
||||||
|
url: https://github.com/slackhq/nebula/discussions/new/choose
|
||||||
|
about: 'We ask that you create a discussion instead of an issue for performance-related questions. This allows us to have a more open conversation about the issue and helps us to better understand the problem.'
|
||||||
|
|
||||||
|
- name: 📄 Documentation Issues
|
||||||
|
url: https://github.com/definednet/nebula-docs
|
||||||
|
about: "If you've found an issue with the website documentation, please file it in the nebula-docs repository."
|
||||||
|
|
||||||
|
- name: 📱 Mobile Nebula Issues
|
||||||
|
url: https://github.com/definednet/mobile_nebula
|
||||||
|
about: "If you're using the mobile Nebula app and have found an issue, please file it in the mobile_nebula repository."
|
||||||
|
|
||||||
- name: 📘 Documentation
|
- name: 📘 Documentation
|
||||||
url: https://nebula.defined.net/docs/
|
url: https://nebula.defined.net/docs/
|
||||||
about: Review documentation.
|
about: 'The documentation is the best place to start if you are new to Nebula.'
|
||||||
|
|
||||||
- name: 💁 Support/Chat
|
- name: 💁 Support/Chat
|
||||||
url: https://join.slack.com/t/nebulaoss/shared_invite/enQtOTA5MDI4NDg3MTg4LTkwY2EwNTI4NzQyMzc0M2ZlODBjNWI3NTY1MzhiOThiMmZlZjVkMTI0NGY4YTMyNjUwMWEyNzNkZTJmYzQxOGU
|
url: https://join.slack.com/t/nebulaoss/shared_invite/zt-39pk4xopc-CUKlGcb5Z39dQ0cK1v7ehA
|
||||||
about: 'This issue tracker is not for support questions. Join us on Slack for assistance!'
|
about: 'For faster support, join us on Slack for assistance!'
|
||||||
|
|
||||||
- name: 📱 Mobile Nebula
|
|
||||||
url: https://github.com/definednet/mobile_nebula
|
|
||||||
about: 'This issue tracker is not for mobile support. Try the Mobile Nebula repo instead!'
|
|
||||||
|
|||||||
11
.github/pull_request_template.md
vendored
Normal file
11
.github/pull_request_template.md
vendored
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
<!--
|
||||||
|
Thank you for taking the time to submit a pull request!
|
||||||
|
|
||||||
|
Please be sure to provide a clear description of what you're trying to achieve with the change.
|
||||||
|
|
||||||
|
- If you're submitting a new feature, please explain how to use it and document any new config options in the example config.
|
||||||
|
- If you're submitting a bugfix, please link the related issue or describe the circumstances surrounding the issue.
|
||||||
|
- If you're changing a default, explain why you believe the new default is appropriate for most users.
|
||||||
|
|
||||||
|
P.S. If you're only updating the README or other docs, please file a pull request here instead: https://github.com/DefinedNet/nebula-docs
|
||||||
|
-->
|
||||||
6
.github/workflows/gofmt.yml
vendored
6
.github/workflows/gofmt.yml
vendored
@@ -14,11 +14,11 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
|
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v5
|
||||||
|
|
||||||
- uses: actions/setup-go@v5
|
- uses: actions/setup-go@v6
|
||||||
with:
|
with:
|
||||||
go-version: '1.22'
|
go-version: '1.25'
|
||||||
check-latest: true
|
check-latest: true
|
||||||
|
|
||||||
- name: Install goimports
|
- name: Install goimports
|
||||||
|
|||||||
34
.github/workflows/release.yml
vendored
34
.github/workflows/release.yml
vendored
@@ -10,11 +10,11 @@ jobs:
|
|||||||
name: Build Linux/BSD All
|
name: Build Linux/BSD All
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v5
|
||||||
|
|
||||||
- uses: actions/setup-go@v5
|
- uses: actions/setup-go@v6
|
||||||
with:
|
with:
|
||||||
go-version: '1.22'
|
go-version: '1.25'
|
||||||
check-latest: true
|
check-latest: true
|
||||||
|
|
||||||
- name: Build
|
- name: Build
|
||||||
@@ -24,7 +24,7 @@ jobs:
|
|||||||
mv build/*.tar.gz release
|
mv build/*.tar.gz release
|
||||||
|
|
||||||
- name: Upload artifacts
|
- name: Upload artifacts
|
||||||
uses: actions/upload-artifact@v4
|
uses: actions/upload-artifact@v5
|
||||||
with:
|
with:
|
||||||
name: linux-latest
|
name: linux-latest
|
||||||
path: release
|
path: release
|
||||||
@@ -33,11 +33,11 @@ jobs:
|
|||||||
name: Build Windows
|
name: Build Windows
|
||||||
runs-on: windows-latest
|
runs-on: windows-latest
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v5
|
||||||
|
|
||||||
- uses: actions/setup-go@v5
|
- uses: actions/setup-go@v6
|
||||||
with:
|
with:
|
||||||
go-version: '1.22'
|
go-version: '1.25'
|
||||||
check-latest: true
|
check-latest: true
|
||||||
|
|
||||||
- name: Build
|
- name: Build
|
||||||
@@ -55,7 +55,7 @@ jobs:
|
|||||||
mv dist\windows\wintun build\dist\windows\
|
mv dist\windows\wintun build\dist\windows\
|
||||||
|
|
||||||
- name: Upload artifacts
|
- name: Upload artifacts
|
||||||
uses: actions/upload-artifact@v4
|
uses: actions/upload-artifact@v5
|
||||||
with:
|
with:
|
||||||
name: windows-latest
|
name: windows-latest
|
||||||
path: build
|
path: build
|
||||||
@@ -66,16 +66,16 @@ jobs:
|
|||||||
HAS_SIGNING_CREDS: ${{ secrets.AC_USERNAME != '' }}
|
HAS_SIGNING_CREDS: ${{ secrets.AC_USERNAME != '' }}
|
||||||
runs-on: macos-latest
|
runs-on: macos-latest
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v5
|
||||||
|
|
||||||
- uses: actions/setup-go@v5
|
- uses: actions/setup-go@v6
|
||||||
with:
|
with:
|
||||||
go-version: '1.22'
|
go-version: '1.25'
|
||||||
check-latest: true
|
check-latest: true
|
||||||
|
|
||||||
- name: Import certificates
|
- name: Import certificates
|
||||||
if: env.HAS_SIGNING_CREDS == 'true'
|
if: env.HAS_SIGNING_CREDS == 'true'
|
||||||
uses: Apple-Actions/import-codesign-certs@v3
|
uses: Apple-Actions/import-codesign-certs@v5
|
||||||
with:
|
with:
|
||||||
p12-file-base64: ${{ secrets.APPLE_DEVELOPER_CERTIFICATE_P12_BASE64 }}
|
p12-file-base64: ${{ secrets.APPLE_DEVELOPER_CERTIFICATE_P12_BASE64 }}
|
||||||
p12-password: ${{ secrets.APPLE_DEVELOPER_CERTIFICATE_PASSWORD }}
|
p12-password: ${{ secrets.APPLE_DEVELOPER_CERTIFICATE_PASSWORD }}
|
||||||
@@ -104,7 +104,7 @@ jobs:
|
|||||||
fi
|
fi
|
||||||
|
|
||||||
- name: Upload artifacts
|
- name: Upload artifacts
|
||||||
uses: actions/upload-artifact@v4
|
uses: actions/upload-artifact@v5
|
||||||
with:
|
with:
|
||||||
name: darwin-latest
|
name: darwin-latest
|
||||||
path: ./release/*
|
path: ./release/*
|
||||||
@@ -124,11 +124,11 @@ jobs:
|
|||||||
# be overwritten
|
# be overwritten
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
if: ${{ env.HAS_DOCKER_CREDS == 'true' }}
|
if: ${{ env.HAS_DOCKER_CREDS == 'true' }}
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v5
|
||||||
|
|
||||||
- name: Download artifacts
|
- name: Download artifacts
|
||||||
if: ${{ env.HAS_DOCKER_CREDS == 'true' }}
|
if: ${{ env.HAS_DOCKER_CREDS == 'true' }}
|
||||||
uses: actions/download-artifact@v4
|
uses: actions/download-artifact@v6
|
||||||
with:
|
with:
|
||||||
name: linux-latest
|
name: linux-latest
|
||||||
path: artifacts
|
path: artifacts
|
||||||
@@ -160,10 +160,10 @@ jobs:
|
|||||||
needs: [build-linux, build-darwin, build-windows]
|
needs: [build-linux, build-darwin, build-windows]
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v5
|
||||||
|
|
||||||
- name: Download artifacts
|
- name: Download artifacts
|
||||||
uses: actions/download-artifact@v4
|
uses: actions/download-artifact@v6
|
||||||
with:
|
with:
|
||||||
path: artifacts
|
path: artifacts
|
||||||
|
|
||||||
|
|||||||
9
.github/workflows/smoke-extra.yml
vendored
9
.github/workflows/smoke-extra.yml
vendored
@@ -20,13 +20,16 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
|
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v5
|
||||||
|
|
||||||
- uses: actions/setup-go@v5
|
- uses: actions/setup-go@v6
|
||||||
with:
|
with:
|
||||||
go-version-file: 'go.mod'
|
go-version: '1.25'
|
||||||
check-latest: true
|
check-latest: true
|
||||||
|
|
||||||
|
- name: add hashicorp source
|
||||||
|
run: wget -O- https://apt.releases.hashicorp.com/gpg | gpg --dearmor | sudo tee /usr/share/keyrings/hashicorp-archive-keyring.gpg && echo "deb [signed-by=/usr/share/keyrings/hashicorp-archive-keyring.gpg] https://apt.releases.hashicorp.com $(lsb_release -cs) main" | sudo tee /etc/apt/sources.list.d/hashicorp.list
|
||||||
|
|
||||||
- name: install vagrant
|
- name: install vagrant
|
||||||
run: sudo apt-get update && sudo apt-get install -y vagrant virtualbox
|
run: sudo apt-get update && sudo apt-get install -y vagrant virtualbox
|
||||||
|
|
||||||
|
|||||||
6
.github/workflows/smoke.yml
vendored
6
.github/workflows/smoke.yml
vendored
@@ -18,11 +18,11 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
|
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v5
|
||||||
|
|
||||||
- uses: actions/setup-go@v5
|
- uses: actions/setup-go@v6
|
||||||
with:
|
with:
|
||||||
go-version: '1.22'
|
go-version: '1.25'
|
||||||
check-latest: true
|
check-latest: true
|
||||||
|
|
||||||
- name: build
|
- name: build
|
||||||
|
|||||||
10
.github/workflows/smoke/build.sh
vendored
10
.github/workflows/smoke/build.sh
vendored
@@ -5,6 +5,10 @@ set -e -x
|
|||||||
rm -rf ./build
|
rm -rf ./build
|
||||||
mkdir ./build
|
mkdir ./build
|
||||||
|
|
||||||
|
# TODO: Assumes your docker bridge network is a /24, and the first container that launches will be .1
|
||||||
|
# - We could make this better by launching the lighthouse first and then fetching what IP it is.
|
||||||
|
NET="$(docker network inspect bridge -f '{{ range .IPAM.Config }}{{ .Subnet }}{{ end }}' | cut -d. -f1-3)"
|
||||||
|
|
||||||
(
|
(
|
||||||
cd build
|
cd build
|
||||||
|
|
||||||
@@ -21,16 +25,16 @@ mkdir ./build
|
|||||||
../genconfig.sh >lighthouse1.yml
|
../genconfig.sh >lighthouse1.yml
|
||||||
|
|
||||||
HOST="host2" \
|
HOST="host2" \
|
||||||
LIGHTHOUSES="192.168.100.1 172.17.0.2:4242" \
|
LIGHTHOUSES="192.168.100.1 $NET.2:4242" \
|
||||||
../genconfig.sh >host2.yml
|
../genconfig.sh >host2.yml
|
||||||
|
|
||||||
HOST="host3" \
|
HOST="host3" \
|
||||||
LIGHTHOUSES="192.168.100.1 172.17.0.2:4242" \
|
LIGHTHOUSES="192.168.100.1 $NET.2:4242" \
|
||||||
INBOUND='[{"port": "any", "proto": "icmp", "group": "lighthouse"}]' \
|
INBOUND='[{"port": "any", "proto": "icmp", "group": "lighthouse"}]' \
|
||||||
../genconfig.sh >host3.yml
|
../genconfig.sh >host3.yml
|
||||||
|
|
||||||
HOST="host4" \
|
HOST="host4" \
|
||||||
LIGHTHOUSES="192.168.100.1 172.17.0.2:4242" \
|
LIGHTHOUSES="192.168.100.1 $NET.2:4242" \
|
||||||
OUTBOUND='[{"port": "any", "proto": "icmp", "group": "lighthouse"}]' \
|
OUTBOUND='[{"port": "any", "proto": "icmp", "group": "lighthouse"}]' \
|
||||||
../genconfig.sh >host4.yml
|
../genconfig.sh >host4.yml
|
||||||
|
|
||||||
|
|||||||
34
.github/workflows/smoke/smoke-vagrant.sh
vendored
34
.github/workflows/smoke/smoke-vagrant.sh
vendored
@@ -29,13 +29,13 @@ docker run --name lighthouse1 --rm "$CONTAINER" -config lighthouse1.yml -test
|
|||||||
docker run --name host2 --rm "$CONTAINER" -config host2.yml -test
|
docker run --name host2 --rm "$CONTAINER" -config host2.yml -test
|
||||||
|
|
||||||
vagrant up
|
vagrant up
|
||||||
vagrant ssh -c "cd /nebula && /nebula/$1-nebula -config host3.yml -test"
|
vagrant ssh -c "cd /nebula && /nebula/$1-nebula -config host3.yml -test" -- -T
|
||||||
|
|
||||||
docker run --name lighthouse1 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config lighthouse1.yml 2>&1 | tee logs/lighthouse1 | sed -u 's/^/ [lighthouse1] /' &
|
docker run --name lighthouse1 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config lighthouse1.yml 2>&1 | tee logs/lighthouse1 | sed -u 's/^/ [lighthouse1] /' &
|
||||||
sleep 1
|
sleep 1
|
||||||
docker run --name host2 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config host2.yml 2>&1 | tee logs/host2 | sed -u 's/^/ [host2] /' &
|
docker run --name host2 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config host2.yml 2>&1 | tee logs/host2 | sed -u 's/^/ [host2] /' &
|
||||||
sleep 1
|
sleep 1
|
||||||
vagrant ssh -c "cd /nebula && sudo sh -c 'echo \$\$ >/nebula/pid && exec /nebula/$1-nebula -config host3.yml'" &
|
vagrant ssh -c "cd /nebula && sudo sh -c 'echo \$\$ >/nebula/pid && exec /nebula/$1-nebula -config host3.yml'" 2>&1 -- -T | tee logs/host3 | sed -u 's/^/ [host3] /' &
|
||||||
sleep 15
|
sleep 15
|
||||||
|
|
||||||
# grab tcpdump pcaps for debugging
|
# grab tcpdump pcaps for debugging
|
||||||
@@ -46,8 +46,8 @@ docker exec host2 tcpdump -i eth0 -q -w - -U 2>logs/host2.outside.log >logs/host
|
|||||||
# vagrant ssh -c "tcpdump -i nebula1 -q -w - -U" 2>logs/host3.inside.log >logs/host3.inside.pcap &
|
# vagrant ssh -c "tcpdump -i nebula1 -q -w - -U" 2>logs/host3.inside.log >logs/host3.inside.pcap &
|
||||||
# vagrant ssh -c "tcpdump -i eth0 -q -w - -U" 2>logs/host3.outside.log >logs/host3.outside.pcap &
|
# vagrant ssh -c "tcpdump -i eth0 -q -w - -U" 2>logs/host3.outside.log >logs/host3.outside.pcap &
|
||||||
|
|
||||||
docker exec host2 ncat -nklv 0.0.0.0 2000 &
|
#docker exec host2 ncat -nklv 0.0.0.0 2000 &
|
||||||
vagrant ssh -c "ncat -nklv 0.0.0.0 2000" &
|
#vagrant ssh -c "ncat -nklv 0.0.0.0 2000" &
|
||||||
#docker exec host2 ncat -e '/usr/bin/echo host2' -nkluv 0.0.0.0 3000 &
|
#docker exec host2 ncat -e '/usr/bin/echo host2' -nkluv 0.0.0.0 3000 &
|
||||||
#vagrant ssh -c "ncat -e '/usr/bin/echo host3' -nkluv 0.0.0.0 3000" &
|
#vagrant ssh -c "ncat -e '/usr/bin/echo host3' -nkluv 0.0.0.0 3000" &
|
||||||
|
|
||||||
@@ -68,11 +68,11 @@ docker exec host2 ping -c1 192.168.100.1
|
|||||||
# Should fail because not allowed by host3 inbound firewall
|
# Should fail because not allowed by host3 inbound firewall
|
||||||
! docker exec host2 ping -c1 192.168.100.3 -w5 || exit 1
|
! docker exec host2 ping -c1 192.168.100.3 -w5 || exit 1
|
||||||
|
|
||||||
set +x
|
#set +x
|
||||||
echo
|
#echo
|
||||||
echo " *** Testing ncat from host2"
|
#echo " *** Testing ncat from host2"
|
||||||
echo
|
#echo
|
||||||
set -x
|
#set -x
|
||||||
# Should fail because not allowed by host3 inbound firewall
|
# Should fail because not allowed by host3 inbound firewall
|
||||||
#! docker exec host2 ncat -nzv -w5 192.168.100.3 2000 || exit 1
|
#! docker exec host2 ncat -nzv -w5 192.168.100.3 2000 || exit 1
|
||||||
#! docker exec host2 ncat -nzuv -w5 192.168.100.3 3000 | grep -q host3 || exit 1
|
#! docker exec host2 ncat -nzuv -w5 192.168.100.3 3000 | grep -q host3 || exit 1
|
||||||
@@ -82,18 +82,18 @@ echo
|
|||||||
echo " *** Testing ping from host3"
|
echo " *** Testing ping from host3"
|
||||||
echo
|
echo
|
||||||
set -x
|
set -x
|
||||||
vagrant ssh -c "ping -c1 192.168.100.1"
|
vagrant ssh -c "ping -c1 192.168.100.1" -- -T
|
||||||
vagrant ssh -c "ping -c1 192.168.100.2"
|
vagrant ssh -c "ping -c1 192.168.100.2" -- -T
|
||||||
|
|
||||||
set +x
|
#set +x
|
||||||
echo
|
#echo
|
||||||
echo " *** Testing ncat from host3"
|
#echo " *** Testing ncat from host3"
|
||||||
echo
|
#echo
|
||||||
set -x
|
#set -x
|
||||||
#vagrant ssh -c "ncat -nzv -w5 192.168.100.2 2000"
|
#vagrant ssh -c "ncat -nzv -w5 192.168.100.2 2000"
|
||||||
#vagrant ssh -c "ncat -nzuv -w5 192.168.100.2 3000" | grep -q host2
|
#vagrant ssh -c "ncat -nzuv -w5 192.168.100.2 3000" | grep -q host2
|
||||||
|
|
||||||
vagrant ssh -c "sudo xargs kill </nebula/pid"
|
vagrant ssh -c "sudo xargs kill </nebula/pid" -- -T
|
||||||
docker exec host2 sh -c 'kill 1'
|
docker exec host2 sh -c 'kill 1'
|
||||||
docker exec lighthouse1 sh -c 'kill 1'
|
docker exec lighthouse1 sh -c 'kill 1'
|
||||||
sleep 1
|
sleep 1
|
||||||
|
|||||||
40
.github/workflows/test.yml
vendored
40
.github/workflows/test.yml
vendored
@@ -18,11 +18,11 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
|
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v5
|
||||||
|
|
||||||
- uses: actions/setup-go@v5
|
- uses: actions/setup-go@v6
|
||||||
with:
|
with:
|
||||||
go-version: '1.22'
|
go-version: '1.25'
|
||||||
check-latest: true
|
check-latest: true
|
||||||
|
|
||||||
- name: Build
|
- name: Build
|
||||||
@@ -31,6 +31,11 @@ jobs:
|
|||||||
- name: Vet
|
- name: Vet
|
||||||
run: make vet
|
run: make vet
|
||||||
|
|
||||||
|
- name: golangci-lint
|
||||||
|
uses: golangci/golangci-lint-action@v9
|
||||||
|
with:
|
||||||
|
version: v2.5
|
||||||
|
|
||||||
- name: Test
|
- name: Test
|
||||||
run: make test
|
run: make test
|
||||||
|
|
||||||
@@ -40,7 +45,7 @@ jobs:
|
|||||||
- name: Build test mobile
|
- name: Build test mobile
|
||||||
run: make build-test-mobile
|
run: make build-test-mobile
|
||||||
|
|
||||||
- uses: actions/upload-artifact@v4
|
- uses: actions/upload-artifact@v5
|
||||||
with:
|
with:
|
||||||
name: e2e packet flow linux-latest
|
name: e2e packet flow linux-latest
|
||||||
path: e2e/mermaid/linux-latest
|
path: e2e/mermaid/linux-latest
|
||||||
@@ -51,11 +56,11 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
|
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v5
|
||||||
|
|
||||||
- uses: actions/setup-go@v5
|
- uses: actions/setup-go@v6
|
||||||
with:
|
with:
|
||||||
go-version: '1.22'
|
go-version: '1.25'
|
||||||
check-latest: true
|
check-latest: true
|
||||||
|
|
||||||
- name: Build
|
- name: Build
|
||||||
@@ -65,18 +70,18 @@ jobs:
|
|||||||
run: make test-boringcrypto
|
run: make test-boringcrypto
|
||||||
|
|
||||||
- name: End 2 end
|
- name: End 2 end
|
||||||
run: make e2evv GOEXPERIMENT=boringcrypto CGO_ENABLED=1
|
run: make e2e GOEXPERIMENT=boringcrypto CGO_ENABLED=1 TEST_ENV="TEST_LOGS=1" TEST_FLAGS="-v -ldflags -checklinkname=0"
|
||||||
|
|
||||||
test-linux-pkcs11:
|
test-linux-pkcs11:
|
||||||
name: Build and test on linux with pkcs11
|
name: Build and test on linux with pkcs11
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
|
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v5
|
||||||
|
|
||||||
- uses: actions/setup-go@v5
|
- uses: actions/setup-go@v6
|
||||||
with:
|
with:
|
||||||
go-version: '1.22'
|
go-version: '1.25'
|
||||||
check-latest: true
|
check-latest: true
|
||||||
|
|
||||||
- name: Build
|
- name: Build
|
||||||
@@ -93,11 +98,11 @@ jobs:
|
|||||||
os: [windows-latest, macos-latest]
|
os: [windows-latest, macos-latest]
|
||||||
steps:
|
steps:
|
||||||
|
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v5
|
||||||
|
|
||||||
- uses: actions/setup-go@v5
|
- uses: actions/setup-go@v6
|
||||||
with:
|
with:
|
||||||
go-version: '1.22'
|
go-version: '1.25'
|
||||||
check-latest: true
|
check-latest: true
|
||||||
|
|
||||||
- name: Build nebula
|
- name: Build nebula
|
||||||
@@ -109,13 +114,18 @@ jobs:
|
|||||||
- name: Vet
|
- name: Vet
|
||||||
run: make vet
|
run: make vet
|
||||||
|
|
||||||
|
- name: golangci-lint
|
||||||
|
uses: golangci/golangci-lint-action@v9
|
||||||
|
with:
|
||||||
|
version: v2.5
|
||||||
|
|
||||||
- name: Test
|
- name: Test
|
||||||
run: make test
|
run: make test
|
||||||
|
|
||||||
- name: End 2 end
|
- name: End 2 end
|
||||||
run: make e2evv
|
run: make e2evv
|
||||||
|
|
||||||
- uses: actions/upload-artifact@v4
|
- uses: actions/upload-artifact@v5
|
||||||
with:
|
with:
|
||||||
name: e2e packet flow ${{ matrix.os }}
|
name: e2e packet flow ${{ matrix.os }}
|
||||||
path: e2e/mermaid/${{ matrix.os }}
|
path: e2e/mermaid/${{ matrix.os }}
|
||||||
|
|||||||
23
.golangci.yaml
Normal file
23
.golangci.yaml
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
version: "2"
|
||||||
|
linters:
|
||||||
|
default: none
|
||||||
|
enable:
|
||||||
|
- testifylint
|
||||||
|
exclusions:
|
||||||
|
generated: lax
|
||||||
|
presets:
|
||||||
|
- comments
|
||||||
|
- common-false-positives
|
||||||
|
- legacy
|
||||||
|
- std-error-handling
|
||||||
|
paths:
|
||||||
|
- third_party$
|
||||||
|
- builtin$
|
||||||
|
- examples$
|
||||||
|
formatters:
|
||||||
|
exclusions:
|
||||||
|
generated: lax
|
||||||
|
paths:
|
||||||
|
- third_party$
|
||||||
|
- builtin$
|
||||||
|
- examples$
|
||||||
@@ -7,6 +7,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
|||||||
|
|
||||||
## [Unreleased]
|
## [Unreleased]
|
||||||
|
|
||||||
|
### Changed
|
||||||
|
|
||||||
|
- `default_local_cidr_any` now defaults to false, meaning that any firewall rule
|
||||||
|
intended to target an `unsafe_routes` entry must explicitly declare it via the
|
||||||
|
`local_cidr` field. This is almost always the intended behavior. This flag is
|
||||||
|
deprecated and will be removed in a future release.
|
||||||
|
|
||||||
## [1.9.4] - 2024-09-09
|
## [1.9.4] - 2024-09-09
|
||||||
|
|
||||||
### Added
|
### Added
|
||||||
|
|||||||
4
Makefile
4
Makefile
@@ -137,6 +137,8 @@ build/linux-mips-softfloat/%: LDFLAGS += -s -w
|
|||||||
# boringcrypto
|
# boringcrypto
|
||||||
build/linux-amd64-boringcrypto/%: GOENV += GOEXPERIMENT=boringcrypto CGO_ENABLED=1
|
build/linux-amd64-boringcrypto/%: GOENV += GOEXPERIMENT=boringcrypto CGO_ENABLED=1
|
||||||
build/linux-arm64-boringcrypto/%: GOENV += GOEXPERIMENT=boringcrypto CGO_ENABLED=1
|
build/linux-arm64-boringcrypto/%: GOENV += GOEXPERIMENT=boringcrypto CGO_ENABLED=1
|
||||||
|
build/linux-amd64-boringcrypto/%: LDFLAGS += -checklinkname=0
|
||||||
|
build/linux-arm64-boringcrypto/%: LDFLAGS += -checklinkname=0
|
||||||
|
|
||||||
build/%/nebula: .FORCE
|
build/%/nebula: .FORCE
|
||||||
GOOS=$(firstword $(subst -, , $*)) \
|
GOOS=$(firstword $(subst -, , $*)) \
|
||||||
@@ -170,7 +172,7 @@ test:
|
|||||||
go test -v ./...
|
go test -v ./...
|
||||||
|
|
||||||
test-boringcrypto:
|
test-boringcrypto:
|
||||||
GOEXPERIMENT=boringcrypto CGO_ENABLED=1 go test -v ./...
|
GOEXPERIMENT=boringcrypto CGO_ENABLED=1 go test -ldflags "-checklinkname=0" -v ./...
|
||||||
|
|
||||||
test-pkcs11:
|
test-pkcs11:
|
||||||
CGO_ENABLED=1 go test -v -tags pkcs11 ./...
|
CGO_ENABLED=1 go test -v -tags pkcs11 ./...
|
||||||
|
|||||||
69
README.md
69
README.md
@@ -4,7 +4,7 @@ It lets you seamlessly connect computers anywhere in the world. Nebula is portab
|
|||||||
It can be used to connect a small number of computers, but is also able to connect tens of thousands of computers.
|
It can be used to connect a small number of computers, but is also able to connect tens of thousands of computers.
|
||||||
|
|
||||||
Nebula incorporates a number of existing concepts like encryption, security groups, certificates,
|
Nebula incorporates a number of existing concepts like encryption, security groups, certificates,
|
||||||
and tunneling, and each of those individual pieces existed before Nebula in various forms.
|
and tunneling.
|
||||||
What makes Nebula different to existing offerings is that it brings all of these ideas together,
|
What makes Nebula different to existing offerings is that it brings all of these ideas together,
|
||||||
resulting in a sum that is greater than its individual parts.
|
resulting in a sum that is greater than its individual parts.
|
||||||
|
|
||||||
@@ -12,7 +12,7 @@ Further documentation can be found [here](https://nebula.defined.net/docs/).
|
|||||||
|
|
||||||
You can read more about Nebula [here](https://medium.com/p/884110a5579).
|
You can read more about Nebula [here](https://medium.com/p/884110a5579).
|
||||||
|
|
||||||
You can also join the NebulaOSS Slack group [here](https://join.slack.com/t/nebulaoss/shared_invite/enQtOTA5MDI4NDg3MTg4LTkwY2EwNTI4NzQyMzc0M2ZlODBjNWI3NTY1MzhiOThiMmZlZjVkMTI0NGY4YTMyNjUwMWEyNzNkZTJmYzQxOGU).
|
You can also join the NebulaOSS Slack group [here](https://join.slack.com/t/nebulaoss/shared_invite/zt-39pk4xopc-CUKlGcb5Z39dQ0cK1v7ehA).
|
||||||
|
|
||||||
## Supported Platforms
|
## Supported Platforms
|
||||||
|
|
||||||
@@ -28,33 +28,33 @@ Check the [releases](https://github.com/slackhq/nebula/releases/latest) page for
|
|||||||
#### Distribution Packages
|
#### Distribution Packages
|
||||||
|
|
||||||
- [Arch Linux](https://archlinux.org/packages/extra/x86_64/nebula/)
|
- [Arch Linux](https://archlinux.org/packages/extra/x86_64/nebula/)
|
||||||
```
|
```sh
|
||||||
$ sudo pacman -S nebula
|
sudo pacman -S nebula
|
||||||
```
|
```
|
||||||
|
|
||||||
- [Fedora Linux](https://src.fedoraproject.org/rpms/nebula)
|
- [Fedora Linux](https://src.fedoraproject.org/rpms/nebula)
|
||||||
```
|
```sh
|
||||||
$ sudo dnf install nebula
|
sudo dnf install nebula
|
||||||
```
|
```
|
||||||
|
|
||||||
- [Debian Linux](https://packages.debian.org/source/stable/nebula)
|
- [Debian Linux](https://packages.debian.org/source/stable/nebula)
|
||||||
```
|
```sh
|
||||||
$ sudo apt install nebula
|
sudo apt install nebula
|
||||||
```
|
```
|
||||||
|
|
||||||
- [Alpine Linux](https://pkgs.alpinelinux.org/packages?name=nebula)
|
- [Alpine Linux](https://pkgs.alpinelinux.org/packages?name=nebula)
|
||||||
```
|
```sh
|
||||||
$ sudo apk add nebula
|
sudo apk add nebula
|
||||||
```
|
```
|
||||||
|
|
||||||
- [macOS Homebrew](https://github.com/Homebrew/homebrew-core/blob/HEAD/Formula/nebula.rb)
|
- [macOS Homebrew](https://github.com/Homebrew/homebrew-core/blob/HEAD/Formula/n/nebula.rb)
|
||||||
```
|
```sh
|
||||||
$ brew install nebula
|
brew install nebula
|
||||||
```
|
```
|
||||||
|
|
||||||
- [Docker](https://hub.docker.com/r/nebulaoss/nebula)
|
- [Docker](https://hub.docker.com/r/nebulaoss/nebula)
|
||||||
```
|
```sh
|
||||||
$ docker pull nebulaoss/nebula
|
docker pull nebulaoss/nebula
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Mobile
|
#### Mobile
|
||||||
@@ -64,10 +64,10 @@ Check the [releases](https://github.com/slackhq/nebula/releases/latest) page for
|
|||||||
|
|
||||||
## Technical Overview
|
## Technical Overview
|
||||||
|
|
||||||
Nebula is a mutually authenticated peer-to-peer software defined network based on the [Noise Protocol Framework](https://noiseprotocol.org/).
|
Nebula is a mutually authenticated peer-to-peer software-defined network based on the [Noise Protocol Framework](https://noiseprotocol.org/).
|
||||||
Nebula uses certificates to assert a node's IP address, name, and membership within user-defined groups.
|
Nebula uses certificates to assert a node's IP address, name, and membership within user-defined groups.
|
||||||
Nebula's user-defined groups allow for provider agnostic traffic filtering between nodes.
|
Nebula's user-defined groups allow for provider agnostic traffic filtering between nodes.
|
||||||
Discovery nodes allow individual peers to find each other and optionally use UDP hole punching to establish connections from behind most firewalls or NATs.
|
Discovery nodes (aka lighthouses) allow individual peers to find each other and optionally use UDP hole punching to establish connections from behind most firewalls or NATs.
|
||||||
Users can move data between nodes in any number of cloud service providers, datacenters, and endpoints, without needing to maintain a particular addressing scheme.
|
Users can move data between nodes in any number of cloud service providers, datacenters, and endpoints, without needing to maintain a particular addressing scheme.
|
||||||
|
|
||||||
Nebula uses Elliptic-curve Diffie-Hellman (`ECDH`) key exchange and `AES-256-GCM` in its default configuration.
|
Nebula uses Elliptic-curve Diffie-Hellman (`ECDH`) key exchange and `AES-256-GCM` in its default configuration.
|
||||||
@@ -82,28 +82,34 @@ To set up a Nebula network, you'll need:
|
|||||||
|
|
||||||
#### 2. (Optional, but you really should..) At least one discovery node with a routable IP address, which we call a lighthouse.
|
#### 2. (Optional, but you really should..) At least one discovery node with a routable IP address, which we call a lighthouse.
|
||||||
|
|
||||||
Nebula lighthouses allow nodes to find each other, anywhere in the world. A lighthouse is the only node in a Nebula network whose IP should not change. Running a lighthouse requires very few compute resources, and you can easily use the least expensive option from a cloud hosting provider. If you're not sure which provider to use, a number of us have used $5/mo [DigitalOcean](https://digitalocean.com) droplets as lighthouses.
|
Nebula lighthouses allow nodes to find each other, anywhere in the world. A lighthouse is the only node in a Nebula network whose IP should not change. Running a lighthouse requires very few compute resources, and you can easily use the least expensive option from a cloud hosting provider. If you're not sure which provider to use, a number of us have used $6/mo [DigitalOcean](https://digitalocean.com) droplets as lighthouses.
|
||||||
|
|
||||||
Once you have launched an instance, ensure that Nebula udp traffic (default port udp/4242) can reach it over the internet.
|
|
||||||
|
|
||||||
|
Once you have launched an instance, ensure that Nebula udp traffic (default port udp/4242) can reach it over the internet.
|
||||||
|
|
||||||
#### 3. A Nebula certificate authority, which will be the root of trust for a particular Nebula network.
|
#### 3. A Nebula certificate authority, which will be the root of trust for a particular Nebula network.
|
||||||
|
|
||||||
```
|
```sh
|
||||||
./nebula-cert ca -name "Myorganization, Inc"
|
./nebula-cert ca -name "Myorganization, Inc"
|
||||||
```
|
```
|
||||||
This will create files named `ca.key` and `ca.cert` in the current directory. The `ca.key` file is the most sensitive file you'll create, because it is the key used to sign the certificates for individual nebula nodes/hosts. Please store this file somewhere safe, preferably with strong encryption.
|
|
||||||
|
This will create files named `ca.key` and `ca.cert` in the current directory. The `ca.key` file is the most sensitive file you'll create, because it is the key used to sign the certificates for individual nebula nodes/hosts. Please store this file somewhere safe, preferably with strong encryption.
|
||||||
|
|
||||||
|
**Be aware!** By default, certificate authorities have a 1-year lifetime before expiration. See [this guide](https://nebula.defined.net/docs/guides/rotating-certificate-authority/) for details on rotating a CA.
|
||||||
|
|
||||||
#### 4. Nebula host keys and certificates generated from that certificate authority
|
#### 4. Nebula host keys and certificates generated from that certificate authority
|
||||||
|
|
||||||
This assumes you have four nodes, named lighthouse1, laptop, server1, host3. You can name the nodes any way you'd like, including FQDN. You'll also need to choose IP addresses and the associated subnet. In this example, we are creating a nebula network that will use 192.168.100.x/24 as its network range. This example also demonstrates nebula groups, which can later be used to define traffic rules in a nebula network.
|
This assumes you have four nodes, named lighthouse1, laptop, server1, host3. You can name the nodes any way you'd like, including FQDN. You'll also need to choose IP addresses and the associated subnet. In this example, we are creating a nebula network that will use 192.168.100.x/24 as its network range. This example also demonstrates nebula groups, which can later be used to define traffic rules in a nebula network.
|
||||||
```
|
```sh
|
||||||
./nebula-cert sign -name "lighthouse1" -ip "192.168.100.1/24"
|
./nebula-cert sign -name "lighthouse1" -ip "192.168.100.1/24"
|
||||||
./nebula-cert sign -name "laptop" -ip "192.168.100.2/24" -groups "laptop,home,ssh"
|
./nebula-cert sign -name "laptop" -ip "192.168.100.2/24" -groups "laptop,home,ssh"
|
||||||
./nebula-cert sign -name "server1" -ip "192.168.100.9/24" -groups "servers"
|
./nebula-cert sign -name "server1" -ip "192.168.100.9/24" -groups "servers"
|
||||||
./nebula-cert sign -name "host3" -ip "192.168.100.10/24"
|
./nebula-cert sign -name "host3" -ip "192.168.100.10/24"
|
||||||
```
|
```
|
||||||
|
|
||||||
|
By default, host certificates will expire 1 second before the CA expires. Use the `-duration` flag to specify a shorter lifetime.
|
||||||
|
|
||||||
#### 5. Configuration files for each host
|
#### 5. Configuration files for each host
|
||||||
|
|
||||||
Download a copy of the nebula [example configuration](https://github.com/slackhq/nebula/blob/master/examples/config.yml).
|
Download a copy of the nebula [example configuration](https://github.com/slackhq/nebula/blob/master/examples/config.yml).
|
||||||
|
|
||||||
* On the lighthouse node, you'll need to ensure `am_lighthouse: true` is set.
|
* On the lighthouse node, you'll need to ensure `am_lighthouse: true` is set.
|
||||||
@@ -118,10 +124,13 @@ For each host, copy the nebula binary to the host, along with `config.yml` from
|
|||||||
**DO NOT COPY `ca.key` TO INDIVIDUAL NODES.**
|
**DO NOT COPY `ca.key` TO INDIVIDUAL NODES.**
|
||||||
|
|
||||||
#### 7. Run nebula on each host
|
#### 7. Run nebula on each host
|
||||||
```
|
|
||||||
|
```sh
|
||||||
./nebula -config /path/to/config.yml
|
./nebula -config /path/to/config.yml
|
||||||
```
|
```
|
||||||
|
|
||||||
|
For more detailed instructions, [find the full documentation here](https://nebula.defined.net/docs/).
|
||||||
|
|
||||||
## Building Nebula from source
|
## Building Nebula from source
|
||||||
|
|
||||||
Make sure you have [go](https://go.dev/doc/install) installed and clone this repo. Change to the nebula directory.
|
Make sure you have [go](https://go.dev/doc/install) installed and clone this repo. Change to the nebula directory.
|
||||||
@@ -140,8 +149,10 @@ The default curve used for cryptographic handshakes and signatures is Curve25519
|
|||||||
|
|
||||||
In addition, Nebula can be built using the [BoringCrypto GOEXPERIMENT](https://github.com/golang/go/blob/go1.20/src/crypto/internal/boring/README.md) by running either of the following make targets:
|
In addition, Nebula can be built using the [BoringCrypto GOEXPERIMENT](https://github.com/golang/go/blob/go1.20/src/crypto/internal/boring/README.md) by running either of the following make targets:
|
||||||
|
|
||||||
make bin-boringcrypto
|
```sh
|
||||||
make release-boringcrypto
|
make bin-boringcrypto
|
||||||
|
make release-boringcrypto
|
||||||
|
```
|
||||||
|
|
||||||
This is not the recommended default deployment, but may be useful based on your compliance requirements.
|
This is not the recommended default deployment, but may be useful based on your compliance requirements.
|
||||||
|
|
||||||
@@ -149,5 +160,3 @@ This is not the recommended default deployment, but may be useful based on your
|
|||||||
|
|
||||||
Nebula was created at Slack Technologies, Inc by Nate Brown and Ryan Huber, with contributions from Oliver Fross, Alan Lam, Wade Simmons, and Lining Wang.
|
Nebula was created at Slack Technologies, Inc by Nate Brown and Ryan Huber, with contributions from Oliver Fross, Alan Lam, Wade Simmons, and Lining Wang.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -36,7 +36,7 @@ type AllowListNameRule struct {
|
|||||||
|
|
||||||
func NewLocalAllowListFromConfig(c *config.C, k string) (*LocalAllowList, error) {
|
func NewLocalAllowListFromConfig(c *config.C, k string) (*LocalAllowList, error) {
|
||||||
var nameRules []AllowListNameRule
|
var nameRules []AllowListNameRule
|
||||||
handleKey := func(key string, value interface{}) (bool, error) {
|
handleKey := func(key string, value any) (bool, error) {
|
||||||
if key == "interfaces" {
|
if key == "interfaces" {
|
||||||
var err error
|
var err error
|
||||||
nameRules, err = getAllowListInterfaces(k, value)
|
nameRules, err = getAllowListInterfaces(k, value)
|
||||||
@@ -70,7 +70,7 @@ func NewRemoteAllowListFromConfig(c *config.C, k, rangesKey string) (*RemoteAllo
|
|||||||
|
|
||||||
// If the handleKey func returns true, the rest of the parsing is skipped
|
// If the handleKey func returns true, the rest of the parsing is skipped
|
||||||
// for this key. This allows parsing of special values like `interfaces`.
|
// for this key. This allows parsing of special values like `interfaces`.
|
||||||
func newAllowListFromConfig(c *config.C, k string, handleKey func(key string, value interface{}) (bool, error)) (*AllowList, error) {
|
func newAllowListFromConfig(c *config.C, k string, handleKey func(key string, value any) (bool, error)) (*AllowList, error) {
|
||||||
r := c.Get(k)
|
r := c.Get(k)
|
||||||
if r == nil {
|
if r == nil {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
@@ -81,8 +81,8 @@ func newAllowListFromConfig(c *config.C, k string, handleKey func(key string, va
|
|||||||
|
|
||||||
// If the handleKey func returns true, the rest of the parsing is skipped
|
// If the handleKey func returns true, the rest of the parsing is skipped
|
||||||
// for this key. This allows parsing of special values like `interfaces`.
|
// for this key. This allows parsing of special values like `interfaces`.
|
||||||
func newAllowList(k string, raw interface{}, handleKey func(key string, value interface{}) (bool, error)) (*AllowList, error) {
|
func newAllowList(k string, raw any, handleKey func(key string, value any) (bool, error)) (*AllowList, error) {
|
||||||
rawMap, ok := raw.(map[interface{}]interface{})
|
rawMap, ok := raw.(map[string]any)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("config `%s` has invalid type: %T", k, raw)
|
return nil, fmt.Errorf("config `%s` has invalid type: %T", k, raw)
|
||||||
}
|
}
|
||||||
@@ -100,12 +100,7 @@ func newAllowList(k string, raw interface{}, handleKey func(key string, value in
|
|||||||
rules4 := allowListRules{firstValue: true, allValuesMatch: true, defaultSet: false}
|
rules4 := allowListRules{firstValue: true, allValuesMatch: true, defaultSet: false}
|
||||||
rules6 := allowListRules{firstValue: true, allValuesMatch: true, defaultSet: false}
|
rules6 := allowListRules{firstValue: true, allValuesMatch: true, defaultSet: false}
|
||||||
|
|
||||||
for rawKey, rawValue := range rawMap {
|
for rawCIDR, rawValue := range rawMap {
|
||||||
rawCIDR, ok := rawKey.(string)
|
|
||||||
if !ok {
|
|
||||||
return nil, fmt.Errorf("config `%s` has invalid key (type %T): %v", k, rawKey, rawKey)
|
|
||||||
}
|
|
||||||
|
|
||||||
if handleKey != nil {
|
if handleKey != nil {
|
||||||
handled, err := handleKey(rawCIDR, rawValue)
|
handled, err := handleKey(rawCIDR, rawValue)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -116,7 +111,7 @@ func newAllowList(k string, raw interface{}, handleKey func(key string, value in
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
value, ok := rawValue.(bool)
|
value, ok := config.AsBool(rawValue)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("config `%s` has invalid value (type %T): %v", k, rawValue, rawValue)
|
return nil, fmt.Errorf("config `%s` has invalid value (type %T): %v", k, rawValue, rawValue)
|
||||||
}
|
}
|
||||||
@@ -173,22 +168,18 @@ func newAllowList(k string, raw interface{}, handleKey func(key string, value in
|
|||||||
return &AllowList{cidrTree: tree}, nil
|
return &AllowList{cidrTree: tree}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func getAllowListInterfaces(k string, v interface{}) ([]AllowListNameRule, error) {
|
func getAllowListInterfaces(k string, v any) ([]AllowListNameRule, error) {
|
||||||
var nameRules []AllowListNameRule
|
var nameRules []AllowListNameRule
|
||||||
|
|
||||||
rawRules, ok := v.(map[interface{}]interface{})
|
rawRules, ok := v.(map[string]any)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("config `%s.interfaces` is invalid (type %T): %v", k, v, v)
|
return nil, fmt.Errorf("config `%s.interfaces` is invalid (type %T): %v", k, v, v)
|
||||||
}
|
}
|
||||||
|
|
||||||
firstEntry := true
|
firstEntry := true
|
||||||
var allValues bool
|
var allValues bool
|
||||||
for rawName, rawAllow := range rawRules {
|
for name, rawAllow := range rawRules {
|
||||||
name, ok := rawName.(string)
|
allow, ok := config.AsBool(rawAllow)
|
||||||
if !ok {
|
|
||||||
return nil, fmt.Errorf("config `%s.interfaces` has invalid key (type %T): %v", k, rawName, rawName)
|
|
||||||
}
|
|
||||||
allow, ok := rawAllow.(bool)
|
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("config `%s.interfaces` has invalid value (type %T): %v", k, rawAllow, rawAllow)
|
return nil, fmt.Errorf("config `%s.interfaces` has invalid value (type %T): %v", k, rawAllow, rawAllow)
|
||||||
}
|
}
|
||||||
@@ -224,16 +215,11 @@ func getRemoteAllowRanges(c *config.C, k string) (*bart.Table[*AllowList], error
|
|||||||
|
|
||||||
remoteAllowRanges := new(bart.Table[*AllowList])
|
remoteAllowRanges := new(bart.Table[*AllowList])
|
||||||
|
|
||||||
rawMap, ok := value.(map[interface{}]interface{})
|
rawMap, ok := value.(map[string]any)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("config `%s` has invalid type: %T", k, value)
|
return nil, fmt.Errorf("config `%s` has invalid type: %T", k, value)
|
||||||
}
|
}
|
||||||
for rawKey, rawValue := range rawMap {
|
for rawCIDR, rawValue := range rawMap {
|
||||||
rawCIDR, ok := rawKey.(string)
|
|
||||||
if !ok {
|
|
||||||
return nil, fmt.Errorf("config `%s` has invalid key (type %T): %v", k, rawKey, rawKey)
|
|
||||||
}
|
|
||||||
|
|
||||||
allowList, err := newAllowList(fmt.Sprintf("%s.%s", k, rawCIDR), rawValue, nil)
|
allowList, err := newAllowList(fmt.Sprintf("%s.%s", k, rawCIDR), rawValue, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|||||||
@@ -9,32 +9,33 @@ import (
|
|||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
"github.com/slackhq/nebula/test"
|
"github.com/slackhq/nebula/test"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestNewAllowListFromConfig(t *testing.T) {
|
func TestNewAllowListFromConfig(t *testing.T) {
|
||||||
l := test.NewLogger()
|
l := test.NewLogger()
|
||||||
c := config.NewC(l)
|
c := config.NewC(l)
|
||||||
c.Settings["allowlist"] = map[interface{}]interface{}{
|
c.Settings["allowlist"] = map[string]any{
|
||||||
"192.168.0.0": true,
|
"192.168.0.0": true,
|
||||||
}
|
}
|
||||||
r, err := newAllowListFromConfig(c, "allowlist", nil)
|
r, err := newAllowListFromConfig(c, "allowlist", nil)
|
||||||
assert.EqualError(t, err, "config `allowlist` has invalid CIDR: 192.168.0.0. netip.ParsePrefix(\"192.168.0.0\"): no '/'")
|
require.EqualError(t, err, "config `allowlist` has invalid CIDR: 192.168.0.0. netip.ParsePrefix(\"192.168.0.0\"): no '/'")
|
||||||
assert.Nil(t, r)
|
assert.Nil(t, r)
|
||||||
|
|
||||||
c.Settings["allowlist"] = map[interface{}]interface{}{
|
c.Settings["allowlist"] = map[string]any{
|
||||||
"192.168.0.0/16": "abc",
|
"192.168.0.0/16": "abc",
|
||||||
}
|
}
|
||||||
r, err = newAllowListFromConfig(c, "allowlist", nil)
|
r, err = newAllowListFromConfig(c, "allowlist", nil)
|
||||||
assert.EqualError(t, err, "config `allowlist` has invalid value (type string): abc")
|
require.EqualError(t, err, "config `allowlist` has invalid value (type string): abc")
|
||||||
|
|
||||||
c.Settings["allowlist"] = map[interface{}]interface{}{
|
c.Settings["allowlist"] = map[string]any{
|
||||||
"192.168.0.0/16": true,
|
"192.168.0.0/16": true,
|
||||||
"10.0.0.0/8": false,
|
"10.0.0.0/8": false,
|
||||||
}
|
}
|
||||||
r, err = newAllowListFromConfig(c, "allowlist", nil)
|
r, err = newAllowListFromConfig(c, "allowlist", nil)
|
||||||
assert.EqualError(t, err, "config `allowlist` contains both true and false rules, but no default set for 0.0.0.0/0")
|
require.EqualError(t, err, "config `allowlist` contains both true and false rules, but no default set for 0.0.0.0/0")
|
||||||
|
|
||||||
c.Settings["allowlist"] = map[interface{}]interface{}{
|
c.Settings["allowlist"] = map[string]any{
|
||||||
"0.0.0.0/0": true,
|
"0.0.0.0/0": true,
|
||||||
"10.0.0.0/8": false,
|
"10.0.0.0/8": false,
|
||||||
"10.42.42.0/24": true,
|
"10.42.42.0/24": true,
|
||||||
@@ -42,9 +43,9 @@ func TestNewAllowListFromConfig(t *testing.T) {
|
|||||||
"fd00:fd00::/16": false,
|
"fd00:fd00::/16": false,
|
||||||
}
|
}
|
||||||
r, err = newAllowListFromConfig(c, "allowlist", nil)
|
r, err = newAllowListFromConfig(c, "allowlist", nil)
|
||||||
assert.EqualError(t, err, "config `allowlist` contains both true and false rules, but no default set for ::/0")
|
require.EqualError(t, err, "config `allowlist` contains both true and false rules, but no default set for ::/0")
|
||||||
|
|
||||||
c.Settings["allowlist"] = map[interface{}]interface{}{
|
c.Settings["allowlist"] = map[string]any{
|
||||||
"0.0.0.0/0": true,
|
"0.0.0.0/0": true,
|
||||||
"10.0.0.0/8": false,
|
"10.0.0.0/8": false,
|
||||||
"10.42.42.0/24": true,
|
"10.42.42.0/24": true,
|
||||||
@@ -54,7 +55,7 @@ func TestNewAllowListFromConfig(t *testing.T) {
|
|||||||
assert.NotNil(t, r)
|
assert.NotNil(t, r)
|
||||||
}
|
}
|
||||||
|
|
||||||
c.Settings["allowlist"] = map[interface{}]interface{}{
|
c.Settings["allowlist"] = map[string]any{
|
||||||
"0.0.0.0/0": true,
|
"0.0.0.0/0": true,
|
||||||
"10.0.0.0/8": false,
|
"10.0.0.0/8": false,
|
||||||
"10.42.42.0/24": true,
|
"10.42.42.0/24": true,
|
||||||
@@ -69,25 +70,25 @@ func TestNewAllowListFromConfig(t *testing.T) {
|
|||||||
|
|
||||||
// Test interface names
|
// Test interface names
|
||||||
|
|
||||||
c.Settings["allowlist"] = map[interface{}]interface{}{
|
c.Settings["allowlist"] = map[string]any{
|
||||||
"interfaces": map[interface{}]interface{}{
|
"interfaces": map[string]any{
|
||||||
`docker.*`: "foo",
|
`docker.*`: "foo",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
lr, err := NewLocalAllowListFromConfig(c, "allowlist")
|
lr, err := NewLocalAllowListFromConfig(c, "allowlist")
|
||||||
assert.EqualError(t, err, "config `allowlist.interfaces` has invalid value (type string): foo")
|
require.EqualError(t, err, "config `allowlist.interfaces` has invalid value (type string): foo")
|
||||||
|
|
||||||
c.Settings["allowlist"] = map[interface{}]interface{}{
|
c.Settings["allowlist"] = map[string]any{
|
||||||
"interfaces": map[interface{}]interface{}{
|
"interfaces": map[string]any{
|
||||||
`docker.*`: false,
|
`docker.*`: false,
|
||||||
`eth.*`: true,
|
`eth.*`: true,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
lr, err = NewLocalAllowListFromConfig(c, "allowlist")
|
lr, err = NewLocalAllowListFromConfig(c, "allowlist")
|
||||||
assert.EqualError(t, err, "config `allowlist.interfaces` values must all be the same true/false value")
|
require.EqualError(t, err, "config `allowlist.interfaces` values must all be the same true/false value")
|
||||||
|
|
||||||
c.Settings["allowlist"] = map[interface{}]interface{}{
|
c.Settings["allowlist"] = map[string]any{
|
||||||
"interfaces": map[interface{}]interface{}{
|
"interfaces": map[string]any{
|
||||||
`docker.*`: false,
|
`docker.*`: false,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -98,7 +99,7 @@ func TestNewAllowListFromConfig(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestAllowList_Allow(t *testing.T) {
|
func TestAllowList_Allow(t *testing.T) {
|
||||||
assert.Equal(t, true, ((*AllowList)(nil)).Allow(netip.MustParseAddr("1.1.1.1")))
|
assert.True(t, ((*AllowList)(nil)).Allow(netip.MustParseAddr("1.1.1.1")))
|
||||||
|
|
||||||
tree := new(bart.Table[bool])
|
tree := new(bart.Table[bool])
|
||||||
tree.Insert(netip.MustParsePrefix("0.0.0.0/0"), true)
|
tree.Insert(netip.MustParsePrefix("0.0.0.0/0"), true)
|
||||||
@@ -111,17 +112,17 @@ func TestAllowList_Allow(t *testing.T) {
|
|||||||
tree.Insert(netip.MustParsePrefix("::2/128"), false)
|
tree.Insert(netip.MustParsePrefix("::2/128"), false)
|
||||||
al := &AllowList{cidrTree: tree}
|
al := &AllowList{cidrTree: tree}
|
||||||
|
|
||||||
assert.Equal(t, true, al.Allow(netip.MustParseAddr("1.1.1.1")))
|
assert.True(t, al.Allow(netip.MustParseAddr("1.1.1.1")))
|
||||||
assert.Equal(t, false, al.Allow(netip.MustParseAddr("10.0.0.4")))
|
assert.False(t, al.Allow(netip.MustParseAddr("10.0.0.4")))
|
||||||
assert.Equal(t, true, al.Allow(netip.MustParseAddr("10.42.42.42")))
|
assert.True(t, al.Allow(netip.MustParseAddr("10.42.42.42")))
|
||||||
assert.Equal(t, false, al.Allow(netip.MustParseAddr("10.42.42.41")))
|
assert.False(t, al.Allow(netip.MustParseAddr("10.42.42.41")))
|
||||||
assert.Equal(t, true, al.Allow(netip.MustParseAddr("10.42.0.1")))
|
assert.True(t, al.Allow(netip.MustParseAddr("10.42.0.1")))
|
||||||
assert.Equal(t, true, al.Allow(netip.MustParseAddr("::1")))
|
assert.True(t, al.Allow(netip.MustParseAddr("::1")))
|
||||||
assert.Equal(t, false, al.Allow(netip.MustParseAddr("::2")))
|
assert.False(t, al.Allow(netip.MustParseAddr("::2")))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestLocalAllowList_AllowName(t *testing.T) {
|
func TestLocalAllowList_AllowName(t *testing.T) {
|
||||||
assert.Equal(t, true, ((*LocalAllowList)(nil)).AllowName("docker0"))
|
assert.True(t, ((*LocalAllowList)(nil)).AllowName("docker0"))
|
||||||
|
|
||||||
rules := []AllowListNameRule{
|
rules := []AllowListNameRule{
|
||||||
{Name: regexp.MustCompile("^docker.*$"), Allow: false},
|
{Name: regexp.MustCompile("^docker.*$"), Allow: false},
|
||||||
@@ -129,9 +130,9 @@ func TestLocalAllowList_AllowName(t *testing.T) {
|
|||||||
}
|
}
|
||||||
al := &LocalAllowList{nameRules: rules}
|
al := &LocalAllowList{nameRules: rules}
|
||||||
|
|
||||||
assert.Equal(t, false, al.AllowName("docker0"))
|
assert.False(t, al.AllowName("docker0"))
|
||||||
assert.Equal(t, false, al.AllowName("tun0"))
|
assert.False(t, al.AllowName("tun0"))
|
||||||
assert.Equal(t, true, al.AllowName("eth0"))
|
assert.True(t, al.AllowName("eth0"))
|
||||||
|
|
||||||
rules = []AllowListNameRule{
|
rules = []AllowListNameRule{
|
||||||
{Name: regexp.MustCompile("^eth.*$"), Allow: true},
|
{Name: regexp.MustCompile("^eth.*$"), Allow: true},
|
||||||
@@ -139,7 +140,7 @@ func TestLocalAllowList_AllowName(t *testing.T) {
|
|||||||
}
|
}
|
||||||
al = &LocalAllowList{nameRules: rules}
|
al = &LocalAllowList{nameRules: rules}
|
||||||
|
|
||||||
assert.Equal(t, false, al.AllowName("docker0"))
|
assert.False(t, al.AllowName("docker0"))
|
||||||
assert.Equal(t, true, al.AllowName("eth0"))
|
assert.True(t, al.AllowName("eth0"))
|
||||||
assert.Equal(t, true, al.AllowName("ens5"))
|
assert.True(t, al.AllowName("ens5"))
|
||||||
}
|
}
|
||||||
|
|||||||
109
bits.go
109
bits.go
@@ -9,14 +9,13 @@ type Bits struct {
|
|||||||
length uint64
|
length uint64
|
||||||
current uint64
|
current uint64
|
||||||
bits []bool
|
bits []bool
|
||||||
firstSeen bool
|
|
||||||
lostCounter metrics.Counter
|
lostCounter metrics.Counter
|
||||||
dupeCounter metrics.Counter
|
dupeCounter metrics.Counter
|
||||||
outOfWindowCounter metrics.Counter
|
outOfWindowCounter metrics.Counter
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewBits(bits uint64) *Bits {
|
func NewBits(bits uint64) *Bits {
|
||||||
return &Bits{
|
b := &Bits{
|
||||||
length: bits,
|
length: bits,
|
||||||
bits: make([]bool, bits, bits),
|
bits: make([]bool, bits, bits),
|
||||||
current: 0,
|
current: 0,
|
||||||
@@ -24,34 +23,37 @@ func NewBits(bits uint64) *Bits {
|
|||||||
dupeCounter: metrics.GetOrRegisterCounter("network.packets.duplicate", nil),
|
dupeCounter: metrics.GetOrRegisterCounter("network.packets.duplicate", nil),
|
||||||
outOfWindowCounter: metrics.GetOrRegisterCounter("network.packets.out_of_window", nil),
|
outOfWindowCounter: metrics.GetOrRegisterCounter("network.packets.out_of_window", nil),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// There is no counter value 0, mark it to avoid counting a lost packet later.
|
||||||
|
b.bits[0] = true
|
||||||
|
b.current = 0
|
||||||
|
return b
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *Bits) Check(l logrus.FieldLogger, i uint64) bool {
|
func (b *Bits) Check(l *logrus.Logger, i uint64) bool {
|
||||||
// If i is the next number, return true.
|
// If i is the next number, return true.
|
||||||
if i > b.current || (i == 0 && b.firstSeen == false && b.current < b.length) {
|
if i > b.current {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
// If i is within the window, check if it's been set already. The first window will fail this check
|
// If i is within the window, check if it's been set already.
|
||||||
if i > b.current-b.length {
|
if i > b.current-b.length || i < b.length && b.current < b.length {
|
||||||
return !b.bits[i%b.length]
|
|
||||||
}
|
|
||||||
|
|
||||||
// If i is within the first window
|
|
||||||
if i < b.length {
|
|
||||||
return !b.bits[i%b.length]
|
return !b.bits[i%b.length]
|
||||||
}
|
}
|
||||||
|
|
||||||
// Not within the window
|
// Not within the window
|
||||||
l.Debugf("rejected a packet (top) %d %d\n", b.current, i)
|
if l.Level >= logrus.DebugLevel {
|
||||||
|
l.Debugf("rejected a packet (top) %d %d\n", b.current, i)
|
||||||
|
}
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *Bits) Update(l *logrus.Logger, i uint64) bool {
|
func (b *Bits) Update(l *logrus.Logger, i uint64) bool {
|
||||||
// If i is the next number, return true and update current.
|
// If i is the next number, return true and update current.
|
||||||
if i == b.current+1 {
|
if i == b.current+1 {
|
||||||
// Report missed packets, we can only understand what was missed after the first window has been gone through
|
// Check if the oldest bit was lost since we are shifting the window by 1 and occupying it with this counter
|
||||||
if i > b.length && b.bits[i%b.length] == false {
|
// The very first window can only be tracked as lost once we are on the 2nd window or greater
|
||||||
|
if b.bits[i%b.length] == false && i > b.length {
|
||||||
b.lostCounter.Inc(1)
|
b.lostCounter.Inc(1)
|
||||||
}
|
}
|
||||||
b.bits[i%b.length] = true
|
b.bits[i%b.length] = true
|
||||||
@@ -59,61 +61,32 @@ func (b *Bits) Update(l *logrus.Logger, i uint64) bool {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
// If i packet is greater than current but less than the maximum length of our bitmap,
|
// If i is a jump, adjust the window, record lost, update current, and return true
|
||||||
// flip everything in between to false and move ahead.
|
if i > b.current {
|
||||||
if i > b.current && i < b.current+b.length {
|
lost := int64(0)
|
||||||
// In between current and i need to be zero'd to allow those packets to come in later
|
// Zero out the bits between the current and the new counter value, limited by the window size,
|
||||||
for n := b.current + 1; n < i; n++ {
|
// since the window is shifting
|
||||||
|
for n := b.current + 1; n <= min(i, b.current+b.length); n++ {
|
||||||
|
if b.bits[n%b.length] == false && n > b.length {
|
||||||
|
lost++
|
||||||
|
}
|
||||||
b.bits[n%b.length] = false
|
b.bits[n%b.length] = false
|
||||||
}
|
}
|
||||||
|
|
||||||
b.bits[i%b.length] = true
|
// Only record any skipped packets as a result of the window moving further than the window length
|
||||||
b.current = i
|
// Any loss within the new window will be accounted for in future calls
|
||||||
//l.Debugf("missed %d packets between %d and %d\n", i-b.current, i, b.current)
|
lost += max(0, int64(i-b.current-b.length))
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
// If i is greater than the delta between current and the total length of our bitmap,
|
|
||||||
// just flip everything in the map and move ahead.
|
|
||||||
if i >= b.current+b.length {
|
|
||||||
// The current window loss will be accounted for later, only record the jump as loss up until then
|
|
||||||
lost := maxInt64(0, int64(i-b.current-b.length))
|
|
||||||
//TODO: explain this
|
|
||||||
if b.current == 0 {
|
|
||||||
lost++
|
|
||||||
}
|
|
||||||
|
|
||||||
for n := range b.bits {
|
|
||||||
// Don't want to count the first window as a loss
|
|
||||||
//TODO: this is likely wrong, we are wanting to track only the bit slots that we aren't going to track anymore and this is marking everything as missed
|
|
||||||
//if b.bits[n] == false {
|
|
||||||
// lost++
|
|
||||||
//}
|
|
||||||
b.bits[n] = false
|
|
||||||
}
|
|
||||||
|
|
||||||
b.lostCounter.Inc(lost)
|
b.lostCounter.Inc(lost)
|
||||||
|
|
||||||
if l.Level >= logrus.DebugLevel {
|
|
||||||
l.WithField("receiveWindow", m{"accepted": true, "currentCounter": b.current, "incomingCounter": i, "reason": "window shifting"}).
|
|
||||||
Debug("Receive window")
|
|
||||||
}
|
|
||||||
b.bits[i%b.length] = true
|
b.bits[i%b.length] = true
|
||||||
b.current = i
|
b.current = i
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
// Allow for the 0 packet to come in within the first window
|
// If i is within the current window but below the current counter,
|
||||||
if i == 0 && b.firstSeen == false && b.current < b.length {
|
// Check to see if it's a duplicate
|
||||||
b.firstSeen = true
|
if i > b.current-b.length || i < b.length && b.current < b.length {
|
||||||
b.bits[i%b.length] = true
|
if b.current == i || b.bits[i%b.length] == true {
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
// If i is within the window of current minus length (the total pat window size),
|
|
||||||
// allow it and flip to true but to NOT change current. We also have to account for the first window
|
|
||||||
if ((b.current >= b.length && i > b.current-b.length) || (b.current < b.length && i < b.length)) && i <= b.current {
|
|
||||||
if b.current == i {
|
|
||||||
if l.Level >= logrus.DebugLevel {
|
if l.Level >= logrus.DebugLevel {
|
||||||
l.WithField("receiveWindow", m{"accepted": false, "currentCounter": b.current, "incomingCounter": i, "reason": "duplicate"}).
|
l.WithField("receiveWindow", m{"accepted": false, "currentCounter": b.current, "incomingCounter": i, "reason": "duplicate"}).
|
||||||
Debug("Receive window")
|
Debug("Receive window")
|
||||||
@@ -122,18 +95,8 @@ func (b *Bits) Update(l *logrus.Logger, i uint64) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
if b.bits[i%b.length] == true {
|
|
||||||
if l.Level >= logrus.DebugLevel {
|
|
||||||
l.WithField("receiveWindow", m{"accepted": false, "currentCounter": b.current, "incomingCounter": i, "reason": "old duplicate"}).
|
|
||||||
Debug("Receive window")
|
|
||||||
}
|
|
||||||
b.dupeCounter.Inc(1)
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
b.bits[i%b.length] = true
|
b.bits[i%b.length] = true
|
||||||
return true
|
return true
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// In all other cases, fail and don't change current.
|
// In all other cases, fail and don't change current.
|
||||||
@@ -147,11 +110,3 @@ func (b *Bits) Update(l *logrus.Logger, i uint64) bool {
|
|||||||
}
|
}
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func maxInt64(a, b int64) int64 {
|
|
||||||
if a > b {
|
|
||||||
return a
|
|
||||||
}
|
|
||||||
|
|
||||||
return b
|
|
||||||
}
|
|
||||||
|
|||||||
109
bits_test.go
109
bits_test.go
@@ -15,48 +15,41 @@ func TestBits(t *testing.T) {
|
|||||||
assert.Len(t, b.bits, 10)
|
assert.Len(t, b.bits, 10)
|
||||||
|
|
||||||
// This is initialized to zero - receive one. This should work.
|
// This is initialized to zero - receive one. This should work.
|
||||||
|
|
||||||
assert.True(t, b.Check(l, 1))
|
assert.True(t, b.Check(l, 1))
|
||||||
u := b.Update(l, 1)
|
assert.True(t, b.Update(l, 1))
|
||||||
assert.True(t, u)
|
|
||||||
assert.EqualValues(t, 1, b.current)
|
assert.EqualValues(t, 1, b.current)
|
||||||
g := []bool{false, true, false, false, false, false, false, false, false, false}
|
g := []bool{true, true, false, false, false, false, false, false, false, false}
|
||||||
assert.Equal(t, g, b.bits)
|
assert.Equal(t, g, b.bits)
|
||||||
|
|
||||||
// Receive two
|
// Receive two
|
||||||
assert.True(t, b.Check(l, 2))
|
assert.True(t, b.Check(l, 2))
|
||||||
u = b.Update(l, 2)
|
assert.True(t, b.Update(l, 2))
|
||||||
assert.True(t, u)
|
|
||||||
assert.EqualValues(t, 2, b.current)
|
assert.EqualValues(t, 2, b.current)
|
||||||
g = []bool{false, true, true, false, false, false, false, false, false, false}
|
g = []bool{true, true, true, false, false, false, false, false, false, false}
|
||||||
assert.Equal(t, g, b.bits)
|
assert.Equal(t, g, b.bits)
|
||||||
|
|
||||||
// Receive two again - it will fail
|
// Receive two again - it will fail
|
||||||
assert.False(t, b.Check(l, 2))
|
assert.False(t, b.Check(l, 2))
|
||||||
u = b.Update(l, 2)
|
assert.False(t, b.Update(l, 2))
|
||||||
assert.False(t, u)
|
|
||||||
assert.EqualValues(t, 2, b.current)
|
assert.EqualValues(t, 2, b.current)
|
||||||
|
|
||||||
// Jump ahead to 15, which should clear everything and set the 6th element
|
// Jump ahead to 15, which should clear everything and set the 6th element
|
||||||
assert.True(t, b.Check(l, 15))
|
assert.True(t, b.Check(l, 15))
|
||||||
u = b.Update(l, 15)
|
assert.True(t, b.Update(l, 15))
|
||||||
assert.True(t, u)
|
|
||||||
assert.EqualValues(t, 15, b.current)
|
assert.EqualValues(t, 15, b.current)
|
||||||
g = []bool{false, false, false, false, false, true, false, false, false, false}
|
g = []bool{false, false, false, false, false, true, false, false, false, false}
|
||||||
assert.Equal(t, g, b.bits)
|
assert.Equal(t, g, b.bits)
|
||||||
|
|
||||||
// Mark 14, which is allowed because it is in the window
|
// Mark 14, which is allowed because it is in the window
|
||||||
assert.True(t, b.Check(l, 14))
|
assert.True(t, b.Check(l, 14))
|
||||||
u = b.Update(l, 14)
|
assert.True(t, b.Update(l, 14))
|
||||||
assert.True(t, u)
|
|
||||||
assert.EqualValues(t, 15, b.current)
|
assert.EqualValues(t, 15, b.current)
|
||||||
g = []bool{false, false, false, false, true, true, false, false, false, false}
|
g = []bool{false, false, false, false, true, true, false, false, false, false}
|
||||||
assert.Equal(t, g, b.bits)
|
assert.Equal(t, g, b.bits)
|
||||||
|
|
||||||
// Mark 5, which is not allowed because it is not in the window
|
// Mark 5, which is not allowed because it is not in the window
|
||||||
assert.False(t, b.Check(l, 5))
|
assert.False(t, b.Check(l, 5))
|
||||||
u = b.Update(l, 5)
|
assert.False(t, b.Update(l, 5))
|
||||||
assert.False(t, u)
|
|
||||||
assert.EqualValues(t, 15, b.current)
|
assert.EqualValues(t, 15, b.current)
|
||||||
g = []bool{false, false, false, false, true, true, false, false, false, false}
|
g = []bool{false, false, false, false, true, true, false, false, false, false}
|
||||||
assert.Equal(t, g, b.bits)
|
assert.Equal(t, g, b.bits)
|
||||||
@@ -69,10 +62,29 @@ func TestBits(t *testing.T) {
|
|||||||
|
|
||||||
// Walk through a few windows in order
|
// Walk through a few windows in order
|
||||||
b = NewBits(10)
|
b = NewBits(10)
|
||||||
for i := uint64(0); i <= 100; i++ {
|
for i := uint64(1); i <= 100; i++ {
|
||||||
assert.True(t, b.Check(l, i), "Error while checking %v", i)
|
assert.True(t, b.Check(l, i), "Error while checking %v", i)
|
||||||
assert.True(t, b.Update(l, i), "Error while updating %v", i)
|
assert.True(t, b.Update(l, i), "Error while updating %v", i)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
assert.False(t, b.Check(l, 1), "Out of window check")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBitsLargeJumps(t *testing.T) {
|
||||||
|
l := test.NewLogger()
|
||||||
|
b := NewBits(10)
|
||||||
|
b.lostCounter.Clear()
|
||||||
|
|
||||||
|
b = NewBits(10)
|
||||||
|
b.lostCounter.Clear()
|
||||||
|
assert.True(t, b.Update(l, 55)) // We saw packet 55 and can still track 45,46,47,48,49,50,51,52,53,54
|
||||||
|
assert.Equal(t, int64(45), b.lostCounter.Count())
|
||||||
|
|
||||||
|
assert.True(t, b.Update(l, 100)) // We saw packet 55 and 100 and can still track 90,91,92,93,94,95,96,97,98,99
|
||||||
|
assert.Equal(t, int64(89), b.lostCounter.Count())
|
||||||
|
|
||||||
|
assert.True(t, b.Update(l, 200)) // We saw packet 55, 100, and 200 and can still track 190,191,192,193,194,195,196,197,198,199
|
||||||
|
assert.Equal(t, int64(188), b.lostCounter.Count())
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestBitsDupeCounter(t *testing.T) {
|
func TestBitsDupeCounter(t *testing.T) {
|
||||||
@@ -124,8 +136,7 @@ func TestBitsOutOfWindowCounter(t *testing.T) {
|
|||||||
assert.False(t, b.Update(l, 0))
|
assert.False(t, b.Update(l, 0))
|
||||||
assert.Equal(t, int64(1), b.outOfWindowCounter.Count())
|
assert.Equal(t, int64(1), b.outOfWindowCounter.Count())
|
||||||
|
|
||||||
//tODO: make sure lostcounter doesn't increase in orderly increment
|
assert.Equal(t, int64(19), b.lostCounter.Count()) // packet 0 wasn't lost
|
||||||
assert.Equal(t, int64(20), b.lostCounter.Count())
|
|
||||||
assert.Equal(t, int64(0), b.dupeCounter.Count())
|
assert.Equal(t, int64(0), b.dupeCounter.Count())
|
||||||
assert.Equal(t, int64(1), b.outOfWindowCounter.Count())
|
assert.Equal(t, int64(1), b.outOfWindowCounter.Count())
|
||||||
}
|
}
|
||||||
@@ -137,8 +148,6 @@ func TestBitsLostCounter(t *testing.T) {
|
|||||||
b.dupeCounter.Clear()
|
b.dupeCounter.Clear()
|
||||||
b.outOfWindowCounter.Clear()
|
b.outOfWindowCounter.Clear()
|
||||||
|
|
||||||
//assert.True(t, b.Update(0))
|
|
||||||
assert.True(t, b.Update(l, 0))
|
|
||||||
assert.True(t, b.Update(l, 20))
|
assert.True(t, b.Update(l, 20))
|
||||||
assert.True(t, b.Update(l, 21))
|
assert.True(t, b.Update(l, 21))
|
||||||
assert.True(t, b.Update(l, 22))
|
assert.True(t, b.Update(l, 22))
|
||||||
@@ -149,7 +158,7 @@ func TestBitsLostCounter(t *testing.T) {
|
|||||||
assert.True(t, b.Update(l, 27))
|
assert.True(t, b.Update(l, 27))
|
||||||
assert.True(t, b.Update(l, 28))
|
assert.True(t, b.Update(l, 28))
|
||||||
assert.True(t, b.Update(l, 29))
|
assert.True(t, b.Update(l, 29))
|
||||||
assert.Equal(t, int64(20), b.lostCounter.Count())
|
assert.Equal(t, int64(19), b.lostCounter.Count()) // packet 0 wasn't lost
|
||||||
assert.Equal(t, int64(0), b.dupeCounter.Count())
|
assert.Equal(t, int64(0), b.dupeCounter.Count())
|
||||||
assert.Equal(t, int64(0), b.outOfWindowCounter.Count())
|
assert.Equal(t, int64(0), b.outOfWindowCounter.Count())
|
||||||
|
|
||||||
@@ -158,8 +167,6 @@ func TestBitsLostCounter(t *testing.T) {
|
|||||||
b.dupeCounter.Clear()
|
b.dupeCounter.Clear()
|
||||||
b.outOfWindowCounter.Clear()
|
b.outOfWindowCounter.Clear()
|
||||||
|
|
||||||
assert.True(t, b.Update(l, 0))
|
|
||||||
assert.Equal(t, int64(0), b.lostCounter.Count())
|
|
||||||
assert.True(t, b.Update(l, 9))
|
assert.True(t, b.Update(l, 9))
|
||||||
assert.Equal(t, int64(0), b.lostCounter.Count())
|
assert.Equal(t, int64(0), b.lostCounter.Count())
|
||||||
// 10 will set 0 index, 0 was already set, no lost packets
|
// 10 will set 0 index, 0 was already set, no lost packets
|
||||||
@@ -214,6 +221,62 @@ func TestBitsLostCounter(t *testing.T) {
|
|||||||
assert.Equal(t, int64(0), b.outOfWindowCounter.Count())
|
assert.Equal(t, int64(0), b.outOfWindowCounter.Count())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestBitsLostCounterIssue1(t *testing.T) {
|
||||||
|
l := test.NewLogger()
|
||||||
|
b := NewBits(10)
|
||||||
|
b.lostCounter.Clear()
|
||||||
|
b.dupeCounter.Clear()
|
||||||
|
b.outOfWindowCounter.Clear()
|
||||||
|
|
||||||
|
assert.True(t, b.Update(l, 4))
|
||||||
|
assert.Equal(t, int64(0), b.lostCounter.Count())
|
||||||
|
assert.True(t, b.Update(l, 1))
|
||||||
|
assert.Equal(t, int64(0), b.lostCounter.Count())
|
||||||
|
assert.True(t, b.Update(l, 9))
|
||||||
|
assert.Equal(t, int64(0), b.lostCounter.Count())
|
||||||
|
assert.True(t, b.Update(l, 2))
|
||||||
|
assert.Equal(t, int64(0), b.lostCounter.Count())
|
||||||
|
assert.True(t, b.Update(l, 3))
|
||||||
|
assert.Equal(t, int64(0), b.lostCounter.Count())
|
||||||
|
assert.True(t, b.Update(l, 5))
|
||||||
|
assert.Equal(t, int64(0), b.lostCounter.Count())
|
||||||
|
assert.True(t, b.Update(l, 6))
|
||||||
|
assert.Equal(t, int64(0), b.lostCounter.Count())
|
||||||
|
assert.True(t, b.Update(l, 7))
|
||||||
|
assert.Equal(t, int64(0), b.lostCounter.Count())
|
||||||
|
// assert.True(t, b.Update(l, 8))
|
||||||
|
assert.True(t, b.Update(l, 10))
|
||||||
|
assert.Equal(t, int64(0), b.lostCounter.Count())
|
||||||
|
assert.True(t, b.Update(l, 11))
|
||||||
|
assert.Equal(t, int64(0), b.lostCounter.Count())
|
||||||
|
|
||||||
|
assert.True(t, b.Update(l, 14))
|
||||||
|
assert.Equal(t, int64(0), b.lostCounter.Count())
|
||||||
|
// Issue seems to be here, we reset missing packet 8 to false here and don't increment the lost counter
|
||||||
|
assert.True(t, b.Update(l, 19))
|
||||||
|
assert.Equal(t, int64(1), b.lostCounter.Count())
|
||||||
|
assert.True(t, b.Update(l, 12))
|
||||||
|
assert.Equal(t, int64(1), b.lostCounter.Count())
|
||||||
|
assert.True(t, b.Update(l, 13))
|
||||||
|
assert.Equal(t, int64(1), b.lostCounter.Count())
|
||||||
|
assert.True(t, b.Update(l, 15))
|
||||||
|
assert.Equal(t, int64(1), b.lostCounter.Count())
|
||||||
|
assert.True(t, b.Update(l, 16))
|
||||||
|
assert.Equal(t, int64(1), b.lostCounter.Count())
|
||||||
|
assert.True(t, b.Update(l, 17))
|
||||||
|
assert.Equal(t, int64(1), b.lostCounter.Count())
|
||||||
|
assert.True(t, b.Update(l, 18))
|
||||||
|
assert.Equal(t, int64(1), b.lostCounter.Count())
|
||||||
|
assert.True(t, b.Update(l, 20))
|
||||||
|
assert.Equal(t, int64(1), b.lostCounter.Count())
|
||||||
|
assert.True(t, b.Update(l, 21))
|
||||||
|
|
||||||
|
// We missed packet 8 above
|
||||||
|
assert.Equal(t, int64(1), b.lostCounter.Count())
|
||||||
|
assert.Equal(t, int64(0), b.dupeCounter.Count())
|
||||||
|
assert.Equal(t, int64(0), b.outOfWindowCounter.Count())
|
||||||
|
}
|
||||||
|
|
||||||
func BenchmarkBits(b *testing.B) {
|
func BenchmarkBits(b *testing.B) {
|
||||||
z := NewBits(10)
|
z := NewBits(10)
|
||||||
for n := 0; n < b.N; n++ {
|
for n := 0; n < b.N; n++ {
|
||||||
|
|||||||
@@ -84,16 +84,11 @@ func NewCalculatedRemotesFromConfig(c *config.C, k string) (*bart.Table[[]*calcu
|
|||||||
|
|
||||||
calculatedRemotes := new(bart.Table[[]*calculatedRemote])
|
calculatedRemotes := new(bart.Table[[]*calculatedRemote])
|
||||||
|
|
||||||
rawMap, ok := value.(map[any]any)
|
rawMap, ok := value.(map[string]any)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("config `%s` has invalid type: %T", k, value)
|
return nil, fmt.Errorf("config `%s` has invalid type: %T", k, value)
|
||||||
}
|
}
|
||||||
for rawKey, rawValue := range rawMap {
|
for rawCIDR, rawValue := range rawMap {
|
||||||
rawCIDR, ok := rawKey.(string)
|
|
||||||
if !ok {
|
|
||||||
return nil, fmt.Errorf("config `%s` has invalid key (type %T): %v", k, rawKey, rawKey)
|
|
||||||
}
|
|
||||||
|
|
||||||
cidr, err := netip.ParsePrefix(rawCIDR)
|
cidr, err := netip.ParsePrefix(rawCIDR)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("config `%s` has invalid CIDR: %s", k, rawCIDR)
|
return nil, fmt.Errorf("config `%s` has invalid CIDR: %s", k, rawCIDR)
|
||||||
@@ -129,7 +124,7 @@ func newCalculatedRemotesListFromConfig(cidr netip.Prefix, raw any) ([]*calculat
|
|||||||
}
|
}
|
||||||
|
|
||||||
func newCalculatedRemotesEntryFromConfig(cidr netip.Prefix, raw any) (*calculatedRemote, error) {
|
func newCalculatedRemotesEntryFromConfig(cidr netip.Prefix, raw any) (*calculatedRemote, error) {
|
||||||
rawMap, ok := raw.(map[any]any)
|
rawMap, ok := raw.(map[string]any)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("invalid type: %T", raw)
|
return nil, fmt.Errorf("invalid type: %T", raw)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -15,10 +15,10 @@ func TestCalculatedRemoteApply(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
input, err := netip.ParseAddr("10.0.10.182")
|
input, err := netip.ParseAddr("10.0.10.182")
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
expected, err := netip.ParseAddr("192.168.1.182")
|
expected, err := netip.ParseAddr("192.168.1.182")
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
assert.Equal(t, netAddrToProtoV4AddrPort(expected, 4242), c.ApplyV4(input))
|
assert.Equal(t, netAddrToProtoV4AddrPort(expected, 4242), c.ApplyV4(input))
|
||||||
|
|
||||||
@@ -28,10 +28,10 @@ func TestCalculatedRemoteApply(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
input, err = netip.ParseAddr("beef:beef:beef:beef:beef:beef:beef:beef")
|
input, err = netip.ParseAddr("beef:beef:beef:beef:beef:beef:beef:beef")
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
expected, err = netip.ParseAddr("ffff:ffff:ffff:ffff:beef:beef:beef:beef")
|
expected, err = netip.ParseAddr("ffff:ffff:ffff:ffff:beef:beef:beef:beef")
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
assert.Equal(t, netAddrToProtoV6AddrPort(expected, 4242), c.ApplyV6(input))
|
assert.Equal(t, netAddrToProtoV6AddrPort(expected, 4242), c.ApplyV6(input))
|
||||||
|
|
||||||
@@ -41,10 +41,10 @@ func TestCalculatedRemoteApply(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
input, err = netip.ParseAddr("beef:beef:beef:beef:beef:beef:beef:beef")
|
input, err = netip.ParseAddr("beef:beef:beef:beef:beef:beef:beef:beef")
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
expected, err = netip.ParseAddr("ffff:ffff:ffff:ffff:ffff:beef:beef:beef")
|
expected, err = netip.ParseAddr("ffff:ffff:ffff:ffff:ffff:beef:beef:beef")
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
assert.Equal(t, netAddrToProtoV6AddrPort(expected, 4242), c.ApplyV6(input))
|
assert.Equal(t, netAddrToProtoV6AddrPort(expected, 4242), c.ApplyV6(input))
|
||||||
|
|
||||||
@@ -54,10 +54,10 @@ func TestCalculatedRemoteApply(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
input, err = netip.ParseAddr("beef:beef:beef:beef:beef:beef:beef:beef")
|
input, err = netip.ParseAddr("beef:beef:beef:beef:beef:beef:beef:beef")
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
expected, err = netip.ParseAddr("ffff:ffff:ffff:beef:beef:beef:beef:beef")
|
expected, err = netip.ParseAddr("ffff:ffff:ffff:beef:beef:beef:beef:beef")
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
assert.Equal(t, netAddrToProtoV6AddrPort(expected, 4242), c.ApplyV6(input))
|
assert.Equal(t, netAddrToProtoV6AddrPort(expected, 4242), c.ApplyV6(input))
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestNewCAPoolFromBytes(t *testing.T) {
|
func TestNewCAPoolFromBytes(t *testing.T) {
|
||||||
@@ -82,32 +83,32 @@ k+coOv04r+zh33ISyhbsafnYduN17p2eD7CmHvHuerguXD9f32gcxo/KsFCKEjMe
|
|||||||
}
|
}
|
||||||
|
|
||||||
p, err := NewCAPoolFromPEM([]byte(noNewLines))
|
p, err := NewCAPoolFromPEM([]byte(noNewLines))
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, p.CAs["ce4e6c7a596996eb0d82a8875f0f0137a4b53ce22d2421c9fd7150e7a26f6300"].Certificate.Name(), rootCA.details.name)
|
assert.Equal(t, p.CAs["ce4e6c7a596996eb0d82a8875f0f0137a4b53ce22d2421c9fd7150e7a26f6300"].Certificate.Name(), rootCA.details.name)
|
||||||
assert.Equal(t, p.CAs["04c585fcd9a49b276df956a22b7ebea3bf23f1fca5a17c0b56ce2e626631969e"].Certificate.Name(), rootCA01.details.name)
|
assert.Equal(t, p.CAs["04c585fcd9a49b276df956a22b7ebea3bf23f1fca5a17c0b56ce2e626631969e"].Certificate.Name(), rootCA01.details.name)
|
||||||
|
|
||||||
pp, err := NewCAPoolFromPEM([]byte(withNewLines))
|
pp, err := NewCAPoolFromPEM([]byte(withNewLines))
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, pp.CAs["ce4e6c7a596996eb0d82a8875f0f0137a4b53ce22d2421c9fd7150e7a26f6300"].Certificate.Name(), rootCA.details.name)
|
assert.Equal(t, pp.CAs["ce4e6c7a596996eb0d82a8875f0f0137a4b53ce22d2421c9fd7150e7a26f6300"].Certificate.Name(), rootCA.details.name)
|
||||||
assert.Equal(t, pp.CAs["04c585fcd9a49b276df956a22b7ebea3bf23f1fca5a17c0b56ce2e626631969e"].Certificate.Name(), rootCA01.details.name)
|
assert.Equal(t, pp.CAs["04c585fcd9a49b276df956a22b7ebea3bf23f1fca5a17c0b56ce2e626631969e"].Certificate.Name(), rootCA01.details.name)
|
||||||
|
|
||||||
// expired cert, no valid certs
|
// expired cert, no valid certs
|
||||||
ppp, err := NewCAPoolFromPEM([]byte(expired))
|
ppp, err := NewCAPoolFromPEM([]byte(expired))
|
||||||
assert.Equal(t, ErrExpired, err)
|
assert.Equal(t, ErrExpired, err)
|
||||||
assert.Equal(t, ppp.CAs["c39b35a0e8f246203fe4f32b9aa8bfd155f1ae6a6be9d78370641e43397f48f5"].Certificate.Name(), "expired")
|
assert.Equal(t, "expired", ppp.CAs["c39b35a0e8f246203fe4f32b9aa8bfd155f1ae6a6be9d78370641e43397f48f5"].Certificate.Name())
|
||||||
|
|
||||||
// expired cert, with valid certs
|
// expired cert, with valid certs
|
||||||
pppp, err := NewCAPoolFromPEM(append([]byte(expired), noNewLines...))
|
pppp, err := NewCAPoolFromPEM(append([]byte(expired), noNewLines...))
|
||||||
assert.Equal(t, ErrExpired, err)
|
assert.Equal(t, ErrExpired, err)
|
||||||
assert.Equal(t, pppp.CAs["ce4e6c7a596996eb0d82a8875f0f0137a4b53ce22d2421c9fd7150e7a26f6300"].Certificate.Name(), rootCA.details.name)
|
assert.Equal(t, pppp.CAs["ce4e6c7a596996eb0d82a8875f0f0137a4b53ce22d2421c9fd7150e7a26f6300"].Certificate.Name(), rootCA.details.name)
|
||||||
assert.Equal(t, pppp.CAs["04c585fcd9a49b276df956a22b7ebea3bf23f1fca5a17c0b56ce2e626631969e"].Certificate.Name(), rootCA01.details.name)
|
assert.Equal(t, pppp.CAs["04c585fcd9a49b276df956a22b7ebea3bf23f1fca5a17c0b56ce2e626631969e"].Certificate.Name(), rootCA01.details.name)
|
||||||
assert.Equal(t, pppp.CAs["c39b35a0e8f246203fe4f32b9aa8bfd155f1ae6a6be9d78370641e43397f48f5"].Certificate.Name(), "expired")
|
assert.Equal(t, "expired", pppp.CAs["c39b35a0e8f246203fe4f32b9aa8bfd155f1ae6a6be9d78370641e43397f48f5"].Certificate.Name())
|
||||||
assert.Equal(t, len(pppp.CAs), 3)
|
assert.Len(t, pppp.CAs, 3)
|
||||||
|
|
||||||
ppppp, err := NewCAPoolFromPEM([]byte(p256))
|
ppppp, err := NewCAPoolFromPEM([]byte(p256))
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, ppppp.CAs["552bf7d99bec1fc775a0e4c324bf6d8f789b3078f1919c7960d2e5e0c351ee97"].Certificate.Name(), rootCAP256.details.name)
|
assert.Equal(t, ppppp.CAs["552bf7d99bec1fc775a0e4c324bf6d8f789b3078f1919c7960d2e5e0c351ee97"].Certificate.Name(), rootCAP256.details.name)
|
||||||
assert.Equal(t, len(ppppp.CAs), 1)
|
assert.Len(t, ppppp.CAs, 1)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCertificateV1_Verify(t *testing.T) {
|
func TestCertificateV1_Verify(t *testing.T) {
|
||||||
@@ -115,21 +116,21 @@ func TestCertificateV1_Verify(t *testing.T) {
|
|||||||
c, _, _, _ := NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test cert", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil)
|
c, _, _, _ := NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test cert", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil)
|
||||||
|
|
||||||
caPool := NewCAPool()
|
caPool := NewCAPool()
|
||||||
assert.NoError(t, caPool.AddCA(ca))
|
require.NoError(t, caPool.AddCA(ca))
|
||||||
|
|
||||||
f, err := c.Fingerprint()
|
f, err := c.Fingerprint()
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
caPool.BlocklistFingerprint(f)
|
caPool.BlocklistFingerprint(f)
|
||||||
|
|
||||||
_, err = caPool.VerifyCertificate(time.Now(), c)
|
_, err = caPool.VerifyCertificate(time.Now(), c)
|
||||||
assert.EqualError(t, err, "certificate is in the block list")
|
require.EqualError(t, err, "certificate is in the block list")
|
||||||
|
|
||||||
caPool.ResetCertBlocklist()
|
caPool.ResetCertBlocklist()
|
||||||
_, err = caPool.VerifyCertificate(time.Now(), c)
|
_, err = caPool.VerifyCertificate(time.Now(), c)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
_, err = caPool.VerifyCertificate(time.Now().Add(time.Hour*1000), c)
|
_, err = caPool.VerifyCertificate(time.Now().Add(time.Hour*1000), c)
|
||||||
assert.EqualError(t, err, "root certificate is expired")
|
require.EqualError(t, err, "root certificate is expired")
|
||||||
|
|
||||||
assert.PanicsWithError(t, "certificate is valid before the signing certificate", func() {
|
assert.PanicsWithError(t, "certificate is valid before the signing certificate", func() {
|
||||||
NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test cert2", time.Time{}, time.Time{}, nil, nil, nil)
|
NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test cert2", time.Time{}, time.Time{}, nil, nil, nil)
|
||||||
@@ -138,11 +139,11 @@ func TestCertificateV1_Verify(t *testing.T) {
|
|||||||
// Test group assertion
|
// Test group assertion
|
||||||
ca, _, caKey, _ = NewTestCaCert(Version1, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{"test1", "test2"})
|
ca, _, caKey, _ = NewTestCaCert(Version1, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{"test1", "test2"})
|
||||||
caPem, err := ca.MarshalPEM()
|
caPem, err := ca.MarshalPEM()
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
caPool = NewCAPool()
|
caPool = NewCAPool()
|
||||||
b, err := caPool.AddCAFromPEM(caPem)
|
b, err := caPool.AddCAFromPEM(caPem)
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Empty(t, b)
|
assert.Empty(t, b)
|
||||||
|
|
||||||
assert.PanicsWithError(t, "certificate contained a group not present on the signing ca: bad", func() {
|
assert.PanicsWithError(t, "certificate contained a group not present on the signing ca: bad", func() {
|
||||||
@@ -150,9 +151,9 @@ func TestCertificateV1_Verify(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test2", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"})
|
c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test2", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"})
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
_, err = caPool.VerifyCertificate(time.Now(), c)
|
_, err = caPool.VerifyCertificate(time.Now(), c)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCertificateV1_VerifyP256(t *testing.T) {
|
func TestCertificateV1_VerifyP256(t *testing.T) {
|
||||||
@@ -160,21 +161,21 @@ func TestCertificateV1_VerifyP256(t *testing.T) {
|
|||||||
c, _, _, _ := NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil)
|
c, _, _, _ := NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil)
|
||||||
|
|
||||||
caPool := NewCAPool()
|
caPool := NewCAPool()
|
||||||
assert.NoError(t, caPool.AddCA(ca))
|
require.NoError(t, caPool.AddCA(ca))
|
||||||
|
|
||||||
f, err := c.Fingerprint()
|
f, err := c.Fingerprint()
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
caPool.BlocklistFingerprint(f)
|
caPool.BlocklistFingerprint(f)
|
||||||
|
|
||||||
_, err = caPool.VerifyCertificate(time.Now(), c)
|
_, err = caPool.VerifyCertificate(time.Now(), c)
|
||||||
assert.EqualError(t, err, "certificate is in the block list")
|
require.EqualError(t, err, "certificate is in the block list")
|
||||||
|
|
||||||
caPool.ResetCertBlocklist()
|
caPool.ResetCertBlocklist()
|
||||||
_, err = caPool.VerifyCertificate(time.Now(), c)
|
_, err = caPool.VerifyCertificate(time.Now(), c)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
_, err = caPool.VerifyCertificate(time.Now().Add(time.Hour*1000), c)
|
_, err = caPool.VerifyCertificate(time.Now().Add(time.Hour*1000), c)
|
||||||
assert.EqualError(t, err, "root certificate is expired")
|
require.EqualError(t, err, "root certificate is expired")
|
||||||
|
|
||||||
assert.PanicsWithError(t, "certificate is valid before the signing certificate", func() {
|
assert.PanicsWithError(t, "certificate is valid before the signing certificate", func() {
|
||||||
NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil)
|
NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil)
|
||||||
@@ -183,11 +184,11 @@ func TestCertificateV1_VerifyP256(t *testing.T) {
|
|||||||
// Test group assertion
|
// Test group assertion
|
||||||
ca, _, caKey, _ = NewTestCaCert(Version1, Curve_P256, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{"test1", "test2"})
|
ca, _, caKey, _ = NewTestCaCert(Version1, Curve_P256, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{"test1", "test2"})
|
||||||
caPem, err := ca.MarshalPEM()
|
caPem, err := ca.MarshalPEM()
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
caPool = NewCAPool()
|
caPool = NewCAPool()
|
||||||
b, err := caPool.AddCAFromPEM(caPem)
|
b, err := caPool.AddCAFromPEM(caPem)
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Empty(t, b)
|
assert.Empty(t, b)
|
||||||
|
|
||||||
assert.PanicsWithError(t, "certificate contained a group not present on the signing ca: bad", func() {
|
assert.PanicsWithError(t, "certificate contained a group not present on the signing ca: bad", func() {
|
||||||
@@ -196,7 +197,7 @@ func TestCertificateV1_VerifyP256(t *testing.T) {
|
|||||||
|
|
||||||
c, _, _, _ = NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"})
|
c, _, _, _ = NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"})
|
||||||
_, err = caPool.VerifyCertificate(time.Now(), c)
|
_, err = caPool.VerifyCertificate(time.Now(), c)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCertificateV1_Verify_IPs(t *testing.T) {
|
func TestCertificateV1_Verify_IPs(t *testing.T) {
|
||||||
@@ -205,11 +206,11 @@ func TestCertificateV1_Verify_IPs(t *testing.T) {
|
|||||||
ca, _, caKey, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), []netip.Prefix{caIp1, caIp2}, nil, []string{"test"})
|
ca, _, caKey, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), []netip.Prefix{caIp1, caIp2}, nil, []string{"test"})
|
||||||
|
|
||||||
caPem, err := ca.MarshalPEM()
|
caPem, err := ca.MarshalPEM()
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
caPool := NewCAPool()
|
caPool := NewCAPool()
|
||||||
b, err := caPool.AddCAFromPEM(caPem)
|
b, err := caPool.AddCAFromPEM(caPem)
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Empty(t, b)
|
assert.Empty(t, b)
|
||||||
|
|
||||||
// ip is outside the network
|
// ip is outside the network
|
||||||
@@ -245,25 +246,25 @@ func TestCertificateV1_Verify_IPs(t *testing.T) {
|
|||||||
cIp2 = mustParsePrefixUnmapped("192.168.0.1/25")
|
cIp2 = mustParsePrefixUnmapped("192.168.0.1/25")
|
||||||
c, _, _, _ := NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"})
|
c, _, _, _ := NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"})
|
||||||
_, err = caPool.VerifyCertificate(time.Now(), c)
|
_, err = caPool.VerifyCertificate(time.Now(), c)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Exact matches
|
// Exact matches
|
||||||
c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp1, caIp2}, nil, []string{"test"})
|
c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp1, caIp2}, nil, []string{"test"})
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
_, err = caPool.VerifyCertificate(time.Now(), c)
|
_, err = caPool.VerifyCertificate(time.Now(), c)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Exact matches reversed
|
// Exact matches reversed
|
||||||
c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp2, caIp1}, nil, []string{"test"})
|
c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp2, caIp1}, nil, []string{"test"})
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
_, err = caPool.VerifyCertificate(time.Now(), c)
|
_, err = caPool.VerifyCertificate(time.Now(), c)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Exact matches reversed with just 1
|
// Exact matches reversed with just 1
|
||||||
c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp1}, nil, []string{"test"})
|
c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp1}, nil, []string{"test"})
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
_, err = caPool.VerifyCertificate(time.Now(), c)
|
_, err = caPool.VerifyCertificate(time.Now(), c)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCertificateV1_Verify_Subnets(t *testing.T) {
|
func TestCertificateV1_Verify_Subnets(t *testing.T) {
|
||||||
@@ -272,11 +273,11 @@ func TestCertificateV1_Verify_Subnets(t *testing.T) {
|
|||||||
ca, _, caKey, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, []netip.Prefix{caIp1, caIp2}, []string{"test"})
|
ca, _, caKey, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, []netip.Prefix{caIp1, caIp2}, []string{"test"})
|
||||||
|
|
||||||
caPem, err := ca.MarshalPEM()
|
caPem, err := ca.MarshalPEM()
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
caPool := NewCAPool()
|
caPool := NewCAPool()
|
||||||
b, err := caPool.AddCAFromPEM(caPem)
|
b, err := caPool.AddCAFromPEM(caPem)
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Empty(t, b)
|
assert.Empty(t, b)
|
||||||
|
|
||||||
// ip is outside the network
|
// ip is outside the network
|
||||||
@@ -311,27 +312,27 @@ func TestCertificateV1_Verify_Subnets(t *testing.T) {
|
|||||||
cIp1 = mustParsePrefixUnmapped("10.0.1.0/16")
|
cIp1 = mustParsePrefixUnmapped("10.0.1.0/16")
|
||||||
cIp2 = mustParsePrefixUnmapped("192.168.0.1/25")
|
cIp2 = mustParsePrefixUnmapped("192.168.0.1/25")
|
||||||
c, _, _, _ := NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"})
|
c, _, _, _ := NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"})
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
_, err = caPool.VerifyCertificate(time.Now(), c)
|
_, err = caPool.VerifyCertificate(time.Now(), c)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Exact matches
|
// Exact matches
|
||||||
c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp1, caIp2}, []string{"test"})
|
c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp1, caIp2}, []string{"test"})
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
_, err = caPool.VerifyCertificate(time.Now(), c)
|
_, err = caPool.VerifyCertificate(time.Now(), c)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Exact matches reversed
|
// Exact matches reversed
|
||||||
c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp2, caIp1}, []string{"test"})
|
c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp2, caIp1}, []string{"test"})
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
_, err = caPool.VerifyCertificate(time.Now(), c)
|
_, err = caPool.VerifyCertificate(time.Now(), c)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Exact matches reversed with just 1
|
// Exact matches reversed with just 1
|
||||||
c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp1}, []string{"test"})
|
c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp1}, []string{"test"})
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
_, err = caPool.VerifyCertificate(time.Now(), c)
|
_, err = caPool.VerifyCertificate(time.Now(), c)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCertificateV2_Verify(t *testing.T) {
|
func TestCertificateV2_Verify(t *testing.T) {
|
||||||
@@ -339,21 +340,21 @@ func TestCertificateV2_Verify(t *testing.T) {
|
|||||||
c, _, _, _ := NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test cert", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil)
|
c, _, _, _ := NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test cert", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil)
|
||||||
|
|
||||||
caPool := NewCAPool()
|
caPool := NewCAPool()
|
||||||
assert.NoError(t, caPool.AddCA(ca))
|
require.NoError(t, caPool.AddCA(ca))
|
||||||
|
|
||||||
f, err := c.Fingerprint()
|
f, err := c.Fingerprint()
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
caPool.BlocklistFingerprint(f)
|
caPool.BlocklistFingerprint(f)
|
||||||
|
|
||||||
_, err = caPool.VerifyCertificate(time.Now(), c)
|
_, err = caPool.VerifyCertificate(time.Now(), c)
|
||||||
assert.EqualError(t, err, "certificate is in the block list")
|
require.EqualError(t, err, "certificate is in the block list")
|
||||||
|
|
||||||
caPool.ResetCertBlocklist()
|
caPool.ResetCertBlocklist()
|
||||||
_, err = caPool.VerifyCertificate(time.Now(), c)
|
_, err = caPool.VerifyCertificate(time.Now(), c)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
_, err = caPool.VerifyCertificate(time.Now().Add(time.Hour*1000), c)
|
_, err = caPool.VerifyCertificate(time.Now().Add(time.Hour*1000), c)
|
||||||
assert.EqualError(t, err, "root certificate is expired")
|
require.EqualError(t, err, "root certificate is expired")
|
||||||
|
|
||||||
assert.PanicsWithError(t, "certificate is valid before the signing certificate", func() {
|
assert.PanicsWithError(t, "certificate is valid before the signing certificate", func() {
|
||||||
NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test cert2", time.Time{}, time.Time{}, nil, nil, nil)
|
NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test cert2", time.Time{}, time.Time{}, nil, nil, nil)
|
||||||
@@ -362,11 +363,11 @@ func TestCertificateV2_Verify(t *testing.T) {
|
|||||||
// Test group assertion
|
// Test group assertion
|
||||||
ca, _, caKey, _ = NewTestCaCert(Version2, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{"test1", "test2"})
|
ca, _, caKey, _ = NewTestCaCert(Version2, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{"test1", "test2"})
|
||||||
caPem, err := ca.MarshalPEM()
|
caPem, err := ca.MarshalPEM()
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
caPool = NewCAPool()
|
caPool = NewCAPool()
|
||||||
b, err := caPool.AddCAFromPEM(caPem)
|
b, err := caPool.AddCAFromPEM(caPem)
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Empty(t, b)
|
assert.Empty(t, b)
|
||||||
|
|
||||||
assert.PanicsWithError(t, "certificate contained a group not present on the signing ca: bad", func() {
|
assert.PanicsWithError(t, "certificate contained a group not present on the signing ca: bad", func() {
|
||||||
@@ -374,9 +375,9 @@ func TestCertificateV2_Verify(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test2", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"})
|
c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test2", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"})
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
_, err = caPool.VerifyCertificate(time.Now(), c)
|
_, err = caPool.VerifyCertificate(time.Now(), c)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCertificateV2_VerifyP256(t *testing.T) {
|
func TestCertificateV2_VerifyP256(t *testing.T) {
|
||||||
@@ -384,21 +385,21 @@ func TestCertificateV2_VerifyP256(t *testing.T) {
|
|||||||
c, _, _, _ := NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil)
|
c, _, _, _ := NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil)
|
||||||
|
|
||||||
caPool := NewCAPool()
|
caPool := NewCAPool()
|
||||||
assert.NoError(t, caPool.AddCA(ca))
|
require.NoError(t, caPool.AddCA(ca))
|
||||||
|
|
||||||
f, err := c.Fingerprint()
|
f, err := c.Fingerprint()
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
caPool.BlocklistFingerprint(f)
|
caPool.BlocklistFingerprint(f)
|
||||||
|
|
||||||
_, err = caPool.VerifyCertificate(time.Now(), c)
|
_, err = caPool.VerifyCertificate(time.Now(), c)
|
||||||
assert.EqualError(t, err, "certificate is in the block list")
|
require.EqualError(t, err, "certificate is in the block list")
|
||||||
|
|
||||||
caPool.ResetCertBlocklist()
|
caPool.ResetCertBlocklist()
|
||||||
_, err = caPool.VerifyCertificate(time.Now(), c)
|
_, err = caPool.VerifyCertificate(time.Now(), c)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
_, err = caPool.VerifyCertificate(time.Now().Add(time.Hour*1000), c)
|
_, err = caPool.VerifyCertificate(time.Now().Add(time.Hour*1000), c)
|
||||||
assert.EqualError(t, err, "root certificate is expired")
|
require.EqualError(t, err, "root certificate is expired")
|
||||||
|
|
||||||
assert.PanicsWithError(t, "certificate is valid before the signing certificate", func() {
|
assert.PanicsWithError(t, "certificate is valid before the signing certificate", func() {
|
||||||
NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil)
|
NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil)
|
||||||
@@ -407,11 +408,11 @@ func TestCertificateV2_VerifyP256(t *testing.T) {
|
|||||||
// Test group assertion
|
// Test group assertion
|
||||||
ca, _, caKey, _ = NewTestCaCert(Version2, Curve_P256, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{"test1", "test2"})
|
ca, _, caKey, _ = NewTestCaCert(Version2, Curve_P256, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{"test1", "test2"})
|
||||||
caPem, err := ca.MarshalPEM()
|
caPem, err := ca.MarshalPEM()
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
caPool = NewCAPool()
|
caPool = NewCAPool()
|
||||||
b, err := caPool.AddCAFromPEM(caPem)
|
b, err := caPool.AddCAFromPEM(caPem)
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Empty(t, b)
|
assert.Empty(t, b)
|
||||||
|
|
||||||
assert.PanicsWithError(t, "certificate contained a group not present on the signing ca: bad", func() {
|
assert.PanicsWithError(t, "certificate contained a group not present on the signing ca: bad", func() {
|
||||||
@@ -420,7 +421,7 @@ func TestCertificateV2_VerifyP256(t *testing.T) {
|
|||||||
|
|
||||||
c, _, _, _ = NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"})
|
c, _, _, _ = NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"})
|
||||||
_, err = caPool.VerifyCertificate(time.Now(), c)
|
_, err = caPool.VerifyCertificate(time.Now(), c)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCertificateV2_Verify_IPs(t *testing.T) {
|
func TestCertificateV2_Verify_IPs(t *testing.T) {
|
||||||
@@ -429,11 +430,11 @@ func TestCertificateV2_Verify_IPs(t *testing.T) {
|
|||||||
ca, _, caKey, _ := NewTestCaCert(Version2, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), []netip.Prefix{caIp1, caIp2}, nil, []string{"test"})
|
ca, _, caKey, _ := NewTestCaCert(Version2, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), []netip.Prefix{caIp1, caIp2}, nil, []string{"test"})
|
||||||
|
|
||||||
caPem, err := ca.MarshalPEM()
|
caPem, err := ca.MarshalPEM()
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
caPool := NewCAPool()
|
caPool := NewCAPool()
|
||||||
b, err := caPool.AddCAFromPEM(caPem)
|
b, err := caPool.AddCAFromPEM(caPem)
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Empty(t, b)
|
assert.Empty(t, b)
|
||||||
|
|
||||||
// ip is outside the network
|
// ip is outside the network
|
||||||
@@ -469,25 +470,25 @@ func TestCertificateV2_Verify_IPs(t *testing.T) {
|
|||||||
cIp2 = mustParsePrefixUnmapped("192.168.0.1/25")
|
cIp2 = mustParsePrefixUnmapped("192.168.0.1/25")
|
||||||
c, _, _, _ := NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"})
|
c, _, _, _ := NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"})
|
||||||
_, err = caPool.VerifyCertificate(time.Now(), c)
|
_, err = caPool.VerifyCertificate(time.Now(), c)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Exact matches
|
// Exact matches
|
||||||
c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp1, caIp2}, nil, []string{"test"})
|
c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp1, caIp2}, nil, []string{"test"})
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
_, err = caPool.VerifyCertificate(time.Now(), c)
|
_, err = caPool.VerifyCertificate(time.Now(), c)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Exact matches reversed
|
// Exact matches reversed
|
||||||
c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp2, caIp1}, nil, []string{"test"})
|
c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp2, caIp1}, nil, []string{"test"})
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
_, err = caPool.VerifyCertificate(time.Now(), c)
|
_, err = caPool.VerifyCertificate(time.Now(), c)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Exact matches reversed with just 1
|
// Exact matches reversed with just 1
|
||||||
c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp1}, nil, []string{"test"})
|
c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp1}, nil, []string{"test"})
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
_, err = caPool.VerifyCertificate(time.Now(), c)
|
_, err = caPool.VerifyCertificate(time.Now(), c)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCertificateV2_Verify_Subnets(t *testing.T) {
|
func TestCertificateV2_Verify_Subnets(t *testing.T) {
|
||||||
@@ -496,11 +497,11 @@ func TestCertificateV2_Verify_Subnets(t *testing.T) {
|
|||||||
ca, _, caKey, _ := NewTestCaCert(Version2, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, []netip.Prefix{caIp1, caIp2}, []string{"test"})
|
ca, _, caKey, _ := NewTestCaCert(Version2, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, []netip.Prefix{caIp1, caIp2}, []string{"test"})
|
||||||
|
|
||||||
caPem, err := ca.MarshalPEM()
|
caPem, err := ca.MarshalPEM()
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
caPool := NewCAPool()
|
caPool := NewCAPool()
|
||||||
b, err := caPool.AddCAFromPEM(caPem)
|
b, err := caPool.AddCAFromPEM(caPem)
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Empty(t, b)
|
assert.Empty(t, b)
|
||||||
|
|
||||||
// ip is outside the network
|
// ip is outside the network
|
||||||
@@ -535,25 +536,25 @@ func TestCertificateV2_Verify_Subnets(t *testing.T) {
|
|||||||
cIp1 = mustParsePrefixUnmapped("10.0.1.0/16")
|
cIp1 = mustParsePrefixUnmapped("10.0.1.0/16")
|
||||||
cIp2 = mustParsePrefixUnmapped("192.168.0.1/25")
|
cIp2 = mustParsePrefixUnmapped("192.168.0.1/25")
|
||||||
c, _, _, _ := NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"})
|
c, _, _, _ := NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"})
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
_, err = caPool.VerifyCertificate(time.Now(), c)
|
_, err = caPool.VerifyCertificate(time.Now(), c)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Exact matches
|
// Exact matches
|
||||||
c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp1, caIp2}, []string{"test"})
|
c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp1, caIp2}, []string{"test"})
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
_, err = caPool.VerifyCertificate(time.Now(), c)
|
_, err = caPool.VerifyCertificate(time.Now(), c)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Exact matches reversed
|
// Exact matches reversed
|
||||||
c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp2, caIp1}, []string{"test"})
|
c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp2, caIp1}, []string{"test"})
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
_, err = caPool.VerifyCertificate(time.Now(), c)
|
_, err = caPool.VerifyCertificate(time.Now(), c)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Exact matches reversed with just 1
|
// Exact matches reversed with just 1
|
||||||
c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp1}, []string{"test"})
|
c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp1}, []string{"test"})
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
_, err = caPool.VerifyCertificate(time.Now(), c)
|
_, err = caPool.VerifyCertificate(time.Now(), c)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
}
|
}
|
||||||
|
|||||||
28
cert/cert.go
28
cert/cert.go
@@ -58,6 +58,9 @@ type Certificate interface {
|
|||||||
// PublicKey is the raw bytes to be used in asymmetric cryptographic operations.
|
// PublicKey is the raw bytes to be used in asymmetric cryptographic operations.
|
||||||
PublicKey() []byte
|
PublicKey() []byte
|
||||||
|
|
||||||
|
// MarshalPublicKeyPEM is the value of PublicKey marshalled to PEM
|
||||||
|
MarshalPublicKeyPEM() []byte
|
||||||
|
|
||||||
// Curve identifies which curve was used for the PublicKey and Signature.
|
// Curve identifies which curve was used for the PublicKey and Signature.
|
||||||
Curve() Curve
|
Curve() Curve
|
||||||
|
|
||||||
@@ -113,10 +116,10 @@ func (cc *CachedCertificate) String() string {
|
|||||||
return cc.Certificate.String()
|
return cc.Certificate.String()
|
||||||
}
|
}
|
||||||
|
|
||||||
// RecombineAndValidate will attempt to unmarshal a certificate received in a handshake.
|
// Recombine will attempt to unmarshal a certificate received in a handshake.
|
||||||
// Handshakes save space by placing the peers public key in a different part of the packet, we have to
|
// Handshakes save space by placing the peers public key in a different part of the packet, we have to
|
||||||
// reassemble the actual certificate structure with that in mind.
|
// reassemble the actual certificate structure with that in mind.
|
||||||
func RecombineAndValidate(v Version, rawCertBytes, publicKey []byte, curve Curve, caPool *CAPool) (*CachedCertificate, error) {
|
func Recombine(v Version, rawCertBytes, publicKey []byte, curve Curve) (Certificate, error) {
|
||||||
if publicKey == nil {
|
if publicKey == nil {
|
||||||
return nil, ErrNoPeerStaticKey
|
return nil, ErrNoPeerStaticKey
|
||||||
}
|
}
|
||||||
@@ -125,32 +128,17 @@ func RecombineAndValidate(v Version, rawCertBytes, publicKey []byte, curve Curve
|
|||||||
return nil, ErrNoPayload
|
return nil, ErrNoPayload
|
||||||
}
|
}
|
||||||
|
|
||||||
c, err := unmarshalCertificateFromHandshake(v, rawCertBytes, publicKey, curve)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("error unmarshaling cert: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
cc, err := caPool.VerifyCertificate(time.Now(), c)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("certificate validation failed: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return cc, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func unmarshalCertificateFromHandshake(v Version, b []byte, publicKey []byte, curve Curve) (Certificate, error) {
|
|
||||||
var c Certificate
|
var c Certificate
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
switch v {
|
switch v {
|
||||||
// Implementations must ensure the result is a valid cert!
|
// Implementations must ensure the result is a valid cert!
|
||||||
case VersionPre1, Version1:
|
case VersionPre1, Version1:
|
||||||
c, err = unmarshalCertificateV1(b, publicKey)
|
c, err = unmarshalCertificateV1(rawCertBytes, publicKey)
|
||||||
case Version2:
|
case Version2:
|
||||||
c, err = unmarshalCertificateV2(b, publicKey, curve)
|
c, err = unmarshalCertificateV2(rawCertBytes, publicKey, curve)
|
||||||
default:
|
default:
|
||||||
//TODO: CERT-V2 make a static var
|
return nil, ErrUnknownVersion
|
||||||
return nil, fmt.Errorf("unknown certificate version %d", v)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -41,7 +41,7 @@ type detailsV1 struct {
|
|||||||
curve Curve
|
curve Curve
|
||||||
}
|
}
|
||||||
|
|
||||||
type m map[string]interface{}
|
type m = map[string]any
|
||||||
|
|
||||||
func (c *certificateV1) Version() Version {
|
func (c *certificateV1) Version() Version {
|
||||||
return Version1
|
return Version1
|
||||||
@@ -83,6 +83,10 @@ func (c *certificateV1) PublicKey() []byte {
|
|||||||
return c.details.publicKey
|
return c.details.publicKey
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *certificateV1) MarshalPublicKeyPEM() []byte {
|
||||||
|
return marshalCertPublicKeyToPEM(c)
|
||||||
|
}
|
||||||
|
|
||||||
func (c *certificateV1) Signature() []byte {
|
func (c *certificateV1) Signature() []byte {
|
||||||
return c.signature
|
return c.signature
|
||||||
}
|
}
|
||||||
@@ -110,8 +114,10 @@ func (c *certificateV1) CheckSignature(key []byte) bool {
|
|||||||
case Curve_CURVE25519:
|
case Curve_CURVE25519:
|
||||||
return ed25519.Verify(key, b, c.signature)
|
return ed25519.Verify(key, b, c.signature)
|
||||||
case Curve_P256:
|
case Curve_P256:
|
||||||
x, y := elliptic.Unmarshal(elliptic.P256(), key)
|
pubKey, err := ecdsa.ParseUncompressedPublicKey(elliptic.P256(), key)
|
||||||
pubKey := &ecdsa.PublicKey{Curve: elliptic.P256(), X: x, Y: y}
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
hashed := sha256.Sum256(b)
|
hashed := sha256.Sum256(b)
|
||||||
return ecdsa.VerifyASN1(pubKey, hashed[:], c.signature)
|
return ecdsa.VerifyASN1(pubKey, hashed[:], c.signature)
|
||||||
default:
|
default:
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package cert
|
package cert
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"crypto/ed25519"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"testing"
|
"testing"
|
||||||
@@ -13,6 +14,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func TestCertificateV1_Marshal(t *testing.T) {
|
func TestCertificateV1_Marshal(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
before := time.Now().Add(time.Second * -60).Round(time.Second)
|
before := time.Now().Add(time.Second * -60).Round(time.Second)
|
||||||
after := time.Now().Add(time.Second * 60).Round(time.Second)
|
after := time.Now().Add(time.Second * 60).Round(time.Second)
|
||||||
pubKey := []byte("1234567890abcedfghij1234567890ab")
|
pubKey := []byte("1234567890abcedfghij1234567890ab")
|
||||||
@@ -39,14 +41,14 @@ func TestCertificateV1_Marshal(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
b, err := nc.Marshal()
|
b, err := nc.Marshal()
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
//t.Log("Cert size:", len(b))
|
//t.Log("Cert size:", len(b))
|
||||||
|
|
||||||
nc2, err := unmarshalCertificateV1(b, nil)
|
nc2, err := unmarshalCertificateV1(b, nil)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
assert.Equal(t, nc.Version(), Version1)
|
assert.Equal(t, Version1, nc.Version())
|
||||||
assert.Equal(t, nc.Curve(), Curve_CURVE25519)
|
assert.Equal(t, Curve_CURVE25519, nc.Curve())
|
||||||
assert.Equal(t, nc.Signature(), nc2.Signature())
|
assert.Equal(t, nc.Signature(), nc2.Signature())
|
||||||
assert.Equal(t, nc.Name(), nc2.Name())
|
assert.Equal(t, nc.Name(), nc2.Name())
|
||||||
assert.Equal(t, nc.NotBefore(), nc2.NotBefore())
|
assert.Equal(t, nc.NotBefore(), nc2.NotBefore())
|
||||||
@@ -60,6 +62,58 @@ func TestCertificateV1_Marshal(t *testing.T) {
|
|||||||
assert.Equal(t, nc.Groups(), nc2.Groups())
|
assert.Equal(t, nc.Groups(), nc2.Groups())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCertificateV1_PublicKeyPem(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
before := time.Now().Add(time.Second * -60).Round(time.Second)
|
||||||
|
after := time.Now().Add(time.Second * 60).Round(time.Second)
|
||||||
|
pubKey := ed25519.PublicKey("1234567890abcedfghij1234567890ab")
|
||||||
|
|
||||||
|
nc := certificateV1{
|
||||||
|
details: detailsV1{
|
||||||
|
name: "testing",
|
||||||
|
networks: []netip.Prefix{},
|
||||||
|
unsafeNetworks: []netip.Prefix{},
|
||||||
|
groups: []string{"test-group1", "test-group2", "test-group3"},
|
||||||
|
notBefore: before,
|
||||||
|
notAfter: after,
|
||||||
|
publicKey: pubKey,
|
||||||
|
isCA: false,
|
||||||
|
issuer: "1234567890abcedfghij1234567890ab",
|
||||||
|
},
|
||||||
|
signature: []byte("1234567890abcedfghij1234567890ab"),
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, Version1, nc.Version())
|
||||||
|
assert.Equal(t, Curve_CURVE25519, nc.Curve())
|
||||||
|
pubPem := "-----BEGIN NEBULA X25519 PUBLIC KEY-----\nMTIzNDU2Nzg5MGFiY2VkZmdoaWoxMjM0NTY3ODkwYWI=\n-----END NEBULA X25519 PUBLIC KEY-----\n"
|
||||||
|
assert.Equal(t, string(nc.MarshalPublicKeyPEM()), pubPem)
|
||||||
|
assert.False(t, nc.IsCA())
|
||||||
|
|
||||||
|
nc.details.isCA = true
|
||||||
|
assert.Equal(t, Curve_CURVE25519, nc.Curve())
|
||||||
|
pubPem = "-----BEGIN NEBULA ED25519 PUBLIC KEY-----\nMTIzNDU2Nzg5MGFiY2VkZmdoaWoxMjM0NTY3ODkwYWI=\n-----END NEBULA ED25519 PUBLIC KEY-----\n"
|
||||||
|
assert.Equal(t, string(nc.MarshalPublicKeyPEM()), pubPem)
|
||||||
|
assert.True(t, nc.IsCA())
|
||||||
|
|
||||||
|
pubP256KeyPem := []byte(`-----BEGIN NEBULA P256 PUBLIC KEY-----
|
||||||
|
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
|
||||||
|
AAAAAAAAAAAAAAAAAAAAAAA=
|
||||||
|
-----END NEBULA P256 PUBLIC KEY-----
|
||||||
|
`)
|
||||||
|
pubP256Key, _, _, err := UnmarshalPublicKeyFromPEM(pubP256KeyPem)
|
||||||
|
require.NoError(t, err)
|
||||||
|
nc.details.curve = Curve_P256
|
||||||
|
nc.details.publicKey = pubP256Key
|
||||||
|
assert.Equal(t, Curve_P256, nc.Curve())
|
||||||
|
assert.Equal(t, string(nc.MarshalPublicKeyPEM()), string(pubP256KeyPem))
|
||||||
|
assert.True(t, nc.IsCA())
|
||||||
|
|
||||||
|
nc.details.isCA = false
|
||||||
|
assert.Equal(t, Curve_P256, nc.Curve())
|
||||||
|
assert.Equal(t, string(nc.MarshalPublicKeyPEM()), string(pubP256KeyPem))
|
||||||
|
assert.False(t, nc.IsCA())
|
||||||
|
}
|
||||||
|
|
||||||
func TestCertificateV1_Expired(t *testing.T) {
|
func TestCertificateV1_Expired(t *testing.T) {
|
||||||
nc := certificateV1{
|
nc := certificateV1{
|
||||||
details: detailsV1{
|
details: detailsV1{
|
||||||
@@ -99,8 +153,8 @@ func TestCertificateV1_MarshalJSON(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
b, err := nc.MarshalJSON()
|
b, err := nc.MarshalJSON()
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(
|
assert.JSONEq(
|
||||||
t,
|
t,
|
||||||
"{\"details\":{\"curve\":\"CURVE25519\",\"groups\":[\"test-group1\",\"test-group2\",\"test-group3\"],\"isCa\":false,\"issuer\":\"1234567890abcedfghij1234567890ab\",\"name\":\"testing\",\"networks\":[\"10.1.1.1/24\",\"10.1.1.2/16\"],\"notAfter\":\"0000-11-30T02:00:00Z\",\"notBefore\":\"0000-11-30T01:00:00Z\",\"publicKey\":\"313233343536373839306162636564666768696a313233343536373839306162\",\"unsafeNetworks\":[\"9.1.1.2/24\",\"9.1.1.3/16\"]},\"fingerprint\":\"3944c53d4267a229295b56cb2d27d459164c010ac97d655063ba421e0670f4ba\",\"signature\":\"313233343536373839306162636564666768696a313233343536373839306162\",\"version\":1}",
|
"{\"details\":{\"curve\":\"CURVE25519\",\"groups\":[\"test-group1\",\"test-group2\",\"test-group3\"],\"isCa\":false,\"issuer\":\"1234567890abcedfghij1234567890ab\",\"name\":\"testing\",\"networks\":[\"10.1.1.1/24\",\"10.1.1.2/16\"],\"notAfter\":\"0000-11-30T02:00:00Z\",\"notBefore\":\"0000-11-30T01:00:00Z\",\"publicKey\":\"313233343536373839306162636564666768696a313233343536373839306162\",\"unsafeNetworks\":[\"9.1.1.2/24\",\"9.1.1.3/16\"]},\"fingerprint\":\"3944c53d4267a229295b56cb2d27d459164c010ac97d655063ba421e0670f4ba\",\"signature\":\"313233343536373839306162636564666768696a313233343536373839306162\",\"version\":1}",
|
||||||
string(b),
|
string(b),
|
||||||
@@ -110,47 +164,47 @@ func TestCertificateV1_MarshalJSON(t *testing.T) {
|
|||||||
func TestCertificateV1_VerifyPrivateKey(t *testing.T) {
|
func TestCertificateV1_VerifyPrivateKey(t *testing.T) {
|
||||||
ca, _, caKey, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil)
|
ca, _, caKey, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil)
|
||||||
err := ca.VerifyPrivateKey(Curve_CURVE25519, caKey)
|
err := ca.VerifyPrivateKey(Curve_CURVE25519, caKey)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
_, _, caKey2, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil)
|
_, _, caKey2, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
err = ca.VerifyPrivateKey(Curve_CURVE25519, caKey2)
|
err = ca.VerifyPrivateKey(Curve_CURVE25519, caKey2)
|
||||||
assert.NotNil(t, err)
|
require.Error(t, err)
|
||||||
|
|
||||||
c, _, priv, _ := NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil)
|
c, _, priv, _ := NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil)
|
||||||
rawPriv, b, curve, err := UnmarshalPrivateKeyFromPEM(priv)
|
rawPriv, b, curve, err := UnmarshalPrivateKeyFromPEM(priv)
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Empty(t, b)
|
assert.Empty(t, b)
|
||||||
assert.Equal(t, Curve_CURVE25519, curve)
|
assert.Equal(t, Curve_CURVE25519, curve)
|
||||||
err = c.VerifyPrivateKey(Curve_CURVE25519, rawPriv)
|
err = c.VerifyPrivateKey(Curve_CURVE25519, rawPriv)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
_, priv2 := X25519Keypair()
|
_, priv2 := X25519Keypair()
|
||||||
err = c.VerifyPrivateKey(Curve_CURVE25519, priv2)
|
err = c.VerifyPrivateKey(Curve_CURVE25519, priv2)
|
||||||
assert.NotNil(t, err)
|
require.Error(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCertificateV1_VerifyPrivateKeyP256(t *testing.T) {
|
func TestCertificateV1_VerifyPrivateKeyP256(t *testing.T) {
|
||||||
ca, _, caKey, _ := NewTestCaCert(Version1, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil)
|
ca, _, caKey, _ := NewTestCaCert(Version1, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil)
|
||||||
err := ca.VerifyPrivateKey(Curve_P256, caKey)
|
err := ca.VerifyPrivateKey(Curve_P256, caKey)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
_, _, caKey2, _ := NewTestCaCert(Version1, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil)
|
_, _, caKey2, _ := NewTestCaCert(Version1, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
err = ca.VerifyPrivateKey(Curve_P256, caKey2)
|
err = ca.VerifyPrivateKey(Curve_P256, caKey2)
|
||||||
assert.NotNil(t, err)
|
require.Error(t, err)
|
||||||
|
|
||||||
c, _, priv, _ := NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil)
|
c, _, priv, _ := NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil)
|
||||||
rawPriv, b, curve, err := UnmarshalPrivateKeyFromPEM(priv)
|
rawPriv, b, curve, err := UnmarshalPrivateKeyFromPEM(priv)
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Empty(t, b)
|
assert.Empty(t, b)
|
||||||
assert.Equal(t, Curve_P256, curve)
|
assert.Equal(t, Curve_P256, curve)
|
||||||
err = c.VerifyPrivateKey(Curve_P256, rawPriv)
|
err = c.VerifyPrivateKey(Curve_P256, rawPriv)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
_, priv2 := P256Keypair()
|
_, priv2 := P256Keypair()
|
||||||
err = c.VerifyPrivateKey(Curve_P256, priv2)
|
err = c.VerifyPrivateKey(Curve_P256, priv2)
|
||||||
assert.NotNil(t, err)
|
require.Error(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Ensure that upgrading the protobuf library does not change how certificates
|
// Ensure that upgrading the protobuf library does not change how certificates
|
||||||
@@ -182,11 +236,11 @@ func TestMarshalingCertificateV1Consistency(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
b, err := nc.Marshal()
|
b, err := nc.Marshal()
|
||||||
require.Nil(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, "0a8e010a0774657374696e671212828284508080fcff0f8182845080feffff0f1a12838284488080fcff0f8282844880feffff0f220b746573742d67726f757031220b746573742d67726f757032220b746573742d67726f75703328cd1c30cdb8ccf0af073a20313233343536373839306162636564666768696a3132333435363738393061624a081234567890abcedf1220313233343536373839306162636564666768696a313233343536373839306162", fmt.Sprintf("%x", b))
|
assert.Equal(t, "0a8e010a0774657374696e671212828284508080fcff0f8182845080feffff0f1a12838284488080fcff0f8282844880feffff0f220b746573742d67726f757031220b746573742d67726f757032220b746573742d67726f75703328cd1c30cdb8ccf0af073a20313233343536373839306162636564666768696a3132333435363738393061624a081234567890abcedf1220313233343536373839306162636564666768696a313233343536373839306162", fmt.Sprintf("%x", b))
|
||||||
|
|
||||||
b, err = proto.Marshal(nc.getRawDetails())
|
b, err = proto.Marshal(nc.getRawDetails())
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, "0a0774657374696e671212828284508080fcff0f8182845080feffff0f1a12838284488080fcff0f8282844880feffff0f220b746573742d67726f757031220b746573742d67726f757032220b746573742d67726f75703328cd1c30cdb8ccf0af073a20313233343536373839306162636564666768696a3132333435363738393061624a081234567890abcedf", fmt.Sprintf("%x", b))
|
assert.Equal(t, "0a0774657374696e671212828284508080fcff0f8182845080feffff0f1a12838284488080fcff0f8282844880feffff0f220b746573742d67726f757031220b746573742d67726f757032220b746573742d67726f75703328cd1c30cdb8ccf0af073a20313233343536373839306162636564666768696a3132333435363738393061624a081234567890abcedf", fmt.Sprintf("%x", b))
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -201,7 +255,7 @@ func TestUnmarshalCertificateV1(t *testing.T) {
|
|||||||
// Test that we don't panic with an invalid certificate (#332)
|
// Test that we don't panic with an invalid certificate (#332)
|
||||||
data := []byte("\x98\x00\x00")
|
data := []byte("\x98\x00\x00")
|
||||||
_, err := unmarshalCertificateV1(data, nil)
|
_, err := unmarshalCertificateV1(data, nil)
|
||||||
assert.EqualError(t, err, "encoded Details was nil")
|
require.EqualError(t, err, "encoded Details was nil")
|
||||||
}
|
}
|
||||||
|
|
||||||
func appendByteSlices(b ...[]byte) []byte {
|
func appendByteSlices(b ...[]byte) []byte {
|
||||||
|
|||||||
@@ -114,6 +114,10 @@ func (c *certificateV2) PublicKey() []byte {
|
|||||||
return c.publicKey
|
return c.publicKey
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *certificateV2) MarshalPublicKeyPEM() []byte {
|
||||||
|
return marshalCertPublicKeyToPEM(c)
|
||||||
|
}
|
||||||
|
|
||||||
func (c *certificateV2) Signature() []byte {
|
func (c *certificateV2) Signature() []byte {
|
||||||
return c.signature
|
return c.signature
|
||||||
}
|
}
|
||||||
@@ -149,8 +153,10 @@ func (c *certificateV2) CheckSignature(key []byte) bool {
|
|||||||
case Curve_CURVE25519:
|
case Curve_CURVE25519:
|
||||||
return ed25519.Verify(key, b, c.signature)
|
return ed25519.Verify(key, b, c.signature)
|
||||||
case Curve_P256:
|
case Curve_P256:
|
||||||
x, y := elliptic.Unmarshal(elliptic.P256(), key)
|
pubKey, err := ecdsa.ParseUncompressedPublicKey(elliptic.P256(), key)
|
||||||
pubKey := &ecdsa.PublicKey{Curve: elliptic.P256(), X: x, Y: y}
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
hashed := sha256.Sum256(b)
|
hashed := sha256.Sum256(b)
|
||||||
return ecdsa.VerifyASN1(pubKey, hashed[:], c.signature)
|
return ecdsa.VerifyASN1(pubKey, hashed[:], c.signature)
|
||||||
default:
|
default:
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func TestCertificateV2_Marshal(t *testing.T) {
|
func TestCertificateV2_Marshal(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
before := time.Now().Add(time.Second * -60).Round(time.Second)
|
before := time.Now().Add(time.Second * -60).Round(time.Second)
|
||||||
after := time.Now().Add(time.Second * 60).Round(time.Second)
|
after := time.Now().Add(time.Second * 60).Round(time.Second)
|
||||||
pubKey := []byte("1234567890abcedfghij1234567890ab")
|
pubKey := []byte("1234567890abcedfghij1234567890ab")
|
||||||
@@ -45,14 +46,14 @@ func TestCertificateV2_Marshal(t *testing.T) {
|
|||||||
nc.rawDetails = db
|
nc.rawDetails = db
|
||||||
|
|
||||||
b, err := nc.Marshal()
|
b, err := nc.Marshal()
|
||||||
require.Nil(t, err)
|
require.NoError(t, err)
|
||||||
//t.Log("Cert size:", len(b))
|
//t.Log("Cert size:", len(b))
|
||||||
|
|
||||||
nc2, err := unmarshalCertificateV2(b, nil, Curve_CURVE25519)
|
nc2, err := unmarshalCertificateV2(b, nil, Curve_CURVE25519)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
assert.Equal(t, nc.Version(), Version2)
|
assert.Equal(t, Version2, nc.Version())
|
||||||
assert.Equal(t, nc.Curve(), Curve_CURVE25519)
|
assert.Equal(t, Curve_CURVE25519, nc.Curve())
|
||||||
assert.Equal(t, nc.Signature(), nc2.Signature())
|
assert.Equal(t, nc.Signature(), nc2.Signature())
|
||||||
assert.Equal(t, nc.Name(), nc2.Name())
|
assert.Equal(t, nc.Name(), nc2.Name())
|
||||||
assert.Equal(t, nc.NotBefore(), nc2.NotBefore())
|
assert.Equal(t, nc.NotBefore(), nc2.NotBefore())
|
||||||
@@ -75,6 +76,58 @@ func TestCertificateV2_Marshal(t *testing.T) {
|
|||||||
assert.Equal(t, nc.Groups(), nc2.Groups())
|
assert.Equal(t, nc.Groups(), nc2.Groups())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCertificateV2_PublicKeyPem(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
before := time.Now().Add(time.Second * -60).Round(time.Second)
|
||||||
|
after := time.Now().Add(time.Second * 60).Round(time.Second)
|
||||||
|
pubKey := ed25519.PublicKey("1234567890abcedfghij1234567890ab")
|
||||||
|
|
||||||
|
nc := certificateV2{
|
||||||
|
details: detailsV2{
|
||||||
|
name: "testing",
|
||||||
|
networks: []netip.Prefix{},
|
||||||
|
unsafeNetworks: []netip.Prefix{},
|
||||||
|
groups: []string{"test-group1", "test-group2", "test-group3"},
|
||||||
|
notBefore: before,
|
||||||
|
notAfter: after,
|
||||||
|
isCA: false,
|
||||||
|
issuer: "1234567890abcedfghij1234567890ab",
|
||||||
|
},
|
||||||
|
publicKey: pubKey,
|
||||||
|
signature: []byte("1234567890abcedfghij1234567890ab"),
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, Version2, nc.Version())
|
||||||
|
assert.Equal(t, Curve_CURVE25519, nc.Curve())
|
||||||
|
pubPem := "-----BEGIN NEBULA X25519 PUBLIC KEY-----\nMTIzNDU2Nzg5MGFiY2VkZmdoaWoxMjM0NTY3ODkwYWI=\n-----END NEBULA X25519 PUBLIC KEY-----\n"
|
||||||
|
assert.Equal(t, string(nc.MarshalPublicKeyPEM()), pubPem)
|
||||||
|
assert.False(t, nc.IsCA())
|
||||||
|
|
||||||
|
nc.details.isCA = true
|
||||||
|
assert.Equal(t, Curve_CURVE25519, nc.Curve())
|
||||||
|
pubPem = "-----BEGIN NEBULA ED25519 PUBLIC KEY-----\nMTIzNDU2Nzg5MGFiY2VkZmdoaWoxMjM0NTY3ODkwYWI=\n-----END NEBULA ED25519 PUBLIC KEY-----\n"
|
||||||
|
assert.Equal(t, string(nc.MarshalPublicKeyPEM()), pubPem)
|
||||||
|
assert.True(t, nc.IsCA())
|
||||||
|
|
||||||
|
pubP256KeyPem := []byte(`-----BEGIN NEBULA P256 PUBLIC KEY-----
|
||||||
|
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
|
||||||
|
AAAAAAAAAAAAAAAAAAAAAAA=
|
||||||
|
-----END NEBULA P256 PUBLIC KEY-----
|
||||||
|
`)
|
||||||
|
pubP256Key, _, _, err := UnmarshalPublicKeyFromPEM(pubP256KeyPem)
|
||||||
|
require.NoError(t, err)
|
||||||
|
nc.curve = Curve_P256
|
||||||
|
nc.publicKey = pubP256Key
|
||||||
|
assert.Equal(t, Curve_P256, nc.Curve())
|
||||||
|
assert.Equal(t, string(nc.MarshalPublicKeyPEM()), string(pubP256KeyPem))
|
||||||
|
assert.True(t, nc.IsCA())
|
||||||
|
|
||||||
|
nc.details.isCA = false
|
||||||
|
assert.Equal(t, Curve_P256, nc.Curve())
|
||||||
|
assert.Equal(t, string(nc.MarshalPublicKeyPEM()), string(pubP256KeyPem))
|
||||||
|
assert.False(t, nc.IsCA())
|
||||||
|
}
|
||||||
|
|
||||||
func TestCertificateV2_Expired(t *testing.T) {
|
func TestCertificateV2_Expired(t *testing.T) {
|
||||||
nc := certificateV2{
|
nc := certificateV2{
|
||||||
details: detailsV2{
|
details: detailsV2{
|
||||||
@@ -114,15 +167,15 @@ func TestCertificateV2_MarshalJSON(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
b, err := nc.MarshalJSON()
|
b, err := nc.MarshalJSON()
|
||||||
assert.ErrorIs(t, err, ErrMissingDetails)
|
require.ErrorIs(t, err, ErrMissingDetails)
|
||||||
|
|
||||||
rd, err := nc.details.Marshal()
|
rd, err := nc.details.Marshal()
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
nc.rawDetails = rd
|
nc.rawDetails = rd
|
||||||
b, err = nc.MarshalJSON()
|
b, err = nc.MarshalJSON()
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(
|
assert.JSONEq(
|
||||||
t,
|
t,
|
||||||
"{\"curve\":\"CURVE25519\",\"details\":{\"groups\":[\"test-group1\",\"test-group2\",\"test-group3\"],\"isCa\":false,\"issuer\":\"1234567890abcedf1234567890abcedf\",\"name\":\"testing\",\"networks\":[\"10.1.1.1/24\",\"10.1.1.2/16\"],\"notAfter\":\"0000-11-30T02:00:00Z\",\"notBefore\":\"0000-11-30T01:00:00Z\",\"unsafeNetworks\":[\"9.1.1.2/24\",\"9.1.1.3/16\"]},\"fingerprint\":\"152d9a7400c1e001cb76cffd035215ebb351f69eeb797f7f847dd086e15e56dd\",\"publicKey\":\"3132333435363738393061626365646631323334353637383930616263656466\",\"signature\":\"31323334353637383930616263656466313233343536373839306162636564663132333435363738393061626365646631323334353637383930616263656466\",\"version\":2}",
|
"{\"curve\":\"CURVE25519\",\"details\":{\"groups\":[\"test-group1\",\"test-group2\",\"test-group3\"],\"isCa\":false,\"issuer\":\"1234567890abcedf1234567890abcedf\",\"name\":\"testing\",\"networks\":[\"10.1.1.1/24\",\"10.1.1.2/16\"],\"notAfter\":\"0000-11-30T02:00:00Z\",\"notBefore\":\"0000-11-30T01:00:00Z\",\"unsafeNetworks\":[\"9.1.1.2/24\",\"9.1.1.3/16\"]},\"fingerprint\":\"152d9a7400c1e001cb76cffd035215ebb351f69eeb797f7f847dd086e15e56dd\",\"publicKey\":\"3132333435363738393061626365646631323334353637383930616263656466\",\"signature\":\"31323334353637383930616263656466313233343536373839306162636564663132333435363738393061626365646631323334353637383930616263656466\",\"version\":2}",
|
||||||
string(b),
|
string(b),
|
||||||
@@ -132,85 +185,85 @@ func TestCertificateV2_MarshalJSON(t *testing.T) {
|
|||||||
func TestCertificateV2_VerifyPrivateKey(t *testing.T) {
|
func TestCertificateV2_VerifyPrivateKey(t *testing.T) {
|
||||||
ca, _, caKey, _ := NewTestCaCert(Version2, Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil)
|
ca, _, caKey, _ := NewTestCaCert(Version2, Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil)
|
||||||
err := ca.VerifyPrivateKey(Curve_CURVE25519, caKey)
|
err := ca.VerifyPrivateKey(Curve_CURVE25519, caKey)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
err = ca.VerifyPrivateKey(Curve_CURVE25519, caKey[:16])
|
err = ca.VerifyPrivateKey(Curve_CURVE25519, caKey[:16])
|
||||||
assert.ErrorIs(t, err, ErrInvalidPrivateKey)
|
require.ErrorIs(t, err, ErrInvalidPrivateKey)
|
||||||
|
|
||||||
_, caKey2, err := ed25519.GenerateKey(rand.Reader)
|
_, caKey2, err := ed25519.GenerateKey(rand.Reader)
|
||||||
require.Nil(t, err)
|
require.NoError(t, err)
|
||||||
err = ca.VerifyPrivateKey(Curve_CURVE25519, caKey2)
|
err = ca.VerifyPrivateKey(Curve_CURVE25519, caKey2)
|
||||||
assert.ErrorIs(t, err, ErrPublicPrivateKeyMismatch)
|
require.ErrorIs(t, err, ErrPublicPrivateKeyMismatch)
|
||||||
|
|
||||||
c, _, priv, _ := NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil)
|
c, _, priv, _ := NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil)
|
||||||
rawPriv, b, curve, err := UnmarshalPrivateKeyFromPEM(priv)
|
rawPriv, b, curve, err := UnmarshalPrivateKeyFromPEM(priv)
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Empty(t, b)
|
assert.Empty(t, b)
|
||||||
assert.Equal(t, Curve_CURVE25519, curve)
|
assert.Equal(t, Curve_CURVE25519, curve)
|
||||||
err = c.VerifyPrivateKey(Curve_CURVE25519, rawPriv)
|
err = c.VerifyPrivateKey(Curve_CURVE25519, rawPriv)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
_, priv2 := X25519Keypair()
|
_, priv2 := X25519Keypair()
|
||||||
err = c.VerifyPrivateKey(Curve_P256, priv2)
|
err = c.VerifyPrivateKey(Curve_P256, priv2)
|
||||||
assert.ErrorIs(t, err, ErrPublicPrivateCurveMismatch)
|
require.ErrorIs(t, err, ErrPublicPrivateCurveMismatch)
|
||||||
|
|
||||||
err = c.VerifyPrivateKey(Curve_CURVE25519, priv2)
|
err = c.VerifyPrivateKey(Curve_CURVE25519, priv2)
|
||||||
assert.ErrorIs(t, err, ErrPublicPrivateKeyMismatch)
|
require.ErrorIs(t, err, ErrPublicPrivateKeyMismatch)
|
||||||
|
|
||||||
err = c.VerifyPrivateKey(Curve_CURVE25519, priv2[:16])
|
err = c.VerifyPrivateKey(Curve_CURVE25519, priv2[:16])
|
||||||
assert.ErrorIs(t, err, ErrInvalidPrivateKey)
|
require.ErrorIs(t, err, ErrInvalidPrivateKey)
|
||||||
|
|
||||||
ac, ok := c.(*certificateV2)
|
ac, ok := c.(*certificateV2)
|
||||||
require.True(t, ok)
|
require.True(t, ok)
|
||||||
ac.curve = Curve(99)
|
ac.curve = Curve(99)
|
||||||
err = c.VerifyPrivateKey(Curve(99), priv2)
|
err = c.VerifyPrivateKey(Curve(99), priv2)
|
||||||
assert.EqualError(t, err, "invalid curve: 99")
|
require.EqualError(t, err, "invalid curve: 99")
|
||||||
|
|
||||||
ca2, _, caKey2, _ := NewTestCaCert(Version2, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil)
|
ca2, _, caKey2, _ := NewTestCaCert(Version2, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil)
|
||||||
err = ca.VerifyPrivateKey(Curve_CURVE25519, caKey)
|
err = ca.VerifyPrivateKey(Curve_CURVE25519, caKey)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
err = ca2.VerifyPrivateKey(Curve_P256, caKey2[:16])
|
err = ca2.VerifyPrivateKey(Curve_P256, caKey2[:16])
|
||||||
assert.ErrorIs(t, err, ErrInvalidPrivateKey)
|
require.ErrorIs(t, err, ErrInvalidPrivateKey)
|
||||||
|
|
||||||
c, _, priv, _ = NewTestCert(Version2, Curve_P256, ca2, caKey2, "test", time.Time{}, time.Time{}, nil, nil, nil)
|
c, _, priv, _ = NewTestCert(Version2, Curve_P256, ca2, caKey2, "test", time.Time{}, time.Time{}, nil, nil, nil)
|
||||||
rawPriv, b, curve, err = UnmarshalPrivateKeyFromPEM(priv)
|
rawPriv, b, curve, err = UnmarshalPrivateKeyFromPEM(priv)
|
||||||
|
|
||||||
err = c.VerifyPrivateKey(Curve_P256, priv[:16])
|
err = c.VerifyPrivateKey(Curve_P256, priv[:16])
|
||||||
assert.ErrorIs(t, err, ErrInvalidPrivateKey)
|
require.ErrorIs(t, err, ErrInvalidPrivateKey)
|
||||||
|
|
||||||
err = c.VerifyPrivateKey(Curve_P256, priv)
|
err = c.VerifyPrivateKey(Curve_P256, priv)
|
||||||
assert.ErrorIs(t, err, ErrInvalidPrivateKey)
|
require.ErrorIs(t, err, ErrInvalidPrivateKey)
|
||||||
|
|
||||||
aCa, ok := ca2.(*certificateV2)
|
aCa, ok := ca2.(*certificateV2)
|
||||||
require.True(t, ok)
|
require.True(t, ok)
|
||||||
aCa.curve = Curve(99)
|
aCa.curve = Curve(99)
|
||||||
err = aCa.VerifyPrivateKey(Curve(99), priv2)
|
err = aCa.VerifyPrivateKey(Curve(99), priv2)
|
||||||
assert.EqualError(t, err, "invalid curve: 99")
|
require.EqualError(t, err, "invalid curve: 99")
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCertificateV2_VerifyPrivateKeyP256(t *testing.T) {
|
func TestCertificateV2_VerifyPrivateKeyP256(t *testing.T) {
|
||||||
ca, _, caKey, _ := NewTestCaCert(Version2, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil)
|
ca, _, caKey, _ := NewTestCaCert(Version2, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil)
|
||||||
err := ca.VerifyPrivateKey(Curve_P256, caKey)
|
err := ca.VerifyPrivateKey(Curve_P256, caKey)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
_, _, caKey2, _ := NewTestCaCert(Version2, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil)
|
_, _, caKey2, _ := NewTestCaCert(Version2, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
err = ca.VerifyPrivateKey(Curve_P256, caKey2)
|
err = ca.VerifyPrivateKey(Curve_P256, caKey2)
|
||||||
assert.NotNil(t, err)
|
require.Error(t, err)
|
||||||
|
|
||||||
c, _, priv, _ := NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil)
|
c, _, priv, _ := NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil)
|
||||||
rawPriv, b, curve, err := UnmarshalPrivateKeyFromPEM(priv)
|
rawPriv, b, curve, err := UnmarshalPrivateKeyFromPEM(priv)
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Empty(t, b)
|
assert.Empty(t, b)
|
||||||
assert.Equal(t, Curve_P256, curve)
|
assert.Equal(t, Curve_P256, curve)
|
||||||
err = c.VerifyPrivateKey(Curve_P256, rawPriv)
|
err = c.VerifyPrivateKey(Curve_P256, rawPriv)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
_, priv2 := P256Keypair()
|
_, priv2 := P256Keypair()
|
||||||
err = c.VerifyPrivateKey(Curve_P256, priv2)
|
err = c.VerifyPrivateKey(Curve_P256, priv2)
|
||||||
assert.NotNil(t, err)
|
require.Error(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCertificateV2_Copy(t *testing.T) {
|
func TestCertificateV2_Copy(t *testing.T) {
|
||||||
@@ -223,7 +276,7 @@ func TestCertificateV2_Copy(t *testing.T) {
|
|||||||
func TestUnmarshalCertificateV2(t *testing.T) {
|
func TestUnmarshalCertificateV2(t *testing.T) {
|
||||||
data := []byte("\x98\x00\x00")
|
data := []byte("\x98\x00\x00")
|
||||||
_, err := unmarshalCertificateV2(data, nil, Curve_CURVE25519)
|
_, err := unmarshalCertificateV2(data, nil, Curve_CURVE25519)
|
||||||
assert.EqualError(t, err, "bad wire format")
|
require.EqualError(t, err, "bad wire format")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCertificateV2_marshalForSigningStability(t *testing.T) {
|
func TestCertificateV2_marshalForSigningStability(t *testing.T) {
|
||||||
|
|||||||
@@ -4,19 +4,20 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
"golang.org/x/crypto/argon2"
|
"golang.org/x/crypto/argon2"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestNewArgon2Parameters(t *testing.T) {
|
func TestNewArgon2Parameters(t *testing.T) {
|
||||||
p := NewArgon2Parameters(64*1024, 4, 3)
|
p := NewArgon2Parameters(64*1024, 4, 3)
|
||||||
assert.EqualValues(t, &Argon2Parameters{
|
assert.Equal(t, &Argon2Parameters{
|
||||||
version: argon2.Version,
|
version: argon2.Version,
|
||||||
Memory: 64 * 1024,
|
Memory: 64 * 1024,
|
||||||
Parallelism: 4,
|
Parallelism: 4,
|
||||||
Iterations: 3,
|
Iterations: 3,
|
||||||
}, p)
|
}, p)
|
||||||
p = NewArgon2Parameters(2*1024*1024, 2, 1)
|
p = NewArgon2Parameters(2*1024*1024, 2, 1)
|
||||||
assert.EqualValues(t, &Argon2Parameters{
|
assert.Equal(t, &Argon2Parameters{
|
||||||
version: argon2.Version,
|
version: argon2.Version,
|
||||||
Memory: 2 * 1024 * 1024,
|
Memory: 2 * 1024 * 1024,
|
||||||
Parallelism: 2,
|
Parallelism: 2,
|
||||||
@@ -25,21 +26,21 @@ func TestNewArgon2Parameters(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestDecryptAndUnmarshalSigningPrivateKey(t *testing.T) {
|
func TestDecryptAndUnmarshalSigningPrivateKey(t *testing.T) {
|
||||||
passphrase := []byte("DO NOT USE THIS KEY")
|
passphrase := []byte("DO NOT USE")
|
||||||
privKey := []byte(`# A good key
|
privKey := []byte(`# A good key
|
||||||
-----BEGIN NEBULA ED25519 ENCRYPTED PRIVATE KEY-----
|
-----BEGIN NEBULA ED25519 ENCRYPTED PRIVATE KEY-----
|
||||||
CjwKC0FFUy0yNTYtR0NNEi0IExCAgIABGAEgBCognnjujd67Vsv99p22wfAjQaDT
|
CjsKC0FFUy0yNTYtR0NNEiwIExCAgAQYAyAEKiCPoDfGQiosxNPTbPn5EsMlc2MI
|
||||||
oCMW1mdjkU3gACKNW4MSXOWR9Sts4C81yk1RUku2gvGKs3TB9LYoklLsIizSYOLl
|
c0Bt4oz6gTrFQhX3aBJcimhHKeAuhyTGvllD0Z19fe+DFPcLH3h5VrdjVfIAajg0
|
||||||
+Vs//O1T0I1Xbml2XBAROsb/VSoDln/6LMqR4B6fn6B3GOsLBBqRI8daDl9lRMPB
|
KrbV3n9UHif/Au5skWmquNJzoW1E4MTdRbvpti6o+WdQ49DxjBFhx0YH8LBqrbPU
|
||||||
qrlJ69wer3ZUHFXA
|
0BGkUHmIO7daP24=
|
||||||
-----END NEBULA ED25519 ENCRYPTED PRIVATE KEY-----
|
-----END NEBULA ED25519 ENCRYPTED PRIVATE KEY-----
|
||||||
`)
|
`)
|
||||||
shortKey := []byte(`# A key which, once decrypted, is too short
|
shortKey := []byte(`# A key which, once decrypted, is too short
|
||||||
-----BEGIN NEBULA ED25519 ENCRYPTED PRIVATE KEY-----
|
-----BEGIN NEBULA ED25519 ENCRYPTED PRIVATE KEY-----
|
||||||
CjwKC0FFUy0yNTYtR0NNEi0IExCAgIABGAEgBCoga5h8owMEBWRSMMJKzuUvWce7
|
CjsKC0FFUy0yNTYtR0NNEiwIExCAgAQYAyAEKiAVJwdfl3r+eqi/vF6S7OMdpjfo
|
||||||
k0qlBkQmCxiuLh80MuASW70YcKt8jeEIS2axo2V6zAKA9TSMcCsJW1kDDXEtL/xe
|
hAzmTCRnr58Su4AqmBJbCv3zleYCEKYJP6UI3S8ekLMGISsgO4hm5leukCCyqT0Z
|
||||||
GLF5T7sDl5COp4LU3pGxpV+KoeQ/S3gQCAAcnaOtnJQX+aSDnbO3jCHyP7U9CHbs
|
cQ76yrberpzkJKoPLGisX8f+xdy4aXSZl7oEYWQte1+vqbtl/eY9PGZhxUQdcyq7
|
||||||
rQr3bdH3Oy/WiYU=
|
hqzIyrRqfUgVuA==
|
||||||
-----END NEBULA ED25519 ENCRYPTED PRIVATE KEY-----
|
-----END NEBULA ED25519 ENCRYPTED PRIVATE KEY-----
|
||||||
`)
|
`)
|
||||||
invalidBanner := []byte(`# Invalid banner (not encrypted)
|
invalidBanner := []byte(`# Invalid banner (not encrypted)
|
||||||
@@ -61,35 +62,35 @@ qrlJ69wer3ZUHFXA
|
|||||||
|
|
||||||
// Success test case
|
// Success test case
|
||||||
curve, k, rest, err := DecryptAndUnmarshalSigningPrivateKey(passphrase, keyBundle)
|
curve, k, rest, err := DecryptAndUnmarshalSigningPrivateKey(passphrase, keyBundle)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, Curve_CURVE25519, curve)
|
assert.Equal(t, Curve_CURVE25519, curve)
|
||||||
assert.Len(t, k, 64)
|
assert.Len(t, k, 64)
|
||||||
assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem))
|
assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem))
|
||||||
|
|
||||||
// Fail due to short key
|
// Fail due to short key
|
||||||
curve, k, rest, err = DecryptAndUnmarshalSigningPrivateKey(passphrase, rest)
|
curve, k, rest, err = DecryptAndUnmarshalSigningPrivateKey(passphrase, rest)
|
||||||
assert.EqualError(t, err, "key was not 64 bytes, is invalid ed25519 private key")
|
require.EqualError(t, err, "key was not 64 bytes, is invalid ed25519 private key")
|
||||||
assert.Nil(t, k)
|
assert.Nil(t, k)
|
||||||
assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem))
|
assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem))
|
||||||
|
|
||||||
// Fail due to invalid banner
|
// Fail due to invalid banner
|
||||||
curve, k, rest, err = DecryptAndUnmarshalSigningPrivateKey(passphrase, rest)
|
curve, k, rest, err = DecryptAndUnmarshalSigningPrivateKey(passphrase, rest)
|
||||||
assert.EqualError(t, err, "bytes did not contain a proper nebula encrypted Ed25519/ECDSA private key banner")
|
require.EqualError(t, err, "bytes did not contain a proper nebula encrypted Ed25519/ECDSA private key banner")
|
||||||
assert.Nil(t, k)
|
assert.Nil(t, k)
|
||||||
assert.Equal(t, rest, invalidPem)
|
assert.Equal(t, rest, invalidPem)
|
||||||
|
|
||||||
// Fail due to ivalid PEM format, because
|
// Fail due to ivalid PEM format, because
|
||||||
// it's missing the requisite pre-encapsulation boundary.
|
// it's missing the requisite pre-encapsulation boundary.
|
||||||
curve, k, rest, err = DecryptAndUnmarshalSigningPrivateKey(passphrase, rest)
|
curve, k, rest, err = DecryptAndUnmarshalSigningPrivateKey(passphrase, rest)
|
||||||
assert.EqualError(t, err, "input did not contain a valid PEM encoded block")
|
require.EqualError(t, err, "input did not contain a valid PEM encoded block")
|
||||||
assert.Nil(t, k)
|
assert.Nil(t, k)
|
||||||
assert.Equal(t, rest, invalidPem)
|
assert.Equal(t, rest, invalidPem)
|
||||||
|
|
||||||
// Fail due to invalid passphrase
|
// Fail due to invalid passphrase
|
||||||
curve, k, rest, err = DecryptAndUnmarshalSigningPrivateKey([]byte("invalid passphrase"), privKey)
|
curve, k, rest, err = DecryptAndUnmarshalSigningPrivateKey([]byte("invalid passphrase"), privKey)
|
||||||
assert.EqualError(t, err, "invalid passphrase or corrupt private key")
|
require.EqualError(t, err, "invalid passphrase or corrupt private key")
|
||||||
assert.Nil(t, k)
|
assert.Nil(t, k)
|
||||||
assert.Equal(t, rest, []byte{})
|
assert.Equal(t, []byte{}, rest)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestEncryptAndMarshalSigningPrivateKey(t *testing.T) {
|
func TestEncryptAndMarshalSigningPrivateKey(t *testing.T) {
|
||||||
@@ -99,14 +100,14 @@ func TestEncryptAndMarshalSigningPrivateKey(t *testing.T) {
|
|||||||
bytes := []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA")
|
bytes := []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA")
|
||||||
kdfParams := NewArgon2Parameters(64*1024, 4, 3)
|
kdfParams := NewArgon2Parameters(64*1024, 4, 3)
|
||||||
key, err := EncryptAndMarshalSigningPrivateKey(Curve_CURVE25519, bytes, passphrase, kdfParams)
|
key, err := EncryptAndMarshalSigningPrivateKey(Curve_CURVE25519, bytes, passphrase, kdfParams)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Verify the "key" can be decrypted successfully
|
// Verify the "key" can be decrypted successfully
|
||||||
curve, k, rest, err := DecryptAndUnmarshalSigningPrivateKey(passphrase, key)
|
curve, k, rest, err := DecryptAndUnmarshalSigningPrivateKey(passphrase, key)
|
||||||
assert.Len(t, k, 64)
|
assert.Len(t, k, 64)
|
||||||
assert.Equal(t, Curve_CURVE25519, curve)
|
assert.Equal(t, Curve_CURVE25519, curve)
|
||||||
assert.Equal(t, rest, []byte{})
|
assert.Equal(t, []byte{}, rest)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// EncryptAndMarshalEd25519PrivateKey does not create any errors itself
|
// EncryptAndMarshalEd25519PrivateKey does not create any errors itself
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ var (
|
|||||||
ErrPublicPrivateKeyMismatch = errors.New("public key and private key are not a pair")
|
ErrPublicPrivateKeyMismatch = errors.New("public key and private key are not a pair")
|
||||||
ErrPrivateKeyEncrypted = errors.New("private key must be decrypted")
|
ErrPrivateKeyEncrypted = errors.New("private key must be decrypted")
|
||||||
ErrCaNotFound = errors.New("could not find ca for the certificate")
|
ErrCaNotFound = errors.New("could not find ca for the certificate")
|
||||||
|
ErrUnknownVersion = errors.New("certificate version unrecognized")
|
||||||
|
|
||||||
ErrInvalidPEMBlock = errors.New("input did not contain a valid PEM encoded block")
|
ErrInvalidPEMBlock = errors.New("input did not contain a valid PEM encoded block")
|
||||||
ErrInvalidPEMCertificateBanner = errors.New("bytes did not contain a proper certificate banner")
|
ErrInvalidPEMCertificateBanner = errors.New("bytes did not contain a proper certificate banner")
|
||||||
|
|||||||
52
cert/pem.go
52
cert/pem.go
@@ -7,19 +7,26 @@ import (
|
|||||||
"golang.org/x/crypto/ed25519"
|
"golang.org/x/crypto/ed25519"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const ( //cert banners
|
||||||
CertificateBanner = "NEBULA CERTIFICATE"
|
CertificateBanner = "NEBULA CERTIFICATE"
|
||||||
CertificateV2Banner = "NEBULA CERTIFICATE V2"
|
CertificateV2Banner = "NEBULA CERTIFICATE V2"
|
||||||
X25519PrivateKeyBanner = "NEBULA X25519 PRIVATE KEY"
|
)
|
||||||
X25519PublicKeyBanner = "NEBULA X25519 PUBLIC KEY"
|
|
||||||
EncryptedEd25519PrivateKeyBanner = "NEBULA ED25519 ENCRYPTED PRIVATE KEY"
|
|
||||||
Ed25519PrivateKeyBanner = "NEBULA ED25519 PRIVATE KEY"
|
|
||||||
Ed25519PublicKeyBanner = "NEBULA ED25519 PUBLIC KEY"
|
|
||||||
|
|
||||||
P256PrivateKeyBanner = "NEBULA P256 PRIVATE KEY"
|
const ( //key-agreement-key banners
|
||||||
P256PublicKeyBanner = "NEBULA P256 PUBLIC KEY"
|
X25519PrivateKeyBanner = "NEBULA X25519 PRIVATE KEY"
|
||||||
|
X25519PublicKeyBanner = "NEBULA X25519 PUBLIC KEY"
|
||||||
|
P256PrivateKeyBanner = "NEBULA P256 PRIVATE KEY"
|
||||||
|
P256PublicKeyBanner = "NEBULA P256 PUBLIC KEY"
|
||||||
|
)
|
||||||
|
|
||||||
|
/* including "ECDSA" in the P256 banners is a clue that these keys should be used only for signing */
|
||||||
|
const ( //signing key banners
|
||||||
EncryptedECDSAP256PrivateKeyBanner = "NEBULA ECDSA P256 ENCRYPTED PRIVATE KEY"
|
EncryptedECDSAP256PrivateKeyBanner = "NEBULA ECDSA P256 ENCRYPTED PRIVATE KEY"
|
||||||
ECDSAP256PrivateKeyBanner = "NEBULA ECDSA P256 PRIVATE KEY"
|
ECDSAP256PrivateKeyBanner = "NEBULA ECDSA P256 PRIVATE KEY"
|
||||||
|
ECDSAP256PublicKeyBanner = "NEBULA ECDSA P256 PUBLIC KEY"
|
||||||
|
EncryptedEd25519PrivateKeyBanner = "NEBULA ED25519 ENCRYPTED PRIVATE KEY"
|
||||||
|
Ed25519PrivateKeyBanner = "NEBULA ED25519 PRIVATE KEY"
|
||||||
|
Ed25519PublicKeyBanner = "NEBULA ED25519 PUBLIC KEY"
|
||||||
)
|
)
|
||||||
|
|
||||||
// UnmarshalCertificateFromPEM will try to unmarshal the first pem block in a byte array, returning any non consumed
|
// UnmarshalCertificateFromPEM will try to unmarshal the first pem block in a byte array, returning any non consumed
|
||||||
@@ -51,6 +58,16 @@ func UnmarshalCertificateFromPEM(b []byte) (Certificate, []byte, error) {
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func marshalCertPublicKeyToPEM(c Certificate) []byte {
|
||||||
|
if c.IsCA() {
|
||||||
|
return MarshalSigningPublicKeyToPEM(c.Curve(), c.PublicKey())
|
||||||
|
} else {
|
||||||
|
return MarshalPublicKeyToPEM(c.Curve(), c.PublicKey())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalPublicKeyToPEM returns a PEM representation of a public key used for ECDH.
|
||||||
|
// if your public key came from a certificate, prefer Certificate.PublicKeyPEM() if possible, to avoid mistakes!
|
||||||
func MarshalPublicKeyToPEM(curve Curve, b []byte) []byte {
|
func MarshalPublicKeyToPEM(curve Curve, b []byte) []byte {
|
||||||
switch curve {
|
switch curve {
|
||||||
case Curve_CURVE25519:
|
case Curve_CURVE25519:
|
||||||
@@ -62,6 +79,19 @@ func MarshalPublicKeyToPEM(curve Curve, b []byte) []byte {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// MarshalSigningPublicKeyToPEM returns a PEM representation of a public key used for signing.
|
||||||
|
// if your public key came from a certificate, prefer Certificate.PublicKeyPEM() if possible, to avoid mistakes!
|
||||||
|
func MarshalSigningPublicKeyToPEM(curve Curve, b []byte) []byte {
|
||||||
|
switch curve {
|
||||||
|
case Curve_CURVE25519:
|
||||||
|
return pem.EncodeToMemory(&pem.Block{Type: Ed25519PublicKeyBanner, Bytes: b})
|
||||||
|
case Curve_P256:
|
||||||
|
return pem.EncodeToMemory(&pem.Block{Type: P256PublicKeyBanner, Bytes: b})
|
||||||
|
default:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func UnmarshalPublicKeyFromPEM(b []byte) ([]byte, []byte, Curve, error) {
|
func UnmarshalPublicKeyFromPEM(b []byte) ([]byte, []byte, Curve, error) {
|
||||||
k, r := pem.Decode(b)
|
k, r := pem.Decode(b)
|
||||||
if k == nil {
|
if k == nil {
|
||||||
@@ -73,7 +103,7 @@ func UnmarshalPublicKeyFromPEM(b []byte) ([]byte, []byte, Curve, error) {
|
|||||||
case X25519PublicKeyBanner, Ed25519PublicKeyBanner:
|
case X25519PublicKeyBanner, Ed25519PublicKeyBanner:
|
||||||
expectedLen = 32
|
expectedLen = 32
|
||||||
curve = Curve_CURVE25519
|
curve = Curve_CURVE25519
|
||||||
case P256PublicKeyBanner:
|
case P256PublicKeyBanner, ECDSAP256PublicKeyBanner:
|
||||||
// Uncompressed
|
// Uncompressed
|
||||||
expectedLen = 65
|
expectedLen = 65
|
||||||
curve = Curve_P256
|
curve = Curve_P256
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestUnmarshalCertificateFromPEM(t *testing.T) {
|
func TestUnmarshalCertificateFromPEM(t *testing.T) {
|
||||||
@@ -35,20 +36,20 @@ bzBEr00kERQxxTzTsH8cpYEgRoipvmExvg8WP8NdAJEYJosB
|
|||||||
cert, rest, err := UnmarshalCertificateFromPEM(certBundle)
|
cert, rest, err := UnmarshalCertificateFromPEM(certBundle)
|
||||||
assert.NotNil(t, cert)
|
assert.NotNil(t, cert)
|
||||||
assert.Equal(t, rest, append(badBanner, invalidPem...))
|
assert.Equal(t, rest, append(badBanner, invalidPem...))
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Fail due to invalid banner.
|
// Fail due to invalid banner.
|
||||||
cert, rest, err = UnmarshalCertificateFromPEM(rest)
|
cert, rest, err = UnmarshalCertificateFromPEM(rest)
|
||||||
assert.Nil(t, cert)
|
assert.Nil(t, cert)
|
||||||
assert.Equal(t, rest, invalidPem)
|
assert.Equal(t, rest, invalidPem)
|
||||||
assert.EqualError(t, err, "bytes did not contain a proper certificate banner")
|
require.EqualError(t, err, "bytes did not contain a proper certificate banner")
|
||||||
|
|
||||||
// Fail due to ivalid PEM format, because
|
// Fail due to ivalid PEM format, because
|
||||||
// it's missing the requisite pre-encapsulation boundary.
|
// it's missing the requisite pre-encapsulation boundary.
|
||||||
cert, rest, err = UnmarshalCertificateFromPEM(rest)
|
cert, rest, err = UnmarshalCertificateFromPEM(rest)
|
||||||
assert.Nil(t, cert)
|
assert.Nil(t, cert)
|
||||||
assert.Equal(t, rest, invalidPem)
|
assert.Equal(t, rest, invalidPem)
|
||||||
assert.EqualError(t, err, "input did not contain a valid PEM encoded block")
|
require.EqualError(t, err, "input did not contain a valid PEM encoded block")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestUnmarshalSigningPrivateKeyFromPEM(t *testing.T) {
|
func TestUnmarshalSigningPrivateKeyFromPEM(t *testing.T) {
|
||||||
@@ -84,33 +85,33 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
|
|||||||
assert.Len(t, k, 64)
|
assert.Len(t, k, 64)
|
||||||
assert.Equal(t, rest, appendByteSlices(privP256Key, shortKey, invalidBanner, invalidPem))
|
assert.Equal(t, rest, appendByteSlices(privP256Key, shortKey, invalidBanner, invalidPem))
|
||||||
assert.Equal(t, Curve_CURVE25519, curve)
|
assert.Equal(t, Curve_CURVE25519, curve)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Success test case
|
// Success test case
|
||||||
k, rest, curve, err = UnmarshalSigningPrivateKeyFromPEM(rest)
|
k, rest, curve, err = UnmarshalSigningPrivateKeyFromPEM(rest)
|
||||||
assert.Len(t, k, 32)
|
assert.Len(t, k, 32)
|
||||||
assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem))
|
assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem))
|
||||||
assert.Equal(t, Curve_P256, curve)
|
assert.Equal(t, Curve_P256, curve)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Fail due to short key
|
// Fail due to short key
|
||||||
k, rest, curve, err = UnmarshalSigningPrivateKeyFromPEM(rest)
|
k, rest, curve, err = UnmarshalSigningPrivateKeyFromPEM(rest)
|
||||||
assert.Nil(t, k)
|
assert.Nil(t, k)
|
||||||
assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem))
|
assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem))
|
||||||
assert.EqualError(t, err, "key was not 64 bytes, is invalid Ed25519 private key")
|
require.EqualError(t, err, "key was not 64 bytes, is invalid Ed25519 private key")
|
||||||
|
|
||||||
// Fail due to invalid banner
|
// Fail due to invalid banner
|
||||||
k, rest, curve, err = UnmarshalSigningPrivateKeyFromPEM(rest)
|
k, rest, curve, err = UnmarshalSigningPrivateKeyFromPEM(rest)
|
||||||
assert.Nil(t, k)
|
assert.Nil(t, k)
|
||||||
assert.Equal(t, rest, invalidPem)
|
assert.Equal(t, rest, invalidPem)
|
||||||
assert.EqualError(t, err, "bytes did not contain a proper Ed25519/ECDSA private key banner")
|
require.EqualError(t, err, "bytes did not contain a proper Ed25519/ECDSA private key banner")
|
||||||
|
|
||||||
// Fail due to ivalid PEM format, because
|
// Fail due to ivalid PEM format, because
|
||||||
// it's missing the requisite pre-encapsulation boundary.
|
// it's missing the requisite pre-encapsulation boundary.
|
||||||
k, rest, curve, err = UnmarshalSigningPrivateKeyFromPEM(rest)
|
k, rest, curve, err = UnmarshalSigningPrivateKeyFromPEM(rest)
|
||||||
assert.Nil(t, k)
|
assert.Nil(t, k)
|
||||||
assert.Equal(t, rest, invalidPem)
|
assert.Equal(t, rest, invalidPem)
|
||||||
assert.EqualError(t, err, "input did not contain a valid PEM encoded block")
|
require.EqualError(t, err, "input did not contain a valid PEM encoded block")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestUnmarshalPrivateKeyFromPEM(t *testing.T) {
|
func TestUnmarshalPrivateKeyFromPEM(t *testing.T) {
|
||||||
@@ -146,36 +147,37 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
|
|||||||
assert.Len(t, k, 32)
|
assert.Len(t, k, 32)
|
||||||
assert.Equal(t, rest, appendByteSlices(privP256Key, shortKey, invalidBanner, invalidPem))
|
assert.Equal(t, rest, appendByteSlices(privP256Key, shortKey, invalidBanner, invalidPem))
|
||||||
assert.Equal(t, Curve_CURVE25519, curve)
|
assert.Equal(t, Curve_CURVE25519, curve)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Success test case
|
// Success test case
|
||||||
k, rest, curve, err = UnmarshalPrivateKeyFromPEM(rest)
|
k, rest, curve, err = UnmarshalPrivateKeyFromPEM(rest)
|
||||||
assert.Len(t, k, 32)
|
assert.Len(t, k, 32)
|
||||||
assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem))
|
assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem))
|
||||||
assert.Equal(t, Curve_P256, curve)
|
assert.Equal(t, Curve_P256, curve)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Fail due to short key
|
// Fail due to short key
|
||||||
k, rest, curve, err = UnmarshalPrivateKeyFromPEM(rest)
|
k, rest, curve, err = UnmarshalPrivateKeyFromPEM(rest)
|
||||||
assert.Nil(t, k)
|
assert.Nil(t, k)
|
||||||
assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem))
|
assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem))
|
||||||
assert.EqualError(t, err, "key was not 32 bytes, is invalid CURVE25519 private key")
|
require.EqualError(t, err, "key was not 32 bytes, is invalid CURVE25519 private key")
|
||||||
|
|
||||||
// Fail due to invalid banner
|
// Fail due to invalid banner
|
||||||
k, rest, curve, err = UnmarshalPrivateKeyFromPEM(rest)
|
k, rest, curve, err = UnmarshalPrivateKeyFromPEM(rest)
|
||||||
assert.Nil(t, k)
|
assert.Nil(t, k)
|
||||||
assert.Equal(t, rest, invalidPem)
|
assert.Equal(t, rest, invalidPem)
|
||||||
assert.EqualError(t, err, "bytes did not contain a proper private key banner")
|
require.EqualError(t, err, "bytes did not contain a proper private key banner")
|
||||||
|
|
||||||
// Fail due to ivalid PEM format, because
|
// Fail due to ivalid PEM format, because
|
||||||
// it's missing the requisite pre-encapsulation boundary.
|
// it's missing the requisite pre-encapsulation boundary.
|
||||||
k, rest, curve, err = UnmarshalPrivateKeyFromPEM(rest)
|
k, rest, curve, err = UnmarshalPrivateKeyFromPEM(rest)
|
||||||
assert.Nil(t, k)
|
assert.Nil(t, k)
|
||||||
assert.Equal(t, rest, invalidPem)
|
assert.Equal(t, rest, invalidPem)
|
||||||
assert.EqualError(t, err, "input did not contain a valid PEM encoded block")
|
require.EqualError(t, err, "input did not contain a valid PEM encoded block")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestUnmarshalPublicKeyFromPEM(t *testing.T) {
|
func TestUnmarshalPublicKeyFromPEM(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
pubKey := []byte(`# A good key
|
pubKey := []byte(`# A good key
|
||||||
-----BEGIN NEBULA ED25519 PUBLIC KEY-----
|
-----BEGIN NEBULA ED25519 PUBLIC KEY-----
|
||||||
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
|
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
|
||||||
@@ -200,9 +202,9 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
|
|||||||
|
|
||||||
// Success test case
|
// Success test case
|
||||||
k, rest, curve, err := UnmarshalPublicKeyFromPEM(keyBundle)
|
k, rest, curve, err := UnmarshalPublicKeyFromPEM(keyBundle)
|
||||||
assert.Equal(t, 32, len(k))
|
assert.Len(t, k, 32)
|
||||||
assert.Equal(t, Curve_CURVE25519, curve)
|
assert.Equal(t, Curve_CURVE25519, curve)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem))
|
assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem))
|
||||||
|
|
||||||
// Fail due to short key
|
// Fail due to short key
|
||||||
@@ -210,13 +212,13 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
|
|||||||
assert.Nil(t, k)
|
assert.Nil(t, k)
|
||||||
assert.Equal(t, Curve_CURVE25519, curve)
|
assert.Equal(t, Curve_CURVE25519, curve)
|
||||||
assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem))
|
assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem))
|
||||||
assert.EqualError(t, err, "key was not 32 bytes, is invalid CURVE25519 public key")
|
require.EqualError(t, err, "key was not 32 bytes, is invalid CURVE25519 public key")
|
||||||
|
|
||||||
// Fail due to invalid banner
|
// Fail due to invalid banner
|
||||||
k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest)
|
k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest)
|
||||||
assert.Nil(t, k)
|
assert.Nil(t, k)
|
||||||
assert.Equal(t, Curve_CURVE25519, curve)
|
assert.Equal(t, Curve_CURVE25519, curve)
|
||||||
assert.EqualError(t, err, "bytes did not contain a proper public key banner")
|
require.EqualError(t, err, "bytes did not contain a proper public key banner")
|
||||||
assert.Equal(t, rest, invalidPem)
|
assert.Equal(t, rest, invalidPem)
|
||||||
|
|
||||||
// Fail due to ivalid PEM format, because
|
// Fail due to ivalid PEM format, because
|
||||||
@@ -225,10 +227,11 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
|
|||||||
assert.Nil(t, k)
|
assert.Nil(t, k)
|
||||||
assert.Equal(t, Curve_CURVE25519, curve)
|
assert.Equal(t, Curve_CURVE25519, curve)
|
||||||
assert.Equal(t, rest, invalidPem)
|
assert.Equal(t, rest, invalidPem)
|
||||||
assert.EqualError(t, err, "input did not contain a valid PEM encoded block")
|
require.EqualError(t, err, "input did not contain a valid PEM encoded block")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestUnmarshalX25519PublicKey(t *testing.T) {
|
func TestUnmarshalX25519PublicKey(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
pubKey := []byte(`# A good key
|
pubKey := []byte(`# A good key
|
||||||
-----BEGIN NEBULA X25519 PUBLIC KEY-----
|
-----BEGIN NEBULA X25519 PUBLIC KEY-----
|
||||||
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
|
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
|
||||||
@@ -239,6 +242,12 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
|
|||||||
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
|
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
|
||||||
AAAAAAAAAAAAAAAAAAAAAAA=
|
AAAAAAAAAAAAAAAAAAAAAAA=
|
||||||
-----END NEBULA P256 PUBLIC KEY-----
|
-----END NEBULA P256 PUBLIC KEY-----
|
||||||
|
`)
|
||||||
|
oldPubP256Key := []byte(`# A good key
|
||||||
|
-----BEGIN NEBULA ECDSA P256 PUBLIC KEY-----
|
||||||
|
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
|
||||||
|
AAAAAAAAAAAAAAAAAAAAAAA=
|
||||||
|
-----END NEBULA ECDSA P256 PUBLIC KEY-----
|
||||||
`)
|
`)
|
||||||
shortKey := []byte(`# A short key
|
shortKey := []byte(`# A short key
|
||||||
-----BEGIN NEBULA X25519 PUBLIC KEY-----
|
-----BEGIN NEBULA X25519 PUBLIC KEY-----
|
||||||
@@ -255,19 +264,26 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
|
|||||||
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
|
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
|
||||||
-END NEBULA X25519 PUBLIC KEY-----`)
|
-END NEBULA X25519 PUBLIC KEY-----`)
|
||||||
|
|
||||||
keyBundle := appendByteSlices(pubKey, pubP256Key, shortKey, invalidBanner, invalidPem)
|
keyBundle := appendByteSlices(pubKey, pubP256Key, oldPubP256Key, shortKey, invalidBanner, invalidPem)
|
||||||
|
|
||||||
// Success test case
|
// Success test case
|
||||||
k, rest, curve, err := UnmarshalPublicKeyFromPEM(keyBundle)
|
k, rest, curve, err := UnmarshalPublicKeyFromPEM(keyBundle)
|
||||||
assert.Equal(t, 32, len(k))
|
assert.Len(t, k, 32)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, rest, appendByteSlices(pubP256Key, shortKey, invalidBanner, invalidPem))
|
assert.Equal(t, rest, appendByteSlices(pubP256Key, oldPubP256Key, shortKey, invalidBanner, invalidPem))
|
||||||
assert.Equal(t, Curve_CURVE25519, curve)
|
assert.Equal(t, Curve_CURVE25519, curve)
|
||||||
|
|
||||||
// Success test case
|
// Success test case
|
||||||
k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest)
|
k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest)
|
||||||
assert.Equal(t, 65, len(k))
|
assert.Len(t, k, 65)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, rest, appendByteSlices(oldPubP256Key, shortKey, invalidBanner, invalidPem))
|
||||||
|
assert.Equal(t, Curve_P256, curve)
|
||||||
|
|
||||||
|
// Success test case
|
||||||
|
k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest)
|
||||||
|
assert.Len(t, k, 65)
|
||||||
|
require.NoError(t, err)
|
||||||
assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem))
|
assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem))
|
||||||
assert.Equal(t, Curve_P256, curve)
|
assert.Equal(t, Curve_P256, curve)
|
||||||
|
|
||||||
@@ -275,12 +291,12 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
|
|||||||
k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest)
|
k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest)
|
||||||
assert.Nil(t, k)
|
assert.Nil(t, k)
|
||||||
assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem))
|
assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem))
|
||||||
assert.EqualError(t, err, "key was not 32 bytes, is invalid CURVE25519 public key")
|
require.EqualError(t, err, "key was not 32 bytes, is invalid CURVE25519 public key")
|
||||||
|
|
||||||
// Fail due to invalid banner
|
// Fail due to invalid banner
|
||||||
k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest)
|
k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest)
|
||||||
assert.Nil(t, k)
|
assert.Nil(t, k)
|
||||||
assert.EqualError(t, err, "bytes did not contain a proper public key banner")
|
require.EqualError(t, err, "bytes did not contain a proper public key banner")
|
||||||
assert.Equal(t, rest, invalidPem)
|
assert.Equal(t, rest, invalidPem)
|
||||||
|
|
||||||
// Fail due to ivalid PEM format, because
|
// Fail due to ivalid PEM format, because
|
||||||
@@ -288,5 +304,5 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
|
|||||||
k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest)
|
k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest)
|
||||||
assert.Nil(t, k)
|
assert.Nil(t, k)
|
||||||
assert.Equal(t, rest, invalidPem)
|
assert.Equal(t, rest, invalidPem)
|
||||||
assert.EqualError(t, err, "input did not contain a valid PEM encoded block")
|
require.EqualError(t, err, "input did not contain a valid PEM encoded block")
|
||||||
}
|
}
|
||||||
|
|||||||
12
cert/sign.go
12
cert/sign.go
@@ -7,7 +7,6 @@ import (
|
|||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"crypto/sha256"
|
"crypto/sha256"
|
||||||
"fmt"
|
"fmt"
|
||||||
"math/big"
|
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
@@ -55,15 +54,10 @@ func (t *TBSCertificate) Sign(signer Certificate, curve Curve, key []byte) (Cert
|
|||||||
}
|
}
|
||||||
return t.SignWith(signer, curve, sp)
|
return t.SignWith(signer, curve, sp)
|
||||||
case Curve_P256:
|
case Curve_P256:
|
||||||
pk := &ecdsa.PrivateKey{
|
pk, err := ecdsa.ParseRawPrivateKey(elliptic.P256(), key)
|
||||||
PublicKey: ecdsa.PublicKey{
|
if err != nil {
|
||||||
Curve: elliptic.P256(),
|
return nil, err
|
||||||
},
|
|
||||||
// ref: https://github.com/golang/go/blob/go1.19/src/crypto/x509/sec1.go#L95
|
|
||||||
D: new(big.Int).SetBytes(key),
|
|
||||||
}
|
}
|
||||||
// ref: https://github.com/golang/go/blob/go1.19/src/crypto/x509/sec1.go#L119
|
|
||||||
pk.X, pk.Y = pk.Curve.ScalarBaseMult(key)
|
|
||||||
sp := func(certBytes []byte) ([]byte, error) {
|
sp := func(certBytes []byte) ([]byte, error) {
|
||||||
// We need to hash first for ECDSA
|
// We need to hash first for ECDSA
|
||||||
// - https://pkg.go.dev/crypto/ecdsa#SignASN1
|
// - https://pkg.go.dev/crypto/ecdsa#SignASN1
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestCertificateV1_Sign(t *testing.T) {
|
func TestCertificateV1_Sign(t *testing.T) {
|
||||||
@@ -37,14 +38,14 @@ func TestCertificateV1_Sign(t *testing.T) {
|
|||||||
|
|
||||||
pub, priv, err := ed25519.GenerateKey(rand.Reader)
|
pub, priv, err := ed25519.GenerateKey(rand.Reader)
|
||||||
c, err := tbs.Sign(&certificateV1{details: detailsV1{notBefore: before, notAfter: after}}, Curve_CURVE25519, priv)
|
c, err := tbs.Sign(&certificateV1{details: detailsV1{notBefore: before, notAfter: after}}, Curve_CURVE25519, priv)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
assert.NotNil(t, c)
|
assert.NotNil(t, c)
|
||||||
assert.True(t, c.CheckSignature(pub))
|
assert.True(t, c.CheckSignature(pub))
|
||||||
|
|
||||||
b, err := c.Marshal()
|
b, err := c.Marshal()
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
uc, err := unmarshalCertificateV1(b, nil)
|
uc, err := unmarshalCertificateV1(b, nil)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
assert.NotNil(t, uc)
|
assert.NotNil(t, uc)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -73,18 +74,18 @@ func TestCertificateV1_SignP256(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
pub := elliptic.Marshal(elliptic.P256(), priv.PublicKey.X, priv.PublicKey.Y)
|
pub := elliptic.Marshal(elliptic.P256(), priv.PublicKey.X, priv.PublicKey.Y)
|
||||||
rawPriv := priv.D.FillBytes(make([]byte, 32))
|
rawPriv := priv.D.FillBytes(make([]byte, 32))
|
||||||
|
|
||||||
c, err := tbs.Sign(&certificateV1{details: detailsV1{notBefore: before, notAfter: after}}, Curve_P256, rawPriv)
|
c, err := tbs.Sign(&certificateV1{details: detailsV1{notBefore: before, notAfter: after}}, Curve_P256, rawPriv)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
assert.NotNil(t, c)
|
assert.NotNil(t, c)
|
||||||
assert.True(t, c.CheckSignature(pub))
|
assert.True(t, c.CheckSignature(pub))
|
||||||
|
|
||||||
b, err := c.Marshal()
|
b, err := c.Marshal()
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
uc, err := unmarshalCertificateV1(b, nil)
|
uc, err := unmarshalCertificateV1(b, nil)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
assert.NotNil(t, uc)
|
assert.NotNil(t, uc)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -114,6 +114,33 @@ func NewTestCert(v cert.Version, curve cert.Curve, ca cert.Certificate, key []by
|
|||||||
return c, pub, cert.MarshalPrivateKeyToPEM(curve, priv), pem
|
return c, pub, cert.MarshalPrivateKeyToPEM(curve, priv), pem
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func NewTestCertDifferentVersion(c cert.Certificate, v cert.Version, ca cert.Certificate, key []byte) (cert.Certificate, []byte) {
|
||||||
|
nc := &cert.TBSCertificate{
|
||||||
|
Version: v,
|
||||||
|
Curve: c.Curve(),
|
||||||
|
Name: c.Name(),
|
||||||
|
Networks: c.Networks(),
|
||||||
|
UnsafeNetworks: c.UnsafeNetworks(),
|
||||||
|
Groups: c.Groups(),
|
||||||
|
NotBefore: time.Unix(c.NotBefore().Unix(), 0),
|
||||||
|
NotAfter: time.Unix(c.NotAfter().Unix(), 0),
|
||||||
|
PublicKey: c.PublicKey(),
|
||||||
|
IsCA: false,
|
||||||
|
}
|
||||||
|
|
||||||
|
c, err := nc.Sign(ca, ca.Curve(), key)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
pem, err := c.MarshalPEM()
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return c, pem
|
||||||
|
}
|
||||||
|
|
||||||
func X25519Keypair() ([]byte, []byte) {
|
func X25519Keypair() ([]byte, []byte) {
|
||||||
privkey := make([]byte, 32)
|
privkey := make([]byte, 32)
|
||||||
if _, err := io.ReadFull(rand.Reader, privkey); err != nil {
|
if _, err := io.ReadFull(rand.Reader, privkey); err != nil {
|
||||||
|
|||||||
@@ -173,23 +173,26 @@ func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error
|
|||||||
|
|
||||||
var passphrase []byte
|
var passphrase []byte
|
||||||
if !isP11 && *cf.encryption {
|
if !isP11 && *cf.encryption {
|
||||||
for i := 0; i < 5; i++ {
|
passphrase = []byte(os.Getenv("NEBULA_CA_PASSPHRASE"))
|
||||||
out.Write([]byte("Enter passphrase: "))
|
|
||||||
passphrase, err = pr.ReadPassword()
|
|
||||||
|
|
||||||
if err == ErrNoTerminal {
|
|
||||||
return fmt.Errorf("out-key must be encrypted interactively")
|
|
||||||
} else if err != nil {
|
|
||||||
return fmt.Errorf("error reading passphrase: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(passphrase) > 0 {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(passphrase) == 0 {
|
if len(passphrase) == 0 {
|
||||||
return fmt.Errorf("no passphrase specified, remove -encrypt flag to write out-key in plaintext")
|
for i := 0; i < 5; i++ {
|
||||||
|
out.Write([]byte("Enter passphrase: "))
|
||||||
|
passphrase, err = pr.ReadPassword()
|
||||||
|
|
||||||
|
if err == ErrNoTerminal {
|
||||||
|
return fmt.Errorf("out-key must be encrypted interactively")
|
||||||
|
} else if err != nil {
|
||||||
|
return fmt.Errorf("error reading passphrase: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(passphrase) > 0 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(passphrase) == 0 {
|
||||||
|
return fmt.Errorf("no passphrase specified, remove -encrypt flag to write out-key in plaintext")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ import (
|
|||||||
|
|
||||||
"github.com/slackhq/nebula/cert"
|
"github.com/slackhq/nebula/cert"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
func Test_caSummary(t *testing.T) {
|
func Test_caSummary(t *testing.T) {
|
||||||
@@ -89,75 +90,75 @@ func Test_ca(t *testing.T) {
|
|||||||
assertHelpError(t, ca(
|
assertHelpError(t, ca(
|
||||||
[]string{"-version", "1", "-out-key", "nope", "-out-crt", "nope", "duration", "100m"}, ob, eb, nopw,
|
[]string{"-version", "1", "-out-key", "nope", "-out-crt", "nope", "duration", "100m"}, ob, eb, nopw,
|
||||||
), "-name is required")
|
), "-name is required")
|
||||||
assert.Equal(t, "", ob.String())
|
assert.Empty(t, ob.String())
|
||||||
assert.Equal(t, "", eb.String())
|
assert.Empty(t, eb.String())
|
||||||
|
|
||||||
// ipv4 only ips
|
// ipv4 only ips
|
||||||
assertHelpError(t, ca([]string{"-version", "1", "-name", "ipv6", "-ips", "100::100/100"}, ob, eb, nopw), "invalid -networks definition: v1 certificates can only be ipv4, have 100::100/100")
|
assertHelpError(t, ca([]string{"-version", "1", "-name", "ipv6", "-ips", "100::100/100"}, ob, eb, nopw), "invalid -networks definition: v1 certificates can only be ipv4, have 100::100/100")
|
||||||
assert.Equal(t, "", ob.String())
|
assert.Empty(t, ob.String())
|
||||||
assert.Equal(t, "", eb.String())
|
assert.Empty(t, eb.String())
|
||||||
|
|
||||||
// ipv4 only subnets
|
// ipv4 only subnets
|
||||||
assertHelpError(t, ca([]string{"-version", "1", "-name", "ipv6", "-subnets", "100::100/100"}, ob, eb, nopw), "invalid -unsafe-networks definition: v1 certificates can only be ipv4, have 100::100/100")
|
assertHelpError(t, ca([]string{"-version", "1", "-name", "ipv6", "-subnets", "100::100/100"}, ob, eb, nopw), "invalid -unsafe-networks definition: v1 certificates can only be ipv4, have 100::100/100")
|
||||||
assert.Equal(t, "", ob.String())
|
assert.Empty(t, ob.String())
|
||||||
assert.Equal(t, "", eb.String())
|
assert.Empty(t, eb.String())
|
||||||
|
|
||||||
// failed key write
|
// failed key write
|
||||||
ob.Reset()
|
ob.Reset()
|
||||||
eb.Reset()
|
eb.Reset()
|
||||||
args := []string{"-version", "1", "-name", "test", "-duration", "100m", "-out-crt", "/do/not/write/pleasecrt", "-out-key", "/do/not/write/pleasekey"}
|
args := []string{"-version", "1", "-name", "test", "-duration", "100m", "-out-crt", "/do/not/write/pleasecrt", "-out-key", "/do/not/write/pleasekey"}
|
||||||
assert.EqualError(t, ca(args, ob, eb, nopw), "error while writing out-key: open /do/not/write/pleasekey: "+NoSuchDirError)
|
require.EqualError(t, ca(args, ob, eb, nopw), "error while writing out-key: open /do/not/write/pleasekey: "+NoSuchDirError)
|
||||||
assert.Equal(t, "", ob.String())
|
assert.Empty(t, ob.String())
|
||||||
assert.Equal(t, "", eb.String())
|
assert.Empty(t, eb.String())
|
||||||
|
|
||||||
// create temp key file
|
// create temp key file
|
||||||
keyF, err := os.CreateTemp("", "test.key")
|
keyF, err := os.CreateTemp("", "test.key")
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
assert.Nil(t, os.Remove(keyF.Name()))
|
require.NoError(t, os.Remove(keyF.Name()))
|
||||||
|
|
||||||
// failed cert write
|
// failed cert write
|
||||||
ob.Reset()
|
ob.Reset()
|
||||||
eb.Reset()
|
eb.Reset()
|
||||||
args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-out-crt", "/do/not/write/pleasecrt", "-out-key", keyF.Name()}
|
args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-out-crt", "/do/not/write/pleasecrt", "-out-key", keyF.Name()}
|
||||||
assert.EqualError(t, ca(args, ob, eb, nopw), "error while writing out-crt: open /do/not/write/pleasecrt: "+NoSuchDirError)
|
require.EqualError(t, ca(args, ob, eb, nopw), "error while writing out-crt: open /do/not/write/pleasecrt: "+NoSuchDirError)
|
||||||
assert.Equal(t, "", ob.String())
|
assert.Empty(t, ob.String())
|
||||||
assert.Equal(t, "", eb.String())
|
assert.Empty(t, eb.String())
|
||||||
|
|
||||||
// create temp cert file
|
// create temp cert file
|
||||||
crtF, err := os.CreateTemp("", "test.crt")
|
crtF, err := os.CreateTemp("", "test.crt")
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
assert.Nil(t, os.Remove(crtF.Name()))
|
require.NoError(t, os.Remove(crtF.Name()))
|
||||||
assert.Nil(t, os.Remove(keyF.Name()))
|
require.NoError(t, os.Remove(keyF.Name()))
|
||||||
|
|
||||||
// test proper cert with removed empty groups and subnets
|
// test proper cert with removed empty groups and subnets
|
||||||
ob.Reset()
|
ob.Reset()
|
||||||
eb.Reset()
|
eb.Reset()
|
||||||
args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
|
args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
|
||||||
assert.Nil(t, ca(args, ob, eb, nopw))
|
require.NoError(t, ca(args, ob, eb, nopw))
|
||||||
assert.Equal(t, "", ob.String())
|
assert.Empty(t, ob.String())
|
||||||
assert.Equal(t, "", eb.String())
|
assert.Empty(t, eb.String())
|
||||||
|
|
||||||
// read cert and key files
|
// read cert and key files
|
||||||
rb, _ := os.ReadFile(keyF.Name())
|
rb, _ := os.ReadFile(keyF.Name())
|
||||||
lKey, b, c, err := cert.UnmarshalSigningPrivateKeyFromPEM(rb)
|
lKey, b, c, err := cert.UnmarshalSigningPrivateKeyFromPEM(rb)
|
||||||
assert.Equal(t, cert.Curve_CURVE25519, c)
|
assert.Equal(t, cert.Curve_CURVE25519, c)
|
||||||
assert.Len(t, b, 0)
|
assert.Empty(t, b)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
assert.Len(t, lKey, 64)
|
assert.Len(t, lKey, 64)
|
||||||
|
|
||||||
rb, _ = os.ReadFile(crtF.Name())
|
rb, _ = os.ReadFile(crtF.Name())
|
||||||
lCrt, b, err := cert.UnmarshalCertificateFromPEM(rb)
|
lCrt, b, err := cert.UnmarshalCertificateFromPEM(rb)
|
||||||
assert.Len(t, b, 0)
|
assert.Empty(t, b)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
assert.Equal(t, "test", lCrt.Name())
|
assert.Equal(t, "test", lCrt.Name())
|
||||||
assert.Len(t, lCrt.Networks(), 0)
|
assert.Empty(t, lCrt.Networks())
|
||||||
assert.True(t, lCrt.IsCA())
|
assert.True(t, lCrt.IsCA())
|
||||||
assert.Equal(t, []string{"1", "2", "3", "4", "5"}, lCrt.Groups())
|
assert.Equal(t, []string{"1", "2", "3", "4", "5"}, lCrt.Groups())
|
||||||
assert.Len(t, lCrt.UnsafeNetworks(), 0)
|
assert.Empty(t, lCrt.UnsafeNetworks())
|
||||||
assert.Len(t, lCrt.PublicKey(), 32)
|
assert.Len(t, lCrt.PublicKey(), 32)
|
||||||
assert.Equal(t, time.Duration(time.Minute*100), lCrt.NotAfter().Sub(lCrt.NotBefore()))
|
assert.Equal(t, time.Duration(time.Minute*100), lCrt.NotAfter().Sub(lCrt.NotBefore()))
|
||||||
assert.Equal(t, "", lCrt.Issuer())
|
assert.Empty(t, lCrt.Issuer())
|
||||||
assert.True(t, lCrt.CheckSignature(lCrt.PublicKey()))
|
assert.True(t, lCrt.CheckSignature(lCrt.PublicKey()))
|
||||||
|
|
||||||
// test encrypted key
|
// test encrypted key
|
||||||
@@ -166,15 +167,26 @@ func Test_ca(t *testing.T) {
|
|||||||
ob.Reset()
|
ob.Reset()
|
||||||
eb.Reset()
|
eb.Reset()
|
||||||
args = []string{"-version", "1", "-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
|
args = []string{"-version", "1", "-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
|
||||||
assert.Nil(t, ca(args, ob, eb, testpw))
|
require.NoError(t, ca(args, ob, eb, testpw))
|
||||||
assert.Equal(t, pwPromptOb, ob.String())
|
assert.Equal(t, pwPromptOb, ob.String())
|
||||||
assert.Equal(t, "", eb.String())
|
assert.Empty(t, eb.String())
|
||||||
|
|
||||||
|
// test encrypted key with passphrase environment variable
|
||||||
|
os.Remove(keyF.Name())
|
||||||
|
os.Remove(crtF.Name())
|
||||||
|
ob.Reset()
|
||||||
|
eb.Reset()
|
||||||
|
args = []string{"-version", "1", "-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
|
||||||
|
os.Setenv("NEBULA_CA_PASSPHRASE", string(passphrase))
|
||||||
|
require.NoError(t, ca(args, ob, eb, testpw))
|
||||||
|
assert.Empty(t, eb.String())
|
||||||
|
os.Setenv("NEBULA_CA_PASSPHRASE", "")
|
||||||
|
|
||||||
// read encrypted key file and verify default params
|
// read encrypted key file and verify default params
|
||||||
rb, _ = os.ReadFile(keyF.Name())
|
rb, _ = os.ReadFile(keyF.Name())
|
||||||
k, _ := pem.Decode(rb)
|
k, _ := pem.Decode(rb)
|
||||||
ned, err := cert.UnmarshalNebulaEncryptedData(k.Bytes)
|
ned, err := cert.UnmarshalNebulaEncryptedData(k.Bytes)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
// we won't know salt in advance, so just check start of string
|
// we won't know salt in advance, so just check start of string
|
||||||
assert.Equal(t, uint32(2*1024*1024), ned.EncryptionMetadata.Argon2Parameters.Memory)
|
assert.Equal(t, uint32(2*1024*1024), ned.EncryptionMetadata.Argon2Parameters.Memory)
|
||||||
assert.Equal(t, uint8(4), ned.EncryptionMetadata.Argon2Parameters.Parallelism)
|
assert.Equal(t, uint8(4), ned.EncryptionMetadata.Argon2Parameters.Parallelism)
|
||||||
@@ -184,8 +196,8 @@ func Test_ca(t *testing.T) {
|
|||||||
var curve cert.Curve
|
var curve cert.Curve
|
||||||
curve, lKey, b, err = cert.DecryptAndUnmarshalSigningPrivateKey(passphrase, rb)
|
curve, lKey, b, err = cert.DecryptAndUnmarshalSigningPrivateKey(passphrase, rb)
|
||||||
assert.Equal(t, cert.Curve_CURVE25519, curve)
|
assert.Equal(t, cert.Curve_CURVE25519, curve)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
assert.Len(t, b, 0)
|
assert.Empty(t, b)
|
||||||
assert.Len(t, lKey, 64)
|
assert.Len(t, lKey, 64)
|
||||||
|
|
||||||
// test when reading passsword results in an error
|
// test when reading passsword results in an error
|
||||||
@@ -194,9 +206,9 @@ func Test_ca(t *testing.T) {
|
|||||||
ob.Reset()
|
ob.Reset()
|
||||||
eb.Reset()
|
eb.Reset()
|
||||||
args = []string{"-version", "1", "-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
|
args = []string{"-version", "1", "-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
|
||||||
assert.Error(t, ca(args, ob, eb, errpw))
|
require.Error(t, ca(args, ob, eb, errpw))
|
||||||
assert.Equal(t, pwPromptOb, ob.String())
|
assert.Equal(t, pwPromptOb, ob.String())
|
||||||
assert.Equal(t, "", eb.String())
|
assert.Empty(t, eb.String())
|
||||||
|
|
||||||
// test when user fails to enter a password
|
// test when user fails to enter a password
|
||||||
os.Remove(keyF.Name())
|
os.Remove(keyF.Name())
|
||||||
@@ -204,9 +216,9 @@ func Test_ca(t *testing.T) {
|
|||||||
ob.Reset()
|
ob.Reset()
|
||||||
eb.Reset()
|
eb.Reset()
|
||||||
args = []string{"-version", "1", "-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
|
args = []string{"-version", "1", "-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
|
||||||
assert.EqualError(t, ca(args, ob, eb, nopw), "no passphrase specified, remove -encrypt flag to write out-key in plaintext")
|
require.EqualError(t, ca(args, ob, eb, nopw), "no passphrase specified, remove -encrypt flag to write out-key in plaintext")
|
||||||
assert.Equal(t, strings.Repeat(pwPromptOb, 5), ob.String()) // prompts 5 times before giving up
|
assert.Equal(t, strings.Repeat(pwPromptOb, 5), ob.String()) // prompts 5 times before giving up
|
||||||
assert.Equal(t, "", eb.String())
|
assert.Empty(t, eb.String())
|
||||||
|
|
||||||
// create valid cert/key for overwrite tests
|
// create valid cert/key for overwrite tests
|
||||||
os.Remove(keyF.Name())
|
os.Remove(keyF.Name())
|
||||||
@@ -214,24 +226,24 @@ func Test_ca(t *testing.T) {
|
|||||||
ob.Reset()
|
ob.Reset()
|
||||||
eb.Reset()
|
eb.Reset()
|
||||||
args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
|
args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
|
||||||
assert.Nil(t, ca(args, ob, eb, nopw))
|
require.NoError(t, ca(args, ob, eb, nopw))
|
||||||
|
|
||||||
// test that we won't overwrite existing certificate file
|
// test that we won't overwrite existing certificate file
|
||||||
ob.Reset()
|
ob.Reset()
|
||||||
eb.Reset()
|
eb.Reset()
|
||||||
args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
|
args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
|
||||||
assert.EqualError(t, ca(args, ob, eb, nopw), "refusing to overwrite existing CA key: "+keyF.Name())
|
require.EqualError(t, ca(args, ob, eb, nopw), "refusing to overwrite existing CA key: "+keyF.Name())
|
||||||
assert.Equal(t, "", ob.String())
|
assert.Empty(t, ob.String())
|
||||||
assert.Equal(t, "", eb.String())
|
assert.Empty(t, eb.String())
|
||||||
|
|
||||||
// test that we won't overwrite existing key file
|
// test that we won't overwrite existing key file
|
||||||
os.Remove(keyF.Name())
|
os.Remove(keyF.Name())
|
||||||
ob.Reset()
|
ob.Reset()
|
||||||
eb.Reset()
|
eb.Reset()
|
||||||
args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
|
args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
|
||||||
assert.EqualError(t, ca(args, ob, eb, nopw), "refusing to overwrite existing CA cert: "+crtF.Name())
|
require.EqualError(t, ca(args, ob, eb, nopw), "refusing to overwrite existing CA cert: "+crtF.Name())
|
||||||
assert.Equal(t, "", ob.String())
|
assert.Empty(t, ob.String())
|
||||||
assert.Equal(t, "", eb.String())
|
assert.Empty(t, eb.String())
|
||||||
os.Remove(keyF.Name())
|
os.Remove(keyF.Name())
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
|
|
||||||
"github.com/slackhq/nebula/cert"
|
"github.com/slackhq/nebula/cert"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
func Test_keygenSummary(t *testing.T) {
|
func Test_keygenSummary(t *testing.T) {
|
||||||
@@ -36,59 +37,59 @@ func Test_keygen(t *testing.T) {
|
|||||||
|
|
||||||
// required args
|
// required args
|
||||||
assertHelpError(t, keygen([]string{"-out-pub", "nope"}, ob, eb), "-out-key is required")
|
assertHelpError(t, keygen([]string{"-out-pub", "nope"}, ob, eb), "-out-key is required")
|
||||||
assert.Equal(t, "", ob.String())
|
assert.Empty(t, ob.String())
|
||||||
assert.Equal(t, "", eb.String())
|
assert.Empty(t, eb.String())
|
||||||
|
|
||||||
assertHelpError(t, keygen([]string{"-out-key", "nope"}, ob, eb), "-out-pub is required")
|
assertHelpError(t, keygen([]string{"-out-key", "nope"}, ob, eb), "-out-pub is required")
|
||||||
assert.Equal(t, "", ob.String())
|
assert.Empty(t, ob.String())
|
||||||
assert.Equal(t, "", eb.String())
|
assert.Empty(t, eb.String())
|
||||||
|
|
||||||
// failed key write
|
// failed key write
|
||||||
ob.Reset()
|
ob.Reset()
|
||||||
eb.Reset()
|
eb.Reset()
|
||||||
args := []string{"-out-pub", "/do/not/write/pleasepub", "-out-key", "/do/not/write/pleasekey"}
|
args := []string{"-out-pub", "/do/not/write/pleasepub", "-out-key", "/do/not/write/pleasekey"}
|
||||||
assert.EqualError(t, keygen(args, ob, eb), "error while writing out-key: open /do/not/write/pleasekey: "+NoSuchDirError)
|
require.EqualError(t, keygen(args, ob, eb), "error while writing out-key: open /do/not/write/pleasekey: "+NoSuchDirError)
|
||||||
assert.Equal(t, "", ob.String())
|
assert.Empty(t, ob.String())
|
||||||
assert.Equal(t, "", eb.String())
|
assert.Empty(t, eb.String())
|
||||||
|
|
||||||
// create temp key file
|
// create temp key file
|
||||||
keyF, err := os.CreateTemp("", "test.key")
|
keyF, err := os.CreateTemp("", "test.key")
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
defer os.Remove(keyF.Name())
|
defer os.Remove(keyF.Name())
|
||||||
|
|
||||||
// failed pub write
|
// failed pub write
|
||||||
ob.Reset()
|
ob.Reset()
|
||||||
eb.Reset()
|
eb.Reset()
|
||||||
args = []string{"-out-pub", "/do/not/write/pleasepub", "-out-key", keyF.Name()}
|
args = []string{"-out-pub", "/do/not/write/pleasepub", "-out-key", keyF.Name()}
|
||||||
assert.EqualError(t, keygen(args, ob, eb), "error while writing out-pub: open /do/not/write/pleasepub: "+NoSuchDirError)
|
require.EqualError(t, keygen(args, ob, eb), "error while writing out-pub: open /do/not/write/pleasepub: "+NoSuchDirError)
|
||||||
assert.Equal(t, "", ob.String())
|
assert.Empty(t, ob.String())
|
||||||
assert.Equal(t, "", eb.String())
|
assert.Empty(t, eb.String())
|
||||||
|
|
||||||
// create temp pub file
|
// create temp pub file
|
||||||
pubF, err := os.CreateTemp("", "test.pub")
|
pubF, err := os.CreateTemp("", "test.pub")
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
defer os.Remove(pubF.Name())
|
defer os.Remove(pubF.Name())
|
||||||
|
|
||||||
// test proper keygen
|
// test proper keygen
|
||||||
ob.Reset()
|
ob.Reset()
|
||||||
eb.Reset()
|
eb.Reset()
|
||||||
args = []string{"-out-pub", pubF.Name(), "-out-key", keyF.Name()}
|
args = []string{"-out-pub", pubF.Name(), "-out-key", keyF.Name()}
|
||||||
assert.Nil(t, keygen(args, ob, eb))
|
require.NoError(t, keygen(args, ob, eb))
|
||||||
assert.Equal(t, "", ob.String())
|
assert.Empty(t, ob.String())
|
||||||
assert.Equal(t, "", eb.String())
|
assert.Empty(t, eb.String())
|
||||||
|
|
||||||
// read cert and key files
|
// read cert and key files
|
||||||
rb, _ := os.ReadFile(keyF.Name())
|
rb, _ := os.ReadFile(keyF.Name())
|
||||||
lKey, b, curve, err := cert.UnmarshalPrivateKeyFromPEM(rb)
|
lKey, b, curve, err := cert.UnmarshalPrivateKeyFromPEM(rb)
|
||||||
assert.Equal(t, cert.Curve_CURVE25519, curve)
|
assert.Equal(t, cert.Curve_CURVE25519, curve)
|
||||||
assert.Len(t, b, 0)
|
assert.Empty(t, b)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
assert.Len(t, lKey, 32)
|
assert.Len(t, lKey, 32)
|
||||||
|
|
||||||
rb, _ = os.ReadFile(pubF.Name())
|
rb, _ = os.ReadFile(pubF.Name())
|
||||||
lPub, b, curve, err := cert.UnmarshalPublicKeyFromPEM(rb)
|
lPub, b, curve, err := cert.UnmarshalPublicKeyFromPEM(rb)
|
||||||
assert.Equal(t, cert.Curve_CURVE25519, curve)
|
assert.Equal(t, cert.Curve_CURVE25519, curve)
|
||||||
assert.Len(t, b, 0)
|
assert.Empty(t, b)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
assert.Len(t, lPub, 32)
|
assert.Len(t, lPub, 32)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,10 +5,28 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"os"
|
"os"
|
||||||
|
"runtime/debug"
|
||||||
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// A version string that can be set with
|
||||||
|
//
|
||||||
|
// -ldflags "-X main.Build=SOMEVERSION"
|
||||||
|
//
|
||||||
|
// at compile-time.
|
||||||
var Build string
|
var Build string
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
if Build == "" {
|
||||||
|
info, ok := debug.ReadBuildInfo()
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
Build = strings.TrimPrefix(info.Main.Version, "v")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
type helpError struct {
|
type helpError struct {
|
||||||
s string
|
s string
|
||||||
}
|
}
|
||||||
@@ -17,7 +35,7 @@ func (he *helpError) Error() string {
|
|||||||
return he.s
|
return he.s
|
||||||
}
|
}
|
||||||
|
|
||||||
func newHelpErrorf(s string, v ...interface{}) error {
|
func newHelpErrorf(s string, v ...any) error {
|
||||||
return &helpError{s: fmt.Sprintf(s, v...)}
|
return &helpError{s: fmt.Sprintf(s, v...)}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
func Test_help(t *testing.T) {
|
func Test_help(t *testing.T) {
|
||||||
@@ -79,7 +80,7 @@ func assertHelpError(t *testing.T, err error, msg string) {
|
|||||||
t.Fatal(fmt.Sprintf("err was not a helpError: %q, expected %q", err, msg))
|
t.Fatal(fmt.Sprintf("err was not a helpError: %q, expected %q", err, msg))
|
||||||
}
|
}
|
||||||
|
|
||||||
assert.EqualError(t, err, msg)
|
require.EqualError(t, err, msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
func optionalPkcs11String(msg string) string {
|
func optionalPkcs11String(msg string) string {
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ import (
|
|||||||
|
|
||||||
"github.com/slackhq/nebula/cert"
|
"github.com/slackhq/nebula/cert"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
func Test_printSummary(t *testing.T) {
|
func Test_printSummary(t *testing.T) {
|
||||||
@@ -42,30 +43,30 @@ func Test_printCert(t *testing.T) {
|
|||||||
|
|
||||||
// no path
|
// no path
|
||||||
err := printCert([]string{}, ob, eb)
|
err := printCert([]string{}, ob, eb)
|
||||||
assert.Equal(t, "", ob.String())
|
assert.Empty(t, ob.String())
|
||||||
assert.Equal(t, "", eb.String())
|
assert.Empty(t, eb.String())
|
||||||
assertHelpError(t, err, "-path is required")
|
assertHelpError(t, err, "-path is required")
|
||||||
|
|
||||||
// no cert at path
|
// no cert at path
|
||||||
ob.Reset()
|
ob.Reset()
|
||||||
eb.Reset()
|
eb.Reset()
|
||||||
err = printCert([]string{"-path", "does_not_exist"}, ob, eb)
|
err = printCert([]string{"-path", "does_not_exist"}, ob, eb)
|
||||||
assert.Equal(t, "", ob.String())
|
assert.Empty(t, ob.String())
|
||||||
assert.Equal(t, "", eb.String())
|
assert.Empty(t, eb.String())
|
||||||
assert.EqualError(t, err, "unable to read cert; open does_not_exist: "+NoSuchFileError)
|
require.EqualError(t, err, "unable to read cert; open does_not_exist: "+NoSuchFileError)
|
||||||
|
|
||||||
// invalid cert at path
|
// invalid cert at path
|
||||||
ob.Reset()
|
ob.Reset()
|
||||||
eb.Reset()
|
eb.Reset()
|
||||||
tf, err := os.CreateTemp("", "print-cert")
|
tf, err := os.CreateTemp("", "print-cert")
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
defer os.Remove(tf.Name())
|
defer os.Remove(tf.Name())
|
||||||
|
|
||||||
tf.WriteString("-----BEGIN NOPE-----")
|
tf.WriteString("-----BEGIN NOPE-----")
|
||||||
err = printCert([]string{"-path", tf.Name()}, ob, eb)
|
err = printCert([]string{"-path", tf.Name()}, ob, eb)
|
||||||
assert.Equal(t, "", ob.String())
|
assert.Empty(t, ob.String())
|
||||||
assert.Equal(t, "", eb.String())
|
assert.Empty(t, eb.String())
|
||||||
assert.EqualError(t, err, "error while unmarshaling cert: input did not contain a valid PEM encoded block")
|
require.EqualError(t, err, "error while unmarshaling cert: input did not contain a valid PEM encoded block")
|
||||||
|
|
||||||
// test multiple certs
|
// test multiple certs
|
||||||
ob.Reset()
|
ob.Reset()
|
||||||
@@ -84,7 +85,7 @@ func Test_printCert(t *testing.T) {
|
|||||||
fp, _ := c.Fingerprint()
|
fp, _ := c.Fingerprint()
|
||||||
pk := hex.EncodeToString(c.PublicKey())
|
pk := hex.EncodeToString(c.PublicKey())
|
||||||
sig := hex.EncodeToString(c.Signature())
|
sig := hex.EncodeToString(c.Signature())
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(
|
assert.Equal(
|
||||||
t,
|
t,
|
||||||
//"NebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: "+c.Issuer()+"\n\t\tPublic key: "+pk+"\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: "+fp+"\n\tSignature: "+sig+"\n}\nNebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: "+c.Issuer()+"\n\t\tPublic key: "+pk+"\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: "+fp+"\n\tSignature: "+sig+"\n}\nNebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: "+c.Issuer()+"\n\t\tPublic key: "+pk+"\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: "+fp+"\n\tSignature: "+sig+"\n}\n",
|
//"NebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: "+c.Issuer()+"\n\t\tPublic key: "+pk+"\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: "+fp+"\n\tSignature: "+sig+"\n}\nNebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: "+c.Issuer()+"\n\t\tPublic key: "+pk+"\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: "+fp+"\n\tSignature: "+sig+"\n}\nNebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: "+c.Issuer()+"\n\t\tPublic key: "+pk+"\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: "+fp+"\n\tSignature: "+sig+"\n}\n",
|
||||||
@@ -154,7 +155,7 @@ func Test_printCert(t *testing.T) {
|
|||||||
`,
|
`,
|
||||||
ob.String(),
|
ob.String(),
|
||||||
)
|
)
|
||||||
assert.Equal(t, "", eb.String())
|
assert.Empty(t, eb.String())
|
||||||
|
|
||||||
// test json
|
// test json
|
||||||
ob.Reset()
|
ob.Reset()
|
||||||
@@ -169,14 +170,14 @@ func Test_printCert(t *testing.T) {
|
|||||||
fp, _ = c.Fingerprint()
|
fp, _ = c.Fingerprint()
|
||||||
pk = hex.EncodeToString(c.PublicKey())
|
pk = hex.EncodeToString(c.PublicKey())
|
||||||
sig = hex.EncodeToString(c.Signature())
|
sig = hex.EncodeToString(c.Signature())
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(
|
assert.Equal(
|
||||||
t,
|
t,
|
||||||
`[{"details":{"curve":"CURVE25519","groups":["hi"],"isCa":false,"issuer":"`+c.Issuer()+`","name":"test","networks":["10.0.0.123/8"],"notAfter":"0001-01-01T00:00:00Z","notBefore":"0001-01-01T00:00:00Z","publicKey":"`+pk+`","unsafeNetworks":[]},"fingerprint":"`+fp+`","signature":"`+sig+`","version":1},{"details":{"curve":"CURVE25519","groups":["hi"],"isCa":false,"issuer":"`+c.Issuer()+`","name":"test","networks":["10.0.0.123/8"],"notAfter":"0001-01-01T00:00:00Z","notBefore":"0001-01-01T00:00:00Z","publicKey":"`+pk+`","unsafeNetworks":[]},"fingerprint":"`+fp+`","signature":"`+sig+`","version":1},{"details":{"curve":"CURVE25519","groups":["hi"],"isCa":false,"issuer":"`+c.Issuer()+`","name":"test","networks":["10.0.0.123/8"],"notAfter":"0001-01-01T00:00:00Z","notBefore":"0001-01-01T00:00:00Z","publicKey":"`+pk+`","unsafeNetworks":[]},"fingerprint":"`+fp+`","signature":"`+sig+`","version":1}]
|
`[{"details":{"curve":"CURVE25519","groups":["hi"],"isCa":false,"issuer":"`+c.Issuer()+`","name":"test","networks":["10.0.0.123/8"],"notAfter":"0001-01-01T00:00:00Z","notBefore":"0001-01-01T00:00:00Z","publicKey":"`+pk+`","unsafeNetworks":[]},"fingerprint":"`+fp+`","signature":"`+sig+`","version":1},{"details":{"curve":"CURVE25519","groups":["hi"],"isCa":false,"issuer":"`+c.Issuer()+`","name":"test","networks":["10.0.0.123/8"],"notAfter":"0001-01-01T00:00:00Z","notBefore":"0001-01-01T00:00:00Z","publicKey":"`+pk+`","unsafeNetworks":[]},"fingerprint":"`+fp+`","signature":"`+sig+`","version":1},{"details":{"curve":"CURVE25519","groups":["hi"],"isCa":false,"issuer":"`+c.Issuer()+`","name":"test","networks":["10.0.0.123/8"],"notAfter":"0001-01-01T00:00:00Z","notBefore":"0001-01-01T00:00:00Z","publicKey":"`+pk+`","unsafeNetworks":[]},"fingerprint":"`+fp+`","signature":"`+sig+`","version":1}]
|
||||||
`,
|
`,
|
||||||
ob.String(),
|
ob.String(),
|
||||||
)
|
)
|
||||||
assert.Equal(t, "", eb.String())
|
assert.Empty(t, eb.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewTestCaCert will generate a CA cert
|
// NewTestCaCert will generate a CA cert
|
||||||
|
|||||||
@@ -43,7 +43,7 @@ type signFlags struct {
|
|||||||
func newSignFlags() *signFlags {
|
func newSignFlags() *signFlags {
|
||||||
sf := signFlags{set: flag.NewFlagSet("sign", flag.ContinueOnError)}
|
sf := signFlags{set: flag.NewFlagSet("sign", flag.ContinueOnError)}
|
||||||
sf.set.Usage = func() {}
|
sf.set.Usage = func() {}
|
||||||
sf.version = sf.set.Uint("version", 0, "Optional: version of the certificate format to use, the default is to create both v1 and v2 certificates.")
|
sf.version = sf.set.Uint("version", 0, "Optional: version of the certificate format to use. The default is to match the version of the signing CA")
|
||||||
sf.caKeyPath = sf.set.String("ca-key", "ca.key", "Optional: path to the signing CA key")
|
sf.caKeyPath = sf.set.String("ca-key", "ca.key", "Optional: path to the signing CA key")
|
||||||
sf.caCertPath = sf.set.String("ca-crt", "ca.crt", "Optional: path to the signing CA cert")
|
sf.caCertPath = sf.set.String("ca-crt", "ca.crt", "Optional: path to the signing CA cert")
|
||||||
sf.name = sf.set.String("name", "", "Required: name of the cert, usually a hostname")
|
sf.name = sf.set.String("name", "", "Required: name of the cert, usually a hostname")
|
||||||
@@ -116,26 +116,28 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
|
|||||||
// naively attempt to decode the private key as though it is not encrypted
|
// naively attempt to decode the private key as though it is not encrypted
|
||||||
caKey, _, curve, err = cert.UnmarshalSigningPrivateKeyFromPEM(rawCAKey)
|
caKey, _, curve, err = cert.UnmarshalSigningPrivateKeyFromPEM(rawCAKey)
|
||||||
if errors.Is(err, cert.ErrPrivateKeyEncrypted) {
|
if errors.Is(err, cert.ErrPrivateKeyEncrypted) {
|
||||||
// ask for a passphrase until we get one
|
|
||||||
var passphrase []byte
|
var passphrase []byte
|
||||||
for i := 0; i < 5; i++ {
|
passphrase = []byte(os.Getenv("NEBULA_CA_PASSPHRASE"))
|
||||||
out.Write([]byte("Enter passphrase: "))
|
|
||||||
passphrase, err = pr.ReadPassword()
|
|
||||||
|
|
||||||
if errors.Is(err, ErrNoTerminal) {
|
|
||||||
return fmt.Errorf("ca-key is encrypted and must be decrypted interactively")
|
|
||||||
} else if err != nil {
|
|
||||||
return fmt.Errorf("error reading password: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(passphrase) > 0 {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if len(passphrase) == 0 {
|
if len(passphrase) == 0 {
|
||||||
return fmt.Errorf("cannot open encrypted ca-key without passphrase")
|
// ask for a passphrase until we get one
|
||||||
}
|
for i := 0; i < 5; i++ {
|
||||||
|
out.Write([]byte("Enter passphrase: "))
|
||||||
|
passphrase, err = pr.ReadPassword()
|
||||||
|
|
||||||
|
if errors.Is(err, ErrNoTerminal) {
|
||||||
|
return fmt.Errorf("ca-key is encrypted and must be decrypted interactively")
|
||||||
|
} else if err != nil {
|
||||||
|
return fmt.Errorf("error reading password: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(passphrase) > 0 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(passphrase) == 0 {
|
||||||
|
return fmt.Errorf("cannot open encrypted ca-key without passphrase")
|
||||||
|
}
|
||||||
|
}
|
||||||
curve, caKey, _, err = cert.DecryptAndUnmarshalSigningPrivateKey(passphrase, rawCAKey)
|
curve, caKey, _, err = cert.DecryptAndUnmarshalSigningPrivateKey(passphrase, rawCAKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("error while parsing encrypted ca-key: %s", err)
|
return fmt.Errorf("error while parsing encrypted ca-key: %s", err)
|
||||||
@@ -165,6 +167,10 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
|
|||||||
return fmt.Errorf("ca certificate is expired")
|
return fmt.Errorf("ca certificate is expired")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if version == 0 {
|
||||||
|
version = caCert.Version()
|
||||||
|
}
|
||||||
|
|
||||||
// if no duration is given, expire one second before the root expires
|
// if no duration is given, expire one second before the root expires
|
||||||
if *sf.duration <= 0 {
|
if *sf.duration <= 0 {
|
||||||
*sf.duration = time.Until(caCert.NotAfter()) - time.Second*1
|
*sf.duration = time.Until(caCert.NotAfter()) - time.Second*1
|
||||||
@@ -277,21 +283,19 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
|
|||||||
notBefore := time.Now()
|
notBefore := time.Now()
|
||||||
notAfter := notBefore.Add(*sf.duration)
|
notAfter := notBefore.Add(*sf.duration)
|
||||||
|
|
||||||
if version == 0 || version == cert.Version1 {
|
switch version {
|
||||||
// Make sure we at least have an ip
|
case cert.Version1:
|
||||||
|
// Make sure we have only one ipv4 address
|
||||||
if len(v4Networks) != 1 {
|
if len(v4Networks) != 1 {
|
||||||
return newHelpErrorf("invalid -networks definition: v1 certificates can only have a single ipv4 address")
|
return newHelpErrorf("invalid -networks definition: v1 certificates can only have a single ipv4 address")
|
||||||
}
|
}
|
||||||
|
|
||||||
if version == cert.Version1 {
|
if len(v6Networks) > 0 {
|
||||||
// If we are asked to mint a v1 certificate only then we cant just ignore any v6 addresses
|
return newHelpErrorf("invalid -networks definition: v1 certificates can only contain ipv4 addresses")
|
||||||
if len(v6Networks) > 0 {
|
}
|
||||||
return newHelpErrorf("invalid -networks definition: v1 certificates can only be ipv4")
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(v6UnsafeNetworks) > 0 {
|
if len(v6UnsafeNetworks) > 0 {
|
||||||
return newHelpErrorf("invalid -unsafe-networks definition: v1 certificates can only be ipv4")
|
return newHelpErrorf("invalid -unsafe-networks definition: v1 certificates can only contain ipv4 addresses")
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
t := &cert.TBSCertificate{
|
t := &cert.TBSCertificate{
|
||||||
@@ -321,9 +325,8 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
|
|||||||
}
|
}
|
||||||
|
|
||||||
crts = append(crts, nc)
|
crts = append(crts, nc)
|
||||||
}
|
|
||||||
|
|
||||||
if version == 0 || version == cert.Version2 {
|
case cert.Version2:
|
||||||
t := &cert.TBSCertificate{
|
t := &cert.TBSCertificate{
|
||||||
Version: cert.Version2,
|
Version: cert.Version2,
|
||||||
Name: *sf.name,
|
Name: *sf.name,
|
||||||
@@ -351,6 +354,9 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
|
|||||||
}
|
}
|
||||||
|
|
||||||
crts = append(crts, nc)
|
crts = append(crts, nc)
|
||||||
|
default:
|
||||||
|
// this should be unreachable
|
||||||
|
return fmt.Errorf("invalid version: %d", version)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !isP11 && *sf.inPubPath == "" {
|
if !isP11 && *sf.inPubPath == "" {
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ import (
|
|||||||
|
|
||||||
"github.com/slackhq/nebula/cert"
|
"github.com/slackhq/nebula/cert"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
"golang.org/x/crypto/ed25519"
|
"golang.org/x/crypto/ed25519"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -54,7 +55,7 @@ func Test_signHelp(t *testing.T) {
|
|||||||
" -unsafe-networks string\n"+
|
" -unsafe-networks string\n"+
|
||||||
" \tOptional: comma separated list of ip address and network in CIDR notation. Unsafe networks this cert can route for\n"+
|
" \tOptional: comma separated list of ip address and network in CIDR notation. Unsafe networks this cert can route for\n"+
|
||||||
" -version uint\n"+
|
" -version uint\n"+
|
||||||
" \tOptional: version of the certificate format to use, the default is to create both v1 and v2 certificates.\n",
|
" \tOptional: version of the certificate format to use. The default is to match the version of the signing CA\n",
|
||||||
ob.String(),
|
ob.String(),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
@@ -103,17 +104,17 @@ func Test_signCert(t *testing.T) {
|
|||||||
ob.Reset()
|
ob.Reset()
|
||||||
eb.Reset()
|
eb.Reset()
|
||||||
args := []string{"-version", "1", "-ca-crt", "./nope", "-ca-key", "./nope", "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
|
args := []string{"-version", "1", "-ca-crt", "./nope", "-ca-key", "./nope", "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
|
||||||
assert.EqualError(t, signCert(args, ob, eb, nopw), "error while reading ca-key: open ./nope: "+NoSuchFileError)
|
require.EqualError(t, signCert(args, ob, eb, nopw), "error while reading ca-key: open ./nope: "+NoSuchFileError)
|
||||||
|
|
||||||
// failed to unmarshal key
|
// failed to unmarshal key
|
||||||
ob.Reset()
|
ob.Reset()
|
||||||
eb.Reset()
|
eb.Reset()
|
||||||
caKeyF, err := os.CreateTemp("", "sign-cert.key")
|
caKeyF, err := os.CreateTemp("", "sign-cert.key")
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
defer os.Remove(caKeyF.Name())
|
defer os.Remove(caKeyF.Name())
|
||||||
|
|
||||||
args = []string{"-version", "1", "-ca-crt", "./nope", "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
|
args = []string{"-version", "1", "-ca-crt", "./nope", "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
|
||||||
assert.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing ca-key: input did not contain a valid PEM encoded block")
|
require.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing ca-key: input did not contain a valid PEM encoded block")
|
||||||
assert.Empty(t, ob.String())
|
assert.Empty(t, ob.String())
|
||||||
assert.Empty(t, eb.String())
|
assert.Empty(t, eb.String())
|
||||||
|
|
||||||
@@ -125,7 +126,7 @@ func Test_signCert(t *testing.T) {
|
|||||||
|
|
||||||
// failed to read cert
|
// failed to read cert
|
||||||
args = []string{"-version", "1", "-ca-crt", "./nope", "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
|
args = []string{"-version", "1", "-ca-crt", "./nope", "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
|
||||||
assert.EqualError(t, signCert(args, ob, eb, nopw), "error while reading ca-crt: open ./nope: "+NoSuchFileError)
|
require.EqualError(t, signCert(args, ob, eb, nopw), "error while reading ca-crt: open ./nope: "+NoSuchFileError)
|
||||||
assert.Empty(t, ob.String())
|
assert.Empty(t, ob.String())
|
||||||
assert.Empty(t, eb.String())
|
assert.Empty(t, eb.String())
|
||||||
|
|
||||||
@@ -133,11 +134,11 @@ func Test_signCert(t *testing.T) {
|
|||||||
ob.Reset()
|
ob.Reset()
|
||||||
eb.Reset()
|
eb.Reset()
|
||||||
caCrtF, err := os.CreateTemp("", "sign-cert.crt")
|
caCrtF, err := os.CreateTemp("", "sign-cert.crt")
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
defer os.Remove(caCrtF.Name())
|
defer os.Remove(caCrtF.Name())
|
||||||
|
|
||||||
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
|
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
|
||||||
assert.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing ca-crt: input did not contain a valid PEM encoded block")
|
require.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing ca-crt: input did not contain a valid PEM encoded block")
|
||||||
assert.Empty(t, ob.String())
|
assert.Empty(t, ob.String())
|
||||||
assert.Empty(t, eb.String())
|
assert.Empty(t, eb.String())
|
||||||
|
|
||||||
@@ -148,7 +149,7 @@ func Test_signCert(t *testing.T) {
|
|||||||
|
|
||||||
// failed to read pub
|
// failed to read pub
|
||||||
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-in-pub", "./nope", "-duration", "100m"}
|
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-in-pub", "./nope", "-duration", "100m"}
|
||||||
assert.EqualError(t, signCert(args, ob, eb, nopw), "error while reading in-pub: open ./nope: "+NoSuchFileError)
|
require.EqualError(t, signCert(args, ob, eb, nopw), "error while reading in-pub: open ./nope: "+NoSuchFileError)
|
||||||
assert.Empty(t, ob.String())
|
assert.Empty(t, ob.String())
|
||||||
assert.Empty(t, eb.String())
|
assert.Empty(t, eb.String())
|
||||||
|
|
||||||
@@ -156,11 +157,11 @@ func Test_signCert(t *testing.T) {
|
|||||||
ob.Reset()
|
ob.Reset()
|
||||||
eb.Reset()
|
eb.Reset()
|
||||||
inPubF, err := os.CreateTemp("", "in.pub")
|
inPubF, err := os.CreateTemp("", "in.pub")
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
defer os.Remove(inPubF.Name())
|
defer os.Remove(inPubF.Name())
|
||||||
|
|
||||||
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-in-pub", inPubF.Name(), "-duration", "100m"}
|
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-in-pub", inPubF.Name(), "-duration", "100m"}
|
||||||
assert.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing in-pub: input did not contain a valid PEM encoded block")
|
require.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing in-pub: input did not contain a valid PEM encoded block")
|
||||||
assert.Empty(t, ob.String())
|
assert.Empty(t, ob.String())
|
||||||
assert.Empty(t, eb.String())
|
assert.Empty(t, eb.String())
|
||||||
|
|
||||||
@@ -203,21 +204,21 @@ func Test_signCert(t *testing.T) {
|
|||||||
ob.Reset()
|
ob.Reset()
|
||||||
eb.Reset()
|
eb.Reset()
|
||||||
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "100::100/100"}
|
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "100::100/100"}
|
||||||
assertHelpError(t, signCert(args, ob, eb, nopw), "invalid -unsafe-networks definition: v1 certificates can only be ipv4")
|
assertHelpError(t, signCert(args, ob, eb, nopw), "invalid -unsafe-networks definition: v1 certificates can only contain ipv4 addresses")
|
||||||
assert.Empty(t, ob.String())
|
assert.Empty(t, ob.String())
|
||||||
assert.Empty(t, eb.String())
|
assert.Empty(t, eb.String())
|
||||||
|
|
||||||
// mismatched ca key
|
// mismatched ca key
|
||||||
_, caPriv2, _ := ed25519.GenerateKey(rand.Reader)
|
_, caPriv2, _ := ed25519.GenerateKey(rand.Reader)
|
||||||
caKeyF2, err := os.CreateTemp("", "sign-cert-2.key")
|
caKeyF2, err := os.CreateTemp("", "sign-cert-2.key")
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
defer os.Remove(caKeyF2.Name())
|
defer os.Remove(caKeyF2.Name())
|
||||||
caKeyF2.Write(cert.MarshalSigningPrivateKeyToPEM(cert.Curve_CURVE25519, caPriv2))
|
caKeyF2.Write(cert.MarshalSigningPrivateKeyToPEM(cert.Curve_CURVE25519, caPriv2))
|
||||||
|
|
||||||
ob.Reset()
|
ob.Reset()
|
||||||
eb.Reset()
|
eb.Reset()
|
||||||
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF2.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "a"}
|
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF2.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "a"}
|
||||||
assert.EqualError(t, signCert(args, ob, eb, nopw), "refusing to sign, root certificate does not match private key")
|
require.EqualError(t, signCert(args, ob, eb, nopw), "refusing to sign, root certificate does not match private key")
|
||||||
assert.Empty(t, ob.String())
|
assert.Empty(t, ob.String())
|
||||||
assert.Empty(t, eb.String())
|
assert.Empty(t, eb.String())
|
||||||
|
|
||||||
@@ -225,34 +226,34 @@ func Test_signCert(t *testing.T) {
|
|||||||
ob.Reset()
|
ob.Reset()
|
||||||
eb.Reset()
|
eb.Reset()
|
||||||
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "/do/not/write/pleasecrt", "-out-key", "/do/not/write/pleasekey", "-duration", "100m", "-subnets", "10.1.1.1/32"}
|
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "/do/not/write/pleasecrt", "-out-key", "/do/not/write/pleasekey", "-duration", "100m", "-subnets", "10.1.1.1/32"}
|
||||||
assert.EqualError(t, signCert(args, ob, eb, nopw), "error while writing out-key: open /do/not/write/pleasekey: "+NoSuchDirError)
|
require.EqualError(t, signCert(args, ob, eb, nopw), "error while writing out-key: open /do/not/write/pleasekey: "+NoSuchDirError)
|
||||||
assert.Empty(t, ob.String())
|
assert.Empty(t, ob.String())
|
||||||
assert.Empty(t, eb.String())
|
assert.Empty(t, eb.String())
|
||||||
|
|
||||||
// create temp key file
|
// create temp key file
|
||||||
keyF, err := os.CreateTemp("", "test.key")
|
keyF, err := os.CreateTemp("", "test.key")
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
os.Remove(keyF.Name())
|
os.Remove(keyF.Name())
|
||||||
|
|
||||||
// failed cert write
|
// failed cert write
|
||||||
ob.Reset()
|
ob.Reset()
|
||||||
eb.Reset()
|
eb.Reset()
|
||||||
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "/do/not/write/pleasecrt", "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32"}
|
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "/do/not/write/pleasecrt", "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32"}
|
||||||
assert.EqualError(t, signCert(args, ob, eb, nopw), "error while writing out-crt: open /do/not/write/pleasecrt: "+NoSuchDirError)
|
require.EqualError(t, signCert(args, ob, eb, nopw), "error while writing out-crt: open /do/not/write/pleasecrt: "+NoSuchDirError)
|
||||||
assert.Empty(t, ob.String())
|
assert.Empty(t, ob.String())
|
||||||
assert.Empty(t, eb.String())
|
assert.Empty(t, eb.String())
|
||||||
os.Remove(keyF.Name())
|
os.Remove(keyF.Name())
|
||||||
|
|
||||||
// create temp cert file
|
// create temp cert file
|
||||||
crtF, err := os.CreateTemp("", "test.crt")
|
crtF, err := os.CreateTemp("", "test.crt")
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
os.Remove(crtF.Name())
|
os.Remove(crtF.Name())
|
||||||
|
|
||||||
// test proper cert with removed empty groups and subnets
|
// test proper cert with removed empty groups and subnets
|
||||||
ob.Reset()
|
ob.Reset()
|
||||||
eb.Reset()
|
eb.Reset()
|
||||||
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
|
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
|
||||||
assert.Nil(t, signCert(args, ob, eb, nopw))
|
require.NoError(t, signCert(args, ob, eb, nopw))
|
||||||
assert.Empty(t, ob.String())
|
assert.Empty(t, ob.String())
|
||||||
assert.Empty(t, eb.String())
|
assert.Empty(t, eb.String())
|
||||||
|
|
||||||
@@ -260,14 +261,14 @@ func Test_signCert(t *testing.T) {
|
|||||||
rb, _ := os.ReadFile(keyF.Name())
|
rb, _ := os.ReadFile(keyF.Name())
|
||||||
lKey, b, curve, err := cert.UnmarshalPrivateKeyFromPEM(rb)
|
lKey, b, curve, err := cert.UnmarshalPrivateKeyFromPEM(rb)
|
||||||
assert.Equal(t, cert.Curve_CURVE25519, curve)
|
assert.Equal(t, cert.Curve_CURVE25519, curve)
|
||||||
assert.Len(t, b, 0)
|
assert.Empty(t, b)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
assert.Len(t, lKey, 32)
|
assert.Len(t, lKey, 32)
|
||||||
|
|
||||||
rb, _ = os.ReadFile(crtF.Name())
|
rb, _ = os.ReadFile(crtF.Name())
|
||||||
lCrt, b, err := cert.UnmarshalCertificateFromPEM(rb)
|
lCrt, b, err := cert.UnmarshalCertificateFromPEM(rb)
|
||||||
assert.Len(t, b, 0)
|
assert.Empty(t, b)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
assert.Equal(t, "test", lCrt.Name())
|
assert.Equal(t, "test", lCrt.Name())
|
||||||
assert.Equal(t, "1.1.1.1/24", lCrt.Networks()[0].String())
|
assert.Equal(t, "1.1.1.1/24", lCrt.Networks()[0].String())
|
||||||
@@ -295,15 +296,15 @@ func Test_signCert(t *testing.T) {
|
|||||||
ob.Reset()
|
ob.Reset()
|
||||||
eb.Reset()
|
eb.Reset()
|
||||||
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-in-pub", inPubF.Name(), "-duration", "100m", "-groups", "1"}
|
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-in-pub", inPubF.Name(), "-duration", "100m", "-groups", "1"}
|
||||||
assert.Nil(t, signCert(args, ob, eb, nopw))
|
require.NoError(t, signCert(args, ob, eb, nopw))
|
||||||
assert.Empty(t, ob.String())
|
assert.Empty(t, ob.String())
|
||||||
assert.Empty(t, eb.String())
|
assert.Empty(t, eb.String())
|
||||||
|
|
||||||
// read cert file and check pub key matches in-pub
|
// read cert file and check pub key matches in-pub
|
||||||
rb, _ = os.ReadFile(crtF.Name())
|
rb, _ = os.ReadFile(crtF.Name())
|
||||||
lCrt, b, err = cert.UnmarshalCertificateFromPEM(rb)
|
lCrt, b, err = cert.UnmarshalCertificateFromPEM(rb)
|
||||||
assert.Len(t, b, 0)
|
assert.Empty(t, b)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, lCrt.PublicKey(), inPub)
|
assert.Equal(t, lCrt.PublicKey(), inPub)
|
||||||
|
|
||||||
// test refuse to sign cert with duration beyond root
|
// test refuse to sign cert with duration beyond root
|
||||||
@@ -312,7 +313,7 @@ func Test_signCert(t *testing.T) {
|
|||||||
os.Remove(keyF.Name())
|
os.Remove(keyF.Name())
|
||||||
os.Remove(crtF.Name())
|
os.Remove(crtF.Name())
|
||||||
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "1000m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
|
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "1000m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
|
||||||
assert.EqualError(t, signCert(args, ob, eb, nopw), "error while signing: certificate expires after signing certificate")
|
require.EqualError(t, signCert(args, ob, eb, nopw), "error while signing: certificate expires after signing certificate")
|
||||||
assert.Empty(t, ob.String())
|
assert.Empty(t, ob.String())
|
||||||
assert.Empty(t, eb.String())
|
assert.Empty(t, eb.String())
|
||||||
|
|
||||||
@@ -320,14 +321,14 @@ func Test_signCert(t *testing.T) {
|
|||||||
os.Remove(keyF.Name())
|
os.Remove(keyF.Name())
|
||||||
os.Remove(crtF.Name())
|
os.Remove(crtF.Name())
|
||||||
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
|
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
|
||||||
assert.Nil(t, signCert(args, ob, eb, nopw))
|
require.NoError(t, signCert(args, ob, eb, nopw))
|
||||||
|
|
||||||
// test that we won't overwrite existing key file
|
// test that we won't overwrite existing key file
|
||||||
os.Remove(crtF.Name())
|
os.Remove(crtF.Name())
|
||||||
ob.Reset()
|
ob.Reset()
|
||||||
eb.Reset()
|
eb.Reset()
|
||||||
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
|
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
|
||||||
assert.EqualError(t, signCert(args, ob, eb, nopw), "refusing to overwrite existing key: "+keyF.Name())
|
require.EqualError(t, signCert(args, ob, eb, nopw), "refusing to overwrite existing key: "+keyF.Name())
|
||||||
assert.Empty(t, ob.String())
|
assert.Empty(t, ob.String())
|
||||||
assert.Empty(t, eb.String())
|
assert.Empty(t, eb.String())
|
||||||
|
|
||||||
@@ -335,14 +336,14 @@ func Test_signCert(t *testing.T) {
|
|||||||
os.Remove(keyF.Name())
|
os.Remove(keyF.Name())
|
||||||
os.Remove(crtF.Name())
|
os.Remove(crtF.Name())
|
||||||
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
|
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
|
||||||
assert.Nil(t, signCert(args, ob, eb, nopw))
|
require.NoError(t, signCert(args, ob, eb, nopw))
|
||||||
|
|
||||||
// test that we won't overwrite existing certificate file
|
// test that we won't overwrite existing certificate file
|
||||||
os.Remove(keyF.Name())
|
os.Remove(keyF.Name())
|
||||||
ob.Reset()
|
ob.Reset()
|
||||||
eb.Reset()
|
eb.Reset()
|
||||||
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
|
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
|
||||||
assert.EqualError(t, signCert(args, ob, eb, nopw), "refusing to overwrite existing cert: "+crtF.Name())
|
require.EqualError(t, signCert(args, ob, eb, nopw), "refusing to overwrite existing cert: "+crtF.Name())
|
||||||
assert.Empty(t, ob.String())
|
assert.Empty(t, ob.String())
|
||||||
assert.Empty(t, eb.String())
|
assert.Empty(t, eb.String())
|
||||||
|
|
||||||
@@ -355,11 +356,11 @@ func Test_signCert(t *testing.T) {
|
|||||||
eb.Reset()
|
eb.Reset()
|
||||||
|
|
||||||
caKeyF, err = os.CreateTemp("", "sign-cert.key")
|
caKeyF, err = os.CreateTemp("", "sign-cert.key")
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
defer os.Remove(caKeyF.Name())
|
defer os.Remove(caKeyF.Name())
|
||||||
|
|
||||||
caCrtF, err = os.CreateTemp("", "sign-cert.crt")
|
caCrtF, err = os.CreateTemp("", "sign-cert.crt")
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
defer os.Remove(caCrtF.Name())
|
defer os.Remove(caCrtF.Name())
|
||||||
|
|
||||||
// generate the encrypted key
|
// generate the encrypted key
|
||||||
@@ -374,26 +375,46 @@ func Test_signCert(t *testing.T) {
|
|||||||
|
|
||||||
// test with the proper password
|
// test with the proper password
|
||||||
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
|
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
|
||||||
assert.Nil(t, signCert(args, ob, eb, testpw))
|
require.NoError(t, signCert(args, ob, eb, testpw))
|
||||||
assert.Equal(t, "Enter passphrase: ", ob.String())
|
assert.Equal(t, "Enter passphrase: ", ob.String())
|
||||||
assert.Empty(t, eb.String())
|
assert.Empty(t, eb.String())
|
||||||
|
|
||||||
|
// test with the proper password in the environment
|
||||||
|
os.Remove(crtF.Name())
|
||||||
|
os.Remove(keyF.Name())
|
||||||
|
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
|
||||||
|
os.Setenv("NEBULA_CA_PASSPHRASE", string(passphrase))
|
||||||
|
require.NoError(t, signCert(args, ob, eb, testpw))
|
||||||
|
assert.Empty(t, eb.String())
|
||||||
|
os.Setenv("NEBULA_CA_PASSPHRASE", "")
|
||||||
|
|
||||||
// test with the wrong password
|
// test with the wrong password
|
||||||
ob.Reset()
|
ob.Reset()
|
||||||
eb.Reset()
|
eb.Reset()
|
||||||
|
|
||||||
testpw.password = []byte("invalid password")
|
testpw.password = []byte("invalid password")
|
||||||
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
|
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
|
||||||
assert.Error(t, signCert(args, ob, eb, testpw))
|
require.Error(t, signCert(args, ob, eb, testpw))
|
||||||
assert.Equal(t, "Enter passphrase: ", ob.String())
|
assert.Equal(t, "Enter passphrase: ", ob.String())
|
||||||
assert.Empty(t, eb.String())
|
assert.Empty(t, eb.String())
|
||||||
|
|
||||||
|
// test with the wrong password in environment
|
||||||
|
ob.Reset()
|
||||||
|
eb.Reset()
|
||||||
|
|
||||||
|
os.Setenv("NEBULA_CA_PASSPHRASE", "invalid password")
|
||||||
|
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
|
||||||
|
require.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing encrypted ca-key: invalid passphrase or corrupt private key")
|
||||||
|
assert.Empty(t, ob.String())
|
||||||
|
assert.Empty(t, eb.String())
|
||||||
|
os.Setenv("NEBULA_CA_PASSPHRASE", "")
|
||||||
|
|
||||||
// test with the user not entering a password
|
// test with the user not entering a password
|
||||||
ob.Reset()
|
ob.Reset()
|
||||||
eb.Reset()
|
eb.Reset()
|
||||||
|
|
||||||
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
|
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
|
||||||
assert.Error(t, signCert(args, ob, eb, nopw))
|
require.Error(t, signCert(args, ob, eb, nopw))
|
||||||
// normally the user hitting enter on the prompt would add newlines between these
|
// normally the user hitting enter on the prompt would add newlines between these
|
||||||
assert.Equal(t, "Enter passphrase: Enter passphrase: Enter passphrase: Enter passphrase: Enter passphrase: ", ob.String())
|
assert.Equal(t, "Enter passphrase: Enter passphrase: Enter passphrase: Enter passphrase: Enter passphrase: ", ob.String())
|
||||||
assert.Empty(t, eb.String())
|
assert.Empty(t, eb.String())
|
||||||
@@ -403,7 +424,7 @@ func Test_signCert(t *testing.T) {
|
|||||||
eb.Reset()
|
eb.Reset()
|
||||||
|
|
||||||
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
|
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
|
||||||
assert.Error(t, signCert(args, ob, eb, errpw))
|
require.Error(t, signCert(args, ob, eb, errpw))
|
||||||
assert.Equal(t, "Enter passphrase: ", ob.String())
|
assert.Equal(t, "Enter passphrase: ", ob.String())
|
||||||
assert.Empty(t, eb.String())
|
assert.Empty(t, eb.String())
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,13 +3,13 @@ package main
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"errors"
|
|
||||||
"os"
|
"os"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/slackhq/nebula/cert"
|
"github.com/slackhq/nebula/cert"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
"golang.org/x/crypto/ed25519"
|
"golang.org/x/crypto/ed25519"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -38,33 +38,33 @@ func Test_verify(t *testing.T) {
|
|||||||
|
|
||||||
// required args
|
// required args
|
||||||
assertHelpError(t, verify([]string{"-ca", "derp"}, ob, eb), "-crt is required")
|
assertHelpError(t, verify([]string{"-ca", "derp"}, ob, eb), "-crt is required")
|
||||||
assert.Equal(t, "", ob.String())
|
assert.Empty(t, ob.String())
|
||||||
assert.Equal(t, "", eb.String())
|
assert.Empty(t, eb.String())
|
||||||
|
|
||||||
assertHelpError(t, verify([]string{"-crt", "derp"}, ob, eb), "-ca is required")
|
assertHelpError(t, verify([]string{"-crt", "derp"}, ob, eb), "-ca is required")
|
||||||
assert.Equal(t, "", ob.String())
|
assert.Empty(t, ob.String())
|
||||||
assert.Equal(t, "", eb.String())
|
assert.Empty(t, eb.String())
|
||||||
|
|
||||||
// no ca at path
|
// no ca at path
|
||||||
ob.Reset()
|
ob.Reset()
|
||||||
eb.Reset()
|
eb.Reset()
|
||||||
err := verify([]string{"-ca", "does_not_exist", "-crt", "does_not_exist"}, ob, eb)
|
err := verify([]string{"-ca", "does_not_exist", "-crt", "does_not_exist"}, ob, eb)
|
||||||
assert.Equal(t, "", ob.String())
|
assert.Empty(t, ob.String())
|
||||||
assert.Equal(t, "", eb.String())
|
assert.Empty(t, eb.String())
|
||||||
assert.EqualError(t, err, "error while reading ca: open does_not_exist: "+NoSuchFileError)
|
require.EqualError(t, err, "error while reading ca: open does_not_exist: "+NoSuchFileError)
|
||||||
|
|
||||||
// invalid ca at path
|
// invalid ca at path
|
||||||
ob.Reset()
|
ob.Reset()
|
||||||
eb.Reset()
|
eb.Reset()
|
||||||
caFile, err := os.CreateTemp("", "verify-ca")
|
caFile, err := os.CreateTemp("", "verify-ca")
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
defer os.Remove(caFile.Name())
|
defer os.Remove(caFile.Name())
|
||||||
|
|
||||||
caFile.WriteString("-----BEGIN NOPE-----")
|
caFile.WriteString("-----BEGIN NOPE-----")
|
||||||
err = verify([]string{"-ca", caFile.Name(), "-crt", "does_not_exist"}, ob, eb)
|
err = verify([]string{"-ca", caFile.Name(), "-crt", "does_not_exist"}, ob, eb)
|
||||||
assert.Equal(t, "", ob.String())
|
assert.Empty(t, ob.String())
|
||||||
assert.Equal(t, "", eb.String())
|
assert.Empty(t, eb.String())
|
||||||
assert.EqualError(t, err, "error while adding ca cert to pool: input did not contain a valid PEM encoded block")
|
require.EqualError(t, err, "error while adding ca cert to pool: input did not contain a valid PEM encoded block")
|
||||||
|
|
||||||
// make a ca for later
|
// make a ca for later
|
||||||
caPub, caPriv, _ := ed25519.GenerateKey(rand.Reader)
|
caPub, caPriv, _ := ed25519.GenerateKey(rand.Reader)
|
||||||
@@ -76,22 +76,22 @@ func Test_verify(t *testing.T) {
|
|||||||
|
|
||||||
// no crt at path
|
// no crt at path
|
||||||
err = verify([]string{"-ca", caFile.Name(), "-crt", "does_not_exist"}, ob, eb)
|
err = verify([]string{"-ca", caFile.Name(), "-crt", "does_not_exist"}, ob, eb)
|
||||||
assert.Equal(t, "", ob.String())
|
assert.Empty(t, ob.String())
|
||||||
assert.Equal(t, "", eb.String())
|
assert.Empty(t, eb.String())
|
||||||
assert.EqualError(t, err, "unable to read crt: open does_not_exist: "+NoSuchFileError)
|
require.EqualError(t, err, "unable to read crt: open does_not_exist: "+NoSuchFileError)
|
||||||
|
|
||||||
// invalid crt at path
|
// invalid crt at path
|
||||||
ob.Reset()
|
ob.Reset()
|
||||||
eb.Reset()
|
eb.Reset()
|
||||||
certFile, err := os.CreateTemp("", "verify-cert")
|
certFile, err := os.CreateTemp("", "verify-cert")
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
defer os.Remove(certFile.Name())
|
defer os.Remove(certFile.Name())
|
||||||
|
|
||||||
certFile.WriteString("-----BEGIN NOPE-----")
|
certFile.WriteString("-----BEGIN NOPE-----")
|
||||||
err = verify([]string{"-ca", caFile.Name(), "-crt", certFile.Name()}, ob, eb)
|
err = verify([]string{"-ca", caFile.Name(), "-crt", certFile.Name()}, ob, eb)
|
||||||
assert.Equal(t, "", ob.String())
|
assert.Empty(t, ob.String())
|
||||||
assert.Equal(t, "", eb.String())
|
assert.Empty(t, eb.String())
|
||||||
assert.EqualError(t, err, "error while parsing crt: input did not contain a valid PEM encoded block")
|
require.EqualError(t, err, "error while parsing crt: input did not contain a valid PEM encoded block")
|
||||||
|
|
||||||
// unverifiable cert at path
|
// unverifiable cert at path
|
||||||
crt, _ := NewTestCert(ca, caPriv, "test-cert", time.Now().Add(time.Hour*-1), time.Now().Add(time.Hour), nil, nil, nil)
|
crt, _ := NewTestCert(ca, caPriv, "test-cert", time.Now().Add(time.Hour*-1), time.Now().Add(time.Hour), nil, nil, nil)
|
||||||
@@ -106,9 +106,9 @@ func Test_verify(t *testing.T) {
|
|||||||
certFile.Write(b)
|
certFile.Write(b)
|
||||||
|
|
||||||
err = verify([]string{"-ca", caFile.Name(), "-crt", certFile.Name()}, ob, eb)
|
err = verify([]string{"-ca", caFile.Name(), "-crt", certFile.Name()}, ob, eb)
|
||||||
assert.Equal(t, "", ob.String())
|
assert.Empty(t, ob.String())
|
||||||
assert.Equal(t, "", eb.String())
|
assert.Empty(t, eb.String())
|
||||||
assert.True(t, errors.Is(err, cert.ErrSignatureMismatch))
|
require.ErrorIs(t, err, cert.ErrSignatureMismatch)
|
||||||
|
|
||||||
// verified cert at path
|
// verified cert at path
|
||||||
crt, _ = NewTestCert(ca, caPriv, "test-cert", time.Now().Add(time.Hour*-1), time.Now().Add(time.Hour), nil, nil, nil)
|
crt, _ = NewTestCert(ca, caPriv, "test-cert", time.Now().Add(time.Hour*-1), time.Now().Add(time.Hour), nil, nil, nil)
|
||||||
@@ -118,7 +118,7 @@ func Test_verify(t *testing.T) {
|
|||||||
certFile.Write(b)
|
certFile.Write(b)
|
||||||
|
|
||||||
err = verify([]string{"-ca", caFile.Name(), "-crt", certFile.Name()}, ob, eb)
|
err = verify([]string{"-ca", caFile.Name(), "-crt", certFile.Name()}, ob, eb)
|
||||||
assert.Equal(t, "", ob.String())
|
assert.Empty(t, ob.String())
|
||||||
assert.Equal(t, "", eb.String())
|
assert.Empty(t, eb.String())
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,6 +4,8 @@ import (
|
|||||||
"flag"
|
"flag"
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
|
"runtime/debug"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
"github.com/slackhq/nebula"
|
"github.com/slackhq/nebula"
|
||||||
@@ -18,6 +20,17 @@ import (
|
|||||||
// at compile-time.
|
// at compile-time.
|
||||||
var Build string
|
var Build string
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
if Build == "" {
|
||||||
|
info, ok := debug.ReadBuildInfo()
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
Build = strings.TrimPrefix(info.Main.Version, "v")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
serviceFlag := flag.String("service", "", "Control the system service.")
|
serviceFlag := flag.String("service", "", "Control the system service.")
|
||||||
configPath := flag.String("config", "", "Path to either a file or directory to load configuration from")
|
configPath := flag.String("config", "", "Path to either a file or directory to load configuration from")
|
||||||
|
|||||||
@@ -4,6 +4,8 @@ import (
|
|||||||
"flag"
|
"flag"
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
|
"runtime/debug"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
"github.com/slackhq/nebula"
|
"github.com/slackhq/nebula"
|
||||||
@@ -18,6 +20,17 @@ import (
|
|||||||
// at compile-time.
|
// at compile-time.
|
||||||
var Build string
|
var Build string
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
if Build == "" {
|
||||||
|
info, ok := debug.ReadBuildInfo()
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
Build = strings.TrimPrefix(info.Main.Version, "v")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
configPath := flag.String("config", "", "Path to either a file or directory to load configuration from")
|
configPath := flag.String("config", "", "Path to either a file or directory to load configuration from")
|
||||||
configTest := flag.Bool("test", false, "Test the config and print the end result. Non zero exit indicates a faulty config")
|
configTest := flag.Bool("test", false, "Test the config and print the end result. Non zero exit indicates a faulty config")
|
||||||
|
|||||||
@@ -17,14 +17,14 @@ import (
|
|||||||
|
|
||||||
"dario.cat/mergo"
|
"dario.cat/mergo"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
"gopkg.in/yaml.v2"
|
"go.yaml.in/yaml/v3"
|
||||||
)
|
)
|
||||||
|
|
||||||
type C struct {
|
type C struct {
|
||||||
path string
|
path string
|
||||||
files []string
|
files []string
|
||||||
Settings map[interface{}]interface{}
|
Settings map[string]any
|
||||||
oldSettings map[interface{}]interface{}
|
oldSettings map[string]any
|
||||||
callbacks []func(*C)
|
callbacks []func(*C)
|
||||||
l *logrus.Logger
|
l *logrus.Logger
|
||||||
reloadLock sync.Mutex
|
reloadLock sync.Mutex
|
||||||
@@ -32,7 +32,7 @@ type C struct {
|
|||||||
|
|
||||||
func NewC(l *logrus.Logger) *C {
|
func NewC(l *logrus.Logger) *C {
|
||||||
return &C{
|
return &C{
|
||||||
Settings: make(map[interface{}]interface{}),
|
Settings: make(map[string]any),
|
||||||
l: l,
|
l: l,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -92,8 +92,8 @@ func (c *C) HasChanged(k string) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
nv interface{}
|
nv any
|
||||||
ov interface{}
|
ov any
|
||||||
)
|
)
|
||||||
|
|
||||||
if k == "" {
|
if k == "" {
|
||||||
@@ -147,7 +147,7 @@ func (c *C) ReloadConfig() {
|
|||||||
c.reloadLock.Lock()
|
c.reloadLock.Lock()
|
||||||
defer c.reloadLock.Unlock()
|
defer c.reloadLock.Unlock()
|
||||||
|
|
||||||
c.oldSettings = make(map[interface{}]interface{})
|
c.oldSettings = make(map[string]any)
|
||||||
for k, v := range c.Settings {
|
for k, v := range c.Settings {
|
||||||
c.oldSettings[k] = v
|
c.oldSettings[k] = v
|
||||||
}
|
}
|
||||||
@@ -167,7 +167,7 @@ func (c *C) ReloadConfigString(raw string) error {
|
|||||||
c.reloadLock.Lock()
|
c.reloadLock.Lock()
|
||||||
defer c.reloadLock.Unlock()
|
defer c.reloadLock.Unlock()
|
||||||
|
|
||||||
c.oldSettings = make(map[interface{}]interface{})
|
c.oldSettings = make(map[string]any)
|
||||||
for k, v := range c.Settings {
|
for k, v := range c.Settings {
|
||||||
c.oldSettings[k] = v
|
c.oldSettings[k] = v
|
||||||
}
|
}
|
||||||
@@ -201,7 +201,7 @@ func (c *C) GetStringSlice(k string, d []string) []string {
|
|||||||
return d
|
return d
|
||||||
}
|
}
|
||||||
|
|
||||||
rv, ok := r.([]interface{})
|
rv, ok := r.([]any)
|
||||||
if !ok {
|
if !ok {
|
||||||
return d
|
return d
|
||||||
}
|
}
|
||||||
@@ -215,13 +215,13 @@ func (c *C) GetStringSlice(k string, d []string) []string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetMap will get the map for k or return the default d if not found or invalid
|
// GetMap will get the map for k or return the default d if not found or invalid
|
||||||
func (c *C) GetMap(k string, d map[interface{}]interface{}) map[interface{}]interface{} {
|
func (c *C) GetMap(k string, d map[string]any) map[string]any {
|
||||||
r := c.Get(k)
|
r := c.Get(k)
|
||||||
if r == nil {
|
if r == nil {
|
||||||
return d
|
return d
|
||||||
}
|
}
|
||||||
|
|
||||||
v, ok := r.(map[interface{}]interface{})
|
v, ok := r.(map[string]any)
|
||||||
if !ok {
|
if !ok {
|
||||||
return d
|
return d
|
||||||
}
|
}
|
||||||
@@ -243,7 +243,7 @@ func (c *C) GetInt(k string, d int) int {
|
|||||||
// GetUint32 will get the uint32 for k or return the default d if not found or invalid
|
// GetUint32 will get the uint32 for k or return the default d if not found or invalid
|
||||||
func (c *C) GetUint32(k string, d uint32) uint32 {
|
func (c *C) GetUint32(k string, d uint32) uint32 {
|
||||||
r := c.GetInt(k, int(d))
|
r := c.GetInt(k, int(d))
|
||||||
if uint64(r) > uint64(math.MaxUint32) {
|
if r < 0 || uint64(r) > uint64(math.MaxUint32) {
|
||||||
return d
|
return d
|
||||||
}
|
}
|
||||||
return uint32(r)
|
return uint32(r)
|
||||||
@@ -266,6 +266,22 @@ func (c *C) GetBool(k string, d bool) bool {
|
|||||||
return v
|
return v
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func AsBool(v any) (value bool, ok bool) {
|
||||||
|
switch x := v.(type) {
|
||||||
|
case bool:
|
||||||
|
return x, true
|
||||||
|
case string:
|
||||||
|
switch x {
|
||||||
|
case "y", "yes":
|
||||||
|
return true, true
|
||||||
|
case "n", "no":
|
||||||
|
return false, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false, false
|
||||||
|
}
|
||||||
|
|
||||||
// GetDuration will get the duration for k or return the default d if not found or invalid
|
// GetDuration will get the duration for k or return the default d if not found or invalid
|
||||||
func (c *C) GetDuration(k string, d time.Duration) time.Duration {
|
func (c *C) GetDuration(k string, d time.Duration) time.Duration {
|
||||||
r := c.GetString(k, "")
|
r := c.GetString(k, "")
|
||||||
@@ -276,7 +292,7 @@ func (c *C) GetDuration(k string, d time.Duration) time.Duration {
|
|||||||
return v
|
return v
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *C) Get(k string) interface{} {
|
func (c *C) Get(k string) any {
|
||||||
return c.get(k, c.Settings)
|
return c.get(k, c.Settings)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -284,10 +300,10 @@ func (c *C) IsSet(k string) bool {
|
|||||||
return c.get(k, c.Settings) != nil
|
return c.get(k, c.Settings) != nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *C) get(k string, v interface{}) interface{} {
|
func (c *C) get(k string, v any) any {
|
||||||
parts := strings.Split(k, ".")
|
parts := strings.Split(k, ".")
|
||||||
for _, p := range parts {
|
for _, p := range parts {
|
||||||
m, ok := v.(map[interface{}]interface{})
|
m, ok := v.(map[string]any)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -346,7 +362,7 @@ func (c *C) addFile(path string, direct bool) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *C) parseRaw(b []byte) error {
|
func (c *C) parseRaw(b []byte) error {
|
||||||
var m map[interface{}]interface{}
|
var m map[string]any
|
||||||
|
|
||||||
err := yaml.Unmarshal(b, &m)
|
err := yaml.Unmarshal(b, &m)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -358,7 +374,7 @@ func (c *C) parseRaw(b []byte) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *C) parse() error {
|
func (c *C) parse() error {
|
||||||
var m map[interface{}]interface{}
|
var m map[string]any
|
||||||
|
|
||||||
for _, path := range c.files {
|
for _, path := range c.files {
|
||||||
b, err := os.ReadFile(path)
|
b, err := os.ReadFile(path)
|
||||||
@@ -366,7 +382,7 @@ func (c *C) parse() error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
var nm map[interface{}]interface{}
|
var nm map[string]any
|
||||||
err = yaml.Unmarshal(b, &nm)
|
err = yaml.Unmarshal(b, &nm)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ import (
|
|||||||
"github.com/slackhq/nebula/test"
|
"github.com/slackhq/nebula/test"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"gopkg.in/yaml.v2"
|
"go.yaml.in/yaml/v3"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestConfig_Load(t *testing.T) {
|
func TestConfig_Load(t *testing.T) {
|
||||||
@@ -19,20 +19,20 @@ func TestConfig_Load(t *testing.T) {
|
|||||||
// invalid yaml
|
// invalid yaml
|
||||||
c := NewC(l)
|
c := NewC(l)
|
||||||
os.WriteFile(filepath.Join(dir, "01.yaml"), []byte(" invalid yaml"), 0644)
|
os.WriteFile(filepath.Join(dir, "01.yaml"), []byte(" invalid yaml"), 0644)
|
||||||
assert.EqualError(t, c.Load(dir), "yaml: unmarshal errors:\n line 1: cannot unmarshal !!str `invalid...` into map[interface {}]interface {}")
|
require.EqualError(t, c.Load(dir), "yaml: unmarshal errors:\n line 1: cannot unmarshal !!str `invalid...` into map[string]interface {}")
|
||||||
|
|
||||||
// simple multi config merge
|
// simple multi config merge
|
||||||
c = NewC(l)
|
c = NewC(l)
|
||||||
os.RemoveAll(dir)
|
os.RemoveAll(dir)
|
||||||
os.Mkdir(dir, 0755)
|
os.Mkdir(dir, 0755)
|
||||||
|
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
os.WriteFile(filepath.Join(dir, "01.yaml"), []byte("outer:\n inner: hi"), 0644)
|
os.WriteFile(filepath.Join(dir, "01.yaml"), []byte("outer:\n inner: hi"), 0644)
|
||||||
os.WriteFile(filepath.Join(dir, "02.yml"), []byte("outer:\n inner: override\nnew: hi"), 0644)
|
os.WriteFile(filepath.Join(dir, "02.yml"), []byte("outer:\n inner: override\nnew: hi"), 0644)
|
||||||
assert.Nil(t, c.Load(dir))
|
require.NoError(t, c.Load(dir))
|
||||||
expected := map[interface{}]interface{}{
|
expected := map[string]any{
|
||||||
"outer": map[interface{}]interface{}{
|
"outer": map[string]any{
|
||||||
"inner": "override",
|
"inner": "override",
|
||||||
},
|
},
|
||||||
"new": "hi",
|
"new": "hi",
|
||||||
@@ -44,12 +44,12 @@ func TestConfig_Get(t *testing.T) {
|
|||||||
l := test.NewLogger()
|
l := test.NewLogger()
|
||||||
// test simple type
|
// test simple type
|
||||||
c := NewC(l)
|
c := NewC(l)
|
||||||
c.Settings["firewall"] = map[interface{}]interface{}{"outbound": "hi"}
|
c.Settings["firewall"] = map[string]any{"outbound": "hi"}
|
||||||
assert.Equal(t, "hi", c.Get("firewall.outbound"))
|
assert.Equal(t, "hi", c.Get("firewall.outbound"))
|
||||||
|
|
||||||
// test complex type
|
// test complex type
|
||||||
inner := []map[interface{}]interface{}{{"port": "1", "code": "2"}}
|
inner := []map[string]any{{"port": "1", "code": "2"}}
|
||||||
c.Settings["firewall"] = map[interface{}]interface{}{"outbound": inner}
|
c.Settings["firewall"] = map[string]any{"outbound": inner}
|
||||||
assert.EqualValues(t, inner, c.Get("firewall.outbound"))
|
assert.EqualValues(t, inner, c.Get("firewall.outbound"))
|
||||||
|
|
||||||
// test missing
|
// test missing
|
||||||
@@ -59,7 +59,7 @@ func TestConfig_Get(t *testing.T) {
|
|||||||
func TestConfig_GetStringSlice(t *testing.T) {
|
func TestConfig_GetStringSlice(t *testing.T) {
|
||||||
l := test.NewLogger()
|
l := test.NewLogger()
|
||||||
c := NewC(l)
|
c := NewC(l)
|
||||||
c.Settings["slice"] = []interface{}{"one", "two"}
|
c.Settings["slice"] = []any{"one", "two"}
|
||||||
assert.Equal(t, []string{"one", "two"}, c.GetStringSlice("slice", []string{}))
|
assert.Equal(t, []string{"one", "two"}, c.GetStringSlice("slice", []string{}))
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -67,28 +67,28 @@ func TestConfig_GetBool(t *testing.T) {
|
|||||||
l := test.NewLogger()
|
l := test.NewLogger()
|
||||||
c := NewC(l)
|
c := NewC(l)
|
||||||
c.Settings["bool"] = true
|
c.Settings["bool"] = true
|
||||||
assert.Equal(t, true, c.GetBool("bool", false))
|
assert.True(t, c.GetBool("bool", false))
|
||||||
|
|
||||||
c.Settings["bool"] = "true"
|
c.Settings["bool"] = "true"
|
||||||
assert.Equal(t, true, c.GetBool("bool", false))
|
assert.True(t, c.GetBool("bool", false))
|
||||||
|
|
||||||
c.Settings["bool"] = false
|
c.Settings["bool"] = false
|
||||||
assert.Equal(t, false, c.GetBool("bool", true))
|
assert.False(t, c.GetBool("bool", true))
|
||||||
|
|
||||||
c.Settings["bool"] = "false"
|
c.Settings["bool"] = "false"
|
||||||
assert.Equal(t, false, c.GetBool("bool", true))
|
assert.False(t, c.GetBool("bool", true))
|
||||||
|
|
||||||
c.Settings["bool"] = "Y"
|
c.Settings["bool"] = "Y"
|
||||||
assert.Equal(t, true, c.GetBool("bool", false))
|
assert.True(t, c.GetBool("bool", false))
|
||||||
|
|
||||||
c.Settings["bool"] = "yEs"
|
c.Settings["bool"] = "yEs"
|
||||||
assert.Equal(t, true, c.GetBool("bool", false))
|
assert.True(t, c.GetBool("bool", false))
|
||||||
|
|
||||||
c.Settings["bool"] = "N"
|
c.Settings["bool"] = "N"
|
||||||
assert.Equal(t, false, c.GetBool("bool", true))
|
assert.False(t, c.GetBool("bool", true))
|
||||||
|
|
||||||
c.Settings["bool"] = "nO"
|
c.Settings["bool"] = "nO"
|
||||||
assert.Equal(t, false, c.GetBool("bool", true))
|
assert.False(t, c.GetBool("bool", true))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestConfig_HasChanged(t *testing.T) {
|
func TestConfig_HasChanged(t *testing.T) {
|
||||||
@@ -101,14 +101,14 @@ func TestConfig_HasChanged(t *testing.T) {
|
|||||||
// Test key change
|
// Test key change
|
||||||
c = NewC(l)
|
c = NewC(l)
|
||||||
c.Settings["test"] = "hi"
|
c.Settings["test"] = "hi"
|
||||||
c.oldSettings = map[interface{}]interface{}{"test": "no"}
|
c.oldSettings = map[string]any{"test": "no"}
|
||||||
assert.True(t, c.HasChanged("test"))
|
assert.True(t, c.HasChanged("test"))
|
||||||
assert.True(t, c.HasChanged(""))
|
assert.True(t, c.HasChanged(""))
|
||||||
|
|
||||||
// No key change
|
// No key change
|
||||||
c = NewC(l)
|
c = NewC(l)
|
||||||
c.Settings["test"] = "hi"
|
c.Settings["test"] = "hi"
|
||||||
c.oldSettings = map[interface{}]interface{}{"test": "hi"}
|
c.oldSettings = map[string]any{"test": "hi"}
|
||||||
assert.False(t, c.HasChanged("test"))
|
assert.False(t, c.HasChanged("test"))
|
||||||
assert.False(t, c.HasChanged(""))
|
assert.False(t, c.HasChanged(""))
|
||||||
}
|
}
|
||||||
@@ -117,11 +117,11 @@ func TestConfig_ReloadConfig(t *testing.T) {
|
|||||||
l := test.NewLogger()
|
l := test.NewLogger()
|
||||||
done := make(chan bool, 1)
|
done := make(chan bool, 1)
|
||||||
dir, err := os.MkdirTemp("", "config-test")
|
dir, err := os.MkdirTemp("", "config-test")
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
os.WriteFile(filepath.Join(dir, "01.yaml"), []byte("outer:\n inner: hi"), 0644)
|
os.WriteFile(filepath.Join(dir, "01.yaml"), []byte("outer:\n inner: hi"), 0644)
|
||||||
|
|
||||||
c := NewC(l)
|
c := NewC(l)
|
||||||
assert.Nil(t, c.Load(dir))
|
require.NoError(t, c.Load(dir))
|
||||||
|
|
||||||
assert.False(t, c.HasChanged("outer.inner"))
|
assert.False(t, c.HasChanged("outer.inner"))
|
||||||
assert.False(t, c.HasChanged("outer"))
|
assert.False(t, c.HasChanged("outer"))
|
||||||
@@ -184,11 +184,11 @@ firewall:
|
|||||||
`),
|
`),
|
||||||
}
|
}
|
||||||
|
|
||||||
var m map[any]any
|
var m map[string]any
|
||||||
|
|
||||||
// merge the same way config.parse() merges
|
// merge the same way config.parse() merges
|
||||||
for _, b := range configs {
|
for _, b := range configs {
|
||||||
var nm map[any]any
|
var nm map[string]any
|
||||||
err := yaml.Unmarshal(b, &nm)
|
err := yaml.Unmarshal(b, &nm)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
@@ -205,15 +205,15 @@ firewall:
|
|||||||
t.Logf("Merged Config as YAML:\n%s", mYaml)
|
t.Logf("Merged Config as YAML:\n%s", mYaml)
|
||||||
|
|
||||||
// If a bug is present, some items might be replaced instead of merged like we expect
|
// If a bug is present, some items might be replaced instead of merged like we expect
|
||||||
expected := map[any]any{
|
expected := map[string]any{
|
||||||
"firewall": map[any]any{
|
"firewall": map[string]any{
|
||||||
"inbound": []any{
|
"inbound": []any{
|
||||||
map[any]any{"host": "any", "port": "any", "proto": "icmp"},
|
map[string]any{"host": "any", "port": "any", "proto": "icmp"},
|
||||||
map[any]any{"groups": []any{"server"}, "port": 443, "proto": "tcp"},
|
map[string]any{"groups": []any{"server"}, "port": 443, "proto": "tcp"},
|
||||||
map[any]any{"groups": []any{"webapp"}, "port": 443, "proto": "tcp"}},
|
map[string]any{"groups": []any{"webapp"}, "port": 443, "proto": "tcp"}},
|
||||||
"outbound": []any{
|
"outbound": []any{
|
||||||
map[any]any{"host": "any", "port": "any", "proto": "any"}}},
|
map[string]any{"host": "any", "port": "any", "proto": "any"}}},
|
||||||
"listen": map[any]any{
|
"listen": map[string]any{
|
||||||
"host": "0.0.0.0",
|
"host": "0.0.0.0",
|
||||||
"port": 4242,
|
"port": 4242,
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -4,13 +4,16 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
|
"fmt"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"sync"
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/rcrowley/go-metrics"
|
"github.com/rcrowley/go-metrics"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
"github.com/slackhq/nebula/cert"
|
"github.com/slackhq/nebula/cert"
|
||||||
|
"github.com/slackhq/nebula/config"
|
||||||
"github.com/slackhq/nebula/header"
|
"github.com/slackhq/nebula/header"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -27,130 +30,124 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type connectionManager struct {
|
type connectionManager struct {
|
||||||
in map[uint32]struct{}
|
|
||||||
inLock *sync.RWMutex
|
|
||||||
|
|
||||||
out map[uint32]struct{}
|
|
||||||
outLock *sync.RWMutex
|
|
||||||
|
|
||||||
// relayUsed holds which relay localIndexs are in use
|
// relayUsed holds which relay localIndexs are in use
|
||||||
relayUsed map[uint32]struct{}
|
relayUsed map[uint32]struct{}
|
||||||
relayUsedLock *sync.RWMutex
|
relayUsedLock *sync.RWMutex
|
||||||
|
|
||||||
hostMap *HostMap
|
hostMap *HostMap
|
||||||
trafficTimer *LockingTimerWheel[uint32]
|
trafficTimer *LockingTimerWheel[uint32]
|
||||||
intf *Interface
|
intf *Interface
|
||||||
pendingDeletion map[uint32]struct{}
|
punchy *Punchy
|
||||||
punchy *Punchy
|
|
||||||
|
// Configuration settings
|
||||||
checkInterval time.Duration
|
checkInterval time.Duration
|
||||||
pendingDeletionInterval time.Duration
|
pendingDeletionInterval time.Duration
|
||||||
metricsTxPunchy metrics.Counter
|
inactivityTimeout atomic.Int64
|
||||||
|
dropInactive atomic.Bool
|
||||||
|
|
||||||
|
metricsTxPunchy metrics.Counter
|
||||||
|
|
||||||
l *logrus.Logger
|
l *logrus.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
func newConnectionManager(ctx context.Context, l *logrus.Logger, intf *Interface, checkInterval, pendingDeletionInterval time.Duration, punchy *Punchy) *connectionManager {
|
func newConnectionManagerFromConfig(l *logrus.Logger, c *config.C, hm *HostMap, p *Punchy) *connectionManager {
|
||||||
var max time.Duration
|
cm := &connectionManager{
|
||||||
if checkInterval < pendingDeletionInterval {
|
hostMap: hm,
|
||||||
max = pendingDeletionInterval
|
l: l,
|
||||||
} else {
|
punchy: p,
|
||||||
max = checkInterval
|
relayUsed: make(map[uint32]struct{}),
|
||||||
|
relayUsedLock: &sync.RWMutex{},
|
||||||
|
metricsTxPunchy: metrics.GetOrRegisterCounter("messages.tx.punchy", nil),
|
||||||
}
|
}
|
||||||
|
|
||||||
nc := &connectionManager{
|
cm.reload(c, true)
|
||||||
hostMap: intf.hostMap,
|
c.RegisterReloadCallback(func(c *config.C) {
|
||||||
in: make(map[uint32]struct{}),
|
cm.reload(c, false)
|
||||||
inLock: &sync.RWMutex{},
|
})
|
||||||
out: make(map[uint32]struct{}),
|
|
||||||
outLock: &sync.RWMutex{},
|
|
||||||
relayUsed: make(map[uint32]struct{}),
|
|
||||||
relayUsedLock: &sync.RWMutex{},
|
|
||||||
trafficTimer: NewLockingTimerWheel[uint32](time.Millisecond*500, max),
|
|
||||||
intf: intf,
|
|
||||||
pendingDeletion: make(map[uint32]struct{}),
|
|
||||||
checkInterval: checkInterval,
|
|
||||||
pendingDeletionInterval: pendingDeletionInterval,
|
|
||||||
punchy: punchy,
|
|
||||||
metricsTxPunchy: metrics.GetOrRegisterCounter("messages.tx.punchy", nil),
|
|
||||||
l: l,
|
|
||||||
}
|
|
||||||
|
|
||||||
nc.Start(ctx)
|
return cm
|
||||||
return nc
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n *connectionManager) In(localIndex uint32) {
|
func (cm *connectionManager) reload(c *config.C, initial bool) {
|
||||||
n.inLock.RLock()
|
if initial {
|
||||||
// If this already exists, return
|
cm.checkInterval = time.Duration(c.GetInt("timers.connection_alive_interval", 5)) * time.Second
|
||||||
if _, ok := n.in[localIndex]; ok {
|
cm.pendingDeletionInterval = time.Duration(c.GetInt("timers.pending_deletion_interval", 10)) * time.Second
|
||||||
n.inLock.RUnlock()
|
|
||||||
return
|
// We want at least a minimum resolution of 500ms per tick so that we can hit these intervals
|
||||||
|
// pretty close to their configured duration.
|
||||||
|
// The inactivity duration is checked each time a hostinfo ticks through so we don't need the wheel to contain it.
|
||||||
|
minDuration := min(time.Millisecond*500, cm.checkInterval, cm.pendingDeletionInterval)
|
||||||
|
maxDuration := max(cm.checkInterval, cm.pendingDeletionInterval)
|
||||||
|
cm.trafficTimer = NewLockingTimerWheel[uint32](minDuration, maxDuration)
|
||||||
|
}
|
||||||
|
|
||||||
|
if initial || c.HasChanged("tunnels.inactivity_timeout") {
|
||||||
|
old := cm.getInactivityTimeout()
|
||||||
|
cm.inactivityTimeout.Store((int64)(c.GetDuration("tunnels.inactivity_timeout", 10*time.Minute)))
|
||||||
|
if !initial {
|
||||||
|
cm.l.WithField("oldDuration", old).
|
||||||
|
WithField("newDuration", cm.getInactivityTimeout()).
|
||||||
|
Info("Inactivity timeout has changed")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if initial || c.HasChanged("tunnels.drop_inactive") {
|
||||||
|
old := cm.dropInactive.Load()
|
||||||
|
cm.dropInactive.Store(c.GetBool("tunnels.drop_inactive", false))
|
||||||
|
if !initial {
|
||||||
|
cm.l.WithField("oldBool", old).
|
||||||
|
WithField("newBool", cm.dropInactive.Load()).
|
||||||
|
Info("Drop inactive setting has changed")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
n.inLock.RUnlock()
|
|
||||||
n.inLock.Lock()
|
|
||||||
n.in[localIndex] = struct{}{}
|
|
||||||
n.inLock.Unlock()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n *connectionManager) Out(localIndex uint32) {
|
func (cm *connectionManager) getInactivityTimeout() time.Duration {
|
||||||
n.outLock.RLock()
|
return (time.Duration)(cm.inactivityTimeout.Load())
|
||||||
// If this already exists, return
|
|
||||||
if _, ok := n.out[localIndex]; ok {
|
|
||||||
n.outLock.RUnlock()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
n.outLock.RUnlock()
|
|
||||||
n.outLock.Lock()
|
|
||||||
n.out[localIndex] = struct{}{}
|
|
||||||
n.outLock.Unlock()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n *connectionManager) RelayUsed(localIndex uint32) {
|
func (cm *connectionManager) In(h *HostInfo) {
|
||||||
n.relayUsedLock.RLock()
|
h.in.Store(true)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cm *connectionManager) Out(h *HostInfo) {
|
||||||
|
h.out.Store(true)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cm *connectionManager) RelayUsed(localIndex uint32) {
|
||||||
|
cm.relayUsedLock.RLock()
|
||||||
// If this already exists, return
|
// If this already exists, return
|
||||||
if _, ok := n.relayUsed[localIndex]; ok {
|
if _, ok := cm.relayUsed[localIndex]; ok {
|
||||||
n.relayUsedLock.RUnlock()
|
cm.relayUsedLock.RUnlock()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
n.relayUsedLock.RUnlock()
|
cm.relayUsedLock.RUnlock()
|
||||||
n.relayUsedLock.Lock()
|
cm.relayUsedLock.Lock()
|
||||||
n.relayUsed[localIndex] = struct{}{}
|
cm.relayUsed[localIndex] = struct{}{}
|
||||||
n.relayUsedLock.Unlock()
|
cm.relayUsedLock.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
// getAndResetTrafficCheck returns if there was any inbound or outbound traffic within the last tick and
|
// getAndResetTrafficCheck returns if there was any inbound or outbound traffic within the last tick and
|
||||||
// resets the state for this local index
|
// resets the state for this local index
|
||||||
func (n *connectionManager) getAndResetTrafficCheck(localIndex uint32) (bool, bool) {
|
func (cm *connectionManager) getAndResetTrafficCheck(h *HostInfo, now time.Time) (bool, bool) {
|
||||||
n.inLock.Lock()
|
in := h.in.Swap(false)
|
||||||
n.outLock.Lock()
|
out := h.out.Swap(false)
|
||||||
_, in := n.in[localIndex]
|
if in || out {
|
||||||
_, out := n.out[localIndex]
|
h.lastUsed = now
|
||||||
delete(n.in, localIndex)
|
}
|
||||||
delete(n.out, localIndex)
|
|
||||||
n.inLock.Unlock()
|
|
||||||
n.outLock.Unlock()
|
|
||||||
return in, out
|
return in, out
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n *connectionManager) AddTrafficWatch(localIndex uint32) {
|
// AddTrafficWatch must be called for every new HostInfo.
|
||||||
// Use a write lock directly because it should be incredibly rare that we are ever already tracking this index
|
// We will continue to monitor the HostInfo until the tunnel is dropped.
|
||||||
n.outLock.Lock()
|
func (cm *connectionManager) AddTrafficWatch(h *HostInfo) {
|
||||||
if _, ok := n.out[localIndex]; ok {
|
if h.out.Swap(true) == false {
|
||||||
n.outLock.Unlock()
|
cm.trafficTimer.Add(h.localIndexId, cm.checkInterval)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
n.out[localIndex] = struct{}{}
|
|
||||||
n.trafficTimer.Add(localIndex, n.checkInterval)
|
|
||||||
n.outLock.Unlock()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n *connectionManager) Start(ctx context.Context) {
|
func (cm *connectionManager) Start(ctx context.Context) {
|
||||||
go n.Run(ctx)
|
clockSource := time.NewTicker(cm.trafficTimer.t.tickDuration)
|
||||||
}
|
|
||||||
|
|
||||||
func (n *connectionManager) Run(ctx context.Context) {
|
|
||||||
//TODO: this tick should be based on the min wheel tick? Check firewall
|
|
||||||
clockSource := time.NewTicker(500 * time.Millisecond)
|
|
||||||
defer clockSource.Stop()
|
defer clockSource.Stop()
|
||||||
|
|
||||||
p := []byte("")
|
p := []byte("")
|
||||||
@@ -163,61 +160,61 @@ func (n *connectionManager) Run(ctx context.Context) {
|
|||||||
return
|
return
|
||||||
|
|
||||||
case now := <-clockSource.C:
|
case now := <-clockSource.C:
|
||||||
n.trafficTimer.Advance(now)
|
cm.trafficTimer.Advance(now)
|
||||||
for {
|
for {
|
||||||
localIndex, has := n.trafficTimer.Purge()
|
localIndex, has := cm.trafficTimer.Purge()
|
||||||
if !has {
|
if !has {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
n.doTrafficCheck(localIndex, p, nb, out, now)
|
cm.doTrafficCheck(localIndex, p, nb, out, now)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n *connectionManager) doTrafficCheck(localIndex uint32, p, nb, out []byte, now time.Time) {
|
func (cm *connectionManager) doTrafficCheck(localIndex uint32, p, nb, out []byte, now time.Time) {
|
||||||
decision, hostinfo, primary := n.makeTrafficDecision(localIndex, now)
|
decision, hostinfo, primary := cm.makeTrafficDecision(localIndex, now)
|
||||||
|
|
||||||
switch decision {
|
switch decision {
|
||||||
case deleteTunnel:
|
case deleteTunnel:
|
||||||
if n.hostMap.DeleteHostInfo(hostinfo) {
|
if cm.hostMap.DeleteHostInfo(hostinfo) {
|
||||||
// Only clearing the lighthouse cache if this is the last hostinfo for this vpn ip in the hostmap
|
// Only clearing the lighthouse cache if this is the last hostinfo for this vpn ip in the hostmap
|
||||||
n.intf.lightHouse.DeleteVpnAddrs(hostinfo.vpnAddrs)
|
cm.intf.lightHouse.DeleteVpnAddrs(hostinfo.vpnAddrs)
|
||||||
}
|
}
|
||||||
|
|
||||||
case closeTunnel:
|
case closeTunnel:
|
||||||
n.intf.sendCloseTunnel(hostinfo)
|
cm.intf.sendCloseTunnel(hostinfo)
|
||||||
n.intf.closeTunnel(hostinfo)
|
cm.intf.closeTunnel(hostinfo)
|
||||||
|
|
||||||
case swapPrimary:
|
case swapPrimary:
|
||||||
n.swapPrimary(hostinfo, primary)
|
cm.swapPrimary(hostinfo, primary)
|
||||||
|
|
||||||
case migrateRelays:
|
case migrateRelays:
|
||||||
n.migrateRelayUsed(hostinfo, primary)
|
cm.migrateRelayUsed(hostinfo, primary)
|
||||||
|
|
||||||
case tryRehandshake:
|
case tryRehandshake:
|
||||||
n.tryRehandshake(hostinfo)
|
cm.tryRehandshake(hostinfo)
|
||||||
|
|
||||||
case sendTestPacket:
|
case sendTestPacket:
|
||||||
n.intf.SendMessageToHostInfo(header.Test, header.TestRequest, hostinfo, p, nb, out)
|
cm.intf.SendMessageToHostInfo(header.Test, header.TestRequest, hostinfo, p, nb, out)
|
||||||
}
|
}
|
||||||
|
|
||||||
n.resetRelayTrafficCheck(hostinfo)
|
cm.resetRelayTrafficCheck(hostinfo)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n *connectionManager) resetRelayTrafficCheck(hostinfo *HostInfo) {
|
func (cm *connectionManager) resetRelayTrafficCheck(hostinfo *HostInfo) {
|
||||||
if hostinfo != nil {
|
if hostinfo != nil {
|
||||||
n.relayUsedLock.Lock()
|
cm.relayUsedLock.Lock()
|
||||||
defer n.relayUsedLock.Unlock()
|
defer cm.relayUsedLock.Unlock()
|
||||||
// No need to migrate any relays, delete usage info now.
|
// No need to migrate any relays, delete usage info now.
|
||||||
for _, idx := range hostinfo.relayState.CopyRelayForIdxs() {
|
for _, idx := range hostinfo.relayState.CopyRelayForIdxs() {
|
||||||
delete(n.relayUsed, idx)
|
delete(cm.relayUsed, idx)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo) {
|
func (cm *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo) {
|
||||||
relayFor := oldhostinfo.relayState.CopyAllRelayFor()
|
relayFor := oldhostinfo.relayState.CopyAllRelayFor()
|
||||||
|
|
||||||
for _, r := range relayFor {
|
for _, r := range relayFor {
|
||||||
@@ -227,46 +224,51 @@ func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo)
|
|||||||
var relayFrom netip.Addr
|
var relayFrom netip.Addr
|
||||||
var relayTo netip.Addr
|
var relayTo netip.Addr
|
||||||
switch {
|
switch {
|
||||||
case ok && existing.State == Established:
|
case ok:
|
||||||
// This relay already exists in newhostinfo, then do nothing.
|
switch existing.State {
|
||||||
continue
|
case Established, PeerRequested, Disestablished:
|
||||||
case ok && existing.State == Requested:
|
// This relay already exists in newhostinfo, then do nothing.
|
||||||
// The relay exists in a Requested state; re-send the request
|
continue
|
||||||
index = existing.LocalIndex
|
case Requested:
|
||||||
switch r.Type {
|
// The relay exists in a Requested state; re-send the request
|
||||||
case TerminalType:
|
index = existing.LocalIndex
|
||||||
relayFrom = n.intf.myVpnAddrs[0]
|
switch r.Type {
|
||||||
relayTo = existing.PeerAddr
|
case TerminalType:
|
||||||
case ForwardingType:
|
relayFrom = cm.intf.myVpnAddrs[0]
|
||||||
relayFrom = existing.PeerAddr
|
relayTo = existing.PeerAddr
|
||||||
relayTo = newhostinfo.vpnAddrs[0]
|
case ForwardingType:
|
||||||
default:
|
relayFrom = existing.PeerAddr
|
||||||
// should never happen
|
relayTo = newhostinfo.vpnAddrs[0]
|
||||||
|
default:
|
||||||
|
// should never happen
|
||||||
|
panic(fmt.Sprintf("Migrating unknown relay type: %v", r.Type))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
case !ok:
|
case !ok:
|
||||||
n.relayUsedLock.RLock()
|
cm.relayUsedLock.RLock()
|
||||||
if _, relayUsed := n.relayUsed[r.LocalIndex]; !relayUsed {
|
if _, relayUsed := cm.relayUsed[r.LocalIndex]; !relayUsed {
|
||||||
// The relay hasn't been used; don't migrate it.
|
// The relay hasn't been used; don't migrate it.
|
||||||
n.relayUsedLock.RUnlock()
|
cm.relayUsedLock.RUnlock()
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
n.relayUsedLock.RUnlock()
|
cm.relayUsedLock.RUnlock()
|
||||||
// The relay doesn't exist at all; create some relay state and send the request.
|
// The relay doesn't exist at all; create some relay state and send the request.
|
||||||
var err error
|
var err error
|
||||||
index, err = AddRelay(n.l, newhostinfo, n.hostMap, r.PeerAddr, nil, r.Type, Requested)
|
index, err = AddRelay(cm.l, newhostinfo, cm.hostMap, r.PeerAddr, nil, r.Type, Requested)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
n.l.WithError(err).Error("failed to migrate relay to new hostinfo")
|
cm.l.WithError(err).Error("failed to migrate relay to new hostinfo")
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
switch r.Type {
|
switch r.Type {
|
||||||
case TerminalType:
|
case TerminalType:
|
||||||
relayFrom = n.intf.myVpnAddrs[0]
|
relayFrom = cm.intf.myVpnAddrs[0]
|
||||||
relayTo = r.PeerAddr
|
relayTo = r.PeerAddr
|
||||||
case ForwardingType:
|
case ForwardingType:
|
||||||
relayFrom = r.PeerAddr
|
relayFrom = r.PeerAddr
|
||||||
relayTo = newhostinfo.vpnAddrs[0]
|
relayTo = newhostinfo.vpnAddrs[0]
|
||||||
default:
|
default:
|
||||||
// should never happen
|
// should never happen
|
||||||
|
panic(fmt.Sprintf("Migrating unknown relay type: %v", r.Type))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -279,12 +281,12 @@ func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo)
|
|||||||
switch newhostinfo.GetCert().Certificate.Version() {
|
switch newhostinfo.GetCert().Certificate.Version() {
|
||||||
case cert.Version1:
|
case cert.Version1:
|
||||||
if !relayFrom.Is4() {
|
if !relayFrom.Is4() {
|
||||||
n.l.Error("can not migrate v1 relay with a v6 network because the relay is not running a current nebula version")
|
cm.l.Error("can not migrate v1 relay with a v6 network because the relay is not running a current nebula version")
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if !relayTo.Is4() {
|
if !relayTo.Is4() {
|
||||||
n.l.Error("can not migrate v1 relay with a v6 remote network because the relay is not running a current nebula version")
|
cm.l.Error("can not migrate v1 relay with a v6 remote network because the relay is not running a current nebula version")
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -296,16 +298,16 @@ func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo)
|
|||||||
req.RelayFromAddr = netAddrToProtoAddr(relayFrom)
|
req.RelayFromAddr = netAddrToProtoAddr(relayFrom)
|
||||||
req.RelayToAddr = netAddrToProtoAddr(relayTo)
|
req.RelayToAddr = netAddrToProtoAddr(relayTo)
|
||||||
default:
|
default:
|
||||||
newhostinfo.logger(n.l).Error("Unknown certificate version found while attempting to migrate relay")
|
newhostinfo.logger(cm.l).Error("Unknown certificate version found while attempting to migrate relay")
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
msg, err := req.Marshal()
|
msg, err := req.Marshal()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
n.l.WithError(err).Error("failed to marshal Control message to migrate relay")
|
cm.l.WithError(err).Error("failed to marshal Control message to migrate relay")
|
||||||
} else {
|
} else {
|
||||||
n.intf.SendMessageToHostInfo(header.Control, 0, newhostinfo, msg, make([]byte, 12), make([]byte, mtu))
|
cm.intf.SendMessageToHostInfo(header.Control, 0, newhostinfo, msg, make([]byte, 12), make([]byte, mtu))
|
||||||
n.l.WithFields(logrus.Fields{
|
cm.l.WithFields(logrus.Fields{
|
||||||
"relayFrom": req.RelayFromAddr,
|
"relayFrom": req.RelayFromAddr,
|
||||||
"relayTo": req.RelayToAddr,
|
"relayTo": req.RelayToAddr,
|
||||||
"initiatorRelayIndex": req.InitiatorRelayIndex,
|
"initiatorRelayIndex": req.InitiatorRelayIndex,
|
||||||
@@ -316,46 +318,44 @@ func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n *connectionManager) makeTrafficDecision(localIndex uint32, now time.Time) (trafficDecision, *HostInfo, *HostInfo) {
|
func (cm *connectionManager) makeTrafficDecision(localIndex uint32, now time.Time) (trafficDecision, *HostInfo, *HostInfo) {
|
||||||
n.hostMap.RLock()
|
// Read lock the main hostmap to order decisions based on tunnels being the primary tunnel
|
||||||
defer n.hostMap.RUnlock()
|
cm.hostMap.RLock()
|
||||||
|
defer cm.hostMap.RUnlock()
|
||||||
|
|
||||||
hostinfo := n.hostMap.Indexes[localIndex]
|
hostinfo := cm.hostMap.Indexes[localIndex]
|
||||||
if hostinfo == nil {
|
if hostinfo == nil {
|
||||||
n.l.WithField("localIndex", localIndex).Debugf("Not found in hostmap")
|
cm.l.WithField("localIndex", localIndex).Debugln("Not found in hostmap")
|
||||||
delete(n.pendingDeletion, localIndex)
|
|
||||||
return doNothing, nil, nil
|
return doNothing, nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if n.isInvalidCertificate(now, hostinfo) {
|
if cm.isInvalidCertificate(now, hostinfo) {
|
||||||
delete(n.pendingDeletion, hostinfo.localIndexId)
|
|
||||||
return closeTunnel, hostinfo, nil
|
return closeTunnel, hostinfo, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
primary := n.hostMap.Hosts[hostinfo.vpnAddrs[0]]
|
primary := cm.hostMap.Hosts[hostinfo.vpnAddrs[0]]
|
||||||
mainHostInfo := true
|
mainHostInfo := true
|
||||||
if primary != nil && primary != hostinfo {
|
if primary != nil && primary != hostinfo {
|
||||||
mainHostInfo = false
|
mainHostInfo = false
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check for traffic on this hostinfo
|
// Check for traffic on this hostinfo
|
||||||
inTraffic, outTraffic := n.getAndResetTrafficCheck(localIndex)
|
inTraffic, outTraffic := cm.getAndResetTrafficCheck(hostinfo, now)
|
||||||
|
|
||||||
// A hostinfo is determined alive if there is incoming traffic
|
// A hostinfo is determined alive if there is incoming traffic
|
||||||
if inTraffic {
|
if inTraffic {
|
||||||
decision := doNothing
|
decision := doNothing
|
||||||
if n.l.Level >= logrus.DebugLevel {
|
if cm.l.Level >= logrus.DebugLevel {
|
||||||
hostinfo.logger(n.l).
|
hostinfo.logger(cm.l).
|
||||||
WithField("tunnelCheck", m{"state": "alive", "method": "passive"}).
|
WithField("tunnelCheck", m{"state": "alive", "method": "passive"}).
|
||||||
Debug("Tunnel status")
|
Debug("Tunnel status")
|
||||||
}
|
}
|
||||||
delete(n.pendingDeletion, hostinfo.localIndexId)
|
hostinfo.pendingDeletion.Store(false)
|
||||||
|
|
||||||
if mainHostInfo {
|
if mainHostInfo {
|
||||||
decision = tryRehandshake
|
decision = tryRehandshake
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
if n.shouldSwapPrimary(hostinfo, primary) {
|
if cm.shouldSwapPrimary(hostinfo) {
|
||||||
decision = swapPrimary
|
decision = swapPrimary
|
||||||
} else {
|
} else {
|
||||||
// migrate the relays to the primary, if in use.
|
// migrate the relays to the primary, if in use.
|
||||||
@@ -363,46 +363,55 @@ func (n *connectionManager) makeTrafficDecision(localIndex uint32, now time.Time
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
n.trafficTimer.Add(hostinfo.localIndexId, n.checkInterval)
|
cm.trafficTimer.Add(hostinfo.localIndexId, cm.checkInterval)
|
||||||
|
|
||||||
if !outTraffic {
|
if !outTraffic {
|
||||||
// Send a punch packet to keep the NAT state alive
|
// Send a punch packet to keep the NAT state alive
|
||||||
n.sendPunch(hostinfo)
|
cm.sendPunch(hostinfo)
|
||||||
}
|
}
|
||||||
|
|
||||||
return decision, hostinfo, primary
|
return decision, hostinfo, primary
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, ok := n.pendingDeletion[hostinfo.localIndexId]; ok {
|
if hostinfo.pendingDeletion.Load() {
|
||||||
// We have already sent a test packet and nothing was returned, this hostinfo is dead
|
// We have already sent a test packet and nothing was returned, this hostinfo is dead
|
||||||
hostinfo.logger(n.l).
|
hostinfo.logger(cm.l).
|
||||||
WithField("tunnelCheck", m{"state": "dead", "method": "active"}).
|
WithField("tunnelCheck", m{"state": "dead", "method": "active"}).
|
||||||
Info("Tunnel status")
|
Info("Tunnel status")
|
||||||
|
|
||||||
delete(n.pendingDeletion, hostinfo.localIndexId)
|
|
||||||
return deleteTunnel, hostinfo, nil
|
return deleteTunnel, hostinfo, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
decision := doNothing
|
decision := doNothing
|
||||||
if hostinfo != nil && hostinfo.ConnectionState != nil && mainHostInfo {
|
if hostinfo != nil && hostinfo.ConnectionState != nil && mainHostInfo {
|
||||||
if !outTraffic {
|
if !outTraffic {
|
||||||
|
inactiveFor, isInactive := cm.isInactive(hostinfo, now)
|
||||||
|
if isInactive {
|
||||||
|
// Tunnel is inactive, tear it down
|
||||||
|
hostinfo.logger(cm.l).
|
||||||
|
WithField("inactiveDuration", inactiveFor).
|
||||||
|
WithField("primary", mainHostInfo).
|
||||||
|
Info("Dropping tunnel due to inactivity")
|
||||||
|
|
||||||
|
return closeTunnel, hostinfo, primary
|
||||||
|
}
|
||||||
|
|
||||||
// If we aren't sending or receiving traffic then its an unused tunnel and we don't to test the tunnel.
|
// If we aren't sending or receiving traffic then its an unused tunnel and we don't to test the tunnel.
|
||||||
// Just maintain NAT state if configured to do so.
|
// Just maintain NAT state if configured to do so.
|
||||||
n.sendPunch(hostinfo)
|
cm.sendPunch(hostinfo)
|
||||||
n.trafficTimer.Add(hostinfo.localIndexId, n.checkInterval)
|
cm.trafficTimer.Add(hostinfo.localIndexId, cm.checkInterval)
|
||||||
return doNothing, nil, nil
|
return doNothing, nil, nil
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if n.punchy.GetTargetEverything() {
|
if cm.punchy.GetTargetEverything() {
|
||||||
// This is similar to the old punchy behavior with a slight optimization.
|
// This is similar to the old punchy behavior with a slight optimization.
|
||||||
// We aren't receiving traffic but we are sending it, punch on all known
|
// We aren't receiving traffic but we are sending it, punch on all known
|
||||||
// ips in case we need to re-prime NAT state
|
// ips in case we need to re-prime NAT state
|
||||||
n.sendPunch(hostinfo)
|
cm.sendPunch(hostinfo)
|
||||||
}
|
}
|
||||||
|
|
||||||
if n.l.Level >= logrus.DebugLevel {
|
if cm.l.Level >= logrus.DebugLevel {
|
||||||
hostinfo.logger(n.l).
|
hostinfo.logger(cm.l).
|
||||||
WithField("tunnelCheck", m{"state": "testing", "method": "active"}).
|
WithField("tunnelCheck", m{"state": "testing", "method": "active"}).
|
||||||
Debug("Tunnel status")
|
Debug("Tunnel status")
|
||||||
}
|
}
|
||||||
@@ -411,17 +420,33 @@ func (n *connectionManager) makeTrafficDecision(localIndex uint32, now time.Time
|
|||||||
decision = sendTestPacket
|
decision = sendTestPacket
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
if n.l.Level >= logrus.DebugLevel {
|
if cm.l.Level >= logrus.DebugLevel {
|
||||||
hostinfo.logger(n.l).Debugf("Hostinfo sadness")
|
hostinfo.logger(cm.l).Debugf("Hostinfo sadness")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
n.pendingDeletion[hostinfo.localIndexId] = struct{}{}
|
hostinfo.pendingDeletion.Store(true)
|
||||||
n.trafficTimer.Add(hostinfo.localIndexId, n.pendingDeletionInterval)
|
cm.trafficTimer.Add(hostinfo.localIndexId, cm.pendingDeletionInterval)
|
||||||
return decision, hostinfo, nil
|
return decision, hostinfo, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n *connectionManager) shouldSwapPrimary(current, primary *HostInfo) bool {
|
func (cm *connectionManager) isInactive(hostinfo *HostInfo, now time.Time) (time.Duration, bool) {
|
||||||
|
if cm.dropInactive.Load() == false {
|
||||||
|
// We aren't configured to drop inactive tunnels
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
|
||||||
|
inactiveDuration := now.Sub(hostinfo.lastUsed)
|
||||||
|
if inactiveDuration < cm.getInactivityTimeout() {
|
||||||
|
// It's not considered inactive
|
||||||
|
return inactiveDuration, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// The tunnel is inactive
|
||||||
|
return inactiveDuration, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cm *connectionManager) shouldSwapPrimary(current *HostInfo) bool {
|
||||||
// The primary tunnel is the most recent handshake to complete locally and should work entirely fine.
|
// The primary tunnel is the most recent handshake to complete locally and should work entirely fine.
|
||||||
// If we are here then we have multiple tunnels for a host pair and neither side believes the same tunnel is primary.
|
// If we are here then we have multiple tunnels for a host pair and neither side believes the same tunnel is primary.
|
||||||
// Let's sort this out.
|
// Let's sort this out.
|
||||||
@@ -429,83 +454,127 @@ func (n *connectionManager) shouldSwapPrimary(current, primary *HostInfo) bool {
|
|||||||
// Only one side should swap because if both swap then we may never resolve to a single tunnel.
|
// Only one side should swap because if both swap then we may never resolve to a single tunnel.
|
||||||
// vpn addr is static across all tunnels for this host pair so lets
|
// vpn addr is static across all tunnels for this host pair so lets
|
||||||
// use that to determine if we should consider swapping.
|
// use that to determine if we should consider swapping.
|
||||||
if current.vpnAddrs[0].Compare(n.intf.myVpnAddrs[0]) < 0 {
|
if current.vpnAddrs[0].Compare(cm.intf.myVpnAddrs[0]) < 0 {
|
||||||
// Their primary vpn addr is less than mine. Do not swap.
|
// Their primary vpn addr is less than mine. Do not swap.
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
crt := n.intf.pki.getCertState().getCertificate(current.ConnectionState.myCert.Version())
|
crt := cm.intf.pki.getCertState().getCertificate(current.ConnectionState.myCert.Version())
|
||||||
|
if crt == nil {
|
||||||
|
//my cert was reloaded away. We should definitely swap from this tunnel
|
||||||
|
return true
|
||||||
|
}
|
||||||
// If this tunnel is using the latest certificate then we should swap it to primary for a bit and see if things
|
// If this tunnel is using the latest certificate then we should swap it to primary for a bit and see if things
|
||||||
// settle down.
|
// settle down.
|
||||||
return bytes.Equal(current.ConnectionState.myCert.Signature(), crt.Signature())
|
return bytes.Equal(current.ConnectionState.myCert.Signature(), crt.Signature())
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n *connectionManager) swapPrimary(current, primary *HostInfo) {
|
func (cm *connectionManager) swapPrimary(current, primary *HostInfo) {
|
||||||
n.hostMap.Lock()
|
cm.hostMap.Lock()
|
||||||
// Make sure the primary is still the same after the write lock. This avoids a race with a rehandshake.
|
// Make sure the primary is still the same after the write lock. This avoids a race with a rehandshake.
|
||||||
if n.hostMap.Hosts[current.vpnAddrs[0]] == primary {
|
if cm.hostMap.Hosts[current.vpnAddrs[0]] == primary {
|
||||||
n.hostMap.unlockedMakePrimary(current)
|
cm.hostMap.unlockedMakePrimary(current)
|
||||||
}
|
}
|
||||||
n.hostMap.Unlock()
|
cm.hostMap.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
// isInvalidCertificate will check if we should destroy a tunnel if pki.disconnect_invalid is true and
|
// isInvalidCertificate decides if we should destroy a tunnel.
|
||||||
// the certificate is no longer valid. Block listed certificates will skip the pki.disconnect_invalid
|
// returns true if pki.disconnect_invalid is true and the certificate is no longer valid.
|
||||||
// check and return true.
|
// Blocklisted certificates will skip the pki.disconnect_invalid check and return true.
|
||||||
func (n *connectionManager) isInvalidCertificate(now time.Time, hostinfo *HostInfo) bool {
|
func (cm *connectionManager) isInvalidCertificate(now time.Time, hostinfo *HostInfo) bool {
|
||||||
remoteCert := hostinfo.GetCert()
|
remoteCert := hostinfo.GetCert()
|
||||||
if remoteCert == nil {
|
if remoteCert == nil {
|
||||||
return false
|
return false //don't tear down tunnels for handshakes in progress
|
||||||
}
|
}
|
||||||
|
|
||||||
caPool := n.intf.pki.GetCAPool()
|
caPool := cm.intf.pki.GetCAPool()
|
||||||
err := caPool.VerifyCachedCertificate(now, remoteCert)
|
err := caPool.VerifyCachedCertificate(now, remoteCert)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
return false
|
return false //cert is still valid! yay!
|
||||||
}
|
} else if err == cert.ErrBlockListed { //avoiding errors.Is for speed
|
||||||
|
|
||||||
if !n.intf.disconnectInvalid.Load() && err != cert.ErrBlockListed {
|
|
||||||
// Block listed certificates should always be disconnected
|
// Block listed certificates should always be disconnected
|
||||||
|
hostinfo.logger(cm.l).WithError(err).
|
||||||
|
WithField("fingerprint", remoteCert.Fingerprint).
|
||||||
|
Info("Remote certificate is blocked, tearing down the tunnel")
|
||||||
|
return true
|
||||||
|
} else if cm.intf.disconnectInvalid.Load() {
|
||||||
|
hostinfo.logger(cm.l).WithError(err).
|
||||||
|
WithField("fingerprint", remoteCert.Fingerprint).
|
||||||
|
Info("Remote certificate is no longer valid, tearing down the tunnel")
|
||||||
|
return true
|
||||||
|
} else {
|
||||||
|
//if we reach here, the cert is no longer valid, but we're configured to keep tunnels from now-invalid certs open
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
hostinfo.logger(n.l).WithError(err).
|
|
||||||
WithField("fingerprint", remoteCert.Fingerprint).
|
|
||||||
Info("Remote certificate is no longer valid, tearing down the tunnel")
|
|
||||||
|
|
||||||
return true
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n *connectionManager) sendPunch(hostinfo *HostInfo) {
|
func (cm *connectionManager) sendPunch(hostinfo *HostInfo) {
|
||||||
if !n.punchy.GetPunch() {
|
if !cm.punchy.GetPunch() {
|
||||||
// Punching is disabled
|
// Punching is disabled
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if n.punchy.GetTargetEverything() {
|
if cm.intf.lightHouse.IsAnyLighthouseAddr(hostinfo.vpnAddrs) {
|
||||||
hostinfo.remotes.ForEach(n.hostMap.GetPreferredRanges(), func(addr netip.AddrPort, preferred bool) {
|
// Do not punch to lighthouses, we assume our lighthouse update interval is good enough.
|
||||||
n.metricsTxPunchy.Inc(1)
|
// In the event the update interval is not sufficient to maintain NAT state then a publicly available lighthouse
|
||||||
n.intf.outside.WriteTo([]byte{1}, addr)
|
// would lose the ability to notify us and punchy.respond would become unreliable.
|
||||||
})
|
|
||||||
|
|
||||||
} else if hostinfo.remote.IsValid() {
|
|
||||||
n.metricsTxPunchy.Inc(1)
|
|
||||||
n.intf.outside.WriteTo([]byte{1}, hostinfo.remote)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (n *connectionManager) tryRehandshake(hostinfo *HostInfo) {
|
|
||||||
cs := n.intf.pki.getCertState()
|
|
||||||
curCrt := hostinfo.ConnectionState.myCert
|
|
||||||
myCrt := cs.getCertificate(curCrt.Version())
|
|
||||||
if curCrt.Version() >= cs.defaultVersion && bytes.Equal(curCrt.Signature(), myCrt.Signature()) == true {
|
|
||||||
// The current tunnel is using the latest certificate and version, no need to rehandshake.
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
n.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
|
if cm.punchy.GetTargetEverything() {
|
||||||
WithField("reason", "local certificate is not current").
|
hostinfo.remotes.ForEach(cm.hostMap.GetPreferredRanges(), func(addr netip.AddrPort, preferred bool) {
|
||||||
Info("Re-handshaking with remote")
|
cm.metricsTxPunchy.Inc(1)
|
||||||
|
cm.intf.outside.WriteTo([]byte{1}, addr)
|
||||||
|
})
|
||||||
|
|
||||||
n.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil)
|
} else if hostinfo.remote.IsValid() {
|
||||||
|
cm.metricsTxPunchy.Inc(1)
|
||||||
|
cm.intf.outside.WriteTo([]byte{1}, hostinfo.remote)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cm *connectionManager) tryRehandshake(hostinfo *HostInfo) {
|
||||||
|
cs := cm.intf.pki.getCertState()
|
||||||
|
curCrt := hostinfo.ConnectionState.myCert
|
||||||
|
curCrtVersion := curCrt.Version()
|
||||||
|
myCrt := cs.getCertificate(curCrtVersion)
|
||||||
|
if myCrt == nil {
|
||||||
|
cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
|
||||||
|
WithField("version", curCrtVersion).
|
||||||
|
WithField("reason", "local certificate removed").
|
||||||
|
Info("Re-handshaking with remote")
|
||||||
|
cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
peerCrt := hostinfo.ConnectionState.peerCert
|
||||||
|
if peerCrt != nil && curCrtVersion < peerCrt.Certificate.Version() {
|
||||||
|
// if our certificate version is less than theirs, and we have a matching version available, rehandshake?
|
||||||
|
if cs.getCertificate(peerCrt.Certificate.Version()) != nil {
|
||||||
|
cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
|
||||||
|
WithField("version", curCrtVersion).
|
||||||
|
WithField("peerVersion", peerCrt.Certificate.Version()).
|
||||||
|
WithField("reason", "local certificate version lower than peer, attempting to correct").
|
||||||
|
Info("Re-handshaking with remote")
|
||||||
|
cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], func(hh *HandshakeHostInfo) {
|
||||||
|
hh.initiatingVersionOverride = peerCrt.Certificate.Version()
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !bytes.Equal(curCrt.Signature(), myCrt.Signature()) {
|
||||||
|
cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
|
||||||
|
WithField("reason", "local certificate is not current").
|
||||||
|
Info("Re-handshaking with remote")
|
||||||
|
|
||||||
|
cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if curCrtVersion < cs.initiatingVersion {
|
||||||
|
cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
|
||||||
|
WithField("reason", "current cert version < pki.initiatingVersion").
|
||||||
|
Info("Re-handshaking with remote")
|
||||||
|
|
||||||
|
cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
package nebula
|
package nebula
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"crypto/ed25519"
|
"crypto/ed25519"
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
@@ -14,6 +13,7 @@ import (
|
|||||||
"github.com/slackhq/nebula/test"
|
"github.com/slackhq/nebula/test"
|
||||||
"github.com/slackhq/nebula/udp"
|
"github.com/slackhq/nebula/udp"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
func newTestLighthouse() *LightHouse {
|
func newTestLighthouse() *LightHouse {
|
||||||
@@ -22,7 +22,7 @@ func newTestLighthouse() *LightHouse {
|
|||||||
addrMap: map[netip.Addr]*RemoteList{},
|
addrMap: map[netip.Addr]*RemoteList{},
|
||||||
queryChan: make(chan netip.Addr, 10),
|
queryChan: make(chan netip.Addr, 10),
|
||||||
}
|
}
|
||||||
lighthouses := map[netip.Addr]struct{}{}
|
lighthouses := []netip.Addr{}
|
||||||
staticList := map[netip.Addr]struct{}{}
|
staticList := map[netip.Addr]struct{}{}
|
||||||
|
|
||||||
lh.lighthouses.Store(&lighthouses)
|
lh.lighthouses.Store(&lighthouses)
|
||||||
@@ -43,10 +43,10 @@ func Test_NewConnectionManagerTest(t *testing.T) {
|
|||||||
hostMap.preferredRanges.Store(&preferredRanges)
|
hostMap.preferredRanges.Store(&preferredRanges)
|
||||||
|
|
||||||
cs := &CertState{
|
cs := &CertState{
|
||||||
defaultVersion: cert.Version1,
|
initiatingVersion: cert.Version1,
|
||||||
privateKey: []byte{},
|
privateKey: []byte{},
|
||||||
v1Cert: &dummyCert{version: cert.Version1},
|
v1Cert: &dummyCert{version: cert.Version1},
|
||||||
v1HandshakeBytes: []byte{},
|
v1HandshakeBytes: []byte{},
|
||||||
}
|
}
|
||||||
|
|
||||||
lh := newTestLighthouse()
|
lh := newTestLighthouse()
|
||||||
@@ -63,10 +63,10 @@ func Test_NewConnectionManagerTest(t *testing.T) {
|
|||||||
ifce.pki.cs.Store(cs)
|
ifce.pki.cs.Store(cs)
|
||||||
|
|
||||||
// Create manager
|
// Create manager
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
conf := config.NewC(l)
|
||||||
defer cancel()
|
punchy := NewPunchyFromConfig(l, conf)
|
||||||
punchy := NewPunchyFromConfig(l, config.NewC(l))
|
nc := newConnectionManagerFromConfig(l, conf, hostMap, punchy)
|
||||||
nc := newConnectionManager(ctx, l, ifce, 5, 10, punchy)
|
nc.intf = ifce
|
||||||
p := []byte("")
|
p := []byte("")
|
||||||
nb := make([]byte, 12, 12)
|
nb := make([]byte, 12, 12)
|
||||||
out := make([]byte, mtu)
|
out := make([]byte, mtu)
|
||||||
@@ -84,32 +84,33 @@ func Test_NewConnectionManagerTest(t *testing.T) {
|
|||||||
nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)
|
nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)
|
||||||
|
|
||||||
// We saw traffic out to vpnIp
|
// We saw traffic out to vpnIp
|
||||||
nc.Out(hostinfo.localIndexId)
|
nc.Out(hostinfo)
|
||||||
nc.In(hostinfo.localIndexId)
|
nc.In(hostinfo)
|
||||||
assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId)
|
assert.False(t, hostinfo.pendingDeletion.Load())
|
||||||
assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
|
assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
|
||||||
assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
|
assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
|
||||||
assert.Contains(t, nc.out, hostinfo.localIndexId)
|
assert.True(t, hostinfo.out.Load())
|
||||||
|
assert.True(t, hostinfo.in.Load())
|
||||||
|
|
||||||
// Do a traffic check tick, should not be pending deletion but should not have any in/out packets recorded
|
// Do a traffic check tick, should not be pending deletion but should not have any in/out packets recorded
|
||||||
nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
|
nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
|
||||||
assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId)
|
assert.False(t, hostinfo.pendingDeletion.Load())
|
||||||
assert.NotContains(t, nc.out, hostinfo.localIndexId)
|
assert.False(t, hostinfo.out.Load())
|
||||||
assert.NotContains(t, nc.in, hostinfo.localIndexId)
|
assert.False(t, hostinfo.in.Load())
|
||||||
|
|
||||||
// Do another traffic check tick, this host should be pending deletion now
|
// Do another traffic check tick, this host should be pending deletion now
|
||||||
nc.Out(hostinfo.localIndexId)
|
nc.Out(hostinfo)
|
||||||
|
assert.True(t, hostinfo.out.Load())
|
||||||
nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
|
nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
|
||||||
assert.Contains(t, nc.pendingDeletion, hostinfo.localIndexId)
|
assert.True(t, hostinfo.pendingDeletion.Load())
|
||||||
assert.NotContains(t, nc.out, hostinfo.localIndexId)
|
assert.False(t, hostinfo.out.Load())
|
||||||
assert.NotContains(t, nc.in, hostinfo.localIndexId)
|
assert.False(t, hostinfo.in.Load())
|
||||||
assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
|
assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
|
||||||
assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
|
assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
|
||||||
|
|
||||||
// Do a final traffic check tick, the host should now be removed
|
// Do a final traffic check tick, the host should now be removed
|
||||||
nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
|
nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
|
||||||
assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId)
|
assert.NotContains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs)
|
||||||
assert.NotContains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
|
|
||||||
assert.NotContains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
|
assert.NotContains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -125,10 +126,10 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
|
|||||||
hostMap.preferredRanges.Store(&preferredRanges)
|
hostMap.preferredRanges.Store(&preferredRanges)
|
||||||
|
|
||||||
cs := &CertState{
|
cs := &CertState{
|
||||||
defaultVersion: cert.Version1,
|
initiatingVersion: cert.Version1,
|
||||||
privateKey: []byte{},
|
privateKey: []byte{},
|
||||||
v1Cert: &dummyCert{version: cert.Version1},
|
v1Cert: &dummyCert{version: cert.Version1},
|
||||||
v1HandshakeBytes: []byte{},
|
v1HandshakeBytes: []byte{},
|
||||||
}
|
}
|
||||||
|
|
||||||
lh := newTestLighthouse()
|
lh := newTestLighthouse()
|
||||||
@@ -145,10 +146,10 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
|
|||||||
ifce.pki.cs.Store(cs)
|
ifce.pki.cs.Store(cs)
|
||||||
|
|
||||||
// Create manager
|
// Create manager
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
conf := config.NewC(l)
|
||||||
defer cancel()
|
punchy := NewPunchyFromConfig(l, conf)
|
||||||
punchy := NewPunchyFromConfig(l, config.NewC(l))
|
nc := newConnectionManagerFromConfig(l, conf, hostMap, punchy)
|
||||||
nc := newConnectionManager(ctx, l, ifce, 5, 10, punchy)
|
nc.intf = ifce
|
||||||
p := []byte("")
|
p := []byte("")
|
||||||
nb := make([]byte, 12, 12)
|
nb := make([]byte, 12, 12)
|
||||||
out := make([]byte, mtu)
|
out := make([]byte, mtu)
|
||||||
@@ -166,33 +167,129 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
|
|||||||
nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)
|
nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)
|
||||||
|
|
||||||
// We saw traffic out to vpnIp
|
// We saw traffic out to vpnIp
|
||||||
nc.Out(hostinfo.localIndexId)
|
nc.Out(hostinfo)
|
||||||
nc.In(hostinfo.localIndexId)
|
nc.In(hostinfo)
|
||||||
assert.NotContains(t, nc.pendingDeletion, hostinfo.vpnAddrs[0])
|
assert.True(t, hostinfo.in.Load())
|
||||||
|
assert.True(t, hostinfo.out.Load())
|
||||||
|
assert.False(t, hostinfo.pendingDeletion.Load())
|
||||||
assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
|
assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
|
||||||
assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
|
assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
|
||||||
|
|
||||||
// Do a traffic check tick, should not be pending deletion but should not have any in/out packets recorded
|
// Do a traffic check tick, should not be pending deletion but should not have any in/out packets recorded
|
||||||
nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
|
nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
|
||||||
assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId)
|
assert.False(t, hostinfo.pendingDeletion.Load())
|
||||||
assert.NotContains(t, nc.out, hostinfo.localIndexId)
|
assert.False(t, hostinfo.out.Load())
|
||||||
assert.NotContains(t, nc.in, hostinfo.localIndexId)
|
assert.False(t, hostinfo.in.Load())
|
||||||
|
|
||||||
// Do another traffic check tick, this host should be pending deletion now
|
// Do another traffic check tick, this host should be pending deletion now
|
||||||
nc.Out(hostinfo.localIndexId)
|
nc.Out(hostinfo)
|
||||||
nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
|
nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
|
||||||
assert.Contains(t, nc.pendingDeletion, hostinfo.localIndexId)
|
assert.True(t, hostinfo.pendingDeletion.Load())
|
||||||
assert.NotContains(t, nc.out, hostinfo.localIndexId)
|
assert.False(t, hostinfo.out.Load())
|
||||||
assert.NotContains(t, nc.in, hostinfo.localIndexId)
|
assert.False(t, hostinfo.in.Load())
|
||||||
assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
|
assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
|
||||||
assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
|
assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
|
||||||
|
|
||||||
// We saw traffic, should no longer be pending deletion
|
// We saw traffic, should no longer be pending deletion
|
||||||
nc.In(hostinfo.localIndexId)
|
nc.In(hostinfo)
|
||||||
nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
|
nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
|
||||||
assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId)
|
assert.False(t, hostinfo.pendingDeletion.Load())
|
||||||
assert.NotContains(t, nc.out, hostinfo.localIndexId)
|
assert.False(t, hostinfo.out.Load())
|
||||||
assert.NotContains(t, nc.in, hostinfo.localIndexId)
|
assert.False(t, hostinfo.in.Load())
|
||||||
|
assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
|
||||||
|
assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_NewConnectionManager_DisconnectInactive(t *testing.T) {
|
||||||
|
l := test.NewLogger()
|
||||||
|
localrange := netip.MustParsePrefix("10.1.1.1/24")
|
||||||
|
vpnAddrs := []netip.Addr{netip.MustParseAddr("172.1.1.2")}
|
||||||
|
preferredRanges := []netip.Prefix{localrange}
|
||||||
|
|
||||||
|
// Very incomplete mock objects
|
||||||
|
hostMap := newHostMap(l)
|
||||||
|
hostMap.preferredRanges.Store(&preferredRanges)
|
||||||
|
|
||||||
|
cs := &CertState{
|
||||||
|
initiatingVersion: cert.Version1,
|
||||||
|
privateKey: []byte{},
|
||||||
|
v1Cert: &dummyCert{version: cert.Version1},
|
||||||
|
v1HandshakeBytes: []byte{},
|
||||||
|
}
|
||||||
|
|
||||||
|
lh := newTestLighthouse()
|
||||||
|
ifce := &Interface{
|
||||||
|
hostMap: hostMap,
|
||||||
|
inside: &test.NoopTun{},
|
||||||
|
outside: &udp.NoopConn{},
|
||||||
|
firewall: &Firewall{},
|
||||||
|
lightHouse: lh,
|
||||||
|
pki: &PKI{},
|
||||||
|
handshakeManager: NewHandshakeManager(l, hostMap, lh, &udp.NoopConn{}, defaultHandshakeConfig),
|
||||||
|
l: l,
|
||||||
|
}
|
||||||
|
ifce.pki.cs.Store(cs)
|
||||||
|
|
||||||
|
// Create manager
|
||||||
|
conf := config.NewC(l)
|
||||||
|
conf.Settings["tunnels"] = map[string]any{
|
||||||
|
"drop_inactive": true,
|
||||||
|
}
|
||||||
|
punchy := NewPunchyFromConfig(l, conf)
|
||||||
|
nc := newConnectionManagerFromConfig(l, conf, hostMap, punchy)
|
||||||
|
assert.True(t, nc.dropInactive.Load())
|
||||||
|
nc.intf = ifce
|
||||||
|
|
||||||
|
// Add an ip we have established a connection w/ to hostmap
|
||||||
|
hostinfo := &HostInfo{
|
||||||
|
vpnAddrs: vpnAddrs,
|
||||||
|
localIndexId: 1099,
|
||||||
|
remoteIndexId: 9901,
|
||||||
|
}
|
||||||
|
hostinfo.ConnectionState = &ConnectionState{
|
||||||
|
myCert: &dummyCert{version: cert.Version1},
|
||||||
|
H: &noise.HandshakeState{},
|
||||||
|
}
|
||||||
|
nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)
|
||||||
|
|
||||||
|
// Do a traffic check tick, in and out should be cleared but should not be pending deletion
|
||||||
|
nc.Out(hostinfo)
|
||||||
|
nc.In(hostinfo)
|
||||||
|
assert.True(t, hostinfo.out.Load())
|
||||||
|
assert.True(t, hostinfo.in.Load())
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
decision, _, _ := nc.makeTrafficDecision(hostinfo.localIndexId, now)
|
||||||
|
assert.Equal(t, tryRehandshake, decision)
|
||||||
|
assert.Equal(t, now, hostinfo.lastUsed)
|
||||||
|
assert.False(t, hostinfo.pendingDeletion.Load())
|
||||||
|
assert.False(t, hostinfo.out.Load())
|
||||||
|
assert.False(t, hostinfo.in.Load())
|
||||||
|
|
||||||
|
decision, _, _ = nc.makeTrafficDecision(hostinfo.localIndexId, now.Add(time.Second*5))
|
||||||
|
assert.Equal(t, doNothing, decision)
|
||||||
|
assert.Equal(t, now, hostinfo.lastUsed)
|
||||||
|
assert.False(t, hostinfo.pendingDeletion.Load())
|
||||||
|
assert.False(t, hostinfo.out.Load())
|
||||||
|
assert.False(t, hostinfo.in.Load())
|
||||||
|
|
||||||
|
// Do another traffic check tick, should still not be pending deletion
|
||||||
|
decision, _, _ = nc.makeTrafficDecision(hostinfo.localIndexId, now.Add(time.Second*10))
|
||||||
|
assert.Equal(t, doNothing, decision)
|
||||||
|
assert.Equal(t, now, hostinfo.lastUsed)
|
||||||
|
assert.False(t, hostinfo.pendingDeletion.Load())
|
||||||
|
assert.False(t, hostinfo.out.Load())
|
||||||
|
assert.False(t, hostinfo.in.Load())
|
||||||
|
assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
|
||||||
|
assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
|
||||||
|
|
||||||
|
// Finally advance beyond the inactivity timeout
|
||||||
|
decision, _, _ = nc.makeTrafficDecision(hostinfo.localIndexId, now.Add(time.Minute*10))
|
||||||
|
assert.Equal(t, closeTunnel, decision)
|
||||||
|
assert.Equal(t, now, hostinfo.lastUsed)
|
||||||
|
assert.False(t, hostinfo.pendingDeletion.Load())
|
||||||
|
assert.False(t, hostinfo.out.Load())
|
||||||
|
assert.False(t, hostinfo.in.Load())
|
||||||
assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
|
assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
|
||||||
assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
|
assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
|
||||||
}
|
}
|
||||||
@@ -223,9 +320,9 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
caCert, err := tbs.Sign(nil, cert.Curve_CURVE25519, privCA)
|
caCert, err := tbs.Sign(nil, cert.Curve_CURVE25519, privCA)
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
ncp := cert.NewCAPool()
|
ncp := cert.NewCAPool()
|
||||||
assert.NoError(t, ncp.AddCA(caCert))
|
require.NoError(t, ncp.AddCA(caCert))
|
||||||
|
|
||||||
pubCrt, _, _ := ed25519.GenerateKey(rand.Reader)
|
pubCrt, _, _ := ed25519.GenerateKey(rand.Reader)
|
||||||
tbs = &cert.TBSCertificate{
|
tbs = &cert.TBSCertificate{
|
||||||
@@ -237,7 +334,7 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
|
|||||||
PublicKey: pubCrt,
|
PublicKey: pubCrt,
|
||||||
}
|
}
|
||||||
peerCert, err := tbs.Sign(caCert, cert.Curve_CURVE25519, privCA)
|
peerCert, err := tbs.Sign(caCert, cert.Curve_CURVE25519, privCA)
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
cachedPeerCert, err := ncp.VerifyCertificate(now.Add(time.Second), peerCert)
|
cachedPeerCert, err := ncp.VerifyCertificate(now.Add(time.Second), peerCert)
|
||||||
|
|
||||||
@@ -263,10 +360,10 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
|
|||||||
ifce.disconnectInvalid.Store(true)
|
ifce.disconnectInvalid.Store(true)
|
||||||
|
|
||||||
// Create manager
|
// Create manager
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
conf := config.NewC(l)
|
||||||
defer cancel()
|
punchy := NewPunchyFromConfig(l, conf)
|
||||||
punchy := NewPunchyFromConfig(l, config.NewC(l))
|
nc := newConnectionManagerFromConfig(l, conf, hostMap, punchy)
|
||||||
nc := newConnectionManager(ctx, l, ifce, 5, 10, punchy)
|
nc.intf = ifce
|
||||||
ifce.connectionManager = nc
|
ifce.connectionManager = nc
|
||||||
|
|
||||||
hostinfo := &HostInfo{
|
hostinfo := &HostInfo{
|
||||||
@@ -349,6 +446,10 @@ func (d *dummyCert) PublicKey() []byte {
|
|||||||
return d.publicKey
|
return d.publicKey
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (d *dummyCert) MarshalPublicKeyPEM() []byte {
|
||||||
|
return cert.MarshalPublicKeyToPEM(d.curve, d.publicKey)
|
||||||
|
}
|
||||||
|
|
||||||
func (d *dummyCert) Signature() []byte {
|
func (d *dummyCert) Signature() []byte {
|
||||||
return d.signature
|
return d.signature
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -50,11 +50,6 @@ func NewConnectionState(l *logrus.Logger, cs *CertState, crt cert.Certificate, i
|
|||||||
}
|
}
|
||||||
|
|
||||||
static := noise.DHKey{Private: cs.privateKey, Public: crt.PublicKey()}
|
static := noise.DHKey{Private: cs.privateKey, Public: crt.PublicKey()}
|
||||||
|
|
||||||
b := NewBits(ReplayWindow)
|
|
||||||
// Clear out bit 0, we never transmit it, and we don't want it showing as packet loss
|
|
||||||
b.Update(l, 0)
|
|
||||||
|
|
||||||
hs, err := noise.NewHandshakeState(noise.Config{
|
hs, err := noise.NewHandshakeState(noise.Config{
|
||||||
CipherSuite: ncs,
|
CipherSuite: ncs,
|
||||||
Random: rand.Reader,
|
Random: rand.Reader,
|
||||||
@@ -74,7 +69,7 @@ func NewConnectionState(l *logrus.Logger, cs *CertState, crt cert.Certificate, i
|
|||||||
ci := &ConnectionState{
|
ci := &ConnectionState{
|
||||||
H: hs,
|
H: hs,
|
||||||
initiator: initiator,
|
initiator: initiator,
|
||||||
window: b,
|
window: NewBits(ReplayWindow),
|
||||||
myCert: crt,
|
myCert: crt,
|
||||||
}
|
}
|
||||||
// always start the counter from 2, as packet 1 and packet 2 are handshake packets.
|
// always start the counter from 2, as packet 1 and packet 2 are handshake packets.
|
||||||
|
|||||||
23
control.go
23
control.go
@@ -26,14 +26,15 @@ type controlHostLister interface {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type Control struct {
|
type Control struct {
|
||||||
f *Interface
|
f *Interface
|
||||||
l *logrus.Logger
|
l *logrus.Logger
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
cancel context.CancelFunc
|
cancel context.CancelFunc
|
||||||
sshStart func()
|
sshStart func()
|
||||||
statsStart func()
|
statsStart func()
|
||||||
dnsStart func()
|
dnsStart func()
|
||||||
lighthouseStart func()
|
lighthouseStart func()
|
||||||
|
connectionManagerStart func(context.Context)
|
||||||
}
|
}
|
||||||
|
|
||||||
type ControlHostInfo struct {
|
type ControlHostInfo struct {
|
||||||
@@ -63,6 +64,9 @@ func (c *Control) Start() {
|
|||||||
if c.dnsStart != nil {
|
if c.dnsStart != nil {
|
||||||
go c.dnsStart()
|
go c.dnsStart()
|
||||||
}
|
}
|
||||||
|
if c.connectionManagerStart != nil {
|
||||||
|
go c.connectionManagerStart(c.ctx)
|
||||||
|
}
|
||||||
if c.lighthouseStart != nil {
|
if c.lighthouseStart != nil {
|
||||||
c.lighthouseStart()
|
c.lighthouseStart()
|
||||||
}
|
}
|
||||||
@@ -131,8 +135,7 @@ func (c *Control) ListHostmapIndexes(pendingMap bool) []ControlHostInfo {
|
|||||||
|
|
||||||
// GetCertByVpnIp returns the authenticated certificate of the given vpn IP, or nil if not found
|
// GetCertByVpnIp returns the authenticated certificate of the given vpn IP, or nil if not found
|
||||||
func (c *Control) GetCertByVpnIp(vpnIp netip.Addr) cert.Certificate {
|
func (c *Control) GetCertByVpnIp(vpnIp netip.Addr) cert.Certificate {
|
||||||
_, found := c.f.myVpnAddrsTable.Lookup(vpnIp)
|
if c.f.myVpnAddrsTable.Contains(vpnIp) {
|
||||||
if found {
|
|
||||||
// Only returning the default certificate since its impossible
|
// Only returning the default certificate since its impossible
|
||||||
// for any other host but ourselves to have more than 1
|
// for any other host but ourselves to have more than 1
|
||||||
return c.f.pki.getCertState().GetDefaultCertificate().Copy()
|
return c.f.pki.getCertState().GetDefaultCertificate().Copy()
|
||||||
|
|||||||
@@ -53,7 +53,7 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
|
|||||||
localIndexId: 201,
|
localIndexId: 201,
|
||||||
vpnAddrs: []netip.Addr{vpnIp},
|
vpnAddrs: []netip.Addr{vpnIp},
|
||||||
relayState: RelayState{
|
relayState: RelayState{
|
||||||
relays: map[netip.Addr]struct{}{},
|
relays: nil,
|
||||||
relayForByAddr: map[netip.Addr]*Relay{},
|
relayForByAddr: map[netip.Addr]*Relay{},
|
||||||
relayForByIdx: map[uint32]*Relay{},
|
relayForByIdx: map[uint32]*Relay{},
|
||||||
},
|
},
|
||||||
@@ -72,7 +72,7 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
|
|||||||
localIndexId: 201,
|
localIndexId: 201,
|
||||||
vpnAddrs: []netip.Addr{vpnIp2},
|
vpnAddrs: []netip.Addr{vpnIp2},
|
||||||
relayState: RelayState{
|
relayState: RelayState{
|
||||||
relays: map[netip.Addr]struct{}{},
|
relays: nil,
|
||||||
relayForByAddr: map[netip.Addr]*Relay{},
|
relayForByAddr: map[netip.Addr]*Relay{},
|
||||||
relayForByIdx: map[uint32]*Relay{},
|
relayForByIdx: map[uint32]*Relay{},
|
||||||
},
|
},
|
||||||
@@ -101,7 +101,7 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
|
|||||||
|
|
||||||
// Make sure we don't have any unexpected fields
|
// Make sure we don't have any unexpected fields
|
||||||
assertFields(t, []string{"VpnAddrs", "LocalIndex", "RemoteIndex", "RemoteAddrs", "Cert", "MessageCounter", "CurrentRemote", "CurrentRelaysToMe", "CurrentRelaysThroughMe"}, thi)
|
assertFields(t, []string{"VpnAddrs", "LocalIndex", "RemoteIndex", "RemoteAddrs", "Cert", "MessageCounter", "CurrentRemote", "CurrentRelaysToMe", "CurrentRelaysThroughMe"}, thi)
|
||||||
assert.EqualValues(t, &expectedInfo, thi)
|
assert.Equal(t, &expectedInfo, thi)
|
||||||
test.AssertDeepCopyEqual(t, &expectedInfo, thi)
|
test.AssertDeepCopyEqual(t, &expectedInfo, thi)
|
||||||
|
|
||||||
// Make sure we don't panic if the host info doesn't have a cert yet
|
// Make sure we don't panic if the host info doesn't have a cert yet
|
||||||
@@ -110,7 +110,7 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func assertFields(t *testing.T, expected []string, actualStruct interface{}) {
|
func assertFields(t *testing.T, expected []string, actualStruct any) {
|
||||||
val := reflect.ValueOf(actualStruct).Elem()
|
val := reflect.ValueOf(actualStruct).Elem()
|
||||||
fields := make([]string, val.NumField())
|
fields := make([]string, val.NumField())
|
||||||
for i := 0; i < val.NumField(); i++ {
|
for i := 0; i < val.NumField(); i++ {
|
||||||
|
|||||||
@@ -174,6 +174,10 @@ func (c *Control) GetHostmap() *HostMap {
|
|||||||
return c.f.hostMap
|
return c.f.hostMap
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Control) GetF() *Interface {
|
||||||
|
return c.f
|
||||||
|
}
|
||||||
|
|
||||||
func (c *Control) GetCertState() *CertState {
|
func (c *Control) GetCertState() *CertState {
|
||||||
return c.f.pki.getCertState()
|
return c.f.pki.getCertState()
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ type dnsRecords struct {
|
|||||||
dnsMap4 map[string]netip.Addr
|
dnsMap4 map[string]netip.Addr
|
||||||
dnsMap6 map[string]netip.Addr
|
dnsMap6 map[string]netip.Addr
|
||||||
hostMap *HostMap
|
hostMap *HostMap
|
||||||
myVpnAddrsTable *bart.Table[struct{}]
|
myVpnAddrsTable *bart.Lite
|
||||||
}
|
}
|
||||||
|
|
||||||
func newDnsRecords(l *logrus.Logger, cs *CertState, hostMap *HostMap) *dnsRecords {
|
func newDnsRecords(l *logrus.Logger, cs *CertState, hostMap *HostMap) *dnsRecords {
|
||||||
@@ -112,8 +112,8 @@ func (d *dnsRecords) isSelfNebulaOrLocalhost(addr string) bool {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
_, found := d.myVpnAddrsTable.Lookup(b)
|
//if we found it in this table, it's good
|
||||||
return found //if we found it in this table, it's good
|
return d.myVpnAddrsTable.Contains(b)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *dnsRecords) parseQuery(m *dns.Msg, w dns.ResponseWriter) {
|
func (d *dnsRecords) parseQuery(m *dns.Msg, w dns.ResponseWriter) {
|
||||||
|
|||||||
@@ -38,24 +38,24 @@ func TestParsequery(t *testing.T) {
|
|||||||
func Test_getDnsServerAddr(t *testing.T) {
|
func Test_getDnsServerAddr(t *testing.T) {
|
||||||
c := config.NewC(nil)
|
c := config.NewC(nil)
|
||||||
|
|
||||||
c.Settings["lighthouse"] = map[interface{}]interface{}{
|
c.Settings["lighthouse"] = map[string]any{
|
||||||
"dns": map[interface{}]interface{}{
|
"dns": map[string]any{
|
||||||
"host": "0.0.0.0",
|
"host": "0.0.0.0",
|
||||||
"port": "1",
|
"port": "1",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
assert.Equal(t, "0.0.0.0:1", getDnsServerAddr(c))
|
assert.Equal(t, "0.0.0.0:1", getDnsServerAddr(c))
|
||||||
|
|
||||||
c.Settings["lighthouse"] = map[interface{}]interface{}{
|
c.Settings["lighthouse"] = map[string]any{
|
||||||
"dns": map[interface{}]interface{}{
|
"dns": map[string]any{
|
||||||
"host": "::",
|
"host": "::",
|
||||||
"port": "1",
|
"port": "1",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
assert.Equal(t, "[::]:1", getDnsServerAddr(c))
|
assert.Equal(t, "[::]:1", getDnsServerAddr(c))
|
||||||
|
|
||||||
c.Settings["lighthouse"] = map[interface{}]interface{}{
|
c.Settings["lighthouse"] = map[string]any{
|
||||||
"dns": map[interface{}]interface{}{
|
"dns": map[string]any{
|
||||||
"host": "[::]",
|
"host": "[::]",
|
||||||
"port": "1",
|
"port": "1",
|
||||||
},
|
},
|
||||||
@@ -63,8 +63,8 @@ func Test_getDnsServerAddr(t *testing.T) {
|
|||||||
assert.Equal(t, "[::]:1", getDnsServerAddr(c))
|
assert.Equal(t, "[::]:1", getDnsServerAddr(c))
|
||||||
|
|
||||||
// Make sure whitespace doesn't mess us up
|
// Make sure whitespace doesn't mess us up
|
||||||
c.Settings["lighthouse"] = map[interface{}]interface{}{
|
c.Settings["lighthouse"] = map[string]any{
|
||||||
"dns": map[interface{}]interface{}{
|
"dns": map[string]any{
|
||||||
"host": "[::] ",
|
"host": "[::] ",
|
||||||
"port": "1",
|
"port": "1",
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -19,16 +19,18 @@ import (
|
|||||||
"github.com/slackhq/nebula/header"
|
"github.com/slackhq/nebula/header"
|
||||||
"github.com/slackhq/nebula/udp"
|
"github.com/slackhq/nebula/udp"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"gopkg.in/yaml.v2"
|
"github.com/stretchr/testify/require"
|
||||||
|
"go.yaml.in/yaml/v3"
|
||||||
)
|
)
|
||||||
|
|
||||||
func BenchmarkHotPath(b *testing.B) {
|
func BenchmarkHotPath(b *testing.B) {
|
||||||
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||||
myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil)
|
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil)
|
||||||
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil)
|
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil)
|
||||||
|
|
||||||
// Put their info in our lighthouse
|
// Put their info in our lighthouse
|
||||||
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
|
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
|
||||||
|
theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
|
||||||
|
|
||||||
// Start the servers
|
// Start the servers
|
||||||
myControl.Start()
|
myControl.Start()
|
||||||
@@ -37,6 +39,9 @@ func BenchmarkHotPath(b *testing.B) {
|
|||||||
r := router.NewR(b, myControl, theirControl)
|
r := router.NewR(b, myControl, theirControl)
|
||||||
r.CancelFlowLogs()
|
r.CancelFlowLogs()
|
||||||
|
|
||||||
|
assertTunnel(b, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
|
||||||
|
b.ResetTimer()
|
||||||
|
|
||||||
for n := 0; n < b.N; n++ {
|
for n := 0; n < b.N; n++ {
|
||||||
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
|
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
|
||||||
_ = r.RouteForAllUntilTxTun(theirControl)
|
_ = r.RouteForAllUntilTxTun(theirControl)
|
||||||
@@ -46,6 +51,39 @@ func BenchmarkHotPath(b *testing.B) {
|
|||||||
theirControl.Stop()
|
theirControl.Stop()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func BenchmarkHotPathRelay(b *testing.B) {
|
||||||
|
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||||
|
myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}})
|
||||||
|
relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}})
|
||||||
|
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}})
|
||||||
|
|
||||||
|
// Teach my how to get to the relay and that their can be reached via the relay
|
||||||
|
myControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr)
|
||||||
|
myControl.InjectRelays(theirVpnIpNet[0].Addr(), []netip.Addr{relayVpnIpNet[0].Addr()})
|
||||||
|
relayControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
|
||||||
|
|
||||||
|
// Build a router so we don't have to reason who gets which packet
|
||||||
|
r := router.NewR(b, myControl, relayControl, theirControl)
|
||||||
|
r.CancelFlowLogs()
|
||||||
|
|
||||||
|
// Start the servers
|
||||||
|
myControl.Start()
|
||||||
|
relayControl.Start()
|
||||||
|
theirControl.Start()
|
||||||
|
|
||||||
|
assertTunnel(b, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r)
|
||||||
|
b.ResetTimer()
|
||||||
|
|
||||||
|
for n := 0; n < b.N; n++ {
|
||||||
|
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
|
||||||
|
_ = r.RouteForAllUntilTxTun(theirControl)
|
||||||
|
}
|
||||||
|
|
||||||
|
myControl.Stop()
|
||||||
|
theirControl.Stop()
|
||||||
|
relayControl.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
func TestGoodHandshake(t *testing.T) {
|
func TestGoodHandshake(t *testing.T) {
|
||||||
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||||
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil)
|
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil)
|
||||||
@@ -96,6 +134,41 @@ func TestGoodHandshake(t *testing.T) {
|
|||||||
theirControl.Stop()
|
theirControl.Stop()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestGoodHandshakeNoOverlap(t *testing.T) {
|
||||||
|
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||||
|
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "me", "10.128.0.1/24", nil)
|
||||||
|
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "them", "2001::69/24", nil) //look ma, cross-stack!
|
||||||
|
|
||||||
|
// Put their info in our lighthouse
|
||||||
|
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
|
||||||
|
|
||||||
|
// Start the servers
|
||||||
|
myControl.Start()
|
||||||
|
theirControl.Start()
|
||||||
|
|
||||||
|
empty := []byte{}
|
||||||
|
t.Log("do something to cause a handshake")
|
||||||
|
myControl.GetF().SendMessageToVpnAddr(header.Test, header.MessageNone, theirVpnIpNet[0].Addr(), empty, empty, empty)
|
||||||
|
|
||||||
|
t.Log("Have them consume my stage 0 packet. They have a tunnel now")
|
||||||
|
theirControl.InjectUDPPacket(myControl.GetFromUDP(true))
|
||||||
|
|
||||||
|
t.Log("Get their stage 1 packet")
|
||||||
|
stage1Packet := theirControl.GetFromUDP(true)
|
||||||
|
|
||||||
|
t.Log("Have me consume their stage 1 packet. I have a tunnel now")
|
||||||
|
myControl.InjectUDPPacket(stage1Packet)
|
||||||
|
|
||||||
|
t.Log("Wait until we see a test packet come through to make sure we give the tunnel time to complete")
|
||||||
|
myControl.WaitForType(header.Test, 0, theirControl)
|
||||||
|
|
||||||
|
t.Log("Make sure our host infos are correct")
|
||||||
|
assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet, theirVpnIpNet, myControl, theirControl)
|
||||||
|
|
||||||
|
myControl.Stop()
|
||||||
|
theirControl.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
func TestWrongResponderHandshake(t *testing.T) {
|
func TestWrongResponderHandshake(t *testing.T) {
|
||||||
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||||
|
|
||||||
@@ -463,6 +536,35 @@ func TestRelays(t *testing.T) {
|
|||||||
r.RenderHostmaps("Final hostmaps", myControl, relayControl, theirControl)
|
r.RenderHostmaps("Final hostmaps", myControl, relayControl, theirControl)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestRelaysDontCareAboutIps(t *testing.T) {
|
||||||
|
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||||
|
myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version2, ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}})
|
||||||
|
relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "relay ", "2001::9999/24", m{"relay": m{"am_relay": true}})
|
||||||
|
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}})
|
||||||
|
|
||||||
|
// Teach my how to get to the relay and that their can be reached via the relay
|
||||||
|
myControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr)
|
||||||
|
myControl.InjectRelays(theirVpnIpNet[0].Addr(), []netip.Addr{relayVpnIpNet[0].Addr()})
|
||||||
|
relayControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
|
||||||
|
|
||||||
|
// Build a router so we don't have to reason who gets which packet
|
||||||
|
r := router.NewR(t, myControl, relayControl, theirControl)
|
||||||
|
defer r.RenderFlow()
|
||||||
|
|
||||||
|
// Start the servers
|
||||||
|
myControl.Start()
|
||||||
|
relayControl.Start()
|
||||||
|
theirControl.Start()
|
||||||
|
|
||||||
|
t.Log("Trigger a handshake from me to them via the relay")
|
||||||
|
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
|
||||||
|
|
||||||
|
p := r.RouteForAllUntilTxTun(theirControl)
|
||||||
|
r.Log("Assert the tunnel works")
|
||||||
|
assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80)
|
||||||
|
r.RenderHostmaps("Final hostmaps", myControl, relayControl, theirControl)
|
||||||
|
}
|
||||||
|
|
||||||
func TestReestablishRelays(t *testing.T) {
|
func TestReestablishRelays(t *testing.T) {
|
||||||
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||||
myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}})
|
myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}})
|
||||||
@@ -505,7 +607,7 @@ func TestReestablishRelays(t *testing.T) {
|
|||||||
curIndexes := len(myControl.GetHostmap().Indexes)
|
curIndexes := len(myControl.GetHostmap().Indexes)
|
||||||
for curIndexes >= start {
|
for curIndexes >= start {
|
||||||
curIndexes = len(myControl.GetHostmap().Indexes)
|
curIndexes = len(myControl.GetHostmap().Indexes)
|
||||||
r.Logf("Wait for the dead index to go away:start=%v indexes, currnet=%v indexes", start, curIndexes)
|
r.Logf("Wait for the dead index to go away:start=%v indexes, current=%v indexes", start, curIndexes)
|
||||||
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me should fail"))
|
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me should fail"))
|
||||||
|
|
||||||
r.RouteForAllExitFunc(func(p *udp.Packet, c *nebula.Control) router.ExitType {
|
r.RouteForAllExitFunc(func(p *udp.Packet, c *nebula.Control) router.ExitType {
|
||||||
@@ -771,7 +873,7 @@ func TestRehandshakingRelays(t *testing.T) {
|
|||||||
"key": string(myNextPrivKey),
|
"key": string(myNextPrivKey),
|
||||||
}
|
}
|
||||||
rc, err := yaml.Marshal(relayConfig.Settings)
|
rc, err := yaml.Marshal(relayConfig.Settings)
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
relayConfig.ReloadConfigString(string(rc))
|
relayConfig.ReloadConfigString(string(rc))
|
||||||
|
|
||||||
for {
|
for {
|
||||||
@@ -875,7 +977,7 @@ func TestRehandshakingRelaysPrimary(t *testing.T) {
|
|||||||
"key": string(myNextPrivKey),
|
"key": string(myNextPrivKey),
|
||||||
}
|
}
|
||||||
rc, err := yaml.Marshal(relayConfig.Settings)
|
rc, err := yaml.Marshal(relayConfig.Settings)
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
relayConfig.ReloadConfigString(string(rc))
|
relayConfig.ReloadConfigString(string(rc))
|
||||||
|
|
||||||
for {
|
for {
|
||||||
@@ -970,7 +1072,7 @@ func TestRehandshaking(t *testing.T) {
|
|||||||
"key": string(myNextPrivKey),
|
"key": string(myNextPrivKey),
|
||||||
}
|
}
|
||||||
rc, err := yaml.Marshal(myConfig.Settings)
|
rc, err := yaml.Marshal(myConfig.Settings)
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
myConfig.ReloadConfigString(string(rc))
|
myConfig.ReloadConfigString(string(rc))
|
||||||
|
|
||||||
for {
|
for {
|
||||||
@@ -987,17 +1089,17 @@ func TestRehandshaking(t *testing.T) {
|
|||||||
r.Log("Got the new cert")
|
r.Log("Got the new cert")
|
||||||
// Flip their firewall to only allowing the new group to catch the tunnels reverting incorrectly
|
// Flip their firewall to only allowing the new group to catch the tunnels reverting incorrectly
|
||||||
rc, err = yaml.Marshal(theirConfig.Settings)
|
rc, err = yaml.Marshal(theirConfig.Settings)
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
var theirNewConfig m
|
var theirNewConfig m
|
||||||
assert.NoError(t, yaml.Unmarshal(rc, &theirNewConfig))
|
require.NoError(t, yaml.Unmarshal(rc, &theirNewConfig))
|
||||||
theirFirewall := theirNewConfig["firewall"].(map[interface{}]interface{})
|
theirFirewall := theirNewConfig["firewall"].(map[string]any)
|
||||||
theirFirewall["inbound"] = []m{{
|
theirFirewall["inbound"] = []m{{
|
||||||
"proto": "any",
|
"proto": "any",
|
||||||
"port": "any",
|
"port": "any",
|
||||||
"group": "new group",
|
"group": "new group",
|
||||||
}}
|
}}
|
||||||
rc, err = yaml.Marshal(theirNewConfig)
|
rc, err = yaml.Marshal(theirNewConfig)
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
theirConfig.ReloadConfigString(string(rc))
|
theirConfig.ReloadConfigString(string(rc))
|
||||||
|
|
||||||
r.Log("Spin until there is only 1 tunnel")
|
r.Log("Spin until there is only 1 tunnel")
|
||||||
@@ -1051,6 +1153,9 @@ func TestRehandshakingLoser(t *testing.T) {
|
|||||||
t.Log("Stand up a tunnel between me and them")
|
t.Log("Stand up a tunnel between me and them")
|
||||||
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
|
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
|
||||||
|
|
||||||
|
myControl.GetHostInfoByVpnAddr(theirVpnIpNet[0].Addr(), false)
|
||||||
|
theirControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false)
|
||||||
|
|
||||||
r.RenderHostmaps("Starting hostmaps", myControl, theirControl)
|
r.RenderHostmaps("Starting hostmaps", myControl, theirControl)
|
||||||
|
|
||||||
r.Log("Renew their certificate and spin until mine sees it")
|
r.Log("Renew their certificate and spin until mine sees it")
|
||||||
@@ -1067,7 +1172,7 @@ func TestRehandshakingLoser(t *testing.T) {
|
|||||||
"key": string(theirNextPrivKey),
|
"key": string(theirNextPrivKey),
|
||||||
}
|
}
|
||||||
rc, err := yaml.Marshal(theirConfig.Settings)
|
rc, err := yaml.Marshal(theirConfig.Settings)
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
theirConfig.ReloadConfigString(string(rc))
|
theirConfig.ReloadConfigString(string(rc))
|
||||||
|
|
||||||
for {
|
for {
|
||||||
@@ -1083,17 +1188,17 @@ func TestRehandshakingLoser(t *testing.T) {
|
|||||||
|
|
||||||
// Flip my firewall to only allowing the new group to catch the tunnels reverting incorrectly
|
// Flip my firewall to only allowing the new group to catch the tunnels reverting incorrectly
|
||||||
rc, err = yaml.Marshal(myConfig.Settings)
|
rc, err = yaml.Marshal(myConfig.Settings)
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
var myNewConfig m
|
var myNewConfig m
|
||||||
assert.NoError(t, yaml.Unmarshal(rc, &myNewConfig))
|
require.NoError(t, yaml.Unmarshal(rc, &myNewConfig))
|
||||||
theirFirewall := myNewConfig["firewall"].(map[interface{}]interface{})
|
theirFirewall := myNewConfig["firewall"].(map[string]any)
|
||||||
theirFirewall["inbound"] = []m{{
|
theirFirewall["inbound"] = []m{{
|
||||||
"proto": "any",
|
"proto": "any",
|
||||||
"port": "any",
|
"port": "any",
|
||||||
"group": "their new group",
|
"group": "their new group",
|
||||||
}}
|
}}
|
||||||
rc, err = yaml.Marshal(myNewConfig)
|
rc, err = yaml.Marshal(myNewConfig)
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
myConfig.ReloadConfigString(string(rc))
|
myConfig.ReloadConfigString(string(rc))
|
||||||
|
|
||||||
r.Log("Spin until there is only 1 tunnel")
|
r.Log("Spin until there is only 1 tunnel")
|
||||||
@@ -1223,3 +1328,109 @@ func TestV2NonPrimaryWithLighthouse(t *testing.T) {
|
|||||||
myControl.Stop()
|
myControl.Stop()
|
||||||
theirControl.Stop()
|
theirControl.Stop()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestV2NonPrimaryWithOffNetLighthouse(t *testing.T) {
|
||||||
|
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||||
|
lhControl, lhVpnIpNet, lhUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "lh ", "2001::1/64", m{"lighthouse": m{"am_lighthouse": true}})
|
||||||
|
|
||||||
|
o := m{
|
||||||
|
"static_host_map": m{
|
||||||
|
lhVpnIpNet[0].Addr().String(): []string{lhUdpAddr.String()},
|
||||||
|
},
|
||||||
|
"lighthouse": m{
|
||||||
|
"hosts": []string{lhVpnIpNet[0].Addr().String()},
|
||||||
|
"local_allow_list": m{
|
||||||
|
// Try and block our lighthouse updates from using the actual addresses assigned to this computer
|
||||||
|
// If we start discovering addresses the test router doesn't know about then test traffic cant flow
|
||||||
|
"10.0.0.0/24": true,
|
||||||
|
"::/0": false,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version2, ca, caKey, "me ", "10.128.0.2/24, ff::2/64", o)
|
||||||
|
theirControl, theirVpnIpNet, _, _ := newSimpleServer(cert.Version2, ca, caKey, "them", "10.128.0.3/24, ff::3/64", o)
|
||||||
|
|
||||||
|
// Build a router so we don't have to reason who gets which packet
|
||||||
|
r := router.NewR(t, lhControl, myControl, theirControl)
|
||||||
|
defer r.RenderFlow()
|
||||||
|
|
||||||
|
// Start the servers
|
||||||
|
lhControl.Start()
|
||||||
|
myControl.Start()
|
||||||
|
theirControl.Start()
|
||||||
|
|
||||||
|
t.Log("Stand up an ipv6 tunnel between me and them")
|
||||||
|
assert.True(t, myVpnIpNet[1].Addr().Is6())
|
||||||
|
assert.True(t, theirVpnIpNet[1].Addr().Is6())
|
||||||
|
assertTunnel(t, myVpnIpNet[1].Addr(), theirVpnIpNet[1].Addr(), myControl, theirControl, r)
|
||||||
|
|
||||||
|
lhControl.Stop()
|
||||||
|
myControl.Stop()
|
||||||
|
theirControl.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGoodHandshakeUnsafeDest(t *testing.T) {
|
||||||
|
unsafePrefix := "192.168.6.0/24"
|
||||||
|
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||||
|
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServerWithUdpAndUnsafeNetworks(cert.Version2, ca, caKey, "spooky", "10.128.0.2/24", netip.MustParseAddrPort("10.64.0.2:4242"), unsafePrefix, nil)
|
||||||
|
route := m{"route": unsafePrefix, "via": theirVpnIpNet[0].Addr().String()}
|
||||||
|
myCfg := m{
|
||||||
|
"tun": m{
|
||||||
|
"unsafe_routes": []m{route},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
myControl, myVpnIpNet, myUdpAddr, myConfig := newSimpleServer(cert.Version2, ca, caKey, "me", "10.128.0.1/24", myCfg)
|
||||||
|
t.Logf("my config %v", myConfig)
|
||||||
|
// Put their info in our lighthouse
|
||||||
|
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
|
||||||
|
|
||||||
|
spookyDest := netip.MustParseAddr("192.168.6.4")
|
||||||
|
|
||||||
|
// Start the servers
|
||||||
|
myControl.Start()
|
||||||
|
theirControl.Start()
|
||||||
|
|
||||||
|
t.Log("Send a udp packet through to begin standing up the tunnel, this should come out the other side")
|
||||||
|
myControl.InjectTunUDPPacket(spookyDest, 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
|
||||||
|
|
||||||
|
t.Log("Have them consume my stage 0 packet. They have a tunnel now")
|
||||||
|
theirControl.InjectUDPPacket(myControl.GetFromUDP(true))
|
||||||
|
|
||||||
|
t.Log("Get their stage 1 packet so that we can play with it")
|
||||||
|
stage1Packet := theirControl.GetFromUDP(true)
|
||||||
|
|
||||||
|
t.Log("I consume a garbage packet with a proper nebula header for our tunnel")
|
||||||
|
// this should log a statement and get ignored, allowing the real handshake packet to complete the tunnel
|
||||||
|
badPacket := stage1Packet.Copy()
|
||||||
|
badPacket.Data = badPacket.Data[:len(badPacket.Data)-header.Len]
|
||||||
|
myControl.InjectUDPPacket(badPacket)
|
||||||
|
|
||||||
|
t.Log("Have me consume their real stage 1 packet. I have a tunnel now")
|
||||||
|
myControl.InjectUDPPacket(stage1Packet)
|
||||||
|
|
||||||
|
t.Log("Wait until we see my cached packet come through")
|
||||||
|
myControl.WaitForType(1, 0, theirControl)
|
||||||
|
|
||||||
|
t.Log("Make sure our host infos are correct")
|
||||||
|
assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet, theirVpnIpNet, myControl, theirControl)
|
||||||
|
|
||||||
|
t.Log("Get that cached packet and make sure it looks right")
|
||||||
|
myCachedPacket := theirControl.GetFromTun(true)
|
||||||
|
assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet[0].Addr(), spookyDest, 80, 80)
|
||||||
|
|
||||||
|
//reply
|
||||||
|
theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, spookyDest, 80, []byte("Hi from the spookyman"))
|
||||||
|
//wait for reply
|
||||||
|
theirControl.WaitForType(1, 0, myControl)
|
||||||
|
theirCachedPacket := myControl.GetFromTun(true)
|
||||||
|
assertUdpPacket(t, []byte("Hi from the spookyman"), theirCachedPacket, spookyDest, myVpnIpNet[0].Addr(), 80, 80)
|
||||||
|
|
||||||
|
t.Log("Do a bidirectional tunnel test")
|
||||||
|
r := router.NewR(t, myControl, theirControl)
|
||||||
|
defer r.RenderFlow()
|
||||||
|
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
|
||||||
|
|
||||||
|
r.RenderHostmaps("Final hostmaps", myControl, theirControl)
|
||||||
|
myControl.Stop()
|
||||||
|
theirControl.Stop()
|
||||||
|
}
|
||||||
|
|||||||
@@ -22,15 +22,14 @@ import (
|
|||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
"github.com/slackhq/nebula/e2e/router"
|
"github.com/slackhq/nebula/e2e/router"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"gopkg.in/yaml.v2"
|
"github.com/stretchr/testify/require"
|
||||||
|
"go.yaml.in/yaml/v3"
|
||||||
)
|
)
|
||||||
|
|
||||||
type m map[string]interface{}
|
type m = map[string]any
|
||||||
|
|
||||||
// newSimpleServer creates a nebula instance with many assumptions
|
// newSimpleServer creates a nebula instance with many assumptions
|
||||||
func newSimpleServer(v cert.Version, caCrt cert.Certificate, caKey []byte, name string, sVpnNetworks string, overrides m) (*nebula.Control, []netip.Prefix, netip.AddrPort, *config.C) {
|
func newSimpleServer(v cert.Version, caCrt cert.Certificate, caKey []byte, name string, sVpnNetworks string, overrides m) (*nebula.Control, []netip.Prefix, netip.AddrPort, *config.C) {
|
||||||
l := NewTestLogger()
|
|
||||||
|
|
||||||
var vpnNetworks []netip.Prefix
|
var vpnNetworks []netip.Prefix
|
||||||
for _, sn := range strings.Split(sVpnNetworks, ",") {
|
for _, sn := range strings.Split(sVpnNetworks, ",") {
|
||||||
vpnIpNet, err := netip.ParsePrefix(strings.TrimSpace(sn))
|
vpnIpNet, err := netip.ParsePrefix(strings.TrimSpace(sn))
|
||||||
@@ -56,7 +55,54 @@ func newSimpleServer(v cert.Version, caCrt cert.Certificate, caKey []byte, name
|
|||||||
budpIp[3] = 239
|
budpIp[3] = 239
|
||||||
udpAddr = netip.AddrPortFrom(netip.AddrFrom16(budpIp), 4242)
|
udpAddr = netip.AddrPortFrom(netip.AddrFrom16(budpIp), 4242)
|
||||||
}
|
}
|
||||||
_, _, myPrivKey, myPEM := cert_test.NewTestCert(v, cert.Curve_CURVE25519, caCrt, caKey, name, time.Now(), time.Now().Add(5*time.Minute), vpnNetworks, nil, []string{})
|
return newSimpleServerWithUdp(v, caCrt, caKey, name, sVpnNetworks, udpAddr, overrides)
|
||||||
|
}
|
||||||
|
|
||||||
|
func newSimpleServerWithUdp(v cert.Version, caCrt cert.Certificate, caKey []byte, name string, sVpnNetworks string, udpAddr netip.AddrPort, overrides m) (*nebula.Control, []netip.Prefix, netip.AddrPort, *config.C) {
|
||||||
|
return newSimpleServerWithUdpAndUnsafeNetworks(v, caCrt, caKey, name, sVpnNetworks, udpAddr, "", overrides)
|
||||||
|
}
|
||||||
|
|
||||||
|
func newSimpleServerWithUdpAndUnsafeNetworks(v cert.Version, caCrt cert.Certificate, caKey []byte, name string, sVpnNetworks string, udpAddr netip.AddrPort, sUnsafeNetworks string, overrides m) (*nebula.Control, []netip.Prefix, netip.AddrPort, *config.C) {
|
||||||
|
l := NewTestLogger()
|
||||||
|
|
||||||
|
var vpnNetworks []netip.Prefix
|
||||||
|
for _, sn := range strings.Split(sVpnNetworks, ",") {
|
||||||
|
vpnIpNet, err := netip.ParsePrefix(strings.TrimSpace(sn))
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
vpnNetworks = append(vpnNetworks, vpnIpNet)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(vpnNetworks) == 0 {
|
||||||
|
panic("no vpn networks")
|
||||||
|
}
|
||||||
|
|
||||||
|
firewallInbound := []m{{
|
||||||
|
"proto": "any",
|
||||||
|
"port": "any",
|
||||||
|
"host": "any",
|
||||||
|
}}
|
||||||
|
|
||||||
|
var unsafeNetworks []netip.Prefix
|
||||||
|
if sUnsafeNetworks != "" {
|
||||||
|
firewallInbound = []m{{
|
||||||
|
"proto": "any",
|
||||||
|
"port": "any",
|
||||||
|
"host": "any",
|
||||||
|
"local_cidr": "0.0.0.0/0",
|
||||||
|
}}
|
||||||
|
|
||||||
|
for _, sn := range strings.Split(sUnsafeNetworks, ",") {
|
||||||
|
x, err := netip.ParsePrefix(strings.TrimSpace(sn))
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
unsafeNetworks = append(unsafeNetworks, x)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
_, _, myPrivKey, myPEM := cert_test.NewTestCert(v, cert.Curve_CURVE25519, caCrt, caKey, name, time.Now(), time.Now().Add(5*time.Minute), vpnNetworks, unsafeNetworks, []string{})
|
||||||
|
|
||||||
caB, err := caCrt.MarshalPEM()
|
caB, err := caCrt.MarshalPEM()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -76,11 +122,7 @@ func newSimpleServer(v cert.Version, caCrt cert.Certificate, caKey []byte, name
|
|||||||
"port": "any",
|
"port": "any",
|
||||||
"host": "any",
|
"host": "any",
|
||||||
}},
|
}},
|
||||||
"inbound": []m{{
|
"inbound": firewallInbound,
|
||||||
"proto": "any",
|
|
||||||
"port": "any",
|
|
||||||
"host": "any",
|
|
||||||
}},
|
|
||||||
},
|
},
|
||||||
//"handshakes": m{
|
//"handshakes": m{
|
||||||
// "try_interval": "1s",
|
// "try_interval": "1s",
|
||||||
@@ -129,6 +171,109 @@ func newSimpleServer(v cert.Version, caCrt cert.Certificate, caKey []byte, name
|
|||||||
return control, vpnNetworks, udpAddr, c
|
return control, vpnNetworks, udpAddr, c
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// newServer creates a nebula instance with fewer assumptions
|
||||||
|
func newServer(caCrt []cert.Certificate, certs []cert.Certificate, key []byte, overrides m) (*nebula.Control, []netip.Prefix, netip.AddrPort, *config.C) {
|
||||||
|
l := NewTestLogger()
|
||||||
|
|
||||||
|
vpnNetworks := certs[len(certs)-1].Networks()
|
||||||
|
|
||||||
|
var udpAddr netip.AddrPort
|
||||||
|
if vpnNetworks[0].Addr().Is4() {
|
||||||
|
budpIp := vpnNetworks[0].Addr().As4()
|
||||||
|
budpIp[1] -= 128
|
||||||
|
udpAddr = netip.AddrPortFrom(netip.AddrFrom4(budpIp), 4242)
|
||||||
|
} else {
|
||||||
|
budpIp := vpnNetworks[0].Addr().As16()
|
||||||
|
// beef for funsies
|
||||||
|
budpIp[2] = 190
|
||||||
|
budpIp[3] = 239
|
||||||
|
udpAddr = netip.AddrPortFrom(netip.AddrFrom16(budpIp), 4242)
|
||||||
|
}
|
||||||
|
|
||||||
|
caStr := ""
|
||||||
|
for _, ca := range caCrt {
|
||||||
|
x, err := ca.MarshalPEM()
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
caStr += string(x)
|
||||||
|
}
|
||||||
|
certStr := ""
|
||||||
|
for _, c := range certs {
|
||||||
|
x, err := c.MarshalPEM()
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
certStr += string(x)
|
||||||
|
}
|
||||||
|
|
||||||
|
mc := m{
|
||||||
|
"pki": m{
|
||||||
|
"ca": caStr,
|
||||||
|
"cert": certStr,
|
||||||
|
"key": string(key),
|
||||||
|
},
|
||||||
|
//"tun": m{"disabled": true},
|
||||||
|
"firewall": m{
|
||||||
|
"outbound": []m{{
|
||||||
|
"proto": "any",
|
||||||
|
"port": "any",
|
||||||
|
"host": "any",
|
||||||
|
}},
|
||||||
|
"inbound": []m{{
|
||||||
|
"proto": "any",
|
||||||
|
"port": "any",
|
||||||
|
"host": "any",
|
||||||
|
}},
|
||||||
|
},
|
||||||
|
//"handshakes": m{
|
||||||
|
// "try_interval": "1s",
|
||||||
|
//},
|
||||||
|
"listen": m{
|
||||||
|
"host": udpAddr.Addr().String(),
|
||||||
|
"port": udpAddr.Port(),
|
||||||
|
},
|
||||||
|
"logging": m{
|
||||||
|
"timestamp_format": fmt.Sprintf("%v 15:04:05.000000", certs[0].Name()),
|
||||||
|
"level": l.Level.String(),
|
||||||
|
},
|
||||||
|
"timers": m{
|
||||||
|
"pending_deletion_interval": 2,
|
||||||
|
"connection_alive_interval": 2,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
if overrides != nil {
|
||||||
|
final := m{}
|
||||||
|
err := mergo.Merge(&final, overrides, mergo.WithAppendSlice)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
err = mergo.Merge(&final, mc, mergo.WithAppendSlice)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
mc = final
|
||||||
|
}
|
||||||
|
|
||||||
|
cb, err := yaml.Marshal(mc)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
c := config.NewC(l)
|
||||||
|
cStr := string(cb)
|
||||||
|
c.LoadString(cStr)
|
||||||
|
|
||||||
|
control, err := nebula.Main(c, false, "e2e-test", l, nil)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return control, vpnNetworks, udpAddr, c
|
||||||
|
}
|
||||||
|
|
||||||
type doneCb func()
|
type doneCb func()
|
||||||
|
|
||||||
func deadline(t *testing.T, seconds time.Duration) doneCb {
|
func deadline(t *testing.T, seconds time.Duration) doneCb {
|
||||||
@@ -147,7 +292,7 @@ func deadline(t *testing.T, seconds time.Duration) doneCb {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func assertTunnel(t *testing.T, vpnIpA, vpnIpB netip.Addr, controlA, controlB *nebula.Control, r *router.R) {
|
func assertTunnel(t testing.TB, vpnIpA, vpnIpB netip.Addr, controlA, controlB *nebula.Control, r *router.R) {
|
||||||
// Send a packet from them to me
|
// Send a packet from them to me
|
||||||
controlB.InjectTunUDPPacket(vpnIpA, 80, vpnIpB, 90, []byte("Hi from B"))
|
controlB.InjectTunUDPPacket(vpnIpA, 80, vpnIpB, 90, []byte("Hi from B"))
|
||||||
bPacket := r.RouteForAllUntilTxTun(controlA)
|
bPacket := r.RouteForAllUntilTxTun(controlA)
|
||||||
@@ -159,14 +304,14 @@ func assertTunnel(t *testing.T, vpnIpA, vpnIpB netip.Addr, controlA, controlB *n
|
|||||||
assertUdpPacket(t, []byte("Hello from A"), aPacket, vpnIpA, vpnIpB, 90, 80)
|
assertUdpPacket(t, []byte("Hello from A"), aPacket, vpnIpA, vpnIpB, 90, 80)
|
||||||
}
|
}
|
||||||
|
|
||||||
func assertHostInfoPair(t *testing.T, addrA, addrB netip.AddrPort, vpnNetsA, vpnNetsB []netip.Prefix, controlA, controlB *nebula.Control) {
|
func assertHostInfoPair(t testing.TB, addrA, addrB netip.AddrPort, vpnNetsA, vpnNetsB []netip.Prefix, controlA, controlB *nebula.Control) {
|
||||||
// Get both host infos
|
// Get both host infos
|
||||||
//TODO: CERT-V2 we may want to loop over each vpnAddr and assert all the things
|
//TODO: CERT-V2 we may want to loop over each vpnAddr and assert all the things
|
||||||
hBinA := controlA.GetHostInfoByVpnAddr(vpnNetsB[0].Addr(), false)
|
hBinA := controlA.GetHostInfoByVpnAddr(vpnNetsB[0].Addr(), false)
|
||||||
assert.NotNil(t, hBinA, "Host B was not found by vpnAddr in controlA")
|
require.NotNil(t, hBinA, "Host B was not found by vpnAddr in controlA")
|
||||||
|
|
||||||
hAinB := controlB.GetHostInfoByVpnAddr(vpnNetsA[0].Addr(), false)
|
hAinB := controlB.GetHostInfoByVpnAddr(vpnNetsA[0].Addr(), false)
|
||||||
assert.NotNil(t, hAinB, "Host A was not found by vpnAddr in controlB")
|
require.NotNil(t, hAinB, "Host A was not found by vpnAddr in controlB")
|
||||||
|
|
||||||
// Check that both vpn and real addr are correct
|
// Check that both vpn and real addr are correct
|
||||||
assert.EqualValues(t, getAddrs(vpnNetsB), hBinA.VpnAddrs, "Host B VpnIp is wrong in control A")
|
assert.EqualValues(t, getAddrs(vpnNetsB), hBinA.VpnAddrs, "Host B VpnIp is wrong in control A")
|
||||||
@@ -180,7 +325,7 @@ func assertHostInfoPair(t *testing.T, addrA, addrB netip.AddrPort, vpnNetsA, vpn
|
|||||||
assert.Equal(t, hBinA.RemoteIndex, hAinB.LocalIndex, "Host B remote index does not match host A local index")
|
assert.Equal(t, hBinA.RemoteIndex, hAinB.LocalIndex, "Host B remote index does not match host A local index")
|
||||||
}
|
}
|
||||||
|
|
||||||
func assertUdpPacket(t *testing.T, expected, b []byte, fromIp, toIp netip.Addr, fromPort, toPort uint16) {
|
func assertUdpPacket(t testing.TB, expected, b []byte, fromIp, toIp netip.Addr, fromPort, toPort uint16) {
|
||||||
if toIp.Is6() {
|
if toIp.Is6() {
|
||||||
assertUdpPacket6(t, expected, b, fromIp, toIp, fromPort, toPort)
|
assertUdpPacket6(t, expected, b, fromIp, toIp, fromPort, toPort)
|
||||||
} else {
|
} else {
|
||||||
@@ -188,7 +333,7 @@ func assertUdpPacket(t *testing.T, expected, b []byte, fromIp, toIp netip.Addr,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func assertUdpPacket6(t *testing.T, expected, b []byte, fromIp, toIp netip.Addr, fromPort, toPort uint16) {
|
func assertUdpPacket6(t testing.TB, expected, b []byte, fromIp, toIp netip.Addr, fromPort, toPort uint16) {
|
||||||
packet := gopacket.NewPacket(b, layers.LayerTypeIPv6, gopacket.Lazy)
|
packet := gopacket.NewPacket(b, layers.LayerTypeIPv6, gopacket.Lazy)
|
||||||
v6 := packet.Layer(layers.LayerTypeIPv6).(*layers.IPv6)
|
v6 := packet.Layer(layers.LayerTypeIPv6).(*layers.IPv6)
|
||||||
assert.NotNil(t, v6, "No ipv6 data found")
|
assert.NotNil(t, v6, "No ipv6 data found")
|
||||||
@@ -207,7 +352,7 @@ func assertUdpPacket6(t *testing.T, expected, b []byte, fromIp, toIp netip.Addr,
|
|||||||
assert.Equal(t, expected, data.Payload(), "Data was incorrect")
|
assert.Equal(t, expected, data.Payload(), "Data was incorrect")
|
||||||
}
|
}
|
||||||
|
|
||||||
func assertUdpPacket4(t *testing.T, expected, b []byte, fromIp, toIp netip.Addr, fromPort, toPort uint16) {
|
func assertUdpPacket4(t testing.TB, expected, b []byte, fromIp, toIp netip.Addr, fromPort, toPort uint16) {
|
||||||
packet := gopacket.NewPacket(b, layers.LayerTypeIPv4, gopacket.Lazy)
|
packet := gopacket.NewPacket(b, layers.LayerTypeIPv4, gopacket.Lazy)
|
||||||
v4 := packet.Layer(layers.LayerTypeIPv4).(*layers.IPv4)
|
v4 := packet.Layer(layers.LayerTypeIPv4).(*layers.IPv4)
|
||||||
assert.NotNil(t, v4, "No ipv4 data found")
|
assert.NotNil(t, v4, "No ipv4 data found")
|
||||||
|
|||||||
@@ -700,6 +700,7 @@ func (r *R) FlushAll() {
|
|||||||
r.Unlock()
|
r.Unlock()
|
||||||
panic("Can't FlushAll for host: " + p.To.String())
|
panic("Can't FlushAll for host: " + p.To.String())
|
||||||
}
|
}
|
||||||
|
receiver.InjectUDPPacket(p)
|
||||||
r.Unlock()
|
r.Unlock()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
367
e2e/tunnels_test.go
Normal file
367
e2e/tunnels_test.go
Normal file
@@ -0,0 +1,367 @@
|
|||||||
|
//go:build e2e_testing
|
||||||
|
// +build e2e_testing
|
||||||
|
|
||||||
|
package e2e
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/netip"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/slackhq/nebula/cert"
|
||||||
|
"github.com/slackhq/nebula/cert_test"
|
||||||
|
"github.com/slackhq/nebula/e2e/router"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"gopkg.in/yaml.v3"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestDropInactiveTunnels(t *testing.T) {
|
||||||
|
// The goal of this test is to ensure the shortest inactivity timeout will close the tunnel on both sides
|
||||||
|
// under ideal conditions
|
||||||
|
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||||
|
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", m{"tunnels": m{"drop_inactive": true, "inactivity_timeout": "5s"}})
|
||||||
|
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", m{"tunnels": m{"drop_inactive": true, "inactivity_timeout": "10m"}})
|
||||||
|
|
||||||
|
// Share our underlay information
|
||||||
|
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
|
||||||
|
theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
|
||||||
|
|
||||||
|
// Start the servers
|
||||||
|
myControl.Start()
|
||||||
|
theirControl.Start()
|
||||||
|
|
||||||
|
r := router.NewR(t, myControl, theirControl)
|
||||||
|
|
||||||
|
r.Log("Assert the tunnel between me and them works")
|
||||||
|
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
|
||||||
|
|
||||||
|
r.Log("Go inactive and wait for the tunnels to get dropped")
|
||||||
|
waitStart := time.Now()
|
||||||
|
for {
|
||||||
|
myIndexes := len(myControl.GetHostmap().Indexes)
|
||||||
|
theirIndexes := len(theirControl.GetHostmap().Indexes)
|
||||||
|
if myIndexes == 0 && theirIndexes == 0 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
since := time.Since(waitStart)
|
||||||
|
r.Logf("my tunnels: %v; their tunnels: %v; duration: %v", myIndexes, theirIndexes, since)
|
||||||
|
if since > time.Second*30 {
|
||||||
|
t.Fatal("Tunnel should have been declared inactive after 5 seconds and before 30 seconds")
|
||||||
|
}
|
||||||
|
|
||||||
|
time.Sleep(1 * time.Second)
|
||||||
|
r.FlushAll()
|
||||||
|
}
|
||||||
|
|
||||||
|
r.Logf("Inactive tunnels were dropped within %v", time.Since(waitStart))
|
||||||
|
myControl.Stop()
|
||||||
|
theirControl.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCertUpgrade(t *testing.T) {
|
||||||
|
// The goal of this test is to ensure the shortest inactivity timeout will close the tunnel on both sides
|
||||||
|
// under ideal conditions
|
||||||
|
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||||
|
caB, err := ca.MarshalPEM()
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
ca2, _, caKey2, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||||
|
|
||||||
|
ca2B, err := ca2.MarshalPEM()
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
caStr := fmt.Sprintf("%s\n%s", caB, ca2B)
|
||||||
|
|
||||||
|
myCert, _, myPrivKey, _ := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "me", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.1/24")}, nil, []string{})
|
||||||
|
_, myCert2Pem := cert_test.NewTestCertDifferentVersion(myCert, cert.Version2, ca2, caKey2)
|
||||||
|
|
||||||
|
theirCert, _, theirPrivKey, _ := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "them", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.2/24")}, nil, []string{})
|
||||||
|
theirCert2, _ := cert_test.NewTestCertDifferentVersion(theirCert, cert.Version2, ca2, caKey2)
|
||||||
|
|
||||||
|
myControl, myVpnIpNet, myUdpAddr, myC := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{myCert}, myPrivKey, m{})
|
||||||
|
theirControl, theirVpnIpNet, theirUdpAddr, _ := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{theirCert, theirCert2}, theirPrivKey, m{})
|
||||||
|
|
||||||
|
// Share our underlay information
|
||||||
|
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
|
||||||
|
theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
|
||||||
|
|
||||||
|
// Start the servers
|
||||||
|
myControl.Start()
|
||||||
|
theirControl.Start()
|
||||||
|
|
||||||
|
r := router.NewR(t, myControl, theirControl)
|
||||||
|
defer r.RenderFlow()
|
||||||
|
|
||||||
|
r.Log("Assert the tunnel between me and them works")
|
||||||
|
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
|
||||||
|
r.Log("yay")
|
||||||
|
//todo ???
|
||||||
|
time.Sleep(1 * time.Second)
|
||||||
|
r.FlushAll()
|
||||||
|
|
||||||
|
mc := m{
|
||||||
|
"pki": m{
|
||||||
|
"ca": caStr,
|
||||||
|
"cert": string(myCert2Pem),
|
||||||
|
"key": string(myPrivKey),
|
||||||
|
},
|
||||||
|
//"tun": m{"disabled": true},
|
||||||
|
"firewall": myC.Settings["firewall"],
|
||||||
|
//"handshakes": m{
|
||||||
|
// "try_interval": "1s",
|
||||||
|
//},
|
||||||
|
"listen": myC.Settings["listen"],
|
||||||
|
"logging": myC.Settings["logging"],
|
||||||
|
"timers": myC.Settings["timers"],
|
||||||
|
}
|
||||||
|
|
||||||
|
cb, err := yaml.Marshal(mc)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
r.Logf("reload new v2-only config")
|
||||||
|
err = myC.ReloadConfigString(string(cb))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
r.Log("yay, spin until their sees it")
|
||||||
|
waitStart := time.Now()
|
||||||
|
for {
|
||||||
|
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
|
||||||
|
c := theirControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false)
|
||||||
|
if c == nil {
|
||||||
|
r.Log("nil")
|
||||||
|
} else {
|
||||||
|
version := c.Cert.Version()
|
||||||
|
r.Logf("version %d", version)
|
||||||
|
if version == cert.Version2 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
since := time.Since(waitStart)
|
||||||
|
if since > time.Second*10 {
|
||||||
|
t.Fatal("Cert should be new by now")
|
||||||
|
}
|
||||||
|
time.Sleep(time.Second)
|
||||||
|
}
|
||||||
|
|
||||||
|
r.RenderHostmaps("Final hostmaps", myControl, theirControl)
|
||||||
|
|
||||||
|
myControl.Stop()
|
||||||
|
theirControl.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCertDowngrade(t *testing.T) {
|
||||||
|
// The goal of this test is to ensure the shortest inactivity timeout will close the tunnel on both sides
|
||||||
|
// under ideal conditions
|
||||||
|
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||||
|
caB, err := ca.MarshalPEM()
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
ca2, _, caKey2, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||||
|
|
||||||
|
ca2B, err := ca2.MarshalPEM()
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
caStr := fmt.Sprintf("%s\n%s", caB, ca2B)
|
||||||
|
|
||||||
|
myCert, _, myPrivKey, myCertPem := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "me", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.1/24")}, nil, []string{})
|
||||||
|
myCert2, _ := cert_test.NewTestCertDifferentVersion(myCert, cert.Version2, ca2, caKey2)
|
||||||
|
|
||||||
|
theirCert, _, theirPrivKey, _ := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "them", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.2/24")}, nil, []string{})
|
||||||
|
theirCert2, _ := cert_test.NewTestCertDifferentVersion(theirCert, cert.Version2, ca2, caKey2)
|
||||||
|
|
||||||
|
myControl, myVpnIpNet, myUdpAddr, myC := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{myCert2}, myPrivKey, m{})
|
||||||
|
theirControl, theirVpnIpNet, theirUdpAddr, _ := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{theirCert, theirCert2}, theirPrivKey, m{})
|
||||||
|
|
||||||
|
// Share our underlay information
|
||||||
|
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
|
||||||
|
theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
|
||||||
|
|
||||||
|
// Start the servers
|
||||||
|
myControl.Start()
|
||||||
|
theirControl.Start()
|
||||||
|
|
||||||
|
r := router.NewR(t, myControl, theirControl)
|
||||||
|
defer r.RenderFlow()
|
||||||
|
|
||||||
|
r.Log("Assert the tunnel between me and them works")
|
||||||
|
//assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r)
|
||||||
|
//r.Log("yay")
|
||||||
|
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
|
||||||
|
r.Log("yay")
|
||||||
|
//todo ???
|
||||||
|
time.Sleep(1 * time.Second)
|
||||||
|
r.FlushAll()
|
||||||
|
|
||||||
|
mc := m{
|
||||||
|
"pki": m{
|
||||||
|
"ca": caStr,
|
||||||
|
"cert": string(myCertPem),
|
||||||
|
"key": string(myPrivKey),
|
||||||
|
},
|
||||||
|
"firewall": myC.Settings["firewall"],
|
||||||
|
"listen": myC.Settings["listen"],
|
||||||
|
"logging": myC.Settings["logging"],
|
||||||
|
"timers": myC.Settings["timers"],
|
||||||
|
}
|
||||||
|
|
||||||
|
cb, err := yaml.Marshal(mc)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
r.Logf("reload new v1-only config")
|
||||||
|
err = myC.ReloadConfigString(string(cb))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
r.Log("yay, spin until their sees it")
|
||||||
|
waitStart := time.Now()
|
||||||
|
for {
|
||||||
|
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
|
||||||
|
c := theirControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false)
|
||||||
|
c2 := myControl.GetHostInfoByVpnAddr(theirVpnIpNet[0].Addr(), false)
|
||||||
|
if c == nil || c2 == nil {
|
||||||
|
r.Log("nil")
|
||||||
|
} else {
|
||||||
|
version := c.Cert.Version()
|
||||||
|
theirVersion := c2.Cert.Version()
|
||||||
|
r.Logf("version %d,%d", version, theirVersion)
|
||||||
|
if version == cert.Version1 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
since := time.Since(waitStart)
|
||||||
|
if since > time.Second*5 {
|
||||||
|
r.Log("it is unusual that the cert is not new yet, but not a failure yet")
|
||||||
|
}
|
||||||
|
if since > time.Second*10 {
|
||||||
|
r.Log("wtf")
|
||||||
|
t.Fatal("Cert should be new by now")
|
||||||
|
}
|
||||||
|
time.Sleep(time.Second)
|
||||||
|
}
|
||||||
|
|
||||||
|
r.RenderHostmaps("Final hostmaps", myControl, theirControl)
|
||||||
|
|
||||||
|
myControl.Stop()
|
||||||
|
theirControl.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCertMismatchCorrection(t *testing.T) {
|
||||||
|
// The goal of this test is to ensure the shortest inactivity timeout will close the tunnel on both sides
|
||||||
|
// under ideal conditions
|
||||||
|
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||||
|
ca2, _, caKey2, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||||
|
|
||||||
|
myCert, _, myPrivKey, _ := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "me", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.1/24")}, nil, []string{})
|
||||||
|
myCert2, _ := cert_test.NewTestCertDifferentVersion(myCert, cert.Version2, ca2, caKey2)
|
||||||
|
|
||||||
|
theirCert, _, theirPrivKey, _ := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "them", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.2/24")}, nil, []string{})
|
||||||
|
theirCert2, _ := cert_test.NewTestCertDifferentVersion(theirCert, cert.Version2, ca2, caKey2)
|
||||||
|
|
||||||
|
myControl, myVpnIpNet, myUdpAddr, _ := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{myCert2}, myPrivKey, m{})
|
||||||
|
theirControl, theirVpnIpNet, theirUdpAddr, _ := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{theirCert, theirCert2}, theirPrivKey, m{})
|
||||||
|
|
||||||
|
// Share our underlay information
|
||||||
|
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
|
||||||
|
theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
|
||||||
|
|
||||||
|
// Start the servers
|
||||||
|
myControl.Start()
|
||||||
|
theirControl.Start()
|
||||||
|
|
||||||
|
r := router.NewR(t, myControl, theirControl)
|
||||||
|
defer r.RenderFlow()
|
||||||
|
|
||||||
|
r.Log("Assert the tunnel between me and them works")
|
||||||
|
//assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r)
|
||||||
|
//r.Log("yay")
|
||||||
|
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
|
||||||
|
r.Log("yay")
|
||||||
|
//todo ???
|
||||||
|
time.Sleep(1 * time.Second)
|
||||||
|
r.FlushAll()
|
||||||
|
|
||||||
|
waitStart := time.Now()
|
||||||
|
for {
|
||||||
|
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
|
||||||
|
c := theirControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false)
|
||||||
|
c2 := myControl.GetHostInfoByVpnAddr(theirVpnIpNet[0].Addr(), false)
|
||||||
|
if c == nil || c2 == nil {
|
||||||
|
r.Log("nil")
|
||||||
|
} else {
|
||||||
|
version := c.Cert.Version()
|
||||||
|
theirVersion := c2.Cert.Version()
|
||||||
|
r.Logf("version %d,%d", version, theirVersion)
|
||||||
|
if version == theirVersion {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
since := time.Since(waitStart)
|
||||||
|
if since > time.Second*5 {
|
||||||
|
r.Log("wtf")
|
||||||
|
}
|
||||||
|
if since > time.Second*10 {
|
||||||
|
r.Log("wtf")
|
||||||
|
t.Fatal("Cert should be new by now")
|
||||||
|
}
|
||||||
|
time.Sleep(time.Second)
|
||||||
|
}
|
||||||
|
|
||||||
|
r.RenderHostmaps("Final hostmaps", myControl, theirControl)
|
||||||
|
|
||||||
|
myControl.Stop()
|
||||||
|
theirControl.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCrossStackRelaysWork(t *testing.T) {
|
||||||
|
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||||
|
myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version2, ca, caKey, "me ", "10.128.0.1/24,fc00::1/64", m{"relay": m{"use_relays": true}})
|
||||||
|
relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "relay ", "10.128.0.128/24,fc00::128/64", m{"relay": m{"am_relay": true}})
|
||||||
|
theirUdp := netip.MustParseAddrPort("10.0.0.2:4242")
|
||||||
|
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServerWithUdp(cert.Version2, ca, caKey, "them ", "fc00::2/64", theirUdp, m{"relay": m{"use_relays": true}})
|
||||||
|
|
||||||
|
//myVpnV4 := myVpnIpNet[0]
|
||||||
|
myVpnV6 := myVpnIpNet[1]
|
||||||
|
relayVpnV4 := relayVpnIpNet[0]
|
||||||
|
relayVpnV6 := relayVpnIpNet[1]
|
||||||
|
theirVpnV6 := theirVpnIpNet[0]
|
||||||
|
|
||||||
|
// Teach my how to get to the relay and that their can be reached via the relay
|
||||||
|
myControl.InjectLightHouseAddr(relayVpnV4.Addr(), relayUdpAddr)
|
||||||
|
myControl.InjectLightHouseAddr(relayVpnV6.Addr(), relayUdpAddr)
|
||||||
|
myControl.InjectRelays(theirVpnV6.Addr(), []netip.Addr{relayVpnV6.Addr()})
|
||||||
|
relayControl.InjectLightHouseAddr(theirVpnV6.Addr(), theirUdpAddr)
|
||||||
|
|
||||||
|
// Build a router so we don't have to reason who gets which packet
|
||||||
|
r := router.NewR(t, myControl, relayControl, theirControl)
|
||||||
|
defer r.RenderFlow()
|
||||||
|
|
||||||
|
// Start the servers
|
||||||
|
myControl.Start()
|
||||||
|
relayControl.Start()
|
||||||
|
theirControl.Start()
|
||||||
|
|
||||||
|
t.Log("Trigger a handshake from me to them via the relay")
|
||||||
|
myControl.InjectTunUDPPacket(theirVpnV6.Addr(), 80, myVpnV6.Addr(), 80, []byte("Hi from me"))
|
||||||
|
|
||||||
|
p := r.RouteForAllUntilTxTun(theirControl)
|
||||||
|
r.Log("Assert the tunnel works")
|
||||||
|
assertUdpPacket(t, []byte("Hi from me"), p, myVpnV6.Addr(), theirVpnV6.Addr(), 80, 80)
|
||||||
|
|
||||||
|
t.Log("reply?")
|
||||||
|
theirControl.InjectTunUDPPacket(myVpnV6.Addr(), 80, theirVpnV6.Addr(), 80, []byte("Hi from them"))
|
||||||
|
p = r.RouteForAllUntilTxTun(myControl)
|
||||||
|
assertUdpPacket(t, []byte("Hi from them"), p, theirVpnV6.Addr(), myVpnV6.Addr(), 80, 80)
|
||||||
|
|
||||||
|
r.RenderHostmaps("Final hostmaps", myControl, relayControl, theirControl)
|
||||||
|
//t.Log("finish up")
|
||||||
|
//myControl.Stop()
|
||||||
|
//theirControl.Stop()
|
||||||
|
//relayControl.Stop()
|
||||||
|
}
|
||||||
@@ -13,11 +13,11 @@ pki:
|
|||||||
# disconnect_invalid is a toggle to force a client to be disconnected if the certificate is expired or invalid.
|
# disconnect_invalid is a toggle to force a client to be disconnected if the certificate is expired or invalid.
|
||||||
#disconnect_invalid: true
|
#disconnect_invalid: true
|
||||||
|
|
||||||
# default_version controls which certificate version is used in handshakes.
|
# initiating_version controls which certificate version is used when initiating handshakes.
|
||||||
# This setting only applies if both a v1 and a v2 certificate are configured, in which case it will default to `1`.
|
# This setting only applies if both a v1 and a v2 certificate are configured, in which case it will default to `1`.
|
||||||
# Once all hosts in the mesh are configured with both a v1 and v2 certificate then this should be changed to `2`.
|
# Once all hosts in the mesh are configured with both a v1 and v2 certificate then this should be changed to `2`.
|
||||||
# After all hosts in the mesh are using a v2 certificate then v1 certificates are no longer needed.
|
# After all hosts in the mesh are using a v2 certificate then v1 certificates are no longer needed.
|
||||||
# default_version: 1
|
# initiating_version: 1
|
||||||
|
|
||||||
# The static host map defines a set of hosts with fixed IP addresses on the internet (or any network).
|
# The static host map defines a set of hosts with fixed IP addresses on the internet (or any network).
|
||||||
# A host can have multiple fixed IP addresses defined here, and nebula will try each when establishing a tunnel.
|
# A host can have multiple fixed IP addresses defined here, and nebula will try each when establishing a tunnel.
|
||||||
@@ -126,8 +126,8 @@ lighthouse:
|
|||||||
# Port Nebula will be listening on. The default here is 4242. For a lighthouse node, the port should be defined,
|
# Port Nebula will be listening on. The default here is 4242. For a lighthouse node, the port should be defined,
|
||||||
# however using port 0 will dynamically assign a port and is recommended for roaming nodes.
|
# however using port 0 will dynamically assign a port and is recommended for roaming nodes.
|
||||||
listen:
|
listen:
|
||||||
# To listen on both any ipv4 and ipv6 use "::"
|
# To listen on only ipv4, use "0.0.0.0"
|
||||||
host: 0.0.0.0
|
host: "::"
|
||||||
port: 4242
|
port: 4242
|
||||||
# Sets the max number of packets to pull from the kernel for each syscall (under systems that support recvmmsg)
|
# Sets the max number of packets to pull from the kernel for each syscall (under systems that support recvmmsg)
|
||||||
# default is 64, does not support reload
|
# default is 64, does not support reload
|
||||||
@@ -144,6 +144,11 @@ listen:
|
|||||||
# valid values: always, never, private
|
# valid values: always, never, private
|
||||||
# This setting is reloadable.
|
# This setting is reloadable.
|
||||||
#send_recv_error: always
|
#send_recv_error: always
|
||||||
|
# The so_sock option is a Linux-specific feature that allows all outgoing Nebula packets to be tagged with a specific identifier.
|
||||||
|
# This tagging enables IP rule-based filtering. For example, it supports 0.0.0.0/0 unsafe_routes,
|
||||||
|
# allowing for more precise routing decisions based on the packet tags. Default is 0 meaning no mark is set.
|
||||||
|
# This setting is reloadable.
|
||||||
|
#so_mark: 0
|
||||||
|
|
||||||
# Routines is the number of thread pairs to run that consume from the tun and UDP queues.
|
# Routines is the number of thread pairs to run that consume from the tun and UDP queues.
|
||||||
# Currently, this defaults to 1 which means we have 1 tun queue reader and 1
|
# Currently, this defaults to 1 which means we have 1 tun queue reader and 1
|
||||||
@@ -234,7 +239,28 @@ tun:
|
|||||||
|
|
||||||
# Unsafe routes allows you to route traffic over nebula to non-nebula nodes
|
# Unsafe routes allows you to route traffic over nebula to non-nebula nodes
|
||||||
# Unsafe routes should be avoided unless you have hosts/services that cannot run nebula
|
# Unsafe routes should be avoided unless you have hosts/services that cannot run nebula
|
||||||
# NOTE: The nebula certificate of the "via" node *MUST* have the "route" defined as a subnet in its certificate
|
# Supports weighted ECMP if you define a list of gateways, this can be used for load balancing or redundancy to hosts outside of nebula
|
||||||
|
# NOTES:
|
||||||
|
# * You will only see a single gateway in the routing table if you are not on linux
|
||||||
|
# * If a gateway is not reachable through the overlay another gateway will be selected to send the traffic through, ignoring weights
|
||||||
|
#
|
||||||
|
# unsafe_routes:
|
||||||
|
# # Multiple gateways without defining a weight defaults to a weight of 1, this will balance traffic equally between the three gateways
|
||||||
|
# - route: 192.168.87.0/24
|
||||||
|
# via:
|
||||||
|
# - gateway: 10.0.0.1
|
||||||
|
# - gateway: 10.0.0.2
|
||||||
|
# - gateway: 10.0.0.3
|
||||||
|
# # Multiple gateways with a weight, this will balance traffic accordingly
|
||||||
|
# - route: 192.168.87.0/24
|
||||||
|
# via:
|
||||||
|
# - gateway: 10.0.0.1
|
||||||
|
# weight: 10
|
||||||
|
# - gateway: 10.0.0.2
|
||||||
|
# weight: 5
|
||||||
|
#
|
||||||
|
# NOTE: The nebula certificate of the "via" node(s) *MUST* have the "route" defined as a subnet in its certificate
|
||||||
|
# `via`: single node or list of gateways to use for this route
|
||||||
# `mtu`: will default to tun mtu if this option is not specified
|
# `mtu`: will default to tun mtu if this option is not specified
|
||||||
# `metric`: will default to 0 if this option is not specified
|
# `metric`: will default to 0 if this option is not specified
|
||||||
# `install`: will default to true, controls whether this route is installed in the systems routing table.
|
# `install`: will default to true, controls whether this route is installed in the systems routing table.
|
||||||
@@ -249,6 +275,10 @@ tun:
|
|||||||
# On linux only, set to true to manage unsafe routes directly on the system route table with gateway routes instead of
|
# On linux only, set to true to manage unsafe routes directly on the system route table with gateway routes instead of
|
||||||
# in nebula configuration files. Default false, not reloadable.
|
# in nebula configuration files. Default false, not reloadable.
|
||||||
#use_system_route_table: false
|
#use_system_route_table: false
|
||||||
|
# Buffer size for reading routes updates. 0 means default system buffer size. (/proc/sys/net/core/rmem_default).
|
||||||
|
# If using massive routes updates, for example BGP, you may need to increase this value to avoid packet loss.
|
||||||
|
# SO_RCVBUFFORCE is used to avoid having to raise the system wide max
|
||||||
|
#use_system_route_table_buffer_size: 0
|
||||||
|
|
||||||
# Configure logging level
|
# Configure logging level
|
||||||
logging:
|
logging:
|
||||||
@@ -308,6 +338,18 @@ logging:
|
|||||||
# after receiving the response for lighthouse queries
|
# after receiving the response for lighthouse queries
|
||||||
#trigger_buffer: 64
|
#trigger_buffer: 64
|
||||||
|
|
||||||
|
# Tunnel manager settings
|
||||||
|
#tunnels:
|
||||||
|
# drop_inactive controls whether inactive tunnels are maintained or dropped after the inactive_timeout period has
|
||||||
|
# elapsed.
|
||||||
|
# In general, it is a good idea to enable this setting. It will be enabled by default in a future release.
|
||||||
|
# This setting is reloadable
|
||||||
|
#drop_inactive: false
|
||||||
|
|
||||||
|
# inactivity_timeout controls how long a tunnel MUST NOT see any inbound or outbound traffic before being considered
|
||||||
|
# inactive and eligible to be dropped.
|
||||||
|
# This setting is reloadable
|
||||||
|
#inactivity_timeout: 10m
|
||||||
|
|
||||||
# Nebula security group configuration
|
# Nebula security group configuration
|
||||||
firewall:
|
firewall:
|
||||||
@@ -320,11 +362,11 @@ firewall:
|
|||||||
outbound_action: drop
|
outbound_action: drop
|
||||||
inbound_action: drop
|
inbound_action: drop
|
||||||
|
|
||||||
# Controls the default value for local_cidr. Default is true, will be deprecated after v1.9 and defaulted to false.
|
# THIS FLAG IS DEPRECATED AND WILL BE REMOVED IN A FUTURE RELEASE. (Defaults to false.)
|
||||||
# This setting only affects nebula hosts with subnets encoded in their certificate. A nebula host acting as an
|
# This setting only affects nebula hosts exposing unsafe_routes. When set to false, each inbound rule must contain a
|
||||||
# unsafe router with `default_local_cidr_any: true` will expose their unsafe routes to every inbound rule regardless
|
# `local_cidr` if the intention is to allow traffic to flow to an unsafe route. When set to true, every firewall rule
|
||||||
# of the actual destination for the packet. Setting this to false requires each inbound rule to contain a `local_cidr`
|
# will apply to all configured unsafe_routes regardless of the actual destination of the packet, unless `local_cidr`
|
||||||
# if the intention is to allow traffic to flow to an unsafe route.
|
# is explicitly defined. This is usually not the desired behavior and should be avoided!
|
||||||
#default_local_cidr_any: false
|
#default_local_cidr_any: false
|
||||||
|
|
||||||
conntrack:
|
conntrack:
|
||||||
@@ -341,12 +383,11 @@ firewall:
|
|||||||
# host: `any` or a literal hostname, ie `test-host`
|
# host: `any` or a literal hostname, ie `test-host`
|
||||||
# group: `any` or a literal group name, ie `default-group`
|
# group: `any` or a literal group name, ie `default-group`
|
||||||
# groups: Same as group but accepts a list of values. Multiple values are AND'd together and a certificate would have to contain all groups to pass
|
# groups: Same as group but accepts a list of values. Multiple values are AND'd together and a certificate would have to contain all groups to pass
|
||||||
# cidr: a remote CIDR, `0.0.0.0/0` is any ipv4 and `::/0` is any ipv6.
|
# cidr: a remote CIDR, `0.0.0.0/0` is any ipv4 and `::/0` is any ipv6. `any` means any ip family and address.
|
||||||
# local_cidr: a local CIDR, `0.0.0.0/0` is any ipv4 and `::/0` is any ipv6. This could be used to filter destinations when using unsafe_routes.
|
# local_cidr: a local CIDR, `0.0.0.0/0` is any ipv4 and `::/0` is any ipv6. `any` means any ip family and address.
|
||||||
# If no unsafe networks are present in the certificate(s) or `default_local_cidr_any` is true then the default is any ipv4 or ipv6 network.
|
# This can be used to filter destinations when using unsafe_routes.
|
||||||
# Otherwise the default is any vpn network assigned to via the certificate.
|
# By default, this is set to only the VPN (overlay) networks assigned via the certificate networks field unless `default_local_cidr_any` is set to true.
|
||||||
# `default_local_cidr_any` defaults to false and is deprecated, it will be removed in a future release.
|
# If there are unsafe_routes present in this config file, `local_cidr` should be set appropriately for the intended us case.
|
||||||
# If there are unsafe routes present its best to set `local_cidr` to whatever best fits the situation.
|
|
||||||
# ca_name: An issuing CA name
|
# ca_name: An issuing CA name
|
||||||
# ca_sha: An issuing CA shasum
|
# ca_sha: An issuing CA shasum
|
||||||
|
|
||||||
|
|||||||
@@ -5,8 +5,12 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
"net"
|
"net"
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
|
"github.com/slackhq/nebula"
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
|
"github.com/slackhq/nebula/overlay"
|
||||||
"github.com/slackhq/nebula/service"
|
"github.com/slackhq/nebula/service"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -59,7 +63,16 @@ pki:
|
|||||||
if err := cfg.LoadString(configStr); err != nil {
|
if err := cfg.LoadString(configStr); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
svc, err := service.New(&cfg)
|
|
||||||
|
logger := logrus.New()
|
||||||
|
logger.Out = os.Stdout
|
||||||
|
|
||||||
|
ctrl, err := nebula.Main(&cfg, false, "custom-app", logger, overlay.NewUserDeviceFromConfig)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
svc, err := service.New(ctrl)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|||||||
252
firewall.go
252
firewall.go
@@ -8,6 +8,7 @@ import (
|
|||||||
"hash/fnv"
|
"hash/fnv"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
"slices"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
@@ -22,7 +23,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type FirewallInterface interface {
|
type FirewallInterface interface {
|
||||||
AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, addr, localAddr netip.Prefix, caName string, caSha string) error
|
AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, cidr, localCidr string, caName string, caSha string) error
|
||||||
}
|
}
|
||||||
|
|
||||||
type conn struct {
|
type conn struct {
|
||||||
@@ -53,7 +54,7 @@ type Firewall struct {
|
|||||||
|
|
||||||
// routableNetworks describes the vpn addresses as well as any unsafe networks issued to us in the certificate.
|
// routableNetworks describes the vpn addresses as well as any unsafe networks issued to us in the certificate.
|
||||||
// The vpn addresses are a full bit match while the unsafe networks only match the prefix
|
// The vpn addresses are a full bit match while the unsafe networks only match the prefix
|
||||||
routableNetworks *bart.Table[struct{}]
|
routableNetworks *bart.Lite
|
||||||
|
|
||||||
// assignedNetworks is a list of vpn networks assigned to us in the certificate.
|
// assignedNetworks is a list of vpn networks assigned to us in the certificate.
|
||||||
assignedNetworks []netip.Prefix
|
assignedNetworks []netip.Prefix
|
||||||
@@ -125,7 +126,7 @@ type firewallPort map[int32]*FirewallCA
|
|||||||
|
|
||||||
type firewallLocalCIDR struct {
|
type firewallLocalCIDR struct {
|
||||||
Any bool
|
Any bool
|
||||||
LocalCIDR *bart.Table[struct{}]
|
LocalCIDR *bart.Lite
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewFirewall creates a new Firewall object. A TimerWheel is created for you from the provided timeouts.
|
// NewFirewall creates a new Firewall object. A TimerWheel is created for you from the provided timeouts.
|
||||||
@@ -148,17 +149,17 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D
|
|||||||
tmax = defaultTimeout
|
tmax = defaultTimeout
|
||||||
}
|
}
|
||||||
|
|
||||||
routableNetworks := new(bart.Table[struct{}])
|
routableNetworks := new(bart.Lite)
|
||||||
var assignedNetworks []netip.Prefix
|
var assignedNetworks []netip.Prefix
|
||||||
for _, network := range c.Networks() {
|
for _, network := range c.Networks() {
|
||||||
nprefix := netip.PrefixFrom(network.Addr(), network.Addr().BitLen())
|
nprefix := netip.PrefixFrom(network.Addr(), network.Addr().BitLen())
|
||||||
routableNetworks.Insert(nprefix, struct{}{})
|
routableNetworks.Insert(nprefix)
|
||||||
assignedNetworks = append(assignedNetworks, network)
|
assignedNetworks = append(assignedNetworks, network)
|
||||||
}
|
}
|
||||||
|
|
||||||
hasUnsafeNetworks := false
|
hasUnsafeNetworks := false
|
||||||
for _, n := range c.UnsafeNetworks() {
|
for _, n := range c.UnsafeNetworks() {
|
||||||
routableNetworks.Insert(n, struct{}{})
|
routableNetworks.Insert(n)
|
||||||
hasUnsafeNetworks = true
|
hasUnsafeNetworks = true
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -247,22 +248,11 @@ func NewFirewallFromConfig(l *logrus.Logger, cs *CertState, c *config.C) (*Firew
|
|||||||
}
|
}
|
||||||
|
|
||||||
// AddRule properly creates the in memory rule structure for a firewall table.
|
// AddRule properly creates the in memory rule structure for a firewall table.
|
||||||
func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip, localIp netip.Prefix, caName string, caSha string) error {
|
func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, cidr, localCidr, caName string, caSha string) error {
|
||||||
// Under gomobile, stringing a nil pointer with fmt causes an abort in debug mode for iOS
|
|
||||||
// https://github.com/golang/go/issues/14131
|
|
||||||
sIp := ""
|
|
||||||
if ip.IsValid() {
|
|
||||||
sIp = ip.String()
|
|
||||||
}
|
|
||||||
lIp := ""
|
|
||||||
if localIp.IsValid() {
|
|
||||||
lIp = localIp.String()
|
|
||||||
}
|
|
||||||
|
|
||||||
// We need this rule string because we generate a hash. Removing this will break firewall reload.
|
// We need this rule string because we generate a hash. Removing this will break firewall reload.
|
||||||
ruleString := fmt.Sprintf(
|
ruleString := fmt.Sprintf(
|
||||||
"incoming: %v, proto: %v, startPort: %v, endPort: %v, groups: %v, host: %v, ip: %v, localIp: %v, caName: %v, caSha: %s",
|
"incoming: %v, proto: %v, startPort: %v, endPort: %v, groups: %v, host: %v, ip: %v, localIp: %v, caName: %v, caSha: %s",
|
||||||
incoming, proto, startPort, endPort, groups, host, sIp, lIp, caName, caSha,
|
incoming, proto, startPort, endPort, groups, host, cidr, localCidr, caName, caSha,
|
||||||
)
|
)
|
||||||
f.rules += ruleString + "\n"
|
f.rules += ruleString + "\n"
|
||||||
|
|
||||||
@@ -270,7 +260,7 @@ func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort
|
|||||||
if !incoming {
|
if !incoming {
|
||||||
direction = "outgoing"
|
direction = "outgoing"
|
||||||
}
|
}
|
||||||
f.l.WithField("firewallRule", m{"direction": direction, "proto": proto, "startPort": startPort, "endPort": endPort, "groups": groups, "host": host, "ip": sIp, "localIp": lIp, "caName": caName, "caSha": caSha}).
|
f.l.WithField("firewallRule", m{"direction": direction, "proto": proto, "startPort": startPort, "endPort": endPort, "groups": groups, "host": host, "cidr": cidr, "localCidr": localCidr, "caName": caName, "caSha": caSha}).
|
||||||
Info("Firewall rule added")
|
Info("Firewall rule added")
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -297,7 +287,7 @@ func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort
|
|||||||
return fmt.Errorf("unknown protocol %v", proto)
|
return fmt.Errorf("unknown protocol %v", proto)
|
||||||
}
|
}
|
||||||
|
|
||||||
return fp.addRule(f, startPort, endPort, groups, host, ip, localIp, caName, caSha)
|
return fp.addRule(f, startPort, endPort, groups, host, cidr, localCidr, caName, caSha)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetRuleHash returns a hash representation of all inbound and outbound rules
|
// GetRuleHash returns a hash representation of all inbound and outbound rules
|
||||||
@@ -331,13 +321,12 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
rs, ok := r.([]interface{})
|
rs, ok := r.([]any)
|
||||||
if !ok {
|
if !ok {
|
||||||
return fmt.Errorf("%s failed to parse, should be an array of rules", table)
|
return fmt.Errorf("%s failed to parse, should be an array of rules", table)
|
||||||
}
|
}
|
||||||
|
|
||||||
for i, t := range rs {
|
for i, t := range rs {
|
||||||
var groups []string
|
|
||||||
r, err := convertRule(l, t, table, i)
|
r, err := convertRule(l, t, table, i)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("%s rule #%v; %s", table, i, err)
|
return fmt.Errorf("%s rule #%v; %s", table, i, err)
|
||||||
@@ -347,23 +336,10 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw
|
|||||||
return fmt.Errorf("%s rule #%v; only one of port or code should be provided", table, i)
|
return fmt.Errorf("%s rule #%v; only one of port or code should be provided", table, i)
|
||||||
}
|
}
|
||||||
|
|
||||||
if r.Host == "" && len(r.Groups) == 0 && r.Group == "" && r.Cidr == "" && r.LocalCidr == "" && r.CAName == "" && r.CASha == "" {
|
if r.Host == "" && len(r.Groups) == 0 && r.Cidr == "" && r.LocalCidr == "" && r.CAName == "" && r.CASha == "" {
|
||||||
return fmt.Errorf("%s rule #%v; at least one of host, group, cidr, local_cidr, ca_name, or ca_sha must be provided", table, i)
|
return fmt.Errorf("%s rule #%v; at least one of host, group, cidr, local_cidr, ca_name, or ca_sha must be provided", table, i)
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(r.Groups) > 0 {
|
|
||||||
groups = r.Groups
|
|
||||||
}
|
|
||||||
|
|
||||||
if r.Group != "" {
|
|
||||||
// Check if we have both groups and group provided in the rule config
|
|
||||||
if len(groups) > 0 {
|
|
||||||
return fmt.Errorf("%s rule #%v; only one of group or groups should be defined, both provided", table, i)
|
|
||||||
}
|
|
||||||
|
|
||||||
groups = []string{r.Group}
|
|
||||||
}
|
|
||||||
|
|
||||||
var sPort, errPort string
|
var sPort, errPort string
|
||||||
if r.Code != "" {
|
if r.Code != "" {
|
||||||
errPort = "code"
|
errPort = "code"
|
||||||
@@ -392,23 +368,25 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw
|
|||||||
return fmt.Errorf("%s rule #%v; proto was not understood; `%s`", table, i, r.Proto)
|
return fmt.Errorf("%s rule #%v; proto was not understood; `%s`", table, i, r.Proto)
|
||||||
}
|
}
|
||||||
|
|
||||||
var cidr netip.Prefix
|
if r.Cidr != "" && r.Cidr != "any" {
|
||||||
if r.Cidr != "" {
|
_, err = netip.ParsePrefix(r.Cidr)
|
||||||
cidr, err = netip.ParsePrefix(r.Cidr)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("%s rule #%v; cidr did not parse; %s", table, i, err)
|
return fmt.Errorf("%s rule #%v; cidr did not parse; %s", table, i, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var localCidr netip.Prefix
|
if r.LocalCidr != "" && r.LocalCidr != "any" {
|
||||||
if r.LocalCidr != "" {
|
_, err = netip.ParsePrefix(r.LocalCidr)
|
||||||
localCidr, err = netip.ParsePrefix(r.LocalCidr)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("%s rule #%v; local_cidr did not parse; %s", table, i, err)
|
return fmt.Errorf("%s rule #%v; local_cidr did not parse; %s", table, i, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
err = fw.AddRule(inbound, proto, startPort, endPort, groups, r.Host, cidr, localCidr, r.CAName, r.CASha)
|
if warning := r.sanity(); warning != nil {
|
||||||
|
l.Warnf("%s rule #%v; %s", table, i, warning)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = fw.AddRule(inbound, proto, startPort, endPort, r.Groups, r.Host, r.Cidr, r.LocalCidr, r.CAName, r.CASha)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("%s rule #%v; `%s`", table, i, err)
|
return fmt.Errorf("%s rule #%v; `%s`", table, i, err)
|
||||||
}
|
}
|
||||||
@@ -417,8 +395,10 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var ErrInvalidRemoteIP = errors.New("remote IP is not in remote certificate subnets")
|
var ErrUnknownNetworkType = errors.New("unknown network type")
|
||||||
var ErrInvalidLocalIP = errors.New("local IP is not in list of handled local IPs")
|
var ErrPeerRejected = errors.New("remote address is not within a network that we handle")
|
||||||
|
var ErrInvalidRemoteIP = errors.New("remote address is not in remote certificate networks")
|
||||||
|
var ErrInvalidLocalIP = errors.New("local address is not in list of handled local addresses")
|
||||||
var ErrNoMatchingRule = errors.New("no matching rule in firewall table")
|
var ErrNoMatchingRule = errors.New("no matching rule in firewall table")
|
||||||
|
|
||||||
// Drop returns an error if the packet should be dropped, explaining why. It
|
// Drop returns an error if the packet should be dropped, explaining why. It
|
||||||
@@ -429,24 +409,35 @@ func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Make sure remote address matches nebula certificate
|
// Make sure remote address matches nebula certificate, and determine how to treat it
|
||||||
if h.networks != nil {
|
if h.networks == nil {
|
||||||
_, ok := h.networks.Lookup(fp.RemoteAddr)
|
|
||||||
if !ok {
|
|
||||||
f.metrics(incoming).droppedRemoteAddr.Inc(1)
|
|
||||||
return ErrInvalidRemoteIP
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// Simple case: Certificate has one address and no unsafe networks
|
// Simple case: Certificate has one address and no unsafe networks
|
||||||
if h.vpnAddrs[0] != fp.RemoteAddr {
|
if h.vpnAddrs[0] != fp.RemoteAddr {
|
||||||
f.metrics(incoming).droppedRemoteAddr.Inc(1)
|
f.metrics(incoming).droppedRemoteAddr.Inc(1)
|
||||||
return ErrInvalidRemoteIP
|
return ErrInvalidRemoteIP
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
nwType, ok := h.networks.Lookup(fp.RemoteAddr)
|
||||||
|
if !ok {
|
||||||
|
f.metrics(incoming).droppedRemoteAddr.Inc(1)
|
||||||
|
return ErrInvalidRemoteIP
|
||||||
|
}
|
||||||
|
switch nwType {
|
||||||
|
case NetworkTypeVPN:
|
||||||
|
break // nothing special
|
||||||
|
case NetworkTypeVPNPeer:
|
||||||
|
f.metrics(incoming).droppedRemoteAddr.Inc(1)
|
||||||
|
return ErrPeerRejected // reject for now, one day this may have different FW rules
|
||||||
|
case NetworkTypeUnsafe:
|
||||||
|
break // nothing special, one day this may have different FW rules
|
||||||
|
default:
|
||||||
|
f.metrics(incoming).droppedRemoteAddr.Inc(1)
|
||||||
|
return ErrUnknownNetworkType //should never happen
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Make sure we are supposed to be handling this local ip address
|
// Make sure we are supposed to be handling this local ip address
|
||||||
_, ok := f.routableNetworks.Lookup(fp.LocalAddr)
|
if !f.routableNetworks.Contains(fp.LocalAddr) {
|
||||||
if !ok {
|
|
||||||
f.metrics(incoming).droppedLocalAddr.Inc(1)
|
f.metrics(incoming).droppedLocalAddr.Inc(1)
|
||||||
return ErrInvalidLocalIP
|
return ErrInvalidLocalIP
|
||||||
}
|
}
|
||||||
@@ -642,7 +633,7 @@ func (ft *FirewallTable) match(p firewall.Packet, incoming bool, c *cert.CachedC
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (fp firewallPort) addRule(f *Firewall, startPort int32, endPort int32, groups []string, host string, ip, localIp netip.Prefix, caName string, caSha string) error {
|
func (fp firewallPort) addRule(f *Firewall, startPort int32, endPort int32, groups []string, host string, cidr, localCidr, caName string, caSha string) error {
|
||||||
if startPort > endPort {
|
if startPort > endPort {
|
||||||
return fmt.Errorf("start port was lower than end port")
|
return fmt.Errorf("start port was lower than end port")
|
||||||
}
|
}
|
||||||
@@ -655,7 +646,7 @@ func (fp firewallPort) addRule(f *Firewall, startPort int32, endPort int32, grou
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := fp[i].addRule(f, groups, host, ip, localIp, caName, caSha); err != nil {
|
if err := fp[i].addRule(f, groups, host, cidr, localCidr, caName, caSha); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -686,7 +677,7 @@ func (fp firewallPort) match(p firewall.Packet, incoming bool, c *cert.CachedCer
|
|||||||
return fp[firewall.PortAny].match(p, c, caPool)
|
return fp[firewall.PortAny].match(p, c, caPool)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (fc *FirewallCA) addRule(f *Firewall, groups []string, host string, ip, localIp netip.Prefix, caName, caSha string) error {
|
func (fc *FirewallCA) addRule(f *Firewall, groups []string, host string, cidr, localCidr, caName, caSha string) error {
|
||||||
fr := func() *FirewallRule {
|
fr := func() *FirewallRule {
|
||||||
return &FirewallRule{
|
return &FirewallRule{
|
||||||
Hosts: make(map[string]*firewallLocalCIDR),
|
Hosts: make(map[string]*firewallLocalCIDR),
|
||||||
@@ -700,14 +691,14 @@ func (fc *FirewallCA) addRule(f *Firewall, groups []string, host string, ip, loc
|
|||||||
fc.Any = fr()
|
fc.Any = fr()
|
||||||
}
|
}
|
||||||
|
|
||||||
return fc.Any.addRule(f, groups, host, ip, localIp)
|
return fc.Any.addRule(f, groups, host, cidr, localCidr)
|
||||||
}
|
}
|
||||||
|
|
||||||
if caSha != "" {
|
if caSha != "" {
|
||||||
if _, ok := fc.CAShas[caSha]; !ok {
|
if _, ok := fc.CAShas[caSha]; !ok {
|
||||||
fc.CAShas[caSha] = fr()
|
fc.CAShas[caSha] = fr()
|
||||||
}
|
}
|
||||||
err := fc.CAShas[caSha].addRule(f, groups, host, ip, localIp)
|
err := fc.CAShas[caSha].addRule(f, groups, host, cidr, localCidr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -717,7 +708,7 @@ func (fc *FirewallCA) addRule(f *Firewall, groups []string, host string, ip, loc
|
|||||||
if _, ok := fc.CANames[caName]; !ok {
|
if _, ok := fc.CANames[caName]; !ok {
|
||||||
fc.CANames[caName] = fr()
|
fc.CANames[caName] = fr()
|
||||||
}
|
}
|
||||||
err := fc.CANames[caName].addRule(f, groups, host, ip, localIp)
|
err := fc.CANames[caName].addRule(f, groups, host, cidr, localCidr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -749,24 +740,24 @@ func (fc *FirewallCA) match(p firewall.Packet, c *cert.CachedCertificate, caPool
|
|||||||
return fc.CANames[s.Certificate.Name()].match(p, c)
|
return fc.CANames[s.Certificate.Name()].match(p, c)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (fr *FirewallRule) addRule(f *Firewall, groups []string, host string, ip, localCIDR netip.Prefix) error {
|
func (fr *FirewallRule) addRule(f *Firewall, groups []string, host, cidr, localCidr string) error {
|
||||||
flc := func() *firewallLocalCIDR {
|
flc := func() *firewallLocalCIDR {
|
||||||
return &firewallLocalCIDR{
|
return &firewallLocalCIDR{
|
||||||
LocalCIDR: new(bart.Table[struct{}]),
|
LocalCIDR: new(bart.Lite),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if fr.isAny(groups, host, ip) {
|
if fr.isAny(groups, host, cidr) {
|
||||||
if fr.Any == nil {
|
if fr.Any == nil {
|
||||||
fr.Any = flc()
|
fr.Any = flc()
|
||||||
}
|
}
|
||||||
|
|
||||||
return fr.Any.addRule(f, localCIDR)
|
return fr.Any.addRule(f, localCidr)
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(groups) > 0 {
|
if len(groups) > 0 {
|
||||||
nlc := flc()
|
nlc := flc()
|
||||||
err := nlc.addRule(f, localCIDR)
|
err := nlc.addRule(f, localCidr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -782,30 +773,34 @@ func (fr *FirewallRule) addRule(f *Firewall, groups []string, host string, ip, l
|
|||||||
if nlc == nil {
|
if nlc == nil {
|
||||||
nlc = flc()
|
nlc = flc()
|
||||||
}
|
}
|
||||||
err := nlc.addRule(f, localCIDR)
|
err := nlc.addRule(f, localCidr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
fr.Hosts[host] = nlc
|
fr.Hosts[host] = nlc
|
||||||
}
|
}
|
||||||
|
|
||||||
if ip.IsValid() {
|
if cidr != "" {
|
||||||
nlc, _ := fr.CIDR.Get(ip)
|
c, err := netip.ParsePrefix(cidr)
|
||||||
if nlc == nil {
|
|
||||||
nlc = flc()
|
|
||||||
}
|
|
||||||
err := nlc.addRule(f, localCIDR)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
fr.CIDR.Insert(ip, nlc)
|
nlc, _ := fr.CIDR.Get(c)
|
||||||
|
if nlc == nil {
|
||||||
|
nlc = flc()
|
||||||
|
}
|
||||||
|
err = nlc.addRule(f, localCidr)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
fr.CIDR.Insert(c, nlc)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (fr *FirewallRule) isAny(groups []string, host string, ip netip.Prefix) bool {
|
func (fr *FirewallRule) isAny(groups []string, host string, cidr string) bool {
|
||||||
if len(groups) == 0 && host == "" && !ip.IsValid() {
|
if len(groups) == 0 && host == "" && cidr == "" {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -819,7 +814,7 @@ func (fr *FirewallRule) isAny(groups []string, host string, ip netip.Prefix) boo
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
if ip.IsValid() && ip.Bits() == 0 {
|
if cidr == "any" {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -862,36 +857,39 @@ func (fr *FirewallRule) match(p firewall.Packet, c *cert.CachedCertificate) bool
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
matched := false
|
for _, v := range fr.CIDR.Supernets(netip.PrefixFrom(p.RemoteAddr, p.RemoteAddr.BitLen())) {
|
||||||
prefix := netip.PrefixFrom(p.RemoteAddr, p.RemoteAddr.BitLen())
|
if v.match(p, c) {
|
||||||
fr.CIDR.EachLookupPrefix(prefix, func(prefix netip.Prefix, val *firewallLocalCIDR) bool {
|
return true
|
||||||
if prefix.Contains(p.RemoteAddr) && val.match(p, c) {
|
|
||||||
matched = true
|
|
||||||
return false
|
|
||||||
}
|
}
|
||||||
return true
|
}
|
||||||
})
|
|
||||||
return matched
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (flc *firewallLocalCIDR) addRule(f *Firewall, localIp netip.Prefix) error {
|
func (flc *firewallLocalCIDR) addRule(f *Firewall, localCidr string) error {
|
||||||
if !localIp.IsValid() {
|
if localCidr == "any" {
|
||||||
|
flc.Any = true
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if localCidr == "" {
|
||||||
if !f.hasUnsafeNetworks || f.defaultLocalCIDRAny {
|
if !f.hasUnsafeNetworks || f.defaultLocalCIDRAny {
|
||||||
flc.Any = true
|
flc.Any = true
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, network := range f.assignedNetworks {
|
for _, network := range f.assignedNetworks {
|
||||||
flc.LocalCIDR.Insert(network, struct{}{})
|
flc.LocalCIDR.Insert(network)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
|
||||||
} else if localIp.Bits() == 0 {
|
|
||||||
flc.Any = true
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
flc.LocalCIDR.Insert(localIp, struct{}{})
|
c, err := netip.ParsePrefix(localCidr)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
flc.LocalCIDR.Insert(c)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -904,8 +902,7 @@ func (flc *firewallLocalCIDR) match(p firewall.Packet, c *cert.CachedCertificate
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
_, ok := flc.LocalCIDR.Lookup(p.LocalAddr)
|
return flc.LocalCIDR.Contains(p.LocalAddr)
|
||||||
return ok
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type rule struct {
|
type rule struct {
|
||||||
@@ -913,7 +910,6 @@ type rule struct {
|
|||||||
Code string
|
Code string
|
||||||
Proto string
|
Proto string
|
||||||
Host string
|
Host string
|
||||||
Group string
|
|
||||||
Groups []string
|
Groups []string
|
||||||
Cidr string
|
Cidr string
|
||||||
LocalCidr string
|
LocalCidr string
|
||||||
@@ -921,15 +917,15 @@ type rule struct {
|
|||||||
CASha string
|
CASha string
|
||||||
}
|
}
|
||||||
|
|
||||||
func convertRule(l *logrus.Logger, p interface{}, table string, i int) (rule, error) {
|
func convertRule(l *logrus.Logger, p any, table string, i int) (rule, error) {
|
||||||
r := rule{}
|
r := rule{}
|
||||||
|
|
||||||
m, ok := p.(map[interface{}]interface{})
|
m, ok := p.(map[string]any)
|
||||||
if !ok {
|
if !ok {
|
||||||
return r, errors.New("could not parse rule")
|
return r, errors.New("could not parse rule")
|
||||||
}
|
}
|
||||||
|
|
||||||
toString := func(k string, m map[interface{}]interface{}) string {
|
toString := func(k string, m map[string]any) string {
|
||||||
v, ok := m[k]
|
v, ok := m[k]
|
||||||
if !ok {
|
if !ok {
|
||||||
return ""
|
return ""
|
||||||
@@ -947,7 +943,7 @@ func convertRule(l *logrus.Logger, p interface{}, table string, i int) (rule, er
|
|||||||
r.CASha = toString("ca_sha", m)
|
r.CASha = toString("ca_sha", m)
|
||||||
|
|
||||||
// Make sure group isn't an array
|
// Make sure group isn't an array
|
||||||
if v, ok := m["group"].([]interface{}); ok {
|
if v, ok := m["group"].([]any); ok {
|
||||||
if len(v) > 1 {
|
if len(v) > 1 {
|
||||||
return r, errors.New("group should contain a single value, an array with more than one entry was provided")
|
return r, errors.New("group should contain a single value, an array with more than one entry was provided")
|
||||||
}
|
}
|
||||||
@@ -955,7 +951,8 @@ func convertRule(l *logrus.Logger, p interface{}, table string, i int) (rule, er
|
|||||||
l.Warnf("%s rule #%v; group was an array with a single value, converting to simple value", table, i)
|
l.Warnf("%s rule #%v; group was an array with a single value, converting to simple value", table, i)
|
||||||
m["group"] = v[0]
|
m["group"] = v[0]
|
||||||
}
|
}
|
||||||
r.Group = toString("group", m)
|
|
||||||
|
singleGroup := toString("group", m)
|
||||||
|
|
||||||
if rg, ok := m["groups"]; ok {
|
if rg, ok := m["groups"]; ok {
|
||||||
switch reflect.TypeOf(rg).Kind() {
|
switch reflect.TypeOf(rg).Kind() {
|
||||||
@@ -972,9 +969,60 @@ func convertRule(l *logrus.Logger, p interface{}, table string, i int) (rule, er
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//flatten group vs groups
|
||||||
|
if singleGroup != "" {
|
||||||
|
// Check if we have both groups and group provided in the rule config
|
||||||
|
if len(r.Groups) > 0 {
|
||||||
|
return r, fmt.Errorf("only one of group or groups should be defined, both provided")
|
||||||
|
}
|
||||||
|
r.Groups = []string{singleGroup}
|
||||||
|
}
|
||||||
|
|
||||||
return r, nil
|
return r, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// sanity returns an error if the rule would be evaluated in a way that would short-circuit a configured check on a wildcard value
|
||||||
|
// rules are evaluated as "port AND proto AND (ca_sha OR ca_name) AND (host OR group OR groups OR cidr) AND local_cidr"
|
||||||
|
func (r *rule) sanity() error {
|
||||||
|
//port, proto, local_cidr are AND, no need to check here
|
||||||
|
//ca_sha and ca_name don't have a wildcard value, no need to check here
|
||||||
|
groupsEmpty := len(r.Groups) == 0
|
||||||
|
hostEmpty := r.Host == ""
|
||||||
|
cidrEmpty := r.Cidr == ""
|
||||||
|
|
||||||
|
if (groupsEmpty && hostEmpty && cidrEmpty) == true {
|
||||||
|
return nil //no content!
|
||||||
|
}
|
||||||
|
|
||||||
|
groupsHasAny := slices.Contains(r.Groups, "any")
|
||||||
|
if groupsHasAny && len(r.Groups) > 1 {
|
||||||
|
return fmt.Errorf("groups spec [%s] contains the group '\"any\". This rule will ignore the other groups specified", r.Groups)
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.Host == "any" {
|
||||||
|
if !groupsEmpty {
|
||||||
|
return fmt.Errorf("groups specified as %s, but host=any will match any host, regardless of groups", r.Groups)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !cidrEmpty {
|
||||||
|
return fmt.Errorf("cidr specified as %s, but host=any will match any host, regardless of cidr", r.Cidr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if groupsHasAny {
|
||||||
|
if !hostEmpty && r.Host != "any" {
|
||||||
|
return fmt.Errorf("groups spec [%s] contains the group '\"any\". This rule will ignore the specified host %s", r.Groups, r.Host)
|
||||||
|
}
|
||||||
|
if !cidrEmpty {
|
||||||
|
return fmt.Errorf("groups spec [%s] contains the group '\"any\". This rule will ignore the specified cidr %s", r.Groups, r.Cidr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//todo alert on cidr-any
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func parsePort(s string) (startPort, endPort int32, err error) {
|
func parsePort(s string) (startPort, endPort int32, err error) {
|
||||||
if s == "any" {
|
if s == "any" {
|
||||||
startPort = firewall.PortAny
|
startPort = firewall.PortAny
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import (
|
|||||||
"net/netip"
|
"net/netip"
|
||||||
)
|
)
|
||||||
|
|
||||||
type m map[string]interface{}
|
type m = map[string]any
|
||||||
|
|
||||||
const (
|
const (
|
||||||
ProtoAny = 0 // When we want to handle HOPOPT (0) we can change this, if ever
|
ProtoAny = 0 // When we want to handle HOPOPT (0) we can change this, if ever
|
||||||
|
|||||||
793
firewall_test.go
793
firewall_test.go
File diff suppressed because it is too large
Load Diff
49
go.mod
49
go.mod
@@ -1,58 +1,55 @@
|
|||||||
module github.com/slackhq/nebula
|
module github.com/slackhq/nebula
|
||||||
|
|
||||||
go 1.22.0
|
go 1.25
|
||||||
|
|
||||||
toolchain go1.22.2
|
|
||||||
|
|
||||||
require (
|
require (
|
||||||
dario.cat/mergo v1.0.1
|
dario.cat/mergo v1.0.2
|
||||||
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be
|
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be
|
||||||
github.com/armon/go-radix v1.0.0
|
github.com/armon/go-radix v1.0.0
|
||||||
github.com/cyberdelia/go-metrics-graphite v0.0.0-20161219230853-39f87cc3b432
|
github.com/cyberdelia/go-metrics-graphite v0.0.0-20161219230853-39f87cc3b432
|
||||||
github.com/flynn/noise v1.1.0
|
github.com/flynn/noise v1.1.0
|
||||||
github.com/gaissmai/bart v0.13.0
|
github.com/gaissmai/bart v0.26.0
|
||||||
github.com/gogo/protobuf v1.3.2
|
github.com/gogo/protobuf v1.3.2
|
||||||
github.com/google/gopacket v1.1.19
|
github.com/google/gopacket v1.1.19
|
||||||
github.com/kardianos/service v1.2.2
|
github.com/kardianos/service v1.2.4
|
||||||
github.com/miekg/dns v1.1.62
|
github.com/miekg/dns v1.1.68
|
||||||
github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b
|
github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b
|
||||||
github.com/nbrownus/go-metrics-prometheus v0.0.0-20210712211119-974a6260965f
|
github.com/nbrownus/go-metrics-prometheus v0.0.0-20210712211119-974a6260965f
|
||||||
github.com/prometheus/client_golang v1.20.4
|
github.com/prometheus/client_golang v1.23.2
|
||||||
github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475
|
github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475
|
||||||
github.com/sirupsen/logrus v1.9.3
|
github.com/sirupsen/logrus v1.9.3
|
||||||
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e
|
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e
|
||||||
github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6
|
github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6
|
||||||
github.com/stretchr/testify v1.9.0
|
github.com/stretchr/testify v1.11.1
|
||||||
github.com/vishvananda/netlink v1.3.0
|
github.com/vishvananda/netlink v1.3.1
|
||||||
golang.org/x/crypto v0.28.0
|
go.yaml.in/yaml/v3 v3.0.4
|
||||||
|
golang.org/x/crypto v0.45.0
|
||||||
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090
|
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090
|
||||||
golang.org/x/net v0.30.0
|
golang.org/x/net v0.47.0
|
||||||
golang.org/x/sync v0.8.0
|
golang.org/x/sync v0.18.0
|
||||||
golang.org/x/sys v0.26.0
|
golang.org/x/sys v0.38.0
|
||||||
golang.org/x/term v0.25.0
|
golang.org/x/term v0.37.0
|
||||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2
|
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2
|
||||||
golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b
|
golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b
|
||||||
golang.zx2c4.com/wireguard/windows v0.5.3
|
golang.zx2c4.com/wireguard/windows v0.5.3
|
||||||
google.golang.org/protobuf v1.35.1
|
google.golang.org/protobuf v1.36.10
|
||||||
gopkg.in/yaml.v2 v2.4.0
|
gopkg.in/yaml.v3 v3.0.1
|
||||||
gvisor.dev/gvisor v0.0.0-20240423190808-9d7a357edefe
|
gvisor.dev/gvisor v0.0.0-20240423190808-9d7a357edefe
|
||||||
)
|
)
|
||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/beorn7/perks v1.0.1 // indirect
|
github.com/beorn7/perks v1.0.1 // indirect
|
||||||
github.com/bits-and-blooms/bitset v1.14.3 // indirect
|
|
||||||
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
||||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||||
github.com/google/btree v1.1.2 // indirect
|
github.com/google/btree v1.1.2 // indirect
|
||||||
github.com/klauspost/compress v1.17.9 // indirect
|
|
||||||
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
|
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
|
||||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||||
github.com/prometheus/client_model v0.6.1 // indirect
|
github.com/prometheus/client_model v0.6.2 // indirect
|
||||||
github.com/prometheus/common v0.55.0 // indirect
|
github.com/prometheus/common v0.66.1 // indirect
|
||||||
github.com/prometheus/procfs v0.15.1 // indirect
|
github.com/prometheus/procfs v0.16.1 // indirect
|
||||||
github.com/vishvananda/netns v0.0.4 // indirect
|
github.com/vishvananda/netns v0.0.5 // indirect
|
||||||
golang.org/x/mod v0.18.0 // indirect
|
go.yaml.in/yaml/v2 v2.4.2 // indirect
|
||||||
|
golang.org/x/mod v0.24.0 // indirect
|
||||||
golang.org/x/time v0.5.0 // indirect
|
golang.org/x/time v0.5.0 // indirect
|
||||||
golang.org/x/tools v0.22.0 // indirect
|
golang.org/x/tools v0.33.0 // indirect
|
||||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
|
||||||
)
|
)
|
||||||
|
|||||||
95
go.sum
95
go.sum
@@ -1,6 +1,6 @@
|
|||||||
cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
|
cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
|
||||||
dario.cat/mergo v1.0.1 h1:Ra4+bf83h2ztPIQYNP99R6m+Y7KfnARDfID+a+vLl4s=
|
dario.cat/mergo v1.0.2 h1:85+piFYR1tMbRrLcDwR18y4UKJ3aH1Tbzi24VRW1TK8=
|
||||||
dario.cat/mergo v1.0.1/go.mod h1:uNxQE+84aUszobStD9th8a29P2fMDhsBdgRYvZOxGmk=
|
dario.cat/mergo v1.0.2/go.mod h1:E/hbnu0NxMFBjpMIE34DRGLWqDy0g5FuKDhCb31ngxA=
|
||||||
github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc=
|
github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc=
|
||||||
github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc=
|
github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc=
|
||||||
github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0=
|
github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0=
|
||||||
@@ -14,8 +14,6 @@ github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24
|
|||||||
github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8=
|
github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8=
|
||||||
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
|
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
|
||||||
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
|
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
|
||||||
github.com/bits-and-blooms/bitset v1.14.3 h1:Gd2c8lSNf9pKXom5JtD7AaKO8o7fGQ2LtFj1436qilA=
|
|
||||||
github.com/bits-and-blooms/bitset v1.14.3/go.mod h1:7hO7Gc7Pp1vODcmWvKMRA9BNmbv6a/7QIWpPxHddWR8=
|
|
||||||
github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
||||||
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
|
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
|
||||||
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
||||||
@@ -26,8 +24,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
|
|||||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
github.com/flynn/noise v1.1.0 h1:KjPQoQCEFdZDiP03phOvGi11+SVVhBG2wOWAorLsstg=
|
github.com/flynn/noise v1.1.0 h1:KjPQoQCEFdZDiP03phOvGi11+SVVhBG2wOWAorLsstg=
|
||||||
github.com/flynn/noise v1.1.0/go.mod h1:xbMo+0i6+IGbYdJhF31t2eR1BIU0CYc12+BNAKwUTag=
|
github.com/flynn/noise v1.1.0/go.mod h1:xbMo+0i6+IGbYdJhF31t2eR1BIU0CYc12+BNAKwUTag=
|
||||||
github.com/gaissmai/bart v0.13.0 h1:pItEhXDVVebUa+i978FfQ7ye8xZc1FrMgs8nJPPWAgA=
|
github.com/gaissmai/bart v0.26.0 h1:xOZ57E9hJLBiQaSyeZa9wgWhGuzfGACgqp4BE77OkO0=
|
||||||
github.com/gaissmai/bart v0.13.0/go.mod h1:qSes2fnJ8hB410BW0ymHUN/eQkuGpTYyJcN8sKMYpJU=
|
github.com/gaissmai/bart v0.26.0/go.mod h1:GREWQfTLRWz/c5FTOsIw+KkscuFkIV5t8Rp7Nd1Td5c=
|
||||||
github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
|
github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
|
||||||
github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
|
github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
|
||||||
github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY=
|
github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY=
|
||||||
@@ -55,8 +53,8 @@ github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMyw
|
|||||||
github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||||
github.com/google/go-cmp v0.5.4/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
github.com/google/go-cmp v0.5.4/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||||
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||||
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
|
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||||
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
||||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||||
github.com/google/gopacket v1.1.19 h1:ves8RnFZPGiFnTS0uPQStjwru6uO6h+nlr9j6fL7kF8=
|
github.com/google/gopacket v1.1.19 h1:ves8RnFZPGiFnTS0uPQStjwru6uO6h+nlr9j6fL7kF8=
|
||||||
github.com/google/gopacket v1.1.19/go.mod h1:iJ8V8n6KS+z2U1A8pUwu8bW5SyEMkXJB8Yo/Vo+TKTo=
|
github.com/google/gopacket v1.1.19/go.mod h1:iJ8V8n6KS+z2U1A8pUwu8bW5SyEMkXJB8Yo/Vo+TKTo=
|
||||||
@@ -66,12 +64,12 @@ github.com/json-iterator/go v1.1.10/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/
|
|||||||
github.com/json-iterator/go v1.1.11/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4=
|
github.com/json-iterator/go v1.1.11/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4=
|
||||||
github.com/julienschmidt/httprouter v1.2.0/go.mod h1:SYymIcj16QtmaHHD7aYtjjsJG7VTCxuUUipMqKk8s4w=
|
github.com/julienschmidt/httprouter v1.2.0/go.mod h1:SYymIcj16QtmaHHD7aYtjjsJG7VTCxuUUipMqKk8s4w=
|
||||||
github.com/julienschmidt/httprouter v1.3.0/go.mod h1:JR6WtHb+2LUe8TCKY3cZOxFyyO8IZAc4RVcycCCAKdM=
|
github.com/julienschmidt/httprouter v1.3.0/go.mod h1:JR6WtHb+2LUe8TCKY3cZOxFyyO8IZAc4RVcycCCAKdM=
|
||||||
github.com/kardianos/service v1.2.2 h1:ZvePhAHfvo0A7Mftk/tEzqEZ7Q4lgnR8sGz4xu1YX60=
|
github.com/kardianos/service v1.2.4 h1:XNlGtZOYNx2u91urOdg/Kfmc+gfmuIo1Dd3rEi2OgBk=
|
||||||
github.com/kardianos/service v1.2.2/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
|
github.com/kardianos/service v1.2.4/go.mod h1:E4V9ufUuY82F7Ztlu1eN9VXWIQxg8NoLQlmFe0MtrXc=
|
||||||
github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8=
|
github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8=
|
||||||
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
|
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
|
||||||
github.com/klauspost/compress v1.17.9 h1:6KIumPrER1LHsvBVuDa0r5xaG0Es51mhhB9BQB2qeMA=
|
github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo=
|
||||||
github.com/klauspost/compress v1.17.9/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw=
|
github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ=
|
||||||
github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
|
github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
|
||||||
github.com/konsorten/go-windows-terminal-sequences v1.0.3/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
|
github.com/konsorten/go-windows-terminal-sequences v1.0.3/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
|
||||||
github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc=
|
github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc=
|
||||||
@@ -85,8 +83,8 @@ github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
|
|||||||
github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc=
|
github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc=
|
||||||
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
|
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
|
||||||
github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0=
|
github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0=
|
||||||
github.com/miekg/dns v1.1.62 h1:cN8OuEF1/x5Rq6Np+h1epln8OiyPWV+lROx9LxcGgIQ=
|
github.com/miekg/dns v1.1.68 h1:jsSRkNozw7G/mnmXULynzMNIsgY2dHC8LO6U6Ij2JEA=
|
||||||
github.com/miekg/dns v1.1.62/go.mod h1:mvDlcItzm+br7MToIKqkglaGhlFMHJ9DTNNWONWXbNQ=
|
github.com/miekg/dns v1.1.68/go.mod h1:fujopn7TB3Pu3JM69XaawiU0wqjpL9/8xGop5UrTPps=
|
||||||
github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b h1:J/AzCvg5z0Hn1rqZUJjpbzALUmkKX0Zwbc/i4fw7Sfk=
|
github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b h1:J/AzCvg5z0Hn1rqZUJjpbzALUmkKX0Zwbc/i4fw7Sfk=
|
||||||
github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b/go.mod h1:XsNlhZGX73bx86s2hdc/FuaLm2CPZJemRLMA+WTFxgs=
|
github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b/go.mod h1:XsNlhZGX73bx86s2hdc/FuaLm2CPZJemRLMA+WTFxgs=
|
||||||
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||||
@@ -108,24 +106,24 @@ github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXP
|
|||||||
github.com/prometheus/client_golang v1.0.0/go.mod h1:db9x61etRT2tGnBNRi70OPL5FsnadC4Ky3P0J6CfImo=
|
github.com/prometheus/client_golang v1.0.0/go.mod h1:db9x61etRT2tGnBNRi70OPL5FsnadC4Ky3P0J6CfImo=
|
||||||
github.com/prometheus/client_golang v1.7.1/go.mod h1:PY5Wy2awLA44sXw4AOSfFBetzPP4j5+D6mVACh+pe2M=
|
github.com/prometheus/client_golang v1.7.1/go.mod h1:PY5Wy2awLA44sXw4AOSfFBetzPP4j5+D6mVACh+pe2M=
|
||||||
github.com/prometheus/client_golang v1.11.0/go.mod h1:Z6t4BnS23TR94PD6BsDNk8yVqroYurpAkEiz0P2BEV0=
|
github.com/prometheus/client_golang v1.11.0/go.mod h1:Z6t4BnS23TR94PD6BsDNk8yVqroYurpAkEiz0P2BEV0=
|
||||||
github.com/prometheus/client_golang v1.20.4 h1:Tgh3Yr67PaOv/uTqloMsCEdeuFTatm5zIq5+qNN23vI=
|
github.com/prometheus/client_golang v1.23.2 h1:Je96obch5RDVy3FDMndoUsjAhG5Edi49h0RJWRi/o0o=
|
||||||
github.com/prometheus/client_golang v1.20.4/go.mod h1:PIEt8X02hGcP8JWbeHyeZ53Y/jReSnHgO035n//V5WE=
|
github.com/prometheus/client_golang v1.23.2/go.mod h1:Tb1a6LWHB3/SPIzCoaDXI4I8UHKeFTEQ1YCr+0Gyqmg=
|
||||||
github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo=
|
github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo=
|
||||||
github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
|
github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
|
||||||
github.com/prometheus/client_model v0.2.0/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
|
github.com/prometheus/client_model v0.2.0/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
|
||||||
github.com/prometheus/client_model v0.6.1 h1:ZKSh/rekM+n3CeS952MLRAdFwIKqeY8b62p8ais2e9E=
|
github.com/prometheus/client_model v0.6.2 h1:oBsgwpGs7iVziMvrGhE53c/GrLUsZdHnqNwqPLxwZyk=
|
||||||
github.com/prometheus/client_model v0.6.1/go.mod h1:OrxVMOVHjw3lKMa8+x6HeMGkHMQyHDk9E3jmP2AmGiY=
|
github.com/prometheus/client_model v0.6.2/go.mod h1:y3m2F6Gdpfy6Ut/GBsUqTWZqCUvMVzSfMLjcu6wAwpE=
|
||||||
github.com/prometheus/common v0.4.1/go.mod h1:TNfzLD0ON7rHzMJeJkieUDPYmFC7Snx/y86RQel1bk4=
|
github.com/prometheus/common v0.4.1/go.mod h1:TNfzLD0ON7rHzMJeJkieUDPYmFC7Snx/y86RQel1bk4=
|
||||||
github.com/prometheus/common v0.10.0/go.mod h1:Tlit/dnDKsSWFlCLTWaA1cyBgKHSMdTB80sz/V91rCo=
|
github.com/prometheus/common v0.10.0/go.mod h1:Tlit/dnDKsSWFlCLTWaA1cyBgKHSMdTB80sz/V91rCo=
|
||||||
github.com/prometheus/common v0.26.0/go.mod h1:M7rCNAaPfAosfx8veZJCuw84e35h3Cfd9VFqTh1DIvc=
|
github.com/prometheus/common v0.26.0/go.mod h1:M7rCNAaPfAosfx8veZJCuw84e35h3Cfd9VFqTh1DIvc=
|
||||||
github.com/prometheus/common v0.55.0 h1:KEi6DK7lXW/m7Ig5i47x0vRzuBsHuvJdi5ee6Y3G1dc=
|
github.com/prometheus/common v0.66.1 h1:h5E0h5/Y8niHc5DlaLlWLArTQI7tMrsfQjHV+d9ZoGs=
|
||||||
github.com/prometheus/common v0.55.0/go.mod h1:2SECS4xJG1kd8XF9IcM1gMX6510RAEL65zxzNImwdc8=
|
github.com/prometheus/common v0.66.1/go.mod h1:gcaUsgf3KfRSwHY4dIMXLPV0K/Wg1oZ8+SbZk/HH/dA=
|
||||||
github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk=
|
github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk=
|
||||||
github.com/prometheus/procfs v0.0.2/go.mod h1:TjEm7ze935MbeOT/UhFTIMYKhuLP4wbCsTZCD3I8kEA=
|
github.com/prometheus/procfs v0.0.2/go.mod h1:TjEm7ze935MbeOT/UhFTIMYKhuLP4wbCsTZCD3I8kEA=
|
||||||
github.com/prometheus/procfs v0.1.3/go.mod h1:lV6e/gmhEcM9IjHGsFOCxxuZ+z1YqCvr4OA4YeYWdaU=
|
github.com/prometheus/procfs v0.1.3/go.mod h1:lV6e/gmhEcM9IjHGsFOCxxuZ+z1YqCvr4OA4YeYWdaU=
|
||||||
github.com/prometheus/procfs v0.6.0/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1xBZuNvfVA=
|
github.com/prometheus/procfs v0.6.0/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1xBZuNvfVA=
|
||||||
github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc=
|
github.com/prometheus/procfs v0.16.1 h1:hZ15bTNuirocR6u0JZ6BAHHmwS1p8B4P6MRqxtzMyRg=
|
||||||
github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk=
|
github.com/prometheus/procfs v0.16.1/go.mod h1:teAbpZRB1iIAJYREa1LsoWUXykVXA1KlTmWl8x/U+Is=
|
||||||
github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475 h1:N/ElC8H3+5XpJzTSTfLsJV/mx9Q9g7kxmchpfZyxgzM=
|
github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475 h1:N/ElC8H3+5XpJzTSTfLsJV/mx9Q9g7kxmchpfZyxgzM=
|
||||||
github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4=
|
github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4=
|
||||||
github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ=
|
github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ=
|
||||||
@@ -145,29 +143,35 @@ github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXf
|
|||||||
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||||
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
|
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
|
||||||
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||||
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
|
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||||
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||||
github.com/vishvananda/netlink v1.3.0 h1:X7l42GfcV4S6E4vHTsw48qbrV+9PVojNfIhZcwQdrZk=
|
github.com/vishvananda/netlink v1.3.1 h1:3AEMt62VKqz90r0tmNhog0r/PpWKmrEShJU0wJW6bV0=
|
||||||
github.com/vishvananda/netlink v1.3.0/go.mod h1:i6NetklAujEcC6fK0JPjT8qSwWyO0HLn4UKG+hGqeJs=
|
github.com/vishvananda/netlink v1.3.1/go.mod h1:ARtKouGSTGchR8aMwmkzC0qiNPrrWO5JS/XMVl45+b4=
|
||||||
github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8=
|
github.com/vishvananda/netns v0.0.5 h1:DfiHV+j8bA32MFM7bfEunvT8IAqQ/NzSJHtcmW5zdEY=
|
||||||
github.com/vishvananda/netns v0.0.4/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM=
|
github.com/vishvananda/netns v0.0.5/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM=
|
||||||
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
||||||
github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
||||||
|
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
|
||||||
|
go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
|
||||||
|
go.yaml.in/yaml/v2 v2.4.2 h1:DzmwEr2rDGHl7lsFgAHxmNz/1NlQ7xLIrlN2h5d1eGI=
|
||||||
|
go.yaml.in/yaml/v2 v2.4.2/go.mod h1:081UH+NErpNdqlCXm3TtEran0rJZGxAYx9hb/ELlsPU=
|
||||||
|
go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc=
|
||||||
|
go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg=
|
||||||
golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
|
golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
|
||||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||||
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||||
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
||||||
golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
|
golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
|
||||||
golang.org/x/crypto v0.28.0 h1:GBDwsMXVQi34v5CCYUm2jkJvu4cbtru2U4TN2PSyQnw=
|
golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q=
|
||||||
golang.org/x/crypto v0.28.0/go.mod h1:rmgy+3RHxRZMyY0jjAJShp2zgEdOqj2AO7U0pYmeQ7U=
|
golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4=
|
||||||
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 h1:Di6/M8l0O2lCLc6VVRWhgCiApHV8MnQurBnFSHsQtNY=
|
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 h1:Di6/M8l0O2lCLc6VVRWhgCiApHV8MnQurBnFSHsQtNY=
|
||||||
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc=
|
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc=
|
||||||
golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY=
|
golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY=
|
||||||
golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg=
|
golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg=
|
||||||
golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
||||||
golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
||||||
golang.org/x/mod v0.18.0 h1:5+9lSbEzPSdWkH32vYPBwEpX8KwDbM52Ud9xBUvNlb0=
|
golang.org/x/mod v0.24.0 h1:ZfthKaKaT4NrhGVZHO1/WDTwGES4De8KtWO0SIbNJMU=
|
||||||
golang.org/x/mod v0.18.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
|
golang.org/x/mod v0.24.0/go.mod h1:IXM97Txy2VM4PJ3gI61r1YEk/gAj6zAHN3AdZt6S9Ww=
|
||||||
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||||
golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||||
golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||||
@@ -178,8 +182,8 @@ golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLL
|
|||||||
golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
|
golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
|
||||||
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
|
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
|
||||||
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||||
golang.org/x/net v0.30.0 h1:AcW1SDZMkb8IpzCdQUaIq2sP4sZ4zw+55h6ynffypl4=
|
golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY=
|
||||||
golang.org/x/net v0.30.0/go.mod h1:2wGyMJ5iFasEhkwi13ChkO/t1ECNC4X4eBKkVFyYFlU=
|
golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU=
|
||||||
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
|
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
|
||||||
golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
@@ -187,8 +191,8 @@ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJ
|
|||||||
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ=
|
golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I=
|
||||||
golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||||
golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||||
golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||||
@@ -199,18 +203,17 @@ golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7w
|
|||||||
golang.org/x/sys v0.0.0-20200615200032-f1bc736245b1/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
golang.org/x/sys v0.0.0-20200615200032-f1bc736245b1/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
golang.org/x/sys v0.0.0-20200625212154-ddb9806d33ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
golang.org/x/sys v0.0.0-20200625212154-ddb9806d33ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
golang.org/x/sys v0.0.0-20201015000850-e3ed0017c211/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
|
||||||
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
golang.org/x/sys v0.0.0-20210603081109-ebe580a85c40/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.0.0-20210603081109-ebe580a85c40/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.26.0 h1:KHjCJyddX0LoSTb3J+vWpupP9p0oznkqVk/IfjymZbo=
|
golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc=
|
||||||
golang.org/x/sys v0.26.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||||
golang.org/x/term v0.25.0 h1:WtHI/ltw4NvSUig5KARz9h521QvRC8RmF/cuYqifU24=
|
golang.org/x/term v0.37.0 h1:8EGAD0qCmHYZg6J17DvsMy9/wJ7/D/4pV/wfnld5lTU=
|
||||||
golang.org/x/term v0.25.0/go.mod h1:RPyXicDX+6vLxogjjRxjgD2TKtmAO6NZBsBRfrOLu7M=
|
golang.org/x/term v0.37.0/go.mod h1:5pB4lxRNYYVZuTLmy8oR2BH8dflOR+IbTYFD8fi3254=
|
||||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||||
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
|
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
|
||||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||||
@@ -221,8 +224,8 @@ golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtn
|
|||||||
golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
|
golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
|
||||||
golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
|
golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
|
||||||
golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
|
golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
|
||||||
golang.org/x/tools v0.22.0 h1:gqSGLZqv+AI9lIQzniJ0nZDRG5GBPsSi+DRNHWNz6yA=
|
golang.org/x/tools v0.33.0 h1:4qz2S3zmRxbGIhDIAgjxvFutSvH5EfnsYrRBj0UI0bc=
|
||||||
golang.org/x/tools v0.22.0/go.mod h1:aCwcsjqvq7Yqt6TNyX7QMU2enbQ/Gt0bo6krSeEri+c=
|
golang.org/x/tools v0.33.0/go.mod h1:CIJMaWEY88juyUfo7UbgPqbC8rU2OqfAV1h2Qp0oMYI=
|
||||||
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||||
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||||
@@ -241,8 +244,8 @@ google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miE
|
|||||||
google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo=
|
google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo=
|
||||||
google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU=
|
google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU=
|
||||||
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
|
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
|
||||||
google.golang.org/protobuf v1.35.1 h1:m3LfL6/Ca+fqnjnlqQXNpFPABW1UD7mjh8KO2mKFytA=
|
google.golang.org/protobuf v1.36.10 h1:AYd7cD/uASjIL6Q9LiTjz8JLcrh/88q5UObnmY3aOOE=
|
||||||
google.golang.org/protobuf v1.35.1/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE=
|
google.golang.org/protobuf v1.36.10/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco=
|
||||||
gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw=
|
gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw=
|
||||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||||
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||||
@@ -253,8 +256,6 @@ gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
|||||||
gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||||
gopkg.in/yaml.v2 v2.2.5/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
gopkg.in/yaml.v2 v2.2.5/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||||
gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||||
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
|
|
||||||
gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
|
|
||||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||||
|
|||||||
266
handshake_ix.go
266
handshake_ix.go
@@ -2,7 +2,6 @@ package nebula
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"slices"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/flynn/noise"
|
"github.com/flynn/noise"
|
||||||
@@ -23,13 +22,17 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// If we're connecting to a v6 address we must use a v2 cert
|
|
||||||
cs := f.pki.getCertState()
|
cs := f.pki.getCertState()
|
||||||
v := cs.defaultVersion
|
v := cs.initiatingVersion
|
||||||
for _, a := range hh.hostinfo.vpnAddrs {
|
if hh.initiatingVersionOverride != cert.VersionPre1 {
|
||||||
if a.Is6() {
|
v = hh.initiatingVersionOverride
|
||||||
v = cert.Version2
|
} else if v < cert.Version2 {
|
||||||
break
|
// If we're connecting to a v6 address we should encourage use of a V2 cert
|
||||||
|
for _, a := range hh.hostinfo.vpnAddrs {
|
||||||
|
if a.Is6() {
|
||||||
|
v = cert.Version2
|
||||||
|
break
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -48,6 +51,7 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
|
|||||||
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
|
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
|
||||||
WithField("certVersion", v).
|
WithField("certVersion", v).
|
||||||
Error("Unable to handshake with host because no certificate handshake bytes is available")
|
Error("Unable to handshake with host because no certificate handshake bytes is available")
|
||||||
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
ci, err := NewConnectionState(f.l, cs, crt, true, noise.HandshakeIX)
|
ci, err := NewConnectionState(f.l, cs, crt, true, noise.HandshakeIX)
|
||||||
@@ -71,7 +75,8 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
|
|||||||
|
|
||||||
hsBytes, err := hs.Marshal()
|
hsBytes, err := hs.Marshal()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs).WithField("certVersion", v).
|
f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs).
|
||||||
|
WithField("certVersion", v).
|
||||||
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to marshal handshake message")
|
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to marshal handshake message")
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
@@ -100,8 +105,9 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
|||||||
if crt == nil {
|
if crt == nil {
|
||||||
f.l.WithField("udpAddr", addr).
|
f.l.WithField("udpAddr", addr).
|
||||||
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
|
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
|
||||||
WithField("certVersion", cs.defaultVersion).
|
WithField("certVersion", cs.initiatingVersion).
|
||||||
Error("Unable to handshake with host because no certificate is available")
|
Error("Unable to handshake with host because no certificate is available")
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
ci, err := NewConnectionState(f.l, cs, crt, false, noise.HandshakeIX)
|
ci, err := NewConnectionState(f.l, cs, crt, false, noise.HandshakeIX)
|
||||||
@@ -132,13 +138,28 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
remoteCert, err := cert.RecombineAndValidate(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), ci.Curve(), f.pki.GetCAPool())
|
rc, err := cert.Recombine(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), ci.Curve())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
e := f.l.WithError(err).WithField("udpAddr", addr).
|
f.l.WithError(err).WithField("udpAddr", addr).
|
||||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"})
|
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
||||||
|
Info("Handshake did not contain a certificate")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
if f.l.Level > logrus.DebugLevel {
|
remoteCert, err := f.pki.GetCAPool().VerifyCertificate(time.Now(), rc)
|
||||||
e = e.WithField("cert", remoteCert)
|
if err != nil {
|
||||||
|
fp, fperr := rc.Fingerprint()
|
||||||
|
if fperr != nil {
|
||||||
|
fp = "<error generating certificate fingerprint>"
|
||||||
|
}
|
||||||
|
|
||||||
|
e := f.l.WithError(err).WithField("udpAddr", addr).
|
||||||
|
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
||||||
|
WithField("certVpnNetworks", rc.Networks()).
|
||||||
|
WithField("certFingerprint", fp)
|
||||||
|
|
||||||
|
if f.l.Level >= logrus.DebugLevel {
|
||||||
|
e = e.WithField("cert", rc)
|
||||||
}
|
}
|
||||||
|
|
||||||
e.Info("Invalid certificate from host")
|
e.Info("Invalid certificate from host")
|
||||||
@@ -147,64 +168,51 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
|||||||
|
|
||||||
if remoteCert.Certificate.Version() != ci.myCert.Version() {
|
if remoteCert.Certificate.Version() != ci.myCert.Version() {
|
||||||
// We started off using the wrong certificate version, lets see if we can match the version that was sent to us
|
// We started off using the wrong certificate version, lets see if we can match the version that was sent to us
|
||||||
rc := cs.getCertificate(remoteCert.Certificate.Version())
|
myCertOtherVersion := cs.getCertificate(remoteCert.Certificate.Version())
|
||||||
if rc == nil {
|
if myCertOtherVersion == nil {
|
||||||
f.l.WithError(err).WithField("udpAddr", addr).
|
if f.l.Level >= logrus.DebugLevel {
|
||||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).WithField("cert", remoteCert).
|
f.l.WithError(err).WithFields(m{
|
||||||
Info("Unable to handshake with host due to missing certificate version")
|
"udpAddr": addr,
|
||||||
return
|
"handshake": m{"stage": 1, "style": "ix_psk0"},
|
||||||
|
"cert": remoteCert,
|
||||||
|
}).Debug("Might be unable to handshake with host due to missing certificate version")
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Record the certificate we are actually using
|
||||||
|
ci.myCert = myCertOtherVersion
|
||||||
}
|
}
|
||||||
|
|
||||||
// Record the certificate we are actually using
|
|
||||||
ci.myCert = rc
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(remoteCert.Certificate.Networks()) == 0 {
|
if len(remoteCert.Certificate.Networks()) == 0 {
|
||||||
e := f.l.WithError(err).WithField("udpAddr", addr).
|
f.l.WithError(err).WithField("udpAddr", addr).
|
||||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"})
|
WithField("cert", remoteCert).
|
||||||
|
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
||||||
if f.l.Level > logrus.DebugLevel {
|
Info("No networks in certificate")
|
||||||
e = e.WithField("cert", remoteCert)
|
|
||||||
}
|
|
||||||
|
|
||||||
e.Info("Invalid vpn ip from host")
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var vpnAddrs []netip.Addr
|
|
||||||
var filteredNetworks []netip.Prefix
|
|
||||||
certName := remoteCert.Certificate.Name()
|
certName := remoteCert.Certificate.Name()
|
||||||
|
certVersion := remoteCert.Certificate.Version()
|
||||||
fingerprint := remoteCert.Fingerprint
|
fingerprint := remoteCert.Fingerprint
|
||||||
issuer := remoteCert.Certificate.Issuer()
|
issuer := remoteCert.Certificate.Issuer()
|
||||||
|
vpnNetworks := remoteCert.Certificate.Networks()
|
||||||
|
|
||||||
for _, network := range remoteCert.Certificate.Networks() {
|
anyVpnAddrsInCommon := false
|
||||||
vpnAddr := network.Addr()
|
vpnAddrs := make([]netip.Addr, len(vpnNetworks))
|
||||||
_, found := f.myVpnAddrsTable.Lookup(vpnAddr)
|
for i, network := range vpnNetworks {
|
||||||
if found {
|
if f.myVpnAddrsTable.Contains(network.Addr()) {
|
||||||
f.l.WithField("vpnAddr", vpnAddr).WithField("udpAddr", addr).
|
f.l.WithField("vpnNetworks", vpnNetworks).WithField("udpAddr", addr).
|
||||||
WithField("certName", certName).
|
WithField("certName", certName).
|
||||||
|
WithField("certVersion", certVersion).
|
||||||
WithField("fingerprint", fingerprint).
|
WithField("fingerprint", fingerprint).
|
||||||
WithField("issuer", issuer).
|
WithField("issuer", issuer).
|
||||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Refusing to handshake with myself")
|
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Refusing to handshake with myself")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
vpnAddrs[i] = network.Addr()
|
||||||
// vpnAddrs outside our vpn networks are of no use to us, filter them out
|
if f.myVpnNetworksTable.Contains(network.Addr()) {
|
||||||
if _, ok := f.myVpnNetworksTable.Lookup(vpnAddr); !ok {
|
anyVpnAddrsInCommon = true
|
||||||
continue
|
|
||||||
}
|
}
|
||||||
|
|
||||||
filteredNetworks = append(filteredNetworks, network)
|
|
||||||
vpnAddrs = append(vpnAddrs, vpnAddr)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(vpnAddrs) == 0 {
|
|
||||||
f.l.WithError(err).WithField("udpAddr", addr).
|
|
||||||
WithField("certName", certName).
|
|
||||||
WithField("fingerprint", fingerprint).
|
|
||||||
WithField("issuer", issuer).
|
|
||||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("No usable vpn addresses from host, refusing handshake")
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if addr.IsValid() {
|
if addr.IsValid() {
|
||||||
@@ -220,6 +228,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
f.l.WithError(err).WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
|
f.l.WithError(err).WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
|
||||||
WithField("certName", certName).
|
WithField("certName", certName).
|
||||||
|
WithField("certVersion", certVersion).
|
||||||
WithField("fingerprint", fingerprint).
|
WithField("fingerprint", fingerprint).
|
||||||
WithField("issuer", issuer).
|
WithField("issuer", issuer).
|
||||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to generate index")
|
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to generate index")
|
||||||
@@ -234,30 +243,36 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
|||||||
HandshakePacket: make(map[uint8][]byte, 0),
|
HandshakePacket: make(map[uint8][]byte, 0),
|
||||||
lastHandshakeTime: hs.Details.Time,
|
lastHandshakeTime: hs.Details.Time,
|
||||||
relayState: RelayState{
|
relayState: RelayState{
|
||||||
relays: map[netip.Addr]struct{}{},
|
relays: nil,
|
||||||
relayForByAddr: map[netip.Addr]*Relay{},
|
relayForByAddr: map[netip.Addr]*Relay{},
|
||||||
relayForByIdx: map[uint32]*Relay{},
|
relayForByIdx: map[uint32]*Relay{},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
|
msgRxL := f.l.WithFields(m{
|
||||||
WithField("certName", certName).
|
"vpnAddrs": vpnAddrs,
|
||||||
WithField("fingerprint", fingerprint).
|
"udpAddr": addr,
|
||||||
WithField("issuer", issuer).
|
"certName": certName,
|
||||||
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
"certVersion": certVersion,
|
||||||
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
"fingerprint": fingerprint,
|
||||||
Info("Handshake message received")
|
"issuer": issuer,
|
||||||
|
"initiatorIndex": hs.Details.InitiatorIndex,
|
||||||
|
"responderIndex": hs.Details.ResponderIndex,
|
||||||
|
"remoteIndex": h.RemoteIndex,
|
||||||
|
"handshake": m{"stage": 1, "style": "ix_psk0"},
|
||||||
|
})
|
||||||
|
|
||||||
|
if anyVpnAddrsInCommon {
|
||||||
|
msgRxL.Info("Handshake message received")
|
||||||
|
} else {
|
||||||
|
//todo warn if not lighthouse or relay?
|
||||||
|
msgRxL.Info("Handshake message received, but no vpnNetworks in common.")
|
||||||
|
}
|
||||||
|
|
||||||
hs.Details.ResponderIndex = myIndex
|
hs.Details.ResponderIndex = myIndex
|
||||||
hs.Details.Cert = cs.getHandshakeBytes(ci.myCert.Version())
|
hs.Details.Cert = cs.getHandshakeBytes(ci.myCert.Version())
|
||||||
if hs.Details.Cert == nil {
|
if hs.Details.Cert == nil {
|
||||||
f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
|
msgRxL.WithField("myCertVersion", ci.myCert.Version()).
|
||||||
WithField("certName", certName).
|
|
||||||
WithField("fingerprint", fingerprint).
|
|
||||||
WithField("issuer", issuer).
|
|
||||||
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
|
||||||
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
|
||||||
WithField("certVersion", ci.myCert.Version()).
|
|
||||||
Error("Unable to handshake with host because no certificate handshake bytes is available")
|
Error("Unable to handshake with host because no certificate handshake bytes is available")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -270,6 +285,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).
|
f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).
|
||||||
WithField("certName", certName).
|
WithField("certName", certName).
|
||||||
|
WithField("certVersion", certVersion).
|
||||||
WithField("fingerprint", fingerprint).
|
WithField("fingerprint", fingerprint).
|
||||||
WithField("issuer", issuer).
|
WithField("issuer", issuer).
|
||||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to marshal handshake message")
|
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to marshal handshake message")
|
||||||
@@ -281,6 +297,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).
|
f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).
|
||||||
WithField("certName", certName).
|
WithField("certName", certName).
|
||||||
|
WithField("certVersion", certVersion).
|
||||||
WithField("fingerprint", fingerprint).
|
WithField("fingerprint", fingerprint).
|
||||||
WithField("issuer", issuer).
|
WithField("issuer", issuer).
|
||||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage")
|
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage")
|
||||||
@@ -288,6 +305,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
|||||||
} else if dKey == nil || eKey == nil {
|
} else if dKey == nil || eKey == nil {
|
||||||
f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).
|
f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).
|
||||||
WithField("certName", certName).
|
WithField("certName", certName).
|
||||||
|
WithField("certVersion", certVersion).
|
||||||
WithField("fingerprint", fingerprint).
|
WithField("fingerprint", fingerprint).
|
||||||
WithField("issuer", issuer).
|
WithField("issuer", issuer).
|
||||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Noise did not arrive at a key")
|
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Noise did not arrive at a key")
|
||||||
@@ -312,7 +330,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
|||||||
|
|
||||||
hostinfo.remotes = f.lightHouse.QueryCache(vpnAddrs)
|
hostinfo.remotes = f.lightHouse.QueryCache(vpnAddrs)
|
||||||
hostinfo.SetRemote(addr)
|
hostinfo.SetRemote(addr)
|
||||||
hostinfo.buildNetworks(filteredNetworks, remoteCert.Certificate.UnsafeNetworks())
|
hostinfo.buildNetworks(f.myVpnNetworksTable, remoteCert.Certificate)
|
||||||
|
|
||||||
existing, err := f.handshakeManager.CheckAndComplete(hostinfo, 0, f)
|
existing, err := f.handshakeManager.CheckAndComplete(hostinfo, 0, f)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -355,6 +373,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
|||||||
// This means there was an existing tunnel and this handshake was older than the one we are currently based on
|
// This means there was an existing tunnel and this handshake was older than the one we are currently based on
|
||||||
f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
|
f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
|
||||||
WithField("certName", certName).
|
WithField("certName", certName).
|
||||||
|
WithField("certVersion", certVersion).
|
||||||
WithField("oldHandshakeTime", existing.lastHandshakeTime).
|
WithField("oldHandshakeTime", existing.lastHandshakeTime).
|
||||||
WithField("newHandshakeTime", hostinfo.lastHandshakeTime).
|
WithField("newHandshakeTime", hostinfo.lastHandshakeTime).
|
||||||
WithField("fingerprint", fingerprint).
|
WithField("fingerprint", fingerprint).
|
||||||
@@ -370,6 +389,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
|||||||
// This means we failed to insert because of collision on localIndexId. Just let the next handshake packet retry
|
// This means we failed to insert because of collision on localIndexId. Just let the next handshake packet retry
|
||||||
f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
|
f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
|
||||||
WithField("certName", certName).
|
WithField("certName", certName).
|
||||||
|
WithField("certVersion", certVersion).
|
||||||
WithField("fingerprint", fingerprint).
|
WithField("fingerprint", fingerprint).
|
||||||
WithField("issuer", issuer).
|
WithField("issuer", issuer).
|
||||||
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
||||||
@@ -382,6 +402,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
|||||||
// And we forget to update it here
|
// And we forget to update it here
|
||||||
f.l.WithError(err).WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
|
f.l.WithError(err).WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
|
||||||
WithField("certName", certName).
|
WithField("certName", certName).
|
||||||
|
WithField("certVersion", certVersion).
|
||||||
WithField("fingerprint", fingerprint).
|
WithField("fingerprint", fingerprint).
|
||||||
WithField("issuer", issuer).
|
WithField("issuer", issuer).
|
||||||
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
||||||
@@ -398,6 +419,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
|
f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
|
||||||
WithField("certName", certName).
|
WithField("certName", certName).
|
||||||
|
WithField("certVersion", certVersion).
|
||||||
WithField("fingerprint", fingerprint).
|
WithField("fingerprint", fingerprint).
|
||||||
WithField("issuer", issuer).
|
WithField("issuer", issuer).
|
||||||
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
||||||
@@ -406,6 +428,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
|||||||
} else {
|
} else {
|
||||||
f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
|
f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
|
||||||
WithField("certName", certName).
|
WithField("certName", certName).
|
||||||
|
WithField("certVersion", certVersion).
|
||||||
WithField("fingerprint", fingerprint).
|
WithField("fingerprint", fingerprint).
|
||||||
WithField("issuer", issuer).
|
WithField("issuer", issuer).
|
||||||
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
||||||
@@ -424,6 +447,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
|||||||
f.SendVia(via.relayHI, via.relay, msg, make([]byte, 12), make([]byte, mtu), false)
|
f.SendVia(via.relayHI, via.relay, msg, make([]byte, 12), make([]byte, mtu), false)
|
||||||
f.l.WithField("vpnAddrs", vpnAddrs).WithField("relay", via.relayHI.vpnAddrs[0]).
|
f.l.WithField("vpnAddrs", vpnAddrs).WithField("relay", via.relayHI.vpnAddrs[0]).
|
||||||
WithField("certName", certName).
|
WithField("certName", certName).
|
||||||
|
WithField("certVersion", certVersion).
|
||||||
WithField("fingerprint", fingerprint).
|
WithField("fingerprint", fingerprint).
|
||||||
WithField("issuer", issuer).
|
WithField("issuer", issuer).
|
||||||
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
||||||
@@ -431,9 +455,9 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
|||||||
Info("Handshake message sent")
|
Info("Handshake message sent")
|
||||||
}
|
}
|
||||||
|
|
||||||
f.connectionManager.AddTrafficWatch(hostinfo.localIndexId)
|
f.connectionManager.AddTrafficWatch(hostinfo)
|
||||||
|
|
||||||
hostinfo.remotes.ResetBlockedRemotes()
|
hostinfo.remotes.RefreshFromHandshake(vpnAddrs)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -487,35 +511,48 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
remoteCert, err := cert.RecombineAndValidate(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), ci.Curve(), f.pki.GetCAPool())
|
rc, err := cert.Recombine(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), ci.Curve())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
e := f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).
|
f.l.WithError(err).WithField("udpAddr", addr).
|
||||||
WithField("handshake", m{"stage": 2, "style": "ix_psk0"})
|
WithField("vpnAddrs", hostinfo.vpnAddrs).
|
||||||
|
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
|
||||||
|
Info("Handshake did not contain a certificate")
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
if f.l.Level > logrus.DebugLevel {
|
remoteCert, err := f.pki.GetCAPool().VerifyCertificate(time.Now(), rc)
|
||||||
e = e.WithField("cert", remoteCert)
|
if err != nil {
|
||||||
|
fp, err := rc.Fingerprint()
|
||||||
|
if err != nil {
|
||||||
|
fp = "<error generating certificate fingerprint>"
|
||||||
}
|
}
|
||||||
|
|
||||||
e.Error("Invalid certificate from host")
|
e := f.l.WithError(err).WithField("udpAddr", addr).
|
||||||
|
WithField("vpnAddrs", hostinfo.vpnAddrs).
|
||||||
|
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
|
||||||
|
WithField("certFingerprint", fp).
|
||||||
|
WithField("certVpnNetworks", rc.Networks())
|
||||||
|
|
||||||
// The handshake state machine is complete, if things break now there is no chance to recover. Tear down and start again
|
if f.l.Level >= logrus.DebugLevel {
|
||||||
|
e = e.WithField("cert", rc)
|
||||||
|
}
|
||||||
|
|
||||||
|
e.Info("Invalid certificate from host")
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(remoteCert.Certificate.Networks()) == 0 {
|
if len(remoteCert.Certificate.Networks()) == 0 {
|
||||||
e := f.l.WithError(err).WithField("udpAddr", addr).
|
f.l.WithError(err).WithField("udpAddr", addr).
|
||||||
WithField("handshake", m{"stage": 2, "style": "ix_psk0"})
|
WithField("vpnAddrs", hostinfo.vpnAddrs).
|
||||||
|
WithField("cert", remoteCert).
|
||||||
if f.l.Level > logrus.DebugLevel {
|
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
|
||||||
e = e.WithField("cert", remoteCert)
|
Info("No networks in certificate")
|
||||||
}
|
|
||||||
|
|
||||||
e.Info("Empty networks from host")
|
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
vpnNetworks := remoteCert.Certificate.Networks()
|
vpnNetworks := remoteCert.Certificate.Networks()
|
||||||
certName := remoteCert.Certificate.Name()
|
certName := remoteCert.Certificate.Name()
|
||||||
|
certVersion := remoteCert.Certificate.Version()
|
||||||
fingerprint := remoteCert.Fingerprint
|
fingerprint := remoteCert.Fingerprint
|
||||||
issuer := remoteCert.Certificate.Issuer()
|
issuer := remoteCert.Certificate.Issuer()
|
||||||
|
|
||||||
@@ -534,32 +571,26 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
|
|||||||
hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0])
|
hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0])
|
||||||
}
|
}
|
||||||
|
|
||||||
var vpnAddrs []netip.Addr
|
correctHostResponded := false
|
||||||
var filteredNetworks []netip.Prefix
|
anyVpnAddrsInCommon := false
|
||||||
for _, network := range vpnNetworks {
|
vpnAddrs := make([]netip.Addr, len(vpnNetworks))
|
||||||
// vpnAddrs outside our vpn networks are of no use to us, filter them out
|
for i, network := range vpnNetworks {
|
||||||
vpnAddr := network.Addr()
|
vpnAddrs[i] = network.Addr()
|
||||||
if _, ok := f.myVpnNetworksTable.Lookup(vpnAddr); !ok {
|
if f.myVpnNetworksTable.Contains(network.Addr()) {
|
||||||
continue
|
anyVpnAddrsInCommon = true
|
||||||
|
}
|
||||||
|
if hostinfo.vpnAddrs[0] == network.Addr() {
|
||||||
|
// todo is it more correct to see if any of hostinfo.vpnAddrs are in the cert? it should have len==1, but one day it might not?
|
||||||
|
correctHostResponded = true
|
||||||
}
|
}
|
||||||
|
|
||||||
filteredNetworks = append(filteredNetworks, network)
|
|
||||||
vpnAddrs = append(vpnAddrs, vpnAddr)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(vpnAddrs) == 0 {
|
|
||||||
f.l.WithError(err).WithField("udpAddr", addr).
|
|
||||||
WithField("certName", certName).
|
|
||||||
WithField("fingerprint", fingerprint).
|
|
||||||
WithField("issuer", issuer).
|
|
||||||
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).Error("No usable vpn addresses from host, refusing handshake")
|
|
||||||
return true
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Ensure the right host responded
|
// Ensure the right host responded
|
||||||
if !slices.Contains(vpnAddrs, hostinfo.vpnAddrs[0]) {
|
if !correctHostResponded {
|
||||||
f.l.WithField("intendedVpnAddrs", hostinfo.vpnAddrs).WithField("haveVpnNetworks", vpnNetworks).
|
f.l.WithField("intendedVpnAddrs", hostinfo.vpnAddrs).WithField("haveVpnNetworks", vpnNetworks).
|
||||||
WithField("udpAddr", addr).WithField("certName", certName).
|
WithField("udpAddr", addr).
|
||||||
|
WithField("certName", certName).
|
||||||
|
WithField("certVersion", certVersion).
|
||||||
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
|
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
|
||||||
Info("Incorrect host responded to handshake")
|
Info("Incorrect host responded to handshake")
|
||||||
|
|
||||||
@@ -567,6 +598,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
|
|||||||
f.handshakeManager.DeleteHostInfo(hostinfo)
|
f.handshakeManager.DeleteHostInfo(hostinfo)
|
||||||
|
|
||||||
// Create a new hostinfo/handshake for the intended vpn ip
|
// Create a new hostinfo/handshake for the intended vpn ip
|
||||||
|
//TODO is hostinfo.vpnAddrs[0] always the address to use?
|
||||||
f.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], func(newHH *HandshakeHostInfo) {
|
f.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], func(newHH *HandshakeHostInfo) {
|
||||||
// Block the current used address
|
// Block the current used address
|
||||||
newHH.hostinfo.remotes = hostinfo.remotes
|
newHH.hostinfo.remotes = hostinfo.remotes
|
||||||
@@ -593,23 +625,29 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
|
|||||||
ci.window.Update(f.l, 2)
|
ci.window.Update(f.l, 2)
|
||||||
|
|
||||||
duration := time.Since(hh.startTime).Nanoseconds()
|
duration := time.Since(hh.startTime).Nanoseconds()
|
||||||
f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
|
msgRxL := f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
|
||||||
WithField("certName", certName).
|
WithField("certName", certName).
|
||||||
|
WithField("certVersion", certVersion).
|
||||||
WithField("fingerprint", fingerprint).
|
WithField("fingerprint", fingerprint).
|
||||||
WithField("issuer", issuer).
|
WithField("issuer", issuer).
|
||||||
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
||||||
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
|
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
|
||||||
WithField("durationNs", duration).
|
WithField("durationNs", duration).
|
||||||
WithField("sentCachedPackets", len(hh.packetStore)).
|
WithField("sentCachedPackets", len(hh.packetStore))
|
||||||
Info("Handshake message received")
|
if anyVpnAddrsInCommon {
|
||||||
|
msgRxL.Info("Handshake message received")
|
||||||
|
} else {
|
||||||
|
//todo warn if not lighthouse or relay?
|
||||||
|
msgRxL.Info("Handshake message received, but no vpnNetworks in common.")
|
||||||
|
}
|
||||||
|
|
||||||
// Build up the radix for the firewall if we have subnets in the cert
|
// Build up the radix for the firewall if we have subnets in the cert
|
||||||
hostinfo.vpnAddrs = vpnAddrs
|
hostinfo.vpnAddrs = vpnAddrs
|
||||||
hostinfo.buildNetworks(filteredNetworks, remoteCert.Certificate.UnsafeNetworks())
|
hostinfo.buildNetworks(f.myVpnNetworksTable, remoteCert.Certificate)
|
||||||
|
|
||||||
// Complete our handshake and update metrics, this will replace any existing tunnels for the vpnAddrs here
|
// Complete our handshake and update metrics, this will replace any existing tunnels for the vpnAddrs here
|
||||||
f.handshakeManager.Complete(hostinfo, f)
|
f.handshakeManager.Complete(hostinfo, f)
|
||||||
f.connectionManager.AddTrafficWatch(hostinfo.localIndexId)
|
f.connectionManager.AddTrafficWatch(hostinfo)
|
||||||
|
|
||||||
if f.l.Level >= logrus.DebugLevel {
|
if f.l.Level >= logrus.DebugLevel {
|
||||||
hostinfo.logger(f.l).Debugf("Sending %d stored packets", len(hh.packetStore))
|
hostinfo.logger(f.l).Debugf("Sending %d stored packets", len(hh.packetStore))
|
||||||
@@ -624,7 +662,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
|
|||||||
f.cachedPacketMetrics.sent.Inc(int64(len(hh.packetStore)))
|
f.cachedPacketMetrics.sent.Inc(int64(len(hh.packetStore)))
|
||||||
}
|
}
|
||||||
|
|
||||||
hostinfo.remotes.ResetBlockedRemotes()
|
hostinfo.remotes.RefreshFromHandshake(vpnAddrs)
|
||||||
f.metricHandshakes.Update(duration)
|
f.metricHandshakes.Update(duration)
|
||||||
|
|
||||||
return false
|
return false
|
||||||
|
|||||||
@@ -68,11 +68,12 @@ type HandshakeManager struct {
|
|||||||
type HandshakeHostInfo struct {
|
type HandshakeHostInfo struct {
|
||||||
sync.Mutex
|
sync.Mutex
|
||||||
|
|
||||||
startTime time.Time // Time that we first started trying with this handshake
|
startTime time.Time // Time that we first started trying with this handshake
|
||||||
ready bool // Is the handshake ready
|
ready bool // Is the handshake ready
|
||||||
counter int64 // How many attempts have we made so far
|
initiatingVersionOverride cert.Version // Should we use a non-default cert version for this handshake?
|
||||||
lastRemotes []netip.AddrPort // Remotes that we sent to during the previous attempt
|
counter int64 // How many attempts have we made so far
|
||||||
packetStore []*cachedPacket // A set of packets to be transmitted once the handshake completes
|
lastRemotes []netip.AddrPort // Remotes that we sent to during the previous attempt
|
||||||
|
packetStore []*cachedPacket // A set of packets to be transmitted once the handshake completes
|
||||||
|
|
||||||
hostinfo *HostInfo
|
hostinfo *HostInfo
|
||||||
}
|
}
|
||||||
@@ -257,7 +258,7 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
|
|||||||
WithField("initiatorIndex", hostinfo.localIndexId).
|
WithField("initiatorIndex", hostinfo.localIndexId).
|
||||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
||||||
Info("Handshake message sent")
|
Info("Handshake message sent")
|
||||||
} else if hm.l.IsLevelEnabled(logrus.DebugLevel) {
|
} else if hm.l.Level >= logrus.DebugLevel {
|
||||||
hostinfo.logger(hm.l).WithField("udpAddrs", sentTo).
|
hostinfo.logger(hm.l).WithField("udpAddrs", sentTo).
|
||||||
WithField("initiatorIndex", hostinfo.localIndexId).
|
WithField("initiatorIndex", hostinfo.localIndexId).
|
||||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
||||||
@@ -268,14 +269,13 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
|
|||||||
hostinfo.logger(hm.l).WithField("relays", hostinfo.remotes.relays).Info("Attempt to relay through hosts")
|
hostinfo.logger(hm.l).WithField("relays", hostinfo.remotes.relays).Info("Attempt to relay through hosts")
|
||||||
// Send a RelayRequest to all known Relay IP's
|
// Send a RelayRequest to all known Relay IP's
|
||||||
for _, relay := range hostinfo.remotes.relays {
|
for _, relay := range hostinfo.remotes.relays {
|
||||||
// Don't relay to myself
|
// Don't relay through the host I'm trying to connect to
|
||||||
if relay == vpnIp {
|
if relay == vpnIp {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// Don't relay through the host I'm trying to connect to
|
// Don't relay to myself
|
||||||
_, found := hm.f.myVpnAddrsTable.Lookup(relay)
|
if hm.f.myVpnAddrsTable.Contains(relay) {
|
||||||
if found {
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -451,7 +451,7 @@ func (hm *HandshakeManager) StartHandshake(vpnAddr netip.Addr, cacheCb func(*Han
|
|||||||
vpnAddrs: []netip.Addr{vpnAddr},
|
vpnAddrs: []netip.Addr{vpnAddr},
|
||||||
HandshakePacket: make(map[uint8][]byte, 0),
|
HandshakePacket: make(map[uint8][]byte, 0),
|
||||||
relayState: RelayState{
|
relayState: RelayState{
|
||||||
relays: map[netip.Addr]struct{}{},
|
relays: nil,
|
||||||
relayForByAddr: map[netip.Addr]*Relay{},
|
relayForByAddr: map[netip.Addr]*Relay{},
|
||||||
relayForByIdx: map[uint32]*Relay{},
|
relayForByIdx: map[uint32]*Relay{},
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -24,10 +24,10 @@ func Test_NewHandshakeManagerVpnIp(t *testing.T) {
|
|||||||
lh := newTestLighthouse()
|
lh := newTestLighthouse()
|
||||||
|
|
||||||
cs := &CertState{
|
cs := &CertState{
|
||||||
defaultVersion: cert.Version1,
|
initiatingVersion: cert.Version1,
|
||||||
privateKey: []byte{},
|
privateKey: []byte{},
|
||||||
v1Cert: &dummyCert{version: cert.Version1},
|
v1Cert: &dummyCert{version: cert.Version1},
|
||||||
v1HandshakeBytes: []byte{},
|
v1HandshakeBytes: []byte{},
|
||||||
}
|
}
|
||||||
|
|
||||||
blah := NewHandshakeManager(l, mainHM, lh, &udp.NoopConn{}, defaultHandshakeConfig)
|
blah := NewHandshakeManager(l, mainHM, lh, &udp.NoopConn{}, defaultHandshakeConfig)
|
||||||
@@ -44,7 +44,7 @@ func Test_NewHandshakeManagerVpnIp(t *testing.T) {
|
|||||||
i.remotes = NewRemoteList([]netip.Addr{}, nil)
|
i.remotes = NewRemoteList([]netip.Addr{}, nil)
|
||||||
|
|
||||||
// Adding something to pending should not affect the main hostmap
|
// Adding something to pending should not affect the main hostmap
|
||||||
assert.Len(t, mainHM.Hosts, 0)
|
assert.Empty(t, mainHM.Hosts)
|
||||||
|
|
||||||
// Confirm they are in the pending index list
|
// Confirm they are in the pending index list
|
||||||
assert.Contains(t, blah.vpnIps, ip)
|
assert.Contains(t, blah.vpnIps, ip)
|
||||||
@@ -98,5 +98,5 @@ func (mw *mockEncWriter) GetHostInfo(_ netip.Addr) *HostInfo {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (mw *mockEncWriter) GetCertState() *CertState {
|
func (mw *mockEncWriter) GetCertState() *CertState {
|
||||||
return &CertState{defaultVersion: cert.Version2}
|
return &CertState{initiatingVersion: cert.Version2}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ import (
|
|||||||
// |-----------------------------------------------------------------------|
|
// |-----------------------------------------------------------------------|
|
||||||
// | payload... |
|
// | payload... |
|
||||||
|
|
||||||
type m map[string]interface{}
|
type m = map[string]any
|
||||||
|
|
||||||
const (
|
const (
|
||||||
Version uint8 = 1
|
Version uint8 = 1
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
type headerTest struct {
|
type headerTest struct {
|
||||||
@@ -111,7 +112,7 @@ func TestHeader_String(t *testing.T) {
|
|||||||
|
|
||||||
func TestHeader_MarshalJSON(t *testing.T) {
|
func TestHeader_MarshalJSON(t *testing.T) {
|
||||||
b, err := (&H{100, Test, TestRequest, 99, 98, 97}).MarshalJSON()
|
b, err := (&H{100, Test, TestRequest, 99, 98, 97}).MarshalJSON()
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(
|
assert.Equal(
|
||||||
t,
|
t,
|
||||||
"{\"messageCounter\":97,\"remoteIndex\":98,\"reserved\":99,\"subType\":\"testRequest\",\"type\":\"test\",\"version\":100}",
|
"{\"messageCounter\":97,\"remoteIndex\":98,\"reserved\":99,\"subType\":\"testRequest\",\"type\":\"test\",\"version\":100}",
|
||||||
|
|||||||
81
hostmap.go
81
hostmap.go
@@ -4,6 +4,7 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"slices"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
@@ -16,12 +17,10 @@ import (
|
|||||||
"github.com/slackhq/nebula/header"
|
"github.com/slackhq/nebula/header"
|
||||||
)
|
)
|
||||||
|
|
||||||
// const ProbeLen = 100
|
|
||||||
const defaultPromoteEvery = 1000 // Count of packets sent before we try moving a tunnel to a preferred underlay ip address
|
const defaultPromoteEvery = 1000 // Count of packets sent before we try moving a tunnel to a preferred underlay ip address
|
||||||
const defaultReQueryEvery = 5000 // Count of packets sent before re-querying a hostinfo to the lighthouse
|
const defaultReQueryEvery = 5000 // Count of packets sent before re-querying a hostinfo to the lighthouse
|
||||||
const defaultReQueryWait = time.Minute // Minimum amount of seconds to wait before re-querying a hostinfo the lighthouse. Evaluated every ReQueryEvery
|
const defaultReQueryWait = time.Minute // Minimum amount of seconds to wait before re-querying a hostinfo the lighthouse. Evaluated every ReQueryEvery
|
||||||
const MaxRemotes = 10
|
const MaxRemotes = 10
|
||||||
const maxRecvError = 4
|
|
||||||
|
|
||||||
// MaxHostInfosPerVpnIp is the max number of hostinfos we will track for a given vpn ip
|
// MaxHostInfosPerVpnIp is the max number of hostinfos we will track for a given vpn ip
|
||||||
// 5 allows for an initial handshake and each host pair re-handshaking twice
|
// 5 allows for an initial handshake and each host pair re-handshaking twice
|
||||||
@@ -68,7 +67,7 @@ type HostMap struct {
|
|||||||
type RelayState struct {
|
type RelayState struct {
|
||||||
sync.RWMutex
|
sync.RWMutex
|
||||||
|
|
||||||
relays map[netip.Addr]struct{} // Set of vpnAddr's of Hosts to use as relays to access this peer
|
relays []netip.Addr // Ordered set of VpnAddrs of Hosts to use as relays to access this peer
|
||||||
// For data race avoidance, the contents of a *Relay are treated immutably. To update a *Relay, copy the existing data,
|
// For data race avoidance, the contents of a *Relay are treated immutably. To update a *Relay, copy the existing data,
|
||||||
// modify what needs to be updated, and store the new modified copy in the relayForByIp and relayForByIdx maps (with
|
// modify what needs to be updated, and store the new modified copy in the relayForByIp and relayForByIdx maps (with
|
||||||
// the RelayState Lock held)
|
// the RelayState Lock held)
|
||||||
@@ -79,7 +78,12 @@ type RelayState struct {
|
|||||||
func (rs *RelayState) DeleteRelay(ip netip.Addr) {
|
func (rs *RelayState) DeleteRelay(ip netip.Addr) {
|
||||||
rs.Lock()
|
rs.Lock()
|
||||||
defer rs.Unlock()
|
defer rs.Unlock()
|
||||||
delete(rs.relays, ip)
|
for idx, val := range rs.relays {
|
||||||
|
if val == ip {
|
||||||
|
rs.relays = append(rs.relays[:idx], rs.relays[idx+1:]...)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rs *RelayState) UpdateRelayForByIpState(vpnIp netip.Addr, state int) {
|
func (rs *RelayState) UpdateRelayForByIpState(vpnIp netip.Addr, state int) {
|
||||||
@@ -124,16 +128,16 @@ func (rs *RelayState) GetRelayForByAddr(addr netip.Addr) (*Relay, bool) {
|
|||||||
func (rs *RelayState) InsertRelayTo(ip netip.Addr) {
|
func (rs *RelayState) InsertRelayTo(ip netip.Addr) {
|
||||||
rs.Lock()
|
rs.Lock()
|
||||||
defer rs.Unlock()
|
defer rs.Unlock()
|
||||||
rs.relays[ip] = struct{}{}
|
if !slices.Contains(rs.relays, ip) {
|
||||||
|
rs.relays = append(rs.relays, ip)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rs *RelayState) CopyRelayIps() []netip.Addr {
|
func (rs *RelayState) CopyRelayIps() []netip.Addr {
|
||||||
|
ret := make([]netip.Addr, len(rs.relays))
|
||||||
rs.RLock()
|
rs.RLock()
|
||||||
defer rs.RUnlock()
|
defer rs.RUnlock()
|
||||||
ret := make([]netip.Addr, 0, len(rs.relays))
|
copy(ret, rs.relays)
|
||||||
for ip := range rs.relays {
|
|
||||||
ret = append(ret, ip)
|
|
||||||
}
|
|
||||||
return ret
|
return ret
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -208,6 +212,18 @@ func (rs *RelayState) InsertRelay(ip netip.Addr, idx uint32, r *Relay) {
|
|||||||
rs.relayForByIdx[idx] = r
|
rs.relayForByIdx[idx] = r
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type NetworkType uint8
|
||||||
|
|
||||||
|
const (
|
||||||
|
NetworkTypeUnknown NetworkType = iota
|
||||||
|
// NetworkTypeVPN is a network that overlaps one or more of the vpnNetworks in our certificate
|
||||||
|
NetworkTypeVPN
|
||||||
|
// NetworkTypeVPNPeer is a network that does not overlap one of our networks
|
||||||
|
NetworkTypeVPNPeer
|
||||||
|
// NetworkTypeUnsafe is a network from Certificate.UnsafeNetworks()
|
||||||
|
NetworkTypeUnsafe
|
||||||
|
)
|
||||||
|
|
||||||
type HostInfo struct {
|
type HostInfo struct {
|
||||||
remote netip.AddrPort
|
remote netip.AddrPort
|
||||||
remotes *RemoteList
|
remotes *RemoteList
|
||||||
@@ -219,11 +235,10 @@ type HostInfo struct {
|
|||||||
// vpnAddrs is a list of vpn addresses assigned to this host that are within our own vpn networks
|
// vpnAddrs is a list of vpn addresses assigned to this host that are within our own vpn networks
|
||||||
// The host may have other vpn addresses that are outside our
|
// The host may have other vpn addresses that are outside our
|
||||||
// vpn networks but were removed because they are not usable
|
// vpn networks but were removed because they are not usable
|
||||||
vpnAddrs []netip.Addr
|
vpnAddrs []netip.Addr
|
||||||
recvError atomic.Uint32
|
|
||||||
|
|
||||||
// networks are both all vpn and unsafe networks assigned to this host
|
// networks is a combination of specific vpn addresses (not prefixes!) and full unsafe networks assigned to this host.
|
||||||
networks *bart.Table[struct{}]
|
networks *bart.Table[NetworkType]
|
||||||
relayState RelayState
|
relayState RelayState
|
||||||
|
|
||||||
// HandshakePacket records the packets used to create this hostinfo
|
// HandshakePacket records the packets used to create this hostinfo
|
||||||
@@ -250,6 +265,14 @@ type HostInfo struct {
|
|||||||
// Used to track other hostinfos for this vpn ip since only 1 can be primary
|
// Used to track other hostinfos for this vpn ip since only 1 can be primary
|
||||||
// Synchronised via hostmap lock and not the hostinfo lock.
|
// Synchronised via hostmap lock and not the hostinfo lock.
|
||||||
next, prev *HostInfo
|
next, prev *HostInfo
|
||||||
|
|
||||||
|
//TODO: in, out, and others might benefit from being an atomic.Int32. We could collapse connectionManager pendingDeletion, relayUsed, and in/out into this 1 thing
|
||||||
|
in, out, pendingDeletion atomic.Bool
|
||||||
|
|
||||||
|
// lastUsed tracks the last time ConnectionManager checked the tunnel and it was in use.
|
||||||
|
// This value will be behind against actual tunnel utilization in the hot path.
|
||||||
|
// This should only be used by the ConnectionManagers ticker routine.
|
||||||
|
lastUsed time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
type ViaSender struct {
|
type ViaSender struct {
|
||||||
@@ -719,26 +742,26 @@ func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, newRemote netip.AddrPort) b
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (i *HostInfo) RecvErrorExceeded() bool {
|
// buildNetworks fills in the networks field of HostInfo. It accepts a cert.Certificate so you never ever mix the network types up.
|
||||||
if i.recvError.Add(1) >= maxRecvError {
|
func (i *HostInfo) buildNetworks(myVpnNetworksTable *bart.Lite, c cert.Certificate) {
|
||||||
return true
|
if len(c.Networks()) == 1 && len(c.UnsafeNetworks()) == 0 {
|
||||||
}
|
if myVpnNetworksTable.Contains(c.Networks()[0].Addr()) {
|
||||||
return true
|
return // Simple case, no BART needed
|
||||||
}
|
}
|
||||||
|
|
||||||
func (i *HostInfo) buildNetworks(networks, unsafeNetworks []netip.Prefix) {
|
|
||||||
if len(networks) == 1 && len(unsafeNetworks) == 0 {
|
|
||||||
// Simple case, no CIDRTree needed
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
i.networks = new(bart.Table[struct{}])
|
i.networks = new(bart.Table[NetworkType])
|
||||||
for _, network := range networks {
|
for _, network := range c.Networks() {
|
||||||
i.networks.Insert(network, struct{}{})
|
nprefix := netip.PrefixFrom(network.Addr(), network.Addr().BitLen())
|
||||||
|
if myVpnNetworksTable.Contains(network.Addr()) {
|
||||||
|
i.networks.Insert(nprefix, NetworkTypeVPN)
|
||||||
|
} else {
|
||||||
|
i.networks.Insert(nprefix, NetworkTypeVPNPeer)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, network := range unsafeNetworks {
|
for _, network := range c.UnsafeNetworks() {
|
||||||
i.networks.Insert(network, struct{}{})
|
i.networks.Insert(network, NetworkTypeUnsafe)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
"github.com/slackhq/nebula/test"
|
"github.com/slackhq/nebula/test"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestHostMap_MakePrimary(t *testing.T) {
|
func TestHostMap_MakePrimary(t *testing.T) {
|
||||||
@@ -210,8 +211,36 @@ func TestHostMap_reload(t *testing.T) {
|
|||||||
assert.Empty(t, hm.GetPreferredRanges())
|
assert.Empty(t, hm.GetPreferredRanges())
|
||||||
|
|
||||||
c.ReloadConfigString("preferred_ranges: [1.1.1.0/24, 10.1.1.0/24]")
|
c.ReloadConfigString("preferred_ranges: [1.1.1.0/24, 10.1.1.0/24]")
|
||||||
assert.EqualValues(t, []string{"1.1.1.0/24", "10.1.1.0/24"}, toS(hm.GetPreferredRanges()))
|
assert.Equal(t, []string{"1.1.1.0/24", "10.1.1.0/24"}, toS(hm.GetPreferredRanges()))
|
||||||
|
|
||||||
c.ReloadConfigString("preferred_ranges: [1.1.1.1/32]")
|
c.ReloadConfigString("preferred_ranges: [1.1.1.1/32]")
|
||||||
assert.EqualValues(t, []string{"1.1.1.1/32"}, toS(hm.GetPreferredRanges()))
|
assert.Equal(t, []string{"1.1.1.1/32"}, toS(hm.GetPreferredRanges()))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHostMap_RelayState(t *testing.T) {
|
||||||
|
h1 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 1}
|
||||||
|
a1 := netip.MustParseAddr("::1")
|
||||||
|
a2 := netip.MustParseAddr("2001::1")
|
||||||
|
|
||||||
|
h1.relayState.InsertRelayTo(a1)
|
||||||
|
assert.Equal(t, []netip.Addr{a1}, h1.relayState.relays)
|
||||||
|
h1.relayState.InsertRelayTo(a2)
|
||||||
|
assert.Equal(t, []netip.Addr{a1, a2}, h1.relayState.relays)
|
||||||
|
// Ensure that the first relay added is the first one returned in the copy
|
||||||
|
currentRelays := h1.relayState.CopyRelayIps()
|
||||||
|
require.Len(t, currentRelays, 2)
|
||||||
|
assert.Equal(t, a1, currentRelays[0])
|
||||||
|
|
||||||
|
// Deleting the last one in the list works ok
|
||||||
|
h1.relayState.DeleteRelay(a2)
|
||||||
|
assert.Equal(t, []netip.Addr{a1}, h1.relayState.relays)
|
||||||
|
|
||||||
|
// Deleting an element not in the list works ok
|
||||||
|
h1.relayState.DeleteRelay(a2)
|
||||||
|
assert.Equal(t, []netip.Addr{a1}, h1.relayState.relays)
|
||||||
|
|
||||||
|
// Deleting the only element in the list works ok
|
||||||
|
h1.relayState.DeleteRelay(a1)
|
||||||
|
assert.Equal(t, []netip.Addr{}, h1.relayState.relays)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
109
inside.go
109
inside.go
@@ -8,6 +8,7 @@ import (
|
|||||||
"github.com/slackhq/nebula/header"
|
"github.com/slackhq/nebula/header"
|
||||||
"github.com/slackhq/nebula/iputil"
|
"github.com/slackhq/nebula/iputil"
|
||||||
"github.com/slackhq/nebula/noiseutil"
|
"github.com/slackhq/nebula/noiseutil"
|
||||||
|
"github.com/slackhq/nebula/routing"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb, out []byte, q int, localCache firewall.ConntrackCache) {
|
func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb, out []byte, q int, localCache firewall.ConntrackCache) {
|
||||||
@@ -21,14 +22,12 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
|
|||||||
|
|
||||||
// Ignore local broadcast packets
|
// Ignore local broadcast packets
|
||||||
if f.dropLocalBroadcast {
|
if f.dropLocalBroadcast {
|
||||||
_, found := f.myBroadcastAddrsTable.Lookup(fwPacket.RemoteAddr)
|
if f.myBroadcastAddrsTable.Contains(fwPacket.RemoteAddr) {
|
||||||
if found {
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
_, found := f.myVpnAddrsTable.Lookup(fwPacket.RemoteAddr)
|
if f.myVpnAddrsTable.Contains(fwPacket.RemoteAddr) {
|
||||||
if found {
|
|
||||||
// Immediately forward packets from self to self.
|
// Immediately forward packets from self to self.
|
||||||
// This should only happen on Darwin-based and FreeBSD hosts, which
|
// This should only happen on Darwin-based and FreeBSD hosts, which
|
||||||
// routes packets from the Nebula addr to the Nebula addr through the Nebula
|
// routes packets from the Nebula addr to the Nebula addr through the Nebula
|
||||||
@@ -49,7 +48,7 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
hostinfo, ready := f.getOrHandshake(fwPacket.RemoteAddr, func(hh *HandshakeHostInfo) {
|
hostinfo, ready := f.getOrHandshakeConsiderRouting(fwPacket, func(hh *HandshakeHostInfo) {
|
||||||
hh.cachePacket(f.l, header.Message, 0, packet, f.sendMessageNow, f.cachedPacketMetrics)
|
hh.cachePacket(f.l, header.Message, 0, packet, f.sendMessageNow, f.cachedPacketMetrics)
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -121,22 +120,93 @@ func (f *Interface) rejectOutside(packet []byte, ci *ConnectionState, hostinfo *
|
|||||||
f.sendNoMetrics(header.Message, 0, ci, hostinfo, netip.AddrPort{}, out, nb, packet, q)
|
f.sendNoMetrics(header.Message, 0, ci, hostinfo, netip.AddrPort{}, out, nb, packet, q)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Handshake will attempt to initiate a tunnel with the provided vpn address. This is a no-op if the tunnel is already established or being established
|
||||||
|
// it does not check if it is within our vpn networks!
|
||||||
func (f *Interface) Handshake(vpnAddr netip.Addr) {
|
func (f *Interface) Handshake(vpnAddr netip.Addr) {
|
||||||
f.getOrHandshake(vpnAddr, nil)
|
f.handshakeManager.GetOrHandshake(vpnAddr, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
// getOrHandshake returns nil if the vpnAddr is not routable.
|
// getOrHandshakeNoRouting returns nil if the vpnAddr is not routable.
|
||||||
// If the 2nd return var is false then the hostinfo is not ready to be used in a tunnel
|
// If the 2nd return var is false then the hostinfo is not ready to be used in a tunnel
|
||||||
func (f *Interface) getOrHandshake(vpnAddr netip.Addr, cacheCallback func(*HandshakeHostInfo)) (*HostInfo, bool) {
|
func (f *Interface) getOrHandshakeNoRouting(vpnAddr netip.Addr, cacheCallback func(*HandshakeHostInfo)) (*HostInfo, bool) {
|
||||||
_, found := f.myVpnNetworksTable.Lookup(vpnAddr)
|
if f.myVpnNetworksTable.Contains(vpnAddr) {
|
||||||
if !found {
|
return f.handshakeManager.GetOrHandshake(vpnAddr, cacheCallback)
|
||||||
vpnAddr = f.inside.RouteFor(vpnAddr)
|
}
|
||||||
if !vpnAddr.IsValid() {
|
|
||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// getOrHandshakeConsiderRouting will try to find the HostInfo to handle this packet, starting a handshake if necessary.
|
||||||
|
// If the 2nd return var is false then the hostinfo is not ready to be used in a tunnel.
|
||||||
|
func (f *Interface) getOrHandshakeConsiderRouting(fwPacket *firewall.Packet, cacheCallback func(*HandshakeHostInfo)) (*HostInfo, bool) {
|
||||||
|
destinationAddr := fwPacket.RemoteAddr
|
||||||
|
|
||||||
|
hostinfo, ready := f.getOrHandshakeNoRouting(destinationAddr, cacheCallback)
|
||||||
|
|
||||||
|
// Host is inside the mesh, no routing required
|
||||||
|
if hostinfo != nil {
|
||||||
|
return hostinfo, ready
|
||||||
|
}
|
||||||
|
|
||||||
|
gateways := f.inside.RoutesFor(destinationAddr)
|
||||||
|
|
||||||
|
switch len(gateways) {
|
||||||
|
case 0:
|
||||||
|
return nil, false
|
||||||
|
case 1:
|
||||||
|
// Single gateway route
|
||||||
|
return f.handshakeManager.GetOrHandshake(gateways[0].Addr(), cacheCallback)
|
||||||
|
default:
|
||||||
|
// Multi gateway route, perform ECMP categorization
|
||||||
|
gatewayAddr, balancingOk := routing.BalancePacket(fwPacket, gateways)
|
||||||
|
|
||||||
|
if !balancingOk {
|
||||||
|
// This happens if the gateway buckets were not calculated, this _should_ never happen
|
||||||
|
f.l.Error("Gateway buckets not calculated, fallback from ECMP to random routing. Please report this bug.")
|
||||||
|
}
|
||||||
|
|
||||||
|
var handshakeInfoForChosenGateway *HandshakeHostInfo
|
||||||
|
var hhReceiver = func(hh *HandshakeHostInfo) {
|
||||||
|
handshakeInfoForChosenGateway = hh
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store the handshakeHostInfo for later.
|
||||||
|
// If this node is not reachable we will attempt other nodes, if none are reachable we will
|
||||||
|
// cache the packet for this gateway.
|
||||||
|
if hostinfo, ready = f.handshakeManager.GetOrHandshake(gatewayAddr, hhReceiver); ready {
|
||||||
|
return hostinfo, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// It appears the selected gateway cannot be reached, find another gateway to fallback on.
|
||||||
|
// The current implementation breaks ECMP but that seems better than no connectivity.
|
||||||
|
// If ECMP is also required when a gateway is down then connectivity status
|
||||||
|
// for each gateway needs to be kept and the weights recalculated when they go up or down.
|
||||||
|
// This would also need to interact with unsafe_route updates through reloading the config or
|
||||||
|
// use of the use_system_route_table option
|
||||||
|
|
||||||
|
if f.l.Level >= logrus.DebugLevel {
|
||||||
|
f.l.WithField("destination", destinationAddr).
|
||||||
|
WithField("originalGateway", gatewayAddr).
|
||||||
|
Debugln("Calculated gateway for ECMP not available, attempting other gateways")
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := range gateways {
|
||||||
|
// Skip the gateway that failed previously
|
||||||
|
if gateways[i].Addr() == gatewayAddr {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// We do not need the HandshakeHostInfo since we cache the packet in the originally chosen gateway
|
||||||
|
if hostinfo, ready = f.handshakeManager.GetOrHandshake(gateways[i].Addr(), nil); ready {
|
||||||
|
return hostinfo, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// No gateways reachable, cache the packet in the originally chosen gateway
|
||||||
|
cacheCallback(handshakeInfoForChosenGateway)
|
||||||
|
return hostinfo, false
|
||||||
}
|
}
|
||||||
|
|
||||||
return f.handshakeManager.GetOrHandshake(vpnAddr, cacheCallback)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, nb, out []byte) {
|
func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, nb, out []byte) {
|
||||||
@@ -161,9 +231,10 @@ func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubTyp
|
|||||||
f.sendNoMetrics(header.Message, st, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, p, nb, out, 0)
|
f.sendNoMetrics(header.Message, st, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, p, nb, out, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
// SendMessageToVpnAddr handles real addr:port lookup and sends to the current best known address for vpnAddr
|
// SendMessageToVpnAddr handles real addr:port lookup and sends to the current best known address for vpnAddr.
|
||||||
|
// This function ignores myVpnNetworksTable, and will always attempt to treat the address as a vpnAddr
|
||||||
func (f *Interface) SendMessageToVpnAddr(t header.MessageType, st header.MessageSubType, vpnAddr netip.Addr, p, nb, out []byte) {
|
func (f *Interface) SendMessageToVpnAddr(t header.MessageType, st header.MessageSubType, vpnAddr netip.Addr, p, nb, out []byte) {
|
||||||
hostInfo, ready := f.getOrHandshake(vpnAddr, func(hh *HandshakeHostInfo) {
|
hostInfo, ready := f.handshakeManager.GetOrHandshake(vpnAddr, func(hh *HandshakeHostInfo) {
|
||||||
hh.cachePacket(f.l, t, st, p, f.SendMessageToHostInfo, f.cachedPacketMetrics)
|
hh.cachePacket(f.l, t, st, p, f.SendMessageToHostInfo, f.cachedPacketMetrics)
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -218,7 +289,7 @@ func (f *Interface) SendVia(via *HostInfo,
|
|||||||
c := via.ConnectionState.messageCounter.Add(1)
|
c := via.ConnectionState.messageCounter.Add(1)
|
||||||
|
|
||||||
out = header.Encode(out, header.Version, header.Message, header.MessageRelay, relay.RemoteIndex, c)
|
out = header.Encode(out, header.Version, header.Message, header.MessageRelay, relay.RemoteIndex, c)
|
||||||
f.connectionManager.Out(via.localIndexId)
|
f.connectionManager.Out(via)
|
||||||
|
|
||||||
// Authenticate the header and payload, but do not encrypt for this message type.
|
// Authenticate the header and payload, but do not encrypt for this message type.
|
||||||
// The payload consists of the inner, unencrypted Nebula header, as well as the end-to-end encrypted payload.
|
// The payload consists of the inner, unencrypted Nebula header, as well as the end-to-end encrypted payload.
|
||||||
@@ -286,7 +357,7 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType
|
|||||||
|
|
||||||
//l.WithField("trace", string(debug.Stack())).Error("out Header ", &Header{Version, t, st, 0, hostinfo.remoteIndexId, c}, p)
|
//l.WithField("trace", string(debug.Stack())).Error("out Header ", &Header{Version, t, st, 0, hostinfo.remoteIndexId, c}, p)
|
||||||
out = header.Encode(out, header.Version, t, st, hostinfo.remoteIndexId, c)
|
out = header.Encode(out, header.Version, t, st, hostinfo.remoteIndexId, c)
|
||||||
f.connectionManager.Out(hostinfo.localIndexId)
|
f.connectionManager.Out(hostinfo)
|
||||||
|
|
||||||
// Query our LH if we haven't since the last time we've been rebound, this will cause the remote to punch against
|
// Query our LH if we haven't since the last time we've been rebound, this will cause the remote to punch against
|
||||||
// all our addrs and enable a faster roaming.
|
// all our addrs and enable a faster roaming.
|
||||||
|
|||||||
62
interface.go
62
interface.go
@@ -24,23 +24,23 @@ import (
|
|||||||
const mtu = 9001
|
const mtu = 9001
|
||||||
|
|
||||||
type InterfaceConfig struct {
|
type InterfaceConfig struct {
|
||||||
HostMap *HostMap
|
HostMap *HostMap
|
||||||
Outside udp.Conn
|
Outside udp.Conn
|
||||||
Inside overlay.Device
|
Inside overlay.Device
|
||||||
pki *PKI
|
pki *PKI
|
||||||
Firewall *Firewall
|
Cipher string
|
||||||
ServeDns bool
|
Firewall *Firewall
|
||||||
HandshakeManager *HandshakeManager
|
ServeDns bool
|
||||||
lightHouse *LightHouse
|
HandshakeManager *HandshakeManager
|
||||||
checkInterval time.Duration
|
lightHouse *LightHouse
|
||||||
pendingDeletionInterval time.Duration
|
connectionManager *connectionManager
|
||||||
DropLocalBroadcast bool
|
DropLocalBroadcast bool
|
||||||
DropMulticast bool
|
DropMulticast bool
|
||||||
routines int
|
routines int
|
||||||
MessageMetrics *MessageMetrics
|
MessageMetrics *MessageMetrics
|
||||||
version string
|
version string
|
||||||
relayManager *relayManager
|
relayManager *relayManager
|
||||||
punchy *Punchy
|
punchy *Punchy
|
||||||
|
|
||||||
tryPromoteEvery uint32
|
tryPromoteEvery uint32
|
||||||
reQueryEvery uint32
|
reQueryEvery uint32
|
||||||
@@ -61,11 +61,11 @@ type Interface struct {
|
|||||||
serveDns bool
|
serveDns bool
|
||||||
createTime time.Time
|
createTime time.Time
|
||||||
lightHouse *LightHouse
|
lightHouse *LightHouse
|
||||||
myBroadcastAddrsTable *bart.Table[struct{}]
|
myBroadcastAddrsTable *bart.Lite
|
||||||
myVpnAddrs []netip.Addr // A list of addresses assigned to us via our certificate
|
myVpnAddrs []netip.Addr // A list of addresses assigned to us via our certificate
|
||||||
myVpnAddrsTable *bart.Table[struct{}] // A table of addresses assigned to us via our certificate
|
myVpnAddrsTable *bart.Lite
|
||||||
myVpnNetworks []netip.Prefix // A list of networks assigned to us via our certificate
|
myVpnNetworks []netip.Prefix // A list of networks assigned to us via our certificate
|
||||||
myVpnNetworksTable *bart.Table[struct{}] // A table of networks assigned to us via our certificate
|
myVpnNetworksTable *bart.Lite
|
||||||
dropLocalBroadcast bool
|
dropLocalBroadcast bool
|
||||||
dropMulticast bool
|
dropMulticast bool
|
||||||
routines int
|
routines int
|
||||||
@@ -157,6 +157,9 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
|
|||||||
if c.Firewall == nil {
|
if c.Firewall == nil {
|
||||||
return nil, errors.New("no firewall rules")
|
return nil, errors.New("no firewall rules")
|
||||||
}
|
}
|
||||||
|
if c.connectionManager == nil {
|
||||||
|
return nil, errors.New("no connection manager")
|
||||||
|
}
|
||||||
|
|
||||||
cs := c.pki.getCertState()
|
cs := c.pki.getCertState()
|
||||||
ifce := &Interface{
|
ifce := &Interface{
|
||||||
@@ -181,7 +184,7 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
|
|||||||
myVpnAddrsTable: cs.myVpnAddrsTable,
|
myVpnAddrsTable: cs.myVpnAddrsTable,
|
||||||
myBroadcastAddrsTable: cs.myVpnBroadcastAddrsTable,
|
myBroadcastAddrsTable: cs.myVpnBroadcastAddrsTable,
|
||||||
relayManager: c.relayManager,
|
relayManager: c.relayManager,
|
||||||
|
connectionManager: c.connectionManager,
|
||||||
conntrackCacheTimeout: c.ConntrackCacheTimeout,
|
conntrackCacheTimeout: c.ConntrackCacheTimeout,
|
||||||
|
|
||||||
metricHandshakes: metrics.GetOrRegisterHistogram("handshakes", nil, metrics.NewExpDecaySample(1028, 0.015)),
|
metricHandshakes: metrics.GetOrRegisterHistogram("handshakes", nil, metrics.NewExpDecaySample(1028, 0.015)),
|
||||||
@@ -198,7 +201,7 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
|
|||||||
ifce.reQueryEvery.Store(c.reQueryEvery)
|
ifce.reQueryEvery.Store(c.reQueryEvery)
|
||||||
ifce.reQueryWait.Store(int64(c.reQueryWait))
|
ifce.reQueryWait.Store(int64(c.reQueryWait))
|
||||||
|
|
||||||
ifce.connectionManager = newConnectionManager(ctx, c.l, ifce, c.checkInterval, c.pendingDeletionInterval, c.punchy)
|
ifce.connectionManager.intf = ifce
|
||||||
|
|
||||||
return ifce, nil
|
return ifce, nil
|
||||||
}
|
}
|
||||||
@@ -219,6 +222,13 @@ func (f *Interface) activate() {
|
|||||||
WithField("boringcrypto", boringEnabled()).
|
WithField("boringcrypto", boringEnabled()).
|
||||||
Info("Nebula interface is active")
|
Info("Nebula interface is active")
|
||||||
|
|
||||||
|
if f.routines > 1 {
|
||||||
|
if !f.inside.SupportsMultiqueue() || !f.outside.SupportsMultipleReaders() {
|
||||||
|
f.routines = 1
|
||||||
|
f.l.Warn("routines is not supported on this platform, falling back to a single routine")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
metrics.GetOrRegisterGauge("routines", nil).Update(int64(f.routines))
|
metrics.GetOrRegisterGauge("routines", nil).Update(int64(f.routines))
|
||||||
|
|
||||||
// Prepare n tun queues
|
// Prepare n tun queues
|
||||||
@@ -410,7 +420,7 @@ func (f *Interface) emitStats(ctx context.Context, i time.Duration) {
|
|||||||
udpStats := udp.NewUDPStatsEmitter(f.writers)
|
udpStats := udp.NewUDPStatsEmitter(f.writers)
|
||||||
|
|
||||||
certExpirationGauge := metrics.GetOrRegisterGauge("certificate.ttl_seconds", nil)
|
certExpirationGauge := metrics.GetOrRegisterGauge("certificate.ttl_seconds", nil)
|
||||||
certDefaultVersion := metrics.GetOrRegisterGauge("certificate.default_version", nil)
|
certInitiatingVersion := metrics.GetOrRegisterGauge("certificate.initiating_version", nil)
|
||||||
certMaxVersion := metrics.GetOrRegisterGauge("certificate.max_version", nil)
|
certMaxVersion := metrics.GetOrRegisterGauge("certificate.max_version", nil)
|
||||||
|
|
||||||
for {
|
for {
|
||||||
@@ -425,7 +435,7 @@ func (f *Interface) emitStats(ctx context.Context, i time.Duration) {
|
|||||||
certState := f.pki.getCertState()
|
certState := f.pki.getCertState()
|
||||||
defaultCrt := certState.GetDefaultCertificate()
|
defaultCrt := certState.GetDefaultCertificate()
|
||||||
certExpirationGauge.Update(int64(defaultCrt.NotAfter().Sub(time.Now()) / time.Second))
|
certExpirationGauge.Update(int64(defaultCrt.NotAfter().Sub(time.Now()) / time.Second))
|
||||||
certDefaultVersion.Update(int64(defaultCrt.Version()))
|
certInitiatingVersion.Update(int64(defaultCrt.Version()))
|
||||||
|
|
||||||
// Report the max certificate version we are capable of using
|
// Report the max certificate version we are capable of using
|
||||||
if certState.v2Cert != nil {
|
if certState.v2Cert != nil {
|
||||||
|
|||||||
283
lighthouse.go
283
lighthouse.go
@@ -24,6 +24,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
var ErrHostNotKnown = errors.New("host not known")
|
var ErrHostNotKnown = errors.New("host not known")
|
||||||
|
var ErrBadDetailsVpnAddr = errors.New("invalid packet, malformed detailsVpnAddr")
|
||||||
|
|
||||||
type LightHouse struct {
|
type LightHouse struct {
|
||||||
//TODO: We need a timer wheel to kick out vpnAddrs that haven't reported in a long time
|
//TODO: We need a timer wheel to kick out vpnAddrs that haven't reported in a long time
|
||||||
@@ -32,7 +33,7 @@ type LightHouse struct {
|
|||||||
amLighthouse bool
|
amLighthouse bool
|
||||||
|
|
||||||
myVpnNetworks []netip.Prefix
|
myVpnNetworks []netip.Prefix
|
||||||
myVpnNetworksTable *bart.Table[struct{}]
|
myVpnNetworksTable *bart.Lite
|
||||||
punchConn udp.Conn
|
punchConn udp.Conn
|
||||||
punchy *Punchy
|
punchy *Punchy
|
||||||
|
|
||||||
@@ -56,7 +57,7 @@ type LightHouse struct {
|
|||||||
// staticList exists to avoid having a bool in each addrMap entry
|
// staticList exists to avoid having a bool in each addrMap entry
|
||||||
// since static should be rare
|
// since static should be rare
|
||||||
staticList atomic.Pointer[map[netip.Addr]struct{}]
|
staticList atomic.Pointer[map[netip.Addr]struct{}]
|
||||||
lighthouses atomic.Pointer[map[netip.Addr]struct{}]
|
lighthouses atomic.Pointer[[]netip.Addr]
|
||||||
|
|
||||||
interval atomic.Int64
|
interval atomic.Int64
|
||||||
updateCancel context.CancelFunc
|
updateCancel context.CancelFunc
|
||||||
@@ -107,7 +108,7 @@ func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C,
|
|||||||
queryChan: make(chan netip.Addr, c.GetUint32("handshakes.query_buffer", 64)),
|
queryChan: make(chan netip.Addr, c.GetUint32("handshakes.query_buffer", 64)),
|
||||||
l: l,
|
l: l,
|
||||||
}
|
}
|
||||||
lighthouses := make(map[netip.Addr]struct{})
|
lighthouses := make([]netip.Addr, 0)
|
||||||
h.lighthouses.Store(&lighthouses)
|
h.lighthouses.Store(&lighthouses)
|
||||||
staticList := make(map[netip.Addr]struct{})
|
staticList := make(map[netip.Addr]struct{})
|
||||||
h.staticList.Store(&staticList)
|
h.staticList.Store(&staticList)
|
||||||
@@ -143,7 +144,7 @@ func (lh *LightHouse) GetStaticHostList() map[netip.Addr]struct{} {
|
|||||||
return *lh.staticList.Load()
|
return *lh.staticList.Load()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (lh *LightHouse) GetLighthouses() map[netip.Addr]struct{} {
|
func (lh *LightHouse) GetLighthouses() []netip.Addr {
|
||||||
return *lh.lighthouses.Load()
|
return *lh.lighthouses.Load()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -201,8 +202,7 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error {
|
|||||||
|
|
||||||
//TODO: we could technically insert all returned addrs instead of just the first one if a dns lookup was used
|
//TODO: we could technically insert all returned addrs instead of just the first one if a dns lookup was used
|
||||||
addr := addrs[0].Unmap()
|
addr := addrs[0].Unmap()
|
||||||
_, found := lh.myVpnNetworksTable.Lookup(addr)
|
if lh.myVpnNetworksTable.Contains(addr) {
|
||||||
if found {
|
|
||||||
lh.l.WithField("addr", rawAddr).WithField("entry", i+1).
|
lh.l.WithField("addr", rawAddr).WithField("entry", i+1).
|
||||||
Warn("Ignoring lighthouse.advertise_addrs report because it is within the nebula network range")
|
Warn("Ignoring lighthouse.advertise_addrs report because it is within the nebula network range")
|
||||||
continue
|
continue
|
||||||
@@ -307,13 +307,12 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if initial || c.HasChanged("lighthouse.hosts") {
|
if initial || c.HasChanged("lighthouse.hosts") {
|
||||||
lhMap := make(map[netip.Addr]struct{})
|
lhList, err := lh.parseLighthouses(c)
|
||||||
err := lh.parseLighthouses(c, lhMap)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
lh.lighthouses.Store(&lhMap)
|
lh.lighthouses.Store(&lhList)
|
||||||
if !initial {
|
if !initial {
|
||||||
//NOTE: we are not tearing down existing lighthouse connections because they might be used for non lighthouse traffic
|
//NOTE: we are not tearing down existing lighthouse connections because they might be used for non lighthouse traffic
|
||||||
lh.l.Info("lighthouse.hosts has changed")
|
lh.l.Info("lighthouse.hosts has changed")
|
||||||
@@ -347,37 +346,38 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (lh *LightHouse) parseLighthouses(c *config.C, lhMap map[netip.Addr]struct{}) error {
|
func (lh *LightHouse) parseLighthouses(c *config.C) ([]netip.Addr, error) {
|
||||||
lhs := c.GetStringSlice("lighthouse.hosts", []string{})
|
lhs := c.GetStringSlice("lighthouse.hosts", []string{})
|
||||||
if lh.amLighthouse && len(lhs) != 0 {
|
if lh.amLighthouse && len(lhs) != 0 {
|
||||||
lh.l.Warn("lighthouse.am_lighthouse enabled on node but upstream lighthouses exist in config")
|
lh.l.Warn("lighthouse.am_lighthouse enabled on node but upstream lighthouses exist in config")
|
||||||
}
|
}
|
||||||
|
out := make([]netip.Addr, len(lhs))
|
||||||
|
|
||||||
for i, host := range lhs {
|
for i, host := range lhs {
|
||||||
addr, err := netip.ParseAddr(host)
|
addr, err := netip.ParseAddr(host)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return util.NewContextualError("Unable to parse lighthouse host entry", m{"host": host, "entry": i + 1}, err)
|
return nil, util.NewContextualError("Unable to parse lighthouse host entry", m{"host": host, "entry": i + 1}, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
_, found := lh.myVpnNetworksTable.Lookup(addr)
|
if !lh.myVpnNetworksTable.Contains(addr) {
|
||||||
if !found {
|
lh.l.WithFields(m{"vpnAddr": addr, "networks": lh.myVpnNetworks}).
|
||||||
return util.NewContextualError("lighthouse host is not in our networks, invalid", m{"vpnAddr": addr, "networks": lh.myVpnNetworks}, nil)
|
Warn("lighthouse host is not within our networks, lighthouse functionality will work but layer 3 network traffic to the lighthouse will not")
|
||||||
}
|
}
|
||||||
lhMap[addr] = struct{}{}
|
out[i] = addr
|
||||||
}
|
}
|
||||||
|
|
||||||
if !lh.amLighthouse && len(lhMap) == 0 {
|
if !lh.amLighthouse && len(out) == 0 {
|
||||||
lh.l.Warn("No lighthouse.hosts configured, this host will only be able to initiate tunnels with static_host_map entries")
|
lh.l.Warn("No lighthouse.hosts configured, this host will only be able to initiate tunnels with static_host_map entries")
|
||||||
}
|
}
|
||||||
|
|
||||||
staticList := lh.GetStaticHostList()
|
staticList := lh.GetStaticHostList()
|
||||||
for lhAddr, _ := range lhMap {
|
for i := range out {
|
||||||
if _, ok := staticList[lhAddr]; !ok {
|
if _, ok := staticList[out[i]]; !ok {
|
||||||
return fmt.Errorf("lighthouse %s does not have a static_host_map entry", lhAddr)
|
return nil, fmt.Errorf("lighthouse %s does not have a static_host_map entry", out[i])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return out, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func getStaticMapCadence(c *config.C) (time.Duration, error) {
|
func getStaticMapCadence(c *config.C) (time.Duration, error) {
|
||||||
@@ -422,7 +422,7 @@ func (lh *LightHouse) loadStaticMap(c *config.C, staticList map[netip.Addr]struc
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
shm := c.GetMap("static_host_map", map[interface{}]interface{}{})
|
shm := c.GetMap("static_host_map", map[string]any{})
|
||||||
i := 0
|
i := 0
|
||||||
|
|
||||||
for k, v := range shm {
|
for k, v := range shm {
|
||||||
@@ -431,14 +431,14 @@ func (lh *LightHouse) loadStaticMap(c *config.C, staticList map[netip.Addr]struc
|
|||||||
return util.NewContextualError("Unable to parse static_host_map entry", m{"host": k, "entry": i + 1}, err)
|
return util.NewContextualError("Unable to parse static_host_map entry", m{"host": k, "entry": i + 1}, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
_, found := lh.myVpnNetworksTable.Lookup(vpnAddr)
|
if !lh.myVpnNetworksTable.Contains(vpnAddr) {
|
||||||
if !found {
|
lh.l.WithFields(m{"vpnAddr": vpnAddr, "networks": lh.myVpnNetworks, "entry": i + 1}).
|
||||||
return util.NewContextualError("static_host_map key is not in our network, invalid", m{"vpnAddr": vpnAddr, "networks": lh.myVpnNetworks, "entry": i + 1}, nil)
|
Warn("static_host_map key is not within our networks, layer 3 network traffic to this host will not work")
|
||||||
}
|
}
|
||||||
|
|
||||||
vals, ok := v.([]interface{})
|
vals, ok := v.([]any)
|
||||||
if !ok {
|
if !ok {
|
||||||
vals = []interface{}{v}
|
vals = []any{v}
|
||||||
}
|
}
|
||||||
remoteAddrs := []string{}
|
remoteAddrs := []string{}
|
||||||
for _, v := range vals {
|
for _, v := range vals {
|
||||||
@@ -489,7 +489,7 @@ func (lh *LightHouse) QueryCache(vpnAddrs []netip.Addr) *RemoteList {
|
|||||||
lh.Lock()
|
lh.Lock()
|
||||||
defer lh.Unlock()
|
defer lh.Unlock()
|
||||||
// Add an entry if we don't already have one
|
// Add an entry if we don't already have one
|
||||||
return lh.unlockedGetRemoteList(vpnAddrs)
|
return lh.unlockedGetRemoteList(vpnAddrs) //todo CERT-V2 this contains addrmap lookups we could potentially skip
|
||||||
}
|
}
|
||||||
|
|
||||||
// queryAndPrepMessage is a lock helper on RemoteList, assisting the caller to build a lighthouse message containing
|
// queryAndPrepMessage is a lock helper on RemoteList, assisting the caller to build a lighthouse message containing
|
||||||
@@ -522,11 +522,15 @@ func (lh *LightHouse) queryAndPrepMessage(vpnAddr netip.Addr, f func(*cache) (in
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (lh *LightHouse) DeleteVpnAddrs(allVpnAddrs []netip.Addr) {
|
func (lh *LightHouse) DeleteVpnAddrs(allVpnAddrs []netip.Addr) {
|
||||||
// First we check the static mapping
|
// First we check the static host map. If any of the VpnAddrs to be deleted are present, do nothing.
|
||||||
// and do nothing if it is there
|
staticList := lh.GetStaticHostList()
|
||||||
if _, ok := lh.GetStaticHostList()[allVpnAddrs[0]]; ok {
|
for _, addr := range allVpnAddrs {
|
||||||
return
|
if _, ok := staticList[addr]; ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// None of the VpnAddrs were present. Now we can do the deletes.
|
||||||
lh.Lock()
|
lh.Lock()
|
||||||
rm, ok := lh.addrMap[allVpnAddrs[0]]
|
rm, ok := lh.addrMap[allVpnAddrs[0]]
|
||||||
if ok {
|
if ok {
|
||||||
@@ -568,7 +572,7 @@ func (lh *LightHouse) addStaticRemotes(i int, d time.Duration, network string, t
|
|||||||
am.unlockedSetHostnamesResults(hr)
|
am.unlockedSetHostnamesResults(hr)
|
||||||
|
|
||||||
for _, addrPort := range hr.GetAddrs() {
|
for _, addrPort := range hr.GetAddrs() {
|
||||||
if !lh.shouldAdd(vpnAddr, addrPort.Addr()) {
|
if !lh.shouldAdd([]netip.Addr{vpnAddr}, addrPort.Addr()) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
switch {
|
switch {
|
||||||
@@ -630,31 +634,37 @@ func (lh *LightHouse) addCalculatedRemotes(vpnAddr netip.Addr) bool {
|
|||||||
return len(calculatedV4) > 0 || len(calculatedV6) > 0
|
return len(calculatedV4) > 0 || len(calculatedV6) > 0
|
||||||
}
|
}
|
||||||
|
|
||||||
// unlockedGetRemoteList
|
// unlockedGetRemoteList assumes you have the lh lock
|
||||||
// assumes you have the lh lock
|
|
||||||
func (lh *LightHouse) unlockedGetRemoteList(allAddrs []netip.Addr) *RemoteList {
|
func (lh *LightHouse) unlockedGetRemoteList(allAddrs []netip.Addr) *RemoteList {
|
||||||
am, ok := lh.addrMap[allAddrs[0]]
|
// before we go and make a new remotelist, we need to make sure we don't have one for any of this set of vpnaddrs yet
|
||||||
if !ok {
|
for i, addr := range allAddrs {
|
||||||
am = NewRemoteList(allAddrs, func(a netip.Addr) bool { return lh.shouldAdd(allAddrs[0], a) })
|
am, ok := lh.addrMap[addr]
|
||||||
for _, addr := range allAddrs {
|
if ok {
|
||||||
lh.addrMap[addr] = am
|
if i != 0 {
|
||||||
|
lh.addrMap[allAddrs[0]] = am
|
||||||
|
}
|
||||||
|
return am
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
am := NewRemoteList(allAddrs, lh.shouldAdd)
|
||||||
|
for _, addr := range allAddrs {
|
||||||
|
lh.addrMap[addr] = am
|
||||||
|
}
|
||||||
return am
|
return am
|
||||||
}
|
}
|
||||||
|
|
||||||
func (lh *LightHouse) shouldAdd(vpnAddr netip.Addr, to netip.Addr) bool {
|
func (lh *LightHouse) shouldAdd(vpnAddrs []netip.Addr, to netip.Addr) bool {
|
||||||
allow := lh.GetRemoteAllowList().Allow(vpnAddr, to)
|
allow := lh.GetRemoteAllowList().AllowAll(vpnAddrs, to)
|
||||||
if lh.l.Level >= logrus.TraceLevel {
|
if lh.l.Level >= logrus.TraceLevel {
|
||||||
lh.l.WithField("vpnAddr", vpnAddr).WithField("udpAddr", to).WithField("allow", allow).
|
lh.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", to).WithField("allow", allow).
|
||||||
Trace("remoteAllowList.Allow")
|
Trace("remoteAllowList.Allow")
|
||||||
}
|
}
|
||||||
if !allow {
|
if !allow {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
_, found := lh.myVpnNetworksTable.Lookup(to)
|
if lh.myVpnNetworksTable.Contains(to) {
|
||||||
if found {
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -674,8 +684,7 @@ func (lh *LightHouse) unlockedShouldAddV4(vpnAddr netip.Addr, to *V4AddrPort) bo
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
_, found := lh.myVpnNetworksTable.Lookup(udpAddr.Addr())
|
if lh.myVpnNetworksTable.Contains(udpAddr.Addr()) {
|
||||||
if found {
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -695,8 +704,7 @@ func (lh *LightHouse) unlockedShouldAddV6(vpnAddr netip.Addr, to *V6AddrPort) bo
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
_, found := lh.myVpnNetworksTable.Lookup(udpAddr.Addr())
|
if lh.myVpnNetworksTable.Contains(udpAddr.Addr()) {
|
||||||
if found {
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -704,19 +712,22 @@ func (lh *LightHouse) unlockedShouldAddV6(vpnAddr netip.Addr, to *V6AddrPort) bo
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (lh *LightHouse) IsLighthouseAddr(vpnAddr netip.Addr) bool {
|
func (lh *LightHouse) IsLighthouseAddr(vpnAddr netip.Addr) bool {
|
||||||
if _, ok := lh.GetLighthouses()[vpnAddr]; ok {
|
l := lh.GetLighthouses()
|
||||||
return true
|
for i := range l {
|
||||||
|
if l[i] == vpnAddr {
|
||||||
|
return true
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: CERT-V2 IsLighthouseAddr should be sufficient, we just need to update the vpnAddrs for lighthouses after a handshake
|
func (lh *LightHouse) IsAnyLighthouseAddr(vpnAddrs []netip.Addr) bool {
|
||||||
// so that we know all the lighthouse vpnAddrs, not just the ones we were configured to talk to initially
|
|
||||||
func (lh *LightHouse) IsAnyLighthouseAddr(vpnAddr []netip.Addr) bool {
|
|
||||||
l := lh.GetLighthouses()
|
l := lh.GetLighthouses()
|
||||||
for _, a := range vpnAddr {
|
for i := range vpnAddrs {
|
||||||
if _, ok := l[a]; ok {
|
for j := range l {
|
||||||
return true
|
if l[j] == vpnAddrs[i] {
|
||||||
|
return true
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return false
|
return false
|
||||||
@@ -758,12 +769,12 @@ func (lh *LightHouse) innerQueryServer(addr netip.Addr, nb, out []byte) {
|
|||||||
queried := 0
|
queried := 0
|
||||||
lighthouses := lh.GetLighthouses()
|
lighthouses := lh.GetLighthouses()
|
||||||
|
|
||||||
for lhVpnAddr := range lighthouses {
|
for _, lhVpnAddr := range lighthouses {
|
||||||
hi := lh.ifce.GetHostInfo(lhVpnAddr)
|
hi := lh.ifce.GetHostInfo(lhVpnAddr)
|
||||||
if hi != nil {
|
if hi != nil {
|
||||||
v = hi.ConnectionState.myCert.Version()
|
v = hi.ConnectionState.myCert.Version()
|
||||||
} else {
|
} else {
|
||||||
v = lh.ifce.GetCertState().defaultVersion
|
v = lh.ifce.GetCertState().initiatingVersion
|
||||||
}
|
}
|
||||||
|
|
||||||
if v == cert.Version1 {
|
if v == cert.Version1 {
|
||||||
@@ -856,8 +867,7 @@ func (lh *LightHouse) SendUpdate() {
|
|||||||
|
|
||||||
lal := lh.GetLocalAllowList()
|
lal := lh.GetLocalAllowList()
|
||||||
for _, e := range localAddrs(lh.l, lal) {
|
for _, e := range localAddrs(lh.l, lal) {
|
||||||
_, found := lh.myVpnNetworksTable.Lookup(e)
|
if lh.myVpnNetworksTable.Contains(e) {
|
||||||
if found {
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -877,13 +887,13 @@ func (lh *LightHouse) SendUpdate() {
|
|||||||
updated := 0
|
updated := 0
|
||||||
lighthouses := lh.GetLighthouses()
|
lighthouses := lh.GetLighthouses()
|
||||||
|
|
||||||
for lhVpnAddr := range lighthouses {
|
for _, lhVpnAddr := range lighthouses {
|
||||||
var v cert.Version
|
var v cert.Version
|
||||||
hi := lh.ifce.GetHostInfo(lhVpnAddr)
|
hi := lh.ifce.GetHostInfo(lhVpnAddr)
|
||||||
if hi != nil {
|
if hi != nil {
|
||||||
v = hi.ConnectionState.myCert.Version()
|
v = hi.ConnectionState.myCert.Version()
|
||||||
} else {
|
} else {
|
||||||
v = lh.ifce.GetCertState().defaultVersion
|
v = lh.ifce.GetCertState().initiatingVersion
|
||||||
}
|
}
|
||||||
if v == cert.Version1 {
|
if v == cert.Version1 {
|
||||||
if v1Update == nil {
|
if v1Update == nil {
|
||||||
@@ -935,7 +945,6 @@ func (lh *LightHouse) SendUpdate() {
|
|||||||
V4AddrPorts: v4,
|
V4AddrPorts: v4,
|
||||||
V6AddrPorts: v6,
|
V6AddrPorts: v6,
|
||||||
RelayVpnAddrs: relays,
|
RelayVpnAddrs: relays,
|
||||||
VpnAddr: netAddrToProtoAddr(lh.myVpnNetworks[0].Addr()),
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1055,19 +1064,19 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, fromVpnAddrs []neti
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
useVersion := cert.Version1
|
queryVpnAddr, useVersion, err := n.Details.GetVpnAddrAndVersion()
|
||||||
var queryVpnAddr netip.Addr
|
if err != nil {
|
||||||
if n.Details.OldVpnAddr != 0 {
|
|
||||||
b := [4]byte{}
|
|
||||||
binary.BigEndian.PutUint32(b[:], n.Details.OldVpnAddr)
|
|
||||||
queryVpnAddr = netip.AddrFrom4(b)
|
|
||||||
useVersion = 1
|
|
||||||
} else if n.Details.VpnAddr != nil {
|
|
||||||
queryVpnAddr = protoAddrToNetAddr(n.Details.VpnAddr)
|
|
||||||
useVersion = 2
|
|
||||||
} else {
|
|
||||||
if lhh.l.Level >= logrus.DebugLevel {
|
if lhh.l.Level >= logrus.DebugLevel {
|
||||||
lhh.l.WithField("from", fromVpnAddrs).WithField("details", n.Details).Debugln("Dropping malformed HostQuery")
|
lhh.l.WithField("from", fromVpnAddrs).WithField("details", n.Details).
|
||||||
|
Debugln("Dropping malformed HostQuery")
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if useVersion == cert.Version1 && queryVpnAddr.Is6() {
|
||||||
|
// this case really shouldn't be possible to represent, but reject it anyway.
|
||||||
|
if lhh.l.Level >= logrus.DebugLevel {
|
||||||
|
lhh.l.WithField("vpnAddrs", fromVpnAddrs).WithField("queryVpnAddr", queryVpnAddr).
|
||||||
|
Debugln("invalid vpn addr for v1 handleHostQuery")
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -1076,9 +1085,6 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, fromVpnAddrs []neti
|
|||||||
n = lhh.resetMeta()
|
n = lhh.resetMeta()
|
||||||
n.Type = NebulaMeta_HostQueryReply
|
n.Type = NebulaMeta_HostQueryReply
|
||||||
if useVersion == cert.Version1 {
|
if useVersion == cert.Version1 {
|
||||||
if !queryVpnAddr.Is4() {
|
|
||||||
return 0, fmt.Errorf("invalid vpn addr for v1 handleHostQuery")
|
|
||||||
}
|
|
||||||
b := queryVpnAddr.As4()
|
b := queryVpnAddr.As4()
|
||||||
n.Details.OldVpnAddr = binary.BigEndian.Uint32(b[:])
|
n.Details.OldVpnAddr = binary.BigEndian.Uint32(b[:])
|
||||||
} else {
|
} else {
|
||||||
@@ -1114,7 +1120,7 @@ func (lhh *LightHouseHandler) sendHostPunchNotification(n *NebulaMeta, fromVpnAd
|
|||||||
targetHI := lhh.lh.ifce.GetHostInfo(punchNotifDest)
|
targetHI := lhh.lh.ifce.GetHostInfo(punchNotifDest)
|
||||||
var useVersion cert.Version
|
var useVersion cert.Version
|
||||||
if targetHI == nil {
|
if targetHI == nil {
|
||||||
useVersion = lhh.lh.ifce.GetCertState().defaultVersion
|
useVersion = lhh.lh.ifce.GetCertState().initiatingVersion
|
||||||
} else {
|
} else {
|
||||||
crt := targetHI.GetCert().Certificate
|
crt := targetHI.GetCert().Certificate
|
||||||
useVersion = crt.Version()
|
useVersion = crt.Version()
|
||||||
@@ -1123,8 +1129,9 @@ func (lhh *LightHouseHandler) sendHostPunchNotification(n *NebulaMeta, fromVpnAd
|
|||||||
if ok {
|
if ok {
|
||||||
whereToPunch = newDest
|
whereToPunch = newDest
|
||||||
} else {
|
} else {
|
||||||
//TODO: CERT-V2 this means the destination will have no addresses in common with the punch-ee
|
if lhh.l.Level >= logrus.DebugLevel {
|
||||||
//choosing to do nothing for now, but maybe we return an error?
|
lhh.l.WithField("to", crt.Networks()).Debugln("unable to punch to host, no addresses in common")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1183,19 +1190,17 @@ func (lhh *LightHouseHandler) coalesceAnswers(v cert.Version, c *cache, n *Nebul
|
|||||||
if !r.Is4() {
|
if !r.Is4() {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
b = r.As4()
|
b = r.As4()
|
||||||
n.Details.OldRelayVpnAddrs = append(n.Details.OldRelayVpnAddrs, binary.BigEndian.Uint32(b[:]))
|
n.Details.OldRelayVpnAddrs = append(n.Details.OldRelayVpnAddrs, binary.BigEndian.Uint32(b[:]))
|
||||||
}
|
}
|
||||||
|
|
||||||
} else if v == cert.Version2 {
|
} else if v == cert.Version2 {
|
||||||
for _, r := range c.relay.relay {
|
for _, r := range c.relay.relay {
|
||||||
n.Details.RelayVpnAddrs = append(n.Details.RelayVpnAddrs, netAddrToProtoAddr(r))
|
n.Details.RelayVpnAddrs = append(n.Details.RelayVpnAddrs, netAddrToProtoAddr(r))
|
||||||
}
|
}
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
//TODO: CERT-V2 don't panic
|
if lhh.l.Level >= logrus.DebugLevel {
|
||||||
panic("unsupported version")
|
lhh.l.WithField("version", v).Debug("unsupported protocol version")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1205,18 +1210,16 @@ func (lhh *LightHouseHandler) handleHostQueryReply(n *NebulaMeta, fromVpnAddrs [
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
lhh.lh.Lock()
|
certVpnAddr, _, err := n.Details.GetVpnAddrAndVersion()
|
||||||
|
if err != nil {
|
||||||
var certVpnAddr netip.Addr
|
if lhh.l.Level >= logrus.DebugLevel {
|
||||||
if n.Details.OldVpnAddr != 0 {
|
lhh.l.WithError(err).WithField("vpnAddrs", fromVpnAddrs).Error("dropping malformed HostQueryReply")
|
||||||
b := [4]byte{}
|
}
|
||||||
binary.BigEndian.PutUint32(b[:], n.Details.OldVpnAddr)
|
return
|
||||||
certVpnAddr = netip.AddrFrom4(b)
|
|
||||||
} else if n.Details.VpnAddr != nil {
|
|
||||||
certVpnAddr = protoAddrToNetAddr(n.Details.VpnAddr)
|
|
||||||
}
|
}
|
||||||
relays := n.Details.GetRelays()
|
relays := n.Details.GetRelays()
|
||||||
|
|
||||||
|
lhh.lh.Lock()
|
||||||
am := lhh.lh.unlockedGetRemoteList([]netip.Addr{certVpnAddr})
|
am := lhh.lh.unlockedGetRemoteList([]netip.Addr{certVpnAddr})
|
||||||
am.Lock()
|
am.Lock()
|
||||||
lhh.lh.Unlock()
|
lhh.lh.Unlock()
|
||||||
@@ -1241,27 +1244,24 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, fromVp
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// not using GetVpnAddrAndVersion because we don't want to error on a blank detailsVpnAddr
|
||||||
var detailsVpnAddr netip.Addr
|
var detailsVpnAddr netip.Addr
|
||||||
useVersion := cert.Version1
|
var useVersion cert.Version
|
||||||
if n.Details.OldVpnAddr != 0 {
|
if n.Details.OldVpnAddr != 0 { //v1 always sets this field
|
||||||
b := [4]byte{}
|
b := [4]byte{}
|
||||||
binary.BigEndian.PutUint32(b[:], n.Details.OldVpnAddr)
|
binary.BigEndian.PutUint32(b[:], n.Details.OldVpnAddr)
|
||||||
detailsVpnAddr = netip.AddrFrom4(b)
|
detailsVpnAddr = netip.AddrFrom4(b)
|
||||||
useVersion = cert.Version1
|
useVersion = cert.Version1
|
||||||
} else if n.Details.VpnAddr != nil {
|
} else if n.Details.VpnAddr != nil { //this field is "optional" in v2, but if it's set, we should enforce it
|
||||||
detailsVpnAddr = protoAddrToNetAddr(n.Details.VpnAddr)
|
detailsVpnAddr = protoAddrToNetAddr(n.Details.VpnAddr)
|
||||||
useVersion = cert.Version2
|
useVersion = cert.Version2
|
||||||
} else {
|
} else {
|
||||||
if lhh.l.Level >= logrus.DebugLevel {
|
detailsVpnAddr = netip.Addr{}
|
||||||
lhh.l.WithField("details", n.Details).Debugf("dropping invalid HostUpdateNotification")
|
useVersion = cert.Version2
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//TODO: CERT-V2 hosts with only v2 certs cannot provide their ipv6 addr when contacting the lighthouse via v4?
|
//Simple check that the host sent this not someone else, if detailsVpnAddr is filled
|
||||||
//TODO: CERT-V2 why do we care about the vpnAddr in the packet? We know where it came from, right?
|
if detailsVpnAddr.IsValid() && !slices.Contains(fromVpnAddrs, detailsVpnAddr) {
|
||||||
//Simple check that the host sent this not someone else
|
|
||||||
if !slices.Contains(fromVpnAddrs, detailsVpnAddr) {
|
|
||||||
if lhh.l.Level >= logrus.DebugLevel {
|
if lhh.l.Level >= logrus.DebugLevel {
|
||||||
lhh.l.WithField("vpnAddrs", fromVpnAddrs).WithField("answer", detailsVpnAddr).Debugln("Host sent invalid update")
|
lhh.l.WithField("vpnAddrs", fromVpnAddrs).WithField("answer", detailsVpnAddr).Debugln("Host sent invalid update")
|
||||||
}
|
}
|
||||||
@@ -1275,24 +1275,24 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, fromVp
|
|||||||
am.Lock()
|
am.Lock()
|
||||||
lhh.lh.Unlock()
|
lhh.lh.Unlock()
|
||||||
|
|
||||||
am.unlockedSetV4(fromVpnAddrs[0], detailsVpnAddr, n.Details.V4AddrPorts, lhh.lh.unlockedShouldAddV4)
|
am.unlockedSetV4(fromVpnAddrs[0], fromVpnAddrs[0], n.Details.V4AddrPorts, lhh.lh.unlockedShouldAddV4)
|
||||||
am.unlockedSetV6(fromVpnAddrs[0], detailsVpnAddr, n.Details.V6AddrPorts, lhh.lh.unlockedShouldAddV6)
|
am.unlockedSetV6(fromVpnAddrs[0], fromVpnAddrs[0], n.Details.V6AddrPorts, lhh.lh.unlockedShouldAddV6)
|
||||||
am.unlockedSetRelay(fromVpnAddrs[0], relays)
|
am.unlockedSetRelay(fromVpnAddrs[0], relays)
|
||||||
am.Unlock()
|
am.Unlock()
|
||||||
|
|
||||||
n = lhh.resetMeta()
|
n = lhh.resetMeta()
|
||||||
n.Type = NebulaMeta_HostUpdateNotificationAck
|
n.Type = NebulaMeta_HostUpdateNotificationAck
|
||||||
|
switch useVersion {
|
||||||
if useVersion == cert.Version1 {
|
case cert.Version1:
|
||||||
if !fromVpnAddrs[0].Is4() {
|
if !fromVpnAddrs[0].Is4() {
|
||||||
lhh.l.WithField("vpnAddrs", fromVpnAddrs).Error("Can not send HostUpdateNotificationAck for a ipv6 vpn ip in a v1 message")
|
lhh.l.WithField("vpnAddrs", fromVpnAddrs).Error("Can not send HostUpdateNotificationAck for a ipv6 vpn ip in a v1 message")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
vpnAddrB := fromVpnAddrs[0].As4()
|
vpnAddrB := fromVpnAddrs[0].As4()
|
||||||
n.Details.OldVpnAddr = binary.BigEndian.Uint32(vpnAddrB[:])
|
n.Details.OldVpnAddr = binary.BigEndian.Uint32(vpnAddrB[:])
|
||||||
} else if useVersion == cert.Version2 {
|
case cert.Version2:
|
||||||
n.Details.VpnAddr = netAddrToProtoAddr(fromVpnAddrs[0])
|
// do nothing, we want to send a blank message
|
||||||
} else {
|
default:
|
||||||
lhh.l.WithField("useVersion", useVersion).Error("invalid protocol version")
|
lhh.l.WithField("useVersion", useVersion).Error("invalid protocol version")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -1310,13 +1310,20 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, fromVp
|
|||||||
func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, fromVpnAddrs []netip.Addr, w EncWriter) {
|
func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, fromVpnAddrs []netip.Addr, w EncWriter) {
|
||||||
//It's possible the lighthouse is communicating with us using a non primary vpn addr,
|
//It's possible the lighthouse is communicating with us using a non primary vpn addr,
|
||||||
//which means we need to compare all fromVpnAddrs against all configured lighthouse vpn addrs.
|
//which means we need to compare all fromVpnAddrs against all configured lighthouse vpn addrs.
|
||||||
//maybe one day we'll have a better idea, if it matters.
|
|
||||||
if !lhh.lh.IsAnyLighthouseAddr(fromVpnAddrs) {
|
if !lhh.lh.IsAnyLighthouseAddr(fromVpnAddrs) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
detailsVpnAddr, _, err := n.Details.GetVpnAddrAndVersion()
|
||||||
|
if err != nil {
|
||||||
|
if lhh.l.Level >= logrus.DebugLevel {
|
||||||
|
lhh.l.WithField("details", n.Details).WithError(err).Debugln("dropping invalid HostPunchNotification")
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
empty := []byte{0}
|
empty := []byte{0}
|
||||||
punch := func(vpnPeer netip.AddrPort) {
|
punch := func(vpnPeer netip.AddrPort, logVpnAddr netip.Addr) {
|
||||||
if !vpnPeer.IsValid() {
|
if !vpnPeer.IsValid() {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -1328,48 +1335,38 @@ func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, fromVpn
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
if lhh.l.Level >= logrus.DebugLevel {
|
if lhh.l.Level >= logrus.DebugLevel {
|
||||||
var logVpnAddr netip.Addr
|
|
||||||
if n.Details.OldVpnAddr != 0 {
|
|
||||||
b := [4]byte{}
|
|
||||||
binary.BigEndian.PutUint32(b[:], n.Details.OldVpnAddr)
|
|
||||||
logVpnAddr = netip.AddrFrom4(b)
|
|
||||||
} else if n.Details.VpnAddr != nil {
|
|
||||||
logVpnAddr = protoAddrToNetAddr(n.Details.VpnAddr)
|
|
||||||
}
|
|
||||||
lhh.l.Debugf("Punching on %v for %v", vpnPeer, logVpnAddr)
|
lhh.l.Debugf("Punching on %v for %v", vpnPeer, logVpnAddr)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
remoteAllowList := lhh.lh.GetRemoteAllowList()
|
||||||
for _, a := range n.Details.V4AddrPorts {
|
for _, a := range n.Details.V4AddrPorts {
|
||||||
punch(protoV4AddrPortToNetAddrPort(a))
|
b := protoV4AddrPortToNetAddrPort(a)
|
||||||
|
if remoteAllowList.Allow(detailsVpnAddr, b.Addr()) {
|
||||||
|
punch(b, detailsVpnAddr)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, a := range n.Details.V6AddrPorts {
|
for _, a := range n.Details.V6AddrPorts {
|
||||||
punch(protoV6AddrPortToNetAddrPort(a))
|
b := protoV6AddrPortToNetAddrPort(a)
|
||||||
|
if remoteAllowList.Allow(detailsVpnAddr, b.Addr()) {
|
||||||
|
punch(b, detailsVpnAddr)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// This sends a nebula test packet to the host trying to contact us. In the case
|
// This sends a nebula test packet to the host trying to contact us. In the case
|
||||||
// of a double nat or other difficult scenario, this may help establish
|
// of a double nat or other difficult scenario, this may help establish
|
||||||
// a tunnel.
|
// a tunnel.
|
||||||
if lhh.lh.punchy.GetRespond() {
|
if lhh.lh.punchy.GetRespond() {
|
||||||
var queryVpnAddr netip.Addr
|
|
||||||
if n.Details.OldVpnAddr != 0 {
|
|
||||||
b := [4]byte{}
|
|
||||||
binary.BigEndian.PutUint32(b[:], n.Details.OldVpnAddr)
|
|
||||||
queryVpnAddr = netip.AddrFrom4(b)
|
|
||||||
} else if n.Details.VpnAddr != nil {
|
|
||||||
queryVpnAddr = protoAddrToNetAddr(n.Details.VpnAddr)
|
|
||||||
}
|
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
time.Sleep(lhh.lh.punchy.GetRespondDelay())
|
time.Sleep(lhh.lh.punchy.GetRespondDelay())
|
||||||
if lhh.l.Level >= logrus.DebugLevel {
|
if lhh.l.Level >= logrus.DebugLevel {
|
||||||
lhh.l.Debugf("Sending a nebula test packet to vpn addr %s", queryVpnAddr)
|
lhh.l.Debugf("Sending a nebula test packet to vpn addr %s", detailsVpnAddr)
|
||||||
}
|
}
|
||||||
//NOTE: we have to allocate a new output buffer here since we are spawning a new goroutine
|
//NOTE: we have to allocate a new output buffer here since we are spawning a new goroutine
|
||||||
// for each punchBack packet. We should move this into a timerwheel or a single goroutine
|
// for each punchBack packet. We should move this into a timerwheel or a single goroutine
|
||||||
// managed by a channel.
|
// managed by a channel.
|
||||||
w.SendMessageToVpnAddr(header.Test, header.TestRequest, queryVpnAddr, []byte(""), make([]byte, 12, 12), make([]byte, mtu))
|
w.SendMessageToVpnAddr(header.Test, header.TestRequest, detailsVpnAddr, []byte(""), make([]byte, 12, 12), make([]byte, mtu))
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1448,3 +1445,17 @@ func findNetworkUnion(prefixes []netip.Prefix, addrs []netip.Addr) (netip.Addr,
|
|||||||
}
|
}
|
||||||
return netip.Addr{}, false
|
return netip.Addr{}, false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (d *NebulaMetaDetails) GetVpnAddrAndVersion() (netip.Addr, cert.Version, error) {
|
||||||
|
if d.OldVpnAddr != 0 {
|
||||||
|
b := [4]byte{}
|
||||||
|
binary.BigEndian.PutUint32(b[:], d.OldVpnAddr)
|
||||||
|
detailsVpnAddr := netip.AddrFrom4(b)
|
||||||
|
return detailsVpnAddr, cert.Version1, nil
|
||||||
|
} else if d.VpnAddr != nil {
|
||||||
|
detailsVpnAddr := protoAddrToNetAddr(d.VpnAddr)
|
||||||
|
return detailsVpnAddr, cert.Version2, nil
|
||||||
|
} else {
|
||||||
|
return netip.Addr{}, cert.Version1, ErrBadDetailsVpnAddr
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -13,7 +13,8 @@ import (
|
|||||||
"github.com/slackhq/nebula/header"
|
"github.com/slackhq/nebula/header"
|
||||||
"github.com/slackhq/nebula/test"
|
"github.com/slackhq/nebula/test"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"gopkg.in/yaml.v2"
|
"github.com/stretchr/testify/require"
|
||||||
|
"go.yaml.in/yaml/v3"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestOldIPv4Only(t *testing.T) {
|
func TestOldIPv4Only(t *testing.T) {
|
||||||
@@ -21,7 +22,7 @@ func TestOldIPv4Only(t *testing.T) {
|
|||||||
b := []byte{8, 129, 130, 132, 80, 16, 10}
|
b := []byte{8, 129, 130, 132, 80, 16, 10}
|
||||||
var m V4AddrPort
|
var m V4AddrPort
|
||||||
err := m.Unmarshal(b)
|
err := m.Unmarshal(b)
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
ip := netip.MustParseAddr("10.1.1.1")
|
ip := netip.MustParseAddr("10.1.1.1")
|
||||||
bp := ip.As4()
|
bp := ip.As4()
|
||||||
assert.Equal(t, binary.BigEndian.Uint32(bp[:]), m.GetAddr())
|
assert.Equal(t, binary.BigEndian.Uint32(bp[:]), m.GetAddr())
|
||||||
@@ -30,8 +31,8 @@ func TestOldIPv4Only(t *testing.T) {
|
|||||||
func Test_lhStaticMapping(t *testing.T) {
|
func Test_lhStaticMapping(t *testing.T) {
|
||||||
l := test.NewLogger()
|
l := test.NewLogger()
|
||||||
myVpnNet := netip.MustParsePrefix("10.128.0.1/16")
|
myVpnNet := netip.MustParsePrefix("10.128.0.1/16")
|
||||||
nt := new(bart.Table[struct{}])
|
nt := new(bart.Lite)
|
||||||
nt.Insert(myVpnNet, struct{}{})
|
nt.Insert(myVpnNet)
|
||||||
cs := &CertState{
|
cs := &CertState{
|
||||||
myVpnNetworks: []netip.Prefix{myVpnNet},
|
myVpnNetworks: []netip.Prefix{myVpnNet},
|
||||||
myVpnNetworksTable: nt,
|
myVpnNetworksTable: nt,
|
||||||
@@ -39,24 +40,24 @@ func Test_lhStaticMapping(t *testing.T) {
|
|||||||
lh1 := "10.128.0.2"
|
lh1 := "10.128.0.2"
|
||||||
|
|
||||||
c := config.NewC(l)
|
c := config.NewC(l)
|
||||||
c.Settings["lighthouse"] = map[interface{}]interface{}{"hosts": []interface{}{lh1}}
|
c.Settings["lighthouse"] = map[string]any{"hosts": []any{lh1}}
|
||||||
c.Settings["static_host_map"] = map[interface{}]interface{}{lh1: []interface{}{"1.1.1.1:4242"}}
|
c.Settings["static_host_map"] = map[string]any{lh1: []any{"1.1.1.1:4242"}}
|
||||||
_, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
|
_, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
lh2 := "10.128.0.3"
|
lh2 := "10.128.0.3"
|
||||||
c = config.NewC(l)
|
c = config.NewC(l)
|
||||||
c.Settings["lighthouse"] = map[interface{}]interface{}{"hosts": []interface{}{lh1, lh2}}
|
c.Settings["lighthouse"] = map[string]any{"hosts": []any{lh1, lh2}}
|
||||||
c.Settings["static_host_map"] = map[interface{}]interface{}{lh1: []interface{}{"100.1.1.1:4242"}}
|
c.Settings["static_host_map"] = map[string]any{lh1: []any{"100.1.1.1:4242"}}
|
||||||
_, err = NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
|
_, err = NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
|
||||||
assert.EqualError(t, err, "lighthouse 10.128.0.3 does not have a static_host_map entry")
|
require.EqualError(t, err, "lighthouse 10.128.0.3 does not have a static_host_map entry")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestReloadLighthouseInterval(t *testing.T) {
|
func TestReloadLighthouseInterval(t *testing.T) {
|
||||||
l := test.NewLogger()
|
l := test.NewLogger()
|
||||||
myVpnNet := netip.MustParsePrefix("10.128.0.1/16")
|
myVpnNet := netip.MustParsePrefix("10.128.0.1/16")
|
||||||
nt := new(bart.Table[struct{}])
|
nt := new(bart.Lite)
|
||||||
nt.Insert(myVpnNet, struct{}{})
|
nt.Insert(myVpnNet)
|
||||||
cs := &CertState{
|
cs := &CertState{
|
||||||
myVpnNetworks: []netip.Prefix{myVpnNet},
|
myVpnNetworks: []netip.Prefix{myVpnNet},
|
||||||
myVpnNetworksTable: nt,
|
myVpnNetworksTable: nt,
|
||||||
@@ -64,34 +65,34 @@ func TestReloadLighthouseInterval(t *testing.T) {
|
|||||||
lh1 := "10.128.0.2"
|
lh1 := "10.128.0.2"
|
||||||
|
|
||||||
c := config.NewC(l)
|
c := config.NewC(l)
|
||||||
c.Settings["lighthouse"] = map[interface{}]interface{}{
|
c.Settings["lighthouse"] = map[string]any{
|
||||||
"hosts": []interface{}{lh1},
|
"hosts": []any{lh1},
|
||||||
"interval": "1s",
|
"interval": "1s",
|
||||||
}
|
}
|
||||||
|
|
||||||
c.Settings["static_host_map"] = map[interface{}]interface{}{lh1: []interface{}{"1.1.1.1:4242"}}
|
c.Settings["static_host_map"] = map[string]any{lh1: []any{"1.1.1.1:4242"}}
|
||||||
lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
|
lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
lh.ifce = &mockEncWriter{}
|
lh.ifce = &mockEncWriter{}
|
||||||
|
|
||||||
// The first one routine is kicked off by main.go currently, lets make sure that one dies
|
// The first one routine is kicked off by main.go currently, lets make sure that one dies
|
||||||
assert.NoError(t, c.ReloadConfigString("lighthouse:\n interval: 5"))
|
require.NoError(t, c.ReloadConfigString("lighthouse:\n interval: 5"))
|
||||||
assert.Equal(t, int64(5), lh.interval.Load())
|
assert.Equal(t, int64(5), lh.interval.Load())
|
||||||
|
|
||||||
// Subsequent calls are killed off by the LightHouse.Reload function
|
// Subsequent calls are killed off by the LightHouse.Reload function
|
||||||
assert.NoError(t, c.ReloadConfigString("lighthouse:\n interval: 10"))
|
require.NoError(t, c.ReloadConfigString("lighthouse:\n interval: 10"))
|
||||||
assert.Equal(t, int64(10), lh.interval.Load())
|
assert.Equal(t, int64(10), lh.interval.Load())
|
||||||
|
|
||||||
// If this completes then nothing is stealing our reload routine
|
// If this completes then nothing is stealing our reload routine
|
||||||
assert.NoError(t, c.ReloadConfigString("lighthouse:\n interval: 11"))
|
require.NoError(t, c.ReloadConfigString("lighthouse:\n interval: 11"))
|
||||||
assert.Equal(t, int64(11), lh.interval.Load())
|
assert.Equal(t, int64(11), lh.interval.Load())
|
||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkLighthouseHandleRequest(b *testing.B) {
|
func BenchmarkLighthouseHandleRequest(b *testing.B) {
|
||||||
l := test.NewLogger()
|
l := test.NewLogger()
|
||||||
myVpnNet := netip.MustParsePrefix("10.128.0.1/0")
|
myVpnNet := netip.MustParsePrefix("10.128.0.1/0")
|
||||||
nt := new(bart.Table[struct{}])
|
nt := new(bart.Lite)
|
||||||
nt.Insert(myVpnNet, struct{}{})
|
nt.Insert(myVpnNet)
|
||||||
cs := &CertState{
|
cs := &CertState{
|
||||||
myVpnNetworks: []netip.Prefix{myVpnNet},
|
myVpnNetworks: []netip.Prefix{myVpnNet},
|
||||||
myVpnNetworksTable: nt,
|
myVpnNetworksTable: nt,
|
||||||
@@ -99,9 +100,7 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) {
|
|||||||
|
|
||||||
c := config.NewC(l)
|
c := config.NewC(l)
|
||||||
lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
|
lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
|
||||||
if !assert.NoError(b, err) {
|
require.NoError(b, err)
|
||||||
b.Fatal()
|
|
||||||
}
|
|
||||||
|
|
||||||
hAddr := netip.MustParseAddrPort("4.5.6.7:12345")
|
hAddr := netip.MustParseAddrPort("4.5.6.7:12345")
|
||||||
hAddr2 := netip.MustParseAddrPort("4.5.6.7:12346")
|
hAddr2 := netip.MustParseAddrPort("4.5.6.7:12346")
|
||||||
@@ -145,7 +144,7 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
p, err := req.Marshal()
|
p, err := req.Marshal()
|
||||||
assert.NoError(b, err)
|
require.NoError(b, err)
|
||||||
for n := 0; n < b.N; n++ {
|
for n := 0; n < b.N; n++ {
|
||||||
lhh.HandleRequest(rAddr, hi, p, mw)
|
lhh.HandleRequest(rAddr, hi, p, mw)
|
||||||
}
|
}
|
||||||
@@ -160,7 +159,7 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
p, err := req.Marshal()
|
p, err := req.Marshal()
|
||||||
assert.NoError(b, err)
|
require.NoError(b, err)
|
||||||
|
|
||||||
for n := 0; n < b.N; n++ {
|
for n := 0; n < b.N; n++ {
|
||||||
lhh.HandleRequest(rAddr, hi, p, mw)
|
lhh.HandleRequest(rAddr, hi, p, mw)
|
||||||
@@ -193,19 +192,19 @@ func TestLighthouse_Memory(t *testing.T) {
|
|||||||
theirVpnIp := netip.MustParseAddr("10.128.0.3")
|
theirVpnIp := netip.MustParseAddr("10.128.0.3")
|
||||||
|
|
||||||
c := config.NewC(l)
|
c := config.NewC(l)
|
||||||
c.Settings["lighthouse"] = map[interface{}]interface{}{"am_lighthouse": true}
|
c.Settings["lighthouse"] = map[string]any{"am_lighthouse": true}
|
||||||
c.Settings["listen"] = map[interface{}]interface{}{"port": 4242}
|
c.Settings["listen"] = map[string]any{"port": 4242}
|
||||||
|
|
||||||
myVpnNet := netip.MustParsePrefix("10.128.0.1/24")
|
myVpnNet := netip.MustParsePrefix("10.128.0.1/24")
|
||||||
nt := new(bart.Table[struct{}])
|
nt := new(bart.Lite)
|
||||||
nt.Insert(myVpnNet, struct{}{})
|
nt.Insert(myVpnNet)
|
||||||
cs := &CertState{
|
cs := &CertState{
|
||||||
myVpnNetworks: []netip.Prefix{myVpnNet},
|
myVpnNetworks: []netip.Prefix{myVpnNet},
|
||||||
myVpnNetworksTable: nt,
|
myVpnNetworksTable: nt,
|
||||||
}
|
}
|
||||||
lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
|
lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
|
||||||
lh.ifce = &mockEncWriter{}
|
lh.ifce = &mockEncWriter{}
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
lhh := lh.NewRequestHandler()
|
lhh := lh.NewRequestHandler()
|
||||||
|
|
||||||
// Test that my first update responds with just that
|
// Test that my first update responds with just that
|
||||||
@@ -278,31 +277,31 @@ func TestLighthouse_Memory(t *testing.T) {
|
|||||||
func TestLighthouse_reload(t *testing.T) {
|
func TestLighthouse_reload(t *testing.T) {
|
||||||
l := test.NewLogger()
|
l := test.NewLogger()
|
||||||
c := config.NewC(l)
|
c := config.NewC(l)
|
||||||
c.Settings["lighthouse"] = map[interface{}]interface{}{"am_lighthouse": true}
|
c.Settings["lighthouse"] = map[string]any{"am_lighthouse": true}
|
||||||
c.Settings["listen"] = map[interface{}]interface{}{"port": 4242}
|
c.Settings["listen"] = map[string]any{"port": 4242}
|
||||||
|
|
||||||
myVpnNet := netip.MustParsePrefix("10.128.0.1/24")
|
myVpnNet := netip.MustParsePrefix("10.128.0.1/24")
|
||||||
nt := new(bart.Table[struct{}])
|
nt := new(bart.Lite)
|
||||||
nt.Insert(myVpnNet, struct{}{})
|
nt.Insert(myVpnNet)
|
||||||
cs := &CertState{
|
cs := &CertState{
|
||||||
myVpnNetworks: []netip.Prefix{myVpnNet},
|
myVpnNetworks: []netip.Prefix{myVpnNet},
|
||||||
myVpnNetworksTable: nt,
|
myVpnNetworksTable: nt,
|
||||||
}
|
}
|
||||||
|
|
||||||
lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
|
lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
nc := map[interface{}]interface{}{
|
nc := map[string]any{
|
||||||
"static_host_map": map[interface{}]interface{}{
|
"static_host_map": map[string]any{
|
||||||
"10.128.0.2": []interface{}{"1.1.1.1:4242"},
|
"10.128.0.2": []any{"1.1.1.1:4242"},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
rc, err := yaml.Marshal(nc)
|
rc, err := yaml.Marshal(nc)
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
c.ReloadConfigString(string(rc))
|
c.ReloadConfigString(string(rc))
|
||||||
|
|
||||||
err = lh.reload(c, false)
|
err = lh.reload(c, false)
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
func newLHHostRequest(fromAddr netip.AddrPort, myVpnIp, queryVpnIp netip.Addr, lhh *LightHouseHandler) testLhReply {
|
func newLHHostRequest(fromAddr netip.AddrPort, myVpnIp, queryVpnIp netip.Addr, lhh *LightHouseHandler) testLhReply {
|
||||||
@@ -418,7 +417,7 @@ func (tw *testEncWriter) GetHostInfo(vpnIp netip.Addr) *HostInfo {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (tw *testEncWriter) GetCertState() *CertState {
|
func (tw *testEncWriter) GetCertState() *CertState {
|
||||||
return &CertState{defaultVersion: tw.protocolVersion}
|
return &CertState{initiatingVersion: tw.protocolVersion}
|
||||||
}
|
}
|
||||||
|
|
||||||
// assertIp4InArray asserts every address in want is at the same position in have and that the lengths match
|
// assertIp4InArray asserts every address in want is at the same position in have and that the lengths match
|
||||||
@@ -494,3 +493,123 @@ func Test_findNetworkUnion(t *testing.T) {
|
|||||||
out, ok = findNetworkUnion([]netip.Prefix{fc00}, []netip.Addr{a1, afe81})
|
out, ok = findNetworkUnion([]netip.Prefix{fc00}, []netip.Addr{a1, afe81})
|
||||||
assert.False(t, ok)
|
assert.False(t, ok)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestLighthouse_Dont_Delete_Static_Hosts(t *testing.T) {
|
||||||
|
l := test.NewLogger()
|
||||||
|
|
||||||
|
myUdpAddr2 := netip.MustParseAddrPort("1.2.3.4:4242")
|
||||||
|
|
||||||
|
testSameHostNotStatic := netip.MustParseAddr("10.128.0.41")
|
||||||
|
testStaticHost := netip.MustParseAddr("10.128.0.42")
|
||||||
|
//myVpnIp := netip.MustParseAddr("10.128.0.2")
|
||||||
|
|
||||||
|
c := config.NewC(l)
|
||||||
|
lh1 := "10.128.0.2"
|
||||||
|
c.Settings["lighthouse"] = map[string]any{
|
||||||
|
"hosts": []any{lh1},
|
||||||
|
"interval": "1s",
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Settings["listen"] = map[string]any{"port": 4242}
|
||||||
|
c.Settings["static_host_map"] = map[string]any{
|
||||||
|
lh1: []any{"1.1.1.1:4242"},
|
||||||
|
"10.128.0.42": []any{"1.2.3.4:4242"},
|
||||||
|
}
|
||||||
|
|
||||||
|
myVpnNet := netip.MustParsePrefix("10.128.0.1/24")
|
||||||
|
nt := new(bart.Lite)
|
||||||
|
nt.Insert(myVpnNet)
|
||||||
|
cs := &CertState{
|
||||||
|
myVpnNetworks: []netip.Prefix{myVpnNet},
|
||||||
|
myVpnNetworksTable: nt,
|
||||||
|
}
|
||||||
|
lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
lh.ifce = &mockEncWriter{}
|
||||||
|
|
||||||
|
//test that we actually have the static entry:
|
||||||
|
out := lh.Query(testStaticHost)
|
||||||
|
assert.NotNil(t, out)
|
||||||
|
assert.Equal(t, out.vpnAddrs[0], testStaticHost)
|
||||||
|
out.Rebuild([]netip.Prefix{}) //why tho
|
||||||
|
assert.Equal(t, out.addrs[0], myUdpAddr2)
|
||||||
|
|
||||||
|
//bolt on a lower numbered primary IP
|
||||||
|
am := lh.unlockedGetRemoteList([]netip.Addr{testStaticHost})
|
||||||
|
am.vpnAddrs = []netip.Addr{testSameHostNotStatic, testStaticHost}
|
||||||
|
lh.addrMap[testSameHostNotStatic] = am
|
||||||
|
out.Rebuild([]netip.Prefix{}) //???
|
||||||
|
|
||||||
|
//test that we actually have the static entry:
|
||||||
|
out = lh.Query(testStaticHost)
|
||||||
|
assert.NotNil(t, out)
|
||||||
|
assert.Equal(t, out.vpnAddrs[0], testSameHostNotStatic)
|
||||||
|
assert.Equal(t, out.vpnAddrs[1], testStaticHost)
|
||||||
|
assert.Equal(t, out.addrs[0], myUdpAddr2)
|
||||||
|
|
||||||
|
//test that we actually have the static entry for BOTH:
|
||||||
|
out2 := lh.Query(testSameHostNotStatic)
|
||||||
|
assert.Same(t, out2, out)
|
||||||
|
|
||||||
|
//now do the delete
|
||||||
|
lh.DeleteVpnAddrs([]netip.Addr{testSameHostNotStatic, testStaticHost})
|
||||||
|
//verify
|
||||||
|
out = lh.Query(testSameHostNotStatic)
|
||||||
|
assert.NotNil(t, out)
|
||||||
|
if out == nil {
|
||||||
|
t.Fatal("expected non-nil query for the static host")
|
||||||
|
}
|
||||||
|
assert.Equal(t, out.vpnAddrs[0], testSameHostNotStatic)
|
||||||
|
assert.Equal(t, out.vpnAddrs[1], testStaticHost)
|
||||||
|
assert.Equal(t, out.addrs[0], myUdpAddr2)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLighthouse_DeletesWork(t *testing.T) {
|
||||||
|
l := test.NewLogger()
|
||||||
|
|
||||||
|
myUdpAddr2 := netip.MustParseAddrPort("1.2.3.4:4242")
|
||||||
|
testHost := netip.MustParseAddr("10.128.0.42")
|
||||||
|
|
||||||
|
c := config.NewC(l)
|
||||||
|
lh1 := "10.128.0.2"
|
||||||
|
c.Settings["lighthouse"] = map[string]any{
|
||||||
|
"hosts": []any{lh1},
|
||||||
|
"interval": "1s",
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Settings["listen"] = map[string]any{"port": 4242}
|
||||||
|
c.Settings["static_host_map"] = map[string]any{
|
||||||
|
lh1: []any{"1.1.1.1:4242"},
|
||||||
|
}
|
||||||
|
|
||||||
|
myVpnNet := netip.MustParsePrefix("10.128.0.1/24")
|
||||||
|
nt := new(bart.Lite)
|
||||||
|
nt.Insert(myVpnNet)
|
||||||
|
cs := &CertState{
|
||||||
|
myVpnNetworks: []netip.Prefix{myVpnNet},
|
||||||
|
myVpnNetworksTable: nt,
|
||||||
|
}
|
||||||
|
lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
lh.ifce = &mockEncWriter{}
|
||||||
|
|
||||||
|
//insert the host
|
||||||
|
am := lh.unlockedGetRemoteList([]netip.Addr{testHost})
|
||||||
|
am.vpnAddrs = []netip.Addr{testHost}
|
||||||
|
am.addrs = []netip.AddrPort{myUdpAddr2}
|
||||||
|
lh.addrMap[testHost] = am
|
||||||
|
am.Rebuild([]netip.Prefix{}) //???
|
||||||
|
|
||||||
|
//test that we actually have the entry:
|
||||||
|
out := lh.Query(testHost)
|
||||||
|
assert.NotNil(t, out)
|
||||||
|
assert.Equal(t, out.vpnAddrs[0], testHost)
|
||||||
|
out.Rebuild([]netip.Prefix{}) //why tho
|
||||||
|
assert.Equal(t, out.addrs[0], myUdpAddr2)
|
||||||
|
|
||||||
|
//now do the delete
|
||||||
|
lh.DeleteVpnAddrs([]netip.Addr{testHost})
|
||||||
|
//verify
|
||||||
|
out = lh.Query(testHost)
|
||||||
|
assert.Nil(t, out)
|
||||||
|
}
|
||||||
|
|||||||
73
main.go
73
main.go
@@ -5,6 +5,8 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"runtime/debug"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
@@ -13,10 +15,10 @@ import (
|
|||||||
"github.com/slackhq/nebula/sshd"
|
"github.com/slackhq/nebula/sshd"
|
||||||
"github.com/slackhq/nebula/udp"
|
"github.com/slackhq/nebula/udp"
|
||||||
"github.com/slackhq/nebula/util"
|
"github.com/slackhq/nebula/util"
|
||||||
"gopkg.in/yaml.v2"
|
"go.yaml.in/yaml/v3"
|
||||||
)
|
)
|
||||||
|
|
||||||
type m map[string]interface{}
|
type m = map[string]any
|
||||||
|
|
||||||
func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logger, deviceFactory overlay.DeviceFactory) (retcon *Control, reterr error) {
|
func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logger, deviceFactory overlay.DeviceFactory) (retcon *Control, reterr error) {
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
@@ -27,6 +29,10 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
if buildVersion == "" {
|
||||||
|
buildVersion = moduleVersion()
|
||||||
|
}
|
||||||
|
|
||||||
l := logger
|
l := logger
|
||||||
l.Formatter = &logrus.TextFormatter{
|
l.Formatter = &logrus.TextFormatter{
|
||||||
FullTimestamp: true,
|
FullTimestamp: true,
|
||||||
@@ -75,7 +81,8 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
|||||||
if c.GetBool("sshd.enabled", false) {
|
if c.GetBool("sshd.enabled", false) {
|
||||||
sshStart, err = configSSH(l, ssh, c)
|
sshStart, err = configSSH(l, ssh, c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, util.ContextualizeIfNeeded("Error while configuring the sshd", err)
|
l.WithError(err).Warn("Failed to configure sshd, ssh debugging will not be available")
|
||||||
|
sshStart = nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -185,6 +192,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
|||||||
|
|
||||||
hostMap := NewHostMapFromConfig(l, c)
|
hostMap := NewHostMapFromConfig(l, c)
|
||||||
punchy := NewPunchyFromConfig(l, c)
|
punchy := NewPunchyFromConfig(l, c)
|
||||||
|
connManager := newConnectionManagerFromConfig(l, c, hostMap, punchy)
|
||||||
lightHouse, err := NewLightHouseFromConfig(ctx, l, c, pki.getCertState(), udpConns[0], punchy)
|
lightHouse, err := NewLightHouseFromConfig(ctx, l, c, pki.getCertState(), udpConns[0], punchy)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, util.ContextualizeIfNeeded("Failed to initialize lighthouse handler", err)
|
return nil, util.ContextualizeIfNeeded("Failed to initialize lighthouse handler", err)
|
||||||
@@ -220,31 +228,26 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
checkInterval := c.GetInt("timers.connection_alive_interval", 5)
|
|
||||||
pendingDeletionInterval := c.GetInt("timers.pending_deletion_interval", 10)
|
|
||||||
|
|
||||||
ifConfig := &InterfaceConfig{
|
ifConfig := &InterfaceConfig{
|
||||||
HostMap: hostMap,
|
HostMap: hostMap,
|
||||||
Inside: tun,
|
Inside: tun,
|
||||||
Outside: udpConns[0],
|
Outside: udpConns[0],
|
||||||
pki: pki,
|
pki: pki,
|
||||||
Firewall: fw,
|
Firewall: fw,
|
||||||
ServeDns: serveDns,
|
ServeDns: serveDns,
|
||||||
HandshakeManager: handshakeManager,
|
HandshakeManager: handshakeManager,
|
||||||
lightHouse: lightHouse,
|
connectionManager: connManager,
|
||||||
checkInterval: time.Second * time.Duration(checkInterval),
|
lightHouse: lightHouse,
|
||||||
pendingDeletionInterval: time.Second * time.Duration(pendingDeletionInterval),
|
tryPromoteEvery: c.GetUint32("counters.try_promote", defaultPromoteEvery),
|
||||||
tryPromoteEvery: c.GetUint32("counters.try_promote", defaultPromoteEvery),
|
reQueryEvery: c.GetUint32("counters.requery_every_packets", defaultReQueryEvery),
|
||||||
reQueryEvery: c.GetUint32("counters.requery_every_packets", defaultReQueryEvery),
|
reQueryWait: c.GetDuration("timers.requery_wait_duration", defaultReQueryWait),
|
||||||
reQueryWait: c.GetDuration("timers.requery_wait_duration", defaultReQueryWait),
|
DropLocalBroadcast: c.GetBool("tun.drop_local_broadcast", false),
|
||||||
DropLocalBroadcast: c.GetBool("tun.drop_local_broadcast", false),
|
DropMulticast: c.GetBool("tun.drop_multicast", false),
|
||||||
DropMulticast: c.GetBool("tun.drop_multicast", false),
|
routines: routines,
|
||||||
routines: routines,
|
MessageMetrics: messageMetrics,
|
||||||
MessageMetrics: messageMetrics,
|
version: buildVersion,
|
||||||
version: buildVersion,
|
relayManager: NewRelayManager(ctx, l, hostMap, c),
|
||||||
relayManager: NewRelayManager(ctx, l, hostMap, c),
|
punchy: punchy,
|
||||||
punchy: punchy,
|
|
||||||
|
|
||||||
ConntrackCacheTimeout: conntrackCacheTimeout,
|
ConntrackCacheTimeout: conntrackCacheTimeout,
|
||||||
l: l,
|
l: l,
|
||||||
}
|
}
|
||||||
@@ -296,5 +299,21 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
|||||||
statsStart,
|
statsStart,
|
||||||
dnsStart,
|
dnsStart,
|
||||||
lightHouse.StartUpdateWorker,
|
lightHouse.StartUpdateWorker,
|
||||||
|
connManager.Start,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func moduleVersion() string {
|
||||||
|
info, ok := debug.ReadBuildInfo()
|
||||||
|
if !ok {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, dep := range info.Deps {
|
||||||
|
if dep.Path == "github.com/slackhq/nebula" {
|
||||||
|
return strings.TrimPrefix(dep.Version, "v")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|||||||
18
metadata.go
18
metadata.go
@@ -1,18 +0,0 @@
|
|||||||
package nebula
|
|
||||||
|
|
||||||
/*
|
|
||||||
|
|
||||||
import (
|
|
||||||
proto "google.golang.org/protobuf/proto"
|
|
||||||
)
|
|
||||||
|
|
||||||
func HandleMetaProto(p []byte) {
|
|
||||||
m := &NebulaMeta{}
|
|
||||||
err := proto.Unmarshal(p, m)
|
|
||||||
if err != nil {
|
|
||||||
l.Debugf("problem unmarshaling meta message: %s", err)
|
|
||||||
}
|
|
||||||
//fmt.Println(m)
|
|
||||||
}
|
|
||||||
|
|
||||||
*/
|
|
||||||
36
outside.go
36
outside.go
@@ -31,8 +31,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
|
|||||||
|
|
||||||
//l.Error("in packet ", header, packet[HeaderLen:])
|
//l.Error("in packet ", header, packet[HeaderLen:])
|
||||||
if ip.IsValid() {
|
if ip.IsValid() {
|
||||||
_, found := f.myVpnNetworksTable.Lookup(ip.Addr())
|
if f.myVpnNetworksTable.Contains(ip.Addr()) {
|
||||||
if found {
|
|
||||||
if f.l.Level >= logrus.DebugLevel {
|
if f.l.Level >= logrus.DebugLevel {
|
||||||
f.l.WithField("udpAddr", ip).Debug("Refusing to process double encrypted packet")
|
f.l.WithField("udpAddr", ip).Debug("Refusing to process double encrypted packet")
|
||||||
}
|
}
|
||||||
@@ -82,7 +81,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
|
|||||||
// Pull the Roaming parts up here, and return in all call paths.
|
// Pull the Roaming parts up here, and return in all call paths.
|
||||||
f.handleHostRoaming(hostinfo, ip)
|
f.handleHostRoaming(hostinfo, ip)
|
||||||
// Track usage of both the HostInfo and the Relay for the received & authenticated packet
|
// Track usage of both the HostInfo and the Relay for the received & authenticated packet
|
||||||
f.connectionManager.In(hostinfo.localIndexId)
|
f.connectionManager.In(hostinfo)
|
||||||
f.connectionManager.RelayUsed(h.RemoteIndex)
|
f.connectionManager.RelayUsed(h.RemoteIndex)
|
||||||
|
|
||||||
relay, ok := hostinfo.relayState.QueryRelayForByIdx(h.RemoteIndex)
|
relay, ok := hostinfo.relayState.QueryRelayForByIdx(h.RemoteIndex)
|
||||||
@@ -214,7 +213,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
|
|||||||
|
|
||||||
f.handleHostRoaming(hostinfo, ip)
|
f.handleHostRoaming(hostinfo, ip)
|
||||||
|
|
||||||
f.connectionManager.In(hostinfo.localIndexId)
|
f.connectionManager.In(hostinfo)
|
||||||
}
|
}
|
||||||
|
|
||||||
// closeTunnel closes a tunnel locally, it does not send a closeTunnel packet to the remote
|
// closeTunnel closes a tunnel locally, it does not send a closeTunnel packet to the remote
|
||||||
@@ -255,16 +254,18 @@ func (f *Interface) handleHostRoaming(hostinfo *HostInfo, udpAddr netip.AddrPort
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// handleEncrypted returns true if a packet should be processed, false otherwise
|
||||||
func (f *Interface) handleEncrypted(ci *ConnectionState, addr netip.AddrPort, h *header.H) bool {
|
func (f *Interface) handleEncrypted(ci *ConnectionState, addr netip.AddrPort, h *header.H) bool {
|
||||||
// If connectionstate exists and the replay protector allows, process packet
|
// If connectionstate does not exist, send a recv error, if possible, to encourage a fast reconnect
|
||||||
// Else, send recv errors for 300 seconds after a restart to allow fast reconnection.
|
if ci == nil {
|
||||||
if ci == nil || !ci.window.Check(f.l, h.MessageCounter) {
|
|
||||||
if addr.IsValid() {
|
if addr.IsValid() {
|
||||||
f.maybeSendRecvError(addr, h.RemoteIndex)
|
f.maybeSendRecvError(addr, h.RemoteIndex)
|
||||||
return false
|
|
||||||
} else {
|
|
||||||
return false
|
|
||||||
}
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
// If the window check fails, refuse to process the packet, but don't send a recv error
|
||||||
|
if !ci.window.Check(f.l, h.MessageCounter) {
|
||||||
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
return true
|
return true
|
||||||
@@ -313,12 +314,11 @@ func parseV6(data []byte, incoming bool, fp *firewall.Packet) error {
|
|||||||
offset := ipv6.HeaderLen // Start at the end of the ipv6 header
|
offset := ipv6.HeaderLen // Start at the end of the ipv6 header
|
||||||
next := 0
|
next := 0
|
||||||
for {
|
for {
|
||||||
if dataLen < offset {
|
if protoAt >= dataLen {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
proto := layers.IPProtocol(data[protoAt])
|
proto := layers.IPProtocol(data[protoAt])
|
||||||
//fmt.Println(proto, protoAt)
|
|
||||||
switch proto {
|
switch proto {
|
||||||
case layers.IPProtocolICMPv6, layers.IPProtocolESP, layers.IPProtocolNoNextHeader:
|
case layers.IPProtocolICMPv6, layers.IPProtocolESP, layers.IPProtocolNoNextHeader:
|
||||||
fp.Protocol = uint8(proto)
|
fp.Protocol = uint8(proto)
|
||||||
@@ -366,7 +366,7 @@ func parseV6(data []byte, incoming bool, fp *firewall.Packet) error {
|
|||||||
|
|
||||||
case layers.IPProtocolAH:
|
case layers.IPProtocolAH:
|
||||||
// Auth headers, used by IPSec, have a different meaning for header length
|
// Auth headers, used by IPSec, have a different meaning for header length
|
||||||
if dataLen < offset+1 {
|
if dataLen <= offset+1 {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -374,7 +374,7 @@ func parseV6(data []byte, incoming bool, fp *firewall.Packet) error {
|
|||||||
|
|
||||||
default:
|
default:
|
||||||
// Normal ipv6 header length processing
|
// Normal ipv6 header length processing
|
||||||
if dataLen < offset+1 {
|
if dataLen <= offset+1 {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -500,7 +500,7 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
f.connectionManager.In(hostinfo.localIndexId)
|
f.connectionManager.In(hostinfo)
|
||||||
_, err = f.readers[q].Write(out)
|
_, err = f.readers[q].Write(out)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
f.l.WithError(err).Error("Failed to write to tun")
|
f.l.WithError(err).Error("Failed to write to tun")
|
||||||
@@ -539,10 +539,6 @@ func (f *Interface) handleRecvError(addr netip.AddrPort, h *header.H) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if !hostinfo.RecvErrorExceeded() {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if hostinfo.remote.IsValid() && hostinfo.remote != addr {
|
if hostinfo.remote.IsValid() && hostinfo.remote != addr {
|
||||||
f.l.Infoln("Someone spoofing recv_errors? ", addr, hostinfo.remote)
|
f.l.Infoln("Someone spoofing recv_errors? ", addr, hostinfo.remote)
|
||||||
return
|
return
|
||||||
|
|||||||
100
outside_test.go
100
outside_test.go
@@ -12,6 +12,7 @@ import (
|
|||||||
|
|
||||||
"github.com/slackhq/nebula/firewall"
|
"github.com/slackhq/nebula/firewall"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
"golang.org/x/net/ipv4"
|
"golang.org/x/net/ipv4"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -20,13 +21,13 @@ func Test_newPacket(t *testing.T) {
|
|||||||
|
|
||||||
// length fails
|
// length fails
|
||||||
err := newPacket([]byte{}, true, p)
|
err := newPacket([]byte{}, true, p)
|
||||||
assert.ErrorIs(t, err, ErrPacketTooShort)
|
require.ErrorIs(t, err, ErrPacketTooShort)
|
||||||
|
|
||||||
err = newPacket([]byte{0x40}, true, p)
|
err = newPacket([]byte{0x40}, true, p)
|
||||||
assert.ErrorIs(t, err, ErrIPv4PacketTooShort)
|
require.ErrorIs(t, err, ErrIPv4PacketTooShort)
|
||||||
|
|
||||||
err = newPacket([]byte{0x60}, true, p)
|
err = newPacket([]byte{0x60}, true, p)
|
||||||
assert.ErrorIs(t, err, ErrIPv6PacketTooShort)
|
require.ErrorIs(t, err, ErrIPv6PacketTooShort)
|
||||||
|
|
||||||
// length fail with ip options
|
// length fail with ip options
|
||||||
h := ipv4.Header{
|
h := ipv4.Header{
|
||||||
@@ -39,15 +40,15 @@ func Test_newPacket(t *testing.T) {
|
|||||||
|
|
||||||
b, _ := h.Marshal()
|
b, _ := h.Marshal()
|
||||||
err = newPacket(b, true, p)
|
err = newPacket(b, true, p)
|
||||||
assert.ErrorIs(t, err, ErrIPv4InvalidHeaderLength)
|
require.ErrorIs(t, err, ErrIPv4InvalidHeaderLength)
|
||||||
|
|
||||||
// not an ipv4 packet
|
// not an ipv4 packet
|
||||||
err = newPacket([]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, true, p)
|
err = newPacket([]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, true, p)
|
||||||
assert.ErrorIs(t, err, ErrUnknownIPVersion)
|
require.ErrorIs(t, err, ErrUnknownIPVersion)
|
||||||
|
|
||||||
// invalid ihl
|
// invalid ihl
|
||||||
err = newPacket([]byte{4<<4 | (8 >> 2 & 0x0f), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, true, p)
|
err = newPacket([]byte{4<<4 | (8 >> 2 & 0x0f), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, true, p)
|
||||||
assert.ErrorIs(t, err, ErrIPv4InvalidHeaderLength)
|
require.ErrorIs(t, err, ErrIPv4InvalidHeaderLength)
|
||||||
|
|
||||||
// account for variable ip header length - incoming
|
// account for variable ip header length - incoming
|
||||||
h = ipv4.Header{
|
h = ipv4.Header{
|
||||||
@@ -63,7 +64,7 @@ func Test_newPacket(t *testing.T) {
|
|||||||
b = append(b, []byte{0, 3, 0, 4}...)
|
b = append(b, []byte{0, 3, 0, 4}...)
|
||||||
err = newPacket(b, true, p)
|
err = newPacket(b, true, p)
|
||||||
|
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, uint8(firewall.ProtoTCP), p.Protocol)
|
assert.Equal(t, uint8(firewall.ProtoTCP), p.Protocol)
|
||||||
assert.Equal(t, netip.MustParseAddr("10.0.0.2"), p.LocalAddr)
|
assert.Equal(t, netip.MustParseAddr("10.0.0.2"), p.LocalAddr)
|
||||||
assert.Equal(t, netip.MustParseAddr("10.0.0.1"), p.RemoteAddr)
|
assert.Equal(t, netip.MustParseAddr("10.0.0.1"), p.RemoteAddr)
|
||||||
@@ -85,7 +86,7 @@ func Test_newPacket(t *testing.T) {
|
|||||||
b = append(b, []byte{0, 5, 0, 6}...)
|
b = append(b, []byte{0, 5, 0, 6}...)
|
||||||
err = newPacket(b, false, p)
|
err = newPacket(b, false, p)
|
||||||
|
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, uint8(2), p.Protocol)
|
assert.Equal(t, uint8(2), p.Protocol)
|
||||||
assert.Equal(t, netip.MustParseAddr("10.0.0.1"), p.LocalAddr)
|
assert.Equal(t, netip.MustParseAddr("10.0.0.1"), p.LocalAddr)
|
||||||
assert.Equal(t, netip.MustParseAddr("10.0.0.2"), p.RemoteAddr)
|
assert.Equal(t, netip.MustParseAddr("10.0.0.2"), p.RemoteAddr)
|
||||||
@@ -111,10 +112,49 @@ func Test_newPacket_v6(t *testing.T) {
|
|||||||
FixLengths: false,
|
FixLengths: false,
|
||||||
}
|
}
|
||||||
err := gopacket.SerializeLayers(buffer, opt, &ip)
|
err := gopacket.SerializeLayers(buffer, opt, &ip)
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
err = newPacket(buffer.Bytes(), true, p)
|
err = newPacket(buffer.Bytes(), true, p)
|
||||||
assert.ErrorIs(t, err, ErrIPv6CouldNotFindPayload)
|
require.ErrorIs(t, err, ErrIPv6CouldNotFindPayload)
|
||||||
|
|
||||||
|
// A v6 packet with a hop-by-hop extension
|
||||||
|
// ICMPv6 Payload (Echo Request)
|
||||||
|
icmpLayer := layers.ICMPv6{
|
||||||
|
TypeCode: layers.ICMPv6TypeEchoRequest,
|
||||||
|
}
|
||||||
|
// Hop-by-Hop Extension Header
|
||||||
|
hopOption := layers.IPv6HopByHopOption{}
|
||||||
|
hopOption.OptionData = []byte{0, 0, 0, 0}
|
||||||
|
hopByHop := layers.IPv6HopByHop{}
|
||||||
|
hopByHop.Options = append(hopByHop.Options, &hopOption)
|
||||||
|
|
||||||
|
ip = layers.IPv6{
|
||||||
|
Version: 6,
|
||||||
|
HopLimit: 128,
|
||||||
|
NextHeader: layers.IPProtocolIPv6Destination,
|
||||||
|
SrcIP: net.IPv6linklocalallrouters,
|
||||||
|
DstIP: net.IPv6linklocalallnodes,
|
||||||
|
}
|
||||||
|
|
||||||
|
buffer.Clear()
|
||||||
|
err = gopacket.SerializeLayers(buffer, gopacket.SerializeOptions{
|
||||||
|
ComputeChecksums: false,
|
||||||
|
FixLengths: true,
|
||||||
|
}, &ip, &hopByHop, &icmpLayer)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
// Ensure buffer length checks during parsing with the next 2 tests.
|
||||||
|
|
||||||
|
// A full IPv6 header and 1 byte in the first extension, but missing
|
||||||
|
// the length byte.
|
||||||
|
err = newPacket(buffer.Bytes()[:41], true, p)
|
||||||
|
require.ErrorIs(t, err, ErrIPv6CouldNotFindPayload)
|
||||||
|
|
||||||
|
// A full IPv6 header plus 1 full extension, but only 1 byte of the
|
||||||
|
// next layer, missing length byte
|
||||||
|
err = newPacket(buffer.Bytes()[:49], true, p)
|
||||||
|
require.ErrorIs(t, err, ErrIPv6CouldNotFindPayload)
|
||||||
|
|
||||||
// A good ICMP packet
|
// A good ICMP packet
|
||||||
ip = layers.IPv6{
|
ip = layers.IPv6{
|
||||||
@@ -134,7 +174,7 @@ func Test_newPacket_v6(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
err = newPacket(buffer.Bytes(), true, p)
|
err = newPacket(buffer.Bytes(), true, p)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, uint8(layers.IPProtocolICMPv6), p.Protocol)
|
assert.Equal(t, uint8(layers.IPProtocolICMPv6), p.Protocol)
|
||||||
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
|
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
|
||||||
assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr)
|
assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr)
|
||||||
@@ -146,7 +186,7 @@ func Test_newPacket_v6(t *testing.T) {
|
|||||||
b := buffer.Bytes()
|
b := buffer.Bytes()
|
||||||
b[6] = byte(layers.IPProtocolESP)
|
b[6] = byte(layers.IPProtocolESP)
|
||||||
err = newPacket(b, true, p)
|
err = newPacket(b, true, p)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, uint8(layers.IPProtocolESP), p.Protocol)
|
assert.Equal(t, uint8(layers.IPProtocolESP), p.Protocol)
|
||||||
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
|
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
|
||||||
assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr)
|
assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr)
|
||||||
@@ -158,7 +198,7 @@ func Test_newPacket_v6(t *testing.T) {
|
|||||||
b = buffer.Bytes()
|
b = buffer.Bytes()
|
||||||
b[6] = byte(layers.IPProtocolNoNextHeader)
|
b[6] = byte(layers.IPProtocolNoNextHeader)
|
||||||
err = newPacket(b, true, p)
|
err = newPacket(b, true, p)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, uint8(layers.IPProtocolNoNextHeader), p.Protocol)
|
assert.Equal(t, uint8(layers.IPProtocolNoNextHeader), p.Protocol)
|
||||||
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
|
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
|
||||||
assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr)
|
assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr)
|
||||||
@@ -170,7 +210,7 @@ func Test_newPacket_v6(t *testing.T) {
|
|||||||
b = buffer.Bytes()
|
b = buffer.Bytes()
|
||||||
b[6] = 255 // 255 is a reserved protocol number
|
b[6] = 255 // 255 is a reserved protocol number
|
||||||
err = newPacket(b, true, p)
|
err = newPacket(b, true, p)
|
||||||
assert.ErrorIs(t, err, ErrIPv6CouldNotFindPayload)
|
require.ErrorIs(t, err, ErrIPv6CouldNotFindPayload)
|
||||||
|
|
||||||
// A good UDP packet
|
// A good UDP packet
|
||||||
ip = layers.IPv6{
|
ip = layers.IPv6{
|
||||||
@@ -186,7 +226,7 @@ func Test_newPacket_v6(t *testing.T) {
|
|||||||
DstPort: layers.UDPPort(22),
|
DstPort: layers.UDPPort(22),
|
||||||
}
|
}
|
||||||
err = udp.SetNetworkLayerForChecksum(&ip)
|
err = udp.SetNetworkLayerForChecksum(&ip)
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
buffer.Clear()
|
buffer.Clear()
|
||||||
err = gopacket.SerializeLayers(buffer, opt, &ip, &udp, gopacket.Payload([]byte{0xde, 0xad, 0xbe, 0xef}))
|
err = gopacket.SerializeLayers(buffer, opt, &ip, &udp, gopacket.Payload([]byte{0xde, 0xad, 0xbe, 0xef}))
|
||||||
@@ -197,7 +237,7 @@ func Test_newPacket_v6(t *testing.T) {
|
|||||||
|
|
||||||
// incoming
|
// incoming
|
||||||
err = newPacket(b, true, p)
|
err = newPacket(b, true, p)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, uint8(firewall.ProtoUDP), p.Protocol)
|
assert.Equal(t, uint8(firewall.ProtoUDP), p.Protocol)
|
||||||
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
|
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
|
||||||
assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr)
|
assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr)
|
||||||
@@ -207,7 +247,7 @@ func Test_newPacket_v6(t *testing.T) {
|
|||||||
|
|
||||||
// outgoing
|
// outgoing
|
||||||
err = newPacket(b, false, p)
|
err = newPacket(b, false, p)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, uint8(firewall.ProtoUDP), p.Protocol)
|
assert.Equal(t, uint8(firewall.ProtoUDP), p.Protocol)
|
||||||
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.LocalAddr)
|
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.LocalAddr)
|
||||||
assert.Equal(t, netip.MustParseAddr("ff02::1"), p.RemoteAddr)
|
assert.Equal(t, netip.MustParseAddr("ff02::1"), p.RemoteAddr)
|
||||||
@@ -217,14 +257,14 @@ func Test_newPacket_v6(t *testing.T) {
|
|||||||
|
|
||||||
// Too short UDP packet
|
// Too short UDP packet
|
||||||
err = newPacket(b[:len(b)-10], false, p) // pull off the last 10 bytes
|
err = newPacket(b[:len(b)-10], false, p) // pull off the last 10 bytes
|
||||||
assert.ErrorIs(t, err, ErrIPv6PacketTooShort)
|
require.ErrorIs(t, err, ErrIPv6PacketTooShort)
|
||||||
|
|
||||||
// A good TCP packet
|
// A good TCP packet
|
||||||
b[6] = byte(layers.IPProtocolTCP)
|
b[6] = byte(layers.IPProtocolTCP)
|
||||||
|
|
||||||
// incoming
|
// incoming
|
||||||
err = newPacket(b, true, p)
|
err = newPacket(b, true, p)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, uint8(firewall.ProtoTCP), p.Protocol)
|
assert.Equal(t, uint8(firewall.ProtoTCP), p.Protocol)
|
||||||
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
|
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
|
||||||
assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr)
|
assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr)
|
||||||
@@ -234,7 +274,7 @@ func Test_newPacket_v6(t *testing.T) {
|
|||||||
|
|
||||||
// outgoing
|
// outgoing
|
||||||
err = newPacket(b, false, p)
|
err = newPacket(b, false, p)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, uint8(firewall.ProtoTCP), p.Protocol)
|
assert.Equal(t, uint8(firewall.ProtoTCP), p.Protocol)
|
||||||
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.LocalAddr)
|
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.LocalAddr)
|
||||||
assert.Equal(t, netip.MustParseAddr("ff02::1"), p.RemoteAddr)
|
assert.Equal(t, netip.MustParseAddr("ff02::1"), p.RemoteAddr)
|
||||||
@@ -244,7 +284,7 @@ func Test_newPacket_v6(t *testing.T) {
|
|||||||
|
|
||||||
// Too short TCP packet
|
// Too short TCP packet
|
||||||
err = newPacket(b[:len(b)-10], false, p) // pull off the last 10 bytes
|
err = newPacket(b[:len(b)-10], false, p) // pull off the last 10 bytes
|
||||||
assert.ErrorIs(t, err, ErrIPv6PacketTooShort)
|
require.ErrorIs(t, err, ErrIPv6PacketTooShort)
|
||||||
|
|
||||||
// A good UDP packet with an AH header
|
// A good UDP packet with an AH header
|
||||||
ip = layers.IPv6{
|
ip = layers.IPv6{
|
||||||
@@ -279,7 +319,7 @@ func Test_newPacket_v6(t *testing.T) {
|
|||||||
b = append(b, udpHeader...)
|
b = append(b, udpHeader...)
|
||||||
|
|
||||||
err = newPacket(b, true, p)
|
err = newPacket(b, true, p)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, uint8(firewall.ProtoUDP), p.Protocol)
|
assert.Equal(t, uint8(firewall.ProtoUDP), p.Protocol)
|
||||||
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
|
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
|
||||||
assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr)
|
assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr)
|
||||||
@@ -287,10 +327,14 @@ func Test_newPacket_v6(t *testing.T) {
|
|||||||
assert.Equal(t, uint16(22), p.LocalPort)
|
assert.Equal(t, uint16(22), p.LocalPort)
|
||||||
assert.False(t, p.Fragment)
|
assert.False(t, p.Fragment)
|
||||||
|
|
||||||
|
// Ensure buffer bounds checking during processing
|
||||||
|
err = newPacket(b[:41], true, p)
|
||||||
|
require.ErrorIs(t, err, ErrIPv6PacketTooShort)
|
||||||
|
|
||||||
// Invalid AH header
|
// Invalid AH header
|
||||||
b = buffer.Bytes()
|
b = buffer.Bytes()
|
||||||
err = newPacket(b, true, p)
|
err = newPacket(b, true, p)
|
||||||
assert.ErrorIs(t, err, ErrIPv6CouldNotFindPayload)
|
require.ErrorIs(t, err, ErrIPv6CouldNotFindPayload)
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_newPacket_ipv6Fragment(t *testing.T) {
|
func Test_newPacket_ipv6Fragment(t *testing.T) {
|
||||||
@@ -338,7 +382,7 @@ func Test_newPacket_ipv6Fragment(t *testing.T) {
|
|||||||
|
|
||||||
// Test first fragment incoming
|
// Test first fragment incoming
|
||||||
err = newPacket(firstFrag, true, p)
|
err = newPacket(firstFrag, true, p)
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
|
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
|
||||||
assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr)
|
assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr)
|
||||||
assert.Equal(t, uint8(layers.IPProtocolUDP), p.Protocol)
|
assert.Equal(t, uint8(layers.IPProtocolUDP), p.Protocol)
|
||||||
@@ -348,7 +392,7 @@ func Test_newPacket_ipv6Fragment(t *testing.T) {
|
|||||||
|
|
||||||
// Test first fragment outgoing
|
// Test first fragment outgoing
|
||||||
err = newPacket(firstFrag, false, p)
|
err = newPacket(firstFrag, false, p)
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.LocalAddr)
|
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.LocalAddr)
|
||||||
assert.Equal(t, netip.MustParseAddr("ff02::1"), p.RemoteAddr)
|
assert.Equal(t, netip.MustParseAddr("ff02::1"), p.RemoteAddr)
|
||||||
assert.Equal(t, uint8(layers.IPProtocolUDP), p.Protocol)
|
assert.Equal(t, uint8(layers.IPProtocolUDP), p.Protocol)
|
||||||
@@ -377,7 +421,7 @@ func Test_newPacket_ipv6Fragment(t *testing.T) {
|
|||||||
|
|
||||||
// Test second fragment incoming
|
// Test second fragment incoming
|
||||||
err = newPacket(secondFrag, true, p)
|
err = newPacket(secondFrag, true, p)
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
|
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
|
||||||
assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr)
|
assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr)
|
||||||
assert.Equal(t, uint8(layers.IPProtocolUDP), p.Protocol)
|
assert.Equal(t, uint8(layers.IPProtocolUDP), p.Protocol)
|
||||||
@@ -387,7 +431,7 @@ func Test_newPacket_ipv6Fragment(t *testing.T) {
|
|||||||
|
|
||||||
// Test second fragment outgoing
|
// Test second fragment outgoing
|
||||||
err = newPacket(secondFrag, false, p)
|
err = newPacket(secondFrag, false, p)
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.LocalAddr)
|
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.LocalAddr)
|
||||||
assert.Equal(t, netip.MustParseAddr("ff02::1"), p.RemoteAddr)
|
assert.Equal(t, netip.MustParseAddr("ff02::1"), p.RemoteAddr)
|
||||||
assert.Equal(t, uint8(layers.IPProtocolUDP), p.Protocol)
|
assert.Equal(t, uint8(layers.IPProtocolUDP), p.Protocol)
|
||||||
@@ -397,7 +441,7 @@ func Test_newPacket_ipv6Fragment(t *testing.T) {
|
|||||||
|
|
||||||
// Too short of a fragment packet
|
// Too short of a fragment packet
|
||||||
err = newPacket(secondFrag[:len(secondFrag)-10], false, p)
|
err = newPacket(secondFrag[:len(secondFrag)-10], false, p)
|
||||||
assert.ErrorIs(t, err, ErrIPv6PacketTooShort)
|
require.ErrorIs(t, err, ErrIPv6PacketTooShort)
|
||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkParseV6(b *testing.B) {
|
func BenchmarkParseV6(b *testing.B) {
|
||||||
|
|||||||
@@ -3,6 +3,8 @@ package overlay
|
|||||||
import (
|
import (
|
||||||
"io"
|
"io"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
|
||||||
|
"github.com/slackhq/nebula/routing"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Device interface {
|
type Device interface {
|
||||||
@@ -10,6 +12,7 @@ type Device interface {
|
|||||||
Activate() error
|
Activate() error
|
||||||
Networks() []netip.Prefix
|
Networks() []netip.Prefix
|
||||||
Name() string
|
Name() string
|
||||||
RouteFor(netip.Addr) netip.Addr
|
RoutesFor(netip.Addr) routing.Gateways
|
||||||
|
SupportsMultiqueue() bool
|
||||||
NewMultiQueueReader() (io.ReadWriteCloser, error)
|
NewMultiQueueReader() (io.ReadWriteCloser, error)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -11,13 +11,14 @@ import (
|
|||||||
"github.com/gaissmai/bart"
|
"github.com/gaissmai/bart"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
|
"github.com/slackhq/nebula/routing"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Route struct {
|
type Route struct {
|
||||||
MTU int
|
MTU int
|
||||||
Metric int
|
Metric int
|
||||||
Cidr netip.Prefix
|
Cidr netip.Prefix
|
||||||
Via netip.Addr
|
Via routing.Gateways
|
||||||
Install bool
|
Install bool
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -47,15 +48,17 @@ func (r Route) String() string {
|
|||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
func makeRouteTree(l *logrus.Logger, routes []Route, allowMTU bool) (*bart.Table[netip.Addr], error) {
|
func makeRouteTree(l *logrus.Logger, routes []Route, allowMTU bool) (*bart.Table[routing.Gateways], error) {
|
||||||
routeTree := new(bart.Table[netip.Addr])
|
routeTree := new(bart.Table[routing.Gateways])
|
||||||
for _, r := range routes {
|
for _, r := range routes {
|
||||||
if !allowMTU && r.MTU > 0 {
|
if !allowMTU && r.MTU > 0 {
|
||||||
l.WithField("route", r).Warnf("route MTU is not supported in %s", runtime.GOOS)
|
l.WithField("route", r).Warnf("route MTU is not supported in %s", runtime.GOOS)
|
||||||
}
|
}
|
||||||
|
|
||||||
if r.Via.IsValid() {
|
gateways := r.Via
|
||||||
routeTree.Insert(r.Cidr, r.Via)
|
if len(gateways) > 0 {
|
||||||
|
routing.CalculateBucketsForGateways(gateways)
|
||||||
|
routeTree.Insert(r.Cidr, gateways)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return routeTree, nil
|
return routeTree, nil
|
||||||
@@ -69,7 +72,7 @@ func parseRoutes(c *config.C, networks []netip.Prefix) ([]Route, error) {
|
|||||||
return []Route{}, nil
|
return []Route{}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
rawRoutes, ok := r.([]interface{})
|
rawRoutes, ok := r.([]any)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("tun.routes is not an array")
|
return nil, fmt.Errorf("tun.routes is not an array")
|
||||||
}
|
}
|
||||||
@@ -80,7 +83,7 @@ func parseRoutes(c *config.C, networks []netip.Prefix) ([]Route, error) {
|
|||||||
|
|
||||||
routes := make([]Route, len(rawRoutes))
|
routes := make([]Route, len(rawRoutes))
|
||||||
for i, r := range rawRoutes {
|
for i, r := range rawRoutes {
|
||||||
m, ok := r.(map[interface{}]interface{})
|
m, ok := r.(map[string]any)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("entry %v in tun.routes is invalid", i+1)
|
return nil, fmt.Errorf("entry %v in tun.routes is invalid", i+1)
|
||||||
}
|
}
|
||||||
@@ -148,7 +151,7 @@ func parseUnsafeRoutes(c *config.C, networks []netip.Prefix) ([]Route, error) {
|
|||||||
return []Route{}, nil
|
return []Route{}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
rawRoutes, ok := r.([]interface{})
|
rawRoutes, ok := r.([]any)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("tun.unsafe_routes is not an array")
|
return nil, fmt.Errorf("tun.unsafe_routes is not an array")
|
||||||
}
|
}
|
||||||
@@ -159,7 +162,7 @@ func parseUnsafeRoutes(c *config.C, networks []netip.Prefix) ([]Route, error) {
|
|||||||
|
|
||||||
routes := make([]Route, len(rawRoutes))
|
routes := make([]Route, len(rawRoutes))
|
||||||
for i, r := range rawRoutes {
|
for i, r := range rawRoutes {
|
||||||
m, ok := r.(map[interface{}]interface{})
|
m, ok := r.(map[string]any)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("entry %v in tun.unsafe_routes is invalid", i+1)
|
return nil, fmt.Errorf("entry %v in tun.unsafe_routes is invalid", i+1)
|
||||||
}
|
}
|
||||||
@@ -201,14 +204,63 @@ func parseUnsafeRoutes(c *config.C, networks []netip.Prefix) ([]Route, error) {
|
|||||||
return nil, fmt.Errorf("entry %v.via in tun.unsafe_routes is not present", i+1)
|
return nil, fmt.Errorf("entry %v.via in tun.unsafe_routes is not present", i+1)
|
||||||
}
|
}
|
||||||
|
|
||||||
via, ok := rVia.(string)
|
var gateways routing.Gateways
|
||||||
if !ok {
|
|
||||||
return nil, fmt.Errorf("entry %v.via in tun.unsafe_routes is not a string: found %T", i+1, rVia)
|
|
||||||
}
|
|
||||||
|
|
||||||
viaVpnIp, err := netip.ParseAddr(via)
|
switch via := rVia.(type) {
|
||||||
if err != nil {
|
case string:
|
||||||
return nil, fmt.Errorf("entry %v.via in tun.unsafe_routes failed to parse address: %v", i+1, err)
|
viaIp, err := netip.ParseAddr(via)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("entry %v.via in tun.unsafe_routes failed to parse address: %v", i+1, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
gateways = routing.Gateways{routing.NewGateway(viaIp, 1)}
|
||||||
|
|
||||||
|
case []any:
|
||||||
|
gateways = make(routing.Gateways, len(via))
|
||||||
|
for ig, v := range via {
|
||||||
|
gatewayMap, ok := v.(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("entry %v in tun.unsafe_routes[%v].via is invalid", i+1, ig+1)
|
||||||
|
}
|
||||||
|
|
||||||
|
rGateway, ok := gatewayMap["gateway"]
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("entry .gateway in tun.unsafe_routes[%v].via[%v] is not present", i+1, ig+1)
|
||||||
|
}
|
||||||
|
|
||||||
|
parsedGateway, ok := rGateway.(string)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("entry .gateway in tun.unsafe_routes[%v].via[%v] is not a string", i+1, ig+1)
|
||||||
|
}
|
||||||
|
|
||||||
|
gatewayIp, err := netip.ParseAddr(parsedGateway)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("entry .gateway in tun.unsafe_routes[%v].via[%v] failed to parse address: %v", i+1, ig+1, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
rGatewayWeight, ok := gatewayMap["weight"]
|
||||||
|
if !ok {
|
||||||
|
rGatewayWeight = 1
|
||||||
|
}
|
||||||
|
|
||||||
|
gatewayWeight, ok := rGatewayWeight.(int)
|
||||||
|
if !ok {
|
||||||
|
_, err = strconv.ParseInt(rGatewayWeight.(string), 10, 32)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("entry .weight in tun.unsafe_routes[%v].via[%v] is not an integer", i+1, ig+1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if gatewayWeight < 1 || gatewayWeight > math.MaxInt32 {
|
||||||
|
return nil, fmt.Errorf("entry .weight in tun.unsafe_routes[%v].via[%v] is not in range (1-%d) : %v", i+1, ig+1, math.MaxInt32, gatewayWeight)
|
||||||
|
}
|
||||||
|
|
||||||
|
gateways[ig] = routing.NewGateway(gatewayIp, gatewayWeight)
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("entry %v.via in tun.unsafe_routes is not a string or list of gateways: found %T", i+1, rVia)
|
||||||
}
|
}
|
||||||
|
|
||||||
rRoute, ok := m["route"]
|
rRoute, ok := m["route"]
|
||||||
@@ -226,7 +278,7 @@ func parseUnsafeRoutes(c *config.C, networks []netip.Prefix) ([]Route, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
r := Route{
|
r := Route{
|
||||||
Via: viaVpnIp,
|
Via: gateways,
|
||||||
MTU: mtu,
|
MTU: mtu,
|
||||||
Metric: metric,
|
Metric: metric,
|
||||||
Install: install,
|
Install: install,
|
||||||
|
|||||||
@@ -6,94 +6,96 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
|
"github.com/slackhq/nebula/routing"
|
||||||
"github.com/slackhq/nebula/test"
|
"github.com/slackhq/nebula/test"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
func Test_parseRoutes(t *testing.T) {
|
func Test_parseRoutes(t *testing.T) {
|
||||||
l := test.NewLogger()
|
l := test.NewLogger()
|
||||||
c := config.NewC(l)
|
c := config.NewC(l)
|
||||||
n, err := netip.ParsePrefix("10.0.0.0/24")
|
n, err := netip.ParsePrefix("10.0.0.0/24")
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// test no routes config
|
// test no routes config
|
||||||
routes, err := parseRoutes(c, []netip.Prefix{n})
|
routes, err := parseRoutes(c, []netip.Prefix{n})
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
assert.Len(t, routes, 0)
|
assert.Empty(t, routes)
|
||||||
|
|
||||||
// not an array
|
// not an array
|
||||||
c.Settings["tun"] = map[interface{}]interface{}{"routes": "hi"}
|
c.Settings["tun"] = map[string]any{"routes": "hi"}
|
||||||
routes, err = parseRoutes(c, []netip.Prefix{n})
|
routes, err = parseRoutes(c, []netip.Prefix{n})
|
||||||
assert.Nil(t, routes)
|
assert.Nil(t, routes)
|
||||||
assert.EqualError(t, err, "tun.routes is not an array")
|
require.EqualError(t, err, "tun.routes is not an array")
|
||||||
|
|
||||||
// no routes
|
// no routes
|
||||||
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{}}
|
c.Settings["tun"] = map[string]any{"routes": []any{}}
|
||||||
routes, err = parseRoutes(c, []netip.Prefix{n})
|
routes, err = parseRoutes(c, []netip.Prefix{n})
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
assert.Len(t, routes, 0)
|
assert.Empty(t, routes)
|
||||||
|
|
||||||
// weird route
|
// weird route
|
||||||
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{"asdf"}}
|
c.Settings["tun"] = map[string]any{"routes": []any{"asdf"}}
|
||||||
routes, err = parseRoutes(c, []netip.Prefix{n})
|
routes, err = parseRoutes(c, []netip.Prefix{n})
|
||||||
assert.Nil(t, routes)
|
assert.Nil(t, routes)
|
||||||
assert.EqualError(t, err, "entry 1 in tun.routes is invalid")
|
require.EqualError(t, err, "entry 1 in tun.routes is invalid")
|
||||||
|
|
||||||
// no mtu
|
// no mtu
|
||||||
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{}}}
|
c.Settings["tun"] = map[string]any{"routes": []any{map[string]any{}}}
|
||||||
routes, err = parseRoutes(c, []netip.Prefix{n})
|
routes, err = parseRoutes(c, []netip.Prefix{n})
|
||||||
assert.Nil(t, routes)
|
assert.Nil(t, routes)
|
||||||
assert.EqualError(t, err, "entry 1.mtu in tun.routes is not present")
|
require.EqualError(t, err, "entry 1.mtu in tun.routes is not present")
|
||||||
|
|
||||||
// bad mtu
|
// bad mtu
|
||||||
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "nope"}}}
|
c.Settings["tun"] = map[string]any{"routes": []any{map[string]any{"mtu": "nope"}}}
|
||||||
routes, err = parseRoutes(c, []netip.Prefix{n})
|
routes, err = parseRoutes(c, []netip.Prefix{n})
|
||||||
assert.Nil(t, routes)
|
assert.Nil(t, routes)
|
||||||
assert.EqualError(t, err, "entry 1.mtu in tun.routes is not an integer: strconv.Atoi: parsing \"nope\": invalid syntax")
|
require.EqualError(t, err, "entry 1.mtu in tun.routes is not an integer: strconv.Atoi: parsing \"nope\": invalid syntax")
|
||||||
|
|
||||||
// low mtu
|
// low mtu
|
||||||
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "499"}}}
|
c.Settings["tun"] = map[string]any{"routes": []any{map[string]any{"mtu": "499"}}}
|
||||||
routes, err = parseRoutes(c, []netip.Prefix{n})
|
routes, err = parseRoutes(c, []netip.Prefix{n})
|
||||||
assert.Nil(t, routes)
|
assert.Nil(t, routes)
|
||||||
assert.EqualError(t, err, "entry 1.mtu in tun.routes is below 500: 499")
|
require.EqualError(t, err, "entry 1.mtu in tun.routes is below 500: 499")
|
||||||
|
|
||||||
// missing route
|
// missing route
|
||||||
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500"}}}
|
c.Settings["tun"] = map[string]any{"routes": []any{map[string]any{"mtu": "500"}}}
|
||||||
routes, err = parseRoutes(c, []netip.Prefix{n})
|
routes, err = parseRoutes(c, []netip.Prefix{n})
|
||||||
assert.Nil(t, routes)
|
assert.Nil(t, routes)
|
||||||
assert.EqualError(t, err, "entry 1.route in tun.routes is not present")
|
require.EqualError(t, err, "entry 1.route in tun.routes is not present")
|
||||||
|
|
||||||
// unparsable route
|
// unparsable route
|
||||||
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "nope"}}}
|
c.Settings["tun"] = map[string]any{"routes": []any{map[string]any{"mtu": "500", "route": "nope"}}}
|
||||||
routes, err = parseRoutes(c, []netip.Prefix{n})
|
routes, err = parseRoutes(c, []netip.Prefix{n})
|
||||||
assert.Nil(t, routes)
|
assert.Nil(t, routes)
|
||||||
assert.EqualError(t, err, "entry 1.route in tun.routes failed to parse: netip.ParsePrefix(\"nope\"): no '/'")
|
require.EqualError(t, err, "entry 1.route in tun.routes failed to parse: netip.ParsePrefix(\"nope\"): no '/'")
|
||||||
|
|
||||||
// below network range
|
// below network range
|
||||||
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "1.0.0.0/8"}}}
|
c.Settings["tun"] = map[string]any{"routes": []any{map[string]any{"mtu": "500", "route": "1.0.0.0/8"}}}
|
||||||
routes, err = parseRoutes(c, []netip.Prefix{n})
|
routes, err = parseRoutes(c, []netip.Prefix{n})
|
||||||
assert.Nil(t, routes)
|
assert.Nil(t, routes)
|
||||||
assert.EqualError(t, err, "entry 1.route in tun.routes is not contained within the configured vpn networks; route: 1.0.0.0/8, networks: [10.0.0.0/24]")
|
require.EqualError(t, err, "entry 1.route in tun.routes is not contained within the configured vpn networks; route: 1.0.0.0/8, networks: [10.0.0.0/24]")
|
||||||
|
|
||||||
// above network range
|
// above network range
|
||||||
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "10.0.1.0/24"}}}
|
c.Settings["tun"] = map[string]any{"routes": []any{map[string]any{"mtu": "500", "route": "10.0.1.0/24"}}}
|
||||||
routes, err = parseRoutes(c, []netip.Prefix{n})
|
routes, err = parseRoutes(c, []netip.Prefix{n})
|
||||||
assert.Nil(t, routes)
|
assert.Nil(t, routes)
|
||||||
assert.EqualError(t, err, "entry 1.route in tun.routes is not contained within the configured vpn networks; route: 10.0.1.0/24, networks: [10.0.0.0/24]")
|
require.EqualError(t, err, "entry 1.route in tun.routes is not contained within the configured vpn networks; route: 10.0.1.0/24, networks: [10.0.0.0/24]")
|
||||||
|
|
||||||
// Not in multiple ranges
|
// Not in multiple ranges
|
||||||
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "192.0.0.0/24"}}}
|
c.Settings["tun"] = map[string]any{"routes": []any{map[string]any{"mtu": "500", "route": "192.0.0.0/24"}}}
|
||||||
routes, err = parseRoutes(c, []netip.Prefix{n, netip.MustParsePrefix("192.1.0.0/24")})
|
routes, err = parseRoutes(c, []netip.Prefix{n, netip.MustParsePrefix("192.1.0.0/24")})
|
||||||
assert.Nil(t, routes)
|
assert.Nil(t, routes)
|
||||||
assert.EqualError(t, err, "entry 1.route in tun.routes is not contained within the configured vpn networks; route: 192.0.0.0/24, networks: [10.0.0.0/24 192.1.0.0/24]")
|
require.EqualError(t, err, "entry 1.route in tun.routes is not contained within the configured vpn networks; route: 192.0.0.0/24, networks: [10.0.0.0/24 192.1.0.0/24]")
|
||||||
|
|
||||||
// happy case
|
// happy case
|
||||||
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{
|
c.Settings["tun"] = map[string]any{"routes": []any{
|
||||||
map[interface{}]interface{}{"mtu": "9000", "route": "10.0.0.0/29"},
|
map[string]any{"mtu": "9000", "route": "10.0.0.0/29"},
|
||||||
map[interface{}]interface{}{"mtu": "8000", "route": "10.0.0.1/32"},
|
map[string]any{"mtu": "8000", "route": "10.0.0.1/32"},
|
||||||
}}
|
}}
|
||||||
routes, err = parseRoutes(c, []netip.Prefix{n})
|
routes, err = parseRoutes(c, []netip.Prefix{n})
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
assert.Len(t, routes, 2)
|
assert.Len(t, routes, 2)
|
||||||
|
|
||||||
tested := 0
|
tested := 0
|
||||||
@@ -119,116 +121,140 @@ func Test_parseUnsafeRoutes(t *testing.T) {
|
|||||||
l := test.NewLogger()
|
l := test.NewLogger()
|
||||||
c := config.NewC(l)
|
c := config.NewC(l)
|
||||||
n, err := netip.ParsePrefix("10.0.0.0/24")
|
n, err := netip.ParsePrefix("10.0.0.0/24")
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// test no routes config
|
// test no routes config
|
||||||
routes, err := parseUnsafeRoutes(c, []netip.Prefix{n})
|
routes, err := parseUnsafeRoutes(c, []netip.Prefix{n})
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
assert.Len(t, routes, 0)
|
assert.Empty(t, routes)
|
||||||
|
|
||||||
// not an array
|
// not an array
|
||||||
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": "hi"}
|
c.Settings["tun"] = map[string]any{"unsafe_routes": "hi"}
|
||||||
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
|
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
|
||||||
assert.Nil(t, routes)
|
assert.Nil(t, routes)
|
||||||
assert.EqualError(t, err, "tun.unsafe_routes is not an array")
|
require.EqualError(t, err, "tun.unsafe_routes is not an array")
|
||||||
|
|
||||||
// no routes
|
// no routes
|
||||||
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{}}
|
c.Settings["tun"] = map[string]any{"unsafe_routes": []any{}}
|
||||||
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
|
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
assert.Len(t, routes, 0)
|
assert.Empty(t, routes)
|
||||||
|
|
||||||
// weird route
|
// weird route
|
||||||
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{"asdf"}}
|
c.Settings["tun"] = map[string]any{"unsafe_routes": []any{"asdf"}}
|
||||||
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
|
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
|
||||||
assert.Nil(t, routes)
|
assert.Nil(t, routes)
|
||||||
assert.EqualError(t, err, "entry 1 in tun.unsafe_routes is invalid")
|
require.EqualError(t, err, "entry 1 in tun.unsafe_routes is invalid")
|
||||||
|
|
||||||
// no via
|
// no via
|
||||||
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{}}}
|
c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{}}}
|
||||||
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
|
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
|
||||||
assert.Nil(t, routes)
|
assert.Nil(t, routes)
|
||||||
assert.EqualError(t, err, "entry 1.via in tun.unsafe_routes is not present")
|
require.EqualError(t, err, "entry 1.via in tun.unsafe_routes is not present")
|
||||||
|
|
||||||
// invalid via
|
// invalid via
|
||||||
for _, invalidValue := range []interface{}{
|
for _, invalidValue := range []any{
|
||||||
127, false, nil, 1.0, []string{"1", "2"},
|
127, false, nil, 1.0, []string{"1", "2"},
|
||||||
} {
|
} {
|
||||||
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": invalidValue}}}
|
c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"via": invalidValue}}}
|
||||||
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
|
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
|
||||||
assert.Nil(t, routes)
|
assert.Nil(t, routes)
|
||||||
assert.EqualError(t, err, fmt.Sprintf("entry 1.via in tun.unsafe_routes is not a string: found %T", invalidValue))
|
require.EqualError(t, err, fmt.Sprintf("entry 1.via in tun.unsafe_routes is not a string or list of gateways: found %T", invalidValue))
|
||||||
}
|
}
|
||||||
|
|
||||||
// unparsable via
|
// Unparsable list of via
|
||||||
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "via": "nope"}}}
|
c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"via": []string{"1", "2"}}}}
|
||||||
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
|
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
|
||||||
assert.Nil(t, routes)
|
assert.Nil(t, routes)
|
||||||
assert.EqualError(t, err, "entry 1.via in tun.unsafe_routes failed to parse address: ParseAddr(\"nope\"): unable to parse IP")
|
require.EqualError(t, err, "entry 1.via in tun.unsafe_routes is not a string or list of gateways: found []string")
|
||||||
|
|
||||||
|
// unparsable via
|
||||||
|
c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"mtu": "500", "via": "nope"}}}
|
||||||
|
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
|
||||||
|
assert.Nil(t, routes)
|
||||||
|
require.EqualError(t, err, "entry 1.via in tun.unsafe_routes failed to parse address: ParseAddr(\"nope\"): unable to parse IP")
|
||||||
|
|
||||||
|
// unparsable gateway
|
||||||
|
c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"mtu": "500", "via": []any{map[string]any{"gateway": "1"}}}}}
|
||||||
|
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
|
||||||
|
assert.Nil(t, routes)
|
||||||
|
require.EqualError(t, err, "entry .gateway in tun.unsafe_routes[1].via[1] failed to parse address: ParseAddr(\"1\"): unable to parse IP")
|
||||||
|
|
||||||
|
// missing gateway element
|
||||||
|
c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"mtu": "500", "via": []any{map[string]any{"weight": "1"}}}}}
|
||||||
|
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
|
||||||
|
assert.Nil(t, routes)
|
||||||
|
require.EqualError(t, err, "entry .gateway in tun.unsafe_routes[1].via[1] is not present")
|
||||||
|
|
||||||
|
// unparsable weight element
|
||||||
|
c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"mtu": "500", "via": []any{map[string]any{"gateway": "10.0.0.1", "weight": "a"}}}}}
|
||||||
|
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
|
||||||
|
assert.Nil(t, routes)
|
||||||
|
require.EqualError(t, err, "entry .weight in tun.unsafe_routes[1].via[1] is not an integer")
|
||||||
|
|
||||||
// missing route
|
// missing route
|
||||||
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "500"}}}
|
c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"via": "127.0.0.1", "mtu": "500"}}}
|
||||||
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
|
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
|
||||||
assert.Nil(t, routes)
|
assert.Nil(t, routes)
|
||||||
assert.EqualError(t, err, "entry 1.route in tun.unsafe_routes is not present")
|
require.EqualError(t, err, "entry 1.route in tun.unsafe_routes is not present")
|
||||||
|
|
||||||
// unparsable route
|
// unparsable route
|
||||||
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "500", "route": "nope"}}}
|
c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"via": "127.0.0.1", "mtu": "500", "route": "nope"}}}
|
||||||
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
|
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
|
||||||
assert.Nil(t, routes)
|
assert.Nil(t, routes)
|
||||||
assert.EqualError(t, err, "entry 1.route in tun.unsafe_routes failed to parse: netip.ParsePrefix(\"nope\"): no '/'")
|
require.EqualError(t, err, "entry 1.route in tun.unsafe_routes failed to parse: netip.ParsePrefix(\"nope\"): no '/'")
|
||||||
|
|
||||||
// within network range
|
// within network range
|
||||||
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "10.0.0.0/24"}}}
|
c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"via": "127.0.0.1", "route": "10.0.0.0/24"}}}
|
||||||
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
|
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
|
||||||
assert.Nil(t, routes)
|
assert.Nil(t, routes)
|
||||||
assert.EqualError(t, err, "entry 1.route in tun.unsafe_routes is contained within the configured vpn networks; route: 10.0.0.0/24, network: 10.0.0.0/24")
|
require.EqualError(t, err, "entry 1.route in tun.unsafe_routes is contained within the configured vpn networks; route: 10.0.0.0/24, network: 10.0.0.0/24")
|
||||||
|
|
||||||
// below network range
|
// below network range
|
||||||
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "1.0.0.0/8"}}}
|
c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"via": "127.0.0.1", "route": "1.0.0.0/8"}}}
|
||||||
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
|
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
|
||||||
assert.Len(t, routes, 1)
|
assert.Len(t, routes, 1)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// above network range
|
// above network range
|
||||||
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "10.0.1.0/24"}}}
|
c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"via": "127.0.0.1", "route": "10.0.1.0/24"}}}
|
||||||
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
|
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
|
||||||
assert.Len(t, routes, 1)
|
assert.Len(t, routes, 1)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// no mtu
|
// no mtu
|
||||||
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "1.0.0.0/8"}}}
|
c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"via": "127.0.0.1", "route": "1.0.0.0/8"}}}
|
||||||
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
|
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
|
||||||
assert.Len(t, routes, 1)
|
assert.Len(t, routes, 1)
|
||||||
assert.Equal(t, 0, routes[0].MTU)
|
assert.Equal(t, 0, routes[0].MTU)
|
||||||
|
|
||||||
// bad mtu
|
// bad mtu
|
||||||
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "nope"}}}
|
c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"via": "127.0.0.1", "mtu": "nope"}}}
|
||||||
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
|
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
|
||||||
assert.Nil(t, routes)
|
assert.Nil(t, routes)
|
||||||
assert.EqualError(t, err, "entry 1.mtu in tun.unsafe_routes is not an integer: strconv.Atoi: parsing \"nope\": invalid syntax")
|
require.EqualError(t, err, "entry 1.mtu in tun.unsafe_routes is not an integer: strconv.Atoi: parsing \"nope\": invalid syntax")
|
||||||
|
|
||||||
// low mtu
|
// low mtu
|
||||||
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "499"}}}
|
c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"via": "127.0.0.1", "mtu": "499"}}}
|
||||||
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
|
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
|
||||||
assert.Nil(t, routes)
|
assert.Nil(t, routes)
|
||||||
assert.EqualError(t, err, "entry 1.mtu in tun.unsafe_routes is below 500: 499")
|
require.EqualError(t, err, "entry 1.mtu in tun.unsafe_routes is below 500: 499")
|
||||||
|
|
||||||
// bad install
|
// bad install
|
||||||
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "9000", "route": "1.0.0.0/29", "install": "nope"}}}
|
c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"via": "127.0.0.1", "mtu": "9000", "route": "1.0.0.0/29", "install": "nope"}}}
|
||||||
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
|
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
|
||||||
assert.Nil(t, routes)
|
assert.Nil(t, routes)
|
||||||
assert.EqualError(t, err, "entry 1.install in tun.unsafe_routes is not a boolean: strconv.ParseBool: parsing \"nope\": invalid syntax")
|
require.EqualError(t, err, "entry 1.install in tun.unsafe_routes is not a boolean: strconv.ParseBool: parsing \"nope\": invalid syntax")
|
||||||
|
|
||||||
// happy case
|
// happy case
|
||||||
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{
|
c.Settings["tun"] = map[string]any{"unsafe_routes": []any{
|
||||||
map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "9000", "route": "1.0.0.0/29", "install": "t"},
|
map[string]any{"via": "127.0.0.1", "mtu": "9000", "route": "1.0.0.0/29", "install": "t"},
|
||||||
map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "8000", "route": "1.0.0.1/32", "install": 0},
|
map[string]any{"via": "127.0.0.1", "mtu": "8000", "route": "1.0.0.1/32", "install": 0},
|
||||||
map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "1500", "metric": 1234, "route": "1.0.0.2/32", "install": 1},
|
map[string]any{"via": "127.0.0.1", "mtu": "1500", "metric": 1234, "route": "1.0.0.2/32", "install": 1},
|
||||||
map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "1500", "metric": 1234, "route": "1.0.0.2/32"},
|
map[string]any{"via": "127.0.0.1", "mtu": "1500", "metric": 1234, "route": "1.0.0.2/32"},
|
||||||
}}
|
}}
|
||||||
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
|
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
assert.Len(t, routes, 4)
|
assert.Len(t, routes, 4)
|
||||||
|
|
||||||
tested := 0
|
tested := 0
|
||||||
@@ -260,38 +286,119 @@ func Test_makeRouteTree(t *testing.T) {
|
|||||||
l := test.NewLogger()
|
l := test.NewLogger()
|
||||||
c := config.NewC(l)
|
c := config.NewC(l)
|
||||||
n, err := netip.ParsePrefix("10.0.0.0/24")
|
n, err := netip.ParsePrefix("10.0.0.0/24")
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{
|
c.Settings["tun"] = map[string]any{"unsafe_routes": []any{
|
||||||
map[interface{}]interface{}{"via": "192.168.0.1", "route": "1.0.0.0/28"},
|
map[string]any{"via": "192.168.0.1", "route": "1.0.0.0/28"},
|
||||||
map[interface{}]interface{}{"via": "192.168.0.2", "route": "1.0.0.1/32"},
|
map[string]any{"via": "192.168.0.2", "route": "1.0.0.1/32"},
|
||||||
}}
|
}}
|
||||||
routes, err := parseUnsafeRoutes(c, []netip.Prefix{n})
|
routes, err := parseUnsafeRoutes(c, []netip.Prefix{n})
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Len(t, routes, 2)
|
assert.Len(t, routes, 2)
|
||||||
routeTree, err := makeRouteTree(l, routes, true)
|
routeTree, err := makeRouteTree(l, routes, true)
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
ip, err := netip.ParseAddr("1.0.0.2")
|
ip, err := netip.ParseAddr("1.0.0.2")
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
r, ok := routeTree.Lookup(ip)
|
r, ok := routeTree.Lookup(ip)
|
||||||
assert.True(t, ok)
|
assert.True(t, ok)
|
||||||
|
|
||||||
nip, err := netip.ParseAddr("192.168.0.1")
|
nip, err := netip.ParseAddr("192.168.0.1")
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, nip, r)
|
assert.Equal(t, nip, r[0].Addr())
|
||||||
|
|
||||||
ip, err = netip.ParseAddr("1.0.0.1")
|
ip, err = netip.ParseAddr("1.0.0.1")
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
r, ok = routeTree.Lookup(ip)
|
r, ok = routeTree.Lookup(ip)
|
||||||
assert.True(t, ok)
|
assert.True(t, ok)
|
||||||
|
|
||||||
nip, err = netip.ParseAddr("192.168.0.2")
|
nip, err = netip.ParseAddr("192.168.0.2")
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, nip, r)
|
assert.Equal(t, nip, r[0].Addr())
|
||||||
|
|
||||||
ip, err = netip.ParseAddr("1.1.0.1")
|
ip, err = netip.ParseAddr("1.1.0.1")
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
r, ok = routeTree.Lookup(ip)
|
r, ok = routeTree.Lookup(ip)
|
||||||
assert.False(t, ok)
|
assert.False(t, ok)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func Test_makeMultipathUnsafeRouteTree(t *testing.T) {
|
||||||
|
l := test.NewLogger()
|
||||||
|
c := config.NewC(l)
|
||||||
|
n, err := netip.ParsePrefix("10.0.0.0/24")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
c.Settings["tun"] = map[string]any{
|
||||||
|
"unsafe_routes": []any{
|
||||||
|
map[string]any{
|
||||||
|
"route": "192.168.86.0/24",
|
||||||
|
"via": "192.168.100.10",
|
||||||
|
},
|
||||||
|
map[string]any{
|
||||||
|
"route": "192.168.87.0/24",
|
||||||
|
"via": []any{
|
||||||
|
map[string]any{
|
||||||
|
"gateway": "10.0.0.1",
|
||||||
|
},
|
||||||
|
map[string]any{
|
||||||
|
"gateway": "10.0.0.2",
|
||||||
|
},
|
||||||
|
map[string]any{
|
||||||
|
"gateway": "10.0.0.3",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
map[string]any{
|
||||||
|
"route": "192.168.89.0/24",
|
||||||
|
"via": []any{
|
||||||
|
map[string]any{
|
||||||
|
"gateway": "10.0.0.1",
|
||||||
|
"weight": 10,
|
||||||
|
},
|
||||||
|
map[string]any{
|
||||||
|
"gateway": "10.0.0.2",
|
||||||
|
"weight": 5,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
routes, err := parseUnsafeRoutes(c, []netip.Prefix{n})
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Len(t, routes, 3)
|
||||||
|
routeTree, err := makeRouteTree(l, routes, true)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
ip, err := netip.ParseAddr("192.168.86.1")
|
||||||
|
require.NoError(t, err)
|
||||||
|
r, ok := routeTree.Lookup(ip)
|
||||||
|
assert.True(t, ok)
|
||||||
|
|
||||||
|
nip, err := netip.ParseAddr("192.168.100.10")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, nip, r[0].Addr())
|
||||||
|
|
||||||
|
ip, err = netip.ParseAddr("192.168.87.1")
|
||||||
|
require.NoError(t, err)
|
||||||
|
r, ok = routeTree.Lookup(ip)
|
||||||
|
assert.True(t, ok)
|
||||||
|
|
||||||
|
expectedGateways := routing.Gateways{routing.NewGateway(netip.MustParseAddr("10.0.0.1"), 1),
|
||||||
|
routing.NewGateway(netip.MustParseAddr("10.0.0.2"), 1),
|
||||||
|
routing.NewGateway(netip.MustParseAddr("10.0.0.3"), 1)}
|
||||||
|
|
||||||
|
routing.CalculateBucketsForGateways(expectedGateways)
|
||||||
|
assert.ElementsMatch(t, expectedGateways, r)
|
||||||
|
|
||||||
|
ip, err = netip.ParseAddr("192.168.89.1")
|
||||||
|
require.NoError(t, err)
|
||||||
|
r, ok = routeTree.Lookup(ip)
|
||||||
|
assert.True(t, ok)
|
||||||
|
|
||||||
|
expectedGateways = routing.Gateways{routing.NewGateway(netip.MustParseAddr("10.0.0.1"), 10),
|
||||||
|
routing.NewGateway(netip.MustParseAddr("10.0.0.2"), 5)}
|
||||||
|
|
||||||
|
routing.CalculateBucketsForGateways(expectedGateways)
|
||||||
|
assert.ElementsMatch(t, expectedGateways, r)
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
package overlay
|
package overlay
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
@@ -70,3 +72,51 @@ func findRemovedRoutes(newRoutes, oldRoutes []Route) []Route {
|
|||||||
|
|
||||||
return removed
|
return removed
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func prefixToMask(prefix netip.Prefix) netip.Addr {
|
||||||
|
pLen := 128
|
||||||
|
if prefix.Addr().Is4() {
|
||||||
|
pLen = 32
|
||||||
|
}
|
||||||
|
|
||||||
|
addr, _ := netip.AddrFromSlice(net.CIDRMask(prefix.Bits(), pLen))
|
||||||
|
return addr
|
||||||
|
}
|
||||||
|
|
||||||
|
func flipBytes(b []byte) []byte {
|
||||||
|
for i := 0; i < len(b); i++ {
|
||||||
|
b[i] ^= 0xFF
|
||||||
|
}
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
func orBytes(a []byte, b []byte) []byte {
|
||||||
|
ret := make([]byte, len(a))
|
||||||
|
for i := 0; i < len(a); i++ {
|
||||||
|
ret[i] = a[i] | b[i]
|
||||||
|
}
|
||||||
|
return ret
|
||||||
|
}
|
||||||
|
|
||||||
|
func getBroadcast(cidr netip.Prefix) netip.Addr {
|
||||||
|
broadcast, _ := netip.AddrFromSlice(
|
||||||
|
orBytes(
|
||||||
|
cidr.Addr().AsSlice(),
|
||||||
|
flipBytes(prefixToMask(cidr).AsSlice()),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return broadcast
|
||||||
|
}
|
||||||
|
|
||||||
|
func selectGateway(dest netip.Prefix, gateways []netip.Prefix) (netip.Prefix, error) {
|
||||||
|
for _, gateway := range gateways {
|
||||||
|
if dest.Addr().Is4() && gateway.Addr().Is4() {
|
||||||
|
return gateway, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if dest.Addr().Is6() && gateway.Addr().Is6() {
|
||||||
|
return gateway, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return netip.Prefix{}, fmt.Errorf("no gateway found for %v in the list of vpn networks", dest)
|
||||||
|
}
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ import (
|
|||||||
"github.com/gaissmai/bart"
|
"github.com/gaissmai/bart"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
|
"github.com/slackhq/nebula/routing"
|
||||||
"github.com/slackhq/nebula/util"
|
"github.com/slackhq/nebula/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -21,7 +22,7 @@ type tun struct {
|
|||||||
fd int
|
fd int
|
||||||
vpnNetworks []netip.Prefix
|
vpnNetworks []netip.Prefix
|
||||||
Routes atomic.Pointer[[]Route]
|
Routes atomic.Pointer[[]Route]
|
||||||
routeTree atomic.Pointer[bart.Table[netip.Addr]]
|
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
||||||
l *logrus.Logger
|
l *logrus.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -56,7 +57,7 @@ func newTun(_ *config.C, _ *logrus.Logger, _ []netip.Prefix, _ bool) (*tun, erro
|
|||||||
return nil, fmt.Errorf("newTun not supported in Android")
|
return nil, fmt.Errorf("newTun not supported in Android")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) RouteFor(ip netip.Addr) netip.Addr {
|
func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
|
||||||
r, _ := t.routeTree.Load().Lookup(ip)
|
r, _ := t.routeTree.Load().Lookup(ip)
|
||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
@@ -94,6 +95,10 @@ func (t *tun) Name() string {
|
|||||||
return "android"
|
return "android"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *tun) SupportsMultiqueue() bool {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||||
return nil, fmt.Errorf("TODO: multiqueue not implemented for android")
|
return nil, fmt.Errorf("TODO: multiqueue not implemented for android")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,7 +7,6 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
@@ -17,6 +16,7 @@ import (
|
|||||||
"github.com/gaissmai/bart"
|
"github.com/gaissmai/bart"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
|
"github.com/slackhq/nebula/routing"
|
||||||
"github.com/slackhq/nebula/util"
|
"github.com/slackhq/nebula/util"
|
||||||
netroute "golang.org/x/net/route"
|
netroute "golang.org/x/net/route"
|
||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
@@ -28,7 +28,7 @@ type tun struct {
|
|||||||
vpnNetworks []netip.Prefix
|
vpnNetworks []netip.Prefix
|
||||||
DefaultMTU int
|
DefaultMTU int
|
||||||
Routes atomic.Pointer[[]Route]
|
Routes atomic.Pointer[[]Route]
|
||||||
routeTree atomic.Pointer[bart.Table[netip.Addr]]
|
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
||||||
linkAddr *netroute.LinkAddr
|
linkAddr *netroute.LinkAddr
|
||||||
l *logrus.Logger
|
l *logrus.Logger
|
||||||
|
|
||||||
@@ -341,12 +341,12 @@ func (t *tun) reload(c *config.C, initial bool) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) RouteFor(ip netip.Addr) netip.Addr {
|
func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
|
||||||
r, ok := t.routeTree.Load().Lookup(ip)
|
r, ok := t.routeTree.Load().Lookup(ip)
|
||||||
if ok {
|
if ok {
|
||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
return netip.Addr{}
|
return routing.Gateways{}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get the LinkAddr for the interface of the given name
|
// Get the LinkAddr for the interface of the given name
|
||||||
@@ -381,7 +381,7 @@ func (t *tun) addRoutes(logErrors bool) error {
|
|||||||
routes := *t.Routes.Load()
|
routes := *t.Routes.Load()
|
||||||
|
|
||||||
for _, r := range routes {
|
for _, r := range routes {
|
||||||
if !r.Via.IsValid() || !r.Install {
|
if len(r.Via) == 0 || !r.Install {
|
||||||
// We don't allow route MTUs so only install routes with a via
|
// We don't allow route MTUs so only install routes with a via
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -392,7 +392,7 @@ func (t *tun) addRoutes(logErrors bool) error {
|
|||||||
t.l.WithField("route", r.Cidr).
|
t.l.WithField("route", r.Cidr).
|
||||||
Warnf("unable to add unsafe_route, identical route already exists")
|
Warnf("unable to add unsafe_route, identical route already exists")
|
||||||
} else {
|
} else {
|
||||||
retErr := util.NewContextualError("Failed to add route", map[string]interface{}{"route": r}, err)
|
retErr := util.NewContextualError("Failed to add route", map[string]any{"route": r}, err)
|
||||||
if logErrors {
|
if logErrors {
|
||||||
retErr.Log(t.l)
|
retErr.Log(t.l)
|
||||||
} else {
|
} else {
|
||||||
@@ -549,16 +549,10 @@ func (t *tun) Name() string {
|
|||||||
return t.Device
|
return t.Device
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *tun) SupportsMultiqueue() bool {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||||
return nil, fmt.Errorf("TODO: multiqueue not implemented for darwin")
|
return nil, fmt.Errorf("TODO: multiqueue not implemented for darwin")
|
||||||
}
|
}
|
||||||
|
|
||||||
func prefixToMask(prefix netip.Prefix) netip.Addr {
|
|
||||||
pLen := 128
|
|
||||||
if prefix.Addr().Is4() {
|
|
||||||
pLen = 32
|
|
||||||
}
|
|
||||||
|
|
||||||
addr, _ := netip.AddrFromSlice(net.CIDRMask(prefix.Bits(), pLen))
|
|
||||||
return addr
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
"github.com/rcrowley/go-metrics"
|
"github.com/rcrowley/go-metrics"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
"github.com/slackhq/nebula/iputil"
|
"github.com/slackhq/nebula/iputil"
|
||||||
|
"github.com/slackhq/nebula/routing"
|
||||||
)
|
)
|
||||||
|
|
||||||
type disabledTun struct {
|
type disabledTun struct {
|
||||||
@@ -43,8 +44,8 @@ func (*disabledTun) Activate() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (*disabledTun) RouteFor(addr netip.Addr) netip.Addr {
|
func (*disabledTun) RoutesFor(addr netip.Addr) routing.Gateways {
|
||||||
return netip.Addr{}
|
return routing.Gateways{}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *disabledTun) Networks() []netip.Prefix {
|
func (t *disabledTun) Networks() []netip.Prefix {
|
||||||
@@ -104,6 +105,10 @@ func (t *disabledTun) Write(b []byte) (int, error) {
|
|||||||
return len(b), nil
|
return len(b), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *disabledTun) SupportsMultiqueue() bool {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
func (t *disabledTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
func (t *disabledTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||||
return t, nil
|
return t, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,23 +10,28 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"io/fs"
|
"io/fs"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
|
||||||
"os/exec"
|
|
||||||
"strconv"
|
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"syscall"
|
"syscall"
|
||||||
|
"time"
|
||||||
"unsafe"
|
"unsafe"
|
||||||
|
|
||||||
"github.com/gaissmai/bart"
|
"github.com/gaissmai/bart"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
|
"github.com/slackhq/nebula/routing"
|
||||||
"github.com/slackhq/nebula/util"
|
"github.com/slackhq/nebula/util"
|
||||||
|
netroute "golang.org/x/net/route"
|
||||||
|
"golang.org/x/sys/unix"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
// FIODGNAME is defined in sys/sys/filio.h on FreeBSD
|
// FIODGNAME is defined in sys/sys/filio.h on FreeBSD
|
||||||
// For 32-bit systems, use FIODGNAME_32 (not defined in this file: 0x80086678)
|
// For 32-bit systems, use FIODGNAME_32 (not defined in this file: 0x80086678)
|
||||||
FIODGNAME = 0x80106678
|
FIODGNAME = 0x80106678
|
||||||
|
TUNSIFMODE = 0x8004745e
|
||||||
|
TUNSIFHEAD = 0x80047460
|
||||||
|
OSIOCAIFADDR_IN6 = 0x8088691b
|
||||||
|
IN6_IFF_NODAD = 0x0020
|
||||||
)
|
)
|
||||||
|
|
||||||
type fiodgnameArg struct {
|
type fiodgnameArg struct {
|
||||||
@@ -36,43 +41,159 @@ type fiodgnameArg struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type ifreqRename struct {
|
type ifreqRename struct {
|
||||||
Name [16]byte
|
Name [unix.IFNAMSIZ]byte
|
||||||
Data uintptr
|
Data uintptr
|
||||||
}
|
}
|
||||||
|
|
||||||
type ifreqDestroy struct {
|
type ifreqDestroy struct {
|
||||||
Name [16]byte
|
Name [unix.IFNAMSIZ]byte
|
||||||
pad [16]byte
|
pad [16]byte
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type ifReq struct {
|
||||||
|
Name [unix.IFNAMSIZ]byte
|
||||||
|
Flags uint16
|
||||||
|
}
|
||||||
|
|
||||||
|
type ifreqMTU struct {
|
||||||
|
Name [unix.IFNAMSIZ]byte
|
||||||
|
MTU int32
|
||||||
|
}
|
||||||
|
|
||||||
|
type addrLifetime struct {
|
||||||
|
Expire uint64
|
||||||
|
Preferred uint64
|
||||||
|
Vltime uint32
|
||||||
|
Pltime uint32
|
||||||
|
}
|
||||||
|
|
||||||
|
type ifreqAlias4 struct {
|
||||||
|
Name [unix.IFNAMSIZ]byte
|
||||||
|
Addr unix.RawSockaddrInet4
|
||||||
|
DstAddr unix.RawSockaddrInet4
|
||||||
|
MaskAddr unix.RawSockaddrInet4
|
||||||
|
VHid uint32
|
||||||
|
}
|
||||||
|
|
||||||
|
type ifreqAlias6 struct {
|
||||||
|
Name [unix.IFNAMSIZ]byte
|
||||||
|
Addr unix.RawSockaddrInet6
|
||||||
|
DstAddr unix.RawSockaddrInet6
|
||||||
|
PrefixMask unix.RawSockaddrInet6
|
||||||
|
Flags uint32
|
||||||
|
Lifetime addrLifetime
|
||||||
|
VHid uint32
|
||||||
|
}
|
||||||
|
|
||||||
type tun struct {
|
type tun struct {
|
||||||
Device string
|
Device string
|
||||||
vpnNetworks []netip.Prefix
|
vpnNetworks []netip.Prefix
|
||||||
MTU int
|
MTU int
|
||||||
Routes atomic.Pointer[[]Route]
|
Routes atomic.Pointer[[]Route]
|
||||||
routeTree atomic.Pointer[bart.Table[netip.Addr]]
|
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
||||||
|
linkAddr *netroute.LinkAddr
|
||||||
l *logrus.Logger
|
l *logrus.Logger
|
||||||
|
devFd int
|
||||||
|
}
|
||||||
|
|
||||||
io.ReadWriteCloser
|
func (t *tun) Read(to []byte) (int, error) {
|
||||||
|
// use readv() to read from the tunnel device, to eliminate the need for copying the buffer
|
||||||
|
if t.devFd < 0 {
|
||||||
|
return -1, syscall.EINVAL
|
||||||
|
}
|
||||||
|
|
||||||
|
// first 4 bytes is protocol family, in network byte order
|
||||||
|
head := make([]byte, 4)
|
||||||
|
|
||||||
|
iovecs := []syscall.Iovec{
|
||||||
|
{&head[0], 4},
|
||||||
|
{&to[0], uint64(len(to))},
|
||||||
|
}
|
||||||
|
|
||||||
|
n, _, errno := syscall.Syscall(syscall.SYS_READV, uintptr(t.devFd), uintptr(unsafe.Pointer(&iovecs[0])), uintptr(2))
|
||||||
|
|
||||||
|
var err error
|
||||||
|
if errno != 0 {
|
||||||
|
err = syscall.Errno(errno)
|
||||||
|
} else {
|
||||||
|
err = nil
|
||||||
|
}
|
||||||
|
// fix bytes read number to exclude header
|
||||||
|
bytesRead := int(n)
|
||||||
|
if bytesRead < 0 {
|
||||||
|
return bytesRead, err
|
||||||
|
} else if bytesRead < 4 {
|
||||||
|
return 0, err
|
||||||
|
} else {
|
||||||
|
return bytesRead - 4, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write is only valid for single threaded use
|
||||||
|
func (t *tun) Write(from []byte) (int, error) {
|
||||||
|
// use writev() to write to the tunnel device, to eliminate the need for copying the buffer
|
||||||
|
if t.devFd < 0 {
|
||||||
|
return -1, syscall.EINVAL
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(from) <= 1 {
|
||||||
|
return 0, syscall.EIO
|
||||||
|
}
|
||||||
|
ipVer := from[0] >> 4
|
||||||
|
var head []byte
|
||||||
|
// first 4 bytes is protocol family, in network byte order
|
||||||
|
if ipVer == 4 {
|
||||||
|
head = []byte{0, 0, 0, syscall.AF_INET}
|
||||||
|
} else if ipVer == 6 {
|
||||||
|
head = []byte{0, 0, 0, syscall.AF_INET6}
|
||||||
|
} else {
|
||||||
|
return 0, fmt.Errorf("unable to determine IP version from packet")
|
||||||
|
}
|
||||||
|
iovecs := []syscall.Iovec{
|
||||||
|
{&head[0], 4},
|
||||||
|
{&from[0], uint64(len(from))},
|
||||||
|
}
|
||||||
|
|
||||||
|
n, _, errno := syscall.Syscall(syscall.SYS_WRITEV, uintptr(t.devFd), uintptr(unsafe.Pointer(&iovecs[0])), uintptr(2))
|
||||||
|
|
||||||
|
var err error
|
||||||
|
if errno != 0 {
|
||||||
|
err = syscall.Errno(errno)
|
||||||
|
} else {
|
||||||
|
err = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return int(n) - 4, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) Close() error {
|
func (t *tun) Close() error {
|
||||||
if t.ReadWriteCloser != nil {
|
if t.devFd >= 0 {
|
||||||
if err := t.ReadWriteCloser.Close(); err != nil {
|
err := syscall.Close(t.devFd)
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
s, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_DGRAM, syscall.IPPROTO_IP)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
t.l.WithError(err).Error("Error closing device")
|
||||||
}
|
}
|
||||||
defer syscall.Close(s)
|
t.devFd = -1
|
||||||
|
|
||||||
ifreq := ifreqDestroy{Name: t.deviceBytes()}
|
c := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
// destroying the interface can block if a read() is still pending. Do this asynchronously.
|
||||||
|
defer close(c)
|
||||||
|
s, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_DGRAM, syscall.IPPROTO_IP)
|
||||||
|
if err == nil {
|
||||||
|
defer syscall.Close(s)
|
||||||
|
ifreq := ifreqDestroy{Name: t.deviceBytes()}
|
||||||
|
err = ioctl(uintptr(s), syscall.SIOCIFDESTROY, uintptr(unsafe.Pointer(&ifreq)))
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
t.l.WithError(err).Error("Error destroying tunnel")
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
// Destroy the interface
|
// wait up to 1 second so we start blocking at the ioctl
|
||||||
err = ioctl(uintptr(s), syscall.SIOCIFDESTROY, uintptr(unsafe.Pointer(&ifreq)))
|
select {
|
||||||
return err
|
case <-c:
|
||||||
|
case <-time.After(1 * time.Second):
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@@ -84,32 +205,37 @@ func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun,
|
|||||||
|
|
||||||
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
|
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
|
||||||
// Try to open existing tun device
|
// Try to open existing tun device
|
||||||
var file *os.File
|
var fd int
|
||||||
var err error
|
var err error
|
||||||
deviceName := c.GetString("tun.dev", "")
|
deviceName := c.GetString("tun.dev", "")
|
||||||
if deviceName != "" {
|
if deviceName != "" {
|
||||||
file, err = os.OpenFile("/dev/"+deviceName, os.O_RDWR, 0)
|
fd, err = syscall.Open("/dev/"+deviceName, syscall.O_RDWR, 0)
|
||||||
}
|
}
|
||||||
if errors.Is(err, fs.ErrNotExist) || deviceName == "" {
|
if errors.Is(err, fs.ErrNotExist) || deviceName == "" {
|
||||||
// If the device doesn't already exist, request a new one and rename it
|
// If the device doesn't already exist, request a new one and rename it
|
||||||
file, err = os.OpenFile("/dev/tun", os.O_RDWR, 0)
|
fd, err = syscall.Open("/dev/tun", syscall.O_RDWR, 0)
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
rawConn, err := file.SyscallConn()
|
// Read the name of the interface
|
||||||
if err != nil {
|
var name [16]byte
|
||||||
return nil, fmt.Errorf("SyscallConn: %v", err)
|
arg := fiodgnameArg{length: 16, buf: unsafe.Pointer(&name)}
|
||||||
|
ctrlErr := ioctl(uintptr(fd), FIODGNAME, uintptr(unsafe.Pointer(&arg)))
|
||||||
|
|
||||||
|
if ctrlErr == nil {
|
||||||
|
// set broadcast mode and multicast
|
||||||
|
ifmode := uint32(unix.IFF_BROADCAST | unix.IFF_MULTICAST)
|
||||||
|
ctrlErr = ioctl(uintptr(fd), TUNSIFMODE, uintptr(unsafe.Pointer(&ifmode)))
|
||||||
|
}
|
||||||
|
|
||||||
|
if ctrlErr == nil {
|
||||||
|
// turn on link-layer mode, to support ipv6
|
||||||
|
ifhead := uint32(1)
|
||||||
|
ctrlErr = ioctl(uintptr(fd), TUNSIFHEAD, uintptr(unsafe.Pointer(&ifhead)))
|
||||||
}
|
}
|
||||||
|
|
||||||
var name [16]byte
|
|
||||||
var ctrlErr error
|
|
||||||
rawConn.Control(func(fd uintptr) {
|
|
||||||
// Read the name of the interface
|
|
||||||
arg := fiodgnameArg{length: 16, buf: unsafe.Pointer(&name)}
|
|
||||||
ctrlErr = ioctl(fd, FIODGNAME, uintptr(unsafe.Pointer(&arg)))
|
|
||||||
})
|
|
||||||
if ctrlErr != nil {
|
if ctrlErr != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -121,11 +247,7 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
|
|||||||
|
|
||||||
// If the name doesn't match the desired interface name, rename it now
|
// If the name doesn't match the desired interface name, rename it now
|
||||||
if ifName != deviceName {
|
if ifName != deviceName {
|
||||||
s, err := syscall.Socket(
|
s, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP)
|
||||||
syscall.AF_INET,
|
|
||||||
syscall.SOCK_DGRAM,
|
|
||||||
syscall.IPPROTO_IP,
|
|
||||||
)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -148,11 +270,11 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
|
|||||||
}
|
}
|
||||||
|
|
||||||
t := &tun{
|
t := &tun{
|
||||||
ReadWriteCloser: file,
|
Device: deviceName,
|
||||||
Device: deviceName,
|
vpnNetworks: vpnNetworks,
|
||||||
vpnNetworks: vpnNetworks,
|
MTU: c.GetInt("tun.mtu", DefaultMTU),
|
||||||
MTU: c.GetInt("tun.mtu", DefaultMTU),
|
l: l,
|
||||||
l: l,
|
devFd: fd,
|
||||||
}
|
}
|
||||||
|
|
||||||
err = t.reload(c, true)
|
err = t.reload(c, true)
|
||||||
@@ -171,38 +293,111 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) addIp(cidr netip.Prefix) error {
|
func (t *tun) addIp(cidr netip.Prefix) error {
|
||||||
var err error
|
if cidr.Addr().Is4() {
|
||||||
// TODO use syscalls instead of exec.Command
|
ifr := ifreqAlias4{
|
||||||
cmd := exec.Command("/sbin/ifconfig", t.Device, cidr.String(), cidr.Addr().String())
|
Name: t.deviceBytes(),
|
||||||
t.l.Debug("command: ", cmd.String())
|
Addr: unix.RawSockaddrInet4{
|
||||||
if err = cmd.Run(); err != nil {
|
Len: unix.SizeofSockaddrInet4,
|
||||||
return fmt.Errorf("failed to run 'ifconfig': %s", err)
|
Family: unix.AF_INET,
|
||||||
|
Addr: cidr.Addr().As4(),
|
||||||
|
},
|
||||||
|
DstAddr: unix.RawSockaddrInet4{
|
||||||
|
Len: unix.SizeofSockaddrInet4,
|
||||||
|
Family: unix.AF_INET,
|
||||||
|
Addr: getBroadcast(cidr).As4(),
|
||||||
|
},
|
||||||
|
MaskAddr: unix.RawSockaddrInet4{
|
||||||
|
Len: unix.SizeofSockaddrInet4,
|
||||||
|
Family: unix.AF_INET,
|
||||||
|
Addr: prefixToMask(cidr).As4(),
|
||||||
|
},
|
||||||
|
VHid: 0,
|
||||||
|
}
|
||||||
|
s, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer syscall.Close(s)
|
||||||
|
// Note: unix.SIOCAIFADDR corresponds to FreeBSD's OSIOCAIFADDR
|
||||||
|
if err := ioctl(uintptr(s), unix.SIOCAIFADDR, uintptr(unsafe.Pointer(&ifr))); err != nil {
|
||||||
|
return fmt.Errorf("failed to set tun address %s: %s", cidr.Addr().String(), err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
cmd = exec.Command("/sbin/route", "-n", "add", "-net", cidr.String(), "-interface", t.Device)
|
if cidr.Addr().Is6() {
|
||||||
t.l.Debug("command: ", cmd.String())
|
ifr := ifreqAlias6{
|
||||||
if err = cmd.Run(); err != nil {
|
Name: t.deviceBytes(),
|
||||||
return fmt.Errorf("failed to run 'route add': %s", err)
|
Addr: unix.RawSockaddrInet6{
|
||||||
|
Len: unix.SizeofSockaddrInet6,
|
||||||
|
Family: unix.AF_INET6,
|
||||||
|
Addr: cidr.Addr().As16(),
|
||||||
|
},
|
||||||
|
PrefixMask: unix.RawSockaddrInet6{
|
||||||
|
Len: unix.SizeofSockaddrInet6,
|
||||||
|
Family: unix.AF_INET6,
|
||||||
|
Addr: prefixToMask(cidr).As16(),
|
||||||
|
},
|
||||||
|
Lifetime: addrLifetime{
|
||||||
|
Expire: 0,
|
||||||
|
Preferred: 0,
|
||||||
|
Vltime: 0xffffffff,
|
||||||
|
Pltime: 0xffffffff,
|
||||||
|
},
|
||||||
|
Flags: IN6_IFF_NODAD,
|
||||||
|
}
|
||||||
|
s, err := syscall.Socket(syscall.AF_INET6, syscall.SOCK_DGRAM, syscall.IPPROTO_IP)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer syscall.Close(s)
|
||||||
|
|
||||||
|
if err := ioctl(uintptr(s), OSIOCAIFADDR_IN6, uintptr(unsafe.Pointer(&ifr))); err != nil {
|
||||||
|
return fmt.Errorf("failed to set tun address %s: %s", cidr.Addr().String(), err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
cmd = exec.Command("/sbin/ifconfig", t.Device, "mtu", strconv.Itoa(t.MTU))
|
return fmt.Errorf("unknown address type %v", cidr)
|
||||||
t.l.Debug("command: ", cmd.String())
|
|
||||||
if err = cmd.Run(); err != nil {
|
|
||||||
return fmt.Errorf("failed to run 'ifconfig': %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Unsafe path routes
|
|
||||||
return t.addRoutes(false)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) Activate() error {
|
func (t *tun) Activate() error {
|
||||||
|
// Setup our default MTU
|
||||||
|
err := t.setMTU()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
linkAddr, err := getLinkAddr(t.Device)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if linkAddr == nil {
|
||||||
|
return fmt.Errorf("unable to discover link_addr for tun interface")
|
||||||
|
}
|
||||||
|
t.linkAddr = linkAddr
|
||||||
|
|
||||||
for i := range t.vpnNetworks {
|
for i := range t.vpnNetworks {
|
||||||
err := t.addIp(t.vpnNetworks[i])
|
err := t.addIp(t.vpnNetworks[i])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
|
||||||
|
return t.addRoutes(false)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *tun) setMTU() error {
|
||||||
|
// Set the MTU on the device
|
||||||
|
s, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer syscall.Close(s)
|
||||||
|
|
||||||
|
ifm := ifreqMTU{Name: t.deviceBytes(), MTU: int32(t.MTU)}
|
||||||
|
err = ioctl(uintptr(s), unix.SIOCSIFMTU, uintptr(unsafe.Pointer(&ifm)))
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) reload(c *config.C, initial bool) error {
|
func (t *tun) reload(c *config.C, initial bool) error {
|
||||||
@@ -242,7 +437,7 @@ func (t *tun) reload(c *config.C, initial bool) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) RouteFor(ip netip.Addr) netip.Addr {
|
func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
|
||||||
r, _ := t.routeTree.Load().Lookup(ip)
|
r, _ := t.routeTree.Load().Lookup(ip)
|
||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
@@ -255,6 +450,10 @@ func (t *tun) Name() string {
|
|||||||
return t.Device
|
return t.Device
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *tun) SupportsMultiqueue() bool {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||||
return nil, fmt.Errorf("TODO: multiqueue not implemented for freebsd")
|
return nil, fmt.Errorf("TODO: multiqueue not implemented for freebsd")
|
||||||
}
|
}
|
||||||
@@ -262,20 +461,21 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
|||||||
func (t *tun) addRoutes(logErrors bool) error {
|
func (t *tun) addRoutes(logErrors bool) error {
|
||||||
routes := *t.Routes.Load()
|
routes := *t.Routes.Load()
|
||||||
for _, r := range routes {
|
for _, r := range routes {
|
||||||
if !r.Via.IsValid() || !r.Install {
|
if len(r.Via) == 0 || !r.Install {
|
||||||
// We don't allow route MTUs so only install routes with a via
|
// We don't allow route MTUs so only install routes with a via
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
cmd := exec.Command("/sbin/route", "-n", "add", "-net", r.Cidr.String(), "-interface", t.Device)
|
err := addRoute(r.Cidr, t.linkAddr)
|
||||||
t.l.Debug("command: ", cmd.String())
|
if err != nil {
|
||||||
if err := cmd.Run(); err != nil {
|
retErr := util.NewContextualError("Failed to add route", map[string]any{"route": r}, err)
|
||||||
retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]interface{}{"route": r}, err)
|
|
||||||
if logErrors {
|
if logErrors {
|
||||||
retErr.Log(t.l)
|
retErr.Log(t.l)
|
||||||
} else {
|
} else {
|
||||||
return retErr
|
return retErr
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
t.l.WithField("route", r).Info("Added route")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -288,9 +488,8 @@ func (t *tun) removeRoutes(routes []Route) error {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
cmd := exec.Command("/sbin/route", "-n", "delete", "-net", r.Cidr.String(), "-interface", t.Device)
|
err := delRoute(r.Cidr, t.linkAddr)
|
||||||
t.l.Debug("command: ", cmd.String())
|
if err != nil {
|
||||||
if err := cmd.Run(); err != nil {
|
|
||||||
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
|
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
|
||||||
} else {
|
} else {
|
||||||
t.l.WithField("route", r).Info("Removed route")
|
t.l.WithField("route", r).Info("Removed route")
|
||||||
@@ -305,3 +504,120 @@ func (t *tun) deviceBytes() (o [16]byte) {
|
|||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func addRoute(prefix netip.Prefix, gateway netroute.Addr) error {
|
||||||
|
sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
|
||||||
|
}
|
||||||
|
defer unix.Close(sock)
|
||||||
|
|
||||||
|
route := &netroute.RouteMessage{
|
||||||
|
Version: unix.RTM_VERSION,
|
||||||
|
Type: unix.RTM_ADD,
|
||||||
|
Flags: unix.RTF_UP,
|
||||||
|
Seq: 1,
|
||||||
|
}
|
||||||
|
|
||||||
|
if prefix.Addr().Is4() {
|
||||||
|
route.Addrs = []netroute.Addr{
|
||||||
|
unix.RTAX_DST: &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()},
|
||||||
|
unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()},
|
||||||
|
unix.RTAX_GATEWAY: gateway,
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
route.Addrs = []netroute.Addr{
|
||||||
|
unix.RTAX_DST: &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()},
|
||||||
|
unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()},
|
||||||
|
unix.RTAX_GATEWAY: gateway,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := route.Marshal()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create route.RouteMessage: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = unix.Write(sock, data[:])
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, unix.EEXIST) {
|
||||||
|
// Try to do a change
|
||||||
|
route.Type = unix.RTM_CHANGE
|
||||||
|
data, err = route.Marshal()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create route.RouteMessage for change: %w", err)
|
||||||
|
}
|
||||||
|
_, err = unix.Write(sock, data[:])
|
||||||
|
fmt.Println("DOING CHANGE")
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func delRoute(prefix netip.Prefix, gateway netroute.Addr) error {
|
||||||
|
sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
|
||||||
|
}
|
||||||
|
defer unix.Close(sock)
|
||||||
|
|
||||||
|
route := netroute.RouteMessage{
|
||||||
|
Version: unix.RTM_VERSION,
|
||||||
|
Type: unix.RTM_DELETE,
|
||||||
|
Seq: 1,
|
||||||
|
}
|
||||||
|
|
||||||
|
if prefix.Addr().Is4() {
|
||||||
|
route.Addrs = []netroute.Addr{
|
||||||
|
unix.RTAX_DST: &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()},
|
||||||
|
unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()},
|
||||||
|
unix.RTAX_GATEWAY: gateway,
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
route.Addrs = []netroute.Addr{
|
||||||
|
unix.RTAX_DST: &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()},
|
||||||
|
unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()},
|
||||||
|
unix.RTAX_GATEWAY: gateway,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := route.Marshal()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create route.RouteMessage: %w", err)
|
||||||
|
}
|
||||||
|
_, err = unix.Write(sock, data[:])
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// getLinkAddr Gets the link address for the interface of the given name
|
||||||
|
func getLinkAddr(name string) (*netroute.LinkAddr, error) {
|
||||||
|
rib, err := netroute.FetchRIB(unix.AF_UNSPEC, unix.NET_RT_IFLIST, 0)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
msgs, err := netroute.ParseRIB(unix.NET_RT_IFLIST, rib)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, m := range msgs {
|
||||||
|
switch m := m.(type) {
|
||||||
|
case *netroute.InterfaceMessage:
|
||||||
|
if m.Name == name {
|
||||||
|
sa, ok := m.Addrs[unix.RTAX_IFP].(*netroute.LinkAddr)
|
||||||
|
if ok {
|
||||||
|
return sa, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ import (
|
|||||||
"github.com/gaissmai/bart"
|
"github.com/gaissmai/bart"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
|
"github.com/slackhq/nebula/routing"
|
||||||
"github.com/slackhq/nebula/util"
|
"github.com/slackhq/nebula/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -23,7 +24,7 @@ type tun struct {
|
|||||||
io.ReadWriteCloser
|
io.ReadWriteCloser
|
||||||
vpnNetworks []netip.Prefix
|
vpnNetworks []netip.Prefix
|
||||||
Routes atomic.Pointer[[]Route]
|
Routes atomic.Pointer[[]Route]
|
||||||
routeTree atomic.Pointer[bart.Table[netip.Addr]]
|
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
||||||
l *logrus.Logger
|
l *logrus.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -79,7 +80,7 @@ func (t *tun) reload(c *config.C, initial bool) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) RouteFor(ip netip.Addr) netip.Addr {
|
func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
|
||||||
r, _ := t.routeTree.Load().Lookup(ip)
|
r, _ := t.routeTree.Load().Lookup(ip)
|
||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
@@ -150,6 +151,10 @@ func (t *tun) Name() string {
|
|||||||
return "iOS"
|
return "iOS"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *tun) SupportsMultiqueue() bool {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||||
return nil, fmt.Errorf("TODO: multiqueue not implemented for ios")
|
return nil, fmt.Errorf("TODO: multiqueue not implemented for ios")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ import (
|
|||||||
"github.com/gaissmai/bart"
|
"github.com/gaissmai/bart"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
|
"github.com/slackhq/nebula/routing"
|
||||||
"github.com/slackhq/nebula/util"
|
"github.com/slackhq/nebula/util"
|
||||||
"github.com/vishvananda/netlink"
|
"github.com/vishvananda/netlink"
|
||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
@@ -33,10 +34,11 @@ type tun struct {
|
|||||||
deviceIndex int
|
deviceIndex int
|
||||||
ioctlFd uintptr
|
ioctlFd uintptr
|
||||||
|
|
||||||
Routes atomic.Pointer[[]Route]
|
Routes atomic.Pointer[[]Route]
|
||||||
routeTree atomic.Pointer[bart.Table[netip.Addr]]
|
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
||||||
routeChan chan struct{}
|
routeChan chan struct{}
|
||||||
useSystemRoutes bool
|
useSystemRoutes bool
|
||||||
|
useSystemRoutesBufferSize int
|
||||||
|
|
||||||
l *logrus.Logger
|
l *logrus.Logger
|
||||||
}
|
}
|
||||||
@@ -123,12 +125,13 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueu
|
|||||||
|
|
||||||
func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, vpnNetworks []netip.Prefix) (*tun, error) {
|
func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, vpnNetworks []netip.Prefix) (*tun, error) {
|
||||||
t := &tun{
|
t := &tun{
|
||||||
ReadWriteCloser: file,
|
ReadWriteCloser: file,
|
||||||
fd: int(file.Fd()),
|
fd: int(file.Fd()),
|
||||||
vpnNetworks: vpnNetworks,
|
vpnNetworks: vpnNetworks,
|
||||||
TXQueueLen: c.GetInt("tun.tx_queue", 500),
|
TXQueueLen: c.GetInt("tun.tx_queue", 500),
|
||||||
useSystemRoutes: c.GetBool("tun.use_system_route_table", false),
|
useSystemRoutes: c.GetBool("tun.use_system_route_table", false),
|
||||||
l: l,
|
useSystemRoutesBufferSize: c.GetInt("tun.use_system_route_table_buffer_size", 0),
|
||||||
|
l: l,
|
||||||
}
|
}
|
||||||
|
|
||||||
err := t.reload(c, true)
|
err := t.reload(c, true)
|
||||||
@@ -213,6 +216,10 @@ func (t *tun) reload(c *config.C, initial bool) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *tun) SupportsMultiqueue() bool {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||||
fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
|
fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -231,7 +238,7 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
|||||||
return file, nil
|
return file, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) RouteFor(ip netip.Addr) netip.Addr {
|
func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
|
||||||
r, _ := t.routeTree.Load().Lookup(ip)
|
r, _ := t.routeTree.Load().Lookup(ip)
|
||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
@@ -290,7 +297,6 @@ func (t *tun) addIPs(link netlink.Link) error {
|
|||||||
|
|
||||||
//add all new addresses
|
//add all new addresses
|
||||||
for i := range newAddrs {
|
for i := range newAddrs {
|
||||||
//TODO: CERT-V2 do we want to stack errors and try as many ops as possible?
|
|
||||||
//AddrReplace still adds new IPs, but if their properties change it will change them as well
|
//AddrReplace still adds new IPs, but if their properties change it will change them as well
|
||||||
if err := netlink.AddrReplace(link, newAddrs[i]); err != nil {
|
if err := netlink.AddrReplace(link, newAddrs[i]); err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -358,6 +364,11 @@ func (t *tun) Activate() error {
|
|||||||
t.l.WithError(err).Error("Failed to set tun tx queue length")
|
t.l.WithError(err).Error("Failed to set tun tx queue length")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const modeNone = 1
|
||||||
|
if err = netlink.LinkSetIP6AddrGenMode(link, modeNone); err != nil {
|
||||||
|
t.l.WithError(err).Warn("Failed to disable link local address generation")
|
||||||
|
}
|
||||||
|
|
||||||
if err = t.addIPs(link); err != nil {
|
if err = t.addIPs(link); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -463,7 +474,7 @@ func (t *tun) addRoutes(logErrors bool) error {
|
|||||||
|
|
||||||
err := netlink.RouteReplace(&nr)
|
err := netlink.RouteReplace(&nr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
retErr := util.NewContextualError("Failed to add route", map[string]interface{}{"route": r}, err)
|
retErr := util.NewContextualError("Failed to add route", map[string]any{"route": r}, err)
|
||||||
if logErrors {
|
if logErrors {
|
||||||
retErr.Log(t.l)
|
retErr.Log(t.l)
|
||||||
} else {
|
} else {
|
||||||
@@ -530,7 +541,13 @@ func (t *tun) watchRoutes() {
|
|||||||
rch := make(chan netlink.RouteUpdate)
|
rch := make(chan netlink.RouteUpdate)
|
||||||
doneChan := make(chan struct{})
|
doneChan := make(chan struct{})
|
||||||
|
|
||||||
if err := netlink.RouteSubscribe(rch, doneChan); err != nil {
|
netlinkOptions := netlink.RouteSubscribeOptions{
|
||||||
|
ReceiveBufferSize: t.useSystemRoutesBufferSize,
|
||||||
|
ReceiveBufferForceSize: t.useSystemRoutesBufferSize != 0,
|
||||||
|
ErrorCallback: func(e error) { t.l.WithError(e).Errorf("netlink error") },
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := netlink.RouteSubscribeWithOptions(rch, doneChan, netlinkOptions); err != nil {
|
||||||
t.l.WithError(err).Errorf("failed to subscribe to system route changes")
|
t.l.WithError(err).Errorf("failed to subscribe to system route changes")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -540,8 +557,14 @@ func (t *tun) watchRoutes() {
|
|||||||
go func() {
|
go func() {
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case r := <-rch:
|
case r, ok := <-rch:
|
||||||
t.updateRoutes(r)
|
if ok {
|
||||||
|
t.updateRoutes(r)
|
||||||
|
} else {
|
||||||
|
// may be should do something here as
|
||||||
|
// netlink stops sending updates
|
||||||
|
return
|
||||||
|
}
|
||||||
case <-doneChan:
|
case <-doneChan:
|
||||||
// netlink.RouteSubscriber will close the rch for us
|
// netlink.RouteSubscriber will close the rch for us
|
||||||
return
|
return
|
||||||
@@ -550,20 +573,7 @@ func (t *tun) watchRoutes() {
|
|||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) updateRoutes(r netlink.RouteUpdate) {
|
func (t *tun) isGatewayInVpnNetworks(gwAddr netip.Addr) bool {
|
||||||
if r.Gw == nil {
|
|
||||||
// Not a gateway route, ignore
|
|
||||||
t.l.WithField("route", r).Debug("Ignoring route update, not a gateway route")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
gwAddr, ok := netip.AddrFromSlice(r.Gw)
|
|
||||||
if !ok {
|
|
||||||
t.l.WithField("route", r).Debug("Ignoring route update, invalid gateway address")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
gwAddr = gwAddr.Unmap()
|
|
||||||
withinNetworks := false
|
withinNetworks := false
|
||||||
for i := range t.vpnNetworks {
|
for i := range t.vpnNetworks {
|
||||||
if t.vpnNetworks[i].Contains(gwAddr) {
|
if t.vpnNetworks[i].Contains(gwAddr) {
|
||||||
@@ -571,9 +581,84 @@ func (t *tun) updateRoutes(r netlink.RouteUpdate) {
|
|||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if !withinNetworks {
|
|
||||||
// Gateway isn't in our overlay network, ignore
|
return withinNetworks
|
||||||
t.l.WithField("route", r).Debug("Ignoring route update, not in our networks")
|
}
|
||||||
|
|
||||||
|
func (t *tun) getGatewaysFromRoute(r *netlink.Route) routing.Gateways {
|
||||||
|
var gateways routing.Gateways
|
||||||
|
|
||||||
|
link, err := netlink.LinkByName(t.Device)
|
||||||
|
if err != nil {
|
||||||
|
t.l.WithField("deviceName", t.Device).Error("Ignoring route update: failed to get link by name")
|
||||||
|
return gateways
|
||||||
|
}
|
||||||
|
|
||||||
|
// If this route is relevant to our interface and there is a gateway then add it
|
||||||
|
if r.LinkIndex == link.Attrs().Index {
|
||||||
|
gwAddr, ok := getGatewayAddr(r.Gw, r.Via)
|
||||||
|
if ok {
|
||||||
|
if t.isGatewayInVpnNetworks(gwAddr) {
|
||||||
|
gateways = append(gateways, routing.NewGateway(gwAddr, 1))
|
||||||
|
} else {
|
||||||
|
// Gateway isn't in our overlay network, ignore
|
||||||
|
t.l.WithField("route", r).Debug("Ignoring route update, gateway is not in our network")
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
t.l.WithField("route", r).Debug("Ignoring route update, invalid gateway or via address")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, p := range r.MultiPath {
|
||||||
|
// If this route is relevant to our interface and there is a gateway then add it
|
||||||
|
if p.LinkIndex == link.Attrs().Index {
|
||||||
|
gwAddr, ok := getGatewayAddr(p.Gw, p.Via)
|
||||||
|
if ok {
|
||||||
|
if t.isGatewayInVpnNetworks(gwAddr) {
|
||||||
|
gateways = append(gateways, routing.NewGateway(gwAddr, p.Hops+1))
|
||||||
|
} else {
|
||||||
|
// Gateway isn't in our overlay network, ignore
|
||||||
|
t.l.WithField("route", r).Debug("Ignoring route update, gateway is not in our network")
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
t.l.WithField("route", r).Debug("Ignoring route update, invalid gateway or via address")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
routing.CalculateBucketsForGateways(gateways)
|
||||||
|
return gateways
|
||||||
|
}
|
||||||
|
|
||||||
|
func getGatewayAddr(gw net.IP, via netlink.Destination) (netip.Addr, bool) {
|
||||||
|
// Try to use the old RTA_GATEWAY first
|
||||||
|
gwAddr, ok := netip.AddrFromSlice(gw)
|
||||||
|
if !ok {
|
||||||
|
// Fallback to the new RTA_VIA
|
||||||
|
rVia, ok := via.(*netlink.Via)
|
||||||
|
if ok {
|
||||||
|
gwAddr, ok = netip.AddrFromSlice(rVia.Addr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if gwAddr.IsValid() {
|
||||||
|
gwAddr = gwAddr.Unmap()
|
||||||
|
return gwAddr, true
|
||||||
|
}
|
||||||
|
|
||||||
|
return netip.Addr{}, false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *tun) updateRoutes(r netlink.RouteUpdate) {
|
||||||
|
gateways := t.getGatewaysFromRoute(&r.Route)
|
||||||
|
if len(gateways) == 0 {
|
||||||
|
// No gateways relevant to our network, no routing changes required.
|
||||||
|
t.l.WithField("route", r).Debug("Ignoring route update, no gateways")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.Dst == nil {
|
||||||
|
t.l.WithField("route", r).Debug("Ignoring route update, no destination address")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -589,12 +674,12 @@ func (t *tun) updateRoutes(r netlink.RouteUpdate) {
|
|||||||
newTree := t.routeTree.Load().Clone()
|
newTree := t.routeTree.Load().Clone()
|
||||||
|
|
||||||
if r.Type == unix.RTM_NEWROUTE {
|
if r.Type == unix.RTM_NEWROUTE {
|
||||||
t.l.WithField("destination", r.Dst).WithField("via", r.Gw).Info("Adding route")
|
t.l.WithField("destination", dst).WithField("via", gateways).Info("Adding route")
|
||||||
newTree.Insert(dst, gwAddr)
|
newTree.Insert(dst, gateways)
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
|
t.l.WithField("destination", dst).WithField("via", gateways).Info("Removing route")
|
||||||
newTree.Delete(dst)
|
newTree.Delete(dst)
|
||||||
t.l.WithField("destination", r.Dst).WithField("via", r.Gw).Info("Removing route")
|
|
||||||
}
|
}
|
||||||
t.routeTree.Store(newTree)
|
t.routeTree.Store(newTree)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,13 +4,12 @@
|
|||||||
package overlay
|
package overlay
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
"os/exec"
|
|
||||||
"regexp"
|
"regexp"
|
||||||
"strconv"
|
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"syscall"
|
"syscall"
|
||||||
"unsafe"
|
"unsafe"
|
||||||
@@ -18,12 +17,44 @@ import (
|
|||||||
"github.com/gaissmai/bart"
|
"github.com/gaissmai/bart"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
|
"github.com/slackhq/nebula/routing"
|
||||||
"github.com/slackhq/nebula/util"
|
"github.com/slackhq/nebula/util"
|
||||||
|
netroute "golang.org/x/net/route"
|
||||||
|
"golang.org/x/sys/unix"
|
||||||
)
|
)
|
||||||
|
|
||||||
type ifreqDestroy struct {
|
const (
|
||||||
Name [16]byte
|
SIOCAIFADDR_IN6 = 0x8080696b
|
||||||
pad [16]byte
|
TUNSIFHEAD = 0x80047442
|
||||||
|
TUNSIFMODE = 0x80047458
|
||||||
|
)
|
||||||
|
|
||||||
|
type ifreqAlias4 struct {
|
||||||
|
Name [unix.IFNAMSIZ]byte
|
||||||
|
Addr unix.RawSockaddrInet4
|
||||||
|
DstAddr unix.RawSockaddrInet4
|
||||||
|
MaskAddr unix.RawSockaddrInet4
|
||||||
|
}
|
||||||
|
|
||||||
|
type ifreqAlias6 struct {
|
||||||
|
Name [unix.IFNAMSIZ]byte
|
||||||
|
Addr unix.RawSockaddrInet6
|
||||||
|
DstAddr unix.RawSockaddrInet6
|
||||||
|
PrefixMask unix.RawSockaddrInet6
|
||||||
|
Flags uint32
|
||||||
|
Lifetime addrLifetime
|
||||||
|
}
|
||||||
|
|
||||||
|
type ifreq struct {
|
||||||
|
Name [unix.IFNAMSIZ]byte
|
||||||
|
data int
|
||||||
|
}
|
||||||
|
|
||||||
|
type addrLifetime struct {
|
||||||
|
Expire uint64
|
||||||
|
Preferred uint64
|
||||||
|
Vltime uint32
|
||||||
|
Pltime uint32
|
||||||
}
|
}
|
||||||
|
|
||||||
type tun struct {
|
type tun struct {
|
||||||
@@ -31,42 +62,20 @@ type tun struct {
|
|||||||
vpnNetworks []netip.Prefix
|
vpnNetworks []netip.Prefix
|
||||||
MTU int
|
MTU int
|
||||||
Routes atomic.Pointer[[]Route]
|
Routes atomic.Pointer[[]Route]
|
||||||
routeTree atomic.Pointer[bart.Table[netip.Addr]]
|
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
||||||
l *logrus.Logger
|
l *logrus.Logger
|
||||||
|
f *os.File
|
||||||
io.ReadWriteCloser
|
fd int
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) Close() error {
|
var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
|
||||||
if t.ReadWriteCloser != nil {
|
|
||||||
if err := t.ReadWriteCloser.Close(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
s, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_DGRAM, syscall.IPPROTO_IP)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer syscall.Close(s)
|
|
||||||
|
|
||||||
ifreq := ifreqDestroy{Name: t.deviceBytes()}
|
|
||||||
|
|
||||||
err = ioctl(uintptr(s), syscall.SIOCIFDESTROY, uintptr(unsafe.Pointer(&ifreq)))
|
|
||||||
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) {
|
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) {
|
||||||
return nil, fmt.Errorf("newTunFromFd not supported in NetBSD")
|
return nil, fmt.Errorf("newTunFromFd not supported in NetBSD")
|
||||||
}
|
}
|
||||||
|
|
||||||
var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
|
|
||||||
|
|
||||||
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
|
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
|
||||||
// Try to open tun device
|
// Try to open tun device
|
||||||
var file *os.File
|
|
||||||
var err error
|
var err error
|
||||||
deviceName := c.GetString("tun.dev", "")
|
deviceName := c.GetString("tun.dev", "")
|
||||||
if deviceName == "" {
|
if deviceName == "" {
|
||||||
@@ -76,17 +85,23 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
|
|||||||
return nil, fmt.Errorf("a device name in the format of /dev/tunN must be specified")
|
return nil, fmt.Errorf("a device name in the format of /dev/tunN must be specified")
|
||||||
}
|
}
|
||||||
|
|
||||||
file, err = os.OpenFile("/dev/"+deviceName, os.O_RDWR, 0)
|
fd, err := unix.Open("/dev/"+deviceName, os.O_RDWR, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
err = unix.SetNonblock(fd, true)
|
||||||
|
if err != nil {
|
||||||
|
l.WithError(err).Warn("Failed to set the tun device as nonblocking")
|
||||||
|
}
|
||||||
|
|
||||||
t := &tun{
|
t := &tun{
|
||||||
ReadWriteCloser: file,
|
f: os.NewFile(uintptr(fd), ""),
|
||||||
Device: deviceName,
|
fd: fd,
|
||||||
vpnNetworks: vpnNetworks,
|
Device: deviceName,
|
||||||
MTU: c.GetInt("tun.mtu", DefaultMTU),
|
vpnNetworks: vpnNetworks,
|
||||||
l: l,
|
MTU: c.GetInt("tun.mtu", DefaultMTU),
|
||||||
|
l: l,
|
||||||
}
|
}
|
||||||
|
|
||||||
err = t.reload(c, true)
|
err = t.reload(c, true)
|
||||||
@@ -104,58 +119,227 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
|
|||||||
return t, nil
|
return t, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) addIp(cidr netip.Prefix) error {
|
func (t *tun) Close() error {
|
||||||
var err error
|
if t.f != nil {
|
||||||
|
if err := t.f.Close(); err != nil {
|
||||||
// TODO use syscalls instead of exec.Command
|
return fmt.Errorf("error closing tun file: %w", err)
|
||||||
if cidr.Addr().Is6() {
|
|
||||||
cmd := exec.Command("/sbin/ifconfig", t.Device, "inet6", cidr.Addr().String(), "prefixlen", strconv.Itoa(cidr.Bits()), "alias")
|
|
||||||
t.l.Debug("command: ", cmd.String())
|
|
||||||
if err = cmd.Run(); err != nil {
|
|
||||||
return fmt.Errorf("failed to run 'ifconfig': %s", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
cmd = exec.Command("/sbin/route", "-n", "add", "-net", cidr.String(), cidr.Addr().String())
|
// t.f.Close should have handled it for us but let's be extra sure
|
||||||
t.l.Debug("command: ", cmd.String())
|
_ = unix.Close(t.fd)
|
||||||
if err = cmd.Run(); err != nil {
|
|
||||||
return fmt.Errorf("failed to run 'route add': %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
} else {
|
s, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_DGRAM, syscall.IPPROTO_IP)
|
||||||
cmd := exec.Command("/sbin/ifconfig", t.Device, cidr.String(), cidr.Addr().String())
|
if err != nil {
|
||||||
t.l.Debug("command: ", cmd.String())
|
return err
|
||||||
if err = cmd.Run(); err != nil {
|
|
||||||
return fmt.Errorf("failed to run 'ifconfig': %s", err)
|
|
||||||
}
|
}
|
||||||
|
defer syscall.Close(s)
|
||||||
|
|
||||||
cmd = exec.Command("/sbin/route", "-n", "add", "-net", cidr.String(), cidr.Addr().String())
|
ifr := ifreq{Name: t.deviceBytes()}
|
||||||
t.l.Debug("command: ", cmd.String())
|
err = ioctl(uintptr(s), syscall.SIOCIFDESTROY, uintptr(unsafe.Pointer(&ifr)))
|
||||||
if err = cmd.Run(); err != nil {
|
return err
|
||||||
return fmt.Errorf("failed to run 'route add': %s", err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *tun) Read(to []byte) (int, error) {
|
||||||
|
rc, err := t.f.SyscallConn()
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("failed to get syscall conn for tun: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var errno syscall.Errno
|
||||||
|
var n uintptr
|
||||||
|
err = rc.Read(func(fd uintptr) bool {
|
||||||
|
// first 4 bytes is protocol family, in network byte order
|
||||||
|
head := [4]byte{}
|
||||||
|
iovecs := []syscall.Iovec{
|
||||||
|
{&head[0], 4},
|
||||||
|
{&to[0], uint64(len(to))},
|
||||||
|
}
|
||||||
|
|
||||||
|
n, _, errno = syscall.Syscall(syscall.SYS_READV, fd, uintptr(unsafe.Pointer(&iovecs[0])), uintptr(2))
|
||||||
|
if errno.Temporary() {
|
||||||
|
// We got an EAGAIN, EINTR, or EWOULDBLOCK, go again
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
if err == syscall.EBADF || err.Error() == "use of closed file" {
|
||||||
|
// Go doesn't export poll.ErrFileClosing but happily reports it to us so here we are
|
||||||
|
// https://github.com/golang/go/blob/master/src/internal/poll/fd_poll_runtime.go#L121
|
||||||
|
return 0, os.ErrClosed
|
||||||
|
}
|
||||||
|
return 0, fmt.Errorf("failed to make read call for tun: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if errno != 0 {
|
||||||
|
return 0, fmt.Errorf("failed to make inner read call for tun: %w", errno)
|
||||||
|
}
|
||||||
|
|
||||||
|
// fix bytes read number to exclude header
|
||||||
|
bytesRead := int(n)
|
||||||
|
if bytesRead < 0 {
|
||||||
|
return bytesRead, nil
|
||||||
|
} else if bytesRead < 4 {
|
||||||
|
return 0, nil
|
||||||
|
} else {
|
||||||
|
return bytesRead - 4, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write is only valid for single threaded use
|
||||||
|
func (t *tun) Write(from []byte) (int, error) {
|
||||||
|
if len(from) <= 1 {
|
||||||
|
return 0, syscall.EIO
|
||||||
|
}
|
||||||
|
|
||||||
|
ipVer := from[0] >> 4
|
||||||
|
var head [4]byte
|
||||||
|
// first 4 bytes is protocol family, in network byte order
|
||||||
|
if ipVer == 4 {
|
||||||
|
head[3] = syscall.AF_INET
|
||||||
|
} else if ipVer == 6 {
|
||||||
|
head[3] = syscall.AF_INET6
|
||||||
|
} else {
|
||||||
|
return 0, fmt.Errorf("unable to determine IP version from packet")
|
||||||
|
}
|
||||||
|
|
||||||
|
rc, err := t.f.SyscallConn()
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var errno syscall.Errno
|
||||||
|
var n uintptr
|
||||||
|
err = rc.Write(func(fd uintptr) bool {
|
||||||
|
iovecs := []syscall.Iovec{
|
||||||
|
{&head[0], 4},
|
||||||
|
{&from[0], uint64(len(from))},
|
||||||
|
}
|
||||||
|
|
||||||
|
n, _, errno = syscall.Syscall(syscall.SYS_WRITEV, fd, uintptr(unsafe.Pointer(&iovecs[0])), uintptr(2))
|
||||||
|
// According to NetBSD documentation for TUN, writes will only return errors in which
|
||||||
|
// this packet will never be delivered so just go on living life.
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if errno != 0 {
|
||||||
|
return 0, errno
|
||||||
|
}
|
||||||
|
|
||||||
|
return int(n) - 4, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *tun) addIp(cidr netip.Prefix) error {
|
||||||
|
if cidr.Addr().Is4() {
|
||||||
|
var req ifreqAlias4
|
||||||
|
req.Name = t.deviceBytes()
|
||||||
|
req.Addr = unix.RawSockaddrInet4{
|
||||||
|
Len: unix.SizeofSockaddrInet4,
|
||||||
|
Family: unix.AF_INET,
|
||||||
|
Addr: cidr.Addr().As4(),
|
||||||
|
}
|
||||||
|
req.DstAddr = unix.RawSockaddrInet4{
|
||||||
|
Len: unix.SizeofSockaddrInet4,
|
||||||
|
Family: unix.AF_INET,
|
||||||
|
Addr: cidr.Addr().As4(),
|
||||||
|
}
|
||||||
|
req.MaskAddr = unix.RawSockaddrInet4{
|
||||||
|
Len: unix.SizeofSockaddrInet4,
|
||||||
|
Family: unix.AF_INET,
|
||||||
|
Addr: prefixToMask(cidr).As4(),
|
||||||
|
}
|
||||||
|
|
||||||
|
s, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer syscall.Close(s)
|
||||||
|
|
||||||
|
if err := ioctl(uintptr(s), unix.SIOCAIFADDR, uintptr(unsafe.Pointer(&req))); err != nil {
|
||||||
|
return fmt.Errorf("failed to set tun address %s: %s", cidr.Addr(), err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if cidr.Addr().Is6() {
|
||||||
|
var req ifreqAlias6
|
||||||
|
req.Name = t.deviceBytes()
|
||||||
|
req.Addr = unix.RawSockaddrInet6{
|
||||||
|
Len: unix.SizeofSockaddrInet6,
|
||||||
|
Family: unix.AF_INET6,
|
||||||
|
Addr: cidr.Addr().As16(),
|
||||||
|
}
|
||||||
|
req.PrefixMask = unix.RawSockaddrInet6{
|
||||||
|
Len: unix.SizeofSockaddrInet6,
|
||||||
|
Family: unix.AF_INET6,
|
||||||
|
Addr: prefixToMask(cidr).As16(),
|
||||||
|
}
|
||||||
|
req.Lifetime = addrLifetime{
|
||||||
|
Vltime: 0xffffffff,
|
||||||
|
Pltime: 0xffffffff,
|
||||||
|
}
|
||||||
|
|
||||||
|
s, err := unix.Socket(unix.AF_INET6, unix.SOCK_DGRAM, unix.IPPROTO_IP)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer syscall.Close(s)
|
||||||
|
|
||||||
|
if err := ioctl(uintptr(s), SIOCAIFADDR_IN6, uintptr(unsafe.Pointer(&req))); err != nil {
|
||||||
|
return fmt.Errorf("failed to set tun address %s: %s", cidr.Addr().String(), err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Errorf("unknown address type %v", cidr)
|
||||||
|
}
|
||||||
|
|
||||||
func (t *tun) Activate() error {
|
func (t *tun) Activate() error {
|
||||||
|
mode := int32(unix.IFF_BROADCAST)
|
||||||
|
err := ioctl(uintptr(t.fd), TUNSIFMODE, uintptr(unsafe.Pointer(&mode)))
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to set tun device mode: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
v := 1
|
||||||
|
err = ioctl(uintptr(t.fd), TUNSIFHEAD, uintptr(unsafe.Pointer(&v)))
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to set tun device head: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = t.doIoctlByName(unix.SIOCSIFMTU, uint32(t.MTU))
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to set tun mtu: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
for i := range t.vpnNetworks {
|
for i := range t.vpnNetworks {
|
||||||
err := t.addIp(t.vpnNetworks[i])
|
err = t.addIp(t.vpnNetworks[i])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
cmd := exec.Command("/sbin/ifconfig", t.Device, "mtu", strconv.Itoa(t.MTU))
|
|
||||||
t.l.Debug("command: ", cmd.String())
|
|
||||||
if err := cmd.Run(); err != nil {
|
|
||||||
return fmt.Errorf("failed to run '%s': %s", cmd, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Unsafe path routes
|
|
||||||
return t.addRoutes(false)
|
return t.addRoutes(false)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *tun) doIoctlByName(ctl uintptr, value uint32) error {
|
||||||
|
s, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer syscall.Close(s)
|
||||||
|
|
||||||
|
ir := ifreq{Name: t.deviceBytes(), data: int(value)}
|
||||||
|
err = ioctl(uintptr(s), ctl, uintptr(unsafe.Pointer(&ir)))
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
func (t *tun) reload(c *config.C, initial bool) error {
|
func (t *tun) reload(c *config.C, initial bool) error {
|
||||||
change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial)
|
change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -193,7 +377,7 @@ func (t *tun) reload(c *config.C, initial bool) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) RouteFor(ip netip.Addr) netip.Addr {
|
func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
|
||||||
r, _ := t.routeTree.Load().Lookup(ip)
|
r, _ := t.routeTree.Load().Lookup(ip)
|
||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
@@ -206,27 +390,33 @@ func (t *tun) Name() string {
|
|||||||
return t.Device
|
return t.Device
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *tun) SupportsMultiqueue() bool {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||||
return nil, fmt.Errorf("TODO: multiqueue not implemented for netbsd")
|
return nil, fmt.Errorf("TODO: multiqueue not implemented for netbsd")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) addRoutes(logErrors bool) error {
|
func (t *tun) addRoutes(logErrors bool) error {
|
||||||
routes := *t.Routes.Load()
|
routes := *t.Routes.Load()
|
||||||
|
|
||||||
for _, r := range routes {
|
for _, r := range routes {
|
||||||
if !r.Via.IsValid() || !r.Install {
|
if len(r.Via) == 0 || !r.Install {
|
||||||
// We don't allow route MTUs so only install routes with a via
|
// We don't allow route MTUs so only install routes with a via
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
cmd := exec.Command("/sbin/route", "-n", "add", "-net", r.Cidr.String(), t.vpnNetworks[0].Addr().String())
|
err := addRoute(r.Cidr, t.vpnNetworks)
|
||||||
t.l.Debug("command: ", cmd.String())
|
if err != nil {
|
||||||
if err := cmd.Run(); err != nil {
|
retErr := util.NewContextualError("Failed to add route", map[string]any{"route": r}, err)
|
||||||
retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]interface{}{"route": r}, err)
|
|
||||||
if logErrors {
|
if logErrors {
|
||||||
retErr.Log(t.l)
|
retErr.Log(t.l)
|
||||||
} else {
|
} else {
|
||||||
return retErr
|
return retErr
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
t.l.WithField("route", r).Info("Added route")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -239,10 +429,8 @@ func (t *tun) removeRoutes(routes []Route) error {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
//TODO: CERT-V2 is this right?
|
err := delRoute(r.Cidr, t.vpnNetworks)
|
||||||
cmd := exec.Command("/sbin/route", "-n", "delete", "-net", r.Cidr.String(), t.vpnNetworks[0].Addr().String())
|
if err != nil {
|
||||||
t.l.Debug("command: ", cmd.String())
|
|
||||||
if err := cmd.Run(); err != nil {
|
|
||||||
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
|
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
|
||||||
} else {
|
} else {
|
||||||
t.l.WithField("route", r).Info("Removed route")
|
t.l.WithField("route", r).Info("Removed route")
|
||||||
@@ -257,3 +445,109 @@ func (t *tun) deviceBytes() (o [16]byte) {
|
|||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func addRoute(prefix netip.Prefix, gateways []netip.Prefix) error {
|
||||||
|
sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
|
||||||
|
}
|
||||||
|
defer unix.Close(sock)
|
||||||
|
|
||||||
|
route := &netroute.RouteMessage{
|
||||||
|
Version: unix.RTM_VERSION,
|
||||||
|
Type: unix.RTM_ADD,
|
||||||
|
Flags: unix.RTF_UP | unix.RTF_GATEWAY,
|
||||||
|
Seq: 1,
|
||||||
|
}
|
||||||
|
|
||||||
|
if prefix.Addr().Is4() {
|
||||||
|
gw, err := selectGateway(prefix, gateways)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
route.Addrs = []netroute.Addr{
|
||||||
|
unix.RTAX_DST: &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()},
|
||||||
|
unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()},
|
||||||
|
unix.RTAX_GATEWAY: &netroute.Inet4Addr{IP: gw.Addr().As4()},
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
gw, err := selectGateway(prefix, gateways)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
route.Addrs = []netroute.Addr{
|
||||||
|
unix.RTAX_DST: &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()},
|
||||||
|
unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()},
|
||||||
|
unix.RTAX_GATEWAY: &netroute.Inet6Addr{IP: gw.Addr().As16()},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := route.Marshal()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create route.RouteMessage: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = unix.Write(sock, data[:])
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, unix.EEXIST) {
|
||||||
|
// Try to do a change
|
||||||
|
route.Type = unix.RTM_CHANGE
|
||||||
|
data, err = route.Marshal()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create route.RouteMessage for change: %w", err)
|
||||||
|
}
|
||||||
|
_, err = unix.Write(sock, data[:])
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func delRoute(prefix netip.Prefix, gateways []netip.Prefix) error {
|
||||||
|
sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
|
||||||
|
}
|
||||||
|
defer unix.Close(sock)
|
||||||
|
|
||||||
|
route := netroute.RouteMessage{
|
||||||
|
Version: unix.RTM_VERSION,
|
||||||
|
Type: unix.RTM_DELETE,
|
||||||
|
Seq: 1,
|
||||||
|
}
|
||||||
|
|
||||||
|
if prefix.Addr().Is4() {
|
||||||
|
gw, err := selectGateway(prefix, gateways)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
route.Addrs = []netroute.Addr{
|
||||||
|
unix.RTAX_DST: &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()},
|
||||||
|
unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()},
|
||||||
|
unix.RTAX_GATEWAY: &netroute.Inet4Addr{IP: gw.Addr().As4()},
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
gw, err := selectGateway(prefix, gateways)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
route.Addrs = []netroute.Addr{
|
||||||
|
unix.RTAX_DST: &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()},
|
||||||
|
unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()},
|
||||||
|
unix.RTAX_GATEWAY: &netroute.Inet6Addr{IP: gw.Addr().As16()},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := route.Marshal()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create route.RouteMessage: %w", err)
|
||||||
|
}
|
||||||
|
_, err = unix.Write(sock, data[:])
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -4,71 +4,97 @@
|
|||||||
package overlay
|
package overlay
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
"os/exec"
|
|
||||||
"regexp"
|
"regexp"
|
||||||
"strconv"
|
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"syscall"
|
"syscall"
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
"github.com/gaissmai/bart"
|
"github.com/gaissmai/bart"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
|
"github.com/slackhq/nebula/routing"
|
||||||
"github.com/slackhq/nebula/util"
|
"github.com/slackhq/nebula/util"
|
||||||
|
netroute "golang.org/x/net/route"
|
||||||
|
"golang.org/x/sys/unix"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
SIOCAIFADDR_IN6 = 0x8080691a
|
||||||
|
)
|
||||||
|
|
||||||
|
type ifreqAlias4 struct {
|
||||||
|
Name [unix.IFNAMSIZ]byte
|
||||||
|
Addr unix.RawSockaddrInet4
|
||||||
|
DstAddr unix.RawSockaddrInet4
|
||||||
|
MaskAddr unix.RawSockaddrInet4
|
||||||
|
}
|
||||||
|
|
||||||
|
type ifreqAlias6 struct {
|
||||||
|
Name [unix.IFNAMSIZ]byte
|
||||||
|
Addr unix.RawSockaddrInet6
|
||||||
|
DstAddr unix.RawSockaddrInet6
|
||||||
|
PrefixMask unix.RawSockaddrInet6
|
||||||
|
Flags uint32
|
||||||
|
Lifetime [2]uint32
|
||||||
|
}
|
||||||
|
|
||||||
|
type ifreq struct {
|
||||||
|
Name [unix.IFNAMSIZ]byte
|
||||||
|
data int
|
||||||
|
}
|
||||||
|
|
||||||
type tun struct {
|
type tun struct {
|
||||||
Device string
|
Device string
|
||||||
vpnNetworks []netip.Prefix
|
vpnNetworks []netip.Prefix
|
||||||
MTU int
|
MTU int
|
||||||
Routes atomic.Pointer[[]Route]
|
Routes atomic.Pointer[[]Route]
|
||||||
routeTree atomic.Pointer[bart.Table[netip.Addr]]
|
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
||||||
l *logrus.Logger
|
l *logrus.Logger
|
||||||
|
f *os.File
|
||||||
io.ReadWriteCloser
|
fd int
|
||||||
|
|
||||||
// cache out buffer since we need to prepend 4 bytes for tun metadata
|
// cache out buffer since we need to prepend 4 bytes for tun metadata
|
||||||
out []byte
|
out []byte
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) Close() error {
|
|
||||||
if t.ReadWriteCloser != nil {
|
|
||||||
return t.ReadWriteCloser.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) {
|
|
||||||
return nil, fmt.Errorf("newTunFromFd not supported in OpenBSD")
|
|
||||||
}
|
|
||||||
|
|
||||||
var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
|
var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
|
||||||
|
|
||||||
|
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) {
|
||||||
|
return nil, fmt.Errorf("newTunFromFd not supported in openbsd")
|
||||||
|
}
|
||||||
|
|
||||||
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
|
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
|
||||||
|
// Try to open tun device
|
||||||
|
var err error
|
||||||
deviceName := c.GetString("tun.dev", "")
|
deviceName := c.GetString("tun.dev", "")
|
||||||
if deviceName == "" {
|
if deviceName == "" {
|
||||||
return nil, fmt.Errorf("a device name in the format of tunN must be specified")
|
return nil, fmt.Errorf("a device name in the format of /dev/tunN must be specified")
|
||||||
}
|
}
|
||||||
|
|
||||||
if !deviceNameRE.MatchString(deviceName) {
|
if !deviceNameRE.MatchString(deviceName) {
|
||||||
return nil, fmt.Errorf("a device name in the format of tunN must be specified")
|
return nil, fmt.Errorf("a device name in the format of /dev/tunN must be specified")
|
||||||
}
|
}
|
||||||
|
|
||||||
file, err := os.OpenFile("/dev/"+deviceName, os.O_RDWR, 0)
|
fd, err := unix.Open("/dev/"+deviceName, os.O_RDWR, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
err = unix.SetNonblock(fd, true)
|
||||||
|
if err != nil {
|
||||||
|
l.WithError(err).Warn("Failed to set the tun device as nonblocking")
|
||||||
|
}
|
||||||
|
|
||||||
t := &tun{
|
t := &tun{
|
||||||
ReadWriteCloser: file,
|
f: os.NewFile(uintptr(fd), ""),
|
||||||
Device: deviceName,
|
fd: fd,
|
||||||
vpnNetworks: vpnNetworks,
|
Device: deviceName,
|
||||||
MTU: c.GetInt("tun.mtu", DefaultMTU),
|
vpnNetworks: vpnNetworks,
|
||||||
l: l,
|
MTU: c.GetInt("tun.mtu", DefaultMTU),
|
||||||
|
l: l,
|
||||||
}
|
}
|
||||||
|
|
||||||
err = t.reload(c, true)
|
err = t.reload(c, true)
|
||||||
@@ -86,6 +112,154 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
|
|||||||
return t, nil
|
return t, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *tun) Close() error {
|
||||||
|
if t.f != nil {
|
||||||
|
if err := t.f.Close(); err != nil {
|
||||||
|
return fmt.Errorf("error closing tun file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// t.f.Close should have handled it for us but let's be extra sure
|
||||||
|
_ = unix.Close(t.fd)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *tun) Read(to []byte) (int, error) {
|
||||||
|
buf := make([]byte, len(to)+4)
|
||||||
|
|
||||||
|
n, err := t.f.Read(buf)
|
||||||
|
|
||||||
|
copy(to, buf[4:])
|
||||||
|
return n - 4, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write is only valid for single threaded use
|
||||||
|
func (t *tun) Write(from []byte) (int, error) {
|
||||||
|
buf := t.out
|
||||||
|
if cap(buf) < len(from)+4 {
|
||||||
|
buf = make([]byte, len(from)+4)
|
||||||
|
t.out = buf
|
||||||
|
}
|
||||||
|
buf = buf[:len(from)+4]
|
||||||
|
|
||||||
|
if len(from) == 0 {
|
||||||
|
return 0, syscall.EIO
|
||||||
|
}
|
||||||
|
|
||||||
|
// Determine the IP Family for the NULL L2 Header
|
||||||
|
ipVer := from[0] >> 4
|
||||||
|
if ipVer == 4 {
|
||||||
|
buf[3] = syscall.AF_INET
|
||||||
|
} else if ipVer == 6 {
|
||||||
|
buf[3] = syscall.AF_INET6
|
||||||
|
} else {
|
||||||
|
return 0, fmt.Errorf("unable to determine IP version from packet")
|
||||||
|
}
|
||||||
|
|
||||||
|
copy(buf[4:], from)
|
||||||
|
|
||||||
|
n, err := t.f.Write(buf)
|
||||||
|
return n - 4, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *tun) addIp(cidr netip.Prefix) error {
|
||||||
|
if cidr.Addr().Is4() {
|
||||||
|
var req ifreqAlias4
|
||||||
|
req.Name = t.deviceBytes()
|
||||||
|
req.Addr = unix.RawSockaddrInet4{
|
||||||
|
Len: unix.SizeofSockaddrInet4,
|
||||||
|
Family: unix.AF_INET,
|
||||||
|
Addr: cidr.Addr().As4(),
|
||||||
|
}
|
||||||
|
req.DstAddr = unix.RawSockaddrInet4{
|
||||||
|
Len: unix.SizeofSockaddrInet4,
|
||||||
|
Family: unix.AF_INET,
|
||||||
|
Addr: cidr.Addr().As4(),
|
||||||
|
}
|
||||||
|
req.MaskAddr = unix.RawSockaddrInet4{
|
||||||
|
Len: unix.SizeofSockaddrInet4,
|
||||||
|
Family: unix.AF_INET,
|
||||||
|
Addr: prefixToMask(cidr).As4(),
|
||||||
|
}
|
||||||
|
|
||||||
|
s, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer syscall.Close(s)
|
||||||
|
|
||||||
|
if err := ioctl(uintptr(s), unix.SIOCAIFADDR, uintptr(unsafe.Pointer(&req))); err != nil {
|
||||||
|
return fmt.Errorf("failed to set tun address %s: %s", cidr.Addr(), err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = addRoute(cidr, t.vpnNetworks)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to set route for vpn network %v: %w", cidr, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if cidr.Addr().Is6() {
|
||||||
|
var req ifreqAlias6
|
||||||
|
req.Name = t.deviceBytes()
|
||||||
|
req.Addr = unix.RawSockaddrInet6{
|
||||||
|
Len: unix.SizeofSockaddrInet6,
|
||||||
|
Family: unix.AF_INET6,
|
||||||
|
Addr: cidr.Addr().As16(),
|
||||||
|
}
|
||||||
|
req.PrefixMask = unix.RawSockaddrInet6{
|
||||||
|
Len: unix.SizeofSockaddrInet6,
|
||||||
|
Family: unix.AF_INET6,
|
||||||
|
Addr: prefixToMask(cidr).As16(),
|
||||||
|
}
|
||||||
|
req.Lifetime[0] = 0xffffffff
|
||||||
|
req.Lifetime[1] = 0xffffffff
|
||||||
|
|
||||||
|
s, err := unix.Socket(unix.AF_INET6, unix.SOCK_DGRAM, unix.IPPROTO_IP)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer syscall.Close(s)
|
||||||
|
|
||||||
|
if err := ioctl(uintptr(s), SIOCAIFADDR_IN6, uintptr(unsafe.Pointer(&req))); err != nil {
|
||||||
|
return fmt.Errorf("failed to set tun address %s: %s", cidr.Addr().String(), err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Errorf("unknown address type %v", cidr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *tun) Activate() error {
|
||||||
|
err := t.doIoctlByName(unix.SIOCSIFMTU, uint32(t.MTU))
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to set tun mtu: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := range t.vpnNetworks {
|
||||||
|
err = t.addIp(t.vpnNetworks[i])
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return t.addRoutes(false)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *tun) doIoctlByName(ctl uintptr, value uint32) error {
|
||||||
|
s, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer syscall.Close(s)
|
||||||
|
|
||||||
|
ir := ifreq{Name: t.deviceBytes(), data: int(value)}
|
||||||
|
err = ioctl(uintptr(s), ctl, uintptr(unsafe.Pointer(&ir)))
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
func (t *tun) reload(c *config.C, initial bool) error {
|
func (t *tun) reload(c *config.C, initial bool) error {
|
||||||
change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial)
|
change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -123,63 +297,46 @@ func (t *tun) reload(c *config.C, initial bool) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) addIp(cidr netip.Prefix) error {
|
func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
|
||||||
var err error
|
|
||||||
// TODO use syscalls instead of exec.Command
|
|
||||||
cmd := exec.Command("/sbin/ifconfig", t.Device, cidr.String(), cidr.Addr().String())
|
|
||||||
t.l.Debug("command: ", cmd.String())
|
|
||||||
if err = cmd.Run(); err != nil {
|
|
||||||
return fmt.Errorf("failed to run 'ifconfig': %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
cmd = exec.Command("/sbin/ifconfig", t.Device, "mtu", strconv.Itoa(t.MTU))
|
|
||||||
t.l.Debug("command: ", cmd.String())
|
|
||||||
if err = cmd.Run(); err != nil {
|
|
||||||
return fmt.Errorf("failed to run 'ifconfig': %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
cmd = exec.Command("/sbin/route", "-n", "add", "-inet", cidr.String(), cidr.Addr().String())
|
|
||||||
t.l.Debug("command: ", cmd.String())
|
|
||||||
if err = cmd.Run(); err != nil {
|
|
||||||
return fmt.Errorf("failed to run 'route add': %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Unsafe path routes
|
|
||||||
return t.addRoutes(false)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *tun) Activate() error {
|
|
||||||
for i := range t.vpnNetworks {
|
|
||||||
err := t.addIp(t.vpnNetworks[i])
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *tun) RouteFor(ip netip.Addr) netip.Addr {
|
|
||||||
r, _ := t.routeTree.Load().Lookup(ip)
|
r, _ := t.routeTree.Load().Lookup(ip)
|
||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *tun) Networks() []netip.Prefix {
|
||||||
|
return t.vpnNetworks
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *tun) Name() string {
|
||||||
|
return t.Device
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *tun) SupportsMultiqueue() bool {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||||
|
return nil, fmt.Errorf("TODO: multiqueue not implemented for openbsd")
|
||||||
|
}
|
||||||
|
|
||||||
func (t *tun) addRoutes(logErrors bool) error {
|
func (t *tun) addRoutes(logErrors bool) error {
|
||||||
routes := *t.Routes.Load()
|
routes := *t.Routes.Load()
|
||||||
|
|
||||||
for _, r := range routes {
|
for _, r := range routes {
|
||||||
if !r.Via.IsValid() || !r.Install {
|
if len(r.Via) == 0 || !r.Install {
|
||||||
// We don't allow route MTUs so only install routes with a via
|
// We don't allow route MTUs so only install routes with a via
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
//TODO: CERT-V2 is this right?
|
|
||||||
cmd := exec.Command("/sbin/route", "-n", "add", "-inet", r.Cidr.String(), t.vpnNetworks[0].Addr().String())
|
err := addRoute(r.Cidr, t.vpnNetworks)
|
||||||
t.l.Debug("command: ", cmd.String())
|
if err != nil {
|
||||||
if err := cmd.Run(); err != nil {
|
retErr := util.NewContextualError("Failed to add route", map[string]any{"route": r}, err)
|
||||||
retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]interface{}{"route": r}, err)
|
|
||||||
if logErrors {
|
if logErrors {
|
||||||
retErr.Log(t.l)
|
retErr.Log(t.l)
|
||||||
} else {
|
} else {
|
||||||
return retErr
|
return retErr
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
t.l.WithField("route", r).Info("Added route")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -191,10 +348,9 @@ func (t *tun) removeRoutes(routes []Route) error {
|
|||||||
if !r.Install {
|
if !r.Install {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
//TODO: CERT-V2 is this right?
|
|
||||||
cmd := exec.Command("/sbin/route", "-n", "delete", "-inet", r.Cidr.String(), t.vpnNetworks[0].Addr().String())
|
err := delRoute(r.Cidr, t.vpnNetworks)
|
||||||
t.l.Debug("command: ", cmd.String())
|
if err != nil {
|
||||||
if err := cmd.Run(); err != nil {
|
|
||||||
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
|
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
|
||||||
} else {
|
} else {
|
||||||
t.l.WithField("route", r).Info("Removed route")
|
t.l.WithField("route", r).Info("Removed route")
|
||||||
@@ -203,52 +359,115 @@ func (t *tun) removeRoutes(routes []Route) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) Networks() []netip.Prefix {
|
func (t *tun) deviceBytes() (o [16]byte) {
|
||||||
return t.vpnNetworks
|
for i, c := range t.Device {
|
||||||
}
|
o[i] = byte(c)
|
||||||
|
|
||||||
func (t *tun) Name() string {
|
|
||||||
return t.Device
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
|
||||||
return nil, fmt.Errorf("TODO: multiqueue not implemented for freebsd")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *tun) Read(to []byte) (int, error) {
|
|
||||||
buf := make([]byte, len(to)+4)
|
|
||||||
|
|
||||||
n, err := t.ReadWriteCloser.Read(buf)
|
|
||||||
|
|
||||||
copy(to, buf[4:])
|
|
||||||
return n - 4, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Write is only valid for single threaded use
|
|
||||||
func (t *tun) Write(from []byte) (int, error) {
|
|
||||||
buf := t.out
|
|
||||||
if cap(buf) < len(from)+4 {
|
|
||||||
buf = make([]byte, len(from)+4)
|
|
||||||
t.out = buf
|
|
||||||
}
|
}
|
||||||
buf = buf[:len(from)+4]
|
return
|
||||||
|
}
|
||||||
|
|
||||||
if len(from) == 0 {
|
func addRoute(prefix netip.Prefix, gateways []netip.Prefix) error {
|
||||||
return 0, syscall.EIO
|
sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
|
||||||
|
}
|
||||||
|
defer unix.Close(sock)
|
||||||
|
|
||||||
|
route := &netroute.RouteMessage{
|
||||||
|
Version: unix.RTM_VERSION,
|
||||||
|
Type: unix.RTM_ADD,
|
||||||
|
Flags: unix.RTF_UP | unix.RTF_GATEWAY,
|
||||||
|
Seq: 1,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Determine the IP Family for the NULL L2 Header
|
if prefix.Addr().Is4() {
|
||||||
ipVer := from[0] >> 4
|
gw, err := selectGateway(prefix, gateways)
|
||||||
if ipVer == 4 {
|
if err != nil {
|
||||||
buf[3] = syscall.AF_INET
|
return err
|
||||||
} else if ipVer == 6 {
|
}
|
||||||
buf[3] = syscall.AF_INET6
|
route.Addrs = []netroute.Addr{
|
||||||
|
unix.RTAX_DST: &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()},
|
||||||
|
unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()},
|
||||||
|
unix.RTAX_GATEWAY: &netroute.Inet4Addr{IP: gw.Addr().As4()},
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
return 0, fmt.Errorf("unable to determine IP version from packet")
|
gw, err := selectGateway(prefix, gateways)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
route.Addrs = []netroute.Addr{
|
||||||
|
unix.RTAX_DST: &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()},
|
||||||
|
unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()},
|
||||||
|
unix.RTAX_GATEWAY: &netroute.Inet6Addr{IP: gw.Addr().As16()},
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
copy(buf[4:], from)
|
data, err := route.Marshal()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create route.RouteMessage: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
n, err := t.ReadWriteCloser.Write(buf)
|
_, err = unix.Write(sock, data[:])
|
||||||
return n - 4, err
|
if err != nil {
|
||||||
|
if errors.Is(err, unix.EEXIST) {
|
||||||
|
// Try to do a change
|
||||||
|
route.Type = unix.RTM_CHANGE
|
||||||
|
data, err = route.Marshal()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create route.RouteMessage for change: %w", err)
|
||||||
|
}
|
||||||
|
_, err = unix.Write(sock, data[:])
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func delRoute(prefix netip.Prefix, gateways []netip.Prefix) error {
|
||||||
|
sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
|
||||||
|
}
|
||||||
|
defer unix.Close(sock)
|
||||||
|
|
||||||
|
route := netroute.RouteMessage{
|
||||||
|
Version: unix.RTM_VERSION,
|
||||||
|
Type: unix.RTM_DELETE,
|
||||||
|
Seq: 1,
|
||||||
|
}
|
||||||
|
|
||||||
|
if prefix.Addr().Is4() {
|
||||||
|
gw, err := selectGateway(prefix, gateways)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
route.Addrs = []netroute.Addr{
|
||||||
|
unix.RTAX_DST: &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()},
|
||||||
|
unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()},
|
||||||
|
unix.RTAX_GATEWAY: &netroute.Inet4Addr{IP: gw.Addr().As4()},
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
gw, err := selectGateway(prefix, gateways)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
route.Addrs = []netroute.Addr{
|
||||||
|
unix.RTAX_DST: &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()},
|
||||||
|
unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()},
|
||||||
|
unix.RTAX_GATEWAY: &netroute.Inet6Addr{IP: gw.Addr().As16()},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := route.Marshal()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create route.RouteMessage: %w", err)
|
||||||
|
}
|
||||||
|
_, err = unix.Write(sock, data[:])
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -13,13 +13,14 @@ import (
|
|||||||
"github.com/gaissmai/bart"
|
"github.com/gaissmai/bart"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
|
"github.com/slackhq/nebula/routing"
|
||||||
)
|
)
|
||||||
|
|
||||||
type TestTun struct {
|
type TestTun struct {
|
||||||
Device string
|
Device string
|
||||||
vpnNetworks []netip.Prefix
|
vpnNetworks []netip.Prefix
|
||||||
Routes []Route
|
Routes []Route
|
||||||
routeTree *bart.Table[netip.Addr]
|
routeTree *bart.Table[routing.Gateways]
|
||||||
l *logrus.Logger
|
l *logrus.Logger
|
||||||
|
|
||||||
closed atomic.Bool
|
closed atomic.Bool
|
||||||
@@ -86,7 +87,7 @@ func (t *TestTun) Get(block bool) []byte {
|
|||||||
// Below this is boilerplate implementation to make nebula actually work
|
// Below this is boilerplate implementation to make nebula actually work
|
||||||
//********************************************************************************************************************//
|
//********************************************************************************************************************//
|
||||||
|
|
||||||
func (t *TestTun) RouteFor(ip netip.Addr) netip.Addr {
|
func (t *TestTun) RoutesFor(ip netip.Addr) routing.Gateways {
|
||||||
r, _ := t.routeTree.Lookup(ip)
|
r, _ := t.routeTree.Lookup(ip)
|
||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
@@ -131,6 +132,10 @@ func (t *TestTun) Read(b []byte) (int, error) {
|
|||||||
return len(p), nil
|
return len(p), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *TestTun) SupportsMultiqueue() bool {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
func (t *TestTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
func (t *TestTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||||
return nil, fmt.Errorf("TODO: multiqueue not implemented")
|
return nil, fmt.Errorf("TODO: multiqueue not implemented")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ import (
|
|||||||
"github.com/gaissmai/bart"
|
"github.com/gaissmai/bart"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
|
"github.com/slackhq/nebula/routing"
|
||||||
"github.com/slackhq/nebula/util"
|
"github.com/slackhq/nebula/util"
|
||||||
"github.com/slackhq/nebula/wintun"
|
"github.com/slackhq/nebula/wintun"
|
||||||
"golang.org/x/sys/windows"
|
"golang.org/x/sys/windows"
|
||||||
@@ -31,7 +32,7 @@ type winTun struct {
|
|||||||
vpnNetworks []netip.Prefix
|
vpnNetworks []netip.Prefix
|
||||||
MTU int
|
MTU int
|
||||||
Routes atomic.Pointer[[]Route]
|
Routes atomic.Pointer[[]Route]
|
||||||
routeTree atomic.Pointer[bart.Table[netip.Addr]]
|
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
||||||
l *logrus.Logger
|
l *logrus.Logger
|
||||||
|
|
||||||
tun *wintun.NativeTun
|
tun *wintun.NativeTun
|
||||||
@@ -147,15 +148,18 @@ func (t *winTun) addRoutes(logErrors bool) error {
|
|||||||
foundDefault4 := false
|
foundDefault4 := false
|
||||||
|
|
||||||
for _, r := range routes {
|
for _, r := range routes {
|
||||||
if !r.Via.IsValid() || !r.Install {
|
if len(r.Via) == 0 || !r.Install {
|
||||||
// We don't allow route MTUs so only install routes with a via
|
// We don't allow route MTUs so only install routes with a via
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add our unsafe route
|
// Add our unsafe route
|
||||||
err := luid.AddRoute(r.Cidr, r.Via, uint32(r.Metric))
|
// Windows does not support multipath routes natively, so we install only a single route.
|
||||||
|
// This is not a problem as traffic will always be sent to Nebula which handles the multipath routing internally.
|
||||||
|
// In effect this provides multipath routing support to windows supporting loadbalancing and redundancy.
|
||||||
|
err := luid.AddRoute(r.Cidr, r.Via[0].Addr(), uint32(r.Metric))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
retErr := util.NewContextualError("Failed to add route", map[string]interface{}{"route": r}, err)
|
retErr := util.NewContextualError("Failed to add route", map[string]any{"route": r}, err)
|
||||||
if logErrors {
|
if logErrors {
|
||||||
retErr.Log(t.l)
|
retErr.Log(t.l)
|
||||||
continue
|
continue
|
||||||
@@ -198,7 +202,8 @@ func (t *winTun) removeRoutes(routes []Route) error {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
err := luid.DeleteRoute(r.Cidr, r.Via)
|
// See comment on luid.AddRoute
|
||||||
|
err := luid.DeleteRoute(r.Cidr, r.Via[0].Addr())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
|
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
|
||||||
} else {
|
} else {
|
||||||
@@ -208,7 +213,7 @@ func (t *winTun) removeRoutes(routes []Route) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *winTun) RouteFor(ip netip.Addr) netip.Addr {
|
func (t *winTun) RoutesFor(ip netip.Addr) routing.Gateways {
|
||||||
r, _ := t.routeTree.Load().Lookup(ip)
|
r, _ := t.routeTree.Load().Lookup(ip)
|
||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
@@ -229,6 +234,10 @@ func (t *winTun) Write(b []byte) (int, error) {
|
|||||||
return t.tun.Write(b, 0)
|
return t.tun.Write(b, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *winTun) SupportsMultiqueue() bool {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
func (t *winTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
func (t *winTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||||
return nil, fmt.Errorf("TODO: multiqueue not implemented for windows")
|
return nil, fmt.Errorf("TODO: multiqueue not implemented for windows")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
|
"github.com/slackhq/nebula/routing"
|
||||||
)
|
)
|
||||||
|
|
||||||
func NewUserDeviceFromConfig(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) {
|
func NewUserDeviceFromConfig(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) {
|
||||||
@@ -38,9 +39,17 @@ type UserDevice struct {
|
|||||||
func (d *UserDevice) Activate() error {
|
func (d *UserDevice) Activate() error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
func (d *UserDevice) Networks() []netip.Prefix { return d.vpnNetworks }
|
|
||||||
func (d *UserDevice) Name() string { return "faketun0" }
|
func (d *UserDevice) Networks() []netip.Prefix { return d.vpnNetworks }
|
||||||
func (d *UserDevice) RouteFor(ip netip.Addr) netip.Addr { return ip }
|
func (d *UserDevice) Name() string { return "faketun0" }
|
||||||
|
func (d *UserDevice) RoutesFor(ip netip.Addr) routing.Gateways {
|
||||||
|
return routing.Gateways{routing.NewGateway(ip, 1)}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *UserDevice) SupportsMultiqueue() bool {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
func (d *UserDevice) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
func (d *UserDevice) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||||
return d, nil
|
return d, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -180,6 +180,7 @@ func (c *PKClient) DeriveNoise(peerPubKey []byte) ([]byte, error) {
|
|||||||
pkcs11.NewAttribute(pkcs11.CKA_DECRYPT, true),
|
pkcs11.NewAttribute(pkcs11.CKA_DECRYPT, true),
|
||||||
pkcs11.NewAttribute(pkcs11.CKA_WRAP, true),
|
pkcs11.NewAttribute(pkcs11.CKA_WRAP, true),
|
||||||
pkcs11.NewAttribute(pkcs11.CKA_UNWRAP, true),
|
pkcs11.NewAttribute(pkcs11.CKA_UNWRAP, true),
|
||||||
|
pkcs11.NewAttribute(pkcs11.CKA_VALUE_LEN, NoiseKeySize),
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set up the parameters which include the peer's public key
|
// Set up the parameters which include the peer's public key
|
||||||
|
|||||||
157
pki.go
157
pki.go
@@ -33,16 +33,16 @@ type CertState struct {
|
|||||||
v2Cert cert.Certificate
|
v2Cert cert.Certificate
|
||||||
v2HandshakeBytes []byte
|
v2HandshakeBytes []byte
|
||||||
|
|
||||||
defaultVersion cert.Version
|
initiatingVersion cert.Version
|
||||||
privateKey []byte
|
privateKey []byte
|
||||||
pkcs11Backed bool
|
pkcs11Backed bool
|
||||||
cipher string
|
cipher string
|
||||||
|
|
||||||
myVpnNetworks []netip.Prefix
|
myVpnNetworks []netip.Prefix
|
||||||
myVpnNetworksTable *bart.Table[struct{}]
|
myVpnNetworksTable *bart.Lite
|
||||||
myVpnAddrs []netip.Addr
|
myVpnAddrs []netip.Addr
|
||||||
myVpnAddrsTable *bart.Table[struct{}]
|
myVpnAddrsTable *bart.Lite
|
||||||
myVpnBroadcastAddrsTable *bart.Table[struct{}]
|
myVpnBroadcastAddrsTable *bart.Lite
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewPKIFromConfig(l *logrus.Logger, c *config.C) (*PKI, error) {
|
func NewPKIFromConfig(l *logrus.Logger, c *config.C) (*PKI, error) {
|
||||||
@@ -100,55 +100,62 @@ func (p *PKI) reloadCerts(c *config.C, initial bool) *util.ContextualError {
|
|||||||
currentState := p.cs.Load()
|
currentState := p.cs.Load()
|
||||||
if newState.v1Cert != nil {
|
if newState.v1Cert != nil {
|
||||||
if currentState.v1Cert == nil {
|
if currentState.v1Cert == nil {
|
||||||
return util.NewContextualError("v1 certificate was added, restart required", nil, err)
|
//adding certs is fine, actually. Networks-in-common confirmed in newCertState().
|
||||||
}
|
} else {
|
||||||
|
// did IP in cert change? if so, don't set
|
||||||
|
if !slices.Equal(currentState.v1Cert.Networks(), newState.v1Cert.Networks()) {
|
||||||
|
return util.NewContextualError(
|
||||||
|
"Networks in new cert was different from old",
|
||||||
|
m{"new_networks": newState.v1Cert.Networks(), "old_networks": currentState.v1Cert.Networks(), "cert_version": cert.Version1},
|
||||||
|
nil,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
// did IP in cert change? if so, don't set
|
if currentState.v1Cert.Curve() != newState.v1Cert.Curve() {
|
||||||
if !slices.Equal(currentState.v1Cert.Networks(), newState.v1Cert.Networks()) {
|
return util.NewContextualError(
|
||||||
return util.NewContextualError(
|
"Curve in new v1 cert was different from old",
|
||||||
"Networks in new cert was different from old",
|
m{"new_curve": newState.v1Cert.Curve(), "old_curve": currentState.v1Cert.Curve(), "cert_version": cert.Version1},
|
||||||
m{"new_networks": newState.v1Cert.Networks(), "old_networks": currentState.v1Cert.Networks()},
|
nil,
|
||||||
nil,
|
)
|
||||||
)
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if currentState.v1Cert.Curve() != newState.v1Cert.Curve() {
|
|
||||||
return util.NewContextualError(
|
|
||||||
"Curve in new cert was different from old",
|
|
||||||
m{"new_curve": newState.v1Cert.Curve(), "old_curve": currentState.v1Cert.Curve()},
|
|
||||||
nil,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
} else if currentState.v1Cert != nil {
|
|
||||||
//TODO: CERT-V2 we should be able to tear this down
|
|
||||||
return util.NewContextualError("v1 certificate was removed, restart required", nil, err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if newState.v2Cert != nil {
|
if newState.v2Cert != nil {
|
||||||
if currentState.v2Cert == nil {
|
if currentState.v2Cert == nil {
|
||||||
return util.NewContextualError("v2 certificate was added, restart required", nil, err)
|
//adding certs is fine, actually
|
||||||
}
|
} else {
|
||||||
|
// did IP in cert change? if so, don't set
|
||||||
|
if !slices.Equal(currentState.v2Cert.Networks(), newState.v2Cert.Networks()) {
|
||||||
|
return util.NewContextualError(
|
||||||
|
"Networks in new cert was different from old",
|
||||||
|
m{"new_networks": newState.v2Cert.Networks(), "old_networks": currentState.v2Cert.Networks(), "cert_version": cert.Version2},
|
||||||
|
nil,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
// did IP in cert change? if so, don't set
|
if currentState.v2Cert.Curve() != newState.v2Cert.Curve() {
|
||||||
if !slices.Equal(currentState.v2Cert.Networks(), newState.v2Cert.Networks()) {
|
return util.NewContextualError(
|
||||||
return util.NewContextualError(
|
"Curve in new cert was different from old",
|
||||||
"Networks in new cert was different from old",
|
m{"new_curve": newState.v2Cert.Curve(), "old_curve": currentState.v2Cert.Curve(), "cert_version": cert.Version2},
|
||||||
m{"new_networks": newState.v2Cert.Networks(), "old_networks": currentState.v2Cert.Networks()},
|
nil,
|
||||||
nil,
|
)
|
||||||
)
|
}
|
||||||
}
|
|
||||||
|
|
||||||
if currentState.v2Cert.Curve() != newState.v2Cert.Curve() {
|
|
||||||
return util.NewContextualError(
|
|
||||||
"Curve in new cert was different from old",
|
|
||||||
m{"new_curve": newState.v2Cert.Curve(), "old_curve": currentState.v2Cert.Curve()},
|
|
||||||
nil,
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} else if currentState.v2Cert != nil {
|
} else if currentState.v2Cert != nil {
|
||||||
return util.NewContextualError("v2 certificate was removed, restart required", nil, err)
|
//newState.v1Cert is non-nil bc empty certstates aren't permitted
|
||||||
|
if newState.v1Cert == nil {
|
||||||
|
return util.NewContextualError("v1 and v2 certs are nil, this should be impossible", nil, err)
|
||||||
|
}
|
||||||
|
//if we're going to v1-only, we need to make sure we didn't orphan any v2-cert vpnaddrs
|
||||||
|
if !slices.Equal(currentState.v2Cert.Networks(), newState.v1Cert.Networks()) {
|
||||||
|
return util.NewContextualError(
|
||||||
|
"Removing a V2 cert is not permitted unless it has identical networks to the new V1 cert",
|
||||||
|
m{"new_v1_networks": newState.v1Cert.Networks(), "old_v2_networks": currentState.v2Cert.Networks()},
|
||||||
|
nil,
|
||||||
|
)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Cipher cant be hot swapped so just leave it at what it was before
|
// Cipher cant be hot swapped so just leave it at what it was before
|
||||||
@@ -193,7 +200,7 @@ func (p *PKI) reloadCAPool(c *config.C) *util.ContextualError {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (cs *CertState) GetDefaultCertificate() cert.Certificate {
|
func (cs *CertState) GetDefaultCertificate() cert.Certificate {
|
||||||
c := cs.getCertificate(cs.defaultVersion)
|
c := cs.getCertificate(cs.initiatingVersion)
|
||||||
if c == nil {
|
if c == nil {
|
||||||
panic("No default certificate found")
|
panic("No default certificate found")
|
||||||
}
|
}
|
||||||
@@ -316,37 +323,37 @@ func newCertStateFromConfig(c *config.C) (*CertState, error) {
|
|||||||
return nil, errors.New("no certificates found in pki.cert")
|
return nil, errors.New("no certificates found in pki.cert")
|
||||||
}
|
}
|
||||||
|
|
||||||
useDefaultVersion := uint32(1)
|
useInitiatingVersion := uint32(1)
|
||||||
if v1 == nil {
|
if v1 == nil {
|
||||||
// The only condition that requires v2 as the default is if only a v2 certificate is present
|
// The only condition that requires v2 as the default is if only a v2 certificate is present
|
||||||
// We do this to avoid having to configure it specifically in the config file
|
// We do this to avoid having to configure it specifically in the config file
|
||||||
useDefaultVersion = 2
|
useInitiatingVersion = 2
|
||||||
}
|
}
|
||||||
|
|
||||||
rawDefaultVersion := c.GetUint32("pki.default_version", useDefaultVersion)
|
rawInitiatingVersion := c.GetUint32("pki.initiating_version", useInitiatingVersion)
|
||||||
var defaultVersion cert.Version
|
var initiatingVersion cert.Version
|
||||||
switch rawDefaultVersion {
|
switch rawInitiatingVersion {
|
||||||
case 1:
|
case 1:
|
||||||
if v1 == nil {
|
if v1 == nil {
|
||||||
return nil, fmt.Errorf("can not use pki.default_version 1 without a v1 certificate in pki.cert")
|
return nil, fmt.Errorf("can not use pki.initiating_version 1 without a v1 certificate in pki.cert")
|
||||||
}
|
}
|
||||||
defaultVersion = cert.Version1
|
initiatingVersion = cert.Version1
|
||||||
case 2:
|
case 2:
|
||||||
defaultVersion = cert.Version2
|
initiatingVersion = cert.Version2
|
||||||
default:
|
default:
|
||||||
return nil, fmt.Errorf("unknown pki.default_version: %v", rawDefaultVersion)
|
return nil, fmt.Errorf("unknown pki.initiating_version: %v", rawInitiatingVersion)
|
||||||
}
|
}
|
||||||
|
|
||||||
return newCertState(defaultVersion, v1, v2, isPkcs11, curve, rawKey)
|
return newCertState(initiatingVersion, v1, v2, isPkcs11, curve, rawKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
func newCertState(dv cert.Version, v1, v2 cert.Certificate, pkcs11backed bool, privateKeyCurve cert.Curve, privateKey []byte) (*CertState, error) {
|
func newCertState(dv cert.Version, v1, v2 cert.Certificate, pkcs11backed bool, privateKeyCurve cert.Curve, privateKey []byte) (*CertState, error) {
|
||||||
cs := CertState{
|
cs := CertState{
|
||||||
privateKey: privateKey,
|
privateKey: privateKey,
|
||||||
pkcs11Backed: pkcs11backed,
|
pkcs11Backed: pkcs11backed,
|
||||||
myVpnNetworksTable: new(bart.Table[struct{}]),
|
myVpnNetworksTable: new(bart.Lite),
|
||||||
myVpnAddrsTable: new(bart.Table[struct{}]),
|
myVpnAddrsTable: new(bart.Lite),
|
||||||
myVpnBroadcastAddrsTable: new(bart.Table[struct{}]),
|
myVpnBroadcastAddrsTable: new(bart.Lite),
|
||||||
}
|
}
|
||||||
|
|
||||||
if v1 != nil && v2 != nil {
|
if v1 != nil && v2 != nil {
|
||||||
@@ -358,9 +365,11 @@ func newCertState(dv cert.Version, v1, v2 cert.Certificate, pkcs11backed bool, p
|
|||||||
return nil, util.NewContextualError("v1 and v2 curve are not the same, ignoring", nil, nil)
|
return nil, util.NewContextualError("v1 and v2 curve are not the same, ignoring", nil, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
//TODO: CERT-V2 make sure v2 has v1s address
|
if v1.Networks()[0] != v2.Networks()[0] {
|
||||||
|
return nil, util.NewContextualError("v1 and v2 networks are not the same", nil, nil)
|
||||||
|
}
|
||||||
|
|
||||||
cs.defaultVersion = dv
|
cs.initiatingVersion = dv
|
||||||
}
|
}
|
||||||
|
|
||||||
if v1 != nil {
|
if v1 != nil {
|
||||||
@@ -379,8 +388,8 @@ func newCertState(dv cert.Version, v1, v2 cert.Certificate, pkcs11backed bool, p
|
|||||||
cs.v1Cert = v1
|
cs.v1Cert = v1
|
||||||
cs.v1HandshakeBytes = v1hs
|
cs.v1HandshakeBytes = v1hs
|
||||||
|
|
||||||
if cs.defaultVersion == 0 {
|
if cs.initiatingVersion == 0 {
|
||||||
cs.defaultVersion = cert.Version1
|
cs.initiatingVersion = cert.Version1
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -400,8 +409,8 @@ func newCertState(dv cert.Version, v1, v2 cert.Certificate, pkcs11backed bool, p
|
|||||||
cs.v2Cert = v2
|
cs.v2Cert = v2
|
||||||
cs.v2HandshakeBytes = v2hs
|
cs.v2HandshakeBytes = v2hs
|
||||||
|
|
||||||
if cs.defaultVersion == 0 {
|
if cs.initiatingVersion == 0 {
|
||||||
cs.defaultVersion = cert.Version2
|
cs.initiatingVersion = cert.Version2
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -414,16 +423,16 @@ func newCertState(dv cert.Version, v1, v2 cert.Certificate, pkcs11backed bool, p
|
|||||||
|
|
||||||
for _, network := range crt.Networks() {
|
for _, network := range crt.Networks() {
|
||||||
cs.myVpnNetworks = append(cs.myVpnNetworks, network)
|
cs.myVpnNetworks = append(cs.myVpnNetworks, network)
|
||||||
cs.myVpnNetworksTable.Insert(network, struct{}{})
|
cs.myVpnNetworksTable.Insert(network)
|
||||||
|
|
||||||
cs.myVpnAddrs = append(cs.myVpnAddrs, network.Addr())
|
cs.myVpnAddrs = append(cs.myVpnAddrs, network.Addr())
|
||||||
cs.myVpnAddrsTable.Insert(netip.PrefixFrom(network.Addr(), network.Addr().BitLen()), struct{}{})
|
cs.myVpnAddrsTable.Insert(netip.PrefixFrom(network.Addr(), network.Addr().BitLen()))
|
||||||
|
|
||||||
if network.Addr().Is4() {
|
if network.Addr().Is4() {
|
||||||
addr := network.Masked().Addr().As4()
|
addr := network.Masked().Addr().As4()
|
||||||
mask := net.CIDRMask(network.Bits(), network.Addr().BitLen())
|
mask := net.CIDRMask(network.Bits(), network.Addr().BitLen())
|
||||||
binary.BigEndian.PutUint32(addr[:], binary.BigEndian.Uint32(addr[:])|^binary.BigEndian.Uint32(mask))
|
binary.BigEndian.PutUint32(addr[:], binary.BigEndian.Uint32(addr[:])|^binary.BigEndian.Uint32(mask))
|
||||||
cs.myVpnBroadcastAddrsTable.Insert(netip.PrefixFrom(netip.AddrFrom4(addr), network.Addr().BitLen()), struct{}{})
|
cs.myVpnBroadcastAddrsTable.Insert(netip.PrefixFrom(netip.AddrFrom4(addr), network.Addr().BitLen()))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -514,9 +523,13 @@ func loadCAPoolFromConfig(l *logrus.Logger, c *config.C) (*cert.CAPool, error) {
|
|||||||
return nil, fmt.Errorf("error while adding CA certificate to CA trust store: %s", err)
|
return nil, fmt.Errorf("error while adding CA certificate to CA trust store: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, fp := range c.GetStringSlice("pki.blocklist", []string{}) {
|
bl := c.GetStringSlice("pki.blocklist", []string{})
|
||||||
l.WithField("fingerprint", fp).Info("Blocklisting cert")
|
if len(bl) > 0 {
|
||||||
caPool.BlocklistFingerprint(fp)
|
for _, fp := range bl {
|
||||||
|
caPool.BlocklistFingerprint(fp)
|
||||||
|
}
|
||||||
|
|
||||||
|
l.WithField("fingerprintCount", len(bl)).Info("Blocklisted certificates")
|
||||||
}
|
}
|
||||||
|
|
||||||
return caPool, nil
|
return caPool, nil
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
"github.com/slackhq/nebula/test"
|
"github.com/slackhq/nebula/test"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestNewPunchyFromConfig(t *testing.T) {
|
func TestNewPunchyFromConfig(t *testing.T) {
|
||||||
@@ -15,39 +16,39 @@ func TestNewPunchyFromConfig(t *testing.T) {
|
|||||||
|
|
||||||
// Test defaults
|
// Test defaults
|
||||||
p := NewPunchyFromConfig(l, c)
|
p := NewPunchyFromConfig(l, c)
|
||||||
assert.Equal(t, false, p.GetPunch())
|
assert.False(t, p.GetPunch())
|
||||||
assert.Equal(t, false, p.GetRespond())
|
assert.False(t, p.GetRespond())
|
||||||
assert.Equal(t, time.Second, p.GetDelay())
|
assert.Equal(t, time.Second, p.GetDelay())
|
||||||
assert.Equal(t, 5*time.Second, p.GetRespondDelay())
|
assert.Equal(t, 5*time.Second, p.GetRespondDelay())
|
||||||
|
|
||||||
// punchy deprecation
|
// punchy deprecation
|
||||||
c.Settings["punchy"] = true
|
c.Settings["punchy"] = true
|
||||||
p = NewPunchyFromConfig(l, c)
|
p = NewPunchyFromConfig(l, c)
|
||||||
assert.Equal(t, true, p.GetPunch())
|
assert.True(t, p.GetPunch())
|
||||||
|
|
||||||
// punchy.punch
|
// punchy.punch
|
||||||
c.Settings["punchy"] = map[interface{}]interface{}{"punch": true}
|
c.Settings["punchy"] = map[string]any{"punch": true}
|
||||||
p = NewPunchyFromConfig(l, c)
|
p = NewPunchyFromConfig(l, c)
|
||||||
assert.Equal(t, true, p.GetPunch())
|
assert.True(t, p.GetPunch())
|
||||||
|
|
||||||
// punch_back deprecation
|
// punch_back deprecation
|
||||||
c.Settings["punch_back"] = true
|
c.Settings["punch_back"] = true
|
||||||
p = NewPunchyFromConfig(l, c)
|
p = NewPunchyFromConfig(l, c)
|
||||||
assert.Equal(t, true, p.GetRespond())
|
assert.True(t, p.GetRespond())
|
||||||
|
|
||||||
// punchy.respond
|
// punchy.respond
|
||||||
c.Settings["punchy"] = map[interface{}]interface{}{"respond": true}
|
c.Settings["punchy"] = map[string]any{"respond": true}
|
||||||
c.Settings["punch_back"] = false
|
c.Settings["punch_back"] = false
|
||||||
p = NewPunchyFromConfig(l, c)
|
p = NewPunchyFromConfig(l, c)
|
||||||
assert.Equal(t, true, p.GetRespond())
|
assert.True(t, p.GetRespond())
|
||||||
|
|
||||||
// punchy.delay
|
// punchy.delay
|
||||||
c.Settings["punchy"] = map[interface{}]interface{}{"delay": "1m"}
|
c.Settings["punchy"] = map[string]any{"delay": "1m"}
|
||||||
p = NewPunchyFromConfig(l, c)
|
p = NewPunchyFromConfig(l, c)
|
||||||
assert.Equal(t, time.Minute, p.GetDelay())
|
assert.Equal(t, time.Minute, p.GetDelay())
|
||||||
|
|
||||||
// punchy.respond_delay
|
// punchy.respond_delay
|
||||||
c.Settings["punchy"] = map[interface{}]interface{}{"respond_delay": "1m"}
|
c.Settings["punchy"] = map[string]any{"respond_delay": "1m"}
|
||||||
p = NewPunchyFromConfig(l, c)
|
p = NewPunchyFromConfig(l, c)
|
||||||
assert.Equal(t, time.Minute, p.GetRespondDelay())
|
assert.Equal(t, time.Minute, p.GetRespondDelay())
|
||||||
}
|
}
|
||||||
@@ -56,22 +57,22 @@ func TestPunchy_reload(t *testing.T) {
|
|||||||
l := test.NewLogger()
|
l := test.NewLogger()
|
||||||
c := config.NewC(l)
|
c := config.NewC(l)
|
||||||
delay, _ := time.ParseDuration("1m")
|
delay, _ := time.ParseDuration("1m")
|
||||||
assert.NoError(t, c.LoadString(`
|
require.NoError(t, c.LoadString(`
|
||||||
punchy:
|
punchy:
|
||||||
delay: 1m
|
delay: 1m
|
||||||
respond: false
|
respond: false
|
||||||
`))
|
`))
|
||||||
p := NewPunchyFromConfig(l, c)
|
p := NewPunchyFromConfig(l, c)
|
||||||
assert.Equal(t, delay, p.GetDelay())
|
assert.Equal(t, delay, p.GetDelay())
|
||||||
assert.Equal(t, false, p.GetRespond())
|
assert.False(t, p.GetRespond())
|
||||||
|
|
||||||
newDelay, _ := time.ParseDuration("10m")
|
newDelay, _ := time.ParseDuration("10m")
|
||||||
assert.NoError(t, c.ReloadConfigString(`
|
require.NoError(t, c.ReloadConfigString(`
|
||||||
punchy:
|
punchy:
|
||||||
delay: 10m
|
delay: 10m
|
||||||
respond: true
|
respond: true
|
||||||
`))
|
`))
|
||||||
p.reload(c, false)
|
p.reload(c, false)
|
||||||
assert.Equal(t, newDelay, p.GetDelay())
|
assert.Equal(t, newDelay, p.GetDelay())
|
||||||
assert.Equal(t, true, p.GetRespond())
|
assert.True(t, p.GetRespond())
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -241,15 +241,13 @@ func (rm *relayManager) handleCreateRelayRequest(v cert.Version, h *HostInfo, f
|
|||||||
logMsg.Info("handleCreateRelayRequest")
|
logMsg.Info("handleCreateRelayRequest")
|
||||||
// Is the source of the relay me? This should never happen, but did happen due to
|
// Is the source of the relay me? This should never happen, but did happen due to
|
||||||
// an issue migrating relays over to newly re-handshaked host info objects.
|
// an issue migrating relays over to newly re-handshaked host info objects.
|
||||||
_, found := f.myVpnAddrsTable.Lookup(from)
|
if f.myVpnAddrsTable.Contains(from) {
|
||||||
if found {
|
|
||||||
logMsg.WithField("myIP", from).Error("Discarding relay request from myself")
|
logMsg.WithField("myIP", from).Error("Discarding relay request from myself")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Is the target of the relay me?
|
// Is the target of the relay me?
|
||||||
_, found = f.myVpnAddrsTable.Lookup(target)
|
if f.myVpnAddrsTable.Contains(target) {
|
||||||
if found {
|
|
||||||
existingRelay, ok := h.relayState.QueryRelayForByIp(from)
|
existingRelay, ok := h.relayState.QueryRelayForByIp(from)
|
||||||
if ok {
|
if ok {
|
||||||
switch existingRelay.State {
|
switch existingRelay.State {
|
||||||
|
|||||||
@@ -190,7 +190,7 @@ type RemoteList struct {
|
|||||||
// The full list of vpn addresses assigned to this host
|
// The full list of vpn addresses assigned to this host
|
||||||
vpnAddrs []netip.Addr
|
vpnAddrs []netip.Addr
|
||||||
|
|
||||||
// A deduplicated set of addresses. Any accessor should lock beforehand.
|
// A deduplicated set of underlay addresses. Any accessor should lock beforehand.
|
||||||
addrs []netip.AddrPort
|
addrs []netip.AddrPort
|
||||||
|
|
||||||
// A set of relay addresses. VpnIp addresses that the remote identified as relays.
|
// A set of relay addresses. VpnIp addresses that the remote identified as relays.
|
||||||
@@ -201,8 +201,10 @@ type RemoteList struct {
|
|||||||
// For learned addresses, this is the vpnIp that sent the packet
|
// For learned addresses, this is the vpnIp that sent the packet
|
||||||
cache map[netip.Addr]*cache
|
cache map[netip.Addr]*cache
|
||||||
|
|
||||||
hr *hostnamesResults
|
hr *hostnamesResults
|
||||||
shouldAdd func(netip.Addr) bool
|
|
||||||
|
// shouldAdd is a nillable function that decides if x should be added to addrs.
|
||||||
|
shouldAdd func(vpnAddrs []netip.Addr, x netip.Addr) bool
|
||||||
|
|
||||||
// This is a list of remotes that we have tried to handshake with and have returned from the wrong vpn ip.
|
// This is a list of remotes that we have tried to handshake with and have returned from the wrong vpn ip.
|
||||||
// They should not be tried again during a handshake
|
// They should not be tried again during a handshake
|
||||||
@@ -213,7 +215,7 @@ type RemoteList struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// NewRemoteList creates a new empty RemoteList
|
// NewRemoteList creates a new empty RemoteList
|
||||||
func NewRemoteList(vpnAddrs []netip.Addr, shouldAdd func(netip.Addr) bool) *RemoteList {
|
func NewRemoteList(vpnAddrs []netip.Addr, shouldAdd func([]netip.Addr, netip.Addr) bool) *RemoteList {
|
||||||
r := &RemoteList{
|
r := &RemoteList{
|
||||||
vpnAddrs: make([]netip.Addr, len(vpnAddrs)),
|
vpnAddrs: make([]netip.Addr, len(vpnAddrs)),
|
||||||
addrs: make([]netip.AddrPort, 0),
|
addrs: make([]netip.AddrPort, 0),
|
||||||
@@ -368,6 +370,15 @@ func (r *RemoteList) CopyBlockedRemotes() []netip.AddrPort {
|
|||||||
return c
|
return c
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RefreshFromHandshake locks and updates the RemoteList to account for data learned upon a completed handshake
|
||||||
|
func (r *RemoteList) RefreshFromHandshake(vpnAddrs []netip.Addr) {
|
||||||
|
r.Lock()
|
||||||
|
r.badRemotes = nil
|
||||||
|
r.vpnAddrs = make([]netip.Addr, len(vpnAddrs))
|
||||||
|
copy(r.vpnAddrs, vpnAddrs)
|
||||||
|
r.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
// ResetBlockedRemotes locks and clears the blocked remotes list
|
// ResetBlockedRemotes locks and clears the blocked remotes list
|
||||||
func (r *RemoteList) ResetBlockedRemotes() {
|
func (r *RemoteList) ResetBlockedRemotes() {
|
||||||
r.Lock()
|
r.Lock()
|
||||||
@@ -577,7 +588,7 @@ func (r *RemoteList) unlockedCollect() {
|
|||||||
|
|
||||||
dnsAddrs := r.hr.GetAddrs()
|
dnsAddrs := r.hr.GetAddrs()
|
||||||
for _, addr := range dnsAddrs {
|
for _, addr := range dnsAddrs {
|
||||||
if r.shouldAdd == nil || r.shouldAdd(addr.Addr()) {
|
if r.shouldAdd == nil || r.shouldAdd(r.vpnAddrs, addr.Addr()) {
|
||||||
if !r.unlockedIsBad(addr) {
|
if !r.unlockedIsBad(addr) {
|
||||||
addrs = append(addrs, addr)
|
addrs = append(addrs, addr)
|
||||||
}
|
}
|
||||||
|
|||||||
39
routing/balance.go
Normal file
39
routing/balance.go
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
package routing
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/netip"
|
||||||
|
|
||||||
|
"github.com/slackhq/nebula/firewall"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Hashes the packet source and destination port and always returns a positive integer
|
||||||
|
// Based on 'Prospecting for Hash Functions'
|
||||||
|
// - https://nullprogram.com/blog/2018/07/31/
|
||||||
|
// - https://github.com/skeeto/hash-prospector
|
||||||
|
// [16 21f0aaad 15 d35a2d97 15] = 0.10760229515479501
|
||||||
|
func hashPacket(p *firewall.Packet) int {
|
||||||
|
x := (uint32(p.LocalPort) << 16) | uint32(p.RemotePort)
|
||||||
|
x ^= x >> 16
|
||||||
|
x *= 0x21f0aaad
|
||||||
|
x ^= x >> 15
|
||||||
|
x *= 0xd35a2d97
|
||||||
|
x ^= x >> 15
|
||||||
|
|
||||||
|
return int(x) & 0x7FFFFFFF
|
||||||
|
}
|
||||||
|
|
||||||
|
// For this function to work correctly it requires that the buckets for the gateways have been calculated
|
||||||
|
// If the contract is violated balancing will not work properly and the second return value will return false
|
||||||
|
func BalancePacket(fwPacket *firewall.Packet, gateways []Gateway) (netip.Addr, bool) {
|
||||||
|
hash := hashPacket(fwPacket)
|
||||||
|
|
||||||
|
for i := range gateways {
|
||||||
|
if hash <= gateways[i].BucketUpperBound() {
|
||||||
|
return gateways[i].Addr(), true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If you land here then the buckets for the gateways are not properly calculated
|
||||||
|
// Fallback to random routing and let the caller know
|
||||||
|
return gateways[hash%len(gateways)].Addr(), false
|
||||||
|
}
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user