Compare commits

..

51 Commits

Author SHA1 Message Date
Wade Simmons
c71c84882e v1.3.0 (#268)
Update the CHANGELOG for Nebula v1.3.0

Co-authored-by: forfuncsake <drussell@slack-corp.com>
2020-09-22 12:21:12 -04:00
Darren Hoo
0010db46e4 Fix a data race on message counter (#284)
3. ==================
WARNING: DATA RACE
Write at 0x00c00030e020 by goroutine 17:
  sync/atomic.AddInt64()
      runtime/race_amd64.s:276 +0xb
  github.com/slackhq/nebula.(*Interface).sendNoMetrics()
      github.com/slackhq/nebula/inside.go:226 +0x9c
  github.com/slackhq/nebula.(*Interface).send()
      github.com/slackhq/nebula/inside.go:214 +0x149
  github.com/slackhq/nebula.(*Interface).readOutsidePackets()
      github.com/slackhq/nebula/outside.go:94 +0x1213
  github.com/slackhq/nebula.(*udpConn).ListenOut()
      github.com/slackhq/nebula/udp_generic.go:109 +0x3b5
  github.com/slackhq/nebula.(*Interface).listenOut()
      github.com/slackhq/nebula/interface.go:147 +0x15e

Previous read at 0x00c00030e020 by goroutine 18:
  github.com/slackhq/nebula.(*Interface).consumeInsidePacket()
      github.com/slackhq/nebula/inside.go:58 +0x892
  github.com/slackhq/nebula.(*Interface).listenIn()
      github.com/slackhq/nebula/interface.go:164 +0x178
2020-09-21 21:41:46 -04:00
Nathan Brown
68e3e84fdc More like a library (#279) 2020-09-18 09:20:09 -05:00
Brian Luong
6238f1550b Handle panic when invalid IP entered in sshd (#296) 2020-09-18 10:10:25 -04:00
forfuncsake
50b04413c7 Block nebula ssh server from listening on port 22 (#266)
Port 22 is blocked as a safety mechanism. In a case where nebula is
started before sshd, a system may be rendered unreachable if nebula
is holding the system ssh port and there is no other connectivity.

This commit enforces the restriction, which could previously be worked
around by listening on an IPv6 address, e.g.  "[::]:22".
2020-09-15 09:57:32 -04:00
CzBiX
ef498a31da Add disable_timestamp option (#288) 2020-09-09 07:42:11 -04:00
forfuncsake
2e5a477a50 Align linux UDP performance optimizations with configuration (#275)
* Remove unused (*udpConn).Read method

* Align linux UDP performance optimizations with configuration

While attempting to run nebula on an older Synology NAS, it became
apparent that some of the performance optimizations effectively
block support for older kernels. The recvmmsg syscall was added in
linux kernel 2.6.33, and the Synology DS212j (among other models)
is pinned to 2.6.32.12.

Similarly, SO_REUSEPORT was added to the kernel in the 3.9 cycle.
While this option has been backported into some older trees, it
is also missing from the Synology kernel.

This commit allows nebula to be run on linux devices with older
kernels if the config options are set up with a single listener
and a UDP batch size of 1.
2020-08-13 08:24:05 +10:00
Wade Simmons
32fe9bfe75 Use Go 1.15 (#277)
Update all CI checks and release process to use the latest patch version
of go1.15.
2020-08-12 16:16:21 -04:00
forfuncsake
9b8b3c478b Support startup without a tun device (#269)
This commit adds support for Nebula to be started without creating
a tun device. A node started in this mode still has a full "control
plane", but no effective "data plane". Its use is suited to a
lighthouse that has no need to partake in the mesh VPN.

Consequently, creation of the tun device is the only reason nebula
neesd to be started with elevated privileged, so this example
lighthouse can also be run as a non-root user.
2020-08-10 09:15:55 -04:00
Michael Hardy
7b3f23d9a1 Start nebula after the network is up (#270) 2020-08-07 11:33:48 -05:00
forfuncsake
25964b54f6 Use inclusive terminology for cert blocking (#272) 2020-08-06 11:17:47 +10:00
Wade Simmons
ac557f381b drop unroutable packets (#267)
Currently, if a packet arrives on the tun device with a destination that
is not a routable Nebula IP, `queryUnsafeRoute` converts that IP to
0.0.0.0 and we store that packet and try to look up that IP with the
lighthouse. This doesn't make any sense to do, if we get a packet that
is unroutable we should just drop it.

Note, we have a few configurable options like `drop_local_broadcast`
and `drop_multicast` which do this for a few specific types, but since
no packets like this will send correctly I think we should just drop
anything that is unroutable.
2020-08-04 22:59:04 -04:00
Wade Simmons
a54f3fc681 fix fast handshake trigger for static hosts (#265)
We are currently triggering a fast handshake for static hosts right
inside HandshakeManager.AddVpnIP, but this can actually trigger before
we have generated the handshake packet to use. Instead, we should be
triggering right after we call ixHandshakeStage0 in getOrHandshake
(which generates the handshake packet)
2020-08-02 20:59:50 -04:00
Alan Lam
5545cff6ef log remote certificate fingerprint on handshakes (#262) 2020-07-31 18:54:51 -04:00
Wade Simmons
f3a6d8d990 Preserve conntrack table during firewall rules reload (SIGHUP) (#233)
Currently, we drop the conntrack table when firewall rules change during a SIGHUP reload. This means responses to inflight HTTP requests can be dropped, among other issues. This change copies the conntrack table over to the new firewall (it holds the conntrack mutex lock during this process, to be safe).

This change also records which firewall rules hash each conntrack entry used, so that we can re-verify the rules after the new firewall has been loaded.
2020-07-31 18:53:36 -04:00
forfuncsake
9b06748506 Make Interface.Inside an interface type (#252)
This commit updates the Interface.Inside type to be a new interface
type instead of a *Tun. This will allow for an inside interface
that does not use a tun device, such as a single-binary client that
can run without elevated privileges.
2020-07-28 08:53:16 -04:00
Wade Simmons
4756c9613d trigger handshakes when lighthouse reply arrives (#246)
Currently, we wait until the next timer tick to act on the lighthouse's
reply to our HostQuery. This means we can easily add hundreds of
milliseconds of unnecessary delay to the handshake. To fix this, we
can introduce a channel to trigger an outbound handshake without waiting
for the next timer tick.

A few samples of cold ping time between two hosts that require a
lighthouse lookup:

    before (v1.2.0):

    time=156 ms
    time=252 ms
    time=12.6 ms
    time=301 ms
    time=352 ms
    time=49.4 ms
    time=150 ms
    time=13.5 ms
    time=8.24 ms
    time=161 ms
    time=355 ms

    after:

    time=3.53 ms
    time=3.14 ms
    time=3.08 ms
    time=3.92 ms
    time=7.78 ms
    time=3.59 ms
    time=3.07 ms
    time=3.22 ms
    time=3.12 ms
    time=3.08 ms
    time=8.04 ms

I recommend reviewing this PR by looking at each commit individually, as
some refactoring was required that makes the diff a bit confusing when
combined together.
2020-07-22 10:35:10 -04:00
Nathan Brown
4645e6034b Fix up the tun for android (#249) 2020-07-01 10:20:52 -05:00
Wade Simmons
aba42f9fa6 enforce the use of goimports (#248)
* enforce the use of goimports

Instead of enforcing `gofmt`, enforce `goimports`, which also asserts
a separate section for non-builtin packages.

* run `goimports` everywhere

* exclude generated .pb.go files
2020-06-30 18:53:30 -04:00
Nathan Brown
41578ca971 Be more like a library to support mobile (#247) 2020-06-30 13:48:58 -05:00
Wade Simmons
1ea8847085 linux: set advmss correctly when route MTU is used (#245)
If different mtus are specified for different routes, we should set
advmss on each route because Linux does a poor job of selecting the
default (from ip-route(8)):

    advmss NUMBER (Linux 2.3.15+ only)
           the MSS ('Maximal Segment Size') to advertise to these destinations when estab‐
           lishing TCP connections. If it is not given, Linux uses a default value calcu‐
           lated from the first hop device MTU.  (If the path to these destination is asym‐
           metric, this guess may be wrong.)

Note that the default value is calculated from the first hop *device
MTU*, not the *route MTU*. In practice this is usually ok as long as the
other side of the tunnel has the mtu configured exactly the same, but we
should probably just set advmss correctly on these routes.
2020-06-26 13:47:21 -04:00
Wade Simmons
55858c64cc smoke test: test firewall inbound / outbound (#240)
Test that basic inbound / outbound firewall rules work during the smoke
test. This change sets an inbound firewall rule on host3, and a new
host4 with outbound firewall rules. It also tests that conntrack allows
packets once the connection has been established.
2020-06-26 13:46:51 -04:00
Wade Simmons
e94c6b0125 mips-softfloat (#231)
This makes GOARM more generic and does GOMIPS in a similar way to
support mips-softfloat. We also set `-ldflags "-s -w"` for
mips-softfloat to give the best chance of the binary working on these
small devices.
2020-06-26 13:46:23 -04:00
Wade Simmons
b37a91cfbc add meta packet statistics (#230)
This change add more metrics around "meta" (non "message" type packets).
For lighthouse packets, we also record statistics around the specific
lighthouse meta type.

We don't keep statistics for the "message" type so that we don't slow
down the fast path (and you can just look at metrics on the tun
interface to find that information).
2020-06-26 13:45:48 -04:00
David Sonder
3212b769d4 fix typo in conntrack section in examples/config.yml (#236)
the rest of the conntrack values match the default
2020-06-26 11:08:22 -05:00
Patrick Bogen
ecf0e5a9f6 drop packets even if we aren't going to emit Debug logs about it (#239)
* drop packets even if we aren't going to emit Debug logs about it

* smallify change
2020-06-10 16:55:49 -05:00
Wade Simmons
ff13aba8fc allow go test -bench=. to run (#234)
This benchmark had an Errorf at the end, lets remove it so the
benchmarks all run.
2020-05-27 16:52:34 -04:00
Mateusz Kwiatkowski
cc03ff9e9a Unbreak building for FreeBSD (#103)
Add support for freebsd. You have to set `tun.dev` in your config. The second pass of this would be to remove the exec calls and use ioctl(2) and route(4) instead, but we can do that in a second PR.

Co-authored-by: Wade Simmons <wade@wades.im>
2020-05-26 22:23:23 -04:00
Patrick Bogen
363c836422 log the reason for fw drops (#220)
* log the reason for fw drops

* only prepare log if we will end up sending it
2020-04-10 10:57:21 -07:00
Wade Simmons
fb252db4a1 v1.2.0 (#215)
Add descriptions for all commits since v1.1.0
2020-04-08 19:52:24 -04:00
Wade Simmons
4f6313ebd3 fix config name for {remote,local}_allow_list (#219)
This config option should be snake_case, not camelCase.
2020-04-08 16:20:12 -04:00
Wade Simmons
0a474e757b Add lighthouse.{remoteAllowList,localAllowList} (#217)
These settings make it possible to blacklist / whitelist IP addresses
that are used for remote connections.

`lighthouse.remoteAllowList` filters which remote IPs are allow when
fetching from the lighthouse (or, if you are the lighthouse, which IPs
you store and forward to querying hosts). By default, any remote IPs are
allowed. You can provide CIDRs here with `true` to allow and `false` to
deny. The most specific CIDR rule applies to each remote.  If all rules
are "allow", the default will be "deny", and vice-versa. If both "allow"
and "deny" rules are present, then you MUST set a rule for "0.0.0.0/0"
as the default.

    lighthouse:
      remoteAllowList:
        # Example to block IPs from this subnet from being used for remote IPs.
        "172.16.0.0/12": false

        # A more complicated example, allow public IPs but only private IPs from a specific subnet
        "0.0.0.0/0": true
        "10.0.0.0/8": false
        "10.42.42.0/24": true

`lighthouse.localAllowList` has the same logic as above, but it applies
to the local addresses we advertise to the lighthouse. Additionally, you
can specify an `interfaces` map of regular expressions to match against
interface names. The regexp must match the entire name. All interface
rules must be either true or false (and the default rule will be the
inverse). CIDR rules are matched after interface name rules.

Default is all local IP addresses.

    lighthouse:
      localAllowList:
        # Example to blacklist docker interfaces.
        interfaces:
          'docker.*': false

        # Example to only advertise IPs in this subnet to the lighthouse.
        "10.0.0.0/8": true
2020-04-08 15:36:43 -04:00
Nathan Brown
7cd342c7ab Add a systemd unit for arch and a wireshark dissector (#216) 2020-04-06 18:47:32 -07:00
Wade Simmons
7cdbb14a18 Better config test (#177)
* Better config test

Previously, when using the config test option `-test`, we quit fairly
earlier in the process and would not catch a variety of additional
parsing errors (such as lighthouse IP addresses, local_range, the new
check to make sure static hosts are in the certificate's subnet, etc).

* run config test as part of smoke test

* don't need privileges for configtest

Co-authored-by: Nathan Brown <nate@slack-corp.com>
2020-04-06 11:35:32 -07:00
Wade Simmons
b4f2f7ce4e log certName alongside vpnIp (#200)
This change adds a new helper, `(*HostInfo).logger()`, that starts a new
logrus.Entry with `vpnIp` and `certName`. We don't use the helper inside
of handshake_ix though since the certificate has not been attached to
the HostInfo yet.

Fixes: #84
2020-04-06 11:34:00 -07:00
Alex
ff64d1f952 unsafe_routes mtu (#209) 2020-04-06 11:33:30 -07:00
Felix Yan
9e2ff7df57 Correct typos in noise.go (#205) 2020-03-30 11:23:55 -07:00
Ryan Huber
1297090af3 add configurable punching delay because of race-condition-y conntracks (#210)
* add configurable punching delay because of race-condition-y conntracks

* add changelog

* fix tests

* only do one punch per query

* Coalesce punchy config

* It is not is not set

* Add tests

Co-authored-by: Nate Brown <nbrown.us@gmail.com>
2020-03-27 11:26:39 -07:00
Wade Simmons
add1b21777 only create a CIDRTree for each host if necessary (#198)
A CIDRTree can be expensive to create, so only do it if we need
it. If the remote host only has one IP address and no subnets, just do
an exact IP match instead.

Fixes: #171
2020-03-02 16:21:33 -05:00
Wade Simmons
1cb3201b5e Github Actions: cache modules and only run when necessary (#197)
This PR does two things:

- Only run the tests when relevant files change.
- Cache the Go Modules directory between runs, so they don't have to redownload everything everytime (go.sum is the cache key). Pretty much straight from the examples: https://github.com/actions/cache/blob/master/examples.md#go---modules
2020-03-02 16:21:19 -05:00
Ryan Huber
41968551f9 clarify that lighthouse IP should be nebula range (#196) 2020-02-28 11:35:55 -08:00
Wade Simmons
8548ac3c31 build and test with go1.14 (#195)
- https://golang.org/doc/go1.14

I did a performance sanity check in Docker, and performance seems about
the same (perhaps slightly higher).
2020-02-27 15:48:39 -05:00
Wade Simmons
fb9b36f677 allow any config file name if specified directly (#189)
Currently, we require that config file names end with `.yml` or `.yaml`.
This is because if the user points `-config` at a directory of files, we
only want to use the YAML files in that directory.

But this makes it more difficult to use the `-test -config` option
because config management tools might not have an extension on the file
when preparing a new config file. This change makes it so that if you
point `-config file` directly at a file, it uses it no matter what the
extension is.
2020-02-26 15:38:56 -05:00
Sebastien Bariteau
4d1928f1e3 Support unsafe_routes on Windows (#184)
* Support unsafe_routes on Windows

* Full path to route executable

* Escape string properly
2020-02-26 15:23:16 -05:00
Ryan Huber
a91a40212d check that packet isn't bound for my vpn ip (#192) 2020-02-21 16:49:54 -08:00
Wade Simmons
179a369130 add configuration options for HandshakeManager (#179)
This change exposes the current constants we have defined for the handshake
manager as configuration options. This will allow us to test and tweak
with different intervals and wait rotations.

    # Handshake Manger Settings
    handshakes:
      # Total time to try a handshake = sequence of `try_interval * retries`
      # With 100ms interval and 20 retries it is 23.5 seconds
      try_interval: 100ms
      retries: 20

      # wait_rotation is the number of handshake attempts to do before starting to try non-local IP addresses
      wait_rotation: 5
2020-02-21 16:25:11 -05:00
Wade Simmons
df69371620 use absolute paths on darwin and windows (#191)
We want to make sure to use the system binaries, and not whatever is in
the PATH.
2020-02-21 15:25:33 -05:00
Wade Simmons
eda344d88f add logging.timestamp_format config option (#187)
This change introduces logging.timestamp_format, which allows
configuration of the Logrus TimestampFormat setting. The primary purpose
of this change was to allow logging with millisecond precision. The
default for `text` and `json` formats remains the same for backwards
compatibility.

timestamp format is specified in Go time format, see:

 - https://golang.org/pkg/time/#pkg-constants

Default when `format: json`: "2006-01-02T15:04:05Z07:00" (RFC3339)
Default when `format: text`:
  when TTY attached: seconds since beginning of execution
  otherwise: "2006-01-02T15:04:05Z07:00" (RFC3339)

As an example, to log as RFC3339 with millisecond precision, set to:

    logging:
        timestamp_format: "2006-01-02T15:04:05.000Z07:00"
2020-02-21 15:25:00 -05:00
Wade Simmons
065e2ff88a update golang.org/x/crypto (#188)
This version contains a fix for CVE-2020-9283, a remote crash bug:

- https://groups.google.com/forum/#!msg/golang-announce/3L45YRc91SY/ywEPcKLnGQAJ
2020-02-20 14:49:55 -05:00
Nathan Brown
45a5de2719 Print the udp listen address on startup (#181) 2020-02-06 21:17:43 -08:00
Wade Simmons
2d24ef7166 validate lighthouses and static hosts are in our subnet (#170)
Validate all lighthouse.hosts and static_host_map VPN IPs are in the
subnet defined in our cert. Exit with a fatal error if they are not in
our subnet, as this is an invalid configuration (we will not have the
proper routes set up to communicate with these hosts).

This error case could occur for the following invalid example:

    nebula-cert sign -name "lighthouse" -ip "10.0.1.1/24"
    nebula-cert sign -name "host" -ip "10.0.2.1/24"

    config.yaml:

        static_host_map:
            "10.0.1.1": ["lighthouse.local:4242"]
        lighthouse:
          hosts:
            - "10.0.1.1"

We will now return a fatal error for this config, since `10.0.1.1` is
not in the host cert's subnet of `10.0.2.1/24`
2020-01-20 15:52:55 -05:00
82 changed files with 3425 additions and 540 deletions

View File

@@ -4,6 +4,9 @@ on:
branches:
- master
pull_request:
paths:
- '.github/workflows/gofmt.yml'
- '**.go'
jobs:
gofmt:
@@ -11,19 +14,31 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Set up Go 1.13
- name: Set up Go 1.15
uses: actions/setup-go@v1
with:
go-version: 1.13
go-version: 1.15
id: go
- name: Check out code into the Go module directory
uses: actions/checkout@v1
- uses: actions/cache@v1
with:
path: ~/go/pkg/mod
key: ${{ runner.os }}-gofmt-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-gofmt-
- name: Install goimports
run: |
go get golang.org/x/tools/cmd/goimports
go build golang.org/x/tools/cmd/goimports
- name: gofmt
run: |
if [ "$(find . -iname '*.go' | xargs gofmt -l)" ]
if [ "$(find . -iname '*.go' | grep -v '\.pb\.go$' | xargs ./goimports -l)" ]
then
find . -iname '*.go' | xargs gofmt -d
find . -iname '*.go' | grep -v '\.pb\.go$' | xargs ./goimports -d
exit 1
fi

View File

@@ -10,17 +10,17 @@ jobs:
name: Build Linux All
runs-on: ubuntu-latest
steps:
- name: Set up Go 1.13
- name: Set up Go 1.15
uses: actions/setup-go@v1
with:
go-version: 1.13
go-version: 1.15
- name: Checkout code
uses: actions/checkout@v2
- name: Build
run: |
make BUILD_NUMBER="${GITHUB_REF#refs/tags/v}" release-linux
make BUILD_NUMBER="${GITHUB_REF#refs/tags/v}" release-linux release-freebsd
mkdir release
mv build/*.tar.gz release
@@ -34,10 +34,10 @@ jobs:
name: Build Windows amd64
runs-on: windows-latest
steps:
- name: Set up Go 1.13
- name: Set up Go 1.15
uses: actions/setup-go@v1
with:
go-version: 1.13
go-version: 1.15
- name: Checkout code
uses: actions/checkout@v2
@@ -58,10 +58,10 @@ jobs:
name: Build Darwin amd64
runs-on: macOS-latest
steps:
- name: Set up Go 1.13
- name: Set up Go 1.15
uses: actions/setup-go@v1
with:
go-version: 1.13
go-version: 1.15
- name: Checkout code
uses: actions/checkout@v2
@@ -278,3 +278,23 @@ jobs:
asset_path: ./linux-latest/nebula-linux-mips64le.tar.gz
asset_name: nebula-linux-mips64le.tar.gz
asset_content_type: application/gzip
- name: Upload linux-mips-softfloat
uses: actions/upload-release-asset@v1.0.1
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
with:
upload_url: ${{ steps.create_release.outputs.upload_url }}
asset_path: ./linux-latest/nebula-linux-mips-softfloat.tar.gz
asset_name: nebula-linux-mips-softfloat.tar.gz
asset_content_type: application/gzip
- name: Upload freebsd-amd64
uses: actions/upload-release-asset@v1.0.1
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
with:
upload_url: ${{ steps.create_release.outputs.upload_url }}
asset_path: ./linux-latest/nebula-freebsd-amd64.tar.gz
asset_name: nebula-freebsd-amd64.tar.gz
asset_content_type: application/gzip

View File

@@ -4,22 +4,36 @@ on:
branches:
- master
pull_request:
paths:
- '.github/workflows/smoke**'
- '**Makefile'
- '**.go'
- '**.proto'
- 'go.mod'
- 'go.sum'
jobs:
smoke:
name: Run 3 node smoke test
name: Run multi node smoke test
runs-on: ubuntu-latest
steps:
- name: Set up Go 1.13
- name: Set up Go 1.15
uses: actions/setup-go@v1
with:
go-version: 1.13
go-version: 1.15
id: go
- name: Check out code into the Go module directory
uses: actions/checkout@v1
- uses: actions/cache@v1
with:
path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-go-
- name: build
run: make

View File

@@ -11,14 +11,29 @@ mkdir ./build
cp ../../../../nebula .
cp ../../../../nebula-cert .
HOST="lighthouse1" AM_LIGHTHOUSE=true ../genconfig.sh >lighthouse1.yml
HOST="host2" LIGHTHOUSES="192.168.100.1 172.17.0.2:4242" ../genconfig.sh >host2.yml
HOST="host3" LIGHTHOUSES="192.168.100.1 172.17.0.2:4242" ../genconfig.sh >host3.yml
HOST="lighthouse1" \
AM_LIGHTHOUSE=true \
../genconfig.sh >lighthouse1.yml
HOST="host2" \
LIGHTHOUSES="192.168.100.1 172.17.0.2:4242" \
../genconfig.sh >host2.yml
HOST="host3" \
LIGHTHOUSES="192.168.100.1 172.17.0.2:4242" \
INBOUND='[{"port": "any", "proto": "icmp", "group": "lighthouse"}]' \
../genconfig.sh >host3.yml
HOST="host4" \
LIGHTHOUSES="192.168.100.1 172.17.0.2:4242" \
OUTBOUND='[{"port": "any", "proto": "icmp", "group": "lighthouse"}]' \
../genconfig.sh >host4.yml
./nebula-cert ca -name "Smoke Test"
./nebula-cert sign -name "lighthouse1" -ip "192.168.100.1/24"
./nebula-cert sign -name "host2" -ip "192.168.100.2/24"
./nebula-cert sign -name "host3" -ip "192.168.100.3/24"
./nebula-cert sign -name "lighthouse1" -groups "lighthouse,lighthouse1" -ip "192.168.100.1/24"
./nebula-cert sign -name "host2" -groups "host,host2" -ip "192.168.100.2/24"
./nebula-cert sign -name "host3" -groups "host,host3" -ip "192.168.100.3/24"
./nebula-cert sign -name "host4" -groups "host,host4" -ip "192.168.100.4/24"
)
docker build -t nebula:smoke .

View File

@@ -2,6 +2,7 @@
set -e
FIREWALL_ALL='[{"port": "any", "proto": "any", "host": "any"}]'
if [ "$STATIC_HOSTS" ] || [ "$LIGHTHOUSES" ]
then
@@ -48,13 +49,6 @@ tun:
dev: ${TUN_DEV:-nebula1}
firewall:
outbound:
- port: any
proto: any
host: any
inbound:
- port: any
proto: any
host: any
outbound: ${OUTBOUND:-$FIREWALL_ALL}
inbound: ${INBOUND:-$FIREWALL_ALL}
EOF

View File

@@ -2,12 +2,19 @@
set -e -x
docker run --name lighthouse1 --rm nebula:smoke -config lighthouse1.yml -test
docker run --name host2 --rm nebula:smoke -config host2.yml -test
docker run --name host3 --rm nebula:smoke -config host3.yml -test
docker run --name host4 --rm nebula:smoke -config host4.yml -test
docker run --name lighthouse1 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm nebula:smoke -config lighthouse1.yml &
sleep 1
docker run --name host2 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm nebula:smoke -config host2.yml &
sleep 1
docker run --name host3 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm nebula:smoke -config host3.yml &
sleep 1
docker run --name host4 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm nebula:smoke -config host4.yml &
sleep 1
set +x
echo
@@ -23,7 +30,8 @@ echo " *** Testing ping from host2"
echo
set -x
docker exec host2 ping -c1 192.168.100.1
docker exec host2 ping -c1 192.168.100.3
# Should fail because not allowed by host3 inbound firewall
! docker exec host2 ping -c1 192.168.100.3 -w5 || exit 1
set +x
echo
@@ -32,3 +40,24 @@ echo
set -x
docker exec host3 ping -c1 192.168.100.1
docker exec host3 ping -c1 192.168.100.2
set +x
echo
echo " *** Testing ping from host4"
echo
set -x
docker exec host4 ping -c1 192.168.100.1
# Should fail because not allowed by host4 outbound firewall
! docker exec host4 ping -c1 192.168.100.2 -w5 || exit 1
! docker exec host4 ping -c1 192.168.100.3 -w5 || exit 1
set +x
echo
echo " *** Testing conntrack"
echo
set -x
# host2 can ping host3 now that host3 pinged it first
docker exec host2 ping -c1 192.168.100.3
# host4 can ping host2 once conntrack established
docker exec host2 ping -c1 192.168.100.4
docker exec host4 ping -c1 192.168.100.2

View File

@@ -4,6 +4,13 @@ on:
branches:
- master
pull_request:
paths:
- '.github/workflows/test.yml'
- '**Makefile'
- '**.go'
- '**.proto'
- 'go.mod'
- 'go.sum'
jobs:
test-linux:
@@ -11,15 +18,22 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Set up Go 1.13
- name: Set up Go 1.15
uses: actions/setup-go@v1
with:
go-version: 1.13
go-version: 1.15
id: go
- name: Check out code into the Go module directory
uses: actions/checkout@v1
- uses: actions/cache@v1
with:
path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-go-
- name: Build
run: make all
@@ -34,15 +48,22 @@ jobs:
os: [windows-latest, macOS-latest]
steps:
- name: Set up Go 1.13
- name: Set up Go 1.15
uses: actions/setup-go@v1
with:
go-version: 1.13
go-version: 1.15
id: go
- name: Check out code into the Go module directory
uses: actions/checkout@v1
- uses: actions/cache@v1
with:
path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-go-
- name: Build nebula
run: go build ./cmd/nebula

View File

@@ -7,6 +7,139 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [Unreleased]
## [1.3.0] - 2020-09-22
### Added
- You can emit statistics about non-message packets by setting the option
`stats.message_metrics`. You can similarly emit detailed statistics about
lighthouse packets by setting the option `stats.lighthouse_metrics`. See
the example config for more details. (#230)
- We now support freebsd/amd64. This is experimental, please give us feedback.
(#103)
- We now release a binary for `linux/mips-softfloat` which has also been
stripped to reduce filesize and hopefully have a better chance on running on
small mips devices. (#231)
- You can set `tun.disabled` to true to run a standalone lighthouse without a
tun device (and thus, without root). (#269)
- You can set `logging.disable_timestamp` to remove timestamps from log lines,
which is useful when output is redirected to a logging system that already
adds timestamps. (#288)
### Changed
- Handshakes should now trigger faster, as we try to be proactive with sending
them instead of waiting for the next timer tick in most cases. (#246, #265)
- Previously, we would drop the conntrack table whenever firewall rules were
changed during a SIGHUP. Now, we will maintain the table and just validate
that an entry still matches with the new rule set. (#233)
- Debug logs for firewall drops now include the reason. (#220, #239)
- Logs for handshakes now include the fingerprint of the remote host. (#262)
- Config item `pki.blacklist` is now `pki.blocklist`. (#272)
- Better support for older Linux kernels. We now only set `SO_REUSEPORT` if
`tun.routines` is greater than 1 (default is 1). We also only use the
`recvmmsg` syscall if `listen.batch` is greater than 1 (default is 64).
(#275)
- It is possible to run Nebula as a library inside of another process now.
Note that this is still experimental and the internal APIs around this might
change in minor version releases. (#279)
### Deprecated
- `pki.blacklist` is deprecated in favor of `pki.blocklist` with the same
functionality. Existing configs will continue to load for this release to
allow for migrations. (#272)
### Fixed
- `advmss` is now set correctly for each route table entry when `tun.routes`
is configured to have some routes with higher MTU. (#245)
- Packets that arrive on the tun device with an unroutable destination IP are
now dropped correctly, instead of wasting time making queries to the
lighthouses for IP `0.0.0.0` (#267)
## [1.2.0] - 2020-04-08
### Added
- Add `logging.timestamp_format` config option. The primary purpose of this
change is to allow logging timestamps with millisecond precision. (#187)
- Support `unsafe_routes` on Windows. (#184)
- Add `lighthouse.remote_allow_list` to filter which subnets we will use to
handshake with other hosts. See the example config for more details. (#217)
- Add `lighthouse.local_allow_list` to filter which local IP addresses and/or
interfaces we advertise to the lighthouses. See the example config for more
details. (#217)
- Wireshark dissector plugin. Add this file in `dist/wireshark` to your
Wireshark plugins folder to see Nebula packet headers decoded. (#216)
- systemd unit for Arch, so it can be built entirely from this repo. (#216)
### Changed
- Added a delay to punching via lighthouse signal to deal with race conditions
in some linux conntrack implementations. (#210)
See deprecated, this also adds a new `punchy.delay` option that defaults to `1s`.
- Validate all `lighthouse.hosts` and `static_host_map` VPN IPs are in the
subnet defined in our cert. Exit with a fatal error if they are not in our
subnet, as this is an invalid configuration (we will not have the proper
routes set up to communicate with these hosts). (#170)
- Use absolute paths to system binaries on macOS and Windows. (#191)
- Add configuration options for `handshakes`. This includes options to tweak
`try_interval`, `retries` and `wait_rotation`. See example config for
descriptions. (#179)
- Allow `-config` file to not end in `.yaml` or `yml`. Useful when using
`-test` and automated tools like Ansible that create temporary files without
suffixes. (#189)
- The config test mode, `-test`, is now more thorough and catches more parsing
issues. (#177)
- Various documentation and example fixes. (#196)
- Improved log messages. (#181, #200)
- Dependencies updated. (#188)
### Deprecated
- `punchy`, `punch_back` configuration options have been collapsed under the
now top level `punchy` config directive. (#210)
`punchy.punch` - This is the old `punchy` option. Should we perform NAT hole
punching (default false)?
`punchy.respond` - This is the old `punch_back` option. Should we respond to
hole punching by hole punching back (default false)?
### Fixed
- Reduce memory allocations when not using `unsafe_routes`. (#198)
- Ignore packets from self to self. (#192)
- MTU fixed for `unsafe_routes`. (#209)
## [1.1.0] - 2020-01-17
### Added
@@ -47,6 +180,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Initial public release.
[Unreleased]: https://github.com/slackhq/nebula/compare/v1.1.0...HEAD
[Unreleased]: https://github.com/slackhq/nebula/compare/v1.3.0...HEAD
[1.3.0]: https://github.com/slackhq/nebula/releases/tag/v1.3.0
[1.2.0]: https://github.com/slackhq/nebula/releases/tag/v1.2.0
[1.1.0]: https://github.com/slackhq/nebula/releases/tag/v1.1.0
[1.0.0]: https://github.com/slackhq/nebula/releases/tag/v1.0.0

View File

@@ -3,6 +3,8 @@ BUILD_NUMBER ?= dev+$(shell date -u '+%Y%m%d%H%M%S')
GO111MODULE = on
export GO111MODULE
LDFLAGS = -X main.Build=$(BUILD_NUMBER)
ALL_LINUX = linux-amd64 \
linux-386 \
linux-ppc64le \
@@ -13,10 +15,12 @@ ALL_LINUX = linux-amd64 \
linux-mips \
linux-mipsle \
linux-mips64 \
linux-mips64le
linux-mips64le \
linux-mips-softfloat
ALL = $(ALL_LINUX) \
darwin-amd64 \
freebsd-amd64 \
windows-amd64
all: $(ALL:%=build/%/nebula) $(ALL:%=build/%/nebula-cert)
@@ -25,31 +29,40 @@ release: $(ALL:%=build/nebula-%.tar.gz)
release-linux: $(ALL_LINUX:%=build/nebula-%.tar.gz)
release-freebsd: build/nebula-freebsd-amd64.tar.gz
bin-windows: build/windows-amd64/nebula.exe build/windows-amd64/nebula-cert.exe
mv $? .
bin-darwin: build/darwin-amd64/nebula build/darwin-amd64/nebula-cert
mv $? .
bin-freebsd: build/freebsd-amd64/nebula build/freebsd-amd64/nebula-cert
mv $? .
bin:
go build -trimpath -ldflags "-X main.Build=$(BUILD_NUMBER)" -o ./nebula ${NEBULA_CMD_PATH}
go build -trimpath -ldflags "-X main.Build=$(BUILD_NUMBER)" -o ./nebula-cert ./cmd/nebula-cert
go build -trimpath -ldflags "$(LDFLAGS)" -o ./nebula ${NEBULA_CMD_PATH}
go build -trimpath -ldflags "$(LDFLAGS)" -o ./nebula-cert ./cmd/nebula-cert
install:
go install -trimpath -ldflags "-X main.Build=$(BUILD_NUMBER)" ${NEBULA_CMD_PATH}
go install -trimpath -ldflags "-X main.Build=$(BUILD_NUMBER)" ./cmd/nebula-cert
go install -trimpath -ldflags "$(LDFLAGS)" ${NEBULA_CMD_PATH}
go install -trimpath -ldflags "$(LDFLAGS)" ./cmd/nebula-cert
build/linux-arm-%: GOENV += GOARM=$(word 3, $(subst -, ,$*))
build/linux-mips-%: GOENV += GOMIPS=$(word 3, $(subst -, ,$*))
# Build an extra small binary for mips-softfloat
build/linux-mips-softfloat/%: LDFLAGS += -s -w
build/%/nebula: .FORCE
GOOS=$(firstword $(subst -, , $*)) \
GOARCH=$(word 2, $(subst -, ,$*)) \
GOARM=$(word 3, $(subst -, ,$*)) \
go build -trimpath -o $@ -ldflags "-X main.Build=$(BUILD_NUMBER)" ${NEBULA_CMD_PATH}
GOARCH=$(word 2, $(subst -, ,$*)) $(GOENV) \
go build -trimpath -o $@ -ldflags "$(LDFLAGS)" ${NEBULA_CMD_PATH}
build/%/nebula-cert: .FORCE
GOOS=$(firstword $(subst -, , $*)) \
GOARCH=$(word 2, $(subst -, ,$*)) \
GOARM=$(word 3, $(subst -, ,$*)) \
go build -trimpath -o $@ -ldflags "-X main.Build=$(BUILD_NUMBER)" ./cmd/nebula-cert
GOARCH=$(word 2, $(subst -, ,$*)) $(GOENV) \
go build -trimpath -o $@ -ldflags "$(LDFLAGS)" ./cmd/nebula-cert
build/%/nebula.exe: build/%/nebula
mv $< $@

48
allow_list.go Normal file
View File

@@ -0,0 +1,48 @@
package nebula
import (
"fmt"
"regexp"
)
type AllowList struct {
// The values of this cidrTree are `bool`, signifying allow/deny
cidrTree *CIDRTree
// To avoid ambiguity, all rules must be true, or all rules must be false.
nameRules []AllowListNameRule
}
type AllowListNameRule struct {
Name *regexp.Regexp
Allow bool
}
func (al *AllowList) Allow(ip uint32) bool {
if al == nil {
return true
}
result := al.cidrTree.MostSpecificContains(ip)
switch v := result.(type) {
case bool:
return v
default:
panic(fmt.Errorf("invalid state, allowlist returned: %T %v", result, result))
}
}
func (al *AllowList) AllowName(name string) bool {
if al == nil || len(al.nameRules) == 0 {
return true
}
for _, rule := range al.nameRules {
if rule.Name.MatchString(name) {
return rule.Allow
}
}
// If no rules match, return the default, which is the inverse of the rules
return !al.nameRules[0].Allow
}

47
allow_list_test.go Normal file
View File

@@ -0,0 +1,47 @@
package nebula
import (
"net"
"regexp"
"testing"
"github.com/stretchr/testify/assert"
)
func TestAllowList_Allow(t *testing.T) {
assert.Equal(t, true, ((*AllowList)(nil)).Allow(ip2int(net.ParseIP("1.1.1.1"))))
tree := NewCIDRTree()
tree.AddCIDR(getCIDR("0.0.0.0/0"), true)
tree.AddCIDR(getCIDR("10.0.0.0/8"), false)
tree.AddCIDR(getCIDR("10.42.42.0/24"), true)
al := &AllowList{cidrTree: tree}
assert.Equal(t, true, al.Allow(ip2int(net.ParseIP("1.1.1.1"))))
assert.Equal(t, false, al.Allow(ip2int(net.ParseIP("10.0.0.4"))))
assert.Equal(t, true, al.Allow(ip2int(net.ParseIP("10.42.42.42"))))
}
func TestAllowList_AllowName(t *testing.T) {
assert.Equal(t, true, ((*AllowList)(nil)).AllowName("docker0"))
rules := []AllowListNameRule{
{Name: regexp.MustCompile("^docker.*$"), Allow: false},
{Name: regexp.MustCompile("^tun.*$"), Allow: false},
}
al := &AllowList{nameRules: rules}
assert.Equal(t, false, al.AllowName("docker0"))
assert.Equal(t, false, al.AllowName("tun0"))
assert.Equal(t, true, al.AllowName("eth0"))
rules = []AllowListNameRule{
{Name: regexp.MustCompile("^eth.*$"), Allow: true},
{Name: regexp.MustCompile("^ens.*$"), Allow: true},
}
al = &AllowList{nameRules: rules}
assert.Equal(t, false, al.AllowName("docker0"))
assert.Equal(t, true, al.AllowName("eth0"))
assert.Equal(t, true, al.AllowName("ens5"))
}

View File

@@ -212,10 +212,10 @@ func TestBitsLostCounter(t *testing.T) {
func BenchmarkBits(b *testing.B) {
z := NewBits(10)
for n := 0; n < b.N; n++ {
for i, _ := range z.bits {
for i := range z.bits {
z.bits[i] = true
}
for i, _ := range z.bits {
for i := range z.bits {
z.bits[i] = false
}

12
cert.go
View File

@@ -149,10 +149,16 @@ func loadCAFromConfig(c *Config) (*cert.NebulaCAPool, error) {
return nil, fmt.Errorf("error while adding CA certificate to CA trust store: %s", err)
}
// pki.blacklist entered the scene at about the same time we aliased x509 to pki, not supporting backwards compat
for _, fp := range c.GetStringSlice("pki.blocklist", []string{}) {
l.WithField("fingerprint", fp).Infof("Blocklisting cert")
CAs.BlocklistFingerprint(fp)
}
// Support deprecated config for at leaast one minor release to allow for migrations
for _, fp := range c.GetStringSlice("pki.blacklist", []string{}) {
l.WithField("fingerprint", fp).Infof("Blacklisting cert")
CAs.BlacklistFingerprint(fp)
l.WithField("fingerprint", fp).Infof("Blocklisting cert")
l.Warn("pki.blacklist is deprecated and will not be supported in a future release. Please migrate your config to use pki.blocklist")
CAs.BlocklistFingerprint(fp)
}
return CAs, nil

View File

@@ -8,14 +8,14 @@ import (
type NebulaCAPool struct {
CAs map[string]*NebulaCertificate
certBlacklist map[string]struct{}
certBlocklist map[string]struct{}
}
// NewCAPool creates a CAPool
func NewCAPool() *NebulaCAPool {
ca := NebulaCAPool{
CAs: make(map[string]*NebulaCertificate),
certBlacklist: make(map[string]struct{}),
certBlocklist: make(map[string]struct{}),
}
return &ca
@@ -67,24 +67,24 @@ func (ncp *NebulaCAPool) AddCACertificate(pemBytes []byte) ([]byte, error) {
return pemBytes, nil
}
// BlacklistFingerprint adds a cert fingerprint to the blacklist
func (ncp *NebulaCAPool) BlacklistFingerprint(f string) {
ncp.certBlacklist[f] = struct{}{}
// BlocklistFingerprint adds a cert fingerprint to the blocklist
func (ncp *NebulaCAPool) BlocklistFingerprint(f string) {
ncp.certBlocklist[f] = struct{}{}
}
// ResetCertBlacklist removes all previously blacklisted cert fingerprints
func (ncp *NebulaCAPool) ResetCertBlacklist() {
ncp.certBlacklist = make(map[string]struct{})
// ResetCertBlocklist removes all previously blocklisted cert fingerprints
func (ncp *NebulaCAPool) ResetCertBlocklist() {
ncp.certBlocklist = make(map[string]struct{})
}
// IsBlacklisted returns true if the fingerprint fails to generate or has been explicitly blacklisted
func (ncp *NebulaCAPool) IsBlacklisted(c *NebulaCertificate) bool {
// IsBlocklisted returns true if the fingerprint fails to generate or has been explicitly blocklisted
func (ncp *NebulaCAPool) IsBlocklisted(c *NebulaCertificate) bool {
h, err := c.Sha256Sum()
if err != nil {
return true
}
if _, ok := ncp.certBlacklist[h]; ok {
if _, ok := ncp.certBlocklist[h]; ok {
return true
}

View File

@@ -1,18 +1,18 @@
package cert
import (
"bytes"
"crypto"
"crypto/rand"
"crypto/sha256"
"encoding/binary"
"encoding/hex"
"encoding/json"
"encoding/pem"
"fmt"
"net"
"time"
"bytes"
"encoding/json"
"github.com/golang/protobuf/proto"
"golang.org/x/crypto/curve25519"
"golang.org/x/crypto/ed25519"
@@ -244,10 +244,10 @@ func (nc *NebulaCertificate) Expired(t time.Time) bool {
return nc.Details.NotBefore.After(t) || nc.Details.NotAfter.Before(t)
}
// Verify will ensure a certificate is good in all respects (expiry, group membership, signature, cert blacklist, etc)
// Verify will ensure a certificate is good in all respects (expiry, group membership, signature, cert blocklist, etc)
func (nc *NebulaCertificate) Verify(t time.Time, ncp *NebulaCAPool) (bool, error) {
if ncp.IsBlacklisted(nc) {
return false, fmt.Errorf("certificate has been blacklisted")
if ncp.IsBlocklisted(nc) {
return false, fmt.Errorf("certificate has been blocked")
}
signer, err := ncp.GetCAForCert(nc)
@@ -468,6 +468,63 @@ func (nc *NebulaCertificate) MarshalJSON() ([]byte, error) {
return json.Marshal(jc)
}
//func (nc *NebulaCertificate) Copy() *NebulaCertificate {
// r, err := nc.Marshal()
// if err != nil {
// //TODO
// return nil
// }
//
// c, err := UnmarshalNebulaCertificate(r)
// return c
//}
func (nc *NebulaCertificate) Copy() *NebulaCertificate {
c := &NebulaCertificate{
Details: NebulaCertificateDetails{
Name: nc.Details.Name,
Groups: make([]string, len(nc.Details.Groups)),
Ips: make([]*net.IPNet, len(nc.Details.Ips)),
Subnets: make([]*net.IPNet, len(nc.Details.Subnets)),
NotBefore: nc.Details.NotBefore,
NotAfter: nc.Details.NotAfter,
PublicKey: make([]byte, len(nc.Details.PublicKey)),
IsCA: nc.Details.IsCA,
Issuer: nc.Details.Issuer,
InvertedGroups: make(map[string]struct{}, len(nc.Details.InvertedGroups)),
},
Signature: make([]byte, len(nc.Signature)),
}
copy(c.Signature, nc.Signature)
copy(c.Details.Groups, nc.Details.Groups)
copy(c.Details.PublicKey, nc.Details.PublicKey)
for i, p := range nc.Details.Ips {
c.Details.Ips[i] = &net.IPNet{
IP: make(net.IP, len(p.IP)),
Mask: make(net.IPMask, len(p.Mask)),
}
copy(c.Details.Ips[i].IP, p.IP)
copy(c.Details.Ips[i].Mask, p.Mask)
}
for i, p := range nc.Details.Subnets {
c.Details.Subnets[i] = &net.IPNet{
IP: make(net.IP, len(p.IP)),
Mask: make(net.IPMask, len(p.Mask)),
}
copy(c.Details.Subnets[i].IP, p.IP)
copy(c.Details.Subnets[i].Mask, p.Mask)
}
for g := range nc.Details.InvertedGroups {
c.Details.InvertedGroups[g] = struct{}{}
}
return c
}
func netMatch(certIp *net.IPNet, rootIps []*net.IPNet) bool {
for _, net := range rootIps {
if net.Contains(certIp.IP) && maskContains(net.Mask, certIp.Mask) {

View File

@@ -9,6 +9,7 @@ import (
"time"
"github.com/golang/protobuf/proto"
"github.com/slackhq/nebula/util"
"github.com/stretchr/testify/assert"
"golang.org/x/crypto/curve25519"
"golang.org/x/crypto/ed25519"
@@ -172,13 +173,13 @@ func TestNebulaCertificate_Verify(t *testing.T) {
f, err := c.Sha256Sum()
assert.Nil(t, err)
caPool.BlacklistFingerprint(f)
caPool.BlocklistFingerprint(f)
v, err := c.Verify(time.Now(), caPool)
assert.False(t, v)
assert.EqualError(t, err, "certificate has been blacklisted")
assert.EqualError(t, err, "certificate has been blocked")
caPool.ResetCertBlacklist()
caPool.ResetCertBlocklist()
v, err = c.Verify(time.Now(), caPool)
assert.True(t, v)
assert.Nil(t, err)
@@ -487,6 +488,17 @@ func TestMarshalingNebulaCertificateConsistency(t *testing.T) {
assert.Equal(t, "0a0774657374696e67121b8182845080feffff0f828284508080fcff0f8382845080fe83f80f1a1b8182844880fe83f80f8282844880feffff0f838284488080fcff0f220b746573742d67726f757031220b746573742d67726f757032220b746573742d67726f75703328f0e0e7d70430a08681c4053a20313233343536373839306162636564666768696a3132333435363738393061624a081234567890abcedf", fmt.Sprintf("%x", b))
}
func TestNebulaCertificate_Copy(t *testing.T) {
ca, _, caKey, err := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
assert.Nil(t, err)
c, _, _, err := newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
assert.Nil(t, err)
cc := c.Copy()
util.AssertDeepCopyEqual(t, c, cc)
}
func newTestCaCert(before, after time.Time, ips, subnets []*net.IPNet, groups []string) (*NebulaCertificate, []byte, []byte, error) {
pub, priv, err := ed25519.GenerateKey(rand.Reader)
if before.IsZero() {
@@ -498,11 +510,12 @@ func newTestCaCert(before, after time.Time, ips, subnets []*net.IPNet, groups []
nc := &NebulaCertificate{
Details: NebulaCertificateDetails{
Name: "test ca",
NotBefore: before,
NotAfter: after,
PublicKey: pub,
IsCA: true,
Name: "test ca",
NotBefore: time.Unix(before.Unix(), 0),
NotAfter: time.Unix(after.Unix(), 0),
PublicKey: pub,
IsCA: true,
InvertedGroups: make(map[string]struct{}),
},
}
@@ -544,17 +557,17 @@ func newTestCert(ca *NebulaCertificate, key []byte, before, after time.Time, ips
if len(ips) == 0 {
ips = []*net.IPNet{
{IP: net.ParseIP("10.1.1.1"), Mask: net.IPMask(net.ParseIP("255.255.255.0"))},
{IP: net.ParseIP("10.1.1.2"), Mask: net.IPMask(net.ParseIP("255.255.0.0"))},
{IP: net.ParseIP("10.1.1.3"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))},
{IP: net.ParseIP("10.1.1.1").To4(), Mask: net.IPMask(net.ParseIP("255.255.255.0").To4())},
{IP: net.ParseIP("10.1.1.2").To4(), Mask: net.IPMask(net.ParseIP("255.255.0.0").To4())},
{IP: net.ParseIP("10.1.1.3").To4(), Mask: net.IPMask(net.ParseIP("255.0.255.0").To4())},
}
}
if len(subnets) == 0 {
subnets = []*net.IPNet{
{IP: net.ParseIP("9.1.1.1"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))},
{IP: net.ParseIP("9.1.1.2"), Mask: net.IPMask(net.ParseIP("255.255.255.0"))},
{IP: net.ParseIP("9.1.1.3"), Mask: net.IPMask(net.ParseIP("255.255.0.0"))},
{IP: net.ParseIP("9.1.1.1").To4(), Mask: net.IPMask(net.ParseIP("255.0.255.0").To4())},
{IP: net.ParseIP("9.1.1.2").To4(), Mask: net.IPMask(net.ParseIP("255.255.255.0").To4())},
{IP: net.ParseIP("9.1.1.3").To4(), Mask: net.IPMask(net.ParseIP("255.255.0.0").To4())},
}
}
@@ -562,15 +575,16 @@ func newTestCert(ca *NebulaCertificate, key []byte, before, after time.Time, ips
nc := &NebulaCertificate{
Details: NebulaCertificateDetails{
Name: "testing",
Ips: ips,
Subnets: subnets,
Groups: groups,
NotBefore: before,
NotAfter: after,
PublicKey: pub,
IsCA: false,
Issuer: issuer,
Name: "testing",
Ips: ips,
Subnets: subnets,
Groups: groups,
NotBefore: time.Unix(before.Unix(), 0),
NotAfter: time.Unix(after.Unix(), 0),
PublicKey: pub,
IsCA: false,
Issuer: issuer,
InvertedGroups: make(map[string]struct{}),
},
}

View File

@@ -3,10 +3,11 @@ package main
import (
"bytes"
"errors"
"github.com/stretchr/testify/assert"
"io"
"os"
"testing"
"github.com/stretchr/testify/assert"
)
//TODO: all flag parsing continueOnError will print to stderr on its own currently

View File

@@ -4,11 +4,12 @@ import (
"encoding/json"
"flag"
"fmt"
"github.com/slackhq/nebula/cert"
"io"
"io/ioutil"
"os"
"strings"
"github.com/slackhq/nebula/cert"
)
type printFlags struct {

View File

@@ -2,12 +2,13 @@ package main
import (
"bytes"
"github.com/slackhq/nebula/cert"
"github.com/stretchr/testify/assert"
"io/ioutil"
"os"
"testing"
"time"
"github.com/slackhq/nebula/cert"
"github.com/stretchr/testify/assert"
)
func Test_printSummary(t *testing.T) {

View File

@@ -3,12 +3,13 @@ package main
import (
"flag"
"fmt"
"github.com/slackhq/nebula/cert"
"io"
"io/ioutil"
"os"
"strings"
"time"
"github.com/slackhq/nebula/cert"
)
type verifyFlags struct {

View File

@@ -3,13 +3,14 @@ package main
import (
"bytes"
"crypto/rand"
"github.com/slackhq/nebula/cert"
"github.com/stretchr/testify/assert"
"golang.org/x/crypto/ed25519"
"io/ioutil"
"os"
"testing"
"time"
"github.com/slackhq/nebula/cert"
"github.com/stretchr/testify/assert"
"golang.org/x/crypto/ed25519"
)
func Test_verifySummary(t *testing.T) {

View File

@@ -5,6 +5,7 @@ import (
"fmt"
"os"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula"
)
@@ -45,5 +46,30 @@ func main() {
os.Exit(1)
}
nebula.Main(*configPath, *configTest, Build)
config := nebula.NewConfig()
err := config.Load(*configPath)
if err != nil {
fmt.Printf("failed to load config: %s", err)
os.Exit(1)
}
l := logrus.New()
l.Out = os.Stdout
c, err := nebula.Main(config, *configTest, Build, l, nil)
switch v := err.(type) {
case nebula.ContextualError:
v.Log(l)
os.Exit(1)
case error:
l.WithError(err).Error("Failed to start")
os.Exit(1)
}
if !*configTest {
c.Start()
c.ShutdownBlock()
}
os.Exit(0)
}

View File

@@ -1,44 +1,53 @@
package main
import (
"fmt"
"log"
"os"
"path/filepath"
"github.com/kardianos/service"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula"
)
var logger service.Logger
type program struct {
exit chan struct{}
configPath *string
configTest *bool
build string
control *nebula.Control
}
func (p *program) Start(s service.Service) error {
logger.Info("Nebula service starting.")
p.exit = make(chan struct{})
// Start should not block.
go p.run()
return nil
}
logger.Info("Nebula service starting.")
func (p *program) run() error {
nebula.Main(*p.configPath, *p.configTest, Build)
config := nebula.NewConfig()
err := config.Load(*p.configPath)
if err != nil {
return fmt.Errorf("failed to load config: %s", err)
}
l := logrus.New()
l.Out = os.Stdout
p.control, err = nebula.Main(config, *p.configTest, Build, l, nil)
if err != nil {
return err
}
p.control.Start()
return nil
}
func (p *program) Stop(s service.Service) error {
logger.Info("Nebula service stopping.")
close(p.exit)
p.control.Stop()
return nil
}
func doService(configPath *string, configTest *bool, build string, serviceFlag *string) {
if *configPath == "" {
ex, err := os.Executable()
if err != nil {

View File

@@ -5,6 +5,7 @@ import (
"fmt"
"os"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula"
)
@@ -39,5 +40,30 @@ func main() {
os.Exit(1)
}
nebula.Main(*configPath, *configTest, Build)
config := nebula.NewConfig()
err := config.Load(*configPath)
if err != nil {
fmt.Printf("failed to load config: %s", err)
os.Exit(1)
}
l := logrus.New()
l.Out = os.Stdout
c, err := nebula.Main(config, *configTest, Build, l, nil)
switch v := err.(type) {
case nebula.ContextualError:
v.Log(l)
os.Exit(1)
case error:
l.WithError(err).Error("Failed to start")
os.Exit(1)
}
if !*configTest {
c.Start()
c.ShutdownBlock()
}
os.Exit(0)
}

188
config.go
View File

@@ -1,19 +1,23 @@
package nebula
import (
"errors"
"fmt"
"github.com/imdario/mergo"
"github.com/sirupsen/logrus"
"gopkg.in/yaml.v2"
"io/ioutil"
"net"
"os"
"os/signal"
"path/filepath"
"regexp"
"sort"
"strconv"
"strings"
"syscall"
"time"
"github.com/imdario/mergo"
"github.com/sirupsen/logrus"
"gopkg.in/yaml.v2"
)
type Config struct {
@@ -35,7 +39,7 @@ func (c *Config) Load(path string) error {
c.path = path
c.files = make([]string, 0)
err := c.resolve(path)
err := c.resolve(path, true)
if err != nil {
return err
}
@@ -54,6 +58,13 @@ func (c *Config) Load(path string) error {
return nil
}
func (c *Config) LoadString(raw string) error {
if raw == "" {
return errors.New("Empty configuration")
}
return c.parseRaw([]byte(raw))
}
// RegisterReloadCallback stores a function to be called when a config reload is triggered. The functions registered
// here should decide if they need to make a change to the current process before making the change. HasChanged can be
// used to help decide if a change is necessary.
@@ -213,10 +224,137 @@ func (c *Config) GetDuration(k string, d time.Duration) time.Duration {
return v
}
func (c *Config) GetAllowList(k string, allowInterfaces bool) (*AllowList, error) {
r := c.Get(k)
if r == nil {
return nil, nil
}
rawMap, ok := r.(map[interface{}]interface{})
if !ok {
return nil, fmt.Errorf("config `%s` has invalid type: %T", k, r)
}
tree := NewCIDRTree()
var nameRules []AllowListNameRule
firstValue := true
allValuesMatch := true
defaultSet := false
var allValues bool
for rawKey, rawValue := range rawMap {
rawCIDR, ok := rawKey.(string)
if !ok {
return nil, fmt.Errorf("config `%s` has invalid key (type %T): %v", k, rawKey, rawKey)
}
// Special rule for interface names
if rawCIDR == "interfaces" {
if !allowInterfaces {
return nil, fmt.Errorf("config `%s` does not support `interfaces`", k)
}
var err error
nameRules, err = c.getAllowListInterfaces(k, rawValue)
if err != nil {
return nil, err
}
continue
}
value, ok := rawValue.(bool)
if !ok {
return nil, fmt.Errorf("config `%s` has invalid value (type %T): %v", k, rawValue, rawValue)
}
_, cidr, err := net.ParseCIDR(rawCIDR)
if err != nil {
return nil, fmt.Errorf("config `%s` has invalid CIDR: %s", k, rawCIDR)
}
// TODO: should we error on duplicate CIDRs in the config?
tree.AddCIDR(cidr, value)
if firstValue {
allValues = value
firstValue = false
} else {
if value != allValues {
allValuesMatch = false
}
}
// Check if this is 0.0.0.0/0
bits, size := cidr.Mask.Size()
if bits == 0 && size == 32 {
defaultSet = true
}
}
if !defaultSet {
if allValuesMatch {
_, zeroCIDR, _ := net.ParseCIDR("0.0.0.0/0")
tree.AddCIDR(zeroCIDR, !allValues)
} else {
return nil, fmt.Errorf("config `%s` contains both true and false rules, but no default set for 0.0.0.0/0", k)
}
}
return &AllowList{cidrTree: tree, nameRules: nameRules}, nil
}
func (c *Config) getAllowListInterfaces(k string, v interface{}) ([]AllowListNameRule, error) {
var nameRules []AllowListNameRule
rawRules, ok := v.(map[interface{}]interface{})
if !ok {
return nil, fmt.Errorf("config `%s.interfaces` is invalid (type %T): %v", k, v, v)
}
firstEntry := true
var allValues bool
for rawName, rawAllow := range rawRules {
name, ok := rawName.(string)
if !ok {
return nil, fmt.Errorf("config `%s.interfaces` has invalid key (type %T): %v", k, rawName, rawName)
}
allow, ok := rawAllow.(bool)
if !ok {
return nil, fmt.Errorf("config `%s.interfaces` has invalid value (type %T): %v", k, rawAllow, rawAllow)
}
nameRE, err := regexp.Compile("^" + name + "$")
if err != nil {
return nil, fmt.Errorf("config `%s.interfaces` has invalid key: %s: %v", k, name, err)
}
nameRules = append(nameRules, AllowListNameRule{
Name: nameRE,
Allow: allow,
})
if firstEntry {
allValues = allow
firstEntry = false
} else {
if allow != allValues {
return nil, fmt.Errorf("config `%s.interfaces` values must all be the same true/false value", k)
}
}
}
return nameRules, nil
}
func (c *Config) Get(k string) interface{} {
return c.get(k, c.Settings)
}
func (c *Config) IsSet(k string) bool {
return c.get(k, c.Settings) != nil
}
func (c *Config) get(k string, v interface{}) interface{} {
parts := strings.Split(k, ".")
for _, p := range parts {
@@ -234,14 +372,16 @@ func (c *Config) get(k string, v interface{}) interface{} {
return v
}
func (c *Config) resolve(path string) error {
// direct signifies if this is the config path directly specified by the user,
// versus a file/dir found by recursing into that path
func (c *Config) resolve(path string, direct bool) error {
i, err := os.Stat(path)
if err != nil {
return nil
}
if !i.IsDir() {
c.addFile(path)
c.addFile(path, direct)
return nil
}
@@ -251,7 +391,7 @@ func (c *Config) resolve(path string) error {
}
for _, p := range paths {
err := c.resolve(filepath.Join(path, p))
err := c.resolve(filepath.Join(path, p), false)
if err != nil {
return err
}
@@ -260,10 +400,10 @@ func (c *Config) resolve(path string) error {
return nil
}
func (c *Config) addFile(path string) error {
func (c *Config) addFile(path string, direct bool) error {
ext := filepath.Ext(path)
if ext != ".yaml" && ext != ".yml" {
if !direct && ext != ".yaml" && ext != ".yml" {
return nil
}
@@ -276,6 +416,18 @@ func (c *Config) addFile(path string) error {
return nil
}
func (c *Config) parseRaw(b []byte) error {
var m map[interface{}]interface{}
err := yaml.Unmarshal(b, &m)
if err != nil {
return err
}
c.Settings = m
return nil
}
func (c *Config) parse() error {
var m map[interface{}]interface{}
@@ -328,12 +480,26 @@ func configLogger(c *Config) error {
}
l.SetLevel(logLevel)
disableTimestamp := c.GetBool("logging.disable_timestamp", false)
timestampFormat := c.GetString("logging.timestamp_format", "")
fullTimestamp := (timestampFormat != "")
if timestampFormat == "" {
timestampFormat = time.RFC3339
}
logFormat := strings.ToLower(c.GetString("logging.format", "text"))
switch logFormat {
case "text":
l.Formatter = &logrus.TextFormatter{}
l.Formatter = &logrus.TextFormatter{
TimestampFormat: timestampFormat,
FullTimestamp: fullTimestamp,
DisableTimestamp: disableTimestamp,
}
case "json":
l.Formatter = &logrus.JSONFormatter{}
l.Formatter = &logrus.JSONFormatter{
TimestampFormat: timestampFormat,
DisableTimestamp: disableTimestamp,
}
default:
return fmt.Errorf("unknown log format `%s`. possible formats: %s", logFormat, []string{"text", "json"})
}

View File

@@ -1,12 +1,13 @@
package nebula
import (
"github.com/stretchr/testify/assert"
"io/ioutil"
"os"
"path/filepath"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestConfig_Load(t *testing.T) {
@@ -86,6 +87,76 @@ func TestConfig_GetBool(t *testing.T) {
assert.Equal(t, false, c.GetBool("bool", true))
}
func TestConfig_GetAllowList(t *testing.T) {
c := NewConfig()
c.Settings["allowlist"] = map[interface{}]interface{}{
"192.168.0.0": true,
}
r, err := c.GetAllowList("allowlist", false)
assert.EqualError(t, err, "config `allowlist` has invalid CIDR: 192.168.0.0")
assert.Nil(t, r)
c.Settings["allowlist"] = map[interface{}]interface{}{
"192.168.0.0/16": "abc",
}
r, err = c.GetAllowList("allowlist", false)
assert.EqualError(t, err, "config `allowlist` has invalid value (type string): abc")
c.Settings["allowlist"] = map[interface{}]interface{}{
"192.168.0.0/16": true,
"10.0.0.0/8": false,
}
r, err = c.GetAllowList("allowlist", false)
assert.EqualError(t, err, "config `allowlist` contains both true and false rules, but no default set for 0.0.0.0/0")
c.Settings["allowlist"] = map[interface{}]interface{}{
"0.0.0.0/0": true,
"10.0.0.0/8": false,
"10.42.42.0/24": true,
}
r, err = c.GetAllowList("allowlist", false)
if assert.NoError(t, err) {
assert.NotNil(t, r)
}
// Test interface names
c.Settings["allowlist"] = map[interface{}]interface{}{
"interfaces": map[interface{}]interface{}{
`docker.*`: false,
},
}
r, err = c.GetAllowList("allowlist", false)
assert.EqualError(t, err, "config `allowlist` does not support `interfaces`")
c.Settings["allowlist"] = map[interface{}]interface{}{
"interfaces": map[interface{}]interface{}{
`docker.*`: "foo",
},
}
r, err = c.GetAllowList("allowlist", true)
assert.EqualError(t, err, "config `allowlist.interfaces` has invalid value (type string): foo")
c.Settings["allowlist"] = map[interface{}]interface{}{
"interfaces": map[interface{}]interface{}{
`docker.*`: false,
`eth.*`: true,
},
}
r, err = c.GetAllowList("allowlist", true)
assert.EqualError(t, err, "config `allowlist.interfaces` values must all be the same true/false value")
c.Settings["allowlist"] = map[interface{}]interface{}{
"interfaces": map[interface{}]interface{}{
`docker.*`: false,
},
}
r, err = c.GetAllowList("allowlist", true)
if assert.NoError(t, err) {
assert.NotNil(t, r)
}
}
func TestConfig_HasChanged(t *testing.T) {
// No reload has occurred, return false
c := NewConfig()

View File

@@ -182,7 +182,7 @@ func (n *connectionManager) HandleMonitorTick(now time.Time) {
continue
}
l.WithField("vpnIp", IntIp(vpnIP)).
hostinfo.logger().
WithField("tunnelCheck", m{"state": "testing", "method": "active"}).
Debug("Tunnel status")
@@ -191,7 +191,7 @@ func (n *connectionManager) HandleMonitorTick(now time.Time) {
n.intf.SendMessageToVpnIp(test, testRequest, vpnIP, []byte(""), make([]byte, 12, 12), make([]byte, mtu))
} else {
l.Debugf("Hostinfo sadness: %s", IntIp(vpnIP))
hostinfo.logger().Debugf("Hostinfo sadness: %s", IntIp(vpnIP))
}
n.AddPendingDeletion(vpnIP)
}
@@ -233,7 +233,7 @@ func (n *connectionManager) HandleDeletionTick(now time.Time) {
if hostinfo.ConnectionState != nil && hostinfo.ConnectionState.peerCert != nil {
cn = hostinfo.ConnectionState.peerCert.Details.Name
}
l.WithField("vpnIp", IntIp(vpnIP)).
hostinfo.logger().
WithField("tunnelCheck", m{"state": "dead", "method": "active"}).
WithField("certName", cn).
Info("Tunnel status")

View File

@@ -28,7 +28,7 @@ func Test_NewConnectionManagerTest(t *testing.T) {
rawCertificateNoKey: []byte{},
}
lh := NewLightHouse(false, 0, []uint32{}, 1000, 0, &udpConn{}, false)
lh := NewLightHouse(false, 0, []uint32{}, 1000, 0, &udpConn{}, false, 1, false)
ifce := &Interface{
hostMap: hostMap,
inside: &Tun{},
@@ -36,7 +36,7 @@ func Test_NewConnectionManagerTest(t *testing.T) {
certState: cs,
firewall: &Firewall{},
lightHouse: lh,
handshakeManager: NewHandshakeManager(vpncidr, preferredRanges, hostMap, lh, &udpConn{}),
handshakeManager: NewHandshakeManager(vpncidr, preferredRanges, hostMap, lh, &udpConn{}, defaultHandshakeConfig),
}
now := time.Now()
@@ -91,7 +91,7 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
rawCertificateNoKey: []byte{},
}
lh := NewLightHouse(false, 0, []uint32{}, 1000, 0, &udpConn{}, false)
lh := NewLightHouse(false, 0, []uint32{}, 1000, 0, &udpConn{}, false, 1, false)
ifce := &Interface{
hostMap: hostMap,
inside: &Tun{},
@@ -99,7 +99,7 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
certState: cs,
firewall: &Firewall{},
lightHouse: lh,
handshakeManager: NewHandshakeManager(vpncidr, preferredRanges, hostMap, lh, &udpConn{}),
handshakeManager: NewHandshakeManager(vpncidr, preferredRanges, hostMap, lh, &udpConn{}, defaultHandshakeConfig),
}
now := time.Now()

169
control.go Normal file
View File

@@ -0,0 +1,169 @@
package nebula
import (
"net"
"os"
"os/signal"
"syscall"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert"
)
// Every interaction here needs to take extra care to copy memory and not return or use arguments "as is" when touching
// core. This means copying IP objects, slices, de-referencing pointers and taking the actual value, etc
type Control struct {
f *Interface
l *logrus.Logger
}
type ControlHostInfo struct {
VpnIP net.IP `json:"vpnIp"`
LocalIndex uint32 `json:"localIndex"`
RemoteIndex uint32 `json:"remoteIndex"`
RemoteAddrs []udpAddr `json:"remoteAddrs"`
CachedPackets int `json:"cachedPackets"`
Cert *cert.NebulaCertificate `json:"cert"`
MessageCounter uint64 `json:"messageCounter"`
CurrentRemote udpAddr `json:"currentRemote"`
}
// Start actually runs nebula, this is a nonblocking call. To block use Control.ShutdownBlock()
func (c *Control) Start() {
c.f.run()
}
// Stop signals nebula to shutdown, returns after the shutdown is complete
func (c *Control) Stop() {
//TODO: stop tun and udp routines, the lock on hostMap effectively does that though
//TODO: this is probably better as a function in ConnectionManager or HostMap directly
c.f.hostMap.Lock()
for _, h := range c.f.hostMap.Hosts {
if h.ConnectionState.ready {
c.f.send(closeTunnel, 0, h.ConnectionState, h, h.remote, []byte{}, make([]byte, 12, 12), make([]byte, mtu))
c.l.WithField("vpnIp", IntIp(h.hostId)).WithField("udpAddr", h.remote).
Debug("Sending close tunnel message")
}
}
c.f.hostMap.Unlock()
c.l.Info("Goodbye")
}
// ShutdownBlock will listen for and block on term and interrupt signals, calling Control.Stop() once signalled
func (c *Control) ShutdownBlock() {
sigChan := make(chan os.Signal)
signal.Notify(sigChan, syscall.SIGTERM)
signal.Notify(sigChan, syscall.SIGINT)
rawSig := <-sigChan
sig := rawSig.String()
c.l.WithField("signal", sig).Info("Caught signal, shutting down")
c.Stop()
}
// RebindUDPServer asks the UDP listener to rebind it's listener. Mainly used on mobile clients when interfaces change
func (c *Control) RebindUDPServer() {
_ = c.f.outside.Rebind()
}
// ListHostmap returns details about the actual or pending (handshaking) hostmap
func (c *Control) ListHostmap(pendingMap bool) []ControlHostInfo {
var hm *HostMap
if pendingMap {
hm = c.f.handshakeManager.pendingHostMap
} else {
hm = c.f.hostMap
}
hm.RLock()
hosts := make([]ControlHostInfo, len(hm.Hosts))
i := 0
for _, v := range hm.Hosts {
hosts[i] = copyHostInfo(v)
i++
}
hm.RUnlock()
return hosts
}
// GetHostInfoByVpnIP returns a single tunnels hostInfo, or nil if not found
func (c *Control) GetHostInfoByVpnIP(vpnIP uint32, pending bool) *ControlHostInfo {
var hm *HostMap
if pending {
hm = c.f.handshakeManager.pendingHostMap
} else {
hm = c.f.hostMap
}
h, err := hm.QueryVpnIP(vpnIP)
if err != nil {
return nil
}
ch := copyHostInfo(h)
return &ch
}
// SetRemoteForTunnel forces a tunnel to use a specific remote
func (c *Control) SetRemoteForTunnel(vpnIP uint32, addr udpAddr) *ControlHostInfo {
hostInfo, err := c.f.hostMap.QueryVpnIP(vpnIP)
if err != nil {
return nil
}
hostInfo.SetRemote(addr.Copy())
ch := copyHostInfo(hostInfo)
return &ch
}
// CloseTunnel closes a fully established tunnel. If localOnly is false it will notify the remote end as well.
func (c *Control) CloseTunnel(vpnIP uint32, localOnly bool) bool {
hostInfo, err := c.f.hostMap.QueryVpnIP(vpnIP)
if err != nil {
return false
}
if !localOnly {
c.f.send(
closeTunnel,
0,
hostInfo.ConnectionState,
hostInfo,
hostInfo.remote,
[]byte{},
make([]byte, 12, 12),
make([]byte, mtu),
)
}
c.f.closeTunnel(hostInfo)
return true
}
func copyHostInfo(h *HostInfo) ControlHostInfo {
addrs := h.RemoteUDPAddrs()
chi := ControlHostInfo{
VpnIP: int2ip(h.hostId),
LocalIndex: h.localIndexId,
RemoteIndex: h.remoteIndexId,
RemoteAddrs: make([]udpAddr, len(addrs), len(addrs)),
CachedPackets: len(h.packetStore),
MessageCounter: *h.ConnectionState.messageCounter,
}
if c := h.GetCert(); c != nil {
chi.Cert = c.Copy()
}
if h.remote != nil {
chi.CurrentRemote = *h.remote
}
for i, addr := range addrs {
chi.RemoteAddrs[i] = addr.Copy()
}
return chi
}

111
control_test.go Normal file
View File

@@ -0,0 +1,111 @@
package nebula
import (
"net"
"reflect"
"testing"
"time"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/util"
"github.com/stretchr/testify/assert"
)
func TestControl_GetHostInfoByVpnIP(t *testing.T) {
// Special care must be taken to re-use all objects provided to the hostmap and certificate in the expectedInfo object
// To properly ensure we are not exposing core memory to the caller
hm := NewHostMap("test", &net.IPNet{}, make([]*net.IPNet, 0))
remote1 := NewUDPAddr(100, 4444)
remote2 := NewUDPAddr(101, 4444)
ipNet := net.IPNet{
IP: net.IPv4(1, 2, 3, 4),
Mask: net.IPMask{255, 255, 255, 0},
}
ipNet2 := net.IPNet{
IP: net.IPv4(1, 2, 3, 5),
Mask: net.IPMask{255, 255, 255, 0},
}
crt := &cert.NebulaCertificate{
Details: cert.NebulaCertificateDetails{
Name: "test",
Ips: []*net.IPNet{&ipNet},
Subnets: []*net.IPNet{},
Groups: []string{"default-group"},
NotBefore: time.Unix(1, 0),
NotAfter: time.Unix(2, 0),
PublicKey: []byte{5, 6, 7, 8},
IsCA: false,
Issuer: "the-issuer",
InvertedGroups: map[string]struct{}{"default-group": {}},
},
Signature: []byte{1, 2, 1, 2, 1, 3},
}
counter := uint64(0)
remotes := []*HostInfoDest{NewHostInfoDest(remote1), NewHostInfoDest(remote2)}
hm.Add(ip2int(ipNet.IP), &HostInfo{
remote: remote1,
Remotes: remotes,
ConnectionState: &ConnectionState{
peerCert: crt,
messageCounter: &counter,
},
remoteIndexId: 200,
localIndexId: 201,
hostId: ip2int(ipNet.IP),
})
hm.Add(ip2int(ipNet2.IP), &HostInfo{
remote: remote1,
Remotes: remotes,
ConnectionState: &ConnectionState{
peerCert: nil,
messageCounter: &counter,
},
remoteIndexId: 200,
localIndexId: 201,
hostId: ip2int(ipNet2.IP),
})
c := Control{
f: &Interface{
hostMap: hm,
},
l: logrus.New(),
}
thi := c.GetHostInfoByVpnIP(ip2int(ipNet.IP), false)
expectedInfo := ControlHostInfo{
VpnIP: net.IPv4(1, 2, 3, 4).To4(),
LocalIndex: 201,
RemoteIndex: 200,
RemoteAddrs: []udpAddr{*remote1, *remote2},
CachedPackets: 0,
Cert: crt.Copy(),
MessageCounter: 0,
CurrentRemote: *NewUDPAddr(100, 4444),
}
// Make sure we don't have any unexpected fields
assertFields(t, []string{"VpnIP", "LocalIndex", "RemoteIndex", "RemoteAddrs", "CachedPackets", "Cert", "MessageCounter", "CurrentRemote"}, thi)
util.AssertDeepCopyEqual(t, &expectedInfo, thi)
// Make sure we don't panic if the host info doesn't have a cert yet
assert.NotPanics(t, func() {
thi = c.GetHostInfoByVpnIP(ip2int(ipNet2.IP), false)
})
}
func assertFields(t *testing.T, expected []string, actualStruct interface{}) {
val := reflect.ValueOf(actualStruct).Elem()
fields := make([]string, val.NumField())
for i := 0; i < val.NumField(); i++ {
fields[i] = val.Type().Field(i).Name
}
assert.Equal(t, expected, fields)
}

15
dist/arch/nebula.service vendored Normal file
View File

@@ -0,0 +1,15 @@
[Unit]
Description=nebula
Wants=basic.target network-online.target
After=basic.target network.target network-online.target
[Service]
SyslogIdentifier=nebula
StandardOutput=syslog
StandardError=syslog
ExecReload=/bin/kill -HUP $MAINPID
ExecStart=/usr/bin/nebula -config /etc/nebula/config.yml
Restart=always
[Install]
WantedBy=multi-user.target

113
dist/wireshark/nebula.lua vendored Normal file
View File

@@ -0,0 +1,113 @@
local nebula = Proto("nebula", "nebula")
local default_settings = {
port = 4242,
all_ports = false,
}
nebula.prefs.port = Pref.uint("Port number", default_settings.port, "The UDP port number for Nebula")
nebula.prefs.all_ports = Pref.bool("All ports", default_settings.all_ports, "Assume nebula packets on any port, useful when dealing with hole punching")
local pf_version = ProtoField.new("version", "nebula.version", ftypes.UINT8, nil, base.DEC, 0xF0)
local pf_type = ProtoField.new("type", "nebula.type", ftypes.UINT8, {
[0] = "handshake",
[1] = "message",
[2] = "recvError",
[3] = "lightHouse",
[4] = "test",
[5] = "closeTunnel",
}, base.DEC, 0x0F)
local pf_subtype = ProtoField.new("subtype", "nebula.subtype", ftypes.UINT8, nil, base.DEC)
local pf_subtype_test = ProtoField.new("subtype", "nebula.subtype", ftypes.UINT8, {
[0] = "request",
[1] = "reply",
}, base.DEC)
local pf_subtype_handshake = ProtoField.new("subtype", "nebula.subtype", ftypes.UINT8, {
[0] = "ix_psk0",
}, base.DEC)
local pf_reserved = ProtoField.new("reserved", "nebula.reserved", ftypes.UINT16, nil, base.HEX)
local pf_remote_index = ProtoField.new("remote index", "nebula.remote_index", ftypes.UINT32, nil, base.DEC)
local pf_message_counter = ProtoField.new("counter", "nebula.counter", ftypes.UINT64, nil, base.DEC)
local pf_payload = ProtoField.new("payload", "nebula.payload", ftypes.BYTES, nil, base.NONE)
nebula.fields = { pf_version, pf_type, pf_subtype, pf_subtype_handshake, pf_subtype_test, pf_reserved, pf_remote_index, pf_message_counter, pf_payload }
local ef_holepunch = ProtoExpert.new("nebula.holepunch.expert", "Nebula hole punch packet", expert.group.PROTOCOL, expert.severity.NOTE)
local ef_punchy = ProtoExpert.new("nebula.punchy.expert", "Nebula punchy keepalive packet", expert.group.PROTOCOL, expert.severity.NOTE)
nebula.experts = { ef_holepunch, ef_punchy }
local type_field = Field.new("nebula.type")
local subtype_field = Field.new("nebula.subtype")
function nebula.dissector(tvbuf, pktinfo, root)
-- set the protocol column to show our protocol name
pktinfo.cols.protocol:set("NEBULA")
local pktlen = tvbuf:reported_length_remaining()
local tree = root:add(nebula, tvbuf:range(0,pktlen))
if pktlen == 0 then
tree:add_proto_expert_info(ef_holepunch)
pktinfo.cols.info:append(" (holepunch)")
return
elseif pktlen == 1 then
tree:add_proto_expert_info(ef_punchy)
pktinfo.cols.info:append(" (punchy)")
return
end
tree:add(pf_version, tvbuf:range(0,1))
local type = tree:add(pf_type, tvbuf:range(0,1))
local nebula_type = bit32.band(tvbuf:range(0,1):uint(), 0x0F)
if nebula_type == 0 then
local stage = tvbuf(8,8):uint64()
tree:add(pf_subtype_handshake, tvbuf:range(1,1))
type:append_text(" stage " .. stage)
pktinfo.cols.info:append(" (" .. type_field().display .. ", stage " .. stage .. ", " .. subtype_field().display .. ")")
elseif nebula_type == 4 then
tree:add(pf_subtype_test, tvbuf:range(1,1))
pktinfo.cols.info:append(" (" .. type_field().display .. ", " .. subtype_field().display .. ")")
else
tree:add(pf_subtype, tvbuf:range(1,1))
pktinfo.cols.info:append(" (" .. type_field().display .. ")")
end
tree:add(pf_reserved, tvbuf:range(2,2))
tree:add(pf_remote_index, tvbuf:range(4,4))
tree:add(pf_message_counter, tvbuf:range(8,8))
tree:add(pf_payload, tvbuf:range(16,tvbuf:len() - 16))
end
function nebula.prefs_changed()
if default_settings.all_ports == nebula.prefs.all_ports and default_settings.port == nebula.prefs.port then
-- Nothing changed, bail
return
end
-- Remove our old dissector
DissectorTable.get("udp.port"):remove_all(nebula)
if nebula.prefs.all_ports and default_settings.all_ports ~= nebula.prefs.all_ports then
default_settings.all_port = nebula.prefs.all_ports
for i=0, 65535 do
DissectorTable.get("udp.port"):add(i, nebula)
end
-- no need to establish again on specific ports
return
end
if default_settings.all_ports ~= nebula.prefs.all_ports then
-- Add our new port dissector
default_settings.port = nebula.prefs.port
DissectorTable.get("udp.port"):add(default_settings.port, nebula)
end
end
DissectorTable.get("udp.port"):add(default_settings.port, nebula)

View File

@@ -7,8 +7,8 @@ pki:
ca: /etc/nebula/ca.crt
cert: /etc/nebula/host.crt
key: /etc/nebula/host.key
#blacklist is a list of certificate fingerprints that we will refuse to talk to
#blacklist:
#blocklist is a list of certificate fingerprints that we will refuse to talk to
#blocklist:
# - c99d4e650533b92061b09918e838a5a0a6aaee21eed1d12fd937682865936c72
# The static host map defines a set of hosts with fixed IP addresses on the internet (or any network).
@@ -36,9 +36,41 @@ lighthouse:
interval: 60
# hosts is a list of lighthouse hosts this node should report to and query from
# IMPORTANT: THIS SHOULD BE EMPTY ON LIGHTHOUSE NODES
# IMPORTANT2: THIS SHOULD BE LIGHTHOUSES' NEBULA IPs, NOT LIGHTHOUSES' REAL ROUTABLE IPs
hosts:
- "192.168.100.1"
# remote_allow_list allows you to control ip ranges that this node will
# consider when handshaking to another node. By default, any remote IPs are
# allowed. You can provide CIDRs here with `true` to allow and `false` to
# deny. The most specific CIDR rule applies to each remote. If all rules are
# "allow", the default will be "deny", and vice-versa. If both "allow" and
# "deny" rules are present, then you MUST set a rule for "0.0.0.0/0" as the
# default.
#remote_allow_list:
# Example to block IPs from this subnet from being used for remote IPs.
#"172.16.0.0/12": false
# A more complicated example, allow public IPs but only private IPs from a specific subnet
#"0.0.0.0/0": true
#"10.0.0.0/8": false
#"10.42.42.0/24": true
# local_allow_list allows you to filter which local IP addresses we advertise
# to the lighthouses. This uses the same logic as `remote_allow_list`, but
# additionally, you can specify an `interfaces` map of regular expressions
# to match against interface names. The regexp must match the entire name.
# All interface rules must be either true or false (and the default will be
# the inverse). CIDR rules are matched after interface name rules.
# Default is all local IP addresses.
#local_allow_list:
# Example to block tun0 and all docker interfaces.
#interfaces:
#tun0: false
#'docker.*': false
# Example to only advertise this subnet to the lighthouse.
#"10.0.0.0/8": true
# Port Nebula will be listening on. The default here is 4242. For a lighthouse node, the port should be defined,
# however using port 0 will dynamically assign a port and is recommended for roaming nodes.
listen:
@@ -54,11 +86,17 @@ listen:
#read_buffer: 10485760
#write_buffer: 10485760
# Punchy continues to punch inbound/outbound at a regular interval to avoid expiration of firewall nat mappings
punchy: true
# punch_back means that a node you are trying to reach will connect back out to you if your hole punching fails
# this is extremely useful if one node is behind a difficult nat, such as symmetric
#punch_back: true
punchy:
# Continues to punch inbound/outbound at a regular interval to avoid expiration of firewall nat mappings
punch: true
# respond means that a node you are trying to reach will connect back out to you if your hole punching fails
# this is extremely useful if one node is behind a difficult nat, such as a symmetric NAT
# Default is false
#respond: true
# delays a punch response for misbehaving NATs, default is 1 second, respond must be true to take effect
#delay: 1s
# Cipher allows you to choose between the available ciphers for your network.
# IMPORTANT: this value must be identical on ALL NODES/LIGHTHOUSES. We do not/will not support use of different ciphers simultaneously!
@@ -86,6 +124,8 @@ punchy: true
# Configure the private interface. Note: addr is baked into the nebula certificate
tun:
# When tun is disabled, a lighthouse can be started without a local tun interface (and therefore without root)
disabled: false
# Name of the device
dev: nebula1
# Toggles forwarding of local broadcast packets, the address of which depends on the ip/mask encoded in pki.cert
@@ -116,6 +156,16 @@ logging:
level: info
# json or text formats currently available. Default is text
format: text
# Disable timestamp logging. useful when output is redirected to logging system that already adds timestamps. Default is false
#disable_timestamp: true
# timestamp format is specified in Go time format, see:
# https://golang.org/pkg/time/#pkg-constants
# default when `format: json`: "2006-01-02T15:04:05Z07:00" (RFC3339)
# default when `format: text`:
# when TTY attached: seconds since beginning of execution
# otherwise: "2006-01-02T15:04:05Z07:00" (RFC3339)
# As an example, to log as RFC3339 with millisecond precision, set to:
#timestamp_format: "2006-01-02T15:04:05.000Z07:00"
#stats:
#type: graphite
@@ -131,10 +181,31 @@ logging:
#subsystem: nebula
#interval: 10s
# enables counter metrics for meta packets
# e.g.: `messages.tx.handshake`
# NOTE: `message.{tx,rx}.recv_error` is always emitted
#message_metrics: false
# enables detailed counter metrics for lighthouse packets
# e.g.: `lighthouse.rx.HostQuery`
#lighthouse_metrics: false
# Handshake Manger Settings
#handshakes:
# Total time to try a handshake = sequence of `try_interval * retries`
# With 100ms interval and 20 retries it is 23.5 seconds
#try_interval: 100ms
#retries: 20
# wait_rotation is the number of handshake attempts to do before starting to try non-local IP addresses
#wait_rotation: 5
# trigger_buffer is the size of the buffer channel for quickly sending handshakes
# after receiving the response for lighthouse queries
#trigger_buffer: 64
# Nebula security group configuration
firewall:
conntrack:
tcp_timeout: 120h
tcp_timeout: 12m
udp_timeout: 3m
default_timeout: 10m
max_connections: 100000

View File

@@ -1,21 +1,21 @@
package nebula
import (
"crypto/sha256"
"encoding/binary"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"net"
"sync"
"time"
"crypto/sha256"
"encoding/hex"
"errors"
"reflect"
"strconv"
"strings"
"sync"
"time"
"github.com/rcrowley/go-metrics"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert"
)
@@ -38,13 +38,19 @@ type FirewallInterface interface {
type conn struct {
Expires time.Time // Time when this conntrack entry will expire
Seq uint32 // If tcp rtt tracking is enabled this will be the seq we are looking for an ack
Sent time.Time // If tcp rtt tracking is enabled this will be when Seq was last set
Seq uint32 // If tcp rtt tracking is enabled this will be the seq we are looking for an ack
// record why the original connection passed the firewall, so we can re-validate
// after ruleset changes. Note, rulesVersion is a uint16 so that these two
// fields pack for free after the uint32 above
incoming bool
rulesVersion uint16
}
// TODO: need conntrack max tracked connections handling
type Firewall struct {
Conns map[FirewallPacket]*conn
Conntrack *FirewallConntrack
InRules *FirewallTable
OutRules *FirewallTable
@@ -55,18 +61,23 @@ type Firewall struct {
UDPTimeout time.Duration //linux: 180s max
DefaultTimeout time.Duration //linux: 600s
TimerWheel *TimerWheel
// Used to ensure we don't emit local packets for ips we don't own
localIps *CIDRTree
connMutex sync.Mutex
rules string
rules string
rulesVersion uint16
trackTCPRTT bool
metricTCPRTT metrics.Histogram
}
type FirewallConntrack struct {
sync.Mutex
Conns map[FirewallPacket]*conn
TimerWheel *TimerWheel
}
type FirewallTable struct {
TCP firewallPort
UDP firewallPort
@@ -172,10 +183,12 @@ func NewFirewall(tcpTimeout, UDPTimeout, defaultTimeout time.Duration, c *cert.N
}
return &Firewall{
Conns: make(map[FirewallPacket]*conn),
Conntrack: &FirewallConntrack{
Conns: make(map[FirewallPacket]*conn),
TimerWheel: NewTimerWheel(min, max),
},
InRules: newFirewallTable(),
OutRules: newFirewallTable(),
TimerWheel: NewTimerWheel(min, max),
TCPTimeout: tcpTimeout,
UDPTimeout: UDPTimeout,
DefaultTimeout: defaultTimeout,
@@ -208,11 +221,17 @@ func NewFirewallFromConfig(nc *cert.NebulaCertificate, c *Config) (*Firewall, er
// AddRule properly creates the in memory rule structure for a firewall table.
func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, caName string, caSha string) error {
// Under gomobile, stringing a nil pointer with fmt causes an abort in debug mode for iOS
// https://github.com/golang/go/issues/14131
sIp := ""
if ip != nil {
sIp = ip.String()
}
// We need this rule string because we generate a hash. Removing this will break firewall reload.
ruleString := fmt.Sprintf(
"incoming: %v, proto: %v, startPort: %v, endPort: %v, groups: %v, host: %v, ip: %v, caName: %v, caSha: %s",
incoming, proto, startPort, endPort, groups, host, ip, caName, caSha,
incoming, proto, startPort, endPort, groups, host, sIp, caName, caSha,
)
f.rules += ruleString + "\n"
@@ -220,7 +239,7 @@ func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort
if !incoming {
direction = "outgoing"
}
l.WithField("firewallRule", m{"direction": direction, "proto": proto, "startPort": startPort, "endPort": endPort, "groups": groups, "host": host, "ip": ip, "caName": caName, "caSha": caSha}).
l.WithField("firewallRule", m{"direction": direction, "proto": proto, "startPort": startPort, "endPort": endPort, "groups": groups, "host": host, "ip": sIp, "caName": caName, "caSha": caSha}).
Info("Firewall rule added")
var (
@@ -347,20 +366,33 @@ func AddFirewallRulesFromConfig(inbound bool, config *Config, fw FirewallInterfa
return nil
}
func (f *Firewall) Drop(packet []byte, fp FirewallPacket, incoming bool, h *HostInfo, caPool *cert.NebulaCAPool) bool {
var ErrInvalidRemoteIP = errors.New("remote IP is not in remote certificate subnets")
var ErrInvalidLocalIP = errors.New("local IP is not in list of handled local IPs")
var ErrNoMatchingRule = errors.New("no matching rule in firewall table")
// Drop returns an error if the packet should be dropped, explaining why. It
// returns nil if the packet should not be dropped.
func (f *Firewall) Drop(packet []byte, fp FirewallPacket, incoming bool, h *HostInfo, caPool *cert.NebulaCAPool) error {
// Check if we spoke to this tuple, if we did then allow this packet
if f.inConns(packet, fp, incoming) {
return false
if f.inConns(packet, fp, incoming, h, caPool) {
return nil
}
// Make sure remote address matches nebula certificate
if h.remoteCidr.Contains(fp.RemoteIP) == nil {
return true
if remoteCidr := h.remoteCidr; remoteCidr != nil {
if remoteCidr.Contains(fp.RemoteIP) == nil {
return ErrInvalidRemoteIP
}
} else {
// Simple case: Certificate has one IP and no subnets
if fp.RemoteIP != h.hostId {
return ErrInvalidRemoteIP
}
}
// Make sure we are supposed to be handling this local ip address
if f.localIps.Contains(fp.LocalIP) == nil {
return true
return ErrInvalidLocalIP
}
table := f.OutRules
@@ -370,13 +402,13 @@ func (f *Firewall) Drop(packet []byte, fp FirewallPacket, incoming bool, h *Host
// We now know which firewall table to check against
if !table.match(fp, incoming, h.ConnectionState.peerCert, caPool) {
return true
return ErrNoMatchingRule
}
// We always want to conntrack since it is a faster operation
f.addConn(packet, fp, incoming)
return false
return nil
}
// Destroy cleans up any known cyclical references so the object can be free'd my GC. This should be called if a new
@@ -386,26 +418,66 @@ func (f *Firewall) Destroy() {
}
func (f *Firewall) EmitStats() {
conntrackCount := len(f.Conns)
conntrack := f.Conntrack
conntrack.Lock()
conntrackCount := len(conntrack.Conns)
conntrack.Unlock()
metrics.GetOrRegisterGauge("firewall.conntrack.count", nil).Update(int64(conntrackCount))
metrics.GetOrRegisterGauge("firewall.rules.version", nil).Update(int64(f.rulesVersion))
}
func (f *Firewall) inConns(packet []byte, fp FirewallPacket, incoming bool) bool {
f.connMutex.Lock()
func (f *Firewall) inConns(packet []byte, fp FirewallPacket, incoming bool, h *HostInfo, caPool *cert.NebulaCAPool) bool {
conntrack := f.Conntrack
conntrack.Lock()
// Purge every time we test
ep, has := f.TimerWheel.Purge()
ep, has := conntrack.TimerWheel.Purge()
if has {
f.evict(ep)
}
c, ok := f.Conns[fp]
c, ok := conntrack.Conns[fp]
if !ok {
f.connMutex.Unlock()
conntrack.Unlock()
return false
}
if c.rulesVersion != f.rulesVersion {
// This conntrack entry was for an older rule set, validate
// it still passes with the current rule set
table := f.OutRules
if c.incoming {
table = f.InRules
}
// We now know which firewall table to check against
if !table.match(fp, c.incoming, h.ConnectionState.peerCert, caPool) {
if l.Level >= logrus.DebugLevel {
h.logger().
WithField("fwPacket", fp).
WithField("incoming", c.incoming).
WithField("rulesVersion", f.rulesVersion).
WithField("oldRulesVersion", c.rulesVersion).
Debugln("dropping old conntrack entry, does not match new ruleset")
}
delete(conntrack.Conns, fp)
conntrack.Unlock()
return false
}
if l.Level >= logrus.DebugLevel {
h.logger().
WithField("fwPacket", fp).
WithField("incoming", c.incoming).
WithField("rulesVersion", f.rulesVersion).
WithField("oldRulesVersion", c.rulesVersion).
Debugln("keeping old conntrack entry, does match new ruleset")
}
c.rulesVersion = f.rulesVersion
}
switch fp.Protocol {
case fwProtoTCP:
c.Expires = time.Now().Add(f.TCPTimeout)
@@ -420,7 +492,7 @@ func (f *Firewall) inConns(packet []byte, fp FirewallPacket, incoming bool) bool
c.Expires = time.Now().Add(f.DefaultTimeout)
}
f.connMutex.Unlock()
conntrack.Unlock()
return true
}
@@ -441,14 +513,19 @@ func (f *Firewall) addConn(packet []byte, fp FirewallPacket, incoming bool) {
timeout = f.DefaultTimeout
}
f.connMutex.Lock()
if _, ok := f.Conns[fp]; !ok {
f.TimerWheel.Add(fp, timeout)
conntrack := f.Conntrack
conntrack.Lock()
if _, ok := conntrack.Conns[fp]; !ok {
conntrack.TimerWheel.Add(fp, timeout)
}
// Record which rulesVersion allowed this connection, so we can retest after
// firewall reload
c.incoming = incoming
c.rulesVersion = f.rulesVersion
c.Expires = time.Now().Add(timeout)
f.Conns[fp] = c
f.connMutex.Unlock()
conntrack.Conns[fp] = c
conntrack.Unlock()
}
// Evict checks if a conntrack entry has expired, if so it is removed, if not it is re-added to the wheel
@@ -456,7 +533,8 @@ func (f *Firewall) addConn(packet []byte, fp FirewallPacket, incoming bool) {
func (f *Firewall) evict(p FirewallPacket) {
//TODO: report a stat if the tcp rtt tracking was never resolved?
// Are we still tracking this conn?
t, ok := f.Conns[p]
conntrack := f.Conntrack
t, ok := conntrack.Conns[p]
if !ok {
return
}
@@ -465,12 +543,12 @@ func (f *Firewall) evict(p FirewallPacket) {
// Timeout is in the future, re-add the timer
if newT > 0 {
f.TimerWheel.Add(p, newT)
conntrack.TimerWheel.Add(p, newT)
return
}
// This conn is done
delete(f.Conns, p)
delete(conntrack.Conns, p)
}
func (ft *FirewallTable) match(p FirewallPacket, incoming bool, c *cert.NebulaCertificate, caPool *cert.NebulaCAPool) bool {

View File

@@ -17,37 +17,39 @@ import (
func TestNewFirewall(t *testing.T) {
c := &cert.NebulaCertificate{}
fw := NewFirewall(time.Second, time.Minute, time.Hour, c)
assert.NotNil(t, fw.Conns)
conntrack := fw.Conntrack
assert.NotNil(t, conntrack)
assert.NotNil(t, conntrack.Conns)
assert.NotNil(t, conntrack.TimerWheel)
assert.NotNil(t, fw.InRules)
assert.NotNil(t, fw.OutRules)
assert.NotNil(t, fw.TimerWheel)
assert.Equal(t, time.Second, fw.TCPTimeout)
assert.Equal(t, time.Minute, fw.UDPTimeout)
assert.Equal(t, time.Hour, fw.DefaultTimeout)
assert.Equal(t, time.Hour, fw.TimerWheel.wheelDuration)
assert.Equal(t, time.Hour, fw.TimerWheel.wheelDuration)
assert.Equal(t, 3601, fw.TimerWheel.wheelLen)
assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen)
fw = NewFirewall(time.Second, time.Hour, time.Minute, c)
assert.Equal(t, time.Hour, fw.TimerWheel.wheelDuration)
assert.Equal(t, 3601, fw.TimerWheel.wheelLen)
assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen)
fw = NewFirewall(time.Hour, time.Second, time.Minute, c)
assert.Equal(t, time.Hour, fw.TimerWheel.wheelDuration)
assert.Equal(t, 3601, fw.TimerWheel.wheelLen)
assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen)
fw = NewFirewall(time.Hour, time.Minute, time.Second, c)
assert.Equal(t, time.Hour, fw.TimerWheel.wheelDuration)
assert.Equal(t, 3601, fw.TimerWheel.wheelLen)
assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen)
fw = NewFirewall(time.Minute, time.Hour, time.Second, c)
assert.Equal(t, time.Hour, fw.TimerWheel.wheelDuration)
assert.Equal(t, 3601, fw.TimerWheel.wheelLen)
assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen)
fw = NewFirewall(time.Minute, time.Second, time.Hour, c)
assert.Equal(t, time.Hour, fw.TimerWheel.wheelDuration)
assert.Equal(t, 3601, fw.TimerWheel.wheelLen)
assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen)
}
func TestFirewall_AddRule(t *testing.T) {
@@ -171,6 +173,7 @@ func TestFirewall_Drop(t *testing.T) {
ConnectionState: &ConnectionState{
peerCert: &c,
},
hostId: ip2int(ipNet.IP),
}
h.CreateRemoteCIDR(&c)
@@ -179,44 +182,44 @@ func TestFirewall_Drop(t *testing.T) {
cp := cert.NewCAPool()
// Drop outbound
assert.True(t, fw.Drop([]byte{}, p, false, &h, cp))
assert.Equal(t, fw.Drop([]byte{}, p, false, &h, cp), ErrNoMatchingRule)
// Allow inbound
resetConntrack(fw)
assert.False(t, fw.Drop([]byte{}, p, true, &h, cp))
assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp))
// Allow outbound because conntrack
assert.False(t, fw.Drop([]byte{}, p, false, &h, cp))
assert.NoError(t, fw.Drop([]byte{}, p, false, &h, cp))
// test remote mismatch
oldRemote := p.RemoteIP
p.RemoteIP = ip2int(net.IPv4(1, 2, 3, 10))
assert.True(t, fw.Drop([]byte{}, p, false, &h, cp))
assert.Equal(t, fw.Drop([]byte{}, p, false, &h, cp), ErrInvalidRemoteIP)
p.RemoteIP = oldRemote
// ensure signer doesn't get in the way of group checks
fw = NewFirewall(time.Second, time.Minute, time.Hour, &c)
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"nope"}, "", nil, "", "signer-shasum"))
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group"}, "", nil, "", "signer-shasum-bad"))
assert.True(t, fw.Drop([]byte{}, p, true, &h, cp))
assert.Equal(t, fw.Drop([]byte{}, p, true, &h, cp), ErrNoMatchingRule)
// test caSha doesn't drop on match
fw = NewFirewall(time.Second, time.Minute, time.Hour, &c)
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"nope"}, "", nil, "", "signer-shasum-bad"))
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group"}, "", nil, "", "signer-shasum"))
assert.False(t, fw.Drop([]byte{}, p, true, &h, cp))
assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp))
// ensure ca name doesn't get in the way of group checks
cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}}
fw = NewFirewall(time.Second, time.Minute, time.Hour, &c)
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"nope"}, "", nil, "ca-good", ""))
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group"}, "", nil, "ca-good-bad", ""))
assert.True(t, fw.Drop([]byte{}, p, true, &h, cp))
assert.Equal(t, fw.Drop([]byte{}, p, true, &h, cp), ErrNoMatchingRule)
// test caName doesn't drop on match
cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}}
fw = NewFirewall(time.Second, time.Minute, time.Hour, &c)
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"nope"}, "", nil, "ca-good-bad", ""))
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group"}, "", nil, "ca-good", ""))
assert.False(t, fw.Drop([]byte{}, p, true, &h, cp))
assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp))
}
func BenchmarkFirewallTable_match(b *testing.B) {
@@ -344,6 +347,7 @@ func TestFirewall_Drop2(t *testing.T) {
ConnectionState: &ConnectionState{
peerCert: &c,
},
hostId: ip2int(ipNet.IP),
}
h.CreateRemoteCIDR(&c)
@@ -366,10 +370,10 @@ func TestFirewall_Drop2(t *testing.T) {
cp := cert.NewCAPool()
// h1/c1 lacks the proper groups
assert.True(t, fw.Drop([]byte{}, p, true, &h1, cp))
assert.Error(t, fw.Drop([]byte{}, p, true, &h1, cp), ErrNoMatchingRule)
// c has the proper groups
resetConntrack(fw)
assert.False(t, fw.Drop([]byte{}, p, true, &h, cp))
assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp))
}
func TestFirewall_Drop3(t *testing.T) {
@@ -410,6 +414,7 @@ func TestFirewall_Drop3(t *testing.T) {
ConnectionState: &ConnectionState{
peerCert: &c1,
},
hostId: ip2int(ipNet.IP),
}
h1.CreateRemoteCIDR(&c1)
@@ -424,6 +429,7 @@ func TestFirewall_Drop3(t *testing.T) {
ConnectionState: &ConnectionState{
peerCert: &c2,
},
hostId: ip2int(ipNet.IP),
}
h2.CreateRemoteCIDR(&c2)
@@ -438,6 +444,7 @@ func TestFirewall_Drop3(t *testing.T) {
ConnectionState: &ConnectionState{
peerCert: &c3,
},
hostId: ip2int(ipNet.IP),
}
h3.CreateRemoteCIDR(&c3)
@@ -447,13 +454,81 @@ func TestFirewall_Drop3(t *testing.T) {
cp := cert.NewCAPool()
// c1 should pass because host match
assert.False(t, fw.Drop([]byte{}, p, true, &h1, cp))
assert.NoError(t, fw.Drop([]byte{}, p, true, &h1, cp))
// c2 should pass because ca sha match
resetConntrack(fw)
assert.False(t, fw.Drop([]byte{}, p, true, &h2, cp))
assert.NoError(t, fw.Drop([]byte{}, p, true, &h2, cp))
// c3 should fail because no match
resetConntrack(fw)
assert.True(t, fw.Drop([]byte{}, p, true, &h3, cp))
assert.Equal(t, fw.Drop([]byte{}, p, true, &h3, cp), ErrNoMatchingRule)
}
func TestFirewall_DropConntrackReload(t *testing.T) {
ob := &bytes.Buffer{}
out := l.Out
l.SetOutput(ob)
defer l.SetOutput(out)
p := FirewallPacket{
ip2int(net.IPv4(1, 2, 3, 4)),
ip2int(net.IPv4(1, 2, 3, 4)),
10,
90,
fwProtoUDP,
false,
}
ipNet := net.IPNet{
IP: net.IPv4(1, 2, 3, 4),
Mask: net.IPMask{255, 255, 255, 0},
}
c := cert.NebulaCertificate{
Details: cert.NebulaCertificateDetails{
Name: "host1",
Ips: []*net.IPNet{&ipNet},
Groups: []string{"default-group"},
InvertedGroups: map[string]struct{}{"default-group": {}},
Issuer: "signer-shasum",
},
}
h := HostInfo{
ConnectionState: &ConnectionState{
peerCert: &c,
},
hostId: ip2int(ipNet.IP),
}
h.CreateRemoteCIDR(&c)
fw := NewFirewall(time.Second, time.Minute, time.Hour, &c)
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"any"}, "", nil, "", ""))
cp := cert.NewCAPool()
// Drop outbound
assert.Equal(t, fw.Drop([]byte{}, p, false, &h, cp), ErrNoMatchingRule)
// Allow inbound
resetConntrack(fw)
assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp))
// Allow outbound because conntrack
assert.NoError(t, fw.Drop([]byte{}, p, false, &h, cp))
oldFw := fw
fw = NewFirewall(time.Second, time.Minute, time.Hour, &c)
assert.Nil(t, fw.AddRule(true, fwProtoAny, 10, 10, []string{"any"}, "", nil, "", ""))
fw.Conntrack = oldFw.Conntrack
fw.rulesVersion = oldFw.rulesVersion + 1
// Allow outbound because conntrack and new rules allow port 10
assert.NoError(t, fw.Drop([]byte{}, p, false, &h, cp))
oldFw = fw
fw = NewFirewall(time.Second, time.Minute, time.Hour, &c)
assert.Nil(t, fw.AddRule(true, fwProtoAny, 11, 11, []string{"any"}, "", nil, "", ""))
fw.Conntrack = oldFw.Conntrack
fw.rulesVersion = oldFw.rulesVersion + 1
// Drop outbound because conntrack doesn't match new ruleset
assert.Equal(t, fw.Drop([]byte{}, p, false, &h, cp), ErrNoMatchingRule)
}
func BenchmarkLookup(b *testing.B) {
@@ -856,7 +931,7 @@ func (mf *mockFirewall) AddRule(incoming bool, proto uint8, startPort int32, end
}
func resetConntrack(fw *Firewall) {
fw.connMutex.Lock()
fw.Conns = map[FirewallPacket]*conn{}
fw.connMutex.Unlock()
fw.Conntrack.Lock()
fw.Conntrack.Conns = map[FirewallPacket]*conn{}
fw.Conntrack.Unlock()
}

4
go.mod
View File

@@ -22,10 +22,10 @@ require (
github.com/rcrowley/go-metrics v0.0.0-20190826022208-cac0b30c2563
github.com/sirupsen/logrus v1.4.2
github.com/songgao/water v0.0.0-20190725173103-fd331bda3f4b
github.com/stretchr/testify v1.4.0
github.com/stretchr/testify v1.6.1
github.com/vishvananda/netlink v1.0.1-0.20190522153524-00009fb8606a
github.com/vishvananda/netns v0.0.0-20191106174202-0a2b9b5464df // indirect
golang.org/x/crypto v0.0.0-20191206172530-e9b2fee46413
golang.org/x/crypto v0.0.0-20200220183623-bac4c82f6975
golang.org/x/net v0.0.0-20191209160850-c0dbc17a3553
golang.org/x/sys v0.0.0-20191210023423-ac6580df4449
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 // indirect

10
go.sum
View File

@@ -103,8 +103,8 @@ github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk=
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0=
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/vishvananda/netlink v1.0.1-0.20190522153524-00009fb8606a h1:Bt1IVPhiCDMqwGrc2nnbIN4QKvJGx6SK2NzWBmW00ao=
github.com/vishvananda/netlink v1.0.1-0.20190522153524-00009fb8606a/go.mod h1:+SR5DhBJrl6ZM7CoCKvpw5BKroDKQ+PJqOg65H/2ktk=
github.com/vishvananda/netns v0.0.0-20191106174202-0a2b9b5464df h1:OviZH7qLw/7ZovXvuNyL3XQl8UFofeikI1NW1Gypu7k=
@@ -112,8 +112,8 @@ github.com/vishvananda/netns v0.0.0-20191106174202-0a2b9b5464df/go.mod h1:JP3t17
golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20190923035154-9ee001bba392/go.mod h1:/lpIB1dKB+9EgE3H3cr1v9wB50oz8l4C4h62xy7jSTY=
golang.org/x/crypto v0.0.0-20191206172530-e9b2fee46413 h1:ULYEB3JvPRE/IfO+9uO7vKV/xzVTO7XPAwm8xbf4w2g=
golang.org/x/crypto v0.0.0-20191206172530-e9b2fee46413/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20200220183623-bac4c82f6975 h1:/Tl7pH94bvbAAHBdZJT947M/+gp0+CqQXDtMRC0fseo=
golang.org/x/crypto v0.0.0-20200220183623-bac4c82f6975/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190613194153-d28f0bde5980/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
@@ -152,3 +152,5 @@ gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw=
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.2.7 h1:VUgggvou5XRW9mHwD/yXxIYSMtY0zoKQf/v226p2nyo=
gopkg.in/yaml.v2 v2.2.7/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

View File

@@ -13,6 +13,11 @@ func HandleIncomingHandshake(f *Interface, addr *udpAddr, packet []byte, h *Head
// return
//}
if !f.lightHouse.remoteAllowList.Allow(udp2ipInt(addr)) {
l.WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
return
}
tearDown := false
switch h.Subtype {
case handshakeIXPSK0:

View File

@@ -1,11 +1,10 @@
package nebula
import (
"bytes"
"sync/atomic"
"time"
"bytes"
"github.com/flynn/noise"
"github.com/golang/protobuf/proto"
)
@@ -98,6 +97,7 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
hostinfo, _ := f.handshakeManager.pendingHostMap.QueryReverseIndex(hs.Details.InitiatorIndex)
if hostinfo != nil && bytes.Equal(hostinfo.HandshakePacket[0], packet[HeaderLen:]) {
if msg, ok := hostinfo.HandshakePacket[2]; ok {
f.messageMetrics.Tx(handshake, NebulaMessageSubType(msg[1]), 1)
err := f.outside.WriteTo(msg, addr)
if err != nil {
l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
@@ -125,10 +125,14 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
return true
}
vpnIP := ip2int(remoteCert.Details.Ips[0].IP)
certName := remoteCert.Details.Name
fingerprint, _ := remoteCert.Sha256Sum()
myIndex, err := generateIndex()
if err != nil {
l.WithError(err).WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to generate index")
return true
}
@@ -136,11 +140,15 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
hostinfo, err = f.handshakeManager.AddIndex(myIndex, ci)
if err != nil {
l.WithError(err).WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Error adding index to connection manager")
return true
}
l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
Info("Handshake message received")
@@ -152,6 +160,8 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
hsBytes, err := proto.Marshal(hs)
if err != nil {
l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to marshal handshake message")
return true
}
@@ -160,12 +170,16 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
msg, dKey, eKey, err := ci.H.WriteMessage(header, hsBytes)
if err != nil {
l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage")
return true
}
if f.hostMap.CheckHandshakeCompleteIP(vpnIP) && vpnIP < ip2int(f.certState.certificate.Details.Ips[0].IP) {
l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
Info("Prevented a handshake race")
@@ -184,14 +198,19 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
hostinfo.HandshakePacket[2] = make([]byte, len(msg))
copy(hostinfo.HandshakePacket[2], msg)
f.messageMetrics.Tx(handshake, NebulaMessageSubType(msg[1]), 1)
err := f.outside.WriteTo(msg, addr)
if err != nil {
l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
WithError(err).Error("Failed to send handshake")
} else {
l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
Info("Handshake message sent")
@@ -214,6 +233,8 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
ho, err := f.hostMap.QueryVpnIP(vpnIP)
if err == nil && ho.localIndexId != 0 {
l.WithField("vpnIp", vpnIP).
WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("action", "removing stale index").
WithField("index", ho.localIndexId).
Debug("Handshake processing")
@@ -226,6 +247,8 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
hostinfo.handshakeComplete()
} else {
l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
Error("Noise did not arrive at a key")
return true
@@ -284,9 +307,13 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
return true
}
vpnIP := ip2int(remoteCert.Details.Ips[0].IP)
certName := remoteCert.Details.Name
fingerprint, _ := remoteCert.Sha256Sum()
duration := time.Since(hostinfo.handshakeStart).Nanoseconds()
l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
WithField("durationNs", duration).
@@ -324,6 +351,8 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
ho, err := f.hostMap.QueryVpnIP(vpnIP)
if err == nil && ho.localIndexId != 0 {
l.WithField("vpnIp", vpnIP).
WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("action", "removing stale index").
WithField("index", ho.localIndexId).
Debug("Handshake processing")
@@ -337,6 +366,8 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
f.metricHandshakes.Update(duration)
} else {
l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
Error("Noise did not arrive at a key")
return true

View File

@@ -13,39 +13,76 @@ import (
const (
// Total time to try a handshake = sequence of HandshakeTryInterval * HandshakeRetries
// With 100ms interval and 20 retries is 23.5 seconds
HandshakeTryInterval = time.Millisecond * 100
HandshakeRetries = 20
// HandshakeWaitRotation is the number of handshake attempts to do before starting to use other ips addresses
HandshakeWaitRotation = 5
DefaultHandshakeTryInterval = time.Millisecond * 100
DefaultHandshakeRetries = 20
// DefaultHandshakeWaitRotation is the number of handshake attempts to do before starting to use other ips addresses
DefaultHandshakeWaitRotation = 5
DefaultHandshakeTriggerBuffer = 64
)
var (
defaultHandshakeConfig = HandshakeConfig{
tryInterval: DefaultHandshakeTryInterval,
retries: DefaultHandshakeRetries,
waitRotation: DefaultHandshakeWaitRotation,
triggerBuffer: DefaultHandshakeTriggerBuffer,
}
)
type HandshakeConfig struct {
tryInterval time.Duration
retries int
waitRotation int
triggerBuffer int
messageMetrics *MessageMetrics
}
type HandshakeManager struct {
pendingHostMap *HostMap
mainHostMap *HostMap
lightHouse *LightHouse
outside *udpConn
config HandshakeConfig
// can be used to trigger outbound handshake for the given vpnIP
trigger chan uint32
OutboundHandshakeTimer *SystemTimerWheel
InboundHandshakeTimer *SystemTimerWheel
messageMetrics *MessageMetrics
}
func NewHandshakeManager(tunCidr *net.IPNet, preferredRanges []*net.IPNet, mainHostMap *HostMap, lightHouse *LightHouse, outside *udpConn) *HandshakeManager {
func NewHandshakeManager(tunCidr *net.IPNet, preferredRanges []*net.IPNet, mainHostMap *HostMap, lightHouse *LightHouse, outside *udpConn, config HandshakeConfig) *HandshakeManager {
return &HandshakeManager{
pendingHostMap: NewHostMap("pending", tunCidr, preferredRanges),
mainHostMap: mainHostMap,
lightHouse: lightHouse,
outside: outside,
OutboundHandshakeTimer: NewSystemTimerWheel(HandshakeTryInterval, HandshakeTryInterval*HandshakeRetries),
InboundHandshakeTimer: NewSystemTimerWheel(HandshakeTryInterval, HandshakeTryInterval*HandshakeRetries),
config: config,
trigger: make(chan uint32, config.triggerBuffer),
OutboundHandshakeTimer: NewSystemTimerWheel(config.tryInterval, config.tryInterval*time.Duration(config.retries)),
InboundHandshakeTimer: NewSystemTimerWheel(config.tryInterval, config.tryInterval*time.Duration(config.retries)),
messageMetrics: config.messageMetrics,
}
}
func (c *HandshakeManager) Run(f EncWriter) {
clockSource := time.Tick(HandshakeTryInterval)
for now := range clockSource {
c.NextOutboundHandshakeTimerTick(now, f)
c.NextInboundHandshakeTimerTick(now)
clockSource := time.Tick(c.config.tryInterval)
for {
select {
case vpnIP := <-c.trigger:
l.WithField("vpnIp", IntIp(vpnIP)).Debug("HandshakeManager: triggered")
c.handleOutbound(vpnIP, f, true)
case now := <-clockSource:
c.NextOutboundHandshakeTimerTick(now, f)
c.NextInboundHandshakeTimerTick(now)
}
}
}
@@ -57,68 +94,86 @@ func (c *HandshakeManager) NextOutboundHandshakeTimerTick(now time.Time, f EncWr
break
}
vpnIP := ep.(uint32)
c.handleOutbound(vpnIP, f, false)
}
}
index, err := c.pendingHostMap.GetIndexByVpnIP(vpnIP)
if err != nil {
continue
func (c *HandshakeManager) handleOutbound(vpnIP uint32, f EncWriter, lighthouseTriggered bool) {
index, err := c.pendingHostMap.GetIndexByVpnIP(vpnIP)
if err != nil {
return
}
hostinfo, err := c.pendingHostMap.QueryVpnIP(vpnIP)
if err != nil {
return
}
// If we haven't finished the handshake and we haven't hit max retries, query
// lighthouse and then send the handshake packet again.
if hostinfo.HandshakeCounter < c.config.retries && !hostinfo.HandshakeComplete {
if hostinfo.remote == nil {
// We continue to query the lighthouse because hosts may
// come online during handshake retries. If the query
// succeeds (no error), add the lighthouse info to hostinfo
ips := c.lightHouse.QueryCache(vpnIP)
// If we have no responses yet, or only one IP (the host hadn't
// finished reporting its own IPs yet), then send another query to
// the LH.
if len(ips) <= 1 {
ips, err = c.lightHouse.Query(vpnIP, f)
}
if err == nil {
for _, ip := range ips {
hostinfo.AddRemote(ip)
}
hostinfo.ForcePromoteBest(c.mainHostMap.preferredRanges)
}
} else if lighthouseTriggered {
// We were triggered by a lighthouse HostQueryReply packet, but
// we have already picked a remote for this host (this can happen
// if we are configured with multiple lighthouses). So we can skip
// this trigger and let the timerwheel handle the rest of the
// process
return
}
hostinfo, err := c.pendingHostMap.QueryVpnIP(vpnIP)
if err != nil {
continue
hostinfo.HandshakeCounter++
// We want to use the "best" calculated ip for the first 5 attempts, after that we just blindly rotate through
// all the others until we can stand up a connection.
if hostinfo.HandshakeCounter > c.config.waitRotation {
hostinfo.rotateRemote()
}
// If we haven't finished the handshake and we haven't hit max retries, query
// lighthouse and then send the handshake packet again.
if hostinfo.HandshakeCounter < HandshakeRetries && !hostinfo.HandshakeComplete {
if hostinfo.remote == nil {
// We continue to query the lighthouse because hosts may
// come online during handshake retries. If the query
// succeeds (no error), add the lighthouse info to hostinfo
ips, err := c.lightHouse.Query(vpnIP, f)
if err == nil {
for _, ip := range ips {
hostinfo.AddRemote(ip)
}
hostinfo.ForcePromoteBest(c.mainHostMap.preferredRanges)
}
// Ensure the handshake is ready to avoid a race in timer tick and stage 0 handshake generation
if hostinfo.HandshakeReady && hostinfo.remote != nil {
c.messageMetrics.Tx(handshake, NebulaMessageSubType(hostinfo.HandshakePacket[0][1]), 1)
err := c.outside.WriteTo(hostinfo.HandshakePacket[0], hostinfo.remote)
if err != nil {
hostinfo.logger().WithField("udpAddr", hostinfo.remote).
WithField("initiatorIndex", hostinfo.localIndexId).
WithField("remoteIndex", hostinfo.remoteIndexId).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
WithError(err).Error("Failed to send handshake message")
} else {
//TODO: this log line is assuming a lot of stuff around the cached stage 0 handshake packet, we should
// keep the real packet struct around for logging purposes
hostinfo.logger().WithField("udpAddr", hostinfo.remote).
WithField("initiatorIndex", hostinfo.localIndexId).
WithField("remoteIndex", hostinfo.remoteIndexId).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
Info("Handshake message sent")
}
}
hostinfo.HandshakeCounter++
// We want to use the "best" calculated ip for the first 5 attempts, after that we just blindly rotate through
// all the others until we can stand up a connection.
if hostinfo.HandshakeCounter > HandshakeWaitRotation {
hostinfo.rotateRemote()
}
// Ensure the handshake is ready to avoid a race in timer tick and stage 0 handshake generation
if hostinfo.HandshakeReady && hostinfo.remote != nil {
err := c.outside.WriteTo(hostinfo.HandshakePacket[0], hostinfo.remote)
if err != nil {
l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", hostinfo.remote).
WithField("initiatorIndex", hostinfo.localIndexId).
WithField("remoteIndex", hostinfo.remoteIndexId).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
WithError(err).Error("Failed to send handshake message")
} else {
//TODO: this log line is assuming a lot of stuff around the cached stage 0 handshake packet, we should
// keep the real packet struct around for logging purposes
l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", hostinfo.remote).
WithField("initiatorIndex", hostinfo.localIndexId).
WithField("remoteIndex", hostinfo.remoteIndexId).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
Info("Handshake message sent")
}
}
// Readd to the timer wheel so we continue trying wait HandshakeTryInterval * counter longer for next try
// Readd to the timer wheel so we continue trying wait HandshakeTryInterval * counter longer for next try
if !lighthouseTriggered {
//l.Infoln("Interval: ", HandshakeTryInterval*time.Duration(hostinfo.HandshakeCounter))
c.OutboundHandshakeTimer.Add(vpnIP, HandshakeTryInterval*time.Duration(hostinfo.HandshakeCounter))
} else {
c.pendingHostMap.DeleteVpnIP(vpnIP)
c.pendingHostMap.DeleteIndex(index)
c.OutboundHandshakeTimer.Add(vpnIP, c.config.tryInterval*time.Duration(hostinfo.HandshakeCounter))
}
} else {
c.pendingHostMap.DeleteVpnIP(vpnIP)
c.pendingHostMap.DeleteIndex(index)
}
}
@@ -144,7 +199,8 @@ func (c *HandshakeManager) AddVpnIP(vpnIP uint32) *HostInfo {
hostinfo := c.pendingHostMap.AddVpnIP(vpnIP)
// We lock here and use an array to insert items to prevent locking the
// main receive thread for very long by waiting to add items to the pending map
c.OutboundHandshakeTimer.Add(vpnIP, HandshakeTryInterval)
c.OutboundHandshakeTimer.Add(vpnIP, c.config.tryInterval)
return hostinfo
}

View File

@@ -21,7 +21,7 @@ func Test_NewHandshakeManagerIndex(t *testing.T) {
preferredRanges := []*net.IPNet{localrange}
mainHM := NewHostMap("test", vpncidr, preferredRanges)
blah := NewHandshakeManager(tuncidr, preferredRanges, mainHM, &LightHouse{}, &udpConn{})
blah := NewHandshakeManager(tuncidr, preferredRanges, mainHM, &LightHouse{}, &udpConn{}, defaultHandshakeConfig)
now := time.Now()
blah.NextInboundHandshakeTimerTick(now)
@@ -37,8 +37,8 @@ func Test_NewHandshakeManagerIndex(t *testing.T) {
// Adding something to pending should not affect the main hostmap
assert.Len(t, mainHM.Indexes, 0)
// Jump ahead 8 seconds
for i := 1; i <= HandshakeRetries; i++ {
next_tick := now.Add(HandshakeTryInterval * time.Duration(i))
for i := 1; i <= DefaultHandshakeRetries; i++ {
next_tick := now.Add(DefaultHandshakeTryInterval * time.Duration(i))
blah.NextInboundHandshakeTimerTick(next_tick)
}
// Confirm they are still in the pending index list
@@ -63,7 +63,7 @@ func Test_NewHandshakeManagerVpnIP(t *testing.T) {
mw := &mockEncWriter{}
mainHM := NewHostMap("test", vpncidr, preferredRanges)
blah := NewHandshakeManager(tuncidr, preferredRanges, mainHM, &LightHouse{}, &udpConn{})
blah := NewHandshakeManager(tuncidr, preferredRanges, mainHM, &LightHouse{}, &udpConn{}, defaultHandshakeConfig)
now := time.Now()
blah.NextOutboundHandshakeTimerTick(now, mw)
@@ -81,8 +81,8 @@ func Test_NewHandshakeManagerVpnIP(t *testing.T) {
// Jump ahead `HandshakeRetries` ticks
cumulative := time.Duration(0)
for i := 0; i <= HandshakeRetries+1; i++ {
cumulative += time.Duration(i)*HandshakeTryInterval + 1
for i := 0; i <= DefaultHandshakeRetries+1; i++ {
cumulative += time.Duration(i)*DefaultHandshakeTryInterval + 1
next_tick := now.Add(cumulative)
//l.Infoln(next_tick)
blah.NextOutboundHandshakeTimerTick(next_tick, mw)
@@ -93,7 +93,7 @@ func Test_NewHandshakeManagerVpnIP(t *testing.T) {
assert.Contains(t, blah.pendingHostMap.Hosts, uint32(v))
}
// Jump ahead 1 more second
cumulative += time.Duration(HandshakeRetries+1) * HandshakeTryInterval
cumulative += time.Duration(DefaultHandshakeRetries+1) * DefaultHandshakeTryInterval
next_tick := now.Add(cumulative)
//l.Infoln(next_tick)
blah.NextOutboundHandshakeTimerTick(next_tick, mw)
@@ -103,6 +103,56 @@ func Test_NewHandshakeManagerVpnIP(t *testing.T) {
}
}
func Test_NewHandshakeManagerTrigger(t *testing.T) {
_, tuncidr, _ := net.ParseCIDR("172.1.1.1/24")
_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
ip := ip2int(net.ParseIP("172.1.1.2"))
preferredRanges := []*net.IPNet{localrange}
mw := &mockEncWriter{}
mainHM := NewHostMap("test", vpncidr, preferredRanges)
lh := &LightHouse{}
blah := NewHandshakeManager(tuncidr, preferredRanges, mainHM, lh, &udpConn{}, defaultHandshakeConfig)
now := time.Now()
blah.NextOutboundHandshakeTimerTick(now, mw)
assert.Equal(t, 0, testCountTimerWheelEntries(blah.OutboundHandshakeTimer))
blah.AddVpnIP(ip)
assert.Equal(t, 1, testCountTimerWheelEntries(blah.OutboundHandshakeTimer))
// Trigger the same method the channel will
blah.handleOutbound(ip, mw, true)
// Make sure the trigger doesn't schedule another timer entry
assert.Equal(t, 1, testCountTimerWheelEntries(blah.OutboundHandshakeTimer))
hi := blah.pendingHostMap.Hosts[ip]
assert.Nil(t, hi.remote)
lh.addrMap = map[uint32][]udpAddr{
ip: {*NewUDPAddrFromString("10.1.1.1:4242")},
}
// This should trigger the hostmap to populate the hostinfo
blah.handleOutbound(ip, mw, true)
assert.NotNil(t, hi.remote)
assert.Equal(t, 1, testCountTimerWheelEntries(blah.OutboundHandshakeTimer))
}
func testCountTimerWheelEntries(tw *SystemTimerWheel) (c int) {
for _, i := range tw.wheel {
n := i.Head
for n != nil {
c++
n = n.Next
}
}
return c
}
func Test_NewHandshakeManagerVpnIPcleanup(t *testing.T) {
_, tuncidr, _ := net.ParseCIDR("172.1.1.1/24")
_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
@@ -112,7 +162,7 @@ func Test_NewHandshakeManagerVpnIPcleanup(t *testing.T) {
mw := &mockEncWriter{}
mainHM := NewHostMap("test", vpncidr, preferredRanges)
blah := NewHandshakeManager(tuncidr, preferredRanges, mainHM, &LightHouse{}, &udpConn{})
blah := NewHandshakeManager(tuncidr, preferredRanges, mainHM, &LightHouse{}, &udpConn{}, defaultHandshakeConfig)
now := time.Now()
blah.NextOutboundHandshakeTimerTick(now, mw)
@@ -125,8 +175,8 @@ func Test_NewHandshakeManagerVpnIPcleanup(t *testing.T) {
// Jump ahead `HandshakeRetries` ticks. Eviction should happen in pending
// but not main hostmap
cumulative := time.Duration(0)
for i := 1; i <= HandshakeRetries+2; i++ {
cumulative += HandshakeTryInterval * time.Duration(i)
for i := 1; i <= DefaultHandshakeRetries+2; i++ {
cumulative += DefaultHandshakeTryInterval * time.Duration(i)
next_tick := now.Add(cumulative)
blah.NextOutboundHandshakeTimerTick(next_tick, mw)
}
@@ -161,7 +211,7 @@ func Test_NewHandshakeManagerIndexcleanup(t *testing.T) {
preferredRanges := []*net.IPNet{localrange}
mainHM := NewHostMap("test", vpncidr, preferredRanges)
blah := NewHandshakeManager(tuncidr, preferredRanges, mainHM, &LightHouse{}, &udpConn{})
blah := NewHandshakeManager(tuncidr, preferredRanges, mainHM, &LightHouse{}, &udpConn{}, defaultHandshakeConfig)
now := time.Now()
blah.NextInboundHandshakeTimerTick(now)
@@ -171,12 +221,12 @@ func Test_NewHandshakeManagerIndexcleanup(t *testing.T) {
blah.pendingHostMap.AddVpnIPHostInfo(101010, hostinfo)
assert.Contains(t, blah.pendingHostMap.Hosts, uint32(101010))
for i := 1; i <= HandshakeRetries+2; i++ {
next_tick := now.Add(HandshakeTryInterval * time.Duration(i))
for i := 1; i <= DefaultHandshakeRetries+2; i++ {
next_tick := now.Add(DefaultHandshakeTryInterval * time.Duration(i))
blah.NextInboundHandshakeTimerTick(next_tick)
}
next_tick := now.Add(HandshakeTryInterval*HandshakeRetries + 3)
next_tick := now.Add(DefaultHandshakeTryInterval*DefaultHandshakeRetries + 3)
blah.NextInboundHandshakeTimerTick(next_tick)
assert.NotContains(t, blah.pendingHostMap.Hosts, uint32(101010))
assert.NotContains(t, blah.pendingHostMap.Indexes, uint32(12341234))

View File

@@ -1,9 +1,10 @@
package nebula
import (
"github.com/stretchr/testify/assert"
"reflect"
"testing"
"github.com/stretchr/testify/assert"
)
type headerTest struct {

View File

@@ -30,6 +30,7 @@ type HostMap struct {
vpnCIDR *net.IPNet
defaultRoute uint32
unsafeRoutes *CIDRTree
metricsEnabled bool
}
type HostInfo struct {
@@ -384,8 +385,16 @@ func (hm *HostMap) PunchList() []*udpAddr {
}
func (hm *HostMap) Punchy(conn *udpConn) {
var metricsTxPunchy metrics.Counter
if hm.metricsEnabled {
metricsTxPunchy = metrics.GetOrRegisterCounter("messages.tx.punchy", nil)
} else {
metricsTxPunchy = metrics.NilCounter{}
}
for {
for _, addr := range hm.PunchList() {
metricsTxPunchy.Inc(1)
conn.WriteTo([]byte{1}, addr)
}
time.Sleep(time.Second * 30)
@@ -532,13 +541,13 @@ func (i *HostInfo) cachePacket(t NebulaMessageType, st NebulaMessageSubType, pac
copy(tempPacket, packet)
//l.WithField("trace", string(debug.Stack())).Error("Caching packet", tempPacket)
i.packetStore = append(i.packetStore, &cachedPacket{t, st, f, tempPacket})
l.WithField("vpnIp", IntIp(i.hostId)).
i.logger().
WithField("length", len(i.packetStore)).
WithField("stored", true).
Debugf("Packet store")
} else if l.Level >= logrus.DebugLevel {
l.WithField("vpnIp", IntIp(i.hostId)).
i.logger().
WithField("length", len(i.packetStore)).
WithField("stored", false).
Debugf("Packet store")
@@ -556,7 +565,7 @@ func (i *HostInfo) handshakeComplete() {
//TODO: this should be managed by the handshake state machine to set it based on how many handshake were seen.
// Clamping it to 2 gets us out of the woods for now
*i.ConnectionState.messageCounter = 2
l.WithField("vpnIp", IntIp(i.hostId)).Debugf("Sending %d stored packets", len(i.packetStore))
i.logger().Debugf("Sending %d stored packets", len(i.packetStore))
nb := make([]byte, 12, 12)
out := make([]byte, mtu)
for _, cp := range i.packetStore {
@@ -623,6 +632,11 @@ func (i *HostInfo) RecvErrorExceeded() bool {
}
func (i *HostInfo) CreateRemoteCIDR(c *cert.NebulaCertificate) {
if len(c.Details.Ips) == 1 && len(c.Details.Subnets) == 0 {
// Simple case, no CIDRTree needed
return
}
remoteCidr := NewCIDRTree()
for _, ip := range c.Details.Ips {
remoteCidr.AddCIDR(&net.IPNet{IP: ip.IP, Mask: net.IPMask{255, 255, 255, 255}}, struct{}{})
@@ -634,6 +648,22 @@ func (i *HostInfo) CreateRemoteCIDR(c *cert.NebulaCertificate) {
i.remoteCidr = remoteCidr
}
func (i *HostInfo) logger() *logrus.Entry {
if i == nil {
return logrus.NewEntry(l)
}
li := l.WithField("vpnIp", IntIp(i.hostId))
if connState := i.ConnectionState; connState != nil {
if peerCert := connState.peerCert; peerCert != nil {
li = li.WithField("certName", peerCert.Details.Name)
}
}
return li
}
//########################
func NewHostInfoDest(addr *udpAddr) *HostInfoDest {
@@ -734,11 +764,16 @@ func (d *HostInfoDest) ProbeReceived(probeCount int) {
// Utility functions
func localIps() *[]net.IP {
func localIps(allowList *AllowList) *[]net.IP {
//FIXME: This function is pretty garbage
var ips []net.IP
ifaces, _ := net.Interfaces()
for _, i := range ifaces {
allow := allowList.AllowName(i.Name)
l.WithField("interfaceName", i.Name).WithField("allow", allow).Debug("localAllowList.AllowName")
if !allow {
continue
}
addrs, _ := i.Addrs()
for _, addr := range addrs {
var ip net.IP
@@ -750,6 +785,12 @@ func localIps() *[]net.IP {
ip = v.IP
}
if ip.To4() != nil && ip.IsLoopback() == false {
allow := allowList.Allow(ip2int(ip))
l.WithField("localIp", ip).WithField("allow", allow).Debug("localAllowList.Allow")
if !allow {
continue
}
ips = append(ips, ip)
}
}

View File

@@ -161,6 +161,4 @@ func BenchmarkHostmappromote2(b *testing.B) {
m.AddRemote(ip2int(net.ParseIP("10.128.1.1")), g)
m.AddRemote(ip2int(net.ParseIP("10.128.1.1")), y)
}
b.Errorf("hi")
}

View File

@@ -19,12 +19,25 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *FirewallPacket,
return
}
// Ignore packets from self to self
if fwPacket.RemoteIP == f.lightHouse.myIp {
return
}
// Ignore broadcast packets
if f.dropMulticast && isMulticast(fwPacket.RemoteIP) {
return
}
hostinfo := f.getOrHandshake(fwPacket.RemoteIP)
if hostinfo == nil {
if l.Level >= logrus.DebugLevel {
l.WithField("vpnIp", IntIp(fwPacket.RemoteIP)).
WithField("fwPacket", fwPacket).
Debugln("dropping outbound packet, vpnIp not in our CIDR or in unsafe routes")
}
return
}
ci := hostinfo.ConnectionState
if ci.ready == false {
@@ -39,21 +52,28 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *FirewallPacket,
ci.queueLock.Unlock()
}
if !f.firewall.Drop(packet, *fwPacket, false, hostinfo, trustedCAs) {
f.send(message, 0, ci, hostinfo, hostinfo.remote, packet, nb, out)
if f.lightHouse != nil && *ci.messageCounter%5000 == 0 {
dropReason := f.firewall.Drop(packet, *fwPacket, false, hostinfo, trustedCAs)
if dropReason == nil {
mc := f.sendNoMetrics(message, 0, ci, hostinfo, hostinfo.remote, packet, nb, out)
if f.lightHouse != nil && mc%5000 == 0 {
f.lightHouse.Query(fwPacket.RemoteIP, f)
}
} else if l.Level >= logrus.DebugLevel {
l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("fwPacket", fwPacket).
hostinfo.logger().
WithField("fwPacket", fwPacket).
WithField("reason", dropReason).
Debugln("dropping outbound packet")
}
}
// getOrHandshake returns nil if the vpnIp is not routable
func (f *Interface) getOrHandshake(vpnIp uint32) *HostInfo {
if f.hostMap.vpnCIDR.Contains(int2ip(vpnIp)) == false {
vpnIp = f.hostMap.queryUnsafeRoute(vpnIp)
if vpnIp == 0 {
return nil
}
}
hostinfo, err := f.hostMap.PromoteBestQueryVpnIP(vpnIp, f)
@@ -86,6 +106,15 @@ func (f *Interface) getOrHandshake(vpnIp uint32) *HostInfo {
ixHandshakeStage0(f, vpnIp, hostinfo)
// FIXME: Maybe make XX selectable, but probably not since psk makes it nearly pointless for us.
//xx_handshakeStage0(f, ip, hostinfo)
// If this is a static host, we don't need to wait for the HostQueryReply
// We can trigger the handshake right now
if _, ok := f.lightHouse.staticList[vpnIp]; ok {
select {
case f.handshakeManager.trigger <- vpnIp:
default:
}
}
}
return hostinfo
@@ -100,12 +129,17 @@ func (f *Interface) sendMessageNow(t NebulaMessageType, st NebulaMessageSubType,
}
// check if packet is in outbound fw rules
if f.firewall.Drop(p, *fp, false, hostInfo, trustedCAs) {
l.WithField("fwPacket", fp).Debugln("dropping cached packet")
dropReason := f.firewall.Drop(p, *fp, false, hostInfo, trustedCAs)
if dropReason != nil {
if l.Level >= logrus.DebugLevel {
l.WithField("fwPacket", fp).
WithField("reason", dropReason).
Debugln("dropping cached packet")
}
return
}
f.send(message, st, hostInfo.ConnectionState, hostInfo, hostInfo.remote, p, nb, out)
f.sendNoMetrics(message, st, hostInfo.ConnectionState, hostInfo, hostInfo.remote, p, nb, out)
if f.lightHouse != nil && *hostInfo.ConnectionState.messageCounter%5000 == 0 {
f.lightHouse.Query(fp.RemoteIP, f)
}
@@ -114,6 +148,13 @@ func (f *Interface) sendMessageNow(t NebulaMessageType, st NebulaMessageSubType,
// SendMessageToVpnIp handles real ip:port lookup and sends to the current best known address for vpnIp
func (f *Interface) SendMessageToVpnIp(t NebulaMessageType, st NebulaMessageSubType, vpnIp uint32, p, nb, out []byte) {
hostInfo := f.getOrHandshake(vpnIp)
if hostInfo == nil {
if l.Level >= logrus.DebugLevel {
l.WithField("vpnIp", IntIp(vpnIp)).
Debugln("dropping SendMessageToVpnIp, vpnIp not in our CIDR or in unsafe routes")
}
return
}
if !hostInfo.ConnectionState.ready {
// Because we might be sending stored packets, lock here to stop new things going to
@@ -138,6 +179,13 @@ func (f *Interface) sendMessageToVpnIp(t NebulaMessageType, st NebulaMessageSubT
// SendMessageToAll handles real ip:port lookup and sends to all known addresses for vpnIp
func (f *Interface) SendMessageToAll(t NebulaMessageType, st NebulaMessageSubType, vpnIp uint32, p, nb, out []byte) {
hostInfo := f.getOrHandshake(vpnIp)
if hostInfo == nil {
if l.Level >= logrus.DebugLevel {
l.WithField("vpnIp", IntIp(vpnIp)).
Debugln("dropping SendMessageToAll, vpnIp not in our CIDR or in unsafe routes")
}
return
}
if hostInfo.ConnectionState.ready == false {
// Because we might be sending stored packets, lock here to stop new things going to
@@ -162,9 +210,14 @@ func (f *Interface) sendMessageToAll(t NebulaMessageType, st NebulaMessageSubTyp
}
func (f *Interface) send(t NebulaMessageType, st NebulaMessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote *udpAddr, p, nb, out []byte) {
f.messageMetrics.Tx(t, st, 1)
f.sendNoMetrics(t, st, ci, hostinfo, remote, p, nb, out)
}
func (f *Interface) sendNoMetrics(t NebulaMessageType, st NebulaMessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote *udpAddr, p, nb, out []byte) uint64 {
if ci.eKey == nil {
//TODO: log warning
return
return 0
}
var err error
@@ -180,18 +233,19 @@ func (f *Interface) send(t NebulaMessageType, st NebulaMessageSubType, ci *Conne
//TODO: see above note on lock
//ci.writeLock.Unlock()
if err != nil {
l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).
hostinfo.logger().WithError(err).
WithField("udpAddr", remote).WithField("counter", c).
WithField("attemptedCounter", ci.messageCounter).
Error("Failed to encrypt outgoing packet")
return
return c
}
err = f.outside.WriteTo(out, remote)
if err != nil {
l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).
hostinfo.logger().WithError(err).
WithField("udpAddr", remote).Error("Failed to write outgoing packet")
}
return c
}
func isMulticast(ip uint32) bool {

View File

@@ -2,6 +2,8 @@ package nebula
import (
"errors"
"io"
"net"
"os"
"time"
@@ -10,10 +12,18 @@ import (
const mtu = 9001
type Inside interface {
io.ReadWriteCloser
Activate() error
CidrNet() *net.IPNet
DeviceName() string
WriteRaw([]byte) error
}
type InterfaceConfig struct {
HostMap *HostMap
Outside *udpConn
Inside *Tun
Inside Inside
certState *CertState
Cipher string
Firewall *Firewall
@@ -25,12 +35,16 @@ type InterfaceConfig struct {
DropLocalBroadcast bool
DropMulticast bool
UDPBatchSize int
udpQueues int
tunQueues int
MessageMetrics *MessageMetrics
version string
}
type Interface struct {
hostMap *HostMap
outside *udpConn
inside *Tun
inside Inside
certState *CertState
cipher string
firewall *Firewall
@@ -43,11 +57,12 @@ type Interface struct {
dropLocalBroadcast bool
dropMulticast bool
udpBatchSize int
udpQueues int
tunQueues int
version string
metricRxRecvError metrics.Counter
metricTxRecvError metrics.Counter
metricHandshakes metrics.Histogram
metricHandshakes metrics.Histogram
messageMetrics *MessageMetrics
}
func NewInterface(c *InterfaceConfig) (*Interface, error) {
@@ -79,10 +94,12 @@ func NewInterface(c *InterfaceConfig) (*Interface, error) {
dropLocalBroadcast: c.DropLocalBroadcast,
dropMulticast: c.DropMulticast,
udpBatchSize: c.UDPBatchSize,
udpQueues: c.udpQueues,
tunQueues: c.tunQueues,
version: c.version,
metricRxRecvError: metrics.GetOrRegisterCounter("messages.rx.recv_error", nil),
metricTxRecvError: metrics.GetOrRegisterCounter("messages.tx.recv_error", nil),
metricHandshakes: metrics.GetOrRegisterHistogram("handshakes", nil, metrics.NewExpDecaySample(1028, 0.015)),
metricHandshakes: metrics.GetOrRegisterHistogram("handshakes", nil, metrics.NewExpDecaySample(1028, 0.015)),
messageMetrics: c.MessageMetrics,
}
ifce.connectionManager = newConnectionManager(ifce, c.checkInterval, c.pendingDeletionInterval)
@@ -90,24 +107,28 @@ func NewInterface(c *InterfaceConfig) (*Interface, error) {
return ifce, nil
}
func (f *Interface) Run(tunRoutines, udpRoutines int, buildVersion string) {
func (f *Interface) run() {
// actually turn on tun dev
if err := f.inside.Activate(); err != nil {
l.Fatal(err)
}
f.version = buildVersion
l.WithField("interface", f.inside.Device).WithField("network", f.inside.Cidr.String()).
WithField("build", buildVersion).
addr, err := f.outside.LocalAddr()
if err != nil {
l.WithError(err).Error("Failed to get udp listen address")
}
l.WithField("interface", f.inside.DeviceName()).WithField("network", f.inside.CidrNet().String()).
WithField("build", f.version).WithField("udpAddr", addr).
Info("Nebula interface is active")
// Launch n queues to read packets from udp
for i := 0; i < udpRoutines; i++ {
for i := 0; i < f.udpQueues; i++ {
go f.listenOut(i)
}
// Launch n queues to read packets from tun dev
for i := 0; i < tunRoutines; i++ {
for i := 0; i < f.tunQueues; i++ {
go f.listenIn(i)
}
}
@@ -205,11 +226,28 @@ func (f *Interface) reloadFirewall(c *Config) {
}
oldFw := f.firewall
conntrack := oldFw.Conntrack
conntrack.Lock()
defer conntrack.Unlock()
fw.rulesVersion = oldFw.rulesVersion + 1
// If rulesVersion is back to zero, we have wrapped all the way around. Be
// safe and just reset conntrack in this case.
if fw.rulesVersion == 0 {
l.WithField("firewallHash", fw.GetRuleHash()).
WithField("oldFirewallHash", oldFw.GetRuleHash()).
WithField("rulesVersion", fw.rulesVersion).
Warn("firewall rulesVersion has overflowed, resetting conntrack")
} else {
fw.Conntrack = conntrack
}
f.firewall = fw
oldFw.Destroy()
l.WithField("firewallHash", fw.GetRuleHash()).
WithField("oldFirewallHash", oldFw.GetRuleHash()).
WithField("rulesVersion", fw.rulesVersion).
Info("New firewall has been installed")
}

View File

@@ -7,6 +7,7 @@ import (
"time"
"github.com/golang/protobuf/proto"
"github.com/rcrowley/go-metrics"
"github.com/slackhq/nebula/cert"
)
@@ -19,6 +20,19 @@ type LightHouse struct {
// Local cache of answers from light houses
addrMap map[uint32][]udpAddr
// filters remote addresses allowed for each host
// - When we are a lighthouse, this filters what addresses we store and
// respond with.
// - When we are not a lighthouse, this filters which addresses we accept
// from lighthouses.
remoteAllowList *AllowList
// filters local addresses that we advertise to lighthouses
localAllowList *AllowList
// used to trigger the HandshakeManager when we receive HostQueryReply
handshakeTrigger chan<- uint32
// staticList exists to avoid having a bool in each addrMap entry
// since static should be rare
staticList map[uint32]struct{}
@@ -26,6 +40,10 @@ type LightHouse struct {
interval int
nebulaPort int
punchBack bool
punchDelay time.Duration
metrics *MessageMetrics
metricHolepunchTx metrics.Counter
}
type EncWriter interface {
@@ -33,7 +51,7 @@ type EncWriter interface {
SendMessageToAll(t NebulaMessageType, st NebulaMessageSubType, vpnIp uint32, p, nb, out []byte)
}
func NewLightHouse(amLighthouse bool, myIp uint32, ips []uint32, interval int, nebulaPort int, pc *udpConn, punchBack bool) *LightHouse {
func NewLightHouse(amLighthouse bool, myIp uint32, ips []uint32, interval int, nebulaPort int, pc *udpConn, punchBack bool, punchDelay time.Duration, metricsEnabled bool) *LightHouse {
h := LightHouse{
amLighthouse: amLighthouse,
myIp: myIp,
@@ -44,6 +62,15 @@ func NewLightHouse(amLighthouse bool, myIp uint32, ips []uint32, interval int, n
interval: interval,
punchConn: pc,
punchBack: punchBack,
punchDelay: punchDelay,
}
if metricsEnabled {
h.metrics = newLighthouseMetrics()
h.metricHolepunchTx = metrics.GetOrRegisterCounter("messages.tx.holepunch", nil)
} else {
h.metricHolepunchTx = metrics.NilCounter{}
}
for _, ip := range ips {
@@ -53,6 +80,20 @@ func NewLightHouse(amLighthouse bool, myIp uint32, ips []uint32, interval int, n
return &h
}
func (lh *LightHouse) SetRemoteAllowList(allowList *AllowList) {
lh.Lock()
defer lh.Unlock()
lh.remoteAllowList = allowList
}
func (lh *LightHouse) SetLocalAllowList(allowList *AllowList) {
lh.Lock()
defer lh.Unlock()
lh.localAllowList = allowList
}
func (lh *LightHouse) ValidateLHStaticEntries() error {
for lhIP, _ := range lh.lighthouses {
if _, ok := lh.staticList[lhIP]; !ok {
@@ -85,6 +126,7 @@ func (lh *LightHouse) QueryServer(ip uint32, f EncWriter) {
return
}
lh.metricTx(NebulaMeta_HostQuery, int64(len(lh.lighthouses)))
nb := make([]byte, 12, 12)
out := make([]byte, mtu)
for n := range lh.lighthouses {
@@ -133,6 +175,13 @@ func (lh *LightHouse) AddRemote(vpnIP uint32, toIp *udpAddr, static bool) {
return
}
}
allow := lh.remoteAllowList.Allow(udp2ipInt(toIp))
l.WithField("remoteIp", toIp).WithField("allow", allow).Debug("remoteAllowList.Allow")
if !allow {
return
}
//l.Debugf("Adding reply of %s as %s\n", IntIp(vpnIP), toIp)
if static {
lh.staticList[vpnIP] = struct{}{}
@@ -201,7 +250,7 @@ func (lh *LightHouse) LhUpdateWorker(f EncWriter) {
for {
ipp := []*IpAndPort{}
for _, e := range *localIps() {
for _, e := range *localIps(lh.localAllowList) {
// Only add IPs that aren't my VPN/tun IP
if ip2int(e) != lh.myIp {
ipp = append(ipp, &IpAndPort{Ip: ip2int(e), Port: uint32(lh.nebulaPort)})
@@ -216,6 +265,7 @@ func (lh *LightHouse) LhUpdateWorker(f EncWriter) {
},
}
lh.metricTx(NebulaMeta_HostUpdateNotification, int64(len(lh.lighthouses)))
nb := make([]byte, 12, 12)
out := make([]byte, mtu)
for vpnIp := range lh.lighthouses {
@@ -248,6 +298,8 @@ func (lh *LightHouse) HandleRequest(rAddr *udpAddr, vpnIp uint32, p []byte, c *c
return
}
lh.metricRx(n.Type, 1)
switch n.Type {
case NebulaMeta_HostQuery:
// Exit if we don't answer queries
@@ -275,6 +327,7 @@ func (lh *LightHouse) HandleRequest(rAddr *udpAddr, vpnIp uint32, p []byte, c *c
l.WithError(err).WithField("vpnIp", IntIp(vpnIp)).Error("Failed to marshal lighthouse host query reply")
return
}
lh.metricTx(NebulaMeta_HostQueryReply, 1)
f.SendMessageToVpnIp(lightHouse, 0, vpnIp, reply, make([]byte, 12, 12), make([]byte, mtu))
// This signals the other side to punch some zero byte udp packets
@@ -293,6 +346,7 @@ func (lh *LightHouse) HandleRequest(rAddr *udpAddr, vpnIp uint32, p []byte, c *c
},
}
reply, _ := proto.Marshal(answer)
lh.metricTx(NebulaMeta_HostPunchNotification, 1)
f.SendMessageToVpnIp(lightHouse, 0, n.Details.VpnIp, reply, make([]byte, 12, 12), make([]byte, mtu))
}
//fmt.Println(reply, remoteaddr)
@@ -307,6 +361,11 @@ func (lh *LightHouse) HandleRequest(rAddr *udpAddr, vpnIp uint32, p []byte, c *c
ans := NewUDPAddr(a.Ip, uint16(a.Port))
lh.AddRemote(n.Details.VpnIp, ans, false)
}
// Non-blocking attempt to trigger, skip if it would block
select {
case lh.handshakeTrigger <- n.Details.VpnIp:
default:
}
case NebulaMeta_HostUpdateNotification:
//Simple check that the host sent this not someone else
@@ -328,10 +387,9 @@ func (lh *LightHouse) HandleRequest(rAddr *udpAddr, vpnIp uint32, p []byte, c *c
for _, a := range n.Details.IpAndPorts {
vpnPeer := NewUDPAddr(a.Ip, uint16(a.Port))
go func() {
for i := 0; i < 5; i++ {
lh.punchConn.WriteTo(empty, vpnPeer)
time.Sleep(time.Second * 1)
}
time.Sleep(lh.punchDelay)
lh.metricHolepunchTx.Inc(1)
lh.punchConn.WriteTo(empty, vpnPeer)
}()
l.Debugf("Punching %s on %d for %s", IntIp(a.Ip), a.Port, IntIp(n.Details.VpnIp))
@@ -349,6 +407,13 @@ func (lh *LightHouse) HandleRequest(rAddr *udpAddr, vpnIp uint32, p []byte, c *c
}
}
func (lh *LightHouse) metricRx(t NebulaMeta_MessageType, i int64) {
lh.metrics.Rx(NebulaMessageType(t), 0, i)
}
func (lh *LightHouse) metricTx(t NebulaMeta_MessageType, i int64) {
lh.metrics.Tx(NebulaMessageType(t), 0, i)
}
/*
func (f *Interface) sendPathCheck(ci *ConnectionState, endpoint *net.UDPAddr, counter int) {
c := ci.messageCounter

View File

@@ -4,7 +4,7 @@ import (
"net"
"testing"
proto "github.com/golang/protobuf/proto"
"github.com/golang/protobuf/proto"
"github.com/stretchr/testify/assert"
)
@@ -52,7 +52,7 @@ func Test_lhStaticMapping(t *testing.T) {
udpServer, _ := NewListener("0.0.0.0", 0, true)
meh := NewLightHouse(true, 1, []uint32{ip2int(lh1IP)}, 10, 10003, udpServer, false)
meh := NewLightHouse(true, 1, []uint32{ip2int(lh1IP)}, 10, 10003, udpServer, false, 1, false)
meh.AddRemote(ip2int(lh1IP), NewUDPAddr(ip2int(lh1IP), uint16(4242)), true)
err := meh.ValidateLHStaticEntries()
assert.Nil(t, err)
@@ -60,7 +60,7 @@ func Test_lhStaticMapping(t *testing.T) {
lh2 := "10.128.0.3"
lh2IP := net.ParseIP(lh2)
meh = NewLightHouse(true, 1, []uint32{ip2int(lh1IP), ip2int(lh2IP)}, 10, 10003, udpServer, false)
meh = NewLightHouse(true, 1, []uint32{ip2int(lh1IP), ip2int(lh2IP)}, 10, 10003, udpServer, false, 1, false)
meh.AddRemote(ip2int(lh1IP), NewUDPAddr(ip2int(lh1IP), uint16(4242)), true)
err = meh.ValidateLHStaticEntries()
assert.EqualError(t, err, "Lighthouse 10.128.0.3 does not have a static_host_map entry")

39
logger.go Normal file
View File

@@ -0,0 +1,39 @@
package nebula
import (
"errors"
"github.com/sirupsen/logrus"
)
type ContextualError struct {
RealError error
Fields map[string]interface{}
Context string
}
func NewContextualError(msg string, fields map[string]interface{}, realError error) ContextualError {
return ContextualError{Context: msg, Fields: fields, RealError: realError}
}
func (ce ContextualError) Error() string {
if ce.RealError == nil {
return ce.Context
}
return ce.RealError.Error()
}
func (ce ContextualError) Unwrap() error {
if ce.RealError == nil {
return errors.New(ce.Context)
}
return ce.RealError
}
func (ce *ContextualError) Log(lr *logrus.Logger) {
if ce.RealError != nil {
lr.WithFields(ce.Fields).WithError(ce.RealError).Error(ce.Context)
} else {
lr.WithFields(ce.Fields).Error(ce.Context)
}
}

67
logger_test.go Normal file
View File

@@ -0,0 +1,67 @@
package nebula
import (
"errors"
"testing"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
)
type TestLogWriter struct {
Logs []string
}
func NewTestLogWriter() *TestLogWriter {
return &TestLogWriter{Logs: make([]string, 0)}
}
func (tl *TestLogWriter) Write(p []byte) (n int, err error) {
tl.Logs = append(tl.Logs, string(p))
return len(p), nil
}
func (tl *TestLogWriter) Reset() {
tl.Logs = tl.Logs[:0]
}
func TestContextualError_Log(t *testing.T) {
l := logrus.New()
l.Formatter = &logrus.TextFormatter{
DisableTimestamp: true,
DisableColors: true,
}
tl := NewTestLogWriter()
l.Out = tl
// Test a full context line
tl.Reset()
e := NewContextualError("test message", m{"field": "1"}, errors.New("error"))
e.Log(l)
assert.Equal(t, []string{"level=error msg=\"test message\" error=error field=1\n"}, tl.Logs)
// Test a line with an error and msg but no fields
tl.Reset()
e = NewContextualError("test message", nil, errors.New("error"))
e.Log(l)
assert.Equal(t, []string{"level=error msg=\"test message\" error=error\n"}, tl.Logs)
// Test just a context and fields
tl.Reset()
e = NewContextualError("test message", m{"field": "1"}, nil)
e.Log(l)
assert.Equal(t, []string{"level=error msg=\"test message\" field=1\n"}, tl.Logs)
// Test just a context
tl.Reset()
e = NewContextualError("test message", nil, nil)
e.Log(l)
assert.Equal(t, []string{"level=error msg=\"test message\"\n"}, tl.Logs)
// Test just an error
tl.Reset()
e = NewContextualError("", nil, errors.New("error"))
e.Log(l)
assert.Equal(t, []string{"level=error error=error\n"}, tl.Logs)
}

217
main.go
View File

@@ -4,11 +4,8 @@ import (
"encoding/binary"
"fmt"
"net"
"os"
"os/signal"
"strconv"
"strings"
"syscall"
"time"
"github.com/sirupsen/logrus"
@@ -16,36 +13,31 @@ import (
"gopkg.in/yaml.v2"
)
// The caller should provide a real logger, we have one just in case
var l = logrus.New()
type m map[string]interface{}
func Main(configPath string, configTest bool, buildVersion string) {
l.Out = os.Stdout
func Main(config *Config, configTest bool, buildVersion string, logger *logrus.Logger, tunFd *int) (*Control, error) {
l = logger
l.Formatter = &logrus.TextFormatter{
FullTimestamp: true,
}
config := NewConfig()
err := config.Load(configPath)
if err != nil {
l.WithError(err).Error("Failed to load config")
os.Exit(1)
}
// Print the config if in test, the exit comes later
if configTest {
b, err := yaml.Marshal(config.Settings)
if err != nil {
l.Println(err)
os.Exit(1)
return nil, err
}
// Print the final config
l.Println(string(b))
}
err = configLogger(config)
err := configLogger(config)
if err != nil {
l.WithError(err).Error("Failed to configure the logger")
return nil, NewContextualError("Failed to configure the logger", nil, err)
}
config.RegisterReloadCallback(func(c *Config) {
@@ -59,20 +51,20 @@ func Main(configPath string, configTest bool, buildVersion string) {
trustedCAs, err = loadCAFromConfig(config)
if err != nil {
//The errors coming out of loadCA are already nicely formatted
l.WithError(err).Fatal("Failed to load ca from config")
return nil, NewContextualError("Failed to load ca from config", nil, err)
}
l.WithField("fingerprints", trustedCAs.GetFingerprints()).Debug("Trusted CA fingerprints")
cs, err := NewCertStateFromConfig(config)
if err != nil {
//The errors coming out of NewCertStateFromConfig are already nicely formatted
l.WithError(err).Fatal("Failed to load certificate from config")
return nil, NewContextualError("Failed to load certificate from config", nil, err)
}
l.WithField("cert", cs.certificate).Debug("Client nebula certificate")
fw, err := NewFirewallFromConfig(cs.certificate, config)
if err != nil {
l.WithError(err).Fatal("Error while loading firewall rules")
return nil, NewContextualError("Error while loading firewall rules", nil, err)
}
l.WithField("firewallHash", fw.GetRuleHash()).Info("Firewall started")
@@ -80,11 +72,11 @@ func Main(configPath string, configTest bool, buildVersion string) {
tunCidr := cs.certificate.Details.Ips[0]
routes, err := parseRoutes(config, tunCidr)
if err != nil {
l.WithError(err).Fatal("Could not parse tun.routes")
return nil, NewContextualError("Could not parse tun.routes", nil, err)
}
unsafeRoutes, err := parseUnsafeRoutes(config, tunCidr)
if err != nil {
l.WithError(err).Fatal("Could not parse tun.unsafe_routes")
return nil, NewContextualError("Could not parse tun.unsafe_routes", nil, err)
}
ssh, err := sshd.NewSSHServer(l.WithField("subsystem", "sshd"))
@@ -92,7 +84,7 @@ func Main(configPath string, configTest bool, buildVersion string) {
if config.GetBool("sshd.enabled", false) {
err = configSSH(ssh, config)
if err != nil {
l.WithError(err).Fatal("Error while configuring the sshd")
return nil, NewContextualError("Error while configuring the sshd", nil, err)
}
}
@@ -101,32 +93,49 @@ func Main(configPath string, configTest bool, buildVersion string) {
// tun config, listeners, anything modifying the computer should be below
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
if configTest {
os.Exit(0)
}
var tun Inside
if !configTest {
config.CatchHUP()
config.CatchHUP()
switch {
case config.GetBool("tun.disabled", false):
tun = newDisabledTun(tunCidr, l)
case tunFd != nil:
tun, err = newTunFromFd(
*tunFd,
tunCidr,
config.GetInt("tun.mtu", DEFAULT_MTU),
routes,
unsafeRoutes,
config.GetInt("tun.tx_queue", 500),
)
default:
tun, err = newTun(
config.GetString("tun.dev", ""),
tunCidr,
config.GetInt("tun.mtu", DEFAULT_MTU),
routes,
unsafeRoutes,
config.GetInt("tun.tx_queue", 500),
)
}
// set up our tun dev
tun, err := newTun(
config.GetString("tun.dev", ""),
tunCidr,
config.GetInt("tun.mtu", DEFAULT_MTU),
routes,
unsafeRoutes,
config.GetInt("tun.tx_queue", 500),
)
if err != nil {
l.WithError(err).Fatal("Failed to get a tun/tap device")
if err != nil {
return nil, NewContextualError("Failed to get a tun/tap device", nil, err)
}
}
// set up our UDP listener
udpQueues := config.GetInt("listen.routines", 1)
udpServer, err := NewListener(config.GetString("listen.host", "0.0.0.0"), config.GetInt("listen.port", 0), udpQueues > 1)
if err != nil {
l.WithError(err).Fatal("Failed to open udp listener")
var udpServer *udpConn
if !configTest {
udpServer, err = NewListener(config.GetString("listen.host", "0.0.0.0"), config.GetInt("listen.port", 0), udpQueues > 1)
if err != nil {
return nil, NewContextualError("Failed to open udp listener", nil, err)
}
udpServer.reloadConfig(config)
}
udpServer.reloadConfig(config)
// Set up my internal host map
var preferredRanges []*net.IPNet
@@ -136,7 +145,7 @@ func Main(configPath string, configTest bool, buildVersion string) {
for _, rawPreferredRange := range rawPreferredRanges {
_, preferredRange, err := net.ParseCIDR(rawPreferredRange)
if err != nil {
l.WithError(err).Fatal("Failed to parse preferred ranges")
return nil, NewContextualError("Failed to parse preferred ranges", nil, err)
}
preferredRanges = append(preferredRanges, preferredRange)
}
@@ -149,7 +158,7 @@ func Main(configPath string, configTest bool, buildVersion string) {
if rawLocalRange != "" {
_, localRange, err := net.ParseCIDR(rawLocalRange)
if err != nil {
l.WithError(err).Fatal("Failed to parse local range")
return nil, NewContextualError("Failed to parse local_range", nil, err)
}
// Check if the entry for local_range was already specified in
@@ -169,6 +178,7 @@ func Main(configPath string, configTest bool, buildVersion string) {
hostMap := NewHostMap("main", tunCidr, preferredRanges)
hostMap.SetDefaultRoute(ip2int(net.ParseIP(config.GetString("default_route", "0.0.0.0"))))
hostMap.addUnsafeRoutes(&unsafeRoutes)
hostMap.metricsEnabled = config.GetBool("stats.message_metrics", false)
l.WithField("network", hostMap.vpnCIDR).WithField("preferredRanges", hostMap.preferredRanges).Info("Main HostMap created")
@@ -177,23 +187,22 @@ func Main(configPath string, configTest bool, buildVersion string) {
go hostMap.Promoter(config.GetInt("promoter.interval"))
*/
punchy := config.GetBool("punchy", false)
if punchy == true {
punchy := NewPunchyFromConfig(config)
if punchy.Punch && !configTest {
l.Info("UDP hole punching enabled")
go hostMap.Punchy(udpServer)
}
port := config.GetInt("listen.port", 0)
// If port is dynamic, discover it
if port == 0 {
if port == 0 && !configTest {
uPort, err := udpServer.LocalAddr()
if err != nil {
l.WithError(err).Fatal("Failed to get listening port")
return nil, NewContextualError("Failed to get listening port", nil, err)
}
port = int(uPort.Port)
}
punchBack := config.GetBool("punch_back", false)
amLighthouse := config.GetBool("lighthouse.am_lighthouse", false)
// warn if am_lighthouse is enabled but upstream lighthouses exists
@@ -206,7 +215,10 @@ func Main(configPath string, configTest bool, buildVersion string) {
for i, host := range rawLighthouseHosts {
ip := net.ParseIP(host)
if ip == nil {
l.WithField("host", host).Fatalf("Unable to parse lighthouse host entry %v", i+1)
return nil, NewContextualError("Unable to parse lighthouse host entry", m{"host": host, "entry": i + 1}, nil)
}
if !tunCidr.Contains(ip) {
return nil, NewContextualError("lighthouse host is not in our subnet, invalid", m{"vpnIp": ip, "network": tunCidr.String()}, nil)
}
lighthouseHosts[i] = ip2int(ip)
}
@@ -219,12 +231,29 @@ func Main(configPath string, configTest bool, buildVersion string) {
config.GetInt("lighthouse.interval", 10),
port,
udpServer,
punchBack,
punchy.Respond,
punchy.Delay,
config.GetBool("stats.lighthouse_metrics", false),
)
remoteAllowList, err := config.GetAllowList("lighthouse.remote_allow_list", false)
if err != nil {
return nil, NewContextualError("Invalid lighthouse.remote_allow_list", nil, err)
}
lightHouse.SetRemoteAllowList(remoteAllowList)
localAllowList, err := config.GetAllowList("lighthouse.local_allow_list", true)
if err != nil {
return nil, NewContextualError("Invalid lighthouse.local_allow_list", nil, err)
}
lightHouse.SetLocalAllowList(localAllowList)
//TODO: Move all of this inside functions in lighthouse.go
for k, v := range config.GetMap("static_host_map", map[interface{}]interface{}{}) {
vpnIp := net.ParseIP(fmt.Sprintf("%v", k))
if !tunCidr.Contains(vpnIp) {
return nil, NewContextualError("static_host_map key is not in our subnet, invalid", m{"vpnIp": vpnIp, "network": tunCidr.String()}, nil)
}
vals, ok := v.([]interface{})
if ok {
for _, v := range vals {
@@ -234,7 +263,7 @@ func Main(configPath string, configTest bool, buildVersion string) {
ip := addr.IP
port, err := strconv.Atoi(parts[1])
if err != nil {
l.Fatalf("Static host address for %s could not be parsed: %s", vpnIp, v)
return nil, NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp}, err)
}
lightHouse.AddRemote(ip2int(vpnIp), NewUDPAddr(ip2int(ip), uint16(port)), true)
}
@@ -247,7 +276,7 @@ func Main(configPath string, configTest bool, buildVersion string) {
ip := addr.IP
port, err := strconv.Atoi(parts[1])
if err != nil {
l.Fatalf("Static host address for %s could not be parsed: %s", vpnIp, v)
return nil, NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp}, err)
}
lightHouse.AddRemote(ip2int(vpnIp), NewUDPAddr(ip2int(ip), uint16(port)), true)
}
@@ -259,7 +288,24 @@ func Main(configPath string, configTest bool, buildVersion string) {
l.WithError(err).Error("Lighthouse unreachable")
}
handshakeManager := NewHandshakeManager(tunCidr, preferredRanges, hostMap, lightHouse, udpServer)
var messageMetrics *MessageMetrics
if config.GetBool("stats.message_metrics", false) {
messageMetrics = newMessageMetrics()
} else {
messageMetrics = newMessageMetricsOnlyRecvError()
}
handshakeConfig := HandshakeConfig{
tryInterval: config.GetDuration("handshakes.try_interval", DefaultHandshakeTryInterval),
retries: config.GetInt("handshakes.retries", DefaultHandshakeRetries),
waitRotation: config.GetInt("handshakes.wait_rotation", DefaultHandshakeWaitRotation),
triggerBuffer: config.GetInt("handshakes.trigger_buffer", DefaultHandshakeTriggerBuffer),
messageMetrics: messageMetrics,
}
handshakeManager := NewHandshakeManager(tunCidr, preferredRanges, hostMap, lightHouse, udpServer, handshakeConfig)
lightHouse.handshakeTrigger = handshakeManager.trigger
//TODO: These will be reused for psk
//handshakeMACKey := config.GetString("handshake_mac.key", "")
@@ -283,37 +329,47 @@ func Main(configPath string, configTest bool, buildVersion string) {
DropLocalBroadcast: config.GetBool("tun.drop_local_broadcast", false),
DropMulticast: config.GetBool("tun.drop_multicast", false),
UDPBatchSize: config.GetInt("listen.batch", 64),
udpQueues: udpQueues,
tunQueues: config.GetInt("tun.routines", 1),
MessageMetrics: messageMetrics,
version: buildVersion,
}
switch ifConfig.Cipher {
case "aes":
noiseEndiannes = binary.BigEndian
noiseEndianness = binary.BigEndian
case "chachapoly":
noiseEndiannes = binary.LittleEndian
noiseEndianness = binary.LittleEndian
default:
l.Fatalf("Unknown cipher: %v", ifConfig.Cipher)
return nil, fmt.Errorf("unknown cipher: %v", ifConfig.Cipher)
}
ifce, err := NewInterface(ifConfig)
if err != nil {
l.WithError(err).Fatal("Failed to initialize interface")
var ifce *Interface
if !configTest {
ifce, err = NewInterface(ifConfig)
if err != nil {
return nil, fmt.Errorf("failed to initialize interface: %s", err)
}
ifce.RegisterConfigChangeCallbacks(config)
go handshakeManager.Run(ifce)
go lightHouse.LhUpdateWorker(ifce)
}
ifce.RegisterConfigChangeCallbacks(config)
go handshakeManager.Run(ifce)
go lightHouse.LhUpdateWorker(ifce)
err = startStats(config)
err = startStats(config, configTest)
if err != nil {
l.WithError(err).Fatal("Failed to start stats emitter")
return nil, NewContextualError("Failed to start stats emitter", nil, err)
}
if configTest {
return nil, nil
}
//TODO: check if we _should_ be emitting stats
go ifce.emitStats(config.GetDuration("stats.interval", time.Second*10))
attachCommands(ssh, hostMap, handshakeManager.pendingHostMap, lightHouse, ifce)
ifce.Run(config.GetInt("tun.routines", 1), udpQueues, buildVersion)
// Start DNS server last to allow using the nebula IP as lighthouse.dns.host
if amLighthouse && serveDns {
@@ -321,30 +377,5 @@ func Main(configPath string, configTest bool, buildVersion string) {
go dnsMain(hostMap, config)
}
// Just sit here and be friendly, main thread.
shutdownBlock(ifce)
}
func shutdownBlock(ifce *Interface) {
var sigChan = make(chan os.Signal)
signal.Notify(sigChan, syscall.SIGTERM)
signal.Notify(sigChan, syscall.SIGINT)
sig := <-sigChan
l.WithField("signal", sig).Info("Caught signal, shutting down")
//TODO: stop tun and udp routines, the lock on hostMap does effectively does that though
//TODO: this is probably better as a function in ConnectionManager or HostMap directly
ifce.hostMap.Lock()
for _, h := range ifce.hostMap.Hosts {
if h.ConnectionState.ready {
ifce.send(closeTunnel, 0, h.ConnectionState, h, h.remote, []byte{}, make([]byte, 12, 12), make([]byte, mtu))
l.WithField("vpnIp", IntIp(h.hostId)).WithField("udpAddr", h.remote).
Debug("Sending close tunnel message")
}
}
ifce.hostMap.Unlock()
l.WithField("signal", sig).Info("Goodbye")
os.Exit(0)
return &Control{ifce, l}, nil
}

97
message_metrics.go Normal file
View File

@@ -0,0 +1,97 @@
package nebula
import (
"fmt"
"github.com/rcrowley/go-metrics"
)
type MessageMetrics struct {
rx [][]metrics.Counter
tx [][]metrics.Counter
rxUnknown metrics.Counter
txUnknown metrics.Counter
}
func (m *MessageMetrics) Rx(t NebulaMessageType, s NebulaMessageSubType, i int64) {
if m != nil {
if t >= 0 && int(t) < len(m.rx) && s >= 0 && int(s) < len(m.rx[t]) {
m.rx[t][s].Inc(i)
} else if m.rxUnknown != nil {
m.rxUnknown.Inc(i)
}
}
}
func (m *MessageMetrics) Tx(t NebulaMessageType, s NebulaMessageSubType, i int64) {
if m != nil {
if t >= 0 && int(t) < len(m.tx) && s >= 0 && int(s) < len(m.tx[t]) {
m.tx[t][s].Inc(i)
} else if m.txUnknown != nil {
m.txUnknown.Inc(i)
}
}
}
func newMessageMetrics() *MessageMetrics {
gen := func(t string) [][]metrics.Counter {
return [][]metrics.Counter{
{
metrics.GetOrRegisterCounter(fmt.Sprintf("messages.%s.handshake_ixpsk0", t), nil),
},
nil,
{metrics.GetOrRegisterCounter(fmt.Sprintf("messages.%s.recv_error", t), nil)},
{metrics.GetOrRegisterCounter(fmt.Sprintf("messages.%s.lighthouse", t), nil)},
{
metrics.GetOrRegisterCounter(fmt.Sprintf("messages.%s.test_request", t), nil),
metrics.GetOrRegisterCounter(fmt.Sprintf("messages.%s.test_response", t), nil),
},
{metrics.GetOrRegisterCounter(fmt.Sprintf("messages.%s.close_tunnel", t), nil)},
}
}
return &MessageMetrics{
rx: gen("rx"),
tx: gen("tx"),
rxUnknown: metrics.GetOrRegisterCounter("messages.rx.other", nil),
txUnknown: metrics.GetOrRegisterCounter("messages.tx.other", nil),
}
}
// Historically we only recorded recv_error, so this is backwards compat
func newMessageMetricsOnlyRecvError() *MessageMetrics {
gen := func(t string) [][]metrics.Counter {
return [][]metrics.Counter{
nil,
nil,
{metrics.GetOrRegisterCounter(fmt.Sprintf("messages.%s.recv_error", t), nil)},
}
}
return &MessageMetrics{
rx: gen("rx"),
tx: gen("tx"),
}
}
func newLighthouseMetrics() *MessageMetrics {
gen := func(t string) [][]metrics.Counter {
h := make([][]metrics.Counter, len(NebulaMeta_MessageType_name))
used := []NebulaMeta_MessageType{
NebulaMeta_HostQuery,
NebulaMeta_HostQueryReply,
NebulaMeta_HostUpdateNotification,
NebulaMeta_HostPunchNotification,
}
for _, i := range used {
h[i] = []metrics.Counter{metrics.GetOrRegisterCounter(fmt.Sprintf("lighthouse.%s.%s", t, i.String()), nil)}
}
return h
}
return &MessageMetrics{
rx: gen("rx"),
tx: gen("tx"),
rxUnknown: metrics.GetOrRegisterCounter("lighthouse.rx.other", nil),
txUnknown: metrics.GetOrRegisterCounter("lighthouse.tx.other", nil),
}
}

View File

@@ -8,11 +8,11 @@ import (
"github.com/flynn/noise"
)
type endiannes interface {
type endianness interface {
PutUint64(b []byte, v uint64)
}
var noiseEndiannes endiannes = binary.BigEndian
var noiseEndianness endianness = binary.BigEndian
type NebulaCipherState struct {
c noise.Cipher
@@ -37,7 +37,7 @@ func (s *NebulaCipherState) EncryptDanger(out, ad, plaintext []byte, n uint64, n
nb[1] = 0
nb[2] = 0
nb[3] = 0
noiseEndiannes.PutUint64(nb[4:], n)
noiseEndianness.PutUint64(nb[4:], n)
out = s.c.(cipher.AEAD).Seal(out, nb, plaintext, ad)
//l.Debugf("Encryption: outlen: %d, nonce: %d, ad: %s, plainlen %d", len(out), n, ad, len(plaintext))
return out, nil
@@ -52,7 +52,7 @@ func (s *NebulaCipherState) DecryptDanger(out, ad, ciphertext []byte, n uint64,
nb[1] = 0
nb[2] = 0
nb[3] = 0
noiseEndiannes.PutUint64(nb[4:], n)
noiseEndianness.PutUint64(nb[4:], n)
return s.c.(cipher.AEAD).Open(out, nb, ciphertext, ad)
} else {
return []byte{}, nil

View File

@@ -2,18 +2,14 @@ package nebula
import (
"encoding/binary"
"errors"
"fmt"
"time"
"github.com/flynn/noise"
"github.com/golang/protobuf/proto"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert"
// "github.com/google/gopacket"
// "github.com/google/gopacket/layers"
// "encoding/binary"
"errors"
"fmt"
"time"
"golang.org/x/net/ipv4"
)
@@ -54,13 +50,14 @@ func (f *Interface) readOutsidePackets(addr *udpAddr, out []byte, packet []byte,
// Fallthrough to the bottom to record incoming traffic
case lightHouse:
f.messageMetrics.Rx(header.Type, header.Subtype, 1)
if !f.handleEncrypted(ci, addr, header) {
return
}
d, err := f.decrypt(hostinfo, header.MessageCounter, out, packet, header, nb)
if err != nil {
l.WithError(err).WithField("udpAddr", addr).WithField("vpnIp", IntIp(hostinfo.hostId)).
hostinfo.logger().WithError(err).WithField("udpAddr", addr).
WithField("packet", packet).
Error("Failed to decrypt lighthouse packet")
@@ -74,13 +71,14 @@ func (f *Interface) readOutsidePackets(addr *udpAddr, out []byte, packet []byte,
// Fallthrough to the bottom to record incoming traffic
case test:
f.messageMetrics.Rx(header.Type, header.Subtype, 1)
if !f.handleEncrypted(ci, addr, header) {
return
}
d, err := f.decrypt(hostinfo, header.MessageCounter, out, packet, header, nb)
if err != nil {
l.WithError(err).WithField("udpAddr", addr).WithField("vpnIp", IntIp(hostinfo.hostId)).
hostinfo.logger().WithError(err).WithField("udpAddr", addr).
WithField("packet", packet).
Error("Failed to decrypt test packet")
@@ -102,27 +100,31 @@ func (f *Interface) readOutsidePackets(addr *udpAddr, out []byte, packet []byte,
// are unauthenticated
case handshake:
f.messageMetrics.Rx(header.Type, header.Subtype, 1)
HandleIncomingHandshake(f, addr, packet, header, hostinfo)
return
case recvError:
f.messageMetrics.Rx(header.Type, header.Subtype, 1)
// TODO: Remove this with recv_error deprecation
f.handleRecvError(addr, header)
return
case closeTunnel:
f.messageMetrics.Rx(header.Type, header.Subtype, 1)
if !f.handleEncrypted(ci, addr, header) {
return
}
l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
hostinfo.logger().WithField("udpAddr", addr).
Info("Close tunnel received, tearing down.")
f.closeTunnel(hostinfo)
return
default:
l.Debugf("Unexpected packet received from %s", addr)
f.messageMetrics.Rx(header.Type, header.Subtype, 1)
hostinfo.logger().Debugf("Unexpected packet received from %s", addr)
return
}
@@ -142,15 +144,19 @@ func (f *Interface) closeTunnel(hostInfo *HostInfo) {
func (f *Interface) handleHostRoaming(hostinfo *HostInfo, addr *udpAddr) {
if hostDidRoam(hostinfo.remote, addr) {
if !f.lightHouse.remoteAllowList.Allow(udp2ipInt(addr)) {
hostinfo.logger().WithField("newAddr", addr).Debug("lighthouse.remote_allow_list denied roaming")
return
}
if !hostinfo.lastRoam.IsZero() && addr.Equals(hostinfo.lastRoamRemote) && time.Since(hostinfo.lastRoam) < RoamingSupressSeconds*time.Second {
if l.Level >= logrus.DebugLevel {
l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", hostinfo.remote).WithField("newAddr", addr).
hostinfo.logger().WithField("udpAddr", hostinfo.remote).WithField("newAddr", addr).
Debugf("Supressing roam back to previous remote for %d seconds", RoamingSupressSeconds)
}
return
}
l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", hostinfo.remote).WithField("newAddr", addr).
hostinfo.logger().WithField("udpAddr", hostinfo.remote).WithField("newAddr", addr).
Info("Host roamed to new udp ip/port.")
hostinfo.lastRoam = time.Now()
remoteCopy := *hostinfo.remote
@@ -244,7 +250,7 @@ func (f *Interface) decrypt(hostinfo *HostInfo, mc uint64, out []byte, packet []
}
if !hostinfo.ConnectionState.window.Update(mc) {
l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("header", header).
hostinfo.logger().WithField("header", header).
Debugln("dropping out of window packet")
return nil, errors.New("out of window packet")
}
@@ -257,7 +263,7 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:HeaderLen], packet[HeaderLen:], messageCounter, nb)
if err != nil {
l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).Error("Failed to decrypt packet")
hostinfo.logger().WithError(err).Error("Failed to decrypt packet")
//TODO: maybe after build 64 is out? 06/14/2018 - NB
//f.sendRecvError(hostinfo.remote, header.RemoteIndex)
return
@@ -265,20 +271,24 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
err = newPacket(out, true, fwPacket)
if err != nil {
l.WithError(err).WithField("packet", out).WithField("hostInfo", IntIp(hostinfo.hostId)).
hostinfo.logger().WithError(err).WithField("packet", out).
Warnf("Error while validating inbound packet")
return
}
if !hostinfo.ConnectionState.window.Update(messageCounter) {
l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("fwPacket", fwPacket).
hostinfo.logger().WithField("fwPacket", fwPacket).
Debugln("dropping out of window packet")
return
}
if f.firewall.Drop(out, *fwPacket, true, hostinfo, trustedCAs) {
l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("fwPacket", fwPacket).
Debugln("dropping inbound packet")
dropReason := f.firewall.Drop(out, *fwPacket, true, hostinfo, trustedCAs)
if dropReason != nil {
if l.Level >= logrus.DebugLevel {
hostinfo.logger().WithField("fwPacket", fwPacket).
WithField("reason", dropReason).
Debugln("dropping inbound packet")
}
return
}
@@ -290,7 +300,7 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
}
func (f *Interface) sendRecvError(endpoint *udpAddr, index uint32) {
f.metricTxRecvError.Inc(1)
f.messageMetrics.Tx(recvError, 0, 1)
//TODO: this should be a signed message so we can trust that we should drop the index
b := HeaderEncode(make([]byte, HeaderLen), Version, uint8(recvError), 0, index, 0)
@@ -303,8 +313,6 @@ func (f *Interface) sendRecvError(endpoint *udpAddr, index uint32) {
}
func (f *Interface) handleRecvError(addr *udpAddr, h *Header) {
f.metricRxRecvError.Inc(1)
// This flag is to stop caring about recv_error from old versions
// This should go away when the old version is gone from prod
if l.Level >= logrus.DebugLevel {

View File

@@ -1,10 +1,11 @@
package nebula
import (
"github.com/stretchr/testify/assert"
"golang.org/x/net/ipv4"
"net"
"testing"
"github.com/stretchr/testify/assert"
"golang.org/x/net/ipv4"
)
func Test_newPacket(t *testing.T) {

30
punchy.go Normal file
View File

@@ -0,0 +1,30 @@
package nebula
import "time"
type Punchy struct {
Punch bool
Respond bool
Delay time.Duration
}
func NewPunchyFromConfig(c *Config) *Punchy {
p := &Punchy{}
if c.IsSet("punchy.punch") {
p.Punch = c.GetBool("punchy.punch", false)
} else {
// Deprecated fallback
p.Punch = c.GetBool("punchy", false)
}
if c.IsSet("punchy.respond") {
p.Respond = c.GetBool("punchy.respond", false)
} else {
// Deprecated fallback
p.Respond = c.GetBool("punch_back", false)
}
p.Delay = c.GetDuration("punchy.delay", time.Second)
return p
}

44
punchy_test.go Normal file
View File

@@ -0,0 +1,44 @@
package nebula
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestNewPunchyFromConfig(t *testing.T) {
c := NewConfig()
// Test defaults
p := NewPunchyFromConfig(c)
assert.Equal(t, false, p.Punch)
assert.Equal(t, false, p.Respond)
assert.Equal(t, time.Second, p.Delay)
// punchy deprecation
c.Settings["punchy"] = true
p = NewPunchyFromConfig(c)
assert.Equal(t, true, p.Punch)
// punchy.punch
c.Settings["punchy"] = map[interface{}]interface{}{"punch": true}
p = NewPunchyFromConfig(c)
assert.Equal(t, true, p.Punch)
// punch_back deprecation
c.Settings["punch_back"] = true
p = NewPunchyFromConfig(c)
assert.Equal(t, true, p.Respond)
// punchy.respond
c.Settings["punchy"] = map[interface{}]interface{}{"respond": true}
c.Settings["punch_back"] = false
p = NewPunchyFromConfig(c)
assert.Equal(t, true, p.Respond)
// punchy.delay
c.Settings["punchy"] = map[interface{}]interface{}{"delay": "1m"}
p = NewPunchyFromConfig(c)
assert.Equal(t, time.Minute, p.Delay)
}

56
ssh.go
View File

@@ -5,8 +5,6 @@ import (
"encoding/json"
"flag"
"fmt"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/sshd"
"io/ioutil"
"net"
"os"
@@ -14,6 +12,9 @@ import (
"runtime/pprof"
"strings"
"syscall"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/sshd"
)
type sshListHostMapFlags struct {
@@ -65,10 +66,11 @@ func configSSH(ssh *sshd.SSHServer, c *Config) error {
return fmt.Errorf("sshd.listen must be provided")
}
port := strings.Split(listen, ":")
if len(port) < 2 {
return fmt.Errorf("sshd.listen does not have a port")
} else if port[1] == "22" {
_, port, err := net.SplitHostPort(listen)
if err != nil {
return fmt.Errorf("invalid sshd.listen address: %s", err)
}
if port == "22" {
return fmt.Errorf("sshd.listen can not use port 22")
}
@@ -461,7 +463,12 @@ func sshQueryLighthouse(ifce *Interface, fs interface{}, a []string, w sshd.Stri
return w.WriteLine("No vpn ip was provided")
}
vpnIp := ip2int(net.ParseIP(a[0]))
parsedIp := net.ParseIP(a[0])
if parsedIp == nil {
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
}
vpnIp := ip2int(parsedIp)
if vpnIp == 0 {
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
}
@@ -481,7 +488,12 @@ func sshCloseTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWr
return w.WriteLine("No vpn ip was provided")
}
vpnIp := ip2int(net.ParseIP(a[0]))
parsedIp := net.ParseIP(a[0])
if parsedIp == nil {
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
}
vpnIp := ip2int(parsedIp)
if vpnIp == 0 {
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
}
@@ -519,7 +531,12 @@ func sshCreateTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringW
return w.WriteLine("No vpn ip was provided")
}
vpnIp := ip2int(net.ParseIP(a[0]))
parsedIp := net.ParseIP(a[0])
if parsedIp == nil {
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
}
vpnIp := ip2int(parsedIp)
if vpnIp == 0 {
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
}
@@ -571,7 +588,12 @@ func sshChangeRemote(ifce *Interface, fs interface{}, a []string, w sshd.StringW
return w.WriteLine("Address could not be parsed")
}
vpnIp := ip2int(net.ParseIP(a[0]))
parsedIp := net.ParseIP(a[0])
if parsedIp == nil {
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
}
vpnIp := ip2int(parsedIp)
if vpnIp == 0 {
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
}
@@ -647,7 +669,12 @@ func sshPrintCert(ifce *Interface, fs interface{}, a []string, w sshd.StringWrit
cert := ifce.certState.certificate
if len(a) > 0 {
vpnIp := ip2int(net.ParseIP(a[0]))
parsedIp := net.ParseIP(a[0])
if parsedIp == nil {
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
}
vpnIp := ip2int(parsedIp)
if vpnIp == 0 {
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
}
@@ -694,7 +721,12 @@ func sshPrintTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWr
return w.WriteLine("No vpn ip was provided")
}
vpnIp := ip2int(net.ParseIP(a[0]))
parsedIp := net.ParseIP(a[0])
if parsedIp == nil {
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
}
vpnIp := ip2int(parsedIp)
if vpnIp == 0 {
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
}

View File

@@ -4,9 +4,10 @@ import (
"errors"
"flag"
"fmt"
"github.com/armon/go-radix"
"sort"
"strings"
"github.com/armon/go-radix"
)
// CommandFlags is a function called before help or command execution to parse command line flags

View File

@@ -2,10 +2,11 @@ package sshd
import (
"fmt"
"net"
"github.com/armon/go-radix"
"github.com/sirupsen/logrus"
"golang.org/x/crypto/ssh"
"net"
)
type SSHServer struct {

View File

@@ -2,13 +2,14 @@ package sshd
import (
"fmt"
"sort"
"strings"
"github.com/anmitsu/go-shlex"
"github.com/armon/go-radix"
"github.com/sirupsen/logrus"
"golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/terminal"
"sort"
"strings"
)
type session struct {

View File

@@ -3,18 +3,19 @@ package nebula
import (
"errors"
"fmt"
"github.com/cyberdelia/go-metrics-graphite"
mp "github.com/nbrownus/go-metrics-prometheus"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promhttp"
"github.com/rcrowley/go-metrics"
"log"
"net"
"net/http"
"time"
graphite "github.com/cyberdelia/go-metrics-graphite"
mp "github.com/nbrownus/go-metrics-prometheus"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promhttp"
"github.com/rcrowley/go-metrics"
)
func startStats(c *Config) error {
func startStats(c *Config, configTest bool) error {
mType := c.GetString("stats.type", "")
if mType == "" || mType == "none" {
return nil
@@ -27,9 +28,9 @@ func startStats(c *Config) error {
switch mType {
case "graphite":
startGraphiteStats(interval, c)
startGraphiteStats(interval, c, configTest)
case "prometheus":
startPrometheusStats(interval, c)
startPrometheusStats(interval, c, configTest)
default:
return fmt.Errorf("stats.type was not understood: %s", mType)
}
@@ -43,7 +44,7 @@ func startStats(c *Config) error {
return nil
}
func startGraphiteStats(i time.Duration, c *Config) error {
func startGraphiteStats(i time.Duration, c *Config, configTest bool) error {
proto := c.GetString("stats.protocol", "tcp")
host := c.GetString("stats.host", "")
if host == "" {
@@ -57,11 +58,13 @@ func startGraphiteStats(i time.Duration, c *Config) error {
}
l.Infof("Starting graphite. Interval: %s, prefix: %s, addr: %s", i, prefix, addr)
go graphite.Graphite(metrics.DefaultRegistry, i, prefix, addr)
if !configTest {
go graphite.Graphite(metrics.DefaultRegistry, i, prefix, addr)
}
return nil
}
func startPrometheusStats(i time.Duration, c *Config) error {
func startPrometheusStats(i time.Duration, c *Config, configTest bool) error {
namespace := c.GetString("stats.namespace", "")
subsystem := c.GetString("stats.subsystem", "")
@@ -79,11 +82,13 @@ func startPrometheusStats(i time.Duration, c *Config) error {
pClient := mp.NewPrometheusProvider(metrics.DefaultRegistry, namespace, subsystem, pr, i)
go pClient.UpdatePrometheusMetrics()
go func() {
l.Infof("Prometheus stats listening on %s at %s", listen, path)
http.Handle(path, promhttp.HandlerFor(pr, promhttp.HandlerOpts{ErrorLog: l}))
log.Fatal(http.ListenAndServe(listen, nil))
}()
if !configTest {
go func() {
l.Infof("Prometheus stats listening on %s at %s", listen, path)
http.Handle(path, promhttp.HandlerFor(pr, promhttp.HandlerOpts{ErrorLog: l}))
log.Fatal(http.ListenAndServe(listen, nil))
}()
}
return nil
}

View File

@@ -1,9 +1,10 @@
package nebula
import (
"github.com/stretchr/testify/assert"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestNewTimerWheel(t *testing.T) {

76
tun_android.go Normal file
View File

@@ -0,0 +1,76 @@
package nebula
import (
"fmt"
"io"
"net"
"os"
"golang.org/x/sys/unix"
)
type Tun struct {
io.ReadWriteCloser
fd int
Device string
Cidr *net.IPNet
MaxMTU int
DefaultMTU int
TXQueueLen int
Routes []route
UnsafeRoutes []route
}
func newTunFromFd(deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
ifce = &Tun{
ReadWriteCloser: file,
fd: int(file.Fd()),
Device: "android",
Cidr: cidr,
DefaultMTU: defaultMTU,
TXQueueLen: txQueueLen,
Routes: routes,
UnsafeRoutes: unsafeRoutes,
}
return
}
func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
return nil, fmt.Errorf("newTun not supported in Android")
}
func (c *Tun) WriteRaw(b []byte) error {
var nn int
for {
max := len(b)
n, err := unix.Write(c.fd, b[nn:max])
if n > 0 {
nn += n
}
if nn == len(b) {
return err
}
if err != nil {
return err
}
if n == 0 {
return io.ErrUnexpectedEOF
}
}
}
func (c Tun) Activate() error {
return nil
}
func (c *Tun) CidrNet() *net.IPNet {
return c.Cidr
}
func (c *Tun) DeviceName() string {
return c.Device
}

View File

@@ -132,7 +132,7 @@ func parseUnsafeRoutes(config *Config, network *net.IPNet) ([]route, error) {
via, ok := rVia.(string)
if !ok {
return nil, fmt.Errorf("entry %v.via in tun.unsafe_routes is not a string: %v", i+1, err)
return nil, fmt.Errorf("entry %v.via in tun.unsafe_routes is not a string: found %T", i+1, rVia)
}
nVia := net.ParseIP(via)
@@ -147,6 +147,7 @@ func parseUnsafeRoutes(config *Config, network *net.IPNet) ([]route, error) {
r := route{
via: &nVia,
mtu: mtu,
}
_, r.route, err = net.ParseCIDR(fmt.Sprintf("%v", rRoute))

View File

@@ -1,3 +1,5 @@
// +build !ios
package nebula
import (
@@ -20,8 +22,9 @@ type Tun struct {
func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
if len(routes) > 0 {
return nil, fmt.Errorf("Route MTU not supported in Darwin")
return nil, fmt.Errorf("route MTU not supported in Darwin")
}
// NOTE: You cannot set the deviceName under Darwin, so you must check tun.Device after calling .Activate()
return &Tun{
Cidr: cidr,
@@ -30,30 +33,34 @@ func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route,
}, nil
}
func newTunFromFd(deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
return nil, fmt.Errorf("newTunFromFd not supported in Darwin")
}
func (c *Tun) Activate() error {
var err error
c.Interface, err = water.New(water.Config{
DeviceType: water.TUN,
})
if err != nil {
return fmt.Errorf("Activate failed: %v", err)
return fmt.Errorf("activate failed: %v", err)
}
c.Device = c.Interface.Name()
// TODO use syscalls instead of exec.Command
if err = exec.Command("ifconfig", c.Device, c.Cidr.String(), c.Cidr.IP.String()).Run(); err != nil {
if err = exec.Command("/sbin/ifconfig", c.Device, c.Cidr.String(), c.Cidr.IP.String()).Run(); err != nil {
return fmt.Errorf("failed to run 'ifconfig': %s", err)
}
if err = exec.Command("route", "-n", "add", "-net", c.Cidr.String(), "-interface", c.Device).Run(); err != nil {
if err = exec.Command("/sbin/route", "-n", "add", "-net", c.Cidr.String(), "-interface", c.Device).Run(); err != nil {
return fmt.Errorf("failed to run 'route add': %s", err)
}
if err = exec.Command("ifconfig", c.Device, "mtu", strconv.Itoa(c.MTU)).Run(); err != nil {
if err = exec.Command("/sbin/ifconfig", c.Device, "mtu", strconv.Itoa(c.MTU)).Run(); err != nil {
return fmt.Errorf("failed to run 'ifconfig': %s", err)
}
// Unsafe path routes
for _, r := range c.UnsafeRoutes {
if err = exec.Command("route", "-n", "add", "-net", r.route.String(), "-interface", c.Device).Run(); err != nil {
if err = exec.Command("/sbin/route", "-n", "add", "-net", r.route.String(), "-interface", c.Device).Run(); err != nil {
return fmt.Errorf("failed to run 'route add' for unsafe_route %s: %s", r.route.String(), err)
}
}
@@ -61,6 +68,14 @@ func (c *Tun) Activate() error {
return nil
}
func (c *Tun) CidrNet() *net.IPNet {
return c.Cidr
}
func (c *Tun) DeviceName() string {
return c.Device
}
func (c *Tun) WriteRaw(b []byte) error {
_, err := c.Write(b)
return err

74
tun_disabled.go Normal file
View File

@@ -0,0 +1,74 @@
package nebula
import (
"fmt"
"io"
"net"
"strings"
log "github.com/sirupsen/logrus"
)
type disabledTun struct {
block chan struct{}
cidr *net.IPNet
logger *log.Logger
}
func newDisabledTun(cidr *net.IPNet, l *log.Logger) *disabledTun {
return &disabledTun{
cidr: cidr,
block: make(chan struct{}),
logger: l,
}
}
func (*disabledTun) Activate() error {
return nil
}
func (t *disabledTun) CidrNet() *net.IPNet {
return t.cidr
}
func (*disabledTun) DeviceName() string {
return "disabled"
}
func (t *disabledTun) Read(b []byte) (int, error) {
<-t.block
return 0, io.EOF
}
func (t *disabledTun) Write(b []byte) (int, error) {
t.logger.WithField("raw", prettyPacket(b)).Debugf("Disabled tun received unexpected payload")
return len(b), nil
}
func (t *disabledTun) WriteRaw(b []byte) error {
_, err := t.Write(b)
return err
}
func (t *disabledTun) Close() error {
if t.block != nil {
close(t.block)
t.block = nil
}
return nil
}
type prettyPacket []byte
func (p prettyPacket) String() string {
var s strings.Builder
for i, b := range p {
if i > 0 && i%8 == 0 {
s.WriteString(" ")
}
s.WriteString(fmt.Sprintf("%02x ", b))
}
return s.String()
}

89
tun_freebsd.go Normal file
View File

@@ -0,0 +1,89 @@
package nebula
import (
"fmt"
"io"
"net"
"os"
"os/exec"
"regexp"
"strconv"
"strings"
)
var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
type Tun struct {
Device string
Cidr *net.IPNet
MTU int
UnsafeRoutes []route
io.ReadWriteCloser
}
func newTunFromFd(deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
return nil, fmt.Errorf("newTunFromFd not supported in FreeBSD")
}
func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
if len(routes) > 0 {
return nil, fmt.Errorf("Route MTU not supported in FreeBSD")
}
if strings.HasPrefix(deviceName, "/dev/") {
deviceName = strings.TrimPrefix(deviceName, "/dev/")
}
if !deviceNameRE.MatchString(deviceName) {
return nil, fmt.Errorf("tun.dev must match `tun[0-9]+`")
}
return &Tun{
Device: deviceName,
Cidr: cidr,
MTU: defaultMTU,
UnsafeRoutes: unsafeRoutes,
}, nil
}
func (c *Tun) Activate() error {
var err error
c.ReadWriteCloser, err = os.OpenFile("/dev/"+c.Device, os.O_RDWR, 0)
if err != nil {
return fmt.Errorf("Activate failed: %v", err)
}
// TODO use syscalls instead of exec.Command
l.Debug("command: ifconfig", c.Device, c.Cidr.String(), c.Cidr.IP.String())
if err = exec.Command("/sbin/ifconfig", c.Device, c.Cidr.String(), c.Cidr.IP.String()).Run(); err != nil {
return fmt.Errorf("failed to run 'ifconfig': %s", err)
}
l.Debug("command: route", "-n", "add", "-net", c.Cidr.String(), "-interface", c.Device)
if err = exec.Command("/sbin/route", "-n", "add", "-net", c.Cidr.String(), "-interface", c.Device).Run(); err != nil {
return fmt.Errorf("failed to run 'route add': %s", err)
}
l.Debug("command: ifconfig", c.Device, "mtu", strconv.Itoa(c.MTU))
if err = exec.Command("/sbin/ifconfig", c.Device, "mtu", strconv.Itoa(c.MTU)).Run(); err != nil {
return fmt.Errorf("failed to run 'ifconfig': %s", err)
}
// Unsafe path routes
for _, r := range c.UnsafeRoutes {
l.Debug("command: route", "-n", "add", "-net", r.route.String(), "-interface", c.Device)
if err = exec.Command("/sbin/route", "-n", "add", "-net", r.route.String(), "-interface", c.Device).Run(); err != nil {
return fmt.Errorf("failed to run 'route add' for unsafe_route %s: %s", r.route.String(), err)
}
}
return nil
}
func (c *Tun) CidrNet() *net.IPNet {
return c.Cidr
}
func (c *Tun) DeviceName() string {
return c.Device
}
func (c *Tun) WriteRaw(b []byte) error {
_, err := c.Write(b)
return err
}

113
tun_ios.go Normal file
View File

@@ -0,0 +1,113 @@
// +build ios
package nebula
import (
"errors"
"fmt"
"io"
"net"
"os"
"sync"
"syscall"
)
type Tun struct {
io.ReadWriteCloser
Device string
Cidr *net.IPNet
}
func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
return nil, fmt.Errorf("newTun not supported in iOS")
}
func newTunFromFd(deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
if len(routes) > 0 {
return nil, fmt.Errorf("route MTU not supported in Darwin")
}
file := os.NewFile(uintptr(deviceFd), "/dev/tun")
ifce = &Tun{
Cidr: cidr,
Device: "iOS",
ReadWriteCloser: &tunReadCloser{f: file},
}
return
}
func (c *Tun) Activate() error {
return nil
}
func (c *Tun) WriteRaw(b []byte) error {
_, err := c.Write(b)
return err
}
// The following is hoisted up from water, we do this so we can inject our own fd on iOS
type tunReadCloser struct {
f io.ReadWriteCloser
rMu sync.Mutex
rBuf []byte
wMu sync.Mutex
wBuf []byte
}
func (t *tunReadCloser) Read(to []byte) (int, error) {
t.rMu.Lock()
defer t.rMu.Unlock()
if cap(t.rBuf) < len(to)+4 {
t.rBuf = make([]byte, len(to)+4)
}
t.rBuf = t.rBuf[:len(to)+4]
n, err := t.f.Read(t.rBuf)
copy(to, t.rBuf[4:])
return n - 4, err
}
func (t *tunReadCloser) Write(from []byte) (int, error) {
if len(from) == 0 {
return 0, syscall.EIO
}
t.wMu.Lock()
defer t.wMu.Unlock()
if cap(t.wBuf) < len(from)+4 {
t.wBuf = make([]byte, len(from)+4)
}
t.wBuf = t.wBuf[:len(from)+4]
// Determine the IP Family for the NULL L2 Header
ipVer := from[0] >> 4
if ipVer == 4 {
t.wBuf[3] = syscall.AF_INET
} else if ipVer == 6 {
t.wBuf[3] = syscall.AF_INET6
} else {
return 0, errors.New("unable to determine IP version from packet")
}
copy(t.wBuf[4:], from)
n, err := t.f.Write(t.wBuf)
return n - 4, err
}
func (t *tunReadCloser) Close() error {
return t.f.Close()
}
func (c *Tun) CidrNet() *net.IPNet {
return c.Cidr
}
func (c *Tun) DeviceName() string {
return c.Device
}

View File

@@ -1,3 +1,5 @@
// +build !android
package nebula
import (
@@ -75,6 +77,23 @@ type ifreqQLEN struct {
pad [8]byte
}
func newTunFromFd(deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
ifce = &Tun{
ReadWriteCloser: file,
fd: int(file.Fd()),
Device: "tun0",
Cidr: cidr,
DefaultMTU: defaultMTU,
TXQueueLen: txQueueLen,
Routes: routes,
UnsafeRoutes: unsafeRoutes,
}
return
}
func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
if err != nil {
@@ -216,6 +235,7 @@ func (c Tun) Activate() error {
LinkIndex: link.Attrs().Index,
Dst: dr,
MTU: c.DefaultMTU,
AdvMSS: c.advMSS(route{}),
Scope: unix.RT_SCOPE_LINK,
Src: c.Cidr.IP,
Protocol: unix.RTPROT_KERNEL,
@@ -233,6 +253,7 @@ func (c Tun) Activate() error {
LinkIndex: link.Attrs().Index,
Dst: r.route,
MTU: r.mtu,
AdvMSS: c.advMSS(r),
Scope: unix.RT_SCOPE_LINK,
}
@@ -248,6 +269,7 @@ func (c Tun) Activate() error {
LinkIndex: link.Attrs().Index,
Dst: r.route,
MTU: r.mtu,
AdvMSS: c.advMSS(r),
Scope: unix.RT_SCOPE_LINK,
}
@@ -265,3 +287,24 @@ func (c Tun) Activate() error {
return nil
}
func (c *Tun) CidrNet() *net.IPNet {
return c.Cidr
}
func (c *Tun) DeviceName() string {
return c.Device
}
func (c Tun) advMSS(r route) int {
mtu := r.mtu
if r.mtu == 0 {
mtu = c.DefaultMTU
}
// We only need to set advmss if the route MTU does not match the device MTU
if mtu != c.MaxMTU {
return mtu - 40
}
return 0
}

31
tun_linux_test.go Normal file
View File

@@ -0,0 +1,31 @@
package nebula
import "testing"
var runAdvMSSTests = []struct {
name string
tun Tun
r route
expected int
}{
// Standard case, default MTU is the device max MTU
{"default", Tun{DefaultMTU: 1440, MaxMTU: 1440}, route{}, 0},
{"default-min", Tun{DefaultMTU: 1440, MaxMTU: 1440}, route{mtu: 1440}, 0},
{"default-low", Tun{DefaultMTU: 1440, MaxMTU: 1440}, route{mtu: 1200}, 1160},
// Case where we have a route MTU set higher than the default
{"route", Tun{DefaultMTU: 1440, MaxMTU: 8941}, route{}, 1400},
{"route-min", Tun{DefaultMTU: 1440, MaxMTU: 8941}, route{mtu: 1440}, 1400},
{"route-high", Tun{DefaultMTU: 1440, MaxMTU: 8941}, route{mtu: 8941}, 0},
}
func TestTunAdvMSS(t *testing.T) {
for _, tt := range runAdvMSSTests {
t.Run(tt.name, func(t *testing.T) {
o := tt.tun.advMSS(tt.r)
if o != tt.expected {
t.Errorf("got %d, want %d", o, tt.expected)
}
})
}
}

View File

@@ -1,9 +1,11 @@
package nebula
import (
"github.com/stretchr/testify/assert"
"fmt"
"net"
"testing"
"github.com/stretchr/testify/assert"
)
func Test_parseRoutes(t *testing.T) {
@@ -100,3 +102,126 @@ func Test_parseRoutes(t *testing.T) {
t.Fatal("Did not see both routes")
}
}
func Test_parseUnsafeRoutes(t *testing.T) {
c := NewConfig()
_, n, _ := net.ParseCIDR("10.0.0.0/24")
// test no routes config
routes, err := parseUnsafeRoutes(c, n)
assert.Nil(t, err)
assert.Len(t, routes, 0)
// not an array
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": "hi"}
routes, err = parseUnsafeRoutes(c, n)
assert.Nil(t, routes)
assert.EqualError(t, err, "tun.unsafe_routes is not an array")
// no routes
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{}}
routes, err = parseUnsafeRoutes(c, n)
assert.Nil(t, err)
assert.Len(t, routes, 0)
// weird route
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{"asdf"}}
routes, err = parseUnsafeRoutes(c, n)
assert.Nil(t, routes)
assert.EqualError(t, err, "entry 1 in tun.unsafe_routes is invalid")
// no via
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{}}}
routes, err = parseUnsafeRoutes(c, n)
assert.Nil(t, routes)
assert.EqualError(t, err, "entry 1.via in tun.unsafe_routes is not present")
// invalid via
for _, invalidValue := range []interface{}{
127, false, nil, 1.0, []string{"1", "2"},
} {
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": invalidValue}}}
routes, err = parseUnsafeRoutes(c, n)
assert.Nil(t, routes)
assert.EqualError(t, err, fmt.Sprintf("entry 1.via in tun.unsafe_routes is not a string: found %T", invalidValue))
}
// unparsable via
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "via": "nope"}}}
routes, err = parseUnsafeRoutes(c, n)
assert.Nil(t, routes)
assert.EqualError(t, err, "entry 1.via in tun.unsafe_routes failed to parse address: nope")
// missing route
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "500"}}}
routes, err = parseUnsafeRoutes(c, n)
assert.Nil(t, routes)
assert.EqualError(t, err, "entry 1.route in tun.unsafe_routes is not present")
// unparsable route
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "500", "route": "nope"}}}
routes, err = parseUnsafeRoutes(c, n)
assert.Nil(t, routes)
assert.EqualError(t, err, "entry 1.route in tun.unsafe_routes failed to parse: invalid CIDR address: nope")
// within network range
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "10.0.0.0/24"}}}
routes, err = parseUnsafeRoutes(c, n)
assert.Nil(t, routes)
assert.EqualError(t, err, "entry 1.route in tun.unsafe_routes is contained within the network attached to the certificate; route: 10.0.0.0/24, network: 10.0.0.0/24")
// below network range
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "1.0.0.0/8"}}}
routes, err = parseUnsafeRoutes(c, n)
assert.Len(t, routes, 1)
assert.Nil(t, err)
// above network range
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "10.0.1.0/24"}}}
routes, err = parseUnsafeRoutes(c, n)
assert.Len(t, routes, 1)
assert.Nil(t, err)
// no mtu
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "1.0.0.0/8"}}}
routes, err = parseUnsafeRoutes(c, n)
assert.Len(t, routes, 1)
assert.Equal(t, DEFAULT_MTU, routes[0].mtu)
// bad mtu
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "nope"}}}
routes, err = parseUnsafeRoutes(c, n)
assert.Nil(t, routes)
assert.EqualError(t, err, "entry 1.mtu in tun.unsafe_routes is not an integer: strconv.Atoi: parsing \"nope\": invalid syntax")
// low mtu
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "499"}}}
routes, err = parseUnsafeRoutes(c, n)
assert.Nil(t, routes)
assert.EqualError(t, err, "entry 1.mtu in tun.unsafe_routes is below 500: 499")
// happy case
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{
map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "9000", "route": "1.0.0.0/29"},
map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "8000", "route": "1.0.0.1/32"},
}}
routes, err = parseUnsafeRoutes(c, n)
assert.Nil(t, err)
assert.Len(t, routes, 2)
tested := 0
for _, r := range routes {
if r.mtu == 8000 {
assert.Equal(t, "1.0.0.1/32", r.route.String())
tested++
} else {
assert.Equal(t, 9000, r.mtu)
assert.Equal(t, "1.0.0.0/29", r.route.String())
tested++
}
}
if tested != 2 {
t.Fatal("Did not see both unsafe_routes")
}
}

View File

@@ -4,29 +4,34 @@ import (
"fmt"
"net"
"os/exec"
"strconv"
"github.com/songgao/water"
)
type Tun struct {
Device string
Cidr *net.IPNet
MTU int
Device string
Cidr *net.IPNet
MTU int
UnsafeRoutes []route
*water.Interface
}
func newTunFromFd(deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
return nil, fmt.Errorf("newTunFromFd not supported in Windows")
}
func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
if len(routes) > 0 {
return nil, fmt.Errorf("Route MTU not supported in Windows")
}
if len(unsafeRoutes) > 0 {
return nil, fmt.Errorf("unsafeRoutes not supported in Windows")
return nil, fmt.Errorf("route MTU not supported in Windows")
}
// NOTE: You cannot set the deviceName under Windows, so you must check tun.Device after calling .Activate()
return &Tun{
Cidr: cidr,
MTU: defaultMTU,
Cidr: cidr,
MTU: defaultMTU,
UnsafeRoutes: unsafeRoutes,
}, nil
}
@@ -47,7 +52,7 @@ func (c *Tun) Activate() error {
// TODO use syscalls instead of exec.Command
err = exec.Command(
"netsh", "interface", "ipv4", "set", "address",
`C:\Windows\System32\netsh.exe`, "interface", "ipv4", "set", "address",
fmt.Sprintf("name=%s", c.Device),
"source=static",
fmt.Sprintf("addr=%s", c.Cidr.IP),
@@ -58,7 +63,7 @@ func (c *Tun) Activate() error {
return fmt.Errorf("failed to run 'netsh' to set address: %s", err)
}
err = exec.Command(
"netsh", "interface", "ipv4", "set", "interface",
`C:\Windows\System32\netsh.exe`, "interface", "ipv4", "set", "interface",
c.Device,
fmt.Sprintf("mtu=%d", c.MTU),
).Run()
@@ -66,9 +71,31 @@ func (c *Tun) Activate() error {
return fmt.Errorf("failed to run 'netsh' to set MTU: %s", err)
}
iface, err := net.InterfaceByName(c.Device)
if err != nil {
return fmt.Errorf("failed to find interface named %s: %v", c.Device, err)
}
for _, r := range c.UnsafeRoutes {
err = exec.Command(
"C:\\Windows\\System32\\route.exe", "add", r.route.String(), r.via.String(), "IF", strconv.Itoa(iface.Index),
).Run()
if err != nil {
return fmt.Errorf("failed to add the unsafe_route %s: %v", r.route.String(), err)
}
}
return nil
}
func (c *Tun) CidrNet() *net.IPNet {
return c.Cidr
}
func (c *Tun) DeviceName() string {
return c.Device
}
func (c *Tun) WriteRaw(b []byte) error {
_, err := c.Write(b)
return err

36
udp_android.go Normal file
View File

@@ -0,0 +1,36 @@
package nebula
import (
"fmt"
"net"
"syscall"
"golang.org/x/sys/unix"
)
func NewListenConfig(multi bool) net.ListenConfig {
return net.ListenConfig{
Control: func(network, address string, c syscall.RawConn) error {
if multi {
var controlErr error
err := c.Control(func(fd uintptr) {
if err := syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, unix.SO_REUSEPORT, 1); err != nil {
controlErr = fmt.Errorf("SO_REUSEPORT failed: %v", err)
return
}
})
if err != nil {
return err
}
if controlErr != nil {
return controlErr
}
}
return nil
},
}
}
func (u *udpConn) Rebind() error {
return nil
}

View File

@@ -32,3 +32,12 @@ func NewListenConfig(multi bool) net.ListenConfig {
},
}
}
func (u *udpConn) Rebind() error {
file, err := u.File()
if err != nil {
return err
}
return syscall.SetsockoptInt(int(file.Fd()), unix.IPPROTO_IP, unix.IP_BOUND_IF, 0)
}

38
udp_freebsd.go Normal file
View File

@@ -0,0 +1,38 @@
package nebula
// FreeBSD support is primarily implemented in udp_generic, besides NewListenConfig
import (
"fmt"
"net"
"syscall"
"golang.org/x/sys/unix"
)
func NewListenConfig(multi bool) net.ListenConfig {
return net.ListenConfig{
Control: func(network, address string, c syscall.RawConn) error {
if multi {
var controlErr error
err := c.Control(func(fd uintptr) {
if err := syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, unix.SO_REUSEPORT, 1); err != nil {
controlErr = fmt.Errorf("SO_REUSEPORT failed: %v", err)
return
}
})
if err != nil {
return err
}
if controlErr != nil {
return controlErr
}
}
return nil
},
}
}
func (u *udpConn) Rebind() error {
return nil
}

View File

@@ -1,4 +1,4 @@
// +build !linux
// +build !linux android
// udp_generic implements the nebula UDP interface in pure Go stdlib. This
// means it can be used on platforms like Darwin and Windows.
@@ -65,6 +65,17 @@ func (ua *udpAddr) Equals(t *udpAddr) bool {
return ua.IP.Equal(t.IP) && ua.Port == t.Port
}
func (ua *udpAddr) Copy() udpAddr {
nu := udpAddr{net.UDPAddr{
Port: ua.Port,
Zone: ua.Zone,
IP: make(net.IP, len(ua.IP)),
}}
copy(nu.IP, ua.IP)
return nu
}
func (uc *udpConn) WriteTo(b []byte, addr *udpAddr) error {
_, err := uc.UDPConn.WriteToUDP(b, &addr.UDPAddr)
return err

View File

@@ -1,3 +1,5 @@
// +build !android
package nebula
import (
@@ -69,8 +71,10 @@ func NewListener(ip string, port int, multi bool) (*udpConn, error) {
var lip [4]byte
copy(lip[:], net.ParseIP(ip).To4())
if err = unix.SetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_REUSEPORT, 1); err != nil {
return nil, fmt.Errorf("unable to set SO_REUSEPORT: %s", err)
if multi {
if err = unix.SetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_REUSEPORT, 1); err != nil {
return nil, fmt.Errorf("unable to set SO_REUSEPORT: %s", err)
}
}
if err = unix.Bind(fd, &unix.SockaddrInet4{Addr: lip, Port: port}); err != nil {
@@ -85,6 +89,14 @@ func NewListener(ip string, port int, multi bool) (*udpConn, error) {
return &udpConn{sysFd: fd}, err
}
func (u *udpConn) Rebind() error {
return nil
}
func (ua *udpAddr) Copy() udpAddr {
return *ua
}
func (u *udpConn) SetRecvBuffer(n int) error {
return unix.SetsockoptInt(u.sysFd, unix.SOL_SOCKET, unix.SO_RCVBUFFORCE, n)
}
@@ -137,9 +149,13 @@ func (u *udpConn) ListenOut(f *Interface) {
//TODO: should we track this?
//metric := metrics.GetOrRegisterHistogram("test.batch_read", nil, metrics.NewExpDecaySample(1028, 0.015))
msgs, buffers, names := u.PrepareRawMessages(f.udpBatchSize)
read := u.ReadMulti
if f.udpBatchSize == 1 {
read = u.ReadSingle
}
for {
n, err := u.ReadMulti(msgs)
n, err := read(msgs)
if err != nil {
l.WithError(err).Error("Failed to read packets")
continue
@@ -155,34 +171,24 @@ func (u *udpConn) ListenOut(f *Interface) {
}
}
func (u *udpConn) Read(addr *udpAddr, b []byte) ([]byte, error) {
var rsa rawSockaddrAny
var rLen = unix.SizeofSockaddrAny
func (u *udpConn) ReadSingle(msgs []rawMessage) (int, error) {
for {
n, _, err := unix.Syscall6(
unix.SYS_RECVFROM,
unix.SYS_RECVMSG,
uintptr(u.sysFd),
uintptr(unsafe.Pointer(&b[0])),
uintptr(len(b)),
uintptr(0),
uintptr(unsafe.Pointer(&rsa)),
uintptr(unsafe.Pointer(&rLen)),
uintptr(unsafe.Pointer(&(msgs[0].Hdr))),
0,
0,
0,
0,
)
if err != 0 {
return nil, &net.OpError{Op: "read", Err: err}
return 0, &net.OpError{Op: "recvmsg", Err: err}
}
if rsa.Addr.Family == unix.AF_INET {
addr.Port = uint16(rsa.Addr.Data[0])<<8 + uint16(rsa.Addr.Data[1])
addr.IP = uint32(rsa.Addr.Data[2])<<24 + uint32(rsa.Addr.Data[3])<<16 + uint32(rsa.Addr.Data[4])<<8 + uint32(rsa.Addr.Data[5])
} else {
addr.Port = 0
addr.IP = 0
}
return b[:n], nil
msgs[0].Len = uint32(n)
return 1, nil
}
}
@@ -280,13 +286,6 @@ func (ua *udpAddr) Equals(t *udpAddr) bool {
return ua.IP == t.IP && ua.Port == t.Port
}
func (ua *udpAddr) Copy() *udpAddr {
return &udpAddr{
Port: ua.Port,
IP: ua.IP,
}
}
func (ua *udpAddr) String() string {
return fmt.Sprintf("%s:%v", int2ip(ua.IP), ua.Port)
}

View File

@@ -1,5 +1,6 @@
// +build linux
// +build 386 amd64p32 arm mips mipsle
// +build !android
package nebula

View File

@@ -1,5 +1,6 @@
// +build linux
// +build amd64 arm64 ppc64 ppc64le mips64 mips64le s390x
// +build !android
package nebula

View File

@@ -20,3 +20,7 @@ func NewListenConfig(multi bool) net.ListenConfig {
},
}
}
func (u *udpConn) Rebind() error {
return nil
}

130
util/assert.go Normal file
View File

@@ -0,0 +1,130 @@
package util
import (
"fmt"
"reflect"
"testing"
"time"
"unsafe"
"github.com/stretchr/testify/assert"
)
// AssertDeepCopyEqual checks to see if two variables have the same values but DO NOT share any memory
// There is currently a special case for `time.loc` (as this code traverses into unexported fields)
func AssertDeepCopyEqual(t *testing.T, a interface{}, b interface{}) {
v1 := reflect.ValueOf(a)
v2 := reflect.ValueOf(b)
if !assert.Equal(t, v1.Type(), v2.Type()) {
return
}
traverseDeepCopy(t, v1, v2, v1.Type().String())
}
func traverseDeepCopy(t *testing.T, v1 reflect.Value, v2 reflect.Value, name string) bool {
switch v1.Kind() {
case reflect.Array:
for i := 0; i < v1.Len(); i++ {
if !traverseDeepCopy(t, v1.Index(i), v2.Index(i), fmt.Sprintf("%s[%v]", name, i)) {
return false
}
}
return true
case reflect.Slice:
if v1.IsNil() || v2.IsNil() {
return assert.Equal(t, v1.IsNil(), v2.IsNil(), "%s are not both nil %+v, %+v", name, v1, v2)
}
if !assert.Equal(t, v1.Len(), v2.Len(), "%s did not have the same length", name) {
return false
}
// A slice with cap 0
if v1.Cap() != 0 && !assert.NotEqual(t, v1.Pointer(), v2.Pointer(), "%s point to the same slice %v == %v", name, v1.Pointer(), v2.Pointer()) {
return false
}
v1c := v1.Cap()
v2c := v2.Cap()
if v1c > 0 && v2c > 0 && v1.Slice(0, v1c).Slice(v1c-1, v1c-1).Pointer() == v2.Slice(0, v2c).Slice(v2c-1, v2c-1).Pointer() {
return assert.Fail(t, "", "%s share some underlying memory", name)
}
for i := 0; i < v1.Len(); i++ {
if !traverseDeepCopy(t, v1.Index(i), v2.Index(i), fmt.Sprintf("%s[%v]", name, i)) {
return false
}
}
return true
case reflect.Interface:
if v1.IsNil() || v2.IsNil() {
return assert.Equal(t, v1.IsNil(), v2.IsNil(), "%s are not both nil", name)
}
return traverseDeepCopy(t, v1.Elem(), v2.Elem(), name)
case reflect.Ptr:
local := reflect.ValueOf(time.Local).Pointer()
if local == v1.Pointer() && local == v2.Pointer() {
return true
}
if !assert.NotEqual(t, v1.Pointer(), v2.Pointer(), "%s points to the same memory", name) {
return false
}
return traverseDeepCopy(t, v1.Elem(), v2.Elem(), name)
case reflect.Struct:
for i, n := 0, v1.NumField(); i < n; i++ {
if !traverseDeepCopy(t, v1.Field(i), v2.Field(i), name+"."+v1.Type().Field(i).Name) {
return false
}
}
return true
case reflect.Map:
if v1.IsNil() || v2.IsNil() {
return assert.Equal(t, v1.IsNil(), v2.IsNil(), "%s are not both nil", name)
}
if !assert.Equal(t, v1.Len(), v2.Len(), "%s are not the same length", name) {
return false
}
if !assert.NotEqual(t, v1.Pointer(), v2.Pointer(), "%s point to the same memory", name) {
return false
}
for _, k := range v1.MapKeys() {
val1 := v1.MapIndex(k)
val2 := v2.MapIndex(k)
if !assert.True(t, val1.IsValid(), "%s is an invalid key in %s", k, name) {
return false
}
if !assert.True(t, val2.IsValid(), "%s is an invalid key in %s", k, name) {
return false
}
if !traverseDeepCopy(t, val1, val2, name+fmt.Sprintf("%s[%s]", name, k)) {
return false
}
}
return true
default:
if v1.CanInterface() && v2.CanInterface() {
return assert.Equal(t, v1.Interface(), v2.Interface(), "%s was not equal", name)
}
e1 := reflect.NewAt(v1.Type(), unsafe.Pointer(v1.UnsafeAddr())).Elem().Interface()
e2 := reflect.NewAt(v2.Type(), unsafe.Pointer(v2.UnsafeAddr())).Elem().Interface()
return assert.Equal(t, e1, e2, "%s (unexported) was not equal", name)
}
}