Compare commits

...

63 Commits

Author SHA1 Message Date
Nate Brown
2e1d6743be v1.4.0 (#458)
Update CHANGELOG for Nebula v1.4.0

Co-authored-by: Wade Simmons <wade@wades.im>
2021-05-10 21:23:49 -04:00
Nate Brown
d004fae4f9 Unlock the hostmap quickly, lock hostinfo instead (#459) 2021-05-05 13:10:55 -05:00
Nate Brown
95f4c8a01b Don't check for rebind if we are closing the tunnel (#457) 2021-05-04 19:15:24 -05:00
Nate Brown
9ff73cb02f Increase the timestamp resolution for handshakes (#453) 2021-05-03 14:10:00 -05:00
John Maguire
98c391396c Remove log when no handshake message is sent (#452) 2021-04-30 18:19:40 -05:00
Nate Brown
1bc6f5fe6c Minor windows focused improvements (#443)
Co-authored-by: Wade Simmons <wadey@slack-corp.com>
2021-04-30 15:04:47 -05:00
Wade Simmons
44cb697552 Add more metrics (#450)
* Add more metrics

This change adds the following counter metrics:

Metrics to track packets dropped at the firewall:

    firewall.dropped.local_ip
    firewall.dropped.remote_ip
    firewall.dropped.no_rule

Metrics to track handshakes attempts that have been initiated and ones
that have timed out (ones that have completed are tracked by the
existing "handshakes" histogram).

    handshake_manager.initiated
    handshake_manager.timed_out

Metrics to track when cached_packets are dropped because we run out of
buffer space, and how many are sent once the handshake completes.

    hostinfo.cached_packets.dropped
    hostinfo.cached_packets.sent

This change also notes how many cached packets we have when we log the
final "Handshake received" message for either stage1 for stage2.

* separate incoming/outgoing metrics

* remove "allowed" firewall metrics

We don't need this on the hotpath, they aren't worh it.

* don't need pointers here
2021-04-27 22:23:18 -04:00
Nathan Brown
db23fdf9bc Dont apply race avoidance to existing handshakes, use the handshake time to determine who wins (#451)
Co-authored-by: Wade Simmons <wadey@slack-corp.com>
2021-04-27 21:15:34 -05:00
Nathan Brown
df7c7eec4a Get out faster on nil udpAddr (#449) 2021-04-26 20:21:47 -05:00
Nathan Brown
6f37280e8e Fully close tunnels when CloseAllTunnels is called (#448) 2021-04-26 10:42:24 -05:00
Nathan Brown
a0735dd7d5 Add locking around ssh conns to avoid concurrent map access on reload (#447) 2021-04-23 14:43:16 -05:00
Nathan Brown
1deb5d98e8 Fix tun funcs for ios and android (#446) 2021-04-22 15:23:40 -05:00
Nathan Brown
a1ee521d79 Fix a failed return in an error case (#445) 2021-04-17 18:47:31 -05:00
brad-defined
7859140711 Only set serveDns if the host is also configured to be a lighthouse. (#433) 2021-04-16 13:33:56 -05:00
brad-defined
17106f83a0 Ensure the Nebula device exists before attempting to bind to the Nebula IP (#375) 2021-04-16 10:34:28 -05:00
Nathan Brown
ab08be1e3e Don't panic on a nil response from the lighthouse (#442) 2021-04-15 09:12:21 -05:00
Nathan Brown
710df6a876 Refactor remotes and handshaking to give every address a fair shot (#437) 2021-04-14 13:50:09 -05:00
John Maguire
20bef975cd Remove obsolete systemd unit settings (take 2) (#438) 2021-04-07 12:02:40 -05:00
Nathan Brown
480036fbc8 Remove unused structs in hostmap.go (#430) 2021-04-01 22:07:11 -05:00
Nathan Brown
1499be3e40 Fix name resolution for host names in config (#431) 2021-04-01 21:48:41 -05:00
Nathan Brown
64d8e5aa96 More LH cleanup (#429) 2021-04-01 10:23:31 -05:00
Nathan Brown
75f7bda0a4 Lighthouse performance pass (#418) 2021-03-31 17:32:02 -05:00
Nathan Brown
e7e55618ff Include bad backets in the good handshake test (#428) 2021-03-31 13:36:10 -05:00
Nathan Brown
0c2e5973e1 Simple lie test (#427) 2021-03-31 10:26:35 -05:00
Nathan Brown
830d6d4639 Start of end to end testing with a good handshake between two nodes (#425) 2021-03-29 14:29:20 -05:00
Nathan Brown
883e09a392 Don't use a global ca pool (#426) 2021-03-29 12:10:19 -05:00
Wade Simmons
4603b5b2dd fix PromoteEvery check (#424)
This check was accidentally typo'd in #396 from `%` to `&`. Restore the
correct functionality here (we want to do the check every "PromoteEvery"
count packets).
2021-03-26 15:01:05 -04:00
Wade Simmons
a71541fb0b export build version as a prometheus label (#405)
This is how Prometheus recommends you do it, and how they do it
themselves in their client. This makes it easy to see which versions you
have deployed in your fleet, and query over it too.
2021-03-26 14:16:35 -04:00
Nathan Brown
3ea7e1b75f Don't use a global logger (#423) 2021-03-26 09:46:30 -05:00
Nathan Brown
7a9f9dbded Don't craft buffers if we don't need them (#416) 2021-03-22 18:25:06 -05:00
Nathan Brown
7073d204a8 IPv6 support for outside (udp) (#369) 2021-03-18 20:37:24 -05:00
Joe Doss
9e94442ce7 Add fedora dist files. (#413) 2021-03-18 12:33:43 -07:00
Joe Doss
13471f5792 Remove obsolete systemd unit settings. (#412) 2021-03-18 12:29:36 -07:00
Thomas Roten
ea07a89cc8 Ensure mutex is unlocked when adding remote IP. (#406)
Currently, if you use the remote allow list config, as soon as you attempt to create a tunnel to a node that has a blocked IP address, a mutex is locked and never unlocked. This happens even if the node has an allowed remote IP address in addition to the blocked remote IP address.

This pull request ensures that the lighthouse mutex is unlocked whenever we attempt to add a remote IP.
2021-03-16 12:41:35 -04:00
Ryan Huber
3aaaea6309 don't allow a useless handshake with yourself (#402)
* don't allow a useless handshake with yourself

* remove helper
2021-03-15 12:58:23 -07:00
Wade Simmons
5506da3de9 Fix selection of UDP remote to use during stage2 (#404)
The change for #401 incorrectly called HostInfo.ForcePromoteBest in
stage2, when we really we want to pick the remote that we received the
response from.
2021-03-12 21:43:24 -05:00
Wade Simmons
6c55d67f18 Refactor handshake_ix (#401)
There are some subtle race conditions with the previous handshake_ix implementation, mostly around collisions with localIndexId. This change refactors it so that we have a "commit" phase during the handshake where we grab the lock for the hostmap and ensure that we have a unique local index before storing it. We also now avoid using the pending hostmap at all for receiving stage1 packets, since we have everything we need to just store the completed handshake.

Co-authored-by: Nate Brown <nbrown.us@gmail.com>
Co-authored-by: Ryan Huber <rhuber@gmail.com>
Co-authored-by: forfuncsake <drussell@slack-corp.com>
2021-03-12 14:16:25 -05:00
Wade Simmons
64d8035d09 fix race in getOrHandshake (#400)
We missed this race with #396 (and I think this is also the crash in
issue #226). We need to lock a little higher in the getOrHandshake
method, before we reset hostinfo.ConnectionInfo. Previously, two
routines could enter this section and confuse the handshake process.

This could result in the other side sending a recv_error that also has
a race with setting hostinfo.ConnectionInfo back to nil. So we make sure
to grab the lock in handleRecvError as well.

Neither of these code paths are in the hot path (handling packets
between two hosts over an active tunnel) so there should be no
performance concerns.
2021-03-09 09:27:02 -05:00
Ryan Huber
73a5ed90b2 Do not allow someone to run a nebula lighthouse with an ephemeral port (#399)
* Do not allow someone to run a nebula lighthouse with an ephemeral port

* derp - we discover the port so we have to check the config setting

* No context needed for this error

* gofmt yourself

* Revert "gofmt yourself"

This reverts commit c01423498e.

* Revert "No context needed for this error"

This reverts commit 6792af6846.

* snip snap snip snap
2021-03-08 12:42:06 -08:00
Wade Simmons
d604270966 Fix most known data races (#396)
This change fixes all of the known data races that `make smoke-docker-race` finds, except for one.

Most of these races are around the handshake phase for a hostinfo, so we add a RWLock to the hostinfo and Lock during each of the handshake stages.

Some of the other races are around consistently using `atomic` around the `messageCounter` field. To make this harder to mess up, I have renamed the field to `atomicMessageCounter` (I also removed the unnecessary extra pointer deference as we can just point directly to the struct field).

The last remaining data race is around reading `ConnectionInfo.ready`, which is a boolean that is only written to once when the handshake has finished. Due to it being in the hot path for packets and the rare case that this could actually be an issue, holding off on fixing that one for now.

here is the results of `make smoke-docker-race`:

before:

    lighthouse1: Found 2 data race(s)
    host2:       Found 36 data race(s)
    host3:       Found 17 data race(s)
    host4:       Found 31 data race(s)

after:

    host2: Found 1 data race(s)
    host4: Found 1 data race(s)

Fixes: #147
Fixes: #226
Fixes: #283
Fixes: #316
2021-03-05 21:18:33 -05:00
Nathan Brown
29c5f31f90 Add a check in the makefile to ensure a minimum version of go is installed (#383) 2021-03-02 13:29:05 -06:00
Nathan Brown
b6234abfb3 Add a way to trigger punch backs via lighthouse (#394) 2021-03-01 19:06:01 -06:00
Wade Simmons
2a4beb41b9 Routine-local conntrack cache (#391)
Previously, every packet we see gets a lock on the conntrack table and updates it. When running with multiple routines, this can cause heavy lock contention and limit our ability for the threads to run independently. This change caches reads from the conntrack table for a very short period of time to reduce this lock contention. This cache will currently default to disabled unless you are running with multiple routines, in which case the default cache delay will be 1 second. This means that entries in the conntrack table may be up to 1 second out of date and remain in a routine local cache for up to 1 second longer than the global table.

Instead of calling time.Now() for every packet, this cache system relies on a tick thread that updates the current cache "version" each tick. Every packet we check if the cache version is out of date, and reset the cache if so.
2021-03-01 19:52:17 -05:00
Wade Simmons
d232ccbfab add metrics for the udp sockets using SO_MEMINFO (#390)
Retrieve the current socket stats using SO_MEMINFO and report them as
metrics gauges. If SO_MEMINFO isn't supported, we don't report these metrics.
2021-03-01 19:51:33 -05:00
Nathan Brown
ecfb40f29c Fix osx for mq changes, this does not implement mq on osx (#395) 2021-03-01 16:57:05 -05:00
Wade Simmons
1bae5b2550 more validation in pending hostmap deletes (#344)
We are currently seeing some cases where we are not deleting entries
correctly from the pending hostmap. I believe this is a case of
an inbound timer tick firing and deleting the Hosts map entry for
a newer handshake attempt than intended, thus leaving the old Indexes
entry orphaned. This change adds some extra checking when deleteing from
the Indexes and Hosts maps to ensure we clean everything up correctly.
2021-03-01 12:40:46 -05:00
Wade Simmons
73081d99bc add make smoke-docker (#287)
This makes it easier to use the docker container smoke test that
GitHub actions runs. There is also `make smoke-docker-race` that runs the
smoke test with `-race` enabled.
2021-03-01 11:15:15 -05:00
Tim Rots
e7e6a23cde fix a few typos (#302) 2021-03-01 11:14:34 -05:00
Wade Simmons
a0583ebdca tun_disabled: reply to ICMP Echo Request (#342)
This change allows a server running with `tun.disabled: true` (usually
a lighthouse) to still reply to ICMP EchoRequest packets. This allows
you to "ping" the lighthouse Nebula IP as a quick check to make sure the
tunnel is up, even when running with tun.disabled.

This is still gated by allowing `icmp` packets in the inbound firewall
rules.
2021-03-01 11:09:41 -05:00
Wade Simmons
27d9a67dda Proper multiqueue support for tun devices (#382)
This change is for Linux only.

Previously, when running with multiple tun.routines, we would only have one file descriptor. This change instead sets IFF_MULTI_QUEUE and opens a file descriptor for each routine. This allows us to process with multiple threads while preventing out of order packet reception issues.

To attempt to distribute the flows across the queues, we try to write to the tun/UDP queue that corresponds with the one we read from. So if we read a packet from tun queue "2", we will write the outgoing encrypted packet to UDP queue "2". Because of the nature of how multi queue works with flows, a given host tunnel will be sticky to a given routine (so if you try to performance benchmark by only using one tunnel between two hosts, you are only going to be using a max of one thread for each direction).

Because this system works much better when we can correlate flows between the tun and udp routines, we are deprecating the undocumented "tun.routines" and "listen.routines" parameters and introducing a new "routines" parameter that sets the value for both. If you use the old undocumented parameters, the max of the values will be used and a warning logged.

Co-authored-by: Nate Brown <nbrown.us@gmail.com>
2021-02-25 15:01:14 -05:00
John Maguire
2bce222550 List possible cipher options in example config (#385) 2021-02-19 21:46:42 -06:00
Wade Simmons
3dd1108099 Go 1.16 and darwin-arm64 (#381)
This commit switches to Go 1.16 and adds a release binary for darwin-arm64.

Fixes: #343
2021-02-17 13:11:57 -05:00
Nathan Brown
d4b81f9b8d Add QR code support to nebula-cert (#297) 2021-02-11 18:53:25 -06:00
brad-defined
454bc8a6bb Check certificate banner during nebula-cert print (#373) 2021-02-05 14:52:32 -06:00
Wade Simmons
ce9ad37431 fix regression with LightHouseHandler and punchBack (#346)
The change introduced by #320 incorrectly re-uses the output buffer for
sending punchBack packets. Since we are currently spawning a new
goroutine for each send here, we need to allocate a new buffer each
time. We can come back and optimize this in the future, but for now we
should fix the regression.
2020-11-25 17:49:26 -05:00
Wade Simmons
ee7c27093c add HostMap.RemoteIndexes (#329)
This change adds an index based on HostInfo.remoteIndexId. This allows
us to use HostMap.QueryReverseIndex without having to loop over all
entries in the map (this can be a bottleneck under high traffic
lighthouses).

Without this patch, a high traffic lighthouse server receiving recv_error
packets and lots of handshakes, cpu pprof trace can look like this:

      flat  flat%   sum%        cum   cum%
    2000ms 32.26% 32.26%     3040ms 49.03%  github.com/slackhq/nebula.(*HostMap).QueryReverseIndex
     870ms 14.03% 46.29%     1060ms 17.10%  runtime.mapiternext

Which shows 50% of total cpu time is being spent in QueryReverseIndex.
2020-11-23 14:51:16 -05:00
Wade Simmons
2e7ca027a4 Lighthouse handler optimizations (#320)
We noticed that the number of memory allocations LightHouse.HandleRequest creates for each call can seriously impact performance for high traffic lighthouses. This PR introduces a benchmark in the first commit and then optimizes memory usage by creating a LightHouseHandler struct. This struct allows us to re-use memory between each lighthouse request (one instance per UDP listener go-routine).
2020-11-23 14:50:01 -05:00
mhp
672ce1f0a8 Move slice allocations in connection manager monitor loop (#340)
* Move slice allocations in connection manager monitor loop

* move further out

Co-authored-by: Miran Park <mpark@slack-corp.com>
2020-11-19 15:44:05 -08:00
Wade Simmons
384b1166ea fix panic in UnmarshalNebulaCertificate (#339)
This fixes a panic in UnmarshalNebulaCertificate when unmarshaling
a payload with Details set to nil.

Fixes: #332
2020-11-19 08:44:54 -05:00
Wade Simmons
0389596f66 don't mark handshake packets as "lost" (#331)
Packet 1 is always a stage 1 handshake and packet 2 is always stage 2.
Normal packets don't start flowing until the message counter is 3 or
higher.

Currently we only receive either packet 1 or 2 depending on if
we are the initiator or responder for the handshake, so we end up
marking one of these as "lost". We should mark these packets as "seen"
when we are the one sending them, since we don't expect to see them from
the other side.
2020-11-16 14:03:08 -05:00
Ryan Huber
43a3988afc i don't think this is used at all anymore (#323) 2020-10-29 21:43:50 -04:00
Brian Kelly
5c23676a0f Added line to systemd config template to start Nebula before sshd (#317)
During shutdown, this will keep Nebula alive until after sshd is finished. This cleanly terminates ssh clients accessing a server over a Nebula tunnel.
2020-10-29 21:43:02 -04:00
Nathan Brown
f6d0b4b893 Update README for supported platforms (#312) 2020-10-12 13:11:32 -05:00
99 changed files with 7948 additions and 2767 deletions

View File

@@ -14,10 +14,10 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Set up Go 1.15
- name: Set up Go 1.16
uses: actions/setup-go@v1
with:
go-version: 1.15
go-version: 1.16
id: go
- name: Check out code into the Go module directory
@@ -26,9 +26,9 @@ jobs:
- uses: actions/cache@v1
with:
path: ~/go/pkg/mod
key: ${{ runner.os }}-gofmt-${{ hashFiles('**/go.sum') }}
key: ${{ runner.os }}-gofmt1.16-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-gofmt-
${{ runner.os }}-gofmt1.16-
- name: Install goimports
run: |

View File

@@ -10,10 +10,10 @@ jobs:
name: Build Linux All
runs-on: ubuntu-latest
steps:
- name: Set up Go 1.15
- name: Set up Go 1.16
uses: actions/setup-go@v1
with:
go-version: 1.15
go-version: 1.16
- name: Checkout code
uses: actions/checkout@v2
@@ -34,10 +34,10 @@ jobs:
name: Build Windows amd64
runs-on: windows-latest
steps:
- name: Set up Go 1.15
- name: Set up Go 1.16
uses: actions/setup-go@v1
with:
go-version: 1.15
go-version: 1.16
- name: Checkout code
uses: actions/checkout@v2
@@ -58,10 +58,10 @@ jobs:
name: Build Darwin amd64
runs-on: macOS-latest
steps:
- name: Set up Go 1.15
- name: Set up Go 1.16
uses: actions/setup-go@v1
with:
go-version: 1.15
go-version: 1.16
- name: Checkout code
uses: actions/checkout@v2
@@ -69,6 +69,7 @@ jobs:
- name: Build
run: |
make BUILD_NUMBER="${GITHUB_REF#refs/tags/v}" service build/nebula-darwin-amd64.tar.gz
make BUILD_NUMBER="${GITHUB_REF#refs/tags/v}" service build/nebula-darwin-arm64.tar.gz
mkdir release
mv build/*.tar.gz release
@@ -159,6 +160,16 @@ jobs:
asset_name: nebula-darwin-amd64.tar.gz
asset_content_type: application/gzip
- name: Upload darwin-arm64
uses: actions/upload-release-asset@v1.0.1
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
with:
upload_url: ${{ steps.create_release.outputs.upload_url }}
asset_path: ./darwin-latest/nebula-darwin-arm64.tar.gz
asset_name: nebula-darwin-arm64.tar.gz
asset_content_type: application/gzip
- name: Upload windows-amd64
uses: actions/upload-release-asset@v1.0.1
env:

View File

@@ -18,10 +18,10 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Set up Go 1.15
- name: Set up Go 1.16
uses: actions/setup-go@v1
with:
go-version: 1.15
go-version: 1.16
id: go
- name: Check out code into the Go module directory
@@ -30,12 +30,12 @@ jobs:
- uses: actions/cache@v1
with:
path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
key: ${{ runner.os }}-go1.16-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-go-
${{ runner.os }}-go1.16-
- name: build
run: make
run: make bin-docker
- name: setup docker image
working-directory: ./.github/workflows/smoke

View File

@@ -1,5 +1,7 @@
FROM debian:buster
ADD ./build /
ADD ./build /nebula
ENTRYPOINT ["/nebula"]
WORKDIR /nebula
ENTRYPOINT ["/nebula/nebula"]

View File

@@ -8,8 +8,8 @@ mkdir ./build
(
cd build
cp ../../../../nebula .
cp ../../../../nebula-cert .
cp ../../../../build/linux-amd64/nebula .
cp ../../../../build/linux-amd64/nebula-cert .
HOST="lighthouse1" \
AM_LIGHTHOUSE=true \
@@ -29,11 +29,11 @@ mkdir ./build
OUTBOUND='[{"port": "any", "proto": "icmp", "group": "lighthouse"}]' \
../genconfig.sh >host4.yml
./nebula-cert ca -name "Smoke Test"
./nebula-cert sign -name "lighthouse1" -groups "lighthouse,lighthouse1" -ip "192.168.100.1/24"
./nebula-cert sign -name "host2" -groups "host,host2" -ip "192.168.100.2/24"
./nebula-cert sign -name "host3" -groups "host,host3" -ip "192.168.100.3/24"
./nebula-cert sign -name "host4" -groups "host,host4" -ip "192.168.100.4/24"
../../../../nebula-cert ca -name "Smoke Test"
../../../../nebula-cert sign -name "lighthouse1" -groups "lighthouse,lighthouse1" -ip "192.168.100.1/24"
../../../../nebula-cert sign -name "host2" -groups "host,host2" -ip "192.168.100.2/24"
../../../../nebula-cert sign -name "host3" -groups "host,host3" -ip "192.168.100.3/24"
../../../../nebula-cert sign -name "host4" -groups "host,host4" -ip "192.168.100.4/24"
)
docker build -t nebula:smoke .
sudo docker build -t nebula:smoke .

View File

@@ -33,9 +33,9 @@ lighthouse_hosts() {
cat <<EOF
pki:
ca: /ca.crt
cert: /${HOST}.crt
key: /${HOST}.key
ca: ca.crt
cert: ${HOST}.crt
key: ${HOST}.key
lighthouse:
am_lighthouse: ${AM_LIGHTHOUSE:-false}

View File

@@ -1,19 +1,33 @@
#!/bin/sh
#!/bin/bash
set -e -x
docker run --name lighthouse1 --rm nebula:smoke -config lighthouse1.yml -test
docker run --name host2 --rm nebula:smoke -config host2.yml -test
docker run --name host3 --rm nebula:smoke -config host3.yml -test
docker run --name host4 --rm nebula:smoke -config host4.yml -test
set -o pipefail
docker run --name lighthouse1 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm nebula:smoke -config lighthouse1.yml &
mkdir -p logs
cleanup() {
set +e
if [ "$(jobs -r)" ]
then
sudo docker kill lighthouse1 host2 host3 host4
fi
}
trap cleanup EXIT
sudo docker run --name lighthouse1 --rm nebula:smoke -config lighthouse1.yml -test
sudo docker run --name host2 --rm nebula:smoke -config host2.yml -test
sudo docker run --name host3 --rm nebula:smoke -config host3.yml -test
sudo docker run --name host4 --rm nebula:smoke -config host4.yml -test
sudo docker run --name lighthouse1 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm nebula:smoke -config lighthouse1.yml 2>&1 | tee logs/lighthouse1 &
sleep 1
docker run --name host2 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm nebula:smoke -config host2.yml &
sudo docker run --name host2 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm nebula:smoke -config host2.yml 2>&1 | tee logs/host2 &
sleep 1
docker run --name host3 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm nebula:smoke -config host3.yml &
sudo docker run --name host3 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm nebula:smoke -config host3.yml 2>&1 | tee logs/host3 &
sleep 1
docker run --name host4 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm nebula:smoke -config host4.yml &
sudo docker run --name host4 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm nebula:smoke -config host4.yml 2>&1 | tee logs/host4 &
sleep 1
set +x
@@ -21,35 +35,35 @@ echo
echo " *** Testing ping from lighthouse1"
echo
set -x
docker exec lighthouse1 ping -c1 192.168.100.2
docker exec lighthouse1 ping -c1 192.168.100.3
sudo docker exec lighthouse1 ping -c1 192.168.100.2
sudo docker exec lighthouse1 ping -c1 192.168.100.3
set +x
echo
echo " *** Testing ping from host2"
echo
set -x
docker exec host2 ping -c1 192.168.100.1
sudo docker exec host2 ping -c1 192.168.100.1
# Should fail because not allowed by host3 inbound firewall
! docker exec host2 ping -c1 192.168.100.3 -w5 || exit 1
! sudo docker exec host2 ping -c1 192.168.100.3 -w5 || exit 1
set +x
echo
echo " *** Testing ping from host3"
echo
set -x
docker exec host3 ping -c1 192.168.100.1
docker exec host3 ping -c1 192.168.100.2
sudo docker exec host3 ping -c1 192.168.100.1
sudo docker exec host3 ping -c1 192.168.100.2
set +x
echo
echo " *** Testing ping from host4"
echo
set -x
docker exec host4 ping -c1 192.168.100.1
sudo docker exec host4 ping -c1 192.168.100.1
# Should fail because not allowed by host4 outbound firewall
! docker exec host4 ping -c1 192.168.100.2 -w5 || exit 1
! docker exec host4 ping -c1 192.168.100.3 -w5 || exit 1
! sudo docker exec host4 ping -c1 192.168.100.2 -w5 || exit 1
! sudo docker exec host4 ping -c1 192.168.100.3 -w5 || exit 1
set +x
echo
@@ -57,7 +71,13 @@ echo " *** Testing conntrack"
echo
set -x
# host2 can ping host3 now that host3 pinged it first
docker exec host2 ping -c1 192.168.100.3
sudo docker exec host2 ping -c1 192.168.100.3
# host4 can ping host2 once conntrack established
docker exec host2 ping -c1 192.168.100.4
docker exec host4 ping -c1 192.168.100.2
sudo docker exec host2 ping -c1 192.168.100.4
sudo docker exec host4 ping -c1 192.168.100.2
sudo docker exec host4 sh -c 'kill 1'
sudo docker exec host3 sh -c 'kill 1'
sudo docker exec host2 sh -c 'kill 1'
sudo docker exec lighthouse1 sh -c 'kill 1'
sleep 1

View File

@@ -18,10 +18,10 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Set up Go 1.15
- name: Set up Go 1.16
uses: actions/setup-go@v1
with:
go-version: 1.15
go-version: 1.16
id: go
- name: Check out code into the Go module directory
@@ -30,9 +30,9 @@ jobs:
- uses: actions/cache@v1
with:
path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
key: ${{ runner.os }}-go1.16-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-go-
${{ runner.os }}-go1.16-
- name: Build
run: make all
@@ -40,6 +40,9 @@ jobs:
- name: Test
run: make test
- name: End 2 end
run: make e2evv
test:
name: Build and test on ${{ matrix.os }}
runs-on: ${{ matrix.os }}
@@ -48,10 +51,10 @@ jobs:
os: [windows-latest, macOS-latest]
steps:
- name: Set up Go 1.15
- name: Set up Go 1.16
uses: actions/setup-go@v1
with:
go-version: 1.15
go-version: 1.16
id: go
- name: Check out code into the Go module directory
@@ -60,9 +63,9 @@ jobs:
- uses: actions/cache@v1
with:
path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
key: ${{ runner.os }}-go1.16-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-go-
${{ runner.os }}-go1.16-
- name: Build nebula
run: go build ./cmd/nebula
@@ -72,3 +75,6 @@ jobs:
- name: Test
run: go test -v ./...
- name: End 2 end
run: make e2evv

View File

@@ -7,10 +7,70 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [Unreleased]
## [1.4.0] - 2021-05-11
### Added
- Ability to output qr code images in `print`, `ca`, and `sign` modes for `nebula-cert`.
This is useful when configuring mobile clients. (#297)
- Experimental: Nebula can now do work on more than 2 cpu cores in send and receive paths via
the new `routines` config option. (#382, #391, #395)
- ICMP ping requests can be responded to when the `tun.disabled` is `true`.
This is useful so that you can "ping" a lighthouse running in this mode. (#342)
- Run smoke tests via `make smoke-docker`. (#287)
- More reported stats, udp memory use on linux, build version (when using Prometheus), firewall,
handshake, and cached packet stats. (#390, #405, #450, #453)
- IPv6 support for the underlay network. (#369)
- End to end testing, run with `make e2e`. (#425, #427, #428)
### Changed
- Updated the kardianos/service go library from 1.0.0 to 1.1.0, which
now creates launchd plist to write stdout/stderr to files by default.
- Darwin will now log stdout/stderr to a file when using `-service` mode. (#303)
- Example systemd unit file now better arranged startup order when using `sshd`
and other fixes. (#317, #412, #438)
- Reduced memory utilization/garbage collection. (#320, #323, #340)
- Reduced CPU utilization. (#329)
- Build against go 1.16. (#381)
- Refactored handshakes to improve performance and correctness. (#401, #402, #404, #416, #451)
- Improved roaming support for mobile clients. (#394, #457)
- Lighthouse performance and correctness improvements. (#406, #418, #429, #433, #437, #442, #449)
- Better ordered startup to enable `sshd`, `stats`, and `dns` subsystems to listen on
the nebula interface. (#375)
### Fixed
- No longer report handshake packets as `lost` in stats. (#331)
- Error handling in the `cert` package. (#339, #373)
- Orphaned pending hostmap entries are cleaned up. (#344)
- Most known data races are now resolved. (#396, #400, #424)
- Refuse to run a lighthouse on an ephemeral port. (#399)
- Removed the global references. (#423, #426, #446)
- Reloading via ssh command avoids a panic. (#447)
- Shutdown is now performed in a cleaner way. (#448)
- Logs will now find their way to Windows event viewer when running under `-service` mode
in Windows. (#443)
## [1.3.0] - 2020-09-22
@@ -185,7 +245,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Initial public release.
[Unreleased]: https://github.com/slackhq/nebula/compare/v1.3.0...HEAD
[Unreleased]: https://github.com/slackhq/nebula/compare/v1.4.0...HEAD
[1.4.0]: https://github.com/slackhq/nebula/releases/tag/v1.4.0
[1.3.0]: https://github.com/slackhq/nebula/releases/tag/v1.3.0
[1.2.0]: https://github.com/slackhq/nebula/releases/tag/v1.2.0
[1.1.0]: https://github.com/slackhq/nebula/releases/tag/v1.1.0

View File

@@ -1,8 +1,31 @@
GOMINVERSION = 1.16
NEBULA_CMD_PATH = "./cmd/nebula"
BUILD_NUMBER ?= dev+$(shell date -u '+%Y%m%d%H%M%S')
GO111MODULE = on
export GO111MODULE
# Set up OS specific bits
ifeq ($(OS),Windows_NT)
#TODO: we should be able to ditch awk as well
GOVERSION := $(shell go version | awk "{print substr($$3, 3)}")
GOISMIN := $(shell IF "$(GOVERSION)" GEQ "$(GOMINVERSION)" ECHO 1)
NEBULA_CMD_SUFFIX = .exe
NULL_FILE = nul
else
GOVERSION := $(shell go version | awk '{print substr($$3, 3)}')
GOISMIN := $(shell expr "$(GOVERSION)" ">=" "$(GOMINVERSION)")
NEBULA_CMD_SUFFIX =
NULL_FILE = /dev/null
endif
# Only defined the build number if we haven't already
ifndef BUILD_NUMBER
ifeq ($(shell git describe --exact-match 2>$(NULL_FILE)),)
BUILD_NUMBER = $(shell git describe --abbrev=0 --match "v*" | cut -dv -f2)-$(shell git branch --show-current)-$(shell git describe --long --dirty | cut -d- -f2-)
else
BUILD_NUMBER = $(shell git describe --exact-match --dirty | cut -dv -f2)
endif
endif
LDFLAGS = -X main.Build=$(BUILD_NUMBER)
ALL_LINUX = linux-amd64 \
@@ -20,9 +43,25 @@ ALL_LINUX = linux-amd64 \
ALL = $(ALL_LINUX) \
darwin-amd64 \
darwin-arm64 \
freebsd-amd64 \
windows-amd64
e2e:
$(TEST_ENV) go test -tags=e2e_testing -count=1 $(TEST_FLAGS) ./e2e
e2ev: TEST_FLAGS = -v
e2ev: e2e
e2evv: TEST_ENV += TEST_LOGS=1
e2evv: e2ev
e2evvv: TEST_ENV += TEST_LOGS=2
e2evvv: e2ev
e2evvvv: TEST_ENV += TEST_LOGS=3
e2evvvv: e2ev
all: $(ALL:%=build/%/nebula) $(ALL:%=build/%/nebula-cert)
release: $(ALL:%=build/nebula-%.tar.gz)
@@ -31,6 +70,8 @@ release-linux: $(ALL_LINUX:%=build/nebula-%.tar.gz)
release-freebsd: build/nebula-freebsd-amd64.tar.gz
BUILD_ARGS = -trimpath
bin-windows: build/windows-amd64/nebula.exe build/windows-amd64/nebula-cert.exe
mv $? .
@@ -41,12 +82,12 @@ bin-freebsd: build/freebsd-amd64/nebula build/freebsd-amd64/nebula-cert
mv $? .
bin:
go build -trimpath -ldflags "$(LDFLAGS)" -o ./nebula ${NEBULA_CMD_PATH}
go build -trimpath -ldflags "$(LDFLAGS)" -o ./nebula-cert ./cmd/nebula-cert
go build $(BUILD_ARGS) -ldflags "$(LDFLAGS)" -o ./nebula${NEBULA_CMD_SUFFIX} ${NEBULA_CMD_PATH}
go build $(BUILD_ARGS) -ldflags "$(LDFLAGS)" -o ./nebula-cert${NEBULA_CMD_SUFFIX} ./cmd/nebula-cert
install:
go install -trimpath -ldflags "$(LDFLAGS)" ${NEBULA_CMD_PATH}
go install -trimpath -ldflags "$(LDFLAGS)" ./cmd/nebula-cert
go install $(BUILD_ARGS) -ldflags "$(LDFLAGS)" ${NEBULA_CMD_PATH}
go install $(BUILD_ARGS) -ldflags "$(LDFLAGS)" ./cmd/nebula-cert
build/linux-arm-%: GOENV += GOARM=$(word 3, $(subst -, ,$*))
build/linux-mips-%: GOENV += GOMIPS=$(word 3, $(subst -, ,$*))
@@ -57,12 +98,12 @@ build/linux-mips-softfloat/%: LDFLAGS += -s -w
build/%/nebula: .FORCE
GOOS=$(firstword $(subst -, , $*)) \
GOARCH=$(word 2, $(subst -, ,$*)) $(GOENV) \
go build -trimpath -o $@ -ldflags "$(LDFLAGS)" ${NEBULA_CMD_PATH}
go build $(BUILD_ARGS) -o $@ -ldflags "$(LDFLAGS)" ${NEBULA_CMD_PATH}
build/%/nebula-cert: .FORCE
GOOS=$(firstword $(subst -, , $*)) \
GOARCH=$(word 2, $(subst -, ,$*)) $(GOENV) \
go build -trimpath -o $@ -ldflags "$(LDFLAGS)" ./cmd/nebula-cert
go build $(BUILD_ARGS) -o $@ -ldflags "$(LDFLAGS)" ./cmd/nebula-cert
build/%/nebula.exe: build/%/nebula
mv $< $@
@@ -100,20 +141,29 @@ bench-cpu-long:
proto: nebula.pb.go cert/cert.pb.go
nebula.pb.go: nebula.proto .FORCE
go build github.com/golang/protobuf/protoc-gen-go
PATH="$(PWD):$(PATH)" protoc --go_out=. $<
rm protoc-gen-go
go build github.com/gogo/protobuf/protoc-gen-gogofaster
PATH="$(CURDIR):$(PATH)" protoc --gogofaster_out=paths=source_relative:. $<
rm protoc-gen-gogofaster
cert/cert.pb.go: cert/cert.proto .FORCE
$(MAKE) -C cert cert.pb.go
service:
@echo > /dev/null
@echo > $(NULL_FILE)
$(eval NEBULA_CMD_PATH := "./cmd/nebula-service")
ifeq ($(words $(MAKECMDGOALS)),1)
$(MAKE) service ${.DEFAULT_GOAL} --no-print-directory
@$(MAKE) service ${.DEFAULT_GOAL} --no-print-directory
endif
bin-docker: bin build/linux-amd64/nebula build/linux-amd64/nebula-cert
smoke-docker: bin-docker
cd .github/workflows/smoke/ && ./build.sh
cd .github/workflows/smoke/ && ./smoke.sh
smoke-docker-race: BUILD_ARGS = -race
smoke-docker-race: smoke-docker
.FORCE:
.PHONY: test test-cov-html bench bench-cpu bench-cpu-long bin proto release service
.PHONY: e2e e2ev e2evv e2evvv e2evvvv test test-cov-html bench bench-cpu bench-cpu-long bin proto release service smoke-docker smoke-docker-race
.DEFAULT_GOAL := bin

View File

@@ -1,7 +1,6 @@
## What is Nebula?
Nebula is a scalable overlay networking tool with a focus on performance, simplicity and security.
It lets you seamlessly connect computers anywhere in the world. Nebula is portable, and runs on Linux, OSX, and Windows.
(Also: keep this quiet, but we have an early prototype running on iOS).
It lets you seamlessly connect computers anywhere in the world. Nebula is portable, and runs on Linux, OSX, Windows, iOS, and Android.
It can be used to connect a small number of computers, but is also able to connect tens of thousands of computers.
Nebula incorporates a number of existing concepts like encryption, security groups, certificates,
@@ -13,6 +12,22 @@ You can read more about Nebula [here](https://medium.com/p/884110a5579).
You can also join the NebulaOSS Slack group [here](https://join.slack.com/t/nebulaoss/shared_invite/enQtOTA5MDI4NDg3MTg4LTkwY2EwNTI4NzQyMzc0M2ZlODBjNWI3NTY1MzhiOThiMmZlZjVkMTI0NGY4YTMyNjUwMWEyNzNkZTJmYzQxOGU)
## Supported Platforms
#### Desktop and Server
Check the [releases](https://github.com/slackhq/nebula/releases/latest) page for downloads
- Linux - 64 and 32 bit, arm, and others
- Windows
- MacOS
- Freebsd
#### Mobile
- [iOS](https://apps.apple.com/us/app/mobile-nebula/id1509587936?itsct=apps_box&amp;itscg=30200)
- [Android](https://play.google.com/store/apps/details?id=net.defined.mobile_nebula&pcampaignid=pcampaignidMKT-Other-global-all-co-prtnr-py-PartBadge-Mar2515-1)
## Technical Overview
Nebula is a mutually authenticated peer-to-peer software defined network based on the [Noise Protocol Framework](https://noiseprotocol.org/).

View File

@@ -2,12 +2,13 @@ package nebula
import (
"fmt"
"net"
"regexp"
)
type AllowList struct {
// The values of this cidrTree are `bool`, signifying allow/deny
cidrTree *CIDRTree
cidrTree *CIDR6Tree
// To avoid ambiguity, all rules must be true, or all rules must be false.
nameRules []AllowListNameRule
@@ -18,7 +19,7 @@ type AllowListNameRule struct {
Allow bool
}
func (al *AllowList) Allow(ip uint32) bool {
func (al *AllowList) Allow(ip net.IP) bool {
if al == nil {
return true
}
@@ -32,6 +33,34 @@ func (al *AllowList) Allow(ip uint32) bool {
}
}
func (al *AllowList) AllowIpV4(ip uint32) bool {
if al == nil {
return true
}
result := al.cidrTree.MostSpecificContainsIpV4(ip)
switch v := result.(type) {
case bool:
return v
default:
panic(fmt.Errorf("invalid state, allowlist returned: %T %v", result, result))
}
}
func (al *AllowList) AllowIpV6(hi, lo uint64) bool {
if al == nil {
return true
}
result := al.cidrTree.MostSpecificContainsIpV6(hi, lo)
switch v := result.(type) {
case bool:
return v
default:
panic(fmt.Errorf("invalid state, allowlist returned: %T %v", result, result))
}
}
func (al *AllowList) AllowName(name string) bool {
if al == nil || len(al.nameRules) == 0 {
return true

View File

@@ -9,17 +9,26 @@ import (
)
func TestAllowList_Allow(t *testing.T) {
assert.Equal(t, true, ((*AllowList)(nil)).Allow(ip2int(net.ParseIP("1.1.1.1"))))
assert.Equal(t, true, ((*AllowList)(nil)).Allow(net.ParseIP("1.1.1.1")))
tree := NewCIDRTree()
tree := NewCIDR6Tree()
tree.AddCIDR(getCIDR("0.0.0.0/0"), true)
tree.AddCIDR(getCIDR("10.0.0.0/8"), false)
tree.AddCIDR(getCIDR("10.42.42.42/32"), true)
tree.AddCIDR(getCIDR("10.42.0.0/16"), true)
tree.AddCIDR(getCIDR("10.42.42.0/24"), true)
tree.AddCIDR(getCIDR("10.42.42.0/24"), false)
tree.AddCIDR(getCIDR("::1/128"), true)
tree.AddCIDR(getCIDR("::2/128"), false)
al := &AllowList{cidrTree: tree}
assert.Equal(t, true, al.Allow(ip2int(net.ParseIP("1.1.1.1"))))
assert.Equal(t, false, al.Allow(ip2int(net.ParseIP("10.0.0.4"))))
assert.Equal(t, true, al.Allow(ip2int(net.ParseIP("10.42.42.42"))))
assert.Equal(t, true, al.Allow(net.ParseIP("1.1.1.1")))
assert.Equal(t, false, al.Allow(net.ParseIP("10.0.0.4")))
assert.Equal(t, true, al.Allow(net.ParseIP("10.42.42.42")))
assert.Equal(t, false, al.Allow(net.ParseIP("10.42.42.41")))
assert.Equal(t, true, al.Allow(net.ParseIP("10.42.0.1")))
assert.Equal(t, true, al.Allow(net.ParseIP("::1")))
assert.Equal(t, false, al.Allow(net.ParseIP("::2")))
}
func TestAllowList_AllowName(t *testing.T) {

View File

@@ -26,7 +26,7 @@ func NewBits(bits uint64) *Bits {
}
}
func (b *Bits) Check(i uint64) bool {
func (b *Bits) Check(l logrus.FieldLogger, i uint64) bool {
// If i is the next number, return true.
if i > b.current || (i == 0 && b.firstSeen == false && b.current < b.length) {
return true
@@ -47,7 +47,7 @@ func (b *Bits) Check(i uint64) bool {
return false
}
func (b *Bits) Update(i uint64) bool {
func (b *Bits) Update(l *logrus.Logger, i uint64) bool {
// If i is the next number, return true and update current.
if i == b.current+1 {
// Report missed packets, we can only understand what was missed after the first window has been gone through

View File

@@ -7,6 +7,7 @@ import (
)
func TestBits(t *testing.T) {
l := NewTestLogger()
b := NewBits(10)
// make sure it is the right size
@@ -14,46 +15,46 @@ func TestBits(t *testing.T) {
// This is initialized to zero - receive one. This should work.
assert.True(t, b.Check(1))
u := b.Update(1)
assert.True(t, b.Check(l, 1))
u := b.Update(l, 1)
assert.True(t, u)
assert.EqualValues(t, 1, b.current)
g := []bool{false, true, false, false, false, false, false, false, false, false}
assert.Equal(t, g, b.bits)
// Receive two
assert.True(t, b.Check(2))
u = b.Update(2)
assert.True(t, b.Check(l, 2))
u = b.Update(l, 2)
assert.True(t, u)
assert.EqualValues(t, 2, b.current)
g = []bool{false, true, true, false, false, false, false, false, false, false}
assert.Equal(t, g, b.bits)
// Receive two again - it will fail
assert.False(t, b.Check(2))
u = b.Update(2)
assert.False(t, b.Check(l, 2))
u = b.Update(l, 2)
assert.False(t, u)
assert.EqualValues(t, 2, b.current)
// Jump ahead to 15, which should clear everything and set the 6th element
assert.True(t, b.Check(15))
u = b.Update(15)
assert.True(t, b.Check(l, 15))
u = b.Update(l, 15)
assert.True(t, u)
assert.EqualValues(t, 15, b.current)
g = []bool{false, false, false, false, false, true, false, false, false, false}
assert.Equal(t, g, b.bits)
// Mark 14, which is allowed because it is in the window
assert.True(t, b.Check(14))
u = b.Update(14)
assert.True(t, b.Check(l, 14))
u = b.Update(l, 14)
assert.True(t, u)
assert.EqualValues(t, 15, b.current)
g = []bool{false, false, false, false, true, true, false, false, false, false}
assert.Equal(t, g, b.bits)
// Mark 5, which is not allowed because it is not in the window
assert.False(t, b.Check(5))
u = b.Update(5)
assert.False(t, b.Check(l, 5))
u = b.Update(l, 5)
assert.False(t, u)
assert.EqualValues(t, 15, b.current)
g = []bool{false, false, false, false, true, true, false, false, false, false}
@@ -61,63 +62,65 @@ func TestBits(t *testing.T) {
// make sure we handle wrapping around once to the current position
b = NewBits(10)
assert.True(t, b.Update(1))
assert.True(t, b.Update(11))
assert.True(t, b.Update(l, 1))
assert.True(t, b.Update(l, 11))
assert.Equal(t, []bool{false, true, false, false, false, false, false, false, false, false}, b.bits)
// Walk through a few windows in order
b = NewBits(10)
for i := uint64(0); i <= 100; i++ {
assert.True(t, b.Check(i), "Error while checking %v", i)
assert.True(t, b.Update(i), "Error while updating %v", i)
assert.True(t, b.Check(l, i), "Error while checking %v", i)
assert.True(t, b.Update(l, i), "Error while updating %v", i)
}
}
func TestBitsDupeCounter(t *testing.T) {
l := NewTestLogger()
b := NewBits(10)
b.lostCounter.Clear()
b.dupeCounter.Clear()
b.outOfWindowCounter.Clear()
assert.True(t, b.Update(1))
assert.True(t, b.Update(l, 1))
assert.Equal(t, int64(0), b.dupeCounter.Count())
assert.False(t, b.Update(1))
assert.False(t, b.Update(l, 1))
assert.Equal(t, int64(1), b.dupeCounter.Count())
assert.True(t, b.Update(2))
assert.True(t, b.Update(l, 2))
assert.Equal(t, int64(1), b.dupeCounter.Count())
assert.True(t, b.Update(3))
assert.True(t, b.Update(l, 3))
assert.Equal(t, int64(1), b.dupeCounter.Count())
assert.False(t, b.Update(1))
assert.False(t, b.Update(l, 1))
assert.Equal(t, int64(0), b.lostCounter.Count())
assert.Equal(t, int64(2), b.dupeCounter.Count())
assert.Equal(t, int64(0), b.outOfWindowCounter.Count())
}
func TestBitsOutOfWindowCounter(t *testing.T) {
l := NewTestLogger()
b := NewBits(10)
b.lostCounter.Clear()
b.dupeCounter.Clear()
b.outOfWindowCounter.Clear()
assert.True(t, b.Update(20))
assert.True(t, b.Update(l, 20))
assert.Equal(t, int64(0), b.outOfWindowCounter.Count())
assert.True(t, b.Update(21))
assert.True(t, b.Update(22))
assert.True(t, b.Update(23))
assert.True(t, b.Update(24))
assert.True(t, b.Update(25))
assert.True(t, b.Update(26))
assert.True(t, b.Update(27))
assert.True(t, b.Update(28))
assert.True(t, b.Update(29))
assert.True(t, b.Update(l, 21))
assert.True(t, b.Update(l, 22))
assert.True(t, b.Update(l, 23))
assert.True(t, b.Update(l, 24))
assert.True(t, b.Update(l, 25))
assert.True(t, b.Update(l, 26))
assert.True(t, b.Update(l, 27))
assert.True(t, b.Update(l, 28))
assert.True(t, b.Update(l, 29))
assert.Equal(t, int64(0), b.outOfWindowCounter.Count())
assert.False(t, b.Update(0))
assert.False(t, b.Update(l, 0))
assert.Equal(t, int64(1), b.outOfWindowCounter.Count())
//tODO: make sure lostcounter doesn't increase in orderly increment
@@ -127,23 +130,24 @@ func TestBitsOutOfWindowCounter(t *testing.T) {
}
func TestBitsLostCounter(t *testing.T) {
l := NewTestLogger()
b := NewBits(10)
b.lostCounter.Clear()
b.dupeCounter.Clear()
b.outOfWindowCounter.Clear()
//assert.True(t, b.Update(0))
assert.True(t, b.Update(0))
assert.True(t, b.Update(20))
assert.True(t, b.Update(21))
assert.True(t, b.Update(22))
assert.True(t, b.Update(23))
assert.True(t, b.Update(24))
assert.True(t, b.Update(25))
assert.True(t, b.Update(26))
assert.True(t, b.Update(27))
assert.True(t, b.Update(28))
assert.True(t, b.Update(29))
assert.True(t, b.Update(l, 0))
assert.True(t, b.Update(l, 20))
assert.True(t, b.Update(l, 21))
assert.True(t, b.Update(l, 22))
assert.True(t, b.Update(l, 23))
assert.True(t, b.Update(l, 24))
assert.True(t, b.Update(l, 25))
assert.True(t, b.Update(l, 26))
assert.True(t, b.Update(l, 27))
assert.True(t, b.Update(l, 28))
assert.True(t, b.Update(l, 29))
assert.Equal(t, int64(20), b.lostCounter.Count())
assert.Equal(t, int64(0), b.dupeCounter.Count())
assert.Equal(t, int64(0), b.outOfWindowCounter.Count())
@@ -153,56 +157,56 @@ func TestBitsLostCounter(t *testing.T) {
b.dupeCounter.Clear()
b.outOfWindowCounter.Clear()
assert.True(t, b.Update(0))
assert.True(t, b.Update(l, 0))
assert.Equal(t, int64(0), b.lostCounter.Count())
assert.True(t, b.Update(9))
assert.True(t, b.Update(l, 9))
assert.Equal(t, int64(0), b.lostCounter.Count())
// 10 will set 0 index, 0 was already set, no lost packets
assert.True(t, b.Update(10))
assert.True(t, b.Update(l, 10))
assert.Equal(t, int64(0), b.lostCounter.Count())
// 11 will set 1 index, 1 was missed, we should see 1 packet lost
assert.True(t, b.Update(11))
assert.True(t, b.Update(l, 11))
assert.Equal(t, int64(1), b.lostCounter.Count())
// Now let's fill in the window, should end up with 8 lost packets
assert.True(t, b.Update(12))
assert.True(t, b.Update(13))
assert.True(t, b.Update(14))
assert.True(t, b.Update(15))
assert.True(t, b.Update(16))
assert.True(t, b.Update(17))
assert.True(t, b.Update(18))
assert.True(t, b.Update(19))
assert.True(t, b.Update(l, 12))
assert.True(t, b.Update(l, 13))
assert.True(t, b.Update(l, 14))
assert.True(t, b.Update(l, 15))
assert.True(t, b.Update(l, 16))
assert.True(t, b.Update(l, 17))
assert.True(t, b.Update(l, 18))
assert.True(t, b.Update(l, 19))
assert.Equal(t, int64(8), b.lostCounter.Count())
// Jump ahead by a window size
assert.True(t, b.Update(29))
assert.True(t, b.Update(l, 29))
assert.Equal(t, int64(8), b.lostCounter.Count())
// Now lets walk ahead normally through the window, the missed packets should fill in
assert.True(t, b.Update(30))
assert.True(t, b.Update(31))
assert.True(t, b.Update(32))
assert.True(t, b.Update(33))
assert.True(t, b.Update(34))
assert.True(t, b.Update(35))
assert.True(t, b.Update(36))
assert.True(t, b.Update(37))
assert.True(t, b.Update(38))
assert.True(t, b.Update(l, 30))
assert.True(t, b.Update(l, 31))
assert.True(t, b.Update(l, 32))
assert.True(t, b.Update(l, 33))
assert.True(t, b.Update(l, 34))
assert.True(t, b.Update(l, 35))
assert.True(t, b.Update(l, 36))
assert.True(t, b.Update(l, 37))
assert.True(t, b.Update(l, 38))
// 39 packets tracked, 22 seen, 17 lost
assert.Equal(t, int64(17), b.lostCounter.Count())
// Jump ahead by 2 windows, should have recording 1 full window missing
assert.True(t, b.Update(58))
assert.True(t, b.Update(l, 58))
assert.Equal(t, int64(27), b.lostCounter.Count())
// Now lets walk ahead normally through the window, the missed packets should fill in from this window
assert.True(t, b.Update(59))
assert.True(t, b.Update(60))
assert.True(t, b.Update(61))
assert.True(t, b.Update(62))
assert.True(t, b.Update(63))
assert.True(t, b.Update(64))
assert.True(t, b.Update(65))
assert.True(t, b.Update(66))
assert.True(t, b.Update(67))
assert.True(t, b.Update(l, 59))
assert.True(t, b.Update(l, 60))
assert.True(t, b.Update(l, 61))
assert.True(t, b.Update(l, 62))
assert.True(t, b.Update(l, 63))
assert.True(t, b.Update(l, 64))
assert.True(t, b.Update(l, 65))
assert.True(t, b.Update(l, 66))
assert.True(t, b.Update(l, 67))
// 68 packets tracked, 32 seen, 36 missed
assert.Equal(t, int64(36), b.lostCounter.Count())
assert.Equal(t, int64(0), b.dupeCounter.Count())

View File

@@ -7,11 +7,10 @@ import (
"strings"
"time"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert"
)
var trustedCAs *cert.NebulaCAPool
type CertState struct {
certificate *cert.NebulaCertificate
rawCertificate []byte
@@ -119,7 +118,7 @@ func NewCertStateFromConfig(c *Config) (*CertState, error) {
return NewCertState(nebulaCert, rawKey)
}
func loadCAFromConfig(c *Config) (*cert.NebulaCAPool, error) {
func loadCAFromConfig(l *logrus.Logger, c *Config) (*cert.NebulaCAPool, error) {
var rawCA []byte
var err error

View File

@@ -2,8 +2,8 @@ GO111MODULE = on
export GO111MODULE
cert.pb.go: cert.proto .FORCE
go build github.com/golang/protobuf/protoc-gen-go
PATH="$(PWD):$(PATH)" protoc --go_out=. $<
go build google.golang.org/protobuf/cmd/protoc-gen-go
PATH="$(CURDIR):$(PATH)" protoc --go_out=. --go_opt=paths=source_relative $<
rm protoc-gen-go
.FORCE:

View File

@@ -61,6 +61,10 @@ func UnmarshalNebulaCertificate(b []byte) (*NebulaCertificate, error) {
return nil, err
}
if rc.Details == nil {
return nil, fmt.Errorf("encoded Details was nil")
}
if len(rc.Details.Ips)%2 != 0 {
return nil, fmt.Errorf("encoded IPs should be in pairs, an odd number was found")
}
@@ -123,6 +127,9 @@ func UnmarshalNebulaCertificateFromPEM(b []byte) (*NebulaCertificate, []byte, er
if p == nil {
return nil, r, fmt.Errorf("input did not contain a valid PEM encoded block")
}
if p.Type != CertBanner {
return nil, r, fmt.Errorf("bytes did not contain a proper nebula certificate banner")
}
nc, err := UnmarshalNebulaCertificate(p.Bytes)
return nc, r, err
}

View File

@@ -1,202 +1,298 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.26.0
// protoc v3.14.0
// source: cert.proto
package cert
import (
fmt "fmt"
proto "github.com/golang/protobuf/proto"
math "math"
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
reflect "reflect"
sync "sync"
)
// Reference imports to suppress errors if they are not otherwise used.
var _ = proto.Marshal
var _ = fmt.Errorf
var _ = math.Inf
// This is a compile-time assertion to ensure that this generated file
// is compatible with the proto package it is being compiled against.
// A compilation error at this line likely means your copy of the
// proto package needs to be updated.
const _ = proto.ProtoPackageIsVersion3 // please upgrade the proto package
const (
// Verify that this generated code is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion)
// Verify that runtime/protoimpl is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
)
type RawNebulaCertificate struct {
Details *RawNebulaCertificateDetails `protobuf:"bytes,1,opt,name=Details,json=details,proto3" json:"Details,omitempty"`
Signature []byte `protobuf:"bytes,2,opt,name=Signature,json=signature,proto3" json:"Signature,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
Details *RawNebulaCertificateDetails `protobuf:"bytes,1,opt,name=Details,proto3" json:"Details,omitempty"`
Signature []byte `protobuf:"bytes,2,opt,name=Signature,proto3" json:"Signature,omitempty"`
}
func (m *RawNebulaCertificate) Reset() { *m = RawNebulaCertificate{} }
func (m *RawNebulaCertificate) String() string { return proto.CompactTextString(m) }
func (*RawNebulaCertificate) ProtoMessage() {}
func (x *RawNebulaCertificate) Reset() {
*x = RawNebulaCertificate{}
if protoimpl.UnsafeEnabled {
mi := &file_cert_proto_msgTypes[0]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *RawNebulaCertificate) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*RawNebulaCertificate) ProtoMessage() {}
func (x *RawNebulaCertificate) ProtoReflect() protoreflect.Message {
mi := &file_cert_proto_msgTypes[0]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use RawNebulaCertificate.ProtoReflect.Descriptor instead.
func (*RawNebulaCertificate) Descriptor() ([]byte, []int) {
return fileDescriptor_a142e29cbef9b1cf, []int{0}
return file_cert_proto_rawDescGZIP(), []int{0}
}
func (m *RawNebulaCertificate) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_RawNebulaCertificate.Unmarshal(m, b)
}
func (m *RawNebulaCertificate) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_RawNebulaCertificate.Marshal(b, m, deterministic)
}
func (m *RawNebulaCertificate) XXX_Merge(src proto.Message) {
xxx_messageInfo_RawNebulaCertificate.Merge(m, src)
}
func (m *RawNebulaCertificate) XXX_Size() int {
return xxx_messageInfo_RawNebulaCertificate.Size(m)
}
func (m *RawNebulaCertificate) XXX_DiscardUnknown() {
xxx_messageInfo_RawNebulaCertificate.DiscardUnknown(m)
}
var xxx_messageInfo_RawNebulaCertificate proto.InternalMessageInfo
func (m *RawNebulaCertificate) GetDetails() *RawNebulaCertificateDetails {
if m != nil {
return m.Details
func (x *RawNebulaCertificate) GetDetails() *RawNebulaCertificateDetails {
if x != nil {
return x.Details
}
return nil
}
func (m *RawNebulaCertificate) GetSignature() []byte {
if m != nil {
return m.Signature
func (x *RawNebulaCertificate) GetSignature() []byte {
if x != nil {
return x.Signature
}
return nil
}
type RawNebulaCertificateDetails struct {
Name string `protobuf:"bytes,1,opt,name=Name,json=name,proto3" json:"Name,omitempty"`
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
Name string `protobuf:"bytes,1,opt,name=Name,proto3" json:"Name,omitempty"`
// Ips and Subnets are in big endian 32 bit pairs, 1st the ip, 2nd the mask
Ips []uint32 `protobuf:"varint,2,rep,packed,name=Ips,json=ips,proto3" json:"Ips,omitempty"`
Subnets []uint32 `protobuf:"varint,3,rep,packed,name=Subnets,json=subnets,proto3" json:"Subnets,omitempty"`
Groups []string `protobuf:"bytes,4,rep,name=Groups,json=groups,proto3" json:"Groups,omitempty"`
NotBefore int64 `protobuf:"varint,5,opt,name=NotBefore,json=notBefore,proto3" json:"NotBefore,omitempty"`
NotAfter int64 `protobuf:"varint,6,opt,name=NotAfter,json=notAfter,proto3" json:"NotAfter,omitempty"`
PublicKey []byte `protobuf:"bytes,7,opt,name=PublicKey,json=publicKey,proto3" json:"PublicKey,omitempty"`
IsCA bool `protobuf:"varint,8,opt,name=IsCA,json=isCA,proto3" json:"IsCA,omitempty"`
Ips []uint32 `protobuf:"varint,2,rep,packed,name=Ips,proto3" json:"Ips,omitempty"`
Subnets []uint32 `protobuf:"varint,3,rep,packed,name=Subnets,proto3" json:"Subnets,omitempty"`
Groups []string `protobuf:"bytes,4,rep,name=Groups,proto3" json:"Groups,omitempty"`
NotBefore int64 `protobuf:"varint,5,opt,name=NotBefore,proto3" json:"NotBefore,omitempty"`
NotAfter int64 `protobuf:"varint,6,opt,name=NotAfter,proto3" json:"NotAfter,omitempty"`
PublicKey []byte `protobuf:"bytes,7,opt,name=PublicKey,proto3" json:"PublicKey,omitempty"`
IsCA bool `protobuf:"varint,8,opt,name=IsCA,proto3" json:"IsCA,omitempty"`
// sha-256 of the issuer certificate, if this field is blank the cert is self-signed
Issuer []byte `protobuf:"bytes,9,opt,name=Issuer,json=issuer,proto3" json:"Issuer,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
Issuer []byte `protobuf:"bytes,9,opt,name=Issuer,proto3" json:"Issuer,omitempty"`
}
func (m *RawNebulaCertificateDetails) Reset() { *m = RawNebulaCertificateDetails{} }
func (m *RawNebulaCertificateDetails) String() string { return proto.CompactTextString(m) }
func (*RawNebulaCertificateDetails) ProtoMessage() {}
func (x *RawNebulaCertificateDetails) Reset() {
*x = RawNebulaCertificateDetails{}
if protoimpl.UnsafeEnabled {
mi := &file_cert_proto_msgTypes[1]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *RawNebulaCertificateDetails) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*RawNebulaCertificateDetails) ProtoMessage() {}
func (x *RawNebulaCertificateDetails) ProtoReflect() protoreflect.Message {
mi := &file_cert_proto_msgTypes[1]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use RawNebulaCertificateDetails.ProtoReflect.Descriptor instead.
func (*RawNebulaCertificateDetails) Descriptor() ([]byte, []int) {
return fileDescriptor_a142e29cbef9b1cf, []int{1}
return file_cert_proto_rawDescGZIP(), []int{1}
}
func (m *RawNebulaCertificateDetails) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_RawNebulaCertificateDetails.Unmarshal(m, b)
}
func (m *RawNebulaCertificateDetails) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_RawNebulaCertificateDetails.Marshal(b, m, deterministic)
}
func (m *RawNebulaCertificateDetails) XXX_Merge(src proto.Message) {
xxx_messageInfo_RawNebulaCertificateDetails.Merge(m, src)
}
func (m *RawNebulaCertificateDetails) XXX_Size() int {
return xxx_messageInfo_RawNebulaCertificateDetails.Size(m)
}
func (m *RawNebulaCertificateDetails) XXX_DiscardUnknown() {
xxx_messageInfo_RawNebulaCertificateDetails.DiscardUnknown(m)
}
var xxx_messageInfo_RawNebulaCertificateDetails proto.InternalMessageInfo
func (m *RawNebulaCertificateDetails) GetName() string {
if m != nil {
return m.Name
func (x *RawNebulaCertificateDetails) GetName() string {
if x != nil {
return x.Name
}
return ""
}
func (m *RawNebulaCertificateDetails) GetIps() []uint32 {
if m != nil {
return m.Ips
func (x *RawNebulaCertificateDetails) GetIps() []uint32 {
if x != nil {
return x.Ips
}
return nil
}
func (m *RawNebulaCertificateDetails) GetSubnets() []uint32 {
if m != nil {
return m.Subnets
func (x *RawNebulaCertificateDetails) GetSubnets() []uint32 {
if x != nil {
return x.Subnets
}
return nil
}
func (m *RawNebulaCertificateDetails) GetGroups() []string {
if m != nil {
return m.Groups
func (x *RawNebulaCertificateDetails) GetGroups() []string {
if x != nil {
return x.Groups
}
return nil
}
func (m *RawNebulaCertificateDetails) GetNotBefore() int64 {
if m != nil {
return m.NotBefore
func (x *RawNebulaCertificateDetails) GetNotBefore() int64 {
if x != nil {
return x.NotBefore
}
return 0
}
func (m *RawNebulaCertificateDetails) GetNotAfter() int64 {
if m != nil {
return m.NotAfter
func (x *RawNebulaCertificateDetails) GetNotAfter() int64 {
if x != nil {
return x.NotAfter
}
return 0
}
func (m *RawNebulaCertificateDetails) GetPublicKey() []byte {
if m != nil {
return m.PublicKey
func (x *RawNebulaCertificateDetails) GetPublicKey() []byte {
if x != nil {
return x.PublicKey
}
return nil
}
func (m *RawNebulaCertificateDetails) GetIsCA() bool {
if m != nil {
return m.IsCA
func (x *RawNebulaCertificateDetails) GetIsCA() bool {
if x != nil {
return x.IsCA
}
return false
}
func (m *RawNebulaCertificateDetails) GetIssuer() []byte {
if m != nil {
return m.Issuer
func (x *RawNebulaCertificateDetails) GetIssuer() []byte {
if x != nil {
return x.Issuer
}
return nil
}
func init() {
proto.RegisterType((*RawNebulaCertificate)(nil), "cert.RawNebulaCertificate")
proto.RegisterType((*RawNebulaCertificateDetails)(nil), "cert.RawNebulaCertificateDetails")
var File_cert_proto protoreflect.FileDescriptor
var file_cert_proto_rawDesc = []byte{
0x0a, 0x0a, 0x63, 0x65, 0x72, 0x74, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x04, 0x63, 0x65,
0x72, 0x74, 0x22, 0x71, 0x0a, 0x14, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62, 0x75, 0x6c, 0x61, 0x43,
0x65, 0x72, 0x74, 0x69, 0x66, 0x69, 0x63, 0x61, 0x74, 0x65, 0x12, 0x3b, 0x0a, 0x07, 0x44, 0x65,
0x74, 0x61, 0x69, 0x6c, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x21, 0x2e, 0x63, 0x65,
0x72, 0x74, 0x2e, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62, 0x75, 0x6c, 0x61, 0x43, 0x65, 0x72, 0x74,
0x69, 0x66, 0x69, 0x63, 0x61, 0x74, 0x65, 0x44, 0x65, 0x74, 0x61, 0x69, 0x6c, 0x73, 0x52, 0x07,
0x44, 0x65, 0x74, 0x61, 0x69, 0x6c, 0x73, 0x12, 0x1c, 0x0a, 0x09, 0x53, 0x69, 0x67, 0x6e, 0x61,
0x74, 0x75, 0x72, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x09, 0x53, 0x69, 0x67, 0x6e,
0x61, 0x74, 0x75, 0x72, 0x65, 0x22, 0xf9, 0x01, 0x0a, 0x1b, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62,
0x75, 0x6c, 0x61, 0x43, 0x65, 0x72, 0x74, 0x69, 0x66, 0x69, 0x63, 0x61, 0x74, 0x65, 0x44, 0x65,
0x74, 0x61, 0x69, 0x6c, 0x73, 0x12, 0x12, 0x0a, 0x04, 0x4e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20,
0x01, 0x28, 0x09, 0x52, 0x04, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x10, 0x0a, 0x03, 0x49, 0x70, 0x73,
0x18, 0x02, 0x20, 0x03, 0x28, 0x0d, 0x52, 0x03, 0x49, 0x70, 0x73, 0x12, 0x18, 0x0a, 0x07, 0x53,
0x75, 0x62, 0x6e, 0x65, 0x74, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x0d, 0x52, 0x07, 0x53, 0x75,
0x62, 0x6e, 0x65, 0x74, 0x73, 0x12, 0x16, 0x0a, 0x06, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x73, 0x18,
0x04, 0x20, 0x03, 0x28, 0x09, 0x52, 0x06, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x73, 0x12, 0x1c, 0x0a,
0x09, 0x4e, 0x6f, 0x74, 0x42, 0x65, 0x66, 0x6f, 0x72, 0x65, 0x18, 0x05, 0x20, 0x01, 0x28, 0x03,
0x52, 0x09, 0x4e, 0x6f, 0x74, 0x42, 0x65, 0x66, 0x6f, 0x72, 0x65, 0x12, 0x1a, 0x0a, 0x08, 0x4e,
0x6f, 0x74, 0x41, 0x66, 0x74, 0x65, 0x72, 0x18, 0x06, 0x20, 0x01, 0x28, 0x03, 0x52, 0x08, 0x4e,
0x6f, 0x74, 0x41, 0x66, 0x74, 0x65, 0x72, 0x12, 0x1c, 0x0a, 0x09, 0x50, 0x75, 0x62, 0x6c, 0x69,
0x63, 0x4b, 0x65, 0x79, 0x18, 0x07, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x09, 0x50, 0x75, 0x62, 0x6c,
0x69, 0x63, 0x4b, 0x65, 0x79, 0x12, 0x12, 0x0a, 0x04, 0x49, 0x73, 0x43, 0x41, 0x18, 0x08, 0x20,
0x01, 0x28, 0x08, 0x52, 0x04, 0x49, 0x73, 0x43, 0x41, 0x12, 0x16, 0x0a, 0x06, 0x49, 0x73, 0x73,
0x75, 0x65, 0x72, 0x18, 0x09, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x06, 0x49, 0x73, 0x73, 0x75, 0x65,
0x72, 0x42, 0x20, 0x5a, 0x1e, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f,
0x73, 0x6c, 0x61, 0x63, 0x6b, 0x68, 0x71, 0x2f, 0x6e, 0x65, 0x62, 0x75, 0x6c, 0x61, 0x2f, 0x63,
0x65, 0x72, 0x74, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
}
func init() { proto.RegisterFile("cert.proto", fileDescriptor_a142e29cbef9b1cf) }
var (
file_cert_proto_rawDescOnce sync.Once
file_cert_proto_rawDescData = file_cert_proto_rawDesc
)
var fileDescriptor_a142e29cbef9b1cf = []byte{
// 279 bytes of a gzipped FileDescriptorProto
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x7c, 0x90, 0xcf, 0x4a, 0xf4, 0x30,
0x14, 0xc5, 0xc9, 0xa4, 0x5f, 0xdb, 0xe4, 0x53, 0x90, 0x20, 0x12, 0xd4, 0x45, 0x9c, 0x55, 0x56,
0xb3, 0xd0, 0xa5, 0xab, 0x71, 0x04, 0x29, 0x42, 0x91, 0xcc, 0x13, 0xa4, 0xf5, 0x76, 0x08, 0x74,
0x9a, 0x9a, 0x3f, 0x88, 0x8f, 0xee, 0x4e, 0x9a, 0x4e, 0x77, 0xe2, 0xee, 0x9e, 0x5f, 0xce, 0x49,
0x4e, 0x2e, 0xa5, 0x2d, 0xb8, 0xb0, 0x19, 0x9d, 0x0d, 0x96, 0x65, 0xd3, 0xbc, 0xfe, 0xa0, 0x97,
0x4a, 0x7f, 0xd6, 0xd0, 0xc4, 0x5e, 0xef, 0xc0, 0x05, 0xd3, 0x99, 0x56, 0x07, 0x60, 0x8f, 0xb4,
0x78, 0x86, 0xa0, 0x4d, 0xef, 0x39, 0x12, 0x48, 0xfe, 0xbf, 0xbf, 0xdb, 0xa4, 0xec, 0x6f, 0xe6,
0x93, 0x51, 0x15, 0xef, 0xf3, 0xc0, 0x6e, 0x29, 0xd9, 0x9b, 0xc3, 0xa0, 0x43, 0x74, 0xc0, 0x57,
0x02, 0xc9, 0x33, 0x45, 0xfc, 0x02, 0xd6, 0xdf, 0x88, 0xde, 0xfc, 0x71, 0x0d, 0x63, 0x34, 0xab,
0xf5, 0x11, 0xd2, 0xbb, 0x44, 0x65, 0x83, 0x3e, 0x02, 0xbb, 0xa0, 0xb8, 0x1a, 0x3d, 0x5f, 0x09,
0x2c, 0xcf, 0x15, 0x36, 0xa3, 0x67, 0x9c, 0x16, 0xfb, 0xd8, 0x0c, 0x10, 0x3c, 0xc7, 0x89, 0x16,
0x7e, 0x96, 0xec, 0x8a, 0xe6, 0x2f, 0xce, 0xc6, 0xd1, 0xf3, 0x4c, 0x60, 0x49, 0x54, 0x7e, 0x48,
0x6a, 0x6a, 0x55, 0xdb, 0xf0, 0x04, 0x9d, 0x75, 0xc0, 0xff, 0x09, 0x24, 0xb1, 0x22, 0xc3, 0x02,
0xd8, 0x35, 0x2d, 0x6b, 0x1b, 0xb6, 0x5d, 0x00, 0xc7, 0xf3, 0x74, 0x58, 0x0e, 0x27, 0x3d, 0x25,
0xdf, 0x62, 0xd3, 0x9b, 0xf6, 0x15, 0xbe, 0x78, 0x31, 0xff, 0x67, 0x5c, 0xc0, 0xd4, 0xb7, 0xf2,
0xbb, 0x2d, 0x2f, 0x05, 0x92, 0xa5, 0xca, 0x8c, 0xdf, 0x6d, 0xa7, 0x0e, 0x95, 0xf7, 0x11, 0x1c,
0x27, 0xc9, 0x9e, 0x9b, 0xa4, 0x9a, 0x3c, 0xed, 0xfe, 0xe1, 0x27, 0x00, 0x00, 0xff, 0xff, 0x2c,
0xe3, 0x08, 0x37, 0x89, 0x01, 0x00, 0x00,
func file_cert_proto_rawDescGZIP() []byte {
file_cert_proto_rawDescOnce.Do(func() {
file_cert_proto_rawDescData = protoimpl.X.CompressGZIP(file_cert_proto_rawDescData)
})
return file_cert_proto_rawDescData
}
var file_cert_proto_msgTypes = make([]protoimpl.MessageInfo, 2)
var file_cert_proto_goTypes = []interface{}{
(*RawNebulaCertificate)(nil), // 0: cert.RawNebulaCertificate
(*RawNebulaCertificateDetails)(nil), // 1: cert.RawNebulaCertificateDetails
}
var file_cert_proto_depIdxs = []int32{
1, // 0: cert.RawNebulaCertificate.Details:type_name -> cert.RawNebulaCertificateDetails
1, // [1:1] is the sub-list for method output_type
1, // [1:1] is the sub-list for method input_type
1, // [1:1] is the sub-list for extension type_name
1, // [1:1] is the sub-list for extension extendee
0, // [0:1] is the sub-list for field type_name
}
func init() { file_cert_proto_init() }
func file_cert_proto_init() {
if File_cert_proto != nil {
return
}
if !protoimpl.UnsafeEnabled {
file_cert_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*RawNebulaCertificate); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_cert_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*RawNebulaCertificateDetails); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
}
type x struct{}
out := protoimpl.TypeBuilder{
File: protoimpl.DescBuilder{
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: file_cert_proto_rawDesc,
NumEnums: 0,
NumMessages: 2,
NumExtensions: 0,
NumServices: 0,
},
GoTypes: file_cert_proto_goTypes,
DependencyIndexes: file_cert_proto_depIdxs,
MessageInfos: file_cert_proto_msgTypes,
}.Build()
File_cert_proto = out.File
file_cert_proto_rawDesc = nil
file_cert_proto_goTypes = nil
file_cert_proto_depIdxs = nil
}

View File

@@ -1,6 +1,8 @@
syntax = "proto3";
package cert;
option go_package = "github.com/slackhq/nebula/cert";
//import "google/protobuf/timestamp.proto";
message RawNebulaCertificate {

View File

@@ -447,6 +447,255 @@ BVG+oJpAoqokUBbI4U0N8CSfpUABEkB/Pm5A2xyH/nc8mg/wvGUWG3pZ7nHzaDMf
assert.Equal(t, pp.CAs[string("5c9c3f23e7ee7fe97637cbd3a0a5b854154d1d9aaaf7b566a51f4a88f76b64cd")].Details.Name, rootCA01.Details.Name)
}
func appendByteSlices(b ...[]byte) []byte {
retSlice := []byte{}
for _, v := range b {
retSlice = append(retSlice, v...)
}
return retSlice
}
func TestUnmrshalCertPEM(t *testing.T) {
goodCert := []byte(`
# A good cert
-----BEGIN NEBULA CERTIFICATE-----
CkAKDm5lYnVsYSByb290IGNhKJfap9AFMJfg1+YGOiCUQGByMuNRhIlQBOyzXWbL
vcKBwDhov900phEfJ5DN3kABEkDCq5R8qBiu8sl54yVfgRcQXEDt3cHr8UTSLszv
bzBEr00kERQxxTzTsH8cpYEgRoipvmExvg8WP8NdAJEYJosB
-----END NEBULA CERTIFICATE-----
`)
badBanner := []byte(`# A bad banner
-----BEGIN NOT A NEBULA CERTIFICATE-----
CkAKDm5lYnVsYSByb290IGNhKJfap9AFMJfg1+YGOiCUQGByMuNRhIlQBOyzXWbL
vcKBwDhov900phEfJ5DN3kABEkDCq5R8qBiu8sl54yVfgRcQXEDt3cHr8UTSLszv
bzBEr00kERQxxTzTsH8cpYEgRoipvmExvg8WP8NdAJEYJosB
-----END NOT A NEBULA CERTIFICATE-----
`)
invalidPem := []byte(`# Not a valid PEM format
-BEGIN NEBULA CERTIFICATE-----
CkAKDm5lYnVsYSByb290IGNhKJfap9AFMJfg1+YGOiCUQGByMuNRhIlQBOyzXWbL
vcKBwDhov900phEfJ5DN3kABEkDCq5R8qBiu8sl54yVfgRcQXEDt3cHr8UTSLszv
bzBEr00kERQxxTzTsH8cpYEgRoipvmExvg8WP8NdAJEYJosB
-END NEBULA CERTIFICATE----`)
certBundle := appendByteSlices(goodCert, badBanner, invalidPem)
// Success test case
cert, rest, err := UnmarshalNebulaCertificateFromPEM(certBundle)
assert.NotNil(t, cert)
assert.Equal(t, rest, append(badBanner, invalidPem...))
assert.Nil(t, err)
// Fail due to invalid banner.
cert, rest, err = UnmarshalNebulaCertificateFromPEM(rest)
assert.Nil(t, cert)
assert.Equal(t, rest, invalidPem)
assert.EqualError(t, err, "bytes did not contain a proper nebula certificate banner")
// Fail due to ivalid PEM format, because
// it's missing the requisite pre-encapsulation boundary.
cert, rest, err = UnmarshalNebulaCertificateFromPEM(rest)
assert.Nil(t, cert)
assert.Equal(t, rest, invalidPem)
assert.EqualError(t, err, "input did not contain a valid PEM encoded block")
}
func TestUnmarshalEd25519PrivateKey(t *testing.T) {
privKey := []byte(`# A good key
-----BEGIN NEBULA ED25519 PRIVATE KEY-----
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA==
-----END NEBULA ED25519 PRIVATE KEY-----
`)
shortKey := []byte(`# A short key
-----BEGIN NEBULA ED25519 PRIVATE KEY-----
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
-----END NEBULA ED25519 PRIVATE KEY-----
`)
invalidBanner := []byte(`# Invalid banner
-----BEGIN NOT A NEBULA PRIVATE KEY-----
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA==
-----END NOT A NEBULA PRIVATE KEY-----
`)
invalidPem := []byte(`# Not a valid PEM format
-BEGIN NEBULA ED25519 PRIVATE KEY-----
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA==
-END NEBULA ED25519 PRIVATE KEY-----`)
keyBundle := appendByteSlices(privKey, shortKey, invalidBanner, invalidPem)
// Success test case
k, rest, err := UnmarshalEd25519PrivateKey(keyBundle)
assert.Len(t, k, 64)
assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem))
assert.Nil(t, err)
// Fail due to short key
k, rest, err = UnmarshalEd25519PrivateKey(rest)
assert.Nil(t, k)
assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem))
assert.EqualError(t, err, "key was not 64 bytes, is invalid ed25519 private key")
// Fail due to invalid banner
k, rest, err = UnmarshalEd25519PrivateKey(rest)
assert.Nil(t, k)
assert.Equal(t, rest, invalidPem)
assert.EqualError(t, err, "bytes did not contain a proper nebula Ed25519 private key banner")
// Fail due to ivalid PEM format, because
// it's missing the requisite pre-encapsulation boundary.
k, rest, err = UnmarshalEd25519PrivateKey(rest)
assert.Nil(t, k)
assert.Equal(t, rest, invalidPem)
assert.EqualError(t, err, "input did not contain a valid PEM encoded block")
}
func TestUnmarshalX25519PrivateKey(t *testing.T) {
privKey := []byte(`# A good key
-----BEGIN NEBULA X25519 PRIVATE KEY-----
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
-----END NEBULA X25519 PRIVATE KEY-----
`)
shortKey := []byte(`# A short key
-----BEGIN NEBULA X25519 PRIVATE KEY-----
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA==
-----END NEBULA X25519 PRIVATE KEY-----
`)
invalidBanner := []byte(`# Invalid banner
-----BEGIN NOT A NEBULA PRIVATE KEY-----
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
-----END NOT A NEBULA PRIVATE KEY-----
`)
invalidPem := []byte(`# Not a valid PEM format
-BEGIN NEBULA X25519 PRIVATE KEY-----
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
-END NEBULA X25519 PRIVATE KEY-----`)
keyBundle := appendByteSlices(privKey, shortKey, invalidBanner, invalidPem)
// Success test case
k, rest, err := UnmarshalX25519PrivateKey(keyBundle)
assert.Len(t, k, 32)
assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem))
assert.Nil(t, err)
// Fail due to short key
k, rest, err = UnmarshalX25519PrivateKey(rest)
assert.Nil(t, k)
assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem))
assert.EqualError(t, err, "key was not 32 bytes, is invalid X25519 private key")
// Fail due to invalid banner
k, rest, err = UnmarshalX25519PrivateKey(rest)
assert.Nil(t, k)
assert.Equal(t, rest, invalidPem)
assert.EqualError(t, err, "bytes did not contain a proper nebula X25519 private key banner")
// Fail due to ivalid PEM format, because
// it's missing the requisite pre-encapsulation boundary.
k, rest, err = UnmarshalX25519PrivateKey(rest)
assert.Nil(t, k)
assert.Equal(t, rest, invalidPem)
assert.EqualError(t, err, "input did not contain a valid PEM encoded block")
}
func TestUnmarshalEd25519PublicKey(t *testing.T) {
pubKey := []byte(`# A good key
-----BEGIN NEBULA ED25519 PUBLIC KEY-----
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
-----END NEBULA ED25519 PUBLIC KEY-----
`)
shortKey := []byte(`# A short key
-----BEGIN NEBULA ED25519 PUBLIC KEY-----
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA==
-----END NEBULA ED25519 PUBLIC KEY-----
`)
invalidBanner := []byte(`# Invalid banner
-----BEGIN NOT A NEBULA PUBLIC KEY-----
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
-----END NOT A NEBULA PUBLIC KEY-----
`)
invalidPem := []byte(`# Not a valid PEM format
-BEGIN NEBULA ED25519 PUBLIC KEY-----
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
-END NEBULA ED25519 PUBLIC KEY-----`)
keyBundle := appendByteSlices(pubKey, shortKey, invalidBanner, invalidPem)
// Success test case
k, rest, err := UnmarshalEd25519PublicKey(keyBundle)
assert.Equal(t, len(k), 32)
assert.Nil(t, err)
assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem))
// Fail due to short key
k, rest, err = UnmarshalEd25519PublicKey(rest)
assert.Nil(t, k)
assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem))
assert.EqualError(t, err, "key was not 32 bytes, is invalid ed25519 public key")
// Fail due to invalid banner
k, rest, err = UnmarshalEd25519PublicKey(rest)
assert.Nil(t, k)
assert.EqualError(t, err, "bytes did not contain a proper nebula Ed25519 public key banner")
assert.Equal(t, rest, invalidPem)
// Fail due to ivalid PEM format, because
// it's missing the requisite pre-encapsulation boundary.
k, rest, err = UnmarshalEd25519PublicKey(rest)
assert.Nil(t, k)
assert.Equal(t, rest, invalidPem)
assert.EqualError(t, err, "input did not contain a valid PEM encoded block")
}
func TestUnmarshalX25519PublicKey(t *testing.T) {
pubKey := []byte(`# A good key
-----BEGIN NEBULA X25519 PUBLIC KEY-----
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
-----END NEBULA X25519 PUBLIC KEY-----
`)
shortKey := []byte(`# A short key
-----BEGIN NEBULA X25519 PUBLIC KEY-----
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA==
-----END NEBULA X25519 PUBLIC KEY-----
`)
invalidBanner := []byte(`# Invalid banner
-----BEGIN NOT A NEBULA PUBLIC KEY-----
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
-----END NOT A NEBULA PUBLIC KEY-----
`)
invalidPem := []byte(`# Not a valid PEM format
-BEGIN NEBULA X25519 PUBLIC KEY-----
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
-END NEBULA X25519 PUBLIC KEY-----`)
keyBundle := appendByteSlices(pubKey, shortKey, invalidBanner, invalidPem)
// Success test case
k, rest, err := UnmarshalX25519PublicKey(keyBundle)
assert.Equal(t, len(k), 32)
assert.Nil(t, err)
assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem))
// Fail due to short key
k, rest, err = UnmarshalX25519PublicKey(rest)
assert.Nil(t, k)
assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem))
assert.EqualError(t, err, "key was not 32 bytes, is invalid X25519 public key")
// Fail due to invalid banner
k, rest, err = UnmarshalX25519PublicKey(rest)
assert.Nil(t, k)
assert.EqualError(t, err, "bytes did not contain a proper nebula X25519 public key banner")
assert.Equal(t, rest, invalidPem)
// Fail due to ivalid PEM format, because
// it's missing the requisite pre-encapsulation boundary.
k, rest, err = UnmarshalX25519PublicKey(rest)
assert.Nil(t, k)
assert.Equal(t, rest, invalidPem)
assert.EqualError(t, err, "input did not contain a valid PEM encoded block")
}
// Ensure that upgrading the protobuf library does not change how certificates
// are marshalled, since this would break signature verification
func TestMarshalingNebulaCertificateConsistency(t *testing.T) {
@@ -499,6 +748,13 @@ func TestNebulaCertificate_Copy(t *testing.T) {
util.AssertDeepCopyEqual(t, c, cc)
}
func TestUnmarshalNebulaCertificate(t *testing.T) {
// Test that we don't panic with an invalid certificate (#332)
data := []byte("\x98\x00\x00")
_, err := UnmarshalNebulaCertificate(data)
assert.EqualError(t, err, "encoded Details was nil")
}
func newTestCaCert(before, after time.Time, ips, subnets []*net.IPNet, groups []string) (*NebulaCertificate, []byte, []byte, error) {
pub, priv, err := ed25519.GenerateKey(rand.Reader)
if before.IsZero() {

184
cidr6_radix.go Normal file
View File

@@ -0,0 +1,184 @@
package nebula
import (
"encoding/binary"
"net"
)
const startbit6 = uint64(1 << 63)
type CIDR6Tree struct {
root4 *CIDRNode
root6 *CIDRNode
}
func NewCIDR6Tree() *CIDR6Tree {
tree := new(CIDR6Tree)
tree.root4 = &CIDRNode{}
tree.root6 = &CIDRNode{}
return tree
}
func (tree *CIDR6Tree) AddCIDR(cidr *net.IPNet, val interface{}) {
var node, next *CIDRNode
cidrIP, ipv4 := isIPV4(cidr.IP)
if ipv4 {
node = tree.root4
next = tree.root4
} else {
node = tree.root6
next = tree.root6
}
for i := 0; i < len(cidrIP); i += 4 {
ip := binary.BigEndian.Uint32(cidrIP[i : i+4])
mask := binary.BigEndian.Uint32(cidr.Mask[i : i+4])
bit := startbit
// Find our last ancestor in the tree
for bit&mask != 0 {
if ip&bit != 0 {
next = node.right
} else {
next = node.left
}
if next == nil {
break
}
bit = bit >> 1
node = next
}
// Build up the rest of the tree we don't already have
for bit&mask != 0 {
next = &CIDRNode{}
next.parent = node
if ip&bit != 0 {
node.right = next
} else {
node.left = next
}
bit >>= 1
node = next
}
}
// Final node marks our cidr, set the value
node.value = val
}
// Finds the most specific match
func (tree *CIDR6Tree) MostSpecificContains(ip net.IP) (value interface{}) {
var node *CIDRNode
wholeIP, ipv4 := isIPV4(ip)
if ipv4 {
node = tree.root4
} else {
node = tree.root6
}
for i := 0; i < len(wholeIP); i += 4 {
ip := ip2int(wholeIP[i : i+4])
bit := startbit
for node != nil {
if node.value != nil {
value = node.value
}
if bit == 0 {
break
}
if ip&bit != 0 {
node = node.right
} else {
node = node.left
}
bit >>= 1
}
}
return value
}
func (tree *CIDR6Tree) MostSpecificContainsIpV4(ip uint32) (value interface{}) {
bit := startbit
node := tree.root4
for node != nil {
if node.value != nil {
value = node.value
}
if ip&bit != 0 {
node = node.right
} else {
node = node.left
}
bit >>= 1
}
return value
}
func (tree *CIDR6Tree) MostSpecificContainsIpV6(hi, lo uint64) (value interface{}) {
ip := hi
node := tree.root6
for i := 0; i < 2; i++ {
bit := startbit6
for node != nil {
if node.value != nil {
value = node.value
}
if bit == 0 {
break
}
if ip&bit != 0 {
node = node.right
} else {
node = node.left
}
bit >>= 1
}
ip = lo
}
return value
}
func isIPV4(ip net.IP) (net.IP, bool) {
if len(ip) == net.IPv4len {
return ip, true
}
if len(ip) == net.IPv6len && isZeros(ip[0:10]) && ip[10] == 0xff && ip[11] == 0xff {
return ip[12:16], true
}
return ip, false
}
func isZeros(p net.IP) bool {
for i := 0; i < len(p); i++ {
if p[i] != 0 {
return false
}
}
return true
}

77
cidr6_radix_test.go Normal file
View File

@@ -0,0 +1,77 @@
package nebula
import (
"net"
"testing"
"github.com/stretchr/testify/assert"
)
func TestCIDR6Tree_MostSpecificContains(t *testing.T) {
tree := NewCIDR6Tree()
tree.AddCIDR(getCIDR("1.0.0.0/8"), "1")
tree.AddCIDR(getCIDR("2.1.0.0/16"), "2")
tree.AddCIDR(getCIDR("3.1.1.0/24"), "3")
tree.AddCIDR(getCIDR("4.1.1.1/24"), "4a")
tree.AddCIDR(getCIDR("4.1.1.1/30"), "4b")
tree.AddCIDR(getCIDR("4.1.1.1/32"), "4c")
tree.AddCIDR(getCIDR("254.0.0.0/4"), "5")
tree.AddCIDR(getCIDR("1:2:0:4:5:0:0:0/64"), "6a")
tree.AddCIDR(getCIDR("1:2:0:4:5:0:0:0/80"), "6b")
tree.AddCIDR(getCIDR("1:2:0:4:5:0:0:0/96"), "6c")
tests := []struct {
Result interface{}
IP string
}{
{"1", "1.0.0.0"},
{"1", "1.255.255.255"},
{"2", "2.1.0.0"},
{"2", "2.1.255.255"},
{"3", "3.1.1.0"},
{"3", "3.1.1.255"},
{"4a", "4.1.1.255"},
{"4b", "4.1.1.2"},
{"4c", "4.1.1.1"},
{"5", "240.0.0.0"},
{"5", "255.255.255.255"},
{"6a", "1:2:0:4:1:1:1:1"},
{"6b", "1:2:0:4:5:1:1:1"},
{"6c", "1:2:0:4:5:0:0:0"},
{nil, "239.0.0.0"},
{nil, "4.1.2.2"},
}
for _, tt := range tests {
assert.Equal(t, tt.Result, tree.MostSpecificContains(net.ParseIP(tt.IP)))
}
tree = NewCIDR6Tree()
tree.AddCIDR(getCIDR("1.1.1.1/0"), "cool")
tree.AddCIDR(getCIDR("::/0"), "cool6")
assert.Equal(t, "cool", tree.MostSpecificContains(net.ParseIP("0.0.0.0")))
assert.Equal(t, "cool", tree.MostSpecificContains(net.ParseIP("255.255.255.255")))
assert.Equal(t, "cool6", tree.MostSpecificContains(net.ParseIP("::")))
assert.Equal(t, "cool6", tree.MostSpecificContains(net.ParseIP("1:2:3:4:5:6:7:8")))
}
func TestCIDR6Tree_MostSpecificContainsIpV6(t *testing.T) {
tree := NewCIDR6Tree()
tree.AddCIDR(getCIDR("1:2:0:4:5:0:0:0/64"), "6a")
tree.AddCIDR(getCIDR("1:2:0:4:5:0:0:0/80"), "6b")
tree.AddCIDR(getCIDR("1:2:0:4:5:0:0:0/96"), "6c")
tests := []struct {
Result interface{}
IP string
}{
{"6a", "1:2:0:4:1:1:1:1"},
{"6b", "1:2:0:4:5:1:1:1"},
{"6c", "1:2:0:4:5:0:0:0"},
}
for _, tt := range tests {
ip := NewIp6AndPort(net.ParseIP(tt.IP), 0)
assert.Equal(t, tt.Result, tree.MostSpecificContainsIpV6(ip.Hi, ip.Lo))
}
}

View File

@@ -76,7 +76,7 @@ func (tree *CIDRTree) AddCIDR(cidr *net.IPNet, val interface{}) {
node.value = val
}
// Finds the first match, which way be the least specific
// Finds the first match, which may be the least specific
func (tree *CIDRTree) Contains(ip uint32) (value interface{}) {
bit := startbit
node := tree.root
@@ -116,7 +116,6 @@ func (tree *CIDRTree) MostSpecificContains(ip uint32) (value interface{}) {
}
bit >>= 1
}
return value

View File

@@ -11,6 +11,7 @@ import (
"strings"
"time"
"github.com/skip2/go-qrcode"
"github.com/slackhq/nebula/cert"
"golang.org/x/crypto/ed25519"
)
@@ -21,6 +22,7 @@ type caFlags struct {
duration *time.Duration
outKeyPath *string
outCertPath *string
outQRPath *string
groups *string
ips *string
subnets *string
@@ -33,6 +35,7 @@ func newCaFlags() *caFlags {
cf.duration = cf.set.Duration("duration", time.Duration(time.Hour*8760), "Optional: amount of time the certificate should be valid for. Valid time units are seconds: \"s\", minutes: \"m\", hours: \"h\"")
cf.outKeyPath = cf.set.String("out-key", "ca.key", "Optional: path to write the private key to")
cf.outCertPath = cf.set.String("out-crt", "ca.crt", "Optional: path to write the certificate to")
cf.outQRPath = cf.set.String("out-qr", "", "Optional: output a qr code image (png) of the certificate")
cf.groups = cf.set.String("groups", "", "Optional: comma separated list of groups. This will limit which groups subordinate certs can use")
cf.ips = cf.set.String("ips", "", "Optional: comma separated list of ip and network in CIDR notation. This will limit which ip addresses and networks subordinate certs can use")
cf.subnets = cf.set.String("subnets", "", "Optional: comma separated list of ip and network in CIDR notation. This will limit which subnet addresses and networks subordinate certs can use")
@@ -146,6 +149,18 @@ func ca(args []string, out io.Writer, errOut io.Writer) error {
return fmt.Errorf("error while writing out-crt: %s", err)
}
if *cf.outQRPath != "" {
b, err = qrcode.Encode(string(b), qrcode.Medium, -5)
if err != nil {
return fmt.Errorf("error while generating qr code: %s", err)
}
err = ioutil.WriteFile(*cf.outQRPath, b, 0600)
if err != nil {
return fmt.Errorf("error while writing out-qr: %s", err)
}
}
return nil
}

View File

@@ -37,6 +37,8 @@ func Test_caHelp(t *testing.T) {
" \tOptional: path to write the certificate to (default \"ca.crt\")\n"+
" -out-key string\n"+
" \tOptional: path to write the private key to (default \"ca.key\")\n"+
" -out-qr string\n"+
" \tOptional: output a qr code image (png) of the certificate\n"+
" -subnets string\n"+
" \tOptional: comma separated list of ip and network in CIDR notation. This will limit which subnet addresses and networks subordinate certs can use\n",
ob.String(),

View File

@@ -9,19 +9,22 @@ import (
"os"
"strings"
"github.com/skip2/go-qrcode"
"github.com/slackhq/nebula/cert"
)
type printFlags struct {
set *flag.FlagSet
json *bool
path *string
set *flag.FlagSet
json *bool
outQRPath *string
path *string
}
func newPrintFlags() *printFlags {
pf := printFlags{set: flag.NewFlagSet("print", flag.ContinueOnError)}
pf.set.Usage = func() {}
pf.json = pf.set.Bool("json", false, "Optional: outputs certificates in json format")
pf.outQRPath = pf.set.String("out-qr", "", "Optional: output a qr code image (png) of the certificate")
pf.path = pf.set.String("path", "", "Required: path to the certificate")
return &pf
@@ -44,6 +47,8 @@ func printCert(args []string, out io.Writer, errOut io.Writer) error {
}
var c *cert.NebulaCertificate
var qrBytes []byte
part := 0
for {
c, rawCert, err = cert.UnmarshalNebulaCertificateFromPEM(rawCert)
@@ -61,9 +66,31 @@ func printCert(args []string, out io.Writer, errOut io.Writer) error {
out.Write([]byte("\n"))
}
if *pf.outQRPath != "" {
b, err := c.MarshalToPEM()
if err != nil {
return fmt.Errorf("error while marshalling cert to PEM: %s", err)
}
qrBytes = append(qrBytes, b...)
}
if rawCert == nil || len(rawCert) == 0 || strings.TrimSpace(string(rawCert)) == "" {
break
}
part++
}
if *pf.outQRPath != "" {
b, err := qrcode.Encode(string(qrBytes), qrcode.Medium, -5)
if err != nil {
return fmt.Errorf("error while generating qr code: %s", err)
}
err = ioutil.WriteFile(*pf.outQRPath, b, 0600)
if err != nil {
return fmt.Errorf("error while writing out-qr: %s", err)
}
}
return nil

View File

@@ -23,6 +23,8 @@ func Test_printHelp(t *testing.T) {
"Usage of "+os.Args[0]+" print <flags>: prints details about a certificate\n"+
" -json\n"+
" \tOptional: outputs certificates in json format\n"+
" -out-qr string\n"+
" \tOptional: output a qr code image (png) of the certificate\n"+
" -path string\n"+
" \tRequired: path to the certificate\n",
ob.String(),

View File

@@ -11,6 +11,7 @@ import (
"strings"
"time"
"github.com/skip2/go-qrcode"
"github.com/slackhq/nebula/cert"
"golang.org/x/crypto/curve25519"
)
@@ -25,6 +26,7 @@ type signFlags struct {
inPubPath *string
outKeyPath *string
outCertPath *string
outQRPath *string
groups *string
subnets *string
}
@@ -40,8 +42,9 @@ func newSignFlags() *signFlags {
sf.inPubPath = sf.set.String("in-pub", "", "Optional (if out-key not set): path to read a previously generated public key")
sf.outKeyPath = sf.set.String("out-key", "", "Optional (if in-pub not set): path to write the private key to")
sf.outCertPath = sf.set.String("out-crt", "", "Optional: path to write the certificate to")
sf.outQRPath = sf.set.String("out-qr", "", "Optional: output a qr code image (png) of the certificate")
sf.groups = sf.set.String("groups", "", "Optional: comma separated list of groups")
sf.subnets = sf.set.String("subnets", "", "Optional: comma seperated list of subnet this cert can serve for")
sf.subnets = sf.set.String("subnets", "", "Optional: comma separated list of subnet this cert can serve for")
return &sf
}
@@ -203,6 +206,18 @@ func signCert(args []string, out io.Writer, errOut io.Writer) error {
return fmt.Errorf("error while writing out-crt: %s", err)
}
if *sf.outQRPath != "" {
b, err = qrcode.Encode(string(b), qrcode.Medium, -5)
if err != nil {
return fmt.Errorf("error while generating qr code: %s", err)
}
err = ioutil.WriteFile(*sf.outQRPath, b, 0600)
if err != nil {
return fmt.Errorf("error while writing out-qr: %s", err)
}
}
return nil
}

View File

@@ -45,8 +45,10 @@ func Test_signHelp(t *testing.T) {
" \tOptional: path to write the certificate to\n"+
" -out-key string\n"+
" \tOptional (if in-pub not set): path to write the private key to\n"+
" -out-qr string\n"+
" \tOptional: output a qr code image (png) of the certificate\n"+
" -subnets string\n"+
" \tOptional: comma seperated list of subnet this cert can serve for\n",
" \tOptional: comma separated list of subnet this cert can serve for\n",
ob.String(),
)
}
@@ -286,5 +288,4 @@ func Test_signCert(t *testing.T) {
assert.EqualError(t, signCert(args, ob, eb), "refusing to overwrite existing cert: "+crtF.Name())
assert.Empty(t, ob.String())
assert.Empty(t, eb.String())
}

View File

@@ -0,0 +1,9 @@
// +build !windows
package main
import "github.com/sirupsen/logrus"
func HookLogger(l *logrus.Logger) {
// Do nothing, let the logs flow to stdout/stderr
}

View File

@@ -0,0 +1,54 @@
package main
import (
"fmt"
"io/ioutil"
"os"
"github.com/kardianos/service"
"github.com/sirupsen/logrus"
)
// HookLogger routes the logrus logs through the service logger so that they end up in the Windows Event Viewer
// logrus output will be discarded
func HookLogger(l *logrus.Logger) {
l.AddHook(newLogHook(logger))
l.SetOutput(ioutil.Discard)
}
type logHook struct {
sl service.Logger
}
func newLogHook(sl service.Logger) *logHook {
return &logHook{sl: sl}
}
func (h *logHook) Fire(entry *logrus.Entry) error {
line, err := entry.String()
if err != nil {
fmt.Fprintf(os.Stderr, "Unable to read entry, %v", err)
return err
}
switch entry.Level {
case logrus.PanicLevel:
return h.sl.Error(line)
case logrus.FatalLevel:
return h.sl.Error(line)
case logrus.ErrorLevel:
return h.sl.Error(line)
case logrus.WarnLevel:
return h.sl.Warning(line)
case logrus.InfoLevel:
return h.sl.Info(line)
case logrus.DebugLevel:
return h.sl.Info(line)
default:
return nil
}
}
func (h *logHook) Levels() []logrus.Level {
return logrus.AllLevels
}

View File

@@ -46,15 +46,16 @@ func main() {
os.Exit(1)
}
config := nebula.NewConfig()
l := logrus.New()
l.Out = os.Stdout
config := nebula.NewConfig(l)
err := config.Load(*configPath)
if err != nil {
fmt.Printf("failed to load config: %s", err)
os.Exit(1)
}
l := logrus.New()
l.Out = os.Stdout
c, err := nebula.Main(config, *configTest, Build, l, nil)
switch v := err.(type) {

View File

@@ -24,14 +24,15 @@ func (p *program) Start(s service.Service) error {
// Start should not block.
logger.Info("Nebula service starting.")
config := nebula.NewConfig()
l := logrus.New()
HookLogger(l)
config := nebula.NewConfig(l)
err := config.Load(*p.configPath)
if err != nil {
return fmt.Errorf("failed to load config: %s", err)
}
l := logrus.New()
l.Out = os.Stdout
p.control, err = nebula.Main(config, *p.configTest, Build, l, nil)
if err != nil {
return err
@@ -69,6 +70,10 @@ func doService(configPath *string, configTest *bool, build string, serviceFlag *
build: build,
}
// Here are what the different loggers are doing:
// - `log` is the standard go log utility, meant to be used while the process is still attached to stdout/stderr
// - `logger` is the service log utility that may be attached to a special place depending on OS (Windows will have it attached to the event log)
// - above, in `Run` we create a `logrus.Logger` which is what nebula expects to use
s, err := service.New(prg, svcConfig)
if err != nil {
log.Fatal(err)
@@ -84,6 +89,7 @@ func doService(configPath *string, configTest *bool, build string, serviceFlag *
for {
err := <-errs
if err != nil {
// Route any errors from the system logger to stdout as a best effort to notice issues there
log.Print(err)
}
}
@@ -93,6 +99,7 @@ func doService(configPath *string, configTest *bool, build string, serviceFlag *
case "run":
err = s.Run()
if err != nil {
// Route any errors to the system logger
logger.Error(err)
}
default:

View File

@@ -40,15 +40,16 @@ func main() {
os.Exit(1)
}
config := nebula.NewConfig()
l := logrus.New()
l.Out = os.Stdout
config := nebula.NewConfig(l)
err := config.Load(*configPath)
if err != nil {
fmt.Printf("failed to load config: %s", err)
os.Exit(1)
}
l := logrus.New()
l.Out = os.Stdout
c, err := nebula.Main(config, *configTest, Build, l, nil)
switch v := err.(type) {

View File

@@ -26,11 +26,13 @@ type Config struct {
Settings map[interface{}]interface{}
oldSettings map[interface{}]interface{}
callbacks []func(*Config)
l *logrus.Logger
}
func NewConfig() *Config {
func NewConfig(l *logrus.Logger) *Config {
return &Config{
Settings: make(map[interface{}]interface{}),
l: l,
}
}
@@ -99,12 +101,12 @@ func (c *Config) HasChanged(k string) bool {
newVals, err := yaml.Marshal(nv)
if err != nil {
l.WithField("config_path", k).WithError(err).Error("Error while marshaling new config")
c.l.WithField("config_path", k).WithError(err).Error("Error while marshaling new config")
}
oldVals, err := yaml.Marshal(ov)
if err != nil {
l.WithField("config_path", k).WithError(err).Error("Error while marshaling old config")
c.l.WithField("config_path", k).WithError(err).Error("Error while marshaling old config")
}
return string(newVals) != string(oldVals)
@@ -118,7 +120,7 @@ func (c *Config) CatchHUP() {
go func() {
for range ch {
l.Info("Caught HUP, reloading config")
c.l.Info("Caught HUP, reloading config")
c.ReloadConfig()
}
}()
@@ -132,7 +134,7 @@ func (c *Config) ReloadConfig() {
err := c.Load(c.path)
if err != nil {
l.WithField("config_path", c.path).WithError(err).Error("Error occurred while reloading config")
c.l.WithField("config_path", c.path).WithError(err).Error("Error occurred while reloading config")
return
}
@@ -235,13 +237,18 @@ func (c *Config) GetAllowList(k string, allowInterfaces bool) (*AllowList, error
return nil, fmt.Errorf("config `%s` has invalid type: %T", k, r)
}
tree := NewCIDRTree()
tree := NewCIDR6Tree()
var nameRules []AllowListNameRule
firstValue := true
allValuesMatch := true
defaultSet := false
var allValues bool
// Keep track of the rules we have added for both ipv4 and ipv6
type allowListRules struct {
firstValue bool
allValuesMatch bool
defaultSet bool
allValues bool
}
rules4 := allowListRules{firstValue: true, allValuesMatch: true, defaultSet: false}
rules6 := allowListRules{firstValue: true, allValuesMatch: true, defaultSet: false}
for rawKey, rawValue := range rawMap {
rawCIDR, ok := rawKey.(string)
@@ -276,31 +283,48 @@ func (c *Config) GetAllowList(k string, allowInterfaces bool) (*AllowList, error
// TODO: should we error on duplicate CIDRs in the config?
tree.AddCIDR(cidr, value)
if firstValue {
allValues = value
firstValue = false
maskBits, maskSize := cidr.Mask.Size()
var rules *allowListRules
if maskSize == 32 {
rules = &rules4
} else {
if value != allValues {
allValuesMatch = false
rules = &rules6
}
if rules.firstValue {
rules.allValues = value
rules.firstValue = false
} else {
if value != rules.allValues {
rules.allValuesMatch = false
}
}
// Check if this is 0.0.0.0/0
bits, size := cidr.Mask.Size()
if bits == 0 && size == 32 {
defaultSet = true
// Check if this is 0.0.0.0/0 or ::/0
if maskBits == 0 {
rules.defaultSet = true
}
}
if !defaultSet {
if allValuesMatch {
if !rules4.defaultSet {
if rules4.allValuesMatch {
_, zeroCIDR, _ := net.ParseCIDR("0.0.0.0/0")
tree.AddCIDR(zeroCIDR, !allValues)
tree.AddCIDR(zeroCIDR, !rules4.allValues)
} else {
return nil, fmt.Errorf("config `%s` contains both true and false rules, but no default set for 0.0.0.0/0", k)
}
}
if !rules6.defaultSet {
if rules6.allValuesMatch {
_, zeroCIDR, _ := net.ParseCIDR("::/0")
tree.AddCIDR(zeroCIDR, !rules6.allValues)
} else {
return nil, fmt.Errorf("config `%s` contains both true and false rules, but no default set for ::/0", k)
}
}
return &AllowList{cidrTree: tree, nameRules: nameRules}, nil
}
@@ -478,7 +502,7 @@ func configLogger(c *Config) error {
if err != nil {
return fmt.Errorf("%s; possible levels: %s", err, logrus.AllLevels)
}
l.SetLevel(logLevel)
c.l.SetLevel(logLevel)
disableTimestamp := c.GetBool("logging.disable_timestamp", false)
timestampFormat := c.GetString("logging.timestamp_format", "")
@@ -490,13 +514,13 @@ func configLogger(c *Config) error {
logFormat := strings.ToLower(c.GetString("logging.format", "text"))
switch logFormat {
case "text":
l.Formatter = &logrus.TextFormatter{
c.l.Formatter = &logrus.TextFormatter{
TimestampFormat: timestampFormat,
FullTimestamp: fullTimestamp,
DisableTimestamp: disableTimestamp,
}
case "json":
l.Formatter = &logrus.JSONFormatter{
c.l.Formatter = &logrus.JSONFormatter{
TimestampFormat: timestampFormat,
DisableTimestamp: disableTimestamp,
}

View File

@@ -11,14 +11,15 @@ import (
)
func TestConfig_Load(t *testing.T) {
l := NewTestLogger()
dir, err := ioutil.TempDir("", "config-test")
// invalid yaml
c := NewConfig()
c := NewConfig(l)
ioutil.WriteFile(filepath.Join(dir, "01.yaml"), []byte(" invalid yaml"), 0644)
assert.EqualError(t, c.Load(dir), "yaml: unmarshal errors:\n line 1: cannot unmarshal !!str `invalid...` into map[interface {}]interface {}")
// simple multi config merge
c = NewConfig()
c = NewConfig(l)
os.RemoveAll(dir)
os.Mkdir(dir, 0755)
@@ -40,8 +41,9 @@ func TestConfig_Load(t *testing.T) {
}
func TestConfig_Get(t *testing.T) {
l := NewTestLogger()
// test simple type
c := NewConfig()
c := NewConfig(l)
c.Settings["firewall"] = map[interface{}]interface{}{"outbound": "hi"}
assert.Equal(t, "hi", c.Get("firewall.outbound"))
@@ -55,13 +57,15 @@ func TestConfig_Get(t *testing.T) {
}
func TestConfig_GetStringSlice(t *testing.T) {
c := NewConfig()
l := NewTestLogger()
c := NewConfig(l)
c.Settings["slice"] = []interface{}{"one", "two"}
assert.Equal(t, []string{"one", "two"}, c.GetStringSlice("slice", []string{}))
}
func TestConfig_GetBool(t *testing.T) {
c := NewConfig()
l := NewTestLogger()
c := NewConfig(l)
c.Settings["bool"] = true
assert.Equal(t, true, c.GetBool("bool", false))
@@ -88,7 +92,8 @@ func TestConfig_GetBool(t *testing.T) {
}
func TestConfig_GetAllowList(t *testing.T) {
c := NewConfig()
l := NewTestLogger()
c := NewConfig(l)
c.Settings["allowlist"] = map[interface{}]interface{}{
"192.168.0.0": true,
}
@@ -109,6 +114,16 @@ func TestConfig_GetAllowList(t *testing.T) {
r, err = c.GetAllowList("allowlist", false)
assert.EqualError(t, err, "config `allowlist` contains both true and false rules, but no default set for 0.0.0.0/0")
c.Settings["allowlist"] = map[interface{}]interface{}{
"0.0.0.0/0": true,
"10.0.0.0/8": false,
"10.42.42.0/24": true,
"fd00::/8": true,
"fd00:fd00::/16": false,
}
r, err = c.GetAllowList("allowlist", false)
assert.EqualError(t, err, "config `allowlist` contains both true and false rules, but no default set for ::/0")
c.Settings["allowlist"] = map[interface{}]interface{}{
"0.0.0.0/0": true,
"10.0.0.0/8": false,
@@ -119,6 +134,19 @@ func TestConfig_GetAllowList(t *testing.T) {
assert.NotNil(t, r)
}
c.Settings["allowlist"] = map[interface{}]interface{}{
"0.0.0.0/0": true,
"10.0.0.0/8": false,
"10.42.42.0/24": true,
"::/0": false,
"fd00::/8": true,
"fd00:fd00::/16": false,
}
r, err = c.GetAllowList("allowlist", false)
if assert.NoError(t, err) {
assert.NotNil(t, r)
}
// Test interface names
c.Settings["allowlist"] = map[interface{}]interface{}{
@@ -158,20 +186,21 @@ func TestConfig_GetAllowList(t *testing.T) {
}
func TestConfig_HasChanged(t *testing.T) {
l := NewTestLogger()
// No reload has occurred, return false
c := NewConfig()
c := NewConfig(l)
c.Settings["test"] = "hi"
assert.False(t, c.HasChanged(""))
// Test key change
c = NewConfig()
c = NewConfig(l)
c.Settings["test"] = "hi"
c.oldSettings = map[interface{}]interface{}{"test": "no"}
assert.True(t, c.HasChanged("test"))
assert.True(t, c.HasChanged(""))
// No key change
c = NewConfig()
c = NewConfig(l)
c.Settings["test"] = "hi"
c.oldSettings = map[interface{}]interface{}{"test": "hi"}
assert.False(t, c.HasChanged("test"))
@@ -179,12 +208,13 @@ func TestConfig_HasChanged(t *testing.T) {
}
func TestConfig_ReloadConfig(t *testing.T) {
l := NewTestLogger()
done := make(chan bool, 1)
dir, err := ioutil.TempDir("", "config-test")
assert.Nil(t, err)
ioutil.WriteFile(filepath.Join(dir, "01.yaml"), []byte("outer:\n inner: hi"), 0644)
c := NewConfig()
c := NewConfig(l)
assert.Nil(t, c.Load(dir))
assert.False(t, c.HasChanged("outer.inner"))

View File

@@ -28,10 +28,11 @@ type connectionManager struct {
checkInterval int
pendingDeletionInterval int
l *logrus.Logger
// I wanted to call one matLock
}
func newConnectionManager(intf *Interface, checkInterval, pendingDeletionInterval int) *connectionManager {
func newConnectionManager(l *logrus.Logger, intf *Interface, checkInterval, pendingDeletionInterval int) *connectionManager {
nc := &connectionManager{
hostMap: intf.hostMap,
in: make(map[uint32]struct{}),
@@ -47,6 +48,7 @@ func newConnectionManager(intf *Interface, checkInterval, pendingDeletionInterva
pendingDeletionTimer: NewSystemTimerWheel(time.Millisecond*500, time.Second*60),
checkInterval: checkInterval,
pendingDeletionInterval: pendingDeletionInterval,
l: l,
}
nc.Start()
return nc
@@ -141,14 +143,17 @@ func (n *connectionManager) Start() {
func (n *connectionManager) Run() {
clockSource := time.Tick(500 * time.Millisecond)
p := []byte("")
nb := make([]byte, 12, 12)
out := make([]byte, mtu)
for now := range clockSource {
n.HandleMonitorTick(now)
n.HandleMonitorTick(now, p, nb, out)
n.HandleDeletionTick(now)
}
}
func (n *connectionManager) HandleMonitorTick(now time.Time) {
func (n *connectionManager) HandleMonitorTick(now time.Time, p, nb, out []byte) {
n.TrafficTimer.advance(now)
for {
ep := n.TrafficTimer.Purge()
@@ -163,8 +168,8 @@ func (n *connectionManager) HandleMonitorTick(now time.Time) {
// If we saw incoming packets from this ip, just return
if traf {
if l.Level >= logrus.DebugLevel {
l.WithField("vpnIp", IntIp(vpnIP)).
if n.l.Level >= logrus.DebugLevel {
n.l.WithField("vpnIp", IntIp(vpnIP)).
WithField("tunnelCheck", m{"state": "alive", "method": "passive"}).
Debug("Tunnel status")
}
@@ -176,22 +181,22 @@ func (n *connectionManager) HandleMonitorTick(now time.Time) {
// If we didn't we may need to probe or destroy the conn
hostinfo, err := n.hostMap.QueryVpnIP(vpnIP)
if err != nil {
l.Debugf("Not found in hostmap: %s", IntIp(vpnIP))
n.l.Debugf("Not found in hostmap: %s", IntIp(vpnIP))
n.ClearIP(vpnIP)
n.ClearPendingDeletion(vpnIP)
continue
}
hostinfo.logger().
hostinfo.logger(n.l).
WithField("tunnelCheck", m{"state": "testing", "method": "active"}).
Debug("Tunnel status")
if hostinfo != nil && hostinfo.ConnectionState != nil {
// Send a test packet to trigger an authenticated tunnel test, this should suss out any lingering tunnel issues
n.intf.SendMessageToVpnIp(test, testRequest, vpnIP, []byte(""), make([]byte, 12, 12), make([]byte, mtu))
n.intf.SendMessageToVpnIp(test, testRequest, vpnIP, p, nb, out)
} else {
hostinfo.logger().Debugf("Hostinfo sadness: %s", IntIp(vpnIP))
hostinfo.logger(n.l).Debugf("Hostinfo sadness: %s", IntIp(vpnIP))
}
n.AddPendingDeletion(vpnIP)
}
@@ -211,7 +216,7 @@ func (n *connectionManager) HandleDeletionTick(now time.Time) {
// If we saw incoming packets from this ip, just return
traf := n.CheckIn(vpnIP)
if traf {
l.WithField("vpnIp", IntIp(vpnIP)).
n.l.WithField("vpnIp", IntIp(vpnIP)).
WithField("tunnelCheck", m{"state": "alive", "method": "active"}).
Debug("Tunnel status")
n.ClearIP(vpnIP)
@@ -223,7 +228,7 @@ func (n *connectionManager) HandleDeletionTick(now time.Time) {
if err != nil {
n.ClearIP(vpnIP)
n.ClearPendingDeletion(vpnIP)
l.Debugf("Not found in hostmap: %s", IntIp(vpnIP))
n.l.Debugf("Not found in hostmap: %s", IntIp(vpnIP))
continue
}
@@ -233,7 +238,7 @@ func (n *connectionManager) HandleDeletionTick(now time.Time) {
if hostinfo.ConnectionState != nil && hostinfo.ConnectionState.peerCert != nil {
cn = hostinfo.ConnectionState.peerCert.Details.Name
}
hostinfo.logger().
hostinfo.logger(n.l).
WithField("tunnelCheck", m{"state": "dead", "method": "active"}).
WithField("certName", cn).
Info("Tunnel status")
@@ -244,8 +249,7 @@ func (n *connectionManager) HandleDeletionTick(now time.Time) {
if n.intf.lightHouse != nil {
n.intf.lightHouse.DeleteVpnIP(vpnIP)
}
n.hostMap.DeleteVpnIP(vpnIP)
n.hostMap.DeleteIndex(hostinfo.localIndexId)
n.hostMap.DeleteHostInfo(hostinfo)
} else {
n.ClearIP(vpnIP)
n.ClearPendingDeletion(vpnIP)

View File

@@ -13,6 +13,7 @@ import (
var vpnIP uint32
func Test_NewConnectionManagerTest(t *testing.T) {
l := NewTestLogger()
//_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24")
_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
@@ -20,7 +21,7 @@ func Test_NewConnectionManagerTest(t *testing.T) {
preferredRanges := []*net.IPNet{localrange}
// Very incomplete mock objects
hostMap := NewHostMap("test", vpncidr, preferredRanges)
hostMap := NewHostMap(l, "test", vpncidr, preferredRanges)
cs := &CertState{
rawCertificate: []byte{},
privateKey: []byte{},
@@ -28,7 +29,7 @@ func Test_NewConnectionManagerTest(t *testing.T) {
rawCertificateNoKey: []byte{},
}
lh := NewLightHouse(false, 0, []uint32{}, 1000, 0, &udpConn{}, false, 1, false)
lh := NewLightHouse(l, false, &net.IPNet{IP: net.IP{0, 0, 0, 0}, Mask: net.IPMask{0, 0, 0, 0}}, []uint32{}, 1000, 0, &udpConn{}, false, 1, false)
ifce := &Interface{
hostMap: hostMap,
inside: &Tun{},
@@ -36,19 +37,22 @@ func Test_NewConnectionManagerTest(t *testing.T) {
certState: cs,
firewall: &Firewall{},
lightHouse: lh,
handshakeManager: NewHandshakeManager(vpncidr, preferredRanges, hostMap, lh, &udpConn{}, defaultHandshakeConfig),
handshakeManager: NewHandshakeManager(l, vpncidr, preferredRanges, hostMap, lh, &udpConn{}, defaultHandshakeConfig),
l: l,
}
now := time.Now()
// Create manager
nc := newConnectionManager(ifce, 5, 10)
nc.HandleMonitorTick(now)
nc := newConnectionManager(l, ifce, 5, 10)
p := []byte("")
nb := make([]byte, 12, 12)
out := make([]byte, mtu)
nc.HandleMonitorTick(now, p, nb, out)
// Add an ip we have established a connection w/ to hostmap
hostinfo := nc.hostMap.AddVpnIP(vpnIP)
hostinfo.ConnectionState = &ConnectionState{
certState: cs,
H: &noise.HandshakeState{},
messageCounter: new(uint64),
certState: cs,
H: &noise.HandshakeState{},
}
// We saw traffic out to vpnIP
@@ -57,18 +61,18 @@ func Test_NewConnectionManagerTest(t *testing.T) {
assert.Contains(t, nc.hostMap.Hosts, vpnIP)
// Move ahead 5s. Nothing should happen
next_tick := now.Add(5 * time.Second)
nc.HandleMonitorTick(next_tick)
nc.HandleMonitorTick(next_tick, p, nb, out)
nc.HandleDeletionTick(next_tick)
// Move ahead 6s. We haven't heard back
next_tick = now.Add(6 * time.Second)
nc.HandleMonitorTick(next_tick)
nc.HandleMonitorTick(next_tick, p, nb, out)
nc.HandleDeletionTick(next_tick)
// This host should now be up for deletion
assert.Contains(t, nc.pendingDeletion, vpnIP)
assert.Contains(t, nc.hostMap.Hosts, vpnIP)
// Move ahead some more
next_tick = now.Add(45 * time.Second)
nc.HandleMonitorTick(next_tick)
nc.HandleMonitorTick(next_tick, p, nb, out)
nc.HandleDeletionTick(next_tick)
// The host should be evicted
assert.NotContains(t, nc.pendingDeletion, vpnIP)
@@ -77,13 +81,14 @@ func Test_NewConnectionManagerTest(t *testing.T) {
}
func Test_NewConnectionManagerTest2(t *testing.T) {
l := NewTestLogger()
//_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24")
_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
preferredRanges := []*net.IPNet{localrange}
// Very incomplete mock objects
hostMap := NewHostMap("test", vpncidr, preferredRanges)
hostMap := NewHostMap(l, "test", vpncidr, preferredRanges)
cs := &CertState{
rawCertificate: []byte{},
privateKey: []byte{},
@@ -91,7 +96,7 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
rawCertificateNoKey: []byte{},
}
lh := NewLightHouse(false, 0, []uint32{}, 1000, 0, &udpConn{}, false, 1, false)
lh := NewLightHouse(l, false, &net.IPNet{IP: net.IP{0, 0, 0, 0}, Mask: net.IPMask{0, 0, 0, 0}}, []uint32{}, 1000, 0, &udpConn{}, false, 1, false)
ifce := &Interface{
hostMap: hostMap,
inside: &Tun{},
@@ -99,19 +104,22 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
certState: cs,
firewall: &Firewall{},
lightHouse: lh,
handshakeManager: NewHandshakeManager(vpncidr, preferredRanges, hostMap, lh, &udpConn{}, defaultHandshakeConfig),
handshakeManager: NewHandshakeManager(l, vpncidr, preferredRanges, hostMap, lh, &udpConn{}, defaultHandshakeConfig),
l: l,
}
now := time.Now()
// Create manager
nc := newConnectionManager(ifce, 5, 10)
nc.HandleMonitorTick(now)
nc := newConnectionManager(l, ifce, 5, 10)
p := []byte("")
nb := make([]byte, 12, 12)
out := make([]byte, mtu)
nc.HandleMonitorTick(now, p, nb, out)
// Add an ip we have established a connection w/ to hostmap
hostinfo := nc.hostMap.AddVpnIP(vpnIP)
hostinfo.ConnectionState = &ConnectionState{
certState: cs,
H: &noise.HandshakeState{},
messageCounter: new(uint64),
certState: cs,
H: &noise.HandshakeState{},
}
// We saw traffic out to vpnIP
@@ -120,11 +128,11 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
assert.Contains(t, nc.hostMap.Hosts, vpnIP)
// Move ahead 5s. Nothing should happen
next_tick := now.Add(5 * time.Second)
nc.HandleMonitorTick(next_tick)
nc.HandleMonitorTick(next_tick, p, nb, out)
nc.HandleDeletionTick(next_tick)
// Move ahead 6s. We haven't heard back
next_tick = now.Add(6 * time.Second)
nc.HandleMonitorTick(next_tick)
nc.HandleMonitorTick(next_tick, p, nb, out)
nc.HandleDeletionTick(next_tick)
// This host should now be up for deletion
assert.Contains(t, nc.pendingDeletion, vpnIP)
@@ -133,7 +141,7 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
nc.In(vpnIP)
// Move ahead some more
next_tick = now.Add(45 * time.Second)
nc.HandleMonitorTick(next_tick)
nc.HandleMonitorTick(next_tick, p, nb, out)
nc.HandleDeletionTick(next_tick)
// The host should be evicted
assert.NotContains(t, nc.pendingDeletion, vpnIP)

View File

@@ -4,28 +4,30 @@ import (
"crypto/rand"
"encoding/json"
"sync"
"sync/atomic"
"github.com/flynn/noise"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert"
)
const ReplayWindow = 1024
type ConnectionState struct {
eKey *NebulaCipherState
dKey *NebulaCipherState
H *noise.HandshakeState
certState *CertState
peerCert *cert.NebulaCertificate
initiator bool
messageCounter *uint64
window *Bits
queueLock sync.Mutex
writeLock sync.Mutex
ready bool
eKey *NebulaCipherState
dKey *NebulaCipherState
H *noise.HandshakeState
certState *CertState
peerCert *cert.NebulaCertificate
initiator bool
atomicMessageCounter uint64
window *Bits
queueLock sync.Mutex
writeLock sync.Mutex
ready bool
}
func (f *Interface) newConnectionState(initiator bool, pattern noise.HandshakePattern, psk []byte, pskStage int) *ConnectionState {
func (f *Interface) newConnectionState(l *logrus.Logger, initiator bool, pattern noise.HandshakePattern, psk []byte, pskStage int) *ConnectionState {
cs := noise.NewCipherSuite(noise.DH25519, noise.CipherAESGCM, noise.HashSHA256)
if f.cipher == "chachapoly" {
cs = noise.NewCipherSuite(noise.DH25519, noise.CipherChaChaPoly, noise.HashSHA256)
@@ -36,7 +38,7 @@ func (f *Interface) newConnectionState(initiator bool, pattern noise.HandshakePa
b := NewBits(ReplayWindow)
// Clear out bit 0, we never transmit it and we don't want it showing as packet loss
b.Update(0)
b.Update(l, 0)
hs, err := noise.NewHandshakeState(noise.Config{
CipherSuite: cs,
@@ -54,12 +56,11 @@ func (f *Interface) newConnectionState(initiator bool, pattern noise.HandshakePa
// The queue and ready params prevent a counter race that would happen when
// sending stored packets and simultaneously accepting new traffic.
ci := &ConnectionState{
H: hs,
initiator: initiator,
window: b,
ready: false,
certState: curCertState,
messageCounter: new(uint64),
H: hs,
initiator: initiator,
window: b,
ready: false,
certState: curCertState,
}
return ci
@@ -69,7 +70,7 @@ func (cs *ConnectionState) MarshalJSON() ([]byte, error) {
return json.Marshal(m{
"certificate": cs.peerCert,
"initiator": cs.initiator,
"message_counter": cs.messageCounter,
"message_counter": atomic.LoadUint64(&cs.atomicMessageCounter),
"ready": cs.ready,
})
}

View File

@@ -4,6 +4,7 @@ import (
"net"
"os"
"os/signal"
"sync/atomic"
"syscall"
"github.com/sirupsen/logrus"
@@ -14,39 +15,48 @@ import (
// core. This means copying IP objects, slices, de-referencing pointers and taking the actual value, etc
type Control struct {
f *Interface
l *logrus.Logger
f *Interface
l *logrus.Logger
sshStart func()
statsStart func()
dnsStart func()
}
type ControlHostInfo struct {
VpnIP net.IP `json:"vpnIp"`
LocalIndex uint32 `json:"localIndex"`
RemoteIndex uint32 `json:"remoteIndex"`
RemoteAddrs []udpAddr `json:"remoteAddrs"`
RemoteAddrs []*udpAddr `json:"remoteAddrs"`
CachedPackets int `json:"cachedPackets"`
Cert *cert.NebulaCertificate `json:"cert"`
MessageCounter uint64 `json:"messageCounter"`
CurrentRemote udpAddr `json:"currentRemote"`
CurrentRemote *udpAddr `json:"currentRemote"`
}
// Start actually runs nebula, this is a nonblocking call. To block use Control.ShutdownBlock()
func (c *Control) Start() {
// Activate the interface
c.f.activate()
// Call all the delayed funcs that waited patiently for the interface to be created.
if c.sshStart != nil {
go c.sshStart()
}
if c.statsStart != nil {
go c.statsStart()
}
if c.dnsStart != nil {
go c.dnsStart()
}
// Start reading packets.
c.f.run()
}
// Stop signals nebula to shutdown, returns after the shutdown is complete
func (c *Control) Stop() {
//TODO: stop tun and udp routines, the lock on hostMap effectively does that though
//TODO: this is probably better as a function in ConnectionManager or HostMap directly
c.f.hostMap.Lock()
for _, h := range c.f.hostMap.Hosts {
if h.ConnectionState.ready {
c.f.send(closeTunnel, 0, h.ConnectionState, h, h.remote, []byte{}, make([]byte, 12, 12), make([]byte, mtu))
c.l.WithField("vpnIp", IntIp(h.hostId)).WithField("udpAddr", h.remote).
Debug("Sending close tunnel message")
}
}
c.f.hostMap.Unlock()
c.CloseAllTunnels(false)
c.l.Info("Goodbye")
}
@@ -65,27 +75,21 @@ func (c *Control) ShutdownBlock() {
// RebindUDPServer asks the UDP listener to rebind it's listener. Mainly used on mobile clients when interfaces change
func (c *Control) RebindUDPServer() {
_ = c.f.outside.Rebind()
// Trigger a lighthouse update, useful for mobile clients that should have an update interval of 0
c.f.lightHouse.SendUpdate(c.f)
// Let the main interface know that we rebound so that underlying tunnels know to trigger punches from their remotes
c.f.rebindCount++
}
// ListHostmap returns details about the actual or pending (handshaking) hostmap
func (c *Control) ListHostmap(pendingMap bool) []ControlHostInfo {
var hm *HostMap
if pendingMap {
hm = c.f.handshakeManager.pendingHostMap
return listHostMap(c.f.handshakeManager.pendingHostMap)
} else {
hm = c.f.hostMap
return listHostMap(c.f.hostMap)
}
hm.RLock()
hosts := make([]ControlHostInfo, len(hm.Hosts))
i := 0
for _, v := range hm.Hosts {
hosts[i] = copyHostInfo(v)
i++
}
hm.RUnlock()
return hosts
}
// GetHostInfoByVpnIP returns a single tunnels hostInfo, or nil if not found
@@ -102,7 +106,7 @@ func (c *Control) GetHostInfoByVpnIP(vpnIP uint32, pending bool) *ControlHostInf
return nil
}
ch := copyHostInfo(h)
ch := copyHostInfo(h, c.f.hostMap.preferredRanges)
return &ch
}
@@ -114,7 +118,7 @@ func (c *Control) SetRemoteForTunnel(vpnIP uint32, addr udpAddr) *ControlHostInf
}
hostInfo.SetRemote(addr.Copy())
ch := copyHostInfo(hostInfo)
ch := copyHostInfo(hostInfo, c.f.hostMap.preferredRanges)
return &ch
}
@@ -138,19 +142,46 @@ func (c *Control) CloseTunnel(vpnIP uint32, localOnly bool) bool {
)
}
c.f.closeTunnel(hostInfo)
c.f.closeTunnel(hostInfo, false)
return true
}
func copyHostInfo(h *HostInfo) ControlHostInfo {
addrs := h.RemoteUDPAddrs()
// CloseAllTunnels is just like CloseTunnel except it goes through and shuts them all down, optionally you can avoid shutting down lighthouse tunnels
// the int returned is a count of tunnels closed
func (c *Control) CloseAllTunnels(excludeLighthouses bool) (closed int) {
//TODO: this is probably better as a function in ConnectionManager or HostMap directly
c.f.hostMap.Lock()
for _, h := range c.f.hostMap.Hosts {
if excludeLighthouses {
if _, ok := c.f.lightHouse.lighthouses[h.hostId]; ok {
continue
}
}
if h.ConnectionState.ready {
c.f.send(closeTunnel, 0, h.ConnectionState, h, h.remote, []byte{}, make([]byte, 12, 12), make([]byte, mtu))
c.f.closeTunnel(h, true)
c.l.WithField("vpnIp", IntIp(h.hostId)).WithField("udpAddr", h.remote).
Debug("Sending close tunnel message")
closed++
}
}
c.f.hostMap.Unlock()
return
}
func copyHostInfo(h *HostInfo, preferredRanges []*net.IPNet) ControlHostInfo {
chi := ControlHostInfo{
VpnIP: int2ip(h.hostId),
LocalIndex: h.localIndexId,
RemoteIndex: h.remoteIndexId,
RemoteAddrs: make([]udpAddr, len(addrs), len(addrs)),
CachedPackets: len(h.packetStore),
MessageCounter: *h.ConnectionState.messageCounter,
VpnIP: int2ip(h.hostId),
LocalIndex: h.localIndexId,
RemoteIndex: h.remoteIndexId,
RemoteAddrs: h.remotes.CopyAddrs(preferredRanges),
CachedPackets: len(h.packetStore),
}
if h.ConnectionState != nil {
chi.MessageCounter = atomic.LoadUint64(&h.ConnectionState.atomicMessageCounter)
}
if c := h.GetCert(); c != nil {
@@ -158,12 +189,21 @@ func copyHostInfo(h *HostInfo) ControlHostInfo {
}
if h.remote != nil {
chi.CurrentRemote = *h.remote
}
for i, addr := range addrs {
chi.RemoteAddrs[i] = addr.Copy()
chi.CurrentRemote = h.remote.Copy()
}
return chi
}
func listHostMap(hm *HostMap) []ControlHostInfo {
hm.RLock()
hosts := make([]ControlHostInfo, len(hm.Hosts))
i := 0
for _, v := range hm.Hosts {
hosts[i] = copyHostInfo(v, hm.preferredRanges)
i++
}
hm.RUnlock()
return hosts
}

View File

@@ -13,18 +13,19 @@ import (
)
func TestControl_GetHostInfoByVpnIP(t *testing.T) {
l := NewTestLogger()
// Special care must be taken to re-use all objects provided to the hostmap and certificate in the expectedInfo object
// To properly ensure we are not exposing core memory to the caller
hm := NewHostMap("test", &net.IPNet{}, make([]*net.IPNet, 0))
remote1 := NewUDPAddr(100, 4444)
remote2 := NewUDPAddr(101, 4444)
hm := NewHostMap(l, "test", &net.IPNet{}, make([]*net.IPNet, 0))
remote1 := NewUDPAddr(int2ip(100), 4444)
remote2 := NewUDPAddr(net.ParseIP("1:2:3:4:5:6:7:8"), 4444)
ipNet := net.IPNet{
IP: net.IPv4(1, 2, 3, 4),
Mask: net.IPMask{255, 255, 255, 0},
}
ipNet2 := net.IPNet{
IP: net.IPv4(1, 2, 3, 5),
IP: net.ParseIP("1:2:3:4:5:6:7:8"),
Mask: net.IPMask{255, 255, 255, 0},
}
@@ -43,15 +44,15 @@ func TestControl_GetHostInfoByVpnIP(t *testing.T) {
},
Signature: []byte{1, 2, 1, 2, 1, 3},
}
counter := uint64(0)
remotes := []*HostInfoDest{NewHostInfoDest(remote1), NewHostInfoDest(remote2)}
remotes := NewRemoteList()
remotes.unlockedPrependV4(0, NewIp4AndPort(remote1.IP, uint32(remote1.Port)))
remotes.unlockedPrependV6(0, NewIp6AndPort(remote2.IP, uint32(remote2.Port)))
hm.Add(ip2int(ipNet.IP), &HostInfo{
remote: remote1,
Remotes: remotes,
remotes: remotes,
ConnectionState: &ConnectionState{
peerCert: crt,
messageCounter: &counter,
peerCert: crt,
},
remoteIndexId: 200,
localIndexId: 201,
@@ -60,10 +61,9 @@ func TestControl_GetHostInfoByVpnIP(t *testing.T) {
hm.Add(ip2int(ipNet2.IP), &HostInfo{
remote: remote1,
Remotes: remotes,
remotes: remotes,
ConnectionState: &ConnectionState{
peerCert: nil,
messageCounter: &counter,
peerCert: nil,
},
remoteIndexId: 200,
localIndexId: 201,
@@ -83,11 +83,11 @@ func TestControl_GetHostInfoByVpnIP(t *testing.T) {
VpnIP: net.IPv4(1, 2, 3, 4).To4(),
LocalIndex: 201,
RemoteIndex: 200,
RemoteAddrs: []udpAddr{*remote1, *remote2},
RemoteAddrs: []*udpAddr{remote2, remote1},
CachedPackets: 0,
Cert: crt.Copy(),
MessageCounter: 0,
CurrentRemote: *NewUDPAddr(100, 4444),
CurrentRemote: NewUDPAddr(int2ip(100), 4444),
}
// Make sure we don't have any unexpected fields

128
control_tester.go Normal file
View File

@@ -0,0 +1,128 @@
// +build e2e_testing
package nebula
import (
"net"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
)
// WaitForTypeByIndex will pipe all messages from this control device into the pipeTo control device
// returning after a message matching the criteria has been piped
func (c *Control) WaitForType(msgType NebulaMessageType, subType NebulaMessageSubType, pipeTo *Control) {
h := &Header{}
for {
p := c.f.outside.Get(true)
if err := h.Parse(p.Data); err != nil {
panic(err)
}
pipeTo.InjectUDPPacket(p)
if h.Type == msgType && h.Subtype == subType {
return
}
}
}
// WaitForTypeByIndex is similar to WaitForType except it adds an index check
// Useful if you have many nodes communicating and want to wait to find a specific nodes packet
func (c *Control) WaitForTypeByIndex(toIndex uint32, msgType NebulaMessageType, subType NebulaMessageSubType, pipeTo *Control) {
h := &Header{}
for {
p := c.f.outside.Get(true)
if err := h.Parse(p.Data); err != nil {
panic(err)
}
pipeTo.InjectUDPPacket(p)
if h.RemoteIndex == toIndex && h.Type == msgType && h.Subtype == subType {
return
}
}
}
// InjectLightHouseAddr will push toAddr into the local lighthouse cache for the vpnIp
// This is necessary if you did not configure static hosts or are not running a lighthouse
func (c *Control) InjectLightHouseAddr(vpnIp net.IP, toAddr *net.UDPAddr) {
c.f.lightHouse.Lock()
remoteList := c.f.lightHouse.unlockedGetRemoteList(ip2int(vpnIp))
remoteList.Lock()
defer remoteList.Unlock()
c.f.lightHouse.Unlock()
iVpnIp := ip2int(vpnIp)
if v4 := toAddr.IP.To4(); v4 != nil {
remoteList.unlockedPrependV4(iVpnIp, NewIp4AndPort(v4, uint32(toAddr.Port)))
} else {
remoteList.unlockedPrependV6(iVpnIp, NewIp6AndPort(toAddr.IP, uint32(toAddr.Port)))
}
}
// GetFromTun will pull a packet off the tun side of nebula
func (c *Control) GetFromTun(block bool) []byte {
return c.f.inside.(*Tun).Get(block)
}
// GetFromUDP will pull a udp packet off the udp side of nebula
func (c *Control) GetFromUDP(block bool) *UdpPacket {
return c.f.outside.Get(block)
}
func (c *Control) GetUDPTxChan() <-chan *UdpPacket {
return c.f.outside.txPackets
}
func (c *Control) GetTunTxChan() <-chan []byte {
return c.f.inside.(*Tun).txPackets
}
// InjectUDPPacket will inject a packet into the udp side of nebula
func (c *Control) InjectUDPPacket(p *UdpPacket) {
c.f.outside.Send(p)
}
// InjectTunUDPPacket puts a udp packet on the tun interface. Using UDP here because it's a simpler protocol
func (c *Control) InjectTunUDPPacket(toIp net.IP, toPort uint16, fromPort uint16, data []byte) {
ip := layers.IPv4{
Version: 4,
TTL: 64,
Protocol: layers.IPProtocolUDP,
SrcIP: c.f.inside.CidrNet().IP,
DstIP: toIp,
}
udp := layers.UDP{
SrcPort: layers.UDPPort(fromPort),
DstPort: layers.UDPPort(toPort),
}
err := udp.SetNetworkLayerForChecksum(&ip)
if err != nil {
panic(err)
}
buffer := gopacket.NewSerializeBuffer()
opt := gopacket.SerializeOptions{
ComputeChecksums: true,
FixLengths: true,
}
err = gopacket.SerializeLayers(buffer, opt, &ip, &udp, gopacket.Payload(data))
if err != nil {
panic(err)
}
c.f.inside.(*Tun).Send(buffer.Bytes())
}
func (c *Control) GetUDPAddr() string {
return c.f.outside.addr.String()
}
func (c *Control) KillPendingTunnel(vpnIp net.IP) bool {
hostinfo, ok := c.f.handshakeManager.pendingHostMap.Hosts[ip2int(vpnIp)]
if !ok {
return false
}
c.f.handshakeManager.pendingHostMap.DeleteHostInfo(hostinfo)
return true
}

View File

@@ -5,8 +5,6 @@ After=basic.target network.target network-online.target
[Service]
SyslogIdentifier=nebula
StandardOutput=syslog
StandardError=syslog
ExecReload=/bin/kill -HUP $MAINPID
ExecStart=/usr/bin/nebula -config /etc/nebula/config.yml
Restart=always

15
dist/fedora/nebula.service vendored Normal file
View File

@@ -0,0 +1,15 @@
[Unit]
Description=Nebula overlay networking tool
After=basic.target network.target network-online.target
Before=sshd.service
Wants=basic.target network-online.target
[Service]
ExecReload=/bin/kill -HUP $MAINPID
ExecStart=/usr/bin/nebula -config /etc/nebula/config.yml
Restart=always
SyslogIdentifier=nebula
[Install]
WantedBy=multi-user.target

View File

@@ -7,6 +7,7 @@ import (
"sync"
"github.com/miekg/dns"
"github.com/sirupsen/logrus"
)
// This whole thing should be rewritten to use context
@@ -63,7 +64,7 @@ func (d *dnsRecords) Add(host, data string) {
d.Unlock()
}
func parseQuery(m *dns.Msg, w dns.ResponseWriter) {
func parseQuery(l *logrus.Logger, m *dns.Msg, w dns.ResponseWriter) {
for _, q := range m.Question {
switch q.Qtype {
case dns.TypeA:
@@ -95,37 +96,44 @@ func parseQuery(m *dns.Msg, w dns.ResponseWriter) {
}
}
func handleDnsRequest(w dns.ResponseWriter, r *dns.Msg) {
func handleDnsRequest(l *logrus.Logger, w dns.ResponseWriter, r *dns.Msg) {
m := new(dns.Msg)
m.SetReply(r)
m.Compress = false
switch r.Opcode {
case dns.OpcodeQuery:
parseQuery(m, w)
parseQuery(l, m, w)
}
w.WriteMsg(m)
}
func dnsMain(hostMap *HostMap, c *Config) {
func dnsMain(l *logrus.Logger, hostMap *HostMap, c *Config) func() {
dnsR = newDnsRecords(hostMap)
// attach request handler func
dns.HandleFunc(".", handleDnsRequest)
dns.HandleFunc(".", func(w dns.ResponseWriter, r *dns.Msg) {
handleDnsRequest(l, w, r)
})
c.RegisterReloadCallback(reloadDns)
startDns(c)
c.RegisterReloadCallback(func(c *Config) {
reloadDns(l, c)
})
return func() {
startDns(l, c)
}
}
func getDnsServerAddr(c *Config) string {
return c.GetString("lighthouse.dns.host", "") + ":" + strconv.Itoa(c.GetInt("lighthouse.dns.port", 53))
}
func startDns(c *Config) {
func startDns(l *logrus.Logger, c *Config) {
dnsAddr = getDnsServerAddr(c)
dnsServer = &dns.Server{Addr: dnsAddr, Net: "udp"}
l.Debugf("Starting DNS responder at %s\n", dnsAddr)
l.WithField("dnsListener", dnsAddr).Infof("Starting DNS responder")
err := dnsServer.ListenAndServe()
defer dnsServer.Shutdown()
if err != nil {
@@ -133,7 +141,7 @@ func startDns(c *Config) {
}
}
func reloadDns(c *Config) {
func reloadDns(l *logrus.Logger, c *Config) {
if dnsAddr == getDnsServerAddr(c) {
l.Debug("No DNS server config change detected")
return
@@ -141,5 +149,5 @@ func reloadDns(c *Config) {
l.Debug("Restarting DNS server")
dnsServer.Shutdown()
go startDns(c)
go startDns(l, c)
}

3
e2e/doc.go Normal file
View File

@@ -0,0 +1,3 @@
package e2e
// This file exists to allow `go fmt` to traverse here on its own. The build tags were keeping it out before

181
e2e/handshakes_test.go Normal file
View File

@@ -0,0 +1,181 @@
// +build e2e_testing
package e2e
import (
"net"
"testing"
"time"
"github.com/slackhq/nebula"
"github.com/slackhq/nebula/e2e/router"
"github.com/stretchr/testify/assert"
)
func TestGoodHandshake(t *testing.T) {
ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
myControl, myVpnIp, myUdpAddr := newSimpleServer(ca, caKey, "me", net.IP{10, 0, 0, 1})
theirControl, theirVpnIp, theirUdpAddr := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2})
// Put their info in our lighthouse
myControl.InjectLightHouseAddr(theirVpnIp, theirUdpAddr)
// Start the servers
myControl.Start()
theirControl.Start()
t.Log("Send a udp packet through to begin standing up the tunnel, this should come out the other side")
myControl.InjectTunUDPPacket(theirVpnIp, 80, 80, []byte("Hi from me"))
t.Log("Have them consume my stage 0 packet. They have a tunnel now")
theirControl.InjectUDPPacket(myControl.GetFromUDP(true))
t.Log("Get their stage 1 packet so that we can play with it")
stage1Packet := theirControl.GetFromUDP(true)
t.Log("I consume a garbage packet with a proper nebula header for our tunnel")
// this should log a statement and get ignored, allowing the real handshake packet to complete the tunnel
badPacket := stage1Packet.Copy()
badPacket.Data = badPacket.Data[:len(badPacket.Data)-nebula.HeaderLen]
myControl.InjectUDPPacket(badPacket)
t.Log("Have me consume their real stage 1 packet. I have a tunnel now")
myControl.InjectUDPPacket(stage1Packet)
t.Log("Wait until we see my cached packet come through")
myControl.WaitForType(1, 0, theirControl)
t.Log("Make sure our host infos are correct")
assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIp, theirVpnIp, myControl, theirControl)
t.Log("Get that cached packet and make sure it looks right")
myCachedPacket := theirControl.GetFromTun(true)
assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIp, theirVpnIp, 80, 80)
t.Log("Do a bidirectional tunnel test")
assertTunnel(t, myVpnIp, theirVpnIp, myControl, theirControl, router.NewR(myControl, theirControl))
myControl.Stop()
theirControl.Stop()
//TODO: assert hostmaps
}
func TestWrongResponderHandshake(t *testing.T) {
ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
// The IPs here are chosen on purpose:
// The current remote handling will sort by preference, public, and then lexically.
// So we need them to have a higher address than evil (we could apply a preference though)
myControl, myVpnIp, myUdpAddr := newSimpleServer(ca, caKey, "me", net.IP{10, 0, 0, 100})
theirControl, theirVpnIp, theirUdpAddr := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 99})
evilControl, evilVpnIp, evilUdpAddr := newSimpleServer(ca, caKey, "evil", net.IP{10, 0, 0, 2})
// Add their real udp addr, which should be tried after evil.
myControl.InjectLightHouseAddr(theirVpnIp, theirUdpAddr)
// Put the evil udp addr in for their vpn Ip, this is a case of being lied to by the lighthouse.
myControl.InjectLightHouseAddr(theirVpnIp, evilUdpAddr)
// Build a router so we don't have to reason who gets which packet
r := router.NewR(myControl, theirControl, evilControl)
// Start the servers
myControl.Start()
theirControl.Start()
evilControl.Start()
t.Log("Start the handshake process, we will route until we see our cached packet get sent to them")
myControl.InjectTunUDPPacket(theirVpnIp, 80, 80, []byte("Hi from me"))
r.RouteForAllExitFunc(func(p *nebula.UdpPacket, c *nebula.Control) router.ExitType {
h := &nebula.Header{}
err := h.Parse(p.Data)
if err != nil {
panic(err)
}
if p.ToIp.Equal(theirUdpAddr.IP) && p.ToPort == uint16(theirUdpAddr.Port) && h.Type == 1 {
return router.RouteAndExit
}
return router.KeepRouting
})
//TODO: Assert pending hostmap - I should have a correct hostinfo for them now
t.Log("My cached packet should be received by them")
myCachedPacket := theirControl.GetFromTun(true)
assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIp, theirVpnIp, 80, 80)
t.Log("Test the tunnel with them")
assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIp, theirVpnIp, myControl, theirControl)
assertTunnel(t, myVpnIp, theirVpnIp, myControl, theirControl, r)
t.Log("Flush all packets from all controllers")
r.FlushAll()
t.Log("Ensure ensure I don't have any hostinfo artifacts from evil")
assert.Nil(t, myControl.GetHostInfoByVpnIP(ip2int(evilVpnIp), true), "My pending hostmap should not contain evil")
assert.Nil(t, myControl.GetHostInfoByVpnIP(ip2int(evilVpnIp), false), "My main hostmap should not contain evil")
//NOTE: if evil lost the handshake race it may still have a tunnel since me would reject the handshake since the tunnel is complete
//TODO: assert hostmaps for everyone
t.Log("Success!")
myControl.Stop()
theirControl.Stop()
}
func Test_Case1_Stage1Race(t *testing.T) {
ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
myControl, myVpnIp, myUdpAddr := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 1})
theirControl, theirVpnIp, theirUdpAddr := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2})
// Put their info in our lighthouse and vice versa
myControl.InjectLightHouseAddr(theirVpnIp, theirUdpAddr)
theirControl.InjectLightHouseAddr(myVpnIp, myUdpAddr)
// Build a router so we don't have to reason who gets which packet
r := router.NewR(myControl, theirControl)
// Start the servers
myControl.Start()
theirControl.Start()
t.Log("Trigger a handshake to start on both me and them")
myControl.InjectTunUDPPacket(theirVpnIp, 80, 80, []byte("Hi from me"))
theirControl.InjectTunUDPPacket(myVpnIp, 80, 80, []byte("Hi from them"))
t.Log("Get both stage 1 handshake packets")
myHsForThem := myControl.GetFromUDP(true)
theirHsForMe := theirControl.GetFromUDP(true)
t.Log("Now inject both stage 1 handshake packets")
myControl.InjectUDPPacket(theirHsForMe)
theirControl.InjectUDPPacket(myHsForThem)
//TODO: they should win, grab their index for me and make sure I use it in the end.
t.Log("They should not have a stage 2 (won the race) but I should send one")
theirControl.InjectUDPPacket(myControl.GetFromUDP(true))
t.Log("Route for me until I send a message packet to them")
myControl.WaitForType(1, 0, theirControl)
t.Log("My cached packet should be received by them")
myCachedPacket := theirControl.GetFromTun(true)
assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIp, theirVpnIp, 80, 80)
t.Log("Route for them until I send a message packet to me")
theirControl.WaitForType(1, 0, myControl)
t.Log("Their cached packet should be received by me")
theirCachedPacket := myControl.GetFromTun(true)
assertUdpPacket(t, []byte("Hi from them"), theirCachedPacket, theirVpnIp, myVpnIp, 80, 80)
t.Log("Do a bidirectional tunnel test")
assertTunnel(t, myVpnIp, theirVpnIp, myControl, theirControl, r)
myControl.Stop()
theirControl.Stop()
//TODO: assert hostmaps
}
//TODO: add a test with many lies

317
e2e/helpers_test.go Normal file
View File

@@ -0,0 +1,317 @@
// +build e2e_testing
package e2e
import (
"crypto/rand"
"encoding/binary"
"fmt"
"io"
"io/ioutil"
"net"
"os"
"testing"
"time"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/e2e/router"
"github.com/stretchr/testify/assert"
"golang.org/x/crypto/curve25519"
"golang.org/x/crypto/ed25519"
"gopkg.in/yaml.v2"
)
type m map[string]interface{}
// newSimpleServer creates a nebula instance with many assumptions
func newSimpleServer(caCrt *cert.NebulaCertificate, caKey []byte, name string, udpIp net.IP) (*nebula.Control, net.IP, *net.UDPAddr) {
l := NewTestLogger()
vpnIpNet := &net.IPNet{IP: make([]byte, len(udpIp)), Mask: net.IPMask{255, 255, 255, 0}}
copy(vpnIpNet.IP, udpIp)
vpnIpNet.IP[1] += 128
udpAddr := net.UDPAddr{
IP: udpIp,
Port: 4242,
}
_, _, myPrivKey, myPEM := newTestCert(caCrt, caKey, name, time.Now(), time.Now().Add(5*time.Minute), vpnIpNet, nil, []string{})
caB, err := caCrt.MarshalToPEM()
if err != nil {
panic(err)
}
mc := m{
"pki": m{
"ca": string(caB),
"cert": string(myPEM),
"key": string(myPrivKey),
},
//"tun": m{"disabled": true},
"firewall": m{
"outbound": []m{{
"proto": "any",
"port": "any",
"host": "any",
}},
"inbound": []m{{
"proto": "any",
"port": "any",
"host": "any",
}},
},
//"handshakes": m{
// "try_interval": "1s",
//},
"listen": m{
"host": udpAddr.IP.String(),
"port": udpAddr.Port,
},
"logging": m{
"timestamp_format": fmt.Sprintf("%v 15:04:05.000000", name),
"level": l.Level.String(),
},
}
cb, err := yaml.Marshal(mc)
if err != nil {
panic(err)
}
config := nebula.NewConfig(l)
config.LoadString(string(cb))
control, err := nebula.Main(config, false, "e2e-test", l, nil)
if err != nil {
panic(err)
}
return control, vpnIpNet.IP, &udpAddr
}
// newTestCaCert will generate a CA cert
func newTestCaCert(before, after time.Time, ips, subnets []*net.IPNet, groups []string) (*cert.NebulaCertificate, []byte, []byte, []byte) {
pub, priv, err := ed25519.GenerateKey(rand.Reader)
if before.IsZero() {
before = time.Now().Add(time.Second * -60).Round(time.Second)
}
if after.IsZero() {
after = time.Now().Add(time.Second * 60).Round(time.Second)
}
nc := &cert.NebulaCertificate{
Details: cert.NebulaCertificateDetails{
Name: "test ca",
NotBefore: time.Unix(before.Unix(), 0),
NotAfter: time.Unix(after.Unix(), 0),
PublicKey: pub,
IsCA: true,
InvertedGroups: make(map[string]struct{}),
},
}
if len(ips) > 0 {
nc.Details.Ips = ips
}
if len(subnets) > 0 {
nc.Details.Subnets = subnets
}
if len(groups) > 0 {
nc.Details.Groups = groups
}
err = nc.Sign(priv)
if err != nil {
panic(err)
}
pem, err := nc.MarshalToPEM()
if err != nil {
panic(err)
}
return nc, pub, priv, pem
}
// newTestCert will generate a signed certificate with the provided details.
// Expiry times are defaulted if you do not pass them in
func newTestCert(ca *cert.NebulaCertificate, key []byte, name string, before, after time.Time, ip *net.IPNet, subnets []*net.IPNet, groups []string) (*cert.NebulaCertificate, []byte, []byte, []byte) {
issuer, err := ca.Sha256Sum()
if err != nil {
panic(err)
}
if before.IsZero() {
before = time.Now().Add(time.Second * -60).Round(time.Second)
}
if after.IsZero() {
after = time.Now().Add(time.Second * 60).Round(time.Second)
}
pub, rawPriv := x25519Keypair()
nc := &cert.NebulaCertificate{
Details: cert.NebulaCertificateDetails{
Name: name,
Ips: []*net.IPNet{ip},
Subnets: subnets,
Groups: groups,
NotBefore: time.Unix(before.Unix(), 0),
NotAfter: time.Unix(after.Unix(), 0),
PublicKey: pub,
IsCA: false,
Issuer: issuer,
InvertedGroups: make(map[string]struct{}),
},
}
err = nc.Sign(key)
if err != nil {
panic(err)
}
pem, err := nc.MarshalToPEM()
if err != nil {
panic(err)
}
return nc, pub, cert.MarshalX25519PrivateKey(rawPriv), pem
}
func x25519Keypair() ([]byte, []byte) {
var pubkey, privkey [32]byte
if _, err := io.ReadFull(rand.Reader, privkey[:]); err != nil {
panic(err)
}
curve25519.ScalarBaseMult(&pubkey, &privkey)
return pubkey[:], privkey[:]
}
func ip2int(ip []byte) uint32 {
if len(ip) == 16 {
return binary.BigEndian.Uint32(ip[12:16])
}
return binary.BigEndian.Uint32(ip)
}
func int2ip(nn uint32) net.IP {
ip := make(net.IP, 4)
binary.BigEndian.PutUint32(ip, nn)
return ip
}
type doneCb func()
func deadline(t *testing.T, seconds time.Duration) doneCb {
timeout := time.After(seconds * time.Second)
done := make(chan bool)
go func() {
select {
case <-timeout:
t.Fatal("Test did not finish in time")
case <-done:
}
}()
return func() {
done <- true
}
}
func assertTunnel(t *testing.T, vpnIpA, vpnIpB net.IP, controlA, controlB *nebula.Control, r *router.R) {
// Send a packet from them to me
controlB.InjectTunUDPPacket(vpnIpA, 80, 90, []byte("Hi from B"))
bPacket := r.RouteUntilTxTun(controlB, controlA)
assertUdpPacket(t, []byte("Hi from B"), bPacket, vpnIpB, vpnIpA, 90, 80)
// And once more from me to them
controlA.InjectTunUDPPacket(vpnIpB, 80, 90, []byte("Hello from A"))
aPacket := r.RouteUntilTxTun(controlA, controlB)
assertUdpPacket(t, []byte("Hello from A"), aPacket, vpnIpA, vpnIpB, 90, 80)
}
func assertHostInfoPair(t *testing.T, addrA, addrB *net.UDPAddr, vpnIpA, vpnIpB net.IP, controlA, controlB *nebula.Control) {
// Get both host infos
hBinA := controlA.GetHostInfoByVpnIP(ip2int(vpnIpB), false)
assert.NotNil(t, hBinA, "Host B was not found by vpnIP in controlA")
hAinB := controlB.GetHostInfoByVpnIP(ip2int(vpnIpA), false)
assert.NotNil(t, hAinB, "Host A was not found by vpnIP in controlB")
// Check that both vpn and real addr are correct
assert.Equal(t, vpnIpB, hBinA.VpnIP, "Host B VpnIp is wrong in control A")
assert.Equal(t, vpnIpA, hAinB.VpnIP, "Host A VpnIp is wrong in control B")
assert.Equal(t, addrB.IP.To16(), hBinA.CurrentRemote.IP.To16(), "Host B remote ip is wrong in control A")
assert.Equal(t, addrA.IP.To16(), hAinB.CurrentRemote.IP.To16(), "Host A remote ip is wrong in control B")
assert.Equal(t, addrB.Port, int(hBinA.CurrentRemote.Port), "Host B remote port is wrong in control A")
assert.Equal(t, addrA.Port, int(hAinB.CurrentRemote.Port), "Host A remote port is wrong in control B")
// Check that our indexes match
assert.Equal(t, hBinA.LocalIndex, hAinB.RemoteIndex, "Host B local index does not match host A remote index")
assert.Equal(t, hBinA.RemoteIndex, hAinB.LocalIndex, "Host B remote index does not match host A local index")
//TODO: Would be nice to assert this memory
//checkIndexes := func(name string, hm *HostMap, hi *HostInfo) {
// hBbyIndex := hmA.Indexes[hBinA.localIndexId]
// assert.NotNil(t, hBbyIndex, "Could not host info by local index in %s", name)
// assert.Equal(t, &hBbyIndex, &hBinA, "%s Indexes map did not point to the right host info", name)
//
// //TODO: remote indexes are susceptible to collision
// hBbyRemoteIndex := hmA.RemoteIndexes[hBinA.remoteIndexId]
// assert.NotNil(t, hBbyIndex, "Could not host info by remote index in %s", name)
// assert.Equal(t, &hBbyRemoteIndex, &hBinA, "%s RemoteIndexes did not point to the right host info", name)
//}
//
//// Check hostmap indexes too
//checkIndexes("hmA", hmA, hBinA)
//checkIndexes("hmB", hmB, hAinB)
}
func assertUdpPacket(t *testing.T, expected, b []byte, fromIp, toIp net.IP, fromPort, toPort uint16) {
packet := gopacket.NewPacket(b, layers.LayerTypeIPv4, gopacket.Lazy)
v4 := packet.Layer(layers.LayerTypeIPv4).(*layers.IPv4)
assert.NotNil(t, v4, "No ipv4 data found")
assert.Equal(t, fromIp, v4.SrcIP, "Source ip was incorrect")
assert.Equal(t, toIp, v4.DstIP, "Dest ip was incorrect")
udp := packet.Layer(layers.LayerTypeUDP).(*layers.UDP)
assert.NotNil(t, udp, "No udp data found")
assert.Equal(t, fromPort, uint16(udp.SrcPort), "Source port was incorrect")
assert.Equal(t, toPort, uint16(udp.DstPort), "Dest port was incorrect")
data := packet.ApplicationLayer()
assert.NotNil(t, data)
assert.Equal(t, expected, data.Payload(), "Data was incorrect")
}
func NewTestLogger() *logrus.Logger {
l := logrus.New()
v := os.Getenv("TEST_LOGS")
if v == "" {
l.SetOutput(ioutil.Discard)
return l
}
switch v {
case "2":
l.SetLevel(logrus.DebugLevel)
case "3":
l.SetLevel(logrus.TraceLevel)
default:
l.SetLevel(logrus.InfoLevel)
}
return l
}

3
e2e/router/doc.go Normal file
View File

@@ -0,0 +1,3 @@
package router
// This file exists to allow `go fmt` to traverse here on its own. The build tags were keeping it out before

320
e2e/router/router.go Normal file
View File

@@ -0,0 +1,320 @@
// +build e2e_testing
package router
import (
"fmt"
"net"
"reflect"
"strconv"
"sync"
"github.com/slackhq/nebula"
)
type R struct {
// Simple map of the ip:port registered on a control to the control
// Basically a router, right?
controls map[string]*nebula.Control
// A map for inbound packets for a control that doesn't know about this address
inNat map[string]*nebula.Control
// A last used map, if an inbound packet hit the inNat map then
// all return packets should use the same last used inbound address for the outbound sender
// map[from address + ":" + to address] => ip:port to rewrite in the udp packet to receiver
outNat map[string]net.UDPAddr
// All interactions are locked to help serialize behavior
sync.Mutex
}
type ExitType int
const (
// Keeps routing, the function will get called again on the next packet
KeepRouting ExitType = 0
// Does not route this packet and exits immediately
ExitNow ExitType = 1
// Routes this packet and exits immediately afterwards
RouteAndExit ExitType = 2
)
type ExitFunc func(packet *nebula.UdpPacket, receiver *nebula.Control) ExitType
func NewR(controls ...*nebula.Control) *R {
r := &R{
controls: make(map[string]*nebula.Control),
inNat: make(map[string]*nebula.Control),
outNat: make(map[string]net.UDPAddr),
}
for _, c := range controls {
addr := c.GetUDPAddr()
if _, ok := r.controls[addr]; ok {
panic("Duplicate listen address: " + addr)
}
r.controls[addr] = c
}
return r
}
// AddRoute will place the nebula controller at the ip and port specified.
// It does not look at the addr attached to the instance.
// If a route is used, this will behave like a NAT for the return path.
// Rewriting the source ip:port to what was last sent to from the origin
func (r *R) AddRoute(ip net.IP, port uint16, c *nebula.Control) {
r.Lock()
defer r.Unlock()
inAddr := net.JoinHostPort(ip.String(), fmt.Sprintf("%v", port))
if _, ok := r.inNat[inAddr]; ok {
panic("Duplicate listen address inNat: " + inAddr)
}
r.inNat[inAddr] = c
}
// OnceFrom will route a single packet from sender then return
// If the router doesn't have the nebula controller for that address, we panic
func (r *R) OnceFrom(sender *nebula.Control) {
r.RouteExitFunc(sender, func(*nebula.UdpPacket, *nebula.Control) ExitType {
return RouteAndExit
})
}
// RouteUntilTxTun will route for sender and return when a packet is seen on receivers tun
// If the router doesn't have the nebula controller for that address, we panic
func (r *R) RouteUntilTxTun(sender *nebula.Control, receiver *nebula.Control) []byte {
tunTx := receiver.GetTunTxChan()
udpTx := sender.GetUDPTxChan()
for {
select {
// Maybe we already have something on the tun for us
case b := <-tunTx:
return b
// Nope, lets push the sender along
case p := <-udpTx:
outAddr := sender.GetUDPAddr()
r.Lock()
inAddr := net.JoinHostPort(p.ToIp.String(), fmt.Sprintf("%v", p.ToPort))
c := r.getControl(outAddr, inAddr, p)
if c == nil {
r.Unlock()
panic("No control for udp tx")
}
c.InjectUDPPacket(p)
r.Unlock()
}
}
}
// RouteExitFunc will call the whatDo func with each udp packet from sender.
// whatDo can return:
// - exitNow: the packet will not be routed and this call will return immediately
// - routeAndExit: this call will return immediately after routing the last packet from sender
// - keepRouting: the packet will be routed and whatDo will be called again on the next packet from sender
func (r *R) RouteExitFunc(sender *nebula.Control, whatDo ExitFunc) {
h := &nebula.Header{}
for {
p := sender.GetFromUDP(true)
r.Lock()
if err := h.Parse(p.Data); err != nil {
panic(err)
}
outAddr := sender.GetUDPAddr()
inAddr := net.JoinHostPort(p.ToIp.String(), fmt.Sprintf("%v", p.ToPort))
receiver := r.getControl(outAddr, inAddr, p)
if receiver == nil {
r.Unlock()
panic("Can't route for host: " + inAddr)
}
e := whatDo(p, receiver)
switch e {
case ExitNow:
r.Unlock()
return
case RouteAndExit:
receiver.InjectUDPPacket(p)
r.Unlock()
return
case KeepRouting:
receiver.InjectUDPPacket(p)
default:
panic(fmt.Sprintf("Unknown exitFunc return: %v", e))
}
r.Unlock()
}
}
// RouteUntilAfterMsgType will route for sender until a message type is seen and sent from sender
// If the router doesn't have the nebula controller for that address, we panic
func (r *R) RouteUntilAfterMsgType(sender *nebula.Control, msgType nebula.NebulaMessageType, subType nebula.NebulaMessageSubType) {
h := &nebula.Header{}
r.RouteExitFunc(sender, func(p *nebula.UdpPacket, r *nebula.Control) ExitType {
if err := h.Parse(p.Data); err != nil {
panic(err)
}
if h.Type == msgType && h.Subtype == subType {
return RouteAndExit
}
return KeepRouting
})
}
// RouteForUntilAfterToAddr will route for sender and return only after it sees and sends a packet destined for toAddr
// finish can be any of the exitType values except `keepRouting`, the default value is `routeAndExit`
// If the router doesn't have the nebula controller for that address, we panic
func (r *R) RouteForUntilAfterToAddr(sender *nebula.Control, toAddr *net.UDPAddr, finish ExitType) {
if finish == KeepRouting {
finish = RouteAndExit
}
r.RouteExitFunc(sender, func(p *nebula.UdpPacket, r *nebula.Control) ExitType {
if p.ToIp.Equal(toAddr.IP) && p.ToPort == uint16(toAddr.Port) {
return finish
}
return KeepRouting
})
}
// RouteForAllExitFunc will route for every registered controller and calls the whatDo func with each udp packet from
// whatDo can return:
// - exitNow: the packet will not be routed and this call will return immediately
// - routeAndExit: this call will return immediately after routing the last packet from sender
// - keepRouting: the packet will be routed and whatDo will be called again on the next packet from sender
func (r *R) RouteForAllExitFunc(whatDo ExitFunc) {
sc := make([]reflect.SelectCase, len(r.controls))
cm := make([]*nebula.Control, len(r.controls))
i := 0
for _, c := range r.controls {
sc[i] = reflect.SelectCase{
Dir: reflect.SelectRecv,
Chan: reflect.ValueOf(c.GetUDPTxChan()),
Send: reflect.Value{},
}
cm[i] = c
i++
}
for {
x, rx, _ := reflect.Select(sc)
r.Lock()
p := rx.Interface().(*nebula.UdpPacket)
outAddr := cm[x].GetUDPAddr()
inAddr := net.JoinHostPort(p.ToIp.String(), fmt.Sprintf("%v", p.ToPort))
receiver := r.getControl(outAddr, inAddr, p)
if receiver == nil {
r.Unlock()
panic("Can't route for host: " + inAddr)
}
e := whatDo(p, receiver)
switch e {
case ExitNow:
r.Unlock()
return
case RouteAndExit:
receiver.InjectUDPPacket(p)
r.Unlock()
return
case KeepRouting:
receiver.InjectUDPPacket(p)
default:
panic(fmt.Sprintf("Unknown exitFunc return: %v", e))
}
r.Unlock()
}
}
// FlushAll will route for every registered controller, exiting once there are no packets left to route
func (r *R) FlushAll() {
sc := make([]reflect.SelectCase, len(r.controls))
cm := make([]*nebula.Control, len(r.controls))
i := 0
for _, c := range r.controls {
sc[i] = reflect.SelectCase{
Dir: reflect.SelectRecv,
Chan: reflect.ValueOf(c.GetUDPTxChan()),
Send: reflect.Value{},
}
cm[i] = c
i++
}
// Add a default case to exit when nothing is left to send
sc = append(sc, reflect.SelectCase{
Dir: reflect.SelectDefault,
Chan: reflect.Value{},
Send: reflect.Value{},
})
for {
x, rx, ok := reflect.Select(sc)
if !ok {
return
}
r.Lock()
p := rx.Interface().(*nebula.UdpPacket)
outAddr := cm[x].GetUDPAddr()
inAddr := net.JoinHostPort(p.ToIp.String(), fmt.Sprintf("%v", p.ToPort))
receiver := r.getControl(outAddr, inAddr, p)
if receiver == nil {
r.Unlock()
panic("Can't route for host: " + inAddr)
}
r.Unlock()
}
}
// getControl performs or seeds NAT translation and returns the control for toAddr, p from fields may change
// This is an internal router function, the caller must hold the lock
func (r *R) getControl(fromAddr, toAddr string, p *nebula.UdpPacket) *nebula.Control {
if newAddr, ok := r.outNat[fromAddr+":"+toAddr]; ok {
p.FromIp = newAddr.IP
p.FromPort = uint16(newAddr.Port)
}
c, ok := r.inNat[toAddr]
if ok {
sHost, sPort, err := net.SplitHostPort(toAddr)
if err != nil {
panic(err)
}
port, err := strconv.Atoi(sPort)
if err != nil {
panic(err)
}
r.outNat[c.GetUDPAddr()+":"+fromAddr] = net.UDPAddr{
IP: net.ParseIP(sHost),
Port: port,
}
return c
}
return r.controls[toAddr]
}

View File

@@ -74,6 +74,7 @@ lighthouse:
# Port Nebula will be listening on. The default here is 4242. For a lighthouse node, the port should be defined,
# however using port 0 will dynamically assign a port and is recommended for roaming nodes.
listen:
# To listen on both any ipv4 and ipv6 use "[::]"
host: 0.0.0.0
port: 4242
# Sets the max number of packets to pull from the kernel for each syscall (under systems that support recvmmsg)
@@ -86,6 +87,15 @@ listen:
#read_buffer: 10485760
#write_buffer: 10485760
# EXPERIMENTAL: This option is currently only supported on linux and may
# change in future minor releases.
#
# Routines is the number of thread pairs to run that consume from the tun and UDP queues.
# Currently, this defaults to 1 which means we have 1 tun queue reader and 1
# UDP queue reader. Setting this above one will set IFF_MULTI_QUEUE on the tun
# device and SO_REUSEPORT on the UDP socket to allow multiple queues.
#routines: 1
punchy:
# Continues to punch inbound/outbound at a regular interval to avoid expiration of firewall nat mappings
punch: true
@@ -98,7 +108,7 @@ punchy:
# delays a punch response for misbehaving NATs, default is 1 second, respond must be true to take effect
#delay: 1s
# Cipher allows you to choose between the available ciphers for your network.
# Cipher allows you to choose between the available ciphers for your network. Options are chachapoly or aes
# IMPORTANT: this value must be identical on ALL NODES/LIGHTHOUSES. We do not/will not support use of different ciphers simultaneously!
#cipher: chachapoly
@@ -192,16 +202,16 @@ logging:
# Handshake Manger Settings
#handshakes:
# Total time to try a handshake = sequence of `try_interval * retries`
# With 100ms interval and 20 retries it is 23.5 seconds
# Handshakes are sent to all known addresses at each interval with a linear backoff,
# Wait try_interval after the 1st attempt, 2 * try_interval after the 2nd, etc, until the handshake is older than timeout
# A 100ms interval with the default 10 retries will give a handshake 5.5 seconds to resolve before timing out
#try_interval: 100ms
#retries: 20
# wait_rotation is the number of handshake attempts to do before starting to try non-local IP addresses
#wait_rotation: 5
# trigger_buffer is the size of the buffer channel for quickly sending handshakes
# after receiving the response for lighthouse queries
#trigger_buffer: 64
# Nebula security group configuration
firewall:
conntrack:

View File

@@ -5,8 +5,6 @@ After=basic.target network.target
[Service]
SyslogIdentifier=nebula
StandardOutput=syslog
StandardError=syslog
ExecReload=/bin/kill -HUP $MAINPID
ExecStart=/usr/local/bin/nebula -config /etc/nebula/config.yml
Restart=always

View File

@@ -2,11 +2,10 @@
Description=nebula
Wants=basic.target
After=basic.target network.target
Before=sshd.service
[Service]
SyslogIdentifier=nebula
StandardOutput=syslog
StandardError=syslog
ExecReload=/bin/kill -HUP $MAINPID
ExecStart=/usr/local/bin/nebula -config /etc/nebula/config.yml
Restart=always

View File

@@ -12,6 +12,7 @@ import (
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/rcrowley/go-metrics"
@@ -67,8 +68,18 @@ type Firewall struct {
rules string
rulesVersion uint16
trackTCPRTT bool
metricTCPRTT metrics.Histogram
trackTCPRTT bool
metricTCPRTT metrics.Histogram
incomingMetrics firewallMetrics
outgoingMetrics firewallMetrics
l *logrus.Logger
}
type firewallMetrics struct {
droppedLocalIP metrics.Counter
droppedRemoteIP metrics.Counter
droppedNoRule metrics.Counter
}
type FirewallConntrack struct {
@@ -155,7 +166,7 @@ func (fp FirewallPacket) MarshalJSON() ([]byte, error) {
}
// NewFirewall creates a new Firewall object. A TimerWheel is created for you from the provided timeouts.
func NewFirewall(tcpTimeout, UDPTimeout, defaultTimeout time.Duration, c *cert.NebulaCertificate) *Firewall {
func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.Duration, c *cert.NebulaCertificate) *Firewall {
//TODO: error on 0 duration
var min, max time.Duration
@@ -193,12 +204,25 @@ func NewFirewall(tcpTimeout, UDPTimeout, defaultTimeout time.Duration, c *cert.N
UDPTimeout: UDPTimeout,
DefaultTimeout: defaultTimeout,
localIps: localIps,
metricTCPRTT: metrics.GetOrRegisterHistogram("network.tcp.rtt", nil, metrics.NewExpDecaySample(1028, 0.015)),
l: l,
metricTCPRTT: metrics.GetOrRegisterHistogram("network.tcp.rtt", nil, metrics.NewExpDecaySample(1028, 0.015)),
incomingMetrics: firewallMetrics{
droppedLocalIP: metrics.GetOrRegisterCounter("firewall.incoming.dropped.local_ip", nil),
droppedRemoteIP: metrics.GetOrRegisterCounter("firewall.incoming.dropped.remote_ip", nil),
droppedNoRule: metrics.GetOrRegisterCounter("firewall.incoming.dropped.no_rule", nil),
},
outgoingMetrics: firewallMetrics{
droppedLocalIP: metrics.GetOrRegisterCounter("firewall.outgoing.dropped.local_ip", nil),
droppedRemoteIP: metrics.GetOrRegisterCounter("firewall.outgoing.dropped.remote_ip", nil),
droppedNoRule: metrics.GetOrRegisterCounter("firewall.outgoing.dropped.no_rule", nil),
},
}
}
func NewFirewallFromConfig(nc *cert.NebulaCertificate, c *Config) (*Firewall, error) {
func NewFirewallFromConfig(l *logrus.Logger, nc *cert.NebulaCertificate, c *Config) (*Firewall, error) {
fw := NewFirewall(
l,
c.GetDuration("firewall.conntrack.tcp_timeout", time.Minute*12),
c.GetDuration("firewall.conntrack.udp_timeout", time.Minute*3),
c.GetDuration("firewall.conntrack.default_timeout", time.Minute*10),
@@ -206,12 +230,12 @@ func NewFirewallFromConfig(nc *cert.NebulaCertificate, c *Config) (*Firewall, er
//TODO: max_connections
)
err := AddFirewallRulesFromConfig(false, c, fw)
err := AddFirewallRulesFromConfig(l, false, c, fw)
if err != nil {
return nil, err
}
err = AddFirewallRulesFromConfig(true, c, fw)
err = AddFirewallRulesFromConfig(l, true, c, fw)
if err != nil {
return nil, err
}
@@ -239,7 +263,7 @@ func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort
if !incoming {
direction = "outgoing"
}
l.WithField("firewallRule", m{"direction": direction, "proto": proto, "startPort": startPort, "endPort": endPort, "groups": groups, "host": host, "ip": sIp, "caName": caName, "caSha": caSha}).
f.l.WithField("firewallRule", m{"direction": direction, "proto": proto, "startPort": startPort, "endPort": endPort, "groups": groups, "host": host, "ip": sIp, "caName": caName, "caSha": caSha}).
Info("Firewall rule added")
var (
@@ -275,7 +299,7 @@ func (f *Firewall) GetRuleHash() string {
return hex.EncodeToString(sum[:])
}
func AddFirewallRulesFromConfig(inbound bool, config *Config, fw FirewallInterface) error {
func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, config *Config, fw FirewallInterface) error {
var table string
if inbound {
table = "firewall.inbound"
@@ -295,7 +319,7 @@ func AddFirewallRulesFromConfig(inbound bool, config *Config, fw FirewallInterfa
for i, t := range rs {
var groups []string
r, err := convertRule(t, table, i)
r, err := convertRule(l, t, table, i)
if err != nil {
return fmt.Errorf("%s rule #%v; %s", table, i, err)
}
@@ -372,26 +396,29 @@ var ErrNoMatchingRule = errors.New("no matching rule in firewall table")
// Drop returns an error if the packet should be dropped, explaining why. It
// returns nil if the packet should not be dropped.
func (f *Firewall) Drop(packet []byte, fp FirewallPacket, incoming bool, h *HostInfo, caPool *cert.NebulaCAPool) error {
func (f *Firewall) Drop(packet []byte, fp FirewallPacket, incoming bool, h *HostInfo, caPool *cert.NebulaCAPool, localCache ConntrackCache) error {
// Check if we spoke to this tuple, if we did then allow this packet
if f.inConns(packet, fp, incoming, h, caPool) {
if f.inConns(packet, fp, incoming, h, caPool, localCache) {
return nil
}
// Make sure remote address matches nebula certificate
if remoteCidr := h.remoteCidr; remoteCidr != nil {
if remoteCidr.Contains(fp.RemoteIP) == nil {
f.metrics(incoming).droppedRemoteIP.Inc(1)
return ErrInvalidRemoteIP
}
} else {
// Simple case: Certificate has one IP and no subnets
if fp.RemoteIP != h.hostId {
f.metrics(incoming).droppedRemoteIP.Inc(1)
return ErrInvalidRemoteIP
}
}
// Make sure we are supposed to be handling this local ip address
if f.localIps.Contains(fp.LocalIP) == nil {
f.metrics(incoming).droppedLocalIP.Inc(1)
return ErrInvalidLocalIP
}
@@ -402,6 +429,7 @@ func (f *Firewall) Drop(packet []byte, fp FirewallPacket, incoming bool, h *Host
// We now know which firewall table to check against
if !table.match(fp, incoming, h.ConnectionState.peerCert, caPool) {
f.metrics(incoming).droppedNoRule.Inc(1)
return ErrNoMatchingRule
}
@@ -411,6 +439,14 @@ func (f *Firewall) Drop(packet []byte, fp FirewallPacket, incoming bool, h *Host
return nil
}
func (f *Firewall) metrics(incoming bool) firewallMetrics {
if incoming {
return f.incomingMetrics
} else {
return f.outgoingMetrics
}
}
// Destroy cleans up any known cyclical references so the object can be free'd my GC. This should be called if a new
// firewall object is created
func (f *Firewall) Destroy() {
@@ -426,7 +462,12 @@ func (f *Firewall) EmitStats() {
metrics.GetOrRegisterGauge("firewall.rules.version", nil).Update(int64(f.rulesVersion))
}
func (f *Firewall) inConns(packet []byte, fp FirewallPacket, incoming bool, h *HostInfo, caPool *cert.NebulaCAPool) bool {
func (f *Firewall) inConns(packet []byte, fp FirewallPacket, incoming bool, h *HostInfo, caPool *cert.NebulaCAPool, localCache ConntrackCache) bool {
if localCache != nil {
if _, ok := localCache[fp]; ok {
return true
}
}
conntrack := f.Conntrack
conntrack.Lock()
@@ -453,8 +494,8 @@ func (f *Firewall) inConns(packet []byte, fp FirewallPacket, incoming bool, h *H
// We now know which firewall table to check against
if !table.match(fp, c.incoming, h.ConnectionState.peerCert, caPool) {
if l.Level >= logrus.DebugLevel {
h.logger().
if f.l.Level >= logrus.DebugLevel {
h.logger(f.l).
WithField("fwPacket", fp).
WithField("incoming", c.incoming).
WithField("rulesVersion", f.rulesVersion).
@@ -466,8 +507,8 @@ func (f *Firewall) inConns(packet []byte, fp FirewallPacket, incoming bool, h *H
return false
}
if l.Level >= logrus.DebugLevel {
h.logger().
if f.l.Level >= logrus.DebugLevel {
h.logger(f.l).
WithField("fwPacket", fp).
WithField("incoming", c.incoming).
WithField("rulesVersion", f.rulesVersion).
@@ -494,6 +535,10 @@ func (f *Firewall) inConns(packet []byte, fp FirewallPacket, incoming bool, h *H
conntrack.Unlock()
if localCache != nil {
localCache[fp] = struct{}{}
}
return true
}
@@ -785,7 +830,7 @@ type rule struct {
CASha string
}
func convertRule(p interface{}, table string, i int) (rule, error) {
func convertRule(l *logrus.Logger, p interface{}, table string, i int) (rule, error) {
r := rule{}
m, ok := p.(map[interface{}]interface{})
@@ -923,3 +968,54 @@ func (f *Firewall) checkTCPRTT(c *conn, p []byte) bool {
c.Seq = 0
return true
}
// ConntrackCache is used as a local routine cache to know if a given flow
// has been seen in the conntrack table.
type ConntrackCache map[FirewallPacket]struct{}
type ConntrackCacheTicker struct {
cacheV uint64
cacheTick uint64
cache ConntrackCache
}
func NewConntrackCacheTicker(d time.Duration) *ConntrackCacheTicker {
if d == 0 {
return nil
}
c := &ConntrackCacheTicker{
cache: ConntrackCache{},
}
go c.tick(d)
return c
}
func (c *ConntrackCacheTicker) tick(d time.Duration) {
for {
time.Sleep(d)
atomic.AddUint64(&c.cacheTick, 1)
}
}
// Get checks if the cache ticker has moved to the next version before returning
// the map. If it has moved, we reset the map.
func (c *ConntrackCacheTicker) Get(l *logrus.Logger) ConntrackCache {
if c == nil {
return nil
}
if tick := atomic.LoadUint64(&c.cacheTick); tick != c.cacheV {
c.cacheV = tick
if ll := len(c.cache); ll > 0 {
if l.Level == logrus.DebugLevel {
l.WithField("len", ll).Debug("resetting conntrack cache")
}
c.cache = make(ConntrackCache, ll)
}
}
return c.cache
}

View File

@@ -15,8 +15,9 @@ import (
)
func TestNewFirewall(t *testing.T) {
l := NewTestLogger()
c := &cert.NebulaCertificate{}
fw := NewFirewall(time.Second, time.Minute, time.Hour, c)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c)
conntrack := fw.Conntrack
assert.NotNil(t, conntrack)
assert.NotNil(t, conntrack.Conns)
@@ -31,35 +32,34 @@ func TestNewFirewall(t *testing.T) {
assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen)
fw = NewFirewall(time.Second, time.Hour, time.Minute, c)
fw = NewFirewall(l, time.Second, time.Hour, time.Minute, c)
assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen)
fw = NewFirewall(time.Hour, time.Second, time.Minute, c)
fw = NewFirewall(l, time.Hour, time.Second, time.Minute, c)
assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen)
fw = NewFirewall(time.Hour, time.Minute, time.Second, c)
fw = NewFirewall(l, time.Hour, time.Minute, time.Second, c)
assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen)
fw = NewFirewall(time.Minute, time.Hour, time.Second, c)
fw = NewFirewall(l, time.Minute, time.Hour, time.Second, c)
assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen)
fw = NewFirewall(time.Minute, time.Second, time.Hour, c)
fw = NewFirewall(l, time.Minute, time.Second, time.Hour, c)
assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen)
}
func TestFirewall_AddRule(t *testing.T) {
l := NewTestLogger()
ob := &bytes.Buffer{}
out := l.Out
l.SetOutput(ob)
defer l.SetOutput(out)
c := &cert.NebulaCertificate{}
fw := NewFirewall(time.Second, time.Minute, time.Hour, c)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c)
assert.NotNil(t, fw.InRules)
assert.NotNil(t, fw.OutRules)
@@ -74,7 +74,7 @@ func TestFirewall_AddRule(t *testing.T) {
assert.Nil(t, fw.InRules.TCP[1].Any.CIDR.root.right)
assert.Nil(t, fw.InRules.TCP[1].Any.CIDR.root.value)
fw = NewFirewall(time.Second, time.Minute, time.Hour, c)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
assert.Nil(t, fw.AddRule(true, fwProtoUDP, 1, 1, []string{"g1"}, "", nil, "", ""))
assert.False(t, fw.InRules.UDP[1].Any.Any)
assert.Contains(t, fw.InRules.UDP[1].Any.Groups[0], "g1")
@@ -83,7 +83,7 @@ func TestFirewall_AddRule(t *testing.T) {
assert.Nil(t, fw.InRules.UDP[1].Any.CIDR.root.right)
assert.Nil(t, fw.InRules.UDP[1].Any.CIDR.root.value)
fw = NewFirewall(time.Second, time.Minute, time.Hour, c)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
assert.Nil(t, fw.AddRule(true, fwProtoICMP, 1, 1, []string{}, "h1", nil, "", ""))
assert.False(t, fw.InRules.ICMP[1].Any.Any)
assert.Empty(t, fw.InRules.ICMP[1].Any.Groups)
@@ -92,23 +92,23 @@ func TestFirewall_AddRule(t *testing.T) {
assert.Nil(t, fw.InRules.ICMP[1].Any.CIDR.root.right)
assert.Nil(t, fw.InRules.ICMP[1].Any.CIDR.root.value)
fw = NewFirewall(time.Second, time.Minute, time.Hour, c)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
assert.Nil(t, fw.AddRule(false, fwProtoAny, 1, 1, []string{}, "", ti, "", ""))
assert.False(t, fw.OutRules.AnyProto[1].Any.Any)
assert.Empty(t, fw.OutRules.AnyProto[1].Any.Groups)
assert.Empty(t, fw.OutRules.AnyProto[1].Any.Hosts)
assert.NotNil(t, fw.OutRules.AnyProto[1].Any.CIDR.Match(ip2int(ti.IP)))
fw = NewFirewall(time.Second, time.Minute, time.Hour, c)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
assert.Nil(t, fw.AddRule(true, fwProtoUDP, 1, 1, []string{"g1"}, "", nil, "ca-name", ""))
assert.Contains(t, fw.InRules.UDP[1].CANames, "ca-name")
fw = NewFirewall(time.Second, time.Minute, time.Hour, c)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
assert.Nil(t, fw.AddRule(true, fwProtoUDP, 1, 1, []string{"g1"}, "", nil, "", "ca-sha"))
assert.Contains(t, fw.InRules.UDP[1].CAShas, "ca-sha")
// Set any and clear fields
fw = NewFirewall(time.Second, time.Minute, time.Hour, c)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
assert.Nil(t, fw.AddRule(false, fwProtoAny, 0, 0, []string{"g1", "g2"}, "h1", ti, "", ""))
assert.Equal(t, []string{"g1", "g2"}, fw.OutRules.AnyProto[0].Any.Groups[0])
assert.Contains(t, fw.OutRules.AnyProto[0].Any.Hosts, "h1")
@@ -125,26 +125,25 @@ func TestFirewall_AddRule(t *testing.T) {
assert.Nil(t, fw.OutRules.AnyProto[0].Any.CIDR.root.right)
assert.Nil(t, fw.OutRules.AnyProto[0].Any.CIDR.root.value)
fw = NewFirewall(time.Second, time.Minute, time.Hour, c)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
assert.Nil(t, fw.AddRule(false, fwProtoAny, 0, 0, []string{}, "any", nil, "", ""))
assert.True(t, fw.OutRules.AnyProto[0].Any.Any)
fw = NewFirewall(time.Second, time.Minute, time.Hour, c)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
_, anyIp, _ := net.ParseCIDR("0.0.0.0/0")
assert.Nil(t, fw.AddRule(false, fwProtoAny, 0, 0, []string{}, "", anyIp, "", ""))
assert.True(t, fw.OutRules.AnyProto[0].Any.Any)
// Test error conditions
fw = NewFirewall(time.Second, time.Minute, time.Hour, c)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
assert.Error(t, fw.AddRule(true, math.MaxUint8, 0, 0, []string{}, "", nil, "", ""))
assert.Error(t, fw.AddRule(true, fwProtoAny, 10, 0, []string{}, "", nil, "", ""))
}
func TestFirewall_Drop(t *testing.T) {
l := NewTestLogger()
ob := &bytes.Buffer{}
out := l.Out
l.SetOutput(ob)
defer l.SetOutput(out)
p := FirewallPacket{
ip2int(net.IPv4(1, 2, 3, 4)),
@@ -177,49 +176,49 @@ func TestFirewall_Drop(t *testing.T) {
}
h.CreateRemoteCIDR(&c)
fw := NewFirewall(time.Second, time.Minute, time.Hour, &c)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"any"}, "", nil, "", ""))
cp := cert.NewCAPool()
// Drop outbound
assert.Equal(t, fw.Drop([]byte{}, p, false, &h, cp), ErrNoMatchingRule)
assert.Equal(t, fw.Drop([]byte{}, p, false, &h, cp, nil), ErrNoMatchingRule)
// Allow inbound
resetConntrack(fw)
assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp))
assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp, nil))
// Allow outbound because conntrack
assert.NoError(t, fw.Drop([]byte{}, p, false, &h, cp))
assert.NoError(t, fw.Drop([]byte{}, p, false, &h, cp, nil))
// test remote mismatch
oldRemote := p.RemoteIP
p.RemoteIP = ip2int(net.IPv4(1, 2, 3, 10))
assert.Equal(t, fw.Drop([]byte{}, p, false, &h, cp), ErrInvalidRemoteIP)
assert.Equal(t, fw.Drop([]byte{}, p, false, &h, cp, nil), ErrInvalidRemoteIP)
p.RemoteIP = oldRemote
// ensure signer doesn't get in the way of group checks
fw = NewFirewall(time.Second, time.Minute, time.Hour, &c)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"nope"}, "", nil, "", "signer-shasum"))
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group"}, "", nil, "", "signer-shasum-bad"))
assert.Equal(t, fw.Drop([]byte{}, p, true, &h, cp), ErrNoMatchingRule)
assert.Equal(t, fw.Drop([]byte{}, p, true, &h, cp, nil), ErrNoMatchingRule)
// test caSha doesn't drop on match
fw = NewFirewall(time.Second, time.Minute, time.Hour, &c)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"nope"}, "", nil, "", "signer-shasum-bad"))
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group"}, "", nil, "", "signer-shasum"))
assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp))
assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp, nil))
// ensure ca name doesn't get in the way of group checks
cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}}
fw = NewFirewall(time.Second, time.Minute, time.Hour, &c)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"nope"}, "", nil, "ca-good", ""))
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group"}, "", nil, "ca-good-bad", ""))
assert.Equal(t, fw.Drop([]byte{}, p, true, &h, cp), ErrNoMatchingRule)
assert.Equal(t, fw.Drop([]byte{}, p, true, &h, cp, nil), ErrNoMatchingRule)
// test caName doesn't drop on match
cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}}
fw = NewFirewall(time.Second, time.Minute, time.Hour, &c)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"nope"}, "", nil, "ca-good-bad", ""))
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group"}, "", nil, "ca-good", ""))
assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp))
assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp, nil))
}
func BenchmarkFirewallTable_match(b *testing.B) {
@@ -317,10 +316,9 @@ func BenchmarkFirewallTable_match(b *testing.B) {
}
func TestFirewall_Drop2(t *testing.T) {
l := NewTestLogger()
ob := &bytes.Buffer{}
out := l.Out
l.SetOutput(ob)
defer l.SetOutput(out)
p := FirewallPacket{
ip2int(net.IPv4(1, 2, 3, 4)),
@@ -365,22 +363,21 @@ func TestFirewall_Drop2(t *testing.T) {
}
h1.CreateRemoteCIDR(&c1)
fw := NewFirewall(time.Second, time.Minute, time.Hour, &c)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group", "test-group"}, "", nil, "", ""))
cp := cert.NewCAPool()
// h1/c1 lacks the proper groups
assert.Error(t, fw.Drop([]byte{}, p, true, &h1, cp), ErrNoMatchingRule)
assert.Error(t, fw.Drop([]byte{}, p, true, &h1, cp, nil), ErrNoMatchingRule)
// c has the proper groups
resetConntrack(fw)
assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp))
assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp, nil))
}
func TestFirewall_Drop3(t *testing.T) {
l := NewTestLogger()
ob := &bytes.Buffer{}
out := l.Out
l.SetOutput(ob)
defer l.SetOutput(out)
p := FirewallPacket{
ip2int(net.IPv4(1, 2, 3, 4)),
@@ -448,26 +445,25 @@ func TestFirewall_Drop3(t *testing.T) {
}
h3.CreateRemoteCIDR(&c3)
fw := NewFirewall(time.Second, time.Minute, time.Hour, &c)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
assert.Nil(t, fw.AddRule(true, fwProtoAny, 1, 1, []string{}, "host1", nil, "", ""))
assert.Nil(t, fw.AddRule(true, fwProtoAny, 1, 1, []string{}, "", nil, "", "signer-sha"))
cp := cert.NewCAPool()
// c1 should pass because host match
assert.NoError(t, fw.Drop([]byte{}, p, true, &h1, cp))
assert.NoError(t, fw.Drop([]byte{}, p, true, &h1, cp, nil))
// c2 should pass because ca sha match
resetConntrack(fw)
assert.NoError(t, fw.Drop([]byte{}, p, true, &h2, cp))
assert.NoError(t, fw.Drop([]byte{}, p, true, &h2, cp, nil))
// c3 should fail because no match
resetConntrack(fw)
assert.Equal(t, fw.Drop([]byte{}, p, true, &h3, cp), ErrNoMatchingRule)
assert.Equal(t, fw.Drop([]byte{}, p, true, &h3, cp, nil), ErrNoMatchingRule)
}
func TestFirewall_DropConntrackReload(t *testing.T) {
l := NewTestLogger()
ob := &bytes.Buffer{}
out := l.Out
l.SetOutput(ob)
defer l.SetOutput(out)
p := FirewallPacket{
ip2int(net.IPv4(1, 2, 3, 4)),
@@ -500,35 +496,35 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
}
h.CreateRemoteCIDR(&c)
fw := NewFirewall(time.Second, time.Minute, time.Hour, &c)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"any"}, "", nil, "", ""))
cp := cert.NewCAPool()
// Drop outbound
assert.Equal(t, fw.Drop([]byte{}, p, false, &h, cp), ErrNoMatchingRule)
assert.Equal(t, fw.Drop([]byte{}, p, false, &h, cp, nil), ErrNoMatchingRule)
// Allow inbound
resetConntrack(fw)
assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp))
assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp, nil))
// Allow outbound because conntrack
assert.NoError(t, fw.Drop([]byte{}, p, false, &h, cp))
assert.NoError(t, fw.Drop([]byte{}, p, false, &h, cp, nil))
oldFw := fw
fw = NewFirewall(time.Second, time.Minute, time.Hour, &c)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
assert.Nil(t, fw.AddRule(true, fwProtoAny, 10, 10, []string{"any"}, "", nil, "", ""))
fw.Conntrack = oldFw.Conntrack
fw.rulesVersion = oldFw.rulesVersion + 1
// Allow outbound because conntrack and new rules allow port 10
assert.NoError(t, fw.Drop([]byte{}, p, false, &h, cp))
assert.NoError(t, fw.Drop([]byte{}, p, false, &h, cp, nil))
oldFw = fw
fw = NewFirewall(time.Second, time.Minute, time.Hour, &c)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
assert.Nil(t, fw.AddRule(true, fwProtoAny, 11, 11, []string{"any"}, "", nil, "", ""))
fw.Conntrack = oldFw.Conntrack
fw.rulesVersion = oldFw.rulesVersion + 1
// Drop outbound because conntrack doesn't match new ruleset
assert.Equal(t, fw.Drop([]byte{}, p, false, &h, cp), ErrNoMatchingRule)
assert.Equal(t, fw.Drop([]byte{}, p, false, &h, cp, nil), ErrNoMatchingRule)
}
func BenchmarkLookup(b *testing.B) {
@@ -647,124 +643,126 @@ func Test_parsePort(t *testing.T) {
}
func TestNewFirewallFromConfig(t *testing.T) {
l := NewTestLogger()
// Test a bad rule definition
c := &cert.NebulaCertificate{}
conf := NewConfig()
conf := NewConfig(l)
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": "asdf"}
_, err := NewFirewallFromConfig(c, conf)
_, err := NewFirewallFromConfig(l, c, conf)
assert.EqualError(t, err, "firewall.outbound failed to parse, should be an array of rules")
// Test both port and code
conf = NewConfig()
conf = NewConfig(l)
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "code": "2"}}}
_, err = NewFirewallFromConfig(c, conf)
_, err = NewFirewallFromConfig(l, c, conf)
assert.EqualError(t, err, "firewall.outbound rule #0; only one of port or code should be provided")
// Test missing host, group, cidr, ca_name and ca_sha
conf = NewConfig()
conf = NewConfig(l)
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{}}}
_, err = NewFirewallFromConfig(c, conf)
_, err = NewFirewallFromConfig(l, c, conf)
assert.EqualError(t, err, "firewall.outbound rule #0; at least one of host, group, cidr, ca_name, or ca_sha must be provided")
// Test code/port error
conf = NewConfig()
conf = NewConfig(l)
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "a", "host": "testh"}}}
_, err = NewFirewallFromConfig(c, conf)
_, err = NewFirewallFromConfig(l, c, conf)
assert.EqualError(t, err, "firewall.outbound rule #0; code was not a number; `a`")
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "a", "host": "testh"}}}
_, err = NewFirewallFromConfig(c, conf)
_, err = NewFirewallFromConfig(l, c, conf)
assert.EqualError(t, err, "firewall.outbound rule #0; port was not a number; `a`")
// Test proto error
conf = NewConfig()
conf = NewConfig(l)
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "host": "testh"}}}
_, err = NewFirewallFromConfig(c, conf)
_, err = NewFirewallFromConfig(l, c, conf)
assert.EqualError(t, err, "firewall.outbound rule #0; proto was not understood; ``")
// Test cidr parse error
conf = NewConfig()
conf = NewConfig(l)
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "cidr": "testh", "proto": "any"}}}
_, err = NewFirewallFromConfig(c, conf)
_, err = NewFirewallFromConfig(l, c, conf)
assert.EqualError(t, err, "firewall.outbound rule #0; cidr did not parse; invalid CIDR address: testh")
// Test both group and groups
conf = NewConfig()
conf = NewConfig(l)
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a", "groups": []string{"b", "c"}}}}
_, err = NewFirewallFromConfig(c, conf)
_, err = NewFirewallFromConfig(l, c, conf)
assert.EqualError(t, err, "firewall.inbound rule #0; only one of group or groups should be defined, both provided")
}
func TestAddFirewallRulesFromConfig(t *testing.T) {
l := NewTestLogger()
// Test adding tcp rule
conf := NewConfig()
conf := NewConfig(l)
mf := &mockFirewall{}
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "tcp", "host": "a"}}}
assert.Nil(t, AddFirewallRulesFromConfig(false, conf, mf))
assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf))
assert.Equal(t, addRuleCall{incoming: false, proto: fwProtoTCP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil}, mf.lastCall)
// Test adding udp rule
conf = NewConfig()
conf = NewConfig(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "udp", "host": "a"}}}
assert.Nil(t, AddFirewallRulesFromConfig(false, conf, mf))
assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf))
assert.Equal(t, addRuleCall{incoming: false, proto: fwProtoUDP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil}, mf.lastCall)
// Test adding icmp rule
conf = NewConfig()
conf = NewConfig(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "icmp", "host": "a"}}}
assert.Nil(t, AddFirewallRulesFromConfig(false, conf, mf))
assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf))
assert.Equal(t, addRuleCall{incoming: false, proto: fwProtoICMP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil}, mf.lastCall)
// Test adding any rule
conf = NewConfig()
conf = NewConfig(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "host": "a"}}}
assert.Nil(t, AddFirewallRulesFromConfig(true, conf, mf))
assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: fwProtoAny, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil}, mf.lastCall)
// Test adding rule with ca_sha
conf = NewConfig()
conf = NewConfig(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "ca_sha": "12312313123"}}}
assert.Nil(t, AddFirewallRulesFromConfig(true, conf, mf))
assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: fwProtoAny, startPort: 1, endPort: 1, groups: nil, ip: nil, caSha: "12312313123"}, mf.lastCall)
// Test adding rule with ca_name
conf = NewConfig()
conf = NewConfig(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "ca_name": "root01"}}}
assert.Nil(t, AddFirewallRulesFromConfig(true, conf, mf))
assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: fwProtoAny, startPort: 1, endPort: 1, groups: nil, ip: nil, caName: "root01"}, mf.lastCall)
// Test single group
conf = NewConfig()
conf = NewConfig(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a"}}}
assert.Nil(t, AddFirewallRulesFromConfig(true, conf, mf))
assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: fwProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: nil}, mf.lastCall)
// Test single groups
conf = NewConfig()
conf = NewConfig(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "groups": "a"}}}
assert.Nil(t, AddFirewallRulesFromConfig(true, conf, mf))
assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: fwProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: nil}, mf.lastCall)
// Test multiple AND groups
conf = NewConfig()
conf = NewConfig(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "groups": []string{"a", "b"}}}}
assert.Nil(t, AddFirewallRulesFromConfig(true, conf, mf))
assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: fwProtoAny, startPort: 1, endPort: 1, groups: []string{"a", "b"}, ip: nil}, mf.lastCall)
// Test Add error
conf = NewConfig()
conf = NewConfig(l)
mf = &mockFirewall{}
mf.nextCallReturn = errors.New("test error")
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "host": "a"}}}
assert.EqualError(t, AddFirewallRulesFromConfig(true, conf, mf), "firewall.inbound rule #0; `test error`")
assert.EqualError(t, AddFirewallRulesFromConfig(l, true, conf, mf), "firewall.inbound rule #0; `test error`")
}
func TestTCPRTTTracking(t *testing.T) {
@@ -859,17 +857,16 @@ func TestTCPRTTTracking(t *testing.T) {
}
func TestFirewall_convertRule(t *testing.T) {
l := NewTestLogger()
ob := &bytes.Buffer{}
out := l.Out
l.SetOutput(ob)
defer l.SetOutput(out)
// Ensure group array of 1 is converted and a warning is printed
c := map[interface{}]interface{}{
"group": []interface{}{"group1"},
}
r, err := convertRule(c, "test", 1)
r, err := convertRule(l, c, "test", 1)
assert.Contains(t, ob.String(), "test rule #1; group was an array with a single value, converting to simple value")
assert.Nil(t, err)
assert.Equal(t, "group1", r.Group)
@@ -880,7 +877,7 @@ func TestFirewall_convertRule(t *testing.T) {
"group": []interface{}{"group1", "group2"},
}
r, err = convertRule(c, "test", 1)
r, err = convertRule(l, c, "test", 1)
assert.Equal(t, "", ob.String())
assert.Error(t, err, "group should contain a single value, an array with more than one entry was provided")
@@ -890,7 +887,7 @@ func TestFirewall_convertRule(t *testing.T) {
"group": "group1",
}
r, err = convertRule(c, "test", 1)
r, err = convertRule(l, c, "test", 1)
assert.Nil(t, err)
assert.Equal(t, "group1", r.Group)
}

18
go.mod
View File

@@ -1,6 +1,6 @@
module github.com/slackhq/nebula
go 1.12
go 1.16
require (
github.com/anmitsu/go-shlex v0.0.0-20161002113705-648efa622239
@@ -8,12 +8,13 @@ require (
github.com/cespare/xxhash/v2 v2.1.1 // indirect
github.com/cyberdelia/go-metrics-graphite v0.0.0-20161219230853-39f87cc3b432
github.com/flynn/go-shlex v0.0.0-20150515145356-3f9db97f8568 // indirect
github.com/flynn/noise v0.0.0-20180327030543-2492fe189ae6
github.com/golang/protobuf v1.3.2
github.com/flynn/noise v0.0.0-20210331153838-4bdb43be3117
github.com/gogo/protobuf v1.3.2
github.com/golang/protobuf v1.5.0
github.com/google/gopacket v1.1.19
github.com/imdario/mergo v0.3.8
github.com/kardianos/service v1.1.0
github.com/konsorten/go-windows-terminal-sequences v1.0.2 // indirect
github.com/kr/pretty v0.1.0 // indirect
github.com/miekg/dns v1.1.25
github.com/nbrownus/go-metrics-prometheus v0.0.0-20180622211546-6e6d5173d99c
github.com/prometheus/client_golang v1.2.1
@@ -21,13 +22,14 @@ require (
github.com/prometheus/procfs v0.0.8 // indirect
github.com/rcrowley/go-metrics v0.0.0-20190826022208-cac0b30c2563
github.com/sirupsen/logrus v1.4.2
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e
github.com/songgao/water v0.0.0-20190725173103-fd331bda3f4b
github.com/stretchr/testify v1.6.1
github.com/vishvananda/netlink v1.0.1-0.20190522153524-00009fb8606a
github.com/vishvananda/netns v0.0.0-20191106174202-0a2b9b5464df // indirect
golang.org/x/crypto v0.0.0-20200220183623-bac4c82f6975
golang.org/x/net v0.0.0-20191209160850-c0dbc17a3553
golang.org/x/sys v0.0.0-20191210023423-ac6580df4449
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 // indirect
golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68
google.golang.org/protobuf v1.26.0
gopkg.in/yaml.v2 v2.2.7
)

78
go.sum
View File

@@ -7,11 +7,9 @@ github.com/anmitsu/go-shlex v0.0.0-20161002113705-648efa622239/go.mod h1:2FmKhYU
github.com/armon/go-radix v1.0.0 h1:F4z6KzEeeQIMeLFa97iZU6vupzoecKdU5TX24SNppXI=
github.com/armon/go-radix v1.0.0/go.mod h1:ufUuZ+zHj4x4TnLV4JWEpy2hxWSpsRywHrMgIH9cCH8=
github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q=
github.com/beorn7/perks v1.0.0 h1:HWo1m869IqiPhD389kmkxeTalrjNbbJTC8LXupb+sl0=
github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8=
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
github.com/cespare/xxhash/v2 v2.1.0 h1:yTUvW7Vhb89inJ+8irsUqiWjh8iT6sQPZiQzI6ReGkA=
github.com/cespare/xxhash/v2 v2.1.0/go.mod h1:dgIUBU3pDso/gPgZ1osOZ0iQf77oPR28Tjxl5dIMyVM=
github.com/cespare/xxhash/v2 v2.1.1 h1:6MnRN8NT7+YBpUIWxHtefFZOKTAPgGjpQSxqLNn0+qY=
github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
@@ -22,38 +20,43 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/flynn/go-shlex v0.0.0-20150515145356-3f9db97f8568 h1:BHsljHzVlRcyQhjrss6TZTdY2VfCqZPbv5k3iBFa2ZQ=
github.com/flynn/go-shlex v0.0.0-20150515145356-3f9db97f8568/go.mod h1:xEzjJPgXI435gkrCt3MPfRiAkVrwSbHsst4LCFVfpJc=
github.com/flynn/noise v0.0.0-20180327030543-2492fe189ae6 h1:u/UEqS66A5ckRmS4yNpjmVH56sVtS/RfclBAYocb4as=
github.com/flynn/noise v0.0.0-20180327030543-2492fe189ae6/go.mod h1:1i71OnUq3iUe1ma7Lr6yG6/rjvM3emb6yoL7xLFzcVQ=
github.com/flynn/noise v0.0.0-20210331153838-4bdb43be3117 h1:Dxhvhray2DpvNnrZEnoGG5rz238fUeQTh4sdzTr+d1U=
github.com/flynn/noise v0.0.0-20210331153838-4bdb43be3117/go.mod h1:xbMo+0i6+IGbYdJhF31t2eR1BIU0CYc12+BNAKwUTag=
github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
github.com/go-logfmt/logfmt v0.3.0/go.mod h1:Qt1PoO58o5twSAckw1HlFXLmHsOX5/0LbT9GBnD5lWE=
github.com/go-logfmt/logfmt v0.4.0/go.mod h1:3RMwSq7FuexP4Kalkev3ejPJsZTpXXBr9+V4qmtdjCk=
github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY=
github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ=
github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q=
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/golang/protobuf v1.3.1 h1:YF8+flBXS5eO826T4nzqPrxfhQThhXl0YzfuUPu4SBg=
github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/golang/protobuf v1.3.2 h1:6nsPYzhq5kReh6QImI3k5qWzO4PEbvbIW2cwSfR/6xs=
github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/golang/protobuf v1.5.0 h1:LUVKkCeviFUMKqHa4tXIIij/lbhnMbP7Fn5wKdKkRh4=
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
github.com/google/go-cmp v0.3.1 h1:Xye71clBPdm5HgqGwUkwhbynsUJZhDbS20FvLhQ2izg=
github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU=
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/gopacket v1.1.19 h1:ves8RnFZPGiFnTS0uPQStjwru6uO6h+nlr9j6fL7kF8=
github.com/google/gopacket v1.1.19/go.mod h1:iJ8V8n6KS+z2U1A8pUwu8bW5SyEMkXJB8Yo/Vo+TKTo=
github.com/imdario/mergo v0.3.8 h1:CGgOkSJeqMRmt0D9XLWExdT4m4F1vd3FV3VPt+0VxkQ=
github.com/imdario/mergo v0.3.8/go.mod h1:2EnlNZ0deacrJVfApfmtdGgDfMuh/nq6Ok1EcJh5FfA=
github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU=
github.com/json-iterator/go v1.1.7/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4=
github.com/julienschmidt/httprouter v1.2.0/go.mod h1:SYymIcj16QtmaHHD7aYtjjsJG7VTCxuUUipMqKk8s4w=
github.com/kardianos/service v1.0.0 h1:HgQS3mFfOlyntWX8Oke98JcJLqt1DBcHR4kxShpYef0=
github.com/kardianos/service v1.0.0/go.mod h1:8CzDhVuCuugtsHyZoTvsOBuvonN/UDBvl0kH+BUxvbo=
github.com/kardianos/service v1.1.0 h1:QV2SiEeWK42P0aEmGcsAgjApw/lRxkwopvT+Gu6t1/0=
github.com/kardianos/service v1.1.0/go.mod h1:RrJI2xn5vve/r32U5suTbeaSGoMU6GbNPoj36CVYcHc=
github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8=
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
github.com/konsorten/go-windows-terminal-sequences v1.0.2 h1:DB17ag19krx9CFsz4o3enTrPXyIXCl+2iCXH/aMAp9s=
github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc=
github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
github.com/kr/pretty v0.2.1 h1:Fmg33tUaq4/8ym9TJN1x7sLJnHVwhP33CNkpYV/7rwI=
github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
@@ -62,10 +65,8 @@ github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5
github.com/miekg/dns v1.1.25 h1:dFwPR6SfLtrSwgDcIq2bcU/gVutB4sNApq2HBdqcakg=
github.com/miekg/dns v1.1.25/go.mod h1:bPDLeHnStXmXAq1m/Ch/hvfNHr14JKNPMBo3VZKjuso=
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0=
github.com/modern-go/reflect2 v1.0.1 h1:9f412s+6RmYXLWZSEzVVgPGK7C2PphHj5RJrvfx9AWI=
github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0=
github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U=
github.com/nbrownus/go-metrics-prometheus v0.0.0-20180622211546-6e6d5173d99c h1:G/mfx/MWYuaaGlHkZQBBXFAJiYnRt/GaOVxnRHjlxg4=
@@ -79,12 +80,10 @@ github.com/prometheus/client_golang v1.0.0/go.mod h1:db9x61etRT2tGnBNRi70OPL5Fsn
github.com/prometheus/client_golang v1.2.1 h1:JnMpQc6ppsNgw9QPAGF6Dod479itz7lvlsMzzNayLOI=
github.com/prometheus/client_golang v1.2.1/go.mod h1:XMU6Z2MjaRKVu/dC1qupJI9SiNkDYzz3xecMgSW/F+U=
github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo=
github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90 h1:S/YWwWx/RA8rT8tKFRuGUZhuA90OyIBpPCXkcbwU8DE=
github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
github.com/prometheus/client_model v0.0.0-20191202183732-d1d2010b5bee h1:iBZPTYkGLvdu6+A5TsMUJQkQX9Ad4aCEnSQtdxPuTCQ=
github.com/prometheus/client_model v0.0.0-20191202183732-d1d2010b5bee/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
github.com/prometheus/common v0.4.1 h1:K0MGApIoQvMw27RTdJkPbr3JZ7DNbtxQNyi5STVM6Kw=
github.com/prometheus/common v0.4.1/go.mod h1:TNfzLD0ON7rHzMJeJkieUDPYmFC7Snx/y86RQel1bk4=
github.com/prometheus/common v0.7.0 h1:L+1lyG48J1zAQXA3RBX/nG/B3gjlHq0zTt2tlbJLyCY=
github.com/prometheus/common v0.7.0/go.mod h1:DjGbpBbp5NYNiECxcL/VnbXCCaQpKd3tt26CguLLsqA=
@@ -98,12 +97,13 @@ github.com/rcrowley/go-metrics v0.0.0-20190826022208-cac0b30c2563/go.mod h1:bCqn
github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo=
github.com/sirupsen/logrus v1.4.2 h1:SPIRibHv4MatM3XXNO2BJeFLZwZ2LvZgfQ5+UNI2im4=
github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE=
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e h1:MRM5ITcdelLK2j1vwZ3Je0FKVCfqOLp5zO6trqMLYs0=
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e/go.mod h1:XV66xRDqSt+GTGFMVlhk3ULuV0y9ZmzeVGR4mloJI3M=
github.com/songgao/water v0.0.0-20190725173103-fd331bda3f4b h1:+y4hCMc/WKsDbAPsOQZgBSaSZ26uh2afyaWeVg/3s/c=
github.com/songgao/water v0.0.0-20190725173103-fd331bda3f4b/go.mod h1:P5HUIBuIWKbyjl083/loAegFkfbFNx5i2qEP4CNbm7E=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0=
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
@@ -111,24 +111,34 @@ github.com/vishvananda/netlink v1.0.1-0.20190522153524-00009fb8606a h1:Bt1IVPhiC
github.com/vishvananda/netlink v1.0.1-0.20190522153524-00009fb8606a/go.mod h1:+SR5DhBJrl6ZM7CoCKvpw5BKroDKQ+PJqOg65H/2ktk=
github.com/vishvananda/netns v0.0.0-20191106174202-0a2b9b5464df h1:OviZH7qLw/7ZovXvuNyL3XQl8UFofeikI1NW1Gypu7k=
github.com/vishvananda/netns v0.0.0-20191106174202-0a2b9b5464df/go.mod h1:JP3t17pCcGlemwknint6hfoeCVQrEMVwxRLRjXpq+BU=
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20190923035154-9ee001bba392/go.mod h1:/lpIB1dKB+9EgE3H3cr1v9wB50oz8l4C4h62xy7jSTY=
golang.org/x/crypto v0.0.0-20200220183623-bac4c82f6975 h1:/Tl7pH94bvbAAHBdZJT947M/+gp0+CqQXDtMRC0fseo=
golang.org/x/crypto v0.0.0-20200220183623-bac4c82f6975/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2 h1:It14KIkyBFYkHkwZ7k45minvA9aorojkyjGk9KJ5B/w=
golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY=
golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg=
golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190613194153-d28f0bde5980/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20190923162816-aa69164e4478/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20191209160850-c0dbc17a3553 h1:efeOvDhwQ29Dj3SdAV/MJf8oukgn+8D8WgaCaRMchF8=
golang.org/x/net v0.0.0-20191209160850-c0dbc17a3553/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110 h1:qWPm9rbaAMKs8Bq/9LRpbMqxWRVUAQwMI9fVrssnTfw=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20190423024810-112230192c58 h1:8gQV6CLnAEikrhgkHFbMAEhagSSnXWGV915qUMm9mrU=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e h1:vcxGaoTs7kV8m5Np9uUNQin4BrLOthgV7252N8V+FwY=
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9 h1:SQFwaSi55rU7vdNs9Yr0Z324VNlrF+0wMqRXT4St8ck=
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190204203706-41f3e6584952/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
@@ -138,19 +148,33 @@ golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20190922100055-0a153f010e69/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190924154521-2837fb4f24fe/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20191010194322-b09406accb47/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20191210023423-ac6580df4449 h1:gSbV7h1NRL2G1xTg/owz62CST1oJBmxy4QpMMregXVQ=
golang.org/x/sys v0.0.0-20191210023423-ac6580df4449/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68 h1:nxC68pudNYkKU6jWhgrqdreuFiOQWj1Fs7T3VrH4Pjw=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1 h1:v+OssWQX+hTHEmOBgwxdZxK4zHq3yOs8F9J7mk0PY8E=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20190907020128-2ca718005c18/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE=
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
google.golang.org/protobuf v1.26.0 h1:bxAC2xTBsZGibn2RTntX0oH50xLsqy1OxA9tTL3p/lk=
google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc=
gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw=
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.2.7 h1:VUgggvou5XRW9mHwD/yXxIYSMtY0zoKQf/v226p2nyo=
gopkg.in/yaml.v2 v2.2.7/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=

View File

@@ -6,31 +6,23 @@ const (
)
func HandleIncomingHandshake(f *Interface, addr *udpAddr, packet []byte, h *Header, hostinfo *HostInfo) {
newHostinfo, _ := f.handshakeManager.QueryIndex(h.RemoteIndex)
//TODO: For stage 1 we won't have hostinfo yet but stage 2 and above would require it, this check may be helpful in those cases
//if err != nil {
// l.WithError(err).WithField("udpAddr", addr).Error("Error while finding host info for handshake message")
// return
//}
if !f.lightHouse.remoteAllowList.Allow(udp2ipInt(addr)) {
l.WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
if !f.lightHouse.remoteAllowList.Allow(addr.IP) {
f.l.WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
return
}
tearDown := false
switch h.Subtype {
case handshakeIXPSK0:
switch h.MessageCounter {
case 1:
tearDown = ixHandshakeStage1(f, addr, newHostinfo, packet, h)
ixHandshakeStage1(f, addr, packet, h)
case 2:
tearDown = ixHandshakeStage2(f, addr, newHostinfo, packet, h)
newHostinfo, _ := f.handshakeManager.QueryIndex(h.RemoteIndex)
tearDown := ixHandshakeStage2(f, addr, newHostinfo, packet, h)
if tearDown && newHostinfo != nil {
f.handshakeManager.DeleteHostInfo(newHostinfo)
}
}
}
if tearDown && newHostinfo != nil {
f.handshakeManager.DeleteIndex(newHostinfo.localIndexId)
f.handshakeManager.DeleteVpnIP(newHostinfo.hostId)
}
}

View File

@@ -1,7 +1,6 @@
package nebula
import (
"bytes"
"sync/atomic"
"time"
@@ -15,28 +14,24 @@ import (
// Sending is done by the handshake manager
func ixHandshakeStage0(f *Interface, vpnIp uint32, hostinfo *HostInfo) {
// This queries the lighthouse if we don't know a remote for the host
// We do it here to provoke the lighthouse to preempt our timer wheel and trigger the stage 1 packet to send
// more quickly, effect is a quicker handshake.
if hostinfo.remote == nil {
ips, err := f.lightHouse.Query(vpnIp, f)
if err != nil {
//l.Debugln(err)
}
for _, ip := range ips {
hostinfo.AddRemote(ip)
}
f.lightHouse.QueryServer(vpnIp, f)
}
myIndex, err := generateIndex()
err := f.handshakeManager.AddIndexHostInfo(hostinfo)
if err != nil {
l.WithError(err).WithField("vpnIp", IntIp(vpnIp)).
f.l.WithError(err).WithField("vpnIp", IntIp(vpnIp)).
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to generate index")
return
}
ci := hostinfo.ConnectionState
f.handshakeManager.AddIndexHostInfo(myIndex, hostinfo)
hsProto := &NebulaHandshakeDetails{
InitiatorIndex: myIndex,
Time: uint64(time.Now().Unix()),
InitiatorIndex: hostinfo.localIndexId,
Time: uint64(time.Now().UnixNano()),
Cert: ci.certState.rawCertificateNoKey,
}
@@ -48,240 +43,264 @@ func ixHandshakeStage0(f *Interface, vpnIp uint32, hostinfo *HostInfo) {
hsBytes, err = proto.Marshal(hs)
if err != nil {
l.WithError(err).WithField("vpnIp", IntIp(vpnIp)).
f.l.WithError(err).WithField("vpnIp", IntIp(vpnIp)).
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to marshal handshake message")
return
}
header := HeaderEncode(make([]byte, HeaderLen), Version, uint8(handshake), handshakeIXPSK0, 0, 1)
atomic.AddUint64(ci.messageCounter, 1)
atomic.AddUint64(&ci.atomicMessageCounter, 1)
msg, _, _, err := ci.H.WriteMessage(header, hsBytes)
if err != nil {
l.WithError(err).WithField("vpnIp", IntIp(vpnIp)).
f.l.WithError(err).WithField("vpnIp", IntIp(vpnIp)).
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage")
return
}
// We are sending handshake packet 1, so we don't expect to receive
// handshake packet 1 from the responder
ci.window.Update(f.l, 1)
hostinfo.HandshakePacket[0] = msg
hostinfo.HandshakeReady = true
hostinfo.handshakeStart = time.Now()
}
func ixHandshakeStage1(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet []byte, h *Header) bool {
var ip uint32
if h.RemoteIndex == 0 {
ci := f.newConnectionState(false, noise.HandshakeIX, []byte{}, 0)
// Mark packet 1 as seen so it doesn't show up as missed
ci.window.Update(1)
func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
ci := f.newConnectionState(f.l, false, noise.HandshakeIX, []byte{}, 0)
// Mark packet 1 as seen so it doesn't show up as missed
ci.window.Update(f.l, 1)
msg, _, _, err := ci.H.ReadMessage(nil, packet[HeaderLen:])
if err != nil {
l.WithError(err).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.ReadMessage")
return true
}
msg, _, _, err := ci.H.ReadMessage(nil, packet[HeaderLen:])
if err != nil {
f.l.WithError(err).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.ReadMessage")
return
}
hs := &NebulaHandshake{}
err = proto.Unmarshal(msg, hs)
/*
l.Debugln("GOT INDEX: ", hs.Details.InitiatorIndex)
*/
if err != nil || hs.Details == nil {
l.WithError(err).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed unmarshal handshake message")
return true
}
hs := &NebulaHandshake{}
err = proto.Unmarshal(msg, hs)
/*
l.Debugln("GOT INDEX: ", hs.Details.InitiatorIndex)
*/
if err != nil || hs.Details == nil {
f.l.WithError(err).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed unmarshal handshake message")
return
}
hostinfo, _ := f.handshakeManager.pendingHostMap.QueryReverseIndex(hs.Details.InitiatorIndex)
if hostinfo != nil && bytes.Equal(hostinfo.HandshakePacket[0], packet[HeaderLen:]) {
if msg, ok := hostinfo.HandshakePacket[2]; ok {
f.messageMetrics.Tx(handshake, NebulaMessageSubType(msg[1]), 1)
err := f.outside.WriteTo(msg, addr)
if err != nil {
l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true).
WithError(err).Error("Failed to send handshake message")
} else {
l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true).
Info("Handshake message sent")
}
return false
}
remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert, f.caPool)
if err != nil {
f.l.WithError(err).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).WithField("cert", remoteCert).
Info("Invalid certificate from host")
return
}
vpnIP := ip2int(remoteCert.Details.Ips[0].IP)
certName := remoteCert.Details.Name
fingerprint, _ := remoteCert.Sha256Sum()
l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).WithField("cached", true).
WithField("packets", hostinfo.HandshakePacket).
Error("Seen this handshake packet already but don't have a cached packet to return")
}
remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert)
if err != nil {
l.WithError(err).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).WithField("cert", remoteCert).
Info("Invalid certificate from host")
return true
}
vpnIP := ip2int(remoteCert.Details.Ips[0].IP)
certName := remoteCert.Details.Name
fingerprint, _ := remoteCert.Sha256Sum()
myIndex, err := generateIndex()
if err != nil {
l.WithError(err).WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to generate index")
return true
}
hostinfo, err = f.handshakeManager.AddIndex(myIndex, ci)
if err != nil {
l.WithError(err).WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Error adding index to connection manager")
return true
}
l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
if vpnIP == ip2int(f.certState.certificate.Details.Ips[0].IP) {
f.l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
Info("Handshake message received")
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Refusing to handshake with myself")
return
}
hostinfo.remoteIndexId = hs.Details.InitiatorIndex
hs.Details.ResponderIndex = myIndex
hs.Details.Cert = ci.certState.rawCertificateNoKey
myIndex, err := generateIndex(f.l)
if err != nil {
f.l.WithError(err).WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to generate index")
return
}
hsBytes, err := proto.Marshal(hs)
if err != nil {
l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
hostinfo := &HostInfo{
ConnectionState: ci,
localIndexId: myIndex,
remoteIndexId: hs.Details.InitiatorIndex,
hostId: vpnIP,
HandshakePacket: make(map[uint8][]byte, 0),
lastHandshakeTime: hs.Details.Time,
}
hostinfo.Lock()
defer hostinfo.Unlock()
f.l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
Info("Handshake message received")
hs.Details.ResponderIndex = myIndex
hs.Details.Cert = ci.certState.rawCertificateNoKey
// Update the time in case their clock is way off from ours
hs.Details.Time = uint64(time.Now().UnixNano())
hsBytes, err := proto.Marshal(hs)
if err != nil {
f.l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to marshal handshake message")
return
}
header := HeaderEncode(make([]byte, HeaderLen), Version, uint8(handshake), handshakeIXPSK0, hs.Details.InitiatorIndex, 2)
msg, dKey, eKey, err := ci.H.WriteMessage(header, hsBytes)
if err != nil {
f.l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage")
return
} else if dKey == nil || eKey == nil {
f.l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Noise did not arrive at a key")
return
}
hostinfo.HandshakePacket[0] = make([]byte, len(packet[HeaderLen:]))
copy(hostinfo.HandshakePacket[0], packet[HeaderLen:])
// Regardless of whether you are the sender or receiver, you should arrive here
// and complete standing up the connection.
hostinfo.HandshakePacket[2] = make([]byte, len(msg))
copy(hostinfo.HandshakePacket[2], msg)
// We are sending handshake packet 2, so we don't expect to receive
// handshake packet 2 from the initiator.
ci.window.Update(f.l, 2)
ci.peerCert = remoteCert
ci.dKey = NewNebulaCipherState(dKey)
ci.eKey = NewNebulaCipherState(eKey)
hostinfo.remotes = f.lightHouse.QueryCache(vpnIP)
hostinfo.SetRemote(addr)
hostinfo.CreateRemoteCIDR(remoteCert)
// Only overwrite existing record if we should win the handshake race
overwrite := vpnIP > ip2int(f.certState.certificate.Details.Ips[0].IP)
existing, err := f.handshakeManager.CheckAndComplete(hostinfo, 0, overwrite, f)
if err != nil {
switch err {
case ErrAlreadySeen:
msg = existing.HandshakePacket[2]
f.messageMetrics.Tx(handshake, NebulaMessageSubType(msg[1]), 1)
err := f.outside.WriteTo(msg, addr)
if err != nil {
f.l.WithField("vpnIp", IntIp(existing.hostId)).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true).
WithError(err).Error("Failed to send handshake message")
} else {
f.l.WithField("vpnIp", IntIp(existing.hostId)).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true).
Info("Handshake message sent")
}
return
case ErrExistingHostInfo:
// This means there was an existing tunnel and this handshake was older than the one we are currently based on
f.l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("oldHandshakeTime", existing.lastHandshakeTime).
WithField("newHandshakeTime", hostinfo.lastHandshakeTime).
WithField("fingerprint", fingerprint).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to marshal handshake message")
return true
}
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
Info("Handshake too old")
header := HeaderEncode(make([]byte, HeaderLen), Version, uint8(handshake), handshakeIXPSK0, hs.Details.InitiatorIndex, 2)
msg, dKey, eKey, err := ci.H.WriteMessage(header, hsBytes)
if err != nil {
l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage")
return true
}
if f.hostMap.CheckHandshakeCompleteIP(vpnIP) && vpnIP < ip2int(f.certState.certificate.Details.Ips[0].IP) {
l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
// Send a test packet to trigger an authenticated tunnel test, this should suss out any lingering tunnel issues
f.SendMessageToVpnIp(test, testRequest, vpnIP, []byte(""), make([]byte, 12, 12), make([]byte, mtu))
return
case ErrLocalIndexCollision:
// This means we failed to insert because of collision on localIndexId. Just let the next handshake packet retry
f.l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
Info("Prevented a handshake race")
// Send a test packet to trigger an authenticated tunnel test, this should suss out any lingering tunnel issues
f.SendMessageToVpnIp(test, testRequest, vpnIP, []byte(""), make([]byte, 12, 12), make([]byte, mtu))
return true
}
hostinfo.HandshakePacket[0] = make([]byte, len(packet[HeaderLen:]))
copy(hostinfo.HandshakePacket[0], packet[HeaderLen:])
// Regardless of whether you are the sender or receiver, you should arrive here
// and complete standing up the connection.
if dKey != nil && eKey != nil {
hostinfo.HandshakePacket[2] = make([]byte, len(msg))
copy(hostinfo.HandshakePacket[2], msg)
f.messageMetrics.Tx(handshake, NebulaMessageSubType(msg[1]), 1)
err := f.outside.WriteTo(msg, addr)
if err != nil {
l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
WithError(err).Error("Failed to send handshake")
} else {
l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
Info("Handshake message sent")
}
ip = ip2int(remoteCert.Details.Ips[0].IP)
ci.peerCert = remoteCert
ci.dKey = NewNebulaCipherState(dKey)
ci.eKey = NewNebulaCipherState(eKey)
//l.Debugln("got symmetric pairs")
//hostinfo.ClearRemotes()
hostinfo.AddRemote(*addr)
hostinfo.CreateRemoteCIDR(remoteCert)
f.lightHouse.AddRemoteAndReset(ip, addr)
if f.serveDns {
dnsR.Add(remoteCert.Details.Name+".", remoteCert.Details.Ips[0].IP.String())
}
ho, err := f.hostMap.QueryVpnIP(vpnIP)
if err == nil && ho.localIndexId != 0 {
l.WithField("vpnIp", vpnIP).
WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("action", "removing stale index").
WithField("index", ho.localIndexId).
Debug("Handshake processing")
f.hostMap.DeleteIndex(ho.localIndexId)
}
f.hostMap.AddIndexHostInfo(hostinfo.localIndexId, hostinfo)
f.hostMap.AddVpnIPHostInfo(vpnIP, hostinfo)
hostinfo.handshakeComplete()
} else {
l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
WithField("localIndex", hostinfo.localIndexId).WithField("collision", IntIp(existing.hostId)).
Error("Failed to add HostInfo due to localIndex collision")
return
case ErrExistingHandshake:
// We have a race where both parties think they are an initiator and this tunnel lost, let the other one finish
f.l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
Error("Noise did not arrive at a key")
return true
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
Error("Prevented a pending handshake race")
return
default:
// Shouldn't happen, but just in case someone adds a new error type to CheckAndComplete
// And we forget to update it here
f.l.WithError(err).WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
Error("Failed to add HostInfo to HostMap")
return
}
}
f.hostMap.AddRemote(ip, addr)
return false
// Do the send
f.messageMetrics.Tx(handshake, NebulaMessageSubType(msg[1]), 1)
err = f.outside.WriteTo(msg, addr)
if err != nil {
f.l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
WithError(err).Error("Failed to send handshake")
} else {
f.l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
WithField("sentCachedPackets", len(hostinfo.packetStore)).
Info("Handshake message sent")
}
hostinfo.handshakeComplete(f.l, f.cachedPacketMetrics)
return
}
func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet []byte, h *Header) bool {
if hostinfo == nil {
// Nothing here to tear down, got a bogus stage 2 packet
return true
}
if bytes.Equal(hostinfo.HandshakePacket[2], packet[HeaderLen:]) {
l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
hostinfo.Lock()
defer hostinfo.Unlock()
ci := hostinfo.ConnectionState
if ci.ready {
f.l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("header", h).
Error("Already seen this handshake packet")
Info("Handshake is already complete")
//TODO: evaluate addr for preference, if we handshook with a less preferred addr we can correct quickly here
// We already have a complete tunnel, there is nothing that can be done by processing further stage 1 packets
return false
}
ci := hostinfo.ConnectionState
// Mark packet 2 as seen so it doesn't show up as missed
ci.window.Update(2)
hostinfo.HandshakePacket[2] = make([]byte, len(packet[HeaderLen:]))
copy(hostinfo.HandshakePacket[2], packet[HeaderLen:])
msg, eKey, dKey, err := ci.H.ReadMessage(nil, packet[HeaderLen:])
if err != nil {
l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
f.l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("header", h).
Error("Failed to call noise.ReadMessage")
@@ -289,89 +308,112 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
// to DOS us. Every other error condition after should to allow a possible good handshake to complete in the
// near future
return false
} else if dKey == nil || eKey == nil {
f.l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
Error("Noise did not arrive at a key")
// This should be impossible in IX but just in case, if we get here then there is no chance to recover
// the handshake state machine. Tear it down
return true
}
hs := &NebulaHandshake{}
err = proto.Unmarshal(msg, hs)
if err != nil || hs.Details == nil {
l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
f.l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).Error("Failed unmarshal handshake message")
// The handshake state machine is complete, if things break now there is no chance to recover. Tear down and start again
return true
}
remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert)
remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert, f.caPool)
if err != nil {
l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
f.l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
WithField("cert", remoteCert).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
Error("Invalid certificate from host")
// The handshake state machine is complete, if things break now there is no chance to recover. Tear down and start again
return true
}
vpnIP := ip2int(remoteCert.Details.Ips[0].IP)
certName := remoteCert.Details.Name
fingerprint, _ := remoteCert.Sha256Sum()
// Ensure the right host responded
if vpnIP != hostinfo.hostId {
f.l.WithField("intendedVpnIp", IntIp(hostinfo.hostId)).WithField("haveVpnIp", IntIp(vpnIP)).
WithField("udpAddr", addr).WithField("certName", certName).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
Info("Incorrect host responded to handshake")
// Release our old handshake from pending, it should not continue
f.handshakeManager.pendingHostMap.DeleteHostInfo(hostinfo)
// Create a new hostinfo/handshake for the intended vpn ip
//TODO: this adds it to the timer wheel in a way that aggressively retries
newHostInfo := f.getOrHandshake(hostinfo.hostId)
newHostInfo.Lock()
// Block the current used address
newHostInfo.remotes = hostinfo.remotes
newHostInfo.remotes.BlockRemote(addr)
// Get the correct remote list for the host we did handshake with
hostinfo.remotes = f.lightHouse.QueryCache(vpnIP)
f.l.WithField("blockedUdpAddrs", newHostInfo.remotes.CopyBlockedRemotes()).WithField("vpnIp", IntIp(vpnIP)).
WithField("remotes", newHostInfo.remotes.CopyAddrs(f.hostMap.preferredRanges)).
Info("Blocked addresses for handshakes")
// Swap the packet store to benefit the original intended recipient
hostinfo.ConnectionState.queueLock.Lock()
newHostInfo.packetStore = hostinfo.packetStore
hostinfo.packetStore = []*cachedPacket{}
hostinfo.ConnectionState.queueLock.Unlock()
// Finally, put the correct vpn ip in the host info, tell them to close the tunnel, and return true to tear down
hostinfo.hostId = vpnIP
f.sendCloseTunnel(hostinfo)
newHostInfo.Unlock()
return true
}
// Mark packet 2 as seen so it doesn't show up as missed
ci.window.Update(f.l, 2)
duration := time.Since(hostinfo.handshakeStart).Nanoseconds()
l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
f.l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
WithField("durationNs", duration).
WithField("sentCachedPackets", len(hostinfo.packetStore)).
Info("Handshake message received")
//ci.remoteIndex = hs.ResponderIndex
hostinfo.remoteIndexId = hs.Details.ResponderIndex
hs.Details.Cert = ci.certState.rawCertificateNoKey
hostinfo.lastHandshakeTime = hs.Details.Time
/*
hsBytes, err := proto.Marshal(hs)
if err != nil {
l.Debugln("Failed to marshal handshake: ", err)
return
}
*/
// Store their cert and our symmetric keys
ci.peerCert = remoteCert
ci.dKey = NewNebulaCipherState(dKey)
ci.eKey = NewNebulaCipherState(eKey)
// Regardless of whether you are the sender or receiver, you should arrive here
// and complete standing up the connection.
if dKey != nil && eKey != nil {
ip := ip2int(remoteCert.Details.Ips[0].IP)
ci.peerCert = remoteCert
ci.dKey = NewNebulaCipherState(dKey)
ci.eKey = NewNebulaCipherState(eKey)
//l.Debugln("got symmetric pairs")
// Make sure the current udpAddr being used is set for responding
hostinfo.SetRemote(addr)
//hostinfo.ClearRemotes()
f.hostMap.AddRemote(ip, addr)
hostinfo.CreateRemoteCIDR(remoteCert)
f.lightHouse.AddRemoteAndReset(ip, addr)
if f.serveDns {
dnsR.Add(remoteCert.Details.Name+".", remoteCert.Details.Ips[0].IP.String())
}
// Build up the radix for the firewall if we have subnets in the cert
hostinfo.CreateRemoteCIDR(remoteCert)
ho, err := f.hostMap.QueryVpnIP(vpnIP)
if err == nil && ho.localIndexId != 0 {
l.WithField("vpnIp", vpnIP).
WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("action", "removing stale index").
WithField("index", ho.localIndexId).
Debug("Handshake processing")
f.hostMap.DeleteIndex(ho.localIndexId)
}
f.hostMap.AddVpnIPHostInfo(vpnIP, hostinfo)
f.hostMap.AddIndexHostInfo(hostinfo.localIndexId, hostinfo)
hostinfo.handshakeComplete()
f.metricHandshakes.Update(duration)
} else {
l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
Error("Noise did not arrive at a key")
return true
}
// Complete our handshake and update metrics, this will replace any existing tunnels for this vpnIp
//TODO: Complete here does not do a race avoidance, it will just take the new tunnel. Is this ok?
f.handshakeManager.Complete(hostinfo, f)
hostinfo.handshakeComplete(f.l, f.cachedPacketMetrics)
f.metricHandshakes.Update(duration)
return false
}

View File

@@ -1,22 +1,20 @@
package nebula
import (
"bytes"
"crypto/rand"
"encoding/binary"
"fmt"
"errors"
"net"
"time"
"github.com/rcrowley/go-metrics"
"github.com/sirupsen/logrus"
)
const (
// Total time to try a handshake = sequence of HandshakeTryInterval * HandshakeRetries
// With 100ms interval and 20 retries is 23.5 seconds
DefaultHandshakeTryInterval = time.Millisecond * 100
DefaultHandshakeRetries = 20
// DefaultHandshakeWaitRotation is the number of handshake attempts to do before starting to use other ips addresses
DefaultHandshakeWaitRotation = 5
DefaultHandshakeTryInterval = time.Millisecond * 100
DefaultHandshakeRetries = 10
DefaultHandshakeTriggerBuffer = 64
)
@@ -24,7 +22,6 @@ var (
defaultHandshakeConfig = HandshakeConfig{
tryInterval: DefaultHandshakeTryInterval,
retries: DefaultHandshakeRetries,
waitRotation: DefaultHandshakeWaitRotation,
triggerBuffer: DefaultHandshakeTriggerBuffer,
}
)
@@ -32,43 +29,40 @@ var (
type HandshakeConfig struct {
tryInterval time.Duration
retries int
waitRotation int
triggerBuffer int
messageMetrics *MessageMetrics
}
type HandshakeManager struct {
pendingHostMap *HostMap
mainHostMap *HostMap
lightHouse *LightHouse
outside *udpConn
config HandshakeConfig
pendingHostMap *HostMap
mainHostMap *HostMap
lightHouse *LightHouse
outside *udpConn
config HandshakeConfig
OutboundHandshakeTimer *SystemTimerWheel
messageMetrics *MessageMetrics
metricInitiated metrics.Counter
metricTimedOut metrics.Counter
l *logrus.Logger
// can be used to trigger outbound handshake for the given vpnIP
trigger chan uint32
OutboundHandshakeTimer *SystemTimerWheel
InboundHandshakeTimer *SystemTimerWheel
messageMetrics *MessageMetrics
}
func NewHandshakeManager(tunCidr *net.IPNet, preferredRanges []*net.IPNet, mainHostMap *HostMap, lightHouse *LightHouse, outside *udpConn, config HandshakeConfig) *HandshakeManager {
func NewHandshakeManager(l *logrus.Logger, tunCidr *net.IPNet, preferredRanges []*net.IPNet, mainHostMap *HostMap, lightHouse *LightHouse, outside *udpConn, config HandshakeConfig) *HandshakeManager {
return &HandshakeManager{
pendingHostMap: NewHostMap("pending", tunCidr, preferredRanges),
mainHostMap: mainHostMap,
lightHouse: lightHouse,
outside: outside,
config: config,
trigger: make(chan uint32, config.triggerBuffer),
OutboundHandshakeTimer: NewSystemTimerWheel(config.tryInterval, config.tryInterval*time.Duration(config.retries)),
InboundHandshakeTimer: NewSystemTimerWheel(config.tryInterval, config.tryInterval*time.Duration(config.retries)),
messageMetrics: config.messageMetrics,
pendingHostMap: NewHostMap(l, "pending", tunCidr, preferredRanges),
mainHostMap: mainHostMap,
lightHouse: lightHouse,
outside: outside,
config: config,
trigger: make(chan uint32, config.triggerBuffer),
OutboundHandshakeTimer: NewSystemTimerWheel(config.tryInterval, hsTimeout(config.retries, config.tryInterval)),
messageMetrics: config.messageMetrics,
metricInitiated: metrics.GetOrRegisterCounter("handshake_manager.initiated", nil),
metricTimedOut: metrics.GetOrRegisterCounter("handshake_manager.timed_out", nil),
l: l,
}
}
@@ -77,11 +71,10 @@ func (c *HandshakeManager) Run(f EncWriter) {
for {
select {
case vpnIP := <-c.trigger:
l.WithField("vpnIp", IntIp(vpnIP)).Debug("HandshakeManager: triggered")
c.l.WithField("vpnIp", IntIp(vpnIP)).Debug("HandshakeManager: triggered")
c.handleOutbound(vpnIP, f, true)
case now := <-clockSource:
c.NextOutboundHandshakeTimerTick(now, f)
c.NextInboundHandshakeTimerTick(now)
}
}
}
@@ -99,99 +92,94 @@ func (c *HandshakeManager) NextOutboundHandshakeTimerTick(now time.Time, f EncWr
}
func (c *HandshakeManager) handleOutbound(vpnIP uint32, f EncWriter, lighthouseTriggered bool) {
index, err := c.pendingHostMap.GetIndexByVpnIP(vpnIP)
if err != nil {
return
}
hostinfo, err := c.pendingHostMap.QueryVpnIP(vpnIP)
if err != nil {
return
}
hostinfo.Lock()
defer hostinfo.Unlock()
// If we haven't finished the handshake and we haven't hit max retries, query
// lighthouse and then send the handshake packet again.
if hostinfo.HandshakeCounter < c.config.retries && !hostinfo.HandshakeComplete {
if hostinfo.remote == nil {
// We continue to query the lighthouse because hosts may
// come online during handshake retries. If the query
// succeeds (no error), add the lighthouse info to hostinfo
ips := c.lightHouse.QueryCache(vpnIP)
// If we have no responses yet, or only one IP (the host hadn't
// finished reporting its own IPs yet), then send another query to
// the LH.
if len(ips) <= 1 {
ips, err = c.lightHouse.Query(vpnIP, f)
}
if err == nil {
for _, ip := range ips {
hostinfo.AddRemote(ip)
}
hostinfo.ForcePromoteBest(c.mainHostMap.preferredRanges)
}
} else if lighthouseTriggered {
// We were triggered by a lighthouse HostQueryReply packet, but
// we have already picked a remote for this host (this can happen
// if we are configured with multiple lighthouses). So we can skip
// this trigger and let the timerwheel handle the rest of the
// process
return
}
hostinfo.HandshakeCounter++
// We want to use the "best" calculated ip for the first 5 attempts, after that we just blindly rotate through
// all the others until we can stand up a connection.
if hostinfo.HandshakeCounter > c.config.waitRotation {
hostinfo.rotateRemote()
}
// Ensure the handshake is ready to avoid a race in timer tick and stage 0 handshake generation
if hostinfo.HandshakeReady && hostinfo.remote != nil {
c.messageMetrics.Tx(handshake, NebulaMessageSubType(hostinfo.HandshakePacket[0][1]), 1)
err := c.outside.WriteTo(hostinfo.HandshakePacket[0], hostinfo.remote)
if err != nil {
hostinfo.logger().WithField("udpAddr", hostinfo.remote).
WithField("initiatorIndex", hostinfo.localIndexId).
WithField("remoteIndex", hostinfo.remoteIndexId).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
WithError(err).Error("Failed to send handshake message")
} else {
//TODO: this log line is assuming a lot of stuff around the cached stage 0 handshake packet, we should
// keep the real packet struct around for logging purposes
hostinfo.logger().WithField("udpAddr", hostinfo.remote).
WithField("initiatorIndex", hostinfo.localIndexId).
WithField("remoteIndex", hostinfo.remoteIndexId).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
Info("Handshake message sent")
}
}
// Readd to the timer wheel so we continue trying wait HandshakeTryInterval * counter longer for next try
if !lighthouseTriggered {
//l.Infoln("Interval: ", HandshakeTryInterval*time.Duration(hostinfo.HandshakeCounter))
c.OutboundHandshakeTimer.Add(vpnIP, c.config.tryInterval*time.Duration(hostinfo.HandshakeCounter))
}
} else {
c.pendingHostMap.DeleteVpnIP(vpnIP)
c.pendingHostMap.DeleteIndex(index)
// We may have raced to completion but now that we have a lock we should ensure we have not yet completed.
if hostinfo.HandshakeComplete {
// Ensure we don't exist in the pending hostmap anymore since we have completed
c.pendingHostMap.DeleteHostInfo(hostinfo)
return
}
}
func (c *HandshakeManager) NextInboundHandshakeTimerTick(now time.Time) {
c.InboundHandshakeTimer.advance(now)
for {
ep := c.InboundHandshakeTimer.Purge()
if ep == nil {
break
}
index := ep.(uint32)
// Check if we have a handshake packet to transmit yet
if !hostinfo.HandshakeReady {
// There is currently a slight race in getOrHandshake due to ConnectionState not being part of the HostInfo directly
// Our hostinfo here was added to the pending map and the wheel may have ticked to us before we created ConnectionState
c.OutboundHandshakeTimer.Add(vpnIP, c.config.tryInterval*time.Duration(hostinfo.HandshakeCounter))
return
}
vpnIP, err := c.pendingHostMap.GetVpnIPByIndex(index)
// If we are out of time, clean up
if hostinfo.HandshakeCounter >= c.config.retries {
hostinfo.logger(c.l).WithField("udpAddrs", hostinfo.remotes.CopyAddrs(c.pendingHostMap.preferredRanges)).
WithField("initiatorIndex", hostinfo.localIndexId).
WithField("remoteIndex", hostinfo.remoteIndexId).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
WithField("durationNs", time.Since(hostinfo.handshakeStart).Nanoseconds()).
Info("Handshake timed out")
c.metricTimedOut.Inc(1)
c.pendingHostMap.DeleteHostInfo(hostinfo)
return
}
// We only care about a lighthouse trigger before the first handshake transmit attempt. This is a very specific
// optimization for a fast lighthouse reply
//TODO: it would feel better to do this once, anytime, as our delay increases over time
if lighthouseTriggered && hostinfo.HandshakeCounter > 0 {
// If we didn't return here a lighthouse could cause us to aggressively send handshakes
return
}
// Get a remotes object if we don't already have one.
// This is mainly to protect us as this should never be the case
if hostinfo.remotes == nil {
hostinfo.remotes = c.lightHouse.QueryCache(vpnIP)
}
//TODO: this will generate a load of queries for hosts with only 1 ip (i'm not using a lighthouse, static mapped)
if hostinfo.remotes.Len(c.pendingHostMap.preferredRanges) <= 1 {
// If we only have 1 remote it is highly likely our query raced with the other host registered within the lighthouse
// Our vpnIP here has a tunnel with a lighthouse but has yet to send a host update packet there so we only know about
// the learned public ip for them. Query again to short circuit the promotion counter
c.lightHouse.QueryServer(vpnIP, f)
}
// Send a the handshake to all known ips, stage 2 takes care of assigning the hostinfo.remote based on the first to reply
var sentTo []*udpAddr
hostinfo.remotes.ForEach(c.pendingHostMap.preferredRanges, func(addr *udpAddr, _ bool) {
c.messageMetrics.Tx(handshake, NebulaMessageSubType(hostinfo.HandshakePacket[0][1]), 1)
err = c.outside.WriteTo(hostinfo.HandshakePacket[0], addr)
if err != nil {
continue
hostinfo.logger(c.l).WithField("udpAddr", addr).
WithField("initiatorIndex", hostinfo.localIndexId).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
WithError(err).Error("Failed to send handshake message")
} else {
sentTo = append(sentTo, addr)
}
c.pendingHostMap.DeleteIndex(index)
c.pendingHostMap.DeleteVpnIP(vpnIP)
})
// Don't be too noisy or confusing if we fail to send a handshake - if we don't get through we'll eventually log a timeout
if len(sentTo) > 0 {
hostinfo.logger(c.l).WithField("udpAddrs", sentTo).
WithField("initiatorIndex", hostinfo.localIndexId).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
Info("Handshake message sent")
}
// Increment the counter to increase our delay, linear backoff
hostinfo.HandshakeCounter++
// If a lighthouse triggered this attempt then we are still in the timer wheel and do not need to re-add
if !lighthouseTriggered {
//TODO: feel like we dupe handshake real fast in a tight loop, why?
c.OutboundHandshakeTimer.Add(vpnIP, c.config.tryInterval*time.Duration(hostinfo.HandshakeCounter))
}
}
@@ -199,33 +187,166 @@ func (c *HandshakeManager) AddVpnIP(vpnIP uint32) *HostInfo {
hostinfo := c.pendingHostMap.AddVpnIP(vpnIP)
// We lock here and use an array to insert items to prevent locking the
// main receive thread for very long by waiting to add items to the pending map
//TODO: what lock?
c.OutboundHandshakeTimer.Add(vpnIP, c.config.tryInterval)
c.metricInitiated.Inc(1)
return hostinfo
}
func (c *HandshakeManager) DeleteVpnIP(vpnIP uint32) {
//l.Debugln("Deleting pending vpn ip :", IntIp(vpnIP))
c.pendingHostMap.DeleteVpnIP(vpnIP)
}
var (
ErrExistingHostInfo = errors.New("existing hostinfo")
ErrAlreadySeen = errors.New("already seen")
ErrLocalIndexCollision = errors.New("local index collision")
ErrExistingHandshake = errors.New("existing handshake")
)
func (c *HandshakeManager) AddIndex(index uint32, ci *ConnectionState) (*HostInfo, error) {
hostinfo, err := c.pendingHostMap.AddIndex(index, ci)
if err != nil {
return nil, fmt.Errorf("Issue adding index: %d", index)
// CheckAndComplete checks for any conflicts in the main and pending hostmap
// before adding hostinfo to main. If err is nil, it was added. Otherwise err will be:
// ErrAlreadySeen if we already have an entry in the hostmap that has seen the
// exact same handshake packet
//
// ErrExistingHostInfo if we already have an entry in the hostmap for this
// VpnIP and the new handshake was older than the one we currently have
//
// ErrLocalIndexCollision if we already have an entry in the main or pending
// hostmap for the hostinfo.localIndexId.
func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket uint8, overwrite bool, f *Interface) (*HostInfo, error) {
c.pendingHostMap.Lock()
defer c.pendingHostMap.Unlock()
c.mainHostMap.Lock()
defer c.mainHostMap.Unlock()
// Check if we already have a tunnel with this vpn ip
existingHostInfo, found := c.mainHostMap.Hosts[hostinfo.hostId]
if found && existingHostInfo != nil {
// Is it just a delayed handshake packet?
if bytes.Equal(hostinfo.HandshakePacket[handshakePacket], existingHostInfo.HandshakePacket[handshakePacket]) {
return existingHostInfo, ErrAlreadySeen
}
// Is this a newer handshake?
if existingHostInfo.lastHandshakeTime >= hostinfo.lastHandshakeTime {
return existingHostInfo, ErrExistingHostInfo
}
existingHostInfo.logger(c.l).Info("Taking new handshake")
}
//c.mainHostMap.AddIndexHostInfo(index, hostinfo)
c.InboundHandshakeTimer.Add(index, time.Second*10)
return hostinfo, nil
existingIndex, found := c.mainHostMap.Indexes[hostinfo.localIndexId]
if found {
// We have a collision, but for a different hostinfo
return existingIndex, ErrLocalIndexCollision
}
existingIndex, found = c.pendingHostMap.Indexes[hostinfo.localIndexId]
if found && existingIndex != hostinfo {
// We have a collision, but for a different hostinfo
return existingIndex, ErrLocalIndexCollision
}
existingRemoteIndex, found := c.mainHostMap.RemoteIndexes[hostinfo.remoteIndexId]
if found && existingRemoteIndex != nil && existingRemoteIndex.hostId != hostinfo.hostId {
// We have a collision, but this can happen since we can't control
// the remote ID. Just log about the situation as a note.
hostinfo.logger(c.l).
WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", IntIp(existingRemoteIndex.hostId)).
Info("New host shadows existing host remoteIndex")
}
// Check if we are also handshaking with this vpn ip
pendingHostInfo, found := c.pendingHostMap.Hosts[hostinfo.hostId]
if found && pendingHostInfo != nil {
if !overwrite {
// We won, let our pending handshake win
return pendingHostInfo, ErrExistingHandshake
}
// We lost, take this handshake and move any cached packets over so they get sent
pendingHostInfo.ConnectionState.queueLock.Lock()
hostinfo.packetStore = append(hostinfo.packetStore, pendingHostInfo.packetStore...)
c.pendingHostMap.unlockedDeleteHostInfo(pendingHostInfo)
pendingHostInfo.ConnectionState.queueLock.Unlock()
pendingHostInfo.logger(c.l).Info("Handshake race lost, replacing pending handshake with completed tunnel")
}
if existingHostInfo != nil {
// We are going to overwrite this entry, so remove the old references
delete(c.mainHostMap.Hosts, existingHostInfo.hostId)
delete(c.mainHostMap.Indexes, existingHostInfo.localIndexId)
delete(c.mainHostMap.RemoteIndexes, existingHostInfo.remoteIndexId)
}
c.mainHostMap.addHostInfo(hostinfo, f)
return existingHostInfo, nil
}
func (c *HandshakeManager) AddIndexHostInfo(index uint32, h *HostInfo) {
c.pendingHostMap.AddIndexHostInfo(index, h)
// Complete is a simpler version of CheckAndComplete when we already know we
// won't have a localIndexId collision because we already have an entry in the
// pendingHostMap
func (c *HandshakeManager) Complete(hostinfo *HostInfo, f *Interface) {
c.pendingHostMap.Lock()
defer c.pendingHostMap.Unlock()
c.mainHostMap.Lock()
defer c.mainHostMap.Unlock()
existingHostInfo, found := c.mainHostMap.Hosts[hostinfo.hostId]
if found && existingHostInfo != nil {
// We are going to overwrite this entry, so remove the old references
delete(c.mainHostMap.Hosts, existingHostInfo.hostId)
delete(c.mainHostMap.Indexes, existingHostInfo.localIndexId)
delete(c.mainHostMap.RemoteIndexes, existingHostInfo.remoteIndexId)
}
existingRemoteIndex, found := c.mainHostMap.RemoteIndexes[hostinfo.remoteIndexId]
if found && existingRemoteIndex != nil {
// We have a collision, but this can happen since we can't control
// the remote ID. Just log about the situation as a note.
hostinfo.logger(c.l).
WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", IntIp(existingRemoteIndex.hostId)).
Info("New host shadows existing host remoteIndex")
}
c.mainHostMap.addHostInfo(hostinfo, f)
c.pendingHostMap.unlockedDeleteHostInfo(hostinfo)
}
func (c *HandshakeManager) DeleteIndex(index uint32) {
//l.Debugln("Deleting pending index :", index)
c.pendingHostMap.DeleteIndex(index)
// AddIndexHostInfo generates a unique localIndexId for this HostInfo
// and adds it to the pendingHostMap. Will error if we are unable to generate
// a unique localIndexId
func (c *HandshakeManager) AddIndexHostInfo(h *HostInfo) error {
c.pendingHostMap.Lock()
defer c.pendingHostMap.Unlock()
c.mainHostMap.RLock()
defer c.mainHostMap.RUnlock()
for i := 0; i < 32; i++ {
index, err := generateIndex(c.l)
if err != nil {
return err
}
_, inPending := c.pendingHostMap.Indexes[index]
_, inMain := c.mainHostMap.Indexes[index]
if !inMain && !inPending {
h.localIndexId = index
c.pendingHostMap.Indexes[index] = h
return nil
}
}
return errors.New("failed to generate unique localIndexId")
}
func (c *HandshakeManager) addRemoteIndexHostInfo(index uint32, h *HostInfo) {
c.pendingHostMap.addRemoteIndexHostInfo(index, h)
}
func (c *HandshakeManager) DeleteHostInfo(hostinfo *HostInfo) {
//l.Debugln("Deleting pending hostinfo :", hostinfo)
c.pendingHostMap.DeleteHostInfo(hostinfo)
}
func (c *HandshakeManager) QueryIndex(index uint32) (*HostInfo, error) {
@@ -239,18 +360,28 @@ func (c *HandshakeManager) EmitStats() {
// Utility functions below
func generateIndex() (uint32, error) {
func generateIndex(l *logrus.Logger) (uint32, error) {
b := make([]byte, 4)
_, err := rand.Read(b)
if err != nil {
l.Errorln(err)
return 0, err
// Let zero mean we don't know the ID, so don't generate zero
var index uint32
for index == 0 {
_, err := rand.Read(b)
if err != nil {
l.Errorln(err)
return 0, err
}
index = binary.BigEndian.Uint32(b)
}
index := binary.BigEndian.Uint32(b)
if l.Level >= logrus.DebugLevel {
l.WithField("index", index).
Debug("Generated index")
}
return index, nil
}
func hsTimeout(tries int, interval time.Duration) time.Duration {
return time.Duration(tries / 2 * ((2 * int(interval)) + (tries-1)*int(interval)))
}

View File

@@ -8,137 +8,84 @@ import (
"github.com/stretchr/testify/assert"
)
var indexes []uint32 = []uint32{1000, 2000, 3000, 4000}
//var ips []uint32 = []uint32{9000, 9999999, 3, 292394923}
var ips []uint32
func Test_NewHandshakeManagerIndex(t *testing.T) {
_, tuncidr, _ := net.ParseCIDR("172.1.1.1/24")
_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
ips = []uint32{ip2int(net.ParseIP("172.1.1.2"))}
preferredRanges := []*net.IPNet{localrange}
mainHM := NewHostMap("test", vpncidr, preferredRanges)
blah := NewHandshakeManager(tuncidr, preferredRanges, mainHM, &LightHouse{}, &udpConn{}, defaultHandshakeConfig)
now := time.Now()
blah.NextInboundHandshakeTimerTick(now)
// Add four indexes
for _, v := range indexes {
blah.AddIndex(v, &ConnectionState{})
}
// Confirm they are in the pending index list
for _, v := range indexes {
assert.Contains(t, blah.pendingHostMap.Indexes, uint32(v))
}
// Adding something to pending should not affect the main hostmap
assert.Len(t, mainHM.Indexes, 0)
// Jump ahead 8 seconds
for i := 1; i <= DefaultHandshakeRetries; i++ {
next_tick := now.Add(DefaultHandshakeTryInterval * time.Duration(i))
blah.NextInboundHandshakeTimerTick(next_tick)
}
// Confirm they are still in the pending index list
for _, v := range indexes {
assert.Contains(t, blah.pendingHostMap.Indexes, uint32(v))
}
// Jump ahead 4 more seconds
next_tick := now.Add(12 * time.Second)
blah.NextInboundHandshakeTimerTick(next_tick)
// Confirm they have been removed
for _, v := range indexes {
assert.NotContains(t, blah.pendingHostMap.Indexes, uint32(v))
}
}
func Test_NewHandshakeManagerVpnIP(t *testing.T) {
_, tuncidr, _ := net.ParseCIDR("172.1.1.1/24")
_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
ips = []uint32{ip2int(net.ParseIP("172.1.1.2"))}
preferredRanges := []*net.IPNet{localrange}
mw := &mockEncWriter{}
mainHM := NewHostMap("test", vpncidr, preferredRanges)
blah := NewHandshakeManager(tuncidr, preferredRanges, mainHM, &LightHouse{}, &udpConn{}, defaultHandshakeConfig)
now := time.Now()
blah.NextOutboundHandshakeTimerTick(now, mw)
// Add four "IPs" - which are just uint32s
for _, v := range ips {
blah.AddVpnIP(v)
}
// Adding something to pending should not affect the main hostmap
assert.Len(t, mainHM.Hosts, 0)
// Confirm they are in the pending index list
for _, v := range ips {
assert.Contains(t, blah.pendingHostMap.Hosts, uint32(v))
}
// Jump ahead `HandshakeRetries` ticks
cumulative := time.Duration(0)
for i := 0; i <= DefaultHandshakeRetries+1; i++ {
cumulative += time.Duration(i)*DefaultHandshakeTryInterval + 1
next_tick := now.Add(cumulative)
//l.Infoln(next_tick)
blah.NextOutboundHandshakeTimerTick(next_tick, mw)
}
// Confirm they are still in the pending index list
for _, v := range ips {
assert.Contains(t, blah.pendingHostMap.Hosts, uint32(v))
}
// Jump ahead 1 more second
cumulative += time.Duration(DefaultHandshakeRetries+1) * DefaultHandshakeTryInterval
next_tick := now.Add(cumulative)
//l.Infoln(next_tick)
blah.NextOutboundHandshakeTimerTick(next_tick, mw)
// Confirm they have been removed
for _, v := range ips {
assert.NotContains(t, blah.pendingHostMap.Hosts, uint32(v))
}
}
func Test_NewHandshakeManagerTrigger(t *testing.T) {
l := NewTestLogger()
_, tuncidr, _ := net.ParseCIDR("172.1.1.1/24")
_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
ip := ip2int(net.ParseIP("172.1.1.2"))
preferredRanges := []*net.IPNet{localrange}
mw := &mockEncWriter{}
mainHM := NewHostMap("test", vpncidr, preferredRanges)
lh := &LightHouse{}
mainHM := NewHostMap(l, "test", vpncidr, preferredRanges)
blah := NewHandshakeManager(tuncidr, preferredRanges, mainHM, lh, &udpConn{}, defaultHandshakeConfig)
blah := NewHandshakeManager(l, tuncidr, preferredRanges, mainHM, &LightHouse{}, &udpConn{}, defaultHandshakeConfig)
now := time.Now()
blah.NextOutboundHandshakeTimerTick(now, mw)
i := blah.AddVpnIP(ip)
i.remotes = NewRemoteList()
i.HandshakeReady = true
// Adding something to pending should not affect the main hostmap
assert.Len(t, mainHM.Hosts, 0)
// Confirm they are in the pending index list
assert.Contains(t, blah.pendingHostMap.Hosts, ip)
// Jump ahead `HandshakeRetries` ticks, offset by one to get the sleep logic right
for i := 1; i <= DefaultHandshakeRetries+1; i++ {
now = now.Add(time.Duration(i) * DefaultHandshakeTryInterval)
blah.NextOutboundHandshakeTimerTick(now, mw)
}
// Confirm they are still in the pending index list
assert.Contains(t, blah.pendingHostMap.Hosts, ip)
// Tick 1 more time, a minute will certainly flush it out
blah.NextOutboundHandshakeTimerTick(now.Add(time.Minute), mw)
// Confirm they have been removed
assert.NotContains(t, blah.pendingHostMap.Hosts, ip)
}
func Test_NewHandshakeManagerTrigger(t *testing.T) {
l := NewTestLogger()
_, tuncidr, _ := net.ParseCIDR("172.1.1.1/24")
_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
ip := ip2int(net.ParseIP("172.1.1.2"))
preferredRanges := []*net.IPNet{localrange}
mw := &mockEncWriter{}
mainHM := NewHostMap(l, "test", vpncidr, preferredRanges)
lh := &LightHouse{addrMap: make(map[uint32]*RemoteList), l: l}
blah := NewHandshakeManager(l, tuncidr, preferredRanges, mainHM, lh, &udpConn{}, defaultHandshakeConfig)
now := time.Now()
blah.NextOutboundHandshakeTimerTick(now, mw)
assert.Equal(t, 0, testCountTimerWheelEntries(blah.OutboundHandshakeTimer))
blah.AddVpnIP(ip)
hi := blah.AddVpnIP(ip)
hi.HandshakeReady = true
assert.Equal(t, 1, testCountTimerWheelEntries(blah.OutboundHandshakeTimer))
assert.Equal(t, 0, hi.HandshakeCounter, "Should not have attempted a handshake yet")
// Trigger the same method the channel will but, this should set our remotes pointer
blah.handleOutbound(ip, mw, true)
assert.Equal(t, 1, hi.HandshakeCounter, "Trigger should have done a handshake attempt")
assert.NotNil(t, hi.remotes, "Manager should have set my remotes pointer")
// Make sure the trigger doesn't double schedule the timer entry
assert.Equal(t, 1, testCountTimerWheelEntries(blah.OutboundHandshakeTimer))
// Trigger the same method the channel will
uaddr := NewUDPAddrFromString("10.1.1.1:4242")
hi.remotes.unlockedPrependV4(ip, NewIp4AndPort(uaddr.IP, uint32(uaddr.Port)))
// We now have remotes but only the first trigger should have pushed things forward
blah.handleOutbound(ip, mw, true)
// Make sure the trigger doesn't schedule another timer entry
assert.Equal(t, 1, testCountTimerWheelEntries(blah.OutboundHandshakeTimer))
hi := blah.pendingHostMap.Hosts[ip]
assert.Nil(t, hi.remote)
lh.addrMap = map[uint32][]udpAddr{
ip: {*NewUDPAddrFromString("10.1.1.1:4242")},
}
// This should trigger the hostmap to populate the hostinfo
blah.handleOutbound(ip, mw, true)
assert.NotNil(t, hi.remote)
assert.Equal(t, 1, hi.HandshakeCounter, "Trigger should have not done a handshake attempt")
assert.Equal(t, 1, testCountTimerWheelEntries(blah.OutboundHandshakeTimer))
}
@@ -153,92 +100,9 @@ func testCountTimerWheelEntries(tw *SystemTimerWheel) (c int) {
return c
}
func Test_NewHandshakeManagerVpnIPcleanup(t *testing.T) {
_, tuncidr, _ := net.ParseCIDR("172.1.1.1/24")
_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
vpnIP = ip2int(net.ParseIP("172.1.1.2"))
preferredRanges := []*net.IPNet{localrange}
mw := &mockEncWriter{}
mainHM := NewHostMap("test", vpncidr, preferredRanges)
blah := NewHandshakeManager(tuncidr, preferredRanges, mainHM, &LightHouse{}, &udpConn{}, defaultHandshakeConfig)
now := time.Now()
blah.NextOutboundHandshakeTimerTick(now, mw)
hostinfo := blah.AddVpnIP(vpnIP)
// Pretned we have an index too
blah.AddIndexHostInfo(12341234, hostinfo)
assert.Contains(t, blah.pendingHostMap.Indexes, uint32(12341234))
// Jump ahead `HandshakeRetries` ticks. Eviction should happen in pending
// but not main hostmap
cumulative := time.Duration(0)
for i := 1; i <= DefaultHandshakeRetries+2; i++ {
cumulative += DefaultHandshakeTryInterval * time.Duration(i)
next_tick := now.Add(cumulative)
blah.NextOutboundHandshakeTimerTick(next_tick, mw)
}
/*
for i := 0; i <= HandshakeRetries+1; i++ {
next_tick := now.Add(cumulative)
//l.Infoln(next_tick)
blah.NextOutboundHandshakeTimerTick(next_tick)
}
*/
/*
for i := 0; i <= HandshakeRetries+1; i++ {
next_tick := now.Add(time.Duration(i) * time.Second)
blah.NextOutboundHandshakeTimerTick(next_tick)
}
*/
/*
cumulative += HandshakeTryInterval*time.Duration(HandshakeRetries) + 3
next_tick := now.Add(cumulative)
l.Infoln(cumulative, next_tick)
blah.NextOutboundHandshakeTimerTick(next_tick)
*/
assert.NotContains(t, blah.pendingHostMap.Hosts, uint32(vpnIP))
assert.NotContains(t, blah.pendingHostMap.Indexes, uint32(12341234))
}
func Test_NewHandshakeManagerIndexcleanup(t *testing.T) {
_, tuncidr, _ := net.ParseCIDR("172.1.1.1/24")
_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
preferredRanges := []*net.IPNet{localrange}
mainHM := NewHostMap("test", vpncidr, preferredRanges)
blah := NewHandshakeManager(tuncidr, preferredRanges, mainHM, &LightHouse{}, &udpConn{}, defaultHandshakeConfig)
now := time.Now()
blah.NextInboundHandshakeTimerTick(now)
hostinfo, _ := blah.AddIndex(12341234, &ConnectionState{})
// Pretned we have an index too
blah.pendingHostMap.AddVpnIPHostInfo(101010, hostinfo)
assert.Contains(t, blah.pendingHostMap.Hosts, uint32(101010))
for i := 1; i <= DefaultHandshakeRetries+2; i++ {
next_tick := now.Add(DefaultHandshakeTryInterval * time.Duration(i))
blah.NextInboundHandshakeTimerTick(next_tick)
}
next_tick := now.Add(DefaultHandshakeTryInterval*DefaultHandshakeRetries + 3)
blah.NextInboundHandshakeTimerTick(next_tick)
assert.NotContains(t, blah.pendingHostMap.Hosts, uint32(101010))
assert.NotContains(t, blah.pendingHostMap.Indexes, uint32(12341234))
}
type mockEncWriter struct {
}
func (mw *mockEncWriter) SendMessageToVpnIp(t NebulaMessageType, st NebulaMessageSubType, vpnIp uint32, p, nb, out []byte) {
return
}
func (mw *mockEncWriter) SendMessageToAll(t NebulaMessageType, st NebulaMessageSubType, vpnIp uint32, p, nb, out []byte) {
return
}

View File

@@ -1,11 +1,11 @@
package nebula
import (
"encoding/json"
"errors"
"fmt"
"net"
"sync"
"sync/atomic"
"time"
"github.com/rcrowley/go-metrics"
@@ -15,41 +15,55 @@ import (
//const ProbeLen = 100
const PromoteEvery = 1000
const ReQueryEvery = 5000
const MaxRemotes = 10
// How long we should prevent roaming back to the previous IP.
// This helps prevent flapping due to packets already in flight
const RoamingSupressSeconds = 2
const RoamingSuppressSeconds = 2
type HostMap struct {
sync.RWMutex //Because we concurrently read and write to our maps
name string
Indexes map[uint32]*HostInfo
RemoteIndexes map[uint32]*HostInfo
Hosts map[uint32]*HostInfo
preferredRanges []*net.IPNet
vpnCIDR *net.IPNet
defaultRoute uint32
unsafeRoutes *CIDRTree
metricsEnabled bool
l *logrus.Logger
}
type HostInfo struct {
sync.RWMutex
remote *udpAddr
Remotes []*HostInfoDest
remotes *RemoteList
promoteCounter uint32
ConnectionState *ConnectionState
handshakeStart time.Time
HandshakeReady bool
HandshakeCounter int
HandshakeComplete bool
HandshakePacket map[uint8][]byte
packetStore []*cachedPacket
handshakeStart time.Time //todo: this an entry in the handshake manager
HandshakeReady bool //todo: being in the manager means you are ready
HandshakeCounter int //todo: another handshake manager entry
HandshakeComplete bool //todo: this should go away in favor of ConnectionState.ready
HandshakePacket map[uint8][]byte //todo: this is other handshake manager entry
packetStore []*cachedPacket //todo: this is other handshake manager entry
remoteIndexId uint32
localIndexId uint32
hostId uint32
recvError int
remoteCidr *CIDRTree
// lastRebindCount is the other side of Interface.rebindCount, if these values don't match then we need to ask LH
// for a punch from the remote end of this tunnel. The goal being to prime their conntrack for our traffic just like
// with a handshake
lastRebindCount int8
// lastHandshakeTime records the time the remote side told us about at the stage when the handshake was completed locally
// Stage 1 packet will contain it if I am a responder, stage 2 packet if I am an initiator
// This is used to avoid an attack where a handshake packet is replayed after some time
lastHandshakeTime uint64
lastRoam time.Time
lastRoamRemote *udpAddr
}
@@ -63,29 +77,24 @@ type cachedPacket struct {
type packetCallback func(t NebulaMessageType, st NebulaMessageSubType, h *HostInfo, p, nb, out []byte)
type HostInfoDest struct {
active bool
addr *udpAddr
//probes [ProbeLen]bool
probeCounter int
type cachedPacketMetrics struct {
sent metrics.Counter
dropped metrics.Counter
}
type Probe struct {
Addr *net.UDPAddr
Counter int
}
func NewHostMap(name string, vpnCIDR *net.IPNet, preferredRanges []*net.IPNet) *HostMap {
func NewHostMap(l *logrus.Logger, name string, vpnCIDR *net.IPNet, preferredRanges []*net.IPNet) *HostMap {
h := map[uint32]*HostInfo{}
i := map[uint32]*HostInfo{}
r := map[uint32]*HostInfo{}
m := HostMap{
name: name,
Indexes: i,
RemoteIndexes: r,
Hosts: h,
preferredRanges: preferredRanges,
vpnCIDR: vpnCIDR,
defaultRoute: 0,
unsafeRoutes: NewCIDRTree(),
l: l,
}
return &m
}
@@ -95,10 +104,12 @@ func (hm *HostMap) EmitStats(name string) {
hm.RLock()
hostLen := len(hm.Hosts)
indexLen := len(hm.Indexes)
remoteIndexLen := len(hm.RemoteIndexes)
hm.RUnlock()
metrics.GetOrRegisterGauge("hostmap."+name+".hosts", nil).Update(int64(hostLen))
metrics.GetOrRegisterGauge("hostmap."+name+".indexes", nil).Update(int64(indexLen))
metrics.GetOrRegisterGauge("hostmap."+name+".remoteIndexes", nil).Update(int64(remoteIndexLen))
}
func (hm *HostMap) GetIndexByVpnIP(vpnIP uint32) (uint32, error) {
@@ -112,17 +123,6 @@ func (hm *HostMap) GetIndexByVpnIP(vpnIP uint32) (uint32, error) {
return 0, errors.New("vpn IP not found")
}
func (hm *HostMap) GetVpnIPByIndex(index uint32) (uint32, error) {
hm.RLock()
if i, ok := hm.Indexes[index]; ok {
vpnIP := i.hostId
hm.RUnlock()
return vpnIP, nil
}
hm.RUnlock()
return 0, errors.New("vpn IP not found")
}
func (hm *HostMap) Add(ip uint32, hostinfo *HostInfo) {
hm.Lock()
hm.Hosts[ip] = hostinfo
@@ -135,7 +135,6 @@ func (hm *HostMap) AddVpnIP(vpnIP uint32) *HostInfo {
if _, ok := hm.Hosts[vpnIP]; !ok {
hm.RUnlock()
h = &HostInfo{
Remotes: []*HostInfoDest{},
promoteCounter: 0,
hostId: vpnIP,
HandshakePacket: make(map[uint8][]byte, 0),
@@ -159,43 +158,23 @@ func (hm *HostMap) DeleteVpnIP(vpnIP uint32) {
}
hm.Unlock()
if l.Level >= logrus.DebugLevel {
l.WithField("hostMap", m{"mapName": hm.name, "vpnIp": IntIp(vpnIP), "mapTotalSize": len(hm.Hosts)}).
if hm.l.Level >= logrus.DebugLevel {
hm.l.WithField("hostMap", m{"mapName": hm.name, "vpnIp": IntIp(vpnIP), "mapTotalSize": len(hm.Hosts)}).
Debug("Hostmap vpnIp deleted")
}
}
func (hm *HostMap) AddIndex(index uint32, ci *ConnectionState) (*HostInfo, error) {
// Only used by pendingHostMap when the remote index is not initially known
func (hm *HostMap) addRemoteIndexHostInfo(index uint32, h *HostInfo) {
hm.Lock()
if _, ok := hm.Indexes[index]; !ok {
h := &HostInfo{
ConnectionState: ci,
Remotes: []*HostInfoDest{},
localIndexId: index,
HandshakePacket: make(map[uint8][]byte, 0),
}
hm.Indexes[index] = h
l.WithField("hostMap", m{"mapName": hm.name, "indexNumber": index, "mapTotalSize": len(hm.Indexes),
"hostinfo": m{"existing": false, "localIndexId": h.localIndexId, "hostId": IntIp(h.hostId)}}).
Debug("Hostmap index added")
hm.Unlock()
return h, nil
}
hm.Unlock()
return nil, fmt.Errorf("refusing to overwrite existing index: %d", index)
}
func (hm *HostMap) AddIndexHostInfo(index uint32, h *HostInfo) {
hm.Lock()
h.localIndexId = index
hm.Indexes[index] = h
h.remoteIndexId = index
hm.RemoteIndexes[index] = h
hm.Unlock()
if l.Level > logrus.DebugLevel {
l.WithField("hostMap", m{"mapName": hm.name, "indexNumber": index, "mapTotalSize": len(hm.Indexes),
if hm.l.Level > logrus.DebugLevel {
hm.l.WithField("hostMap", m{"mapName": hm.name, "indexNumber": index, "mapTotalSize": len(hm.Indexes),
"hostinfo": m{"existing": true, "localIndexId": h.localIndexId, "hostId": IntIp(h.hostId)}}).
Debug("Hostmap index added")
Debug("Hostmap remoteIndex added")
}
}
@@ -203,29 +182,101 @@ func (hm *HostMap) AddVpnIPHostInfo(vpnIP uint32, h *HostInfo) {
hm.Lock()
h.hostId = vpnIP
hm.Hosts[vpnIP] = h
hm.Indexes[h.localIndexId] = h
hm.RemoteIndexes[h.remoteIndexId] = h
hm.Unlock()
if l.Level > logrus.DebugLevel {
l.WithField("hostMap", m{"mapName": hm.name, "vpnIp": IntIp(vpnIP), "mapTotalSize": len(hm.Hosts),
if hm.l.Level > logrus.DebugLevel {
hm.l.WithField("hostMap", m{"mapName": hm.name, "vpnIp": IntIp(vpnIP), "mapTotalSize": len(hm.Hosts),
"hostinfo": m{"existing": true, "localIndexId": h.localIndexId, "hostId": IntIp(h.hostId)}}).
Debug("Hostmap vpnIp added")
}
}
// This is only called in pendingHostmap, to cleanup an inbound handshake
func (hm *HostMap) DeleteIndex(index uint32) {
hm.Lock()
delete(hm.Indexes, index)
if len(hm.Indexes) == 0 {
hm.Indexes = map[uint32]*HostInfo{}
hostinfo, ok := hm.Indexes[index]
if ok {
delete(hm.Indexes, index)
delete(hm.RemoteIndexes, hostinfo.remoteIndexId)
// Check if we have an entry under hostId that matches the same hostinfo
// instance. Clean it up as well if we do.
hostinfo2, ok := hm.Hosts[hostinfo.hostId]
if ok && hostinfo2 == hostinfo {
delete(hm.Hosts, hostinfo.hostId)
}
}
hm.Unlock()
if l.Level >= logrus.DebugLevel {
l.WithField("hostMap", m{"mapName": hm.name, "indexNumber": index, "mapTotalSize": len(hm.Indexes)}).
if hm.l.Level >= logrus.DebugLevel {
hm.l.WithField("hostMap", m{"mapName": hm.name, "indexNumber": index, "mapTotalSize": len(hm.Indexes)}).
Debug("Hostmap index deleted")
}
}
// This is used to cleanup on recv_error
func (hm *HostMap) DeleteReverseIndex(index uint32) {
hm.Lock()
hostinfo, ok := hm.RemoteIndexes[index]
if ok {
delete(hm.Indexes, hostinfo.localIndexId)
delete(hm.RemoteIndexes, index)
// Check if we have an entry under hostId that matches the same hostinfo
// instance. Clean it up as well if we do (they might not match in pendingHostmap)
var hostinfo2 *HostInfo
hostinfo2, ok = hm.Hosts[hostinfo.hostId]
if ok && hostinfo2 == hostinfo {
delete(hm.Hosts, hostinfo.hostId)
}
}
hm.Unlock()
if hm.l.Level >= logrus.DebugLevel {
hm.l.WithField("hostMap", m{"mapName": hm.name, "indexNumber": index, "mapTotalSize": len(hm.Indexes)}).
Debug("Hostmap remote index deleted")
}
}
func (hm *HostMap) DeleteHostInfo(hostinfo *HostInfo) {
hm.Lock()
defer hm.Unlock()
hm.unlockedDeleteHostInfo(hostinfo)
}
func (hm *HostMap) unlockedDeleteHostInfo(hostinfo *HostInfo) {
// Check if this same hostId is in the hostmap with a different instance.
// This could happen if we have an entry in the pending hostmap with different
// index values than the one in the main hostmap.
hostinfo2, ok := hm.Hosts[hostinfo.hostId]
if ok && hostinfo2 != hostinfo {
delete(hm.Hosts, hostinfo2.hostId)
delete(hm.Indexes, hostinfo2.localIndexId)
delete(hm.RemoteIndexes, hostinfo2.remoteIndexId)
}
delete(hm.Hosts, hostinfo.hostId)
if len(hm.Hosts) == 0 {
hm.Hosts = map[uint32]*HostInfo{}
}
delete(hm.Indexes, hostinfo.localIndexId)
if len(hm.Indexes) == 0 {
hm.Indexes = map[uint32]*HostInfo{}
}
delete(hm.RemoteIndexes, hostinfo.remoteIndexId)
if len(hm.RemoteIndexes) == 0 {
hm.RemoteIndexes = map[uint32]*HostInfo{}
}
if hm.l.Level >= logrus.DebugLevel {
hm.l.WithField("hostMap", m{"mapName": hm.name, "mapTotalSize": len(hm.Hosts),
"vpnIp": IntIp(hostinfo.hostId), "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId}).
Debug("Hostmap hostInfo deleted")
}
}
func (hm *HostMap) QueryIndex(index uint32) (*HostInfo, error) {
//TODO: we probably just want ot return bool instead of error, or at least a static error
hm.RLock()
@@ -238,45 +289,15 @@ func (hm *HostMap) QueryIndex(index uint32) (*HostInfo, error) {
}
}
// This function needs to range because we don't keep a map of remote indexes.
func (hm *HostMap) QueryReverseIndex(index uint32) (*HostInfo, error) {
hm.RLock()
for _, h := range hm.Indexes {
if h.ConnectionState != nil && h.remoteIndexId == index {
hm.RUnlock()
return h, nil
}
}
for _, h := range hm.Hosts {
if h.ConnectionState != nil && h.remoteIndexId == index {
hm.RUnlock()
return h, nil
}
}
hm.RUnlock()
return nil, fmt.Errorf("unable to find reverse index or connectionstate nil in %s hostmap", hm.name)
}
func (hm *HostMap) AddRemote(vpnIp uint32, remote *udpAddr) *HostInfo {
hm.Lock()
i, v := hm.Hosts[vpnIp]
if v {
i.AddRemote(*remote)
if h, ok := hm.RemoteIndexes[index]; ok {
hm.RUnlock()
return h, nil
} else {
i = &HostInfo{
Remotes: []*HostInfoDest{NewHostInfoDest(remote)},
promoteCounter: 0,
hostId: vpnIp,
HandshakePacket: make(map[uint8][]byte, 0),
}
i.remote = i.Remotes[0].addr
hm.Hosts[vpnIp] = i
l.WithField("hostMap", m{"mapName": hm.name, "vpnIp": IntIp(vpnIp), "udpAddr": remote, "mapTotalSize": len(hm.Hosts)}).
Debug("Hostmap remote ip added")
hm.RUnlock()
return nil, fmt.Errorf("unable to find reverse index or connectionstate nil in %s hostmap", hm.name)
}
i.ForcePromoteBest(hm.preferredRanges)
hm.Unlock()
return i
}
func (hm *HostMap) QueryVpnIP(vpnIp uint32) (*HostInfo, error) {
@@ -292,23 +313,17 @@ func (hm *HostMap) PromoteBestQueryVpnIP(vpnIp uint32, ifce *Interface) (*HostIn
func (hm *HostMap) queryVpnIP(vpnIp uint32, promoteIfce *Interface) (*HostInfo, error) {
hm.RLock()
if h, ok := hm.Hosts[vpnIp]; ok {
if promoteIfce != nil {
hm.RUnlock()
// Do not attempt promotion if you are a lighthouse
if promoteIfce != nil && !promoteIfce.lightHouse.amLighthouse {
h.TryPromoteBest(hm.preferredRanges, promoteIfce)
}
//fmt.Println(h.remote)
hm.RUnlock()
return h, nil
} else {
//return &net.UDPAddr{}, nil, errors.New("Unable to find host")
hm.RUnlock()
/*
if lightHouse != nil {
lightHouse.Query(vpnIp)
return nil, errors.New("Unable to find host")
}
*/
return nil, errors.New("unable to find host")
}
hm.RUnlock()
return nil, errors.New("unable to find host")
}
func (hm *HostMap) queryUnsafeRoute(ip uint32) uint32 {
@@ -320,70 +335,40 @@ func (hm *HostMap) queryUnsafeRoute(ip uint32) uint32 {
}
}
func (hm *HostMap) CheckHandshakeCompleteIP(vpnIP uint32) bool {
hm.RLock()
if i, ok := hm.Hosts[vpnIP]; ok {
if i == nil {
hm.RUnlock()
return false
}
complete := i.HandshakeComplete
hm.RUnlock()
return complete
// We already have the hm Lock when this is called, so make sure to not call
// any other methods that might try to grab it again
func (hm *HostMap) addHostInfo(hostinfo *HostInfo, f *Interface) {
if f.serveDns {
remoteCert := hostinfo.ConnectionState.peerCert
dnsR.Add(remoteCert.Details.Name+".", remoteCert.Details.Ips[0].IP.String())
}
hm.RUnlock()
return false
}
func (hm *HostMap) CheckHandshakeCompleteIndex(index uint32) bool {
hm.RLock()
if i, ok := hm.Indexes[index]; ok {
if i == nil {
hm.RUnlock()
return false
}
complete := i.HandshakeComplete
hm.RUnlock()
return complete
hm.Hosts[hostinfo.hostId] = hostinfo
hm.Indexes[hostinfo.localIndexId] = hostinfo
hm.RemoteIndexes[hostinfo.remoteIndexId] = hostinfo
if hm.l.Level >= logrus.DebugLevel {
hm.l.WithField("hostMap", m{"mapName": hm.name, "vpnIp": IntIp(hostinfo.hostId), "mapTotalSize": len(hm.Hosts),
"hostinfo": m{"existing": true, "localIndexId": hostinfo.localIndexId, "hostId": IntIp(hostinfo.hostId)}}).
Debug("Hostmap vpnIp added")
}
hm.RUnlock()
return false
}
func (hm *HostMap) ClearRemotes(vpnIP uint32) {
hm.Lock()
i := hm.Hosts[vpnIP]
if i == nil {
hm.Unlock()
return
}
i.remote = nil
i.Remotes = nil
hm.Unlock()
}
func (hm *HostMap) SetDefaultRoute(ip uint32) {
hm.defaultRoute = ip
}
func (hm *HostMap) PunchList() []*udpAddr {
var list []*udpAddr
// punchList assembles a list of all non nil RemoteList pointer entries in this hostmap
// The caller can then do the its work outside of the read lock
func (hm *HostMap) punchList(rl []*RemoteList) []*RemoteList {
hm.RLock()
defer hm.RUnlock()
for _, v := range hm.Hosts {
for _, r := range v.Remotes {
list = append(list, r.addr)
if v.remotes != nil {
rl = append(rl, v.remotes)
}
// if h, ok := hm.Hosts[vpnIp]; ok {
// hm.Hosts[vpnIp].PromoteBest(hm.preferredRanges, false)
//fmt.Println(h.remote)
// }
}
hm.RUnlock()
return list
return rl
}
// Punchy iterates through the result of punchList() to assemble all known addresses and sends a hole punch packet to them
func (hm *HostMap) Punchy(conn *udpConn) {
var metricsTxPunchy metrics.Counter
if hm.metricsEnabled {
@@ -392,162 +377,83 @@ func (hm *HostMap) Punchy(conn *udpConn) {
metricsTxPunchy = metrics.NilCounter{}
}
var remotes []*RemoteList
b := []byte{1}
for {
for _, addr := range hm.PunchList() {
metricsTxPunchy.Inc(1)
conn.WriteTo([]byte{1}, addr)
remotes = hm.punchList(remotes[:0])
for _, rl := range remotes {
//TODO: CopyAddrs generates garbage but ForEach locks for the work here, figure out which way is better
for _, addr := range rl.CopyAddrs(hm.preferredRanges) {
metricsTxPunchy.Inc(1)
conn.WriteTo(b, addr)
}
}
time.Sleep(time.Second * 30)
time.Sleep(time.Second * 10)
}
}
func (hm *HostMap) addUnsafeRoutes(routes *[]route) {
for _, r := range *routes {
l.WithField("route", r.route).WithField("via", r.via).Warn("Adding UNSAFE Route")
hm.l.WithField("route", r.route).WithField("via", r.via).Warn("Adding UNSAFE Route")
hm.unsafeRoutes.AddCIDR(r.route, ip2int(*r.via))
}
}
func (i *HostInfo) MarshalJSON() ([]byte, error) {
return json.Marshal(m{
"remote": i.remote,
"remotes": i.Remotes,
"promote_counter": i.promoteCounter,
"connection_state": i.ConnectionState,
"handshake_start": i.handshakeStart,
"handshake_ready": i.HandshakeReady,
"handshake_counter": i.HandshakeCounter,
"handshake_complete": i.HandshakeComplete,
"handshake_packet": i.HandshakePacket,
"packet_store": i.packetStore,
"remote_index": i.remoteIndexId,
"local_index": i.localIndexId,
"host_id": int2ip(i.hostId),
"receive_errors": i.recvError,
"last_roam": i.lastRoam,
"last_roam_remote": i.lastRoamRemote,
})
}
func (i *HostInfo) BindConnectionState(cs *ConnectionState) {
i.ConnectionState = cs
}
// TryPromoteBest handles re-querying lighthouses and probing for better paths
// NOTE: It is an error to call this if you are a lighthouse since they should not roam clients!
func (i *HostInfo) TryPromoteBest(preferredRanges []*net.IPNet, ifce *Interface) {
if i.remote == nil {
i.ForcePromoteBest(preferredRanges)
return
}
c := atomic.AddUint32(&i.promoteCounter, 1)
if c%PromoteEvery == 0 {
// The lock here is currently protecting i.remote access
i.RLock()
defer i.RUnlock()
i.promoteCounter++
if i.promoteCounter%PromoteEvery == 0 {
// return early if we are already on a preferred remote
rIP := udp2ip(i.remote)
rIP := i.remote.IP
for _, l := range preferredRanges {
if l.Contains(rIP) {
return
}
}
// We re-query the lighthouse periodically while sending packets, so
// check for new remotes in our local lighthouse cache
ips := ifce.lightHouse.QueryCache(i.hostId)
for _, ip := range ips {
i.AddRemote(ip)
}
i.remotes.ForEach(preferredRanges, func(addr *udpAddr, preferred bool) {
if addr == nil || !preferred {
return
}
best, preferred := i.getBestRemote(preferredRanges)
if preferred && !best.Equals(i.remote) {
// Try to send a test packet to that host, this should
// cause it to detect a roaming event and switch remotes
ifce.send(test, testRequest, i.ConnectionState, i, best, []byte(""), make([]byte, 12, 12), make([]byte, mtu))
}
ifce.send(test, testRequest, i.ConnectionState, i, addr, []byte(""), make([]byte, 12, 12), make([]byte, mtu))
})
}
// Re query our lighthouses for new remotes occasionally
if c%ReQueryEvery == 0 && ifce.lightHouse != nil {
ifce.lightHouse.QueryServer(i.hostId, ifce)
}
}
func (i *HostInfo) ForcePromoteBest(preferredRanges []*net.IPNet) {
best, _ := i.getBestRemote(preferredRanges)
if best != nil {
i.remote = best
}
}
func (i *HostInfo) getBestRemote(preferredRanges []*net.IPNet) (best *udpAddr, preferred bool) {
if len(i.Remotes) > 0 {
for _, r := range i.Remotes {
rIP := udp2ip(r.addr)
for _, l := range preferredRanges {
if l.Contains(rIP) {
return r.addr, true
}
}
if best == nil || !PrivateIP(rIP) {
best = r.addr
}
/*
for _, r := range i.Remotes {
// Must have > 80% probe success to be considered.
//fmt.Println("GRADE:", r.addr.IP, r.Grade())
if r.Grade() > float64(.8) {
if localToMe.Contains(r.addr.IP) == true {
best = r.addr
break
//i.remote = i.Remotes[c].addr
} else {
//}
}
*/
}
return best, false
}
return nil, false
}
// rotateRemote will move remote to the next ip in the list of remote ips for this host
// This is different than PromoteBest in that what is algorithmically best may not actually work.
// Only known use case is when sending a stage 0 handshake.
// It may be better to just send stage 0 handshakes to all known ips and sort it out in the receiver.
func (i *HostInfo) rotateRemote() {
// We have 0, can't rotate
if len(i.Remotes) < 1 {
return
}
if i.remote == nil {
i.remote = i.Remotes[0].addr
return
}
// We want to look at all but the very last entry since that is handled at the end
for x := 0; x < len(i.Remotes)-1; x++ {
// Find our current position and move to the next one in the list
if i.Remotes[x].addr.Equals(i.remote) {
i.remote = i.Remotes[x+1].addr
return
}
}
// Our current position was likely the last in the list, start over at 0
i.remote = i.Remotes[0].addr
}
func (i *HostInfo) cachePacket(t NebulaMessageType, st NebulaMessageSubType, packet []byte, f packetCallback) {
func (i *HostInfo) cachePacket(l *logrus.Logger, t NebulaMessageType, st NebulaMessageSubType, packet []byte, f packetCallback, m *cachedPacketMetrics) {
//TODO: return the error so we can log with more context
if len(i.packetStore) < 100 {
tempPacket := make([]byte, len(packet))
copy(tempPacket, packet)
//l.WithField("trace", string(debug.Stack())).Error("Caching packet", tempPacket)
i.packetStore = append(i.packetStore, &cachedPacket{t, st, f, tempPacket})
i.logger().
WithField("length", len(i.packetStore)).
WithField("stored", true).
Debugf("Packet store")
if l.Level >= logrus.DebugLevel {
i.logger(l).
WithField("length", len(i.packetStore)).
WithField("stored", true).
Debugf("Packet store")
}
} else if l.Level >= logrus.DebugLevel {
i.logger().
m.dropped.Inc(1)
i.logger(l).
WithField("length", len(i.packetStore)).
WithField("stored", false).
Debugf("Packet store")
@@ -555,7 +461,7 @@ func (i *HostInfo) cachePacket(t NebulaMessageType, st NebulaMessageSubType, pac
}
// handshakeComplete will set the connection as ready to communicate, as well as flush any stored packets
func (i *HostInfo) handshakeComplete() {
func (i *HostInfo) handshakeComplete(l *logrus.Logger, m *cachedPacketMetrics) {
//TODO: I'm not certain the distinction between handshake complete and ConnectionState being ready matters because:
//TODO: HandshakeComplete means send stored packets and ConnectionState.ready means we are ready to send
//TODO: if the transition from HandhsakeComplete to ConnectionState.ready happens all within this function they are identical
@@ -564,27 +470,28 @@ func (i *HostInfo) handshakeComplete() {
i.HandshakeComplete = true
//TODO: this should be managed by the handshake state machine to set it based on how many handshake were seen.
// Clamping it to 2 gets us out of the woods for now
*i.ConnectionState.messageCounter = 2
i.logger().Debugf("Sending %d stored packets", len(i.packetStore))
nb := make([]byte, 12, 12)
out := make([]byte, mtu)
for _, cp := range i.packetStore {
cp.callback(cp.messageType, cp.messageSubType, i, cp.packet, nb, out)
atomic.StoreUint64(&i.ConnectionState.atomicMessageCounter, 2)
if l.Level >= logrus.DebugLevel {
i.logger(l).Debugf("Sending %d stored packets", len(i.packetStore))
}
if len(i.packetStore) > 0 {
nb := make([]byte, 12, 12)
out := make([]byte, mtu)
for _, cp := range i.packetStore {
cp.callback(cp.messageType, cp.messageSubType, i, cp.packet, nb, out)
}
m.sent.Inc(int64(len(i.packetStore)))
}
i.remotes.ResetBlockedRemotes()
i.packetStore = make([]*cachedPacket, 0)
i.ConnectionState.ready = true
i.ConnectionState.queueLock.Unlock()
i.ConnectionState.certState = nil
}
func (i *HostInfo) RemoteUDPAddrs() []*udpAddr {
var addrs []*udpAddr
for _, r := range i.Remotes {
addrs = append(addrs, r.addr)
}
return addrs
}
func (i *HostInfo) GetCert() *cert.NebulaCertificate {
if i.ConnectionState != nil {
return i.ConnectionState.peerCert
@@ -592,31 +499,12 @@ func (i *HostInfo) GetCert() *cert.NebulaCertificate {
return nil
}
func (i *HostInfo) AddRemote(r udpAddr) *udpAddr {
remote := &r
//add := true
for _, r := range i.Remotes {
if r.addr.Equals(remote) {
return r.addr
//add = false
}
func (i *HostInfo) SetRemote(remote *udpAddr) {
// We copy here because we likely got this remote from a source that reuses the object
if !i.remote.Equals(remote) {
i.remote = remote.Copy()
i.remotes.LearnRemote(i.hostId, remote.Copy())
}
// Trim this down if necessary
if len(i.Remotes) > MaxRemotes {
i.Remotes = i.Remotes[len(i.Remotes)-MaxRemotes:]
}
i.Remotes = append(i.Remotes, NewHostInfoDest(remote))
return remote
//l.Debugf("Added remote %s for vpn ip", remote)
}
func (i *HostInfo) SetRemote(remote udpAddr) {
i.remote = i.AddRemote(remote)
}
func (i *HostInfo) ClearRemotes() {
i.remote = nil
i.Remotes = []*HostInfoDest{}
}
func (i *HostInfo) ClearConnectionState() {
@@ -648,7 +536,7 @@ func (i *HostInfo) CreateRemoteCIDR(c *cert.NebulaCertificate) {
i.remoteCidr = remoteCidr
}
func (i *HostInfo) logger() *logrus.Entry {
func (i *HostInfo) logger(l *logrus.Logger) *logrus.Entry {
if i == nil {
return logrus.NewEntry(l)
}
@@ -666,21 +554,6 @@ func (i *HostInfo) logger() *logrus.Entry {
//########################
func NewHostInfoDest(addr *udpAddr) *HostInfoDest {
i := &HostInfoDest{
addr: addr,
}
return i
}
func (hid *HostInfoDest) MarshalJSON() ([]byte, error) {
return json.Marshal(m{
"active": hid.active,
"address": hid.addr,
"probe_count": hid.probeCounter,
})
}
/*
func (hm *HostMap) DebugRemotes(vpnIp uint32) string {
@@ -693,40 +566,6 @@ func (hm *HostMap) DebugRemotes(vpnIp uint32) string {
return s
}
func (d *HostInfoDest) Grade() float64 {
c1 := ProbeLen
for n := len(d.probes) - 1; n >= 0; n-- {
if d.probes[n] == true {
c1 -= 1
}
}
return float64(c1) / float64(ProbeLen)
}
func (d *HostInfoDest) Grade() (float64, float64, float64) {
c1 := ProbeLen
c2 := ProbeLen / 2
c2c := ProbeLen - ProbeLen/2
c3 := ProbeLen / 5
c3c := ProbeLen - ProbeLen/5
for n := len(d.probes) - 1; n >= 0; n-- {
if d.probes[n] == true {
c1 -= 1
if n >= c2c {
c2 -= 1
if n >= c3c {
c3 -= 1
}
}
}
//if n >= d {
}
return float64(c3) / float64(ProbeLen/5), float64(c2) / float64(ProbeLen/2), float64(c1) / float64(ProbeLen)
//return float64(c1) / float64(ProbeLen), float64(c2) / float64(ProbeLen/2), float64(c3) / float64(ProbeLen/5)
}
func (i *HostInfo) HandleReply(addr *net.UDPAddr, counter int) {
for _, r := range i.Remotes {
if r.addr.IP.Equal(addr.IP) && r.addr.Port == addr.Port {
@@ -743,34 +582,20 @@ func (i *HostInfo) Probes() []*Probe {
return p
}
func (d *HostInfoDest) Probe() int {
//d.probes = append(d.probes, true)
d.probeCounter++
d.probes[d.probeCounter%ProbeLen] = true
return d.probeCounter
//return d.probeCounter
}
func (d *HostInfoDest) ProbeReceived(probeCount int) {
if probeCount >= (d.probeCounter - ProbeLen) {
//fmt.Println("PROBE WORKED", probeCount)
//fmt.Println(d.addr, d.Grade())
d.probes[probeCount%ProbeLen] = false
}
}
*/
// Utility functions
func localIps(allowList *AllowList) *[]net.IP {
func localIps(l *logrus.Logger, allowList *AllowList) *[]net.IP {
//FIXME: This function is pretty garbage
var ips []net.IP
ifaces, _ := net.Interfaces()
for _, i := range ifaces {
allow := allowList.AllowName(i.Name)
l.WithField("interfaceName", i.Name).WithField("allow", allow).Debug("localAllowList.AllowName")
if l.Level >= logrus.TraceLevel {
l.WithField("interfaceName", i.Name).WithField("allow", allow).Trace("localAllowList.AllowName")
}
if !allow {
continue
}
@@ -784,9 +609,14 @@ func localIps(allowList *AllowList) *[]net.IP {
case *net.IPAddr:
ip = v.IP
}
if ip.To4() != nil && ip.IsLoopback() == false {
allow := allowList.Allow(ip2int(ip))
l.WithField("localIp", ip).WithField("allow", allow).Debug("localAllowList.Allow")
//TODO: Filtering out link local for now, this is probably the most correct thing
//TODO: Would be nice to filter out SLAAC MAC based ips as well
if ip.IsLoopback() == false && !ip.IsLinkLocalUnicast() {
allow := allowList.Allow(ip)
if l.Level >= logrus.TraceLevel {
l.WithField("localIp", ip).WithField("allow", allow).Trace("localAllowList.Allow")
}
if !allow {
continue
}
@@ -797,12 +627,3 @@ func localIps(allowList *AllowList) *[]net.IP {
}
return &ips
}
func PrivateIP(ip net.IP) bool {
private := false
_, private24BitBlock, _ := net.ParseCIDR("10.0.0.0/8")
_, private20BitBlock, _ := net.ParseCIDR("172.16.0.0/12")
_, private16BitBlock, _ := net.ParseCIDR("192.168.0.0/16")
private = private24BitBlock.Contains(ip) || private20BitBlock.Contains(ip) || private16BitBlock.Contains(ip)
return private
}

View File

@@ -1,164 +1 @@
package nebula
import (
"net"
"testing"
"github.com/stretchr/testify/assert"
)
/*
func TestHostInfoDestProbe(t *testing.T) {
a, _ := net.ResolveUDPAddr("udp", "1.0.0.1:22222")
d := NewHostInfoDest(a)
// 999 probes that all return should give a 100% success rate
for i := 0; i < 999; i++ {
meh := d.Probe()
d.ProbeReceived(meh)
}
assert.Equal(t, d.Grade(), float64(1))
// 999 probes of which only half return should give a 50% success rate
for i := 0; i < 999; i++ {
meh := d.Probe()
if i%2 == 0 {
d.ProbeReceived(meh)
}
}
assert.Equal(t, d.Grade(), float64(.5))
// 999 probes of which none return should give a 0% success rate
for i := 0; i < 999; i++ {
d.Probe()
}
assert.Equal(t, d.Grade(), float64(0))
// 999 probes of which only 1/4 return should give a 25% success rate
for i := 0; i < 999; i++ {
meh := d.Probe()
if i%4 == 0 {
d.ProbeReceived(meh)
}
}
assert.Equal(t, d.Grade(), float64(.25))
// 999 probes of which only half return and are duplicates should give a 50% success rate
for i := 0; i < 999; i++ {
meh := d.Probe()
if i%2 == 0 {
d.ProbeReceived(meh)
d.ProbeReceived(meh)
}
}
assert.Equal(t, d.Grade(), float64(.5))
// 999 probes of which only way old replies return should give a 0% success rate
for i := 0; i < 999; i++ {
meh := d.Probe()
d.ProbeReceived(meh - 101)
}
assert.Equal(t, d.Grade(), float64(0))
}
*/
func TestHostmap(t *testing.T) {
_, myNet, _ := net.ParseCIDR("10.128.0.0/16")
_, localToMe, _ := net.ParseCIDR("192.168.1.0/24")
myNets := []*net.IPNet{myNet}
preferredRanges := []*net.IPNet{localToMe}
m := NewHostMap("test", myNet, preferredRanges)
a := NewUDPAddrFromString("10.127.0.3:11111")
b := NewUDPAddrFromString("1.0.0.1:22222")
y := NewUDPAddrFromString("10.128.0.3:11111")
m.AddRemote(ip2int(net.ParseIP("10.128.1.1")), a)
m.AddRemote(ip2int(net.ParseIP("10.128.1.1")), b)
m.AddRemote(ip2int(net.ParseIP("10.128.1.1")), y)
info, _ := m.QueryVpnIP(ip2int(net.ParseIP("10.128.1.1")))
// There should be three remotes in the host map
assert.Equal(t, 3, len(info.Remotes))
// Adding an identical remote should not change the count
m.AddRemote(ip2int(net.ParseIP("10.128.1.1")), y)
assert.Equal(t, 3, len(info.Remotes))
// Adding a fresh remote should add one
y = NewUDPAddrFromString("10.18.0.3:11111")
m.AddRemote(ip2int(net.ParseIP("10.128.1.1")), y)
assert.Equal(t, 4, len(info.Remotes))
// Query and reference remote should get the first one (and not nil)
info, _ = m.QueryVpnIP(ip2int(net.ParseIP("10.128.1.1")))
assert.NotNil(t, info.remote)
// Promotion should ensure that the best remote is chosen (y)
info.ForcePromoteBest(myNets)
assert.True(t, myNet.Contains(udp2ip(info.remote)))
}
func TestHostmapdebug(t *testing.T) {
_, myNet, _ := net.ParseCIDR("10.128.0.0/16")
_, localToMe, _ := net.ParseCIDR("192.168.1.0/24")
preferredRanges := []*net.IPNet{localToMe}
m := NewHostMap("test", myNet, preferredRanges)
a := NewUDPAddrFromString("10.127.0.3:11111")
b := NewUDPAddrFromString("1.0.0.1:22222")
y := NewUDPAddrFromString("10.128.0.3:11111")
m.AddRemote(ip2int(net.ParseIP("10.128.1.1")), a)
m.AddRemote(ip2int(net.ParseIP("10.128.1.1")), b)
m.AddRemote(ip2int(net.ParseIP("10.128.1.1")), y)
//t.Errorf("%s", m.DebugRemotes(1))
}
func TestHostMap_rotateRemote(t *testing.T) {
h := HostInfo{}
// 0 remotes, no panic
h.rotateRemote()
assert.Nil(t, h.remote)
// 1 remote, no panic
h.AddRemote(*NewUDPAddr(ip2int(net.IP{1, 1, 1, 1}), 0))
h.rotateRemote()
assert.Equal(t, udp2ipInt(h.remote), ip2int(net.IP{1, 1, 1, 1}))
h.AddRemote(*NewUDPAddr(ip2int(net.IP{1, 1, 1, 2}), 0))
h.AddRemote(*NewUDPAddr(ip2int(net.IP{1, 1, 1, 3}), 0))
h.AddRemote(*NewUDPAddr(ip2int(net.IP{1, 1, 1, 4}), 0))
// Rotate through those 3
h.rotateRemote()
assert.Equal(t, udp2ipInt(h.remote), ip2int(net.IP{1, 1, 1, 2}))
h.rotateRemote()
assert.Equal(t, udp2ipInt(h.remote), ip2int(net.IP{1, 1, 1, 3}))
h.rotateRemote()
assert.Equal(t, udp2ipInt(h.remote), ip2int(net.IP{1, 1, 1, 4}))
// Finally, we should start over
h.rotateRemote()
assert.Equal(t, udp2ipInt(h.remote), ip2int(net.IP{1, 1, 1, 1}))
}
func BenchmarkHostmappromote2(b *testing.B) {
for n := 0; n < b.N; n++ {
_, myNet, _ := net.ParseCIDR("10.128.0.0/16")
_, localToMe, _ := net.ParseCIDR("192.168.1.0/24")
preferredRanges := []*net.IPNet{localToMe}
m := NewHostMap("test", myNet, preferredRanges)
y := NewUDPAddrFromString("10.128.0.3:11111")
a := NewUDPAddrFromString("10.127.0.3:11111")
g := NewUDPAddrFromString("1.0.0.1:22222")
m.AddRemote(ip2int(net.ParseIP("10.128.1.1")), a)
m.AddRemote(ip2int(net.ParseIP("10.128.1.1")), g)
m.AddRemote(ip2int(net.ParseIP("10.128.1.1")), y)
}
}

120
inside.go
View File

@@ -7,10 +7,10 @@ import (
"github.com/sirupsen/logrus"
)
func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *FirewallPacket, nb, out []byte) {
func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *FirewallPacket, nb, out []byte, q int, localCache ConntrackCache) {
err := newPacket(packet, false, fwPacket)
if err != nil {
l.WithField("packet", packet).Debugf("Error while validating outbound packet: %s", err)
f.l.WithField("packet", packet).Debugf("Error while validating outbound packet: %s", err)
return
}
@@ -20,7 +20,7 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *FirewallPacket,
}
// Ignore packets from self to self
if fwPacket.RemoteIP == f.lightHouse.myIp {
if fwPacket.RemoteIP == f.myVpnIp {
return
}
@@ -31,8 +31,8 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *FirewallPacket,
hostinfo := f.getOrHandshake(fwPacket.RemoteIP)
if hostinfo == nil {
if l.Level >= logrus.DebugLevel {
l.WithField("vpnIp", IntIp(fwPacket.RemoteIP)).
if f.l.Level >= logrus.DebugLevel {
f.l.WithField("vpnIp", IntIp(fwPacket.RemoteIP)).
WithField("fwPacket", fwPacket).
Debugln("dropping outbound packet, vpnIp not in our CIDR or in unsafe routes")
}
@@ -45,22 +45,19 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *FirewallPacket,
// the packet queue.
ci.queueLock.Lock()
if !ci.ready {
hostinfo.cachePacket(message, 0, packet, f.sendMessageNow)
hostinfo.cachePacket(f.l, message, 0, packet, f.sendMessageNow, f.cachedPacketMetrics)
ci.queueLock.Unlock()
return
}
ci.queueLock.Unlock()
}
dropReason := f.firewall.Drop(packet, *fwPacket, false, hostinfo, trustedCAs)
dropReason := f.firewall.Drop(packet, *fwPacket, false, hostinfo, f.caPool, localCache)
if dropReason == nil {
mc := f.sendNoMetrics(message, 0, ci, hostinfo, hostinfo.remote, packet, nb, out)
if f.lightHouse != nil && mc%5000 == 0 {
f.lightHouse.Query(fwPacket.RemoteIP, f)
}
f.sendNoMetrics(message, 0, ci, hostinfo, hostinfo.remote, packet, nb, out, q)
} else if l.Level >= logrus.DebugLevel {
hostinfo.logger().
} else if f.l.Level >= logrus.DebugLevel {
hostinfo.logger(f.l).
WithField("fwPacket", fwPacket).
WithField("reason", dropReason).
Debugln("dropping outbound packet")
@@ -84,16 +81,25 @@ func (f *Interface) getOrHandshake(vpnIp uint32) *HostInfo {
hostinfo = f.handshakeManager.AddVpnIP(vpnIp)
}
}
ci := hostinfo.ConnectionState
if ci != nil && ci.eKey != nil && ci.ready {
return hostinfo
}
// Handshake is not ready, we need to grab the lock now before we start the handshake process
hostinfo.Lock()
defer hostinfo.Unlock()
// Double check, now that we have the lock
ci = hostinfo.ConnectionState
if ci != nil && ci.eKey != nil && ci.ready {
return hostinfo
}
if ci == nil {
// if we don't have a connection state, then send a handshake initiation
ci = f.newConnectionState(true, noise.HandshakeIX, []byte{}, 0)
ci = f.newConnectionState(f.l, true, noise.HandshakeIX, []byte{}, 0)
// FIXME: Maybe make XX selectable, but probably not since psk makes it nearly pointless for us.
//ci = f.newConnectionState(true, noise.HandshakeXX, []byte{}, 0)
hostinfo.ConnectionState = ci
@@ -124,33 +130,30 @@ func (f *Interface) sendMessageNow(t NebulaMessageType, st NebulaMessageSubType,
fp := &FirewallPacket{}
err := newPacket(p, false, fp)
if err != nil {
l.Warnf("error while parsing outgoing packet for firewall check; %v", err)
f.l.Warnf("error while parsing outgoing packet for firewall check; %v", err)
return
}
// check if packet is in outbound fw rules
dropReason := f.firewall.Drop(p, *fp, false, hostInfo, trustedCAs)
dropReason := f.firewall.Drop(p, *fp, false, hostInfo, f.caPool, nil)
if dropReason != nil {
if l.Level >= logrus.DebugLevel {
l.WithField("fwPacket", fp).
if f.l.Level >= logrus.DebugLevel {
f.l.WithField("fwPacket", fp).
WithField("reason", dropReason).
Debugln("dropping cached packet")
}
return
}
f.sendNoMetrics(message, st, hostInfo.ConnectionState, hostInfo, hostInfo.remote, p, nb, out)
if f.lightHouse != nil && *hostInfo.ConnectionState.messageCounter%5000 == 0 {
f.lightHouse.Query(fp.RemoteIP, f)
}
f.sendNoMetrics(message, st, hostInfo.ConnectionState, hostInfo, hostInfo.remote, p, nb, out, 0)
}
// SendMessageToVpnIp handles real ip:port lookup and sends to the current best known address for vpnIp
func (f *Interface) SendMessageToVpnIp(t NebulaMessageType, st NebulaMessageSubType, vpnIp uint32, p, nb, out []byte) {
hostInfo := f.getOrHandshake(vpnIp)
if hostInfo == nil {
if l.Level >= logrus.DebugLevel {
l.WithField("vpnIp", IntIp(vpnIp)).
if f.l.Level >= logrus.DebugLevel {
f.l.WithField("vpnIp", IntIp(vpnIp)).
Debugln("dropping SendMessageToVpnIp, vpnIp not in our CIDR or in unsafe routes")
}
return
@@ -161,7 +164,7 @@ func (f *Interface) SendMessageToVpnIp(t NebulaMessageType, st NebulaMessageSubT
// the packet queue.
hostInfo.ConnectionState.queueLock.Lock()
if !hostInfo.ConnectionState.ready {
hostInfo.cachePacket(t, st, p, f.sendMessageToVpnIp)
hostInfo.cachePacket(f.l, t, st, p, f.sendMessageToVpnIp, f.cachedPacketMetrics)
hostInfo.ConnectionState.queueLock.Unlock()
return
}
@@ -176,76 +179,55 @@ func (f *Interface) sendMessageToVpnIp(t NebulaMessageType, st NebulaMessageSubT
f.send(t, st, hostInfo.ConnectionState, hostInfo, hostInfo.remote, p, nb, out)
}
// SendMessageToAll handles real ip:port lookup and sends to all known addresses for vpnIp
func (f *Interface) SendMessageToAll(t NebulaMessageType, st NebulaMessageSubType, vpnIp uint32, p, nb, out []byte) {
hostInfo := f.getOrHandshake(vpnIp)
if hostInfo == nil {
if l.Level >= logrus.DebugLevel {
l.WithField("vpnIp", IntIp(vpnIp)).
Debugln("dropping SendMessageToAll, vpnIp not in our CIDR or in unsafe routes")
}
return
}
if hostInfo.ConnectionState.ready == false {
// Because we might be sending stored packets, lock here to stop new things going to
// the packet queue.
hostInfo.ConnectionState.queueLock.Lock()
if !hostInfo.ConnectionState.ready {
hostInfo.cachePacket(t, st, p, f.sendMessageToAll)
hostInfo.ConnectionState.queueLock.Unlock()
return
}
hostInfo.ConnectionState.queueLock.Unlock()
}
f.sendMessageToAll(t, st, hostInfo, p, nb, out)
return
}
func (f *Interface) sendMessageToAll(t NebulaMessageType, st NebulaMessageSubType, hostInfo *HostInfo, p, nb, b []byte) {
for _, r := range hostInfo.RemoteUDPAddrs() {
f.send(t, st, hostInfo.ConnectionState, hostInfo, r, p, nb, b)
}
}
func (f *Interface) send(t NebulaMessageType, st NebulaMessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote *udpAddr, p, nb, out []byte) {
f.messageMetrics.Tx(t, st, 1)
f.sendNoMetrics(t, st, ci, hostinfo, remote, p, nb, out)
f.sendNoMetrics(t, st, ci, hostinfo, remote, p, nb, out, 0)
}
func (f *Interface) sendNoMetrics(t NebulaMessageType, st NebulaMessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote *udpAddr, p, nb, out []byte) uint64 {
func (f *Interface) sendNoMetrics(t NebulaMessageType, st NebulaMessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote *udpAddr, p, nb, out []byte, q int) {
if ci.eKey == nil {
//TODO: log warning
return 0
return
}
var err error
//TODO: enable if we do more than 1 tun queue
//ci.writeLock.Lock()
c := atomic.AddUint64(ci.messageCounter, 1)
c := atomic.AddUint64(&ci.atomicMessageCounter, 1)
//l.WithField("trace", string(debug.Stack())).Error("out Header ", &Header{Version, t, st, 0, hostinfo.remoteIndexId, c}, p)
out = HeaderEncode(out, Version, uint8(t), uint8(st), hostinfo.remoteIndexId, c)
f.connectionManager.Out(hostinfo.hostId)
// Query our LH if we haven't since the last time we've been rebound, this will cause the remote to punch against
// all our IPs and enable a faster roaming.
if t != closeTunnel && hostinfo.lastRebindCount != f.rebindCount {
//NOTE: there is an update hole if a tunnel isn't used and exactly 256 rebinds occur before the tunnel is
// finally used again. This tunnel would eventually be torn down and recreated if this action didn't help.
f.lightHouse.QueryServer(hostinfo.hostId, f)
hostinfo.lastRebindCount = f.rebindCount
if f.l.Level >= logrus.DebugLevel {
f.l.WithField("vpnIp", hostinfo.hostId).Debug("Lighthouse update triggered for punch due to rebind counter")
}
}
out, err = ci.eKey.EncryptDanger(out, out, p, c, nb)
//TODO: see above note on lock
//ci.writeLock.Unlock()
if err != nil {
hostinfo.logger().WithError(err).
hostinfo.logger(f.l).WithError(err).
WithField("udpAddr", remote).WithField("counter", c).
WithField("attemptedCounter", ci.messageCounter).
WithField("attemptedCounter", c).
Error("Failed to encrypt outgoing packet")
return c
return
}
err = f.outside.WriteTo(out, remote)
err = f.writers[q].WriteTo(out, remote)
if err != nil {
hostinfo.logger().WithError(err).
hostinfo.logger(f.l).WithError(err).
WithField("udpAddr", remote).Error("Failed to write outgoing packet")
}
return c
return
}
func isMulticast(ip uint32) bool {

View File

@@ -5,9 +5,12 @@ import (
"io"
"net"
"os"
"runtime"
"time"
"github.com/rcrowley/go-metrics"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert"
)
const mtu = 9001
@@ -18,6 +21,7 @@ type Inside interface {
CidrNet() *net.IPNet
DeviceName() string
WriteRaw([]byte) error
NewMultiQueueReader() (io.ReadWriteCloser, error)
}
type InterfaceConfig struct {
@@ -35,10 +39,13 @@ type InterfaceConfig struct {
DropLocalBroadcast bool
DropMulticast bool
UDPBatchSize int
udpQueues int
tunQueues int
routines int
MessageMetrics *MessageMetrics
version string
caPool *cert.NebulaCAPool
ConntrackCacheTimeout time.Duration
l *logrus.Logger
}
type Interface struct {
@@ -54,15 +61,27 @@ type Interface struct {
createTime time.Time
lightHouse *LightHouse
localBroadcast uint32
myVpnIp uint32
dropLocalBroadcast bool
dropMulticast bool
udpBatchSize int
udpQueues int
tunQueues int
version string
routines int
caPool *cert.NebulaCAPool
metricHandshakes metrics.Histogram
messageMetrics *MessageMetrics
// rebindCount is used to decide if an active tunnel should trigger a punch notification through a lighthouse
rebindCount int8
version string
conntrackCacheTimeout time.Duration
writers []*udpConn
readers []io.ReadWriteCloser
metricHandshakes metrics.Histogram
messageMetrics *MessageMetrics
cachedPacketMetrics *cachedPacketMetrics
l *logrus.Logger
}
func NewInterface(c *InterfaceConfig) (*Interface, error) {
@@ -94,81 +113,108 @@ func NewInterface(c *InterfaceConfig) (*Interface, error) {
dropLocalBroadcast: c.DropLocalBroadcast,
dropMulticast: c.DropMulticast,
udpBatchSize: c.UDPBatchSize,
udpQueues: c.udpQueues,
tunQueues: c.tunQueues,
routines: c.routines,
version: c.version,
writers: make([]*udpConn, c.routines),
readers: make([]io.ReadWriteCloser, c.routines),
caPool: c.caPool,
myVpnIp: ip2int(c.certState.certificate.Details.Ips[0].IP),
conntrackCacheTimeout: c.ConntrackCacheTimeout,
metricHandshakes: metrics.GetOrRegisterHistogram("handshakes", nil, metrics.NewExpDecaySample(1028, 0.015)),
messageMetrics: c.MessageMetrics,
cachedPacketMetrics: &cachedPacketMetrics{
sent: metrics.GetOrRegisterCounter("hostinfo.cached_packets.sent", nil),
dropped: metrics.GetOrRegisterCounter("hostinfo.cached_packets.dropped", nil),
},
l: c.l,
}
ifce.connectionManager = newConnectionManager(ifce, c.checkInterval, c.pendingDeletionInterval)
ifce.connectionManager = newConnectionManager(c.l, ifce, c.checkInterval, c.pendingDeletionInterval)
return ifce, nil
}
func (f *Interface) run() {
// activate creates the interface on the host. After the interface is created, any
// other services that want to bind listeners to its IP may do so successfully. However,
// the interface isn't going to process anything until run() is called.
func (f *Interface) activate() {
// actually turn on tun dev
if err := f.inside.Activate(); err != nil {
l.Fatal(err)
}
addr, err := f.outside.LocalAddr()
if err != nil {
l.WithError(err).Error("Failed to get udp listen address")
f.l.WithError(err).Error("Failed to get udp listen address")
}
l.WithField("interface", f.inside.DeviceName()).WithField("network", f.inside.CidrNet().String()).
f.l.WithField("interface", f.inside.DeviceName()).WithField("network", f.inside.CidrNet().String()).
WithField("build", f.version).WithField("udpAddr", addr).
Info("Nebula interface is active")
metrics.GetOrRegisterGauge("routines", nil).Update(int64(f.routines))
// Prepare n tun queues
var reader io.ReadWriteCloser = f.inside
for i := 0; i < f.routines; i++ {
if i > 0 {
reader, err = f.inside.NewMultiQueueReader()
if err != nil {
f.l.Fatal(err)
}
}
f.readers[i] = reader
}
if err := f.inside.Activate(); err != nil {
f.l.Fatal(err)
}
}
func (f *Interface) run() {
// Launch n queues to read packets from udp
for i := 0; i < f.udpQueues; i++ {
for i := 0; i < f.routines; i++ {
go f.listenOut(i)
}
// Launch n queues to read packets from tun dev
for i := 0; i < f.tunQueues; i++ {
go f.listenIn(i)
for i := 0; i < f.routines; i++ {
go f.listenIn(f.readers[i], i)
}
}
func (f *Interface) listenOut(i int) {
//TODO: handle error
addr, err := f.outside.LocalAddr()
if err != nil {
l.WithError(err).Error("failed to discover udp listening address")
}
runtime.LockOSThread()
var li *udpConn
// TODO clean this up with a coherent interface for each outside connection
if i > 0 {
//TODO: handle error
li, err = NewListener(udp2ip(addr).String(), int(addr.Port), i > 0)
if err != nil {
l.WithError(err).Error("failed to make a new udp listener")
}
li = f.writers[i]
} else {
li = f.outside
}
li.ListenOut(f)
li.ListenOut(f, i)
}
func (f *Interface) listenIn(i int) {
func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
runtime.LockOSThread()
packet := make([]byte, mtu)
out := make([]byte, mtu)
fwPacket := &FirewallPacket{}
nb := make([]byte, 12, 12)
conntrackCache := NewConntrackCacheTicker(f.conntrackCacheTimeout)
for {
n, err := f.inside.Read(packet)
n, err := reader.Read(packet)
if err != nil {
l.WithError(err).Error("Error while reading outbound packet")
f.l.WithError(err).Error("Error while reading outbound packet")
// This only seems to happen when something fatal happens to the fd, so exit.
os.Exit(2)
}
f.consumeInsidePacket(packet[:n], fwPacket, nb, out)
f.consumeInsidePacket(packet[:n], fwPacket, nb, out, i, conntrackCache.Get(f.l))
}
}
@@ -176,27 +222,29 @@ func (f *Interface) RegisterConfigChangeCallbacks(c *Config) {
c.RegisterReloadCallback(f.reloadCA)
c.RegisterReloadCallback(f.reloadCertKey)
c.RegisterReloadCallback(f.reloadFirewall)
c.RegisterReloadCallback(f.outside.reloadConfig)
for _, udpConn := range f.writers {
c.RegisterReloadCallback(udpConn.reloadConfig)
}
}
func (f *Interface) reloadCA(c *Config) {
// reload and check regardless
// todo: need mutex?
newCAs, err := loadCAFromConfig(c)
newCAs, err := loadCAFromConfig(f.l, c)
if err != nil {
l.WithError(err).Error("Could not refresh trusted CA certificates")
f.l.WithError(err).Error("Could not refresh trusted CA certificates")
return
}
trustedCAs = newCAs
l.WithField("fingerprints", trustedCAs.GetFingerprints()).Info("Trusted CA certificates refreshed")
f.caPool = newCAs
f.l.WithField("fingerprints", f.caPool.GetFingerprints()).Info("Trusted CA certificates refreshed")
}
func (f *Interface) reloadCertKey(c *Config) {
// reload and check in all cases
cs, err := NewCertStateFromConfig(c)
if err != nil {
l.WithError(err).Error("Could not refresh client cert")
f.l.WithError(err).Error("Could not refresh client cert")
return
}
@@ -204,24 +252,24 @@ func (f *Interface) reloadCertKey(c *Config) {
oldIPs := f.certState.certificate.Details.Ips
newIPs := cs.certificate.Details.Ips
if len(oldIPs) > 0 && len(newIPs) > 0 && oldIPs[0].String() != newIPs[0].String() {
l.WithField("new_ip", newIPs[0]).WithField("old_ip", oldIPs[0]).Error("IP in new cert was different from old")
f.l.WithField("new_ip", newIPs[0]).WithField("old_ip", oldIPs[0]).Error("IP in new cert was different from old")
return
}
f.certState = cs
l.WithField("cert", cs.certificate).Info("Client cert refreshed from disk")
f.l.WithField("cert", cs.certificate).Info("Client cert refreshed from disk")
}
func (f *Interface) reloadFirewall(c *Config) {
//TODO: need to trigger/detect if the certificate changed too
if c.HasChanged("firewall") == false {
l.Debug("No firewall config change detected")
f.l.Debug("No firewall config change detected")
return
}
fw, err := NewFirewallFromConfig(f.certState.certificate, c)
fw, err := NewFirewallFromConfig(f.l, f.certState.certificate, c)
if err != nil {
l.WithError(err).Error("Error while creating firewall during reload")
f.l.WithError(err).Error("Error while creating firewall during reload")
return
}
@@ -234,7 +282,7 @@ func (f *Interface) reloadFirewall(c *Config) {
// If rulesVersion is back to zero, we have wrapped all the way around. Be
// safe and just reset conntrack in this case.
if fw.rulesVersion == 0 {
l.WithField("firewallHash", fw.GetRuleHash()).
f.l.WithField("firewallHash", fw.GetRuleHash()).
WithField("oldFirewallHash", oldFw.GetRuleHash()).
WithField("rulesVersion", fw.rulesVersion).
Warn("firewall rulesVersion has overflowed, resetting conntrack")
@@ -245,7 +293,7 @@ func (f *Interface) reloadFirewall(c *Config) {
f.firewall = fw
oldFw.Destroy()
l.WithField("firewallHash", fw.GetRuleHash()).
f.l.WithField("firewallHash", fw.GetRuleHash()).
WithField("oldFirewallHash", oldFw.GetRuleHash()).
WithField("rulesVersion", fw.rulesVersion).
Info("New firewall has been installed")
@@ -253,8 +301,13 @@ func (f *Interface) reloadFirewall(c *Config) {
func (f *Interface) emitStats(i time.Duration) {
ticker := time.NewTicker(i)
udpStats := NewUDPStatsEmitter(f.writers)
for range ticker.C {
f.firewall.EmitStats()
f.handshakeManager.EmitStats()
udpStats()
}
}

View File

@@ -1,6 +1,8 @@
package nebula
import (
"encoding/binary"
"errors"
"fmt"
"net"
"sync"
@@ -8,17 +10,25 @@ import (
"github.com/golang/protobuf/proto"
"github.com/rcrowley/go-metrics"
"github.com/slackhq/nebula/cert"
"github.com/sirupsen/logrus"
)
//TODO: if a lighthouse doesn't have an answer, clients AGGRESSIVELY REQUERY.. why? handshake manager and/or getOrHandshake?
//TODO: nodes are roaming lighthouses, this is bad. How are they learning?
var ErrHostNotKnown = errors.New("host not known")
type LightHouse struct {
//TODO: We need a timer wheel to kick out vpnIps that haven't reported in a long time
sync.RWMutex //Because we concurrently read and write to our maps
amLighthouse bool
myIp uint32
myVpnIp uint32
myVpnZeros uint32
punchConn *udpConn
// Local cache of answers from light houses
addrMap map[uint32][]udpAddr
// map of vpn Ip to answers
addrMap map[uint32]*RemoteList
// filters remote addresses allowed for each host
// - When we are a lighthouse, this filters what addresses we store and
@@ -38,24 +48,26 @@ type LightHouse struct {
staticList map[uint32]struct{}
lighthouses map[uint32]struct{}
interval int
nebulaPort int
nebulaPort uint32 // 32 bits because protobuf does not have a uint16
punchBack bool
punchDelay time.Duration
metrics *MessageMetrics
metricHolepunchTx metrics.Counter
l *logrus.Logger
}
type EncWriter interface {
SendMessageToVpnIp(t NebulaMessageType, st NebulaMessageSubType, vpnIp uint32, p, nb, out []byte)
SendMessageToAll(t NebulaMessageType, st NebulaMessageSubType, vpnIp uint32, p, nb, out []byte)
}
func NewLightHouse(amLighthouse bool, myIp uint32, ips []uint32, interval int, nebulaPort int, pc *udpConn, punchBack bool, punchDelay time.Duration, metricsEnabled bool) *LightHouse {
func NewLightHouse(l *logrus.Logger, amLighthouse bool, myVpnIpNet *net.IPNet, ips []uint32, interval int, nebulaPort uint32, pc *udpConn, punchBack bool, punchDelay time.Duration, metricsEnabled bool) *LightHouse {
ones, _ := myVpnIpNet.Mask.Size()
h := LightHouse{
amLighthouse: amLighthouse,
myIp: myIp,
addrMap: make(map[uint32][]udpAddr),
myVpnIp: ip2int(myVpnIpNet.IP),
myVpnZeros: uint32(32 - ones),
addrMap: make(map[uint32]*RemoteList),
nebulaPort: nebulaPort,
lighthouses: make(map[uint32]struct{}),
staticList: make(map[uint32]struct{}),
@@ -63,6 +75,7 @@ func NewLightHouse(amLighthouse bool, myIp uint32, ips []uint32, interval int, n
punchConn: pc,
punchBack: punchBack,
punchDelay: punchDelay,
l: l,
}
if metricsEnabled {
@@ -103,40 +116,10 @@ func (lh *LightHouse) ValidateLHStaticEntries() error {
return nil
}
func (lh *LightHouse) Query(ip uint32, f EncWriter) ([]udpAddr, error) {
func (lh *LightHouse) Query(ip uint32, f EncWriter) *RemoteList {
if !lh.IsLighthouseIP(ip) {
lh.QueryServer(ip, f)
}
lh.RLock()
if v, ok := lh.addrMap[ip]; ok {
lh.RUnlock()
return v, nil
}
lh.RUnlock()
return nil, fmt.Errorf("host %s not known, queries sent to lighthouses", IntIp(ip))
}
// This is asynchronous so no reply should be expected
func (lh *LightHouse) QueryServer(ip uint32, f EncWriter) {
if !lh.amLighthouse {
// Send a query to the lighthouses and hope for the best next time
query, err := proto.Marshal(NewLhQueryByInt(ip))
if err != nil {
l.WithError(err).WithField("vpnIp", IntIp(ip)).Error("Failed to marshal lighthouse query payload")
return
}
lh.metricTx(NebulaMeta_HostQuery, int64(len(lh.lighthouses)))
nb := make([]byte, 12, 12)
out := make([]byte, mtu)
for n := range lh.lighthouses {
f.SendMessageToVpnIp(lightHouse, 0, n, query, nb, out)
}
}
}
// Query our local lighthouse cached results
func (lh *LightHouse) QueryCache(ip uint32) []udpAddr {
lh.RLock()
if v, ok := lh.addrMap[ip]; ok {
lh.RUnlock()
@@ -146,6 +129,71 @@ func (lh *LightHouse) QueryCache(ip uint32) []udpAddr {
return nil
}
// This is asynchronous so no reply should be expected
func (lh *LightHouse) QueryServer(ip uint32, f EncWriter) {
if lh.amLighthouse {
return
}
if lh.IsLighthouseIP(ip) {
return
}
// Send a query to the lighthouses and hope for the best next time
query, err := proto.Marshal(NewLhQueryByInt(ip))
if err != nil {
lh.l.WithError(err).WithField("vpnIp", IntIp(ip)).Error("Failed to marshal lighthouse query payload")
return
}
lh.metricTx(NebulaMeta_HostQuery, int64(len(lh.lighthouses)))
nb := make([]byte, 12, 12)
out := make([]byte, mtu)
for n := range lh.lighthouses {
f.SendMessageToVpnIp(lightHouse, 0, n, query, nb, out)
}
}
func (lh *LightHouse) QueryCache(ip uint32) *RemoteList {
lh.RLock()
if v, ok := lh.addrMap[ip]; ok {
lh.RUnlock()
return v
}
lh.RUnlock()
lh.Lock()
defer lh.Unlock()
// Add an entry if we don't already have one
return lh.unlockedGetRemoteList(ip)
}
// queryAndPrepMessage is a lock helper on RemoteList, assisting the caller to build a lighthouse message containing
// details from the remote list. It looks for a hit in the addrMap and a hit in the RemoteList under the owner vpnIp
// If one is found then f() is called with proper locking, f() must return result of n.MarshalTo()
func (lh *LightHouse) queryAndPrepMessage(vpnIp uint32, f func(*cache) (int, error)) (bool, int, error) {
lh.RLock()
// Do we have an entry in the main cache?
if v, ok := lh.addrMap[vpnIp]; ok {
// Swap lh lock for remote list lock
v.RLock()
defer v.RUnlock()
lh.RUnlock()
// vpnIp should also be the owner here since we are a lighthouse.
c := v.cache[vpnIp]
// Make sure we have
if c != nil {
n, err := f(c)
return true, n, err
}
return false, 0, nil
}
lh.RUnlock()
return false, 0, nil
}
func (lh *LightHouse) DeleteVpnIP(vpnIP uint32) {
// First we check the static mapping
// and do nothing if it is there
@@ -155,47 +203,87 @@ func (lh *LightHouse) DeleteVpnIP(vpnIP uint32) {
lh.Lock()
//l.Debugln(lh.addrMap)
delete(lh.addrMap, vpnIP)
l.Debugf("deleting %s from lighthouse.", IntIp(vpnIP))
if lh.l.Level >= logrus.DebugLevel {
lh.l.Debugf("deleting %s from lighthouse.", IntIp(vpnIP))
}
lh.Unlock()
}
func (lh *LightHouse) AddRemote(vpnIP uint32, toIp *udpAddr, static bool) {
// First we check if the sender thinks this is a static entry
// and do nothing if it is not, but should be considered static
if static == false {
if _, ok := lh.staticList[vpnIP]; ok {
return
}
}
// AddStaticRemote adds a static host entry for vpnIp as ourselves as the owner
// We are the owner because we don't want a lighthouse server to advertise for static hosts it was configured with
// And we don't want a lighthouse query reply to interfere with our learned cache if we are a client
func (lh *LightHouse) AddStaticRemote(vpnIp uint32, toAddr *udpAddr) {
lh.Lock()
for _, v := range lh.addrMap[vpnIP] {
if v.Equals(toIp) {
lh.Unlock()
am := lh.unlockedGetRemoteList(vpnIp)
am.Lock()
defer am.Unlock()
lh.Unlock()
if ipv4 := toAddr.IP.To4(); ipv4 != nil {
to := NewIp4AndPort(ipv4, uint32(toAddr.Port))
if !lh.unlockedShouldAddV4(to) {
return
}
am.unlockedPrependV4(lh.myVpnIp, to)
} else {
to := NewIp6AndPort(toAddr.IP, uint32(toAddr.Port))
if !lh.unlockedShouldAddV6(to) {
return
}
am.unlockedPrependV6(lh.myVpnIp, to)
}
allow := lh.remoteAllowList.Allow(udp2ipInt(toIp))
l.WithField("remoteIp", toIp).WithField("allow", allow).Debug("remoteAllowList.Allow")
if !allow {
return
}
//l.Debugf("Adding reply of %s as %s\n", IntIp(vpnIP), toIp)
if static {
lh.staticList[vpnIP] = struct{}{}
}
lh.addrMap[vpnIP] = append(lh.addrMap[vpnIP], *toIp)
lh.Unlock()
// Mark it as static
lh.staticList[vpnIp] = struct{}{}
}
func (lh *LightHouse) AddRemoteAndReset(vpnIP uint32, toIp *udpAddr) {
if lh.amLighthouse {
lh.DeleteVpnIP(vpnIP)
lh.AddRemote(vpnIP, toIp, false)
// unlockedGetRemoteList assumes you have the lh lock
func (lh *LightHouse) unlockedGetRemoteList(vpnIP uint32) *RemoteList {
am, ok := lh.addrMap[vpnIP]
if !ok {
am = NewRemoteList()
lh.addrMap[vpnIP] = am
}
return am
}
// unlockedShouldAddV4 checks if to is allowed by our allow list
func (lh *LightHouse) unlockedShouldAddV4(to *Ip4AndPort) bool {
allow := lh.remoteAllowList.AllowIpV4(to.Ip)
if lh.l.Level >= logrus.TraceLevel {
lh.l.WithField("remoteIp", IntIp(to.Ip)).WithField("allow", allow).Trace("remoteAllowList.Allow")
}
if !allow || ipMaskContains(lh.myVpnIp, lh.myVpnZeros, to.Ip) {
return false
}
return true
}
// unlockedShouldAddV6 checks if to is allowed by our allow list
func (lh *LightHouse) unlockedShouldAddV6(to *Ip6AndPort) bool {
allow := lh.remoteAllowList.AllowIpV6(to.Hi, to.Lo)
if lh.l.Level >= logrus.TraceLevel {
lh.l.WithField("remoteIp", lhIp6ToIp(to)).WithField("allow", allow).Trace("remoteAllowList.Allow")
}
// We don't check our vpn network here because nebula does not support ipv6 on the inside
if !allow {
return false
}
return true
}
func lhIp6ToIp(v *Ip6AndPort) net.IP {
ip := make(net.IP, 16)
binary.BigEndian.PutUint64(ip[:8], v.Hi)
binary.BigEndian.PutUint64(ip[8:], v.Lo)
return ip
}
func (lh *LightHouse) IsLighthouseIP(vpnIP uint32) bool {
@@ -205,12 +293,6 @@ func (lh *LightHouse) IsLighthouseIP(vpnIP uint32) bool {
return false
}
// Quick generators for protobuf
func NewLhQueryByIpString(VpnIp string) *NebulaMeta {
return NewLhQueryByInt(ip2int(net.ParseIP(VpnIp)))
}
func NewLhQueryByInt(VpnIp uint32) *NebulaMeta {
return &NebulaMeta{
Type: NebulaMeta_HostQuery,
@@ -220,26 +302,30 @@ func NewLhQueryByInt(VpnIp uint32) *NebulaMeta {
}
}
func NewLhWhoami() *NebulaMeta {
return &NebulaMeta{
Type: NebulaMeta_HostWhoami,
Details: &NebulaMetaDetails{},
func NewIp4AndPort(ip net.IP, port uint32) *Ip4AndPort {
ipp := Ip4AndPort{Port: port}
ipp.Ip = ip2int(ip)
return &ipp
}
func NewIp6AndPort(ip net.IP, port uint32) *Ip6AndPort {
return &Ip6AndPort{
Hi: binary.BigEndian.Uint64(ip[:8]),
Lo: binary.BigEndian.Uint64(ip[8:]),
Port: port,
}
}
// End Quick generators for protobuf
func NewIpAndPortFromUDPAddr(addr udpAddr) *IpAndPort {
return &IpAndPort{Ip: udp2ipInt(&addr), Port: uint32(addr.Port)}
func NewUDPAddrFromLH4(ipp *Ip4AndPort) *udpAddr {
ip := ipp.Ip
return NewUDPAddr(
net.IPv4(byte(ip&0xff000000>>24), byte(ip&0x00ff0000>>16), byte(ip&0x0000ff00>>8), byte(ip&0x000000ff)),
uint16(ipp.Port),
)
}
func NewIpAndPortsFromNetIps(ips []udpAddr) *[]*IpAndPort {
var iap []*IpAndPort
for _, e := range ips {
// Only add IPs that aren't my VPN/tun IP
iap = append(iap, NewIpAndPortFromUDPAddr(e))
}
return &iap
func NewUDPAddrFromLH6(ipp *Ip6AndPort) *udpAddr {
return NewUDPAddr(lhIp6ToIp(ipp), uint16(ipp.Port))
}
func (lh *LightHouse) LhUpdateWorker(f EncWriter) {
@@ -248,195 +334,308 @@ func (lh *LightHouse) LhUpdateWorker(f EncWriter) {
}
for {
ipp := []*IpAndPort{}
for _, e := range *localIps(lh.localAllowList) {
// Only add IPs that aren't my VPN/tun IP
if ip2int(e) != lh.myIp {
ipp = append(ipp, &IpAndPort{Ip: ip2int(e), Port: uint32(lh.nebulaPort)})
//fmt.Println(e)
}
}
m := &NebulaMeta{
Type: NebulaMeta_HostUpdateNotification,
Details: &NebulaMetaDetails{
VpnIp: lh.myIp,
IpAndPorts: ipp,
},
}
lh.metricTx(NebulaMeta_HostUpdateNotification, int64(len(lh.lighthouses)))
nb := make([]byte, 12, 12)
out := make([]byte, mtu)
for vpnIp := range lh.lighthouses {
mm, err := proto.Marshal(m)
if err != nil {
l.Debugf("Invalid marshal to update")
}
//l.Error("LIGHTHOUSE PACKET SEND", mm)
f.SendMessageToVpnIp(lightHouse, 0, vpnIp, mm, nb, out)
}
lh.SendUpdate(f)
time.Sleep(time.Second * time.Duration(lh.interval))
}
}
func (lh *LightHouse) HandleRequest(rAddr *udpAddr, vpnIp uint32, p []byte, c *cert.NebulaCertificate, f EncWriter) {
n := &NebulaMeta{}
err := proto.Unmarshal(p, n)
func (lh *LightHouse) SendUpdate(f EncWriter) {
var v4 []*Ip4AndPort
var v6 []*Ip6AndPort
for _, e := range *localIps(lh.l, lh.localAllowList) {
if ip4 := e.To4(); ip4 != nil && ipMaskContains(lh.myVpnIp, lh.myVpnZeros, ip2int(ip4)) {
continue
}
// Only add IPs that aren't my VPN/tun IP
if ip := e.To4(); ip != nil {
v4 = append(v4, NewIp4AndPort(e, lh.nebulaPort))
} else {
v6 = append(v6, NewIp6AndPort(e, lh.nebulaPort))
}
}
m := &NebulaMeta{
Type: NebulaMeta_HostUpdateNotification,
Details: &NebulaMetaDetails{
VpnIp: lh.myVpnIp,
Ip4AndPorts: v4,
Ip6AndPorts: v6,
},
}
lh.metricTx(NebulaMeta_HostUpdateNotification, int64(len(lh.lighthouses)))
nb := make([]byte, 12, 12)
out := make([]byte, mtu)
mm, err := proto.Marshal(m)
if err != nil {
l.WithError(err).WithField("vpnIp", IntIp(vpnIp)).WithField("udpAddr", rAddr).
lh.l.WithError(err).Error("Error while marshaling for lighthouse update")
return
}
for vpnIp := range lh.lighthouses {
f.SendMessageToVpnIp(lightHouse, 0, vpnIp, mm, nb, out)
}
}
type LightHouseHandler struct {
lh *LightHouse
nb []byte
out []byte
pb []byte
meta *NebulaMeta
l *logrus.Logger
}
func (lh *LightHouse) NewRequestHandler() *LightHouseHandler {
lhh := &LightHouseHandler{
lh: lh,
nb: make([]byte, 12, 12),
out: make([]byte, mtu),
l: lh.l,
pb: make([]byte, mtu),
meta: &NebulaMeta{
Details: &NebulaMetaDetails{},
},
}
return lhh
}
func (lh *LightHouse) metricRx(t NebulaMeta_MessageType, i int64) {
lh.metrics.Rx(NebulaMessageType(t), 0, i)
}
func (lh *LightHouse) metricTx(t NebulaMeta_MessageType, i int64) {
lh.metrics.Tx(NebulaMessageType(t), 0, i)
}
// This method is similar to Reset(), but it re-uses the pointer structs
// so that we don't have to re-allocate them
func (lhh *LightHouseHandler) resetMeta() *NebulaMeta {
details := lhh.meta.Details
lhh.meta.Reset()
// Keep the array memory around
details.Ip4AndPorts = details.Ip4AndPorts[:0]
details.Ip6AndPorts = details.Ip6AndPorts[:0]
lhh.meta.Details = details
return lhh.meta
}
func (lhh *LightHouseHandler) HandleRequest(rAddr *udpAddr, vpnIp uint32, p []byte, w EncWriter) {
n := lhh.resetMeta()
err := n.Unmarshal(p)
if err != nil {
lhh.l.WithError(err).WithField("vpnIp", IntIp(vpnIp)).WithField("udpAddr", rAddr).
Error("Failed to unmarshal lighthouse packet")
//TODO: send recv_error?
return
}
if n.Details == nil {
l.WithField("vpnIp", IntIp(vpnIp)).WithField("udpAddr", rAddr).
lhh.l.WithField("vpnIp", IntIp(vpnIp)).WithField("udpAddr", rAddr).
Error("Invalid lighthouse update")
//TODO: send recv_error?
return
}
lh.metricRx(n.Type, 1)
lhh.lh.metricRx(n.Type, 1)
switch n.Type {
case NebulaMeta_HostQuery:
// Exit if we don't answer queries
if !lh.amLighthouse {
l.Debugln("I don't answer queries, but received from: ", rAddr)
return
}
//l.Debugln("Got Query")
ips, err := lh.Query(n.Details.VpnIp, f)
if err != nil {
//l.Debugf("Can't answer query %s from %s because error: %s", IntIp(n.Details.VpnIp), rAddr, err)
return
} else {
iap := NewIpAndPortsFromNetIps(ips)
answer := &NebulaMeta{
Type: NebulaMeta_HostQueryReply,
Details: &NebulaMetaDetails{
VpnIp: n.Details.VpnIp,
IpAndPorts: *iap,
},
}
reply, err := proto.Marshal(answer)
if err != nil {
l.WithError(err).WithField("vpnIp", IntIp(vpnIp)).Error("Failed to marshal lighthouse host query reply")
return
}
lh.metricTx(NebulaMeta_HostQueryReply, 1)
f.SendMessageToVpnIp(lightHouse, 0, vpnIp, reply, make([]byte, 12, 12), make([]byte, mtu))
// This signals the other side to punch some zero byte udp packets
ips, err = lh.Query(vpnIp, f)
if err != nil {
l.WithField("vpnIp", IntIp(vpnIp)).Debugln("Can't notify host to punch")
return
} else {
//l.Debugln("Notify host to punch", iap)
iap = NewIpAndPortsFromNetIps(ips)
answer = &NebulaMeta{
Type: NebulaMeta_HostPunchNotification,
Details: &NebulaMetaDetails{
VpnIp: vpnIp,
IpAndPorts: *iap,
},
}
reply, _ := proto.Marshal(answer)
lh.metricTx(NebulaMeta_HostPunchNotification, 1)
f.SendMessageToVpnIp(lightHouse, 0, n.Details.VpnIp, reply, make([]byte, 12, 12), make([]byte, mtu))
}
//fmt.Println(reply, remoteaddr)
}
lhh.handleHostQuery(n, vpnIp, rAddr, w)
case NebulaMeta_HostQueryReply:
if !lh.IsLighthouseIP(vpnIp) {
return
}
for _, a := range n.Details.IpAndPorts {
//first := n.Details.IpAndPorts[0]
ans := NewUDPAddr(a.Ip, uint16(a.Port))
lh.AddRemote(n.Details.VpnIp, ans, false)
}
// Non-blocking attempt to trigger, skip if it would block
select {
case lh.handshakeTrigger <- n.Details.VpnIp:
default:
}
lhh.handleHostQueryReply(n, vpnIp)
case NebulaMeta_HostUpdateNotification:
//Simple check that the host sent this not someone else
if n.Details.VpnIp != vpnIp {
l.WithField("vpnIp", IntIp(vpnIp)).WithField("answer", IntIp(n.Details.VpnIp)).Debugln("Host sent invalid update")
return
}
for _, a := range n.Details.IpAndPorts {
ans := NewUDPAddr(a.Ip, uint16(a.Port))
lh.AddRemote(n.Details.VpnIp, ans, false)
}
lhh.handleHostUpdateNotification(n, vpnIp)
case NebulaMeta_HostMovedNotification:
case NebulaMeta_HostPunchNotification:
if !lh.IsLighthouseIP(vpnIp) {
lhh.handleHostPunchNotification(n, vpnIp, w)
}
}
func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, vpnIp uint32, addr *udpAddr, w EncWriter) {
// Exit if we don't answer queries
if !lhh.lh.amLighthouse {
if lhh.l.Level >= logrus.DebugLevel {
lhh.l.Debugln("I don't answer queries, but received from: ", addr)
}
return
}
//TODO: we can DRY this further
reqVpnIP := n.Details.VpnIp
//TODO: Maybe instead of marshalling into n we marshal into a new `r` to not nuke our current request data
found, ln, err := lhh.lh.queryAndPrepMessage(n.Details.VpnIp, func(c *cache) (int, error) {
n = lhh.resetMeta()
n.Type = NebulaMeta_HostQueryReply
n.Details.VpnIp = reqVpnIP
lhh.coalesceAnswers(c, n)
return n.MarshalTo(lhh.pb)
})
if !found {
return
}
if err != nil {
lhh.l.WithError(err).WithField("vpnIp", IntIp(vpnIp)).Error("Failed to marshal lighthouse host query reply")
return
}
lhh.lh.metricTx(NebulaMeta_HostQueryReply, 1)
w.SendMessageToVpnIp(lightHouse, 0, vpnIp, lhh.pb[:ln], lhh.nb, lhh.out[:0])
// This signals the other side to punch some zero byte udp packets
found, ln, err = lhh.lh.queryAndPrepMessage(vpnIp, func(c *cache) (int, error) {
n = lhh.resetMeta()
n.Type = NebulaMeta_HostPunchNotification
n.Details.VpnIp = vpnIp
lhh.coalesceAnswers(c, n)
return n.MarshalTo(lhh.pb)
})
if !found {
return
}
if err != nil {
lhh.l.WithError(err).WithField("vpnIp", IntIp(vpnIp)).Error("Failed to marshal lighthouse host was queried for")
return
}
lhh.lh.metricTx(NebulaMeta_HostPunchNotification, 1)
w.SendMessageToVpnIp(lightHouse, 0, reqVpnIP, lhh.pb[:ln], lhh.nb, lhh.out[:0])
}
func (lhh *LightHouseHandler) coalesceAnswers(c *cache, n *NebulaMeta) {
if c.v4 != nil {
if c.v4.learned != nil {
n.Details.Ip4AndPorts = append(n.Details.Ip4AndPorts, c.v4.learned)
}
if c.v4.reported != nil && len(c.v4.reported) > 0 {
n.Details.Ip4AndPorts = append(n.Details.Ip4AndPorts, c.v4.reported...)
}
}
if c.v6 != nil {
if c.v6.learned != nil {
n.Details.Ip6AndPorts = append(n.Details.Ip6AndPorts, c.v6.learned)
}
if c.v6.reported != nil && len(c.v6.reported) > 0 {
n.Details.Ip6AndPorts = append(n.Details.Ip6AndPorts, c.v6.reported...)
}
}
}
func (lhh *LightHouseHandler) handleHostQueryReply(n *NebulaMeta, vpnIp uint32) {
if !lhh.lh.IsLighthouseIP(vpnIp) {
return
}
lhh.lh.Lock()
am := lhh.lh.unlockedGetRemoteList(n.Details.VpnIp)
am.Lock()
lhh.lh.Unlock()
am.unlockedSetV4(vpnIp, n.Details.Ip4AndPorts, lhh.lh.unlockedShouldAddV4)
am.unlockedSetV6(vpnIp, n.Details.Ip6AndPorts, lhh.lh.unlockedShouldAddV6)
am.Unlock()
// Non-blocking attempt to trigger, skip if it would block
select {
case lhh.lh.handshakeTrigger <- n.Details.VpnIp:
default:
}
}
func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, vpnIp uint32) {
if !lhh.lh.amLighthouse {
if lhh.l.Level >= logrus.DebugLevel {
lhh.l.Debugln("I am not a lighthouse, do not take host updates: ", vpnIp)
}
return
}
//Simple check that the host sent this not someone else
if n.Details.VpnIp != vpnIp {
if lhh.l.Level >= logrus.DebugLevel {
lhh.l.WithField("vpnIp", IntIp(vpnIp)).WithField("answer", IntIp(n.Details.VpnIp)).Debugln("Host sent invalid update")
}
return
}
lhh.lh.Lock()
am := lhh.lh.unlockedGetRemoteList(vpnIp)
am.Lock()
lhh.lh.Unlock()
am.unlockedSetV4(vpnIp, n.Details.Ip4AndPorts, lhh.lh.unlockedShouldAddV4)
am.unlockedSetV6(vpnIp, n.Details.Ip6AndPorts, lhh.lh.unlockedShouldAddV6)
am.Unlock()
}
func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, vpnIp uint32, w EncWriter) {
if !lhh.lh.IsLighthouseIP(vpnIp) {
return
}
empty := []byte{0}
punch := func(vpnPeer *udpAddr) {
if vpnPeer == nil {
return
}
empty := []byte{0}
for _, a := range n.Details.IpAndPorts {
vpnPeer := NewUDPAddr(a.Ip, uint16(a.Port))
go func() {
time.Sleep(lh.punchDelay)
lh.metricHolepunchTx.Inc(1)
lh.punchConn.WriteTo(empty, vpnPeer)
go func() {
time.Sleep(lhh.lh.punchDelay)
lhh.lh.metricHolepunchTx.Inc(1)
lhh.lh.punchConn.WriteTo(empty, vpnPeer)
}()
}()
l.Debugf("Punching %s on %d for %s", IntIp(a.Ip), a.Port, IntIp(n.Details.VpnIp))
}
// This sends a nebula test packet to the host trying to contact us. In the case
// of a double nat or other difficult scenario, this may help establish
// a tunnel.
if lh.punchBack {
go func() {
time.Sleep(time.Second * 5)
l.Debugf("Sending a nebula test packet to vpn ip %s", IntIp(n.Details.VpnIp))
f.SendMessageToVpnIp(test, testRequest, n.Details.VpnIp, []byte(""), make([]byte, 12, 12), make([]byte, mtu))
}()
if lhh.l.Level >= logrus.DebugLevel {
//TODO: lacking the ip we are actually punching on, old: l.Debugf("Punching %s on %d for %s", IntIp(a.Ip), a.Port, IntIp(n.Details.VpnIp))
lhh.l.Debugf("Punching on %d for %s", vpnPeer.Port, IntIp(n.Details.VpnIp))
}
}
}
func (lh *LightHouse) metricRx(t NebulaMeta_MessageType, i int64) {
lh.metrics.Rx(NebulaMessageType(t), 0, i)
}
func (lh *LightHouse) metricTx(t NebulaMeta_MessageType, i int64) {
lh.metrics.Tx(NebulaMessageType(t), 0, i)
}
for _, a := range n.Details.Ip4AndPorts {
punch(NewUDPAddrFromLH4(a))
}
/*
func (f *Interface) sendPathCheck(ci *ConnectionState, endpoint *net.UDPAddr, counter int) {
c := ci.messageCounter
b := HeaderEncode(nil, Version, uint8(path_check), 0, ci.remoteIndex, c)
ci.messageCounter++
for _, a := range n.Details.Ip6AndPorts {
punch(NewUDPAddrFromLH6(a))
}
if ci.eKey != nil {
msg := ci.eKey.EncryptDanger(b, nil, []byte(strconv.Itoa(counter)), c)
//msg := ci.eKey.EncryptDanger(b, nil, []byte(fmt.Sprintf("%d", counter)), c)
f.outside.WriteTo(msg, endpoint)
l.Debugf("path_check sent, remote index: %d, pathCounter %d", ci.remoteIndex, counter)
// This sends a nebula test packet to the host trying to contact us. In the case
// of a double nat or other difficult scenario, this may help establish
// a tunnel.
if lhh.lh.punchBack {
go func() {
time.Sleep(time.Second * 5)
if lhh.l.Level >= logrus.DebugLevel {
lhh.l.Debugf("Sending a nebula test packet to vpn ip %s", IntIp(n.Details.VpnIp))
}
//NOTE: we have to allocate a new output buffer here since we are spawning a new goroutine
// for each punchBack packet. We should move this into a timerwheel or a single goroutine
// managed by a channel.
w.SendMessageToVpnIp(test, testRequest, n.Details.VpnIp, []byte(""), make([]byte, 12, 12), make([]byte, mtu))
}()
}
}
func (f *Interface) sendPathCheckReply(ci *ConnectionState, endpoint *net.UDPAddr, counter []byte) {
c := ci.messageCounter
b := HeaderEncode(nil, Version, uint8(path_check_reply), 0, ci.remoteIndex, c)
ci.messageCounter++
if ci.eKey != nil {
msg := ci.eKey.EncryptDanger(b, nil, counter, c)
f.outside.WriteTo(msg, endpoint)
l.Debugln("path_check sent, remote index: ", ci.remoteIndex)
}
// ipMaskContains checks if testIp is contained by ip after applying a cidr
// zeros is 32 - bits from net.IPMask.Size()
func ipMaskContains(ip uint32, zeros uint32, testIp uint32) bool {
return (testIp^ip)>>zeros == 0
}
*/

View File

@@ -1,6 +1,7 @@
package nebula
import (
"fmt"
"net"
"testing"
@@ -8,6 +9,17 @@ import (
"github.com/stretchr/testify/assert"
)
//TODO: Add a test to ensure udpAddr is copied and not reused
func TestOldIPv4Only(t *testing.T) {
// This test ensures our new ipv6 enabled LH protobuf IpAndPorts works with the old style to enable backwards compatibility
b := []byte{8, 129, 130, 132, 80, 16, 10}
var m Ip4AndPort
err := proto.Unmarshal(b, &m)
assert.NoError(t, err)
assert.Equal(t, "10.1.1.1", int2ip(m.GetIp()).String())
}
func TestNewLhQuery(t *testing.T) {
myIp := net.ParseIP("192.1.1.1")
myIpint := ip2int(myIp)
@@ -29,70 +41,344 @@ func TestNewLhQuery(t *testing.T) {
}
func TestNewipandportfromudpaddr(t *testing.T) {
blah := NewUDPAddrFromString("1.2.2.3:12345")
meh := NewIpAndPortFromUDPAddr(*blah)
assert.Equal(t, uint32(16908803), meh.Ip)
assert.Equal(t, uint32(12345), meh.Port)
}
func TestNewipandportsfromudpaddrs(t *testing.T) {
blah := NewUDPAddrFromString("1.2.2.3:12345")
blah2 := NewUDPAddrFromString("9.9.9.9:47828")
group := []udpAddr{*blah, *blah2}
hah := NewIpAndPortsFromNetIps(group)
assert.IsType(t, &[]*IpAndPort{}, hah)
//t.Error(reflect.TypeOf(hah))
}
func Test_lhStaticMapping(t *testing.T) {
l := NewTestLogger()
lh1 := "10.128.0.2"
lh1IP := net.ParseIP(lh1)
udpServer, _ := NewListener("0.0.0.0", 0, true)
udpServer, _ := NewListener(l, "0.0.0.0", 0, true)
meh := NewLightHouse(true, 1, []uint32{ip2int(lh1IP)}, 10, 10003, udpServer, false, 1, false)
meh.AddRemote(ip2int(lh1IP), NewUDPAddr(ip2int(lh1IP), uint16(4242)), true)
meh := NewLightHouse(l, true, &net.IPNet{IP: net.IP{0, 0, 0, 1}, Mask: net.IPMask{255, 255, 255, 255}}, []uint32{ip2int(lh1IP)}, 10, 10003, udpServer, false, 1, false)
meh.AddStaticRemote(ip2int(lh1IP), NewUDPAddr(lh1IP, uint16(4242)))
err := meh.ValidateLHStaticEntries()
assert.Nil(t, err)
lh2 := "10.128.0.3"
lh2IP := net.ParseIP(lh2)
meh = NewLightHouse(true, 1, []uint32{ip2int(lh1IP), ip2int(lh2IP)}, 10, 10003, udpServer, false, 1, false)
meh.AddRemote(ip2int(lh1IP), NewUDPAddr(ip2int(lh1IP), uint16(4242)), true)
meh = NewLightHouse(l, true, &net.IPNet{IP: net.IP{0, 0, 0, 1}, Mask: net.IPMask{255, 255, 255, 255}}, []uint32{ip2int(lh1IP), ip2int(lh2IP)}, 10, 10003, udpServer, false, 1, false)
meh.AddStaticRemote(ip2int(lh1IP), NewUDPAddr(lh1IP, uint16(4242)))
err = meh.ValidateLHStaticEntries()
assert.EqualError(t, err, "Lighthouse 10.128.0.3 does not have a static_host_map entry")
}
//func NewLightHouse(amLighthouse bool, myIp uint32, ips []string, interval int, nebulaPort int, pc *udpConn, punchBack bool) *LightHouse {
func BenchmarkLighthouseHandleRequest(b *testing.B) {
l := NewTestLogger()
lh1 := "10.128.0.2"
lh1IP := net.ParseIP(lh1)
/*
func TestLHQuery(t *testing.T) {
//n := NewLhQueryByIpString("10.128.0.3")
_, myNet, _ := net.ParseCIDR("10.128.0.0/16")
m := NewHostMap(myNet)
y, _ := net.ResolveUDPAddr("udp", "10.128.0.3:11111")
m.Add(ip2int(net.ParseIP("127.0.0.1")), y)
//t.Errorf("%s", m)
_ = m
udpServer, _ := NewListener(l, "0.0.0.0", 0, true)
_, n, _ := net.ParseCIDR("127.0.0.1/8")
lh := NewLightHouse(l, true, &net.IPNet{IP: net.IP{0, 0, 0, 1}, Mask: net.IPMask{0, 0, 0, 0}}, []uint32{ip2int(lh1IP)}, 10, 10003, udpServer, false, 1, false)
/*udpServer, err := net.ListenUDP("udp", &net.UDPAddr{Port: 10009})
if err != nil {
t.Errorf("%s", err)
}
hAddr := NewUDPAddrFromString("4.5.6.7:12345")
hAddr2 := NewUDPAddrFromString("4.5.6.7:12346")
lh.addrMap[3] = NewRemoteList()
lh.addrMap[3].unlockedSetV4(
3,
[]*Ip4AndPort{
NewIp4AndPort(hAddr.IP, uint32(hAddr.Port)),
NewIp4AndPort(hAddr2.IP, uint32(hAddr2.Port)),
},
func(*Ip4AndPort) bool { return true },
)
meh := NewLightHouse(n, m, []string{"10.128.0.2"}, false, 10, 10003, 10004)
//t.Error(m.Hosts)
meh2, err := meh.Query(ip2int(net.ParseIP("10.128.0.3")))
t.Error(err)
if err != nil {
return
}
t.Errorf("%s", meh2)
t.Errorf("%s", n)
rAddr := NewUDPAddrFromString("1.2.2.3:12345")
rAddr2 := NewUDPAddrFromString("1.2.2.3:12346")
lh.addrMap[2] = NewRemoteList()
lh.addrMap[2].unlockedSetV4(
3,
[]*Ip4AndPort{
NewIp4AndPort(rAddr.IP, uint32(rAddr.Port)),
NewIp4AndPort(rAddr2.IP, uint32(rAddr2.Port)),
},
func(*Ip4AndPort) bool { return true },
)
mw := &mockEncWriter{}
b.Run("notfound", func(b *testing.B) {
lhh := lh.NewRequestHandler()
req := &NebulaMeta{
Type: NebulaMeta_HostQuery,
Details: &NebulaMetaDetails{
VpnIp: 4,
Ip4AndPorts: nil,
},
}
p, err := proto.Marshal(req)
assert.NoError(b, err)
for n := 0; n < b.N; n++ {
lhh.HandleRequest(rAddr, 2, p, mw)
}
})
b.Run("found", func(b *testing.B) {
lhh := lh.NewRequestHandler()
req := &NebulaMeta{
Type: NebulaMeta_HostQuery,
Details: &NebulaMetaDetails{
VpnIp: 3,
Ip4AndPorts: nil,
},
}
p, err := proto.Marshal(req)
assert.NoError(b, err)
for n := 0; n < b.N; n++ {
lhh.HandleRequest(rAddr, 2, p, mw)
}
})
}
func TestLighthouse_Memory(t *testing.T) {
l := NewTestLogger()
myUdpAddr0 := &udpAddr{IP: net.ParseIP("10.0.0.2"), Port: 4242}
myUdpAddr1 := &udpAddr{IP: net.ParseIP("192.168.0.2"), Port: 4242}
myUdpAddr2 := &udpAddr{IP: net.ParseIP("172.16.0.2"), Port: 4242}
myUdpAddr3 := &udpAddr{IP: net.ParseIP("100.152.0.2"), Port: 4242}
myUdpAddr4 := &udpAddr{IP: net.ParseIP("24.15.0.2"), Port: 4242}
myUdpAddr5 := &udpAddr{IP: net.ParseIP("192.168.0.2"), Port: 4243}
myUdpAddr6 := &udpAddr{IP: net.ParseIP("192.168.0.2"), Port: 4244}
myUdpAddr7 := &udpAddr{IP: net.ParseIP("192.168.0.2"), Port: 4245}
myUdpAddr8 := &udpAddr{IP: net.ParseIP("192.168.0.2"), Port: 4246}
myUdpAddr9 := &udpAddr{IP: net.ParseIP("192.168.0.2"), Port: 4247}
myUdpAddr10 := &udpAddr{IP: net.ParseIP("192.168.0.2"), Port: 4248}
myUdpAddr11 := &udpAddr{IP: net.ParseIP("192.168.0.2"), Port: 4249}
myVpnIp := ip2int(net.ParseIP("10.128.0.2"))
theirUdpAddr0 := &udpAddr{IP: net.ParseIP("10.0.0.3"), Port: 4242}
theirUdpAddr1 := &udpAddr{IP: net.ParseIP("192.168.0.3"), Port: 4242}
theirUdpAddr2 := &udpAddr{IP: net.ParseIP("172.16.0.3"), Port: 4242}
theirUdpAddr3 := &udpAddr{IP: net.ParseIP("100.152.0.3"), Port: 4242}
theirUdpAddr4 := &udpAddr{IP: net.ParseIP("24.15.0.3"), Port: 4242}
theirVpnIp := ip2int(net.ParseIP("10.128.0.3"))
udpServer, _ := NewListener(l, "0.0.0.0", 0, true)
lh := NewLightHouse(l, true, &net.IPNet{IP: net.IP{10, 128, 0, 1}, Mask: net.IPMask{255, 255, 255, 0}}, []uint32{}, 10, 10003, udpServer, false, 1, false)
lhh := lh.NewRequestHandler()
// Test that my first update responds with just that
newLHHostUpdate(myUdpAddr0, myVpnIp, []*udpAddr{myUdpAddr1, myUdpAddr2}, lhh)
r := newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr1, myUdpAddr2)
// Ensure we don't accumulate addresses
newLHHostUpdate(myUdpAddr0, myVpnIp, []*udpAddr{myUdpAddr3}, lhh)
r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr3)
// Grow it back to 2
newLHHostUpdate(myUdpAddr0, myVpnIp, []*udpAddr{myUdpAddr1, myUdpAddr4}, lhh)
r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr1, myUdpAddr4)
// Update a different host
newLHHostUpdate(theirUdpAddr0, theirVpnIp, []*udpAddr{theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4}, lhh)
r = newLHHostRequest(theirUdpAddr0, theirVpnIp, myVpnIp, lhh)
assertIp4InArray(t, r.msg.Details.Ip4AndPorts, theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4)
// Make sure we didn't get changed
r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr1, myUdpAddr4)
// Ensure proper ordering and limiting
// Send 12 addrs, get 10 back, the last 2 removed, allowing the duplicate to remain (clients dedupe)
newLHHostUpdate(
myUdpAddr0,
myVpnIp,
[]*udpAddr{
myUdpAddr1,
myUdpAddr2,
myUdpAddr3,
myUdpAddr4,
myUdpAddr5,
myUdpAddr5, //Duplicated on purpose
myUdpAddr6,
myUdpAddr7,
myUdpAddr8,
myUdpAddr9,
myUdpAddr10,
myUdpAddr11, // This should get cut
}, lhh)
r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
assertIp4InArray(
t,
r.msg.Details.Ip4AndPorts,
myUdpAddr1, myUdpAddr2, myUdpAddr3, myUdpAddr4, myUdpAddr5, myUdpAddr5, myUdpAddr6, myUdpAddr7, myUdpAddr8, myUdpAddr9,
)
// Make sure we won't add ips in our vpn network
bad1 := &udpAddr{IP: net.ParseIP("10.128.0.99"), Port: 4242}
bad2 := &udpAddr{IP: net.ParseIP("10.128.0.100"), Port: 4242}
good := &udpAddr{IP: net.ParseIP("1.128.0.99"), Port: 4242}
newLHHostUpdate(myUdpAddr0, myVpnIp, []*udpAddr{bad1, bad2, good}, lhh)
r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
assertIp4InArray(t, r.msg.Details.Ip4AndPorts, good)
}
func newLHHostRequest(fromAddr *udpAddr, myVpnIp, queryVpnIp uint32, lhh *LightHouseHandler) testLhReply {
req := &NebulaMeta{
Type: NebulaMeta_HostQuery,
Details: &NebulaMetaDetails{
VpnIp: queryVpnIp,
},
}
b, err := req.Marshal()
if err != nil {
panic(err)
}
w := &testEncWriter{}
lhh.HandleRequest(fromAddr, myVpnIp, b, w)
return w.lastReply
}
func newLHHostUpdate(fromAddr *udpAddr, vpnIp uint32, addrs []*udpAddr, lhh *LightHouseHandler) {
req := &NebulaMeta{
Type: NebulaMeta_HostUpdateNotification,
Details: &NebulaMetaDetails{
VpnIp: vpnIp,
Ip4AndPorts: make([]*Ip4AndPort, len(addrs)),
},
}
for k, v := range addrs {
req.Details.Ip4AndPorts[k] = &Ip4AndPort{Ip: ip2int(v.IP), Port: uint32(v.Port)}
}
b, err := req.Marshal()
if err != nil {
panic(err)
}
w := &testEncWriter{}
lhh.HandleRequest(fromAddr, vpnIp, b, w)
}
//TODO: this is a RemoteList test
//func Test_lhRemoteAllowList(t *testing.T) {
// l := NewTestLogger()
// c := NewConfig(l)
// c.Settings["remoteallowlist"] = map[interface{}]interface{}{
// "10.20.0.0/12": false,
// }
// allowList, err := c.GetAllowList("remoteallowlist", false)
// assert.Nil(t, err)
//
// lh1 := "10.128.0.2"
// lh1IP := net.ParseIP(lh1)
//
// udpServer, _ := NewListener(l, "0.0.0.0", 0, true)
//
// lh := NewLightHouse(l, true, &net.IPNet{IP: net.IP{0, 0, 0, 1}, Mask: net.IPMask{255, 255, 255, 0}}, []uint32{ip2int(lh1IP)}, 10, 10003, udpServer, false, 1, false)
// lh.SetRemoteAllowList(allowList)
//
// // A disallowed ip should not enter the cache but we should end up with an empty entry in the addrMap
// remote1IP := net.ParseIP("10.20.0.3")
// remotes := lh.unlockedGetRemoteList(ip2int(remote1IP))
// remotes.unlockedPrependV4(ip2int(remote1IP), NewIp4AndPort(remote1IP, 4242))
// assert.NotNil(t, lh.addrMap[ip2int(remote1IP)])
// assert.Empty(t, lh.addrMap[ip2int(remote1IP)].CopyAddrs([]*net.IPNet{}))
//
// // Make sure a good ip enters the cache and addrMap
// remote2IP := net.ParseIP("10.128.0.3")
// remote2UDPAddr := NewUDPAddr(remote2IP, uint16(4242))
// lh.addRemoteV4(ip2int(remote2IP), ip2int(remote2IP), NewIp4AndPort(remote2UDPAddr.IP, uint32(remote2UDPAddr.Port)), false, false)
// assertUdpAddrInArray(t, lh.addrMap[ip2int(remote2IP)].CopyAddrs([]*net.IPNet{}), remote2UDPAddr)
//
// // Another good ip gets into the cache, ordering is inverted
// remote3IP := net.ParseIP("10.128.0.4")
// remote3UDPAddr := NewUDPAddr(remote3IP, uint16(4243))
// lh.addRemoteV4(ip2int(remote2IP), ip2int(remote2IP), NewIp4AndPort(remote3UDPAddr.IP, uint32(remote3UDPAddr.Port)), false, false)
// assertUdpAddrInArray(t, lh.addrMap[ip2int(remote2IP)].CopyAddrs([]*net.IPNet{}), remote2UDPAddr, remote3UDPAddr)
//
// // If we exceed the length limit we should only have the most recent addresses
// addedAddrs := []*udpAddr{}
// for i := 0; i < 11; i++ {
// remoteUDPAddr := NewUDPAddr(net.IP{10, 128, 0, 4}, uint16(4243+i))
// lh.addRemoteV4(ip2int(remote2IP), ip2int(remote2IP), NewIp4AndPort(remoteUDPAddr.IP, uint32(remoteUDPAddr.Port)), false, false)
// // The first entry here is a duplicate, don't add it to the assert list
// if i != 0 {
// addedAddrs = append(addedAddrs, remoteUDPAddr)
// }
// }
//
// // We should only have the last 10 of what we tried to add
// assert.True(t, len(addedAddrs) >= 10, "We should have tried to add at least 10 addresses")
// assertUdpAddrInArray(
// t,
// lh.addrMap[ip2int(remote2IP)].CopyAddrs([]*net.IPNet{}),
// addedAddrs[0],
// addedAddrs[1],
// addedAddrs[2],
// addedAddrs[3],
// addedAddrs[4],
// addedAddrs[5],
// addedAddrs[6],
// addedAddrs[7],
// addedAddrs[8],
// addedAddrs[9],
// )
//}
func Test_ipMaskContains(t *testing.T) {
assert.True(t, ipMaskContains(ip2int(net.ParseIP("10.0.0.1")), 32-24, ip2int(net.ParseIP("10.0.0.255"))))
assert.False(t, ipMaskContains(ip2int(net.ParseIP("10.0.0.1")), 32-24, ip2int(net.ParseIP("10.0.1.1"))))
assert.True(t, ipMaskContains(ip2int(net.ParseIP("10.0.0.1")), 32, ip2int(net.ParseIP("10.0.1.1"))))
}
type testLhReply struct {
nebType NebulaMessageType
nebSubType NebulaMessageSubType
vpnIp uint32
msg *NebulaMeta
}
type testEncWriter struct {
lastReply testLhReply
}
func (tw *testEncWriter) SendMessageToVpnIp(t NebulaMessageType, st NebulaMessageSubType, vpnIp uint32, p, _, _ []byte) {
tw.lastReply = testLhReply{
nebType: t,
nebSubType: st,
vpnIp: vpnIp,
msg: &NebulaMeta{},
}
err := proto.Unmarshal(p, tw.lastReply.msg)
if err != nil {
panic(err)
}
}
// assertIp4InArray asserts every address in want is at the same position in have and that the lengths match
func assertIp4InArray(t *testing.T, have []*Ip4AndPort, want ...*udpAddr) {
assert.Len(t, have, len(want))
for k, w := range want {
if !(have[k].Ip == ip2int(w.IP) && have[k].Port == uint32(w.Port)) {
assert.Fail(t, fmt.Sprintf("Response did not contain: %v:%v at %v; %v", w.IP, w.Port, k, translateV4toUdpAddr(have)))
}
}
}
// assertUdpAddrInArray asserts every address in want is at the same position in have and that the lengths match
func assertUdpAddrInArray(t *testing.T, have []*udpAddr, want ...*udpAddr) {
assert.Len(t, have, len(want))
for k, w := range want {
if !(have[k].IP.Equal(w.IP) && have[k].Port == w.Port) {
assert.Fail(t, fmt.Sprintf("Response did not contain: %v at %v; %v", w, k, have))
}
}
}
func translateV4toUdpAddr(ips []*Ip4AndPort) []*udpAddr {
addrs := make([]*udpAddr, len(ips))
for k, v := range ips {
addrs[k] = NewUDPAddrFromLH4(v)
}
return addrs
}
*/

174
main.go
View File

@@ -4,8 +4,6 @@ import (
"encoding/binary"
"fmt"
"net"
"strconv"
"strings"
"time"
"github.com/sirupsen/logrus"
@@ -13,13 +11,10 @@ import (
"gopkg.in/yaml.v2"
)
// The caller should provide a real logger, we have one just in case
var l = logrus.New()
type m map[string]interface{}
func Main(config *Config, configTest bool, buildVersion string, logger *logrus.Logger, tunFd *int) (*Control, error) {
l = logger
l := logger
l.Formatter = &logrus.TextFormatter{
FullTimestamp: true,
}
@@ -47,13 +42,12 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
}
})
// trustedCAs is currently a global, so loadCA operates on that global directly
trustedCAs, err = loadCAFromConfig(config)
caPool, err := loadCAFromConfig(l, config)
if err != nil {
//The errors coming out of loadCA are already nicely formatted
return nil, NewContextualError("Failed to load ca from config", nil, err)
}
l.WithField("fingerprints", trustedCAs.GetFingerprints()).Debug("Trusted CA fingerprints")
l.WithField("fingerprints", caPool.GetFingerprints()).Debug("Trusted CA fingerprints")
cs, err := NewCertStateFromConfig(config)
if err != nil {
@@ -62,7 +56,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
}
l.WithField("cert", cs.certificate).Debug("Client nebula certificate")
fw, err := NewFirewallFromConfig(cs.certificate, config)
fw, err := NewFirewallFromConfig(l, cs.certificate, config)
if err != nil {
return nil, NewContextualError("Error while loading firewall rules", nil, err)
}
@@ -80,9 +74,10 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
}
ssh, err := sshd.NewSSHServer(l.WithField("subsystem", "sshd"))
wireSSHReload(ssh, config)
wireSSHReload(l, ssh, config)
var sshStart func()
if config.GetBool("sshd.enabled", false) {
err = configSSH(ssh, config)
sshStart, err = configSSH(l, ssh, config)
if err != nil {
return nil, NewContextualError("Error while configuring the sshd", nil, err)
}
@@ -93,15 +88,52 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
// tun config, listeners, anything modifying the computer should be below
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
var routines int
// If `routines` is set, use that and ignore the specific values
if routines = config.GetInt("routines", 0); routines != 0 {
if routines < 1 {
routines = 1
}
if routines > 1 {
l.WithField("routines", routines).Info("Using multiple routines")
}
} else {
// deprecated and undocumented
tunQueues := config.GetInt("tun.routines", 1)
udpQueues := config.GetInt("listen.routines", 1)
if tunQueues > udpQueues {
routines = tunQueues
} else {
routines = udpQueues
}
if routines != 1 {
l.WithField("routines", routines).Warn("Setting tun.routines and listen.routines is deprecated. Use `routines` instead")
}
}
// EXPERIMENTAL
// Intentionally not documented yet while we do more testing and determine
// a good default value.
conntrackCacheTimeout := config.GetDuration("firewall.conntrack.routine_cache_timeout", 0)
if routines > 1 && !config.IsSet("firewall.conntrack.routine_cache_timeout") {
// Use a different default if we are running with multiple routines
conntrackCacheTimeout = 1 * time.Second
}
if conntrackCacheTimeout > 0 {
l.WithField("duration", conntrackCacheTimeout).Info("Using routine-local conntrack cache")
}
var tun Inside
if !configTest {
config.CatchHUP()
switch {
case config.GetBool("tun.disabled", false):
tun = newDisabledTun(tunCidr, l)
tun = newDisabledTun(tunCidr, config.GetInt("tun.tx_queue", 500), config.GetBool("stats.message_metrics", false), l)
case tunFd != nil:
tun, err = newTunFromFd(
l,
*tunFd,
tunCidr,
config.GetInt("tun.mtu", DEFAULT_MTU),
@@ -111,12 +143,14 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
)
default:
tun, err = newTun(
l,
config.GetString("tun.dev", ""),
tunCidr,
config.GetInt("tun.mtu", DEFAULT_MTU),
routes,
unsafeRoutes,
config.GetInt("tun.tx_queue", 500),
routines > 1,
)
}
@@ -126,15 +160,27 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
}
// set up our UDP listener
udpQueues := config.GetInt("listen.routines", 1)
var udpServer *udpConn
udpConns := make([]*udpConn, routines)
port := config.GetInt("listen.port", 0)
if !configTest {
udpServer, err = NewListener(config.GetString("listen.host", "0.0.0.0"), config.GetInt("listen.port", 0), udpQueues > 1)
if err != nil {
return nil, NewContextualError("Failed to open udp listener", nil, err)
for i := 0; i < routines; i++ {
udpServer, err := NewListener(l, config.GetString("listen.host", "0.0.0.0"), port, routines > 1)
if err != nil {
return nil, NewContextualError("Failed to open udp listener", m{"queue": i}, err)
}
udpServer.reloadConfig(config)
udpConns[i] = udpServer
// If port is dynamic, discover it
if port == 0 {
uPort, err := udpServer.LocalAddr()
if err != nil {
return nil, NewContextualError("Failed to get listening port", nil, err)
}
port = int(uPort.Port)
}
}
udpServer.reloadConfig(config)
}
// Set up my internal host map
@@ -175,8 +221,8 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
}
}
hostMap := NewHostMap("main", tunCidr, preferredRanges)
hostMap.SetDefaultRoute(ip2int(net.ParseIP(config.GetString("default_route", "0.0.0.0"))))
hostMap := NewHostMap(l, "main", tunCidr, preferredRanges)
hostMap.addUnsafeRoutes(&unsafeRoutes)
hostMap.metricsEnabled = config.GetBool("stats.message_metrics", false)
@@ -190,21 +236,16 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
punchy := NewPunchyFromConfig(config)
if punchy.Punch && !configTest {
l.Info("UDP hole punching enabled")
go hostMap.Punchy(udpServer)
}
port := config.GetInt("listen.port", 0)
// If port is dynamic, discover it
if port == 0 && !configTest {
uPort, err := udpServer.LocalAddr()
if err != nil {
return nil, NewContextualError("Failed to get listening port", nil, err)
}
port = int(uPort.Port)
go hostMap.Punchy(udpConns[0])
}
amLighthouse := config.GetBool("lighthouse.am_lighthouse", false)
// fatal if am_lighthouse is enabled but we are using an ephemeral port
if amLighthouse && (config.GetInt("listen.port", 0) == 0) {
return nil, NewContextualError("lighthouse.am_lighthouse enabled on node but no port number is set in config", nil, nil)
}
// warn if am_lighthouse is enabled but upstream lighthouses exists
rawLighthouseHosts := config.GetStringSlice("lighthouse.hosts", []string{})
if amLighthouse && len(rawLighthouseHosts) != 0 {
@@ -224,13 +265,14 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
}
lightHouse := NewLightHouse(
l,
amLighthouse,
ip2int(tunCidr.IP),
tunCidr,
lighthouseHosts,
//TODO: change to a duration
config.GetInt("lighthouse.interval", 10),
port,
udpServer,
uint32(port),
udpConns[0],
punchy.Respond,
punchy.Delay,
config.GetBool("stats.lighthouse_metrics", false),
@@ -257,29 +299,18 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
vals, ok := v.([]interface{})
if ok {
for _, v := range vals {
parts := strings.Split(fmt.Sprintf("%v", v), ":")
addr, err := net.ResolveIPAddr("ip", parts[0])
if err == nil {
ip := addr.IP
port, err := strconv.Atoi(parts[1])
if err != nil {
return nil, NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp}, err)
}
lightHouse.AddRemote(ip2int(vpnIp), NewUDPAddr(ip2int(ip), uint16(port)), true)
}
}
} else {
//TODO: make this all a helper
parts := strings.Split(fmt.Sprintf("%v", v), ":")
addr, err := net.ResolveIPAddr("ip", parts[0])
if err == nil {
ip := addr.IP
port, err := strconv.Atoi(parts[1])
ip, port, err := parseIPAndPort(fmt.Sprintf("%v", v))
if err != nil {
return nil, NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp}, err)
}
lightHouse.AddRemote(ip2int(vpnIp), NewUDPAddr(ip2int(ip), uint16(port)), true)
lightHouse.AddStaticRemote(ip2int(vpnIp), NewUDPAddr(ip, port))
}
} else {
ip, port, err := parseIPAndPort(fmt.Sprintf("%v", v))
if err != nil {
return nil, NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp}, err)
}
lightHouse.AddStaticRemote(ip2int(vpnIp), NewUDPAddr(ip, port))
}
}
@@ -298,26 +329,33 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
handshakeConfig := HandshakeConfig{
tryInterval: config.GetDuration("handshakes.try_interval", DefaultHandshakeTryInterval),
retries: config.GetInt("handshakes.retries", DefaultHandshakeRetries),
waitRotation: config.GetInt("handshakes.wait_rotation", DefaultHandshakeWaitRotation),
triggerBuffer: config.GetInt("handshakes.trigger_buffer", DefaultHandshakeTriggerBuffer),
messageMetrics: messageMetrics,
}
handshakeManager := NewHandshakeManager(tunCidr, preferredRanges, hostMap, lightHouse, udpServer, handshakeConfig)
handshakeManager := NewHandshakeManager(l, tunCidr, preferredRanges, hostMap, lightHouse, udpConns[0], handshakeConfig)
lightHouse.handshakeTrigger = handshakeManager.trigger
//TODO: These will be reused for psk
//handshakeMACKey := config.GetString("handshake_mac.key", "")
//handshakeAcceptedMACKeys := config.GetStringSlice("handshake_mac.accepted_keys", []string{})
serveDns := config.GetBool("lighthouse.serve_dns", false)
serveDns := false
if config.GetBool("lighthouse.serve_dns", false) {
if config.GetBool("lighthouse.am_lighthouse", false) {
serveDns = true
} else {
l.Warn("DNS server refusing to run because this host is not a lighthouse.")
}
}
checkInterval := config.GetInt("timers.connection_alive_interval", 5)
pendingDeletionInterval := config.GetInt("timers.pending_deletion_interval", 10)
ifConfig := &InterfaceConfig{
HostMap: hostMap,
Inside: tun,
Outside: udpServer,
Outside: udpConns[0],
certState: cs,
Cipher: config.GetString("cipher", "aes"),
Firewall: fw,
@@ -329,10 +367,13 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
DropLocalBroadcast: config.GetBool("tun.drop_local_broadcast", false),
DropMulticast: config.GetBool("tun.drop_multicast", false),
UDPBatchSize: config.GetInt("listen.batch", 64),
udpQueues: udpQueues,
tunQueues: config.GetInt("tun.routines", 1),
routines: routines,
MessageMetrics: messageMetrics,
version: buildVersion,
caPool: caPool,
ConntrackCacheTimeout: conntrackCacheTimeout,
l: l,
}
switch ifConfig.Cipher {
@@ -351,13 +392,17 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
return nil, fmt.Errorf("failed to initialize interface: %s", err)
}
// TODO: Better way to attach these, probably want a new interface in InterfaceConfig
// I don't want to make this initial commit too far-reaching though
ifce.writers = udpConns
ifce.RegisterConfigChangeCallbacks(config)
go handshakeManager.Run(ifce)
go lightHouse.LhUpdateWorker(ifce)
}
err = startStats(config, configTest)
statsStart, err := startStats(l, config, buildVersion, configTest)
if err != nil {
return nil, NewContextualError("Failed to start stats emitter", nil, err)
}
@@ -369,13 +414,14 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
//TODO: check if we _should_ be emitting stats
go ifce.emitStats(config.GetDuration("stats.interval", time.Second*10))
attachCommands(ssh, hostMap, handshakeManager.pendingHostMap, lightHouse, ifce)
attachCommands(l, ssh, hostMap, handshakeManager.pendingHostMap, lightHouse, ifce)
// Start DNS server last to allow using the nebula IP as lighthouse.dns.host
var dnsStart func()
if amLighthouse && serveDns {
l.Debugln("Starting dns server")
go dnsMain(hostMap, config)
dnsStart = dnsMain(l, hostMap, config)
}
return &Control{ifce, l}, nil
return &Control{ifce, l, sshStart, statsStart, dnsStart}, nil
}

View File

@@ -1 +1,30 @@
package nebula
import (
"io/ioutil"
"os"
"github.com/sirupsen/logrus"
)
func NewTestLogger() *logrus.Logger {
l := logrus.New()
v := os.Getenv("TEST_LOGS")
if v == "" {
l.SetOutput(ioutil.Discard)
return l
}
switch v {
case "1":
// This is the default level but we are being explicit
l.SetLevel(logrus.InfoLevel)
case "2":
l.SetLevel(logrus.DebugLevel)
case "3":
l.SetLevel(logrus.TraceLevel)
}
return l
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,6 +1,8 @@
syntax = "proto3";
package nebula;
option go_package = "github.com/slackhq/nebula";
message NebulaMeta {
enum MessageType {
None = 0;
@@ -20,19 +22,23 @@ message NebulaMeta {
NebulaMetaDetails Details = 2;
}
message NebulaMetaDetails {
uint32 VpnIp = 1;
repeated IpAndPort IpAndPorts = 2;
repeated Ip4AndPort Ip4AndPorts = 2;
repeated Ip6AndPort Ip6AndPorts = 4;
uint32 counter = 3;
}
message IpAndPort {
message Ip4AndPort {
uint32 Ip = 1;
uint32 Port = 2;
}
message Ip6AndPort {
uint64 Hi = 1;
uint64 Lo = 2;
uint32 Port = 3;
}
message NebulaPing {
enum MessageType {

View File

@@ -17,14 +17,14 @@ const (
minFwPacketLen = 4
)
func (f *Interface) readOutsidePackets(addr *udpAddr, out []byte, packet []byte, header *Header, fwPacket *FirewallPacket, nb []byte) {
func (f *Interface) readOutsidePackets(addr *udpAddr, out []byte, packet []byte, header *Header, fwPacket *FirewallPacket, lhh *LightHouseHandler, nb []byte, q int, localCache ConntrackCache) {
err := header.Parse(packet)
if err != nil {
// TODO: best if we return this and let caller log
// TODO: Might be better to send the literal []byte("holepunch") packet and ignore that?
// Hole punch packets are 0 or 1 byte big, so lets ignore printing those errors
if len(packet) > 1 {
l.WithField("packet", packet).Infof("Error while parsing inbound packet from %s: %s", addr, err)
f.l.WithField("packet", packet).Infof("Error while parsing inbound packet from %s: %s", addr, err)
}
return
}
@@ -45,7 +45,7 @@ func (f *Interface) readOutsidePackets(addr *udpAddr, out []byte, packet []byte,
return
}
f.decryptToTun(hostinfo, header.MessageCounter, out, packet, fwPacket, nb)
f.decryptToTun(hostinfo, header.MessageCounter, out, packet, fwPacket, nb, q, localCache)
// Fallthrough to the bottom to record incoming traffic
@@ -57,7 +57,7 @@ func (f *Interface) readOutsidePackets(addr *udpAddr, out []byte, packet []byte,
d, err := f.decrypt(hostinfo, header.MessageCounter, out, packet, header, nb)
if err != nil {
hostinfo.logger().WithError(err).WithField("udpAddr", addr).
hostinfo.logger(f.l).WithError(err).WithField("udpAddr", addr).
WithField("packet", packet).
Error("Failed to decrypt lighthouse packet")
@@ -66,7 +66,7 @@ func (f *Interface) readOutsidePackets(addr *udpAddr, out []byte, packet []byte,
return
}
f.lightHouse.HandleRequest(addr, hostinfo.hostId, d, hostinfo.GetCert(), f)
lhh.HandleRequest(addr, hostinfo.hostId, d, f)
// Fallthrough to the bottom to record incoming traffic
@@ -78,7 +78,7 @@ func (f *Interface) readOutsidePackets(addr *udpAddr, out []byte, packet []byte,
d, err := f.decrypt(hostinfo, header.MessageCounter, out, packet, header, nb)
if err != nil {
hostinfo.logger().WithError(err).WithField("udpAddr", addr).
hostinfo.logger(f.l).WithError(err).WithField("udpAddr", addr).
WithField("packet", packet).
Error("Failed to decrypt test packet")
@@ -106,7 +106,6 @@ func (f *Interface) readOutsidePackets(addr *udpAddr, out []byte, packet []byte,
case recvError:
f.messageMetrics.Rx(header.Type, header.Subtype, 1)
// TODO: Remove this with recv_error deprecation
f.handleRecvError(addr, header)
return
@@ -116,15 +115,15 @@ func (f *Interface) readOutsidePackets(addr *udpAddr, out []byte, packet []byte,
return
}
hostinfo.logger().WithField("udpAddr", addr).
hostinfo.logger(f.l).WithField("udpAddr", addr).
Info("Close tunnel received, tearing down.")
f.closeTunnel(hostinfo)
f.closeTunnel(hostinfo, false)
return
default:
f.messageMetrics.Rx(header.Type, header.Subtype, 1)
hostinfo.logger().Debugf("Unexpected packet received from %s", addr)
hostinfo.logger(f.l).Debugf("Unexpected packet received from %s", addr)
return
}
@@ -133,38 +132,45 @@ func (f *Interface) readOutsidePackets(addr *udpAddr, out []byte, packet []byte,
f.connectionManager.In(hostinfo.hostId)
}
func (f *Interface) closeTunnel(hostInfo *HostInfo) {
// closeTunnel closes a tunnel locally, it does not send a closeTunnel packet to the remote
func (f *Interface) closeTunnel(hostInfo *HostInfo, hasHostMapLock bool) {
//TODO: this would be better as a single function in ConnectionManager that handled locks appropriately
f.connectionManager.ClearIP(hostInfo.hostId)
f.connectionManager.ClearPendingDeletion(hostInfo.hostId)
f.lightHouse.DeleteVpnIP(hostInfo.hostId)
f.hostMap.DeleteVpnIP(hostInfo.hostId)
f.hostMap.DeleteIndex(hostInfo.localIndexId)
if hasHostMapLock {
f.hostMap.unlockedDeleteHostInfo(hostInfo)
} else {
f.hostMap.DeleteHostInfo(hostInfo)
}
}
// sendCloseTunnel is a helper function to send a proper close tunnel packet to a remote
func (f *Interface) sendCloseTunnel(h *HostInfo) {
f.send(closeTunnel, 0, h.ConnectionState, h, h.remote, []byte{}, make([]byte, 12, 12), make([]byte, mtu))
}
func (f *Interface) handleHostRoaming(hostinfo *HostInfo, addr *udpAddr) {
if hostDidRoam(hostinfo.remote, addr) {
if !f.lightHouse.remoteAllowList.Allow(udp2ipInt(addr)) {
hostinfo.logger().WithField("newAddr", addr).Debug("lighthouse.remote_allow_list denied roaming")
if !f.lightHouse.remoteAllowList.Allow(addr.IP) {
hostinfo.logger(f.l).WithField("newAddr", addr).Debug("lighthouse.remote_allow_list denied roaming")
return
}
if !hostinfo.lastRoam.IsZero() && addr.Equals(hostinfo.lastRoamRemote) && time.Since(hostinfo.lastRoam) < RoamingSupressSeconds*time.Second {
if l.Level >= logrus.DebugLevel {
hostinfo.logger().WithField("udpAddr", hostinfo.remote).WithField("newAddr", addr).
Debugf("Supressing roam back to previous remote for %d seconds", RoamingSupressSeconds)
if !hostinfo.lastRoam.IsZero() && addr.Equals(hostinfo.lastRoamRemote) && time.Since(hostinfo.lastRoam) < RoamingSuppressSeconds*time.Second {
if f.l.Level >= logrus.DebugLevel {
hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", addr).
Debugf("Suppressing roam back to previous remote for %d seconds", RoamingSuppressSeconds)
}
return
}
hostinfo.logger().WithField("udpAddr", hostinfo.remote).WithField("newAddr", addr).
hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", addr).
Info("Host roamed to new udp ip/port.")
hostinfo.lastRoam = time.Now()
remoteCopy := *hostinfo.remote
hostinfo.lastRoamRemote = &remoteCopy
hostinfo.SetRemote(*addr)
if f.lightHouse.amLighthouse {
f.lightHouse.AddRemote(hostinfo.hostId, addr, false)
}
hostinfo.SetRemote(addr)
}
}
@@ -172,7 +178,7 @@ func (f *Interface) handleHostRoaming(hostinfo *HostInfo, addr *udpAddr) {
func (f *Interface) handleEncrypted(ci *ConnectionState, addr *udpAddr, header *Header) bool {
// If connectionstate exists and the replay protector allows, process packet
// Else, send recv errors for 300 seconds after a restart to allow fast reconnection.
if ci == nil || !ci.window.Check(header.MessageCounter) {
if ci == nil || !ci.window.Check(f.l, header.MessageCounter) {
f.sendRecvError(addr, header.RemoteIndex)
return false
}
@@ -249,8 +255,8 @@ func (f *Interface) decrypt(hostinfo *HostInfo, mc uint64, out []byte, packet []
return nil, err
}
if !hostinfo.ConnectionState.window.Update(mc) {
hostinfo.logger().WithField("header", header).
if !hostinfo.ConnectionState.window.Update(f.l, mc) {
hostinfo.logger(f.l).WithField("header", header).
Debugln("dropping out of window packet")
return nil, errors.New("out of window packet")
}
@@ -258,12 +264,12 @@ func (f *Interface) decrypt(hostinfo *HostInfo, mc uint64, out []byte, packet []
return out, nil
}
func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out []byte, packet []byte, fwPacket *FirewallPacket, nb []byte) {
func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out []byte, packet []byte, fwPacket *FirewallPacket, nb []byte, q int, localCache ConntrackCache) {
var err error
out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:HeaderLen], packet[HeaderLen:], messageCounter, nb)
if err != nil {
hostinfo.logger().WithError(err).Error("Failed to decrypt packet")
hostinfo.logger(f.l).WithError(err).Error("Failed to decrypt packet")
//TODO: maybe after build 64 is out? 06/14/2018 - NB
//f.sendRecvError(hostinfo.remote, header.RemoteIndex)
return
@@ -271,21 +277,21 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
err = newPacket(out, true, fwPacket)
if err != nil {
hostinfo.logger().WithError(err).WithField("packet", out).
hostinfo.logger(f.l).WithError(err).WithField("packet", out).
Warnf("Error while validating inbound packet")
return
}
if !hostinfo.ConnectionState.window.Update(messageCounter) {
hostinfo.logger().WithField("fwPacket", fwPacket).
if !hostinfo.ConnectionState.window.Update(f.l, messageCounter) {
hostinfo.logger(f.l).WithField("fwPacket", fwPacket).
Debugln("dropping out of window packet")
return
}
dropReason := f.firewall.Drop(out, *fwPacket, true, hostinfo, trustedCAs)
dropReason := f.firewall.Drop(out, *fwPacket, true, hostinfo, f.caPool, localCache)
if dropReason != nil {
if l.Level >= logrus.DebugLevel {
hostinfo.logger().WithField("fwPacket", fwPacket).
if f.l.Level >= logrus.DebugLevel {
hostinfo.logger(f.l).WithField("fwPacket", fwPacket).
WithField("reason", dropReason).
Debugln("dropping inbound packet")
}
@@ -293,9 +299,9 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
}
f.connectionManager.In(hostinfo.hostId)
err = f.inside.WriteRaw(out)
_, err = f.readers[q].Write(out)
if err != nil {
l.WithError(err).Error("Failed to write to tun")
f.l.WithError(err).Error("Failed to write to tun")
}
}
@@ -305,47 +311,47 @@ func (f *Interface) sendRecvError(endpoint *udpAddr, index uint32) {
//TODO: this should be a signed message so we can trust that we should drop the index
b := HeaderEncode(make([]byte, HeaderLen), Version, uint8(recvError), 0, index, 0)
f.outside.WriteTo(b, endpoint)
if l.Level >= logrus.DebugLevel {
l.WithField("index", index).
if f.l.Level >= logrus.DebugLevel {
f.l.WithField("index", index).
WithField("udpAddr", endpoint).
Debug("Recv error sent")
}
}
func (f *Interface) handleRecvError(addr *udpAddr, h *Header) {
// This flag is to stop caring about recv_error from old versions
// This should go away when the old version is gone from prod
if l.Level >= logrus.DebugLevel {
l.WithField("index", h.RemoteIndex).
if f.l.Level >= logrus.DebugLevel {
f.l.WithField("index", h.RemoteIndex).
WithField("udpAddr", addr).
Debug("Recv error received")
}
// First, clean up in the pending hostmap
f.handshakeManager.pendingHostMap.DeleteReverseIndex(h.RemoteIndex)
hostinfo, err := f.hostMap.QueryReverseIndex(h.RemoteIndex)
if err != nil {
l.Debugln(err, ": ", h.RemoteIndex)
f.l.Debugln(err, ": ", h.RemoteIndex)
return
}
hostinfo.Lock()
defer hostinfo.Unlock()
if !hostinfo.RecvErrorExceeded() {
return
}
if hostinfo.remote != nil && hostinfo.remote.String() != addr.String() {
l.Infoln("Someone spoofing recv_errors? ", addr, hostinfo.remote)
if hostinfo.remote != nil && hostinfo.remote.Equals(addr) {
f.l.Infoln("Someone spoofing recv_errors? ", addr, hostinfo.remote)
return
}
id := hostinfo.localIndexId
host := hostinfo.hostId
// We delete this host from the main hostmap
f.hostMap.DeleteIndex(id)
f.hostMap.DeleteVpnIP(host)
f.hostMap.DeleteHostInfo(hostinfo)
// We also delete it from pending to allow for
// fast reconnect. We must null the connectionstate
// or a counter reuse may happen
hostinfo.ConnectionState = nil
f.handshakeManager.DeleteIndex(id)
f.handshakeManager.DeleteVpnIP(host)
f.handshakeManager.DeleteHostInfo(hostinfo)
}
/*
@@ -370,7 +376,7 @@ func (f *Interface) sendMeta(ci *ConnectionState, endpoint *net.UDPAddr, meta *N
}
*/
func RecombineCertAndValidate(h *noise.HandshakeState, rawCertBytes []byte) (*cert.NebulaCertificate, error) {
func RecombineCertAndValidate(h *noise.HandshakeState, rawCertBytes []byte, caPool *cert.NebulaCAPool) (*cert.NebulaCertificate, error) {
pk := h.PeerStatic()
if pk == nil {
@@ -399,7 +405,7 @@ func RecombineCertAndValidate(h *noise.HandshakeState, rawCertBytes []byte) (*ce
}
c, _ := cert.UnmarshalNebulaCertificate(recombined)
isValid, err := c.Verify(time.Now(), trustedCAs)
isValid, err := c.Verify(time.Now(), caPool)
if err != nil {
return c, fmt.Errorf("certificate validation failed: %s", err)
} else if !isValid {

View File

@@ -8,7 +8,8 @@ import (
)
func TestNewPunchyFromConfig(t *testing.T) {
c := NewConfig()
l := NewTestLogger()
c := NewConfig(l)
// Test defaults
p := NewPunchyFromConfig(c)

504
remote_list.go Normal file
View File

@@ -0,0 +1,504 @@
package nebula
import (
"bytes"
"net"
"sort"
"sync"
)
// forEachFunc is used to benefit folks that want to do work inside the lock
type forEachFunc func(addr *udpAddr, preferred bool)
// The checkFuncs here are to simplify bulk importing LH query response logic into a single function (reset slice and iterate)
type checkFuncV4 func(to *Ip4AndPort) bool
type checkFuncV6 func(to *Ip6AndPort) bool
// CacheMap is a struct that better represents the lighthouse cache for humans
// The string key is the owners vpnIp
type CacheMap map[string]*Cache
// Cache is the other part of CacheMap to better represent the lighthouse cache for humans
// We don't reason about ipv4 vs ipv6 here
type Cache struct {
Learned []*udpAddr `json:"learned,omitempty"`
Reported []*udpAddr `json:"reported,omitempty"`
}
//TODO: Seems like we should plop static host entries in here too since the are protected by the lighthouse from deletion
// We will never clean learned/reported information for them as it stands today
// cache is an internal struct that splits v4 and v6 addresses inside the cache map
type cache struct {
v4 *cacheV4
v6 *cacheV6
}
// cacheV4 stores learned and reported ipv4 records under cache
type cacheV4 struct {
learned *Ip4AndPort
reported []*Ip4AndPort
}
// cacheV4 stores learned and reported ipv6 records under cache
type cacheV6 struct {
learned *Ip6AndPort
reported []*Ip6AndPort
}
// RemoteList is a unifying concept for lighthouse servers and clients as well as hostinfos.
// It serves as a local cache of query replies, host update notifications, and locally learned addresses
type RemoteList struct {
// Every interaction with internals requires a lock!
sync.RWMutex
// A deduplicated set of addresses. Any accessor should lock beforehand.
addrs []*udpAddr
// These are maps to store v4 and v6 addresses per lighthouse
// Map key is the vpnIp of the person that told us about this the cached entries underneath.
// For learned addresses, this is the vpnIp that sent the packet
cache map[uint32]*cache
// This is a list of remotes that we have tried to handshake with and have returned from the wrong vpn ip.
// They should not be tried again during a handshake
badRemotes []*udpAddr
// A flag that the cache may have changed and addrs needs to be rebuilt
shouldRebuild bool
}
// NewRemoteList creates a new empty RemoteList
func NewRemoteList() *RemoteList {
return &RemoteList{
addrs: make([]*udpAddr, 0),
cache: make(map[uint32]*cache),
}
}
// Len locks and reports the size of the deduplicated address list
// The deduplication work may need to occur here, so you must pass preferredRanges
func (r *RemoteList) Len(preferredRanges []*net.IPNet) int {
r.Rebuild(preferredRanges)
r.RLock()
defer r.RUnlock()
return len(r.addrs)
}
// ForEach locks and will call the forEachFunc for every deduplicated address in the list
// The deduplication work may need to occur here, so you must pass preferredRanges
func (r *RemoteList) ForEach(preferredRanges []*net.IPNet, forEach forEachFunc) {
r.Rebuild(preferredRanges)
r.RLock()
for _, v := range r.addrs {
forEach(v, isPreferred(v.IP, preferredRanges))
}
r.RUnlock()
}
// CopyAddrs locks and makes a deep copy of the deduplicated address list
// The deduplication work may need to occur here, so you must pass preferredRanges
func (r *RemoteList) CopyAddrs(preferredRanges []*net.IPNet) []*udpAddr {
if r == nil {
return nil
}
r.Rebuild(preferredRanges)
r.RLock()
defer r.RUnlock()
c := make([]*udpAddr, len(r.addrs))
for i, v := range r.addrs {
c[i] = v.Copy()
}
return c
}
// LearnRemote locks and sets the learned slot for the owner vpn ip to the provided addr
// Currently this is only needed when HostInfo.SetRemote is called as that should cover both handshaking and roaming.
// It will mark the deduplicated address list as dirty, so do not call it unless new information is available
//TODO: this needs to support the allow list list
func (r *RemoteList) LearnRemote(ownerVpnIp uint32, addr *udpAddr) {
r.Lock()
defer r.Unlock()
if v4 := addr.IP.To4(); v4 != nil {
r.unlockedSetLearnedV4(ownerVpnIp, NewIp4AndPort(v4, uint32(addr.Port)))
} else {
r.unlockedSetLearnedV6(ownerVpnIp, NewIp6AndPort(addr.IP, uint32(addr.Port)))
}
}
// CopyCache locks and creates a more human friendly form of the internal address cache.
// This may contain duplicates and blocked addresses
func (r *RemoteList) CopyCache() *CacheMap {
r.RLock()
defer r.RUnlock()
cm := make(CacheMap)
getOrMake := func(vpnIp string) *Cache {
c := cm[vpnIp]
if c == nil {
c = &Cache{
Learned: make([]*udpAddr, 0),
Reported: make([]*udpAddr, 0),
}
cm[vpnIp] = c
}
return c
}
for owner, mc := range r.cache {
c := getOrMake(IntIp(owner).String())
if mc.v4 != nil {
if mc.v4.learned != nil {
c.Learned = append(c.Learned, NewUDPAddrFromLH4(mc.v4.learned))
}
for _, a := range mc.v4.reported {
c.Reported = append(c.Reported, NewUDPAddrFromLH4(a))
}
}
if mc.v6 != nil {
if mc.v6.learned != nil {
c.Learned = append(c.Learned, NewUDPAddrFromLH6(mc.v6.learned))
}
for _, a := range mc.v6.reported {
c.Reported = append(c.Reported, NewUDPAddrFromLH6(a))
}
}
}
return &cm
}
// BlockRemote locks and records the address as bad, it will be excluded from the deduplicated address list
func (r *RemoteList) BlockRemote(bad *udpAddr) {
r.Lock()
defer r.Unlock()
// Check if we already blocked this addr
if r.unlockedIsBad(bad) {
return
}
// We copy here because we are taking something else's memory and we can't trust everything
r.badRemotes = append(r.badRemotes, bad.Copy())
// Mark the next interaction must recollect/dedupe
r.shouldRebuild = true
}
// CopyBlockedRemotes locks and makes a deep copy of the blocked remotes list
func (r *RemoteList) CopyBlockedRemotes() []*udpAddr {
r.RLock()
defer r.RUnlock()
c := make([]*udpAddr, len(r.badRemotes))
for i, v := range r.badRemotes {
c[i] = v.Copy()
}
return c
}
// ResetBlockedRemotes locks and clears the blocked remotes list
func (r *RemoteList) ResetBlockedRemotes() {
r.Lock()
r.badRemotes = nil
r.Unlock()
}
// Rebuild locks and generates the deduplicated address list only if there is work to be done
// There is generally no reason to call this directly but it is safe to do so
func (r *RemoteList) Rebuild(preferredRanges []*net.IPNet) {
r.Lock()
defer r.Unlock()
// Only rebuild if the cache changed
//TODO: shouldRebuild is probably pointless as we don't check for actual change when lighthouse updates come in
if r.shouldRebuild {
r.unlockedCollect()
r.shouldRebuild = false
}
// Always re-sort, preferredRanges can change via HUP
r.unlockedSort(preferredRanges)
}
// unlockedIsBad assumes you have the write lock and checks if the remote matches any entry in the blocked address list
func (r *RemoteList) unlockedIsBad(remote *udpAddr) bool {
for _, v := range r.badRemotes {
if v.Equals(remote) {
return true
}
}
return false
}
// unlockedSetLearnedV4 assumes you have the write lock and sets the current learned address for this owner and marks the
// deduplicated address list as dirty
func (r *RemoteList) unlockedSetLearnedV4(ownerVpnIp uint32, to *Ip4AndPort) {
r.shouldRebuild = true
r.unlockedGetOrMakeV4(ownerVpnIp).learned = to
}
// unlockedSetV4 assumes you have the write lock and resets the reported list of ips for this owner to the list provided
// and marks the deduplicated address list as dirty
func (r *RemoteList) unlockedSetV4(ownerVpnIp uint32, to []*Ip4AndPort, check checkFuncV4) {
r.shouldRebuild = true
c := r.unlockedGetOrMakeV4(ownerVpnIp)
// Reset the slice
c.reported = c.reported[:0]
// We can't take their array but we can take their pointers
for _, v := range to[:minInt(len(to), MaxRemotes)] {
if check(v) {
c.reported = append(c.reported, v)
}
}
}
// unlockedPrependV4 assumes you have the write lock and prepends the address in the reported list for this owner
// This is only useful for establishing static hosts
func (r *RemoteList) unlockedPrependV4(ownerVpnIp uint32, to *Ip4AndPort) {
r.shouldRebuild = true
c := r.unlockedGetOrMakeV4(ownerVpnIp)
// We are doing the easy append because this is rarely called
c.reported = append([]*Ip4AndPort{to}, c.reported...)
if len(c.reported) > MaxRemotes {
c.reported = c.reported[:MaxRemotes]
}
}
// unlockedSetLearnedV6 assumes you have the write lock and sets the current learned address for this owner and marks the
// deduplicated address list as dirty
func (r *RemoteList) unlockedSetLearnedV6(ownerVpnIp uint32, to *Ip6AndPort) {
r.shouldRebuild = true
r.unlockedGetOrMakeV6(ownerVpnIp).learned = to
}
// unlockedSetV6 assumes you have the write lock and resets the reported list of ips for this owner to the list provided
// and marks the deduplicated address list as dirty
func (r *RemoteList) unlockedSetV6(ownerVpnIp uint32, to []*Ip6AndPort, check checkFuncV6) {
r.shouldRebuild = true
c := r.unlockedGetOrMakeV6(ownerVpnIp)
// Reset the slice
c.reported = c.reported[:0]
// We can't take their array but we can take their pointers
for _, v := range to[:minInt(len(to), MaxRemotes)] {
if check(v) {
c.reported = append(c.reported, v)
}
}
}
// unlockedPrependV6 assumes you have the write lock and prepends the address in the reported list for this owner
// This is only useful for establishing static hosts
func (r *RemoteList) unlockedPrependV6(ownerVpnIp uint32, to *Ip6AndPort) {
r.shouldRebuild = true
c := r.unlockedGetOrMakeV6(ownerVpnIp)
// We are doing the easy append because this is rarely called
c.reported = append([]*Ip6AndPort{to}, c.reported...)
if len(c.reported) > MaxRemotes {
c.reported = c.reported[:MaxRemotes]
}
}
// unlockedGetOrMakeV4 assumes you have the write lock and builds the cache and owner entry. Only the v4 pointer is established.
// The caller must dirty the learned address cache if required
func (r *RemoteList) unlockedGetOrMakeV4(ownerVpnIp uint32) *cacheV4 {
am := r.cache[ownerVpnIp]
if am == nil {
am = &cache{}
r.cache[ownerVpnIp] = am
}
// Avoid occupying memory for v6 addresses if we never have any
if am.v4 == nil {
am.v4 = &cacheV4{}
}
return am.v4
}
// unlockedGetOrMakeV6 assumes you have the write lock and builds the cache and owner entry. Only the v6 pointer is established.
// The caller must dirty the learned address cache if required
func (r *RemoteList) unlockedGetOrMakeV6(ownerVpnIp uint32) *cacheV6 {
am := r.cache[ownerVpnIp]
if am == nil {
am = &cache{}
r.cache[ownerVpnIp] = am
}
// Avoid occupying memory for v4 addresses if we never have any
if am.v6 == nil {
am.v6 = &cacheV6{}
}
return am.v6
}
// unlockedCollect assumes you have the write lock and collects/transforms the cache into the deduped address list.
// The result of this function can contain duplicates. unlockedSort handles cleaning it.
func (r *RemoteList) unlockedCollect() {
addrs := r.addrs[:0]
for _, c := range r.cache {
if c.v4 != nil {
if c.v4.learned != nil {
u := NewUDPAddrFromLH4(c.v4.learned)
if !r.unlockedIsBad(u) {
addrs = append(addrs, u)
}
}
for _, v := range c.v4.reported {
u := NewUDPAddrFromLH4(v)
if !r.unlockedIsBad(u) {
addrs = append(addrs, u)
}
}
}
if c.v6 != nil {
if c.v6.learned != nil {
u := NewUDPAddrFromLH6(c.v6.learned)
if !r.unlockedIsBad(u) {
addrs = append(addrs, u)
}
}
for _, v := range c.v6.reported {
u := NewUDPAddrFromLH6(v)
if !r.unlockedIsBad(u) {
addrs = append(addrs, u)
}
}
}
}
r.addrs = addrs
}
// unlockedSort assumes you have the write lock and performs the deduping and sorting of the address list
func (r *RemoteList) unlockedSort(preferredRanges []*net.IPNet) {
n := len(r.addrs)
if n < 2 {
return
}
lessFunc := func(i, j int) bool {
a := r.addrs[i]
b := r.addrs[j]
// Preferred addresses first
aPref := isPreferred(a.IP, preferredRanges)
bPref := isPreferred(b.IP, preferredRanges)
switch {
case aPref && !bPref:
// If i is preferred and j is not, i is less than j
return true
case !aPref && bPref:
// If j is preferred then i is not due to the else, i is not less than j
return false
default:
// Both i an j are either preferred or not, sort within that
}
// ipv6 addresses 2nd
a4 := a.IP.To4()
b4 := b.IP.To4()
switch {
case a4 == nil && b4 != nil:
// If i is v6 and j is v4, i is less than j
return true
case a4 != nil && b4 == nil:
// If j is v6 and i is v4, i is not less than j
return false
case a4 != nil && b4 != nil:
// Special case for ipv4, a4 and b4 are not nil
aPrivate := isPrivateIP(a4)
bPrivate := isPrivateIP(b4)
switch {
case !aPrivate && bPrivate:
// If i is a public ip (not private) and j is a private ip, i is less then j
return true
case aPrivate && !bPrivate:
// If j is public (not private) then i is private due to the else, i is not less than j
return false
default:
// Both i an j are either public or private, sort within that
}
default:
// Both i an j are either ipv4 or ipv6, sort within that
}
// lexical order of ips 3rd
c := bytes.Compare(a.IP, b.IP)
if c == 0 {
// Ips are the same, Lexical order of ports 4th
return a.Port < b.Port
}
// Ip wasn't the same
return c < 0
}
// Sort it
sort.Slice(r.addrs, lessFunc)
// Deduplicate
a, b := 0, 1
for b < n {
if !r.addrs[a].Equals(r.addrs[b]) {
a++
if a != b {
r.addrs[a], r.addrs[b] = r.addrs[b], r.addrs[a]
}
}
b++
}
r.addrs = r.addrs[:a+1]
return
}
// minInt returns the minimum integer of a or b
func minInt(a, b int) int {
if a < b {
return a
}
return b
}
// isPreferred returns true of the ip is contained in the preferredRanges list
func isPreferred(ip net.IP, preferredRanges []*net.IPNet) bool {
//TODO: this would be better in a CIDR6Tree
for _, p := range preferredRanges {
if p.Contains(ip) {
return true
}
}
return false
}
var _, private24BitBlock, _ = net.ParseCIDR("10.0.0.0/8")
var _, private20BitBlock, _ = net.ParseCIDR("172.16.0.0/12")
var _, private16BitBlock, _ = net.ParseCIDR("192.168.0.0/16")
// isPrivateIP returns true if the ip is contained by a rfc 1918 private range
func isPrivateIP(ip net.IP) bool {
//TODO: another great cidrtree option
//TODO: Private for ipv6 or just let it ride?
return private24BitBlock.Contains(ip) || private20BitBlock.Contains(ip) || private16BitBlock.Contains(ip)
}

228
remote_list_test.go Normal file
View File

@@ -0,0 +1,228 @@
package nebula
import (
"net"
"testing"
"github.com/stretchr/testify/assert"
)
func TestRemoteList_Rebuild(t *testing.T) {
rl := NewRemoteList()
rl.unlockedSetV4(
0,
[]*Ip4AndPort{
{Ip: ip2int(net.ParseIP("70.199.182.92")), Port: 1475}, // this is duped
{Ip: ip2int(net.ParseIP("172.17.0.182")), Port: 10101},
{Ip: ip2int(net.ParseIP("172.17.1.1")), Port: 10101}, // this is duped
{Ip: ip2int(net.ParseIP("172.18.0.1")), Port: 10101}, // this is duped
{Ip: ip2int(net.ParseIP("172.18.0.1")), Port: 10101}, // this is a dupe
{Ip: ip2int(net.ParseIP("172.19.0.1")), Port: 10101},
{Ip: ip2int(net.ParseIP("172.31.0.1")), Port: 10101},
{Ip: ip2int(net.ParseIP("172.17.1.1")), Port: 10101}, // this is a dupe
{Ip: ip2int(net.ParseIP("70.199.182.92")), Port: 1476}, // almost dupe of 0 with a diff port
{Ip: ip2int(net.ParseIP("70.199.182.92")), Port: 1475}, // this is a dupe
},
func(*Ip4AndPort) bool { return true },
)
rl.unlockedSetV6(
1,
[]*Ip6AndPort{
NewIp6AndPort(net.ParseIP("1::1"), 1), // this is duped
NewIp6AndPort(net.ParseIP("1::1"), 2), // almost dupe of 0 with a diff port, also gets duped
NewIp6AndPort(net.ParseIP("1:100::1"), 1),
NewIp6AndPort(net.ParseIP("1::1"), 1), // this is a dupe
NewIp6AndPort(net.ParseIP("1::1"), 2), // this is a dupe
},
func(*Ip6AndPort) bool { return true },
)
rl.Rebuild([]*net.IPNet{})
assert.Len(t, rl.addrs, 10, "addrs contains too many entries")
// ipv6 first, sorted lexically within
assert.Equal(t, "[1::1]:1", rl.addrs[0].String())
assert.Equal(t, "[1::1]:2", rl.addrs[1].String())
assert.Equal(t, "[1:100::1]:1", rl.addrs[2].String())
// ipv4 last, sorted by public first, then private, lexically within them
assert.Equal(t, "70.199.182.92:1475", rl.addrs[3].String())
assert.Equal(t, "70.199.182.92:1476", rl.addrs[4].String())
assert.Equal(t, "172.17.0.182:10101", rl.addrs[5].String())
assert.Equal(t, "172.17.1.1:10101", rl.addrs[6].String())
assert.Equal(t, "172.18.0.1:10101", rl.addrs[7].String())
assert.Equal(t, "172.19.0.1:10101", rl.addrs[8].String())
assert.Equal(t, "172.31.0.1:10101", rl.addrs[9].String())
// Now ensure we can hoist ipv4 up
_, ipNet, err := net.ParseCIDR("0.0.0.0/0")
assert.NoError(t, err)
rl.Rebuild([]*net.IPNet{ipNet})
assert.Len(t, rl.addrs, 10, "addrs contains too many entries")
// ipv4 first, public then private, lexically within them
assert.Equal(t, "70.199.182.92:1475", rl.addrs[0].String())
assert.Equal(t, "70.199.182.92:1476", rl.addrs[1].String())
assert.Equal(t, "172.17.0.182:10101", rl.addrs[2].String())
assert.Equal(t, "172.17.1.1:10101", rl.addrs[3].String())
assert.Equal(t, "172.18.0.1:10101", rl.addrs[4].String())
assert.Equal(t, "172.19.0.1:10101", rl.addrs[5].String())
assert.Equal(t, "172.31.0.1:10101", rl.addrs[6].String())
// ipv6 last, sorted by public first, then private, lexically within them
assert.Equal(t, "[1::1]:1", rl.addrs[7].String())
assert.Equal(t, "[1::1]:2", rl.addrs[8].String())
assert.Equal(t, "[1:100::1]:1", rl.addrs[9].String())
// Ensure we can hoist a specific ipv4 range over anything else
_, ipNet, err = net.ParseCIDR("172.17.0.0/16")
assert.NoError(t, err)
rl.Rebuild([]*net.IPNet{ipNet})
assert.Len(t, rl.addrs, 10, "addrs contains too many entries")
// Preferred ipv4 first
assert.Equal(t, "172.17.0.182:10101", rl.addrs[0].String())
assert.Equal(t, "172.17.1.1:10101", rl.addrs[1].String())
// ipv6 next
assert.Equal(t, "[1::1]:1", rl.addrs[2].String())
assert.Equal(t, "[1::1]:2", rl.addrs[3].String())
assert.Equal(t, "[1:100::1]:1", rl.addrs[4].String())
// the remaining ipv4 last
assert.Equal(t, "70.199.182.92:1475", rl.addrs[5].String())
assert.Equal(t, "70.199.182.92:1476", rl.addrs[6].String())
assert.Equal(t, "172.18.0.1:10101", rl.addrs[7].String())
assert.Equal(t, "172.19.0.1:10101", rl.addrs[8].String())
assert.Equal(t, "172.31.0.1:10101", rl.addrs[9].String())
}
func BenchmarkFullRebuild(b *testing.B) {
rl := NewRemoteList()
rl.unlockedSetV4(
0,
[]*Ip4AndPort{
{Ip: ip2int(net.ParseIP("70.199.182.92")), Port: 1475},
{Ip: ip2int(net.ParseIP("172.17.0.182")), Port: 10101},
{Ip: ip2int(net.ParseIP("172.17.1.1")), Port: 10101},
{Ip: ip2int(net.ParseIP("172.18.0.1")), Port: 10101},
{Ip: ip2int(net.ParseIP("172.19.0.1")), Port: 10101},
{Ip: ip2int(net.ParseIP("172.31.0.1")), Port: 10101},
{Ip: ip2int(net.ParseIP("172.17.1.1")), Port: 10101}, // this is a dupe
{Ip: ip2int(net.ParseIP("70.199.182.92")), Port: 1476}, // dupe of 0 with a diff port
},
func(*Ip4AndPort) bool { return true },
)
rl.unlockedSetV6(
0,
[]*Ip6AndPort{
NewIp6AndPort(net.ParseIP("1::1"), 1),
NewIp6AndPort(net.ParseIP("1::1"), 2), // dupe of 0 with a diff port
NewIp6AndPort(net.ParseIP("1:100::1"), 1),
NewIp6AndPort(net.ParseIP("1::1"), 1), // this is a dupe
},
func(*Ip6AndPort) bool { return true },
)
b.Run("no preferred", func(b *testing.B) {
for i := 0; i < b.N; i++ {
rl.shouldRebuild = true
rl.Rebuild([]*net.IPNet{})
}
})
_, ipNet, err := net.ParseCIDR("172.17.0.0/16")
assert.NoError(b, err)
b.Run("1 preferred", func(b *testing.B) {
for i := 0; i < b.N; i++ {
rl.shouldRebuild = true
rl.Rebuild([]*net.IPNet{ipNet})
}
})
_, ipNet2, err := net.ParseCIDR("70.0.0.0/8")
assert.NoError(b, err)
b.Run("2 preferred", func(b *testing.B) {
for i := 0; i < b.N; i++ {
rl.shouldRebuild = true
rl.Rebuild([]*net.IPNet{ipNet, ipNet2})
}
})
_, ipNet3, err := net.ParseCIDR("0.0.0.0/0")
assert.NoError(b, err)
b.Run("3 preferred", func(b *testing.B) {
for i := 0; i < b.N; i++ {
rl.shouldRebuild = true
rl.Rebuild([]*net.IPNet{ipNet, ipNet2, ipNet3})
}
})
}
func BenchmarkSortRebuild(b *testing.B) {
rl := NewRemoteList()
rl.unlockedSetV4(
0,
[]*Ip4AndPort{
{Ip: ip2int(net.ParseIP("70.199.182.92")), Port: 1475},
{Ip: ip2int(net.ParseIP("172.17.0.182")), Port: 10101},
{Ip: ip2int(net.ParseIP("172.17.1.1")), Port: 10101},
{Ip: ip2int(net.ParseIP("172.18.0.1")), Port: 10101},
{Ip: ip2int(net.ParseIP("172.19.0.1")), Port: 10101},
{Ip: ip2int(net.ParseIP("172.31.0.1")), Port: 10101},
{Ip: ip2int(net.ParseIP("172.17.1.1")), Port: 10101}, // this is a dupe
{Ip: ip2int(net.ParseIP("70.199.182.92")), Port: 1476}, // dupe of 0 with a diff port
},
func(*Ip4AndPort) bool { return true },
)
rl.unlockedSetV6(
0,
[]*Ip6AndPort{
NewIp6AndPort(net.ParseIP("1::1"), 1),
NewIp6AndPort(net.ParseIP("1::1"), 2), // dupe of 0 with a diff port
NewIp6AndPort(net.ParseIP("1:100::1"), 1),
NewIp6AndPort(net.ParseIP("1::1"), 1), // this is a dupe
},
func(*Ip6AndPort) bool { return true },
)
b.Run("no preferred", func(b *testing.B) {
for i := 0; i < b.N; i++ {
rl.shouldRebuild = true
rl.Rebuild([]*net.IPNet{})
}
})
_, ipNet, err := net.ParseCIDR("172.17.0.0/16")
rl.Rebuild([]*net.IPNet{ipNet})
assert.NoError(b, err)
b.Run("1 preferred", func(b *testing.B) {
for i := 0; i < b.N; i++ {
rl.Rebuild([]*net.IPNet{ipNet})
}
})
_, ipNet2, err := net.ParseCIDR("70.0.0.0/8")
rl.Rebuild([]*net.IPNet{ipNet, ipNet2})
assert.NoError(b, err)
b.Run("2 preferred", func(b *testing.B) {
for i := 0; i < b.N; i++ {
rl.Rebuild([]*net.IPNet{ipNet, ipNet2})
}
})
_, ipNet3, err := net.ParseCIDR("0.0.0.0/0")
rl.Rebuild([]*net.IPNet{ipNet, ipNet2, ipNet3})
assert.NoError(b, err)
b.Run("3 preferred", func(b *testing.B) {
for i := 0; i < b.N; i++ {
rl.Rebuild([]*net.IPNet{ipNet, ipNet2, ipNet3})
}
})
}

152
ssh.go
View File

@@ -10,6 +10,7 @@ import (
"os"
"reflect"
"runtime/pprof"
"sort"
"strings"
"syscall"
@@ -43,51 +44,58 @@ type sshCreateTunnelFlags struct {
Address string
}
func wireSSHReload(ssh *sshd.SSHServer, c *Config) {
func wireSSHReload(l *logrus.Logger, ssh *sshd.SSHServer, c *Config) {
c.RegisterReloadCallback(func(c *Config) {
if c.GetBool("sshd.enabled", false) {
err := configSSH(ssh, c)
sshRun, err := configSSH(l, ssh, c)
if err != nil {
l.WithError(err).Error("Failed to reconfigure the sshd")
ssh.Stop()
}
if sshRun != nil {
go sshRun()
}
} else {
ssh.Stop()
}
})
}
func configSSH(ssh *sshd.SSHServer, c *Config) error {
// configSSH reads the ssh info out of the passed-in Config and
// updates the passed-in SSHServer. On success, it returns a function
// that callers may invoke to run the configured ssh server. On
// failure, it returns nil, error.
func configSSH(l *logrus.Logger, ssh *sshd.SSHServer, c *Config) (func(), error) {
//TODO conntrack list
//TODO print firewall rules or hash?
listen := c.GetString("sshd.listen", "")
if listen == "" {
return fmt.Errorf("sshd.listen must be provided")
return nil, fmt.Errorf("sshd.listen must be provided")
}
_, port, err := net.SplitHostPort(listen)
if err != nil {
return fmt.Errorf("invalid sshd.listen address: %s", err)
return nil, fmt.Errorf("invalid sshd.listen address: %s", err)
}
if port == "22" {
return fmt.Errorf("sshd.listen can not use port 22")
return nil, fmt.Errorf("sshd.listen can not use port 22")
}
//TODO: no good way to reload this right now
hostKeyFile := c.GetString("sshd.host_key", "")
if hostKeyFile == "" {
return fmt.Errorf("sshd.host_key must be provided")
return nil, fmt.Errorf("sshd.host_key must be provided")
}
hostKeyBytes, err := ioutil.ReadFile(hostKeyFile)
if err != nil {
return fmt.Errorf("error while loading sshd.host_key file: %s", err)
return nil, fmt.Errorf("error while loading sshd.host_key file: %s", err)
}
err = ssh.SetHostKey(hostKeyBytes)
if err != nil {
return fmt.Errorf("error while adding sshd.host_key: %s", err)
return nil, fmt.Errorf("error while adding sshd.host_key: %s", err)
}
rawKeys := c.Get("sshd.authorized_users")
@@ -138,17 +146,22 @@ func configSSH(ssh *sshd.SSHServer, c *Config) error {
l.Info("no ssh users to authorize")
}
var runner func()
if c.GetBool("sshd.enabled", false) {
ssh.Stop()
go ssh.Run(listen)
runner = func() {
if err := ssh.Run(listen); err != nil {
l.WithField("err", err).Warn("Failed to run the SSH server")
}
}
} else {
ssh.Stop()
}
return nil
return runner, nil
}
func attachCommands(ssh *sshd.SSHServer, hostMap *HostMap, pendingHostMap *HostMap, lightHouse *LightHouse, ifce *Interface) {
func attachCommands(l *logrus.Logger, ssh *sshd.SSHServer, hostMap *HostMap, pendingHostMap *HostMap, lightHouse *LightHouse, ifce *Interface) {
ssh.RegisterCommand(&sshd.Command{
Name: "list-hostmap",
ShortDescription: "List all known previously connected hosts",
@@ -224,13 +237,17 @@ func attachCommands(ssh *sshd.SSHServer, hostMap *HostMap, pendingHostMap *HostM
ssh.RegisterCommand(&sshd.Command{
Name: "log-level",
ShortDescription: "Gets or sets the current log level",
Callback: sshLogLevel,
Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
return sshLogLevel(l, fs, a, w)
},
})
ssh.RegisterCommand(&sshd.Command{
Name: "log-format",
ShortDescription: "Gets or sets the current log format",
Callback: sshLogFormat,
Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
return sshLogFormat(l, fs, a, w)
},
})
ssh.RegisterCommand(&sshd.Command{
@@ -330,8 +347,10 @@ func sshListHostMap(hostMap *HostMap, a interface{}, w sshd.StringWriter) error
return nil
}
hostMap.RLock()
defer hostMap.RUnlock()
hm := listHostMap(hostMap)
sort.Slice(hm, func(i, j int) bool {
return bytes.Compare(hm[i].VpnIP, hm[j].VpnIP) < 0
})
if fs.Json || fs.Pretty {
js := json.NewEncoder(w.GetWriter())
@@ -339,35 +358,15 @@ func sshListHostMap(hostMap *HostMap, a interface{}, w sshd.StringWriter) error
js.SetIndent("", " ")
}
d := make([]m, len(hostMap.Hosts))
x := 0
var h m
for _, v := range hostMap.Hosts {
h = m{
"vpnIp": int2ip(v.hostId),
"localIndex": v.localIndexId,
"remoteIndex": v.remoteIndexId,
"remoteAddrs": v.RemoteUDPAddrs(),
"cachedPackets": len(v.packetStore),
"cert": v.GetCert(),
}
if v.ConnectionState != nil {
h["messageCounter"] = v.ConnectionState.messageCounter
}
d[x] = h
x++
}
err := js.Encode(d)
err := js.Encode(hm)
if err != nil {
//TODO
return nil
}
} else {
for i, v := range hostMap.Hosts {
err := w.WriteLine(fmt.Sprintf("%s: %s", int2ip(i), v.RemoteUDPAddrs()))
for _, v := range hm {
err := w.WriteLine(fmt.Sprintf("%s: %s", v.VpnIP, v.RemoteAddrs))
if err != nil {
return err
}
@@ -384,8 +383,26 @@ func sshListLighthouseMap(lightHouse *LightHouse, a interface{}, w sshd.StringWr
return nil
}
type lighthouseInfo struct {
VpnIP net.IP `json:"vpnIp"`
Addrs *CacheMap `json:"addrs"`
}
lightHouse.RLock()
defer lightHouse.RUnlock()
addrMap := make([]lighthouseInfo, len(lightHouse.addrMap))
x := 0
for k, v := range lightHouse.addrMap {
addrMap[x] = lighthouseInfo{
VpnIP: int2ip(k),
Addrs: v.CopyCache(),
}
x++
}
lightHouse.RUnlock()
sort.Slice(addrMap, func(i, j int) bool {
return bytes.Compare(addrMap[i].VpnIP, addrMap[j].VpnIP) < 0
})
if fs.Json || fs.Pretty {
js := json.NewEncoder(w.GetWriter())
@@ -393,36 +410,19 @@ func sshListLighthouseMap(lightHouse *LightHouse, a interface{}, w sshd.StringWr
js.SetIndent("", " ")
}
d := make([]m, len(lightHouse.addrMap))
x := 0
var h m
for vpnIp, v := range lightHouse.addrMap {
ips := make([]string, len(v))
for i, ip := range v {
ips[i] = ip.String()
}
h = m{
"vpnIp": int2ip(vpnIp),
"addrs": ips,
}
d[x] = h
x++
}
err := js.Encode(d)
err := js.Encode(addrMap)
if err != nil {
//TODO
return nil
}
} else {
for vpnIp, v := range lightHouse.addrMap {
ips := make([]string, len(v))
for i, ip := range v {
ips[i] = ip.String()
for _, v := range addrMap {
b, err := json.Marshal(v.Addrs)
if err != nil {
return err
}
err := w.WriteLine(fmt.Sprintf("%s: %s", int2ip(vpnIp), ips))
err = w.WriteLine(fmt.Sprintf("%s: %s", v.VpnIP, string(b)))
if err != nil {
return err
}
@@ -473,8 +473,12 @@ func sshQueryLighthouse(ifce *Interface, fs interface{}, a []string, w sshd.Stri
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
}
ips, _ := ifce.lightHouse.Query(vpnIp, ifce)
return json.NewEncoder(w.GetWriter()).Encode(ips)
var cm *CacheMap
rl := ifce.lightHouse.Query(vpnIp, ifce)
if rl != nil {
cm = rl.CopyCache()
}
return json.NewEncoder(w.GetWriter()).Encode(cm)
}
func sshCloseTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWriter) error {
@@ -516,7 +520,7 @@ func sshCloseTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWr
)
}
ifce.closeTunnel(hostInfo)
ifce.closeTunnel(hostInfo, false)
return w.WriteLine("Closed")
}
@@ -561,7 +565,7 @@ func sshCreateTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringW
hostInfo = ifce.handshakeManager.AddVpnIP(vpnIp)
if addr != nil {
hostInfo.SetRemote(*addr)
hostInfo.SetRemote(addr)
}
ifce.getOrHandshake(vpnIp)
@@ -603,7 +607,7 @@ func sshChangeRemote(ifce *Interface, fs interface{}, a []string, w sshd.StringW
return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0]))
}
hostInfo.SetRemote(*addr)
hostInfo.SetRemote(addr)
return w.WriteLine("Changed")
}
@@ -628,7 +632,7 @@ func sshGetHeapProfile(fs interface{}, a []string, w sshd.StringWriter) error {
return err
}
func sshLogLevel(fs interface{}, a []string, w sshd.StringWriter) error {
func sshLogLevel(l *logrus.Logger, fs interface{}, a []string, w sshd.StringWriter) error {
if len(a) == 0 {
return w.WriteLine(fmt.Sprintf("Log level is: %s", l.Level))
}
@@ -642,7 +646,7 @@ func sshLogLevel(fs interface{}, a []string, w sshd.StringWriter) error {
return w.WriteLine(fmt.Sprintf("Log level is: %s", l.Level))
}
func sshLogFormat(fs interface{}, a []string, w sshd.StringWriter) error {
func sshLogFormat(l *logrus.Logger, fs interface{}, a []string, w sshd.StringWriter) error {
if len(a) == 0 {
return w.WriteLine(fmt.Sprintf("Log format is: %s", reflect.TypeOf(l.Formatter)))
}
@@ -731,7 +735,7 @@ func sshPrintTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWr
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
}
hostInfo, err := ifce.hostMap.QueryVpnIP(uint32(vpnIp))
hostInfo, err := ifce.hostMap.QueryVpnIP(vpnIp)
if err != nil {
return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0]))
}
@@ -741,7 +745,7 @@ func sshPrintTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWr
enc.SetIndent("", " ")
}
return enc.Encode(hostInfo)
return enc.Encode(copyHostInfo(hostInfo, ifce.hostMap.preferredRanges))
}
func sshReload(fs interface{}, a []string, w sshd.StringWriter) error {

View File

@@ -1,8 +1,10 @@
package sshd
import (
"errors"
"fmt"
"net"
"sync"
"github.com/armon/go-radix"
"github.com/sirupsen/logrus"
@@ -20,8 +22,11 @@ type SSHServer struct {
helpCommand *Command
commands *radix.Tree
listener net.Listener
conns map[int]*session
counter int
// Locks the conns/counter to avoid concurrent map access
connsLock sync.Mutex
conns map[int]*session
counter int
}
// NewSSHServer creates a new ssh server rigged with default commands and prepares to listen
@@ -97,11 +102,24 @@ func (s *SSHServer) Run(addr string) error {
}
s.l.WithField("sshListener", addr).Info("SSH server is listening")
// Run loops until there is an error
s.run()
s.closeSessions()
s.l.Info("SSH server stopped listening")
// We don't return an error because run logs for us
return nil
}
func (s *SSHServer) run() {
for {
c, err := s.listener.Accept()
if err != nil {
s.l.WithError(err).Warn("Error in listener, shutting down")
return nil
if !errors.Is(err, net.ErrClosed) {
s.l.WithError(err).Warn("Error in listener, shutting down")
}
return
}
conn, chans, reqs, err := ssh.NewServerConn(c, s.config)
@@ -127,36 +145,38 @@ func (s *SSHServer) Run(addr string) error {
l.WithField("remoteAddress", c.RemoteAddr()).WithField("sshFingerprint", fp).Info("ssh user logged in")
session := NewSession(s.commands, conn, chans, l.WithField("subsystem", "sshd.session"))
s.connsLock.Lock()
s.counter++
counter := s.counter
s.conns[counter] = session
s.connsLock.Unlock()
go ssh.DiscardRequests(reqs)
go func() {
<-session.exitChan
s.l.WithField("id", counter).Debug("closing conn")
s.connsLock.Lock()
delete(s.conns, counter)
s.connsLock.Unlock()
}()
}
}
func (s *SSHServer) Stop() {
// Close the listener, this will cause all session to terminate as well, see SSHServer.Run
if s.listener != nil {
if err := s.listener.Close(); err != nil {
s.l.WithError(err).Warn("Failed to close the sshd listener")
}
}
}
func (s *SSHServer) closeSessions() {
s.connsLock.Lock()
for _, c := range s.conns {
c.Close()
}
if s.listener == nil {
return
}
err := s.listener.Close()
if err != nil {
s.l.WithError(err).Warn("Failed to close the sshd listener")
return
}
s.l.Info("SSH server stopped listening")
return
s.connsLock.Unlock()
}
func (s *SSHServer) matchPubKey(c ssh.ConnMetadata, pubKey ssh.PublicKey) (*ssh.Permissions, error) {

View File

@@ -6,6 +6,7 @@ import (
"log"
"net"
"net/http"
"runtime"
"time"
graphite "github.com/cyberdelia/go-metrics-graphite"
@@ -13,26 +14,38 @@ import (
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promhttp"
"github.com/rcrowley/go-metrics"
"github.com/sirupsen/logrus"
)
func startStats(c *Config, configTest bool) error {
// startStats initializes stats from config. On success, if any futher work
// is needed to serve stats, it returns a func to handle that work. If no
// work is needed, it'll return nil. On failure, it returns nil, error.
func startStats(l *logrus.Logger, c *Config, buildVersion string, configTest bool) (func(), error) {
mType := c.GetString("stats.type", "")
if mType == "" || mType == "none" {
return nil
return nil, nil
}
interval := c.GetDuration("stats.interval", 0)
if interval == 0 {
return fmt.Errorf("stats.interval was an invalid duration: %s", c.GetString("stats.interval", ""))
return nil, fmt.Errorf("stats.interval was an invalid duration: %s", c.GetString("stats.interval", ""))
}
var startFn func()
switch mType {
case "graphite":
startGraphiteStats(interval, c, configTest)
err := startGraphiteStats(l, interval, c, configTest)
if err != nil {
return nil, err
}
case "prometheus":
startPrometheusStats(interval, c, configTest)
var err error
startFn, err = startPrometheusStats(l, interval, c, buildVersion, configTest)
if err != nil {
return nil, err
}
default:
return fmt.Errorf("stats.type was not understood: %s", mType)
return nil, fmt.Errorf("stats.type was not understood: %s", mType)
}
metrics.RegisterDebugGCStats(metrics.DefaultRegistry)
@@ -41,10 +54,10 @@ func startStats(c *Config, configTest bool) error {
go metrics.CaptureDebugGCStats(metrics.DefaultRegistry, interval)
go metrics.CaptureRuntimeMemStats(metrics.DefaultRegistry, interval)
return nil
return startFn, nil
}
func startGraphiteStats(i time.Duration, c *Config, configTest bool) error {
func startGraphiteStats(l *logrus.Logger, i time.Duration, c *Config, configTest bool) error {
proto := c.GetString("stats.protocol", "tcp")
host := c.GetString("stats.host", "")
if host == "" {
@@ -57,38 +70,53 @@ func startGraphiteStats(i time.Duration, c *Config, configTest bool) error {
return fmt.Errorf("error while setting up graphite sink: %s", err)
}
l.Infof("Starting graphite. Interval: %s, prefix: %s, addr: %s", i, prefix, addr)
if !configTest {
l.Infof("Starting graphite. Interval: %s, prefix: %s, addr: %s", i, prefix, addr)
go graphite.Graphite(metrics.DefaultRegistry, i, prefix, addr)
}
return nil
}
func startPrometheusStats(i time.Duration, c *Config, configTest bool) error {
func startPrometheusStats(l *logrus.Logger, i time.Duration, c *Config, buildVersion string, configTest bool) (func(), error) {
namespace := c.GetString("stats.namespace", "")
subsystem := c.GetString("stats.subsystem", "")
listen := c.GetString("stats.listen", "")
if listen == "" {
return fmt.Errorf("stats.listen should not be empty")
return nil, fmt.Errorf("stats.listen should not be empty")
}
path := c.GetString("stats.path", "")
if path == "" {
return fmt.Errorf("stats.path should not be empty")
return nil, fmt.Errorf("stats.path should not be empty")
}
pr := prometheus.NewRegistry()
pClient := mp.NewPrometheusProvider(metrics.DefaultRegistry, namespace, subsystem, pr, i)
go pClient.UpdatePrometheusMetrics()
// Export our version information as labels on a static gauge
g := prometheus.NewGauge(prometheus.GaugeOpts{
Namespace: namespace,
Subsystem: subsystem,
Name: "info",
Help: "Version information for the Nebula binary",
ConstLabels: prometheus.Labels{
"version": buildVersion,
"goversion": runtime.Version(),
},
})
pr.MustRegister(g)
g.Set(1)
var startFn func()
if !configTest {
go func() {
startFn = func() {
l.Infof("Prometheus stats listening on %s at %s", listen, path)
http.Handle(path, promhttp.HandlerFor(pr, promhttp.HandlerOpts{ErrorLog: l}))
log.Fatal(http.ListenAndServe(listen, nil))
}()
}
}
return nil
return startFn, nil
}

View File

@@ -1,3 +1,5 @@
// +build !e2e_testing
package nebula
import (
@@ -6,6 +8,7 @@ import (
"net"
"os"
"github.com/sirupsen/logrus"
"golang.org/x/sys/unix"
)
@@ -19,9 +22,10 @@ type Tun struct {
TXQueueLen int
Routes []route
UnsafeRoutes []route
l *logrus.Logger
}
func newTunFromFd(deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
ifce = &Tun{
@@ -33,11 +37,12 @@ func newTunFromFd(deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route,
TXQueueLen: txQueueLen,
Routes: routes,
UnsafeRoutes: unsafeRoutes,
l: l,
}
return
}
func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int, multiqueue bool) (ifce *Tun, err error) {
return nil, fmt.Errorf("newTun not supported in Android")
}
@@ -74,3 +79,7 @@ func (c *Tun) CidrNet() *net.IPNet {
func (c *Tun) DeviceName() string {
return c.Device
}
func (t *Tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for android")
}

View File

@@ -1,13 +1,16 @@
// +build !ios
// +build !e2e_testing
package nebula
import (
"fmt"
"io"
"net"
"os/exec"
"strconv"
"github.com/sirupsen/logrus"
"github.com/songgao/water"
)
@@ -16,11 +19,11 @@ type Tun struct {
Cidr *net.IPNet
MTU int
UnsafeRoutes []route
l *logrus.Logger
*water.Interface
}
func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int, multiqueue bool) (ifce *Tun, err error) {
if len(routes) > 0 {
return nil, fmt.Errorf("route MTU not supported in Darwin")
}
@@ -30,10 +33,11 @@ func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route,
Cidr: cidr,
MTU: defaultMTU,
UnsafeRoutes: unsafeRoutes,
l: l,
}, nil
}
func newTunFromFd(deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
return nil, fmt.Errorf("newTunFromFd not supported in Darwin")
}
@@ -80,3 +84,7 @@ func (c *Tun) WriteRaw(b []byte) error {
_, err := c.Write(b)
return err
}
func (t *Tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for darwin")
}

View File

@@ -1,26 +1,42 @@
package nebula
import (
"encoding/binary"
"fmt"
"io"
"net"
"strings"
log "github.com/sirupsen/logrus"
"github.com/rcrowley/go-metrics"
"github.com/sirupsen/logrus"
)
type disabledTun struct {
block chan struct{}
cidr *net.IPNet
logger *log.Logger
read chan []byte
cidr *net.IPNet
// Track these metrics since we don't have the tun device to do it for us
tx metrics.Counter
rx metrics.Counter
l *logrus.Logger
}
func newDisabledTun(cidr *net.IPNet, l *log.Logger) *disabledTun {
return &disabledTun{
cidr: cidr,
block: make(chan struct{}),
logger: l,
func newDisabledTun(cidr *net.IPNet, queueLen int, metricsEnabled bool, l *logrus.Logger) *disabledTun {
tun := &disabledTun{
cidr: cidr,
read: make(chan []byte, queueLen),
l: l,
}
if metricsEnabled {
tun.tx = metrics.GetOrRegisterCounter("messages.tx.message", nil)
tun.rx = metrics.GetOrRegisterCounter("messages.rx.message", nil)
} else {
tun.tx = &metrics.NilCounter{}
tun.rx = &metrics.NilCounter{}
}
return tun
}
func (*disabledTun) Activate() error {
@@ -36,12 +52,73 @@ func (*disabledTun) DeviceName() string {
}
func (t *disabledTun) Read(b []byte) (int, error) {
<-t.block
return 0, io.EOF
r, ok := <-t.read
if !ok {
return 0, io.EOF
}
if len(r) > len(b) {
return 0, fmt.Errorf("packet larger than mtu: %d > %d bytes", len(r), len(b))
}
t.tx.Inc(1)
if t.l.Level >= logrus.DebugLevel {
t.l.WithField("raw", prettyPacket(r)).Debugf("Write payload")
}
return copy(b, r), nil
}
func (t *disabledTun) handleICMPEchoRequest(b []byte) bool {
// Return early if this is not a simple ICMP Echo Request
if !(len(b) >= 28 && len(b) <= mtu && b[0] == 0x45 && b[9] == 0x01 && b[20] == 0x08) {
return false
}
// We don't support fragmented packets
if b[7] != 0 || (b[6]&0x2F != 0) {
return false
}
buf := make([]byte, len(b))
copy(buf, b)
// Swap dest / src IPs and recalculate checksum
ipv4 := buf[0:20]
copy(ipv4[12:16], b[16:20])
copy(ipv4[16:20], b[12:16])
ipv4[10] = 0
ipv4[11] = 0
binary.BigEndian.PutUint16(ipv4[10:], ipChecksum(ipv4))
// Change type to ICMP Echo Reply and recalculate checksum
icmp := buf[20:]
icmp[0] = 0
icmp[2] = 0
icmp[3] = 0
binary.BigEndian.PutUint16(icmp[2:], ipChecksum(icmp))
// attempt to write it, but don't block
select {
case t.read <- buf:
default:
t.l.Debugf("tun_disabled: dropped ICMP Echo Reply response")
}
return true
}
func (t *disabledTun) Write(b []byte) (int, error) {
t.logger.WithField("raw", prettyPacket(b)).Debugf("Disabled tun received unexpected payload")
t.rx.Inc(1)
// Check for ICMP Echo Request before spending time doing the full parsing
if t.handleICMPEchoRequest(b) {
if t.l.Level >= logrus.DebugLevel {
t.l.WithField("raw", prettyPacket(b)).Debugf("Disabled tun responded to ICMP Echo Request")
}
} else if t.l.Level >= logrus.DebugLevel {
t.l.WithField("raw", prettyPacket(b)).Debugf("Disabled tun received unexpected payload")
}
return len(b), nil
}
@@ -50,10 +127,14 @@ func (t *disabledTun) WriteRaw(b []byte) error {
return err
}
func (t *disabledTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return t, nil
}
func (t *disabledTun) Close() error {
if t.block != nil {
close(t.block)
t.block = nil
if t.read != nil {
close(t.read)
t.read = nil
}
return nil
}
@@ -72,3 +153,22 @@ func (p prettyPacket) String() string {
return s.String()
}
func ipChecksum(b []byte) uint16 {
var c uint32
sz := len(b) - 1
for i := 0; i < sz; i += 2 {
c += uint32(b[i]) << 8
c += uint32(b[i+1])
}
if sz%2 == 0 {
c += uint32(b[sz]) << 8
}
for (c >> 16) > 0 {
c = (c & 0xffff) + (c >> 16)
}
return ^uint16(c)
}

View File

@@ -1,3 +1,5 @@
// +build !e2e_testing
package nebula
import (
@@ -9,6 +11,8 @@ import (
"regexp"
"strconv"
"strings"
"github.com/sirupsen/logrus"
)
var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
@@ -18,15 +22,16 @@ type Tun struct {
Cidr *net.IPNet
MTU int
UnsafeRoutes []route
l *logrus.Logger
io.ReadWriteCloser
}
func newTunFromFd(deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
return nil, fmt.Errorf("newTunFromFd not supported in FreeBSD")
}
func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int, multiqueue bool) (ifce *Tun, err error) {
if len(routes) > 0 {
return nil, fmt.Errorf("Route MTU not supported in FreeBSD")
}
@@ -41,6 +46,7 @@ func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route,
Cidr: cidr,
MTU: defaultMTU,
UnsafeRoutes: unsafeRoutes,
l: l,
}, nil
}
@@ -52,21 +58,21 @@ func (c *Tun) Activate() error {
}
// TODO use syscalls instead of exec.Command
l.Debug("command: ifconfig", c.Device, c.Cidr.String(), c.Cidr.IP.String())
c.l.Debug("command: ifconfig", c.Device, c.Cidr.String(), c.Cidr.IP.String())
if err = exec.Command("/sbin/ifconfig", c.Device, c.Cidr.String(), c.Cidr.IP.String()).Run(); err != nil {
return fmt.Errorf("failed to run 'ifconfig': %s", err)
}
l.Debug("command: route", "-n", "add", "-net", c.Cidr.String(), "-interface", c.Device)
c.l.Debug("command: route", "-n", "add", "-net", c.Cidr.String(), "-interface", c.Device)
if err = exec.Command("/sbin/route", "-n", "add", "-net", c.Cidr.String(), "-interface", c.Device).Run(); err != nil {
return fmt.Errorf("failed to run 'route add': %s", err)
}
l.Debug("command: ifconfig", c.Device, "mtu", strconv.Itoa(c.MTU))
c.l.Debug("command: ifconfig", c.Device, "mtu", strconv.Itoa(c.MTU))
if err = exec.Command("/sbin/ifconfig", c.Device, "mtu", strconv.Itoa(c.MTU)).Run(); err != nil {
return fmt.Errorf("failed to run 'ifconfig': %s", err)
}
// Unsafe path routes
for _, r := range c.UnsafeRoutes {
l.Debug("command: route", "-n", "add", "-net", r.route.String(), "-interface", c.Device)
c.l.Debug("command: route", "-n", "add", "-net", r.route.String(), "-interface", c.Device)
if err = exec.Command("/sbin/route", "-n", "add", "-net", r.route.String(), "-interface", c.Device).Run(); err != nil {
return fmt.Errorf("failed to run 'route add' for unsafe_route %s: %s", r.route.String(), err)
}
@@ -87,3 +93,7 @@ func (c *Tun) WriteRaw(b []byte) error {
_, err := c.Write(b)
return err
}
func (t *Tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for freebsd")
}

View File

@@ -1,4 +1,5 @@
// +build ios
// +build !e2e_testing
package nebula
@@ -10,6 +11,8 @@ import (
"os"
"sync"
"syscall"
"github.com/sirupsen/logrus"
)
type Tun struct {
@@ -18,11 +21,11 @@ type Tun struct {
Cidr *net.IPNet
}
func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int, multiqueue bool) (ifce *Tun, err error) {
return nil, fmt.Errorf("newTun not supported in iOS")
}
func newTunFromFd(deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
if len(routes) > 0 {
return nil, fmt.Errorf("route MTU not supported in Darwin")
}
@@ -111,3 +114,7 @@ func (c *Tun) CidrNet() *net.IPNet {
func (c *Tun) DeviceName() string {
return c.Device
}
func (t *Tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for ios")
}

View File

@@ -1,4 +1,5 @@
// +build !android
// +build !e2e_testing
package nebula
@@ -10,6 +11,7 @@ import (
"strings"
"unsafe"
"github.com/sirupsen/logrus"
"github.com/vishvananda/netlink"
"golang.org/x/sys/unix"
)
@@ -24,6 +26,7 @@ type Tun struct {
TXQueueLen int
Routes []route
UnsafeRoutes []route
l *logrus.Logger
}
type ifReq struct {
@@ -55,8 +58,9 @@ func ipv4(addr string) (o [4]byte, err error) {
*/
const (
cIFF_TUN = 0x0001
cIFF_NO_PI = 0x1000
cIFF_TUN = 0x0001
cIFF_NO_PI = 0x1000
cIFF_MULTI_QUEUE = 0x0100
)
type ifreqAddr struct {
@@ -77,7 +81,7 @@ type ifreqQLEN struct {
pad [8]byte
}
func newTunFromFd(deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
@@ -90,11 +94,12 @@ func newTunFromFd(deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route,
TXQueueLen: txQueueLen,
Routes: routes,
UnsafeRoutes: unsafeRoutes,
l: l,
}
return
}
func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int, multiqueue bool) (ifce *Tun, err error) {
fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
if err != nil {
return nil, err
@@ -102,9 +107,12 @@ func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route,
var req ifReq
req.Flags = uint16(cIFF_TUN | cIFF_NO_PI)
if multiqueue {
req.Flags |= cIFF_MULTI_QUEUE
}
copy(req.Name[:], deviceName)
if err = ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil {
return
return nil, err
}
name := strings.Trim(string(req.Name[:]), "\x00")
@@ -127,10 +135,29 @@ func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route,
TXQueueLen: txQueueLen,
Routes: routes,
UnsafeRoutes: unsafeRoutes,
l: l,
}
return
}
func (c *Tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
if err != nil {
return nil, err
}
var req ifReq
req.Flags = uint16(cIFF_TUN | cIFF_NO_PI | cIFF_MULTI_QUEUE)
copy(req.Name[:], c.Device)
if err = ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil {
return nil, err
}
file := os.NewFile(uintptr(fd), "/dev/net/tun")
return file, nil
}
func (c *Tun) WriteRaw(b []byte) error {
var nn int
for {
@@ -153,6 +180,10 @@ func (c *Tun) WriteRaw(b []byte) error {
}
}
func (c *Tun) Write(b []byte) (int, error) {
return len(b), c.WriteRaw(b)
}
func (c Tun) deviceBytes() (o [16]byte) {
for i, c := range c.Device {
o[i] = byte(c)
@@ -207,14 +238,14 @@ func (c Tun) Activate() error {
ifm := ifreqMTU{Name: devName, MTU: int32(c.MaxMTU)}
if err = ioctl(fd, unix.SIOCSIFMTU, uintptr(unsafe.Pointer(&ifm))); err != nil {
// This is currently a non fatal condition because the route table must have the MTU set appropriately as well
l.WithError(err).Error("Failed to set tun mtu")
c.l.WithError(err).Error("Failed to set tun mtu")
}
// Set the transmit queue length
ifrq := ifreqQLEN{Name: devName, Value: int32(c.TXQueueLen)}
if err = ioctl(fd, unix.SIOCSIFTXQLEN, uintptr(unsafe.Pointer(&ifrq))); err != nil {
// If we can't set the queue length nebula will still work but it may lead to packet loss
l.WithError(err).Error("Failed to set tun tx queue length")
c.l.WithError(err).Error("Failed to set tun tx queue length")
}
// Bring up the interface

View File

@@ -1,3 +1,5 @@
// +build !e2e_testing
package nebula
import "testing"

View File

@@ -9,7 +9,8 @@ import (
)
func Test_parseRoutes(t *testing.T) {
c := NewConfig()
l := NewTestLogger()
c := NewConfig(l)
_, n, _ := net.ParseCIDR("10.0.0.0/24")
// test no routes config
@@ -104,7 +105,8 @@ func Test_parseRoutes(t *testing.T) {
}
func Test_parseUnsafeRoutes(t *testing.T) {
c := NewConfig()
l := NewTestLogger()
c := NewConfig(l)
_, n, _ := net.ParseCIDR("10.0.0.0/24")
// test no routes config

104
tun_tester.go Normal file
View File

@@ -0,0 +1,104 @@
// +build e2e_testing
package nebula
import (
"fmt"
"io"
"net"
"github.com/sirupsen/logrus"
)
type Tun struct {
Device string
Cidr *net.IPNet
MTU int
UnsafeRoutes []route
l *logrus.Logger
rxPackets chan []byte // Packets to receive into nebula
txPackets chan []byte // Packets transmitted outside by nebula
}
func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, _ []route, unsafeRoutes []route, _ int, _ bool) (ifce *Tun, err error) {
return &Tun{
Device: deviceName,
Cidr: cidr,
MTU: defaultMTU,
UnsafeRoutes: unsafeRoutes,
l: l,
rxPackets: make(chan []byte, 1),
txPackets: make(chan []byte, 1),
}, nil
}
func newTunFromFd(_ *logrus.Logger, _ int, _ *net.IPNet, _ int, _ []route, _ []route, _ int) (ifce *Tun, err error) {
return nil, fmt.Errorf("newTunFromFd not supported")
}
// Send will place a byte array onto the receive queue for nebula to consume
// These are unencrypted ip layer frames destined for another nebula node.
// packets should exit the udp side, capture them with udpConn.Get
func (c *Tun) Send(packet []byte) {
c.l.WithField("dataLen", len(packet)).Info("Tun receiving injected packet")
c.rxPackets <- packet
}
// Get will pull an unencrypted ip layer frame from the transmit queue
// nebula meant to send this message to some application on the local system
// packets were ingested from the udp side, you can send them with udpConn.Send
func (c *Tun) Get(block bool) []byte {
if block {
return <-c.txPackets
}
select {
case p := <-c.txPackets:
return p
default:
return nil
}
}
//********************************************************************************************************************//
// Below this is boilerplate implementation to make nebula actually work
//********************************************************************************************************************//
func (c *Tun) Activate() error {
return nil
}
func (c *Tun) CidrNet() *net.IPNet {
return c.Cidr
}
func (c *Tun) DeviceName() string {
return c.Device
}
func (c *Tun) Write(b []byte) (n int, err error) {
return len(b), c.WriteRaw(b)
}
func (c *Tun) Close() error {
close(c.rxPackets)
return nil
}
func (c *Tun) WriteRaw(b []byte) error {
packet := make([]byte, len(b), len(b))
copy(packet, b)
c.txPackets <- packet
return nil
}
func (c *Tun) Read(b []byte) (int, error) {
p := <-c.rxPackets
copy(b, p)
return len(p), nil
}
func (c *Tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented")
}

View File

@@ -1,11 +1,15 @@
// +build !e2e_testing
package nebula
import (
"fmt"
"io"
"net"
"os/exec"
"strconv"
"github.com/sirupsen/logrus"
"github.com/songgao/water"
)
@@ -14,15 +18,16 @@ type Tun struct {
Cidr *net.IPNet
MTU int
UnsafeRoutes []route
l *logrus.Logger
*water.Interface
}
func newTunFromFd(deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
return nil, fmt.Errorf("newTunFromFd not supported in Windows")
}
func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int, multiqueue bool) (ifce *Tun, err error) {
if len(routes) > 0 {
return nil, fmt.Errorf("route MTU not supported in Windows")
}
@@ -32,6 +37,7 @@ func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route,
Cidr: cidr,
MTU: defaultMTU,
UnsafeRoutes: unsafeRoutes,
l: l,
}, nil
}
@@ -100,3 +106,7 @@ func (c *Tun) WriteRaw(b []byte) error {
_, err := c.Write(b)
return err
}
func (t *Tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for windows")
}

82
udp_all.go Normal file
View File

@@ -0,0 +1,82 @@
package nebula
import (
"encoding/json"
"fmt"
"net"
"strconv"
)
type udpAddr struct {
IP net.IP
Port uint16
}
func NewUDPAddr(ip net.IP, port uint16) *udpAddr {
addr := udpAddr{IP: make([]byte, net.IPv6len), Port: port}
copy(addr.IP, ip.To16())
return &addr
}
func NewUDPAddrFromString(s string) *udpAddr {
ip, port, err := parseIPAndPort(s)
//TODO: handle err
_ = err
return &udpAddr{IP: ip.To16(), Port: port}
}
func (ua *udpAddr) Equals(t *udpAddr) bool {
if t == nil || ua == nil {
return t == nil && ua == nil
}
return ua.IP.Equal(t.IP) && ua.Port == t.Port
}
func (ua *udpAddr) String() string {
if ua == nil {
return "<nil>"
}
return net.JoinHostPort(ua.IP.String(), fmt.Sprintf("%v", ua.Port))
}
func (ua *udpAddr) MarshalJSON() ([]byte, error) {
if ua == nil {
return nil, nil
}
return json.Marshal(m{"ip": ua.IP, "port": ua.Port})
}
func (ua *udpAddr) Copy() *udpAddr {
if ua == nil {
return nil
}
nu := udpAddr{
Port: ua.Port,
IP: make(net.IP, len(ua.IP)),
}
copy(nu.IP, ua.IP)
return &nu
}
func parseIPAndPort(s string) (net.IP, uint16, error) {
rIp, sPort, err := net.SplitHostPort(s)
if err != nil {
return nil, 0, err
}
addr, err := net.ResolveIPAddr("ip", rIp)
if err != nil {
return nil, 0, err
}
iPort, err := strconv.Atoi(sPort)
if err != nil {
return nil, 0, err
}
return addr.IP, uint16(iPort), nil
}

View File

@@ -1,3 +1,5 @@
// +build !e2e_testing
package nebula
import (

View File

@@ -1,3 +1,5 @@
// +build !e2e_testing
package nebula
// Darwin support is primarily implemented in udp_generic, besides NewListenConfig
@@ -28,6 +30,7 @@ func NewListenConfig(multi bool) net.ListenConfig {
return controlErr
}
}
return nil
},
}
@@ -39,5 +42,5 @@ func (u *udpConn) Rebind() error {
return err
}
return syscall.SetsockoptInt(int(file.Fd()), unix.IPPROTO_IP, unix.IP_BOUND_IF, 0)
return syscall.SetsockoptInt(int(file.Fd()), unix.IPPROTO_IPV6, unix.IPV6_BOUND_IF, 0)
}

View File

@@ -1,3 +1,5 @@
// +build !e2e_testing
package nebula
// FreeBSD support is primarily implemented in udp_generic, besides NewListenConfig

View File

@@ -1,4 +1,5 @@
// +build !linux android
// +build !e2e_testing
// udp_generic implements the nebula UDP interface in pure Go stdlib. This
// means it can be used on platforms like Darwin and Windows.
@@ -7,77 +8,31 @@ package nebula
import (
"context"
"encoding/binary"
"fmt"
"net"
"strconv"
"strings"
)
type udpAddr struct {
net.UDPAddr
}
"github.com/sirupsen/logrus"
)
type udpConn struct {
*net.UDPConn
l *logrus.Logger
}
func NewUDPAddr(ip uint32, port uint16) *udpAddr {
return &udpAddr{
UDPAddr: net.UDPAddr{
IP: int2ip(ip),
Port: int(port),
},
}
}
func NewUDPAddrFromString(s string) *udpAddr {
p := strings.Split(s, ":")
if len(p) < 2 {
return nil
}
port, _ := strconv.Atoi(p[1])
return &udpAddr{
UDPAddr: net.UDPAddr{
IP: net.ParseIP(p[0]),
Port: port,
},
}
}
func NewListener(ip string, port int, multi bool) (*udpConn, error) {
func NewListener(l *logrus.Logger, ip string, port int, multi bool) (*udpConn, error) {
lc := NewListenConfig(multi)
pc, err := lc.ListenPacket(context.TODO(), "udp4", fmt.Sprintf("%s:%d", ip, port))
pc, err := lc.ListenPacket(context.TODO(), "udp", fmt.Sprintf("%s:%d", ip, port))
if err != nil {
return nil, err
}
if uc, ok := pc.(*net.UDPConn); ok {
return &udpConn{UDPConn: uc}, nil
return &udpConn{UDPConn: uc, l: l}, nil
}
return nil, fmt.Errorf("Unexpected PacketConn: %T %#v", pc, pc)
}
func (ua *udpAddr) Equals(t *udpAddr) bool {
if t == nil || ua == nil {
return t == nil && ua == nil
}
return ua.IP.Equal(t.IP) && ua.Port == t.Port
}
func (ua *udpAddr) Copy() udpAddr {
nu := udpAddr{net.UDPAddr{
Port: ua.Port,
Zone: ua.Zone,
IP: make(net.IP, len(ua.IP)),
}}
copy(nu.IP, ua.IP)
return nu
}
func (uc *udpConn) WriteTo(b []byte, addr *udpAddr) error {
_, err := uc.UDPConn.WriteToUDP(b, &addr.UDPAddr)
_, err := uc.UDPConn.WriteToUDP(b, &net.UDPAddr{IP: addr.IP, Port: int(addr.Port)})
return err
}
@@ -86,7 +41,11 @@ func (uc *udpConn) LocalAddr() (*udpAddr, error) {
switch v := a.(type) {
case *net.UDPAddr:
return &udpAddr{UDPAddr: *v}, nil
addr := &udpAddr{IP: make([]byte, len(v.IP))}
copy(addr.IP, v.IP)
addr.Port = uint16(v.Port)
return addr, nil
default:
return nil, fmt.Errorf("LocalAddr returned: %#v", a)
}
@@ -96,39 +55,41 @@ func (u *udpConn) reloadConfig(c *Config) {
// TODO
}
func NewUDPStatsEmitter(udpConns []*udpConn) func() {
// No UDP stats for non-linux
return func() {}
}
type rawMessage struct {
Len uint32
}
func (u *udpConn) ListenOut(f *Interface) {
func (u *udpConn) ListenOut(f *Interface, q int) {
plaintext := make([]byte, mtu)
buffer := make([]byte, mtu)
header := &Header{}
fwPacket := &FirewallPacket{}
udpAddr := &udpAddr{}
udpAddr := &udpAddr{IP: make([]byte, 16)}
nb := make([]byte, 12, 12)
lhh := f.lightHouse.NewRequestHandler()
conntrackCache := NewConntrackCacheTicker(f.conntrackCacheTimeout)
for {
// Just read one packet at a time
n, rua, err := u.ReadFromUDP(buffer)
if err != nil {
l.WithError(err).Error("Failed to read packets")
f.l.WithError(err).Error("Failed to read packets")
continue
}
udpAddr.UDPAddr = *rua
f.readOutsidePackets(udpAddr, plaintext[:0], buffer[:n], header, fwPacket, nb)
udpAddr.IP = rua.IP
udpAddr.Port = uint16(rua.Port)
f.readOutsidePackets(udpAddr, plaintext[:0], buffer[:n], header, fwPacket, lhh, nb, q, conntrackCache.Get(f.l))
}
}
func udp2ip(addr *udpAddr) net.IP {
return addr.IP
}
func udp2ipInt(addr *udpAddr) uint32 {
return binary.BigEndian.Uint32(addr.IP.To4())
}
func hostDidRoam(addr *udpAddr, newaddr *udpAddr) bool {
return !addr.Equals(newaddr)
}

View File

@@ -1,17 +1,17 @@
// +build !android
// +build !e2e_testing
package nebula
import (
"encoding/binary"
"encoding/json"
"fmt"
"net"
"strconv"
"strings"
"syscall"
"unsafe"
"github.com/rcrowley/go-metrics"
"github.com/sirupsen/logrus"
"golang.org/x/sys/unix"
)
@@ -19,45 +19,31 @@ import (
type udpConn struct {
sysFd int
}
type udpAddr struct {
IP uint32
Port uint16
}
func NewUDPAddr(ip uint32, port uint16) *udpAddr {
return &udpAddr{IP: ip, Port: port}
}
func NewUDPAddrFromString(s string) *udpAddr {
p := strings.Split(s, ":")
if len(p) < 2 {
return nil
}
port, _ := strconv.Atoi(p[1])
return &udpAddr{
IP: ip2int(net.ParseIP(p[0])),
Port: uint16(port),
}
}
type rawSockaddr struct {
Family uint16
Data [14]uint8
}
type rawSockaddrAny struct {
Addr rawSockaddr
Pad [96]int8
l *logrus.Logger
}
var x int
func NewListener(ip string, port int, multi bool) (*udpConn, error) {
// From linux/sock_diag.h
const (
_SK_MEMINFO_RMEM_ALLOC = iota
_SK_MEMINFO_RCVBUF
_SK_MEMINFO_WMEM_ALLOC
_SK_MEMINFO_SNDBUF
_SK_MEMINFO_FWD_ALLOC
_SK_MEMINFO_WMEM_QUEUED
_SK_MEMINFO_OPTMEM
_SK_MEMINFO_BACKLOG
_SK_MEMINFO_DROPS
_SK_MEMINFO_VARS
)
type _SK_MEMINFO [_SK_MEMINFO_VARS]uint32
func NewListener(l *logrus.Logger, ip string, port int, multi bool) (*udpConn, error) {
syscall.ForkLock.RLock()
fd, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_UDP)
fd, err := unix.Socket(unix.AF_INET6, unix.SOCK_DGRAM, unix.IPPROTO_UDP)
if err == nil {
unix.CloseOnExec(fd)
}
@@ -68,8 +54,8 @@ func NewListener(ip string, port int, multi bool) (*udpConn, error) {
return nil, fmt.Errorf("unable to open socket: %s", err)
}
var lip [4]byte
copy(lip[:], net.ParseIP(ip).To4())
var lip [16]byte
copy(lip[:], net.ParseIP(ip))
if multi {
if err = unix.SetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_REUSEPORT, 1); err != nil {
@@ -77,7 +63,8 @@ func NewListener(ip string, port int, multi bool) (*udpConn, error) {
}
}
if err = unix.Bind(fd, &unix.SockaddrInet4{Addr: lip, Port: port}); err != nil {
//TODO: support multiple listening IPs (for limiting ipv6)
if err = unix.Bind(fd, &unix.SockaddrInet6{Addr: lip, Port: port}); err != nil {
return nil, fmt.Errorf("unable to bind to socket: %s", err)
}
@@ -86,17 +73,13 @@ func NewListener(ip string, port int, multi bool) (*udpConn, error) {
//v, err := unix.GetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_INCOMING_CPU)
//l.Println(v, err)
return &udpConn{sysFd: fd}, err
return &udpConn{sysFd: fd, l: l}, err
}
func (u *udpConn) Rebind() error {
return nil
}
func (ua *udpAddr) Copy() udpAddr {
return *ua
}
func (u *udpConn) SetRecvBuffer(n int) error {
return unix.SetsockoptInt(u.sysFd, unix.SOL_SOCKET, unix.SO_RCVBUFFORCE, n)
}
@@ -114,38 +97,33 @@ func (u *udpConn) GetSendBuffer() (int, error) {
}
func (u *udpConn) LocalAddr() (*udpAddr, error) {
var rsa rawSockaddrAny
var rLen = unix.SizeofSockaddrAny
_, _, err := unix.Syscall(
unix.SYS_GETSOCKNAME,
uintptr(u.sysFd),
uintptr(unsafe.Pointer(&rsa)),
uintptr(unsafe.Pointer(&rLen)),
)
if err != 0 {
sa, err := unix.Getsockname(u.sysFd)
if err != nil {
return nil, err
}
addr := &udpAddr{}
if rsa.Addr.Family == unix.AF_INET {
addr.Port = uint16(rsa.Addr.Data[0])<<8 + uint16(rsa.Addr.Data[1])
addr.IP = uint32(rsa.Addr.Data[2])<<24 + uint32(rsa.Addr.Data[3])<<16 + uint32(rsa.Addr.Data[4])<<8 + uint32(rsa.Addr.Data[5])
} else {
addr.Port = 0
addr.IP = 0
switch sa := sa.(type) {
case *unix.SockaddrInet4:
addr.IP = net.IP{sa.Addr[0], sa.Addr[1], sa.Addr[2], sa.Addr[3]}.To16()
addr.Port = uint16(sa.Port)
case *unix.SockaddrInet6:
addr.IP = sa.Addr[0:]
addr.Port = uint16(sa.Port)
}
return addr, nil
}
func (u *udpConn) ListenOut(f *Interface) {
func (u *udpConn) ListenOut(f *Interface, q int) {
plaintext := make([]byte, mtu)
header := &Header{}
fwPacket := &FirewallPacket{}
udpAddr := &udpAddr{}
nb := make([]byte, 12, 12)
lhh := f.lightHouse.NewRequestHandler()
//TODO: should we track this?
//metric := metrics.GetOrRegisterHistogram("test.batch_read", nil, metrics.NewExpDecaySample(1028, 0.015))
msgs, buffers, names := u.PrepareRawMessages(f.udpBatchSize)
@@ -154,19 +132,20 @@ func (u *udpConn) ListenOut(f *Interface) {
read = u.ReadSingle
}
conntrackCache := NewConntrackCacheTicker(f.conntrackCacheTimeout)
for {
n, err := read(msgs)
if err != nil {
l.WithError(err).Error("Failed to read packets")
u.l.WithError(err).Error("Failed to read packets")
continue
}
//metric.Update(int64(n))
for i := 0; i < n; i++ {
udpAddr.IP = binary.BigEndian.Uint32(names[i][4:8])
udpAddr.IP = names[i][8:24]
udpAddr.Port = binary.BigEndian.Uint16(names[i][2:4])
f.readOutsidePackets(udpAddr, plaintext[:0], buffers[i][:msgs[i].Len], header, fwPacket, nb)
f.readOutsidePackets(udpAddr, plaintext[:0], buffers[i][:msgs[i].Len], header, fwPacket, lhh, nb, q, conntrackCache.Get(u.l))
}
}
}
@@ -213,18 +192,13 @@ func (u *udpConn) ReadMulti(msgs []rawMessage) (int, error) {
}
func (u *udpConn) WriteTo(b []byte, addr *udpAddr) error {
var rsa unix.RawSockaddrInet4
//TODO: sometimes addr is nil!
rsa.Family = unix.AF_INET
var rsa unix.RawSockaddrInet6
rsa.Family = unix.AF_INET6
p := (*[2]byte)(unsafe.Pointer(&rsa.Port))
p[0] = byte(addr.Port >> 8)
p[1] = byte(addr.Port)
rsa.Addr[0] = byte(addr.IP & 0xff000000 >> 24)
rsa.Addr[1] = byte(addr.IP & 0x00ff0000 >> 16)
rsa.Addr[2] = byte(addr.IP & 0x0000ff00 >> 8)
rsa.Addr[3] = byte(addr.IP & 0x000000ff)
copy(rsa.Addr[:], addr.IP)
for {
_, _, err := unix.Syscall6(
@@ -234,7 +208,7 @@ func (u *udpConn) WriteTo(b []byte, addr *udpAddr) error {
uintptr(len(b)),
uintptr(0),
uintptr(unsafe.Pointer(&rsa)),
uintptr(unix.SizeofSockaddrInet4),
uintptr(unix.SizeofSockaddrInet6),
)
if err != 0 {
@@ -254,12 +228,12 @@ func (u *udpConn) reloadConfig(c *Config) {
if err == nil {
s, err := u.GetRecvBuffer()
if err == nil {
l.WithField("size", s).Info("listen.read_buffer was set")
u.l.WithField("size", s).Info("listen.read_buffer was set")
} else {
l.WithError(err).Warn("Failed to get listen.read_buffer")
u.l.WithError(err).Warn("Failed to get listen.read_buffer")
}
} else {
l.WithError(err).Error("Failed to set listen.read_buffer")
u.l.WithError(err).Error("Failed to set listen.read_buffer")
}
}
@@ -269,37 +243,55 @@ func (u *udpConn) reloadConfig(c *Config) {
if err == nil {
s, err := u.GetSendBuffer()
if err == nil {
l.WithField("size", s).Info("listen.write_buffer was set")
u.l.WithField("size", s).Info("listen.write_buffer was set")
} else {
l.WithError(err).Warn("Failed to get listen.write_buffer")
u.l.WithError(err).Warn("Failed to get listen.write_buffer")
}
} else {
l.WithError(err).Error("Failed to set listen.write_buffer")
u.l.WithError(err).Error("Failed to set listen.write_buffer")
}
}
}
func (ua *udpAddr) Equals(t *udpAddr) bool {
if t == nil || ua == nil {
return t == nil && ua == nil
func (u *udpConn) getMemInfo(meminfo *_SK_MEMINFO) error {
var vallen uint32 = 4 * _SK_MEMINFO_VARS
_, _, err := unix.Syscall6(unix.SYS_GETSOCKOPT, uintptr(u.sysFd), uintptr(unix.SOL_SOCKET), uintptr(unix.SO_MEMINFO), uintptr(unsafe.Pointer(meminfo)), uintptr(unsafe.Pointer(&vallen)), 0)
if err != 0 {
return err
}
return ua.IP == t.IP && ua.Port == t.Port
return nil
}
func (ua *udpAddr) String() string {
return fmt.Sprintf("%s:%v", int2ip(ua.IP), ua.Port)
}
func NewUDPStatsEmitter(udpConns []*udpConn) func() {
// Check if our kernel supports SO_MEMINFO before registering the gauges
var udpGauges [][_SK_MEMINFO_VARS]metrics.Gauge
var meminfo _SK_MEMINFO
if err := udpConns[0].getMemInfo(&meminfo); err == nil {
udpGauges = make([][_SK_MEMINFO_VARS]metrics.Gauge, len(udpConns))
for i := range udpConns {
udpGauges[i] = [_SK_MEMINFO_VARS]metrics.Gauge{
metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.rmem_alloc", i), nil),
metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.rcvbuf", i), nil),
metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.wmem_alloc", i), nil),
metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.sndbuf", i), nil),
metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.fwd_alloc", i), nil),
metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.wmem_queued", i), nil),
metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.optmem", i), nil),
metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.backlog", i), nil),
metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.drops", i), nil),
}
}
}
func (ua *udpAddr) MarshalJSON() ([]byte, error) {
return json.Marshal(m{"ip": int2ip(ua.IP), "port": ua.Port})
}
func udp2ip(addr *udpAddr) net.IP {
return int2ip(addr.IP)
}
func udp2ipInt(addr *udpAddr) uint32 {
return addr.IP
return func() {
for i, gauges := range udpGauges {
if err := udpConns[i].getMemInfo(&meminfo); err == nil {
for j := 0; j < _SK_MEMINFO_VARS; j++ {
gauges[j].Update(int64(meminfo[j]))
}
}
}
}
}
func hostDidRoam(addr *udpAddr, newaddr *udpAddr) bool {

View File

@@ -1,10 +1,13 @@
// +build linux
// +build 386 amd64p32 arm mips mipsle
// +build !android
// +build !e2e_testing
package nebula
import "unsafe"
import (
"golang.org/x/sys/unix"
)
type iovec struct {
Base *byte
@@ -33,17 +36,17 @@ func (u *udpConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) {
for i := range msgs {
buffers[i] = make([]byte, mtu)
names[i] = make([]byte, 0x1c) //TODO = sizeofSockaddrInet6
names[i] = make([]byte, unix.SizeofSockaddrInet6)
//TODO: this is still silly, no need for an array
vs := []iovec{
{Base: (*byte)(unsafe.Pointer(&buffers[i][0])), Len: uint32(len(buffers[i]))},
{Base: &buffers[i][0], Len: uint32(len(buffers[i]))},
}
msgs[i].Hdr.Iov = &vs[0]
msgs[i].Hdr.Iovlen = uint32(len(vs))
msgs[i].Hdr.Name = (*byte)(unsafe.Pointer(&names[i][0]))
msgs[i].Hdr.Name = &names[i][0]
msgs[i].Hdr.Namelen = uint32(len(names[i]))
}

View File

@@ -1,10 +1,13 @@
// +build linux
// +build amd64 arm64 ppc64 ppc64le mips64 mips64le s390x
// +build !android
// +build !e2e_testing
package nebula
import "unsafe"
import (
"golang.org/x/sys/unix"
)
type iovec struct {
Base *byte
@@ -36,17 +39,17 @@ func (u *udpConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) {
for i := range msgs {
buffers[i] = make([]byte, mtu)
names[i] = make([]byte, 0x1c) //TODO = sizeofSockaddrInet6
names[i] = make([]byte, unix.SizeofSockaddrInet6)
//TODO: this is still silly, no need for an array
vs := []iovec{
{Base: (*byte)(unsafe.Pointer(&buffers[i][0])), Len: uint64(len(buffers[i]))},
{Base: &buffers[i][0], Len: uint64(len(buffers[i]))},
}
msgs[i].Hdr.Iov = &vs[0]
msgs[i].Hdr.Iovlen = uint64(len(vs))
msgs[i].Hdr.Name = (*byte)(unsafe.Pointer(&names[i][0]))
msgs[i].Hdr.Name = &names[i][0]
msgs[i].Hdr.Namelen = uint32(len(names[i]))
}

140
udp_tester.go Normal file
View File

@@ -0,0 +1,140 @@
// +build e2e_testing
package nebula
import (
"fmt"
"net"
"github.com/sirupsen/logrus"
)
type UdpPacket struct {
ToIp net.IP
ToPort uint16
FromIp net.IP
FromPort uint16
Data []byte
}
func (u *UdpPacket) Copy() *UdpPacket {
n := &UdpPacket{
ToIp: make(net.IP, len(u.ToIp)),
ToPort: u.ToPort,
FromIp: make(net.IP, len(u.FromIp)),
FromPort: u.FromPort,
Data: make([]byte, len(u.Data)),
}
copy(n.ToIp, u.ToIp)
copy(n.FromIp, u.FromIp)
copy(n.Data, u.Data)
return n
}
type udpConn struct {
addr *udpAddr
rxPackets chan *UdpPacket // Packets to receive into nebula
txPackets chan *UdpPacket // Packets transmitted outside by nebula
l *logrus.Logger
}
func NewListener(l *logrus.Logger, ip string, port int, _ bool) (*udpConn, error) {
return &udpConn{
addr: &udpAddr{net.ParseIP(ip), uint16(port)},
rxPackets: make(chan *UdpPacket, 1),
txPackets: make(chan *UdpPacket, 1),
l: l,
}, nil
}
// Send will place a UdpPacket onto the receive queue for nebula to consume
// this is an encrypted packet or a handshake message in most cases
// packets were transmitted from another nebula node, you can send them with Tun.Send
func (u *udpConn) Send(packet *UdpPacket) {
h := &Header{}
if err := h.Parse(packet.Data); err != nil {
panic(err)
}
u.l.WithField("header", h).
WithField("udpAddr", fmt.Sprintf("%v:%v", packet.FromIp, packet.FromPort)).
WithField("dataLen", len(packet.Data)).
Info("UDP receiving injected packet")
u.rxPackets <- packet
}
// Get will pull a UdpPacket from the transmit queue
// nebula meant to send this message on the network, it will be encrypted
// packets were ingested from the tun side (in most cases), you can send them with Tun.Send
func (u *udpConn) Get(block bool) *UdpPacket {
if block {
return <-u.txPackets
}
select {
case p := <-u.txPackets:
return p
default:
return nil
}
}
//********************************************************************************************************************//
// Below this is boilerplate implementation to make nebula actually work
//********************************************************************************************************************//
func (u *udpConn) WriteTo(b []byte, addr *udpAddr) error {
p := &UdpPacket{
Data: make([]byte, len(b), len(b)),
FromIp: make([]byte, 16),
FromPort: u.addr.Port,
ToIp: make([]byte, 16),
ToPort: addr.Port,
}
copy(p.Data, b)
copy(p.ToIp, addr.IP.To16())
copy(p.FromIp, u.addr.IP.To16())
u.txPackets <- p
return nil
}
func (u *udpConn) ListenOut(f *Interface, q int) {
plaintext := make([]byte, mtu)
header := &Header{}
fwPacket := &FirewallPacket{}
ua := &udpAddr{IP: make([]byte, 16)}
nb := make([]byte, 12, 12)
lhh := f.lightHouse.NewRequestHandler()
conntrackCache := NewConntrackCacheTicker(f.conntrackCacheTimeout)
for {
p := <-u.rxPackets
ua.Port = p.FromPort
copy(ua.IP, p.FromIp.To16())
f.readOutsidePackets(ua, plaintext[:0], p.Data, header, fwPacket, lhh, nb, q, conntrackCache.Get(u.l))
}
}
func (u *udpConn) reloadConfig(*Config) {}
func NewUDPStatsEmitter(_ []*udpConn) func() {
// No UDP stats for non-linux
return func() {}
}
func (u *udpConn) LocalAddr() (*udpAddr, error) {
return u.addr, nil
}
func (u *udpConn) Rebind() error {
return nil
}
func hostDidRoam(addr *udpAddr, newaddr *udpAddr) bool {
return !addr.Equals(newaddr)
}

View File

@@ -1,3 +1,5 @@
// +build !e2e_testing
package nebula
// Windows support is primarily implemented in udp_generic, besides NewListenConfig