Compare commits

..

38 Commits

Author SHA1 Message Date
Dave Russell
db11e2f1af Revert "smoke test"
This reverts commit fa034a6d83.
2020-10-03 00:09:18 +10:00
Dave Russell
2ee428b067 Hook send should use a code path that actually firewalls
This change enforces that outbound hook traffic will actually be checked
by the firewall and added to the conntrack if allowed.
2020-10-02 23:42:20 +10:00
Dave Russell
e9657d571e control->Send: Also set the src port
With the source port also set, we only need to enable inbound
firewall rules on the 'server' side of the connection, as
the conntrack will allow replies.
2020-10-02 22:25:31 +10:00
Dave Russell
3cebf38504 The custom message packet sender needs a dest port
Source/Dest ports are required for the nebula firewall on the
receiving side, allow the port to be configured so that it can
be matched to specific rules as required.
2020-10-02 20:46:08 +10:00
Dave Russell
ae3ee42469 Provide hooks for custom message packet handlers
This commit augments the Control API by providing new methods to
inject message packets destined peer nodes, and/or to intercept
message packets of a custom message subtype that are received from
peer nodes.
2020-09-28 22:31:19 +10:00
Dave Russell
fa034a6d83 smoke test 2020-09-27 22:43:24 +10:00
Dave Russell
55d72ac46f Tighten up the inside handlers with a bit of DRY 2020-09-27 22:37:20 +10:00
Dave Russell
2c931d5691 Move inside packet handlers into map
This commit moves the inside packet handlers into a map of functions
from the large switch statement. The functions are mapped by packet
protocol version, type and subtype; which makes it simpler to inject
either a new protocol version and/or custom handlers.
2020-09-27 22:04:14 +10:00
Ryan Huber
0d6b55e495 Bring in the new version of kardianos/service and output logfiles on osx (#303)
* this brings in the new version of kardianos/service which properly
outputs logs from launchd services

* add go sum

* is it really this easy?

* Update CHANGELOG.md
2020-09-24 15:34:08 -07:00
Wade Simmons
c71c84882e v1.3.0 (#268)
Update the CHANGELOG for Nebula v1.3.0

Co-authored-by: forfuncsake <drussell@slack-corp.com>
2020-09-22 12:21:12 -04:00
Darren Hoo
0010db46e4 Fix a data race on message counter (#284)
3. ==================
WARNING: DATA RACE
Write at 0x00c00030e020 by goroutine 17:
  sync/atomic.AddInt64()
      runtime/race_amd64.s:276 +0xb
  github.com/slackhq/nebula.(*Interface).sendNoMetrics()
      github.com/slackhq/nebula/inside.go:226 +0x9c
  github.com/slackhq/nebula.(*Interface).send()
      github.com/slackhq/nebula/inside.go:214 +0x149
  github.com/slackhq/nebula.(*Interface).readOutsidePackets()
      github.com/slackhq/nebula/outside.go:94 +0x1213
  github.com/slackhq/nebula.(*udpConn).ListenOut()
      github.com/slackhq/nebula/udp_generic.go:109 +0x3b5
  github.com/slackhq/nebula.(*Interface).listenOut()
      github.com/slackhq/nebula/interface.go:147 +0x15e

Previous read at 0x00c00030e020 by goroutine 18:
  github.com/slackhq/nebula.(*Interface).consumeInsidePacket()
      github.com/slackhq/nebula/inside.go:58 +0x892
  github.com/slackhq/nebula.(*Interface).listenIn()
      github.com/slackhq/nebula/interface.go:164 +0x178
2020-09-21 21:41:46 -04:00
Nathan Brown
68e3e84fdc More like a library (#279) 2020-09-18 09:20:09 -05:00
Brian Luong
6238f1550b Handle panic when invalid IP entered in sshd (#296) 2020-09-18 10:10:25 -04:00
forfuncsake
50b04413c7 Block nebula ssh server from listening on port 22 (#266)
Port 22 is blocked as a safety mechanism. In a case where nebula is
started before sshd, a system may be rendered unreachable if nebula
is holding the system ssh port and there is no other connectivity.

This commit enforces the restriction, which could previously be worked
around by listening on an IPv6 address, e.g.  "[::]:22".
2020-09-15 09:57:32 -04:00
CzBiX
ef498a31da Add disable_timestamp option (#288) 2020-09-09 07:42:11 -04:00
forfuncsake
2e5a477a50 Align linux UDP performance optimizations with configuration (#275)
* Remove unused (*udpConn).Read method

* Align linux UDP performance optimizations with configuration

While attempting to run nebula on an older Synology NAS, it became
apparent that some of the performance optimizations effectively
block support for older kernels. The recvmmsg syscall was added in
linux kernel 2.6.33, and the Synology DS212j (among other models)
is pinned to 2.6.32.12.

Similarly, SO_REUSEPORT was added to the kernel in the 3.9 cycle.
While this option has been backported into some older trees, it
is also missing from the Synology kernel.

This commit allows nebula to be run on linux devices with older
kernels if the config options are set up with a single listener
and a UDP batch size of 1.
2020-08-13 08:24:05 +10:00
Wade Simmons
32fe9bfe75 Use Go 1.15 (#277)
Update all CI checks and release process to use the latest patch version
of go1.15.
2020-08-12 16:16:21 -04:00
forfuncsake
9b8b3c478b Support startup without a tun device (#269)
This commit adds support for Nebula to be started without creating
a tun device. A node started in this mode still has a full "control
plane", but no effective "data plane". Its use is suited to a
lighthouse that has no need to partake in the mesh VPN.

Consequently, creation of the tun device is the only reason nebula
neesd to be started with elevated privileged, so this example
lighthouse can also be run as a non-root user.
2020-08-10 09:15:55 -04:00
Michael Hardy
7b3f23d9a1 Start nebula after the network is up (#270) 2020-08-07 11:33:48 -05:00
forfuncsake
25964b54f6 Use inclusive terminology for cert blocking (#272) 2020-08-06 11:17:47 +10:00
Wade Simmons
ac557f381b drop unroutable packets (#267)
Currently, if a packet arrives on the tun device with a destination that
is not a routable Nebula IP, `queryUnsafeRoute` converts that IP to
0.0.0.0 and we store that packet and try to look up that IP with the
lighthouse. This doesn't make any sense to do, if we get a packet that
is unroutable we should just drop it.

Note, we have a few configurable options like `drop_local_broadcast`
and `drop_multicast` which do this for a few specific types, but since
no packets like this will send correctly I think we should just drop
anything that is unroutable.
2020-08-04 22:59:04 -04:00
Wade Simmons
a54f3fc681 fix fast handshake trigger for static hosts (#265)
We are currently triggering a fast handshake for static hosts right
inside HandshakeManager.AddVpnIP, but this can actually trigger before
we have generated the handshake packet to use. Instead, we should be
triggering right after we call ixHandshakeStage0 in getOrHandshake
(which generates the handshake packet)
2020-08-02 20:59:50 -04:00
Alan Lam
5545cff6ef log remote certificate fingerprint on handshakes (#262) 2020-07-31 18:54:51 -04:00
Wade Simmons
f3a6d8d990 Preserve conntrack table during firewall rules reload (SIGHUP) (#233)
Currently, we drop the conntrack table when firewall rules change during a SIGHUP reload. This means responses to inflight HTTP requests can be dropped, among other issues. This change copies the conntrack table over to the new firewall (it holds the conntrack mutex lock during this process, to be safe).

This change also records which firewall rules hash each conntrack entry used, so that we can re-verify the rules after the new firewall has been loaded.
2020-07-31 18:53:36 -04:00
forfuncsake
9b06748506 Make Interface.Inside an interface type (#252)
This commit updates the Interface.Inside type to be a new interface
type instead of a *Tun. This will allow for an inside interface
that does not use a tun device, such as a single-binary client that
can run without elevated privileges.
2020-07-28 08:53:16 -04:00
Wade Simmons
4756c9613d trigger handshakes when lighthouse reply arrives (#246)
Currently, we wait until the next timer tick to act on the lighthouse's
reply to our HostQuery. This means we can easily add hundreds of
milliseconds of unnecessary delay to the handshake. To fix this, we
can introduce a channel to trigger an outbound handshake without waiting
for the next timer tick.

A few samples of cold ping time between two hosts that require a
lighthouse lookup:

    before (v1.2.0):

    time=156 ms
    time=252 ms
    time=12.6 ms
    time=301 ms
    time=352 ms
    time=49.4 ms
    time=150 ms
    time=13.5 ms
    time=8.24 ms
    time=161 ms
    time=355 ms

    after:

    time=3.53 ms
    time=3.14 ms
    time=3.08 ms
    time=3.92 ms
    time=7.78 ms
    time=3.59 ms
    time=3.07 ms
    time=3.22 ms
    time=3.12 ms
    time=3.08 ms
    time=8.04 ms

I recommend reviewing this PR by looking at each commit individually, as
some refactoring was required that makes the diff a bit confusing when
combined together.
2020-07-22 10:35:10 -04:00
Nathan Brown
4645e6034b Fix up the tun for android (#249) 2020-07-01 10:20:52 -05:00
Wade Simmons
aba42f9fa6 enforce the use of goimports (#248)
* enforce the use of goimports

Instead of enforcing `gofmt`, enforce `goimports`, which also asserts
a separate section for non-builtin packages.

* run `goimports` everywhere

* exclude generated .pb.go files
2020-06-30 18:53:30 -04:00
Nathan Brown
41578ca971 Be more like a library to support mobile (#247) 2020-06-30 13:48:58 -05:00
Wade Simmons
1ea8847085 linux: set advmss correctly when route MTU is used (#245)
If different mtus are specified for different routes, we should set
advmss on each route because Linux does a poor job of selecting the
default (from ip-route(8)):

    advmss NUMBER (Linux 2.3.15+ only)
           the MSS ('Maximal Segment Size') to advertise to these destinations when estab‐
           lishing TCP connections. If it is not given, Linux uses a default value calcu‐
           lated from the first hop device MTU.  (If the path to these destination is asym‐
           metric, this guess may be wrong.)

Note that the default value is calculated from the first hop *device
MTU*, not the *route MTU*. In practice this is usually ok as long as the
other side of the tunnel has the mtu configured exactly the same, but we
should probably just set advmss correctly on these routes.
2020-06-26 13:47:21 -04:00
Wade Simmons
55858c64cc smoke test: test firewall inbound / outbound (#240)
Test that basic inbound / outbound firewall rules work during the smoke
test. This change sets an inbound firewall rule on host3, and a new
host4 with outbound firewall rules. It also tests that conntrack allows
packets once the connection has been established.
2020-06-26 13:46:51 -04:00
Wade Simmons
e94c6b0125 mips-softfloat (#231)
This makes GOARM more generic and does GOMIPS in a similar way to
support mips-softfloat. We also set `-ldflags "-s -w"` for
mips-softfloat to give the best chance of the binary working on these
small devices.
2020-06-26 13:46:23 -04:00
Wade Simmons
b37a91cfbc add meta packet statistics (#230)
This change add more metrics around "meta" (non "message" type packets).
For lighthouse packets, we also record statistics around the specific
lighthouse meta type.

We don't keep statistics for the "message" type so that we don't slow
down the fast path (and you can just look at metrics on the tun
interface to find that information).
2020-06-26 13:45:48 -04:00
David Sonder
3212b769d4 fix typo in conntrack section in examples/config.yml (#236)
the rest of the conntrack values match the default
2020-06-26 11:08:22 -05:00
Patrick Bogen
ecf0e5a9f6 drop packets even if we aren't going to emit Debug logs about it (#239)
* drop packets even if we aren't going to emit Debug logs about it

* smallify change
2020-06-10 16:55:49 -05:00
Wade Simmons
ff13aba8fc allow go test -bench=. to run (#234)
This benchmark had an Errorf at the end, lets remove it so the
benchmarks all run.
2020-05-27 16:52:34 -04:00
Mateusz Kwiatkowski
cc03ff9e9a Unbreak building for FreeBSD (#103)
Add support for freebsd. You have to set `tun.dev` in your config. The second pass of this would be to remove the exec calls and use ioctl(2) and route(4) instead, but we can do that in a second PR.

Co-authored-by: Wade Simmons <wade@wades.im>
2020-05-26 22:23:23 -04:00
Patrick Bogen
363c836422 log the reason for fw drops (#220)
* log the reason for fw drops

* only prepare log if we will end up sending it
2020-04-10 10:57:21 -07:00
75 changed files with 2515 additions and 531 deletions

View File

@@ -14,19 +14,31 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Set up Go 1.14 - name: Set up Go 1.15
uses: actions/setup-go@v1 uses: actions/setup-go@v1
with: with:
go-version: 1.14 go-version: 1.15
id: go id: go
- name: Check out code into the Go module directory - name: Check out code into the Go module directory
uses: actions/checkout@v1 uses: actions/checkout@v1
- uses: actions/cache@v1
with:
path: ~/go/pkg/mod
key: ${{ runner.os }}-gofmt-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-gofmt-
- name: Install goimports
run: |
go get golang.org/x/tools/cmd/goimports
go build golang.org/x/tools/cmd/goimports
- name: gofmt - name: gofmt
run: | run: |
if [ "$(find . -iname '*.go' | xargs gofmt -l)" ] if [ "$(find . -iname '*.go' | grep -v '\.pb\.go$' | xargs ./goimports -l)" ]
then then
find . -iname '*.go' | xargs gofmt -d find . -iname '*.go' | grep -v '\.pb\.go$' | xargs ./goimports -d
exit 1 exit 1
fi fi

View File

@@ -10,17 +10,17 @@ jobs:
name: Build Linux All name: Build Linux All
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Set up Go 1.14 - name: Set up Go 1.15
uses: actions/setup-go@v1 uses: actions/setup-go@v1
with: with:
go-version: 1.14 go-version: 1.15
- name: Checkout code - name: Checkout code
uses: actions/checkout@v2 uses: actions/checkout@v2
- name: Build - name: Build
run: | run: |
make BUILD_NUMBER="${GITHUB_REF#refs/tags/v}" release-linux make BUILD_NUMBER="${GITHUB_REF#refs/tags/v}" release-linux release-freebsd
mkdir release mkdir release
mv build/*.tar.gz release mv build/*.tar.gz release
@@ -34,10 +34,10 @@ jobs:
name: Build Windows amd64 name: Build Windows amd64
runs-on: windows-latest runs-on: windows-latest
steps: steps:
- name: Set up Go 1.14 - name: Set up Go 1.15
uses: actions/setup-go@v1 uses: actions/setup-go@v1
with: with:
go-version: 1.14 go-version: 1.15
- name: Checkout code - name: Checkout code
uses: actions/checkout@v2 uses: actions/checkout@v2
@@ -58,10 +58,10 @@ jobs:
name: Build Darwin amd64 name: Build Darwin amd64
runs-on: macOS-latest runs-on: macOS-latest
steps: steps:
- name: Set up Go 1.14 - name: Set up Go 1.15
uses: actions/setup-go@v1 uses: actions/setup-go@v1
with: with:
go-version: 1.14 go-version: 1.15
- name: Checkout code - name: Checkout code
uses: actions/checkout@v2 uses: actions/checkout@v2
@@ -278,3 +278,23 @@ jobs:
asset_path: ./linux-latest/nebula-linux-mips64le.tar.gz asset_path: ./linux-latest/nebula-linux-mips64le.tar.gz
asset_name: nebula-linux-mips64le.tar.gz asset_name: nebula-linux-mips64le.tar.gz
asset_content_type: application/gzip asset_content_type: application/gzip
- name: Upload linux-mips-softfloat
uses: actions/upload-release-asset@v1.0.1
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
with:
upload_url: ${{ steps.create_release.outputs.upload_url }}
asset_path: ./linux-latest/nebula-linux-mips-softfloat.tar.gz
asset_name: nebula-linux-mips-softfloat.tar.gz
asset_content_type: application/gzip
- name: Upload freebsd-amd64
uses: actions/upload-release-asset@v1.0.1
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
with:
upload_url: ${{ steps.create_release.outputs.upload_url }}
asset_path: ./linux-latest/nebula-freebsd-amd64.tar.gz
asset_name: nebula-freebsd-amd64.tar.gz
asset_content_type: application/gzip

View File

@@ -14,14 +14,14 @@ on:
jobs: jobs:
smoke: smoke:
name: Run 3 node smoke test name: Run multi node smoke test
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Set up Go 1.14 - name: Set up Go 1.15
uses: actions/setup-go@v1 uses: actions/setup-go@v1
with: with:
go-version: 1.14 go-version: 1.15
id: go id: go
- name: Check out code into the Go module directory - name: Check out code into the Go module directory

View File

@@ -11,14 +11,29 @@ mkdir ./build
cp ../../../../nebula . cp ../../../../nebula .
cp ../../../../nebula-cert . cp ../../../../nebula-cert .
HOST="lighthouse1" AM_LIGHTHOUSE=true ../genconfig.sh >lighthouse1.yml HOST="lighthouse1" \
HOST="host2" LIGHTHOUSES="192.168.100.1 172.17.0.2:4242" ../genconfig.sh >host2.yml AM_LIGHTHOUSE=true \
HOST="host3" LIGHTHOUSES="192.168.100.1 172.17.0.2:4242" ../genconfig.sh >host3.yml ../genconfig.sh >lighthouse1.yml
HOST="host2" \
LIGHTHOUSES="192.168.100.1 172.17.0.2:4242" \
../genconfig.sh >host2.yml
HOST="host3" \
LIGHTHOUSES="192.168.100.1 172.17.0.2:4242" \
INBOUND='[{"port": "any", "proto": "icmp", "group": "lighthouse"}]' \
../genconfig.sh >host3.yml
HOST="host4" \
LIGHTHOUSES="192.168.100.1 172.17.0.2:4242" \
OUTBOUND='[{"port": "any", "proto": "icmp", "group": "lighthouse"}]' \
../genconfig.sh >host4.yml
./nebula-cert ca -name "Smoke Test" ./nebula-cert ca -name "Smoke Test"
./nebula-cert sign -name "lighthouse1" -ip "192.168.100.1/24" ./nebula-cert sign -name "lighthouse1" -groups "lighthouse,lighthouse1" -ip "192.168.100.1/24"
./nebula-cert sign -name "host2" -ip "192.168.100.2/24" ./nebula-cert sign -name "host2" -groups "host,host2" -ip "192.168.100.2/24"
./nebula-cert sign -name "host3" -ip "192.168.100.3/24" ./nebula-cert sign -name "host3" -groups "host,host3" -ip "192.168.100.3/24"
./nebula-cert sign -name "host4" -groups "host,host4" -ip "192.168.100.4/24"
) )
docker build -t nebula:smoke . docker build -t nebula:smoke .

View File

@@ -2,6 +2,7 @@
set -e set -e
FIREWALL_ALL='[{"port": "any", "proto": "any", "host": "any"}]'
if [ "$STATIC_HOSTS" ] || [ "$LIGHTHOUSES" ] if [ "$STATIC_HOSTS" ] || [ "$LIGHTHOUSES" ]
then then
@@ -48,13 +49,6 @@ tun:
dev: ${TUN_DEV:-nebula1} dev: ${TUN_DEV:-nebula1}
firewall: firewall:
outbound: outbound: ${OUTBOUND:-$FIREWALL_ALL}
- port: any inbound: ${INBOUND:-$FIREWALL_ALL}
proto: any
host: any
inbound:
- port: any
proto: any
host: any
EOF EOF

View File

@@ -5,6 +5,7 @@ set -e -x
docker run --name lighthouse1 --rm nebula:smoke -config lighthouse1.yml -test docker run --name lighthouse1 --rm nebula:smoke -config lighthouse1.yml -test
docker run --name host2 --rm nebula:smoke -config host2.yml -test docker run --name host2 --rm nebula:smoke -config host2.yml -test
docker run --name host3 --rm nebula:smoke -config host3.yml -test docker run --name host3 --rm nebula:smoke -config host3.yml -test
docker run --name host4 --rm nebula:smoke -config host4.yml -test
docker run --name lighthouse1 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm nebula:smoke -config lighthouse1.yml & docker run --name lighthouse1 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm nebula:smoke -config lighthouse1.yml &
sleep 1 sleep 1
@@ -12,6 +13,8 @@ docker run --name host2 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN -
sleep 1 sleep 1
docker run --name host3 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm nebula:smoke -config host3.yml & docker run --name host3 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm nebula:smoke -config host3.yml &
sleep 1 sleep 1
docker run --name host4 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm nebula:smoke -config host4.yml &
sleep 1
set +x set +x
echo echo
@@ -27,7 +30,8 @@ echo " *** Testing ping from host2"
echo echo
set -x set -x
docker exec host2 ping -c1 192.168.100.1 docker exec host2 ping -c1 192.168.100.1
docker exec host2 ping -c1 192.168.100.3 # Should fail because not allowed by host3 inbound firewall
! docker exec host2 ping -c1 192.168.100.3 -w5 || exit 1
set +x set +x
echo echo
@@ -36,3 +40,24 @@ echo
set -x set -x
docker exec host3 ping -c1 192.168.100.1 docker exec host3 ping -c1 192.168.100.1
docker exec host3 ping -c1 192.168.100.2 docker exec host3 ping -c1 192.168.100.2
set +x
echo
echo " *** Testing ping from host4"
echo
set -x
docker exec host4 ping -c1 192.168.100.1
# Should fail because not allowed by host4 outbound firewall
! docker exec host4 ping -c1 192.168.100.2 -w5 || exit 1
! docker exec host4 ping -c1 192.168.100.3 -w5 || exit 1
set +x
echo
echo " *** Testing conntrack"
echo
set -x
# host2 can ping host3 now that host3 pinged it first
docker exec host2 ping -c1 192.168.100.3
# host4 can ping host2 once conntrack established
docker exec host2 ping -c1 192.168.100.4
docker exec host4 ping -c1 192.168.100.2

View File

@@ -18,10 +18,10 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Set up Go 1.14 - name: Set up Go 1.15
uses: actions/setup-go@v1 uses: actions/setup-go@v1
with: with:
go-version: 1.14 go-version: 1.15
id: go id: go
- name: Check out code into the Go module directory - name: Check out code into the Go module directory
@@ -48,10 +48,10 @@ jobs:
os: [windows-latest, macOS-latest] os: [windows-latest, macOS-latest]
steps: steps:
- name: Set up Go 1.14 - name: Set up Go 1.15
uses: actions/setup-go@v1 uses: actions/setup-go@v1
with: with:
go-version: 1.14 go-version: 1.15
id: go id: go
- name: Check out code into the Go module directory - name: Check out code into the Go module directory

View File

@@ -7,6 +7,73 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [Unreleased] ## [Unreleased]
### Changed
- Updated the kardianos/service go library from 1.0.0 to 1.1.0, which
now creates launchd plist to write stdout/stderr to files by default.
## [1.3.0] - 2020-09-22
### Added
- You can emit statistics about non-message packets by setting the option
`stats.message_metrics`. You can similarly emit detailed statistics about
lighthouse packets by setting the option `stats.lighthouse_metrics`. See
the example config for more details. (#230)
- We now support freebsd/amd64. This is experimental, please give us feedback.
(#103)
- We now release a binary for `linux/mips-softfloat` which has also been
stripped to reduce filesize and hopefully have a better chance on running on
small mips devices. (#231)
- You can set `tun.disabled` to true to run a standalone lighthouse without a
tun device (and thus, without root). (#269)
- You can set `logging.disable_timestamp` to remove timestamps from log lines,
which is useful when output is redirected to a logging system that already
adds timestamps. (#288)
### Changed
- Handshakes should now trigger faster, as we try to be proactive with sending
them instead of waiting for the next timer tick in most cases. (#246, #265)
- Previously, we would drop the conntrack table whenever firewall rules were
changed during a SIGHUP. Now, we will maintain the table and just validate
that an entry still matches with the new rule set. (#233)
- Debug logs for firewall drops now include the reason. (#220, #239)
- Logs for handshakes now include the fingerprint of the remote host. (#262)
- Config item `pki.blacklist` is now `pki.blocklist`. (#272)
- Better support for older Linux kernels. We now only set `SO_REUSEPORT` if
`tun.routines` is greater than 1 (default is 1). We also only use the
`recvmmsg` syscall if `listen.batch` is greater than 1 (default is 64).
(#275)
- It is possible to run Nebula as a library inside of another process now.
Note that this is still experimental and the internal APIs around this might
change in minor version releases. (#279)
### Deprecated
- `pki.blacklist` is deprecated in favor of `pki.blocklist` with the same
functionality. Existing configs will continue to load for this release to
allow for migrations. (#272)
### Fixed
- `advmss` is now set correctly for each route table entry when `tun.routes`
is configured to have some routes with higher MTU. (#245)
- Packets that arrive on the tun device with an unroutable destination IP are
now dropped correctly, instead of wasting time making queries to the
lighthouses for IP `0.0.0.0` (#267)
## [1.2.0] - 2020-04-08 ## [1.2.0] - 2020-04-08
### Added ### Added
@@ -118,7 +185,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Initial public release. - Initial public release.
[Unreleased]: https://github.com/slackhq/nebula/compare/v1.2.0...HEAD [Unreleased]: https://github.com/slackhq/nebula/compare/v1.3.0...HEAD
[1.3.0]: https://github.com/slackhq/nebula/releases/tag/v1.3.0
[1.2.0]: https://github.com/slackhq/nebula/releases/tag/v1.2.0 [1.2.0]: https://github.com/slackhq/nebula/releases/tag/v1.2.0
[1.1.0]: https://github.com/slackhq/nebula/releases/tag/v1.1.0 [1.1.0]: https://github.com/slackhq/nebula/releases/tag/v1.1.0
[1.0.0]: https://github.com/slackhq/nebula/releases/tag/v1.0.0 [1.0.0]: https://github.com/slackhq/nebula/releases/tag/v1.0.0

View File

@@ -3,6 +3,8 @@ BUILD_NUMBER ?= dev+$(shell date -u '+%Y%m%d%H%M%S')
GO111MODULE = on GO111MODULE = on
export GO111MODULE export GO111MODULE
LDFLAGS = -X main.Build=$(BUILD_NUMBER)
ALL_LINUX = linux-amd64 \ ALL_LINUX = linux-amd64 \
linux-386 \ linux-386 \
linux-ppc64le \ linux-ppc64le \
@@ -13,10 +15,12 @@ ALL_LINUX = linux-amd64 \
linux-mips \ linux-mips \
linux-mipsle \ linux-mipsle \
linux-mips64 \ linux-mips64 \
linux-mips64le linux-mips64le \
linux-mips-softfloat
ALL = $(ALL_LINUX) \ ALL = $(ALL_LINUX) \
darwin-amd64 \ darwin-amd64 \
freebsd-amd64 \
windows-amd64 windows-amd64
all: $(ALL:%=build/%/nebula) $(ALL:%=build/%/nebula-cert) all: $(ALL:%=build/%/nebula) $(ALL:%=build/%/nebula-cert)
@@ -25,31 +29,40 @@ release: $(ALL:%=build/nebula-%.tar.gz)
release-linux: $(ALL_LINUX:%=build/nebula-%.tar.gz) release-linux: $(ALL_LINUX:%=build/nebula-%.tar.gz)
release-freebsd: build/nebula-freebsd-amd64.tar.gz
bin-windows: build/windows-amd64/nebula.exe build/windows-amd64/nebula-cert.exe bin-windows: build/windows-amd64/nebula.exe build/windows-amd64/nebula-cert.exe
mv $? . mv $? .
bin-darwin: build/darwin-amd64/nebula build/darwin-amd64/nebula-cert bin-darwin: build/darwin-amd64/nebula build/darwin-amd64/nebula-cert
mv $? . mv $? .
bin-freebsd: build/freebsd-amd64/nebula build/freebsd-amd64/nebula-cert
mv $? .
bin: bin:
go build -trimpath -ldflags "-X main.Build=$(BUILD_NUMBER)" -o ./nebula ${NEBULA_CMD_PATH} go build -trimpath -ldflags "$(LDFLAGS)" -o ./nebula ${NEBULA_CMD_PATH}
go build -trimpath -ldflags "-X main.Build=$(BUILD_NUMBER)" -o ./nebula-cert ./cmd/nebula-cert go build -trimpath -ldflags "$(LDFLAGS)" -o ./nebula-cert ./cmd/nebula-cert
install: install:
go install -trimpath -ldflags "-X main.Build=$(BUILD_NUMBER)" ${NEBULA_CMD_PATH} go install -trimpath -ldflags "$(LDFLAGS)" ${NEBULA_CMD_PATH}
go install -trimpath -ldflags "-X main.Build=$(BUILD_NUMBER)" ./cmd/nebula-cert go install -trimpath -ldflags "$(LDFLAGS)" ./cmd/nebula-cert
build/linux-arm-%: GOENV += GOARM=$(word 3, $(subst -, ,$*))
build/linux-mips-%: GOENV += GOMIPS=$(word 3, $(subst -, ,$*))
# Build an extra small binary for mips-softfloat
build/linux-mips-softfloat/%: LDFLAGS += -s -w
build/%/nebula: .FORCE build/%/nebula: .FORCE
GOOS=$(firstword $(subst -, , $*)) \ GOOS=$(firstword $(subst -, , $*)) \
GOARCH=$(word 2, $(subst -, ,$*)) \ GOARCH=$(word 2, $(subst -, ,$*)) $(GOENV) \
GOARM=$(word 3, $(subst -, ,$*)) \ go build -trimpath -o $@ -ldflags "$(LDFLAGS)" ${NEBULA_CMD_PATH}
go build -trimpath -o $@ -ldflags "-X main.Build=$(BUILD_NUMBER)" ${NEBULA_CMD_PATH}
build/%/nebula-cert: .FORCE build/%/nebula-cert: .FORCE
GOOS=$(firstword $(subst -, , $*)) \ GOOS=$(firstword $(subst -, , $*)) \
GOARCH=$(word 2, $(subst -, ,$*)) \ GOARCH=$(word 2, $(subst -, ,$*)) $(GOENV) \
GOARM=$(word 3, $(subst -, ,$*)) \ go build -trimpath -o $@ -ldflags "$(LDFLAGS)" ./cmd/nebula-cert
go build -trimpath -o $@ -ldflags "-X main.Build=$(BUILD_NUMBER)" ./cmd/nebula-cert
build/%/nebula.exe: build/%/nebula build/%/nebula.exe: build/%/nebula
mv $< $@ mv $< $@

View File

@@ -212,10 +212,10 @@ func TestBitsLostCounter(t *testing.T) {
func BenchmarkBits(b *testing.B) { func BenchmarkBits(b *testing.B) {
z := NewBits(10) z := NewBits(10)
for n := 0; n < b.N; n++ { for n := 0; n < b.N; n++ {
for i, _ := range z.bits { for i := range z.bits {
z.bits[i] = true z.bits[i] = true
} }
for i, _ := range z.bits { for i := range z.bits {
z.bits[i] = false z.bits[i] = false
} }

12
cert.go
View File

@@ -149,10 +149,16 @@ func loadCAFromConfig(c *Config) (*cert.NebulaCAPool, error) {
return nil, fmt.Errorf("error while adding CA certificate to CA trust store: %s", err) return nil, fmt.Errorf("error while adding CA certificate to CA trust store: %s", err)
} }
// pki.blacklist entered the scene at about the same time we aliased x509 to pki, not supporting backwards compat for _, fp := range c.GetStringSlice("pki.blocklist", []string{}) {
l.WithField("fingerprint", fp).Infof("Blocklisting cert")
CAs.BlocklistFingerprint(fp)
}
// Support deprecated config for at leaast one minor release to allow for migrations
for _, fp := range c.GetStringSlice("pki.blacklist", []string{}) { for _, fp := range c.GetStringSlice("pki.blacklist", []string{}) {
l.WithField("fingerprint", fp).Infof("Blacklisting cert") l.WithField("fingerprint", fp).Infof("Blocklisting cert")
CAs.BlacklistFingerprint(fp) l.Warn("pki.blacklist is deprecated and will not be supported in a future release. Please migrate your config to use pki.blocklist")
CAs.BlocklistFingerprint(fp)
} }
return CAs, nil return CAs, nil

View File

@@ -8,14 +8,14 @@ import (
type NebulaCAPool struct { type NebulaCAPool struct {
CAs map[string]*NebulaCertificate CAs map[string]*NebulaCertificate
certBlacklist map[string]struct{} certBlocklist map[string]struct{}
} }
// NewCAPool creates a CAPool // NewCAPool creates a CAPool
func NewCAPool() *NebulaCAPool { func NewCAPool() *NebulaCAPool {
ca := NebulaCAPool{ ca := NebulaCAPool{
CAs: make(map[string]*NebulaCertificate), CAs: make(map[string]*NebulaCertificate),
certBlacklist: make(map[string]struct{}), certBlocklist: make(map[string]struct{}),
} }
return &ca return &ca
@@ -67,24 +67,24 @@ func (ncp *NebulaCAPool) AddCACertificate(pemBytes []byte) ([]byte, error) {
return pemBytes, nil return pemBytes, nil
} }
// BlacklistFingerprint adds a cert fingerprint to the blacklist // BlocklistFingerprint adds a cert fingerprint to the blocklist
func (ncp *NebulaCAPool) BlacklistFingerprint(f string) { func (ncp *NebulaCAPool) BlocklistFingerprint(f string) {
ncp.certBlacklist[f] = struct{}{} ncp.certBlocklist[f] = struct{}{}
} }
// ResetCertBlacklist removes all previously blacklisted cert fingerprints // ResetCertBlocklist removes all previously blocklisted cert fingerprints
func (ncp *NebulaCAPool) ResetCertBlacklist() { func (ncp *NebulaCAPool) ResetCertBlocklist() {
ncp.certBlacklist = make(map[string]struct{}) ncp.certBlocklist = make(map[string]struct{})
} }
// IsBlacklisted returns true if the fingerprint fails to generate or has been explicitly blacklisted // IsBlocklisted returns true if the fingerprint fails to generate or has been explicitly blocklisted
func (ncp *NebulaCAPool) IsBlacklisted(c *NebulaCertificate) bool { func (ncp *NebulaCAPool) IsBlocklisted(c *NebulaCertificate) bool {
h, err := c.Sha256Sum() h, err := c.Sha256Sum()
if err != nil { if err != nil {
return true return true
} }
if _, ok := ncp.certBlacklist[h]; ok { if _, ok := ncp.certBlocklist[h]; ok {
return true return true
} }

View File

@@ -1,18 +1,18 @@
package cert package cert
import ( import (
"bytes"
"crypto" "crypto"
"crypto/rand" "crypto/rand"
"crypto/sha256" "crypto/sha256"
"encoding/binary" "encoding/binary"
"encoding/hex" "encoding/hex"
"encoding/json"
"encoding/pem" "encoding/pem"
"fmt" "fmt"
"net" "net"
"time" "time"
"bytes"
"encoding/json"
"github.com/golang/protobuf/proto" "github.com/golang/protobuf/proto"
"golang.org/x/crypto/curve25519" "golang.org/x/crypto/curve25519"
"golang.org/x/crypto/ed25519" "golang.org/x/crypto/ed25519"
@@ -244,10 +244,10 @@ func (nc *NebulaCertificate) Expired(t time.Time) bool {
return nc.Details.NotBefore.After(t) || nc.Details.NotAfter.Before(t) return nc.Details.NotBefore.After(t) || nc.Details.NotAfter.Before(t)
} }
// Verify will ensure a certificate is good in all respects (expiry, group membership, signature, cert blacklist, etc) // Verify will ensure a certificate is good in all respects (expiry, group membership, signature, cert blocklist, etc)
func (nc *NebulaCertificate) Verify(t time.Time, ncp *NebulaCAPool) (bool, error) { func (nc *NebulaCertificate) Verify(t time.Time, ncp *NebulaCAPool) (bool, error) {
if ncp.IsBlacklisted(nc) { if ncp.IsBlocklisted(nc) {
return false, fmt.Errorf("certificate has been blacklisted") return false, fmt.Errorf("certificate has been blocked")
} }
signer, err := ncp.GetCAForCert(nc) signer, err := ncp.GetCAForCert(nc)
@@ -468,6 +468,63 @@ func (nc *NebulaCertificate) MarshalJSON() ([]byte, error) {
return json.Marshal(jc) return json.Marshal(jc)
} }
//func (nc *NebulaCertificate) Copy() *NebulaCertificate {
// r, err := nc.Marshal()
// if err != nil {
// //TODO
// return nil
// }
//
// c, err := UnmarshalNebulaCertificate(r)
// return c
//}
func (nc *NebulaCertificate) Copy() *NebulaCertificate {
c := &NebulaCertificate{
Details: NebulaCertificateDetails{
Name: nc.Details.Name,
Groups: make([]string, len(nc.Details.Groups)),
Ips: make([]*net.IPNet, len(nc.Details.Ips)),
Subnets: make([]*net.IPNet, len(nc.Details.Subnets)),
NotBefore: nc.Details.NotBefore,
NotAfter: nc.Details.NotAfter,
PublicKey: make([]byte, len(nc.Details.PublicKey)),
IsCA: nc.Details.IsCA,
Issuer: nc.Details.Issuer,
InvertedGroups: make(map[string]struct{}, len(nc.Details.InvertedGroups)),
},
Signature: make([]byte, len(nc.Signature)),
}
copy(c.Signature, nc.Signature)
copy(c.Details.Groups, nc.Details.Groups)
copy(c.Details.PublicKey, nc.Details.PublicKey)
for i, p := range nc.Details.Ips {
c.Details.Ips[i] = &net.IPNet{
IP: make(net.IP, len(p.IP)),
Mask: make(net.IPMask, len(p.Mask)),
}
copy(c.Details.Ips[i].IP, p.IP)
copy(c.Details.Ips[i].Mask, p.Mask)
}
for i, p := range nc.Details.Subnets {
c.Details.Subnets[i] = &net.IPNet{
IP: make(net.IP, len(p.IP)),
Mask: make(net.IPMask, len(p.Mask)),
}
copy(c.Details.Subnets[i].IP, p.IP)
copy(c.Details.Subnets[i].Mask, p.Mask)
}
for g := range nc.Details.InvertedGroups {
c.Details.InvertedGroups[g] = struct{}{}
}
return c
}
func netMatch(certIp *net.IPNet, rootIps []*net.IPNet) bool { func netMatch(certIp *net.IPNet, rootIps []*net.IPNet) bool {
for _, net := range rootIps { for _, net := range rootIps {
if net.Contains(certIp.IP) && maskContains(net.Mask, certIp.Mask) { if net.Contains(certIp.IP) && maskContains(net.Mask, certIp.Mask) {

View File

@@ -9,6 +9,7 @@ import (
"time" "time"
"github.com/golang/protobuf/proto" "github.com/golang/protobuf/proto"
"github.com/slackhq/nebula/util"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"golang.org/x/crypto/curve25519" "golang.org/x/crypto/curve25519"
"golang.org/x/crypto/ed25519" "golang.org/x/crypto/ed25519"
@@ -172,13 +173,13 @@ func TestNebulaCertificate_Verify(t *testing.T) {
f, err := c.Sha256Sum() f, err := c.Sha256Sum()
assert.Nil(t, err) assert.Nil(t, err)
caPool.BlacklistFingerprint(f) caPool.BlocklistFingerprint(f)
v, err := c.Verify(time.Now(), caPool) v, err := c.Verify(time.Now(), caPool)
assert.False(t, v) assert.False(t, v)
assert.EqualError(t, err, "certificate has been blacklisted") assert.EqualError(t, err, "certificate has been blocked")
caPool.ResetCertBlacklist() caPool.ResetCertBlocklist()
v, err = c.Verify(time.Now(), caPool) v, err = c.Verify(time.Now(), caPool)
assert.True(t, v) assert.True(t, v)
assert.Nil(t, err) assert.Nil(t, err)
@@ -487,6 +488,17 @@ func TestMarshalingNebulaCertificateConsistency(t *testing.T) {
assert.Equal(t, "0a0774657374696e67121b8182845080feffff0f828284508080fcff0f8382845080fe83f80f1a1b8182844880fe83f80f8282844880feffff0f838284488080fcff0f220b746573742d67726f757031220b746573742d67726f757032220b746573742d67726f75703328f0e0e7d70430a08681c4053a20313233343536373839306162636564666768696a3132333435363738393061624a081234567890abcedf", fmt.Sprintf("%x", b)) assert.Equal(t, "0a0774657374696e67121b8182845080feffff0f828284508080fcff0f8382845080fe83f80f1a1b8182844880fe83f80f8282844880feffff0f838284488080fcff0f220b746573742d67726f757031220b746573742d67726f757032220b746573742d67726f75703328f0e0e7d70430a08681c4053a20313233343536373839306162636564666768696a3132333435363738393061624a081234567890abcedf", fmt.Sprintf("%x", b))
} }
func TestNebulaCertificate_Copy(t *testing.T) {
ca, _, caKey, err := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
assert.Nil(t, err)
c, _, _, err := newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
assert.Nil(t, err)
cc := c.Copy()
util.AssertDeepCopyEqual(t, c, cc)
}
func newTestCaCert(before, after time.Time, ips, subnets []*net.IPNet, groups []string) (*NebulaCertificate, []byte, []byte, error) { func newTestCaCert(before, after time.Time, ips, subnets []*net.IPNet, groups []string) (*NebulaCertificate, []byte, []byte, error) {
pub, priv, err := ed25519.GenerateKey(rand.Reader) pub, priv, err := ed25519.GenerateKey(rand.Reader)
if before.IsZero() { if before.IsZero() {
@@ -499,10 +511,11 @@ func newTestCaCert(before, after time.Time, ips, subnets []*net.IPNet, groups []
nc := &NebulaCertificate{ nc := &NebulaCertificate{
Details: NebulaCertificateDetails{ Details: NebulaCertificateDetails{
Name: "test ca", Name: "test ca",
NotBefore: before, NotBefore: time.Unix(before.Unix(), 0),
NotAfter: after, NotAfter: time.Unix(after.Unix(), 0),
PublicKey: pub, PublicKey: pub,
IsCA: true, IsCA: true,
InvertedGroups: make(map[string]struct{}),
}, },
} }
@@ -544,17 +557,17 @@ func newTestCert(ca *NebulaCertificate, key []byte, before, after time.Time, ips
if len(ips) == 0 { if len(ips) == 0 {
ips = []*net.IPNet{ ips = []*net.IPNet{
{IP: net.ParseIP("10.1.1.1"), Mask: net.IPMask(net.ParseIP("255.255.255.0"))}, {IP: net.ParseIP("10.1.1.1").To4(), Mask: net.IPMask(net.ParseIP("255.255.255.0").To4())},
{IP: net.ParseIP("10.1.1.2"), Mask: net.IPMask(net.ParseIP("255.255.0.0"))}, {IP: net.ParseIP("10.1.1.2").To4(), Mask: net.IPMask(net.ParseIP("255.255.0.0").To4())},
{IP: net.ParseIP("10.1.1.3"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))}, {IP: net.ParseIP("10.1.1.3").To4(), Mask: net.IPMask(net.ParseIP("255.0.255.0").To4())},
} }
} }
if len(subnets) == 0 { if len(subnets) == 0 {
subnets = []*net.IPNet{ subnets = []*net.IPNet{
{IP: net.ParseIP("9.1.1.1"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))}, {IP: net.ParseIP("9.1.1.1").To4(), Mask: net.IPMask(net.ParseIP("255.0.255.0").To4())},
{IP: net.ParseIP("9.1.1.2"), Mask: net.IPMask(net.ParseIP("255.255.255.0"))}, {IP: net.ParseIP("9.1.1.2").To4(), Mask: net.IPMask(net.ParseIP("255.255.255.0").To4())},
{IP: net.ParseIP("9.1.1.3"), Mask: net.IPMask(net.ParseIP("255.255.0.0"))}, {IP: net.ParseIP("9.1.1.3").To4(), Mask: net.IPMask(net.ParseIP("255.255.0.0").To4())},
} }
} }
@@ -566,11 +579,12 @@ func newTestCert(ca *NebulaCertificate, key []byte, before, after time.Time, ips
Ips: ips, Ips: ips,
Subnets: subnets, Subnets: subnets,
Groups: groups, Groups: groups,
NotBefore: before, NotBefore: time.Unix(before.Unix(), 0),
NotAfter: after, NotAfter: time.Unix(after.Unix(), 0),
PublicKey: pub, PublicKey: pub,
IsCA: false, IsCA: false,
Issuer: issuer, Issuer: issuer,
InvertedGroups: make(map[string]struct{}),
}, },
} }

View File

@@ -3,10 +3,11 @@ package main
import ( import (
"bytes" "bytes"
"errors" "errors"
"github.com/stretchr/testify/assert"
"io" "io"
"os" "os"
"testing" "testing"
"github.com/stretchr/testify/assert"
) )
//TODO: all flag parsing continueOnError will print to stderr on its own currently //TODO: all flag parsing continueOnError will print to stderr on its own currently

View File

@@ -4,11 +4,12 @@ import (
"encoding/json" "encoding/json"
"flag" "flag"
"fmt" "fmt"
"github.com/slackhq/nebula/cert"
"io" "io"
"io/ioutil" "io/ioutil"
"os" "os"
"strings" "strings"
"github.com/slackhq/nebula/cert"
) )
type printFlags struct { type printFlags struct {

View File

@@ -2,12 +2,13 @@ package main
import ( import (
"bytes" "bytes"
"github.com/slackhq/nebula/cert"
"github.com/stretchr/testify/assert"
"io/ioutil" "io/ioutil"
"os" "os"
"testing" "testing"
"time" "time"
"github.com/slackhq/nebula/cert"
"github.com/stretchr/testify/assert"
) )
func Test_printSummary(t *testing.T) { func Test_printSummary(t *testing.T) {

View File

@@ -3,12 +3,13 @@ package main
import ( import (
"flag" "flag"
"fmt" "fmt"
"github.com/slackhq/nebula/cert"
"io" "io"
"io/ioutil" "io/ioutil"
"os" "os"
"strings" "strings"
"time" "time"
"github.com/slackhq/nebula/cert"
) )
type verifyFlags struct { type verifyFlags struct {

View File

@@ -3,13 +3,14 @@ package main
import ( import (
"bytes" "bytes"
"crypto/rand" "crypto/rand"
"github.com/slackhq/nebula/cert"
"github.com/stretchr/testify/assert"
"golang.org/x/crypto/ed25519"
"io/ioutil" "io/ioutil"
"os" "os"
"testing" "testing"
"time" "time"
"github.com/slackhq/nebula/cert"
"github.com/stretchr/testify/assert"
"golang.org/x/crypto/ed25519"
) )
func Test_verifySummary(t *testing.T) { func Test_verifySummary(t *testing.T) {

View File

@@ -5,6 +5,7 @@ import (
"fmt" "fmt"
"os" "os"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula" "github.com/slackhq/nebula"
) )
@@ -45,5 +46,30 @@ func main() {
os.Exit(1) os.Exit(1)
} }
nebula.Main(*configPath, *configTest, Build) config := nebula.NewConfig()
err := config.Load(*configPath)
if err != nil {
fmt.Printf("failed to load config: %s", err)
os.Exit(1)
}
l := logrus.New()
l.Out = os.Stdout
c, err := nebula.Main(config, *configTest, Build, l, nil)
switch v := err.(type) {
case nebula.ContextualError:
v.Log(l)
os.Exit(1)
case error:
l.WithError(err).Error("Failed to start")
os.Exit(1)
}
if !*configTest {
c.Start()
c.ShutdownBlock()
}
os.Exit(0)
} }

View File

@@ -1,44 +1,53 @@
package main package main
import ( import (
"fmt"
"log" "log"
"os" "os"
"path/filepath" "path/filepath"
"github.com/kardianos/service" "github.com/kardianos/service"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula" "github.com/slackhq/nebula"
) )
var logger service.Logger var logger service.Logger
type program struct { type program struct {
exit chan struct{}
configPath *string configPath *string
configTest *bool configTest *bool
build string build string
control *nebula.Control
} }
func (p *program) Start(s service.Service) error { func (p *program) Start(s service.Service) error {
logger.Info("Nebula service starting.")
p.exit = make(chan struct{})
// Start should not block. // Start should not block.
go p.run() logger.Info("Nebula service starting.")
return nil
config := nebula.NewConfig()
err := config.Load(*p.configPath)
if err != nil {
return fmt.Errorf("failed to load config: %s", err)
} }
func (p *program) run() error { l := logrus.New()
nebula.Main(*p.configPath, *p.configTest, Build) l.Out = os.Stdout
p.control, err = nebula.Main(config, *p.configTest, Build, l, nil)
if err != nil {
return err
}
p.control.Start()
return nil return nil
} }
func (p *program) Stop(s service.Service) error { func (p *program) Stop(s service.Service) error {
logger.Info("Nebula service stopping.") logger.Info("Nebula service stopping.")
close(p.exit) p.control.Stop()
return nil return nil
} }
func doService(configPath *string, configTest *bool, build string, serviceFlag *string) { func doService(configPath *string, configTest *bool, build string, serviceFlag *string) {
if *configPath == "" { if *configPath == "" {
ex, err := os.Executable() ex, err := os.Executable()
if err != nil { if err != nil {

View File

@@ -5,6 +5,7 @@ import (
"fmt" "fmt"
"os" "os"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula" "github.com/slackhq/nebula"
) )
@@ -39,5 +40,30 @@ func main() {
os.Exit(1) os.Exit(1)
} }
nebula.Main(*configPath, *configTest, Build) config := nebula.NewConfig()
err := config.Load(*configPath)
if err != nil {
fmt.Printf("failed to load config: %s", err)
os.Exit(1)
}
l := logrus.New()
l.Out = os.Stdout
c, err := nebula.Main(config, *configTest, Build, l, nil)
switch v := err.(type) {
case nebula.ContextualError:
v.Log(l)
os.Exit(1)
case error:
l.WithError(err).Error("Failed to start")
os.Exit(1)
}
if !*configTest {
c.Start()
c.ShutdownBlock()
}
os.Exit(0)
} }

View File

@@ -1,10 +1,8 @@
package nebula package nebula
import ( import (
"errors"
"fmt" "fmt"
"github.com/imdario/mergo"
"github.com/sirupsen/logrus"
"gopkg.in/yaml.v2"
"io/ioutil" "io/ioutil"
"net" "net"
"os" "os"
@@ -16,6 +14,10 @@ import (
"strings" "strings"
"syscall" "syscall"
"time" "time"
"github.com/imdario/mergo"
"github.com/sirupsen/logrus"
"gopkg.in/yaml.v2"
) )
type Config struct { type Config struct {
@@ -56,6 +58,13 @@ func (c *Config) Load(path string) error {
return nil return nil
} }
func (c *Config) LoadString(raw string) error {
if raw == "" {
return errors.New("Empty configuration")
}
return c.parseRaw([]byte(raw))
}
// RegisterReloadCallback stores a function to be called when a config reload is triggered. The functions registered // RegisterReloadCallback stores a function to be called when a config reload is triggered. The functions registered
// here should decide if they need to make a change to the current process before making the change. HasChanged can be // here should decide if they need to make a change to the current process before making the change. HasChanged can be
// used to help decide if a change is necessary. // used to help decide if a change is necessary.
@@ -407,6 +416,18 @@ func (c *Config) addFile(path string, direct bool) error {
return nil return nil
} }
func (c *Config) parseRaw(b []byte) error {
var m map[interface{}]interface{}
err := yaml.Unmarshal(b, &m)
if err != nil {
return err
}
c.Settings = m
return nil
}
func (c *Config) parse() error { func (c *Config) parse() error {
var m map[interface{}]interface{} var m map[interface{}]interface{}
@@ -459,6 +480,7 @@ func configLogger(c *Config) error {
} }
l.SetLevel(logLevel) l.SetLevel(logLevel)
disableTimestamp := c.GetBool("logging.disable_timestamp", false)
timestampFormat := c.GetString("logging.timestamp_format", "") timestampFormat := c.GetString("logging.timestamp_format", "")
fullTimestamp := (timestampFormat != "") fullTimestamp := (timestampFormat != "")
if timestampFormat == "" { if timestampFormat == "" {
@@ -471,10 +493,12 @@ func configLogger(c *Config) error {
l.Formatter = &logrus.TextFormatter{ l.Formatter = &logrus.TextFormatter{
TimestampFormat: timestampFormat, TimestampFormat: timestampFormat,
FullTimestamp: fullTimestamp, FullTimestamp: fullTimestamp,
DisableTimestamp: disableTimestamp,
} }
case "json": case "json":
l.Formatter = &logrus.JSONFormatter{ l.Formatter = &logrus.JSONFormatter{
TimestampFormat: timestampFormat, TimestampFormat: timestampFormat,
DisableTimestamp: disableTimestamp,
} }
default: default:
return fmt.Errorf("unknown log format `%s`. possible formats: %s", logFormat, []string{"text", "json"}) return fmt.Errorf("unknown log format `%s`. possible formats: %s", logFormat, []string{"text", "json"})

View File

@@ -1,12 +1,13 @@
package nebula package nebula
import ( import (
"github.com/stretchr/testify/assert"
"io/ioutil" "io/ioutil"
"os" "os"
"path/filepath" "path/filepath"
"testing" "testing"
"time" "time"
"github.com/stretchr/testify/assert"
) )
func TestConfig_Load(t *testing.T) { func TestConfig_Load(t *testing.T) {

View File

@@ -28,7 +28,7 @@ func Test_NewConnectionManagerTest(t *testing.T) {
rawCertificateNoKey: []byte{}, rawCertificateNoKey: []byte{},
} }
lh := NewLightHouse(false, 0, []uint32{}, 1000, 0, &udpConn{}, false, 1) lh := NewLightHouse(false, 0, []uint32{}, 1000, 0, &udpConn{}, false, 1, false)
ifce := &Interface{ ifce := &Interface{
hostMap: hostMap, hostMap: hostMap,
inside: &Tun{}, inside: &Tun{},
@@ -91,7 +91,7 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
rawCertificateNoKey: []byte{}, rawCertificateNoKey: []byte{},
} }
lh := NewLightHouse(false, 0, []uint32{}, 1000, 0, &udpConn{}, false, 1) lh := NewLightHouse(false, 0, []uint32{}, 1000, 0, &udpConn{}, false, 1, false)
ifce := &Interface{ ifce := &Interface{
hostMap: hostMap, hostMap: hostMap,
inside: &Tun{}, inside: &Tun{},

216
control.go Normal file
View File

@@ -0,0 +1,216 @@
package nebula
import (
"encoding/binary"
"fmt"
"net"
"os"
"os/signal"
"syscall"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert"
"golang.org/x/net/ipv4"
)
// Every interaction here needs to take extra care to copy memory and not return or use arguments "as is" when touching
// core. This means copying IP objects, slices, de-referencing pointers and taking the actual value, etc
type Control struct {
f *Interface
l *logrus.Logger
}
type ControlHostInfo struct {
VpnIP net.IP `json:"vpnIp"`
LocalIndex uint32 `json:"localIndex"`
RemoteIndex uint32 `json:"remoteIndex"`
RemoteAddrs []udpAddr `json:"remoteAddrs"`
CachedPackets int `json:"cachedPackets"`
Cert *cert.NebulaCertificate `json:"cert"`
MessageCounter uint64 `json:"messageCounter"`
CurrentRemote udpAddr `json:"currentRemote"`
}
// Start actually runs nebula, this is a nonblocking call. To block use Control.ShutdownBlock()
func (c *Control) Start() {
c.f.run()
}
// Stop signals nebula to shutdown, returns after the shutdown is complete
func (c *Control) Stop() {
//TODO: stop tun and udp routines, the lock on hostMap effectively does that though
//TODO: this is probably better as a function in ConnectionManager or HostMap directly
c.f.hostMap.Lock()
for _, h := range c.f.hostMap.Hosts {
if h.ConnectionState.ready {
c.f.send(closeTunnel, 0, h.ConnectionState, h, h.remote, []byte{}, make([]byte, 12, 12), make([]byte, mtu))
c.l.WithField("vpnIp", IntIp(h.hostId)).WithField("udpAddr", h.remote).
Debug("Sending close tunnel message")
}
}
c.f.hostMap.Unlock()
c.l.Info("Goodbye")
}
// ShutdownBlock will listen for and block on term and interrupt signals, calling Control.Stop() once signalled
func (c *Control) ShutdownBlock() {
sigChan := make(chan os.Signal)
signal.Notify(sigChan, syscall.SIGTERM)
signal.Notify(sigChan, syscall.SIGINT)
rawSig := <-sigChan
sig := rawSig.String()
c.l.WithField("signal", sig).Info("Caught signal, shutting down")
c.Stop()
}
// RebindUDPServer asks the UDP listener to rebind it's listener. Mainly used on mobile clients when interfaces change
func (c *Control) RebindUDPServer() {
_ = c.f.outside.Rebind()
}
// ListHostmap returns details about the actual or pending (handshaking) hostmap
func (c *Control) ListHostmap(pendingMap bool) []ControlHostInfo {
var hm *HostMap
if pendingMap {
hm = c.f.handshakeManager.pendingHostMap
} else {
hm = c.f.hostMap
}
hm.RLock()
hosts := make([]ControlHostInfo, len(hm.Hosts))
i := 0
for _, v := range hm.Hosts {
hosts[i] = copyHostInfo(v)
i++
}
hm.RUnlock()
return hosts
}
// GetHostInfoByVpnIP returns a single tunnels hostInfo, or nil if not found
func (c *Control) GetHostInfoByVpnIP(vpnIP uint32, pending bool) *ControlHostInfo {
var hm *HostMap
if pending {
hm = c.f.handshakeManager.pendingHostMap
} else {
hm = c.f.hostMap
}
h, err := hm.QueryVpnIP(vpnIP)
if err != nil {
return nil
}
ch := copyHostInfo(h)
return &ch
}
// SetRemoteForTunnel forces a tunnel to use a specific remote
func (c *Control) SetRemoteForTunnel(vpnIP uint32, addr udpAddr) *ControlHostInfo {
hostInfo, err := c.f.hostMap.QueryVpnIP(vpnIP)
if err != nil {
return nil
}
hostInfo.SetRemote(addr.Copy())
ch := copyHostInfo(hostInfo)
return &ch
}
// CloseTunnel closes a fully established tunnel. If localOnly is false it will notify the remote end as well.
func (c *Control) CloseTunnel(vpnIP uint32, localOnly bool) bool {
hostInfo, err := c.f.hostMap.QueryVpnIP(vpnIP)
if err != nil {
return false
}
if !localOnly {
c.f.send(
closeTunnel,
0,
hostInfo.ConnectionState,
hostInfo,
hostInfo.remote,
[]byte{},
make([]byte, 12, 12),
make([]byte, mtu),
)
}
c.f.closeTunnel(hostInfo)
return true
}
func copyHostInfo(h *HostInfo) ControlHostInfo {
addrs := h.RemoteUDPAddrs()
chi := ControlHostInfo{
VpnIP: int2ip(h.hostId),
LocalIndex: h.localIndexId,
RemoteIndex: h.remoteIndexId,
RemoteAddrs: make([]udpAddr, len(addrs), len(addrs)),
CachedPackets: len(h.packetStore),
MessageCounter: *h.ConnectionState.messageCounter,
}
if c := h.GetCert(); c != nil {
chi.Cert = c.Copy()
}
if h.remote != nil {
chi.CurrentRemote = *h.remote
}
for i, addr := range addrs {
chi.RemoteAddrs[i] = addr.Copy()
}
return chi
}
// Hook provides the ability to hook into the network path for a particular
// message sub type. Any received message of that subtype that is allowed by
// the firewall will be written to the provided write func instead of the
// inside interface.
// TODO: make this an io.Writer
func (c *Control) Hook(t NebulaMessageSubType, w func([]byte) error) error {
if t == 0 {
return fmt.Errorf("non-default message subtype must be specified")
}
if _, ok := c.f.handlers[Version][message][t]; ok {
return fmt.Errorf("message subtype %d already hooked", t)
}
c.f.handlers[Version][message][t] = c.f.newHook(w)
return nil
}
// Send provides the ability to send arbitrary message packets to peer nodes.
// The provided payload will be encapsulated in a Nebula Firewall packet
// (IPv4 plus ports) from the node IP to the provided destination nebula IP.
// Any protocol handling above layer 3 (IP) must be managed by the caller.
func (c *Control) Send(ip uint32, port uint16, st NebulaMessageSubType, payload []byte) {
headerLen := ipv4.HeaderLen + minFwPacketLen
length := headerLen + len(payload)
packet := make([]byte, length)
packet[0] = 0x45 // IPv4 HL=20
packet[9] = 114 // Declare as arbitrary 0-hop protocol
binary.BigEndian.PutUint16(packet[2:4], uint16(length))
binary.BigEndian.PutUint32(packet[12:16], ip2int(c.f.inside.CidrNet().IP.To4()))
binary.BigEndian.PutUint32(packet[16:20], ip)
// Set identical values for src and dst port as they're only
// used for nebula firewall rule/conntrack matching.
binary.BigEndian.PutUint16(packet[20:22], port)
binary.BigEndian.PutUint16(packet[22:24], port)
copy(packet[headerLen:], payload)
fp := &FirewallPacket{}
nb := make([]byte, 12)
out := make([]byte, mtu)
c.f.consumeInsidePacket(st, packet, fp, nb, out)
}

111
control_test.go Normal file
View File

@@ -0,0 +1,111 @@
package nebula
import (
"net"
"reflect"
"testing"
"time"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/util"
"github.com/stretchr/testify/assert"
)
func TestControl_GetHostInfoByVpnIP(t *testing.T) {
// Special care must be taken to re-use all objects provided to the hostmap and certificate in the expectedInfo object
// To properly ensure we are not exposing core memory to the caller
hm := NewHostMap("test", &net.IPNet{}, make([]*net.IPNet, 0))
remote1 := NewUDPAddr(100, 4444)
remote2 := NewUDPAddr(101, 4444)
ipNet := net.IPNet{
IP: net.IPv4(1, 2, 3, 4),
Mask: net.IPMask{255, 255, 255, 0},
}
ipNet2 := net.IPNet{
IP: net.IPv4(1, 2, 3, 5),
Mask: net.IPMask{255, 255, 255, 0},
}
crt := &cert.NebulaCertificate{
Details: cert.NebulaCertificateDetails{
Name: "test",
Ips: []*net.IPNet{&ipNet},
Subnets: []*net.IPNet{},
Groups: []string{"default-group"},
NotBefore: time.Unix(1, 0),
NotAfter: time.Unix(2, 0),
PublicKey: []byte{5, 6, 7, 8},
IsCA: false,
Issuer: "the-issuer",
InvertedGroups: map[string]struct{}{"default-group": {}},
},
Signature: []byte{1, 2, 1, 2, 1, 3},
}
counter := uint64(0)
remotes := []*HostInfoDest{NewHostInfoDest(remote1), NewHostInfoDest(remote2)}
hm.Add(ip2int(ipNet.IP), &HostInfo{
remote: remote1,
Remotes: remotes,
ConnectionState: &ConnectionState{
peerCert: crt,
messageCounter: &counter,
},
remoteIndexId: 200,
localIndexId: 201,
hostId: ip2int(ipNet.IP),
})
hm.Add(ip2int(ipNet2.IP), &HostInfo{
remote: remote1,
Remotes: remotes,
ConnectionState: &ConnectionState{
peerCert: nil,
messageCounter: &counter,
},
remoteIndexId: 200,
localIndexId: 201,
hostId: ip2int(ipNet2.IP),
})
c := Control{
f: &Interface{
hostMap: hm,
},
l: logrus.New(),
}
thi := c.GetHostInfoByVpnIP(ip2int(ipNet.IP), false)
expectedInfo := ControlHostInfo{
VpnIP: net.IPv4(1, 2, 3, 4).To4(),
LocalIndex: 201,
RemoteIndex: 200,
RemoteAddrs: []udpAddr{*remote1, *remote2},
CachedPackets: 0,
Cert: crt.Copy(),
MessageCounter: 0,
CurrentRemote: *NewUDPAddr(100, 4444),
}
// Make sure we don't have any unexpected fields
assertFields(t, []string{"VpnIP", "LocalIndex", "RemoteIndex", "RemoteAddrs", "CachedPackets", "Cert", "MessageCounter", "CurrentRemote"}, thi)
util.AssertDeepCopyEqual(t, &expectedInfo, thi)
// Make sure we don't panic if the host info doesn't have a cert yet
assert.NotPanics(t, func() {
thi = c.GetHostInfoByVpnIP(ip2int(ipNet2.IP), false)
})
}
func assertFields(t *testing.T, expected []string, actualStruct interface{}) {
val := reflect.ValueOf(actualStruct).Elem()
fields := make([]string, val.NumField())
for i := 0; i < val.NumField(); i++ {
fields[i] = val.Type().Field(i).Name
}
assert.Equal(t, expected, fields)
}

View File

@@ -1,7 +1,7 @@
[Unit] [Unit]
Description=nebula Description=nebula
Wants=basic.target Wants=basic.target network-online.target
After=basic.target network.target After=basic.target network.target network-online.target
[Service] [Service]
SyslogIdentifier=nebula SyslogIdentifier=nebula

View File

@@ -7,8 +7,8 @@ pki:
ca: /etc/nebula/ca.crt ca: /etc/nebula/ca.crt
cert: /etc/nebula/host.crt cert: /etc/nebula/host.crt
key: /etc/nebula/host.key key: /etc/nebula/host.key
#blacklist is a list of certificate fingerprints that we will refuse to talk to #blocklist is a list of certificate fingerprints that we will refuse to talk to
#blacklist: #blocklist:
# - c99d4e650533b92061b09918e838a5a0a6aaee21eed1d12fd937682865936c72 # - c99d4e650533b92061b09918e838a5a0a6aaee21eed1d12fd937682865936c72
# The static host map defines a set of hosts with fixed IP addresses on the internet (or any network). # The static host map defines a set of hosts with fixed IP addresses on the internet (or any network).
@@ -64,7 +64,7 @@ lighthouse:
# the inverse). CIDR rules are matched after interface name rules. # the inverse). CIDR rules are matched after interface name rules.
# Default is all local IP addresses. # Default is all local IP addresses.
#local_allow_list: #local_allow_list:
# Example to blacklist tun0 and all docker interfaces. # Example to block tun0 and all docker interfaces.
#interfaces: #interfaces:
#tun0: false #tun0: false
#'docker.*': false #'docker.*': false
@@ -124,6 +124,8 @@ punchy:
# Configure the private interface. Note: addr is baked into the nebula certificate # Configure the private interface. Note: addr is baked into the nebula certificate
tun: tun:
# When tun is disabled, a lighthouse can be started without a local tun interface (and therefore without root)
disabled: false
# Name of the device # Name of the device
dev: nebula1 dev: nebula1
# Toggles forwarding of local broadcast packets, the address of which depends on the ip/mask encoded in pki.cert # Toggles forwarding of local broadcast packets, the address of which depends on the ip/mask encoded in pki.cert
@@ -154,6 +156,8 @@ logging:
level: info level: info
# json or text formats currently available. Default is text # json or text formats currently available. Default is text
format: text format: text
# Disable timestamp logging. useful when output is redirected to logging system that already adds timestamps. Default is false
#disable_timestamp: true
# timestamp format is specified in Go time format, see: # timestamp format is specified in Go time format, see:
# https://golang.org/pkg/time/#pkg-constants # https://golang.org/pkg/time/#pkg-constants
# default when `format: json`: "2006-01-02T15:04:05Z07:00" (RFC3339) # default when `format: json`: "2006-01-02T15:04:05Z07:00" (RFC3339)
@@ -177,6 +181,15 @@ logging:
#subsystem: nebula #subsystem: nebula
#interval: 10s #interval: 10s
# enables counter metrics for meta packets
# e.g.: `messages.tx.handshake`
# NOTE: `message.{tx,rx}.recv_error` is always emitted
#message_metrics: false
# enables detailed counter metrics for lighthouse packets
# e.g.: `lighthouse.rx.HostQuery`
#lighthouse_metrics: false
# Handshake Manger Settings # Handshake Manger Settings
#handshakes: #handshakes:
# Total time to try a handshake = sequence of `try_interval * retries` # Total time to try a handshake = sequence of `try_interval * retries`
@@ -185,11 +198,14 @@ logging:
#retries: 20 #retries: 20
# wait_rotation is the number of handshake attempts to do before starting to try non-local IP addresses # wait_rotation is the number of handshake attempts to do before starting to try non-local IP addresses
#wait_rotation: 5 #wait_rotation: 5
# trigger_buffer is the size of the buffer channel for quickly sending handshakes
# after receiving the response for lighthouse queries
#trigger_buffer: 64
# Nebula security group configuration # Nebula security group configuration
firewall: firewall:
conntrack: conntrack:
tcp_timeout: 120h tcp_timeout: 12m
udp_timeout: 3m udp_timeout: 3m
default_timeout: 10m default_timeout: 10m
max_connections: 100000 max_connections: 100000

View File

@@ -1,21 +1,21 @@
package nebula package nebula
import ( import (
"crypto/sha256"
"encoding/binary" "encoding/binary"
"encoding/hex"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"net" "net"
"sync"
"time"
"crypto/sha256"
"encoding/hex"
"errors"
"reflect" "reflect"
"strconv" "strconv"
"strings" "strings"
"sync"
"time"
"github.com/rcrowley/go-metrics" "github.com/rcrowley/go-metrics"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/cert"
) )
@@ -38,13 +38,19 @@ type FirewallInterface interface {
type conn struct { type conn struct {
Expires time.Time // Time when this conntrack entry will expire Expires time.Time // Time when this conntrack entry will expire
Seq uint32 // If tcp rtt tracking is enabled this will be the seq we are looking for an ack
Sent time.Time // If tcp rtt tracking is enabled this will be when Seq was last set Sent time.Time // If tcp rtt tracking is enabled this will be when Seq was last set
Seq uint32 // If tcp rtt tracking is enabled this will be the seq we are looking for an ack
// record why the original connection passed the firewall, so we can re-validate
// after ruleset changes. Note, rulesVersion is a uint16 so that these two
// fields pack for free after the uint32 above
incoming bool
rulesVersion uint16
} }
// TODO: need conntrack max tracked connections handling // TODO: need conntrack max tracked connections handling
type Firewall struct { type Firewall struct {
Conns map[FirewallPacket]*conn Conntrack *FirewallConntrack
InRules *FirewallTable InRules *FirewallTable
OutRules *FirewallTable OutRules *FirewallTable
@@ -55,18 +61,23 @@ type Firewall struct {
UDPTimeout time.Duration //linux: 180s max UDPTimeout time.Duration //linux: 180s max
DefaultTimeout time.Duration //linux: 600s DefaultTimeout time.Duration //linux: 600s
TimerWheel *TimerWheel
// Used to ensure we don't emit local packets for ips we don't own // Used to ensure we don't emit local packets for ips we don't own
localIps *CIDRTree localIps *CIDRTree
connMutex sync.Mutex
rules string rules string
rulesVersion uint16
trackTCPRTT bool trackTCPRTT bool
metricTCPRTT metrics.Histogram metricTCPRTT metrics.Histogram
} }
type FirewallConntrack struct {
sync.Mutex
Conns map[FirewallPacket]*conn
TimerWheel *TimerWheel
}
type FirewallTable struct { type FirewallTable struct {
TCP firewallPort TCP firewallPort
UDP firewallPort UDP firewallPort
@@ -172,10 +183,12 @@ func NewFirewall(tcpTimeout, UDPTimeout, defaultTimeout time.Duration, c *cert.N
} }
return &Firewall{ return &Firewall{
Conntrack: &FirewallConntrack{
Conns: make(map[FirewallPacket]*conn), Conns: make(map[FirewallPacket]*conn),
TimerWheel: NewTimerWheel(min, max),
},
InRules: newFirewallTable(), InRules: newFirewallTable(),
OutRules: newFirewallTable(), OutRules: newFirewallTable(),
TimerWheel: NewTimerWheel(min, max),
TCPTimeout: tcpTimeout, TCPTimeout: tcpTimeout,
UDPTimeout: UDPTimeout, UDPTimeout: UDPTimeout,
DefaultTimeout: defaultTimeout, DefaultTimeout: defaultTimeout,
@@ -208,11 +221,17 @@ func NewFirewallFromConfig(nc *cert.NebulaCertificate, c *Config) (*Firewall, er
// AddRule properly creates the in memory rule structure for a firewall table. // AddRule properly creates the in memory rule structure for a firewall table.
func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, caName string, caSha string) error { func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, caName string, caSha string) error {
// Under gomobile, stringing a nil pointer with fmt causes an abort in debug mode for iOS
// https://github.com/golang/go/issues/14131
sIp := ""
if ip != nil {
sIp = ip.String()
}
// We need this rule string because we generate a hash. Removing this will break firewall reload. // We need this rule string because we generate a hash. Removing this will break firewall reload.
ruleString := fmt.Sprintf( ruleString := fmt.Sprintf(
"incoming: %v, proto: %v, startPort: %v, endPort: %v, groups: %v, host: %v, ip: %v, caName: %v, caSha: %s", "incoming: %v, proto: %v, startPort: %v, endPort: %v, groups: %v, host: %v, ip: %v, caName: %v, caSha: %s",
incoming, proto, startPort, endPort, groups, host, ip, caName, caSha, incoming, proto, startPort, endPort, groups, host, sIp, caName, caSha,
) )
f.rules += ruleString + "\n" f.rules += ruleString + "\n"
@@ -220,7 +239,7 @@ func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort
if !incoming { if !incoming {
direction = "outgoing" direction = "outgoing"
} }
l.WithField("firewallRule", m{"direction": direction, "proto": proto, "startPort": startPort, "endPort": endPort, "groups": groups, "host": host, "ip": ip, "caName": caName, "caSha": caSha}). l.WithField("firewallRule", m{"direction": direction, "proto": proto, "startPort": startPort, "endPort": endPort, "groups": groups, "host": host, "ip": sIp, "caName": caName, "caSha": caSha}).
Info("Firewall rule added") Info("Firewall rule added")
var ( var (
@@ -347,27 +366,33 @@ func AddFirewallRulesFromConfig(inbound bool, config *Config, fw FirewallInterfa
return nil return nil
} }
func (f *Firewall) Drop(packet []byte, fp FirewallPacket, incoming bool, h *HostInfo, caPool *cert.NebulaCAPool) bool { var ErrInvalidRemoteIP = errors.New("remote IP is not in remote certificate subnets")
var ErrInvalidLocalIP = errors.New("local IP is not in list of handled local IPs")
var ErrNoMatchingRule = errors.New("no matching rule in firewall table")
// Drop returns an error if the packet should be dropped, explaining why. It
// returns nil if the packet should not be dropped.
func (f *Firewall) Drop(packet []byte, fp FirewallPacket, incoming bool, h *HostInfo, caPool *cert.NebulaCAPool) error {
// Check if we spoke to this tuple, if we did then allow this packet // Check if we spoke to this tuple, if we did then allow this packet
if f.inConns(packet, fp, incoming) { if f.inConns(packet, fp, incoming, h, caPool) {
return false return nil
} }
// Make sure remote address matches nebula certificate // Make sure remote address matches nebula certificate
if remoteCidr := h.remoteCidr; remoteCidr != nil { if remoteCidr := h.remoteCidr; remoteCidr != nil {
if remoteCidr.Contains(fp.RemoteIP) == nil { if remoteCidr.Contains(fp.RemoteIP) == nil {
return true return ErrInvalidRemoteIP
} }
} else { } else {
// Simple case: Certificate has one IP and no subnets // Simple case: Certificate has one IP and no subnets
if fp.RemoteIP != h.hostId { if fp.RemoteIP != h.hostId {
return true return ErrInvalidRemoteIP
} }
} }
// Make sure we are supposed to be handling this local ip address // Make sure we are supposed to be handling this local ip address
if f.localIps.Contains(fp.LocalIP) == nil { if f.localIps.Contains(fp.LocalIP) == nil {
return true return ErrInvalidLocalIP
} }
table := f.OutRules table := f.OutRules
@@ -377,13 +402,13 @@ func (f *Firewall) Drop(packet []byte, fp FirewallPacket, incoming bool, h *Host
// We now know which firewall table to check against // We now know which firewall table to check against
if !table.match(fp, incoming, h.ConnectionState.peerCert, caPool) { if !table.match(fp, incoming, h.ConnectionState.peerCert, caPool) {
return true return ErrNoMatchingRule
} }
// We always want to conntrack since it is a faster operation // We always want to conntrack since it is a faster operation
f.addConn(packet, fp, incoming) f.addConn(packet, fp, incoming)
return false return nil
} }
// Destroy cleans up any known cyclical references so the object can be free'd my GC. This should be called if a new // Destroy cleans up any known cyclical references so the object can be free'd my GC. This should be called if a new
@@ -393,26 +418,66 @@ func (f *Firewall) Destroy() {
} }
func (f *Firewall) EmitStats() { func (f *Firewall) EmitStats() {
conntrackCount := len(f.Conns) conntrack := f.Conntrack
conntrack.Lock()
conntrackCount := len(conntrack.Conns)
conntrack.Unlock()
metrics.GetOrRegisterGauge("firewall.conntrack.count", nil).Update(int64(conntrackCount)) metrics.GetOrRegisterGauge("firewall.conntrack.count", nil).Update(int64(conntrackCount))
metrics.GetOrRegisterGauge("firewall.rules.version", nil).Update(int64(f.rulesVersion))
} }
func (f *Firewall) inConns(packet []byte, fp FirewallPacket, incoming bool) bool { func (f *Firewall) inConns(packet []byte, fp FirewallPacket, incoming bool, h *HostInfo, caPool *cert.NebulaCAPool) bool {
f.connMutex.Lock() conntrack := f.Conntrack
conntrack.Lock()
// Purge every time we test // Purge every time we test
ep, has := f.TimerWheel.Purge() ep, has := conntrack.TimerWheel.Purge()
if has { if has {
f.evict(ep) f.evict(ep)
} }
c, ok := f.Conns[fp] c, ok := conntrack.Conns[fp]
if !ok { if !ok {
f.connMutex.Unlock() conntrack.Unlock()
return false return false
} }
if c.rulesVersion != f.rulesVersion {
// This conntrack entry was for an older rule set, validate
// it still passes with the current rule set
table := f.OutRules
if c.incoming {
table = f.InRules
}
// We now know which firewall table to check against
if !table.match(fp, c.incoming, h.ConnectionState.peerCert, caPool) {
if l.Level >= logrus.DebugLevel {
h.logger().
WithField("fwPacket", fp).
WithField("incoming", c.incoming).
WithField("rulesVersion", f.rulesVersion).
WithField("oldRulesVersion", c.rulesVersion).
Debugln("dropping old conntrack entry, does not match new ruleset")
}
delete(conntrack.Conns, fp)
conntrack.Unlock()
return false
}
if l.Level >= logrus.DebugLevel {
h.logger().
WithField("fwPacket", fp).
WithField("incoming", c.incoming).
WithField("rulesVersion", f.rulesVersion).
WithField("oldRulesVersion", c.rulesVersion).
Debugln("keeping old conntrack entry, does match new ruleset")
}
c.rulesVersion = f.rulesVersion
}
switch fp.Protocol { switch fp.Protocol {
case fwProtoTCP: case fwProtoTCP:
c.Expires = time.Now().Add(f.TCPTimeout) c.Expires = time.Now().Add(f.TCPTimeout)
@@ -427,7 +492,7 @@ func (f *Firewall) inConns(packet []byte, fp FirewallPacket, incoming bool) bool
c.Expires = time.Now().Add(f.DefaultTimeout) c.Expires = time.Now().Add(f.DefaultTimeout)
} }
f.connMutex.Unlock() conntrack.Unlock()
return true return true
} }
@@ -448,14 +513,19 @@ func (f *Firewall) addConn(packet []byte, fp FirewallPacket, incoming bool) {
timeout = f.DefaultTimeout timeout = f.DefaultTimeout
} }
f.connMutex.Lock() conntrack := f.Conntrack
if _, ok := f.Conns[fp]; !ok { conntrack.Lock()
f.TimerWheel.Add(fp, timeout) if _, ok := conntrack.Conns[fp]; !ok {
conntrack.TimerWheel.Add(fp, timeout)
} }
// Record which rulesVersion allowed this connection, so we can retest after
// firewall reload
c.incoming = incoming
c.rulesVersion = f.rulesVersion
c.Expires = time.Now().Add(timeout) c.Expires = time.Now().Add(timeout)
f.Conns[fp] = c conntrack.Conns[fp] = c
f.connMutex.Unlock() conntrack.Unlock()
} }
// Evict checks if a conntrack entry has expired, if so it is removed, if not it is re-added to the wheel // Evict checks if a conntrack entry has expired, if so it is removed, if not it is re-added to the wheel
@@ -463,7 +533,8 @@ func (f *Firewall) addConn(packet []byte, fp FirewallPacket, incoming bool) {
func (f *Firewall) evict(p FirewallPacket) { func (f *Firewall) evict(p FirewallPacket) {
//TODO: report a stat if the tcp rtt tracking was never resolved? //TODO: report a stat if the tcp rtt tracking was never resolved?
// Are we still tracking this conn? // Are we still tracking this conn?
t, ok := f.Conns[p] conntrack := f.Conntrack
t, ok := conntrack.Conns[p]
if !ok { if !ok {
return return
} }
@@ -472,12 +543,12 @@ func (f *Firewall) evict(p FirewallPacket) {
// Timeout is in the future, re-add the timer // Timeout is in the future, re-add the timer
if newT > 0 { if newT > 0 {
f.TimerWheel.Add(p, newT) conntrack.TimerWheel.Add(p, newT)
return return
} }
// This conn is done // This conn is done
delete(f.Conns, p) delete(conntrack.Conns, p)
} }
func (ft *FirewallTable) match(p FirewallPacket, incoming bool, c *cert.NebulaCertificate, caPool *cert.NebulaCAPool) bool { func (ft *FirewallTable) match(p FirewallPacket, incoming bool, c *cert.NebulaCertificate, caPool *cert.NebulaCAPool) bool {

View File

@@ -17,37 +17,39 @@ import (
func TestNewFirewall(t *testing.T) { func TestNewFirewall(t *testing.T) {
c := &cert.NebulaCertificate{} c := &cert.NebulaCertificate{}
fw := NewFirewall(time.Second, time.Minute, time.Hour, c) fw := NewFirewall(time.Second, time.Minute, time.Hour, c)
assert.NotNil(t, fw.Conns) conntrack := fw.Conntrack
assert.NotNil(t, conntrack)
assert.NotNil(t, conntrack.Conns)
assert.NotNil(t, conntrack.TimerWheel)
assert.NotNil(t, fw.InRules) assert.NotNil(t, fw.InRules)
assert.NotNil(t, fw.OutRules) assert.NotNil(t, fw.OutRules)
assert.NotNil(t, fw.TimerWheel)
assert.Equal(t, time.Second, fw.TCPTimeout) assert.Equal(t, time.Second, fw.TCPTimeout)
assert.Equal(t, time.Minute, fw.UDPTimeout) assert.Equal(t, time.Minute, fw.UDPTimeout)
assert.Equal(t, time.Hour, fw.DefaultTimeout) assert.Equal(t, time.Hour, fw.DefaultTimeout)
assert.Equal(t, time.Hour, fw.TimerWheel.wheelDuration) assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
assert.Equal(t, time.Hour, fw.TimerWheel.wheelDuration) assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
assert.Equal(t, 3601, fw.TimerWheel.wheelLen) assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen)
fw = NewFirewall(time.Second, time.Hour, time.Minute, c) fw = NewFirewall(time.Second, time.Hour, time.Minute, c)
assert.Equal(t, time.Hour, fw.TimerWheel.wheelDuration) assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
assert.Equal(t, 3601, fw.TimerWheel.wheelLen) assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen)
fw = NewFirewall(time.Hour, time.Second, time.Minute, c) fw = NewFirewall(time.Hour, time.Second, time.Minute, c)
assert.Equal(t, time.Hour, fw.TimerWheel.wheelDuration) assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
assert.Equal(t, 3601, fw.TimerWheel.wheelLen) assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen)
fw = NewFirewall(time.Hour, time.Minute, time.Second, c) fw = NewFirewall(time.Hour, time.Minute, time.Second, c)
assert.Equal(t, time.Hour, fw.TimerWheel.wheelDuration) assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
assert.Equal(t, 3601, fw.TimerWheel.wheelLen) assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen)
fw = NewFirewall(time.Minute, time.Hour, time.Second, c) fw = NewFirewall(time.Minute, time.Hour, time.Second, c)
assert.Equal(t, time.Hour, fw.TimerWheel.wheelDuration) assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
assert.Equal(t, 3601, fw.TimerWheel.wheelLen) assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen)
fw = NewFirewall(time.Minute, time.Second, time.Hour, c) fw = NewFirewall(time.Minute, time.Second, time.Hour, c)
assert.Equal(t, time.Hour, fw.TimerWheel.wheelDuration) assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
assert.Equal(t, 3601, fw.TimerWheel.wheelLen) assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen)
} }
func TestFirewall_AddRule(t *testing.T) { func TestFirewall_AddRule(t *testing.T) {
@@ -180,44 +182,44 @@ func TestFirewall_Drop(t *testing.T) {
cp := cert.NewCAPool() cp := cert.NewCAPool()
// Drop outbound // Drop outbound
assert.True(t, fw.Drop([]byte{}, p, false, &h, cp)) assert.Equal(t, fw.Drop([]byte{}, p, false, &h, cp), ErrNoMatchingRule)
// Allow inbound // Allow inbound
resetConntrack(fw) resetConntrack(fw)
assert.False(t, fw.Drop([]byte{}, p, true, &h, cp)) assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp))
// Allow outbound because conntrack // Allow outbound because conntrack
assert.False(t, fw.Drop([]byte{}, p, false, &h, cp)) assert.NoError(t, fw.Drop([]byte{}, p, false, &h, cp))
// test remote mismatch // test remote mismatch
oldRemote := p.RemoteIP oldRemote := p.RemoteIP
p.RemoteIP = ip2int(net.IPv4(1, 2, 3, 10)) p.RemoteIP = ip2int(net.IPv4(1, 2, 3, 10))
assert.True(t, fw.Drop([]byte{}, p, false, &h, cp)) assert.Equal(t, fw.Drop([]byte{}, p, false, &h, cp), ErrInvalidRemoteIP)
p.RemoteIP = oldRemote p.RemoteIP = oldRemote
// ensure signer doesn't get in the way of group checks // ensure signer doesn't get in the way of group checks
fw = NewFirewall(time.Second, time.Minute, time.Hour, &c) fw = NewFirewall(time.Second, time.Minute, time.Hour, &c)
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"nope"}, "", nil, "", "signer-shasum")) assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"nope"}, "", nil, "", "signer-shasum"))
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group"}, "", nil, "", "signer-shasum-bad")) assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group"}, "", nil, "", "signer-shasum-bad"))
assert.True(t, fw.Drop([]byte{}, p, true, &h, cp)) assert.Equal(t, fw.Drop([]byte{}, p, true, &h, cp), ErrNoMatchingRule)
// test caSha doesn't drop on match // test caSha doesn't drop on match
fw = NewFirewall(time.Second, time.Minute, time.Hour, &c) fw = NewFirewall(time.Second, time.Minute, time.Hour, &c)
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"nope"}, "", nil, "", "signer-shasum-bad")) assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"nope"}, "", nil, "", "signer-shasum-bad"))
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group"}, "", nil, "", "signer-shasum")) assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group"}, "", nil, "", "signer-shasum"))
assert.False(t, fw.Drop([]byte{}, p, true, &h, cp)) assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp))
// ensure ca name doesn't get in the way of group checks // ensure ca name doesn't get in the way of group checks
cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}} cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}}
fw = NewFirewall(time.Second, time.Minute, time.Hour, &c) fw = NewFirewall(time.Second, time.Minute, time.Hour, &c)
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"nope"}, "", nil, "ca-good", "")) assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"nope"}, "", nil, "ca-good", ""))
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group"}, "", nil, "ca-good-bad", "")) assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group"}, "", nil, "ca-good-bad", ""))
assert.True(t, fw.Drop([]byte{}, p, true, &h, cp)) assert.Equal(t, fw.Drop([]byte{}, p, true, &h, cp), ErrNoMatchingRule)
// test caName doesn't drop on match // test caName doesn't drop on match
cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}} cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}}
fw = NewFirewall(time.Second, time.Minute, time.Hour, &c) fw = NewFirewall(time.Second, time.Minute, time.Hour, &c)
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"nope"}, "", nil, "ca-good-bad", "")) assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"nope"}, "", nil, "ca-good-bad", ""))
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group"}, "", nil, "ca-good", "")) assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group"}, "", nil, "ca-good", ""))
assert.False(t, fw.Drop([]byte{}, p, true, &h, cp)) assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp))
} }
func BenchmarkFirewallTable_match(b *testing.B) { func BenchmarkFirewallTable_match(b *testing.B) {
@@ -368,10 +370,10 @@ func TestFirewall_Drop2(t *testing.T) {
cp := cert.NewCAPool() cp := cert.NewCAPool()
// h1/c1 lacks the proper groups // h1/c1 lacks the proper groups
assert.True(t, fw.Drop([]byte{}, p, true, &h1, cp)) assert.Error(t, fw.Drop([]byte{}, p, true, &h1, cp), ErrNoMatchingRule)
// c has the proper groups // c has the proper groups
resetConntrack(fw) resetConntrack(fw)
assert.False(t, fw.Drop([]byte{}, p, true, &h, cp)) assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp))
} }
func TestFirewall_Drop3(t *testing.T) { func TestFirewall_Drop3(t *testing.T) {
@@ -452,13 +454,81 @@ func TestFirewall_Drop3(t *testing.T) {
cp := cert.NewCAPool() cp := cert.NewCAPool()
// c1 should pass because host match // c1 should pass because host match
assert.False(t, fw.Drop([]byte{}, p, true, &h1, cp)) assert.NoError(t, fw.Drop([]byte{}, p, true, &h1, cp))
// c2 should pass because ca sha match // c2 should pass because ca sha match
resetConntrack(fw) resetConntrack(fw)
assert.False(t, fw.Drop([]byte{}, p, true, &h2, cp)) assert.NoError(t, fw.Drop([]byte{}, p, true, &h2, cp))
// c3 should fail because no match // c3 should fail because no match
resetConntrack(fw) resetConntrack(fw)
assert.True(t, fw.Drop([]byte{}, p, true, &h3, cp)) assert.Equal(t, fw.Drop([]byte{}, p, true, &h3, cp), ErrNoMatchingRule)
}
func TestFirewall_DropConntrackReload(t *testing.T) {
ob := &bytes.Buffer{}
out := l.Out
l.SetOutput(ob)
defer l.SetOutput(out)
p := FirewallPacket{
ip2int(net.IPv4(1, 2, 3, 4)),
ip2int(net.IPv4(1, 2, 3, 4)),
10,
90,
fwProtoUDP,
false,
}
ipNet := net.IPNet{
IP: net.IPv4(1, 2, 3, 4),
Mask: net.IPMask{255, 255, 255, 0},
}
c := cert.NebulaCertificate{
Details: cert.NebulaCertificateDetails{
Name: "host1",
Ips: []*net.IPNet{&ipNet},
Groups: []string{"default-group"},
InvertedGroups: map[string]struct{}{"default-group": {}},
Issuer: "signer-shasum",
},
}
h := HostInfo{
ConnectionState: &ConnectionState{
peerCert: &c,
},
hostId: ip2int(ipNet.IP),
}
h.CreateRemoteCIDR(&c)
fw := NewFirewall(time.Second, time.Minute, time.Hour, &c)
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"any"}, "", nil, "", ""))
cp := cert.NewCAPool()
// Drop outbound
assert.Equal(t, fw.Drop([]byte{}, p, false, &h, cp), ErrNoMatchingRule)
// Allow inbound
resetConntrack(fw)
assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp))
// Allow outbound because conntrack
assert.NoError(t, fw.Drop([]byte{}, p, false, &h, cp))
oldFw := fw
fw = NewFirewall(time.Second, time.Minute, time.Hour, &c)
assert.Nil(t, fw.AddRule(true, fwProtoAny, 10, 10, []string{"any"}, "", nil, "", ""))
fw.Conntrack = oldFw.Conntrack
fw.rulesVersion = oldFw.rulesVersion + 1
// Allow outbound because conntrack and new rules allow port 10
assert.NoError(t, fw.Drop([]byte{}, p, false, &h, cp))
oldFw = fw
fw = NewFirewall(time.Second, time.Minute, time.Hour, &c)
assert.Nil(t, fw.AddRule(true, fwProtoAny, 11, 11, []string{"any"}, "", nil, "", ""))
fw.Conntrack = oldFw.Conntrack
fw.rulesVersion = oldFw.rulesVersion + 1
// Drop outbound because conntrack doesn't match new ruleset
assert.Equal(t, fw.Drop([]byte{}, p, false, &h, cp), ErrNoMatchingRule)
} }
func BenchmarkLookup(b *testing.B) { func BenchmarkLookup(b *testing.B) {
@@ -861,7 +931,7 @@ func (mf *mockFirewall) AddRule(incoming bool, proto uint8, startPort int32, end
} }
func resetConntrack(fw *Firewall) { func resetConntrack(fw *Firewall) {
fw.connMutex.Lock() fw.Conntrack.Lock()
fw.Conns = map[FirewallPacket]*conn{} fw.Conntrack.Conns = map[FirewallPacket]*conn{}
fw.connMutex.Unlock() fw.Conntrack.Unlock()
} }

4
go.mod
View File

@@ -11,7 +11,7 @@ require (
github.com/flynn/noise v0.0.0-20180327030543-2492fe189ae6 github.com/flynn/noise v0.0.0-20180327030543-2492fe189ae6
github.com/golang/protobuf v1.3.2 github.com/golang/protobuf v1.3.2
github.com/imdario/mergo v0.3.8 github.com/imdario/mergo v0.3.8
github.com/kardianos/service v1.0.0 github.com/kardianos/service v1.1.0
github.com/konsorten/go-windows-terminal-sequences v1.0.2 // indirect github.com/konsorten/go-windows-terminal-sequences v1.0.2 // indirect
github.com/kr/pretty v0.1.0 // indirect github.com/kr/pretty v0.1.0 // indirect
github.com/miekg/dns v1.1.25 github.com/miekg/dns v1.1.25
@@ -22,7 +22,7 @@ require (
github.com/rcrowley/go-metrics v0.0.0-20190826022208-cac0b30c2563 github.com/rcrowley/go-metrics v0.0.0-20190826022208-cac0b30c2563
github.com/sirupsen/logrus v1.4.2 github.com/sirupsen/logrus v1.4.2
github.com/songgao/water v0.0.0-20190725173103-fd331bda3f4b github.com/songgao/water v0.0.0-20190725173103-fd331bda3f4b
github.com/stretchr/testify v1.4.0 github.com/stretchr/testify v1.6.1
github.com/vishvananda/netlink v1.0.1-0.20190522153524-00009fb8606a github.com/vishvananda/netlink v1.0.1-0.20190522153524-00009fb8606a
github.com/vishvananda/netns v0.0.0-20191106174202-0a2b9b5464df // indirect github.com/vishvananda/netns v0.0.0-20191106174202-0a2b9b5464df // indirect
golang.org/x/crypto v0.0.0-20200220183623-bac4c82f6975 golang.org/x/crypto v0.0.0-20200220183623-bac4c82f6975

10
go.sum
View File

@@ -46,6 +46,8 @@ github.com/json-iterator/go v1.1.7/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/u
github.com/julienschmidt/httprouter v1.2.0/go.mod h1:SYymIcj16QtmaHHD7aYtjjsJG7VTCxuUUipMqKk8s4w= github.com/julienschmidt/httprouter v1.2.0/go.mod h1:SYymIcj16QtmaHHD7aYtjjsJG7VTCxuUUipMqKk8s4w=
github.com/kardianos/service v1.0.0 h1:HgQS3mFfOlyntWX8Oke98JcJLqt1DBcHR4kxShpYef0= github.com/kardianos/service v1.0.0 h1:HgQS3mFfOlyntWX8Oke98JcJLqt1DBcHR4kxShpYef0=
github.com/kardianos/service v1.0.0/go.mod h1:8CzDhVuCuugtsHyZoTvsOBuvonN/UDBvl0kH+BUxvbo= github.com/kardianos/service v1.0.0/go.mod h1:8CzDhVuCuugtsHyZoTvsOBuvonN/UDBvl0kH+BUxvbo=
github.com/kardianos/service v1.1.0 h1:QV2SiEeWK42P0aEmGcsAgjApw/lRxkwopvT+Gu6t1/0=
github.com/kardianos/service v1.1.0/go.mod h1:RrJI2xn5vve/r32U5suTbeaSGoMU6GbNPoj36CVYcHc=
github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
github.com/konsorten/go-windows-terminal-sequences v1.0.2 h1:DB17ag19krx9CFsz4o3enTrPXyIXCl+2iCXH/aMAp9s= github.com/konsorten/go-windows-terminal-sequences v1.0.2 h1:DB17ag19krx9CFsz4o3enTrPXyIXCl+2iCXH/aMAp9s=
github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
@@ -103,8 +105,8 @@ github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q= github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0=
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/vishvananda/netlink v1.0.1-0.20190522153524-00009fb8606a h1:Bt1IVPhiCDMqwGrc2nnbIN4QKvJGx6SK2NzWBmW00ao= github.com/vishvananda/netlink v1.0.1-0.20190522153524-00009fb8606a h1:Bt1IVPhiCDMqwGrc2nnbIN4QKvJGx6SK2NzWBmW00ao=
github.com/vishvananda/netlink v1.0.1-0.20190522153524-00009fb8606a/go.mod h1:+SR5DhBJrl6ZM7CoCKvpw5BKroDKQ+PJqOg65H/2ktk= github.com/vishvananda/netlink v1.0.1-0.20190522153524-00009fb8606a/go.mod h1:+SR5DhBJrl6ZM7CoCKvpw5BKroDKQ+PJqOg65H/2ktk=
github.com/vishvananda/netns v0.0.0-20191106174202-0a2b9b5464df h1:OviZH7qLw/7ZovXvuNyL3XQl8UFofeikI1NW1Gypu7k= github.com/vishvananda/netns v0.0.0-20191106174202-0a2b9b5464df h1:OviZH7qLw/7ZovXvuNyL3XQl8UFofeikI1NW1Gypu7k=
@@ -112,8 +114,6 @@ github.com/vishvananda/netns v0.0.0-20191106174202-0a2b9b5464df/go.mod h1:JP3t17
golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20190923035154-9ee001bba392/go.mod h1:/lpIB1dKB+9EgE3H3cr1v9wB50oz8l4C4h62xy7jSTY= golang.org/x/crypto v0.0.0-20190923035154-9ee001bba392/go.mod h1:/lpIB1dKB+9EgE3H3cr1v9wB50oz8l4C4h62xy7jSTY=
golang.org/x/crypto v0.0.0-20191206172530-e9b2fee46413 h1:ULYEB3JvPRE/IfO+9uO7vKV/xzVTO7XPAwm8xbf4w2g=
golang.org/x/crypto v0.0.0-20191206172530-e9b2fee46413/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20200220183623-bac4c82f6975 h1:/Tl7pH94bvbAAHBdZJT947M/+gp0+CqQXDtMRC0fseo= golang.org/x/crypto v0.0.0-20200220183623-bac4c82f6975 h1:/Tl7pH94bvbAAHBdZJT947M/+gp0+CqQXDtMRC0fseo=
golang.org/x/crypto v0.0.0-20200220183623-bac4c82f6975/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20200220183623-bac4c82f6975/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
@@ -154,3 +154,5 @@ gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw=
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.2.7 h1:VUgggvou5XRW9mHwD/yXxIYSMtY0zoKQf/v226p2nyo= gopkg.in/yaml.v2 v2.2.7 h1:VUgggvou5XRW9mHwD/yXxIYSMtY0zoKQf/v226p2nyo=
gopkg.in/yaml.v2 v2.2.7/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.7/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

83
handler.go Normal file
View File

@@ -0,0 +1,83 @@
package nebula
func (f *Interface) newHook(w func([]byte) error) InsideHandler {
fn := func(hostInfo *HostInfo, ci *ConnectionState, addr *udpAddr, header *Header, out []byte, packet []byte, fwPacket *FirewallPacket, nb []byte) {
f.decryptTo(w, hostInfo, header.MessageCounter, out, packet, fwPacket, nb)
}
return f.encrypted(fn)
}
func (f *Interface) encrypted(h InsideHandler) InsideHandler {
return func(hostInfo *HostInfo, ci *ConnectionState, addr *udpAddr, header *Header, out []byte, packet []byte, fwPacket *FirewallPacket, nb []byte) {
if !f.handleEncrypted(ci, addr, header) {
return
}
h(hostInfo, ci, addr, header, out, packet, fwPacket, nb)
f.handleHostRoaming(hostInfo, addr)
f.connectionManager.In(hostInfo.hostId)
}
}
func (f *Interface) rxMetrics(h InsideHandler) InsideHandler {
return func(hostInfo *HostInfo, ci *ConnectionState, addr *udpAddr, header *Header, out []byte, packet []byte, fwPacket *FirewallPacket, nb []byte) {
f.messageMetrics.Rx(header.Type, header.Subtype, 1)
h(hostInfo, ci, addr, header, out, packet, fwPacket, nb)
}
}
func (f *Interface) handleMessagePacket(hostInfo *HostInfo, ci *ConnectionState, addr *udpAddr, header *Header, out []byte, packet []byte, fwPacket *FirewallPacket, nb []byte) {
f.decryptTo(f.inside.WriteRaw, hostInfo, header.MessageCounter, out, packet, fwPacket, nb)
}
func (f *Interface) handleLighthousePacket(hostInfo *HostInfo, ci *ConnectionState, addr *udpAddr, header *Header, out []byte, packet []byte, fwPacket *FirewallPacket, nb []byte) {
d, err := f.decrypt(hostInfo, header.MessageCounter, out, packet, header, nb)
if err != nil {
hostInfo.logger().WithError(err).WithField("udpAddr", addr).
WithField("packet", packet).
Error("Failed to decrypt lighthouse packet")
//TODO: maybe after build 64 is out? 06/14/2018 - NB
//f.sendRecvError(net.Addr(addr), header.RemoteIndex)
return
}
f.lightHouse.HandleRequest(addr, hostInfo.hostId, d, hostInfo.GetCert(), f)
}
func (f *Interface) handleTestPacket(hostInfo *HostInfo, ci *ConnectionState, addr *udpAddr, header *Header, out []byte, packet []byte, fwPacket *FirewallPacket, nb []byte) {
d, err := f.decrypt(hostInfo, header.MessageCounter, out, packet, header, nb)
if err != nil {
hostInfo.logger().WithError(err).WithField("udpAddr", addr).
WithField("packet", packet).
Error("Failed to decrypt test packet")
//TODO: maybe after build 64 is out? 06/14/2018 - NB
//f.sendRecvError(net.Addr(addr), header.RemoteIndex)
return
}
if header.Subtype == testRequest {
// This testRequest might be from TryPromoteBest, so we should roam
// to the new IP address before responding
f.handleHostRoaming(hostInfo, addr)
f.send(test, testReply, ci, hostInfo, hostInfo.remote, d, nb, out)
}
}
func (f *Interface) handleHandshakePacket(hostInfo *HostInfo, ci *ConnectionState, addr *udpAddr, header *Header, out []byte, packet []byte, fwPacket *FirewallPacket, nb []byte) {
HandleIncomingHandshake(f, addr, packet, header, hostInfo)
}
func (f *Interface) handleRecvErrorPacket(hostInfo *HostInfo, ci *ConnectionState, addr *udpAddr, header *Header, out []byte, packet []byte, fwPacket *FirewallPacket, nb []byte) {
// TODO: Remove this with recv_error deprecation
f.handleRecvError(addr, header)
}
func (f *Interface) handleCloseTunnelPacket(hostInfo *HostInfo, ci *ConnectionState, addr *udpAddr, header *Header, out []byte, packet []byte, fwPacket *FirewallPacket, nb []byte) {
hostInfo.logger().WithField("udpAddr", addr).
Info("Close tunnel received, tearing down.")
f.closeTunnel(hostInfo)
}

View File

@@ -1,11 +1,10 @@
package nebula package nebula
import ( import (
"bytes"
"sync/atomic" "sync/atomic"
"time" "time"
"bytes"
"github.com/flynn/noise" "github.com/flynn/noise"
"github.com/golang/protobuf/proto" "github.com/golang/protobuf/proto"
) )
@@ -98,6 +97,7 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
hostinfo, _ := f.handshakeManager.pendingHostMap.QueryReverseIndex(hs.Details.InitiatorIndex) hostinfo, _ := f.handshakeManager.pendingHostMap.QueryReverseIndex(hs.Details.InitiatorIndex)
if hostinfo != nil && bytes.Equal(hostinfo.HandshakePacket[0], packet[HeaderLen:]) { if hostinfo != nil && bytes.Equal(hostinfo.HandshakePacket[0], packet[HeaderLen:]) {
if msg, ok := hostinfo.HandshakePacket[2]; ok { if msg, ok := hostinfo.HandshakePacket[2]; ok {
f.messageMetrics.Tx(handshake, NebulaMessageSubType(msg[1]), 1)
err := f.outside.WriteTo(msg, addr) err := f.outside.WriteTo(msg, addr)
if err != nil { if err != nil {
l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr). l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
@@ -126,11 +126,13 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
} }
vpnIP := ip2int(remoteCert.Details.Ips[0].IP) vpnIP := ip2int(remoteCert.Details.Ips[0].IP)
certName := remoteCert.Details.Name certName := remoteCert.Details.Name
fingerprint, _ := remoteCert.Sha256Sum()
myIndex, err := generateIndex() myIndex, err := generateIndex()
if err != nil { if err != nil {
l.WithError(err).WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr). l.WithError(err).WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
WithField("certName", certName). WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to generate index") WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to generate index")
return true return true
} }
@@ -139,12 +141,14 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
if err != nil { if err != nil {
l.WithError(err).WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr). l.WithError(err).WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
WithField("certName", certName). WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Error adding index to connection manager") WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Error adding index to connection manager")
return true return true
} }
l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr). l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
WithField("certName", certName). WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
Info("Handshake message received") Info("Handshake message received")
@@ -157,6 +161,7 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
if err != nil { if err != nil {
l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr). l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
WithField("certName", certName). WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to marshal handshake message") WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to marshal handshake message")
return true return true
} }
@@ -166,6 +171,7 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
if err != nil { if err != nil {
l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr). l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
WithField("certName", certName). WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage") WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage")
return true return true
} }
@@ -173,6 +179,7 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
if f.hostMap.CheckHandshakeCompleteIP(vpnIP) && vpnIP < ip2int(f.certState.certificate.Details.Ips[0].IP) { if f.hostMap.CheckHandshakeCompleteIP(vpnIP) && vpnIP < ip2int(f.certState.certificate.Details.Ips[0].IP) {
l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr). l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
WithField("certName", certName). WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
Info("Prevented a handshake race") Info("Prevented a handshake race")
@@ -191,16 +198,19 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
hostinfo.HandshakePacket[2] = make([]byte, len(msg)) hostinfo.HandshakePacket[2] = make([]byte, len(msg))
copy(hostinfo.HandshakePacket[2], msg) copy(hostinfo.HandshakePacket[2], msg)
f.messageMetrics.Tx(handshake, NebulaMessageSubType(msg[1]), 1)
err := f.outside.WriteTo(msg, addr) err := f.outside.WriteTo(msg, addr)
if err != nil { if err != nil {
l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr). l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
WithField("certName", certName). WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
WithError(err).Error("Failed to send handshake") WithError(err).Error("Failed to send handshake")
} else { } else {
l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr). l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
WithField("certName", certName). WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
Info("Handshake message sent") Info("Handshake message sent")
@@ -224,6 +234,7 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
if err == nil && ho.localIndexId != 0 { if err == nil && ho.localIndexId != 0 {
l.WithField("vpnIp", vpnIP). l.WithField("vpnIp", vpnIP).
WithField("certName", certName). WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("action", "removing stale index"). WithField("action", "removing stale index").
WithField("index", ho.localIndexId). WithField("index", ho.localIndexId).
Debug("Handshake processing") Debug("Handshake processing")
@@ -237,6 +248,7 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
} else { } else {
l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr). l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
WithField("certName", certName). WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
Error("Noise did not arrive at a key") Error("Noise did not arrive at a key")
return true return true
@@ -296,10 +308,12 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
} }
vpnIP := ip2int(remoteCert.Details.Ips[0].IP) vpnIP := ip2int(remoteCert.Details.Ips[0].IP)
certName := remoteCert.Details.Name certName := remoteCert.Details.Name
fingerprint, _ := remoteCert.Sha256Sum()
duration := time.Since(hostinfo.handshakeStart).Nanoseconds() duration := time.Since(hostinfo.handshakeStart).Nanoseconds()
l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr). l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
WithField("certName", certName). WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
WithField("durationNs", duration). WithField("durationNs", duration).
@@ -338,6 +352,7 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
if err == nil && ho.localIndexId != 0 { if err == nil && ho.localIndexId != 0 {
l.WithField("vpnIp", vpnIP). l.WithField("vpnIp", vpnIP).
WithField("certName", certName). WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("action", "removing stale index"). WithField("action", "removing stale index").
WithField("index", ho.localIndexId). WithField("index", ho.localIndexId).
Debug("Handshake processing") Debug("Handshake processing")
@@ -352,6 +367,7 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
} else { } else {
l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr). l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
WithField("certName", certName). WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
Error("Noise did not arrive at a key") Error("Noise did not arrive at a key")
return true return true

View File

@@ -17,6 +17,7 @@ const (
DefaultHandshakeRetries = 20 DefaultHandshakeRetries = 20
// DefaultHandshakeWaitRotation is the number of handshake attempts to do before starting to use other ips addresses // DefaultHandshakeWaitRotation is the number of handshake attempts to do before starting to use other ips addresses
DefaultHandshakeWaitRotation = 5 DefaultHandshakeWaitRotation = 5
DefaultHandshakeTriggerBuffer = 64
) )
var ( var (
@@ -24,6 +25,7 @@ var (
tryInterval: DefaultHandshakeTryInterval, tryInterval: DefaultHandshakeTryInterval,
retries: DefaultHandshakeRetries, retries: DefaultHandshakeRetries,
waitRotation: DefaultHandshakeWaitRotation, waitRotation: DefaultHandshakeWaitRotation,
triggerBuffer: DefaultHandshakeTriggerBuffer,
} }
) )
@@ -31,6 +33,9 @@ type HandshakeConfig struct {
tryInterval time.Duration tryInterval time.Duration
retries int retries int
waitRotation int waitRotation int
triggerBuffer int
messageMetrics *MessageMetrics
} }
type HandshakeManager struct { type HandshakeManager struct {
@@ -40,8 +45,13 @@ type HandshakeManager struct {
outside *udpConn outside *udpConn
config HandshakeConfig config HandshakeConfig
// can be used to trigger outbound handshake for the given vpnIP
trigger chan uint32
OutboundHandshakeTimer *SystemTimerWheel OutboundHandshakeTimer *SystemTimerWheel
InboundHandshakeTimer *SystemTimerWheel InboundHandshakeTimer *SystemTimerWheel
messageMetrics *MessageMetrics
} }
func NewHandshakeManager(tunCidr *net.IPNet, preferredRanges []*net.IPNet, mainHostMap *HostMap, lightHouse *LightHouse, outside *udpConn, config HandshakeConfig) *HandshakeManager { func NewHandshakeManager(tunCidr *net.IPNet, preferredRanges []*net.IPNet, mainHostMap *HostMap, lightHouse *LightHouse, outside *udpConn, config HandshakeConfig) *HandshakeManager {
@@ -53,18 +63,28 @@ func NewHandshakeManager(tunCidr *net.IPNet, preferredRanges []*net.IPNet, mainH
config: config, config: config,
trigger: make(chan uint32, config.triggerBuffer),
OutboundHandshakeTimer: NewSystemTimerWheel(config.tryInterval, config.tryInterval*time.Duration(config.retries)), OutboundHandshakeTimer: NewSystemTimerWheel(config.tryInterval, config.tryInterval*time.Duration(config.retries)),
InboundHandshakeTimer: NewSystemTimerWheel(config.tryInterval, config.tryInterval*time.Duration(config.retries)), InboundHandshakeTimer: NewSystemTimerWheel(config.tryInterval, config.tryInterval*time.Duration(config.retries)),
messageMetrics: config.messageMetrics,
} }
} }
func (c *HandshakeManager) Run(f EncWriter) { func (c *HandshakeManager) Run(f EncWriter) {
clockSource := time.Tick(c.config.tryInterval) clockSource := time.Tick(c.config.tryInterval)
for now := range clockSource { for {
select {
case vpnIP := <-c.trigger:
l.WithField("vpnIp", IntIp(vpnIP)).Debug("HandshakeManager: triggered")
c.handleOutbound(vpnIP, f, true)
case now := <-clockSource:
c.NextOutboundHandshakeTimerTick(now, f) c.NextOutboundHandshakeTimerTick(now, f)
c.NextInboundHandshakeTimerTick(now) c.NextInboundHandshakeTimerTick(now)
} }
} }
}
func (c *HandshakeManager) NextOutboundHandshakeTimerTick(now time.Time, f EncWriter) { func (c *HandshakeManager) NextOutboundHandshakeTimerTick(now time.Time, f EncWriter) {
c.OutboundHandshakeTimer.advance(now) c.OutboundHandshakeTimer.advance(now)
@@ -74,15 +94,18 @@ func (c *HandshakeManager) NextOutboundHandshakeTimerTick(now time.Time, f EncWr
break break
} }
vpnIP := ep.(uint32) vpnIP := ep.(uint32)
c.handleOutbound(vpnIP, f, false)
index, err := c.pendingHostMap.GetIndexByVpnIP(vpnIP) }
if err != nil {
continue
} }
func (c *HandshakeManager) handleOutbound(vpnIP uint32, f EncWriter, lighthouseTriggered bool) {
index, err := c.pendingHostMap.GetIndexByVpnIP(vpnIP)
if err != nil {
return
}
hostinfo, err := c.pendingHostMap.QueryVpnIP(vpnIP) hostinfo, err := c.pendingHostMap.QueryVpnIP(vpnIP)
if err != nil { if err != nil {
continue return
} }
// If we haven't finished the handshake and we haven't hit max retries, query // If we haven't finished the handshake and we haven't hit max retries, query
@@ -92,13 +115,26 @@ func (c *HandshakeManager) NextOutboundHandshakeTimerTick(now time.Time, f EncWr
// We continue to query the lighthouse because hosts may // We continue to query the lighthouse because hosts may
// come online during handshake retries. If the query // come online during handshake retries. If the query
// succeeds (no error), add the lighthouse info to hostinfo // succeeds (no error), add the lighthouse info to hostinfo
ips, err := c.lightHouse.Query(vpnIP, f) ips := c.lightHouse.QueryCache(vpnIP)
// If we have no responses yet, or only one IP (the host hadn't
// finished reporting its own IPs yet), then send another query to
// the LH.
if len(ips) <= 1 {
ips, err = c.lightHouse.Query(vpnIP, f)
}
if err == nil { if err == nil {
for _, ip := range ips { for _, ip := range ips {
hostinfo.AddRemote(ip) hostinfo.AddRemote(ip)
} }
hostinfo.ForcePromoteBest(c.mainHostMap.preferredRanges) hostinfo.ForcePromoteBest(c.mainHostMap.preferredRanges)
} }
} else if lighthouseTriggered {
// We were triggered by a lighthouse HostQueryReply packet, but
// we have already picked a remote for this host (this can happen
// if we are configured with multiple lighthouses). So we can skip
// this trigger and let the timerwheel handle the rest of the
// process
return
} }
hostinfo.HandshakeCounter++ hostinfo.HandshakeCounter++
@@ -111,6 +147,7 @@ func (c *HandshakeManager) NextOutboundHandshakeTimerTick(now time.Time, f EncWr
// Ensure the handshake is ready to avoid a race in timer tick and stage 0 handshake generation // Ensure the handshake is ready to avoid a race in timer tick and stage 0 handshake generation
if hostinfo.HandshakeReady && hostinfo.remote != nil { if hostinfo.HandshakeReady && hostinfo.remote != nil {
c.messageMetrics.Tx(handshake, NebulaMessageSubType(hostinfo.HandshakePacket[0][1]), 1)
err := c.outside.WriteTo(hostinfo.HandshakePacket[0], hostinfo.remote) err := c.outside.WriteTo(hostinfo.HandshakePacket[0], hostinfo.remote)
if err != nil { if err != nil {
hostinfo.logger().WithField("udpAddr", hostinfo.remote). hostinfo.logger().WithField("udpAddr", hostinfo.remote).
@@ -130,14 +167,15 @@ func (c *HandshakeManager) NextOutboundHandshakeTimerTick(now time.Time, f EncWr
} }
// Readd to the timer wheel so we continue trying wait HandshakeTryInterval * counter longer for next try // Readd to the timer wheel so we continue trying wait HandshakeTryInterval * counter longer for next try
if !lighthouseTriggered {
//l.Infoln("Interval: ", HandshakeTryInterval*time.Duration(hostinfo.HandshakeCounter)) //l.Infoln("Interval: ", HandshakeTryInterval*time.Duration(hostinfo.HandshakeCounter))
c.OutboundHandshakeTimer.Add(vpnIP, c.config.tryInterval*time.Duration(hostinfo.HandshakeCounter)) c.OutboundHandshakeTimer.Add(vpnIP, c.config.tryInterval*time.Duration(hostinfo.HandshakeCounter))
}
} else { } else {
c.pendingHostMap.DeleteVpnIP(vpnIP) c.pendingHostMap.DeleteVpnIP(vpnIP)
c.pendingHostMap.DeleteIndex(index) c.pendingHostMap.DeleteIndex(index)
} }
} }
}
func (c *HandshakeManager) NextInboundHandshakeTimerTick(now time.Time) { func (c *HandshakeManager) NextInboundHandshakeTimerTick(now time.Time) {
c.InboundHandshakeTimer.advance(now) c.InboundHandshakeTimer.advance(now)
@@ -162,6 +200,7 @@ func (c *HandshakeManager) AddVpnIP(vpnIP uint32) *HostInfo {
// We lock here and use an array to insert items to prevent locking the // We lock here and use an array to insert items to prevent locking the
// main receive thread for very long by waiting to add items to the pending map // main receive thread for very long by waiting to add items to the pending map
c.OutboundHandshakeTimer.Add(vpnIP, c.config.tryInterval) c.OutboundHandshakeTimer.Add(vpnIP, c.config.tryInterval)
return hostinfo return hostinfo
} }

View File

@@ -103,6 +103,56 @@ func Test_NewHandshakeManagerVpnIP(t *testing.T) {
} }
} }
func Test_NewHandshakeManagerTrigger(t *testing.T) {
_, tuncidr, _ := net.ParseCIDR("172.1.1.1/24")
_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
ip := ip2int(net.ParseIP("172.1.1.2"))
preferredRanges := []*net.IPNet{localrange}
mw := &mockEncWriter{}
mainHM := NewHostMap("test", vpncidr, preferredRanges)
lh := &LightHouse{}
blah := NewHandshakeManager(tuncidr, preferredRanges, mainHM, lh, &udpConn{}, defaultHandshakeConfig)
now := time.Now()
blah.NextOutboundHandshakeTimerTick(now, mw)
assert.Equal(t, 0, testCountTimerWheelEntries(blah.OutboundHandshakeTimer))
blah.AddVpnIP(ip)
assert.Equal(t, 1, testCountTimerWheelEntries(blah.OutboundHandshakeTimer))
// Trigger the same method the channel will
blah.handleOutbound(ip, mw, true)
// Make sure the trigger doesn't schedule another timer entry
assert.Equal(t, 1, testCountTimerWheelEntries(blah.OutboundHandshakeTimer))
hi := blah.pendingHostMap.Hosts[ip]
assert.Nil(t, hi.remote)
lh.addrMap = map[uint32][]udpAddr{
ip: {*NewUDPAddrFromString("10.1.1.1:4242")},
}
// This should trigger the hostmap to populate the hostinfo
blah.handleOutbound(ip, mw, true)
assert.NotNil(t, hi.remote)
assert.Equal(t, 1, testCountTimerWheelEntries(blah.OutboundHandshakeTimer))
}
func testCountTimerWheelEntries(tw *SystemTimerWheel) (c int) {
for _, i := range tw.wheel {
n := i.Head
for n != nil {
c++
n = n.Next
}
}
return c
}
func Test_NewHandshakeManagerVpnIPcleanup(t *testing.T) { func Test_NewHandshakeManagerVpnIPcleanup(t *testing.T) {
_, tuncidr, _ := net.ParseCIDR("172.1.1.1/24") _, tuncidr, _ := net.ParseCIDR("172.1.1.1/24")
_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24") _, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")

View File

@@ -54,6 +54,7 @@ var typeMap = map[NebulaMessageType]string{
} }
const ( const (
subTypeNone NebulaMessageSubType = 0
testRequest NebulaMessageSubType = 0 testRequest NebulaMessageSubType = 0
testReply NebulaMessageSubType = 1 testReply NebulaMessageSubType = 1
) )

View File

@@ -1,9 +1,10 @@
package nebula package nebula
import ( import (
"github.com/stretchr/testify/assert"
"reflect" "reflect"
"testing" "testing"
"github.com/stretchr/testify/assert"
) )
type headerTest struct { type headerTest struct {

View File

@@ -30,6 +30,7 @@ type HostMap struct {
vpnCIDR *net.IPNet vpnCIDR *net.IPNet
defaultRoute uint32 defaultRoute uint32
unsafeRoutes *CIDRTree unsafeRoutes *CIDRTree
metricsEnabled bool
} }
type HostInfo struct { type HostInfo struct {
@@ -384,8 +385,16 @@ func (hm *HostMap) PunchList() []*udpAddr {
} }
func (hm *HostMap) Punchy(conn *udpConn) { func (hm *HostMap) Punchy(conn *udpConn) {
var metricsTxPunchy metrics.Counter
if hm.metricsEnabled {
metricsTxPunchy = metrics.GetOrRegisterCounter("messages.tx.punchy", nil)
} else {
metricsTxPunchy = metrics.NilCounter{}
}
for { for {
for _, addr := range hm.PunchList() { for _, addr := range hm.PunchList() {
metricsTxPunchy.Inc(1)
conn.WriteTo([]byte{1}, addr) conn.WriteTo([]byte{1}, addr)
} }
time.Sleep(time.Second * 30) time.Sleep(time.Second * 30)

View File

@@ -161,6 +161,4 @@ func BenchmarkHostmappromote2(b *testing.B) {
m.AddRemote(ip2int(net.ParseIP("10.128.1.1")), g) m.AddRemote(ip2int(net.ParseIP("10.128.1.1")), g)
m.AddRemote(ip2int(net.ParseIP("10.128.1.1")), y) m.AddRemote(ip2int(net.ParseIP("10.128.1.1")), y)
} }
b.Errorf("hi")
} }

View File

@@ -7,7 +7,7 @@ import (
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
) )
func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *FirewallPacket, nb, out []byte) { func (f *Interface) consumeInsidePacket(st NebulaMessageSubType, packet []byte, fwPacket *FirewallPacket, nb, out []byte) {
err := newPacket(packet, false, fwPacket) err := newPacket(packet, false, fwPacket)
if err != nil { if err != nil {
l.WithField("packet", packet).Debugf("Error while validating outbound packet: %s", err) l.WithField("packet", packet).Debugf("Error while validating outbound packet: %s", err)
@@ -30,6 +30,14 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *FirewallPacket,
} }
hostinfo := f.getOrHandshake(fwPacket.RemoteIP) hostinfo := f.getOrHandshake(fwPacket.RemoteIP)
if hostinfo == nil {
if l.Level >= logrus.DebugLevel {
l.WithField("vpnIp", IntIp(fwPacket.RemoteIP)).
WithField("fwPacket", fwPacket).
Debugln("dropping outbound packet, vpnIp not in our CIDR or in unsafe routes")
}
return
}
ci := hostinfo.ConnectionState ci := hostinfo.ConnectionState
if ci.ready == false { if ci.ready == false {
@@ -37,28 +45,35 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *FirewallPacket,
// the packet queue. // the packet queue.
ci.queueLock.Lock() ci.queueLock.Lock()
if !ci.ready { if !ci.ready {
hostinfo.cachePacket(message, 0, packet, f.sendMessageNow) hostinfo.cachePacket(message, st, packet, f.sendMessageNow)
ci.queueLock.Unlock() ci.queueLock.Unlock()
return return
} }
ci.queueLock.Unlock() ci.queueLock.Unlock()
} }
if !f.firewall.Drop(packet, *fwPacket, false, hostinfo, trustedCAs) { dropReason := f.firewall.Drop(packet, *fwPacket, false, hostinfo, trustedCAs)
f.send(message, 0, ci, hostinfo, hostinfo.remote, packet, nb, out) if dropReason == nil {
if f.lightHouse != nil && *ci.messageCounter%5000 == 0 { mc := f.sendNoMetrics(message, st, ci, hostinfo, hostinfo.remote, packet, nb, out)
if f.lightHouse != nil && mc%5000 == 0 {
f.lightHouse.Query(fwPacket.RemoteIP, f) f.lightHouse.Query(fwPacket.RemoteIP, f)
} }
} else if l.Level >= logrus.DebugLevel { } else if l.Level >= logrus.DebugLevel {
hostinfo.logger().WithField("fwPacket", fwPacket). hostinfo.logger().
WithField("fwPacket", fwPacket).
WithField("reason", dropReason).
Debugln("dropping outbound packet") Debugln("dropping outbound packet")
} }
} }
// getOrHandshake returns nil if the vpnIp is not routable
func (f *Interface) getOrHandshake(vpnIp uint32) *HostInfo { func (f *Interface) getOrHandshake(vpnIp uint32) *HostInfo {
if f.hostMap.vpnCIDR.Contains(int2ip(vpnIp)) == false { if f.hostMap.vpnCIDR.Contains(int2ip(vpnIp)) == false {
vpnIp = f.hostMap.queryUnsafeRoute(vpnIp) vpnIp = f.hostMap.queryUnsafeRoute(vpnIp)
if vpnIp == 0 {
return nil
}
} }
hostinfo, err := f.hostMap.PromoteBestQueryVpnIP(vpnIp, f) hostinfo, err := f.hostMap.PromoteBestQueryVpnIP(vpnIp, f)
@@ -91,6 +106,15 @@ func (f *Interface) getOrHandshake(vpnIp uint32) *HostInfo {
ixHandshakeStage0(f, vpnIp, hostinfo) ixHandshakeStage0(f, vpnIp, hostinfo)
// FIXME: Maybe make XX selectable, but probably not since psk makes it nearly pointless for us. // FIXME: Maybe make XX selectable, but probably not since psk makes it nearly pointless for us.
//xx_handshakeStage0(f, ip, hostinfo) //xx_handshakeStage0(f, ip, hostinfo)
// If this is a static host, we don't need to wait for the HostQueryReply
// We can trigger the handshake right now
if _, ok := f.lightHouse.staticList[vpnIp]; ok {
select {
case f.handshakeManager.trigger <- vpnIp:
default:
}
}
} }
return hostinfo return hostinfo
@@ -105,12 +129,17 @@ func (f *Interface) sendMessageNow(t NebulaMessageType, st NebulaMessageSubType,
} }
// check if packet is in outbound fw rules // check if packet is in outbound fw rules
if f.firewall.Drop(p, *fp, false, hostInfo, trustedCAs) { dropReason := f.firewall.Drop(p, *fp, false, hostInfo, trustedCAs)
l.WithField("fwPacket", fp).Debugln("dropping cached packet") if dropReason != nil {
if l.Level >= logrus.DebugLevel {
l.WithField("fwPacket", fp).
WithField("reason", dropReason).
Debugln("dropping cached packet")
}
return return
} }
f.send(message, st, hostInfo.ConnectionState, hostInfo, hostInfo.remote, p, nb, out) f.sendNoMetrics(message, st, hostInfo.ConnectionState, hostInfo, hostInfo.remote, p, nb, out)
if f.lightHouse != nil && *hostInfo.ConnectionState.messageCounter%5000 == 0 { if f.lightHouse != nil && *hostInfo.ConnectionState.messageCounter%5000 == 0 {
f.lightHouse.Query(fp.RemoteIP, f) f.lightHouse.Query(fp.RemoteIP, f)
} }
@@ -119,6 +148,13 @@ func (f *Interface) sendMessageNow(t NebulaMessageType, st NebulaMessageSubType,
// SendMessageToVpnIp handles real ip:port lookup and sends to the current best known address for vpnIp // SendMessageToVpnIp handles real ip:port lookup and sends to the current best known address for vpnIp
func (f *Interface) SendMessageToVpnIp(t NebulaMessageType, st NebulaMessageSubType, vpnIp uint32, p, nb, out []byte) { func (f *Interface) SendMessageToVpnIp(t NebulaMessageType, st NebulaMessageSubType, vpnIp uint32, p, nb, out []byte) {
hostInfo := f.getOrHandshake(vpnIp) hostInfo := f.getOrHandshake(vpnIp)
if hostInfo == nil {
if l.Level >= logrus.DebugLevel {
l.WithField("vpnIp", IntIp(vpnIp)).
Debugln("dropping SendMessageToVpnIp, vpnIp not in our CIDR or in unsafe routes")
}
return
}
if !hostInfo.ConnectionState.ready { if !hostInfo.ConnectionState.ready {
// Because we might be sending stored packets, lock here to stop new things going to // Because we might be sending stored packets, lock here to stop new things going to
@@ -143,6 +179,13 @@ func (f *Interface) sendMessageToVpnIp(t NebulaMessageType, st NebulaMessageSubT
// SendMessageToAll handles real ip:port lookup and sends to all known addresses for vpnIp // SendMessageToAll handles real ip:port lookup and sends to all known addresses for vpnIp
func (f *Interface) SendMessageToAll(t NebulaMessageType, st NebulaMessageSubType, vpnIp uint32, p, nb, out []byte) { func (f *Interface) SendMessageToAll(t NebulaMessageType, st NebulaMessageSubType, vpnIp uint32, p, nb, out []byte) {
hostInfo := f.getOrHandshake(vpnIp) hostInfo := f.getOrHandshake(vpnIp)
if hostInfo == nil {
if l.Level >= logrus.DebugLevel {
l.WithField("vpnIp", IntIp(vpnIp)).
Debugln("dropping SendMessageToAll, vpnIp not in our CIDR or in unsafe routes")
}
return
}
if hostInfo.ConnectionState.ready == false { if hostInfo.ConnectionState.ready == false {
// Because we might be sending stored packets, lock here to stop new things going to // Because we might be sending stored packets, lock here to stop new things going to
@@ -167,9 +210,14 @@ func (f *Interface) sendMessageToAll(t NebulaMessageType, st NebulaMessageSubTyp
} }
func (f *Interface) send(t NebulaMessageType, st NebulaMessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote *udpAddr, p, nb, out []byte) { func (f *Interface) send(t NebulaMessageType, st NebulaMessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote *udpAddr, p, nb, out []byte) {
f.messageMetrics.Tx(t, st, 1)
f.sendNoMetrics(t, st, ci, hostinfo, remote, p, nb, out)
}
func (f *Interface) sendNoMetrics(t NebulaMessageType, st NebulaMessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote *udpAddr, p, nb, out []byte) uint64 {
if ci.eKey == nil { if ci.eKey == nil {
//TODO: log warning //TODO: log warning
return return 0
} }
var err error var err error
@@ -189,7 +237,7 @@ func (f *Interface) send(t NebulaMessageType, st NebulaMessageSubType, ci *Conne
WithField("udpAddr", remote).WithField("counter", c). WithField("udpAddr", remote).WithField("counter", c).
WithField("attemptedCounter", ci.messageCounter). WithField("attemptedCounter", ci.messageCounter).
Error("Failed to encrypt outgoing packet") Error("Failed to encrypt outgoing packet")
return return c
} }
err = f.outside.WriteTo(out, remote) err = f.outside.WriteTo(out, remote)
@@ -197,6 +245,7 @@ func (f *Interface) send(t NebulaMessageType, st NebulaMessageSubType, ci *Conne
hostinfo.logger().WithError(err). hostinfo.logger().WithError(err).
WithField("udpAddr", remote).Error("Failed to write outgoing packet") WithField("udpAddr", remote).Error("Failed to write outgoing packet")
} }
return c
} }
func isMulticast(ip uint32) bool { func isMulticast(ip uint32) bool {

View File

@@ -2,6 +2,8 @@ package nebula
import ( import (
"errors" "errors"
"io"
"net"
"os" "os"
"time" "time"
@@ -10,10 +12,20 @@ import (
const mtu = 9001 const mtu = 9001
type Inside interface {
io.ReadWriteCloser
Activate() error
CidrNet() *net.IPNet
DeviceName() string
WriteRaw([]byte) error
}
type InsideHandler func(hostInfo *HostInfo, ci *ConnectionState, addr *udpAddr, header *Header, out []byte, packet []byte, fp *FirewallPacket, nb []byte)
type InterfaceConfig struct { type InterfaceConfig struct {
HostMap *HostMap HostMap *HostMap
Outside *udpConn Outside *udpConn
Inside *Tun Inside Inside
certState *CertState certState *CertState
Cipher string Cipher string
Firewall *Firewall Firewall *Firewall
@@ -25,12 +37,16 @@ type InterfaceConfig struct {
DropLocalBroadcast bool DropLocalBroadcast bool
DropMulticast bool DropMulticast bool
UDPBatchSize int UDPBatchSize int
udpQueues int
tunQueues int
MessageMetrics *MessageMetrics
version string
} }
type Interface struct { type Interface struct {
hostMap *HostMap hostMap *HostMap
outside *udpConn outside *udpConn
inside *Tun inside Inside
certState *CertState certState *CertState
cipher string cipher string
firewall *Firewall firewall *Firewall
@@ -43,11 +59,15 @@ type Interface struct {
dropLocalBroadcast bool dropLocalBroadcast bool
dropMulticast bool dropMulticast bool
udpBatchSize int udpBatchSize int
udpQueues int
tunQueues int
version string version string
metricRxRecvError metrics.Counter // handlers are mapped by protocol version -> type -> subtype
metricTxRecvError metrics.Counter handlers map[uint8]map[NebulaMessageType]map[NebulaMessageSubType]InsideHandler
metricHandshakes metrics.Histogram metricHandshakes metrics.Histogram
messageMetrics *MessageMetrics
} }
func NewInterface(c *InterfaceConfig) (*Interface, error) { func NewInterface(c *InterfaceConfig) (*Interface, error) {
@@ -79,40 +99,64 @@ func NewInterface(c *InterfaceConfig) (*Interface, error) {
dropLocalBroadcast: c.DropLocalBroadcast, dropLocalBroadcast: c.DropLocalBroadcast,
dropMulticast: c.DropMulticast, dropMulticast: c.DropMulticast,
udpBatchSize: c.UDPBatchSize, udpBatchSize: c.UDPBatchSize,
udpQueues: c.udpQueues,
tunQueues: c.tunQueues,
version: c.version,
metricRxRecvError: metrics.GetOrRegisterCounter("messages.rx.recv_error", nil),
metricTxRecvError: metrics.GetOrRegisterCounter("messages.tx.recv_error", nil),
metricHandshakes: metrics.GetOrRegisterHistogram("handshakes", nil, metrics.NewExpDecaySample(1028, 0.015)), metricHandshakes: metrics.GetOrRegisterHistogram("handshakes", nil, metrics.NewExpDecaySample(1028, 0.015)),
messageMetrics: c.MessageMetrics,
} }
ifce.connectionManager = newConnectionManager(ifce, c.checkInterval, c.pendingDeletionInterval) ifce.connectionManager = newConnectionManager(ifce, c.checkInterval, c.pendingDeletionInterval)
ifce.handlers = map[uint8]map[NebulaMessageType]map[NebulaMessageSubType]InsideHandler{
Version: {
handshake: {
handshakeIXPSK0: ifce.rxMetrics(ifce.handleHandshakePacket),
},
message: {
subTypeNone: ifce.encrypted(ifce.handleMessagePacket),
},
recvError: {
subTypeNone: ifce.rxMetrics(ifce.handleRecvErrorPacket),
},
lightHouse: {
subTypeNone: ifce.rxMetrics(ifce.encrypted(ifce.handleLighthousePacket)),
},
test: {
testRequest: ifce.rxMetrics(ifce.encrypted(ifce.handleTestPacket)),
testReply: ifce.rxMetrics(ifce.encrypted(ifce.handleTestPacket)),
},
closeTunnel: {
subTypeNone: ifce.rxMetrics(ifce.encrypted(ifce.handleCloseTunnelPacket)),
},
},
}
return ifce, nil return ifce, nil
} }
func (f *Interface) Run(tunRoutines, udpRoutines int, buildVersion string) { func (f *Interface) run() {
// actually turn on tun dev // actually turn on tun dev
if err := f.inside.Activate(); err != nil { if err := f.inside.Activate(); err != nil {
l.Fatal(err) l.Fatal(err)
} }
f.version = buildVersion
addr, err := f.outside.LocalAddr() addr, err := f.outside.LocalAddr()
if err != nil { if err != nil {
l.WithError(err).Error("Failed to get udp listen address") l.WithError(err).Error("Failed to get udp listen address")
} }
l.WithField("interface", f.inside.Device).WithField("network", f.inside.Cidr.String()). l.WithField("interface", f.inside.DeviceName()).WithField("network", f.inside.CidrNet().String()).
WithField("build", buildVersion).WithField("udpAddr", addr). WithField("build", f.version).WithField("udpAddr", addr).
Info("Nebula interface is active") Info("Nebula interface is active")
// Launch n queues to read packets from udp // Launch n queues to read packets from udp
for i := 0; i < udpRoutines; i++ { for i := 0; i < f.udpQueues; i++ {
go f.listenOut(i) go f.listenOut(i)
} }
// Launch n queues to read packets from tun dev // Launch n queues to read packets from tun dev
for i := 0; i < tunRoutines; i++ { for i := 0; i < f.tunQueues; i++ {
go f.listenIn(i) go f.listenIn(i)
} }
} }
@@ -152,7 +196,7 @@ func (f *Interface) listenIn(i int) {
os.Exit(2) os.Exit(2)
} }
f.consumeInsidePacket(packet[:n], fwPacket, nb, out) f.consumeInsidePacket(subTypeNone, packet[:n], fwPacket, nb, out)
} }
} }
@@ -210,11 +254,28 @@ func (f *Interface) reloadFirewall(c *Config) {
} }
oldFw := f.firewall oldFw := f.firewall
conntrack := oldFw.Conntrack
conntrack.Lock()
defer conntrack.Unlock()
fw.rulesVersion = oldFw.rulesVersion + 1
// If rulesVersion is back to zero, we have wrapped all the way around. Be
// safe and just reset conntrack in this case.
if fw.rulesVersion == 0 {
l.WithField("firewallHash", fw.GetRuleHash()).
WithField("oldFirewallHash", oldFw.GetRuleHash()).
WithField("rulesVersion", fw.rulesVersion).
Warn("firewall rulesVersion has overflowed, resetting conntrack")
} else {
fw.Conntrack = conntrack
}
f.firewall = fw f.firewall = fw
oldFw.Destroy() oldFw.Destroy()
l.WithField("firewallHash", fw.GetRuleHash()). l.WithField("firewallHash", fw.GetRuleHash()).
WithField("oldFirewallHash", oldFw.GetRuleHash()). WithField("oldFirewallHash", oldFw.GetRuleHash()).
WithField("rulesVersion", fw.rulesVersion).
Info("New firewall has been installed") Info("New firewall has been installed")
} }

View File

@@ -7,6 +7,7 @@ import (
"time" "time"
"github.com/golang/protobuf/proto" "github.com/golang/protobuf/proto"
"github.com/rcrowley/go-metrics"
"github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/cert"
) )
@@ -29,6 +30,9 @@ type LightHouse struct {
// filters local addresses that we advertise to lighthouses // filters local addresses that we advertise to lighthouses
localAllowList *AllowList localAllowList *AllowList
// used to trigger the HandshakeManager when we receive HostQueryReply
handshakeTrigger chan<- uint32
// staticList exists to avoid having a bool in each addrMap entry // staticList exists to avoid having a bool in each addrMap entry
// since static should be rare // since static should be rare
staticList map[uint32]struct{} staticList map[uint32]struct{}
@@ -37,6 +41,9 @@ type LightHouse struct {
nebulaPort int nebulaPort int
punchBack bool punchBack bool
punchDelay time.Duration punchDelay time.Duration
metrics *MessageMetrics
metricHolepunchTx metrics.Counter
} }
type EncWriter interface { type EncWriter interface {
@@ -44,7 +51,7 @@ type EncWriter interface {
SendMessageToAll(t NebulaMessageType, st NebulaMessageSubType, vpnIp uint32, p, nb, out []byte) SendMessageToAll(t NebulaMessageType, st NebulaMessageSubType, vpnIp uint32, p, nb, out []byte)
} }
func NewLightHouse(amLighthouse bool, myIp uint32, ips []uint32, interval int, nebulaPort int, pc *udpConn, punchBack bool, punchDelay time.Duration) *LightHouse { func NewLightHouse(amLighthouse bool, myIp uint32, ips []uint32, interval int, nebulaPort int, pc *udpConn, punchBack bool, punchDelay time.Duration, metricsEnabled bool) *LightHouse {
h := LightHouse{ h := LightHouse{
amLighthouse: amLighthouse, amLighthouse: amLighthouse,
myIp: myIp, myIp: myIp,
@@ -58,6 +65,14 @@ func NewLightHouse(amLighthouse bool, myIp uint32, ips []uint32, interval int, n
punchDelay: punchDelay, punchDelay: punchDelay,
} }
if metricsEnabled {
h.metrics = newLighthouseMetrics()
h.metricHolepunchTx = metrics.GetOrRegisterCounter("messages.tx.holepunch", nil)
} else {
h.metricHolepunchTx = metrics.NilCounter{}
}
for _, ip := range ips { for _, ip := range ips {
h.lighthouses[ip] = struct{}{} h.lighthouses[ip] = struct{}{}
} }
@@ -111,6 +126,7 @@ func (lh *LightHouse) QueryServer(ip uint32, f EncWriter) {
return return
} }
lh.metricTx(NebulaMeta_HostQuery, int64(len(lh.lighthouses)))
nb := make([]byte, 12, 12) nb := make([]byte, 12, 12)
out := make([]byte, mtu) out := make([]byte, mtu)
for n := range lh.lighthouses { for n := range lh.lighthouses {
@@ -249,6 +265,7 @@ func (lh *LightHouse) LhUpdateWorker(f EncWriter) {
}, },
} }
lh.metricTx(NebulaMeta_HostUpdateNotification, int64(len(lh.lighthouses)))
nb := make([]byte, 12, 12) nb := make([]byte, 12, 12)
out := make([]byte, mtu) out := make([]byte, mtu)
for vpnIp := range lh.lighthouses { for vpnIp := range lh.lighthouses {
@@ -281,6 +298,8 @@ func (lh *LightHouse) HandleRequest(rAddr *udpAddr, vpnIp uint32, p []byte, c *c
return return
} }
lh.metricRx(n.Type, 1)
switch n.Type { switch n.Type {
case NebulaMeta_HostQuery: case NebulaMeta_HostQuery:
// Exit if we don't answer queries // Exit if we don't answer queries
@@ -308,6 +327,7 @@ func (lh *LightHouse) HandleRequest(rAddr *udpAddr, vpnIp uint32, p []byte, c *c
l.WithError(err).WithField("vpnIp", IntIp(vpnIp)).Error("Failed to marshal lighthouse host query reply") l.WithError(err).WithField("vpnIp", IntIp(vpnIp)).Error("Failed to marshal lighthouse host query reply")
return return
} }
lh.metricTx(NebulaMeta_HostQueryReply, 1)
f.SendMessageToVpnIp(lightHouse, 0, vpnIp, reply, make([]byte, 12, 12), make([]byte, mtu)) f.SendMessageToVpnIp(lightHouse, 0, vpnIp, reply, make([]byte, 12, 12), make([]byte, mtu))
// This signals the other side to punch some zero byte udp packets // This signals the other side to punch some zero byte udp packets
@@ -326,6 +346,7 @@ func (lh *LightHouse) HandleRequest(rAddr *udpAddr, vpnIp uint32, p []byte, c *c
}, },
} }
reply, _ := proto.Marshal(answer) reply, _ := proto.Marshal(answer)
lh.metricTx(NebulaMeta_HostPunchNotification, 1)
f.SendMessageToVpnIp(lightHouse, 0, n.Details.VpnIp, reply, make([]byte, 12, 12), make([]byte, mtu)) f.SendMessageToVpnIp(lightHouse, 0, n.Details.VpnIp, reply, make([]byte, 12, 12), make([]byte, mtu))
} }
//fmt.Println(reply, remoteaddr) //fmt.Println(reply, remoteaddr)
@@ -340,6 +361,11 @@ func (lh *LightHouse) HandleRequest(rAddr *udpAddr, vpnIp uint32, p []byte, c *c
ans := NewUDPAddr(a.Ip, uint16(a.Port)) ans := NewUDPAddr(a.Ip, uint16(a.Port))
lh.AddRemote(n.Details.VpnIp, ans, false) lh.AddRemote(n.Details.VpnIp, ans, false)
} }
// Non-blocking attempt to trigger, skip if it would block
select {
case lh.handshakeTrigger <- n.Details.VpnIp:
default:
}
case NebulaMeta_HostUpdateNotification: case NebulaMeta_HostUpdateNotification:
//Simple check that the host sent this not someone else //Simple check that the host sent this not someone else
@@ -362,6 +388,7 @@ func (lh *LightHouse) HandleRequest(rAddr *udpAddr, vpnIp uint32, p []byte, c *c
vpnPeer := NewUDPAddr(a.Ip, uint16(a.Port)) vpnPeer := NewUDPAddr(a.Ip, uint16(a.Port))
go func() { go func() {
time.Sleep(lh.punchDelay) time.Sleep(lh.punchDelay)
lh.metricHolepunchTx.Inc(1)
lh.punchConn.WriteTo(empty, vpnPeer) lh.punchConn.WriteTo(empty, vpnPeer)
}() }()
@@ -380,6 +407,13 @@ func (lh *LightHouse) HandleRequest(rAddr *udpAddr, vpnIp uint32, p []byte, c *c
} }
} }
func (lh *LightHouse) metricRx(t NebulaMeta_MessageType, i int64) {
lh.metrics.Rx(NebulaMessageType(t), 0, i)
}
func (lh *LightHouse) metricTx(t NebulaMeta_MessageType, i int64) {
lh.metrics.Tx(NebulaMessageType(t), 0, i)
}
/* /*
func (f *Interface) sendPathCheck(ci *ConnectionState, endpoint *net.UDPAddr, counter int) { func (f *Interface) sendPathCheck(ci *ConnectionState, endpoint *net.UDPAddr, counter int) {
c := ci.messageCounter c := ci.messageCounter

View File

@@ -4,7 +4,7 @@ import (
"net" "net"
"testing" "testing"
proto "github.com/golang/protobuf/proto" "github.com/golang/protobuf/proto"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
@@ -52,7 +52,7 @@ func Test_lhStaticMapping(t *testing.T) {
udpServer, _ := NewListener("0.0.0.0", 0, true) udpServer, _ := NewListener("0.0.0.0", 0, true)
meh := NewLightHouse(true, 1, []uint32{ip2int(lh1IP)}, 10, 10003, udpServer, false, 1) meh := NewLightHouse(true, 1, []uint32{ip2int(lh1IP)}, 10, 10003, udpServer, false, 1, false)
meh.AddRemote(ip2int(lh1IP), NewUDPAddr(ip2int(lh1IP), uint16(4242)), true) meh.AddRemote(ip2int(lh1IP), NewUDPAddr(ip2int(lh1IP), uint16(4242)), true)
err := meh.ValidateLHStaticEntries() err := meh.ValidateLHStaticEntries()
assert.Nil(t, err) assert.Nil(t, err)
@@ -60,7 +60,7 @@ func Test_lhStaticMapping(t *testing.T) {
lh2 := "10.128.0.3" lh2 := "10.128.0.3"
lh2IP := net.ParseIP(lh2) lh2IP := net.ParseIP(lh2)
meh = NewLightHouse(true, 1, []uint32{ip2int(lh1IP), ip2int(lh2IP)}, 10, 10003, udpServer, false, 1) meh = NewLightHouse(true, 1, []uint32{ip2int(lh1IP), ip2int(lh2IP)}, 10, 10003, udpServer, false, 1, false)
meh.AddRemote(ip2int(lh1IP), NewUDPAddr(ip2int(lh1IP), uint16(4242)), true) meh.AddRemote(ip2int(lh1IP), NewUDPAddr(ip2int(lh1IP), uint16(4242)), true)
err = meh.ValidateLHStaticEntries() err = meh.ValidateLHStaticEntries()
assert.EqualError(t, err, "Lighthouse 10.128.0.3 does not have a static_host_map entry") assert.EqualError(t, err, "Lighthouse 10.128.0.3 does not have a static_host_map entry")

39
logger.go Normal file
View File

@@ -0,0 +1,39 @@
package nebula
import (
"errors"
"github.com/sirupsen/logrus"
)
type ContextualError struct {
RealError error
Fields map[string]interface{}
Context string
}
func NewContextualError(msg string, fields map[string]interface{}, realError error) ContextualError {
return ContextualError{Context: msg, Fields: fields, RealError: realError}
}
func (ce ContextualError) Error() string {
if ce.RealError == nil {
return ce.Context
}
return ce.RealError.Error()
}
func (ce ContextualError) Unwrap() error {
if ce.RealError == nil {
return errors.New(ce.Context)
}
return ce.RealError
}
func (ce *ContextualError) Log(lr *logrus.Logger) {
if ce.RealError != nil {
lr.WithFields(ce.Fields).WithError(ce.RealError).Error(ce.Context)
} else {
lr.WithFields(ce.Fields).Error(ce.Context)
}
}

67
logger_test.go Normal file
View File

@@ -0,0 +1,67 @@
package nebula
import (
"errors"
"testing"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
)
type TestLogWriter struct {
Logs []string
}
func NewTestLogWriter() *TestLogWriter {
return &TestLogWriter{Logs: make([]string, 0)}
}
func (tl *TestLogWriter) Write(p []byte) (n int, err error) {
tl.Logs = append(tl.Logs, string(p))
return len(p), nil
}
func (tl *TestLogWriter) Reset() {
tl.Logs = tl.Logs[:0]
}
func TestContextualError_Log(t *testing.T) {
l := logrus.New()
l.Formatter = &logrus.TextFormatter{
DisableTimestamp: true,
DisableColors: true,
}
tl := NewTestLogWriter()
l.Out = tl
// Test a full context line
tl.Reset()
e := NewContextualError("test message", m{"field": "1"}, errors.New("error"))
e.Log(l)
assert.Equal(t, []string{"level=error msg=\"test message\" error=error field=1\n"}, tl.Logs)
// Test a line with an error and msg but no fields
tl.Reset()
e = NewContextualError("test message", nil, errors.New("error"))
e.Log(l)
assert.Equal(t, []string{"level=error msg=\"test message\" error=error\n"}, tl.Logs)
// Test just a context and fields
tl.Reset()
e = NewContextualError("test message", m{"field": "1"}, nil)
e.Log(l)
assert.Equal(t, []string{"level=error msg=\"test message\" field=1\n"}, tl.Logs)
// Test just a context
tl.Reset()
e = NewContextualError("test message", nil, nil)
e.Log(l)
assert.Equal(t, []string{"level=error msg=\"test message\"\n"}, tl.Logs)
// Test just an error
tl.Reset()
e = NewContextualError("", nil, errors.New("error"))
e.Log(l)
assert.Equal(t, []string{"level=error error=error\n"}, tl.Logs)
}

131
main.go
View File

@@ -4,11 +4,8 @@ import (
"encoding/binary" "encoding/binary"
"fmt" "fmt"
"net" "net"
"os"
"os/signal"
"strconv" "strconv"
"strings" "strings"
"syscall"
"time" "time"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
@@ -16,36 +13,31 @@ import (
"gopkg.in/yaml.v2" "gopkg.in/yaml.v2"
) )
// The caller should provide a real logger, we have one just in case
var l = logrus.New() var l = logrus.New()
type m map[string]interface{} type m map[string]interface{}
func Main(configPath string, configTest bool, buildVersion string) { func Main(config *Config, configTest bool, buildVersion string, logger *logrus.Logger, tunFd *int) (*Control, error) {
l.Out = os.Stdout l = logger
l.Formatter = &logrus.TextFormatter{ l.Formatter = &logrus.TextFormatter{
FullTimestamp: true, FullTimestamp: true,
} }
config := NewConfig()
err := config.Load(configPath)
if err != nil {
l.WithError(err).Error("Failed to load config")
os.Exit(1)
}
// Print the config if in test, the exit comes later // Print the config if in test, the exit comes later
if configTest { if configTest {
b, err := yaml.Marshal(config.Settings) b, err := yaml.Marshal(config.Settings)
if err != nil { if err != nil {
l.Println(err) return nil, err
os.Exit(1)
} }
// Print the final config
l.Println(string(b)) l.Println(string(b))
} }
err = configLogger(config) err := configLogger(config)
if err != nil { if err != nil {
l.WithError(err).Error("Failed to configure the logger") return nil, NewContextualError("Failed to configure the logger", nil, err)
} }
config.RegisterReloadCallback(func(c *Config) { config.RegisterReloadCallback(func(c *Config) {
@@ -59,20 +51,20 @@ func Main(configPath string, configTest bool, buildVersion string) {
trustedCAs, err = loadCAFromConfig(config) trustedCAs, err = loadCAFromConfig(config)
if err != nil { if err != nil {
//The errors coming out of loadCA are already nicely formatted //The errors coming out of loadCA are already nicely formatted
l.WithError(err).Fatal("Failed to load ca from config") return nil, NewContextualError("Failed to load ca from config", nil, err)
} }
l.WithField("fingerprints", trustedCAs.GetFingerprints()).Debug("Trusted CA fingerprints") l.WithField("fingerprints", trustedCAs.GetFingerprints()).Debug("Trusted CA fingerprints")
cs, err := NewCertStateFromConfig(config) cs, err := NewCertStateFromConfig(config)
if err != nil { if err != nil {
//The errors coming out of NewCertStateFromConfig are already nicely formatted //The errors coming out of NewCertStateFromConfig are already nicely formatted
l.WithError(err).Fatal("Failed to load certificate from config") return nil, NewContextualError("Failed to load certificate from config", nil, err)
} }
l.WithField("cert", cs.certificate).Debug("Client nebula certificate") l.WithField("cert", cs.certificate).Debug("Client nebula certificate")
fw, err := NewFirewallFromConfig(cs.certificate, config) fw, err := NewFirewallFromConfig(cs.certificate, config)
if err != nil { if err != nil {
l.WithError(err).Fatal("Error while loading firewall rules") return nil, NewContextualError("Error while loading firewall rules", nil, err)
} }
l.WithField("firewallHash", fw.GetRuleHash()).Info("Firewall started") l.WithField("firewallHash", fw.GetRuleHash()).Info("Firewall started")
@@ -80,11 +72,11 @@ func Main(configPath string, configTest bool, buildVersion string) {
tunCidr := cs.certificate.Details.Ips[0] tunCidr := cs.certificate.Details.Ips[0]
routes, err := parseRoutes(config, tunCidr) routes, err := parseRoutes(config, tunCidr)
if err != nil { if err != nil {
l.WithError(err).Fatal("Could not parse tun.routes") return nil, NewContextualError("Could not parse tun.routes", nil, err)
} }
unsafeRoutes, err := parseUnsafeRoutes(config, tunCidr) unsafeRoutes, err := parseUnsafeRoutes(config, tunCidr)
if err != nil { if err != nil {
l.WithError(err).Fatal("Could not parse tun.unsafe_routes") return nil, NewContextualError("Could not parse tun.unsafe_routes", nil, err)
} }
ssh, err := sshd.NewSSHServer(l.WithField("subsystem", "sshd")) ssh, err := sshd.NewSSHServer(l.WithField("subsystem", "sshd"))
@@ -92,7 +84,7 @@ func Main(configPath string, configTest bool, buildVersion string) {
if config.GetBool("sshd.enabled", false) { if config.GetBool("sshd.enabled", false) {
err = configSSH(ssh, config) err = configSSH(ssh, config)
if err != nil { if err != nil {
l.WithError(err).Fatal("Error while configuring the sshd") return nil, NewContextualError("Error while configuring the sshd", nil, err)
} }
} }
@@ -101,11 +93,23 @@ func Main(configPath string, configTest bool, buildVersion string) {
// tun config, listeners, anything modifying the computer should be below // tun config, listeners, anything modifying the computer should be below
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
var tun *Tun var tun Inside
if !configTest { if !configTest {
config.CatchHUP() config.CatchHUP()
// set up our tun dev switch {
case config.GetBool("tun.disabled", false):
tun = newDisabledTun(tunCidr, l)
case tunFd != nil:
tun, err = newTunFromFd(
*tunFd,
tunCidr,
config.GetInt("tun.mtu", DEFAULT_MTU),
routes,
unsafeRoutes,
config.GetInt("tun.tx_queue", 500),
)
default:
tun, err = newTun( tun, err = newTun(
config.GetString("tun.dev", ""), config.GetString("tun.dev", ""),
tunCidr, tunCidr,
@@ -114,8 +118,10 @@ func Main(configPath string, configTest bool, buildVersion string) {
unsafeRoutes, unsafeRoutes,
config.GetInt("tun.tx_queue", 500), config.GetInt("tun.tx_queue", 500),
) )
}
if err != nil { if err != nil {
l.WithError(err).Fatal("Failed to get a tun/tap device") return nil, NewContextualError("Failed to get a tun/tap device", nil, err)
} }
} }
@@ -126,7 +132,7 @@ func Main(configPath string, configTest bool, buildVersion string) {
if !configTest { if !configTest {
udpServer, err = NewListener(config.GetString("listen.host", "0.0.0.0"), config.GetInt("listen.port", 0), udpQueues > 1) udpServer, err = NewListener(config.GetString("listen.host", "0.0.0.0"), config.GetInt("listen.port", 0), udpQueues > 1)
if err != nil { if err != nil {
l.WithError(err).Fatal("Failed to open udp listener") return nil, NewContextualError("Failed to open udp listener", nil, err)
} }
udpServer.reloadConfig(config) udpServer.reloadConfig(config)
} }
@@ -139,7 +145,7 @@ func Main(configPath string, configTest bool, buildVersion string) {
for _, rawPreferredRange := range rawPreferredRanges { for _, rawPreferredRange := range rawPreferredRanges {
_, preferredRange, err := net.ParseCIDR(rawPreferredRange) _, preferredRange, err := net.ParseCIDR(rawPreferredRange)
if err != nil { if err != nil {
l.WithError(err).Fatal("Failed to parse preferred ranges") return nil, NewContextualError("Failed to parse preferred ranges", nil, err)
} }
preferredRanges = append(preferredRanges, preferredRange) preferredRanges = append(preferredRanges, preferredRange)
} }
@@ -152,7 +158,7 @@ func Main(configPath string, configTest bool, buildVersion string) {
if rawLocalRange != "" { if rawLocalRange != "" {
_, localRange, err := net.ParseCIDR(rawLocalRange) _, localRange, err := net.ParseCIDR(rawLocalRange)
if err != nil { if err != nil {
l.WithError(err).Fatal("Failed to parse local range") return nil, NewContextualError("Failed to parse local_range", nil, err)
} }
// Check if the entry for local_range was already specified in // Check if the entry for local_range was already specified in
@@ -172,6 +178,7 @@ func Main(configPath string, configTest bool, buildVersion string) {
hostMap := NewHostMap("main", tunCidr, preferredRanges) hostMap := NewHostMap("main", tunCidr, preferredRanges)
hostMap.SetDefaultRoute(ip2int(net.ParseIP(config.GetString("default_route", "0.0.0.0")))) hostMap.SetDefaultRoute(ip2int(net.ParseIP(config.GetString("default_route", "0.0.0.0"))))
hostMap.addUnsafeRoutes(&unsafeRoutes) hostMap.addUnsafeRoutes(&unsafeRoutes)
hostMap.metricsEnabled = config.GetBool("stats.message_metrics", false)
l.WithField("network", hostMap.vpnCIDR).WithField("preferredRanges", hostMap.preferredRanges).Info("Main HostMap created") l.WithField("network", hostMap.vpnCIDR).WithField("preferredRanges", hostMap.preferredRanges).Info("Main HostMap created")
@@ -191,7 +198,7 @@ func Main(configPath string, configTest bool, buildVersion string) {
if port == 0 && !configTest { if port == 0 && !configTest {
uPort, err := udpServer.LocalAddr() uPort, err := udpServer.LocalAddr()
if err != nil { if err != nil {
l.WithError(err).Fatal("Failed to get listening port") return nil, NewContextualError("Failed to get listening port", nil, err)
} }
port = int(uPort.Port) port = int(uPort.Port)
} }
@@ -208,10 +215,10 @@ func Main(configPath string, configTest bool, buildVersion string) {
for i, host := range rawLighthouseHosts { for i, host := range rawLighthouseHosts {
ip := net.ParseIP(host) ip := net.ParseIP(host)
if ip == nil { if ip == nil {
l.WithField("host", host).Fatalf("Unable to parse lighthouse host entry %v", i+1) return nil, NewContextualError("Unable to parse lighthouse host entry", m{"host": host, "entry": i + 1}, nil)
} }
if !tunCidr.Contains(ip) { if !tunCidr.Contains(ip) {
l.WithField("vpnIp", ip).WithField("network", tunCidr.String()).Fatalf("lighthouse host is not in our subnet, invalid") return nil, NewContextualError("lighthouse host is not in our subnet, invalid", m{"vpnIp": ip, "network": tunCidr.String()}, nil)
} }
lighthouseHosts[i] = ip2int(ip) lighthouseHosts[i] = ip2int(ip)
} }
@@ -226,17 +233,18 @@ func Main(configPath string, configTest bool, buildVersion string) {
udpServer, udpServer,
punchy.Respond, punchy.Respond,
punchy.Delay, punchy.Delay,
config.GetBool("stats.lighthouse_metrics", false),
) )
remoteAllowList, err := config.GetAllowList("lighthouse.remote_allow_list", false) remoteAllowList, err := config.GetAllowList("lighthouse.remote_allow_list", false)
if err != nil { if err != nil {
l.WithError(err).Fatal("Invalid lighthouse.remote_allow_list") return nil, NewContextualError("Invalid lighthouse.remote_allow_list", nil, err)
} }
lightHouse.SetRemoteAllowList(remoteAllowList) lightHouse.SetRemoteAllowList(remoteAllowList)
localAllowList, err := config.GetAllowList("lighthouse.local_allow_list", true) localAllowList, err := config.GetAllowList("lighthouse.local_allow_list", true)
if err != nil { if err != nil {
l.WithError(err).Fatal("Invalid lighthouse.local_allow_list") return nil, NewContextualError("Invalid lighthouse.local_allow_list", nil, err)
} }
lightHouse.SetLocalAllowList(localAllowList) lightHouse.SetLocalAllowList(localAllowList)
@@ -244,7 +252,7 @@ func Main(configPath string, configTest bool, buildVersion string) {
for k, v := range config.GetMap("static_host_map", map[interface{}]interface{}{}) { for k, v := range config.GetMap("static_host_map", map[interface{}]interface{}{}) {
vpnIp := net.ParseIP(fmt.Sprintf("%v", k)) vpnIp := net.ParseIP(fmt.Sprintf("%v", k))
if !tunCidr.Contains(vpnIp) { if !tunCidr.Contains(vpnIp) {
l.WithField("vpnIp", vpnIp).WithField("network", tunCidr.String()).Fatalf("static_host_map key is not in our subnet, invalid") return nil, NewContextualError("static_host_map key is not in our subnet, invalid", m{"vpnIp": vpnIp, "network": tunCidr.String()}, nil)
} }
vals, ok := v.([]interface{}) vals, ok := v.([]interface{})
if ok { if ok {
@@ -255,7 +263,7 @@ func Main(configPath string, configTest bool, buildVersion string) {
ip := addr.IP ip := addr.IP
port, err := strconv.Atoi(parts[1]) port, err := strconv.Atoi(parts[1])
if err != nil { if err != nil {
l.Fatalf("Static host address for %s could not be parsed: %s", vpnIp, v) return nil, NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp}, err)
} }
lightHouse.AddRemote(ip2int(vpnIp), NewUDPAddr(ip2int(ip), uint16(port)), true) lightHouse.AddRemote(ip2int(vpnIp), NewUDPAddr(ip2int(ip), uint16(port)), true)
} }
@@ -268,7 +276,7 @@ func Main(configPath string, configTest bool, buildVersion string) {
ip := addr.IP ip := addr.IP
port, err := strconv.Atoi(parts[1]) port, err := strconv.Atoi(parts[1])
if err != nil { if err != nil {
l.Fatalf("Static host address for %s could not be parsed: %s", vpnIp, v) return nil, NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp}, err)
} }
lightHouse.AddRemote(ip2int(vpnIp), NewUDPAddr(ip2int(ip), uint16(port)), true) lightHouse.AddRemote(ip2int(vpnIp), NewUDPAddr(ip2int(ip), uint16(port)), true)
} }
@@ -280,13 +288,24 @@ func Main(configPath string, configTest bool, buildVersion string) {
l.WithError(err).Error("Lighthouse unreachable") l.WithError(err).Error("Lighthouse unreachable")
} }
var messageMetrics *MessageMetrics
if config.GetBool("stats.message_metrics", false) {
messageMetrics = newMessageMetrics()
} else {
messageMetrics = newMessageMetricsOnlyRecvError()
}
handshakeConfig := HandshakeConfig{ handshakeConfig := HandshakeConfig{
tryInterval: config.GetDuration("handshakes.try_interval", DefaultHandshakeTryInterval), tryInterval: config.GetDuration("handshakes.try_interval", DefaultHandshakeTryInterval),
retries: config.GetInt("handshakes.retries", DefaultHandshakeRetries), retries: config.GetInt("handshakes.retries", DefaultHandshakeRetries),
waitRotation: config.GetInt("handshakes.wait_rotation", DefaultHandshakeWaitRotation), waitRotation: config.GetInt("handshakes.wait_rotation", DefaultHandshakeWaitRotation),
triggerBuffer: config.GetInt("handshakes.trigger_buffer", DefaultHandshakeTriggerBuffer),
messageMetrics: messageMetrics,
} }
handshakeManager := NewHandshakeManager(tunCidr, preferredRanges, hostMap, lightHouse, udpServer, handshakeConfig) handshakeManager := NewHandshakeManager(tunCidr, preferredRanges, hostMap, lightHouse, udpServer, handshakeConfig)
lightHouse.handshakeTrigger = handshakeManager.trigger
//TODO: These will be reused for psk //TODO: These will be reused for psk
//handshakeMACKey := config.GetString("handshake_mac.key", "") //handshakeMACKey := config.GetString("handshake_mac.key", "")
@@ -310,6 +329,10 @@ func Main(configPath string, configTest bool, buildVersion string) {
DropLocalBroadcast: config.GetBool("tun.drop_local_broadcast", false), DropLocalBroadcast: config.GetBool("tun.drop_local_broadcast", false),
DropMulticast: config.GetBool("tun.drop_multicast", false), DropMulticast: config.GetBool("tun.drop_multicast", false),
UDPBatchSize: config.GetInt("listen.batch", 64), UDPBatchSize: config.GetInt("listen.batch", 64),
udpQueues: udpQueues,
tunQueues: config.GetInt("tun.routines", 1),
MessageMetrics: messageMetrics,
version: buildVersion,
} }
switch ifConfig.Cipher { switch ifConfig.Cipher {
@@ -318,14 +341,14 @@ func Main(configPath string, configTest bool, buildVersion string) {
case "chachapoly": case "chachapoly":
noiseEndianness = binary.LittleEndian noiseEndianness = binary.LittleEndian
default: default:
l.Fatalf("Unknown cipher: %v", ifConfig.Cipher) return nil, fmt.Errorf("unknown cipher: %v", ifConfig.Cipher)
} }
var ifce *Interface var ifce *Interface
if !configTest { if !configTest {
ifce, err = NewInterface(ifConfig) ifce, err = NewInterface(ifConfig)
if err != nil { if err != nil {
l.WithError(err).Fatal("Failed to initialize interface") return nil, fmt.Errorf("failed to initialize interface: %s", err)
} }
ifce.RegisterConfigChangeCallbacks(config) ifce.RegisterConfigChangeCallbacks(config)
@@ -336,18 +359,17 @@ func Main(configPath string, configTest bool, buildVersion string) {
err = startStats(config, configTest) err = startStats(config, configTest)
if err != nil { if err != nil {
l.WithError(err).Fatal("Failed to start stats emitter") return nil, NewContextualError("Failed to start stats emitter", nil, err)
} }
if configTest { if configTest {
os.Exit(0) return nil, nil
} }
//TODO: check if we _should_ be emitting stats //TODO: check if we _should_ be emitting stats
go ifce.emitStats(config.GetDuration("stats.interval", time.Second*10)) go ifce.emitStats(config.GetDuration("stats.interval", time.Second*10))
attachCommands(ssh, hostMap, handshakeManager.pendingHostMap, lightHouse, ifce) attachCommands(ssh, hostMap, handshakeManager.pendingHostMap, lightHouse, ifce)
ifce.Run(config.GetInt("tun.routines", 1), udpQueues, buildVersion)
// Start DNS server last to allow using the nebula IP as lighthouse.dns.host // Start DNS server last to allow using the nebula IP as lighthouse.dns.host
if amLighthouse && serveDns { if amLighthouse && serveDns {
@@ -355,30 +377,5 @@ func Main(configPath string, configTest bool, buildVersion string) {
go dnsMain(hostMap, config) go dnsMain(hostMap, config)
} }
// Just sit here and be friendly, main thread. return &Control{ifce, l}, nil
shutdownBlock(ifce)
}
func shutdownBlock(ifce *Interface) {
var sigChan = make(chan os.Signal)
signal.Notify(sigChan, syscall.SIGTERM)
signal.Notify(sigChan, syscall.SIGINT)
sig := <-sigChan
l.WithField("signal", sig).Info("Caught signal, shutting down")
//TODO: stop tun and udp routines, the lock on hostMap does effectively does that though
//TODO: this is probably better as a function in ConnectionManager or HostMap directly
ifce.hostMap.Lock()
for _, h := range ifce.hostMap.Hosts {
if h.ConnectionState.ready {
ifce.send(closeTunnel, 0, h.ConnectionState, h, h.remote, []byte{}, make([]byte, 12, 12), make([]byte, mtu))
l.WithField("vpnIp", IntIp(h.hostId)).WithField("udpAddr", h.remote).
Debug("Sending close tunnel message")
}
}
ifce.hostMap.Unlock()
l.WithField("signal", sig).Info("Goodbye")
os.Exit(0)
} }

97
message_metrics.go Normal file
View File

@@ -0,0 +1,97 @@
package nebula
import (
"fmt"
"github.com/rcrowley/go-metrics"
)
type MessageMetrics struct {
rx [][]metrics.Counter
tx [][]metrics.Counter
rxUnknown metrics.Counter
txUnknown metrics.Counter
}
func (m *MessageMetrics) Rx(t NebulaMessageType, s NebulaMessageSubType, i int64) {
if m != nil {
if t >= 0 && int(t) < len(m.rx) && s >= 0 && int(s) < len(m.rx[t]) {
m.rx[t][s].Inc(i)
} else if m.rxUnknown != nil {
m.rxUnknown.Inc(i)
}
}
}
func (m *MessageMetrics) Tx(t NebulaMessageType, s NebulaMessageSubType, i int64) {
if m != nil {
if t >= 0 && int(t) < len(m.tx) && s >= 0 && int(s) < len(m.tx[t]) {
m.tx[t][s].Inc(i)
} else if m.txUnknown != nil {
m.txUnknown.Inc(i)
}
}
}
func newMessageMetrics() *MessageMetrics {
gen := func(t string) [][]metrics.Counter {
return [][]metrics.Counter{
{
metrics.GetOrRegisterCounter(fmt.Sprintf("messages.%s.handshake_ixpsk0", t), nil),
},
nil,
{metrics.GetOrRegisterCounter(fmt.Sprintf("messages.%s.recv_error", t), nil)},
{metrics.GetOrRegisterCounter(fmt.Sprintf("messages.%s.lighthouse", t), nil)},
{
metrics.GetOrRegisterCounter(fmt.Sprintf("messages.%s.test_request", t), nil),
metrics.GetOrRegisterCounter(fmt.Sprintf("messages.%s.test_response", t), nil),
},
{metrics.GetOrRegisterCounter(fmt.Sprintf("messages.%s.close_tunnel", t), nil)},
}
}
return &MessageMetrics{
rx: gen("rx"),
tx: gen("tx"),
rxUnknown: metrics.GetOrRegisterCounter("messages.rx.other", nil),
txUnknown: metrics.GetOrRegisterCounter("messages.tx.other", nil),
}
}
// Historically we only recorded recv_error, so this is backwards compat
func newMessageMetricsOnlyRecvError() *MessageMetrics {
gen := func(t string) [][]metrics.Counter {
return [][]metrics.Counter{
nil,
nil,
{metrics.GetOrRegisterCounter(fmt.Sprintf("messages.%s.recv_error", t), nil)},
}
}
return &MessageMetrics{
rx: gen("rx"),
tx: gen("tx"),
}
}
func newLighthouseMetrics() *MessageMetrics {
gen := func(t string) [][]metrics.Counter {
h := make([][]metrics.Counter, len(NebulaMeta_MessageType_name))
used := []NebulaMeta_MessageType{
NebulaMeta_HostQuery,
NebulaMeta_HostQueryReply,
NebulaMeta_HostUpdateNotification,
NebulaMeta_HostPunchNotification,
}
for _, i := range used {
h[i] = []metrics.Counter{metrics.GetOrRegisterCounter(fmt.Sprintf("lighthouse.%s.%s", t, i.String()), nil)}
}
return h
}
return &MessageMetrics{
rx: gen("rx"),
tx: gen("tx"),
rxUnknown: metrics.GetOrRegisterCounter("lighthouse.rx.other", nil),
txUnknown: metrics.GetOrRegisterCounter("lighthouse.tx.other", nil),
}
}

View File

@@ -2,18 +2,14 @@ package nebula
import ( import (
"encoding/binary" "encoding/binary"
"errors"
"fmt"
"time"
"github.com/flynn/noise" "github.com/flynn/noise"
"github.com/golang/protobuf/proto" "github.com/golang/protobuf/proto"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/cert"
// "github.com/google/gopacket"
// "github.com/google/gopacket/layers"
// "encoding/binary"
"errors"
"fmt"
"time"
"golang.org/x/net/ipv4" "golang.org/x/net/ipv4"
) )
@@ -43,92 +39,15 @@ func (f *Interface) readOutsidePackets(addr *udpAddr, out []byte, packet []byte,
ci = hostinfo.ConnectionState ci = hostinfo.ConnectionState
} }
switch header.Type { handle := f.handlers[header.Version][header.Type][header.Subtype]
case message:
if !f.handleEncrypted(ci, addr, header) {
return
}
f.decryptToTun(hostinfo, header.MessageCounter, out, packet, fwPacket, nb) if handle == nil {
f.messageMetrics.Rx(header.Type, header.Subtype, 1)
// Fallthrough to the bottom to record incoming traffic
case lightHouse:
if !f.handleEncrypted(ci, addr, header) {
return
}
d, err := f.decrypt(hostinfo, header.MessageCounter, out, packet, header, nb)
if err != nil {
hostinfo.logger().WithError(err).WithField("udpAddr", addr).
WithField("packet", packet).
Error("Failed to decrypt lighthouse packet")
//TODO: maybe after build 64 is out? 06/14/2018 - NB
//f.sendRecvError(net.Addr(addr), header.RemoteIndex)
return
}
f.lightHouse.HandleRequest(addr, hostinfo.hostId, d, hostinfo.GetCert(), f)
// Fallthrough to the bottom to record incoming traffic
case test:
if !f.handleEncrypted(ci, addr, header) {
return
}
d, err := f.decrypt(hostinfo, header.MessageCounter, out, packet, header, nb)
if err != nil {
hostinfo.logger().WithError(err).WithField("udpAddr", addr).
WithField("packet", packet).
Error("Failed to decrypt test packet")
//TODO: maybe after build 64 is out? 06/14/2018 - NB
//f.sendRecvError(net.Addr(addr), header.RemoteIndex)
return
}
if header.Subtype == testRequest {
// This testRequest might be from TryPromoteBest, so we should roam
// to the new IP address before responding
f.handleHostRoaming(hostinfo, addr)
f.send(test, testReply, ci, hostinfo, hostinfo.remote, d, nb, out)
}
// Fallthrough to the bottom to record incoming traffic
// Non encrypted messages below here, they should not fall through to avoid tracking incoming traffic since they
// are unauthenticated
case handshake:
HandleIncomingHandshake(f, addr, packet, header, hostinfo)
return
case recvError:
// TODO: Remove this with recv_error deprecation
f.handleRecvError(addr, header)
return
case closeTunnel:
if !f.handleEncrypted(ci, addr, header) {
return
}
hostinfo.logger().WithField("udpAddr", addr).
Info("Close tunnel received, tearing down.")
f.closeTunnel(hostinfo)
return
default:
hostinfo.logger().Debugf("Unexpected packet received from %s", addr) hostinfo.logger().Debugf("Unexpected packet received from %s", addr)
return return
} }
f.handleHostRoaming(hostinfo, addr) handle(hostinfo, ci, addr, header, out, packet, fwPacket, nb)
f.connectionManager.In(hostinfo.hostId)
} }
func (f *Interface) closeTunnel(hostInfo *HostInfo) { func (f *Interface) closeTunnel(hostInfo *HostInfo) {
@@ -256,7 +175,7 @@ func (f *Interface) decrypt(hostinfo *HostInfo, mc uint64, out []byte, packet []
return out, nil return out, nil
} }
func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out []byte, packet []byte, fwPacket *FirewallPacket, nb []byte) { func (f *Interface) decryptTo(write func([]byte) error, hostinfo *HostInfo, messageCounter uint64, out []byte, packet []byte, fwPacket *FirewallPacket, nb []byte) {
var err error var err error
out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:HeaderLen], packet[HeaderLen:], messageCounter, nb) out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:HeaderLen], packet[HeaderLen:], messageCounter, nb)
@@ -280,21 +199,25 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
return return
} }
if f.firewall.Drop(out, *fwPacket, true, hostinfo, trustedCAs) { dropReason := f.firewall.Drop(out, *fwPacket, true, hostinfo, trustedCAs)
if dropReason != nil {
if l.Level >= logrus.DebugLevel {
hostinfo.logger().WithField("fwPacket", fwPacket). hostinfo.logger().WithField("fwPacket", fwPacket).
WithField("reason", dropReason).
Debugln("dropping inbound packet") Debugln("dropping inbound packet")
}
return return
} }
f.connectionManager.In(hostinfo.hostId) f.connectionManager.In(hostinfo.hostId)
err = f.inside.WriteRaw(out) err = write(out)
if err != nil { if err != nil {
l.WithError(err).Error("Failed to write to tun") l.WithError(err).Error("Failed to write to tun")
} }
} }
func (f *Interface) sendRecvError(endpoint *udpAddr, index uint32) { func (f *Interface) sendRecvError(endpoint *udpAddr, index uint32) {
f.metricTxRecvError.Inc(1) f.messageMetrics.Tx(recvError, 0, 1)
//TODO: this should be a signed message so we can trust that we should drop the index //TODO: this should be a signed message so we can trust that we should drop the index
b := HeaderEncode(make([]byte, HeaderLen), Version, uint8(recvError), 0, index, 0) b := HeaderEncode(make([]byte, HeaderLen), Version, uint8(recvError), 0, index, 0)
@@ -307,8 +230,6 @@ func (f *Interface) sendRecvError(endpoint *udpAddr, index uint32) {
} }
func (f *Interface) handleRecvError(addr *udpAddr, h *Header) { func (f *Interface) handleRecvError(addr *udpAddr, h *Header) {
f.metricRxRecvError.Inc(1)
// This flag is to stop caring about recv_error from old versions // This flag is to stop caring about recv_error from old versions
// This should go away when the old version is gone from prod // This should go away when the old version is gone from prod
if l.Level >= logrus.DebugLevel { if l.Level >= logrus.DebugLevel {

View File

@@ -1,10 +1,11 @@
package nebula package nebula
import ( import (
"github.com/stretchr/testify/assert"
"golang.org/x/net/ipv4"
"net" "net"
"testing" "testing"
"github.com/stretchr/testify/assert"
"golang.org/x/net/ipv4"
) )
func Test_newPacket(t *testing.T) { func Test_newPacket(t *testing.T) {

View File

@@ -1,9 +1,10 @@
package nebula package nebula
import ( import (
"github.com/stretchr/testify/assert"
"testing" "testing"
"time" "time"
"github.com/stretchr/testify/assert"
) )
func TestNewPunchyFromConfig(t *testing.T) { func TestNewPunchyFromConfig(t *testing.T) {

56
ssh.go
View File

@@ -5,8 +5,6 @@ import (
"encoding/json" "encoding/json"
"flag" "flag"
"fmt" "fmt"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/sshd"
"io/ioutil" "io/ioutil"
"net" "net"
"os" "os"
@@ -14,6 +12,9 @@ import (
"runtime/pprof" "runtime/pprof"
"strings" "strings"
"syscall" "syscall"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/sshd"
) )
type sshListHostMapFlags struct { type sshListHostMapFlags struct {
@@ -65,10 +66,11 @@ func configSSH(ssh *sshd.SSHServer, c *Config) error {
return fmt.Errorf("sshd.listen must be provided") return fmt.Errorf("sshd.listen must be provided")
} }
port := strings.Split(listen, ":") _, port, err := net.SplitHostPort(listen)
if len(port) < 2 { if err != nil {
return fmt.Errorf("sshd.listen does not have a port") return fmt.Errorf("invalid sshd.listen address: %s", err)
} else if port[1] == "22" { }
if port == "22" {
return fmt.Errorf("sshd.listen can not use port 22") return fmt.Errorf("sshd.listen can not use port 22")
} }
@@ -461,7 +463,12 @@ func sshQueryLighthouse(ifce *Interface, fs interface{}, a []string, w sshd.Stri
return w.WriteLine("No vpn ip was provided") return w.WriteLine("No vpn ip was provided")
} }
vpnIp := ip2int(net.ParseIP(a[0])) parsedIp := net.ParseIP(a[0])
if parsedIp == nil {
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
}
vpnIp := ip2int(parsedIp)
if vpnIp == 0 { if vpnIp == 0 {
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
} }
@@ -481,7 +488,12 @@ func sshCloseTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWr
return w.WriteLine("No vpn ip was provided") return w.WriteLine("No vpn ip was provided")
} }
vpnIp := ip2int(net.ParseIP(a[0])) parsedIp := net.ParseIP(a[0])
if parsedIp == nil {
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
}
vpnIp := ip2int(parsedIp)
if vpnIp == 0 { if vpnIp == 0 {
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
} }
@@ -519,7 +531,12 @@ func sshCreateTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringW
return w.WriteLine("No vpn ip was provided") return w.WriteLine("No vpn ip was provided")
} }
vpnIp := ip2int(net.ParseIP(a[0])) parsedIp := net.ParseIP(a[0])
if parsedIp == nil {
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
}
vpnIp := ip2int(parsedIp)
if vpnIp == 0 { if vpnIp == 0 {
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
} }
@@ -571,7 +588,12 @@ func sshChangeRemote(ifce *Interface, fs interface{}, a []string, w sshd.StringW
return w.WriteLine("Address could not be parsed") return w.WriteLine("Address could not be parsed")
} }
vpnIp := ip2int(net.ParseIP(a[0])) parsedIp := net.ParseIP(a[0])
if parsedIp == nil {
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
}
vpnIp := ip2int(parsedIp)
if vpnIp == 0 { if vpnIp == 0 {
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
} }
@@ -647,7 +669,12 @@ func sshPrintCert(ifce *Interface, fs interface{}, a []string, w sshd.StringWrit
cert := ifce.certState.certificate cert := ifce.certState.certificate
if len(a) > 0 { if len(a) > 0 {
vpnIp := ip2int(net.ParseIP(a[0])) parsedIp := net.ParseIP(a[0])
if parsedIp == nil {
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
}
vpnIp := ip2int(parsedIp)
if vpnIp == 0 { if vpnIp == 0 {
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
} }
@@ -694,7 +721,12 @@ func sshPrintTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWr
return w.WriteLine("No vpn ip was provided") return w.WriteLine("No vpn ip was provided")
} }
vpnIp := ip2int(net.ParseIP(a[0])) parsedIp := net.ParseIP(a[0])
if parsedIp == nil {
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
}
vpnIp := ip2int(parsedIp)
if vpnIp == 0 { if vpnIp == 0 {
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
} }

View File

@@ -4,9 +4,10 @@ import (
"errors" "errors"
"flag" "flag"
"fmt" "fmt"
"github.com/armon/go-radix"
"sort" "sort"
"strings" "strings"
"github.com/armon/go-radix"
) )
// CommandFlags is a function called before help or command execution to parse command line flags // CommandFlags is a function called before help or command execution to parse command line flags

View File

@@ -2,10 +2,11 @@ package sshd
import ( import (
"fmt" "fmt"
"net"
"github.com/armon/go-radix" "github.com/armon/go-radix"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
"net"
) )
type SSHServer struct { type SSHServer struct {

View File

@@ -2,13 +2,14 @@ package sshd
import ( import (
"fmt" "fmt"
"sort"
"strings"
"github.com/anmitsu/go-shlex" "github.com/anmitsu/go-shlex"
"github.com/armon/go-radix" "github.com/armon/go-radix"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/terminal" "golang.org/x/crypto/ssh/terminal"
"sort"
"strings"
) )
type session struct { type session struct {

View File

@@ -3,15 +3,16 @@ package nebula
import ( import (
"errors" "errors"
"fmt" "fmt"
"github.com/cyberdelia/go-metrics-graphite"
mp "github.com/nbrownus/go-metrics-prometheus"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promhttp"
"github.com/rcrowley/go-metrics"
"log" "log"
"net" "net"
"net/http" "net/http"
"time" "time"
graphite "github.com/cyberdelia/go-metrics-graphite"
mp "github.com/nbrownus/go-metrics-prometheus"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promhttp"
"github.com/rcrowley/go-metrics"
) )
func startStats(c *Config, configTest bool) error { func startStats(c *Config, configTest bool) error {

View File

@@ -1,9 +1,10 @@
package nebula package nebula
import ( import (
"github.com/stretchr/testify/assert"
"testing" "testing"
"time" "time"
"github.com/stretchr/testify/assert"
) )
func TestNewTimerWheel(t *testing.T) { func TestNewTimerWheel(t *testing.T) {

76
tun_android.go Normal file
View File

@@ -0,0 +1,76 @@
package nebula
import (
"fmt"
"io"
"net"
"os"
"golang.org/x/sys/unix"
)
type Tun struct {
io.ReadWriteCloser
fd int
Device string
Cidr *net.IPNet
MaxMTU int
DefaultMTU int
TXQueueLen int
Routes []route
UnsafeRoutes []route
}
func newTunFromFd(deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
ifce = &Tun{
ReadWriteCloser: file,
fd: int(file.Fd()),
Device: "android",
Cidr: cidr,
DefaultMTU: defaultMTU,
TXQueueLen: txQueueLen,
Routes: routes,
UnsafeRoutes: unsafeRoutes,
}
return
}
func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
return nil, fmt.Errorf("newTun not supported in Android")
}
func (c *Tun) WriteRaw(b []byte) error {
var nn int
for {
max := len(b)
n, err := unix.Write(c.fd, b[nn:max])
if n > 0 {
nn += n
}
if nn == len(b) {
return err
}
if err != nil {
return err
}
if n == 0 {
return io.ErrUnexpectedEOF
}
}
}
func (c Tun) Activate() error {
return nil
}
func (c *Tun) CidrNet() *net.IPNet {
return c.Cidr
}
func (c *Tun) DeviceName() string {
return c.Device
}

View File

@@ -1,3 +1,5 @@
// +build !ios
package nebula package nebula
import ( import (
@@ -20,8 +22,9 @@ type Tun struct {
func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) { func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
if len(routes) > 0 { if len(routes) > 0 {
return nil, fmt.Errorf("Route MTU not supported in Darwin") return nil, fmt.Errorf("route MTU not supported in Darwin")
} }
// NOTE: You cannot set the deviceName under Darwin, so you must check tun.Device after calling .Activate() // NOTE: You cannot set the deviceName under Darwin, so you must check tun.Device after calling .Activate()
return &Tun{ return &Tun{
Cidr: cidr, Cidr: cidr,
@@ -30,13 +33,17 @@ func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route,
}, nil }, nil
} }
func newTunFromFd(deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
return nil, fmt.Errorf("newTunFromFd not supported in Darwin")
}
func (c *Tun) Activate() error { func (c *Tun) Activate() error {
var err error var err error
c.Interface, err = water.New(water.Config{ c.Interface, err = water.New(water.Config{
DeviceType: water.TUN, DeviceType: water.TUN,
}) })
if err != nil { if err != nil {
return fmt.Errorf("Activate failed: %v", err) return fmt.Errorf("activate failed: %v", err)
} }
c.Device = c.Interface.Name() c.Device = c.Interface.Name()
@@ -61,6 +68,14 @@ func (c *Tun) Activate() error {
return nil return nil
} }
func (c *Tun) CidrNet() *net.IPNet {
return c.Cidr
}
func (c *Tun) DeviceName() string {
return c.Device
}
func (c *Tun) WriteRaw(b []byte) error { func (c *Tun) WriteRaw(b []byte) error {
_, err := c.Write(b) _, err := c.Write(b)
return err return err

74
tun_disabled.go Normal file
View File

@@ -0,0 +1,74 @@
package nebula
import (
"fmt"
"io"
"net"
"strings"
log "github.com/sirupsen/logrus"
)
type disabledTun struct {
block chan struct{}
cidr *net.IPNet
logger *log.Logger
}
func newDisabledTun(cidr *net.IPNet, l *log.Logger) *disabledTun {
return &disabledTun{
cidr: cidr,
block: make(chan struct{}),
logger: l,
}
}
func (*disabledTun) Activate() error {
return nil
}
func (t *disabledTun) CidrNet() *net.IPNet {
return t.cidr
}
func (*disabledTun) DeviceName() string {
return "disabled"
}
func (t *disabledTun) Read(b []byte) (int, error) {
<-t.block
return 0, io.EOF
}
func (t *disabledTun) Write(b []byte) (int, error) {
t.logger.WithField("raw", prettyPacket(b)).Debugf("Disabled tun received unexpected payload")
return len(b), nil
}
func (t *disabledTun) WriteRaw(b []byte) error {
_, err := t.Write(b)
return err
}
func (t *disabledTun) Close() error {
if t.block != nil {
close(t.block)
t.block = nil
}
return nil
}
type prettyPacket []byte
func (p prettyPacket) String() string {
var s strings.Builder
for i, b := range p {
if i > 0 && i%8 == 0 {
s.WriteString(" ")
}
s.WriteString(fmt.Sprintf("%02x ", b))
}
return s.String()
}

89
tun_freebsd.go Normal file
View File

@@ -0,0 +1,89 @@
package nebula
import (
"fmt"
"io"
"net"
"os"
"os/exec"
"regexp"
"strconv"
"strings"
)
var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
type Tun struct {
Device string
Cidr *net.IPNet
MTU int
UnsafeRoutes []route
io.ReadWriteCloser
}
func newTunFromFd(deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
return nil, fmt.Errorf("newTunFromFd not supported in FreeBSD")
}
func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
if len(routes) > 0 {
return nil, fmt.Errorf("Route MTU not supported in FreeBSD")
}
if strings.HasPrefix(deviceName, "/dev/") {
deviceName = strings.TrimPrefix(deviceName, "/dev/")
}
if !deviceNameRE.MatchString(deviceName) {
return nil, fmt.Errorf("tun.dev must match `tun[0-9]+`")
}
return &Tun{
Device: deviceName,
Cidr: cidr,
MTU: defaultMTU,
UnsafeRoutes: unsafeRoutes,
}, nil
}
func (c *Tun) Activate() error {
var err error
c.ReadWriteCloser, err = os.OpenFile("/dev/"+c.Device, os.O_RDWR, 0)
if err != nil {
return fmt.Errorf("Activate failed: %v", err)
}
// TODO use syscalls instead of exec.Command
l.Debug("command: ifconfig", c.Device, c.Cidr.String(), c.Cidr.IP.String())
if err = exec.Command("/sbin/ifconfig", c.Device, c.Cidr.String(), c.Cidr.IP.String()).Run(); err != nil {
return fmt.Errorf("failed to run 'ifconfig': %s", err)
}
l.Debug("command: route", "-n", "add", "-net", c.Cidr.String(), "-interface", c.Device)
if err = exec.Command("/sbin/route", "-n", "add", "-net", c.Cidr.String(), "-interface", c.Device).Run(); err != nil {
return fmt.Errorf("failed to run 'route add': %s", err)
}
l.Debug("command: ifconfig", c.Device, "mtu", strconv.Itoa(c.MTU))
if err = exec.Command("/sbin/ifconfig", c.Device, "mtu", strconv.Itoa(c.MTU)).Run(); err != nil {
return fmt.Errorf("failed to run 'ifconfig': %s", err)
}
// Unsafe path routes
for _, r := range c.UnsafeRoutes {
l.Debug("command: route", "-n", "add", "-net", r.route.String(), "-interface", c.Device)
if err = exec.Command("/sbin/route", "-n", "add", "-net", r.route.String(), "-interface", c.Device).Run(); err != nil {
return fmt.Errorf("failed to run 'route add' for unsafe_route %s: %s", r.route.String(), err)
}
}
return nil
}
func (c *Tun) CidrNet() *net.IPNet {
return c.Cidr
}
func (c *Tun) DeviceName() string {
return c.Device
}
func (c *Tun) WriteRaw(b []byte) error {
_, err := c.Write(b)
return err
}

113
tun_ios.go Normal file
View File

@@ -0,0 +1,113 @@
// +build ios
package nebula
import (
"errors"
"fmt"
"io"
"net"
"os"
"sync"
"syscall"
)
type Tun struct {
io.ReadWriteCloser
Device string
Cidr *net.IPNet
}
func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
return nil, fmt.Errorf("newTun not supported in iOS")
}
func newTunFromFd(deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
if len(routes) > 0 {
return nil, fmt.Errorf("route MTU not supported in Darwin")
}
file := os.NewFile(uintptr(deviceFd), "/dev/tun")
ifce = &Tun{
Cidr: cidr,
Device: "iOS",
ReadWriteCloser: &tunReadCloser{f: file},
}
return
}
func (c *Tun) Activate() error {
return nil
}
func (c *Tun) WriteRaw(b []byte) error {
_, err := c.Write(b)
return err
}
// The following is hoisted up from water, we do this so we can inject our own fd on iOS
type tunReadCloser struct {
f io.ReadWriteCloser
rMu sync.Mutex
rBuf []byte
wMu sync.Mutex
wBuf []byte
}
func (t *tunReadCloser) Read(to []byte) (int, error) {
t.rMu.Lock()
defer t.rMu.Unlock()
if cap(t.rBuf) < len(to)+4 {
t.rBuf = make([]byte, len(to)+4)
}
t.rBuf = t.rBuf[:len(to)+4]
n, err := t.f.Read(t.rBuf)
copy(to, t.rBuf[4:])
return n - 4, err
}
func (t *tunReadCloser) Write(from []byte) (int, error) {
if len(from) == 0 {
return 0, syscall.EIO
}
t.wMu.Lock()
defer t.wMu.Unlock()
if cap(t.wBuf) < len(from)+4 {
t.wBuf = make([]byte, len(from)+4)
}
t.wBuf = t.wBuf[:len(from)+4]
// Determine the IP Family for the NULL L2 Header
ipVer := from[0] >> 4
if ipVer == 4 {
t.wBuf[3] = syscall.AF_INET
} else if ipVer == 6 {
t.wBuf[3] = syscall.AF_INET6
} else {
return 0, errors.New("unable to determine IP version from packet")
}
copy(t.wBuf[4:], from)
n, err := t.f.Write(t.wBuf)
return n - 4, err
}
func (t *tunReadCloser) Close() error {
return t.f.Close()
}
func (c *Tun) CidrNet() *net.IPNet {
return c.Cidr
}
func (c *Tun) DeviceName() string {
return c.Device
}

View File

@@ -1,3 +1,5 @@
// +build !android
package nebula package nebula
import ( import (
@@ -75,6 +77,23 @@ type ifreqQLEN struct {
pad [8]byte pad [8]byte
} }
func newTunFromFd(deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
ifce = &Tun{
ReadWriteCloser: file,
fd: int(file.Fd()),
Device: "tun0",
Cidr: cidr,
DefaultMTU: defaultMTU,
TXQueueLen: txQueueLen,
Routes: routes,
UnsafeRoutes: unsafeRoutes,
}
return
}
func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) { func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0) fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
if err != nil { if err != nil {
@@ -216,6 +235,7 @@ func (c Tun) Activate() error {
LinkIndex: link.Attrs().Index, LinkIndex: link.Attrs().Index,
Dst: dr, Dst: dr,
MTU: c.DefaultMTU, MTU: c.DefaultMTU,
AdvMSS: c.advMSS(route{}),
Scope: unix.RT_SCOPE_LINK, Scope: unix.RT_SCOPE_LINK,
Src: c.Cidr.IP, Src: c.Cidr.IP,
Protocol: unix.RTPROT_KERNEL, Protocol: unix.RTPROT_KERNEL,
@@ -233,6 +253,7 @@ func (c Tun) Activate() error {
LinkIndex: link.Attrs().Index, LinkIndex: link.Attrs().Index,
Dst: r.route, Dst: r.route,
MTU: r.mtu, MTU: r.mtu,
AdvMSS: c.advMSS(r),
Scope: unix.RT_SCOPE_LINK, Scope: unix.RT_SCOPE_LINK,
} }
@@ -248,6 +269,7 @@ func (c Tun) Activate() error {
LinkIndex: link.Attrs().Index, LinkIndex: link.Attrs().Index,
Dst: r.route, Dst: r.route,
MTU: r.mtu, MTU: r.mtu,
AdvMSS: c.advMSS(r),
Scope: unix.RT_SCOPE_LINK, Scope: unix.RT_SCOPE_LINK,
} }
@@ -265,3 +287,24 @@ func (c Tun) Activate() error {
return nil return nil
} }
func (c *Tun) CidrNet() *net.IPNet {
return c.Cidr
}
func (c *Tun) DeviceName() string {
return c.Device
}
func (c Tun) advMSS(r route) int {
mtu := r.mtu
if r.mtu == 0 {
mtu = c.DefaultMTU
}
// We only need to set advmss if the route MTU does not match the device MTU
if mtu != c.MaxMTU {
return mtu - 40
}
return 0
}

31
tun_linux_test.go Normal file
View File

@@ -0,0 +1,31 @@
package nebula
import "testing"
var runAdvMSSTests = []struct {
name string
tun Tun
r route
expected int
}{
// Standard case, default MTU is the device max MTU
{"default", Tun{DefaultMTU: 1440, MaxMTU: 1440}, route{}, 0},
{"default-min", Tun{DefaultMTU: 1440, MaxMTU: 1440}, route{mtu: 1440}, 0},
{"default-low", Tun{DefaultMTU: 1440, MaxMTU: 1440}, route{mtu: 1200}, 1160},
// Case where we have a route MTU set higher than the default
{"route", Tun{DefaultMTU: 1440, MaxMTU: 8941}, route{}, 1400},
{"route-min", Tun{DefaultMTU: 1440, MaxMTU: 8941}, route{mtu: 1440}, 1400},
{"route-high", Tun{DefaultMTU: 1440, MaxMTU: 8941}, route{mtu: 8941}, 0},
}
func TestTunAdvMSS(t *testing.T) {
for _, tt := range runAdvMSSTests {
t.Run(tt.name, func(t *testing.T) {
o := tt.tun.advMSS(tt.r)
if o != tt.expected {
t.Errorf("got %d, want %d", o, tt.expected)
}
})
}
}

View File

@@ -18,9 +18,13 @@ type Tun struct {
*water.Interface *water.Interface
} }
func newTunFromFd(deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
return nil, fmt.Errorf("newTunFromFd not supported in Windows")
}
func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) { func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
if len(routes) > 0 { if len(routes) > 0 {
return nil, fmt.Errorf("Route MTU not supported in Windows") return nil, fmt.Errorf("route MTU not supported in Windows")
} }
// NOTE: You cannot set the deviceName under Windows, so you must check tun.Device after calling .Activate() // NOTE: You cannot set the deviceName under Windows, so you must check tun.Device after calling .Activate()
@@ -84,6 +88,14 @@ func (c *Tun) Activate() error {
return nil return nil
} }
func (c *Tun) CidrNet() *net.IPNet {
return c.Cidr
}
func (c *Tun) DeviceName() string {
return c.Device
}
func (c *Tun) WriteRaw(b []byte) error { func (c *Tun) WriteRaw(b []byte) error {
_, err := c.Write(b) _, err := c.Write(b)
return err return err

36
udp_android.go Normal file
View File

@@ -0,0 +1,36 @@
package nebula
import (
"fmt"
"net"
"syscall"
"golang.org/x/sys/unix"
)
func NewListenConfig(multi bool) net.ListenConfig {
return net.ListenConfig{
Control: func(network, address string, c syscall.RawConn) error {
if multi {
var controlErr error
err := c.Control(func(fd uintptr) {
if err := syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, unix.SO_REUSEPORT, 1); err != nil {
controlErr = fmt.Errorf("SO_REUSEPORT failed: %v", err)
return
}
})
if err != nil {
return err
}
if controlErr != nil {
return controlErr
}
}
return nil
},
}
}
func (u *udpConn) Rebind() error {
return nil
}

View File

@@ -32,3 +32,12 @@ func NewListenConfig(multi bool) net.ListenConfig {
}, },
} }
} }
func (u *udpConn) Rebind() error {
file, err := u.File()
if err != nil {
return err
}
return syscall.SetsockoptInt(int(file.Fd()), unix.IPPROTO_IP, unix.IP_BOUND_IF, 0)
}

38
udp_freebsd.go Normal file
View File

@@ -0,0 +1,38 @@
package nebula
// FreeBSD support is primarily implemented in udp_generic, besides NewListenConfig
import (
"fmt"
"net"
"syscall"
"golang.org/x/sys/unix"
)
func NewListenConfig(multi bool) net.ListenConfig {
return net.ListenConfig{
Control: func(network, address string, c syscall.RawConn) error {
if multi {
var controlErr error
err := c.Control(func(fd uintptr) {
if err := syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, unix.SO_REUSEPORT, 1); err != nil {
controlErr = fmt.Errorf("SO_REUSEPORT failed: %v", err)
return
}
})
if err != nil {
return err
}
if controlErr != nil {
return controlErr
}
}
return nil
},
}
}
func (u *udpConn) Rebind() error {
return nil
}

View File

@@ -1,4 +1,4 @@
// +build !linux // +build !linux android
// udp_generic implements the nebula UDP interface in pure Go stdlib. This // udp_generic implements the nebula UDP interface in pure Go stdlib. This
// means it can be used on platforms like Darwin and Windows. // means it can be used on platforms like Darwin and Windows.
@@ -65,6 +65,17 @@ func (ua *udpAddr) Equals(t *udpAddr) bool {
return ua.IP.Equal(t.IP) && ua.Port == t.Port return ua.IP.Equal(t.IP) && ua.Port == t.Port
} }
func (ua *udpAddr) Copy() udpAddr {
nu := udpAddr{net.UDPAddr{
Port: ua.Port,
Zone: ua.Zone,
IP: make(net.IP, len(ua.IP)),
}}
copy(nu.IP, ua.IP)
return nu
}
func (uc *udpConn) WriteTo(b []byte, addr *udpAddr) error { func (uc *udpConn) WriteTo(b []byte, addr *udpAddr) error {
_, err := uc.UDPConn.WriteToUDP(b, &addr.UDPAddr) _, err := uc.UDPConn.WriteToUDP(b, &addr.UDPAddr)
return err return err

View File

@@ -1,3 +1,5 @@
// +build !android
package nebula package nebula
import ( import (
@@ -69,9 +71,11 @@ func NewListener(ip string, port int, multi bool) (*udpConn, error) {
var lip [4]byte var lip [4]byte
copy(lip[:], net.ParseIP(ip).To4()) copy(lip[:], net.ParseIP(ip).To4())
if multi {
if err = unix.SetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_REUSEPORT, 1); err != nil { if err = unix.SetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_REUSEPORT, 1); err != nil {
return nil, fmt.Errorf("unable to set SO_REUSEPORT: %s", err) return nil, fmt.Errorf("unable to set SO_REUSEPORT: %s", err)
} }
}
if err = unix.Bind(fd, &unix.SockaddrInet4{Addr: lip, Port: port}); err != nil { if err = unix.Bind(fd, &unix.SockaddrInet4{Addr: lip, Port: port}); err != nil {
return nil, fmt.Errorf("unable to bind to socket: %s", err) return nil, fmt.Errorf("unable to bind to socket: %s", err)
@@ -85,6 +89,14 @@ func NewListener(ip string, port int, multi bool) (*udpConn, error) {
return &udpConn{sysFd: fd}, err return &udpConn{sysFd: fd}, err
} }
func (u *udpConn) Rebind() error {
return nil
}
func (ua *udpAddr) Copy() udpAddr {
return *ua
}
func (u *udpConn) SetRecvBuffer(n int) error { func (u *udpConn) SetRecvBuffer(n int) error {
return unix.SetsockoptInt(u.sysFd, unix.SOL_SOCKET, unix.SO_RCVBUFFORCE, n) return unix.SetsockoptInt(u.sysFd, unix.SOL_SOCKET, unix.SO_RCVBUFFORCE, n)
} }
@@ -137,9 +149,13 @@ func (u *udpConn) ListenOut(f *Interface) {
//TODO: should we track this? //TODO: should we track this?
//metric := metrics.GetOrRegisterHistogram("test.batch_read", nil, metrics.NewExpDecaySample(1028, 0.015)) //metric := metrics.GetOrRegisterHistogram("test.batch_read", nil, metrics.NewExpDecaySample(1028, 0.015))
msgs, buffers, names := u.PrepareRawMessages(f.udpBatchSize) msgs, buffers, names := u.PrepareRawMessages(f.udpBatchSize)
read := u.ReadMulti
if f.udpBatchSize == 1 {
read = u.ReadSingle
}
for { for {
n, err := u.ReadMulti(msgs) n, err := read(msgs)
if err != nil { if err != nil {
l.WithError(err).Error("Failed to read packets") l.WithError(err).Error("Failed to read packets")
continue continue
@@ -155,34 +171,24 @@ func (u *udpConn) ListenOut(f *Interface) {
} }
} }
func (u *udpConn) Read(addr *udpAddr, b []byte) ([]byte, error) { func (u *udpConn) ReadSingle(msgs []rawMessage) (int, error) {
var rsa rawSockaddrAny
var rLen = unix.SizeofSockaddrAny
for { for {
n, _, err := unix.Syscall6( n, _, err := unix.Syscall6(
unix.SYS_RECVFROM, unix.SYS_RECVMSG,
uintptr(u.sysFd), uintptr(u.sysFd),
uintptr(unsafe.Pointer(&b[0])), uintptr(unsafe.Pointer(&(msgs[0].Hdr))),
uintptr(len(b)), 0,
uintptr(0), 0,
uintptr(unsafe.Pointer(&rsa)), 0,
uintptr(unsafe.Pointer(&rLen)), 0,
) )
if err != 0 { if err != 0 {
return nil, &net.OpError{Op: "read", Err: err} return 0, &net.OpError{Op: "recvmsg", Err: err}
} }
if rsa.Addr.Family == unix.AF_INET { msgs[0].Len = uint32(n)
addr.Port = uint16(rsa.Addr.Data[0])<<8 + uint16(rsa.Addr.Data[1]) return 1, nil
addr.IP = uint32(rsa.Addr.Data[2])<<24 + uint32(rsa.Addr.Data[3])<<16 + uint32(rsa.Addr.Data[4])<<8 + uint32(rsa.Addr.Data[5])
} else {
addr.Port = 0
addr.IP = 0
}
return b[:n], nil
} }
} }
@@ -280,13 +286,6 @@ func (ua *udpAddr) Equals(t *udpAddr) bool {
return ua.IP == t.IP && ua.Port == t.Port return ua.IP == t.IP && ua.Port == t.Port
} }
func (ua *udpAddr) Copy() *udpAddr {
return &udpAddr{
Port: ua.Port,
IP: ua.IP,
}
}
func (ua *udpAddr) String() string { func (ua *udpAddr) String() string {
return fmt.Sprintf("%s:%v", int2ip(ua.IP), ua.Port) return fmt.Sprintf("%s:%v", int2ip(ua.IP), ua.Port)
} }

View File

@@ -1,5 +1,6 @@
// +build linux // +build linux
// +build 386 amd64p32 arm mips mipsle // +build 386 amd64p32 arm mips mipsle
// +build !android
package nebula package nebula

View File

@@ -1,5 +1,6 @@
// +build linux // +build linux
// +build amd64 arm64 ppc64 ppc64le mips64 mips64le s390x // +build amd64 arm64 ppc64 ppc64le mips64 mips64le s390x
// +build !android
package nebula package nebula

View File

@@ -20,3 +20,7 @@ func NewListenConfig(multi bool) net.ListenConfig {
}, },
} }
} }
func (u *udpConn) Rebind() error {
return nil
}

130
util/assert.go Normal file
View File

@@ -0,0 +1,130 @@
package util
import (
"fmt"
"reflect"
"testing"
"time"
"unsafe"
"github.com/stretchr/testify/assert"
)
// AssertDeepCopyEqual checks to see if two variables have the same values but DO NOT share any memory
// There is currently a special case for `time.loc` (as this code traverses into unexported fields)
func AssertDeepCopyEqual(t *testing.T, a interface{}, b interface{}) {
v1 := reflect.ValueOf(a)
v2 := reflect.ValueOf(b)
if !assert.Equal(t, v1.Type(), v2.Type()) {
return
}
traverseDeepCopy(t, v1, v2, v1.Type().String())
}
func traverseDeepCopy(t *testing.T, v1 reflect.Value, v2 reflect.Value, name string) bool {
switch v1.Kind() {
case reflect.Array:
for i := 0; i < v1.Len(); i++ {
if !traverseDeepCopy(t, v1.Index(i), v2.Index(i), fmt.Sprintf("%s[%v]", name, i)) {
return false
}
}
return true
case reflect.Slice:
if v1.IsNil() || v2.IsNil() {
return assert.Equal(t, v1.IsNil(), v2.IsNil(), "%s are not both nil %+v, %+v", name, v1, v2)
}
if !assert.Equal(t, v1.Len(), v2.Len(), "%s did not have the same length", name) {
return false
}
// A slice with cap 0
if v1.Cap() != 0 && !assert.NotEqual(t, v1.Pointer(), v2.Pointer(), "%s point to the same slice %v == %v", name, v1.Pointer(), v2.Pointer()) {
return false
}
v1c := v1.Cap()
v2c := v2.Cap()
if v1c > 0 && v2c > 0 && v1.Slice(0, v1c).Slice(v1c-1, v1c-1).Pointer() == v2.Slice(0, v2c).Slice(v2c-1, v2c-1).Pointer() {
return assert.Fail(t, "", "%s share some underlying memory", name)
}
for i := 0; i < v1.Len(); i++ {
if !traverseDeepCopy(t, v1.Index(i), v2.Index(i), fmt.Sprintf("%s[%v]", name, i)) {
return false
}
}
return true
case reflect.Interface:
if v1.IsNil() || v2.IsNil() {
return assert.Equal(t, v1.IsNil(), v2.IsNil(), "%s are not both nil", name)
}
return traverseDeepCopy(t, v1.Elem(), v2.Elem(), name)
case reflect.Ptr:
local := reflect.ValueOf(time.Local).Pointer()
if local == v1.Pointer() && local == v2.Pointer() {
return true
}
if !assert.NotEqual(t, v1.Pointer(), v2.Pointer(), "%s points to the same memory", name) {
return false
}
return traverseDeepCopy(t, v1.Elem(), v2.Elem(), name)
case reflect.Struct:
for i, n := 0, v1.NumField(); i < n; i++ {
if !traverseDeepCopy(t, v1.Field(i), v2.Field(i), name+"."+v1.Type().Field(i).Name) {
return false
}
}
return true
case reflect.Map:
if v1.IsNil() || v2.IsNil() {
return assert.Equal(t, v1.IsNil(), v2.IsNil(), "%s are not both nil", name)
}
if !assert.Equal(t, v1.Len(), v2.Len(), "%s are not the same length", name) {
return false
}
if !assert.NotEqual(t, v1.Pointer(), v2.Pointer(), "%s point to the same memory", name) {
return false
}
for _, k := range v1.MapKeys() {
val1 := v1.MapIndex(k)
val2 := v2.MapIndex(k)
if !assert.True(t, val1.IsValid(), "%s is an invalid key in %s", k, name) {
return false
}
if !assert.True(t, val2.IsValid(), "%s is an invalid key in %s", k, name) {
return false
}
if !traverseDeepCopy(t, val1, val2, name+fmt.Sprintf("%s[%s]", name, k)) {
return false
}
}
return true
default:
if v1.CanInterface() && v2.CanInterface() {
return assert.Equal(t, v1.Interface(), v2.Interface(), "%s was not equal", name)
}
e1 := reflect.NewAt(v1.Type(), unsafe.Pointer(v1.UnsafeAddr())).Elem().Interface()
e2 := reflect.NewAt(v2.Type(), unsafe.Pointer(v2.UnsafeAddr())).Elem().Interface()
return assert.Equal(t, e1, e2, "%s (unexported) was not equal", name)
}
}