Compare commits

..

60 Commits

Author SHA1 Message Date
Dave Russell
db11e2f1af Revert "smoke test"
This reverts commit fa034a6d83.
2020-10-03 00:09:18 +10:00
Dave Russell
2ee428b067 Hook send should use a code path that actually firewalls
This change enforces that outbound hook traffic will actually be checked
by the firewall and added to the conntrack if allowed.
2020-10-02 23:42:20 +10:00
Dave Russell
e9657d571e control->Send: Also set the src port
With the source port also set, we only need to enable inbound
firewall rules on the 'server' side of the connection, as
the conntrack will allow replies.
2020-10-02 22:25:31 +10:00
Dave Russell
3cebf38504 The custom message packet sender needs a dest port
Source/Dest ports are required for the nebula firewall on the
receiving side, allow the port to be configured so that it can
be matched to specific rules as required.
2020-10-02 20:46:08 +10:00
Dave Russell
ae3ee42469 Provide hooks for custom message packet handlers
This commit augments the Control API by providing new methods to
inject message packets destined peer nodes, and/or to intercept
message packets of a custom message subtype that are received from
peer nodes.
2020-09-28 22:31:19 +10:00
Dave Russell
fa034a6d83 smoke test 2020-09-27 22:43:24 +10:00
Dave Russell
55d72ac46f Tighten up the inside handlers with a bit of DRY 2020-09-27 22:37:20 +10:00
Dave Russell
2c931d5691 Move inside packet handlers into map
This commit moves the inside packet handlers into a map of functions
from the large switch statement. The functions are mapped by packet
protocol version, type and subtype; which makes it simpler to inject
either a new protocol version and/or custom handlers.
2020-09-27 22:04:14 +10:00
Ryan Huber
0d6b55e495 Bring in the new version of kardianos/service and output logfiles on osx (#303)
* this brings in the new version of kardianos/service which properly
outputs logs from launchd services

* add go sum

* is it really this easy?

* Update CHANGELOG.md
2020-09-24 15:34:08 -07:00
Wade Simmons
c71c84882e v1.3.0 (#268)
Update the CHANGELOG for Nebula v1.3.0

Co-authored-by: forfuncsake <drussell@slack-corp.com>
2020-09-22 12:21:12 -04:00
Darren Hoo
0010db46e4 Fix a data race on message counter (#284)
3. ==================
WARNING: DATA RACE
Write at 0x00c00030e020 by goroutine 17:
  sync/atomic.AddInt64()
      runtime/race_amd64.s:276 +0xb
  github.com/slackhq/nebula.(*Interface).sendNoMetrics()
      github.com/slackhq/nebula/inside.go:226 +0x9c
  github.com/slackhq/nebula.(*Interface).send()
      github.com/slackhq/nebula/inside.go:214 +0x149
  github.com/slackhq/nebula.(*Interface).readOutsidePackets()
      github.com/slackhq/nebula/outside.go:94 +0x1213
  github.com/slackhq/nebula.(*udpConn).ListenOut()
      github.com/slackhq/nebula/udp_generic.go:109 +0x3b5
  github.com/slackhq/nebula.(*Interface).listenOut()
      github.com/slackhq/nebula/interface.go:147 +0x15e

Previous read at 0x00c00030e020 by goroutine 18:
  github.com/slackhq/nebula.(*Interface).consumeInsidePacket()
      github.com/slackhq/nebula/inside.go:58 +0x892
  github.com/slackhq/nebula.(*Interface).listenIn()
      github.com/slackhq/nebula/interface.go:164 +0x178
2020-09-21 21:41:46 -04:00
Nathan Brown
68e3e84fdc More like a library (#279) 2020-09-18 09:20:09 -05:00
Brian Luong
6238f1550b Handle panic when invalid IP entered in sshd (#296) 2020-09-18 10:10:25 -04:00
forfuncsake
50b04413c7 Block nebula ssh server from listening on port 22 (#266)
Port 22 is blocked as a safety mechanism. In a case where nebula is
started before sshd, a system may be rendered unreachable if nebula
is holding the system ssh port and there is no other connectivity.

This commit enforces the restriction, which could previously be worked
around by listening on an IPv6 address, e.g.  "[::]:22".
2020-09-15 09:57:32 -04:00
CzBiX
ef498a31da Add disable_timestamp option (#288) 2020-09-09 07:42:11 -04:00
forfuncsake
2e5a477a50 Align linux UDP performance optimizations with configuration (#275)
* Remove unused (*udpConn).Read method

* Align linux UDP performance optimizations with configuration

While attempting to run nebula on an older Synology NAS, it became
apparent that some of the performance optimizations effectively
block support for older kernels. The recvmmsg syscall was added in
linux kernel 2.6.33, and the Synology DS212j (among other models)
is pinned to 2.6.32.12.

Similarly, SO_REUSEPORT was added to the kernel in the 3.9 cycle.
While this option has been backported into some older trees, it
is also missing from the Synology kernel.

This commit allows nebula to be run on linux devices with older
kernels if the config options are set up with a single listener
and a UDP batch size of 1.
2020-08-13 08:24:05 +10:00
Wade Simmons
32fe9bfe75 Use Go 1.15 (#277)
Update all CI checks and release process to use the latest patch version
of go1.15.
2020-08-12 16:16:21 -04:00
forfuncsake
9b8b3c478b Support startup without a tun device (#269)
This commit adds support for Nebula to be started without creating
a tun device. A node started in this mode still has a full "control
plane", but no effective "data plane". Its use is suited to a
lighthouse that has no need to partake in the mesh VPN.

Consequently, creation of the tun device is the only reason nebula
neesd to be started with elevated privileged, so this example
lighthouse can also be run as a non-root user.
2020-08-10 09:15:55 -04:00
Michael Hardy
7b3f23d9a1 Start nebula after the network is up (#270) 2020-08-07 11:33:48 -05:00
forfuncsake
25964b54f6 Use inclusive terminology for cert blocking (#272) 2020-08-06 11:17:47 +10:00
Wade Simmons
ac557f381b drop unroutable packets (#267)
Currently, if a packet arrives on the tun device with a destination that
is not a routable Nebula IP, `queryUnsafeRoute` converts that IP to
0.0.0.0 and we store that packet and try to look up that IP with the
lighthouse. This doesn't make any sense to do, if we get a packet that
is unroutable we should just drop it.

Note, we have a few configurable options like `drop_local_broadcast`
and `drop_multicast` which do this for a few specific types, but since
no packets like this will send correctly I think we should just drop
anything that is unroutable.
2020-08-04 22:59:04 -04:00
Wade Simmons
a54f3fc681 fix fast handshake trigger for static hosts (#265)
We are currently triggering a fast handshake for static hosts right
inside HandshakeManager.AddVpnIP, but this can actually trigger before
we have generated the handshake packet to use. Instead, we should be
triggering right after we call ixHandshakeStage0 in getOrHandshake
(which generates the handshake packet)
2020-08-02 20:59:50 -04:00
Alan Lam
5545cff6ef log remote certificate fingerprint on handshakes (#262) 2020-07-31 18:54:51 -04:00
Wade Simmons
f3a6d8d990 Preserve conntrack table during firewall rules reload (SIGHUP) (#233)
Currently, we drop the conntrack table when firewall rules change during a SIGHUP reload. This means responses to inflight HTTP requests can be dropped, among other issues. This change copies the conntrack table over to the new firewall (it holds the conntrack mutex lock during this process, to be safe).

This change also records which firewall rules hash each conntrack entry used, so that we can re-verify the rules after the new firewall has been loaded.
2020-07-31 18:53:36 -04:00
forfuncsake
9b06748506 Make Interface.Inside an interface type (#252)
This commit updates the Interface.Inside type to be a new interface
type instead of a *Tun. This will allow for an inside interface
that does not use a tun device, such as a single-binary client that
can run without elevated privileges.
2020-07-28 08:53:16 -04:00
Wade Simmons
4756c9613d trigger handshakes when lighthouse reply arrives (#246)
Currently, we wait until the next timer tick to act on the lighthouse's
reply to our HostQuery. This means we can easily add hundreds of
milliseconds of unnecessary delay to the handshake. To fix this, we
can introduce a channel to trigger an outbound handshake without waiting
for the next timer tick.

A few samples of cold ping time between two hosts that require a
lighthouse lookup:

    before (v1.2.0):

    time=156 ms
    time=252 ms
    time=12.6 ms
    time=301 ms
    time=352 ms
    time=49.4 ms
    time=150 ms
    time=13.5 ms
    time=8.24 ms
    time=161 ms
    time=355 ms

    after:

    time=3.53 ms
    time=3.14 ms
    time=3.08 ms
    time=3.92 ms
    time=7.78 ms
    time=3.59 ms
    time=3.07 ms
    time=3.22 ms
    time=3.12 ms
    time=3.08 ms
    time=8.04 ms

I recommend reviewing this PR by looking at each commit individually, as
some refactoring was required that makes the diff a bit confusing when
combined together.
2020-07-22 10:35:10 -04:00
Nathan Brown
4645e6034b Fix up the tun for android (#249) 2020-07-01 10:20:52 -05:00
Wade Simmons
aba42f9fa6 enforce the use of goimports (#248)
* enforce the use of goimports

Instead of enforcing `gofmt`, enforce `goimports`, which also asserts
a separate section for non-builtin packages.

* run `goimports` everywhere

* exclude generated .pb.go files
2020-06-30 18:53:30 -04:00
Nathan Brown
41578ca971 Be more like a library to support mobile (#247) 2020-06-30 13:48:58 -05:00
Wade Simmons
1ea8847085 linux: set advmss correctly when route MTU is used (#245)
If different mtus are specified for different routes, we should set
advmss on each route because Linux does a poor job of selecting the
default (from ip-route(8)):

    advmss NUMBER (Linux 2.3.15+ only)
           the MSS ('Maximal Segment Size') to advertise to these destinations when estab‐
           lishing TCP connections. If it is not given, Linux uses a default value calcu‐
           lated from the first hop device MTU.  (If the path to these destination is asym‐
           metric, this guess may be wrong.)

Note that the default value is calculated from the first hop *device
MTU*, not the *route MTU*. In practice this is usually ok as long as the
other side of the tunnel has the mtu configured exactly the same, but we
should probably just set advmss correctly on these routes.
2020-06-26 13:47:21 -04:00
Wade Simmons
55858c64cc smoke test: test firewall inbound / outbound (#240)
Test that basic inbound / outbound firewall rules work during the smoke
test. This change sets an inbound firewall rule on host3, and a new
host4 with outbound firewall rules. It also tests that conntrack allows
packets once the connection has been established.
2020-06-26 13:46:51 -04:00
Wade Simmons
e94c6b0125 mips-softfloat (#231)
This makes GOARM more generic and does GOMIPS in a similar way to
support mips-softfloat. We also set `-ldflags "-s -w"` for
mips-softfloat to give the best chance of the binary working on these
small devices.
2020-06-26 13:46:23 -04:00
Wade Simmons
b37a91cfbc add meta packet statistics (#230)
This change add more metrics around "meta" (non "message" type packets).
For lighthouse packets, we also record statistics around the specific
lighthouse meta type.

We don't keep statistics for the "message" type so that we don't slow
down the fast path (and you can just look at metrics on the tun
interface to find that information).
2020-06-26 13:45:48 -04:00
David Sonder
3212b769d4 fix typo in conntrack section in examples/config.yml (#236)
the rest of the conntrack values match the default
2020-06-26 11:08:22 -05:00
Patrick Bogen
ecf0e5a9f6 drop packets even if we aren't going to emit Debug logs about it (#239)
* drop packets even if we aren't going to emit Debug logs about it

* smallify change
2020-06-10 16:55:49 -05:00
Wade Simmons
ff13aba8fc allow go test -bench=. to run (#234)
This benchmark had an Errorf at the end, lets remove it so the
benchmarks all run.
2020-05-27 16:52:34 -04:00
Mateusz Kwiatkowski
cc03ff9e9a Unbreak building for FreeBSD (#103)
Add support for freebsd. You have to set `tun.dev` in your config. The second pass of this would be to remove the exec calls and use ioctl(2) and route(4) instead, but we can do that in a second PR.

Co-authored-by: Wade Simmons <wade@wades.im>
2020-05-26 22:23:23 -04:00
Patrick Bogen
363c836422 log the reason for fw drops (#220)
* log the reason for fw drops

* only prepare log if we will end up sending it
2020-04-10 10:57:21 -07:00
Wade Simmons
fb252db4a1 v1.2.0 (#215)
Add descriptions for all commits since v1.1.0
2020-04-08 19:52:24 -04:00
Wade Simmons
4f6313ebd3 fix config name for {remote,local}_allow_list (#219)
This config option should be snake_case, not camelCase.
2020-04-08 16:20:12 -04:00
Wade Simmons
0a474e757b Add lighthouse.{remoteAllowList,localAllowList} (#217)
These settings make it possible to blacklist / whitelist IP addresses
that are used for remote connections.

`lighthouse.remoteAllowList` filters which remote IPs are allow when
fetching from the lighthouse (or, if you are the lighthouse, which IPs
you store and forward to querying hosts). By default, any remote IPs are
allowed. You can provide CIDRs here with `true` to allow and `false` to
deny. The most specific CIDR rule applies to each remote.  If all rules
are "allow", the default will be "deny", and vice-versa. If both "allow"
and "deny" rules are present, then you MUST set a rule for "0.0.0.0/0"
as the default.

    lighthouse:
      remoteAllowList:
        # Example to block IPs from this subnet from being used for remote IPs.
        "172.16.0.0/12": false

        # A more complicated example, allow public IPs but only private IPs from a specific subnet
        "0.0.0.0/0": true
        "10.0.0.0/8": false
        "10.42.42.0/24": true

`lighthouse.localAllowList` has the same logic as above, but it applies
to the local addresses we advertise to the lighthouse. Additionally, you
can specify an `interfaces` map of regular expressions to match against
interface names. The regexp must match the entire name. All interface
rules must be either true or false (and the default rule will be the
inverse). CIDR rules are matched after interface name rules.

Default is all local IP addresses.

    lighthouse:
      localAllowList:
        # Example to blacklist docker interfaces.
        interfaces:
          'docker.*': false

        # Example to only advertise IPs in this subnet to the lighthouse.
        "10.0.0.0/8": true
2020-04-08 15:36:43 -04:00
Nathan Brown
7cd342c7ab Add a systemd unit for arch and a wireshark dissector (#216) 2020-04-06 18:47:32 -07:00
Wade Simmons
7cdbb14a18 Better config test (#177)
* Better config test

Previously, when using the config test option `-test`, we quit fairly
earlier in the process and would not catch a variety of additional
parsing errors (such as lighthouse IP addresses, local_range, the new
check to make sure static hosts are in the certificate's subnet, etc).

* run config test as part of smoke test

* don't need privileges for configtest

Co-authored-by: Nathan Brown <nate@slack-corp.com>
2020-04-06 11:35:32 -07:00
Wade Simmons
b4f2f7ce4e log certName alongside vpnIp (#200)
This change adds a new helper, `(*HostInfo).logger()`, that starts a new
logrus.Entry with `vpnIp` and `certName`. We don't use the helper inside
of handshake_ix though since the certificate has not been attached to
the HostInfo yet.

Fixes: #84
2020-04-06 11:34:00 -07:00
Alex
ff64d1f952 unsafe_routes mtu (#209) 2020-04-06 11:33:30 -07:00
Felix Yan
9e2ff7df57 Correct typos in noise.go (#205) 2020-03-30 11:23:55 -07:00
Ryan Huber
1297090af3 add configurable punching delay because of race-condition-y conntracks (#210)
* add configurable punching delay because of race-condition-y conntracks

* add changelog

* fix tests

* only do one punch per query

* Coalesce punchy config

* It is not is not set

* Add tests

Co-authored-by: Nate Brown <nbrown.us@gmail.com>
2020-03-27 11:26:39 -07:00
Wade Simmons
add1b21777 only create a CIDRTree for each host if necessary (#198)
A CIDRTree can be expensive to create, so only do it if we need
it. If the remote host only has one IP address and no subnets, just do
an exact IP match instead.

Fixes: #171
2020-03-02 16:21:33 -05:00
Wade Simmons
1cb3201b5e Github Actions: cache modules and only run when necessary (#197)
This PR does two things:

- Only run the tests when relevant files change.
- Cache the Go Modules directory between runs, so they don't have to redownload everything everytime (go.sum is the cache key). Pretty much straight from the examples: https://github.com/actions/cache/blob/master/examples.md#go---modules
2020-03-02 16:21:19 -05:00
Ryan Huber
41968551f9 clarify that lighthouse IP should be nebula range (#196) 2020-02-28 11:35:55 -08:00
Wade Simmons
8548ac3c31 build and test with go1.14 (#195)
- https://golang.org/doc/go1.14

I did a performance sanity check in Docker, and performance seems about
the same (perhaps slightly higher).
2020-02-27 15:48:39 -05:00
Wade Simmons
fb9b36f677 allow any config file name if specified directly (#189)
Currently, we require that config file names end with `.yml` or `.yaml`.
This is because if the user points `-config` at a directory of files, we
only want to use the YAML files in that directory.

But this makes it more difficult to use the `-test -config` option
because config management tools might not have an extension on the file
when preparing a new config file. This change makes it so that if you
point `-config file` directly at a file, it uses it no matter what the
extension is.
2020-02-26 15:38:56 -05:00
Sebastien Bariteau
4d1928f1e3 Support unsafe_routes on Windows (#184)
* Support unsafe_routes on Windows

* Full path to route executable

* Escape string properly
2020-02-26 15:23:16 -05:00
Ryan Huber
a91a40212d check that packet isn't bound for my vpn ip (#192) 2020-02-21 16:49:54 -08:00
Wade Simmons
179a369130 add configuration options for HandshakeManager (#179)
This change exposes the current constants we have defined for the handshake
manager as configuration options. This will allow us to test and tweak
with different intervals and wait rotations.

    # Handshake Manger Settings
    handshakes:
      # Total time to try a handshake = sequence of `try_interval * retries`
      # With 100ms interval and 20 retries it is 23.5 seconds
      try_interval: 100ms
      retries: 20

      # wait_rotation is the number of handshake attempts to do before starting to try non-local IP addresses
      wait_rotation: 5
2020-02-21 16:25:11 -05:00
Wade Simmons
df69371620 use absolute paths on darwin and windows (#191)
We want to make sure to use the system binaries, and not whatever is in
the PATH.
2020-02-21 15:25:33 -05:00
Wade Simmons
eda344d88f add logging.timestamp_format config option (#187)
This change introduces logging.timestamp_format, which allows
configuration of the Logrus TimestampFormat setting. The primary purpose
of this change was to allow logging with millisecond precision. The
default for `text` and `json` formats remains the same for backwards
compatibility.

timestamp format is specified in Go time format, see:

 - https://golang.org/pkg/time/#pkg-constants

Default when `format: json`: "2006-01-02T15:04:05Z07:00" (RFC3339)
Default when `format: text`:
  when TTY attached: seconds since beginning of execution
  otherwise: "2006-01-02T15:04:05Z07:00" (RFC3339)

As an example, to log as RFC3339 with millisecond precision, set to:

    logging:
        timestamp_format: "2006-01-02T15:04:05.000Z07:00"
2020-02-21 15:25:00 -05:00
Wade Simmons
065e2ff88a update golang.org/x/crypto (#188)
This version contains a fix for CVE-2020-9283, a remote crash bug:

- https://groups.google.com/forum/#!msg/golang-announce/3L45YRc91SY/ywEPcKLnGQAJ
2020-02-20 14:49:55 -05:00
Nathan Brown
45a5de2719 Print the udp listen address on startup (#181) 2020-02-06 21:17:43 -08:00
Wade Simmons
2d24ef7166 validate lighthouses and static hosts are in our subnet (#170)
Validate all lighthouse.hosts and static_host_map VPN IPs are in the
subnet defined in our cert. Exit with a fatal error if they are not in
our subnet, as this is an invalid configuration (we will not have the
proper routes set up to communicate with these hosts).

This error case could occur for the following invalid example:

    nebula-cert sign -name "lighthouse" -ip "10.0.1.1/24"
    nebula-cert sign -name "host" -ip "10.0.2.1/24"

    config.yaml:

        static_host_map:
            "10.0.1.1": ["lighthouse.local:4242"]
        lighthouse:
          hosts:
            - "10.0.1.1"

We will now return a fatal error for this config, since `10.0.1.1` is
not in the host cert's subnet of `10.0.2.1/24`
2020-01-20 15:52:55 -05:00
84 changed files with 3592 additions and 624 deletions

View File

@@ -4,6 +4,9 @@ on:
branches: branches:
- master - master
pull_request: pull_request:
paths:
- '.github/workflows/gofmt.yml'
- '**.go'
jobs: jobs:
gofmt: gofmt:
@@ -11,19 +14,31 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Set up Go 1.13 - name: Set up Go 1.15
uses: actions/setup-go@v1 uses: actions/setup-go@v1
with: with:
go-version: 1.13 go-version: 1.15
id: go id: go
- name: Check out code into the Go module directory - name: Check out code into the Go module directory
uses: actions/checkout@v1 uses: actions/checkout@v1
- uses: actions/cache@v1
with:
path: ~/go/pkg/mod
key: ${{ runner.os }}-gofmt-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-gofmt-
- name: Install goimports
run: |
go get golang.org/x/tools/cmd/goimports
go build golang.org/x/tools/cmd/goimports
- name: gofmt - name: gofmt
run: | run: |
if [ "$(find . -iname '*.go' | xargs gofmt -l)" ] if [ "$(find . -iname '*.go' | grep -v '\.pb\.go$' | xargs ./goimports -l)" ]
then then
find . -iname '*.go' | xargs gofmt -d find . -iname '*.go' | grep -v '\.pb\.go$' | xargs ./goimports -d
exit 1 exit 1
fi fi

View File

@@ -10,17 +10,17 @@ jobs:
name: Build Linux All name: Build Linux All
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Set up Go 1.13 - name: Set up Go 1.15
uses: actions/setup-go@v1 uses: actions/setup-go@v1
with: with:
go-version: 1.13 go-version: 1.15
- name: Checkout code - name: Checkout code
uses: actions/checkout@v2 uses: actions/checkout@v2
- name: Build - name: Build
run: | run: |
make BUILD_NUMBER="${GITHUB_REF#refs/tags/v}" release-linux make BUILD_NUMBER="${GITHUB_REF#refs/tags/v}" release-linux release-freebsd
mkdir release mkdir release
mv build/*.tar.gz release mv build/*.tar.gz release
@@ -34,10 +34,10 @@ jobs:
name: Build Windows amd64 name: Build Windows amd64
runs-on: windows-latest runs-on: windows-latest
steps: steps:
- name: Set up Go 1.13 - name: Set up Go 1.15
uses: actions/setup-go@v1 uses: actions/setup-go@v1
with: with:
go-version: 1.13 go-version: 1.15
- name: Checkout code - name: Checkout code
uses: actions/checkout@v2 uses: actions/checkout@v2
@@ -58,10 +58,10 @@ jobs:
name: Build Darwin amd64 name: Build Darwin amd64
runs-on: macOS-latest runs-on: macOS-latest
steps: steps:
- name: Set up Go 1.13 - name: Set up Go 1.15
uses: actions/setup-go@v1 uses: actions/setup-go@v1
with: with:
go-version: 1.13 go-version: 1.15
- name: Checkout code - name: Checkout code
uses: actions/checkout@v2 uses: actions/checkout@v2
@@ -278,3 +278,23 @@ jobs:
asset_path: ./linux-latest/nebula-linux-mips64le.tar.gz asset_path: ./linux-latest/nebula-linux-mips64le.tar.gz
asset_name: nebula-linux-mips64le.tar.gz asset_name: nebula-linux-mips64le.tar.gz
asset_content_type: application/gzip asset_content_type: application/gzip
- name: Upload linux-mips-softfloat
uses: actions/upload-release-asset@v1.0.1
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
with:
upload_url: ${{ steps.create_release.outputs.upload_url }}
asset_path: ./linux-latest/nebula-linux-mips-softfloat.tar.gz
asset_name: nebula-linux-mips-softfloat.tar.gz
asset_content_type: application/gzip
- name: Upload freebsd-amd64
uses: actions/upload-release-asset@v1.0.1
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
with:
upload_url: ${{ steps.create_release.outputs.upload_url }}
asset_path: ./linux-latest/nebula-freebsd-amd64.tar.gz
asset_name: nebula-freebsd-amd64.tar.gz
asset_content_type: application/gzip

View File

@@ -4,22 +4,36 @@ on:
branches: branches:
- master - master
pull_request: pull_request:
paths:
- '.github/workflows/smoke**'
- '**Makefile'
- '**.go'
- '**.proto'
- 'go.mod'
- 'go.sum'
jobs: jobs:
smoke: smoke:
name: Run 3 node smoke test name: Run multi node smoke test
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Set up Go 1.13 - name: Set up Go 1.15
uses: actions/setup-go@v1 uses: actions/setup-go@v1
with: with:
go-version: 1.13 go-version: 1.15
id: go id: go
- name: Check out code into the Go module directory - name: Check out code into the Go module directory
uses: actions/checkout@v1 uses: actions/checkout@v1
- uses: actions/cache@v1
with:
path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-go-
- name: build - name: build
run: make run: make

View File

@@ -11,14 +11,29 @@ mkdir ./build
cp ../../../../nebula . cp ../../../../nebula .
cp ../../../../nebula-cert . cp ../../../../nebula-cert .
HOST="lighthouse1" AM_LIGHTHOUSE=true ../genconfig.sh >lighthouse1.yml HOST="lighthouse1" \
HOST="host2" LIGHTHOUSES="192.168.100.1 172.17.0.2:4242" ../genconfig.sh >host2.yml AM_LIGHTHOUSE=true \
HOST="host3" LIGHTHOUSES="192.168.100.1 172.17.0.2:4242" ../genconfig.sh >host3.yml ../genconfig.sh >lighthouse1.yml
HOST="host2" \
LIGHTHOUSES="192.168.100.1 172.17.0.2:4242" \
../genconfig.sh >host2.yml
HOST="host3" \
LIGHTHOUSES="192.168.100.1 172.17.0.2:4242" \
INBOUND='[{"port": "any", "proto": "icmp", "group": "lighthouse"}]' \
../genconfig.sh >host3.yml
HOST="host4" \
LIGHTHOUSES="192.168.100.1 172.17.0.2:4242" \
OUTBOUND='[{"port": "any", "proto": "icmp", "group": "lighthouse"}]' \
../genconfig.sh >host4.yml
./nebula-cert ca -name "Smoke Test" ./nebula-cert ca -name "Smoke Test"
./nebula-cert sign -name "lighthouse1" -ip "192.168.100.1/24" ./nebula-cert sign -name "lighthouse1" -groups "lighthouse,lighthouse1" -ip "192.168.100.1/24"
./nebula-cert sign -name "host2" -ip "192.168.100.2/24" ./nebula-cert sign -name "host2" -groups "host,host2" -ip "192.168.100.2/24"
./nebula-cert sign -name "host3" -ip "192.168.100.3/24" ./nebula-cert sign -name "host3" -groups "host,host3" -ip "192.168.100.3/24"
./nebula-cert sign -name "host4" -groups "host,host4" -ip "192.168.100.4/24"
) )
docker build -t nebula:smoke . docker build -t nebula:smoke .

View File

@@ -2,6 +2,7 @@
set -e set -e
FIREWALL_ALL='[{"port": "any", "proto": "any", "host": "any"}]'
if [ "$STATIC_HOSTS" ] || [ "$LIGHTHOUSES" ] if [ "$STATIC_HOSTS" ] || [ "$LIGHTHOUSES" ]
then then
@@ -48,13 +49,6 @@ tun:
dev: ${TUN_DEV:-nebula1} dev: ${TUN_DEV:-nebula1}
firewall: firewall:
outbound: outbound: ${OUTBOUND:-$FIREWALL_ALL}
- port: any inbound: ${INBOUND:-$FIREWALL_ALL}
proto: any
host: any
inbound:
- port: any
proto: any
host: any
EOF EOF

View File

@@ -2,12 +2,19 @@
set -e -x set -e -x
docker run --name lighthouse1 --rm nebula:smoke -config lighthouse1.yml -test
docker run --name host2 --rm nebula:smoke -config host2.yml -test
docker run --name host3 --rm nebula:smoke -config host3.yml -test
docker run --name host4 --rm nebula:smoke -config host4.yml -test
docker run --name lighthouse1 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm nebula:smoke -config lighthouse1.yml & docker run --name lighthouse1 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm nebula:smoke -config lighthouse1.yml &
sleep 1 sleep 1
docker run --name host2 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm nebula:smoke -config host2.yml & docker run --name host2 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm nebula:smoke -config host2.yml &
sleep 1 sleep 1
docker run --name host3 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm nebula:smoke -config host3.yml & docker run --name host3 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm nebula:smoke -config host3.yml &
sleep 1 sleep 1
docker run --name host4 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm nebula:smoke -config host4.yml &
sleep 1
set +x set +x
echo echo
@@ -23,7 +30,8 @@ echo " *** Testing ping from host2"
echo echo
set -x set -x
docker exec host2 ping -c1 192.168.100.1 docker exec host2 ping -c1 192.168.100.1
docker exec host2 ping -c1 192.168.100.3 # Should fail because not allowed by host3 inbound firewall
! docker exec host2 ping -c1 192.168.100.3 -w5 || exit 1
set +x set +x
echo echo
@@ -32,3 +40,24 @@ echo
set -x set -x
docker exec host3 ping -c1 192.168.100.1 docker exec host3 ping -c1 192.168.100.1
docker exec host3 ping -c1 192.168.100.2 docker exec host3 ping -c1 192.168.100.2
set +x
echo
echo " *** Testing ping from host4"
echo
set -x
docker exec host4 ping -c1 192.168.100.1
# Should fail because not allowed by host4 outbound firewall
! docker exec host4 ping -c1 192.168.100.2 -w5 || exit 1
! docker exec host4 ping -c1 192.168.100.3 -w5 || exit 1
set +x
echo
echo " *** Testing conntrack"
echo
set -x
# host2 can ping host3 now that host3 pinged it first
docker exec host2 ping -c1 192.168.100.3
# host4 can ping host2 once conntrack established
docker exec host2 ping -c1 192.168.100.4
docker exec host4 ping -c1 192.168.100.2

View File

@@ -4,6 +4,13 @@ on:
branches: branches:
- master - master
pull_request: pull_request:
paths:
- '.github/workflows/test.yml'
- '**Makefile'
- '**.go'
- '**.proto'
- 'go.mod'
- 'go.sum'
jobs: jobs:
test-linux: test-linux:
@@ -11,15 +18,22 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Set up Go 1.13 - name: Set up Go 1.15
uses: actions/setup-go@v1 uses: actions/setup-go@v1
with: with:
go-version: 1.13 go-version: 1.15
id: go id: go
- name: Check out code into the Go module directory - name: Check out code into the Go module directory
uses: actions/checkout@v1 uses: actions/checkout@v1
- uses: actions/cache@v1
with:
path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-go-
- name: Build - name: Build
run: make all run: make all
@@ -34,15 +48,22 @@ jobs:
os: [windows-latest, macOS-latest] os: [windows-latest, macOS-latest]
steps: steps:
- name: Set up Go 1.13 - name: Set up Go 1.15
uses: actions/setup-go@v1 uses: actions/setup-go@v1
with: with:
go-version: 1.13 go-version: 1.15
id: go id: go
- name: Check out code into the Go module directory - name: Check out code into the Go module directory
uses: actions/checkout@v1 uses: actions/checkout@v1
- uses: actions/cache@v1
with:
path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-go-
- name: Build nebula - name: Build nebula
run: go build ./cmd/nebula run: go build ./cmd/nebula

View File

@@ -7,6 +7,144 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [Unreleased] ## [Unreleased]
### Changed
- Updated the kardianos/service go library from 1.0.0 to 1.1.0, which
now creates launchd plist to write stdout/stderr to files by default.
## [1.3.0] - 2020-09-22
### Added
- You can emit statistics about non-message packets by setting the option
`stats.message_metrics`. You can similarly emit detailed statistics about
lighthouse packets by setting the option `stats.lighthouse_metrics`. See
the example config for more details. (#230)
- We now support freebsd/amd64. This is experimental, please give us feedback.
(#103)
- We now release a binary for `linux/mips-softfloat` which has also been
stripped to reduce filesize and hopefully have a better chance on running on
small mips devices. (#231)
- You can set `tun.disabled` to true to run a standalone lighthouse without a
tun device (and thus, without root). (#269)
- You can set `logging.disable_timestamp` to remove timestamps from log lines,
which is useful when output is redirected to a logging system that already
adds timestamps. (#288)
### Changed
- Handshakes should now trigger faster, as we try to be proactive with sending
them instead of waiting for the next timer tick in most cases. (#246, #265)
- Previously, we would drop the conntrack table whenever firewall rules were
changed during a SIGHUP. Now, we will maintain the table and just validate
that an entry still matches with the new rule set. (#233)
- Debug logs for firewall drops now include the reason. (#220, #239)
- Logs for handshakes now include the fingerprint of the remote host. (#262)
- Config item `pki.blacklist` is now `pki.blocklist`. (#272)
- Better support for older Linux kernels. We now only set `SO_REUSEPORT` if
`tun.routines` is greater than 1 (default is 1). We also only use the
`recvmmsg` syscall if `listen.batch` is greater than 1 (default is 64).
(#275)
- It is possible to run Nebula as a library inside of another process now.
Note that this is still experimental and the internal APIs around this might
change in minor version releases. (#279)
### Deprecated
- `pki.blacklist` is deprecated in favor of `pki.blocklist` with the same
functionality. Existing configs will continue to load for this release to
allow for migrations. (#272)
### Fixed
- `advmss` is now set correctly for each route table entry when `tun.routes`
is configured to have some routes with higher MTU. (#245)
- Packets that arrive on the tun device with an unroutable destination IP are
now dropped correctly, instead of wasting time making queries to the
lighthouses for IP `0.0.0.0` (#267)
## [1.2.0] - 2020-04-08
### Added
- Add `logging.timestamp_format` config option. The primary purpose of this
change is to allow logging timestamps with millisecond precision. (#187)
- Support `unsafe_routes` on Windows. (#184)
- Add `lighthouse.remote_allow_list` to filter which subnets we will use to
handshake with other hosts. See the example config for more details. (#217)
- Add `lighthouse.local_allow_list` to filter which local IP addresses and/or
interfaces we advertise to the lighthouses. See the example config for more
details. (#217)
- Wireshark dissector plugin. Add this file in `dist/wireshark` to your
Wireshark plugins folder to see Nebula packet headers decoded. (#216)
- systemd unit for Arch, so it can be built entirely from this repo. (#216)
### Changed
- Added a delay to punching via lighthouse signal to deal with race conditions
in some linux conntrack implementations. (#210)
See deprecated, this also adds a new `punchy.delay` option that defaults to `1s`.
- Validate all `lighthouse.hosts` and `static_host_map` VPN IPs are in the
subnet defined in our cert. Exit with a fatal error if they are not in our
subnet, as this is an invalid configuration (we will not have the proper
routes set up to communicate with these hosts). (#170)
- Use absolute paths to system binaries on macOS and Windows. (#191)
- Add configuration options for `handshakes`. This includes options to tweak
`try_interval`, `retries` and `wait_rotation`. See example config for
descriptions. (#179)
- Allow `-config` file to not end in `.yaml` or `yml`. Useful when using
`-test` and automated tools like Ansible that create temporary files without
suffixes. (#189)
- The config test mode, `-test`, is now more thorough and catches more parsing
issues. (#177)
- Various documentation and example fixes. (#196)
- Improved log messages. (#181, #200)
- Dependencies updated. (#188)
### Deprecated
- `punchy`, `punch_back` configuration options have been collapsed under the
now top level `punchy` config directive. (#210)
`punchy.punch` - This is the old `punchy` option. Should we perform NAT hole
punching (default false)?
`punchy.respond` - This is the old `punch_back` option. Should we respond to
hole punching by hole punching back (default false)?
### Fixed
- Reduce memory allocations when not using `unsafe_routes`. (#198)
- Ignore packets from self to self. (#192)
- MTU fixed for `unsafe_routes`. (#209)
## [1.1.0] - 2020-01-17 ## [1.1.0] - 2020-01-17
### Added ### Added
@@ -47,6 +185,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Initial public release. - Initial public release.
[Unreleased]: https://github.com/slackhq/nebula/compare/v1.1.0...HEAD [Unreleased]: https://github.com/slackhq/nebula/compare/v1.3.0...HEAD
[1.3.0]: https://github.com/slackhq/nebula/releases/tag/v1.3.0
[1.2.0]: https://github.com/slackhq/nebula/releases/tag/v1.2.0
[1.1.0]: https://github.com/slackhq/nebula/releases/tag/v1.1.0 [1.1.0]: https://github.com/slackhq/nebula/releases/tag/v1.1.0
[1.0.0]: https://github.com/slackhq/nebula/releases/tag/v1.0.0 [1.0.0]: https://github.com/slackhq/nebula/releases/tag/v1.0.0

View File

@@ -3,6 +3,8 @@ BUILD_NUMBER ?= dev+$(shell date -u '+%Y%m%d%H%M%S')
GO111MODULE = on GO111MODULE = on
export GO111MODULE export GO111MODULE
LDFLAGS = -X main.Build=$(BUILD_NUMBER)
ALL_LINUX = linux-amd64 \ ALL_LINUX = linux-amd64 \
linux-386 \ linux-386 \
linux-ppc64le \ linux-ppc64le \
@@ -13,10 +15,12 @@ ALL_LINUX = linux-amd64 \
linux-mips \ linux-mips \
linux-mipsle \ linux-mipsle \
linux-mips64 \ linux-mips64 \
linux-mips64le linux-mips64le \
linux-mips-softfloat
ALL = $(ALL_LINUX) \ ALL = $(ALL_LINUX) \
darwin-amd64 \ darwin-amd64 \
freebsd-amd64 \
windows-amd64 windows-amd64
all: $(ALL:%=build/%/nebula) $(ALL:%=build/%/nebula-cert) all: $(ALL:%=build/%/nebula) $(ALL:%=build/%/nebula-cert)
@@ -25,31 +29,40 @@ release: $(ALL:%=build/nebula-%.tar.gz)
release-linux: $(ALL_LINUX:%=build/nebula-%.tar.gz) release-linux: $(ALL_LINUX:%=build/nebula-%.tar.gz)
release-freebsd: build/nebula-freebsd-amd64.tar.gz
bin-windows: build/windows-amd64/nebula.exe build/windows-amd64/nebula-cert.exe bin-windows: build/windows-amd64/nebula.exe build/windows-amd64/nebula-cert.exe
mv $? . mv $? .
bin-darwin: build/darwin-amd64/nebula build/darwin-amd64/nebula-cert bin-darwin: build/darwin-amd64/nebula build/darwin-amd64/nebula-cert
mv $? . mv $? .
bin-freebsd: build/freebsd-amd64/nebula build/freebsd-amd64/nebula-cert
mv $? .
bin: bin:
go build -trimpath -ldflags "-X main.Build=$(BUILD_NUMBER)" -o ./nebula ${NEBULA_CMD_PATH} go build -trimpath -ldflags "$(LDFLAGS)" -o ./nebula ${NEBULA_CMD_PATH}
go build -trimpath -ldflags "-X main.Build=$(BUILD_NUMBER)" -o ./nebula-cert ./cmd/nebula-cert go build -trimpath -ldflags "$(LDFLAGS)" -o ./nebula-cert ./cmd/nebula-cert
install: install:
go install -trimpath -ldflags "-X main.Build=$(BUILD_NUMBER)" ${NEBULA_CMD_PATH} go install -trimpath -ldflags "$(LDFLAGS)" ${NEBULA_CMD_PATH}
go install -trimpath -ldflags "-X main.Build=$(BUILD_NUMBER)" ./cmd/nebula-cert go install -trimpath -ldflags "$(LDFLAGS)" ./cmd/nebula-cert
build/linux-arm-%: GOENV += GOARM=$(word 3, $(subst -, ,$*))
build/linux-mips-%: GOENV += GOMIPS=$(word 3, $(subst -, ,$*))
# Build an extra small binary for mips-softfloat
build/linux-mips-softfloat/%: LDFLAGS += -s -w
build/%/nebula: .FORCE build/%/nebula: .FORCE
GOOS=$(firstword $(subst -, , $*)) \ GOOS=$(firstword $(subst -, , $*)) \
GOARCH=$(word 2, $(subst -, ,$*)) \ GOARCH=$(word 2, $(subst -, ,$*)) $(GOENV) \
GOARM=$(word 3, $(subst -, ,$*)) \ go build -trimpath -o $@ -ldflags "$(LDFLAGS)" ${NEBULA_CMD_PATH}
go build -trimpath -o $@ -ldflags "-X main.Build=$(BUILD_NUMBER)" ${NEBULA_CMD_PATH}
build/%/nebula-cert: .FORCE build/%/nebula-cert: .FORCE
GOOS=$(firstword $(subst -, , $*)) \ GOOS=$(firstword $(subst -, , $*)) \
GOARCH=$(word 2, $(subst -, ,$*)) \ GOARCH=$(word 2, $(subst -, ,$*)) $(GOENV) \
GOARM=$(word 3, $(subst -, ,$*)) \ go build -trimpath -o $@ -ldflags "$(LDFLAGS)" ./cmd/nebula-cert
go build -trimpath -o $@ -ldflags "-X main.Build=$(BUILD_NUMBER)" ./cmd/nebula-cert
build/%/nebula.exe: build/%/nebula build/%/nebula.exe: build/%/nebula
mv $< $@ mv $< $@

48
allow_list.go Normal file
View File

@@ -0,0 +1,48 @@
package nebula
import (
"fmt"
"regexp"
)
type AllowList struct {
// The values of this cidrTree are `bool`, signifying allow/deny
cidrTree *CIDRTree
// To avoid ambiguity, all rules must be true, or all rules must be false.
nameRules []AllowListNameRule
}
type AllowListNameRule struct {
Name *regexp.Regexp
Allow bool
}
func (al *AllowList) Allow(ip uint32) bool {
if al == nil {
return true
}
result := al.cidrTree.MostSpecificContains(ip)
switch v := result.(type) {
case bool:
return v
default:
panic(fmt.Errorf("invalid state, allowlist returned: %T %v", result, result))
}
}
func (al *AllowList) AllowName(name string) bool {
if al == nil || len(al.nameRules) == 0 {
return true
}
for _, rule := range al.nameRules {
if rule.Name.MatchString(name) {
return rule.Allow
}
}
// If no rules match, return the default, which is the inverse of the rules
return !al.nameRules[0].Allow
}

47
allow_list_test.go Normal file
View File

@@ -0,0 +1,47 @@
package nebula
import (
"net"
"regexp"
"testing"
"github.com/stretchr/testify/assert"
)
func TestAllowList_Allow(t *testing.T) {
assert.Equal(t, true, ((*AllowList)(nil)).Allow(ip2int(net.ParseIP("1.1.1.1"))))
tree := NewCIDRTree()
tree.AddCIDR(getCIDR("0.0.0.0/0"), true)
tree.AddCIDR(getCIDR("10.0.0.0/8"), false)
tree.AddCIDR(getCIDR("10.42.42.0/24"), true)
al := &AllowList{cidrTree: tree}
assert.Equal(t, true, al.Allow(ip2int(net.ParseIP("1.1.1.1"))))
assert.Equal(t, false, al.Allow(ip2int(net.ParseIP("10.0.0.4"))))
assert.Equal(t, true, al.Allow(ip2int(net.ParseIP("10.42.42.42"))))
}
func TestAllowList_AllowName(t *testing.T) {
assert.Equal(t, true, ((*AllowList)(nil)).AllowName("docker0"))
rules := []AllowListNameRule{
{Name: regexp.MustCompile("^docker.*$"), Allow: false},
{Name: regexp.MustCompile("^tun.*$"), Allow: false},
}
al := &AllowList{nameRules: rules}
assert.Equal(t, false, al.AllowName("docker0"))
assert.Equal(t, false, al.AllowName("tun0"))
assert.Equal(t, true, al.AllowName("eth0"))
rules = []AllowListNameRule{
{Name: regexp.MustCompile("^eth.*$"), Allow: true},
{Name: regexp.MustCompile("^ens.*$"), Allow: true},
}
al = &AllowList{nameRules: rules}
assert.Equal(t, false, al.AllowName("docker0"))
assert.Equal(t, true, al.AllowName("eth0"))
assert.Equal(t, true, al.AllowName("ens5"))
}

View File

@@ -212,10 +212,10 @@ func TestBitsLostCounter(t *testing.T) {
func BenchmarkBits(b *testing.B) { func BenchmarkBits(b *testing.B) {
z := NewBits(10) z := NewBits(10)
for n := 0; n < b.N; n++ { for n := 0; n < b.N; n++ {
for i, _ := range z.bits { for i := range z.bits {
z.bits[i] = true z.bits[i] = true
} }
for i, _ := range z.bits { for i := range z.bits {
z.bits[i] = false z.bits[i] = false
} }

12
cert.go
View File

@@ -149,10 +149,16 @@ func loadCAFromConfig(c *Config) (*cert.NebulaCAPool, error) {
return nil, fmt.Errorf("error while adding CA certificate to CA trust store: %s", err) return nil, fmt.Errorf("error while adding CA certificate to CA trust store: %s", err)
} }
// pki.blacklist entered the scene at about the same time we aliased x509 to pki, not supporting backwards compat for _, fp := range c.GetStringSlice("pki.blocklist", []string{}) {
l.WithField("fingerprint", fp).Infof("Blocklisting cert")
CAs.BlocklistFingerprint(fp)
}
// Support deprecated config for at leaast one minor release to allow for migrations
for _, fp := range c.GetStringSlice("pki.blacklist", []string{}) { for _, fp := range c.GetStringSlice("pki.blacklist", []string{}) {
l.WithField("fingerprint", fp).Infof("Blacklisting cert") l.WithField("fingerprint", fp).Infof("Blocklisting cert")
CAs.BlacklistFingerprint(fp) l.Warn("pki.blacklist is deprecated and will not be supported in a future release. Please migrate your config to use pki.blocklist")
CAs.BlocklistFingerprint(fp)
} }
return CAs, nil return CAs, nil

View File

@@ -8,14 +8,14 @@ import (
type NebulaCAPool struct { type NebulaCAPool struct {
CAs map[string]*NebulaCertificate CAs map[string]*NebulaCertificate
certBlacklist map[string]struct{} certBlocklist map[string]struct{}
} }
// NewCAPool creates a CAPool // NewCAPool creates a CAPool
func NewCAPool() *NebulaCAPool { func NewCAPool() *NebulaCAPool {
ca := NebulaCAPool{ ca := NebulaCAPool{
CAs: make(map[string]*NebulaCertificate), CAs: make(map[string]*NebulaCertificate),
certBlacklist: make(map[string]struct{}), certBlocklist: make(map[string]struct{}),
} }
return &ca return &ca
@@ -67,24 +67,24 @@ func (ncp *NebulaCAPool) AddCACertificate(pemBytes []byte) ([]byte, error) {
return pemBytes, nil return pemBytes, nil
} }
// BlacklistFingerprint adds a cert fingerprint to the blacklist // BlocklistFingerprint adds a cert fingerprint to the blocklist
func (ncp *NebulaCAPool) BlacklistFingerprint(f string) { func (ncp *NebulaCAPool) BlocklistFingerprint(f string) {
ncp.certBlacklist[f] = struct{}{} ncp.certBlocklist[f] = struct{}{}
} }
// ResetCertBlacklist removes all previously blacklisted cert fingerprints // ResetCertBlocklist removes all previously blocklisted cert fingerprints
func (ncp *NebulaCAPool) ResetCertBlacklist() { func (ncp *NebulaCAPool) ResetCertBlocklist() {
ncp.certBlacklist = make(map[string]struct{}) ncp.certBlocklist = make(map[string]struct{})
} }
// IsBlacklisted returns true if the fingerprint fails to generate or has been explicitly blacklisted // IsBlocklisted returns true if the fingerprint fails to generate or has been explicitly blocklisted
func (ncp *NebulaCAPool) IsBlacklisted(c *NebulaCertificate) bool { func (ncp *NebulaCAPool) IsBlocklisted(c *NebulaCertificate) bool {
h, err := c.Sha256Sum() h, err := c.Sha256Sum()
if err != nil { if err != nil {
return true return true
} }
if _, ok := ncp.certBlacklist[h]; ok { if _, ok := ncp.certBlocklist[h]; ok {
return true return true
} }

View File

@@ -1,18 +1,18 @@
package cert package cert
import ( import (
"bytes"
"crypto" "crypto"
"crypto/rand" "crypto/rand"
"crypto/sha256" "crypto/sha256"
"encoding/binary" "encoding/binary"
"encoding/hex" "encoding/hex"
"encoding/json"
"encoding/pem" "encoding/pem"
"fmt" "fmt"
"net" "net"
"time" "time"
"bytes"
"encoding/json"
"github.com/golang/protobuf/proto" "github.com/golang/protobuf/proto"
"golang.org/x/crypto/curve25519" "golang.org/x/crypto/curve25519"
"golang.org/x/crypto/ed25519" "golang.org/x/crypto/ed25519"
@@ -244,10 +244,10 @@ func (nc *NebulaCertificate) Expired(t time.Time) bool {
return nc.Details.NotBefore.After(t) || nc.Details.NotAfter.Before(t) return nc.Details.NotBefore.After(t) || nc.Details.NotAfter.Before(t)
} }
// Verify will ensure a certificate is good in all respects (expiry, group membership, signature, cert blacklist, etc) // Verify will ensure a certificate is good in all respects (expiry, group membership, signature, cert blocklist, etc)
func (nc *NebulaCertificate) Verify(t time.Time, ncp *NebulaCAPool) (bool, error) { func (nc *NebulaCertificate) Verify(t time.Time, ncp *NebulaCAPool) (bool, error) {
if ncp.IsBlacklisted(nc) { if ncp.IsBlocklisted(nc) {
return false, fmt.Errorf("certificate has been blacklisted") return false, fmt.Errorf("certificate has been blocked")
} }
signer, err := ncp.GetCAForCert(nc) signer, err := ncp.GetCAForCert(nc)
@@ -468,6 +468,63 @@ func (nc *NebulaCertificate) MarshalJSON() ([]byte, error) {
return json.Marshal(jc) return json.Marshal(jc)
} }
//func (nc *NebulaCertificate) Copy() *NebulaCertificate {
// r, err := nc.Marshal()
// if err != nil {
// //TODO
// return nil
// }
//
// c, err := UnmarshalNebulaCertificate(r)
// return c
//}
func (nc *NebulaCertificate) Copy() *NebulaCertificate {
c := &NebulaCertificate{
Details: NebulaCertificateDetails{
Name: nc.Details.Name,
Groups: make([]string, len(nc.Details.Groups)),
Ips: make([]*net.IPNet, len(nc.Details.Ips)),
Subnets: make([]*net.IPNet, len(nc.Details.Subnets)),
NotBefore: nc.Details.NotBefore,
NotAfter: nc.Details.NotAfter,
PublicKey: make([]byte, len(nc.Details.PublicKey)),
IsCA: nc.Details.IsCA,
Issuer: nc.Details.Issuer,
InvertedGroups: make(map[string]struct{}, len(nc.Details.InvertedGroups)),
},
Signature: make([]byte, len(nc.Signature)),
}
copy(c.Signature, nc.Signature)
copy(c.Details.Groups, nc.Details.Groups)
copy(c.Details.PublicKey, nc.Details.PublicKey)
for i, p := range nc.Details.Ips {
c.Details.Ips[i] = &net.IPNet{
IP: make(net.IP, len(p.IP)),
Mask: make(net.IPMask, len(p.Mask)),
}
copy(c.Details.Ips[i].IP, p.IP)
copy(c.Details.Ips[i].Mask, p.Mask)
}
for i, p := range nc.Details.Subnets {
c.Details.Subnets[i] = &net.IPNet{
IP: make(net.IP, len(p.IP)),
Mask: make(net.IPMask, len(p.Mask)),
}
copy(c.Details.Subnets[i].IP, p.IP)
copy(c.Details.Subnets[i].Mask, p.Mask)
}
for g := range nc.Details.InvertedGroups {
c.Details.InvertedGroups[g] = struct{}{}
}
return c
}
func netMatch(certIp *net.IPNet, rootIps []*net.IPNet) bool { func netMatch(certIp *net.IPNet, rootIps []*net.IPNet) bool {
for _, net := range rootIps { for _, net := range rootIps {
if net.Contains(certIp.IP) && maskContains(net.Mask, certIp.Mask) { if net.Contains(certIp.IP) && maskContains(net.Mask, certIp.Mask) {

View File

@@ -9,6 +9,7 @@ import (
"time" "time"
"github.com/golang/protobuf/proto" "github.com/golang/protobuf/proto"
"github.com/slackhq/nebula/util"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"golang.org/x/crypto/curve25519" "golang.org/x/crypto/curve25519"
"golang.org/x/crypto/ed25519" "golang.org/x/crypto/ed25519"
@@ -172,13 +173,13 @@ func TestNebulaCertificate_Verify(t *testing.T) {
f, err := c.Sha256Sum() f, err := c.Sha256Sum()
assert.Nil(t, err) assert.Nil(t, err)
caPool.BlacklistFingerprint(f) caPool.BlocklistFingerprint(f)
v, err := c.Verify(time.Now(), caPool) v, err := c.Verify(time.Now(), caPool)
assert.False(t, v) assert.False(t, v)
assert.EqualError(t, err, "certificate has been blacklisted") assert.EqualError(t, err, "certificate has been blocked")
caPool.ResetCertBlacklist() caPool.ResetCertBlocklist()
v, err = c.Verify(time.Now(), caPool) v, err = c.Verify(time.Now(), caPool)
assert.True(t, v) assert.True(t, v)
assert.Nil(t, err) assert.Nil(t, err)
@@ -487,6 +488,17 @@ func TestMarshalingNebulaCertificateConsistency(t *testing.T) {
assert.Equal(t, "0a0774657374696e67121b8182845080feffff0f828284508080fcff0f8382845080fe83f80f1a1b8182844880fe83f80f8282844880feffff0f838284488080fcff0f220b746573742d67726f757031220b746573742d67726f757032220b746573742d67726f75703328f0e0e7d70430a08681c4053a20313233343536373839306162636564666768696a3132333435363738393061624a081234567890abcedf", fmt.Sprintf("%x", b)) assert.Equal(t, "0a0774657374696e67121b8182845080feffff0f828284508080fcff0f8382845080fe83f80f1a1b8182844880fe83f80f8282844880feffff0f838284488080fcff0f220b746573742d67726f757031220b746573742d67726f757032220b746573742d67726f75703328f0e0e7d70430a08681c4053a20313233343536373839306162636564666768696a3132333435363738393061624a081234567890abcedf", fmt.Sprintf("%x", b))
} }
func TestNebulaCertificate_Copy(t *testing.T) {
ca, _, caKey, err := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
assert.Nil(t, err)
c, _, _, err := newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
assert.Nil(t, err)
cc := c.Copy()
util.AssertDeepCopyEqual(t, c, cc)
}
func newTestCaCert(before, after time.Time, ips, subnets []*net.IPNet, groups []string) (*NebulaCertificate, []byte, []byte, error) { func newTestCaCert(before, after time.Time, ips, subnets []*net.IPNet, groups []string) (*NebulaCertificate, []byte, []byte, error) {
pub, priv, err := ed25519.GenerateKey(rand.Reader) pub, priv, err := ed25519.GenerateKey(rand.Reader)
if before.IsZero() { if before.IsZero() {
@@ -499,10 +511,11 @@ func newTestCaCert(before, after time.Time, ips, subnets []*net.IPNet, groups []
nc := &NebulaCertificate{ nc := &NebulaCertificate{
Details: NebulaCertificateDetails{ Details: NebulaCertificateDetails{
Name: "test ca", Name: "test ca",
NotBefore: before, NotBefore: time.Unix(before.Unix(), 0),
NotAfter: after, NotAfter: time.Unix(after.Unix(), 0),
PublicKey: pub, PublicKey: pub,
IsCA: true, IsCA: true,
InvertedGroups: make(map[string]struct{}),
}, },
} }
@@ -544,17 +557,17 @@ func newTestCert(ca *NebulaCertificate, key []byte, before, after time.Time, ips
if len(ips) == 0 { if len(ips) == 0 {
ips = []*net.IPNet{ ips = []*net.IPNet{
{IP: net.ParseIP("10.1.1.1"), Mask: net.IPMask(net.ParseIP("255.255.255.0"))}, {IP: net.ParseIP("10.1.1.1").To4(), Mask: net.IPMask(net.ParseIP("255.255.255.0").To4())},
{IP: net.ParseIP("10.1.1.2"), Mask: net.IPMask(net.ParseIP("255.255.0.0"))}, {IP: net.ParseIP("10.1.1.2").To4(), Mask: net.IPMask(net.ParseIP("255.255.0.0").To4())},
{IP: net.ParseIP("10.1.1.3"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))}, {IP: net.ParseIP("10.1.1.3").To4(), Mask: net.IPMask(net.ParseIP("255.0.255.0").To4())},
} }
} }
if len(subnets) == 0 { if len(subnets) == 0 {
subnets = []*net.IPNet{ subnets = []*net.IPNet{
{IP: net.ParseIP("9.1.1.1"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))}, {IP: net.ParseIP("9.1.1.1").To4(), Mask: net.IPMask(net.ParseIP("255.0.255.0").To4())},
{IP: net.ParseIP("9.1.1.2"), Mask: net.IPMask(net.ParseIP("255.255.255.0"))}, {IP: net.ParseIP("9.1.1.2").To4(), Mask: net.IPMask(net.ParseIP("255.255.255.0").To4())},
{IP: net.ParseIP("9.1.1.3"), Mask: net.IPMask(net.ParseIP("255.255.0.0"))}, {IP: net.ParseIP("9.1.1.3").To4(), Mask: net.IPMask(net.ParseIP("255.255.0.0").To4())},
} }
} }
@@ -566,11 +579,12 @@ func newTestCert(ca *NebulaCertificate, key []byte, before, after time.Time, ips
Ips: ips, Ips: ips,
Subnets: subnets, Subnets: subnets,
Groups: groups, Groups: groups,
NotBefore: before, NotBefore: time.Unix(before.Unix(), 0),
NotAfter: after, NotAfter: time.Unix(after.Unix(), 0),
PublicKey: pub, PublicKey: pub,
IsCA: false, IsCA: false,
Issuer: issuer, Issuer: issuer,
InvertedGroups: make(map[string]struct{}),
}, },
} }

View File

@@ -3,10 +3,11 @@ package main
import ( import (
"bytes" "bytes"
"errors" "errors"
"github.com/stretchr/testify/assert"
"io" "io"
"os" "os"
"testing" "testing"
"github.com/stretchr/testify/assert"
) )
//TODO: all flag parsing continueOnError will print to stderr on its own currently //TODO: all flag parsing continueOnError will print to stderr on its own currently

View File

@@ -4,11 +4,12 @@ import (
"encoding/json" "encoding/json"
"flag" "flag"
"fmt" "fmt"
"github.com/slackhq/nebula/cert"
"io" "io"
"io/ioutil" "io/ioutil"
"os" "os"
"strings" "strings"
"github.com/slackhq/nebula/cert"
) )
type printFlags struct { type printFlags struct {

View File

@@ -2,12 +2,13 @@ package main
import ( import (
"bytes" "bytes"
"github.com/slackhq/nebula/cert"
"github.com/stretchr/testify/assert"
"io/ioutil" "io/ioutil"
"os" "os"
"testing" "testing"
"time" "time"
"github.com/slackhq/nebula/cert"
"github.com/stretchr/testify/assert"
) )
func Test_printSummary(t *testing.T) { func Test_printSummary(t *testing.T) {

View File

@@ -3,12 +3,13 @@ package main
import ( import (
"flag" "flag"
"fmt" "fmt"
"github.com/slackhq/nebula/cert"
"io" "io"
"io/ioutil" "io/ioutil"
"os" "os"
"strings" "strings"
"time" "time"
"github.com/slackhq/nebula/cert"
) )
type verifyFlags struct { type verifyFlags struct {

View File

@@ -3,13 +3,14 @@ package main
import ( import (
"bytes" "bytes"
"crypto/rand" "crypto/rand"
"github.com/slackhq/nebula/cert"
"github.com/stretchr/testify/assert"
"golang.org/x/crypto/ed25519"
"io/ioutil" "io/ioutil"
"os" "os"
"testing" "testing"
"time" "time"
"github.com/slackhq/nebula/cert"
"github.com/stretchr/testify/assert"
"golang.org/x/crypto/ed25519"
) )
func Test_verifySummary(t *testing.T) { func Test_verifySummary(t *testing.T) {

View File

@@ -5,6 +5,7 @@ import (
"fmt" "fmt"
"os" "os"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula" "github.com/slackhq/nebula"
) )
@@ -45,5 +46,30 @@ func main() {
os.Exit(1) os.Exit(1)
} }
nebula.Main(*configPath, *configTest, Build) config := nebula.NewConfig()
err := config.Load(*configPath)
if err != nil {
fmt.Printf("failed to load config: %s", err)
os.Exit(1)
}
l := logrus.New()
l.Out = os.Stdout
c, err := nebula.Main(config, *configTest, Build, l, nil)
switch v := err.(type) {
case nebula.ContextualError:
v.Log(l)
os.Exit(1)
case error:
l.WithError(err).Error("Failed to start")
os.Exit(1)
}
if !*configTest {
c.Start()
c.ShutdownBlock()
}
os.Exit(0)
} }

View File

@@ -1,44 +1,53 @@
package main package main
import ( import (
"fmt"
"log" "log"
"os" "os"
"path/filepath" "path/filepath"
"github.com/kardianos/service" "github.com/kardianos/service"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula" "github.com/slackhq/nebula"
) )
var logger service.Logger var logger service.Logger
type program struct { type program struct {
exit chan struct{}
configPath *string configPath *string
configTest *bool configTest *bool
build string build string
control *nebula.Control
} }
func (p *program) Start(s service.Service) error { func (p *program) Start(s service.Service) error {
logger.Info("Nebula service starting.")
p.exit = make(chan struct{})
// Start should not block. // Start should not block.
go p.run() logger.Info("Nebula service starting.")
return nil
config := nebula.NewConfig()
err := config.Load(*p.configPath)
if err != nil {
return fmt.Errorf("failed to load config: %s", err)
} }
func (p *program) run() error { l := logrus.New()
nebula.Main(*p.configPath, *p.configTest, Build) l.Out = os.Stdout
p.control, err = nebula.Main(config, *p.configTest, Build, l, nil)
if err != nil {
return err
}
p.control.Start()
return nil return nil
} }
func (p *program) Stop(s service.Service) error { func (p *program) Stop(s service.Service) error {
logger.Info("Nebula service stopping.") logger.Info("Nebula service stopping.")
close(p.exit) p.control.Stop()
return nil return nil
} }
func doService(configPath *string, configTest *bool, build string, serviceFlag *string) { func doService(configPath *string, configTest *bool, build string, serviceFlag *string) {
if *configPath == "" { if *configPath == "" {
ex, err := os.Executable() ex, err := os.Executable()
if err != nil { if err != nil {

View File

@@ -5,6 +5,7 @@ import (
"fmt" "fmt"
"os" "os"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula" "github.com/slackhq/nebula"
) )
@@ -39,5 +40,30 @@ func main() {
os.Exit(1) os.Exit(1)
} }
nebula.Main(*configPath, *configTest, Build) config := nebula.NewConfig()
err := config.Load(*configPath)
if err != nil {
fmt.Printf("failed to load config: %s", err)
os.Exit(1)
}
l := logrus.New()
l.Out = os.Stdout
c, err := nebula.Main(config, *configTest, Build, l, nil)
switch v := err.(type) {
case nebula.ContextualError:
v.Log(l)
os.Exit(1)
case error:
l.WithError(err).Error("Failed to start")
os.Exit(1)
}
if !*configTest {
c.Start()
c.ShutdownBlock()
}
os.Exit(0)
} }

188
config.go
View File

@@ -1,19 +1,23 @@
package nebula package nebula
import ( import (
"errors"
"fmt" "fmt"
"github.com/imdario/mergo"
"github.com/sirupsen/logrus"
"gopkg.in/yaml.v2"
"io/ioutil" "io/ioutil"
"net"
"os" "os"
"os/signal" "os/signal"
"path/filepath" "path/filepath"
"regexp"
"sort" "sort"
"strconv" "strconv"
"strings" "strings"
"syscall" "syscall"
"time" "time"
"github.com/imdario/mergo"
"github.com/sirupsen/logrus"
"gopkg.in/yaml.v2"
) )
type Config struct { type Config struct {
@@ -35,7 +39,7 @@ func (c *Config) Load(path string) error {
c.path = path c.path = path
c.files = make([]string, 0) c.files = make([]string, 0)
err := c.resolve(path) err := c.resolve(path, true)
if err != nil { if err != nil {
return err return err
} }
@@ -54,6 +58,13 @@ func (c *Config) Load(path string) error {
return nil return nil
} }
func (c *Config) LoadString(raw string) error {
if raw == "" {
return errors.New("Empty configuration")
}
return c.parseRaw([]byte(raw))
}
// RegisterReloadCallback stores a function to be called when a config reload is triggered. The functions registered // RegisterReloadCallback stores a function to be called when a config reload is triggered. The functions registered
// here should decide if they need to make a change to the current process before making the change. HasChanged can be // here should decide if they need to make a change to the current process before making the change. HasChanged can be
// used to help decide if a change is necessary. // used to help decide if a change is necessary.
@@ -213,10 +224,137 @@ func (c *Config) GetDuration(k string, d time.Duration) time.Duration {
return v return v
} }
func (c *Config) GetAllowList(k string, allowInterfaces bool) (*AllowList, error) {
r := c.Get(k)
if r == nil {
return nil, nil
}
rawMap, ok := r.(map[interface{}]interface{})
if !ok {
return nil, fmt.Errorf("config `%s` has invalid type: %T", k, r)
}
tree := NewCIDRTree()
var nameRules []AllowListNameRule
firstValue := true
allValuesMatch := true
defaultSet := false
var allValues bool
for rawKey, rawValue := range rawMap {
rawCIDR, ok := rawKey.(string)
if !ok {
return nil, fmt.Errorf("config `%s` has invalid key (type %T): %v", k, rawKey, rawKey)
}
// Special rule for interface names
if rawCIDR == "interfaces" {
if !allowInterfaces {
return nil, fmt.Errorf("config `%s` does not support `interfaces`", k)
}
var err error
nameRules, err = c.getAllowListInterfaces(k, rawValue)
if err != nil {
return nil, err
}
continue
}
value, ok := rawValue.(bool)
if !ok {
return nil, fmt.Errorf("config `%s` has invalid value (type %T): %v", k, rawValue, rawValue)
}
_, cidr, err := net.ParseCIDR(rawCIDR)
if err != nil {
return nil, fmt.Errorf("config `%s` has invalid CIDR: %s", k, rawCIDR)
}
// TODO: should we error on duplicate CIDRs in the config?
tree.AddCIDR(cidr, value)
if firstValue {
allValues = value
firstValue = false
} else {
if value != allValues {
allValuesMatch = false
}
}
// Check if this is 0.0.0.0/0
bits, size := cidr.Mask.Size()
if bits == 0 && size == 32 {
defaultSet = true
}
}
if !defaultSet {
if allValuesMatch {
_, zeroCIDR, _ := net.ParseCIDR("0.0.0.0/0")
tree.AddCIDR(zeroCIDR, !allValues)
} else {
return nil, fmt.Errorf("config `%s` contains both true and false rules, but no default set for 0.0.0.0/0", k)
}
}
return &AllowList{cidrTree: tree, nameRules: nameRules}, nil
}
func (c *Config) getAllowListInterfaces(k string, v interface{}) ([]AllowListNameRule, error) {
var nameRules []AllowListNameRule
rawRules, ok := v.(map[interface{}]interface{})
if !ok {
return nil, fmt.Errorf("config `%s.interfaces` is invalid (type %T): %v", k, v, v)
}
firstEntry := true
var allValues bool
for rawName, rawAllow := range rawRules {
name, ok := rawName.(string)
if !ok {
return nil, fmt.Errorf("config `%s.interfaces` has invalid key (type %T): %v", k, rawName, rawName)
}
allow, ok := rawAllow.(bool)
if !ok {
return nil, fmt.Errorf("config `%s.interfaces` has invalid value (type %T): %v", k, rawAllow, rawAllow)
}
nameRE, err := regexp.Compile("^" + name + "$")
if err != nil {
return nil, fmt.Errorf("config `%s.interfaces` has invalid key: %s: %v", k, name, err)
}
nameRules = append(nameRules, AllowListNameRule{
Name: nameRE,
Allow: allow,
})
if firstEntry {
allValues = allow
firstEntry = false
} else {
if allow != allValues {
return nil, fmt.Errorf("config `%s.interfaces` values must all be the same true/false value", k)
}
}
}
return nameRules, nil
}
func (c *Config) Get(k string) interface{} { func (c *Config) Get(k string) interface{} {
return c.get(k, c.Settings) return c.get(k, c.Settings)
} }
func (c *Config) IsSet(k string) bool {
return c.get(k, c.Settings) != nil
}
func (c *Config) get(k string, v interface{}) interface{} { func (c *Config) get(k string, v interface{}) interface{} {
parts := strings.Split(k, ".") parts := strings.Split(k, ".")
for _, p := range parts { for _, p := range parts {
@@ -234,14 +372,16 @@ func (c *Config) get(k string, v interface{}) interface{} {
return v return v
} }
func (c *Config) resolve(path string) error { // direct signifies if this is the config path directly specified by the user,
// versus a file/dir found by recursing into that path
func (c *Config) resolve(path string, direct bool) error {
i, err := os.Stat(path) i, err := os.Stat(path)
if err != nil { if err != nil {
return nil return nil
} }
if !i.IsDir() { if !i.IsDir() {
c.addFile(path) c.addFile(path, direct)
return nil return nil
} }
@@ -251,7 +391,7 @@ func (c *Config) resolve(path string) error {
} }
for _, p := range paths { for _, p := range paths {
err := c.resolve(filepath.Join(path, p)) err := c.resolve(filepath.Join(path, p), false)
if err != nil { if err != nil {
return err return err
} }
@@ -260,10 +400,10 @@ func (c *Config) resolve(path string) error {
return nil return nil
} }
func (c *Config) addFile(path string) error { func (c *Config) addFile(path string, direct bool) error {
ext := filepath.Ext(path) ext := filepath.Ext(path)
if ext != ".yaml" && ext != ".yml" { if !direct && ext != ".yaml" && ext != ".yml" {
return nil return nil
} }
@@ -276,6 +416,18 @@ func (c *Config) addFile(path string) error {
return nil return nil
} }
func (c *Config) parseRaw(b []byte) error {
var m map[interface{}]interface{}
err := yaml.Unmarshal(b, &m)
if err != nil {
return err
}
c.Settings = m
return nil
}
func (c *Config) parse() error { func (c *Config) parse() error {
var m map[interface{}]interface{} var m map[interface{}]interface{}
@@ -328,12 +480,26 @@ func configLogger(c *Config) error {
} }
l.SetLevel(logLevel) l.SetLevel(logLevel)
disableTimestamp := c.GetBool("logging.disable_timestamp", false)
timestampFormat := c.GetString("logging.timestamp_format", "")
fullTimestamp := (timestampFormat != "")
if timestampFormat == "" {
timestampFormat = time.RFC3339
}
logFormat := strings.ToLower(c.GetString("logging.format", "text")) logFormat := strings.ToLower(c.GetString("logging.format", "text"))
switch logFormat { switch logFormat {
case "text": case "text":
l.Formatter = &logrus.TextFormatter{} l.Formatter = &logrus.TextFormatter{
TimestampFormat: timestampFormat,
FullTimestamp: fullTimestamp,
DisableTimestamp: disableTimestamp,
}
case "json": case "json":
l.Formatter = &logrus.JSONFormatter{} l.Formatter = &logrus.JSONFormatter{
TimestampFormat: timestampFormat,
DisableTimestamp: disableTimestamp,
}
default: default:
return fmt.Errorf("unknown log format `%s`. possible formats: %s", logFormat, []string{"text", "json"}) return fmt.Errorf("unknown log format `%s`. possible formats: %s", logFormat, []string{"text", "json"})
} }

View File

@@ -1,12 +1,13 @@
package nebula package nebula
import ( import (
"github.com/stretchr/testify/assert"
"io/ioutil" "io/ioutil"
"os" "os"
"path/filepath" "path/filepath"
"testing" "testing"
"time" "time"
"github.com/stretchr/testify/assert"
) )
func TestConfig_Load(t *testing.T) { func TestConfig_Load(t *testing.T) {
@@ -86,6 +87,76 @@ func TestConfig_GetBool(t *testing.T) {
assert.Equal(t, false, c.GetBool("bool", true)) assert.Equal(t, false, c.GetBool("bool", true))
} }
func TestConfig_GetAllowList(t *testing.T) {
c := NewConfig()
c.Settings["allowlist"] = map[interface{}]interface{}{
"192.168.0.0": true,
}
r, err := c.GetAllowList("allowlist", false)
assert.EqualError(t, err, "config `allowlist` has invalid CIDR: 192.168.0.0")
assert.Nil(t, r)
c.Settings["allowlist"] = map[interface{}]interface{}{
"192.168.0.0/16": "abc",
}
r, err = c.GetAllowList("allowlist", false)
assert.EqualError(t, err, "config `allowlist` has invalid value (type string): abc")
c.Settings["allowlist"] = map[interface{}]interface{}{
"192.168.0.0/16": true,
"10.0.0.0/8": false,
}
r, err = c.GetAllowList("allowlist", false)
assert.EqualError(t, err, "config `allowlist` contains both true and false rules, but no default set for 0.0.0.0/0")
c.Settings["allowlist"] = map[interface{}]interface{}{
"0.0.0.0/0": true,
"10.0.0.0/8": false,
"10.42.42.0/24": true,
}
r, err = c.GetAllowList("allowlist", false)
if assert.NoError(t, err) {
assert.NotNil(t, r)
}
// Test interface names
c.Settings["allowlist"] = map[interface{}]interface{}{
"interfaces": map[interface{}]interface{}{
`docker.*`: false,
},
}
r, err = c.GetAllowList("allowlist", false)
assert.EqualError(t, err, "config `allowlist` does not support `interfaces`")
c.Settings["allowlist"] = map[interface{}]interface{}{
"interfaces": map[interface{}]interface{}{
`docker.*`: "foo",
},
}
r, err = c.GetAllowList("allowlist", true)
assert.EqualError(t, err, "config `allowlist.interfaces` has invalid value (type string): foo")
c.Settings["allowlist"] = map[interface{}]interface{}{
"interfaces": map[interface{}]interface{}{
`docker.*`: false,
`eth.*`: true,
},
}
r, err = c.GetAllowList("allowlist", true)
assert.EqualError(t, err, "config `allowlist.interfaces` values must all be the same true/false value")
c.Settings["allowlist"] = map[interface{}]interface{}{
"interfaces": map[interface{}]interface{}{
`docker.*`: false,
},
}
r, err = c.GetAllowList("allowlist", true)
if assert.NoError(t, err) {
assert.NotNil(t, r)
}
}
func TestConfig_HasChanged(t *testing.T) { func TestConfig_HasChanged(t *testing.T) {
// No reload has occurred, return false // No reload has occurred, return false
c := NewConfig() c := NewConfig()

View File

@@ -182,7 +182,7 @@ func (n *connectionManager) HandleMonitorTick(now time.Time) {
continue continue
} }
l.WithField("vpnIp", IntIp(vpnIP)). hostinfo.logger().
WithField("tunnelCheck", m{"state": "testing", "method": "active"}). WithField("tunnelCheck", m{"state": "testing", "method": "active"}).
Debug("Tunnel status") Debug("Tunnel status")
@@ -191,7 +191,7 @@ func (n *connectionManager) HandleMonitorTick(now time.Time) {
n.intf.SendMessageToVpnIp(test, testRequest, vpnIP, []byte(""), make([]byte, 12, 12), make([]byte, mtu)) n.intf.SendMessageToVpnIp(test, testRequest, vpnIP, []byte(""), make([]byte, 12, 12), make([]byte, mtu))
} else { } else {
l.Debugf("Hostinfo sadness: %s", IntIp(vpnIP)) hostinfo.logger().Debugf("Hostinfo sadness: %s", IntIp(vpnIP))
} }
n.AddPendingDeletion(vpnIP) n.AddPendingDeletion(vpnIP)
} }
@@ -233,7 +233,7 @@ func (n *connectionManager) HandleDeletionTick(now time.Time) {
if hostinfo.ConnectionState != nil && hostinfo.ConnectionState.peerCert != nil { if hostinfo.ConnectionState != nil && hostinfo.ConnectionState.peerCert != nil {
cn = hostinfo.ConnectionState.peerCert.Details.Name cn = hostinfo.ConnectionState.peerCert.Details.Name
} }
l.WithField("vpnIp", IntIp(vpnIP)). hostinfo.logger().
WithField("tunnelCheck", m{"state": "dead", "method": "active"}). WithField("tunnelCheck", m{"state": "dead", "method": "active"}).
WithField("certName", cn). WithField("certName", cn).
Info("Tunnel status") Info("Tunnel status")

View File

@@ -28,7 +28,7 @@ func Test_NewConnectionManagerTest(t *testing.T) {
rawCertificateNoKey: []byte{}, rawCertificateNoKey: []byte{},
} }
lh := NewLightHouse(false, 0, []uint32{}, 1000, 0, &udpConn{}, false) lh := NewLightHouse(false, 0, []uint32{}, 1000, 0, &udpConn{}, false, 1, false)
ifce := &Interface{ ifce := &Interface{
hostMap: hostMap, hostMap: hostMap,
inside: &Tun{}, inside: &Tun{},
@@ -36,7 +36,7 @@ func Test_NewConnectionManagerTest(t *testing.T) {
certState: cs, certState: cs,
firewall: &Firewall{}, firewall: &Firewall{},
lightHouse: lh, lightHouse: lh,
handshakeManager: NewHandshakeManager(vpncidr, preferredRanges, hostMap, lh, &udpConn{}), handshakeManager: NewHandshakeManager(vpncidr, preferredRanges, hostMap, lh, &udpConn{}, defaultHandshakeConfig),
} }
now := time.Now() now := time.Now()
@@ -91,7 +91,7 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
rawCertificateNoKey: []byte{}, rawCertificateNoKey: []byte{},
} }
lh := NewLightHouse(false, 0, []uint32{}, 1000, 0, &udpConn{}, false) lh := NewLightHouse(false, 0, []uint32{}, 1000, 0, &udpConn{}, false, 1, false)
ifce := &Interface{ ifce := &Interface{
hostMap: hostMap, hostMap: hostMap,
inside: &Tun{}, inside: &Tun{},
@@ -99,7 +99,7 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
certState: cs, certState: cs,
firewall: &Firewall{}, firewall: &Firewall{},
lightHouse: lh, lightHouse: lh,
handshakeManager: NewHandshakeManager(vpncidr, preferredRanges, hostMap, lh, &udpConn{}), handshakeManager: NewHandshakeManager(vpncidr, preferredRanges, hostMap, lh, &udpConn{}, defaultHandshakeConfig),
} }
now := time.Now() now := time.Now()

216
control.go Normal file
View File

@@ -0,0 +1,216 @@
package nebula
import (
"encoding/binary"
"fmt"
"net"
"os"
"os/signal"
"syscall"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert"
"golang.org/x/net/ipv4"
)
// Every interaction here needs to take extra care to copy memory and not return or use arguments "as is" when touching
// core. This means copying IP objects, slices, de-referencing pointers and taking the actual value, etc
type Control struct {
f *Interface
l *logrus.Logger
}
type ControlHostInfo struct {
VpnIP net.IP `json:"vpnIp"`
LocalIndex uint32 `json:"localIndex"`
RemoteIndex uint32 `json:"remoteIndex"`
RemoteAddrs []udpAddr `json:"remoteAddrs"`
CachedPackets int `json:"cachedPackets"`
Cert *cert.NebulaCertificate `json:"cert"`
MessageCounter uint64 `json:"messageCounter"`
CurrentRemote udpAddr `json:"currentRemote"`
}
// Start actually runs nebula, this is a nonblocking call. To block use Control.ShutdownBlock()
func (c *Control) Start() {
c.f.run()
}
// Stop signals nebula to shutdown, returns after the shutdown is complete
func (c *Control) Stop() {
//TODO: stop tun and udp routines, the lock on hostMap effectively does that though
//TODO: this is probably better as a function in ConnectionManager or HostMap directly
c.f.hostMap.Lock()
for _, h := range c.f.hostMap.Hosts {
if h.ConnectionState.ready {
c.f.send(closeTunnel, 0, h.ConnectionState, h, h.remote, []byte{}, make([]byte, 12, 12), make([]byte, mtu))
c.l.WithField("vpnIp", IntIp(h.hostId)).WithField("udpAddr", h.remote).
Debug("Sending close tunnel message")
}
}
c.f.hostMap.Unlock()
c.l.Info("Goodbye")
}
// ShutdownBlock will listen for and block on term and interrupt signals, calling Control.Stop() once signalled
func (c *Control) ShutdownBlock() {
sigChan := make(chan os.Signal)
signal.Notify(sigChan, syscall.SIGTERM)
signal.Notify(sigChan, syscall.SIGINT)
rawSig := <-sigChan
sig := rawSig.String()
c.l.WithField("signal", sig).Info("Caught signal, shutting down")
c.Stop()
}
// RebindUDPServer asks the UDP listener to rebind it's listener. Mainly used on mobile clients when interfaces change
func (c *Control) RebindUDPServer() {
_ = c.f.outside.Rebind()
}
// ListHostmap returns details about the actual or pending (handshaking) hostmap
func (c *Control) ListHostmap(pendingMap bool) []ControlHostInfo {
var hm *HostMap
if pendingMap {
hm = c.f.handshakeManager.pendingHostMap
} else {
hm = c.f.hostMap
}
hm.RLock()
hosts := make([]ControlHostInfo, len(hm.Hosts))
i := 0
for _, v := range hm.Hosts {
hosts[i] = copyHostInfo(v)
i++
}
hm.RUnlock()
return hosts
}
// GetHostInfoByVpnIP returns a single tunnels hostInfo, or nil if not found
func (c *Control) GetHostInfoByVpnIP(vpnIP uint32, pending bool) *ControlHostInfo {
var hm *HostMap
if pending {
hm = c.f.handshakeManager.pendingHostMap
} else {
hm = c.f.hostMap
}
h, err := hm.QueryVpnIP(vpnIP)
if err != nil {
return nil
}
ch := copyHostInfo(h)
return &ch
}
// SetRemoteForTunnel forces a tunnel to use a specific remote
func (c *Control) SetRemoteForTunnel(vpnIP uint32, addr udpAddr) *ControlHostInfo {
hostInfo, err := c.f.hostMap.QueryVpnIP(vpnIP)
if err != nil {
return nil
}
hostInfo.SetRemote(addr.Copy())
ch := copyHostInfo(hostInfo)
return &ch
}
// CloseTunnel closes a fully established tunnel. If localOnly is false it will notify the remote end as well.
func (c *Control) CloseTunnel(vpnIP uint32, localOnly bool) bool {
hostInfo, err := c.f.hostMap.QueryVpnIP(vpnIP)
if err != nil {
return false
}
if !localOnly {
c.f.send(
closeTunnel,
0,
hostInfo.ConnectionState,
hostInfo,
hostInfo.remote,
[]byte{},
make([]byte, 12, 12),
make([]byte, mtu),
)
}
c.f.closeTunnel(hostInfo)
return true
}
func copyHostInfo(h *HostInfo) ControlHostInfo {
addrs := h.RemoteUDPAddrs()
chi := ControlHostInfo{
VpnIP: int2ip(h.hostId),
LocalIndex: h.localIndexId,
RemoteIndex: h.remoteIndexId,
RemoteAddrs: make([]udpAddr, len(addrs), len(addrs)),
CachedPackets: len(h.packetStore),
MessageCounter: *h.ConnectionState.messageCounter,
}
if c := h.GetCert(); c != nil {
chi.Cert = c.Copy()
}
if h.remote != nil {
chi.CurrentRemote = *h.remote
}
for i, addr := range addrs {
chi.RemoteAddrs[i] = addr.Copy()
}
return chi
}
// Hook provides the ability to hook into the network path for a particular
// message sub type. Any received message of that subtype that is allowed by
// the firewall will be written to the provided write func instead of the
// inside interface.
// TODO: make this an io.Writer
func (c *Control) Hook(t NebulaMessageSubType, w func([]byte) error) error {
if t == 0 {
return fmt.Errorf("non-default message subtype must be specified")
}
if _, ok := c.f.handlers[Version][message][t]; ok {
return fmt.Errorf("message subtype %d already hooked", t)
}
c.f.handlers[Version][message][t] = c.f.newHook(w)
return nil
}
// Send provides the ability to send arbitrary message packets to peer nodes.
// The provided payload will be encapsulated in a Nebula Firewall packet
// (IPv4 plus ports) from the node IP to the provided destination nebula IP.
// Any protocol handling above layer 3 (IP) must be managed by the caller.
func (c *Control) Send(ip uint32, port uint16, st NebulaMessageSubType, payload []byte) {
headerLen := ipv4.HeaderLen + minFwPacketLen
length := headerLen + len(payload)
packet := make([]byte, length)
packet[0] = 0x45 // IPv4 HL=20
packet[9] = 114 // Declare as arbitrary 0-hop protocol
binary.BigEndian.PutUint16(packet[2:4], uint16(length))
binary.BigEndian.PutUint32(packet[12:16], ip2int(c.f.inside.CidrNet().IP.To4()))
binary.BigEndian.PutUint32(packet[16:20], ip)
// Set identical values for src and dst port as they're only
// used for nebula firewall rule/conntrack matching.
binary.BigEndian.PutUint16(packet[20:22], port)
binary.BigEndian.PutUint16(packet[22:24], port)
copy(packet[headerLen:], payload)
fp := &FirewallPacket{}
nb := make([]byte, 12)
out := make([]byte, mtu)
c.f.consumeInsidePacket(st, packet, fp, nb, out)
}

111
control_test.go Normal file
View File

@@ -0,0 +1,111 @@
package nebula
import (
"net"
"reflect"
"testing"
"time"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/util"
"github.com/stretchr/testify/assert"
)
func TestControl_GetHostInfoByVpnIP(t *testing.T) {
// Special care must be taken to re-use all objects provided to the hostmap and certificate in the expectedInfo object
// To properly ensure we are not exposing core memory to the caller
hm := NewHostMap("test", &net.IPNet{}, make([]*net.IPNet, 0))
remote1 := NewUDPAddr(100, 4444)
remote2 := NewUDPAddr(101, 4444)
ipNet := net.IPNet{
IP: net.IPv4(1, 2, 3, 4),
Mask: net.IPMask{255, 255, 255, 0},
}
ipNet2 := net.IPNet{
IP: net.IPv4(1, 2, 3, 5),
Mask: net.IPMask{255, 255, 255, 0},
}
crt := &cert.NebulaCertificate{
Details: cert.NebulaCertificateDetails{
Name: "test",
Ips: []*net.IPNet{&ipNet},
Subnets: []*net.IPNet{},
Groups: []string{"default-group"},
NotBefore: time.Unix(1, 0),
NotAfter: time.Unix(2, 0),
PublicKey: []byte{5, 6, 7, 8},
IsCA: false,
Issuer: "the-issuer",
InvertedGroups: map[string]struct{}{"default-group": {}},
},
Signature: []byte{1, 2, 1, 2, 1, 3},
}
counter := uint64(0)
remotes := []*HostInfoDest{NewHostInfoDest(remote1), NewHostInfoDest(remote2)}
hm.Add(ip2int(ipNet.IP), &HostInfo{
remote: remote1,
Remotes: remotes,
ConnectionState: &ConnectionState{
peerCert: crt,
messageCounter: &counter,
},
remoteIndexId: 200,
localIndexId: 201,
hostId: ip2int(ipNet.IP),
})
hm.Add(ip2int(ipNet2.IP), &HostInfo{
remote: remote1,
Remotes: remotes,
ConnectionState: &ConnectionState{
peerCert: nil,
messageCounter: &counter,
},
remoteIndexId: 200,
localIndexId: 201,
hostId: ip2int(ipNet2.IP),
})
c := Control{
f: &Interface{
hostMap: hm,
},
l: logrus.New(),
}
thi := c.GetHostInfoByVpnIP(ip2int(ipNet.IP), false)
expectedInfo := ControlHostInfo{
VpnIP: net.IPv4(1, 2, 3, 4).To4(),
LocalIndex: 201,
RemoteIndex: 200,
RemoteAddrs: []udpAddr{*remote1, *remote2},
CachedPackets: 0,
Cert: crt.Copy(),
MessageCounter: 0,
CurrentRemote: *NewUDPAddr(100, 4444),
}
// Make sure we don't have any unexpected fields
assertFields(t, []string{"VpnIP", "LocalIndex", "RemoteIndex", "RemoteAddrs", "CachedPackets", "Cert", "MessageCounter", "CurrentRemote"}, thi)
util.AssertDeepCopyEqual(t, &expectedInfo, thi)
// Make sure we don't panic if the host info doesn't have a cert yet
assert.NotPanics(t, func() {
thi = c.GetHostInfoByVpnIP(ip2int(ipNet2.IP), false)
})
}
func assertFields(t *testing.T, expected []string, actualStruct interface{}) {
val := reflect.ValueOf(actualStruct).Elem()
fields := make([]string, val.NumField())
for i := 0; i < val.NumField(); i++ {
fields[i] = val.Type().Field(i).Name
}
assert.Equal(t, expected, fields)
}

15
dist/arch/nebula.service vendored Normal file
View File

@@ -0,0 +1,15 @@
[Unit]
Description=nebula
Wants=basic.target network-online.target
After=basic.target network.target network-online.target
[Service]
SyslogIdentifier=nebula
StandardOutput=syslog
StandardError=syslog
ExecReload=/bin/kill -HUP $MAINPID
ExecStart=/usr/bin/nebula -config /etc/nebula/config.yml
Restart=always
[Install]
WantedBy=multi-user.target

113
dist/wireshark/nebula.lua vendored Normal file
View File

@@ -0,0 +1,113 @@
local nebula = Proto("nebula", "nebula")
local default_settings = {
port = 4242,
all_ports = false,
}
nebula.prefs.port = Pref.uint("Port number", default_settings.port, "The UDP port number for Nebula")
nebula.prefs.all_ports = Pref.bool("All ports", default_settings.all_ports, "Assume nebula packets on any port, useful when dealing with hole punching")
local pf_version = ProtoField.new("version", "nebula.version", ftypes.UINT8, nil, base.DEC, 0xF0)
local pf_type = ProtoField.new("type", "nebula.type", ftypes.UINT8, {
[0] = "handshake",
[1] = "message",
[2] = "recvError",
[3] = "lightHouse",
[4] = "test",
[5] = "closeTunnel",
}, base.DEC, 0x0F)
local pf_subtype = ProtoField.new("subtype", "nebula.subtype", ftypes.UINT8, nil, base.DEC)
local pf_subtype_test = ProtoField.new("subtype", "nebula.subtype", ftypes.UINT8, {
[0] = "request",
[1] = "reply",
}, base.DEC)
local pf_subtype_handshake = ProtoField.new("subtype", "nebula.subtype", ftypes.UINT8, {
[0] = "ix_psk0",
}, base.DEC)
local pf_reserved = ProtoField.new("reserved", "nebula.reserved", ftypes.UINT16, nil, base.HEX)
local pf_remote_index = ProtoField.new("remote index", "nebula.remote_index", ftypes.UINT32, nil, base.DEC)
local pf_message_counter = ProtoField.new("counter", "nebula.counter", ftypes.UINT64, nil, base.DEC)
local pf_payload = ProtoField.new("payload", "nebula.payload", ftypes.BYTES, nil, base.NONE)
nebula.fields = { pf_version, pf_type, pf_subtype, pf_subtype_handshake, pf_subtype_test, pf_reserved, pf_remote_index, pf_message_counter, pf_payload }
local ef_holepunch = ProtoExpert.new("nebula.holepunch.expert", "Nebula hole punch packet", expert.group.PROTOCOL, expert.severity.NOTE)
local ef_punchy = ProtoExpert.new("nebula.punchy.expert", "Nebula punchy keepalive packet", expert.group.PROTOCOL, expert.severity.NOTE)
nebula.experts = { ef_holepunch, ef_punchy }
local type_field = Field.new("nebula.type")
local subtype_field = Field.new("nebula.subtype")
function nebula.dissector(tvbuf, pktinfo, root)
-- set the protocol column to show our protocol name
pktinfo.cols.protocol:set("NEBULA")
local pktlen = tvbuf:reported_length_remaining()
local tree = root:add(nebula, tvbuf:range(0,pktlen))
if pktlen == 0 then
tree:add_proto_expert_info(ef_holepunch)
pktinfo.cols.info:append(" (holepunch)")
return
elseif pktlen == 1 then
tree:add_proto_expert_info(ef_punchy)
pktinfo.cols.info:append(" (punchy)")
return
end
tree:add(pf_version, tvbuf:range(0,1))
local type = tree:add(pf_type, tvbuf:range(0,1))
local nebula_type = bit32.band(tvbuf:range(0,1):uint(), 0x0F)
if nebula_type == 0 then
local stage = tvbuf(8,8):uint64()
tree:add(pf_subtype_handshake, tvbuf:range(1,1))
type:append_text(" stage " .. stage)
pktinfo.cols.info:append(" (" .. type_field().display .. ", stage " .. stage .. ", " .. subtype_field().display .. ")")
elseif nebula_type == 4 then
tree:add(pf_subtype_test, tvbuf:range(1,1))
pktinfo.cols.info:append(" (" .. type_field().display .. ", " .. subtype_field().display .. ")")
else
tree:add(pf_subtype, tvbuf:range(1,1))
pktinfo.cols.info:append(" (" .. type_field().display .. ")")
end
tree:add(pf_reserved, tvbuf:range(2,2))
tree:add(pf_remote_index, tvbuf:range(4,4))
tree:add(pf_message_counter, tvbuf:range(8,8))
tree:add(pf_payload, tvbuf:range(16,tvbuf:len() - 16))
end
function nebula.prefs_changed()
if default_settings.all_ports == nebula.prefs.all_ports and default_settings.port == nebula.prefs.port then
-- Nothing changed, bail
return
end
-- Remove our old dissector
DissectorTable.get("udp.port"):remove_all(nebula)
if nebula.prefs.all_ports and default_settings.all_ports ~= nebula.prefs.all_ports then
default_settings.all_port = nebula.prefs.all_ports
for i=0, 65535 do
DissectorTable.get("udp.port"):add(i, nebula)
end
-- no need to establish again on specific ports
return
end
if default_settings.all_ports ~= nebula.prefs.all_ports then
-- Add our new port dissector
default_settings.port = nebula.prefs.port
DissectorTable.get("udp.port"):add(default_settings.port, nebula)
end
end
DissectorTable.get("udp.port"):add(default_settings.port, nebula)

View File

@@ -7,8 +7,8 @@ pki:
ca: /etc/nebula/ca.crt ca: /etc/nebula/ca.crt
cert: /etc/nebula/host.crt cert: /etc/nebula/host.crt
key: /etc/nebula/host.key key: /etc/nebula/host.key
#blacklist is a list of certificate fingerprints that we will refuse to talk to #blocklist is a list of certificate fingerprints that we will refuse to talk to
#blacklist: #blocklist:
# - c99d4e650533b92061b09918e838a5a0a6aaee21eed1d12fd937682865936c72 # - c99d4e650533b92061b09918e838a5a0a6aaee21eed1d12fd937682865936c72
# The static host map defines a set of hosts with fixed IP addresses on the internet (or any network). # The static host map defines a set of hosts with fixed IP addresses on the internet (or any network).
@@ -36,9 +36,41 @@ lighthouse:
interval: 60 interval: 60
# hosts is a list of lighthouse hosts this node should report to and query from # hosts is a list of lighthouse hosts this node should report to and query from
# IMPORTANT: THIS SHOULD BE EMPTY ON LIGHTHOUSE NODES # IMPORTANT: THIS SHOULD BE EMPTY ON LIGHTHOUSE NODES
# IMPORTANT2: THIS SHOULD BE LIGHTHOUSES' NEBULA IPs, NOT LIGHTHOUSES' REAL ROUTABLE IPs
hosts: hosts:
- "192.168.100.1" - "192.168.100.1"
# remote_allow_list allows you to control ip ranges that this node will
# consider when handshaking to another node. By default, any remote IPs are
# allowed. You can provide CIDRs here with `true` to allow and `false` to
# deny. The most specific CIDR rule applies to each remote. If all rules are
# "allow", the default will be "deny", and vice-versa. If both "allow" and
# "deny" rules are present, then you MUST set a rule for "0.0.0.0/0" as the
# default.
#remote_allow_list:
# Example to block IPs from this subnet from being used for remote IPs.
#"172.16.0.0/12": false
# A more complicated example, allow public IPs but only private IPs from a specific subnet
#"0.0.0.0/0": true
#"10.0.0.0/8": false
#"10.42.42.0/24": true
# local_allow_list allows you to filter which local IP addresses we advertise
# to the lighthouses. This uses the same logic as `remote_allow_list`, but
# additionally, you can specify an `interfaces` map of regular expressions
# to match against interface names. The regexp must match the entire name.
# All interface rules must be either true or false (and the default will be
# the inverse). CIDR rules are matched after interface name rules.
# Default is all local IP addresses.
#local_allow_list:
# Example to block tun0 and all docker interfaces.
#interfaces:
#tun0: false
#'docker.*': false
# Example to only advertise this subnet to the lighthouse.
#"10.0.0.0/8": true
# Port Nebula will be listening on. The default here is 4242. For a lighthouse node, the port should be defined, # Port Nebula will be listening on. The default here is 4242. For a lighthouse node, the port should be defined,
# however using port 0 will dynamically assign a port and is recommended for roaming nodes. # however using port 0 will dynamically assign a port and is recommended for roaming nodes.
listen: listen:
@@ -54,11 +86,17 @@ listen:
#read_buffer: 10485760 #read_buffer: 10485760
#write_buffer: 10485760 #write_buffer: 10485760
# Punchy continues to punch inbound/outbound at a regular interval to avoid expiration of firewall nat mappings punchy:
punchy: true # Continues to punch inbound/outbound at a regular interval to avoid expiration of firewall nat mappings
# punch_back means that a node you are trying to reach will connect back out to you if your hole punching fails punch: true
# this is extremely useful if one node is behind a difficult nat, such as symmetric
#punch_back: true # respond means that a node you are trying to reach will connect back out to you if your hole punching fails
# this is extremely useful if one node is behind a difficult nat, such as a symmetric NAT
# Default is false
#respond: true
# delays a punch response for misbehaving NATs, default is 1 second, respond must be true to take effect
#delay: 1s
# Cipher allows you to choose between the available ciphers for your network. # Cipher allows you to choose between the available ciphers for your network.
# IMPORTANT: this value must be identical on ALL NODES/LIGHTHOUSES. We do not/will not support use of different ciphers simultaneously! # IMPORTANT: this value must be identical on ALL NODES/LIGHTHOUSES. We do not/will not support use of different ciphers simultaneously!
@@ -86,6 +124,8 @@ punchy: true
# Configure the private interface. Note: addr is baked into the nebula certificate # Configure the private interface. Note: addr is baked into the nebula certificate
tun: tun:
# When tun is disabled, a lighthouse can be started without a local tun interface (and therefore without root)
disabled: false
# Name of the device # Name of the device
dev: nebula1 dev: nebula1
# Toggles forwarding of local broadcast packets, the address of which depends on the ip/mask encoded in pki.cert # Toggles forwarding of local broadcast packets, the address of which depends on the ip/mask encoded in pki.cert
@@ -116,6 +156,16 @@ logging:
level: info level: info
# json or text formats currently available. Default is text # json or text formats currently available. Default is text
format: text format: text
# Disable timestamp logging. useful when output is redirected to logging system that already adds timestamps. Default is false
#disable_timestamp: true
# timestamp format is specified in Go time format, see:
# https://golang.org/pkg/time/#pkg-constants
# default when `format: json`: "2006-01-02T15:04:05Z07:00" (RFC3339)
# default when `format: text`:
# when TTY attached: seconds since beginning of execution
# otherwise: "2006-01-02T15:04:05Z07:00" (RFC3339)
# As an example, to log as RFC3339 with millisecond precision, set to:
#timestamp_format: "2006-01-02T15:04:05.000Z07:00"
#stats: #stats:
#type: graphite #type: graphite
@@ -131,10 +181,31 @@ logging:
#subsystem: nebula #subsystem: nebula
#interval: 10s #interval: 10s
# enables counter metrics for meta packets
# e.g.: `messages.tx.handshake`
# NOTE: `message.{tx,rx}.recv_error` is always emitted
#message_metrics: false
# enables detailed counter metrics for lighthouse packets
# e.g.: `lighthouse.rx.HostQuery`
#lighthouse_metrics: false
# Handshake Manger Settings
#handshakes:
# Total time to try a handshake = sequence of `try_interval * retries`
# With 100ms interval and 20 retries it is 23.5 seconds
#try_interval: 100ms
#retries: 20
# wait_rotation is the number of handshake attempts to do before starting to try non-local IP addresses
#wait_rotation: 5
# trigger_buffer is the size of the buffer channel for quickly sending handshakes
# after receiving the response for lighthouse queries
#trigger_buffer: 64
# Nebula security group configuration # Nebula security group configuration
firewall: firewall:
conntrack: conntrack:
tcp_timeout: 120h tcp_timeout: 12m
udp_timeout: 3m udp_timeout: 3m
default_timeout: 10m default_timeout: 10m
max_connections: 100000 max_connections: 100000

View File

@@ -1,21 +1,21 @@
package nebula package nebula
import ( import (
"crypto/sha256"
"encoding/binary" "encoding/binary"
"encoding/hex"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"net" "net"
"sync"
"time"
"crypto/sha256"
"encoding/hex"
"errors"
"reflect" "reflect"
"strconv" "strconv"
"strings" "strings"
"sync"
"time"
"github.com/rcrowley/go-metrics" "github.com/rcrowley/go-metrics"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/cert"
) )
@@ -38,13 +38,19 @@ type FirewallInterface interface {
type conn struct { type conn struct {
Expires time.Time // Time when this conntrack entry will expire Expires time.Time // Time when this conntrack entry will expire
Seq uint32 // If tcp rtt tracking is enabled this will be the seq we are looking for an ack
Sent time.Time // If tcp rtt tracking is enabled this will be when Seq was last set Sent time.Time // If tcp rtt tracking is enabled this will be when Seq was last set
Seq uint32 // If tcp rtt tracking is enabled this will be the seq we are looking for an ack
// record why the original connection passed the firewall, so we can re-validate
// after ruleset changes. Note, rulesVersion is a uint16 so that these two
// fields pack for free after the uint32 above
incoming bool
rulesVersion uint16
} }
// TODO: need conntrack max tracked connections handling // TODO: need conntrack max tracked connections handling
type Firewall struct { type Firewall struct {
Conns map[FirewallPacket]*conn Conntrack *FirewallConntrack
InRules *FirewallTable InRules *FirewallTable
OutRules *FirewallTable OutRules *FirewallTable
@@ -55,18 +61,23 @@ type Firewall struct {
UDPTimeout time.Duration //linux: 180s max UDPTimeout time.Duration //linux: 180s max
DefaultTimeout time.Duration //linux: 600s DefaultTimeout time.Duration //linux: 600s
TimerWheel *TimerWheel
// Used to ensure we don't emit local packets for ips we don't own // Used to ensure we don't emit local packets for ips we don't own
localIps *CIDRTree localIps *CIDRTree
connMutex sync.Mutex
rules string rules string
rulesVersion uint16
trackTCPRTT bool trackTCPRTT bool
metricTCPRTT metrics.Histogram metricTCPRTT metrics.Histogram
} }
type FirewallConntrack struct {
sync.Mutex
Conns map[FirewallPacket]*conn
TimerWheel *TimerWheel
}
type FirewallTable struct { type FirewallTable struct {
TCP firewallPort TCP firewallPort
UDP firewallPort UDP firewallPort
@@ -172,10 +183,12 @@ func NewFirewall(tcpTimeout, UDPTimeout, defaultTimeout time.Duration, c *cert.N
} }
return &Firewall{ return &Firewall{
Conntrack: &FirewallConntrack{
Conns: make(map[FirewallPacket]*conn), Conns: make(map[FirewallPacket]*conn),
TimerWheel: NewTimerWheel(min, max),
},
InRules: newFirewallTable(), InRules: newFirewallTable(),
OutRules: newFirewallTable(), OutRules: newFirewallTable(),
TimerWheel: NewTimerWheel(min, max),
TCPTimeout: tcpTimeout, TCPTimeout: tcpTimeout,
UDPTimeout: UDPTimeout, UDPTimeout: UDPTimeout,
DefaultTimeout: defaultTimeout, DefaultTimeout: defaultTimeout,
@@ -208,11 +221,17 @@ func NewFirewallFromConfig(nc *cert.NebulaCertificate, c *Config) (*Firewall, er
// AddRule properly creates the in memory rule structure for a firewall table. // AddRule properly creates the in memory rule structure for a firewall table.
func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, caName string, caSha string) error { func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, caName string, caSha string) error {
// Under gomobile, stringing a nil pointer with fmt causes an abort in debug mode for iOS
// https://github.com/golang/go/issues/14131
sIp := ""
if ip != nil {
sIp = ip.String()
}
// We need this rule string because we generate a hash. Removing this will break firewall reload. // We need this rule string because we generate a hash. Removing this will break firewall reload.
ruleString := fmt.Sprintf( ruleString := fmt.Sprintf(
"incoming: %v, proto: %v, startPort: %v, endPort: %v, groups: %v, host: %v, ip: %v, caName: %v, caSha: %s", "incoming: %v, proto: %v, startPort: %v, endPort: %v, groups: %v, host: %v, ip: %v, caName: %v, caSha: %s",
incoming, proto, startPort, endPort, groups, host, ip, caName, caSha, incoming, proto, startPort, endPort, groups, host, sIp, caName, caSha,
) )
f.rules += ruleString + "\n" f.rules += ruleString + "\n"
@@ -220,7 +239,7 @@ func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort
if !incoming { if !incoming {
direction = "outgoing" direction = "outgoing"
} }
l.WithField("firewallRule", m{"direction": direction, "proto": proto, "startPort": startPort, "endPort": endPort, "groups": groups, "host": host, "ip": ip, "caName": caName, "caSha": caSha}). l.WithField("firewallRule", m{"direction": direction, "proto": proto, "startPort": startPort, "endPort": endPort, "groups": groups, "host": host, "ip": sIp, "caName": caName, "caSha": caSha}).
Info("Firewall rule added") Info("Firewall rule added")
var ( var (
@@ -347,20 +366,33 @@ func AddFirewallRulesFromConfig(inbound bool, config *Config, fw FirewallInterfa
return nil return nil
} }
func (f *Firewall) Drop(packet []byte, fp FirewallPacket, incoming bool, h *HostInfo, caPool *cert.NebulaCAPool) bool { var ErrInvalidRemoteIP = errors.New("remote IP is not in remote certificate subnets")
var ErrInvalidLocalIP = errors.New("local IP is not in list of handled local IPs")
var ErrNoMatchingRule = errors.New("no matching rule in firewall table")
// Drop returns an error if the packet should be dropped, explaining why. It
// returns nil if the packet should not be dropped.
func (f *Firewall) Drop(packet []byte, fp FirewallPacket, incoming bool, h *HostInfo, caPool *cert.NebulaCAPool) error {
// Check if we spoke to this tuple, if we did then allow this packet // Check if we spoke to this tuple, if we did then allow this packet
if f.inConns(packet, fp, incoming) { if f.inConns(packet, fp, incoming, h, caPool) {
return false return nil
} }
// Make sure remote address matches nebula certificate // Make sure remote address matches nebula certificate
if h.remoteCidr.Contains(fp.RemoteIP) == nil { if remoteCidr := h.remoteCidr; remoteCidr != nil {
return true if remoteCidr.Contains(fp.RemoteIP) == nil {
return ErrInvalidRemoteIP
}
} else {
// Simple case: Certificate has one IP and no subnets
if fp.RemoteIP != h.hostId {
return ErrInvalidRemoteIP
}
} }
// Make sure we are supposed to be handling this local ip address // Make sure we are supposed to be handling this local ip address
if f.localIps.Contains(fp.LocalIP) == nil { if f.localIps.Contains(fp.LocalIP) == nil {
return true return ErrInvalidLocalIP
} }
table := f.OutRules table := f.OutRules
@@ -370,13 +402,13 @@ func (f *Firewall) Drop(packet []byte, fp FirewallPacket, incoming bool, h *Host
// We now know which firewall table to check against // We now know which firewall table to check against
if !table.match(fp, incoming, h.ConnectionState.peerCert, caPool) { if !table.match(fp, incoming, h.ConnectionState.peerCert, caPool) {
return true return ErrNoMatchingRule
} }
// We always want to conntrack since it is a faster operation // We always want to conntrack since it is a faster operation
f.addConn(packet, fp, incoming) f.addConn(packet, fp, incoming)
return false return nil
} }
// Destroy cleans up any known cyclical references so the object can be free'd my GC. This should be called if a new // Destroy cleans up any known cyclical references so the object can be free'd my GC. This should be called if a new
@@ -386,26 +418,66 @@ func (f *Firewall) Destroy() {
} }
func (f *Firewall) EmitStats() { func (f *Firewall) EmitStats() {
conntrackCount := len(f.Conns) conntrack := f.Conntrack
conntrack.Lock()
conntrackCount := len(conntrack.Conns)
conntrack.Unlock()
metrics.GetOrRegisterGauge("firewall.conntrack.count", nil).Update(int64(conntrackCount)) metrics.GetOrRegisterGauge("firewall.conntrack.count", nil).Update(int64(conntrackCount))
metrics.GetOrRegisterGauge("firewall.rules.version", nil).Update(int64(f.rulesVersion))
} }
func (f *Firewall) inConns(packet []byte, fp FirewallPacket, incoming bool) bool { func (f *Firewall) inConns(packet []byte, fp FirewallPacket, incoming bool, h *HostInfo, caPool *cert.NebulaCAPool) bool {
f.connMutex.Lock() conntrack := f.Conntrack
conntrack.Lock()
// Purge every time we test // Purge every time we test
ep, has := f.TimerWheel.Purge() ep, has := conntrack.TimerWheel.Purge()
if has { if has {
f.evict(ep) f.evict(ep)
} }
c, ok := f.Conns[fp] c, ok := conntrack.Conns[fp]
if !ok { if !ok {
f.connMutex.Unlock() conntrack.Unlock()
return false return false
} }
if c.rulesVersion != f.rulesVersion {
// This conntrack entry was for an older rule set, validate
// it still passes with the current rule set
table := f.OutRules
if c.incoming {
table = f.InRules
}
// We now know which firewall table to check against
if !table.match(fp, c.incoming, h.ConnectionState.peerCert, caPool) {
if l.Level >= logrus.DebugLevel {
h.logger().
WithField("fwPacket", fp).
WithField("incoming", c.incoming).
WithField("rulesVersion", f.rulesVersion).
WithField("oldRulesVersion", c.rulesVersion).
Debugln("dropping old conntrack entry, does not match new ruleset")
}
delete(conntrack.Conns, fp)
conntrack.Unlock()
return false
}
if l.Level >= logrus.DebugLevel {
h.logger().
WithField("fwPacket", fp).
WithField("incoming", c.incoming).
WithField("rulesVersion", f.rulesVersion).
WithField("oldRulesVersion", c.rulesVersion).
Debugln("keeping old conntrack entry, does match new ruleset")
}
c.rulesVersion = f.rulesVersion
}
switch fp.Protocol { switch fp.Protocol {
case fwProtoTCP: case fwProtoTCP:
c.Expires = time.Now().Add(f.TCPTimeout) c.Expires = time.Now().Add(f.TCPTimeout)
@@ -420,7 +492,7 @@ func (f *Firewall) inConns(packet []byte, fp FirewallPacket, incoming bool) bool
c.Expires = time.Now().Add(f.DefaultTimeout) c.Expires = time.Now().Add(f.DefaultTimeout)
} }
f.connMutex.Unlock() conntrack.Unlock()
return true return true
} }
@@ -441,14 +513,19 @@ func (f *Firewall) addConn(packet []byte, fp FirewallPacket, incoming bool) {
timeout = f.DefaultTimeout timeout = f.DefaultTimeout
} }
f.connMutex.Lock() conntrack := f.Conntrack
if _, ok := f.Conns[fp]; !ok { conntrack.Lock()
f.TimerWheel.Add(fp, timeout) if _, ok := conntrack.Conns[fp]; !ok {
conntrack.TimerWheel.Add(fp, timeout)
} }
// Record which rulesVersion allowed this connection, so we can retest after
// firewall reload
c.incoming = incoming
c.rulesVersion = f.rulesVersion
c.Expires = time.Now().Add(timeout) c.Expires = time.Now().Add(timeout)
f.Conns[fp] = c conntrack.Conns[fp] = c
f.connMutex.Unlock() conntrack.Unlock()
} }
// Evict checks if a conntrack entry has expired, if so it is removed, if not it is re-added to the wheel // Evict checks if a conntrack entry has expired, if so it is removed, if not it is re-added to the wheel
@@ -456,7 +533,8 @@ func (f *Firewall) addConn(packet []byte, fp FirewallPacket, incoming bool) {
func (f *Firewall) evict(p FirewallPacket) { func (f *Firewall) evict(p FirewallPacket) {
//TODO: report a stat if the tcp rtt tracking was never resolved? //TODO: report a stat if the tcp rtt tracking was never resolved?
// Are we still tracking this conn? // Are we still tracking this conn?
t, ok := f.Conns[p] conntrack := f.Conntrack
t, ok := conntrack.Conns[p]
if !ok { if !ok {
return return
} }
@@ -465,12 +543,12 @@ func (f *Firewall) evict(p FirewallPacket) {
// Timeout is in the future, re-add the timer // Timeout is in the future, re-add the timer
if newT > 0 { if newT > 0 {
f.TimerWheel.Add(p, newT) conntrack.TimerWheel.Add(p, newT)
return return
} }
// This conn is done // This conn is done
delete(f.Conns, p) delete(conntrack.Conns, p)
} }
func (ft *FirewallTable) match(p FirewallPacket, incoming bool, c *cert.NebulaCertificate, caPool *cert.NebulaCAPool) bool { func (ft *FirewallTable) match(p FirewallPacket, incoming bool, c *cert.NebulaCertificate, caPool *cert.NebulaCAPool) bool {

View File

@@ -17,37 +17,39 @@ import (
func TestNewFirewall(t *testing.T) { func TestNewFirewall(t *testing.T) {
c := &cert.NebulaCertificate{} c := &cert.NebulaCertificate{}
fw := NewFirewall(time.Second, time.Minute, time.Hour, c) fw := NewFirewall(time.Second, time.Minute, time.Hour, c)
assert.NotNil(t, fw.Conns) conntrack := fw.Conntrack
assert.NotNil(t, conntrack)
assert.NotNil(t, conntrack.Conns)
assert.NotNil(t, conntrack.TimerWheel)
assert.NotNil(t, fw.InRules) assert.NotNil(t, fw.InRules)
assert.NotNil(t, fw.OutRules) assert.NotNil(t, fw.OutRules)
assert.NotNil(t, fw.TimerWheel)
assert.Equal(t, time.Second, fw.TCPTimeout) assert.Equal(t, time.Second, fw.TCPTimeout)
assert.Equal(t, time.Minute, fw.UDPTimeout) assert.Equal(t, time.Minute, fw.UDPTimeout)
assert.Equal(t, time.Hour, fw.DefaultTimeout) assert.Equal(t, time.Hour, fw.DefaultTimeout)
assert.Equal(t, time.Hour, fw.TimerWheel.wheelDuration) assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
assert.Equal(t, time.Hour, fw.TimerWheel.wheelDuration) assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
assert.Equal(t, 3601, fw.TimerWheel.wheelLen) assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen)
fw = NewFirewall(time.Second, time.Hour, time.Minute, c) fw = NewFirewall(time.Second, time.Hour, time.Minute, c)
assert.Equal(t, time.Hour, fw.TimerWheel.wheelDuration) assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
assert.Equal(t, 3601, fw.TimerWheel.wheelLen) assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen)
fw = NewFirewall(time.Hour, time.Second, time.Minute, c) fw = NewFirewall(time.Hour, time.Second, time.Minute, c)
assert.Equal(t, time.Hour, fw.TimerWheel.wheelDuration) assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
assert.Equal(t, 3601, fw.TimerWheel.wheelLen) assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen)
fw = NewFirewall(time.Hour, time.Minute, time.Second, c) fw = NewFirewall(time.Hour, time.Minute, time.Second, c)
assert.Equal(t, time.Hour, fw.TimerWheel.wheelDuration) assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
assert.Equal(t, 3601, fw.TimerWheel.wheelLen) assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen)
fw = NewFirewall(time.Minute, time.Hour, time.Second, c) fw = NewFirewall(time.Minute, time.Hour, time.Second, c)
assert.Equal(t, time.Hour, fw.TimerWheel.wheelDuration) assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
assert.Equal(t, 3601, fw.TimerWheel.wheelLen) assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen)
fw = NewFirewall(time.Minute, time.Second, time.Hour, c) fw = NewFirewall(time.Minute, time.Second, time.Hour, c)
assert.Equal(t, time.Hour, fw.TimerWheel.wheelDuration) assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
assert.Equal(t, 3601, fw.TimerWheel.wheelLen) assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen)
} }
func TestFirewall_AddRule(t *testing.T) { func TestFirewall_AddRule(t *testing.T) {
@@ -171,6 +173,7 @@ func TestFirewall_Drop(t *testing.T) {
ConnectionState: &ConnectionState{ ConnectionState: &ConnectionState{
peerCert: &c, peerCert: &c,
}, },
hostId: ip2int(ipNet.IP),
} }
h.CreateRemoteCIDR(&c) h.CreateRemoteCIDR(&c)
@@ -179,44 +182,44 @@ func TestFirewall_Drop(t *testing.T) {
cp := cert.NewCAPool() cp := cert.NewCAPool()
// Drop outbound // Drop outbound
assert.True(t, fw.Drop([]byte{}, p, false, &h, cp)) assert.Equal(t, fw.Drop([]byte{}, p, false, &h, cp), ErrNoMatchingRule)
// Allow inbound // Allow inbound
resetConntrack(fw) resetConntrack(fw)
assert.False(t, fw.Drop([]byte{}, p, true, &h, cp)) assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp))
// Allow outbound because conntrack // Allow outbound because conntrack
assert.False(t, fw.Drop([]byte{}, p, false, &h, cp)) assert.NoError(t, fw.Drop([]byte{}, p, false, &h, cp))
// test remote mismatch // test remote mismatch
oldRemote := p.RemoteIP oldRemote := p.RemoteIP
p.RemoteIP = ip2int(net.IPv4(1, 2, 3, 10)) p.RemoteIP = ip2int(net.IPv4(1, 2, 3, 10))
assert.True(t, fw.Drop([]byte{}, p, false, &h, cp)) assert.Equal(t, fw.Drop([]byte{}, p, false, &h, cp), ErrInvalidRemoteIP)
p.RemoteIP = oldRemote p.RemoteIP = oldRemote
// ensure signer doesn't get in the way of group checks // ensure signer doesn't get in the way of group checks
fw = NewFirewall(time.Second, time.Minute, time.Hour, &c) fw = NewFirewall(time.Second, time.Minute, time.Hour, &c)
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"nope"}, "", nil, "", "signer-shasum")) assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"nope"}, "", nil, "", "signer-shasum"))
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group"}, "", nil, "", "signer-shasum-bad")) assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group"}, "", nil, "", "signer-shasum-bad"))
assert.True(t, fw.Drop([]byte{}, p, true, &h, cp)) assert.Equal(t, fw.Drop([]byte{}, p, true, &h, cp), ErrNoMatchingRule)
// test caSha doesn't drop on match // test caSha doesn't drop on match
fw = NewFirewall(time.Second, time.Minute, time.Hour, &c) fw = NewFirewall(time.Second, time.Minute, time.Hour, &c)
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"nope"}, "", nil, "", "signer-shasum-bad")) assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"nope"}, "", nil, "", "signer-shasum-bad"))
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group"}, "", nil, "", "signer-shasum")) assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group"}, "", nil, "", "signer-shasum"))
assert.False(t, fw.Drop([]byte{}, p, true, &h, cp)) assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp))
// ensure ca name doesn't get in the way of group checks // ensure ca name doesn't get in the way of group checks
cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}} cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}}
fw = NewFirewall(time.Second, time.Minute, time.Hour, &c) fw = NewFirewall(time.Second, time.Minute, time.Hour, &c)
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"nope"}, "", nil, "ca-good", "")) assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"nope"}, "", nil, "ca-good", ""))
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group"}, "", nil, "ca-good-bad", "")) assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group"}, "", nil, "ca-good-bad", ""))
assert.True(t, fw.Drop([]byte{}, p, true, &h, cp)) assert.Equal(t, fw.Drop([]byte{}, p, true, &h, cp), ErrNoMatchingRule)
// test caName doesn't drop on match // test caName doesn't drop on match
cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}} cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}}
fw = NewFirewall(time.Second, time.Minute, time.Hour, &c) fw = NewFirewall(time.Second, time.Minute, time.Hour, &c)
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"nope"}, "", nil, "ca-good-bad", "")) assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"nope"}, "", nil, "ca-good-bad", ""))
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group"}, "", nil, "ca-good", "")) assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group"}, "", nil, "ca-good", ""))
assert.False(t, fw.Drop([]byte{}, p, true, &h, cp)) assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp))
} }
func BenchmarkFirewallTable_match(b *testing.B) { func BenchmarkFirewallTable_match(b *testing.B) {
@@ -344,6 +347,7 @@ func TestFirewall_Drop2(t *testing.T) {
ConnectionState: &ConnectionState{ ConnectionState: &ConnectionState{
peerCert: &c, peerCert: &c,
}, },
hostId: ip2int(ipNet.IP),
} }
h.CreateRemoteCIDR(&c) h.CreateRemoteCIDR(&c)
@@ -366,10 +370,10 @@ func TestFirewall_Drop2(t *testing.T) {
cp := cert.NewCAPool() cp := cert.NewCAPool()
// h1/c1 lacks the proper groups // h1/c1 lacks the proper groups
assert.True(t, fw.Drop([]byte{}, p, true, &h1, cp)) assert.Error(t, fw.Drop([]byte{}, p, true, &h1, cp), ErrNoMatchingRule)
// c has the proper groups // c has the proper groups
resetConntrack(fw) resetConntrack(fw)
assert.False(t, fw.Drop([]byte{}, p, true, &h, cp)) assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp))
} }
func TestFirewall_Drop3(t *testing.T) { func TestFirewall_Drop3(t *testing.T) {
@@ -410,6 +414,7 @@ func TestFirewall_Drop3(t *testing.T) {
ConnectionState: &ConnectionState{ ConnectionState: &ConnectionState{
peerCert: &c1, peerCert: &c1,
}, },
hostId: ip2int(ipNet.IP),
} }
h1.CreateRemoteCIDR(&c1) h1.CreateRemoteCIDR(&c1)
@@ -424,6 +429,7 @@ func TestFirewall_Drop3(t *testing.T) {
ConnectionState: &ConnectionState{ ConnectionState: &ConnectionState{
peerCert: &c2, peerCert: &c2,
}, },
hostId: ip2int(ipNet.IP),
} }
h2.CreateRemoteCIDR(&c2) h2.CreateRemoteCIDR(&c2)
@@ -438,6 +444,7 @@ func TestFirewall_Drop3(t *testing.T) {
ConnectionState: &ConnectionState{ ConnectionState: &ConnectionState{
peerCert: &c3, peerCert: &c3,
}, },
hostId: ip2int(ipNet.IP),
} }
h3.CreateRemoteCIDR(&c3) h3.CreateRemoteCIDR(&c3)
@@ -447,13 +454,81 @@ func TestFirewall_Drop3(t *testing.T) {
cp := cert.NewCAPool() cp := cert.NewCAPool()
// c1 should pass because host match // c1 should pass because host match
assert.False(t, fw.Drop([]byte{}, p, true, &h1, cp)) assert.NoError(t, fw.Drop([]byte{}, p, true, &h1, cp))
// c2 should pass because ca sha match // c2 should pass because ca sha match
resetConntrack(fw) resetConntrack(fw)
assert.False(t, fw.Drop([]byte{}, p, true, &h2, cp)) assert.NoError(t, fw.Drop([]byte{}, p, true, &h2, cp))
// c3 should fail because no match // c3 should fail because no match
resetConntrack(fw) resetConntrack(fw)
assert.True(t, fw.Drop([]byte{}, p, true, &h3, cp)) assert.Equal(t, fw.Drop([]byte{}, p, true, &h3, cp), ErrNoMatchingRule)
}
func TestFirewall_DropConntrackReload(t *testing.T) {
ob := &bytes.Buffer{}
out := l.Out
l.SetOutput(ob)
defer l.SetOutput(out)
p := FirewallPacket{
ip2int(net.IPv4(1, 2, 3, 4)),
ip2int(net.IPv4(1, 2, 3, 4)),
10,
90,
fwProtoUDP,
false,
}
ipNet := net.IPNet{
IP: net.IPv4(1, 2, 3, 4),
Mask: net.IPMask{255, 255, 255, 0},
}
c := cert.NebulaCertificate{
Details: cert.NebulaCertificateDetails{
Name: "host1",
Ips: []*net.IPNet{&ipNet},
Groups: []string{"default-group"},
InvertedGroups: map[string]struct{}{"default-group": {}},
Issuer: "signer-shasum",
},
}
h := HostInfo{
ConnectionState: &ConnectionState{
peerCert: &c,
},
hostId: ip2int(ipNet.IP),
}
h.CreateRemoteCIDR(&c)
fw := NewFirewall(time.Second, time.Minute, time.Hour, &c)
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"any"}, "", nil, "", ""))
cp := cert.NewCAPool()
// Drop outbound
assert.Equal(t, fw.Drop([]byte{}, p, false, &h, cp), ErrNoMatchingRule)
// Allow inbound
resetConntrack(fw)
assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp))
// Allow outbound because conntrack
assert.NoError(t, fw.Drop([]byte{}, p, false, &h, cp))
oldFw := fw
fw = NewFirewall(time.Second, time.Minute, time.Hour, &c)
assert.Nil(t, fw.AddRule(true, fwProtoAny, 10, 10, []string{"any"}, "", nil, "", ""))
fw.Conntrack = oldFw.Conntrack
fw.rulesVersion = oldFw.rulesVersion + 1
// Allow outbound because conntrack and new rules allow port 10
assert.NoError(t, fw.Drop([]byte{}, p, false, &h, cp))
oldFw = fw
fw = NewFirewall(time.Second, time.Minute, time.Hour, &c)
assert.Nil(t, fw.AddRule(true, fwProtoAny, 11, 11, []string{"any"}, "", nil, "", ""))
fw.Conntrack = oldFw.Conntrack
fw.rulesVersion = oldFw.rulesVersion + 1
// Drop outbound because conntrack doesn't match new ruleset
assert.Equal(t, fw.Drop([]byte{}, p, false, &h, cp), ErrNoMatchingRule)
} }
func BenchmarkLookup(b *testing.B) { func BenchmarkLookup(b *testing.B) {
@@ -856,7 +931,7 @@ func (mf *mockFirewall) AddRule(incoming bool, proto uint8, startPort int32, end
} }
func resetConntrack(fw *Firewall) { func resetConntrack(fw *Firewall) {
fw.connMutex.Lock() fw.Conntrack.Lock()
fw.Conns = map[FirewallPacket]*conn{} fw.Conntrack.Conns = map[FirewallPacket]*conn{}
fw.connMutex.Unlock() fw.Conntrack.Unlock()
} }

6
go.mod
View File

@@ -11,7 +11,7 @@ require (
github.com/flynn/noise v0.0.0-20180327030543-2492fe189ae6 github.com/flynn/noise v0.0.0-20180327030543-2492fe189ae6
github.com/golang/protobuf v1.3.2 github.com/golang/protobuf v1.3.2
github.com/imdario/mergo v0.3.8 github.com/imdario/mergo v0.3.8
github.com/kardianos/service v1.0.0 github.com/kardianos/service v1.1.0
github.com/konsorten/go-windows-terminal-sequences v1.0.2 // indirect github.com/konsorten/go-windows-terminal-sequences v1.0.2 // indirect
github.com/kr/pretty v0.1.0 // indirect github.com/kr/pretty v0.1.0 // indirect
github.com/miekg/dns v1.1.25 github.com/miekg/dns v1.1.25
@@ -22,10 +22,10 @@ require (
github.com/rcrowley/go-metrics v0.0.0-20190826022208-cac0b30c2563 github.com/rcrowley/go-metrics v0.0.0-20190826022208-cac0b30c2563
github.com/sirupsen/logrus v1.4.2 github.com/sirupsen/logrus v1.4.2
github.com/songgao/water v0.0.0-20190725173103-fd331bda3f4b github.com/songgao/water v0.0.0-20190725173103-fd331bda3f4b
github.com/stretchr/testify v1.4.0 github.com/stretchr/testify v1.6.1
github.com/vishvananda/netlink v1.0.1-0.20190522153524-00009fb8606a github.com/vishvananda/netlink v1.0.1-0.20190522153524-00009fb8606a
github.com/vishvananda/netns v0.0.0-20191106174202-0a2b9b5464df // indirect github.com/vishvananda/netns v0.0.0-20191106174202-0a2b9b5464df // indirect
golang.org/x/crypto v0.0.0-20191206172530-e9b2fee46413 golang.org/x/crypto v0.0.0-20200220183623-bac4c82f6975
golang.org/x/net v0.0.0-20191209160850-c0dbc17a3553 golang.org/x/net v0.0.0-20191209160850-c0dbc17a3553
golang.org/x/sys v0.0.0-20191210023423-ac6580df4449 golang.org/x/sys v0.0.0-20191210023423-ac6580df4449
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 // indirect gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 // indirect

12
go.sum
View File

@@ -46,6 +46,8 @@ github.com/json-iterator/go v1.1.7/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/u
github.com/julienschmidt/httprouter v1.2.0/go.mod h1:SYymIcj16QtmaHHD7aYtjjsJG7VTCxuUUipMqKk8s4w= github.com/julienschmidt/httprouter v1.2.0/go.mod h1:SYymIcj16QtmaHHD7aYtjjsJG7VTCxuUUipMqKk8s4w=
github.com/kardianos/service v1.0.0 h1:HgQS3mFfOlyntWX8Oke98JcJLqt1DBcHR4kxShpYef0= github.com/kardianos/service v1.0.0 h1:HgQS3mFfOlyntWX8Oke98JcJLqt1DBcHR4kxShpYef0=
github.com/kardianos/service v1.0.0/go.mod h1:8CzDhVuCuugtsHyZoTvsOBuvonN/UDBvl0kH+BUxvbo= github.com/kardianos/service v1.0.0/go.mod h1:8CzDhVuCuugtsHyZoTvsOBuvonN/UDBvl0kH+BUxvbo=
github.com/kardianos/service v1.1.0 h1:QV2SiEeWK42P0aEmGcsAgjApw/lRxkwopvT+Gu6t1/0=
github.com/kardianos/service v1.1.0/go.mod h1:RrJI2xn5vve/r32U5suTbeaSGoMU6GbNPoj36CVYcHc=
github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
github.com/konsorten/go-windows-terminal-sequences v1.0.2 h1:DB17ag19krx9CFsz4o3enTrPXyIXCl+2iCXH/aMAp9s= github.com/konsorten/go-windows-terminal-sequences v1.0.2 h1:DB17ag19krx9CFsz4o3enTrPXyIXCl+2iCXH/aMAp9s=
github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
@@ -103,8 +105,8 @@ github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q= github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0=
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/vishvananda/netlink v1.0.1-0.20190522153524-00009fb8606a h1:Bt1IVPhiCDMqwGrc2nnbIN4QKvJGx6SK2NzWBmW00ao= github.com/vishvananda/netlink v1.0.1-0.20190522153524-00009fb8606a h1:Bt1IVPhiCDMqwGrc2nnbIN4QKvJGx6SK2NzWBmW00ao=
github.com/vishvananda/netlink v1.0.1-0.20190522153524-00009fb8606a/go.mod h1:+SR5DhBJrl6ZM7CoCKvpw5BKroDKQ+PJqOg65H/2ktk= github.com/vishvananda/netlink v1.0.1-0.20190522153524-00009fb8606a/go.mod h1:+SR5DhBJrl6ZM7CoCKvpw5BKroDKQ+PJqOg65H/2ktk=
github.com/vishvananda/netns v0.0.0-20191106174202-0a2b9b5464df h1:OviZH7qLw/7ZovXvuNyL3XQl8UFofeikI1NW1Gypu7k= github.com/vishvananda/netns v0.0.0-20191106174202-0a2b9b5464df h1:OviZH7qLw/7ZovXvuNyL3XQl8UFofeikI1NW1Gypu7k=
@@ -112,8 +114,8 @@ github.com/vishvananda/netns v0.0.0-20191106174202-0a2b9b5464df/go.mod h1:JP3t17
golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20190923035154-9ee001bba392/go.mod h1:/lpIB1dKB+9EgE3H3cr1v9wB50oz8l4C4h62xy7jSTY= golang.org/x/crypto v0.0.0-20190923035154-9ee001bba392/go.mod h1:/lpIB1dKB+9EgE3H3cr1v9wB50oz8l4C4h62xy7jSTY=
golang.org/x/crypto v0.0.0-20191206172530-e9b2fee46413 h1:ULYEB3JvPRE/IfO+9uO7vKV/xzVTO7XPAwm8xbf4w2g= golang.org/x/crypto v0.0.0-20200220183623-bac4c82f6975 h1:/Tl7pH94bvbAAHBdZJT947M/+gp0+CqQXDtMRC0fseo=
golang.org/x/crypto v0.0.0-20191206172530-e9b2fee46413/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20200220183623-bac4c82f6975/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190613194153-d28f0bde5980/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20190613194153-d28f0bde5980/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
@@ -152,3 +154,5 @@ gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw=
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.2.7 h1:VUgggvou5XRW9mHwD/yXxIYSMtY0zoKQf/v226p2nyo= gopkg.in/yaml.v2 v2.2.7 h1:VUgggvou5XRW9mHwD/yXxIYSMtY0zoKQf/v226p2nyo=
gopkg.in/yaml.v2 v2.2.7/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.7/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

83
handler.go Normal file
View File

@@ -0,0 +1,83 @@
package nebula
func (f *Interface) newHook(w func([]byte) error) InsideHandler {
fn := func(hostInfo *HostInfo, ci *ConnectionState, addr *udpAddr, header *Header, out []byte, packet []byte, fwPacket *FirewallPacket, nb []byte) {
f.decryptTo(w, hostInfo, header.MessageCounter, out, packet, fwPacket, nb)
}
return f.encrypted(fn)
}
func (f *Interface) encrypted(h InsideHandler) InsideHandler {
return func(hostInfo *HostInfo, ci *ConnectionState, addr *udpAddr, header *Header, out []byte, packet []byte, fwPacket *FirewallPacket, nb []byte) {
if !f.handleEncrypted(ci, addr, header) {
return
}
h(hostInfo, ci, addr, header, out, packet, fwPacket, nb)
f.handleHostRoaming(hostInfo, addr)
f.connectionManager.In(hostInfo.hostId)
}
}
func (f *Interface) rxMetrics(h InsideHandler) InsideHandler {
return func(hostInfo *HostInfo, ci *ConnectionState, addr *udpAddr, header *Header, out []byte, packet []byte, fwPacket *FirewallPacket, nb []byte) {
f.messageMetrics.Rx(header.Type, header.Subtype, 1)
h(hostInfo, ci, addr, header, out, packet, fwPacket, nb)
}
}
func (f *Interface) handleMessagePacket(hostInfo *HostInfo, ci *ConnectionState, addr *udpAddr, header *Header, out []byte, packet []byte, fwPacket *FirewallPacket, nb []byte) {
f.decryptTo(f.inside.WriteRaw, hostInfo, header.MessageCounter, out, packet, fwPacket, nb)
}
func (f *Interface) handleLighthousePacket(hostInfo *HostInfo, ci *ConnectionState, addr *udpAddr, header *Header, out []byte, packet []byte, fwPacket *FirewallPacket, nb []byte) {
d, err := f.decrypt(hostInfo, header.MessageCounter, out, packet, header, nb)
if err != nil {
hostInfo.logger().WithError(err).WithField("udpAddr", addr).
WithField("packet", packet).
Error("Failed to decrypt lighthouse packet")
//TODO: maybe after build 64 is out? 06/14/2018 - NB
//f.sendRecvError(net.Addr(addr), header.RemoteIndex)
return
}
f.lightHouse.HandleRequest(addr, hostInfo.hostId, d, hostInfo.GetCert(), f)
}
func (f *Interface) handleTestPacket(hostInfo *HostInfo, ci *ConnectionState, addr *udpAddr, header *Header, out []byte, packet []byte, fwPacket *FirewallPacket, nb []byte) {
d, err := f.decrypt(hostInfo, header.MessageCounter, out, packet, header, nb)
if err != nil {
hostInfo.logger().WithError(err).WithField("udpAddr", addr).
WithField("packet", packet).
Error("Failed to decrypt test packet")
//TODO: maybe after build 64 is out? 06/14/2018 - NB
//f.sendRecvError(net.Addr(addr), header.RemoteIndex)
return
}
if header.Subtype == testRequest {
// This testRequest might be from TryPromoteBest, so we should roam
// to the new IP address before responding
f.handleHostRoaming(hostInfo, addr)
f.send(test, testReply, ci, hostInfo, hostInfo.remote, d, nb, out)
}
}
func (f *Interface) handleHandshakePacket(hostInfo *HostInfo, ci *ConnectionState, addr *udpAddr, header *Header, out []byte, packet []byte, fwPacket *FirewallPacket, nb []byte) {
HandleIncomingHandshake(f, addr, packet, header, hostInfo)
}
func (f *Interface) handleRecvErrorPacket(hostInfo *HostInfo, ci *ConnectionState, addr *udpAddr, header *Header, out []byte, packet []byte, fwPacket *FirewallPacket, nb []byte) {
// TODO: Remove this with recv_error deprecation
f.handleRecvError(addr, header)
}
func (f *Interface) handleCloseTunnelPacket(hostInfo *HostInfo, ci *ConnectionState, addr *udpAddr, header *Header, out []byte, packet []byte, fwPacket *FirewallPacket, nb []byte) {
hostInfo.logger().WithField("udpAddr", addr).
Info("Close tunnel received, tearing down.")
f.closeTunnel(hostInfo)
}

View File

@@ -13,6 +13,11 @@ func HandleIncomingHandshake(f *Interface, addr *udpAddr, packet []byte, h *Head
// return // return
//} //}
if !f.lightHouse.remoteAllowList.Allow(udp2ipInt(addr)) {
l.WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
return
}
tearDown := false tearDown := false
switch h.Subtype { switch h.Subtype {
case handshakeIXPSK0: case handshakeIXPSK0:

View File

@@ -1,11 +1,10 @@
package nebula package nebula
import ( import (
"bytes"
"sync/atomic" "sync/atomic"
"time" "time"
"bytes"
"github.com/flynn/noise" "github.com/flynn/noise"
"github.com/golang/protobuf/proto" "github.com/golang/protobuf/proto"
) )
@@ -98,6 +97,7 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
hostinfo, _ := f.handshakeManager.pendingHostMap.QueryReverseIndex(hs.Details.InitiatorIndex) hostinfo, _ := f.handshakeManager.pendingHostMap.QueryReverseIndex(hs.Details.InitiatorIndex)
if hostinfo != nil && bytes.Equal(hostinfo.HandshakePacket[0], packet[HeaderLen:]) { if hostinfo != nil && bytes.Equal(hostinfo.HandshakePacket[0], packet[HeaderLen:]) {
if msg, ok := hostinfo.HandshakePacket[2]; ok { if msg, ok := hostinfo.HandshakePacket[2]; ok {
f.messageMetrics.Tx(handshake, NebulaMessageSubType(msg[1]), 1)
err := f.outside.WriteTo(msg, addr) err := f.outside.WriteTo(msg, addr)
if err != nil { if err != nil {
l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr). l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
@@ -125,10 +125,14 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
return true return true
} }
vpnIP := ip2int(remoteCert.Details.Ips[0].IP) vpnIP := ip2int(remoteCert.Details.Ips[0].IP)
certName := remoteCert.Details.Name
fingerprint, _ := remoteCert.Sha256Sum()
myIndex, err := generateIndex() myIndex, err := generateIndex()
if err != nil { if err != nil {
l.WithError(err).WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr). l.WithError(err).WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to generate index") WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to generate index")
return true return true
} }
@@ -136,11 +140,15 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
hostinfo, err = f.handshakeManager.AddIndex(myIndex, ci) hostinfo, err = f.handshakeManager.AddIndex(myIndex, ci)
if err != nil { if err != nil {
l.WithError(err).WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr). l.WithError(err).WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Error adding index to connection manager") WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Error adding index to connection manager")
return true return true
} }
l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr). l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
Info("Handshake message received") Info("Handshake message received")
@@ -152,6 +160,8 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
hsBytes, err := proto.Marshal(hs) hsBytes, err := proto.Marshal(hs)
if err != nil { if err != nil {
l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr). l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to marshal handshake message") WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to marshal handshake message")
return true return true
} }
@@ -160,12 +170,16 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
msg, dKey, eKey, err := ci.H.WriteMessage(header, hsBytes) msg, dKey, eKey, err := ci.H.WriteMessage(header, hsBytes)
if err != nil { if err != nil {
l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr). l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage") WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage")
return true return true
} }
if f.hostMap.CheckHandshakeCompleteIP(vpnIP) && vpnIP < ip2int(f.certState.certificate.Details.Ips[0].IP) { if f.hostMap.CheckHandshakeCompleteIP(vpnIP) && vpnIP < ip2int(f.certState.certificate.Details.Ips[0].IP) {
l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr). l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
Info("Prevented a handshake race") Info("Prevented a handshake race")
@@ -184,14 +198,19 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
hostinfo.HandshakePacket[2] = make([]byte, len(msg)) hostinfo.HandshakePacket[2] = make([]byte, len(msg))
copy(hostinfo.HandshakePacket[2], msg) copy(hostinfo.HandshakePacket[2], msg)
f.messageMetrics.Tx(handshake, NebulaMessageSubType(msg[1]), 1)
err := f.outside.WriteTo(msg, addr) err := f.outside.WriteTo(msg, addr)
if err != nil { if err != nil {
l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr). l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
WithError(err).Error("Failed to send handshake") WithError(err).Error("Failed to send handshake")
} else { } else {
l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr). l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
Info("Handshake message sent") Info("Handshake message sent")
@@ -214,6 +233,8 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
ho, err := f.hostMap.QueryVpnIP(vpnIP) ho, err := f.hostMap.QueryVpnIP(vpnIP)
if err == nil && ho.localIndexId != 0 { if err == nil && ho.localIndexId != 0 {
l.WithField("vpnIp", vpnIP). l.WithField("vpnIp", vpnIP).
WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("action", "removing stale index"). WithField("action", "removing stale index").
WithField("index", ho.localIndexId). WithField("index", ho.localIndexId).
Debug("Handshake processing") Debug("Handshake processing")
@@ -226,6 +247,8 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
hostinfo.handshakeComplete() hostinfo.handshakeComplete()
} else { } else {
l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr). l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
Error("Noise did not arrive at a key") Error("Noise did not arrive at a key")
return true return true
@@ -284,9 +307,13 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
return true return true
} }
vpnIP := ip2int(remoteCert.Details.Ips[0].IP) vpnIP := ip2int(remoteCert.Details.Ips[0].IP)
certName := remoteCert.Details.Name
fingerprint, _ := remoteCert.Sha256Sum()
duration := time.Since(hostinfo.handshakeStart).Nanoseconds() duration := time.Since(hostinfo.handshakeStart).Nanoseconds()
l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr). l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
WithField("durationNs", duration). WithField("durationNs", duration).
@@ -324,6 +351,8 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
ho, err := f.hostMap.QueryVpnIP(vpnIP) ho, err := f.hostMap.QueryVpnIP(vpnIP)
if err == nil && ho.localIndexId != 0 { if err == nil && ho.localIndexId != 0 {
l.WithField("vpnIp", vpnIP). l.WithField("vpnIp", vpnIP).
WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("action", "removing stale index"). WithField("action", "removing stale index").
WithField("index", ho.localIndexId). WithField("index", ho.localIndexId).
Debug("Handshake processing") Debug("Handshake processing")
@@ -337,6 +366,8 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
f.metricHandshakes.Update(duration) f.metricHandshakes.Update(duration)
} else { } else {
l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr). l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
Error("Noise did not arrive at a key") Error("Noise did not arrive at a key")
return true return true

View File

@@ -13,41 +13,78 @@ import (
const ( const (
// Total time to try a handshake = sequence of HandshakeTryInterval * HandshakeRetries // Total time to try a handshake = sequence of HandshakeTryInterval * HandshakeRetries
// With 100ms interval and 20 retries is 23.5 seconds // With 100ms interval and 20 retries is 23.5 seconds
HandshakeTryInterval = time.Millisecond * 100 DefaultHandshakeTryInterval = time.Millisecond * 100
HandshakeRetries = 20 DefaultHandshakeRetries = 20
// HandshakeWaitRotation is the number of handshake attempts to do before starting to use other ips addresses // DefaultHandshakeWaitRotation is the number of handshake attempts to do before starting to use other ips addresses
HandshakeWaitRotation = 5 DefaultHandshakeWaitRotation = 5
DefaultHandshakeTriggerBuffer = 64
) )
var (
defaultHandshakeConfig = HandshakeConfig{
tryInterval: DefaultHandshakeTryInterval,
retries: DefaultHandshakeRetries,
waitRotation: DefaultHandshakeWaitRotation,
triggerBuffer: DefaultHandshakeTriggerBuffer,
}
)
type HandshakeConfig struct {
tryInterval time.Duration
retries int
waitRotation int
triggerBuffer int
messageMetrics *MessageMetrics
}
type HandshakeManager struct { type HandshakeManager struct {
pendingHostMap *HostMap pendingHostMap *HostMap
mainHostMap *HostMap mainHostMap *HostMap
lightHouse *LightHouse lightHouse *LightHouse
outside *udpConn outside *udpConn
config HandshakeConfig
// can be used to trigger outbound handshake for the given vpnIP
trigger chan uint32
OutboundHandshakeTimer *SystemTimerWheel OutboundHandshakeTimer *SystemTimerWheel
InboundHandshakeTimer *SystemTimerWheel InboundHandshakeTimer *SystemTimerWheel
messageMetrics *MessageMetrics
} }
func NewHandshakeManager(tunCidr *net.IPNet, preferredRanges []*net.IPNet, mainHostMap *HostMap, lightHouse *LightHouse, outside *udpConn) *HandshakeManager { func NewHandshakeManager(tunCidr *net.IPNet, preferredRanges []*net.IPNet, mainHostMap *HostMap, lightHouse *LightHouse, outside *udpConn, config HandshakeConfig) *HandshakeManager {
return &HandshakeManager{ return &HandshakeManager{
pendingHostMap: NewHostMap("pending", tunCidr, preferredRanges), pendingHostMap: NewHostMap("pending", tunCidr, preferredRanges),
mainHostMap: mainHostMap, mainHostMap: mainHostMap,
lightHouse: lightHouse, lightHouse: lightHouse,
outside: outside, outside: outside,
OutboundHandshakeTimer: NewSystemTimerWheel(HandshakeTryInterval, HandshakeTryInterval*HandshakeRetries), config: config,
InboundHandshakeTimer: NewSystemTimerWheel(HandshakeTryInterval, HandshakeTryInterval*HandshakeRetries),
trigger: make(chan uint32, config.triggerBuffer),
OutboundHandshakeTimer: NewSystemTimerWheel(config.tryInterval, config.tryInterval*time.Duration(config.retries)),
InboundHandshakeTimer: NewSystemTimerWheel(config.tryInterval, config.tryInterval*time.Duration(config.retries)),
messageMetrics: config.messageMetrics,
} }
} }
func (c *HandshakeManager) Run(f EncWriter) { func (c *HandshakeManager) Run(f EncWriter) {
clockSource := time.Tick(HandshakeTryInterval) clockSource := time.Tick(c.config.tryInterval)
for now := range clockSource { for {
select {
case vpnIP := <-c.trigger:
l.WithField("vpnIp", IntIp(vpnIP)).Debug("HandshakeManager: triggered")
c.handleOutbound(vpnIP, f, true)
case now := <-clockSource:
c.NextOutboundHandshakeTimerTick(now, f) c.NextOutboundHandshakeTimerTick(now, f)
c.NextInboundHandshakeTimerTick(now) c.NextInboundHandshakeTimerTick(now)
} }
} }
}
func (c *HandshakeManager) NextOutboundHandshakeTimerTick(now time.Time, f EncWriter) { func (c *HandshakeManager) NextOutboundHandshakeTimerTick(now time.Time, f EncWriter) {
c.OutboundHandshakeTimer.advance(now) c.OutboundHandshakeTimer.advance(now)
@@ -57,46 +94,63 @@ func (c *HandshakeManager) NextOutboundHandshakeTimerTick(now time.Time, f EncWr
break break
} }
vpnIP := ep.(uint32) vpnIP := ep.(uint32)
c.handleOutbound(vpnIP, f, false)
index, err := c.pendingHostMap.GetIndexByVpnIP(vpnIP) }
if err != nil {
continue
} }
func (c *HandshakeManager) handleOutbound(vpnIP uint32, f EncWriter, lighthouseTriggered bool) {
index, err := c.pendingHostMap.GetIndexByVpnIP(vpnIP)
if err != nil {
return
}
hostinfo, err := c.pendingHostMap.QueryVpnIP(vpnIP) hostinfo, err := c.pendingHostMap.QueryVpnIP(vpnIP)
if err != nil { if err != nil {
continue return
} }
// If we haven't finished the handshake and we haven't hit max retries, query // If we haven't finished the handshake and we haven't hit max retries, query
// lighthouse and then send the handshake packet again. // lighthouse and then send the handshake packet again.
if hostinfo.HandshakeCounter < HandshakeRetries && !hostinfo.HandshakeComplete { if hostinfo.HandshakeCounter < c.config.retries && !hostinfo.HandshakeComplete {
if hostinfo.remote == nil { if hostinfo.remote == nil {
// We continue to query the lighthouse because hosts may // We continue to query the lighthouse because hosts may
// come online during handshake retries. If the query // come online during handshake retries. If the query
// succeeds (no error), add the lighthouse info to hostinfo // succeeds (no error), add the lighthouse info to hostinfo
ips, err := c.lightHouse.Query(vpnIP, f) ips := c.lightHouse.QueryCache(vpnIP)
// If we have no responses yet, or only one IP (the host hadn't
// finished reporting its own IPs yet), then send another query to
// the LH.
if len(ips) <= 1 {
ips, err = c.lightHouse.Query(vpnIP, f)
}
if err == nil { if err == nil {
for _, ip := range ips { for _, ip := range ips {
hostinfo.AddRemote(ip) hostinfo.AddRemote(ip)
} }
hostinfo.ForcePromoteBest(c.mainHostMap.preferredRanges) hostinfo.ForcePromoteBest(c.mainHostMap.preferredRanges)
} }
} else if lighthouseTriggered {
// We were triggered by a lighthouse HostQueryReply packet, but
// we have already picked a remote for this host (this can happen
// if we are configured with multiple lighthouses). So we can skip
// this trigger and let the timerwheel handle the rest of the
// process
return
} }
hostinfo.HandshakeCounter++ hostinfo.HandshakeCounter++
// We want to use the "best" calculated ip for the first 5 attempts, after that we just blindly rotate through // We want to use the "best" calculated ip for the first 5 attempts, after that we just blindly rotate through
// all the others until we can stand up a connection. // all the others until we can stand up a connection.
if hostinfo.HandshakeCounter > HandshakeWaitRotation { if hostinfo.HandshakeCounter > c.config.waitRotation {
hostinfo.rotateRemote() hostinfo.rotateRemote()
} }
// Ensure the handshake is ready to avoid a race in timer tick and stage 0 handshake generation // Ensure the handshake is ready to avoid a race in timer tick and stage 0 handshake generation
if hostinfo.HandshakeReady && hostinfo.remote != nil { if hostinfo.HandshakeReady && hostinfo.remote != nil {
c.messageMetrics.Tx(handshake, NebulaMessageSubType(hostinfo.HandshakePacket[0][1]), 1)
err := c.outside.WriteTo(hostinfo.HandshakePacket[0], hostinfo.remote) err := c.outside.WriteTo(hostinfo.HandshakePacket[0], hostinfo.remote)
if err != nil { if err != nil {
l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", hostinfo.remote). hostinfo.logger().WithField("udpAddr", hostinfo.remote).
WithField("initiatorIndex", hostinfo.localIndexId). WithField("initiatorIndex", hostinfo.localIndexId).
WithField("remoteIndex", hostinfo.remoteIndexId). WithField("remoteIndex", hostinfo.remoteIndexId).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
@@ -104,7 +158,7 @@ func (c *HandshakeManager) NextOutboundHandshakeTimerTick(now time.Time, f EncWr
} else { } else {
//TODO: this log line is assuming a lot of stuff around the cached stage 0 handshake packet, we should //TODO: this log line is assuming a lot of stuff around the cached stage 0 handshake packet, we should
// keep the real packet struct around for logging purposes // keep the real packet struct around for logging purposes
l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", hostinfo.remote). hostinfo.logger().WithField("udpAddr", hostinfo.remote).
WithField("initiatorIndex", hostinfo.localIndexId). WithField("initiatorIndex", hostinfo.localIndexId).
WithField("remoteIndex", hostinfo.remoteIndexId). WithField("remoteIndex", hostinfo.remoteIndexId).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
@@ -113,14 +167,15 @@ func (c *HandshakeManager) NextOutboundHandshakeTimerTick(now time.Time, f EncWr
} }
// Readd to the timer wheel so we continue trying wait HandshakeTryInterval * counter longer for next try // Readd to the timer wheel so we continue trying wait HandshakeTryInterval * counter longer for next try
if !lighthouseTriggered {
//l.Infoln("Interval: ", HandshakeTryInterval*time.Duration(hostinfo.HandshakeCounter)) //l.Infoln("Interval: ", HandshakeTryInterval*time.Duration(hostinfo.HandshakeCounter))
c.OutboundHandshakeTimer.Add(vpnIP, HandshakeTryInterval*time.Duration(hostinfo.HandshakeCounter)) c.OutboundHandshakeTimer.Add(vpnIP, c.config.tryInterval*time.Duration(hostinfo.HandshakeCounter))
}
} else { } else {
c.pendingHostMap.DeleteVpnIP(vpnIP) c.pendingHostMap.DeleteVpnIP(vpnIP)
c.pendingHostMap.DeleteIndex(index) c.pendingHostMap.DeleteIndex(index)
} }
} }
}
func (c *HandshakeManager) NextInboundHandshakeTimerTick(now time.Time) { func (c *HandshakeManager) NextInboundHandshakeTimerTick(now time.Time) {
c.InboundHandshakeTimer.advance(now) c.InboundHandshakeTimer.advance(now)
@@ -144,7 +199,8 @@ func (c *HandshakeManager) AddVpnIP(vpnIP uint32) *HostInfo {
hostinfo := c.pendingHostMap.AddVpnIP(vpnIP) hostinfo := c.pendingHostMap.AddVpnIP(vpnIP)
// We lock here and use an array to insert items to prevent locking the // We lock here and use an array to insert items to prevent locking the
// main receive thread for very long by waiting to add items to the pending map // main receive thread for very long by waiting to add items to the pending map
c.OutboundHandshakeTimer.Add(vpnIP, HandshakeTryInterval) c.OutboundHandshakeTimer.Add(vpnIP, c.config.tryInterval)
return hostinfo return hostinfo
} }

View File

@@ -21,7 +21,7 @@ func Test_NewHandshakeManagerIndex(t *testing.T) {
preferredRanges := []*net.IPNet{localrange} preferredRanges := []*net.IPNet{localrange}
mainHM := NewHostMap("test", vpncidr, preferredRanges) mainHM := NewHostMap("test", vpncidr, preferredRanges)
blah := NewHandshakeManager(tuncidr, preferredRanges, mainHM, &LightHouse{}, &udpConn{}) blah := NewHandshakeManager(tuncidr, preferredRanges, mainHM, &LightHouse{}, &udpConn{}, defaultHandshakeConfig)
now := time.Now() now := time.Now()
blah.NextInboundHandshakeTimerTick(now) blah.NextInboundHandshakeTimerTick(now)
@@ -37,8 +37,8 @@ func Test_NewHandshakeManagerIndex(t *testing.T) {
// Adding something to pending should not affect the main hostmap // Adding something to pending should not affect the main hostmap
assert.Len(t, mainHM.Indexes, 0) assert.Len(t, mainHM.Indexes, 0)
// Jump ahead 8 seconds // Jump ahead 8 seconds
for i := 1; i <= HandshakeRetries; i++ { for i := 1; i <= DefaultHandshakeRetries; i++ {
next_tick := now.Add(HandshakeTryInterval * time.Duration(i)) next_tick := now.Add(DefaultHandshakeTryInterval * time.Duration(i))
blah.NextInboundHandshakeTimerTick(next_tick) blah.NextInboundHandshakeTimerTick(next_tick)
} }
// Confirm they are still in the pending index list // Confirm they are still in the pending index list
@@ -63,7 +63,7 @@ func Test_NewHandshakeManagerVpnIP(t *testing.T) {
mw := &mockEncWriter{} mw := &mockEncWriter{}
mainHM := NewHostMap("test", vpncidr, preferredRanges) mainHM := NewHostMap("test", vpncidr, preferredRanges)
blah := NewHandshakeManager(tuncidr, preferredRanges, mainHM, &LightHouse{}, &udpConn{}) blah := NewHandshakeManager(tuncidr, preferredRanges, mainHM, &LightHouse{}, &udpConn{}, defaultHandshakeConfig)
now := time.Now() now := time.Now()
blah.NextOutboundHandshakeTimerTick(now, mw) blah.NextOutboundHandshakeTimerTick(now, mw)
@@ -81,8 +81,8 @@ func Test_NewHandshakeManagerVpnIP(t *testing.T) {
// Jump ahead `HandshakeRetries` ticks // Jump ahead `HandshakeRetries` ticks
cumulative := time.Duration(0) cumulative := time.Duration(0)
for i := 0; i <= HandshakeRetries+1; i++ { for i := 0; i <= DefaultHandshakeRetries+1; i++ {
cumulative += time.Duration(i)*HandshakeTryInterval + 1 cumulative += time.Duration(i)*DefaultHandshakeTryInterval + 1
next_tick := now.Add(cumulative) next_tick := now.Add(cumulative)
//l.Infoln(next_tick) //l.Infoln(next_tick)
blah.NextOutboundHandshakeTimerTick(next_tick, mw) blah.NextOutboundHandshakeTimerTick(next_tick, mw)
@@ -93,7 +93,7 @@ func Test_NewHandshakeManagerVpnIP(t *testing.T) {
assert.Contains(t, blah.pendingHostMap.Hosts, uint32(v)) assert.Contains(t, blah.pendingHostMap.Hosts, uint32(v))
} }
// Jump ahead 1 more second // Jump ahead 1 more second
cumulative += time.Duration(HandshakeRetries+1) * HandshakeTryInterval cumulative += time.Duration(DefaultHandshakeRetries+1) * DefaultHandshakeTryInterval
next_tick := now.Add(cumulative) next_tick := now.Add(cumulative)
//l.Infoln(next_tick) //l.Infoln(next_tick)
blah.NextOutboundHandshakeTimerTick(next_tick, mw) blah.NextOutboundHandshakeTimerTick(next_tick, mw)
@@ -103,6 +103,56 @@ func Test_NewHandshakeManagerVpnIP(t *testing.T) {
} }
} }
func Test_NewHandshakeManagerTrigger(t *testing.T) {
_, tuncidr, _ := net.ParseCIDR("172.1.1.1/24")
_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
ip := ip2int(net.ParseIP("172.1.1.2"))
preferredRanges := []*net.IPNet{localrange}
mw := &mockEncWriter{}
mainHM := NewHostMap("test", vpncidr, preferredRanges)
lh := &LightHouse{}
blah := NewHandshakeManager(tuncidr, preferredRanges, mainHM, lh, &udpConn{}, defaultHandshakeConfig)
now := time.Now()
blah.NextOutboundHandshakeTimerTick(now, mw)
assert.Equal(t, 0, testCountTimerWheelEntries(blah.OutboundHandshakeTimer))
blah.AddVpnIP(ip)
assert.Equal(t, 1, testCountTimerWheelEntries(blah.OutboundHandshakeTimer))
// Trigger the same method the channel will
blah.handleOutbound(ip, mw, true)
// Make sure the trigger doesn't schedule another timer entry
assert.Equal(t, 1, testCountTimerWheelEntries(blah.OutboundHandshakeTimer))
hi := blah.pendingHostMap.Hosts[ip]
assert.Nil(t, hi.remote)
lh.addrMap = map[uint32][]udpAddr{
ip: {*NewUDPAddrFromString("10.1.1.1:4242")},
}
// This should trigger the hostmap to populate the hostinfo
blah.handleOutbound(ip, mw, true)
assert.NotNil(t, hi.remote)
assert.Equal(t, 1, testCountTimerWheelEntries(blah.OutboundHandshakeTimer))
}
func testCountTimerWheelEntries(tw *SystemTimerWheel) (c int) {
for _, i := range tw.wheel {
n := i.Head
for n != nil {
c++
n = n.Next
}
}
return c
}
func Test_NewHandshakeManagerVpnIPcleanup(t *testing.T) { func Test_NewHandshakeManagerVpnIPcleanup(t *testing.T) {
_, tuncidr, _ := net.ParseCIDR("172.1.1.1/24") _, tuncidr, _ := net.ParseCIDR("172.1.1.1/24")
_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24") _, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
@@ -112,7 +162,7 @@ func Test_NewHandshakeManagerVpnIPcleanup(t *testing.T) {
mw := &mockEncWriter{} mw := &mockEncWriter{}
mainHM := NewHostMap("test", vpncidr, preferredRanges) mainHM := NewHostMap("test", vpncidr, preferredRanges)
blah := NewHandshakeManager(tuncidr, preferredRanges, mainHM, &LightHouse{}, &udpConn{}) blah := NewHandshakeManager(tuncidr, preferredRanges, mainHM, &LightHouse{}, &udpConn{}, defaultHandshakeConfig)
now := time.Now() now := time.Now()
blah.NextOutboundHandshakeTimerTick(now, mw) blah.NextOutboundHandshakeTimerTick(now, mw)
@@ -125,8 +175,8 @@ func Test_NewHandshakeManagerVpnIPcleanup(t *testing.T) {
// Jump ahead `HandshakeRetries` ticks. Eviction should happen in pending // Jump ahead `HandshakeRetries` ticks. Eviction should happen in pending
// but not main hostmap // but not main hostmap
cumulative := time.Duration(0) cumulative := time.Duration(0)
for i := 1; i <= HandshakeRetries+2; i++ { for i := 1; i <= DefaultHandshakeRetries+2; i++ {
cumulative += HandshakeTryInterval * time.Duration(i) cumulative += DefaultHandshakeTryInterval * time.Duration(i)
next_tick := now.Add(cumulative) next_tick := now.Add(cumulative)
blah.NextOutboundHandshakeTimerTick(next_tick, mw) blah.NextOutboundHandshakeTimerTick(next_tick, mw)
} }
@@ -161,7 +211,7 @@ func Test_NewHandshakeManagerIndexcleanup(t *testing.T) {
preferredRanges := []*net.IPNet{localrange} preferredRanges := []*net.IPNet{localrange}
mainHM := NewHostMap("test", vpncidr, preferredRanges) mainHM := NewHostMap("test", vpncidr, preferredRanges)
blah := NewHandshakeManager(tuncidr, preferredRanges, mainHM, &LightHouse{}, &udpConn{}) blah := NewHandshakeManager(tuncidr, preferredRanges, mainHM, &LightHouse{}, &udpConn{}, defaultHandshakeConfig)
now := time.Now() now := time.Now()
blah.NextInboundHandshakeTimerTick(now) blah.NextInboundHandshakeTimerTick(now)
@@ -171,12 +221,12 @@ func Test_NewHandshakeManagerIndexcleanup(t *testing.T) {
blah.pendingHostMap.AddVpnIPHostInfo(101010, hostinfo) blah.pendingHostMap.AddVpnIPHostInfo(101010, hostinfo)
assert.Contains(t, blah.pendingHostMap.Hosts, uint32(101010)) assert.Contains(t, blah.pendingHostMap.Hosts, uint32(101010))
for i := 1; i <= HandshakeRetries+2; i++ { for i := 1; i <= DefaultHandshakeRetries+2; i++ {
next_tick := now.Add(HandshakeTryInterval * time.Duration(i)) next_tick := now.Add(DefaultHandshakeTryInterval * time.Duration(i))
blah.NextInboundHandshakeTimerTick(next_tick) blah.NextInboundHandshakeTimerTick(next_tick)
} }
next_tick := now.Add(HandshakeTryInterval*HandshakeRetries + 3) next_tick := now.Add(DefaultHandshakeTryInterval*DefaultHandshakeRetries + 3)
blah.NextInboundHandshakeTimerTick(next_tick) blah.NextInboundHandshakeTimerTick(next_tick)
assert.NotContains(t, blah.pendingHostMap.Hosts, uint32(101010)) assert.NotContains(t, blah.pendingHostMap.Hosts, uint32(101010))
assert.NotContains(t, blah.pendingHostMap.Indexes, uint32(12341234)) assert.NotContains(t, blah.pendingHostMap.Indexes, uint32(12341234))

View File

@@ -54,6 +54,7 @@ var typeMap = map[NebulaMessageType]string{
} }
const ( const (
subTypeNone NebulaMessageSubType = 0
testRequest NebulaMessageSubType = 0 testRequest NebulaMessageSubType = 0
testReply NebulaMessageSubType = 1 testReply NebulaMessageSubType = 1
) )

View File

@@ -1,9 +1,10 @@
package nebula package nebula
import ( import (
"github.com/stretchr/testify/assert"
"reflect" "reflect"
"testing" "testing"
"github.com/stretchr/testify/assert"
) )
type headerTest struct { type headerTest struct {

View File

@@ -30,6 +30,7 @@ type HostMap struct {
vpnCIDR *net.IPNet vpnCIDR *net.IPNet
defaultRoute uint32 defaultRoute uint32
unsafeRoutes *CIDRTree unsafeRoutes *CIDRTree
metricsEnabled bool
} }
type HostInfo struct { type HostInfo struct {
@@ -384,8 +385,16 @@ func (hm *HostMap) PunchList() []*udpAddr {
} }
func (hm *HostMap) Punchy(conn *udpConn) { func (hm *HostMap) Punchy(conn *udpConn) {
var metricsTxPunchy metrics.Counter
if hm.metricsEnabled {
metricsTxPunchy = metrics.GetOrRegisterCounter("messages.tx.punchy", nil)
} else {
metricsTxPunchy = metrics.NilCounter{}
}
for { for {
for _, addr := range hm.PunchList() { for _, addr := range hm.PunchList() {
metricsTxPunchy.Inc(1)
conn.WriteTo([]byte{1}, addr) conn.WriteTo([]byte{1}, addr)
} }
time.Sleep(time.Second * 30) time.Sleep(time.Second * 30)
@@ -532,13 +541,13 @@ func (i *HostInfo) cachePacket(t NebulaMessageType, st NebulaMessageSubType, pac
copy(tempPacket, packet) copy(tempPacket, packet)
//l.WithField("trace", string(debug.Stack())).Error("Caching packet", tempPacket) //l.WithField("trace", string(debug.Stack())).Error("Caching packet", tempPacket)
i.packetStore = append(i.packetStore, &cachedPacket{t, st, f, tempPacket}) i.packetStore = append(i.packetStore, &cachedPacket{t, st, f, tempPacket})
l.WithField("vpnIp", IntIp(i.hostId)). i.logger().
WithField("length", len(i.packetStore)). WithField("length", len(i.packetStore)).
WithField("stored", true). WithField("stored", true).
Debugf("Packet store") Debugf("Packet store")
} else if l.Level >= logrus.DebugLevel { } else if l.Level >= logrus.DebugLevel {
l.WithField("vpnIp", IntIp(i.hostId)). i.logger().
WithField("length", len(i.packetStore)). WithField("length", len(i.packetStore)).
WithField("stored", false). WithField("stored", false).
Debugf("Packet store") Debugf("Packet store")
@@ -556,7 +565,7 @@ func (i *HostInfo) handshakeComplete() {
//TODO: this should be managed by the handshake state machine to set it based on how many handshake were seen. //TODO: this should be managed by the handshake state machine to set it based on how many handshake were seen.
// Clamping it to 2 gets us out of the woods for now // Clamping it to 2 gets us out of the woods for now
*i.ConnectionState.messageCounter = 2 *i.ConnectionState.messageCounter = 2
l.WithField("vpnIp", IntIp(i.hostId)).Debugf("Sending %d stored packets", len(i.packetStore)) i.logger().Debugf("Sending %d stored packets", len(i.packetStore))
nb := make([]byte, 12, 12) nb := make([]byte, 12, 12)
out := make([]byte, mtu) out := make([]byte, mtu)
for _, cp := range i.packetStore { for _, cp := range i.packetStore {
@@ -623,6 +632,11 @@ func (i *HostInfo) RecvErrorExceeded() bool {
} }
func (i *HostInfo) CreateRemoteCIDR(c *cert.NebulaCertificate) { func (i *HostInfo) CreateRemoteCIDR(c *cert.NebulaCertificate) {
if len(c.Details.Ips) == 1 && len(c.Details.Subnets) == 0 {
// Simple case, no CIDRTree needed
return
}
remoteCidr := NewCIDRTree() remoteCidr := NewCIDRTree()
for _, ip := range c.Details.Ips { for _, ip := range c.Details.Ips {
remoteCidr.AddCIDR(&net.IPNet{IP: ip.IP, Mask: net.IPMask{255, 255, 255, 255}}, struct{}{}) remoteCidr.AddCIDR(&net.IPNet{IP: ip.IP, Mask: net.IPMask{255, 255, 255, 255}}, struct{}{})
@@ -634,6 +648,22 @@ func (i *HostInfo) CreateRemoteCIDR(c *cert.NebulaCertificate) {
i.remoteCidr = remoteCidr i.remoteCidr = remoteCidr
} }
func (i *HostInfo) logger() *logrus.Entry {
if i == nil {
return logrus.NewEntry(l)
}
li := l.WithField("vpnIp", IntIp(i.hostId))
if connState := i.ConnectionState; connState != nil {
if peerCert := connState.peerCert; peerCert != nil {
li = li.WithField("certName", peerCert.Details.Name)
}
}
return li
}
//######################## //########################
func NewHostInfoDest(addr *udpAddr) *HostInfoDest { func NewHostInfoDest(addr *udpAddr) *HostInfoDest {
@@ -734,11 +764,16 @@ func (d *HostInfoDest) ProbeReceived(probeCount int) {
// Utility functions // Utility functions
func localIps() *[]net.IP { func localIps(allowList *AllowList) *[]net.IP {
//FIXME: This function is pretty garbage //FIXME: This function is pretty garbage
var ips []net.IP var ips []net.IP
ifaces, _ := net.Interfaces() ifaces, _ := net.Interfaces()
for _, i := range ifaces { for _, i := range ifaces {
allow := allowList.AllowName(i.Name)
l.WithField("interfaceName", i.Name).WithField("allow", allow).Debug("localAllowList.AllowName")
if !allow {
continue
}
addrs, _ := i.Addrs() addrs, _ := i.Addrs()
for _, addr := range addrs { for _, addr := range addrs {
var ip net.IP var ip net.IP
@@ -750,6 +785,12 @@ func localIps() *[]net.IP {
ip = v.IP ip = v.IP
} }
if ip.To4() != nil && ip.IsLoopback() == false { if ip.To4() != nil && ip.IsLoopback() == false {
allow := allowList.Allow(ip2int(ip))
l.WithField("localIp", ip).WithField("allow", allow).Debug("localAllowList.Allow")
if !allow {
continue
}
ips = append(ips, ip) ips = append(ips, ip)
} }
} }

View File

@@ -161,6 +161,4 @@ func BenchmarkHostmappromote2(b *testing.B) {
m.AddRemote(ip2int(net.ParseIP("10.128.1.1")), g) m.AddRemote(ip2int(net.ParseIP("10.128.1.1")), g)
m.AddRemote(ip2int(net.ParseIP("10.128.1.1")), y) m.AddRemote(ip2int(net.ParseIP("10.128.1.1")), y)
} }
b.Errorf("hi")
} }

View File

@@ -7,7 +7,7 @@ import (
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
) )
func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *FirewallPacket, nb, out []byte) { func (f *Interface) consumeInsidePacket(st NebulaMessageSubType, packet []byte, fwPacket *FirewallPacket, nb, out []byte) {
err := newPacket(packet, false, fwPacket) err := newPacket(packet, false, fwPacket)
if err != nil { if err != nil {
l.WithField("packet", packet).Debugf("Error while validating outbound packet: %s", err) l.WithField("packet", packet).Debugf("Error while validating outbound packet: %s", err)
@@ -19,12 +19,25 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *FirewallPacket,
return return
} }
// Ignore packets from self to self
if fwPacket.RemoteIP == f.lightHouse.myIp {
return
}
// Ignore broadcast packets // Ignore broadcast packets
if f.dropMulticast && isMulticast(fwPacket.RemoteIP) { if f.dropMulticast && isMulticast(fwPacket.RemoteIP) {
return return
} }
hostinfo := f.getOrHandshake(fwPacket.RemoteIP) hostinfo := f.getOrHandshake(fwPacket.RemoteIP)
if hostinfo == nil {
if l.Level >= logrus.DebugLevel {
l.WithField("vpnIp", IntIp(fwPacket.RemoteIP)).
WithField("fwPacket", fwPacket).
Debugln("dropping outbound packet, vpnIp not in our CIDR or in unsafe routes")
}
return
}
ci := hostinfo.ConnectionState ci := hostinfo.ConnectionState
if ci.ready == false { if ci.ready == false {
@@ -32,28 +45,35 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *FirewallPacket,
// the packet queue. // the packet queue.
ci.queueLock.Lock() ci.queueLock.Lock()
if !ci.ready { if !ci.ready {
hostinfo.cachePacket(message, 0, packet, f.sendMessageNow) hostinfo.cachePacket(message, st, packet, f.sendMessageNow)
ci.queueLock.Unlock() ci.queueLock.Unlock()
return return
} }
ci.queueLock.Unlock() ci.queueLock.Unlock()
} }
if !f.firewall.Drop(packet, *fwPacket, false, hostinfo, trustedCAs) { dropReason := f.firewall.Drop(packet, *fwPacket, false, hostinfo, trustedCAs)
f.send(message, 0, ci, hostinfo, hostinfo.remote, packet, nb, out) if dropReason == nil {
if f.lightHouse != nil && *ci.messageCounter%5000 == 0 { mc := f.sendNoMetrics(message, st, ci, hostinfo, hostinfo.remote, packet, nb, out)
if f.lightHouse != nil && mc%5000 == 0 {
f.lightHouse.Query(fwPacket.RemoteIP, f) f.lightHouse.Query(fwPacket.RemoteIP, f)
} }
} else if l.Level >= logrus.DebugLevel { } else if l.Level >= logrus.DebugLevel {
l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("fwPacket", fwPacket). hostinfo.logger().
WithField("fwPacket", fwPacket).
WithField("reason", dropReason).
Debugln("dropping outbound packet") Debugln("dropping outbound packet")
} }
} }
// getOrHandshake returns nil if the vpnIp is not routable
func (f *Interface) getOrHandshake(vpnIp uint32) *HostInfo { func (f *Interface) getOrHandshake(vpnIp uint32) *HostInfo {
if f.hostMap.vpnCIDR.Contains(int2ip(vpnIp)) == false { if f.hostMap.vpnCIDR.Contains(int2ip(vpnIp)) == false {
vpnIp = f.hostMap.queryUnsafeRoute(vpnIp) vpnIp = f.hostMap.queryUnsafeRoute(vpnIp)
if vpnIp == 0 {
return nil
}
} }
hostinfo, err := f.hostMap.PromoteBestQueryVpnIP(vpnIp, f) hostinfo, err := f.hostMap.PromoteBestQueryVpnIP(vpnIp, f)
@@ -86,6 +106,15 @@ func (f *Interface) getOrHandshake(vpnIp uint32) *HostInfo {
ixHandshakeStage0(f, vpnIp, hostinfo) ixHandshakeStage0(f, vpnIp, hostinfo)
// FIXME: Maybe make XX selectable, but probably not since psk makes it nearly pointless for us. // FIXME: Maybe make XX selectable, but probably not since psk makes it nearly pointless for us.
//xx_handshakeStage0(f, ip, hostinfo) //xx_handshakeStage0(f, ip, hostinfo)
// If this is a static host, we don't need to wait for the HostQueryReply
// We can trigger the handshake right now
if _, ok := f.lightHouse.staticList[vpnIp]; ok {
select {
case f.handshakeManager.trigger <- vpnIp:
default:
}
}
} }
return hostinfo return hostinfo
@@ -100,12 +129,17 @@ func (f *Interface) sendMessageNow(t NebulaMessageType, st NebulaMessageSubType,
} }
// check if packet is in outbound fw rules // check if packet is in outbound fw rules
if f.firewall.Drop(p, *fp, false, hostInfo, trustedCAs) { dropReason := f.firewall.Drop(p, *fp, false, hostInfo, trustedCAs)
l.WithField("fwPacket", fp).Debugln("dropping cached packet") if dropReason != nil {
if l.Level >= logrus.DebugLevel {
l.WithField("fwPacket", fp).
WithField("reason", dropReason).
Debugln("dropping cached packet")
}
return return
} }
f.send(message, st, hostInfo.ConnectionState, hostInfo, hostInfo.remote, p, nb, out) f.sendNoMetrics(message, st, hostInfo.ConnectionState, hostInfo, hostInfo.remote, p, nb, out)
if f.lightHouse != nil && *hostInfo.ConnectionState.messageCounter%5000 == 0 { if f.lightHouse != nil && *hostInfo.ConnectionState.messageCounter%5000 == 0 {
f.lightHouse.Query(fp.RemoteIP, f) f.lightHouse.Query(fp.RemoteIP, f)
} }
@@ -114,6 +148,13 @@ func (f *Interface) sendMessageNow(t NebulaMessageType, st NebulaMessageSubType,
// SendMessageToVpnIp handles real ip:port lookup and sends to the current best known address for vpnIp // SendMessageToVpnIp handles real ip:port lookup and sends to the current best known address for vpnIp
func (f *Interface) SendMessageToVpnIp(t NebulaMessageType, st NebulaMessageSubType, vpnIp uint32, p, nb, out []byte) { func (f *Interface) SendMessageToVpnIp(t NebulaMessageType, st NebulaMessageSubType, vpnIp uint32, p, nb, out []byte) {
hostInfo := f.getOrHandshake(vpnIp) hostInfo := f.getOrHandshake(vpnIp)
if hostInfo == nil {
if l.Level >= logrus.DebugLevel {
l.WithField("vpnIp", IntIp(vpnIp)).
Debugln("dropping SendMessageToVpnIp, vpnIp not in our CIDR or in unsafe routes")
}
return
}
if !hostInfo.ConnectionState.ready { if !hostInfo.ConnectionState.ready {
// Because we might be sending stored packets, lock here to stop new things going to // Because we might be sending stored packets, lock here to stop new things going to
@@ -138,6 +179,13 @@ func (f *Interface) sendMessageToVpnIp(t NebulaMessageType, st NebulaMessageSubT
// SendMessageToAll handles real ip:port lookup and sends to all known addresses for vpnIp // SendMessageToAll handles real ip:port lookup and sends to all known addresses for vpnIp
func (f *Interface) SendMessageToAll(t NebulaMessageType, st NebulaMessageSubType, vpnIp uint32, p, nb, out []byte) { func (f *Interface) SendMessageToAll(t NebulaMessageType, st NebulaMessageSubType, vpnIp uint32, p, nb, out []byte) {
hostInfo := f.getOrHandshake(vpnIp) hostInfo := f.getOrHandshake(vpnIp)
if hostInfo == nil {
if l.Level >= logrus.DebugLevel {
l.WithField("vpnIp", IntIp(vpnIp)).
Debugln("dropping SendMessageToAll, vpnIp not in our CIDR or in unsafe routes")
}
return
}
if hostInfo.ConnectionState.ready == false { if hostInfo.ConnectionState.ready == false {
// Because we might be sending stored packets, lock here to stop new things going to // Because we might be sending stored packets, lock here to stop new things going to
@@ -162,9 +210,14 @@ func (f *Interface) sendMessageToAll(t NebulaMessageType, st NebulaMessageSubTyp
} }
func (f *Interface) send(t NebulaMessageType, st NebulaMessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote *udpAddr, p, nb, out []byte) { func (f *Interface) send(t NebulaMessageType, st NebulaMessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote *udpAddr, p, nb, out []byte) {
f.messageMetrics.Tx(t, st, 1)
f.sendNoMetrics(t, st, ci, hostinfo, remote, p, nb, out)
}
func (f *Interface) sendNoMetrics(t NebulaMessageType, st NebulaMessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote *udpAddr, p, nb, out []byte) uint64 {
if ci.eKey == nil { if ci.eKey == nil {
//TODO: log warning //TODO: log warning
return return 0
} }
var err error var err error
@@ -180,18 +233,19 @@ func (f *Interface) send(t NebulaMessageType, st NebulaMessageSubType, ci *Conne
//TODO: see above note on lock //TODO: see above note on lock
//ci.writeLock.Unlock() //ci.writeLock.Unlock()
if err != nil { if err != nil {
l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)). hostinfo.logger().WithError(err).
WithField("udpAddr", remote).WithField("counter", c). WithField("udpAddr", remote).WithField("counter", c).
WithField("attemptedCounter", ci.messageCounter). WithField("attemptedCounter", ci.messageCounter).
Error("Failed to encrypt outgoing packet") Error("Failed to encrypt outgoing packet")
return return c
} }
err = f.outside.WriteTo(out, remote) err = f.outside.WriteTo(out, remote)
if err != nil { if err != nil {
l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)). hostinfo.logger().WithError(err).
WithField("udpAddr", remote).Error("Failed to write outgoing packet") WithField("udpAddr", remote).Error("Failed to write outgoing packet")
} }
return c
} }
func isMulticast(ip uint32) bool { func isMulticast(ip uint32) bool {

View File

@@ -2,6 +2,8 @@ package nebula
import ( import (
"errors" "errors"
"io"
"net"
"os" "os"
"time" "time"
@@ -10,10 +12,20 @@ import (
const mtu = 9001 const mtu = 9001
type Inside interface {
io.ReadWriteCloser
Activate() error
CidrNet() *net.IPNet
DeviceName() string
WriteRaw([]byte) error
}
type InsideHandler func(hostInfo *HostInfo, ci *ConnectionState, addr *udpAddr, header *Header, out []byte, packet []byte, fp *FirewallPacket, nb []byte)
type InterfaceConfig struct { type InterfaceConfig struct {
HostMap *HostMap HostMap *HostMap
Outside *udpConn Outside *udpConn
Inside *Tun Inside Inside
certState *CertState certState *CertState
Cipher string Cipher string
Firewall *Firewall Firewall *Firewall
@@ -25,12 +37,16 @@ type InterfaceConfig struct {
DropLocalBroadcast bool DropLocalBroadcast bool
DropMulticast bool DropMulticast bool
UDPBatchSize int UDPBatchSize int
udpQueues int
tunQueues int
MessageMetrics *MessageMetrics
version string
} }
type Interface struct { type Interface struct {
hostMap *HostMap hostMap *HostMap
outside *udpConn outside *udpConn
inside *Tun inside Inside
certState *CertState certState *CertState
cipher string cipher string
firewall *Firewall firewall *Firewall
@@ -43,11 +59,15 @@ type Interface struct {
dropLocalBroadcast bool dropLocalBroadcast bool
dropMulticast bool dropMulticast bool
udpBatchSize int udpBatchSize int
udpQueues int
tunQueues int
version string version string
metricRxRecvError metrics.Counter // handlers are mapped by protocol version -> type -> subtype
metricTxRecvError metrics.Counter handlers map[uint8]map[NebulaMessageType]map[NebulaMessageSubType]InsideHandler
metricHandshakes metrics.Histogram metricHandshakes metrics.Histogram
messageMetrics *MessageMetrics
} }
func NewInterface(c *InterfaceConfig) (*Interface, error) { func NewInterface(c *InterfaceConfig) (*Interface, error) {
@@ -79,35 +99,64 @@ func NewInterface(c *InterfaceConfig) (*Interface, error) {
dropLocalBroadcast: c.DropLocalBroadcast, dropLocalBroadcast: c.DropLocalBroadcast,
dropMulticast: c.DropMulticast, dropMulticast: c.DropMulticast,
udpBatchSize: c.UDPBatchSize, udpBatchSize: c.UDPBatchSize,
udpQueues: c.udpQueues,
tunQueues: c.tunQueues,
version: c.version,
metricRxRecvError: metrics.GetOrRegisterCounter("messages.rx.recv_error", nil),
metricTxRecvError: metrics.GetOrRegisterCounter("messages.tx.recv_error", nil),
metricHandshakes: metrics.GetOrRegisterHistogram("handshakes", nil, metrics.NewExpDecaySample(1028, 0.015)), metricHandshakes: metrics.GetOrRegisterHistogram("handshakes", nil, metrics.NewExpDecaySample(1028, 0.015)),
messageMetrics: c.MessageMetrics,
} }
ifce.connectionManager = newConnectionManager(ifce, c.checkInterval, c.pendingDeletionInterval) ifce.connectionManager = newConnectionManager(ifce, c.checkInterval, c.pendingDeletionInterval)
ifce.handlers = map[uint8]map[NebulaMessageType]map[NebulaMessageSubType]InsideHandler{
Version: {
handshake: {
handshakeIXPSK0: ifce.rxMetrics(ifce.handleHandshakePacket),
},
message: {
subTypeNone: ifce.encrypted(ifce.handleMessagePacket),
},
recvError: {
subTypeNone: ifce.rxMetrics(ifce.handleRecvErrorPacket),
},
lightHouse: {
subTypeNone: ifce.rxMetrics(ifce.encrypted(ifce.handleLighthousePacket)),
},
test: {
testRequest: ifce.rxMetrics(ifce.encrypted(ifce.handleTestPacket)),
testReply: ifce.rxMetrics(ifce.encrypted(ifce.handleTestPacket)),
},
closeTunnel: {
subTypeNone: ifce.rxMetrics(ifce.encrypted(ifce.handleCloseTunnelPacket)),
},
},
}
return ifce, nil return ifce, nil
} }
func (f *Interface) Run(tunRoutines, udpRoutines int, buildVersion string) { func (f *Interface) run() {
// actually turn on tun dev // actually turn on tun dev
if err := f.inside.Activate(); err != nil { if err := f.inside.Activate(); err != nil {
l.Fatal(err) l.Fatal(err)
} }
f.version = buildVersion addr, err := f.outside.LocalAddr()
l.WithField("interface", f.inside.Device).WithField("network", f.inside.Cidr.String()). if err != nil {
WithField("build", buildVersion). l.WithError(err).Error("Failed to get udp listen address")
}
l.WithField("interface", f.inside.DeviceName()).WithField("network", f.inside.CidrNet().String()).
WithField("build", f.version).WithField("udpAddr", addr).
Info("Nebula interface is active") Info("Nebula interface is active")
// Launch n queues to read packets from udp // Launch n queues to read packets from udp
for i := 0; i < udpRoutines; i++ { for i := 0; i < f.udpQueues; i++ {
go f.listenOut(i) go f.listenOut(i)
} }
// Launch n queues to read packets from tun dev // Launch n queues to read packets from tun dev
for i := 0; i < tunRoutines; i++ { for i := 0; i < f.tunQueues; i++ {
go f.listenIn(i) go f.listenIn(i)
} }
} }
@@ -147,7 +196,7 @@ func (f *Interface) listenIn(i int) {
os.Exit(2) os.Exit(2)
} }
f.consumeInsidePacket(packet[:n], fwPacket, nb, out) f.consumeInsidePacket(subTypeNone, packet[:n], fwPacket, nb, out)
} }
} }
@@ -205,11 +254,28 @@ func (f *Interface) reloadFirewall(c *Config) {
} }
oldFw := f.firewall oldFw := f.firewall
conntrack := oldFw.Conntrack
conntrack.Lock()
defer conntrack.Unlock()
fw.rulesVersion = oldFw.rulesVersion + 1
// If rulesVersion is back to zero, we have wrapped all the way around. Be
// safe and just reset conntrack in this case.
if fw.rulesVersion == 0 {
l.WithField("firewallHash", fw.GetRuleHash()).
WithField("oldFirewallHash", oldFw.GetRuleHash()).
WithField("rulesVersion", fw.rulesVersion).
Warn("firewall rulesVersion has overflowed, resetting conntrack")
} else {
fw.Conntrack = conntrack
}
f.firewall = fw f.firewall = fw
oldFw.Destroy() oldFw.Destroy()
l.WithField("firewallHash", fw.GetRuleHash()). l.WithField("firewallHash", fw.GetRuleHash()).
WithField("oldFirewallHash", oldFw.GetRuleHash()). WithField("oldFirewallHash", oldFw.GetRuleHash()).
WithField("rulesVersion", fw.rulesVersion).
Info("New firewall has been installed") Info("New firewall has been installed")
} }

View File

@@ -7,6 +7,7 @@ import (
"time" "time"
"github.com/golang/protobuf/proto" "github.com/golang/protobuf/proto"
"github.com/rcrowley/go-metrics"
"github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/cert"
) )
@@ -19,6 +20,19 @@ type LightHouse struct {
// Local cache of answers from light houses // Local cache of answers from light houses
addrMap map[uint32][]udpAddr addrMap map[uint32][]udpAddr
// filters remote addresses allowed for each host
// - When we are a lighthouse, this filters what addresses we store and
// respond with.
// - When we are not a lighthouse, this filters which addresses we accept
// from lighthouses.
remoteAllowList *AllowList
// filters local addresses that we advertise to lighthouses
localAllowList *AllowList
// used to trigger the HandshakeManager when we receive HostQueryReply
handshakeTrigger chan<- uint32
// staticList exists to avoid having a bool in each addrMap entry // staticList exists to avoid having a bool in each addrMap entry
// since static should be rare // since static should be rare
staticList map[uint32]struct{} staticList map[uint32]struct{}
@@ -26,6 +40,10 @@ type LightHouse struct {
interval int interval int
nebulaPort int nebulaPort int
punchBack bool punchBack bool
punchDelay time.Duration
metrics *MessageMetrics
metricHolepunchTx metrics.Counter
} }
type EncWriter interface { type EncWriter interface {
@@ -33,7 +51,7 @@ type EncWriter interface {
SendMessageToAll(t NebulaMessageType, st NebulaMessageSubType, vpnIp uint32, p, nb, out []byte) SendMessageToAll(t NebulaMessageType, st NebulaMessageSubType, vpnIp uint32, p, nb, out []byte)
} }
func NewLightHouse(amLighthouse bool, myIp uint32, ips []uint32, interval int, nebulaPort int, pc *udpConn, punchBack bool) *LightHouse { func NewLightHouse(amLighthouse bool, myIp uint32, ips []uint32, interval int, nebulaPort int, pc *udpConn, punchBack bool, punchDelay time.Duration, metricsEnabled bool) *LightHouse {
h := LightHouse{ h := LightHouse{
amLighthouse: amLighthouse, amLighthouse: amLighthouse,
myIp: myIp, myIp: myIp,
@@ -44,6 +62,15 @@ func NewLightHouse(amLighthouse bool, myIp uint32, ips []uint32, interval int, n
interval: interval, interval: interval,
punchConn: pc, punchConn: pc,
punchBack: punchBack, punchBack: punchBack,
punchDelay: punchDelay,
}
if metricsEnabled {
h.metrics = newLighthouseMetrics()
h.metricHolepunchTx = metrics.GetOrRegisterCounter("messages.tx.holepunch", nil)
} else {
h.metricHolepunchTx = metrics.NilCounter{}
} }
for _, ip := range ips { for _, ip := range ips {
@@ -53,6 +80,20 @@ func NewLightHouse(amLighthouse bool, myIp uint32, ips []uint32, interval int, n
return &h return &h
} }
func (lh *LightHouse) SetRemoteAllowList(allowList *AllowList) {
lh.Lock()
defer lh.Unlock()
lh.remoteAllowList = allowList
}
func (lh *LightHouse) SetLocalAllowList(allowList *AllowList) {
lh.Lock()
defer lh.Unlock()
lh.localAllowList = allowList
}
func (lh *LightHouse) ValidateLHStaticEntries() error { func (lh *LightHouse) ValidateLHStaticEntries() error {
for lhIP, _ := range lh.lighthouses { for lhIP, _ := range lh.lighthouses {
if _, ok := lh.staticList[lhIP]; !ok { if _, ok := lh.staticList[lhIP]; !ok {
@@ -85,6 +126,7 @@ func (lh *LightHouse) QueryServer(ip uint32, f EncWriter) {
return return
} }
lh.metricTx(NebulaMeta_HostQuery, int64(len(lh.lighthouses)))
nb := make([]byte, 12, 12) nb := make([]byte, 12, 12)
out := make([]byte, mtu) out := make([]byte, mtu)
for n := range lh.lighthouses { for n := range lh.lighthouses {
@@ -133,6 +175,13 @@ func (lh *LightHouse) AddRemote(vpnIP uint32, toIp *udpAddr, static bool) {
return return
} }
} }
allow := lh.remoteAllowList.Allow(udp2ipInt(toIp))
l.WithField("remoteIp", toIp).WithField("allow", allow).Debug("remoteAllowList.Allow")
if !allow {
return
}
//l.Debugf("Adding reply of %s as %s\n", IntIp(vpnIP), toIp) //l.Debugf("Adding reply of %s as %s\n", IntIp(vpnIP), toIp)
if static { if static {
lh.staticList[vpnIP] = struct{}{} lh.staticList[vpnIP] = struct{}{}
@@ -201,7 +250,7 @@ func (lh *LightHouse) LhUpdateWorker(f EncWriter) {
for { for {
ipp := []*IpAndPort{} ipp := []*IpAndPort{}
for _, e := range *localIps() { for _, e := range *localIps(lh.localAllowList) {
// Only add IPs that aren't my VPN/tun IP // Only add IPs that aren't my VPN/tun IP
if ip2int(e) != lh.myIp { if ip2int(e) != lh.myIp {
ipp = append(ipp, &IpAndPort{Ip: ip2int(e), Port: uint32(lh.nebulaPort)}) ipp = append(ipp, &IpAndPort{Ip: ip2int(e), Port: uint32(lh.nebulaPort)})
@@ -216,6 +265,7 @@ func (lh *LightHouse) LhUpdateWorker(f EncWriter) {
}, },
} }
lh.metricTx(NebulaMeta_HostUpdateNotification, int64(len(lh.lighthouses)))
nb := make([]byte, 12, 12) nb := make([]byte, 12, 12)
out := make([]byte, mtu) out := make([]byte, mtu)
for vpnIp := range lh.lighthouses { for vpnIp := range lh.lighthouses {
@@ -248,6 +298,8 @@ func (lh *LightHouse) HandleRequest(rAddr *udpAddr, vpnIp uint32, p []byte, c *c
return return
} }
lh.metricRx(n.Type, 1)
switch n.Type { switch n.Type {
case NebulaMeta_HostQuery: case NebulaMeta_HostQuery:
// Exit if we don't answer queries // Exit if we don't answer queries
@@ -275,6 +327,7 @@ func (lh *LightHouse) HandleRequest(rAddr *udpAddr, vpnIp uint32, p []byte, c *c
l.WithError(err).WithField("vpnIp", IntIp(vpnIp)).Error("Failed to marshal lighthouse host query reply") l.WithError(err).WithField("vpnIp", IntIp(vpnIp)).Error("Failed to marshal lighthouse host query reply")
return return
} }
lh.metricTx(NebulaMeta_HostQueryReply, 1)
f.SendMessageToVpnIp(lightHouse, 0, vpnIp, reply, make([]byte, 12, 12), make([]byte, mtu)) f.SendMessageToVpnIp(lightHouse, 0, vpnIp, reply, make([]byte, 12, 12), make([]byte, mtu))
// This signals the other side to punch some zero byte udp packets // This signals the other side to punch some zero byte udp packets
@@ -293,6 +346,7 @@ func (lh *LightHouse) HandleRequest(rAddr *udpAddr, vpnIp uint32, p []byte, c *c
}, },
} }
reply, _ := proto.Marshal(answer) reply, _ := proto.Marshal(answer)
lh.metricTx(NebulaMeta_HostPunchNotification, 1)
f.SendMessageToVpnIp(lightHouse, 0, n.Details.VpnIp, reply, make([]byte, 12, 12), make([]byte, mtu)) f.SendMessageToVpnIp(lightHouse, 0, n.Details.VpnIp, reply, make([]byte, 12, 12), make([]byte, mtu))
} }
//fmt.Println(reply, remoteaddr) //fmt.Println(reply, remoteaddr)
@@ -307,6 +361,11 @@ func (lh *LightHouse) HandleRequest(rAddr *udpAddr, vpnIp uint32, p []byte, c *c
ans := NewUDPAddr(a.Ip, uint16(a.Port)) ans := NewUDPAddr(a.Ip, uint16(a.Port))
lh.AddRemote(n.Details.VpnIp, ans, false) lh.AddRemote(n.Details.VpnIp, ans, false)
} }
// Non-blocking attempt to trigger, skip if it would block
select {
case lh.handshakeTrigger <- n.Details.VpnIp:
default:
}
case NebulaMeta_HostUpdateNotification: case NebulaMeta_HostUpdateNotification:
//Simple check that the host sent this not someone else //Simple check that the host sent this not someone else
@@ -328,10 +387,9 @@ func (lh *LightHouse) HandleRequest(rAddr *udpAddr, vpnIp uint32, p []byte, c *c
for _, a := range n.Details.IpAndPorts { for _, a := range n.Details.IpAndPorts {
vpnPeer := NewUDPAddr(a.Ip, uint16(a.Port)) vpnPeer := NewUDPAddr(a.Ip, uint16(a.Port))
go func() { go func() {
for i := 0; i < 5; i++ { time.Sleep(lh.punchDelay)
lh.metricHolepunchTx.Inc(1)
lh.punchConn.WriteTo(empty, vpnPeer) lh.punchConn.WriteTo(empty, vpnPeer)
time.Sleep(time.Second * 1)
}
}() }()
l.Debugf("Punching %s on %d for %s", IntIp(a.Ip), a.Port, IntIp(n.Details.VpnIp)) l.Debugf("Punching %s on %d for %s", IntIp(a.Ip), a.Port, IntIp(n.Details.VpnIp))
@@ -349,6 +407,13 @@ func (lh *LightHouse) HandleRequest(rAddr *udpAddr, vpnIp uint32, p []byte, c *c
} }
} }
func (lh *LightHouse) metricRx(t NebulaMeta_MessageType, i int64) {
lh.metrics.Rx(NebulaMessageType(t), 0, i)
}
func (lh *LightHouse) metricTx(t NebulaMeta_MessageType, i int64) {
lh.metrics.Tx(NebulaMessageType(t), 0, i)
}
/* /*
func (f *Interface) sendPathCheck(ci *ConnectionState, endpoint *net.UDPAddr, counter int) { func (f *Interface) sendPathCheck(ci *ConnectionState, endpoint *net.UDPAddr, counter int) {
c := ci.messageCounter c := ci.messageCounter

View File

@@ -4,7 +4,7 @@ import (
"net" "net"
"testing" "testing"
proto "github.com/golang/protobuf/proto" "github.com/golang/protobuf/proto"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
@@ -52,7 +52,7 @@ func Test_lhStaticMapping(t *testing.T) {
udpServer, _ := NewListener("0.0.0.0", 0, true) udpServer, _ := NewListener("0.0.0.0", 0, true)
meh := NewLightHouse(true, 1, []uint32{ip2int(lh1IP)}, 10, 10003, udpServer, false) meh := NewLightHouse(true, 1, []uint32{ip2int(lh1IP)}, 10, 10003, udpServer, false, 1, false)
meh.AddRemote(ip2int(lh1IP), NewUDPAddr(ip2int(lh1IP), uint16(4242)), true) meh.AddRemote(ip2int(lh1IP), NewUDPAddr(ip2int(lh1IP), uint16(4242)), true)
err := meh.ValidateLHStaticEntries() err := meh.ValidateLHStaticEntries()
assert.Nil(t, err) assert.Nil(t, err)
@@ -60,7 +60,7 @@ func Test_lhStaticMapping(t *testing.T) {
lh2 := "10.128.0.3" lh2 := "10.128.0.3"
lh2IP := net.ParseIP(lh2) lh2IP := net.ParseIP(lh2)
meh = NewLightHouse(true, 1, []uint32{ip2int(lh1IP), ip2int(lh2IP)}, 10, 10003, udpServer, false) meh = NewLightHouse(true, 1, []uint32{ip2int(lh1IP), ip2int(lh2IP)}, 10, 10003, udpServer, false, 1, false)
meh.AddRemote(ip2int(lh1IP), NewUDPAddr(ip2int(lh1IP), uint16(4242)), true) meh.AddRemote(ip2int(lh1IP), NewUDPAddr(ip2int(lh1IP), uint16(4242)), true)
err = meh.ValidateLHStaticEntries() err = meh.ValidateLHStaticEntries()
assert.EqualError(t, err, "Lighthouse 10.128.0.3 does not have a static_host_map entry") assert.EqualError(t, err, "Lighthouse 10.128.0.3 does not have a static_host_map entry")

39
logger.go Normal file
View File

@@ -0,0 +1,39 @@
package nebula
import (
"errors"
"github.com/sirupsen/logrus"
)
type ContextualError struct {
RealError error
Fields map[string]interface{}
Context string
}
func NewContextualError(msg string, fields map[string]interface{}, realError error) ContextualError {
return ContextualError{Context: msg, Fields: fields, RealError: realError}
}
func (ce ContextualError) Error() string {
if ce.RealError == nil {
return ce.Context
}
return ce.RealError.Error()
}
func (ce ContextualError) Unwrap() error {
if ce.RealError == nil {
return errors.New(ce.Context)
}
return ce.RealError
}
func (ce *ContextualError) Log(lr *logrus.Logger) {
if ce.RealError != nil {
lr.WithFields(ce.Fields).WithError(ce.RealError).Error(ce.Context)
} else {
lr.WithFields(ce.Fields).Error(ce.Context)
}
}

67
logger_test.go Normal file
View File

@@ -0,0 +1,67 @@
package nebula
import (
"errors"
"testing"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
)
type TestLogWriter struct {
Logs []string
}
func NewTestLogWriter() *TestLogWriter {
return &TestLogWriter{Logs: make([]string, 0)}
}
func (tl *TestLogWriter) Write(p []byte) (n int, err error) {
tl.Logs = append(tl.Logs, string(p))
return len(p), nil
}
func (tl *TestLogWriter) Reset() {
tl.Logs = tl.Logs[:0]
}
func TestContextualError_Log(t *testing.T) {
l := logrus.New()
l.Formatter = &logrus.TextFormatter{
DisableTimestamp: true,
DisableColors: true,
}
tl := NewTestLogWriter()
l.Out = tl
// Test a full context line
tl.Reset()
e := NewContextualError("test message", m{"field": "1"}, errors.New("error"))
e.Log(l)
assert.Equal(t, []string{"level=error msg=\"test message\" error=error field=1\n"}, tl.Logs)
// Test a line with an error and msg but no fields
tl.Reset()
e = NewContextualError("test message", nil, errors.New("error"))
e.Log(l)
assert.Equal(t, []string{"level=error msg=\"test message\" error=error\n"}, tl.Logs)
// Test just a context and fields
tl.Reset()
e = NewContextualError("test message", m{"field": "1"}, nil)
e.Log(l)
assert.Equal(t, []string{"level=error msg=\"test message\" field=1\n"}, tl.Logs)
// Test just a context
tl.Reset()
e = NewContextualError("test message", nil, nil)
e.Log(l)
assert.Equal(t, []string{"level=error msg=\"test message\"\n"}, tl.Logs)
// Test just an error
tl.Reset()
e = NewContextualError("", nil, errors.New("error"))
e.Log(l)
assert.Equal(t, []string{"level=error error=error\n"}, tl.Logs)
}

185
main.go
View File

@@ -4,11 +4,8 @@ import (
"encoding/binary" "encoding/binary"
"fmt" "fmt"
"net" "net"
"os"
"os/signal"
"strconv" "strconv"
"strings" "strings"
"syscall"
"time" "time"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
@@ -16,36 +13,31 @@ import (
"gopkg.in/yaml.v2" "gopkg.in/yaml.v2"
) )
// The caller should provide a real logger, we have one just in case
var l = logrus.New() var l = logrus.New()
type m map[string]interface{} type m map[string]interface{}
func Main(configPath string, configTest bool, buildVersion string) { func Main(config *Config, configTest bool, buildVersion string, logger *logrus.Logger, tunFd *int) (*Control, error) {
l.Out = os.Stdout l = logger
l.Formatter = &logrus.TextFormatter{ l.Formatter = &logrus.TextFormatter{
FullTimestamp: true, FullTimestamp: true,
} }
config := NewConfig()
err := config.Load(configPath)
if err != nil {
l.WithError(err).Error("Failed to load config")
os.Exit(1)
}
// Print the config if in test, the exit comes later // Print the config if in test, the exit comes later
if configTest { if configTest {
b, err := yaml.Marshal(config.Settings) b, err := yaml.Marshal(config.Settings)
if err != nil { if err != nil {
l.Println(err) return nil, err
os.Exit(1)
} }
// Print the final config
l.Println(string(b)) l.Println(string(b))
} }
err = configLogger(config) err := configLogger(config)
if err != nil { if err != nil {
l.WithError(err).Error("Failed to configure the logger") return nil, NewContextualError("Failed to configure the logger", nil, err)
} }
config.RegisterReloadCallback(func(c *Config) { config.RegisterReloadCallback(func(c *Config) {
@@ -59,20 +51,20 @@ func Main(configPath string, configTest bool, buildVersion string) {
trustedCAs, err = loadCAFromConfig(config) trustedCAs, err = loadCAFromConfig(config)
if err != nil { if err != nil {
//The errors coming out of loadCA are already nicely formatted //The errors coming out of loadCA are already nicely formatted
l.WithError(err).Fatal("Failed to load ca from config") return nil, NewContextualError("Failed to load ca from config", nil, err)
} }
l.WithField("fingerprints", trustedCAs.GetFingerprints()).Debug("Trusted CA fingerprints") l.WithField("fingerprints", trustedCAs.GetFingerprints()).Debug("Trusted CA fingerprints")
cs, err := NewCertStateFromConfig(config) cs, err := NewCertStateFromConfig(config)
if err != nil { if err != nil {
//The errors coming out of NewCertStateFromConfig are already nicely formatted //The errors coming out of NewCertStateFromConfig are already nicely formatted
l.WithError(err).Fatal("Failed to load certificate from config") return nil, NewContextualError("Failed to load certificate from config", nil, err)
} }
l.WithField("cert", cs.certificate).Debug("Client nebula certificate") l.WithField("cert", cs.certificate).Debug("Client nebula certificate")
fw, err := NewFirewallFromConfig(cs.certificate, config) fw, err := NewFirewallFromConfig(cs.certificate, config)
if err != nil { if err != nil {
l.WithError(err).Fatal("Error while loading firewall rules") return nil, NewContextualError("Error while loading firewall rules", nil, err)
} }
l.WithField("firewallHash", fw.GetRuleHash()).Info("Firewall started") l.WithField("firewallHash", fw.GetRuleHash()).Info("Firewall started")
@@ -80,11 +72,11 @@ func Main(configPath string, configTest bool, buildVersion string) {
tunCidr := cs.certificate.Details.Ips[0] tunCidr := cs.certificate.Details.Ips[0]
routes, err := parseRoutes(config, tunCidr) routes, err := parseRoutes(config, tunCidr)
if err != nil { if err != nil {
l.WithError(err).Fatal("Could not parse tun.routes") return nil, NewContextualError("Could not parse tun.routes", nil, err)
} }
unsafeRoutes, err := parseUnsafeRoutes(config, tunCidr) unsafeRoutes, err := parseUnsafeRoutes(config, tunCidr)
if err != nil { if err != nil {
l.WithError(err).Fatal("Could not parse tun.unsafe_routes") return nil, NewContextualError("Could not parse tun.unsafe_routes", nil, err)
} }
ssh, err := sshd.NewSSHServer(l.WithField("subsystem", "sshd")) ssh, err := sshd.NewSSHServer(l.WithField("subsystem", "sshd"))
@@ -92,7 +84,7 @@ func Main(configPath string, configTest bool, buildVersion string) {
if config.GetBool("sshd.enabled", false) { if config.GetBool("sshd.enabled", false) {
err = configSSH(ssh, config) err = configSSH(ssh, config)
if err != nil { if err != nil {
l.WithError(err).Fatal("Error while configuring the sshd") return nil, NewContextualError("Error while configuring the sshd", nil, err)
} }
} }
@@ -101,14 +93,24 @@ func Main(configPath string, configTest bool, buildVersion string) {
// tun config, listeners, anything modifying the computer should be below // tun config, listeners, anything modifying the computer should be below
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
if configTest { var tun Inside
os.Exit(0) if !configTest {
}
config.CatchHUP() config.CatchHUP()
// set up our tun dev switch {
tun, err := newTun( case config.GetBool("tun.disabled", false):
tun = newDisabledTun(tunCidr, l)
case tunFd != nil:
tun, err = newTunFromFd(
*tunFd,
tunCidr,
config.GetInt("tun.mtu", DEFAULT_MTU),
routes,
unsafeRoutes,
config.GetInt("tun.tx_queue", 500),
)
default:
tun, err = newTun(
config.GetString("tun.dev", ""), config.GetString("tun.dev", ""),
tunCidr, tunCidr,
config.GetInt("tun.mtu", DEFAULT_MTU), config.GetInt("tun.mtu", DEFAULT_MTU),
@@ -116,17 +118,24 @@ func Main(configPath string, configTest bool, buildVersion string) {
unsafeRoutes, unsafeRoutes,
config.GetInt("tun.tx_queue", 500), config.GetInt("tun.tx_queue", 500),
) )
}
if err != nil { if err != nil {
l.WithError(err).Fatal("Failed to get a tun/tap device") return nil, NewContextualError("Failed to get a tun/tap device", nil, err)
}
} }
// set up our UDP listener // set up our UDP listener
udpQueues := config.GetInt("listen.routines", 1) udpQueues := config.GetInt("listen.routines", 1)
udpServer, err := NewListener(config.GetString("listen.host", "0.0.0.0"), config.GetInt("listen.port", 0), udpQueues > 1) var udpServer *udpConn
if !configTest {
udpServer, err = NewListener(config.GetString("listen.host", "0.0.0.0"), config.GetInt("listen.port", 0), udpQueues > 1)
if err != nil { if err != nil {
l.WithError(err).Fatal("Failed to open udp listener") return nil, NewContextualError("Failed to open udp listener", nil, err)
} }
udpServer.reloadConfig(config) udpServer.reloadConfig(config)
}
// Set up my internal host map // Set up my internal host map
var preferredRanges []*net.IPNet var preferredRanges []*net.IPNet
@@ -136,7 +145,7 @@ func Main(configPath string, configTest bool, buildVersion string) {
for _, rawPreferredRange := range rawPreferredRanges { for _, rawPreferredRange := range rawPreferredRanges {
_, preferredRange, err := net.ParseCIDR(rawPreferredRange) _, preferredRange, err := net.ParseCIDR(rawPreferredRange)
if err != nil { if err != nil {
l.WithError(err).Fatal("Failed to parse preferred ranges") return nil, NewContextualError("Failed to parse preferred ranges", nil, err)
} }
preferredRanges = append(preferredRanges, preferredRange) preferredRanges = append(preferredRanges, preferredRange)
} }
@@ -149,7 +158,7 @@ func Main(configPath string, configTest bool, buildVersion string) {
if rawLocalRange != "" { if rawLocalRange != "" {
_, localRange, err := net.ParseCIDR(rawLocalRange) _, localRange, err := net.ParseCIDR(rawLocalRange)
if err != nil { if err != nil {
l.WithError(err).Fatal("Failed to parse local range") return nil, NewContextualError("Failed to parse local_range", nil, err)
} }
// Check if the entry for local_range was already specified in // Check if the entry for local_range was already specified in
@@ -169,6 +178,7 @@ func Main(configPath string, configTest bool, buildVersion string) {
hostMap := NewHostMap("main", tunCidr, preferredRanges) hostMap := NewHostMap("main", tunCidr, preferredRanges)
hostMap.SetDefaultRoute(ip2int(net.ParseIP(config.GetString("default_route", "0.0.0.0")))) hostMap.SetDefaultRoute(ip2int(net.ParseIP(config.GetString("default_route", "0.0.0.0"))))
hostMap.addUnsafeRoutes(&unsafeRoutes) hostMap.addUnsafeRoutes(&unsafeRoutes)
hostMap.metricsEnabled = config.GetBool("stats.message_metrics", false)
l.WithField("network", hostMap.vpnCIDR).WithField("preferredRanges", hostMap.preferredRanges).Info("Main HostMap created") l.WithField("network", hostMap.vpnCIDR).WithField("preferredRanges", hostMap.preferredRanges).Info("Main HostMap created")
@@ -177,23 +187,22 @@ func Main(configPath string, configTest bool, buildVersion string) {
go hostMap.Promoter(config.GetInt("promoter.interval")) go hostMap.Promoter(config.GetInt("promoter.interval"))
*/ */
punchy := config.GetBool("punchy", false) punchy := NewPunchyFromConfig(config)
if punchy == true { if punchy.Punch && !configTest {
l.Info("UDP hole punching enabled") l.Info("UDP hole punching enabled")
go hostMap.Punchy(udpServer) go hostMap.Punchy(udpServer)
} }
port := config.GetInt("listen.port", 0) port := config.GetInt("listen.port", 0)
// If port is dynamic, discover it // If port is dynamic, discover it
if port == 0 { if port == 0 && !configTest {
uPort, err := udpServer.LocalAddr() uPort, err := udpServer.LocalAddr()
if err != nil { if err != nil {
l.WithError(err).Fatal("Failed to get listening port") return nil, NewContextualError("Failed to get listening port", nil, err)
} }
port = int(uPort.Port) port = int(uPort.Port)
} }
punchBack := config.GetBool("punch_back", false)
amLighthouse := config.GetBool("lighthouse.am_lighthouse", false) amLighthouse := config.GetBool("lighthouse.am_lighthouse", false)
// warn if am_lighthouse is enabled but upstream lighthouses exists // warn if am_lighthouse is enabled but upstream lighthouses exists
@@ -206,7 +215,10 @@ func Main(configPath string, configTest bool, buildVersion string) {
for i, host := range rawLighthouseHosts { for i, host := range rawLighthouseHosts {
ip := net.ParseIP(host) ip := net.ParseIP(host)
if ip == nil { if ip == nil {
l.WithField("host", host).Fatalf("Unable to parse lighthouse host entry %v", i+1) return nil, NewContextualError("Unable to parse lighthouse host entry", m{"host": host, "entry": i + 1}, nil)
}
if !tunCidr.Contains(ip) {
return nil, NewContextualError("lighthouse host is not in our subnet, invalid", m{"vpnIp": ip, "network": tunCidr.String()}, nil)
} }
lighthouseHosts[i] = ip2int(ip) lighthouseHosts[i] = ip2int(ip)
} }
@@ -219,12 +231,29 @@ func Main(configPath string, configTest bool, buildVersion string) {
config.GetInt("lighthouse.interval", 10), config.GetInt("lighthouse.interval", 10),
port, port,
udpServer, udpServer,
punchBack, punchy.Respond,
punchy.Delay,
config.GetBool("stats.lighthouse_metrics", false),
) )
remoteAllowList, err := config.GetAllowList("lighthouse.remote_allow_list", false)
if err != nil {
return nil, NewContextualError("Invalid lighthouse.remote_allow_list", nil, err)
}
lightHouse.SetRemoteAllowList(remoteAllowList)
localAllowList, err := config.GetAllowList("lighthouse.local_allow_list", true)
if err != nil {
return nil, NewContextualError("Invalid lighthouse.local_allow_list", nil, err)
}
lightHouse.SetLocalAllowList(localAllowList)
//TODO: Move all of this inside functions in lighthouse.go //TODO: Move all of this inside functions in lighthouse.go
for k, v := range config.GetMap("static_host_map", map[interface{}]interface{}{}) { for k, v := range config.GetMap("static_host_map", map[interface{}]interface{}{}) {
vpnIp := net.ParseIP(fmt.Sprintf("%v", k)) vpnIp := net.ParseIP(fmt.Sprintf("%v", k))
if !tunCidr.Contains(vpnIp) {
return nil, NewContextualError("static_host_map key is not in our subnet, invalid", m{"vpnIp": vpnIp, "network": tunCidr.String()}, nil)
}
vals, ok := v.([]interface{}) vals, ok := v.([]interface{})
if ok { if ok {
for _, v := range vals { for _, v := range vals {
@@ -234,7 +263,7 @@ func Main(configPath string, configTest bool, buildVersion string) {
ip := addr.IP ip := addr.IP
port, err := strconv.Atoi(parts[1]) port, err := strconv.Atoi(parts[1])
if err != nil { if err != nil {
l.Fatalf("Static host address for %s could not be parsed: %s", vpnIp, v) return nil, NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp}, err)
} }
lightHouse.AddRemote(ip2int(vpnIp), NewUDPAddr(ip2int(ip), uint16(port)), true) lightHouse.AddRemote(ip2int(vpnIp), NewUDPAddr(ip2int(ip), uint16(port)), true)
} }
@@ -247,7 +276,7 @@ func Main(configPath string, configTest bool, buildVersion string) {
ip := addr.IP ip := addr.IP
port, err := strconv.Atoi(parts[1]) port, err := strconv.Atoi(parts[1])
if err != nil { if err != nil {
l.Fatalf("Static host address for %s could not be parsed: %s", vpnIp, v) return nil, NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp}, err)
} }
lightHouse.AddRemote(ip2int(vpnIp), NewUDPAddr(ip2int(ip), uint16(port)), true) lightHouse.AddRemote(ip2int(vpnIp), NewUDPAddr(ip2int(ip), uint16(port)), true)
} }
@@ -259,7 +288,24 @@ func Main(configPath string, configTest bool, buildVersion string) {
l.WithError(err).Error("Lighthouse unreachable") l.WithError(err).Error("Lighthouse unreachable")
} }
handshakeManager := NewHandshakeManager(tunCidr, preferredRanges, hostMap, lightHouse, udpServer) var messageMetrics *MessageMetrics
if config.GetBool("stats.message_metrics", false) {
messageMetrics = newMessageMetrics()
} else {
messageMetrics = newMessageMetricsOnlyRecvError()
}
handshakeConfig := HandshakeConfig{
tryInterval: config.GetDuration("handshakes.try_interval", DefaultHandshakeTryInterval),
retries: config.GetInt("handshakes.retries", DefaultHandshakeRetries),
waitRotation: config.GetInt("handshakes.wait_rotation", DefaultHandshakeWaitRotation),
triggerBuffer: config.GetInt("handshakes.trigger_buffer", DefaultHandshakeTriggerBuffer),
messageMetrics: messageMetrics,
}
handshakeManager := NewHandshakeManager(tunCidr, preferredRanges, hostMap, lightHouse, udpServer, handshakeConfig)
lightHouse.handshakeTrigger = handshakeManager.trigger
//TODO: These will be reused for psk //TODO: These will be reused for psk
//handshakeMACKey := config.GetString("handshake_mac.key", "") //handshakeMACKey := config.GetString("handshake_mac.key", "")
@@ -283,37 +329,47 @@ func Main(configPath string, configTest bool, buildVersion string) {
DropLocalBroadcast: config.GetBool("tun.drop_local_broadcast", false), DropLocalBroadcast: config.GetBool("tun.drop_local_broadcast", false),
DropMulticast: config.GetBool("tun.drop_multicast", false), DropMulticast: config.GetBool("tun.drop_multicast", false),
UDPBatchSize: config.GetInt("listen.batch", 64), UDPBatchSize: config.GetInt("listen.batch", 64),
udpQueues: udpQueues,
tunQueues: config.GetInt("tun.routines", 1),
MessageMetrics: messageMetrics,
version: buildVersion,
} }
switch ifConfig.Cipher { switch ifConfig.Cipher {
case "aes": case "aes":
noiseEndiannes = binary.BigEndian noiseEndianness = binary.BigEndian
case "chachapoly": case "chachapoly":
noiseEndiannes = binary.LittleEndian noiseEndianness = binary.LittleEndian
default: default:
l.Fatalf("Unknown cipher: %v", ifConfig.Cipher) return nil, fmt.Errorf("unknown cipher: %v", ifConfig.Cipher)
} }
ifce, err := NewInterface(ifConfig) var ifce *Interface
if !configTest {
ifce, err = NewInterface(ifConfig)
if err != nil { if err != nil {
l.WithError(err).Fatal("Failed to initialize interface") return nil, fmt.Errorf("failed to initialize interface: %s", err)
} }
ifce.RegisterConfigChangeCallbacks(config) ifce.RegisterConfigChangeCallbacks(config)
go handshakeManager.Run(ifce) go handshakeManager.Run(ifce)
go lightHouse.LhUpdateWorker(ifce) go lightHouse.LhUpdateWorker(ifce)
}
err = startStats(config) err = startStats(config, configTest)
if err != nil { if err != nil {
l.WithError(err).Fatal("Failed to start stats emitter") return nil, NewContextualError("Failed to start stats emitter", nil, err)
}
if configTest {
return nil, nil
} }
//TODO: check if we _should_ be emitting stats //TODO: check if we _should_ be emitting stats
go ifce.emitStats(config.GetDuration("stats.interval", time.Second*10)) go ifce.emitStats(config.GetDuration("stats.interval", time.Second*10))
attachCommands(ssh, hostMap, handshakeManager.pendingHostMap, lightHouse, ifce) attachCommands(ssh, hostMap, handshakeManager.pendingHostMap, lightHouse, ifce)
ifce.Run(config.GetInt("tun.routines", 1), udpQueues, buildVersion)
// Start DNS server last to allow using the nebula IP as lighthouse.dns.host // Start DNS server last to allow using the nebula IP as lighthouse.dns.host
if amLighthouse && serveDns { if amLighthouse && serveDns {
@@ -321,30 +377,5 @@ func Main(configPath string, configTest bool, buildVersion string) {
go dnsMain(hostMap, config) go dnsMain(hostMap, config)
} }
// Just sit here and be friendly, main thread. return &Control{ifce, l}, nil
shutdownBlock(ifce)
}
func shutdownBlock(ifce *Interface) {
var sigChan = make(chan os.Signal)
signal.Notify(sigChan, syscall.SIGTERM)
signal.Notify(sigChan, syscall.SIGINT)
sig := <-sigChan
l.WithField("signal", sig).Info("Caught signal, shutting down")
//TODO: stop tun and udp routines, the lock on hostMap does effectively does that though
//TODO: this is probably better as a function in ConnectionManager or HostMap directly
ifce.hostMap.Lock()
for _, h := range ifce.hostMap.Hosts {
if h.ConnectionState.ready {
ifce.send(closeTunnel, 0, h.ConnectionState, h, h.remote, []byte{}, make([]byte, 12, 12), make([]byte, mtu))
l.WithField("vpnIp", IntIp(h.hostId)).WithField("udpAddr", h.remote).
Debug("Sending close tunnel message")
}
}
ifce.hostMap.Unlock()
l.WithField("signal", sig).Info("Goodbye")
os.Exit(0)
} }

97
message_metrics.go Normal file
View File

@@ -0,0 +1,97 @@
package nebula
import (
"fmt"
"github.com/rcrowley/go-metrics"
)
type MessageMetrics struct {
rx [][]metrics.Counter
tx [][]metrics.Counter
rxUnknown metrics.Counter
txUnknown metrics.Counter
}
func (m *MessageMetrics) Rx(t NebulaMessageType, s NebulaMessageSubType, i int64) {
if m != nil {
if t >= 0 && int(t) < len(m.rx) && s >= 0 && int(s) < len(m.rx[t]) {
m.rx[t][s].Inc(i)
} else if m.rxUnknown != nil {
m.rxUnknown.Inc(i)
}
}
}
func (m *MessageMetrics) Tx(t NebulaMessageType, s NebulaMessageSubType, i int64) {
if m != nil {
if t >= 0 && int(t) < len(m.tx) && s >= 0 && int(s) < len(m.tx[t]) {
m.tx[t][s].Inc(i)
} else if m.txUnknown != nil {
m.txUnknown.Inc(i)
}
}
}
func newMessageMetrics() *MessageMetrics {
gen := func(t string) [][]metrics.Counter {
return [][]metrics.Counter{
{
metrics.GetOrRegisterCounter(fmt.Sprintf("messages.%s.handshake_ixpsk0", t), nil),
},
nil,
{metrics.GetOrRegisterCounter(fmt.Sprintf("messages.%s.recv_error", t), nil)},
{metrics.GetOrRegisterCounter(fmt.Sprintf("messages.%s.lighthouse", t), nil)},
{
metrics.GetOrRegisterCounter(fmt.Sprintf("messages.%s.test_request", t), nil),
metrics.GetOrRegisterCounter(fmt.Sprintf("messages.%s.test_response", t), nil),
},
{metrics.GetOrRegisterCounter(fmt.Sprintf("messages.%s.close_tunnel", t), nil)},
}
}
return &MessageMetrics{
rx: gen("rx"),
tx: gen("tx"),
rxUnknown: metrics.GetOrRegisterCounter("messages.rx.other", nil),
txUnknown: metrics.GetOrRegisterCounter("messages.tx.other", nil),
}
}
// Historically we only recorded recv_error, so this is backwards compat
func newMessageMetricsOnlyRecvError() *MessageMetrics {
gen := func(t string) [][]metrics.Counter {
return [][]metrics.Counter{
nil,
nil,
{metrics.GetOrRegisterCounter(fmt.Sprintf("messages.%s.recv_error", t), nil)},
}
}
return &MessageMetrics{
rx: gen("rx"),
tx: gen("tx"),
}
}
func newLighthouseMetrics() *MessageMetrics {
gen := func(t string) [][]metrics.Counter {
h := make([][]metrics.Counter, len(NebulaMeta_MessageType_name))
used := []NebulaMeta_MessageType{
NebulaMeta_HostQuery,
NebulaMeta_HostQueryReply,
NebulaMeta_HostUpdateNotification,
NebulaMeta_HostPunchNotification,
}
for _, i := range used {
h[i] = []metrics.Counter{metrics.GetOrRegisterCounter(fmt.Sprintf("lighthouse.%s.%s", t, i.String()), nil)}
}
return h
}
return &MessageMetrics{
rx: gen("rx"),
tx: gen("tx"),
rxUnknown: metrics.GetOrRegisterCounter("lighthouse.rx.other", nil),
txUnknown: metrics.GetOrRegisterCounter("lighthouse.tx.other", nil),
}
}

View File

@@ -8,11 +8,11 @@ import (
"github.com/flynn/noise" "github.com/flynn/noise"
) )
type endiannes interface { type endianness interface {
PutUint64(b []byte, v uint64) PutUint64(b []byte, v uint64)
} }
var noiseEndiannes endiannes = binary.BigEndian var noiseEndianness endianness = binary.BigEndian
type NebulaCipherState struct { type NebulaCipherState struct {
c noise.Cipher c noise.Cipher
@@ -37,7 +37,7 @@ func (s *NebulaCipherState) EncryptDanger(out, ad, plaintext []byte, n uint64, n
nb[1] = 0 nb[1] = 0
nb[2] = 0 nb[2] = 0
nb[3] = 0 nb[3] = 0
noiseEndiannes.PutUint64(nb[4:], n) noiseEndianness.PutUint64(nb[4:], n)
out = s.c.(cipher.AEAD).Seal(out, nb, plaintext, ad) out = s.c.(cipher.AEAD).Seal(out, nb, plaintext, ad)
//l.Debugf("Encryption: outlen: %d, nonce: %d, ad: %s, plainlen %d", len(out), n, ad, len(plaintext)) //l.Debugf("Encryption: outlen: %d, nonce: %d, ad: %s, plainlen %d", len(out), n, ad, len(plaintext))
return out, nil return out, nil
@@ -52,7 +52,7 @@ func (s *NebulaCipherState) DecryptDanger(out, ad, ciphertext []byte, n uint64,
nb[1] = 0 nb[1] = 0
nb[2] = 0 nb[2] = 0
nb[3] = 0 nb[3] = 0
noiseEndiannes.PutUint64(nb[4:], n) noiseEndianness.PutUint64(nb[4:], n)
return s.c.(cipher.AEAD).Open(out, nb, ciphertext, ad) return s.c.(cipher.AEAD).Open(out, nb, ciphertext, ad)
} else { } else {
return []byte{}, nil return []byte{}, nil

View File

@@ -2,18 +2,14 @@ package nebula
import ( import (
"encoding/binary" "encoding/binary"
"errors"
"fmt"
"time"
"github.com/flynn/noise" "github.com/flynn/noise"
"github.com/golang/protobuf/proto" "github.com/golang/protobuf/proto"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/cert"
// "github.com/google/gopacket"
// "github.com/google/gopacket/layers"
// "encoding/binary"
"errors"
"fmt"
"time"
"golang.org/x/net/ipv4" "golang.org/x/net/ipv4"
) )
@@ -43,92 +39,15 @@ func (f *Interface) readOutsidePackets(addr *udpAddr, out []byte, packet []byte,
ci = hostinfo.ConnectionState ci = hostinfo.ConnectionState
} }
switch header.Type { handle := f.handlers[header.Version][header.Type][header.Subtype]
case message:
if !f.handleEncrypted(ci, addr, header) { if handle == nil {
f.messageMetrics.Rx(header.Type, header.Subtype, 1)
hostinfo.logger().Debugf("Unexpected packet received from %s", addr)
return return
} }
f.decryptToTun(hostinfo, header.MessageCounter, out, packet, fwPacket, nb) handle(hostinfo, ci, addr, header, out, packet, fwPacket, nb)
// Fallthrough to the bottom to record incoming traffic
case lightHouse:
if !f.handleEncrypted(ci, addr, header) {
return
}
d, err := f.decrypt(hostinfo, header.MessageCounter, out, packet, header, nb)
if err != nil {
l.WithError(err).WithField("udpAddr", addr).WithField("vpnIp", IntIp(hostinfo.hostId)).
WithField("packet", packet).
Error("Failed to decrypt lighthouse packet")
//TODO: maybe after build 64 is out? 06/14/2018 - NB
//f.sendRecvError(net.Addr(addr), header.RemoteIndex)
return
}
f.lightHouse.HandleRequest(addr, hostinfo.hostId, d, hostinfo.GetCert(), f)
// Fallthrough to the bottom to record incoming traffic
case test:
if !f.handleEncrypted(ci, addr, header) {
return
}
d, err := f.decrypt(hostinfo, header.MessageCounter, out, packet, header, nb)
if err != nil {
l.WithError(err).WithField("udpAddr", addr).WithField("vpnIp", IntIp(hostinfo.hostId)).
WithField("packet", packet).
Error("Failed to decrypt test packet")
//TODO: maybe after build 64 is out? 06/14/2018 - NB
//f.sendRecvError(net.Addr(addr), header.RemoteIndex)
return
}
if header.Subtype == testRequest {
// This testRequest might be from TryPromoteBest, so we should roam
// to the new IP address before responding
f.handleHostRoaming(hostinfo, addr)
f.send(test, testReply, ci, hostinfo, hostinfo.remote, d, nb, out)
}
// Fallthrough to the bottom to record incoming traffic
// Non encrypted messages below here, they should not fall through to avoid tracking incoming traffic since they
// are unauthenticated
case handshake:
HandleIncomingHandshake(f, addr, packet, header, hostinfo)
return
case recvError:
// TODO: Remove this with recv_error deprecation
f.handleRecvError(addr, header)
return
case closeTunnel:
if !f.handleEncrypted(ci, addr, header) {
return
}
l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
Info("Close tunnel received, tearing down.")
f.closeTunnel(hostinfo)
return
default:
l.Debugf("Unexpected packet received from %s", addr)
return
}
f.handleHostRoaming(hostinfo, addr)
f.connectionManager.In(hostinfo.hostId)
} }
func (f *Interface) closeTunnel(hostInfo *HostInfo) { func (f *Interface) closeTunnel(hostInfo *HostInfo) {
@@ -142,15 +61,19 @@ func (f *Interface) closeTunnel(hostInfo *HostInfo) {
func (f *Interface) handleHostRoaming(hostinfo *HostInfo, addr *udpAddr) { func (f *Interface) handleHostRoaming(hostinfo *HostInfo, addr *udpAddr) {
if hostDidRoam(hostinfo.remote, addr) { if hostDidRoam(hostinfo.remote, addr) {
if !f.lightHouse.remoteAllowList.Allow(udp2ipInt(addr)) {
hostinfo.logger().WithField("newAddr", addr).Debug("lighthouse.remote_allow_list denied roaming")
return
}
if !hostinfo.lastRoam.IsZero() && addr.Equals(hostinfo.lastRoamRemote) && time.Since(hostinfo.lastRoam) < RoamingSupressSeconds*time.Second { if !hostinfo.lastRoam.IsZero() && addr.Equals(hostinfo.lastRoamRemote) && time.Since(hostinfo.lastRoam) < RoamingSupressSeconds*time.Second {
if l.Level >= logrus.DebugLevel { if l.Level >= logrus.DebugLevel {
l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", hostinfo.remote).WithField("newAddr", addr). hostinfo.logger().WithField("udpAddr", hostinfo.remote).WithField("newAddr", addr).
Debugf("Supressing roam back to previous remote for %d seconds", RoamingSupressSeconds) Debugf("Supressing roam back to previous remote for %d seconds", RoamingSupressSeconds)
} }
return return
} }
l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", hostinfo.remote).WithField("newAddr", addr). hostinfo.logger().WithField("udpAddr", hostinfo.remote).WithField("newAddr", addr).
Info("Host roamed to new udp ip/port.") Info("Host roamed to new udp ip/port.")
hostinfo.lastRoam = time.Now() hostinfo.lastRoam = time.Now()
remoteCopy := *hostinfo.remote remoteCopy := *hostinfo.remote
@@ -244,7 +167,7 @@ func (f *Interface) decrypt(hostinfo *HostInfo, mc uint64, out []byte, packet []
} }
if !hostinfo.ConnectionState.window.Update(mc) { if !hostinfo.ConnectionState.window.Update(mc) {
l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("header", header). hostinfo.logger().WithField("header", header).
Debugln("dropping out of window packet") Debugln("dropping out of window packet")
return nil, errors.New("out of window packet") return nil, errors.New("out of window packet")
} }
@@ -252,12 +175,12 @@ func (f *Interface) decrypt(hostinfo *HostInfo, mc uint64, out []byte, packet []
return out, nil return out, nil
} }
func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out []byte, packet []byte, fwPacket *FirewallPacket, nb []byte) { func (f *Interface) decryptTo(write func([]byte) error, hostinfo *HostInfo, messageCounter uint64, out []byte, packet []byte, fwPacket *FirewallPacket, nb []byte) {
var err error var err error
out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:HeaderLen], packet[HeaderLen:], messageCounter, nb) out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:HeaderLen], packet[HeaderLen:], messageCounter, nb)
if err != nil { if err != nil {
l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).Error("Failed to decrypt packet") hostinfo.logger().WithError(err).Error("Failed to decrypt packet")
//TODO: maybe after build 64 is out? 06/14/2018 - NB //TODO: maybe after build 64 is out? 06/14/2018 - NB
//f.sendRecvError(hostinfo.remote, header.RemoteIndex) //f.sendRecvError(hostinfo.remote, header.RemoteIndex)
return return
@@ -265,32 +188,36 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
err = newPacket(out, true, fwPacket) err = newPacket(out, true, fwPacket)
if err != nil { if err != nil {
l.WithError(err).WithField("packet", out).WithField("hostInfo", IntIp(hostinfo.hostId)). hostinfo.logger().WithError(err).WithField("packet", out).
Warnf("Error while validating inbound packet") Warnf("Error while validating inbound packet")
return return
} }
if !hostinfo.ConnectionState.window.Update(messageCounter) { if !hostinfo.ConnectionState.window.Update(messageCounter) {
l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("fwPacket", fwPacket). hostinfo.logger().WithField("fwPacket", fwPacket).
Debugln("dropping out of window packet") Debugln("dropping out of window packet")
return return
} }
if f.firewall.Drop(out, *fwPacket, true, hostinfo, trustedCAs) { dropReason := f.firewall.Drop(out, *fwPacket, true, hostinfo, trustedCAs)
l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("fwPacket", fwPacket). if dropReason != nil {
if l.Level >= logrus.DebugLevel {
hostinfo.logger().WithField("fwPacket", fwPacket).
WithField("reason", dropReason).
Debugln("dropping inbound packet") Debugln("dropping inbound packet")
}
return return
} }
f.connectionManager.In(hostinfo.hostId) f.connectionManager.In(hostinfo.hostId)
err = f.inside.WriteRaw(out) err = write(out)
if err != nil { if err != nil {
l.WithError(err).Error("Failed to write to tun") l.WithError(err).Error("Failed to write to tun")
} }
} }
func (f *Interface) sendRecvError(endpoint *udpAddr, index uint32) { func (f *Interface) sendRecvError(endpoint *udpAddr, index uint32) {
f.metricTxRecvError.Inc(1) f.messageMetrics.Tx(recvError, 0, 1)
//TODO: this should be a signed message so we can trust that we should drop the index //TODO: this should be a signed message so we can trust that we should drop the index
b := HeaderEncode(make([]byte, HeaderLen), Version, uint8(recvError), 0, index, 0) b := HeaderEncode(make([]byte, HeaderLen), Version, uint8(recvError), 0, index, 0)
@@ -303,8 +230,6 @@ func (f *Interface) sendRecvError(endpoint *udpAddr, index uint32) {
} }
func (f *Interface) handleRecvError(addr *udpAddr, h *Header) { func (f *Interface) handleRecvError(addr *udpAddr, h *Header) {
f.metricRxRecvError.Inc(1)
// This flag is to stop caring about recv_error from old versions // This flag is to stop caring about recv_error from old versions
// This should go away when the old version is gone from prod // This should go away when the old version is gone from prod
if l.Level >= logrus.DebugLevel { if l.Level >= logrus.DebugLevel {

View File

@@ -1,10 +1,11 @@
package nebula package nebula
import ( import (
"github.com/stretchr/testify/assert"
"golang.org/x/net/ipv4"
"net" "net"
"testing" "testing"
"github.com/stretchr/testify/assert"
"golang.org/x/net/ipv4"
) )
func Test_newPacket(t *testing.T) { func Test_newPacket(t *testing.T) {

30
punchy.go Normal file
View File

@@ -0,0 +1,30 @@
package nebula
import "time"
type Punchy struct {
Punch bool
Respond bool
Delay time.Duration
}
func NewPunchyFromConfig(c *Config) *Punchy {
p := &Punchy{}
if c.IsSet("punchy.punch") {
p.Punch = c.GetBool("punchy.punch", false)
} else {
// Deprecated fallback
p.Punch = c.GetBool("punchy", false)
}
if c.IsSet("punchy.respond") {
p.Respond = c.GetBool("punchy.respond", false)
} else {
// Deprecated fallback
p.Respond = c.GetBool("punch_back", false)
}
p.Delay = c.GetDuration("punchy.delay", time.Second)
return p
}

44
punchy_test.go Normal file
View File

@@ -0,0 +1,44 @@
package nebula
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestNewPunchyFromConfig(t *testing.T) {
c := NewConfig()
// Test defaults
p := NewPunchyFromConfig(c)
assert.Equal(t, false, p.Punch)
assert.Equal(t, false, p.Respond)
assert.Equal(t, time.Second, p.Delay)
// punchy deprecation
c.Settings["punchy"] = true
p = NewPunchyFromConfig(c)
assert.Equal(t, true, p.Punch)
// punchy.punch
c.Settings["punchy"] = map[interface{}]interface{}{"punch": true}
p = NewPunchyFromConfig(c)
assert.Equal(t, true, p.Punch)
// punch_back deprecation
c.Settings["punch_back"] = true
p = NewPunchyFromConfig(c)
assert.Equal(t, true, p.Respond)
// punchy.respond
c.Settings["punchy"] = map[interface{}]interface{}{"respond": true}
c.Settings["punch_back"] = false
p = NewPunchyFromConfig(c)
assert.Equal(t, true, p.Respond)
// punchy.delay
c.Settings["punchy"] = map[interface{}]interface{}{"delay": "1m"}
p = NewPunchyFromConfig(c)
assert.Equal(t, time.Minute, p.Delay)
}

56
ssh.go
View File

@@ -5,8 +5,6 @@ import (
"encoding/json" "encoding/json"
"flag" "flag"
"fmt" "fmt"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/sshd"
"io/ioutil" "io/ioutil"
"net" "net"
"os" "os"
@@ -14,6 +12,9 @@ import (
"runtime/pprof" "runtime/pprof"
"strings" "strings"
"syscall" "syscall"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/sshd"
) )
type sshListHostMapFlags struct { type sshListHostMapFlags struct {
@@ -65,10 +66,11 @@ func configSSH(ssh *sshd.SSHServer, c *Config) error {
return fmt.Errorf("sshd.listen must be provided") return fmt.Errorf("sshd.listen must be provided")
} }
port := strings.Split(listen, ":") _, port, err := net.SplitHostPort(listen)
if len(port) < 2 { if err != nil {
return fmt.Errorf("sshd.listen does not have a port") return fmt.Errorf("invalid sshd.listen address: %s", err)
} else if port[1] == "22" { }
if port == "22" {
return fmt.Errorf("sshd.listen can not use port 22") return fmt.Errorf("sshd.listen can not use port 22")
} }
@@ -461,7 +463,12 @@ func sshQueryLighthouse(ifce *Interface, fs interface{}, a []string, w sshd.Stri
return w.WriteLine("No vpn ip was provided") return w.WriteLine("No vpn ip was provided")
} }
vpnIp := ip2int(net.ParseIP(a[0])) parsedIp := net.ParseIP(a[0])
if parsedIp == nil {
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
}
vpnIp := ip2int(parsedIp)
if vpnIp == 0 { if vpnIp == 0 {
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
} }
@@ -481,7 +488,12 @@ func sshCloseTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWr
return w.WriteLine("No vpn ip was provided") return w.WriteLine("No vpn ip was provided")
} }
vpnIp := ip2int(net.ParseIP(a[0])) parsedIp := net.ParseIP(a[0])
if parsedIp == nil {
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
}
vpnIp := ip2int(parsedIp)
if vpnIp == 0 { if vpnIp == 0 {
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
} }
@@ -519,7 +531,12 @@ func sshCreateTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringW
return w.WriteLine("No vpn ip was provided") return w.WriteLine("No vpn ip was provided")
} }
vpnIp := ip2int(net.ParseIP(a[0])) parsedIp := net.ParseIP(a[0])
if parsedIp == nil {
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
}
vpnIp := ip2int(parsedIp)
if vpnIp == 0 { if vpnIp == 0 {
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
} }
@@ -571,7 +588,12 @@ func sshChangeRemote(ifce *Interface, fs interface{}, a []string, w sshd.StringW
return w.WriteLine("Address could not be parsed") return w.WriteLine("Address could not be parsed")
} }
vpnIp := ip2int(net.ParseIP(a[0])) parsedIp := net.ParseIP(a[0])
if parsedIp == nil {
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
}
vpnIp := ip2int(parsedIp)
if vpnIp == 0 { if vpnIp == 0 {
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
} }
@@ -647,7 +669,12 @@ func sshPrintCert(ifce *Interface, fs interface{}, a []string, w sshd.StringWrit
cert := ifce.certState.certificate cert := ifce.certState.certificate
if len(a) > 0 { if len(a) > 0 {
vpnIp := ip2int(net.ParseIP(a[0])) parsedIp := net.ParseIP(a[0])
if parsedIp == nil {
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
}
vpnIp := ip2int(parsedIp)
if vpnIp == 0 { if vpnIp == 0 {
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
} }
@@ -694,7 +721,12 @@ func sshPrintTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWr
return w.WriteLine("No vpn ip was provided") return w.WriteLine("No vpn ip was provided")
} }
vpnIp := ip2int(net.ParseIP(a[0])) parsedIp := net.ParseIP(a[0])
if parsedIp == nil {
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
}
vpnIp := ip2int(parsedIp)
if vpnIp == 0 { if vpnIp == 0 {
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
} }

View File

@@ -4,9 +4,10 @@ import (
"errors" "errors"
"flag" "flag"
"fmt" "fmt"
"github.com/armon/go-radix"
"sort" "sort"
"strings" "strings"
"github.com/armon/go-radix"
) )
// CommandFlags is a function called before help or command execution to parse command line flags // CommandFlags is a function called before help or command execution to parse command line flags

View File

@@ -2,10 +2,11 @@ package sshd
import ( import (
"fmt" "fmt"
"net"
"github.com/armon/go-radix" "github.com/armon/go-radix"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
"net"
) )
type SSHServer struct { type SSHServer struct {

View File

@@ -2,13 +2,14 @@ package sshd
import ( import (
"fmt" "fmt"
"sort"
"strings"
"github.com/anmitsu/go-shlex" "github.com/anmitsu/go-shlex"
"github.com/armon/go-radix" "github.com/armon/go-radix"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/terminal" "golang.org/x/crypto/ssh/terminal"
"sort"
"strings"
) )
type session struct { type session struct {

View File

@@ -3,18 +3,19 @@ package nebula
import ( import (
"errors" "errors"
"fmt" "fmt"
"github.com/cyberdelia/go-metrics-graphite"
mp "github.com/nbrownus/go-metrics-prometheus"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promhttp"
"github.com/rcrowley/go-metrics"
"log" "log"
"net" "net"
"net/http" "net/http"
"time" "time"
graphite "github.com/cyberdelia/go-metrics-graphite"
mp "github.com/nbrownus/go-metrics-prometheus"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promhttp"
"github.com/rcrowley/go-metrics"
) )
func startStats(c *Config) error { func startStats(c *Config, configTest bool) error {
mType := c.GetString("stats.type", "") mType := c.GetString("stats.type", "")
if mType == "" || mType == "none" { if mType == "" || mType == "none" {
return nil return nil
@@ -27,9 +28,9 @@ func startStats(c *Config) error {
switch mType { switch mType {
case "graphite": case "graphite":
startGraphiteStats(interval, c) startGraphiteStats(interval, c, configTest)
case "prometheus": case "prometheus":
startPrometheusStats(interval, c) startPrometheusStats(interval, c, configTest)
default: default:
return fmt.Errorf("stats.type was not understood: %s", mType) return fmt.Errorf("stats.type was not understood: %s", mType)
} }
@@ -43,7 +44,7 @@ func startStats(c *Config) error {
return nil return nil
} }
func startGraphiteStats(i time.Duration, c *Config) error { func startGraphiteStats(i time.Duration, c *Config, configTest bool) error {
proto := c.GetString("stats.protocol", "tcp") proto := c.GetString("stats.protocol", "tcp")
host := c.GetString("stats.host", "") host := c.GetString("stats.host", "")
if host == "" { if host == "" {
@@ -57,11 +58,13 @@ func startGraphiteStats(i time.Duration, c *Config) error {
} }
l.Infof("Starting graphite. Interval: %s, prefix: %s, addr: %s", i, prefix, addr) l.Infof("Starting graphite. Interval: %s, prefix: %s, addr: %s", i, prefix, addr)
if !configTest {
go graphite.Graphite(metrics.DefaultRegistry, i, prefix, addr) go graphite.Graphite(metrics.DefaultRegistry, i, prefix, addr)
}
return nil return nil
} }
func startPrometheusStats(i time.Duration, c *Config) error { func startPrometheusStats(i time.Duration, c *Config, configTest bool) error {
namespace := c.GetString("stats.namespace", "") namespace := c.GetString("stats.namespace", "")
subsystem := c.GetString("stats.subsystem", "") subsystem := c.GetString("stats.subsystem", "")
@@ -79,11 +82,13 @@ func startPrometheusStats(i time.Duration, c *Config) error {
pClient := mp.NewPrometheusProvider(metrics.DefaultRegistry, namespace, subsystem, pr, i) pClient := mp.NewPrometheusProvider(metrics.DefaultRegistry, namespace, subsystem, pr, i)
go pClient.UpdatePrometheusMetrics() go pClient.UpdatePrometheusMetrics()
if !configTest {
go func() { go func() {
l.Infof("Prometheus stats listening on %s at %s", listen, path) l.Infof("Prometheus stats listening on %s at %s", listen, path)
http.Handle(path, promhttp.HandlerFor(pr, promhttp.HandlerOpts{ErrorLog: l})) http.Handle(path, promhttp.HandlerFor(pr, promhttp.HandlerOpts{ErrorLog: l}))
log.Fatal(http.ListenAndServe(listen, nil)) log.Fatal(http.ListenAndServe(listen, nil))
}() }()
}
return nil return nil
} }

View File

@@ -1,9 +1,10 @@
package nebula package nebula
import ( import (
"github.com/stretchr/testify/assert"
"testing" "testing"
"time" "time"
"github.com/stretchr/testify/assert"
) )
func TestNewTimerWheel(t *testing.T) { func TestNewTimerWheel(t *testing.T) {

76
tun_android.go Normal file
View File

@@ -0,0 +1,76 @@
package nebula
import (
"fmt"
"io"
"net"
"os"
"golang.org/x/sys/unix"
)
type Tun struct {
io.ReadWriteCloser
fd int
Device string
Cidr *net.IPNet
MaxMTU int
DefaultMTU int
TXQueueLen int
Routes []route
UnsafeRoutes []route
}
func newTunFromFd(deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
ifce = &Tun{
ReadWriteCloser: file,
fd: int(file.Fd()),
Device: "android",
Cidr: cidr,
DefaultMTU: defaultMTU,
TXQueueLen: txQueueLen,
Routes: routes,
UnsafeRoutes: unsafeRoutes,
}
return
}
func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
return nil, fmt.Errorf("newTun not supported in Android")
}
func (c *Tun) WriteRaw(b []byte) error {
var nn int
for {
max := len(b)
n, err := unix.Write(c.fd, b[nn:max])
if n > 0 {
nn += n
}
if nn == len(b) {
return err
}
if err != nil {
return err
}
if n == 0 {
return io.ErrUnexpectedEOF
}
}
}
func (c Tun) Activate() error {
return nil
}
func (c *Tun) CidrNet() *net.IPNet {
return c.Cidr
}
func (c *Tun) DeviceName() string {
return c.Device
}

View File

@@ -132,7 +132,7 @@ func parseUnsafeRoutes(config *Config, network *net.IPNet) ([]route, error) {
via, ok := rVia.(string) via, ok := rVia.(string)
if !ok { if !ok {
return nil, fmt.Errorf("entry %v.via in tun.unsafe_routes is not a string: %v", i+1, err) return nil, fmt.Errorf("entry %v.via in tun.unsafe_routes is not a string: found %T", i+1, rVia)
} }
nVia := net.ParseIP(via) nVia := net.ParseIP(via)
@@ -147,6 +147,7 @@ func parseUnsafeRoutes(config *Config, network *net.IPNet) ([]route, error) {
r := route{ r := route{
via: &nVia, via: &nVia,
mtu: mtu,
} }
_, r.route, err = net.ParseCIDR(fmt.Sprintf("%v", rRoute)) _, r.route, err = net.ParseCIDR(fmt.Sprintf("%v", rRoute))

View File

@@ -1,3 +1,5 @@
// +build !ios
package nebula package nebula
import ( import (
@@ -20,8 +22,9 @@ type Tun struct {
func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) { func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
if len(routes) > 0 { if len(routes) > 0 {
return nil, fmt.Errorf("Route MTU not supported in Darwin") return nil, fmt.Errorf("route MTU not supported in Darwin")
} }
// NOTE: You cannot set the deviceName under Darwin, so you must check tun.Device after calling .Activate() // NOTE: You cannot set the deviceName under Darwin, so you must check tun.Device after calling .Activate()
return &Tun{ return &Tun{
Cidr: cidr, Cidr: cidr,
@@ -30,30 +33,34 @@ func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route,
}, nil }, nil
} }
func newTunFromFd(deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
return nil, fmt.Errorf("newTunFromFd not supported in Darwin")
}
func (c *Tun) Activate() error { func (c *Tun) Activate() error {
var err error var err error
c.Interface, err = water.New(water.Config{ c.Interface, err = water.New(water.Config{
DeviceType: water.TUN, DeviceType: water.TUN,
}) })
if err != nil { if err != nil {
return fmt.Errorf("Activate failed: %v", err) return fmt.Errorf("activate failed: %v", err)
} }
c.Device = c.Interface.Name() c.Device = c.Interface.Name()
// TODO use syscalls instead of exec.Command // TODO use syscalls instead of exec.Command
if err = exec.Command("ifconfig", c.Device, c.Cidr.String(), c.Cidr.IP.String()).Run(); err != nil { if err = exec.Command("/sbin/ifconfig", c.Device, c.Cidr.String(), c.Cidr.IP.String()).Run(); err != nil {
return fmt.Errorf("failed to run 'ifconfig': %s", err) return fmt.Errorf("failed to run 'ifconfig': %s", err)
} }
if err = exec.Command("route", "-n", "add", "-net", c.Cidr.String(), "-interface", c.Device).Run(); err != nil { if err = exec.Command("/sbin/route", "-n", "add", "-net", c.Cidr.String(), "-interface", c.Device).Run(); err != nil {
return fmt.Errorf("failed to run 'route add': %s", err) return fmt.Errorf("failed to run 'route add': %s", err)
} }
if err = exec.Command("ifconfig", c.Device, "mtu", strconv.Itoa(c.MTU)).Run(); err != nil { if err = exec.Command("/sbin/ifconfig", c.Device, "mtu", strconv.Itoa(c.MTU)).Run(); err != nil {
return fmt.Errorf("failed to run 'ifconfig': %s", err) return fmt.Errorf("failed to run 'ifconfig': %s", err)
} }
// Unsafe path routes // Unsafe path routes
for _, r := range c.UnsafeRoutes { for _, r := range c.UnsafeRoutes {
if err = exec.Command("route", "-n", "add", "-net", r.route.String(), "-interface", c.Device).Run(); err != nil { if err = exec.Command("/sbin/route", "-n", "add", "-net", r.route.String(), "-interface", c.Device).Run(); err != nil {
return fmt.Errorf("failed to run 'route add' for unsafe_route %s: %s", r.route.String(), err) return fmt.Errorf("failed to run 'route add' for unsafe_route %s: %s", r.route.String(), err)
} }
} }
@@ -61,6 +68,14 @@ func (c *Tun) Activate() error {
return nil return nil
} }
func (c *Tun) CidrNet() *net.IPNet {
return c.Cidr
}
func (c *Tun) DeviceName() string {
return c.Device
}
func (c *Tun) WriteRaw(b []byte) error { func (c *Tun) WriteRaw(b []byte) error {
_, err := c.Write(b) _, err := c.Write(b)
return err return err

74
tun_disabled.go Normal file
View File

@@ -0,0 +1,74 @@
package nebula
import (
"fmt"
"io"
"net"
"strings"
log "github.com/sirupsen/logrus"
)
type disabledTun struct {
block chan struct{}
cidr *net.IPNet
logger *log.Logger
}
func newDisabledTun(cidr *net.IPNet, l *log.Logger) *disabledTun {
return &disabledTun{
cidr: cidr,
block: make(chan struct{}),
logger: l,
}
}
func (*disabledTun) Activate() error {
return nil
}
func (t *disabledTun) CidrNet() *net.IPNet {
return t.cidr
}
func (*disabledTun) DeviceName() string {
return "disabled"
}
func (t *disabledTun) Read(b []byte) (int, error) {
<-t.block
return 0, io.EOF
}
func (t *disabledTun) Write(b []byte) (int, error) {
t.logger.WithField("raw", prettyPacket(b)).Debugf("Disabled tun received unexpected payload")
return len(b), nil
}
func (t *disabledTun) WriteRaw(b []byte) error {
_, err := t.Write(b)
return err
}
func (t *disabledTun) Close() error {
if t.block != nil {
close(t.block)
t.block = nil
}
return nil
}
type prettyPacket []byte
func (p prettyPacket) String() string {
var s strings.Builder
for i, b := range p {
if i > 0 && i%8 == 0 {
s.WriteString(" ")
}
s.WriteString(fmt.Sprintf("%02x ", b))
}
return s.String()
}

89
tun_freebsd.go Normal file
View File

@@ -0,0 +1,89 @@
package nebula
import (
"fmt"
"io"
"net"
"os"
"os/exec"
"regexp"
"strconv"
"strings"
)
var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
type Tun struct {
Device string
Cidr *net.IPNet
MTU int
UnsafeRoutes []route
io.ReadWriteCloser
}
func newTunFromFd(deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
return nil, fmt.Errorf("newTunFromFd not supported in FreeBSD")
}
func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
if len(routes) > 0 {
return nil, fmt.Errorf("Route MTU not supported in FreeBSD")
}
if strings.HasPrefix(deviceName, "/dev/") {
deviceName = strings.TrimPrefix(deviceName, "/dev/")
}
if !deviceNameRE.MatchString(deviceName) {
return nil, fmt.Errorf("tun.dev must match `tun[0-9]+`")
}
return &Tun{
Device: deviceName,
Cidr: cidr,
MTU: defaultMTU,
UnsafeRoutes: unsafeRoutes,
}, nil
}
func (c *Tun) Activate() error {
var err error
c.ReadWriteCloser, err = os.OpenFile("/dev/"+c.Device, os.O_RDWR, 0)
if err != nil {
return fmt.Errorf("Activate failed: %v", err)
}
// TODO use syscalls instead of exec.Command
l.Debug("command: ifconfig", c.Device, c.Cidr.String(), c.Cidr.IP.String())
if err = exec.Command("/sbin/ifconfig", c.Device, c.Cidr.String(), c.Cidr.IP.String()).Run(); err != nil {
return fmt.Errorf("failed to run 'ifconfig': %s", err)
}
l.Debug("command: route", "-n", "add", "-net", c.Cidr.String(), "-interface", c.Device)
if err = exec.Command("/sbin/route", "-n", "add", "-net", c.Cidr.String(), "-interface", c.Device).Run(); err != nil {
return fmt.Errorf("failed to run 'route add': %s", err)
}
l.Debug("command: ifconfig", c.Device, "mtu", strconv.Itoa(c.MTU))
if err = exec.Command("/sbin/ifconfig", c.Device, "mtu", strconv.Itoa(c.MTU)).Run(); err != nil {
return fmt.Errorf("failed to run 'ifconfig': %s", err)
}
// Unsafe path routes
for _, r := range c.UnsafeRoutes {
l.Debug("command: route", "-n", "add", "-net", r.route.String(), "-interface", c.Device)
if err = exec.Command("/sbin/route", "-n", "add", "-net", r.route.String(), "-interface", c.Device).Run(); err != nil {
return fmt.Errorf("failed to run 'route add' for unsafe_route %s: %s", r.route.String(), err)
}
}
return nil
}
func (c *Tun) CidrNet() *net.IPNet {
return c.Cidr
}
func (c *Tun) DeviceName() string {
return c.Device
}
func (c *Tun) WriteRaw(b []byte) error {
_, err := c.Write(b)
return err
}

113
tun_ios.go Normal file
View File

@@ -0,0 +1,113 @@
// +build ios
package nebula
import (
"errors"
"fmt"
"io"
"net"
"os"
"sync"
"syscall"
)
type Tun struct {
io.ReadWriteCloser
Device string
Cidr *net.IPNet
}
func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
return nil, fmt.Errorf("newTun not supported in iOS")
}
func newTunFromFd(deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
if len(routes) > 0 {
return nil, fmt.Errorf("route MTU not supported in Darwin")
}
file := os.NewFile(uintptr(deviceFd), "/dev/tun")
ifce = &Tun{
Cidr: cidr,
Device: "iOS",
ReadWriteCloser: &tunReadCloser{f: file},
}
return
}
func (c *Tun) Activate() error {
return nil
}
func (c *Tun) WriteRaw(b []byte) error {
_, err := c.Write(b)
return err
}
// The following is hoisted up from water, we do this so we can inject our own fd on iOS
type tunReadCloser struct {
f io.ReadWriteCloser
rMu sync.Mutex
rBuf []byte
wMu sync.Mutex
wBuf []byte
}
func (t *tunReadCloser) Read(to []byte) (int, error) {
t.rMu.Lock()
defer t.rMu.Unlock()
if cap(t.rBuf) < len(to)+4 {
t.rBuf = make([]byte, len(to)+4)
}
t.rBuf = t.rBuf[:len(to)+4]
n, err := t.f.Read(t.rBuf)
copy(to, t.rBuf[4:])
return n - 4, err
}
func (t *tunReadCloser) Write(from []byte) (int, error) {
if len(from) == 0 {
return 0, syscall.EIO
}
t.wMu.Lock()
defer t.wMu.Unlock()
if cap(t.wBuf) < len(from)+4 {
t.wBuf = make([]byte, len(from)+4)
}
t.wBuf = t.wBuf[:len(from)+4]
// Determine the IP Family for the NULL L2 Header
ipVer := from[0] >> 4
if ipVer == 4 {
t.wBuf[3] = syscall.AF_INET
} else if ipVer == 6 {
t.wBuf[3] = syscall.AF_INET6
} else {
return 0, errors.New("unable to determine IP version from packet")
}
copy(t.wBuf[4:], from)
n, err := t.f.Write(t.wBuf)
return n - 4, err
}
func (t *tunReadCloser) Close() error {
return t.f.Close()
}
func (c *Tun) CidrNet() *net.IPNet {
return c.Cidr
}
func (c *Tun) DeviceName() string {
return c.Device
}

View File

@@ -1,3 +1,5 @@
// +build !android
package nebula package nebula
import ( import (
@@ -75,6 +77,23 @@ type ifreqQLEN struct {
pad [8]byte pad [8]byte
} }
func newTunFromFd(deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
ifce = &Tun{
ReadWriteCloser: file,
fd: int(file.Fd()),
Device: "tun0",
Cidr: cidr,
DefaultMTU: defaultMTU,
TXQueueLen: txQueueLen,
Routes: routes,
UnsafeRoutes: unsafeRoutes,
}
return
}
func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) { func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0) fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
if err != nil { if err != nil {
@@ -216,6 +235,7 @@ func (c Tun) Activate() error {
LinkIndex: link.Attrs().Index, LinkIndex: link.Attrs().Index,
Dst: dr, Dst: dr,
MTU: c.DefaultMTU, MTU: c.DefaultMTU,
AdvMSS: c.advMSS(route{}),
Scope: unix.RT_SCOPE_LINK, Scope: unix.RT_SCOPE_LINK,
Src: c.Cidr.IP, Src: c.Cidr.IP,
Protocol: unix.RTPROT_KERNEL, Protocol: unix.RTPROT_KERNEL,
@@ -233,6 +253,7 @@ func (c Tun) Activate() error {
LinkIndex: link.Attrs().Index, LinkIndex: link.Attrs().Index,
Dst: r.route, Dst: r.route,
MTU: r.mtu, MTU: r.mtu,
AdvMSS: c.advMSS(r),
Scope: unix.RT_SCOPE_LINK, Scope: unix.RT_SCOPE_LINK,
} }
@@ -248,6 +269,7 @@ func (c Tun) Activate() error {
LinkIndex: link.Attrs().Index, LinkIndex: link.Attrs().Index,
Dst: r.route, Dst: r.route,
MTU: r.mtu, MTU: r.mtu,
AdvMSS: c.advMSS(r),
Scope: unix.RT_SCOPE_LINK, Scope: unix.RT_SCOPE_LINK,
} }
@@ -265,3 +287,24 @@ func (c Tun) Activate() error {
return nil return nil
} }
func (c *Tun) CidrNet() *net.IPNet {
return c.Cidr
}
func (c *Tun) DeviceName() string {
return c.Device
}
func (c Tun) advMSS(r route) int {
mtu := r.mtu
if r.mtu == 0 {
mtu = c.DefaultMTU
}
// We only need to set advmss if the route MTU does not match the device MTU
if mtu != c.MaxMTU {
return mtu - 40
}
return 0
}

31
tun_linux_test.go Normal file
View File

@@ -0,0 +1,31 @@
package nebula
import "testing"
var runAdvMSSTests = []struct {
name string
tun Tun
r route
expected int
}{
// Standard case, default MTU is the device max MTU
{"default", Tun{DefaultMTU: 1440, MaxMTU: 1440}, route{}, 0},
{"default-min", Tun{DefaultMTU: 1440, MaxMTU: 1440}, route{mtu: 1440}, 0},
{"default-low", Tun{DefaultMTU: 1440, MaxMTU: 1440}, route{mtu: 1200}, 1160},
// Case where we have a route MTU set higher than the default
{"route", Tun{DefaultMTU: 1440, MaxMTU: 8941}, route{}, 1400},
{"route-min", Tun{DefaultMTU: 1440, MaxMTU: 8941}, route{mtu: 1440}, 1400},
{"route-high", Tun{DefaultMTU: 1440, MaxMTU: 8941}, route{mtu: 8941}, 0},
}
func TestTunAdvMSS(t *testing.T) {
for _, tt := range runAdvMSSTests {
t.Run(tt.name, func(t *testing.T) {
o := tt.tun.advMSS(tt.r)
if o != tt.expected {
t.Errorf("got %d, want %d", o, tt.expected)
}
})
}
}

View File

@@ -1,9 +1,11 @@
package nebula package nebula
import ( import (
"github.com/stretchr/testify/assert" "fmt"
"net" "net"
"testing" "testing"
"github.com/stretchr/testify/assert"
) )
func Test_parseRoutes(t *testing.T) { func Test_parseRoutes(t *testing.T) {
@@ -100,3 +102,126 @@ func Test_parseRoutes(t *testing.T) {
t.Fatal("Did not see both routes") t.Fatal("Did not see both routes")
} }
} }
func Test_parseUnsafeRoutes(t *testing.T) {
c := NewConfig()
_, n, _ := net.ParseCIDR("10.0.0.0/24")
// test no routes config
routes, err := parseUnsafeRoutes(c, n)
assert.Nil(t, err)
assert.Len(t, routes, 0)
// not an array
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": "hi"}
routes, err = parseUnsafeRoutes(c, n)
assert.Nil(t, routes)
assert.EqualError(t, err, "tun.unsafe_routes is not an array")
// no routes
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{}}
routes, err = parseUnsafeRoutes(c, n)
assert.Nil(t, err)
assert.Len(t, routes, 0)
// weird route
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{"asdf"}}
routes, err = parseUnsafeRoutes(c, n)
assert.Nil(t, routes)
assert.EqualError(t, err, "entry 1 in tun.unsafe_routes is invalid")
// no via
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{}}}
routes, err = parseUnsafeRoutes(c, n)
assert.Nil(t, routes)
assert.EqualError(t, err, "entry 1.via in tun.unsafe_routes is not present")
// invalid via
for _, invalidValue := range []interface{}{
127, false, nil, 1.0, []string{"1", "2"},
} {
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": invalidValue}}}
routes, err = parseUnsafeRoutes(c, n)
assert.Nil(t, routes)
assert.EqualError(t, err, fmt.Sprintf("entry 1.via in tun.unsafe_routes is not a string: found %T", invalidValue))
}
// unparsable via
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "via": "nope"}}}
routes, err = parseUnsafeRoutes(c, n)
assert.Nil(t, routes)
assert.EqualError(t, err, "entry 1.via in tun.unsafe_routes failed to parse address: nope")
// missing route
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "500"}}}
routes, err = parseUnsafeRoutes(c, n)
assert.Nil(t, routes)
assert.EqualError(t, err, "entry 1.route in tun.unsafe_routes is not present")
// unparsable route
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "500", "route": "nope"}}}
routes, err = parseUnsafeRoutes(c, n)
assert.Nil(t, routes)
assert.EqualError(t, err, "entry 1.route in tun.unsafe_routes failed to parse: invalid CIDR address: nope")
// within network range
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "10.0.0.0/24"}}}
routes, err = parseUnsafeRoutes(c, n)
assert.Nil(t, routes)
assert.EqualError(t, err, "entry 1.route in tun.unsafe_routes is contained within the network attached to the certificate; route: 10.0.0.0/24, network: 10.0.0.0/24")
// below network range
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "1.0.0.0/8"}}}
routes, err = parseUnsafeRoutes(c, n)
assert.Len(t, routes, 1)
assert.Nil(t, err)
// above network range
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "10.0.1.0/24"}}}
routes, err = parseUnsafeRoutes(c, n)
assert.Len(t, routes, 1)
assert.Nil(t, err)
// no mtu
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "1.0.0.0/8"}}}
routes, err = parseUnsafeRoutes(c, n)
assert.Len(t, routes, 1)
assert.Equal(t, DEFAULT_MTU, routes[0].mtu)
// bad mtu
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "nope"}}}
routes, err = parseUnsafeRoutes(c, n)
assert.Nil(t, routes)
assert.EqualError(t, err, "entry 1.mtu in tun.unsafe_routes is not an integer: strconv.Atoi: parsing \"nope\": invalid syntax")
// low mtu
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "499"}}}
routes, err = parseUnsafeRoutes(c, n)
assert.Nil(t, routes)
assert.EqualError(t, err, "entry 1.mtu in tun.unsafe_routes is below 500: 499")
// happy case
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{
map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "9000", "route": "1.0.0.0/29"},
map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "8000", "route": "1.0.0.1/32"},
}}
routes, err = parseUnsafeRoutes(c, n)
assert.Nil(t, err)
assert.Len(t, routes, 2)
tested := 0
for _, r := range routes {
if r.mtu == 8000 {
assert.Equal(t, "1.0.0.1/32", r.route.String())
tested++
} else {
assert.Equal(t, 9000, r.mtu)
assert.Equal(t, "1.0.0.0/29", r.route.String())
tested++
}
}
if tested != 2 {
t.Fatal("Did not see both unsafe_routes")
}
}

View File

@@ -4,6 +4,7 @@ import (
"fmt" "fmt"
"net" "net"
"os/exec" "os/exec"
"strconv"
"github.com/songgao/water" "github.com/songgao/water"
) )
@@ -12,21 +13,25 @@ type Tun struct {
Device string Device string
Cidr *net.IPNet Cidr *net.IPNet
MTU int MTU int
UnsafeRoutes []route
*water.Interface *water.Interface
} }
func newTunFromFd(deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
return nil, fmt.Errorf("newTunFromFd not supported in Windows")
}
func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) { func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
if len(routes) > 0 { if len(routes) > 0 {
return nil, fmt.Errorf("Route MTU not supported in Windows") return nil, fmt.Errorf("route MTU not supported in Windows")
}
if len(unsafeRoutes) > 0 {
return nil, fmt.Errorf("unsafeRoutes not supported in Windows")
} }
// NOTE: You cannot set the deviceName under Windows, so you must check tun.Device after calling .Activate() // NOTE: You cannot set the deviceName under Windows, so you must check tun.Device after calling .Activate()
return &Tun{ return &Tun{
Cidr: cidr, Cidr: cidr,
MTU: defaultMTU, MTU: defaultMTU,
UnsafeRoutes: unsafeRoutes,
}, nil }, nil
} }
@@ -47,7 +52,7 @@ func (c *Tun) Activate() error {
// TODO use syscalls instead of exec.Command // TODO use syscalls instead of exec.Command
err = exec.Command( err = exec.Command(
"netsh", "interface", "ipv4", "set", "address", `C:\Windows\System32\netsh.exe`, "interface", "ipv4", "set", "address",
fmt.Sprintf("name=%s", c.Device), fmt.Sprintf("name=%s", c.Device),
"source=static", "source=static",
fmt.Sprintf("addr=%s", c.Cidr.IP), fmt.Sprintf("addr=%s", c.Cidr.IP),
@@ -58,7 +63,7 @@ func (c *Tun) Activate() error {
return fmt.Errorf("failed to run 'netsh' to set address: %s", err) return fmt.Errorf("failed to run 'netsh' to set address: %s", err)
} }
err = exec.Command( err = exec.Command(
"netsh", "interface", "ipv4", "set", "interface", `C:\Windows\System32\netsh.exe`, "interface", "ipv4", "set", "interface",
c.Device, c.Device,
fmt.Sprintf("mtu=%d", c.MTU), fmt.Sprintf("mtu=%d", c.MTU),
).Run() ).Run()
@@ -66,9 +71,31 @@ func (c *Tun) Activate() error {
return fmt.Errorf("failed to run 'netsh' to set MTU: %s", err) return fmt.Errorf("failed to run 'netsh' to set MTU: %s", err)
} }
iface, err := net.InterfaceByName(c.Device)
if err != nil {
return fmt.Errorf("failed to find interface named %s: %v", c.Device, err)
}
for _, r := range c.UnsafeRoutes {
err = exec.Command(
"C:\\Windows\\System32\\route.exe", "add", r.route.String(), r.via.String(), "IF", strconv.Itoa(iface.Index),
).Run()
if err != nil {
return fmt.Errorf("failed to add the unsafe_route %s: %v", r.route.String(), err)
}
}
return nil return nil
} }
func (c *Tun) CidrNet() *net.IPNet {
return c.Cidr
}
func (c *Tun) DeviceName() string {
return c.Device
}
func (c *Tun) WriteRaw(b []byte) error { func (c *Tun) WriteRaw(b []byte) error {
_, err := c.Write(b) _, err := c.Write(b)
return err return err

36
udp_android.go Normal file
View File

@@ -0,0 +1,36 @@
package nebula
import (
"fmt"
"net"
"syscall"
"golang.org/x/sys/unix"
)
func NewListenConfig(multi bool) net.ListenConfig {
return net.ListenConfig{
Control: func(network, address string, c syscall.RawConn) error {
if multi {
var controlErr error
err := c.Control(func(fd uintptr) {
if err := syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, unix.SO_REUSEPORT, 1); err != nil {
controlErr = fmt.Errorf("SO_REUSEPORT failed: %v", err)
return
}
})
if err != nil {
return err
}
if controlErr != nil {
return controlErr
}
}
return nil
},
}
}
func (u *udpConn) Rebind() error {
return nil
}

View File

@@ -32,3 +32,12 @@ func NewListenConfig(multi bool) net.ListenConfig {
}, },
} }
} }
func (u *udpConn) Rebind() error {
file, err := u.File()
if err != nil {
return err
}
return syscall.SetsockoptInt(int(file.Fd()), unix.IPPROTO_IP, unix.IP_BOUND_IF, 0)
}

38
udp_freebsd.go Normal file
View File

@@ -0,0 +1,38 @@
package nebula
// FreeBSD support is primarily implemented in udp_generic, besides NewListenConfig
import (
"fmt"
"net"
"syscall"
"golang.org/x/sys/unix"
)
func NewListenConfig(multi bool) net.ListenConfig {
return net.ListenConfig{
Control: func(network, address string, c syscall.RawConn) error {
if multi {
var controlErr error
err := c.Control(func(fd uintptr) {
if err := syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, unix.SO_REUSEPORT, 1); err != nil {
controlErr = fmt.Errorf("SO_REUSEPORT failed: %v", err)
return
}
})
if err != nil {
return err
}
if controlErr != nil {
return controlErr
}
}
return nil
},
}
}
func (u *udpConn) Rebind() error {
return nil
}

View File

@@ -1,4 +1,4 @@
// +build !linux // +build !linux android
// udp_generic implements the nebula UDP interface in pure Go stdlib. This // udp_generic implements the nebula UDP interface in pure Go stdlib. This
// means it can be used on platforms like Darwin and Windows. // means it can be used on platforms like Darwin and Windows.
@@ -65,6 +65,17 @@ func (ua *udpAddr) Equals(t *udpAddr) bool {
return ua.IP.Equal(t.IP) && ua.Port == t.Port return ua.IP.Equal(t.IP) && ua.Port == t.Port
} }
func (ua *udpAddr) Copy() udpAddr {
nu := udpAddr{net.UDPAddr{
Port: ua.Port,
Zone: ua.Zone,
IP: make(net.IP, len(ua.IP)),
}}
copy(nu.IP, ua.IP)
return nu
}
func (uc *udpConn) WriteTo(b []byte, addr *udpAddr) error { func (uc *udpConn) WriteTo(b []byte, addr *udpAddr) error {
_, err := uc.UDPConn.WriteToUDP(b, &addr.UDPAddr) _, err := uc.UDPConn.WriteToUDP(b, &addr.UDPAddr)
return err return err

View File

@@ -1,3 +1,5 @@
// +build !android
package nebula package nebula
import ( import (
@@ -69,9 +71,11 @@ func NewListener(ip string, port int, multi bool) (*udpConn, error) {
var lip [4]byte var lip [4]byte
copy(lip[:], net.ParseIP(ip).To4()) copy(lip[:], net.ParseIP(ip).To4())
if multi {
if err = unix.SetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_REUSEPORT, 1); err != nil { if err = unix.SetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_REUSEPORT, 1); err != nil {
return nil, fmt.Errorf("unable to set SO_REUSEPORT: %s", err) return nil, fmt.Errorf("unable to set SO_REUSEPORT: %s", err)
} }
}
if err = unix.Bind(fd, &unix.SockaddrInet4{Addr: lip, Port: port}); err != nil { if err = unix.Bind(fd, &unix.SockaddrInet4{Addr: lip, Port: port}); err != nil {
return nil, fmt.Errorf("unable to bind to socket: %s", err) return nil, fmt.Errorf("unable to bind to socket: %s", err)
@@ -85,6 +89,14 @@ func NewListener(ip string, port int, multi bool) (*udpConn, error) {
return &udpConn{sysFd: fd}, err return &udpConn{sysFd: fd}, err
} }
func (u *udpConn) Rebind() error {
return nil
}
func (ua *udpAddr) Copy() udpAddr {
return *ua
}
func (u *udpConn) SetRecvBuffer(n int) error { func (u *udpConn) SetRecvBuffer(n int) error {
return unix.SetsockoptInt(u.sysFd, unix.SOL_SOCKET, unix.SO_RCVBUFFORCE, n) return unix.SetsockoptInt(u.sysFd, unix.SOL_SOCKET, unix.SO_RCVBUFFORCE, n)
} }
@@ -137,9 +149,13 @@ func (u *udpConn) ListenOut(f *Interface) {
//TODO: should we track this? //TODO: should we track this?
//metric := metrics.GetOrRegisterHistogram("test.batch_read", nil, metrics.NewExpDecaySample(1028, 0.015)) //metric := metrics.GetOrRegisterHistogram("test.batch_read", nil, metrics.NewExpDecaySample(1028, 0.015))
msgs, buffers, names := u.PrepareRawMessages(f.udpBatchSize) msgs, buffers, names := u.PrepareRawMessages(f.udpBatchSize)
read := u.ReadMulti
if f.udpBatchSize == 1 {
read = u.ReadSingle
}
for { for {
n, err := u.ReadMulti(msgs) n, err := read(msgs)
if err != nil { if err != nil {
l.WithError(err).Error("Failed to read packets") l.WithError(err).Error("Failed to read packets")
continue continue
@@ -155,34 +171,24 @@ func (u *udpConn) ListenOut(f *Interface) {
} }
} }
func (u *udpConn) Read(addr *udpAddr, b []byte) ([]byte, error) { func (u *udpConn) ReadSingle(msgs []rawMessage) (int, error) {
var rsa rawSockaddrAny
var rLen = unix.SizeofSockaddrAny
for { for {
n, _, err := unix.Syscall6( n, _, err := unix.Syscall6(
unix.SYS_RECVFROM, unix.SYS_RECVMSG,
uintptr(u.sysFd), uintptr(u.sysFd),
uintptr(unsafe.Pointer(&b[0])), uintptr(unsafe.Pointer(&(msgs[0].Hdr))),
uintptr(len(b)), 0,
uintptr(0), 0,
uintptr(unsafe.Pointer(&rsa)), 0,
uintptr(unsafe.Pointer(&rLen)), 0,
) )
if err != 0 { if err != 0 {
return nil, &net.OpError{Op: "read", Err: err} return 0, &net.OpError{Op: "recvmsg", Err: err}
} }
if rsa.Addr.Family == unix.AF_INET { msgs[0].Len = uint32(n)
addr.Port = uint16(rsa.Addr.Data[0])<<8 + uint16(rsa.Addr.Data[1]) return 1, nil
addr.IP = uint32(rsa.Addr.Data[2])<<24 + uint32(rsa.Addr.Data[3])<<16 + uint32(rsa.Addr.Data[4])<<8 + uint32(rsa.Addr.Data[5])
} else {
addr.Port = 0
addr.IP = 0
}
return b[:n], nil
} }
} }
@@ -280,13 +286,6 @@ func (ua *udpAddr) Equals(t *udpAddr) bool {
return ua.IP == t.IP && ua.Port == t.Port return ua.IP == t.IP && ua.Port == t.Port
} }
func (ua *udpAddr) Copy() *udpAddr {
return &udpAddr{
Port: ua.Port,
IP: ua.IP,
}
}
func (ua *udpAddr) String() string { func (ua *udpAddr) String() string {
return fmt.Sprintf("%s:%v", int2ip(ua.IP), ua.Port) return fmt.Sprintf("%s:%v", int2ip(ua.IP), ua.Port)
} }

View File

@@ -1,5 +1,6 @@
// +build linux // +build linux
// +build 386 amd64p32 arm mips mipsle // +build 386 amd64p32 arm mips mipsle
// +build !android
package nebula package nebula

View File

@@ -1,5 +1,6 @@
// +build linux // +build linux
// +build amd64 arm64 ppc64 ppc64le mips64 mips64le s390x // +build amd64 arm64 ppc64 ppc64le mips64 mips64le s390x
// +build !android
package nebula package nebula

View File

@@ -20,3 +20,7 @@ func NewListenConfig(multi bool) net.ListenConfig {
}, },
} }
} }
func (u *udpConn) Rebind() error {
return nil
}

130
util/assert.go Normal file
View File

@@ -0,0 +1,130 @@
package util
import (
"fmt"
"reflect"
"testing"
"time"
"unsafe"
"github.com/stretchr/testify/assert"
)
// AssertDeepCopyEqual checks to see if two variables have the same values but DO NOT share any memory
// There is currently a special case for `time.loc` (as this code traverses into unexported fields)
func AssertDeepCopyEqual(t *testing.T, a interface{}, b interface{}) {
v1 := reflect.ValueOf(a)
v2 := reflect.ValueOf(b)
if !assert.Equal(t, v1.Type(), v2.Type()) {
return
}
traverseDeepCopy(t, v1, v2, v1.Type().String())
}
func traverseDeepCopy(t *testing.T, v1 reflect.Value, v2 reflect.Value, name string) bool {
switch v1.Kind() {
case reflect.Array:
for i := 0; i < v1.Len(); i++ {
if !traverseDeepCopy(t, v1.Index(i), v2.Index(i), fmt.Sprintf("%s[%v]", name, i)) {
return false
}
}
return true
case reflect.Slice:
if v1.IsNil() || v2.IsNil() {
return assert.Equal(t, v1.IsNil(), v2.IsNil(), "%s are not both nil %+v, %+v", name, v1, v2)
}
if !assert.Equal(t, v1.Len(), v2.Len(), "%s did not have the same length", name) {
return false
}
// A slice with cap 0
if v1.Cap() != 0 && !assert.NotEqual(t, v1.Pointer(), v2.Pointer(), "%s point to the same slice %v == %v", name, v1.Pointer(), v2.Pointer()) {
return false
}
v1c := v1.Cap()
v2c := v2.Cap()
if v1c > 0 && v2c > 0 && v1.Slice(0, v1c).Slice(v1c-1, v1c-1).Pointer() == v2.Slice(0, v2c).Slice(v2c-1, v2c-1).Pointer() {
return assert.Fail(t, "", "%s share some underlying memory", name)
}
for i := 0; i < v1.Len(); i++ {
if !traverseDeepCopy(t, v1.Index(i), v2.Index(i), fmt.Sprintf("%s[%v]", name, i)) {
return false
}
}
return true
case reflect.Interface:
if v1.IsNil() || v2.IsNil() {
return assert.Equal(t, v1.IsNil(), v2.IsNil(), "%s are not both nil", name)
}
return traverseDeepCopy(t, v1.Elem(), v2.Elem(), name)
case reflect.Ptr:
local := reflect.ValueOf(time.Local).Pointer()
if local == v1.Pointer() && local == v2.Pointer() {
return true
}
if !assert.NotEqual(t, v1.Pointer(), v2.Pointer(), "%s points to the same memory", name) {
return false
}
return traverseDeepCopy(t, v1.Elem(), v2.Elem(), name)
case reflect.Struct:
for i, n := 0, v1.NumField(); i < n; i++ {
if !traverseDeepCopy(t, v1.Field(i), v2.Field(i), name+"."+v1.Type().Field(i).Name) {
return false
}
}
return true
case reflect.Map:
if v1.IsNil() || v2.IsNil() {
return assert.Equal(t, v1.IsNil(), v2.IsNil(), "%s are not both nil", name)
}
if !assert.Equal(t, v1.Len(), v2.Len(), "%s are not the same length", name) {
return false
}
if !assert.NotEqual(t, v1.Pointer(), v2.Pointer(), "%s point to the same memory", name) {
return false
}
for _, k := range v1.MapKeys() {
val1 := v1.MapIndex(k)
val2 := v2.MapIndex(k)
if !assert.True(t, val1.IsValid(), "%s is an invalid key in %s", k, name) {
return false
}
if !assert.True(t, val2.IsValid(), "%s is an invalid key in %s", k, name) {
return false
}
if !traverseDeepCopy(t, val1, val2, name+fmt.Sprintf("%s[%s]", name, k)) {
return false
}
}
return true
default:
if v1.CanInterface() && v2.CanInterface() {
return assert.Equal(t, v1.Interface(), v2.Interface(), "%s was not equal", name)
}
e1 := reflect.NewAt(v1.Type(), unsafe.Pointer(v1.UnsafeAddr())).Elem().Interface()
e2 := reflect.NewAt(v2.Type(), unsafe.Pointer(v2.UnsafeAddr())).Elem().Interface()
return assert.Equal(t, e1, e2, "%s (unexported) was not equal", name)
}
}