Compare commits

..

34 Commits

Author SHA1 Message Date
brad-defined
105e0ec66c v1.9.6 (#1434)
Update CHANGELOG for Nebula v1.9.6
2025-07-18 08:39:33 -04:00
Nate Brown
4870bb680d Darwin udp fix (#1426) 2025-07-01 16:41:29 -05:00
brad-defined
a1498ca8f8 Store relay states in a slice for consistent ordering (#1422) 2025-06-24 12:04:00 -04:00
Nate Brown
9877648da9 Drop inactive tunnels (#1413) 2025-06-23 11:32:50 -05:00
brad-defined
8e0a7bcbb7 Disable UDP receive error returns due to ICMP messages on Windows. (#1412) 2025-05-22 08:55:45 -04:00
brad-defined
8c29b15c6d fix relay migration panic (#1403) 2025-05-13 14:58:58 -04:00
brad-defined
04d7a8ccba Retry UDP receive on Windows in some receive error cases (#1404) 2025-05-13 14:58:37 -04:00
Nate Brown
b55b9019a7 v1.9.5 (#1285)
Update CHANGELOG for Nebula v1.9.5
2024-12-06 09:50:24 -05:00
Nate Brown
2e85d138cd [v1.9.x] do not panic when loading a V2 CA certificate (#1282)
Co-authored-by: Jack Doan <jackdoan@rivian.com>
2024-12-03 09:49:54 -06:00
brad-defined
9bfdfbafc1 Backport reestablish relays from cert-v2 to release-1.9 (#1277) 2024-11-20 21:49:53 -06:00
Wade Simmons
ab81b62ea0 v1.9.4 (#1210)
Update CHANGELOG for Nebula v1.9.4
2024-09-09 14:11:44 -04:00
dependabot[bot]
45bbad2f21 Bump the golang-x-dependencies group with 4 updates (#1195)
Bumps the golang-x-dependencies group with 4 updates: [golang.org/x/crypto](https://github.com/golang/crypto), [golang.org/x/net](https://github.com/golang/net), [golang.org/x/sys](https://github.com/golang/sys) and [golang.org/x/term](https://github.com/golang/term).


Updates `golang.org/x/crypto` from 0.25.0 to 0.26.0
- [Commits](https://github.com/golang/crypto/compare/v0.25.0...v0.26.0)

Updates `golang.org/x/net` from 0.27.0 to 0.28.0
- [Commits](https://github.com/golang/net/compare/v0.27.0...v0.28.0)

Updates `golang.org/x/sys` from 0.23.0 to 0.24.0
- [Commits](https://github.com/golang/sys/compare/v0.23.0...v0.24.0)

Updates `golang.org/x/term` from 0.22.0 to 0.23.0
- [Commits](https://github.com/golang/term/compare/v0.22.0...v0.23.0)

---
updated-dependencies:
- dependency-name: golang.org/x/crypto
  dependency-type: direct:production
  update-type: version-update:semver-minor
  dependency-group: golang-x-dependencies
- dependency-name: golang.org/x/net
  dependency-type: direct:production
  update-type: version-update:semver-minor
  dependency-group: golang-x-dependencies
- dependency-name: golang.org/x/sys
  dependency-type: direct:production
  update-type: version-update:semver-minor
  dependency-group: golang-x-dependencies
- dependency-name: golang.org/x/term
  dependency-type: direct:production
  update-type: version-update:semver-minor
  dependency-group: golang-x-dependencies
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2024-09-03 16:47:36 -04:00
Jack Doan
3dc56e1184 Support UDP dialling with gvisor (#1181) 2024-08-26 12:38:32 -05:00
Wade Simmons
0736cfa562 udp: fix endianness for port (#1194)
If the host OS is already big endian, we were swapping bytes when we
shouldn't have. Use the Go helper to make sure we do the endianness
correctly

Fixes: #1189
2024-08-14 12:53:00 -04:00
Jack Doan
248cf194cd fix integer wraparound in the calculation of handshake timeouts on 32-bit targets (#1185)
Fixes: #1169
2024-08-13 09:25:18 -04:00
dependabot[bot]
8a6a0f0636 Bump the golang-x-dependencies group with 2 updates (#1190)
Bumps the golang-x-dependencies group with 2 updates: [golang.org/x/sync](https://github.com/golang/sync) and [golang.org/x/sys](https://github.com/golang/sys).


Updates `golang.org/x/sync` from 0.7.0 to 0.8.0
- [Commits](https://github.com/golang/sync/compare/v0.7.0...v0.8.0)

Updates `golang.org/x/sys` from 0.22.0 to 0.23.0
- [Commits](https://github.com/golang/sys/compare/v0.22.0...v0.23.0)

---
updated-dependencies:
- dependency-name: golang.org/x/sync
  dependency-type: direct:production
  update-type: version-update:semver-minor
  dependency-group: golang-x-dependencies
- dependency-name: golang.org/x/sys
  dependency-type: direct:production
  update-type: version-update:semver-minor
  dependency-group: golang-x-dependencies
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2024-08-07 11:58:46 -04:00
Wade Simmons
f5f6c269ac fix rare panic when local index collision happens (#1191)
A local index collision happens when two tunnels attempt to use the same
random int32 index ID. This is a rare chance, and we have code to deal
with it, but we have a panic because we return the wrong thing in this
case. This change should fix the panic.
2024-08-07 11:53:32 -04:00
brad-defined
9a63fa0a07 Make some Nebula state programmatically available via control object (#1188) 2024-08-01 13:40:05 -04:00
Nate Brown
e264a0ff88 Switch most everything to netip in prep for ipv6 in the overlay (#1173) 2024-07-31 10:18:56 -05:00
dependabot[bot]
00458302ca Bump the golang-x-dependencies group with 4 updates (#1174)
Bumps the golang-x-dependencies group with 4 updates: [golang.org/x/crypto](https://github.com/golang/crypto), [golang.org/x/net](https://github.com/golang/net), [golang.org/x/sys](https://github.com/golang/sys) and [golang.org/x/term](https://github.com/golang/term).


Updates `golang.org/x/crypto` from 0.24.0 to 0.25.0
- [Commits](https://github.com/golang/crypto/compare/v0.24.0...v0.25.0)

Updates `golang.org/x/net` from 0.26.0 to 0.27.0
- [Commits](https://github.com/golang/net/compare/v0.26.0...v0.27.0)

Updates `golang.org/x/sys` from 0.21.0 to 0.22.0
- [Commits](https://github.com/golang/sys/compare/v0.21.0...v0.22.0)

Updates `golang.org/x/term` from 0.21.0 to 0.22.0
- [Commits](https://github.com/golang/term/compare/v0.21.0...v0.22.0)

---
updated-dependencies:
- dependency-name: golang.org/x/crypto
  dependency-type: direct:production
  update-type: version-update:semver-minor
  dependency-group: golang-x-dependencies
- dependency-name: golang.org/x/net
  dependency-type: direct:production
  update-type: version-update:semver-minor
  dependency-group: golang-x-dependencies
- dependency-name: golang.org/x/sys
  dependency-type: direct:production
  update-type: version-update:semver-minor
  dependency-group: golang-x-dependencies
- dependency-name: golang.org/x/term
  dependency-type: direct:production
  update-type: version-update:semver-minor
  dependency-group: golang-x-dependencies
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2024-07-29 11:42:33 -04:00
Wade Simmons
e6009b8491 github actions: use macos-latest (#1171)
macos-11 was deprecated and removed:

> The macos-11 label has been deprecated and will no longer be available after 28 June 2024.

We can just use macos-latest instead.
2024-07-02 11:50:51 -04:00
dependabot[bot]
b9aace1e58 Bump github.com/prometheus/client_golang from 1.19.0 to 1.19.1 (#1147)
Bumps [github.com/prometheus/client_golang](https://github.com/prometheus/client_golang) from 1.19.0 to 1.19.1.
- [Release notes](https://github.com/prometheus/client_golang/releases)
- [Changelog](https://github.com/prometheus/client_golang/blob/main/CHANGELOG.md)
- [Commits](https://github.com/prometheus/client_golang/compare/v1.19.0...v1.19.1)

---
updated-dependencies:
- dependency-name: github.com/prometheus/client_golang
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2024-06-24 14:54:51 -04:00
dependabot[bot]
a76723eaf5 Bump Apple-Actions/import-codesign-certs from 2 to 3 (#1146)
Bumps [Apple-Actions/import-codesign-certs](https://github.com/apple-actions/import-codesign-certs) from 2 to 3.
- [Release notes](https://github.com/apple-actions/import-codesign-certs/releases)
- [Commits](https://github.com/apple-actions/import-codesign-certs/compare/v2...v3)

---
updated-dependencies:
- dependency-name: Apple-Actions/import-codesign-certs
  dependency-type: direct:production
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2024-06-24 14:54:05 -04:00
Caleb Jasik
8109cf2170 Add puncuation to doc comment (#1164)
* Add puncuation to doc comment

* Fix list formatting inside `EncryptDanger` doc comment
2024-06-24 14:50:17 -04:00
Wade Simmons
97e9834f82 cleanup SK_MEMINFO vars (#1162)
We had to manually define these types before, but the latest release of
`golang.org/x/sys` adds these definitions:

- 6dfb94eaa3

Since we just updated with this PR, we can clean this up now:

- https://github.com/slackhq/nebula/pull/1161
2024-06-24 14:47:14 -04:00
dependabot[bot]
506ba5ab5b Bump github.com/miekg/dns from 1.1.59 to 1.1.61 (#1168)
Bumps [github.com/miekg/dns](https://github.com/miekg/dns) from 1.1.59 to 1.1.61.
- [Changelog](https://github.com/miekg/dns/blob/master/Makefile.release)
- [Commits](https://github.com/miekg/dns/compare/v1.1.59...v1.1.61)

---
updated-dependencies:
- dependency-name: github.com/miekg/dns
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2024-06-24 14:46:27 -04:00
dependabot[bot]
d372df56ab Bump google.golang.org/protobuf in the protobuf-dependencies group (#1167)
Bumps the protobuf-dependencies group with 1 update: google.golang.org/protobuf.


Updates `google.golang.org/protobuf` from 1.34.1 to 1.34.2

---
updated-dependencies:
- dependency-name: google.golang.org/protobuf
  dependency-type: direct:production
  update-type: version-update:semver-patch
  dependency-group: protobuf-dependencies
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2024-06-24 14:45:52 -04:00
dependabot[bot]
40cfd00e87 Bump the golang-x-dependencies group with 4 updates (#1161)
Bumps the golang-x-dependencies group with 4 updates: [golang.org/x/crypto](https://github.com/golang/crypto), [golang.org/x/net](https://github.com/golang/net), [golang.org/x/sys](https://github.com/golang/sys) and [golang.org/x/term](https://github.com/golang/term).


Updates `golang.org/x/crypto` from 0.23.0 to 0.24.0
- [Commits](https://github.com/golang/crypto/compare/v0.23.0...v0.24.0)

Updates `golang.org/x/net` from 0.25.0 to 0.26.0
- [Commits](https://github.com/golang/net/compare/v0.25.0...v0.26.0)

Updates `golang.org/x/sys` from 0.20.0 to 0.21.0
- [Commits](https://github.com/golang/sys/compare/v0.20.0...v0.21.0)

Updates `golang.org/x/term` from 0.20.0 to 0.21.0
- [Commits](https://github.com/golang/term/compare/v0.20.0...v0.21.0)

---
updated-dependencies:
- dependency-name: golang.org/x/crypto
  dependency-type: direct:production
  update-type: version-update:semver-minor
  dependency-group: golang-x-dependencies
- dependency-name: golang.org/x/net
  dependency-type: direct:production
  update-type: version-update:semver-minor
  dependency-group: golang-x-dependencies
- dependency-name: golang.org/x/sys
  dependency-type: direct:production
  update-type: version-update:semver-minor
  dependency-group: golang-x-dependencies
- dependency-name: golang.org/x/term
  dependency-type: direct:production
  update-type: version-update:semver-minor
  dependency-group: golang-x-dependencies
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2024-06-10 16:08:43 -04:00
Wade Simmons
b14bad586a v1.9.3 (#1160)
Update CHANGELOG for Nebula v1.9.3
2024-06-06 13:17:07 -04:00
Wade Simmons
4c066d8c32 initialize messageCounter to 2 instead of verifying later (#1156)
Clean up the messageCounter checks added in #1154. Instead of checking that
messageCounter is still at 2, just initialize it to 2 and only increment for
non-handshake messages. Handshake packets will always be packets 1 and 2.
2024-06-06 13:03:07 -04:00
Wade Simmons
249ae41fec v1.9.2 (#1155)
Update CHANGELOG for Nebula v1.9.2
2024-06-03 15:50:02 -04:00
Wade Simmons
d9cae9e062 ensure messageCounter is set before handshake is complete (#1154)
Ensure we set messageCounter to 2 before the handshake is marked as
complete.
2024-06-03 15:40:51 -04:00
Wade Simmons
a92056a7db v1.9.1 (#1152)
Update CHANGELOG for Nebula v1.9.1
2024-05-29 14:06:46 -04:00
Wade Simmons
4eb1da0958 remove deadlock in GetOrHandshake (#1151)
We had a rare deadlock in GetOrHandshake because we kept the hostmap
lock when we do the call to StartHandshake. StartHandshake can block
while sending to the lighthouse query worker channel, and that worker
needs to be able to grab the hostmap lock to do its work. Other calls
for StartHandshake don't hold the hostmap lock so we should be able to
drop it here.

This lock was originally added with: https://github.com/slackhq/nebula/pull/954
2024-05-29 12:52:52 -04:00
92 changed files with 3179 additions and 3230 deletions

View File

@@ -64,7 +64,7 @@ jobs:
name: Build Universal Darwin name: Build Universal Darwin
env: env:
HAS_SIGNING_CREDS: ${{ secrets.AC_USERNAME != '' }} HAS_SIGNING_CREDS: ${{ secrets.AC_USERNAME != '' }}
runs-on: macos-11 runs-on: macos-latest
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
@@ -75,7 +75,7 @@ jobs:
- name: Import certificates - name: Import certificates
if: env.HAS_SIGNING_CREDS == 'true' if: env.HAS_SIGNING_CREDS == 'true'
uses: Apple-Actions/import-codesign-certs@v2 uses: Apple-Actions/import-codesign-certs@v3
with: with:
p12-file-base64: ${{ secrets.APPLE_DEVELOPER_CERTIFICATE_P12_BASE64 }} p12-file-base64: ${{ secrets.APPLE_DEVELOPER_CERTIFICATE_P12_BASE64 }}
p12-password: ${{ secrets.APPLE_DEVELOPER_CERTIFICATE_PASSWORD }} p12-password: ${{ secrets.APPLE_DEVELOPER_CERTIFICATE_PASSWORD }}

View File

@@ -72,7 +72,7 @@ jobs:
runs-on: ${{ matrix.os }} runs-on: ${{ matrix.os }}
strategy: strategy:
matrix: matrix:
os: [windows-latest, macos-11] os: [windows-latest, macos-latest]
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v4

View File

@@ -7,6 +7,67 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [Unreleased] ## [Unreleased]
## [1.9.6] - 2025-7-15
### Added
- Support dropping inactive tunnels. This is disabled by default in this release but can be enabled with `tunnels.drop_inactive`. See example config for more details. (#1413)
### Fixed
- Fix Darwin freeze due to presence of some Network Extensions (#1426)
- Ensure the same relay tunnel is always used when multiple relay tunnels are present (#1422)
- Fix Windows freeze due to ICMP error handling (#1412)
- Fix relay migration panic (#1403)
## [1.9.5] - 2024-12-05
### Added
- Gracefully ignore v2 certificates. (#1282)
### Fixed
- Fix relays that refuse to re-establish after one of the remote tunnel pairs breaks. (#1277)
## [1.9.4] - 2024-09-09
### Added
- Support UDP dialing with gVisor. (#1181)
### Changed
- Make some Nebula state programmatically available via control object. (#1188)
- Switch internal representation of IPs to netip, to prepare for IPv6 support
in the overlay. (#1173)
- Minor build and cleanup changes. (#1171, #1164, #1162)
- Various dependency updates. (#1195, #1190, #1174, #1168, #1167, #1161, #1147, #1146)
### Fixed
- Fix a bug on big endian hosts, like mips. (#1194)
- Fix a rare panic if a local index collision happens. (#1191)
- Fix integer wraparound in the calculation of handshake timeouts on 32-bit targets. (#1185)
## [1.9.3] - 2024-06-06
### Fixed
- Initialize messageCounter to 2 instead of verifying later. (#1156)
## [1.9.2] - 2024-06-03
### Fixed
- Ensure messageCounter is set before handshake is complete. (#1154)
## [1.9.1] - 2024-05-29
### Fixed
- Fixed a potential deadlock in GetOrHandshake. (#1151)
## [1.9.0] - 2024-05-07 ## [1.9.0] - 2024-05-07
### Deprecated ### Deprecated
@@ -626,7 +687,13 @@ created.)
- Initial public release. - Initial public release.
[Unreleased]: https://github.com/slackhq/nebula/compare/v1.9.0...HEAD [Unreleased]: https://github.com/slackhq/nebula/compare/v1.9.6...HEAD
[1.9.6]: https://github.com/slackhq/nebula/releases/tag/v1.9.6
[1.9.5]: https://github.com/slackhq/nebula/releases/tag/v1.9.5
[1.9.4]: https://github.com/slackhq/nebula/releases/tag/v1.9.4
[1.9.3]: https://github.com/slackhq/nebula/releases/tag/v1.9.3
[1.9.2]: https://github.com/slackhq/nebula/releases/tag/v1.9.2
[1.9.1]: https://github.com/slackhq/nebula/releases/tag/v1.9.1
[1.9.0]: https://github.com/slackhq/nebula/releases/tag/v1.9.0 [1.9.0]: https://github.com/slackhq/nebula/releases/tag/v1.9.0
[1.8.2]: https://github.com/slackhq/nebula/releases/tag/v1.8.2 [1.8.2]: https://github.com/slackhq/nebula/releases/tag/v1.8.2
[1.8.1]: https://github.com/slackhq/nebula/releases/tag/v1.8.1 [1.8.1]: https://github.com/slackhq/nebula/releases/tag/v1.8.1

View File

@@ -2,17 +2,16 @@ package nebula
import ( import (
"fmt" "fmt"
"net" "net/netip"
"regexp" "regexp"
"github.com/slackhq/nebula/cidr" "github.com/gaissmai/bart"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/iputil"
) )
type AllowList struct { type AllowList struct {
// The values of this cidrTree are `bool`, signifying allow/deny // The values of this cidrTree are `bool`, signifying allow/deny
cidrTree *cidr.Tree6[bool] cidrTree *bart.Table[bool]
} }
type RemoteAllowList struct { type RemoteAllowList struct {
@@ -20,7 +19,7 @@ type RemoteAllowList struct {
// Inside Range Specific, keys of this tree are inside CIDRs and values // Inside Range Specific, keys of this tree are inside CIDRs and values
// are *AllowList // are *AllowList
insideAllowLists *cidr.Tree6[*AllowList] insideAllowLists *bart.Table[*AllowList]
} }
type LocalAllowList struct { type LocalAllowList struct {
@@ -88,7 +87,7 @@ func newAllowList(k string, raw interface{}, handleKey func(key string, value in
return nil, fmt.Errorf("config `%s` has invalid type: %T", k, raw) return nil, fmt.Errorf("config `%s` has invalid type: %T", k, raw)
} }
tree := cidr.NewTree6[bool]() tree := new(bart.Table[bool])
// Keep track of the rules we have added for both ipv4 and ipv6 // Keep track of the rules we have added for both ipv4 and ipv6
type allowListRules struct { type allowListRules struct {
@@ -122,18 +121,20 @@ func newAllowList(k string, raw interface{}, handleKey func(key string, value in
return nil, fmt.Errorf("config `%s` has invalid value (type %T): %v", k, rawValue, rawValue) return nil, fmt.Errorf("config `%s` has invalid value (type %T): %v", k, rawValue, rawValue)
} }
_, ipNet, err := net.ParseCIDR(rawCIDR) ipNet, err := netip.ParsePrefix(rawCIDR)
if err != nil { if err != nil {
return nil, fmt.Errorf("config `%s` has invalid CIDR: %s", k, rawCIDR) return nil, fmt.Errorf("config `%s` has invalid CIDR: %s. %w", k, rawCIDR, err)
} }
// TODO: should we error on duplicate CIDRs in the config? ipNet = netip.PrefixFrom(ipNet.Addr().Unmap(), ipNet.Bits())
tree.AddCIDR(ipNet, value)
maskBits, maskSize := ipNet.Mask.Size() // TODO: should we error on duplicate CIDRs in the config?
tree.Insert(ipNet, value)
maskBits := ipNet.Bits()
var rules *allowListRules var rules *allowListRules
if maskSize == 32 { if ipNet.Addr().Is4() {
rules = &rules4 rules = &rules4
} else { } else {
rules = &rules6 rules = &rules6
@@ -156,8 +157,7 @@ func newAllowList(k string, raw interface{}, handleKey func(key string, value in
if !rules4.defaultSet { if !rules4.defaultSet {
if rules4.allValuesMatch { if rules4.allValuesMatch {
_, zeroCIDR, _ := net.ParseCIDR("0.0.0.0/0") tree.Insert(netip.PrefixFrom(netip.IPv4Unspecified(), 0), !rules4.allValues)
tree.AddCIDR(zeroCIDR, !rules4.allValues)
} else { } else {
return nil, fmt.Errorf("config `%s` contains both true and false rules, but no default set for 0.0.0.0/0", k) return nil, fmt.Errorf("config `%s` contains both true and false rules, but no default set for 0.0.0.0/0", k)
} }
@@ -165,8 +165,7 @@ func newAllowList(k string, raw interface{}, handleKey func(key string, value in
if !rules6.defaultSet { if !rules6.defaultSet {
if rules6.allValuesMatch { if rules6.allValuesMatch {
_, zeroCIDR, _ := net.ParseCIDR("::/0") tree.Insert(netip.PrefixFrom(netip.IPv6Unspecified(), 0), !rules6.allValues)
tree.AddCIDR(zeroCIDR, !rules6.allValues)
} else { } else {
return nil, fmt.Errorf("config `%s` contains both true and false rules, but no default set for ::/0", k) return nil, fmt.Errorf("config `%s` contains both true and false rules, but no default set for ::/0", k)
} }
@@ -218,13 +217,13 @@ func getAllowListInterfaces(k string, v interface{}) ([]AllowListNameRule, error
return nameRules, nil return nameRules, nil
} }
func getRemoteAllowRanges(c *config.C, k string) (*cidr.Tree6[*AllowList], error) { func getRemoteAllowRanges(c *config.C, k string) (*bart.Table[*AllowList], error) {
value := c.Get(k) value := c.Get(k)
if value == nil { if value == nil {
return nil, nil return nil, nil
} }
remoteAllowRanges := cidr.NewTree6[*AllowList]() remoteAllowRanges := new(bart.Table[*AllowList])
rawMap, ok := value.(map[interface{}]interface{}) rawMap, ok := value.(map[interface{}]interface{})
if !ok { if !ok {
@@ -241,45 +240,27 @@ func getRemoteAllowRanges(c *config.C, k string) (*cidr.Tree6[*AllowList], error
return nil, err return nil, err
} }
_, ipNet, err := net.ParseCIDR(rawCIDR) ipNet, err := netip.ParsePrefix(rawCIDR)
if err != nil { if err != nil {
return nil, fmt.Errorf("config `%s` has invalid CIDR: %s", k, rawCIDR) return nil, fmt.Errorf("config `%s` has invalid CIDR: %s. %w", k, rawCIDR, err)
} }
remoteAllowRanges.AddCIDR(ipNet, allowList) remoteAllowRanges.Insert(netip.PrefixFrom(ipNet.Addr().Unmap(), ipNet.Bits()), allowList)
} }
return remoteAllowRanges, nil return remoteAllowRanges, nil
} }
func (al *AllowList) Allow(ip net.IP) bool { func (al *AllowList) Allow(ip netip.Addr) bool {
if al == nil { if al == nil {
return true return true
} }
_, result := al.cidrTree.MostSpecificContains(ip) result, _ := al.cidrTree.Lookup(ip)
return result return result
} }
func (al *AllowList) AllowIpV4(ip iputil.VpnIp) bool { func (al *LocalAllowList) Allow(ip netip.Addr) bool {
if al == nil {
return true
}
_, result := al.cidrTree.MostSpecificContainsIpV4(ip)
return result
}
func (al *AllowList) AllowIpV6(hi, lo uint64) bool {
if al == nil {
return true
}
_, result := al.cidrTree.MostSpecificContainsIpV6(hi, lo)
return result
}
func (al *LocalAllowList) Allow(ip net.IP) bool {
if al == nil { if al == nil {
return true return true
} }
@@ -301,43 +282,23 @@ func (al *LocalAllowList) AllowName(name string) bool {
return !al.nameRules[0].Allow return !al.nameRules[0].Allow
} }
func (al *RemoteAllowList) AllowUnknownVpnIp(ip net.IP) bool { func (al *RemoteAllowList) AllowUnknownVpnIp(ip netip.Addr) bool {
if al == nil { if al == nil {
return true return true
} }
return al.AllowList.Allow(ip) return al.AllowList.Allow(ip)
} }
func (al *RemoteAllowList) Allow(vpnIp iputil.VpnIp, ip net.IP) bool { func (al *RemoteAllowList) Allow(vpnIp netip.Addr, ip netip.Addr) bool {
if !al.getInsideAllowList(vpnIp).Allow(ip) { if !al.getInsideAllowList(vpnIp).Allow(ip) {
return false return false
} }
return al.AllowList.Allow(ip) return al.AllowList.Allow(ip)
} }
func (al *RemoteAllowList) AllowIpV4(vpnIp iputil.VpnIp, ip iputil.VpnIp) bool { func (al *RemoteAllowList) getInsideAllowList(vpnIp netip.Addr) *AllowList {
if al == nil {
return true
}
if !al.getInsideAllowList(vpnIp).AllowIpV4(ip) {
return false
}
return al.AllowList.AllowIpV4(ip)
}
func (al *RemoteAllowList) AllowIpV6(vpnIp iputil.VpnIp, hi, lo uint64) bool {
if al == nil {
return true
}
if !al.getInsideAllowList(vpnIp).AllowIpV6(hi, lo) {
return false
}
return al.AllowList.AllowIpV6(hi, lo)
}
func (al *RemoteAllowList) getInsideAllowList(vpnIp iputil.VpnIp) *AllowList {
if al.insideAllowLists != nil { if al.insideAllowLists != nil {
ok, inside := al.insideAllowLists.MostSpecificContainsIpV4(vpnIp) inside, ok := al.insideAllowLists.Lookup(vpnIp)
if ok { if ok {
return inside return inside
} }

View File

@@ -1,11 +1,11 @@
package nebula package nebula
import ( import (
"net" "net/netip"
"regexp" "regexp"
"testing" "testing"
"github.com/slackhq/nebula/cidr" "github.com/gaissmai/bart"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/test" "github.com/slackhq/nebula/test"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@@ -18,7 +18,7 @@ func TestNewAllowListFromConfig(t *testing.T) {
"192.168.0.0": true, "192.168.0.0": true,
} }
r, err := newAllowListFromConfig(c, "allowlist", nil) r, err := newAllowListFromConfig(c, "allowlist", nil)
assert.EqualError(t, err, "config `allowlist` has invalid CIDR: 192.168.0.0") assert.EqualError(t, err, "config `allowlist` has invalid CIDR: 192.168.0.0. netip.ParsePrefix(\"192.168.0.0\"): no '/'")
assert.Nil(t, r) assert.Nil(t, r)
c.Settings["allowlist"] = map[interface{}]interface{}{ c.Settings["allowlist"] = map[interface{}]interface{}{
@@ -98,26 +98,26 @@ func TestNewAllowListFromConfig(t *testing.T) {
} }
func TestAllowList_Allow(t *testing.T) { func TestAllowList_Allow(t *testing.T) {
assert.Equal(t, true, ((*AllowList)(nil)).Allow(net.ParseIP("1.1.1.1"))) assert.Equal(t, true, ((*AllowList)(nil)).Allow(netip.MustParseAddr("1.1.1.1")))
tree := cidr.NewTree6[bool]() tree := new(bart.Table[bool])
tree.AddCIDR(cidr.Parse("0.0.0.0/0"), true) tree.Insert(netip.MustParsePrefix("0.0.0.0/0"), true)
tree.AddCIDR(cidr.Parse("10.0.0.0/8"), false) tree.Insert(netip.MustParsePrefix("10.0.0.0/8"), false)
tree.AddCIDR(cidr.Parse("10.42.42.42/32"), true) tree.Insert(netip.MustParsePrefix("10.42.42.42/32"), true)
tree.AddCIDR(cidr.Parse("10.42.0.0/16"), true) tree.Insert(netip.MustParsePrefix("10.42.0.0/16"), true)
tree.AddCIDR(cidr.Parse("10.42.42.0/24"), true) tree.Insert(netip.MustParsePrefix("10.42.42.0/24"), true)
tree.AddCIDR(cidr.Parse("10.42.42.0/24"), false) tree.Insert(netip.MustParsePrefix("10.42.42.0/24"), false)
tree.AddCIDR(cidr.Parse("::1/128"), true) tree.Insert(netip.MustParsePrefix("::1/128"), true)
tree.AddCIDR(cidr.Parse("::2/128"), false) tree.Insert(netip.MustParsePrefix("::2/128"), false)
al := &AllowList{cidrTree: tree} al := &AllowList{cidrTree: tree}
assert.Equal(t, true, al.Allow(net.ParseIP("1.1.1.1"))) assert.Equal(t, true, al.Allow(netip.MustParseAddr("1.1.1.1")))
assert.Equal(t, false, al.Allow(net.ParseIP("10.0.0.4"))) assert.Equal(t, false, al.Allow(netip.MustParseAddr("10.0.0.4")))
assert.Equal(t, true, al.Allow(net.ParseIP("10.42.42.42"))) assert.Equal(t, true, al.Allow(netip.MustParseAddr("10.42.42.42")))
assert.Equal(t, false, al.Allow(net.ParseIP("10.42.42.41"))) assert.Equal(t, false, al.Allow(netip.MustParseAddr("10.42.42.41")))
assert.Equal(t, true, al.Allow(net.ParseIP("10.42.0.1"))) assert.Equal(t, true, al.Allow(netip.MustParseAddr("10.42.0.1")))
assert.Equal(t, true, al.Allow(net.ParseIP("::1"))) assert.Equal(t, true, al.Allow(netip.MustParseAddr("::1")))
assert.Equal(t, false, al.Allow(net.ParseIP("::2"))) assert.Equal(t, false, al.Allow(netip.MustParseAddr("::2")))
} }
func TestLocalAllowList_AllowName(t *testing.T) { func TestLocalAllowList_AllowName(t *testing.T) {

View File

@@ -1,40 +1,35 @@
package nebula package nebula
import ( import (
"encoding/binary"
"fmt" "fmt"
"math" "math"
"net" "net"
"net/netip"
"strconv" "strconv"
"github.com/slackhq/nebula/cidr" "github.com/gaissmai/bart"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/iputil"
) )
// This allows us to "guess" what the remote might be for a host while we wait // This allows us to "guess" what the remote might be for a host while we wait
// for the lighthouse response. See "lighthouse.calculated_remotes" in the // for the lighthouse response. See "lighthouse.calculated_remotes" in the
// example config file. // example config file.
type calculatedRemote struct { type calculatedRemote struct {
ipNet net.IPNet ipNet netip.Prefix
maskIP iputil.VpnIp mask netip.Prefix
mask iputil.VpnIp
port uint32 port uint32
} }
func newCalculatedRemote(ipNet *net.IPNet, port int) (*calculatedRemote, error) { func newCalculatedRemote(maskCidr netip.Prefix, port int) (*calculatedRemote, error) {
// Ensure this is an IPv4 mask that we expect masked := maskCidr.Masked()
ones, bits := ipNet.Mask.Size()
if ones == 0 || bits != 32 {
return nil, fmt.Errorf("invalid mask: %v", ipNet)
}
if port < 0 || port > math.MaxUint16 { if port < 0 || port > math.MaxUint16 {
return nil, fmt.Errorf("invalid port: %d", port) return nil, fmt.Errorf("invalid port: %d", port)
} }
return &calculatedRemote{ return &calculatedRemote{
ipNet: *ipNet, ipNet: maskCidr,
maskIP: iputil.Ip2VpnIp(ipNet.IP), mask: masked,
mask: iputil.Ip2VpnIp(ipNet.Mask),
port: uint32(port), port: uint32(port),
}, nil }, nil
} }
@@ -43,21 +38,41 @@ func (c *calculatedRemote) String() string {
return fmt.Sprintf("CalculatedRemote(mask=%v port=%d)", c.ipNet, c.port) return fmt.Sprintf("CalculatedRemote(mask=%v port=%d)", c.ipNet, c.port)
} }
func (c *calculatedRemote) Apply(ip iputil.VpnIp) *Ip4AndPort { func (c *calculatedRemote) Apply(ip netip.Addr) *Ip4AndPort {
// Combine the masked bytes of the "mask" IP with the unmasked bytes // Combine the masked bytes of the "mask" IP with the unmasked bytes
// of the overlay IP // of the overlay IP
masked := (c.maskIP & c.mask) | (ip & ^c.mask) if c.ipNet.Addr().Is4() {
return c.apply4(ip)
return &Ip4AndPort{Ip: uint32(masked), Port: c.port} }
return c.apply6(ip)
} }
func NewCalculatedRemotesFromConfig(c *config.C, k string) (*cidr.Tree4[[]*calculatedRemote], error) { func (c *calculatedRemote) apply4(ip netip.Addr) *Ip4AndPort {
//TODO: IPV6-WORK this can be less crappy
maskb := net.CIDRMask(c.mask.Bits(), c.mask.Addr().BitLen())
mask := binary.BigEndian.Uint32(maskb[:])
b := c.mask.Addr().As4()
maskIp := binary.BigEndian.Uint32(b[:])
b = ip.As4()
intIp := binary.BigEndian.Uint32(b[:])
return &Ip4AndPort{(maskIp & mask) | (intIp & ^mask), c.port}
}
func (c *calculatedRemote) apply6(ip netip.Addr) *Ip4AndPort {
//TODO: IPV6-WORK
panic("Can not calculate ipv6 remote addresses")
}
func NewCalculatedRemotesFromConfig(c *config.C, k string) (*bart.Table[[]*calculatedRemote], error) {
value := c.Get(k) value := c.Get(k)
if value == nil { if value == nil {
return nil, nil return nil, nil
} }
calculatedRemotes := cidr.NewTree4[[]*calculatedRemote]() calculatedRemotes := new(bart.Table[[]*calculatedRemote])
rawMap, ok := value.(map[any]any) rawMap, ok := value.(map[any]any)
if !ok { if !ok {
@@ -69,17 +84,18 @@ func NewCalculatedRemotesFromConfig(c *config.C, k string) (*cidr.Tree4[[]*calcu
return nil, fmt.Errorf("config `%s` has invalid key (type %T): %v", k, rawKey, rawKey) return nil, fmt.Errorf("config `%s` has invalid key (type %T): %v", k, rawKey, rawKey)
} }
_, ipNet, err := net.ParseCIDR(rawCIDR) cidr, err := netip.ParsePrefix(rawCIDR)
if err != nil { if err != nil {
return nil, fmt.Errorf("config `%s` has invalid CIDR: %s", k, rawCIDR) return nil, fmt.Errorf("config `%s` has invalid CIDR: %s", k, rawCIDR)
} }
//TODO: IPV6-WORK this does not verify that rawValue contains the same bits as cidr here
entry, err := newCalculatedRemotesListFromConfig(rawValue) entry, err := newCalculatedRemotesListFromConfig(rawValue)
if err != nil { if err != nil {
return nil, fmt.Errorf("config '%s.%s': %w", k, rawCIDR, err) return nil, fmt.Errorf("config '%s.%s': %w", k, rawCIDR, err)
} }
calculatedRemotes.AddCIDR(ipNet, entry) calculatedRemotes.Insert(cidr, entry)
} }
return calculatedRemotes, nil return calculatedRemotes, nil
@@ -117,7 +133,7 @@ func newCalculatedRemotesEntryFromConfig(raw any) (*calculatedRemote, error) {
if !ok { if !ok {
return nil, fmt.Errorf("invalid mask (type %T): %v", rawValue, rawValue) return nil, fmt.Errorf("invalid mask (type %T): %v", rawValue, rawValue)
} }
_, ipNet, err := net.ParseCIDR(rawMask) maskCidr, err := netip.ParsePrefix(rawMask)
if err != nil { if err != nil {
return nil, fmt.Errorf("invalid mask: %s", rawMask) return nil, fmt.Errorf("invalid mask: %s", rawMask)
} }
@@ -139,5 +155,5 @@ func newCalculatedRemotesEntryFromConfig(raw any) (*calculatedRemote, error) {
return nil, fmt.Errorf("invalid port (type %T): %v", rawValue, rawValue) return nil, fmt.Errorf("invalid port (type %T): %v", rawValue, rawValue)
} }
return newCalculatedRemote(ipNet, port) return newCalculatedRemote(maskCidr, port)
} }

View File

@@ -1,27 +1,25 @@
package nebula package nebula
import ( import (
"net" "net/netip"
"testing" "testing"
"github.com/slackhq/nebula/iputil"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func TestCalculatedRemoteApply(t *testing.T) { func TestCalculatedRemoteApply(t *testing.T) {
_, ipNet, err := net.ParseCIDR("192.168.1.0/24") ipNet, err := netip.ParsePrefix("192.168.1.0/24")
require.NoError(t, err) require.NoError(t, err)
c, err := newCalculatedRemote(ipNet, 4242) c, err := newCalculatedRemote(ipNet, 4242)
require.NoError(t, err) require.NoError(t, err)
input := iputil.Ip2VpnIp([]byte{10, 0, 10, 182}) input, err := netip.ParseAddr("10.0.10.182")
assert.NoError(t, err)
expected := &Ip4AndPort{ expected, err := netip.ParseAddr("192.168.1.182")
Ip: uint32(iputil.Ip2VpnIp([]byte{192, 168, 1, 182})), assert.NoError(t, err)
Port: 4242,
}
assert.Equal(t, expected, c.Apply(input)) assert.Equal(t, NewIp4AndPortFromNetIP(expected, 4242), c.Apply(input))
} }

View File

@@ -24,31 +24,39 @@ func NewCAPool() *NebulaCAPool {
// NewCAPoolFromBytes will create a new CA pool from the provided // NewCAPoolFromBytes will create a new CA pool from the provided
// input bytes, which must be a PEM-encoded set of nebula certificates. // input bytes, which must be a PEM-encoded set of nebula certificates.
// If the pool contains unsupported certificates, they will generate warnings
// in the []error return arg.
// If the pool contains any expired certificates, an ErrExpired will be // If the pool contains any expired certificates, an ErrExpired will be
// returned along with the pool. The caller must handle any such errors. // returned along with the pool. The caller must handle any such errors.
func NewCAPoolFromBytes(caPEMs []byte) (*NebulaCAPool, error) { func NewCAPoolFromBytes(caPEMs []byte) (*NebulaCAPool, []error, error) {
pool := NewCAPool() pool := NewCAPool()
var err error var err error
var expired bool var warnings []error
good := 0
for { for {
caPEMs, err = pool.AddCACertificate(caPEMs) caPEMs, err = pool.AddCACertificate(caPEMs)
if errors.Is(err, ErrExpired) { if errors.Is(err, ErrExpired) {
expired = true warnings = append(warnings, err)
err = nil } else if errors.Is(err, ErrInvalidPEMCertificateUnsupported) {
} warnings = append(warnings, err)
if err != nil { } else if err != nil {
return nil, err return nil, warnings, err
} else {
// Only consider a good certificate if there were no errors present
good++
} }
if len(caPEMs) == 0 || strings.TrimSpace(string(caPEMs)) == "" { if len(caPEMs) == 0 || strings.TrimSpace(string(caPEMs)) == "" {
break break
} }
} }
if expired { if good == 0 {
return pool, ErrExpired return nil, warnings, errors.New("no valid CA certificates present")
} }
return pool, nil return pool, warnings, nil
} }
// AddCACertificate verifies a Nebula CA certificate and adds it to the pool // AddCACertificate verifies a Nebula CA certificate and adds it to the pool

View File

@@ -28,6 +28,7 @@ const publicKeyLen = 32
const ( const (
CertBanner = "NEBULA CERTIFICATE" CertBanner = "NEBULA CERTIFICATE"
CertificateV2Banner = "NEBULA CERTIFICATE V2"
X25519PrivateKeyBanner = "NEBULA X25519 PRIVATE KEY" X25519PrivateKeyBanner = "NEBULA X25519 PRIVATE KEY"
X25519PublicKeyBanner = "NEBULA X25519 PUBLIC KEY" X25519PublicKeyBanner = "NEBULA X25519 PUBLIC KEY"
EncryptedEd25519PrivateKeyBanner = "NEBULA ED25519 ENCRYPTED PRIVATE KEY" EncryptedEd25519PrivateKeyBanner = "NEBULA ED25519 ENCRYPTED PRIVATE KEY"
@@ -163,6 +164,9 @@ func UnmarshalNebulaCertificateFromPEM(b []byte) (*NebulaCertificate, []byte, er
if p == nil { if p == nil {
return nil, r, fmt.Errorf("input did not contain a valid PEM encoded block") return nil, r, fmt.Errorf("input did not contain a valid PEM encoded block")
} }
if p.Type == CertificateV2Banner {
return nil, r, fmt.Errorf("%w: %s", ErrInvalidPEMCertificateUnsupported, p.Type)
}
if p.Type != CertBanner { if p.Type != CertBanner {
return nil, r, fmt.Errorf("bytes did not contain a proper nebula certificate banner") return nil, r, fmt.Errorf("bytes did not contain a proper nebula certificate banner")
} }

View File

@@ -5,6 +5,7 @@ import (
"crypto/ecdsa" "crypto/ecdsa"
"crypto/elliptic" "crypto/elliptic"
"crypto/rand" "crypto/rand"
"errors"
"fmt" "fmt"
"io" "io"
"net" "net"
@@ -572,6 +573,13 @@ CmYKEG5lYnVsYSBQMjU2IHRlc3Qo4s+7mgYw4tXrsAc6QQRkaW2jFmllYvN4+/k2
76gvQAGgBgESRzBFAiEAib0/te6eMiZOKD8gdDeloMTS0wGuX2t0C7TFdUhAQzgC 76gvQAGgBgESRzBFAiEAib0/te6eMiZOKD8gdDeloMTS0wGuX2t0C7TFdUhAQzgC
IBNWYMep3ysx9zCgknfG5dKtwGTaqF++BWKDYdyl34KX IBNWYMep3ysx9zCgknfG5dKtwGTaqF++BWKDYdyl34KX
-----END NEBULA CERTIFICATE----- -----END NEBULA CERTIFICATE-----
`
v2 := `
# valid PEM with the V2 header
-----BEGIN NEBULA CERTIFICATE V2-----
CmYKEG5lYnVsYSBQMjU2IHRlc3Qo4s+7mgYw4tXrsAc6QQRkaW2jFmllYvN4+/k2
-----END NEBULA CERTIFICATE V2-----
` `
rootCA := NebulaCertificate{ rootCA := NebulaCertificate{
@@ -592,33 +600,46 @@ IBNWYMep3ysx9zCgknfG5dKtwGTaqF++BWKDYdyl34KX
}, },
} }
p, err := NewCAPoolFromBytes([]byte(noNewLines)) p, warn, err := NewCAPoolFromBytes([]byte(noNewLines))
assert.Nil(t, err) assert.Nil(t, err)
assert.Nil(t, warn)
assert.Equal(t, p.CAs[string("c9bfaf7ce8e84b2eeda2e27b469f4b9617bde192efd214b68891ecda6ed49522")].Details.Name, rootCA.Details.Name) assert.Equal(t, p.CAs[string("c9bfaf7ce8e84b2eeda2e27b469f4b9617bde192efd214b68891ecda6ed49522")].Details.Name, rootCA.Details.Name)
assert.Equal(t, p.CAs[string("5c9c3f23e7ee7fe97637cbd3a0a5b854154d1d9aaaf7b566a51f4a88f76b64cd")].Details.Name, rootCA01.Details.Name) assert.Equal(t, p.CAs[string("5c9c3f23e7ee7fe97637cbd3a0a5b854154d1d9aaaf7b566a51f4a88f76b64cd")].Details.Name, rootCA01.Details.Name)
pp, err := NewCAPoolFromBytes([]byte(withNewLines)) pp, warn, err := NewCAPoolFromBytes([]byte(withNewLines))
assert.Nil(t, err) assert.Nil(t, err)
assert.Nil(t, warn)
assert.Equal(t, pp.CAs[string("c9bfaf7ce8e84b2eeda2e27b469f4b9617bde192efd214b68891ecda6ed49522")].Details.Name, rootCA.Details.Name) assert.Equal(t, pp.CAs[string("c9bfaf7ce8e84b2eeda2e27b469f4b9617bde192efd214b68891ecda6ed49522")].Details.Name, rootCA.Details.Name)
assert.Equal(t, pp.CAs[string("5c9c3f23e7ee7fe97637cbd3a0a5b854154d1d9aaaf7b566a51f4a88f76b64cd")].Details.Name, rootCA01.Details.Name) assert.Equal(t, pp.CAs[string("5c9c3f23e7ee7fe97637cbd3a0a5b854154d1d9aaaf7b566a51f4a88f76b64cd")].Details.Name, rootCA01.Details.Name)
// expired cert, no valid certs // expired cert, no valid certs
ppp, err := NewCAPoolFromBytes([]byte(expired)) ppp, warn, err := NewCAPoolFromBytes([]byte(expired))
assert.Equal(t, ErrExpired, err) assert.Error(t, err, "no valid CA certificates present")
assert.Equal(t, ppp.CAs[string("152070be6bb19bc9e3bde4c2f0e7d8f4ff5448b4c9856b8eccb314fade0229b0")].Details.Name, "expired") assert.Len(t, warn, 1)
assert.Error(t, warn[0], ErrExpired)
assert.Nil(t, ppp)
// expired cert, with valid certs // expired cert, with valid certs
pppp, err := NewCAPoolFromBytes(append([]byte(expired), noNewLines...)) pppp, warn, err := NewCAPoolFromBytes(append([]byte(expired), noNewLines...))
assert.Equal(t, ErrExpired, err) assert.Len(t, warn, 1)
assert.Nil(t, err)
assert.Error(t, warn[0], ErrExpired)
assert.Equal(t, pppp.CAs[string("c9bfaf7ce8e84b2eeda2e27b469f4b9617bde192efd214b68891ecda6ed49522")].Details.Name, rootCA.Details.Name) assert.Equal(t, pppp.CAs[string("c9bfaf7ce8e84b2eeda2e27b469f4b9617bde192efd214b68891ecda6ed49522")].Details.Name, rootCA.Details.Name)
assert.Equal(t, pppp.CAs[string("5c9c3f23e7ee7fe97637cbd3a0a5b854154d1d9aaaf7b566a51f4a88f76b64cd")].Details.Name, rootCA01.Details.Name) assert.Equal(t, pppp.CAs[string("5c9c3f23e7ee7fe97637cbd3a0a5b854154d1d9aaaf7b566a51f4a88f76b64cd")].Details.Name, rootCA01.Details.Name)
assert.Equal(t, pppp.CAs[string("152070be6bb19bc9e3bde4c2f0e7d8f4ff5448b4c9856b8eccb314fade0229b0")].Details.Name, "expired") assert.Equal(t, pppp.CAs[string("152070be6bb19bc9e3bde4c2f0e7d8f4ff5448b4c9856b8eccb314fade0229b0")].Details.Name, "expired")
assert.Equal(t, len(pppp.CAs), 3) assert.Equal(t, len(pppp.CAs), 3)
ppppp, err := NewCAPoolFromBytes([]byte(p256)) ppppp, warn, err := NewCAPoolFromBytes([]byte(p256))
assert.Nil(t, err) assert.Nil(t, err)
assert.Nil(t, warn)
assert.Equal(t, ppppp.CAs[string("a7938893ec8c4ef769b06d7f425e5e46f7a7f5ffa49c3bcf4a86b608caba9159")].Details.Name, rootCAP256.Details.Name) assert.Equal(t, ppppp.CAs[string("a7938893ec8c4ef769b06d7f425e5e46f7a7f5ffa49c3bcf4a86b608caba9159")].Details.Name, rootCAP256.Details.Name)
assert.Equal(t, len(ppppp.CAs), 1) assert.Equal(t, len(ppppp.CAs), 1)
pppppp, warn, err := NewCAPoolFromBytes(append([]byte(p256), []byte(v2)...))
assert.Nil(t, err)
assert.True(t, errors.Is(warn[0], ErrInvalidPEMCertificateUnsupported))
assert.Equal(t, pppppp.CAs[string("a7938893ec8c4ef769b06d7f425e5e46f7a7f5ffa49c3bcf4a86b608caba9159")].Details.Name, rootCAP256.Details.Name)
assert.Equal(t, len(pppppp.CAs), 1)
} }
func appendByteSlices(b ...[]byte) []byte { func appendByteSlices(b ...[]byte) []byte {

View File

@@ -11,4 +11,5 @@ var (
ErrNotSelfSigned = errors.New("certificate is not self-signed") ErrNotSelfSigned = errors.New("certificate is not self-signed")
ErrBlockListed = errors.New("certificate is in the block list") ErrBlockListed = errors.New("certificate is in the block list")
ErrSignatureMismatch = errors.New("certificate signature did not match") ErrSignatureMismatch = errors.New("certificate signature did not match")
ErrInvalidPEMCertificateUnsupported = errors.New("bytes contain an unsupported certificate format")
) )

View File

@@ -1,10 +0,0 @@
package cidr
import "net"
// Parse is a convenience function that returns only the IPNet
// This function ignores errors since it is primarily a test helper, the result could be nil
func Parse(s string) *net.IPNet {
_, c, _ := net.ParseCIDR(s)
return c
}

View File

@@ -1,203 +0,0 @@
package cidr
import (
"net"
"github.com/slackhq/nebula/iputil"
)
type Node[T any] struct {
left *Node[T]
right *Node[T]
parent *Node[T]
hasValue bool
value T
}
type entry[T any] struct {
CIDR *net.IPNet
Value T
}
type Tree4[T any] struct {
root *Node[T]
list []entry[T]
}
const (
startbit = iputil.VpnIp(0x80000000)
)
func NewTree4[T any]() *Tree4[T] {
tree := new(Tree4[T])
tree.root = &Node[T]{}
tree.list = []entry[T]{}
return tree
}
func (tree *Tree4[T]) AddCIDR(cidr *net.IPNet, val T) {
bit := startbit
node := tree.root
next := tree.root
ip := iputil.Ip2VpnIp(cidr.IP)
mask := iputil.Ip2VpnIp(cidr.Mask)
// Find our last ancestor in the tree
for bit&mask != 0 {
if ip&bit != 0 {
next = node.right
} else {
next = node.left
}
if next == nil {
break
}
bit = bit >> 1
node = next
}
// We already have this range so update the value
if next != nil {
addCIDR := cidr.String()
for i, v := range tree.list {
if addCIDR == v.CIDR.String() {
tree.list = append(tree.list[:i], tree.list[i+1:]...)
break
}
}
tree.list = append(tree.list, entry[T]{CIDR: cidr, Value: val})
node.value = val
node.hasValue = true
return
}
// Build up the rest of the tree we don't already have
for bit&mask != 0 {
next = &Node[T]{}
next.parent = node
if ip&bit != 0 {
node.right = next
} else {
node.left = next
}
bit >>= 1
node = next
}
// Final node marks our cidr, set the value
node.value = val
node.hasValue = true
tree.list = append(tree.list, entry[T]{CIDR: cidr, Value: val})
}
// Contains finds the first match, which may be the least specific
func (tree *Tree4[T]) Contains(ip iputil.VpnIp) (ok bool, value T) {
bit := startbit
node := tree.root
for node != nil {
if node.hasValue {
return true, node.value
}
if ip&bit != 0 {
node = node.right
} else {
node = node.left
}
bit >>= 1
}
return false, value
}
// MostSpecificContains finds the most specific match
func (tree *Tree4[T]) MostSpecificContains(ip iputil.VpnIp) (ok bool, value T) {
bit := startbit
node := tree.root
for node != nil {
if node.hasValue {
value = node.value
ok = true
}
if ip&bit != 0 {
node = node.right
} else {
node = node.left
}
bit >>= 1
}
return ok, value
}
type eachFunc[T any] func(T) bool
// EachContains will call a function, passing the value, for each entry until the function returns true or the search is complete
// The final return value will be true if the provided function returned true
func (tree *Tree4[T]) EachContains(ip iputil.VpnIp, each eachFunc[T]) bool {
bit := startbit
node := tree.root
for node != nil {
if node.hasValue {
// If the each func returns true then we can exit the loop
if each(node.value) {
return true
}
}
if ip&bit != 0 {
node = node.right
} else {
node = node.left
}
bit >>= 1
}
return false
}
// GetCIDR returns the entry added by the most recent matching AddCIDR call
func (tree *Tree4[T]) GetCIDR(cidr *net.IPNet) (ok bool, value T) {
bit := startbit
node := tree.root
ip := iputil.Ip2VpnIp(cidr.IP)
mask := iputil.Ip2VpnIp(cidr.Mask)
// Find our last ancestor in the tree
for node != nil && bit&mask != 0 {
if ip&bit != 0 {
node = node.right
} else {
node = node.left
}
bit = bit >> 1
}
if bit&mask == 0 && node != nil {
value = node.value
ok = node.hasValue
}
return ok, value
}
// List will return all CIDRs and their current values. Do not modify the contents!
func (tree *Tree4[T]) List() []entry[T] {
return tree.list
}

View File

@@ -1,170 +0,0 @@
package cidr
import (
"net"
"testing"
"github.com/slackhq/nebula/iputil"
"github.com/stretchr/testify/assert"
)
func TestCIDRTree_List(t *testing.T) {
tree := NewTree4[string]()
tree.AddCIDR(Parse("1.0.0.0/16"), "1")
tree.AddCIDR(Parse("1.0.0.0/8"), "2")
tree.AddCIDR(Parse("1.0.0.0/16"), "3")
tree.AddCIDR(Parse("1.0.0.0/16"), "4")
list := tree.List()
assert.Len(t, list, 2)
assert.Equal(t, "1.0.0.0/8", list[0].CIDR.String())
assert.Equal(t, "2", list[0].Value)
assert.Equal(t, "1.0.0.0/16", list[1].CIDR.String())
assert.Equal(t, "4", list[1].Value)
}
func TestCIDRTree_Contains(t *testing.T) {
tree := NewTree4[string]()
tree.AddCIDR(Parse("1.0.0.0/8"), "1")
tree.AddCIDR(Parse("2.1.0.0/16"), "2")
tree.AddCIDR(Parse("3.1.1.0/24"), "3")
tree.AddCIDR(Parse("4.1.1.0/24"), "4a")
tree.AddCIDR(Parse("4.1.1.1/32"), "4b")
tree.AddCIDR(Parse("4.1.2.1/32"), "4c")
tree.AddCIDR(Parse("254.0.0.0/4"), "5")
tests := []struct {
Found bool
Result interface{}
IP string
}{
{true, "1", "1.0.0.0"},
{true, "1", "1.255.255.255"},
{true, "2", "2.1.0.0"},
{true, "2", "2.1.255.255"},
{true, "3", "3.1.1.0"},
{true, "3", "3.1.1.255"},
{true, "4a", "4.1.1.255"},
{true, "4a", "4.1.1.1"},
{true, "5", "240.0.0.0"},
{true, "5", "255.255.255.255"},
{false, "", "239.0.0.0"},
{false, "", "4.1.2.2"},
}
for _, tt := range tests {
ok, r := tree.Contains(iputil.Ip2VpnIp(net.ParseIP(tt.IP)))
assert.Equal(t, tt.Found, ok)
assert.Equal(t, tt.Result, r)
}
tree = NewTree4[string]()
tree.AddCIDR(Parse("1.1.1.1/0"), "cool")
ok, r := tree.Contains(iputil.Ip2VpnIp(net.ParseIP("0.0.0.0")))
assert.True(t, ok)
assert.Equal(t, "cool", r)
ok, r = tree.Contains(iputil.Ip2VpnIp(net.ParseIP("255.255.255.255")))
assert.True(t, ok)
assert.Equal(t, "cool", r)
}
func TestCIDRTree_MostSpecificContains(t *testing.T) {
tree := NewTree4[string]()
tree.AddCIDR(Parse("1.0.0.0/8"), "1")
tree.AddCIDR(Parse("2.1.0.0/16"), "2")
tree.AddCIDR(Parse("3.1.1.0/24"), "3")
tree.AddCIDR(Parse("4.1.1.0/24"), "4a")
tree.AddCIDR(Parse("4.1.1.0/30"), "4b")
tree.AddCIDR(Parse("4.1.1.1/32"), "4c")
tree.AddCIDR(Parse("254.0.0.0/4"), "5")
tests := []struct {
Found bool
Result interface{}
IP string
}{
{true, "1", "1.0.0.0"},
{true, "1", "1.255.255.255"},
{true, "2", "2.1.0.0"},
{true, "2", "2.1.255.255"},
{true, "3", "3.1.1.0"},
{true, "3", "3.1.1.255"},
{true, "4a", "4.1.1.255"},
{true, "4b", "4.1.1.2"},
{true, "4c", "4.1.1.1"},
{true, "5", "240.0.0.0"},
{true, "5", "255.255.255.255"},
{false, "", "239.0.0.0"},
{false, "", "4.1.2.2"},
}
for _, tt := range tests {
ok, r := tree.MostSpecificContains(iputil.Ip2VpnIp(net.ParseIP(tt.IP)))
assert.Equal(t, tt.Found, ok)
assert.Equal(t, tt.Result, r)
}
tree = NewTree4[string]()
tree.AddCIDR(Parse("1.1.1.1/0"), "cool")
ok, r := tree.MostSpecificContains(iputil.Ip2VpnIp(net.ParseIP("0.0.0.0")))
assert.True(t, ok)
assert.Equal(t, "cool", r)
ok, r = tree.MostSpecificContains(iputil.Ip2VpnIp(net.ParseIP("255.255.255.255")))
assert.True(t, ok)
assert.Equal(t, "cool", r)
}
func TestTree4_GetCIDR(t *testing.T) {
tree := NewTree4[string]()
tree.AddCIDR(Parse("1.0.0.0/8"), "1")
tree.AddCIDR(Parse("2.1.0.0/16"), "2")
tree.AddCIDR(Parse("3.1.1.0/24"), "3")
tree.AddCIDR(Parse("4.1.1.0/24"), "4a")
tree.AddCIDR(Parse("4.1.1.1/32"), "4b")
tree.AddCIDR(Parse("4.1.2.1/32"), "4c")
tree.AddCIDR(Parse("254.0.0.0/4"), "5")
tests := []struct {
Found bool
Result interface{}
IPNet *net.IPNet
}{
{true, "1", Parse("1.0.0.0/8")},
{true, "2", Parse("2.1.0.0/16")},
{true, "3", Parse("3.1.1.0/24")},
{true, "4a", Parse("4.1.1.0/24")},
{true, "4b", Parse("4.1.1.1/32")},
{true, "4c", Parse("4.1.2.1/32")},
{true, "5", Parse("254.0.0.0/4")},
{false, "", Parse("2.0.0.0/8")},
}
for _, tt := range tests {
ok, r := tree.GetCIDR(tt.IPNet)
assert.Equal(t, tt.Found, ok)
assert.Equal(t, tt.Result, r)
}
}
func BenchmarkCIDRTree_Contains(b *testing.B) {
tree := NewTree4[string]()
tree.AddCIDR(Parse("1.1.0.0/16"), "1")
tree.AddCIDR(Parse("1.2.1.1/32"), "1")
tree.AddCIDR(Parse("192.2.1.1/32"), "1")
tree.AddCIDR(Parse("172.2.1.1/32"), "1")
ip := iputil.Ip2VpnIp(net.ParseIP("1.2.1.1"))
b.Run("found", func(b *testing.B) {
for i := 0; i < b.N; i++ {
tree.Contains(ip)
}
})
ip = iputil.Ip2VpnIp(net.ParseIP("1.2.1.255"))
b.Run("not found", func(b *testing.B) {
for i := 0; i < b.N; i++ {
tree.Contains(ip)
}
})
}

View File

@@ -1,189 +0,0 @@
package cidr
import (
"net"
"github.com/slackhq/nebula/iputil"
)
const startbit6 = uint64(1 << 63)
type Tree6[T any] struct {
root4 *Node[T]
root6 *Node[T]
}
func NewTree6[T any]() *Tree6[T] {
tree := new(Tree6[T])
tree.root4 = &Node[T]{}
tree.root6 = &Node[T]{}
return tree
}
func (tree *Tree6[T]) AddCIDR(cidr *net.IPNet, val T) {
var node, next *Node[T]
cidrIP, ipv4 := isIPV4(cidr.IP)
if ipv4 {
node = tree.root4
next = tree.root4
} else {
node = tree.root6
next = tree.root6
}
for i := 0; i < len(cidrIP); i += 4 {
ip := iputil.Ip2VpnIp(cidrIP[i : i+4])
mask := iputil.Ip2VpnIp(cidr.Mask[i : i+4])
bit := startbit
// Find our last ancestor in the tree
for bit&mask != 0 {
if ip&bit != 0 {
next = node.right
} else {
next = node.left
}
if next == nil {
break
}
bit = bit >> 1
node = next
}
// Build up the rest of the tree we don't already have
for bit&mask != 0 {
next = &Node[T]{}
next.parent = node
if ip&bit != 0 {
node.right = next
} else {
node.left = next
}
bit >>= 1
node = next
}
}
// Final node marks our cidr, set the value
node.value = val
node.hasValue = true
}
// Finds the most specific match
func (tree *Tree6[T]) MostSpecificContains(ip net.IP) (ok bool, value T) {
var node *Node[T]
wholeIP, ipv4 := isIPV4(ip)
if ipv4 {
node = tree.root4
} else {
node = tree.root6
}
for i := 0; i < len(wholeIP); i += 4 {
ip := iputil.Ip2VpnIp(wholeIP[i : i+4])
bit := startbit
for node != nil {
if node.hasValue {
value = node.value
ok = true
}
if bit == 0 {
break
}
if ip&bit != 0 {
node = node.right
} else {
node = node.left
}
bit >>= 1
}
}
return ok, value
}
func (tree *Tree6[T]) MostSpecificContainsIpV4(ip iputil.VpnIp) (ok bool, value T) {
bit := startbit
node := tree.root4
for node != nil {
if node.hasValue {
value = node.value
ok = true
}
if ip&bit != 0 {
node = node.right
} else {
node = node.left
}
bit >>= 1
}
return ok, value
}
func (tree *Tree6[T]) MostSpecificContainsIpV6(hi, lo uint64) (ok bool, value T) {
ip := hi
node := tree.root6
for i := 0; i < 2; i++ {
bit := startbit6
for node != nil {
if node.hasValue {
value = node.value
ok = true
}
if bit == 0 {
break
}
if ip&bit != 0 {
node = node.right
} else {
node = node.left
}
bit >>= 1
}
ip = lo
}
return ok, value
}
func isIPV4(ip net.IP) (net.IP, bool) {
if len(ip) == net.IPv4len {
return ip, true
}
if len(ip) == net.IPv6len && isZeros(ip[0:10]) && ip[10] == 0xff && ip[11] == 0xff {
return ip[12:16], true
}
return ip, false
}
func isZeros(p net.IP) bool {
for i := 0; i < len(p); i++ {
if p[i] != 0 {
return false
}
}
return true
}

View File

@@ -1,98 +0,0 @@
package cidr
import (
"encoding/binary"
"net"
"testing"
"github.com/stretchr/testify/assert"
)
func TestCIDR6Tree_MostSpecificContains(t *testing.T) {
tree := NewTree6[string]()
tree.AddCIDR(Parse("1.0.0.0/8"), "1")
tree.AddCIDR(Parse("2.1.0.0/16"), "2")
tree.AddCIDR(Parse("3.1.1.0/24"), "3")
tree.AddCIDR(Parse("4.1.1.1/24"), "4a")
tree.AddCIDR(Parse("4.1.1.1/30"), "4b")
tree.AddCIDR(Parse("4.1.1.1/32"), "4c")
tree.AddCIDR(Parse("254.0.0.0/4"), "5")
tree.AddCIDR(Parse("1:2:0:4:5:0:0:0/64"), "6a")
tree.AddCIDR(Parse("1:2:0:4:5:0:0:0/80"), "6b")
tree.AddCIDR(Parse("1:2:0:4:5:0:0:0/96"), "6c")
tests := []struct {
Found bool
Result interface{}
IP string
}{
{true, "1", "1.0.0.0"},
{true, "1", "1.255.255.255"},
{true, "2", "2.1.0.0"},
{true, "2", "2.1.255.255"},
{true, "3", "3.1.1.0"},
{true, "3", "3.1.1.255"},
{true, "4a", "4.1.1.255"},
{true, "4b", "4.1.1.2"},
{true, "4c", "4.1.1.1"},
{true, "5", "240.0.0.0"},
{true, "5", "255.255.255.255"},
{true, "6a", "1:2:0:4:1:1:1:1"},
{true, "6b", "1:2:0:4:5:1:1:1"},
{true, "6c", "1:2:0:4:5:0:0:0"},
{false, "", "239.0.0.0"},
{false, "", "4.1.2.2"},
}
for _, tt := range tests {
ok, r := tree.MostSpecificContains(net.ParseIP(tt.IP))
assert.Equal(t, tt.Found, ok)
assert.Equal(t, tt.Result, r)
}
tree = NewTree6[string]()
tree.AddCIDR(Parse("1.1.1.1/0"), "cool")
tree.AddCIDR(Parse("::/0"), "cool6")
ok, r := tree.MostSpecificContains(net.ParseIP("0.0.0.0"))
assert.True(t, ok)
assert.Equal(t, "cool", r)
ok, r = tree.MostSpecificContains(net.ParseIP("255.255.255.255"))
assert.True(t, ok)
assert.Equal(t, "cool", r)
ok, r = tree.MostSpecificContains(net.ParseIP("::"))
assert.True(t, ok)
assert.Equal(t, "cool6", r)
ok, r = tree.MostSpecificContains(net.ParseIP("1:2:3:4:5:6:7:8"))
assert.True(t, ok)
assert.Equal(t, "cool6", r)
}
func TestCIDR6Tree_MostSpecificContainsIpV6(t *testing.T) {
tree := NewTree6[string]()
tree.AddCIDR(Parse("1:2:0:4:5:0:0:0/64"), "6a")
tree.AddCIDR(Parse("1:2:0:4:5:0:0:0/80"), "6b")
tree.AddCIDR(Parse("1:2:0:4:5:0:0:0/96"), "6c")
tests := []struct {
Found bool
Result interface{}
IP string
}{
{true, "6a", "1:2:0:4:1:1:1:1"},
{true, "6b", "1:2:0:4:5:1:1:1"},
{true, "6c", "1:2:0:4:5:0:0:0"},
}
for _, tt := range tests {
ip := net.ParseIP(tt.IP)
hi := binary.BigEndian.Uint64(ip[:8])
lo := binary.BigEndian.Uint64(ip[8:])
ok, r := tree.MostSpecificContainsIpV6(hi, lo)
assert.Equal(t, tt.Found, ok)
assert.Equal(t, tt.Result, r)
}
}

View File

@@ -3,15 +3,18 @@ package nebula
import ( import (
"bytes" "bytes"
"context" "context"
"encoding/binary"
"fmt"
"net/netip"
"sync" "sync"
"sync/atomic"
"time" "time"
"github.com/rcrowley/go-metrics" "github.com/rcrowley/go-metrics"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/header" "github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/udp"
) )
type trafficDecision int type trafficDecision int
@@ -27,12 +30,6 @@ const (
) )
type connectionManager struct { type connectionManager struct {
in map[uint32]struct{}
inLock *sync.RWMutex
out map[uint32]struct{}
outLock *sync.RWMutex
// relayUsed holds which relay localIndexs are in use // relayUsed holds which relay localIndexs are in use
relayUsed map[uint32]struct{} relayUsed map[uint32]struct{}
relayUsedLock *sync.RWMutex relayUsedLock *sync.RWMutex
@@ -40,117 +37,117 @@ type connectionManager struct {
hostMap *HostMap hostMap *HostMap
trafficTimer *LockingTimerWheel[uint32] trafficTimer *LockingTimerWheel[uint32]
intf *Interface intf *Interface
pendingDeletion map[uint32]struct{}
punchy *Punchy punchy *Punchy
// Configuration settings
checkInterval time.Duration checkInterval time.Duration
pendingDeletionInterval time.Duration pendingDeletionInterval time.Duration
inactivityTimeout atomic.Int64
dropInactive atomic.Bool
metricsTxPunchy metrics.Counter metricsTxPunchy metrics.Counter
l *logrus.Logger l *logrus.Logger
} }
func newConnectionManager(ctx context.Context, l *logrus.Logger, intf *Interface, checkInterval, pendingDeletionInterval time.Duration, punchy *Punchy) *connectionManager { func newConnectionManagerFromConfig(l *logrus.Logger, c *config.C, hm *HostMap, p *Punchy) *connectionManager {
var max time.Duration cm := &connectionManager{
if checkInterval < pendingDeletionInterval { hostMap: hm,
max = pendingDeletionInterval l: l,
} else { punchy: p,
max = checkInterval
}
nc := &connectionManager{
hostMap: intf.hostMap,
in: make(map[uint32]struct{}),
inLock: &sync.RWMutex{},
out: make(map[uint32]struct{}),
outLock: &sync.RWMutex{},
relayUsed: make(map[uint32]struct{}), relayUsed: make(map[uint32]struct{}),
relayUsedLock: &sync.RWMutex{}, relayUsedLock: &sync.RWMutex{},
trafficTimer: NewLockingTimerWheel[uint32](time.Millisecond*500, max),
intf: intf,
pendingDeletion: make(map[uint32]struct{}),
checkInterval: checkInterval,
pendingDeletionInterval: pendingDeletionInterval,
punchy: punchy,
metricsTxPunchy: metrics.GetOrRegisterCounter("messages.tx.punchy", nil), metricsTxPunchy: metrics.GetOrRegisterCounter("messages.tx.punchy", nil),
l: l,
} }
nc.Start(ctx) cm.reload(c, true)
return nc c.RegisterReloadCallback(func(c *config.C) {
cm.reload(c, false)
})
return cm
} }
func (n *connectionManager) In(localIndex uint32) { func (cm *connectionManager) reload(c *config.C, initial bool) {
n.inLock.RLock() if initial {
cm.checkInterval = time.Duration(c.GetInt("timers.connection_alive_interval", 5)) * time.Second
cm.pendingDeletionInterval = time.Duration(c.GetInt("timers.pending_deletion_interval", 10)) * time.Second
// We want at least a minimum resolution of 500ms per tick so that we can hit these intervals
// pretty close to their configured duration.
// The inactivity duration is checked each time a hostinfo ticks through so we don't need the wheel to contain it.
minDuration := min(time.Millisecond*500, cm.checkInterval, cm.pendingDeletionInterval)
maxDuration := max(cm.checkInterval, cm.pendingDeletionInterval)
cm.trafficTimer = NewLockingTimerWheel[uint32](minDuration, maxDuration)
}
if initial || c.HasChanged("tunnels.inactivity_timeout") {
old := cm.getInactivityTimeout()
cm.inactivityTimeout.Store((int64)(c.GetDuration("tunnels.inactivity_timeout", 10*time.Minute)))
if !initial {
cm.l.WithField("oldDuration", old).
WithField("newDuration", cm.getInactivityTimeout()).
Info("Inactivity timeout has changed")
}
}
if initial || c.HasChanged("tunnels.drop_inactive") {
old := cm.dropInactive.Load()
cm.dropInactive.Store(c.GetBool("tunnels.drop_inactive", false))
if !initial {
cm.l.WithField("oldBool", old).
WithField("newBool", cm.dropInactive.Load()).
Info("Drop inactive setting has changed")
}
}
}
func (cm *connectionManager) getInactivityTimeout() time.Duration {
return (time.Duration)(cm.inactivityTimeout.Load())
}
func (cm *connectionManager) In(h *HostInfo) {
h.in.Store(true)
}
func (cm *connectionManager) Out(h *HostInfo) {
h.out.Store(true)
}
func (cm *connectionManager) RelayUsed(localIndex uint32) {
cm.relayUsedLock.RLock()
// If this already exists, return // If this already exists, return
if _, ok := n.in[localIndex]; ok { if _, ok := cm.relayUsed[localIndex]; ok {
n.inLock.RUnlock() cm.relayUsedLock.RUnlock()
return return
} }
n.inLock.RUnlock() cm.relayUsedLock.RUnlock()
n.inLock.Lock() cm.relayUsedLock.Lock()
n.in[localIndex] = struct{}{} cm.relayUsed[localIndex] = struct{}{}
n.inLock.Unlock() cm.relayUsedLock.Unlock()
}
func (n *connectionManager) Out(localIndex uint32) {
n.outLock.RLock()
// If this already exists, return
if _, ok := n.out[localIndex]; ok {
n.outLock.RUnlock()
return
}
n.outLock.RUnlock()
n.outLock.Lock()
n.out[localIndex] = struct{}{}
n.outLock.Unlock()
}
func (n *connectionManager) RelayUsed(localIndex uint32) {
n.relayUsedLock.RLock()
// If this already exists, return
if _, ok := n.relayUsed[localIndex]; ok {
n.relayUsedLock.RUnlock()
return
}
n.relayUsedLock.RUnlock()
n.relayUsedLock.Lock()
n.relayUsed[localIndex] = struct{}{}
n.relayUsedLock.Unlock()
} }
// getAndResetTrafficCheck returns if there was any inbound or outbound traffic within the last tick and // getAndResetTrafficCheck returns if there was any inbound or outbound traffic within the last tick and
// resets the state for this local index // resets the state for this local index
func (n *connectionManager) getAndResetTrafficCheck(localIndex uint32) (bool, bool) { func (cm *connectionManager) getAndResetTrafficCheck(h *HostInfo, now time.Time) (bool, bool) {
n.inLock.Lock() in := h.in.Swap(false)
n.outLock.Lock() out := h.out.Swap(false)
_, in := n.in[localIndex] if in || out {
_, out := n.out[localIndex] h.lastUsed = now
delete(n.in, localIndex) }
delete(n.out, localIndex)
n.inLock.Unlock()
n.outLock.Unlock()
return in, out return in, out
} }
func (n *connectionManager) AddTrafficWatch(localIndex uint32) { // AddTrafficWatch must be called for every new HostInfo.
// Use a write lock directly because it should be incredibly rare that we are ever already tracking this index // We will continue to monitor the HostInfo until the tunnel is dropped.
n.outLock.Lock() func (cm *connectionManager) AddTrafficWatch(h *HostInfo) {
if _, ok := n.out[localIndex]; ok { if h.out.Swap(true) == false {
n.outLock.Unlock() cm.trafficTimer.Add(h.localIndexId, cm.checkInterval)
return
} }
n.out[localIndex] = struct{}{}
n.trafficTimer.Add(localIndex, n.checkInterval)
n.outLock.Unlock()
} }
func (n *connectionManager) Start(ctx context.Context) { func (cm *connectionManager) Start(ctx context.Context) {
go n.Run(ctx) clockSource := time.NewTicker(cm.trafficTimer.t.tickDuration)
}
func (n *connectionManager) Run(ctx context.Context) {
//TODO: this tick should be based on the min wheel tick? Check firewall
clockSource := time.NewTicker(500 * time.Millisecond)
defer clockSource.Stop() defer clockSource.Stop()
p := []byte("") p := []byte("")
@@ -163,128 +160,137 @@ func (n *connectionManager) Run(ctx context.Context) {
return return
case now := <-clockSource.C: case now := <-clockSource.C:
n.trafficTimer.Advance(now) cm.trafficTimer.Advance(now)
for { for {
localIndex, has := n.trafficTimer.Purge() localIndex, has := cm.trafficTimer.Purge()
if !has { if !has {
break break
} }
n.doTrafficCheck(localIndex, p, nb, out, now) cm.doTrafficCheck(localIndex, p, nb, out, now)
} }
} }
} }
} }
func (n *connectionManager) doTrafficCheck(localIndex uint32, p, nb, out []byte, now time.Time) { func (cm *connectionManager) doTrafficCheck(localIndex uint32, p, nb, out []byte, now time.Time) {
decision, hostinfo, primary := n.makeTrafficDecision(localIndex, now) decision, hostinfo, primary := cm.makeTrafficDecision(localIndex, now)
switch decision { switch decision {
case deleteTunnel: case deleteTunnel:
if n.hostMap.DeleteHostInfo(hostinfo) { if cm.hostMap.DeleteHostInfo(hostinfo) {
// Only clearing the lighthouse cache if this is the last hostinfo for this vpn ip in the hostmap // Only clearing the lighthouse cache if this is the last hostinfo for this vpn ip in the hostmap
n.intf.lightHouse.DeleteVpnIp(hostinfo.vpnIp) cm.intf.lightHouse.DeleteVpnIp(hostinfo.vpnIp)
} }
case closeTunnel: case closeTunnel:
n.intf.sendCloseTunnel(hostinfo) cm.intf.sendCloseTunnel(hostinfo)
n.intf.closeTunnel(hostinfo) cm.intf.closeTunnel(hostinfo)
case swapPrimary: case swapPrimary:
n.swapPrimary(hostinfo, primary) cm.swapPrimary(hostinfo, primary)
case migrateRelays: case migrateRelays:
n.migrateRelayUsed(hostinfo, primary) cm.migrateRelayUsed(hostinfo, primary)
case tryRehandshake: case tryRehandshake:
n.tryRehandshake(hostinfo) cm.tryRehandshake(hostinfo)
case sendTestPacket: case sendTestPacket:
n.intf.SendMessageToHostInfo(header.Test, header.TestRequest, hostinfo, p, nb, out) cm.intf.SendMessageToHostInfo(header.Test, header.TestRequest, hostinfo, p, nb, out)
} }
n.resetRelayTrafficCheck(hostinfo) cm.resetRelayTrafficCheck(hostinfo)
} }
func (n *connectionManager) resetRelayTrafficCheck(hostinfo *HostInfo) { func (cm *connectionManager) resetRelayTrafficCheck(hostinfo *HostInfo) {
if hostinfo != nil { if hostinfo != nil {
n.relayUsedLock.Lock() cm.relayUsedLock.Lock()
defer n.relayUsedLock.Unlock() defer cm.relayUsedLock.Unlock()
// No need to migrate any relays, delete usage info now. // No need to migrate any relays, delete usage info now.
for _, idx := range hostinfo.relayState.CopyRelayForIdxs() { for _, idx := range hostinfo.relayState.CopyRelayForIdxs() {
delete(n.relayUsed, idx) delete(cm.relayUsed, idx)
} }
} }
} }
func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo) { func (cm *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo) {
relayFor := oldhostinfo.relayState.CopyAllRelayFor() relayFor := oldhostinfo.relayState.CopyAllRelayFor()
for _, r := range relayFor { for _, r := range relayFor {
existing, ok := newhostinfo.relayState.QueryRelayForByIp(r.PeerIp) existing, ok := newhostinfo.relayState.QueryRelayForByIp(r.PeerIp)
var index uint32 var index uint32
var relayFrom iputil.VpnIp var relayFrom netip.Addr
var relayTo iputil.VpnIp var relayTo netip.Addr
switch { switch {
case ok && existing.State == Established: case ok:
switch existing.State {
case Established, PeerRequested, Disestablished:
// This relay already exists in newhostinfo, then do nothing. // This relay already exists in newhostinfo, then do nothing.
continue continue
case ok && existing.State == Requested: case Requested:
// The relay exists in a Requested state; re-send the request // The relayed connection exists in a Requested state; re-send the request
index = existing.LocalIndex index = existing.LocalIndex
switch r.Type { switch r.Type {
case TerminalType: case TerminalType:
relayFrom = n.intf.myVpnIp relayFrom = cm.intf.myVpnNet.Addr()
relayTo = existing.PeerIp relayTo = existing.PeerIp
case ForwardingType: case ForwardingType:
relayFrom = existing.PeerIp relayFrom = existing.PeerIp
relayTo = newhostinfo.vpnIp relayTo = newhostinfo.vpnIp
default: default:
// should never happen // should never happen
panic(fmt.Sprintf("Migrating unknown relay type: %v", r.Type))
}
} }
case !ok: case !ok:
n.relayUsedLock.RLock() cm.relayUsedLock.RLock()
if _, relayUsed := n.relayUsed[r.LocalIndex]; !relayUsed { if _, relayUsed := cm.relayUsed[r.LocalIndex]; !relayUsed {
// The relay hasn't been used; don't migrate it. // The relay hasn't been used; don't migrate it.
n.relayUsedLock.RUnlock() cm.relayUsedLock.RUnlock()
continue continue
} }
n.relayUsedLock.RUnlock() cm.relayUsedLock.RUnlock()
// The relay doesn't exist at all; create some relay state and send the request. // The relay doesn't exist at all; create some relay state and send the request.
var err error var err error
index, err = AddRelay(n.l, newhostinfo, n.hostMap, r.PeerIp, nil, r.Type, Requested) index, err = AddRelay(cm.l, newhostinfo, cm.hostMap, r.PeerIp, nil, r.Type, Requested)
if err != nil { if err != nil {
n.l.WithError(err).Error("failed to migrate relay to new hostinfo") cm.l.WithError(err).Error("failed to migrate relay to new hostinfo")
continue continue
} }
switch r.Type { switch r.Type {
case TerminalType: case TerminalType:
relayFrom = n.intf.myVpnIp relayFrom = cm.intf.myVpnNet.Addr()
relayTo = r.PeerIp relayTo = r.PeerIp
case ForwardingType: case ForwardingType:
relayFrom = r.PeerIp relayFrom = r.PeerIp
relayTo = newhostinfo.vpnIp relayTo = newhostinfo.vpnIp
default: default:
// should never happen // should never happen
panic(fmt.Sprintf("Migrating unknown relay type: %v", r.Type))
} }
} }
//TODO: IPV6-WORK
relayFromB := relayFrom.As4()
relayToB := relayTo.As4()
// Send a CreateRelayRequest to the peer. // Send a CreateRelayRequest to the peer.
req := NebulaControl{ req := NebulaControl{
Type: NebulaControl_CreateRelayRequest, Type: NebulaControl_CreateRelayRequest,
InitiatorRelayIndex: index, InitiatorRelayIndex: index,
RelayFromIp: uint32(relayFrom), RelayFromIp: binary.BigEndian.Uint32(relayFromB[:]),
RelayToIp: uint32(relayTo), RelayToIp: binary.BigEndian.Uint32(relayToB[:]),
} }
msg, err := req.Marshal() msg, err := req.Marshal()
if err != nil { if err != nil {
n.l.WithError(err).Error("failed to marshal Control message to migrate relay") cm.l.WithError(err).Error("failed to marshal Control message to migrate relay")
} else { } else {
n.intf.SendMessageToHostInfo(header.Control, 0, newhostinfo, msg, make([]byte, 12), make([]byte, mtu)) cm.intf.SendMessageToHostInfo(header.Control, 0, newhostinfo, msg, make([]byte, 12), make([]byte, mtu))
n.l.WithFields(logrus.Fields{ cm.l.WithFields(logrus.Fields{
"relayFrom": iputil.VpnIp(req.RelayFromIp), "relayFrom": req.RelayFromIp,
"relayTo": iputil.VpnIp(req.RelayToIp), "relayTo": req.RelayToIp,
"initiatorRelayIndex": req.InitiatorRelayIndex, "initiatorRelayIndex": req.InitiatorRelayIndex,
"responderRelayIndex": req.ResponderRelayIndex, "responderRelayIndex": req.ResponderRelayIndex,
"vpnIp": newhostinfo.vpnIp}). "vpnIp": newhostinfo.vpnIp}).
@@ -293,46 +299,45 @@ func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo)
} }
} }
func (n *connectionManager) makeTrafficDecision(localIndex uint32, now time.Time) (trafficDecision, *HostInfo, *HostInfo) { func (cm *connectionManager) makeTrafficDecision(localIndex uint32, now time.Time) (trafficDecision, *HostInfo, *HostInfo) {
n.hostMap.RLock() // Read lock the main hostmap to order decisions based on tunnels being the primary tunnel
defer n.hostMap.RUnlock() cm.hostMap.RLock()
defer cm.hostMap.RUnlock()
hostinfo := n.hostMap.Indexes[localIndex] hostinfo := cm.hostMap.Indexes[localIndex]
if hostinfo == nil { if hostinfo == nil {
n.l.WithField("localIndex", localIndex).Debugf("Not found in hostmap") cm.l.WithField("localIndex", localIndex).Debugln("Not found in hostmap")
delete(n.pendingDeletion, localIndex)
return doNothing, nil, nil return doNothing, nil, nil
} }
if n.isInvalidCertificate(now, hostinfo) { if cm.isInvalidCertificate(now, hostinfo) {
delete(n.pendingDeletion, hostinfo.localIndexId)
return closeTunnel, hostinfo, nil return closeTunnel, hostinfo, nil
} }
primary := n.hostMap.Hosts[hostinfo.vpnIp] primary := cm.hostMap.Hosts[hostinfo.vpnIp]
mainHostInfo := true mainHostInfo := true
if primary != nil && primary != hostinfo { if primary != nil && primary != hostinfo {
mainHostInfo = false mainHostInfo = false
} }
// Check for traffic on this hostinfo // Check for traffic on this hostinfo
inTraffic, outTraffic := n.getAndResetTrafficCheck(localIndex) inTraffic, outTraffic := cm.getAndResetTrafficCheck(hostinfo, now)
// A hostinfo is determined alive if there is incoming traffic // A hostinfo is determined alive if there is incoming traffic
if inTraffic { if inTraffic {
decision := doNothing decision := doNothing
if n.l.Level >= logrus.DebugLevel { if cm.l.Level >= logrus.DebugLevel {
hostinfo.logger(n.l). hostinfo.logger(cm.l).
WithField("tunnelCheck", m{"state": "alive", "method": "passive"}). WithField("tunnelCheck", m{"state": "alive", "method": "passive"}).
Debug("Tunnel status") Debug("Tunnel status")
} }
delete(n.pendingDeletion, hostinfo.localIndexId) hostinfo.pendingDeletion.Store(false)
if mainHostInfo { if mainHostInfo {
decision = tryRehandshake decision = tryRehandshake
} else { } else {
if n.shouldSwapPrimary(hostinfo, primary) { if cm.shouldSwapPrimary(hostinfo, primary) {
decision = swapPrimary decision = swapPrimary
} else { } else {
// migrate the relays to the primary, if in use. // migrate the relays to the primary, if in use.
@@ -340,46 +345,55 @@ func (n *connectionManager) makeTrafficDecision(localIndex uint32, now time.Time
} }
} }
n.trafficTimer.Add(hostinfo.localIndexId, n.checkInterval) cm.trafficTimer.Add(hostinfo.localIndexId, cm.checkInterval)
if !outTraffic { if !outTraffic {
// Send a punch packet to keep the NAT state alive // Send a punch packet to keep the NAT state alive
n.sendPunch(hostinfo) cm.sendPunch(hostinfo)
} }
return decision, hostinfo, primary return decision, hostinfo, primary
} }
if _, ok := n.pendingDeletion[hostinfo.localIndexId]; ok { if hostinfo.pendingDeletion.Load() {
// We have already sent a test packet and nothing was returned, this hostinfo is dead // We have already sent a test packet and nothing was returned, this hostinfo is dead
hostinfo.logger(n.l). hostinfo.logger(cm.l).
WithField("tunnelCheck", m{"state": "dead", "method": "active"}). WithField("tunnelCheck", m{"state": "dead", "method": "active"}).
Info("Tunnel status") Info("Tunnel status")
delete(n.pendingDeletion, hostinfo.localIndexId)
return deleteTunnel, hostinfo, nil return deleteTunnel, hostinfo, nil
} }
decision := doNothing decision := doNothing
if hostinfo != nil && hostinfo.ConnectionState != nil && mainHostInfo { if hostinfo != nil && hostinfo.ConnectionState != nil && mainHostInfo {
if !outTraffic { if !outTraffic {
// If we aren't sending or receiving traffic then its an unused tunnel and we don't to test the tunnel. inactiveFor, isInactive := cm.isInactive(hostinfo, now)
// Just maintain NAT state if configured to do so. if isInactive {
n.sendPunch(hostinfo) // Tunnel is inactive, tear it down
n.trafficTimer.Add(hostinfo.localIndexId, n.checkInterval) hostinfo.logger(cm.l).
return doNothing, nil, nil WithField("inactiveDuration", inactiveFor).
WithField("primary", mainHostInfo).
Info("Dropping tunnel due to inactivity")
return closeTunnel, hostinfo, primary
} }
if n.punchy.GetTargetEverything() { // If we aren't sending or receiving traffic then its an unused tunnel and we don't to test the tunnel.
// Just maintain NAT state if configured to do so.
cm.sendPunch(hostinfo)
cm.trafficTimer.Add(hostinfo.localIndexId, cm.checkInterval)
return doNothing, nil, nil
}
if cm.punchy.GetTargetEverything() {
// This is similar to the old punchy behavior with a slight optimization. // This is similar to the old punchy behavior with a slight optimization.
// We aren't receiving traffic but we are sending it, punch on all known // We aren't receiving traffic but we are sending it, punch on all known
// ips in case we need to re-prime NAT state // ips in case we need to re-prime NAT state
n.sendPunch(hostinfo) cm.sendPunch(hostinfo)
} }
if n.l.Level >= logrus.DebugLevel { if cm.l.Level >= logrus.DebugLevel {
hostinfo.logger(n.l). hostinfo.logger(cm.l).
WithField("tunnelCheck", m{"state": "testing", "method": "active"}). WithField("tunnelCheck", m{"state": "testing", "method": "active"}).
Debug("Tunnel status") Debug("Tunnel status")
} }
@@ -388,95 +402,118 @@ func (n *connectionManager) makeTrafficDecision(localIndex uint32, now time.Time
decision = sendTestPacket decision = sendTestPacket
} else { } else {
if n.l.Level >= logrus.DebugLevel { if cm.l.Level >= logrus.DebugLevel {
hostinfo.logger(n.l).Debugf("Hostinfo sadness") hostinfo.logger(cm.l).Debugf("Hostinfo sadness")
} }
} }
n.pendingDeletion[hostinfo.localIndexId] = struct{}{} hostinfo.pendingDeletion.Store(true)
n.trafficTimer.Add(hostinfo.localIndexId, n.pendingDeletionInterval) cm.trafficTimer.Add(hostinfo.localIndexId, cm.pendingDeletionInterval)
return decision, hostinfo, nil return decision, hostinfo, nil
} }
func (n *connectionManager) shouldSwapPrimary(current, primary *HostInfo) bool { func (cm *connectionManager) isInactive(hostinfo *HostInfo, now time.Time) (time.Duration, bool) {
if cm.dropInactive.Load() == false {
// We aren't configured to drop inactive tunnels
return 0, false
}
inactiveDuration := now.Sub(hostinfo.lastUsed)
if inactiveDuration < cm.getInactivityTimeout() {
// It's not considered inactive
return inactiveDuration, false
}
// The tunnel is inactive
return inactiveDuration, true
}
func (cm *connectionManager) shouldSwapPrimary(current, primary *HostInfo) bool {
// The primary tunnel is the most recent handshake to complete locally and should work entirely fine. // The primary tunnel is the most recent handshake to complete locally and should work entirely fine.
// If we are here then we have multiple tunnels for a host pair and neither side believes the same tunnel is primary. // If we are here then we have multiple tunnels for a host pair and neither side believes the same tunnel is primary.
// Let's sort this out. // Let's sort this out.
if current.vpnIp < n.intf.myVpnIp { if current.vpnIp.Compare(cm.intf.myVpnNet.Addr()) < 0 {
// Only one side should flip primary because if both flip then we may never resolve to a single tunnel. // Only one side should flip primary because if both flip then we may never resolve to a single tunnel.
// vpn ip is static across all tunnels for this host pair so lets use that to determine who is flipping. // vpn ip is static across all tunnels for this host pair so lets use that to determine who is flipping.
// The remotes vpn ip is lower than mine. I will not flip. // The remotes vpn ip is lower than mine. I will not flip.
return false return false
} }
certState := n.intf.pki.GetCertState() certState := cm.intf.pki.GetCertState()
return bytes.Equal(current.ConnectionState.myCert.Signature, certState.Certificate.Signature) return bytes.Equal(current.ConnectionState.myCert.Signature, certState.Certificate.Signature)
} }
func (n *connectionManager) swapPrimary(current, primary *HostInfo) { func (cm *connectionManager) swapPrimary(current, primary *HostInfo) {
n.hostMap.Lock() cm.hostMap.Lock()
// Make sure the primary is still the same after the write lock. This avoids a race with a rehandshake. // Make sure the primary is still the same after the write lock. This avoids a race with a rehandshake.
if n.hostMap.Hosts[current.vpnIp] == primary { if cm.hostMap.Hosts[current.vpnIp] == primary {
n.hostMap.unlockedMakePrimary(current) cm.hostMap.unlockedMakePrimary(current)
} }
n.hostMap.Unlock() cm.hostMap.Unlock()
} }
// isInvalidCertificate will check if we should destroy a tunnel if pki.disconnect_invalid is true and // isInvalidCertificate will check if we should destroy a tunnel if pki.disconnect_invalid is true and
// the certificate is no longer valid. Block listed certificates will skip the pki.disconnect_invalid // the certificate is no longer valid. Block listed certificates will skip the pki.disconnect_invalid
// check and return true. // check and return true.
func (n *connectionManager) isInvalidCertificate(now time.Time, hostinfo *HostInfo) bool { func (cm *connectionManager) isInvalidCertificate(now time.Time, hostinfo *HostInfo) bool {
remoteCert := hostinfo.GetCert() remoteCert := hostinfo.GetCert()
if remoteCert == nil { if remoteCert == nil {
return false return false
} }
valid, err := remoteCert.VerifyWithCache(now, n.intf.pki.GetCAPool()) valid, err := remoteCert.VerifyWithCache(now, cm.intf.pki.GetCAPool())
if valid { if valid {
return false return false
} }
if !n.intf.disconnectInvalid.Load() && err != cert.ErrBlockListed { if !cm.intf.disconnectInvalid.Load() && err != cert.ErrBlockListed {
// Block listed certificates should always be disconnected // Block listed certificates should always be disconnected
return false return false
} }
fingerprint, _ := remoteCert.Sha256Sum() fingerprint, _ := remoteCert.Sha256Sum()
hostinfo.logger(n.l).WithError(err). hostinfo.logger(cm.l).WithError(err).
WithField("fingerprint", fingerprint). WithField("fingerprint", fingerprint).
Info("Remote certificate is no longer valid, tearing down the tunnel") Info("Remote certificate is no longer valid, tearing down the tunnel")
return true return true
} }
func (n *connectionManager) sendPunch(hostinfo *HostInfo) { func (cm *connectionManager) sendPunch(hostinfo *HostInfo) {
if !n.punchy.GetPunch() { if !cm.punchy.GetPunch() {
// Punching is disabled // Punching is disabled
return return
} }
if n.punchy.GetTargetEverything() { if cm.intf.lightHouse.IsLighthouseIP(hostinfo.vpnIp) {
hostinfo.remotes.ForEach(n.hostMap.GetPreferredRanges(), func(addr *udp.Addr, preferred bool) { // Do not punch to lighthouses, we assume our lighthouse update interval is good enough.
n.metricsTxPunchy.Inc(1) // In the event the update interval is not sufficient to maintain NAT state then a publicly available lighthouse
n.intf.outside.WriteTo([]byte{1}, addr) // would lose the ability to notify us and punchy.respond would become unreliable.
return
}
if cm.punchy.GetTargetEverything() {
hostinfo.remotes.ForEach(cm.hostMap.GetPreferredRanges(), func(addr netip.AddrPort, preferred bool) {
cm.metricsTxPunchy.Inc(1)
cm.intf.outside.WriteTo([]byte{1}, addr)
}) })
} else if hostinfo.remote != nil { } else if hostinfo.remote.IsValid() {
n.metricsTxPunchy.Inc(1) cm.metricsTxPunchy.Inc(1)
n.intf.outside.WriteTo([]byte{1}, hostinfo.remote) cm.intf.outside.WriteTo([]byte{1}, hostinfo.remote)
} }
} }
func (n *connectionManager) tryRehandshake(hostinfo *HostInfo) { func (cm *connectionManager) tryRehandshake(hostinfo *HostInfo) {
certState := n.intf.pki.GetCertState() certState := cm.intf.pki.GetCertState()
if bytes.Equal(hostinfo.ConnectionState.myCert.Signature, certState.Certificate.Signature) { if bytes.Equal(hostinfo.ConnectionState.myCert.Signature, certState.Certificate.Signature) {
return return
} }
n.l.WithField("vpnIp", hostinfo.vpnIp). cm.l.WithField("vpnIp", hostinfo.vpnIp).
WithField("reason", "local certificate is not current"). WithField("reason", "local certificate is not current").
Info("Re-handshaking with remote") Info("Re-handshaking with remote")
n.intf.handshakeManager.StartHandshake(hostinfo.vpnIp, nil) cm.intf.handshakeManager.StartHandshake(hostinfo.vpnIp, nil)
} }

View File

@@ -1,32 +1,29 @@
package nebula package nebula
import ( import (
"context"
"crypto/ed25519" "crypto/ed25519"
"crypto/rand" "crypto/rand"
"net" "net"
"net/netip"
"testing" "testing"
"time" "time"
"github.com/flynn/noise" "github.com/flynn/noise"
"github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/test" "github.com/slackhq/nebula/test"
"github.com/slackhq/nebula/udp" "github.com/slackhq/nebula/udp"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
var vpnIp iputil.VpnIp
func newTestLighthouse() *LightHouse { func newTestLighthouse() *LightHouse {
lh := &LightHouse{ lh := &LightHouse{
l: test.NewLogger(), l: test.NewLogger(),
addrMap: map[iputil.VpnIp]*RemoteList{}, addrMap: map[netip.Addr]*RemoteList{},
queryChan: make(chan iputil.VpnIp, 10), queryChan: make(chan netip.Addr, 10),
} }
lighthouses := map[iputil.VpnIp]struct{}{} lighthouses := map[netip.Addr]struct{}{}
staticList := map[iputil.VpnIp]struct{}{} staticList := map[netip.Addr]struct{}{}
lh.lighthouses.Store(&lighthouses) lh.lighthouses.Store(&lighthouses)
lh.staticList.Store(&staticList) lh.staticList.Store(&staticList)
@@ -37,10 +34,10 @@ func newTestLighthouse() *LightHouse {
func Test_NewConnectionManagerTest(t *testing.T) { func Test_NewConnectionManagerTest(t *testing.T) {
l := test.NewLogger() l := test.NewLogger()
//_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24") //_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24")
_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24") vpncidr := netip.MustParsePrefix("172.1.1.1/24")
_, localrange, _ := net.ParseCIDR("10.1.1.1/24") localrange := netip.MustParsePrefix("10.1.1.1/24")
vpnIp = iputil.Ip2VpnIp(net.ParseIP("172.1.1.2")) vpnIp := netip.MustParseAddr("172.1.1.2")
preferredRanges := []*net.IPNet{localrange} preferredRanges := []netip.Prefix{localrange}
// Very incomplete mock objects // Very incomplete mock objects
hostMap := newHostMap(l, vpncidr) hostMap := newHostMap(l, vpncidr)
@@ -67,10 +64,10 @@ func Test_NewConnectionManagerTest(t *testing.T) {
ifce.pki.cs.Store(cs) ifce.pki.cs.Store(cs)
// Create manager // Create manager
ctx, cancel := context.WithCancel(context.Background()) conf := config.NewC(l)
defer cancel() punchy := NewPunchyFromConfig(l, conf)
punchy := NewPunchyFromConfig(l, config.NewC(l)) nc := newConnectionManagerFromConfig(l, conf, hostMap, punchy)
nc := newConnectionManager(ctx, l, ifce, 5, 10, punchy) nc.intf = ifce
p := []byte("") p := []byte("")
nb := make([]byte, 12, 12) nb := make([]byte, 12, 12)
out := make([]byte, mtu) out := make([]byte, mtu)
@@ -88,31 +85,32 @@ func Test_NewConnectionManagerTest(t *testing.T) {
nc.hostMap.unlockedAddHostInfo(hostinfo, ifce) nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)
// We saw traffic out to vpnIp // We saw traffic out to vpnIp
nc.Out(hostinfo.localIndexId) nc.Out(hostinfo)
nc.In(hostinfo.localIndexId) nc.In(hostinfo)
assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId) assert.False(t, hostinfo.pendingDeletion.Load())
assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnIp) assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnIp)
assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId) assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
assert.Contains(t, nc.out, hostinfo.localIndexId) assert.True(t, hostinfo.out.Load())
assert.True(t, hostinfo.in.Load())
// Do a traffic check tick, should not be pending deletion but should not have any in/out packets recorded // Do a traffic check tick, should not be pending deletion but should not have any in/out packets recorded
nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now()) nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId) assert.False(t, hostinfo.pendingDeletion.Load())
assert.NotContains(t, nc.out, hostinfo.localIndexId) assert.False(t, hostinfo.out.Load())
assert.NotContains(t, nc.in, hostinfo.localIndexId) assert.False(t, hostinfo.in.Load())
// Do another traffic check tick, this host should be pending deletion now // Do another traffic check tick, this host should be pending deletion now
nc.Out(hostinfo.localIndexId) nc.Out(hostinfo)
assert.True(t, hostinfo.out.Load())
nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now()) nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
assert.Contains(t, nc.pendingDeletion, hostinfo.localIndexId) assert.True(t, hostinfo.pendingDeletion.Load())
assert.NotContains(t, nc.out, hostinfo.localIndexId) assert.False(t, hostinfo.out.Load())
assert.NotContains(t, nc.in, hostinfo.localIndexId) assert.False(t, hostinfo.in.Load())
assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId) assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnIp) assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnIp)
// Do a final traffic check tick, the host should now be removed // Do a final traffic check tick, the host should now be removed
nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now()) nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId)
assert.NotContains(t, nc.hostMap.Hosts, hostinfo.vpnIp) assert.NotContains(t, nc.hostMap.Hosts, hostinfo.vpnIp)
assert.NotContains(t, nc.hostMap.Indexes, hostinfo.localIndexId) assert.NotContains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
} }
@@ -120,9 +118,10 @@ func Test_NewConnectionManagerTest(t *testing.T) {
func Test_NewConnectionManagerTest2(t *testing.T) { func Test_NewConnectionManagerTest2(t *testing.T) {
l := test.NewLogger() l := test.NewLogger()
//_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24") //_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24")
_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24") vpncidr := netip.MustParsePrefix("172.1.1.1/24")
_, localrange, _ := net.ParseCIDR("10.1.1.1/24") localrange := netip.MustParsePrefix("10.1.1.1/24")
preferredRanges := []*net.IPNet{localrange} vpnIp := netip.MustParseAddr("172.1.1.2")
preferredRanges := []netip.Prefix{localrange}
// Very incomplete mock objects // Very incomplete mock objects
hostMap := newHostMap(l, vpncidr) hostMap := newHostMap(l, vpncidr)
@@ -149,10 +148,10 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
ifce.pki.cs.Store(cs) ifce.pki.cs.Store(cs)
// Create manager // Create manager
ctx, cancel := context.WithCancel(context.Background()) conf := config.NewC(l)
defer cancel() punchy := NewPunchyFromConfig(l, conf)
punchy := NewPunchyFromConfig(l, config.NewC(l)) nc := newConnectionManagerFromConfig(l, conf, hostMap, punchy)
nc := newConnectionManager(ctx, l, ifce, 5, 10, punchy) nc.intf = ifce
p := []byte("") p := []byte("")
nb := make([]byte, 12, 12) nb := make([]byte, 12, 12)
out := make([]byte, mtu) out := make([]byte, mtu)
@@ -170,33 +169,130 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
nc.hostMap.unlockedAddHostInfo(hostinfo, ifce) nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)
// We saw traffic out to vpnIp // We saw traffic out to vpnIp
nc.Out(hostinfo.localIndexId) nc.Out(hostinfo)
nc.In(hostinfo.localIndexId) nc.In(hostinfo)
assert.NotContains(t, nc.pendingDeletion, hostinfo.vpnIp) assert.True(t, hostinfo.in.Load())
assert.True(t, hostinfo.out.Load())
assert.False(t, hostinfo.pendingDeletion.Load())
assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnIp) assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnIp)
assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId) assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
// Do a traffic check tick, should not be pending deletion but should not have any in/out packets recorded // Do a traffic check tick, should not be pending deletion but should not have any in/out packets recorded
nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now()) nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId) assert.False(t, hostinfo.pendingDeletion.Load())
assert.NotContains(t, nc.out, hostinfo.localIndexId) assert.False(t, hostinfo.out.Load())
assert.NotContains(t, nc.in, hostinfo.localIndexId) assert.False(t, hostinfo.in.Load())
// Do another traffic check tick, this host should be pending deletion now // Do another traffic check tick, this host should be pending deletion now
nc.Out(hostinfo.localIndexId) nc.Out(hostinfo)
nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now()) nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
assert.Contains(t, nc.pendingDeletion, hostinfo.localIndexId) assert.True(t, hostinfo.pendingDeletion.Load())
assert.NotContains(t, nc.out, hostinfo.localIndexId) assert.False(t, hostinfo.out.Load())
assert.NotContains(t, nc.in, hostinfo.localIndexId) assert.False(t, hostinfo.in.Load())
assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId) assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnIp) assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnIp)
// We saw traffic, should no longer be pending deletion // We saw traffic, should no longer be pending deletion
nc.In(hostinfo.localIndexId) nc.In(hostinfo)
nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now()) nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId) assert.False(t, hostinfo.pendingDeletion.Load())
assert.NotContains(t, nc.out, hostinfo.localIndexId) assert.False(t, hostinfo.out.Load())
assert.NotContains(t, nc.in, hostinfo.localIndexId) assert.False(t, hostinfo.in.Load())
assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnIp)
}
func Test_NewConnectionManager_DisconnectInactive(t *testing.T) {
l := test.NewLogger()
vpncidr := netip.MustParsePrefix("172.1.1.1/24")
localrange := netip.MustParsePrefix("10.1.1.1/24")
vpnIp := netip.MustParseAddr("172.1.1.2")
preferredRanges := []netip.Prefix{localrange}
// Very incomplete mock objects
hostMap := newHostMap(l, vpncidr)
hostMap.preferredRanges.Store(&preferredRanges)
cs := &CertState{
RawCertificate: []byte{},
PrivateKey: []byte{},
Certificate: &cert.NebulaCertificate{},
RawCertificateNoKey: []byte{},
}
lh := newTestLighthouse()
ifce := &Interface{
hostMap: hostMap,
inside: &test.NoopTun{},
outside: &udp.NoopConn{},
firewall: &Firewall{},
lightHouse: lh,
pki: &PKI{},
handshakeManager: NewHandshakeManager(l, hostMap, lh, &udp.NoopConn{}, defaultHandshakeConfig),
l: l,
}
ifce.pki.cs.Store(cs)
// Create manager
conf := config.NewC(l)
conf.Settings["tunnels"] = map[interface{}]interface{}{
"drop_inactive": true,
}
punchy := NewPunchyFromConfig(l, conf)
nc := newConnectionManagerFromConfig(l, conf, hostMap, punchy)
assert.True(t, nc.dropInactive.Load())
nc.intf = ifce
// Add an ip we have established a connection w/ to hostmap
hostinfo := &HostInfo{
vpnIp: vpnIp,
localIndexId: 1099,
remoteIndexId: 9901,
}
hostinfo.ConnectionState = &ConnectionState{
myCert: &cert.NebulaCertificate{},
H: &noise.HandshakeState{},
}
nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)
// Do a traffic check tick, in and out should be cleared but should not be pending deletion
nc.Out(hostinfo)
nc.In(hostinfo)
assert.True(t, hostinfo.out.Load())
assert.True(t, hostinfo.in.Load())
now := time.Now()
decision, _, _ := nc.makeTrafficDecision(hostinfo.localIndexId, now)
assert.Equal(t, tryRehandshake, decision)
assert.Equal(t, now, hostinfo.lastUsed)
assert.False(t, hostinfo.pendingDeletion.Load())
assert.False(t, hostinfo.out.Load())
assert.False(t, hostinfo.in.Load())
decision, _, _ = nc.makeTrafficDecision(hostinfo.localIndexId, now.Add(time.Second*5))
assert.Equal(t, doNothing, decision)
assert.Equal(t, now, hostinfo.lastUsed)
assert.False(t, hostinfo.pendingDeletion.Load())
assert.False(t, hostinfo.out.Load())
assert.False(t, hostinfo.in.Load())
// Do another traffic check tick, should still not be pending deletion
decision, _, _ = nc.makeTrafficDecision(hostinfo.localIndexId, now.Add(time.Second*10))
assert.Equal(t, doNothing, decision)
assert.Equal(t, now, hostinfo.lastUsed)
assert.False(t, hostinfo.pendingDeletion.Load())
assert.False(t, hostinfo.out.Load())
assert.False(t, hostinfo.in.Load())
assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnIp)
// Finally advance beyond the inactivity timeout
decision, _, _ = nc.makeTrafficDecision(hostinfo.localIndexId, now.Add(time.Minute*10))
assert.Equal(t, closeTunnel, decision)
assert.Equal(t, now, hostinfo.lastUsed)
assert.False(t, hostinfo.pendingDeletion.Load())
assert.False(t, hostinfo.out.Load())
assert.False(t, hostinfo.in.Load())
assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId) assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnIp) assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnIp)
} }
@@ -211,9 +307,10 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
IP: net.IPv4(172, 1, 1, 2), IP: net.IPv4(172, 1, 1, 2),
Mask: net.IPMask{255, 255, 255, 0}, Mask: net.IPMask{255, 255, 255, 0},
} }
_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24") vpncidr := netip.MustParsePrefix("172.1.1.1/24")
_, localrange, _ := net.ParseCIDR("10.1.1.1/24") localrange := netip.MustParsePrefix("10.1.1.1/24")
preferredRanges := []*net.IPNet{localrange} vpnIp := netip.MustParseAddr("172.1.1.2")
preferredRanges := []netip.Prefix{localrange}
hostMap := newHostMap(l, vpncidr) hostMap := newHostMap(l, vpncidr)
hostMap.preferredRanges.Store(&preferredRanges) hostMap.preferredRanges.Store(&preferredRanges)
@@ -273,10 +370,10 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
ifce.disconnectInvalid.Store(true) ifce.disconnectInvalid.Store(true)
// Create manager // Create manager
ctx, cancel := context.WithCancel(context.Background()) conf := config.NewC(l)
defer cancel() punchy := NewPunchyFromConfig(l, conf)
punchy := NewPunchyFromConfig(l, config.NewC(l)) nc := newConnectionManagerFromConfig(l, conf, hostMap, punchy)
nc := newConnectionManager(ctx, l, ifce, 5, 10, punchy) nc.intf = ifce
ifce.connectionManager = nc ifce.connectionManager = nc
hostinfo := &HostInfo{ hostinfo := &HostInfo{

View File

@@ -72,6 +72,8 @@ func NewConnectionState(l *logrus.Logger, cipher string, certState *CertState, i
window: b, window: b,
myCert: certState.Certificate, myCert: certState.Certificate,
} }
// always start the counter from 2, as packet 1 and packet 2 are handshake packets.
ci.messageCounter.Add(2)
return ci return ci
} }

View File

@@ -2,7 +2,7 @@ package nebula
import ( import (
"context" "context"
"net" "net/netip"
"os" "os"
"os/signal" "os/signal"
"syscall" "syscall"
@@ -10,9 +10,7 @@ import (
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/header" "github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/overlay" "github.com/slackhq/nebula/overlay"
"github.com/slackhq/nebula/udp"
) )
// Every interaction here needs to take extra care to copy memory and not return or use arguments "as is" when touching // Every interaction here needs to take extra care to copy memory and not return or use arguments "as is" when touching
@@ -21,10 +19,10 @@ import (
type controlEach func(h *HostInfo) type controlEach func(h *HostInfo)
type controlHostLister interface { type controlHostLister interface {
QueryVpnIp(vpnIp iputil.VpnIp) *HostInfo QueryVpnIp(vpnIp netip.Addr) *HostInfo
ForEachIndex(each controlEach) ForEachIndex(each controlEach)
ForEachVpnIp(each controlEach) ForEachVpnIp(each controlEach)
GetPreferredRanges() []*net.IPNet GetPreferredRanges() []netip.Prefix
} }
type Control struct { type Control struct {
@@ -36,18 +34,19 @@ type Control struct {
statsStart func() statsStart func()
dnsStart func() dnsStart func()
lighthouseStart func() lighthouseStart func()
connectionManagerStart func(context.Context)
} }
type ControlHostInfo struct { type ControlHostInfo struct {
VpnIp net.IP `json:"vpnIp"` VpnIp netip.Addr `json:"vpnIp"`
LocalIndex uint32 `json:"localIndex"` LocalIndex uint32 `json:"localIndex"`
RemoteIndex uint32 `json:"remoteIndex"` RemoteIndex uint32 `json:"remoteIndex"`
RemoteAddrs []*udp.Addr `json:"remoteAddrs"` RemoteAddrs []netip.AddrPort `json:"remoteAddrs"`
Cert *cert.NebulaCertificate `json:"cert"` Cert *cert.NebulaCertificate `json:"cert"`
MessageCounter uint64 `json:"messageCounter"` MessageCounter uint64 `json:"messageCounter"`
CurrentRemote *udp.Addr `json:"currentRemote"` CurrentRemote netip.AddrPort `json:"currentRemote"`
CurrentRelaysToMe []iputil.VpnIp `json:"currentRelaysToMe"` CurrentRelaysToMe []netip.Addr `json:"currentRelaysToMe"`
CurrentRelaysThroughMe []iputil.VpnIp `json:"currentRelaysThroughMe"` CurrentRelaysThroughMe []netip.Addr `json:"currentRelaysThroughMe"`
} }
// Start actually runs nebula, this is a nonblocking call. To block use Control.ShutdownBlock() // Start actually runs nebula, this is a nonblocking call. To block use Control.ShutdownBlock()
@@ -65,6 +64,9 @@ func (c *Control) Start() {
if c.dnsStart != nil { if c.dnsStart != nil {
go c.dnsStart() go c.dnsStart()
} }
if c.connectionManagerStart != nil {
go c.connectionManagerStart(c.ctx)
}
if c.lighthouseStart != nil { if c.lighthouseStart != nil {
c.lighthouseStart() c.lighthouseStart()
} }
@@ -131,8 +133,45 @@ func (c *Control) ListHostmapIndexes(pendingMap bool) []ControlHostInfo {
} }
} }
// GetCertByVpnIp returns the authenticated certificate of the given vpn IP, or nil if not found
func (c *Control) GetCertByVpnIp(vpnIp netip.Addr) *cert.NebulaCertificate {
if c.f.myVpnNet.Addr() == vpnIp {
return c.f.pki.GetCertState().Certificate
}
hi := c.f.hostMap.QueryVpnIp(vpnIp)
if hi == nil {
return nil
}
return hi.GetCert()
}
// CreateTunnel creates a new tunnel to the given vpn ip.
func (c *Control) CreateTunnel(vpnIp netip.Addr) {
c.f.handshakeManager.StartHandshake(vpnIp, nil)
}
// PrintTunnel creates a new tunnel to the given vpn ip.
func (c *Control) PrintTunnel(vpnIp netip.Addr) *ControlHostInfo {
hi := c.f.hostMap.QueryVpnIp(vpnIp)
if hi == nil {
return nil
}
chi := copyHostInfo(hi, c.f.hostMap.GetPreferredRanges())
return &chi
}
// QueryLighthouse queries the lighthouse.
func (c *Control) QueryLighthouse(vpnIp netip.Addr) *CacheMap {
hi := c.f.lightHouse.Query(vpnIp)
if hi == nil {
return nil
}
return hi.CopyCache()
}
// GetHostInfoByVpnIp returns a single tunnels hostInfo, or nil if not found // GetHostInfoByVpnIp returns a single tunnels hostInfo, or nil if not found
func (c *Control) GetHostInfoByVpnIp(vpnIp iputil.VpnIp, pending bool) *ControlHostInfo { // Caller should take care to Unmap() any 4in6 addresses prior to calling.
func (c *Control) GetHostInfoByVpnIp(vpnIp netip.Addr, pending bool) *ControlHostInfo {
var hl controlHostLister var hl controlHostLister
if pending { if pending {
hl = c.f.handshakeManager hl = c.f.handshakeManager
@@ -150,19 +189,21 @@ func (c *Control) GetHostInfoByVpnIp(vpnIp iputil.VpnIp, pending bool) *ControlH
} }
// SetRemoteForTunnel forces a tunnel to use a specific remote // SetRemoteForTunnel forces a tunnel to use a specific remote
func (c *Control) SetRemoteForTunnel(vpnIp iputil.VpnIp, addr udp.Addr) *ControlHostInfo { // Caller should take care to Unmap() any 4in6 addresses prior to calling.
func (c *Control) SetRemoteForTunnel(vpnIp netip.Addr, addr netip.AddrPort) *ControlHostInfo {
hostInfo := c.f.hostMap.QueryVpnIp(vpnIp) hostInfo := c.f.hostMap.QueryVpnIp(vpnIp)
if hostInfo == nil { if hostInfo == nil {
return nil return nil
} }
hostInfo.SetRemote(addr.Copy()) hostInfo.SetRemote(addr)
ch := copyHostInfo(hostInfo, c.f.hostMap.GetPreferredRanges()) ch := copyHostInfo(hostInfo, c.f.hostMap.GetPreferredRanges())
return &ch return &ch
} }
// CloseTunnel closes a fully established tunnel. If localOnly is false it will notify the remote end as well. // CloseTunnel closes a fully established tunnel. If localOnly is false it will notify the remote end as well.
func (c *Control) CloseTunnel(vpnIp iputil.VpnIp, localOnly bool) bool { // Caller should take care to Unmap() any 4in6 addresses prior to calling.
func (c *Control) CloseTunnel(vpnIp netip.Addr, localOnly bool) bool {
hostInfo := c.f.hostMap.QueryVpnIp(vpnIp) hostInfo := c.f.hostMap.QueryVpnIp(vpnIp)
if hostInfo == nil { if hostInfo == nil {
return false return false
@@ -205,7 +246,7 @@ func (c *Control) CloseAllTunnels(excludeLighthouses bool) (closed int) {
} }
// Learn which hosts are being used as relays, so we can shut them down last. // Learn which hosts are being used as relays, so we can shut them down last.
relayingHosts := map[iputil.VpnIp]*HostInfo{} relayingHosts := map[netip.Addr]*HostInfo{}
// Grab the hostMap lock to access the Relays map // Grab the hostMap lock to access the Relays map
c.f.hostMap.Lock() c.f.hostMap.Lock()
for _, relayingHost := range c.f.hostMap.Relays { for _, relayingHost := range c.f.hostMap.Relays {
@@ -236,15 +277,16 @@ func (c *Control) Device() overlay.Device {
return c.f.inside return c.f.inside
} }
func copyHostInfo(h *HostInfo, preferredRanges []*net.IPNet) ControlHostInfo { func copyHostInfo(h *HostInfo, preferredRanges []netip.Prefix) ControlHostInfo {
chi := ControlHostInfo{ chi := ControlHostInfo{
VpnIp: h.vpnIp.ToIP(), VpnIp: h.vpnIp,
LocalIndex: h.localIndexId, LocalIndex: h.localIndexId,
RemoteIndex: h.remoteIndexId, RemoteIndex: h.remoteIndexId,
RemoteAddrs: h.remotes.CopyAddrs(preferredRanges), RemoteAddrs: h.remotes.CopyAddrs(preferredRanges),
CurrentRelaysToMe: h.relayState.CopyRelayIps(), CurrentRelaysToMe: h.relayState.CopyRelayIps(),
CurrentRelaysThroughMe: h.relayState.CopyRelayForIps(), CurrentRelaysThroughMe: h.relayState.CopyRelayForIps(),
CurrentRemote: h.remote,
} }
if h.ConnectionState != nil { if h.ConnectionState != nil {
@@ -255,10 +297,6 @@ func copyHostInfo(h *HostInfo, preferredRanges []*net.IPNet) ControlHostInfo {
chi.Cert = c.Copy() chi.Cert = c.Copy()
} }
if h.remote != nil {
chi.CurrentRemote = h.remote.Copy()
}
return chi return chi
} }

View File

@@ -2,15 +2,14 @@ package nebula
import ( import (
"net" "net"
"net/netip"
"reflect" "reflect"
"testing" "testing"
"time" "time"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/test" "github.com/slackhq/nebula/test"
"github.com/slackhq/nebula/udp"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
@@ -18,18 +17,19 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
l := test.NewLogger() l := test.NewLogger()
// Special care must be taken to re-use all objects provided to the hostmap and certificate in the expectedInfo object // Special care must be taken to re-use all objects provided to the hostmap and certificate in the expectedInfo object
// To properly ensure we are not exposing core memory to the caller // To properly ensure we are not exposing core memory to the caller
hm := newHostMap(l, &net.IPNet{}) hm := newHostMap(l, netip.Prefix{})
hm.preferredRanges.Store(&[]*net.IPNet{}) hm.preferredRanges.Store(&[]netip.Prefix{})
remote1 := netip.MustParseAddrPort("0.0.0.100:4444")
remote2 := netip.MustParseAddrPort("[1:2:3:4:5:6:7:8]:4444")
remote1 := udp.NewAddr(net.ParseIP("0.0.0.100"), 4444)
remote2 := udp.NewAddr(net.ParseIP("1:2:3:4:5:6:7:8"), 4444)
ipNet := net.IPNet{ ipNet := net.IPNet{
IP: net.IPv4(1, 2, 3, 4), IP: remote1.Addr().AsSlice(),
Mask: net.IPMask{255, 255, 255, 0}, Mask: net.IPMask{255, 255, 255, 0},
} }
ipNet2 := net.IPNet{ ipNet2 := net.IPNet{
IP: net.ParseIP("1:2:3:4:5:6:7:8"), IP: remote2.Addr().AsSlice(),
Mask: net.IPMask{255, 255, 255, 0}, Mask: net.IPMask{255, 255, 255, 0},
} }
@@ -50,8 +50,12 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
} }
remotes := NewRemoteList(nil) remotes := NewRemoteList(nil)
remotes.unlockedPrependV4(0, NewIp4AndPort(remote1.IP, uint32(remote1.Port))) remotes.unlockedPrependV4(netip.IPv4Unspecified(), NewIp4AndPortFromNetIP(remote1.Addr(), remote1.Port()))
remotes.unlockedPrependV6(0, NewIp6AndPort(remote2.IP, uint32(remote2.Port))) remotes.unlockedPrependV6(netip.IPv4Unspecified(), NewIp6AndPortFromNetIP(remote2.Addr(), remote2.Port()))
vpnIp, ok := netip.AddrFromSlice(ipNet.IP)
assert.True(t, ok)
hm.unlockedAddHostInfo(&HostInfo{ hm.unlockedAddHostInfo(&HostInfo{
remote: remote1, remote: remote1,
remotes: remotes, remotes: remotes,
@@ -60,14 +64,17 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
}, },
remoteIndexId: 200, remoteIndexId: 200,
localIndexId: 201, localIndexId: 201,
vpnIp: iputil.Ip2VpnIp(ipNet.IP), vpnIp: vpnIp,
relayState: RelayState{ relayState: RelayState{
relays: map[iputil.VpnIp]struct{}{}, relays: nil,
relayForByIp: map[iputil.VpnIp]*Relay{}, relayForByIp: map[netip.Addr]*Relay{},
relayForByIdx: map[uint32]*Relay{}, relayForByIdx: map[uint32]*Relay{},
}, },
}, &Interface{}) }, &Interface{})
vpnIp2, ok := netip.AddrFromSlice(ipNet2.IP)
assert.True(t, ok)
hm.unlockedAddHostInfo(&HostInfo{ hm.unlockedAddHostInfo(&HostInfo{
remote: remote1, remote: remote1,
remotes: remotes, remotes: remotes,
@@ -76,10 +83,10 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
}, },
remoteIndexId: 200, remoteIndexId: 200,
localIndexId: 201, localIndexId: 201,
vpnIp: iputil.Ip2VpnIp(ipNet2.IP), vpnIp: vpnIp2,
relayState: RelayState{ relayState: RelayState{
relays: map[iputil.VpnIp]struct{}{}, relays: nil,
relayForByIp: map[iputil.VpnIp]*Relay{}, relayForByIp: map[netip.Addr]*Relay{},
relayForByIdx: map[uint32]*Relay{}, relayForByIdx: map[uint32]*Relay{},
}, },
}, &Interface{}) }, &Interface{})
@@ -91,27 +98,29 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
l: logrus.New(), l: logrus.New(),
} }
thi := c.GetHostInfoByVpnIp(iputil.Ip2VpnIp(ipNet.IP), false) thi := c.GetHostInfoByVpnIp(vpnIp, false)
expectedInfo := ControlHostInfo{ expectedInfo := ControlHostInfo{
VpnIp: net.IPv4(1, 2, 3, 4).To4(), VpnIp: vpnIp,
LocalIndex: 201, LocalIndex: 201,
RemoteIndex: 200, RemoteIndex: 200,
RemoteAddrs: []*udp.Addr{remote2, remote1}, RemoteAddrs: []netip.AddrPort{remote2, remote1},
Cert: crt.Copy(), Cert: crt.Copy(),
MessageCounter: 0, MessageCounter: 0,
CurrentRemote: udp.NewAddr(net.ParseIP("0.0.0.100"), 4444), CurrentRemote: remote1,
CurrentRelaysToMe: []iputil.VpnIp{}, CurrentRelaysToMe: []netip.Addr{},
CurrentRelaysThroughMe: []iputil.VpnIp{}, CurrentRelaysThroughMe: []netip.Addr{},
} }
// Make sure we don't have any unexpected fields // Make sure we don't have any unexpected fields
assertFields(t, []string{"VpnIp", "LocalIndex", "RemoteIndex", "RemoteAddrs", "Cert", "MessageCounter", "CurrentRemote", "CurrentRelaysToMe", "CurrentRelaysThroughMe"}, thi) assertFields(t, []string{"VpnIp", "LocalIndex", "RemoteIndex", "RemoteAddrs", "Cert", "MessageCounter", "CurrentRemote", "CurrentRelaysToMe", "CurrentRelaysThroughMe"}, thi)
test.AssertDeepCopyEqual(t, &expectedInfo, thi) assert.EqualValues(t, &expectedInfo, thi)
//TODO: netip.Addr reuses global memory for zone identifiers which breaks our "no reused memory check" here
//test.AssertDeepCopyEqual(t, &expectedInfo, thi)
// Make sure we don't panic if the host info doesn't have a cert yet // Make sure we don't panic if the host info doesn't have a cert yet
assert.NotPanics(t, func() { assert.NotPanics(t, func() {
thi = c.GetHostInfoByVpnIp(iputil.Ip2VpnIp(ipNet2.IP), false) thi = c.GetHostInfoByVpnIp(vpnIp2, false)
}) })
} }

View File

@@ -4,14 +4,13 @@
package nebula package nebula
import ( import (
"net" "net/netip"
"github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/cert"
"github.com/google/gopacket" "github.com/google/gopacket"
"github.com/google/gopacket/layers" "github.com/google/gopacket/layers"
"github.com/slackhq/nebula/header" "github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/overlay" "github.com/slackhq/nebula/overlay"
"github.com/slackhq/nebula/udp" "github.com/slackhq/nebula/udp"
) )
@@ -50,37 +49,30 @@ func (c *Control) WaitForTypeByIndex(toIndex uint32, msgType header.MessageType,
// InjectLightHouseAddr will push toAddr into the local lighthouse cache for the vpnIp // InjectLightHouseAddr will push toAddr into the local lighthouse cache for the vpnIp
// This is necessary if you did not configure static hosts or are not running a lighthouse // This is necessary if you did not configure static hosts or are not running a lighthouse
func (c *Control) InjectLightHouseAddr(vpnIp net.IP, toAddr *net.UDPAddr) { func (c *Control) InjectLightHouseAddr(vpnIp netip.Addr, toAddr netip.AddrPort) {
c.f.lightHouse.Lock() c.f.lightHouse.Lock()
remoteList := c.f.lightHouse.unlockedGetRemoteList(iputil.Ip2VpnIp(vpnIp)) remoteList := c.f.lightHouse.unlockedGetRemoteList(vpnIp)
remoteList.Lock() remoteList.Lock()
defer remoteList.Unlock() defer remoteList.Unlock()
c.f.lightHouse.Unlock() c.f.lightHouse.Unlock()
iVpnIp := iputil.Ip2VpnIp(vpnIp) if toAddr.Addr().Is4() {
if v4 := toAddr.IP.To4(); v4 != nil { remoteList.unlockedPrependV4(vpnIp, NewIp4AndPortFromNetIP(toAddr.Addr(), toAddr.Port()))
remoteList.unlockedPrependV4(iVpnIp, NewIp4AndPort(v4, uint32(toAddr.Port)))
} else { } else {
remoteList.unlockedPrependV6(iVpnIp, NewIp6AndPort(toAddr.IP, uint32(toAddr.Port))) remoteList.unlockedPrependV6(vpnIp, NewIp6AndPortFromNetIP(toAddr.Addr(), toAddr.Port()))
} }
} }
// InjectRelays will push relayVpnIps into the local lighthouse cache for the vpnIp // InjectRelays will push relayVpnIps into the local lighthouse cache for the vpnIp
// This is necessary to inform an initiator of possible relays for communicating with a responder // This is necessary to inform an initiator of possible relays for communicating with a responder
func (c *Control) InjectRelays(vpnIp net.IP, relayVpnIps []net.IP) { func (c *Control) InjectRelays(vpnIp netip.Addr, relayVpnIps []netip.Addr) {
c.f.lightHouse.Lock() c.f.lightHouse.Lock()
remoteList := c.f.lightHouse.unlockedGetRemoteList(iputil.Ip2VpnIp(vpnIp)) remoteList := c.f.lightHouse.unlockedGetRemoteList(vpnIp)
remoteList.Lock() remoteList.Lock()
defer remoteList.Unlock() defer remoteList.Unlock()
c.f.lightHouse.Unlock() c.f.lightHouse.Unlock()
iVpnIp := iputil.Ip2VpnIp(vpnIp) remoteList.unlockedSetRelay(vpnIp, vpnIp, relayVpnIps)
uVpnIp := []uint32{}
for _, rVPnIp := range relayVpnIps {
uVpnIp = append(uVpnIp, uint32(iputil.Ip2VpnIp(rVPnIp)))
}
remoteList.unlockedSetRelay(iVpnIp, iVpnIp, uVpnIp)
} }
// GetFromTun will pull a packet off the tun side of nebula // GetFromTun will pull a packet off the tun side of nebula
@@ -107,13 +99,14 @@ func (c *Control) InjectUDPPacket(p *udp.Packet) {
} }
// InjectTunUDPPacket puts a udp packet on the tun interface. Using UDP here because it's a simpler protocol // InjectTunUDPPacket puts a udp packet on the tun interface. Using UDP here because it's a simpler protocol
func (c *Control) InjectTunUDPPacket(toIp net.IP, toPort uint16, fromPort uint16, data []byte) { func (c *Control) InjectTunUDPPacket(toIp netip.Addr, toPort uint16, fromPort uint16, data []byte) {
//TODO: IPV6-WORK
ip := layers.IPv4{ ip := layers.IPv4{
Version: 4, Version: 4,
TTL: 64, TTL: 64,
Protocol: layers.IPProtocolUDP, Protocol: layers.IPProtocolUDP,
SrcIP: c.f.inside.Cidr().IP, SrcIP: c.f.inside.Cidr().Addr().Unmap().AsSlice(),
DstIP: toIp, DstIP: toIp.Unmap().AsSlice(),
} }
udp := layers.UDP{ udp := layers.UDP{
@@ -138,16 +131,16 @@ func (c *Control) InjectTunUDPPacket(toIp net.IP, toPort uint16, fromPort uint16
c.f.inside.(*overlay.TestTun).Send(buffer.Bytes()) c.f.inside.(*overlay.TestTun).Send(buffer.Bytes())
} }
func (c *Control) GetVpnIp() iputil.VpnIp { func (c *Control) GetVpnIp() netip.Addr {
return c.f.myVpnIp return c.f.myVpnNet.Addr()
} }
func (c *Control) GetUDPAddr() string { func (c *Control) GetUDPAddr() netip.AddrPort {
return c.f.outside.(*udp.TesterConn).Addr.String() return c.f.outside.(*udp.TesterConn).Addr
} }
func (c *Control) KillPendingTunnel(vpnIp net.IP) bool { func (c *Control) KillPendingTunnel(vpnIp netip.Addr) bool {
hostinfo := c.f.handshakeManager.QueryVpnIp(iputil.Ip2VpnIp(vpnIp)) hostinfo := c.f.handshakeManager.QueryVpnIp(vpnIp)
if hostinfo == nil { if hostinfo == nil {
return false return false
} }
@@ -164,6 +157,6 @@ func (c *Control) GetCert() *cert.NebulaCertificate {
return c.f.pki.GetCertState().Certificate return c.f.pki.GetCertState().Certificate
} }
func (c *Control) ReHandshake(vpnIp iputil.VpnIp) { func (c *Control) ReHandshake(vpnIp netip.Addr) {
c.f.handshakeManager.StartHandshake(vpnIp, nil) c.f.handshakeManager.StartHandshake(vpnIp, nil)
} }

View File

@@ -3,6 +3,7 @@ package nebula
import ( import (
"fmt" "fmt"
"net" "net"
"net/netip"
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
@@ -10,7 +11,6 @@ import (
"github.com/miekg/dns" "github.com/miekg/dns"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/iputil"
) )
// This whole thing should be rewritten to use context // This whole thing should be rewritten to use context
@@ -42,19 +42,21 @@ func (d *dnsRecords) Query(data string) string {
} }
func (d *dnsRecords) QueryCert(data string) string { func (d *dnsRecords) QueryCert(data string) string {
ip := net.ParseIP(data[:len(data)-1]) ip, err := netip.ParseAddr(data[:len(data)-1])
if ip == nil { if err != nil {
return "" return ""
} }
iip := iputil.Ip2VpnIp(ip)
hostinfo := d.hostMap.QueryVpnIp(iip) hostinfo := d.hostMap.QueryVpnIp(ip)
if hostinfo == nil { if hostinfo == nil {
return "" return ""
} }
q := hostinfo.GetCert() q := hostinfo.GetCert()
if q == nil { if q == nil {
return "" return ""
} }
cert := q.Details cert := q.Details
c := fmt.Sprintf("\"Name: %s\" \"Ips: %s\" \"Subnets %s\" \"Groups %s\" \"NotBefore %s\" \"NotAfter %s\" \"PublicKey %x\" \"IsCA %t\" \"Issuer %s\"", cert.Name, cert.Ips, cert.Subnets, cert.Groups, cert.NotBefore, cert.NotAfter, cert.PublicKey, cert.IsCA, cert.Issuer) c := fmt.Sprintf("\"Name: %s\" \"Ips: %s\" \"Subnets %s\" \"Groups %s\" \"NotBefore %s\" \"NotAfter %s\" \"PublicKey %x\" \"IsCA %t\" \"Issuer %s\"", cert.Name, cert.Ips, cert.Subnets, cert.Groups, cert.NotBefore, cert.NotAfter, cert.PublicKey, cert.IsCA, cert.Issuer)
return c return c
@@ -80,7 +82,11 @@ func parseQuery(l *logrus.Logger, m *dns.Msg, w dns.ResponseWriter) {
} }
case dns.TypeTXT: case dns.TypeTXT:
a, _, _ := net.SplitHostPort(w.RemoteAddr().String()) a, _, _ := net.SplitHostPort(w.RemoteAddr().String())
b := net.ParseIP(a) b, err := netip.ParseAddr(a)
if err != nil {
return
}
// We don't answer these queries from non nebula nodes or localhost // We don't answer these queries from non nebula nodes or localhost
//l.Debugf("Does %s contain %s", b, dnsR.hostMap.vpnCIDR) //l.Debugf("Does %s contain %s", b, dnsR.hostMap.vpnCIDR)
if !dnsR.hostMap.vpnCIDR.Contains(b) && a != "127.0.0.1" { if !dnsR.hostMap.vpnCIDR.Contains(b) && a != "127.0.0.1" {

View File

@@ -4,28 +4,29 @@
package e2e package e2e
import ( import (
"fmt" "net/netip"
"net" "slices"
"testing" "testing"
"time" "time"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula" "github.com/slackhq/nebula"
"github.com/slackhq/nebula/e2e/router" "github.com/slackhq/nebula/e2e/router"
"github.com/slackhq/nebula/header" "github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/udp" "github.com/slackhq/nebula/udp"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"gopkg.in/yaml.v2" "gopkg.in/yaml.v2"
) )
func BenchmarkHotPath(b *testing.B) { func BenchmarkHotPath(b *testing.B) {
ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
myControl, _, _, _ := newSimpleServer(ca, caKey, "me", net.IP{10, 0, 0, 1}, nil) myControl, _, _, _ := newSimpleServer(ca, caKey, "me", "10.128.0.1/24", nil)
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil) theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", "10.128.0.2/24", nil)
// Put their info in our lighthouse // Put their info in our lighthouse
myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr) myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr)
// Start the servers // Start the servers
myControl.Start() myControl.Start()
@@ -35,7 +36,7 @@ func BenchmarkHotPath(b *testing.B) {
r.CancelFlowLogs() r.CancelFlowLogs()
for n := 0; n < b.N; n++ { for n := 0; n < b.N; n++ {
myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me")) myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me"))
_ = r.RouteForAllUntilTxTun(theirControl) _ = r.RouteForAllUntilTxTun(theirControl)
} }
@@ -44,19 +45,19 @@ func BenchmarkHotPath(b *testing.B) {
} }
func TestGoodHandshake(t *testing.T) { func TestGoodHandshake(t *testing.T) {
ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me", net.IP{10, 0, 0, 1}, nil) myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me", "10.128.0.1/24", nil)
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil) theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", "10.128.0.2/24", nil)
// Put their info in our lighthouse // Put their info in our lighthouse
myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr) myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr)
// Start the servers // Start the servers
myControl.Start() myControl.Start()
theirControl.Start() theirControl.Start()
t.Log("Send a udp packet through to begin standing up the tunnel, this should come out the other side") t.Log("Send a udp packet through to begin standing up the tunnel, this should come out the other side")
myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me")) myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me"))
t.Log("Have them consume my stage 0 packet. They have a tunnel now") t.Log("Have them consume my stage 0 packet. They have a tunnel now")
theirControl.InjectUDPPacket(myControl.GetFromUDP(true)) theirControl.InjectUDPPacket(myControl.GetFromUDP(true))
@@ -77,16 +78,16 @@ func TestGoodHandshake(t *testing.T) {
myControl.WaitForType(1, 0, theirControl) myControl.WaitForType(1, 0, theirControl)
t.Log("Make sure our host infos are correct") t.Log("Make sure our host infos are correct")
assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl) assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl)
t.Log("Get that cached packet and make sure it looks right") t.Log("Get that cached packet and make sure it looks right")
myCachedPacket := theirControl.GetFromTun(true) myCachedPacket := theirControl.GetFromTun(true)
assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet.IP, theirVpnIpNet.IP, 80, 80) assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80)
t.Log("Do a bidirectional tunnel test") t.Log("Do a bidirectional tunnel test")
r := router.NewR(t, myControl, theirControl) r := router.NewR(t, myControl, theirControl)
defer r.RenderFlow() defer r.RenderFlow()
assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r) assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
r.RenderHostmaps("Final hostmaps", myControl, theirControl) r.RenderHostmaps("Final hostmaps", myControl, theirControl)
myControl.Stop() myControl.Stop()
@@ -95,20 +96,20 @@ func TestGoodHandshake(t *testing.T) {
} }
func TestWrongResponderHandshake(t *testing.T) { func TestWrongResponderHandshake(t *testing.T) {
ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
// The IPs here are chosen on purpose: // The IPs here are chosen on purpose:
// The current remote handling will sort by preference, public, and then lexically. // The current remote handling will sort by preference, public, and then lexically.
// So we need them to have a higher address than evil (we could apply a preference though) // So we need them to have a higher address than evil (we could apply a preference though)
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me", net.IP{10, 0, 0, 100}, nil) myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me", "10.128.0.100/24", nil)
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 99}, nil) theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", "10.128.0.99/24", nil)
evilControl, evilVpnIp, evilUdpAddr, _ := newSimpleServer(ca, caKey, "evil", net.IP{10, 0, 0, 2}, nil) evilControl, evilVpnIp, evilUdpAddr, _ := newSimpleServer(ca, caKey, "evil", "10.128.0.2/24", nil)
// Add their real udp addr, which should be tried after evil. // Add their real udp addr, which should be tried after evil.
myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr) myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr)
// Put the evil udp addr in for their vpn Ip, this is a case of being lied to by the lighthouse. // Put the evil udp addr in for their vpn Ip, this is a case of being lied to by the lighthouse.
myControl.InjectLightHouseAddr(theirVpnIpNet.IP, evilUdpAddr) myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), evilUdpAddr)
// Build a router so we don't have to reason who gets which packet // Build a router so we don't have to reason who gets which packet
r := router.NewR(t, myControl, theirControl, evilControl) r := router.NewR(t, myControl, theirControl, evilControl)
@@ -120,7 +121,7 @@ func TestWrongResponderHandshake(t *testing.T) {
evilControl.Start() evilControl.Start()
t.Log("Start the handshake process, we will route until we see our cached packet get sent to them") t.Log("Start the handshake process, we will route until we see our cached packet get sent to them")
myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me")) myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me"))
r.RouteForAllExitFunc(func(p *udp.Packet, c *nebula.Control) router.ExitType { r.RouteForAllExitFunc(func(p *udp.Packet, c *nebula.Control) router.ExitType {
h := &header.H{} h := &header.H{}
err := h.Parse(p.Data) err := h.Parse(p.Data)
@@ -128,7 +129,7 @@ func TestWrongResponderHandshake(t *testing.T) {
panic(err) panic(err)
} }
if p.ToIp.Equal(theirUdpAddr.IP) && p.ToPort == uint16(theirUdpAddr.Port) && h.Type == 1 { if p.To == theirUdpAddr && h.Type == 1 {
return router.RouteAndExit return router.RouteAndExit
} }
@@ -139,18 +140,18 @@ func TestWrongResponderHandshake(t *testing.T) {
t.Log("My cached packet should be received by them") t.Log("My cached packet should be received by them")
myCachedPacket := theirControl.GetFromTun(true) myCachedPacket := theirControl.GetFromTun(true)
assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet.IP, theirVpnIpNet.IP, 80, 80) assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80)
t.Log("Test the tunnel with them") t.Log("Test the tunnel with them")
assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl) assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl)
assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r) assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
t.Log("Flush all packets from all controllers") t.Log("Flush all packets from all controllers")
r.FlushAll() r.FlushAll()
t.Log("Ensure ensure I don't have any hostinfo artifacts from evil") t.Log("Ensure ensure I don't have any hostinfo artifacts from evil")
assert.Nil(t, myControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(evilVpnIp.IP), true), "My pending hostmap should not contain evil") assert.Nil(t, myControl.GetHostInfoByVpnIp(evilVpnIp.Addr(), true), "My pending hostmap should not contain evil")
assert.Nil(t, myControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(evilVpnIp.IP), false), "My main hostmap should not contain evil") assert.Nil(t, myControl.GetHostInfoByVpnIp(evilVpnIp.Addr(), false), "My main hostmap should not contain evil")
//NOTE: if evil lost the handshake race it may still have a tunnel since me would reject the handshake since the tunnel is complete //NOTE: if evil lost the handshake race it may still have a tunnel since me would reject the handshake since the tunnel is complete
//TODO: assert hostmaps for everyone //TODO: assert hostmaps for everyone
@@ -164,13 +165,13 @@ func TestStage1Race(t *testing.T) {
// This tests ensures that two hosts handshaking with each other at the same time will allow traffic to flow // This tests ensures that two hosts handshaking with each other at the same time will allow traffic to flow
// But will eventually collapse down to a single tunnel // But will eventually collapse down to a single tunnel
ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 1}, nil) myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me ", "10.128.0.1/24", nil)
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil) theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", "10.128.0.2/24", nil)
// Put their info in our lighthouse and vice versa // Put their info in our lighthouse and vice versa
myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr) myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr)
theirControl.InjectLightHouseAddr(myVpnIpNet.IP, myUdpAddr) theirControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr)
// Build a router so we don't have to reason who gets which packet // Build a router so we don't have to reason who gets which packet
r := router.NewR(t, myControl, theirControl) r := router.NewR(t, myControl, theirControl)
@@ -181,8 +182,8 @@ func TestStage1Race(t *testing.T) {
theirControl.Start() theirControl.Start()
t.Log("Trigger a handshake to start on both me and them") t.Log("Trigger a handshake to start on both me and them")
myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me")) myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me"))
theirControl.InjectTunUDPPacket(myVpnIpNet.IP, 80, 80, []byte("Hi from them")) theirControl.InjectTunUDPPacket(myVpnIpNet.Addr(), 80, 80, []byte("Hi from them"))
t.Log("Get both stage 1 handshake packets") t.Log("Get both stage 1 handshake packets")
myHsForThem := myControl.GetFromUDP(true) myHsForThem := myControl.GetFromUDP(true)
@@ -194,14 +195,14 @@ func TestStage1Race(t *testing.T) {
r.Log("Route until they receive a message packet") r.Log("Route until they receive a message packet")
myCachedPacket := r.RouteForAllUntilTxTun(theirControl) myCachedPacket := r.RouteForAllUntilTxTun(theirControl)
assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet.IP, theirVpnIpNet.IP, 80, 80) assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80)
r.Log("Their cached packet should be received by me") r.Log("Their cached packet should be received by me")
theirCachedPacket := r.RouteForAllUntilTxTun(myControl) theirCachedPacket := r.RouteForAllUntilTxTun(myControl)
assertUdpPacket(t, []byte("Hi from them"), theirCachedPacket, theirVpnIpNet.IP, myVpnIpNet.IP, 80, 80) assertUdpPacket(t, []byte("Hi from them"), theirCachedPacket, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), 80, 80)
r.Log("Do a bidirectional tunnel test") r.Log("Do a bidirectional tunnel test")
assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r) assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
myHostmapHosts := myControl.ListHostmapHosts(false) myHostmapHosts := myControl.ListHostmapHosts(false)
myHostmapIndexes := myControl.ListHostmapIndexes(false) myHostmapIndexes := myControl.ListHostmapIndexes(false)
@@ -219,7 +220,7 @@ func TestStage1Race(t *testing.T) {
r.Log("Spin until connection manager tears down a tunnel") r.Log("Spin until connection manager tears down a tunnel")
for len(myControl.GetHostmap().Indexes)+len(theirControl.GetHostmap().Indexes) > 2 { for len(myControl.GetHostmap().Indexes)+len(theirControl.GetHostmap().Indexes) > 2 {
assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r) assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
t.Log("Connection manager hasn't ticked yet") t.Log("Connection manager hasn't ticked yet")
time.Sleep(time.Second) time.Sleep(time.Second)
} }
@@ -241,13 +242,13 @@ func TestStage1Race(t *testing.T) {
} }
func TestUncleanShutdownRaceLoser(t *testing.T) { func TestUncleanShutdownRaceLoser(t *testing.T) {
ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 1}, nil) myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me ", "10.128.0.1/24", nil)
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil) theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", "10.128.0.2/24", nil)
// Teach my how to get to the relay and that their can be reached via the relay // Teach my how to get to the relay and that their can be reached via the relay
myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr) myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr)
theirControl.InjectLightHouseAddr(myVpnIpNet.IP, myUdpAddr) theirControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr)
// Build a router so we don't have to reason who gets which packet // Build a router so we don't have to reason who gets which packet
r := router.NewR(t, myControl, theirControl) r := router.NewR(t, myControl, theirControl)
@@ -258,28 +259,28 @@ func TestUncleanShutdownRaceLoser(t *testing.T) {
theirControl.Start() theirControl.Start()
r.Log("Trigger a handshake from me to them") r.Log("Trigger a handshake from me to them")
myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me")) myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me"))
p := r.RouteForAllUntilTxTun(theirControl) p := r.RouteForAllUntilTxTun(theirControl)
assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.IP, theirVpnIpNet.IP, 80, 80) assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80)
r.Log("Nuke my hostmap") r.Log("Nuke my hostmap")
myHostmap := myControl.GetHostmap() myHostmap := myControl.GetHostmap()
myHostmap.Hosts = map[iputil.VpnIp]*nebula.HostInfo{} myHostmap.Hosts = map[netip.Addr]*nebula.HostInfo{}
myHostmap.Indexes = map[uint32]*nebula.HostInfo{} myHostmap.Indexes = map[uint32]*nebula.HostInfo{}
myHostmap.RemoteIndexes = map[uint32]*nebula.HostInfo{} myHostmap.RemoteIndexes = map[uint32]*nebula.HostInfo{}
myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me again")) myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me again"))
p = r.RouteForAllUntilTxTun(theirControl) p = r.RouteForAllUntilTxTun(theirControl)
assertUdpPacket(t, []byte("Hi from me again"), p, myVpnIpNet.IP, theirVpnIpNet.IP, 80, 80) assertUdpPacket(t, []byte("Hi from me again"), p, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80)
r.Log("Assert the tunnel works") r.Log("Assert the tunnel works")
assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r) assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
r.Log("Wait for the dead index to go away") r.Log("Wait for the dead index to go away")
start := len(theirControl.GetHostmap().Indexes) start := len(theirControl.GetHostmap().Indexes)
for { for {
assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r) assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
if len(theirControl.GetHostmap().Indexes) < start { if len(theirControl.GetHostmap().Indexes) < start {
break break
} }
@@ -290,13 +291,13 @@ func TestUncleanShutdownRaceLoser(t *testing.T) {
} }
func TestUncleanShutdownRaceWinner(t *testing.T) { func TestUncleanShutdownRaceWinner(t *testing.T) {
ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 1}, nil) myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me ", "10.128.0.1/24", nil)
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil) theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", "10.128.0.2/24", nil)
// Teach my how to get to the relay and that their can be reached via the relay // Teach my how to get to the relay and that their can be reached via the relay
myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr) myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr)
theirControl.InjectLightHouseAddr(myVpnIpNet.IP, myUdpAddr) theirControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr)
// Build a router so we don't have to reason who gets which packet // Build a router so we don't have to reason who gets which packet
r := router.NewR(t, myControl, theirControl) r := router.NewR(t, myControl, theirControl)
@@ -307,30 +308,30 @@ func TestUncleanShutdownRaceWinner(t *testing.T) {
theirControl.Start() theirControl.Start()
r.Log("Trigger a handshake from me to them") r.Log("Trigger a handshake from me to them")
myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me")) myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me"))
p := r.RouteForAllUntilTxTun(theirControl) p := r.RouteForAllUntilTxTun(theirControl)
assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.IP, theirVpnIpNet.IP, 80, 80) assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80)
r.RenderHostmaps("Final hostmaps", myControl, theirControl) r.RenderHostmaps("Final hostmaps", myControl, theirControl)
r.Log("Nuke my hostmap") r.Log("Nuke my hostmap")
theirHostmap := theirControl.GetHostmap() theirHostmap := theirControl.GetHostmap()
theirHostmap.Hosts = map[iputil.VpnIp]*nebula.HostInfo{} theirHostmap.Hosts = map[netip.Addr]*nebula.HostInfo{}
theirHostmap.Indexes = map[uint32]*nebula.HostInfo{} theirHostmap.Indexes = map[uint32]*nebula.HostInfo{}
theirHostmap.RemoteIndexes = map[uint32]*nebula.HostInfo{} theirHostmap.RemoteIndexes = map[uint32]*nebula.HostInfo{}
theirControl.InjectTunUDPPacket(myVpnIpNet.IP, 80, 80, []byte("Hi from them again")) theirControl.InjectTunUDPPacket(myVpnIpNet.Addr(), 80, 80, []byte("Hi from them again"))
p = r.RouteForAllUntilTxTun(myControl) p = r.RouteForAllUntilTxTun(myControl)
assertUdpPacket(t, []byte("Hi from them again"), p, theirVpnIpNet.IP, myVpnIpNet.IP, 80, 80) assertUdpPacket(t, []byte("Hi from them again"), p, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), 80, 80)
r.RenderHostmaps("Derp hostmaps", myControl, theirControl) r.RenderHostmaps("Derp hostmaps", myControl, theirControl)
r.Log("Assert the tunnel works") r.Log("Assert the tunnel works")
assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r) assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
r.Log("Wait for the dead index to go away") r.Log("Wait for the dead index to go away")
start := len(myControl.GetHostmap().Indexes) start := len(myControl.GetHostmap().Indexes)
for { for {
assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r) assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
if len(myControl.GetHostmap().Indexes) < start { if len(myControl.GetHostmap().Indexes) < start {
break break
} }
@@ -341,15 +342,15 @@ func TestUncleanShutdownRaceWinner(t *testing.T) {
} }
func TestRelays(t *testing.T) { func TestRelays(t *testing.T) {
ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
myControl, myVpnIpNet, _, _ := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 1}, m{"relay": m{"use_relays": true}}) myControl, myVpnIpNet, _, _ := newSimpleServer(ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}})
relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(ca, caKey, "relay ", net.IP{10, 0, 0, 128}, m{"relay": m{"am_relay": true}}) relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}})
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", net.IP{10, 0, 0, 2}, m{"relay": m{"use_relays": true}}) theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}})
// Teach my how to get to the relay and that their can be reached via the relay // Teach my how to get to the relay and that their can be reached via the relay
myControl.InjectLightHouseAddr(relayVpnIpNet.IP, relayUdpAddr) myControl.InjectLightHouseAddr(relayVpnIpNet.Addr(), relayUdpAddr)
myControl.InjectRelays(theirVpnIpNet.IP, []net.IP{relayVpnIpNet.IP}) myControl.InjectRelays(theirVpnIpNet.Addr(), []netip.Addr{relayVpnIpNet.Addr()})
relayControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr) relayControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr)
// Build a router so we don't have to reason who gets which packet // Build a router so we don't have to reason who gets which packet
r := router.NewR(t, myControl, relayControl, theirControl) r := router.NewR(t, myControl, relayControl, theirControl)
@@ -361,31 +362,162 @@ func TestRelays(t *testing.T) {
theirControl.Start() theirControl.Start()
t.Log("Trigger a handshake from me to them via the relay") t.Log("Trigger a handshake from me to them via the relay")
myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me")) myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me"))
p := r.RouteForAllUntilTxTun(theirControl) p := r.RouteForAllUntilTxTun(theirControl)
r.Log("Assert the tunnel works") r.Log("Assert the tunnel works")
assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.IP, theirVpnIpNet.IP, 80, 80) assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80)
r.RenderHostmaps("Final hostmaps", myControl, relayControl, theirControl) r.RenderHostmaps("Final hostmaps", myControl, relayControl, theirControl)
//TODO: assert we actually used the relay even though it should be impossible for a tunnel to have occurred without it //TODO: assert we actually used the relay even though it should be impossible for a tunnel to have occurred without it
} }
func TestStage1RaceRelays(t *testing.T) { func TestReestablishRelays(t *testing.T) {
//NOTE: this is a race between me and relay resulting in a full tunnel from me to them via relay ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) myControl, myVpnIpNet, _, _ := newSimpleServer(ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}})
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 1}, m{"relay": m{"use_relays": true}}) relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}})
relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(ca, caKey, "relay ", net.IP{10, 0, 0, 128}, m{"relay": m{"am_relay": true}}) theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}})
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", net.IP{10, 0, 0, 2}, m{"relay": m{"use_relays": true}})
// Teach my how to get to the relay and that their can be reached via the relay // Teach my how to get to the relay and that their can be reached via the relay
myControl.InjectLightHouseAddr(relayVpnIpNet.IP, relayUdpAddr) myControl.InjectLightHouseAddr(relayVpnIpNet.Addr(), relayUdpAddr)
theirControl.InjectLightHouseAddr(relayVpnIpNet.IP, relayUdpAddr) myControl.InjectRelays(theirVpnIpNet.Addr(), []netip.Addr{relayVpnIpNet.Addr()})
relayControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr)
myControl.InjectRelays(theirVpnIpNet.IP, []net.IP{relayVpnIpNet.IP}) // Build a router so we don't have to reason who gets which packet
theirControl.InjectRelays(myVpnIpNet.IP, []net.IP{relayVpnIpNet.IP}) r := router.NewR(t, myControl, relayControl, theirControl)
defer r.RenderFlow()
relayControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr) // Start the servers
relayControl.InjectLightHouseAddr(myVpnIpNet.IP, myUdpAddr) myControl.Start()
relayControl.Start()
theirControl.Start()
t.Log("Trigger a handshake from me to them via the relay")
myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me"))
p := r.RouteForAllUntilTxTun(theirControl)
r.Log("Assert the tunnel works")
assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80)
t.Log("Ensure packet traversal from them to me via the relay")
theirControl.InjectTunUDPPacket(myVpnIpNet.Addr(), 80, 80, []byte("Hi from them"))
p = r.RouteForAllUntilTxTun(myControl)
r.Log("Assert the tunnel works")
assertUdpPacket(t, []byte("Hi from them"), p, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), 80, 80)
// If we break the relay's connection to 'them', 'me' needs to detect and recover the connection
r.Log("Close the tunnel")
relayControl.CloseTunnel(theirVpnIpNet.Addr(), true)
start := len(myControl.GetHostmap().Indexes)
curIndexes := len(myControl.GetHostmap().Indexes)
for curIndexes >= start {
curIndexes = len(myControl.GetHostmap().Indexes)
r.Logf("Wait for the dead index to go away:start=%v indexes, current=%v indexes", start, curIndexes)
myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me should fail"))
r.RouteForAllExitFunc(func(p *udp.Packet, c *nebula.Control) router.ExitType {
return router.RouteAndExit
})
time.Sleep(2 * time.Second)
}
r.Log("Dead index went away. Woot!")
r.RenderHostmaps("Me removed hostinfo", myControl, relayControl, theirControl)
// Next packet should re-establish a relayed connection and work just great.
t.Logf("Assert the tunnel...")
for {
t.Log("RouteForAllUntilTxTun")
myControl.InjectLightHouseAddr(relayVpnIpNet.Addr(), relayUdpAddr)
myControl.InjectRelays(theirVpnIpNet.Addr(), []netip.Addr{relayVpnIpNet.Addr()})
relayControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr)
myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me"))
p = r.RouteForAllUntilTxTun(theirControl)
r.Log("Assert the tunnel works")
packet := gopacket.NewPacket(p, layers.LayerTypeIPv4, gopacket.Lazy)
v4 := packet.Layer(layers.LayerTypeIPv4).(*layers.IPv4)
if slices.Compare(v4.SrcIP, myVpnIpNet.Addr().AsSlice()) != 0 {
t.Logf("SrcIP is unexpected...this is not the packet I'm looking for. Keep looking")
continue
}
if slices.Compare(v4.DstIP, theirVpnIpNet.Addr().AsSlice()) != 0 {
t.Logf("DstIP is unexpected...this is not the packet I'm looking for. Keep looking")
continue
}
udp := packet.Layer(layers.LayerTypeUDP).(*layers.UDP)
if udp == nil {
t.Log("Not a UDP packet. This is not the packet I'm looking for. Keep looking")
continue
}
data := packet.ApplicationLayer()
if data == nil {
t.Log("No data found in packet. This is not the packet I'm looking for. Keep looking.")
continue
}
if string(data.Payload()) != "Hi from me" {
t.Logf("Unexpected payload: '%v', keep looking", string(data.Payload()))
continue
}
t.Log("I found my lost packet. I am so happy.")
break
}
t.Log("Assert the tunnel works the other way, too")
for {
t.Log("RouteForAllUntilTxTun")
theirControl.InjectTunUDPPacket(myVpnIpNet.Addr(), 80, 80, []byte("Hi from them"))
p = r.RouteForAllUntilTxTun(myControl)
r.Log("Assert the tunnel works")
packet := gopacket.NewPacket(p, layers.LayerTypeIPv4, gopacket.Lazy)
v4 := packet.Layer(layers.LayerTypeIPv4).(*layers.IPv4)
if slices.Compare(v4.DstIP, myVpnIpNet.Addr().AsSlice()) != 0 {
t.Logf("Dst is unexpected...this is not the packet I'm looking for. Keep looking")
continue
}
if slices.Compare(v4.SrcIP, theirVpnIpNet.Addr().AsSlice()) != 0 {
t.Logf("SrcIP is unexpected...this is not the packet I'm looking for. Keep looking")
continue
}
udp := packet.Layer(layers.LayerTypeUDP).(*layers.UDP)
if udp == nil {
t.Log("Not a UDP packet. This is not the packet I'm looking for. Keep looking")
continue
}
data := packet.ApplicationLayer()
if data == nil {
t.Log("No data found in packet. This is not the packet I'm looking for. Keep looking.")
continue
}
if string(data.Payload()) != "Hi from them" {
t.Logf("Unexpected payload: '%v', keep looking", string(data.Payload()))
continue
}
t.Log("I found my lost packet. I am so happy.")
break
}
r.RenderHostmaps("Final hostmaps", myControl, relayControl, theirControl)
}
func TestStage1RaceRelays(t *testing.T) {
//NOTE: this is a race between me and relay resulting in a full tunnel from me to them via relay
ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}})
relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}})
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}})
// Teach my how to get to the relay and that their can be reached via the relay
myControl.InjectLightHouseAddr(relayVpnIpNet.Addr(), relayUdpAddr)
theirControl.InjectLightHouseAddr(relayVpnIpNet.Addr(), relayUdpAddr)
myControl.InjectRelays(theirVpnIpNet.Addr(), []netip.Addr{relayVpnIpNet.Addr()})
theirControl.InjectRelays(myVpnIpNet.Addr(), []netip.Addr{relayVpnIpNet.Addr()})
relayControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr)
relayControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr)
// Build a router so we don't have to reason who gets which packet // Build a router so we don't have to reason who gets which packet
r := router.NewR(t, myControl, relayControl, theirControl) r := router.NewR(t, myControl, relayControl, theirControl)
@@ -397,14 +529,14 @@ func TestStage1RaceRelays(t *testing.T) {
theirControl.Start() theirControl.Start()
r.Log("Get a tunnel between me and relay") r.Log("Get a tunnel between me and relay")
assertTunnel(t, myVpnIpNet.IP, relayVpnIpNet.IP, myControl, relayControl, r) assertTunnel(t, myVpnIpNet.Addr(), relayVpnIpNet.Addr(), myControl, relayControl, r)
r.Log("Get a tunnel between them and relay") r.Log("Get a tunnel between them and relay")
assertTunnel(t, theirVpnIpNet.IP, relayVpnIpNet.IP, theirControl, relayControl, r) assertTunnel(t, theirVpnIpNet.Addr(), relayVpnIpNet.Addr(), theirControl, relayControl, r)
r.Log("Trigger a handshake from both them and me via relay to them and me") r.Log("Trigger a handshake from both them and me via relay to them and me")
myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me")) myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me"))
theirControl.InjectTunUDPPacket(myVpnIpNet.IP, 80, 80, []byte("Hi from them")) theirControl.InjectTunUDPPacket(myVpnIpNet.Addr(), 80, 80, []byte("Hi from them"))
r.Log("Wait for a packet from them to me") r.Log("Wait for a packet from them to me")
p := r.RouteForAllUntilTxTun(myControl) p := r.RouteForAllUntilTxTun(myControl)
@@ -421,21 +553,21 @@ func TestStage1RaceRelays(t *testing.T) {
func TestStage1RaceRelays2(t *testing.T) { func TestStage1RaceRelays2(t *testing.T) {
//NOTE: this is a race between me and relay resulting in a full tunnel from me to them via relay //NOTE: this is a race between me and relay resulting in a full tunnel from me to them via relay
ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 1}, m{"relay": m{"use_relays": true}}) myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}})
relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(ca, caKey, "relay ", net.IP{10, 0, 0, 128}, m{"relay": m{"am_relay": true}}) relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}})
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", net.IP{10, 0, 0, 2}, m{"relay": m{"use_relays": true}}) theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}})
l := NewTestLogger() l := NewTestLogger()
// Teach my how to get to the relay and that their can be reached via the relay // Teach my how to get to the relay and that their can be reached via the relay
myControl.InjectLightHouseAddr(relayVpnIpNet.IP, relayUdpAddr) myControl.InjectLightHouseAddr(relayVpnIpNet.Addr(), relayUdpAddr)
theirControl.InjectLightHouseAddr(relayVpnIpNet.IP, relayUdpAddr) theirControl.InjectLightHouseAddr(relayVpnIpNet.Addr(), relayUdpAddr)
myControl.InjectRelays(theirVpnIpNet.IP, []net.IP{relayVpnIpNet.IP}) myControl.InjectRelays(theirVpnIpNet.Addr(), []netip.Addr{relayVpnIpNet.Addr()})
theirControl.InjectRelays(myVpnIpNet.IP, []net.IP{relayVpnIpNet.IP}) theirControl.InjectRelays(myVpnIpNet.Addr(), []netip.Addr{relayVpnIpNet.Addr()})
relayControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr) relayControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr)
relayControl.InjectLightHouseAddr(myVpnIpNet.IP, myUdpAddr) relayControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr)
// Build a router so we don't have to reason who gets which packet // Build a router so we don't have to reason who gets which packet
r := router.NewR(t, myControl, relayControl, theirControl) r := router.NewR(t, myControl, relayControl, theirControl)
@@ -448,16 +580,16 @@ func TestStage1RaceRelays2(t *testing.T) {
r.Log("Get a tunnel between me and relay") r.Log("Get a tunnel between me and relay")
l.Info("Get a tunnel between me and relay") l.Info("Get a tunnel between me and relay")
assertTunnel(t, myVpnIpNet.IP, relayVpnIpNet.IP, myControl, relayControl, r) assertTunnel(t, myVpnIpNet.Addr(), relayVpnIpNet.Addr(), myControl, relayControl, r)
r.Log("Get a tunnel between them and relay") r.Log("Get a tunnel between them and relay")
l.Info("Get a tunnel between them and relay") l.Info("Get a tunnel between them and relay")
assertTunnel(t, theirVpnIpNet.IP, relayVpnIpNet.IP, theirControl, relayControl, r) assertTunnel(t, theirVpnIpNet.Addr(), relayVpnIpNet.Addr(), theirControl, relayControl, r)
r.Log("Trigger a handshake from both them and me via relay to them and me") r.Log("Trigger a handshake from both them and me via relay to them and me")
l.Info("Trigger a handshake from both them and me via relay to them and me") l.Info("Trigger a handshake from both them and me via relay to them and me")
myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me")) myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me"))
theirControl.InjectTunUDPPacket(myVpnIpNet.IP, 80, 80, []byte("Hi from them")) theirControl.InjectTunUDPPacket(myVpnIpNet.Addr(), 80, 80, []byte("Hi from them"))
//r.RouteUntilAfterMsgType(myControl, header.Control, header.MessageNone) //r.RouteUntilAfterMsgType(myControl, header.Control, header.MessageNone)
//r.RouteUntilAfterMsgType(theirControl, header.Control, header.MessageNone) //r.RouteUntilAfterMsgType(theirControl, header.Control, header.MessageNone)
@@ -470,7 +602,7 @@ func TestStage1RaceRelays2(t *testing.T) {
r.Log("Assert the tunnel works") r.Log("Assert the tunnel works")
l.Info("Assert the tunnel works") l.Info("Assert the tunnel works")
assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r) assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r)
t.Log("Wait until we remove extra tunnels") t.Log("Wait until we remove extra tunnels")
l.Info("Wait until we remove extra tunnels") l.Info("Wait until we remove extra tunnels")
@@ -490,7 +622,7 @@ func TestStage1RaceRelays2(t *testing.T) {
"theirControl": len(theirControl.GetHostmap().Indexes), "theirControl": len(theirControl.GetHostmap().Indexes),
"relayControl": len(relayControl.GetHostmap().Indexes), "relayControl": len(relayControl.GetHostmap().Indexes),
}).Info("Waiting for hostinfos to be removed...") }).Info("Waiting for hostinfos to be removed...")
assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r) assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
t.Log("Connection manager hasn't ticked yet") t.Log("Connection manager hasn't ticked yet")
time.Sleep(time.Second) time.Sleep(time.Second)
retries-- retries--
@@ -498,7 +630,7 @@ func TestStage1RaceRelays2(t *testing.T) {
r.Log("Assert the tunnel works") r.Log("Assert the tunnel works")
l.Info("Assert the tunnel works") l.Info("Assert the tunnel works")
assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r) assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r)
myControl.Stop() myControl.Stop()
theirControl.Stop() theirControl.Stop()
@@ -507,16 +639,17 @@ func TestStage1RaceRelays2(t *testing.T) {
// //
////TODO: assert hostmaps ////TODO: assert hostmaps
} }
func TestRehandshakingRelays(t *testing.T) { func TestRehandshakingRelays(t *testing.T) {
ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
myControl, myVpnIpNet, _, _ := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 1}, m{"relay": m{"use_relays": true}}) myControl, myVpnIpNet, _, _ := newSimpleServer(ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}})
relayControl, relayVpnIpNet, relayUdpAddr, relayConfig := newSimpleServer(ca, caKey, "relay ", net.IP{10, 0, 0, 128}, m{"relay": m{"am_relay": true}}) relayControl, relayVpnIpNet, relayUdpAddr, relayConfig := newSimpleServer(ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}})
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", net.IP{10, 0, 0, 2}, m{"relay": m{"use_relays": true}}) theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}})
// Teach my how to get to the relay and that their can be reached via the relay // Teach my how to get to the relay and that their can be reached via the relay
myControl.InjectLightHouseAddr(relayVpnIpNet.IP, relayUdpAddr) myControl.InjectLightHouseAddr(relayVpnIpNet.Addr(), relayUdpAddr)
myControl.InjectRelays(theirVpnIpNet.IP, []net.IP{relayVpnIpNet.IP}) myControl.InjectRelays(theirVpnIpNet.Addr(), []netip.Addr{relayVpnIpNet.Addr()})
relayControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr) relayControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr)
// Build a router so we don't have to reason who gets which packet // Build a router so we don't have to reason who gets which packet
r := router.NewR(t, myControl, relayControl, theirControl) r := router.NewR(t, myControl, relayControl, theirControl)
@@ -528,11 +661,11 @@ func TestRehandshakingRelays(t *testing.T) {
theirControl.Start() theirControl.Start()
t.Log("Trigger a handshake from me to them via the relay") t.Log("Trigger a handshake from me to them via the relay")
myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me")) myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me"))
p := r.RouteForAllUntilTxTun(theirControl) p := r.RouteForAllUntilTxTun(theirControl)
r.Log("Assert the tunnel works") r.Log("Assert the tunnel works")
assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.IP, theirVpnIpNet.IP, 80, 80) assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80)
r.RenderHostmaps("working hostmaps", myControl, relayControl, theirControl) r.RenderHostmaps("working hostmaps", myControl, relayControl, theirControl)
// When I update the certificate for the relay, both me and them will have 2 host infos for the relay, // When I update the certificate for the relay, both me and them will have 2 host infos for the relay,
@@ -556,8 +689,8 @@ func TestRehandshakingRelays(t *testing.T) {
for { for {
r.Log("Assert the tunnel works between myVpnIpNet and relayVpnIpNet") r.Log("Assert the tunnel works between myVpnIpNet and relayVpnIpNet")
assertTunnel(t, myVpnIpNet.IP, relayVpnIpNet.IP, myControl, relayControl, r) assertTunnel(t, myVpnIpNet.Addr(), relayVpnIpNet.Addr(), myControl, relayControl, r)
c := myControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(relayVpnIpNet.IP), false) c := myControl.GetHostInfoByVpnIp(relayVpnIpNet.Addr(), false)
if len(c.Cert.Details.Groups) != 0 { if len(c.Cert.Details.Groups) != 0 {
// We have a new certificate now // We have a new certificate now
r.Log("Certificate between my and relay is updated!") r.Log("Certificate between my and relay is updated!")
@@ -569,8 +702,8 @@ func TestRehandshakingRelays(t *testing.T) {
for { for {
r.Log("Assert the tunnel works between theirVpnIpNet and relayVpnIpNet") r.Log("Assert the tunnel works between theirVpnIpNet and relayVpnIpNet")
assertTunnel(t, theirVpnIpNet.IP, relayVpnIpNet.IP, theirControl, relayControl, r) assertTunnel(t, theirVpnIpNet.Addr(), relayVpnIpNet.Addr(), theirControl, relayControl, r)
c := theirControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(relayVpnIpNet.IP), false) c := theirControl.GetHostInfoByVpnIp(relayVpnIpNet.Addr(), false)
if len(c.Cert.Details.Groups) != 0 { if len(c.Cert.Details.Groups) != 0 {
// We have a new certificate now // We have a new certificate now
r.Log("Certificate between their and relay is updated!") r.Log("Certificate between their and relay is updated!")
@@ -581,13 +714,13 @@ func TestRehandshakingRelays(t *testing.T) {
} }
r.Log("Assert the relay tunnel still works") r.Log("Assert the relay tunnel still works")
assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r) assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r)
r.RenderHostmaps("working hostmaps", myControl, relayControl, theirControl) r.RenderHostmaps("working hostmaps", myControl, relayControl, theirControl)
// We should have two hostinfos on all sides // We should have two hostinfos on all sides
for len(myControl.GetHostmap().Indexes) != 2 { for len(myControl.GetHostmap().Indexes) != 2 {
t.Logf("Waiting for myControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(myControl.GetHostmap().Indexes)) t.Logf("Waiting for myControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(myControl.GetHostmap().Indexes))
r.Log("Assert the relay tunnel still works") r.Log("Assert the relay tunnel still works")
assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r) assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r)
r.Log("yupitdoes") r.Log("yupitdoes")
time.Sleep(time.Second) time.Sleep(time.Second)
} }
@@ -595,7 +728,7 @@ func TestRehandshakingRelays(t *testing.T) {
for len(theirControl.GetHostmap().Indexes) != 2 { for len(theirControl.GetHostmap().Indexes) != 2 {
t.Logf("Waiting for theirControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(theirControl.GetHostmap().Indexes)) t.Logf("Waiting for theirControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(theirControl.GetHostmap().Indexes))
r.Log("Assert the relay tunnel still works") r.Log("Assert the relay tunnel still works")
assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r) assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r)
r.Log("yupitdoes") r.Log("yupitdoes")
time.Sleep(time.Second) time.Sleep(time.Second)
} }
@@ -603,7 +736,7 @@ func TestRehandshakingRelays(t *testing.T) {
for len(relayControl.GetHostmap().Indexes) != 2 { for len(relayControl.GetHostmap().Indexes) != 2 {
t.Logf("Waiting for relayControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(relayControl.GetHostmap().Indexes)) t.Logf("Waiting for relayControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(relayControl.GetHostmap().Indexes))
r.Log("Assert the relay tunnel still works") r.Log("Assert the relay tunnel still works")
assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r) assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r)
r.Log("yupitdoes") r.Log("yupitdoes")
time.Sleep(time.Second) time.Sleep(time.Second)
} }
@@ -612,15 +745,15 @@ func TestRehandshakingRelays(t *testing.T) {
func TestRehandshakingRelaysPrimary(t *testing.T) { func TestRehandshakingRelaysPrimary(t *testing.T) {
// This test is the same as TestRehandshakingRelays but one of the terminal types is a primary swap winner // This test is the same as TestRehandshakingRelays but one of the terminal types is a primary swap winner
ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
myControl, myVpnIpNet, _, _ := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 128}, m{"relay": m{"use_relays": true}}) myControl, myVpnIpNet, _, _ := newSimpleServer(ca, caKey, "me ", "10.128.0.128/24", m{"relay": m{"use_relays": true}})
relayControl, relayVpnIpNet, relayUdpAddr, relayConfig := newSimpleServer(ca, caKey, "relay ", net.IP{10, 0, 0, 1}, m{"relay": m{"am_relay": true}}) relayControl, relayVpnIpNet, relayUdpAddr, relayConfig := newSimpleServer(ca, caKey, "relay ", "10.128.0.1/24", m{"relay": m{"am_relay": true}})
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", net.IP{10, 0, 0, 2}, m{"relay": m{"use_relays": true}}) theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}})
// Teach my how to get to the relay and that their can be reached via the relay // Teach my how to get to the relay and that their can be reached via the relay
myControl.InjectLightHouseAddr(relayVpnIpNet.IP, relayUdpAddr) myControl.InjectLightHouseAddr(relayVpnIpNet.Addr(), relayUdpAddr)
myControl.InjectRelays(theirVpnIpNet.IP, []net.IP{relayVpnIpNet.IP}) myControl.InjectRelays(theirVpnIpNet.Addr(), []netip.Addr{relayVpnIpNet.Addr()})
relayControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr) relayControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr)
// Build a router so we don't have to reason who gets which packet // Build a router so we don't have to reason who gets which packet
r := router.NewR(t, myControl, relayControl, theirControl) r := router.NewR(t, myControl, relayControl, theirControl)
@@ -632,11 +765,11 @@ func TestRehandshakingRelaysPrimary(t *testing.T) {
theirControl.Start() theirControl.Start()
t.Log("Trigger a handshake from me to them via the relay") t.Log("Trigger a handshake from me to them via the relay")
myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me")) myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me"))
p := r.RouteForAllUntilTxTun(theirControl) p := r.RouteForAllUntilTxTun(theirControl)
r.Log("Assert the tunnel works") r.Log("Assert the tunnel works")
assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.IP, theirVpnIpNet.IP, 80, 80) assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80)
r.RenderHostmaps("working hostmaps", myControl, relayControl, theirControl) r.RenderHostmaps("working hostmaps", myControl, relayControl, theirControl)
// When I update the certificate for the relay, both me and them will have 2 host infos for the relay, // When I update the certificate for the relay, both me and them will have 2 host infos for the relay,
@@ -660,8 +793,8 @@ func TestRehandshakingRelaysPrimary(t *testing.T) {
for { for {
r.Log("Assert the tunnel works between myVpnIpNet and relayVpnIpNet") r.Log("Assert the tunnel works between myVpnIpNet and relayVpnIpNet")
assertTunnel(t, myVpnIpNet.IP, relayVpnIpNet.IP, myControl, relayControl, r) assertTunnel(t, myVpnIpNet.Addr(), relayVpnIpNet.Addr(), myControl, relayControl, r)
c := myControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(relayVpnIpNet.IP), false) c := myControl.GetHostInfoByVpnIp(relayVpnIpNet.Addr(), false)
if len(c.Cert.Details.Groups) != 0 { if len(c.Cert.Details.Groups) != 0 {
// We have a new certificate now // We have a new certificate now
r.Log("Certificate between my and relay is updated!") r.Log("Certificate between my and relay is updated!")
@@ -673,8 +806,8 @@ func TestRehandshakingRelaysPrimary(t *testing.T) {
for { for {
r.Log("Assert the tunnel works between theirVpnIpNet and relayVpnIpNet") r.Log("Assert the tunnel works between theirVpnIpNet and relayVpnIpNet")
assertTunnel(t, theirVpnIpNet.IP, relayVpnIpNet.IP, theirControl, relayControl, r) assertTunnel(t, theirVpnIpNet.Addr(), relayVpnIpNet.Addr(), theirControl, relayControl, r)
c := theirControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(relayVpnIpNet.IP), false) c := theirControl.GetHostInfoByVpnIp(relayVpnIpNet.Addr(), false)
if len(c.Cert.Details.Groups) != 0 { if len(c.Cert.Details.Groups) != 0 {
// We have a new certificate now // We have a new certificate now
r.Log("Certificate between their and relay is updated!") r.Log("Certificate between their and relay is updated!")
@@ -685,13 +818,13 @@ func TestRehandshakingRelaysPrimary(t *testing.T) {
} }
r.Log("Assert the relay tunnel still works") r.Log("Assert the relay tunnel still works")
assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r) assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r)
r.RenderHostmaps("working hostmaps", myControl, relayControl, theirControl) r.RenderHostmaps("working hostmaps", myControl, relayControl, theirControl)
// We should have two hostinfos on all sides // We should have two hostinfos on all sides
for len(myControl.GetHostmap().Indexes) != 2 { for len(myControl.GetHostmap().Indexes) != 2 {
t.Logf("Waiting for myControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(myControl.GetHostmap().Indexes)) t.Logf("Waiting for myControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(myControl.GetHostmap().Indexes))
r.Log("Assert the relay tunnel still works") r.Log("Assert the relay tunnel still works")
assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r) assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r)
r.Log("yupitdoes") r.Log("yupitdoes")
time.Sleep(time.Second) time.Sleep(time.Second)
} }
@@ -699,7 +832,7 @@ func TestRehandshakingRelaysPrimary(t *testing.T) {
for len(theirControl.GetHostmap().Indexes) != 2 { for len(theirControl.GetHostmap().Indexes) != 2 {
t.Logf("Waiting for theirControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(theirControl.GetHostmap().Indexes)) t.Logf("Waiting for theirControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(theirControl.GetHostmap().Indexes))
r.Log("Assert the relay tunnel still works") r.Log("Assert the relay tunnel still works")
assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r) assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r)
r.Log("yupitdoes") r.Log("yupitdoes")
time.Sleep(time.Second) time.Sleep(time.Second)
} }
@@ -707,7 +840,7 @@ func TestRehandshakingRelaysPrimary(t *testing.T) {
for len(relayControl.GetHostmap().Indexes) != 2 { for len(relayControl.GetHostmap().Indexes) != 2 {
t.Logf("Waiting for relayControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(relayControl.GetHostmap().Indexes)) t.Logf("Waiting for relayControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(relayControl.GetHostmap().Indexes))
r.Log("Assert the relay tunnel still works") r.Log("Assert the relay tunnel still works")
assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r) assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r)
r.Log("yupitdoes") r.Log("yupitdoes")
time.Sleep(time.Second) time.Sleep(time.Second)
} }
@@ -715,13 +848,13 @@ func TestRehandshakingRelaysPrimary(t *testing.T) {
} }
func TestRehandshaking(t *testing.T) { func TestRehandshaking(t *testing.T) {
ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
myControl, myVpnIpNet, myUdpAddr, myConfig := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 2}, nil) myControl, myVpnIpNet, myUdpAddr, myConfig := newSimpleServer(ca, caKey, "me ", "10.128.0.2/24", nil)
theirControl, theirVpnIpNet, theirUdpAddr, theirConfig := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 1}, nil) theirControl, theirVpnIpNet, theirUdpAddr, theirConfig := newSimpleServer(ca, caKey, "them", "10.128.0.1/24", nil)
// Put their info in our lighthouse and vice versa // Put their info in our lighthouse and vice versa
myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr) myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr)
theirControl.InjectLightHouseAddr(myVpnIpNet.IP, myUdpAddr) theirControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr)
// Build a router so we don't have to reason who gets which packet // Build a router so we don't have to reason who gets which packet
r := router.NewR(t, myControl, theirControl) r := router.NewR(t, myControl, theirControl)
@@ -732,7 +865,7 @@ func TestRehandshaking(t *testing.T) {
theirControl.Start() theirControl.Start()
t.Log("Stand up a tunnel between me and them") t.Log("Stand up a tunnel between me and them")
assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r) assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
r.RenderHostmaps("Starting hostmaps", myControl, theirControl) r.RenderHostmaps("Starting hostmaps", myControl, theirControl)
@@ -754,8 +887,8 @@ func TestRehandshaking(t *testing.T) {
myConfig.ReloadConfigString(string(rc)) myConfig.ReloadConfigString(string(rc))
for { for {
assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r) assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
c := theirControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(myVpnIpNet.IP), false) c := theirControl.GetHostInfoByVpnIp(myVpnIpNet.Addr(), false)
if len(c.Cert.Details.Groups) != 0 { if len(c.Cert.Details.Groups) != 0 {
// We have a new certificate now // We have a new certificate now
break break
@@ -781,19 +914,19 @@ func TestRehandshaking(t *testing.T) {
r.Log("Spin until there is only 1 tunnel") r.Log("Spin until there is only 1 tunnel")
for len(myControl.GetHostmap().Indexes)+len(theirControl.GetHostmap().Indexes) > 2 { for len(myControl.GetHostmap().Indexes)+len(theirControl.GetHostmap().Indexes) > 2 {
assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r) assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
t.Log("Connection manager hasn't ticked yet") t.Log("Connection manager hasn't ticked yet")
time.Sleep(time.Second) time.Sleep(time.Second)
} }
assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r) assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
myFinalHostmapHosts := myControl.ListHostmapHosts(false) myFinalHostmapHosts := myControl.ListHostmapHosts(false)
myFinalHostmapIndexes := myControl.ListHostmapIndexes(false) myFinalHostmapIndexes := myControl.ListHostmapIndexes(false)
theirFinalHostmapHosts := theirControl.ListHostmapHosts(false) theirFinalHostmapHosts := theirControl.ListHostmapHosts(false)
theirFinalHostmapIndexes := theirControl.ListHostmapIndexes(false) theirFinalHostmapIndexes := theirControl.ListHostmapIndexes(false)
// Make sure the correct tunnel won // Make sure the correct tunnel won
c := theirControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(myVpnIpNet.IP), false) c := theirControl.GetHostInfoByVpnIp(myVpnIpNet.Addr(), false)
assert.Contains(t, c.Cert.Details.Groups, "new group") assert.Contains(t, c.Cert.Details.Groups, "new group")
// We should only have a single tunnel now on both sides // We should only have a single tunnel now on both sides
@@ -811,13 +944,13 @@ func TestRehandshaking(t *testing.T) {
func TestRehandshakingLoser(t *testing.T) { func TestRehandshakingLoser(t *testing.T) {
// The purpose of this test is that the race loser renews their certificate and rehandshakes. The final tunnel // The purpose of this test is that the race loser renews their certificate and rehandshakes. The final tunnel
// Should be the one with the new certificate // Should be the one with the new certificate
ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
myControl, myVpnIpNet, myUdpAddr, myConfig := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 2}, nil) myControl, myVpnIpNet, myUdpAddr, myConfig := newSimpleServer(ca, caKey, "me ", "10.128.0.2/24", nil)
theirControl, theirVpnIpNet, theirUdpAddr, theirConfig := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 1}, nil) theirControl, theirVpnIpNet, theirUdpAddr, theirConfig := newSimpleServer(ca, caKey, "them", "10.128.0.1/24", nil)
// Put their info in our lighthouse and vice versa // Put their info in our lighthouse and vice versa
myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr) myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr)
theirControl.InjectLightHouseAddr(myVpnIpNet.IP, myUdpAddr) theirControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr)
// Build a router so we don't have to reason who gets which packet // Build a router so we don't have to reason who gets which packet
r := router.NewR(t, myControl, theirControl) r := router.NewR(t, myControl, theirControl)
@@ -828,11 +961,10 @@ func TestRehandshakingLoser(t *testing.T) {
theirControl.Start() theirControl.Start()
t.Log("Stand up a tunnel between me and them") t.Log("Stand up a tunnel between me and them")
assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r) assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
tt1 := myControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(theirVpnIpNet.IP), false) myControl.GetHostInfoByVpnIp(theirVpnIpNet.Addr(), false)
tt2 := theirControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(myVpnIpNet.IP), false) theirControl.GetHostInfoByVpnIp(myVpnIpNet.Addr(), false)
fmt.Println(tt1.LocalIndex, tt2.LocalIndex)
r.RenderHostmaps("Starting hostmaps", myControl, theirControl) r.RenderHostmaps("Starting hostmaps", myControl, theirControl)
@@ -854,8 +986,8 @@ func TestRehandshakingLoser(t *testing.T) {
theirConfig.ReloadConfigString(string(rc)) theirConfig.ReloadConfigString(string(rc))
for { for {
assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r) assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
theirCertInMe := myControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(theirVpnIpNet.IP), false) theirCertInMe := myControl.GetHostInfoByVpnIp(theirVpnIpNet.Addr(), false)
_, theirNewGroup := theirCertInMe.Cert.Details.InvertedGroups["their new group"] _, theirNewGroup := theirCertInMe.Cert.Details.InvertedGroups["their new group"]
if theirNewGroup { if theirNewGroup {
@@ -882,19 +1014,19 @@ func TestRehandshakingLoser(t *testing.T) {
r.Log("Spin until there is only 1 tunnel") r.Log("Spin until there is only 1 tunnel")
for len(myControl.GetHostmap().Indexes)+len(theirControl.GetHostmap().Indexes) > 2 { for len(myControl.GetHostmap().Indexes)+len(theirControl.GetHostmap().Indexes) > 2 {
assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r) assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
t.Log("Connection manager hasn't ticked yet") t.Log("Connection manager hasn't ticked yet")
time.Sleep(time.Second) time.Sleep(time.Second)
} }
assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r) assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
myFinalHostmapHosts := myControl.ListHostmapHosts(false) myFinalHostmapHosts := myControl.ListHostmapHosts(false)
myFinalHostmapIndexes := myControl.ListHostmapIndexes(false) myFinalHostmapIndexes := myControl.ListHostmapIndexes(false)
theirFinalHostmapHosts := theirControl.ListHostmapHosts(false) theirFinalHostmapHosts := theirControl.ListHostmapHosts(false)
theirFinalHostmapIndexes := theirControl.ListHostmapIndexes(false) theirFinalHostmapIndexes := theirControl.ListHostmapIndexes(false)
// Make sure the correct tunnel won // Make sure the correct tunnel won
theirCertInMe := myControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(theirVpnIpNet.IP), false) theirCertInMe := myControl.GetHostInfoByVpnIp(theirVpnIpNet.Addr(), false)
assert.Contains(t, theirCertInMe.Cert.Details.Groups, "their new group") assert.Contains(t, theirCertInMe.Cert.Details.Groups, "their new group")
// We should only have a single tunnel now on both sides // We should only have a single tunnel now on both sides
@@ -912,13 +1044,13 @@ func TestRaceRegression(t *testing.T) {
// This test forces stage 1, stage 2, stage 1 to be received by me from them // This test forces stage 1, stage 2, stage 1 to be received by me from them
// We had a bug where we were not finding the duplicate handshake and responding to the final stage 1 which // We had a bug where we were not finding the duplicate handshake and responding to the final stage 1 which
// caused a cross-linked hostinfo // caused a cross-linked hostinfo
ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me", net.IP{10, 0, 0, 1}, nil) myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me", "10.128.0.1/24", nil)
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil) theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", "10.128.0.2/24", nil)
// Put their info in our lighthouse // Put their info in our lighthouse
myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr) myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr)
theirControl.InjectLightHouseAddr(myVpnIpNet.IP, myUdpAddr) theirControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr)
// Start the servers // Start the servers
myControl.Start() myControl.Start()
@@ -932,8 +1064,8 @@ func TestRaceRegression(t *testing.T) {
//them rx stage:2 initiatorIndex=120607833 responderIndex=4209862089 //them rx stage:2 initiatorIndex=120607833 responderIndex=4209862089
t.Log("Start both handshakes") t.Log("Start both handshakes")
myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me")) myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me"))
theirControl.InjectTunUDPPacket(myVpnIpNet.IP, 80, 80, []byte("Hi from them")) theirControl.InjectTunUDPPacket(myVpnIpNet.Addr(), 80, 80, []byte("Hi from them"))
t.Log("Get both stage 1") t.Log("Get both stage 1")
myStage1ForThem := myControl.GetFromUDP(true) myStage1ForThem := myControl.GetFromUDP(true)
@@ -963,7 +1095,7 @@ func TestRaceRegression(t *testing.T) {
r.RenderHostmaps("Starting hostmaps", myControl, theirControl) r.RenderHostmaps("Starting hostmaps", myControl, theirControl)
t.Log("Make sure the tunnel still works") t.Log("Make sure the tunnel still works")
assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r) assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
myControl.Stop() myControl.Stop()
theirControl.Stop() theirControl.Stop()

View File

@@ -4,6 +4,7 @@ import (
"crypto/rand" "crypto/rand"
"io" "io"
"net" "net"
"net/netip"
"time" "time"
"github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/cert"
@@ -12,7 +13,7 @@ import (
) )
// NewTestCaCert will generate a CA cert // NewTestCaCert will generate a CA cert
func NewTestCaCert(before, after time.Time, ips, subnets []*net.IPNet, groups []string) (*cert.NebulaCertificate, []byte, []byte, []byte) { func NewTestCaCert(before, after time.Time, ips, subnets []netip.Prefix, groups []string) (*cert.NebulaCertificate, []byte, []byte, []byte) {
pub, priv, err := ed25519.GenerateKey(rand.Reader) pub, priv, err := ed25519.GenerateKey(rand.Reader)
if before.IsZero() { if before.IsZero() {
before = time.Now().Add(time.Second * -60).Round(time.Second) before = time.Now().Add(time.Second * -60).Round(time.Second)
@@ -33,11 +34,17 @@ func NewTestCaCert(before, after time.Time, ips, subnets []*net.IPNet, groups []
} }
if len(ips) > 0 { if len(ips) > 0 {
nc.Details.Ips = ips nc.Details.Ips = make([]*net.IPNet, len(ips))
for i, ip := range ips {
nc.Details.Ips[i] = &net.IPNet{IP: ip.Addr().AsSlice(), Mask: net.CIDRMask(ip.Bits(), ip.Addr().BitLen())}
}
} }
if len(subnets) > 0 { if len(subnets) > 0 {
nc.Details.Subnets = subnets nc.Details.Subnets = make([]*net.IPNet, len(subnets))
for i, ip := range subnets {
nc.Details.Ips[i] = &net.IPNet{IP: ip.Addr().AsSlice(), Mask: net.CIDRMask(ip.Bits(), ip.Addr().BitLen())}
}
} }
if len(groups) > 0 { if len(groups) > 0 {
@@ -59,7 +66,7 @@ func NewTestCaCert(before, after time.Time, ips, subnets []*net.IPNet, groups []
// NewTestCert will generate a signed certificate with the provided details. // NewTestCert will generate a signed certificate with the provided details.
// Expiry times are defaulted if you do not pass them in // Expiry times are defaulted if you do not pass them in
func NewTestCert(ca *cert.NebulaCertificate, key []byte, name string, before, after time.Time, ip *net.IPNet, subnets []*net.IPNet, groups []string) (*cert.NebulaCertificate, []byte, []byte, []byte) { func NewTestCert(ca *cert.NebulaCertificate, key []byte, name string, before, after time.Time, ip netip.Prefix, subnets []netip.Prefix, groups []string) (*cert.NebulaCertificate, []byte, []byte, []byte) {
issuer, err := ca.Sha256Sum() issuer, err := ca.Sha256Sum()
if err != nil { if err != nil {
panic(err) panic(err)
@@ -74,12 +81,12 @@ func NewTestCert(ca *cert.NebulaCertificate, key []byte, name string, before, af
} }
pub, rawPriv := x25519Keypair() pub, rawPriv := x25519Keypair()
ipb := ip.Addr().AsSlice()
nc := &cert.NebulaCertificate{ nc := &cert.NebulaCertificate{
Details: cert.NebulaCertificateDetails{ Details: cert.NebulaCertificateDetails{
Name: name, Name: name,
Ips: []*net.IPNet{ip}, Ips: []*net.IPNet{{IP: ipb[:], Mask: net.CIDRMask(ip.Bits(), ip.Addr().BitLen())}},
Subnets: subnets, //Subnets: subnets,
Groups: groups, Groups: groups,
NotBefore: time.Unix(before.Unix(), 0), NotBefore: time.Unix(before.Unix(), 0),
NotAfter: time.Unix(after.Unix(), 0), NotAfter: time.Unix(after.Unix(), 0),

View File

@@ -6,7 +6,7 @@ package e2e
import ( import (
"fmt" "fmt"
"io" "io"
"net" "net/netip"
"os" "os"
"testing" "testing"
"time" "time"
@@ -19,7 +19,6 @@ import (
"github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/e2e/router" "github.com/slackhq/nebula/e2e/router"
"github.com/slackhq/nebula/iputil"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"gopkg.in/yaml.v2" "gopkg.in/yaml.v2"
) )
@@ -27,15 +26,23 @@ import (
type m map[string]interface{} type m map[string]interface{}
// newSimpleServer creates a nebula instance with many assumptions // newSimpleServer creates a nebula instance with many assumptions
func newSimpleServer(caCrt *cert.NebulaCertificate, caKey []byte, name string, udpIp net.IP, overrides m) (*nebula.Control, *net.IPNet, *net.UDPAddr, *config.C) { func newSimpleServer(caCrt *cert.NebulaCertificate, caKey []byte, name string, sVpnIpNet string, overrides m) (*nebula.Control, netip.Prefix, netip.AddrPort, *config.C) {
l := NewTestLogger() l := NewTestLogger()
vpnIpNet := &net.IPNet{IP: make([]byte, len(udpIp)), Mask: net.IPMask{255, 255, 255, 0}} vpnIpNet, err := netip.ParsePrefix(sVpnIpNet)
copy(vpnIpNet.IP, udpIp) if err != nil {
vpnIpNet.IP[1] += 128 panic(err)
udpAddr := net.UDPAddr{ }
IP: udpIp,
Port: 4242, var udpAddr netip.AddrPort
if vpnIpNet.Addr().Is4() {
budpIp := vpnIpNet.Addr().As4()
budpIp[1] -= 128
udpAddr = netip.AddrPortFrom(netip.AddrFrom4(budpIp), 4242)
} else {
budpIp := vpnIpNet.Addr().As16()
budpIp[13] -= 128
udpAddr = netip.AddrPortFrom(netip.AddrFrom16(budpIp), 4242)
} }
_, _, myPrivKey, myPEM := NewTestCert(caCrt, caKey, name, time.Now(), time.Now().Add(5*time.Minute), vpnIpNet, nil, []string{}) _, _, myPrivKey, myPEM := NewTestCert(caCrt, caKey, name, time.Now(), time.Now().Add(5*time.Minute), vpnIpNet, nil, []string{})
@@ -67,8 +74,8 @@ func newSimpleServer(caCrt *cert.NebulaCertificate, caKey []byte, name string, u
// "try_interval": "1s", // "try_interval": "1s",
//}, //},
"listen": m{ "listen": m{
"host": udpAddr.IP.String(), "host": udpAddr.Addr().String(),
"port": udpAddr.Port, "port": udpAddr.Port(),
}, },
"logging": m{ "logging": m{
"timestamp_format": fmt.Sprintf("%v 15:04:05.000000", name), "timestamp_format": fmt.Sprintf("%v 15:04:05.000000", name),
@@ -102,7 +109,7 @@ func newSimpleServer(caCrt *cert.NebulaCertificate, caKey []byte, name string, u
panic(err) panic(err)
} }
return control, vpnIpNet, &udpAddr, c return control, vpnIpNet, udpAddr, c
} }
type doneCb func() type doneCb func()
@@ -123,7 +130,7 @@ func deadline(t *testing.T, seconds time.Duration) doneCb {
} }
} }
func assertTunnel(t *testing.T, vpnIpA, vpnIpB net.IP, controlA, controlB *nebula.Control, r *router.R) { func assertTunnel(t *testing.T, vpnIpA, vpnIpB netip.Addr, controlA, controlB *nebula.Control, r *router.R) {
// Send a packet from them to me // Send a packet from them to me
controlB.InjectTunUDPPacket(vpnIpA, 80, 90, []byte("Hi from B")) controlB.InjectTunUDPPacket(vpnIpA, 80, 90, []byte("Hi from B"))
bPacket := r.RouteForAllUntilTxTun(controlA) bPacket := r.RouteForAllUntilTxTun(controlA)
@@ -135,23 +142,20 @@ func assertTunnel(t *testing.T, vpnIpA, vpnIpB net.IP, controlA, controlB *nebul
assertUdpPacket(t, []byte("Hello from A"), aPacket, vpnIpA, vpnIpB, 90, 80) assertUdpPacket(t, []byte("Hello from A"), aPacket, vpnIpA, vpnIpB, 90, 80)
} }
func assertHostInfoPair(t *testing.T, addrA, addrB *net.UDPAddr, vpnIpA, vpnIpB net.IP, controlA, controlB *nebula.Control) { func assertHostInfoPair(t *testing.T, addrA, addrB netip.AddrPort, vpnIpA, vpnIpB netip.Addr, controlA, controlB *nebula.Control) {
// Get both host infos // Get both host infos
hBinA := controlA.GetHostInfoByVpnIp(iputil.Ip2VpnIp(vpnIpB), false) hBinA := controlA.GetHostInfoByVpnIp(vpnIpB, false)
assert.NotNil(t, hBinA, "Host B was not found by vpnIp in controlA") assert.NotNil(t, hBinA, "Host B was not found by vpnIp in controlA")
hAinB := controlB.GetHostInfoByVpnIp(iputil.Ip2VpnIp(vpnIpA), false) hAinB := controlB.GetHostInfoByVpnIp(vpnIpA, false)
assert.NotNil(t, hAinB, "Host A was not found by vpnIp in controlB") assert.NotNil(t, hAinB, "Host A was not found by vpnIp in controlB")
// Check that both vpn and real addr are correct // Check that both vpn and real addr are correct
assert.Equal(t, vpnIpB, hBinA.VpnIp, "Host B VpnIp is wrong in control A") assert.Equal(t, vpnIpB, hBinA.VpnIp, "Host B VpnIp is wrong in control A")
assert.Equal(t, vpnIpA, hAinB.VpnIp, "Host A VpnIp is wrong in control B") assert.Equal(t, vpnIpA, hAinB.VpnIp, "Host A VpnIp is wrong in control B")
assert.Equal(t, addrB.IP.To16(), hBinA.CurrentRemote.IP.To16(), "Host B remote ip is wrong in control A") assert.Equal(t, addrB, hBinA.CurrentRemote, "Host B remote is wrong in control A")
assert.Equal(t, addrA.IP.To16(), hAinB.CurrentRemote.IP.To16(), "Host A remote ip is wrong in control B") assert.Equal(t, addrA, hAinB.CurrentRemote, "Host A remote is wrong in control B")
assert.Equal(t, addrB.Port, int(hBinA.CurrentRemote.Port), "Host B remote port is wrong in control A")
assert.Equal(t, addrA.Port, int(hAinB.CurrentRemote.Port), "Host A remote port is wrong in control B")
// Check that our indexes match // Check that our indexes match
assert.Equal(t, hBinA.LocalIndex, hAinB.RemoteIndex, "Host B local index does not match host A remote index") assert.Equal(t, hBinA.LocalIndex, hAinB.RemoteIndex, "Host B local index does not match host A remote index")
@@ -174,13 +178,13 @@ func assertHostInfoPair(t *testing.T, addrA, addrB *net.UDPAddr, vpnIpA, vpnIpB
//checkIndexes("hmB", hmB, hAinB) //checkIndexes("hmB", hmB, hAinB)
} }
func assertUdpPacket(t *testing.T, expected, b []byte, fromIp, toIp net.IP, fromPort, toPort uint16) { func assertUdpPacket(t *testing.T, expected, b []byte, fromIp, toIp netip.Addr, fromPort, toPort uint16) {
packet := gopacket.NewPacket(b, layers.LayerTypeIPv4, gopacket.Lazy) packet := gopacket.NewPacket(b, layers.LayerTypeIPv4, gopacket.Lazy)
v4 := packet.Layer(layers.LayerTypeIPv4).(*layers.IPv4) v4 := packet.Layer(layers.LayerTypeIPv4).(*layers.IPv4)
assert.NotNil(t, v4, "No ipv4 data found") assert.NotNil(t, v4, "No ipv4 data found")
assert.Equal(t, fromIp, v4.SrcIP, "Source ip was incorrect") assert.Equal(t, fromIp.AsSlice(), []byte(v4.SrcIP), "Source ip was incorrect")
assert.Equal(t, toIp, v4.DstIP, "Dest ip was incorrect") assert.Equal(t, toIp.AsSlice(), []byte(v4.DstIP), "Dest ip was incorrect")
udp := packet.Layer(layers.LayerTypeUDP).(*layers.UDP) udp := packet.Layer(layers.LayerTypeUDP).(*layers.UDP)
assert.NotNil(t, udp, "No udp data found") assert.NotNil(t, udp, "No udp data found")

View File

@@ -5,11 +5,11 @@ package router
import ( import (
"fmt" "fmt"
"net/netip"
"sort" "sort"
"strings" "strings"
"github.com/slackhq/nebula" "github.com/slackhq/nebula"
"github.com/slackhq/nebula/iputil"
) )
type edge struct { type edge struct {
@@ -118,14 +118,14 @@ func renderHostmap(c *nebula.Control) (string, []*edge) {
return r, globalLines return r, globalLines
} }
func sortedHosts(hosts map[iputil.VpnIp]*nebula.HostInfo) []iputil.VpnIp { func sortedHosts(hosts map[netip.Addr]*nebula.HostInfo) []netip.Addr {
keys := make([]iputil.VpnIp, 0, len(hosts)) keys := make([]netip.Addr, 0, len(hosts))
for key := range hosts { for key := range hosts {
keys = append(keys, key) keys = append(keys, key)
} }
sort.SliceStable(keys, func(i, j int) bool { sort.SliceStable(keys, func(i, j int) bool {
return keys[i] > keys[j] return keys[i].Compare(keys[j]) > 0
}) })
return keys return keys

View File

@@ -6,12 +6,11 @@ package router
import ( import (
"context" "context"
"fmt" "fmt"
"net" "net/netip"
"os" "os"
"path/filepath" "path/filepath"
"reflect" "reflect"
"sort" "sort"
"strconv"
"strings" "strings"
"sync" "sync"
"testing" "testing"
@@ -21,7 +20,6 @@ import (
"github.com/google/gopacket/layers" "github.com/google/gopacket/layers"
"github.com/slackhq/nebula" "github.com/slackhq/nebula"
"github.com/slackhq/nebula/header" "github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/udp" "github.com/slackhq/nebula/udp"
"golang.org/x/exp/maps" "golang.org/x/exp/maps"
) )
@@ -29,18 +27,18 @@ import (
type R struct { type R struct {
// Simple map of the ip:port registered on a control to the control // Simple map of the ip:port registered on a control to the control
// Basically a router, right? // Basically a router, right?
controls map[string]*nebula.Control controls map[netip.AddrPort]*nebula.Control
// A map for inbound packets for a control that doesn't know about this address // A map for inbound packets for a control that doesn't know about this address
inNat map[string]*nebula.Control inNat map[netip.AddrPort]*nebula.Control
// A last used map, if an inbound packet hit the inNat map then // A last used map, if an inbound packet hit the inNat map then
// all return packets should use the same last used inbound address for the outbound sender // all return packets should use the same last used inbound address for the outbound sender
// map[from address + ":" + to address] => ip:port to rewrite in the udp packet to receiver // map[from address + ":" + to address] => ip:port to rewrite in the udp packet to receiver
outNat map[string]net.UDPAddr outNat map[string]netip.AddrPort
// A map of vpn ip to the nebula control it belongs to // A map of vpn ip to the nebula control it belongs to
vpnControls map[iputil.VpnIp]*nebula.Control vpnControls map[netip.Addr]*nebula.Control
ignoreFlows []ignoreFlow ignoreFlows []ignoreFlow
flow []flowEntry flow []flowEntry
@@ -118,10 +116,10 @@ func NewR(t testing.TB, controls ...*nebula.Control) *R {
} }
r := &R{ r := &R{
controls: make(map[string]*nebula.Control), controls: make(map[netip.AddrPort]*nebula.Control),
vpnControls: make(map[iputil.VpnIp]*nebula.Control), vpnControls: make(map[netip.Addr]*nebula.Control),
inNat: make(map[string]*nebula.Control), inNat: make(map[netip.AddrPort]*nebula.Control),
outNat: make(map[string]net.UDPAddr), outNat: make(map[string]netip.AddrPort),
flow: []flowEntry{}, flow: []flowEntry{},
ignoreFlows: []ignoreFlow{}, ignoreFlows: []ignoreFlow{},
fn: filepath.Join("mermaid", fmt.Sprintf("%s.md", t.Name())), fn: filepath.Join("mermaid", fmt.Sprintf("%s.md", t.Name())),
@@ -135,7 +133,7 @@ func NewR(t testing.TB, controls ...*nebula.Control) *R {
for _, c := range controls { for _, c := range controls {
addr := c.GetUDPAddr() addr := c.GetUDPAddr()
if _, ok := r.controls[addr]; ok { if _, ok := r.controls[addr]; ok {
panic("Duplicate listen address: " + addr) panic("Duplicate listen address: " + addr.String())
} }
r.vpnControls[c.GetVpnIp()] = c r.vpnControls[c.GetVpnIp()] = c
@@ -165,13 +163,13 @@ func NewR(t testing.TB, controls ...*nebula.Control) *R {
// It does not look at the addr attached to the instance. // It does not look at the addr attached to the instance.
// If a route is used, this will behave like a NAT for the return path. // If a route is used, this will behave like a NAT for the return path.
// Rewriting the source ip:port to what was last sent to from the origin // Rewriting the source ip:port to what was last sent to from the origin
func (r *R) AddRoute(ip net.IP, port uint16, c *nebula.Control) { func (r *R) AddRoute(ip netip.Addr, port uint16, c *nebula.Control) {
r.Lock() r.Lock()
defer r.Unlock() defer r.Unlock()
inAddr := net.JoinHostPort(ip.String(), fmt.Sprintf("%v", port)) inAddr := netip.AddrPortFrom(ip, port)
if _, ok := r.inNat[inAddr]; ok { if _, ok := r.inNat[inAddr]; ok {
panic("Duplicate listen address inNat: " + inAddr) panic("Duplicate listen address inNat: " + inAddr.String())
} }
r.inNat[inAddr] = c r.inNat[inAddr] = c
} }
@@ -198,7 +196,7 @@ func (r *R) renderFlow() {
panic(err) panic(err)
} }
var participants = map[string]struct{}{} var participants = map[netip.AddrPort]struct{}{}
var participantsVals []string var participantsVals []string
fmt.Fprintln(f, "```mermaid") fmt.Fprintln(f, "```mermaid")
@@ -215,7 +213,7 @@ func (r *R) renderFlow() {
continue continue
} }
participants[addr] = struct{}{} participants[addr] = struct{}{}
sanAddr := strings.Replace(addr, ":", "-", 1) sanAddr := strings.Replace(addr.String(), ":", "-", 1)
participantsVals = append(participantsVals, sanAddr) participantsVals = append(participantsVals, sanAddr)
fmt.Fprintf( fmt.Fprintf(
f, " participant %s as Nebula: %s<br/>UDP: %s\n", f, " participant %s as Nebula: %s<br/>UDP: %s\n",
@@ -252,9 +250,9 @@ func (r *R) renderFlow() {
fmt.Fprintf(f, fmt.Fprintf(f,
" %s%s%s: %s(%s), index %v, counter: %v\n", " %s%s%s: %s(%s), index %v, counter: %v\n",
strings.Replace(p.from.GetUDPAddr(), ":", "-", 1), strings.Replace(p.from.GetUDPAddr().String(), ":", "-", 1),
line, line,
strings.Replace(p.to.GetUDPAddr(), ":", "-", 1), strings.Replace(p.to.GetUDPAddr().String(), ":", "-", 1),
h.TypeName(), h.SubTypeName(), h.RemoteIndex, h.MessageCounter, h.TypeName(), h.SubTypeName(), h.RemoteIndex, h.MessageCounter,
) )
} }
@@ -305,7 +303,7 @@ func (r *R) RenderHostmaps(title string, controls ...*nebula.Control) {
func (r *R) renderHostmaps(title string) { func (r *R) renderHostmaps(title string) {
c := maps.Values(r.controls) c := maps.Values(r.controls)
sort.SliceStable(c, func(i, j int) bool { sort.SliceStable(c, func(i, j int) bool {
return c[i].GetVpnIp() > c[j].GetVpnIp() return c[i].GetVpnIp().Compare(c[j].GetVpnIp()) > 0
}) })
s := renderHostmaps(c...) s := renderHostmaps(c...)
@@ -420,10 +418,8 @@ func (r *R) RouteUntilTxTun(sender *nebula.Control, receiver *nebula.Control) []
// Nope, lets push the sender along // Nope, lets push the sender along
case p := <-udpTx: case p := <-udpTx:
outAddr := sender.GetUDPAddr()
r.Lock() r.Lock()
inAddr := net.JoinHostPort(p.ToIp.String(), fmt.Sprintf("%v", p.ToPort)) c := r.getControl(sender.GetUDPAddr(), p.To, p)
c := r.getControl(outAddr, inAddr, p)
if c == nil { if c == nil {
r.Unlock() r.Unlock()
panic("No control for udp tx") panic("No control for udp tx")
@@ -479,10 +475,7 @@ func (r *R) RouteForAllUntilTxTun(receiver *nebula.Control) []byte {
} else { } else {
// we are a udp tx, route and continue // we are a udp tx, route and continue
p := rx.Interface().(*udp.Packet) p := rx.Interface().(*udp.Packet)
outAddr := cm[x].GetUDPAddr() c := r.getControl(cm[x].GetUDPAddr(), p.To, p)
inAddr := net.JoinHostPort(p.ToIp.String(), fmt.Sprintf("%v", p.ToPort))
c := r.getControl(outAddr, inAddr, p)
if c == nil { if c == nil {
r.Unlock() r.Unlock()
panic("No control for udp tx") panic("No control for udp tx")
@@ -509,12 +502,10 @@ func (r *R) RouteExitFunc(sender *nebula.Control, whatDo ExitFunc) {
panic(err) panic(err)
} }
outAddr := sender.GetUDPAddr() receiver := r.getControl(sender.GetUDPAddr(), p.To, p)
inAddr := net.JoinHostPort(p.ToIp.String(), fmt.Sprintf("%v", p.ToPort))
receiver := r.getControl(outAddr, inAddr, p)
if receiver == nil { if receiver == nil {
r.Unlock() r.Unlock()
panic("Can't route for host: " + inAddr) panic("Can't RouteExitFunc for host: " + p.To.String())
} }
e := whatDo(p, receiver) e := whatDo(p, receiver)
@@ -590,13 +581,13 @@ func (r *R) InjectUDPPacket(sender, receiver *nebula.Control, packet *udp.Packet
// RouteForUntilAfterToAddr will route for sender and return only after it sees and sends a packet destined for toAddr // RouteForUntilAfterToAddr will route for sender and return only after it sees and sends a packet destined for toAddr
// finish can be any of the exitType values except `keepRouting`, the default value is `routeAndExit` // finish can be any of the exitType values except `keepRouting`, the default value is `routeAndExit`
// If the router doesn't have the nebula controller for that address, we panic // If the router doesn't have the nebula controller for that address, we panic
func (r *R) RouteForUntilAfterToAddr(sender *nebula.Control, toAddr *net.UDPAddr, finish ExitType) { func (r *R) RouteForUntilAfterToAddr(sender *nebula.Control, toAddr netip.AddrPort, finish ExitType) {
if finish == KeepRouting { if finish == KeepRouting {
finish = RouteAndExit finish = RouteAndExit
} }
r.RouteExitFunc(sender, func(p *udp.Packet, r *nebula.Control) ExitType { r.RouteExitFunc(sender, func(p *udp.Packet, r *nebula.Control) ExitType {
if p.ToIp.Equal(toAddr.IP) && p.ToPort == uint16(toAddr.Port) { if p.To == toAddr {
return finish return finish
} }
@@ -630,13 +621,10 @@ func (r *R) RouteForAllExitFunc(whatDo ExitFunc) {
r.Lock() r.Lock()
p := rx.Interface().(*udp.Packet) p := rx.Interface().(*udp.Packet)
receiver := r.getControl(cm[x].GetUDPAddr(), p.To, p)
outAddr := cm[x].GetUDPAddr()
inAddr := net.JoinHostPort(p.ToIp.String(), fmt.Sprintf("%v", p.ToPort))
receiver := r.getControl(outAddr, inAddr, p)
if receiver == nil { if receiver == nil {
r.Unlock() r.Unlock()
panic("Can't route for host: " + inAddr) panic("Can't RouteForAllExitFunc for host: " + p.To.String())
} }
e := whatDo(p, receiver) e := whatDo(p, receiver)
@@ -697,41 +685,26 @@ func (r *R) FlushAll() {
p := rx.Interface().(*udp.Packet) p := rx.Interface().(*udp.Packet)
outAddr := cm[x].GetUDPAddr() receiver := r.getControl(cm[x].GetUDPAddr(), p.To, p)
inAddr := net.JoinHostPort(p.ToIp.String(), fmt.Sprintf("%v", p.ToPort))
receiver := r.getControl(outAddr, inAddr, p)
if receiver == nil { if receiver == nil {
r.Unlock() r.Unlock()
panic("Can't route for host: " + inAddr) panic("Can't FlushAll for host: " + p.To.String())
} }
receiver.InjectUDPPacket(p)
r.Unlock() r.Unlock()
} }
} }
// getControl performs or seeds NAT translation and returns the control for toAddr, p from fields may change // getControl performs or seeds NAT translation and returns the control for toAddr, p from fields may change
// This is an internal router function, the caller must hold the lock // This is an internal router function, the caller must hold the lock
func (r *R) getControl(fromAddr, toAddr string, p *udp.Packet) *nebula.Control { func (r *R) getControl(fromAddr, toAddr netip.AddrPort, p *udp.Packet) *nebula.Control {
if newAddr, ok := r.outNat[fromAddr+":"+toAddr]; ok { if newAddr, ok := r.outNat[fromAddr.String()+":"+toAddr.String()]; ok {
p.FromIp = newAddr.IP p.From = newAddr
p.FromPort = uint16(newAddr.Port)
} }
c, ok := r.inNat[toAddr] c, ok := r.inNat[toAddr]
if ok { if ok {
sHost, sPort, err := net.SplitHostPort(toAddr) r.outNat[c.GetUDPAddr().String()+":"+fromAddr.String()] = toAddr
if err != nil {
panic(err)
}
port, err := strconv.Atoi(sPort)
if err != nil {
panic(err)
}
r.outNat[c.GetUDPAddr()+":"+fromAddr] = net.UDPAddr{
IP: net.ParseIP(sHost),
Port: port,
}
return c return c
} }
@@ -746,8 +719,9 @@ func (r *R) formatUdpPacket(p *packet) string {
} }
from := "unknown" from := "unknown"
if c, ok := r.vpnControls[iputil.Ip2VpnIp(v4.SrcIP)]; ok { srcAddr, _ := netip.AddrFromSlice(v4.SrcIP)
from = c.GetUDPAddr() if c, ok := r.vpnControls[srcAddr]; ok {
from = c.GetUDPAddr().String()
} }
udp := packet.Layer(layers.LayerTypeUDP).(*layers.UDP) udp := packet.Layer(layers.LayerTypeUDP).(*layers.UDP)
@@ -759,7 +733,7 @@ func (r *R) formatUdpPacket(p *packet) string {
return fmt.Sprintf( return fmt.Sprintf(
" %s-->>%s: src port: %v<br/>dest port: %v<br/>data: \"%v\"\n", " %s-->>%s: src port: %v<br/>dest port: %v<br/>data: \"%v\"\n",
strings.Replace(from, ":", "-", 1), strings.Replace(from, ":", "-", 1),
strings.Replace(p.to.GetUDPAddr(), ":", "-", 1), strings.Replace(p.to.GetUDPAddr().String(), ":", "-", 1),
udp.SrcPort, udp.SrcPort,
udp.DstPort, udp.DstPort,
string(data.Payload()), string(data.Payload()),

55
e2e/tunnels_test.go Normal file
View File

@@ -0,0 +1,55 @@
//go:build e2e_testing
// +build e2e_testing
package e2e
import (
"testing"
"time"
"github.com/slackhq/nebula/e2e/router"
)
func TestDropInactiveTunnels(t *testing.T) {
// The goal of this test is to ensure the shortest inactivity timeout will close the tunnel on both sides
// under ideal conditions
ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me", "10.128.0.1/24", m{"tunnels": m{"drop_inactive": true, "inactivity_timeout": "5s"}})
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", "10.128.0.2/24", m{"tunnels": m{"drop_inactive": true, "inactivity_timeout": "10m"}})
// Share our underlay information
myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr)
theirControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr)
// Start the servers
myControl.Start()
theirControl.Start()
r := router.NewR(t, myControl, theirControl)
r.Log("Assert the tunnel between me and them works")
assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
r.Log("Go inactive and wait for the tunnels to get dropped")
waitStart := time.Now()
for {
myIndexes := len(myControl.GetHostmap().Indexes)
theirIndexes := len(theirControl.GetHostmap().Indexes)
if myIndexes == 0 && theirIndexes == 0 {
break
}
since := time.Since(waitStart)
r.Logf("my tunnels: %v; their tunnels: %v; duration: %v", myIndexes, theirIndexes, since)
if since > time.Second*30 {
t.Fatal("Tunnel should have been declared inactive after 5 seconds and before 30 seconds")
}
time.Sleep(1 * time.Second)
r.FlushAll()
}
r.Logf("Inactive tunnels were dropped within %v", time.Since(waitStart))
myControl.Stop()
theirControl.Stop()
}

View File

@@ -303,6 +303,18 @@ logging:
# after receiving the response for lighthouse queries # after receiving the response for lighthouse queries
#trigger_buffer: 64 #trigger_buffer: 64
# Tunnel manager settings
#tunnels:
# drop_inactive controls whether inactive tunnels are maintained or dropped after the inactive_timeout period has
# elapsed.
# In general, it is a good idea to enable this setting. It will be enabled by default in a future release.
# This setting is reloadable
#drop_inactive: false
# inactivity_timeout controls how long a tunnel MUST NOT see any inbound or outbound traffic before being considered
# inactive and eligible to be dropped.
# This setting is reloadable
#inactivity_timeout: 10m
# Nebula security group configuration # Nebula security group configuration
firewall: firewall:

View File

@@ -4,6 +4,7 @@ import (
"bufio" "bufio"
"fmt" "fmt"
"log" "log"
"net"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/service" "github.com/slackhq/nebula/service"
@@ -54,16 +55,16 @@ pki:
cert: /home/rice/Developer/nebula-config/app.crt cert: /home/rice/Developer/nebula-config/app.crt
key: /home/rice/Developer/nebula-config/app.key key: /home/rice/Developer/nebula-config/app.key
` `
var config config.C var cfg config.C
if err := config.LoadString(configStr); err != nil { if err := cfg.LoadString(configStr); err != nil {
return err return err
} }
service, err := service.New(&config) svc, err := service.New(&cfg)
if err != nil { if err != nil {
return err return err
} }
ln, err := service.Listen("tcp", ":1234") ln, err := svc.Listen("tcp", ":1234")
if err != nil { if err != nil {
return err return err
} }
@@ -73,16 +74,24 @@ pki:
log.Printf("accept error: %s", err) log.Printf("accept error: %s", err)
break break
} }
defer conn.Close() defer func(conn net.Conn) {
_ = conn.Close()
}(conn)
log.Printf("got connection") log.Printf("got connection")
conn.Write([]byte("hello world\n")) _, err = conn.Write([]byte("hello world\n"))
if err != nil {
log.Printf("write error: %s", err)
}
scanner := bufio.NewScanner(conn) scanner := bufio.NewScanner(conn)
for scanner.Scan() { for scanner.Scan() {
message := scanner.Text() message := scanner.Text()
fmt.Fprintf(conn, "echo: %q\n", message) _, err = fmt.Fprintf(conn, "echo: %q\n", message)
if err != nil {
log.Printf("write error: %s", err)
}
log.Printf("got message %q", message) log.Printf("got message %q", message)
} }
@@ -92,8 +101,8 @@ pki:
} }
} }
service.Close() _ = svc.Close()
if err := service.Wait(); err != nil { if err := svc.Wait(); err != nil {
return err return err
} }
return nil return nil

View File

@@ -6,23 +6,23 @@ import (
"errors" "errors"
"fmt" "fmt"
"hash/fnv" "hash/fnv"
"net" "net/netip"
"reflect" "reflect"
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
"time" "time"
"github.com/gaissmai/bart"
"github.com/rcrowley/go-metrics" "github.com/rcrowley/go-metrics"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/cidr"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/firewall" "github.com/slackhq/nebula/firewall"
) )
type FirewallInterface interface { type FirewallInterface interface {
AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, localIp *net.IPNet, caName string, caSha string) error AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip, localIp netip.Prefix, caName string, caSha string) error
} }
type conn struct { type conn struct {
@@ -52,8 +52,8 @@ type Firewall struct {
DefaultTimeout time.Duration //linux: 600s DefaultTimeout time.Duration //linux: 600s
// Used to ensure we don't emit local packets for ips we don't own // Used to ensure we don't emit local packets for ips we don't own
localIps *cidr.Tree4[struct{}] localIps *bart.Table[struct{}]
assignedCIDR *net.IPNet assignedCIDR netip.Prefix
hasSubnets bool hasSubnets bool
rules string rules string
@@ -108,7 +108,7 @@ type FirewallRule struct {
Any *firewallLocalCIDR Any *firewallLocalCIDR
Hosts map[string]*firewallLocalCIDR Hosts map[string]*firewallLocalCIDR
Groups []*firewallGroups Groups []*firewallGroups
CIDR *cidr.Tree4[*firewallLocalCIDR] CIDR *bart.Table[*firewallLocalCIDR]
} }
type firewallGroups struct { type firewallGroups struct {
@@ -122,7 +122,7 @@ type firewallPort map[int32]*FirewallCA
type firewallLocalCIDR struct { type firewallLocalCIDR struct {
Any bool Any bool
LocalCIDR *cidr.Tree4[struct{}] LocalCIDR *bart.Table[struct{}]
} }
// NewFirewall creates a new Firewall object. A TimerWheel is created for you from the provided timeouts. // NewFirewall creates a new Firewall object. A TimerWheel is created for you from the provided timeouts.
@@ -144,20 +144,28 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D
max = defaultTimeout max = defaultTimeout
} }
localIps := cidr.NewTree4[struct{}]() localIps := new(bart.Table[struct{}])
var assignedCIDR *net.IPNet var assignedCIDR netip.Prefix
var assignedSet bool
for _, ip := range c.Details.Ips { for _, ip := range c.Details.Ips {
ipNet := &net.IPNet{IP: ip.IP, Mask: net.IPMask{255, 255, 255, 255}} //TODO: IPV6-WORK the unmap is a bit unfortunate
localIps.AddCIDR(ipNet, struct{}{}) nip, _ := netip.AddrFromSlice(ip.IP)
nip = nip.Unmap()
nprefix := netip.PrefixFrom(nip, nip.BitLen())
localIps.Insert(nprefix, struct{}{})
if assignedCIDR == nil { if !assignedSet {
// Only grabbing the first one in the cert since any more than that currently has undefined behavior // Only grabbing the first one in the cert since any more than that currently has undefined behavior
assignedCIDR = ipNet assignedCIDR = nprefix
assignedSet = true
} }
} }
for _, n := range c.Details.Subnets { for _, n := range c.Details.Subnets {
localIps.AddCIDR(n, struct{}{}) nip, _ := netip.AddrFromSlice(n.IP)
ones, _ := n.Mask.Size()
nip = nip.Unmap()
localIps.Insert(netip.PrefixFrom(nip, ones), struct{}{})
} }
return &Firewall{ return &Firewall{
@@ -237,15 +245,15 @@ func NewFirewallFromConfig(l *logrus.Logger, nc *cert.NebulaCertificate, c *conf
} }
// AddRule properly creates the in memory rule structure for a firewall table. // AddRule properly creates the in memory rule structure for a firewall table.
func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, localIp *net.IPNet, caName string, caSha string) error { func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip, localIp netip.Prefix, caName string, caSha string) error {
// Under gomobile, stringing a nil pointer with fmt causes an abort in debug mode for iOS // Under gomobile, stringing a nil pointer with fmt causes an abort in debug mode for iOS
// https://github.com/golang/go/issues/14131 // https://github.com/golang/go/issues/14131
sIp := "" sIp := ""
if ip != nil { if ip.IsValid() {
sIp = ip.String() sIp = ip.String()
} }
lIp := "" lIp := ""
if localIp != nil { if localIp.IsValid() {
lIp = localIp.String() lIp = localIp.String()
} }
@@ -382,17 +390,17 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw
return fmt.Errorf("%s rule #%v; proto was not understood; `%s`", table, i, r.Proto) return fmt.Errorf("%s rule #%v; proto was not understood; `%s`", table, i, r.Proto)
} }
var cidr *net.IPNet var cidr netip.Prefix
if r.Cidr != "" { if r.Cidr != "" {
_, cidr, err = net.ParseCIDR(r.Cidr) cidr, err = netip.ParsePrefix(r.Cidr)
if err != nil { if err != nil {
return fmt.Errorf("%s rule #%v; cidr did not parse; %s", table, i, err) return fmt.Errorf("%s rule #%v; cidr did not parse; %s", table, i, err)
} }
} }
var localCidr *net.IPNet var localCidr netip.Prefix
if r.LocalCidr != "" { if r.LocalCidr != "" {
_, localCidr, err = net.ParseCIDR(r.LocalCidr) localCidr, err = netip.ParsePrefix(r.LocalCidr)
if err != nil { if err != nil {
return fmt.Errorf("%s rule #%v; local_cidr did not parse; %s", table, i, err) return fmt.Errorf("%s rule #%v; local_cidr did not parse; %s", table, i, err)
} }
@@ -421,7 +429,8 @@ func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *
// Make sure remote address matches nebula certificate // Make sure remote address matches nebula certificate
if remoteCidr := h.remoteCidr; remoteCidr != nil { if remoteCidr := h.remoteCidr; remoteCidr != nil {
ok, _ := remoteCidr.Contains(fp.RemoteIP) //TODO: this would be better if we had a least specific match lookup, could waste time here, need to benchmark since the algo is different
_, ok := remoteCidr.Lookup(fp.RemoteIP)
if !ok { if !ok {
f.metrics(incoming).droppedRemoteIP.Inc(1) f.metrics(incoming).droppedRemoteIP.Inc(1)
return ErrInvalidRemoteIP return ErrInvalidRemoteIP
@@ -435,7 +444,8 @@ func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *
} }
// Make sure we are supposed to be handling this local ip address // Make sure we are supposed to be handling this local ip address
ok, _ := f.localIps.Contains(fp.LocalIP) //TODO: this would be better if we had a least specific match lookup, could waste time here, need to benchmark since the algo is different
_, ok := f.localIps.Lookup(fp.LocalIP)
if !ok { if !ok {
f.metrics(incoming).droppedLocalIP.Inc(1) f.metrics(incoming).droppedLocalIP.Inc(1)
return ErrInvalidLocalIP return ErrInvalidLocalIP
@@ -589,7 +599,6 @@ func (f *Firewall) addConn(fp firewall.Packet, incoming bool) {
// Evict checks if a conntrack entry has expired, if so it is removed, if not it is re-added to the wheel // Evict checks if a conntrack entry has expired, if so it is removed, if not it is re-added to the wheel
// Caller must own the connMutex lock! // Caller must own the connMutex lock!
func (f *Firewall) evict(p firewall.Packet) { func (f *Firewall) evict(p firewall.Packet) {
//TODO: report a stat if the tcp rtt tracking was never resolved?
// Are we still tracking this conn? // Are we still tracking this conn?
conntrack := f.Conntrack conntrack := f.Conntrack
t, ok := conntrack.Conns[p] t, ok := conntrack.Conns[p]
@@ -633,7 +642,7 @@ func (ft *FirewallTable) match(p firewall.Packet, incoming bool, c *cert.NebulaC
return false return false
} }
func (fp firewallPort) addRule(f *Firewall, startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, localIp *net.IPNet, caName string, caSha string) error { func (fp firewallPort) addRule(f *Firewall, startPort int32, endPort int32, groups []string, host string, ip, localIp netip.Prefix, caName string, caSha string) error {
if startPort > endPort { if startPort > endPort {
return fmt.Errorf("start port was lower than end port") return fmt.Errorf("start port was lower than end port")
} }
@@ -677,12 +686,12 @@ func (fp firewallPort) match(p firewall.Packet, incoming bool, c *cert.NebulaCer
return fp[firewall.PortAny].match(p, c, caPool) return fp[firewall.PortAny].match(p, c, caPool)
} }
func (fc *FirewallCA) addRule(f *Firewall, groups []string, host string, ip, localIp *net.IPNet, caName, caSha string) error { func (fc *FirewallCA) addRule(f *Firewall, groups []string, host string, ip, localIp netip.Prefix, caName, caSha string) error {
fr := func() *FirewallRule { fr := func() *FirewallRule {
return &FirewallRule{ return &FirewallRule{
Hosts: make(map[string]*firewallLocalCIDR), Hosts: make(map[string]*firewallLocalCIDR),
Groups: make([]*firewallGroups, 0), Groups: make([]*firewallGroups, 0),
CIDR: cidr.NewTree4[*firewallLocalCIDR](), CIDR: new(bart.Table[*firewallLocalCIDR]),
} }
} }
@@ -740,10 +749,10 @@ func (fc *FirewallCA) match(p firewall.Packet, c *cert.NebulaCertificate, caPool
return fc.CANames[s.Details.Name].match(p, c) return fc.CANames[s.Details.Name].match(p, c)
} }
func (fr *FirewallRule) addRule(f *Firewall, groups []string, host string, ip *net.IPNet, localCIDR *net.IPNet) error { func (fr *FirewallRule) addRule(f *Firewall, groups []string, host string, ip, localCIDR netip.Prefix) error {
flc := func() *firewallLocalCIDR { flc := func() *firewallLocalCIDR {
return &firewallLocalCIDR{ return &firewallLocalCIDR{
LocalCIDR: cidr.NewTree4[struct{}](), LocalCIDR: new(bart.Table[struct{}]),
} }
} }
@@ -780,8 +789,8 @@ func (fr *FirewallRule) addRule(f *Firewall, groups []string, host string, ip *n
fr.Hosts[host] = nlc fr.Hosts[host] = nlc
} }
if ip != nil { if ip.IsValid() {
_, nlc := fr.CIDR.GetCIDR(ip) nlc, _ := fr.CIDR.Get(ip)
if nlc == nil { if nlc == nil {
nlc = flc() nlc = flc()
} }
@@ -789,14 +798,14 @@ func (fr *FirewallRule) addRule(f *Firewall, groups []string, host string, ip *n
if err != nil { if err != nil {
return err return err
} }
fr.CIDR.AddCIDR(ip, nlc) fr.CIDR.Insert(ip, nlc)
} }
return nil return nil
} }
func (fr *FirewallRule) isAny(groups []string, host string, ip *net.IPNet) bool { func (fr *FirewallRule) isAny(groups []string, host string, ip netip.Prefix) bool {
if len(groups) == 0 && host == "" && ip == nil { if len(groups) == 0 && host == "" && !ip.IsValid() {
return true return true
} }
@@ -810,7 +819,7 @@ func (fr *FirewallRule) isAny(groups []string, host string, ip *net.IPNet) bool
return true return true
} }
if ip != nil && ip.Contains(net.IPv4(0, 0, 0, 0)) { if ip.IsValid() && ip.Bits() == 0 {
return true return true
} }
@@ -853,24 +862,31 @@ func (fr *FirewallRule) match(p firewall.Packet, c *cert.NebulaCertificate) bool
} }
} }
return fr.CIDR.EachContains(p.RemoteIP, func(flc *firewallLocalCIDR) bool { matched := false
return flc.match(p, c) prefix := netip.PrefixFrom(p.RemoteIP, p.RemoteIP.BitLen())
fr.CIDR.EachLookupPrefix(prefix, func(prefix netip.Prefix, val *firewallLocalCIDR) bool {
if prefix.Contains(p.RemoteIP) && val.match(p, c) {
matched = true
return false
}
return true
}) })
return matched
} }
func (flc *firewallLocalCIDR) addRule(f *Firewall, localIp *net.IPNet) error { func (flc *firewallLocalCIDR) addRule(f *Firewall, localIp netip.Prefix) error {
if localIp == nil { if !localIp.IsValid() {
if !f.hasSubnets || f.defaultLocalCIDRAny { if !f.hasSubnets || f.defaultLocalCIDRAny {
flc.Any = true flc.Any = true
return nil return nil
} }
localIp = f.assignedCIDR localIp = f.assignedCIDR
} else if localIp.Contains(net.IPv4(0, 0, 0, 0)) { } else if localIp.Bits() == 0 {
flc.Any = true flc.Any = true
} }
flc.LocalCIDR.AddCIDR(localIp, struct{}{}) flc.LocalCIDR.Insert(localIp, struct{}{})
return nil return nil
} }
@@ -883,7 +899,7 @@ func (flc *firewallLocalCIDR) match(p firewall.Packet, c *cert.NebulaCertificate
return true return true
} }
ok, _ := flc.LocalCIDR.Contains(p.LocalIP) _, ok := flc.LocalCIDR.Lookup(p.LocalIP)
return ok return ok
} }

View File

@@ -3,8 +3,7 @@ package firewall
import ( import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"net/netip"
"github.com/slackhq/nebula/iputil"
) )
type m map[string]interface{} type m map[string]interface{}
@@ -20,8 +19,8 @@ const (
) )
type Packet struct { type Packet struct {
LocalIP iputil.VpnIp LocalIP netip.Addr
RemoteIP iputil.VpnIp RemoteIP netip.Addr
LocalPort uint16 LocalPort uint16
RemotePort uint16 RemotePort uint16
Protocol uint8 Protocol uint8

View File

@@ -5,13 +5,13 @@ import (
"errors" "errors"
"math" "math"
"net" "net"
"net/netip"
"testing" "testing"
"time" "time"
"github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/firewall" "github.com/slackhq/nebula/firewall"
"github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/test" "github.com/slackhq/nebula/test"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
@@ -65,59 +65,62 @@ func TestFirewall_AddRule(t *testing.T) {
assert.NotNil(t, fw.InRules) assert.NotNil(t, fw.InRules)
assert.NotNil(t, fw.OutRules) assert.NotNil(t, fw.OutRules)
_, ti, _ := net.ParseCIDR("1.2.3.4/32") ti, err := netip.ParsePrefix("1.2.3.4/32")
assert.NoError(t, err)
assert.Nil(t, fw.AddRule(true, firewall.ProtoTCP, 1, 1, []string{}, "", nil, nil, "", "")) assert.Nil(t, fw.AddRule(true, firewall.ProtoTCP, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
// An empty rule is any // An empty rule is any
assert.True(t, fw.InRules.TCP[1].Any.Any.Any) assert.True(t, fw.InRules.TCP[1].Any.Any.Any)
assert.Empty(t, fw.InRules.TCP[1].Any.Groups) assert.Empty(t, fw.InRules.TCP[1].Any.Groups)
assert.Empty(t, fw.InRules.TCP[1].Any.Hosts) assert.Empty(t, fw.InRules.TCP[1].Any.Hosts)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", nil, nil, "", "")) assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
assert.Nil(t, fw.InRules.UDP[1].Any.Any) assert.Nil(t, fw.InRules.UDP[1].Any.Any)
assert.Contains(t, fw.InRules.UDP[1].Any.Groups[0].Groups, "g1") assert.Contains(t, fw.InRules.UDP[1].Any.Groups[0].Groups, "g1")
assert.Empty(t, fw.InRules.UDP[1].Any.Hosts) assert.Empty(t, fw.InRules.UDP[1].Any.Hosts)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
assert.Nil(t, fw.AddRule(true, firewall.ProtoICMP, 1, 1, []string{}, "h1", nil, nil, "", "")) assert.Nil(t, fw.AddRule(true, firewall.ProtoICMP, 1, 1, []string{}, "h1", netip.Prefix{}, netip.Prefix{}, "", ""))
assert.Nil(t, fw.InRules.ICMP[1].Any.Any) assert.Nil(t, fw.InRules.ICMP[1].Any.Any)
assert.Empty(t, fw.InRules.ICMP[1].Any.Groups) assert.Empty(t, fw.InRules.ICMP[1].Any.Groups)
assert.Contains(t, fw.InRules.ICMP[1].Any.Hosts, "h1") assert.Contains(t, fw.InRules.ICMP[1].Any.Hosts, "h1")
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti, nil, "", "")) assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti, netip.Prefix{}, "", ""))
assert.Nil(t, fw.OutRules.AnyProto[1].Any.Any) assert.Nil(t, fw.OutRules.AnyProto[1].Any.Any)
ok, _ := fw.OutRules.AnyProto[1].Any.CIDR.GetCIDR(ti) _, ok := fw.OutRules.AnyProto[1].Any.CIDR.Get(ti)
assert.True(t, ok) assert.True(t, ok)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", nil, ti, "", "")) assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, ti, "", ""))
assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any) assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any)
ok, _ = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.GetCIDR(ti) _, ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti)
assert.True(t, ok) assert.True(t, ok)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", nil, nil, "ca-name", "")) assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "ca-name", ""))
assert.Contains(t, fw.InRules.UDP[1].CANames, "ca-name") assert.Contains(t, fw.InRules.UDP[1].CANames, "ca-name")
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", nil, nil, "", "ca-sha")) assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "", "ca-sha"))
assert.Contains(t, fw.InRules.UDP[1].CAShas, "ca-sha") assert.Contains(t, fw.InRules.UDP[1].CAShas, "ca-sha")
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", nil, nil, "", "")) assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", netip.Prefix{}, netip.Prefix{}, "", ""))
assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any) assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
_, anyIp, _ := net.ParseCIDR("0.0.0.0/0") anyIp, err := netip.ParsePrefix("0.0.0.0/0")
assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp, nil, "", "")) assert.NoError(t, err)
assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp, netip.Prefix{}, "", ""))
assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any) assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
// Test error conditions // Test error conditions
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
assert.Error(t, fw.AddRule(true, math.MaxUint8, 0, 0, []string{}, "", nil, nil, "", "")) assert.Error(t, fw.AddRule(true, math.MaxUint8, 0, 0, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
assert.Error(t, fw.AddRule(true, firewall.ProtoAny, 10, 0, []string{}, "", nil, nil, "", "")) assert.Error(t, fw.AddRule(true, firewall.ProtoAny, 10, 0, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
} }
func TestFirewall_Drop(t *testing.T) { func TestFirewall_Drop(t *testing.T) {
@@ -126,8 +129,8 @@ func TestFirewall_Drop(t *testing.T) {
l.SetOutput(ob) l.SetOutput(ob)
p := firewall.Packet{ p := firewall.Packet{
LocalIP: iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)), LocalIP: netip.MustParseAddr("1.2.3.4"),
RemoteIP: iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)), RemoteIP: netip.MustParseAddr("1.2.3.4"),
LocalPort: 10, LocalPort: 10,
RemotePort: 90, RemotePort: 90,
Protocol: firewall.ProtoUDP, Protocol: firewall.ProtoUDP,
@@ -152,16 +155,16 @@ func TestFirewall_Drop(t *testing.T) {
ConnectionState: &ConnectionState{ ConnectionState: &ConnectionState{
peerCert: &c, peerCert: &c,
}, },
vpnIp: iputil.Ip2VpnIp(ipNet.IP), vpnIp: netip.MustParseAddr("1.2.3.4"),
} }
h.CreateRemoteCIDR(&c) h.CreateRemoteCIDR(&c)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", nil, nil, "", "")) assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
cp := cert.NewCAPool() cp := cert.NewCAPool()
// Drop outbound // Drop outbound
assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrNoMatchingRule) assert.Equal(t, ErrNoMatchingRule, fw.Drop(p, false, &h, cp, nil))
// Allow inbound // Allow inbound
resetConntrack(fw) resetConntrack(fw)
assert.NoError(t, fw.Drop(p, true, &h, cp, nil)) assert.NoError(t, fw.Drop(p, true, &h, cp, nil))
@@ -170,34 +173,34 @@ func TestFirewall_Drop(t *testing.T) {
// test remote mismatch // test remote mismatch
oldRemote := p.RemoteIP oldRemote := p.RemoteIP
p.RemoteIP = iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 10)) p.RemoteIP = netip.MustParseAddr("1.2.3.10")
assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrInvalidRemoteIP) assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrInvalidRemoteIP)
p.RemoteIP = oldRemote p.RemoteIP = oldRemote
// ensure signer doesn't get in the way of group checks // ensure signer doesn't get in the way of group checks
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, nil, "", "signer-shasum")) assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum"))
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, nil, "", "signer-shasum-bad")) assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad"))
assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule) assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
// test caSha doesn't drop on match // test caSha doesn't drop on match
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, nil, "", "signer-shasum-bad")) assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad"))
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, nil, "", "signer-shasum")) assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum"))
assert.NoError(t, fw.Drop(p, true, &h, cp, nil)) assert.NoError(t, fw.Drop(p, true, &h, cp, nil))
// ensure ca name doesn't get in the way of group checks // ensure ca name doesn't get in the way of group checks
cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}} cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}}
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, nil, "ca-good", "")) assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", ""))
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, nil, "ca-good-bad", "")) assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", ""))
assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule) assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
// test caName doesn't drop on match // test caName doesn't drop on match
cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}} cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}}
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, nil, "ca-good-bad", "")) assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", ""))
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, nil, "ca-good", "")) assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", ""))
assert.NoError(t, fw.Drop(p, true, &h, cp, nil)) assert.NoError(t, fw.Drop(p, true, &h, cp, nil))
} }
@@ -207,10 +210,9 @@ func BenchmarkFirewallTable_match(b *testing.B) {
TCP: firewallPort{}, TCP: firewallPort{},
} }
_, n, _ := net.ParseCIDR("172.1.1.1/32") pfix := netip.MustParsePrefix("172.1.1.1/32")
goodLocalCIDRIP := iputil.Ip2VpnIp(n.IP) _ = ft.TCP.addRule(f, 10, 10, []string{"good-group"}, "good-host", pfix, netip.Prefix{}, "", "")
_ = ft.TCP.addRule(f, 10, 10, []string{"good-group"}, "good-host", n, nil, "", "") _ = ft.TCP.addRule(f, 100, 100, []string{"good-group"}, "good-host", netip.Prefix{}, pfix, "", "")
_ = ft.TCP.addRule(f, 100, 100, []string{"good-group"}, "good-host", nil, n, "", "")
cp := cert.NewCAPool() cp := cert.NewCAPool()
b.Run("fail on proto", func(b *testing.B) { b.Run("fail on proto", func(b *testing.B) {
@@ -231,10 +233,9 @@ func BenchmarkFirewallTable_match(b *testing.B) {
b.Run("pass proto, port, fail on local CIDR", func(b *testing.B) { b.Run("pass proto, port, fail on local CIDR", func(b *testing.B) {
c := &cert.NebulaCertificate{} c := &cert.NebulaCertificate{}
ip, _, _ := net.ParseCIDR("9.254.254.254/32") ip := netip.MustParsePrefix("9.254.254.254/32")
lip := iputil.Ip2VpnIp(ip)
for n := 0; n < b.N; n++ { for n := 0; n < b.N; n++ {
assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: lip}, true, c, cp)) assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: ip.Addr()}, true, c, cp))
} }
}) })
@@ -262,7 +263,7 @@ func BenchmarkFirewallTable_match(b *testing.B) {
}, },
} }
for n := 0; n < b.N; n++ { for n := 0; n < b.N; n++ {
assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: goodLocalCIDRIP}, true, c, cp)) assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: pfix.Addr()}, true, c, cp))
} }
}) })
@@ -286,7 +287,7 @@ func BenchmarkFirewallTable_match(b *testing.B) {
}, },
} }
for n := 0; n < b.N; n++ { for n := 0; n < b.N; n++ {
assert.True(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: goodLocalCIDRIP}, true, c, cp)) assert.True(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: pfix.Addr()}, true, c, cp))
} }
}) })
@@ -363,8 +364,8 @@ func TestFirewall_Drop2(t *testing.T) {
l.SetOutput(ob) l.SetOutput(ob)
p := firewall.Packet{ p := firewall.Packet{
LocalIP: iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)), LocalIP: netip.MustParseAddr("1.2.3.4"),
RemoteIP: iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)), RemoteIP: netip.MustParseAddr("1.2.3.4"),
LocalPort: 10, LocalPort: 10,
RemotePort: 90, RemotePort: 90,
Protocol: firewall.ProtoUDP, Protocol: firewall.ProtoUDP,
@@ -387,7 +388,7 @@ func TestFirewall_Drop2(t *testing.T) {
ConnectionState: &ConnectionState{ ConnectionState: &ConnectionState{
peerCert: &c, peerCert: &c,
}, },
vpnIp: iputil.Ip2VpnIp(ipNet.IP), vpnIp: netip.MustParseAddr(ipNet.IP.String()),
} }
h.CreateRemoteCIDR(&c) h.CreateRemoteCIDR(&c)
@@ -406,7 +407,7 @@ func TestFirewall_Drop2(t *testing.T) {
h1.CreateRemoteCIDR(&c1) h1.CreateRemoteCIDR(&c1)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", nil, nil, "", "")) assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
cp := cert.NewCAPool() cp := cert.NewCAPool()
// h1/c1 lacks the proper groups // h1/c1 lacks the proper groups
@@ -422,8 +423,8 @@ func TestFirewall_Drop3(t *testing.T) {
l.SetOutput(ob) l.SetOutput(ob)
p := firewall.Packet{ p := firewall.Packet{
LocalIP: iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)), LocalIP: netip.MustParseAddr("1.2.3.4"),
RemoteIP: iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)), RemoteIP: netip.MustParseAddr("1.2.3.4"),
LocalPort: 1, LocalPort: 1,
RemotePort: 1, RemotePort: 1,
Protocol: firewall.ProtoUDP, Protocol: firewall.ProtoUDP,
@@ -453,7 +454,7 @@ func TestFirewall_Drop3(t *testing.T) {
ConnectionState: &ConnectionState{ ConnectionState: &ConnectionState{
peerCert: &c1, peerCert: &c1,
}, },
vpnIp: iputil.Ip2VpnIp(ipNet.IP), vpnIp: netip.MustParseAddr(ipNet.IP.String()),
} }
h1.CreateRemoteCIDR(&c1) h1.CreateRemoteCIDR(&c1)
@@ -468,7 +469,7 @@ func TestFirewall_Drop3(t *testing.T) {
ConnectionState: &ConnectionState{ ConnectionState: &ConnectionState{
peerCert: &c2, peerCert: &c2,
}, },
vpnIp: iputil.Ip2VpnIp(ipNet.IP), vpnIp: netip.MustParseAddr(ipNet.IP.String()),
} }
h2.CreateRemoteCIDR(&c2) h2.CreateRemoteCIDR(&c2)
@@ -483,13 +484,13 @@ func TestFirewall_Drop3(t *testing.T) {
ConnectionState: &ConnectionState{ ConnectionState: &ConnectionState{
peerCert: &c3, peerCert: &c3,
}, },
vpnIp: iputil.Ip2VpnIp(ipNet.IP), vpnIp: netip.MustParseAddr(ipNet.IP.String()),
} }
h3.CreateRemoteCIDR(&c3) h3.CreateRemoteCIDR(&c3)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", nil, nil, "", "")) assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", netip.Prefix{}, netip.Prefix{}, "", ""))
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", nil, nil, "", "signer-sha")) assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-sha"))
cp := cert.NewCAPool() cp := cert.NewCAPool()
// c1 should pass because host match // c1 should pass because host match
@@ -508,8 +509,8 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
l.SetOutput(ob) l.SetOutput(ob)
p := firewall.Packet{ p := firewall.Packet{
LocalIP: iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)), LocalIP: netip.MustParseAddr("1.2.3.4"),
RemoteIP: iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)), RemoteIP: netip.MustParseAddr("1.2.3.4"),
LocalPort: 10, LocalPort: 10,
RemotePort: 90, RemotePort: 90,
Protocol: firewall.ProtoUDP, Protocol: firewall.ProtoUDP,
@@ -534,12 +535,12 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
ConnectionState: &ConnectionState{ ConnectionState: &ConnectionState{
peerCert: &c, peerCert: &c,
}, },
vpnIp: iputil.Ip2VpnIp(ipNet.IP), vpnIp: netip.MustParseAddr(ipNet.IP.String()),
} }
h.CreateRemoteCIDR(&c) h.CreateRemoteCIDR(&c)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", nil, nil, "", "")) assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
cp := cert.NewCAPool() cp := cert.NewCAPool()
// Drop outbound // Drop outbound
@@ -552,7 +553,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
oldFw := fw oldFw := fw
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 10, 10, []string{"any"}, "", nil, nil, "", "")) assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 10, 10, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
fw.Conntrack = oldFw.Conntrack fw.Conntrack = oldFw.Conntrack
fw.rulesVersion = oldFw.rulesVersion + 1 fw.rulesVersion = oldFw.rulesVersion + 1
@@ -561,7 +562,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
oldFw = fw oldFw = fw
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 11, 11, []string{"any"}, "", nil, nil, "", "")) assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 11, 11, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
fw.Conntrack = oldFw.Conntrack fw.Conntrack = oldFw.Conntrack
fw.rulesVersion = oldFw.rulesVersion + 1 fw.rulesVersion = oldFw.rulesVersion + 1
@@ -725,13 +726,13 @@ func TestNewFirewallFromConfig(t *testing.T) {
conf = config.NewC(l) conf = config.NewC(l)
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "cidr": "testh", "proto": "any"}}} conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "cidr": "testh", "proto": "any"}}}
_, err = NewFirewallFromConfig(l, c, conf) _, err = NewFirewallFromConfig(l, c, conf)
assert.EqualError(t, err, "firewall.outbound rule #0; cidr did not parse; invalid CIDR address: testh") assert.EqualError(t, err, "firewall.outbound rule #0; cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'")
// Test local_cidr parse error // Test local_cidr parse error
conf = config.NewC(l) conf = config.NewC(l)
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "local_cidr": "testh", "proto": "any"}}} conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "local_cidr": "testh", "proto": "any"}}}
_, err = NewFirewallFromConfig(l, c, conf) _, err = NewFirewallFromConfig(l, c, conf)
assert.EqualError(t, err, "firewall.outbound rule #0; local_cidr did not parse; invalid CIDR address: testh") assert.EqualError(t, err, "firewall.outbound rule #0; local_cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'")
// Test both group and groups // Test both group and groups
conf = config.NewC(l) conf = config.NewC(l)
@@ -747,78 +748,78 @@ func TestAddFirewallRulesFromConfig(t *testing.T) {
mf := &mockFirewall{} mf := &mockFirewall{}
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "tcp", "host": "a"}}} conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "tcp", "host": "a"}}}
assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf)) assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf))
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoTCP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil, localIp: nil}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoTCP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
// Test adding udp rule // Test adding udp rule
conf = config.NewC(l) conf = config.NewC(l)
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "udp", "host": "a"}}} conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "udp", "host": "a"}}}
assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf)) assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf))
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoUDP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil, localIp: nil}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoUDP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
// Test adding icmp rule // Test adding icmp rule
conf = config.NewC(l) conf = config.NewC(l)
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "icmp", "host": "a"}}} conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "icmp", "host": "a"}}}
assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf)) assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf))
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil, localIp: nil}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
// Test adding any rule // Test adding any rule
conf = config.NewC(l) conf = config.NewC(l)
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "host": "a"}}} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "host": "a"}}}
assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf)) assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil, localIp: nil}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
// Test adding rule with cidr // Test adding rule with cidr
cidr := &net.IPNet{IP: net.ParseIP("10.0.0.0").To4(), Mask: net.IPv4Mask(255, 0, 0, 0)} cidr := netip.MustParsePrefix("10.0.0.0/8")
conf = config.NewC(l) conf = config.NewC(l)
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "cidr": cidr.String()}}} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "cidr": cidr.String()}}}
assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf)) assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr, localIp: nil}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr, localIp: netip.Prefix{}}, mf.lastCall)
// Test adding rule with local_cidr // Test adding rule with local_cidr
conf = config.NewC(l) conf = config.NewC(l)
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "local_cidr": cidr.String()}}} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "local_cidr": cidr.String()}}}
assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf)) assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: nil, localIp: cidr}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: cidr}, mf.lastCall)
// Test adding rule with ca_sha // Test adding rule with ca_sha
conf = config.NewC(l) conf = config.NewC(l)
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "ca_sha": "12312313123"}}} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "ca_sha": "12312313123"}}}
assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf)) assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: nil, localIp: nil, caSha: "12312313123"}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: netip.Prefix{}, caSha: "12312313123"}, mf.lastCall)
// Test adding rule with ca_name // Test adding rule with ca_name
conf = config.NewC(l) conf = config.NewC(l)
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "ca_name": "root01"}}} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "ca_name": "root01"}}}
assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf)) assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: nil, localIp: nil, caName: "root01"}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: netip.Prefix{}, caName: "root01"}, mf.lastCall)
// Test single group // Test single group
conf = config.NewC(l) conf = config.NewC(l)
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a"}}} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a"}}}
assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf)) assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: nil, localIp: nil}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
// Test single groups // Test single groups
conf = config.NewC(l) conf = config.NewC(l)
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "groups": "a"}}} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "groups": "a"}}}
assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf)) assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: nil, localIp: nil}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
// Test multiple AND groups // Test multiple AND groups
conf = config.NewC(l) conf = config.NewC(l)
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "groups": []string{"a", "b"}}}} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "groups": []string{"a", "b"}}}}
assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf)) assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a", "b"}, ip: nil, localIp: nil}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a", "b"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
// Test Add error // Test Add error
conf = config.NewC(l) conf = config.NewC(l)
@@ -871,8 +872,8 @@ type addRuleCall struct {
endPort int32 endPort int32
groups []string groups []string
host string host string
ip *net.IPNet ip netip.Prefix
localIp *net.IPNet localIp netip.Prefix
caName string caName string
caSha string caSha string
} }
@@ -882,7 +883,7 @@ type mockFirewall struct {
nextCallReturn error nextCallReturn error
} }
func (mf *mockFirewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, localIp *net.IPNet, caName string, caSha string) error { func (mf *mockFirewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip netip.Prefix, localIp netip.Prefix, caName string, caSha string) error {
mf.lastCall = addRuleCall{ mf.lastCall = addRuleCall{
incoming: incoming, incoming: incoming,
proto: proto, proto: proto,

22
go.mod
View File

@@ -10,34 +10,36 @@ require (
github.com/armon/go-radix v1.0.0 github.com/armon/go-radix v1.0.0
github.com/cyberdelia/go-metrics-graphite v0.0.0-20161219230853-39f87cc3b432 github.com/cyberdelia/go-metrics-graphite v0.0.0-20161219230853-39f87cc3b432
github.com/flynn/noise v1.1.0 github.com/flynn/noise v1.1.0
github.com/gaissmai/bart v0.11.1
github.com/gogo/protobuf v1.3.2 github.com/gogo/protobuf v1.3.2
github.com/google/gopacket v1.1.19 github.com/google/gopacket v1.1.19
github.com/kardianos/service v1.2.2 github.com/kardianos/service v1.2.2
github.com/miekg/dns v1.1.59 github.com/miekg/dns v1.1.61
github.com/nbrownus/go-metrics-prometheus v0.0.0-20210712211119-974a6260965f github.com/nbrownus/go-metrics-prometheus v0.0.0-20210712211119-974a6260965f
github.com/prometheus/client_golang v1.19.0 github.com/prometheus/client_golang v1.19.1
github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475 github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475
github.com/sirupsen/logrus v1.9.3 github.com/sirupsen/logrus v1.9.3
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e
github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8 github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8
github.com/stretchr/testify v1.9.0 github.com/stretchr/testify v1.9.0
github.com/vishvananda/netlink v1.2.1-beta.2 github.com/vishvananda/netlink v1.2.1-beta.2
golang.org/x/crypto v0.23.0 golang.org/x/crypto v0.26.0
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 golang.org/x/exp v0.0.0-20230725093048-515e97ebf090
golang.org/x/net v0.25.0 golang.org/x/net v0.28.0
golang.org/x/sync v0.7.0 golang.org/x/sync v0.8.0
golang.org/x/sys v0.20.0 golang.org/x/sys v0.24.0
golang.org/x/term v0.20.0 golang.org/x/term v0.23.0
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2
golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b
golang.zx2c4.com/wireguard/windows v0.5.3 golang.zx2c4.com/wireguard/windows v0.5.3
google.golang.org/protobuf v1.34.1 google.golang.org/protobuf v1.34.2
gopkg.in/yaml.v2 v2.4.0 gopkg.in/yaml.v2 v2.4.0
gvisor.dev/gvisor v0.0.0-20240423190808-9d7a357edefe gvisor.dev/gvisor v0.0.0-20240423190808-9d7a357edefe
) )
require ( require (
github.com/beorn7/perks v1.0.1 // indirect github.com/beorn7/perks v1.0.1 // indirect
github.com/bits-and-blooms/bitset v1.13.0 // indirect
github.com/cespare/xxhash/v2 v2.2.0 // indirect github.com/cespare/xxhash/v2 v2.2.0 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect github.com/davecgh/go-spew v1.1.1 // indirect
github.com/google/btree v1.1.2 // indirect github.com/google/btree v1.1.2 // indirect
@@ -46,8 +48,8 @@ require (
github.com/prometheus/common v0.48.0 // indirect github.com/prometheus/common v0.48.0 // indirect
github.com/prometheus/procfs v0.12.0 // indirect github.com/prometheus/procfs v0.12.0 // indirect
github.com/vishvananda/netns v0.0.4 // indirect github.com/vishvananda/netns v0.0.4 // indirect
golang.org/x/mod v0.16.0 // indirect golang.org/x/mod v0.18.0 // indirect
golang.org/x/time v0.5.0 // indirect golang.org/x/time v0.5.0 // indirect
golang.org/x/tools v0.19.0 // indirect golang.org/x/tools v0.22.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect
) )

44
go.sum
View File

@@ -14,6 +14,8 @@ github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24
github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8= github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8=
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
github.com/bits-and-blooms/bitset v1.13.0 h1:bAQ9OPNFYbGHV6Nez0tmNI0RiEu7/hxlYJRUA0wFAVE=
github.com/bits-and-blooms/bitset v1.13.0/go.mod h1:7hO7Gc7Pp1vODcmWvKMRA9BNmbv6a/7QIWpPxHddWR8=
github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44= github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44=
github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
@@ -24,6 +26,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/flynn/noise v1.1.0 h1:KjPQoQCEFdZDiP03phOvGi11+SVVhBG2wOWAorLsstg= github.com/flynn/noise v1.1.0 h1:KjPQoQCEFdZDiP03phOvGi11+SVVhBG2wOWAorLsstg=
github.com/flynn/noise v1.1.0/go.mod h1:xbMo+0i6+IGbYdJhF31t2eR1BIU0CYc12+BNAKwUTag= github.com/flynn/noise v1.1.0/go.mod h1:xbMo+0i6+IGbYdJhF31t2eR1BIU0CYc12+BNAKwUTag=
github.com/gaissmai/bart v0.11.1 h1:5Uv5XwsaFBRo4E5VBcb9TzY8B7zxFf+U7isDxqOrRfc=
github.com/gaissmai/bart v0.11.1/go.mod h1:KHeYECXQiBjTzQz/om2tqn3sZF1J7hw9m6z41ftj3fg=
github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY= github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY=
@@ -77,8 +81,8 @@ github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0=
github.com/miekg/dns v1.1.59 h1:C9EXc/UToRwKLhK5wKU/I4QVsBUc8kE6MkHBkeypWZs= github.com/miekg/dns v1.1.61 h1:nLxbwF3XxhwVSm8g9Dghm9MHPaUZuqhPiGL+675ZmEs=
github.com/miekg/dns v1.1.59/go.mod h1:nZpewl5p6IvctfgrckopVx2OlSEHPRO/U4SYkRklrEk= github.com/miekg/dns v1.1.61/go.mod h1:mnAarhS3nWaW+NVP2wTkYVIZyHNJ098SJZUki3eykwQ=
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0=
@@ -96,8 +100,8 @@ github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXP
github.com/prometheus/client_golang v1.0.0/go.mod h1:db9x61etRT2tGnBNRi70OPL5FsnadC4Ky3P0J6CfImo= github.com/prometheus/client_golang v1.0.0/go.mod h1:db9x61etRT2tGnBNRi70OPL5FsnadC4Ky3P0J6CfImo=
github.com/prometheus/client_golang v1.7.1/go.mod h1:PY5Wy2awLA44sXw4AOSfFBetzPP4j5+D6mVACh+pe2M= github.com/prometheus/client_golang v1.7.1/go.mod h1:PY5Wy2awLA44sXw4AOSfFBetzPP4j5+D6mVACh+pe2M=
github.com/prometheus/client_golang v1.11.0/go.mod h1:Z6t4BnS23TR94PD6BsDNk8yVqroYurpAkEiz0P2BEV0= github.com/prometheus/client_golang v1.11.0/go.mod h1:Z6t4BnS23TR94PD6BsDNk8yVqroYurpAkEiz0P2BEV0=
github.com/prometheus/client_golang v1.19.0 h1:ygXvpU1AoN1MhdzckN+PyD9QJOSD4x7kmXYlnfbA6JU= github.com/prometheus/client_golang v1.19.1 h1:wZWJDwK+NameRJuPGDhlnFgx8e8HN3XHQeLaYJFJBOE=
github.com/prometheus/client_golang v1.19.0/go.mod h1:ZRM9uEAypZakd+q/x7+gmsvXdURP+DABIEIjnmDdp+k= github.com/prometheus/client_golang v1.19.1/go.mod h1:mP78NwGzrVks5S2H6ab8+ZZGJLZUq1hoULYBAYBw1Ho=
github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo= github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo=
github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
github.com/prometheus/client_model v0.2.0/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= github.com/prometheus/client_model v0.2.0/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
@@ -147,16 +151,16 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
golang.org/x/crypto v0.23.0 h1:dIJU/v2J8Mdglj/8rJ6UUOM3Zc9zLZxVZwwxMooUSAI= golang.org/x/crypto v0.26.0 h1:RrRspgV4mU+YwB4FYnuBoKsUapNIL5cohGAmSH3azsw=
golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8= golang.org/x/crypto v0.26.0/go.mod h1:GY7jblb9wI+FOo5y8/S2oY4zWP07AkOJ4+jxCqdqn54=
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 h1:Di6/M8l0O2lCLc6VVRWhgCiApHV8MnQurBnFSHsQtNY= golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 h1:Di6/M8l0O2lCLc6VVRWhgCiApHV8MnQurBnFSHsQtNY=
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc= golang.org/x/exp v0.0.0-20230725093048-515e97ebf090/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc=
golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY= golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY=
golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg=
golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.16.0 h1:QX4fJ0Rr5cPQCF7O9lh9Se4pmwfwskqZfq5moyldzic= golang.org/x/mod v0.18.0 h1:5+9lSbEzPSdWkH32vYPBwEpX8KwDbM52Ud9xBUvNlb0=
golang.org/x/mod v0.16.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= golang.org/x/mod v0.18.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
@@ -167,8 +171,8 @@ golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLL
golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac= golang.org/x/net v0.28.0 h1:a9JDOJc5GMUJ0+UDqmLT86WiEy7iWyIhz8gz8E4e5hE=
golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM= golang.org/x/net v0.28.0/go.mod h1:yqtgsTWOOnlGLG9GFRrK3++bGOUEkNBoHZc8MEDWPNg=
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
@@ -176,8 +180,8 @@ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJ
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ=
golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
@@ -195,11 +199,11 @@ golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210603081109-ebe580a85c40/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210603081109-ebe580a85c40/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y= golang.org/x/sys v0.24.0 h1:Twjiwq9dn6R1fQcyiK+wQyHWfaz/BJB+YIpzU/Cv3Xg=
golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.24.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.20.0 h1:VnkxpohqXaOBYJtBmEppKUG6mXpi+4O6purfc2+sMhw= golang.org/x/term v0.23.0 h1:F6D4vR+EHoL9/sWAWgAR1H2DcHr4PareCbAaCo1RpuU=
golang.org/x/term v0.20.0/go.mod h1:8UkIAJTvZgivsXaD6/pH6U9ecQzZ45awqEOzuCvwpFY= golang.org/x/term v0.23.0/go.mod h1:DgV24QBUrK6jhZXl+20l6UWznPlwAHm1Q1mGHtydmSk=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
@@ -210,8 +214,8 @@ golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtn
golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
golang.org/x/tools v0.19.0 h1:tfGCXNR1OsFG+sVdLAitlpjAvD/I6dHDKnYrpEZUHkw= golang.org/x/tools v0.22.0 h1:gqSGLZqv+AI9lIQzniJ0nZDRG5GBPsSi+DRNHWNz6yA=
golang.org/x/tools v0.19.0/go.mod h1:qoJWxmGSIBmAeriMx19ogtrEPrGtDbPK634QFIcLAhc= golang.org/x/tools v0.22.0/go.mod h1:aCwcsjqvq7Yqt6TNyX7QMU2enbQ/Gt0bo6krSeEri+c=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
@@ -230,8 +234,8 @@ google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miE
google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo= google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo=
google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU=
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFWg= google.golang.org/protobuf v1.34.2 h1:6xV6lTsCfpGD21XK49h7MhtcApnLqkfYgPcdHftf6hg=
google.golang.org/protobuf v1.34.1/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= google.golang.org/protobuf v1.34.2/go.mod h1:qYOHts0dSfpeUzUFpOMr/WGzszTmLH+DiWniOlNbLDw=
gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw= gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=

View File

@@ -1,13 +1,12 @@
package nebula package nebula
import ( import (
"net/netip"
"time" "time"
"github.com/flynn/noise" "github.com/flynn/noise"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula/header" "github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/udp"
) )
// NOISE IX Handshakes // NOISE IX Handshakes
@@ -46,7 +45,6 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
} }
h := header.Encode(make([]byte, header.Len), header.Version, header.Handshake, header.HandshakeIXPSK0, 0, 1) h := header.Encode(make([]byte, header.Len), header.Version, header.Handshake, header.HandshakeIXPSK0, 0, 1)
ci.messageCounter.Add(1)
msg, _, _, err := ci.H.WriteMessage(h, hsBytes) msg, _, _, err := ci.H.WriteMessage(h, hsBytes)
if err != nil { if err != nil {
@@ -64,7 +62,7 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
return true return true
} }
func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []byte, h *header.H) { func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet []byte, h *header.H) {
certState := f.pki.GetCertState() certState := f.pki.GetCertState()
ci := NewConnectionState(f.l, f.cipher, certState, false, noise.HandshakeIX, []byte{}, 0) ci := NewConnectionState(f.l, f.cipher, certState, false, noise.HandshakeIX, []byte{}, 0)
// Mark packet 1 as seen so it doesn't show up as missed // Mark packet 1 as seen so it doesn't show up as missed
@@ -100,12 +98,26 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []by
e.Info("Invalid certificate from host") e.Info("Invalid certificate from host")
return return
} }
vpnIp := iputil.Ip2VpnIp(remoteCert.Details.Ips[0].IP)
vpnIp, ok := netip.AddrFromSlice(remoteCert.Details.Ips[0].IP)
if !ok {
e := f.l.WithError(err).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"})
if f.l.Level > logrus.DebugLevel {
e = e.WithField("cert", remoteCert)
}
e.Info("Invalid vpn ip from host")
return
}
vpnIp = vpnIp.Unmap()
certName := remoteCert.Details.Name certName := remoteCert.Details.Name
fingerprint, _ := remoteCert.Sha256Sum() fingerprint, _ := remoteCert.Sha256Sum()
issuer := remoteCert.Details.Issuer issuer := remoteCert.Details.Issuer
if vpnIp == f.myVpnIp { if vpnIp == f.myVpnNet.Addr() {
f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr). f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
WithField("certName", certName). WithField("certName", certName).
WithField("fingerprint", fingerprint). WithField("fingerprint", fingerprint).
@@ -114,8 +126,8 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []by
return return
} }
if addr != nil { if addr.IsValid() {
if !f.lightHouse.GetRemoteAllowList().Allow(vpnIp, addr.IP) { if !f.lightHouse.GetRemoteAllowList().Allow(vpnIp, addr.Addr()) {
f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake") f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
return return
} }
@@ -139,8 +151,8 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []by
HandshakePacket: make(map[uint8][]byte, 0), HandshakePacket: make(map[uint8][]byte, 0),
lastHandshakeTime: hs.Details.Time, lastHandshakeTime: hs.Details.Time,
relayState: RelayState{ relayState: RelayState{
relays: map[iputil.VpnIp]struct{}{}, relays: nil,
relayForByIp: map[iputil.VpnIp]*Relay{}, relayForByIp: map[netip.Addr]*Relay{},
relayForByIdx: map[uint32]*Relay{}, relayForByIdx: map[uint32]*Relay{},
}, },
} }
@@ -219,7 +231,7 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []by
msg = existing.HandshakePacket[2] msg = existing.HandshakePacket[2]
f.messageMetrics.Tx(header.Handshake, header.MessageSubType(msg[1]), 1) f.messageMetrics.Tx(header.Handshake, header.MessageSubType(msg[1]), 1)
if addr != nil { if addr.IsValid() {
err := f.outside.WriteTo(msg, addr) err := f.outside.WriteTo(msg, addr)
if err != nil { if err != nil {
f.l.WithField("vpnIp", existing.vpnIp).WithField("udpAddr", addr). f.l.WithField("vpnIp", existing.vpnIp).WithField("udpAddr", addr).
@@ -285,7 +297,7 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []by
// Do the send // Do the send
f.messageMetrics.Tx(header.Handshake, header.MessageSubType(msg[1]), 1) f.messageMetrics.Tx(header.Handshake, header.MessageSubType(msg[1]), 1)
if addr != nil { if addr.IsValid() {
err = f.outside.WriteTo(msg, addr) err = f.outside.WriteTo(msg, addr)
if err != nil { if err != nil {
f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr). f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
@@ -310,6 +322,9 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []by
return return
} }
hostinfo.relayState.InsertRelayTo(via.relayHI.vpnIp) hostinfo.relayState.InsertRelayTo(via.relayHI.vpnIp)
// I successfully received a handshake. Just in case I marked this tunnel as 'Disestablished', ensure
// it's correctly marked as working.
via.relayHI.relayState.UpdateRelayForByIdxState(via.remoteIdx, Established)
f.SendVia(via.relayHI, via.relay, msg, make([]byte, 12), make([]byte, mtu), false) f.SendVia(via.relayHI, via.relay, msg, make([]byte, 12), make([]byte, mtu), false)
f.l.WithField("vpnIp", vpnIp).WithField("relay", via.relayHI.vpnIp). f.l.WithField("vpnIp", vpnIp).WithField("relay", via.relayHI.vpnIp).
WithField("certName", certName). WithField("certName", certName).
@@ -320,14 +335,14 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []by
Info("Handshake message sent") Info("Handshake message sent")
} }
f.connectionManager.AddTrafficWatch(hostinfo.localIndexId) f.connectionManager.AddTrafficWatch(hostinfo)
hostinfo.ConnectionState.messageCounter.Store(2)
hostinfo.remotes.ResetBlockedRemotes() hostinfo.remotes.ResetBlockedRemotes()
return return
} }
func ixHandshakeStage2(f *Interface, addr *udp.Addr, via *ViaSender, hh *HandshakeHostInfo, packet []byte, h *header.H) bool { func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *HandshakeHostInfo, packet []byte, h *header.H) bool {
if hh == nil { if hh == nil {
// Nothing here to tear down, got a bogus stage 2 packet // Nothing here to tear down, got a bogus stage 2 packet
return true return true
@@ -337,8 +352,8 @@ func ixHandshakeStage2(f *Interface, addr *udp.Addr, via *ViaSender, hh *Handsha
defer hh.Unlock() defer hh.Unlock()
hostinfo := hh.hostinfo hostinfo := hh.hostinfo
if addr != nil { if addr.IsValid() {
if !f.lightHouse.GetRemoteAllowList().Allow(hostinfo.vpnIp, addr.IP) { if !f.lightHouse.GetRemoteAllowList().Allow(hostinfo.vpnIp, addr.Addr()) {
f.l.WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake") f.l.WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
return false return false
} }
@@ -390,7 +405,20 @@ func ixHandshakeStage2(f *Interface, addr *udp.Addr, via *ViaSender, hh *Handsha
return true return true
} }
vpnIp := iputil.Ip2VpnIp(remoteCert.Details.Ips[0].IP) vpnIp, ok := netip.AddrFromSlice(remoteCert.Details.Ips[0].IP)
if !ok {
e := f.l.WithError(err).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"})
if f.l.Level > logrus.DebugLevel {
e = e.WithField("cert", remoteCert)
}
e.Info("Invalid vpn ip from host")
return true
}
vpnIp = vpnIp.Unmap()
certName := remoteCert.Details.Name certName := remoteCert.Details.Name
fingerprint, _ := remoteCert.Sha256Sum() fingerprint, _ := remoteCert.Sha256Sum()
issuer := remoteCert.Details.Issuer issuer := remoteCert.Details.Issuer
@@ -454,7 +482,7 @@ func ixHandshakeStage2(f *Interface, addr *udp.Addr, via *ViaSender, hh *Handsha
ci.eKey = NewNebulaCipherState(eKey) ci.eKey = NewNebulaCipherState(eKey)
// Make sure the current udpAddr being used is set for responding // Make sure the current udpAddr being used is set for responding
if addr != nil { if addr.IsValid() {
hostinfo.SetRemote(addr) hostinfo.SetRemote(addr)
} else { } else {
hostinfo.relayState.InsertRelayTo(via.relayHI.vpnIp) hostinfo.relayState.InsertRelayTo(via.relayHI.vpnIp)
@@ -465,9 +493,7 @@ func ixHandshakeStage2(f *Interface, addr *udp.Addr, via *ViaSender, hh *Handsha
// Complete our handshake and update metrics, this will replace any existing tunnels for this vpnIp // Complete our handshake and update metrics, this will replace any existing tunnels for this vpnIp
f.handshakeManager.Complete(hostinfo, f) f.handshakeManager.Complete(hostinfo, f)
f.connectionManager.AddTrafficWatch(hostinfo.localIndexId) f.connectionManager.AddTrafficWatch(hostinfo)
hostinfo.ConnectionState.messageCounter.Store(2)
if f.l.Level >= logrus.DebugLevel { if f.l.Level >= logrus.DebugLevel {
hostinfo.logger(f.l).Debugf("Sending %d stored packets", len(hh.packetStore)) hostinfo.logger(f.l).Debugf("Sending %d stored packets", len(hh.packetStore))

View File

@@ -6,15 +6,15 @@ import (
"crypto/rand" "crypto/rand"
"encoding/binary" "encoding/binary"
"errors" "errors"
"net" "net/netip"
"sync" "sync"
"time" "time"
"github.com/rcrowley/go-metrics" "github.com/rcrowley/go-metrics"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula/header" "github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/udp" "github.com/slackhq/nebula/udp"
"golang.org/x/exp/slices"
) )
const ( const (
@@ -35,7 +35,7 @@ var (
type HandshakeConfig struct { type HandshakeConfig struct {
tryInterval time.Duration tryInterval time.Duration
retries int retries int64
triggerBuffer int triggerBuffer int
useRelays bool useRelays bool
@@ -46,14 +46,14 @@ type HandshakeManager struct {
// Mutex for interacting with the vpnIps and indexes maps // Mutex for interacting with the vpnIps and indexes maps
sync.RWMutex sync.RWMutex
vpnIps map[iputil.VpnIp]*HandshakeHostInfo vpnIps map[netip.Addr]*HandshakeHostInfo
indexes map[uint32]*HandshakeHostInfo indexes map[uint32]*HandshakeHostInfo
mainHostMap *HostMap mainHostMap *HostMap
lightHouse *LightHouse lightHouse *LightHouse
outside udp.Conn outside udp.Conn
config HandshakeConfig config HandshakeConfig
OutboundHandshakeTimer *LockingTimerWheel[iputil.VpnIp] OutboundHandshakeTimer *LockingTimerWheel[netip.Addr]
messageMetrics *MessageMetrics messageMetrics *MessageMetrics
metricInitiated metrics.Counter metricInitiated metrics.Counter
metricTimedOut metrics.Counter metricTimedOut metrics.Counter
@@ -61,7 +61,7 @@ type HandshakeManager struct {
l *logrus.Logger l *logrus.Logger
// can be used to trigger outbound handshake for the given vpnIp // can be used to trigger outbound handshake for the given vpnIp
trigger chan iputil.VpnIp trigger chan netip.Addr
} }
type HandshakeHostInfo struct { type HandshakeHostInfo struct {
@@ -69,8 +69,8 @@ type HandshakeHostInfo struct {
startTime time.Time // Time that we first started trying with this handshake startTime time.Time // Time that we first started trying with this handshake
ready bool // Is the handshake ready ready bool // Is the handshake ready
counter int // How many attempts have we made so far counter int64 // How many attempts have we made so far
lastRemotes []*udp.Addr // Remotes that we sent to during the previous attempt lastRemotes []netip.AddrPort // Remotes that we sent to during the previous attempt
packetStore []*cachedPacket // A set of packets to be transmitted once the handshake completes packetStore []*cachedPacket // A set of packets to be transmitted once the handshake completes
hostinfo *HostInfo hostinfo *HostInfo
@@ -103,14 +103,14 @@ func (hh *HandshakeHostInfo) cachePacket(l *logrus.Logger, t header.MessageType,
func NewHandshakeManager(l *logrus.Logger, mainHostMap *HostMap, lightHouse *LightHouse, outside udp.Conn, config HandshakeConfig) *HandshakeManager { func NewHandshakeManager(l *logrus.Logger, mainHostMap *HostMap, lightHouse *LightHouse, outside udp.Conn, config HandshakeConfig) *HandshakeManager {
return &HandshakeManager{ return &HandshakeManager{
vpnIps: map[iputil.VpnIp]*HandshakeHostInfo{}, vpnIps: map[netip.Addr]*HandshakeHostInfo{},
indexes: map[uint32]*HandshakeHostInfo{}, indexes: map[uint32]*HandshakeHostInfo{},
mainHostMap: mainHostMap, mainHostMap: mainHostMap,
lightHouse: lightHouse, lightHouse: lightHouse,
outside: outside, outside: outside,
config: config, config: config,
trigger: make(chan iputil.VpnIp, config.triggerBuffer), trigger: make(chan netip.Addr, config.triggerBuffer),
OutboundHandshakeTimer: NewLockingTimerWheel[iputil.VpnIp](config.tryInterval, hsTimeout(config.retries, config.tryInterval)), OutboundHandshakeTimer: NewLockingTimerWheel[netip.Addr](config.tryInterval, hsTimeout(config.retries, config.tryInterval)),
messageMetrics: config.messageMetrics, messageMetrics: config.messageMetrics,
metricInitiated: metrics.GetOrRegisterCounter("handshake_manager.initiated", nil), metricInitiated: metrics.GetOrRegisterCounter("handshake_manager.initiated", nil),
metricTimedOut: metrics.GetOrRegisterCounter("handshake_manager.timed_out", nil), metricTimedOut: metrics.GetOrRegisterCounter("handshake_manager.timed_out", nil),
@@ -134,10 +134,10 @@ func (c *HandshakeManager) Run(ctx context.Context) {
} }
} }
func (hm *HandshakeManager) HandleIncoming(addr *udp.Addr, via *ViaSender, packet []byte, h *header.H) { func (hm *HandshakeManager) HandleIncoming(addr netip.AddrPort, via *ViaSender, packet []byte, h *header.H) {
// First remote allow list check before we know the vpnIp // First remote allow list check before we know the vpnIp
if addr != nil { if addr.IsValid() {
if !hm.lightHouse.GetRemoteAllowList().AllowUnknownVpnIp(addr.IP) { if !hm.lightHouse.GetRemoteAllowList().AllowUnknownVpnIp(addr.Addr()) {
hm.l.WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake") hm.l.WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
return return
} }
@@ -170,7 +170,7 @@ func (c *HandshakeManager) NextOutboundHandshakeTimerTick(now time.Time) {
} }
} }
func (hm *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTriggered bool) { func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered bool) {
hh := hm.queryVpnIp(vpnIp) hh := hm.queryVpnIp(vpnIp)
if hh == nil { if hh == nil {
return return
@@ -212,7 +212,7 @@ func (hm *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTrigger
} }
remotes := hostinfo.remotes.CopyAddrs(hm.mainHostMap.GetPreferredRanges()) remotes := hostinfo.remotes.CopyAddrs(hm.mainHostMap.GetPreferredRanges())
remotesHaveChanged := !udp.AddrSlice(remotes).Equal(hh.lastRemotes) remotesHaveChanged := !slices.Equal(remotes, hh.lastRemotes)
// We only care about a lighthouse trigger if we have new remotes to send to. // We only care about a lighthouse trigger if we have new remotes to send to.
// This is a very specific optimization for a fast lighthouse reply. // This is a very specific optimization for a fast lighthouse reply.
@@ -234,8 +234,8 @@ func (hm *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTrigger
} }
// Send the handshake to all known ips, stage 2 takes care of assigning the hostinfo.remote based on the first to reply // Send the handshake to all known ips, stage 2 takes care of assigning the hostinfo.remote based on the first to reply
var sentTo []*udp.Addr var sentTo []netip.AddrPort
hostinfo.remotes.ForEach(hm.mainHostMap.GetPreferredRanges(), func(addr *udp.Addr, _ bool) { hostinfo.remotes.ForEach(hm.mainHostMap.GetPreferredRanges(), func(addr netip.AddrPort, _ bool) {
hm.messageMetrics.Tx(header.Handshake, header.MessageSubType(hostinfo.HandshakePacket[0][1]), 1) hm.messageMetrics.Tx(header.Handshake, header.MessageSubType(hostinfo.HandshakePacket[0][1]), 1)
err := hm.outside.WriteTo(hostinfo.HandshakePacket[0], addr) err := hm.outside.WriteTo(hostinfo.HandshakePacket[0], addr)
if err != nil { if err != nil {
@@ -268,30 +268,72 @@ func (hm *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTrigger
// Send a RelayRequest to all known Relay IP's // Send a RelayRequest to all known Relay IP's
for _, relay := range hostinfo.remotes.relays { for _, relay := range hostinfo.remotes.relays {
// Don't relay to myself, and don't relay through the host I'm trying to connect to // Don't relay to myself, and don't relay through the host I'm trying to connect to
if *relay == vpnIp || *relay == hm.lightHouse.myVpnIp { if relay == vpnIp || relay == hm.lightHouse.myVpnNet.Addr() {
continue continue
} }
relayHostInfo := hm.mainHostMap.QueryVpnIp(*relay) relayHostInfo := hm.mainHostMap.QueryVpnIp(relay)
if relayHostInfo == nil || relayHostInfo.remote == nil { if relayHostInfo == nil || !relayHostInfo.remote.IsValid() {
hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Establish tunnel to relay target") hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Establish tunnel to relay target")
hm.f.Handshake(*relay) hm.f.Handshake(relay)
continue continue
} }
// Check the relay HostInfo to see if we already established a relay through it // Check the relay HostInfo to see if we already established a relay through it
if existingRelay, ok := relayHostInfo.relayState.QueryRelayForByIp(vpnIp); ok { existingRelay, ok := relayHostInfo.relayState.QueryRelayForByIp(vpnIp)
if !ok {
// No relays exist or requested yet.
if relayHostInfo.remote.IsValid() {
idx, err := AddRelay(hm.l, relayHostInfo, hm.mainHostMap, vpnIp, nil, TerminalType, Requested)
if err != nil {
hostinfo.logger(hm.l).WithField("relay", relay.String()).WithError(err).Info("Failed to add relay to hostmap")
}
//TODO: IPV6-WORK
myVpnIpB := hm.f.myVpnNet.Addr().As4()
theirVpnIpB := vpnIp.As4()
m := NebulaControl{
Type: NebulaControl_CreateRelayRequest,
InitiatorRelayIndex: idx,
RelayFromIp: binary.BigEndian.Uint32(myVpnIpB[:]),
RelayToIp: binary.BigEndian.Uint32(theirVpnIpB[:]),
}
msg, err := m.Marshal()
if err != nil {
hostinfo.logger(hm.l).
WithError(err).
Error("Failed to marshal Control message to create relay")
} else {
hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu))
hm.l.WithFields(logrus.Fields{
"relayFrom": hm.f.myVpnNet.Addr(),
"relayTo": vpnIp,
"initiatorRelayIndex": idx,
"relay": relay}).
Info("send CreateRelayRequest")
}
}
continue
}
switch existingRelay.State { switch existingRelay.State {
case Established: case Established:
hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Send handshake via relay") hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Send handshake via relay")
hm.f.SendVia(relayHostInfo, existingRelay, hostinfo.HandshakePacket[0], make([]byte, 12), make([]byte, mtu), false) hm.f.SendVia(relayHostInfo, existingRelay, hostinfo.HandshakePacket[0], make([]byte, 12), make([]byte, mtu), false)
case Disestablished:
// Mark this relay as 'requested'
relayHostInfo.relayState.UpdateRelayForByIpState(vpnIp, Requested)
fallthrough
case Requested: case Requested:
hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Re-send CreateRelay request") hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Re-send CreateRelay request")
// Re-send the CreateRelay request, in case the previous one was lost. // Re-send the CreateRelay request, in case the previous one was lost.
relayFrom := hm.f.myVpnNet.Addr().As4()
relayTo := vpnIp.As4()
m := NebulaControl{ m := NebulaControl{
Type: NebulaControl_CreateRelayRequest, Type: NebulaControl_CreateRelayRequest,
InitiatorRelayIndex: existingRelay.LocalIndex, InitiatorRelayIndex: existingRelay.LocalIndex,
RelayFromIp: uint32(hm.lightHouse.myVpnIp), RelayFromIp: binary.BigEndian.Uint32(relayFrom[:]),
RelayToIp: uint32(vpnIp), RelayToIp: binary.BigEndian.Uint32(relayTo[:]),
} }
msg, err := m.Marshal() msg, err := m.Marshal()
if err != nil { if err != nil {
hostinfo.logger(hm.l). hostinfo.logger(hm.l).
@@ -301,49 +343,22 @@ func (hm *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTrigger
// This must send over the hostinfo, not over hm.Hosts[ip] // This must send over the hostinfo, not over hm.Hosts[ip]
hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu)) hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu))
hm.l.WithFields(logrus.Fields{ hm.l.WithFields(logrus.Fields{
"relayFrom": hm.lightHouse.myVpnIp, "relayFrom": hm.f.myVpnNet,
"relayTo": vpnIp, "relayTo": vpnIp,
"initiatorRelayIndex": existingRelay.LocalIndex, "initiatorRelayIndex": existingRelay.LocalIndex,
"relay": *relay}). "relay": relay}).
Info("send CreateRelayRequest") Info("send CreateRelayRequest")
} }
case PeerRequested:
// PeerRequested only occurs in Forwarding relays, not Terminal relays, and this is a Terminal relay case.
fallthrough
default: default:
hostinfo.logger(hm.l). hostinfo.logger(hm.l).
WithField("vpnIp", vpnIp). WithField("vpnIp", vpnIp).
WithField("state", existingRelay.State). WithField("state", existingRelay.State).
WithField("relay", relayHostInfo.vpnIp). WithField("relay", relay).
Errorf("Relay unexpected state") Errorf("Relay unexpected state")
} }
} else {
// No relays exist or requested yet.
if relayHostInfo.remote != nil {
idx, err := AddRelay(hm.l, relayHostInfo, hm.mainHostMap, vpnIp, nil, TerminalType, Requested)
if err != nil {
hostinfo.logger(hm.l).WithField("relay", relay.String()).WithError(err).Info("Failed to add relay to hostmap")
}
m := NebulaControl{
Type: NebulaControl_CreateRelayRequest,
InitiatorRelayIndex: idx,
RelayFromIp: uint32(hm.lightHouse.myVpnIp),
RelayToIp: uint32(vpnIp),
}
msg, err := m.Marshal()
if err != nil {
hostinfo.logger(hm.l).
WithError(err).
Error("Failed to marshal Control message to create relay")
} else {
hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu))
hm.l.WithFields(logrus.Fields{
"relayFrom": hm.lightHouse.myVpnIp,
"relayTo": vpnIp,
"initiatorRelayIndex": idx,
"relay": *relay}).
Info("send CreateRelayRequest")
}
}
}
} }
} }
@@ -355,11 +370,12 @@ func (hm *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTrigger
// GetOrHandshake will try to find a hostinfo with a fully formed tunnel or start a new handshake if one is not present // GetOrHandshake will try to find a hostinfo with a fully formed tunnel or start a new handshake if one is not present
// The 2nd argument will be true if the hostinfo is ready to transmit traffic // The 2nd argument will be true if the hostinfo is ready to transmit traffic
func (hm *HandshakeManager) GetOrHandshake(vpnIp iputil.VpnIp, cacheCb func(*HandshakeHostInfo)) (*HostInfo, bool) { func (hm *HandshakeManager) GetOrHandshake(vpnIp netip.Addr, cacheCb func(*HandshakeHostInfo)) (*HostInfo, bool) {
// Check the main hostmap and maintain a read lock if our host is not there
hm.mainHostMap.RLock() hm.mainHostMap.RLock()
if h, ok := hm.mainHostMap.Hosts[vpnIp]; ok { h, ok := hm.mainHostMap.Hosts[vpnIp]
hm.mainHostMap.RUnlock() hm.mainHostMap.RUnlock()
if ok {
// Do not attempt promotion if you are a lighthouse // Do not attempt promotion if you are a lighthouse
if !hm.lightHouse.amLighthouse { if !hm.lightHouse.amLighthouse {
h.TryPromoteBest(hm.mainHostMap.GetPreferredRanges(), hm.f) h.TryPromoteBest(hm.mainHostMap.GetPreferredRanges(), hm.f)
@@ -367,12 +383,11 @@ func (hm *HandshakeManager) GetOrHandshake(vpnIp iputil.VpnIp, cacheCb func(*Han
return h, true return h, true
} }
defer hm.mainHostMap.RUnlock()
return hm.StartHandshake(vpnIp, cacheCb), false return hm.StartHandshake(vpnIp, cacheCb), false
} }
// StartHandshake will ensure a handshake is currently being attempted for the provided vpn ip // StartHandshake will ensure a handshake is currently being attempted for the provided vpn ip
func (hm *HandshakeManager) StartHandshake(vpnIp iputil.VpnIp, cacheCb func(*HandshakeHostInfo)) *HostInfo { func (hm *HandshakeManager) StartHandshake(vpnIp netip.Addr, cacheCb func(*HandshakeHostInfo)) *HostInfo {
hm.Lock() hm.Lock()
if hh, ok := hm.vpnIps[vpnIp]; ok { if hh, ok := hm.vpnIps[vpnIp]; ok {
@@ -388,8 +403,8 @@ func (hm *HandshakeManager) StartHandshake(vpnIp iputil.VpnIp, cacheCb func(*Han
vpnIp: vpnIp, vpnIp: vpnIp,
HandshakePacket: make(map[uint8][]byte, 0), HandshakePacket: make(map[uint8][]byte, 0),
relayState: RelayState{ relayState: RelayState{
relays: map[iputil.VpnIp]struct{}{}, relays: nil,
relayForByIp: map[iputil.VpnIp]*Relay{}, relayForByIp: map[netip.Addr]*Relay{},
relayForByIdx: map[uint32]*Relay{}, relayForByIdx: map[uint32]*Relay{},
}, },
} }
@@ -479,7 +494,7 @@ func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket
existingPendingIndex, found := c.indexes[hostinfo.localIndexId] existingPendingIndex, found := c.indexes[hostinfo.localIndexId]
if found && existingPendingIndex.hostinfo != hostinfo { if found && existingPendingIndex.hostinfo != hostinfo {
// We have a collision, but for a different hostinfo // We have a collision, but for a different hostinfo
return existingIndex, ErrLocalIndexCollision return existingPendingIndex.hostinfo, ErrLocalIndexCollision
} }
existingRemoteIndex, found := c.mainHostMap.RemoteIndexes[hostinfo.remoteIndexId] existingRemoteIndex, found := c.mainHostMap.RemoteIndexes[hostinfo.remoteIndexId]
@@ -555,7 +570,7 @@ func (c *HandshakeManager) DeleteHostInfo(hostinfo *HostInfo) {
func (c *HandshakeManager) unlockedDeleteHostInfo(hostinfo *HostInfo) { func (c *HandshakeManager) unlockedDeleteHostInfo(hostinfo *HostInfo) {
delete(c.vpnIps, hostinfo.vpnIp) delete(c.vpnIps, hostinfo.vpnIp)
if len(c.vpnIps) == 0 { if len(c.vpnIps) == 0 {
c.vpnIps = map[iputil.VpnIp]*HandshakeHostInfo{} c.vpnIps = map[netip.Addr]*HandshakeHostInfo{}
} }
delete(c.indexes, hostinfo.localIndexId) delete(c.indexes, hostinfo.localIndexId)
@@ -570,7 +585,7 @@ func (c *HandshakeManager) unlockedDeleteHostInfo(hostinfo *HostInfo) {
} }
} }
func (hm *HandshakeManager) QueryVpnIp(vpnIp iputil.VpnIp) *HostInfo { func (hm *HandshakeManager) QueryVpnIp(vpnIp netip.Addr) *HostInfo {
hh := hm.queryVpnIp(vpnIp) hh := hm.queryVpnIp(vpnIp)
if hh != nil { if hh != nil {
return hh.hostinfo return hh.hostinfo
@@ -579,7 +594,7 @@ func (hm *HandshakeManager) QueryVpnIp(vpnIp iputil.VpnIp) *HostInfo {
} }
func (hm *HandshakeManager) queryVpnIp(vpnIp iputil.VpnIp) *HandshakeHostInfo { func (hm *HandshakeManager) queryVpnIp(vpnIp netip.Addr) *HandshakeHostInfo {
hm.RLock() hm.RLock()
defer hm.RUnlock() defer hm.RUnlock()
return hm.vpnIps[vpnIp] return hm.vpnIps[vpnIp]
@@ -599,7 +614,7 @@ func (hm *HandshakeManager) queryIndex(index uint32) *HandshakeHostInfo {
return hm.indexes[index] return hm.indexes[index]
} }
func (c *HandshakeManager) GetPreferredRanges() []*net.IPNet { func (c *HandshakeManager) GetPreferredRanges() []netip.Prefix {
return c.mainHostMap.GetPreferredRanges() return c.mainHostMap.GetPreferredRanges()
} }
@@ -656,6 +671,6 @@ func generateIndex(l *logrus.Logger) (uint32, error) {
return index, nil return index, nil
} }
func hsTimeout(tries int, interval time.Duration) time.Duration { func hsTimeout(tries int64, interval time.Duration) time.Duration {
return time.Duration(tries / 2 * ((2 * int(interval)) + (tries-1)*int(interval))) return time.Duration(tries / 2 * ((2 * int64(interval)) + (tries-1)*int64(interval)))
} }

View File

@@ -1,13 +1,12 @@
package nebula package nebula
import ( import (
"net" "net/netip"
"testing" "testing"
"time" "time"
"github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/header" "github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/test" "github.com/slackhq/nebula/test"
"github.com/slackhq/nebula/udp" "github.com/slackhq/nebula/udp"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@@ -15,10 +14,11 @@ import (
func Test_NewHandshakeManagerVpnIp(t *testing.T) { func Test_NewHandshakeManagerVpnIp(t *testing.T) {
l := test.NewLogger() l := test.NewLogger()
_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24") vpncidr := netip.MustParsePrefix("172.1.1.1/24")
_, localrange, _ := net.ParseCIDR("10.1.1.1/24") localrange := netip.MustParsePrefix("10.1.1.1/24")
ip := iputil.Ip2VpnIp(net.ParseIP("172.1.1.2")) ip := netip.MustParseAddr("172.1.1.2")
preferredRanges := []*net.IPNet{localrange}
preferredRanges := []netip.Prefix{localrange}
mainHM := newHostMap(l, vpncidr) mainHM := newHostMap(l, vpncidr)
mainHM.preferredRanges.Store(&preferredRanges) mainHM.preferredRanges.Store(&preferredRanges)
@@ -66,7 +66,7 @@ func Test_NewHandshakeManagerVpnIp(t *testing.T) {
assert.NotContains(t, blah.vpnIps, ip) assert.NotContains(t, blah.vpnIps, ip)
} }
func testCountTimerWheelEntries(tw *LockingTimerWheel[iputil.VpnIp]) (c int) { func testCountTimerWheelEntries(tw *LockingTimerWheel[netip.Addr]) (c int) {
for _, i := range tw.t.wheel { for _, i := range tw.t.wheel {
n := i.Head n := i.Head
for n != nil { for n != nil {
@@ -80,7 +80,7 @@ func testCountTimerWheelEntries(tw *LockingTimerWheel[iputil.VpnIp]) (c int) {
type mockEncWriter struct { type mockEncWriter struct {
} }
func (mw *mockEncWriter) SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp iputil.VpnIp, p, nb, out []byte) { func (mw *mockEncWriter) SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp netip.Addr, p, nb, out []byte) {
return return
} }
@@ -92,4 +92,4 @@ func (mw *mockEncWriter) SendMessageToHostInfo(t header.MessageType, st header.M
return return
} }
func (mw *mockEncWriter) Handshake(vpnIP iputil.VpnIp) {} func (mw *mockEncWriter) Handshake(vpnIP netip.Addr) {}

View File

@@ -3,18 +3,18 @@ package nebula
import ( import (
"errors" "errors"
"net" "net"
"net/netip"
"slices"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
"github.com/gaissmai/bart"
"github.com/rcrowley/go-metrics" "github.com/rcrowley/go-metrics"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/cidr"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/header" "github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/udp"
) )
// const ProbeLen = 100 // const ProbeLen = 100
@@ -36,6 +36,7 @@ const (
Requested = iota Requested = iota
PeerRequested PeerRequested
Established Established
Disestablished
) )
const ( const (
@@ -49,7 +50,7 @@ type Relay struct {
State int State int
LocalIndex uint32 LocalIndex uint32
RemoteIndex uint32 RemoteIndex uint32
PeerIp iputil.VpnIp PeerIp netip.Addr
} }
type HostMap struct { type HostMap struct {
@@ -57,9 +58,9 @@ type HostMap struct {
Indexes map[uint32]*HostInfo Indexes map[uint32]*HostInfo
Relays map[uint32]*HostInfo // Maps a Relay IDX to a Relay HostInfo object Relays map[uint32]*HostInfo // Maps a Relay IDX to a Relay HostInfo object
RemoteIndexes map[uint32]*HostInfo RemoteIndexes map[uint32]*HostInfo
Hosts map[iputil.VpnIp]*HostInfo Hosts map[netip.Addr]*HostInfo
preferredRanges atomic.Pointer[[]*net.IPNet] preferredRanges atomic.Pointer[[]netip.Prefix]
vpnCIDR *net.IPNet vpnCIDR netip.Prefix
l *logrus.Logger l *logrus.Logger
} }
@@ -69,15 +70,42 @@ type HostMap struct {
type RelayState struct { type RelayState struct {
sync.RWMutex sync.RWMutex
relays map[iputil.VpnIp]struct{} // Set of VpnIp's of Hosts to use as relays to access this peer relays []netip.Addr // Ordered set of VpnIp's of Hosts to use as relays to access this peer
relayForByIp map[iputil.VpnIp]*Relay // Maps VpnIps of peers for which this HostInfo is a relay to some Relay info relayForByIp map[netip.Addr]*Relay // Maps VpnIps of peers for which this HostInfo is a relay to some Relay info
relayForByIdx map[uint32]*Relay // Maps a local index to some Relay info relayForByIdx map[uint32]*Relay // Maps a local index to some Relay info
} }
func (rs *RelayState) DeleteRelay(ip iputil.VpnIp) { func (rs *RelayState) DeleteRelay(ip netip.Addr) {
rs.Lock() rs.Lock()
defer rs.Unlock() defer rs.Unlock()
delete(rs.relays, ip) for idx, val := range rs.relays {
if val == ip {
rs.relays = append(rs.relays[:idx], rs.relays[idx+1:]...)
return
}
}
}
func (rs *RelayState) UpdateRelayForByIpState(vpnIp netip.Addr, state int) {
rs.Lock()
defer rs.Unlock()
if r, ok := rs.relayForByIp[vpnIp]; ok {
newRelay := *r
newRelay.State = state
rs.relayForByIp[newRelay.PeerIp] = &newRelay
rs.relayForByIdx[newRelay.LocalIndex] = &newRelay
}
}
func (rs *RelayState) UpdateRelayForByIdxState(idx uint32, state int) {
rs.Lock()
defer rs.Unlock()
if r, ok := rs.relayForByIdx[idx]; ok {
newRelay := *r
newRelay.State = state
rs.relayForByIp[newRelay.PeerIp] = &newRelay
rs.relayForByIdx[newRelay.LocalIndex] = &newRelay
}
} }
func (rs *RelayState) CopyAllRelayFor() []*Relay { func (rs *RelayState) CopyAllRelayFor() []*Relay {
@@ -90,33 +118,33 @@ func (rs *RelayState) CopyAllRelayFor() []*Relay {
return ret return ret
} }
func (rs *RelayState) GetRelayForByIp(ip iputil.VpnIp) (*Relay, bool) { func (rs *RelayState) GetRelayForByIp(ip netip.Addr) (*Relay, bool) {
rs.RLock() rs.RLock()
defer rs.RUnlock() defer rs.RUnlock()
r, ok := rs.relayForByIp[ip] r, ok := rs.relayForByIp[ip]
return r, ok return r, ok
} }
func (rs *RelayState) InsertRelayTo(ip iputil.VpnIp) { func (rs *RelayState) InsertRelayTo(ip netip.Addr) {
rs.Lock() rs.Lock()
defer rs.Unlock() defer rs.Unlock()
rs.relays[ip] = struct{}{} if !slices.Contains(rs.relays, ip) {
rs.relays = append(rs.relays, ip)
}
} }
func (rs *RelayState) CopyRelayIps() []iputil.VpnIp { func (rs *RelayState) CopyRelayIps() []netip.Addr {
ret := make([]netip.Addr, len(rs.relays))
rs.RLock() rs.RLock()
defer rs.RUnlock() defer rs.RUnlock()
ret := make([]iputil.VpnIp, 0, len(rs.relays)) copy(ret, rs.relays)
for ip := range rs.relays {
ret = append(ret, ip)
}
return ret return ret
} }
func (rs *RelayState) CopyRelayForIps() []iputil.VpnIp { func (rs *RelayState) CopyRelayForIps() []netip.Addr {
rs.RLock() rs.RLock()
defer rs.RUnlock() defer rs.RUnlock()
currentRelays := make([]iputil.VpnIp, 0, len(rs.relayForByIp)) currentRelays := make([]netip.Addr, 0, len(rs.relayForByIp))
for relayIp := range rs.relayForByIp { for relayIp := range rs.relayForByIp {
currentRelays = append(currentRelays, relayIp) currentRelays = append(currentRelays, relayIp)
} }
@@ -133,19 +161,7 @@ func (rs *RelayState) CopyRelayForIdxs() []uint32 {
return ret return ret
} }
func (rs *RelayState) RemoveRelay(localIdx uint32) (iputil.VpnIp, bool) { func (rs *RelayState) CompleteRelayByIP(vpnIp netip.Addr, remoteIdx uint32) bool {
rs.Lock()
defer rs.Unlock()
r, ok := rs.relayForByIdx[localIdx]
if !ok {
return iputil.VpnIp(0), false
}
delete(rs.relayForByIdx, localIdx)
delete(rs.relayForByIp, r.PeerIp)
return r.PeerIp, true
}
func (rs *RelayState) CompleteRelayByIP(vpnIp iputil.VpnIp, remoteIdx uint32) bool {
rs.Lock() rs.Lock()
defer rs.Unlock() defer rs.Unlock()
r, ok := rs.relayForByIp[vpnIp] r, ok := rs.relayForByIp[vpnIp]
@@ -175,7 +191,7 @@ func (rs *RelayState) CompleteRelayByIdx(localIdx uint32, remoteIdx uint32) (*Re
return &newRelay, true return &newRelay, true
} }
func (rs *RelayState) QueryRelayForByIp(vpnIp iputil.VpnIp) (*Relay, bool) { func (rs *RelayState) QueryRelayForByIp(vpnIp netip.Addr) (*Relay, bool) {
rs.RLock() rs.RLock()
defer rs.RUnlock() defer rs.RUnlock()
r, ok := rs.relayForByIp[vpnIp] r, ok := rs.relayForByIp[vpnIp]
@@ -189,7 +205,7 @@ func (rs *RelayState) QueryRelayForByIdx(idx uint32) (*Relay, bool) {
return r, ok return r, ok
} }
func (rs *RelayState) InsertRelay(ip iputil.VpnIp, idx uint32, r *Relay) { func (rs *RelayState) InsertRelay(ip netip.Addr, idx uint32, r *Relay) {
rs.Lock() rs.Lock()
defer rs.Unlock() defer rs.Unlock()
rs.relayForByIp[ip] = r rs.relayForByIp[ip] = r
@@ -197,15 +213,15 @@ func (rs *RelayState) InsertRelay(ip iputil.VpnIp, idx uint32, r *Relay) {
} }
type HostInfo struct { type HostInfo struct {
remote *udp.Addr remote netip.AddrPort
remotes *RemoteList remotes *RemoteList
promoteCounter atomic.Uint32 promoteCounter atomic.Uint32
ConnectionState *ConnectionState ConnectionState *ConnectionState
remoteIndexId uint32 remoteIndexId uint32
localIndexId uint32 localIndexId uint32
vpnIp iputil.VpnIp vpnIp netip.Addr
recvError atomic.Uint32 recvError atomic.Uint32
remoteCidr *cidr.Tree4[struct{}] remoteCidr *bart.Table[struct{}]
relayState RelayState relayState RelayState
// HandshakePacket records the packets used to create this hostinfo // HandshakePacket records the packets used to create this hostinfo
@@ -227,11 +243,19 @@ type HostInfo struct {
lastHandshakeTime uint64 lastHandshakeTime uint64
lastRoam time.Time lastRoam time.Time
lastRoamRemote *udp.Addr lastRoamRemote netip.AddrPort
// Used to track other hostinfos for this vpn ip since only 1 can be primary // Used to track other hostinfos for this vpn ip since only 1 can be primary
// Synchronised via hostmap lock and not the hostinfo lock. // Synchronised via hostmap lock and not the hostinfo lock.
next, prev *HostInfo next, prev *HostInfo
//TODO: in, out, and others might benefit from being an atomic.Int32. We could collapse connectionManager pendingDeletion, relayUsed, and in/out into this 1 thing
in, out, pendingDeletion atomic.Bool
// lastUsed tracks the last time ConnectionManager checked the tunnel and it was in use.
// This value will be behind against actual tunnel utilization in the hot path.
// This should only be used by the ConnectionManagers ticker routine.
lastUsed time.Time
} }
type ViaSender struct { type ViaSender struct {
@@ -254,7 +278,7 @@ type cachedPacketMetrics struct {
dropped metrics.Counter dropped metrics.Counter
} }
func NewHostMapFromConfig(l *logrus.Logger, vpnCIDR *net.IPNet, c *config.C) *HostMap { func NewHostMapFromConfig(l *logrus.Logger, vpnCIDR netip.Prefix, c *config.C) *HostMap {
hm := newHostMap(l, vpnCIDR) hm := newHostMap(l, vpnCIDR)
hm.reload(c, true) hm.reload(c, true)
@@ -269,12 +293,12 @@ func NewHostMapFromConfig(l *logrus.Logger, vpnCIDR *net.IPNet, c *config.C) *Ho
return hm return hm
} }
func newHostMap(l *logrus.Logger, vpnCIDR *net.IPNet) *HostMap { func newHostMap(l *logrus.Logger, vpnCIDR netip.Prefix) *HostMap {
return &HostMap{ return &HostMap{
Indexes: map[uint32]*HostInfo{}, Indexes: map[uint32]*HostInfo{},
Relays: map[uint32]*HostInfo{}, Relays: map[uint32]*HostInfo{},
RemoteIndexes: map[uint32]*HostInfo{}, RemoteIndexes: map[uint32]*HostInfo{},
Hosts: map[iputil.VpnIp]*HostInfo{}, Hosts: map[netip.Addr]*HostInfo{},
vpnCIDR: vpnCIDR, vpnCIDR: vpnCIDR,
l: l, l: l,
} }
@@ -282,11 +306,11 @@ func newHostMap(l *logrus.Logger, vpnCIDR *net.IPNet) *HostMap {
func (hm *HostMap) reload(c *config.C, initial bool) { func (hm *HostMap) reload(c *config.C, initial bool) {
if initial || c.HasChanged("preferred_ranges") { if initial || c.HasChanged("preferred_ranges") {
var preferredRanges []*net.IPNet var preferredRanges []netip.Prefix
rawPreferredRanges := c.GetStringSlice("preferred_ranges", []string{}) rawPreferredRanges := c.GetStringSlice("preferred_ranges", []string{})
for _, rawPreferredRange := range rawPreferredRanges { for _, rawPreferredRange := range rawPreferredRanges {
_, preferredRange, err := net.ParseCIDR(rawPreferredRange) preferredRange, err := netip.ParsePrefix(rawPreferredRange)
if err != nil { if err != nil {
hm.l.WithError(err).WithField("range", rawPreferredRanges).Warn("Failed to parse preferred ranges, ignoring") hm.l.WithError(err).WithField("range", rawPreferredRanges).Warn("Failed to parse preferred ranges, ignoring")
@@ -374,11 +398,12 @@ func (hm *HostMap) unlockedMakePrimary(hostinfo *HostInfo) {
func (hm *HostMap) unlockedDeleteHostInfo(hostinfo *HostInfo) { func (hm *HostMap) unlockedDeleteHostInfo(hostinfo *HostInfo) {
primary, ok := hm.Hosts[hostinfo.vpnIp] primary, ok := hm.Hosts[hostinfo.vpnIp]
isLastHostinfo := hostinfo.next == nil && hostinfo.prev == nil
if ok && primary == hostinfo { if ok && primary == hostinfo {
// The vpnIp pointer points to the same hostinfo as the local index id, we can remove it // The vpnIp pointer points to the same hostinfo as the local index id, we can remove it
delete(hm.Hosts, hostinfo.vpnIp) delete(hm.Hosts, hostinfo.vpnIp)
if len(hm.Hosts) == 0 { if len(hm.Hosts) == 0 {
hm.Hosts = map[iputil.VpnIp]*HostInfo{} hm.Hosts = map[netip.Addr]*HostInfo{}
} }
if hostinfo.next != nil { if hostinfo.next != nil {
@@ -423,6 +448,12 @@ func (hm *HostMap) unlockedDeleteHostInfo(hostinfo *HostInfo) {
Debug("Hostmap hostInfo deleted") Debug("Hostmap hostInfo deleted")
} }
if isLastHostinfo {
// I have lost connectivity to my peers. My relay tunnel is likely broken. Mark the next
// hops as 'Disestablished' so that new relay tunnels are created in the future.
hm.unlockedDisestablishVpnAddrRelayFor(hostinfo)
}
// Clean up any local relay indexes for which I am acting as a relay hop
for _, localRelayIdx := range hostinfo.relayState.CopyRelayForIdxs() { for _, localRelayIdx := range hostinfo.relayState.CopyRelayForIdxs() {
delete(hm.Relays, localRelayIdx) delete(hm.Relays, localRelayIdx)
} }
@@ -461,11 +492,11 @@ func (hm *HostMap) QueryReverseIndex(index uint32) *HostInfo {
} }
} }
func (hm *HostMap) QueryVpnIp(vpnIp iputil.VpnIp) *HostInfo { func (hm *HostMap) QueryVpnIp(vpnIp netip.Addr) *HostInfo {
return hm.queryVpnIp(vpnIp, nil) return hm.queryVpnIp(vpnIp, nil)
} }
func (hm *HostMap) QueryVpnIpRelayFor(targetIp, relayHostIp iputil.VpnIp) (*HostInfo, *Relay, error) { func (hm *HostMap) QueryVpnIpRelayFor(targetIp, relayHostIp netip.Addr) (*HostInfo, *Relay, error) {
hm.RLock() hm.RLock()
defer hm.RUnlock() defer hm.RUnlock()
@@ -483,7 +514,28 @@ func (hm *HostMap) QueryVpnIpRelayFor(targetIp, relayHostIp iputil.VpnIp) (*Host
return nil, nil, errors.New("unable to find host with relay") return nil, nil, errors.New("unable to find host with relay")
} }
func (hm *HostMap) queryVpnIp(vpnIp iputil.VpnIp, promoteIfce *Interface) *HostInfo { func (hm *HostMap) unlockedDisestablishVpnAddrRelayFor(hi *HostInfo) {
for _, relayHostIp := range hi.relayState.CopyRelayIps() {
if h, ok := hm.Hosts[relayHostIp]; ok {
for h != nil {
h.relayState.UpdateRelayForByIpState(hi.vpnIp, Disestablished)
h = h.next
}
}
}
for _, rs := range hi.relayState.CopyAllRelayFor() {
if rs.Type == ForwardingType {
if h, ok := hm.Hosts[rs.PeerIp]; ok {
for h != nil {
h.relayState.UpdateRelayForByIpState(hi.vpnIp, Disestablished)
h = h.next
}
}
}
}
}
func (hm *HostMap) queryVpnIp(vpnIp netip.Addr, promoteIfce *Interface) *HostInfo {
hm.RLock() hm.RLock()
if h, ok := hm.Hosts[vpnIp]; ok { if h, ok := hm.Hosts[vpnIp]; ok {
hm.RUnlock() hm.RUnlock()
@@ -535,7 +587,7 @@ func (hm *HostMap) unlockedAddHostInfo(hostinfo *HostInfo, f *Interface) {
} }
} }
func (hm *HostMap) GetPreferredRanges() []*net.IPNet { func (hm *HostMap) GetPreferredRanges() []netip.Prefix {
//NOTE: if preferredRanges is ever not stored before a load this will fail to dereference a nil pointer //NOTE: if preferredRanges is ever not stored before a load this will fail to dereference a nil pointer
return *hm.preferredRanges.Load() return *hm.preferredRanges.Load()
} }
@@ -560,14 +612,14 @@ func (hm *HostMap) ForEachIndex(f controlEach) {
// TryPromoteBest handles re-querying lighthouses and probing for better paths // TryPromoteBest handles re-querying lighthouses and probing for better paths
// NOTE: It is an error to call this if you are a lighthouse since they should not roam clients! // NOTE: It is an error to call this if you are a lighthouse since they should not roam clients!
func (i *HostInfo) TryPromoteBest(preferredRanges []*net.IPNet, ifce *Interface) { func (i *HostInfo) TryPromoteBest(preferredRanges []netip.Prefix, ifce *Interface) {
c := i.promoteCounter.Add(1) c := i.promoteCounter.Add(1)
if c%ifce.tryPromoteEvery.Load() == 0 { if c%ifce.tryPromoteEvery.Load() == 0 {
remote := i.remote remote := i.remote
// return early if we are already on a preferred remote // return early if we are already on a preferred remote
if remote != nil { if remote.IsValid() {
rIP := remote.IP rIP := remote.Addr()
for _, l := range preferredRanges { for _, l := range preferredRanges {
if l.Contains(rIP) { if l.Contains(rIP) {
return return
@@ -575,8 +627,8 @@ func (i *HostInfo) TryPromoteBest(preferredRanges []*net.IPNet, ifce *Interface)
} }
} }
i.remotes.ForEach(preferredRanges, func(addr *udp.Addr, preferred bool) { i.remotes.ForEach(preferredRanges, func(addr netip.AddrPort, preferred bool) {
if remote != nil && (addr == nil || !preferred) { if remote.IsValid() && (!addr.IsValid() || !preferred) {
return return
} }
@@ -605,23 +657,23 @@ func (i *HostInfo) GetCert() *cert.NebulaCertificate {
return nil return nil
} }
func (i *HostInfo) SetRemote(remote *udp.Addr) { func (i *HostInfo) SetRemote(remote netip.AddrPort) {
// We copy here because we likely got this remote from a source that reuses the object // We copy here because we likely got this remote from a source that reuses the object
if !i.remote.Equals(remote) { if i.remote != remote {
i.remote = remote.Copy() i.remote = remote
i.remotes.LearnRemote(i.vpnIp, remote.Copy()) i.remotes.LearnRemote(i.vpnIp, remote)
} }
} }
// SetRemoteIfPreferred returns true if the remote was changed. The lastRoam // SetRemoteIfPreferred returns true if the remote was changed. The lastRoam
// time on the HostInfo will also be updated. // time on the HostInfo will also be updated.
func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, newRemote *udp.Addr) bool { func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, newRemote netip.AddrPort) bool {
if newRemote == nil { if !newRemote.IsValid() {
// relays have nil udp Addrs // relays have nil udp Addrs
return false return false
} }
currentRemote := i.remote currentRemote := i.remote
if currentRemote == nil { if !currentRemote.IsValid() {
i.SetRemote(newRemote) i.SetRemote(newRemote)
return true return true
} }
@@ -631,11 +683,11 @@ func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, newRemote *udp.Addr) bool {
newIsPreferred := false newIsPreferred := false
for _, l := range hm.GetPreferredRanges() { for _, l := range hm.GetPreferredRanges() {
// return early if we are already on a preferred remote // return early if we are already on a preferred remote
if l.Contains(currentRemote.IP) { if l.Contains(currentRemote.Addr()) {
return false return false
} }
if l.Contains(newRemote.IP) { if l.Contains(newRemote.Addr()) {
newIsPreferred = true newIsPreferred = true
} }
} }
@@ -643,7 +695,7 @@ func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, newRemote *udp.Addr) bool {
if newIsPreferred { if newIsPreferred {
// Consider this a roaming event // Consider this a roaming event
i.lastRoam = time.Now() i.lastRoam = time.Now()
i.lastRoamRemote = currentRemote.Copy() i.lastRoamRemote = currentRemote
i.SetRemote(newRemote) i.SetRemote(newRemote)
@@ -666,13 +718,21 @@ func (i *HostInfo) CreateRemoteCIDR(c *cert.NebulaCertificate) {
return return
} }
remoteCidr := cidr.NewTree4[struct{}]() remoteCidr := new(bart.Table[struct{}])
for _, ip := range c.Details.Ips { for _, ip := range c.Details.Ips {
remoteCidr.AddCIDR(&net.IPNet{IP: ip.IP, Mask: net.IPMask{255, 255, 255, 255}}, struct{}{}) //TODO: IPV6-WORK what to do when ip is invalid?
nip, _ := netip.AddrFromSlice(ip.IP)
nip = nip.Unmap()
bits, _ := ip.Mask.Size()
remoteCidr.Insert(netip.PrefixFrom(nip, bits), struct{}{})
} }
for _, n := range c.Details.Subnets { for _, n := range c.Details.Subnets {
remoteCidr.AddCIDR(n, struct{}{}) //TODO: IPV6-WORK what to do when ip is invalid?
nip, _ := netip.AddrFromSlice(n.IP)
nip = nip.Unmap()
bits, _ := n.Mask.Size()
remoteCidr.Insert(netip.PrefixFrom(nip, bits), struct{}{})
} }
i.remoteCidr = remoteCidr i.remoteCidr = remoteCidr
} }
@@ -697,9 +757,9 @@ func (i *HostInfo) logger(l *logrus.Logger) *logrus.Entry {
// Utility functions // Utility functions
func localIps(l *logrus.Logger, allowList *LocalAllowList) *[]net.IP { func localIps(l *logrus.Logger, allowList *LocalAllowList) []netip.Addr {
//FIXME: This function is pretty garbage //FIXME: This function is pretty garbage
var ips []net.IP var ips []netip.Addr
ifaces, _ := net.Interfaces() ifaces, _ := net.Interfaces()
for _, i := range ifaces { for _, i := range ifaces {
allow := allowList.AllowName(i.Name) allow := allowList.AllowName(i.Name)
@@ -721,20 +781,29 @@ func localIps(l *logrus.Logger, allowList *LocalAllowList) *[]net.IP {
ip = v.IP ip = v.IP
} }
nip, ok := netip.AddrFromSlice(ip)
if !ok {
if l.Level >= logrus.DebugLevel {
l.WithField("localIp", ip).Debug("ip was invalid for netip")
}
continue
}
nip = nip.Unmap()
//TODO: Filtering out link local for now, this is probably the most correct thing //TODO: Filtering out link local for now, this is probably the most correct thing
//TODO: Would be nice to filter out SLAAC MAC based ips as well //TODO: Would be nice to filter out SLAAC MAC based ips as well
if ip.IsLoopback() == false && !ip.IsLinkLocalUnicast() { if nip.IsLoopback() == false && nip.IsLinkLocalUnicast() == false {
allow := allowList.Allow(ip) allow := allowList.Allow(nip)
if l.Level >= logrus.TraceLevel { if l.Level >= logrus.TraceLevel {
l.WithField("localIp", ip).WithField("allow", allow).Trace("localAllowList.Allow") l.WithField("localIp", nip).WithField("allow", allow).Trace("localAllowList.Allow")
} }
if !allow { if !allow {
continue continue
} }
ips = append(ips, ip) ips = append(ips, nip)
} }
} }
} }
return &ips return ips
} }

View File

@@ -1,30 +1,28 @@
package nebula package nebula
import ( import (
"net" "net/netip"
"testing" "testing"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/test" "github.com/slackhq/nebula/test"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
func TestHostMap_MakePrimary(t *testing.T) { func TestHostMap_MakePrimary(t *testing.T) {
l := test.NewLogger() l := test.NewLogger()
hm := newHostMap( hm := newHostMap(
l, l,
&net.IPNet{ netip.MustParsePrefix("10.0.0.1/24"),
IP: net.IP{10, 0, 0, 1},
Mask: net.IPMask{255, 255, 255, 0},
},
) )
f := &Interface{} f := &Interface{}
h1 := &HostInfo{vpnIp: 1, localIndexId: 1} h1 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 1}
h2 := &HostInfo{vpnIp: 1, localIndexId: 2} h2 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 2}
h3 := &HostInfo{vpnIp: 1, localIndexId: 3} h3 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 3}
h4 := &HostInfo{vpnIp: 1, localIndexId: 4} h4 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 4}
hm.unlockedAddHostInfo(h4, f) hm.unlockedAddHostInfo(h4, f)
hm.unlockedAddHostInfo(h3, f) hm.unlockedAddHostInfo(h3, f)
@@ -32,7 +30,7 @@ func TestHostMap_MakePrimary(t *testing.T) {
hm.unlockedAddHostInfo(h1, f) hm.unlockedAddHostInfo(h1, f)
// Make sure we go h1 -> h2 -> h3 -> h4 // Make sure we go h1 -> h2 -> h3 -> h4
prim := hm.QueryVpnIp(1) prim := hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1"))
assert.Equal(t, h1.localIndexId, prim.localIndexId) assert.Equal(t, h1.localIndexId, prim.localIndexId)
assert.Equal(t, h2.localIndexId, prim.next.localIndexId) assert.Equal(t, h2.localIndexId, prim.next.localIndexId)
assert.Nil(t, prim.prev) assert.Nil(t, prim.prev)
@@ -47,7 +45,7 @@ func TestHostMap_MakePrimary(t *testing.T) {
hm.MakePrimary(h3) hm.MakePrimary(h3)
// Make sure we go h3 -> h1 -> h2 -> h4 // Make sure we go h3 -> h1 -> h2 -> h4
prim = hm.QueryVpnIp(1) prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1"))
assert.Equal(t, h3.localIndexId, prim.localIndexId) assert.Equal(t, h3.localIndexId, prim.localIndexId)
assert.Equal(t, h1.localIndexId, prim.next.localIndexId) assert.Equal(t, h1.localIndexId, prim.next.localIndexId)
assert.Nil(t, prim.prev) assert.Nil(t, prim.prev)
@@ -62,7 +60,7 @@ func TestHostMap_MakePrimary(t *testing.T) {
hm.MakePrimary(h4) hm.MakePrimary(h4)
// Make sure we go h4 -> h3 -> h1 -> h2 // Make sure we go h4 -> h3 -> h1 -> h2
prim = hm.QueryVpnIp(1) prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1"))
assert.Equal(t, h4.localIndexId, prim.localIndexId) assert.Equal(t, h4.localIndexId, prim.localIndexId)
assert.Equal(t, h3.localIndexId, prim.next.localIndexId) assert.Equal(t, h3.localIndexId, prim.next.localIndexId)
assert.Nil(t, prim.prev) assert.Nil(t, prim.prev)
@@ -77,7 +75,7 @@ func TestHostMap_MakePrimary(t *testing.T) {
hm.MakePrimary(h4) hm.MakePrimary(h4)
// Make sure we go h4 -> h3 -> h1 -> h2 // Make sure we go h4 -> h3 -> h1 -> h2
prim = hm.QueryVpnIp(1) prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1"))
assert.Equal(t, h4.localIndexId, prim.localIndexId) assert.Equal(t, h4.localIndexId, prim.localIndexId)
assert.Equal(t, h3.localIndexId, prim.next.localIndexId) assert.Equal(t, h3.localIndexId, prim.next.localIndexId)
assert.Nil(t, prim.prev) assert.Nil(t, prim.prev)
@@ -93,20 +91,17 @@ func TestHostMap_DeleteHostInfo(t *testing.T) {
l := test.NewLogger() l := test.NewLogger()
hm := newHostMap( hm := newHostMap(
l, l,
&net.IPNet{ netip.MustParsePrefix("10.0.0.1/24"),
IP: net.IP{10, 0, 0, 1},
Mask: net.IPMask{255, 255, 255, 0},
},
) )
f := &Interface{} f := &Interface{}
h1 := &HostInfo{vpnIp: 1, localIndexId: 1} h1 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 1}
h2 := &HostInfo{vpnIp: 1, localIndexId: 2} h2 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 2}
h3 := &HostInfo{vpnIp: 1, localIndexId: 3} h3 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 3}
h4 := &HostInfo{vpnIp: 1, localIndexId: 4} h4 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 4}
h5 := &HostInfo{vpnIp: 1, localIndexId: 5} h5 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 5}
h6 := &HostInfo{vpnIp: 1, localIndexId: 6} h6 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 6}
hm.unlockedAddHostInfo(h6, f) hm.unlockedAddHostInfo(h6, f)
hm.unlockedAddHostInfo(h5, f) hm.unlockedAddHostInfo(h5, f)
@@ -122,7 +117,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) {
assert.Nil(t, h) assert.Nil(t, h)
// Make sure we go h1 -> h2 -> h3 -> h4 -> h5 // Make sure we go h1 -> h2 -> h3 -> h4 -> h5
prim := hm.QueryVpnIp(1) prim := hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1"))
assert.Equal(t, h1.localIndexId, prim.localIndexId) assert.Equal(t, h1.localIndexId, prim.localIndexId)
assert.Equal(t, h2.localIndexId, prim.next.localIndexId) assert.Equal(t, h2.localIndexId, prim.next.localIndexId)
assert.Nil(t, prim.prev) assert.Nil(t, prim.prev)
@@ -141,7 +136,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) {
assert.Nil(t, h1.next) assert.Nil(t, h1.next)
// Make sure we go h2 -> h3 -> h4 -> h5 // Make sure we go h2 -> h3 -> h4 -> h5
prim = hm.QueryVpnIp(1) prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1"))
assert.Equal(t, h2.localIndexId, prim.localIndexId) assert.Equal(t, h2.localIndexId, prim.localIndexId)
assert.Equal(t, h3.localIndexId, prim.next.localIndexId) assert.Equal(t, h3.localIndexId, prim.next.localIndexId)
assert.Nil(t, prim.prev) assert.Nil(t, prim.prev)
@@ -159,7 +154,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) {
assert.Nil(t, h3.next) assert.Nil(t, h3.next)
// Make sure we go h2 -> h4 -> h5 // Make sure we go h2 -> h4 -> h5
prim = hm.QueryVpnIp(1) prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1"))
assert.Equal(t, h2.localIndexId, prim.localIndexId) assert.Equal(t, h2.localIndexId, prim.localIndexId)
assert.Equal(t, h4.localIndexId, prim.next.localIndexId) assert.Equal(t, h4.localIndexId, prim.next.localIndexId)
assert.Nil(t, prim.prev) assert.Nil(t, prim.prev)
@@ -175,7 +170,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) {
assert.Nil(t, h5.next) assert.Nil(t, h5.next)
// Make sure we go h2 -> h4 // Make sure we go h2 -> h4
prim = hm.QueryVpnIp(1) prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1"))
assert.Equal(t, h2.localIndexId, prim.localIndexId) assert.Equal(t, h2.localIndexId, prim.localIndexId)
assert.Equal(t, h4.localIndexId, prim.next.localIndexId) assert.Equal(t, h4.localIndexId, prim.next.localIndexId)
assert.Nil(t, prim.prev) assert.Nil(t, prim.prev)
@@ -189,7 +184,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) {
assert.Nil(t, h2.next) assert.Nil(t, h2.next)
// Make sure we only have h4 // Make sure we only have h4
prim = hm.QueryVpnIp(1) prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1"))
assert.Equal(t, h4.localIndexId, prim.localIndexId) assert.Equal(t, h4.localIndexId, prim.localIndexId)
assert.Nil(t, prim.prev) assert.Nil(t, prim.prev)
assert.Nil(t, prim.next) assert.Nil(t, prim.next)
@@ -201,7 +196,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) {
assert.Nil(t, h4.next) assert.Nil(t, h4.next)
// Make sure we have nil // Make sure we have nil
prim = hm.QueryVpnIp(1) prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1"))
assert.Nil(t, prim) assert.Nil(t, prim)
} }
@@ -211,14 +206,11 @@ func TestHostMap_reload(t *testing.T) {
hm := NewHostMapFromConfig( hm := NewHostMapFromConfig(
l, l,
&net.IPNet{ netip.MustParsePrefix("10.0.0.1/24"),
IP: net.IP{10, 0, 0, 1},
Mask: net.IPMask{255, 255, 255, 0},
},
c, c,
) )
toS := func(ipn []*net.IPNet) []string { toS := func(ipn []netip.Prefix) []string {
var s []string var s []string
for _, n := range ipn { for _, n := range ipn {
s = append(s, n.String()) s = append(s, n.String())
@@ -234,3 +226,31 @@ func TestHostMap_reload(t *testing.T) {
c.ReloadConfigString("preferred_ranges: [1.1.1.1/32]") c.ReloadConfigString("preferred_ranges: [1.1.1.1/32]")
assert.EqualValues(t, []string{"1.1.1.1/32"}, toS(hm.GetPreferredRanges())) assert.EqualValues(t, []string{"1.1.1.1/32"}, toS(hm.GetPreferredRanges()))
} }
func TestHostMap_RelayState(t *testing.T) {
h1 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 1}
a1 := netip.MustParseAddr("::1")
a2 := netip.MustParseAddr("2001::1")
h1.relayState.InsertRelayTo(a1)
assert.Equal(t, h1.relayState.relays, []netip.Addr{a1})
h1.relayState.InsertRelayTo(a2)
assert.Equal(t, h1.relayState.relays, []netip.Addr{a1, a2})
// Ensure that the first relay added is the first one returned in the copy
currentRelays := h1.relayState.CopyRelayIps()
require.Len(t, currentRelays, 2)
assert.Equal(t, currentRelays[0], a1)
// Deleting the last one in the list works ok
h1.relayState.DeleteRelay(a2)
assert.Equal(t, h1.relayState.relays, []netip.Addr{a1})
// Deleting an element not in the list works ok
h1.relayState.DeleteRelay(a2)
assert.Equal(t, h1.relayState.relays, []netip.Addr{a1})
// Deleting the only element in the list works ok
h1.relayState.DeleteRelay(a1)
assert.Equal(t, h1.relayState.relays, []netip.Addr{})
}

View File

@@ -5,9 +5,11 @@ package nebula
// This file contains functions used to export information to the e2e testing framework // This file contains functions used to export information to the e2e testing framework
import "github.com/slackhq/nebula/iputil" import (
"net/netip"
)
func (i *HostInfo) GetVpnIp() iputil.VpnIp { func (i *HostInfo) GetVpnIp() netip.Addr {
return i.vpnIp return i.vpnIp
} }

View File

@@ -1,12 +1,13 @@
package nebula package nebula
import ( import (
"net/netip"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula/firewall" "github.com/slackhq/nebula/firewall"
"github.com/slackhq/nebula/header" "github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/noiseutil" "github.com/slackhq/nebula/noiseutil"
"github.com/slackhq/nebula/udp"
) )
func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb, out []byte, q int, localCache firewall.ConntrackCache) { func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb, out []byte, q int, localCache firewall.ConntrackCache) {
@@ -19,11 +20,11 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
} }
// Ignore local broadcast packets // Ignore local broadcast packets
if f.dropLocalBroadcast && fwPacket.RemoteIP == f.localBroadcast { if f.dropLocalBroadcast && fwPacket.RemoteIP == f.myBroadcastAddr {
return return
} }
if fwPacket.RemoteIP == f.myVpnIp { if fwPacket.RemoteIP == f.myVpnNet.Addr() {
// Immediately forward packets from self to self. // Immediately forward packets from self to self.
// This should only happen on Darwin-based and FreeBSD hosts, which // This should only happen on Darwin-based and FreeBSD hosts, which
// routes packets from the Nebula IP to the Nebula IP through the Nebula // routes packets from the Nebula IP to the Nebula IP through the Nebula
@@ -39,8 +40,8 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
return return
} }
// Ignore broadcast packets // Ignore multicast packets
if f.dropMulticast && isMulticast(fwPacket.RemoteIP) { if f.dropMulticast && fwPacket.RemoteIP.IsMulticast() {
return return
} }
@@ -64,7 +65,7 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
dropReason := f.firewall.Drop(*fwPacket, false, hostinfo, f.pki.GetCAPool(), localCache) dropReason := f.firewall.Drop(*fwPacket, false, hostinfo, f.pki.GetCAPool(), localCache)
if dropReason == nil { if dropReason == nil {
f.sendNoMetrics(header.Message, 0, hostinfo.ConnectionState, hostinfo, nil, packet, nb, out, q) f.sendNoMetrics(header.Message, 0, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, packet, nb, out, q)
} else { } else {
f.rejectInside(packet, out, q) f.rejectInside(packet, out, q)
@@ -113,19 +114,19 @@ func (f *Interface) rejectOutside(packet []byte, ci *ConnectionState, hostinfo *
return return
} }
f.sendNoMetrics(header.Message, 0, ci, hostinfo, nil, out, nb, packet, q) f.sendNoMetrics(header.Message, 0, ci, hostinfo, netip.AddrPort{}, out, nb, packet, q)
} }
func (f *Interface) Handshake(vpnIp iputil.VpnIp) { func (f *Interface) Handshake(vpnIp netip.Addr) {
f.getOrHandshake(vpnIp, nil) f.getOrHandshake(vpnIp, nil)
} }
// getOrHandshake returns nil if the vpnIp is not routable. // getOrHandshake returns nil if the vpnIp is not routable.
// If the 2nd return var is false then the hostinfo is not ready to be used in a tunnel // If the 2nd return var is false then the hostinfo is not ready to be used in a tunnel
func (f *Interface) getOrHandshake(vpnIp iputil.VpnIp, cacheCallback func(*HandshakeHostInfo)) (*HostInfo, bool) { func (f *Interface) getOrHandshake(vpnIp netip.Addr, cacheCallback func(*HandshakeHostInfo)) (*HostInfo, bool) {
if !ipMaskContains(f.lightHouse.myVpnIp, f.lightHouse.myVpnZeros, vpnIp) { if !f.myVpnNet.Contains(vpnIp) {
vpnIp = f.inside.RouteFor(vpnIp) vpnIp = f.inside.RouteFor(vpnIp)
if vpnIp == 0 { if !vpnIp.IsValid() {
return nil, false return nil, false
} }
} }
@@ -152,11 +153,11 @@ func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubTyp
return return
} }
f.sendNoMetrics(header.Message, st, hostinfo.ConnectionState, hostinfo, nil, p, nb, out, 0) f.sendNoMetrics(header.Message, st, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, p, nb, out, 0)
} }
// SendMessageToVpnIp handles real ip:port lookup and sends to the current best known address for vpnIp // SendMessageToVpnIp handles real ip:port lookup and sends to the current best known address for vpnIp
func (f *Interface) SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp iputil.VpnIp, p, nb, out []byte) { func (f *Interface) SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp netip.Addr, p, nb, out []byte) {
hostInfo, ready := f.getOrHandshake(vpnIp, func(hh *HandshakeHostInfo) { hostInfo, ready := f.getOrHandshake(vpnIp, func(hh *HandshakeHostInfo) {
hh.cachePacket(f.l, t, st, p, f.SendMessageToHostInfo, f.cachedPacketMetrics) hh.cachePacket(f.l, t, st, p, f.SendMessageToHostInfo, f.cachedPacketMetrics)
}) })
@@ -182,10 +183,10 @@ func (f *Interface) SendMessageToHostInfo(t header.MessageType, st header.Messag
func (f *Interface) send(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, p, nb, out []byte) { func (f *Interface) send(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, p, nb, out []byte) {
f.messageMetrics.Tx(t, st, 1) f.messageMetrics.Tx(t, st, 1)
f.sendNoMetrics(t, st, ci, hostinfo, nil, p, nb, out, 0) f.sendNoMetrics(t, st, ci, hostinfo, netip.AddrPort{}, p, nb, out, 0)
} }
func (f *Interface) sendTo(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote *udp.Addr, p, nb, out []byte) { func (f *Interface) sendTo(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote netip.AddrPort, p, nb, out []byte) {
f.messageMetrics.Tx(t, st, 1) f.messageMetrics.Tx(t, st, 1)
f.sendNoMetrics(t, st, ci, hostinfo, remote, p, nb, out, 0) f.sendNoMetrics(t, st, ci, hostinfo, remote, p, nb, out, 0)
} }
@@ -212,7 +213,7 @@ func (f *Interface) SendVia(via *HostInfo,
c := via.ConnectionState.messageCounter.Add(1) c := via.ConnectionState.messageCounter.Add(1)
out = header.Encode(out, header.Version, header.Message, header.MessageRelay, relay.RemoteIndex, c) out = header.Encode(out, header.Version, header.Message, header.MessageRelay, relay.RemoteIndex, c)
f.connectionManager.Out(via.localIndexId) f.connectionManager.Out(via)
// Authenticate the header and payload, but do not encrypt for this message type. // Authenticate the header and payload, but do not encrypt for this message type.
// The payload consists of the inner, unencrypted Nebula header, as well as the end-to-end encrypted payload. // The payload consists of the inner, unencrypted Nebula header, as well as the end-to-end encrypted payload.
@@ -255,12 +256,12 @@ func (f *Interface) SendVia(via *HostInfo,
f.connectionManager.RelayUsed(relay.LocalIndex) f.connectionManager.RelayUsed(relay.LocalIndex)
} }
func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote *udp.Addr, p, nb, out []byte, q int) { func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote netip.AddrPort, p, nb, out []byte, q int) {
if ci.eKey == nil { if ci.eKey == nil {
//TODO: log warning //TODO: log warning
return return
} }
useRelay := remote == nil && hostinfo.remote == nil useRelay := !remote.IsValid() && !hostinfo.remote.IsValid()
fullOut := out fullOut := out
if useRelay { if useRelay {
@@ -281,7 +282,7 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType
//l.WithField("trace", string(debug.Stack())).Error("out Header ", &Header{Version, t, st, 0, hostinfo.remoteIndexId, c}, p) //l.WithField("trace", string(debug.Stack())).Error("out Header ", &Header{Version, t, st, 0, hostinfo.remoteIndexId, c}, p)
out = header.Encode(out, header.Version, t, st, hostinfo.remoteIndexId, c) out = header.Encode(out, header.Version, t, st, hostinfo.remoteIndexId, c)
f.connectionManager.Out(hostinfo.localIndexId) f.connectionManager.Out(hostinfo)
// Query our LH if we haven't since the last time we've been rebound, this will cause the remote to punch against // Query our LH if we haven't since the last time we've been rebound, this will cause the remote to punch against
// all our IPs and enable a faster roaming. // all our IPs and enable a faster roaming.
@@ -308,13 +309,13 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType
return return
} }
if remote != nil { if remote.IsValid() {
err = f.writers[q].WriteTo(out, remote) err = f.writers[q].WriteTo(out, remote)
if err != nil { if err != nil {
hostinfo.logger(f.l).WithError(err). hostinfo.logger(f.l).WithError(err).
WithField("udpAddr", remote).Error("Failed to write outgoing packet") WithField("udpAddr", remote).Error("Failed to write outgoing packet")
} }
} else if hostinfo.remote != nil { } else if hostinfo.remote.IsValid() {
err = f.writers[q].WriteTo(out, hostinfo.remote) err = f.writers[q].WriteTo(out, hostinfo.remote)
if err != nil { if err != nil {
hostinfo.logger(f.l).WithError(err). hostinfo.logger(f.l).WithError(err).
@@ -334,8 +335,3 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType
} }
} }
} }
func isMulticast(ip iputil.VpnIp) bool {
// Class D multicast
return (((ip >> 24) & 0xff) & 0xf0) == 0xe0
}

View File

@@ -2,10 +2,11 @@ package nebula
import ( import (
"context" "context"
"encoding/binary"
"errors" "errors"
"fmt" "fmt"
"io" "io"
"net" "net/netip"
"os" "os"
"runtime" "runtime"
"sync/atomic" "sync/atomic"
@@ -16,7 +17,6 @@ import (
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/firewall" "github.com/slackhq/nebula/firewall"
"github.com/slackhq/nebula/header" "github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/overlay" "github.com/slackhq/nebula/overlay"
"github.com/slackhq/nebula/udp" "github.com/slackhq/nebula/udp"
) )
@@ -33,8 +33,7 @@ type InterfaceConfig struct {
ServeDns bool ServeDns bool
HandshakeManager *HandshakeManager HandshakeManager *HandshakeManager
lightHouse *LightHouse lightHouse *LightHouse
checkInterval time.Duration connectionManager *connectionManager
pendingDeletionInterval time.Duration
DropLocalBroadcast bool DropLocalBroadcast bool
DropMulticast bool DropMulticast bool
routines int routines int
@@ -63,8 +62,8 @@ type Interface struct {
serveDns bool serveDns bool
createTime time.Time createTime time.Time
lightHouse *LightHouse lightHouse *LightHouse
localBroadcast iputil.VpnIp myBroadcastAddr netip.Addr
myVpnIp iputil.VpnIp myVpnNet netip.Prefix
dropLocalBroadcast bool dropLocalBroadcast bool
dropMulticast bool dropMulticast bool
routines int routines int
@@ -102,9 +101,9 @@ type EncWriter interface {
out []byte, out []byte,
nocopy bool, nocopy bool,
) )
SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp iputil.VpnIp, p, nb, out []byte) SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp netip.Addr, p, nb, out []byte)
SendMessageToHostInfo(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, nb, out []byte) SendMessageToHostInfo(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, nb, out []byte)
Handshake(vpnIp iputil.VpnIp) Handshake(vpnIp netip.Addr)
} }
type sendRecvErrorConfig uint8 type sendRecvErrorConfig uint8
@@ -115,10 +114,10 @@ const (
sendRecvErrorPrivate sendRecvErrorPrivate
) )
func (s sendRecvErrorConfig) ShouldSendRecvError(ip net.IP) bool { func (s sendRecvErrorConfig) ShouldSendRecvError(ip netip.AddrPort) bool {
switch s { switch s {
case sendRecvErrorPrivate: case sendRecvErrorPrivate:
return ip.IsPrivate() return ip.Addr().IsPrivate()
case sendRecvErrorAlways: case sendRecvErrorAlways:
return true return true
case sendRecvErrorNever: case sendRecvErrorNever:
@@ -154,9 +153,32 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
if c.Firewall == nil { if c.Firewall == nil {
return nil, errors.New("no firewall rules") return nil, errors.New("no firewall rules")
} }
if c.connectionManager == nil {
return nil, errors.New("no connection manager")
}
certificate := c.pki.GetCertState().Certificate certificate := c.pki.GetCertState().Certificate
myVpnIp := iputil.Ip2VpnIp(certificate.Details.Ips[0].IP)
myVpnAddr, ok := netip.AddrFromSlice(certificate.Details.Ips[0].IP)
if !ok {
return nil, fmt.Errorf("invalid ip address in certificate: %s", certificate.Details.Ips[0].IP)
}
myVpnMask, ok := netip.AddrFromSlice(certificate.Details.Ips[0].Mask)
if !ok {
return nil, fmt.Errorf("invalid ip mask in certificate: %s", certificate.Details.Ips[0].Mask)
}
myVpnAddr = myVpnAddr.Unmap()
myVpnMask = myVpnMask.Unmap()
if myVpnAddr.BitLen() != myVpnMask.BitLen() {
return nil, fmt.Errorf("ip address and mask are different lengths in certificate")
}
ones, _ := certificate.Details.Ips[0].Mask.Size()
myVpnNet := netip.PrefixFrom(myVpnAddr, ones)
ifce := &Interface{ ifce := &Interface{
pki: c.pki, pki: c.pki,
hostMap: c.HostMap, hostMap: c.HostMap,
@@ -168,15 +190,15 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
handshakeManager: c.HandshakeManager, handshakeManager: c.HandshakeManager,
createTime: time.Now(), createTime: time.Now(),
lightHouse: c.lightHouse, lightHouse: c.lightHouse,
localBroadcast: myVpnIp | ^iputil.Ip2VpnIp(certificate.Details.Ips[0].Mask),
dropLocalBroadcast: c.DropLocalBroadcast, dropLocalBroadcast: c.DropLocalBroadcast,
dropMulticast: c.DropMulticast, dropMulticast: c.DropMulticast,
routines: c.routines, routines: c.routines,
version: c.version, version: c.version,
writers: make([]udp.Conn, c.routines), writers: make([]udp.Conn, c.routines),
readers: make([]io.ReadWriteCloser, c.routines), readers: make([]io.ReadWriteCloser, c.routines),
myVpnIp: myVpnIp, myVpnNet: myVpnNet,
relayManager: c.relayManager, relayManager: c.relayManager,
connectionManager: c.connectionManager,
conntrackCacheTimeout: c.ConntrackCacheTimeout, conntrackCacheTimeout: c.ConntrackCacheTimeout,
@@ -190,11 +212,17 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
l: c.l, l: c.l,
} }
if myVpnAddr.Is4() {
addr := myVpnNet.Masked().Addr().As4()
binary.BigEndian.PutUint32(addr[:], binary.BigEndian.Uint32(addr[:])|^binary.BigEndian.Uint32(certificate.Details.Ips[0].Mask))
ifce.myBroadcastAddr = netip.AddrFrom4(addr)
}
ifce.tryPromoteEvery.Store(c.tryPromoteEvery) ifce.tryPromoteEvery.Store(c.tryPromoteEvery)
ifce.reQueryEvery.Store(c.reQueryEvery) ifce.reQueryEvery.Store(c.reQueryEvery)
ifce.reQueryWait.Store(int64(c.reQueryWait)) ifce.reQueryWait.Store(int64(c.reQueryWait))
ifce.connectionManager = newConnectionManager(ctx, c.l, ifce, c.checkInterval, c.pendingDeletionInterval, c.punchy) ifce.connectionManager.intf = ifce
return ifce, nil return ifce, nil
} }

View File

@@ -6,6 +6,8 @@ import (
"golang.org/x/net/ipv4" "golang.org/x/net/ipv4"
) )
//TODO: IPV6-WORK can probably delete this
const ( const (
// Need 96 bytes for the largest reject packet: // Need 96 bytes for the largest reject packet:
// - 20 byte ipv4 header // - 20 byte ipv4 header

View File

@@ -1,93 +0,0 @@
package iputil
import (
"encoding/binary"
"fmt"
"net"
"net/netip"
)
type VpnIp uint32
const maxIPv4StringLen = len("255.255.255.255")
func (ip VpnIp) String() string {
b := make([]byte, maxIPv4StringLen)
n := ubtoa(b, 0, byte(ip>>24))
b[n] = '.'
n++
n += ubtoa(b, n, byte(ip>>16&255))
b[n] = '.'
n++
n += ubtoa(b, n, byte(ip>>8&255))
b[n] = '.'
n++
n += ubtoa(b, n, byte(ip&255))
return string(b[:n])
}
func (ip VpnIp) MarshalJSON() ([]byte, error) {
return []byte(fmt.Sprintf("\"%s\"", ip.String())), nil
}
func (ip VpnIp) ToIP() net.IP {
nip := make(net.IP, 4)
binary.BigEndian.PutUint32(nip, uint32(ip))
return nip
}
func (ip VpnIp) ToNetIpAddr() netip.Addr {
var nip [4]byte
binary.BigEndian.PutUint32(nip[:], uint32(ip))
return netip.AddrFrom4(nip)
}
func Ip2VpnIp(ip []byte) VpnIp {
if len(ip) == 16 {
return VpnIp(binary.BigEndian.Uint32(ip[12:16]))
}
return VpnIp(binary.BigEndian.Uint32(ip))
}
func ToNetIpAddr(ip net.IP) (netip.Addr, error) {
addr, ok := netip.AddrFromSlice(ip)
if !ok {
return netip.Addr{}, fmt.Errorf("invalid net.IP: %v", ip)
}
return addr, nil
}
func ToNetIpPrefix(ipNet net.IPNet) (netip.Prefix, error) {
addr, err := ToNetIpAddr(ipNet.IP)
if err != nil {
return netip.Prefix{}, err
}
ones, bits := ipNet.Mask.Size()
if ones == 0 && bits == 0 {
return netip.Prefix{}, fmt.Errorf("invalid net.IP: %v", ipNet)
}
return netip.PrefixFrom(addr, ones), nil
}
// ubtoa encodes the string form of the integer v to dst[start:] and
// returns the number of bytes written to dst. The caller must ensure
// that dst has sufficient length.
func ubtoa(dst []byte, start int, v byte) int {
if v < 10 {
dst[start] = v + '0'
return 1
} else if v < 100 {
dst[start+1] = v%10 + '0'
dst[start] = v/10 + '0'
return 2
}
dst[start+2] = v%10 + '0'
dst[start+1] = (v/10)%10 + '0'
dst[start] = v/100 + '0'
return 3
}

View File

@@ -1,17 +0,0 @@
package iputil
import (
"net"
"testing"
"github.com/stretchr/testify/assert"
)
func TestVpnIp_String(t *testing.T) {
assert.Equal(t, "255.255.255.255", Ip2VpnIp(net.ParseIP("255.255.255.255")).String())
assert.Equal(t, "1.255.255.255", Ip2VpnIp(net.ParseIP("1.255.255.255")).String())
assert.Equal(t, "1.1.255.255", Ip2VpnIp(net.ParseIP("1.1.255.255")).String())
assert.Equal(t, "1.1.1.255", Ip2VpnIp(net.ParseIP("1.1.1.255")).String())
assert.Equal(t, "1.1.1.1", Ip2VpnIp(net.ParseIP("1.1.1.1")).String())
assert.Equal(t, "0.0.0.0", Ip2VpnIp(net.ParseIP("0.0.0.0")).String())
}

View File

@@ -7,16 +7,16 @@ import (
"fmt" "fmt"
"net" "net"
"net/netip" "net/netip"
"strconv"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
"github.com/gaissmai/bart"
"github.com/rcrowley/go-metrics" "github.com/rcrowley/go-metrics"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cidr"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/header" "github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/udp" "github.com/slackhq/nebula/udp"
"github.com/slackhq/nebula/util" "github.com/slackhq/nebula/util"
) )
@@ -26,25 +26,18 @@ import (
var ErrHostNotKnown = errors.New("host not known") var ErrHostNotKnown = errors.New("host not known")
type netIpAndPort struct {
ip net.IP
port uint16
}
type LightHouse struct { type LightHouse struct {
//TODO: We need a timer wheel to kick out vpnIps that haven't reported in a long time //TODO: We need a timer wheel to kick out vpnIps that haven't reported in a long time
sync.RWMutex //Because we concurrently read and write to our maps sync.RWMutex //Because we concurrently read and write to our maps
ctx context.Context ctx context.Context
amLighthouse bool amLighthouse bool
myVpnIp iputil.VpnIp myVpnNet netip.Prefix
myVpnZeros iputil.VpnIp
myVpnNet *net.IPNet
punchConn udp.Conn punchConn udp.Conn
punchy *Punchy punchy *Punchy
// Local cache of answers from light houses // Local cache of answers from light houses
// map of vpn Ip to answers // map of vpn Ip to answers
addrMap map[iputil.VpnIp]*RemoteList addrMap map[netip.Addr]*RemoteList
// filters remote addresses allowed for each host // filters remote addresses allowed for each host
// - When we are a lighthouse, this filters what addresses we store and // - When we are a lighthouse, this filters what addresses we store and
@@ -57,26 +50,26 @@ type LightHouse struct {
localAllowList atomic.Pointer[LocalAllowList] localAllowList atomic.Pointer[LocalAllowList]
// used to trigger the HandshakeManager when we receive HostQueryReply // used to trigger the HandshakeManager when we receive HostQueryReply
handshakeTrigger chan<- iputil.VpnIp handshakeTrigger chan<- netip.Addr
// staticList exists to avoid having a bool in each addrMap entry // staticList exists to avoid having a bool in each addrMap entry
// since static should be rare // since static should be rare
staticList atomic.Pointer[map[iputil.VpnIp]struct{}] staticList atomic.Pointer[map[netip.Addr]struct{}]
lighthouses atomic.Pointer[map[iputil.VpnIp]struct{}] lighthouses atomic.Pointer[map[netip.Addr]struct{}]
interval atomic.Int64 interval atomic.Int64
updateCancel context.CancelFunc updateCancel context.CancelFunc
ifce EncWriter ifce EncWriter
nebulaPort uint32 // 32 bits because protobuf does not have a uint16 nebulaPort uint32 // 32 bits because protobuf does not have a uint16
advertiseAddrs atomic.Pointer[[]netIpAndPort] advertiseAddrs atomic.Pointer[[]netip.AddrPort]
// IP's of relays that can be used by peers to access me // IP's of relays that can be used by peers to access me
relaysForMe atomic.Pointer[[]iputil.VpnIp] relaysForMe atomic.Pointer[[]netip.Addr]
queryChan chan iputil.VpnIp queryChan chan netip.Addr
calculatedRemotes atomic.Pointer[cidr.Tree4[[]*calculatedRemote]] // Maps VpnIp to []*calculatedRemote calculatedRemotes atomic.Pointer[bart.Table[[]*calculatedRemote]] // Maps VpnIp to []*calculatedRemote
metrics *MessageMetrics metrics *MessageMetrics
metricHolepunchTx metrics.Counter metricHolepunchTx metrics.Counter
@@ -85,7 +78,7 @@ type LightHouse struct {
// NewLightHouseFromConfig will build a Lighthouse struct from the values provided in the config object // NewLightHouseFromConfig will build a Lighthouse struct from the values provided in the config object
// addrMap should be nil unless this is during a config reload // addrMap should be nil unless this is during a config reload
func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C, myVpnNet *net.IPNet, pc udp.Conn, p *Punchy) (*LightHouse, error) { func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C, myVpnNet netip.Prefix, pc udp.Conn, p *Punchy) (*LightHouse, error) {
amLighthouse := c.GetBool("lighthouse.am_lighthouse", false) amLighthouse := c.GetBool("lighthouse.am_lighthouse", false)
nebulaPort := uint32(c.GetInt("listen.port", 0)) nebulaPort := uint32(c.GetInt("listen.port", 0))
if amLighthouse && nebulaPort == 0 { if amLighthouse && nebulaPort == 0 {
@@ -98,26 +91,23 @@ func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C,
if err != nil { if err != nil {
return nil, util.NewContextualError("Failed to get listening port", nil, err) return nil, util.NewContextualError("Failed to get listening port", nil, err)
} }
nebulaPort = uint32(uPort.Port) nebulaPort = uint32(uPort.Port())
} }
ones, _ := myVpnNet.Mask.Size()
h := LightHouse{ h := LightHouse{
ctx: ctx, ctx: ctx,
amLighthouse: amLighthouse, amLighthouse: amLighthouse,
myVpnIp: iputil.Ip2VpnIp(myVpnNet.IP),
myVpnZeros: iputil.VpnIp(32 - ones),
myVpnNet: myVpnNet, myVpnNet: myVpnNet,
addrMap: make(map[iputil.VpnIp]*RemoteList), addrMap: make(map[netip.Addr]*RemoteList),
nebulaPort: nebulaPort, nebulaPort: nebulaPort,
punchConn: pc, punchConn: pc,
punchy: p, punchy: p,
queryChan: make(chan iputil.VpnIp, c.GetUint32("handshakes.query_buffer", 64)), queryChan: make(chan netip.Addr, c.GetUint32("handshakes.query_buffer", 64)),
l: l, l: l,
} }
lighthouses := make(map[iputil.VpnIp]struct{}) lighthouses := make(map[netip.Addr]struct{})
h.lighthouses.Store(&lighthouses) h.lighthouses.Store(&lighthouses)
staticList := make(map[iputil.VpnIp]struct{}) staticList := make(map[netip.Addr]struct{})
h.staticList.Store(&staticList) h.staticList.Store(&staticList)
if c.GetBool("stats.lighthouse_metrics", false) { if c.GetBool("stats.lighthouse_metrics", false) {
@@ -147,11 +137,11 @@ func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C,
return &h, nil return &h, nil
} }
func (lh *LightHouse) GetStaticHostList() map[iputil.VpnIp]struct{} { func (lh *LightHouse) GetStaticHostList() map[netip.Addr]struct{} {
return *lh.staticList.Load() return *lh.staticList.Load()
} }
func (lh *LightHouse) GetLighthouses() map[iputil.VpnIp]struct{} { func (lh *LightHouse) GetLighthouses() map[netip.Addr]struct{} {
return *lh.lighthouses.Load() return *lh.lighthouses.Load()
} }
@@ -163,15 +153,15 @@ func (lh *LightHouse) GetLocalAllowList() *LocalAllowList {
return lh.localAllowList.Load() return lh.localAllowList.Load()
} }
func (lh *LightHouse) GetAdvertiseAddrs() []netIpAndPort { func (lh *LightHouse) GetAdvertiseAddrs() []netip.AddrPort {
return *lh.advertiseAddrs.Load() return *lh.advertiseAddrs.Load()
} }
func (lh *LightHouse) GetRelaysForMe() []iputil.VpnIp { func (lh *LightHouse) GetRelaysForMe() []netip.Addr {
return *lh.relaysForMe.Load() return *lh.relaysForMe.Load()
} }
func (lh *LightHouse) getCalculatedRemotes() *cidr.Tree4[[]*calculatedRemote] { func (lh *LightHouse) getCalculatedRemotes() *bart.Table[[]*calculatedRemote] {
return lh.calculatedRemotes.Load() return lh.calculatedRemotes.Load()
} }
@@ -182,25 +172,40 @@ func (lh *LightHouse) GetUpdateInterval() int64 {
func (lh *LightHouse) reload(c *config.C, initial bool) error { func (lh *LightHouse) reload(c *config.C, initial bool) error {
if initial || c.HasChanged("lighthouse.advertise_addrs") { if initial || c.HasChanged("lighthouse.advertise_addrs") {
rawAdvAddrs := c.GetStringSlice("lighthouse.advertise_addrs", []string{}) rawAdvAddrs := c.GetStringSlice("lighthouse.advertise_addrs", []string{})
advAddrs := make([]netIpAndPort, 0) advAddrs := make([]netip.AddrPort, 0)
for i, rawAddr := range rawAdvAddrs { for i, rawAddr := range rawAdvAddrs {
fIp, fPort, err := udp.ParseIPAndPort(rawAddr) host, sport, err := net.SplitHostPort(rawAddr)
if err != nil { if err != nil {
return util.NewContextualError("Unable to parse lighthouse.advertise_addrs entry", m{"addr": rawAddr, "entry": i + 1}, err) return util.NewContextualError("Unable to parse lighthouse.advertise_addrs entry", m{"addr": rawAddr, "entry": i + 1}, err)
} }
if fPort == 0 { ips, err := net.DefaultResolver.LookupNetIP(context.Background(), "ip", host)
fPort = uint16(lh.nebulaPort) if err != nil {
return util.NewContextualError("Unable to lookup lighthouse.advertise_addrs entry", m{"addr": rawAddr, "entry": i + 1}, err)
}
if len(ips) == 0 {
return util.NewContextualError("Unable to lookup lighthouse.advertise_addrs entry", m{"addr": rawAddr, "entry": i + 1}, nil)
} }
if ip4 := fIp.To4(); ip4 != nil && lh.myVpnNet.Contains(fIp) { port, err := strconv.Atoi(sport)
if err != nil {
return util.NewContextualError("Unable to parse port in lighthouse.advertise_addrs entry", m{"addr": rawAddr, "entry": i + 1}, err)
}
if port == 0 {
port = int(lh.nebulaPort)
}
//TODO: we could technically insert all returned ips instead of just the first one if a dns lookup was used
ip := ips[0].Unmap()
if lh.myVpnNet.Contains(ip) {
lh.l.WithField("addr", rawAddr).WithField("entry", i+1). lh.l.WithField("addr", rawAddr).WithField("entry", i+1).
Warn("Ignoring lighthouse.advertise_addrs report because it is within the nebula network range") Warn("Ignoring lighthouse.advertise_addrs report because it is within the nebula network range")
continue continue
} }
advAddrs = append(advAddrs, netIpAndPort{ip: fIp, port: fPort}) advAddrs = append(advAddrs, netip.AddrPortFrom(ip, uint16(port)))
} }
lh.advertiseAddrs.Store(&advAddrs) lh.advertiseAddrs.Store(&advAddrs)
@@ -278,8 +283,8 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error {
lh.RUnlock() lh.RUnlock()
} }
// Build a new list based on current config. // Build a new list based on current config.
staticList := make(map[iputil.VpnIp]struct{}) staticList := make(map[netip.Addr]struct{})
err := lh.loadStaticMap(c, lh.myVpnNet, staticList) err := lh.loadStaticMap(c, staticList)
if err != nil { if err != nil {
return err return err
} }
@@ -303,8 +308,8 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error {
} }
if initial || c.HasChanged("lighthouse.hosts") { if initial || c.HasChanged("lighthouse.hosts") {
lhMap := make(map[iputil.VpnIp]struct{}) lhMap := make(map[netip.Addr]struct{})
err := lh.parseLighthouses(c, lh.myVpnNet, lhMap) err := lh.parseLighthouses(c, lhMap)
if err != nil { if err != nil {
return err return err
} }
@@ -323,16 +328,17 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error {
if len(c.GetStringSlice("relay.relays", nil)) > 0 { if len(c.GetStringSlice("relay.relays", nil)) > 0 {
lh.l.Info("Ignoring relays from config because am_relay is true") lh.l.Info("Ignoring relays from config because am_relay is true")
} }
relaysForMe := []iputil.VpnIp{} relaysForMe := []netip.Addr{}
lh.relaysForMe.Store(&relaysForMe) lh.relaysForMe.Store(&relaysForMe)
case false: case false:
relaysForMe := []iputil.VpnIp{} relaysForMe := []netip.Addr{}
for _, v := range c.GetStringSlice("relay.relays", nil) { for _, v := range c.GetStringSlice("relay.relays", nil) {
lh.l.WithField("relay", v).Info("Read relay from config") lh.l.WithField("relay", v).Info("Read relay from config")
configRIP := net.ParseIP(v) configRIP, err := netip.ParseAddr(v)
if configRIP != nil { //TODO: We could print the error here
relaysForMe = append(relaysForMe, iputil.Ip2VpnIp(configRIP)) if err == nil {
relaysForMe = append(relaysForMe, configRIP)
} }
} }
lh.relaysForMe.Store(&relaysForMe) lh.relaysForMe.Store(&relaysForMe)
@@ -342,21 +348,21 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error {
return nil return nil
} }
func (lh *LightHouse) parseLighthouses(c *config.C, tunCidr *net.IPNet, lhMap map[iputil.VpnIp]struct{}) error { func (lh *LightHouse) parseLighthouses(c *config.C, lhMap map[netip.Addr]struct{}) error {
lhs := c.GetStringSlice("lighthouse.hosts", []string{}) lhs := c.GetStringSlice("lighthouse.hosts", []string{})
if lh.amLighthouse && len(lhs) != 0 { if lh.amLighthouse && len(lhs) != 0 {
lh.l.Warn("lighthouse.am_lighthouse enabled on node but upstream lighthouses exist in config") lh.l.Warn("lighthouse.am_lighthouse enabled on node but upstream lighthouses exist in config")
} }
for i, host := range lhs { for i, host := range lhs {
ip := net.ParseIP(host) ip, err := netip.ParseAddr(host)
if ip == nil { if err != nil {
return util.NewContextualError("Unable to parse lighthouse host entry", m{"host": host, "entry": i + 1}, nil) return util.NewContextualError("Unable to parse lighthouse host entry", m{"host": host, "entry": i + 1}, err)
} }
if !tunCidr.Contains(ip) { if !lh.myVpnNet.Contains(ip) {
return util.NewContextualError("lighthouse host is not in our subnet, invalid", m{"vpnIp": ip, "network": tunCidr.String()}, nil) return util.NewContextualError("lighthouse host is not in our subnet, invalid", m{"vpnIp": ip, "network": lh.myVpnNet}, nil)
} }
lhMap[iputil.Ip2VpnIp(ip)] = struct{}{} lhMap[ip] = struct{}{}
} }
if !lh.amLighthouse && len(lhMap) == 0 { if !lh.amLighthouse && len(lhMap) == 0 {
@@ -399,7 +405,7 @@ func getStaticMapNetwork(c *config.C) (string, error) {
return network, nil return network, nil
} }
func (lh *LightHouse) loadStaticMap(c *config.C, tunCidr *net.IPNet, staticList map[iputil.VpnIp]struct{}) error { func (lh *LightHouse) loadStaticMap(c *config.C, staticList map[netip.Addr]struct{}) error {
d, err := getStaticMapCadence(c) d, err := getStaticMapCadence(c)
if err != nil { if err != nil {
return err return err
@@ -410,7 +416,7 @@ func (lh *LightHouse) loadStaticMap(c *config.C, tunCidr *net.IPNet, staticList
return err return err
} }
lookup_timeout, err := getStaticMapLookupTimeout(c) lookupTimeout, err := getStaticMapLookupTimeout(c)
if err != nil { if err != nil {
return err return err
} }
@@ -419,16 +425,15 @@ func (lh *LightHouse) loadStaticMap(c *config.C, tunCidr *net.IPNet, staticList
i := 0 i := 0
for k, v := range shm { for k, v := range shm {
rip := net.ParseIP(fmt.Sprintf("%v", k)) vpnIp, err := netip.ParseAddr(fmt.Sprintf("%v", k))
if rip == nil { if err != nil {
return util.NewContextualError("Unable to parse static_host_map entry", m{"host": k, "entry": i + 1}, nil) return util.NewContextualError("Unable to parse static_host_map entry", m{"host": k, "entry": i + 1}, err)
} }
if !tunCidr.Contains(rip) { if !lh.myVpnNet.Contains(vpnIp) {
return util.NewContextualError("static_host_map key is not in our subnet, invalid", m{"vpnIp": rip, "network": tunCidr.String(), "entry": i + 1}, nil) return util.NewContextualError("static_host_map key is not in our subnet, invalid", m{"vpnIp": vpnIp, "network": lh.myVpnNet, "entry": i + 1}, nil)
} }
vpnIp := iputil.Ip2VpnIp(rip)
vals, ok := v.([]interface{}) vals, ok := v.([]interface{})
if !ok { if !ok {
vals = []interface{}{v} vals = []interface{}{v}
@@ -438,7 +443,7 @@ func (lh *LightHouse) loadStaticMap(c *config.C, tunCidr *net.IPNet, staticList
remoteAddrs = append(remoteAddrs, fmt.Sprintf("%v", v)) remoteAddrs = append(remoteAddrs, fmt.Sprintf("%v", v))
} }
err := lh.addStaticRemotes(i, d, network, lookup_timeout, vpnIp, remoteAddrs, staticList) err = lh.addStaticRemotes(i, d, network, lookupTimeout, vpnIp, remoteAddrs, staticList)
if err != nil { if err != nil {
return err return err
} }
@@ -448,7 +453,7 @@ func (lh *LightHouse) loadStaticMap(c *config.C, tunCidr *net.IPNet, staticList
return nil return nil
} }
func (lh *LightHouse) Query(ip iputil.VpnIp) *RemoteList { func (lh *LightHouse) Query(ip netip.Addr) *RemoteList {
if !lh.IsLighthouseIP(ip) { if !lh.IsLighthouseIP(ip) {
lh.QueryServer(ip) lh.QueryServer(ip)
} }
@@ -462,7 +467,7 @@ func (lh *LightHouse) Query(ip iputil.VpnIp) *RemoteList {
} }
// QueryServer is asynchronous so no reply should be expected // QueryServer is asynchronous so no reply should be expected
func (lh *LightHouse) QueryServer(ip iputil.VpnIp) { func (lh *LightHouse) QueryServer(ip netip.Addr) {
// Don't put lighthouse ips in the query channel because we can't query lighthouses about lighthouses // Don't put lighthouse ips in the query channel because we can't query lighthouses about lighthouses
if lh.amLighthouse || lh.IsLighthouseIP(ip) { if lh.amLighthouse || lh.IsLighthouseIP(ip) {
return return
@@ -471,7 +476,7 @@ func (lh *LightHouse) QueryServer(ip iputil.VpnIp) {
lh.queryChan <- ip lh.queryChan <- ip
} }
func (lh *LightHouse) QueryCache(ip iputil.VpnIp) *RemoteList { func (lh *LightHouse) QueryCache(ip netip.Addr) *RemoteList {
lh.RLock() lh.RLock()
if v, ok := lh.addrMap[ip]; ok { if v, ok := lh.addrMap[ip]; ok {
lh.RUnlock() lh.RUnlock()
@@ -488,7 +493,7 @@ func (lh *LightHouse) QueryCache(ip iputil.VpnIp) *RemoteList {
// queryAndPrepMessage is a lock helper on RemoteList, assisting the caller to build a lighthouse message containing // queryAndPrepMessage is a lock helper on RemoteList, assisting the caller to build a lighthouse message containing
// details from the remote list. It looks for a hit in the addrMap and a hit in the RemoteList under the owner vpnIp // details from the remote list. It looks for a hit in the addrMap and a hit in the RemoteList under the owner vpnIp
// If one is found then f() is called with proper locking, f() must return result of n.MarshalTo() // If one is found then f() is called with proper locking, f() must return result of n.MarshalTo()
func (lh *LightHouse) queryAndPrepMessage(vpnIp iputil.VpnIp, f func(*cache) (int, error)) (bool, int, error) { func (lh *LightHouse) queryAndPrepMessage(vpnIp netip.Addr, f func(*cache) (int, error)) (bool, int, error) {
lh.RLock() lh.RLock()
// Do we have an entry in the main cache? // Do we have an entry in the main cache?
if v, ok := lh.addrMap[vpnIp]; ok { if v, ok := lh.addrMap[vpnIp]; ok {
@@ -511,7 +516,7 @@ func (lh *LightHouse) queryAndPrepMessage(vpnIp iputil.VpnIp, f func(*cache) (in
return false, 0, nil return false, 0, nil
} }
func (lh *LightHouse) DeleteVpnIp(vpnIp iputil.VpnIp) { func (lh *LightHouse) DeleteVpnIp(vpnIp netip.Addr) {
// First we check the static mapping // First we check the static mapping
// and do nothing if it is there // and do nothing if it is there
if _, ok := lh.GetStaticHostList()[vpnIp]; ok { if _, ok := lh.GetStaticHostList()[vpnIp]; ok {
@@ -532,7 +537,7 @@ func (lh *LightHouse) DeleteVpnIp(vpnIp iputil.VpnIp) {
// We are the owner because we don't want a lighthouse server to advertise for static hosts it was configured with // We are the owner because we don't want a lighthouse server to advertise for static hosts it was configured with
// And we don't want a lighthouse query reply to interfere with our learned cache if we are a client // And we don't want a lighthouse query reply to interfere with our learned cache if we are a client
// NOTE: this function should not interact with any hot path objects, like lh.staticList, the caller should handle it // NOTE: this function should not interact with any hot path objects, like lh.staticList, the caller should handle it
func (lh *LightHouse) addStaticRemotes(i int, d time.Duration, network string, timeout time.Duration, vpnIp iputil.VpnIp, toAddrs []string, staticList map[iputil.VpnIp]struct{}) error { func (lh *LightHouse) addStaticRemotes(i int, d time.Duration, network string, timeout time.Duration, vpnIp netip.Addr, toAddrs []string, staticList map[netip.Addr]struct{}) error {
lh.Lock() lh.Lock()
am := lh.unlockedGetRemoteList(vpnIp) am := lh.unlockedGetRemoteList(vpnIp)
am.Lock() am.Lock()
@@ -553,20 +558,14 @@ func (lh *LightHouse) addStaticRemotes(i int, d time.Duration, network string, t
am.unlockedSetHostnamesResults(hr) am.unlockedSetHostnamesResults(hr)
for _, addrPort := range hr.GetIPs() { for _, addrPort := range hr.GetIPs() {
if !lh.shouldAdd(vpnIp, addrPort.Addr()) {
continue
}
switch { switch {
case addrPort.Addr().Is4(): case addrPort.Addr().Is4():
to := NewIp4AndPortFromNetIP(addrPort.Addr(), addrPort.Port()) am.unlockedPrependV4(lh.myVpnNet.Addr(), NewIp4AndPortFromNetIP(addrPort.Addr(), addrPort.Port()))
if !lh.unlockedShouldAddV4(vpnIp, to) {
continue
}
am.unlockedPrependV4(lh.myVpnIp, to)
case addrPort.Addr().Is6(): case addrPort.Addr().Is6():
to := NewIp6AndPortFromNetIP(addrPort.Addr(), addrPort.Port()) am.unlockedPrependV6(lh.myVpnNet.Addr(), NewIp6AndPortFromNetIP(addrPort.Addr(), addrPort.Port()))
if !lh.unlockedShouldAddV6(vpnIp, to) {
continue
}
am.unlockedPrependV6(lh.myVpnIp, to)
} }
} }
@@ -578,12 +577,12 @@ func (lh *LightHouse) addStaticRemotes(i int, d time.Duration, network string, t
// addCalculatedRemotes adds any calculated remotes based on the // addCalculatedRemotes adds any calculated remotes based on the
// lighthouse.calculated_remotes configuration. It returns true if any // lighthouse.calculated_remotes configuration. It returns true if any
// calculated remotes were added // calculated remotes were added
func (lh *LightHouse) addCalculatedRemotes(vpnIp iputil.VpnIp) bool { func (lh *LightHouse) addCalculatedRemotes(vpnIp netip.Addr) bool {
tree := lh.getCalculatedRemotes() tree := lh.getCalculatedRemotes()
if tree == nil { if tree == nil {
return false return false
} }
ok, calculatedRemotes := tree.MostSpecificContains(vpnIp) calculatedRemotes, ok := tree.Lookup(vpnIp)
if !ok { if !ok {
return false return false
} }
@@ -602,13 +601,13 @@ func (lh *LightHouse) addCalculatedRemotes(vpnIp iputil.VpnIp) bool {
defer am.Unlock() defer am.Unlock()
lh.Unlock() lh.Unlock()
am.unlockedSetV4(lh.myVpnIp, vpnIp, calculated, lh.unlockedShouldAddV4) am.unlockedSetV4(lh.myVpnNet.Addr(), vpnIp, calculated, lh.unlockedShouldAddV4)
return len(calculated) > 0 return len(calculated) > 0
} }
// unlockedGetRemoteList assumes you have the lh lock // unlockedGetRemoteList assumes you have the lh lock
func (lh *LightHouse) unlockedGetRemoteList(vpnIp iputil.VpnIp) *RemoteList { func (lh *LightHouse) unlockedGetRemoteList(vpnIp netip.Addr) *RemoteList {
am, ok := lh.addrMap[vpnIp] am, ok := lh.addrMap[vpnIp]
if !ok { if !ok {
am = NewRemoteList(func(a netip.Addr) bool { return lh.shouldAdd(vpnIp, a) }) am = NewRemoteList(func(a netip.Addr) bool { return lh.shouldAdd(vpnIp, a) })
@@ -617,44 +616,27 @@ func (lh *LightHouse) unlockedGetRemoteList(vpnIp iputil.VpnIp) *RemoteList {
return am return am
} }
func (lh *LightHouse) shouldAdd(vpnIp iputil.VpnIp, to netip.Addr) bool { func (lh *LightHouse) shouldAdd(vpnIp netip.Addr, to netip.Addr) bool {
switch { allow := lh.GetRemoteAllowList().Allow(vpnIp, to)
case to.Is4():
ipBytes := to.As4()
ip := iputil.Ip2VpnIp(ipBytes[:])
allow := lh.GetRemoteAllowList().AllowIpV4(vpnIp, ip)
if lh.l.Level >= logrus.TraceLevel { if lh.l.Level >= logrus.TraceLevel {
lh.l.WithField("remoteIp", vpnIp).WithField("allow", allow).Trace("remoteAllowList.Allow") lh.l.WithField("remoteIp", vpnIp).WithField("allow", allow).Trace("remoteAllowList.Allow")
} }
if !allow || ipMaskContains(lh.myVpnIp, lh.myVpnZeros, ip) { if !allow || lh.myVpnNet.Contains(to) {
return false return false
} }
case to.Is6():
ipBytes := to.As16()
hi := binary.BigEndian.Uint64(ipBytes[:8])
lo := binary.BigEndian.Uint64(ipBytes[8:])
allow := lh.GetRemoteAllowList().AllowIpV6(vpnIp, hi, lo)
if lh.l.Level >= logrus.TraceLevel {
lh.l.WithField("remoteIp", to).WithField("allow", allow).Trace("remoteAllowList.Allow")
}
// We don't check our vpn network here because nebula does not support ipv6 on the inside
if !allow {
return false
}
}
return true return true
} }
// unlockedShouldAddV4 checks if to is allowed by our allow list // unlockedShouldAddV4 checks if to is allowed by our allow list
func (lh *LightHouse) unlockedShouldAddV4(vpnIp iputil.VpnIp, to *Ip4AndPort) bool { func (lh *LightHouse) unlockedShouldAddV4(vpnIp netip.Addr, to *Ip4AndPort) bool {
allow := lh.GetRemoteAllowList().AllowIpV4(vpnIp, iputil.VpnIp(to.Ip)) ip := AddrPortFromIp4AndPort(to)
allow := lh.GetRemoteAllowList().Allow(vpnIp, ip.Addr())
if lh.l.Level >= logrus.TraceLevel { if lh.l.Level >= logrus.TraceLevel {
lh.l.WithField("remoteIp", vpnIp).WithField("allow", allow).Trace("remoteAllowList.Allow") lh.l.WithField("remoteIp", vpnIp).WithField("allow", allow).Trace("remoteAllowList.Allow")
} }
if !allow || ipMaskContains(lh.myVpnIp, lh.myVpnZeros, iputil.VpnIp(to.Ip)) { if !allow || lh.myVpnNet.Contains(ip.Addr()) {
return false return false
} }
@@ -662,14 +644,14 @@ func (lh *LightHouse) unlockedShouldAddV4(vpnIp iputil.VpnIp, to *Ip4AndPort) bo
} }
// unlockedShouldAddV6 checks if to is allowed by our allow list // unlockedShouldAddV6 checks if to is allowed by our allow list
func (lh *LightHouse) unlockedShouldAddV6(vpnIp iputil.VpnIp, to *Ip6AndPort) bool { func (lh *LightHouse) unlockedShouldAddV6(vpnIp netip.Addr, to *Ip6AndPort) bool {
allow := lh.GetRemoteAllowList().AllowIpV6(vpnIp, to.Hi, to.Lo) ip := AddrPortFromIp6AndPort(to)
allow := lh.GetRemoteAllowList().Allow(vpnIp, ip.Addr())
if lh.l.Level >= logrus.TraceLevel { if lh.l.Level >= logrus.TraceLevel {
lh.l.WithField("remoteIp", lhIp6ToIp(to)).WithField("allow", allow).Trace("remoteAllowList.Allow") lh.l.WithField("remoteIp", lhIp6ToIp(to)).WithField("allow", allow).Trace("remoteAllowList.Allow")
} }
// We don't check our vpn network here because nebula does not support ipv6 on the inside if !allow || lh.myVpnNet.Contains(ip.Addr()) {
if !allow {
return false return false
} }
@@ -683,26 +665,39 @@ func lhIp6ToIp(v *Ip6AndPort) net.IP {
return ip return ip
} }
func (lh *LightHouse) IsLighthouseIP(vpnIp iputil.VpnIp) bool { func (lh *LightHouse) IsLighthouseIP(vpnIp netip.Addr) bool {
if _, ok := lh.GetLighthouses()[vpnIp]; ok { if _, ok := lh.GetLighthouses()[vpnIp]; ok {
return true return true
} }
return false return false
} }
func NewLhQueryByInt(VpnIp iputil.VpnIp) *NebulaMeta { func NewLhQueryByInt(vpnIp netip.Addr) *NebulaMeta {
if vpnIp.Is6() {
//TODO: need to support ipv6
panic("ipv6 is not yet supported")
}
b := vpnIp.As4()
return &NebulaMeta{ return &NebulaMeta{
Type: NebulaMeta_HostQuery, Type: NebulaMeta_HostQuery,
Details: &NebulaMetaDetails{ Details: &NebulaMetaDetails{
VpnIp: uint32(VpnIp), VpnIp: binary.BigEndian.Uint32(b[:]),
}, },
} }
} }
func NewIp4AndPort(ip net.IP, port uint32) *Ip4AndPort { func AddrPortFromIp4AndPort(ip *Ip4AndPort) netip.AddrPort {
ipp := Ip4AndPort{Port: port} b := [4]byte{}
ipp.Ip = uint32(iputil.Ip2VpnIp(ip)) binary.BigEndian.PutUint32(b[:], ip.Ip)
return &ipp return netip.AddrPortFrom(netip.AddrFrom4(b), uint16(ip.Port))
}
func AddrPortFromIp6AndPort(ip *Ip6AndPort) netip.AddrPort {
b := [16]byte{}
binary.BigEndian.PutUint64(b[:8], ip.Hi)
binary.BigEndian.PutUint64(b[8:], ip.Lo)
return netip.AddrPortFrom(netip.AddrFrom16(b), uint16(ip.Port))
} }
func NewIp4AndPortFromNetIP(ip netip.Addr, port uint16) *Ip4AndPort { func NewIp4AndPortFromNetIP(ip netip.Addr, port uint16) *Ip4AndPort {
@@ -713,14 +708,7 @@ func NewIp4AndPortFromNetIP(ip netip.Addr, port uint16) *Ip4AndPort {
} }
} }
func NewIp6AndPort(ip net.IP, port uint32) *Ip6AndPort { // TODO: IPV6-WORK we can delete some more of these
return &Ip6AndPort{
Hi: binary.BigEndian.Uint64(ip[:8]),
Lo: binary.BigEndian.Uint64(ip[8:]),
Port: port,
}
}
func NewIp6AndPortFromNetIP(ip netip.Addr, port uint16) *Ip6AndPort { func NewIp6AndPortFromNetIP(ip netip.Addr, port uint16) *Ip6AndPort {
ip6Addr := ip.As16() ip6Addr := ip.As16()
return &Ip6AndPort{ return &Ip6AndPort{
@@ -729,17 +717,6 @@ func NewIp6AndPortFromNetIP(ip netip.Addr, port uint16) *Ip6AndPort {
Port: uint32(port), Port: uint32(port),
} }
} }
func NewUDPAddrFromLH4(ipp *Ip4AndPort) *udp.Addr {
ip := ipp.Ip
return udp.NewAddr(
net.IPv4(byte(ip&0xff000000>>24), byte(ip&0x00ff0000>>16), byte(ip&0x0000ff00>>8), byte(ip&0x000000ff)),
uint16(ipp.Port),
)
}
func NewUDPAddrFromLH6(ipp *Ip6AndPort) *udp.Addr {
return udp.NewAddr(lhIp6ToIp(ipp), uint16(ipp.Port))
}
func (lh *LightHouse) startQueryWorker() { func (lh *LightHouse) startQueryWorker() {
if lh.amLighthouse { if lh.amLighthouse {
@@ -761,7 +738,7 @@ func (lh *LightHouse) startQueryWorker() {
}() }()
} }
func (lh *LightHouse) innerQueryServer(ip iputil.VpnIp, nb, out []byte) { func (lh *LightHouse) innerQueryServer(ip netip.Addr, nb, out []byte) {
if lh.IsLighthouseIP(ip) { if lh.IsLighthouseIP(ip) {
return return
} }
@@ -812,36 +789,41 @@ func (lh *LightHouse) SendUpdate() {
var v6 []*Ip6AndPort var v6 []*Ip6AndPort
for _, e := range lh.GetAdvertiseAddrs() { for _, e := range lh.GetAdvertiseAddrs() {
if ip := e.ip.To4(); ip != nil { if e.Addr().Is4() {
v4 = append(v4, NewIp4AndPort(e.ip, uint32(e.port))) v4 = append(v4, NewIp4AndPortFromNetIP(e.Addr(), e.Port()))
} else { } else {
v6 = append(v6, NewIp6AndPort(e.ip, uint32(e.port))) v6 = append(v6, NewIp6AndPortFromNetIP(e.Addr(), e.Port()))
} }
} }
lal := lh.GetLocalAllowList() lal := lh.GetLocalAllowList()
for _, e := range *localIps(lh.l, lal) { for _, e := range localIps(lh.l, lal) {
if ip4 := e.To4(); ip4 != nil && ipMaskContains(lh.myVpnIp, lh.myVpnZeros, iputil.Ip2VpnIp(ip4)) { if lh.myVpnNet.Contains(e) {
continue continue
} }
// Only add IPs that aren't my VPN/tun IP // Only add IPs that aren't my VPN/tun IP
if ip := e.To4(); ip != nil { if e.Is4() {
v4 = append(v4, NewIp4AndPort(e, lh.nebulaPort)) v4 = append(v4, NewIp4AndPortFromNetIP(e, uint16(lh.nebulaPort)))
} else { } else {
v6 = append(v6, NewIp6AndPort(e, lh.nebulaPort)) v6 = append(v6, NewIp6AndPortFromNetIP(e, uint16(lh.nebulaPort)))
} }
} }
var relays []uint32 var relays []uint32
for _, r := range lh.GetRelaysForMe() { for _, r := range lh.GetRelaysForMe() {
relays = append(relays, (uint32)(r)) //TODO: IPV6-WORK both relays and vpnip need ipv6 support
b := r.As4()
relays = append(relays, binary.BigEndian.Uint32(b[:]))
} }
//TODO: IPV6-WORK both relays and vpnip need ipv6 support
b := lh.myVpnNet.Addr().As4()
m := &NebulaMeta{ m := &NebulaMeta{
Type: NebulaMeta_HostUpdateNotification, Type: NebulaMeta_HostUpdateNotification,
Details: &NebulaMetaDetails{ Details: &NebulaMetaDetails{
VpnIp: uint32(lh.myVpnIp), VpnIp: binary.BigEndian.Uint32(b[:]),
Ip4AndPorts: v4, Ip4AndPorts: v4,
Ip6AndPorts: v6, Ip6AndPorts: v6,
RelayVpnIp: relays, RelayVpnIp: relays,
@@ -913,12 +895,12 @@ func (lhh *LightHouseHandler) resetMeta() *NebulaMeta {
} }
func lhHandleRequest(lhh *LightHouseHandler, f *Interface) udp.LightHouseHandlerFunc { func lhHandleRequest(lhh *LightHouseHandler, f *Interface) udp.LightHouseHandlerFunc {
return func(rAddr *udp.Addr, vpnIp iputil.VpnIp, p []byte) { return func(rAddr netip.AddrPort, vpnIp netip.Addr, p []byte) {
lhh.HandleRequest(rAddr, vpnIp, p, f) lhh.HandleRequest(rAddr, vpnIp, p, f)
} }
} }
func (lhh *LightHouseHandler) HandleRequest(rAddr *udp.Addr, vpnIp iputil.VpnIp, p []byte, w EncWriter) { func (lhh *LightHouseHandler) HandleRequest(rAddr netip.AddrPort, vpnIp netip.Addr, p []byte, w EncWriter) {
n := lhh.resetMeta() n := lhh.resetMeta()
err := n.Unmarshal(p) err := n.Unmarshal(p)
if err != nil { if err != nil {
@@ -956,7 +938,7 @@ func (lhh *LightHouseHandler) HandleRequest(rAddr *udp.Addr, vpnIp iputil.VpnIp,
} }
} }
func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, vpnIp iputil.VpnIp, addr *udp.Addr, w EncWriter) { func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, vpnIp netip.Addr, addr netip.AddrPort, w EncWriter) {
// Exit if we don't answer queries // Exit if we don't answer queries
if !lhh.lh.amLighthouse { if !lhh.lh.amLighthouse {
if lhh.l.Level >= logrus.DebugLevel { if lhh.l.Level >= logrus.DebugLevel {
@@ -967,8 +949,14 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, vpnIp iputil.VpnIp,
//TODO: we can DRY this further //TODO: we can DRY this further
reqVpnIp := n.Details.VpnIp reqVpnIp := n.Details.VpnIp
//TODO: IPV6-WORK
b := [4]byte{}
binary.BigEndian.PutUint32(b[:], n.Details.VpnIp)
queryVpnIp := netip.AddrFrom4(b)
//TODO: Maybe instead of marshalling into n we marshal into a new `r` to not nuke our current request data //TODO: Maybe instead of marshalling into n we marshal into a new `r` to not nuke our current request data
found, ln, err := lhh.lh.queryAndPrepMessage(iputil.VpnIp(n.Details.VpnIp), func(c *cache) (int, error) { found, ln, err := lhh.lh.queryAndPrepMessage(queryVpnIp, func(c *cache) (int, error) {
n = lhh.resetMeta() n = lhh.resetMeta()
n.Type = NebulaMeta_HostQueryReply n.Type = NebulaMeta_HostQueryReply
n.Details.VpnIp = reqVpnIp n.Details.VpnIp = reqVpnIp
@@ -994,8 +982,9 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, vpnIp iputil.VpnIp,
found, ln, err = lhh.lh.queryAndPrepMessage(vpnIp, func(c *cache) (int, error) { found, ln, err = lhh.lh.queryAndPrepMessage(vpnIp, func(c *cache) (int, error) {
n = lhh.resetMeta() n = lhh.resetMeta()
n.Type = NebulaMeta_HostPunchNotification n.Type = NebulaMeta_HostPunchNotification
n.Details.VpnIp = uint32(vpnIp) //TODO: IPV6-WORK
b = vpnIp.As4()
n.Details.VpnIp = binary.BigEndian.Uint32(b[:])
lhh.coalesceAnswers(c, n) lhh.coalesceAnswers(c, n)
return n.MarshalTo(lhh.pb) return n.MarshalTo(lhh.pb)
@@ -1011,7 +1000,11 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, vpnIp iputil.VpnIp,
} }
lhh.lh.metricTx(NebulaMeta_HostPunchNotification, 1) lhh.lh.metricTx(NebulaMeta_HostPunchNotification, 1)
w.SendMessageToVpnIp(header.LightHouse, 0, iputil.VpnIp(reqVpnIp), lhh.pb[:ln], lhh.nb, lhh.out[:0])
//TODO: IPV6-WORK
binary.BigEndian.PutUint32(b[:], reqVpnIp)
sendTo := netip.AddrFrom4(b)
w.SendMessageToVpnIp(header.LightHouse, 0, sendTo, lhh.pb[:ln], lhh.nb, lhh.out[:0])
} }
func (lhh *LightHouseHandler) coalesceAnswers(c *cache, n *NebulaMeta) { func (lhh *LightHouseHandler) coalesceAnswers(c *cache, n *NebulaMeta) {
@@ -1034,34 +1027,52 @@ func (lhh *LightHouseHandler) coalesceAnswers(c *cache, n *NebulaMeta) {
} }
if c.relay != nil { if c.relay != nil {
n.Details.RelayVpnIp = append(n.Details.RelayVpnIp, c.relay.relay...) //TODO: IPV6-WORK
relays := make([]uint32, len(c.relay.relay))
b := [4]byte{}
for i, _ := range relays {
b = c.relay.relay[i].As4()
relays[i] = binary.BigEndian.Uint32(b[:])
}
n.Details.RelayVpnIp = append(n.Details.RelayVpnIp, relays...)
} }
} }
func (lhh *LightHouseHandler) handleHostQueryReply(n *NebulaMeta, vpnIp iputil.VpnIp) { func (lhh *LightHouseHandler) handleHostQueryReply(n *NebulaMeta, vpnIp netip.Addr) {
if !lhh.lh.IsLighthouseIP(vpnIp) { if !lhh.lh.IsLighthouseIP(vpnIp) {
return return
} }
lhh.lh.Lock() lhh.lh.Lock()
am := lhh.lh.unlockedGetRemoteList(iputil.VpnIp(n.Details.VpnIp)) //TODO: IPV6-WORK
b := [4]byte{}
binary.BigEndian.PutUint32(b[:], n.Details.VpnIp)
certVpnIp := netip.AddrFrom4(b)
am := lhh.lh.unlockedGetRemoteList(certVpnIp)
am.Lock() am.Lock()
lhh.lh.Unlock() lhh.lh.Unlock()
certVpnIp := iputil.VpnIp(n.Details.VpnIp) //TODO: IPV6-WORK
am.unlockedSetV4(vpnIp, certVpnIp, n.Details.Ip4AndPorts, lhh.lh.unlockedShouldAddV4) am.unlockedSetV4(vpnIp, certVpnIp, n.Details.Ip4AndPorts, lhh.lh.unlockedShouldAddV4)
am.unlockedSetV6(vpnIp, certVpnIp, n.Details.Ip6AndPorts, lhh.lh.unlockedShouldAddV6) am.unlockedSetV6(vpnIp, certVpnIp, n.Details.Ip6AndPorts, lhh.lh.unlockedShouldAddV6)
am.unlockedSetRelay(vpnIp, certVpnIp, n.Details.RelayVpnIp)
//TODO: IPV6-WORK
relays := make([]netip.Addr, len(n.Details.RelayVpnIp))
for i, _ := range n.Details.RelayVpnIp {
binary.BigEndian.PutUint32(b[:], n.Details.RelayVpnIp[i])
relays[i] = netip.AddrFrom4(b)
}
am.unlockedSetRelay(vpnIp, certVpnIp, relays)
am.Unlock() am.Unlock()
// Non-blocking attempt to trigger, skip if it would block // Non-blocking attempt to trigger, skip if it would block
select { select {
case lhh.lh.handshakeTrigger <- iputil.VpnIp(n.Details.VpnIp): case lhh.lh.handshakeTrigger <- certVpnIp:
default: default:
} }
} }
func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, vpnIp iputil.VpnIp, w EncWriter) { func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, vpnIp netip.Addr, w EncWriter) {
if !lhh.lh.amLighthouse { if !lhh.lh.amLighthouse {
if lhh.l.Level >= logrus.DebugLevel { if lhh.l.Level >= logrus.DebugLevel {
lhh.l.Debugln("I am not a lighthouse, do not take host updates: ", vpnIp) lhh.l.Debugln("I am not a lighthouse, do not take host updates: ", vpnIp)
@@ -1070,9 +1081,13 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, vpnIp
} }
//Simple check that the host sent this not someone else //Simple check that the host sent this not someone else
if n.Details.VpnIp != uint32(vpnIp) { //TODO: IPV6-WORK
b := [4]byte{}
binary.BigEndian.PutUint32(b[:], n.Details.VpnIp)
detailsVpnIp := netip.AddrFrom4(b)
if detailsVpnIp != vpnIp {
if lhh.l.Level >= logrus.DebugLevel { if lhh.l.Level >= logrus.DebugLevel {
lhh.l.WithField("vpnIp", vpnIp).WithField("answer", iputil.VpnIp(n.Details.VpnIp)).Debugln("Host sent invalid update") lhh.l.WithField("vpnIp", vpnIp).WithField("answer", detailsVpnIp).Debugln("Host sent invalid update")
} }
return return
} }
@@ -1082,15 +1097,24 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, vpnIp
am.Lock() am.Lock()
lhh.lh.Unlock() lhh.lh.Unlock()
certVpnIp := iputil.VpnIp(n.Details.VpnIp) am.unlockedSetV4(vpnIp, detailsVpnIp, n.Details.Ip4AndPorts, lhh.lh.unlockedShouldAddV4)
am.unlockedSetV4(vpnIp, certVpnIp, n.Details.Ip4AndPorts, lhh.lh.unlockedShouldAddV4) am.unlockedSetV6(vpnIp, detailsVpnIp, n.Details.Ip6AndPorts, lhh.lh.unlockedShouldAddV6)
am.unlockedSetV6(vpnIp, certVpnIp, n.Details.Ip6AndPorts, lhh.lh.unlockedShouldAddV6)
am.unlockedSetRelay(vpnIp, certVpnIp, n.Details.RelayVpnIp) //TODO: IPV6-WORK
relays := make([]netip.Addr, len(n.Details.RelayVpnIp))
for i, _ := range n.Details.RelayVpnIp {
binary.BigEndian.PutUint32(b[:], n.Details.RelayVpnIp[i])
relays[i] = netip.AddrFrom4(b)
}
am.unlockedSetRelay(vpnIp, detailsVpnIp, relays)
am.Unlock() am.Unlock()
n = lhh.resetMeta() n = lhh.resetMeta()
n.Type = NebulaMeta_HostUpdateNotificationAck n.Type = NebulaMeta_HostUpdateNotificationAck
n.Details.VpnIp = uint32(vpnIp)
//TODO: IPV6-WORK
vpnIpB := vpnIp.As4()
n.Details.VpnIp = binary.BigEndian.Uint32(vpnIpB[:])
ln, err := n.MarshalTo(lhh.pb) ln, err := n.MarshalTo(lhh.pb)
if err != nil { if err != nil {
@@ -1102,14 +1126,14 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, vpnIp
w.SendMessageToVpnIp(header.LightHouse, 0, vpnIp, lhh.pb[:ln], lhh.nb, lhh.out[:0]) w.SendMessageToVpnIp(header.LightHouse, 0, vpnIp, lhh.pb[:ln], lhh.nb, lhh.out[:0])
} }
func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, vpnIp iputil.VpnIp, w EncWriter) { func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, vpnIp netip.Addr, w EncWriter) {
if !lhh.lh.IsLighthouseIP(vpnIp) { if !lhh.lh.IsLighthouseIP(vpnIp) {
return return
} }
empty := []byte{0} empty := []byte{0}
punch := func(vpnPeer *udp.Addr) { punch := func(vpnPeer netip.AddrPort) {
if vpnPeer == nil { if !vpnPeer.IsValid() {
return return
} }
@@ -1121,23 +1145,29 @@ func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, vpnIp i
if lhh.l.Level >= logrus.DebugLevel { if lhh.l.Level >= logrus.DebugLevel {
//TODO: lacking the ip we are actually punching on, old: l.Debugf("Punching %s on %d for %s", IntIp(a.Ip), a.Port, IntIp(n.Details.VpnIp)) //TODO: lacking the ip we are actually punching on, old: l.Debugf("Punching %s on %d for %s", IntIp(a.Ip), a.Port, IntIp(n.Details.VpnIp))
lhh.l.Debugf("Punching on %d for %s", vpnPeer.Port, iputil.VpnIp(n.Details.VpnIp)) //TODO: IPV6-WORK, make this debug line not suck
b := [4]byte{}
binary.BigEndian.PutUint32(b[:], n.Details.VpnIp)
lhh.l.Debugf("Punching on %d for %v", vpnPeer.Port(), netip.AddrFrom4(b))
} }
} }
for _, a := range n.Details.Ip4AndPorts { for _, a := range n.Details.Ip4AndPorts {
punch(NewUDPAddrFromLH4(a)) punch(AddrPortFromIp4AndPort(a))
} }
for _, a := range n.Details.Ip6AndPorts { for _, a := range n.Details.Ip6AndPorts {
punch(NewUDPAddrFromLH6(a)) punch(AddrPortFromIp6AndPort(a))
} }
// This sends a nebula test packet to the host trying to contact us. In the case // This sends a nebula test packet to the host trying to contact us. In the case
// of a double nat or other difficult scenario, this may help establish // of a double nat or other difficult scenario, this may help establish
// a tunnel. // a tunnel.
if lhh.lh.punchy.GetRespond() { if lhh.lh.punchy.GetRespond() {
queryVpnIp := iputil.VpnIp(n.Details.VpnIp) //TODO: IPV6-WORK
b := [4]byte{}
binary.BigEndian.PutUint32(b[:], n.Details.VpnIp)
queryVpnIp := netip.AddrFrom4(b)
go func() { go func() {
time.Sleep(lhh.lh.punchy.GetRespondDelay()) time.Sleep(lhh.lh.punchy.GetRespondDelay())
if lhh.l.Level >= logrus.DebugLevel { if lhh.l.Level >= logrus.DebugLevel {
@@ -1150,9 +1180,3 @@ func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, vpnIp i
}() }()
} }
} }
// ipMaskContains checks if testIp is contained by ip after applying a cidr
// zeros is 32 - bits from net.IPMask.Size()
func ipMaskContains(ip iputil.VpnIp, zeros iputil.VpnIp, testIp iputil.VpnIp) bool {
return (testIp^ip)>>zeros == 0
}

View File

@@ -2,15 +2,14 @@ package nebula
import ( import (
"context" "context"
"encoding/binary"
"fmt" "fmt"
"net" "net/netip"
"testing" "testing"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/header" "github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/test" "github.com/slackhq/nebula/test"
"github.com/slackhq/nebula/udp"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"gopkg.in/yaml.v2" "gopkg.in/yaml.v2"
) )
@@ -23,15 +22,17 @@ func TestOldIPv4Only(t *testing.T) {
var m Ip4AndPort var m Ip4AndPort
err := m.Unmarshal(b) err := m.Unmarshal(b)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, "10.1.1.1", iputil.VpnIp(m.GetIp()).String()) ip := netip.MustParseAddr("10.1.1.1")
bp := ip.As4()
assert.Equal(t, binary.BigEndian.Uint32(bp[:]), m.GetIp())
} }
func TestNewLhQuery(t *testing.T) { func TestNewLhQuery(t *testing.T) {
myIp := net.ParseIP("192.1.1.1") myIp, err := netip.ParseAddr("192.1.1.1")
myIpint := iputil.Ip2VpnIp(myIp) assert.NoError(t, err)
// Generating a new lh query should work // Generating a new lh query should work
a := NewLhQueryByInt(myIpint) a := NewLhQueryByInt(myIp)
// The result should be a nebulameta protobuf // The result should be a nebulameta protobuf
assert.IsType(t, &NebulaMeta{}, a) assert.IsType(t, &NebulaMeta{}, a)
@@ -49,7 +50,7 @@ func TestNewLhQuery(t *testing.T) {
func Test_lhStaticMapping(t *testing.T) { func Test_lhStaticMapping(t *testing.T) {
l := test.NewLogger() l := test.NewLogger()
_, myVpnNet, _ := net.ParseCIDR("10.128.0.1/16") myVpnNet := netip.MustParsePrefix("10.128.0.1/16")
lh1 := "10.128.0.2" lh1 := "10.128.0.2"
c := config.NewC(l) c := config.NewC(l)
@@ -68,7 +69,7 @@ func Test_lhStaticMapping(t *testing.T) {
func TestReloadLighthouseInterval(t *testing.T) { func TestReloadLighthouseInterval(t *testing.T) {
l := test.NewLogger() l := test.NewLogger()
_, myVpnNet, _ := net.ParseCIDR("10.128.0.1/16") myVpnNet := netip.MustParsePrefix("10.128.0.1/16")
lh1 := "10.128.0.2" lh1 := "10.128.0.2"
c := config.NewC(l) c := config.NewC(l)
@@ -83,21 +84,21 @@ func TestReloadLighthouseInterval(t *testing.T) {
lh.ifce = &mockEncWriter{} lh.ifce = &mockEncWriter{}
// The first one routine is kicked off by main.go currently, lets make sure that one dies // The first one routine is kicked off by main.go currently, lets make sure that one dies
c.ReloadConfigString("lighthouse:\n interval: 5") assert.NoError(t, c.ReloadConfigString("lighthouse:\n interval: 5"))
assert.Equal(t, int64(5), lh.interval.Load()) assert.Equal(t, int64(5), lh.interval.Load())
// Subsequent calls are killed off by the LightHouse.Reload function // Subsequent calls are killed off by the LightHouse.Reload function
c.ReloadConfigString("lighthouse:\n interval: 10") assert.NoError(t, c.ReloadConfigString("lighthouse:\n interval: 10"))
assert.Equal(t, int64(10), lh.interval.Load()) assert.Equal(t, int64(10), lh.interval.Load())
// If this completes then nothing is stealing our reload routine // If this completes then nothing is stealing our reload routine
c.ReloadConfigString("lighthouse:\n interval: 11") assert.NoError(t, c.ReloadConfigString("lighthouse:\n interval: 11"))
assert.Equal(t, int64(11), lh.interval.Load()) assert.Equal(t, int64(11), lh.interval.Load())
} }
func BenchmarkLighthouseHandleRequest(b *testing.B) { func BenchmarkLighthouseHandleRequest(b *testing.B) {
l := test.NewLogger() l := test.NewLogger()
_, myVpnNet, _ := net.ParseCIDR("10.128.0.1/0") myVpnNet := netip.MustParsePrefix("10.128.0.1/0")
c := config.NewC(l) c := config.NewC(l)
lh, err := NewLightHouseFromConfig(context.Background(), l, c, myVpnNet, nil, nil) lh, err := NewLightHouseFromConfig(context.Background(), l, c, myVpnNet, nil, nil)
@@ -105,30 +106,33 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) {
b.Fatal() b.Fatal()
} }
hAddr := udp.NewAddrFromString("4.5.6.7:12345") hAddr := netip.MustParseAddrPort("4.5.6.7:12345")
hAddr2 := udp.NewAddrFromString("4.5.6.7:12346") hAddr2 := netip.MustParseAddrPort("4.5.6.7:12346")
lh.addrMap[3] = NewRemoteList(nil)
lh.addrMap[3].unlockedSetV4( vpnIp3 := netip.MustParseAddr("0.0.0.3")
3, lh.addrMap[vpnIp3] = NewRemoteList(nil)
3, lh.addrMap[vpnIp3].unlockedSetV4(
vpnIp3,
vpnIp3,
[]*Ip4AndPort{ []*Ip4AndPort{
NewIp4AndPort(hAddr.IP, uint32(hAddr.Port)), NewIp4AndPortFromNetIP(hAddr.Addr(), hAddr.Port()),
NewIp4AndPort(hAddr2.IP, uint32(hAddr2.Port)), NewIp4AndPortFromNetIP(hAddr2.Addr(), hAddr2.Port()),
}, },
func(iputil.VpnIp, *Ip4AndPort) bool { return true }, func(netip.Addr, *Ip4AndPort) bool { return true },
) )
rAddr := udp.NewAddrFromString("1.2.2.3:12345") rAddr := netip.MustParseAddrPort("1.2.2.3:12345")
rAddr2 := udp.NewAddrFromString("1.2.2.3:12346") rAddr2 := netip.MustParseAddrPort("1.2.2.3:12346")
lh.addrMap[2] = NewRemoteList(nil) vpnIp2 := netip.MustParseAddr("0.0.0.3")
lh.addrMap[2].unlockedSetV4( lh.addrMap[vpnIp2] = NewRemoteList(nil)
3, lh.addrMap[vpnIp2].unlockedSetV4(
3, vpnIp3,
vpnIp3,
[]*Ip4AndPort{ []*Ip4AndPort{
NewIp4AndPort(rAddr.IP, uint32(rAddr.Port)), NewIp4AndPortFromNetIP(rAddr.Addr(), rAddr.Port()),
NewIp4AndPort(rAddr2.IP, uint32(rAddr2.Port)), NewIp4AndPortFromNetIP(rAddr2.Addr(), rAddr2.Port()),
}, },
func(iputil.VpnIp, *Ip4AndPort) bool { return true }, func(netip.Addr, *Ip4AndPort) bool { return true },
) )
mw := &mockEncWriter{} mw := &mockEncWriter{}
@@ -145,7 +149,7 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) {
p, err := req.Marshal() p, err := req.Marshal()
assert.NoError(b, err) assert.NoError(b, err)
for n := 0; n < b.N; n++ { for n := 0; n < b.N; n++ {
lhh.HandleRequest(rAddr, 2, p, mw) lhh.HandleRequest(rAddr, vpnIp2, p, mw)
} }
}) })
b.Run("found", func(b *testing.B) { b.Run("found", func(b *testing.B) {
@@ -161,7 +165,7 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) {
assert.NoError(b, err) assert.NoError(b, err)
for n := 0; n < b.N; n++ { for n := 0; n < b.N; n++ {
lhh.HandleRequest(rAddr, 2, p, mw) lhh.HandleRequest(rAddr, vpnIp2, p, mw)
} }
}) })
} }
@@ -169,51 +173,51 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) {
func TestLighthouse_Memory(t *testing.T) { func TestLighthouse_Memory(t *testing.T) {
l := test.NewLogger() l := test.NewLogger()
myUdpAddr0 := &udp.Addr{IP: net.ParseIP("10.0.0.2"), Port: 4242} myUdpAddr0 := netip.MustParseAddrPort("10.0.0.2:4242")
myUdpAddr1 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4242} myUdpAddr1 := netip.MustParseAddrPort("192.168.0.2:4242")
myUdpAddr2 := &udp.Addr{IP: net.ParseIP("172.16.0.2"), Port: 4242} myUdpAddr2 := netip.MustParseAddrPort("172.16.0.2:4242")
myUdpAddr3 := &udp.Addr{IP: net.ParseIP("100.152.0.2"), Port: 4242} myUdpAddr3 := netip.MustParseAddrPort("100.152.0.2:4242")
myUdpAddr4 := &udp.Addr{IP: net.ParseIP("24.15.0.2"), Port: 4242} myUdpAddr4 := netip.MustParseAddrPort("24.15.0.2:4242")
myUdpAddr5 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4243} myUdpAddr5 := netip.MustParseAddrPort("192.168.0.2:4243")
myUdpAddr6 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4244} myUdpAddr6 := netip.MustParseAddrPort("192.168.0.2:4244")
myUdpAddr7 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4245} myUdpAddr7 := netip.MustParseAddrPort("192.168.0.2:4245")
myUdpAddr8 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4246} myUdpAddr8 := netip.MustParseAddrPort("192.168.0.2:4246")
myUdpAddr9 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4247} myUdpAddr9 := netip.MustParseAddrPort("192.168.0.2:4247")
myUdpAddr10 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4248} myUdpAddr10 := netip.MustParseAddrPort("192.168.0.2:4248")
myUdpAddr11 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4249} myUdpAddr11 := netip.MustParseAddrPort("192.168.0.2:4249")
myVpnIp := iputil.Ip2VpnIp(net.ParseIP("10.128.0.2")) myVpnIp := netip.MustParseAddr("10.128.0.2")
theirUdpAddr0 := &udp.Addr{IP: net.ParseIP("10.0.0.3"), Port: 4242} theirUdpAddr0 := netip.MustParseAddrPort("10.0.0.3:4242")
theirUdpAddr1 := &udp.Addr{IP: net.ParseIP("192.168.0.3"), Port: 4242} theirUdpAddr1 := netip.MustParseAddrPort("192.168.0.3:4242")
theirUdpAddr2 := &udp.Addr{IP: net.ParseIP("172.16.0.3"), Port: 4242} theirUdpAddr2 := netip.MustParseAddrPort("172.16.0.3:4242")
theirUdpAddr3 := &udp.Addr{IP: net.ParseIP("100.152.0.3"), Port: 4242} theirUdpAddr3 := netip.MustParseAddrPort("100.152.0.3:4242")
theirUdpAddr4 := &udp.Addr{IP: net.ParseIP("24.15.0.3"), Port: 4242} theirUdpAddr4 := netip.MustParseAddrPort("24.15.0.3:4242")
theirVpnIp := iputil.Ip2VpnIp(net.ParseIP("10.128.0.3")) theirVpnIp := netip.MustParseAddr("10.128.0.3")
c := config.NewC(l) c := config.NewC(l)
c.Settings["lighthouse"] = map[interface{}]interface{}{"am_lighthouse": true} c.Settings["lighthouse"] = map[interface{}]interface{}{"am_lighthouse": true}
c.Settings["listen"] = map[interface{}]interface{}{"port": 4242} c.Settings["listen"] = map[interface{}]interface{}{"port": 4242}
lh, err := NewLightHouseFromConfig(context.Background(), l, c, &net.IPNet{IP: net.IP{10, 128, 0, 1}, Mask: net.IPMask{255, 255, 255, 0}}, nil, nil) lh, err := NewLightHouseFromConfig(context.Background(), l, c, netip.MustParsePrefix("10.128.0.1/24"), nil, nil)
assert.NoError(t, err) assert.NoError(t, err)
lhh := lh.NewRequestHandler() lhh := lh.NewRequestHandler()
// Test that my first update responds with just that // Test that my first update responds with just that
newLHHostUpdate(myUdpAddr0, myVpnIp, []*udp.Addr{myUdpAddr1, myUdpAddr2}, lhh) newLHHostUpdate(myUdpAddr0, myVpnIp, []netip.AddrPort{myUdpAddr1, myUdpAddr2}, lhh)
r := newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh) r := newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr1, myUdpAddr2) assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr1, myUdpAddr2)
// Ensure we don't accumulate addresses // Ensure we don't accumulate addresses
newLHHostUpdate(myUdpAddr0, myVpnIp, []*udp.Addr{myUdpAddr3}, lhh) newLHHostUpdate(myUdpAddr0, myVpnIp, []netip.AddrPort{myUdpAddr3}, lhh)
r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh) r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr3) assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr3)
// Grow it back to 2 // Grow it back to 2
newLHHostUpdate(myUdpAddr0, myVpnIp, []*udp.Addr{myUdpAddr1, myUdpAddr4}, lhh) newLHHostUpdate(myUdpAddr0, myVpnIp, []netip.AddrPort{myUdpAddr1, myUdpAddr4}, lhh)
r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh) r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr1, myUdpAddr4) assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr1, myUdpAddr4)
// Update a different host and ask about it // Update a different host and ask about it
newLHHostUpdate(theirUdpAddr0, theirVpnIp, []*udp.Addr{theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4}, lhh) newLHHostUpdate(theirUdpAddr0, theirVpnIp, []netip.AddrPort{theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4}, lhh)
r = newLHHostRequest(theirUdpAddr0, theirVpnIp, theirVpnIp, lhh) r = newLHHostRequest(theirUdpAddr0, theirVpnIp, theirVpnIp, lhh)
assertIp4InArray(t, r.msg.Details.Ip4AndPorts, theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4) assertIp4InArray(t, r.msg.Details.Ip4AndPorts, theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4)
@@ -233,7 +237,7 @@ func TestLighthouse_Memory(t *testing.T) {
newLHHostUpdate( newLHHostUpdate(
myUdpAddr0, myUdpAddr0,
myVpnIp, myVpnIp,
[]*udp.Addr{ []netip.AddrPort{
myUdpAddr1, myUdpAddr1,
myUdpAddr2, myUdpAddr2,
myUdpAddr3, myUdpAddr3,
@@ -256,10 +260,10 @@ func TestLighthouse_Memory(t *testing.T) {
) )
// Make sure we won't add ips in our vpn network // Make sure we won't add ips in our vpn network
bad1 := &udp.Addr{IP: net.ParseIP("10.128.0.99"), Port: 4242} bad1 := netip.MustParseAddrPort("10.128.0.99:4242")
bad2 := &udp.Addr{IP: net.ParseIP("10.128.0.100"), Port: 4242} bad2 := netip.MustParseAddrPort("10.128.0.100:4242")
good := &udp.Addr{IP: net.ParseIP("1.128.0.99"), Port: 4242} good := netip.MustParseAddrPort("1.128.0.99:4242")
newLHHostUpdate(myUdpAddr0, myVpnIp, []*udp.Addr{bad1, bad2, good}, lhh) newLHHostUpdate(myUdpAddr0, myVpnIp, []netip.AddrPort{bad1, bad2, good}, lhh)
r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh) r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
assertIp4InArray(t, r.msg.Details.Ip4AndPorts, good) assertIp4InArray(t, r.msg.Details.Ip4AndPorts, good)
} }
@@ -269,7 +273,7 @@ func TestLighthouse_reload(t *testing.T) {
c := config.NewC(l) c := config.NewC(l)
c.Settings["lighthouse"] = map[interface{}]interface{}{"am_lighthouse": true} c.Settings["lighthouse"] = map[interface{}]interface{}{"am_lighthouse": true}
c.Settings["listen"] = map[interface{}]interface{}{"port": 4242} c.Settings["listen"] = map[interface{}]interface{}{"port": 4242}
lh, err := NewLightHouseFromConfig(context.Background(), l, c, &net.IPNet{IP: net.IP{10, 128, 0, 1}, Mask: net.IPMask{255, 255, 255, 0}}, nil, nil) lh, err := NewLightHouseFromConfig(context.Background(), l, c, netip.MustParsePrefix("10.128.0.1/24"), nil, nil)
assert.NoError(t, err) assert.NoError(t, err)
nc := map[interface{}]interface{}{ nc := map[interface{}]interface{}{
@@ -285,11 +289,13 @@ func TestLighthouse_reload(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
} }
func newLHHostRequest(fromAddr *udp.Addr, myVpnIp, queryVpnIp iputil.VpnIp, lhh *LightHouseHandler) testLhReply { func newLHHostRequest(fromAddr netip.AddrPort, myVpnIp, queryVpnIp netip.Addr, lhh *LightHouseHandler) testLhReply {
//TODO: IPV6-WORK
bip := queryVpnIp.As4()
req := &NebulaMeta{ req := &NebulaMeta{
Type: NebulaMeta_HostQuery, Type: NebulaMeta_HostQuery,
Details: &NebulaMetaDetails{ Details: &NebulaMetaDetails{
VpnIp: uint32(queryVpnIp), VpnIp: binary.BigEndian.Uint32(bip[:]),
}, },
} }
@@ -306,17 +312,19 @@ func newLHHostRequest(fromAddr *udp.Addr, myVpnIp, queryVpnIp iputil.VpnIp, lhh
return w.lastReply return w.lastReply
} }
func newLHHostUpdate(fromAddr *udp.Addr, vpnIp iputil.VpnIp, addrs []*udp.Addr, lhh *LightHouseHandler) { func newLHHostUpdate(fromAddr netip.AddrPort, vpnIp netip.Addr, addrs []netip.AddrPort, lhh *LightHouseHandler) {
//TODO: IPV6-WORK
bip := vpnIp.As4()
req := &NebulaMeta{ req := &NebulaMeta{
Type: NebulaMeta_HostUpdateNotification, Type: NebulaMeta_HostUpdateNotification,
Details: &NebulaMetaDetails{ Details: &NebulaMetaDetails{
VpnIp: uint32(vpnIp), VpnIp: binary.BigEndian.Uint32(bip[:]),
Ip4AndPorts: make([]*Ip4AndPort, len(addrs)), Ip4AndPorts: make([]*Ip4AndPort, len(addrs)),
}, },
} }
for k, v := range addrs { for k, v := range addrs {
req.Details.Ip4AndPorts[k] = &Ip4AndPort{Ip: uint32(iputil.Ip2VpnIp(v.IP)), Port: uint32(v.Port)} req.Details.Ip4AndPorts[k] = NewIp4AndPortFromNetIP(v.Addr(), v.Port())
} }
b, err := req.Marshal() b, err := req.Marshal()
@@ -394,16 +402,10 @@ func newLHHostUpdate(fromAddr *udp.Addr, vpnIp iputil.VpnIp, addrs []*udp.Addr,
// ) // )
//} //}
func Test_ipMaskContains(t *testing.T) {
assert.True(t, ipMaskContains(iputil.Ip2VpnIp(net.ParseIP("10.0.0.1")), 32-24, iputil.Ip2VpnIp(net.ParseIP("10.0.0.255"))))
assert.False(t, ipMaskContains(iputil.Ip2VpnIp(net.ParseIP("10.0.0.1")), 32-24, iputil.Ip2VpnIp(net.ParseIP("10.0.1.1"))))
assert.True(t, ipMaskContains(iputil.Ip2VpnIp(net.ParseIP("10.0.0.1")), 32, iputil.Ip2VpnIp(net.ParseIP("10.0.1.1"))))
}
type testLhReply struct { type testLhReply struct {
nebType header.MessageType nebType header.MessageType
nebSubType header.MessageSubType nebSubType header.MessageSubType
vpnIp iputil.VpnIp vpnIp netip.Addr
msg *NebulaMeta msg *NebulaMeta
} }
@@ -414,7 +416,7 @@ type testEncWriter struct {
func (tw *testEncWriter) SendVia(via *HostInfo, relay *Relay, ad, nb, out []byte, nocopy bool) { func (tw *testEncWriter) SendVia(via *HostInfo, relay *Relay, ad, nb, out []byte, nocopy bool) {
} }
func (tw *testEncWriter) Handshake(vpnIp iputil.VpnIp) { func (tw *testEncWriter) Handshake(vpnIp netip.Addr) {
} }
func (tw *testEncWriter) SendMessageToHostInfo(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, _, _ []byte) { func (tw *testEncWriter) SendMessageToHostInfo(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, _, _ []byte) {
@@ -434,7 +436,7 @@ func (tw *testEncWriter) SendMessageToHostInfo(t header.MessageType, st header.M
} }
} }
func (tw *testEncWriter) SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp iputil.VpnIp, p, _, _ []byte) { func (tw *testEncWriter) SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp netip.Addr, p, _, _ []byte) {
msg := &NebulaMeta{} msg := &NebulaMeta{}
err := msg.Unmarshal(p) err := msg.Unmarshal(p)
if tw.metaFilter == nil || msg.Type == *tw.metaFilter { if tw.metaFilter == nil || msg.Type == *tw.metaFilter {
@@ -452,35 +454,16 @@ func (tw *testEncWriter) SendMessageToVpnIp(t header.MessageType, st header.Mess
} }
// assertIp4InArray asserts every address in want is at the same position in have and that the lengths match // assertIp4InArray asserts every address in want is at the same position in have and that the lengths match
func assertIp4InArray(t *testing.T, have []*Ip4AndPort, want ...*udp.Addr) { func assertIp4InArray(t *testing.T, have []*Ip4AndPort, want ...netip.AddrPort) {
if !assert.Len(t, have, len(want)) { if !assert.Len(t, have, len(want)) {
return return
} }
for k, w := range want { for k, w := range want {
if !(have[k].Ip == uint32(iputil.Ip2VpnIp(w.IP)) && have[k].Port == uint32(w.Port)) { //TODO: IPV6-WORK
assert.Fail(t, fmt.Sprintf("Response did not contain: %v:%v at %v; %v", w.IP, w.Port, k, translateV4toUdpAddr(have))) h := AddrPortFromIp4AndPort(have[k])
if !(h == w) {
assert.Fail(t, fmt.Sprintf("Response did not contain: %v at %v, found %v", w, k, h))
} }
} }
} }
// assertUdpAddrInArray asserts every address in want is at the same position in have and that the lengths match
func assertUdpAddrInArray(t *testing.T, have []*udp.Addr, want ...*udp.Addr) {
if !assert.Len(t, have, len(want)) {
return
}
for k, w := range want {
if !(have[k].IP.Equal(w.IP) && have[k].Port == w.Port) {
assert.Fail(t, fmt.Sprintf("Response did not contain: %v at %v; %v", w, k, have))
}
}
}
func translateV4toUdpAddr(ips []*Ip4AndPort) []*udp.Addr {
addrs := make([]*udp.Addr, len(ips))
for k, v := range ips {
addrs[k] = NewUDPAddrFromLH4(v)
}
return addrs
}

40
main.go
View File

@@ -5,6 +5,7 @@ import (
"encoding/binary" "encoding/binary"
"fmt" "fmt"
"net" "net"
"net/netip"
"time" "time"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
@@ -67,8 +68,17 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
} }
l.WithField("firewallHashes", fw.GetRuleHashes()).Info("Firewall started") l.WithField("firewallHashes", fw.GetRuleHashes()).Info("Firewall started")
// TODO: make sure mask is 4 bytes ones, _ := certificate.Details.Ips[0].Mask.Size()
tunCidr := certificate.Details.Ips[0] addr, ok := netip.AddrFromSlice(certificate.Details.Ips[0].IP)
if !ok {
err = util.NewContextualError(
"Invalid ip address in certificate",
m{"vpnIp": certificate.Details.Ips[0].IP},
nil,
)
return nil, err
}
tunCidr := netip.PrefixFrom(addr, ones)
ssh, err := sshd.NewSSHServer(l.WithField("subsystem", "sshd")) ssh, err := sshd.NewSSHServer(l.WithField("subsystem", "sshd"))
if err != nil { if err != nil {
@@ -150,21 +160,25 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
if !configTest { if !configTest {
rawListenHost := c.GetString("listen.host", "0.0.0.0") rawListenHost := c.GetString("listen.host", "0.0.0.0")
var listenHost *net.IPAddr var listenHost netip.Addr
if rawListenHost == "[::]" { if rawListenHost == "[::]" {
// Old guidance was to provide the literal `[::]` in `listen.host` but that won't resolve. // Old guidance was to provide the literal `[::]` in `listen.host` but that won't resolve.
listenHost = &net.IPAddr{IP: net.IPv6zero} listenHost = netip.IPv6Unspecified()
} else { } else {
listenHost, err = net.ResolveIPAddr("ip", rawListenHost) ips, err := net.DefaultResolver.LookupNetIP(context.Background(), "ip", rawListenHost)
if err != nil { if err != nil {
return nil, util.ContextualizeIfNeeded("Failed to resolve listen.host", err) return nil, util.ContextualizeIfNeeded("Failed to resolve listen.host", err)
} }
if len(ips) == 0 {
return nil, util.ContextualizeIfNeeded("Failed to resolve listen.host", err)
}
listenHost = ips[0].Unmap()
} }
for i := 0; i < routines; i++ { for i := 0; i < routines; i++ {
l.Infof("listening %q %d", listenHost.IP, port) l.Infof("listening on %v", netip.AddrPortFrom(listenHost, uint16(port)))
udpServer, err := udp.NewListener(l, listenHost.IP, port, routines > 1, c.GetInt("listen.batch", 64)) udpServer, err := udp.NewListener(l, listenHost, port, routines > 1, c.GetInt("listen.batch", 64))
if err != nil { if err != nil {
return nil, util.NewContextualError("Failed to open udp listener", m{"queue": i}, err) return nil, util.NewContextualError("Failed to open udp listener", m{"queue": i}, err)
} }
@@ -178,13 +192,14 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
if err != nil { if err != nil {
return nil, util.NewContextualError("Failed to get listening port", nil, err) return nil, util.NewContextualError("Failed to get listening port", nil, err)
} }
port = int(uPort.Port) port = int(uPort.Port())
} }
} }
} }
hostMap := NewHostMapFromConfig(l, tunCidr, c) hostMap := NewHostMapFromConfig(l, tunCidr, c)
punchy := NewPunchyFromConfig(l, c) punchy := NewPunchyFromConfig(l, c)
connManager := newConnectionManagerFromConfig(l, c, hostMap, punchy)
lightHouse, err := NewLightHouseFromConfig(ctx, l, c, tunCidr, udpConns[0], punchy) lightHouse, err := NewLightHouseFromConfig(ctx, l, c, tunCidr, udpConns[0], punchy)
if err != nil { if err != nil {
return nil, util.ContextualizeIfNeeded("Failed to initialize lighthouse handler", err) return nil, util.ContextualizeIfNeeded("Failed to initialize lighthouse handler", err)
@@ -201,7 +216,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
handshakeConfig := HandshakeConfig{ handshakeConfig := HandshakeConfig{
tryInterval: c.GetDuration("handshakes.try_interval", DefaultHandshakeTryInterval), tryInterval: c.GetDuration("handshakes.try_interval", DefaultHandshakeTryInterval),
retries: c.GetInt("handshakes.retries", DefaultHandshakeRetries), retries: int64(c.GetInt("handshakes.retries", DefaultHandshakeRetries)),
triggerBuffer: c.GetInt("handshakes.trigger_buffer", DefaultHandshakeTriggerBuffer), triggerBuffer: c.GetInt("handshakes.trigger_buffer", DefaultHandshakeTriggerBuffer),
useRelays: useRelays, useRelays: useRelays,
@@ -220,9 +235,6 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
} }
} }
checkInterval := c.GetInt("timers.connection_alive_interval", 5)
pendingDeletionInterval := c.GetInt("timers.pending_deletion_interval", 10)
ifConfig := &InterfaceConfig{ ifConfig := &InterfaceConfig{
HostMap: hostMap, HostMap: hostMap,
Inside: tun, Inside: tun,
@@ -232,9 +244,8 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
Firewall: fw, Firewall: fw,
ServeDns: serveDns, ServeDns: serveDns,
HandshakeManager: handshakeManager, HandshakeManager: handshakeManager,
connectionManager: connManager,
lightHouse: lightHouse, lightHouse: lightHouse,
checkInterval: time.Second * time.Duration(checkInterval),
pendingDeletionInterval: time.Second * time.Duration(pendingDeletionInterval),
tryPromoteEvery: c.GetUint32("counters.try_promote", defaultPromoteEvery), tryPromoteEvery: c.GetUint32("counters.try_promote", defaultPromoteEvery),
reQueryEvery: c.GetUint32("counters.requery_every_packets", defaultReQueryEvery), reQueryEvery: c.GetUint32("counters.requery_every_packets", defaultReQueryEvery),
reQueryWait: c.GetDuration("timers.requery_wait_duration", defaultReQueryWait), reQueryWait: c.GetDuration("timers.requery_wait_duration", defaultReQueryWait),
@@ -311,5 +322,6 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
statsStart, statsStart,
dnsStart, dnsStart,
lightHouse.StartUpdateWorker, lightHouse.StartUpdateWorker,
connManager.Start,
}, nil }, nil
} }

View File

@@ -4,6 +4,7 @@ import (
"encoding/binary" "encoding/binary"
"errors" "errors"
"fmt" "fmt"
"net/netip"
"time" "time"
"github.com/flynn/noise" "github.com/flynn/noise"
@@ -11,7 +12,6 @@ import (
"github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/firewall" "github.com/slackhq/nebula/firewall"
"github.com/slackhq/nebula/header" "github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/udp" "github.com/slackhq/nebula/udp"
"golang.org/x/net/ipv4" "golang.org/x/net/ipv4"
"google.golang.org/protobuf/proto" "google.golang.org/protobuf/proto"
@@ -21,9 +21,10 @@ const (
minFwPacketLen = 4 minFwPacketLen = 4
) )
// TODO: IPV6-WORK this can likely be removed now
func readOutsidePackets(f *Interface) udp.EncReader { func readOutsidePackets(f *Interface) udp.EncReader {
return func( return func(
addr *udp.Addr, addr netip.AddrPort,
out []byte, out []byte,
packet []byte, packet []byte,
header *header.H, header *header.H,
@@ -37,29 +38,27 @@ func readOutsidePackets(f *Interface) udp.EncReader {
} }
} }
func (f *Interface) readOutsidePackets(addr *udp.Addr, via *ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf udp.LightHouseHandlerFunc, nb []byte, q int, localCache firewall.ConntrackCache) { func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf udp.LightHouseHandlerFunc, nb []byte, q int, localCache firewall.ConntrackCache) {
err := h.Parse(packet) err := h.Parse(packet)
if err != nil { if err != nil {
// TODO: best if we return this and let caller log // TODO: best if we return this and let caller log
// TODO: Might be better to send the literal []byte("holepunch") packet and ignore that? // TODO: Might be better to send the literal []byte("holepunch") packet and ignore that?
// Hole punch packets are 0 or 1 byte big, so lets ignore printing those errors // Hole punch packets are 0 or 1 byte big, so lets ignore printing those errors
if len(packet) > 1 { if len(packet) > 1 {
f.l.WithField("packet", packet).Infof("Error while parsing inbound packet from %s: %s", addr, err) f.l.WithField("packet", packet).Infof("Error while parsing inbound packet from %s: %s", ip, err)
} }
return return
} }
//l.Error("in packet ", header, packet[HeaderLen:]) //l.Error("in packet ", header, packet[HeaderLen:])
if addr != nil { if ip.IsValid() {
if ip4 := addr.IP.To4(); ip4 != nil { if f.myVpnNet.Contains(ip.Addr()) {
if ipMaskContains(f.lightHouse.myVpnIp, f.lightHouse.myVpnZeros, iputil.VpnIp(binary.BigEndian.Uint32(ip4))) {
if f.l.Level >= logrus.DebugLevel { if f.l.Level >= logrus.DebugLevel {
f.l.WithField("udpAddr", addr).Debug("Refusing to process double encrypted packet") f.l.WithField("udpAddr", ip).Debug("Refusing to process double encrypted packet")
} }
return return
} }
} }
}
var hostinfo *HostInfo var hostinfo *HostInfo
// verify if we've seen this index before, otherwise respond to the handshake initiation // verify if we've seen this index before, otherwise respond to the handshake initiation
@@ -77,7 +76,7 @@ func (f *Interface) readOutsidePackets(addr *udp.Addr, via *ViaSender, out []byt
switch h.Type { switch h.Type {
case header.Message: case header.Message:
// TODO handleEncrypted sends directly to addr on error. Handle this in the tunneling case. // TODO handleEncrypted sends directly to addr on error. Handle this in the tunneling case.
if !f.handleEncrypted(ci, addr, h) { if !f.handleEncrypted(ci, ip, h) {
return return
} }
@@ -101,9 +100,9 @@ func (f *Interface) readOutsidePackets(addr *udp.Addr, via *ViaSender, out []byt
// Successfully validated the thing. Get rid of the Relay header. // Successfully validated the thing. Get rid of the Relay header.
signedPayload = signedPayload[header.Len:] signedPayload = signedPayload[header.Len:]
// Pull the Roaming parts up here, and return in all call paths. // Pull the Roaming parts up here, and return in all call paths.
f.handleHostRoaming(hostinfo, addr) f.handleHostRoaming(hostinfo, ip)
// Track usage of both the HostInfo and the Relay for the received & authenticated packet // Track usage of both the HostInfo and the Relay for the received & authenticated packet
f.connectionManager.In(hostinfo.localIndexId) f.connectionManager.In(hostinfo)
f.connectionManager.RelayUsed(h.RemoteIndex) f.connectionManager.RelayUsed(h.RemoteIndex)
relay, ok := hostinfo.relayState.QueryRelayForByIdx(h.RemoteIndex) relay, ok := hostinfo.relayState.QueryRelayForByIdx(h.RemoteIndex)
@@ -118,7 +117,7 @@ func (f *Interface) readOutsidePackets(addr *udp.Addr, via *ViaSender, out []byt
case TerminalType: case TerminalType:
// If I am the target of this relay, process the unwrapped packet // If I am the target of this relay, process the unwrapped packet
// From this recursive point, all these variables are 'burned'. We shouldn't rely on them again. // From this recursive point, all these variables are 'burned'. We shouldn't rely on them again.
f.readOutsidePackets(nil, &ViaSender{relayHI: hostinfo, remoteIdx: relay.RemoteIndex, relay: relay}, out[:0], signedPayload, h, fwPacket, lhf, nb, q, localCache) f.readOutsidePackets(netip.AddrPort{}, &ViaSender{relayHI: hostinfo, remoteIdx: relay.RemoteIndex, relay: relay}, out[:0], signedPayload, h, fwPacket, lhf, nb, q, localCache)
return return
case ForwardingType: case ForwardingType:
// Find the target HostInfo relay object // Find the target HostInfo relay object
@@ -148,13 +147,13 @@ func (f *Interface) readOutsidePackets(addr *udp.Addr, via *ViaSender, out []byt
case header.LightHouse: case header.LightHouse:
f.messageMetrics.Rx(h.Type, h.Subtype, 1) f.messageMetrics.Rx(h.Type, h.Subtype, 1)
if !f.handleEncrypted(ci, addr, h) { if !f.handleEncrypted(ci, ip, h) {
return return
} }
d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb) d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb)
if err != nil { if err != nil {
hostinfo.logger(f.l).WithError(err).WithField("udpAddr", addr). hostinfo.logger(f.l).WithError(err).WithField("udpAddr", ip).
WithField("packet", packet). WithField("packet", packet).
Error("Failed to decrypt lighthouse packet") Error("Failed to decrypt lighthouse packet")
@@ -163,19 +162,19 @@ func (f *Interface) readOutsidePackets(addr *udp.Addr, via *ViaSender, out []byt
return return
} }
lhf(addr, hostinfo.vpnIp, d) lhf(ip, hostinfo.vpnIp, d)
// Fallthrough to the bottom to record incoming traffic // Fallthrough to the bottom to record incoming traffic
case header.Test: case header.Test:
f.messageMetrics.Rx(h.Type, h.Subtype, 1) f.messageMetrics.Rx(h.Type, h.Subtype, 1)
if !f.handleEncrypted(ci, addr, h) { if !f.handleEncrypted(ci, ip, h) {
return return
} }
d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb) d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb)
if err != nil { if err != nil {
hostinfo.logger(f.l).WithError(err).WithField("udpAddr", addr). hostinfo.logger(f.l).WithError(err).WithField("udpAddr", ip).
WithField("packet", packet). WithField("packet", packet).
Error("Failed to decrypt test packet") Error("Failed to decrypt test packet")
@@ -187,7 +186,7 @@ func (f *Interface) readOutsidePackets(addr *udp.Addr, via *ViaSender, out []byt
if h.Subtype == header.TestRequest { if h.Subtype == header.TestRequest {
// This testRequest might be from TryPromoteBest, so we should roam // This testRequest might be from TryPromoteBest, so we should roam
// to the new IP address before responding // to the new IP address before responding
f.handleHostRoaming(hostinfo, addr) f.handleHostRoaming(hostinfo, ip)
f.send(header.Test, header.TestReply, ci, hostinfo, d, nb, out) f.send(header.Test, header.TestReply, ci, hostinfo, d, nb, out)
} }
@@ -198,34 +197,34 @@ func (f *Interface) readOutsidePackets(addr *udp.Addr, via *ViaSender, out []byt
case header.Handshake: case header.Handshake:
f.messageMetrics.Rx(h.Type, h.Subtype, 1) f.messageMetrics.Rx(h.Type, h.Subtype, 1)
f.handshakeManager.HandleIncoming(addr, via, packet, h) f.handshakeManager.HandleIncoming(ip, via, packet, h)
return return
case header.RecvError: case header.RecvError:
f.messageMetrics.Rx(h.Type, h.Subtype, 1) f.messageMetrics.Rx(h.Type, h.Subtype, 1)
f.handleRecvError(addr, h) f.handleRecvError(ip, h)
return return
case header.CloseTunnel: case header.CloseTunnel:
f.messageMetrics.Rx(h.Type, h.Subtype, 1) f.messageMetrics.Rx(h.Type, h.Subtype, 1)
if !f.handleEncrypted(ci, addr, h) { if !f.handleEncrypted(ci, ip, h) {
return return
} }
hostinfo.logger(f.l).WithField("udpAddr", addr). hostinfo.logger(f.l).WithField("udpAddr", ip).
Info("Close tunnel received, tearing down.") Info("Close tunnel received, tearing down.")
f.closeTunnel(hostinfo) f.closeTunnel(hostinfo)
return return
case header.Control: case header.Control:
if !f.handleEncrypted(ci, addr, h) { if !f.handleEncrypted(ci, ip, h) {
return return
} }
d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb) d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb)
if err != nil { if err != nil {
hostinfo.logger(f.l).WithError(err).WithField("udpAddr", addr). hostinfo.logger(f.l).WithError(err).WithField("udpAddr", ip).
WithField("packet", packet). WithField("packet", packet).
Error("Failed to decrypt Control packet") Error("Failed to decrypt Control packet")
return return
@@ -241,13 +240,13 @@ func (f *Interface) readOutsidePackets(addr *udp.Addr, via *ViaSender, out []byt
default: default:
f.messageMetrics.Rx(h.Type, h.Subtype, 1) f.messageMetrics.Rx(h.Type, h.Subtype, 1)
hostinfo.logger(f.l).Debugf("Unexpected packet received from %s", addr) hostinfo.logger(f.l).Debugf("Unexpected packet received from %s", ip)
return return
} }
f.handleHostRoaming(hostinfo, addr) f.handleHostRoaming(hostinfo, ip)
f.connectionManager.In(hostinfo.localIndexId) f.connectionManager.In(hostinfo)
} }
// closeTunnel closes a tunnel locally, it does not send a closeTunnel packet to the remote // closeTunnel closes a tunnel locally, it does not send a closeTunnel packet to the remote
@@ -264,34 +263,34 @@ func (f *Interface) sendCloseTunnel(h *HostInfo) {
f.send(header.CloseTunnel, 0, h.ConnectionState, h, []byte{}, make([]byte, 12, 12), make([]byte, mtu)) f.send(header.CloseTunnel, 0, h.ConnectionState, h, []byte{}, make([]byte, 12, 12), make([]byte, mtu))
} }
func (f *Interface) handleHostRoaming(hostinfo *HostInfo, addr *udp.Addr) { func (f *Interface) handleHostRoaming(hostinfo *HostInfo, ip netip.AddrPort) {
if addr != nil && !hostinfo.remote.Equals(addr) { if ip.IsValid() && hostinfo.remote != ip {
if !f.lightHouse.GetRemoteAllowList().Allow(hostinfo.vpnIp, addr.IP) { if !f.lightHouse.GetRemoteAllowList().Allow(hostinfo.vpnIp, ip.Addr()) {
hostinfo.logger(f.l).WithField("newAddr", addr).Debug("lighthouse.remote_allow_list denied roaming") hostinfo.logger(f.l).WithField("newAddr", ip).Debug("lighthouse.remote_allow_list denied roaming")
return return
} }
if !hostinfo.lastRoam.IsZero() && addr.Equals(hostinfo.lastRoamRemote) && time.Since(hostinfo.lastRoam) < RoamingSuppressSeconds*time.Second { if !hostinfo.lastRoam.IsZero() && ip == hostinfo.lastRoamRemote && time.Since(hostinfo.lastRoam) < RoamingSuppressSeconds*time.Second {
if f.l.Level >= logrus.DebugLevel { if f.l.Level >= logrus.DebugLevel {
hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", addr). hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", ip).
Debugf("Suppressing roam back to previous remote for %d seconds", RoamingSuppressSeconds) Debugf("Suppressing roam back to previous remote for %d seconds", RoamingSuppressSeconds)
} }
return return
} }
hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", addr). hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", ip).
Info("Host roamed to new udp ip/port.") Info("Host roamed to new udp ip/port.")
hostinfo.lastRoam = time.Now() hostinfo.lastRoam = time.Now()
hostinfo.lastRoamRemote = hostinfo.remote hostinfo.lastRoamRemote = hostinfo.remote
hostinfo.SetRemote(addr) hostinfo.SetRemote(ip)
} }
} }
func (f *Interface) handleEncrypted(ci *ConnectionState, addr *udp.Addr, h *header.H) bool { func (f *Interface) handleEncrypted(ci *ConnectionState, addr netip.AddrPort, h *header.H) bool {
// If connectionstate exists and the replay protector allows, process packet // If connectionstate exists and the replay protector allows, process packet
// Else, send recv errors for 300 seconds after a restart to allow fast reconnection. // Else, send recv errors for 300 seconds after a restart to allow fast reconnection.
if ci == nil || !ci.window.Check(f.l, h.MessageCounter) { if ci == nil || !ci.window.Check(f.l, h.MessageCounter) {
if addr != nil { if addr.IsValid() {
f.maybeSendRecvError(addr, h.RemoteIndex) f.maybeSendRecvError(addr, h.RemoteIndex)
return false return false
} else { } else {
@@ -340,8 +339,9 @@ func newPacket(data []byte, incoming bool, fp *firewall.Packet) error {
// Firewall packets are locally oriented // Firewall packets are locally oriented
if incoming { if incoming {
fp.RemoteIP = iputil.Ip2VpnIp(data[12:16]) //TODO: IPV6-WORK
fp.LocalIP = iputil.Ip2VpnIp(data[16:20]) fp.RemoteIP, _ = netip.AddrFromSlice(data[12:16])
fp.LocalIP, _ = netip.AddrFromSlice(data[16:20])
if fp.Fragment || fp.Protocol == firewall.ProtoICMP { if fp.Fragment || fp.Protocol == firewall.ProtoICMP {
fp.RemotePort = 0 fp.RemotePort = 0
fp.LocalPort = 0 fp.LocalPort = 0
@@ -350,8 +350,9 @@ func newPacket(data []byte, incoming bool, fp *firewall.Packet) error {
fp.LocalPort = binary.BigEndian.Uint16(data[ihl+2 : ihl+4]) fp.LocalPort = binary.BigEndian.Uint16(data[ihl+2 : ihl+4])
} }
} else { } else {
fp.LocalIP = iputil.Ip2VpnIp(data[12:16]) //TODO: IPV6-WORK
fp.RemoteIP = iputil.Ip2VpnIp(data[16:20]) fp.LocalIP, _ = netip.AddrFromSlice(data[12:16])
fp.RemoteIP, _ = netip.AddrFromSlice(data[16:20])
if fp.Fragment || fp.Protocol == firewall.ProtoICMP { if fp.Fragment || fp.Protocol == firewall.ProtoICMP {
fp.RemotePort = 0 fp.RemotePort = 0
fp.LocalPort = 0 fp.LocalPort = 0
@@ -417,7 +418,7 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
return false return false
} }
f.connectionManager.In(hostinfo.localIndexId) f.connectionManager.In(hostinfo)
_, err = f.readers[q].Write(out) _, err = f.readers[q].Write(out)
if err != nil { if err != nil {
f.l.WithError(err).Error("Failed to write to tun") f.l.WithError(err).Error("Failed to write to tun")
@@ -425,13 +426,13 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
return true return true
} }
func (f *Interface) maybeSendRecvError(endpoint *udp.Addr, index uint32) { func (f *Interface) maybeSendRecvError(endpoint netip.AddrPort, index uint32) {
if f.sendRecvErrorConfig.ShouldSendRecvError(endpoint.IP) { if f.sendRecvErrorConfig.ShouldSendRecvError(endpoint) {
f.sendRecvError(endpoint, index) f.sendRecvError(endpoint, index)
} }
} }
func (f *Interface) sendRecvError(endpoint *udp.Addr, index uint32) { func (f *Interface) sendRecvError(endpoint netip.AddrPort, index uint32) {
f.messageMetrics.Tx(header.RecvError, 0, 1) f.messageMetrics.Tx(header.RecvError, 0, 1)
//TODO: this should be a signed message so we can trust that we should drop the index //TODO: this should be a signed message so we can trust that we should drop the index
@@ -444,7 +445,7 @@ func (f *Interface) sendRecvError(endpoint *udp.Addr, index uint32) {
} }
} }
func (f *Interface) handleRecvError(addr *udp.Addr, h *header.H) { func (f *Interface) handleRecvError(addr netip.AddrPort, h *header.H) {
if f.l.Level >= logrus.DebugLevel { if f.l.Level >= logrus.DebugLevel {
f.l.WithField("index", h.RemoteIndex). f.l.WithField("index", h.RemoteIndex).
WithField("udpAddr", addr). WithField("udpAddr", addr).
@@ -461,7 +462,7 @@ func (f *Interface) handleRecvError(addr *udp.Addr, h *header.H) {
return return
} }
if hostinfo.remote != nil && !hostinfo.remote.Equals(addr) { if hostinfo.remote.IsValid() && hostinfo.remote != addr {
f.l.Infoln("Someone spoofing recv_errors? ", addr, hostinfo.remote) f.l.Infoln("Someone spoofing recv_errors? ", addr, hostinfo.remote)
return return
} }

View File

@@ -2,10 +2,10 @@ package nebula
import ( import (
"net" "net"
"net/netip"
"testing" "testing"
"github.com/slackhq/nebula/firewall" "github.com/slackhq/nebula/firewall"
"github.com/slackhq/nebula/iputil"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"golang.org/x/net/ipv4" "golang.org/x/net/ipv4"
) )
@@ -55,8 +55,8 @@ func Test_newPacket(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, p.Protocol, uint8(firewall.ProtoTCP)) assert.Equal(t, p.Protocol, uint8(firewall.ProtoTCP))
assert.Equal(t, p.LocalIP, iputil.Ip2VpnIp(net.IPv4(10, 0, 0, 2))) assert.Equal(t, p.LocalIP, netip.MustParseAddr("10.0.0.2"))
assert.Equal(t, p.RemoteIP, iputil.Ip2VpnIp(net.IPv4(10, 0, 0, 1))) assert.Equal(t, p.RemoteIP, netip.MustParseAddr("10.0.0.1"))
assert.Equal(t, p.RemotePort, uint16(3)) assert.Equal(t, p.RemotePort, uint16(3))
assert.Equal(t, p.LocalPort, uint16(4)) assert.Equal(t, p.LocalPort, uint16(4))
@@ -76,8 +76,8 @@ func Test_newPacket(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, p.Protocol, uint8(2)) assert.Equal(t, p.Protocol, uint8(2))
assert.Equal(t, p.LocalIP, iputil.Ip2VpnIp(net.IPv4(10, 0, 0, 1))) assert.Equal(t, p.LocalIP, netip.MustParseAddr("10.0.0.1"))
assert.Equal(t, p.RemoteIP, iputil.Ip2VpnIp(net.IPv4(10, 0, 0, 2))) assert.Equal(t, p.RemoteIP, netip.MustParseAddr("10.0.0.2"))
assert.Equal(t, p.RemotePort, uint16(6)) assert.Equal(t, p.RemotePort, uint16(6))
assert.Equal(t, p.LocalPort, uint16(5)) assert.Equal(t, p.LocalPort, uint16(5))
} }

View File

@@ -2,16 +2,14 @@ package overlay
import ( import (
"io" "io"
"net" "net/netip"
"github.com/slackhq/nebula/iputil"
) )
type Device interface { type Device interface {
io.ReadWriteCloser io.ReadWriteCloser
Activate() error Activate() error
Cidr() *net.IPNet Cidr() netip.Prefix
Name() string Name() string
RouteFor(iputil.VpnIp) iputil.VpnIp RouteFor(netip.Addr) netip.Addr
NewMultiQueueReader() (io.ReadWriteCloser, error) NewMultiQueueReader() (io.ReadWriteCloser, error)
} }

View File

@@ -1,34 +1,30 @@
package overlay package overlay
import ( import (
"bytes"
"fmt" "fmt"
"math" "math"
"net" "net"
"net/netip"
"runtime" "runtime"
"strconv" "strconv"
"github.com/gaissmai/bart"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cidr"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/iputil"
) )
type Route struct { type Route struct {
MTU int MTU int
Metric int Metric int
Cidr *net.IPNet Cidr netip.Prefix
Via *iputil.VpnIp Via netip.Addr
Install bool Install bool
} }
// Equal determines if a route that could be installed in the system route table is equal to another // Equal determines if a route that could be installed in the system route table is equal to another
// Via is ignored since that is only consumed within nebula itself // Via is ignored since that is only consumed within nebula itself
func (r Route) Equal(t Route) bool { func (r Route) Equal(t Route) bool {
if !r.Cidr.IP.Equal(t.Cidr.IP) { if r.Cidr != t.Cidr {
return false
}
if !bytes.Equal(r.Cidr.Mask, t.Cidr.Mask) {
return false return false
} }
if r.Metric != t.Metric { if r.Metric != t.Metric {
@@ -51,21 +47,21 @@ func (r Route) String() string {
return s return s
} }
func makeRouteTree(l *logrus.Logger, routes []Route, allowMTU bool) (*cidr.Tree4[iputil.VpnIp], error) { func makeRouteTree(l *logrus.Logger, routes []Route, allowMTU bool) (*bart.Table[netip.Addr], error) {
routeTree := cidr.NewTree4[iputil.VpnIp]() routeTree := new(bart.Table[netip.Addr])
for _, r := range routes { for _, r := range routes {
if !allowMTU && r.MTU > 0 { if !allowMTU && r.MTU > 0 {
l.WithField("route", r).Warnf("route MTU is not supported in %s", runtime.GOOS) l.WithField("route", r).Warnf("route MTU is not supported in %s", runtime.GOOS)
} }
if r.Via != nil { if r.Via.IsValid() {
routeTree.AddCIDR(r.Cidr, *r.Via) routeTree.Insert(r.Cidr, r.Via)
} }
} }
return routeTree, nil return routeTree, nil
} }
func parseRoutes(c *config.C, network *net.IPNet) ([]Route, error) { func parseRoutes(c *config.C, network netip.Prefix) ([]Route, error) {
var err error var err error
r := c.Get("tun.routes") r := c.Get("tun.routes")
@@ -116,12 +112,12 @@ func parseRoutes(c *config.C, network *net.IPNet) ([]Route, error) {
MTU: mtu, MTU: mtu,
} }
_, r.Cidr, err = net.ParseCIDR(fmt.Sprintf("%v", rRoute)) r.Cidr, err = netip.ParsePrefix(fmt.Sprintf("%v", rRoute))
if err != nil { if err != nil {
return nil, fmt.Errorf("entry %v.route in tun.routes failed to parse: %v", i+1, err) return nil, fmt.Errorf("entry %v.route in tun.routes failed to parse: %v", i+1, err)
} }
if !ipWithin(network, r.Cidr) { if !network.Contains(r.Cidr.Addr()) || r.Cidr.Bits() < network.Bits() {
return nil, fmt.Errorf( return nil, fmt.Errorf(
"entry %v.route in tun.routes is not contained within the network attached to the certificate; route: %v, network: %v", "entry %v.route in tun.routes is not contained within the network attached to the certificate; route: %v, network: %v",
i+1, i+1,
@@ -136,7 +132,7 @@ func parseRoutes(c *config.C, network *net.IPNet) ([]Route, error) {
return routes, nil return routes, nil
} }
func parseUnsafeRoutes(c *config.C, network *net.IPNet) ([]Route, error) { func parseUnsafeRoutes(c *config.C, network netip.Prefix) ([]Route, error) {
var err error var err error
r := c.Get("tun.unsafe_routes") r := c.Get("tun.unsafe_routes")
@@ -202,9 +198,9 @@ func parseUnsafeRoutes(c *config.C, network *net.IPNet) ([]Route, error) {
return nil, fmt.Errorf("entry %v.via in tun.unsafe_routes is not a string: found %T", i+1, rVia) return nil, fmt.Errorf("entry %v.via in tun.unsafe_routes is not a string: found %T", i+1, rVia)
} }
nVia := net.ParseIP(via) viaVpnIp, err := netip.ParseAddr(via)
if nVia == nil { if err != nil {
return nil, fmt.Errorf("entry %v.via in tun.unsafe_routes failed to parse address: %v", i+1, via) return nil, fmt.Errorf("entry %v.via in tun.unsafe_routes failed to parse address: %v", i+1, err)
} }
rRoute, ok := m["route"] rRoute, ok := m["route"]
@@ -212,8 +208,6 @@ func parseUnsafeRoutes(c *config.C, network *net.IPNet) ([]Route, error) {
return nil, fmt.Errorf("entry %v.route in tun.unsafe_routes is not present", i+1) return nil, fmt.Errorf("entry %v.route in tun.unsafe_routes is not present", i+1)
} }
viaVpnIp := iputil.Ip2VpnIp(nVia)
install := true install := true
rInstall, ok := m["install"] rInstall, ok := m["install"]
if ok { if ok {
@@ -224,18 +218,18 @@ func parseUnsafeRoutes(c *config.C, network *net.IPNet) ([]Route, error) {
} }
r := Route{ r := Route{
Via: &viaVpnIp, Via: viaVpnIp,
MTU: mtu, MTU: mtu,
Metric: metric, Metric: metric,
Install: install, Install: install,
} }
_, r.Cidr, err = net.ParseCIDR(fmt.Sprintf("%v", rRoute)) r.Cidr, err = netip.ParsePrefix(fmt.Sprintf("%v", rRoute))
if err != nil { if err != nil {
return nil, fmt.Errorf("entry %v.route in tun.unsafe_routes failed to parse: %v", i+1, err) return nil, fmt.Errorf("entry %v.route in tun.unsafe_routes failed to parse: %v", i+1, err)
} }
if ipWithin(network, r.Cidr) { if network.Contains(r.Cidr.Addr()) {
return nil, fmt.Errorf( return nil, fmt.Errorf(
"entry %v.route in tun.unsafe_routes is contained within the network attached to the certificate; route: %v, network: %v", "entry %v.route in tun.unsafe_routes is contained within the network attached to the certificate; route: %v, network: %v",
i+1, i+1,

View File

@@ -2,11 +2,10 @@ package overlay
import ( import (
"fmt" "fmt"
"net" "net/netip"
"testing" "testing"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/test" "github.com/slackhq/nebula/test"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
@@ -14,7 +13,8 @@ import (
func Test_parseRoutes(t *testing.T) { func Test_parseRoutes(t *testing.T) {
l := test.NewLogger() l := test.NewLogger()
c := config.NewC(l) c := config.NewC(l)
_, n, _ := net.ParseCIDR("10.0.0.0/24") n, err := netip.ParsePrefix("10.0.0.0/24")
assert.NoError(t, err)
// test no routes config // test no routes config
routes, err := parseRoutes(c, n) routes, err := parseRoutes(c, n)
@@ -67,7 +67,7 @@ func Test_parseRoutes(t *testing.T) {
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "nope"}}} c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "nope"}}}
routes, err = parseRoutes(c, n) routes, err = parseRoutes(c, n)
assert.Nil(t, routes) assert.Nil(t, routes)
assert.EqualError(t, err, "entry 1.route in tun.routes failed to parse: invalid CIDR address: nope") assert.EqualError(t, err, "entry 1.route in tun.routes failed to parse: netip.ParsePrefix(\"nope\"): no '/'")
// below network range // below network range
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "1.0.0.0/8"}}} c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "1.0.0.0/8"}}}
@@ -112,7 +112,8 @@ func Test_parseRoutes(t *testing.T) {
func Test_parseUnsafeRoutes(t *testing.T) { func Test_parseUnsafeRoutes(t *testing.T) {
l := test.NewLogger() l := test.NewLogger()
c := config.NewC(l) c := config.NewC(l)
_, n, _ := net.ParseCIDR("10.0.0.0/24") n, err := netip.ParsePrefix("10.0.0.0/24")
assert.NoError(t, err)
// test no routes config // test no routes config
routes, err := parseUnsafeRoutes(c, n) routes, err := parseUnsafeRoutes(c, n)
@@ -157,7 +158,7 @@ func Test_parseUnsafeRoutes(t *testing.T) {
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "via": "nope"}}} c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "via": "nope"}}}
routes, err = parseUnsafeRoutes(c, n) routes, err = parseUnsafeRoutes(c, n)
assert.Nil(t, routes) assert.Nil(t, routes)
assert.EqualError(t, err, "entry 1.via in tun.unsafe_routes failed to parse address: nope") assert.EqualError(t, err, "entry 1.via in tun.unsafe_routes failed to parse address: ParseAddr(\"nope\"): unable to parse IP")
// missing route // missing route
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "500"}}} c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "500"}}}
@@ -169,7 +170,7 @@ func Test_parseUnsafeRoutes(t *testing.T) {
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "500", "route": "nope"}}} c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "500", "route": "nope"}}}
routes, err = parseUnsafeRoutes(c, n) routes, err = parseUnsafeRoutes(c, n)
assert.Nil(t, routes) assert.Nil(t, routes)
assert.EqualError(t, err, "entry 1.route in tun.unsafe_routes failed to parse: invalid CIDR address: nope") assert.EqualError(t, err, "entry 1.route in tun.unsafe_routes failed to parse: netip.ParsePrefix(\"nope\"): no '/'")
// within network range // within network range
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "10.0.0.0/24"}}} c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "10.0.0.0/24"}}}
@@ -252,7 +253,8 @@ func Test_parseUnsafeRoutes(t *testing.T) {
func Test_makeRouteTree(t *testing.T) { func Test_makeRouteTree(t *testing.T) {
l := test.NewLogger() l := test.NewLogger()
c := config.NewC(l) c := config.NewC(l)
_, n, _ := net.ParseCIDR("10.0.0.0/24") n, err := netip.ParsePrefix("10.0.0.0/24")
assert.NoError(t, err)
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{ c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{
map[interface{}]interface{}{"via": "192.168.0.1", "route": "1.0.0.0/28"}, map[interface{}]interface{}{"via": "192.168.0.1", "route": "1.0.0.0/28"},
@@ -264,17 +266,26 @@ func Test_makeRouteTree(t *testing.T) {
routeTree, err := makeRouteTree(l, routes, true) routeTree, err := makeRouteTree(l, routes, true)
assert.NoError(t, err) assert.NoError(t, err)
ip := iputil.Ip2VpnIp(net.ParseIP("1.0.0.2")) ip, err := netip.ParseAddr("1.0.0.2")
ok, r := routeTree.MostSpecificContains(ip) assert.NoError(t, err)
r, ok := routeTree.Lookup(ip)
assert.True(t, ok) assert.True(t, ok)
assert.Equal(t, iputil.Ip2VpnIp(net.ParseIP("192.168.0.1")), r)
ip = iputil.Ip2VpnIp(net.ParseIP("1.0.0.1")) nip, err := netip.ParseAddr("192.168.0.1")
ok, r = routeTree.MostSpecificContains(ip) assert.NoError(t, err)
assert.Equal(t, nip, r)
ip, err = netip.ParseAddr("1.0.0.1")
assert.NoError(t, err)
r, ok = routeTree.Lookup(ip)
assert.True(t, ok) assert.True(t, ok)
assert.Equal(t, iputil.Ip2VpnIp(net.ParseIP("192.168.0.2")), r)
ip = iputil.Ip2VpnIp(net.ParseIP("1.1.0.1")) nip, err = netip.ParseAddr("192.168.0.2")
ok, r = routeTree.MostSpecificContains(ip) assert.NoError(t, err)
assert.Equal(t, nip, r)
ip, err = netip.ParseAddr("1.1.0.1")
assert.NoError(t, err)
r, ok = routeTree.Lookup(ip)
assert.False(t, ok) assert.False(t, ok)
} }

View File

@@ -1,7 +1,7 @@
package overlay package overlay
import ( import (
"net" "net/netip"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
@@ -11,9 +11,9 @@ import (
const DefaultMTU = 1300 const DefaultMTU = 1300
// TODO: We may be able to remove routines // TODO: We may be able to remove routines
type DeviceFactory func(c *config.C, l *logrus.Logger, tunCidr *net.IPNet, routines int) (Device, error) type DeviceFactory func(c *config.C, l *logrus.Logger, tunCidr netip.Prefix, routines int) (Device, error)
func NewDeviceFromConfig(c *config.C, l *logrus.Logger, tunCidr *net.IPNet, routines int) (Device, error) { func NewDeviceFromConfig(c *config.C, l *logrus.Logger, tunCidr netip.Prefix, routines int) (Device, error) {
switch { switch {
case c.GetBool("tun.disabled", false): case c.GetBool("tun.disabled", false):
tun := newDisabledTun(tunCidr, c.GetInt("tun.tx_queue", 500), c.GetBool("stats.message_metrics", false), l) tun := newDisabledTun(tunCidr, c.GetInt("tun.tx_queue", 500), c.GetBool("stats.message_metrics", false), l)
@@ -25,12 +25,12 @@ func NewDeviceFromConfig(c *config.C, l *logrus.Logger, tunCidr *net.IPNet, rout
} }
func NewFdDeviceFromConfig(fd *int) DeviceFactory { func NewFdDeviceFromConfig(fd *int) DeviceFactory {
return func(c *config.C, l *logrus.Logger, tunCidr *net.IPNet, routines int) (Device, error) { return func(c *config.C, l *logrus.Logger, tunCidr netip.Prefix, routines int) (Device, error) {
return newTunFromFd(c, l, *fd, tunCidr) return newTunFromFd(c, l, *fd, tunCidr)
} }
} }
func getAllRoutesFromConfig(c *config.C, cidr *net.IPNet, initial bool) (bool, []Route, error) { func getAllRoutesFromConfig(c *config.C, cidr netip.Prefix, initial bool) (bool, []Route, error) {
if !initial && !c.HasChanged("tun.routes") && !c.HasChanged("tun.unsafe_routes") { if !initial && !c.HasChanged("tun.routes") && !c.HasChanged("tun.unsafe_routes") {
return false, nil, nil return false, nil, nil
} }

View File

@@ -6,27 +6,26 @@ package overlay
import ( import (
"fmt" "fmt"
"io" "io"
"net" "net/netip"
"os" "os"
"sync/atomic" "sync/atomic"
"github.com/gaissmai/bart"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cidr"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/util" "github.com/slackhq/nebula/util"
) )
type tun struct { type tun struct {
io.ReadWriteCloser io.ReadWriteCloser
fd int fd int
cidr *net.IPNet cidr netip.Prefix
Routes atomic.Pointer[[]Route] Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[cidr.Tree4[iputil.VpnIp]] routeTree atomic.Pointer[bart.Table[netip.Addr]]
l *logrus.Logger l *logrus.Logger
} }
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr *net.IPNet) (*tun, error) { func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr netip.Prefix) (*tun, error) {
// XXX Android returns an fd in non-blocking mode which is necessary for shutdown to work properly. // XXX Android returns an fd in non-blocking mode which is necessary for shutdown to work properly.
// Be sure not to call file.Fd() as it will set the fd to blocking mode. // Be sure not to call file.Fd() as it will set the fd to blocking mode.
file := os.NewFile(uintptr(deviceFd), "/dev/net/tun") file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
@@ -53,12 +52,12 @@ func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr *net.IPNet)
return t, nil return t, nil
} }
func newTun(_ *config.C, _ *logrus.Logger, _ *net.IPNet, _ bool) (*tun, error) { func newTun(_ *config.C, _ *logrus.Logger, _ netip.Prefix, _ bool) (*tun, error) {
return nil, fmt.Errorf("newTun not supported in Android") return nil, fmt.Errorf("newTun not supported in Android")
} }
func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp { func (t *tun) RouteFor(ip netip.Addr) netip.Addr {
_, r := t.routeTree.Load().MostSpecificContains(ip) r, _ := t.routeTree.Load().Lookup(ip)
return r return r
} }
@@ -87,7 +86,7 @@ func (t *tun) reload(c *config.C, initial bool) error {
return nil return nil
} }
func (t *tun) Cidr() *net.IPNet { func (t *tun) Cidr() netip.Prefix {
return t.cidr return t.cidr
} }

View File

@@ -8,15 +8,15 @@ import (
"fmt" "fmt"
"io" "io"
"net" "net"
"net/netip"
"os" "os"
"sync/atomic" "sync/atomic"
"syscall" "syscall"
"unsafe" "unsafe"
"github.com/gaissmai/bart"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cidr"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/util" "github.com/slackhq/nebula/util"
netroute "golang.org/x/net/route" netroute "golang.org/x/net/route"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
@@ -25,10 +25,10 @@ import (
type tun struct { type tun struct {
io.ReadWriteCloser io.ReadWriteCloser
Device string Device string
cidr *net.IPNet cidr netip.Prefix
DefaultMTU int DefaultMTU int
Routes atomic.Pointer[[]Route] Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[cidr.Tree4[iputil.VpnIp]] routeTree atomic.Pointer[bart.Table[netip.Addr]]
linkAddr *netroute.LinkAddr linkAddr *netroute.LinkAddr
l *logrus.Logger l *logrus.Logger
@@ -73,7 +73,7 @@ type ifreqMTU struct {
pad [8]byte pad [8]byte
} }
func newTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, _ bool) (*tun, error) { func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, error) {
name := c.GetString("tun.dev", "") name := c.GetString("tun.dev", "")
ifIndex := -1 ifIndex := -1
if name != "" && name != "utun" { if name != "" && name != "utun" {
@@ -172,7 +172,7 @@ func (t *tun) deviceBytes() (o [16]byte) {
return return
} }
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ *net.IPNet) (*tun, error) { func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ netip.Prefix) (*tun, error) {
return nil, fmt.Errorf("newTunFromFd not supported in Darwin") return nil, fmt.Errorf("newTunFromFd not supported in Darwin")
} }
@@ -188,8 +188,13 @@ func (t *tun) Activate() error {
var addr, mask [4]byte var addr, mask [4]byte
copy(addr[:], t.cidr.IP.To4()) if !t.cidr.Addr().Is4() {
copy(mask[:], t.cidr.Mask) //TODO: IPV6-WORK
panic("need ipv6")
}
addr = t.cidr.Addr().As4()
copy(mask[:], prefixToMask(t.cidr))
s, err := unix.Socket( s, err := unix.Socket(
unix.AF_INET, unix.AF_INET,
@@ -329,13 +334,12 @@ func (t *tun) reload(c *config.C, initial bool) error {
return nil return nil
} }
func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp { func (t *tun) RouteFor(ip netip.Addr) netip.Addr {
ok, r := t.routeTree.Load().MostSpecificContains(ip) r, ok := t.routeTree.Load().Lookup(ip)
if ok { if ok {
return r return r
} }
return netip.Addr{}
return 0
} }
// Get the LinkAddr for the interface of the given name // Get the LinkAddr for the interface of the given name
@@ -384,13 +388,19 @@ func (t *tun) addRoutes(logErrors bool) error {
maskAddr := &netroute.Inet4Addr{} maskAddr := &netroute.Inet4Addr{}
routes := *t.Routes.Load() routes := *t.Routes.Load()
for _, r := range routes { for _, r := range routes {
if r.Via == nil || !r.Install { if !r.Via.IsValid() || !r.Install {
// We don't allow route MTUs so only install routes with a via // We don't allow route MTUs so only install routes with a via
continue continue
} }
copy(routeAddr.IP[:], r.Cidr.IP.To4()) if !r.Cidr.Addr().Is4() {
copy(maskAddr.IP[:], net.IP(r.Cidr.Mask).To4()) //TODO: implement ipv6
panic("Cant handle ipv6 routes yet")
}
routeAddr.IP = r.Cidr.Addr().As4()
//TODO: we could avoid the copy
copy(maskAddr.IP[:], prefixToMask(r.Cidr))
err := addRoute(routeSock, routeAddr, maskAddr, t.linkAddr) err := addRoute(routeSock, routeAddr, maskAddr, t.linkAddr)
if err != nil { if err != nil {
@@ -435,8 +445,13 @@ func (t *tun) removeRoutes(routes []Route) error {
continue continue
} }
copy(routeAddr.IP[:], r.Cidr.IP.To4()) if r.Cidr.Addr().Is6() {
copy(maskAddr.IP[:], net.IP(r.Cidr.Mask).To4()) //TODO: implement ipv6
panic("Cant handle ipv6 routes yet")
}
routeAddr.IP = r.Cidr.Addr().As4()
copy(maskAddr.IP[:], prefixToMask(r.Cidr))
err := delRoute(routeSock, routeAddr, maskAddr, t.linkAddr) err := delRoute(routeSock, routeAddr, maskAddr, t.linkAddr)
if err != nil { if err != nil {
@@ -536,7 +551,7 @@ func (t *tun) Write(from []byte) (int, error) {
return n - 4, err return n - 4, err
} }
func (t *tun) Cidr() *net.IPNet { func (t *tun) Cidr() netip.Prefix {
return t.cidr return t.cidr
} }
@@ -547,3 +562,11 @@ func (t *tun) Name() string {
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for darwin") return nil, fmt.Errorf("TODO: multiqueue not implemented for darwin")
} }
func prefixToMask(prefix netip.Prefix) []byte {
pLen := 128
if prefix.Addr().Is4() {
pLen = 32
}
return net.CIDRMask(prefix.Bits(), pLen)
}

View File

@@ -3,7 +3,7 @@ package overlay
import ( import (
"fmt" "fmt"
"io" "io"
"net" "net/netip"
"strings" "strings"
"github.com/rcrowley/go-metrics" "github.com/rcrowley/go-metrics"
@@ -13,7 +13,7 @@ import (
type disabledTun struct { type disabledTun struct {
read chan []byte read chan []byte
cidr *net.IPNet cidr netip.Prefix
// Track these metrics since we don't have the tun device to do it for us // Track these metrics since we don't have the tun device to do it for us
tx metrics.Counter tx metrics.Counter
@@ -21,7 +21,7 @@ type disabledTun struct {
l *logrus.Logger l *logrus.Logger
} }
func newDisabledTun(cidr *net.IPNet, queueLen int, metricsEnabled bool, l *logrus.Logger) *disabledTun { func newDisabledTun(cidr netip.Prefix, queueLen int, metricsEnabled bool, l *logrus.Logger) *disabledTun {
tun := &disabledTun{ tun := &disabledTun{
cidr: cidr, cidr: cidr,
read: make(chan []byte, queueLen), read: make(chan []byte, queueLen),
@@ -43,11 +43,11 @@ func (*disabledTun) Activate() error {
return nil return nil
} }
func (*disabledTun) RouteFor(iputil.VpnIp) iputil.VpnIp { func (*disabledTun) RouteFor(addr netip.Addr) netip.Addr {
return 0 return netip.Addr{}
} }
func (t *disabledTun) Cidr() *net.IPNet { func (t *disabledTun) Cidr() netip.Prefix {
return t.cidr return t.cidr
} }

View File

@@ -9,7 +9,7 @@ import (
"fmt" "fmt"
"io" "io"
"io/fs" "io/fs"
"net" "net/netip"
"os" "os"
"os/exec" "os/exec"
"strconv" "strconv"
@@ -17,10 +17,9 @@ import (
"syscall" "syscall"
"unsafe" "unsafe"
"github.com/gaissmai/bart"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cidr"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/util" "github.com/slackhq/nebula/util"
) )
@@ -48,10 +47,10 @@ type ifreqDestroy struct {
type tun struct { type tun struct {
Device string Device string
cidr *net.IPNet cidr netip.Prefix
MTU int MTU int
Routes atomic.Pointer[[]Route] Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[cidr.Tree4[iputil.VpnIp]] routeTree atomic.Pointer[bart.Table[netip.Addr]]
l *logrus.Logger l *logrus.Logger
io.ReadWriteCloser io.ReadWriteCloser
@@ -79,11 +78,11 @@ func (t *tun) Close() error {
return nil return nil
} }
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ *net.IPNet) (*tun, error) { func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ netip.Prefix) (*tun, error) {
return nil, fmt.Errorf("newTunFromFd not supported in FreeBSD") return nil, fmt.Errorf("newTunFromFd not supported in FreeBSD")
} }
func newTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, _ bool) (*tun, error) { func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, error) {
// Try to open existing tun device // Try to open existing tun device
var file *os.File var file *os.File
var err error var err error
@@ -174,7 +173,7 @@ func newTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, _ bool) (*tun, error
func (t *tun) Activate() error { func (t *tun) Activate() error {
var err error var err error
// TODO use syscalls instead of exec.Command // TODO use syscalls instead of exec.Command
cmd := exec.Command("/sbin/ifconfig", t.Device, t.cidr.String(), t.cidr.IP.String()) cmd := exec.Command("/sbin/ifconfig", t.Device, t.cidr.String(), t.cidr.Addr().String())
t.l.Debug("command: ", cmd.String()) t.l.Debug("command: ", cmd.String())
if err = cmd.Run(); err != nil { if err = cmd.Run(); err != nil {
return fmt.Errorf("failed to run 'ifconfig': %s", err) return fmt.Errorf("failed to run 'ifconfig': %s", err)
@@ -233,12 +232,12 @@ func (t *tun) reload(c *config.C, initial bool) error {
return nil return nil
} }
func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp { func (t *tun) RouteFor(ip netip.Addr) netip.Addr {
_, r := t.routeTree.Load().MostSpecificContains(ip) r, _ := t.routeTree.Load().Lookup(ip)
return r return r
} }
func (t *tun) Cidr() *net.IPNet { func (t *tun) Cidr() netip.Prefix {
return t.cidr return t.cidr
} }
@@ -253,7 +252,7 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
func (t *tun) addRoutes(logErrors bool) error { func (t *tun) addRoutes(logErrors bool) error {
routes := *t.Routes.Load() routes := *t.Routes.Load()
for _, r := range routes { for _, r := range routes {
if r.Via == nil || !r.Install { if !r.Via.IsValid() || !r.Install {
// We don't allow route MTUs so only install routes with a via // We don't allow route MTUs so only install routes with a via
continue continue
} }

View File

@@ -7,32 +7,31 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"net" "net/netip"
"os" "os"
"sync" "sync"
"sync/atomic" "sync/atomic"
"syscall" "syscall"
"github.com/gaissmai/bart"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cidr"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/util" "github.com/slackhq/nebula/util"
) )
type tun struct { type tun struct {
io.ReadWriteCloser io.ReadWriteCloser
cidr *net.IPNet cidr netip.Prefix
Routes atomic.Pointer[[]Route] Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[cidr.Tree4[iputil.VpnIp]] routeTree atomic.Pointer[bart.Table[netip.Addr]]
l *logrus.Logger l *logrus.Logger
} }
func newTun(_ *config.C, _ *logrus.Logger, _ *net.IPNet, _ bool) (*tun, error) { func newTun(_ *config.C, _ *logrus.Logger, _ netip.Prefix, _ bool) (*tun, error) {
return nil, fmt.Errorf("newTun not supported in iOS") return nil, fmt.Errorf("newTun not supported in iOS")
} }
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr *net.IPNet) (*tun, error) { func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr netip.Prefix) (*tun, error) {
file := os.NewFile(uintptr(deviceFd), "/dev/tun") file := os.NewFile(uintptr(deviceFd), "/dev/tun")
t := &tun{ t := &tun{
cidr: cidr, cidr: cidr,
@@ -80,8 +79,8 @@ func (t *tun) reload(c *config.C, initial bool) error {
return nil return nil
} }
func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp { func (t *tun) RouteFor(ip netip.Addr) netip.Addr {
_, r := t.routeTree.Load().MostSpecificContains(ip) r, _ := t.routeTree.Load().Lookup(ip)
return r return r
} }
@@ -143,7 +142,7 @@ func (tr *tunReadCloser) Close() error {
return tr.f.Close() return tr.f.Close()
} }
func (t *tun) Cidr() *net.IPNet { func (t *tun) Cidr() netip.Prefix {
return t.cidr return t.cidr
} }

View File

@@ -4,19 +4,18 @@
package overlay package overlay
import ( import (
"bytes"
"fmt" "fmt"
"io" "io"
"net" "net"
"net/netip"
"os" "os"
"strings" "strings"
"sync/atomic" "sync/atomic"
"unsafe" "unsafe"
"github.com/gaissmai/bart"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cidr"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/util" "github.com/slackhq/nebula/util"
"github.com/vishvananda/netlink" "github.com/vishvananda/netlink"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
@@ -26,7 +25,7 @@ type tun struct {
io.ReadWriteCloser io.ReadWriteCloser
fd int fd int
Device string Device string
cidr *net.IPNet cidr netip.Prefix
MaxMTU int MaxMTU int
DefaultMTU int DefaultMTU int
TXQueueLen int TXQueueLen int
@@ -34,7 +33,7 @@ type tun struct {
ioctlFd uintptr ioctlFd uintptr
Routes atomic.Pointer[[]Route] Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[cidr.Tree4[iputil.VpnIp]] routeTree atomic.Pointer[bart.Table[netip.Addr]]
routeChan chan struct{} routeChan chan struct{}
useSystemRoutes bool useSystemRoutes bool
@@ -65,7 +64,7 @@ type ifreqQLEN struct {
pad [8]byte pad [8]byte
} }
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr *net.IPNet) (*tun, error) { func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr netip.Prefix) (*tun, error) {
file := os.NewFile(uintptr(deviceFd), "/dev/net/tun") file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
t, err := newTunGeneric(c, l, file, cidr) t, err := newTunGeneric(c, l, file, cidr)
@@ -78,7 +77,7 @@ func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr *net.IPNet)
return t, nil return t, nil
} }
func newTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, multiqueue bool) (*tun, error) { func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, multiqueue bool) (*tun, error) {
fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0) fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
if err != nil { if err != nil {
// If /dev/net/tun doesn't exist, try to create it (will happen in docker) // If /dev/net/tun doesn't exist, try to create it (will happen in docker)
@@ -123,7 +122,7 @@ func newTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, multiqueue bool) (*t
return t, nil return t, nil
} }
func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, cidr *net.IPNet) (*tun, error) { func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, cidr netip.Prefix) (*tun, error) {
t := &tun{ t := &tun{
ReadWriteCloser: file, ReadWriteCloser: file,
fd: int(file.Fd()), fd: int(file.Fd()),
@@ -231,8 +230,8 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return file, nil return file, nil
} }
func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp { func (t *tun) RouteFor(ip netip.Addr) netip.Addr {
_, r := t.routeTree.Load().MostSpecificContains(ip) r, _ := t.routeTree.Load().Lookup(ip)
return r return r
} }
@@ -275,8 +274,10 @@ func (t *tun) Activate() error {
var addr, mask [4]byte var addr, mask [4]byte
copy(addr[:], t.cidr.IP.To4()) //TODO: IPV6-WORK
copy(mask[:], t.cidr.Mask) addr = t.cidr.Addr().As4()
tmask := net.CIDRMask(t.cidr.Bits(), 32)
copy(mask[:], tmask)
s, err := unix.Socket( s, err := unix.Socket(
unix.AF_INET, unix.AF_INET,
@@ -364,14 +365,19 @@ func (t *tun) setMTU() {
func (t *tun) setDefaultRoute() error { func (t *tun) setDefaultRoute() error {
// Default route // Default route
dr := &net.IPNet{IP: t.cidr.IP.Mask(t.cidr.Mask), Mask: t.cidr.Mask}
dr := &net.IPNet{
IP: t.cidr.Masked().Addr().AsSlice(),
Mask: net.CIDRMask(t.cidr.Bits(), t.cidr.Addr().BitLen()),
}
nr := netlink.Route{ nr := netlink.Route{
LinkIndex: t.deviceIndex, LinkIndex: t.deviceIndex,
Dst: dr, Dst: dr,
MTU: t.DefaultMTU, MTU: t.DefaultMTU,
AdvMSS: t.advMSS(Route{}), AdvMSS: t.advMSS(Route{}),
Scope: unix.RT_SCOPE_LINK, Scope: unix.RT_SCOPE_LINK,
Src: t.cidr.IP, Src: net.IP(t.cidr.Addr().AsSlice()),
Protocol: unix.RTPROT_KERNEL, Protocol: unix.RTPROT_KERNEL,
Table: unix.RT_TABLE_MAIN, Table: unix.RT_TABLE_MAIN,
Type: unix.RTN_UNICAST, Type: unix.RTN_UNICAST,
@@ -392,9 +398,14 @@ func (t *tun) addRoutes(logErrors bool) error {
continue continue
} }
dr := &net.IPNet{
IP: r.Cidr.Masked().Addr().AsSlice(),
Mask: net.CIDRMask(r.Cidr.Bits(), r.Cidr.Addr().BitLen()),
}
nr := netlink.Route{ nr := netlink.Route{
LinkIndex: t.deviceIndex, LinkIndex: t.deviceIndex,
Dst: r.Cidr, Dst: dr,
MTU: r.MTU, MTU: r.MTU,
AdvMSS: t.advMSS(r), AdvMSS: t.advMSS(r),
Scope: unix.RT_SCOPE_LINK, Scope: unix.RT_SCOPE_LINK,
@@ -426,9 +437,14 @@ func (t *tun) removeRoutes(routes []Route) {
continue continue
} }
dr := &net.IPNet{
IP: r.Cidr.Masked().Addr().AsSlice(),
Mask: net.CIDRMask(r.Cidr.Bits(), r.Cidr.Addr().BitLen()),
}
nr := netlink.Route{ nr := netlink.Route{
LinkIndex: t.deviceIndex, LinkIndex: t.deviceIndex,
Dst: r.Cidr, Dst: dr,
MTU: r.MTU, MTU: r.MTU,
AdvMSS: t.advMSS(r), AdvMSS: t.advMSS(r),
Scope: unix.RT_SCOPE_LINK, Scope: unix.RT_SCOPE_LINK,
@@ -447,7 +463,7 @@ func (t *tun) removeRoutes(routes []Route) {
} }
} }
func (t *tun) Cidr() *net.IPNet { func (t *tun) Cidr() netip.Prefix {
return t.cidr return t.cidr
} }
@@ -499,7 +515,15 @@ func (t *tun) updateRoutes(r netlink.RouteUpdate) {
return return
} }
if !t.cidr.Contains(r.Gw) { //TODO: IPV6-WORK what if not ok?
gwAddr, ok := netip.AddrFromSlice(r.Gw)
if !ok {
t.l.WithField("route", r).Debug("Ignoring route update, invalid gateway address")
return
}
gwAddr = gwAddr.Unmap()
if !t.cidr.Contains(gwAddr) {
// Gateway isn't in our overlay network, ignore // Gateway isn't in our overlay network, ignore
t.l.WithField("route", r).Debug("Ignoring route update, not in our network") t.l.WithField("route", r).Debug("Ignoring route update, not in our network")
return return
@@ -511,28 +535,25 @@ func (t *tun) updateRoutes(r netlink.RouteUpdate) {
return return
} }
newTree := cidr.NewTree4[iputil.VpnIp]() dstAddr, ok := netip.AddrFromSlice(r.Dst.IP)
if r.Type == unix.RTM_NEWROUTE { if !ok {
for _, oldR := range t.routeTree.Load().List() { t.l.WithField("route", r).Debug("Ignoring route update, invalid destination address")
newTree.AddCIDR(oldR.CIDR, oldR.Value) return
} }
ones, _ := r.Dst.Mask.Size()
dst := netip.PrefixFrom(dstAddr, ones)
newTree := t.routeTree.Load().Clone()
if r.Type == unix.RTM_NEWROUTE {
t.l.WithField("destination", r.Dst).WithField("via", r.Gw).Info("Adding route") t.l.WithField("destination", r.Dst).WithField("via", r.Gw).Info("Adding route")
newTree.AddCIDR(r.Dst, iputil.Ip2VpnIp(r.Gw)) newTree.Insert(dst, gwAddr)
} else { } else {
gw := iputil.Ip2VpnIp(r.Gw) newTree.Delete(dst)
for _, oldR := range t.routeTree.Load().List() {
if bytes.Equal(oldR.CIDR.IP, r.Dst.IP) && bytes.Equal(oldR.CIDR.Mask, r.Dst.Mask) && oldR.Value == gw {
// This is the record to delete
t.l.WithField("destination", r.Dst).WithField("via", r.Gw).Info("Removing route") t.l.WithField("destination", r.Dst).WithField("via", r.Gw).Info("Removing route")
continue
} }
newTree.AddCIDR(oldR.CIDR, oldR.Value)
}
}
t.routeTree.Store(newTree) t.routeTree.Store(newTree)
} }

View File

@@ -6,7 +6,7 @@ package overlay
import ( import (
"fmt" "fmt"
"io" "io"
"net" "net/netip"
"os" "os"
"os/exec" "os/exec"
"regexp" "regexp"
@@ -15,10 +15,9 @@ import (
"syscall" "syscall"
"unsafe" "unsafe"
"github.com/gaissmai/bart"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cidr"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/util" "github.com/slackhq/nebula/util"
) )
@@ -29,10 +28,10 @@ type ifreqDestroy struct {
type tun struct { type tun struct {
Device string Device string
cidr *net.IPNet cidr netip.Prefix
MTU int MTU int
Routes atomic.Pointer[[]Route] Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[cidr.Tree4[iputil.VpnIp]] routeTree atomic.Pointer[bart.Table[netip.Addr]]
l *logrus.Logger l *logrus.Logger
io.ReadWriteCloser io.ReadWriteCloser
@@ -59,13 +58,13 @@ func (t *tun) Close() error {
return nil return nil
} }
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ *net.IPNet) (*tun, error) { func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ netip.Prefix) (*tun, error) {
return nil, fmt.Errorf("newTunFromFd not supported in NetBSD") return nil, fmt.Errorf("newTunFromFd not supported in NetBSD")
} }
var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`) var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
func newTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, _ bool) (*tun, error) { func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, error) {
// Try to open tun device // Try to open tun device
var file *os.File var file *os.File
var err error var err error
@@ -109,13 +108,13 @@ func (t *tun) Activate() error {
var err error var err error
// TODO use syscalls instead of exec.Command // TODO use syscalls instead of exec.Command
cmd := exec.Command("/sbin/ifconfig", t.Device, t.cidr.String(), t.cidr.IP.String()) cmd := exec.Command("/sbin/ifconfig", t.Device, t.cidr.String(), t.cidr.Addr().String())
t.l.Debug("command: ", cmd.String()) t.l.Debug("command: ", cmd.String())
if err = cmd.Run(); err != nil { if err = cmd.Run(); err != nil {
return fmt.Errorf("failed to run 'ifconfig': %s", err) return fmt.Errorf("failed to run 'ifconfig': %s", err)
} }
cmd = exec.Command("/sbin/route", "-n", "add", "-net", t.cidr.String(), t.cidr.IP.String()) cmd = exec.Command("/sbin/route", "-n", "add", "-net", t.cidr.String(), t.cidr.Addr().String())
t.l.Debug("command: ", cmd.String()) t.l.Debug("command: ", cmd.String())
if err = cmd.Run(); err != nil { if err = cmd.Run(); err != nil {
return fmt.Errorf("failed to run 'route add': %s", err) return fmt.Errorf("failed to run 'route add': %s", err)
@@ -168,12 +167,12 @@ func (t *tun) reload(c *config.C, initial bool) error {
return nil return nil
} }
func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp { func (t *tun) RouteFor(ip netip.Addr) netip.Addr {
_, r := t.routeTree.Load().MostSpecificContains(ip) r, _ := t.routeTree.Load().Lookup(ip)
return r return r
} }
func (t *tun) Cidr() *net.IPNet { func (t *tun) Cidr() netip.Prefix {
return t.cidr return t.cidr
} }
@@ -188,12 +187,12 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
func (t *tun) addRoutes(logErrors bool) error { func (t *tun) addRoutes(logErrors bool) error {
routes := *t.Routes.Load() routes := *t.Routes.Load()
for _, r := range routes { for _, r := range routes {
if r.Via == nil || !r.Install { if !r.Via.IsValid() || !r.Install {
// We don't allow route MTUs so only install routes with a via // We don't allow route MTUs so only install routes with a via
continue continue
} }
cmd := exec.Command("/sbin/route", "-n", "add", "-net", r.Cidr.String(), t.cidr.IP.String()) cmd := exec.Command("/sbin/route", "-n", "add", "-net", r.Cidr.String(), t.cidr.Addr().String())
t.l.Debug("command: ", cmd.String()) t.l.Debug("command: ", cmd.String())
if err := cmd.Run(); err != nil { if err := cmd.Run(); err != nil {
retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]interface{}{"route": r}, err) retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]interface{}{"route": r}, err)
@@ -214,7 +213,7 @@ func (t *tun) removeRoutes(routes []Route) error {
continue continue
} }
cmd := exec.Command("/sbin/route", "-n", "delete", "-net", r.Cidr.String(), t.cidr.IP.String()) cmd := exec.Command("/sbin/route", "-n", "delete", "-net", r.Cidr.String(), t.cidr.Addr().String())
t.l.Debug("command: ", cmd.String()) t.l.Debug("command: ", cmd.String())
if err := cmd.Run(); err != nil { if err := cmd.Run(); err != nil {
t.l.WithError(err).WithField("route", r).Error("Failed to remove route") t.l.WithError(err).WithField("route", r).Error("Failed to remove route")

View File

@@ -6,7 +6,7 @@ package overlay
import ( import (
"fmt" "fmt"
"io" "io"
"net" "net/netip"
"os" "os"
"os/exec" "os/exec"
"regexp" "regexp"
@@ -14,19 +14,18 @@ import (
"sync/atomic" "sync/atomic"
"syscall" "syscall"
"github.com/gaissmai/bart"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cidr"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/util" "github.com/slackhq/nebula/util"
) )
type tun struct { type tun struct {
Device string Device string
cidr *net.IPNet cidr netip.Prefix
MTU int MTU int
Routes atomic.Pointer[[]Route] Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[cidr.Tree4[iputil.VpnIp]] routeTree atomic.Pointer[bart.Table[netip.Addr]]
l *logrus.Logger l *logrus.Logger
io.ReadWriteCloser io.ReadWriteCloser
@@ -43,13 +42,13 @@ func (t *tun) Close() error {
return nil return nil
} }
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ *net.IPNet) (*tun, error) { func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ netip.Prefix) (*tun, error) {
return nil, fmt.Errorf("newTunFromFd not supported in OpenBSD") return nil, fmt.Errorf("newTunFromFd not supported in OpenBSD")
} }
var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`) var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
func newTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, _ bool) (*tun, error) { func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, error) {
deviceName := c.GetString("tun.dev", "") deviceName := c.GetString("tun.dev", "")
if deviceName == "" { if deviceName == "" {
return nil, fmt.Errorf("a device name in the format of tunN must be specified") return nil, fmt.Errorf("a device name in the format of tunN must be specified")
@@ -127,7 +126,7 @@ func (t *tun) reload(c *config.C, initial bool) error {
func (t *tun) Activate() error { func (t *tun) Activate() error {
var err error var err error
// TODO use syscalls instead of exec.Command // TODO use syscalls instead of exec.Command
cmd := exec.Command("/sbin/ifconfig", t.Device, t.cidr.String(), t.cidr.IP.String()) cmd := exec.Command("/sbin/ifconfig", t.Device, t.cidr.String(), t.cidr.Addr().String())
t.l.Debug("command: ", cmd.String()) t.l.Debug("command: ", cmd.String())
if err = cmd.Run(); err != nil { if err = cmd.Run(); err != nil {
return fmt.Errorf("failed to run 'ifconfig': %s", err) return fmt.Errorf("failed to run 'ifconfig': %s", err)
@@ -139,7 +138,7 @@ func (t *tun) Activate() error {
return fmt.Errorf("failed to run 'ifconfig': %s", err) return fmt.Errorf("failed to run 'ifconfig': %s", err)
} }
cmd = exec.Command("/sbin/route", "-n", "add", "-inet", t.cidr.String(), t.cidr.IP.String()) cmd = exec.Command("/sbin/route", "-n", "add", "-inet", t.cidr.String(), t.cidr.Addr().String())
t.l.Debug("command: ", cmd.String()) t.l.Debug("command: ", cmd.String())
if err = cmd.Run(); err != nil { if err = cmd.Run(); err != nil {
return fmt.Errorf("failed to run 'route add': %s", err) return fmt.Errorf("failed to run 'route add': %s", err)
@@ -149,20 +148,20 @@ func (t *tun) Activate() error {
return t.addRoutes(false) return t.addRoutes(false)
} }
func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp { func (t *tun) RouteFor(ip netip.Addr) netip.Addr {
_, r := t.routeTree.Load().MostSpecificContains(ip) r, _ := t.routeTree.Load().Lookup(ip)
return r return r
} }
func (t *tun) addRoutes(logErrors bool) error { func (t *tun) addRoutes(logErrors bool) error {
routes := *t.Routes.Load() routes := *t.Routes.Load()
for _, r := range routes { for _, r := range routes {
if r.Via == nil || !r.Install { if !r.Via.IsValid() || !r.Install {
// We don't allow route MTUs so only install routes with a via // We don't allow route MTUs so only install routes with a via
continue continue
} }
cmd := exec.Command("/sbin/route", "-n", "add", "-inet", r.Cidr.String(), t.cidr.IP.String()) cmd := exec.Command("/sbin/route", "-n", "add", "-inet", r.Cidr.String(), t.cidr.Addr().String())
t.l.Debug("command: ", cmd.String()) t.l.Debug("command: ", cmd.String())
if err := cmd.Run(); err != nil { if err := cmd.Run(); err != nil {
retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]interface{}{"route": r}, err) retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]interface{}{"route": r}, err)
@@ -183,7 +182,7 @@ func (t *tun) removeRoutes(routes []Route) error {
continue continue
} }
cmd := exec.Command("/sbin/route", "-n", "delete", "-inet", r.Cidr.String(), t.cidr.IP.String()) cmd := exec.Command("/sbin/route", "-n", "delete", "-inet", r.Cidr.String(), t.cidr.Addr().String())
t.l.Debug("command: ", cmd.String()) t.l.Debug("command: ", cmd.String())
if err := cmd.Run(); err != nil { if err := cmd.Run(); err != nil {
t.l.WithError(err).WithField("route", r).Error("Failed to remove route") t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
@@ -194,7 +193,7 @@ func (t *tun) removeRoutes(routes []Route) error {
return nil return nil
} }
func (t *tun) Cidr() *net.IPNet { func (t *tun) Cidr() netip.Prefix {
return t.cidr return t.cidr
} }

View File

@@ -6,21 +6,20 @@ package overlay
import ( import (
"fmt" "fmt"
"io" "io"
"net" "net/netip"
"os" "os"
"sync/atomic" "sync/atomic"
"github.com/gaissmai/bart"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cidr"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/iputil"
) )
type TestTun struct { type TestTun struct {
Device string Device string
cidr *net.IPNet cidr netip.Prefix
Routes []Route Routes []Route
routeTree *cidr.Tree4[iputil.VpnIp] routeTree *bart.Table[netip.Addr]
l *logrus.Logger l *logrus.Logger
closed atomic.Bool closed atomic.Bool
@@ -28,7 +27,7 @@ type TestTun struct {
TxPackets chan []byte // Packets transmitted outside by nebula TxPackets chan []byte // Packets transmitted outside by nebula
} }
func newTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, _ bool) (*TestTun, error) { func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*TestTun, error) {
_, routes, err := getAllRoutesFromConfig(c, cidr, true) _, routes, err := getAllRoutesFromConfig(c, cidr, true)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -49,7 +48,7 @@ func newTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, _ bool) (*TestTun, e
}, nil }, nil
} }
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ *net.IPNet) (*TestTun, error) { func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ netip.Prefix) (*TestTun, error) {
return nil, fmt.Errorf("newTunFromFd not supported") return nil, fmt.Errorf("newTunFromFd not supported")
} }
@@ -87,8 +86,8 @@ func (t *TestTun) Get(block bool) []byte {
// Below this is boilerplate implementation to make nebula actually work // Below this is boilerplate implementation to make nebula actually work
//********************************************************************************************************************// //********************************************************************************************************************//
func (t *TestTun) RouteFor(ip iputil.VpnIp) iputil.VpnIp { func (t *TestTun) RouteFor(ip netip.Addr) netip.Addr {
_, r := t.routeTree.MostSpecificContains(ip) r, _ := t.routeTree.Lookup(ip)
return r return r
} }
@@ -96,7 +95,7 @@ func (t *TestTun) Activate() error {
return nil return nil
} }
func (t *TestTun) Cidr() *net.IPNet { func (t *TestTun) Cidr() netip.Prefix {
return t.cidr return t.cidr
} }

View File

@@ -4,30 +4,30 @@ import (
"fmt" "fmt"
"io" "io"
"net" "net"
"net/netip"
"os/exec" "os/exec"
"strconv" "strconv"
"sync/atomic" "sync/atomic"
"github.com/gaissmai/bart"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cidr"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/util" "github.com/slackhq/nebula/util"
"github.com/songgao/water" "github.com/songgao/water"
) )
type waterTun struct { type waterTun struct {
Device string Device string
cidr *net.IPNet cidr netip.Prefix
MTU int MTU int
Routes atomic.Pointer[[]Route] Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[cidr.Tree4[iputil.VpnIp]] routeTree atomic.Pointer[bart.Table[netip.Addr]]
l *logrus.Logger l *logrus.Logger
f *net.Interface f *net.Interface
*water.Interface *water.Interface
} }
func newWaterTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, _ bool) (*waterTun, error) { func newWaterTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*waterTun, error) {
// NOTE: You cannot set the deviceName under Windows, so you must check tun.Device after calling .Activate() // NOTE: You cannot set the deviceName under Windows, so you must check tun.Device after calling .Activate()
t := &waterTun{ t := &waterTun{
cidr: cidr, cidr: cidr,
@@ -70,8 +70,8 @@ func (t *waterTun) Activate() error {
`C:\Windows\System32\netsh.exe`, "interface", "ipv4", "set", "address", `C:\Windows\System32\netsh.exe`, "interface", "ipv4", "set", "address",
fmt.Sprintf("name=%s", t.Device), fmt.Sprintf("name=%s", t.Device),
"source=static", "source=static",
fmt.Sprintf("addr=%s", t.cidr.IP), fmt.Sprintf("addr=%s", t.cidr.Addr()),
fmt.Sprintf("mask=%s", net.IP(t.cidr.Mask)), fmt.Sprintf("mask=%s", net.CIDRMask(t.cidr.Bits(), t.cidr.Addr().BitLen())),
"gateway=none", "gateway=none",
).Run() ).Run()
if err != nil { if err != nil {
@@ -141,7 +141,7 @@ func (t *waterTun) addRoutes(logErrors bool) error {
// Path routes // Path routes
routes := *t.Routes.Load() routes := *t.Routes.Load()
for _, r := range routes { for _, r := range routes {
if r.Via == nil || !r.Install { if !r.Via.IsValid() || !r.Install {
// We don't allow route MTUs so only install routes with a via // We don't allow route MTUs so only install routes with a via
continue continue
} }
@@ -182,12 +182,12 @@ func (t *waterTun) removeRoutes(routes []Route) {
} }
} }
func (t *waterTun) RouteFor(ip iputil.VpnIp) iputil.VpnIp { func (t *waterTun) RouteFor(ip netip.Addr) netip.Addr {
_, r := t.routeTree.Load().MostSpecificContains(ip) r, _ := t.routeTree.Load().Lookup(ip)
return r return r
} }
func (t *waterTun) Cidr() *net.IPNet { func (t *waterTun) Cidr() netip.Prefix {
return t.cidr return t.cidr
} }

View File

@@ -5,7 +5,7 @@ package overlay
import ( import (
"fmt" "fmt"
"net" "net/netip"
"os" "os"
"path/filepath" "path/filepath"
"runtime" "runtime"
@@ -15,11 +15,11 @@ import (
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
) )
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ *net.IPNet) (Device, error) { func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ netip.Prefix) (Device, error) {
return nil, fmt.Errorf("newTunFromFd not supported in Windows") return nil, fmt.Errorf("newTunFromFd not supported in Windows")
} }
func newTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, multiqueue bool) (Device, error) { func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, multiqueue bool) (Device, error) {
useWintun := true useWintun := true
if err := checkWinTunExists(); err != nil { if err := checkWinTunExists(); err != nil {
l.WithError(err).Warn("Check Wintun driver failed, fallback to wintap driver") l.WithError(err).Warn("Check Wintun driver failed, fallback to wintap driver")

View File

@@ -4,15 +4,13 @@ import (
"crypto" "crypto"
"fmt" "fmt"
"io" "io"
"net"
"net/netip" "net/netip"
"sync/atomic" "sync/atomic"
"unsafe" "unsafe"
"github.com/gaissmai/bart"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cidr"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/util" "github.com/slackhq/nebula/util"
"github.com/slackhq/nebula/wintun" "github.com/slackhq/nebula/wintun"
"golang.org/x/sys/windows" "golang.org/x/sys/windows"
@@ -23,11 +21,10 @@ const tunGUIDLabel = "Fixed Nebula Windows GUID v1"
type winTun struct { type winTun struct {
Device string Device string
cidr *net.IPNet cidr netip.Prefix
prefix netip.Prefix
MTU int MTU int
Routes atomic.Pointer[[]Route] Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[cidr.Tree4[iputil.VpnIp]] routeTree atomic.Pointer[bart.Table[netip.Addr]]
l *logrus.Logger l *logrus.Logger
tun *wintun.NativeTun tun *wintun.NativeTun
@@ -52,22 +49,16 @@ func generateGUIDByDeviceName(name string) (*windows.GUID, error) {
return (*windows.GUID)(unsafe.Pointer(&sum[0])), nil return (*windows.GUID)(unsafe.Pointer(&sum[0])), nil
} }
func newWinTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, _ bool) (*winTun, error) { func newWinTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*winTun, error) {
deviceName := c.GetString("tun.dev", "") deviceName := c.GetString("tun.dev", "")
guid, err := generateGUIDByDeviceName(deviceName) guid, err := generateGUIDByDeviceName(deviceName)
if err != nil { if err != nil {
return nil, fmt.Errorf("generate GUID failed: %w", err) return nil, fmt.Errorf("generate GUID failed: %w", err)
} }
prefix, err := iputil.ToNetIpPrefix(*cidr)
if err != nil {
return nil, err
}
t := &winTun{ t := &winTun{
Device: deviceName, Device: deviceName,
cidr: cidr, cidr: cidr,
prefix: prefix,
MTU: c.GetInt("tun.mtu", DefaultMTU), MTU: c.GetInt("tun.mtu", DefaultMTU),
l: l, l: l,
} }
@@ -140,7 +131,7 @@ func (t *winTun) reload(c *config.C, initial bool) error {
func (t *winTun) Activate() error { func (t *winTun) Activate() error {
luid := winipcfg.LUID(t.tun.LUID()) luid := winipcfg.LUID(t.tun.LUID())
err := luid.SetIPAddresses([]netip.Prefix{t.prefix}) err := luid.SetIPAddresses([]netip.Prefix{t.cidr})
if err != nil { if err != nil {
return fmt.Errorf("failed to set address: %w", err) return fmt.Errorf("failed to set address: %w", err)
} }
@@ -159,24 +150,13 @@ func (t *winTun) addRoutes(logErrors bool) error {
foundDefault4 := false foundDefault4 := false
for _, r := range routes { for _, r := range routes {
if r.Via == nil || !r.Install { if !r.Via.IsValid() || !r.Install {
// We don't allow route MTUs so only install routes with a via // We don't allow route MTUs so only install routes with a via
continue continue
} }
prefix, err := iputil.ToNetIpPrefix(*r.Cidr)
if err != nil {
retErr := util.NewContextualError("Failed to parse cidr to netip prefix, ignoring route", map[string]interface{}{"route": r}, err)
if logErrors {
retErr.Log(t.l)
continue
} else {
return retErr
}
}
// Add our unsafe route // Add our unsafe route
err = luid.AddRoute(prefix, r.Via.ToNetIpAddr(), uint32(r.Metric)) err := luid.AddRoute(r.Cidr, r.Via, uint32(r.Metric))
if err != nil { if err != nil {
retErr := util.NewContextualError("Failed to add route", map[string]interface{}{"route": r}, err) retErr := util.NewContextualError("Failed to add route", map[string]interface{}{"route": r}, err)
if logErrors { if logErrors {
@@ -190,7 +170,7 @@ func (t *winTun) addRoutes(logErrors bool) error {
} }
if !foundDefault4 { if !foundDefault4 {
if ones, bits := r.Cidr.Mask.Size(); ones == 0 && bits != 0 { if r.Cidr.Bits() == 0 && r.Cidr.Addr().BitLen() == 32 {
foundDefault4 = true foundDefault4 = true
} }
} }
@@ -221,13 +201,7 @@ func (t *winTun) removeRoutes(routes []Route) error {
continue continue
} }
prefix, err := iputil.ToNetIpPrefix(*r.Cidr) err := luid.DeleteRoute(r.Cidr, r.Via)
if err != nil {
t.l.WithError(err).WithField("route", r).Info("Failed to convert cidr to netip prefix")
continue
}
err = luid.DeleteRoute(prefix, r.Via.ToNetIpAddr())
if err != nil { if err != nil {
t.l.WithError(err).WithField("route", r).Error("Failed to remove route") t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
} else { } else {
@@ -237,12 +211,12 @@ func (t *winTun) removeRoutes(routes []Route) error {
return nil return nil
} }
func (t *winTun) RouteFor(ip iputil.VpnIp) iputil.VpnIp { func (t *winTun) RouteFor(ip netip.Addr) netip.Addr {
_, r := t.routeTree.Load().MostSpecificContains(ip) r, _ := t.routeTree.Load().Lookup(ip)
return r return r
} }
func (t *winTun) Cidr() *net.IPNet { func (t *winTun) Cidr() netip.Prefix {
return t.cidr return t.cidr
} }

View File

@@ -2,18 +2,17 @@ package overlay
import ( import (
"io" "io"
"net" "net/netip"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/iputil"
) )
func NewUserDeviceFromConfig(c *config.C, l *logrus.Logger, tunCidr *net.IPNet, routines int) (Device, error) { func NewUserDeviceFromConfig(c *config.C, l *logrus.Logger, tunCidr netip.Prefix, routines int) (Device, error) {
return NewUserDevice(tunCidr) return NewUserDevice(tunCidr)
} }
func NewUserDevice(tunCidr *net.IPNet) (Device, error) { func NewUserDevice(tunCidr netip.Prefix) (Device, error) {
// these pipes guarantee each write/read will match 1:1 // these pipes guarantee each write/read will match 1:1
or, ow := io.Pipe() or, ow := io.Pipe()
ir, iw := io.Pipe() ir, iw := io.Pipe()
@@ -27,7 +26,7 @@ func NewUserDevice(tunCidr *net.IPNet) (Device, error) {
} }
type UserDevice struct { type UserDevice struct {
tunCidr *net.IPNet tunCidr netip.Prefix
outboundReader *io.PipeReader outboundReader *io.PipeReader
outboundWriter *io.PipeWriter outboundWriter *io.PipeWriter
@@ -39,9 +38,9 @@ type UserDevice struct {
func (d *UserDevice) Activate() error { func (d *UserDevice) Activate() error {
return nil return nil
} }
func (d *UserDevice) Cidr() *net.IPNet { return d.tunCidr } func (d *UserDevice) Cidr() netip.Prefix { return d.tunCidr }
func (d *UserDevice) Name() string { return "faketun0" } func (d *UserDevice) Name() string { return "faketun0" }
func (d *UserDevice) RouteFor(ip iputil.VpnIp) iputil.VpnIp { return ip } func (d *UserDevice) RouteFor(ip netip.Addr) netip.Addr { return ip }
func (d *UserDevice) NewMultiQueueReader() (io.ReadWriteCloser, error) { func (d *UserDevice) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return d, nil return d, nil
} }

21
pki.go
View File

@@ -80,6 +80,8 @@ func (p *PKI) reloadCert(c *config.C, initial bool) *util.ContextualError {
} }
if !initial { if !initial {
//TODO: include check for mask equality as well
// did IP in cert change? if so, don't set // did IP in cert change? if so, don't set
currentCert := p.cs.Load().Certificate currentCert := p.cs.Load().Certificate
oldIPs := currentCert.Details.Ips oldIPs := currentCert.Details.Ips
@@ -221,22 +223,13 @@ func loadCAPoolFromConfig(l *logrus.Logger, c *config.C) (*cert.NebulaCAPool, er
} }
} }
caPool, err := cert.NewCAPoolFromBytes(rawCA) caPool, warnings, err := cert.NewCAPoolFromBytes(rawCA)
if errors.Is(err, cert.ErrExpired) { for _, w := range warnings {
var expired int l.WithError(w).Warn("parsing a CA certificate failed")
for _, crt := range caPool.CAs {
if crt.Expired(time.Now()) {
expired++
l.WithField("cert", crt).Warn("expired certificate present in CA pool")
}
} }
if expired >= len(caPool.CAs) { if err != nil {
return nil, errors.New("no valid CA certificates present") return nil, fmt.Errorf("could not create CA certificate pool: %s", err)
}
} else if err != nil {
return nil, fmt.Errorf("error while adding CA certificate to CA trust store: %s", err)
} }
for _, fp := range c.GetStringSlice("pki.blocklist", []string{}) { for _, fp := range c.GetStringSlice("pki.blocklist", []string{}) {

View File

@@ -2,14 +2,15 @@ package nebula
import ( import (
"context" "context"
"encoding/binary"
"errors" "errors"
"fmt" "fmt"
"net/netip"
"sync/atomic" "sync/atomic"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/header" "github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/iputil"
) )
type relayManager struct { type relayManager struct {
@@ -50,7 +51,7 @@ func (rm *relayManager) setAmRelay(v bool) {
// AddRelay finds an available relay index on the hostmap, and associates the relay info with it. // AddRelay finds an available relay index on the hostmap, and associates the relay info with it.
// relayHostInfo is the Nebula peer which can be used as a relay to access the target vpnIp. // relayHostInfo is the Nebula peer which can be used as a relay to access the target vpnIp.
func AddRelay(l *logrus.Logger, relayHostInfo *HostInfo, hm *HostMap, vpnIp iputil.VpnIp, remoteIdx *uint32, relayType int, state int) (uint32, error) { func AddRelay(l *logrus.Logger, relayHostInfo *HostInfo, hm *HostMap, vpnIp netip.Addr, remoteIdx *uint32, relayType int, state int) (uint32, error) {
hm.Lock() hm.Lock()
defer hm.Unlock() defer hm.Unlock()
for i := 0; i < 32; i++ { for i := 0; i < 32; i++ {
@@ -113,13 +114,17 @@ func (rm *relayManager) HandleControlMsg(h *HostInfo, m *NebulaControl, f *Inter
func (rm *relayManager) handleCreateRelayResponse(h *HostInfo, f *Interface, m *NebulaControl) { func (rm *relayManager) handleCreateRelayResponse(h *HostInfo, f *Interface, m *NebulaControl) {
rm.l.WithFields(logrus.Fields{ rm.l.WithFields(logrus.Fields{
"relayFrom": iputil.VpnIp(m.RelayFromIp), "relayFrom": m.RelayFromIp,
"relayTo": iputil.VpnIp(m.RelayToIp), "relayTo": m.RelayToIp,
"initiatorRelayIndex": m.InitiatorRelayIndex, "initiatorRelayIndex": m.InitiatorRelayIndex,
"responderRelayIndex": m.ResponderRelayIndex, "responderRelayIndex": m.ResponderRelayIndex,
"vpnIp": h.vpnIp}). "vpnIp": h.vpnIp}).
Info("handleCreateRelayResponse") Info("handleCreateRelayResponse")
target := iputil.VpnIp(m.RelayToIp) target := m.RelayToIp
//TODO: IPV6-WORK
b := [4]byte{}
binary.BigEndian.PutUint32(b[:], m.RelayToIp)
targetAddr := netip.AddrFrom4(b)
relay, err := rm.EstablishRelay(h, m) relay, err := rm.EstablishRelay(h, m)
if err != nil { if err != nil {
@@ -136,18 +141,24 @@ func (rm *relayManager) handleCreateRelayResponse(h *HostInfo, f *Interface, m *
rm.l.WithField("relayTo", relay.PeerIp).Error("Can't find a HostInfo for peer") rm.l.WithField("relayTo", relay.PeerIp).Error("Can't find a HostInfo for peer")
return return
} }
peerRelay, ok := peerHostInfo.relayState.QueryRelayForByIp(target) peerRelay, ok := peerHostInfo.relayState.QueryRelayForByIp(targetAddr)
if !ok { if !ok {
rm.l.WithField("relayTo", peerHostInfo.vpnIp).Error("peerRelay does not have Relay state for relayTo") rm.l.WithField("relayTo", peerHostInfo.vpnIp).Error("peerRelay does not have Relay state for relayTo")
return return
} }
if peerRelay.State == PeerRequested { switch peerRelay.State {
peerRelay.State = Established case Requested:
// I initiated the request to this peer, but haven't heard back from the peer yet. I must wait for this peer
// to respond to complete the connection.
case PeerRequested, Disestablished, Established:
peerHostInfo.relayState.UpdateRelayForByIpState(targetAddr, Established)
//TODO: IPV6-WORK
b = peerHostInfo.vpnIp.As4()
resp := NebulaControl{ resp := NebulaControl{
Type: NebulaControl_CreateRelayResponse, Type: NebulaControl_CreateRelayResponse,
ResponderRelayIndex: peerRelay.LocalIndex, ResponderRelayIndex: peerRelay.LocalIndex,
InitiatorRelayIndex: peerRelay.RemoteIndex, InitiatorRelayIndex: peerRelay.RemoteIndex,
RelayFromIp: uint32(peerHostInfo.vpnIp), RelayFromIp: binary.BigEndian.Uint32(b[:]),
RelayToIp: uint32(target), RelayToIp: uint32(target),
} }
msg, err := resp.Marshal() msg, err := resp.Marshal()
@@ -157,8 +168,8 @@ func (rm *relayManager) handleCreateRelayResponse(h *HostInfo, f *Interface, m *
} else { } else {
f.SendMessageToHostInfo(header.Control, 0, peerHostInfo, msg, make([]byte, 12), make([]byte, mtu)) f.SendMessageToHostInfo(header.Control, 0, peerHostInfo, msg, make([]byte, 12), make([]byte, mtu))
rm.l.WithFields(logrus.Fields{ rm.l.WithFields(logrus.Fields{
"relayFrom": iputil.VpnIp(resp.RelayFromIp), "relayFrom": resp.RelayFromIp,
"relayTo": iputil.VpnIp(resp.RelayToIp), "relayTo": resp.RelayToIp,
"initiatorRelayIndex": resp.InitiatorRelayIndex, "initiatorRelayIndex": resp.InitiatorRelayIndex,
"responderRelayIndex": resp.ResponderRelayIndex, "responderRelayIndex": resp.ResponderRelayIndex,
"vpnIp": peerHostInfo.vpnIp}). "vpnIp": peerHostInfo.vpnIp}).
@@ -168,9 +179,13 @@ func (rm *relayManager) handleCreateRelayResponse(h *HostInfo, f *Interface, m *
} }
func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *NebulaControl) { func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *NebulaControl) {
//TODO: IPV6-WORK
b := [4]byte{}
binary.BigEndian.PutUint32(b[:], m.RelayFromIp)
from := netip.AddrFrom4(b)
from := iputil.VpnIp(m.RelayFromIp) binary.BigEndian.PutUint32(b[:], m.RelayToIp)
target := iputil.VpnIp(m.RelayToIp) target := netip.AddrFrom4(b)
logMsg := rm.l.WithFields(logrus.Fields{ logMsg := rm.l.WithFields(logrus.Fields{
"relayFrom": from, "relayFrom": from,
@@ -181,12 +196,12 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N
logMsg.Info("handleCreateRelayRequest") logMsg.Info("handleCreateRelayRequest")
// Is the source of the relay me? This should never happen, but did happen due to // Is the source of the relay me? This should never happen, but did happen due to
// an issue migrating relays over to newly re-handshaked host info objects. // an issue migrating relays over to newly re-handshaked host info objects.
if from == f.myVpnIp { if from == f.myVpnNet.Addr() {
logMsg.WithField("myIP", f.myVpnIp).Error("Discarding relay request from myself") logMsg.WithField("myIP", from).Error("Discarding relay request from myself")
return return
} }
// Is the target of the relay me? // Is the target of the relay me?
if target == f.myVpnIp { if target == f.myVpnNet.Addr() {
existingRelay, ok := h.relayState.QueryRelayForByIp(from) existingRelay, ok := h.relayState.QueryRelayForByIp(from)
if ok { if ok {
switch existingRelay.State { switch existingRelay.State {
@@ -204,6 +219,21 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N
"existingRemoteIndex": existingRelay.RemoteIndex}).Error("Existing relay mismatch with CreateRelayRequest") "existingRemoteIndex": existingRelay.RemoteIndex}).Error("Existing relay mismatch with CreateRelayRequest")
return return
} }
case Disestablished:
if existingRelay.RemoteIndex != m.InitiatorRelayIndex {
// We got a brand new Relay request, because its index is different than what we saw before.
// This should never happen. The peer should never change an index, once created.
logMsg.WithFields(logrus.Fields{
"existingRemoteIndex": existingRelay.RemoteIndex}).Error("Existing relay mismatch with CreateRelayRequest")
return
}
// Mark the relay as 'Established' because it's safe to use again
h.relayState.UpdateRelayForByIpState(from, Established)
case PeerRequested:
// I should never be in this state, because I am terminal, not forwarding.
logMsg.WithFields(logrus.Fields{
"existingRemoteIndex": existingRelay.RemoteIndex,
"state": existingRelay.State}).Error("Unexpected Relay State found")
} }
} else { } else {
_, err := AddRelay(rm.l, h, f.hostMap, from, &m.InitiatorRelayIndex, TerminalType, Established) _, err := AddRelay(rm.l, h, f.hostMap, from, &m.InitiatorRelayIndex, TerminalType, Established)
@@ -215,16 +245,20 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N
relay, ok := h.relayState.QueryRelayForByIp(from) relay, ok := h.relayState.QueryRelayForByIp(from)
if !ok { if !ok {
logMsg.Error("Relay State not found") logMsg.WithField("from", from).Error("Relay State not found")
return return
} }
//TODO: IPV6-WORK
fromB := from.As4()
targetB := target.As4()
resp := NebulaControl{ resp := NebulaControl{
Type: NebulaControl_CreateRelayResponse, Type: NebulaControl_CreateRelayResponse,
ResponderRelayIndex: relay.LocalIndex, ResponderRelayIndex: relay.LocalIndex,
InitiatorRelayIndex: relay.RemoteIndex, InitiatorRelayIndex: relay.RemoteIndex,
RelayFromIp: uint32(from), RelayFromIp: binary.BigEndian.Uint32(fromB[:]),
RelayToIp: uint32(target), RelayToIp: binary.BigEndian.Uint32(targetB[:]),
} }
msg, err := resp.Marshal() msg, err := resp.Marshal()
if err != nil { if err != nil {
@@ -233,8 +267,9 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N
} else { } else {
f.SendMessageToHostInfo(header.Control, 0, h, msg, make([]byte, 12), make([]byte, mtu)) f.SendMessageToHostInfo(header.Control, 0, h, msg, make([]byte, 12), make([]byte, mtu))
rm.l.WithFields(logrus.Fields{ rm.l.WithFields(logrus.Fields{
"relayFrom": iputil.VpnIp(resp.RelayFromIp), //TODO: IPV6-WORK, this used to use the resp object but I am getting lazy now
"relayTo": iputil.VpnIp(resp.RelayToIp), "relayFrom": from,
"relayTo": target,
"initiatorRelayIndex": resp.InitiatorRelayIndex, "initiatorRelayIndex": resp.InitiatorRelayIndex,
"responderRelayIndex": resp.ResponderRelayIndex, "responderRelayIndex": resp.ResponderRelayIndex,
"vpnIp": h.vpnIp}). "vpnIp": h.vpnIp}).
@@ -253,34 +288,31 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N
f.Handshake(target) f.Handshake(target)
return return
} }
if peer.remote == nil { if !peer.remote.IsValid() {
// Only create relays to peers for whom I have a direct connection // Only create relays to peers for whom I have a direct connection
return return
} }
sendCreateRequest := false
var index uint32 var index uint32
var err error var err error
targetRelay, ok := peer.relayState.QueryRelayForByIp(from) targetRelay, ok := peer.relayState.QueryRelayForByIp(from)
if ok { if ok {
index = targetRelay.LocalIndex index = targetRelay.LocalIndex
if targetRelay.State == Requested {
sendCreateRequest = true
}
} else { } else {
// Allocate an index in the hostMap for this relay peer // Allocate an index in the hostMap for this relay peer
index, err = AddRelay(rm.l, peer, f.hostMap, from, nil, ForwardingType, Requested) index, err = AddRelay(rm.l, peer, f.hostMap, from, nil, ForwardingType, Requested)
if err != nil { if err != nil {
return return
} }
sendCreateRequest = true
} }
if sendCreateRequest { peer.relayState.UpdateRelayForByIpState(from, Requested)
// Send a CreateRelayRequest to the peer. // Send a CreateRelayRequest to the peer.
fromB := from.As4()
targetB := target.As4()
req := NebulaControl{ req := NebulaControl{
Type: NebulaControl_CreateRelayRequest, Type: NebulaControl_CreateRelayRequest,
InitiatorRelayIndex: index, InitiatorRelayIndex: index,
RelayFromIp: uint32(h.vpnIp), RelayFromIp: binary.BigEndian.Uint32(fromB[:]),
RelayToIp: uint32(target), RelayToIp: binary.BigEndian.Uint32(targetB[:]),
} }
msg, err := req.Marshal() msg, err := req.Marshal()
if err != nil { if err != nil {
@@ -289,67 +321,24 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N
} else { } else {
f.SendMessageToHostInfo(header.Control, 0, peer, msg, make([]byte, 12), make([]byte, mtu)) f.SendMessageToHostInfo(header.Control, 0, peer, msg, make([]byte, 12), make([]byte, mtu))
rm.l.WithFields(logrus.Fields{ rm.l.WithFields(logrus.Fields{
"relayFrom": iputil.VpnIp(req.RelayFromIp), //TODO: IPV6-WORK another lazy used to use the req object
"relayTo": iputil.VpnIp(req.RelayToIp), "relayFrom": h.vpnIp,
"relayTo": target,
"initiatorRelayIndex": req.InitiatorRelayIndex, "initiatorRelayIndex": req.InitiatorRelayIndex,
"responderRelayIndex": req.ResponderRelayIndex, "responderRelayIndex": req.ResponderRelayIndex,
"vpnIp": target}). "vpnAddr": target}).
Info("send CreateRelayRequest") Info("send CreateRelayRequest")
}
}
// Also track the half-created Relay state just received // Also track the half-created Relay state just received
relay, ok := h.relayState.QueryRelayForByIp(target) _, ok := h.relayState.QueryRelayForByIp(target)
if !ok { if !ok {
// Add the relay // Add the relay
state := PeerRequested _, err := AddRelay(rm.l, h, f.hostMap, target, &m.InitiatorRelayIndex, ForwardingType, PeerRequested)
if targetRelay != nil && targetRelay.State == Established {
state = Established
}
_, err := AddRelay(rm.l, h, f.hostMap, target, &m.InitiatorRelayIndex, ForwardingType, state)
if err != nil { if err != nil {
logMsg. logMsg.
WithError(err).Error("relayManager Failed to allocate a local index for relay") WithError(err).Error("relayManager Failed to allocate a local index for relay")
return return
} }
} else {
switch relay.State {
case Established:
if relay.RemoteIndex != m.InitiatorRelayIndex {
// We got a brand new Relay request, because its index is different than what we saw before.
// This should never happen. The peer should never change an index, once created.
logMsg.WithFields(logrus.Fields{
"existingRemoteIndex": relay.RemoteIndex}).Error("Existing relay mismatch with CreateRelayRequest")
return
}
resp := NebulaControl{
Type: NebulaControl_CreateRelayResponse,
ResponderRelayIndex: relay.LocalIndex,
InitiatorRelayIndex: relay.RemoteIndex,
RelayFromIp: uint32(h.vpnIp),
RelayToIp: uint32(target),
}
msg, err := resp.Marshal()
if err != nil {
rm.l.
WithError(err).Error("relayManager Failed to marshal Control CreateRelayResponse message to create relay")
} else {
f.SendMessageToHostInfo(header.Control, 0, h, msg, make([]byte, 12), make([]byte, mtu))
rm.l.WithFields(logrus.Fields{
"relayFrom": iputil.VpnIp(resp.RelayFromIp),
"relayTo": iputil.VpnIp(resp.RelayToIp),
"initiatorRelayIndex": resp.InitiatorRelayIndex,
"responderRelayIndex": resp.ResponderRelayIndex,
"vpnIp": h.vpnIp}).
Info("send CreateRelayResponse")
}
case Requested:
// Keep waiting for the other relay to complete
} }
} }
} }
} }
func (rm *relayManager) RemoveRelay(localIdx uint32) {
rm.hostmap.RemoveRelay(localIdx)
}

View File

@@ -1,7 +1,6 @@
package nebula package nebula
import ( import (
"bytes"
"context" "context"
"net" "net"
"net/netip" "net/netip"
@@ -12,16 +11,14 @@ import (
"time" "time"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/udp"
) )
// forEachFunc is used to benefit folks that want to do work inside the lock // forEachFunc is used to benefit folks that want to do work inside the lock
type forEachFunc func(addr *udp.Addr, preferred bool) type forEachFunc func(addr netip.AddrPort, preferred bool)
// The checkFuncs here are to simplify bulk importing LH query response logic into a single function (reset slice and iterate) // The checkFuncs here are to simplify bulk importing LH query response logic into a single function (reset slice and iterate)
type checkFuncV4 func(vpnIp iputil.VpnIp, to *Ip4AndPort) bool type checkFuncV4 func(vpnIp netip.Addr, to *Ip4AndPort) bool
type checkFuncV6 func(vpnIp iputil.VpnIp, to *Ip6AndPort) bool type checkFuncV6 func(vpnIp netip.Addr, to *Ip6AndPort) bool
// CacheMap is a struct that better represents the lighthouse cache for humans // CacheMap is a struct that better represents the lighthouse cache for humans
// The string key is the owners vpnIp // The string key is the owners vpnIp
@@ -30,9 +27,9 @@ type CacheMap map[string]*Cache
// Cache is the other part of CacheMap to better represent the lighthouse cache for humans // Cache is the other part of CacheMap to better represent the lighthouse cache for humans
// We don't reason about ipv4 vs ipv6 here // We don't reason about ipv4 vs ipv6 here
type Cache struct { type Cache struct {
Learned []*udp.Addr `json:"learned,omitempty"` Learned []netip.AddrPort `json:"learned,omitempty"`
Reported []*udp.Addr `json:"reported,omitempty"` Reported []netip.AddrPort `json:"reported,omitempty"`
Relay []*net.IP `json:"relay"` Relay []netip.Addr `json:"relay"`
} }
//TODO: Seems like we should plop static host entries in here too since the are protected by the lighthouse from deletion //TODO: Seems like we should plop static host entries in here too since the are protected by the lighthouse from deletion
@@ -46,7 +43,7 @@ type cache struct {
} }
type cacheRelay struct { type cacheRelay struct {
relay []uint32 relay []netip.Addr
} }
// cacheV4 stores learned and reported ipv4 records under cache // cacheV4 stores learned and reported ipv4 records under cache
@@ -130,7 +127,7 @@ func NewHostnameResults(ctx context.Context, l *logrus.Logger, d time.Duration,
continue continue
} }
for _, a := range addrs { for _, a := range addrs {
netipAddrs[netip.AddrPortFrom(a, hostPort.port)] = struct{}{} netipAddrs[netip.AddrPortFrom(a.Unmap(), hostPort.port)] = struct{}{}
} }
} }
origSet := r.ips.Load() origSet := r.ips.Load()
@@ -193,22 +190,22 @@ type RemoteList struct {
sync.RWMutex sync.RWMutex
// A deduplicated set of addresses. Any accessor should lock beforehand. // A deduplicated set of addresses. Any accessor should lock beforehand.
addrs []*udp.Addr addrs []netip.AddrPort
// A set of relay addresses. VpnIp addresses that the remote identified as relays. // A set of relay addresses. VpnIp addresses that the remote identified as relays.
relays []*iputil.VpnIp relays []netip.Addr
// These are maps to store v4 and v6 addresses per lighthouse // These are maps to store v4 and v6 addresses per lighthouse
// Map key is the vpnIp of the person that told us about this the cached entries underneath. // Map key is the vpnIp of the person that told us about this the cached entries underneath.
// For learned addresses, this is the vpnIp that sent the packet // For learned addresses, this is the vpnIp that sent the packet
cache map[iputil.VpnIp]*cache cache map[netip.Addr]*cache
hr *hostnamesResults hr *hostnamesResults
shouldAdd func(netip.Addr) bool shouldAdd func(netip.Addr) bool
// This is a list of remotes that we have tried to handshake with and have returned from the wrong vpn ip. // This is a list of remotes that we have tried to handshake with and have returned from the wrong vpn ip.
// They should not be tried again during a handshake // They should not be tried again during a handshake
badRemotes []*udp.Addr badRemotes []netip.AddrPort
// A flag that the cache may have changed and addrs needs to be rebuilt // A flag that the cache may have changed and addrs needs to be rebuilt
shouldRebuild bool shouldRebuild bool
@@ -217,9 +214,9 @@ type RemoteList struct {
// NewRemoteList creates a new empty RemoteList // NewRemoteList creates a new empty RemoteList
func NewRemoteList(shouldAdd func(netip.Addr) bool) *RemoteList { func NewRemoteList(shouldAdd func(netip.Addr) bool) *RemoteList {
return &RemoteList{ return &RemoteList{
addrs: make([]*udp.Addr, 0), addrs: make([]netip.AddrPort, 0),
relays: make([]*iputil.VpnIp, 0), relays: make([]netip.Addr, 0),
cache: make(map[iputil.VpnIp]*cache), cache: make(map[netip.Addr]*cache),
shouldAdd: shouldAdd, shouldAdd: shouldAdd,
} }
} }
@@ -232,7 +229,7 @@ func (r *RemoteList) unlockedSetHostnamesResults(hr *hostnamesResults) {
// Len locks and reports the size of the deduplicated address list // Len locks and reports the size of the deduplicated address list
// The deduplication work may need to occur here, so you must pass preferredRanges // The deduplication work may need to occur here, so you must pass preferredRanges
func (r *RemoteList) Len(preferredRanges []*net.IPNet) int { func (r *RemoteList) Len(preferredRanges []netip.Prefix) int {
r.Rebuild(preferredRanges) r.Rebuild(preferredRanges)
r.RLock() r.RLock()
defer r.RUnlock() defer r.RUnlock()
@@ -241,18 +238,18 @@ func (r *RemoteList) Len(preferredRanges []*net.IPNet) int {
// ForEach locks and will call the forEachFunc for every deduplicated address in the list // ForEach locks and will call the forEachFunc for every deduplicated address in the list
// The deduplication work may need to occur here, so you must pass preferredRanges // The deduplication work may need to occur here, so you must pass preferredRanges
func (r *RemoteList) ForEach(preferredRanges []*net.IPNet, forEach forEachFunc) { func (r *RemoteList) ForEach(preferredRanges []netip.Prefix, forEach forEachFunc) {
r.Rebuild(preferredRanges) r.Rebuild(preferredRanges)
r.RLock() r.RLock()
for _, v := range r.addrs { for _, v := range r.addrs {
forEach(v, isPreferred(v.IP, preferredRanges)) forEach(v, isPreferred(v.Addr(), preferredRanges))
} }
r.RUnlock() r.RUnlock()
} }
// CopyAddrs locks and makes a deep copy of the deduplicated address list // CopyAddrs locks and makes a deep copy of the deduplicated address list
// The deduplication work may need to occur here, so you must pass preferredRanges // The deduplication work may need to occur here, so you must pass preferredRanges
func (r *RemoteList) CopyAddrs(preferredRanges []*net.IPNet) []*udp.Addr { func (r *RemoteList) CopyAddrs(preferredRanges []netip.Prefix) []netip.AddrPort {
if r == nil { if r == nil {
return nil return nil
} }
@@ -261,9 +258,9 @@ func (r *RemoteList) CopyAddrs(preferredRanges []*net.IPNet) []*udp.Addr {
r.RLock() r.RLock()
defer r.RUnlock() defer r.RUnlock()
c := make([]*udp.Addr, len(r.addrs)) c := make([]netip.AddrPort, len(r.addrs))
for i, v := range r.addrs { for i, v := range r.addrs {
c[i] = v.Copy() c[i] = v
} }
return c return c
} }
@@ -272,13 +269,13 @@ func (r *RemoteList) CopyAddrs(preferredRanges []*net.IPNet) []*udp.Addr {
// Currently this is only needed when HostInfo.SetRemote is called as that should cover both handshaking and roaming. // Currently this is only needed when HostInfo.SetRemote is called as that should cover both handshaking and roaming.
// It will mark the deduplicated address list as dirty, so do not call it unless new information is available // It will mark the deduplicated address list as dirty, so do not call it unless new information is available
// TODO: this needs to support the allow list list // TODO: this needs to support the allow list list
func (r *RemoteList) LearnRemote(ownerVpnIp iputil.VpnIp, addr *udp.Addr) { func (r *RemoteList) LearnRemote(ownerVpnIp netip.Addr, remote netip.AddrPort) {
r.Lock() r.Lock()
defer r.Unlock() defer r.Unlock()
if v4 := addr.IP.To4(); v4 != nil { if remote.Addr().Is4() {
r.unlockedSetLearnedV4(ownerVpnIp, NewIp4AndPort(v4, uint32(addr.Port))) r.unlockedSetLearnedV4(ownerVpnIp, NewIp4AndPortFromNetIP(remote.Addr(), remote.Port()))
} else { } else {
r.unlockedSetLearnedV6(ownerVpnIp, NewIp6AndPort(addr.IP, uint32(addr.Port))) r.unlockedSetLearnedV6(ownerVpnIp, NewIp6AndPortFromNetIP(remote.Addr(), remote.Port()))
} }
} }
@@ -293,9 +290,9 @@ func (r *RemoteList) CopyCache() *CacheMap {
c := cm[vpnIp] c := cm[vpnIp]
if c == nil { if c == nil {
c = &Cache{ c = &Cache{
Learned: make([]*udp.Addr, 0), Learned: make([]netip.AddrPort, 0),
Reported: make([]*udp.Addr, 0), Reported: make([]netip.AddrPort, 0),
Relay: make([]*net.IP, 0), Relay: make([]netip.Addr, 0),
} }
cm[vpnIp] = c cm[vpnIp] = c
} }
@@ -307,28 +304,27 @@ func (r *RemoteList) CopyCache() *CacheMap {
if mc.v4 != nil { if mc.v4 != nil {
if mc.v4.learned != nil { if mc.v4.learned != nil {
c.Learned = append(c.Learned, NewUDPAddrFromLH4(mc.v4.learned)) c.Learned = append(c.Learned, AddrPortFromIp4AndPort(mc.v4.learned))
} }
for _, a := range mc.v4.reported { for _, a := range mc.v4.reported {
c.Reported = append(c.Reported, NewUDPAddrFromLH4(a)) c.Reported = append(c.Reported, AddrPortFromIp4AndPort(a))
} }
} }
if mc.v6 != nil { if mc.v6 != nil {
if mc.v6.learned != nil { if mc.v6.learned != nil {
c.Learned = append(c.Learned, NewUDPAddrFromLH6(mc.v6.learned)) c.Learned = append(c.Learned, AddrPortFromIp6AndPort(mc.v6.learned))
} }
for _, a := range mc.v6.reported { for _, a := range mc.v6.reported {
c.Reported = append(c.Reported, NewUDPAddrFromLH6(a)) c.Reported = append(c.Reported, AddrPortFromIp6AndPort(a))
} }
} }
if mc.relay != nil { if mc.relay != nil {
for _, a := range mc.relay.relay { for _, a := range mc.relay.relay {
nip := iputil.VpnIp(a).ToIP() c.Relay = append(c.Relay, a)
c.Relay = append(c.Relay, &nip)
} }
} }
} }
@@ -337,8 +333,8 @@ func (r *RemoteList) CopyCache() *CacheMap {
} }
// BlockRemote locks and records the address as bad, it will be excluded from the deduplicated address list // BlockRemote locks and records the address as bad, it will be excluded from the deduplicated address list
func (r *RemoteList) BlockRemote(bad *udp.Addr) { func (r *RemoteList) BlockRemote(bad netip.AddrPort) {
if bad == nil { if !bad.IsValid() {
// relays can have nil udp Addrs // relays can have nil udp Addrs
return return
} }
@@ -351,20 +347,20 @@ func (r *RemoteList) BlockRemote(bad *udp.Addr) {
} }
// We copy here because we are taking something else's memory and we can't trust everything // We copy here because we are taking something else's memory and we can't trust everything
r.badRemotes = append(r.badRemotes, bad.Copy()) r.badRemotes = append(r.badRemotes, bad)
// Mark the next interaction must recollect/dedupe // Mark the next interaction must recollect/dedupe
r.shouldRebuild = true r.shouldRebuild = true
} }
// CopyBlockedRemotes locks and makes a deep copy of the blocked remotes list // CopyBlockedRemotes locks and makes a deep copy of the blocked remotes list
func (r *RemoteList) CopyBlockedRemotes() []*udp.Addr { func (r *RemoteList) CopyBlockedRemotes() []netip.AddrPort {
r.RLock() r.RLock()
defer r.RUnlock() defer r.RUnlock()
c := make([]*udp.Addr, len(r.badRemotes)) c := make([]netip.AddrPort, len(r.badRemotes))
for i, v := range r.badRemotes { for i, v := range r.badRemotes {
c[i] = v.Copy() c[i] = v
} }
return c return c
} }
@@ -378,7 +374,7 @@ func (r *RemoteList) ResetBlockedRemotes() {
// Rebuild locks and generates the deduplicated address list only if there is work to be done // Rebuild locks and generates the deduplicated address list only if there is work to be done
// There is generally no reason to call this directly but it is safe to do so // There is generally no reason to call this directly but it is safe to do so
func (r *RemoteList) Rebuild(preferredRanges []*net.IPNet) { func (r *RemoteList) Rebuild(preferredRanges []netip.Prefix) {
r.Lock() r.Lock()
defer r.Unlock() defer r.Unlock()
@@ -394,9 +390,9 @@ func (r *RemoteList) Rebuild(preferredRanges []*net.IPNet) {
} }
// unlockedIsBad assumes you have the write lock and checks if the remote matches any entry in the blocked address list // unlockedIsBad assumes you have the write lock and checks if the remote matches any entry in the blocked address list
func (r *RemoteList) unlockedIsBad(remote *udp.Addr) bool { func (r *RemoteList) unlockedIsBad(remote netip.AddrPort) bool {
for _, v := range r.badRemotes { for _, v := range r.badRemotes {
if v.Equals(remote) { if v == remote {
return true return true
} }
} }
@@ -405,14 +401,14 @@ func (r *RemoteList) unlockedIsBad(remote *udp.Addr) bool {
// unlockedSetLearnedV4 assumes you have the write lock and sets the current learned address for this owner and marks the // unlockedSetLearnedV4 assumes you have the write lock and sets the current learned address for this owner and marks the
// deduplicated address list as dirty // deduplicated address list as dirty
func (r *RemoteList) unlockedSetLearnedV4(ownerVpnIp iputil.VpnIp, to *Ip4AndPort) { func (r *RemoteList) unlockedSetLearnedV4(ownerVpnIp netip.Addr, to *Ip4AndPort) {
r.shouldRebuild = true r.shouldRebuild = true
r.unlockedGetOrMakeV4(ownerVpnIp).learned = to r.unlockedGetOrMakeV4(ownerVpnIp).learned = to
} }
// unlockedSetV4 assumes you have the write lock and resets the reported list of ips for this owner to the list provided // unlockedSetV4 assumes you have the write lock and resets the reported list of ips for this owner to the list provided
// and marks the deduplicated address list as dirty // and marks the deduplicated address list as dirty
func (r *RemoteList) unlockedSetV4(ownerVpnIp iputil.VpnIp, vpnIp iputil.VpnIp, to []*Ip4AndPort, check checkFuncV4) { func (r *RemoteList) unlockedSetV4(ownerVpnIp, vpnIp netip.Addr, to []*Ip4AndPort, check checkFuncV4) {
r.shouldRebuild = true r.shouldRebuild = true
c := r.unlockedGetOrMakeV4(ownerVpnIp) c := r.unlockedGetOrMakeV4(ownerVpnIp)
@@ -427,7 +423,7 @@ func (r *RemoteList) unlockedSetV4(ownerVpnIp iputil.VpnIp, vpnIp iputil.VpnIp,
} }
} }
func (r *RemoteList) unlockedSetRelay(ownerVpnIp iputil.VpnIp, vpnIp iputil.VpnIp, to []uint32) { func (r *RemoteList) unlockedSetRelay(ownerVpnIp, vpnIp netip.Addr, to []netip.Addr) {
r.shouldRebuild = true r.shouldRebuild = true
c := r.unlockedGetOrMakeRelay(ownerVpnIp) c := r.unlockedGetOrMakeRelay(ownerVpnIp)
@@ -440,7 +436,7 @@ func (r *RemoteList) unlockedSetRelay(ownerVpnIp iputil.VpnIp, vpnIp iputil.VpnI
// unlockedPrependV4 assumes you have the write lock and prepends the address in the reported list for this owner // unlockedPrependV4 assumes you have the write lock and prepends the address in the reported list for this owner
// This is only useful for establishing static hosts // This is only useful for establishing static hosts
func (r *RemoteList) unlockedPrependV4(ownerVpnIp iputil.VpnIp, to *Ip4AndPort) { func (r *RemoteList) unlockedPrependV4(ownerVpnIp netip.Addr, to *Ip4AndPort) {
r.shouldRebuild = true r.shouldRebuild = true
c := r.unlockedGetOrMakeV4(ownerVpnIp) c := r.unlockedGetOrMakeV4(ownerVpnIp)
@@ -453,14 +449,14 @@ func (r *RemoteList) unlockedPrependV4(ownerVpnIp iputil.VpnIp, to *Ip4AndPort)
// unlockedSetLearnedV6 assumes you have the write lock and sets the current learned address for this owner and marks the // unlockedSetLearnedV6 assumes you have the write lock and sets the current learned address for this owner and marks the
// deduplicated address list as dirty // deduplicated address list as dirty
func (r *RemoteList) unlockedSetLearnedV6(ownerVpnIp iputil.VpnIp, to *Ip6AndPort) { func (r *RemoteList) unlockedSetLearnedV6(ownerVpnIp netip.Addr, to *Ip6AndPort) {
r.shouldRebuild = true r.shouldRebuild = true
r.unlockedGetOrMakeV6(ownerVpnIp).learned = to r.unlockedGetOrMakeV6(ownerVpnIp).learned = to
} }
// unlockedSetV6 assumes you have the write lock and resets the reported list of ips for this owner to the list provided // unlockedSetV6 assumes you have the write lock and resets the reported list of ips for this owner to the list provided
// and marks the deduplicated address list as dirty // and marks the deduplicated address list as dirty
func (r *RemoteList) unlockedSetV6(ownerVpnIp iputil.VpnIp, vpnIp iputil.VpnIp, to []*Ip6AndPort, check checkFuncV6) { func (r *RemoteList) unlockedSetV6(ownerVpnIp, vpnIp netip.Addr, to []*Ip6AndPort, check checkFuncV6) {
r.shouldRebuild = true r.shouldRebuild = true
c := r.unlockedGetOrMakeV6(ownerVpnIp) c := r.unlockedGetOrMakeV6(ownerVpnIp)
@@ -477,7 +473,7 @@ func (r *RemoteList) unlockedSetV6(ownerVpnIp iputil.VpnIp, vpnIp iputil.VpnIp,
// unlockedPrependV6 assumes you have the write lock and prepends the address in the reported list for this owner // unlockedPrependV6 assumes you have the write lock and prepends the address in the reported list for this owner
// This is only useful for establishing static hosts // This is only useful for establishing static hosts
func (r *RemoteList) unlockedPrependV6(ownerVpnIp iputil.VpnIp, to *Ip6AndPort) { func (r *RemoteList) unlockedPrependV6(ownerVpnIp netip.Addr, to *Ip6AndPort) {
r.shouldRebuild = true r.shouldRebuild = true
c := r.unlockedGetOrMakeV6(ownerVpnIp) c := r.unlockedGetOrMakeV6(ownerVpnIp)
@@ -488,7 +484,7 @@ func (r *RemoteList) unlockedPrependV6(ownerVpnIp iputil.VpnIp, to *Ip6AndPort)
} }
} }
func (r *RemoteList) unlockedGetOrMakeRelay(ownerVpnIp iputil.VpnIp) *cacheRelay { func (r *RemoteList) unlockedGetOrMakeRelay(ownerVpnIp netip.Addr) *cacheRelay {
am := r.cache[ownerVpnIp] am := r.cache[ownerVpnIp]
if am == nil { if am == nil {
am = &cache{} am = &cache{}
@@ -503,7 +499,7 @@ func (r *RemoteList) unlockedGetOrMakeRelay(ownerVpnIp iputil.VpnIp) *cacheRelay
// unlockedGetOrMakeV4 assumes you have the write lock and builds the cache and owner entry. Only the v4 pointer is established. // unlockedGetOrMakeV4 assumes you have the write lock and builds the cache and owner entry. Only the v4 pointer is established.
// The caller must dirty the learned address cache if required // The caller must dirty the learned address cache if required
func (r *RemoteList) unlockedGetOrMakeV4(ownerVpnIp iputil.VpnIp) *cacheV4 { func (r *RemoteList) unlockedGetOrMakeV4(ownerVpnIp netip.Addr) *cacheV4 {
am := r.cache[ownerVpnIp] am := r.cache[ownerVpnIp]
if am == nil { if am == nil {
am = &cache{} am = &cache{}
@@ -518,7 +514,7 @@ func (r *RemoteList) unlockedGetOrMakeV4(ownerVpnIp iputil.VpnIp) *cacheV4 {
// unlockedGetOrMakeV6 assumes you have the write lock and builds the cache and owner entry. Only the v6 pointer is established. // unlockedGetOrMakeV6 assumes you have the write lock and builds the cache and owner entry. Only the v6 pointer is established.
// The caller must dirty the learned address cache if required // The caller must dirty the learned address cache if required
func (r *RemoteList) unlockedGetOrMakeV6(ownerVpnIp iputil.VpnIp) *cacheV6 { func (r *RemoteList) unlockedGetOrMakeV6(ownerVpnIp netip.Addr) *cacheV6 {
am := r.cache[ownerVpnIp] am := r.cache[ownerVpnIp]
if am == nil { if am == nil {
am = &cache{} am = &cache{}
@@ -540,14 +536,14 @@ func (r *RemoteList) unlockedCollect() {
for _, c := range r.cache { for _, c := range r.cache {
if c.v4 != nil { if c.v4 != nil {
if c.v4.learned != nil { if c.v4.learned != nil {
u := NewUDPAddrFromLH4(c.v4.learned) u := AddrPortFromIp4AndPort(c.v4.learned)
if !r.unlockedIsBad(u) { if !r.unlockedIsBad(u) {
addrs = append(addrs, u) addrs = append(addrs, u)
} }
} }
for _, v := range c.v4.reported { for _, v := range c.v4.reported {
u := NewUDPAddrFromLH4(v) u := AddrPortFromIp4AndPort(v)
if !r.unlockedIsBad(u) { if !r.unlockedIsBad(u) {
addrs = append(addrs, u) addrs = append(addrs, u)
} }
@@ -556,14 +552,14 @@ func (r *RemoteList) unlockedCollect() {
if c.v6 != nil { if c.v6 != nil {
if c.v6.learned != nil { if c.v6.learned != nil {
u := NewUDPAddrFromLH6(c.v6.learned) u := AddrPortFromIp6AndPort(c.v6.learned)
if !r.unlockedIsBad(u) { if !r.unlockedIsBad(u) {
addrs = append(addrs, u) addrs = append(addrs, u)
} }
} }
for _, v := range c.v6.reported { for _, v := range c.v6.reported {
u := NewUDPAddrFromLH6(v) u := AddrPortFromIp6AndPort(v)
if !r.unlockedIsBad(u) { if !r.unlockedIsBad(u) {
addrs = append(addrs, u) addrs = append(addrs, u)
} }
@@ -572,8 +568,7 @@ func (r *RemoteList) unlockedCollect() {
if c.relay != nil { if c.relay != nil {
for _, v := range c.relay.relay { for _, v := range c.relay.relay {
ip := iputil.VpnIp(v) relays = append(relays, v)
relays = append(relays, &ip)
} }
} }
} }
@@ -581,11 +576,7 @@ func (r *RemoteList) unlockedCollect() {
dnsAddrs := r.hr.GetIPs() dnsAddrs := r.hr.GetIPs()
for _, addr := range dnsAddrs { for _, addr := range dnsAddrs {
if r.shouldAdd == nil || r.shouldAdd(addr.Addr()) { if r.shouldAdd == nil || r.shouldAdd(addr.Addr()) {
v6 := addr.Addr().As16() addrs = append(addrs, addr)
addrs = append(addrs, &udp.Addr{
IP: v6[:],
Port: addr.Port(),
})
} }
} }
@@ -595,7 +586,7 @@ func (r *RemoteList) unlockedCollect() {
} }
// unlockedSort assumes you have the write lock and performs the deduping and sorting of the address list // unlockedSort assumes you have the write lock and performs the deduping and sorting of the address list
func (r *RemoteList) unlockedSort(preferredRanges []*net.IPNet) { func (r *RemoteList) unlockedSort(preferredRanges []netip.Prefix) {
n := len(r.addrs) n := len(r.addrs)
if n < 2 { if n < 2 {
return return
@@ -606,8 +597,8 @@ func (r *RemoteList) unlockedSort(preferredRanges []*net.IPNet) {
b := r.addrs[j] b := r.addrs[j]
// Preferred addresses first // Preferred addresses first
aPref := isPreferred(a.IP, preferredRanges) aPref := isPreferred(a.Addr(), preferredRanges)
bPref := isPreferred(b.IP, preferredRanges) bPref := isPreferred(b.Addr(), preferredRanges)
switch { switch {
case aPref && !bPref: case aPref && !bPref:
// If i is preferred and j is not, i is less than j // If i is preferred and j is not, i is less than j
@@ -622,21 +613,21 @@ func (r *RemoteList) unlockedSort(preferredRanges []*net.IPNet) {
} }
// ipv6 addresses 2nd // ipv6 addresses 2nd
a4 := a.IP.To4() a4 := a.Addr().Is4()
b4 := b.IP.To4() b4 := b.Addr().Is4()
switch { switch {
case a4 == nil && b4 != nil: case a4 == false && b4 == true:
// If i is v6 and j is v4, i is less than j // If i is v6 and j is v4, i is less than j
return true return true
case a4 != nil && b4 == nil: case a4 == true && b4 == false:
// If j is v6 and i is v4, i is not less than j // If j is v6 and i is v4, i is not less than j
return false return false
case a4 != nil && b4 != nil: case a4 == true && b4 == true:
// Special case for ipv4, a4 and b4 are not nil // i and j are both ipv4
aPrivate := isPrivateIP(a4) aPrivate := a.Addr().IsPrivate()
bPrivate := isPrivateIP(b4) bPrivate := b.Addr().IsPrivate()
switch { switch {
case !aPrivate && bPrivate: case !aPrivate && bPrivate:
// If i is a public ip (not private) and j is a private ip, i is less then j // If i is a public ip (not private) and j is a private ip, i is less then j
@@ -655,10 +646,10 @@ func (r *RemoteList) unlockedSort(preferredRanges []*net.IPNet) {
} }
// lexical order of ips 3rd // lexical order of ips 3rd
c := bytes.Compare(a.IP, b.IP) c := a.Addr().Compare(b.Addr())
if c == 0 { if c == 0 {
// Ips are the same, Lexical order of ports 4th // Ips are the same, Lexical order of ports 4th
return a.Port < b.Port return a.Port() < b.Port()
} }
// Ip wasn't the same // Ip wasn't the same
@@ -671,7 +662,7 @@ func (r *RemoteList) unlockedSort(preferredRanges []*net.IPNet) {
// Deduplicate // Deduplicate
a, b := 0, 1 a, b := 0, 1
for b < n { for b < n {
if !r.addrs[a].Equals(r.addrs[b]) { if r.addrs[a] != r.addrs[b] {
a++ a++
if a != b { if a != b {
r.addrs[a], r.addrs[b] = r.addrs[b], r.addrs[a] r.addrs[a], r.addrs[b] = r.addrs[b], r.addrs[a]
@@ -693,7 +684,7 @@ func minInt(a, b int) int {
} }
// isPreferred returns true of the ip is contained in the preferredRanges list // isPreferred returns true of the ip is contained in the preferredRanges list
func isPreferred(ip net.IP, preferredRanges []*net.IPNet) bool { func isPreferred(ip netip.Addr, preferredRanges []netip.Prefix) bool {
//TODO: this would be better in a CIDR6Tree //TODO: this would be better in a CIDR6Tree
for _, p := range preferredRanges { for _, p := range preferredRanges {
if p.Contains(ip) { if p.Contains(ip) {
@@ -702,14 +693,3 @@ func isPreferred(ip net.IP, preferredRanges []*net.IPNet) bool {
} }
return false return false
} }
var _, private24BitBlock, _ = net.ParseCIDR("10.0.0.0/8")
var _, private20BitBlock, _ = net.ParseCIDR("172.16.0.0/12")
var _, private16BitBlock, _ = net.ParseCIDR("192.168.0.0/16")
// isPrivateIP returns true if the ip is contained by a rfc 1918 private range
func isPrivateIP(ip net.IP) bool {
//TODO: another great cidrtree option
//TODO: Private for ipv6 or just let it ride?
return private24BitBlock.Contains(ip) || private20BitBlock.Contains(ip) || private16BitBlock.Contains(ip)
}

View File

@@ -1,47 +1,47 @@
package nebula package nebula
import ( import (
"net" "encoding/binary"
"net/netip"
"testing" "testing"
"github.com/slackhq/nebula/iputil"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func TestRemoteList_Rebuild(t *testing.T) { func TestRemoteList_Rebuild(t *testing.T) {
rl := NewRemoteList(nil) rl := NewRemoteList(nil)
rl.unlockedSetV4( rl.unlockedSetV4(
0, netip.MustParseAddr("0.0.0.0"),
0, netip.MustParseAddr("0.0.0.0"),
[]*Ip4AndPort{ []*Ip4AndPort{
{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("70.199.182.92"))), Port: 1475}, // this is duped newIp4AndPortFromString("70.199.182.92:1475"), // this is duped
{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.0.182"))), Port: 10101}, newIp4AndPortFromString("172.17.0.182:10101"),
{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.1.1"))), Port: 10101}, // this is duped newIp4AndPortFromString("172.17.1.1:10101"), // this is duped
{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.18.0.1"))), Port: 10101}, // this is duped newIp4AndPortFromString("172.18.0.1:10101"), // this is duped
{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.18.0.1"))), Port: 10101}, // this is a dupe newIp4AndPortFromString("172.18.0.1:10101"), // this is a dupe
{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.19.0.1"))), Port: 10101}, newIp4AndPortFromString("172.19.0.1:10101"),
{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.31.0.1"))), Port: 10101}, newIp4AndPortFromString("172.31.0.1:10101"),
{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.1.1"))), Port: 10101}, // this is a dupe newIp4AndPortFromString("172.17.1.1:10101"), // this is a dupe
{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("70.199.182.92"))), Port: 1476}, // almost dupe of 0 with a diff port newIp4AndPortFromString("70.199.182.92:1476"), // almost dupe of 0 with a diff port
{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("70.199.182.92"))), Port: 1475}, // this is a dupe newIp4AndPortFromString("70.199.182.92:1475"), // this is a dupe
}, },
func(iputil.VpnIp, *Ip4AndPort) bool { return true }, func(netip.Addr, *Ip4AndPort) bool { return true },
) )
rl.unlockedSetV6( rl.unlockedSetV6(
1, netip.MustParseAddr("0.0.0.1"),
1, netip.MustParseAddr("0.0.0.1"),
[]*Ip6AndPort{ []*Ip6AndPort{
NewIp6AndPort(net.ParseIP("1::1"), 1), // this is duped newIp6AndPortFromString("[1::1]:1"), // this is duped
NewIp6AndPort(net.ParseIP("1::1"), 2), // almost dupe of 0 with a diff port, also gets duped newIp6AndPortFromString("[1::1]:2"), // almost dupe of 0 with a diff port, also gets duped
NewIp6AndPort(net.ParseIP("1:100::1"), 1), newIp6AndPortFromString("[1:100::1]:1"),
NewIp6AndPort(net.ParseIP("1::1"), 1), // this is a dupe newIp6AndPortFromString("[1::1]:1"), // this is a dupe
NewIp6AndPort(net.ParseIP("1::1"), 2), // this is a dupe newIp6AndPortFromString("[1::1]:2"), // this is a dupe
}, },
func(iputil.VpnIp, *Ip6AndPort) bool { return true }, func(netip.Addr, *Ip6AndPort) bool { return true },
) )
rl.Rebuild([]*net.IPNet{}) rl.Rebuild([]netip.Prefix{})
assert.Len(t, rl.addrs, 10, "addrs contains too many entries") assert.Len(t, rl.addrs, 10, "addrs contains too many entries")
// ipv6 first, sorted lexically within // ipv6 first, sorted lexically within
@@ -59,9 +59,7 @@ func TestRemoteList_Rebuild(t *testing.T) {
assert.Equal(t, "172.31.0.1:10101", rl.addrs[9].String()) assert.Equal(t, "172.31.0.1:10101", rl.addrs[9].String())
// Now ensure we can hoist ipv4 up // Now ensure we can hoist ipv4 up
_, ipNet, err := net.ParseCIDR("0.0.0.0/0") rl.Rebuild([]netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")})
assert.NoError(t, err)
rl.Rebuild([]*net.IPNet{ipNet})
assert.Len(t, rl.addrs, 10, "addrs contains too many entries") assert.Len(t, rl.addrs, 10, "addrs contains too many entries")
// ipv4 first, public then private, lexically within them // ipv4 first, public then private, lexically within them
@@ -79,9 +77,7 @@ func TestRemoteList_Rebuild(t *testing.T) {
assert.Equal(t, "[1:100::1]:1", rl.addrs[9].String()) assert.Equal(t, "[1:100::1]:1", rl.addrs[9].String())
// Ensure we can hoist a specific ipv4 range over anything else // Ensure we can hoist a specific ipv4 range over anything else
_, ipNet, err = net.ParseCIDR("172.17.0.0/16") rl.Rebuild([]netip.Prefix{netip.MustParsePrefix("172.17.0.0/16")})
assert.NoError(t, err)
rl.Rebuild([]*net.IPNet{ipNet})
assert.Len(t, rl.addrs, 10, "addrs contains too many entries") assert.Len(t, rl.addrs, 10, "addrs contains too many entries")
// Preferred ipv4 first // Preferred ipv4 first
@@ -104,64 +100,61 @@ func TestRemoteList_Rebuild(t *testing.T) {
func BenchmarkFullRebuild(b *testing.B) { func BenchmarkFullRebuild(b *testing.B) {
rl := NewRemoteList(nil) rl := NewRemoteList(nil)
rl.unlockedSetV4( rl.unlockedSetV4(
0, netip.MustParseAddr("0.0.0.0"),
0, netip.MustParseAddr("0.0.0.0"),
[]*Ip4AndPort{ []*Ip4AndPort{
{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("70.199.182.92"))), Port: 1475}, newIp4AndPortFromString("70.199.182.92:1475"),
{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.0.182"))), Port: 10101}, newIp4AndPortFromString("172.17.0.182:10101"),
{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.1.1"))), Port: 10101}, newIp4AndPortFromString("172.17.1.1:10101"),
{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.18.0.1"))), Port: 10101}, newIp4AndPortFromString("172.18.0.1:10101"),
{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.19.0.1"))), Port: 10101}, newIp4AndPortFromString("172.19.0.1:10101"),
{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.31.0.1"))), Port: 10101}, newIp4AndPortFromString("172.31.0.1:10101"),
{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.1.1"))), Port: 10101}, // this is a dupe newIp4AndPortFromString("172.17.1.1:10101"), // this is a dupe
{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("70.199.182.92"))), Port: 1476}, // dupe of 0 with a diff port newIp4AndPortFromString("70.199.182.92:1476"), // dupe of 0 with a diff port
}, },
func(iputil.VpnIp, *Ip4AndPort) bool { return true }, func(netip.Addr, *Ip4AndPort) bool { return true },
) )
rl.unlockedSetV6( rl.unlockedSetV6(
0, netip.MustParseAddr("0.0.0.0"),
0, netip.MustParseAddr("0.0.0.0"),
[]*Ip6AndPort{ []*Ip6AndPort{
NewIp6AndPort(net.ParseIP("1::1"), 1), newIp6AndPortFromString("[1::1]:1"),
NewIp6AndPort(net.ParseIP("1::1"), 2), // dupe of 0 with a diff port newIp6AndPortFromString("[1::1]:2"), // dupe of 0 with a diff port
NewIp6AndPort(net.ParseIP("1:100::1"), 1), newIp6AndPortFromString("[1:100::1]:1"),
NewIp6AndPort(net.ParseIP("1::1"), 1), // this is a dupe newIp6AndPortFromString("[1::1]:1"), // this is a dupe
}, },
func(iputil.VpnIp, *Ip6AndPort) bool { return true }, func(netip.Addr, *Ip6AndPort) bool { return true },
) )
b.Run("no preferred", func(b *testing.B) { b.Run("no preferred", func(b *testing.B) {
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
rl.shouldRebuild = true rl.shouldRebuild = true
rl.Rebuild([]*net.IPNet{}) rl.Rebuild([]netip.Prefix{})
} }
}) })
_, ipNet, err := net.ParseCIDR("172.17.0.0/16") ipNet1 := netip.MustParsePrefix("172.17.0.0/16")
assert.NoError(b, err)
b.Run("1 preferred", func(b *testing.B) { b.Run("1 preferred", func(b *testing.B) {
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
rl.shouldRebuild = true rl.shouldRebuild = true
rl.Rebuild([]*net.IPNet{ipNet}) rl.Rebuild([]netip.Prefix{ipNet1})
} }
}) })
_, ipNet2, err := net.ParseCIDR("70.0.0.0/8") ipNet2 := netip.MustParsePrefix("70.0.0.0/8")
assert.NoError(b, err)
b.Run("2 preferred", func(b *testing.B) { b.Run("2 preferred", func(b *testing.B) {
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
rl.shouldRebuild = true rl.shouldRebuild = true
rl.Rebuild([]*net.IPNet{ipNet, ipNet2}) rl.Rebuild([]netip.Prefix{ipNet2})
} }
}) })
_, ipNet3, err := net.ParseCIDR("0.0.0.0/0") ipNet3 := netip.MustParsePrefix("0.0.0.0/0")
assert.NoError(b, err)
b.Run("3 preferred", func(b *testing.B) { b.Run("3 preferred", func(b *testing.B) {
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
rl.shouldRebuild = true rl.shouldRebuild = true
rl.Rebuild([]*net.IPNet{ipNet, ipNet2, ipNet3}) rl.Rebuild([]netip.Prefix{ipNet1, ipNet2, ipNet3})
} }
}) })
} }
@@ -169,67 +162,83 @@ func BenchmarkFullRebuild(b *testing.B) {
func BenchmarkSortRebuild(b *testing.B) { func BenchmarkSortRebuild(b *testing.B) {
rl := NewRemoteList(nil) rl := NewRemoteList(nil)
rl.unlockedSetV4( rl.unlockedSetV4(
0, netip.MustParseAddr("0.0.0.0"),
0, netip.MustParseAddr("0.0.0.0"),
[]*Ip4AndPort{ []*Ip4AndPort{
{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("70.199.182.92"))), Port: 1475}, newIp4AndPortFromString("70.199.182.92:1475"),
{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.0.182"))), Port: 10101}, newIp4AndPortFromString("172.17.0.182:10101"),
{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.1.1"))), Port: 10101}, newIp4AndPortFromString("172.17.1.1:10101"),
{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.18.0.1"))), Port: 10101}, newIp4AndPortFromString("172.18.0.1:10101"),
{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.19.0.1"))), Port: 10101}, newIp4AndPortFromString("172.19.0.1:10101"),
{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.31.0.1"))), Port: 10101}, newIp4AndPortFromString("172.31.0.1:10101"),
{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.1.1"))), Port: 10101}, // this is a dupe newIp4AndPortFromString("172.17.1.1:10101"), // this is a dupe
{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("70.199.182.92"))), Port: 1476}, // dupe of 0 with a diff port newIp4AndPortFromString("70.199.182.92:1476"), // dupe of 0 with a diff port
}, },
func(iputil.VpnIp, *Ip4AndPort) bool { return true }, func(netip.Addr, *Ip4AndPort) bool { return true },
) )
rl.unlockedSetV6( rl.unlockedSetV6(
0, netip.MustParseAddr("0.0.0.0"),
0, netip.MustParseAddr("0.0.0.0"),
[]*Ip6AndPort{ []*Ip6AndPort{
NewIp6AndPort(net.ParseIP("1::1"), 1), newIp6AndPortFromString("[1::1]:1"),
NewIp6AndPort(net.ParseIP("1::1"), 2), // dupe of 0 with a diff port newIp6AndPortFromString("[1::1]:2"), // dupe of 0 with a diff port
NewIp6AndPort(net.ParseIP("1:100::1"), 1), newIp6AndPortFromString("[1:100::1]:1"),
NewIp6AndPort(net.ParseIP("1::1"), 1), // this is a dupe newIp6AndPortFromString("[1::1]:1"), // this is a dupe
}, },
func(iputil.VpnIp, *Ip6AndPort) bool { return true }, func(netip.Addr, *Ip6AndPort) bool { return true },
) )
b.Run("no preferred", func(b *testing.B) { b.Run("no preferred", func(b *testing.B) {
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
rl.shouldRebuild = true rl.shouldRebuild = true
rl.Rebuild([]*net.IPNet{}) rl.Rebuild([]netip.Prefix{})
} }
}) })
_, ipNet, err := net.ParseCIDR("172.17.0.0/16") ipNet1 := netip.MustParsePrefix("172.17.0.0/16")
rl.Rebuild([]*net.IPNet{ipNet}) rl.Rebuild([]netip.Prefix{ipNet1})
assert.NoError(b, err)
b.Run("1 preferred", func(b *testing.B) { b.Run("1 preferred", func(b *testing.B) {
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
rl.Rebuild([]*net.IPNet{ipNet}) rl.Rebuild([]netip.Prefix{ipNet1})
} }
}) })
_, ipNet2, err := net.ParseCIDR("70.0.0.0/8") ipNet2 := netip.MustParsePrefix("70.0.0.0/8")
rl.Rebuild([]*net.IPNet{ipNet, ipNet2}) rl.Rebuild([]netip.Prefix{ipNet1, ipNet2})
assert.NoError(b, err)
b.Run("2 preferred", func(b *testing.B) { b.Run("2 preferred", func(b *testing.B) {
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
rl.Rebuild([]*net.IPNet{ipNet, ipNet2}) rl.Rebuild([]netip.Prefix{ipNet1, ipNet2})
} }
}) })
_, ipNet3, err := net.ParseCIDR("0.0.0.0/0") ipNet3 := netip.MustParsePrefix("0.0.0.0/0")
rl.Rebuild([]*net.IPNet{ipNet, ipNet2, ipNet3}) rl.Rebuild([]netip.Prefix{ipNet1, ipNet2, ipNet3})
assert.NoError(b, err)
b.Run("3 preferred", func(b *testing.B) { b.Run("3 preferred", func(b *testing.B) {
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
rl.Rebuild([]*net.IPNet{ipNet, ipNet2, ipNet3}) rl.Rebuild([]netip.Prefix{ipNet1, ipNet2, ipNet3})
} }
}) })
} }
func newIp4AndPortFromString(s string) *Ip4AndPort {
a := netip.MustParseAddrPort(s)
v4Addr := a.Addr().As4()
return &Ip4AndPort{
Ip: binary.BigEndian.Uint32(v4Addr[:]),
Port: uint32(a.Port()),
}
}
func newIp6AndPortFromString(s string) *Ip6AndPort {
a := netip.MustParseAddrPort(s)
v6Addr := a.Addr().As16()
return &Ip6AndPort{
Hi: binary.BigEndian.Uint64(v6Addr[:8]),
Lo: binary.BigEndian.Uint64(v6Addr[8:]),
Port: uint32(a.Port()),
}
}

View File

@@ -8,6 +8,7 @@ import (
"log" "log"
"math" "math"
"net" "net"
"net/netip"
"os" "os"
"strings" "strings"
"sync" "sync"
@@ -91,7 +92,7 @@ func New(config *config.C) (*Service, error) {
ipNet := device.Cidr() ipNet := device.Cidr()
pa := tcpip.ProtocolAddress{ pa := tcpip.ProtocolAddress{
AddressWithPrefix: tcpip.AddrFromSlice(ipNet.IP).WithPrefix(), AddressWithPrefix: tcpip.AddrFromSlice(ipNet.Addr().AsSlice()).WithPrefix(),
Protocol: ipv4.ProtocolNumber, Protocol: ipv4.ProtocolNumber,
} }
if err := s.ipstack.AddProtocolAddress(nicID, pa, stack.AddressProperties{ if err := s.ipstack.AddProtocolAddress(nicID, pa, stack.AddressProperties{
@@ -153,24 +154,48 @@ func New(config *config.C) (*Service, error) {
return &s, nil return &s, nil
} }
// DialContext dials the provided address. Currently only TCP is supported. func getProtocolNumber(addr netip.Addr) tcpip.NetworkProtocolNumber {
func (s *Service) DialContext(ctx context.Context, network, address string) (net.Conn, error) { if addr.Is6() {
if network != "tcp" && network != "tcp4" { return ipv6.ProtocolNumber
return nil, errors.New("only tcp is supported") }
return ipv4.ProtocolNumber
} }
addr, err := net.ResolveTCPAddr(network, address) // DialContext dials the provided address.
func (s *Service) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
switch network {
case "udp", "udp4", "udp6":
addr, err := net.ResolveUDPAddr(network, address)
if err != nil { if err != nil {
return nil, err return nil, err
} }
fullAddr := tcpip.FullAddress{ fullAddr := tcpip.FullAddress{
NIC: nicID, NIC: nicID,
Addr: tcpip.AddrFromSlice(addr.IP), Addr: tcpip.AddrFromSlice(addr.IP),
Port: uint16(addr.Port), Port: uint16(addr.Port),
} }
num := getProtocolNumber(addr.AddrPort().Addr())
return gonet.DialUDP(s.ipstack, nil, &fullAddr, num)
case "tcp", "tcp4", "tcp6":
addr, err := net.ResolveTCPAddr(network, address)
if err != nil {
return nil, err
}
fullAddr := tcpip.FullAddress{
NIC: nicID,
Addr: tcpip.AddrFromSlice(addr.IP),
Port: uint16(addr.Port),
}
num := getProtocolNumber(addr.AddrPort().Addr())
return gonet.DialContextTCP(ctx, s.ipstack, fullAddr, num)
default:
return nil, fmt.Errorf("unknown network type: %s", network)
}
}
return gonet.DialContextTCP(ctx, s.ipstack, fullAddr, ipv4.ProtocolNumber) // Dial dials the provided address
func (s *Service) Dial(network, address string) (net.Conn, error) {
return s.DialContext(context.Background(), network, address)
} }
// Listen listens on the provided address. Currently only TCP with wildcard // Listen listens on the provided address. Currently only TCP with wildcard

View File

@@ -4,7 +4,7 @@ import (
"bytes" "bytes"
"context" "context"
"errors" "errors"
"net" "net/netip"
"testing" "testing"
"time" "time"
@@ -18,12 +18,8 @@ import (
type m map[string]interface{} type m map[string]interface{}
func newSimpleService(caCrt *cert.NebulaCertificate, caKey []byte, name string, udpIp net.IP, overrides m) *Service { func newSimpleService(caCrt *cert.NebulaCertificate, caKey []byte, name string, udpIp netip.Addr, overrides m) *Service {
_, _, myPrivKey, myPEM := e2e.NewTestCert(caCrt, caKey, "a", time.Now(), time.Now().Add(5*time.Minute), netip.PrefixFrom(udpIp, 24), nil, []string{})
vpnIpNet := &net.IPNet{IP: make([]byte, len(udpIp)), Mask: net.IPMask{255, 255, 255, 0}}
copy(vpnIpNet.IP, udpIp)
_, _, myPrivKey, myPEM := e2e.NewTestCert(caCrt, caKey, "a", time.Now(), time.Now().Add(5*time.Minute), vpnIpNet, nil, []string{})
caB, err := caCrt.MarshalToPEM() caB, err := caCrt.MarshalToPEM()
if err != nil { if err != nil {
panic(err) panic(err)
@@ -83,8 +79,8 @@ func newSimpleService(caCrt *cert.NebulaCertificate, caKey []byte, name string,
} }
func TestService(t *testing.T) { func TestService(t *testing.T) {
ca, _, caKey, _ := e2e.NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) ca, _, caKey, _ := e2e.NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
a := newSimpleService(ca, caKey, "a", net.IP{10, 0, 0, 1}, m{ a := newSimpleService(ca, caKey, "a", netip.MustParseAddr("10.0.0.1"), m{
"static_host_map": m{}, "static_host_map": m{},
"lighthouse": m{ "lighthouse": m{
"am_lighthouse": true, "am_lighthouse": true,
@@ -94,7 +90,7 @@ func TestService(t *testing.T) {
"port": 4243, "port": 4243,
}, },
}) })
b := newSimpleService(ca, caKey, "b", net.IP{10, 0, 0, 2}, m{ b := newSimpleService(ca, caKey, "b", netip.MustParseAddr("10.0.0.2"), m{
"static_host_map": m{ "static_host_map": m{
"10.0.0.1": []string{"localhost:4243"}, "10.0.0.1": []string{"localhost:4243"},
}, },

65
ssh.go
View File

@@ -7,6 +7,7 @@ import (
"flag" "flag"
"fmt" "fmt"
"net" "net"
"net/netip"
"os" "os"
"reflect" "reflect"
"runtime" "runtime"
@@ -18,9 +19,7 @@ import (
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/header" "github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/sshd" "github.com/slackhq/nebula/sshd"
"github.com/slackhq/nebula/udp"
) )
type sshListHostMapFlags struct { type sshListHostMapFlags struct {
@@ -431,7 +430,7 @@ func sshListHostMap(hl controlHostLister, a interface{}, w sshd.StringWriter) er
} }
sort.Slice(hm, func(i, j int) bool { sort.Slice(hm, func(i, j int) bool {
return bytes.Compare(hm[i].VpnIp, hm[j].VpnIp) < 0 return hm[i].VpnIp.Compare(hm[j].VpnIp) < 0
}) })
if fs.Json || fs.Pretty { if fs.Json || fs.Pretty {
@@ -545,13 +544,12 @@ func sshQueryLighthouse(ifce *Interface, fs interface{}, a []string, w sshd.Stri
return w.WriteLine("No vpn ip was provided") return w.WriteLine("No vpn ip was provided")
} }
parsedIp := net.ParseIP(a[0]) vpnIp, err := netip.ParseAddr(a[0])
if parsedIp == nil { if err != nil {
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
} }
vpnIp := iputil.Ip2VpnIp(parsedIp) if !vpnIp.IsValid() {
if vpnIp == 0 {
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
} }
@@ -574,13 +572,12 @@ func sshCloseTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWr
return w.WriteLine("No vpn ip was provided") return w.WriteLine("No vpn ip was provided")
} }
parsedIp := net.ParseIP(a[0]) vpnIp, err := netip.ParseAddr(a[0])
if parsedIp == nil { if err != nil {
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
} }
vpnIp := iputil.Ip2VpnIp(parsedIp) if !vpnIp.IsValid() {
if vpnIp == 0 {
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
} }
@@ -616,13 +613,12 @@ func sshCreateTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringW
return w.WriteLine("No vpn ip was provided") return w.WriteLine("No vpn ip was provided")
} }
parsedIp := net.ParseIP(a[0]) vpnIp, err := netip.ParseAddr(a[0])
if parsedIp == nil { if err != nil {
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
} }
vpnIp := iputil.Ip2VpnIp(parsedIp) if !vpnIp.IsValid() {
if vpnIp == 0 {
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
} }
@@ -636,16 +632,16 @@ func sshCreateTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringW
return w.WriteLine(fmt.Sprintf("Tunnel already handshaking")) return w.WriteLine(fmt.Sprintf("Tunnel already handshaking"))
} }
var addr *udp.Addr var addr netip.AddrPort
if flags.Address != "" { if flags.Address != "" {
addr = udp.NewAddrFromString(flags.Address) addr, err = netip.ParseAddrPort(flags.Address)
if addr == nil { if err != nil {
return w.WriteLine("Address could not be parsed") return w.WriteLine("Address could not be parsed")
} }
} }
hostInfo = ifce.handshakeManager.StartHandshake(vpnIp, nil) hostInfo = ifce.handshakeManager.StartHandshake(vpnIp, nil)
if addr != nil { if addr.IsValid() {
hostInfo.SetRemote(addr) hostInfo.SetRemote(addr)
} }
@@ -667,18 +663,17 @@ func sshChangeRemote(ifce *Interface, fs interface{}, a []string, w sshd.StringW
return w.WriteLine("No address was provided") return w.WriteLine("No address was provided")
} }
addr := udp.NewAddrFromString(flags.Address) addr, err := netip.ParseAddrPort(flags.Address)
if addr == nil { if err != nil {
return w.WriteLine("Address could not be parsed") return w.WriteLine("Address could not be parsed")
} }
parsedIp := net.ParseIP(a[0]) vpnIp, err := netip.ParseAddr(a[0])
if parsedIp == nil { if err != nil {
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
} }
vpnIp := iputil.Ip2VpnIp(parsedIp) if !vpnIp.IsValid() {
if vpnIp == 0 {
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
} }
@@ -792,13 +787,12 @@ func sshPrintCert(ifce *Interface, fs interface{}, a []string, w sshd.StringWrit
cert := ifce.pki.GetCertState().Certificate cert := ifce.pki.GetCertState().Certificate
if len(a) > 0 { if len(a) > 0 {
parsedIp := net.ParseIP(a[0]) vpnIp, err := netip.ParseAddr(a[0])
if parsedIp == nil { if err != nil {
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
} }
vpnIp := iputil.Ip2VpnIp(parsedIp) if !vpnIp.IsValid() {
if vpnIp == 0 {
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
} }
@@ -862,14 +856,14 @@ func sshPrintRelays(ifce *Interface, fs interface{}, a []string, w sshd.StringWr
Error error Error error
Type string Type string
State string State string
PeerIp iputil.VpnIp PeerIp netip.Addr
LocalIndex uint32 LocalIndex uint32
RemoteIndex uint32 RemoteIndex uint32
RelayedThrough []iputil.VpnIp RelayedThrough []netip.Addr
} }
type RelayOutput struct { type RelayOutput struct {
NebulaIp iputil.VpnIp NebulaIp netip.Addr
RelayForIps []RelayFor RelayForIps []RelayFor
} }
@@ -952,13 +946,12 @@ func sshPrintTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWr
return w.WriteLine("No vpn ip was provided") return w.WriteLine("No vpn ip was provided")
} }
parsedIp := net.ParseIP(a[0]) vpnIp, err := netip.ParseAddr(a[0])
if parsedIp == nil { if err != nil {
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
} }
vpnIp := iputil.Ip2VpnIp(parsedIp) if !vpnIp.IsValid() {
if vpnIp == 0 {
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
} }

View File

@@ -3,23 +3,21 @@ package test
import ( import (
"errors" "errors"
"io" "io"
"net" "net/netip"
"github.com/slackhq/nebula/iputil"
) )
type NoopTun struct{} type NoopTun struct{}
func (NoopTun) RouteFor(iputil.VpnIp) iputil.VpnIp { func (NoopTun) RouteFor(addr netip.Addr) netip.Addr {
return 0 return netip.Addr{}
} }
func (NoopTun) Activate() error { func (NoopTun) Activate() error {
return nil return nil
} }
func (NoopTun) Cidr() *net.IPNet { func (NoopTun) Cidr() netip.Prefix {
return nil return netip.Prefix{}
} }
func (NoopTun) Name() string { func (NoopTun) Name() string {

View File

@@ -1,6 +1,7 @@
package nebula package nebula
import ( import (
"net/netip"
"testing" "testing"
"time" "time"
@@ -115,10 +116,10 @@ func TestTimerWheel_Purge(t *testing.T) {
assert.Equal(t, 0, tw.current) assert.Equal(t, 0, tw.current)
fps := []firewall.Packet{ fps := []firewall.Packet{
{LocalIP: 1}, {LocalIP: netip.MustParseAddr("0.0.0.1")},
{LocalIP: 2}, {LocalIP: netip.MustParseAddr("0.0.0.2")},
{LocalIP: 3}, {LocalIP: netip.MustParseAddr("0.0.0.3")},
{LocalIP: 4}, {LocalIP: netip.MustParseAddr("0.0.0.4")},
} }
tw.Add(fps[0], time.Second*1) tw.Add(fps[0], time.Second*1)

View File

@@ -1,6 +1,8 @@
package udp package udp
import ( import (
"net/netip"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/firewall" "github.com/slackhq/nebula/firewall"
"github.com/slackhq/nebula/header" "github.com/slackhq/nebula/header"
@@ -9,7 +11,7 @@ import (
const MTU = 9001 const MTU = 9001
type EncReader func( type EncReader func(
addr *Addr, addr netip.AddrPort,
out []byte, out []byte,
packet []byte, packet []byte,
header *header.H, header *header.H,
@@ -22,9 +24,9 @@ type EncReader func(
type Conn interface { type Conn interface {
Rebind() error Rebind() error
LocalAddr() (*Addr, error) LocalAddr() (netip.AddrPort, error)
ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int)
WriteTo(b []byte, addr *Addr) error WriteTo(b []byte, addr netip.AddrPort) error
ReloadConfig(c *config.C) ReloadConfig(c *config.C)
Close() error Close() error
} }
@@ -34,13 +36,13 @@ type NoopConn struct{}
func (NoopConn) Rebind() error { func (NoopConn) Rebind() error {
return nil return nil
} }
func (NoopConn) LocalAddr() (*Addr, error) { func (NoopConn) LocalAddr() (netip.AddrPort, error) {
return nil, nil return netip.AddrPort{}, nil
} }
func (NoopConn) ListenOut(_ EncReader, _ LightHouseHandlerFunc, _ *firewall.ConntrackCacheTicker, _ int) { func (NoopConn) ListenOut(_ EncReader, _ LightHouseHandlerFunc, _ *firewall.ConntrackCacheTicker, _ int) {
return return
} }
func (NoopConn) WriteTo(_ []byte, _ *Addr) error { func (NoopConn) WriteTo(_ []byte, _ netip.AddrPort) error {
return nil return nil
} }
func (NoopConn) ReloadConfig(_ *config.C) { func (NoopConn) ReloadConfig(_ *config.C) {

5
udp/errors.go Normal file
View File

@@ -0,0 +1,5 @@
package udp
import "errors"
var ErrInvalidIPv6RemoteForSocket = errors.New("listener is IPv4, but writing to IPv6 remote")

View File

@@ -1,9 +1,10 @@
package udp package udp
import ( import (
"github.com/slackhq/nebula/iputil" "net/netip"
) )
//TODO: The items in this file belong in their own packages but doing that in a single PR is a nightmare //TODO: The items in this file belong in their own packages but doing that in a single PR is a nightmare
type LightHouseHandlerFunc func(rAddr *Addr, vpnIp iputil.VpnIp, p []byte) // TODO: IPV6-WORK this can likely be removed now
type LightHouseHandlerFunc func(rAddr netip.AddrPort, vpnIp netip.Addr, p []byte)

View File

@@ -1,100 +0,0 @@
package udp
import (
"encoding/json"
"fmt"
"net"
"strconv"
)
type m map[string]interface{}
type Addr struct {
IP net.IP
Port uint16
}
func NewAddr(ip net.IP, port uint16) *Addr {
addr := Addr{IP: make([]byte, net.IPv6len), Port: port}
copy(addr.IP, ip.To16())
return &addr
}
func NewAddrFromString(s string) *Addr {
ip, port, err := ParseIPAndPort(s)
//TODO: handle err
_ = err
return &Addr{IP: ip.To16(), Port: port}
}
func (ua *Addr) Equals(t *Addr) bool {
if t == nil || ua == nil {
return t == nil && ua == nil
}
return ua.IP.Equal(t.IP) && ua.Port == t.Port
}
func (ua *Addr) String() string {
if ua == nil {
return "<nil>"
}
return net.JoinHostPort(ua.IP.String(), fmt.Sprintf("%v", ua.Port))
}
func (ua *Addr) MarshalJSON() ([]byte, error) {
if ua == nil {
return nil, nil
}
return json.Marshal(m{"ip": ua.IP, "port": ua.Port})
}
func (ua *Addr) Copy() *Addr {
if ua == nil {
return nil
}
nu := Addr{
Port: ua.Port,
IP: make(net.IP, len(ua.IP)),
}
copy(nu.IP, ua.IP)
return &nu
}
type AddrSlice []*Addr
func (a AddrSlice) Equal(b AddrSlice) bool {
if len(a) != len(b) {
return false
}
for i := range a {
if !a[i].Equals(b[i]) {
return false
}
}
return true
}
func ParseIPAndPort(s string) (net.IP, uint16, error) {
rIp, sPort, err := net.SplitHostPort(s)
if err != nil {
return nil, 0, err
}
addr, err := net.ResolveIPAddr("ip", rIp)
if err != nil {
return nil, 0, err
}
iPort, err := strconv.Atoi(sPort)
if err != nil {
return nil, 0, err
}
return addr.IP, uint16(iPort), nil
}

View File

@@ -6,13 +6,14 @@ package udp
import ( import (
"fmt" "fmt"
"net" "net"
"net/netip"
"syscall" "syscall"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
) )
func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (Conn, error) { func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
return NewGenericListener(l, ip, port, multi, batch) return NewGenericListener(l, ip, port, multi, batch)
} }

View File

@@ -9,13 +9,14 @@ package udp
import ( import (
"fmt" "fmt"
"net" "net"
"net/netip"
"syscall" "syscall"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
) )
func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (Conn, error) { func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
return NewGenericListener(l, ip, port, multi, batch) return NewGenericListener(l, ip, port, multi, batch)
} }

View File

@@ -6,16 +6,63 @@ package udp
// Darwin support is primarily implemented in udp_generic, besides NewListenConfig // Darwin support is primarily implemented in udp_generic, besides NewListenConfig
import ( import (
"context"
"encoding/binary"
"errors"
"fmt" "fmt"
"net" "net"
"net/netip"
"syscall" "syscall"
"unsafe"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/firewall"
"github.com/slackhq/nebula/header"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
) )
func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (Conn, error) { type StdConn struct {
return NewGenericListener(l, ip, port, multi, batch) *net.UDPConn
isV4 bool
sysFd uintptr
l *logrus.Logger
}
var _ Conn = &StdConn{}
func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
lc := NewListenConfig(multi)
pc, err := lc.ListenPacket(context.TODO(), "udp", net.JoinHostPort(ip.String(), fmt.Sprintf("%v", port)))
if err != nil {
return nil, err
}
if uc, ok := pc.(*net.UDPConn); ok {
c := &StdConn{UDPConn: uc, l: l}
rc, err := uc.SyscallConn()
if err != nil {
return nil, fmt.Errorf("failed to open udp socket: %w", err)
}
err = rc.Control(func(fd uintptr) {
c.sysFd = fd
})
if err != nil {
return nil, fmt.Errorf("failed to get udp fd: %w", err)
}
la, err := c.LocalAddr()
if err != nil {
return nil, err
}
c.isV4 = la.Addr().Is4()
return c, nil
}
return nil, fmt.Errorf("unexpected PacketConn: %T %#v", pc, pc)
} }
func NewListenConfig(multi bool) net.ListenConfig { func NewListenConfig(multi bool) net.ListenConfig {
@@ -42,16 +89,130 @@ func NewListenConfig(multi bool) net.ListenConfig {
} }
} }
func (u *GenericConn) Rebind() error { //go:linkname sendto golang.org/x/sys/unix.sendto
rc, err := u.UDPConn.SyscallConn() //go:noescape
if err != nil { func sendto(s int, buf []byte, flags int, to unsafe.Pointer, addrlen int32) (err error)
return err
func (u *StdConn) WriteTo(b []byte, ap netip.AddrPort) error {
var sa unsafe.Pointer
var addrLen int32
if u.isV4 {
if ap.Addr().Is6() {
return ErrInvalidIPv6RemoteForSocket
}
var rsa unix.RawSockaddrInet6
rsa.Family = unix.AF_INET6
rsa.Addr = ap.Addr().As16()
binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&rsa.Port))[:], ap.Port())
sa = unsafe.Pointer(&rsa)
addrLen = syscall.SizeofSockaddrInet4
} else {
var rsa unix.RawSockaddrInet6
rsa.Family = unix.AF_INET6
rsa.Addr = ap.Addr().As16()
binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&rsa.Port))[:], ap.Port())
sa = unsafe.Pointer(&rsa)
addrLen = syscall.SizeofSockaddrInet6
}
// Golang stdlib doesn't handle EAGAIN correctly in some situations so we do writes ourselves
// See https://github.com/golang/go/issues/73919
for {
//_, _, err := unix.Syscall6(unix.SYS_SENDTO, u.sysFd, uintptr(unsafe.Pointer(&b[0])), uintptr(len(b)), 0, sa, addrLen)
err := sendto(int(u.sysFd), b, 0, sa, addrLen)
if err == nil {
// Written, get out before the error handling
return nil
}
if errors.Is(err, syscall.EINTR) {
// Write was interrupted, retry
continue
}
if errors.Is(err, syscall.EAGAIN) {
return &net.OpError{Op: "sendto", Err: unix.EWOULDBLOCK}
}
if errors.Is(err, syscall.EBADF) {
return net.ErrClosed
}
return &net.OpError{Op: "sendto", Err: err}
}
}
func (u *StdConn) LocalAddr() (netip.AddrPort, error) {
a := u.UDPConn.LocalAddr()
switch v := a.(type) {
case *net.UDPAddr:
addr, ok := netip.AddrFromSlice(v.IP)
if !ok {
return netip.AddrPort{}, fmt.Errorf("LocalAddr returned invalid IP address: %s", v.IP)
}
return netip.AddrPortFrom(addr, uint16(v.Port)), nil
default:
return netip.AddrPort{}, fmt.Errorf("LocalAddr returned: %#v", a)
}
}
func (u *StdConn) ReloadConfig(c *config.C) {
// TODO
}
func NewUDPStatsEmitter(udpConns []Conn) func() {
// No UDP stats for non-linux
return func() {}
}
func (u *StdConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int) {
plaintext := make([]byte, MTU)
buffer := make([]byte, MTU)
h := &header.H{}
fwPacket := &firewall.Packet{}
nb := make([]byte, 12, 12)
for {
// Just read one packet at a time
n, rua, err := u.ReadFromUDPAddrPort(buffer)
if err != nil {
if errors.Is(err, net.ErrClosed) {
u.l.WithError(err).Debug("udp socket is closed, exiting read loop")
return
}
u.l.WithError(err).Error("unexpected udp socket receive error")
}
r(
netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()),
plaintext[:0],
buffer[:n],
h,
fwPacket,
lhf,
nb,
q,
cache.Get(u.l),
)
}
}
func (u *StdConn) Rebind() error {
var err error
if u.isV4 {
err = syscall.SetsockoptInt(int(u.sysFd), syscall.IPPROTO_IP, syscall.IP_BOUND_IF, 0)
} else {
err = syscall.SetsockoptInt(int(u.sysFd), syscall.IPPROTO_IPV6, syscall.IPV6_BOUND_IF, 0)
} }
return rc.Control(func(fd uintptr) {
err := syscall.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_BOUND_IF, 0)
if err != nil { if err != nil {
u.l.WithError(err).Error("Failed to rebind udp socket") u.l.WithError(err).Error("Failed to rebind udp socket")
} }
})
return nil
} }

View File

@@ -1,6 +1,7 @@
//go:build (!linux || android) && !e2e_testing //go:build (!linux || android) && !e2e_testing && !darwin
// +build !linux android // +build !linux android
// +build !e2e_testing // +build !e2e_testing
// +build !darwin
// udp_generic implements the nebula UDP interface in pure Go stdlib. This // udp_generic implements the nebula UDP interface in pure Go stdlib. This
// means it can be used on platforms like Darwin and Windows. // means it can be used on platforms like Darwin and Windows.
@@ -11,6 +12,7 @@ import (
"context" "context"
"fmt" "fmt"
"net" "net"
"net/netip"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
@@ -25,7 +27,7 @@ type GenericConn struct {
var _ Conn = &GenericConn{} var _ Conn = &GenericConn{}
func NewGenericListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (Conn, error) { func NewGenericListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
lc := NewListenConfig(multi) lc := NewListenConfig(multi)
pc, err := lc.ListenPacket(context.TODO(), "udp", net.JoinHostPort(ip.String(), fmt.Sprintf("%v", port))) pc, err := lc.ListenPacket(context.TODO(), "udp", net.JoinHostPort(ip.String(), fmt.Sprintf("%v", port)))
if err != nil { if err != nil {
@@ -37,23 +39,24 @@ func NewGenericListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch
return nil, fmt.Errorf("Unexpected PacketConn: %T %#v", pc, pc) return nil, fmt.Errorf("Unexpected PacketConn: %T %#v", pc, pc)
} }
func (u *GenericConn) WriteTo(b []byte, addr *Addr) error { func (u *GenericConn) WriteTo(b []byte, addr netip.AddrPort) error {
_, err := u.UDPConn.WriteToUDP(b, &net.UDPAddr{IP: addr.IP, Port: int(addr.Port)}) _, err := u.UDPConn.WriteToUDPAddrPort(b, addr)
return err return err
} }
func (u *GenericConn) LocalAddr() (*Addr, error) { func (u *GenericConn) LocalAddr() (netip.AddrPort, error) {
a := u.UDPConn.LocalAddr() a := u.UDPConn.LocalAddr()
switch v := a.(type) { switch v := a.(type) {
case *net.UDPAddr: case *net.UDPAddr:
addr := &Addr{IP: make([]byte, len(v.IP))} addr, ok := netip.AddrFromSlice(v.IP)
copy(addr.IP, v.IP) if !ok {
addr.Port = uint16(v.Port) return netip.AddrPort{}, fmt.Errorf("LocalAddr returned invalid IP address: %s", v.IP)
return addr, nil }
return netip.AddrPortFrom(addr, uint16(v.Port)), nil
default: default:
return nil, fmt.Errorf("LocalAddr returned: %#v", a) return netip.AddrPort{}, fmt.Errorf("LocalAddr returned: %#v", a)
} }
} }
@@ -75,19 +78,26 @@ func (u *GenericConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *f
buffer := make([]byte, MTU) buffer := make([]byte, MTU)
h := &header.H{} h := &header.H{}
fwPacket := &firewall.Packet{} fwPacket := &firewall.Packet{}
udpAddr := &Addr{IP: make([]byte, 16)}
nb := make([]byte, 12, 12) nb := make([]byte, 12, 12)
for { for {
// Just read one packet at a time // Just read one packet at a time
n, rua, err := u.ReadFromUDP(buffer) n, rua, err := u.ReadFromUDPAddrPort(buffer)
if err != nil { if err != nil {
u.l.WithError(err).Debug("udp socket is closed, exiting read loop") u.l.WithError(err).Debug("udp socket is closed, exiting read loop")
return return
} }
udpAddr.IP = rua.IP r(
udpAddr.Port = uint16(rua.Port) netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()),
r(udpAddr, plaintext[:0], buffer[:n], h, fwPacket, lhf, nb, q, cache.Get(u.l)) plaintext[:0],
buffer[:n],
h,
fwPacket,
lhf,
nb,
q,
cache.Get(u.l),
)
} }
} }

View File

@@ -7,6 +7,7 @@ import (
"encoding/binary" "encoding/binary"
"fmt" "fmt"
"net" "net"
"net/netip"
"syscall" "syscall"
"unsafe" "unsafe"
@@ -27,25 +28,6 @@ type StdConn struct {
batch int batch int
} }
var x int
// From linux/sock_diag.h
const (
_SK_MEMINFO_RMEM_ALLOC = iota
_SK_MEMINFO_RCVBUF
_SK_MEMINFO_WMEM_ALLOC
_SK_MEMINFO_SNDBUF
_SK_MEMINFO_FWD_ALLOC
_SK_MEMINFO_WMEM_QUEUED
_SK_MEMINFO_OPTMEM
_SK_MEMINFO_BACKLOG
_SK_MEMINFO_DROPS
_SK_MEMINFO_VARS
)
type _SK_MEMINFO [_SK_MEMINFO_VARS]uint32
func maybeIPV4(ip net.IP) (net.IP, bool) { func maybeIPV4(ip net.IP) (net.IP, bool) {
ip4 := ip.To4() ip4 := ip.To4()
if ip4 != nil { if ip4 != nil {
@@ -54,10 +36,9 @@ func maybeIPV4(ip net.IP) (net.IP, bool) {
return ip, false return ip, false
} }
func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (Conn, error) { func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
ipV4, isV4 := maybeIPV4(ip)
af := unix.AF_INET6 af := unix.AF_INET6
if isV4 { if ip.Is4() {
af = unix.AF_INET af = unix.AF_INET
} }
syscall.ForkLock.RLock() syscall.ForkLock.RLock()
@@ -80,13 +61,13 @@ func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (
//TODO: support multiple listening IPs (for limiting ipv6) //TODO: support multiple listening IPs (for limiting ipv6)
var sa unix.Sockaddr var sa unix.Sockaddr
if isV4 { if ip.Is4() {
sa4 := &unix.SockaddrInet4{Port: port} sa4 := &unix.SockaddrInet4{Port: port}
copy(sa4.Addr[:], ipV4) sa4.Addr = ip.As4()
sa = sa4 sa = sa4
} else { } else {
sa6 := &unix.SockaddrInet6{Port: port} sa6 := &unix.SockaddrInet6{Port: port}
copy(sa6.Addr[:], ip.To16()) sa6.Addr = ip.As16()
sa = sa6 sa = sa6
} }
if err = unix.Bind(fd, sa); err != nil { if err = unix.Bind(fd, sa); err != nil {
@@ -98,7 +79,7 @@ func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (
//v, err := unix.GetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_INCOMING_CPU) //v, err := unix.GetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_INCOMING_CPU)
//l.Println(v, err) //l.Println(v, err)
return &StdConn{sysFd: fd, isV4: isV4, l: l, batch: batch}, err return &StdConn{sysFd: fd, isV4: ip.Is4(), l: l, batch: batch}, err
} }
func (u *StdConn) Rebind() error { func (u *StdConn) Rebind() error {
@@ -121,30 +102,29 @@ func (u *StdConn) GetSendBuffer() (int, error) {
return unix.GetsockoptInt(int(u.sysFd), unix.SOL_SOCKET, unix.SO_SNDBUF) return unix.GetsockoptInt(int(u.sysFd), unix.SOL_SOCKET, unix.SO_SNDBUF)
} }
func (u *StdConn) LocalAddr() (*Addr, error) { func (u *StdConn) LocalAddr() (netip.AddrPort, error) {
sa, err := unix.Getsockname(u.sysFd) sa, err := unix.Getsockname(u.sysFd)
if err != nil { if err != nil {
return nil, err return netip.AddrPort{}, err
} }
addr := &Addr{}
switch sa := sa.(type) { switch sa := sa.(type) {
case *unix.SockaddrInet4: case *unix.SockaddrInet4:
addr.IP = net.IP{sa.Addr[0], sa.Addr[1], sa.Addr[2], sa.Addr[3]}.To16() return netip.AddrPortFrom(netip.AddrFrom4(sa.Addr), uint16(sa.Port)), nil
addr.Port = uint16(sa.Port)
case *unix.SockaddrInet6:
addr.IP = sa.Addr[0:]
addr.Port = uint16(sa.Port)
}
return addr, nil case *unix.SockaddrInet6:
return netip.AddrPortFrom(netip.AddrFrom16(sa.Addr), uint16(sa.Port)), nil
default:
return netip.AddrPort{}, fmt.Errorf("unsupported sock type: %T", sa)
}
} }
func (u *StdConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int) { func (u *StdConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int) {
plaintext := make([]byte, MTU) plaintext := make([]byte, MTU)
h := &header.H{} h := &header.H{}
fwPacket := &firewall.Packet{} fwPacket := &firewall.Packet{}
udpAddr := &Addr{} var ip netip.Addr
nb := make([]byte, 12, 12) nb := make([]byte, 12, 12)
//TODO: should we track this? //TODO: should we track this?
@@ -165,12 +145,23 @@ func (u *StdConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firew
//metric.Update(int64(n)) //metric.Update(int64(n))
for i := 0; i < n; i++ { for i := 0; i < n; i++ {
if u.isV4 { if u.isV4 {
udpAddr.IP = names[i][4:8] ip, _ = netip.AddrFromSlice(names[i][4:8])
//TODO: IPV6-WORK what is not ok?
} else { } else {
udpAddr.IP = names[i][8:24] ip, _ = netip.AddrFromSlice(names[i][8:24])
//TODO: IPV6-WORK what is not ok?
} }
udpAddr.Port = binary.BigEndian.Uint16(names[i][2:4]) r(
r(udpAddr, plaintext[:0], buffers[i][:msgs[i].Len], h, fwPacket, lhf, nb, q, cache.Get(u.l)) netip.AddrPortFrom(ip.Unmap(), binary.BigEndian.Uint16(names[i][2:4])),
plaintext[:0],
buffers[i][:msgs[i].Len],
h,
fwPacket,
lhf,
nb,
q,
cache.Get(u.l),
)
} }
} }
} }
@@ -216,19 +207,18 @@ func (u *StdConn) ReadMulti(msgs []rawMessage) (int, error) {
} }
} }
func (u *StdConn) WriteTo(b []byte, addr *Addr) error { func (u *StdConn) WriteTo(b []byte, ip netip.AddrPort) error {
if u.isV4 { if u.isV4 {
return u.writeTo4(b, addr) return u.writeTo4(b, ip)
} }
return u.writeTo6(b, addr) return u.writeTo6(b, ip)
} }
func (u *StdConn) writeTo6(b []byte, addr *Addr) error { func (u *StdConn) writeTo6(b []byte, ip netip.AddrPort) error {
var rsa unix.RawSockaddrInet6 var rsa unix.RawSockaddrInet6
rsa.Family = unix.AF_INET6 rsa.Family = unix.AF_INET6
// Little Endian -> Network Endian rsa.Addr = ip.Addr().As16()
rsa.Port = (addr.Port >> 8) | ((addr.Port & 0xff) << 8) binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&rsa.Port))[:], ip.Port())
copy(rsa.Addr[:], addr.IP.To16())
for { for {
_, _, err := unix.Syscall6( _, _, err := unix.Syscall6(
@@ -251,17 +241,15 @@ func (u *StdConn) writeTo6(b []byte, addr *Addr) error {
} }
} }
func (u *StdConn) writeTo4(b []byte, addr *Addr) error { func (u *StdConn) writeTo4(b []byte, ip netip.AddrPort) error {
addrV4, isAddrV4 := maybeIPV4(addr.IP) if !ip.Addr().Is4() {
if !isAddrV4 { return ErrInvalidIPv6RemoteForSocket
return fmt.Errorf("Listener is IPv4, but writing to IPv6 remote")
} }
var rsa unix.RawSockaddrInet4 var rsa unix.RawSockaddrInet4
rsa.Family = unix.AF_INET rsa.Family = unix.AF_INET
// Little Endian -> Network Endian rsa.Addr = ip.Addr().As4()
rsa.Port = (addr.Port >> 8) | ((addr.Port & 0xff) << 8) binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&rsa.Port))[:], ip.Port())
copy(rsa.Addr[:], addrV4)
for { for {
_, _, err := unix.Syscall6( _, _, err := unix.Syscall6(
@@ -316,8 +304,8 @@ func (u *StdConn) ReloadConfig(c *config.C) {
} }
} }
func (u *StdConn) getMemInfo(meminfo *_SK_MEMINFO) error { func (u *StdConn) getMemInfo(meminfo *[unix.SK_MEMINFO_VARS]uint32) error {
var vallen uint32 = 4 * _SK_MEMINFO_VARS var vallen uint32 = 4 * unix.SK_MEMINFO_VARS
_, _, err := unix.Syscall6(unix.SYS_GETSOCKOPT, uintptr(u.sysFd), uintptr(unix.SOL_SOCKET), uintptr(unix.SO_MEMINFO), uintptr(unsafe.Pointer(meminfo)), uintptr(unsafe.Pointer(&vallen)), 0) _, _, err := unix.Syscall6(unix.SYS_GETSOCKOPT, uintptr(u.sysFd), uintptr(unix.SOL_SOCKET), uintptr(unix.SO_MEMINFO), uintptr(unsafe.Pointer(meminfo)), uintptr(unsafe.Pointer(&vallen)), 0)
if err != 0 { if err != 0 {
return err return err
@@ -332,12 +320,12 @@ func (u *StdConn) Close() error {
func NewUDPStatsEmitter(udpConns []Conn) func() { func NewUDPStatsEmitter(udpConns []Conn) func() {
// Check if our kernel supports SO_MEMINFO before registering the gauges // Check if our kernel supports SO_MEMINFO before registering the gauges
var udpGauges [][_SK_MEMINFO_VARS]metrics.Gauge var udpGauges [][unix.SK_MEMINFO_VARS]metrics.Gauge
var meminfo _SK_MEMINFO var meminfo [unix.SK_MEMINFO_VARS]uint32
if err := udpConns[0].(*StdConn).getMemInfo(&meminfo); err == nil { if err := udpConns[0].(*StdConn).getMemInfo(&meminfo); err == nil {
udpGauges = make([][_SK_MEMINFO_VARS]metrics.Gauge, len(udpConns)) udpGauges = make([][unix.SK_MEMINFO_VARS]metrics.Gauge, len(udpConns))
for i := range udpConns { for i := range udpConns {
udpGauges[i] = [_SK_MEMINFO_VARS]metrics.Gauge{ udpGauges[i] = [unix.SK_MEMINFO_VARS]metrics.Gauge{
metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.rmem_alloc", i), nil), metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.rmem_alloc", i), nil),
metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.rcvbuf", i), nil), metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.rcvbuf", i), nil),
metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.wmem_alloc", i), nil), metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.wmem_alloc", i), nil),
@@ -354,7 +342,7 @@ func NewUDPStatsEmitter(udpConns []Conn) func() {
return func() { return func() {
for i, gauges := range udpGauges { for i, gauges := range udpGauges {
if err := udpConns[i].(*StdConn).getMemInfo(&meminfo); err == nil { if err := udpConns[i].(*StdConn).getMemInfo(&meminfo); err == nil {
for j := 0; j < _SK_MEMINFO_VARS; j++ { for j := 0; j < unix.SK_MEMINFO_VARS; j++ {
gauges[j].Update(int64(meminfo[j])) gauges[j].Update(int64(meminfo[j]))
} }
} }

View File

@@ -8,13 +8,14 @@ package udp
import ( import (
"fmt" "fmt"
"net" "net"
"net/netip"
"syscall" "syscall"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
) )
func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (Conn, error) { func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
return NewGenericListener(l, ip, port, multi, batch) return NewGenericListener(l, ip, port, multi, batch)
} }

View File

@@ -10,6 +10,7 @@ import (
"fmt" "fmt"
"io" "io"
"net" "net"
"net/netip"
"sync" "sync"
"sync/atomic" "sync/atomic"
"syscall" "syscall"
@@ -61,16 +62,14 @@ type RIOConn struct {
results [packetsPerRing]winrio.Result results [packetsPerRing]winrio.Result
} }
func NewRIOListener(l *logrus.Logger, ip net.IP, port int) (*RIOConn, error) { func NewRIOListener(l *logrus.Logger, addr netip.Addr, port int) (*RIOConn, error) {
if !winrio.Initialize() { if !winrio.Initialize() {
return nil, errors.New("could not initialize winrio") return nil, errors.New("could not initialize winrio")
} }
u := &RIOConn{l: l} u := &RIOConn{l: l}
addr := [16]byte{} err := u.bind(&windows.SockaddrInet6{Addr: addr.As16(), Port: port})
copy(addr[:], ip.To16())
err := u.bind(&windows.SockaddrInet6{Addr: addr, Port: port})
if err != nil { if err != nil {
return nil, fmt.Errorf("bind: %w", err) return nil, fmt.Errorf("bind: %w", err)
} }
@@ -96,6 +95,25 @@ func (u *RIOConn) bind(sa windows.Sockaddr) error {
// Enable v4 for this socket // Enable v4 for this socket
syscall.SetsockoptInt(syscall.Handle(u.sock), syscall.IPPROTO_IPV6, syscall.IPV6_V6ONLY, 0) syscall.SetsockoptInt(syscall.Handle(u.sock), syscall.IPPROTO_IPV6, syscall.IPV6_V6ONLY, 0)
// Disable reporting of PORT_UNREACHABLE and NET_UNREACHABLE errors from the UDP socket receive call.
// These errors are returned on Windows during UDP receives based on the receipt of ICMP packets. Disable
// the UDP receive error returns with these ioctl calls.
ret := uint32(0)
flag := uint32(0)
size := uint32(unsafe.Sizeof(flag))
err = syscall.WSAIoctl(syscall.Handle(u.sock), syscall.SIO_UDP_CONNRESET, (*byte)(unsafe.Pointer(&flag)), size, nil, 0, &ret, nil, 0)
if err != nil {
return err
}
ret = 0
flag = 0
size = uint32(unsafe.Sizeof(flag))
SIO_UDP_NETRESET := uint32(syscall.IOC_IN | syscall.IOC_VENDOR | 15)
err = syscall.WSAIoctl(syscall.Handle(u.sock), SIO_UDP_NETRESET, (*byte)(unsafe.Pointer(&flag)), size, nil, 0, &ret, nil, 0)
if err != nil {
return err
}
err = u.rx.Open() err = u.rx.Open()
if err != nil { if err != nil {
return err return err
@@ -124,22 +142,31 @@ func (u *RIOConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firew
buffer := make([]byte, MTU) buffer := make([]byte, MTU)
h := &header.H{} h := &header.H{}
fwPacket := &firewall.Packet{} fwPacket := &firewall.Packet{}
udpAddr := &Addr{IP: make([]byte, 16)}
nb := make([]byte, 12, 12) nb := make([]byte, 12, 12)
for { for {
// Just read one packet at a time // Just read one packet at a time
n, rua, err := u.receive(buffer) n, rua, err := u.receive(buffer)
if err != nil { if err != nil {
if errors.Is(err, net.ErrClosed) {
u.l.WithError(err).Debug("udp socket is closed, exiting read loop") u.l.WithError(err).Debug("udp socket is closed, exiting read loop")
return return
} }
u.l.WithError(err).Error("unexpected udp socket receive error")
continue
}
udpAddr.IP = rua.Addr[:] r(
p := (*[2]byte)(unsafe.Pointer(&udpAddr.Port)) netip.AddrPortFrom(netip.AddrFrom16(rua.Addr).Unmap(), (rua.Port>>8)|((rua.Port&0xff)<<8)),
p[0] = byte(rua.Port >> 8) plaintext[:0],
p[1] = byte(rua.Port) buffer[:n],
r(udpAddr, plaintext[:0], buffer[:n], h, fwPacket, lhf, nb, q, cache.Get(u.l)) h,
fwPacket,
lhf,
nb,
q,
cache.Get(u.l),
)
} }
} }
@@ -231,7 +258,7 @@ retry:
return n, ep, nil return n, ep, nil
} }
func (u *RIOConn) WriteTo(buf []byte, addr *Addr) error { func (u *RIOConn) WriteTo(buf []byte, ip netip.AddrPort) error {
if !u.isOpen.Load() { if !u.isOpen.Load() {
return net.ErrClosed return net.ErrClosed
} }
@@ -274,10 +301,9 @@ func (u *RIOConn) WriteTo(buf []byte, addr *Addr) error {
packet := u.tx.Push() packet := u.tx.Push()
packet.addr.Family = windows.AF_INET6 packet.addr.Family = windows.AF_INET6
p := (*[2]byte)(unsafe.Pointer(&packet.addr.Port)) packet.addr.Addr = ip.Addr().As16()
p[0] = byte(addr.Port >> 8) port := ip.Port()
p[1] = byte(addr.Port) packet.addr.Port = (port >> 8) | ((port & 0xff) << 8)
copy(packet.addr.Addr[:], addr.IP.To16())
copy(packet.data[:], buf) copy(packet.data[:], buf)
dataBuffer := &winrio.Buffer{ dataBuffer := &winrio.Buffer{
@@ -295,17 +321,15 @@ func (u *RIOConn) WriteTo(buf []byte, addr *Addr) error {
return winrio.SendEx(u.rq, dataBuffer, 1, nil, addressBuffer, nil, nil, 0, 0) return winrio.SendEx(u.rq, dataBuffer, 1, nil, addressBuffer, nil, nil, 0, 0)
} }
func (u *RIOConn) LocalAddr() (*Addr, error) { func (u *RIOConn) LocalAddr() (netip.AddrPort, error) {
sa, err := windows.Getsockname(u.sock) sa, err := windows.Getsockname(u.sock)
if err != nil { if err != nil {
return nil, err return netip.AddrPort{}, err
} }
v6 := sa.(*windows.SockaddrInet6) v6 := sa.(*windows.SockaddrInet6)
return &Addr{ return netip.AddrPortFrom(netip.AddrFrom16(v6.Addr).Unmap(), uint16(v6.Port)), nil
IP: v6.Addr[:],
Port: uint16(v6.Port),
}, nil
} }
func (u *RIOConn) Rebind() error { func (u *RIOConn) Rebind() error {

View File

@@ -4,9 +4,8 @@
package udp package udp
import ( import (
"fmt"
"io" "io"
"net" "net/netip"
"sync/atomic" "sync/atomic"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
@@ -16,30 +15,24 @@ import (
) )
type Packet struct { type Packet struct {
ToIp net.IP To netip.AddrPort
ToPort uint16 From netip.AddrPort
FromIp net.IP
FromPort uint16
Data []byte Data []byte
} }
func (u *Packet) Copy() *Packet { func (u *Packet) Copy() *Packet {
n := &Packet{ n := &Packet{
ToIp: make(net.IP, len(u.ToIp)), To: u.To,
ToPort: u.ToPort, From: u.From,
FromIp: make(net.IP, len(u.FromIp)),
FromPort: u.FromPort,
Data: make([]byte, len(u.Data)), Data: make([]byte, len(u.Data)),
} }
copy(n.ToIp, u.ToIp)
copy(n.FromIp, u.FromIp)
copy(n.Data, u.Data) copy(n.Data, u.Data)
return n return n
} }
type TesterConn struct { type TesterConn struct {
Addr *Addr Addr netip.AddrPort
RxPackets chan *Packet // Packets to receive into nebula RxPackets chan *Packet // Packets to receive into nebula
TxPackets chan *Packet // Packets transmitted outside by nebula TxPackets chan *Packet // Packets transmitted outside by nebula
@@ -48,9 +41,9 @@ type TesterConn struct {
l *logrus.Logger l *logrus.Logger
} }
func NewListener(l *logrus.Logger, ip net.IP, port int, _ bool, _ int) (Conn, error) { func NewListener(l *logrus.Logger, ip netip.Addr, port int, _ bool, _ int) (Conn, error) {
return &TesterConn{ return &TesterConn{
Addr: &Addr{ip, uint16(port)}, Addr: netip.AddrPortFrom(ip, uint16(port)),
RxPackets: make(chan *Packet, 10), RxPackets: make(chan *Packet, 10),
TxPackets: make(chan *Packet, 10), TxPackets: make(chan *Packet, 10),
l: l, l: l,
@@ -71,7 +64,7 @@ func (u *TesterConn) Send(packet *Packet) {
} }
if u.l.Level >= logrus.DebugLevel { if u.l.Level >= logrus.DebugLevel {
u.l.WithField("header", h). u.l.WithField("header", h).
WithField("udpAddr", fmt.Sprintf("%v:%v", packet.FromIp, packet.FromPort)). WithField("udpAddr", packet.From).
WithField("dataLen", len(packet.Data)). WithField("dataLen", len(packet.Data)).
Debug("UDP receiving injected packet") Debug("UDP receiving injected packet")
} }
@@ -98,23 +91,18 @@ func (u *TesterConn) Get(block bool) *Packet {
// Below this is boilerplate implementation to make nebula actually work // Below this is boilerplate implementation to make nebula actually work
//********************************************************************************************************************// //********************************************************************************************************************//
func (u *TesterConn) WriteTo(b []byte, addr *Addr) error { func (u *TesterConn) WriteTo(b []byte, addr netip.AddrPort) error {
if u.closed.Load() { if u.closed.Load() {
return io.ErrClosedPipe return io.ErrClosedPipe
} }
p := &Packet{ p := &Packet{
Data: make([]byte, len(b), len(b)), Data: make([]byte, len(b), len(b)),
FromIp: make([]byte, 16), From: u.Addr,
FromPort: u.Addr.Port, To: addr,
ToIp: make([]byte, 16),
ToPort: addr.Port,
} }
copy(p.Data, b) copy(p.Data, b)
copy(p.ToIp, addr.IP.To16())
copy(p.FromIp, u.Addr.IP.To16())
u.TxPackets <- p u.TxPackets <- p
return nil return nil
} }
@@ -123,7 +111,6 @@ func (u *TesterConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *fi
plaintext := make([]byte, MTU) plaintext := make([]byte, MTU)
h := &header.H{} h := &header.H{}
fwPacket := &firewall.Packet{} fwPacket := &firewall.Packet{}
ua := &Addr{IP: make([]byte, 16)}
nb := make([]byte, 12, 12) nb := make([]byte, 12, 12)
for { for {
@@ -131,9 +118,7 @@ func (u *TesterConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *fi
if !ok { if !ok {
return return
} }
ua.Port = p.FromPort r(p.From, plaintext[:0], p.Data, h, fwPacket, lhf, nb, q, cache.Get(u.l))
copy(ua.IP, p.FromIp.To16())
r(ua, plaintext[:0], p.Data, h, fwPacket, lhf, nb, q, cache.Get(u.l))
} }
} }
@@ -144,7 +129,7 @@ func NewUDPStatsEmitter(_ []Conn) func() {
return func() {} return func() {}
} }
func (u *TesterConn) LocalAddr() (*Addr, error) { func (u *TesterConn) LocalAddr() (netip.AddrPort, error) {
return u.Addr, nil return u.Addr, nil
} }

View File

@@ -6,12 +6,13 @@ package udp
import ( import (
"fmt" "fmt"
"net" "net"
"net/netip"
"syscall" "syscall"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
) )
func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (Conn, error) { func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
if multi { if multi {
//NOTE: Technically we can support it with RIO but it wouldn't be at the socket level //NOTE: Technically we can support it with RIO but it wouldn't be at the socket level
// The udp stack would need to be reworked to hide away the implementation differences between // The udp stack would need to be reworked to hide away the implementation differences between