mirror of
https://github.com/slackhq/nebula.git
synced 2025-11-22 00:15:37 +01:00
Compare commits
84 Commits
v1.1.0
...
windows_ud
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b6c6b96c79 | ||
|
|
6dea7760cb | ||
|
|
ea07a89cc8 | ||
|
|
3aaaea6309 | ||
|
|
5506da3de9 | ||
|
|
6c55d67f18 | ||
|
|
64d8035d09 | ||
|
|
73a5ed90b2 | ||
|
|
d604270966 | ||
|
|
29c5f31f90 | ||
|
|
b6234abfb3 | ||
|
|
2a4beb41b9 | ||
|
|
d232ccbfab | ||
|
|
ecfb40f29c | ||
|
|
1bae5b2550 | ||
|
|
73081d99bc | ||
|
|
e7e6a23cde | ||
|
|
a0583ebdca | ||
|
|
27d9a67dda | ||
|
|
2bce222550 | ||
|
|
3dd1108099 | ||
|
|
d4b81f9b8d | ||
|
|
454bc8a6bb | ||
|
|
ce9ad37431 | ||
|
|
ee7c27093c | ||
|
|
2e7ca027a4 | ||
|
|
672ce1f0a8 | ||
|
|
384b1166ea | ||
|
|
0389596f66 | ||
|
|
43a3988afc | ||
|
|
5c23676a0f | ||
|
|
f6d0b4b893 | ||
|
|
0d6b55e495 | ||
|
|
c71c84882e | ||
|
|
0010db46e4 | ||
|
|
68e3e84fdc | ||
|
|
6238f1550b | ||
|
|
50b04413c7 | ||
|
|
ef498a31da | ||
|
|
2e5a477a50 | ||
|
|
32fe9bfe75 | ||
|
|
9b8b3c478b | ||
|
|
7b3f23d9a1 | ||
|
|
25964b54f6 | ||
|
|
ac557f381b | ||
|
|
a54f3fc681 | ||
|
|
5545cff6ef | ||
|
|
f3a6d8d990 | ||
|
|
9b06748506 | ||
|
|
4756c9613d | ||
|
|
4645e6034b | ||
|
|
aba42f9fa6 | ||
|
|
41578ca971 | ||
|
|
1ea8847085 | ||
|
|
55858c64cc | ||
|
|
e94c6b0125 | ||
|
|
b37a91cfbc | ||
|
|
3212b769d4 | ||
|
|
ecf0e5a9f6 | ||
|
|
ff13aba8fc | ||
|
|
cc03ff9e9a | ||
|
|
363c836422 | ||
|
|
fb252db4a1 | ||
|
|
4f6313ebd3 | ||
|
|
0a474e757b | ||
|
|
7cd342c7ab | ||
|
|
7cdbb14a18 | ||
|
|
b4f2f7ce4e | ||
|
|
ff64d1f952 | ||
|
|
9e2ff7df57 | ||
|
|
1297090af3 | ||
|
|
add1b21777 | ||
|
|
1cb3201b5e | ||
|
|
41968551f9 | ||
|
|
8548ac3c31 | ||
|
|
fb9b36f677 | ||
|
|
4d1928f1e3 | ||
|
|
a91a40212d | ||
|
|
179a369130 | ||
|
|
df69371620 | ||
|
|
eda344d88f | ||
|
|
065e2ff88a | ||
|
|
45a5de2719 | ||
|
|
2d24ef7166 |
23
.github/workflows/gofmt.yml
vendored
23
.github/workflows/gofmt.yml
vendored
@@ -4,6 +4,9 @@ on:
|
||||
branches:
|
||||
- master
|
||||
pull_request:
|
||||
paths:
|
||||
- '.github/workflows/gofmt.yml'
|
||||
- '**.go'
|
||||
jobs:
|
||||
|
||||
gofmt:
|
||||
@@ -11,19 +14,31 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
|
||||
- name: Set up Go 1.13
|
||||
- name: Set up Go 1.16
|
||||
uses: actions/setup-go@v1
|
||||
with:
|
||||
go-version: 1.13
|
||||
go-version: 1.16
|
||||
id: go
|
||||
|
||||
- name: Check out code into the Go module directory
|
||||
uses: actions/checkout@v1
|
||||
|
||||
- uses: actions/cache@v1
|
||||
with:
|
||||
path: ~/go/pkg/mod
|
||||
key: ${{ runner.os }}-gofmt1.16-${{ hashFiles('**/go.sum') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-gofmt1.16-
|
||||
|
||||
- name: Install goimports
|
||||
run: |
|
||||
go get golang.org/x/tools/cmd/goimports
|
||||
go build golang.org/x/tools/cmd/goimports
|
||||
|
||||
- name: gofmt
|
||||
run: |
|
||||
if [ "$(find . -iname '*.go' | xargs gofmt -l)" ]
|
||||
if [ "$(find . -iname '*.go' | grep -v '\.pb\.go$' | xargs ./goimports -l)" ]
|
||||
then
|
||||
find . -iname '*.go' | xargs gofmt -d
|
||||
find . -iname '*.go' | grep -v '\.pb\.go$' | xargs ./goimports -d
|
||||
exit 1
|
||||
fi
|
||||
|
||||
45
.github/workflows/release.yml
vendored
45
.github/workflows/release.yml
vendored
@@ -10,17 +10,17 @@ jobs:
|
||||
name: Build Linux All
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Set up Go 1.13
|
||||
- name: Set up Go 1.16
|
||||
uses: actions/setup-go@v1
|
||||
with:
|
||||
go-version: 1.13
|
||||
go-version: 1.16
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v2
|
||||
|
||||
- name: Build
|
||||
run: |
|
||||
make BUILD_NUMBER="${GITHUB_REF#refs/tags/v}" release-linux
|
||||
make BUILD_NUMBER="${GITHUB_REF#refs/tags/v}" release-linux release-freebsd
|
||||
mkdir release
|
||||
mv build/*.tar.gz release
|
||||
|
||||
@@ -34,10 +34,10 @@ jobs:
|
||||
name: Build Windows amd64
|
||||
runs-on: windows-latest
|
||||
steps:
|
||||
- name: Set up Go 1.13
|
||||
- name: Set up Go 1.16
|
||||
uses: actions/setup-go@v1
|
||||
with:
|
||||
go-version: 1.13
|
||||
go-version: 1.16
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v2
|
||||
@@ -58,10 +58,10 @@ jobs:
|
||||
name: Build Darwin amd64
|
||||
runs-on: macOS-latest
|
||||
steps:
|
||||
- name: Set up Go 1.13
|
||||
- name: Set up Go 1.16
|
||||
uses: actions/setup-go@v1
|
||||
with:
|
||||
go-version: 1.13
|
||||
go-version: 1.16
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v2
|
||||
@@ -69,6 +69,7 @@ jobs:
|
||||
- name: Build
|
||||
run: |
|
||||
make BUILD_NUMBER="${GITHUB_REF#refs/tags/v}" service build/nebula-darwin-amd64.tar.gz
|
||||
make BUILD_NUMBER="${GITHUB_REF#refs/tags/v}" service build/nebula-darwin-arm64.tar.gz
|
||||
mkdir release
|
||||
mv build/*.tar.gz release
|
||||
|
||||
@@ -159,6 +160,16 @@ jobs:
|
||||
asset_name: nebula-darwin-amd64.tar.gz
|
||||
asset_content_type: application/gzip
|
||||
|
||||
- name: Upload darwin-arm64
|
||||
uses: actions/upload-release-asset@v1.0.1
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
with:
|
||||
upload_url: ${{ steps.create_release.outputs.upload_url }}
|
||||
asset_path: ./darwin-latest/nebula-darwin-arm64.tar.gz
|
||||
asset_name: nebula-darwin-arm64.tar.gz
|
||||
asset_content_type: application/gzip
|
||||
|
||||
- name: Upload windows-amd64
|
||||
uses: actions/upload-release-asset@v1.0.1
|
||||
env:
|
||||
@@ -278,3 +289,23 @@ jobs:
|
||||
asset_path: ./linux-latest/nebula-linux-mips64le.tar.gz
|
||||
asset_name: nebula-linux-mips64le.tar.gz
|
||||
asset_content_type: application/gzip
|
||||
|
||||
- name: Upload linux-mips-softfloat
|
||||
uses: actions/upload-release-asset@v1.0.1
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
with:
|
||||
upload_url: ${{ steps.create_release.outputs.upload_url }}
|
||||
asset_path: ./linux-latest/nebula-linux-mips-softfloat.tar.gz
|
||||
asset_name: nebula-linux-mips-softfloat.tar.gz
|
||||
asset_content_type: application/gzip
|
||||
|
||||
- name: Upload freebsd-amd64
|
||||
uses: actions/upload-release-asset@v1.0.1
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
with:
|
||||
upload_url: ${{ steps.create_release.outputs.upload_url }}
|
||||
asset_path: ./linux-latest/nebula-freebsd-amd64.tar.gz
|
||||
asset_name: nebula-freebsd-amd64.tar.gz
|
||||
asset_content_type: application/gzip
|
||||
|
||||
22
.github/workflows/smoke.yml
vendored
22
.github/workflows/smoke.yml
vendored
@@ -4,24 +4,38 @@ on:
|
||||
branches:
|
||||
- master
|
||||
pull_request:
|
||||
paths:
|
||||
- '.github/workflows/smoke**'
|
||||
- '**Makefile'
|
||||
- '**.go'
|
||||
- '**.proto'
|
||||
- 'go.mod'
|
||||
- 'go.sum'
|
||||
jobs:
|
||||
|
||||
smoke:
|
||||
name: Run 3 node smoke test
|
||||
name: Run multi node smoke test
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
|
||||
- name: Set up Go 1.13
|
||||
- name: Set up Go 1.16
|
||||
uses: actions/setup-go@v1
|
||||
with:
|
||||
go-version: 1.13
|
||||
go-version: 1.16
|
||||
id: go
|
||||
|
||||
- name: Check out code into the Go module directory
|
||||
uses: actions/checkout@v1
|
||||
|
||||
- uses: actions/cache@v1
|
||||
with:
|
||||
path: ~/go/pkg/mod
|
||||
key: ${{ runner.os }}-go1.16-${{ hashFiles('**/go.sum') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-go1.16-
|
||||
|
||||
- name: build
|
||||
run: make
|
||||
run: make bin-docker
|
||||
|
||||
- name: setup docker image
|
||||
working-directory: ./.github/workflows/smoke
|
||||
|
||||
6
.github/workflows/smoke/Dockerfile
vendored
6
.github/workflows/smoke/Dockerfile
vendored
@@ -1,5 +1,7 @@
|
||||
FROM debian:buster
|
||||
|
||||
ADD ./build /
|
||||
ADD ./build /nebula
|
||||
|
||||
ENTRYPOINT ["/nebula"]
|
||||
WORKDIR /nebula
|
||||
|
||||
ENTRYPOINT ["/nebula/nebula"]
|
||||
|
||||
35
.github/workflows/smoke/build.sh
vendored
35
.github/workflows/smoke/build.sh
vendored
@@ -8,17 +8,32 @@ mkdir ./build
|
||||
(
|
||||
cd build
|
||||
|
||||
cp ../../../../nebula .
|
||||
cp ../../../../nebula-cert .
|
||||
cp ../../../../build/linux-amd64/nebula .
|
||||
cp ../../../../build/linux-amd64/nebula-cert .
|
||||
|
||||
HOST="lighthouse1" AM_LIGHTHOUSE=true ../genconfig.sh >lighthouse1.yml
|
||||
HOST="host2" LIGHTHOUSES="192.168.100.1 172.17.0.2:4242" ../genconfig.sh >host2.yml
|
||||
HOST="host3" LIGHTHOUSES="192.168.100.1 172.17.0.2:4242" ../genconfig.sh >host3.yml
|
||||
HOST="lighthouse1" \
|
||||
AM_LIGHTHOUSE=true \
|
||||
../genconfig.sh >lighthouse1.yml
|
||||
|
||||
./nebula-cert ca -name "Smoke Test"
|
||||
./nebula-cert sign -name "lighthouse1" -ip "192.168.100.1/24"
|
||||
./nebula-cert sign -name "host2" -ip "192.168.100.2/24"
|
||||
./nebula-cert sign -name "host3" -ip "192.168.100.3/24"
|
||||
HOST="host2" \
|
||||
LIGHTHOUSES="192.168.100.1 172.17.0.2:4242" \
|
||||
../genconfig.sh >host2.yml
|
||||
|
||||
HOST="host3" \
|
||||
LIGHTHOUSES="192.168.100.1 172.17.0.2:4242" \
|
||||
INBOUND='[{"port": "any", "proto": "icmp", "group": "lighthouse"}]' \
|
||||
../genconfig.sh >host3.yml
|
||||
|
||||
HOST="host4" \
|
||||
LIGHTHOUSES="192.168.100.1 172.17.0.2:4242" \
|
||||
OUTBOUND='[{"port": "any", "proto": "icmp", "group": "lighthouse"}]' \
|
||||
../genconfig.sh >host4.yml
|
||||
|
||||
../../../../nebula-cert ca -name "Smoke Test"
|
||||
../../../../nebula-cert sign -name "lighthouse1" -groups "lighthouse,lighthouse1" -ip "192.168.100.1/24"
|
||||
../../../../nebula-cert sign -name "host2" -groups "host,host2" -ip "192.168.100.2/24"
|
||||
../../../../nebula-cert sign -name "host3" -groups "host,host3" -ip "192.168.100.3/24"
|
||||
../../../../nebula-cert sign -name "host4" -groups "host,host4" -ip "192.168.100.4/24"
|
||||
)
|
||||
|
||||
docker build -t nebula:smoke .
|
||||
sudo docker build -t nebula:smoke .
|
||||
|
||||
18
.github/workflows/smoke/genconfig.sh
vendored
18
.github/workflows/smoke/genconfig.sh
vendored
@@ -2,6 +2,7 @@
|
||||
|
||||
set -e
|
||||
|
||||
FIREWALL_ALL='[{"port": "any", "proto": "any", "host": "any"}]'
|
||||
|
||||
if [ "$STATIC_HOSTS" ] || [ "$LIGHTHOUSES" ]
|
||||
then
|
||||
@@ -32,9 +33,9 @@ lighthouse_hosts() {
|
||||
|
||||
cat <<EOF
|
||||
pki:
|
||||
ca: /ca.crt
|
||||
cert: /${HOST}.crt
|
||||
key: /${HOST}.key
|
||||
ca: ca.crt
|
||||
cert: ${HOST}.crt
|
||||
key: ${HOST}.key
|
||||
|
||||
lighthouse:
|
||||
am_lighthouse: ${AM_LIGHTHOUSE:-false}
|
||||
@@ -48,13 +49,6 @@ tun:
|
||||
dev: ${TUN_DEV:-nebula1}
|
||||
|
||||
firewall:
|
||||
outbound:
|
||||
- port: any
|
||||
proto: any
|
||||
host: any
|
||||
|
||||
inbound:
|
||||
- port: any
|
||||
proto: any
|
||||
host: any
|
||||
outbound: ${OUTBOUND:-$FIREWALL_ALL}
|
||||
inbound: ${INBOUND:-$FIREWALL_ALL}
|
||||
EOF
|
||||
|
||||
69
.github/workflows/smoke/smoke.sh
vendored
69
.github/workflows/smoke/smoke.sh
vendored
@@ -1,12 +1,33 @@
|
||||
#!/bin/sh
|
||||
#!/bin/bash
|
||||
|
||||
set -e -x
|
||||
|
||||
docker run --name lighthouse1 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm nebula:smoke -config lighthouse1.yml &
|
||||
set -o pipefail
|
||||
|
||||
mkdir -p logs
|
||||
|
||||
cleanup() {
|
||||
set +e
|
||||
if [ "$(jobs -r)" ]
|
||||
then
|
||||
sudo docker kill lighthouse1 host2 host3 host4
|
||||
fi
|
||||
}
|
||||
|
||||
trap cleanup EXIT
|
||||
|
||||
sudo docker run --name lighthouse1 --rm nebula:smoke -config lighthouse1.yml -test
|
||||
sudo docker run --name host2 --rm nebula:smoke -config host2.yml -test
|
||||
sudo docker run --name host3 --rm nebula:smoke -config host3.yml -test
|
||||
sudo docker run --name host4 --rm nebula:smoke -config host4.yml -test
|
||||
|
||||
sudo docker run --name lighthouse1 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm nebula:smoke -config lighthouse1.yml 2>&1 | tee logs/lighthouse1 &
|
||||
sleep 1
|
||||
docker run --name host2 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm nebula:smoke -config host2.yml &
|
||||
sudo docker run --name host2 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm nebula:smoke -config host2.yml 2>&1 | tee logs/host2 &
|
||||
sleep 1
|
||||
docker run --name host3 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm nebula:smoke -config host3.yml &
|
||||
sudo docker run --name host3 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm nebula:smoke -config host3.yml 2>&1 | tee logs/host3 &
|
||||
sleep 1
|
||||
sudo docker run --name host4 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm nebula:smoke -config host4.yml 2>&1 | tee logs/host4 &
|
||||
sleep 1
|
||||
|
||||
set +x
|
||||
@@ -14,21 +35,49 @@ echo
|
||||
echo " *** Testing ping from lighthouse1"
|
||||
echo
|
||||
set -x
|
||||
docker exec lighthouse1 ping -c1 192.168.100.2
|
||||
docker exec lighthouse1 ping -c1 192.168.100.3
|
||||
sudo docker exec lighthouse1 ping -c1 192.168.100.2
|
||||
sudo docker exec lighthouse1 ping -c1 192.168.100.3
|
||||
|
||||
set +x
|
||||
echo
|
||||
echo " *** Testing ping from host2"
|
||||
echo
|
||||
set -x
|
||||
docker exec host2 ping -c1 192.168.100.1
|
||||
docker exec host2 ping -c1 192.168.100.3
|
||||
sudo docker exec host2 ping -c1 192.168.100.1
|
||||
# Should fail because not allowed by host3 inbound firewall
|
||||
! sudo docker exec host2 ping -c1 192.168.100.3 -w5 || exit 1
|
||||
|
||||
set +x
|
||||
echo
|
||||
echo " *** Testing ping from host3"
|
||||
echo
|
||||
set -x
|
||||
docker exec host3 ping -c1 192.168.100.1
|
||||
docker exec host3 ping -c1 192.168.100.2
|
||||
sudo docker exec host3 ping -c1 192.168.100.1
|
||||
sudo docker exec host3 ping -c1 192.168.100.2
|
||||
|
||||
set +x
|
||||
echo
|
||||
echo " *** Testing ping from host4"
|
||||
echo
|
||||
set -x
|
||||
sudo docker exec host4 ping -c1 192.168.100.1
|
||||
# Should fail because not allowed by host4 outbound firewall
|
||||
! sudo docker exec host4 ping -c1 192.168.100.2 -w5 || exit 1
|
||||
! sudo docker exec host4 ping -c1 192.168.100.3 -w5 || exit 1
|
||||
|
||||
set +x
|
||||
echo
|
||||
echo " *** Testing conntrack"
|
||||
echo
|
||||
set -x
|
||||
# host2 can ping host3 now that host3 pinged it first
|
||||
sudo docker exec host2 ping -c1 192.168.100.3
|
||||
# host4 can ping host2 once conntrack established
|
||||
sudo docker exec host2 ping -c1 192.168.100.4
|
||||
sudo docker exec host4 ping -c1 192.168.100.2
|
||||
|
||||
sudo docker exec host4 sh -c 'kill 1'
|
||||
sudo docker exec host3 sh -c 'kill 1'
|
||||
sudo docker exec host2 sh -c 'kill 1'
|
||||
sudo docker exec lighthouse1 sh -c 'kill 1'
|
||||
sleep 1
|
||||
|
||||
29
.github/workflows/test.yml
vendored
29
.github/workflows/test.yml
vendored
@@ -4,6 +4,13 @@ on:
|
||||
branches:
|
||||
- master
|
||||
pull_request:
|
||||
paths:
|
||||
- '.github/workflows/test.yml'
|
||||
- '**Makefile'
|
||||
- '**.go'
|
||||
- '**.proto'
|
||||
- 'go.mod'
|
||||
- 'go.sum'
|
||||
jobs:
|
||||
|
||||
test-linux:
|
||||
@@ -11,15 +18,22 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
|
||||
- name: Set up Go 1.13
|
||||
- name: Set up Go 1.16
|
||||
uses: actions/setup-go@v1
|
||||
with:
|
||||
go-version: 1.13
|
||||
go-version: 1.16
|
||||
id: go
|
||||
|
||||
- name: Check out code into the Go module directory
|
||||
uses: actions/checkout@v1
|
||||
|
||||
- uses: actions/cache@v1
|
||||
with:
|
||||
path: ~/go/pkg/mod
|
||||
key: ${{ runner.os }}-go1.16-${{ hashFiles('**/go.sum') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-go1.16-
|
||||
|
||||
- name: Build
|
||||
run: make all
|
||||
|
||||
@@ -34,15 +48,22 @@ jobs:
|
||||
os: [windows-latest, macOS-latest]
|
||||
steps:
|
||||
|
||||
- name: Set up Go 1.13
|
||||
- name: Set up Go 1.16
|
||||
uses: actions/setup-go@v1
|
||||
with:
|
||||
go-version: 1.13
|
||||
go-version: 1.16
|
||||
id: go
|
||||
|
||||
- name: Check out code into the Go module directory
|
||||
uses: actions/checkout@v1
|
||||
|
||||
- uses: actions/cache@v1
|
||||
with:
|
||||
path: ~/go/pkg/mod
|
||||
key: ${{ runner.os }}-go1.16-${{ hashFiles('**/go.sum') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-go1.16-
|
||||
|
||||
- name: Build nebula
|
||||
run: go build ./cmd/nebula
|
||||
|
||||
|
||||
142
CHANGELOG.md
142
CHANGELOG.md
@@ -7,6 +7,144 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
|
||||
## [Unreleased]
|
||||
|
||||
### Changed
|
||||
|
||||
- Updated the kardianos/service go library from 1.0.0 to 1.1.0, which
|
||||
now creates launchd plist to write stdout/stderr to files by default.
|
||||
|
||||
## [1.3.0] - 2020-09-22
|
||||
|
||||
### Added
|
||||
|
||||
- You can emit statistics about non-message packets by setting the option
|
||||
`stats.message_metrics`. You can similarly emit detailed statistics about
|
||||
lighthouse packets by setting the option `stats.lighthouse_metrics`. See
|
||||
the example config for more details. (#230)
|
||||
|
||||
- We now support freebsd/amd64. This is experimental, please give us feedback.
|
||||
(#103)
|
||||
|
||||
- We now release a binary for `linux/mips-softfloat` which has also been
|
||||
stripped to reduce filesize and hopefully have a better chance on running on
|
||||
small mips devices. (#231)
|
||||
|
||||
- You can set `tun.disabled` to true to run a standalone lighthouse without a
|
||||
tun device (and thus, without root). (#269)
|
||||
|
||||
- You can set `logging.disable_timestamp` to remove timestamps from log lines,
|
||||
which is useful when output is redirected to a logging system that already
|
||||
adds timestamps. (#288)
|
||||
|
||||
### Changed
|
||||
|
||||
- Handshakes should now trigger faster, as we try to be proactive with sending
|
||||
them instead of waiting for the next timer tick in most cases. (#246, #265)
|
||||
|
||||
- Previously, we would drop the conntrack table whenever firewall rules were
|
||||
changed during a SIGHUP. Now, we will maintain the table and just validate
|
||||
that an entry still matches with the new rule set. (#233)
|
||||
|
||||
- Debug logs for firewall drops now include the reason. (#220, #239)
|
||||
|
||||
- Logs for handshakes now include the fingerprint of the remote host. (#262)
|
||||
|
||||
- Config item `pki.blacklist` is now `pki.blocklist`. (#272)
|
||||
|
||||
- Better support for older Linux kernels. We now only set `SO_REUSEPORT` if
|
||||
`tun.routines` is greater than 1 (default is 1). We also only use the
|
||||
`recvmmsg` syscall if `listen.batch` is greater than 1 (default is 64).
|
||||
(#275)
|
||||
|
||||
- It is possible to run Nebula as a library inside of another process now.
|
||||
Note that this is still experimental and the internal APIs around this might
|
||||
change in minor version releases. (#279)
|
||||
|
||||
### Deprecated
|
||||
|
||||
- `pki.blacklist` is deprecated in favor of `pki.blocklist` with the same
|
||||
functionality. Existing configs will continue to load for this release to
|
||||
allow for migrations. (#272)
|
||||
|
||||
### Fixed
|
||||
|
||||
- `advmss` is now set correctly for each route table entry when `tun.routes`
|
||||
is configured to have some routes with higher MTU. (#245)
|
||||
|
||||
- Packets that arrive on the tun device with an unroutable destination IP are
|
||||
now dropped correctly, instead of wasting time making queries to the
|
||||
lighthouses for IP `0.0.0.0` (#267)
|
||||
|
||||
## [1.2.0] - 2020-04-08
|
||||
|
||||
### Added
|
||||
|
||||
- Add `logging.timestamp_format` config option. The primary purpose of this
|
||||
change is to allow logging timestamps with millisecond precision. (#187)
|
||||
|
||||
- Support `unsafe_routes` on Windows. (#184)
|
||||
|
||||
- Add `lighthouse.remote_allow_list` to filter which subnets we will use to
|
||||
handshake with other hosts. See the example config for more details. (#217)
|
||||
|
||||
- Add `lighthouse.local_allow_list` to filter which local IP addresses and/or
|
||||
interfaces we advertise to the lighthouses. See the example config for more
|
||||
details. (#217)
|
||||
|
||||
- Wireshark dissector plugin. Add this file in `dist/wireshark` to your
|
||||
Wireshark plugins folder to see Nebula packet headers decoded. (#216)
|
||||
|
||||
- systemd unit for Arch, so it can be built entirely from this repo. (#216)
|
||||
|
||||
### Changed
|
||||
|
||||
- Added a delay to punching via lighthouse signal to deal with race conditions
|
||||
in some linux conntrack implementations. (#210)
|
||||
|
||||
See deprecated, this also adds a new `punchy.delay` option that defaults to `1s`.
|
||||
|
||||
- Validate all `lighthouse.hosts` and `static_host_map` VPN IPs are in the
|
||||
subnet defined in our cert. Exit with a fatal error if they are not in our
|
||||
subnet, as this is an invalid configuration (we will not have the proper
|
||||
routes set up to communicate with these hosts). (#170)
|
||||
|
||||
- Use absolute paths to system binaries on macOS and Windows. (#191)
|
||||
|
||||
- Add configuration options for `handshakes`. This includes options to tweak
|
||||
`try_interval`, `retries` and `wait_rotation`. See example config for
|
||||
descriptions. (#179)
|
||||
|
||||
- Allow `-config` file to not end in `.yaml` or `yml`. Useful when using
|
||||
`-test` and automated tools like Ansible that create temporary files without
|
||||
suffixes. (#189)
|
||||
|
||||
- The config test mode, `-test`, is now more thorough and catches more parsing
|
||||
issues. (#177)
|
||||
|
||||
- Various documentation and example fixes. (#196)
|
||||
|
||||
- Improved log messages. (#181, #200)
|
||||
|
||||
- Dependencies updated. (#188)
|
||||
|
||||
### Deprecated
|
||||
|
||||
- `punchy`, `punch_back` configuration options have been collapsed under the
|
||||
now top level `punchy` config directive. (#210)
|
||||
|
||||
`punchy.punch` - This is the old `punchy` option. Should we perform NAT hole
|
||||
punching (default false)?
|
||||
|
||||
`punchy.respond` - This is the old `punch_back` option. Should we respond to
|
||||
hole punching by hole punching back (default false)?
|
||||
|
||||
### Fixed
|
||||
|
||||
- Reduce memory allocations when not using `unsafe_routes`. (#198)
|
||||
|
||||
- Ignore packets from self to self. (#192)
|
||||
|
||||
- MTU fixed for `unsafe_routes`. (#209)
|
||||
|
||||
## [1.1.0] - 2020-01-17
|
||||
|
||||
### Added
|
||||
@@ -47,6 +185,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
|
||||
- Initial public release.
|
||||
|
||||
[Unreleased]: https://github.com/slackhq/nebula/compare/v1.1.0...HEAD
|
||||
[Unreleased]: https://github.com/slackhq/nebula/compare/v1.3.0...HEAD
|
||||
[1.3.0]: https://github.com/slackhq/nebula/releases/tag/v1.3.0
|
||||
[1.2.0]: https://github.com/slackhq/nebula/releases/tag/v1.2.0
|
||||
[1.1.0]: https://github.com/slackhq/nebula/releases/tag/v1.1.0
|
||||
[1.0.0]: https://github.com/slackhq/nebula/releases/tag/v1.0.0
|
||||
|
||||
59
Makefile
59
Makefile
@@ -1,8 +1,18 @@
|
||||
GOMINVERSION = 1.16
|
||||
NEBULA_CMD_PATH = "./cmd/nebula"
|
||||
BUILD_NUMBER ?= dev+$(shell date -u '+%Y%m%d%H%M%S')
|
||||
GO111MODULE = on
|
||||
export GO111MODULE
|
||||
|
||||
# Ensure the version of go we are using is at least what is defined in GOMINVERSION at the top of this file
|
||||
GOVERSION := $(shell go version | awk '{print substr($$3, 3)}')
|
||||
GOISMIN := $(shell expr "$(GOVERSION)" ">=" "$(GOMINVERSION)")
|
||||
ifneq "$(GOISMIN)" "1"
|
||||
$(error "go version $(GOVERSION) is not supported, upgrade to $(GOMINVERSION) or above")
|
||||
endif
|
||||
|
||||
LDFLAGS = -X main.Build=$(BUILD_NUMBER)
|
||||
|
||||
ALL_LINUX = linux-amd64 \
|
||||
linux-386 \
|
||||
linux-ppc64le \
|
||||
@@ -13,43 +23,59 @@ ALL_LINUX = linux-amd64 \
|
||||
linux-mips \
|
||||
linux-mipsle \
|
||||
linux-mips64 \
|
||||
linux-mips64le
|
||||
linux-mips64le \
|
||||
linux-mips-softfloat
|
||||
|
||||
ALL = $(ALL_LINUX) \
|
||||
darwin-amd64 \
|
||||
darwin-arm64 \
|
||||
freebsd-amd64 \
|
||||
windows-amd64
|
||||
|
||||
|
||||
|
||||
all: $(ALL:%=build/%/nebula) $(ALL:%=build/%/nebula-cert)
|
||||
|
||||
release: $(ALL:%=build/nebula-%.tar.gz)
|
||||
|
||||
release-linux: $(ALL_LINUX:%=build/nebula-%.tar.gz)
|
||||
|
||||
release-freebsd: build/nebula-freebsd-amd64.tar.gz
|
||||
|
||||
BUILD_ARGS = -trimpath
|
||||
|
||||
bin-windows: build/windows-amd64/nebula.exe build/windows-amd64/nebula-cert.exe
|
||||
mv $? .
|
||||
|
||||
bin-darwin: build/darwin-amd64/nebula build/darwin-amd64/nebula-cert
|
||||
mv $? .
|
||||
|
||||
bin-freebsd: build/freebsd-amd64/nebula build/freebsd-amd64/nebula-cert
|
||||
mv $? .
|
||||
|
||||
bin:
|
||||
go build -trimpath -ldflags "-X main.Build=$(BUILD_NUMBER)" -o ./nebula ${NEBULA_CMD_PATH}
|
||||
go build -trimpath -ldflags "-X main.Build=$(BUILD_NUMBER)" -o ./nebula-cert ./cmd/nebula-cert
|
||||
go build $(BUILD_ARGS) -ldflags "$(LDFLAGS)" -o ./nebula ${NEBULA_CMD_PATH}
|
||||
go build $(BUILD_ARGS) -ldflags "$(LDFLAGS)" -o ./nebula-cert ./cmd/nebula-cert
|
||||
|
||||
install:
|
||||
go install -trimpath -ldflags "-X main.Build=$(BUILD_NUMBER)" ${NEBULA_CMD_PATH}
|
||||
go install -trimpath -ldflags "-X main.Build=$(BUILD_NUMBER)" ./cmd/nebula-cert
|
||||
go install $(BUILD_ARGS) -ldflags "$(LDFLAGS)" ${NEBULA_CMD_PATH}
|
||||
go install $(BUILD_ARGS) -ldflags "$(LDFLAGS)" ./cmd/nebula-cert
|
||||
|
||||
build/linux-arm-%: GOENV += GOARM=$(word 3, $(subst -, ,$*))
|
||||
build/linux-mips-%: GOENV += GOMIPS=$(word 3, $(subst -, ,$*))
|
||||
|
||||
# Build an extra small binary for mips-softfloat
|
||||
build/linux-mips-softfloat/%: LDFLAGS += -s -w
|
||||
|
||||
build/%/nebula: .FORCE
|
||||
GOOS=$(firstword $(subst -, , $*)) \
|
||||
GOARCH=$(word 2, $(subst -, ,$*)) \
|
||||
GOARM=$(word 3, $(subst -, ,$*)) \
|
||||
go build -trimpath -o $@ -ldflags "-X main.Build=$(BUILD_NUMBER)" ${NEBULA_CMD_PATH}
|
||||
GOARCH=$(word 2, $(subst -, ,$*)) $(GOENV) \
|
||||
go build $(BUILD_ARGS) -o $@ -ldflags "$(LDFLAGS)" ${NEBULA_CMD_PATH}
|
||||
|
||||
build/%/nebula-cert: .FORCE
|
||||
GOOS=$(firstword $(subst -, , $*)) \
|
||||
GOARCH=$(word 2, $(subst -, ,$*)) \
|
||||
GOARM=$(word 3, $(subst -, ,$*)) \
|
||||
go build -trimpath -o $@ -ldflags "-X main.Build=$(BUILD_NUMBER)" ./cmd/nebula-cert
|
||||
GOARCH=$(word 2, $(subst -, ,$*)) $(GOENV) \
|
||||
go build $(BUILD_ARGS) -o $@ -ldflags "$(LDFLAGS)" ./cmd/nebula-cert
|
||||
|
||||
build/%/nebula.exe: build/%/nebula
|
||||
mv $< $@
|
||||
@@ -101,6 +127,15 @@ ifeq ($(words $(MAKECMDGOALS)),1)
|
||||
$(MAKE) service ${.DEFAULT_GOAL} --no-print-directory
|
||||
endif
|
||||
|
||||
bin-docker: bin build/linux-amd64/nebula build/linux-amd64/nebula-cert
|
||||
|
||||
smoke-docker: bin-docker
|
||||
cd .github/workflows/smoke/ && ./build.sh
|
||||
cd .github/workflows/smoke/ && ./smoke.sh
|
||||
|
||||
smoke-docker-race: BUILD_ARGS = -race
|
||||
smoke-docker-race: smoke-docker
|
||||
|
||||
.FORCE:
|
||||
.PHONY: test test-cov-html bench bench-cpu bench-cpu-long bin proto release service
|
||||
.PHONY: test test-cov-html bench bench-cpu bench-cpu-long bin proto release service smoke-docker smoke-docker-race
|
||||
.DEFAULT_GOAL := bin
|
||||
|
||||
19
README.md
19
README.md
@@ -1,7 +1,6 @@
|
||||
## What is Nebula?
|
||||
Nebula is a scalable overlay networking tool with a focus on performance, simplicity and security.
|
||||
It lets you seamlessly connect computers anywhere in the world. Nebula is portable, and runs on Linux, OSX, and Windows.
|
||||
(Also: keep this quiet, but we have an early prototype running on iOS).
|
||||
It lets you seamlessly connect computers anywhere in the world. Nebula is portable, and runs on Linux, OSX, Windows, iOS, and Android.
|
||||
It can be used to connect a small number of computers, but is also able to connect tens of thousands of computers.
|
||||
|
||||
Nebula incorporates a number of existing concepts like encryption, security groups, certificates,
|
||||
@@ -13,6 +12,22 @@ You can read more about Nebula [here](https://medium.com/p/884110a5579).
|
||||
|
||||
You can also join the NebulaOSS Slack group [here](https://join.slack.com/t/nebulaoss/shared_invite/enQtOTA5MDI4NDg3MTg4LTkwY2EwNTI4NzQyMzc0M2ZlODBjNWI3NTY1MzhiOThiMmZlZjVkMTI0NGY4YTMyNjUwMWEyNzNkZTJmYzQxOGU)
|
||||
|
||||
## Supported Platforms
|
||||
|
||||
#### Desktop and Server
|
||||
|
||||
Check the [releases](https://github.com/slackhq/nebula/releases/latest) page for downloads
|
||||
|
||||
- Linux - 64 and 32 bit, arm, and others
|
||||
- Windows
|
||||
- MacOS
|
||||
- Freebsd
|
||||
|
||||
#### Mobile
|
||||
|
||||
- [iOS](https://apps.apple.com/us/app/mobile-nebula/id1509587936?itsct=apps_box&itscg=30200)
|
||||
- [Android](https://play.google.com/store/apps/details?id=net.defined.mobile_nebula&pcampaignid=pcampaignidMKT-Other-global-all-co-prtnr-py-PartBadge-Mar2515-1)
|
||||
|
||||
## Technical Overview
|
||||
|
||||
Nebula is a mutually authenticated peer-to-peer software defined network based on the [Noise Protocol Framework](https://noiseprotocol.org/).
|
||||
|
||||
48
allow_list.go
Normal file
48
allow_list.go
Normal file
@@ -0,0 +1,48 @@
|
||||
package nebula
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"regexp"
|
||||
)
|
||||
|
||||
type AllowList struct {
|
||||
// The values of this cidrTree are `bool`, signifying allow/deny
|
||||
cidrTree *CIDRTree
|
||||
|
||||
// To avoid ambiguity, all rules must be true, or all rules must be false.
|
||||
nameRules []AllowListNameRule
|
||||
}
|
||||
|
||||
type AllowListNameRule struct {
|
||||
Name *regexp.Regexp
|
||||
Allow bool
|
||||
}
|
||||
|
||||
func (al *AllowList) Allow(ip uint32) bool {
|
||||
if al == nil {
|
||||
return true
|
||||
}
|
||||
|
||||
result := al.cidrTree.MostSpecificContains(ip)
|
||||
switch v := result.(type) {
|
||||
case bool:
|
||||
return v
|
||||
default:
|
||||
panic(fmt.Errorf("invalid state, allowlist returned: %T %v", result, result))
|
||||
}
|
||||
}
|
||||
|
||||
func (al *AllowList) AllowName(name string) bool {
|
||||
if al == nil || len(al.nameRules) == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
for _, rule := range al.nameRules {
|
||||
if rule.Name.MatchString(name) {
|
||||
return rule.Allow
|
||||
}
|
||||
}
|
||||
|
||||
// If no rules match, return the default, which is the inverse of the rules
|
||||
return !al.nameRules[0].Allow
|
||||
}
|
||||
47
allow_list_test.go
Normal file
47
allow_list_test.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package nebula
|
||||
|
||||
import (
|
||||
"net"
|
||||
"regexp"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestAllowList_Allow(t *testing.T) {
|
||||
assert.Equal(t, true, ((*AllowList)(nil)).Allow(ip2int(net.ParseIP("1.1.1.1"))))
|
||||
|
||||
tree := NewCIDRTree()
|
||||
tree.AddCIDR(getCIDR("0.0.0.0/0"), true)
|
||||
tree.AddCIDR(getCIDR("10.0.0.0/8"), false)
|
||||
tree.AddCIDR(getCIDR("10.42.42.0/24"), true)
|
||||
al := &AllowList{cidrTree: tree}
|
||||
|
||||
assert.Equal(t, true, al.Allow(ip2int(net.ParseIP("1.1.1.1"))))
|
||||
assert.Equal(t, false, al.Allow(ip2int(net.ParseIP("10.0.0.4"))))
|
||||
assert.Equal(t, true, al.Allow(ip2int(net.ParseIP("10.42.42.42"))))
|
||||
}
|
||||
|
||||
func TestAllowList_AllowName(t *testing.T) {
|
||||
assert.Equal(t, true, ((*AllowList)(nil)).AllowName("docker0"))
|
||||
|
||||
rules := []AllowListNameRule{
|
||||
{Name: regexp.MustCompile("^docker.*$"), Allow: false},
|
||||
{Name: regexp.MustCompile("^tun.*$"), Allow: false},
|
||||
}
|
||||
al := &AllowList{nameRules: rules}
|
||||
|
||||
assert.Equal(t, false, al.AllowName("docker0"))
|
||||
assert.Equal(t, false, al.AllowName("tun0"))
|
||||
assert.Equal(t, true, al.AllowName("eth0"))
|
||||
|
||||
rules = []AllowListNameRule{
|
||||
{Name: regexp.MustCompile("^eth.*$"), Allow: true},
|
||||
{Name: regexp.MustCompile("^ens.*$"), Allow: true},
|
||||
}
|
||||
al = &AllowList{nameRules: rules}
|
||||
|
||||
assert.Equal(t, false, al.AllowName("docker0"))
|
||||
assert.Equal(t, true, al.AllowName("eth0"))
|
||||
assert.Equal(t, true, al.AllowName("ens5"))
|
||||
}
|
||||
@@ -212,10 +212,10 @@ func TestBitsLostCounter(t *testing.T) {
|
||||
func BenchmarkBits(b *testing.B) {
|
||||
z := NewBits(10)
|
||||
for n := 0; n < b.N; n++ {
|
||||
for i, _ := range z.bits {
|
||||
for i := range z.bits {
|
||||
z.bits[i] = true
|
||||
}
|
||||
for i, _ := range z.bits {
|
||||
for i := range z.bits {
|
||||
z.bits[i] = false
|
||||
}
|
||||
|
||||
|
||||
12
cert.go
12
cert.go
@@ -149,10 +149,16 @@ func loadCAFromConfig(c *Config) (*cert.NebulaCAPool, error) {
|
||||
return nil, fmt.Errorf("error while adding CA certificate to CA trust store: %s", err)
|
||||
}
|
||||
|
||||
// pki.blacklist entered the scene at about the same time we aliased x509 to pki, not supporting backwards compat
|
||||
for _, fp := range c.GetStringSlice("pki.blocklist", []string{}) {
|
||||
l.WithField("fingerprint", fp).Infof("Blocklisting cert")
|
||||
CAs.BlocklistFingerprint(fp)
|
||||
}
|
||||
|
||||
// Support deprecated config for at leaast one minor release to allow for migrations
|
||||
for _, fp := range c.GetStringSlice("pki.blacklist", []string{}) {
|
||||
l.WithField("fingerprint", fp).Infof("Blacklisting cert")
|
||||
CAs.BlacklistFingerprint(fp)
|
||||
l.WithField("fingerprint", fp).Infof("Blocklisting cert")
|
||||
l.Warn("pki.blacklist is deprecated and will not be supported in a future release. Please migrate your config to use pki.blocklist")
|
||||
CAs.BlocklistFingerprint(fp)
|
||||
}
|
||||
|
||||
return CAs, nil
|
||||
|
||||
22
cert/ca.go
22
cert/ca.go
@@ -8,14 +8,14 @@ import (
|
||||
|
||||
type NebulaCAPool struct {
|
||||
CAs map[string]*NebulaCertificate
|
||||
certBlacklist map[string]struct{}
|
||||
certBlocklist map[string]struct{}
|
||||
}
|
||||
|
||||
// NewCAPool creates a CAPool
|
||||
func NewCAPool() *NebulaCAPool {
|
||||
ca := NebulaCAPool{
|
||||
CAs: make(map[string]*NebulaCertificate),
|
||||
certBlacklist: make(map[string]struct{}),
|
||||
certBlocklist: make(map[string]struct{}),
|
||||
}
|
||||
|
||||
return &ca
|
||||
@@ -67,24 +67,24 @@ func (ncp *NebulaCAPool) AddCACertificate(pemBytes []byte) ([]byte, error) {
|
||||
return pemBytes, nil
|
||||
}
|
||||
|
||||
// BlacklistFingerprint adds a cert fingerprint to the blacklist
|
||||
func (ncp *NebulaCAPool) BlacklistFingerprint(f string) {
|
||||
ncp.certBlacklist[f] = struct{}{}
|
||||
// BlocklistFingerprint adds a cert fingerprint to the blocklist
|
||||
func (ncp *NebulaCAPool) BlocklistFingerprint(f string) {
|
||||
ncp.certBlocklist[f] = struct{}{}
|
||||
}
|
||||
|
||||
// ResetCertBlacklist removes all previously blacklisted cert fingerprints
|
||||
func (ncp *NebulaCAPool) ResetCertBlacklist() {
|
||||
ncp.certBlacklist = make(map[string]struct{})
|
||||
// ResetCertBlocklist removes all previously blocklisted cert fingerprints
|
||||
func (ncp *NebulaCAPool) ResetCertBlocklist() {
|
||||
ncp.certBlocklist = make(map[string]struct{})
|
||||
}
|
||||
|
||||
// IsBlacklisted returns true if the fingerprint fails to generate or has been explicitly blacklisted
|
||||
func (ncp *NebulaCAPool) IsBlacklisted(c *NebulaCertificate) bool {
|
||||
// IsBlocklisted returns true if the fingerprint fails to generate or has been explicitly blocklisted
|
||||
func (ncp *NebulaCAPool) IsBlocklisted(c *NebulaCertificate) bool {
|
||||
h, err := c.Sha256Sum()
|
||||
if err != nil {
|
||||
return true
|
||||
}
|
||||
|
||||
if _, ok := ncp.certBlacklist[h]; ok {
|
||||
if _, ok := ncp.certBlocklist[h]; ok {
|
||||
return true
|
||||
}
|
||||
|
||||
|
||||
74
cert/cert.go
74
cert/cert.go
@@ -1,18 +1,18 @@
|
||||
package cert
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"github.com/golang/protobuf/proto"
|
||||
"golang.org/x/crypto/curve25519"
|
||||
"golang.org/x/crypto/ed25519"
|
||||
@@ -61,6 +61,10 @@ func UnmarshalNebulaCertificate(b []byte) (*NebulaCertificate, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if rc.Details == nil {
|
||||
return nil, fmt.Errorf("encoded Details was nil")
|
||||
}
|
||||
|
||||
if len(rc.Details.Ips)%2 != 0 {
|
||||
return nil, fmt.Errorf("encoded IPs should be in pairs, an odd number was found")
|
||||
}
|
||||
@@ -123,6 +127,9 @@ func UnmarshalNebulaCertificateFromPEM(b []byte) (*NebulaCertificate, []byte, er
|
||||
if p == nil {
|
||||
return nil, r, fmt.Errorf("input did not contain a valid PEM encoded block")
|
||||
}
|
||||
if p.Type != CertBanner {
|
||||
return nil, r, fmt.Errorf("bytes did not contain a proper nebula certificate banner")
|
||||
}
|
||||
nc, err := UnmarshalNebulaCertificate(p.Bytes)
|
||||
return nc, r, err
|
||||
}
|
||||
@@ -244,10 +251,10 @@ func (nc *NebulaCertificate) Expired(t time.Time) bool {
|
||||
return nc.Details.NotBefore.After(t) || nc.Details.NotAfter.Before(t)
|
||||
}
|
||||
|
||||
// Verify will ensure a certificate is good in all respects (expiry, group membership, signature, cert blacklist, etc)
|
||||
// Verify will ensure a certificate is good in all respects (expiry, group membership, signature, cert blocklist, etc)
|
||||
func (nc *NebulaCertificate) Verify(t time.Time, ncp *NebulaCAPool) (bool, error) {
|
||||
if ncp.IsBlacklisted(nc) {
|
||||
return false, fmt.Errorf("certificate has been blacklisted")
|
||||
if ncp.IsBlocklisted(nc) {
|
||||
return false, fmt.Errorf("certificate has been blocked")
|
||||
}
|
||||
|
||||
signer, err := ncp.GetCAForCert(nc)
|
||||
@@ -468,6 +475,63 @@ func (nc *NebulaCertificate) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(jc)
|
||||
}
|
||||
|
||||
//func (nc *NebulaCertificate) Copy() *NebulaCertificate {
|
||||
// r, err := nc.Marshal()
|
||||
// if err != nil {
|
||||
// //TODO
|
||||
// return nil
|
||||
// }
|
||||
//
|
||||
// c, err := UnmarshalNebulaCertificate(r)
|
||||
// return c
|
||||
//}
|
||||
|
||||
func (nc *NebulaCertificate) Copy() *NebulaCertificate {
|
||||
c := &NebulaCertificate{
|
||||
Details: NebulaCertificateDetails{
|
||||
Name: nc.Details.Name,
|
||||
Groups: make([]string, len(nc.Details.Groups)),
|
||||
Ips: make([]*net.IPNet, len(nc.Details.Ips)),
|
||||
Subnets: make([]*net.IPNet, len(nc.Details.Subnets)),
|
||||
NotBefore: nc.Details.NotBefore,
|
||||
NotAfter: nc.Details.NotAfter,
|
||||
PublicKey: make([]byte, len(nc.Details.PublicKey)),
|
||||
IsCA: nc.Details.IsCA,
|
||||
Issuer: nc.Details.Issuer,
|
||||
InvertedGroups: make(map[string]struct{}, len(nc.Details.InvertedGroups)),
|
||||
},
|
||||
Signature: make([]byte, len(nc.Signature)),
|
||||
}
|
||||
|
||||
copy(c.Signature, nc.Signature)
|
||||
copy(c.Details.Groups, nc.Details.Groups)
|
||||
copy(c.Details.PublicKey, nc.Details.PublicKey)
|
||||
|
||||
for i, p := range nc.Details.Ips {
|
||||
c.Details.Ips[i] = &net.IPNet{
|
||||
IP: make(net.IP, len(p.IP)),
|
||||
Mask: make(net.IPMask, len(p.Mask)),
|
||||
}
|
||||
copy(c.Details.Ips[i].IP, p.IP)
|
||||
copy(c.Details.Ips[i].Mask, p.Mask)
|
||||
}
|
||||
|
||||
for i, p := range nc.Details.Subnets {
|
||||
c.Details.Subnets[i] = &net.IPNet{
|
||||
IP: make(net.IP, len(p.IP)),
|
||||
Mask: make(net.IPMask, len(p.Mask)),
|
||||
}
|
||||
copy(c.Details.Subnets[i].IP, p.IP)
|
||||
copy(c.Details.Subnets[i].Mask, p.Mask)
|
||||
}
|
||||
|
||||
for g := range nc.Details.InvertedGroups {
|
||||
c.Details.InvertedGroups[g] = struct{}{}
|
||||
}
|
||||
|
||||
return c
|
||||
}
|
||||
|
||||
func netMatch(certIp *net.IPNet, rootIps []*net.IPNet) bool {
|
||||
for _, net := range rootIps {
|
||||
if net.Contains(certIp.IP) && maskContains(net.Mask, certIp.Mask) {
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/golang/protobuf/proto"
|
||||
"github.com/slackhq/nebula/util"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"golang.org/x/crypto/curve25519"
|
||||
"golang.org/x/crypto/ed25519"
|
||||
@@ -172,13 +173,13 @@ func TestNebulaCertificate_Verify(t *testing.T) {
|
||||
|
||||
f, err := c.Sha256Sum()
|
||||
assert.Nil(t, err)
|
||||
caPool.BlacklistFingerprint(f)
|
||||
caPool.BlocklistFingerprint(f)
|
||||
|
||||
v, err := c.Verify(time.Now(), caPool)
|
||||
assert.False(t, v)
|
||||
assert.EqualError(t, err, "certificate has been blacklisted")
|
||||
assert.EqualError(t, err, "certificate has been blocked")
|
||||
|
||||
caPool.ResetCertBlacklist()
|
||||
caPool.ResetCertBlocklist()
|
||||
v, err = c.Verify(time.Now(), caPool)
|
||||
assert.True(t, v)
|
||||
assert.Nil(t, err)
|
||||
@@ -446,6 +447,255 @@ BVG+oJpAoqokUBbI4U0N8CSfpUABEkB/Pm5A2xyH/nc8mg/wvGUWG3pZ7nHzaDMf
|
||||
assert.Equal(t, pp.CAs[string("5c9c3f23e7ee7fe97637cbd3a0a5b854154d1d9aaaf7b566a51f4a88f76b64cd")].Details.Name, rootCA01.Details.Name)
|
||||
}
|
||||
|
||||
func appendByteSlices(b ...[]byte) []byte {
|
||||
retSlice := []byte{}
|
||||
for _, v := range b {
|
||||
retSlice = append(retSlice, v...)
|
||||
}
|
||||
return retSlice
|
||||
}
|
||||
|
||||
func TestUnmrshalCertPEM(t *testing.T) {
|
||||
goodCert := []byte(`
|
||||
# A good cert
|
||||
-----BEGIN NEBULA CERTIFICATE-----
|
||||
CkAKDm5lYnVsYSByb290IGNhKJfap9AFMJfg1+YGOiCUQGByMuNRhIlQBOyzXWbL
|
||||
vcKBwDhov900phEfJ5DN3kABEkDCq5R8qBiu8sl54yVfgRcQXEDt3cHr8UTSLszv
|
||||
bzBEr00kERQxxTzTsH8cpYEgRoipvmExvg8WP8NdAJEYJosB
|
||||
-----END NEBULA CERTIFICATE-----
|
||||
`)
|
||||
badBanner := []byte(`# A bad banner
|
||||
-----BEGIN NOT A NEBULA CERTIFICATE-----
|
||||
CkAKDm5lYnVsYSByb290IGNhKJfap9AFMJfg1+YGOiCUQGByMuNRhIlQBOyzXWbL
|
||||
vcKBwDhov900phEfJ5DN3kABEkDCq5R8qBiu8sl54yVfgRcQXEDt3cHr8UTSLszv
|
||||
bzBEr00kERQxxTzTsH8cpYEgRoipvmExvg8WP8NdAJEYJosB
|
||||
-----END NOT A NEBULA CERTIFICATE-----
|
||||
`)
|
||||
invalidPem := []byte(`# Not a valid PEM format
|
||||
-BEGIN NEBULA CERTIFICATE-----
|
||||
CkAKDm5lYnVsYSByb290IGNhKJfap9AFMJfg1+YGOiCUQGByMuNRhIlQBOyzXWbL
|
||||
vcKBwDhov900phEfJ5DN3kABEkDCq5R8qBiu8sl54yVfgRcQXEDt3cHr8UTSLszv
|
||||
bzBEr00kERQxxTzTsH8cpYEgRoipvmExvg8WP8NdAJEYJosB
|
||||
-END NEBULA CERTIFICATE----`)
|
||||
|
||||
certBundle := appendByteSlices(goodCert, badBanner, invalidPem)
|
||||
|
||||
// Success test case
|
||||
cert, rest, err := UnmarshalNebulaCertificateFromPEM(certBundle)
|
||||
assert.NotNil(t, cert)
|
||||
assert.Equal(t, rest, append(badBanner, invalidPem...))
|
||||
assert.Nil(t, err)
|
||||
|
||||
// Fail due to invalid banner.
|
||||
cert, rest, err = UnmarshalNebulaCertificateFromPEM(rest)
|
||||
assert.Nil(t, cert)
|
||||
assert.Equal(t, rest, invalidPem)
|
||||
assert.EqualError(t, err, "bytes did not contain a proper nebula certificate banner")
|
||||
|
||||
// Fail due to ivalid PEM format, because
|
||||
// it's missing the requisite pre-encapsulation boundary.
|
||||
cert, rest, err = UnmarshalNebulaCertificateFromPEM(rest)
|
||||
assert.Nil(t, cert)
|
||||
assert.Equal(t, rest, invalidPem)
|
||||
assert.EqualError(t, err, "input did not contain a valid PEM encoded block")
|
||||
}
|
||||
|
||||
func TestUnmarshalEd25519PrivateKey(t *testing.T) {
|
||||
privKey := []byte(`# A good key
|
||||
-----BEGIN NEBULA ED25519 PRIVATE KEY-----
|
||||
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA==
|
||||
-----END NEBULA ED25519 PRIVATE KEY-----
|
||||
`)
|
||||
shortKey := []byte(`# A short key
|
||||
-----BEGIN NEBULA ED25519 PRIVATE KEY-----
|
||||
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
|
||||
-----END NEBULA ED25519 PRIVATE KEY-----
|
||||
`)
|
||||
invalidBanner := []byte(`# Invalid banner
|
||||
-----BEGIN NOT A NEBULA PRIVATE KEY-----
|
||||
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA==
|
||||
-----END NOT A NEBULA PRIVATE KEY-----
|
||||
`)
|
||||
invalidPem := []byte(`# Not a valid PEM format
|
||||
-BEGIN NEBULA ED25519 PRIVATE KEY-----
|
||||
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA==
|
||||
-END NEBULA ED25519 PRIVATE KEY-----`)
|
||||
|
||||
keyBundle := appendByteSlices(privKey, shortKey, invalidBanner, invalidPem)
|
||||
|
||||
// Success test case
|
||||
k, rest, err := UnmarshalEd25519PrivateKey(keyBundle)
|
||||
assert.Len(t, k, 64)
|
||||
assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem))
|
||||
assert.Nil(t, err)
|
||||
|
||||
// Fail due to short key
|
||||
k, rest, err = UnmarshalEd25519PrivateKey(rest)
|
||||
assert.Nil(t, k)
|
||||
assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem))
|
||||
assert.EqualError(t, err, "key was not 64 bytes, is invalid ed25519 private key")
|
||||
|
||||
// Fail due to invalid banner
|
||||
k, rest, err = UnmarshalEd25519PrivateKey(rest)
|
||||
assert.Nil(t, k)
|
||||
assert.Equal(t, rest, invalidPem)
|
||||
assert.EqualError(t, err, "bytes did not contain a proper nebula Ed25519 private key banner")
|
||||
|
||||
// Fail due to ivalid PEM format, because
|
||||
// it's missing the requisite pre-encapsulation boundary.
|
||||
k, rest, err = UnmarshalEd25519PrivateKey(rest)
|
||||
assert.Nil(t, k)
|
||||
assert.Equal(t, rest, invalidPem)
|
||||
assert.EqualError(t, err, "input did not contain a valid PEM encoded block")
|
||||
}
|
||||
|
||||
func TestUnmarshalX25519PrivateKey(t *testing.T) {
|
||||
privKey := []byte(`# A good key
|
||||
-----BEGIN NEBULA X25519 PRIVATE KEY-----
|
||||
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
|
||||
-----END NEBULA X25519 PRIVATE KEY-----
|
||||
`)
|
||||
shortKey := []byte(`# A short key
|
||||
-----BEGIN NEBULA X25519 PRIVATE KEY-----
|
||||
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA==
|
||||
-----END NEBULA X25519 PRIVATE KEY-----
|
||||
`)
|
||||
invalidBanner := []byte(`# Invalid banner
|
||||
-----BEGIN NOT A NEBULA PRIVATE KEY-----
|
||||
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
|
||||
-----END NOT A NEBULA PRIVATE KEY-----
|
||||
`)
|
||||
invalidPem := []byte(`# Not a valid PEM format
|
||||
-BEGIN NEBULA X25519 PRIVATE KEY-----
|
||||
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
|
||||
-END NEBULA X25519 PRIVATE KEY-----`)
|
||||
|
||||
keyBundle := appendByteSlices(privKey, shortKey, invalidBanner, invalidPem)
|
||||
|
||||
// Success test case
|
||||
k, rest, err := UnmarshalX25519PrivateKey(keyBundle)
|
||||
assert.Len(t, k, 32)
|
||||
assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem))
|
||||
assert.Nil(t, err)
|
||||
|
||||
// Fail due to short key
|
||||
k, rest, err = UnmarshalX25519PrivateKey(rest)
|
||||
assert.Nil(t, k)
|
||||
assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem))
|
||||
assert.EqualError(t, err, "key was not 32 bytes, is invalid X25519 private key")
|
||||
|
||||
// Fail due to invalid banner
|
||||
k, rest, err = UnmarshalX25519PrivateKey(rest)
|
||||
assert.Nil(t, k)
|
||||
assert.Equal(t, rest, invalidPem)
|
||||
assert.EqualError(t, err, "bytes did not contain a proper nebula X25519 private key banner")
|
||||
|
||||
// Fail due to ivalid PEM format, because
|
||||
// it's missing the requisite pre-encapsulation boundary.
|
||||
k, rest, err = UnmarshalX25519PrivateKey(rest)
|
||||
assert.Nil(t, k)
|
||||
assert.Equal(t, rest, invalidPem)
|
||||
assert.EqualError(t, err, "input did not contain a valid PEM encoded block")
|
||||
}
|
||||
|
||||
func TestUnmarshalEd25519PublicKey(t *testing.T) {
|
||||
pubKey := []byte(`# A good key
|
||||
-----BEGIN NEBULA ED25519 PUBLIC KEY-----
|
||||
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
|
||||
-----END NEBULA ED25519 PUBLIC KEY-----
|
||||
`)
|
||||
shortKey := []byte(`# A short key
|
||||
-----BEGIN NEBULA ED25519 PUBLIC KEY-----
|
||||
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA==
|
||||
-----END NEBULA ED25519 PUBLIC KEY-----
|
||||
`)
|
||||
invalidBanner := []byte(`# Invalid banner
|
||||
-----BEGIN NOT A NEBULA PUBLIC KEY-----
|
||||
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
|
||||
-----END NOT A NEBULA PUBLIC KEY-----
|
||||
`)
|
||||
invalidPem := []byte(`# Not a valid PEM format
|
||||
-BEGIN NEBULA ED25519 PUBLIC KEY-----
|
||||
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
|
||||
-END NEBULA ED25519 PUBLIC KEY-----`)
|
||||
|
||||
keyBundle := appendByteSlices(pubKey, shortKey, invalidBanner, invalidPem)
|
||||
|
||||
// Success test case
|
||||
k, rest, err := UnmarshalEd25519PublicKey(keyBundle)
|
||||
assert.Equal(t, len(k), 32)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem))
|
||||
|
||||
// Fail due to short key
|
||||
k, rest, err = UnmarshalEd25519PublicKey(rest)
|
||||
assert.Nil(t, k)
|
||||
assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem))
|
||||
assert.EqualError(t, err, "key was not 32 bytes, is invalid ed25519 public key")
|
||||
|
||||
// Fail due to invalid banner
|
||||
k, rest, err = UnmarshalEd25519PublicKey(rest)
|
||||
assert.Nil(t, k)
|
||||
assert.EqualError(t, err, "bytes did not contain a proper nebula Ed25519 public key banner")
|
||||
assert.Equal(t, rest, invalidPem)
|
||||
|
||||
// Fail due to ivalid PEM format, because
|
||||
// it's missing the requisite pre-encapsulation boundary.
|
||||
k, rest, err = UnmarshalEd25519PublicKey(rest)
|
||||
assert.Nil(t, k)
|
||||
assert.Equal(t, rest, invalidPem)
|
||||
assert.EqualError(t, err, "input did not contain a valid PEM encoded block")
|
||||
}
|
||||
|
||||
func TestUnmarshalX25519PublicKey(t *testing.T) {
|
||||
pubKey := []byte(`# A good key
|
||||
-----BEGIN NEBULA X25519 PUBLIC KEY-----
|
||||
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
|
||||
-----END NEBULA X25519 PUBLIC KEY-----
|
||||
`)
|
||||
shortKey := []byte(`# A short key
|
||||
-----BEGIN NEBULA X25519 PUBLIC KEY-----
|
||||
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA==
|
||||
-----END NEBULA X25519 PUBLIC KEY-----
|
||||
`)
|
||||
invalidBanner := []byte(`# Invalid banner
|
||||
-----BEGIN NOT A NEBULA PUBLIC KEY-----
|
||||
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
|
||||
-----END NOT A NEBULA PUBLIC KEY-----
|
||||
`)
|
||||
invalidPem := []byte(`# Not a valid PEM format
|
||||
-BEGIN NEBULA X25519 PUBLIC KEY-----
|
||||
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
|
||||
-END NEBULA X25519 PUBLIC KEY-----`)
|
||||
|
||||
keyBundle := appendByteSlices(pubKey, shortKey, invalidBanner, invalidPem)
|
||||
|
||||
// Success test case
|
||||
k, rest, err := UnmarshalX25519PublicKey(keyBundle)
|
||||
assert.Equal(t, len(k), 32)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem))
|
||||
|
||||
// Fail due to short key
|
||||
k, rest, err = UnmarshalX25519PublicKey(rest)
|
||||
assert.Nil(t, k)
|
||||
assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem))
|
||||
assert.EqualError(t, err, "key was not 32 bytes, is invalid X25519 public key")
|
||||
|
||||
// Fail due to invalid banner
|
||||
k, rest, err = UnmarshalX25519PublicKey(rest)
|
||||
assert.Nil(t, k)
|
||||
assert.EqualError(t, err, "bytes did not contain a proper nebula X25519 public key banner")
|
||||
assert.Equal(t, rest, invalidPem)
|
||||
|
||||
// Fail due to ivalid PEM format, because
|
||||
// it's missing the requisite pre-encapsulation boundary.
|
||||
k, rest, err = UnmarshalX25519PublicKey(rest)
|
||||
assert.Nil(t, k)
|
||||
assert.Equal(t, rest, invalidPem)
|
||||
assert.EqualError(t, err, "input did not contain a valid PEM encoded block")
|
||||
}
|
||||
|
||||
// Ensure that upgrading the protobuf library does not change how certificates
|
||||
// are marshalled, since this would break signature verification
|
||||
func TestMarshalingNebulaCertificateConsistency(t *testing.T) {
|
||||
@@ -487,6 +737,24 @@ func TestMarshalingNebulaCertificateConsistency(t *testing.T) {
|
||||
assert.Equal(t, "0a0774657374696e67121b8182845080feffff0f828284508080fcff0f8382845080fe83f80f1a1b8182844880fe83f80f8282844880feffff0f838284488080fcff0f220b746573742d67726f757031220b746573742d67726f757032220b746573742d67726f75703328f0e0e7d70430a08681c4053a20313233343536373839306162636564666768696a3132333435363738393061624a081234567890abcedf", fmt.Sprintf("%x", b))
|
||||
}
|
||||
|
||||
func TestNebulaCertificate_Copy(t *testing.T) {
|
||||
ca, _, caKey, err := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
|
||||
assert.Nil(t, err)
|
||||
|
||||
c, _, _, err := newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
|
||||
assert.Nil(t, err)
|
||||
cc := c.Copy()
|
||||
|
||||
util.AssertDeepCopyEqual(t, c, cc)
|
||||
}
|
||||
|
||||
func TestUnmarshalNebulaCertificate(t *testing.T) {
|
||||
// Test that we don't panic with an invalid certificate (#332)
|
||||
data := []byte("\x98\x00\x00")
|
||||
_, err := UnmarshalNebulaCertificate(data)
|
||||
assert.EqualError(t, err, "encoded Details was nil")
|
||||
}
|
||||
|
||||
func newTestCaCert(before, after time.Time, ips, subnets []*net.IPNet, groups []string) (*NebulaCertificate, []byte, []byte, error) {
|
||||
pub, priv, err := ed25519.GenerateKey(rand.Reader)
|
||||
if before.IsZero() {
|
||||
@@ -498,11 +766,12 @@ func newTestCaCert(before, after time.Time, ips, subnets []*net.IPNet, groups []
|
||||
|
||||
nc := &NebulaCertificate{
|
||||
Details: NebulaCertificateDetails{
|
||||
Name: "test ca",
|
||||
NotBefore: before,
|
||||
NotAfter: after,
|
||||
PublicKey: pub,
|
||||
IsCA: true,
|
||||
Name: "test ca",
|
||||
NotBefore: time.Unix(before.Unix(), 0),
|
||||
NotAfter: time.Unix(after.Unix(), 0),
|
||||
PublicKey: pub,
|
||||
IsCA: true,
|
||||
InvertedGroups: make(map[string]struct{}),
|
||||
},
|
||||
}
|
||||
|
||||
@@ -544,17 +813,17 @@ func newTestCert(ca *NebulaCertificate, key []byte, before, after time.Time, ips
|
||||
|
||||
if len(ips) == 0 {
|
||||
ips = []*net.IPNet{
|
||||
{IP: net.ParseIP("10.1.1.1"), Mask: net.IPMask(net.ParseIP("255.255.255.0"))},
|
||||
{IP: net.ParseIP("10.1.1.2"), Mask: net.IPMask(net.ParseIP("255.255.0.0"))},
|
||||
{IP: net.ParseIP("10.1.1.3"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))},
|
||||
{IP: net.ParseIP("10.1.1.1").To4(), Mask: net.IPMask(net.ParseIP("255.255.255.0").To4())},
|
||||
{IP: net.ParseIP("10.1.1.2").To4(), Mask: net.IPMask(net.ParseIP("255.255.0.0").To4())},
|
||||
{IP: net.ParseIP("10.1.1.3").To4(), Mask: net.IPMask(net.ParseIP("255.0.255.0").To4())},
|
||||
}
|
||||
}
|
||||
|
||||
if len(subnets) == 0 {
|
||||
subnets = []*net.IPNet{
|
||||
{IP: net.ParseIP("9.1.1.1"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))},
|
||||
{IP: net.ParseIP("9.1.1.2"), Mask: net.IPMask(net.ParseIP("255.255.255.0"))},
|
||||
{IP: net.ParseIP("9.1.1.3"), Mask: net.IPMask(net.ParseIP("255.255.0.0"))},
|
||||
{IP: net.ParseIP("9.1.1.1").To4(), Mask: net.IPMask(net.ParseIP("255.0.255.0").To4())},
|
||||
{IP: net.ParseIP("9.1.1.2").To4(), Mask: net.IPMask(net.ParseIP("255.255.255.0").To4())},
|
||||
{IP: net.ParseIP("9.1.1.3").To4(), Mask: net.IPMask(net.ParseIP("255.255.0.0").To4())},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -562,15 +831,16 @@ func newTestCert(ca *NebulaCertificate, key []byte, before, after time.Time, ips
|
||||
|
||||
nc := &NebulaCertificate{
|
||||
Details: NebulaCertificateDetails{
|
||||
Name: "testing",
|
||||
Ips: ips,
|
||||
Subnets: subnets,
|
||||
Groups: groups,
|
||||
NotBefore: before,
|
||||
NotAfter: after,
|
||||
PublicKey: pub,
|
||||
IsCA: false,
|
||||
Issuer: issuer,
|
||||
Name: "testing",
|
||||
Ips: ips,
|
||||
Subnets: subnets,
|
||||
Groups: groups,
|
||||
NotBefore: time.Unix(before.Unix(), 0),
|
||||
NotAfter: time.Unix(after.Unix(), 0),
|
||||
PublicKey: pub,
|
||||
IsCA: false,
|
||||
Issuer: issuer,
|
||||
InvertedGroups: make(map[string]struct{}),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/skip2/go-qrcode"
|
||||
"github.com/slackhq/nebula/cert"
|
||||
"golang.org/x/crypto/ed25519"
|
||||
)
|
||||
@@ -21,6 +22,7 @@ type caFlags struct {
|
||||
duration *time.Duration
|
||||
outKeyPath *string
|
||||
outCertPath *string
|
||||
outQRPath *string
|
||||
groups *string
|
||||
ips *string
|
||||
subnets *string
|
||||
@@ -33,6 +35,7 @@ func newCaFlags() *caFlags {
|
||||
cf.duration = cf.set.Duration("duration", time.Duration(time.Hour*8760), "Optional: amount of time the certificate should be valid for. Valid time units are seconds: \"s\", minutes: \"m\", hours: \"h\"")
|
||||
cf.outKeyPath = cf.set.String("out-key", "ca.key", "Optional: path to write the private key to")
|
||||
cf.outCertPath = cf.set.String("out-crt", "ca.crt", "Optional: path to write the certificate to")
|
||||
cf.outQRPath = cf.set.String("out-qr", "", "Optional: output a qr code image (png) of the certificate")
|
||||
cf.groups = cf.set.String("groups", "", "Optional: comma separated list of groups. This will limit which groups subordinate certs can use")
|
||||
cf.ips = cf.set.String("ips", "", "Optional: comma separated list of ip and network in CIDR notation. This will limit which ip addresses and networks subordinate certs can use")
|
||||
cf.subnets = cf.set.String("subnets", "", "Optional: comma separated list of ip and network in CIDR notation. This will limit which subnet addresses and networks subordinate certs can use")
|
||||
@@ -146,6 +149,18 @@ func ca(args []string, out io.Writer, errOut io.Writer) error {
|
||||
return fmt.Errorf("error while writing out-crt: %s", err)
|
||||
}
|
||||
|
||||
if *cf.outQRPath != "" {
|
||||
b, err = qrcode.Encode(string(b), qrcode.Medium, -5)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error while generating qr code: %s", err)
|
||||
}
|
||||
|
||||
err = ioutil.WriteFile(*cf.outQRPath, b, 0600)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error while writing out-qr: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -37,6 +37,8 @@ func Test_caHelp(t *testing.T) {
|
||||
" \tOptional: path to write the certificate to (default \"ca.crt\")\n"+
|
||||
" -out-key string\n"+
|
||||
" \tOptional: path to write the private key to (default \"ca.key\")\n"+
|
||||
" -out-qr string\n"+
|
||||
" \tOptional: output a qr code image (png) of the certificate\n"+
|
||||
" -subnets string\n"+
|
||||
" \tOptional: comma separated list of ip and network in CIDR notation. This will limit which subnet addresses and networks subordinate certs can use\n",
|
||||
ob.String(),
|
||||
|
||||
@@ -3,10 +3,11 @@ package main
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"io"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
//TODO: all flag parsing continueOnError will print to stderr on its own currently
|
||||
|
||||
@@ -4,23 +4,27 @@ import (
|
||||
"encoding/json"
|
||||
"flag"
|
||||
"fmt"
|
||||
"github.com/slackhq/nebula/cert"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/skip2/go-qrcode"
|
||||
"github.com/slackhq/nebula/cert"
|
||||
)
|
||||
|
||||
type printFlags struct {
|
||||
set *flag.FlagSet
|
||||
json *bool
|
||||
path *string
|
||||
set *flag.FlagSet
|
||||
json *bool
|
||||
outQRPath *string
|
||||
path *string
|
||||
}
|
||||
|
||||
func newPrintFlags() *printFlags {
|
||||
pf := printFlags{set: flag.NewFlagSet("print", flag.ContinueOnError)}
|
||||
pf.set.Usage = func() {}
|
||||
pf.json = pf.set.Bool("json", false, "Optional: outputs certificates in json format")
|
||||
pf.outQRPath = pf.set.String("out-qr", "", "Optional: output a qr code image (png) of the certificate")
|
||||
pf.path = pf.set.String("path", "", "Required: path to the certificate")
|
||||
|
||||
return &pf
|
||||
@@ -43,6 +47,8 @@ func printCert(args []string, out io.Writer, errOut io.Writer) error {
|
||||
}
|
||||
|
||||
var c *cert.NebulaCertificate
|
||||
var qrBytes []byte
|
||||
part := 0
|
||||
|
||||
for {
|
||||
c, rawCert, err = cert.UnmarshalNebulaCertificateFromPEM(rawCert)
|
||||
@@ -60,9 +66,31 @@ func printCert(args []string, out io.Writer, errOut io.Writer) error {
|
||||
out.Write([]byte("\n"))
|
||||
}
|
||||
|
||||
if *pf.outQRPath != "" {
|
||||
b, err := c.MarshalToPEM()
|
||||
if err != nil {
|
||||
return fmt.Errorf("error while marshalling cert to PEM: %s", err)
|
||||
}
|
||||
qrBytes = append(qrBytes, b...)
|
||||
}
|
||||
|
||||
if rawCert == nil || len(rawCert) == 0 || strings.TrimSpace(string(rawCert)) == "" {
|
||||
break
|
||||
}
|
||||
|
||||
part++
|
||||
}
|
||||
|
||||
if *pf.outQRPath != "" {
|
||||
b, err := qrcode.Encode(string(qrBytes), qrcode.Medium, -5)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error while generating qr code: %s", err)
|
||||
}
|
||||
|
||||
err = ioutil.WriteFile(*pf.outQRPath, b, 0600)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error while writing out-qr: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
@@ -2,12 +2,13 @@ package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"github.com/slackhq/nebula/cert"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/slackhq/nebula/cert"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_printSummary(t *testing.T) {
|
||||
@@ -22,6 +23,8 @@ func Test_printHelp(t *testing.T) {
|
||||
"Usage of "+os.Args[0]+" print <flags>: prints details about a certificate\n"+
|
||||
" -json\n"+
|
||||
" \tOptional: outputs certificates in json format\n"+
|
||||
" -out-qr string\n"+
|
||||
" \tOptional: output a qr code image (png) of the certificate\n"+
|
||||
" -path string\n"+
|
||||
" \tRequired: path to the certificate\n",
|
||||
ob.String(),
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/skip2/go-qrcode"
|
||||
"github.com/slackhq/nebula/cert"
|
||||
"golang.org/x/crypto/curve25519"
|
||||
)
|
||||
@@ -25,6 +26,7 @@ type signFlags struct {
|
||||
inPubPath *string
|
||||
outKeyPath *string
|
||||
outCertPath *string
|
||||
outQRPath *string
|
||||
groups *string
|
||||
subnets *string
|
||||
}
|
||||
@@ -40,8 +42,9 @@ func newSignFlags() *signFlags {
|
||||
sf.inPubPath = sf.set.String("in-pub", "", "Optional (if out-key not set): path to read a previously generated public key")
|
||||
sf.outKeyPath = sf.set.String("out-key", "", "Optional (if in-pub not set): path to write the private key to")
|
||||
sf.outCertPath = sf.set.String("out-crt", "", "Optional: path to write the certificate to")
|
||||
sf.outQRPath = sf.set.String("out-qr", "", "Optional: output a qr code image (png) of the certificate")
|
||||
sf.groups = sf.set.String("groups", "", "Optional: comma separated list of groups")
|
||||
sf.subnets = sf.set.String("subnets", "", "Optional: comma seperated list of subnet this cert can serve for")
|
||||
sf.subnets = sf.set.String("subnets", "", "Optional: comma separated list of subnet this cert can serve for")
|
||||
return &sf
|
||||
|
||||
}
|
||||
@@ -203,6 +206,18 @@ func signCert(args []string, out io.Writer, errOut io.Writer) error {
|
||||
return fmt.Errorf("error while writing out-crt: %s", err)
|
||||
}
|
||||
|
||||
if *sf.outQRPath != "" {
|
||||
b, err = qrcode.Encode(string(b), qrcode.Medium, -5)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error while generating qr code: %s", err)
|
||||
}
|
||||
|
||||
err = ioutil.WriteFile(*sf.outQRPath, b, 0600)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error while writing out-qr: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -45,8 +45,10 @@ func Test_signHelp(t *testing.T) {
|
||||
" \tOptional: path to write the certificate to\n"+
|
||||
" -out-key string\n"+
|
||||
" \tOptional (if in-pub not set): path to write the private key to\n"+
|
||||
" -out-qr string\n"+
|
||||
" \tOptional: output a qr code image (png) of the certificate\n"+
|
||||
" -subnets string\n"+
|
||||
" \tOptional: comma seperated list of subnet this cert can serve for\n",
|
||||
" \tOptional: comma separated list of subnet this cert can serve for\n",
|
||||
ob.String(),
|
||||
)
|
||||
}
|
||||
@@ -286,5 +288,4 @@ func Test_signCert(t *testing.T) {
|
||||
assert.EqualError(t, signCert(args, ob, eb), "refusing to overwrite existing cert: "+crtF.Name())
|
||||
assert.Empty(t, ob.String())
|
||||
assert.Empty(t, eb.String())
|
||||
|
||||
}
|
||||
|
||||
@@ -3,12 +3,13 @@ package main
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"github.com/slackhq/nebula/cert"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/slackhq/nebula/cert"
|
||||
)
|
||||
|
||||
type verifyFlags struct {
|
||||
|
||||
@@ -3,13 +3,14 @@ package main
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"github.com/slackhq/nebula/cert"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"golang.org/x/crypto/ed25519"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/slackhq/nebula/cert"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"golang.org/x/crypto/ed25519"
|
||||
)
|
||||
|
||||
func Test_verifySummary(t *testing.T) {
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula"
|
||||
)
|
||||
|
||||
@@ -45,5 +46,30 @@ func main() {
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
nebula.Main(*configPath, *configTest, Build)
|
||||
config := nebula.NewConfig()
|
||||
err := config.Load(*configPath)
|
||||
if err != nil {
|
||||
fmt.Printf("failed to load config: %s", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
l := logrus.New()
|
||||
l.Out = os.Stdout
|
||||
c, err := nebula.Main(config, *configTest, Build, l, nil)
|
||||
|
||||
switch v := err.(type) {
|
||||
case nebula.ContextualError:
|
||||
v.Log(l)
|
||||
os.Exit(1)
|
||||
case error:
|
||||
l.WithError(err).Error("Failed to start")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if !*configTest {
|
||||
c.Start()
|
||||
c.ShutdownBlock()
|
||||
}
|
||||
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
@@ -1,44 +1,53 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/kardianos/service"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula"
|
||||
)
|
||||
|
||||
var logger service.Logger
|
||||
|
||||
type program struct {
|
||||
exit chan struct{}
|
||||
configPath *string
|
||||
configTest *bool
|
||||
build string
|
||||
control *nebula.Control
|
||||
}
|
||||
|
||||
func (p *program) Start(s service.Service) error {
|
||||
logger.Info("Nebula service starting.")
|
||||
p.exit = make(chan struct{})
|
||||
// Start should not block.
|
||||
go p.run()
|
||||
return nil
|
||||
}
|
||||
logger.Info("Nebula service starting.")
|
||||
|
||||
func (p *program) run() error {
|
||||
nebula.Main(*p.configPath, *p.configTest, Build)
|
||||
config := nebula.NewConfig()
|
||||
err := config.Load(*p.configPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load config: %s", err)
|
||||
}
|
||||
|
||||
l := logrus.New()
|
||||
l.Out = os.Stdout
|
||||
p.control, err = nebula.Main(config, *p.configTest, Build, l, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
p.control.Start()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *program) Stop(s service.Service) error {
|
||||
logger.Info("Nebula service stopping.")
|
||||
close(p.exit)
|
||||
p.control.Stop()
|
||||
return nil
|
||||
}
|
||||
|
||||
func doService(configPath *string, configTest *bool, build string, serviceFlag *string) {
|
||||
|
||||
if *configPath == "" {
|
||||
ex, err := os.Executable()
|
||||
if err != nil {
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula"
|
||||
)
|
||||
|
||||
@@ -39,5 +40,30 @@ func main() {
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
nebula.Main(*configPath, *configTest, Build)
|
||||
config := nebula.NewConfig()
|
||||
err := config.Load(*configPath)
|
||||
if err != nil {
|
||||
fmt.Printf("failed to load config: %s", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
l := logrus.New()
|
||||
l.Out = os.Stdout
|
||||
c, err := nebula.Main(config, *configTest, Build, l, nil)
|
||||
|
||||
switch v := err.(type) {
|
||||
case nebula.ContextualError:
|
||||
v.Log(l)
|
||||
os.Exit(1)
|
||||
case error:
|
||||
l.WithError(err).Error("Failed to start")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if !*configTest {
|
||||
c.Start()
|
||||
c.ShutdownBlock()
|
||||
}
|
||||
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
188
config.go
188
config.go
@@ -1,19 +1,23 @@
|
||||
package nebula
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/imdario/mergo"
|
||||
"github.com/sirupsen/logrus"
|
||||
"gopkg.in/yaml.v2"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"os"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/imdario/mergo"
|
||||
"github.com/sirupsen/logrus"
|
||||
"gopkg.in/yaml.v2"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
@@ -35,7 +39,7 @@ func (c *Config) Load(path string) error {
|
||||
c.path = path
|
||||
c.files = make([]string, 0)
|
||||
|
||||
err := c.resolve(path)
|
||||
err := c.resolve(path, true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -54,6 +58,13 @@ func (c *Config) Load(path string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Config) LoadString(raw string) error {
|
||||
if raw == "" {
|
||||
return errors.New("Empty configuration")
|
||||
}
|
||||
return c.parseRaw([]byte(raw))
|
||||
}
|
||||
|
||||
// RegisterReloadCallback stores a function to be called when a config reload is triggered. The functions registered
|
||||
// here should decide if they need to make a change to the current process before making the change. HasChanged can be
|
||||
// used to help decide if a change is necessary.
|
||||
@@ -213,10 +224,137 @@ func (c *Config) GetDuration(k string, d time.Duration) time.Duration {
|
||||
return v
|
||||
}
|
||||
|
||||
func (c *Config) GetAllowList(k string, allowInterfaces bool) (*AllowList, error) {
|
||||
r := c.Get(k)
|
||||
if r == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
rawMap, ok := r.(map[interface{}]interface{})
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("config `%s` has invalid type: %T", k, r)
|
||||
}
|
||||
|
||||
tree := NewCIDRTree()
|
||||
var nameRules []AllowListNameRule
|
||||
|
||||
firstValue := true
|
||||
allValuesMatch := true
|
||||
defaultSet := false
|
||||
var allValues bool
|
||||
|
||||
for rawKey, rawValue := range rawMap {
|
||||
rawCIDR, ok := rawKey.(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("config `%s` has invalid key (type %T): %v", k, rawKey, rawKey)
|
||||
}
|
||||
|
||||
// Special rule for interface names
|
||||
if rawCIDR == "interfaces" {
|
||||
if !allowInterfaces {
|
||||
return nil, fmt.Errorf("config `%s` does not support `interfaces`", k)
|
||||
}
|
||||
var err error
|
||||
nameRules, err = c.getAllowListInterfaces(k, rawValue)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
value, ok := rawValue.(bool)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("config `%s` has invalid value (type %T): %v", k, rawValue, rawValue)
|
||||
}
|
||||
|
||||
_, cidr, err := net.ParseCIDR(rawCIDR)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("config `%s` has invalid CIDR: %s", k, rawCIDR)
|
||||
}
|
||||
|
||||
// TODO: should we error on duplicate CIDRs in the config?
|
||||
tree.AddCIDR(cidr, value)
|
||||
|
||||
if firstValue {
|
||||
allValues = value
|
||||
firstValue = false
|
||||
} else {
|
||||
if value != allValues {
|
||||
allValuesMatch = false
|
||||
}
|
||||
}
|
||||
|
||||
// Check if this is 0.0.0.0/0
|
||||
bits, size := cidr.Mask.Size()
|
||||
if bits == 0 && size == 32 {
|
||||
defaultSet = true
|
||||
}
|
||||
}
|
||||
|
||||
if !defaultSet {
|
||||
if allValuesMatch {
|
||||
_, zeroCIDR, _ := net.ParseCIDR("0.0.0.0/0")
|
||||
tree.AddCIDR(zeroCIDR, !allValues)
|
||||
} else {
|
||||
return nil, fmt.Errorf("config `%s` contains both true and false rules, but no default set for 0.0.0.0/0", k)
|
||||
}
|
||||
}
|
||||
|
||||
return &AllowList{cidrTree: tree, nameRules: nameRules}, nil
|
||||
}
|
||||
|
||||
func (c *Config) getAllowListInterfaces(k string, v interface{}) ([]AllowListNameRule, error) {
|
||||
var nameRules []AllowListNameRule
|
||||
|
||||
rawRules, ok := v.(map[interface{}]interface{})
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("config `%s.interfaces` is invalid (type %T): %v", k, v, v)
|
||||
}
|
||||
|
||||
firstEntry := true
|
||||
var allValues bool
|
||||
for rawName, rawAllow := range rawRules {
|
||||
name, ok := rawName.(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("config `%s.interfaces` has invalid key (type %T): %v", k, rawName, rawName)
|
||||
}
|
||||
allow, ok := rawAllow.(bool)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("config `%s.interfaces` has invalid value (type %T): %v", k, rawAllow, rawAllow)
|
||||
}
|
||||
|
||||
nameRE, err := regexp.Compile("^" + name + "$")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("config `%s.interfaces` has invalid key: %s: %v", k, name, err)
|
||||
}
|
||||
|
||||
nameRules = append(nameRules, AllowListNameRule{
|
||||
Name: nameRE,
|
||||
Allow: allow,
|
||||
})
|
||||
|
||||
if firstEntry {
|
||||
allValues = allow
|
||||
firstEntry = false
|
||||
} else {
|
||||
if allow != allValues {
|
||||
return nil, fmt.Errorf("config `%s.interfaces` values must all be the same true/false value", k)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nameRules, nil
|
||||
}
|
||||
|
||||
func (c *Config) Get(k string) interface{} {
|
||||
return c.get(k, c.Settings)
|
||||
}
|
||||
|
||||
func (c *Config) IsSet(k string) bool {
|
||||
return c.get(k, c.Settings) != nil
|
||||
}
|
||||
|
||||
func (c *Config) get(k string, v interface{}) interface{} {
|
||||
parts := strings.Split(k, ".")
|
||||
for _, p := range parts {
|
||||
@@ -234,14 +372,16 @@ func (c *Config) get(k string, v interface{}) interface{} {
|
||||
return v
|
||||
}
|
||||
|
||||
func (c *Config) resolve(path string) error {
|
||||
// direct signifies if this is the config path directly specified by the user,
|
||||
// versus a file/dir found by recursing into that path
|
||||
func (c *Config) resolve(path string, direct bool) error {
|
||||
i, err := os.Stat(path)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if !i.IsDir() {
|
||||
c.addFile(path)
|
||||
c.addFile(path, direct)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -251,7 +391,7 @@ func (c *Config) resolve(path string) error {
|
||||
}
|
||||
|
||||
for _, p := range paths {
|
||||
err := c.resolve(filepath.Join(path, p))
|
||||
err := c.resolve(filepath.Join(path, p), false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -260,10 +400,10 @@ func (c *Config) resolve(path string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Config) addFile(path string) error {
|
||||
func (c *Config) addFile(path string, direct bool) error {
|
||||
ext := filepath.Ext(path)
|
||||
|
||||
if ext != ".yaml" && ext != ".yml" {
|
||||
if !direct && ext != ".yaml" && ext != ".yml" {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -276,6 +416,18 @@ func (c *Config) addFile(path string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Config) parseRaw(b []byte) error {
|
||||
var m map[interface{}]interface{}
|
||||
|
||||
err := yaml.Unmarshal(b, &m)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.Settings = m
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Config) parse() error {
|
||||
var m map[interface{}]interface{}
|
||||
|
||||
@@ -328,12 +480,26 @@ func configLogger(c *Config) error {
|
||||
}
|
||||
l.SetLevel(logLevel)
|
||||
|
||||
disableTimestamp := c.GetBool("logging.disable_timestamp", false)
|
||||
timestampFormat := c.GetString("logging.timestamp_format", "")
|
||||
fullTimestamp := (timestampFormat != "")
|
||||
if timestampFormat == "" {
|
||||
timestampFormat = time.RFC3339
|
||||
}
|
||||
|
||||
logFormat := strings.ToLower(c.GetString("logging.format", "text"))
|
||||
switch logFormat {
|
||||
case "text":
|
||||
l.Formatter = &logrus.TextFormatter{}
|
||||
l.Formatter = &logrus.TextFormatter{
|
||||
TimestampFormat: timestampFormat,
|
||||
FullTimestamp: fullTimestamp,
|
||||
DisableTimestamp: disableTimestamp,
|
||||
}
|
||||
case "json":
|
||||
l.Formatter = &logrus.JSONFormatter{}
|
||||
l.Formatter = &logrus.JSONFormatter{
|
||||
TimestampFormat: timestampFormat,
|
||||
DisableTimestamp: disableTimestamp,
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("unknown log format `%s`. possible formats: %s", logFormat, []string{"text", "json"})
|
||||
}
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
package nebula
|
||||
|
||||
import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestConfig_Load(t *testing.T) {
|
||||
@@ -86,6 +87,76 @@ func TestConfig_GetBool(t *testing.T) {
|
||||
assert.Equal(t, false, c.GetBool("bool", true))
|
||||
}
|
||||
|
||||
func TestConfig_GetAllowList(t *testing.T) {
|
||||
c := NewConfig()
|
||||
c.Settings["allowlist"] = map[interface{}]interface{}{
|
||||
"192.168.0.0": true,
|
||||
}
|
||||
r, err := c.GetAllowList("allowlist", false)
|
||||
assert.EqualError(t, err, "config `allowlist` has invalid CIDR: 192.168.0.0")
|
||||
assert.Nil(t, r)
|
||||
|
||||
c.Settings["allowlist"] = map[interface{}]interface{}{
|
||||
"192.168.0.0/16": "abc",
|
||||
}
|
||||
r, err = c.GetAllowList("allowlist", false)
|
||||
assert.EqualError(t, err, "config `allowlist` has invalid value (type string): abc")
|
||||
|
||||
c.Settings["allowlist"] = map[interface{}]interface{}{
|
||||
"192.168.0.0/16": true,
|
||||
"10.0.0.0/8": false,
|
||||
}
|
||||
r, err = c.GetAllowList("allowlist", false)
|
||||
assert.EqualError(t, err, "config `allowlist` contains both true and false rules, but no default set for 0.0.0.0/0")
|
||||
|
||||
c.Settings["allowlist"] = map[interface{}]interface{}{
|
||||
"0.0.0.0/0": true,
|
||||
"10.0.0.0/8": false,
|
||||
"10.42.42.0/24": true,
|
||||
}
|
||||
r, err = c.GetAllowList("allowlist", false)
|
||||
if assert.NoError(t, err) {
|
||||
assert.NotNil(t, r)
|
||||
}
|
||||
|
||||
// Test interface names
|
||||
|
||||
c.Settings["allowlist"] = map[interface{}]interface{}{
|
||||
"interfaces": map[interface{}]interface{}{
|
||||
`docker.*`: false,
|
||||
},
|
||||
}
|
||||
r, err = c.GetAllowList("allowlist", false)
|
||||
assert.EqualError(t, err, "config `allowlist` does not support `interfaces`")
|
||||
|
||||
c.Settings["allowlist"] = map[interface{}]interface{}{
|
||||
"interfaces": map[interface{}]interface{}{
|
||||
`docker.*`: "foo",
|
||||
},
|
||||
}
|
||||
r, err = c.GetAllowList("allowlist", true)
|
||||
assert.EqualError(t, err, "config `allowlist.interfaces` has invalid value (type string): foo")
|
||||
|
||||
c.Settings["allowlist"] = map[interface{}]interface{}{
|
||||
"interfaces": map[interface{}]interface{}{
|
||||
`docker.*`: false,
|
||||
`eth.*`: true,
|
||||
},
|
||||
}
|
||||
r, err = c.GetAllowList("allowlist", true)
|
||||
assert.EqualError(t, err, "config `allowlist.interfaces` values must all be the same true/false value")
|
||||
|
||||
c.Settings["allowlist"] = map[interface{}]interface{}{
|
||||
"interfaces": map[interface{}]interface{}{
|
||||
`docker.*`: false,
|
||||
},
|
||||
}
|
||||
r, err = c.GetAllowList("allowlist", true)
|
||||
if assert.NoError(t, err) {
|
||||
assert.NotNil(t, r)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfig_HasChanged(t *testing.T) {
|
||||
// No reload has occurred, return false
|
||||
c := NewConfig()
|
||||
|
||||
@@ -141,14 +141,17 @@ func (n *connectionManager) Start() {
|
||||
|
||||
func (n *connectionManager) Run() {
|
||||
clockSource := time.Tick(500 * time.Millisecond)
|
||||
p := []byte("")
|
||||
nb := make([]byte, 12, 12)
|
||||
out := make([]byte, mtu)
|
||||
|
||||
for now := range clockSource {
|
||||
n.HandleMonitorTick(now)
|
||||
n.HandleMonitorTick(now, p, nb, out)
|
||||
n.HandleDeletionTick(now)
|
||||
}
|
||||
}
|
||||
|
||||
func (n *connectionManager) HandleMonitorTick(now time.Time) {
|
||||
func (n *connectionManager) HandleMonitorTick(now time.Time, p, nb, out []byte) {
|
||||
n.TrafficTimer.advance(now)
|
||||
for {
|
||||
ep := n.TrafficTimer.Purge()
|
||||
@@ -182,16 +185,16 @@ func (n *connectionManager) HandleMonitorTick(now time.Time) {
|
||||
continue
|
||||
}
|
||||
|
||||
l.WithField("vpnIp", IntIp(vpnIP)).
|
||||
hostinfo.logger().
|
||||
WithField("tunnelCheck", m{"state": "testing", "method": "active"}).
|
||||
Debug("Tunnel status")
|
||||
|
||||
if hostinfo != nil && hostinfo.ConnectionState != nil {
|
||||
// Send a test packet to trigger an authenticated tunnel test, this should suss out any lingering tunnel issues
|
||||
n.intf.SendMessageToVpnIp(test, testRequest, vpnIP, []byte(""), make([]byte, 12, 12), make([]byte, mtu))
|
||||
n.intf.SendMessageToVpnIp(test, testRequest, vpnIP, p, nb, out)
|
||||
|
||||
} else {
|
||||
l.Debugf("Hostinfo sadness: %s", IntIp(vpnIP))
|
||||
hostinfo.logger().Debugf("Hostinfo sadness: %s", IntIp(vpnIP))
|
||||
}
|
||||
n.AddPendingDeletion(vpnIP)
|
||||
}
|
||||
@@ -233,7 +236,7 @@ func (n *connectionManager) HandleDeletionTick(now time.Time) {
|
||||
if hostinfo.ConnectionState != nil && hostinfo.ConnectionState.peerCert != nil {
|
||||
cn = hostinfo.ConnectionState.peerCert.Details.Name
|
||||
}
|
||||
l.WithField("vpnIp", IntIp(vpnIP)).
|
||||
hostinfo.logger().
|
||||
WithField("tunnelCheck", m{"state": "dead", "method": "active"}).
|
||||
WithField("certName", cn).
|
||||
Info("Tunnel status")
|
||||
@@ -244,8 +247,7 @@ func (n *connectionManager) HandleDeletionTick(now time.Time) {
|
||||
if n.intf.lightHouse != nil {
|
||||
n.intf.lightHouse.DeleteVpnIP(vpnIP)
|
||||
}
|
||||
n.hostMap.DeleteVpnIP(vpnIP)
|
||||
n.hostMap.DeleteIndex(hostinfo.localIndexId)
|
||||
n.hostMap.DeleteHostInfo(hostinfo)
|
||||
} else {
|
||||
n.ClearIP(vpnIP)
|
||||
n.ClearPendingDeletion(vpnIP)
|
||||
|
||||
@@ -28,7 +28,7 @@ func Test_NewConnectionManagerTest(t *testing.T) {
|
||||
rawCertificateNoKey: []byte{},
|
||||
}
|
||||
|
||||
lh := NewLightHouse(false, 0, []uint32{}, 1000, 0, &udpConn{}, false)
|
||||
lh := NewLightHouse(false, 0, []uint32{}, 1000, 0, &udpConn{}, false, 1, false)
|
||||
ifce := &Interface{
|
||||
hostMap: hostMap,
|
||||
inside: &Tun{},
|
||||
@@ -36,19 +36,21 @@ func Test_NewConnectionManagerTest(t *testing.T) {
|
||||
certState: cs,
|
||||
firewall: &Firewall{},
|
||||
lightHouse: lh,
|
||||
handshakeManager: NewHandshakeManager(vpncidr, preferredRanges, hostMap, lh, &udpConn{}),
|
||||
handshakeManager: NewHandshakeManager(vpncidr, preferredRanges, hostMap, lh, &udpConn{}, defaultHandshakeConfig),
|
||||
}
|
||||
now := time.Now()
|
||||
|
||||
// Create manager
|
||||
nc := newConnectionManager(ifce, 5, 10)
|
||||
nc.HandleMonitorTick(now)
|
||||
p := []byte("")
|
||||
nb := make([]byte, 12, 12)
|
||||
out := make([]byte, mtu)
|
||||
nc.HandleMonitorTick(now, p, nb, out)
|
||||
// Add an ip we have established a connection w/ to hostmap
|
||||
hostinfo := nc.hostMap.AddVpnIP(vpnIP)
|
||||
hostinfo.ConnectionState = &ConnectionState{
|
||||
certState: cs,
|
||||
H: &noise.HandshakeState{},
|
||||
messageCounter: new(uint64),
|
||||
certState: cs,
|
||||
H: &noise.HandshakeState{},
|
||||
}
|
||||
|
||||
// We saw traffic out to vpnIP
|
||||
@@ -57,18 +59,18 @@ func Test_NewConnectionManagerTest(t *testing.T) {
|
||||
assert.Contains(t, nc.hostMap.Hosts, vpnIP)
|
||||
// Move ahead 5s. Nothing should happen
|
||||
next_tick := now.Add(5 * time.Second)
|
||||
nc.HandleMonitorTick(next_tick)
|
||||
nc.HandleMonitorTick(next_tick, p, nb, out)
|
||||
nc.HandleDeletionTick(next_tick)
|
||||
// Move ahead 6s. We haven't heard back
|
||||
next_tick = now.Add(6 * time.Second)
|
||||
nc.HandleMonitorTick(next_tick)
|
||||
nc.HandleMonitorTick(next_tick, p, nb, out)
|
||||
nc.HandleDeletionTick(next_tick)
|
||||
// This host should now be up for deletion
|
||||
assert.Contains(t, nc.pendingDeletion, vpnIP)
|
||||
assert.Contains(t, nc.hostMap.Hosts, vpnIP)
|
||||
// Move ahead some more
|
||||
next_tick = now.Add(45 * time.Second)
|
||||
nc.HandleMonitorTick(next_tick)
|
||||
nc.HandleMonitorTick(next_tick, p, nb, out)
|
||||
nc.HandleDeletionTick(next_tick)
|
||||
// The host should be evicted
|
||||
assert.NotContains(t, nc.pendingDeletion, vpnIP)
|
||||
@@ -91,7 +93,7 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
|
||||
rawCertificateNoKey: []byte{},
|
||||
}
|
||||
|
||||
lh := NewLightHouse(false, 0, []uint32{}, 1000, 0, &udpConn{}, false)
|
||||
lh := NewLightHouse(false, 0, []uint32{}, 1000, 0, &udpConn{}, false, 1, false)
|
||||
ifce := &Interface{
|
||||
hostMap: hostMap,
|
||||
inside: &Tun{},
|
||||
@@ -99,19 +101,21 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
|
||||
certState: cs,
|
||||
firewall: &Firewall{},
|
||||
lightHouse: lh,
|
||||
handshakeManager: NewHandshakeManager(vpncidr, preferredRanges, hostMap, lh, &udpConn{}),
|
||||
handshakeManager: NewHandshakeManager(vpncidr, preferredRanges, hostMap, lh, &udpConn{}, defaultHandshakeConfig),
|
||||
}
|
||||
now := time.Now()
|
||||
|
||||
// Create manager
|
||||
nc := newConnectionManager(ifce, 5, 10)
|
||||
nc.HandleMonitorTick(now)
|
||||
p := []byte("")
|
||||
nb := make([]byte, 12, 12)
|
||||
out := make([]byte, mtu)
|
||||
nc.HandleMonitorTick(now, p, nb, out)
|
||||
// Add an ip we have established a connection w/ to hostmap
|
||||
hostinfo := nc.hostMap.AddVpnIP(vpnIP)
|
||||
hostinfo.ConnectionState = &ConnectionState{
|
||||
certState: cs,
|
||||
H: &noise.HandshakeState{},
|
||||
messageCounter: new(uint64),
|
||||
certState: cs,
|
||||
H: &noise.HandshakeState{},
|
||||
}
|
||||
|
||||
// We saw traffic out to vpnIP
|
||||
@@ -120,11 +124,11 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
|
||||
assert.Contains(t, nc.hostMap.Hosts, vpnIP)
|
||||
// Move ahead 5s. Nothing should happen
|
||||
next_tick := now.Add(5 * time.Second)
|
||||
nc.HandleMonitorTick(next_tick)
|
||||
nc.HandleMonitorTick(next_tick, p, nb, out)
|
||||
nc.HandleDeletionTick(next_tick)
|
||||
// Move ahead 6s. We haven't heard back
|
||||
next_tick = now.Add(6 * time.Second)
|
||||
nc.HandleMonitorTick(next_tick)
|
||||
nc.HandleMonitorTick(next_tick, p, nb, out)
|
||||
nc.HandleDeletionTick(next_tick)
|
||||
// This host should now be up for deletion
|
||||
assert.Contains(t, nc.pendingDeletion, vpnIP)
|
||||
@@ -133,7 +137,7 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
|
||||
nc.In(vpnIP)
|
||||
// Move ahead some more
|
||||
next_tick = now.Add(45 * time.Second)
|
||||
nc.HandleMonitorTick(next_tick)
|
||||
nc.HandleMonitorTick(next_tick, p, nb, out)
|
||||
nc.HandleDeletionTick(next_tick)
|
||||
// The host should be evicted
|
||||
assert.NotContains(t, nc.pendingDeletion, vpnIP)
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"crypto/rand"
|
||||
"encoding/json"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/flynn/noise"
|
||||
"github.com/slackhq/nebula/cert"
|
||||
@@ -12,17 +13,17 @@ import (
|
||||
const ReplayWindow = 1024
|
||||
|
||||
type ConnectionState struct {
|
||||
eKey *NebulaCipherState
|
||||
dKey *NebulaCipherState
|
||||
H *noise.HandshakeState
|
||||
certState *CertState
|
||||
peerCert *cert.NebulaCertificate
|
||||
initiator bool
|
||||
messageCounter *uint64
|
||||
window *Bits
|
||||
queueLock sync.Mutex
|
||||
writeLock sync.Mutex
|
||||
ready bool
|
||||
eKey *NebulaCipherState
|
||||
dKey *NebulaCipherState
|
||||
H *noise.HandshakeState
|
||||
certState *CertState
|
||||
peerCert *cert.NebulaCertificate
|
||||
initiator bool
|
||||
atomicMessageCounter uint64
|
||||
window *Bits
|
||||
queueLock sync.Mutex
|
||||
writeLock sync.Mutex
|
||||
ready bool
|
||||
}
|
||||
|
||||
func (f *Interface) newConnectionState(initiator bool, pattern noise.HandshakePattern, psk []byte, pskStage int) *ConnectionState {
|
||||
@@ -54,12 +55,11 @@ func (f *Interface) newConnectionState(initiator bool, pattern noise.HandshakePa
|
||||
// The queue and ready params prevent a counter race that would happen when
|
||||
// sending stored packets and simultaneously accepting new traffic.
|
||||
ci := &ConnectionState{
|
||||
H: hs,
|
||||
initiator: initiator,
|
||||
window: b,
|
||||
ready: false,
|
||||
certState: curCertState,
|
||||
messageCounter: new(uint64),
|
||||
H: hs,
|
||||
initiator: initiator,
|
||||
window: b,
|
||||
ready: false,
|
||||
certState: curCertState,
|
||||
}
|
||||
|
||||
return ci
|
||||
@@ -69,7 +69,7 @@ func (cs *ConnectionState) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(m{
|
||||
"certificate": cs.peerCert,
|
||||
"initiator": cs.initiator,
|
||||
"message_counter": cs.messageCounter,
|
||||
"message_counter": atomic.LoadUint64(&cs.atomicMessageCounter),
|
||||
"ready": cs.ready,
|
||||
})
|
||||
}
|
||||
|
||||
176
control.go
Normal file
176
control.go
Normal file
@@ -0,0 +1,176 @@
|
||||
package nebula
|
||||
|
||||
import (
|
||||
"net"
|
||||
"os"
|
||||
"os/signal"
|
||||
"sync/atomic"
|
||||
"syscall"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/cert"
|
||||
)
|
||||
|
||||
// Every interaction here needs to take extra care to copy memory and not return or use arguments "as is" when touching
|
||||
// core. This means copying IP objects, slices, de-referencing pointers and taking the actual value, etc
|
||||
|
||||
type Control struct {
|
||||
f *Interface
|
||||
l *logrus.Logger
|
||||
}
|
||||
|
||||
type ControlHostInfo struct {
|
||||
VpnIP net.IP `json:"vpnIp"`
|
||||
LocalIndex uint32 `json:"localIndex"`
|
||||
RemoteIndex uint32 `json:"remoteIndex"`
|
||||
RemoteAddrs []udpAddr `json:"remoteAddrs"`
|
||||
CachedPackets int `json:"cachedPackets"`
|
||||
Cert *cert.NebulaCertificate `json:"cert"`
|
||||
MessageCounter uint64 `json:"messageCounter"`
|
||||
CurrentRemote udpAddr `json:"currentRemote"`
|
||||
}
|
||||
|
||||
// Start actually runs nebula, this is a nonblocking call. To block use Control.ShutdownBlock()
|
||||
func (c *Control) Start() {
|
||||
c.f.run()
|
||||
}
|
||||
|
||||
// Stop signals nebula to shutdown, returns after the shutdown is complete
|
||||
func (c *Control) Stop() {
|
||||
//TODO: stop tun and udp routines, the lock on hostMap effectively does that though
|
||||
//TODO: this is probably better as a function in ConnectionManager or HostMap directly
|
||||
c.f.hostMap.Lock()
|
||||
for _, h := range c.f.hostMap.Hosts {
|
||||
if h.ConnectionState.ready {
|
||||
c.f.send(closeTunnel, 0, h.ConnectionState, h, h.remote, []byte{}, make([]byte, 12, 12), make([]byte, mtu))
|
||||
c.l.WithField("vpnIp", IntIp(h.hostId)).WithField("udpAddr", h.remote).
|
||||
Debug("Sending close tunnel message")
|
||||
}
|
||||
}
|
||||
c.f.hostMap.Unlock()
|
||||
c.l.Info("Goodbye")
|
||||
}
|
||||
|
||||
// ShutdownBlock will listen for and block on term and interrupt signals, calling Control.Stop() once signalled
|
||||
func (c *Control) ShutdownBlock() {
|
||||
sigChan := make(chan os.Signal)
|
||||
signal.Notify(sigChan, syscall.SIGTERM)
|
||||
signal.Notify(sigChan, syscall.SIGINT)
|
||||
|
||||
rawSig := <-sigChan
|
||||
sig := rawSig.String()
|
||||
c.l.WithField("signal", sig).Info("Caught signal, shutting down")
|
||||
c.Stop()
|
||||
}
|
||||
|
||||
// RebindUDPServer asks the UDP listener to rebind it's listener. Mainly used on mobile clients when interfaces change
|
||||
func (c *Control) RebindUDPServer() {
|
||||
_ = c.f.outside.Rebind()
|
||||
|
||||
// Trigger a lighthouse update, useful for mobile clients that should have an update interval of 0
|
||||
c.f.lightHouse.SendUpdate(c.f)
|
||||
|
||||
// Let the main interface know that we rebound so that underlying tunnels know to trigger punches from their remotes
|
||||
c.f.rebindCount++
|
||||
}
|
||||
|
||||
// ListHostmap returns details about the actual or pending (handshaking) hostmap
|
||||
func (c *Control) ListHostmap(pendingMap bool) []ControlHostInfo {
|
||||
var hm *HostMap
|
||||
if pendingMap {
|
||||
hm = c.f.handshakeManager.pendingHostMap
|
||||
} else {
|
||||
hm = c.f.hostMap
|
||||
}
|
||||
|
||||
hm.RLock()
|
||||
hosts := make([]ControlHostInfo, len(hm.Hosts))
|
||||
i := 0
|
||||
for _, v := range hm.Hosts {
|
||||
hosts[i] = copyHostInfo(v)
|
||||
i++
|
||||
}
|
||||
hm.RUnlock()
|
||||
|
||||
return hosts
|
||||
}
|
||||
|
||||
// GetHostInfoByVpnIP returns a single tunnels hostInfo, or nil if not found
|
||||
func (c *Control) GetHostInfoByVpnIP(vpnIP uint32, pending bool) *ControlHostInfo {
|
||||
var hm *HostMap
|
||||
if pending {
|
||||
hm = c.f.handshakeManager.pendingHostMap
|
||||
} else {
|
||||
hm = c.f.hostMap
|
||||
}
|
||||
|
||||
h, err := hm.QueryVpnIP(vpnIP)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
ch := copyHostInfo(h)
|
||||
return &ch
|
||||
}
|
||||
|
||||
// SetRemoteForTunnel forces a tunnel to use a specific remote
|
||||
func (c *Control) SetRemoteForTunnel(vpnIP uint32, addr udpAddr) *ControlHostInfo {
|
||||
hostInfo, err := c.f.hostMap.QueryVpnIP(vpnIP)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
hostInfo.SetRemote(addr.Copy())
|
||||
ch := copyHostInfo(hostInfo)
|
||||
return &ch
|
||||
}
|
||||
|
||||
// CloseTunnel closes a fully established tunnel. If localOnly is false it will notify the remote end as well.
|
||||
func (c *Control) CloseTunnel(vpnIP uint32, localOnly bool) bool {
|
||||
hostInfo, err := c.f.hostMap.QueryVpnIP(vpnIP)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if !localOnly {
|
||||
c.f.send(
|
||||
closeTunnel,
|
||||
0,
|
||||
hostInfo.ConnectionState,
|
||||
hostInfo,
|
||||
hostInfo.remote,
|
||||
[]byte{},
|
||||
make([]byte, 12, 12),
|
||||
make([]byte, mtu),
|
||||
)
|
||||
}
|
||||
|
||||
c.f.closeTunnel(hostInfo)
|
||||
return true
|
||||
}
|
||||
|
||||
func copyHostInfo(h *HostInfo) ControlHostInfo {
|
||||
addrs := h.RemoteUDPAddrs()
|
||||
chi := ControlHostInfo{
|
||||
VpnIP: int2ip(h.hostId),
|
||||
LocalIndex: h.localIndexId,
|
||||
RemoteIndex: h.remoteIndexId,
|
||||
RemoteAddrs: make([]udpAddr, len(addrs), len(addrs)),
|
||||
CachedPackets: len(h.packetStore),
|
||||
MessageCounter: atomic.LoadUint64(&h.ConnectionState.atomicMessageCounter),
|
||||
}
|
||||
|
||||
if c := h.GetCert(); c != nil {
|
||||
chi.Cert = c.Copy()
|
||||
}
|
||||
|
||||
if h.remote != nil {
|
||||
chi.CurrentRemote = *h.remote
|
||||
}
|
||||
|
||||
for i, addr := range addrs {
|
||||
chi.RemoteAddrs[i] = addr.Copy()
|
||||
}
|
||||
|
||||
return chi
|
||||
}
|
||||
108
control_test.go
Normal file
108
control_test.go
Normal file
@@ -0,0 +1,108 @@
|
||||
package nebula
|
||||
|
||||
import (
|
||||
"net"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/cert"
|
||||
"github.com/slackhq/nebula/util"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestControl_GetHostInfoByVpnIP(t *testing.T) {
|
||||
// Special care must be taken to re-use all objects provided to the hostmap and certificate in the expectedInfo object
|
||||
// To properly ensure we are not exposing core memory to the caller
|
||||
hm := NewHostMap("test", &net.IPNet{}, make([]*net.IPNet, 0))
|
||||
remote1 := NewUDPAddr(100, 4444)
|
||||
remote2 := NewUDPAddr(101, 4444)
|
||||
ipNet := net.IPNet{
|
||||
IP: net.IPv4(1, 2, 3, 4),
|
||||
Mask: net.IPMask{255, 255, 255, 0},
|
||||
}
|
||||
|
||||
ipNet2 := net.IPNet{
|
||||
IP: net.IPv4(1, 2, 3, 5),
|
||||
Mask: net.IPMask{255, 255, 255, 0},
|
||||
}
|
||||
|
||||
crt := &cert.NebulaCertificate{
|
||||
Details: cert.NebulaCertificateDetails{
|
||||
Name: "test",
|
||||
Ips: []*net.IPNet{&ipNet},
|
||||
Subnets: []*net.IPNet{},
|
||||
Groups: []string{"default-group"},
|
||||
NotBefore: time.Unix(1, 0),
|
||||
NotAfter: time.Unix(2, 0),
|
||||
PublicKey: []byte{5, 6, 7, 8},
|
||||
IsCA: false,
|
||||
Issuer: "the-issuer",
|
||||
InvertedGroups: map[string]struct{}{"default-group": {}},
|
||||
},
|
||||
Signature: []byte{1, 2, 1, 2, 1, 3},
|
||||
}
|
||||
|
||||
remotes := []*HostInfoDest{NewHostInfoDest(remote1), NewHostInfoDest(remote2)}
|
||||
hm.Add(ip2int(ipNet.IP), &HostInfo{
|
||||
remote: remote1,
|
||||
Remotes: remotes,
|
||||
ConnectionState: &ConnectionState{
|
||||
peerCert: crt,
|
||||
},
|
||||
remoteIndexId: 200,
|
||||
localIndexId: 201,
|
||||
hostId: ip2int(ipNet.IP),
|
||||
})
|
||||
|
||||
hm.Add(ip2int(ipNet2.IP), &HostInfo{
|
||||
remote: remote1,
|
||||
Remotes: remotes,
|
||||
ConnectionState: &ConnectionState{
|
||||
peerCert: nil,
|
||||
},
|
||||
remoteIndexId: 200,
|
||||
localIndexId: 201,
|
||||
hostId: ip2int(ipNet2.IP),
|
||||
})
|
||||
|
||||
c := Control{
|
||||
f: &Interface{
|
||||
hostMap: hm,
|
||||
},
|
||||
l: logrus.New(),
|
||||
}
|
||||
|
||||
thi := c.GetHostInfoByVpnIP(ip2int(ipNet.IP), false)
|
||||
|
||||
expectedInfo := ControlHostInfo{
|
||||
VpnIP: net.IPv4(1, 2, 3, 4).To4(),
|
||||
LocalIndex: 201,
|
||||
RemoteIndex: 200,
|
||||
RemoteAddrs: []udpAddr{*remote1, *remote2},
|
||||
CachedPackets: 0,
|
||||
Cert: crt.Copy(),
|
||||
MessageCounter: 0,
|
||||
CurrentRemote: *NewUDPAddr(100, 4444),
|
||||
}
|
||||
|
||||
// Make sure we don't have any unexpected fields
|
||||
assertFields(t, []string{"VpnIP", "LocalIndex", "RemoteIndex", "RemoteAddrs", "CachedPackets", "Cert", "MessageCounter", "CurrentRemote"}, thi)
|
||||
util.AssertDeepCopyEqual(t, &expectedInfo, thi)
|
||||
|
||||
// Make sure we don't panic if the host info doesn't have a cert yet
|
||||
assert.NotPanics(t, func() {
|
||||
thi = c.GetHostInfoByVpnIP(ip2int(ipNet2.IP), false)
|
||||
})
|
||||
}
|
||||
|
||||
func assertFields(t *testing.T, expected []string, actualStruct interface{}) {
|
||||
val := reflect.ValueOf(actualStruct).Elem()
|
||||
fields := make([]string, val.NumField())
|
||||
for i := 0; i < val.NumField(); i++ {
|
||||
fields[i] = val.Type().Field(i).Name
|
||||
}
|
||||
|
||||
assert.Equal(t, expected, fields)
|
||||
}
|
||||
15
dist/arch/nebula.service
vendored
Normal file
15
dist/arch/nebula.service
vendored
Normal file
@@ -0,0 +1,15 @@
|
||||
[Unit]
|
||||
Description=nebula
|
||||
Wants=basic.target network-online.target
|
||||
After=basic.target network.target network-online.target
|
||||
|
||||
[Service]
|
||||
SyslogIdentifier=nebula
|
||||
StandardOutput=syslog
|
||||
StandardError=syslog
|
||||
ExecReload=/bin/kill -HUP $MAINPID
|
||||
ExecStart=/usr/bin/nebula -config /etc/nebula/config.yml
|
||||
Restart=always
|
||||
|
||||
[Install]
|
||||
WantedBy=multi-user.target
|
||||
113
dist/wireshark/nebula.lua
vendored
Normal file
113
dist/wireshark/nebula.lua
vendored
Normal file
@@ -0,0 +1,113 @@
|
||||
local nebula = Proto("nebula", "nebula")
|
||||
|
||||
local default_settings = {
|
||||
port = 4242,
|
||||
all_ports = false,
|
||||
}
|
||||
|
||||
nebula.prefs.port = Pref.uint("Port number", default_settings.port, "The UDP port number for Nebula")
|
||||
nebula.prefs.all_ports = Pref.bool("All ports", default_settings.all_ports, "Assume nebula packets on any port, useful when dealing with hole punching")
|
||||
|
||||
local pf_version = ProtoField.new("version", "nebula.version", ftypes.UINT8, nil, base.DEC, 0xF0)
|
||||
local pf_type = ProtoField.new("type", "nebula.type", ftypes.UINT8, {
|
||||
[0] = "handshake",
|
||||
[1] = "message",
|
||||
[2] = "recvError",
|
||||
[3] = "lightHouse",
|
||||
[4] = "test",
|
||||
[5] = "closeTunnel",
|
||||
}, base.DEC, 0x0F)
|
||||
|
||||
local pf_subtype = ProtoField.new("subtype", "nebula.subtype", ftypes.UINT8, nil, base.DEC)
|
||||
local pf_subtype_test = ProtoField.new("subtype", "nebula.subtype", ftypes.UINT8, {
|
||||
[0] = "request",
|
||||
[1] = "reply",
|
||||
}, base.DEC)
|
||||
|
||||
local pf_subtype_handshake = ProtoField.new("subtype", "nebula.subtype", ftypes.UINT8, {
|
||||
[0] = "ix_psk0",
|
||||
}, base.DEC)
|
||||
|
||||
local pf_reserved = ProtoField.new("reserved", "nebula.reserved", ftypes.UINT16, nil, base.HEX)
|
||||
local pf_remote_index = ProtoField.new("remote index", "nebula.remote_index", ftypes.UINT32, nil, base.DEC)
|
||||
local pf_message_counter = ProtoField.new("counter", "nebula.counter", ftypes.UINT64, nil, base.DEC)
|
||||
local pf_payload = ProtoField.new("payload", "nebula.payload", ftypes.BYTES, nil, base.NONE)
|
||||
|
||||
nebula.fields = { pf_version, pf_type, pf_subtype, pf_subtype_handshake, pf_subtype_test, pf_reserved, pf_remote_index, pf_message_counter, pf_payload }
|
||||
|
||||
local ef_holepunch = ProtoExpert.new("nebula.holepunch.expert", "Nebula hole punch packet", expert.group.PROTOCOL, expert.severity.NOTE)
|
||||
local ef_punchy = ProtoExpert.new("nebula.punchy.expert", "Nebula punchy keepalive packet", expert.group.PROTOCOL, expert.severity.NOTE)
|
||||
|
||||
nebula.experts = { ef_holepunch, ef_punchy }
|
||||
local type_field = Field.new("nebula.type")
|
||||
local subtype_field = Field.new("nebula.subtype")
|
||||
|
||||
function nebula.dissector(tvbuf, pktinfo, root)
|
||||
-- set the protocol column to show our protocol name
|
||||
pktinfo.cols.protocol:set("NEBULA")
|
||||
|
||||
local pktlen = tvbuf:reported_length_remaining()
|
||||
local tree = root:add(nebula, tvbuf:range(0,pktlen))
|
||||
|
||||
if pktlen == 0 then
|
||||
tree:add_proto_expert_info(ef_holepunch)
|
||||
pktinfo.cols.info:append(" (holepunch)")
|
||||
return
|
||||
elseif pktlen == 1 then
|
||||
tree:add_proto_expert_info(ef_punchy)
|
||||
pktinfo.cols.info:append(" (punchy)")
|
||||
return
|
||||
end
|
||||
|
||||
tree:add(pf_version, tvbuf:range(0,1))
|
||||
local type = tree:add(pf_type, tvbuf:range(0,1))
|
||||
|
||||
local nebula_type = bit32.band(tvbuf:range(0,1):uint(), 0x0F)
|
||||
if nebula_type == 0 then
|
||||
local stage = tvbuf(8,8):uint64()
|
||||
tree:add(pf_subtype_handshake, tvbuf:range(1,1))
|
||||
type:append_text(" stage " .. stage)
|
||||
pktinfo.cols.info:append(" (" .. type_field().display .. ", stage " .. stage .. ", " .. subtype_field().display .. ")")
|
||||
elseif nebula_type == 4 then
|
||||
tree:add(pf_subtype_test, tvbuf:range(1,1))
|
||||
pktinfo.cols.info:append(" (" .. type_field().display .. ", " .. subtype_field().display .. ")")
|
||||
else
|
||||
tree:add(pf_subtype, tvbuf:range(1,1))
|
||||
pktinfo.cols.info:append(" (" .. type_field().display .. ")")
|
||||
end
|
||||
|
||||
tree:add(pf_reserved, tvbuf:range(2,2))
|
||||
tree:add(pf_remote_index, tvbuf:range(4,4))
|
||||
tree:add(pf_message_counter, tvbuf:range(8,8))
|
||||
tree:add(pf_payload, tvbuf:range(16,tvbuf:len() - 16))
|
||||
end
|
||||
|
||||
function nebula.prefs_changed()
|
||||
if default_settings.all_ports == nebula.prefs.all_ports and default_settings.port == nebula.prefs.port then
|
||||
-- Nothing changed, bail
|
||||
return
|
||||
end
|
||||
|
||||
-- Remove our old dissector
|
||||
DissectorTable.get("udp.port"):remove_all(nebula)
|
||||
|
||||
if nebula.prefs.all_ports and default_settings.all_ports ~= nebula.prefs.all_ports then
|
||||
default_settings.all_port = nebula.prefs.all_ports
|
||||
|
||||
for i=0, 65535 do
|
||||
DissectorTable.get("udp.port"):add(i, nebula)
|
||||
end
|
||||
|
||||
-- no need to establish again on specific ports
|
||||
return
|
||||
end
|
||||
|
||||
|
||||
if default_settings.all_ports ~= nebula.prefs.all_ports then
|
||||
-- Add our new port dissector
|
||||
default_settings.port = nebula.prefs.port
|
||||
DissectorTable.get("udp.port"):add(default_settings.port, nebula)
|
||||
end
|
||||
end
|
||||
|
||||
DissectorTable.get("udp.port"):add(default_settings.port, nebula)
|
||||
@@ -7,8 +7,8 @@ pki:
|
||||
ca: /etc/nebula/ca.crt
|
||||
cert: /etc/nebula/host.crt
|
||||
key: /etc/nebula/host.key
|
||||
#blacklist is a list of certificate fingerprints that we will refuse to talk to
|
||||
#blacklist:
|
||||
#blocklist is a list of certificate fingerprints that we will refuse to talk to
|
||||
#blocklist:
|
||||
# - c99d4e650533b92061b09918e838a5a0a6aaee21eed1d12fd937682865936c72
|
||||
|
||||
# The static host map defines a set of hosts with fixed IP addresses on the internet (or any network).
|
||||
@@ -36,9 +36,41 @@ lighthouse:
|
||||
interval: 60
|
||||
# hosts is a list of lighthouse hosts this node should report to and query from
|
||||
# IMPORTANT: THIS SHOULD BE EMPTY ON LIGHTHOUSE NODES
|
||||
# IMPORTANT2: THIS SHOULD BE LIGHTHOUSES' NEBULA IPs, NOT LIGHTHOUSES' REAL ROUTABLE IPs
|
||||
hosts:
|
||||
- "192.168.100.1"
|
||||
|
||||
# remote_allow_list allows you to control ip ranges that this node will
|
||||
# consider when handshaking to another node. By default, any remote IPs are
|
||||
# allowed. You can provide CIDRs here with `true` to allow and `false` to
|
||||
# deny. The most specific CIDR rule applies to each remote. If all rules are
|
||||
# "allow", the default will be "deny", and vice-versa. If both "allow" and
|
||||
# "deny" rules are present, then you MUST set a rule for "0.0.0.0/0" as the
|
||||
# default.
|
||||
#remote_allow_list:
|
||||
# Example to block IPs from this subnet from being used for remote IPs.
|
||||
#"172.16.0.0/12": false
|
||||
|
||||
# A more complicated example, allow public IPs but only private IPs from a specific subnet
|
||||
#"0.0.0.0/0": true
|
||||
#"10.0.0.0/8": false
|
||||
#"10.42.42.0/24": true
|
||||
|
||||
# local_allow_list allows you to filter which local IP addresses we advertise
|
||||
# to the lighthouses. This uses the same logic as `remote_allow_list`, but
|
||||
# additionally, you can specify an `interfaces` map of regular expressions
|
||||
# to match against interface names. The regexp must match the entire name.
|
||||
# All interface rules must be either true or false (and the default will be
|
||||
# the inverse). CIDR rules are matched after interface name rules.
|
||||
# Default is all local IP addresses.
|
||||
#local_allow_list:
|
||||
# Example to block tun0 and all docker interfaces.
|
||||
#interfaces:
|
||||
#tun0: false
|
||||
#'docker.*': false
|
||||
# Example to only advertise this subnet to the lighthouse.
|
||||
#"10.0.0.0/8": true
|
||||
|
||||
# Port Nebula will be listening on. The default here is 4242. For a lighthouse node, the port should be defined,
|
||||
# however using port 0 will dynamically assign a port and is recommended for roaming nodes.
|
||||
listen:
|
||||
@@ -54,13 +86,28 @@ listen:
|
||||
#read_buffer: 10485760
|
||||
#write_buffer: 10485760
|
||||
|
||||
# Punchy continues to punch inbound/outbound at a regular interval to avoid expiration of firewall nat mappings
|
||||
punchy: true
|
||||
# punch_back means that a node you are trying to reach will connect back out to you if your hole punching fails
|
||||
# this is extremely useful if one node is behind a difficult nat, such as symmetric
|
||||
#punch_back: true
|
||||
# EXPERIMENTAL: This option is currently only supported on linux and may
|
||||
# change in future minor releases.
|
||||
#
|
||||
# Routines is the number of thread pairs to run that consume from the tun and UDP queues.
|
||||
# Currently, this defaults to 1 which means we have 1 tun queue reader and 1
|
||||
# UDP queue reader. Setting this above one will set IFF_MULTI_QUEUE on the tun
|
||||
# device and SO_REUSEPORT on the UDP socket to allow multiple queues.
|
||||
#routines: 1
|
||||
|
||||
# Cipher allows you to choose between the available ciphers for your network.
|
||||
punchy:
|
||||
# Continues to punch inbound/outbound at a regular interval to avoid expiration of firewall nat mappings
|
||||
punch: true
|
||||
|
||||
# respond means that a node you are trying to reach will connect back out to you if your hole punching fails
|
||||
# this is extremely useful if one node is behind a difficult nat, such as a symmetric NAT
|
||||
# Default is false
|
||||
#respond: true
|
||||
|
||||
# delays a punch response for misbehaving NATs, default is 1 second, respond must be true to take effect
|
||||
#delay: 1s
|
||||
|
||||
# Cipher allows you to choose between the available ciphers for your network. Options are chachapoly or aes
|
||||
# IMPORTANT: this value must be identical on ALL NODES/LIGHTHOUSES. We do not/will not support use of different ciphers simultaneously!
|
||||
#cipher: chachapoly
|
||||
|
||||
@@ -86,6 +133,8 @@ punchy: true
|
||||
|
||||
# Configure the private interface. Note: addr is baked into the nebula certificate
|
||||
tun:
|
||||
# When tun is disabled, a lighthouse can be started without a local tun interface (and therefore without root)
|
||||
disabled: false
|
||||
# Name of the device
|
||||
dev: nebula1
|
||||
# Toggles forwarding of local broadcast packets, the address of which depends on the ip/mask encoded in pki.cert
|
||||
@@ -116,6 +165,16 @@ logging:
|
||||
level: info
|
||||
# json or text formats currently available. Default is text
|
||||
format: text
|
||||
# Disable timestamp logging. useful when output is redirected to logging system that already adds timestamps. Default is false
|
||||
#disable_timestamp: true
|
||||
# timestamp format is specified in Go time format, see:
|
||||
# https://golang.org/pkg/time/#pkg-constants
|
||||
# default when `format: json`: "2006-01-02T15:04:05Z07:00" (RFC3339)
|
||||
# default when `format: text`:
|
||||
# when TTY attached: seconds since beginning of execution
|
||||
# otherwise: "2006-01-02T15:04:05Z07:00" (RFC3339)
|
||||
# As an example, to log as RFC3339 with millisecond precision, set to:
|
||||
#timestamp_format: "2006-01-02T15:04:05.000Z07:00"
|
||||
|
||||
#stats:
|
||||
#type: graphite
|
||||
@@ -131,10 +190,31 @@ logging:
|
||||
#subsystem: nebula
|
||||
#interval: 10s
|
||||
|
||||
# enables counter metrics for meta packets
|
||||
# e.g.: `messages.tx.handshake`
|
||||
# NOTE: `message.{tx,rx}.recv_error` is always emitted
|
||||
#message_metrics: false
|
||||
|
||||
# enables detailed counter metrics for lighthouse packets
|
||||
# e.g.: `lighthouse.rx.HostQuery`
|
||||
#lighthouse_metrics: false
|
||||
|
||||
# Handshake Manger Settings
|
||||
#handshakes:
|
||||
# Total time to try a handshake = sequence of `try_interval * retries`
|
||||
# With 100ms interval and 20 retries it is 23.5 seconds
|
||||
#try_interval: 100ms
|
||||
#retries: 20
|
||||
# wait_rotation is the number of handshake attempts to do before starting to try non-local IP addresses
|
||||
#wait_rotation: 5
|
||||
# trigger_buffer is the size of the buffer channel for quickly sending handshakes
|
||||
# after receiving the response for lighthouse queries
|
||||
#trigger_buffer: 64
|
||||
|
||||
# Nebula security group configuration
|
||||
firewall:
|
||||
conntrack:
|
||||
tcp_timeout: 120h
|
||||
tcp_timeout: 12m
|
||||
udp_timeout: 3m
|
||||
default_timeout: 10m
|
||||
max_connections: 100000
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
Description=nebula
|
||||
Wants=basic.target
|
||||
After=basic.target network.target
|
||||
Before=sshd.service
|
||||
|
||||
[Service]
|
||||
SyslogIdentifier=nebula
|
||||
|
||||
217
firewall.go
217
firewall.go
@@ -1,21 +1,22 @@
|
||||
package nebula
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/rcrowley/go-metrics"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/cert"
|
||||
)
|
||||
|
||||
@@ -38,13 +39,19 @@ type FirewallInterface interface {
|
||||
|
||||
type conn struct {
|
||||
Expires time.Time // Time when this conntrack entry will expire
|
||||
Seq uint32 // If tcp rtt tracking is enabled this will be the seq we are looking for an ack
|
||||
Sent time.Time // If tcp rtt tracking is enabled this will be when Seq was last set
|
||||
Seq uint32 // If tcp rtt tracking is enabled this will be the seq we are looking for an ack
|
||||
|
||||
// record why the original connection passed the firewall, so we can re-validate
|
||||
// after ruleset changes. Note, rulesVersion is a uint16 so that these two
|
||||
// fields pack for free after the uint32 above
|
||||
incoming bool
|
||||
rulesVersion uint16
|
||||
}
|
||||
|
||||
// TODO: need conntrack max tracked connections handling
|
||||
type Firewall struct {
|
||||
Conns map[FirewallPacket]*conn
|
||||
Conntrack *FirewallConntrack
|
||||
|
||||
InRules *FirewallTable
|
||||
OutRules *FirewallTable
|
||||
@@ -55,18 +62,23 @@ type Firewall struct {
|
||||
UDPTimeout time.Duration //linux: 180s max
|
||||
DefaultTimeout time.Duration //linux: 600s
|
||||
|
||||
TimerWheel *TimerWheel
|
||||
|
||||
// Used to ensure we don't emit local packets for ips we don't own
|
||||
localIps *CIDRTree
|
||||
|
||||
connMutex sync.Mutex
|
||||
rules string
|
||||
rules string
|
||||
rulesVersion uint16
|
||||
|
||||
trackTCPRTT bool
|
||||
metricTCPRTT metrics.Histogram
|
||||
}
|
||||
|
||||
type FirewallConntrack struct {
|
||||
sync.Mutex
|
||||
|
||||
Conns map[FirewallPacket]*conn
|
||||
TimerWheel *TimerWheel
|
||||
}
|
||||
|
||||
type FirewallTable struct {
|
||||
TCP firewallPort
|
||||
UDP firewallPort
|
||||
@@ -172,10 +184,12 @@ func NewFirewall(tcpTimeout, UDPTimeout, defaultTimeout time.Duration, c *cert.N
|
||||
}
|
||||
|
||||
return &Firewall{
|
||||
Conns: make(map[FirewallPacket]*conn),
|
||||
Conntrack: &FirewallConntrack{
|
||||
Conns: make(map[FirewallPacket]*conn),
|
||||
TimerWheel: NewTimerWheel(min, max),
|
||||
},
|
||||
InRules: newFirewallTable(),
|
||||
OutRules: newFirewallTable(),
|
||||
TimerWheel: NewTimerWheel(min, max),
|
||||
TCPTimeout: tcpTimeout,
|
||||
UDPTimeout: UDPTimeout,
|
||||
DefaultTimeout: defaultTimeout,
|
||||
@@ -208,11 +222,17 @@ func NewFirewallFromConfig(nc *cert.NebulaCertificate, c *Config) (*Firewall, er
|
||||
|
||||
// AddRule properly creates the in memory rule structure for a firewall table.
|
||||
func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, caName string, caSha string) error {
|
||||
// Under gomobile, stringing a nil pointer with fmt causes an abort in debug mode for iOS
|
||||
// https://github.com/golang/go/issues/14131
|
||||
sIp := ""
|
||||
if ip != nil {
|
||||
sIp = ip.String()
|
||||
}
|
||||
|
||||
// We need this rule string because we generate a hash. Removing this will break firewall reload.
|
||||
ruleString := fmt.Sprintf(
|
||||
"incoming: %v, proto: %v, startPort: %v, endPort: %v, groups: %v, host: %v, ip: %v, caName: %v, caSha: %s",
|
||||
incoming, proto, startPort, endPort, groups, host, ip, caName, caSha,
|
||||
incoming, proto, startPort, endPort, groups, host, sIp, caName, caSha,
|
||||
)
|
||||
f.rules += ruleString + "\n"
|
||||
|
||||
@@ -220,7 +240,7 @@ func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort
|
||||
if !incoming {
|
||||
direction = "outgoing"
|
||||
}
|
||||
l.WithField("firewallRule", m{"direction": direction, "proto": proto, "startPort": startPort, "endPort": endPort, "groups": groups, "host": host, "ip": ip, "caName": caName, "caSha": caSha}).
|
||||
l.WithField("firewallRule", m{"direction": direction, "proto": proto, "startPort": startPort, "endPort": endPort, "groups": groups, "host": host, "ip": sIp, "caName": caName, "caSha": caSha}).
|
||||
Info("Firewall rule added")
|
||||
|
||||
var (
|
||||
@@ -347,20 +367,33 @@ func AddFirewallRulesFromConfig(inbound bool, config *Config, fw FirewallInterfa
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *Firewall) Drop(packet []byte, fp FirewallPacket, incoming bool, h *HostInfo, caPool *cert.NebulaCAPool) bool {
|
||||
var ErrInvalidRemoteIP = errors.New("remote IP is not in remote certificate subnets")
|
||||
var ErrInvalidLocalIP = errors.New("local IP is not in list of handled local IPs")
|
||||
var ErrNoMatchingRule = errors.New("no matching rule in firewall table")
|
||||
|
||||
// Drop returns an error if the packet should be dropped, explaining why. It
|
||||
// returns nil if the packet should not be dropped.
|
||||
func (f *Firewall) Drop(packet []byte, fp FirewallPacket, incoming bool, h *HostInfo, caPool *cert.NebulaCAPool, localCache ConntrackCache) error {
|
||||
// Check if we spoke to this tuple, if we did then allow this packet
|
||||
if f.inConns(packet, fp, incoming) {
|
||||
return false
|
||||
if f.inConns(packet, fp, incoming, h, caPool, localCache) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Make sure remote address matches nebula certificate
|
||||
if h.remoteCidr.Contains(fp.RemoteIP) == nil {
|
||||
return true
|
||||
if remoteCidr := h.remoteCidr; remoteCidr != nil {
|
||||
if remoteCidr.Contains(fp.RemoteIP) == nil {
|
||||
return ErrInvalidRemoteIP
|
||||
}
|
||||
} else {
|
||||
// Simple case: Certificate has one IP and no subnets
|
||||
if fp.RemoteIP != h.hostId {
|
||||
return ErrInvalidRemoteIP
|
||||
}
|
||||
}
|
||||
|
||||
// Make sure we are supposed to be handling this local ip address
|
||||
if f.localIps.Contains(fp.LocalIP) == nil {
|
||||
return true
|
||||
return ErrInvalidLocalIP
|
||||
}
|
||||
|
||||
table := f.OutRules
|
||||
@@ -370,13 +403,13 @@ func (f *Firewall) Drop(packet []byte, fp FirewallPacket, incoming bool, h *Host
|
||||
|
||||
// We now know which firewall table to check against
|
||||
if !table.match(fp, incoming, h.ConnectionState.peerCert, caPool) {
|
||||
return true
|
||||
return ErrNoMatchingRule
|
||||
}
|
||||
|
||||
// We always want to conntrack since it is a faster operation
|
||||
f.addConn(packet, fp, incoming)
|
||||
|
||||
return false
|
||||
return nil
|
||||
}
|
||||
|
||||
// Destroy cleans up any known cyclical references so the object can be free'd my GC. This should be called if a new
|
||||
@@ -386,26 +419,71 @@ func (f *Firewall) Destroy() {
|
||||
}
|
||||
|
||||
func (f *Firewall) EmitStats() {
|
||||
conntrackCount := len(f.Conns)
|
||||
conntrack := f.Conntrack
|
||||
conntrack.Lock()
|
||||
conntrackCount := len(conntrack.Conns)
|
||||
conntrack.Unlock()
|
||||
metrics.GetOrRegisterGauge("firewall.conntrack.count", nil).Update(int64(conntrackCount))
|
||||
metrics.GetOrRegisterGauge("firewall.rules.version", nil).Update(int64(f.rulesVersion))
|
||||
}
|
||||
|
||||
func (f *Firewall) inConns(packet []byte, fp FirewallPacket, incoming bool) bool {
|
||||
f.connMutex.Lock()
|
||||
func (f *Firewall) inConns(packet []byte, fp FirewallPacket, incoming bool, h *HostInfo, caPool *cert.NebulaCAPool, localCache ConntrackCache) bool {
|
||||
if localCache != nil {
|
||||
if _, ok := localCache[fp]; ok {
|
||||
return true
|
||||
}
|
||||
}
|
||||
conntrack := f.Conntrack
|
||||
conntrack.Lock()
|
||||
|
||||
// Purge every time we test
|
||||
ep, has := f.TimerWheel.Purge()
|
||||
ep, has := conntrack.TimerWheel.Purge()
|
||||
if has {
|
||||
f.evict(ep)
|
||||
}
|
||||
|
||||
c, ok := f.Conns[fp]
|
||||
c, ok := conntrack.Conns[fp]
|
||||
|
||||
if !ok {
|
||||
f.connMutex.Unlock()
|
||||
conntrack.Unlock()
|
||||
return false
|
||||
}
|
||||
|
||||
if c.rulesVersion != f.rulesVersion {
|
||||
// This conntrack entry was for an older rule set, validate
|
||||
// it still passes with the current rule set
|
||||
table := f.OutRules
|
||||
if c.incoming {
|
||||
table = f.InRules
|
||||
}
|
||||
|
||||
// We now know which firewall table to check against
|
||||
if !table.match(fp, c.incoming, h.ConnectionState.peerCert, caPool) {
|
||||
if l.Level >= logrus.DebugLevel {
|
||||
h.logger().
|
||||
WithField("fwPacket", fp).
|
||||
WithField("incoming", c.incoming).
|
||||
WithField("rulesVersion", f.rulesVersion).
|
||||
WithField("oldRulesVersion", c.rulesVersion).
|
||||
Debugln("dropping old conntrack entry, does not match new ruleset")
|
||||
}
|
||||
delete(conntrack.Conns, fp)
|
||||
conntrack.Unlock()
|
||||
return false
|
||||
}
|
||||
|
||||
if l.Level >= logrus.DebugLevel {
|
||||
h.logger().
|
||||
WithField("fwPacket", fp).
|
||||
WithField("incoming", c.incoming).
|
||||
WithField("rulesVersion", f.rulesVersion).
|
||||
WithField("oldRulesVersion", c.rulesVersion).
|
||||
Debugln("keeping old conntrack entry, does match new ruleset")
|
||||
}
|
||||
|
||||
c.rulesVersion = f.rulesVersion
|
||||
}
|
||||
|
||||
switch fp.Protocol {
|
||||
case fwProtoTCP:
|
||||
c.Expires = time.Now().Add(f.TCPTimeout)
|
||||
@@ -420,7 +498,11 @@ func (f *Firewall) inConns(packet []byte, fp FirewallPacket, incoming bool) bool
|
||||
c.Expires = time.Now().Add(f.DefaultTimeout)
|
||||
}
|
||||
|
||||
f.connMutex.Unlock()
|
||||
conntrack.Unlock()
|
||||
|
||||
if localCache != nil {
|
||||
localCache[fp] = struct{}{}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
@@ -441,14 +523,19 @@ func (f *Firewall) addConn(packet []byte, fp FirewallPacket, incoming bool) {
|
||||
timeout = f.DefaultTimeout
|
||||
}
|
||||
|
||||
f.connMutex.Lock()
|
||||
if _, ok := f.Conns[fp]; !ok {
|
||||
f.TimerWheel.Add(fp, timeout)
|
||||
conntrack := f.Conntrack
|
||||
conntrack.Lock()
|
||||
if _, ok := conntrack.Conns[fp]; !ok {
|
||||
conntrack.TimerWheel.Add(fp, timeout)
|
||||
}
|
||||
|
||||
// Record which rulesVersion allowed this connection, so we can retest after
|
||||
// firewall reload
|
||||
c.incoming = incoming
|
||||
c.rulesVersion = f.rulesVersion
|
||||
c.Expires = time.Now().Add(timeout)
|
||||
f.Conns[fp] = c
|
||||
f.connMutex.Unlock()
|
||||
conntrack.Conns[fp] = c
|
||||
conntrack.Unlock()
|
||||
}
|
||||
|
||||
// Evict checks if a conntrack entry has expired, if so it is removed, if not it is re-added to the wheel
|
||||
@@ -456,7 +543,8 @@ func (f *Firewall) addConn(packet []byte, fp FirewallPacket, incoming bool) {
|
||||
func (f *Firewall) evict(p FirewallPacket) {
|
||||
//TODO: report a stat if the tcp rtt tracking was never resolved?
|
||||
// Are we still tracking this conn?
|
||||
t, ok := f.Conns[p]
|
||||
conntrack := f.Conntrack
|
||||
t, ok := conntrack.Conns[p]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
@@ -465,12 +553,12 @@ func (f *Firewall) evict(p FirewallPacket) {
|
||||
|
||||
// Timeout is in the future, re-add the timer
|
||||
if newT > 0 {
|
||||
f.TimerWheel.Add(p, newT)
|
||||
conntrack.TimerWheel.Add(p, newT)
|
||||
return
|
||||
}
|
||||
|
||||
// This conn is done
|
||||
delete(f.Conns, p)
|
||||
delete(conntrack.Conns, p)
|
||||
}
|
||||
|
||||
func (ft *FirewallTable) match(p FirewallPacket, incoming bool, c *cert.NebulaCertificate, caPool *cert.NebulaCAPool) bool {
|
||||
@@ -845,3 +933,54 @@ func (f *Firewall) checkTCPRTT(c *conn, p []byte) bool {
|
||||
c.Seq = 0
|
||||
return true
|
||||
}
|
||||
|
||||
// ConntrackCache is used as a local routine cache to know if a given flow
|
||||
// has been seen in the conntrack table.
|
||||
type ConntrackCache map[FirewallPacket]struct{}
|
||||
|
||||
type ConntrackCacheTicker struct {
|
||||
cacheV uint64
|
||||
cacheTick uint64
|
||||
|
||||
cache ConntrackCache
|
||||
}
|
||||
|
||||
func NewConntrackCacheTicker(d time.Duration) *ConntrackCacheTicker {
|
||||
if d == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
c := &ConntrackCacheTicker{
|
||||
cache: ConntrackCache{},
|
||||
}
|
||||
|
||||
go c.tick(d)
|
||||
|
||||
return c
|
||||
}
|
||||
|
||||
func (c *ConntrackCacheTicker) tick(d time.Duration) {
|
||||
for {
|
||||
time.Sleep(d)
|
||||
atomic.AddUint64(&c.cacheTick, 1)
|
||||
}
|
||||
}
|
||||
|
||||
// Get checks if the cache ticker has moved to the next version before returning
|
||||
// the map. If it has moved, we reset the map.
|
||||
func (c *ConntrackCacheTicker) Get() ConntrackCache {
|
||||
if c == nil {
|
||||
return nil
|
||||
}
|
||||
if tick := atomic.LoadUint64(&c.cacheTick); tick != c.cacheV {
|
||||
c.cacheV = tick
|
||||
if ll := len(c.cache); ll > 0 {
|
||||
if l.GetLevel() == logrus.DebugLevel {
|
||||
l.WithField("len", ll).Debug("resetting conntrack cache")
|
||||
}
|
||||
c.cache = make(ConntrackCache, ll)
|
||||
}
|
||||
}
|
||||
|
||||
return c.cache
|
||||
}
|
||||
|
||||
137
firewall_test.go
137
firewall_test.go
@@ -17,37 +17,39 @@ import (
|
||||
func TestNewFirewall(t *testing.T) {
|
||||
c := &cert.NebulaCertificate{}
|
||||
fw := NewFirewall(time.Second, time.Minute, time.Hour, c)
|
||||
assert.NotNil(t, fw.Conns)
|
||||
conntrack := fw.Conntrack
|
||||
assert.NotNil(t, conntrack)
|
||||
assert.NotNil(t, conntrack.Conns)
|
||||
assert.NotNil(t, conntrack.TimerWheel)
|
||||
assert.NotNil(t, fw.InRules)
|
||||
assert.NotNil(t, fw.OutRules)
|
||||
assert.NotNil(t, fw.TimerWheel)
|
||||
assert.Equal(t, time.Second, fw.TCPTimeout)
|
||||
assert.Equal(t, time.Minute, fw.UDPTimeout)
|
||||
assert.Equal(t, time.Hour, fw.DefaultTimeout)
|
||||
|
||||
assert.Equal(t, time.Hour, fw.TimerWheel.wheelDuration)
|
||||
assert.Equal(t, time.Hour, fw.TimerWheel.wheelDuration)
|
||||
assert.Equal(t, 3601, fw.TimerWheel.wheelLen)
|
||||
assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
|
||||
assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
|
||||
assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen)
|
||||
|
||||
fw = NewFirewall(time.Second, time.Hour, time.Minute, c)
|
||||
assert.Equal(t, time.Hour, fw.TimerWheel.wheelDuration)
|
||||
assert.Equal(t, 3601, fw.TimerWheel.wheelLen)
|
||||
assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
|
||||
assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen)
|
||||
|
||||
fw = NewFirewall(time.Hour, time.Second, time.Minute, c)
|
||||
assert.Equal(t, time.Hour, fw.TimerWheel.wheelDuration)
|
||||
assert.Equal(t, 3601, fw.TimerWheel.wheelLen)
|
||||
assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
|
||||
assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen)
|
||||
|
||||
fw = NewFirewall(time.Hour, time.Minute, time.Second, c)
|
||||
assert.Equal(t, time.Hour, fw.TimerWheel.wheelDuration)
|
||||
assert.Equal(t, 3601, fw.TimerWheel.wheelLen)
|
||||
assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
|
||||
assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen)
|
||||
|
||||
fw = NewFirewall(time.Minute, time.Hour, time.Second, c)
|
||||
assert.Equal(t, time.Hour, fw.TimerWheel.wheelDuration)
|
||||
assert.Equal(t, 3601, fw.TimerWheel.wheelLen)
|
||||
assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
|
||||
assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen)
|
||||
|
||||
fw = NewFirewall(time.Minute, time.Second, time.Hour, c)
|
||||
assert.Equal(t, time.Hour, fw.TimerWheel.wheelDuration)
|
||||
assert.Equal(t, 3601, fw.TimerWheel.wheelLen)
|
||||
assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
|
||||
assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen)
|
||||
}
|
||||
|
||||
func TestFirewall_AddRule(t *testing.T) {
|
||||
@@ -171,6 +173,7 @@ func TestFirewall_Drop(t *testing.T) {
|
||||
ConnectionState: &ConnectionState{
|
||||
peerCert: &c,
|
||||
},
|
||||
hostId: ip2int(ipNet.IP),
|
||||
}
|
||||
h.CreateRemoteCIDR(&c)
|
||||
|
||||
@@ -179,44 +182,44 @@ func TestFirewall_Drop(t *testing.T) {
|
||||
cp := cert.NewCAPool()
|
||||
|
||||
// Drop outbound
|
||||
assert.True(t, fw.Drop([]byte{}, p, false, &h, cp))
|
||||
assert.Equal(t, fw.Drop([]byte{}, p, false, &h, cp, nil), ErrNoMatchingRule)
|
||||
// Allow inbound
|
||||
resetConntrack(fw)
|
||||
assert.False(t, fw.Drop([]byte{}, p, true, &h, cp))
|
||||
assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp, nil))
|
||||
// Allow outbound because conntrack
|
||||
assert.False(t, fw.Drop([]byte{}, p, false, &h, cp))
|
||||
assert.NoError(t, fw.Drop([]byte{}, p, false, &h, cp, nil))
|
||||
|
||||
// test remote mismatch
|
||||
oldRemote := p.RemoteIP
|
||||
p.RemoteIP = ip2int(net.IPv4(1, 2, 3, 10))
|
||||
assert.True(t, fw.Drop([]byte{}, p, false, &h, cp))
|
||||
assert.Equal(t, fw.Drop([]byte{}, p, false, &h, cp, nil), ErrInvalidRemoteIP)
|
||||
p.RemoteIP = oldRemote
|
||||
|
||||
// ensure signer doesn't get in the way of group checks
|
||||
fw = NewFirewall(time.Second, time.Minute, time.Hour, &c)
|
||||
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"nope"}, "", nil, "", "signer-shasum"))
|
||||
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group"}, "", nil, "", "signer-shasum-bad"))
|
||||
assert.True(t, fw.Drop([]byte{}, p, true, &h, cp))
|
||||
assert.Equal(t, fw.Drop([]byte{}, p, true, &h, cp, nil), ErrNoMatchingRule)
|
||||
|
||||
// test caSha doesn't drop on match
|
||||
fw = NewFirewall(time.Second, time.Minute, time.Hour, &c)
|
||||
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"nope"}, "", nil, "", "signer-shasum-bad"))
|
||||
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group"}, "", nil, "", "signer-shasum"))
|
||||
assert.False(t, fw.Drop([]byte{}, p, true, &h, cp))
|
||||
assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp, nil))
|
||||
|
||||
// ensure ca name doesn't get in the way of group checks
|
||||
cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}}
|
||||
fw = NewFirewall(time.Second, time.Minute, time.Hour, &c)
|
||||
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"nope"}, "", nil, "ca-good", ""))
|
||||
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group"}, "", nil, "ca-good-bad", ""))
|
||||
assert.True(t, fw.Drop([]byte{}, p, true, &h, cp))
|
||||
assert.Equal(t, fw.Drop([]byte{}, p, true, &h, cp, nil), ErrNoMatchingRule)
|
||||
|
||||
// test caName doesn't drop on match
|
||||
cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}}
|
||||
fw = NewFirewall(time.Second, time.Minute, time.Hour, &c)
|
||||
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"nope"}, "", nil, "ca-good-bad", ""))
|
||||
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group"}, "", nil, "ca-good", ""))
|
||||
assert.False(t, fw.Drop([]byte{}, p, true, &h, cp))
|
||||
assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp, nil))
|
||||
}
|
||||
|
||||
func BenchmarkFirewallTable_match(b *testing.B) {
|
||||
@@ -344,6 +347,7 @@ func TestFirewall_Drop2(t *testing.T) {
|
||||
ConnectionState: &ConnectionState{
|
||||
peerCert: &c,
|
||||
},
|
||||
hostId: ip2int(ipNet.IP),
|
||||
}
|
||||
h.CreateRemoteCIDR(&c)
|
||||
|
||||
@@ -366,10 +370,10 @@ func TestFirewall_Drop2(t *testing.T) {
|
||||
cp := cert.NewCAPool()
|
||||
|
||||
// h1/c1 lacks the proper groups
|
||||
assert.True(t, fw.Drop([]byte{}, p, true, &h1, cp))
|
||||
assert.Error(t, fw.Drop([]byte{}, p, true, &h1, cp, nil), ErrNoMatchingRule)
|
||||
// c has the proper groups
|
||||
resetConntrack(fw)
|
||||
assert.False(t, fw.Drop([]byte{}, p, true, &h, cp))
|
||||
assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp, nil))
|
||||
}
|
||||
|
||||
func TestFirewall_Drop3(t *testing.T) {
|
||||
@@ -410,6 +414,7 @@ func TestFirewall_Drop3(t *testing.T) {
|
||||
ConnectionState: &ConnectionState{
|
||||
peerCert: &c1,
|
||||
},
|
||||
hostId: ip2int(ipNet.IP),
|
||||
}
|
||||
h1.CreateRemoteCIDR(&c1)
|
||||
|
||||
@@ -424,6 +429,7 @@ func TestFirewall_Drop3(t *testing.T) {
|
||||
ConnectionState: &ConnectionState{
|
||||
peerCert: &c2,
|
||||
},
|
||||
hostId: ip2int(ipNet.IP),
|
||||
}
|
||||
h2.CreateRemoteCIDR(&c2)
|
||||
|
||||
@@ -438,6 +444,7 @@ func TestFirewall_Drop3(t *testing.T) {
|
||||
ConnectionState: &ConnectionState{
|
||||
peerCert: &c3,
|
||||
},
|
||||
hostId: ip2int(ipNet.IP),
|
||||
}
|
||||
h3.CreateRemoteCIDR(&c3)
|
||||
|
||||
@@ -447,13 +454,81 @@ func TestFirewall_Drop3(t *testing.T) {
|
||||
cp := cert.NewCAPool()
|
||||
|
||||
// c1 should pass because host match
|
||||
assert.False(t, fw.Drop([]byte{}, p, true, &h1, cp))
|
||||
assert.NoError(t, fw.Drop([]byte{}, p, true, &h1, cp, nil))
|
||||
// c2 should pass because ca sha match
|
||||
resetConntrack(fw)
|
||||
assert.False(t, fw.Drop([]byte{}, p, true, &h2, cp))
|
||||
assert.NoError(t, fw.Drop([]byte{}, p, true, &h2, cp, nil))
|
||||
// c3 should fail because no match
|
||||
resetConntrack(fw)
|
||||
assert.True(t, fw.Drop([]byte{}, p, true, &h3, cp))
|
||||
assert.Equal(t, fw.Drop([]byte{}, p, true, &h3, cp, nil), ErrNoMatchingRule)
|
||||
}
|
||||
|
||||
func TestFirewall_DropConntrackReload(t *testing.T) {
|
||||
ob := &bytes.Buffer{}
|
||||
out := l.Out
|
||||
l.SetOutput(ob)
|
||||
defer l.SetOutput(out)
|
||||
|
||||
p := FirewallPacket{
|
||||
ip2int(net.IPv4(1, 2, 3, 4)),
|
||||
ip2int(net.IPv4(1, 2, 3, 4)),
|
||||
10,
|
||||
90,
|
||||
fwProtoUDP,
|
||||
false,
|
||||
}
|
||||
|
||||
ipNet := net.IPNet{
|
||||
IP: net.IPv4(1, 2, 3, 4),
|
||||
Mask: net.IPMask{255, 255, 255, 0},
|
||||
}
|
||||
|
||||
c := cert.NebulaCertificate{
|
||||
Details: cert.NebulaCertificateDetails{
|
||||
Name: "host1",
|
||||
Ips: []*net.IPNet{&ipNet},
|
||||
Groups: []string{"default-group"},
|
||||
InvertedGroups: map[string]struct{}{"default-group": {}},
|
||||
Issuer: "signer-shasum",
|
||||
},
|
||||
}
|
||||
h := HostInfo{
|
||||
ConnectionState: &ConnectionState{
|
||||
peerCert: &c,
|
||||
},
|
||||
hostId: ip2int(ipNet.IP),
|
||||
}
|
||||
h.CreateRemoteCIDR(&c)
|
||||
|
||||
fw := NewFirewall(time.Second, time.Minute, time.Hour, &c)
|
||||
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"any"}, "", nil, "", ""))
|
||||
cp := cert.NewCAPool()
|
||||
|
||||
// Drop outbound
|
||||
assert.Equal(t, fw.Drop([]byte{}, p, false, &h, cp, nil), ErrNoMatchingRule)
|
||||
// Allow inbound
|
||||
resetConntrack(fw)
|
||||
assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp, nil))
|
||||
// Allow outbound because conntrack
|
||||
assert.NoError(t, fw.Drop([]byte{}, p, false, &h, cp, nil))
|
||||
|
||||
oldFw := fw
|
||||
fw = NewFirewall(time.Second, time.Minute, time.Hour, &c)
|
||||
assert.Nil(t, fw.AddRule(true, fwProtoAny, 10, 10, []string{"any"}, "", nil, "", ""))
|
||||
fw.Conntrack = oldFw.Conntrack
|
||||
fw.rulesVersion = oldFw.rulesVersion + 1
|
||||
|
||||
// Allow outbound because conntrack and new rules allow port 10
|
||||
assert.NoError(t, fw.Drop([]byte{}, p, false, &h, cp, nil))
|
||||
|
||||
oldFw = fw
|
||||
fw = NewFirewall(time.Second, time.Minute, time.Hour, &c)
|
||||
assert.Nil(t, fw.AddRule(true, fwProtoAny, 11, 11, []string{"any"}, "", nil, "", ""))
|
||||
fw.Conntrack = oldFw.Conntrack
|
||||
fw.rulesVersion = oldFw.rulesVersion + 1
|
||||
|
||||
// Drop outbound because conntrack doesn't match new ruleset
|
||||
assert.Equal(t, fw.Drop([]byte{}, p, false, &h, cp, nil), ErrNoMatchingRule)
|
||||
}
|
||||
|
||||
func BenchmarkLookup(b *testing.B) {
|
||||
@@ -856,7 +931,7 @@ func (mf *mockFirewall) AddRule(incoming bool, proto uint8, startPort int32, end
|
||||
}
|
||||
|
||||
func resetConntrack(fw *Firewall) {
|
||||
fw.connMutex.Lock()
|
||||
fw.Conns = map[FirewallPacket]*conn{}
|
||||
fw.connMutex.Unlock()
|
||||
fw.Conntrack.Lock()
|
||||
fw.Conntrack.Conns = map[FirewallPacket]*conn{}
|
||||
fw.Conntrack.Unlock()
|
||||
}
|
||||
|
||||
9
go.mod
9
go.mod
@@ -1,6 +1,6 @@
|
||||
module github.com/slackhq/nebula
|
||||
|
||||
go 1.12
|
||||
go 1.16
|
||||
|
||||
require (
|
||||
github.com/anmitsu/go-shlex v0.0.0-20161002113705-648efa622239
|
||||
@@ -11,7 +11,7 @@ require (
|
||||
github.com/flynn/noise v0.0.0-20180327030543-2492fe189ae6
|
||||
github.com/golang/protobuf v1.3.2
|
||||
github.com/imdario/mergo v0.3.8
|
||||
github.com/kardianos/service v1.0.0
|
||||
github.com/kardianos/service v1.1.0
|
||||
github.com/konsorten/go-windows-terminal-sequences v1.0.2 // indirect
|
||||
github.com/kr/pretty v0.1.0 // indirect
|
||||
github.com/miekg/dns v1.1.25
|
||||
@@ -21,11 +21,12 @@ require (
|
||||
github.com/prometheus/procfs v0.0.8 // indirect
|
||||
github.com/rcrowley/go-metrics v0.0.0-20190826022208-cac0b30c2563
|
||||
github.com/sirupsen/logrus v1.4.2
|
||||
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e
|
||||
github.com/songgao/water v0.0.0-20190725173103-fd331bda3f4b
|
||||
github.com/stretchr/testify v1.4.0
|
||||
github.com/stretchr/testify v1.6.1
|
||||
github.com/vishvananda/netlink v1.0.1-0.20190522153524-00009fb8606a
|
||||
github.com/vishvananda/netns v0.0.0-20191106174202-0a2b9b5464df // indirect
|
||||
golang.org/x/crypto v0.0.0-20191206172530-e9b2fee46413
|
||||
golang.org/x/crypto v0.0.0-20200220183623-bac4c82f6975
|
||||
golang.org/x/net v0.0.0-20191209160850-c0dbc17a3553
|
||||
golang.org/x/sys v0.0.0-20191210023423-ac6580df4449
|
||||
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 // indirect
|
||||
|
||||
14
go.sum
14
go.sum
@@ -46,6 +46,8 @@ github.com/json-iterator/go v1.1.7/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/u
|
||||
github.com/julienschmidt/httprouter v1.2.0/go.mod h1:SYymIcj16QtmaHHD7aYtjjsJG7VTCxuUUipMqKk8s4w=
|
||||
github.com/kardianos/service v1.0.0 h1:HgQS3mFfOlyntWX8Oke98JcJLqt1DBcHR4kxShpYef0=
|
||||
github.com/kardianos/service v1.0.0/go.mod h1:8CzDhVuCuugtsHyZoTvsOBuvonN/UDBvl0kH+BUxvbo=
|
||||
github.com/kardianos/service v1.1.0 h1:QV2SiEeWK42P0aEmGcsAgjApw/lRxkwopvT+Gu6t1/0=
|
||||
github.com/kardianos/service v1.1.0/go.mod h1:RrJI2xn5vve/r32U5suTbeaSGoMU6GbNPoj36CVYcHc=
|
||||
github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
|
||||
github.com/konsorten/go-windows-terminal-sequences v1.0.2 h1:DB17ag19krx9CFsz4o3enTrPXyIXCl+2iCXH/aMAp9s=
|
||||
github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
|
||||
@@ -96,6 +98,8 @@ github.com/rcrowley/go-metrics v0.0.0-20190826022208-cac0b30c2563/go.mod h1:bCqn
|
||||
github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo=
|
||||
github.com/sirupsen/logrus v1.4.2 h1:SPIRibHv4MatM3XXNO2BJeFLZwZ2LvZgfQ5+UNI2im4=
|
||||
github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE=
|
||||
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e h1:MRM5ITcdelLK2j1vwZ3Je0FKVCfqOLp5zO6trqMLYs0=
|
||||
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e/go.mod h1:XV66xRDqSt+GTGFMVlhk3ULuV0y9ZmzeVGR4mloJI3M=
|
||||
github.com/songgao/water v0.0.0-20190725173103-fd331bda3f4b h1:+y4hCMc/WKsDbAPsOQZgBSaSZ26uh2afyaWeVg/3s/c=
|
||||
github.com/songgao/water v0.0.0-20190725173103-fd331bda3f4b/go.mod h1:P5HUIBuIWKbyjl083/loAegFkfbFNx5i2qEP4CNbm7E=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
@@ -103,8 +107,8 @@ github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+
|
||||
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
|
||||
github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q=
|
||||
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||
github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk=
|
||||
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
|
||||
github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0=
|
||||
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/vishvananda/netlink v1.0.1-0.20190522153524-00009fb8606a h1:Bt1IVPhiCDMqwGrc2nnbIN4QKvJGx6SK2NzWBmW00ao=
|
||||
github.com/vishvananda/netlink v1.0.1-0.20190522153524-00009fb8606a/go.mod h1:+SR5DhBJrl6ZM7CoCKvpw5BKroDKQ+PJqOg65H/2ktk=
|
||||
github.com/vishvananda/netns v0.0.0-20191106174202-0a2b9b5464df h1:OviZH7qLw/7ZovXvuNyL3XQl8UFofeikI1NW1Gypu7k=
|
||||
@@ -112,8 +116,8 @@ github.com/vishvananda/netns v0.0.0-20191106174202-0a2b9b5464df/go.mod h1:JP3t17
|
||||
golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
golang.org/x/crypto v0.0.0-20190923035154-9ee001bba392/go.mod h1:/lpIB1dKB+9EgE3H3cr1v9wB50oz8l4C4h62xy7jSTY=
|
||||
golang.org/x/crypto v0.0.0-20191206172530-e9b2fee46413 h1:ULYEB3JvPRE/IfO+9uO7vKV/xzVTO7XPAwm8xbf4w2g=
|
||||
golang.org/x/crypto v0.0.0-20191206172530-e9b2fee46413/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
||||
golang.org/x/crypto v0.0.0-20200220183623-bac4c82f6975 h1:/Tl7pH94bvbAAHBdZJT947M/+gp0+CqQXDtMRC0fseo=
|
||||
golang.org/x/crypto v0.0.0-20200220183623-bac4c82f6975/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
||||
golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
||||
golang.org/x/net v0.0.0-20190613194153-d28f0bde5980/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
@@ -152,3 +156,5 @@ gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw=
|
||||
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
gopkg.in/yaml.v2 v2.2.7 h1:VUgggvou5XRW9mHwD/yXxIYSMtY0zoKQf/v226p2nyo=
|
||||
gopkg.in/yaml.v2 v2.2.7/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
|
||||
23
handshake.go
23
handshake.go
@@ -6,26 +6,23 @@ const (
|
||||
)
|
||||
|
||||
func HandleIncomingHandshake(f *Interface, addr *udpAddr, packet []byte, h *Header, hostinfo *HostInfo) {
|
||||
newHostinfo, _ := f.handshakeManager.QueryIndex(h.RemoteIndex)
|
||||
//TODO: For stage 1 we won't have hostinfo yet but stage 2 and above would require it, this check may be helpful in those cases
|
||||
//if err != nil {
|
||||
// l.WithError(err).WithField("udpAddr", addr).Error("Error while finding host info for handshake message")
|
||||
// return
|
||||
//}
|
||||
if !f.lightHouse.remoteAllowList.Allow(udp2ipInt(addr)) {
|
||||
l.WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
|
||||
return
|
||||
}
|
||||
|
||||
tearDown := false
|
||||
switch h.Subtype {
|
||||
case handshakeIXPSK0:
|
||||
switch h.MessageCounter {
|
||||
case 1:
|
||||
tearDown = ixHandshakeStage1(f, addr, newHostinfo, packet, h)
|
||||
ixHandshakeStage1(f, addr, packet, h)
|
||||
case 2:
|
||||
tearDown = ixHandshakeStage2(f, addr, newHostinfo, packet, h)
|
||||
newHostinfo, _ := f.handshakeManager.QueryIndex(h.RemoteIndex)
|
||||
tearDown := ixHandshakeStage2(f, addr, newHostinfo, packet, h)
|
||||
if tearDown && newHostinfo != nil {
|
||||
f.handshakeManager.DeleteHostInfo(newHostinfo)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if tearDown && newHostinfo != nil {
|
||||
f.handshakeManager.DeleteIndex(newHostinfo.localIndexId)
|
||||
f.handshakeManager.DeleteVpnIP(newHostinfo.hostId)
|
||||
}
|
||||
}
|
||||
|
||||
393
handshake_ix.go
393
handshake_ix.go
@@ -1,11 +1,10 @@
|
||||
package nebula
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"bytes"
|
||||
|
||||
"github.com/flynn/noise"
|
||||
"github.com/golang/protobuf/proto"
|
||||
)
|
||||
@@ -26,17 +25,17 @@ func ixHandshakeStage0(f *Interface, vpnIp uint32, hostinfo *HostInfo) {
|
||||
}
|
||||
}
|
||||
|
||||
myIndex, err := generateIndex()
|
||||
err := f.handshakeManager.AddIndexHostInfo(hostinfo)
|
||||
if err != nil {
|
||||
l.WithError(err).WithField("vpnIp", IntIp(vpnIp)).
|
||||
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to generate index")
|
||||
return
|
||||
}
|
||||
|
||||
ci := hostinfo.ConnectionState
|
||||
f.handshakeManager.AddIndexHostInfo(myIndex, hostinfo)
|
||||
|
||||
hsProto := &NebulaHandshakeDetails{
|
||||
InitiatorIndex: myIndex,
|
||||
InitiatorIndex: hostinfo.localIndexId,
|
||||
Time: uint64(time.Now().Unix()),
|
||||
Cert: ci.certState.rawCertificateNoKey,
|
||||
}
|
||||
@@ -55,7 +54,7 @@ func ixHandshakeStage0(f *Interface, vpnIp uint32, hostinfo *HostInfo) {
|
||||
}
|
||||
|
||||
header := HeaderEncode(make([]byte, HeaderLen), Version, uint8(handshake), handshakeIXPSK0, 0, 1)
|
||||
atomic.AddUint64(ci.messageCounter, 1)
|
||||
atomic.AddUint64(&ci.atomicMessageCounter, 1)
|
||||
|
||||
msg, _, _, err := ci.H.WriteMessage(header, hsBytes)
|
||||
if err != nil {
|
||||
@@ -64,188 +63,226 @@ func ixHandshakeStage0(f *Interface, vpnIp uint32, hostinfo *HostInfo) {
|
||||
return
|
||||
}
|
||||
|
||||
// We are sending handshake packet 1, so we don't expect to receive
|
||||
// handshake packet 1 from the responder
|
||||
ci.window.Update(1)
|
||||
|
||||
hostinfo.HandshakePacket[0] = msg
|
||||
hostinfo.HandshakeReady = true
|
||||
hostinfo.handshakeStart = time.Now()
|
||||
|
||||
}
|
||||
|
||||
func ixHandshakeStage1(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet []byte, h *Header) bool {
|
||||
var ip uint32
|
||||
if h.RemoteIndex == 0 {
|
||||
ci := f.newConnectionState(false, noise.HandshakeIX, []byte{}, 0)
|
||||
// Mark packet 1 as seen so it doesn't show up as missed
|
||||
ci.window.Update(1)
|
||||
func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
|
||||
ci := f.newConnectionState(false, noise.HandshakeIX, []byte{}, 0)
|
||||
// Mark packet 1 as seen so it doesn't show up as missed
|
||||
ci.window.Update(1)
|
||||
|
||||
msg, _, _, err := ci.H.ReadMessage(nil, packet[HeaderLen:])
|
||||
if err != nil {
|
||||
l.WithError(err).WithField("udpAddr", addr).
|
||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.ReadMessage")
|
||||
return true
|
||||
}
|
||||
msg, _, _, err := ci.H.ReadMessage(nil, packet[HeaderLen:])
|
||||
if err != nil {
|
||||
l.WithError(err).WithField("udpAddr", addr).
|
||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.ReadMessage")
|
||||
return
|
||||
}
|
||||
|
||||
hs := &NebulaHandshake{}
|
||||
err = proto.Unmarshal(msg, hs)
|
||||
/*
|
||||
l.Debugln("GOT INDEX: ", hs.Details.InitiatorIndex)
|
||||
*/
|
||||
if err != nil || hs.Details == nil {
|
||||
l.WithError(err).WithField("udpAddr", addr).
|
||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed unmarshal handshake message")
|
||||
return true
|
||||
}
|
||||
hs := &NebulaHandshake{}
|
||||
err = proto.Unmarshal(msg, hs)
|
||||
/*
|
||||
l.Debugln("GOT INDEX: ", hs.Details.InitiatorIndex)
|
||||
*/
|
||||
if err != nil || hs.Details == nil {
|
||||
l.WithError(err).WithField("udpAddr", addr).
|
||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed unmarshal handshake message")
|
||||
return
|
||||
}
|
||||
|
||||
hostinfo, _ := f.handshakeManager.pendingHostMap.QueryReverseIndex(hs.Details.InitiatorIndex)
|
||||
if hostinfo != nil && bytes.Equal(hostinfo.HandshakePacket[0], packet[HeaderLen:]) {
|
||||
if msg, ok := hostinfo.HandshakePacket[2]; ok {
|
||||
err := f.outside.WriteTo(msg, addr)
|
||||
if err != nil {
|
||||
l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
|
||||
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true).
|
||||
WithError(err).Error("Failed to send handshake message")
|
||||
} else {
|
||||
l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
|
||||
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true).
|
||||
Info("Handshake message sent")
|
||||
}
|
||||
return false
|
||||
}
|
||||
remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert)
|
||||
if err != nil {
|
||||
l.WithError(err).WithField("udpAddr", addr).
|
||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).WithField("cert", remoteCert).
|
||||
Info("Invalid certificate from host")
|
||||
return
|
||||
}
|
||||
vpnIP := ip2int(remoteCert.Details.Ips[0].IP)
|
||||
certName := remoteCert.Details.Name
|
||||
fingerprint, _ := remoteCert.Sha256Sum()
|
||||
|
||||
l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
|
||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).WithField("cached", true).
|
||||
WithField("packets", hostinfo.HandshakePacket).
|
||||
Error("Seen this handshake packet already but don't have a cached packet to return")
|
||||
}
|
||||
|
||||
remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert)
|
||||
if err != nil {
|
||||
l.WithError(err).WithField("udpAddr", addr).
|
||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).WithField("cert", remoteCert).
|
||||
Info("Invalid certificate from host")
|
||||
return true
|
||||
}
|
||||
vpnIP := ip2int(remoteCert.Details.Ips[0].IP)
|
||||
|
||||
myIndex, err := generateIndex()
|
||||
if err != nil {
|
||||
l.WithError(err).WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
|
||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to generate index")
|
||||
return true
|
||||
}
|
||||
|
||||
hostinfo, err = f.handshakeManager.AddIndex(myIndex, ci)
|
||||
if err != nil {
|
||||
l.WithError(err).WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
|
||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Error adding index to connection manager")
|
||||
|
||||
return true
|
||||
}
|
||||
if vpnIP == ip2int(f.certState.certificate.Details.Ips[0].IP) {
|
||||
l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
|
||||
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
||||
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
||||
Info("Handshake message received")
|
||||
WithField("certName", certName).
|
||||
WithField("fingerprint", fingerprint).
|
||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Refusing to handshake with myself")
|
||||
return
|
||||
}
|
||||
|
||||
hostinfo.remoteIndexId = hs.Details.InitiatorIndex
|
||||
hs.Details.ResponderIndex = myIndex
|
||||
hs.Details.Cert = ci.certState.rawCertificateNoKey
|
||||
myIndex, err := generateIndex()
|
||||
if err != nil {
|
||||
l.WithError(err).WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
|
||||
WithField("certName", certName).
|
||||
WithField("fingerprint", fingerprint).
|
||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to generate index")
|
||||
return
|
||||
}
|
||||
|
||||
hsBytes, err := proto.Marshal(hs)
|
||||
if err != nil {
|
||||
l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
|
||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to marshal handshake message")
|
||||
return true
|
||||
}
|
||||
hostinfo := &HostInfo{
|
||||
ConnectionState: ci,
|
||||
Remotes: []*HostInfoDest{},
|
||||
localIndexId: myIndex,
|
||||
remoteIndexId: hs.Details.InitiatorIndex,
|
||||
hostId: vpnIP,
|
||||
HandshakePacket: make(map[uint8][]byte, 0),
|
||||
}
|
||||
|
||||
header := HeaderEncode(make([]byte, HeaderLen), Version, uint8(handshake), handshakeIXPSK0, hs.Details.InitiatorIndex, 2)
|
||||
msg, dKey, eKey, err := ci.H.WriteMessage(header, hsBytes)
|
||||
if err != nil {
|
||||
l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
|
||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage")
|
||||
return true
|
||||
}
|
||||
l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
|
||||
WithField("certName", certName).
|
||||
WithField("fingerprint", fingerprint).
|
||||
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
||||
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
||||
Info("Handshake message received")
|
||||
|
||||
if f.hostMap.CheckHandshakeCompleteIP(vpnIP) && vpnIP < ip2int(f.certState.certificate.Details.Ips[0].IP) {
|
||||
hs.Details.ResponderIndex = myIndex
|
||||
hs.Details.Cert = ci.certState.rawCertificateNoKey
|
||||
|
||||
hsBytes, err := proto.Marshal(hs)
|
||||
if err != nil {
|
||||
l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
|
||||
WithField("certName", certName).
|
||||
WithField("fingerprint", fingerprint).
|
||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to marshal handshake message")
|
||||
return
|
||||
}
|
||||
|
||||
header := HeaderEncode(make([]byte, HeaderLen), Version, uint8(handshake), handshakeIXPSK0, hs.Details.InitiatorIndex, 2)
|
||||
msg, dKey, eKey, err := ci.H.WriteMessage(header, hsBytes)
|
||||
if err != nil {
|
||||
l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
|
||||
WithField("certName", certName).
|
||||
WithField("fingerprint", fingerprint).
|
||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage")
|
||||
return
|
||||
} else if dKey == nil || eKey == nil {
|
||||
l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
|
||||
WithField("certName", certName).
|
||||
WithField("fingerprint", fingerprint).
|
||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Noise did not arrive at a key")
|
||||
return
|
||||
}
|
||||
|
||||
hostinfo.HandshakePacket[0] = make([]byte, len(packet[HeaderLen:]))
|
||||
copy(hostinfo.HandshakePacket[0], packet[HeaderLen:])
|
||||
|
||||
// Regardless of whether you are the sender or receiver, you should arrive here
|
||||
// and complete standing up the connection.
|
||||
hostinfo.HandshakePacket[2] = make([]byte, len(msg))
|
||||
copy(hostinfo.HandshakePacket[2], msg)
|
||||
|
||||
// We are sending handshake packet 2, so we don't expect to receive
|
||||
// handshake packet 2 from the initiator.
|
||||
ci.window.Update(2)
|
||||
|
||||
ci.peerCert = remoteCert
|
||||
ci.dKey = NewNebulaCipherState(dKey)
|
||||
ci.eKey = NewNebulaCipherState(eKey)
|
||||
//l.Debugln("got symmetric pairs")
|
||||
|
||||
//hostinfo.ClearRemotes()
|
||||
hostinfo.AddRemote(*addr)
|
||||
hostinfo.ForcePromoteBest(f.hostMap.preferredRanges)
|
||||
hostinfo.CreateRemoteCIDR(remoteCert)
|
||||
|
||||
hostinfo.Lock()
|
||||
defer hostinfo.Unlock()
|
||||
|
||||
// Only overwrite existing record if we should win the handshake race
|
||||
overwrite := vpnIP > ip2int(f.certState.certificate.Details.Ips[0].IP)
|
||||
existing, err := f.handshakeManager.CheckAndComplete(hostinfo, 0, overwrite, f)
|
||||
if err != nil {
|
||||
switch err {
|
||||
case ErrAlreadySeen:
|
||||
msg = existing.HandshakePacket[2]
|
||||
f.messageMetrics.Tx(handshake, NebulaMessageSubType(msg[1]), 1)
|
||||
err := f.outside.WriteTo(msg, addr)
|
||||
if err != nil {
|
||||
l.WithField("vpnIp", IntIp(existing.hostId)).WithField("udpAddr", addr).
|
||||
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true).
|
||||
WithError(err).Error("Failed to send handshake message")
|
||||
} else {
|
||||
l.WithField("vpnIp", IntIp(existing.hostId)).WithField("udpAddr", addr).
|
||||
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true).
|
||||
Info("Handshake message sent")
|
||||
}
|
||||
return
|
||||
case ErrExistingHostInfo:
|
||||
// This means there was an existing tunnel and we didn't win
|
||||
// handshake avoidance
|
||||
l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
|
||||
WithField("certName", certName).
|
||||
WithField("fingerprint", fingerprint).
|
||||
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
||||
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
||||
Info("Prevented a handshake race")
|
||||
|
||||
// Send a test packet to trigger an authenticated tunnel test, this should suss out any lingering tunnel issues
|
||||
f.SendMessageToVpnIp(test, testRequest, vpnIP, []byte(""), make([]byte, 12, 12), make([]byte, mtu))
|
||||
return true
|
||||
return
|
||||
case ErrLocalIndexCollision:
|
||||
// This means we failed to insert because of collision on localIndexId. Just let the next handshake packet retry
|
||||
l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
|
||||
WithField("certName", certName).
|
||||
WithField("fingerprint", fingerprint).
|
||||
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
||||
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
||||
WithField("localIndex", hostinfo.localIndexId).WithField("collision", IntIp(existing.hostId)).
|
||||
Error("Failed to add HostInfo due to localIndex collision")
|
||||
return
|
||||
default:
|
||||
// Shouldn't happen, but just in case someone adds a new error type to CheckAndComplete
|
||||
// And we forget to update it here
|
||||
l.WithError(err).WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
|
||||
WithField("certName", certName).
|
||||
WithField("fingerprint", fingerprint).
|
||||
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
||||
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
||||
Error("Failed to add HostInfo to HostMap")
|
||||
return
|
||||
}
|
||||
|
||||
hostinfo.HandshakePacket[0] = make([]byte, len(packet[HeaderLen:]))
|
||||
copy(hostinfo.HandshakePacket[0], packet[HeaderLen:])
|
||||
|
||||
// Regardless of whether you are the sender or receiver, you should arrive here
|
||||
// and complete standing up the connection.
|
||||
if dKey != nil && eKey != nil {
|
||||
hostinfo.HandshakePacket[2] = make([]byte, len(msg))
|
||||
copy(hostinfo.HandshakePacket[2], msg)
|
||||
|
||||
err := f.outside.WriteTo(msg, addr)
|
||||
if err != nil {
|
||||
l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
|
||||
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
||||
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
|
||||
WithError(err).Error("Failed to send handshake")
|
||||
} else {
|
||||
l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
|
||||
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
||||
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
|
||||
Info("Handshake message sent")
|
||||
}
|
||||
|
||||
ip = ip2int(remoteCert.Details.Ips[0].IP)
|
||||
ci.peerCert = remoteCert
|
||||
ci.dKey = NewNebulaCipherState(dKey)
|
||||
ci.eKey = NewNebulaCipherState(eKey)
|
||||
//l.Debugln("got symmetric pairs")
|
||||
|
||||
//hostinfo.ClearRemotes()
|
||||
hostinfo.AddRemote(*addr)
|
||||
hostinfo.CreateRemoteCIDR(remoteCert)
|
||||
f.lightHouse.AddRemoteAndReset(ip, addr)
|
||||
if f.serveDns {
|
||||
dnsR.Add(remoteCert.Details.Name+".", remoteCert.Details.Ips[0].IP.String())
|
||||
}
|
||||
|
||||
ho, err := f.hostMap.QueryVpnIP(vpnIP)
|
||||
if err == nil && ho.localIndexId != 0 {
|
||||
l.WithField("vpnIp", vpnIP).
|
||||
WithField("action", "removing stale index").
|
||||
WithField("index", ho.localIndexId).
|
||||
Debug("Handshake processing")
|
||||
f.hostMap.DeleteIndex(ho.localIndexId)
|
||||
}
|
||||
|
||||
f.hostMap.AddIndexHostInfo(hostinfo.localIndexId, hostinfo)
|
||||
f.hostMap.AddVpnIPHostInfo(vpnIP, hostinfo)
|
||||
|
||||
hostinfo.handshakeComplete()
|
||||
} else {
|
||||
l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
|
||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
||||
Error("Noise did not arrive at a key")
|
||||
return true
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
f.hostMap.AddRemote(ip, addr)
|
||||
return false
|
||||
// Do the send
|
||||
f.messageMetrics.Tx(handshake, NebulaMessageSubType(msg[1]), 1)
|
||||
err = f.outside.WriteTo(msg, addr)
|
||||
if err != nil {
|
||||
l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
|
||||
WithField("certName", certName).
|
||||
WithField("fingerprint", fingerprint).
|
||||
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
||||
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
|
||||
WithError(err).Error("Failed to send handshake")
|
||||
} else {
|
||||
l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
|
||||
WithField("certName", certName).
|
||||
WithField("fingerprint", fingerprint).
|
||||
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
||||
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
|
||||
Info("Handshake message sent")
|
||||
}
|
||||
|
||||
hostinfo.handshakeComplete()
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet []byte, h *Header) bool {
|
||||
if hostinfo == nil {
|
||||
return true
|
||||
}
|
||||
hostinfo.Lock()
|
||||
defer hostinfo.Unlock()
|
||||
|
||||
if bytes.Equal(hostinfo.HandshakePacket[2], packet[HeaderLen:]) {
|
||||
l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
|
||||
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("header", h).
|
||||
Error("Already seen this handshake packet")
|
||||
Info("Already seen this handshake packet")
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -266,6 +303,11 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
|
||||
// to DOS us. Every other error condition after should to allow a possible good handshake to complete in the
|
||||
// near future
|
||||
return false
|
||||
} else if dKey == nil || eKey == nil {
|
||||
l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
|
||||
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
|
||||
Error("Noise did not arrive at a key")
|
||||
return true
|
||||
}
|
||||
|
||||
hs := &NebulaHandshake{}
|
||||
@@ -284,9 +326,13 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
|
||||
return true
|
||||
}
|
||||
vpnIP := ip2int(remoteCert.Details.Ips[0].IP)
|
||||
certName := remoteCert.Details.Name
|
||||
fingerprint, _ := remoteCert.Sha256Sum()
|
||||
|
||||
duration := time.Since(hostinfo.handshakeStart).Nanoseconds()
|
||||
l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
|
||||
WithField("certName", certName).
|
||||
WithField("fingerprint", fingerprint).
|
||||
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
||||
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
|
||||
WithField("durationNs", duration).
|
||||
@@ -306,41 +352,18 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
|
||||
|
||||
// Regardless of whether you are the sender or receiver, you should arrive here
|
||||
// and complete standing up the connection.
|
||||
if dKey != nil && eKey != nil {
|
||||
ip := ip2int(remoteCert.Details.Ips[0].IP)
|
||||
ci.peerCert = remoteCert
|
||||
ci.dKey = NewNebulaCipherState(dKey)
|
||||
ci.eKey = NewNebulaCipherState(eKey)
|
||||
//l.Debugln("got symmetric pairs")
|
||||
|
||||
//hostinfo.ClearRemotes()
|
||||
f.hostMap.AddRemote(ip, addr)
|
||||
hostinfo.CreateRemoteCIDR(remoteCert)
|
||||
f.lightHouse.AddRemoteAndReset(ip, addr)
|
||||
if f.serveDns {
|
||||
dnsR.Add(remoteCert.Details.Name+".", remoteCert.Details.Ips[0].IP.String())
|
||||
}
|
||||
ci.peerCert = remoteCert
|
||||
ci.dKey = NewNebulaCipherState(dKey)
|
||||
ci.eKey = NewNebulaCipherState(eKey)
|
||||
//l.Debugln("got symmetric pairs")
|
||||
|
||||
ho, err := f.hostMap.QueryVpnIP(vpnIP)
|
||||
if err == nil && ho.localIndexId != 0 {
|
||||
l.WithField("vpnIp", vpnIP).
|
||||
WithField("action", "removing stale index").
|
||||
WithField("index", ho.localIndexId).
|
||||
Debug("Handshake processing")
|
||||
f.hostMap.DeleteIndex(ho.localIndexId)
|
||||
}
|
||||
hostinfo.SetRemote(*addr)
|
||||
hostinfo.CreateRemoteCIDR(remoteCert)
|
||||
|
||||
f.hostMap.AddVpnIPHostInfo(vpnIP, hostinfo)
|
||||
f.hostMap.AddIndexHostInfo(hostinfo.localIndexId, hostinfo)
|
||||
|
||||
hostinfo.handshakeComplete()
|
||||
f.metricHandshakes.Update(duration)
|
||||
} else {
|
||||
l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
|
||||
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
|
||||
Error("Noise did not arrive at a key")
|
||||
return true
|
||||
}
|
||||
f.handshakeManager.Complete(hostinfo, f)
|
||||
hostinfo.handshakeComplete()
|
||||
f.metricHandshakes.Update(duration)
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
package nebula
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"errors"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
@@ -13,39 +14,76 @@ import (
|
||||
const (
|
||||
// Total time to try a handshake = sequence of HandshakeTryInterval * HandshakeRetries
|
||||
// With 100ms interval and 20 retries is 23.5 seconds
|
||||
HandshakeTryInterval = time.Millisecond * 100
|
||||
HandshakeRetries = 20
|
||||
// HandshakeWaitRotation is the number of handshake attempts to do before starting to use other ips addresses
|
||||
HandshakeWaitRotation = 5
|
||||
DefaultHandshakeTryInterval = time.Millisecond * 100
|
||||
DefaultHandshakeRetries = 20
|
||||
// DefaultHandshakeWaitRotation is the number of handshake attempts to do before starting to use other ips addresses
|
||||
DefaultHandshakeWaitRotation = 5
|
||||
DefaultHandshakeTriggerBuffer = 64
|
||||
)
|
||||
|
||||
var (
|
||||
defaultHandshakeConfig = HandshakeConfig{
|
||||
tryInterval: DefaultHandshakeTryInterval,
|
||||
retries: DefaultHandshakeRetries,
|
||||
waitRotation: DefaultHandshakeWaitRotation,
|
||||
triggerBuffer: DefaultHandshakeTriggerBuffer,
|
||||
}
|
||||
)
|
||||
|
||||
type HandshakeConfig struct {
|
||||
tryInterval time.Duration
|
||||
retries int
|
||||
waitRotation int
|
||||
triggerBuffer int
|
||||
|
||||
messageMetrics *MessageMetrics
|
||||
}
|
||||
|
||||
type HandshakeManager struct {
|
||||
pendingHostMap *HostMap
|
||||
mainHostMap *HostMap
|
||||
lightHouse *LightHouse
|
||||
outside *udpConn
|
||||
config HandshakeConfig
|
||||
|
||||
// can be used to trigger outbound handshake for the given vpnIP
|
||||
trigger chan uint32
|
||||
|
||||
OutboundHandshakeTimer *SystemTimerWheel
|
||||
InboundHandshakeTimer *SystemTimerWheel
|
||||
|
||||
messageMetrics *MessageMetrics
|
||||
}
|
||||
|
||||
func NewHandshakeManager(tunCidr *net.IPNet, preferredRanges []*net.IPNet, mainHostMap *HostMap, lightHouse *LightHouse, outside *udpConn) *HandshakeManager {
|
||||
func NewHandshakeManager(tunCidr *net.IPNet, preferredRanges []*net.IPNet, mainHostMap *HostMap, lightHouse *LightHouse, outside *udpConn, config HandshakeConfig) *HandshakeManager {
|
||||
return &HandshakeManager{
|
||||
pendingHostMap: NewHostMap("pending", tunCidr, preferredRanges),
|
||||
mainHostMap: mainHostMap,
|
||||
lightHouse: lightHouse,
|
||||
outside: outside,
|
||||
|
||||
OutboundHandshakeTimer: NewSystemTimerWheel(HandshakeTryInterval, HandshakeTryInterval*HandshakeRetries),
|
||||
InboundHandshakeTimer: NewSystemTimerWheel(HandshakeTryInterval, HandshakeTryInterval*HandshakeRetries),
|
||||
config: config,
|
||||
|
||||
trigger: make(chan uint32, config.triggerBuffer),
|
||||
|
||||
OutboundHandshakeTimer: NewSystemTimerWheel(config.tryInterval, config.tryInterval*time.Duration(config.retries)),
|
||||
InboundHandshakeTimer: NewSystemTimerWheel(config.tryInterval, config.tryInterval*time.Duration(config.retries)),
|
||||
|
||||
messageMetrics: config.messageMetrics,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *HandshakeManager) Run(f EncWriter) {
|
||||
clockSource := time.Tick(HandshakeTryInterval)
|
||||
for now := range clockSource {
|
||||
c.NextOutboundHandshakeTimerTick(now, f)
|
||||
c.NextInboundHandshakeTimerTick(now)
|
||||
clockSource := time.Tick(c.config.tryInterval)
|
||||
for {
|
||||
select {
|
||||
case vpnIP := <-c.trigger:
|
||||
l.WithField("vpnIp", IntIp(vpnIP)).Debug("HandshakeManager: triggered")
|
||||
c.handleOutbound(vpnIP, f, true)
|
||||
case now := <-clockSource:
|
||||
c.NextOutboundHandshakeTimerTick(now, f)
|
||||
c.NextInboundHandshakeTimerTick(now)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -57,68 +95,83 @@ func (c *HandshakeManager) NextOutboundHandshakeTimerTick(now time.Time, f EncWr
|
||||
break
|
||||
}
|
||||
vpnIP := ep.(uint32)
|
||||
c.handleOutbound(vpnIP, f, false)
|
||||
}
|
||||
}
|
||||
|
||||
index, err := c.pendingHostMap.GetIndexByVpnIP(vpnIP)
|
||||
if err != nil {
|
||||
continue
|
||||
func (c *HandshakeManager) handleOutbound(vpnIP uint32, f EncWriter, lighthouseTriggered bool) {
|
||||
hostinfo, err := c.pendingHostMap.QueryVpnIP(vpnIP)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
hostinfo.Lock()
|
||||
defer hostinfo.Unlock()
|
||||
|
||||
// If we haven't finished the handshake and we haven't hit max retries, query
|
||||
// lighthouse and then send the handshake packet again.
|
||||
if hostinfo.HandshakeCounter < c.config.retries && !hostinfo.HandshakeComplete {
|
||||
if hostinfo.remote == nil {
|
||||
// We continue to query the lighthouse because hosts may
|
||||
// come online during handshake retries. If the query
|
||||
// succeeds (no error), add the lighthouse info to hostinfo
|
||||
ips := c.lightHouse.QueryCache(vpnIP)
|
||||
// If we have no responses yet, or only one IP (the host hadn't
|
||||
// finished reporting its own IPs yet), then send another query to
|
||||
// the LH.
|
||||
if len(ips) <= 1 {
|
||||
ips, err = c.lightHouse.Query(vpnIP, f)
|
||||
}
|
||||
if err == nil {
|
||||
for _, ip := range ips {
|
||||
hostinfo.AddRemote(ip)
|
||||
}
|
||||
hostinfo.ForcePromoteBest(c.mainHostMap.preferredRanges)
|
||||
}
|
||||
} else if lighthouseTriggered {
|
||||
// We were triggered by a lighthouse HostQueryReply packet, but
|
||||
// we have already picked a remote for this host (this can happen
|
||||
// if we are configured with multiple lighthouses). So we can skip
|
||||
// this trigger and let the timerwheel handle the rest of the
|
||||
// process
|
||||
return
|
||||
}
|
||||
|
||||
hostinfo, err := c.pendingHostMap.QueryVpnIP(vpnIP)
|
||||
if err != nil {
|
||||
continue
|
||||
hostinfo.HandshakeCounter++
|
||||
|
||||
// We want to use the "best" calculated ip for the first 5 attempts, after that we just blindly rotate through
|
||||
// all the others until we can stand up a connection.
|
||||
if hostinfo.HandshakeCounter > c.config.waitRotation {
|
||||
hostinfo.rotateRemote()
|
||||
}
|
||||
|
||||
// If we haven't finished the handshake and we haven't hit max retries, query
|
||||
// lighthouse and then send the handshake packet again.
|
||||
if hostinfo.HandshakeCounter < HandshakeRetries && !hostinfo.HandshakeComplete {
|
||||
if hostinfo.remote == nil {
|
||||
// We continue to query the lighthouse because hosts may
|
||||
// come online during handshake retries. If the query
|
||||
// succeeds (no error), add the lighthouse info to hostinfo
|
||||
ips, err := c.lightHouse.Query(vpnIP, f)
|
||||
if err == nil {
|
||||
for _, ip := range ips {
|
||||
hostinfo.AddRemote(ip)
|
||||
}
|
||||
hostinfo.ForcePromoteBest(c.mainHostMap.preferredRanges)
|
||||
}
|
||||
// Ensure the handshake is ready to avoid a race in timer tick and stage 0 handshake generation
|
||||
if hostinfo.HandshakeReady && hostinfo.remote != nil {
|
||||
c.messageMetrics.Tx(handshake, NebulaMessageSubType(hostinfo.HandshakePacket[0][1]), 1)
|
||||
err := c.outside.WriteTo(hostinfo.HandshakePacket[0], hostinfo.remote)
|
||||
if err != nil {
|
||||
hostinfo.logger().WithField("udpAddr", hostinfo.remote).
|
||||
WithField("initiatorIndex", hostinfo.localIndexId).
|
||||
WithField("remoteIndex", hostinfo.remoteIndexId).
|
||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
||||
WithError(err).Error("Failed to send handshake message")
|
||||
} else {
|
||||
//TODO: this log line is assuming a lot of stuff around the cached stage 0 handshake packet, we should
|
||||
// keep the real packet struct around for logging purposes
|
||||
hostinfo.logger().WithField("udpAddr", hostinfo.remote).
|
||||
WithField("initiatorIndex", hostinfo.localIndexId).
|
||||
WithField("remoteIndex", hostinfo.remoteIndexId).
|
||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
||||
Info("Handshake message sent")
|
||||
}
|
||||
}
|
||||
|
||||
hostinfo.HandshakeCounter++
|
||||
|
||||
// We want to use the "best" calculated ip for the first 5 attempts, after that we just blindly rotate through
|
||||
// all the others until we can stand up a connection.
|
||||
if hostinfo.HandshakeCounter > HandshakeWaitRotation {
|
||||
hostinfo.rotateRemote()
|
||||
}
|
||||
|
||||
// Ensure the handshake is ready to avoid a race in timer tick and stage 0 handshake generation
|
||||
if hostinfo.HandshakeReady && hostinfo.remote != nil {
|
||||
err := c.outside.WriteTo(hostinfo.HandshakePacket[0], hostinfo.remote)
|
||||
if err != nil {
|
||||
l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", hostinfo.remote).
|
||||
WithField("initiatorIndex", hostinfo.localIndexId).
|
||||
WithField("remoteIndex", hostinfo.remoteIndexId).
|
||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
||||
WithError(err).Error("Failed to send handshake message")
|
||||
} else {
|
||||
//TODO: this log line is assuming a lot of stuff around the cached stage 0 handshake packet, we should
|
||||
// keep the real packet struct around for logging purposes
|
||||
l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", hostinfo.remote).
|
||||
WithField("initiatorIndex", hostinfo.localIndexId).
|
||||
WithField("remoteIndex", hostinfo.remoteIndexId).
|
||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
||||
Info("Handshake message sent")
|
||||
}
|
||||
}
|
||||
|
||||
// Readd to the timer wheel so we continue trying wait HandshakeTryInterval * counter longer for next try
|
||||
// Readd to the timer wheel so we continue trying wait HandshakeTryInterval * counter longer for next try
|
||||
if !lighthouseTriggered {
|
||||
//l.Infoln("Interval: ", HandshakeTryInterval*time.Duration(hostinfo.HandshakeCounter))
|
||||
c.OutboundHandshakeTimer.Add(vpnIP, HandshakeTryInterval*time.Duration(hostinfo.HandshakeCounter))
|
||||
} else {
|
||||
c.pendingHostMap.DeleteVpnIP(vpnIP)
|
||||
c.pendingHostMap.DeleteIndex(index)
|
||||
c.OutboundHandshakeTimer.Add(vpnIP, c.config.tryInterval*time.Duration(hostinfo.HandshakeCounter))
|
||||
}
|
||||
} else {
|
||||
c.pendingHostMap.DeleteHostInfo(hostinfo)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -131,12 +184,7 @@ func (c *HandshakeManager) NextInboundHandshakeTimerTick(now time.Time) {
|
||||
}
|
||||
index := ep.(uint32)
|
||||
|
||||
vpnIP, err := c.pendingHostMap.GetVpnIPByIndex(index)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
c.pendingHostMap.DeleteIndex(index)
|
||||
c.pendingHostMap.DeleteVpnIP(vpnIP)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -144,32 +192,137 @@ func (c *HandshakeManager) AddVpnIP(vpnIP uint32) *HostInfo {
|
||||
hostinfo := c.pendingHostMap.AddVpnIP(vpnIP)
|
||||
// We lock here and use an array to insert items to prevent locking the
|
||||
// main receive thread for very long by waiting to add items to the pending map
|
||||
c.OutboundHandshakeTimer.Add(vpnIP, HandshakeTryInterval)
|
||||
c.OutboundHandshakeTimer.Add(vpnIP, c.config.tryInterval)
|
||||
|
||||
return hostinfo
|
||||
}
|
||||
|
||||
func (c *HandshakeManager) DeleteVpnIP(vpnIP uint32) {
|
||||
//l.Debugln("Deleting pending vpn ip :", IntIp(vpnIP))
|
||||
c.pendingHostMap.DeleteVpnIP(vpnIP)
|
||||
}
|
||||
var (
|
||||
ErrExistingHostInfo = errors.New("existing hostinfo")
|
||||
ErrAlreadySeen = errors.New("already seen")
|
||||
ErrLocalIndexCollision = errors.New("local index collision")
|
||||
)
|
||||
|
||||
func (c *HandshakeManager) AddIndex(index uint32, ci *ConnectionState) (*HostInfo, error) {
|
||||
hostinfo, err := c.pendingHostMap.AddIndex(index, ci)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Issue adding index: %d", index)
|
||||
// CheckAndComplete checks for any conflicts in the main and pending hostmap
|
||||
// before adding hostinfo to main. If err is nil, it was added. Otherwise err will be:
|
||||
|
||||
// ErrAlreadySeen if we already have an entry in the hostmap that has seen the
|
||||
// exact same handshake packet
|
||||
//
|
||||
// ErrExistingHostInfo if we already have an entry in the hostmap for this
|
||||
// VpnIP and overwrite was false.
|
||||
//
|
||||
// ErrLocalIndexCollision if we already have an entry in the main or pending
|
||||
// hostmap for the hostinfo.localIndexId.
|
||||
func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket uint8, overwrite bool, f *Interface) (*HostInfo, error) {
|
||||
c.pendingHostMap.RLock()
|
||||
defer c.pendingHostMap.RUnlock()
|
||||
c.mainHostMap.Lock()
|
||||
defer c.mainHostMap.Unlock()
|
||||
|
||||
existingHostInfo, found := c.mainHostMap.Hosts[hostinfo.hostId]
|
||||
if found && existingHostInfo != nil {
|
||||
if bytes.Equal(hostinfo.HandshakePacket[handshakePacket], existingHostInfo.HandshakePacket[handshakePacket]) {
|
||||
return existingHostInfo, ErrAlreadySeen
|
||||
}
|
||||
if !overwrite {
|
||||
return existingHostInfo, ErrExistingHostInfo
|
||||
}
|
||||
}
|
||||
//c.mainHostMap.AddIndexHostInfo(index, hostinfo)
|
||||
c.InboundHandshakeTimer.Add(index, time.Second*10)
|
||||
return hostinfo, nil
|
||||
|
||||
existingIndex, found := c.mainHostMap.Indexes[hostinfo.localIndexId]
|
||||
if found {
|
||||
// We have a collision, but for a different hostinfo
|
||||
return existingIndex, ErrLocalIndexCollision
|
||||
}
|
||||
existingIndex, found = c.pendingHostMap.Indexes[hostinfo.localIndexId]
|
||||
if found && existingIndex != hostinfo {
|
||||
// We have a collision, but for a different hostinfo
|
||||
return existingIndex, ErrLocalIndexCollision
|
||||
}
|
||||
|
||||
existingRemoteIndex, found := c.mainHostMap.RemoteIndexes[hostinfo.remoteIndexId]
|
||||
if found && existingRemoteIndex != nil && existingRemoteIndex.hostId != hostinfo.hostId {
|
||||
// We have a collision, but this can happen since we can't control
|
||||
// the remote ID. Just log about the situation as a note.
|
||||
hostinfo.logger().
|
||||
WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", IntIp(existingRemoteIndex.hostId)).
|
||||
Info("New host shadows existing host remoteIndex")
|
||||
}
|
||||
|
||||
if existingHostInfo != nil {
|
||||
// We are going to overwrite this entry, so remove the old references
|
||||
delete(c.mainHostMap.Hosts, existingHostInfo.hostId)
|
||||
delete(c.mainHostMap.Indexes, existingHostInfo.localIndexId)
|
||||
delete(c.mainHostMap.RemoteIndexes, existingHostInfo.remoteIndexId)
|
||||
}
|
||||
|
||||
c.mainHostMap.addHostInfo(hostinfo, f)
|
||||
return existingHostInfo, nil
|
||||
}
|
||||
|
||||
func (c *HandshakeManager) AddIndexHostInfo(index uint32, h *HostInfo) {
|
||||
c.pendingHostMap.AddIndexHostInfo(index, h)
|
||||
// Complete is a simpler version of CheckAndComplete when we already know we
|
||||
// won't have a localIndexId collision because we already have an entry in the
|
||||
// pendingHostMap
|
||||
func (c *HandshakeManager) Complete(hostinfo *HostInfo, f *Interface) {
|
||||
c.mainHostMap.Lock()
|
||||
defer c.mainHostMap.Unlock()
|
||||
|
||||
existingHostInfo, found := c.mainHostMap.Hosts[hostinfo.hostId]
|
||||
if found && existingHostInfo != nil {
|
||||
// We are going to overwrite this entry, so remove the old references
|
||||
delete(c.mainHostMap.Hosts, existingHostInfo.hostId)
|
||||
delete(c.mainHostMap.Indexes, existingHostInfo.localIndexId)
|
||||
delete(c.mainHostMap.RemoteIndexes, existingHostInfo.remoteIndexId)
|
||||
}
|
||||
|
||||
existingRemoteIndex, found := c.mainHostMap.RemoteIndexes[hostinfo.remoteIndexId]
|
||||
if found && existingRemoteIndex != nil {
|
||||
// We have a collision, but this can happen since we can't control
|
||||
// the remote ID. Just log about the situation as a note.
|
||||
hostinfo.logger().
|
||||
WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", IntIp(existingRemoteIndex.hostId)).
|
||||
Info("New host shadows existing host remoteIndex")
|
||||
}
|
||||
|
||||
c.mainHostMap.addHostInfo(hostinfo, f)
|
||||
}
|
||||
|
||||
func (c *HandshakeManager) DeleteIndex(index uint32) {
|
||||
//l.Debugln("Deleting pending index :", index)
|
||||
c.pendingHostMap.DeleteIndex(index)
|
||||
// AddIndexHostInfo generates a unique localIndexId for this HostInfo
|
||||
// and adds it to the pendingHostMap. Will error if we are unable to generate
|
||||
// a unique localIndexId
|
||||
func (c *HandshakeManager) AddIndexHostInfo(h *HostInfo) error {
|
||||
c.pendingHostMap.Lock()
|
||||
defer c.pendingHostMap.Unlock()
|
||||
c.mainHostMap.RLock()
|
||||
defer c.mainHostMap.RUnlock()
|
||||
|
||||
for i := 0; i < 32; i++ {
|
||||
index, err := generateIndex()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, inPending := c.pendingHostMap.Indexes[index]
|
||||
_, inMain := c.mainHostMap.Indexes[index]
|
||||
|
||||
if !inMain && !inPending {
|
||||
h.localIndexId = index
|
||||
c.pendingHostMap.Indexes[index] = h
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return errors.New("failed to generate unique localIndexId")
|
||||
}
|
||||
|
||||
func (c *HandshakeManager) addRemoteIndexHostInfo(index uint32, h *HostInfo) {
|
||||
c.pendingHostMap.addRemoteIndexHostInfo(index, h)
|
||||
}
|
||||
|
||||
func (c *HandshakeManager) DeleteHostInfo(hostinfo *HostInfo) {
|
||||
//l.Debugln("Deleting pending hostinfo :", hostinfo)
|
||||
c.pendingHostMap.DeleteHostInfo(hostinfo)
|
||||
}
|
||||
|
||||
func (c *HandshakeManager) QueryIndex(index uint32) (*HostInfo, error) {
|
||||
@@ -185,13 +338,19 @@ func (c *HandshakeManager) EmitStats() {
|
||||
|
||||
func generateIndex() (uint32, error) {
|
||||
b := make([]byte, 4)
|
||||
_, err := rand.Read(b)
|
||||
if err != nil {
|
||||
l.Errorln(err)
|
||||
return 0, err
|
||||
|
||||
// Let zero mean we don't know the ID, so don't generate zero
|
||||
var index uint32
|
||||
for index == 0 {
|
||||
_, err := rand.Read(b)
|
||||
if err != nil {
|
||||
l.Errorln(err)
|
||||
return 0, err
|
||||
}
|
||||
|
||||
index = binary.BigEndian.Uint32(b)
|
||||
}
|
||||
|
||||
index := binary.BigEndian.Uint32(b)
|
||||
if l.Level >= logrus.DebugLevel {
|
||||
l.WithField("index", index).
|
||||
Debug("Generated index")
|
||||
|
||||
@@ -8,12 +8,11 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
var indexes []uint32 = []uint32{1000, 2000, 3000, 4000}
|
||||
|
||||
//var ips []uint32 = []uint32{9000, 9999999, 3, 292394923}
|
||||
var ips []uint32
|
||||
|
||||
func Test_NewHandshakeManagerIndex(t *testing.T) {
|
||||
|
||||
_, tuncidr, _ := net.ParseCIDR("172.1.1.1/24")
|
||||
_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
|
||||
_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
|
||||
@@ -21,14 +20,23 @@ func Test_NewHandshakeManagerIndex(t *testing.T) {
|
||||
preferredRanges := []*net.IPNet{localrange}
|
||||
mainHM := NewHostMap("test", vpncidr, preferredRanges)
|
||||
|
||||
blah := NewHandshakeManager(tuncidr, preferredRanges, mainHM, &LightHouse{}, &udpConn{})
|
||||
blah := NewHandshakeManager(tuncidr, preferredRanges, mainHM, &LightHouse{}, &udpConn{}, defaultHandshakeConfig)
|
||||
|
||||
now := time.Now()
|
||||
blah.NextInboundHandshakeTimerTick(now)
|
||||
|
||||
var indexes = make([]uint32, 4)
|
||||
var hostinfo = make([]*HostInfo, len(indexes))
|
||||
for i := range indexes {
|
||||
hostinfo[i] = &HostInfo{ConnectionState: &ConnectionState{}}
|
||||
}
|
||||
|
||||
// Add four indexes
|
||||
for _, v := range indexes {
|
||||
blah.AddIndex(v, &ConnectionState{})
|
||||
for i := range indexes {
|
||||
err := blah.AddIndexHostInfo(hostinfo[i])
|
||||
assert.NoError(t, err)
|
||||
indexes[i] = hostinfo[i].localIndexId
|
||||
blah.InboundHandshakeTimer.Add(indexes[i], time.Second*10)
|
||||
}
|
||||
// Confirm they are in the pending index list
|
||||
for _, v := range indexes {
|
||||
@@ -37,8 +45,8 @@ func Test_NewHandshakeManagerIndex(t *testing.T) {
|
||||
// Adding something to pending should not affect the main hostmap
|
||||
assert.Len(t, mainHM.Indexes, 0)
|
||||
// Jump ahead 8 seconds
|
||||
for i := 1; i <= HandshakeRetries; i++ {
|
||||
next_tick := now.Add(HandshakeTryInterval * time.Duration(i))
|
||||
for i := 1; i <= DefaultHandshakeRetries; i++ {
|
||||
next_tick := now.Add(DefaultHandshakeTryInterval * time.Duration(i))
|
||||
blah.NextInboundHandshakeTimerTick(next_tick)
|
||||
}
|
||||
// Confirm they are still in the pending index list
|
||||
@@ -63,7 +71,7 @@ func Test_NewHandshakeManagerVpnIP(t *testing.T) {
|
||||
mw := &mockEncWriter{}
|
||||
mainHM := NewHostMap("test", vpncidr, preferredRanges)
|
||||
|
||||
blah := NewHandshakeManager(tuncidr, preferredRanges, mainHM, &LightHouse{}, &udpConn{})
|
||||
blah := NewHandshakeManager(tuncidr, preferredRanges, mainHM, &LightHouse{}, &udpConn{}, defaultHandshakeConfig)
|
||||
|
||||
now := time.Now()
|
||||
blah.NextOutboundHandshakeTimerTick(now, mw)
|
||||
@@ -81,8 +89,8 @@ func Test_NewHandshakeManagerVpnIP(t *testing.T) {
|
||||
|
||||
// Jump ahead `HandshakeRetries` ticks
|
||||
cumulative := time.Duration(0)
|
||||
for i := 0; i <= HandshakeRetries+1; i++ {
|
||||
cumulative += time.Duration(i)*HandshakeTryInterval + 1
|
||||
for i := 0; i <= DefaultHandshakeRetries+1; i++ {
|
||||
cumulative += time.Duration(i)*DefaultHandshakeTryInterval + 1
|
||||
next_tick := now.Add(cumulative)
|
||||
//l.Infoln(next_tick)
|
||||
blah.NextOutboundHandshakeTimerTick(next_tick, mw)
|
||||
@@ -93,7 +101,7 @@ func Test_NewHandshakeManagerVpnIP(t *testing.T) {
|
||||
assert.Contains(t, blah.pendingHostMap.Hosts, uint32(v))
|
||||
}
|
||||
// Jump ahead 1 more second
|
||||
cumulative += time.Duration(HandshakeRetries+1) * HandshakeTryInterval
|
||||
cumulative += time.Duration(DefaultHandshakeRetries+1) * DefaultHandshakeTryInterval
|
||||
next_tick := now.Add(cumulative)
|
||||
//l.Infoln(next_tick)
|
||||
blah.NextOutboundHandshakeTimerTick(next_tick, mw)
|
||||
@@ -103,6 +111,56 @@ func Test_NewHandshakeManagerVpnIP(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func Test_NewHandshakeManagerTrigger(t *testing.T) {
|
||||
_, tuncidr, _ := net.ParseCIDR("172.1.1.1/24")
|
||||
_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
|
||||
_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
|
||||
ip := ip2int(net.ParseIP("172.1.1.2"))
|
||||
preferredRanges := []*net.IPNet{localrange}
|
||||
mw := &mockEncWriter{}
|
||||
mainHM := NewHostMap("test", vpncidr, preferredRanges)
|
||||
lh := &LightHouse{}
|
||||
|
||||
blah := NewHandshakeManager(tuncidr, preferredRanges, mainHM, lh, &udpConn{}, defaultHandshakeConfig)
|
||||
|
||||
now := time.Now()
|
||||
blah.NextOutboundHandshakeTimerTick(now, mw)
|
||||
|
||||
assert.Equal(t, 0, testCountTimerWheelEntries(blah.OutboundHandshakeTimer))
|
||||
|
||||
blah.AddVpnIP(ip)
|
||||
|
||||
assert.Equal(t, 1, testCountTimerWheelEntries(blah.OutboundHandshakeTimer))
|
||||
|
||||
// Trigger the same method the channel will
|
||||
blah.handleOutbound(ip, mw, true)
|
||||
|
||||
// Make sure the trigger doesn't schedule another timer entry
|
||||
assert.Equal(t, 1, testCountTimerWheelEntries(blah.OutboundHandshakeTimer))
|
||||
hi := blah.pendingHostMap.Hosts[ip]
|
||||
assert.Nil(t, hi.remote)
|
||||
|
||||
lh.addrMap = map[uint32][]udpAddr{
|
||||
ip: {*NewUDPAddrFromString("10.1.1.1:4242")},
|
||||
}
|
||||
|
||||
// This should trigger the hostmap to populate the hostinfo
|
||||
blah.handleOutbound(ip, mw, true)
|
||||
assert.NotNil(t, hi.remote)
|
||||
assert.Equal(t, 1, testCountTimerWheelEntries(blah.OutboundHandshakeTimer))
|
||||
}
|
||||
|
||||
func testCountTimerWheelEntries(tw *SystemTimerWheel) (c int) {
|
||||
for _, i := range tw.wheel {
|
||||
n := i.Head
|
||||
for n != nil {
|
||||
c++
|
||||
n = n.Next
|
||||
}
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
func Test_NewHandshakeManagerVpnIPcleanup(t *testing.T) {
|
||||
_, tuncidr, _ := net.ParseCIDR("172.1.1.1/24")
|
||||
_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
|
||||
@@ -112,21 +170,24 @@ func Test_NewHandshakeManagerVpnIPcleanup(t *testing.T) {
|
||||
mw := &mockEncWriter{}
|
||||
mainHM := NewHostMap("test", vpncidr, preferredRanges)
|
||||
|
||||
blah := NewHandshakeManager(tuncidr, preferredRanges, mainHM, &LightHouse{}, &udpConn{})
|
||||
blah := NewHandshakeManager(tuncidr, preferredRanges, mainHM, &LightHouse{}, &udpConn{}, defaultHandshakeConfig)
|
||||
|
||||
now := time.Now()
|
||||
blah.NextOutboundHandshakeTimerTick(now, mw)
|
||||
|
||||
hostinfo := blah.AddVpnIP(vpnIP)
|
||||
// Pretned we have an index too
|
||||
blah.AddIndexHostInfo(12341234, hostinfo)
|
||||
assert.Contains(t, blah.pendingHostMap.Indexes, uint32(12341234))
|
||||
err := blah.AddIndexHostInfo(hostinfo)
|
||||
assert.NoError(t, err)
|
||||
blah.InboundHandshakeTimer.Add(hostinfo.localIndexId, time.Second*10)
|
||||
assert.NotZero(t, hostinfo.localIndexId)
|
||||
assert.Contains(t, blah.pendingHostMap.Indexes, hostinfo.localIndexId)
|
||||
|
||||
// Jump ahead `HandshakeRetries` ticks. Eviction should happen in pending
|
||||
// but not main hostmap
|
||||
cumulative := time.Duration(0)
|
||||
for i := 1; i <= HandshakeRetries+2; i++ {
|
||||
cumulative += HandshakeTryInterval * time.Duration(i)
|
||||
for i := 1; i <= DefaultHandshakeRetries+2; i++ {
|
||||
cumulative += DefaultHandshakeTryInterval * time.Duration(i)
|
||||
next_tick := now.Add(cumulative)
|
||||
blah.NextOutboundHandshakeTimerTick(next_tick, mw)
|
||||
}
|
||||
@@ -161,25 +222,28 @@ func Test_NewHandshakeManagerIndexcleanup(t *testing.T) {
|
||||
preferredRanges := []*net.IPNet{localrange}
|
||||
mainHM := NewHostMap("test", vpncidr, preferredRanges)
|
||||
|
||||
blah := NewHandshakeManager(tuncidr, preferredRanges, mainHM, &LightHouse{}, &udpConn{})
|
||||
blah := NewHandshakeManager(tuncidr, preferredRanges, mainHM, &LightHouse{}, &udpConn{}, defaultHandshakeConfig)
|
||||
|
||||
now := time.Now()
|
||||
blah.NextInboundHandshakeTimerTick(now)
|
||||
|
||||
hostinfo, _ := blah.AddIndex(12341234, &ConnectionState{})
|
||||
hostinfo := &HostInfo{ConnectionState: &ConnectionState{}}
|
||||
err := blah.AddIndexHostInfo(hostinfo)
|
||||
assert.NoError(t, err)
|
||||
blah.InboundHandshakeTimer.Add(hostinfo.localIndexId, time.Second*10)
|
||||
// Pretned we have an index too
|
||||
blah.pendingHostMap.AddVpnIPHostInfo(101010, hostinfo)
|
||||
assert.Contains(t, blah.pendingHostMap.Hosts, uint32(101010))
|
||||
|
||||
for i := 1; i <= HandshakeRetries+2; i++ {
|
||||
next_tick := now.Add(HandshakeTryInterval * time.Duration(i))
|
||||
for i := 1; i <= DefaultHandshakeRetries+2; i++ {
|
||||
next_tick := now.Add(DefaultHandshakeTryInterval * time.Duration(i))
|
||||
blah.NextInboundHandshakeTimerTick(next_tick)
|
||||
}
|
||||
|
||||
next_tick := now.Add(HandshakeTryInterval*HandshakeRetries + 3)
|
||||
next_tick := now.Add(DefaultHandshakeTryInterval*DefaultHandshakeRetries + 3)
|
||||
blah.NextInboundHandshakeTimerTick(next_tick)
|
||||
assert.NotContains(t, blah.pendingHostMap.Hosts, uint32(101010))
|
||||
assert.NotContains(t, blah.pendingHostMap.Indexes, uint32(12341234))
|
||||
assert.NotContains(t, blah.pendingHostMap.Indexes, uint32(hostinfo.localIndexId))
|
||||
}
|
||||
|
||||
type mockEncWriter struct {
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
package nebula
|
||||
|
||||
import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
type headerTest struct {
|
||||
|
||||
257
hostmap.go
257
hostmap.go
@@ -6,6 +6,7 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/rcrowley/go-metrics"
|
||||
@@ -19,20 +20,24 @@ const MaxRemotes = 10
|
||||
|
||||
// How long we should prevent roaming back to the previous IP.
|
||||
// This helps prevent flapping due to packets already in flight
|
||||
const RoamingSupressSeconds = 2
|
||||
const RoamingSuppressSeconds = 2
|
||||
|
||||
type HostMap struct {
|
||||
sync.RWMutex //Because we concurrently read and write to our maps
|
||||
name string
|
||||
Indexes map[uint32]*HostInfo
|
||||
RemoteIndexes map[uint32]*HostInfo
|
||||
Hosts map[uint32]*HostInfo
|
||||
preferredRanges []*net.IPNet
|
||||
vpnCIDR *net.IPNet
|
||||
defaultRoute uint32
|
||||
unsafeRoutes *CIDRTree
|
||||
metricsEnabled bool
|
||||
}
|
||||
|
||||
type HostInfo struct {
|
||||
sync.RWMutex
|
||||
|
||||
remote *udpAddr
|
||||
Remotes []*HostInfoDest
|
||||
promoteCounter uint32
|
||||
@@ -49,6 +54,11 @@ type HostInfo struct {
|
||||
recvError int
|
||||
remoteCidr *CIDRTree
|
||||
|
||||
// lastRebindCount is the other side of Interface.rebindCount, if these values don't match then we need to ask LH
|
||||
// for a punch from the remote end of this tunnel. The goal being to prime their conntrack for our traffic just like
|
||||
// with a handshake
|
||||
lastRebindCount int8
|
||||
|
||||
lastRoam time.Time
|
||||
lastRoamRemote *udpAddr
|
||||
}
|
||||
@@ -63,8 +73,7 @@ type cachedPacket struct {
|
||||
type packetCallback func(t NebulaMessageType, st NebulaMessageSubType, h *HostInfo, p, nb, out []byte)
|
||||
|
||||
type HostInfoDest struct {
|
||||
active bool
|
||||
addr *udpAddr
|
||||
addr *udpAddr
|
||||
//probes [ProbeLen]bool
|
||||
probeCounter int
|
||||
}
|
||||
@@ -77,9 +86,11 @@ type Probe struct {
|
||||
func NewHostMap(name string, vpnCIDR *net.IPNet, preferredRanges []*net.IPNet) *HostMap {
|
||||
h := map[uint32]*HostInfo{}
|
||||
i := map[uint32]*HostInfo{}
|
||||
r := map[uint32]*HostInfo{}
|
||||
m := HostMap{
|
||||
name: name,
|
||||
Indexes: i,
|
||||
RemoteIndexes: r,
|
||||
Hosts: h,
|
||||
preferredRanges: preferredRanges,
|
||||
vpnCIDR: vpnCIDR,
|
||||
@@ -94,10 +105,12 @@ func (hm *HostMap) EmitStats(name string) {
|
||||
hm.RLock()
|
||||
hostLen := len(hm.Hosts)
|
||||
indexLen := len(hm.Indexes)
|
||||
remoteIndexLen := len(hm.RemoteIndexes)
|
||||
hm.RUnlock()
|
||||
|
||||
metrics.GetOrRegisterGauge("hostmap."+name+".hosts", nil).Update(int64(hostLen))
|
||||
metrics.GetOrRegisterGauge("hostmap."+name+".indexes", nil).Update(int64(indexLen))
|
||||
metrics.GetOrRegisterGauge("hostmap."+name+".remoteIndexes", nil).Update(int64(remoteIndexLen))
|
||||
}
|
||||
|
||||
func (hm *HostMap) GetIndexByVpnIP(vpnIP uint32) (uint32, error) {
|
||||
@@ -111,17 +124,6 @@ func (hm *HostMap) GetIndexByVpnIP(vpnIP uint32) (uint32, error) {
|
||||
return 0, errors.New("vpn IP not found")
|
||||
}
|
||||
|
||||
func (hm *HostMap) GetVpnIPByIndex(index uint32) (uint32, error) {
|
||||
hm.RLock()
|
||||
if i, ok := hm.Indexes[index]; ok {
|
||||
vpnIP := i.hostId
|
||||
hm.RUnlock()
|
||||
return vpnIP, nil
|
||||
}
|
||||
hm.RUnlock()
|
||||
return 0, errors.New("vpn IP not found")
|
||||
}
|
||||
|
||||
func (hm *HostMap) Add(ip uint32, hostinfo *HostInfo) {
|
||||
hm.Lock()
|
||||
hm.Hosts[ip] = hostinfo
|
||||
@@ -164,37 +166,17 @@ func (hm *HostMap) DeleteVpnIP(vpnIP uint32) {
|
||||
}
|
||||
}
|
||||
|
||||
func (hm *HostMap) AddIndex(index uint32, ci *ConnectionState) (*HostInfo, error) {
|
||||
// Only used by pendingHostMap when the remote index is not initially known
|
||||
func (hm *HostMap) addRemoteIndexHostInfo(index uint32, h *HostInfo) {
|
||||
hm.Lock()
|
||||
if _, ok := hm.Indexes[index]; !ok {
|
||||
h := &HostInfo{
|
||||
ConnectionState: ci,
|
||||
Remotes: []*HostInfoDest{},
|
||||
localIndexId: index,
|
||||
HandshakePacket: make(map[uint8][]byte, 0),
|
||||
}
|
||||
hm.Indexes[index] = h
|
||||
l.WithField("hostMap", m{"mapName": hm.name, "indexNumber": index, "mapTotalSize": len(hm.Indexes),
|
||||
"hostinfo": m{"existing": false, "localIndexId": h.localIndexId, "hostId": IntIp(h.hostId)}}).
|
||||
Debug("Hostmap index added")
|
||||
|
||||
hm.Unlock()
|
||||
return h, nil
|
||||
}
|
||||
hm.Unlock()
|
||||
return nil, fmt.Errorf("refusing to overwrite existing index: %d", index)
|
||||
}
|
||||
|
||||
func (hm *HostMap) AddIndexHostInfo(index uint32, h *HostInfo) {
|
||||
hm.Lock()
|
||||
h.localIndexId = index
|
||||
hm.Indexes[index] = h
|
||||
h.remoteIndexId = index
|
||||
hm.RemoteIndexes[index] = h
|
||||
hm.Unlock()
|
||||
|
||||
if l.Level > logrus.DebugLevel {
|
||||
l.WithField("hostMap", m{"mapName": hm.name, "indexNumber": index, "mapTotalSize": len(hm.Indexes),
|
||||
"hostinfo": m{"existing": true, "localIndexId": h.localIndexId, "hostId": IntIp(h.hostId)}}).
|
||||
Debug("Hostmap index added")
|
||||
Debug("Hostmap remoteIndex added")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -202,6 +184,8 @@ func (hm *HostMap) AddVpnIPHostInfo(vpnIP uint32, h *HostInfo) {
|
||||
hm.Lock()
|
||||
h.hostId = vpnIP
|
||||
hm.Hosts[vpnIP] = h
|
||||
hm.Indexes[h.localIndexId] = h
|
||||
hm.RemoteIndexes[h.remoteIndexId] = h
|
||||
hm.Unlock()
|
||||
|
||||
if l.Level > logrus.DebugLevel {
|
||||
@@ -211,11 +195,20 @@ func (hm *HostMap) AddVpnIPHostInfo(vpnIP uint32, h *HostInfo) {
|
||||
}
|
||||
}
|
||||
|
||||
// This is only called in pendingHostmap, to cleanup an inbound handshake
|
||||
func (hm *HostMap) DeleteIndex(index uint32) {
|
||||
hm.Lock()
|
||||
delete(hm.Indexes, index)
|
||||
if len(hm.Indexes) == 0 {
|
||||
hm.Indexes = map[uint32]*HostInfo{}
|
||||
hostinfo, ok := hm.Indexes[index]
|
||||
if ok {
|
||||
delete(hm.Indexes, index)
|
||||
delete(hm.RemoteIndexes, hostinfo.remoteIndexId)
|
||||
|
||||
// Check if we have an entry under hostId that matches the same hostinfo
|
||||
// instance. Clean it up as well if we do.
|
||||
hostinfo2, ok := hm.Hosts[hostinfo.hostId]
|
||||
if ok && hostinfo2 == hostinfo {
|
||||
delete(hm.Hosts, hostinfo.hostId)
|
||||
}
|
||||
}
|
||||
hm.Unlock()
|
||||
|
||||
@@ -225,6 +218,64 @@ func (hm *HostMap) DeleteIndex(index uint32) {
|
||||
}
|
||||
}
|
||||
|
||||
// This is used to cleanup on recv_error
|
||||
func (hm *HostMap) DeleteReverseIndex(index uint32) {
|
||||
hm.Lock()
|
||||
hostinfo, ok := hm.RemoteIndexes[index]
|
||||
if ok {
|
||||
delete(hm.Indexes, hostinfo.localIndexId)
|
||||
delete(hm.RemoteIndexes, index)
|
||||
|
||||
// Check if we have an entry under hostId that matches the same hostinfo
|
||||
// instance. Clean it up as well if we do (they might not match in pendingHostmap)
|
||||
var hostinfo2 *HostInfo
|
||||
hostinfo2, ok = hm.Hosts[hostinfo.hostId]
|
||||
if ok && hostinfo2 == hostinfo {
|
||||
delete(hm.Hosts, hostinfo.hostId)
|
||||
}
|
||||
}
|
||||
hm.Unlock()
|
||||
|
||||
if l.Level >= logrus.DebugLevel {
|
||||
l.WithField("hostMap", m{"mapName": hm.name, "indexNumber": index, "mapTotalSize": len(hm.Indexes)}).
|
||||
Debug("Hostmap remote index deleted")
|
||||
}
|
||||
}
|
||||
|
||||
func (hm *HostMap) DeleteHostInfo(hostinfo *HostInfo) {
|
||||
hm.Lock()
|
||||
|
||||
// Check if this same hostId is in the hostmap with a different instance.
|
||||
// This could happen if we have an entry in the pending hostmap with different
|
||||
// index values than the one in the main hostmap.
|
||||
hostinfo2, ok := hm.Hosts[hostinfo.hostId]
|
||||
if ok && hostinfo2 != hostinfo {
|
||||
delete(hm.Hosts, hostinfo2.hostId)
|
||||
delete(hm.Indexes, hostinfo2.localIndexId)
|
||||
delete(hm.RemoteIndexes, hostinfo2.remoteIndexId)
|
||||
}
|
||||
|
||||
delete(hm.Hosts, hostinfo.hostId)
|
||||
if len(hm.Hosts) == 0 {
|
||||
hm.Hosts = map[uint32]*HostInfo{}
|
||||
}
|
||||
delete(hm.Indexes, hostinfo.localIndexId)
|
||||
if len(hm.Indexes) == 0 {
|
||||
hm.Indexes = map[uint32]*HostInfo{}
|
||||
}
|
||||
delete(hm.RemoteIndexes, hostinfo.remoteIndexId)
|
||||
if len(hm.RemoteIndexes) == 0 {
|
||||
hm.RemoteIndexes = map[uint32]*HostInfo{}
|
||||
}
|
||||
hm.Unlock()
|
||||
|
||||
if l.Level >= logrus.DebugLevel {
|
||||
l.WithField("hostMap", m{"mapName": hm.name, "mapTotalSize": len(hm.Hosts),
|
||||
"vpnIp": IntIp(hostinfo.hostId), "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId}).
|
||||
Debug("Hostmap hostInfo deleted")
|
||||
}
|
||||
}
|
||||
|
||||
func (hm *HostMap) QueryIndex(index uint32) (*HostInfo, error) {
|
||||
//TODO: we probably just want ot return bool instead of error, or at least a static error
|
||||
hm.RLock()
|
||||
@@ -237,23 +288,15 @@ func (hm *HostMap) QueryIndex(index uint32) (*HostInfo, error) {
|
||||
}
|
||||
}
|
||||
|
||||
// This function needs to range because we don't keep a map of remote indexes.
|
||||
func (hm *HostMap) QueryReverseIndex(index uint32) (*HostInfo, error) {
|
||||
hm.RLock()
|
||||
for _, h := range hm.Indexes {
|
||||
if h.ConnectionState != nil && h.remoteIndexId == index {
|
||||
hm.RUnlock()
|
||||
return h, nil
|
||||
}
|
||||
if h, ok := hm.RemoteIndexes[index]; ok {
|
||||
hm.RUnlock()
|
||||
return h, nil
|
||||
} else {
|
||||
hm.RUnlock()
|
||||
return nil, fmt.Errorf("unable to find reverse index or connectionstate nil in %s hostmap", hm.name)
|
||||
}
|
||||
for _, h := range hm.Hosts {
|
||||
if h.ConnectionState != nil && h.remoteIndexId == index {
|
||||
hm.RUnlock()
|
||||
return h, nil
|
||||
}
|
||||
}
|
||||
hm.RUnlock()
|
||||
return nil, fmt.Errorf("unable to find reverse index or connectionstate nil in %s hostmap", hm.name)
|
||||
}
|
||||
|
||||
func (hm *HostMap) AddRemote(vpnIp uint32, remote *udpAddr) *HostInfo {
|
||||
@@ -319,36 +362,26 @@ func (hm *HostMap) queryUnsafeRoute(ip uint32) uint32 {
|
||||
}
|
||||
}
|
||||
|
||||
func (hm *HostMap) CheckHandshakeCompleteIP(vpnIP uint32) bool {
|
||||
hm.RLock()
|
||||
if i, ok := hm.Hosts[vpnIP]; ok {
|
||||
if i == nil {
|
||||
hm.RUnlock()
|
||||
return false
|
||||
}
|
||||
complete := i.HandshakeComplete
|
||||
hm.RUnlock()
|
||||
return complete
|
||||
// We already have the hm Lock when this is called, so make sure to not call
|
||||
// any other methods that might try to grab it again
|
||||
func (hm *HostMap) addHostInfo(hostinfo *HostInfo, f *Interface) {
|
||||
remoteCert := hostinfo.ConnectionState.peerCert
|
||||
ip := ip2int(remoteCert.Details.Ips[0].IP)
|
||||
|
||||
f.lightHouse.AddRemoteAndReset(ip, hostinfo.remote)
|
||||
if f.serveDns {
|
||||
dnsR.Add(remoteCert.Details.Name+".", remoteCert.Details.Ips[0].IP.String())
|
||||
}
|
||||
hm.RUnlock()
|
||||
return false
|
||||
}
|
||||
|
||||
func (hm *HostMap) CheckHandshakeCompleteIndex(index uint32) bool {
|
||||
hm.RLock()
|
||||
if i, ok := hm.Indexes[index]; ok {
|
||||
if i == nil {
|
||||
hm.RUnlock()
|
||||
return false
|
||||
}
|
||||
complete := i.HandshakeComplete
|
||||
hm.RUnlock()
|
||||
return complete
|
||||
hm.Hosts[hostinfo.hostId] = hostinfo
|
||||
hm.Indexes[hostinfo.localIndexId] = hostinfo
|
||||
hm.RemoteIndexes[hostinfo.remoteIndexId] = hostinfo
|
||||
|
||||
if l.Level >= logrus.DebugLevel {
|
||||
l.WithField("hostMap", m{"mapName": hm.name, "vpnIp": IntIp(hostinfo.hostId), "mapTotalSize": len(hm.Hosts),
|
||||
"hostinfo": m{"existing": true, "localIndexId": hostinfo.localIndexId, "hostId": IntIp(hostinfo.hostId)}}).
|
||||
Debug("Hostmap vpnIp added")
|
||||
}
|
||||
hm.RUnlock()
|
||||
return false
|
||||
}
|
||||
|
||||
func (hm *HostMap) ClearRemotes(vpnIP uint32) {
|
||||
@@ -384,8 +417,16 @@ func (hm *HostMap) PunchList() []*udpAddr {
|
||||
}
|
||||
|
||||
func (hm *HostMap) Punchy(conn *udpConn) {
|
||||
var metricsTxPunchy metrics.Counter
|
||||
if hm.metricsEnabled {
|
||||
metricsTxPunchy = metrics.GetOrRegisterCounter("messages.tx.punchy", nil)
|
||||
} else {
|
||||
metricsTxPunchy = metrics.NilCounter{}
|
||||
}
|
||||
|
||||
for {
|
||||
for _, addr := range hm.PunchList() {
|
||||
metricsTxPunchy.Inc(1)
|
||||
conn.WriteTo([]byte{1}, addr)
|
||||
}
|
||||
time.Sleep(time.Second * 30)
|
||||
@@ -430,8 +471,7 @@ func (i *HostInfo) TryPromoteBest(preferredRanges []*net.IPNet, ifce *Interface)
|
||||
return
|
||||
}
|
||||
|
||||
i.promoteCounter++
|
||||
if i.promoteCounter%PromoteEvery == 0 {
|
||||
if atomic.AddUint32(&i.promoteCounter, 1)&PromoteEvery == 0 {
|
||||
// return early if we are already on a preferred remote
|
||||
rIP := udp2ip(i.remote)
|
||||
for _, l := range preferredRanges {
|
||||
@@ -532,13 +572,15 @@ func (i *HostInfo) cachePacket(t NebulaMessageType, st NebulaMessageSubType, pac
|
||||
copy(tempPacket, packet)
|
||||
//l.WithField("trace", string(debug.Stack())).Error("Caching packet", tempPacket)
|
||||
i.packetStore = append(i.packetStore, &cachedPacket{t, st, f, tempPacket})
|
||||
l.WithField("vpnIp", IntIp(i.hostId)).
|
||||
WithField("length", len(i.packetStore)).
|
||||
WithField("stored", true).
|
||||
Debugf("Packet store")
|
||||
if l.Level >= logrus.DebugLevel {
|
||||
i.logger().
|
||||
WithField("length", len(i.packetStore)).
|
||||
WithField("stored", true).
|
||||
Debugf("Packet store")
|
||||
}
|
||||
|
||||
} else if l.Level >= logrus.DebugLevel {
|
||||
l.WithField("vpnIp", IntIp(i.hostId)).
|
||||
i.logger().
|
||||
WithField("length", len(i.packetStore)).
|
||||
WithField("stored", false).
|
||||
Debugf("Packet store")
|
||||
@@ -555,8 +597,8 @@ func (i *HostInfo) handshakeComplete() {
|
||||
i.HandshakeComplete = true
|
||||
//TODO: this should be managed by the handshake state machine to set it based on how many handshake were seen.
|
||||
// Clamping it to 2 gets us out of the woods for now
|
||||
*i.ConnectionState.messageCounter = 2
|
||||
l.WithField("vpnIp", IntIp(i.hostId)).Debugf("Sending %d stored packets", len(i.packetStore))
|
||||
atomic.StoreUint64(&i.ConnectionState.atomicMessageCounter, 2)
|
||||
i.logger().Debugf("Sending %d stored packets", len(i.packetStore))
|
||||
nb := make([]byte, 12, 12)
|
||||
out := make([]byte, mtu)
|
||||
for _, cp := range i.packetStore {
|
||||
@@ -623,6 +665,11 @@ func (i *HostInfo) RecvErrorExceeded() bool {
|
||||
}
|
||||
|
||||
func (i *HostInfo) CreateRemoteCIDR(c *cert.NebulaCertificate) {
|
||||
if len(c.Details.Ips) == 1 && len(c.Details.Subnets) == 0 {
|
||||
// Simple case, no CIDRTree needed
|
||||
return
|
||||
}
|
||||
|
||||
remoteCidr := NewCIDRTree()
|
||||
for _, ip := range c.Details.Ips {
|
||||
remoteCidr.AddCIDR(&net.IPNet{IP: ip.IP, Mask: net.IPMask{255, 255, 255, 255}}, struct{}{})
|
||||
@@ -634,6 +681,22 @@ func (i *HostInfo) CreateRemoteCIDR(c *cert.NebulaCertificate) {
|
||||
i.remoteCidr = remoteCidr
|
||||
}
|
||||
|
||||
func (i *HostInfo) logger() *logrus.Entry {
|
||||
if i == nil {
|
||||
return logrus.NewEntry(l)
|
||||
}
|
||||
|
||||
li := l.WithField("vpnIp", IntIp(i.hostId))
|
||||
|
||||
if connState := i.ConnectionState; connState != nil {
|
||||
if peerCert := connState.peerCert; peerCert != nil {
|
||||
li = li.WithField("certName", peerCert.Details.Name)
|
||||
}
|
||||
}
|
||||
|
||||
return li
|
||||
}
|
||||
|
||||
//########################
|
||||
|
||||
func NewHostInfoDest(addr *udpAddr) *HostInfoDest {
|
||||
@@ -645,7 +708,6 @@ func NewHostInfoDest(addr *udpAddr) *HostInfoDest {
|
||||
|
||||
func (hid *HostInfoDest) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(m{
|
||||
"active": hid.active,
|
||||
"address": hid.addr,
|
||||
"probe_count": hid.probeCounter,
|
||||
})
|
||||
@@ -734,11 +796,16 @@ func (d *HostInfoDest) ProbeReceived(probeCount int) {
|
||||
|
||||
// Utility functions
|
||||
|
||||
func localIps() *[]net.IP {
|
||||
func localIps(allowList *AllowList) *[]net.IP {
|
||||
//FIXME: This function is pretty garbage
|
||||
var ips []net.IP
|
||||
ifaces, _ := net.Interfaces()
|
||||
for _, i := range ifaces {
|
||||
allow := allowList.AllowName(i.Name)
|
||||
l.WithField("interfaceName", i.Name).WithField("allow", allow).Debug("localAllowList.AllowName")
|
||||
if !allow {
|
||||
continue
|
||||
}
|
||||
addrs, _ := i.Addrs()
|
||||
for _, addr := range addrs {
|
||||
var ip net.IP
|
||||
@@ -750,6 +817,12 @@ func localIps() *[]net.IP {
|
||||
ip = v.IP
|
||||
}
|
||||
if ip.To4() != nil && ip.IsLoopback() == false {
|
||||
allow := allowList.Allow(ip2int(ip))
|
||||
l.WithField("localIp", ip).WithField("allow", allow).Debug("localAllowList.Allow")
|
||||
if !allow {
|
||||
continue
|
||||
}
|
||||
|
||||
ips = append(ips, ip)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -161,6 +161,4 @@ func BenchmarkHostmappromote2(b *testing.B) {
|
||||
m.AddRemote(ip2int(net.ParseIP("10.128.1.1")), g)
|
||||
m.AddRemote(ip2int(net.ParseIP("10.128.1.1")), y)
|
||||
}
|
||||
b.Errorf("hi")
|
||||
|
||||
}
|
||||
|
||||
109
inside.go
109
inside.go
@@ -7,7 +7,7 @@ import (
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *FirewallPacket, nb, out []byte) {
|
||||
func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *FirewallPacket, nb, out []byte, q int, localCache ConntrackCache) {
|
||||
err := newPacket(packet, false, fwPacket)
|
||||
if err != nil {
|
||||
l.WithField("packet", packet).Debugf("Error while validating outbound packet: %s", err)
|
||||
@@ -19,12 +19,25 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *FirewallPacket,
|
||||
return
|
||||
}
|
||||
|
||||
// Ignore packets from self to self
|
||||
if fwPacket.RemoteIP == f.lightHouse.myIp {
|
||||
return
|
||||
}
|
||||
|
||||
// Ignore broadcast packets
|
||||
if f.dropMulticast && isMulticast(fwPacket.RemoteIP) {
|
||||
return
|
||||
}
|
||||
|
||||
hostinfo := f.getOrHandshake(fwPacket.RemoteIP)
|
||||
if hostinfo == nil {
|
||||
if l.Level >= logrus.DebugLevel {
|
||||
l.WithField("vpnIp", IntIp(fwPacket.RemoteIP)).
|
||||
WithField("fwPacket", fwPacket).
|
||||
Debugln("dropping outbound packet, vpnIp not in our CIDR or in unsafe routes")
|
||||
}
|
||||
return
|
||||
}
|
||||
ci := hostinfo.ConnectionState
|
||||
|
||||
if ci.ready == false {
|
||||
@@ -39,21 +52,28 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *FirewallPacket,
|
||||
ci.queueLock.Unlock()
|
||||
}
|
||||
|
||||
if !f.firewall.Drop(packet, *fwPacket, false, hostinfo, trustedCAs) {
|
||||
f.send(message, 0, ci, hostinfo, hostinfo.remote, packet, nb, out)
|
||||
if f.lightHouse != nil && *ci.messageCounter%5000 == 0 {
|
||||
dropReason := f.firewall.Drop(packet, *fwPacket, false, hostinfo, trustedCAs, localCache)
|
||||
if dropReason == nil {
|
||||
mc := f.sendNoMetrics(message, 0, ci, hostinfo, hostinfo.remote, packet, nb, out, q)
|
||||
if f.lightHouse != nil && mc%5000 == 0 {
|
||||
f.lightHouse.Query(fwPacket.RemoteIP, f)
|
||||
}
|
||||
|
||||
} else if l.Level >= logrus.DebugLevel {
|
||||
l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("fwPacket", fwPacket).
|
||||
hostinfo.logger().
|
||||
WithField("fwPacket", fwPacket).
|
||||
WithField("reason", dropReason).
|
||||
Debugln("dropping outbound packet")
|
||||
}
|
||||
}
|
||||
|
||||
// getOrHandshake returns nil if the vpnIp is not routable
|
||||
func (f *Interface) getOrHandshake(vpnIp uint32) *HostInfo {
|
||||
if f.hostMap.vpnCIDR.Contains(int2ip(vpnIp)) == false {
|
||||
vpnIp = f.hostMap.queryUnsafeRoute(vpnIp)
|
||||
if vpnIp == 0 {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
hostinfo, err := f.hostMap.PromoteBestQueryVpnIP(vpnIp, f)
|
||||
|
||||
@@ -71,6 +91,17 @@ func (f *Interface) getOrHandshake(vpnIp uint32) *HostInfo {
|
||||
return hostinfo
|
||||
}
|
||||
|
||||
// Handshake is not ready, we need to grab the lock now before we start
|
||||
// the handshake process
|
||||
hostinfo.Lock()
|
||||
defer hostinfo.Unlock()
|
||||
|
||||
// Double check, now that we have the lock
|
||||
ci = hostinfo.ConnectionState
|
||||
if ci != nil && ci.eKey != nil && ci.ready {
|
||||
return hostinfo
|
||||
}
|
||||
|
||||
if ci == nil {
|
||||
// if we don't have a connection state, then send a handshake initiation
|
||||
ci = f.newConnectionState(true, noise.HandshakeIX, []byte{}, 0)
|
||||
@@ -86,6 +117,15 @@ func (f *Interface) getOrHandshake(vpnIp uint32) *HostInfo {
|
||||
ixHandshakeStage0(f, vpnIp, hostinfo)
|
||||
// FIXME: Maybe make XX selectable, but probably not since psk makes it nearly pointless for us.
|
||||
//xx_handshakeStage0(f, ip, hostinfo)
|
||||
|
||||
// If this is a static host, we don't need to wait for the HostQueryReply
|
||||
// We can trigger the handshake right now
|
||||
if _, ok := f.lightHouse.staticList[vpnIp]; ok {
|
||||
select {
|
||||
case f.handshakeManager.trigger <- vpnIp:
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return hostinfo
|
||||
@@ -100,13 +140,18 @@ func (f *Interface) sendMessageNow(t NebulaMessageType, st NebulaMessageSubType,
|
||||
}
|
||||
|
||||
// check if packet is in outbound fw rules
|
||||
if f.firewall.Drop(p, *fp, false, hostInfo, trustedCAs) {
|
||||
l.WithField("fwPacket", fp).Debugln("dropping cached packet")
|
||||
dropReason := f.firewall.Drop(p, *fp, false, hostInfo, trustedCAs, nil)
|
||||
if dropReason != nil {
|
||||
if l.Level >= logrus.DebugLevel {
|
||||
l.WithField("fwPacket", fp).
|
||||
WithField("reason", dropReason).
|
||||
Debugln("dropping cached packet")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
f.send(message, st, hostInfo.ConnectionState, hostInfo, hostInfo.remote, p, nb, out)
|
||||
if f.lightHouse != nil && *hostInfo.ConnectionState.messageCounter%5000 == 0 {
|
||||
messageCounter := f.sendNoMetrics(message, st, hostInfo.ConnectionState, hostInfo, hostInfo.remote, p, nb, out, 0)
|
||||
if f.lightHouse != nil && messageCounter%5000 == 0 {
|
||||
f.lightHouse.Query(fp.RemoteIP, f)
|
||||
}
|
||||
}
|
||||
@@ -114,6 +159,13 @@ func (f *Interface) sendMessageNow(t NebulaMessageType, st NebulaMessageSubType,
|
||||
// SendMessageToVpnIp handles real ip:port lookup and sends to the current best known address for vpnIp
|
||||
func (f *Interface) SendMessageToVpnIp(t NebulaMessageType, st NebulaMessageSubType, vpnIp uint32, p, nb, out []byte) {
|
||||
hostInfo := f.getOrHandshake(vpnIp)
|
||||
if hostInfo == nil {
|
||||
if l.Level >= logrus.DebugLevel {
|
||||
l.WithField("vpnIp", IntIp(vpnIp)).
|
||||
Debugln("dropping SendMessageToVpnIp, vpnIp not in our CIDR or in unsafe routes")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if !hostInfo.ConnectionState.ready {
|
||||
// Because we might be sending stored packets, lock here to stop new things going to
|
||||
@@ -138,6 +190,13 @@ func (f *Interface) sendMessageToVpnIp(t NebulaMessageType, st NebulaMessageSubT
|
||||
// SendMessageToAll handles real ip:port lookup and sends to all known addresses for vpnIp
|
||||
func (f *Interface) SendMessageToAll(t NebulaMessageType, st NebulaMessageSubType, vpnIp uint32, p, nb, out []byte) {
|
||||
hostInfo := f.getOrHandshake(vpnIp)
|
||||
if hostInfo == nil {
|
||||
if l.Level >= logrus.DebugLevel {
|
||||
l.WithField("vpnIp", IntIp(vpnIp)).
|
||||
Debugln("dropping SendMessageToAll, vpnIp not in our CIDR or in unsafe routes")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if hostInfo.ConnectionState.ready == false {
|
||||
// Because we might be sending stored packets, lock here to stop new things going to
|
||||
@@ -162,36 +221,54 @@ func (f *Interface) sendMessageToAll(t NebulaMessageType, st NebulaMessageSubTyp
|
||||
}
|
||||
|
||||
func (f *Interface) send(t NebulaMessageType, st NebulaMessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote *udpAddr, p, nb, out []byte) {
|
||||
f.messageMetrics.Tx(t, st, 1)
|
||||
f.sendNoMetrics(t, st, ci, hostinfo, remote, p, nb, out, 0)
|
||||
}
|
||||
|
||||
func (f *Interface) sendNoMetrics(t NebulaMessageType, st NebulaMessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote *udpAddr, p, nb, out []byte, q int) uint64 {
|
||||
if ci.eKey == nil {
|
||||
//TODO: log warning
|
||||
return
|
||||
return 0
|
||||
}
|
||||
|
||||
var err error
|
||||
//TODO: enable if we do more than 1 tun queue
|
||||
//ci.writeLock.Lock()
|
||||
c := atomic.AddUint64(ci.messageCounter, 1)
|
||||
c := atomic.AddUint64(&ci.atomicMessageCounter, 1)
|
||||
|
||||
//l.WithField("trace", string(debug.Stack())).Error("out Header ", &Header{Version, t, st, 0, hostinfo.remoteIndexId, c}, p)
|
||||
out = HeaderEncode(out, Version, uint8(t), uint8(st), hostinfo.remoteIndexId, c)
|
||||
f.connectionManager.Out(hostinfo.hostId)
|
||||
|
||||
// Query our LH if we haven't since the last time we've been rebound, this will cause the remote to punch against
|
||||
// all our IPs and enable a faster roaming.
|
||||
if hostinfo.lastRebindCount != f.rebindCount {
|
||||
//NOTE: there is an update hole if a tunnel isn't used and exactly 256 rebinds occur before the tunnel is
|
||||
// finally used again. This tunnel would eventually be torn down and recreated if this action didn't help.
|
||||
f.lightHouse.Query(hostinfo.hostId, f)
|
||||
hostinfo.lastRebindCount = f.rebindCount
|
||||
if l.Level >= logrus.DebugLevel {
|
||||
l.WithField("vpnIp", hostinfo.hostId).Debug("Lighthouse update triggered for punch due to rebind counter")
|
||||
}
|
||||
}
|
||||
|
||||
out, err = ci.eKey.EncryptDanger(out, out, p, c, nb)
|
||||
//TODO: see above note on lock
|
||||
//ci.writeLock.Unlock()
|
||||
if err != nil {
|
||||
l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).
|
||||
hostinfo.logger().WithError(err).
|
||||
WithField("udpAddr", remote).WithField("counter", c).
|
||||
WithField("attemptedCounter", ci.messageCounter).
|
||||
WithField("attemptedCounter", c).
|
||||
Error("Failed to encrypt outgoing packet")
|
||||
return
|
||||
return c
|
||||
}
|
||||
|
||||
err = f.outside.WriteTo(out, remote)
|
||||
err = f.writers[q].WriteTo(out, remote)
|
||||
if err != nil {
|
||||
l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).
|
||||
hostinfo.logger().WithError(err).
|
||||
WithField("udpAddr", remote).Error("Failed to write outgoing packet")
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
func isMulticast(ip uint32) bool {
|
||||
|
||||
137
interface.go
137
interface.go
@@ -2,7 +2,10 @@ package nebula
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"runtime"
|
||||
"time"
|
||||
|
||||
"github.com/rcrowley/go-metrics"
|
||||
@@ -10,10 +13,19 @@ import (
|
||||
|
||||
const mtu = 9001
|
||||
|
||||
type Inside interface {
|
||||
io.ReadWriteCloser
|
||||
Activate() error
|
||||
CidrNet() *net.IPNet
|
||||
DeviceName() string
|
||||
WriteRaw([]byte) error
|
||||
NewMultiQueueReader() (io.ReadWriteCloser, error)
|
||||
}
|
||||
|
||||
type InterfaceConfig struct {
|
||||
HostMap *HostMap
|
||||
Outside *udpConn
|
||||
Inside *Tun
|
||||
Inside Inside
|
||||
certState *CertState
|
||||
Cipher string
|
||||
Firewall *Firewall
|
||||
@@ -25,12 +37,17 @@ type InterfaceConfig struct {
|
||||
DropLocalBroadcast bool
|
||||
DropMulticast bool
|
||||
UDPBatchSize int
|
||||
routines int
|
||||
MessageMetrics *MessageMetrics
|
||||
version string
|
||||
|
||||
ConntrackCacheTimeout time.Duration
|
||||
}
|
||||
|
||||
type Interface struct {
|
||||
hostMap *HostMap
|
||||
outside *udpConn
|
||||
inside *Tun
|
||||
inside Inside
|
||||
certState *CertState
|
||||
cipher string
|
||||
firewall *Firewall
|
||||
@@ -43,11 +60,19 @@ type Interface struct {
|
||||
dropLocalBroadcast bool
|
||||
dropMulticast bool
|
||||
udpBatchSize int
|
||||
version string
|
||||
routines int
|
||||
|
||||
metricRxRecvError metrics.Counter
|
||||
metricTxRecvError metrics.Counter
|
||||
metricHandshakes metrics.Histogram
|
||||
// rebindCount is used to decide if an active tunnel should trigger a punch notification through a lighthouse
|
||||
rebindCount int8
|
||||
version string
|
||||
|
||||
conntrackCacheTimeout time.Duration
|
||||
|
||||
writers []*udpConn
|
||||
readers []io.ReadWriteCloser
|
||||
|
||||
metricHandshakes metrics.Histogram
|
||||
messageMetrics *MessageMetrics
|
||||
}
|
||||
|
||||
func NewInterface(c *InterfaceConfig) (*Interface, error) {
|
||||
@@ -79,10 +104,15 @@ func NewInterface(c *InterfaceConfig) (*Interface, error) {
|
||||
dropLocalBroadcast: c.DropLocalBroadcast,
|
||||
dropMulticast: c.DropMulticast,
|
||||
udpBatchSize: c.UDPBatchSize,
|
||||
routines: c.routines,
|
||||
version: c.version,
|
||||
writers: make([]*udpConn, c.routines),
|
||||
readers: make([]io.ReadWriteCloser, c.routines),
|
||||
|
||||
metricRxRecvError: metrics.GetOrRegisterCounter("messages.rx.recv_error", nil),
|
||||
metricTxRecvError: metrics.GetOrRegisterCounter("messages.tx.recv_error", nil),
|
||||
metricHandshakes: metrics.GetOrRegisterHistogram("handshakes", nil, metrics.NewExpDecaySample(1028, 0.015)),
|
||||
conntrackCacheTimeout: c.ConntrackCacheTimeout,
|
||||
|
||||
metricHandshakes: metrics.GetOrRegisterHistogram("handshakes", nil, metrics.NewExpDecaySample(1028, 0.015)),
|
||||
messageMetrics: c.MessageMetrics,
|
||||
}
|
||||
|
||||
ifce.connectionManager = newConnectionManager(ifce, c.checkInterval, c.pendingDeletionInterval)
|
||||
@@ -90,64 +120,79 @@ func NewInterface(c *InterfaceConfig) (*Interface, error) {
|
||||
return ifce, nil
|
||||
}
|
||||
|
||||
func (f *Interface) Run(tunRoutines, udpRoutines int, buildVersion string) {
|
||||
func (f *Interface) run() {
|
||||
// actually turn on tun dev
|
||||
|
||||
addr, err := f.outside.LocalAddr()
|
||||
if err != nil {
|
||||
l.WithError(err).Error("Failed to get udp listen address")
|
||||
}
|
||||
|
||||
l.WithField("interface", f.inside.DeviceName()).WithField("network", f.inside.CidrNet().String()).
|
||||
WithField("build", f.version).WithField("udpAddr", addr).
|
||||
Info("Nebula interface is active")
|
||||
|
||||
metrics.GetOrRegisterGauge("routines", nil).Update(int64(f.routines))
|
||||
|
||||
// Prepare n tun queues
|
||||
var reader io.ReadWriteCloser = f.inside
|
||||
for i := 0; i < f.routines; i++ {
|
||||
if i > 0 {
|
||||
reader, err = f.inside.NewMultiQueueReader()
|
||||
if err != nil {
|
||||
l.Fatal(err)
|
||||
}
|
||||
}
|
||||
f.readers[i] = reader
|
||||
}
|
||||
|
||||
if err := f.inside.Activate(); err != nil {
|
||||
l.Fatal(err)
|
||||
}
|
||||
|
||||
f.version = buildVersion
|
||||
l.WithField("interface", f.inside.Device).WithField("network", f.inside.Cidr.String()).
|
||||
WithField("build", buildVersion).
|
||||
Info("Nebula interface is active")
|
||||
|
||||
// Launch n queues to read packets from udp
|
||||
for i := 0; i < udpRoutines; i++ {
|
||||
for i := 0; i < f.routines; i++ {
|
||||
go f.listenOut(i)
|
||||
}
|
||||
|
||||
// Launch n queues to read packets from tun dev
|
||||
for i := 0; i < tunRoutines; i++ {
|
||||
go f.listenIn(i)
|
||||
for i := 0; i < f.routines; i++ {
|
||||
go f.listenIn(f.readers[i], i)
|
||||
}
|
||||
}
|
||||
|
||||
func (f *Interface) listenOut(i int) {
|
||||
//TODO: handle error
|
||||
addr, err := f.outside.LocalAddr()
|
||||
if err != nil {
|
||||
l.WithError(err).Error("failed to discover udp listening address")
|
||||
}
|
||||
runtime.LockOSThread()
|
||||
|
||||
var li *udpConn
|
||||
// TODO clean this up with a coherent interface for each outside connection
|
||||
if i > 0 {
|
||||
//TODO: handle error
|
||||
li, err = NewListener(udp2ip(addr).String(), int(addr.Port), i > 0)
|
||||
if err != nil {
|
||||
l.WithError(err).Error("failed to make a new udp listener")
|
||||
}
|
||||
li = f.writers[i]
|
||||
} else {
|
||||
li = f.outside
|
||||
}
|
||||
|
||||
li.ListenOut(f)
|
||||
li.ListenOut(f, i)
|
||||
}
|
||||
|
||||
func (f *Interface) listenIn(i int) {
|
||||
func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
|
||||
runtime.LockOSThread()
|
||||
|
||||
packet := make([]byte, mtu)
|
||||
out := make([]byte, mtu)
|
||||
fwPacket := &FirewallPacket{}
|
||||
nb := make([]byte, 12, 12)
|
||||
|
||||
conntrackCache := NewConntrackCacheTicker(f.conntrackCacheTimeout)
|
||||
|
||||
for {
|
||||
n, err := f.inside.Read(packet)
|
||||
n, err := reader.Read(packet)
|
||||
if err != nil {
|
||||
l.WithError(err).Error("Error while reading outbound packet")
|
||||
// This only seems to happen when something fatal happens to the fd, so exit.
|
||||
os.Exit(2)
|
||||
}
|
||||
|
||||
f.consumeInsidePacket(packet[:n], fwPacket, nb, out)
|
||||
f.consumeInsidePacket(packet[:n], fwPacket, nb, out, i, conntrackCache.Get())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -155,7 +200,9 @@ func (f *Interface) RegisterConfigChangeCallbacks(c *Config) {
|
||||
c.RegisterReloadCallback(f.reloadCA)
|
||||
c.RegisterReloadCallback(f.reloadCertKey)
|
||||
c.RegisterReloadCallback(f.reloadFirewall)
|
||||
c.RegisterReloadCallback(f.outside.reloadConfig)
|
||||
for _, udpConn := range f.writers {
|
||||
c.RegisterReloadCallback(udpConn.reloadConfig)
|
||||
}
|
||||
}
|
||||
|
||||
func (f *Interface) reloadCA(c *Config) {
|
||||
@@ -205,18 +252,40 @@ func (f *Interface) reloadFirewall(c *Config) {
|
||||
}
|
||||
|
||||
oldFw := f.firewall
|
||||
conntrack := oldFw.Conntrack
|
||||
conntrack.Lock()
|
||||
defer conntrack.Unlock()
|
||||
|
||||
fw.rulesVersion = oldFw.rulesVersion + 1
|
||||
// If rulesVersion is back to zero, we have wrapped all the way around. Be
|
||||
// safe and just reset conntrack in this case.
|
||||
if fw.rulesVersion == 0 {
|
||||
l.WithField("firewallHash", fw.GetRuleHash()).
|
||||
WithField("oldFirewallHash", oldFw.GetRuleHash()).
|
||||
WithField("rulesVersion", fw.rulesVersion).
|
||||
Warn("firewall rulesVersion has overflowed, resetting conntrack")
|
||||
} else {
|
||||
fw.Conntrack = conntrack
|
||||
}
|
||||
|
||||
f.firewall = fw
|
||||
|
||||
oldFw.Destroy()
|
||||
l.WithField("firewallHash", fw.GetRuleHash()).
|
||||
WithField("oldFirewallHash", oldFw.GetRuleHash()).
|
||||
WithField("rulesVersion", fw.rulesVersion).
|
||||
Info("New firewall has been installed")
|
||||
}
|
||||
|
||||
func (f *Interface) emitStats(i time.Duration) {
|
||||
ticker := time.NewTicker(i)
|
||||
|
||||
udpStats := NewUDPStatsEmitter(f.writers)
|
||||
|
||||
for range ticker.C {
|
||||
f.firewall.EmitStats()
|
||||
f.handshakeManager.EmitStats()
|
||||
|
||||
udpStats()
|
||||
}
|
||||
}
|
||||
|
||||
274
lighthouse.go
274
lighthouse.go
@@ -1,15 +1,19 @@
|
||||
package nebula
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/golang/protobuf/proto"
|
||||
"github.com/rcrowley/go-metrics"
|
||||
"github.com/slackhq/nebula/cert"
|
||||
)
|
||||
|
||||
var ErrHostNotKnown = errors.New("host not known")
|
||||
|
||||
type LightHouse struct {
|
||||
sync.RWMutex //Because we concurrently read and write to our maps
|
||||
amLighthouse bool
|
||||
@@ -19,13 +23,30 @@ type LightHouse struct {
|
||||
// Local cache of answers from light houses
|
||||
addrMap map[uint32][]udpAddr
|
||||
|
||||
// filters remote addresses allowed for each host
|
||||
// - When we are a lighthouse, this filters what addresses we store and
|
||||
// respond with.
|
||||
// - When we are not a lighthouse, this filters which addresses we accept
|
||||
// from lighthouses.
|
||||
remoteAllowList *AllowList
|
||||
|
||||
// filters local addresses that we advertise to lighthouses
|
||||
localAllowList *AllowList
|
||||
|
||||
// used to trigger the HandshakeManager when we receive HostQueryReply
|
||||
handshakeTrigger chan<- uint32
|
||||
|
||||
// staticList exists to avoid having a bool in each addrMap entry
|
||||
// since static should be rare
|
||||
staticList map[uint32]struct{}
|
||||
lighthouses map[uint32]struct{}
|
||||
interval int
|
||||
nebulaPort int
|
||||
nebulaPort uint32
|
||||
punchBack bool
|
||||
punchDelay time.Duration
|
||||
|
||||
metrics *MessageMetrics
|
||||
metricHolepunchTx metrics.Counter
|
||||
}
|
||||
|
||||
type EncWriter interface {
|
||||
@@ -33,7 +54,7 @@ type EncWriter interface {
|
||||
SendMessageToAll(t NebulaMessageType, st NebulaMessageSubType, vpnIp uint32, p, nb, out []byte)
|
||||
}
|
||||
|
||||
func NewLightHouse(amLighthouse bool, myIp uint32, ips []uint32, interval int, nebulaPort int, pc *udpConn, punchBack bool) *LightHouse {
|
||||
func NewLightHouse(amLighthouse bool, myIp uint32, ips []uint32, interval int, nebulaPort uint32, pc *udpConn, punchBack bool, punchDelay time.Duration, metricsEnabled bool) *LightHouse {
|
||||
h := LightHouse{
|
||||
amLighthouse: amLighthouse,
|
||||
myIp: myIp,
|
||||
@@ -44,6 +65,15 @@ func NewLightHouse(amLighthouse bool, myIp uint32, ips []uint32, interval int, n
|
||||
interval: interval,
|
||||
punchConn: pc,
|
||||
punchBack: punchBack,
|
||||
punchDelay: punchDelay,
|
||||
}
|
||||
|
||||
if metricsEnabled {
|
||||
h.metrics = newLighthouseMetrics()
|
||||
|
||||
h.metricHolepunchTx = metrics.GetOrRegisterCounter("messages.tx.holepunch", nil)
|
||||
} else {
|
||||
h.metricHolepunchTx = metrics.NilCounter{}
|
||||
}
|
||||
|
||||
for _, ip := range ips {
|
||||
@@ -53,6 +83,20 @@ func NewLightHouse(amLighthouse bool, myIp uint32, ips []uint32, interval int, n
|
||||
return &h
|
||||
}
|
||||
|
||||
func (lh *LightHouse) SetRemoteAllowList(allowList *AllowList) {
|
||||
lh.Lock()
|
||||
defer lh.Unlock()
|
||||
|
||||
lh.remoteAllowList = allowList
|
||||
}
|
||||
|
||||
func (lh *LightHouse) SetLocalAllowList(allowList *AllowList) {
|
||||
lh.Lock()
|
||||
defer lh.Unlock()
|
||||
|
||||
lh.localAllowList = allowList
|
||||
}
|
||||
|
||||
func (lh *LightHouse) ValidateLHStaticEntries() error {
|
||||
for lhIP, _ := range lh.lighthouses {
|
||||
if _, ok := lh.staticList[lhIP]; !ok {
|
||||
@@ -72,7 +116,7 @@ func (lh *LightHouse) Query(ip uint32, f EncWriter) ([]udpAddr, error) {
|
||||
return v, nil
|
||||
}
|
||||
lh.RUnlock()
|
||||
return nil, fmt.Errorf("host %s not known, queries sent to lighthouses", IntIp(ip))
|
||||
return nil, ErrHostNotKnown
|
||||
}
|
||||
|
||||
// This is asynchronous so no reply should be expected
|
||||
@@ -85,6 +129,7 @@ func (lh *LightHouse) QueryServer(ip uint32, f EncWriter) {
|
||||
return
|
||||
}
|
||||
|
||||
lh.metricTx(NebulaMeta_HostQuery, int64(len(lh.lighthouses)))
|
||||
nb := make([]byte, 12, 12)
|
||||
out := make([]byte, mtu)
|
||||
for n := range lh.lighthouses {
|
||||
@@ -127,18 +172,24 @@ func (lh *LightHouse) AddRemote(vpnIP uint32, toIp *udpAddr, static bool) {
|
||||
}
|
||||
|
||||
lh.Lock()
|
||||
defer lh.Unlock()
|
||||
for _, v := range lh.addrMap[vpnIP] {
|
||||
if v.Equals(toIp) {
|
||||
lh.Unlock()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
allow := lh.remoteAllowList.Allow(udp2ipInt(toIp))
|
||||
l.WithField("remoteIp", toIp).WithField("allow", allow).Debug("remoteAllowList.Allow")
|
||||
if !allow {
|
||||
return
|
||||
}
|
||||
|
||||
//l.Debugf("Adding reply of %s as %s\n", IntIp(vpnIP), toIp)
|
||||
if static {
|
||||
lh.staticList[vpnIP] = struct{}{}
|
||||
}
|
||||
lh.addrMap[vpnIP] = append(lh.addrMap[vpnIP], *toIp)
|
||||
lh.Unlock()
|
||||
}
|
||||
|
||||
func (lh *LightHouse) AddRemoteAndReset(vpnIP uint32, toIp *udpAddr) {
|
||||
@@ -156,12 +207,6 @@ func (lh *LightHouse) IsLighthouseIP(vpnIP uint32) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// Quick generators for protobuf
|
||||
|
||||
func NewLhQueryByIpString(VpnIp string) *NebulaMeta {
|
||||
return NewLhQueryByInt(ip2int(net.ParseIP(VpnIp)))
|
||||
}
|
||||
|
||||
func NewLhQueryByInt(VpnIp uint32) *NebulaMeta {
|
||||
return &NebulaMeta{
|
||||
Type: NebulaMeta_HostQuery,
|
||||
@@ -171,26 +216,12 @@ func NewLhQueryByInt(VpnIp uint32) *NebulaMeta {
|
||||
}
|
||||
}
|
||||
|
||||
func NewLhWhoami() *NebulaMeta {
|
||||
return &NebulaMeta{
|
||||
Type: NebulaMeta_HostWhoami,
|
||||
Details: &NebulaMetaDetails{},
|
||||
}
|
||||
func NewIpAndPort(ip net.IP, port uint32) IpAndPort {
|
||||
return IpAndPort{Ip: ip2int(ip), Port: port}
|
||||
}
|
||||
|
||||
// End Quick generators for protobuf
|
||||
|
||||
func NewIpAndPortFromUDPAddr(addr udpAddr) *IpAndPort {
|
||||
return &IpAndPort{Ip: udp2ipInt(&addr), Port: uint32(addr.Port)}
|
||||
}
|
||||
|
||||
func NewIpAndPortsFromNetIps(ips []udpAddr) *[]*IpAndPort {
|
||||
var iap []*IpAndPort
|
||||
for _, e := range ips {
|
||||
// Only add IPs that aren't my VPN/tun IP
|
||||
iap = append(iap, NewIpAndPortFromUDPAddr(e))
|
||||
}
|
||||
return &iap
|
||||
func NewIpAndPortFromUDPAddr(addr udpAddr) IpAndPort {
|
||||
return IpAndPort{Ip: udp2ipInt(&addr), Port: uint32(addr.Port)}
|
||||
}
|
||||
|
||||
func (lh *LightHouse) LhUpdateWorker(f EncWriter) {
|
||||
@@ -199,41 +230,105 @@ func (lh *LightHouse) LhUpdateWorker(f EncWriter) {
|
||||
}
|
||||
|
||||
for {
|
||||
ipp := []*IpAndPort{}
|
||||
|
||||
for _, e := range *localIps() {
|
||||
// Only add IPs that aren't my VPN/tun IP
|
||||
if ip2int(e) != lh.myIp {
|
||||
ipp = append(ipp, &IpAndPort{Ip: ip2int(e), Port: uint32(lh.nebulaPort)})
|
||||
//fmt.Println(e)
|
||||
}
|
||||
}
|
||||
m := &NebulaMeta{
|
||||
Type: NebulaMeta_HostUpdateNotification,
|
||||
Details: &NebulaMetaDetails{
|
||||
VpnIp: lh.myIp,
|
||||
IpAndPorts: ipp,
|
||||
},
|
||||
}
|
||||
|
||||
nb := make([]byte, 12, 12)
|
||||
out := make([]byte, mtu)
|
||||
for vpnIp := range lh.lighthouses {
|
||||
mm, err := proto.Marshal(m)
|
||||
if err != nil {
|
||||
l.Debugf("Invalid marshal to update")
|
||||
}
|
||||
//l.Error("LIGHTHOUSE PACKET SEND", mm)
|
||||
f.SendMessageToVpnIp(lightHouse, 0, vpnIp, mm, nb, out)
|
||||
|
||||
}
|
||||
lh.SendUpdate(f)
|
||||
time.Sleep(time.Second * time.Duration(lh.interval))
|
||||
}
|
||||
}
|
||||
|
||||
func (lh *LightHouse) HandleRequest(rAddr *udpAddr, vpnIp uint32, p []byte, c *cert.NebulaCertificate, f EncWriter) {
|
||||
n := &NebulaMeta{}
|
||||
err := proto.Unmarshal(p, n)
|
||||
func (lh *LightHouse) SendUpdate(f EncWriter) {
|
||||
var ipps []*IpAndPort
|
||||
|
||||
for _, e := range *localIps(lh.localAllowList) {
|
||||
// Only add IPs that aren't my VPN/tun IP
|
||||
if ip2int(e) != lh.myIp {
|
||||
ipp := NewIpAndPort(e, lh.nebulaPort)
|
||||
ipps = append(ipps, &ipp)
|
||||
}
|
||||
}
|
||||
m := &NebulaMeta{
|
||||
Type: NebulaMeta_HostUpdateNotification,
|
||||
Details: &NebulaMetaDetails{
|
||||
VpnIp: lh.myIp,
|
||||
IpAndPorts: ipps,
|
||||
},
|
||||
}
|
||||
|
||||
lh.metricTx(NebulaMeta_HostUpdateNotification, int64(len(lh.lighthouses)))
|
||||
nb := make([]byte, 12, 12)
|
||||
out := make([]byte, mtu)
|
||||
for vpnIp := range lh.lighthouses {
|
||||
mm, err := proto.Marshal(m)
|
||||
if err != nil {
|
||||
l.Debugf("Invalid marshal to update")
|
||||
}
|
||||
//l.Error("LIGHTHOUSE PACKET SEND", mm)
|
||||
f.SendMessageToVpnIp(lightHouse, 0, vpnIp, mm, nb, out)
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
type LightHouseHandler struct {
|
||||
lh *LightHouse
|
||||
nb []byte
|
||||
out []byte
|
||||
meta *NebulaMeta
|
||||
iap []IpAndPort
|
||||
iapp []*IpAndPort
|
||||
}
|
||||
|
||||
func (lh *LightHouse) NewRequestHandler() *LightHouseHandler {
|
||||
lhh := &LightHouseHandler{
|
||||
lh: lh,
|
||||
nb: make([]byte, 12, 12),
|
||||
out: make([]byte, mtu),
|
||||
|
||||
meta: &NebulaMeta{
|
||||
Details: &NebulaMetaDetails{},
|
||||
},
|
||||
}
|
||||
|
||||
lhh.resizeIpAndPorts(10)
|
||||
|
||||
return lhh
|
||||
}
|
||||
|
||||
// This method is similar to Reset(), but it re-uses the pointer structs
|
||||
// so that we don't have to re-allocate them
|
||||
func (lhh *LightHouseHandler) resetMeta() *NebulaMeta {
|
||||
details := lhh.meta.Details
|
||||
|
||||
details.Reset()
|
||||
lhh.meta.Reset()
|
||||
lhh.meta.Details = details
|
||||
|
||||
return lhh.meta
|
||||
}
|
||||
|
||||
func (lhh *LightHouseHandler) resizeIpAndPorts(n int) {
|
||||
if cap(lhh.iap) < n {
|
||||
lhh.iap = make([]IpAndPort, n)
|
||||
lhh.iapp = make([]*IpAndPort, n)
|
||||
|
||||
for i := range lhh.iap {
|
||||
lhh.iapp[i] = &lhh.iap[i]
|
||||
}
|
||||
}
|
||||
lhh.iap = lhh.iap[:n]
|
||||
lhh.iapp = lhh.iapp[:n]
|
||||
}
|
||||
|
||||
func (lhh *LightHouseHandler) setIpAndPortsFromNetIps(ips []udpAddr) []*IpAndPort {
|
||||
lhh.resizeIpAndPorts(len(ips))
|
||||
for i, e := range ips {
|
||||
lhh.iap[i] = NewIpAndPortFromUDPAddr(e)
|
||||
}
|
||||
return lhh.iapp
|
||||
}
|
||||
|
||||
func (lhh *LightHouseHandler) HandleRequest(rAddr *udpAddr, vpnIp uint32, p []byte, c *cert.NebulaCertificate, f EncWriter) {
|
||||
lh := lhh.lh
|
||||
n := lhh.resetMeta()
|
||||
err := proto.UnmarshalMerge(p, n)
|
||||
if err != nil {
|
||||
l.WithError(err).WithField("vpnIp", IntIp(vpnIp)).WithField("udpAddr", rAddr).
|
||||
Error("Failed to unmarshal lighthouse packet")
|
||||
@@ -248,6 +343,8 @@ func (lh *LightHouse) HandleRequest(rAddr *udpAddr, vpnIp uint32, p []byte, c *c
|
||||
return
|
||||
}
|
||||
|
||||
lh.metricRx(n.Type, 1)
|
||||
|
||||
switch n.Type {
|
||||
case NebulaMeta_HostQuery:
|
||||
// Exit if we don't answer queries
|
||||
@@ -262,20 +359,18 @@ func (lh *LightHouse) HandleRequest(rAddr *udpAddr, vpnIp uint32, p []byte, c *c
|
||||
//l.Debugf("Can't answer query %s from %s because error: %s", IntIp(n.Details.VpnIp), rAddr, err)
|
||||
return
|
||||
} else {
|
||||
iap := NewIpAndPortsFromNetIps(ips)
|
||||
answer := &NebulaMeta{
|
||||
Type: NebulaMeta_HostQueryReply,
|
||||
Details: &NebulaMetaDetails{
|
||||
VpnIp: n.Details.VpnIp,
|
||||
IpAndPorts: *iap,
|
||||
},
|
||||
}
|
||||
reply, err := proto.Marshal(answer)
|
||||
reqVpnIP := n.Details.VpnIp
|
||||
n = lhh.resetMeta()
|
||||
n.Type = NebulaMeta_HostQueryReply
|
||||
n.Details.VpnIp = reqVpnIP
|
||||
n.Details.IpAndPorts = lhh.setIpAndPortsFromNetIps(ips)
|
||||
reply, err := proto.Marshal(n)
|
||||
if err != nil {
|
||||
l.WithError(err).WithField("vpnIp", IntIp(vpnIp)).Error("Failed to marshal lighthouse host query reply")
|
||||
return
|
||||
}
|
||||
f.SendMessageToVpnIp(lightHouse, 0, vpnIp, reply, make([]byte, 12, 12), make([]byte, mtu))
|
||||
lh.metricTx(NebulaMeta_HostQueryReply, 1)
|
||||
f.SendMessageToVpnIp(lightHouse, 0, vpnIp, reply, lhh.nb, lhh.out[:0])
|
||||
|
||||
// This signals the other side to punch some zero byte udp packets
|
||||
ips, err = lh.Query(vpnIp, f)
|
||||
@@ -284,16 +379,13 @@ func (lh *LightHouse) HandleRequest(rAddr *udpAddr, vpnIp uint32, p []byte, c *c
|
||||
return
|
||||
} else {
|
||||
//l.Debugln("Notify host to punch", iap)
|
||||
iap = NewIpAndPortsFromNetIps(ips)
|
||||
answer = &NebulaMeta{
|
||||
Type: NebulaMeta_HostPunchNotification,
|
||||
Details: &NebulaMetaDetails{
|
||||
VpnIp: vpnIp,
|
||||
IpAndPorts: *iap,
|
||||
},
|
||||
}
|
||||
reply, _ := proto.Marshal(answer)
|
||||
f.SendMessageToVpnIp(lightHouse, 0, n.Details.VpnIp, reply, make([]byte, 12, 12), make([]byte, mtu))
|
||||
n = lhh.resetMeta()
|
||||
n.Type = NebulaMeta_HostPunchNotification
|
||||
n.Details.VpnIp = vpnIp
|
||||
n.Details.IpAndPorts = lhh.setIpAndPortsFromNetIps(ips)
|
||||
reply, _ := proto.Marshal(n)
|
||||
lh.metricTx(NebulaMeta_HostPunchNotification, 1)
|
||||
f.SendMessageToVpnIp(lightHouse, 0, reqVpnIP, reply, lhh.nb, lhh.out[:0])
|
||||
}
|
||||
//fmt.Println(reply, remoteaddr)
|
||||
}
|
||||
@@ -307,6 +399,11 @@ func (lh *LightHouse) HandleRequest(rAddr *udpAddr, vpnIp uint32, p []byte, c *c
|
||||
ans := NewUDPAddr(a.Ip, uint16(a.Port))
|
||||
lh.AddRemote(n.Details.VpnIp, ans, false)
|
||||
}
|
||||
// Non-blocking attempt to trigger, skip if it would block
|
||||
select {
|
||||
case lh.handshakeTrigger <- n.Details.VpnIp:
|
||||
default:
|
||||
}
|
||||
|
||||
case NebulaMeta_HostUpdateNotification:
|
||||
//Simple check that the host sent this not someone else
|
||||
@@ -328,10 +425,9 @@ func (lh *LightHouse) HandleRequest(rAddr *udpAddr, vpnIp uint32, p []byte, c *c
|
||||
for _, a := range n.Details.IpAndPorts {
|
||||
vpnPeer := NewUDPAddr(a.Ip, uint16(a.Port))
|
||||
go func() {
|
||||
for i := 0; i < 5; i++ {
|
||||
lh.punchConn.WriteTo(empty, vpnPeer)
|
||||
time.Sleep(time.Second * 1)
|
||||
}
|
||||
time.Sleep(lh.punchDelay)
|
||||
lh.metricHolepunchTx.Inc(1)
|
||||
lh.punchConn.WriteTo(empty, vpnPeer)
|
||||
|
||||
}()
|
||||
l.Debugf("Punching %s on %d for %s", IntIp(a.Ip), a.Port, IntIp(n.Details.VpnIp))
|
||||
@@ -343,12 +439,22 @@ func (lh *LightHouse) HandleRequest(rAddr *udpAddr, vpnIp uint32, p []byte, c *c
|
||||
go func() {
|
||||
time.Sleep(time.Second * 5)
|
||||
l.Debugf("Sending a nebula test packet to vpn ip %s", IntIp(n.Details.VpnIp))
|
||||
// TODO we have to allocate a new output buffer here since we are spawning a new goroutine
|
||||
// for each punchBack packet. We should move this into a timerwheel or a single goroutine
|
||||
// managed by a channel.
|
||||
f.SendMessageToVpnIp(test, testRequest, n.Details.VpnIp, []byte(""), make([]byte, 12, 12), make([]byte, mtu))
|
||||
}()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (lh *LightHouse) metricRx(t NebulaMeta_MessageType, i int64) {
|
||||
lh.metrics.Rx(NebulaMessageType(t), 0, i)
|
||||
}
|
||||
func (lh *LightHouse) metricTx(t NebulaMeta_MessageType, i int64) {
|
||||
lh.metrics.Tx(NebulaMessageType(t), 0, i)
|
||||
}
|
||||
|
||||
/*
|
||||
func (f *Interface) sendPathCheck(ci *ConnectionState, endpoint *net.UDPAddr, counter int) {
|
||||
c := ci.messageCounter
|
||||
|
||||
@@ -4,7 +4,7 @@ import (
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
proto "github.com/golang/protobuf/proto"
|
||||
"github.com/golang/protobuf/proto"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
@@ -36,12 +36,19 @@ func TestNewipandportfromudpaddr(t *testing.T) {
|
||||
assert.Equal(t, uint32(12345), meh.Port)
|
||||
}
|
||||
|
||||
func TestNewipandportsfromudpaddrs(t *testing.T) {
|
||||
func TestSetipandportsfromudpaddrs(t *testing.T) {
|
||||
blah := NewUDPAddrFromString("1.2.2.3:12345")
|
||||
blah2 := NewUDPAddrFromString("9.9.9.9:47828")
|
||||
group := []udpAddr{*blah, *blah2}
|
||||
hah := NewIpAndPortsFromNetIps(group)
|
||||
assert.IsType(t, &[]*IpAndPort{}, hah)
|
||||
var lh *LightHouse
|
||||
lhh := lh.NewRequestHandler()
|
||||
result := lhh.setIpAndPortsFromNetIps(group)
|
||||
assert.IsType(t, []*IpAndPort{}, result)
|
||||
assert.Len(t, result, 2)
|
||||
assert.Equal(t, uint32(0x01020203), result[0].Ip)
|
||||
assert.Equal(t, uint32(12345), result[0].Port)
|
||||
assert.Equal(t, uint32(0x09090909), result[1].Ip)
|
||||
assert.Equal(t, uint32(47828), result[1].Port)
|
||||
//t.Error(reflect.TypeOf(hah))
|
||||
|
||||
}
|
||||
@@ -52,7 +59,7 @@ func Test_lhStaticMapping(t *testing.T) {
|
||||
|
||||
udpServer, _ := NewListener("0.0.0.0", 0, true)
|
||||
|
||||
meh := NewLightHouse(true, 1, []uint32{ip2int(lh1IP)}, 10, 10003, udpServer, false)
|
||||
meh := NewLightHouse(true, 1, []uint32{ip2int(lh1IP)}, 10, 10003, udpServer, false, 1, false)
|
||||
meh.AddRemote(ip2int(lh1IP), NewUDPAddr(ip2int(lh1IP), uint16(4242)), true)
|
||||
err := meh.ValidateLHStaticEntries()
|
||||
assert.Nil(t, err)
|
||||
@@ -60,12 +67,92 @@ func Test_lhStaticMapping(t *testing.T) {
|
||||
lh2 := "10.128.0.3"
|
||||
lh2IP := net.ParseIP(lh2)
|
||||
|
||||
meh = NewLightHouse(true, 1, []uint32{ip2int(lh1IP), ip2int(lh2IP)}, 10, 10003, udpServer, false)
|
||||
meh = NewLightHouse(true, 1, []uint32{ip2int(lh1IP), ip2int(lh2IP)}, 10, 10003, udpServer, false, 1, false)
|
||||
meh.AddRemote(ip2int(lh1IP), NewUDPAddr(ip2int(lh1IP), uint16(4242)), true)
|
||||
err = meh.ValidateLHStaticEntries()
|
||||
assert.EqualError(t, err, "Lighthouse 10.128.0.3 does not have a static_host_map entry")
|
||||
}
|
||||
|
||||
func BenchmarkLighthouseHandleRequest(b *testing.B) {
|
||||
lh1 := "10.128.0.2"
|
||||
lh1IP := net.ParseIP(lh1)
|
||||
|
||||
udpServer, _ := NewListener("0.0.0.0", 0, true)
|
||||
|
||||
lh := NewLightHouse(true, 1, []uint32{ip2int(lh1IP)}, 10, 10003, udpServer, false, 1, false)
|
||||
|
||||
hAddr := NewUDPAddrFromString("4.5.6.7:12345")
|
||||
hAddr2 := NewUDPAddrFromString("4.5.6.7:12346")
|
||||
lh.addrMap[3] = []udpAddr{*hAddr, *hAddr2}
|
||||
|
||||
rAddr := NewUDPAddrFromString("1.2.2.3:12345")
|
||||
rAddr2 := NewUDPAddrFromString("1.2.2.3:12346")
|
||||
lh.addrMap[2] = []udpAddr{*rAddr, *rAddr2}
|
||||
|
||||
mw := &mockEncWriter{}
|
||||
|
||||
b.Run("notfound", func(b *testing.B) {
|
||||
lhh := lh.NewRequestHandler()
|
||||
req := &NebulaMeta{
|
||||
Type: NebulaMeta_HostQuery,
|
||||
Details: &NebulaMetaDetails{
|
||||
VpnIp: 4,
|
||||
IpAndPorts: nil,
|
||||
},
|
||||
}
|
||||
p, err := proto.Marshal(req)
|
||||
assert.NoError(b, err)
|
||||
for n := 0; n < b.N; n++ {
|
||||
lhh.HandleRequest(rAddr, 2, p, nil, mw)
|
||||
}
|
||||
})
|
||||
b.Run("found", func(b *testing.B) {
|
||||
lhh := lh.NewRequestHandler()
|
||||
req := &NebulaMeta{
|
||||
Type: NebulaMeta_HostQuery,
|
||||
Details: &NebulaMetaDetails{
|
||||
VpnIp: 3,
|
||||
IpAndPorts: nil,
|
||||
},
|
||||
}
|
||||
p, err := proto.Marshal(req)
|
||||
assert.NoError(b, err)
|
||||
|
||||
for n := 0; n < b.N; n++ {
|
||||
lhh.HandleRequest(rAddr, 2, p, nil, mw)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func Test_lhRemoteAllowList(t *testing.T) {
|
||||
c := NewConfig()
|
||||
c.Settings["remoteallowlist"] = map[interface{}]interface{}{
|
||||
"10.20.0.0/12": false,
|
||||
}
|
||||
allowList, err := c.GetAllowList("remoteallowlist", false)
|
||||
assert.Nil(t, err)
|
||||
|
||||
lh1 := "10.128.0.2"
|
||||
lh1IP := net.ParseIP(lh1)
|
||||
|
||||
udpServer, _ := NewListener("0.0.0.0", 0, true)
|
||||
|
||||
lh := NewLightHouse(true, 1, []uint32{ip2int(lh1IP)}, 10, 10003, udpServer, false, 1, false)
|
||||
lh.SetRemoteAllowList(allowList)
|
||||
|
||||
remote1 := "10.20.0.3"
|
||||
remote1IP := net.ParseIP(remote1)
|
||||
lh.AddRemote(ip2int(remote1IP), NewUDPAddr(ip2int(remote1IP), uint16(4242)), true)
|
||||
assert.Nil(t, lh.addrMap[ip2int(remote1IP)])
|
||||
|
||||
remote2 := "10.128.0.3"
|
||||
remote2IP := net.ParseIP(remote2)
|
||||
remote2UDPAddr := NewUDPAddr(ip2int(remote2IP), uint16(4242))
|
||||
|
||||
lh.AddRemote(ip2int(remote2IP), remote2UDPAddr, true)
|
||||
assert.Equal(t, remote2UDPAddr, &lh.addrMap[ip2int(remote2IP)][0])
|
||||
}
|
||||
|
||||
//func NewLightHouse(amLighthouse bool, myIp uint32, ips []string, interval int, nebulaPort int, pc *udpConn, punchBack bool) *LightHouse {
|
||||
|
||||
/*
|
||||
|
||||
39
logger.go
Normal file
39
logger.go
Normal file
@@ -0,0 +1,39 @@
|
||||
package nebula
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
type ContextualError struct {
|
||||
RealError error
|
||||
Fields map[string]interface{}
|
||||
Context string
|
||||
}
|
||||
|
||||
func NewContextualError(msg string, fields map[string]interface{}, realError error) ContextualError {
|
||||
return ContextualError{Context: msg, Fields: fields, RealError: realError}
|
||||
}
|
||||
|
||||
func (ce ContextualError) Error() string {
|
||||
if ce.RealError == nil {
|
||||
return ce.Context
|
||||
}
|
||||
return ce.RealError.Error()
|
||||
}
|
||||
|
||||
func (ce ContextualError) Unwrap() error {
|
||||
if ce.RealError == nil {
|
||||
return errors.New(ce.Context)
|
||||
}
|
||||
return ce.RealError
|
||||
}
|
||||
|
||||
func (ce *ContextualError) Log(lr *logrus.Logger) {
|
||||
if ce.RealError != nil {
|
||||
lr.WithFields(ce.Fields).WithError(ce.RealError).Error(ce.Context)
|
||||
} else {
|
||||
lr.WithFields(ce.Fields).Error(ce.Context)
|
||||
}
|
||||
}
|
||||
67
logger_test.go
Normal file
67
logger_test.go
Normal file
@@ -0,0 +1,67 @@
|
||||
package nebula
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
type TestLogWriter struct {
|
||||
Logs []string
|
||||
}
|
||||
|
||||
func NewTestLogWriter() *TestLogWriter {
|
||||
return &TestLogWriter{Logs: make([]string, 0)}
|
||||
}
|
||||
|
||||
func (tl *TestLogWriter) Write(p []byte) (n int, err error) {
|
||||
tl.Logs = append(tl.Logs, string(p))
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
func (tl *TestLogWriter) Reset() {
|
||||
tl.Logs = tl.Logs[:0]
|
||||
}
|
||||
|
||||
func TestContextualError_Log(t *testing.T) {
|
||||
l := logrus.New()
|
||||
l.Formatter = &logrus.TextFormatter{
|
||||
DisableTimestamp: true,
|
||||
DisableColors: true,
|
||||
}
|
||||
|
||||
tl := NewTestLogWriter()
|
||||
l.Out = tl
|
||||
|
||||
// Test a full context line
|
||||
tl.Reset()
|
||||
e := NewContextualError("test message", m{"field": "1"}, errors.New("error"))
|
||||
e.Log(l)
|
||||
assert.Equal(t, []string{"level=error msg=\"test message\" error=error field=1\n"}, tl.Logs)
|
||||
|
||||
// Test a line with an error and msg but no fields
|
||||
tl.Reset()
|
||||
e = NewContextualError("test message", nil, errors.New("error"))
|
||||
e.Log(l)
|
||||
assert.Equal(t, []string{"level=error msg=\"test message\" error=error\n"}, tl.Logs)
|
||||
|
||||
// Test just a context and fields
|
||||
tl.Reset()
|
||||
e = NewContextualError("test message", m{"field": "1"}, nil)
|
||||
e.Log(l)
|
||||
assert.Equal(t, []string{"level=error msg=\"test message\" field=1\n"}, tl.Logs)
|
||||
|
||||
// Test just a context
|
||||
tl.Reset()
|
||||
e = NewContextualError("test message", nil, nil)
|
||||
e.Log(l)
|
||||
assert.Equal(t, []string{"level=error msg=\"test message\"\n"}, tl.Logs)
|
||||
|
||||
// Test just an error
|
||||
tl.Reset()
|
||||
e = NewContextualError("", nil, errors.New("error"))
|
||||
e.Log(l)
|
||||
assert.Equal(t, []string{"level=error error=error\n"}, tl.Logs)
|
||||
}
|
||||
290
main.go
290
main.go
@@ -4,11 +4,8 @@ import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"os/signal"
|
||||
"strconv"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
@@ -16,36 +13,31 @@ import (
|
||||
"gopkg.in/yaml.v2"
|
||||
)
|
||||
|
||||
// The caller should provide a real logger, we have one just in case
|
||||
var l = logrus.New()
|
||||
|
||||
type m map[string]interface{}
|
||||
|
||||
func Main(configPath string, configTest bool, buildVersion string) {
|
||||
l.Out = os.Stdout
|
||||
func Main(config *Config, configTest bool, buildVersion string, logger *logrus.Logger, tunFd *int) (*Control, error) {
|
||||
l = logger
|
||||
l.Formatter = &logrus.TextFormatter{
|
||||
FullTimestamp: true,
|
||||
}
|
||||
|
||||
config := NewConfig()
|
||||
err := config.Load(configPath)
|
||||
if err != nil {
|
||||
l.WithError(err).Error("Failed to load config")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// Print the config if in test, the exit comes later
|
||||
if configTest {
|
||||
b, err := yaml.Marshal(config.Settings)
|
||||
if err != nil {
|
||||
l.Println(err)
|
||||
os.Exit(1)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Print the final config
|
||||
l.Println(string(b))
|
||||
}
|
||||
|
||||
err = configLogger(config)
|
||||
err := configLogger(config)
|
||||
if err != nil {
|
||||
l.WithError(err).Error("Failed to configure the logger")
|
||||
return nil, NewContextualError("Failed to configure the logger", nil, err)
|
||||
}
|
||||
|
||||
config.RegisterReloadCallback(func(c *Config) {
|
||||
@@ -59,20 +51,20 @@ func Main(configPath string, configTest bool, buildVersion string) {
|
||||
trustedCAs, err = loadCAFromConfig(config)
|
||||
if err != nil {
|
||||
//The errors coming out of loadCA are already nicely formatted
|
||||
l.WithError(err).Fatal("Failed to load ca from config")
|
||||
return nil, NewContextualError("Failed to load ca from config", nil, err)
|
||||
}
|
||||
l.WithField("fingerprints", trustedCAs.GetFingerprints()).Debug("Trusted CA fingerprints")
|
||||
|
||||
cs, err := NewCertStateFromConfig(config)
|
||||
if err != nil {
|
||||
//The errors coming out of NewCertStateFromConfig are already nicely formatted
|
||||
l.WithError(err).Fatal("Failed to load certificate from config")
|
||||
return nil, NewContextualError("Failed to load certificate from config", nil, err)
|
||||
}
|
||||
l.WithField("cert", cs.certificate).Debug("Client nebula certificate")
|
||||
|
||||
fw, err := NewFirewallFromConfig(cs.certificate, config)
|
||||
if err != nil {
|
||||
l.WithError(err).Fatal("Error while loading firewall rules")
|
||||
return nil, NewContextualError("Error while loading firewall rules", nil, err)
|
||||
}
|
||||
l.WithField("firewallHash", fw.GetRuleHash()).Info("Firewall started")
|
||||
|
||||
@@ -80,11 +72,11 @@ func Main(configPath string, configTest bool, buildVersion string) {
|
||||
tunCidr := cs.certificate.Details.Ips[0]
|
||||
routes, err := parseRoutes(config, tunCidr)
|
||||
if err != nil {
|
||||
l.WithError(err).Fatal("Could not parse tun.routes")
|
||||
return nil, NewContextualError("Could not parse tun.routes", nil, err)
|
||||
}
|
||||
unsafeRoutes, err := parseUnsafeRoutes(config, tunCidr)
|
||||
if err != nil {
|
||||
l.WithError(err).Fatal("Could not parse tun.unsafe_routes")
|
||||
return nil, NewContextualError("Could not parse tun.unsafe_routes", nil, err)
|
||||
}
|
||||
|
||||
ssh, err := sshd.NewSSHServer(l.WithField("subsystem", "sshd"))
|
||||
@@ -92,7 +84,7 @@ func Main(configPath string, configTest bool, buildVersion string) {
|
||||
if config.GetBool("sshd.enabled", false) {
|
||||
err = configSSH(ssh, config)
|
||||
if err != nil {
|
||||
l.WithError(err).Fatal("Error while configuring the sshd")
|
||||
return nil, NewContextualError("Error while configuring the sshd", nil, err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -101,32 +93,98 @@ func Main(configPath string, configTest bool, buildVersion string) {
|
||||
// tun config, listeners, anything modifying the computer should be below
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
if configTest {
|
||||
os.Exit(0)
|
||||
var routines int
|
||||
|
||||
// If `routines` is set, use that and ignore the specific values
|
||||
if routines = config.GetInt("routines", 0); routines != 0 {
|
||||
if routines < 1 {
|
||||
routines = 1
|
||||
}
|
||||
if routines > 1 {
|
||||
l.WithField("routines", routines).Info("Using multiple routines")
|
||||
}
|
||||
} else {
|
||||
// deprecated and undocumented
|
||||
tunQueues := config.GetInt("tun.routines", 1)
|
||||
udpQueues := config.GetInt("listen.routines", 1)
|
||||
if tunQueues > udpQueues {
|
||||
routines = tunQueues
|
||||
} else {
|
||||
routines = udpQueues
|
||||
}
|
||||
if routines != 1 {
|
||||
l.WithField("routines", routines).Warn("Setting tun.routines and listen.routines is deprecated. Use `routines` instead")
|
||||
}
|
||||
}
|
||||
|
||||
config.CatchHUP()
|
||||
// EXPERIMENTAL
|
||||
// Intentionally not documented yet while we do more testing and determine
|
||||
// a good default value.
|
||||
conntrackCacheTimeout := config.GetDuration("firewall.conntrack.routine_cache_timeout", 0)
|
||||
if routines > 1 && !config.IsSet("firewall.conntrack.routine_cache_timeout") {
|
||||
// Use a different default if we are running with multiple routines
|
||||
conntrackCacheTimeout = 1 * time.Second
|
||||
}
|
||||
if conntrackCacheTimeout > 0 {
|
||||
l.WithField("duration", conntrackCacheTimeout).Info("Using routine-local conntrack cache")
|
||||
}
|
||||
|
||||
// set up our tun dev
|
||||
tun, err := newTun(
|
||||
config.GetString("tun.dev", ""),
|
||||
tunCidr,
|
||||
config.GetInt("tun.mtu", DEFAULT_MTU),
|
||||
routes,
|
||||
unsafeRoutes,
|
||||
config.GetInt("tun.tx_queue", 500),
|
||||
)
|
||||
if err != nil {
|
||||
l.WithError(err).Fatal("Failed to get a tun/tap device")
|
||||
var tun Inside
|
||||
if !configTest {
|
||||
config.CatchHUP()
|
||||
|
||||
switch {
|
||||
case config.GetBool("tun.disabled", false):
|
||||
tun = newDisabledTun(tunCidr, config.GetInt("tun.tx_queue", 500), config.GetBool("stats.message_metrics", false), l)
|
||||
case tunFd != nil:
|
||||
tun, err = newTunFromFd(
|
||||
*tunFd,
|
||||
tunCidr,
|
||||
config.GetInt("tun.mtu", DEFAULT_MTU),
|
||||
routes,
|
||||
unsafeRoutes,
|
||||
config.GetInt("tun.tx_queue", 500),
|
||||
)
|
||||
default:
|
||||
tun, err = newTun(
|
||||
config.GetString("tun.dev", ""),
|
||||
tunCidr,
|
||||
config.GetInt("tun.mtu", DEFAULT_MTU),
|
||||
routes,
|
||||
unsafeRoutes,
|
||||
config.GetInt("tun.tx_queue", 500),
|
||||
routines > 1,
|
||||
)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, NewContextualError("Failed to get a tun/tap device", nil, err)
|
||||
}
|
||||
}
|
||||
|
||||
// set up our UDP listener
|
||||
udpQueues := config.GetInt("listen.routines", 1)
|
||||
udpServer, err := NewListener(config.GetString("listen.host", "0.0.0.0"), config.GetInt("listen.port", 0), udpQueues > 1)
|
||||
if err != nil {
|
||||
l.WithError(err).Fatal("Failed to open udp listener")
|
||||
udpConns := make([]*udpConn, routines)
|
||||
port := config.GetInt("listen.port", 0)
|
||||
|
||||
if !configTest {
|
||||
for i := 0; i < routines; i++ {
|
||||
udpServer, err := NewListener(config.GetString("listen.host", "0.0.0.0"), port, routines > 1)
|
||||
if err != nil {
|
||||
return nil, NewContextualError("Failed to open udp listener", m{"queue": i}, err)
|
||||
}
|
||||
udpServer.reloadConfig(config)
|
||||
udpConns[i] = udpServer
|
||||
|
||||
// If port is dynamic, discover it
|
||||
if port == 0 {
|
||||
uPort, err := udpServer.LocalAddr()
|
||||
if err != nil {
|
||||
return nil, NewContextualError("Failed to get listening port", nil, err)
|
||||
}
|
||||
port = int(uPort.Port)
|
||||
}
|
||||
}
|
||||
}
|
||||
udpServer.reloadConfig(config)
|
||||
|
||||
// Set up my internal host map
|
||||
var preferredRanges []*net.IPNet
|
||||
@@ -136,7 +194,7 @@ func Main(configPath string, configTest bool, buildVersion string) {
|
||||
for _, rawPreferredRange := range rawPreferredRanges {
|
||||
_, preferredRange, err := net.ParseCIDR(rawPreferredRange)
|
||||
if err != nil {
|
||||
l.WithError(err).Fatal("Failed to parse preferred ranges")
|
||||
return nil, NewContextualError("Failed to parse preferred ranges", nil, err)
|
||||
}
|
||||
preferredRanges = append(preferredRanges, preferredRange)
|
||||
}
|
||||
@@ -149,7 +207,7 @@ func Main(configPath string, configTest bool, buildVersion string) {
|
||||
if rawLocalRange != "" {
|
||||
_, localRange, err := net.ParseCIDR(rawLocalRange)
|
||||
if err != nil {
|
||||
l.WithError(err).Fatal("Failed to parse local range")
|
||||
return nil, NewContextualError("Failed to parse local_range", nil, err)
|
||||
}
|
||||
|
||||
// Check if the entry for local_range was already specified in
|
||||
@@ -169,6 +227,7 @@ func Main(configPath string, configTest bool, buildVersion string) {
|
||||
hostMap := NewHostMap("main", tunCidr, preferredRanges)
|
||||
hostMap.SetDefaultRoute(ip2int(net.ParseIP(config.GetString("default_route", "0.0.0.0"))))
|
||||
hostMap.addUnsafeRoutes(&unsafeRoutes)
|
||||
hostMap.metricsEnabled = config.GetBool("stats.message_metrics", false)
|
||||
|
||||
l.WithField("network", hostMap.vpnCIDR).WithField("preferredRanges", hostMap.preferredRanges).Info("Main HostMap created")
|
||||
|
||||
@@ -177,25 +236,19 @@ func Main(configPath string, configTest bool, buildVersion string) {
|
||||
go hostMap.Promoter(config.GetInt("promoter.interval"))
|
||||
*/
|
||||
|
||||
punchy := config.GetBool("punchy", false)
|
||||
if punchy == true {
|
||||
punchy := NewPunchyFromConfig(config)
|
||||
if punchy.Punch && !configTest {
|
||||
l.Info("UDP hole punching enabled")
|
||||
go hostMap.Punchy(udpServer)
|
||||
go hostMap.Punchy(udpConns[0])
|
||||
}
|
||||
|
||||
port := config.GetInt("listen.port", 0)
|
||||
// If port is dynamic, discover it
|
||||
if port == 0 {
|
||||
uPort, err := udpServer.LocalAddr()
|
||||
if err != nil {
|
||||
l.WithError(err).Fatal("Failed to get listening port")
|
||||
}
|
||||
port = int(uPort.Port)
|
||||
}
|
||||
|
||||
punchBack := config.GetBool("punch_back", false)
|
||||
amLighthouse := config.GetBool("lighthouse.am_lighthouse", false)
|
||||
|
||||
// fatal if am_lighthouse is enabled but we are using an ephemeral port
|
||||
if amLighthouse && (config.GetInt("listen.port", 0) == 0) {
|
||||
return nil, NewContextualError("lighthouse.am_lighthouse enabled on node but no port number is set in config", nil, nil)
|
||||
}
|
||||
|
||||
// warn if am_lighthouse is enabled but upstream lighthouses exists
|
||||
rawLighthouseHosts := config.GetStringSlice("lighthouse.hosts", []string{})
|
||||
if amLighthouse && len(rawLighthouseHosts) != 0 {
|
||||
@@ -206,7 +259,10 @@ func Main(configPath string, configTest bool, buildVersion string) {
|
||||
for i, host := range rawLighthouseHosts {
|
||||
ip := net.ParseIP(host)
|
||||
if ip == nil {
|
||||
l.WithField("host", host).Fatalf("Unable to parse lighthouse host entry %v", i+1)
|
||||
return nil, NewContextualError("Unable to parse lighthouse host entry", m{"host": host, "entry": i + 1}, nil)
|
||||
}
|
||||
if !tunCidr.Contains(ip) {
|
||||
return nil, NewContextualError("lighthouse host is not in our subnet, invalid", m{"vpnIp": ip, "network": tunCidr.String()}, nil)
|
||||
}
|
||||
lighthouseHosts[i] = ip2int(ip)
|
||||
}
|
||||
@@ -217,14 +273,31 @@ func Main(configPath string, configTest bool, buildVersion string) {
|
||||
lighthouseHosts,
|
||||
//TODO: change to a duration
|
||||
config.GetInt("lighthouse.interval", 10),
|
||||
port,
|
||||
udpServer,
|
||||
punchBack,
|
||||
uint32(port),
|
||||
udpConns[0],
|
||||
punchy.Respond,
|
||||
punchy.Delay,
|
||||
config.GetBool("stats.lighthouse_metrics", false),
|
||||
)
|
||||
|
||||
remoteAllowList, err := config.GetAllowList("lighthouse.remote_allow_list", false)
|
||||
if err != nil {
|
||||
return nil, NewContextualError("Invalid lighthouse.remote_allow_list", nil, err)
|
||||
}
|
||||
lightHouse.SetRemoteAllowList(remoteAllowList)
|
||||
|
||||
localAllowList, err := config.GetAllowList("lighthouse.local_allow_list", true)
|
||||
if err != nil {
|
||||
return nil, NewContextualError("Invalid lighthouse.local_allow_list", nil, err)
|
||||
}
|
||||
lightHouse.SetLocalAllowList(localAllowList)
|
||||
|
||||
//TODO: Move all of this inside functions in lighthouse.go
|
||||
for k, v := range config.GetMap("static_host_map", map[interface{}]interface{}{}) {
|
||||
vpnIp := net.ParseIP(fmt.Sprintf("%v", k))
|
||||
if !tunCidr.Contains(vpnIp) {
|
||||
return nil, NewContextualError("static_host_map key is not in our subnet, invalid", m{"vpnIp": vpnIp, "network": tunCidr.String()}, nil)
|
||||
}
|
||||
vals, ok := v.([]interface{})
|
||||
if ok {
|
||||
for _, v := range vals {
|
||||
@@ -234,7 +307,7 @@ func Main(configPath string, configTest bool, buildVersion string) {
|
||||
ip := addr.IP
|
||||
port, err := strconv.Atoi(parts[1])
|
||||
if err != nil {
|
||||
l.Fatalf("Static host address for %s could not be parsed: %s", vpnIp, v)
|
||||
return nil, NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp}, err)
|
||||
}
|
||||
lightHouse.AddRemote(ip2int(vpnIp), NewUDPAddr(ip2int(ip), uint16(port)), true)
|
||||
}
|
||||
@@ -247,7 +320,7 @@ func Main(configPath string, configTest bool, buildVersion string) {
|
||||
ip := addr.IP
|
||||
port, err := strconv.Atoi(parts[1])
|
||||
if err != nil {
|
||||
l.Fatalf("Static host address for %s could not be parsed: %s", vpnIp, v)
|
||||
return nil, NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp}, err)
|
||||
}
|
||||
lightHouse.AddRemote(ip2int(vpnIp), NewUDPAddr(ip2int(ip), uint16(port)), true)
|
||||
}
|
||||
@@ -259,7 +332,24 @@ func Main(configPath string, configTest bool, buildVersion string) {
|
||||
l.WithError(err).Error("Lighthouse unreachable")
|
||||
}
|
||||
|
||||
handshakeManager := NewHandshakeManager(tunCidr, preferredRanges, hostMap, lightHouse, udpServer)
|
||||
var messageMetrics *MessageMetrics
|
||||
if config.GetBool("stats.message_metrics", false) {
|
||||
messageMetrics = newMessageMetrics()
|
||||
} else {
|
||||
messageMetrics = newMessageMetricsOnlyRecvError()
|
||||
}
|
||||
|
||||
handshakeConfig := HandshakeConfig{
|
||||
tryInterval: config.GetDuration("handshakes.try_interval", DefaultHandshakeTryInterval),
|
||||
retries: config.GetInt("handshakes.retries", DefaultHandshakeRetries),
|
||||
waitRotation: config.GetInt("handshakes.wait_rotation", DefaultHandshakeWaitRotation),
|
||||
triggerBuffer: config.GetInt("handshakes.trigger_buffer", DefaultHandshakeTriggerBuffer),
|
||||
|
||||
messageMetrics: messageMetrics,
|
||||
}
|
||||
|
||||
handshakeManager := NewHandshakeManager(tunCidr, preferredRanges, hostMap, lightHouse, udpConns[0], handshakeConfig)
|
||||
lightHouse.handshakeTrigger = handshakeManager.trigger
|
||||
|
||||
//TODO: These will be reused for psk
|
||||
//handshakeMACKey := config.GetString("handshake_mac.key", "")
|
||||
@@ -271,7 +361,7 @@ func Main(configPath string, configTest bool, buildVersion string) {
|
||||
ifConfig := &InterfaceConfig{
|
||||
HostMap: hostMap,
|
||||
Inside: tun,
|
||||
Outside: udpServer,
|
||||
Outside: udpConns[0],
|
||||
certState: cs,
|
||||
Cipher: config.GetString("cipher", "aes"),
|
||||
Firewall: fw,
|
||||
@@ -283,37 +373,52 @@ func Main(configPath string, configTest bool, buildVersion string) {
|
||||
DropLocalBroadcast: config.GetBool("tun.drop_local_broadcast", false),
|
||||
DropMulticast: config.GetBool("tun.drop_multicast", false),
|
||||
UDPBatchSize: config.GetInt("listen.batch", 64),
|
||||
routines: routines,
|
||||
MessageMetrics: messageMetrics,
|
||||
version: buildVersion,
|
||||
|
||||
ConntrackCacheTimeout: conntrackCacheTimeout,
|
||||
}
|
||||
|
||||
switch ifConfig.Cipher {
|
||||
case "aes":
|
||||
noiseEndiannes = binary.BigEndian
|
||||
noiseEndianness = binary.BigEndian
|
||||
case "chachapoly":
|
||||
noiseEndiannes = binary.LittleEndian
|
||||
noiseEndianness = binary.LittleEndian
|
||||
default:
|
||||
l.Fatalf("Unknown cipher: %v", ifConfig.Cipher)
|
||||
return nil, fmt.Errorf("unknown cipher: %v", ifConfig.Cipher)
|
||||
}
|
||||
|
||||
ifce, err := NewInterface(ifConfig)
|
||||
if err != nil {
|
||||
l.WithError(err).Fatal("Failed to initialize interface")
|
||||
var ifce *Interface
|
||||
if !configTest {
|
||||
ifce, err = NewInterface(ifConfig)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to initialize interface: %s", err)
|
||||
}
|
||||
|
||||
// TODO: Better way to attach these, probably want a new interface in InterfaceConfig
|
||||
// I don't want to make this initial commit too far-reaching though
|
||||
ifce.writers = udpConns
|
||||
|
||||
ifce.RegisterConfigChangeCallbacks(config)
|
||||
|
||||
go handshakeManager.Run(ifce)
|
||||
go lightHouse.LhUpdateWorker(ifce)
|
||||
}
|
||||
|
||||
ifce.RegisterConfigChangeCallbacks(config)
|
||||
|
||||
go handshakeManager.Run(ifce)
|
||||
go lightHouse.LhUpdateWorker(ifce)
|
||||
|
||||
err = startStats(config)
|
||||
err = startStats(config, configTest)
|
||||
if err != nil {
|
||||
l.WithError(err).Fatal("Failed to start stats emitter")
|
||||
return nil, NewContextualError("Failed to start stats emitter", nil, err)
|
||||
}
|
||||
|
||||
if configTest {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
//TODO: check if we _should_ be emitting stats
|
||||
go ifce.emitStats(config.GetDuration("stats.interval", time.Second*10))
|
||||
|
||||
attachCommands(ssh, hostMap, handshakeManager.pendingHostMap, lightHouse, ifce)
|
||||
ifce.Run(config.GetInt("tun.routines", 1), udpQueues, buildVersion)
|
||||
|
||||
// Start DNS server last to allow using the nebula IP as lighthouse.dns.host
|
||||
if amLighthouse && serveDns {
|
||||
@@ -321,30 +426,5 @@ func Main(configPath string, configTest bool, buildVersion string) {
|
||||
go dnsMain(hostMap, config)
|
||||
}
|
||||
|
||||
// Just sit here and be friendly, main thread.
|
||||
shutdownBlock(ifce)
|
||||
}
|
||||
|
||||
func shutdownBlock(ifce *Interface) {
|
||||
var sigChan = make(chan os.Signal)
|
||||
signal.Notify(sigChan, syscall.SIGTERM)
|
||||
signal.Notify(sigChan, syscall.SIGINT)
|
||||
|
||||
sig := <-sigChan
|
||||
l.WithField("signal", sig).Info("Caught signal, shutting down")
|
||||
|
||||
//TODO: stop tun and udp routines, the lock on hostMap does effectively does that though
|
||||
//TODO: this is probably better as a function in ConnectionManager or HostMap directly
|
||||
ifce.hostMap.Lock()
|
||||
for _, h := range ifce.hostMap.Hosts {
|
||||
if h.ConnectionState.ready {
|
||||
ifce.send(closeTunnel, 0, h.ConnectionState, h, h.remote, []byte{}, make([]byte, 12, 12), make([]byte, mtu))
|
||||
l.WithField("vpnIp", IntIp(h.hostId)).WithField("udpAddr", h.remote).
|
||||
Debug("Sending close tunnel message")
|
||||
}
|
||||
}
|
||||
ifce.hostMap.Unlock()
|
||||
|
||||
l.WithField("signal", sig).Info("Goodbye")
|
||||
os.Exit(0)
|
||||
return &Control{ifce, l}, nil
|
||||
}
|
||||
|
||||
97
message_metrics.go
Normal file
97
message_metrics.go
Normal file
@@ -0,0 +1,97 @@
|
||||
package nebula
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/rcrowley/go-metrics"
|
||||
)
|
||||
|
||||
type MessageMetrics struct {
|
||||
rx [][]metrics.Counter
|
||||
tx [][]metrics.Counter
|
||||
|
||||
rxUnknown metrics.Counter
|
||||
txUnknown metrics.Counter
|
||||
}
|
||||
|
||||
func (m *MessageMetrics) Rx(t NebulaMessageType, s NebulaMessageSubType, i int64) {
|
||||
if m != nil {
|
||||
if t >= 0 && int(t) < len(m.rx) && s >= 0 && int(s) < len(m.rx[t]) {
|
||||
m.rx[t][s].Inc(i)
|
||||
} else if m.rxUnknown != nil {
|
||||
m.rxUnknown.Inc(i)
|
||||
}
|
||||
}
|
||||
}
|
||||
func (m *MessageMetrics) Tx(t NebulaMessageType, s NebulaMessageSubType, i int64) {
|
||||
if m != nil {
|
||||
if t >= 0 && int(t) < len(m.tx) && s >= 0 && int(s) < len(m.tx[t]) {
|
||||
m.tx[t][s].Inc(i)
|
||||
} else if m.txUnknown != nil {
|
||||
m.txUnknown.Inc(i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func newMessageMetrics() *MessageMetrics {
|
||||
gen := func(t string) [][]metrics.Counter {
|
||||
return [][]metrics.Counter{
|
||||
{
|
||||
metrics.GetOrRegisterCounter(fmt.Sprintf("messages.%s.handshake_ixpsk0", t), nil),
|
||||
},
|
||||
nil,
|
||||
{metrics.GetOrRegisterCounter(fmt.Sprintf("messages.%s.recv_error", t), nil)},
|
||||
{metrics.GetOrRegisterCounter(fmt.Sprintf("messages.%s.lighthouse", t), nil)},
|
||||
{
|
||||
metrics.GetOrRegisterCounter(fmt.Sprintf("messages.%s.test_request", t), nil),
|
||||
metrics.GetOrRegisterCounter(fmt.Sprintf("messages.%s.test_response", t), nil),
|
||||
},
|
||||
{metrics.GetOrRegisterCounter(fmt.Sprintf("messages.%s.close_tunnel", t), nil)},
|
||||
}
|
||||
}
|
||||
return &MessageMetrics{
|
||||
rx: gen("rx"),
|
||||
tx: gen("tx"),
|
||||
|
||||
rxUnknown: metrics.GetOrRegisterCounter("messages.rx.other", nil),
|
||||
txUnknown: metrics.GetOrRegisterCounter("messages.tx.other", nil),
|
||||
}
|
||||
}
|
||||
|
||||
// Historically we only recorded recv_error, so this is backwards compat
|
||||
func newMessageMetricsOnlyRecvError() *MessageMetrics {
|
||||
gen := func(t string) [][]metrics.Counter {
|
||||
return [][]metrics.Counter{
|
||||
nil,
|
||||
nil,
|
||||
{metrics.GetOrRegisterCounter(fmt.Sprintf("messages.%s.recv_error", t), nil)},
|
||||
}
|
||||
}
|
||||
return &MessageMetrics{
|
||||
rx: gen("rx"),
|
||||
tx: gen("tx"),
|
||||
}
|
||||
}
|
||||
|
||||
func newLighthouseMetrics() *MessageMetrics {
|
||||
gen := func(t string) [][]metrics.Counter {
|
||||
h := make([][]metrics.Counter, len(NebulaMeta_MessageType_name))
|
||||
used := []NebulaMeta_MessageType{
|
||||
NebulaMeta_HostQuery,
|
||||
NebulaMeta_HostQueryReply,
|
||||
NebulaMeta_HostUpdateNotification,
|
||||
NebulaMeta_HostPunchNotification,
|
||||
}
|
||||
for _, i := range used {
|
||||
h[i] = []metrics.Counter{metrics.GetOrRegisterCounter(fmt.Sprintf("lighthouse.%s.%s", t, i.String()), nil)}
|
||||
}
|
||||
return h
|
||||
}
|
||||
return &MessageMetrics{
|
||||
rx: gen("rx"),
|
||||
tx: gen("tx"),
|
||||
|
||||
rxUnknown: metrics.GetOrRegisterCounter("lighthouse.rx.other", nil),
|
||||
txUnknown: metrics.GetOrRegisterCounter("lighthouse.tx.other", nil),
|
||||
}
|
||||
}
|
||||
8
noise.go
8
noise.go
@@ -8,11 +8,11 @@ import (
|
||||
"github.com/flynn/noise"
|
||||
)
|
||||
|
||||
type endiannes interface {
|
||||
type endianness interface {
|
||||
PutUint64(b []byte, v uint64)
|
||||
}
|
||||
|
||||
var noiseEndiannes endiannes = binary.BigEndian
|
||||
var noiseEndianness endianness = binary.BigEndian
|
||||
|
||||
type NebulaCipherState struct {
|
||||
c noise.Cipher
|
||||
@@ -37,7 +37,7 @@ func (s *NebulaCipherState) EncryptDanger(out, ad, plaintext []byte, n uint64, n
|
||||
nb[1] = 0
|
||||
nb[2] = 0
|
||||
nb[3] = 0
|
||||
noiseEndiannes.PutUint64(nb[4:], n)
|
||||
noiseEndianness.PutUint64(nb[4:], n)
|
||||
out = s.c.(cipher.AEAD).Seal(out, nb, plaintext, ad)
|
||||
//l.Debugf("Encryption: outlen: %d, nonce: %d, ad: %s, plainlen %d", len(out), n, ad, len(plaintext))
|
||||
return out, nil
|
||||
@@ -52,7 +52,7 @@ func (s *NebulaCipherState) DecryptDanger(out, ad, ciphertext []byte, n uint64,
|
||||
nb[1] = 0
|
||||
nb[2] = 0
|
||||
nb[3] = 0
|
||||
noiseEndiannes.PutUint64(nb[4:], n)
|
||||
noiseEndianness.PutUint64(nb[4:], n)
|
||||
return s.c.(cipher.AEAD).Open(out, nb, ciphertext, ad)
|
||||
} else {
|
||||
return []byte{}, nil
|
||||
|
||||
88
outside.go
88
outside.go
@@ -2,18 +2,14 @@ package nebula
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/flynn/noise"
|
||||
"github.com/golang/protobuf/proto"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/cert"
|
||||
// "github.com/google/gopacket"
|
||||
// "github.com/google/gopacket/layers"
|
||||
// "encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"golang.org/x/net/ipv4"
|
||||
)
|
||||
|
||||
@@ -21,7 +17,7 @@ const (
|
||||
minFwPacketLen = 4
|
||||
)
|
||||
|
||||
func (f *Interface) readOutsidePackets(addr *udpAddr, out []byte, packet []byte, header *Header, fwPacket *FirewallPacket, nb []byte) {
|
||||
func (f *Interface) readOutsidePackets(addr *udpAddr, out []byte, packet []byte, header *Header, fwPacket *FirewallPacket, lhh *LightHouseHandler, nb []byte, q int, localCache ConntrackCache) {
|
||||
err := header.Parse(packet)
|
||||
if err != nil {
|
||||
// TODO: best if we return this and let caller log
|
||||
@@ -49,18 +45,19 @@ func (f *Interface) readOutsidePackets(addr *udpAddr, out []byte, packet []byte,
|
||||
return
|
||||
}
|
||||
|
||||
f.decryptToTun(hostinfo, header.MessageCounter, out, packet, fwPacket, nb)
|
||||
f.decryptToTun(hostinfo, header.MessageCounter, out, packet, fwPacket, nb, q, localCache)
|
||||
|
||||
// Fallthrough to the bottom to record incoming traffic
|
||||
|
||||
case lightHouse:
|
||||
f.messageMetrics.Rx(header.Type, header.Subtype, 1)
|
||||
if !f.handleEncrypted(ci, addr, header) {
|
||||
return
|
||||
}
|
||||
|
||||
d, err := f.decrypt(hostinfo, header.MessageCounter, out, packet, header, nb)
|
||||
if err != nil {
|
||||
l.WithError(err).WithField("udpAddr", addr).WithField("vpnIp", IntIp(hostinfo.hostId)).
|
||||
hostinfo.logger().WithError(err).WithField("udpAddr", addr).
|
||||
WithField("packet", packet).
|
||||
Error("Failed to decrypt lighthouse packet")
|
||||
|
||||
@@ -69,18 +66,19 @@ func (f *Interface) readOutsidePackets(addr *udpAddr, out []byte, packet []byte,
|
||||
return
|
||||
}
|
||||
|
||||
f.lightHouse.HandleRequest(addr, hostinfo.hostId, d, hostinfo.GetCert(), f)
|
||||
lhh.HandleRequest(addr, hostinfo.hostId, d, hostinfo.GetCert(), f)
|
||||
|
||||
// Fallthrough to the bottom to record incoming traffic
|
||||
|
||||
case test:
|
||||
f.messageMetrics.Rx(header.Type, header.Subtype, 1)
|
||||
if !f.handleEncrypted(ci, addr, header) {
|
||||
return
|
||||
}
|
||||
|
||||
d, err := f.decrypt(hostinfo, header.MessageCounter, out, packet, header, nb)
|
||||
if err != nil {
|
||||
l.WithError(err).WithField("udpAddr", addr).WithField("vpnIp", IntIp(hostinfo.hostId)).
|
||||
hostinfo.logger().WithError(err).WithField("udpAddr", addr).
|
||||
WithField("packet", packet).
|
||||
Error("Failed to decrypt test packet")
|
||||
|
||||
@@ -102,27 +100,30 @@ func (f *Interface) readOutsidePackets(addr *udpAddr, out []byte, packet []byte,
|
||||
// are unauthenticated
|
||||
|
||||
case handshake:
|
||||
f.messageMetrics.Rx(header.Type, header.Subtype, 1)
|
||||
HandleIncomingHandshake(f, addr, packet, header, hostinfo)
|
||||
return
|
||||
|
||||
case recvError:
|
||||
// TODO: Remove this with recv_error deprecation
|
||||
f.messageMetrics.Rx(header.Type, header.Subtype, 1)
|
||||
f.handleRecvError(addr, header)
|
||||
return
|
||||
|
||||
case closeTunnel:
|
||||
f.messageMetrics.Rx(header.Type, header.Subtype, 1)
|
||||
if !f.handleEncrypted(ci, addr, header) {
|
||||
return
|
||||
}
|
||||
|
||||
l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
|
||||
hostinfo.logger().WithField("udpAddr", addr).
|
||||
Info("Close tunnel received, tearing down.")
|
||||
|
||||
f.closeTunnel(hostinfo)
|
||||
return
|
||||
|
||||
default:
|
||||
l.Debugf("Unexpected packet received from %s", addr)
|
||||
f.messageMetrics.Rx(header.Type, header.Subtype, 1)
|
||||
hostinfo.logger().Debugf("Unexpected packet received from %s", addr)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -136,21 +137,24 @@ func (f *Interface) closeTunnel(hostInfo *HostInfo) {
|
||||
f.connectionManager.ClearIP(hostInfo.hostId)
|
||||
f.connectionManager.ClearPendingDeletion(hostInfo.hostId)
|
||||
f.lightHouse.DeleteVpnIP(hostInfo.hostId)
|
||||
f.hostMap.DeleteVpnIP(hostInfo.hostId)
|
||||
f.hostMap.DeleteIndex(hostInfo.localIndexId)
|
||||
f.hostMap.DeleteHostInfo(hostInfo)
|
||||
}
|
||||
|
||||
func (f *Interface) handleHostRoaming(hostinfo *HostInfo, addr *udpAddr) {
|
||||
if hostDidRoam(hostinfo.remote, addr) {
|
||||
if !hostinfo.lastRoam.IsZero() && addr.Equals(hostinfo.lastRoamRemote) && time.Since(hostinfo.lastRoam) < RoamingSupressSeconds*time.Second {
|
||||
if !f.lightHouse.remoteAllowList.Allow(udp2ipInt(addr)) {
|
||||
hostinfo.logger().WithField("newAddr", addr).Debug("lighthouse.remote_allow_list denied roaming")
|
||||
return
|
||||
}
|
||||
if !hostinfo.lastRoam.IsZero() && addr.Equals(hostinfo.lastRoamRemote) && time.Since(hostinfo.lastRoam) < RoamingSuppressSeconds*time.Second {
|
||||
if l.Level >= logrus.DebugLevel {
|
||||
l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", hostinfo.remote).WithField("newAddr", addr).
|
||||
Debugf("Supressing roam back to previous remote for %d seconds", RoamingSupressSeconds)
|
||||
hostinfo.logger().WithField("udpAddr", hostinfo.remote).WithField("newAddr", addr).
|
||||
Debugf("Suppressing roam back to previous remote for %d seconds", RoamingSuppressSeconds)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", hostinfo.remote).WithField("newAddr", addr).
|
||||
hostinfo.logger().WithField("udpAddr", hostinfo.remote).WithField("newAddr", addr).
|
||||
Info("Host roamed to new udp ip/port.")
|
||||
hostinfo.lastRoam = time.Now()
|
||||
remoteCopy := *hostinfo.remote
|
||||
@@ -244,7 +248,7 @@ func (f *Interface) decrypt(hostinfo *HostInfo, mc uint64, out []byte, packet []
|
||||
}
|
||||
|
||||
if !hostinfo.ConnectionState.window.Update(mc) {
|
||||
l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("header", header).
|
||||
hostinfo.logger().WithField("header", header).
|
||||
Debugln("dropping out of window packet")
|
||||
return nil, errors.New("out of window packet")
|
||||
}
|
||||
@@ -252,12 +256,12 @@ func (f *Interface) decrypt(hostinfo *HostInfo, mc uint64, out []byte, packet []
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out []byte, packet []byte, fwPacket *FirewallPacket, nb []byte) {
|
||||
func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out []byte, packet []byte, fwPacket *FirewallPacket, nb []byte, q int, localCache ConntrackCache) {
|
||||
var err error
|
||||
|
||||
out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:HeaderLen], packet[HeaderLen:], messageCounter, nb)
|
||||
if err != nil {
|
||||
l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).Error("Failed to decrypt packet")
|
||||
hostinfo.logger().WithError(err).Error("Failed to decrypt packet")
|
||||
//TODO: maybe after build 64 is out? 06/14/2018 - NB
|
||||
//f.sendRecvError(hostinfo.remote, header.RemoteIndex)
|
||||
return
|
||||
@@ -265,32 +269,36 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
|
||||
|
||||
err = newPacket(out, true, fwPacket)
|
||||
if err != nil {
|
||||
l.WithError(err).WithField("packet", out).WithField("hostInfo", IntIp(hostinfo.hostId)).
|
||||
hostinfo.logger().WithError(err).WithField("packet", out).
|
||||
Warnf("Error while validating inbound packet")
|
||||
return
|
||||
}
|
||||
|
||||
if !hostinfo.ConnectionState.window.Update(messageCounter) {
|
||||
l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("fwPacket", fwPacket).
|
||||
hostinfo.logger().WithField("fwPacket", fwPacket).
|
||||
Debugln("dropping out of window packet")
|
||||
return
|
||||
}
|
||||
|
||||
if f.firewall.Drop(out, *fwPacket, true, hostinfo, trustedCAs) {
|
||||
l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("fwPacket", fwPacket).
|
||||
Debugln("dropping inbound packet")
|
||||
dropReason := f.firewall.Drop(out, *fwPacket, true, hostinfo, trustedCAs, localCache)
|
||||
if dropReason != nil {
|
||||
if l.Level >= logrus.DebugLevel {
|
||||
hostinfo.logger().WithField("fwPacket", fwPacket).
|
||||
WithField("reason", dropReason).
|
||||
Debugln("dropping inbound packet")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
f.connectionManager.In(hostinfo.hostId)
|
||||
err = f.inside.WriteRaw(out)
|
||||
_, err = f.readers[q].Write(out)
|
||||
if err != nil {
|
||||
l.WithError(err).Error("Failed to write to tun")
|
||||
}
|
||||
}
|
||||
|
||||
func (f *Interface) sendRecvError(endpoint *udpAddr, index uint32) {
|
||||
f.metricTxRecvError.Inc(1)
|
||||
f.messageMetrics.Tx(recvError, 0, 1)
|
||||
|
||||
//TODO: this should be a signed message so we can trust that we should drop the index
|
||||
b := HeaderEncode(make([]byte, HeaderLen), Version, uint8(recvError), 0, index, 0)
|
||||
@@ -303,22 +311,24 @@ func (f *Interface) sendRecvError(endpoint *udpAddr, index uint32) {
|
||||
}
|
||||
|
||||
func (f *Interface) handleRecvError(addr *udpAddr, h *Header) {
|
||||
f.metricRxRecvError.Inc(1)
|
||||
|
||||
// This flag is to stop caring about recv_error from old versions
|
||||
// This should go away when the old version is gone from prod
|
||||
if l.Level >= logrus.DebugLevel {
|
||||
l.WithField("index", h.RemoteIndex).
|
||||
WithField("udpAddr", addr).
|
||||
Debug("Recv error received")
|
||||
}
|
||||
|
||||
// First, clean up in the pending hostmap
|
||||
f.handshakeManager.pendingHostMap.DeleteReverseIndex(h.RemoteIndex)
|
||||
|
||||
hostinfo, err := f.hostMap.QueryReverseIndex(h.RemoteIndex)
|
||||
if err != nil {
|
||||
l.Debugln(err, ": ", h.RemoteIndex)
|
||||
return
|
||||
}
|
||||
|
||||
hostinfo.Lock()
|
||||
defer hostinfo.Unlock()
|
||||
|
||||
if !hostinfo.RecvErrorExceeded() {
|
||||
return
|
||||
}
|
||||
@@ -327,17 +337,13 @@ func (f *Interface) handleRecvError(addr *udpAddr, h *Header) {
|
||||
return
|
||||
}
|
||||
|
||||
id := hostinfo.localIndexId
|
||||
host := hostinfo.hostId
|
||||
// We delete this host from the main hostmap
|
||||
f.hostMap.DeleteIndex(id)
|
||||
f.hostMap.DeleteVpnIP(host)
|
||||
f.hostMap.DeleteHostInfo(hostinfo)
|
||||
// We also delete it from pending to allow for
|
||||
// fast reconnect. We must null the connectionstate
|
||||
// or a counter reuse may happen
|
||||
hostinfo.ConnectionState = nil
|
||||
f.handshakeManager.DeleteIndex(id)
|
||||
f.handshakeManager.DeleteVpnIP(host)
|
||||
f.handshakeManager.DeleteHostInfo(hostinfo)
|
||||
}
|
||||
|
||||
/*
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
package nebula
|
||||
|
||||
import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"golang.org/x/net/ipv4"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"golang.org/x/net/ipv4"
|
||||
)
|
||||
|
||||
func Test_newPacket(t *testing.T) {
|
||||
|
||||
30
punchy.go
Normal file
30
punchy.go
Normal file
@@ -0,0 +1,30 @@
|
||||
package nebula
|
||||
|
||||
import "time"
|
||||
|
||||
type Punchy struct {
|
||||
Punch bool
|
||||
Respond bool
|
||||
Delay time.Duration
|
||||
}
|
||||
|
||||
func NewPunchyFromConfig(c *Config) *Punchy {
|
||||
p := &Punchy{}
|
||||
|
||||
if c.IsSet("punchy.punch") {
|
||||
p.Punch = c.GetBool("punchy.punch", false)
|
||||
} else {
|
||||
// Deprecated fallback
|
||||
p.Punch = c.GetBool("punchy", false)
|
||||
}
|
||||
|
||||
if c.IsSet("punchy.respond") {
|
||||
p.Respond = c.GetBool("punchy.respond", false)
|
||||
} else {
|
||||
// Deprecated fallback
|
||||
p.Respond = c.GetBool("punch_back", false)
|
||||
}
|
||||
|
||||
p.Delay = c.GetDuration("punchy.delay", time.Second)
|
||||
return p
|
||||
}
|
||||
44
punchy_test.go
Normal file
44
punchy_test.go
Normal file
@@ -0,0 +1,44 @@
|
||||
package nebula
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestNewPunchyFromConfig(t *testing.T) {
|
||||
c := NewConfig()
|
||||
|
||||
// Test defaults
|
||||
p := NewPunchyFromConfig(c)
|
||||
assert.Equal(t, false, p.Punch)
|
||||
assert.Equal(t, false, p.Respond)
|
||||
assert.Equal(t, time.Second, p.Delay)
|
||||
|
||||
// punchy deprecation
|
||||
c.Settings["punchy"] = true
|
||||
p = NewPunchyFromConfig(c)
|
||||
assert.Equal(t, true, p.Punch)
|
||||
|
||||
// punchy.punch
|
||||
c.Settings["punchy"] = map[interface{}]interface{}{"punch": true}
|
||||
p = NewPunchyFromConfig(c)
|
||||
assert.Equal(t, true, p.Punch)
|
||||
|
||||
// punch_back deprecation
|
||||
c.Settings["punch_back"] = true
|
||||
p = NewPunchyFromConfig(c)
|
||||
assert.Equal(t, true, p.Respond)
|
||||
|
||||
// punchy.respond
|
||||
c.Settings["punchy"] = map[interface{}]interface{}{"respond": true}
|
||||
c.Settings["punch_back"] = false
|
||||
p = NewPunchyFromConfig(c)
|
||||
assert.Equal(t, true, p.Respond)
|
||||
|
||||
// punchy.delay
|
||||
c.Settings["punchy"] = map[interface{}]interface{}{"delay": "1m"}
|
||||
p = NewPunchyFromConfig(c)
|
||||
assert.Equal(t, time.Minute, p.Delay)
|
||||
}
|
||||
59
ssh.go
59
ssh.go
@@ -5,15 +5,17 @@ import (
|
||||
"encoding/json"
|
||||
"flag"
|
||||
"fmt"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/sshd"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"os"
|
||||
"reflect"
|
||||
"runtime/pprof"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"syscall"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/sshd"
|
||||
)
|
||||
|
||||
type sshListHostMapFlags struct {
|
||||
@@ -65,10 +67,11 @@ func configSSH(ssh *sshd.SSHServer, c *Config) error {
|
||||
return fmt.Errorf("sshd.listen must be provided")
|
||||
}
|
||||
|
||||
port := strings.Split(listen, ":")
|
||||
if len(port) < 2 {
|
||||
return fmt.Errorf("sshd.listen does not have a port")
|
||||
} else if port[1] == "22" {
|
||||
_, port, err := net.SplitHostPort(listen)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid sshd.listen address: %s", err)
|
||||
}
|
||||
if port == "22" {
|
||||
return fmt.Errorf("sshd.listen can not use port 22")
|
||||
}
|
||||
|
||||
@@ -351,7 +354,7 @@ func sshListHostMap(hostMap *HostMap, a interface{}, w sshd.StringWriter) error
|
||||
}
|
||||
|
||||
if v.ConnectionState != nil {
|
||||
h["messageCounter"] = v.ConnectionState.messageCounter
|
||||
h["messageCounter"] = atomic.LoadUint64(&v.ConnectionState.atomicMessageCounter)
|
||||
}
|
||||
|
||||
d[x] = h
|
||||
@@ -461,7 +464,12 @@ func sshQueryLighthouse(ifce *Interface, fs interface{}, a []string, w sshd.Stri
|
||||
return w.WriteLine("No vpn ip was provided")
|
||||
}
|
||||
|
||||
vpnIp := ip2int(net.ParseIP(a[0]))
|
||||
parsedIp := net.ParseIP(a[0])
|
||||
if parsedIp == nil {
|
||||
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
|
||||
}
|
||||
|
||||
vpnIp := ip2int(parsedIp)
|
||||
if vpnIp == 0 {
|
||||
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
|
||||
}
|
||||
@@ -481,7 +489,12 @@ func sshCloseTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWr
|
||||
return w.WriteLine("No vpn ip was provided")
|
||||
}
|
||||
|
||||
vpnIp := ip2int(net.ParseIP(a[0]))
|
||||
parsedIp := net.ParseIP(a[0])
|
||||
if parsedIp == nil {
|
||||
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
|
||||
}
|
||||
|
||||
vpnIp := ip2int(parsedIp)
|
||||
if vpnIp == 0 {
|
||||
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
|
||||
}
|
||||
@@ -519,7 +532,12 @@ func sshCreateTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringW
|
||||
return w.WriteLine("No vpn ip was provided")
|
||||
}
|
||||
|
||||
vpnIp := ip2int(net.ParseIP(a[0]))
|
||||
parsedIp := net.ParseIP(a[0])
|
||||
if parsedIp == nil {
|
||||
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
|
||||
}
|
||||
|
||||
vpnIp := ip2int(parsedIp)
|
||||
if vpnIp == 0 {
|
||||
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
|
||||
}
|
||||
@@ -571,7 +589,12 @@ func sshChangeRemote(ifce *Interface, fs interface{}, a []string, w sshd.StringW
|
||||
return w.WriteLine("Address could not be parsed")
|
||||
}
|
||||
|
||||
vpnIp := ip2int(net.ParseIP(a[0]))
|
||||
parsedIp := net.ParseIP(a[0])
|
||||
if parsedIp == nil {
|
||||
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
|
||||
}
|
||||
|
||||
vpnIp := ip2int(parsedIp)
|
||||
if vpnIp == 0 {
|
||||
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
|
||||
}
|
||||
@@ -647,7 +670,12 @@ func sshPrintCert(ifce *Interface, fs interface{}, a []string, w sshd.StringWrit
|
||||
|
||||
cert := ifce.certState.certificate
|
||||
if len(a) > 0 {
|
||||
vpnIp := ip2int(net.ParseIP(a[0]))
|
||||
parsedIp := net.ParseIP(a[0])
|
||||
if parsedIp == nil {
|
||||
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
|
||||
}
|
||||
|
||||
vpnIp := ip2int(parsedIp)
|
||||
if vpnIp == 0 {
|
||||
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
|
||||
}
|
||||
@@ -694,7 +722,12 @@ func sshPrintTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWr
|
||||
return w.WriteLine("No vpn ip was provided")
|
||||
}
|
||||
|
||||
vpnIp := ip2int(net.ParseIP(a[0]))
|
||||
parsedIp := net.ParseIP(a[0])
|
||||
if parsedIp == nil {
|
||||
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
|
||||
}
|
||||
|
||||
vpnIp := ip2int(parsedIp)
|
||||
if vpnIp == 0 {
|
||||
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
|
||||
}
|
||||
|
||||
@@ -4,9 +4,10 @@ import (
|
||||
"errors"
|
||||
"flag"
|
||||
"fmt"
|
||||
"github.com/armon/go-radix"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/armon/go-radix"
|
||||
)
|
||||
|
||||
// CommandFlags is a function called before help or command execution to parse command line flags
|
||||
|
||||
@@ -2,10 +2,11 @@ package sshd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
|
||||
"github.com/armon/go-radix"
|
||||
"github.com/sirupsen/logrus"
|
||||
"golang.org/x/crypto/ssh"
|
||||
"net"
|
||||
)
|
||||
|
||||
type SSHServer struct {
|
||||
|
||||
@@ -2,13 +2,14 @@ package sshd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/anmitsu/go-shlex"
|
||||
"github.com/armon/go-radix"
|
||||
"github.com/sirupsen/logrus"
|
||||
"golang.org/x/crypto/ssh"
|
||||
"golang.org/x/crypto/ssh/terminal"
|
||||
"sort"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type session struct {
|
||||
|
||||
37
stats.go
37
stats.go
@@ -3,18 +3,19 @@ package nebula
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/cyberdelia/go-metrics-graphite"
|
||||
mp "github.com/nbrownus/go-metrics-prometheus"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
"github.com/rcrowley/go-metrics"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
graphite "github.com/cyberdelia/go-metrics-graphite"
|
||||
mp "github.com/nbrownus/go-metrics-prometheus"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
"github.com/rcrowley/go-metrics"
|
||||
)
|
||||
|
||||
func startStats(c *Config) error {
|
||||
func startStats(c *Config, configTest bool) error {
|
||||
mType := c.GetString("stats.type", "")
|
||||
if mType == "" || mType == "none" {
|
||||
return nil
|
||||
@@ -27,9 +28,9 @@ func startStats(c *Config) error {
|
||||
|
||||
switch mType {
|
||||
case "graphite":
|
||||
startGraphiteStats(interval, c)
|
||||
startGraphiteStats(interval, c, configTest)
|
||||
case "prometheus":
|
||||
startPrometheusStats(interval, c)
|
||||
startPrometheusStats(interval, c, configTest)
|
||||
default:
|
||||
return fmt.Errorf("stats.type was not understood: %s", mType)
|
||||
}
|
||||
@@ -43,7 +44,7 @@ func startStats(c *Config) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func startGraphiteStats(i time.Duration, c *Config) error {
|
||||
func startGraphiteStats(i time.Duration, c *Config, configTest bool) error {
|
||||
proto := c.GetString("stats.protocol", "tcp")
|
||||
host := c.GetString("stats.host", "")
|
||||
if host == "" {
|
||||
@@ -57,11 +58,13 @@ func startGraphiteStats(i time.Duration, c *Config) error {
|
||||
}
|
||||
|
||||
l.Infof("Starting graphite. Interval: %s, prefix: %s, addr: %s", i, prefix, addr)
|
||||
go graphite.Graphite(metrics.DefaultRegistry, i, prefix, addr)
|
||||
if !configTest {
|
||||
go graphite.Graphite(metrics.DefaultRegistry, i, prefix, addr)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func startPrometheusStats(i time.Duration, c *Config) error {
|
||||
func startPrometheusStats(i time.Duration, c *Config, configTest bool) error {
|
||||
namespace := c.GetString("stats.namespace", "")
|
||||
subsystem := c.GetString("stats.subsystem", "")
|
||||
|
||||
@@ -79,11 +82,13 @@ func startPrometheusStats(i time.Duration, c *Config) error {
|
||||
pClient := mp.NewPrometheusProvider(metrics.DefaultRegistry, namespace, subsystem, pr, i)
|
||||
go pClient.UpdatePrometheusMetrics()
|
||||
|
||||
go func() {
|
||||
l.Infof("Prometheus stats listening on %s at %s", listen, path)
|
||||
http.Handle(path, promhttp.HandlerFor(pr, promhttp.HandlerOpts{ErrorLog: l}))
|
||||
log.Fatal(http.ListenAndServe(listen, nil))
|
||||
}()
|
||||
if !configTest {
|
||||
go func() {
|
||||
l.Infof("Prometheus stats listening on %s at %s", listen, path)
|
||||
http.Handle(path, promhttp.HandlerFor(pr, promhttp.HandlerOpts{ErrorLog: l}))
|
||||
log.Fatal(http.ListenAndServe(listen, nil))
|
||||
}()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
package nebula
|
||||
|
||||
import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestNewTimerWheel(t *testing.T) {
|
||||
|
||||
80
tun_android.go
Normal file
80
tun_android.go
Normal file
@@ -0,0 +1,80 @@
|
||||
package nebula
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
type Tun struct {
|
||||
io.ReadWriteCloser
|
||||
fd int
|
||||
Device string
|
||||
Cidr *net.IPNet
|
||||
MaxMTU int
|
||||
DefaultMTU int
|
||||
TXQueueLen int
|
||||
Routes []route
|
||||
UnsafeRoutes []route
|
||||
}
|
||||
|
||||
func newTunFromFd(deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
|
||||
file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
|
||||
|
||||
ifce = &Tun{
|
||||
ReadWriteCloser: file,
|
||||
fd: int(file.Fd()),
|
||||
Device: "android",
|
||||
Cidr: cidr,
|
||||
DefaultMTU: defaultMTU,
|
||||
TXQueueLen: txQueueLen,
|
||||
Routes: routes,
|
||||
UnsafeRoutes: unsafeRoutes,
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int, multiqueue bool) (ifce *Tun, err error) {
|
||||
return nil, fmt.Errorf("newTun not supported in Android")
|
||||
}
|
||||
|
||||
func (c *Tun) WriteRaw(b []byte) error {
|
||||
var nn int
|
||||
for {
|
||||
max := len(b)
|
||||
n, err := unix.Write(c.fd, b[nn:max])
|
||||
if n > 0 {
|
||||
nn += n
|
||||
}
|
||||
if nn == len(b) {
|
||||
return err
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if n == 0 {
|
||||
return io.ErrUnexpectedEOF
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c Tun) Activate() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Tun) CidrNet() *net.IPNet {
|
||||
return c.Cidr
|
||||
}
|
||||
|
||||
func (c *Tun) DeviceName() string {
|
||||
return c.Device
|
||||
}
|
||||
|
||||
func (t *Tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||
return nil, fmt.Errorf("TODO: multiqueue not implemented for android")
|
||||
}
|
||||
@@ -132,7 +132,7 @@ func parseUnsafeRoutes(config *Config, network *net.IPNet) ([]route, error) {
|
||||
|
||||
via, ok := rVia.(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("entry %v.via in tun.unsafe_routes is not a string: %v", i+1, err)
|
||||
return nil, fmt.Errorf("entry %v.via in tun.unsafe_routes is not a string: found %T", i+1, rVia)
|
||||
}
|
||||
|
||||
nVia := net.ParseIP(via)
|
||||
@@ -147,6 +147,7 @@ func parseUnsafeRoutes(config *Config, network *net.IPNet) ([]route, error) {
|
||||
|
||||
r := route{
|
||||
via: &nVia,
|
||||
mtu: mtu,
|
||||
}
|
||||
|
||||
_, r.route, err = net.ParseCIDR(fmt.Sprintf("%v", rRoute))
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
// +build !ios
|
||||
|
||||
package nebula
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"os/exec"
|
||||
"strconv"
|
||||
@@ -18,10 +21,11 @@ type Tun struct {
|
||||
*water.Interface
|
||||
}
|
||||
|
||||
func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
|
||||
func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int, multiqueue bool) (ifce *Tun, err error) {
|
||||
if len(routes) > 0 {
|
||||
return nil, fmt.Errorf("Route MTU not supported in Darwin")
|
||||
return nil, fmt.Errorf("route MTU not supported in Darwin")
|
||||
}
|
||||
|
||||
// NOTE: You cannot set the deviceName under Darwin, so you must check tun.Device after calling .Activate()
|
||||
return &Tun{
|
||||
Cidr: cidr,
|
||||
@@ -30,30 +34,34 @@ func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func newTunFromFd(deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
|
||||
return nil, fmt.Errorf("newTunFromFd not supported in Darwin")
|
||||
}
|
||||
|
||||
func (c *Tun) Activate() error {
|
||||
var err error
|
||||
c.Interface, err = water.New(water.Config{
|
||||
DeviceType: water.TUN,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("Activate failed: %v", err)
|
||||
return fmt.Errorf("activate failed: %v", err)
|
||||
}
|
||||
|
||||
c.Device = c.Interface.Name()
|
||||
|
||||
// TODO use syscalls instead of exec.Command
|
||||
if err = exec.Command("ifconfig", c.Device, c.Cidr.String(), c.Cidr.IP.String()).Run(); err != nil {
|
||||
if err = exec.Command("/sbin/ifconfig", c.Device, c.Cidr.String(), c.Cidr.IP.String()).Run(); err != nil {
|
||||
return fmt.Errorf("failed to run 'ifconfig': %s", err)
|
||||
}
|
||||
if err = exec.Command("route", "-n", "add", "-net", c.Cidr.String(), "-interface", c.Device).Run(); err != nil {
|
||||
if err = exec.Command("/sbin/route", "-n", "add", "-net", c.Cidr.String(), "-interface", c.Device).Run(); err != nil {
|
||||
return fmt.Errorf("failed to run 'route add': %s", err)
|
||||
}
|
||||
if err = exec.Command("ifconfig", c.Device, "mtu", strconv.Itoa(c.MTU)).Run(); err != nil {
|
||||
if err = exec.Command("/sbin/ifconfig", c.Device, "mtu", strconv.Itoa(c.MTU)).Run(); err != nil {
|
||||
return fmt.Errorf("failed to run 'ifconfig': %s", err)
|
||||
}
|
||||
// Unsafe path routes
|
||||
for _, r := range c.UnsafeRoutes {
|
||||
if err = exec.Command("route", "-n", "add", "-net", r.route.String(), "-interface", c.Device).Run(); err != nil {
|
||||
if err = exec.Command("/sbin/route", "-n", "add", "-net", r.route.String(), "-interface", c.Device).Run(); err != nil {
|
||||
return fmt.Errorf("failed to run 'route add' for unsafe_route %s: %s", r.route.String(), err)
|
||||
}
|
||||
}
|
||||
@@ -61,7 +69,19 @@ func (c *Tun) Activate() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Tun) CidrNet() *net.IPNet {
|
||||
return c.Cidr
|
||||
}
|
||||
|
||||
func (c *Tun) DeviceName() string {
|
||||
return c.Device
|
||||
}
|
||||
|
||||
func (c *Tun) WriteRaw(b []byte) error {
|
||||
_, err := c.Write(b)
|
||||
return err
|
||||
}
|
||||
|
||||
func (t *Tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||
return nil, fmt.Errorf("TODO: multiqueue not implemented for darwin")
|
||||
}
|
||||
|
||||
175
tun_disabled.go
Normal file
175
tun_disabled.go
Normal file
@@ -0,0 +1,175 @@
|
||||
package nebula
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"strings"
|
||||
|
||||
"github.com/rcrowley/go-metrics"
|
||||
"github.com/sirupsen/logrus"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
type disabledTun struct {
|
||||
read chan []byte
|
||||
cidr *net.IPNet
|
||||
logger *log.Logger
|
||||
|
||||
// Track these metrics since we don't have the tun device to do it for us
|
||||
tx metrics.Counter
|
||||
rx metrics.Counter
|
||||
}
|
||||
|
||||
func newDisabledTun(cidr *net.IPNet, queueLen int, metricsEnabled bool, l *log.Logger) *disabledTun {
|
||||
tun := &disabledTun{
|
||||
cidr: cidr,
|
||||
read: make(chan []byte, queueLen),
|
||||
logger: l,
|
||||
}
|
||||
|
||||
if metricsEnabled {
|
||||
tun.tx = metrics.GetOrRegisterCounter("messages.tx.message", nil)
|
||||
tun.rx = metrics.GetOrRegisterCounter("messages.rx.message", nil)
|
||||
} else {
|
||||
tun.tx = &metrics.NilCounter{}
|
||||
tun.rx = &metrics.NilCounter{}
|
||||
}
|
||||
|
||||
return tun
|
||||
}
|
||||
|
||||
func (*disabledTun) Activate() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *disabledTun) CidrNet() *net.IPNet {
|
||||
return t.cidr
|
||||
}
|
||||
|
||||
func (*disabledTun) DeviceName() string {
|
||||
return "disabled"
|
||||
}
|
||||
|
||||
func (t *disabledTun) Read(b []byte) (int, error) {
|
||||
r, ok := <-t.read
|
||||
if !ok {
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
if len(r) > len(b) {
|
||||
return 0, fmt.Errorf("packet larger than mtu: %d > %d bytes", len(r), len(b))
|
||||
}
|
||||
|
||||
t.tx.Inc(1)
|
||||
if l.Level >= logrus.DebugLevel {
|
||||
t.logger.WithField("raw", prettyPacket(r)).Debugf("Write payload")
|
||||
}
|
||||
|
||||
return copy(b, r), nil
|
||||
}
|
||||
|
||||
func (t *disabledTun) handleICMPEchoRequest(b []byte) bool {
|
||||
// Return early if this is not a simple ICMP Echo Request
|
||||
if !(len(b) >= 28 && len(b) <= mtu && b[0] == 0x45 && b[9] == 0x01 && b[20] == 0x08) {
|
||||
return false
|
||||
}
|
||||
|
||||
// We don't support fragmented packets
|
||||
if b[7] != 0 || (b[6]&0x2F != 0) {
|
||||
return false
|
||||
}
|
||||
|
||||
buf := make([]byte, len(b))
|
||||
copy(buf, b)
|
||||
|
||||
// Swap dest / src IPs and recalculate checksum
|
||||
ipv4 := buf[0:20]
|
||||
copy(ipv4[12:16], b[16:20])
|
||||
copy(ipv4[16:20], b[12:16])
|
||||
ipv4[10] = 0
|
||||
ipv4[11] = 0
|
||||
binary.BigEndian.PutUint16(ipv4[10:], ipChecksum(ipv4))
|
||||
|
||||
// Change type to ICMP Echo Reply and recalculate checksum
|
||||
icmp := buf[20:]
|
||||
icmp[0] = 0
|
||||
icmp[2] = 0
|
||||
icmp[3] = 0
|
||||
binary.BigEndian.PutUint16(icmp[2:], ipChecksum(icmp))
|
||||
|
||||
// attempt to write it, but don't block
|
||||
select {
|
||||
case t.read <- buf:
|
||||
default:
|
||||
t.logger.Debugf("tun_disabled: dropped ICMP Echo Reply response")
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (t *disabledTun) Write(b []byte) (int, error) {
|
||||
t.rx.Inc(1)
|
||||
|
||||
// Check for ICMP Echo Request before spending time doing the full parsing
|
||||
if t.handleICMPEchoRequest(b) {
|
||||
if l.Level >= logrus.DebugLevel {
|
||||
t.logger.WithField("raw", prettyPacket(b)).Debugf("Disabled tun responded to ICMP Echo Request")
|
||||
}
|
||||
} else if l.Level >= logrus.DebugLevel {
|
||||
t.logger.WithField("raw", prettyPacket(b)).Debugf("Disabled tun received unexpected payload")
|
||||
}
|
||||
return len(b), nil
|
||||
}
|
||||
|
||||
func (t *disabledTun) WriteRaw(b []byte) error {
|
||||
_, err := t.Write(b)
|
||||
return err
|
||||
}
|
||||
|
||||
func (t *disabledTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||
return t, nil
|
||||
}
|
||||
|
||||
func (t *disabledTun) Close() error {
|
||||
if t.read != nil {
|
||||
close(t.read)
|
||||
t.read = nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type prettyPacket []byte
|
||||
|
||||
func (p prettyPacket) String() string {
|
||||
var s strings.Builder
|
||||
|
||||
for i, b := range p {
|
||||
if i > 0 && i%8 == 0 {
|
||||
s.WriteString(" ")
|
||||
}
|
||||
s.WriteString(fmt.Sprintf("%02x ", b))
|
||||
}
|
||||
|
||||
return s.String()
|
||||
}
|
||||
|
||||
func ipChecksum(b []byte) uint16 {
|
||||
var c uint32
|
||||
sz := len(b) - 1
|
||||
|
||||
for i := 0; i < sz; i += 2 {
|
||||
c += uint32(b[i]) << 8
|
||||
c += uint32(b[i+1])
|
||||
}
|
||||
if sz%2 == 0 {
|
||||
c += uint32(b[sz]) << 8
|
||||
}
|
||||
|
||||
for (c >> 16) > 0 {
|
||||
c = (c & 0xffff) + (c >> 16)
|
||||
}
|
||||
|
||||
return ^uint16(c)
|
||||
}
|
||||
93
tun_freebsd.go
Normal file
93
tun_freebsd.go
Normal file
@@ -0,0 +1,93 @@
|
||||
package nebula
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"os/exec"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
|
||||
|
||||
type Tun struct {
|
||||
Device string
|
||||
Cidr *net.IPNet
|
||||
MTU int
|
||||
UnsafeRoutes []route
|
||||
|
||||
io.ReadWriteCloser
|
||||
}
|
||||
|
||||
func newTunFromFd(deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
|
||||
return nil, fmt.Errorf("newTunFromFd not supported in FreeBSD")
|
||||
}
|
||||
|
||||
func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int, multiqueue bool) (ifce *Tun, err error) {
|
||||
if len(routes) > 0 {
|
||||
return nil, fmt.Errorf("Route MTU not supported in FreeBSD")
|
||||
}
|
||||
if strings.HasPrefix(deviceName, "/dev/") {
|
||||
deviceName = strings.TrimPrefix(deviceName, "/dev/")
|
||||
}
|
||||
if !deviceNameRE.MatchString(deviceName) {
|
||||
return nil, fmt.Errorf("tun.dev must match `tun[0-9]+`")
|
||||
}
|
||||
return &Tun{
|
||||
Device: deviceName,
|
||||
Cidr: cidr,
|
||||
MTU: defaultMTU,
|
||||
UnsafeRoutes: unsafeRoutes,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *Tun) Activate() error {
|
||||
var err error
|
||||
c.ReadWriteCloser, err = os.OpenFile("/dev/"+c.Device, os.O_RDWR, 0)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Activate failed: %v", err)
|
||||
}
|
||||
|
||||
// TODO use syscalls instead of exec.Command
|
||||
l.Debug("command: ifconfig", c.Device, c.Cidr.String(), c.Cidr.IP.String())
|
||||
if err = exec.Command("/sbin/ifconfig", c.Device, c.Cidr.String(), c.Cidr.IP.String()).Run(); err != nil {
|
||||
return fmt.Errorf("failed to run 'ifconfig': %s", err)
|
||||
}
|
||||
l.Debug("command: route", "-n", "add", "-net", c.Cidr.String(), "-interface", c.Device)
|
||||
if err = exec.Command("/sbin/route", "-n", "add", "-net", c.Cidr.String(), "-interface", c.Device).Run(); err != nil {
|
||||
return fmt.Errorf("failed to run 'route add': %s", err)
|
||||
}
|
||||
l.Debug("command: ifconfig", c.Device, "mtu", strconv.Itoa(c.MTU))
|
||||
if err = exec.Command("/sbin/ifconfig", c.Device, "mtu", strconv.Itoa(c.MTU)).Run(); err != nil {
|
||||
return fmt.Errorf("failed to run 'ifconfig': %s", err)
|
||||
}
|
||||
// Unsafe path routes
|
||||
for _, r := range c.UnsafeRoutes {
|
||||
l.Debug("command: route", "-n", "add", "-net", r.route.String(), "-interface", c.Device)
|
||||
if err = exec.Command("/sbin/route", "-n", "add", "-net", r.route.String(), "-interface", c.Device).Run(); err != nil {
|
||||
return fmt.Errorf("failed to run 'route add' for unsafe_route %s: %s", r.route.String(), err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Tun) CidrNet() *net.IPNet {
|
||||
return c.Cidr
|
||||
}
|
||||
|
||||
func (c *Tun) DeviceName() string {
|
||||
return c.Device
|
||||
}
|
||||
|
||||
func (c *Tun) WriteRaw(b []byte) error {
|
||||
_, err := c.Write(b)
|
||||
return err
|
||||
}
|
||||
|
||||
func (t *Tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||
return nil, fmt.Errorf("TODO: multiqueue not implemented for freebsd")
|
||||
}
|
||||
117
tun_ios.go
Normal file
117
tun_ios.go
Normal file
@@ -0,0 +1,117 @@
|
||||
// +build ios
|
||||
|
||||
package nebula
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"sync"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
type Tun struct {
|
||||
io.ReadWriteCloser
|
||||
Device string
|
||||
Cidr *net.IPNet
|
||||
}
|
||||
|
||||
func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int, multiqueue bool) (ifce *Tun, err error) {
|
||||
return nil, fmt.Errorf("newTun not supported in iOS")
|
||||
}
|
||||
|
||||
func newTunFromFd(deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
|
||||
if len(routes) > 0 {
|
||||
return nil, fmt.Errorf("route MTU not supported in Darwin")
|
||||
}
|
||||
|
||||
file := os.NewFile(uintptr(deviceFd), "/dev/tun")
|
||||
ifce = &Tun{
|
||||
Cidr: cidr,
|
||||
Device: "iOS",
|
||||
ReadWriteCloser: &tunReadCloser{f: file},
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (c *Tun) Activate() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Tun) WriteRaw(b []byte) error {
|
||||
_, err := c.Write(b)
|
||||
return err
|
||||
}
|
||||
|
||||
// The following is hoisted up from water, we do this so we can inject our own fd on iOS
|
||||
type tunReadCloser struct {
|
||||
f io.ReadWriteCloser
|
||||
|
||||
rMu sync.Mutex
|
||||
rBuf []byte
|
||||
|
||||
wMu sync.Mutex
|
||||
wBuf []byte
|
||||
}
|
||||
|
||||
func (t *tunReadCloser) Read(to []byte) (int, error) {
|
||||
t.rMu.Lock()
|
||||
defer t.rMu.Unlock()
|
||||
|
||||
if cap(t.rBuf) < len(to)+4 {
|
||||
t.rBuf = make([]byte, len(to)+4)
|
||||
}
|
||||
t.rBuf = t.rBuf[:len(to)+4]
|
||||
|
||||
n, err := t.f.Read(t.rBuf)
|
||||
copy(to, t.rBuf[4:])
|
||||
return n - 4, err
|
||||
}
|
||||
|
||||
func (t *tunReadCloser) Write(from []byte) (int, error) {
|
||||
|
||||
if len(from) == 0 {
|
||||
return 0, syscall.EIO
|
||||
}
|
||||
|
||||
t.wMu.Lock()
|
||||
defer t.wMu.Unlock()
|
||||
|
||||
if cap(t.wBuf) < len(from)+4 {
|
||||
t.wBuf = make([]byte, len(from)+4)
|
||||
}
|
||||
t.wBuf = t.wBuf[:len(from)+4]
|
||||
|
||||
// Determine the IP Family for the NULL L2 Header
|
||||
ipVer := from[0] >> 4
|
||||
if ipVer == 4 {
|
||||
t.wBuf[3] = syscall.AF_INET
|
||||
} else if ipVer == 6 {
|
||||
t.wBuf[3] = syscall.AF_INET6
|
||||
} else {
|
||||
return 0, errors.New("unable to determine IP version from packet")
|
||||
}
|
||||
|
||||
copy(t.wBuf[4:], from)
|
||||
|
||||
n, err := t.f.Write(t.wBuf)
|
||||
return n - 4, err
|
||||
}
|
||||
|
||||
func (t *tunReadCloser) Close() error {
|
||||
return t.f.Close()
|
||||
}
|
||||
|
||||
func (c *Tun) CidrNet() *net.IPNet {
|
||||
return c.Cidr
|
||||
}
|
||||
|
||||
func (c *Tun) DeviceName() string {
|
||||
return c.Device
|
||||
}
|
||||
|
||||
func (t *Tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||
return nil, fmt.Errorf("TODO: multiqueue not implemented for ios")
|
||||
}
|
||||
77
tun_linux.go
77
tun_linux.go
@@ -1,3 +1,5 @@
|
||||
// +build !android
|
||||
|
||||
package nebula
|
||||
|
||||
import (
|
||||
@@ -53,8 +55,9 @@ func ipv4(addr string) (o [4]byte, err error) {
|
||||
*/
|
||||
|
||||
const (
|
||||
cIFF_TUN = 0x0001
|
||||
cIFF_NO_PI = 0x1000
|
||||
cIFF_TUN = 0x0001
|
||||
cIFF_NO_PI = 0x1000
|
||||
cIFF_MULTI_QUEUE = 0x0100
|
||||
)
|
||||
|
||||
type ifreqAddr struct {
|
||||
@@ -75,7 +78,24 @@ type ifreqQLEN struct {
|
||||
pad [8]byte
|
||||
}
|
||||
|
||||
func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
|
||||
func newTunFromFd(deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
|
||||
|
||||
file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
|
||||
|
||||
ifce = &Tun{
|
||||
ReadWriteCloser: file,
|
||||
fd: int(file.Fd()),
|
||||
Device: "tun0",
|
||||
Cidr: cidr,
|
||||
DefaultMTU: defaultMTU,
|
||||
TXQueueLen: txQueueLen,
|
||||
Routes: routes,
|
||||
UnsafeRoutes: unsafeRoutes,
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int, multiqueue bool) (ifce *Tun, err error) {
|
||||
fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -83,9 +103,12 @@ func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route,
|
||||
|
||||
var req ifReq
|
||||
req.Flags = uint16(cIFF_TUN | cIFF_NO_PI)
|
||||
if multiqueue {
|
||||
req.Flags |= cIFF_MULTI_QUEUE
|
||||
}
|
||||
copy(req.Name[:], deviceName)
|
||||
if err = ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil {
|
||||
return
|
||||
return nil, err
|
||||
}
|
||||
name := strings.Trim(string(req.Name[:]), "\x00")
|
||||
|
||||
@@ -112,6 +135,24 @@ func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route,
|
||||
return
|
||||
}
|
||||
|
||||
func (c *Tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||
fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var req ifReq
|
||||
req.Flags = uint16(cIFF_TUN | cIFF_NO_PI | cIFF_MULTI_QUEUE)
|
||||
copy(req.Name[:], c.Device)
|
||||
if err = ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
file := os.NewFile(uintptr(fd), "/dev/net/tun")
|
||||
|
||||
return file, nil
|
||||
}
|
||||
|
||||
func (c *Tun) WriteRaw(b []byte) error {
|
||||
var nn int
|
||||
for {
|
||||
@@ -134,6 +175,10 @@ func (c *Tun) WriteRaw(b []byte) error {
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Tun) Write(b []byte) (int, error) {
|
||||
return len(b), c.WriteRaw(b)
|
||||
}
|
||||
|
||||
func (c Tun) deviceBytes() (o [16]byte) {
|
||||
for i, c := range c.Device {
|
||||
o[i] = byte(c)
|
||||
@@ -216,6 +261,7 @@ func (c Tun) Activate() error {
|
||||
LinkIndex: link.Attrs().Index,
|
||||
Dst: dr,
|
||||
MTU: c.DefaultMTU,
|
||||
AdvMSS: c.advMSS(route{}),
|
||||
Scope: unix.RT_SCOPE_LINK,
|
||||
Src: c.Cidr.IP,
|
||||
Protocol: unix.RTPROT_KERNEL,
|
||||
@@ -233,6 +279,7 @@ func (c Tun) Activate() error {
|
||||
LinkIndex: link.Attrs().Index,
|
||||
Dst: r.route,
|
||||
MTU: r.mtu,
|
||||
AdvMSS: c.advMSS(r),
|
||||
Scope: unix.RT_SCOPE_LINK,
|
||||
}
|
||||
|
||||
@@ -248,6 +295,7 @@ func (c Tun) Activate() error {
|
||||
LinkIndex: link.Attrs().Index,
|
||||
Dst: r.route,
|
||||
MTU: r.mtu,
|
||||
AdvMSS: c.advMSS(r),
|
||||
Scope: unix.RT_SCOPE_LINK,
|
||||
}
|
||||
|
||||
@@ -265,3 +313,24 @@ func (c Tun) Activate() error {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Tun) CidrNet() *net.IPNet {
|
||||
return c.Cidr
|
||||
}
|
||||
|
||||
func (c *Tun) DeviceName() string {
|
||||
return c.Device
|
||||
}
|
||||
|
||||
func (c Tun) advMSS(r route) int {
|
||||
mtu := r.mtu
|
||||
if r.mtu == 0 {
|
||||
mtu = c.DefaultMTU
|
||||
}
|
||||
|
||||
// We only need to set advmss if the route MTU does not match the device MTU
|
||||
if mtu != c.MaxMTU {
|
||||
return mtu - 40
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
31
tun_linux_test.go
Normal file
31
tun_linux_test.go
Normal file
@@ -0,0 +1,31 @@
|
||||
package nebula
|
||||
|
||||
import "testing"
|
||||
|
||||
var runAdvMSSTests = []struct {
|
||||
name string
|
||||
tun Tun
|
||||
r route
|
||||
expected int
|
||||
}{
|
||||
// Standard case, default MTU is the device max MTU
|
||||
{"default", Tun{DefaultMTU: 1440, MaxMTU: 1440}, route{}, 0},
|
||||
{"default-min", Tun{DefaultMTU: 1440, MaxMTU: 1440}, route{mtu: 1440}, 0},
|
||||
{"default-low", Tun{DefaultMTU: 1440, MaxMTU: 1440}, route{mtu: 1200}, 1160},
|
||||
|
||||
// Case where we have a route MTU set higher than the default
|
||||
{"route", Tun{DefaultMTU: 1440, MaxMTU: 8941}, route{}, 1400},
|
||||
{"route-min", Tun{DefaultMTU: 1440, MaxMTU: 8941}, route{mtu: 1440}, 1400},
|
||||
{"route-high", Tun{DefaultMTU: 1440, MaxMTU: 8941}, route{mtu: 8941}, 0},
|
||||
}
|
||||
|
||||
func TestTunAdvMSS(t *testing.T) {
|
||||
for _, tt := range runAdvMSSTests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
o := tt.tun.advMSS(tt.r)
|
||||
if o != tt.expected {
|
||||
t.Errorf("got %d, want %d", o, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
127
tun_test.go
127
tun_test.go
@@ -1,9 +1,11 @@
|
||||
package nebula
|
||||
|
||||
import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"fmt"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_parseRoutes(t *testing.T) {
|
||||
@@ -100,3 +102,126 @@ func Test_parseRoutes(t *testing.T) {
|
||||
t.Fatal("Did not see both routes")
|
||||
}
|
||||
}
|
||||
|
||||
func Test_parseUnsafeRoutes(t *testing.T) {
|
||||
c := NewConfig()
|
||||
_, n, _ := net.ParseCIDR("10.0.0.0/24")
|
||||
|
||||
// test no routes config
|
||||
routes, err := parseUnsafeRoutes(c, n)
|
||||
assert.Nil(t, err)
|
||||
assert.Len(t, routes, 0)
|
||||
|
||||
// not an array
|
||||
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": "hi"}
|
||||
routes, err = parseUnsafeRoutes(c, n)
|
||||
assert.Nil(t, routes)
|
||||
assert.EqualError(t, err, "tun.unsafe_routes is not an array")
|
||||
|
||||
// no routes
|
||||
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{}}
|
||||
routes, err = parseUnsafeRoutes(c, n)
|
||||
assert.Nil(t, err)
|
||||
assert.Len(t, routes, 0)
|
||||
|
||||
// weird route
|
||||
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{"asdf"}}
|
||||
routes, err = parseUnsafeRoutes(c, n)
|
||||
assert.Nil(t, routes)
|
||||
assert.EqualError(t, err, "entry 1 in tun.unsafe_routes is invalid")
|
||||
|
||||
// no via
|
||||
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{}}}
|
||||
routes, err = parseUnsafeRoutes(c, n)
|
||||
assert.Nil(t, routes)
|
||||
assert.EqualError(t, err, "entry 1.via in tun.unsafe_routes is not present")
|
||||
|
||||
// invalid via
|
||||
for _, invalidValue := range []interface{}{
|
||||
127, false, nil, 1.0, []string{"1", "2"},
|
||||
} {
|
||||
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": invalidValue}}}
|
||||
routes, err = parseUnsafeRoutes(c, n)
|
||||
assert.Nil(t, routes)
|
||||
assert.EqualError(t, err, fmt.Sprintf("entry 1.via in tun.unsafe_routes is not a string: found %T", invalidValue))
|
||||
}
|
||||
|
||||
// unparsable via
|
||||
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "via": "nope"}}}
|
||||
routes, err = parseUnsafeRoutes(c, n)
|
||||
assert.Nil(t, routes)
|
||||
assert.EqualError(t, err, "entry 1.via in tun.unsafe_routes failed to parse address: nope")
|
||||
|
||||
// missing route
|
||||
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "500"}}}
|
||||
routes, err = parseUnsafeRoutes(c, n)
|
||||
assert.Nil(t, routes)
|
||||
assert.EqualError(t, err, "entry 1.route in tun.unsafe_routes is not present")
|
||||
|
||||
// unparsable route
|
||||
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "500", "route": "nope"}}}
|
||||
routes, err = parseUnsafeRoutes(c, n)
|
||||
assert.Nil(t, routes)
|
||||
assert.EqualError(t, err, "entry 1.route in tun.unsafe_routes failed to parse: invalid CIDR address: nope")
|
||||
|
||||
// within network range
|
||||
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "10.0.0.0/24"}}}
|
||||
routes, err = parseUnsafeRoutes(c, n)
|
||||
assert.Nil(t, routes)
|
||||
assert.EqualError(t, err, "entry 1.route in tun.unsafe_routes is contained within the network attached to the certificate; route: 10.0.0.0/24, network: 10.0.0.0/24")
|
||||
|
||||
// below network range
|
||||
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "1.0.0.0/8"}}}
|
||||
routes, err = parseUnsafeRoutes(c, n)
|
||||
assert.Len(t, routes, 1)
|
||||
assert.Nil(t, err)
|
||||
|
||||
// above network range
|
||||
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "10.0.1.0/24"}}}
|
||||
routes, err = parseUnsafeRoutes(c, n)
|
||||
assert.Len(t, routes, 1)
|
||||
assert.Nil(t, err)
|
||||
|
||||
// no mtu
|
||||
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "1.0.0.0/8"}}}
|
||||
routes, err = parseUnsafeRoutes(c, n)
|
||||
assert.Len(t, routes, 1)
|
||||
assert.Equal(t, DEFAULT_MTU, routes[0].mtu)
|
||||
|
||||
// bad mtu
|
||||
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "nope"}}}
|
||||
routes, err = parseUnsafeRoutes(c, n)
|
||||
assert.Nil(t, routes)
|
||||
assert.EqualError(t, err, "entry 1.mtu in tun.unsafe_routes is not an integer: strconv.Atoi: parsing \"nope\": invalid syntax")
|
||||
|
||||
// low mtu
|
||||
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "499"}}}
|
||||
routes, err = parseUnsafeRoutes(c, n)
|
||||
assert.Nil(t, routes)
|
||||
assert.EqualError(t, err, "entry 1.mtu in tun.unsafe_routes is below 500: 499")
|
||||
|
||||
// happy case
|
||||
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{
|
||||
map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "9000", "route": "1.0.0.0/29"},
|
||||
map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "8000", "route": "1.0.0.1/32"},
|
||||
}}
|
||||
routes, err = parseUnsafeRoutes(c, n)
|
||||
assert.Nil(t, err)
|
||||
assert.Len(t, routes, 2)
|
||||
|
||||
tested := 0
|
||||
for _, r := range routes {
|
||||
if r.mtu == 8000 {
|
||||
assert.Equal(t, "1.0.0.1/32", r.route.String())
|
||||
tested++
|
||||
} else {
|
||||
assert.Equal(t, 9000, r.mtu)
|
||||
assert.Equal(t, "1.0.0.0/29", r.route.String())
|
||||
tested++
|
||||
}
|
||||
}
|
||||
|
||||
if tested != 2 {
|
||||
t.Fatal("Did not see both unsafe_routes")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,31 +2,37 @@ package nebula
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"os/exec"
|
||||
"strconv"
|
||||
|
||||
"github.com/songgao/water"
|
||||
)
|
||||
|
||||
type Tun struct {
|
||||
Device string
|
||||
Cidr *net.IPNet
|
||||
MTU int
|
||||
Device string
|
||||
Cidr *net.IPNet
|
||||
MTU int
|
||||
UnsafeRoutes []route
|
||||
|
||||
*water.Interface
|
||||
}
|
||||
|
||||
func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
|
||||
func newTunFromFd(deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
|
||||
return nil, fmt.Errorf("newTunFromFd not supported in Windows")
|
||||
}
|
||||
|
||||
func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int, multiqueue bool) (ifce *Tun, err error) {
|
||||
if len(routes) > 0 {
|
||||
return nil, fmt.Errorf("Route MTU not supported in Windows")
|
||||
}
|
||||
if len(unsafeRoutes) > 0 {
|
||||
return nil, fmt.Errorf("unsafeRoutes not supported in Windows")
|
||||
return nil, fmt.Errorf("route MTU not supported in Windows")
|
||||
}
|
||||
|
||||
// NOTE: You cannot set the deviceName under Windows, so you must check tun.Device after calling .Activate()
|
||||
return &Tun{
|
||||
Cidr: cidr,
|
||||
MTU: defaultMTU,
|
||||
Cidr: cidr,
|
||||
MTU: defaultMTU,
|
||||
UnsafeRoutes: unsafeRoutes,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -47,7 +53,7 @@ func (c *Tun) Activate() error {
|
||||
|
||||
// TODO use syscalls instead of exec.Command
|
||||
err = exec.Command(
|
||||
"netsh", "interface", "ipv4", "set", "address",
|
||||
`C:\Windows\System32\netsh.exe`, "interface", "ipv4", "set", "address",
|
||||
fmt.Sprintf("name=%s", c.Device),
|
||||
"source=static",
|
||||
fmt.Sprintf("addr=%s", c.Cidr.IP),
|
||||
@@ -58,7 +64,7 @@ func (c *Tun) Activate() error {
|
||||
return fmt.Errorf("failed to run 'netsh' to set address: %s", err)
|
||||
}
|
||||
err = exec.Command(
|
||||
"netsh", "interface", "ipv4", "set", "interface",
|
||||
`C:\Windows\System32\netsh.exe`, "interface", "ipv4", "set", "interface",
|
||||
c.Device,
|
||||
fmt.Sprintf("mtu=%d", c.MTU),
|
||||
).Run()
|
||||
@@ -66,10 +72,36 @@ func (c *Tun) Activate() error {
|
||||
return fmt.Errorf("failed to run 'netsh' to set MTU: %s", err)
|
||||
}
|
||||
|
||||
iface, err := net.InterfaceByName(c.Device)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to find interface named %s: %v", c.Device, err)
|
||||
}
|
||||
|
||||
for _, r := range c.UnsafeRoutes {
|
||||
err = exec.Command(
|
||||
"C:\\Windows\\System32\\route.exe", "add", r.route.String(), r.via.String(), "IF", strconv.Itoa(iface.Index),
|
||||
).Run()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to add the unsafe_route %s: %v", r.route.String(), err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Tun) CidrNet() *net.IPNet {
|
||||
return c.Cidr
|
||||
}
|
||||
|
||||
func (c *Tun) DeviceName() string {
|
||||
return c.Device
|
||||
}
|
||||
|
||||
func (c *Tun) WriteRaw(b []byte) error {
|
||||
_, err := c.Write(b)
|
||||
return err
|
||||
}
|
||||
|
||||
func (t *Tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||
return nil, fmt.Errorf("TODO: multiqueue not implemented for windows")
|
||||
}
|
||||
|
||||
36
udp_android.go
Normal file
36
udp_android.go
Normal file
@@ -0,0 +1,36 @@
|
||||
package nebula
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"syscall"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
func NewListenConfig(multi bool) net.ListenConfig {
|
||||
return net.ListenConfig{
|
||||
Control: func(network, address string, c syscall.RawConn) error {
|
||||
if multi {
|
||||
var controlErr error
|
||||
err := c.Control(func(fd uintptr) {
|
||||
if err := syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, unix.SO_REUSEPORT, 1); err != nil {
|
||||
controlErr = fmt.Errorf("SO_REUSEPORT failed: %v", err)
|
||||
return
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if controlErr != nil {
|
||||
return controlErr
|
||||
}
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (u *udpConn) Rebind() error {
|
||||
return nil
|
||||
}
|
||||
@@ -13,6 +13,32 @@ import (
|
||||
func NewListenConfig(multi bool) net.ListenConfig {
|
||||
return net.ListenConfig{
|
||||
Control: func(network, address string, c syscall.RawConn) error {
|
||||
var controlErr error
|
||||
err := c.Control(func(fd uintptr) {
|
||||
if err := syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_SNDBUF, 999999); err != nil {
|
||||
controlErr = fmt.Errorf("SO_SNDBUF failed: %v", err)
|
||||
return
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if controlErr != nil {
|
||||
return controlErr
|
||||
}
|
||||
err = c.Control(func(fd uintptr) {
|
||||
if err := syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_RCVBUF, 999999); err != nil {
|
||||
controlErr = fmt.Errorf("SO_RCVBUF failed: %v", err)
|
||||
return
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if controlErr != nil {
|
||||
return controlErr
|
||||
}
|
||||
|
||||
if multi {
|
||||
var controlErr error
|
||||
err := c.Control(func(fd uintptr) {
|
||||
@@ -32,3 +58,12 @@ func NewListenConfig(multi bool) net.ListenConfig {
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (u *udpConn) Rebind() error {
|
||||
file, err := u.File()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return syscall.SetsockoptInt(int(file.Fd()), unix.IPPROTO_IP, unix.IP_BOUND_IF, 0)
|
||||
}
|
||||
|
||||
38
udp_freebsd.go
Normal file
38
udp_freebsd.go
Normal file
@@ -0,0 +1,38 @@
|
||||
package nebula
|
||||
|
||||
// FreeBSD support is primarily implemented in udp_generic, besides NewListenConfig
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"syscall"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
func NewListenConfig(multi bool) net.ListenConfig {
|
||||
return net.ListenConfig{
|
||||
Control: func(network, address string, c syscall.RawConn) error {
|
||||
if multi {
|
||||
var controlErr error
|
||||
err := c.Control(func(fd uintptr) {
|
||||
if err := syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, unix.SO_REUSEPORT, 1); err != nil {
|
||||
controlErr = fmt.Errorf("SO_REUSEPORT failed: %v", err)
|
||||
return
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if controlErr != nil {
|
||||
return controlErr
|
||||
}
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (u *udpConn) Rebind() error {
|
||||
return nil
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
// +build !linux
|
||||
// +build !linux android
|
||||
|
||||
// udp_generic implements the nebula UDP interface in pure Go stdlib. This
|
||||
// means it can be used on platforms like Darwin and Windows.
|
||||
@@ -65,6 +65,17 @@ func (ua *udpAddr) Equals(t *udpAddr) bool {
|
||||
return ua.IP.Equal(t.IP) && ua.Port == t.Port
|
||||
}
|
||||
|
||||
func (ua *udpAddr) Copy() udpAddr {
|
||||
nu := udpAddr{net.UDPAddr{
|
||||
Port: ua.Port,
|
||||
Zone: ua.Zone,
|
||||
IP: make(net.IP, len(ua.IP)),
|
||||
}}
|
||||
|
||||
copy(nu.IP, ua.IP)
|
||||
return nu
|
||||
}
|
||||
|
||||
func (uc *udpConn) WriteTo(b []byte, addr *udpAddr) error {
|
||||
_, err := uc.UDPConn.WriteToUDP(b, &addr.UDPAddr)
|
||||
return err
|
||||
@@ -85,11 +96,16 @@ func (u *udpConn) reloadConfig(c *Config) {
|
||||
// TODO
|
||||
}
|
||||
|
||||
func NewUDPStatsEmitter(udpConns []*udpConn) func() {
|
||||
// No UDP stats for non-linux
|
||||
return func() {}
|
||||
}
|
||||
|
||||
type rawMessage struct {
|
||||
Len uint32
|
||||
}
|
||||
|
||||
func (u *udpConn) ListenOut(f *Interface) {
|
||||
func (u *udpConn) ListenOut(f *Interface, q int) {
|
||||
plaintext := make([]byte, mtu)
|
||||
buffer := make([]byte, mtu)
|
||||
header := &Header{}
|
||||
@@ -97,6 +113,10 @@ func (u *udpConn) ListenOut(f *Interface) {
|
||||
udpAddr := &udpAddr{}
|
||||
nb := make([]byte, 12, 12)
|
||||
|
||||
lhh := f.lightHouse.NewRequestHandler()
|
||||
|
||||
conntrackCache := NewConntrackCacheTicker(f.conntrackCacheTimeout)
|
||||
|
||||
for {
|
||||
// Just read one packet at a time
|
||||
n, rua, err := u.ReadFromUDP(buffer)
|
||||
@@ -106,7 +126,7 @@ func (u *udpConn) ListenOut(f *Interface) {
|
||||
}
|
||||
|
||||
udpAddr.UDPAddr = *rua
|
||||
f.readOutsidePackets(udpAddr, plaintext[:0], buffer[:n], header, fwPacket, nb)
|
||||
f.readOutsidePackets(udpAddr, plaintext[:0], buffer[:n], header, fwPacket, lhh, nb, q, conntrackCache.Get())
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
126
udp_linux.go
126
udp_linux.go
@@ -1,3 +1,5 @@
|
||||
// +build !android
|
||||
|
||||
package nebula
|
||||
|
||||
import (
|
||||
@@ -10,6 +12,7 @@ import (
|
||||
"syscall"
|
||||
"unsafe"
|
||||
|
||||
"github.com/rcrowley/go-metrics"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
@@ -53,6 +56,23 @@ type rawSockaddrAny struct {
|
||||
|
||||
var x int
|
||||
|
||||
// From linux/sock_diag.h
|
||||
const (
|
||||
_SK_MEMINFO_RMEM_ALLOC = iota
|
||||
_SK_MEMINFO_RCVBUF
|
||||
_SK_MEMINFO_WMEM_ALLOC
|
||||
_SK_MEMINFO_SNDBUF
|
||||
_SK_MEMINFO_FWD_ALLOC
|
||||
_SK_MEMINFO_WMEM_QUEUED
|
||||
_SK_MEMINFO_OPTMEM
|
||||
_SK_MEMINFO_BACKLOG
|
||||
_SK_MEMINFO_DROPS
|
||||
|
||||
_SK_MEMINFO_VARS
|
||||
)
|
||||
|
||||
type _SK_MEMINFO [_SK_MEMINFO_VARS]uint32
|
||||
|
||||
func NewListener(ip string, port int, multi bool) (*udpConn, error) {
|
||||
syscall.ForkLock.RLock()
|
||||
fd, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_UDP)
|
||||
@@ -69,8 +89,10 @@ func NewListener(ip string, port int, multi bool) (*udpConn, error) {
|
||||
var lip [4]byte
|
||||
copy(lip[:], net.ParseIP(ip).To4())
|
||||
|
||||
if err = unix.SetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_REUSEPORT, 1); err != nil {
|
||||
return nil, fmt.Errorf("unable to set SO_REUSEPORT: %s", err)
|
||||
if multi {
|
||||
if err = unix.SetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_REUSEPORT, 1); err != nil {
|
||||
return nil, fmt.Errorf("unable to set SO_REUSEPORT: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err = unix.Bind(fd, &unix.SockaddrInet4{Addr: lip, Port: port}); err != nil {
|
||||
@@ -85,6 +107,14 @@ func NewListener(ip string, port int, multi bool) (*udpConn, error) {
|
||||
return &udpConn{sysFd: fd}, err
|
||||
}
|
||||
|
||||
func (u *udpConn) Rebind() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ua *udpAddr) Copy() udpAddr {
|
||||
return *ua
|
||||
}
|
||||
|
||||
func (u *udpConn) SetRecvBuffer(n int) error {
|
||||
return unix.SetsockoptInt(u.sysFd, unix.SOL_SOCKET, unix.SO_RCVBUFFORCE, n)
|
||||
}
|
||||
@@ -127,19 +157,27 @@ func (u *udpConn) LocalAddr() (*udpAddr, error) {
|
||||
return addr, nil
|
||||
}
|
||||
|
||||
func (u *udpConn) ListenOut(f *Interface) {
|
||||
func (u *udpConn) ListenOut(f *Interface, q int) {
|
||||
plaintext := make([]byte, mtu)
|
||||
header := &Header{}
|
||||
fwPacket := &FirewallPacket{}
|
||||
udpAddr := &udpAddr{}
|
||||
nb := make([]byte, 12, 12)
|
||||
|
||||
lhh := f.lightHouse.NewRequestHandler()
|
||||
|
||||
//TODO: should we track this?
|
||||
//metric := metrics.GetOrRegisterHistogram("test.batch_read", nil, metrics.NewExpDecaySample(1028, 0.015))
|
||||
msgs, buffers, names := u.PrepareRawMessages(f.udpBatchSize)
|
||||
read := u.ReadMulti
|
||||
if f.udpBatchSize == 1 {
|
||||
read = u.ReadSingle
|
||||
}
|
||||
|
||||
conntrackCache := NewConntrackCacheTicker(f.conntrackCacheTimeout)
|
||||
|
||||
for {
|
||||
n, err := u.ReadMulti(msgs)
|
||||
n, err := read(msgs)
|
||||
if err != nil {
|
||||
l.WithError(err).Error("Failed to read packets")
|
||||
continue
|
||||
@@ -150,39 +188,29 @@ func (u *udpConn) ListenOut(f *Interface) {
|
||||
udpAddr.IP = binary.BigEndian.Uint32(names[i][4:8])
|
||||
udpAddr.Port = binary.BigEndian.Uint16(names[i][2:4])
|
||||
|
||||
f.readOutsidePackets(udpAddr, plaintext[:0], buffers[i][:msgs[i].Len], header, fwPacket, nb)
|
||||
f.readOutsidePackets(udpAddr, plaintext[:0], buffers[i][:msgs[i].Len], header, fwPacket, lhh, nb, q, conntrackCache.Get())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (u *udpConn) Read(addr *udpAddr, b []byte) ([]byte, error) {
|
||||
var rsa rawSockaddrAny
|
||||
var rLen = unix.SizeofSockaddrAny
|
||||
|
||||
func (u *udpConn) ReadSingle(msgs []rawMessage) (int, error) {
|
||||
for {
|
||||
n, _, err := unix.Syscall6(
|
||||
unix.SYS_RECVFROM,
|
||||
unix.SYS_RECVMSG,
|
||||
uintptr(u.sysFd),
|
||||
uintptr(unsafe.Pointer(&b[0])),
|
||||
uintptr(len(b)),
|
||||
uintptr(0),
|
||||
uintptr(unsafe.Pointer(&rsa)),
|
||||
uintptr(unsafe.Pointer(&rLen)),
|
||||
uintptr(unsafe.Pointer(&(msgs[0].Hdr))),
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
)
|
||||
|
||||
if err != 0 {
|
||||
return nil, &net.OpError{Op: "read", Err: err}
|
||||
return 0, &net.OpError{Op: "recvmsg", Err: err}
|
||||
}
|
||||
|
||||
if rsa.Addr.Family == unix.AF_INET {
|
||||
addr.Port = uint16(rsa.Addr.Data[0])<<8 + uint16(rsa.Addr.Data[1])
|
||||
addr.IP = uint32(rsa.Addr.Data[2])<<24 + uint32(rsa.Addr.Data[3])<<16 + uint32(rsa.Addr.Data[4])<<8 + uint32(rsa.Addr.Data[5])
|
||||
} else {
|
||||
addr.Port = 0
|
||||
addr.IP = 0
|
||||
}
|
||||
|
||||
return b[:n], nil
|
||||
msgs[0].Len = uint32(n)
|
||||
return 1, nil
|
||||
}
|
||||
}
|
||||
|
||||
@@ -273,6 +301,47 @@ func (u *udpConn) reloadConfig(c *Config) {
|
||||
}
|
||||
}
|
||||
|
||||
func (u *udpConn) getMemInfo(meminfo *_SK_MEMINFO) error {
|
||||
var vallen uint32 = 4 * _SK_MEMINFO_VARS
|
||||
_, _, err := unix.Syscall6(unix.SYS_GETSOCKOPT, uintptr(u.sysFd), uintptr(unix.SOL_SOCKET), uintptr(unix.SO_MEMINFO), uintptr(unsafe.Pointer(meminfo)), uintptr(unsafe.Pointer(&vallen)), 0)
|
||||
if err != 0 {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func NewUDPStatsEmitter(udpConns []*udpConn) func() {
|
||||
// Check if our kernel supports SO_MEMINFO before registering the gauges
|
||||
var udpGauges [][_SK_MEMINFO_VARS]metrics.Gauge
|
||||
var meminfo _SK_MEMINFO
|
||||
if err := udpConns[0].getMemInfo(&meminfo); err == nil {
|
||||
udpGauges = make([][_SK_MEMINFO_VARS]metrics.Gauge, len(udpConns))
|
||||
for i := range udpConns {
|
||||
udpGauges[i] = [_SK_MEMINFO_VARS]metrics.Gauge{
|
||||
metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.rmem_alloc", i), nil),
|
||||
metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.rcvbuf", i), nil),
|
||||
metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.wmem_alloc", i), nil),
|
||||
metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.sndbuf", i), nil),
|
||||
metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.fwd_alloc", i), nil),
|
||||
metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.wmem_queued", i), nil),
|
||||
metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.optmem", i), nil),
|
||||
metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.backlog", i), nil),
|
||||
metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.drops", i), nil),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return func() {
|
||||
for i, gauges := range udpGauges {
|
||||
if err := udpConns[i].getMemInfo(&meminfo); err == nil {
|
||||
for j := 0; j < _SK_MEMINFO_VARS; j++ {
|
||||
gauges[j].Update(int64(meminfo[j]))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (ua *udpAddr) Equals(t *udpAddr) bool {
|
||||
if t == nil || ua == nil {
|
||||
return t == nil && ua == nil
|
||||
@@ -280,13 +349,6 @@ func (ua *udpAddr) Equals(t *udpAddr) bool {
|
||||
return ua.IP == t.IP && ua.Port == t.Port
|
||||
}
|
||||
|
||||
func (ua *udpAddr) Copy() *udpAddr {
|
||||
return &udpAddr{
|
||||
Port: ua.Port,
|
||||
IP: ua.IP,
|
||||
}
|
||||
}
|
||||
|
||||
func (ua *udpAddr) String() string {
|
||||
return fmt.Sprintf("%s:%v", int2ip(ua.IP), ua.Port)
|
||||
}
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
// +build linux
|
||||
// +build 386 amd64p32 arm mips mipsle
|
||||
// +build !android
|
||||
|
||||
package nebula
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
// +build linux
|
||||
// +build amd64 arm64 ppc64 ppc64le mips64 mips64le s390x
|
||||
// +build !android
|
||||
|
||||
package nebula
|
||||
|
||||
|
||||
@@ -11,6 +11,32 @@ import (
|
||||
func NewListenConfig(multi bool) net.ListenConfig {
|
||||
return net.ListenConfig{
|
||||
Control: func(network, address string, c syscall.RawConn) error {
|
||||
var controlErr error
|
||||
err := c.Control(func(fd uintptr) {
|
||||
if err := syscall.SetsockoptInt(syscall.Handle(fd), syscall.SOL_SOCKET, syscall.SO_SNDBUF, 999999); err != nil {
|
||||
controlErr = fmt.Errorf("SO_SNDBUF failed: %v", err)
|
||||
return
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if controlErr != nil {
|
||||
return controlErr
|
||||
}
|
||||
err = c.Control(func(fd uintptr) {
|
||||
if err := syscall.SetsockoptInt(syscall.Handle(fd), syscall.SOL_SOCKET, syscall.SO_RCVBUF, 999999); err != nil {
|
||||
controlErr = fmt.Errorf("SO_RCVBUF failed: %v", err)
|
||||
return
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if controlErr != nil {
|
||||
return controlErr
|
||||
}
|
||||
|
||||
if multi {
|
||||
// There is no way to support multiple listeners safely on Windows:
|
||||
// https://docs.microsoft.com/en-us/windows/desktop/winsock/using-so-reuseaddr-and-so-exclusiveaddruse
|
||||
@@ -20,3 +46,7 @@ func NewListenConfig(multi bool) net.ListenConfig {
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (u *udpConn) Rebind() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
130
util/assert.go
Normal file
130
util/assert.go
Normal file
@@ -0,0 +1,130 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// AssertDeepCopyEqual checks to see if two variables have the same values but DO NOT share any memory
|
||||
// There is currently a special case for `time.loc` (as this code traverses into unexported fields)
|
||||
func AssertDeepCopyEqual(t *testing.T, a interface{}, b interface{}) {
|
||||
v1 := reflect.ValueOf(a)
|
||||
v2 := reflect.ValueOf(b)
|
||||
|
||||
if !assert.Equal(t, v1.Type(), v2.Type()) {
|
||||
return
|
||||
}
|
||||
|
||||
traverseDeepCopy(t, v1, v2, v1.Type().String())
|
||||
}
|
||||
|
||||
func traverseDeepCopy(t *testing.T, v1 reflect.Value, v2 reflect.Value, name string) bool {
|
||||
switch v1.Kind() {
|
||||
case reflect.Array:
|
||||
for i := 0; i < v1.Len(); i++ {
|
||||
if !traverseDeepCopy(t, v1.Index(i), v2.Index(i), fmt.Sprintf("%s[%v]", name, i)) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
|
||||
case reflect.Slice:
|
||||
if v1.IsNil() || v2.IsNil() {
|
||||
return assert.Equal(t, v1.IsNil(), v2.IsNil(), "%s are not both nil %+v, %+v", name, v1, v2)
|
||||
}
|
||||
|
||||
if !assert.Equal(t, v1.Len(), v2.Len(), "%s did not have the same length", name) {
|
||||
return false
|
||||
}
|
||||
|
||||
// A slice with cap 0
|
||||
if v1.Cap() != 0 && !assert.NotEqual(t, v1.Pointer(), v2.Pointer(), "%s point to the same slice %v == %v", name, v1.Pointer(), v2.Pointer()) {
|
||||
return false
|
||||
}
|
||||
|
||||
v1c := v1.Cap()
|
||||
v2c := v2.Cap()
|
||||
if v1c > 0 && v2c > 0 && v1.Slice(0, v1c).Slice(v1c-1, v1c-1).Pointer() == v2.Slice(0, v2c).Slice(v2c-1, v2c-1).Pointer() {
|
||||
return assert.Fail(t, "", "%s share some underlying memory", name)
|
||||
}
|
||||
|
||||
for i := 0; i < v1.Len(); i++ {
|
||||
if !traverseDeepCopy(t, v1.Index(i), v2.Index(i), fmt.Sprintf("%s[%v]", name, i)) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
|
||||
case reflect.Interface:
|
||||
if v1.IsNil() || v2.IsNil() {
|
||||
return assert.Equal(t, v1.IsNil(), v2.IsNil(), "%s are not both nil", name)
|
||||
}
|
||||
return traverseDeepCopy(t, v1.Elem(), v2.Elem(), name)
|
||||
|
||||
case reflect.Ptr:
|
||||
local := reflect.ValueOf(time.Local).Pointer()
|
||||
if local == v1.Pointer() && local == v2.Pointer() {
|
||||
return true
|
||||
}
|
||||
|
||||
if !assert.NotEqual(t, v1.Pointer(), v2.Pointer(), "%s points to the same memory", name) {
|
||||
return false
|
||||
}
|
||||
|
||||
return traverseDeepCopy(t, v1.Elem(), v2.Elem(), name)
|
||||
|
||||
case reflect.Struct:
|
||||
for i, n := 0, v1.NumField(); i < n; i++ {
|
||||
if !traverseDeepCopy(t, v1.Field(i), v2.Field(i), name+"."+v1.Type().Field(i).Name) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
|
||||
case reflect.Map:
|
||||
if v1.IsNil() || v2.IsNil() {
|
||||
return assert.Equal(t, v1.IsNil(), v2.IsNil(), "%s are not both nil", name)
|
||||
}
|
||||
|
||||
if !assert.Equal(t, v1.Len(), v2.Len(), "%s are not the same length", name) {
|
||||
return false
|
||||
}
|
||||
|
||||
if !assert.NotEqual(t, v1.Pointer(), v2.Pointer(), "%s point to the same memory", name) {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, k := range v1.MapKeys() {
|
||||
val1 := v1.MapIndex(k)
|
||||
val2 := v2.MapIndex(k)
|
||||
if !assert.True(t, val1.IsValid(), "%s is an invalid key in %s", k, name) {
|
||||
return false
|
||||
}
|
||||
|
||||
if !assert.True(t, val2.IsValid(), "%s is an invalid key in %s", k, name) {
|
||||
return false
|
||||
}
|
||||
|
||||
if !traverseDeepCopy(t, val1, val2, name+fmt.Sprintf("%s[%s]", name, k)) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
|
||||
default:
|
||||
if v1.CanInterface() && v2.CanInterface() {
|
||||
return assert.Equal(t, v1.Interface(), v2.Interface(), "%s was not equal", name)
|
||||
}
|
||||
|
||||
e1 := reflect.NewAt(v1.Type(), unsafe.Pointer(v1.UnsafeAddr())).Elem().Interface()
|
||||
e2 := reflect.NewAt(v2.Type(), unsafe.Pointer(v2.UnsafeAddr())).Elem().Interface()
|
||||
|
||||
return assert.Equal(t, e1, e2, "%s (unexported) was not equal", name)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user