mirror of
https://github.com/slackhq/nebula.git
synced 2025-11-23 17:04:25 +01:00
Compare commits
38 Commits
v1.2.0
...
interface-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
db11e2f1af | ||
|
|
2ee428b067 | ||
|
|
e9657d571e | ||
|
|
3cebf38504 | ||
|
|
ae3ee42469 | ||
|
|
fa034a6d83 | ||
|
|
55d72ac46f | ||
|
|
2c931d5691 | ||
|
|
0d6b55e495 | ||
|
|
c71c84882e | ||
|
|
0010db46e4 | ||
|
|
68e3e84fdc | ||
|
|
6238f1550b | ||
|
|
50b04413c7 | ||
|
|
ef498a31da | ||
|
|
2e5a477a50 | ||
|
|
32fe9bfe75 | ||
|
|
9b8b3c478b | ||
|
|
7b3f23d9a1 | ||
|
|
25964b54f6 | ||
|
|
ac557f381b | ||
|
|
a54f3fc681 | ||
|
|
5545cff6ef | ||
|
|
f3a6d8d990 | ||
|
|
9b06748506 | ||
|
|
4756c9613d | ||
|
|
4645e6034b | ||
|
|
aba42f9fa6 | ||
|
|
41578ca971 | ||
|
|
1ea8847085 | ||
|
|
55858c64cc | ||
|
|
e94c6b0125 | ||
|
|
b37a91cfbc | ||
|
|
3212b769d4 | ||
|
|
ecf0e5a9f6 | ||
|
|
ff13aba8fc | ||
|
|
cc03ff9e9a | ||
|
|
363c836422 |
20
.github/workflows/gofmt.yml
vendored
20
.github/workflows/gofmt.yml
vendored
@@ -14,19 +14,31 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
|
|
||||||
- name: Set up Go 1.14
|
- name: Set up Go 1.15
|
||||||
uses: actions/setup-go@v1
|
uses: actions/setup-go@v1
|
||||||
with:
|
with:
|
||||||
go-version: 1.14
|
go-version: 1.15
|
||||||
id: go
|
id: go
|
||||||
|
|
||||||
- name: Check out code into the Go module directory
|
- name: Check out code into the Go module directory
|
||||||
uses: actions/checkout@v1
|
uses: actions/checkout@v1
|
||||||
|
|
||||||
|
- uses: actions/cache@v1
|
||||||
|
with:
|
||||||
|
path: ~/go/pkg/mod
|
||||||
|
key: ${{ runner.os }}-gofmt-${{ hashFiles('**/go.sum') }}
|
||||||
|
restore-keys: |
|
||||||
|
${{ runner.os }}-gofmt-
|
||||||
|
|
||||||
|
- name: Install goimports
|
||||||
|
run: |
|
||||||
|
go get golang.org/x/tools/cmd/goimports
|
||||||
|
go build golang.org/x/tools/cmd/goimports
|
||||||
|
|
||||||
- name: gofmt
|
- name: gofmt
|
||||||
run: |
|
run: |
|
||||||
if [ "$(find . -iname '*.go' | xargs gofmt -l)" ]
|
if [ "$(find . -iname '*.go' | grep -v '\.pb\.go$' | xargs ./goimports -l)" ]
|
||||||
then
|
then
|
||||||
find . -iname '*.go' | xargs gofmt -d
|
find . -iname '*.go' | grep -v '\.pb\.go$' | xargs ./goimports -d
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
|||||||
34
.github/workflows/release.yml
vendored
34
.github/workflows/release.yml
vendored
@@ -10,17 +10,17 @@ jobs:
|
|||||||
name: Build Linux All
|
name: Build Linux All
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Set up Go 1.14
|
- name: Set up Go 1.15
|
||||||
uses: actions/setup-go@v1
|
uses: actions/setup-go@v1
|
||||||
with:
|
with:
|
||||||
go-version: 1.14
|
go-version: 1.15
|
||||||
|
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v2
|
uses: actions/checkout@v2
|
||||||
|
|
||||||
- name: Build
|
- name: Build
|
||||||
run: |
|
run: |
|
||||||
make BUILD_NUMBER="${GITHUB_REF#refs/tags/v}" release-linux
|
make BUILD_NUMBER="${GITHUB_REF#refs/tags/v}" release-linux release-freebsd
|
||||||
mkdir release
|
mkdir release
|
||||||
mv build/*.tar.gz release
|
mv build/*.tar.gz release
|
||||||
|
|
||||||
@@ -34,10 +34,10 @@ jobs:
|
|||||||
name: Build Windows amd64
|
name: Build Windows amd64
|
||||||
runs-on: windows-latest
|
runs-on: windows-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Set up Go 1.14
|
- name: Set up Go 1.15
|
||||||
uses: actions/setup-go@v1
|
uses: actions/setup-go@v1
|
||||||
with:
|
with:
|
||||||
go-version: 1.14
|
go-version: 1.15
|
||||||
|
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v2
|
uses: actions/checkout@v2
|
||||||
@@ -58,10 +58,10 @@ jobs:
|
|||||||
name: Build Darwin amd64
|
name: Build Darwin amd64
|
||||||
runs-on: macOS-latest
|
runs-on: macOS-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Set up Go 1.14
|
- name: Set up Go 1.15
|
||||||
uses: actions/setup-go@v1
|
uses: actions/setup-go@v1
|
||||||
with:
|
with:
|
||||||
go-version: 1.14
|
go-version: 1.15
|
||||||
|
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v2
|
uses: actions/checkout@v2
|
||||||
@@ -278,3 +278,23 @@ jobs:
|
|||||||
asset_path: ./linux-latest/nebula-linux-mips64le.tar.gz
|
asset_path: ./linux-latest/nebula-linux-mips64le.tar.gz
|
||||||
asset_name: nebula-linux-mips64le.tar.gz
|
asset_name: nebula-linux-mips64le.tar.gz
|
||||||
asset_content_type: application/gzip
|
asset_content_type: application/gzip
|
||||||
|
|
||||||
|
- name: Upload linux-mips-softfloat
|
||||||
|
uses: actions/upload-release-asset@v1.0.1
|
||||||
|
env:
|
||||||
|
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||||
|
with:
|
||||||
|
upload_url: ${{ steps.create_release.outputs.upload_url }}
|
||||||
|
asset_path: ./linux-latest/nebula-linux-mips-softfloat.tar.gz
|
||||||
|
asset_name: nebula-linux-mips-softfloat.tar.gz
|
||||||
|
asset_content_type: application/gzip
|
||||||
|
|
||||||
|
- name: Upload freebsd-amd64
|
||||||
|
uses: actions/upload-release-asset@v1.0.1
|
||||||
|
env:
|
||||||
|
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||||
|
with:
|
||||||
|
upload_url: ${{ steps.create_release.outputs.upload_url }}
|
||||||
|
asset_path: ./linux-latest/nebula-freebsd-amd64.tar.gz
|
||||||
|
asset_name: nebula-freebsd-amd64.tar.gz
|
||||||
|
asset_content_type: application/gzip
|
||||||
|
|||||||
6
.github/workflows/smoke.yml
vendored
6
.github/workflows/smoke.yml
vendored
@@ -14,14 +14,14 @@ on:
|
|||||||
jobs:
|
jobs:
|
||||||
|
|
||||||
smoke:
|
smoke:
|
||||||
name: Run 3 node smoke test
|
name: Run multi node smoke test
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
|
|
||||||
- name: Set up Go 1.14
|
- name: Set up Go 1.15
|
||||||
uses: actions/setup-go@v1
|
uses: actions/setup-go@v1
|
||||||
with:
|
with:
|
||||||
go-version: 1.14
|
go-version: 1.15
|
||||||
id: go
|
id: go
|
||||||
|
|
||||||
- name: Check out code into the Go module directory
|
- name: Check out code into the Go module directory
|
||||||
|
|||||||
27
.github/workflows/smoke/build.sh
vendored
27
.github/workflows/smoke/build.sh
vendored
@@ -11,14 +11,29 @@ mkdir ./build
|
|||||||
cp ../../../../nebula .
|
cp ../../../../nebula .
|
||||||
cp ../../../../nebula-cert .
|
cp ../../../../nebula-cert .
|
||||||
|
|
||||||
HOST="lighthouse1" AM_LIGHTHOUSE=true ../genconfig.sh >lighthouse1.yml
|
HOST="lighthouse1" \
|
||||||
HOST="host2" LIGHTHOUSES="192.168.100.1 172.17.0.2:4242" ../genconfig.sh >host2.yml
|
AM_LIGHTHOUSE=true \
|
||||||
HOST="host3" LIGHTHOUSES="192.168.100.1 172.17.0.2:4242" ../genconfig.sh >host3.yml
|
../genconfig.sh >lighthouse1.yml
|
||||||
|
|
||||||
|
HOST="host2" \
|
||||||
|
LIGHTHOUSES="192.168.100.1 172.17.0.2:4242" \
|
||||||
|
../genconfig.sh >host2.yml
|
||||||
|
|
||||||
|
HOST="host3" \
|
||||||
|
LIGHTHOUSES="192.168.100.1 172.17.0.2:4242" \
|
||||||
|
INBOUND='[{"port": "any", "proto": "icmp", "group": "lighthouse"}]' \
|
||||||
|
../genconfig.sh >host3.yml
|
||||||
|
|
||||||
|
HOST="host4" \
|
||||||
|
LIGHTHOUSES="192.168.100.1 172.17.0.2:4242" \
|
||||||
|
OUTBOUND='[{"port": "any", "proto": "icmp", "group": "lighthouse"}]' \
|
||||||
|
../genconfig.sh >host4.yml
|
||||||
|
|
||||||
./nebula-cert ca -name "Smoke Test"
|
./nebula-cert ca -name "Smoke Test"
|
||||||
./nebula-cert sign -name "lighthouse1" -ip "192.168.100.1/24"
|
./nebula-cert sign -name "lighthouse1" -groups "lighthouse,lighthouse1" -ip "192.168.100.1/24"
|
||||||
./nebula-cert sign -name "host2" -ip "192.168.100.2/24"
|
./nebula-cert sign -name "host2" -groups "host,host2" -ip "192.168.100.2/24"
|
||||||
./nebula-cert sign -name "host3" -ip "192.168.100.3/24"
|
./nebula-cert sign -name "host3" -groups "host,host3" -ip "192.168.100.3/24"
|
||||||
|
./nebula-cert sign -name "host4" -groups "host,host4" -ip "192.168.100.4/24"
|
||||||
)
|
)
|
||||||
|
|
||||||
docker build -t nebula:smoke .
|
docker build -t nebula:smoke .
|
||||||
|
|||||||
12
.github/workflows/smoke/genconfig.sh
vendored
12
.github/workflows/smoke/genconfig.sh
vendored
@@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
set -e
|
set -e
|
||||||
|
|
||||||
|
FIREWALL_ALL='[{"port": "any", "proto": "any", "host": "any"}]'
|
||||||
|
|
||||||
if [ "$STATIC_HOSTS" ] || [ "$LIGHTHOUSES" ]
|
if [ "$STATIC_HOSTS" ] || [ "$LIGHTHOUSES" ]
|
||||||
then
|
then
|
||||||
@@ -48,13 +49,6 @@ tun:
|
|||||||
dev: ${TUN_DEV:-nebula1}
|
dev: ${TUN_DEV:-nebula1}
|
||||||
|
|
||||||
firewall:
|
firewall:
|
||||||
outbound:
|
outbound: ${OUTBOUND:-$FIREWALL_ALL}
|
||||||
- port: any
|
inbound: ${INBOUND:-$FIREWALL_ALL}
|
||||||
proto: any
|
|
||||||
host: any
|
|
||||||
|
|
||||||
inbound:
|
|
||||||
- port: any
|
|
||||||
proto: any
|
|
||||||
host: any
|
|
||||||
EOF
|
EOF
|
||||||
|
|||||||
27
.github/workflows/smoke/smoke.sh
vendored
27
.github/workflows/smoke/smoke.sh
vendored
@@ -5,6 +5,7 @@ set -e -x
|
|||||||
docker run --name lighthouse1 --rm nebula:smoke -config lighthouse1.yml -test
|
docker run --name lighthouse1 --rm nebula:smoke -config lighthouse1.yml -test
|
||||||
docker run --name host2 --rm nebula:smoke -config host2.yml -test
|
docker run --name host2 --rm nebula:smoke -config host2.yml -test
|
||||||
docker run --name host3 --rm nebula:smoke -config host3.yml -test
|
docker run --name host3 --rm nebula:smoke -config host3.yml -test
|
||||||
|
docker run --name host4 --rm nebula:smoke -config host4.yml -test
|
||||||
|
|
||||||
docker run --name lighthouse1 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm nebula:smoke -config lighthouse1.yml &
|
docker run --name lighthouse1 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm nebula:smoke -config lighthouse1.yml &
|
||||||
sleep 1
|
sleep 1
|
||||||
@@ -12,6 +13,8 @@ docker run --name host2 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN -
|
|||||||
sleep 1
|
sleep 1
|
||||||
docker run --name host3 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm nebula:smoke -config host3.yml &
|
docker run --name host3 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm nebula:smoke -config host3.yml &
|
||||||
sleep 1
|
sleep 1
|
||||||
|
docker run --name host4 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm nebula:smoke -config host4.yml &
|
||||||
|
sleep 1
|
||||||
|
|
||||||
set +x
|
set +x
|
||||||
echo
|
echo
|
||||||
@@ -27,7 +30,8 @@ echo " *** Testing ping from host2"
|
|||||||
echo
|
echo
|
||||||
set -x
|
set -x
|
||||||
docker exec host2 ping -c1 192.168.100.1
|
docker exec host2 ping -c1 192.168.100.1
|
||||||
docker exec host2 ping -c1 192.168.100.3
|
# Should fail because not allowed by host3 inbound firewall
|
||||||
|
! docker exec host2 ping -c1 192.168.100.3 -w5 || exit 1
|
||||||
|
|
||||||
set +x
|
set +x
|
||||||
echo
|
echo
|
||||||
@@ -36,3 +40,24 @@ echo
|
|||||||
set -x
|
set -x
|
||||||
docker exec host3 ping -c1 192.168.100.1
|
docker exec host3 ping -c1 192.168.100.1
|
||||||
docker exec host3 ping -c1 192.168.100.2
|
docker exec host3 ping -c1 192.168.100.2
|
||||||
|
|
||||||
|
set +x
|
||||||
|
echo
|
||||||
|
echo " *** Testing ping from host4"
|
||||||
|
echo
|
||||||
|
set -x
|
||||||
|
docker exec host4 ping -c1 192.168.100.1
|
||||||
|
# Should fail because not allowed by host4 outbound firewall
|
||||||
|
! docker exec host4 ping -c1 192.168.100.2 -w5 || exit 1
|
||||||
|
! docker exec host4 ping -c1 192.168.100.3 -w5 || exit 1
|
||||||
|
|
||||||
|
set +x
|
||||||
|
echo
|
||||||
|
echo " *** Testing conntrack"
|
||||||
|
echo
|
||||||
|
set -x
|
||||||
|
# host2 can ping host3 now that host3 pinged it first
|
||||||
|
docker exec host2 ping -c1 192.168.100.3
|
||||||
|
# host4 can ping host2 once conntrack established
|
||||||
|
docker exec host2 ping -c1 192.168.100.4
|
||||||
|
docker exec host4 ping -c1 192.168.100.2
|
||||||
|
|||||||
8
.github/workflows/test.yml
vendored
8
.github/workflows/test.yml
vendored
@@ -18,10 +18,10 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
|
|
||||||
- name: Set up Go 1.14
|
- name: Set up Go 1.15
|
||||||
uses: actions/setup-go@v1
|
uses: actions/setup-go@v1
|
||||||
with:
|
with:
|
||||||
go-version: 1.14
|
go-version: 1.15
|
||||||
id: go
|
id: go
|
||||||
|
|
||||||
- name: Check out code into the Go module directory
|
- name: Check out code into the Go module directory
|
||||||
@@ -48,10 +48,10 @@ jobs:
|
|||||||
os: [windows-latest, macOS-latest]
|
os: [windows-latest, macOS-latest]
|
||||||
steps:
|
steps:
|
||||||
|
|
||||||
- name: Set up Go 1.14
|
- name: Set up Go 1.15
|
||||||
uses: actions/setup-go@v1
|
uses: actions/setup-go@v1
|
||||||
with:
|
with:
|
||||||
go-version: 1.14
|
go-version: 1.15
|
||||||
id: go
|
id: go
|
||||||
|
|
||||||
- name: Check out code into the Go module directory
|
- name: Check out code into the Go module directory
|
||||||
|
|||||||
70
CHANGELOG.md
70
CHANGELOG.md
@@ -7,6 +7,73 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
|||||||
|
|
||||||
## [Unreleased]
|
## [Unreleased]
|
||||||
|
|
||||||
|
### Changed
|
||||||
|
|
||||||
|
- Updated the kardianos/service go library from 1.0.0 to 1.1.0, which
|
||||||
|
now creates launchd plist to write stdout/stderr to files by default.
|
||||||
|
|
||||||
|
## [1.3.0] - 2020-09-22
|
||||||
|
|
||||||
|
### Added
|
||||||
|
|
||||||
|
- You can emit statistics about non-message packets by setting the option
|
||||||
|
`stats.message_metrics`. You can similarly emit detailed statistics about
|
||||||
|
lighthouse packets by setting the option `stats.lighthouse_metrics`. See
|
||||||
|
the example config for more details. (#230)
|
||||||
|
|
||||||
|
- We now support freebsd/amd64. This is experimental, please give us feedback.
|
||||||
|
(#103)
|
||||||
|
|
||||||
|
- We now release a binary for `linux/mips-softfloat` which has also been
|
||||||
|
stripped to reduce filesize and hopefully have a better chance on running on
|
||||||
|
small mips devices. (#231)
|
||||||
|
|
||||||
|
- You can set `tun.disabled` to true to run a standalone lighthouse without a
|
||||||
|
tun device (and thus, without root). (#269)
|
||||||
|
|
||||||
|
- You can set `logging.disable_timestamp` to remove timestamps from log lines,
|
||||||
|
which is useful when output is redirected to a logging system that already
|
||||||
|
adds timestamps. (#288)
|
||||||
|
|
||||||
|
### Changed
|
||||||
|
|
||||||
|
- Handshakes should now trigger faster, as we try to be proactive with sending
|
||||||
|
them instead of waiting for the next timer tick in most cases. (#246, #265)
|
||||||
|
|
||||||
|
- Previously, we would drop the conntrack table whenever firewall rules were
|
||||||
|
changed during a SIGHUP. Now, we will maintain the table and just validate
|
||||||
|
that an entry still matches with the new rule set. (#233)
|
||||||
|
|
||||||
|
- Debug logs for firewall drops now include the reason. (#220, #239)
|
||||||
|
|
||||||
|
- Logs for handshakes now include the fingerprint of the remote host. (#262)
|
||||||
|
|
||||||
|
- Config item `pki.blacklist` is now `pki.blocklist`. (#272)
|
||||||
|
|
||||||
|
- Better support for older Linux kernels. We now only set `SO_REUSEPORT` if
|
||||||
|
`tun.routines` is greater than 1 (default is 1). We also only use the
|
||||||
|
`recvmmsg` syscall if `listen.batch` is greater than 1 (default is 64).
|
||||||
|
(#275)
|
||||||
|
|
||||||
|
- It is possible to run Nebula as a library inside of another process now.
|
||||||
|
Note that this is still experimental and the internal APIs around this might
|
||||||
|
change in minor version releases. (#279)
|
||||||
|
|
||||||
|
### Deprecated
|
||||||
|
|
||||||
|
- `pki.blacklist` is deprecated in favor of `pki.blocklist` with the same
|
||||||
|
functionality. Existing configs will continue to load for this release to
|
||||||
|
allow for migrations. (#272)
|
||||||
|
|
||||||
|
### Fixed
|
||||||
|
|
||||||
|
- `advmss` is now set correctly for each route table entry when `tun.routes`
|
||||||
|
is configured to have some routes with higher MTU. (#245)
|
||||||
|
|
||||||
|
- Packets that arrive on the tun device with an unroutable destination IP are
|
||||||
|
now dropped correctly, instead of wasting time making queries to the
|
||||||
|
lighthouses for IP `0.0.0.0` (#267)
|
||||||
|
|
||||||
## [1.2.0] - 2020-04-08
|
## [1.2.0] - 2020-04-08
|
||||||
|
|
||||||
### Added
|
### Added
|
||||||
@@ -118,7 +185,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
|||||||
|
|
||||||
- Initial public release.
|
- Initial public release.
|
||||||
|
|
||||||
[Unreleased]: https://github.com/slackhq/nebula/compare/v1.2.0...HEAD
|
[Unreleased]: https://github.com/slackhq/nebula/compare/v1.3.0...HEAD
|
||||||
|
[1.3.0]: https://github.com/slackhq/nebula/releases/tag/v1.3.0
|
||||||
[1.2.0]: https://github.com/slackhq/nebula/releases/tag/v1.2.0
|
[1.2.0]: https://github.com/slackhq/nebula/releases/tag/v1.2.0
|
||||||
[1.1.0]: https://github.com/slackhq/nebula/releases/tag/v1.1.0
|
[1.1.0]: https://github.com/slackhq/nebula/releases/tag/v1.1.0
|
||||||
[1.0.0]: https://github.com/slackhq/nebula/releases/tag/v1.0.0
|
[1.0.0]: https://github.com/slackhq/nebula/releases/tag/v1.0.0
|
||||||
|
|||||||
35
Makefile
35
Makefile
@@ -3,6 +3,8 @@ BUILD_NUMBER ?= dev+$(shell date -u '+%Y%m%d%H%M%S')
|
|||||||
GO111MODULE = on
|
GO111MODULE = on
|
||||||
export GO111MODULE
|
export GO111MODULE
|
||||||
|
|
||||||
|
LDFLAGS = -X main.Build=$(BUILD_NUMBER)
|
||||||
|
|
||||||
ALL_LINUX = linux-amd64 \
|
ALL_LINUX = linux-amd64 \
|
||||||
linux-386 \
|
linux-386 \
|
||||||
linux-ppc64le \
|
linux-ppc64le \
|
||||||
@@ -13,10 +15,12 @@ ALL_LINUX = linux-amd64 \
|
|||||||
linux-mips \
|
linux-mips \
|
||||||
linux-mipsle \
|
linux-mipsle \
|
||||||
linux-mips64 \
|
linux-mips64 \
|
||||||
linux-mips64le
|
linux-mips64le \
|
||||||
|
linux-mips-softfloat
|
||||||
|
|
||||||
ALL = $(ALL_LINUX) \
|
ALL = $(ALL_LINUX) \
|
||||||
darwin-amd64 \
|
darwin-amd64 \
|
||||||
|
freebsd-amd64 \
|
||||||
windows-amd64
|
windows-amd64
|
||||||
|
|
||||||
all: $(ALL:%=build/%/nebula) $(ALL:%=build/%/nebula-cert)
|
all: $(ALL:%=build/%/nebula) $(ALL:%=build/%/nebula-cert)
|
||||||
@@ -25,31 +29,40 @@ release: $(ALL:%=build/nebula-%.tar.gz)
|
|||||||
|
|
||||||
release-linux: $(ALL_LINUX:%=build/nebula-%.tar.gz)
|
release-linux: $(ALL_LINUX:%=build/nebula-%.tar.gz)
|
||||||
|
|
||||||
|
release-freebsd: build/nebula-freebsd-amd64.tar.gz
|
||||||
|
|
||||||
bin-windows: build/windows-amd64/nebula.exe build/windows-amd64/nebula-cert.exe
|
bin-windows: build/windows-amd64/nebula.exe build/windows-amd64/nebula-cert.exe
|
||||||
mv $? .
|
mv $? .
|
||||||
|
|
||||||
bin-darwin: build/darwin-amd64/nebula build/darwin-amd64/nebula-cert
|
bin-darwin: build/darwin-amd64/nebula build/darwin-amd64/nebula-cert
|
||||||
mv $? .
|
mv $? .
|
||||||
|
|
||||||
|
bin-freebsd: build/freebsd-amd64/nebula build/freebsd-amd64/nebula-cert
|
||||||
|
mv $? .
|
||||||
|
|
||||||
bin:
|
bin:
|
||||||
go build -trimpath -ldflags "-X main.Build=$(BUILD_NUMBER)" -o ./nebula ${NEBULA_CMD_PATH}
|
go build -trimpath -ldflags "$(LDFLAGS)" -o ./nebula ${NEBULA_CMD_PATH}
|
||||||
go build -trimpath -ldflags "-X main.Build=$(BUILD_NUMBER)" -o ./nebula-cert ./cmd/nebula-cert
|
go build -trimpath -ldflags "$(LDFLAGS)" -o ./nebula-cert ./cmd/nebula-cert
|
||||||
|
|
||||||
install:
|
install:
|
||||||
go install -trimpath -ldflags "-X main.Build=$(BUILD_NUMBER)" ${NEBULA_CMD_PATH}
|
go install -trimpath -ldflags "$(LDFLAGS)" ${NEBULA_CMD_PATH}
|
||||||
go install -trimpath -ldflags "-X main.Build=$(BUILD_NUMBER)" ./cmd/nebula-cert
|
go install -trimpath -ldflags "$(LDFLAGS)" ./cmd/nebula-cert
|
||||||
|
|
||||||
|
build/linux-arm-%: GOENV += GOARM=$(word 3, $(subst -, ,$*))
|
||||||
|
build/linux-mips-%: GOENV += GOMIPS=$(word 3, $(subst -, ,$*))
|
||||||
|
|
||||||
|
# Build an extra small binary for mips-softfloat
|
||||||
|
build/linux-mips-softfloat/%: LDFLAGS += -s -w
|
||||||
|
|
||||||
build/%/nebula: .FORCE
|
build/%/nebula: .FORCE
|
||||||
GOOS=$(firstword $(subst -, , $*)) \
|
GOOS=$(firstword $(subst -, , $*)) \
|
||||||
GOARCH=$(word 2, $(subst -, ,$*)) \
|
GOARCH=$(word 2, $(subst -, ,$*)) $(GOENV) \
|
||||||
GOARM=$(word 3, $(subst -, ,$*)) \
|
go build -trimpath -o $@ -ldflags "$(LDFLAGS)" ${NEBULA_CMD_PATH}
|
||||||
go build -trimpath -o $@ -ldflags "-X main.Build=$(BUILD_NUMBER)" ${NEBULA_CMD_PATH}
|
|
||||||
|
|
||||||
build/%/nebula-cert: .FORCE
|
build/%/nebula-cert: .FORCE
|
||||||
GOOS=$(firstword $(subst -, , $*)) \
|
GOOS=$(firstword $(subst -, , $*)) \
|
||||||
GOARCH=$(word 2, $(subst -, ,$*)) \
|
GOARCH=$(word 2, $(subst -, ,$*)) $(GOENV) \
|
||||||
GOARM=$(word 3, $(subst -, ,$*)) \
|
go build -trimpath -o $@ -ldflags "$(LDFLAGS)" ./cmd/nebula-cert
|
||||||
go build -trimpath -o $@ -ldflags "-X main.Build=$(BUILD_NUMBER)" ./cmd/nebula-cert
|
|
||||||
|
|
||||||
build/%/nebula.exe: build/%/nebula
|
build/%/nebula.exe: build/%/nebula
|
||||||
mv $< $@
|
mv $< $@
|
||||||
|
|||||||
@@ -212,10 +212,10 @@ func TestBitsLostCounter(t *testing.T) {
|
|||||||
func BenchmarkBits(b *testing.B) {
|
func BenchmarkBits(b *testing.B) {
|
||||||
z := NewBits(10)
|
z := NewBits(10)
|
||||||
for n := 0; n < b.N; n++ {
|
for n := 0; n < b.N; n++ {
|
||||||
for i, _ := range z.bits {
|
for i := range z.bits {
|
||||||
z.bits[i] = true
|
z.bits[i] = true
|
||||||
}
|
}
|
||||||
for i, _ := range z.bits {
|
for i := range z.bits {
|
||||||
z.bits[i] = false
|
z.bits[i] = false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
12
cert.go
12
cert.go
@@ -149,10 +149,16 @@ func loadCAFromConfig(c *Config) (*cert.NebulaCAPool, error) {
|
|||||||
return nil, fmt.Errorf("error while adding CA certificate to CA trust store: %s", err)
|
return nil, fmt.Errorf("error while adding CA certificate to CA trust store: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// pki.blacklist entered the scene at about the same time we aliased x509 to pki, not supporting backwards compat
|
for _, fp := range c.GetStringSlice("pki.blocklist", []string{}) {
|
||||||
|
l.WithField("fingerprint", fp).Infof("Blocklisting cert")
|
||||||
|
CAs.BlocklistFingerprint(fp)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Support deprecated config for at leaast one minor release to allow for migrations
|
||||||
for _, fp := range c.GetStringSlice("pki.blacklist", []string{}) {
|
for _, fp := range c.GetStringSlice("pki.blacklist", []string{}) {
|
||||||
l.WithField("fingerprint", fp).Infof("Blacklisting cert")
|
l.WithField("fingerprint", fp).Infof("Blocklisting cert")
|
||||||
CAs.BlacklistFingerprint(fp)
|
l.Warn("pki.blacklist is deprecated and will not be supported in a future release. Please migrate your config to use pki.blocklist")
|
||||||
|
CAs.BlocklistFingerprint(fp)
|
||||||
}
|
}
|
||||||
|
|
||||||
return CAs, nil
|
return CAs, nil
|
||||||
|
|||||||
22
cert/ca.go
22
cert/ca.go
@@ -8,14 +8,14 @@ import (
|
|||||||
|
|
||||||
type NebulaCAPool struct {
|
type NebulaCAPool struct {
|
||||||
CAs map[string]*NebulaCertificate
|
CAs map[string]*NebulaCertificate
|
||||||
certBlacklist map[string]struct{}
|
certBlocklist map[string]struct{}
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewCAPool creates a CAPool
|
// NewCAPool creates a CAPool
|
||||||
func NewCAPool() *NebulaCAPool {
|
func NewCAPool() *NebulaCAPool {
|
||||||
ca := NebulaCAPool{
|
ca := NebulaCAPool{
|
||||||
CAs: make(map[string]*NebulaCertificate),
|
CAs: make(map[string]*NebulaCertificate),
|
||||||
certBlacklist: make(map[string]struct{}),
|
certBlocklist: make(map[string]struct{}),
|
||||||
}
|
}
|
||||||
|
|
||||||
return &ca
|
return &ca
|
||||||
@@ -67,24 +67,24 @@ func (ncp *NebulaCAPool) AddCACertificate(pemBytes []byte) ([]byte, error) {
|
|||||||
return pemBytes, nil
|
return pemBytes, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// BlacklistFingerprint adds a cert fingerprint to the blacklist
|
// BlocklistFingerprint adds a cert fingerprint to the blocklist
|
||||||
func (ncp *NebulaCAPool) BlacklistFingerprint(f string) {
|
func (ncp *NebulaCAPool) BlocklistFingerprint(f string) {
|
||||||
ncp.certBlacklist[f] = struct{}{}
|
ncp.certBlocklist[f] = struct{}{}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ResetCertBlacklist removes all previously blacklisted cert fingerprints
|
// ResetCertBlocklist removes all previously blocklisted cert fingerprints
|
||||||
func (ncp *NebulaCAPool) ResetCertBlacklist() {
|
func (ncp *NebulaCAPool) ResetCertBlocklist() {
|
||||||
ncp.certBlacklist = make(map[string]struct{})
|
ncp.certBlocklist = make(map[string]struct{})
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsBlacklisted returns true if the fingerprint fails to generate or has been explicitly blacklisted
|
// IsBlocklisted returns true if the fingerprint fails to generate or has been explicitly blocklisted
|
||||||
func (ncp *NebulaCAPool) IsBlacklisted(c *NebulaCertificate) bool {
|
func (ncp *NebulaCAPool) IsBlocklisted(c *NebulaCertificate) bool {
|
||||||
h, err := c.Sha256Sum()
|
h, err := c.Sha256Sum()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, ok := ncp.certBlacklist[h]; ok {
|
if _, ok := ncp.certBlocklist[h]; ok {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
67
cert/cert.go
67
cert/cert.go
@@ -1,18 +1,18 @@
|
|||||||
package cert
|
package cert
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"crypto"
|
"crypto"
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"crypto/sha256"
|
"crypto/sha256"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
|
"encoding/json"
|
||||||
"encoding/pem"
|
"encoding/pem"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"bytes"
|
|
||||||
"encoding/json"
|
|
||||||
"github.com/golang/protobuf/proto"
|
"github.com/golang/protobuf/proto"
|
||||||
"golang.org/x/crypto/curve25519"
|
"golang.org/x/crypto/curve25519"
|
||||||
"golang.org/x/crypto/ed25519"
|
"golang.org/x/crypto/ed25519"
|
||||||
@@ -244,10 +244,10 @@ func (nc *NebulaCertificate) Expired(t time.Time) bool {
|
|||||||
return nc.Details.NotBefore.After(t) || nc.Details.NotAfter.Before(t)
|
return nc.Details.NotBefore.After(t) || nc.Details.NotAfter.Before(t)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verify will ensure a certificate is good in all respects (expiry, group membership, signature, cert blacklist, etc)
|
// Verify will ensure a certificate is good in all respects (expiry, group membership, signature, cert blocklist, etc)
|
||||||
func (nc *NebulaCertificate) Verify(t time.Time, ncp *NebulaCAPool) (bool, error) {
|
func (nc *NebulaCertificate) Verify(t time.Time, ncp *NebulaCAPool) (bool, error) {
|
||||||
if ncp.IsBlacklisted(nc) {
|
if ncp.IsBlocklisted(nc) {
|
||||||
return false, fmt.Errorf("certificate has been blacklisted")
|
return false, fmt.Errorf("certificate has been blocked")
|
||||||
}
|
}
|
||||||
|
|
||||||
signer, err := ncp.GetCAForCert(nc)
|
signer, err := ncp.GetCAForCert(nc)
|
||||||
@@ -468,6 +468,63 @@ func (nc *NebulaCertificate) MarshalJSON() ([]byte, error) {
|
|||||||
return json.Marshal(jc)
|
return json.Marshal(jc)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//func (nc *NebulaCertificate) Copy() *NebulaCertificate {
|
||||||
|
// r, err := nc.Marshal()
|
||||||
|
// if err != nil {
|
||||||
|
// //TODO
|
||||||
|
// return nil
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// c, err := UnmarshalNebulaCertificate(r)
|
||||||
|
// return c
|
||||||
|
//}
|
||||||
|
|
||||||
|
func (nc *NebulaCertificate) Copy() *NebulaCertificate {
|
||||||
|
c := &NebulaCertificate{
|
||||||
|
Details: NebulaCertificateDetails{
|
||||||
|
Name: nc.Details.Name,
|
||||||
|
Groups: make([]string, len(nc.Details.Groups)),
|
||||||
|
Ips: make([]*net.IPNet, len(nc.Details.Ips)),
|
||||||
|
Subnets: make([]*net.IPNet, len(nc.Details.Subnets)),
|
||||||
|
NotBefore: nc.Details.NotBefore,
|
||||||
|
NotAfter: nc.Details.NotAfter,
|
||||||
|
PublicKey: make([]byte, len(nc.Details.PublicKey)),
|
||||||
|
IsCA: nc.Details.IsCA,
|
||||||
|
Issuer: nc.Details.Issuer,
|
||||||
|
InvertedGroups: make(map[string]struct{}, len(nc.Details.InvertedGroups)),
|
||||||
|
},
|
||||||
|
Signature: make([]byte, len(nc.Signature)),
|
||||||
|
}
|
||||||
|
|
||||||
|
copy(c.Signature, nc.Signature)
|
||||||
|
copy(c.Details.Groups, nc.Details.Groups)
|
||||||
|
copy(c.Details.PublicKey, nc.Details.PublicKey)
|
||||||
|
|
||||||
|
for i, p := range nc.Details.Ips {
|
||||||
|
c.Details.Ips[i] = &net.IPNet{
|
||||||
|
IP: make(net.IP, len(p.IP)),
|
||||||
|
Mask: make(net.IPMask, len(p.Mask)),
|
||||||
|
}
|
||||||
|
copy(c.Details.Ips[i].IP, p.IP)
|
||||||
|
copy(c.Details.Ips[i].Mask, p.Mask)
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, p := range nc.Details.Subnets {
|
||||||
|
c.Details.Subnets[i] = &net.IPNet{
|
||||||
|
IP: make(net.IP, len(p.IP)),
|
||||||
|
Mask: make(net.IPMask, len(p.Mask)),
|
||||||
|
}
|
||||||
|
copy(c.Details.Subnets[i].IP, p.IP)
|
||||||
|
copy(c.Details.Subnets[i].Mask, p.Mask)
|
||||||
|
}
|
||||||
|
|
||||||
|
for g := range nc.Details.InvertedGroups {
|
||||||
|
c.Details.InvertedGroups[g] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
func netMatch(certIp *net.IPNet, rootIps []*net.IPNet) bool {
|
func netMatch(certIp *net.IPNet, rootIps []*net.IPNet) bool {
|
||||||
for _, net := range rootIps {
|
for _, net := range rootIps {
|
||||||
if net.Contains(certIp.IP) && maskContains(net.Mask, certIp.Mask) {
|
if net.Contains(certIp.IP) && maskContains(net.Mask, certIp.Mask) {
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/golang/protobuf/proto"
|
"github.com/golang/protobuf/proto"
|
||||||
|
"github.com/slackhq/nebula/util"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"golang.org/x/crypto/curve25519"
|
"golang.org/x/crypto/curve25519"
|
||||||
"golang.org/x/crypto/ed25519"
|
"golang.org/x/crypto/ed25519"
|
||||||
@@ -172,13 +173,13 @@ func TestNebulaCertificate_Verify(t *testing.T) {
|
|||||||
|
|
||||||
f, err := c.Sha256Sum()
|
f, err := c.Sha256Sum()
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
caPool.BlacklistFingerprint(f)
|
caPool.BlocklistFingerprint(f)
|
||||||
|
|
||||||
v, err := c.Verify(time.Now(), caPool)
|
v, err := c.Verify(time.Now(), caPool)
|
||||||
assert.False(t, v)
|
assert.False(t, v)
|
||||||
assert.EqualError(t, err, "certificate has been blacklisted")
|
assert.EqualError(t, err, "certificate has been blocked")
|
||||||
|
|
||||||
caPool.ResetCertBlacklist()
|
caPool.ResetCertBlocklist()
|
||||||
v, err = c.Verify(time.Now(), caPool)
|
v, err = c.Verify(time.Now(), caPool)
|
||||||
assert.True(t, v)
|
assert.True(t, v)
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
@@ -487,6 +488,17 @@ func TestMarshalingNebulaCertificateConsistency(t *testing.T) {
|
|||||||
assert.Equal(t, "0a0774657374696e67121b8182845080feffff0f828284508080fcff0f8382845080fe83f80f1a1b8182844880fe83f80f8282844880feffff0f838284488080fcff0f220b746573742d67726f757031220b746573742d67726f757032220b746573742d67726f75703328f0e0e7d70430a08681c4053a20313233343536373839306162636564666768696a3132333435363738393061624a081234567890abcedf", fmt.Sprintf("%x", b))
|
assert.Equal(t, "0a0774657374696e67121b8182845080feffff0f828284508080fcff0f8382845080fe83f80f1a1b8182844880fe83f80f8282844880feffff0f838284488080fcff0f220b746573742d67726f757031220b746573742d67726f757032220b746573742d67726f75703328f0e0e7d70430a08681c4053a20313233343536373839306162636564666768696a3132333435363738393061624a081234567890abcedf", fmt.Sprintf("%x", b))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestNebulaCertificate_Copy(t *testing.T) {
|
||||||
|
ca, _, caKey, err := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
|
||||||
|
assert.Nil(t, err)
|
||||||
|
|
||||||
|
c, _, _, err := newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
|
||||||
|
assert.Nil(t, err)
|
||||||
|
cc := c.Copy()
|
||||||
|
|
||||||
|
util.AssertDeepCopyEqual(t, c, cc)
|
||||||
|
}
|
||||||
|
|
||||||
func newTestCaCert(before, after time.Time, ips, subnets []*net.IPNet, groups []string) (*NebulaCertificate, []byte, []byte, error) {
|
func newTestCaCert(before, after time.Time, ips, subnets []*net.IPNet, groups []string) (*NebulaCertificate, []byte, []byte, error) {
|
||||||
pub, priv, err := ed25519.GenerateKey(rand.Reader)
|
pub, priv, err := ed25519.GenerateKey(rand.Reader)
|
||||||
if before.IsZero() {
|
if before.IsZero() {
|
||||||
@@ -498,11 +510,12 @@ func newTestCaCert(before, after time.Time, ips, subnets []*net.IPNet, groups []
|
|||||||
|
|
||||||
nc := &NebulaCertificate{
|
nc := &NebulaCertificate{
|
||||||
Details: NebulaCertificateDetails{
|
Details: NebulaCertificateDetails{
|
||||||
Name: "test ca",
|
Name: "test ca",
|
||||||
NotBefore: before,
|
NotBefore: time.Unix(before.Unix(), 0),
|
||||||
NotAfter: after,
|
NotAfter: time.Unix(after.Unix(), 0),
|
||||||
PublicKey: pub,
|
PublicKey: pub,
|
||||||
IsCA: true,
|
IsCA: true,
|
||||||
|
InvertedGroups: make(map[string]struct{}),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -544,17 +557,17 @@ func newTestCert(ca *NebulaCertificate, key []byte, before, after time.Time, ips
|
|||||||
|
|
||||||
if len(ips) == 0 {
|
if len(ips) == 0 {
|
||||||
ips = []*net.IPNet{
|
ips = []*net.IPNet{
|
||||||
{IP: net.ParseIP("10.1.1.1"), Mask: net.IPMask(net.ParseIP("255.255.255.0"))},
|
{IP: net.ParseIP("10.1.1.1").To4(), Mask: net.IPMask(net.ParseIP("255.255.255.0").To4())},
|
||||||
{IP: net.ParseIP("10.1.1.2"), Mask: net.IPMask(net.ParseIP("255.255.0.0"))},
|
{IP: net.ParseIP("10.1.1.2").To4(), Mask: net.IPMask(net.ParseIP("255.255.0.0").To4())},
|
||||||
{IP: net.ParseIP("10.1.1.3"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))},
|
{IP: net.ParseIP("10.1.1.3").To4(), Mask: net.IPMask(net.ParseIP("255.0.255.0").To4())},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(subnets) == 0 {
|
if len(subnets) == 0 {
|
||||||
subnets = []*net.IPNet{
|
subnets = []*net.IPNet{
|
||||||
{IP: net.ParseIP("9.1.1.1"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))},
|
{IP: net.ParseIP("9.1.1.1").To4(), Mask: net.IPMask(net.ParseIP("255.0.255.0").To4())},
|
||||||
{IP: net.ParseIP("9.1.1.2"), Mask: net.IPMask(net.ParseIP("255.255.255.0"))},
|
{IP: net.ParseIP("9.1.1.2").To4(), Mask: net.IPMask(net.ParseIP("255.255.255.0").To4())},
|
||||||
{IP: net.ParseIP("9.1.1.3"), Mask: net.IPMask(net.ParseIP("255.255.0.0"))},
|
{IP: net.ParseIP("9.1.1.3").To4(), Mask: net.IPMask(net.ParseIP("255.255.0.0").To4())},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -562,15 +575,16 @@ func newTestCert(ca *NebulaCertificate, key []byte, before, after time.Time, ips
|
|||||||
|
|
||||||
nc := &NebulaCertificate{
|
nc := &NebulaCertificate{
|
||||||
Details: NebulaCertificateDetails{
|
Details: NebulaCertificateDetails{
|
||||||
Name: "testing",
|
Name: "testing",
|
||||||
Ips: ips,
|
Ips: ips,
|
||||||
Subnets: subnets,
|
Subnets: subnets,
|
||||||
Groups: groups,
|
Groups: groups,
|
||||||
NotBefore: before,
|
NotBefore: time.Unix(before.Unix(), 0),
|
||||||
NotAfter: after,
|
NotAfter: time.Unix(after.Unix(), 0),
|
||||||
PublicKey: pub,
|
PublicKey: pub,
|
||||||
IsCA: false,
|
IsCA: false,
|
||||||
Issuer: issuer,
|
Issuer: issuer,
|
||||||
|
InvertedGroups: make(map[string]struct{}),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -3,10 +3,11 @@ package main
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"errors"
|
"errors"
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"io"
|
"io"
|
||||||
"os"
|
"os"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
//TODO: all flag parsing continueOnError will print to stderr on its own currently
|
//TODO: all flag parsing continueOnError will print to stderr on its own currently
|
||||||
|
|||||||
@@ -4,11 +4,12 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"flag"
|
"flag"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/slackhq/nebula/cert"
|
|
||||||
"io"
|
"io"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"os"
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/slackhq/nebula/cert"
|
||||||
)
|
)
|
||||||
|
|
||||||
type printFlags struct {
|
type printFlags struct {
|
||||||
|
|||||||
@@ -2,12 +2,13 @@ package main
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"github.com/slackhq/nebula/cert"
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"os"
|
"os"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/slackhq/nebula/cert"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
func Test_printSummary(t *testing.T) {
|
func Test_printSummary(t *testing.T) {
|
||||||
|
|||||||
@@ -3,12 +3,13 @@ package main
|
|||||||
import (
|
import (
|
||||||
"flag"
|
"flag"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/slackhq/nebula/cert"
|
|
||||||
"io"
|
"io"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"os"
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/slackhq/nebula/cert"
|
||||||
)
|
)
|
||||||
|
|
||||||
type verifyFlags struct {
|
type verifyFlags struct {
|
||||||
|
|||||||
@@ -3,13 +3,14 @@ package main
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"github.com/slackhq/nebula/cert"
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"golang.org/x/crypto/ed25519"
|
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"os"
|
"os"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/slackhq/nebula/cert"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"golang.org/x/crypto/ed25519"
|
||||||
)
|
)
|
||||||
|
|
||||||
func Test_verifySummary(t *testing.T) {
|
func Test_verifySummary(t *testing.T) {
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
|
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
"github.com/slackhq/nebula"
|
"github.com/slackhq/nebula"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -45,5 +46,30 @@ func main() {
|
|||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
nebula.Main(*configPath, *configTest, Build)
|
config := nebula.NewConfig()
|
||||||
|
err := config.Load(*configPath)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("failed to load config: %s", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
l := logrus.New()
|
||||||
|
l.Out = os.Stdout
|
||||||
|
c, err := nebula.Main(config, *configTest, Build, l, nil)
|
||||||
|
|
||||||
|
switch v := err.(type) {
|
||||||
|
case nebula.ContextualError:
|
||||||
|
v.Log(l)
|
||||||
|
os.Exit(1)
|
||||||
|
case error:
|
||||||
|
l.WithError(err).Error("Failed to start")
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !*configTest {
|
||||||
|
c.Start()
|
||||||
|
c.ShutdownBlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
os.Exit(0)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,44 +1,53 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
|
||||||
"github.com/kardianos/service"
|
"github.com/kardianos/service"
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
"github.com/slackhq/nebula"
|
"github.com/slackhq/nebula"
|
||||||
)
|
)
|
||||||
|
|
||||||
var logger service.Logger
|
var logger service.Logger
|
||||||
|
|
||||||
type program struct {
|
type program struct {
|
||||||
exit chan struct{}
|
|
||||||
configPath *string
|
configPath *string
|
||||||
configTest *bool
|
configTest *bool
|
||||||
build string
|
build string
|
||||||
|
control *nebula.Control
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *program) Start(s service.Service) error {
|
func (p *program) Start(s service.Service) error {
|
||||||
logger.Info("Nebula service starting.")
|
|
||||||
p.exit = make(chan struct{})
|
|
||||||
// Start should not block.
|
// Start should not block.
|
||||||
go p.run()
|
logger.Info("Nebula service starting.")
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *program) run() error {
|
config := nebula.NewConfig()
|
||||||
nebula.Main(*p.configPath, *p.configTest, Build)
|
err := config.Load(*p.configPath)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to load config: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
l := logrus.New()
|
||||||
|
l.Out = os.Stdout
|
||||||
|
p.control, err = nebula.Main(config, *p.configTest, Build, l, nil)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
p.control.Start()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *program) Stop(s service.Service) error {
|
func (p *program) Stop(s service.Service) error {
|
||||||
logger.Info("Nebula service stopping.")
|
logger.Info("Nebula service stopping.")
|
||||||
close(p.exit)
|
p.control.Stop()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func doService(configPath *string, configTest *bool, build string, serviceFlag *string) {
|
func doService(configPath *string, configTest *bool, build string, serviceFlag *string) {
|
||||||
|
|
||||||
if *configPath == "" {
|
if *configPath == "" {
|
||||||
ex, err := os.Executable()
|
ex, err := os.Executable()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
|
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
"github.com/slackhq/nebula"
|
"github.com/slackhq/nebula"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -39,5 +40,30 @@ func main() {
|
|||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
nebula.Main(*configPath, *configTest, Build)
|
config := nebula.NewConfig()
|
||||||
|
err := config.Load(*configPath)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("failed to load config: %s", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
l := logrus.New()
|
||||||
|
l.Out = os.Stdout
|
||||||
|
c, err := nebula.Main(config, *configTest, Build, l, nil)
|
||||||
|
|
||||||
|
switch v := err.(type) {
|
||||||
|
case nebula.ContextualError:
|
||||||
|
v.Log(l)
|
||||||
|
os.Exit(1)
|
||||||
|
case error:
|
||||||
|
l.WithError(err).Error("Failed to start")
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !*configTest {
|
||||||
|
c.Start()
|
||||||
|
c.ShutdownBlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
os.Exit(0)
|
||||||
}
|
}
|
||||||
|
|||||||
36
config.go
36
config.go
@@ -1,10 +1,8 @@
|
|||||||
package nebula
|
package nebula
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/imdario/mergo"
|
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
"gopkg.in/yaml.v2"
|
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"net"
|
"net"
|
||||||
"os"
|
"os"
|
||||||
@@ -16,6 +14,10 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/imdario/mergo"
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
|
"gopkg.in/yaml.v2"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
@@ -56,6 +58,13 @@ func (c *Config) Load(path string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Config) LoadString(raw string) error {
|
||||||
|
if raw == "" {
|
||||||
|
return errors.New("Empty configuration")
|
||||||
|
}
|
||||||
|
return c.parseRaw([]byte(raw))
|
||||||
|
}
|
||||||
|
|
||||||
// RegisterReloadCallback stores a function to be called when a config reload is triggered. The functions registered
|
// RegisterReloadCallback stores a function to be called when a config reload is triggered. The functions registered
|
||||||
// here should decide if they need to make a change to the current process before making the change. HasChanged can be
|
// here should decide if they need to make a change to the current process before making the change. HasChanged can be
|
||||||
// used to help decide if a change is necessary.
|
// used to help decide if a change is necessary.
|
||||||
@@ -407,6 +416,18 @@ func (c *Config) addFile(path string, direct bool) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Config) parseRaw(b []byte) error {
|
||||||
|
var m map[interface{}]interface{}
|
||||||
|
|
||||||
|
err := yaml.Unmarshal(b, &m)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Settings = m
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (c *Config) parse() error {
|
func (c *Config) parse() error {
|
||||||
var m map[interface{}]interface{}
|
var m map[interface{}]interface{}
|
||||||
|
|
||||||
@@ -459,6 +480,7 @@ func configLogger(c *Config) error {
|
|||||||
}
|
}
|
||||||
l.SetLevel(logLevel)
|
l.SetLevel(logLevel)
|
||||||
|
|
||||||
|
disableTimestamp := c.GetBool("logging.disable_timestamp", false)
|
||||||
timestampFormat := c.GetString("logging.timestamp_format", "")
|
timestampFormat := c.GetString("logging.timestamp_format", "")
|
||||||
fullTimestamp := (timestampFormat != "")
|
fullTimestamp := (timestampFormat != "")
|
||||||
if timestampFormat == "" {
|
if timestampFormat == "" {
|
||||||
@@ -469,12 +491,14 @@ func configLogger(c *Config) error {
|
|||||||
switch logFormat {
|
switch logFormat {
|
||||||
case "text":
|
case "text":
|
||||||
l.Formatter = &logrus.TextFormatter{
|
l.Formatter = &logrus.TextFormatter{
|
||||||
TimestampFormat: timestampFormat,
|
TimestampFormat: timestampFormat,
|
||||||
FullTimestamp: fullTimestamp,
|
FullTimestamp: fullTimestamp,
|
||||||
|
DisableTimestamp: disableTimestamp,
|
||||||
}
|
}
|
||||||
case "json":
|
case "json":
|
||||||
l.Formatter = &logrus.JSONFormatter{
|
l.Formatter = &logrus.JSONFormatter{
|
||||||
TimestampFormat: timestampFormat,
|
TimestampFormat: timestampFormat,
|
||||||
|
DisableTimestamp: disableTimestamp,
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
return fmt.Errorf("unknown log format `%s`. possible formats: %s", logFormat, []string{"text", "json"})
|
return fmt.Errorf("unknown log format `%s`. possible formats: %s", logFormat, []string{"text", "json"})
|
||||||
|
|||||||
@@ -1,12 +1,13 @@
|
|||||||
package nebula
|
package nebula
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestConfig_Load(t *testing.T) {
|
func TestConfig_Load(t *testing.T) {
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ func Test_NewConnectionManagerTest(t *testing.T) {
|
|||||||
rawCertificateNoKey: []byte{},
|
rawCertificateNoKey: []byte{},
|
||||||
}
|
}
|
||||||
|
|
||||||
lh := NewLightHouse(false, 0, []uint32{}, 1000, 0, &udpConn{}, false, 1)
|
lh := NewLightHouse(false, 0, []uint32{}, 1000, 0, &udpConn{}, false, 1, false)
|
||||||
ifce := &Interface{
|
ifce := &Interface{
|
||||||
hostMap: hostMap,
|
hostMap: hostMap,
|
||||||
inside: &Tun{},
|
inside: &Tun{},
|
||||||
@@ -91,7 +91,7 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
|
|||||||
rawCertificateNoKey: []byte{},
|
rawCertificateNoKey: []byte{},
|
||||||
}
|
}
|
||||||
|
|
||||||
lh := NewLightHouse(false, 0, []uint32{}, 1000, 0, &udpConn{}, false, 1)
|
lh := NewLightHouse(false, 0, []uint32{}, 1000, 0, &udpConn{}, false, 1, false)
|
||||||
ifce := &Interface{
|
ifce := &Interface{
|
||||||
hostMap: hostMap,
|
hostMap: hostMap,
|
||||||
inside: &Tun{},
|
inside: &Tun{},
|
||||||
|
|||||||
216
control.go
Normal file
216
control.go
Normal file
@@ -0,0 +1,216 @@
|
|||||||
|
package nebula
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"os"
|
||||||
|
"os/signal"
|
||||||
|
"syscall"
|
||||||
|
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
|
"github.com/slackhq/nebula/cert"
|
||||||
|
"golang.org/x/net/ipv4"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Every interaction here needs to take extra care to copy memory and not return or use arguments "as is" when touching
|
||||||
|
// core. This means copying IP objects, slices, de-referencing pointers and taking the actual value, etc
|
||||||
|
|
||||||
|
type Control struct {
|
||||||
|
f *Interface
|
||||||
|
l *logrus.Logger
|
||||||
|
}
|
||||||
|
|
||||||
|
type ControlHostInfo struct {
|
||||||
|
VpnIP net.IP `json:"vpnIp"`
|
||||||
|
LocalIndex uint32 `json:"localIndex"`
|
||||||
|
RemoteIndex uint32 `json:"remoteIndex"`
|
||||||
|
RemoteAddrs []udpAddr `json:"remoteAddrs"`
|
||||||
|
CachedPackets int `json:"cachedPackets"`
|
||||||
|
Cert *cert.NebulaCertificate `json:"cert"`
|
||||||
|
MessageCounter uint64 `json:"messageCounter"`
|
||||||
|
CurrentRemote udpAddr `json:"currentRemote"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start actually runs nebula, this is a nonblocking call. To block use Control.ShutdownBlock()
|
||||||
|
func (c *Control) Start() {
|
||||||
|
c.f.run()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop signals nebula to shutdown, returns after the shutdown is complete
|
||||||
|
func (c *Control) Stop() {
|
||||||
|
//TODO: stop tun and udp routines, the lock on hostMap effectively does that though
|
||||||
|
//TODO: this is probably better as a function in ConnectionManager or HostMap directly
|
||||||
|
c.f.hostMap.Lock()
|
||||||
|
for _, h := range c.f.hostMap.Hosts {
|
||||||
|
if h.ConnectionState.ready {
|
||||||
|
c.f.send(closeTunnel, 0, h.ConnectionState, h, h.remote, []byte{}, make([]byte, 12, 12), make([]byte, mtu))
|
||||||
|
c.l.WithField("vpnIp", IntIp(h.hostId)).WithField("udpAddr", h.remote).
|
||||||
|
Debug("Sending close tunnel message")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
c.f.hostMap.Unlock()
|
||||||
|
c.l.Info("Goodbye")
|
||||||
|
}
|
||||||
|
|
||||||
|
// ShutdownBlock will listen for and block on term and interrupt signals, calling Control.Stop() once signalled
|
||||||
|
func (c *Control) ShutdownBlock() {
|
||||||
|
sigChan := make(chan os.Signal)
|
||||||
|
signal.Notify(sigChan, syscall.SIGTERM)
|
||||||
|
signal.Notify(sigChan, syscall.SIGINT)
|
||||||
|
|
||||||
|
rawSig := <-sigChan
|
||||||
|
sig := rawSig.String()
|
||||||
|
c.l.WithField("signal", sig).Info("Caught signal, shutting down")
|
||||||
|
c.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
|
// RebindUDPServer asks the UDP listener to rebind it's listener. Mainly used on mobile clients when interfaces change
|
||||||
|
func (c *Control) RebindUDPServer() {
|
||||||
|
_ = c.f.outside.Rebind()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListHostmap returns details about the actual or pending (handshaking) hostmap
|
||||||
|
func (c *Control) ListHostmap(pendingMap bool) []ControlHostInfo {
|
||||||
|
var hm *HostMap
|
||||||
|
if pendingMap {
|
||||||
|
hm = c.f.handshakeManager.pendingHostMap
|
||||||
|
} else {
|
||||||
|
hm = c.f.hostMap
|
||||||
|
}
|
||||||
|
|
||||||
|
hm.RLock()
|
||||||
|
hosts := make([]ControlHostInfo, len(hm.Hosts))
|
||||||
|
i := 0
|
||||||
|
for _, v := range hm.Hosts {
|
||||||
|
hosts[i] = copyHostInfo(v)
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
hm.RUnlock()
|
||||||
|
|
||||||
|
return hosts
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetHostInfoByVpnIP returns a single tunnels hostInfo, or nil if not found
|
||||||
|
func (c *Control) GetHostInfoByVpnIP(vpnIP uint32, pending bool) *ControlHostInfo {
|
||||||
|
var hm *HostMap
|
||||||
|
if pending {
|
||||||
|
hm = c.f.handshakeManager.pendingHostMap
|
||||||
|
} else {
|
||||||
|
hm = c.f.hostMap
|
||||||
|
}
|
||||||
|
|
||||||
|
h, err := hm.QueryVpnIP(vpnIP)
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
ch := copyHostInfo(h)
|
||||||
|
return &ch
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetRemoteForTunnel forces a tunnel to use a specific remote
|
||||||
|
func (c *Control) SetRemoteForTunnel(vpnIP uint32, addr udpAddr) *ControlHostInfo {
|
||||||
|
hostInfo, err := c.f.hostMap.QueryVpnIP(vpnIP)
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
hostInfo.SetRemote(addr.Copy())
|
||||||
|
ch := copyHostInfo(hostInfo)
|
||||||
|
return &ch
|
||||||
|
}
|
||||||
|
|
||||||
|
// CloseTunnel closes a fully established tunnel. If localOnly is false it will notify the remote end as well.
|
||||||
|
func (c *Control) CloseTunnel(vpnIP uint32, localOnly bool) bool {
|
||||||
|
hostInfo, err := c.f.hostMap.QueryVpnIP(vpnIP)
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if !localOnly {
|
||||||
|
c.f.send(
|
||||||
|
closeTunnel,
|
||||||
|
0,
|
||||||
|
hostInfo.ConnectionState,
|
||||||
|
hostInfo,
|
||||||
|
hostInfo.remote,
|
||||||
|
[]byte{},
|
||||||
|
make([]byte, 12, 12),
|
||||||
|
make([]byte, mtu),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
c.f.closeTunnel(hostInfo)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func copyHostInfo(h *HostInfo) ControlHostInfo {
|
||||||
|
addrs := h.RemoteUDPAddrs()
|
||||||
|
chi := ControlHostInfo{
|
||||||
|
VpnIP: int2ip(h.hostId),
|
||||||
|
LocalIndex: h.localIndexId,
|
||||||
|
RemoteIndex: h.remoteIndexId,
|
||||||
|
RemoteAddrs: make([]udpAddr, len(addrs), len(addrs)),
|
||||||
|
CachedPackets: len(h.packetStore),
|
||||||
|
MessageCounter: *h.ConnectionState.messageCounter,
|
||||||
|
}
|
||||||
|
|
||||||
|
if c := h.GetCert(); c != nil {
|
||||||
|
chi.Cert = c.Copy()
|
||||||
|
}
|
||||||
|
|
||||||
|
if h.remote != nil {
|
||||||
|
chi.CurrentRemote = *h.remote
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, addr := range addrs {
|
||||||
|
chi.RemoteAddrs[i] = addr.Copy()
|
||||||
|
}
|
||||||
|
|
||||||
|
return chi
|
||||||
|
}
|
||||||
|
|
||||||
|
// Hook provides the ability to hook into the network path for a particular
|
||||||
|
// message sub type. Any received message of that subtype that is allowed by
|
||||||
|
// the firewall will be written to the provided write func instead of the
|
||||||
|
// inside interface.
|
||||||
|
// TODO: make this an io.Writer
|
||||||
|
func (c *Control) Hook(t NebulaMessageSubType, w func([]byte) error) error {
|
||||||
|
if t == 0 {
|
||||||
|
return fmt.Errorf("non-default message subtype must be specified")
|
||||||
|
}
|
||||||
|
if _, ok := c.f.handlers[Version][message][t]; ok {
|
||||||
|
return fmt.Errorf("message subtype %d already hooked", t)
|
||||||
|
}
|
||||||
|
|
||||||
|
c.f.handlers[Version][message][t] = c.f.newHook(w)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send provides the ability to send arbitrary message packets to peer nodes.
|
||||||
|
// The provided payload will be encapsulated in a Nebula Firewall packet
|
||||||
|
// (IPv4 plus ports) from the node IP to the provided destination nebula IP.
|
||||||
|
// Any protocol handling above layer 3 (IP) must be managed by the caller.
|
||||||
|
func (c *Control) Send(ip uint32, port uint16, st NebulaMessageSubType, payload []byte) {
|
||||||
|
headerLen := ipv4.HeaderLen + minFwPacketLen
|
||||||
|
length := headerLen + len(payload)
|
||||||
|
packet := make([]byte, length)
|
||||||
|
packet[0] = 0x45 // IPv4 HL=20
|
||||||
|
packet[9] = 114 // Declare as arbitrary 0-hop protocol
|
||||||
|
binary.BigEndian.PutUint16(packet[2:4], uint16(length))
|
||||||
|
binary.BigEndian.PutUint32(packet[12:16], ip2int(c.f.inside.CidrNet().IP.To4()))
|
||||||
|
binary.BigEndian.PutUint32(packet[16:20], ip)
|
||||||
|
|
||||||
|
// Set identical values for src and dst port as they're only
|
||||||
|
// used for nebula firewall rule/conntrack matching.
|
||||||
|
binary.BigEndian.PutUint16(packet[20:22], port)
|
||||||
|
binary.BigEndian.PutUint16(packet[22:24], port)
|
||||||
|
|
||||||
|
copy(packet[headerLen:], payload)
|
||||||
|
|
||||||
|
fp := &FirewallPacket{}
|
||||||
|
nb := make([]byte, 12)
|
||||||
|
out := make([]byte, mtu)
|
||||||
|
c.f.consumeInsidePacket(st, packet, fp, nb, out)
|
||||||
|
}
|
||||||
111
control_test.go
Normal file
111
control_test.go
Normal file
@@ -0,0 +1,111 @@
|
|||||||
|
package nebula
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
|
"github.com/slackhq/nebula/cert"
|
||||||
|
"github.com/slackhq/nebula/util"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestControl_GetHostInfoByVpnIP(t *testing.T) {
|
||||||
|
// Special care must be taken to re-use all objects provided to the hostmap and certificate in the expectedInfo object
|
||||||
|
// To properly ensure we are not exposing core memory to the caller
|
||||||
|
hm := NewHostMap("test", &net.IPNet{}, make([]*net.IPNet, 0))
|
||||||
|
remote1 := NewUDPAddr(100, 4444)
|
||||||
|
remote2 := NewUDPAddr(101, 4444)
|
||||||
|
ipNet := net.IPNet{
|
||||||
|
IP: net.IPv4(1, 2, 3, 4),
|
||||||
|
Mask: net.IPMask{255, 255, 255, 0},
|
||||||
|
}
|
||||||
|
|
||||||
|
ipNet2 := net.IPNet{
|
||||||
|
IP: net.IPv4(1, 2, 3, 5),
|
||||||
|
Mask: net.IPMask{255, 255, 255, 0},
|
||||||
|
}
|
||||||
|
|
||||||
|
crt := &cert.NebulaCertificate{
|
||||||
|
Details: cert.NebulaCertificateDetails{
|
||||||
|
Name: "test",
|
||||||
|
Ips: []*net.IPNet{&ipNet},
|
||||||
|
Subnets: []*net.IPNet{},
|
||||||
|
Groups: []string{"default-group"},
|
||||||
|
NotBefore: time.Unix(1, 0),
|
||||||
|
NotAfter: time.Unix(2, 0),
|
||||||
|
PublicKey: []byte{5, 6, 7, 8},
|
||||||
|
IsCA: false,
|
||||||
|
Issuer: "the-issuer",
|
||||||
|
InvertedGroups: map[string]struct{}{"default-group": {}},
|
||||||
|
},
|
||||||
|
Signature: []byte{1, 2, 1, 2, 1, 3},
|
||||||
|
}
|
||||||
|
counter := uint64(0)
|
||||||
|
|
||||||
|
remotes := []*HostInfoDest{NewHostInfoDest(remote1), NewHostInfoDest(remote2)}
|
||||||
|
hm.Add(ip2int(ipNet.IP), &HostInfo{
|
||||||
|
remote: remote1,
|
||||||
|
Remotes: remotes,
|
||||||
|
ConnectionState: &ConnectionState{
|
||||||
|
peerCert: crt,
|
||||||
|
messageCounter: &counter,
|
||||||
|
},
|
||||||
|
remoteIndexId: 200,
|
||||||
|
localIndexId: 201,
|
||||||
|
hostId: ip2int(ipNet.IP),
|
||||||
|
})
|
||||||
|
|
||||||
|
hm.Add(ip2int(ipNet2.IP), &HostInfo{
|
||||||
|
remote: remote1,
|
||||||
|
Remotes: remotes,
|
||||||
|
ConnectionState: &ConnectionState{
|
||||||
|
peerCert: nil,
|
||||||
|
messageCounter: &counter,
|
||||||
|
},
|
||||||
|
remoteIndexId: 200,
|
||||||
|
localIndexId: 201,
|
||||||
|
hostId: ip2int(ipNet2.IP),
|
||||||
|
})
|
||||||
|
|
||||||
|
c := Control{
|
||||||
|
f: &Interface{
|
||||||
|
hostMap: hm,
|
||||||
|
},
|
||||||
|
l: logrus.New(),
|
||||||
|
}
|
||||||
|
|
||||||
|
thi := c.GetHostInfoByVpnIP(ip2int(ipNet.IP), false)
|
||||||
|
|
||||||
|
expectedInfo := ControlHostInfo{
|
||||||
|
VpnIP: net.IPv4(1, 2, 3, 4).To4(),
|
||||||
|
LocalIndex: 201,
|
||||||
|
RemoteIndex: 200,
|
||||||
|
RemoteAddrs: []udpAddr{*remote1, *remote2},
|
||||||
|
CachedPackets: 0,
|
||||||
|
Cert: crt.Copy(),
|
||||||
|
MessageCounter: 0,
|
||||||
|
CurrentRemote: *NewUDPAddr(100, 4444),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Make sure we don't have any unexpected fields
|
||||||
|
assertFields(t, []string{"VpnIP", "LocalIndex", "RemoteIndex", "RemoteAddrs", "CachedPackets", "Cert", "MessageCounter", "CurrentRemote"}, thi)
|
||||||
|
util.AssertDeepCopyEqual(t, &expectedInfo, thi)
|
||||||
|
|
||||||
|
// Make sure we don't panic if the host info doesn't have a cert yet
|
||||||
|
assert.NotPanics(t, func() {
|
||||||
|
thi = c.GetHostInfoByVpnIP(ip2int(ipNet2.IP), false)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func assertFields(t *testing.T, expected []string, actualStruct interface{}) {
|
||||||
|
val := reflect.ValueOf(actualStruct).Elem()
|
||||||
|
fields := make([]string, val.NumField())
|
||||||
|
for i := 0; i < val.NumField(); i++ {
|
||||||
|
fields[i] = val.Type().Field(i).Name
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, expected, fields)
|
||||||
|
}
|
||||||
4
dist/arch/nebula.service
vendored
4
dist/arch/nebula.service
vendored
@@ -1,7 +1,7 @@
|
|||||||
[Unit]
|
[Unit]
|
||||||
Description=nebula
|
Description=nebula
|
||||||
Wants=basic.target
|
Wants=basic.target network-online.target
|
||||||
After=basic.target network.target
|
After=basic.target network.target network-online.target
|
||||||
|
|
||||||
[Service]
|
[Service]
|
||||||
SyslogIdentifier=nebula
|
SyslogIdentifier=nebula
|
||||||
|
|||||||
@@ -7,8 +7,8 @@ pki:
|
|||||||
ca: /etc/nebula/ca.crt
|
ca: /etc/nebula/ca.crt
|
||||||
cert: /etc/nebula/host.crt
|
cert: /etc/nebula/host.crt
|
||||||
key: /etc/nebula/host.key
|
key: /etc/nebula/host.key
|
||||||
#blacklist is a list of certificate fingerprints that we will refuse to talk to
|
#blocklist is a list of certificate fingerprints that we will refuse to talk to
|
||||||
#blacklist:
|
#blocklist:
|
||||||
# - c99d4e650533b92061b09918e838a5a0a6aaee21eed1d12fd937682865936c72
|
# - c99d4e650533b92061b09918e838a5a0a6aaee21eed1d12fd937682865936c72
|
||||||
|
|
||||||
# The static host map defines a set of hosts with fixed IP addresses on the internet (or any network).
|
# The static host map defines a set of hosts with fixed IP addresses on the internet (or any network).
|
||||||
@@ -64,7 +64,7 @@ lighthouse:
|
|||||||
# the inverse). CIDR rules are matched after interface name rules.
|
# the inverse). CIDR rules are matched after interface name rules.
|
||||||
# Default is all local IP addresses.
|
# Default is all local IP addresses.
|
||||||
#local_allow_list:
|
#local_allow_list:
|
||||||
# Example to blacklist tun0 and all docker interfaces.
|
# Example to block tun0 and all docker interfaces.
|
||||||
#interfaces:
|
#interfaces:
|
||||||
#tun0: false
|
#tun0: false
|
||||||
#'docker.*': false
|
#'docker.*': false
|
||||||
@@ -124,6 +124,8 @@ punchy:
|
|||||||
|
|
||||||
# Configure the private interface. Note: addr is baked into the nebula certificate
|
# Configure the private interface. Note: addr is baked into the nebula certificate
|
||||||
tun:
|
tun:
|
||||||
|
# When tun is disabled, a lighthouse can be started without a local tun interface (and therefore without root)
|
||||||
|
disabled: false
|
||||||
# Name of the device
|
# Name of the device
|
||||||
dev: nebula1
|
dev: nebula1
|
||||||
# Toggles forwarding of local broadcast packets, the address of which depends on the ip/mask encoded in pki.cert
|
# Toggles forwarding of local broadcast packets, the address of which depends on the ip/mask encoded in pki.cert
|
||||||
@@ -154,6 +156,8 @@ logging:
|
|||||||
level: info
|
level: info
|
||||||
# json or text formats currently available. Default is text
|
# json or text formats currently available. Default is text
|
||||||
format: text
|
format: text
|
||||||
|
# Disable timestamp logging. useful when output is redirected to logging system that already adds timestamps. Default is false
|
||||||
|
#disable_timestamp: true
|
||||||
# timestamp format is specified in Go time format, see:
|
# timestamp format is specified in Go time format, see:
|
||||||
# https://golang.org/pkg/time/#pkg-constants
|
# https://golang.org/pkg/time/#pkg-constants
|
||||||
# default when `format: json`: "2006-01-02T15:04:05Z07:00" (RFC3339)
|
# default when `format: json`: "2006-01-02T15:04:05Z07:00" (RFC3339)
|
||||||
@@ -177,6 +181,15 @@ logging:
|
|||||||
#subsystem: nebula
|
#subsystem: nebula
|
||||||
#interval: 10s
|
#interval: 10s
|
||||||
|
|
||||||
|
# enables counter metrics for meta packets
|
||||||
|
# e.g.: `messages.tx.handshake`
|
||||||
|
# NOTE: `message.{tx,rx}.recv_error` is always emitted
|
||||||
|
#message_metrics: false
|
||||||
|
|
||||||
|
# enables detailed counter metrics for lighthouse packets
|
||||||
|
# e.g.: `lighthouse.rx.HostQuery`
|
||||||
|
#lighthouse_metrics: false
|
||||||
|
|
||||||
# Handshake Manger Settings
|
# Handshake Manger Settings
|
||||||
#handshakes:
|
#handshakes:
|
||||||
# Total time to try a handshake = sequence of `try_interval * retries`
|
# Total time to try a handshake = sequence of `try_interval * retries`
|
||||||
@@ -185,11 +198,14 @@ logging:
|
|||||||
#retries: 20
|
#retries: 20
|
||||||
# wait_rotation is the number of handshake attempts to do before starting to try non-local IP addresses
|
# wait_rotation is the number of handshake attempts to do before starting to try non-local IP addresses
|
||||||
#wait_rotation: 5
|
#wait_rotation: 5
|
||||||
|
# trigger_buffer is the size of the buffer channel for quickly sending handshakes
|
||||||
|
# after receiving the response for lighthouse queries
|
||||||
|
#trigger_buffer: 64
|
||||||
|
|
||||||
# Nebula security group configuration
|
# Nebula security group configuration
|
||||||
firewall:
|
firewall:
|
||||||
conntrack:
|
conntrack:
|
||||||
tcp_timeout: 120h
|
tcp_timeout: 12m
|
||||||
udp_timeout: 3m
|
udp_timeout: 3m
|
||||||
default_timeout: 10m
|
default_timeout: 10m
|
||||||
max_connections: 100000
|
max_connections: 100000
|
||||||
|
|||||||
149
firewall.go
149
firewall.go
@@ -1,21 +1,21 @@
|
|||||||
package nebula
|
package nebula
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"crypto/sha256"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
|
"encoding/hex"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"crypto/sha256"
|
|
||||||
"encoding/hex"
|
|
||||||
"errors"
|
|
||||||
"reflect"
|
"reflect"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/rcrowley/go-metrics"
|
"github.com/rcrowley/go-metrics"
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
"github.com/slackhq/nebula/cert"
|
"github.com/slackhq/nebula/cert"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -38,13 +38,19 @@ type FirewallInterface interface {
|
|||||||
|
|
||||||
type conn struct {
|
type conn struct {
|
||||||
Expires time.Time // Time when this conntrack entry will expire
|
Expires time.Time // Time when this conntrack entry will expire
|
||||||
Seq uint32 // If tcp rtt tracking is enabled this will be the seq we are looking for an ack
|
|
||||||
Sent time.Time // If tcp rtt tracking is enabled this will be when Seq was last set
|
Sent time.Time // If tcp rtt tracking is enabled this will be when Seq was last set
|
||||||
|
Seq uint32 // If tcp rtt tracking is enabled this will be the seq we are looking for an ack
|
||||||
|
|
||||||
|
// record why the original connection passed the firewall, so we can re-validate
|
||||||
|
// after ruleset changes. Note, rulesVersion is a uint16 so that these two
|
||||||
|
// fields pack for free after the uint32 above
|
||||||
|
incoming bool
|
||||||
|
rulesVersion uint16
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: need conntrack max tracked connections handling
|
// TODO: need conntrack max tracked connections handling
|
||||||
type Firewall struct {
|
type Firewall struct {
|
||||||
Conns map[FirewallPacket]*conn
|
Conntrack *FirewallConntrack
|
||||||
|
|
||||||
InRules *FirewallTable
|
InRules *FirewallTable
|
||||||
OutRules *FirewallTable
|
OutRules *FirewallTable
|
||||||
@@ -55,18 +61,23 @@ type Firewall struct {
|
|||||||
UDPTimeout time.Duration //linux: 180s max
|
UDPTimeout time.Duration //linux: 180s max
|
||||||
DefaultTimeout time.Duration //linux: 600s
|
DefaultTimeout time.Duration //linux: 600s
|
||||||
|
|
||||||
TimerWheel *TimerWheel
|
|
||||||
|
|
||||||
// Used to ensure we don't emit local packets for ips we don't own
|
// Used to ensure we don't emit local packets for ips we don't own
|
||||||
localIps *CIDRTree
|
localIps *CIDRTree
|
||||||
|
|
||||||
connMutex sync.Mutex
|
rules string
|
||||||
rules string
|
rulesVersion uint16
|
||||||
|
|
||||||
trackTCPRTT bool
|
trackTCPRTT bool
|
||||||
metricTCPRTT metrics.Histogram
|
metricTCPRTT metrics.Histogram
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type FirewallConntrack struct {
|
||||||
|
sync.Mutex
|
||||||
|
|
||||||
|
Conns map[FirewallPacket]*conn
|
||||||
|
TimerWheel *TimerWheel
|
||||||
|
}
|
||||||
|
|
||||||
type FirewallTable struct {
|
type FirewallTable struct {
|
||||||
TCP firewallPort
|
TCP firewallPort
|
||||||
UDP firewallPort
|
UDP firewallPort
|
||||||
@@ -172,10 +183,12 @@ func NewFirewall(tcpTimeout, UDPTimeout, defaultTimeout time.Duration, c *cert.N
|
|||||||
}
|
}
|
||||||
|
|
||||||
return &Firewall{
|
return &Firewall{
|
||||||
Conns: make(map[FirewallPacket]*conn),
|
Conntrack: &FirewallConntrack{
|
||||||
|
Conns: make(map[FirewallPacket]*conn),
|
||||||
|
TimerWheel: NewTimerWheel(min, max),
|
||||||
|
},
|
||||||
InRules: newFirewallTable(),
|
InRules: newFirewallTable(),
|
||||||
OutRules: newFirewallTable(),
|
OutRules: newFirewallTable(),
|
||||||
TimerWheel: NewTimerWheel(min, max),
|
|
||||||
TCPTimeout: tcpTimeout,
|
TCPTimeout: tcpTimeout,
|
||||||
UDPTimeout: UDPTimeout,
|
UDPTimeout: UDPTimeout,
|
||||||
DefaultTimeout: defaultTimeout,
|
DefaultTimeout: defaultTimeout,
|
||||||
@@ -208,11 +221,17 @@ func NewFirewallFromConfig(nc *cert.NebulaCertificate, c *Config) (*Firewall, er
|
|||||||
|
|
||||||
// AddRule properly creates the in memory rule structure for a firewall table.
|
// AddRule properly creates the in memory rule structure for a firewall table.
|
||||||
func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, caName string, caSha string) error {
|
func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, caName string, caSha string) error {
|
||||||
|
// Under gomobile, stringing a nil pointer with fmt causes an abort in debug mode for iOS
|
||||||
|
// https://github.com/golang/go/issues/14131
|
||||||
|
sIp := ""
|
||||||
|
if ip != nil {
|
||||||
|
sIp = ip.String()
|
||||||
|
}
|
||||||
|
|
||||||
// We need this rule string because we generate a hash. Removing this will break firewall reload.
|
// We need this rule string because we generate a hash. Removing this will break firewall reload.
|
||||||
ruleString := fmt.Sprintf(
|
ruleString := fmt.Sprintf(
|
||||||
"incoming: %v, proto: %v, startPort: %v, endPort: %v, groups: %v, host: %v, ip: %v, caName: %v, caSha: %s",
|
"incoming: %v, proto: %v, startPort: %v, endPort: %v, groups: %v, host: %v, ip: %v, caName: %v, caSha: %s",
|
||||||
incoming, proto, startPort, endPort, groups, host, ip, caName, caSha,
|
incoming, proto, startPort, endPort, groups, host, sIp, caName, caSha,
|
||||||
)
|
)
|
||||||
f.rules += ruleString + "\n"
|
f.rules += ruleString + "\n"
|
||||||
|
|
||||||
@@ -220,7 +239,7 @@ func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort
|
|||||||
if !incoming {
|
if !incoming {
|
||||||
direction = "outgoing"
|
direction = "outgoing"
|
||||||
}
|
}
|
||||||
l.WithField("firewallRule", m{"direction": direction, "proto": proto, "startPort": startPort, "endPort": endPort, "groups": groups, "host": host, "ip": ip, "caName": caName, "caSha": caSha}).
|
l.WithField("firewallRule", m{"direction": direction, "proto": proto, "startPort": startPort, "endPort": endPort, "groups": groups, "host": host, "ip": sIp, "caName": caName, "caSha": caSha}).
|
||||||
Info("Firewall rule added")
|
Info("Firewall rule added")
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -347,27 +366,33 @@ func AddFirewallRulesFromConfig(inbound bool, config *Config, fw FirewallInterfa
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Firewall) Drop(packet []byte, fp FirewallPacket, incoming bool, h *HostInfo, caPool *cert.NebulaCAPool) bool {
|
var ErrInvalidRemoteIP = errors.New("remote IP is not in remote certificate subnets")
|
||||||
|
var ErrInvalidLocalIP = errors.New("local IP is not in list of handled local IPs")
|
||||||
|
var ErrNoMatchingRule = errors.New("no matching rule in firewall table")
|
||||||
|
|
||||||
|
// Drop returns an error if the packet should be dropped, explaining why. It
|
||||||
|
// returns nil if the packet should not be dropped.
|
||||||
|
func (f *Firewall) Drop(packet []byte, fp FirewallPacket, incoming bool, h *HostInfo, caPool *cert.NebulaCAPool) error {
|
||||||
// Check if we spoke to this tuple, if we did then allow this packet
|
// Check if we spoke to this tuple, if we did then allow this packet
|
||||||
if f.inConns(packet, fp, incoming) {
|
if f.inConns(packet, fp, incoming, h, caPool) {
|
||||||
return false
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Make sure remote address matches nebula certificate
|
// Make sure remote address matches nebula certificate
|
||||||
if remoteCidr := h.remoteCidr; remoteCidr != nil {
|
if remoteCidr := h.remoteCidr; remoteCidr != nil {
|
||||||
if remoteCidr.Contains(fp.RemoteIP) == nil {
|
if remoteCidr.Contains(fp.RemoteIP) == nil {
|
||||||
return true
|
return ErrInvalidRemoteIP
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// Simple case: Certificate has one IP and no subnets
|
// Simple case: Certificate has one IP and no subnets
|
||||||
if fp.RemoteIP != h.hostId {
|
if fp.RemoteIP != h.hostId {
|
||||||
return true
|
return ErrInvalidRemoteIP
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Make sure we are supposed to be handling this local ip address
|
// Make sure we are supposed to be handling this local ip address
|
||||||
if f.localIps.Contains(fp.LocalIP) == nil {
|
if f.localIps.Contains(fp.LocalIP) == nil {
|
||||||
return true
|
return ErrInvalidLocalIP
|
||||||
}
|
}
|
||||||
|
|
||||||
table := f.OutRules
|
table := f.OutRules
|
||||||
@@ -377,13 +402,13 @@ func (f *Firewall) Drop(packet []byte, fp FirewallPacket, incoming bool, h *Host
|
|||||||
|
|
||||||
// We now know which firewall table to check against
|
// We now know which firewall table to check against
|
||||||
if !table.match(fp, incoming, h.ConnectionState.peerCert, caPool) {
|
if !table.match(fp, incoming, h.ConnectionState.peerCert, caPool) {
|
||||||
return true
|
return ErrNoMatchingRule
|
||||||
}
|
}
|
||||||
|
|
||||||
// We always want to conntrack since it is a faster operation
|
// We always want to conntrack since it is a faster operation
|
||||||
f.addConn(packet, fp, incoming)
|
f.addConn(packet, fp, incoming)
|
||||||
|
|
||||||
return false
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Destroy cleans up any known cyclical references so the object can be free'd my GC. This should be called if a new
|
// Destroy cleans up any known cyclical references so the object can be free'd my GC. This should be called if a new
|
||||||
@@ -393,26 +418,66 @@ func (f *Firewall) Destroy() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (f *Firewall) EmitStats() {
|
func (f *Firewall) EmitStats() {
|
||||||
conntrackCount := len(f.Conns)
|
conntrack := f.Conntrack
|
||||||
|
conntrack.Lock()
|
||||||
|
conntrackCount := len(conntrack.Conns)
|
||||||
|
conntrack.Unlock()
|
||||||
metrics.GetOrRegisterGauge("firewall.conntrack.count", nil).Update(int64(conntrackCount))
|
metrics.GetOrRegisterGauge("firewall.conntrack.count", nil).Update(int64(conntrackCount))
|
||||||
|
metrics.GetOrRegisterGauge("firewall.rules.version", nil).Update(int64(f.rulesVersion))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Firewall) inConns(packet []byte, fp FirewallPacket, incoming bool) bool {
|
func (f *Firewall) inConns(packet []byte, fp FirewallPacket, incoming bool, h *HostInfo, caPool *cert.NebulaCAPool) bool {
|
||||||
f.connMutex.Lock()
|
conntrack := f.Conntrack
|
||||||
|
conntrack.Lock()
|
||||||
|
|
||||||
// Purge every time we test
|
// Purge every time we test
|
||||||
ep, has := f.TimerWheel.Purge()
|
ep, has := conntrack.TimerWheel.Purge()
|
||||||
if has {
|
if has {
|
||||||
f.evict(ep)
|
f.evict(ep)
|
||||||
}
|
}
|
||||||
|
|
||||||
c, ok := f.Conns[fp]
|
c, ok := conntrack.Conns[fp]
|
||||||
|
|
||||||
if !ok {
|
if !ok {
|
||||||
f.connMutex.Unlock()
|
conntrack.Unlock()
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if c.rulesVersion != f.rulesVersion {
|
||||||
|
// This conntrack entry was for an older rule set, validate
|
||||||
|
// it still passes with the current rule set
|
||||||
|
table := f.OutRules
|
||||||
|
if c.incoming {
|
||||||
|
table = f.InRules
|
||||||
|
}
|
||||||
|
|
||||||
|
// We now know which firewall table to check against
|
||||||
|
if !table.match(fp, c.incoming, h.ConnectionState.peerCert, caPool) {
|
||||||
|
if l.Level >= logrus.DebugLevel {
|
||||||
|
h.logger().
|
||||||
|
WithField("fwPacket", fp).
|
||||||
|
WithField("incoming", c.incoming).
|
||||||
|
WithField("rulesVersion", f.rulesVersion).
|
||||||
|
WithField("oldRulesVersion", c.rulesVersion).
|
||||||
|
Debugln("dropping old conntrack entry, does not match new ruleset")
|
||||||
|
}
|
||||||
|
delete(conntrack.Conns, fp)
|
||||||
|
conntrack.Unlock()
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if l.Level >= logrus.DebugLevel {
|
||||||
|
h.logger().
|
||||||
|
WithField("fwPacket", fp).
|
||||||
|
WithField("incoming", c.incoming).
|
||||||
|
WithField("rulesVersion", f.rulesVersion).
|
||||||
|
WithField("oldRulesVersion", c.rulesVersion).
|
||||||
|
Debugln("keeping old conntrack entry, does match new ruleset")
|
||||||
|
}
|
||||||
|
|
||||||
|
c.rulesVersion = f.rulesVersion
|
||||||
|
}
|
||||||
|
|
||||||
switch fp.Protocol {
|
switch fp.Protocol {
|
||||||
case fwProtoTCP:
|
case fwProtoTCP:
|
||||||
c.Expires = time.Now().Add(f.TCPTimeout)
|
c.Expires = time.Now().Add(f.TCPTimeout)
|
||||||
@@ -427,7 +492,7 @@ func (f *Firewall) inConns(packet []byte, fp FirewallPacket, incoming bool) bool
|
|||||||
c.Expires = time.Now().Add(f.DefaultTimeout)
|
c.Expires = time.Now().Add(f.DefaultTimeout)
|
||||||
}
|
}
|
||||||
|
|
||||||
f.connMutex.Unlock()
|
conntrack.Unlock()
|
||||||
|
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
@@ -448,14 +513,19 @@ func (f *Firewall) addConn(packet []byte, fp FirewallPacket, incoming bool) {
|
|||||||
timeout = f.DefaultTimeout
|
timeout = f.DefaultTimeout
|
||||||
}
|
}
|
||||||
|
|
||||||
f.connMutex.Lock()
|
conntrack := f.Conntrack
|
||||||
if _, ok := f.Conns[fp]; !ok {
|
conntrack.Lock()
|
||||||
f.TimerWheel.Add(fp, timeout)
|
if _, ok := conntrack.Conns[fp]; !ok {
|
||||||
|
conntrack.TimerWheel.Add(fp, timeout)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Record which rulesVersion allowed this connection, so we can retest after
|
||||||
|
// firewall reload
|
||||||
|
c.incoming = incoming
|
||||||
|
c.rulesVersion = f.rulesVersion
|
||||||
c.Expires = time.Now().Add(timeout)
|
c.Expires = time.Now().Add(timeout)
|
||||||
f.Conns[fp] = c
|
conntrack.Conns[fp] = c
|
||||||
f.connMutex.Unlock()
|
conntrack.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Evict checks if a conntrack entry has expired, if so it is removed, if not it is re-added to the wheel
|
// Evict checks if a conntrack entry has expired, if so it is removed, if not it is re-added to the wheel
|
||||||
@@ -463,7 +533,8 @@ func (f *Firewall) addConn(packet []byte, fp FirewallPacket, incoming bool) {
|
|||||||
func (f *Firewall) evict(p FirewallPacket) {
|
func (f *Firewall) evict(p FirewallPacket) {
|
||||||
//TODO: report a stat if the tcp rtt tracking was never resolved?
|
//TODO: report a stat if the tcp rtt tracking was never resolved?
|
||||||
// Are we still tracking this conn?
|
// Are we still tracking this conn?
|
||||||
t, ok := f.Conns[p]
|
conntrack := f.Conntrack
|
||||||
|
t, ok := conntrack.Conns[p]
|
||||||
if !ok {
|
if !ok {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -472,12 +543,12 @@ func (f *Firewall) evict(p FirewallPacket) {
|
|||||||
|
|
||||||
// Timeout is in the future, re-add the timer
|
// Timeout is in the future, re-add the timer
|
||||||
if newT > 0 {
|
if newT > 0 {
|
||||||
f.TimerWheel.Add(p, newT)
|
conntrack.TimerWheel.Add(p, newT)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// This conn is done
|
// This conn is done
|
||||||
delete(f.Conns, p)
|
delete(conntrack.Conns, p)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ft *FirewallTable) match(p FirewallPacket, incoming bool, c *cert.NebulaCertificate, caPool *cert.NebulaCAPool) bool {
|
func (ft *FirewallTable) match(p FirewallPacket, incoming bool, c *cert.NebulaCertificate, caPool *cert.NebulaCAPool) bool {
|
||||||
|
|||||||
132
firewall_test.go
132
firewall_test.go
@@ -17,37 +17,39 @@ import (
|
|||||||
func TestNewFirewall(t *testing.T) {
|
func TestNewFirewall(t *testing.T) {
|
||||||
c := &cert.NebulaCertificate{}
|
c := &cert.NebulaCertificate{}
|
||||||
fw := NewFirewall(time.Second, time.Minute, time.Hour, c)
|
fw := NewFirewall(time.Second, time.Minute, time.Hour, c)
|
||||||
assert.NotNil(t, fw.Conns)
|
conntrack := fw.Conntrack
|
||||||
|
assert.NotNil(t, conntrack)
|
||||||
|
assert.NotNil(t, conntrack.Conns)
|
||||||
|
assert.NotNil(t, conntrack.TimerWheel)
|
||||||
assert.NotNil(t, fw.InRules)
|
assert.NotNil(t, fw.InRules)
|
||||||
assert.NotNil(t, fw.OutRules)
|
assert.NotNil(t, fw.OutRules)
|
||||||
assert.NotNil(t, fw.TimerWheel)
|
|
||||||
assert.Equal(t, time.Second, fw.TCPTimeout)
|
assert.Equal(t, time.Second, fw.TCPTimeout)
|
||||||
assert.Equal(t, time.Minute, fw.UDPTimeout)
|
assert.Equal(t, time.Minute, fw.UDPTimeout)
|
||||||
assert.Equal(t, time.Hour, fw.DefaultTimeout)
|
assert.Equal(t, time.Hour, fw.DefaultTimeout)
|
||||||
|
|
||||||
assert.Equal(t, time.Hour, fw.TimerWheel.wheelDuration)
|
assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
|
||||||
assert.Equal(t, time.Hour, fw.TimerWheel.wheelDuration)
|
assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
|
||||||
assert.Equal(t, 3601, fw.TimerWheel.wheelLen)
|
assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen)
|
||||||
|
|
||||||
fw = NewFirewall(time.Second, time.Hour, time.Minute, c)
|
fw = NewFirewall(time.Second, time.Hour, time.Minute, c)
|
||||||
assert.Equal(t, time.Hour, fw.TimerWheel.wheelDuration)
|
assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
|
||||||
assert.Equal(t, 3601, fw.TimerWheel.wheelLen)
|
assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen)
|
||||||
|
|
||||||
fw = NewFirewall(time.Hour, time.Second, time.Minute, c)
|
fw = NewFirewall(time.Hour, time.Second, time.Minute, c)
|
||||||
assert.Equal(t, time.Hour, fw.TimerWheel.wheelDuration)
|
assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
|
||||||
assert.Equal(t, 3601, fw.TimerWheel.wheelLen)
|
assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen)
|
||||||
|
|
||||||
fw = NewFirewall(time.Hour, time.Minute, time.Second, c)
|
fw = NewFirewall(time.Hour, time.Minute, time.Second, c)
|
||||||
assert.Equal(t, time.Hour, fw.TimerWheel.wheelDuration)
|
assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
|
||||||
assert.Equal(t, 3601, fw.TimerWheel.wheelLen)
|
assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen)
|
||||||
|
|
||||||
fw = NewFirewall(time.Minute, time.Hour, time.Second, c)
|
fw = NewFirewall(time.Minute, time.Hour, time.Second, c)
|
||||||
assert.Equal(t, time.Hour, fw.TimerWheel.wheelDuration)
|
assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
|
||||||
assert.Equal(t, 3601, fw.TimerWheel.wheelLen)
|
assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen)
|
||||||
|
|
||||||
fw = NewFirewall(time.Minute, time.Second, time.Hour, c)
|
fw = NewFirewall(time.Minute, time.Second, time.Hour, c)
|
||||||
assert.Equal(t, time.Hour, fw.TimerWheel.wheelDuration)
|
assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
|
||||||
assert.Equal(t, 3601, fw.TimerWheel.wheelLen)
|
assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestFirewall_AddRule(t *testing.T) {
|
func TestFirewall_AddRule(t *testing.T) {
|
||||||
@@ -180,44 +182,44 @@ func TestFirewall_Drop(t *testing.T) {
|
|||||||
cp := cert.NewCAPool()
|
cp := cert.NewCAPool()
|
||||||
|
|
||||||
// Drop outbound
|
// Drop outbound
|
||||||
assert.True(t, fw.Drop([]byte{}, p, false, &h, cp))
|
assert.Equal(t, fw.Drop([]byte{}, p, false, &h, cp), ErrNoMatchingRule)
|
||||||
// Allow inbound
|
// Allow inbound
|
||||||
resetConntrack(fw)
|
resetConntrack(fw)
|
||||||
assert.False(t, fw.Drop([]byte{}, p, true, &h, cp))
|
assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp))
|
||||||
// Allow outbound because conntrack
|
// Allow outbound because conntrack
|
||||||
assert.False(t, fw.Drop([]byte{}, p, false, &h, cp))
|
assert.NoError(t, fw.Drop([]byte{}, p, false, &h, cp))
|
||||||
|
|
||||||
// test remote mismatch
|
// test remote mismatch
|
||||||
oldRemote := p.RemoteIP
|
oldRemote := p.RemoteIP
|
||||||
p.RemoteIP = ip2int(net.IPv4(1, 2, 3, 10))
|
p.RemoteIP = ip2int(net.IPv4(1, 2, 3, 10))
|
||||||
assert.True(t, fw.Drop([]byte{}, p, false, &h, cp))
|
assert.Equal(t, fw.Drop([]byte{}, p, false, &h, cp), ErrInvalidRemoteIP)
|
||||||
p.RemoteIP = oldRemote
|
p.RemoteIP = oldRemote
|
||||||
|
|
||||||
// ensure signer doesn't get in the way of group checks
|
// ensure signer doesn't get in the way of group checks
|
||||||
fw = NewFirewall(time.Second, time.Minute, time.Hour, &c)
|
fw = NewFirewall(time.Second, time.Minute, time.Hour, &c)
|
||||||
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"nope"}, "", nil, "", "signer-shasum"))
|
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"nope"}, "", nil, "", "signer-shasum"))
|
||||||
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group"}, "", nil, "", "signer-shasum-bad"))
|
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group"}, "", nil, "", "signer-shasum-bad"))
|
||||||
assert.True(t, fw.Drop([]byte{}, p, true, &h, cp))
|
assert.Equal(t, fw.Drop([]byte{}, p, true, &h, cp), ErrNoMatchingRule)
|
||||||
|
|
||||||
// test caSha doesn't drop on match
|
// test caSha doesn't drop on match
|
||||||
fw = NewFirewall(time.Second, time.Minute, time.Hour, &c)
|
fw = NewFirewall(time.Second, time.Minute, time.Hour, &c)
|
||||||
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"nope"}, "", nil, "", "signer-shasum-bad"))
|
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"nope"}, "", nil, "", "signer-shasum-bad"))
|
||||||
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group"}, "", nil, "", "signer-shasum"))
|
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group"}, "", nil, "", "signer-shasum"))
|
||||||
assert.False(t, fw.Drop([]byte{}, p, true, &h, cp))
|
assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp))
|
||||||
|
|
||||||
// ensure ca name doesn't get in the way of group checks
|
// ensure ca name doesn't get in the way of group checks
|
||||||
cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}}
|
cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}}
|
||||||
fw = NewFirewall(time.Second, time.Minute, time.Hour, &c)
|
fw = NewFirewall(time.Second, time.Minute, time.Hour, &c)
|
||||||
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"nope"}, "", nil, "ca-good", ""))
|
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"nope"}, "", nil, "ca-good", ""))
|
||||||
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group"}, "", nil, "ca-good-bad", ""))
|
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group"}, "", nil, "ca-good-bad", ""))
|
||||||
assert.True(t, fw.Drop([]byte{}, p, true, &h, cp))
|
assert.Equal(t, fw.Drop([]byte{}, p, true, &h, cp), ErrNoMatchingRule)
|
||||||
|
|
||||||
// test caName doesn't drop on match
|
// test caName doesn't drop on match
|
||||||
cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}}
|
cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}}
|
||||||
fw = NewFirewall(time.Second, time.Minute, time.Hour, &c)
|
fw = NewFirewall(time.Second, time.Minute, time.Hour, &c)
|
||||||
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"nope"}, "", nil, "ca-good-bad", ""))
|
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"nope"}, "", nil, "ca-good-bad", ""))
|
||||||
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group"}, "", nil, "ca-good", ""))
|
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group"}, "", nil, "ca-good", ""))
|
||||||
assert.False(t, fw.Drop([]byte{}, p, true, &h, cp))
|
assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp))
|
||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkFirewallTable_match(b *testing.B) {
|
func BenchmarkFirewallTable_match(b *testing.B) {
|
||||||
@@ -368,10 +370,10 @@ func TestFirewall_Drop2(t *testing.T) {
|
|||||||
cp := cert.NewCAPool()
|
cp := cert.NewCAPool()
|
||||||
|
|
||||||
// h1/c1 lacks the proper groups
|
// h1/c1 lacks the proper groups
|
||||||
assert.True(t, fw.Drop([]byte{}, p, true, &h1, cp))
|
assert.Error(t, fw.Drop([]byte{}, p, true, &h1, cp), ErrNoMatchingRule)
|
||||||
// c has the proper groups
|
// c has the proper groups
|
||||||
resetConntrack(fw)
|
resetConntrack(fw)
|
||||||
assert.False(t, fw.Drop([]byte{}, p, true, &h, cp))
|
assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestFirewall_Drop3(t *testing.T) {
|
func TestFirewall_Drop3(t *testing.T) {
|
||||||
@@ -452,13 +454,81 @@ func TestFirewall_Drop3(t *testing.T) {
|
|||||||
cp := cert.NewCAPool()
|
cp := cert.NewCAPool()
|
||||||
|
|
||||||
// c1 should pass because host match
|
// c1 should pass because host match
|
||||||
assert.False(t, fw.Drop([]byte{}, p, true, &h1, cp))
|
assert.NoError(t, fw.Drop([]byte{}, p, true, &h1, cp))
|
||||||
// c2 should pass because ca sha match
|
// c2 should pass because ca sha match
|
||||||
resetConntrack(fw)
|
resetConntrack(fw)
|
||||||
assert.False(t, fw.Drop([]byte{}, p, true, &h2, cp))
|
assert.NoError(t, fw.Drop([]byte{}, p, true, &h2, cp))
|
||||||
// c3 should fail because no match
|
// c3 should fail because no match
|
||||||
resetConntrack(fw)
|
resetConntrack(fw)
|
||||||
assert.True(t, fw.Drop([]byte{}, p, true, &h3, cp))
|
assert.Equal(t, fw.Drop([]byte{}, p, true, &h3, cp), ErrNoMatchingRule)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFirewall_DropConntrackReload(t *testing.T) {
|
||||||
|
ob := &bytes.Buffer{}
|
||||||
|
out := l.Out
|
||||||
|
l.SetOutput(ob)
|
||||||
|
defer l.SetOutput(out)
|
||||||
|
|
||||||
|
p := FirewallPacket{
|
||||||
|
ip2int(net.IPv4(1, 2, 3, 4)),
|
||||||
|
ip2int(net.IPv4(1, 2, 3, 4)),
|
||||||
|
10,
|
||||||
|
90,
|
||||||
|
fwProtoUDP,
|
||||||
|
false,
|
||||||
|
}
|
||||||
|
|
||||||
|
ipNet := net.IPNet{
|
||||||
|
IP: net.IPv4(1, 2, 3, 4),
|
||||||
|
Mask: net.IPMask{255, 255, 255, 0},
|
||||||
|
}
|
||||||
|
|
||||||
|
c := cert.NebulaCertificate{
|
||||||
|
Details: cert.NebulaCertificateDetails{
|
||||||
|
Name: "host1",
|
||||||
|
Ips: []*net.IPNet{&ipNet},
|
||||||
|
Groups: []string{"default-group"},
|
||||||
|
InvertedGroups: map[string]struct{}{"default-group": {}},
|
||||||
|
Issuer: "signer-shasum",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
h := HostInfo{
|
||||||
|
ConnectionState: &ConnectionState{
|
||||||
|
peerCert: &c,
|
||||||
|
},
|
||||||
|
hostId: ip2int(ipNet.IP),
|
||||||
|
}
|
||||||
|
h.CreateRemoteCIDR(&c)
|
||||||
|
|
||||||
|
fw := NewFirewall(time.Second, time.Minute, time.Hour, &c)
|
||||||
|
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"any"}, "", nil, "", ""))
|
||||||
|
cp := cert.NewCAPool()
|
||||||
|
|
||||||
|
// Drop outbound
|
||||||
|
assert.Equal(t, fw.Drop([]byte{}, p, false, &h, cp), ErrNoMatchingRule)
|
||||||
|
// Allow inbound
|
||||||
|
resetConntrack(fw)
|
||||||
|
assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp))
|
||||||
|
// Allow outbound because conntrack
|
||||||
|
assert.NoError(t, fw.Drop([]byte{}, p, false, &h, cp))
|
||||||
|
|
||||||
|
oldFw := fw
|
||||||
|
fw = NewFirewall(time.Second, time.Minute, time.Hour, &c)
|
||||||
|
assert.Nil(t, fw.AddRule(true, fwProtoAny, 10, 10, []string{"any"}, "", nil, "", ""))
|
||||||
|
fw.Conntrack = oldFw.Conntrack
|
||||||
|
fw.rulesVersion = oldFw.rulesVersion + 1
|
||||||
|
|
||||||
|
// Allow outbound because conntrack and new rules allow port 10
|
||||||
|
assert.NoError(t, fw.Drop([]byte{}, p, false, &h, cp))
|
||||||
|
|
||||||
|
oldFw = fw
|
||||||
|
fw = NewFirewall(time.Second, time.Minute, time.Hour, &c)
|
||||||
|
assert.Nil(t, fw.AddRule(true, fwProtoAny, 11, 11, []string{"any"}, "", nil, "", ""))
|
||||||
|
fw.Conntrack = oldFw.Conntrack
|
||||||
|
fw.rulesVersion = oldFw.rulesVersion + 1
|
||||||
|
|
||||||
|
// Drop outbound because conntrack doesn't match new ruleset
|
||||||
|
assert.Equal(t, fw.Drop([]byte{}, p, false, &h, cp), ErrNoMatchingRule)
|
||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkLookup(b *testing.B) {
|
func BenchmarkLookup(b *testing.B) {
|
||||||
@@ -861,7 +931,7 @@ func (mf *mockFirewall) AddRule(incoming bool, proto uint8, startPort int32, end
|
|||||||
}
|
}
|
||||||
|
|
||||||
func resetConntrack(fw *Firewall) {
|
func resetConntrack(fw *Firewall) {
|
||||||
fw.connMutex.Lock()
|
fw.Conntrack.Lock()
|
||||||
fw.Conns = map[FirewallPacket]*conn{}
|
fw.Conntrack.Conns = map[FirewallPacket]*conn{}
|
||||||
fw.connMutex.Unlock()
|
fw.Conntrack.Unlock()
|
||||||
}
|
}
|
||||||
|
|||||||
4
go.mod
4
go.mod
@@ -11,7 +11,7 @@ require (
|
|||||||
github.com/flynn/noise v0.0.0-20180327030543-2492fe189ae6
|
github.com/flynn/noise v0.0.0-20180327030543-2492fe189ae6
|
||||||
github.com/golang/protobuf v1.3.2
|
github.com/golang/protobuf v1.3.2
|
||||||
github.com/imdario/mergo v0.3.8
|
github.com/imdario/mergo v0.3.8
|
||||||
github.com/kardianos/service v1.0.0
|
github.com/kardianos/service v1.1.0
|
||||||
github.com/konsorten/go-windows-terminal-sequences v1.0.2 // indirect
|
github.com/konsorten/go-windows-terminal-sequences v1.0.2 // indirect
|
||||||
github.com/kr/pretty v0.1.0 // indirect
|
github.com/kr/pretty v0.1.0 // indirect
|
||||||
github.com/miekg/dns v1.1.25
|
github.com/miekg/dns v1.1.25
|
||||||
@@ -22,7 +22,7 @@ require (
|
|||||||
github.com/rcrowley/go-metrics v0.0.0-20190826022208-cac0b30c2563
|
github.com/rcrowley/go-metrics v0.0.0-20190826022208-cac0b30c2563
|
||||||
github.com/sirupsen/logrus v1.4.2
|
github.com/sirupsen/logrus v1.4.2
|
||||||
github.com/songgao/water v0.0.0-20190725173103-fd331bda3f4b
|
github.com/songgao/water v0.0.0-20190725173103-fd331bda3f4b
|
||||||
github.com/stretchr/testify v1.4.0
|
github.com/stretchr/testify v1.6.1
|
||||||
github.com/vishvananda/netlink v1.0.1-0.20190522153524-00009fb8606a
|
github.com/vishvananda/netlink v1.0.1-0.20190522153524-00009fb8606a
|
||||||
github.com/vishvananda/netns v0.0.0-20191106174202-0a2b9b5464df // indirect
|
github.com/vishvananda/netns v0.0.0-20191106174202-0a2b9b5464df // indirect
|
||||||
golang.org/x/crypto v0.0.0-20200220183623-bac4c82f6975
|
golang.org/x/crypto v0.0.0-20200220183623-bac4c82f6975
|
||||||
|
|||||||
10
go.sum
10
go.sum
@@ -46,6 +46,8 @@ github.com/json-iterator/go v1.1.7/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/u
|
|||||||
github.com/julienschmidt/httprouter v1.2.0/go.mod h1:SYymIcj16QtmaHHD7aYtjjsJG7VTCxuUUipMqKk8s4w=
|
github.com/julienschmidt/httprouter v1.2.0/go.mod h1:SYymIcj16QtmaHHD7aYtjjsJG7VTCxuUUipMqKk8s4w=
|
||||||
github.com/kardianos/service v1.0.0 h1:HgQS3mFfOlyntWX8Oke98JcJLqt1DBcHR4kxShpYef0=
|
github.com/kardianos/service v1.0.0 h1:HgQS3mFfOlyntWX8Oke98JcJLqt1DBcHR4kxShpYef0=
|
||||||
github.com/kardianos/service v1.0.0/go.mod h1:8CzDhVuCuugtsHyZoTvsOBuvonN/UDBvl0kH+BUxvbo=
|
github.com/kardianos/service v1.0.0/go.mod h1:8CzDhVuCuugtsHyZoTvsOBuvonN/UDBvl0kH+BUxvbo=
|
||||||
|
github.com/kardianos/service v1.1.0 h1:QV2SiEeWK42P0aEmGcsAgjApw/lRxkwopvT+Gu6t1/0=
|
||||||
|
github.com/kardianos/service v1.1.0/go.mod h1:RrJI2xn5vve/r32U5suTbeaSGoMU6GbNPoj36CVYcHc=
|
||||||
github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
|
github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
|
||||||
github.com/konsorten/go-windows-terminal-sequences v1.0.2 h1:DB17ag19krx9CFsz4o3enTrPXyIXCl+2iCXH/aMAp9s=
|
github.com/konsorten/go-windows-terminal-sequences v1.0.2 h1:DB17ag19krx9CFsz4o3enTrPXyIXCl+2iCXH/aMAp9s=
|
||||||
github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
|
github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
|
||||||
@@ -103,8 +105,8 @@ github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+
|
|||||||
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
|
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
|
||||||
github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q=
|
github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q=
|
||||||
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||||
github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk=
|
github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0=
|
||||||
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
|
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||||
github.com/vishvananda/netlink v1.0.1-0.20190522153524-00009fb8606a h1:Bt1IVPhiCDMqwGrc2nnbIN4QKvJGx6SK2NzWBmW00ao=
|
github.com/vishvananda/netlink v1.0.1-0.20190522153524-00009fb8606a h1:Bt1IVPhiCDMqwGrc2nnbIN4QKvJGx6SK2NzWBmW00ao=
|
||||||
github.com/vishvananda/netlink v1.0.1-0.20190522153524-00009fb8606a/go.mod h1:+SR5DhBJrl6ZM7CoCKvpw5BKroDKQ+PJqOg65H/2ktk=
|
github.com/vishvananda/netlink v1.0.1-0.20190522153524-00009fb8606a/go.mod h1:+SR5DhBJrl6ZM7CoCKvpw5BKroDKQ+PJqOg65H/2ktk=
|
||||||
github.com/vishvananda/netns v0.0.0-20191106174202-0a2b9b5464df h1:OviZH7qLw/7ZovXvuNyL3XQl8UFofeikI1NW1Gypu7k=
|
github.com/vishvananda/netns v0.0.0-20191106174202-0a2b9b5464df h1:OviZH7qLw/7ZovXvuNyL3XQl8UFofeikI1NW1Gypu7k=
|
||||||
@@ -112,8 +114,6 @@ github.com/vishvananda/netns v0.0.0-20191106174202-0a2b9b5464df/go.mod h1:JP3t17
|
|||||||
golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
|
golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
|
||||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||||
golang.org/x/crypto v0.0.0-20190923035154-9ee001bba392/go.mod h1:/lpIB1dKB+9EgE3H3cr1v9wB50oz8l4C4h62xy7jSTY=
|
golang.org/x/crypto v0.0.0-20190923035154-9ee001bba392/go.mod h1:/lpIB1dKB+9EgE3H3cr1v9wB50oz8l4C4h62xy7jSTY=
|
||||||
golang.org/x/crypto v0.0.0-20191206172530-e9b2fee46413 h1:ULYEB3JvPRE/IfO+9uO7vKV/xzVTO7XPAwm8xbf4w2g=
|
|
||||||
golang.org/x/crypto v0.0.0-20191206172530-e9b2fee46413/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
|
||||||
golang.org/x/crypto v0.0.0-20200220183623-bac4c82f6975 h1:/Tl7pH94bvbAAHBdZJT947M/+gp0+CqQXDtMRC0fseo=
|
golang.org/x/crypto v0.0.0-20200220183623-bac4c82f6975 h1:/Tl7pH94bvbAAHBdZJT947M/+gp0+CqQXDtMRC0fseo=
|
||||||
golang.org/x/crypto v0.0.0-20200220183623-bac4c82f6975/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
golang.org/x/crypto v0.0.0-20200220183623-bac4c82f6975/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
||||||
golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||||
@@ -154,3 +154,5 @@ gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw=
|
|||||||
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||||
gopkg.in/yaml.v2 v2.2.7 h1:VUgggvou5XRW9mHwD/yXxIYSMtY0zoKQf/v226p2nyo=
|
gopkg.in/yaml.v2 v2.2.7 h1:VUgggvou5XRW9mHwD/yXxIYSMtY0zoKQf/v226p2nyo=
|
||||||
gopkg.in/yaml.v2 v2.2.7/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
gopkg.in/yaml.v2 v2.2.7/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||||
|
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo=
|
||||||
|
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||||
|
|||||||
83
handler.go
Normal file
83
handler.go
Normal file
@@ -0,0 +1,83 @@
|
|||||||
|
package nebula
|
||||||
|
|
||||||
|
func (f *Interface) newHook(w func([]byte) error) InsideHandler {
|
||||||
|
fn := func(hostInfo *HostInfo, ci *ConnectionState, addr *udpAddr, header *Header, out []byte, packet []byte, fwPacket *FirewallPacket, nb []byte) {
|
||||||
|
f.decryptTo(w, hostInfo, header.MessageCounter, out, packet, fwPacket, nb)
|
||||||
|
}
|
||||||
|
return f.encrypted(fn)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *Interface) encrypted(h InsideHandler) InsideHandler {
|
||||||
|
return func(hostInfo *HostInfo, ci *ConnectionState, addr *udpAddr, header *Header, out []byte, packet []byte, fwPacket *FirewallPacket, nb []byte) {
|
||||||
|
if !f.handleEncrypted(ci, addr, header) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
h(hostInfo, ci, addr, header, out, packet, fwPacket, nb)
|
||||||
|
|
||||||
|
f.handleHostRoaming(hostInfo, addr)
|
||||||
|
f.connectionManager.In(hostInfo.hostId)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *Interface) rxMetrics(h InsideHandler) InsideHandler {
|
||||||
|
return func(hostInfo *HostInfo, ci *ConnectionState, addr *udpAddr, header *Header, out []byte, packet []byte, fwPacket *FirewallPacket, nb []byte) {
|
||||||
|
f.messageMetrics.Rx(header.Type, header.Subtype, 1)
|
||||||
|
h(hostInfo, ci, addr, header, out, packet, fwPacket, nb)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *Interface) handleMessagePacket(hostInfo *HostInfo, ci *ConnectionState, addr *udpAddr, header *Header, out []byte, packet []byte, fwPacket *FirewallPacket, nb []byte) {
|
||||||
|
f.decryptTo(f.inside.WriteRaw, hostInfo, header.MessageCounter, out, packet, fwPacket, nb)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *Interface) handleLighthousePacket(hostInfo *HostInfo, ci *ConnectionState, addr *udpAddr, header *Header, out []byte, packet []byte, fwPacket *FirewallPacket, nb []byte) {
|
||||||
|
d, err := f.decrypt(hostInfo, header.MessageCounter, out, packet, header, nb)
|
||||||
|
if err != nil {
|
||||||
|
hostInfo.logger().WithError(err).WithField("udpAddr", addr).
|
||||||
|
WithField("packet", packet).
|
||||||
|
Error("Failed to decrypt lighthouse packet")
|
||||||
|
|
||||||
|
//TODO: maybe after build 64 is out? 06/14/2018 - NB
|
||||||
|
//f.sendRecvError(net.Addr(addr), header.RemoteIndex)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
f.lightHouse.HandleRequest(addr, hostInfo.hostId, d, hostInfo.GetCert(), f)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *Interface) handleTestPacket(hostInfo *HostInfo, ci *ConnectionState, addr *udpAddr, header *Header, out []byte, packet []byte, fwPacket *FirewallPacket, nb []byte) {
|
||||||
|
d, err := f.decrypt(hostInfo, header.MessageCounter, out, packet, header, nb)
|
||||||
|
if err != nil {
|
||||||
|
hostInfo.logger().WithError(err).WithField("udpAddr", addr).
|
||||||
|
WithField("packet", packet).
|
||||||
|
Error("Failed to decrypt test packet")
|
||||||
|
|
||||||
|
//TODO: maybe after build 64 is out? 06/14/2018 - NB
|
||||||
|
//f.sendRecvError(net.Addr(addr), header.RemoteIndex)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if header.Subtype == testRequest {
|
||||||
|
// This testRequest might be from TryPromoteBest, so we should roam
|
||||||
|
// to the new IP address before responding
|
||||||
|
f.handleHostRoaming(hostInfo, addr)
|
||||||
|
f.send(test, testReply, ci, hostInfo, hostInfo.remote, d, nb, out)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *Interface) handleHandshakePacket(hostInfo *HostInfo, ci *ConnectionState, addr *udpAddr, header *Header, out []byte, packet []byte, fwPacket *FirewallPacket, nb []byte) {
|
||||||
|
HandleIncomingHandshake(f, addr, packet, header, hostInfo)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *Interface) handleRecvErrorPacket(hostInfo *HostInfo, ci *ConnectionState, addr *udpAddr, header *Header, out []byte, packet []byte, fwPacket *FirewallPacket, nb []byte) {
|
||||||
|
// TODO: Remove this with recv_error deprecation
|
||||||
|
f.handleRecvError(addr, header)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *Interface) handleCloseTunnelPacket(hostInfo *HostInfo, ci *ConnectionState, addr *udpAddr, header *Header, out []byte, packet []byte, fwPacket *FirewallPacket, nb []byte) {
|
||||||
|
hostInfo.logger().WithField("udpAddr", addr).
|
||||||
|
Info("Close tunnel received, tearing down.")
|
||||||
|
|
||||||
|
f.closeTunnel(hostInfo)
|
||||||
|
}
|
||||||
@@ -1,11 +1,10 @@
|
|||||||
package nebula
|
package nebula
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"bytes"
|
|
||||||
|
|
||||||
"github.com/flynn/noise"
|
"github.com/flynn/noise"
|
||||||
"github.com/golang/protobuf/proto"
|
"github.com/golang/protobuf/proto"
|
||||||
)
|
)
|
||||||
@@ -98,6 +97,7 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
|
|||||||
hostinfo, _ := f.handshakeManager.pendingHostMap.QueryReverseIndex(hs.Details.InitiatorIndex)
|
hostinfo, _ := f.handshakeManager.pendingHostMap.QueryReverseIndex(hs.Details.InitiatorIndex)
|
||||||
if hostinfo != nil && bytes.Equal(hostinfo.HandshakePacket[0], packet[HeaderLen:]) {
|
if hostinfo != nil && bytes.Equal(hostinfo.HandshakePacket[0], packet[HeaderLen:]) {
|
||||||
if msg, ok := hostinfo.HandshakePacket[2]; ok {
|
if msg, ok := hostinfo.HandshakePacket[2]; ok {
|
||||||
|
f.messageMetrics.Tx(handshake, NebulaMessageSubType(msg[1]), 1)
|
||||||
err := f.outside.WriteTo(msg, addr)
|
err := f.outside.WriteTo(msg, addr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
|
l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
|
||||||
@@ -126,11 +126,13 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
|
|||||||
}
|
}
|
||||||
vpnIP := ip2int(remoteCert.Details.Ips[0].IP)
|
vpnIP := ip2int(remoteCert.Details.Ips[0].IP)
|
||||||
certName := remoteCert.Details.Name
|
certName := remoteCert.Details.Name
|
||||||
|
fingerprint, _ := remoteCert.Sha256Sum()
|
||||||
|
|
||||||
myIndex, err := generateIndex()
|
myIndex, err := generateIndex()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
l.WithError(err).WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
|
l.WithError(err).WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
|
||||||
WithField("certName", certName).
|
WithField("certName", certName).
|
||||||
|
WithField("fingerprint", fingerprint).
|
||||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to generate index")
|
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to generate index")
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
@@ -139,12 +141,14 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
l.WithError(err).WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
|
l.WithError(err).WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
|
||||||
WithField("certName", certName).
|
WithField("certName", certName).
|
||||||
|
WithField("fingerprint", fingerprint).
|
||||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Error adding index to connection manager")
|
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Error adding index to connection manager")
|
||||||
|
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
|
l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
|
||||||
WithField("certName", certName).
|
WithField("certName", certName).
|
||||||
|
WithField("fingerprint", fingerprint).
|
||||||
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
||||||
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
||||||
Info("Handshake message received")
|
Info("Handshake message received")
|
||||||
@@ -157,6 +161,7 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
|
l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
|
||||||
WithField("certName", certName).
|
WithField("certName", certName).
|
||||||
|
WithField("fingerprint", fingerprint).
|
||||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to marshal handshake message")
|
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to marshal handshake message")
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
@@ -166,6 +171,7 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
|
l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
|
||||||
WithField("certName", certName).
|
WithField("certName", certName).
|
||||||
|
WithField("fingerprint", fingerprint).
|
||||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage")
|
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage")
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
@@ -173,6 +179,7 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
|
|||||||
if f.hostMap.CheckHandshakeCompleteIP(vpnIP) && vpnIP < ip2int(f.certState.certificate.Details.Ips[0].IP) {
|
if f.hostMap.CheckHandshakeCompleteIP(vpnIP) && vpnIP < ip2int(f.certState.certificate.Details.Ips[0].IP) {
|
||||||
l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
|
l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
|
||||||
WithField("certName", certName).
|
WithField("certName", certName).
|
||||||
|
WithField("fingerprint", fingerprint).
|
||||||
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
||||||
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
||||||
Info("Prevented a handshake race")
|
Info("Prevented a handshake race")
|
||||||
@@ -191,16 +198,19 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
|
|||||||
hostinfo.HandshakePacket[2] = make([]byte, len(msg))
|
hostinfo.HandshakePacket[2] = make([]byte, len(msg))
|
||||||
copy(hostinfo.HandshakePacket[2], msg)
|
copy(hostinfo.HandshakePacket[2], msg)
|
||||||
|
|
||||||
|
f.messageMetrics.Tx(handshake, NebulaMessageSubType(msg[1]), 1)
|
||||||
err := f.outside.WriteTo(msg, addr)
|
err := f.outside.WriteTo(msg, addr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
|
l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
|
||||||
WithField("certName", certName).
|
WithField("certName", certName).
|
||||||
|
WithField("fingerprint", fingerprint).
|
||||||
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
||||||
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
|
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
|
||||||
WithError(err).Error("Failed to send handshake")
|
WithError(err).Error("Failed to send handshake")
|
||||||
} else {
|
} else {
|
||||||
l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
|
l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
|
||||||
WithField("certName", certName).
|
WithField("certName", certName).
|
||||||
|
WithField("fingerprint", fingerprint).
|
||||||
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
||||||
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
|
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
|
||||||
Info("Handshake message sent")
|
Info("Handshake message sent")
|
||||||
@@ -224,6 +234,7 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
|
|||||||
if err == nil && ho.localIndexId != 0 {
|
if err == nil && ho.localIndexId != 0 {
|
||||||
l.WithField("vpnIp", vpnIP).
|
l.WithField("vpnIp", vpnIP).
|
||||||
WithField("certName", certName).
|
WithField("certName", certName).
|
||||||
|
WithField("fingerprint", fingerprint).
|
||||||
WithField("action", "removing stale index").
|
WithField("action", "removing stale index").
|
||||||
WithField("index", ho.localIndexId).
|
WithField("index", ho.localIndexId).
|
||||||
Debug("Handshake processing")
|
Debug("Handshake processing")
|
||||||
@@ -237,6 +248,7 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
|
|||||||
} else {
|
} else {
|
||||||
l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
|
l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
|
||||||
WithField("certName", certName).
|
WithField("certName", certName).
|
||||||
|
WithField("fingerprint", fingerprint).
|
||||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
||||||
Error("Noise did not arrive at a key")
|
Error("Noise did not arrive at a key")
|
||||||
return true
|
return true
|
||||||
@@ -296,10 +308,12 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
|
|||||||
}
|
}
|
||||||
vpnIP := ip2int(remoteCert.Details.Ips[0].IP)
|
vpnIP := ip2int(remoteCert.Details.Ips[0].IP)
|
||||||
certName := remoteCert.Details.Name
|
certName := remoteCert.Details.Name
|
||||||
|
fingerprint, _ := remoteCert.Sha256Sum()
|
||||||
|
|
||||||
duration := time.Since(hostinfo.handshakeStart).Nanoseconds()
|
duration := time.Since(hostinfo.handshakeStart).Nanoseconds()
|
||||||
l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
|
l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
|
||||||
WithField("certName", certName).
|
WithField("certName", certName).
|
||||||
|
WithField("fingerprint", fingerprint).
|
||||||
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
||||||
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
|
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
|
||||||
WithField("durationNs", duration).
|
WithField("durationNs", duration).
|
||||||
@@ -338,6 +352,7 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
|
|||||||
if err == nil && ho.localIndexId != 0 {
|
if err == nil && ho.localIndexId != 0 {
|
||||||
l.WithField("vpnIp", vpnIP).
|
l.WithField("vpnIp", vpnIP).
|
||||||
WithField("certName", certName).
|
WithField("certName", certName).
|
||||||
|
WithField("fingerprint", fingerprint).
|
||||||
WithField("action", "removing stale index").
|
WithField("action", "removing stale index").
|
||||||
WithField("index", ho.localIndexId).
|
WithField("index", ho.localIndexId).
|
||||||
Debug("Handshake processing")
|
Debug("Handshake processing")
|
||||||
@@ -352,6 +367,7 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
|
|||||||
} else {
|
} else {
|
||||||
l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
|
l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
|
||||||
WithField("certName", certName).
|
WithField("certName", certName).
|
||||||
|
WithField("fingerprint", fingerprint).
|
||||||
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
|
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
|
||||||
Error("Noise did not arrive at a key")
|
Error("Noise did not arrive at a key")
|
||||||
return true
|
return true
|
||||||
|
|||||||
@@ -16,21 +16,26 @@ const (
|
|||||||
DefaultHandshakeTryInterval = time.Millisecond * 100
|
DefaultHandshakeTryInterval = time.Millisecond * 100
|
||||||
DefaultHandshakeRetries = 20
|
DefaultHandshakeRetries = 20
|
||||||
// DefaultHandshakeWaitRotation is the number of handshake attempts to do before starting to use other ips addresses
|
// DefaultHandshakeWaitRotation is the number of handshake attempts to do before starting to use other ips addresses
|
||||||
DefaultHandshakeWaitRotation = 5
|
DefaultHandshakeWaitRotation = 5
|
||||||
|
DefaultHandshakeTriggerBuffer = 64
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
defaultHandshakeConfig = HandshakeConfig{
|
defaultHandshakeConfig = HandshakeConfig{
|
||||||
tryInterval: DefaultHandshakeTryInterval,
|
tryInterval: DefaultHandshakeTryInterval,
|
||||||
retries: DefaultHandshakeRetries,
|
retries: DefaultHandshakeRetries,
|
||||||
waitRotation: DefaultHandshakeWaitRotation,
|
waitRotation: DefaultHandshakeWaitRotation,
|
||||||
|
triggerBuffer: DefaultHandshakeTriggerBuffer,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
type HandshakeConfig struct {
|
type HandshakeConfig struct {
|
||||||
tryInterval time.Duration
|
tryInterval time.Duration
|
||||||
retries int
|
retries int
|
||||||
waitRotation int
|
waitRotation int
|
||||||
|
triggerBuffer int
|
||||||
|
|
||||||
|
messageMetrics *MessageMetrics
|
||||||
}
|
}
|
||||||
|
|
||||||
type HandshakeManager struct {
|
type HandshakeManager struct {
|
||||||
@@ -40,8 +45,13 @@ type HandshakeManager struct {
|
|||||||
outside *udpConn
|
outside *udpConn
|
||||||
config HandshakeConfig
|
config HandshakeConfig
|
||||||
|
|
||||||
|
// can be used to trigger outbound handshake for the given vpnIP
|
||||||
|
trigger chan uint32
|
||||||
|
|
||||||
OutboundHandshakeTimer *SystemTimerWheel
|
OutboundHandshakeTimer *SystemTimerWheel
|
||||||
InboundHandshakeTimer *SystemTimerWheel
|
InboundHandshakeTimer *SystemTimerWheel
|
||||||
|
|
||||||
|
messageMetrics *MessageMetrics
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewHandshakeManager(tunCidr *net.IPNet, preferredRanges []*net.IPNet, mainHostMap *HostMap, lightHouse *LightHouse, outside *udpConn, config HandshakeConfig) *HandshakeManager {
|
func NewHandshakeManager(tunCidr *net.IPNet, preferredRanges []*net.IPNet, mainHostMap *HostMap, lightHouse *LightHouse, outside *udpConn, config HandshakeConfig) *HandshakeManager {
|
||||||
@@ -53,16 +63,26 @@ func NewHandshakeManager(tunCidr *net.IPNet, preferredRanges []*net.IPNet, mainH
|
|||||||
|
|
||||||
config: config,
|
config: config,
|
||||||
|
|
||||||
|
trigger: make(chan uint32, config.triggerBuffer),
|
||||||
|
|
||||||
OutboundHandshakeTimer: NewSystemTimerWheel(config.tryInterval, config.tryInterval*time.Duration(config.retries)),
|
OutboundHandshakeTimer: NewSystemTimerWheel(config.tryInterval, config.tryInterval*time.Duration(config.retries)),
|
||||||
InboundHandshakeTimer: NewSystemTimerWheel(config.tryInterval, config.tryInterval*time.Duration(config.retries)),
|
InboundHandshakeTimer: NewSystemTimerWheel(config.tryInterval, config.tryInterval*time.Duration(config.retries)),
|
||||||
|
|
||||||
|
messageMetrics: config.messageMetrics,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *HandshakeManager) Run(f EncWriter) {
|
func (c *HandshakeManager) Run(f EncWriter) {
|
||||||
clockSource := time.Tick(c.config.tryInterval)
|
clockSource := time.Tick(c.config.tryInterval)
|
||||||
for now := range clockSource {
|
for {
|
||||||
c.NextOutboundHandshakeTimerTick(now, f)
|
select {
|
||||||
c.NextInboundHandshakeTimerTick(now)
|
case vpnIP := <-c.trigger:
|
||||||
|
l.WithField("vpnIp", IntIp(vpnIP)).Debug("HandshakeManager: triggered")
|
||||||
|
c.handleOutbound(vpnIP, f, true)
|
||||||
|
case now := <-clockSource:
|
||||||
|
c.NextOutboundHandshakeTimerTick(now, f)
|
||||||
|
c.NextInboundHandshakeTimerTick(now)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -74,68 +94,86 @@ func (c *HandshakeManager) NextOutboundHandshakeTimerTick(now time.Time, f EncWr
|
|||||||
break
|
break
|
||||||
}
|
}
|
||||||
vpnIP := ep.(uint32)
|
vpnIP := ep.(uint32)
|
||||||
|
c.handleOutbound(vpnIP, f, false)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
index, err := c.pendingHostMap.GetIndexByVpnIP(vpnIP)
|
func (c *HandshakeManager) handleOutbound(vpnIP uint32, f EncWriter, lighthouseTriggered bool) {
|
||||||
if err != nil {
|
index, err := c.pendingHostMap.GetIndexByVpnIP(vpnIP)
|
||||||
continue
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
hostinfo, err := c.pendingHostMap.QueryVpnIP(vpnIP)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// If we haven't finished the handshake and we haven't hit max retries, query
|
||||||
|
// lighthouse and then send the handshake packet again.
|
||||||
|
if hostinfo.HandshakeCounter < c.config.retries && !hostinfo.HandshakeComplete {
|
||||||
|
if hostinfo.remote == nil {
|
||||||
|
// We continue to query the lighthouse because hosts may
|
||||||
|
// come online during handshake retries. If the query
|
||||||
|
// succeeds (no error), add the lighthouse info to hostinfo
|
||||||
|
ips := c.lightHouse.QueryCache(vpnIP)
|
||||||
|
// If we have no responses yet, or only one IP (the host hadn't
|
||||||
|
// finished reporting its own IPs yet), then send another query to
|
||||||
|
// the LH.
|
||||||
|
if len(ips) <= 1 {
|
||||||
|
ips, err = c.lightHouse.Query(vpnIP, f)
|
||||||
|
}
|
||||||
|
if err == nil {
|
||||||
|
for _, ip := range ips {
|
||||||
|
hostinfo.AddRemote(ip)
|
||||||
|
}
|
||||||
|
hostinfo.ForcePromoteBest(c.mainHostMap.preferredRanges)
|
||||||
|
}
|
||||||
|
} else if lighthouseTriggered {
|
||||||
|
// We were triggered by a lighthouse HostQueryReply packet, but
|
||||||
|
// we have already picked a remote for this host (this can happen
|
||||||
|
// if we are configured with multiple lighthouses). So we can skip
|
||||||
|
// this trigger and let the timerwheel handle the rest of the
|
||||||
|
// process
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
hostinfo, err := c.pendingHostMap.QueryVpnIP(vpnIP)
|
hostinfo.HandshakeCounter++
|
||||||
if err != nil {
|
|
||||||
continue
|
// We want to use the "best" calculated ip for the first 5 attempts, after that we just blindly rotate through
|
||||||
|
// all the others until we can stand up a connection.
|
||||||
|
if hostinfo.HandshakeCounter > c.config.waitRotation {
|
||||||
|
hostinfo.rotateRemote()
|
||||||
}
|
}
|
||||||
|
|
||||||
// If we haven't finished the handshake and we haven't hit max retries, query
|
// Ensure the handshake is ready to avoid a race in timer tick and stage 0 handshake generation
|
||||||
// lighthouse and then send the handshake packet again.
|
if hostinfo.HandshakeReady && hostinfo.remote != nil {
|
||||||
if hostinfo.HandshakeCounter < c.config.retries && !hostinfo.HandshakeComplete {
|
c.messageMetrics.Tx(handshake, NebulaMessageSubType(hostinfo.HandshakePacket[0][1]), 1)
|
||||||
if hostinfo.remote == nil {
|
err := c.outside.WriteTo(hostinfo.HandshakePacket[0], hostinfo.remote)
|
||||||
// We continue to query the lighthouse because hosts may
|
if err != nil {
|
||||||
// come online during handshake retries. If the query
|
hostinfo.logger().WithField("udpAddr", hostinfo.remote).
|
||||||
// succeeds (no error), add the lighthouse info to hostinfo
|
WithField("initiatorIndex", hostinfo.localIndexId).
|
||||||
ips, err := c.lightHouse.Query(vpnIP, f)
|
WithField("remoteIndex", hostinfo.remoteIndexId).
|
||||||
if err == nil {
|
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
||||||
for _, ip := range ips {
|
WithError(err).Error("Failed to send handshake message")
|
||||||
hostinfo.AddRemote(ip)
|
} else {
|
||||||
}
|
//TODO: this log line is assuming a lot of stuff around the cached stage 0 handshake packet, we should
|
||||||
hostinfo.ForcePromoteBest(c.mainHostMap.preferredRanges)
|
// keep the real packet struct around for logging purposes
|
||||||
}
|
hostinfo.logger().WithField("udpAddr", hostinfo.remote).
|
||||||
|
WithField("initiatorIndex", hostinfo.localIndexId).
|
||||||
|
WithField("remoteIndex", hostinfo.remoteIndexId).
|
||||||
|
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
||||||
|
Info("Handshake message sent")
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
hostinfo.HandshakeCounter++
|
// Readd to the timer wheel so we continue trying wait HandshakeTryInterval * counter longer for next try
|
||||||
|
if !lighthouseTriggered {
|
||||||
// We want to use the "best" calculated ip for the first 5 attempts, after that we just blindly rotate through
|
|
||||||
// all the others until we can stand up a connection.
|
|
||||||
if hostinfo.HandshakeCounter > c.config.waitRotation {
|
|
||||||
hostinfo.rotateRemote()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Ensure the handshake is ready to avoid a race in timer tick and stage 0 handshake generation
|
|
||||||
if hostinfo.HandshakeReady && hostinfo.remote != nil {
|
|
||||||
err := c.outside.WriteTo(hostinfo.HandshakePacket[0], hostinfo.remote)
|
|
||||||
if err != nil {
|
|
||||||
hostinfo.logger().WithField("udpAddr", hostinfo.remote).
|
|
||||||
WithField("initiatorIndex", hostinfo.localIndexId).
|
|
||||||
WithField("remoteIndex", hostinfo.remoteIndexId).
|
|
||||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
|
||||||
WithError(err).Error("Failed to send handshake message")
|
|
||||||
} else {
|
|
||||||
//TODO: this log line is assuming a lot of stuff around the cached stage 0 handshake packet, we should
|
|
||||||
// keep the real packet struct around for logging purposes
|
|
||||||
hostinfo.logger().WithField("udpAddr", hostinfo.remote).
|
|
||||||
WithField("initiatorIndex", hostinfo.localIndexId).
|
|
||||||
WithField("remoteIndex", hostinfo.remoteIndexId).
|
|
||||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
|
||||||
Info("Handshake message sent")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Readd to the timer wheel so we continue trying wait HandshakeTryInterval * counter longer for next try
|
|
||||||
//l.Infoln("Interval: ", HandshakeTryInterval*time.Duration(hostinfo.HandshakeCounter))
|
//l.Infoln("Interval: ", HandshakeTryInterval*time.Duration(hostinfo.HandshakeCounter))
|
||||||
c.OutboundHandshakeTimer.Add(vpnIP, c.config.tryInterval*time.Duration(hostinfo.HandshakeCounter))
|
c.OutboundHandshakeTimer.Add(vpnIP, c.config.tryInterval*time.Duration(hostinfo.HandshakeCounter))
|
||||||
} else {
|
|
||||||
c.pendingHostMap.DeleteVpnIP(vpnIP)
|
|
||||||
c.pendingHostMap.DeleteIndex(index)
|
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
c.pendingHostMap.DeleteVpnIP(vpnIP)
|
||||||
|
c.pendingHostMap.DeleteIndex(index)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -162,6 +200,7 @@ func (c *HandshakeManager) AddVpnIP(vpnIP uint32) *HostInfo {
|
|||||||
// We lock here and use an array to insert items to prevent locking the
|
// We lock here and use an array to insert items to prevent locking the
|
||||||
// main receive thread for very long by waiting to add items to the pending map
|
// main receive thread for very long by waiting to add items to the pending map
|
||||||
c.OutboundHandshakeTimer.Add(vpnIP, c.config.tryInterval)
|
c.OutboundHandshakeTimer.Add(vpnIP, c.config.tryInterval)
|
||||||
|
|
||||||
return hostinfo
|
return hostinfo
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -103,6 +103,56 @@ func Test_NewHandshakeManagerVpnIP(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func Test_NewHandshakeManagerTrigger(t *testing.T) {
|
||||||
|
_, tuncidr, _ := net.ParseCIDR("172.1.1.1/24")
|
||||||
|
_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
|
||||||
|
_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
|
||||||
|
ip := ip2int(net.ParseIP("172.1.1.2"))
|
||||||
|
preferredRanges := []*net.IPNet{localrange}
|
||||||
|
mw := &mockEncWriter{}
|
||||||
|
mainHM := NewHostMap("test", vpncidr, preferredRanges)
|
||||||
|
lh := &LightHouse{}
|
||||||
|
|
||||||
|
blah := NewHandshakeManager(tuncidr, preferredRanges, mainHM, lh, &udpConn{}, defaultHandshakeConfig)
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
blah.NextOutboundHandshakeTimerTick(now, mw)
|
||||||
|
|
||||||
|
assert.Equal(t, 0, testCountTimerWheelEntries(blah.OutboundHandshakeTimer))
|
||||||
|
|
||||||
|
blah.AddVpnIP(ip)
|
||||||
|
|
||||||
|
assert.Equal(t, 1, testCountTimerWheelEntries(blah.OutboundHandshakeTimer))
|
||||||
|
|
||||||
|
// Trigger the same method the channel will
|
||||||
|
blah.handleOutbound(ip, mw, true)
|
||||||
|
|
||||||
|
// Make sure the trigger doesn't schedule another timer entry
|
||||||
|
assert.Equal(t, 1, testCountTimerWheelEntries(blah.OutboundHandshakeTimer))
|
||||||
|
hi := blah.pendingHostMap.Hosts[ip]
|
||||||
|
assert.Nil(t, hi.remote)
|
||||||
|
|
||||||
|
lh.addrMap = map[uint32][]udpAddr{
|
||||||
|
ip: {*NewUDPAddrFromString("10.1.1.1:4242")},
|
||||||
|
}
|
||||||
|
|
||||||
|
// This should trigger the hostmap to populate the hostinfo
|
||||||
|
blah.handleOutbound(ip, mw, true)
|
||||||
|
assert.NotNil(t, hi.remote)
|
||||||
|
assert.Equal(t, 1, testCountTimerWheelEntries(blah.OutboundHandshakeTimer))
|
||||||
|
}
|
||||||
|
|
||||||
|
func testCountTimerWheelEntries(tw *SystemTimerWheel) (c int) {
|
||||||
|
for _, i := range tw.wheel {
|
||||||
|
n := i.Head
|
||||||
|
for n != nil {
|
||||||
|
c++
|
||||||
|
n = n.Next
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
func Test_NewHandshakeManagerVpnIPcleanup(t *testing.T) {
|
func Test_NewHandshakeManagerVpnIPcleanup(t *testing.T) {
|
||||||
_, tuncidr, _ := net.ParseCIDR("172.1.1.1/24")
|
_, tuncidr, _ := net.ParseCIDR("172.1.1.1/24")
|
||||||
_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
|
_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
|
||||||
|
|||||||
@@ -54,6 +54,7 @@ var typeMap = map[NebulaMessageType]string{
|
|||||||
}
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
subTypeNone NebulaMessageSubType = 0
|
||||||
testRequest NebulaMessageSubType = 0
|
testRequest NebulaMessageSubType = 0
|
||||||
testReply NebulaMessageSubType = 1
|
testReply NebulaMessageSubType = 1
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,9 +1,10 @@
|
|||||||
package nebula
|
package nebula
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"reflect"
|
"reflect"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
type headerTest struct {
|
type headerTest struct {
|
||||||
|
|||||||
@@ -30,6 +30,7 @@ type HostMap struct {
|
|||||||
vpnCIDR *net.IPNet
|
vpnCIDR *net.IPNet
|
||||||
defaultRoute uint32
|
defaultRoute uint32
|
||||||
unsafeRoutes *CIDRTree
|
unsafeRoutes *CIDRTree
|
||||||
|
metricsEnabled bool
|
||||||
}
|
}
|
||||||
|
|
||||||
type HostInfo struct {
|
type HostInfo struct {
|
||||||
@@ -384,8 +385,16 @@ func (hm *HostMap) PunchList() []*udpAddr {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (hm *HostMap) Punchy(conn *udpConn) {
|
func (hm *HostMap) Punchy(conn *udpConn) {
|
||||||
|
var metricsTxPunchy metrics.Counter
|
||||||
|
if hm.metricsEnabled {
|
||||||
|
metricsTxPunchy = metrics.GetOrRegisterCounter("messages.tx.punchy", nil)
|
||||||
|
} else {
|
||||||
|
metricsTxPunchy = metrics.NilCounter{}
|
||||||
|
}
|
||||||
|
|
||||||
for {
|
for {
|
||||||
for _, addr := range hm.PunchList() {
|
for _, addr := range hm.PunchList() {
|
||||||
|
metricsTxPunchy.Inc(1)
|
||||||
conn.WriteTo([]byte{1}, addr)
|
conn.WriteTo([]byte{1}, addr)
|
||||||
}
|
}
|
||||||
time.Sleep(time.Second * 30)
|
time.Sleep(time.Second * 30)
|
||||||
|
|||||||
@@ -161,6 +161,4 @@ func BenchmarkHostmappromote2(b *testing.B) {
|
|||||||
m.AddRemote(ip2int(net.ParseIP("10.128.1.1")), g)
|
m.AddRemote(ip2int(net.ParseIP("10.128.1.1")), g)
|
||||||
m.AddRemote(ip2int(net.ParseIP("10.128.1.1")), y)
|
m.AddRemote(ip2int(net.ParseIP("10.128.1.1")), y)
|
||||||
}
|
}
|
||||||
b.Errorf("hi")
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
71
inside.go
71
inside.go
@@ -7,7 +7,7 @@ import (
|
|||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *FirewallPacket, nb, out []byte) {
|
func (f *Interface) consumeInsidePacket(st NebulaMessageSubType, packet []byte, fwPacket *FirewallPacket, nb, out []byte) {
|
||||||
err := newPacket(packet, false, fwPacket)
|
err := newPacket(packet, false, fwPacket)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
l.WithField("packet", packet).Debugf("Error while validating outbound packet: %s", err)
|
l.WithField("packet", packet).Debugf("Error while validating outbound packet: %s", err)
|
||||||
@@ -30,6 +30,14 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *FirewallPacket,
|
|||||||
}
|
}
|
||||||
|
|
||||||
hostinfo := f.getOrHandshake(fwPacket.RemoteIP)
|
hostinfo := f.getOrHandshake(fwPacket.RemoteIP)
|
||||||
|
if hostinfo == nil {
|
||||||
|
if l.Level >= logrus.DebugLevel {
|
||||||
|
l.WithField("vpnIp", IntIp(fwPacket.RemoteIP)).
|
||||||
|
WithField("fwPacket", fwPacket).
|
||||||
|
Debugln("dropping outbound packet, vpnIp not in our CIDR or in unsafe routes")
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
ci := hostinfo.ConnectionState
|
ci := hostinfo.ConnectionState
|
||||||
|
|
||||||
if ci.ready == false {
|
if ci.ready == false {
|
||||||
@@ -37,28 +45,35 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *FirewallPacket,
|
|||||||
// the packet queue.
|
// the packet queue.
|
||||||
ci.queueLock.Lock()
|
ci.queueLock.Lock()
|
||||||
if !ci.ready {
|
if !ci.ready {
|
||||||
hostinfo.cachePacket(message, 0, packet, f.sendMessageNow)
|
hostinfo.cachePacket(message, st, packet, f.sendMessageNow)
|
||||||
ci.queueLock.Unlock()
|
ci.queueLock.Unlock()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
ci.queueLock.Unlock()
|
ci.queueLock.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
if !f.firewall.Drop(packet, *fwPacket, false, hostinfo, trustedCAs) {
|
dropReason := f.firewall.Drop(packet, *fwPacket, false, hostinfo, trustedCAs)
|
||||||
f.send(message, 0, ci, hostinfo, hostinfo.remote, packet, nb, out)
|
if dropReason == nil {
|
||||||
if f.lightHouse != nil && *ci.messageCounter%5000 == 0 {
|
mc := f.sendNoMetrics(message, st, ci, hostinfo, hostinfo.remote, packet, nb, out)
|
||||||
|
if f.lightHouse != nil && mc%5000 == 0 {
|
||||||
f.lightHouse.Query(fwPacket.RemoteIP, f)
|
f.lightHouse.Query(fwPacket.RemoteIP, f)
|
||||||
}
|
}
|
||||||
|
|
||||||
} else if l.Level >= logrus.DebugLevel {
|
} else if l.Level >= logrus.DebugLevel {
|
||||||
hostinfo.logger().WithField("fwPacket", fwPacket).
|
hostinfo.logger().
|
||||||
|
WithField("fwPacket", fwPacket).
|
||||||
|
WithField("reason", dropReason).
|
||||||
Debugln("dropping outbound packet")
|
Debugln("dropping outbound packet")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// getOrHandshake returns nil if the vpnIp is not routable
|
||||||
func (f *Interface) getOrHandshake(vpnIp uint32) *HostInfo {
|
func (f *Interface) getOrHandshake(vpnIp uint32) *HostInfo {
|
||||||
if f.hostMap.vpnCIDR.Contains(int2ip(vpnIp)) == false {
|
if f.hostMap.vpnCIDR.Contains(int2ip(vpnIp)) == false {
|
||||||
vpnIp = f.hostMap.queryUnsafeRoute(vpnIp)
|
vpnIp = f.hostMap.queryUnsafeRoute(vpnIp)
|
||||||
|
if vpnIp == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
}
|
}
|
||||||
hostinfo, err := f.hostMap.PromoteBestQueryVpnIP(vpnIp, f)
|
hostinfo, err := f.hostMap.PromoteBestQueryVpnIP(vpnIp, f)
|
||||||
|
|
||||||
@@ -91,6 +106,15 @@ func (f *Interface) getOrHandshake(vpnIp uint32) *HostInfo {
|
|||||||
ixHandshakeStage0(f, vpnIp, hostinfo)
|
ixHandshakeStage0(f, vpnIp, hostinfo)
|
||||||
// FIXME: Maybe make XX selectable, but probably not since psk makes it nearly pointless for us.
|
// FIXME: Maybe make XX selectable, but probably not since psk makes it nearly pointless for us.
|
||||||
//xx_handshakeStage0(f, ip, hostinfo)
|
//xx_handshakeStage0(f, ip, hostinfo)
|
||||||
|
|
||||||
|
// If this is a static host, we don't need to wait for the HostQueryReply
|
||||||
|
// We can trigger the handshake right now
|
||||||
|
if _, ok := f.lightHouse.staticList[vpnIp]; ok {
|
||||||
|
select {
|
||||||
|
case f.handshakeManager.trigger <- vpnIp:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return hostinfo
|
return hostinfo
|
||||||
@@ -105,12 +129,17 @@ func (f *Interface) sendMessageNow(t NebulaMessageType, st NebulaMessageSubType,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// check if packet is in outbound fw rules
|
// check if packet is in outbound fw rules
|
||||||
if f.firewall.Drop(p, *fp, false, hostInfo, trustedCAs) {
|
dropReason := f.firewall.Drop(p, *fp, false, hostInfo, trustedCAs)
|
||||||
l.WithField("fwPacket", fp).Debugln("dropping cached packet")
|
if dropReason != nil {
|
||||||
|
if l.Level >= logrus.DebugLevel {
|
||||||
|
l.WithField("fwPacket", fp).
|
||||||
|
WithField("reason", dropReason).
|
||||||
|
Debugln("dropping cached packet")
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
f.send(message, st, hostInfo.ConnectionState, hostInfo, hostInfo.remote, p, nb, out)
|
f.sendNoMetrics(message, st, hostInfo.ConnectionState, hostInfo, hostInfo.remote, p, nb, out)
|
||||||
if f.lightHouse != nil && *hostInfo.ConnectionState.messageCounter%5000 == 0 {
|
if f.lightHouse != nil && *hostInfo.ConnectionState.messageCounter%5000 == 0 {
|
||||||
f.lightHouse.Query(fp.RemoteIP, f)
|
f.lightHouse.Query(fp.RemoteIP, f)
|
||||||
}
|
}
|
||||||
@@ -119,6 +148,13 @@ func (f *Interface) sendMessageNow(t NebulaMessageType, st NebulaMessageSubType,
|
|||||||
// SendMessageToVpnIp handles real ip:port lookup and sends to the current best known address for vpnIp
|
// SendMessageToVpnIp handles real ip:port lookup and sends to the current best known address for vpnIp
|
||||||
func (f *Interface) SendMessageToVpnIp(t NebulaMessageType, st NebulaMessageSubType, vpnIp uint32, p, nb, out []byte) {
|
func (f *Interface) SendMessageToVpnIp(t NebulaMessageType, st NebulaMessageSubType, vpnIp uint32, p, nb, out []byte) {
|
||||||
hostInfo := f.getOrHandshake(vpnIp)
|
hostInfo := f.getOrHandshake(vpnIp)
|
||||||
|
if hostInfo == nil {
|
||||||
|
if l.Level >= logrus.DebugLevel {
|
||||||
|
l.WithField("vpnIp", IntIp(vpnIp)).
|
||||||
|
Debugln("dropping SendMessageToVpnIp, vpnIp not in our CIDR or in unsafe routes")
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
if !hostInfo.ConnectionState.ready {
|
if !hostInfo.ConnectionState.ready {
|
||||||
// Because we might be sending stored packets, lock here to stop new things going to
|
// Because we might be sending stored packets, lock here to stop new things going to
|
||||||
@@ -143,6 +179,13 @@ func (f *Interface) sendMessageToVpnIp(t NebulaMessageType, st NebulaMessageSubT
|
|||||||
// SendMessageToAll handles real ip:port lookup and sends to all known addresses for vpnIp
|
// SendMessageToAll handles real ip:port lookup and sends to all known addresses for vpnIp
|
||||||
func (f *Interface) SendMessageToAll(t NebulaMessageType, st NebulaMessageSubType, vpnIp uint32, p, nb, out []byte) {
|
func (f *Interface) SendMessageToAll(t NebulaMessageType, st NebulaMessageSubType, vpnIp uint32, p, nb, out []byte) {
|
||||||
hostInfo := f.getOrHandshake(vpnIp)
|
hostInfo := f.getOrHandshake(vpnIp)
|
||||||
|
if hostInfo == nil {
|
||||||
|
if l.Level >= logrus.DebugLevel {
|
||||||
|
l.WithField("vpnIp", IntIp(vpnIp)).
|
||||||
|
Debugln("dropping SendMessageToAll, vpnIp not in our CIDR or in unsafe routes")
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
if hostInfo.ConnectionState.ready == false {
|
if hostInfo.ConnectionState.ready == false {
|
||||||
// Because we might be sending stored packets, lock here to stop new things going to
|
// Because we might be sending stored packets, lock here to stop new things going to
|
||||||
@@ -167,9 +210,14 @@ func (f *Interface) sendMessageToAll(t NebulaMessageType, st NebulaMessageSubTyp
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (f *Interface) send(t NebulaMessageType, st NebulaMessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote *udpAddr, p, nb, out []byte) {
|
func (f *Interface) send(t NebulaMessageType, st NebulaMessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote *udpAddr, p, nb, out []byte) {
|
||||||
|
f.messageMetrics.Tx(t, st, 1)
|
||||||
|
f.sendNoMetrics(t, st, ci, hostinfo, remote, p, nb, out)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *Interface) sendNoMetrics(t NebulaMessageType, st NebulaMessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote *udpAddr, p, nb, out []byte) uint64 {
|
||||||
if ci.eKey == nil {
|
if ci.eKey == nil {
|
||||||
//TODO: log warning
|
//TODO: log warning
|
||||||
return
|
return 0
|
||||||
}
|
}
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
@@ -189,7 +237,7 @@ func (f *Interface) send(t NebulaMessageType, st NebulaMessageSubType, ci *Conne
|
|||||||
WithField("udpAddr", remote).WithField("counter", c).
|
WithField("udpAddr", remote).WithField("counter", c).
|
||||||
WithField("attemptedCounter", ci.messageCounter).
|
WithField("attemptedCounter", ci.messageCounter).
|
||||||
Error("Failed to encrypt outgoing packet")
|
Error("Failed to encrypt outgoing packet")
|
||||||
return
|
return c
|
||||||
}
|
}
|
||||||
|
|
||||||
err = f.outside.WriteTo(out, remote)
|
err = f.outside.WriteTo(out, remote)
|
||||||
@@ -197,6 +245,7 @@ func (f *Interface) send(t NebulaMessageType, st NebulaMessageSubType, ci *Conne
|
|||||||
hostinfo.logger().WithError(err).
|
hostinfo.logger().WithError(err).
|
||||||
WithField("udpAddr", remote).Error("Failed to write outgoing packet")
|
WithField("udpAddr", remote).Error("Failed to write outgoing packet")
|
||||||
}
|
}
|
||||||
|
return c
|
||||||
}
|
}
|
||||||
|
|
||||||
func isMulticast(ip uint32) bool {
|
func isMulticast(ip uint32) bool {
|
||||||
|
|||||||
91
interface.go
91
interface.go
@@ -2,6 +2,8 @@ package nebula
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
"os"
|
"os"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -10,10 +12,20 @@ import (
|
|||||||
|
|
||||||
const mtu = 9001
|
const mtu = 9001
|
||||||
|
|
||||||
|
type Inside interface {
|
||||||
|
io.ReadWriteCloser
|
||||||
|
Activate() error
|
||||||
|
CidrNet() *net.IPNet
|
||||||
|
DeviceName() string
|
||||||
|
WriteRaw([]byte) error
|
||||||
|
}
|
||||||
|
|
||||||
|
type InsideHandler func(hostInfo *HostInfo, ci *ConnectionState, addr *udpAddr, header *Header, out []byte, packet []byte, fp *FirewallPacket, nb []byte)
|
||||||
|
|
||||||
type InterfaceConfig struct {
|
type InterfaceConfig struct {
|
||||||
HostMap *HostMap
|
HostMap *HostMap
|
||||||
Outside *udpConn
|
Outside *udpConn
|
||||||
Inside *Tun
|
Inside Inside
|
||||||
certState *CertState
|
certState *CertState
|
||||||
Cipher string
|
Cipher string
|
||||||
Firewall *Firewall
|
Firewall *Firewall
|
||||||
@@ -25,12 +37,16 @@ type InterfaceConfig struct {
|
|||||||
DropLocalBroadcast bool
|
DropLocalBroadcast bool
|
||||||
DropMulticast bool
|
DropMulticast bool
|
||||||
UDPBatchSize int
|
UDPBatchSize int
|
||||||
|
udpQueues int
|
||||||
|
tunQueues int
|
||||||
|
MessageMetrics *MessageMetrics
|
||||||
|
version string
|
||||||
}
|
}
|
||||||
|
|
||||||
type Interface struct {
|
type Interface struct {
|
||||||
hostMap *HostMap
|
hostMap *HostMap
|
||||||
outside *udpConn
|
outside *udpConn
|
||||||
inside *Tun
|
inside Inside
|
||||||
certState *CertState
|
certState *CertState
|
||||||
cipher string
|
cipher string
|
||||||
firewall *Firewall
|
firewall *Firewall
|
||||||
@@ -43,11 +59,15 @@ type Interface struct {
|
|||||||
dropLocalBroadcast bool
|
dropLocalBroadcast bool
|
||||||
dropMulticast bool
|
dropMulticast bool
|
||||||
udpBatchSize int
|
udpBatchSize int
|
||||||
|
udpQueues int
|
||||||
|
tunQueues int
|
||||||
version string
|
version string
|
||||||
|
|
||||||
metricRxRecvError metrics.Counter
|
// handlers are mapped by protocol version -> type -> subtype
|
||||||
metricTxRecvError metrics.Counter
|
handlers map[uint8]map[NebulaMessageType]map[NebulaMessageSubType]InsideHandler
|
||||||
metricHandshakes metrics.Histogram
|
|
||||||
|
metricHandshakes metrics.Histogram
|
||||||
|
messageMetrics *MessageMetrics
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewInterface(c *InterfaceConfig) (*Interface, error) {
|
func NewInterface(c *InterfaceConfig) (*Interface, error) {
|
||||||
@@ -79,40 +99,64 @@ func NewInterface(c *InterfaceConfig) (*Interface, error) {
|
|||||||
dropLocalBroadcast: c.DropLocalBroadcast,
|
dropLocalBroadcast: c.DropLocalBroadcast,
|
||||||
dropMulticast: c.DropMulticast,
|
dropMulticast: c.DropMulticast,
|
||||||
udpBatchSize: c.UDPBatchSize,
|
udpBatchSize: c.UDPBatchSize,
|
||||||
|
udpQueues: c.udpQueues,
|
||||||
|
tunQueues: c.tunQueues,
|
||||||
|
version: c.version,
|
||||||
|
|
||||||
metricRxRecvError: metrics.GetOrRegisterCounter("messages.rx.recv_error", nil),
|
metricHandshakes: metrics.GetOrRegisterHistogram("handshakes", nil, metrics.NewExpDecaySample(1028, 0.015)),
|
||||||
metricTxRecvError: metrics.GetOrRegisterCounter("messages.tx.recv_error", nil),
|
messageMetrics: c.MessageMetrics,
|
||||||
metricHandshakes: metrics.GetOrRegisterHistogram("handshakes", nil, metrics.NewExpDecaySample(1028, 0.015)),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
ifce.connectionManager = newConnectionManager(ifce, c.checkInterval, c.pendingDeletionInterval)
|
ifce.connectionManager = newConnectionManager(ifce, c.checkInterval, c.pendingDeletionInterval)
|
||||||
|
ifce.handlers = map[uint8]map[NebulaMessageType]map[NebulaMessageSubType]InsideHandler{
|
||||||
|
Version: {
|
||||||
|
handshake: {
|
||||||
|
handshakeIXPSK0: ifce.rxMetrics(ifce.handleHandshakePacket),
|
||||||
|
},
|
||||||
|
message: {
|
||||||
|
subTypeNone: ifce.encrypted(ifce.handleMessagePacket),
|
||||||
|
},
|
||||||
|
recvError: {
|
||||||
|
subTypeNone: ifce.rxMetrics(ifce.handleRecvErrorPacket),
|
||||||
|
},
|
||||||
|
lightHouse: {
|
||||||
|
subTypeNone: ifce.rxMetrics(ifce.encrypted(ifce.handleLighthousePacket)),
|
||||||
|
},
|
||||||
|
test: {
|
||||||
|
testRequest: ifce.rxMetrics(ifce.encrypted(ifce.handleTestPacket)),
|
||||||
|
testReply: ifce.rxMetrics(ifce.encrypted(ifce.handleTestPacket)),
|
||||||
|
},
|
||||||
|
closeTunnel: {
|
||||||
|
subTypeNone: ifce.rxMetrics(ifce.encrypted(ifce.handleCloseTunnelPacket)),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
return ifce, nil
|
return ifce, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Interface) Run(tunRoutines, udpRoutines int, buildVersion string) {
|
func (f *Interface) run() {
|
||||||
// actually turn on tun dev
|
// actually turn on tun dev
|
||||||
if err := f.inside.Activate(); err != nil {
|
if err := f.inside.Activate(); err != nil {
|
||||||
l.Fatal(err)
|
l.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
f.version = buildVersion
|
|
||||||
addr, err := f.outside.LocalAddr()
|
addr, err := f.outside.LocalAddr()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
l.WithError(err).Error("Failed to get udp listen address")
|
l.WithError(err).Error("Failed to get udp listen address")
|
||||||
}
|
}
|
||||||
|
|
||||||
l.WithField("interface", f.inside.Device).WithField("network", f.inside.Cidr.String()).
|
l.WithField("interface", f.inside.DeviceName()).WithField("network", f.inside.CidrNet().String()).
|
||||||
WithField("build", buildVersion).WithField("udpAddr", addr).
|
WithField("build", f.version).WithField("udpAddr", addr).
|
||||||
Info("Nebula interface is active")
|
Info("Nebula interface is active")
|
||||||
|
|
||||||
// Launch n queues to read packets from udp
|
// Launch n queues to read packets from udp
|
||||||
for i := 0; i < udpRoutines; i++ {
|
for i := 0; i < f.udpQueues; i++ {
|
||||||
go f.listenOut(i)
|
go f.listenOut(i)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Launch n queues to read packets from tun dev
|
// Launch n queues to read packets from tun dev
|
||||||
for i := 0; i < tunRoutines; i++ {
|
for i := 0; i < f.tunQueues; i++ {
|
||||||
go f.listenIn(i)
|
go f.listenIn(i)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -152,7 +196,7 @@ func (f *Interface) listenIn(i int) {
|
|||||||
os.Exit(2)
|
os.Exit(2)
|
||||||
}
|
}
|
||||||
|
|
||||||
f.consumeInsidePacket(packet[:n], fwPacket, nb, out)
|
f.consumeInsidePacket(subTypeNone, packet[:n], fwPacket, nb, out)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -210,11 +254,28 @@ func (f *Interface) reloadFirewall(c *Config) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
oldFw := f.firewall
|
oldFw := f.firewall
|
||||||
|
conntrack := oldFw.Conntrack
|
||||||
|
conntrack.Lock()
|
||||||
|
defer conntrack.Unlock()
|
||||||
|
|
||||||
|
fw.rulesVersion = oldFw.rulesVersion + 1
|
||||||
|
// If rulesVersion is back to zero, we have wrapped all the way around. Be
|
||||||
|
// safe and just reset conntrack in this case.
|
||||||
|
if fw.rulesVersion == 0 {
|
||||||
|
l.WithField("firewallHash", fw.GetRuleHash()).
|
||||||
|
WithField("oldFirewallHash", oldFw.GetRuleHash()).
|
||||||
|
WithField("rulesVersion", fw.rulesVersion).
|
||||||
|
Warn("firewall rulesVersion has overflowed, resetting conntrack")
|
||||||
|
} else {
|
||||||
|
fw.Conntrack = conntrack
|
||||||
|
}
|
||||||
|
|
||||||
f.firewall = fw
|
f.firewall = fw
|
||||||
|
|
||||||
oldFw.Destroy()
|
oldFw.Destroy()
|
||||||
l.WithField("firewallHash", fw.GetRuleHash()).
|
l.WithField("firewallHash", fw.GetRuleHash()).
|
||||||
WithField("oldFirewallHash", oldFw.GetRuleHash()).
|
WithField("oldFirewallHash", oldFw.GetRuleHash()).
|
||||||
|
WithField("rulesVersion", fw.rulesVersion).
|
||||||
Info("New firewall has been installed")
|
Info("New firewall has been installed")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/golang/protobuf/proto"
|
"github.com/golang/protobuf/proto"
|
||||||
|
"github.com/rcrowley/go-metrics"
|
||||||
"github.com/slackhq/nebula/cert"
|
"github.com/slackhq/nebula/cert"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -29,6 +30,9 @@ type LightHouse struct {
|
|||||||
// filters local addresses that we advertise to lighthouses
|
// filters local addresses that we advertise to lighthouses
|
||||||
localAllowList *AllowList
|
localAllowList *AllowList
|
||||||
|
|
||||||
|
// used to trigger the HandshakeManager when we receive HostQueryReply
|
||||||
|
handshakeTrigger chan<- uint32
|
||||||
|
|
||||||
// staticList exists to avoid having a bool in each addrMap entry
|
// staticList exists to avoid having a bool in each addrMap entry
|
||||||
// since static should be rare
|
// since static should be rare
|
||||||
staticList map[uint32]struct{}
|
staticList map[uint32]struct{}
|
||||||
@@ -37,6 +41,9 @@ type LightHouse struct {
|
|||||||
nebulaPort int
|
nebulaPort int
|
||||||
punchBack bool
|
punchBack bool
|
||||||
punchDelay time.Duration
|
punchDelay time.Duration
|
||||||
|
|
||||||
|
metrics *MessageMetrics
|
||||||
|
metricHolepunchTx metrics.Counter
|
||||||
}
|
}
|
||||||
|
|
||||||
type EncWriter interface {
|
type EncWriter interface {
|
||||||
@@ -44,7 +51,7 @@ type EncWriter interface {
|
|||||||
SendMessageToAll(t NebulaMessageType, st NebulaMessageSubType, vpnIp uint32, p, nb, out []byte)
|
SendMessageToAll(t NebulaMessageType, st NebulaMessageSubType, vpnIp uint32, p, nb, out []byte)
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewLightHouse(amLighthouse bool, myIp uint32, ips []uint32, interval int, nebulaPort int, pc *udpConn, punchBack bool, punchDelay time.Duration) *LightHouse {
|
func NewLightHouse(amLighthouse bool, myIp uint32, ips []uint32, interval int, nebulaPort int, pc *udpConn, punchBack bool, punchDelay time.Duration, metricsEnabled bool) *LightHouse {
|
||||||
h := LightHouse{
|
h := LightHouse{
|
||||||
amLighthouse: amLighthouse,
|
amLighthouse: amLighthouse,
|
||||||
myIp: myIp,
|
myIp: myIp,
|
||||||
@@ -58,6 +65,14 @@ func NewLightHouse(amLighthouse bool, myIp uint32, ips []uint32, interval int, n
|
|||||||
punchDelay: punchDelay,
|
punchDelay: punchDelay,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if metricsEnabled {
|
||||||
|
h.metrics = newLighthouseMetrics()
|
||||||
|
|
||||||
|
h.metricHolepunchTx = metrics.GetOrRegisterCounter("messages.tx.holepunch", nil)
|
||||||
|
} else {
|
||||||
|
h.metricHolepunchTx = metrics.NilCounter{}
|
||||||
|
}
|
||||||
|
|
||||||
for _, ip := range ips {
|
for _, ip := range ips {
|
||||||
h.lighthouses[ip] = struct{}{}
|
h.lighthouses[ip] = struct{}{}
|
||||||
}
|
}
|
||||||
@@ -111,6 +126,7 @@ func (lh *LightHouse) QueryServer(ip uint32, f EncWriter) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
lh.metricTx(NebulaMeta_HostQuery, int64(len(lh.lighthouses)))
|
||||||
nb := make([]byte, 12, 12)
|
nb := make([]byte, 12, 12)
|
||||||
out := make([]byte, mtu)
|
out := make([]byte, mtu)
|
||||||
for n := range lh.lighthouses {
|
for n := range lh.lighthouses {
|
||||||
@@ -249,6 +265,7 @@ func (lh *LightHouse) LhUpdateWorker(f EncWriter) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
lh.metricTx(NebulaMeta_HostUpdateNotification, int64(len(lh.lighthouses)))
|
||||||
nb := make([]byte, 12, 12)
|
nb := make([]byte, 12, 12)
|
||||||
out := make([]byte, mtu)
|
out := make([]byte, mtu)
|
||||||
for vpnIp := range lh.lighthouses {
|
for vpnIp := range lh.lighthouses {
|
||||||
@@ -281,6 +298,8 @@ func (lh *LightHouse) HandleRequest(rAddr *udpAddr, vpnIp uint32, p []byte, c *c
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
lh.metricRx(n.Type, 1)
|
||||||
|
|
||||||
switch n.Type {
|
switch n.Type {
|
||||||
case NebulaMeta_HostQuery:
|
case NebulaMeta_HostQuery:
|
||||||
// Exit if we don't answer queries
|
// Exit if we don't answer queries
|
||||||
@@ -308,6 +327,7 @@ func (lh *LightHouse) HandleRequest(rAddr *udpAddr, vpnIp uint32, p []byte, c *c
|
|||||||
l.WithError(err).WithField("vpnIp", IntIp(vpnIp)).Error("Failed to marshal lighthouse host query reply")
|
l.WithError(err).WithField("vpnIp", IntIp(vpnIp)).Error("Failed to marshal lighthouse host query reply")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
lh.metricTx(NebulaMeta_HostQueryReply, 1)
|
||||||
f.SendMessageToVpnIp(lightHouse, 0, vpnIp, reply, make([]byte, 12, 12), make([]byte, mtu))
|
f.SendMessageToVpnIp(lightHouse, 0, vpnIp, reply, make([]byte, 12, 12), make([]byte, mtu))
|
||||||
|
|
||||||
// This signals the other side to punch some zero byte udp packets
|
// This signals the other side to punch some zero byte udp packets
|
||||||
@@ -326,6 +346,7 @@ func (lh *LightHouse) HandleRequest(rAddr *udpAddr, vpnIp uint32, p []byte, c *c
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
reply, _ := proto.Marshal(answer)
|
reply, _ := proto.Marshal(answer)
|
||||||
|
lh.metricTx(NebulaMeta_HostPunchNotification, 1)
|
||||||
f.SendMessageToVpnIp(lightHouse, 0, n.Details.VpnIp, reply, make([]byte, 12, 12), make([]byte, mtu))
|
f.SendMessageToVpnIp(lightHouse, 0, n.Details.VpnIp, reply, make([]byte, 12, 12), make([]byte, mtu))
|
||||||
}
|
}
|
||||||
//fmt.Println(reply, remoteaddr)
|
//fmt.Println(reply, remoteaddr)
|
||||||
@@ -340,6 +361,11 @@ func (lh *LightHouse) HandleRequest(rAddr *udpAddr, vpnIp uint32, p []byte, c *c
|
|||||||
ans := NewUDPAddr(a.Ip, uint16(a.Port))
|
ans := NewUDPAddr(a.Ip, uint16(a.Port))
|
||||||
lh.AddRemote(n.Details.VpnIp, ans, false)
|
lh.AddRemote(n.Details.VpnIp, ans, false)
|
||||||
}
|
}
|
||||||
|
// Non-blocking attempt to trigger, skip if it would block
|
||||||
|
select {
|
||||||
|
case lh.handshakeTrigger <- n.Details.VpnIp:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
|
||||||
case NebulaMeta_HostUpdateNotification:
|
case NebulaMeta_HostUpdateNotification:
|
||||||
//Simple check that the host sent this not someone else
|
//Simple check that the host sent this not someone else
|
||||||
@@ -362,6 +388,7 @@ func (lh *LightHouse) HandleRequest(rAddr *udpAddr, vpnIp uint32, p []byte, c *c
|
|||||||
vpnPeer := NewUDPAddr(a.Ip, uint16(a.Port))
|
vpnPeer := NewUDPAddr(a.Ip, uint16(a.Port))
|
||||||
go func() {
|
go func() {
|
||||||
time.Sleep(lh.punchDelay)
|
time.Sleep(lh.punchDelay)
|
||||||
|
lh.metricHolepunchTx.Inc(1)
|
||||||
lh.punchConn.WriteTo(empty, vpnPeer)
|
lh.punchConn.WriteTo(empty, vpnPeer)
|
||||||
|
|
||||||
}()
|
}()
|
||||||
@@ -380,6 +407,13 @@ func (lh *LightHouse) HandleRequest(rAddr *udpAddr, vpnIp uint32, p []byte, c *c
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (lh *LightHouse) metricRx(t NebulaMeta_MessageType, i int64) {
|
||||||
|
lh.metrics.Rx(NebulaMessageType(t), 0, i)
|
||||||
|
}
|
||||||
|
func (lh *LightHouse) metricTx(t NebulaMeta_MessageType, i int64) {
|
||||||
|
lh.metrics.Tx(NebulaMessageType(t), 0, i)
|
||||||
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
func (f *Interface) sendPathCheck(ci *ConnectionState, endpoint *net.UDPAddr, counter int) {
|
func (f *Interface) sendPathCheck(ci *ConnectionState, endpoint *net.UDPAddr, counter int) {
|
||||||
c := ci.messageCounter
|
c := ci.messageCounter
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
proto "github.com/golang/protobuf/proto"
|
"github.com/golang/protobuf/proto"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -52,7 +52,7 @@ func Test_lhStaticMapping(t *testing.T) {
|
|||||||
|
|
||||||
udpServer, _ := NewListener("0.0.0.0", 0, true)
|
udpServer, _ := NewListener("0.0.0.0", 0, true)
|
||||||
|
|
||||||
meh := NewLightHouse(true, 1, []uint32{ip2int(lh1IP)}, 10, 10003, udpServer, false, 1)
|
meh := NewLightHouse(true, 1, []uint32{ip2int(lh1IP)}, 10, 10003, udpServer, false, 1, false)
|
||||||
meh.AddRemote(ip2int(lh1IP), NewUDPAddr(ip2int(lh1IP), uint16(4242)), true)
|
meh.AddRemote(ip2int(lh1IP), NewUDPAddr(ip2int(lh1IP), uint16(4242)), true)
|
||||||
err := meh.ValidateLHStaticEntries()
|
err := meh.ValidateLHStaticEntries()
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
@@ -60,7 +60,7 @@ func Test_lhStaticMapping(t *testing.T) {
|
|||||||
lh2 := "10.128.0.3"
|
lh2 := "10.128.0.3"
|
||||||
lh2IP := net.ParseIP(lh2)
|
lh2IP := net.ParseIP(lh2)
|
||||||
|
|
||||||
meh = NewLightHouse(true, 1, []uint32{ip2int(lh1IP), ip2int(lh2IP)}, 10, 10003, udpServer, false, 1)
|
meh = NewLightHouse(true, 1, []uint32{ip2int(lh1IP), ip2int(lh2IP)}, 10, 10003, udpServer, false, 1, false)
|
||||||
meh.AddRemote(ip2int(lh1IP), NewUDPAddr(ip2int(lh1IP), uint16(4242)), true)
|
meh.AddRemote(ip2int(lh1IP), NewUDPAddr(ip2int(lh1IP), uint16(4242)), true)
|
||||||
err = meh.ValidateLHStaticEntries()
|
err = meh.ValidateLHStaticEntries()
|
||||||
assert.EqualError(t, err, "Lighthouse 10.128.0.3 does not have a static_host_map entry")
|
assert.EqualError(t, err, "Lighthouse 10.128.0.3 does not have a static_host_map entry")
|
||||||
|
|||||||
39
logger.go
Normal file
39
logger.go
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
package nebula
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ContextualError struct {
|
||||||
|
RealError error
|
||||||
|
Fields map[string]interface{}
|
||||||
|
Context string
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewContextualError(msg string, fields map[string]interface{}, realError error) ContextualError {
|
||||||
|
return ContextualError{Context: msg, Fields: fields, RealError: realError}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ce ContextualError) Error() string {
|
||||||
|
if ce.RealError == nil {
|
||||||
|
return ce.Context
|
||||||
|
}
|
||||||
|
return ce.RealError.Error()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ce ContextualError) Unwrap() error {
|
||||||
|
if ce.RealError == nil {
|
||||||
|
return errors.New(ce.Context)
|
||||||
|
}
|
||||||
|
return ce.RealError
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ce *ContextualError) Log(lr *logrus.Logger) {
|
||||||
|
if ce.RealError != nil {
|
||||||
|
lr.WithFields(ce.Fields).WithError(ce.RealError).Error(ce.Context)
|
||||||
|
} else {
|
||||||
|
lr.WithFields(ce.Fields).Error(ce.Context)
|
||||||
|
}
|
||||||
|
}
|
||||||
67
logger_test.go
Normal file
67
logger_test.go
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
package nebula
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
type TestLogWriter struct {
|
||||||
|
Logs []string
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewTestLogWriter() *TestLogWriter {
|
||||||
|
return &TestLogWriter{Logs: make([]string, 0)}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tl *TestLogWriter) Write(p []byte) (n int, err error) {
|
||||||
|
tl.Logs = append(tl.Logs, string(p))
|
||||||
|
return len(p), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tl *TestLogWriter) Reset() {
|
||||||
|
tl.Logs = tl.Logs[:0]
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestContextualError_Log(t *testing.T) {
|
||||||
|
l := logrus.New()
|
||||||
|
l.Formatter = &logrus.TextFormatter{
|
||||||
|
DisableTimestamp: true,
|
||||||
|
DisableColors: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
tl := NewTestLogWriter()
|
||||||
|
l.Out = tl
|
||||||
|
|
||||||
|
// Test a full context line
|
||||||
|
tl.Reset()
|
||||||
|
e := NewContextualError("test message", m{"field": "1"}, errors.New("error"))
|
||||||
|
e.Log(l)
|
||||||
|
assert.Equal(t, []string{"level=error msg=\"test message\" error=error field=1\n"}, tl.Logs)
|
||||||
|
|
||||||
|
// Test a line with an error and msg but no fields
|
||||||
|
tl.Reset()
|
||||||
|
e = NewContextualError("test message", nil, errors.New("error"))
|
||||||
|
e.Log(l)
|
||||||
|
assert.Equal(t, []string{"level=error msg=\"test message\" error=error\n"}, tl.Logs)
|
||||||
|
|
||||||
|
// Test just a context and fields
|
||||||
|
tl.Reset()
|
||||||
|
e = NewContextualError("test message", m{"field": "1"}, nil)
|
||||||
|
e.Log(l)
|
||||||
|
assert.Equal(t, []string{"level=error msg=\"test message\" field=1\n"}, tl.Logs)
|
||||||
|
|
||||||
|
// Test just a context
|
||||||
|
tl.Reset()
|
||||||
|
e = NewContextualError("test message", nil, nil)
|
||||||
|
e.Log(l)
|
||||||
|
assert.Equal(t, []string{"level=error msg=\"test message\"\n"}, tl.Logs)
|
||||||
|
|
||||||
|
// Test just an error
|
||||||
|
tl.Reset()
|
||||||
|
e = NewContextualError("", nil, errors.New("error"))
|
||||||
|
e.Log(l)
|
||||||
|
assert.Equal(t, []string{"level=error error=error\n"}, tl.Logs)
|
||||||
|
}
|
||||||
153
main.go
153
main.go
@@ -4,11 +4,8 @@ import (
|
|||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"os"
|
|
||||||
"os/signal"
|
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"syscall"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
@@ -16,36 +13,31 @@ import (
|
|||||||
"gopkg.in/yaml.v2"
|
"gopkg.in/yaml.v2"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// The caller should provide a real logger, we have one just in case
|
||||||
var l = logrus.New()
|
var l = logrus.New()
|
||||||
|
|
||||||
type m map[string]interface{}
|
type m map[string]interface{}
|
||||||
|
|
||||||
func Main(configPath string, configTest bool, buildVersion string) {
|
func Main(config *Config, configTest bool, buildVersion string, logger *logrus.Logger, tunFd *int) (*Control, error) {
|
||||||
l.Out = os.Stdout
|
l = logger
|
||||||
l.Formatter = &logrus.TextFormatter{
|
l.Formatter = &logrus.TextFormatter{
|
||||||
FullTimestamp: true,
|
FullTimestamp: true,
|
||||||
}
|
}
|
||||||
|
|
||||||
config := NewConfig()
|
|
||||||
err := config.Load(configPath)
|
|
||||||
if err != nil {
|
|
||||||
l.WithError(err).Error("Failed to load config")
|
|
||||||
os.Exit(1)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Print the config if in test, the exit comes later
|
// Print the config if in test, the exit comes later
|
||||||
if configTest {
|
if configTest {
|
||||||
b, err := yaml.Marshal(config.Settings)
|
b, err := yaml.Marshal(config.Settings)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
l.Println(err)
|
return nil, err
|
||||||
os.Exit(1)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Print the final config
|
||||||
l.Println(string(b))
|
l.Println(string(b))
|
||||||
}
|
}
|
||||||
|
|
||||||
err = configLogger(config)
|
err := configLogger(config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
l.WithError(err).Error("Failed to configure the logger")
|
return nil, NewContextualError("Failed to configure the logger", nil, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
config.RegisterReloadCallback(func(c *Config) {
|
config.RegisterReloadCallback(func(c *Config) {
|
||||||
@@ -59,20 +51,20 @@ func Main(configPath string, configTest bool, buildVersion string) {
|
|||||||
trustedCAs, err = loadCAFromConfig(config)
|
trustedCAs, err = loadCAFromConfig(config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
//The errors coming out of loadCA are already nicely formatted
|
//The errors coming out of loadCA are already nicely formatted
|
||||||
l.WithError(err).Fatal("Failed to load ca from config")
|
return nil, NewContextualError("Failed to load ca from config", nil, err)
|
||||||
}
|
}
|
||||||
l.WithField("fingerprints", trustedCAs.GetFingerprints()).Debug("Trusted CA fingerprints")
|
l.WithField("fingerprints", trustedCAs.GetFingerprints()).Debug("Trusted CA fingerprints")
|
||||||
|
|
||||||
cs, err := NewCertStateFromConfig(config)
|
cs, err := NewCertStateFromConfig(config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
//The errors coming out of NewCertStateFromConfig are already nicely formatted
|
//The errors coming out of NewCertStateFromConfig are already nicely formatted
|
||||||
l.WithError(err).Fatal("Failed to load certificate from config")
|
return nil, NewContextualError("Failed to load certificate from config", nil, err)
|
||||||
}
|
}
|
||||||
l.WithField("cert", cs.certificate).Debug("Client nebula certificate")
|
l.WithField("cert", cs.certificate).Debug("Client nebula certificate")
|
||||||
|
|
||||||
fw, err := NewFirewallFromConfig(cs.certificate, config)
|
fw, err := NewFirewallFromConfig(cs.certificate, config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
l.WithError(err).Fatal("Error while loading firewall rules")
|
return nil, NewContextualError("Error while loading firewall rules", nil, err)
|
||||||
}
|
}
|
||||||
l.WithField("firewallHash", fw.GetRuleHash()).Info("Firewall started")
|
l.WithField("firewallHash", fw.GetRuleHash()).Info("Firewall started")
|
||||||
|
|
||||||
@@ -80,11 +72,11 @@ func Main(configPath string, configTest bool, buildVersion string) {
|
|||||||
tunCidr := cs.certificate.Details.Ips[0]
|
tunCidr := cs.certificate.Details.Ips[0]
|
||||||
routes, err := parseRoutes(config, tunCidr)
|
routes, err := parseRoutes(config, tunCidr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
l.WithError(err).Fatal("Could not parse tun.routes")
|
return nil, NewContextualError("Could not parse tun.routes", nil, err)
|
||||||
}
|
}
|
||||||
unsafeRoutes, err := parseUnsafeRoutes(config, tunCidr)
|
unsafeRoutes, err := parseUnsafeRoutes(config, tunCidr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
l.WithError(err).Fatal("Could not parse tun.unsafe_routes")
|
return nil, NewContextualError("Could not parse tun.unsafe_routes", nil, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
ssh, err := sshd.NewSSHServer(l.WithField("subsystem", "sshd"))
|
ssh, err := sshd.NewSSHServer(l.WithField("subsystem", "sshd"))
|
||||||
@@ -92,7 +84,7 @@ func Main(configPath string, configTest bool, buildVersion string) {
|
|||||||
if config.GetBool("sshd.enabled", false) {
|
if config.GetBool("sshd.enabled", false) {
|
||||||
err = configSSH(ssh, config)
|
err = configSSH(ssh, config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
l.WithError(err).Fatal("Error while configuring the sshd")
|
return nil, NewContextualError("Error while configuring the sshd", nil, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -101,21 +93,35 @@ func Main(configPath string, configTest bool, buildVersion string) {
|
|||||||
// tun config, listeners, anything modifying the computer should be below
|
// tun config, listeners, anything modifying the computer should be below
|
||||||
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
var tun *Tun
|
var tun Inside
|
||||||
if !configTest {
|
if !configTest {
|
||||||
config.CatchHUP()
|
config.CatchHUP()
|
||||||
|
|
||||||
// set up our tun dev
|
switch {
|
||||||
tun, err = newTun(
|
case config.GetBool("tun.disabled", false):
|
||||||
config.GetString("tun.dev", ""),
|
tun = newDisabledTun(tunCidr, l)
|
||||||
tunCidr,
|
case tunFd != nil:
|
||||||
config.GetInt("tun.mtu", DEFAULT_MTU),
|
tun, err = newTunFromFd(
|
||||||
routes,
|
*tunFd,
|
||||||
unsafeRoutes,
|
tunCidr,
|
||||||
config.GetInt("tun.tx_queue", 500),
|
config.GetInt("tun.mtu", DEFAULT_MTU),
|
||||||
)
|
routes,
|
||||||
|
unsafeRoutes,
|
||||||
|
config.GetInt("tun.tx_queue", 500),
|
||||||
|
)
|
||||||
|
default:
|
||||||
|
tun, err = newTun(
|
||||||
|
config.GetString("tun.dev", ""),
|
||||||
|
tunCidr,
|
||||||
|
config.GetInt("tun.mtu", DEFAULT_MTU),
|
||||||
|
routes,
|
||||||
|
unsafeRoutes,
|
||||||
|
config.GetInt("tun.tx_queue", 500),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
l.WithError(err).Fatal("Failed to get a tun/tap device")
|
return nil, NewContextualError("Failed to get a tun/tap device", nil, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -126,7 +132,7 @@ func Main(configPath string, configTest bool, buildVersion string) {
|
|||||||
if !configTest {
|
if !configTest {
|
||||||
udpServer, err = NewListener(config.GetString("listen.host", "0.0.0.0"), config.GetInt("listen.port", 0), udpQueues > 1)
|
udpServer, err = NewListener(config.GetString("listen.host", "0.0.0.0"), config.GetInt("listen.port", 0), udpQueues > 1)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
l.WithError(err).Fatal("Failed to open udp listener")
|
return nil, NewContextualError("Failed to open udp listener", nil, err)
|
||||||
}
|
}
|
||||||
udpServer.reloadConfig(config)
|
udpServer.reloadConfig(config)
|
||||||
}
|
}
|
||||||
@@ -139,7 +145,7 @@ func Main(configPath string, configTest bool, buildVersion string) {
|
|||||||
for _, rawPreferredRange := range rawPreferredRanges {
|
for _, rawPreferredRange := range rawPreferredRanges {
|
||||||
_, preferredRange, err := net.ParseCIDR(rawPreferredRange)
|
_, preferredRange, err := net.ParseCIDR(rawPreferredRange)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
l.WithError(err).Fatal("Failed to parse preferred ranges")
|
return nil, NewContextualError("Failed to parse preferred ranges", nil, err)
|
||||||
}
|
}
|
||||||
preferredRanges = append(preferredRanges, preferredRange)
|
preferredRanges = append(preferredRanges, preferredRange)
|
||||||
}
|
}
|
||||||
@@ -152,7 +158,7 @@ func Main(configPath string, configTest bool, buildVersion string) {
|
|||||||
if rawLocalRange != "" {
|
if rawLocalRange != "" {
|
||||||
_, localRange, err := net.ParseCIDR(rawLocalRange)
|
_, localRange, err := net.ParseCIDR(rawLocalRange)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
l.WithError(err).Fatal("Failed to parse local range")
|
return nil, NewContextualError("Failed to parse local_range", nil, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if the entry for local_range was already specified in
|
// Check if the entry for local_range was already specified in
|
||||||
@@ -172,6 +178,7 @@ func Main(configPath string, configTest bool, buildVersion string) {
|
|||||||
hostMap := NewHostMap("main", tunCidr, preferredRanges)
|
hostMap := NewHostMap("main", tunCidr, preferredRanges)
|
||||||
hostMap.SetDefaultRoute(ip2int(net.ParseIP(config.GetString("default_route", "0.0.0.0"))))
|
hostMap.SetDefaultRoute(ip2int(net.ParseIP(config.GetString("default_route", "0.0.0.0"))))
|
||||||
hostMap.addUnsafeRoutes(&unsafeRoutes)
|
hostMap.addUnsafeRoutes(&unsafeRoutes)
|
||||||
|
hostMap.metricsEnabled = config.GetBool("stats.message_metrics", false)
|
||||||
|
|
||||||
l.WithField("network", hostMap.vpnCIDR).WithField("preferredRanges", hostMap.preferredRanges).Info("Main HostMap created")
|
l.WithField("network", hostMap.vpnCIDR).WithField("preferredRanges", hostMap.preferredRanges).Info("Main HostMap created")
|
||||||
|
|
||||||
@@ -191,7 +198,7 @@ func Main(configPath string, configTest bool, buildVersion string) {
|
|||||||
if port == 0 && !configTest {
|
if port == 0 && !configTest {
|
||||||
uPort, err := udpServer.LocalAddr()
|
uPort, err := udpServer.LocalAddr()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
l.WithError(err).Fatal("Failed to get listening port")
|
return nil, NewContextualError("Failed to get listening port", nil, err)
|
||||||
}
|
}
|
||||||
port = int(uPort.Port)
|
port = int(uPort.Port)
|
||||||
}
|
}
|
||||||
@@ -208,10 +215,10 @@ func Main(configPath string, configTest bool, buildVersion string) {
|
|||||||
for i, host := range rawLighthouseHosts {
|
for i, host := range rawLighthouseHosts {
|
||||||
ip := net.ParseIP(host)
|
ip := net.ParseIP(host)
|
||||||
if ip == nil {
|
if ip == nil {
|
||||||
l.WithField("host", host).Fatalf("Unable to parse lighthouse host entry %v", i+1)
|
return nil, NewContextualError("Unable to parse lighthouse host entry", m{"host": host, "entry": i + 1}, nil)
|
||||||
}
|
}
|
||||||
if !tunCidr.Contains(ip) {
|
if !tunCidr.Contains(ip) {
|
||||||
l.WithField("vpnIp", ip).WithField("network", tunCidr.String()).Fatalf("lighthouse host is not in our subnet, invalid")
|
return nil, NewContextualError("lighthouse host is not in our subnet, invalid", m{"vpnIp": ip, "network": tunCidr.String()}, nil)
|
||||||
}
|
}
|
||||||
lighthouseHosts[i] = ip2int(ip)
|
lighthouseHosts[i] = ip2int(ip)
|
||||||
}
|
}
|
||||||
@@ -226,17 +233,18 @@ func Main(configPath string, configTest bool, buildVersion string) {
|
|||||||
udpServer,
|
udpServer,
|
||||||
punchy.Respond,
|
punchy.Respond,
|
||||||
punchy.Delay,
|
punchy.Delay,
|
||||||
|
config.GetBool("stats.lighthouse_metrics", false),
|
||||||
)
|
)
|
||||||
|
|
||||||
remoteAllowList, err := config.GetAllowList("lighthouse.remote_allow_list", false)
|
remoteAllowList, err := config.GetAllowList("lighthouse.remote_allow_list", false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
l.WithError(err).Fatal("Invalid lighthouse.remote_allow_list")
|
return nil, NewContextualError("Invalid lighthouse.remote_allow_list", nil, err)
|
||||||
}
|
}
|
||||||
lightHouse.SetRemoteAllowList(remoteAllowList)
|
lightHouse.SetRemoteAllowList(remoteAllowList)
|
||||||
|
|
||||||
localAllowList, err := config.GetAllowList("lighthouse.local_allow_list", true)
|
localAllowList, err := config.GetAllowList("lighthouse.local_allow_list", true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
l.WithError(err).Fatal("Invalid lighthouse.local_allow_list")
|
return nil, NewContextualError("Invalid lighthouse.local_allow_list", nil, err)
|
||||||
}
|
}
|
||||||
lightHouse.SetLocalAllowList(localAllowList)
|
lightHouse.SetLocalAllowList(localAllowList)
|
||||||
|
|
||||||
@@ -244,7 +252,7 @@ func Main(configPath string, configTest bool, buildVersion string) {
|
|||||||
for k, v := range config.GetMap("static_host_map", map[interface{}]interface{}{}) {
|
for k, v := range config.GetMap("static_host_map", map[interface{}]interface{}{}) {
|
||||||
vpnIp := net.ParseIP(fmt.Sprintf("%v", k))
|
vpnIp := net.ParseIP(fmt.Sprintf("%v", k))
|
||||||
if !tunCidr.Contains(vpnIp) {
|
if !tunCidr.Contains(vpnIp) {
|
||||||
l.WithField("vpnIp", vpnIp).WithField("network", tunCidr.String()).Fatalf("static_host_map key is not in our subnet, invalid")
|
return nil, NewContextualError("static_host_map key is not in our subnet, invalid", m{"vpnIp": vpnIp, "network": tunCidr.String()}, nil)
|
||||||
}
|
}
|
||||||
vals, ok := v.([]interface{})
|
vals, ok := v.([]interface{})
|
||||||
if ok {
|
if ok {
|
||||||
@@ -255,7 +263,7 @@ func Main(configPath string, configTest bool, buildVersion string) {
|
|||||||
ip := addr.IP
|
ip := addr.IP
|
||||||
port, err := strconv.Atoi(parts[1])
|
port, err := strconv.Atoi(parts[1])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
l.Fatalf("Static host address for %s could not be parsed: %s", vpnIp, v)
|
return nil, NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp}, err)
|
||||||
}
|
}
|
||||||
lightHouse.AddRemote(ip2int(vpnIp), NewUDPAddr(ip2int(ip), uint16(port)), true)
|
lightHouse.AddRemote(ip2int(vpnIp), NewUDPAddr(ip2int(ip), uint16(port)), true)
|
||||||
}
|
}
|
||||||
@@ -268,7 +276,7 @@ func Main(configPath string, configTest bool, buildVersion string) {
|
|||||||
ip := addr.IP
|
ip := addr.IP
|
||||||
port, err := strconv.Atoi(parts[1])
|
port, err := strconv.Atoi(parts[1])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
l.Fatalf("Static host address for %s could not be parsed: %s", vpnIp, v)
|
return nil, NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp}, err)
|
||||||
}
|
}
|
||||||
lightHouse.AddRemote(ip2int(vpnIp), NewUDPAddr(ip2int(ip), uint16(port)), true)
|
lightHouse.AddRemote(ip2int(vpnIp), NewUDPAddr(ip2int(ip), uint16(port)), true)
|
||||||
}
|
}
|
||||||
@@ -280,13 +288,24 @@ func Main(configPath string, configTest bool, buildVersion string) {
|
|||||||
l.WithError(err).Error("Lighthouse unreachable")
|
l.WithError(err).Error("Lighthouse unreachable")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var messageMetrics *MessageMetrics
|
||||||
|
if config.GetBool("stats.message_metrics", false) {
|
||||||
|
messageMetrics = newMessageMetrics()
|
||||||
|
} else {
|
||||||
|
messageMetrics = newMessageMetricsOnlyRecvError()
|
||||||
|
}
|
||||||
|
|
||||||
handshakeConfig := HandshakeConfig{
|
handshakeConfig := HandshakeConfig{
|
||||||
tryInterval: config.GetDuration("handshakes.try_interval", DefaultHandshakeTryInterval),
|
tryInterval: config.GetDuration("handshakes.try_interval", DefaultHandshakeTryInterval),
|
||||||
retries: config.GetInt("handshakes.retries", DefaultHandshakeRetries),
|
retries: config.GetInt("handshakes.retries", DefaultHandshakeRetries),
|
||||||
waitRotation: config.GetInt("handshakes.wait_rotation", DefaultHandshakeWaitRotation),
|
waitRotation: config.GetInt("handshakes.wait_rotation", DefaultHandshakeWaitRotation),
|
||||||
|
triggerBuffer: config.GetInt("handshakes.trigger_buffer", DefaultHandshakeTriggerBuffer),
|
||||||
|
|
||||||
|
messageMetrics: messageMetrics,
|
||||||
}
|
}
|
||||||
|
|
||||||
handshakeManager := NewHandshakeManager(tunCidr, preferredRanges, hostMap, lightHouse, udpServer, handshakeConfig)
|
handshakeManager := NewHandshakeManager(tunCidr, preferredRanges, hostMap, lightHouse, udpServer, handshakeConfig)
|
||||||
|
lightHouse.handshakeTrigger = handshakeManager.trigger
|
||||||
|
|
||||||
//TODO: These will be reused for psk
|
//TODO: These will be reused for psk
|
||||||
//handshakeMACKey := config.GetString("handshake_mac.key", "")
|
//handshakeMACKey := config.GetString("handshake_mac.key", "")
|
||||||
@@ -310,6 +329,10 @@ func Main(configPath string, configTest bool, buildVersion string) {
|
|||||||
DropLocalBroadcast: config.GetBool("tun.drop_local_broadcast", false),
|
DropLocalBroadcast: config.GetBool("tun.drop_local_broadcast", false),
|
||||||
DropMulticast: config.GetBool("tun.drop_multicast", false),
|
DropMulticast: config.GetBool("tun.drop_multicast", false),
|
||||||
UDPBatchSize: config.GetInt("listen.batch", 64),
|
UDPBatchSize: config.GetInt("listen.batch", 64),
|
||||||
|
udpQueues: udpQueues,
|
||||||
|
tunQueues: config.GetInt("tun.routines", 1),
|
||||||
|
MessageMetrics: messageMetrics,
|
||||||
|
version: buildVersion,
|
||||||
}
|
}
|
||||||
|
|
||||||
switch ifConfig.Cipher {
|
switch ifConfig.Cipher {
|
||||||
@@ -318,14 +341,14 @@ func Main(configPath string, configTest bool, buildVersion string) {
|
|||||||
case "chachapoly":
|
case "chachapoly":
|
||||||
noiseEndianness = binary.LittleEndian
|
noiseEndianness = binary.LittleEndian
|
||||||
default:
|
default:
|
||||||
l.Fatalf("Unknown cipher: %v", ifConfig.Cipher)
|
return nil, fmt.Errorf("unknown cipher: %v", ifConfig.Cipher)
|
||||||
}
|
}
|
||||||
|
|
||||||
var ifce *Interface
|
var ifce *Interface
|
||||||
if !configTest {
|
if !configTest {
|
||||||
ifce, err = NewInterface(ifConfig)
|
ifce, err = NewInterface(ifConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
l.WithError(err).Fatal("Failed to initialize interface")
|
return nil, fmt.Errorf("failed to initialize interface: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
ifce.RegisterConfigChangeCallbacks(config)
|
ifce.RegisterConfigChangeCallbacks(config)
|
||||||
@@ -336,18 +359,17 @@ func Main(configPath string, configTest bool, buildVersion string) {
|
|||||||
|
|
||||||
err = startStats(config, configTest)
|
err = startStats(config, configTest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
l.WithError(err).Fatal("Failed to start stats emitter")
|
return nil, NewContextualError("Failed to start stats emitter", nil, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if configTest {
|
if configTest {
|
||||||
os.Exit(0)
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
//TODO: check if we _should_ be emitting stats
|
//TODO: check if we _should_ be emitting stats
|
||||||
go ifce.emitStats(config.GetDuration("stats.interval", time.Second*10))
|
go ifce.emitStats(config.GetDuration("stats.interval", time.Second*10))
|
||||||
|
|
||||||
attachCommands(ssh, hostMap, handshakeManager.pendingHostMap, lightHouse, ifce)
|
attachCommands(ssh, hostMap, handshakeManager.pendingHostMap, lightHouse, ifce)
|
||||||
ifce.Run(config.GetInt("tun.routines", 1), udpQueues, buildVersion)
|
|
||||||
|
|
||||||
// Start DNS server last to allow using the nebula IP as lighthouse.dns.host
|
// Start DNS server last to allow using the nebula IP as lighthouse.dns.host
|
||||||
if amLighthouse && serveDns {
|
if amLighthouse && serveDns {
|
||||||
@@ -355,30 +377,5 @@ func Main(configPath string, configTest bool, buildVersion string) {
|
|||||||
go dnsMain(hostMap, config)
|
go dnsMain(hostMap, config)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Just sit here and be friendly, main thread.
|
return &Control{ifce, l}, nil
|
||||||
shutdownBlock(ifce)
|
|
||||||
}
|
|
||||||
|
|
||||||
func shutdownBlock(ifce *Interface) {
|
|
||||||
var sigChan = make(chan os.Signal)
|
|
||||||
signal.Notify(sigChan, syscall.SIGTERM)
|
|
||||||
signal.Notify(sigChan, syscall.SIGINT)
|
|
||||||
|
|
||||||
sig := <-sigChan
|
|
||||||
l.WithField("signal", sig).Info("Caught signal, shutting down")
|
|
||||||
|
|
||||||
//TODO: stop tun and udp routines, the lock on hostMap does effectively does that though
|
|
||||||
//TODO: this is probably better as a function in ConnectionManager or HostMap directly
|
|
||||||
ifce.hostMap.Lock()
|
|
||||||
for _, h := range ifce.hostMap.Hosts {
|
|
||||||
if h.ConnectionState.ready {
|
|
||||||
ifce.send(closeTunnel, 0, h.ConnectionState, h, h.remote, []byte{}, make([]byte, 12, 12), make([]byte, mtu))
|
|
||||||
l.WithField("vpnIp", IntIp(h.hostId)).WithField("udpAddr", h.remote).
|
|
||||||
Debug("Sending close tunnel message")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
ifce.hostMap.Unlock()
|
|
||||||
|
|
||||||
l.WithField("signal", sig).Info("Goodbye")
|
|
||||||
os.Exit(0)
|
|
||||||
}
|
}
|
||||||
|
|||||||
97
message_metrics.go
Normal file
97
message_metrics.go
Normal file
@@ -0,0 +1,97 @@
|
|||||||
|
package nebula
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/rcrowley/go-metrics"
|
||||||
|
)
|
||||||
|
|
||||||
|
type MessageMetrics struct {
|
||||||
|
rx [][]metrics.Counter
|
||||||
|
tx [][]metrics.Counter
|
||||||
|
|
||||||
|
rxUnknown metrics.Counter
|
||||||
|
txUnknown metrics.Counter
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MessageMetrics) Rx(t NebulaMessageType, s NebulaMessageSubType, i int64) {
|
||||||
|
if m != nil {
|
||||||
|
if t >= 0 && int(t) < len(m.rx) && s >= 0 && int(s) < len(m.rx[t]) {
|
||||||
|
m.rx[t][s].Inc(i)
|
||||||
|
} else if m.rxUnknown != nil {
|
||||||
|
m.rxUnknown.Inc(i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
func (m *MessageMetrics) Tx(t NebulaMessageType, s NebulaMessageSubType, i int64) {
|
||||||
|
if m != nil {
|
||||||
|
if t >= 0 && int(t) < len(m.tx) && s >= 0 && int(s) < len(m.tx[t]) {
|
||||||
|
m.tx[t][s].Inc(i)
|
||||||
|
} else if m.txUnknown != nil {
|
||||||
|
m.txUnknown.Inc(i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newMessageMetrics() *MessageMetrics {
|
||||||
|
gen := func(t string) [][]metrics.Counter {
|
||||||
|
return [][]metrics.Counter{
|
||||||
|
{
|
||||||
|
metrics.GetOrRegisterCounter(fmt.Sprintf("messages.%s.handshake_ixpsk0", t), nil),
|
||||||
|
},
|
||||||
|
nil,
|
||||||
|
{metrics.GetOrRegisterCounter(fmt.Sprintf("messages.%s.recv_error", t), nil)},
|
||||||
|
{metrics.GetOrRegisterCounter(fmt.Sprintf("messages.%s.lighthouse", t), nil)},
|
||||||
|
{
|
||||||
|
metrics.GetOrRegisterCounter(fmt.Sprintf("messages.%s.test_request", t), nil),
|
||||||
|
metrics.GetOrRegisterCounter(fmt.Sprintf("messages.%s.test_response", t), nil),
|
||||||
|
},
|
||||||
|
{metrics.GetOrRegisterCounter(fmt.Sprintf("messages.%s.close_tunnel", t), nil)},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return &MessageMetrics{
|
||||||
|
rx: gen("rx"),
|
||||||
|
tx: gen("tx"),
|
||||||
|
|
||||||
|
rxUnknown: metrics.GetOrRegisterCounter("messages.rx.other", nil),
|
||||||
|
txUnknown: metrics.GetOrRegisterCounter("messages.tx.other", nil),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Historically we only recorded recv_error, so this is backwards compat
|
||||||
|
func newMessageMetricsOnlyRecvError() *MessageMetrics {
|
||||||
|
gen := func(t string) [][]metrics.Counter {
|
||||||
|
return [][]metrics.Counter{
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
{metrics.GetOrRegisterCounter(fmt.Sprintf("messages.%s.recv_error", t), nil)},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return &MessageMetrics{
|
||||||
|
rx: gen("rx"),
|
||||||
|
tx: gen("tx"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newLighthouseMetrics() *MessageMetrics {
|
||||||
|
gen := func(t string) [][]metrics.Counter {
|
||||||
|
h := make([][]metrics.Counter, len(NebulaMeta_MessageType_name))
|
||||||
|
used := []NebulaMeta_MessageType{
|
||||||
|
NebulaMeta_HostQuery,
|
||||||
|
NebulaMeta_HostQueryReply,
|
||||||
|
NebulaMeta_HostUpdateNotification,
|
||||||
|
NebulaMeta_HostPunchNotification,
|
||||||
|
}
|
||||||
|
for _, i := range used {
|
||||||
|
h[i] = []metrics.Counter{metrics.GetOrRegisterCounter(fmt.Sprintf("lighthouse.%s.%s", t, i.String()), nil)}
|
||||||
|
}
|
||||||
|
return h
|
||||||
|
}
|
||||||
|
return &MessageMetrics{
|
||||||
|
rx: gen("rx"),
|
||||||
|
tx: gen("tx"),
|
||||||
|
|
||||||
|
rxUnknown: metrics.GetOrRegisterCounter("lighthouse.rx.other", nil),
|
||||||
|
txUnknown: metrics.GetOrRegisterCounter("lighthouse.tx.other", nil),
|
||||||
|
}
|
||||||
|
}
|
||||||
113
outside.go
113
outside.go
@@ -2,18 +2,14 @@ package nebula
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/flynn/noise"
|
"github.com/flynn/noise"
|
||||||
"github.com/golang/protobuf/proto"
|
"github.com/golang/protobuf/proto"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
"github.com/slackhq/nebula/cert"
|
"github.com/slackhq/nebula/cert"
|
||||||
// "github.com/google/gopacket"
|
|
||||||
// "github.com/google/gopacket/layers"
|
|
||||||
// "encoding/binary"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"golang.org/x/net/ipv4"
|
"golang.org/x/net/ipv4"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -43,92 +39,15 @@ func (f *Interface) readOutsidePackets(addr *udpAddr, out []byte, packet []byte,
|
|||||||
ci = hostinfo.ConnectionState
|
ci = hostinfo.ConnectionState
|
||||||
}
|
}
|
||||||
|
|
||||||
switch header.Type {
|
handle := f.handlers[header.Version][header.Type][header.Subtype]
|
||||||
case message:
|
|
||||||
if !f.handleEncrypted(ci, addr, header) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
f.decryptToTun(hostinfo, header.MessageCounter, out, packet, fwPacket, nb)
|
if handle == nil {
|
||||||
|
f.messageMetrics.Rx(header.Type, header.Subtype, 1)
|
||||||
// Fallthrough to the bottom to record incoming traffic
|
|
||||||
|
|
||||||
case lightHouse:
|
|
||||||
if !f.handleEncrypted(ci, addr, header) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
d, err := f.decrypt(hostinfo, header.MessageCounter, out, packet, header, nb)
|
|
||||||
if err != nil {
|
|
||||||
hostinfo.logger().WithError(err).WithField("udpAddr", addr).
|
|
||||||
WithField("packet", packet).
|
|
||||||
Error("Failed to decrypt lighthouse packet")
|
|
||||||
|
|
||||||
//TODO: maybe after build 64 is out? 06/14/2018 - NB
|
|
||||||
//f.sendRecvError(net.Addr(addr), header.RemoteIndex)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
f.lightHouse.HandleRequest(addr, hostinfo.hostId, d, hostinfo.GetCert(), f)
|
|
||||||
|
|
||||||
// Fallthrough to the bottom to record incoming traffic
|
|
||||||
|
|
||||||
case test:
|
|
||||||
if !f.handleEncrypted(ci, addr, header) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
d, err := f.decrypt(hostinfo, header.MessageCounter, out, packet, header, nb)
|
|
||||||
if err != nil {
|
|
||||||
hostinfo.logger().WithError(err).WithField("udpAddr", addr).
|
|
||||||
WithField("packet", packet).
|
|
||||||
Error("Failed to decrypt test packet")
|
|
||||||
|
|
||||||
//TODO: maybe after build 64 is out? 06/14/2018 - NB
|
|
||||||
//f.sendRecvError(net.Addr(addr), header.RemoteIndex)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if header.Subtype == testRequest {
|
|
||||||
// This testRequest might be from TryPromoteBest, so we should roam
|
|
||||||
// to the new IP address before responding
|
|
||||||
f.handleHostRoaming(hostinfo, addr)
|
|
||||||
f.send(test, testReply, ci, hostinfo, hostinfo.remote, d, nb, out)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Fallthrough to the bottom to record incoming traffic
|
|
||||||
|
|
||||||
// Non encrypted messages below here, they should not fall through to avoid tracking incoming traffic since they
|
|
||||||
// are unauthenticated
|
|
||||||
|
|
||||||
case handshake:
|
|
||||||
HandleIncomingHandshake(f, addr, packet, header, hostinfo)
|
|
||||||
return
|
|
||||||
|
|
||||||
case recvError:
|
|
||||||
// TODO: Remove this with recv_error deprecation
|
|
||||||
f.handleRecvError(addr, header)
|
|
||||||
return
|
|
||||||
|
|
||||||
case closeTunnel:
|
|
||||||
if !f.handleEncrypted(ci, addr, header) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
hostinfo.logger().WithField("udpAddr", addr).
|
|
||||||
Info("Close tunnel received, tearing down.")
|
|
||||||
|
|
||||||
f.closeTunnel(hostinfo)
|
|
||||||
return
|
|
||||||
|
|
||||||
default:
|
|
||||||
hostinfo.logger().Debugf("Unexpected packet received from %s", addr)
|
hostinfo.logger().Debugf("Unexpected packet received from %s", addr)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
f.handleHostRoaming(hostinfo, addr)
|
handle(hostinfo, ci, addr, header, out, packet, fwPacket, nb)
|
||||||
|
|
||||||
f.connectionManager.In(hostinfo.hostId)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Interface) closeTunnel(hostInfo *HostInfo) {
|
func (f *Interface) closeTunnel(hostInfo *HostInfo) {
|
||||||
@@ -256,7 +175,7 @@ func (f *Interface) decrypt(hostinfo *HostInfo, mc uint64, out []byte, packet []
|
|||||||
return out, nil
|
return out, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out []byte, packet []byte, fwPacket *FirewallPacket, nb []byte) {
|
func (f *Interface) decryptTo(write func([]byte) error, hostinfo *HostInfo, messageCounter uint64, out []byte, packet []byte, fwPacket *FirewallPacket, nb []byte) {
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:HeaderLen], packet[HeaderLen:], messageCounter, nb)
|
out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:HeaderLen], packet[HeaderLen:], messageCounter, nb)
|
||||||
@@ -280,21 +199,25 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if f.firewall.Drop(out, *fwPacket, true, hostinfo, trustedCAs) {
|
dropReason := f.firewall.Drop(out, *fwPacket, true, hostinfo, trustedCAs)
|
||||||
hostinfo.logger().WithField("fwPacket", fwPacket).
|
if dropReason != nil {
|
||||||
Debugln("dropping inbound packet")
|
if l.Level >= logrus.DebugLevel {
|
||||||
|
hostinfo.logger().WithField("fwPacket", fwPacket).
|
||||||
|
WithField("reason", dropReason).
|
||||||
|
Debugln("dropping inbound packet")
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
f.connectionManager.In(hostinfo.hostId)
|
f.connectionManager.In(hostinfo.hostId)
|
||||||
err = f.inside.WriteRaw(out)
|
err = write(out)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
l.WithError(err).Error("Failed to write to tun")
|
l.WithError(err).Error("Failed to write to tun")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Interface) sendRecvError(endpoint *udpAddr, index uint32) {
|
func (f *Interface) sendRecvError(endpoint *udpAddr, index uint32) {
|
||||||
f.metricTxRecvError.Inc(1)
|
f.messageMetrics.Tx(recvError, 0, 1)
|
||||||
|
|
||||||
//TODO: this should be a signed message so we can trust that we should drop the index
|
//TODO: this should be a signed message so we can trust that we should drop the index
|
||||||
b := HeaderEncode(make([]byte, HeaderLen), Version, uint8(recvError), 0, index, 0)
|
b := HeaderEncode(make([]byte, HeaderLen), Version, uint8(recvError), 0, index, 0)
|
||||||
@@ -307,8 +230,6 @@ func (f *Interface) sendRecvError(endpoint *udpAddr, index uint32) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (f *Interface) handleRecvError(addr *udpAddr, h *Header) {
|
func (f *Interface) handleRecvError(addr *udpAddr, h *Header) {
|
||||||
f.metricRxRecvError.Inc(1)
|
|
||||||
|
|
||||||
// This flag is to stop caring about recv_error from old versions
|
// This flag is to stop caring about recv_error from old versions
|
||||||
// This should go away when the old version is gone from prod
|
// This should go away when the old version is gone from prod
|
||||||
if l.Level >= logrus.DebugLevel {
|
if l.Level >= logrus.DebugLevel {
|
||||||
|
|||||||
@@ -1,10 +1,11 @@
|
|||||||
package nebula
|
package nebula
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"golang.org/x/net/ipv4"
|
|
||||||
"net"
|
"net"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"golang.org/x/net/ipv4"
|
||||||
)
|
)
|
||||||
|
|
||||||
func Test_newPacket(t *testing.T) {
|
func Test_newPacket(t *testing.T) {
|
||||||
|
|||||||
@@ -1,9 +1,10 @@
|
|||||||
package nebula
|
package nebula
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestNewPunchyFromConfig(t *testing.T) {
|
func TestNewPunchyFromConfig(t *testing.T) {
|
||||||
|
|||||||
56
ssh.go
56
ssh.go
@@ -5,8 +5,6 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"flag"
|
"flag"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
"github.com/slackhq/nebula/sshd"
|
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"net"
|
"net"
|
||||||
"os"
|
"os"
|
||||||
@@ -14,6 +12,9 @@ import (
|
|||||||
"runtime/pprof"
|
"runtime/pprof"
|
||||||
"strings"
|
"strings"
|
||||||
"syscall"
|
"syscall"
|
||||||
|
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
|
"github.com/slackhq/nebula/sshd"
|
||||||
)
|
)
|
||||||
|
|
||||||
type sshListHostMapFlags struct {
|
type sshListHostMapFlags struct {
|
||||||
@@ -65,10 +66,11 @@ func configSSH(ssh *sshd.SSHServer, c *Config) error {
|
|||||||
return fmt.Errorf("sshd.listen must be provided")
|
return fmt.Errorf("sshd.listen must be provided")
|
||||||
}
|
}
|
||||||
|
|
||||||
port := strings.Split(listen, ":")
|
_, port, err := net.SplitHostPort(listen)
|
||||||
if len(port) < 2 {
|
if err != nil {
|
||||||
return fmt.Errorf("sshd.listen does not have a port")
|
return fmt.Errorf("invalid sshd.listen address: %s", err)
|
||||||
} else if port[1] == "22" {
|
}
|
||||||
|
if port == "22" {
|
||||||
return fmt.Errorf("sshd.listen can not use port 22")
|
return fmt.Errorf("sshd.listen can not use port 22")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -461,7 +463,12 @@ func sshQueryLighthouse(ifce *Interface, fs interface{}, a []string, w sshd.Stri
|
|||||||
return w.WriteLine("No vpn ip was provided")
|
return w.WriteLine("No vpn ip was provided")
|
||||||
}
|
}
|
||||||
|
|
||||||
vpnIp := ip2int(net.ParseIP(a[0]))
|
parsedIp := net.ParseIP(a[0])
|
||||||
|
if parsedIp == nil {
|
||||||
|
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
|
||||||
|
}
|
||||||
|
|
||||||
|
vpnIp := ip2int(parsedIp)
|
||||||
if vpnIp == 0 {
|
if vpnIp == 0 {
|
||||||
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
|
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
|
||||||
}
|
}
|
||||||
@@ -481,7 +488,12 @@ func sshCloseTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWr
|
|||||||
return w.WriteLine("No vpn ip was provided")
|
return w.WriteLine("No vpn ip was provided")
|
||||||
}
|
}
|
||||||
|
|
||||||
vpnIp := ip2int(net.ParseIP(a[0]))
|
parsedIp := net.ParseIP(a[0])
|
||||||
|
if parsedIp == nil {
|
||||||
|
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
|
||||||
|
}
|
||||||
|
|
||||||
|
vpnIp := ip2int(parsedIp)
|
||||||
if vpnIp == 0 {
|
if vpnIp == 0 {
|
||||||
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
|
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
|
||||||
}
|
}
|
||||||
@@ -519,7 +531,12 @@ func sshCreateTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringW
|
|||||||
return w.WriteLine("No vpn ip was provided")
|
return w.WriteLine("No vpn ip was provided")
|
||||||
}
|
}
|
||||||
|
|
||||||
vpnIp := ip2int(net.ParseIP(a[0]))
|
parsedIp := net.ParseIP(a[0])
|
||||||
|
if parsedIp == nil {
|
||||||
|
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
|
||||||
|
}
|
||||||
|
|
||||||
|
vpnIp := ip2int(parsedIp)
|
||||||
if vpnIp == 0 {
|
if vpnIp == 0 {
|
||||||
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
|
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
|
||||||
}
|
}
|
||||||
@@ -571,7 +588,12 @@ func sshChangeRemote(ifce *Interface, fs interface{}, a []string, w sshd.StringW
|
|||||||
return w.WriteLine("Address could not be parsed")
|
return w.WriteLine("Address could not be parsed")
|
||||||
}
|
}
|
||||||
|
|
||||||
vpnIp := ip2int(net.ParseIP(a[0]))
|
parsedIp := net.ParseIP(a[0])
|
||||||
|
if parsedIp == nil {
|
||||||
|
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
|
||||||
|
}
|
||||||
|
|
||||||
|
vpnIp := ip2int(parsedIp)
|
||||||
if vpnIp == 0 {
|
if vpnIp == 0 {
|
||||||
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
|
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
|
||||||
}
|
}
|
||||||
@@ -647,7 +669,12 @@ func sshPrintCert(ifce *Interface, fs interface{}, a []string, w sshd.StringWrit
|
|||||||
|
|
||||||
cert := ifce.certState.certificate
|
cert := ifce.certState.certificate
|
||||||
if len(a) > 0 {
|
if len(a) > 0 {
|
||||||
vpnIp := ip2int(net.ParseIP(a[0]))
|
parsedIp := net.ParseIP(a[0])
|
||||||
|
if parsedIp == nil {
|
||||||
|
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
|
||||||
|
}
|
||||||
|
|
||||||
|
vpnIp := ip2int(parsedIp)
|
||||||
if vpnIp == 0 {
|
if vpnIp == 0 {
|
||||||
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
|
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
|
||||||
}
|
}
|
||||||
@@ -694,7 +721,12 @@ func sshPrintTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWr
|
|||||||
return w.WriteLine("No vpn ip was provided")
|
return w.WriteLine("No vpn ip was provided")
|
||||||
}
|
}
|
||||||
|
|
||||||
vpnIp := ip2int(net.ParseIP(a[0]))
|
parsedIp := net.ParseIP(a[0])
|
||||||
|
if parsedIp == nil {
|
||||||
|
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
|
||||||
|
}
|
||||||
|
|
||||||
|
vpnIp := ip2int(parsedIp)
|
||||||
if vpnIp == 0 {
|
if vpnIp == 0 {
|
||||||
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
|
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,9 +4,10 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"flag"
|
"flag"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/armon/go-radix"
|
|
||||||
"sort"
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/armon/go-radix"
|
||||||
)
|
)
|
||||||
|
|
||||||
// CommandFlags is a function called before help or command execution to parse command line flags
|
// CommandFlags is a function called before help or command execution to parse command line flags
|
||||||
|
|||||||
@@ -2,10 +2,11 @@ package sshd
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net"
|
||||||
|
|
||||||
"github.com/armon/go-radix"
|
"github.com/armon/go-radix"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
"golang.org/x/crypto/ssh"
|
"golang.org/x/crypto/ssh"
|
||||||
"net"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type SSHServer struct {
|
type SSHServer struct {
|
||||||
|
|||||||
@@ -2,13 +2,14 @@ package sshd
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"sort"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/anmitsu/go-shlex"
|
"github.com/anmitsu/go-shlex"
|
||||||
"github.com/armon/go-radix"
|
"github.com/armon/go-radix"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
"golang.org/x/crypto/ssh"
|
"golang.org/x/crypto/ssh"
|
||||||
"golang.org/x/crypto/ssh/terminal"
|
"golang.org/x/crypto/ssh/terminal"
|
||||||
"sort"
|
|
||||||
"strings"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type session struct {
|
type session struct {
|
||||||
|
|||||||
11
stats.go
11
stats.go
@@ -3,15 +3,16 @@ package nebula
|
|||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/cyberdelia/go-metrics-graphite"
|
|
||||||
mp "github.com/nbrownus/go-metrics-prometheus"
|
|
||||||
"github.com/prometheus/client_golang/prometheus"
|
|
||||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
|
||||||
"github.com/rcrowley/go-metrics"
|
|
||||||
"log"
|
"log"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
graphite "github.com/cyberdelia/go-metrics-graphite"
|
||||||
|
mp "github.com/nbrownus/go-metrics-prometheus"
|
||||||
|
"github.com/prometheus/client_golang/prometheus"
|
||||||
|
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||||
|
"github.com/rcrowley/go-metrics"
|
||||||
)
|
)
|
||||||
|
|
||||||
func startStats(c *Config, configTest bool) error {
|
func startStats(c *Config, configTest bool) error {
|
||||||
|
|||||||
@@ -1,9 +1,10 @@
|
|||||||
package nebula
|
package nebula
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestNewTimerWheel(t *testing.T) {
|
func TestNewTimerWheel(t *testing.T) {
|
||||||
|
|||||||
76
tun_android.go
Normal file
76
tun_android.go
Normal file
@@ -0,0 +1,76 @@
|
|||||||
|
package nebula
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"golang.org/x/sys/unix"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Tun struct {
|
||||||
|
io.ReadWriteCloser
|
||||||
|
fd int
|
||||||
|
Device string
|
||||||
|
Cidr *net.IPNet
|
||||||
|
MaxMTU int
|
||||||
|
DefaultMTU int
|
||||||
|
TXQueueLen int
|
||||||
|
Routes []route
|
||||||
|
UnsafeRoutes []route
|
||||||
|
}
|
||||||
|
|
||||||
|
func newTunFromFd(deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
|
||||||
|
file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
|
||||||
|
|
||||||
|
ifce = &Tun{
|
||||||
|
ReadWriteCloser: file,
|
||||||
|
fd: int(file.Fd()),
|
||||||
|
Device: "android",
|
||||||
|
Cidr: cidr,
|
||||||
|
DefaultMTU: defaultMTU,
|
||||||
|
TXQueueLen: txQueueLen,
|
||||||
|
Routes: routes,
|
||||||
|
UnsafeRoutes: unsafeRoutes,
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
|
||||||
|
return nil, fmt.Errorf("newTun not supported in Android")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Tun) WriteRaw(b []byte) error {
|
||||||
|
var nn int
|
||||||
|
for {
|
||||||
|
max := len(b)
|
||||||
|
n, err := unix.Write(c.fd, b[nn:max])
|
||||||
|
if n > 0 {
|
||||||
|
nn += n
|
||||||
|
}
|
||||||
|
if nn == len(b) {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if n == 0 {
|
||||||
|
return io.ErrUnexpectedEOF
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c Tun) Activate() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Tun) CidrNet() *net.IPNet {
|
||||||
|
return c.Cidr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Tun) DeviceName() string {
|
||||||
|
return c.Device
|
||||||
|
}
|
||||||
@@ -1,3 +1,5 @@
|
|||||||
|
// +build !ios
|
||||||
|
|
||||||
package nebula
|
package nebula
|
||||||
|
|
||||||
import (
|
import (
|
||||||
@@ -20,8 +22,9 @@ type Tun struct {
|
|||||||
|
|
||||||
func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
|
func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
|
||||||
if len(routes) > 0 {
|
if len(routes) > 0 {
|
||||||
return nil, fmt.Errorf("Route MTU not supported in Darwin")
|
return nil, fmt.Errorf("route MTU not supported in Darwin")
|
||||||
}
|
}
|
||||||
|
|
||||||
// NOTE: You cannot set the deviceName under Darwin, so you must check tun.Device after calling .Activate()
|
// NOTE: You cannot set the deviceName under Darwin, so you must check tun.Device after calling .Activate()
|
||||||
return &Tun{
|
return &Tun{
|
||||||
Cidr: cidr,
|
Cidr: cidr,
|
||||||
@@ -30,13 +33,17 @@ func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route,
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func newTunFromFd(deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
|
||||||
|
return nil, fmt.Errorf("newTunFromFd not supported in Darwin")
|
||||||
|
}
|
||||||
|
|
||||||
func (c *Tun) Activate() error {
|
func (c *Tun) Activate() error {
|
||||||
var err error
|
var err error
|
||||||
c.Interface, err = water.New(water.Config{
|
c.Interface, err = water.New(water.Config{
|
||||||
DeviceType: water.TUN,
|
DeviceType: water.TUN,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("Activate failed: %v", err)
|
return fmt.Errorf("activate failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
c.Device = c.Interface.Name()
|
c.Device = c.Interface.Name()
|
||||||
@@ -61,6 +68,14 @@ func (c *Tun) Activate() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Tun) CidrNet() *net.IPNet {
|
||||||
|
return c.Cidr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Tun) DeviceName() string {
|
||||||
|
return c.Device
|
||||||
|
}
|
||||||
|
|
||||||
func (c *Tun) WriteRaw(b []byte) error {
|
func (c *Tun) WriteRaw(b []byte) error {
|
||||||
_, err := c.Write(b)
|
_, err := c.Write(b)
|
||||||
return err
|
return err
|
||||||
|
|||||||
74
tun_disabled.go
Normal file
74
tun_disabled.go
Normal file
@@ -0,0 +1,74 @@
|
|||||||
|
package nebula
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
type disabledTun struct {
|
||||||
|
block chan struct{}
|
||||||
|
cidr *net.IPNet
|
||||||
|
logger *log.Logger
|
||||||
|
}
|
||||||
|
|
||||||
|
func newDisabledTun(cidr *net.IPNet, l *log.Logger) *disabledTun {
|
||||||
|
return &disabledTun{
|
||||||
|
cidr: cidr,
|
||||||
|
block: make(chan struct{}),
|
||||||
|
logger: l,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (*disabledTun) Activate() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *disabledTun) CidrNet() *net.IPNet {
|
||||||
|
return t.cidr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (*disabledTun) DeviceName() string {
|
||||||
|
return "disabled"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *disabledTun) Read(b []byte) (int, error) {
|
||||||
|
<-t.block
|
||||||
|
return 0, io.EOF
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *disabledTun) Write(b []byte) (int, error) {
|
||||||
|
t.logger.WithField("raw", prettyPacket(b)).Debugf("Disabled tun received unexpected payload")
|
||||||
|
return len(b), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *disabledTun) WriteRaw(b []byte) error {
|
||||||
|
_, err := t.Write(b)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *disabledTun) Close() error {
|
||||||
|
if t.block != nil {
|
||||||
|
close(t.block)
|
||||||
|
t.block = nil
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type prettyPacket []byte
|
||||||
|
|
||||||
|
func (p prettyPacket) String() string {
|
||||||
|
var s strings.Builder
|
||||||
|
|
||||||
|
for i, b := range p {
|
||||||
|
if i > 0 && i%8 == 0 {
|
||||||
|
s.WriteString(" ")
|
||||||
|
}
|
||||||
|
s.WriteString(fmt.Sprintf("%02x ", b))
|
||||||
|
}
|
||||||
|
|
||||||
|
return s.String()
|
||||||
|
}
|
||||||
89
tun_freebsd.go
Normal file
89
tun_freebsd.go
Normal file
@@ -0,0 +1,89 @@
|
|||||||
|
package nebula
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"os"
|
||||||
|
"os/exec"
|
||||||
|
"regexp"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
|
||||||
|
|
||||||
|
type Tun struct {
|
||||||
|
Device string
|
||||||
|
Cidr *net.IPNet
|
||||||
|
MTU int
|
||||||
|
UnsafeRoutes []route
|
||||||
|
|
||||||
|
io.ReadWriteCloser
|
||||||
|
}
|
||||||
|
|
||||||
|
func newTunFromFd(deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
|
||||||
|
return nil, fmt.Errorf("newTunFromFd not supported in FreeBSD")
|
||||||
|
}
|
||||||
|
|
||||||
|
func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
|
||||||
|
if len(routes) > 0 {
|
||||||
|
return nil, fmt.Errorf("Route MTU not supported in FreeBSD")
|
||||||
|
}
|
||||||
|
if strings.HasPrefix(deviceName, "/dev/") {
|
||||||
|
deviceName = strings.TrimPrefix(deviceName, "/dev/")
|
||||||
|
}
|
||||||
|
if !deviceNameRE.MatchString(deviceName) {
|
||||||
|
return nil, fmt.Errorf("tun.dev must match `tun[0-9]+`")
|
||||||
|
}
|
||||||
|
return &Tun{
|
||||||
|
Device: deviceName,
|
||||||
|
Cidr: cidr,
|
||||||
|
MTU: defaultMTU,
|
||||||
|
UnsafeRoutes: unsafeRoutes,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Tun) Activate() error {
|
||||||
|
var err error
|
||||||
|
c.ReadWriteCloser, err = os.OpenFile("/dev/"+c.Device, os.O_RDWR, 0)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("Activate failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO use syscalls instead of exec.Command
|
||||||
|
l.Debug("command: ifconfig", c.Device, c.Cidr.String(), c.Cidr.IP.String())
|
||||||
|
if err = exec.Command("/sbin/ifconfig", c.Device, c.Cidr.String(), c.Cidr.IP.String()).Run(); err != nil {
|
||||||
|
return fmt.Errorf("failed to run 'ifconfig': %s", err)
|
||||||
|
}
|
||||||
|
l.Debug("command: route", "-n", "add", "-net", c.Cidr.String(), "-interface", c.Device)
|
||||||
|
if err = exec.Command("/sbin/route", "-n", "add", "-net", c.Cidr.String(), "-interface", c.Device).Run(); err != nil {
|
||||||
|
return fmt.Errorf("failed to run 'route add': %s", err)
|
||||||
|
}
|
||||||
|
l.Debug("command: ifconfig", c.Device, "mtu", strconv.Itoa(c.MTU))
|
||||||
|
if err = exec.Command("/sbin/ifconfig", c.Device, "mtu", strconv.Itoa(c.MTU)).Run(); err != nil {
|
||||||
|
return fmt.Errorf("failed to run 'ifconfig': %s", err)
|
||||||
|
}
|
||||||
|
// Unsafe path routes
|
||||||
|
for _, r := range c.UnsafeRoutes {
|
||||||
|
l.Debug("command: route", "-n", "add", "-net", r.route.String(), "-interface", c.Device)
|
||||||
|
if err = exec.Command("/sbin/route", "-n", "add", "-net", r.route.String(), "-interface", c.Device).Run(); err != nil {
|
||||||
|
return fmt.Errorf("failed to run 'route add' for unsafe_route %s: %s", r.route.String(), err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Tun) CidrNet() *net.IPNet {
|
||||||
|
return c.Cidr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Tun) DeviceName() string {
|
||||||
|
return c.Device
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Tun) WriteRaw(b []byte) error {
|
||||||
|
_, err := c.Write(b)
|
||||||
|
return err
|
||||||
|
}
|
||||||
113
tun_ios.go
Normal file
113
tun_ios.go
Normal file
@@ -0,0 +1,113 @@
|
|||||||
|
// +build ios
|
||||||
|
|
||||||
|
package nebula
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"os"
|
||||||
|
"sync"
|
||||||
|
"syscall"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Tun struct {
|
||||||
|
io.ReadWriteCloser
|
||||||
|
Device string
|
||||||
|
Cidr *net.IPNet
|
||||||
|
}
|
||||||
|
|
||||||
|
func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
|
||||||
|
return nil, fmt.Errorf("newTun not supported in iOS")
|
||||||
|
}
|
||||||
|
|
||||||
|
func newTunFromFd(deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
|
||||||
|
if len(routes) > 0 {
|
||||||
|
return nil, fmt.Errorf("route MTU not supported in Darwin")
|
||||||
|
}
|
||||||
|
|
||||||
|
file := os.NewFile(uintptr(deviceFd), "/dev/tun")
|
||||||
|
ifce = &Tun{
|
||||||
|
Cidr: cidr,
|
||||||
|
Device: "iOS",
|
||||||
|
ReadWriteCloser: &tunReadCloser{f: file},
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Tun) Activate() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Tun) WriteRaw(b []byte) error {
|
||||||
|
_, err := c.Write(b)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// The following is hoisted up from water, we do this so we can inject our own fd on iOS
|
||||||
|
type tunReadCloser struct {
|
||||||
|
f io.ReadWriteCloser
|
||||||
|
|
||||||
|
rMu sync.Mutex
|
||||||
|
rBuf []byte
|
||||||
|
|
||||||
|
wMu sync.Mutex
|
||||||
|
wBuf []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *tunReadCloser) Read(to []byte) (int, error) {
|
||||||
|
t.rMu.Lock()
|
||||||
|
defer t.rMu.Unlock()
|
||||||
|
|
||||||
|
if cap(t.rBuf) < len(to)+4 {
|
||||||
|
t.rBuf = make([]byte, len(to)+4)
|
||||||
|
}
|
||||||
|
t.rBuf = t.rBuf[:len(to)+4]
|
||||||
|
|
||||||
|
n, err := t.f.Read(t.rBuf)
|
||||||
|
copy(to, t.rBuf[4:])
|
||||||
|
return n - 4, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *tunReadCloser) Write(from []byte) (int, error) {
|
||||||
|
|
||||||
|
if len(from) == 0 {
|
||||||
|
return 0, syscall.EIO
|
||||||
|
}
|
||||||
|
|
||||||
|
t.wMu.Lock()
|
||||||
|
defer t.wMu.Unlock()
|
||||||
|
|
||||||
|
if cap(t.wBuf) < len(from)+4 {
|
||||||
|
t.wBuf = make([]byte, len(from)+4)
|
||||||
|
}
|
||||||
|
t.wBuf = t.wBuf[:len(from)+4]
|
||||||
|
|
||||||
|
// Determine the IP Family for the NULL L2 Header
|
||||||
|
ipVer := from[0] >> 4
|
||||||
|
if ipVer == 4 {
|
||||||
|
t.wBuf[3] = syscall.AF_INET
|
||||||
|
} else if ipVer == 6 {
|
||||||
|
t.wBuf[3] = syscall.AF_INET6
|
||||||
|
} else {
|
||||||
|
return 0, errors.New("unable to determine IP version from packet")
|
||||||
|
}
|
||||||
|
|
||||||
|
copy(t.wBuf[4:], from)
|
||||||
|
|
||||||
|
n, err := t.f.Write(t.wBuf)
|
||||||
|
return n - 4, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *tunReadCloser) Close() error {
|
||||||
|
return t.f.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Tun) CidrNet() *net.IPNet {
|
||||||
|
return c.Cidr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Tun) DeviceName() string {
|
||||||
|
return c.Device
|
||||||
|
}
|
||||||
43
tun_linux.go
43
tun_linux.go
@@ -1,3 +1,5 @@
|
|||||||
|
// +build !android
|
||||||
|
|
||||||
package nebula
|
package nebula
|
||||||
|
|
||||||
import (
|
import (
|
||||||
@@ -75,6 +77,23 @@ type ifreqQLEN struct {
|
|||||||
pad [8]byte
|
pad [8]byte
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func newTunFromFd(deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
|
||||||
|
|
||||||
|
file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
|
||||||
|
|
||||||
|
ifce = &Tun{
|
||||||
|
ReadWriteCloser: file,
|
||||||
|
fd: int(file.Fd()),
|
||||||
|
Device: "tun0",
|
||||||
|
Cidr: cidr,
|
||||||
|
DefaultMTU: defaultMTU,
|
||||||
|
TXQueueLen: txQueueLen,
|
||||||
|
Routes: routes,
|
||||||
|
UnsafeRoutes: unsafeRoutes,
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
|
func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
|
||||||
fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
|
fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -216,6 +235,7 @@ func (c Tun) Activate() error {
|
|||||||
LinkIndex: link.Attrs().Index,
|
LinkIndex: link.Attrs().Index,
|
||||||
Dst: dr,
|
Dst: dr,
|
||||||
MTU: c.DefaultMTU,
|
MTU: c.DefaultMTU,
|
||||||
|
AdvMSS: c.advMSS(route{}),
|
||||||
Scope: unix.RT_SCOPE_LINK,
|
Scope: unix.RT_SCOPE_LINK,
|
||||||
Src: c.Cidr.IP,
|
Src: c.Cidr.IP,
|
||||||
Protocol: unix.RTPROT_KERNEL,
|
Protocol: unix.RTPROT_KERNEL,
|
||||||
@@ -233,6 +253,7 @@ func (c Tun) Activate() error {
|
|||||||
LinkIndex: link.Attrs().Index,
|
LinkIndex: link.Attrs().Index,
|
||||||
Dst: r.route,
|
Dst: r.route,
|
||||||
MTU: r.mtu,
|
MTU: r.mtu,
|
||||||
|
AdvMSS: c.advMSS(r),
|
||||||
Scope: unix.RT_SCOPE_LINK,
|
Scope: unix.RT_SCOPE_LINK,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -248,6 +269,7 @@ func (c Tun) Activate() error {
|
|||||||
LinkIndex: link.Attrs().Index,
|
LinkIndex: link.Attrs().Index,
|
||||||
Dst: r.route,
|
Dst: r.route,
|
||||||
MTU: r.mtu,
|
MTU: r.mtu,
|
||||||
|
AdvMSS: c.advMSS(r),
|
||||||
Scope: unix.RT_SCOPE_LINK,
|
Scope: unix.RT_SCOPE_LINK,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -265,3 +287,24 @@ func (c Tun) Activate() error {
|
|||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Tun) CidrNet() *net.IPNet {
|
||||||
|
return c.Cidr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Tun) DeviceName() string {
|
||||||
|
return c.Device
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c Tun) advMSS(r route) int {
|
||||||
|
mtu := r.mtu
|
||||||
|
if r.mtu == 0 {
|
||||||
|
mtu = c.DefaultMTU
|
||||||
|
}
|
||||||
|
|
||||||
|
// We only need to set advmss if the route MTU does not match the device MTU
|
||||||
|
if mtu != c.MaxMTU {
|
||||||
|
return mtu - 40
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|||||||
31
tun_linux_test.go
Normal file
31
tun_linux_test.go
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
package nebula
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
var runAdvMSSTests = []struct {
|
||||||
|
name string
|
||||||
|
tun Tun
|
||||||
|
r route
|
||||||
|
expected int
|
||||||
|
}{
|
||||||
|
// Standard case, default MTU is the device max MTU
|
||||||
|
{"default", Tun{DefaultMTU: 1440, MaxMTU: 1440}, route{}, 0},
|
||||||
|
{"default-min", Tun{DefaultMTU: 1440, MaxMTU: 1440}, route{mtu: 1440}, 0},
|
||||||
|
{"default-low", Tun{DefaultMTU: 1440, MaxMTU: 1440}, route{mtu: 1200}, 1160},
|
||||||
|
|
||||||
|
// Case where we have a route MTU set higher than the default
|
||||||
|
{"route", Tun{DefaultMTU: 1440, MaxMTU: 8941}, route{}, 1400},
|
||||||
|
{"route-min", Tun{DefaultMTU: 1440, MaxMTU: 8941}, route{mtu: 1440}, 1400},
|
||||||
|
{"route-high", Tun{DefaultMTU: 1440, MaxMTU: 8941}, route{mtu: 8941}, 0},
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTunAdvMSS(t *testing.T) {
|
||||||
|
for _, tt := range runAdvMSSTests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
o := tt.tun.advMSS(tt.r)
|
||||||
|
if o != tt.expected {
|
||||||
|
t.Errorf("got %d, want %d", o, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -18,9 +18,13 @@ type Tun struct {
|
|||||||
*water.Interface
|
*water.Interface
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func newTunFromFd(deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
|
||||||
|
return nil, fmt.Errorf("newTunFromFd not supported in Windows")
|
||||||
|
}
|
||||||
|
|
||||||
func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
|
func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
|
||||||
if len(routes) > 0 {
|
if len(routes) > 0 {
|
||||||
return nil, fmt.Errorf("Route MTU not supported in Windows")
|
return nil, fmt.Errorf("route MTU not supported in Windows")
|
||||||
}
|
}
|
||||||
|
|
||||||
// NOTE: You cannot set the deviceName under Windows, so you must check tun.Device after calling .Activate()
|
// NOTE: You cannot set the deviceName under Windows, so you must check tun.Device after calling .Activate()
|
||||||
@@ -84,6 +88,14 @@ func (c *Tun) Activate() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Tun) CidrNet() *net.IPNet {
|
||||||
|
return c.Cidr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Tun) DeviceName() string {
|
||||||
|
return c.Device
|
||||||
|
}
|
||||||
|
|
||||||
func (c *Tun) WriteRaw(b []byte) error {
|
func (c *Tun) WriteRaw(b []byte) error {
|
||||||
_, err := c.Write(b)
|
_, err := c.Write(b)
|
||||||
return err
|
return err
|
||||||
|
|||||||
36
udp_android.go
Normal file
36
udp_android.go
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
package nebula
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"syscall"
|
||||||
|
|
||||||
|
"golang.org/x/sys/unix"
|
||||||
|
)
|
||||||
|
|
||||||
|
func NewListenConfig(multi bool) net.ListenConfig {
|
||||||
|
return net.ListenConfig{
|
||||||
|
Control: func(network, address string, c syscall.RawConn) error {
|
||||||
|
if multi {
|
||||||
|
var controlErr error
|
||||||
|
err := c.Control(func(fd uintptr) {
|
||||||
|
if err := syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, unix.SO_REUSEPORT, 1); err != nil {
|
||||||
|
controlErr = fmt.Errorf("SO_REUSEPORT failed: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if controlErr != nil {
|
||||||
|
return controlErr
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *udpConn) Rebind() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -32,3 +32,12 @@ func NewListenConfig(multi bool) net.ListenConfig {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (u *udpConn) Rebind() error {
|
||||||
|
file, err := u.File()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return syscall.SetsockoptInt(int(file.Fd()), unix.IPPROTO_IP, unix.IP_BOUND_IF, 0)
|
||||||
|
}
|
||||||
|
|||||||
38
udp_freebsd.go
Normal file
38
udp_freebsd.go
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
package nebula
|
||||||
|
|
||||||
|
// FreeBSD support is primarily implemented in udp_generic, besides NewListenConfig
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"syscall"
|
||||||
|
|
||||||
|
"golang.org/x/sys/unix"
|
||||||
|
)
|
||||||
|
|
||||||
|
func NewListenConfig(multi bool) net.ListenConfig {
|
||||||
|
return net.ListenConfig{
|
||||||
|
Control: func(network, address string, c syscall.RawConn) error {
|
||||||
|
if multi {
|
||||||
|
var controlErr error
|
||||||
|
err := c.Control(func(fd uintptr) {
|
||||||
|
if err := syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, unix.SO_REUSEPORT, 1); err != nil {
|
||||||
|
controlErr = fmt.Errorf("SO_REUSEPORT failed: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if controlErr != nil {
|
||||||
|
return controlErr
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *udpConn) Rebind() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
// +build !linux
|
// +build !linux android
|
||||||
|
|
||||||
// udp_generic implements the nebula UDP interface in pure Go stdlib. This
|
// udp_generic implements the nebula UDP interface in pure Go stdlib. This
|
||||||
// means it can be used on platforms like Darwin and Windows.
|
// means it can be used on platforms like Darwin and Windows.
|
||||||
@@ -65,6 +65,17 @@ func (ua *udpAddr) Equals(t *udpAddr) bool {
|
|||||||
return ua.IP.Equal(t.IP) && ua.Port == t.Port
|
return ua.IP.Equal(t.IP) && ua.Port == t.Port
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (ua *udpAddr) Copy() udpAddr {
|
||||||
|
nu := udpAddr{net.UDPAddr{
|
||||||
|
Port: ua.Port,
|
||||||
|
Zone: ua.Zone,
|
||||||
|
IP: make(net.IP, len(ua.IP)),
|
||||||
|
}}
|
||||||
|
|
||||||
|
copy(nu.IP, ua.IP)
|
||||||
|
return nu
|
||||||
|
}
|
||||||
|
|
||||||
func (uc *udpConn) WriteTo(b []byte, addr *udpAddr) error {
|
func (uc *udpConn) WriteTo(b []byte, addr *udpAddr) error {
|
||||||
_, err := uc.UDPConn.WriteToUDP(b, &addr.UDPAddr)
|
_, err := uc.UDPConn.WriteToUDP(b, &addr.UDPAddr)
|
||||||
return err
|
return err
|
||||||
|
|||||||
59
udp_linux.go
59
udp_linux.go
@@ -1,3 +1,5 @@
|
|||||||
|
// +build !android
|
||||||
|
|
||||||
package nebula
|
package nebula
|
||||||
|
|
||||||
import (
|
import (
|
||||||
@@ -69,8 +71,10 @@ func NewListener(ip string, port int, multi bool) (*udpConn, error) {
|
|||||||
var lip [4]byte
|
var lip [4]byte
|
||||||
copy(lip[:], net.ParseIP(ip).To4())
|
copy(lip[:], net.ParseIP(ip).To4())
|
||||||
|
|
||||||
if err = unix.SetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_REUSEPORT, 1); err != nil {
|
if multi {
|
||||||
return nil, fmt.Errorf("unable to set SO_REUSEPORT: %s", err)
|
if err = unix.SetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_REUSEPORT, 1); err != nil {
|
||||||
|
return nil, fmt.Errorf("unable to set SO_REUSEPORT: %s", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = unix.Bind(fd, &unix.SockaddrInet4{Addr: lip, Port: port}); err != nil {
|
if err = unix.Bind(fd, &unix.SockaddrInet4{Addr: lip, Port: port}); err != nil {
|
||||||
@@ -85,6 +89,14 @@ func NewListener(ip string, port int, multi bool) (*udpConn, error) {
|
|||||||
return &udpConn{sysFd: fd}, err
|
return &udpConn{sysFd: fd}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (u *udpConn) Rebind() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ua *udpAddr) Copy() udpAddr {
|
||||||
|
return *ua
|
||||||
|
}
|
||||||
|
|
||||||
func (u *udpConn) SetRecvBuffer(n int) error {
|
func (u *udpConn) SetRecvBuffer(n int) error {
|
||||||
return unix.SetsockoptInt(u.sysFd, unix.SOL_SOCKET, unix.SO_RCVBUFFORCE, n)
|
return unix.SetsockoptInt(u.sysFd, unix.SOL_SOCKET, unix.SO_RCVBUFFORCE, n)
|
||||||
}
|
}
|
||||||
@@ -137,9 +149,13 @@ func (u *udpConn) ListenOut(f *Interface) {
|
|||||||
//TODO: should we track this?
|
//TODO: should we track this?
|
||||||
//metric := metrics.GetOrRegisterHistogram("test.batch_read", nil, metrics.NewExpDecaySample(1028, 0.015))
|
//metric := metrics.GetOrRegisterHistogram("test.batch_read", nil, metrics.NewExpDecaySample(1028, 0.015))
|
||||||
msgs, buffers, names := u.PrepareRawMessages(f.udpBatchSize)
|
msgs, buffers, names := u.PrepareRawMessages(f.udpBatchSize)
|
||||||
|
read := u.ReadMulti
|
||||||
|
if f.udpBatchSize == 1 {
|
||||||
|
read = u.ReadSingle
|
||||||
|
}
|
||||||
|
|
||||||
for {
|
for {
|
||||||
n, err := u.ReadMulti(msgs)
|
n, err := read(msgs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
l.WithError(err).Error("Failed to read packets")
|
l.WithError(err).Error("Failed to read packets")
|
||||||
continue
|
continue
|
||||||
@@ -155,34 +171,24 @@ func (u *udpConn) ListenOut(f *Interface) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *udpConn) Read(addr *udpAddr, b []byte) ([]byte, error) {
|
func (u *udpConn) ReadSingle(msgs []rawMessage) (int, error) {
|
||||||
var rsa rawSockaddrAny
|
|
||||||
var rLen = unix.SizeofSockaddrAny
|
|
||||||
|
|
||||||
for {
|
for {
|
||||||
n, _, err := unix.Syscall6(
|
n, _, err := unix.Syscall6(
|
||||||
unix.SYS_RECVFROM,
|
unix.SYS_RECVMSG,
|
||||||
uintptr(u.sysFd),
|
uintptr(u.sysFd),
|
||||||
uintptr(unsafe.Pointer(&b[0])),
|
uintptr(unsafe.Pointer(&(msgs[0].Hdr))),
|
||||||
uintptr(len(b)),
|
0,
|
||||||
uintptr(0),
|
0,
|
||||||
uintptr(unsafe.Pointer(&rsa)),
|
0,
|
||||||
uintptr(unsafe.Pointer(&rLen)),
|
0,
|
||||||
)
|
)
|
||||||
|
|
||||||
if err != 0 {
|
if err != 0 {
|
||||||
return nil, &net.OpError{Op: "read", Err: err}
|
return 0, &net.OpError{Op: "recvmsg", Err: err}
|
||||||
}
|
}
|
||||||
|
|
||||||
if rsa.Addr.Family == unix.AF_INET {
|
msgs[0].Len = uint32(n)
|
||||||
addr.Port = uint16(rsa.Addr.Data[0])<<8 + uint16(rsa.Addr.Data[1])
|
return 1, nil
|
||||||
addr.IP = uint32(rsa.Addr.Data[2])<<24 + uint32(rsa.Addr.Data[3])<<16 + uint32(rsa.Addr.Data[4])<<8 + uint32(rsa.Addr.Data[5])
|
|
||||||
} else {
|
|
||||||
addr.Port = 0
|
|
||||||
addr.IP = 0
|
|
||||||
}
|
|
||||||
|
|
||||||
return b[:n], nil
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -280,13 +286,6 @@ func (ua *udpAddr) Equals(t *udpAddr) bool {
|
|||||||
return ua.IP == t.IP && ua.Port == t.Port
|
return ua.IP == t.IP && ua.Port == t.Port
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ua *udpAddr) Copy() *udpAddr {
|
|
||||||
return &udpAddr{
|
|
||||||
Port: ua.Port,
|
|
||||||
IP: ua.IP,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ua *udpAddr) String() string {
|
func (ua *udpAddr) String() string {
|
||||||
return fmt.Sprintf("%s:%v", int2ip(ua.IP), ua.Port)
|
return fmt.Sprintf("%s:%v", int2ip(ua.IP), ua.Port)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
// +build linux
|
// +build linux
|
||||||
// +build 386 amd64p32 arm mips mipsle
|
// +build 386 amd64p32 arm mips mipsle
|
||||||
|
// +build !android
|
||||||
|
|
||||||
package nebula
|
package nebula
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
// +build linux
|
// +build linux
|
||||||
// +build amd64 arm64 ppc64 ppc64le mips64 mips64le s390x
|
// +build amd64 arm64 ppc64 ppc64le mips64 mips64le s390x
|
||||||
|
// +build !android
|
||||||
|
|
||||||
package nebula
|
package nebula
|
||||||
|
|
||||||
|
|||||||
@@ -20,3 +20,7 @@ func NewListenConfig(multi bool) net.ListenConfig {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (u *udpConn) Rebind() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|||||||
130
util/assert.go
Normal file
130
util/assert.go
Normal file
@@ -0,0 +1,130 @@
|
|||||||
|
package util
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
// AssertDeepCopyEqual checks to see if two variables have the same values but DO NOT share any memory
|
||||||
|
// There is currently a special case for `time.loc` (as this code traverses into unexported fields)
|
||||||
|
func AssertDeepCopyEqual(t *testing.T, a interface{}, b interface{}) {
|
||||||
|
v1 := reflect.ValueOf(a)
|
||||||
|
v2 := reflect.ValueOf(b)
|
||||||
|
|
||||||
|
if !assert.Equal(t, v1.Type(), v2.Type()) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
traverseDeepCopy(t, v1, v2, v1.Type().String())
|
||||||
|
}
|
||||||
|
|
||||||
|
func traverseDeepCopy(t *testing.T, v1 reflect.Value, v2 reflect.Value, name string) bool {
|
||||||
|
switch v1.Kind() {
|
||||||
|
case reflect.Array:
|
||||||
|
for i := 0; i < v1.Len(); i++ {
|
||||||
|
if !traverseDeepCopy(t, v1.Index(i), v2.Index(i), fmt.Sprintf("%s[%v]", name, i)) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
|
||||||
|
case reflect.Slice:
|
||||||
|
if v1.IsNil() || v2.IsNil() {
|
||||||
|
return assert.Equal(t, v1.IsNil(), v2.IsNil(), "%s are not both nil %+v, %+v", name, v1, v2)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !assert.Equal(t, v1.Len(), v2.Len(), "%s did not have the same length", name) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// A slice with cap 0
|
||||||
|
if v1.Cap() != 0 && !assert.NotEqual(t, v1.Pointer(), v2.Pointer(), "%s point to the same slice %v == %v", name, v1.Pointer(), v2.Pointer()) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
v1c := v1.Cap()
|
||||||
|
v2c := v2.Cap()
|
||||||
|
if v1c > 0 && v2c > 0 && v1.Slice(0, v1c).Slice(v1c-1, v1c-1).Pointer() == v2.Slice(0, v2c).Slice(v2c-1, v2c-1).Pointer() {
|
||||||
|
return assert.Fail(t, "", "%s share some underlying memory", name)
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 0; i < v1.Len(); i++ {
|
||||||
|
if !traverseDeepCopy(t, v1.Index(i), v2.Index(i), fmt.Sprintf("%s[%v]", name, i)) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
|
||||||
|
case reflect.Interface:
|
||||||
|
if v1.IsNil() || v2.IsNil() {
|
||||||
|
return assert.Equal(t, v1.IsNil(), v2.IsNil(), "%s are not both nil", name)
|
||||||
|
}
|
||||||
|
return traverseDeepCopy(t, v1.Elem(), v2.Elem(), name)
|
||||||
|
|
||||||
|
case reflect.Ptr:
|
||||||
|
local := reflect.ValueOf(time.Local).Pointer()
|
||||||
|
if local == v1.Pointer() && local == v2.Pointer() {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
if !assert.NotEqual(t, v1.Pointer(), v2.Pointer(), "%s points to the same memory", name) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return traverseDeepCopy(t, v1.Elem(), v2.Elem(), name)
|
||||||
|
|
||||||
|
case reflect.Struct:
|
||||||
|
for i, n := 0, v1.NumField(); i < n; i++ {
|
||||||
|
if !traverseDeepCopy(t, v1.Field(i), v2.Field(i), name+"."+v1.Type().Field(i).Name) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
|
||||||
|
case reflect.Map:
|
||||||
|
if v1.IsNil() || v2.IsNil() {
|
||||||
|
return assert.Equal(t, v1.IsNil(), v2.IsNil(), "%s are not both nil", name)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !assert.Equal(t, v1.Len(), v2.Len(), "%s are not the same length", name) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if !assert.NotEqual(t, v1.Pointer(), v2.Pointer(), "%s point to the same memory", name) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, k := range v1.MapKeys() {
|
||||||
|
val1 := v1.MapIndex(k)
|
||||||
|
val2 := v2.MapIndex(k)
|
||||||
|
if !assert.True(t, val1.IsValid(), "%s is an invalid key in %s", k, name) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if !assert.True(t, val2.IsValid(), "%s is an invalid key in %s", k, name) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if !traverseDeepCopy(t, val1, val2, name+fmt.Sprintf("%s[%s]", name, k)) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
|
||||||
|
default:
|
||||||
|
if v1.CanInterface() && v2.CanInterface() {
|
||||||
|
return assert.Equal(t, v1.Interface(), v2.Interface(), "%s was not equal", name)
|
||||||
|
}
|
||||||
|
|
||||||
|
e1 := reflect.NewAt(v1.Type(), unsafe.Pointer(v1.UnsafeAddr())).Elem().Interface()
|
||||||
|
e2 := reflect.NewAt(v2.Type(), unsafe.Pointer(v2.UnsafeAddr())).Elem().Interface()
|
||||||
|
|
||||||
|
return assert.Equal(t, e1, e2, "%s (unexported) was not equal", name)
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user