Compare commits

..

84 Commits

Author SHA1 Message Date
JackDoan
625acb7cc0 less garbage 2025-11-17 13:30:02 -06:00
JackDoan
f07275edff re-use IOVs 2025-11-17 12:49:46 -06:00
JackDoan
c93fbbe726 make a lil less garbage 2025-11-14 13:28:38 -06:00
JackDoan
dab4745365 FASTER, enable IFF_NAPI 2025-11-14 12:24:22 -06:00
JackDoan
3443d3ffd3 correctly set ethertype, at least for single-IP networks 2025-11-14 11:38:34 -06:00
JackDoan
a39d3c79b2 we can go faster tbh 2025-11-13 16:44:35 -06:00
JackDoan
e8ea021bdd BAM! a hit from my spice-weasel. Now TX is zero-copy, at the expense of my sanity! 2025-11-13 16:29:16 -06:00
JackDoan
994bc8c32b a little cleaner 2025-11-13 12:47:48 -06:00
JackDoan
4e4a85a891 beeg chains 2025-11-13 12:43:45 -06:00
JackDoan
39d028f16a more! 2025-11-13 12:20:27 -06:00
JackDoan
d2254d6fdd note code I 'borrowed' 2025-11-13 12:07:37 -06:00
JackDoan
978ef636b4 remove unused stuff, broken tests 2025-11-13 12:05:48 -06:00
JackDoan
e671bb66b5 add metrics 2025-11-13 12:02:26 -06:00
JackDoan
de95fc0117 big batch 2025-11-13 12:02:26 -06:00
JackDoan
715750d7c8 more better 2025-11-13 12:02:26 -06:00
JackDoan
1a83817cc2 multiqueue but it doesn't help 2025-11-13 12:02:26 -06:00
JackDoan
9bdc513da0 block less? 2025-11-13 12:02:26 -06:00
JackDoan
c6bee8e981 hmm yes time 2025-11-13 12:02:26 -06:00
JackDoan
685ac3e112 no locks all speed 2025-11-13 12:02:25 -06:00
JackDoan
c026e8624a why does it work 2025-11-13 12:02:25 -06:00
JackDoan
17a6917428 why does it work 2025-11-13 12:02:25 -06:00
JackDoan
400fdace9d tweak 2025-11-13 12:02:25 -06:00
JackDoan
cd30e5aa01 working again 2025-11-13 12:02:25 -06:00
JackDoan
1719149594 broken chkpt 2025-11-13 12:02:25 -06:00
JackDoan
e7f01390a3 broken chkpt 2025-11-13 12:02:25 -06:00
JackDoan
c645a45438 what about with bad GRO on UDP 2025-11-13 12:02:25 -06:00
JackDoan
42591c2042 this is awful, but also it's about 20% better 2025-11-13 12:02:25 -06:00
JackDoan
1f043f84f3 not sure if switching to this epoll actually helped 2025-11-13 12:02:25 -06:00
JackDoan
987f45baf0 yeah 2025-11-13 12:02:25 -06:00
JackDoan
edff19a05b yeah 2025-11-13 12:02:25 -06:00
JackDoan
e0f93c9d4b yeah 2025-11-13 12:02:25 -06:00
JackDoan
aab3333615 move things I'm gclog-ing to the bottom 2025-11-13 12:02:25 -06:00
JackDoan
ea1a9e5785 pull deps in for optimization, maybe slice back out later 2025-11-13 12:02:25 -06:00
JackDoan
1a51ee7884 it works I guess 2025-11-13 12:02:25 -06:00
JackDoan
9b29a3fe14 christ 2025-11-13 12:02:25 -06:00
JackDoan
e7176bca01 tx is good? 2025-11-13 12:02:25 -06:00
JackDoan
e3be0943fd checkpt 2025-11-13 12:02:24 -06:00
JackDoan
6e22bfeeb1 vhost 2025-11-13 12:01:59 -06:00
Jack Doan
a89f95182c Firewall types and cross-stack subnet stuff (#1509)
* firewall can distinguish if the host connecting has an overlapping network, is a VPN peer without an overlapping network, or is a unsafe network

* Cross stack subnet stuff (#1512)

* experiment with not filtering out non-common addresses in hostinfo.networks

* allow handshakes without overlaps

* unsafe network test

* change HostInfo.buildNetworks argument to reference the cert
2025-11-12 13:40:20 -06:00
dependabot[bot]
6a8a2992ff Bump google.golang.org/protobuf in the protobuf-dependencies group (#1502)
Bumps the protobuf-dependencies group with 1 update: google.golang.org/protobuf.


Updates `google.golang.org/protobuf` from 1.36.8 to 1.36.10

---
updated-dependencies:
- dependency-name: google.golang.org/protobuf
  dependency-version: 1.36.10
  dependency-type: direct:production
  update-type: version-update:semver-patch
  dependency-group: protobuf-dependencies
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-11-12 09:59:47 -06:00
dependabot[bot]
3d94dfe6a1 Bump the golang-x-dependencies group across 1 directory with 5 updates (#1526)
Bumps the golang-x-dependencies group with 2 updates in the / directory: [golang.org/x/crypto](https://github.com/golang/crypto) and [golang.org/x/sync](https://github.com/golang/sync).


Updates `golang.org/x/crypto` from 0.43.0 to 0.44.0
- [Commits](https://github.com/golang/crypto/compare/v0.43.0...v0.44.0)

Updates `golang.org/x/net` from 0.45.0 to 0.46.0
- [Commits](https://github.com/golang/net/compare/v0.45.0...v0.46.0)

Updates `golang.org/x/sync` from 0.17.0 to 0.18.0
- [Commits](https://github.com/golang/sync/compare/v0.17.0...v0.18.0)

Updates `golang.org/x/sys` from 0.37.0 to 0.38.0
- [Commits](https://github.com/golang/sys/compare/v0.37.0...v0.38.0)

Updates `golang.org/x/term` from 0.36.0 to 0.37.0
- [Commits](https://github.com/golang/term/compare/v0.36.0...v0.37.0)

---
updated-dependencies:
- dependency-name: golang.org/x/crypto
  dependency-version: 0.44.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
  dependency-group: golang-x-dependencies
- dependency-name: golang.org/x/net
  dependency-version: 0.46.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
  dependency-group: golang-x-dependencies
- dependency-name: golang.org/x/sync
  dependency-version: 0.18.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
  dependency-group: golang-x-dependencies
- dependency-name: golang.org/x/sys
  dependency-version: 0.38.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
  dependency-group: golang-x-dependencies
- dependency-name: golang.org/x/term
  dependency-version: 0.37.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
  dependency-group: golang-x-dependencies
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-11-12 09:55:34 -06:00
dependabot[bot]
3670e24fa0 Bump actions/checkout from 4 to 5 (#1450)
Bumps [actions/checkout](https://github.com/actions/checkout) from 4 to 5.
- [Release notes](https://github.com/actions/checkout/releases)
- [Changelog](https://github.com/actions/checkout/blob/main/CHANGELOG.md)
- [Commits](https://github.com/actions/checkout/compare/v4...v5)

---
updated-dependencies:
- dependency-name: actions/checkout
  dependency-version: '5'
  dependency-type: direct:production
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-11-12 09:51:00 -06:00
dependabot[bot]
b348ee726e Bump actions/download-artifact from 4 to 6 (#1516)
Bumps [actions/download-artifact](https://github.com/actions/download-artifact) from 4 to 6.
- [Release notes](https://github.com/actions/download-artifact/releases)
- [Commits](https://github.com/actions/download-artifact/compare/v4...v6)

---
updated-dependencies:
- dependency-name: actions/download-artifact
  dependency-version: '6'
  dependency-type: direct:production
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-11-12 09:48:58 -06:00
dependabot[bot]
a941b65114 Bump actions/upload-artifact from 4 to 5 (#1515)
Bumps [actions/upload-artifact](https://github.com/actions/upload-artifact) from 4 to 5.
- [Release notes](https://github.com/actions/upload-artifact/releases)
- [Commits](https://github.com/actions/upload-artifact/compare/v4...v5)

---
updated-dependencies:
- dependency-name: actions/upload-artifact
  dependency-version: '5'
  dependency-type: direct:production
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-11-12 09:47:38 -06:00
dependabot[bot]
17101d425f Bump golangci/golangci-lint-action from 8 to 9 (#1523)
Bumps [golangci/golangci-lint-action](https://github.com/golangci/golangci-lint-action) from 8 to 9.
- [Release notes](https://github.com/golangci/golangci-lint-action/releases)
- [Commits](https://github.com/golangci/golangci-lint-action/compare/v8...v9)

---
updated-dependencies:
- dependency-name: golangci/golangci-lint-action
  dependency-version: '9'
  dependency-type: direct:production
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-11-12 09:46:10 -06:00
Nate Brown
52f1908126 Don't log every blocklisted fingerprint (#1525) 2025-11-12 09:41:46 -06:00
Wade Simmons
48f1ae98ba switch to go.yaml.in/yaml (#1478)
The `gopkg.in/yaml.v3` library has been declared as Unmaintained:

- https://github.com/go-yaml/yaml?tab=readme-ov-file#this-project-is-unmaintained

The YAML org has taken over maintaining it and now publishes it as
`go.yaml.in/yaml`:

- https://github.com/yaml/go-yaml
2025-11-12 10:26:22 -05:00
Wade Simmons
97b3972c11 honor remote_allow_list in hole punch response (#1186)
* honor remote_allow_ilst in hole punch response

When we receive a "hole punch notification" from a Lighthouse, we send
a hole punch packet to every remote of that host, even if we don't
include those remotes in our "remote_allow_list". Change the logic here
to check if the remote IP is in our allow list before sending the hole
punch packet.

* fix for netip

* cleanup
2025-11-10 13:52:40 -05:00
Jack Doan
0f305d5397 don't block startup on failure to configure SSH (#1520) 2025-11-05 10:41:56 -06:00
Jack Doan
01909f4715 try to make certificate addition/removal reloadable in some cases (#1468)
* try to make certificate addition/removal reloadable in some cases

* very spicy change to respond to handshakes with cert versions we cannot match with a cert that we can indeed match

* even spicier change to rehandshake if we detect our cert is lower-version than our peer, and we have a newer-version cert available

* make tryRehandshake easier to understand
2025-11-03 19:38:44 -06:00
Jack Doan
770147264d fix make bench (#1510) 2025-10-21 11:32:34 -05:00
dependabot[bot]
fa8c013b97 Bump github.com/miekg/dns from 1.1.65 to 1.1.68 (#1444)
Bumps [github.com/miekg/dns](https://github.com/miekg/dns) from 1.1.65 to 1.1.68.
- [Changelog](https://github.com/miekg/dns/blob/master/Makefile.release)
- [Commits](https://github.com/miekg/dns/compare/v1.1.65...v1.1.68)

---
updated-dependencies:
- dependency-name: github.com/miekg/dns
  dependency-version: 1.1.68
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-10-13 16:41:51 -04:00
dependabot[bot]
2710f2af06 Bump github.com/kardianos/service from 1.2.2 to 1.2.4 (#1433)
Bumps [github.com/kardianos/service](https://github.com/kardianos/service) from 1.2.2 to 1.2.4.
- [Commits](https://github.com/kardianos/service/compare/v1.2.2...v1.2.4)

---
updated-dependencies:
- dependency-name: github.com/kardianos/service
  dependency-version: 1.2.4
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-10-13 15:58:15 -04:00
dependabot[bot]
ad6d3e6bac Bump the golang-x-dependencies group across 1 directory with 5 updates (#1409)
Bumps the golang-x-dependencies group with 3 updates in the / directory: [golang.org/x/crypto](https://github.com/golang/crypto), [golang.org/x/net](https://github.com/golang/net) and [golang.org/x/sync](https://github.com/golang/sync).


Updates `golang.org/x/crypto` from 0.37.0 to 0.38.0
- [Commits](https://github.com/golang/crypto/compare/v0.37.0...v0.38.0)

Updates `golang.org/x/net` from 0.39.0 to 0.40.0
- [Commits](https://github.com/golang/net/compare/v0.39.0...v0.40.0)

Updates `golang.org/x/sync` from 0.13.0 to 0.14.0
- [Commits](https://github.com/golang/sync/compare/v0.13.0...v0.14.0)

Updates `golang.org/x/sys` from 0.32.0 to 0.33.0
- [Commits](https://github.com/golang/sys/compare/v0.32.0...v0.33.0)

Updates `golang.org/x/term` from 0.31.0 to 0.32.0
- [Commits](https://github.com/golang/term/compare/v0.31.0...v0.32.0)

---
updated-dependencies:
- dependency-name: golang.org/x/crypto
  dependency-version: 0.38.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
  dependency-group: golang-x-dependencies
- dependency-name: golang.org/x/net
  dependency-version: 0.40.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
  dependency-group: golang-x-dependencies
- dependency-name: golang.org/x/sync
  dependency-version: 0.14.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
  dependency-group: golang-x-dependencies
- dependency-name: golang.org/x/sys
  dependency-version: 0.33.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
  dependency-group: golang-x-dependencies
- dependency-name: golang.org/x/term
  dependency-version: 0.32.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
  dependency-group: golang-x-dependencies
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-10-13 15:54:38 -04:00
dependabot[bot]
2b0aa74e85 Bump github.com/prometheus/client_golang from 1.22.0 to 1.23.2 (#1470)
Bumps [github.com/prometheus/client_golang](https://github.com/prometheus/client_golang) from 1.22.0 to 1.23.2.
- [Release notes](https://github.com/prometheus/client_golang/releases)
- [Changelog](https://github.com/prometheus/client_golang/blob/main/CHANGELOG.md)
- [Commits](https://github.com/prometheus/client_golang/compare/v1.22.0...v1.23.2)

---
updated-dependencies:
- dependency-name: github.com/prometheus/client_golang
  dependency-version: 1.23.2
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-10-13 15:16:24 -04:00
dependabot[bot]
b126d88963 Bump github.com/gaissmai/bart from 0.20.4 to 0.25.0 (#1471)
Bumps [github.com/gaissmai/bart](https://github.com/gaissmai/bart) from 0.20.4 to 0.25.0.
- [Release notes](https://github.com/gaissmai/bart/releases)
- [Commits](https://github.com/gaissmai/bart/compare/v0.20.4...v0.25.0)

---
updated-dependencies:
- dependency-name: github.com/gaissmai/bart
  dependency-version: 0.25.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-10-13 15:15:07 -04:00
Nate Brown
45c1d3eab3 Support for multi proto tun device on OpenBSD (#1495) 2025-10-08 16:56:42 -05:00
Gary Guo
634181ba66 Fix incorrect CIDR construction in hostmap (#1493)
* Fix incorrect CIDR construction in hostmap

* Introduce a regression test for incorrect hostmap CIDR
2025-10-08 11:02:36 -05:00
Nate Brown
eb89839d13 Support for multi proto tun device on NetBSD (#1492) 2025-10-07 20:17:50 -05:00
Nate Brown
fb7f0c3657 Use x/net/route to manage routes directly (#1488) 2025-10-03 10:59:53 -05:00
sl274
b1f53d8d25 Support IPv6 tunneling in FreeBSD (#1399)
Recent merge of cert-v2 support introduced the ability to tunnel IPv6. However, FreeBSD's IPv6 tunneling does not work for 2 reasons:
* The ifconfig commands did not work for IPv6 addresses
* The tunnel device was not configured for link-layer mode, so it only supported IPv4

This PR improves FreeBSD tunneling support in 3 ways:
* Use ioctl instead of exec'ing ifconfig to configure the interface, with additional logic to support IPv6
* Configure the tunnel in link-layer mode, allowing IPv6 traffic
* Use readv() and writev() to communicate with the tunnel device, to avoid the need to copy the packet buffer
2025-10-02 21:54:30 -05:00
Jack Doan
8824eeaea2 helper functions to more correctly marshal curve 25519 public keys (#1481) 2025-10-02 13:56:41 -05:00
dependabot[bot]
071589f7c7 Bump actions/setup-go from 5 to 6 (#1469)
* Bump actions/setup-go from 5 to 6

Bumps [actions/setup-go](https://github.com/actions/setup-go) from 5 to 6.
- [Release notes](https://github.com/actions/setup-go/releases)
- [Commits](https://github.com/actions/setup-go/compare/v5...v6)

---
updated-dependencies:
- dependency-name: actions/setup-go
  dependency-version: '6'
  dependency-type: direct:production
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>

* Hardcode the last one to go v1.25

---------

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Nate Brown <nbrown.us@gmail.com>
2025-10-02 00:05:12 -05:00
Jack Doan
f1e992f6dd don't require a detailsVpnAddr in a HostUpdateNotification (#1472)
* don't require a detailsVpnAddr in a HostUpdateNotification

* don't send our own addr on HostUpdateNotification for v2
2025-09-29 13:43:12 -05:00
Jack Doan
1ea5f776d7 update to go 1.25, use the cool new ECDSA key marshalling functions (#1483)
* update to go 1.25, use the cool new ECDSA key marshalling functions

* bonk the runners

* actually bump go.mod

* bump golangci-lint
2025-09-29 13:02:25 -05:00
Henry Graham
4cdeb284ef Set CKA_VALUE_LEN attribute in DeriveNoise (#1482) 2025-09-25 13:24:52 -05:00
Jack Doan
5cccd39465 update RemoteList.vpnAddrs when we complete a handshake (#1467) 2025-09-10 09:44:25 -05:00
Jack Doan
8196c22b5a store lighthouses as a slice (#1473)
* store lighthouses as a slice. If you have fewer than 16 lighthouses (and fewer than 16 vpnaddrs on a host, I guess), it's faster
2025-09-10 09:43:25 -05:00
Jack Doan
65cc253c19 prevent linux from assigning ipv6 link-local addresses (#1476) 2025-09-09 13:25:23 -05:00
Wade Simmons
73cfa7b5b1 add firewall tests for ipv6 (#1451)
Test things like cidr and local_cidr with ipv6 addresses, to ensure
everything is working correctly.
2025-09-08 13:57:36 -04:00
Jack Doan
768325c9b4 cert-v2 chores (#1466) 2025-09-05 15:08:22 -05:00
Jack Doan
932e329164 Don't delete static host mappings for non-primary IPs (#1464)
* Don't delete a vpnaddr if it's part of a certificate that contains a vpnaddr that's in the static host map

* remove unused arg from ConnectionManager.shouldSwapPrimary()
2025-09-04 14:49:40 -05:00
Jack Doan
4bea299265 don't send recv errors for packets outside the connection window anymore (#1463)
* don't send recv errors for packets outside the connection window anymore

* Pull in fix from #1459, add my opinion on maxRecvError

* remove recv_error counter entirely
2025-09-03 11:52:52 -05:00
Wade Simmons
5cff83b282 netlink: ignore route updates with no destination (#1437)
Currently we assume each route update must have a destination, but we
should check that it is set before we try to use it.

See: #1436
2025-08-25 13:05:35 -05:00
Wade Simmons
7da79685ff fix lighthouse.calculated_remotes parsing (#1438)
Some checks failed
gofmt / Run gofmt (push) Successful in 27s
smoke-extra / Run extra smoke tests (push) Failing after 21s
smoke / Run multi node smoke test (push) Failing after 1m21s
Build and test / Build all and test on ubuntu-linux (push) Failing after 18m9s
Build and test / Build and test on linux with boringcrypto (push) Failing after 2m16s
Build and test / Build and test on linux with pkcs11 (push) Failing after 2m41s
Build and test / Build and test on macos-latest (push) Has been cancelled
Build and test / Build and test on windows-latest (push) Has been cancelled
This was broken with the change to yaml.v3:

- https://github.com/slackhq/nebula/pull/1148

We forgot to update these references to `map[string]any`.

Without this fix, Nebula crashes with an error like this:

    {"error":"config `lighthouse.calculated_remotes` has invalid type: map[string]interface {}","level":"error","msg":"Invalid lighthouse.calculated_remotes","time":"2025-07-29T15:50:06.479499Z"}
2025-07-29 13:12:07 -04:00
brad-defined
91eff03418 Update slack OSS invite link (#1435)
Some checks failed
gofmt / Run gofmt (push) Successful in 38s
smoke-extra / Run extra smoke tests (push) Failing after 21s
smoke / Run multi node smoke test (push) Failing after 1m21s
Build and test / Build all and test on ubuntu-linux (push) Failing after 19m47s
Build and test / Build and test on linux with boringcrypto (push) Failing after 2m24s
Build and test / Build and test on linux with pkcs11 (push) Failing after 2m29s
Build and test / Build and test on macos-latest (push) Has been cancelled
Build and test / Build and test on windows-latest (push) Has been cancelled
2025-07-15 16:05:28 -04:00
Nate Brown
52623820c2 Drop inactive tunnels (#1427) 2025-07-03 09:58:37 -05:00
Nate Brown
c2420642a0 Darwin udp fix (#1428) 2025-07-02 15:50:22 -05:00
brad-defined
b3a1f7b0a3 Disable UDP receive error returns due to ICMP messages on Windows. (#1412) (#1415) 2025-07-02 11:37:41 -04:00
brad-defined
94142aded5 Fix relay migration panic by covering every possible relay state (#1414) 2025-07-02 08:48:02 -04:00
brad-defined
b158eb0c4c Use a list for relay IPs instead of a map (#1423)
* Use a list for relay IPs instead of a map

* linter
2025-07-02 08:47:05 -04:00
dependabot[bot]
e4b7dbcfb0 Bump dario.cat/mergo from 1.0.1 to 1.0.2 (#1408)
Bumps [dario.cat/mergo](https://github.com/imdario/mergo) from 1.0.1 to 1.0.2.
- [Release notes](https://github.com/imdario/mergo/releases)
- [Commits](https://github.com/imdario/mergo/compare/v1.0.1...v1.0.2)

---
updated-dependencies:
- dependency-name: dario.cat/mergo
  dependency-version: 1.0.2
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-07-01 23:30:40 -05:00
dependabot[bot]
882edf11d7 Bump github.com/vishvananda/netlink from 1.3.0 to 1.3.1 (#1407)
Bumps [github.com/vishvananda/netlink](https://github.com/vishvananda/netlink) from 1.3.0 to 1.3.1.
- [Release notes](https://github.com/vishvananda/netlink/releases)
- [Commits](https://github.com/vishvananda/netlink/compare/v1.3.0...v1.3.1)

---
updated-dependencies:
- dependency-name: github.com/vishvananda/netlink
  dependency-version: 1.3.1
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-07-01 23:29:15 -05:00
dependabot[bot]
d34c2b8e06 Bump golangci/golangci-lint-action from 7 to 8 (#1400)
* Bump golangci/golangci-lint-action from 7 to 8

Bumps [golangci/golangci-lint-action](https://github.com/golangci/golangci-lint-action) from 7 to 8.
- [Release notes](https://github.com/golangci/golangci-lint-action/releases)
- [Commits](https://github.com/golangci/golangci-lint-action/compare/v7...v8)

---
updated-dependencies:
- dependency-name: golangci/golangci-lint-action
  dependency-version: '8'
  dependency-type: direct:production
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>

* bump golangci-lint version

---------

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Wade Simmons <wsimmons@slack-corp.com>
2025-07-01 23:25:24 -05:00
107 changed files with 8748 additions and 1289 deletions

View File

@@ -17,5 +17,5 @@ contact_links:
about: 'The documentation is the best place to start if you are new to Nebula.' about: 'The documentation is the best place to start if you are new to Nebula.'
- name: 💁 Support/Chat - name: 💁 Support/Chat
url: https://join.slack.com/t/nebulaoss/shared_invite/zt-2xqe6e7vn-k_KGi8s13nsr7cvHVvHvuQ url: https://join.slack.com/t/nebulaoss/shared_invite/zt-39pk4xopc-CUKlGcb5Z39dQ0cK1v7ehA
about: 'For faster support, join us on Slack for assistance!' about: 'For faster support, join us on Slack for assistance!'

View File

@@ -14,11 +14,11 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v5
- uses: actions/setup-go@v5 - uses: actions/setup-go@v6
with: with:
go-version: '1.24' go-version: '1.25'
check-latest: true check-latest: true
- name: Install goimports - name: Install goimports

View File

@@ -10,11 +10,11 @@ jobs:
name: Build Linux/BSD All name: Build Linux/BSD All
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v5
- uses: actions/setup-go@v5 - uses: actions/setup-go@v6
with: with:
go-version: '1.24' go-version: '1.25'
check-latest: true check-latest: true
- name: Build - name: Build
@@ -24,7 +24,7 @@ jobs:
mv build/*.tar.gz release mv build/*.tar.gz release
- name: Upload artifacts - name: Upload artifacts
uses: actions/upload-artifact@v4 uses: actions/upload-artifact@v5
with: with:
name: linux-latest name: linux-latest
path: release path: release
@@ -33,11 +33,11 @@ jobs:
name: Build Windows name: Build Windows
runs-on: windows-latest runs-on: windows-latest
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v5
- uses: actions/setup-go@v5 - uses: actions/setup-go@v6
with: with:
go-version: '1.24' go-version: '1.25'
check-latest: true check-latest: true
- name: Build - name: Build
@@ -55,7 +55,7 @@ jobs:
mv dist\windows\wintun build\dist\windows\ mv dist\windows\wintun build\dist\windows\
- name: Upload artifacts - name: Upload artifacts
uses: actions/upload-artifact@v4 uses: actions/upload-artifact@v5
with: with:
name: windows-latest name: windows-latest
path: build path: build
@@ -66,11 +66,11 @@ jobs:
HAS_SIGNING_CREDS: ${{ secrets.AC_USERNAME != '' }} HAS_SIGNING_CREDS: ${{ secrets.AC_USERNAME != '' }}
runs-on: macos-latest runs-on: macos-latest
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v5
- uses: actions/setup-go@v5 - uses: actions/setup-go@v6
with: with:
go-version: '1.24' go-version: '1.25'
check-latest: true check-latest: true
- name: Import certificates - name: Import certificates
@@ -104,7 +104,7 @@ jobs:
fi fi
- name: Upload artifacts - name: Upload artifacts
uses: actions/upload-artifact@v4 uses: actions/upload-artifact@v5
with: with:
name: darwin-latest name: darwin-latest
path: ./release/* path: ./release/*
@@ -124,11 +124,11 @@ jobs:
# be overwritten # be overwritten
- name: Checkout code - name: Checkout code
if: ${{ env.HAS_DOCKER_CREDS == 'true' }} if: ${{ env.HAS_DOCKER_CREDS == 'true' }}
uses: actions/checkout@v4 uses: actions/checkout@v5
- name: Download artifacts - name: Download artifacts
if: ${{ env.HAS_DOCKER_CREDS == 'true' }} if: ${{ env.HAS_DOCKER_CREDS == 'true' }}
uses: actions/download-artifact@v4 uses: actions/download-artifact@v6
with: with:
name: linux-latest name: linux-latest
path: artifacts path: artifacts
@@ -160,10 +160,10 @@ jobs:
needs: [build-linux, build-darwin, build-windows] needs: [build-linux, build-darwin, build-windows]
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v5
- name: Download artifacts - name: Download artifacts
uses: actions/download-artifact@v4 uses: actions/download-artifact@v6
with: with:
path: artifacts path: artifacts

View File

@@ -20,11 +20,11 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v5
- uses: actions/setup-go@v5 - uses: actions/setup-go@v6
with: with:
go-version-file: 'go.mod' go-version: '1.25'
check-latest: true check-latest: true
- name: add hashicorp source - name: add hashicorp source

View File

@@ -18,11 +18,11 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v5
- uses: actions/setup-go@v5 - uses: actions/setup-go@v6
with: with:
go-version: '1.24' go-version: '1.25'
check-latest: true check-latest: true
- name: build - name: build

View File

@@ -18,11 +18,11 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v5
- uses: actions/setup-go@v5 - uses: actions/setup-go@v6
with: with:
go-version: '1.24' go-version: '1.25'
check-latest: true check-latest: true
- name: Build - name: Build
@@ -32,9 +32,9 @@ jobs:
run: make vet run: make vet
- name: golangci-lint - name: golangci-lint
uses: golangci/golangci-lint-action@v7 uses: golangci/golangci-lint-action@v9
with: with:
version: v2.0 version: v2.5
- name: Test - name: Test
run: make test run: make test
@@ -45,7 +45,7 @@ jobs:
- name: Build test mobile - name: Build test mobile
run: make build-test-mobile run: make build-test-mobile
- uses: actions/upload-artifact@v4 - uses: actions/upload-artifact@v5
with: with:
name: e2e packet flow linux-latest name: e2e packet flow linux-latest
path: e2e/mermaid/linux-latest path: e2e/mermaid/linux-latest
@@ -56,11 +56,11 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v5
- uses: actions/setup-go@v5 - uses: actions/setup-go@v6
with: with:
go-version: '1.24' go-version: '1.25'
check-latest: true check-latest: true
- name: Build - name: Build
@@ -77,11 +77,11 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v5
- uses: actions/setup-go@v5 - uses: actions/setup-go@v6
with: with:
go-version: '1.22' go-version: '1.25'
check-latest: true check-latest: true
- name: Build - name: Build
@@ -98,11 +98,11 @@ jobs:
os: [windows-latest, macos-latest] os: [windows-latest, macos-latest]
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v5
- uses: actions/setup-go@v5 - uses: actions/setup-go@v6
with: with:
go-version: '1.24' go-version: '1.25'
check-latest: true check-latest: true
- name: Build nebula - name: Build nebula
@@ -115,9 +115,9 @@ jobs:
run: make vet run: make vet
- name: golangci-lint - name: golangci-lint
uses: golangci/golangci-lint-action@v7 uses: golangci/golangci-lint-action@v9
with: with:
version: v2.0 version: v2.5
- name: Test - name: Test
run: make test run: make test
@@ -125,7 +125,7 @@ jobs:
- name: End 2 end - name: End 2 end
run: make e2evv run: make e2evv
- uses: actions/upload-artifact@v4 - uses: actions/upload-artifact@v5
with: with:
name: e2e packet flow ${{ matrix.os }} name: e2e packet flow ${{ matrix.os }}
path: e2e/mermaid/${{ matrix.os }} path: e2e/mermaid/${{ matrix.os }}

View File

@@ -61,7 +61,7 @@ ALL = $(ALL_LINUX) \
windows-arm64 windows-arm64
e2e: e2e:
$(TEST_ENV) go test -tags=synctrace,e2e_testing -count=1 $(TEST_FLAGS) ./e2e $(TEST_ENV) go test -tags=e2e_testing -count=1 $(TEST_FLAGS) ./e2e
e2ev: TEST_FLAGS += -v e2ev: TEST_FLAGS += -v
e2ev: e2e e2ev: e2e
@@ -215,7 +215,6 @@ ifeq ($(words $(MAKECMDGOALS)),1)
@$(MAKE) service ${.DEFAULT_GOAL} --no-print-directory @$(MAKE) service ${.DEFAULT_GOAL} --no-print-directory
endif endif
bin-docker: BUILD_ARGS = -tags=synctrace
bin-docker: bin build/linux-amd64/nebula build/linux-amd64/nebula-cert bin-docker: bin build/linux-amd64/nebula build/linux-amd64/nebula-cert
smoke-docker: bin-docker smoke-docker: bin-docker

View File

@@ -12,7 +12,7 @@ Further documentation can be found [here](https://nebula.defined.net/docs/).
You can read more about Nebula [here](https://medium.com/p/884110a5579). You can read more about Nebula [here](https://medium.com/p/884110a5579).
You can also join the NebulaOSS Slack group [here](https://join.slack.com/t/nebulaoss/shared_invite/zt-2xqe6e7vn-k_KGi8s13nsr7cvHVvHvuQ). You can also join the NebulaOSS Slack group [here](https://join.slack.com/t/nebulaoss/shared_invite/zt-39pk4xopc-CUKlGcb5Z39dQ0cK1v7ehA).
## Supported Platforms ## Supported Platforms

View File

@@ -43,7 +43,7 @@ func (b *Bits) Check(l logrus.FieldLogger, i uint64) bool {
} }
// Not within the window // Not within the window
l.Debugf("rejected a packet (top) %d %d\n", b.current, i) l.Debugf("rejected a packet (top) %d %d delta %d\n", b.current, i, b.current-i)
return false return false
} }

View File

@@ -84,16 +84,11 @@ func NewCalculatedRemotesFromConfig(c *config.C, k string) (*bart.Table[[]*calcu
calculatedRemotes := new(bart.Table[[]*calculatedRemote]) calculatedRemotes := new(bart.Table[[]*calculatedRemote])
rawMap, ok := value.(map[any]any) rawMap, ok := value.(map[string]any)
if !ok { if !ok {
return nil, fmt.Errorf("config `%s` has invalid type: %T", k, value) return nil, fmt.Errorf("config `%s` has invalid type: %T", k, value)
} }
for rawKey, rawValue := range rawMap { for rawCIDR, rawValue := range rawMap {
rawCIDR, ok := rawKey.(string)
if !ok {
return nil, fmt.Errorf("config `%s` has invalid key (type %T): %v", k, rawKey, rawKey)
}
cidr, err := netip.ParsePrefix(rawCIDR) cidr, err := netip.ParsePrefix(rawCIDR)
if err != nil { if err != nil {
return nil, fmt.Errorf("config `%s` has invalid CIDR: %s", k, rawCIDR) return nil, fmt.Errorf("config `%s` has invalid CIDR: %s", k, rawCIDR)
@@ -129,7 +124,7 @@ func newCalculatedRemotesListFromConfig(cidr netip.Prefix, raw any) ([]*calculat
} }
func newCalculatedRemotesEntryFromConfig(cidr netip.Prefix, raw any) (*calculatedRemote, error) { func newCalculatedRemotesEntryFromConfig(cidr netip.Prefix, raw any) (*calculatedRemote, error) {
rawMap, ok := raw.(map[any]any) rawMap, ok := raw.(map[string]any)
if !ok { if !ok {
return nil, fmt.Errorf("invalid type: %T", raw) return nil, fmt.Errorf("invalid type: %T", raw)
} }

View File

@@ -58,6 +58,9 @@ type Certificate interface {
// PublicKey is the raw bytes to be used in asymmetric cryptographic operations. // PublicKey is the raw bytes to be used in asymmetric cryptographic operations.
PublicKey() []byte PublicKey() []byte
// MarshalPublicKeyPEM is the value of PublicKey marshalled to PEM
MarshalPublicKeyPEM() []byte
// Curve identifies which curve was used for the PublicKey and Signature. // Curve identifies which curve was used for the PublicKey and Signature.
Curve() Curve Curve() Curve
@@ -135,8 +138,7 @@ func Recombine(v Version, rawCertBytes, publicKey []byte, curve Curve) (Certific
case Version2: case Version2:
c, err = unmarshalCertificateV2(rawCertBytes, publicKey, curve) c, err = unmarshalCertificateV2(rawCertBytes, publicKey, curve)
default: default:
//TODO: CERT-V2 make a static var return nil, ErrUnknownVersion
return nil, fmt.Errorf("unknown certificate version %d", v)
} }
if err != nil { if err != nil {

View File

@@ -83,6 +83,10 @@ func (c *certificateV1) PublicKey() []byte {
return c.details.publicKey return c.details.publicKey
} }
func (c *certificateV1) MarshalPublicKeyPEM() []byte {
return marshalCertPublicKeyToPEM(c)
}
func (c *certificateV1) Signature() []byte { func (c *certificateV1) Signature() []byte {
return c.signature return c.signature
} }
@@ -110,8 +114,10 @@ func (c *certificateV1) CheckSignature(key []byte) bool {
case Curve_CURVE25519: case Curve_CURVE25519:
return ed25519.Verify(key, b, c.signature) return ed25519.Verify(key, b, c.signature)
case Curve_P256: case Curve_P256:
x, y := elliptic.Unmarshal(elliptic.P256(), key) pubKey, err := ecdsa.ParseUncompressedPublicKey(elliptic.P256(), key)
pubKey := &ecdsa.PublicKey{Curve: elliptic.P256(), X: x, Y: y} if err != nil {
return false
}
hashed := sha256.Sum256(b) hashed := sha256.Sum256(b)
return ecdsa.VerifyASN1(pubKey, hashed[:], c.signature) return ecdsa.VerifyASN1(pubKey, hashed[:], c.signature)
default: default:

View File

@@ -1,6 +1,7 @@
package cert package cert
import ( import (
"crypto/ed25519"
"fmt" "fmt"
"net/netip" "net/netip"
"testing" "testing"
@@ -13,6 +14,7 @@ import (
) )
func TestCertificateV1_Marshal(t *testing.T) { func TestCertificateV1_Marshal(t *testing.T) {
t.Parallel()
before := time.Now().Add(time.Second * -60).Round(time.Second) before := time.Now().Add(time.Second * -60).Round(time.Second)
after := time.Now().Add(time.Second * 60).Round(time.Second) after := time.Now().Add(time.Second * 60).Round(time.Second)
pubKey := []byte("1234567890abcedfghij1234567890ab") pubKey := []byte("1234567890abcedfghij1234567890ab")
@@ -60,6 +62,58 @@ func TestCertificateV1_Marshal(t *testing.T) {
assert.Equal(t, nc.Groups(), nc2.Groups()) assert.Equal(t, nc.Groups(), nc2.Groups())
} }
func TestCertificateV1_PublicKeyPem(t *testing.T) {
t.Parallel()
before := time.Now().Add(time.Second * -60).Round(time.Second)
after := time.Now().Add(time.Second * 60).Round(time.Second)
pubKey := ed25519.PublicKey("1234567890abcedfghij1234567890ab")
nc := certificateV1{
details: detailsV1{
name: "testing",
networks: []netip.Prefix{},
unsafeNetworks: []netip.Prefix{},
groups: []string{"test-group1", "test-group2", "test-group3"},
notBefore: before,
notAfter: after,
publicKey: pubKey,
isCA: false,
issuer: "1234567890abcedfghij1234567890ab",
},
signature: []byte("1234567890abcedfghij1234567890ab"),
}
assert.Equal(t, Version1, nc.Version())
assert.Equal(t, Curve_CURVE25519, nc.Curve())
pubPem := "-----BEGIN NEBULA X25519 PUBLIC KEY-----\nMTIzNDU2Nzg5MGFiY2VkZmdoaWoxMjM0NTY3ODkwYWI=\n-----END NEBULA X25519 PUBLIC KEY-----\n"
assert.Equal(t, string(nc.MarshalPublicKeyPEM()), pubPem)
assert.False(t, nc.IsCA())
nc.details.isCA = true
assert.Equal(t, Curve_CURVE25519, nc.Curve())
pubPem = "-----BEGIN NEBULA ED25519 PUBLIC KEY-----\nMTIzNDU2Nzg5MGFiY2VkZmdoaWoxMjM0NTY3ODkwYWI=\n-----END NEBULA ED25519 PUBLIC KEY-----\n"
assert.Equal(t, string(nc.MarshalPublicKeyPEM()), pubPem)
assert.True(t, nc.IsCA())
pubP256KeyPem := []byte(`-----BEGIN NEBULA P256 PUBLIC KEY-----
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
AAAAAAAAAAAAAAAAAAAAAAA=
-----END NEBULA P256 PUBLIC KEY-----
`)
pubP256Key, _, _, err := UnmarshalPublicKeyFromPEM(pubP256KeyPem)
require.NoError(t, err)
nc.details.curve = Curve_P256
nc.details.publicKey = pubP256Key
assert.Equal(t, Curve_P256, nc.Curve())
assert.Equal(t, string(nc.MarshalPublicKeyPEM()), string(pubP256KeyPem))
assert.True(t, nc.IsCA())
nc.details.isCA = false
assert.Equal(t, Curve_P256, nc.Curve())
assert.Equal(t, string(nc.MarshalPublicKeyPEM()), string(pubP256KeyPem))
assert.False(t, nc.IsCA())
}
func TestCertificateV1_Expired(t *testing.T) { func TestCertificateV1_Expired(t *testing.T) {
nc := certificateV1{ nc := certificateV1{
details: detailsV1{ details: detailsV1{

View File

@@ -114,6 +114,10 @@ func (c *certificateV2) PublicKey() []byte {
return c.publicKey return c.publicKey
} }
func (c *certificateV2) MarshalPublicKeyPEM() []byte {
return marshalCertPublicKeyToPEM(c)
}
func (c *certificateV2) Signature() []byte { func (c *certificateV2) Signature() []byte {
return c.signature return c.signature
} }
@@ -149,8 +153,10 @@ func (c *certificateV2) CheckSignature(key []byte) bool {
case Curve_CURVE25519: case Curve_CURVE25519:
return ed25519.Verify(key, b, c.signature) return ed25519.Verify(key, b, c.signature)
case Curve_P256: case Curve_P256:
x, y := elliptic.Unmarshal(elliptic.P256(), key) pubKey, err := ecdsa.ParseUncompressedPublicKey(elliptic.P256(), key)
pubKey := &ecdsa.PublicKey{Curve: elliptic.P256(), X: x, Y: y} if err != nil {
return false
}
hashed := sha256.Sum256(b) hashed := sha256.Sum256(b)
return ecdsa.VerifyASN1(pubKey, hashed[:], c.signature) return ecdsa.VerifyASN1(pubKey, hashed[:], c.signature)
default: default:

View File

@@ -15,6 +15,7 @@ import (
) )
func TestCertificateV2_Marshal(t *testing.T) { func TestCertificateV2_Marshal(t *testing.T) {
t.Parallel()
before := time.Now().Add(time.Second * -60).Round(time.Second) before := time.Now().Add(time.Second * -60).Round(time.Second)
after := time.Now().Add(time.Second * 60).Round(time.Second) after := time.Now().Add(time.Second * 60).Round(time.Second)
pubKey := []byte("1234567890abcedfghij1234567890ab") pubKey := []byte("1234567890abcedfghij1234567890ab")
@@ -75,6 +76,58 @@ func TestCertificateV2_Marshal(t *testing.T) {
assert.Equal(t, nc.Groups(), nc2.Groups()) assert.Equal(t, nc.Groups(), nc2.Groups())
} }
func TestCertificateV2_PublicKeyPem(t *testing.T) {
t.Parallel()
before := time.Now().Add(time.Second * -60).Round(time.Second)
after := time.Now().Add(time.Second * 60).Round(time.Second)
pubKey := ed25519.PublicKey("1234567890abcedfghij1234567890ab")
nc := certificateV2{
details: detailsV2{
name: "testing",
networks: []netip.Prefix{},
unsafeNetworks: []netip.Prefix{},
groups: []string{"test-group1", "test-group2", "test-group3"},
notBefore: before,
notAfter: after,
isCA: false,
issuer: "1234567890abcedfghij1234567890ab",
},
publicKey: pubKey,
signature: []byte("1234567890abcedfghij1234567890ab"),
}
assert.Equal(t, Version2, nc.Version())
assert.Equal(t, Curve_CURVE25519, nc.Curve())
pubPem := "-----BEGIN NEBULA X25519 PUBLIC KEY-----\nMTIzNDU2Nzg5MGFiY2VkZmdoaWoxMjM0NTY3ODkwYWI=\n-----END NEBULA X25519 PUBLIC KEY-----\n"
assert.Equal(t, string(nc.MarshalPublicKeyPEM()), pubPem)
assert.False(t, nc.IsCA())
nc.details.isCA = true
assert.Equal(t, Curve_CURVE25519, nc.Curve())
pubPem = "-----BEGIN NEBULA ED25519 PUBLIC KEY-----\nMTIzNDU2Nzg5MGFiY2VkZmdoaWoxMjM0NTY3ODkwYWI=\n-----END NEBULA ED25519 PUBLIC KEY-----\n"
assert.Equal(t, string(nc.MarshalPublicKeyPEM()), pubPem)
assert.True(t, nc.IsCA())
pubP256KeyPem := []byte(`-----BEGIN NEBULA P256 PUBLIC KEY-----
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
AAAAAAAAAAAAAAAAAAAAAAA=
-----END NEBULA P256 PUBLIC KEY-----
`)
pubP256Key, _, _, err := UnmarshalPublicKeyFromPEM(pubP256KeyPem)
require.NoError(t, err)
nc.curve = Curve_P256
nc.publicKey = pubP256Key
assert.Equal(t, Curve_P256, nc.Curve())
assert.Equal(t, string(nc.MarshalPublicKeyPEM()), string(pubP256KeyPem))
assert.True(t, nc.IsCA())
nc.details.isCA = false
assert.Equal(t, Curve_P256, nc.Curve())
assert.Equal(t, string(nc.MarshalPublicKeyPEM()), string(pubP256KeyPem))
assert.False(t, nc.IsCA())
}
func TestCertificateV2_Expired(t *testing.T) { func TestCertificateV2_Expired(t *testing.T) {
nc := certificateV2{ nc := certificateV2{
details: detailsV2{ details: detailsV2{

View File

@@ -20,6 +20,7 @@ var (
ErrPublicPrivateKeyMismatch = errors.New("public key and private key are not a pair") ErrPublicPrivateKeyMismatch = errors.New("public key and private key are not a pair")
ErrPrivateKeyEncrypted = errors.New("private key must be decrypted") ErrPrivateKeyEncrypted = errors.New("private key must be decrypted")
ErrCaNotFound = errors.New("could not find ca for the certificate") ErrCaNotFound = errors.New("could not find ca for the certificate")
ErrUnknownVersion = errors.New("certificate version unrecognized")
ErrInvalidPEMBlock = errors.New("input did not contain a valid PEM encoded block") ErrInvalidPEMBlock = errors.New("input did not contain a valid PEM encoded block")
ErrInvalidPEMCertificateBanner = errors.New("bytes did not contain a proper certificate banner") ErrInvalidPEMCertificateBanner = errors.New("bytes did not contain a proper certificate banner")

View File

@@ -7,19 +7,26 @@ import (
"golang.org/x/crypto/ed25519" "golang.org/x/crypto/ed25519"
) )
const ( const ( //cert banners
CertificateBanner = "NEBULA CERTIFICATE" CertificateBanner = "NEBULA CERTIFICATE"
CertificateV2Banner = "NEBULA CERTIFICATE V2" CertificateV2Banner = "NEBULA CERTIFICATE V2"
)
const ( //key-agreement-key banners
X25519PrivateKeyBanner = "NEBULA X25519 PRIVATE KEY" X25519PrivateKeyBanner = "NEBULA X25519 PRIVATE KEY"
X25519PublicKeyBanner = "NEBULA X25519 PUBLIC KEY" X25519PublicKeyBanner = "NEBULA X25519 PUBLIC KEY"
P256PrivateKeyBanner = "NEBULA P256 PRIVATE KEY"
P256PublicKeyBanner = "NEBULA P256 PUBLIC KEY"
)
/* including "ECDSA" in the P256 banners is a clue that these keys should be used only for signing */
const ( //signing key banners
EncryptedECDSAP256PrivateKeyBanner = "NEBULA ECDSA P256 ENCRYPTED PRIVATE KEY"
ECDSAP256PrivateKeyBanner = "NEBULA ECDSA P256 PRIVATE KEY"
ECDSAP256PublicKeyBanner = "NEBULA ECDSA P256 PUBLIC KEY"
EncryptedEd25519PrivateKeyBanner = "NEBULA ED25519 ENCRYPTED PRIVATE KEY" EncryptedEd25519PrivateKeyBanner = "NEBULA ED25519 ENCRYPTED PRIVATE KEY"
Ed25519PrivateKeyBanner = "NEBULA ED25519 PRIVATE KEY" Ed25519PrivateKeyBanner = "NEBULA ED25519 PRIVATE KEY"
Ed25519PublicKeyBanner = "NEBULA ED25519 PUBLIC KEY" Ed25519PublicKeyBanner = "NEBULA ED25519 PUBLIC KEY"
P256PrivateKeyBanner = "NEBULA P256 PRIVATE KEY"
P256PublicKeyBanner = "NEBULA P256 PUBLIC KEY"
EncryptedECDSAP256PrivateKeyBanner = "NEBULA ECDSA P256 ENCRYPTED PRIVATE KEY"
ECDSAP256PrivateKeyBanner = "NEBULA ECDSA P256 PRIVATE KEY"
) )
// UnmarshalCertificateFromPEM will try to unmarshal the first pem block in a byte array, returning any non consumed // UnmarshalCertificateFromPEM will try to unmarshal the first pem block in a byte array, returning any non consumed
@@ -51,6 +58,16 @@ func UnmarshalCertificateFromPEM(b []byte) (Certificate, []byte, error) {
} }
func marshalCertPublicKeyToPEM(c Certificate) []byte {
if c.IsCA() {
return MarshalSigningPublicKeyToPEM(c.Curve(), c.PublicKey())
} else {
return MarshalPublicKeyToPEM(c.Curve(), c.PublicKey())
}
}
// MarshalPublicKeyToPEM returns a PEM representation of a public key used for ECDH.
// if your public key came from a certificate, prefer Certificate.PublicKeyPEM() if possible, to avoid mistakes!
func MarshalPublicKeyToPEM(curve Curve, b []byte) []byte { func MarshalPublicKeyToPEM(curve Curve, b []byte) []byte {
switch curve { switch curve {
case Curve_CURVE25519: case Curve_CURVE25519:
@@ -62,6 +79,19 @@ func MarshalPublicKeyToPEM(curve Curve, b []byte) []byte {
} }
} }
// MarshalSigningPublicKeyToPEM returns a PEM representation of a public key used for signing.
// if your public key came from a certificate, prefer Certificate.PublicKeyPEM() if possible, to avoid mistakes!
func MarshalSigningPublicKeyToPEM(curve Curve, b []byte) []byte {
switch curve {
case Curve_CURVE25519:
return pem.EncodeToMemory(&pem.Block{Type: Ed25519PublicKeyBanner, Bytes: b})
case Curve_P256:
return pem.EncodeToMemory(&pem.Block{Type: P256PublicKeyBanner, Bytes: b})
default:
return nil
}
}
func UnmarshalPublicKeyFromPEM(b []byte) ([]byte, []byte, Curve, error) { func UnmarshalPublicKeyFromPEM(b []byte) ([]byte, []byte, Curve, error) {
k, r := pem.Decode(b) k, r := pem.Decode(b)
if k == nil { if k == nil {
@@ -73,7 +103,7 @@ func UnmarshalPublicKeyFromPEM(b []byte) ([]byte, []byte, Curve, error) {
case X25519PublicKeyBanner, Ed25519PublicKeyBanner: case X25519PublicKeyBanner, Ed25519PublicKeyBanner:
expectedLen = 32 expectedLen = 32
curve = Curve_CURVE25519 curve = Curve_CURVE25519
case P256PublicKeyBanner: case P256PublicKeyBanner, ECDSAP256PublicKeyBanner:
// Uncompressed // Uncompressed
expectedLen = 65 expectedLen = 65
curve = Curve_P256 curve = Curve_P256

View File

@@ -177,6 +177,7 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
} }
func TestUnmarshalPublicKeyFromPEM(t *testing.T) { func TestUnmarshalPublicKeyFromPEM(t *testing.T) {
t.Parallel()
pubKey := []byte(`# A good key pubKey := []byte(`# A good key
-----BEGIN NEBULA ED25519 PUBLIC KEY----- -----BEGIN NEBULA ED25519 PUBLIC KEY-----
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA= AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
@@ -230,6 +231,7 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
} }
func TestUnmarshalX25519PublicKey(t *testing.T) { func TestUnmarshalX25519PublicKey(t *testing.T) {
t.Parallel()
pubKey := []byte(`# A good key pubKey := []byte(`# A good key
-----BEGIN NEBULA X25519 PUBLIC KEY----- -----BEGIN NEBULA X25519 PUBLIC KEY-----
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA= AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
@@ -240,6 +242,12 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
AAAAAAAAAAAAAAAAAAAAAAA= AAAAAAAAAAAAAAAAAAAAAAA=
-----END NEBULA P256 PUBLIC KEY----- -----END NEBULA P256 PUBLIC KEY-----
`)
oldPubP256Key := []byte(`# A good key
-----BEGIN NEBULA ECDSA P256 PUBLIC KEY-----
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
AAAAAAAAAAAAAAAAAAAAAAA=
-----END NEBULA ECDSA P256 PUBLIC KEY-----
`) `)
shortKey := []byte(`# A short key shortKey := []byte(`# A short key
-----BEGIN NEBULA X25519 PUBLIC KEY----- -----BEGIN NEBULA X25519 PUBLIC KEY-----
@@ -256,15 +264,22 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA= AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
-END NEBULA X25519 PUBLIC KEY-----`) -END NEBULA X25519 PUBLIC KEY-----`)
keyBundle := appendByteSlices(pubKey, pubP256Key, shortKey, invalidBanner, invalidPem) keyBundle := appendByteSlices(pubKey, pubP256Key, oldPubP256Key, shortKey, invalidBanner, invalidPem)
// Success test case // Success test case
k, rest, curve, err := UnmarshalPublicKeyFromPEM(keyBundle) k, rest, curve, err := UnmarshalPublicKeyFromPEM(keyBundle)
assert.Len(t, k, 32) assert.Len(t, k, 32)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, rest, appendByteSlices(pubP256Key, shortKey, invalidBanner, invalidPem)) assert.Equal(t, rest, appendByteSlices(pubP256Key, oldPubP256Key, shortKey, invalidBanner, invalidPem))
assert.Equal(t, Curve_CURVE25519, curve) assert.Equal(t, Curve_CURVE25519, curve)
// Success test case
k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest)
assert.Len(t, k, 65)
require.NoError(t, err)
assert.Equal(t, rest, appendByteSlices(oldPubP256Key, shortKey, invalidBanner, invalidPem))
assert.Equal(t, Curve_P256, curve)
// Success test case // Success test case
k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest) k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest)
assert.Len(t, k, 65) assert.Len(t, k, 65)

View File

@@ -7,7 +7,6 @@ import (
"crypto/rand" "crypto/rand"
"crypto/sha256" "crypto/sha256"
"fmt" "fmt"
"math/big"
"net/netip" "net/netip"
"time" "time"
) )
@@ -55,15 +54,10 @@ func (t *TBSCertificate) Sign(signer Certificate, curve Curve, key []byte) (Cert
} }
return t.SignWith(signer, curve, sp) return t.SignWith(signer, curve, sp)
case Curve_P256: case Curve_P256:
pk := &ecdsa.PrivateKey{ pk, err := ecdsa.ParseRawPrivateKey(elliptic.P256(), key)
PublicKey: ecdsa.PublicKey{ if err != nil {
Curve: elliptic.P256(), return nil, err
},
// ref: https://github.com/golang/go/blob/go1.19/src/crypto/x509/sec1.go#L95
D: new(big.Int).SetBytes(key),
} }
// ref: https://github.com/golang/go/blob/go1.19/src/crypto/x509/sec1.go#L119
pk.X, pk.Y = pk.Curve.ScalarBaseMult(key)
sp := func(certBytes []byte) ([]byte, error) { sp := func(certBytes []byte) ([]byte, error) {
// We need to hash first for ECDSA // We need to hash first for ECDSA
// - https://pkg.go.dev/crypto/ecdsa#SignASN1 // - https://pkg.go.dev/crypto/ecdsa#SignASN1

View File

@@ -114,6 +114,33 @@ func NewTestCert(v cert.Version, curve cert.Curve, ca cert.Certificate, key []by
return c, pub, cert.MarshalPrivateKeyToPEM(curve, priv), pem return c, pub, cert.MarshalPrivateKeyToPEM(curve, priv), pem
} }
func NewTestCertDifferentVersion(c cert.Certificate, v cert.Version, ca cert.Certificate, key []byte) (cert.Certificate, []byte) {
nc := &cert.TBSCertificate{
Version: v,
Curve: c.Curve(),
Name: c.Name(),
Networks: c.Networks(),
UnsafeNetworks: c.UnsafeNetworks(),
Groups: c.Groups(),
NotBefore: time.Unix(c.NotBefore().Unix(), 0),
NotAfter: time.Unix(c.NotAfter().Unix(), 0),
PublicKey: c.PublicKey(),
IsCA: false,
}
c, err := nc.Sign(ca, ca.Curve(), key)
if err != nil {
panic(err)
}
pem, err := c.MarshalPEM()
if err != nil {
panic(err)
}
return c, pem
}
func X25519Keypair() ([]byte, []byte) { func X25519Keypair() ([]byte, []byte) {
privkey := make([]byte, 32) privkey := make([]byte, 32)
if _, err := io.ReadFull(rand.Reader, privkey); err != nil { if _, err := io.ReadFull(rand.Reader, privkey); err != nil {

View File

@@ -3,6 +3,9 @@ package main
import ( import (
"flag" "flag"
"fmt" "fmt"
"log"
"net/http"
_ "net/http/pprof"
"os" "os"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
@@ -58,6 +61,10 @@ func main() {
os.Exit(1) os.Exit(1)
} }
go func() {
log.Println(http.ListenAndServe("0.0.0.0:6060", nil))
}()
if !*configTest { if !*configTest {
ctrl.Start() ctrl.Start()
notifyReady(l) notifyReady(l)

View File

@@ -11,13 +11,13 @@ import (
"sort" "sort"
"strconv" "strconv"
"strings" "strings"
"sync"
"syscall" "syscall"
"time" "time"
"dario.cat/mergo" "dario.cat/mergo"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/wadey/synctrace" "go.yaml.in/yaml/v3"
"gopkg.in/yaml.v3"
) )
type C struct { type C struct {
@@ -27,14 +27,13 @@ type C struct {
oldSettings map[string]any oldSettings map[string]any
callbacks []func(*C) callbacks []func(*C)
l *logrus.Logger l *logrus.Logger
reloadLock synctrace.Mutex reloadLock sync.Mutex
} }
func NewC(l *logrus.Logger) *C { func NewC(l *logrus.Logger) *C {
return &C{ return &C{
Settings: make(map[string]any), Settings: make(map[string]any),
l: l, l: l,
reloadLock: synctrace.NewMutex("config-reload"),
} }
} }

View File

@@ -10,7 +10,7 @@ import (
"github.com/slackhq/nebula/test" "github.com/slackhq/nebula/test"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"gopkg.in/yaml.v3" "go.yaml.in/yaml/v3"
) )
func TestConfig_Load(t *testing.T) { func TestConfig_Load(t *testing.T) {

View File

@@ -4,14 +4,17 @@ import (
"bytes" "bytes"
"context" "context"
"encoding/binary" "encoding/binary"
"fmt"
"net/netip" "net/netip"
"sync"
"sync/atomic"
"time" "time"
"github.com/rcrowley/go-metrics" "github.com/rcrowley/go-metrics"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/header" "github.com/slackhq/nebula/header"
"github.com/wadey/synctrace"
) )
type trafficDecision int type trafficDecision int
@@ -27,130 +30,124 @@ const (
) )
type connectionManager struct { type connectionManager struct {
in map[uint32]struct{}
inLock synctrace.RWMutex
out map[uint32]struct{}
outLock synctrace.RWMutex
// relayUsed holds which relay localIndexs are in use // relayUsed holds which relay localIndexs are in use
relayUsed map[uint32]struct{} relayUsed map[uint32]struct{}
relayUsedLock synctrace.RWMutex relayUsedLock *sync.RWMutex
hostMap *HostMap hostMap *HostMap
trafficTimer *LockingTimerWheel[uint32] trafficTimer *LockingTimerWheel[uint32]
intf *Interface intf *Interface
pendingDeletion map[uint32]struct{}
punchy *Punchy punchy *Punchy
// Configuration settings
checkInterval time.Duration checkInterval time.Duration
pendingDeletionInterval time.Duration pendingDeletionInterval time.Duration
inactivityTimeout atomic.Int64
dropInactive atomic.Bool
metricsTxPunchy metrics.Counter metricsTxPunchy metrics.Counter
l *logrus.Logger l *logrus.Logger
} }
func newConnectionManager(ctx context.Context, l *logrus.Logger, intf *Interface, checkInterval, pendingDeletionInterval time.Duration, punchy *Punchy) *connectionManager { func newConnectionManagerFromConfig(l *logrus.Logger, c *config.C, hm *HostMap, p *Punchy) *connectionManager {
var max time.Duration cm := &connectionManager{
if checkInterval < pendingDeletionInterval { hostMap: hm,
max = pendingDeletionInterval
} else {
max = checkInterval
}
nc := &connectionManager{
hostMap: intf.hostMap,
in: make(map[uint32]struct{}),
inLock: synctrace.NewRWMutex("connection-manager-in"),
out: make(map[uint32]struct{}),
outLock: synctrace.NewRWMutex("connection-manager-out"),
relayUsed: make(map[uint32]struct{}),
relayUsedLock: synctrace.NewRWMutex("connection-manager-relay-used"),
trafficTimer: NewLockingTimerWheel[uint32]("traffic-timer", time.Millisecond*500, max),
intf: intf,
pendingDeletion: make(map[uint32]struct{}),
checkInterval: checkInterval,
pendingDeletionInterval: pendingDeletionInterval,
punchy: punchy,
metricsTxPunchy: metrics.GetOrRegisterCounter("messages.tx.punchy", nil),
l: l, l: l,
punchy: p,
relayUsed: make(map[uint32]struct{}),
relayUsedLock: &sync.RWMutex{},
metricsTxPunchy: metrics.GetOrRegisterCounter("messages.tx.punchy", nil),
} }
nc.Start(ctx) cm.reload(c, true)
return nc c.RegisterReloadCallback(func(c *config.C) {
cm.reload(c, false)
})
return cm
} }
func (n *connectionManager) In(localIndex uint32) { func (cm *connectionManager) reload(c *config.C, initial bool) {
n.inLock.RLock() if initial {
cm.checkInterval = time.Duration(c.GetInt("timers.connection_alive_interval", 5)) * time.Second
cm.pendingDeletionInterval = time.Duration(c.GetInt("timers.pending_deletion_interval", 10)) * time.Second
// We want at least a minimum resolution of 500ms per tick so that we can hit these intervals
// pretty close to their configured duration.
// The inactivity duration is checked each time a hostinfo ticks through so we don't need the wheel to contain it.
minDuration := min(time.Millisecond*500, cm.checkInterval, cm.pendingDeletionInterval)
maxDuration := max(cm.checkInterval, cm.pendingDeletionInterval)
cm.trafficTimer = NewLockingTimerWheel[uint32](minDuration, maxDuration)
}
if initial || c.HasChanged("tunnels.inactivity_timeout") {
old := cm.getInactivityTimeout()
cm.inactivityTimeout.Store((int64)(c.GetDuration("tunnels.inactivity_timeout", 10*time.Minute)))
if !initial {
cm.l.WithField("oldDuration", old).
WithField("newDuration", cm.getInactivityTimeout()).
Info("Inactivity timeout has changed")
}
}
if initial || c.HasChanged("tunnels.drop_inactive") {
old := cm.dropInactive.Load()
cm.dropInactive.Store(c.GetBool("tunnels.drop_inactive", false))
if !initial {
cm.l.WithField("oldBool", old).
WithField("newBool", cm.dropInactive.Load()).
Info("Drop inactive setting has changed")
}
}
}
func (cm *connectionManager) getInactivityTimeout() time.Duration {
return (time.Duration)(cm.inactivityTimeout.Load())
}
func (cm *connectionManager) In(h *HostInfo) {
h.in.Store(true)
}
func (cm *connectionManager) Out(h *HostInfo) {
h.out.Store(true)
}
func (cm *connectionManager) RelayUsed(localIndex uint32) {
cm.relayUsedLock.RLock()
// If this already exists, return // If this already exists, return
if _, ok := n.in[localIndex]; ok { if _, ok := cm.relayUsed[localIndex]; ok {
n.inLock.RUnlock() cm.relayUsedLock.RUnlock()
return return
} }
n.inLock.RUnlock() cm.relayUsedLock.RUnlock()
n.inLock.Lock() cm.relayUsedLock.Lock()
n.in[localIndex] = struct{}{} cm.relayUsed[localIndex] = struct{}{}
n.inLock.Unlock() cm.relayUsedLock.Unlock()
}
func (n *connectionManager) Out(localIndex uint32) {
n.outLock.RLock()
// If this already exists, return
if _, ok := n.out[localIndex]; ok {
n.outLock.RUnlock()
return
}
n.outLock.RUnlock()
n.outLock.Lock()
n.out[localIndex] = struct{}{}
n.outLock.Unlock()
}
func (n *connectionManager) RelayUsed(localIndex uint32) {
n.relayUsedLock.RLock()
// If this already exists, return
if _, ok := n.relayUsed[localIndex]; ok {
n.relayUsedLock.RUnlock()
return
}
n.relayUsedLock.RUnlock()
n.relayUsedLock.Lock()
n.relayUsed[localIndex] = struct{}{}
n.relayUsedLock.Unlock()
} }
// getAndResetTrafficCheck returns if there was any inbound or outbound traffic within the last tick and // getAndResetTrafficCheck returns if there was any inbound or outbound traffic within the last tick and
// resets the state for this local index // resets the state for this local index
func (n *connectionManager) getAndResetTrafficCheck(localIndex uint32) (bool, bool) { func (cm *connectionManager) getAndResetTrafficCheck(h *HostInfo, now time.Time) (bool, bool) {
n.inLock.Lock() in := h.in.Swap(false)
n.outLock.Lock() out := h.out.Swap(false)
_, in := n.in[localIndex] if in || out {
_, out := n.out[localIndex] h.lastUsed = now
delete(n.in, localIndex) }
delete(n.out, localIndex)
n.inLock.Unlock()
n.outLock.Unlock()
return in, out return in, out
} }
func (n *connectionManager) AddTrafficWatch(localIndex uint32) { // AddTrafficWatch must be called for every new HostInfo.
// Use a write lock directly because it should be incredibly rare that we are ever already tracking this index // We will continue to monitor the HostInfo until the tunnel is dropped.
n.outLock.Lock() func (cm *connectionManager) AddTrafficWatch(h *HostInfo) {
if _, ok := n.out[localIndex]; ok { if h.out.Swap(true) == false {
n.outLock.Unlock() cm.trafficTimer.Add(h.localIndexId, cm.checkInterval)
return
} }
n.out[localIndex] = struct{}{}
n.trafficTimer.Add(localIndex, n.checkInterval)
n.outLock.Unlock()
} }
func (n *connectionManager) Start(ctx context.Context) { func (cm *connectionManager) Start(ctx context.Context) {
go n.Run(ctx) clockSource := time.NewTicker(cm.trafficTimer.t.tickDuration)
}
func (n *connectionManager) Run(ctx context.Context) {
//TODO: this tick should be based on the min wheel tick? Check firewall
clockSource := time.NewTicker(500 * time.Millisecond)
defer clockSource.Stop() defer clockSource.Stop()
p := []byte("") p := []byte("")
@@ -163,61 +160,61 @@ func (n *connectionManager) Run(ctx context.Context) {
return return
case now := <-clockSource.C: case now := <-clockSource.C:
n.trafficTimer.Advance(now) cm.trafficTimer.Advance(now)
for { for {
localIndex, has := n.trafficTimer.Purge() localIndex, has := cm.trafficTimer.Purge()
if !has { if !has {
break break
} }
n.doTrafficCheck(localIndex, p, nb, out, now) cm.doTrafficCheck(localIndex, p, nb, out, now)
} }
} }
} }
} }
func (n *connectionManager) doTrafficCheck(localIndex uint32, p, nb, out []byte, now time.Time) { func (cm *connectionManager) doTrafficCheck(localIndex uint32, p, nb, out []byte, now time.Time) {
decision, hostinfo, primary := n.makeTrafficDecision(localIndex, now) decision, hostinfo, primary := cm.makeTrafficDecision(localIndex, now)
switch decision { switch decision {
case deleteTunnel: case deleteTunnel:
if n.hostMap.DeleteHostInfo(hostinfo) { if cm.hostMap.DeleteHostInfo(hostinfo) {
// Only clearing the lighthouse cache if this is the last hostinfo for this vpn ip in the hostmap // Only clearing the lighthouse cache if this is the last hostinfo for this vpn ip in the hostmap
n.intf.lightHouse.DeleteVpnAddrs(hostinfo.vpnAddrs) cm.intf.lightHouse.DeleteVpnAddrs(hostinfo.vpnAddrs)
} }
case closeTunnel: case closeTunnel:
n.intf.sendCloseTunnel(hostinfo) cm.intf.sendCloseTunnel(hostinfo)
n.intf.closeTunnel(hostinfo) cm.intf.closeTunnel(hostinfo)
case swapPrimary: case swapPrimary:
n.swapPrimary(hostinfo, primary) cm.swapPrimary(hostinfo, primary)
case migrateRelays: case migrateRelays:
n.migrateRelayUsed(hostinfo, primary) cm.migrateRelayUsed(hostinfo, primary)
case tryRehandshake: case tryRehandshake:
n.tryRehandshake(hostinfo) cm.tryRehandshake(hostinfo)
case sendTestPacket: case sendTestPacket:
n.intf.SendMessageToHostInfo(header.Test, header.TestRequest, hostinfo, p, nb, out) cm.intf.SendMessageToHostInfo(header.Test, header.TestRequest, hostinfo, p, nb, out)
} }
n.resetRelayTrafficCheck(hostinfo) cm.resetRelayTrafficCheck(hostinfo)
} }
func (n *connectionManager) resetRelayTrafficCheck(hostinfo *HostInfo) { func (cm *connectionManager) resetRelayTrafficCheck(hostinfo *HostInfo) {
if hostinfo != nil { if hostinfo != nil {
n.relayUsedLock.Lock() cm.relayUsedLock.Lock()
defer n.relayUsedLock.Unlock() defer cm.relayUsedLock.Unlock()
// No need to migrate any relays, delete usage info now. // No need to migrate any relays, delete usage info now.
for _, idx := range hostinfo.relayState.CopyRelayForIdxs() { for _, idx := range hostinfo.relayState.CopyRelayForIdxs() {
delete(n.relayUsed, idx) delete(cm.relayUsed, idx)
} }
} }
} }
func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo) { func (cm *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo) {
relayFor := oldhostinfo.relayState.CopyAllRelayFor() relayFor := oldhostinfo.relayState.CopyAllRelayFor()
for _, r := range relayFor { for _, r := range relayFor {
@@ -227,46 +224,51 @@ func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo)
var relayFrom netip.Addr var relayFrom netip.Addr
var relayTo netip.Addr var relayTo netip.Addr
switch { switch {
case ok && existing.State == Established: case ok:
switch existing.State {
case Established, PeerRequested, Disestablished:
// This relay already exists in newhostinfo, then do nothing. // This relay already exists in newhostinfo, then do nothing.
continue continue
case ok && existing.State == Requested: case Requested:
// The relay exists in a Requested state; re-send the request // The relay exists in a Requested state; re-send the request
index = existing.LocalIndex index = existing.LocalIndex
switch r.Type { switch r.Type {
case TerminalType: case TerminalType:
relayFrom = n.intf.myVpnAddrs[0] relayFrom = cm.intf.myVpnAddrs[0]
relayTo = existing.PeerAddr relayTo = existing.PeerAddr
case ForwardingType: case ForwardingType:
relayFrom = existing.PeerAddr relayFrom = existing.PeerAddr
relayTo = newhostinfo.vpnAddrs[0] relayTo = newhostinfo.vpnAddrs[0]
default: default:
// should never happen // should never happen
panic(fmt.Sprintf("Migrating unknown relay type: %v", r.Type))
}
} }
case !ok: case !ok:
n.relayUsedLock.RLock() cm.relayUsedLock.RLock()
if _, relayUsed := n.relayUsed[r.LocalIndex]; !relayUsed { if _, relayUsed := cm.relayUsed[r.LocalIndex]; !relayUsed {
// The relay hasn't been used; don't migrate it. // The relay hasn't been used; don't migrate it.
n.relayUsedLock.RUnlock() cm.relayUsedLock.RUnlock()
continue continue
} }
n.relayUsedLock.RUnlock() cm.relayUsedLock.RUnlock()
// The relay doesn't exist at all; create some relay state and send the request. // The relay doesn't exist at all; create some relay state and send the request.
var err error var err error
index, err = AddRelay(n.l, newhostinfo, n.hostMap, r.PeerAddr, nil, r.Type, Requested) index, err = AddRelay(cm.l, newhostinfo, cm.hostMap, r.PeerAddr, nil, r.Type, Requested)
if err != nil { if err != nil {
n.l.WithError(err).Error("failed to migrate relay to new hostinfo") cm.l.WithError(err).Error("failed to migrate relay to new hostinfo")
continue continue
} }
switch r.Type { switch r.Type {
case TerminalType: case TerminalType:
relayFrom = n.intf.myVpnAddrs[0] relayFrom = cm.intf.myVpnAddrs[0]
relayTo = r.PeerAddr relayTo = r.PeerAddr
case ForwardingType: case ForwardingType:
relayFrom = r.PeerAddr relayFrom = r.PeerAddr
relayTo = newhostinfo.vpnAddrs[0] relayTo = newhostinfo.vpnAddrs[0]
default: default:
// should never happen // should never happen
panic(fmt.Sprintf("Migrating unknown relay type: %v", r.Type))
} }
} }
@@ -279,12 +281,12 @@ func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo)
switch newhostinfo.GetCert().Certificate.Version() { switch newhostinfo.GetCert().Certificate.Version() {
case cert.Version1: case cert.Version1:
if !relayFrom.Is4() { if !relayFrom.Is4() {
n.l.Error("can not migrate v1 relay with a v6 network because the relay is not running a current nebula version") cm.l.Error("can not migrate v1 relay with a v6 network because the relay is not running a current nebula version")
continue continue
} }
if !relayTo.Is4() { if !relayTo.Is4() {
n.l.Error("can not migrate v1 relay with a v6 remote network because the relay is not running a current nebula version") cm.l.Error("can not migrate v1 relay with a v6 remote network because the relay is not running a current nebula version")
continue continue
} }
@@ -296,16 +298,16 @@ func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo)
req.RelayFromAddr = netAddrToProtoAddr(relayFrom) req.RelayFromAddr = netAddrToProtoAddr(relayFrom)
req.RelayToAddr = netAddrToProtoAddr(relayTo) req.RelayToAddr = netAddrToProtoAddr(relayTo)
default: default:
newhostinfo.logger(n.l).Error("Unknown certificate version found while attempting to migrate relay") newhostinfo.logger(cm.l).Error("Unknown certificate version found while attempting to migrate relay")
continue continue
} }
msg, err := req.Marshal() msg, err := req.Marshal()
if err != nil { if err != nil {
n.l.WithError(err).Error("failed to marshal Control message to migrate relay") cm.l.WithError(err).Error("failed to marshal Control message to migrate relay")
} else { } else {
n.intf.SendMessageToHostInfo(header.Control, 0, newhostinfo, msg, make([]byte, 12), make([]byte, mtu)) cm.intf.SendMessageToHostInfo(header.Control, 0, newhostinfo, msg, make([]byte, 12), make([]byte, mtu))
n.l.WithFields(logrus.Fields{ cm.l.WithFields(logrus.Fields{
"relayFrom": req.RelayFromAddr, "relayFrom": req.RelayFromAddr,
"relayTo": req.RelayToAddr, "relayTo": req.RelayToAddr,
"initiatorRelayIndex": req.InitiatorRelayIndex, "initiatorRelayIndex": req.InitiatorRelayIndex,
@@ -316,46 +318,44 @@ func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo)
} }
} }
func (n *connectionManager) makeTrafficDecision(localIndex uint32, now time.Time) (trafficDecision, *HostInfo, *HostInfo) { func (cm *connectionManager) makeTrafficDecision(localIndex uint32, now time.Time) (trafficDecision, *HostInfo, *HostInfo) {
n.hostMap.RLock() // Read lock the main hostmap to order decisions based on tunnels being the primary tunnel
defer n.hostMap.RUnlock() cm.hostMap.RLock()
defer cm.hostMap.RUnlock()
hostinfo := n.hostMap.Indexes[localIndex] hostinfo := cm.hostMap.Indexes[localIndex]
if hostinfo == nil { if hostinfo == nil {
n.l.WithField("localIndex", localIndex).Debugf("Not found in hostmap") cm.l.WithField("localIndex", localIndex).Debugln("Not found in hostmap")
delete(n.pendingDeletion, localIndex)
return doNothing, nil, nil return doNothing, nil, nil
} }
if n.isInvalidCertificate(now, hostinfo) { if cm.isInvalidCertificate(now, hostinfo) {
delete(n.pendingDeletion, hostinfo.localIndexId)
return closeTunnel, hostinfo, nil return closeTunnel, hostinfo, nil
} }
primary := n.hostMap.Hosts[hostinfo.vpnAddrs[0]] primary := cm.hostMap.Hosts[hostinfo.vpnAddrs[0]]
mainHostInfo := true mainHostInfo := true
if primary != nil && primary != hostinfo { if primary != nil && primary != hostinfo {
mainHostInfo = false mainHostInfo = false
} }
// Check for traffic on this hostinfo // Check for traffic on this hostinfo
inTraffic, outTraffic := n.getAndResetTrafficCheck(localIndex) inTraffic, outTraffic := cm.getAndResetTrafficCheck(hostinfo, now)
// A hostinfo is determined alive if there is incoming traffic // A hostinfo is determined alive if there is incoming traffic
if inTraffic { if inTraffic {
decision := doNothing decision := doNothing
if n.l.Level >= logrus.DebugLevel { if cm.l.Level >= logrus.DebugLevel {
hostinfo.logger(n.l). hostinfo.logger(cm.l).
WithField("tunnelCheck", m{"state": "alive", "method": "passive"}). WithField("tunnelCheck", m{"state": "alive", "method": "passive"}).
Debug("Tunnel status") Debug("Tunnel status")
} }
delete(n.pendingDeletion, hostinfo.localIndexId) hostinfo.pendingDeletion.Store(false)
if mainHostInfo { if mainHostInfo {
decision = tryRehandshake decision = tryRehandshake
} else { } else {
if n.shouldSwapPrimary(hostinfo, primary) { if cm.shouldSwapPrimary(hostinfo) {
decision = swapPrimary decision = swapPrimary
} else { } else {
// migrate the relays to the primary, if in use. // migrate the relays to the primary, if in use.
@@ -363,46 +363,55 @@ func (n *connectionManager) makeTrafficDecision(localIndex uint32, now time.Time
} }
} }
n.trafficTimer.Add(hostinfo.localIndexId, n.checkInterval) cm.trafficTimer.Add(hostinfo.localIndexId, cm.checkInterval)
if !outTraffic { if !outTraffic {
// Send a punch packet to keep the NAT state alive // Send a punch packet to keep the NAT state alive
n.sendPunch(hostinfo) cm.sendPunch(hostinfo)
} }
return decision, hostinfo, primary return decision, hostinfo, primary
} }
if _, ok := n.pendingDeletion[hostinfo.localIndexId]; ok { if hostinfo.pendingDeletion.Load() {
// We have already sent a test packet and nothing was returned, this hostinfo is dead // We have already sent a test packet and nothing was returned, this hostinfo is dead
hostinfo.logger(n.l). hostinfo.logger(cm.l).
WithField("tunnelCheck", m{"state": "dead", "method": "active"}). WithField("tunnelCheck", m{"state": "dead", "method": "active"}).
Info("Tunnel status") Info("Tunnel status")
delete(n.pendingDeletion, hostinfo.localIndexId)
return deleteTunnel, hostinfo, nil return deleteTunnel, hostinfo, nil
} }
decision := doNothing decision := doNothing
if hostinfo != nil && hostinfo.ConnectionState != nil && mainHostInfo { if hostinfo != nil && hostinfo.ConnectionState != nil && mainHostInfo {
if !outTraffic { if !outTraffic {
// If we aren't sending or receiving traffic then its an unused tunnel and we don't to test the tunnel. inactiveFor, isInactive := cm.isInactive(hostinfo, now)
// Just maintain NAT state if configured to do so. if isInactive {
n.sendPunch(hostinfo) // Tunnel is inactive, tear it down
n.trafficTimer.Add(hostinfo.localIndexId, n.checkInterval) hostinfo.logger(cm.l).
return doNothing, nil, nil WithField("inactiveDuration", inactiveFor).
WithField("primary", mainHostInfo).
Info("Dropping tunnel due to inactivity")
return closeTunnel, hostinfo, primary
} }
if n.punchy.GetTargetEverything() { // If we aren't sending or receiving traffic then its an unused tunnel and we don't to test the tunnel.
// Just maintain NAT state if configured to do so.
cm.sendPunch(hostinfo)
cm.trafficTimer.Add(hostinfo.localIndexId, cm.checkInterval)
return doNothing, nil, nil
}
if cm.punchy.GetTargetEverything() {
// This is similar to the old punchy behavior with a slight optimization. // This is similar to the old punchy behavior with a slight optimization.
// We aren't receiving traffic but we are sending it, punch on all known // We aren't receiving traffic but we are sending it, punch on all known
// ips in case we need to re-prime NAT state // ips in case we need to re-prime NAT state
n.sendPunch(hostinfo) cm.sendPunch(hostinfo)
} }
if n.l.Level >= logrus.DebugLevel { if cm.l.Level >= logrus.DebugLevel {
hostinfo.logger(n.l). hostinfo.logger(cm.l).
WithField("tunnelCheck", m{"state": "testing", "method": "active"}). WithField("tunnelCheck", m{"state": "testing", "method": "active"}).
Debug("Tunnel status") Debug("Tunnel status")
} }
@@ -411,17 +420,33 @@ func (n *connectionManager) makeTrafficDecision(localIndex uint32, now time.Time
decision = sendTestPacket decision = sendTestPacket
} else { } else {
if n.l.Level >= logrus.DebugLevel { if cm.l.Level >= logrus.DebugLevel {
hostinfo.logger(n.l).Debugf("Hostinfo sadness") hostinfo.logger(cm.l).Debugf("Hostinfo sadness")
} }
} }
n.pendingDeletion[hostinfo.localIndexId] = struct{}{} hostinfo.pendingDeletion.Store(true)
n.trafficTimer.Add(hostinfo.localIndexId, n.pendingDeletionInterval) cm.trafficTimer.Add(hostinfo.localIndexId, cm.pendingDeletionInterval)
return decision, hostinfo, nil return decision, hostinfo, nil
} }
func (n *connectionManager) shouldSwapPrimary(current, primary *HostInfo) bool { func (cm *connectionManager) isInactive(hostinfo *HostInfo, now time.Time) (time.Duration, bool) {
if cm.dropInactive.Load() == false {
// We aren't configured to drop inactive tunnels
return 0, false
}
inactiveDuration := now.Sub(hostinfo.lastUsed)
if inactiveDuration < cm.getInactivityTimeout() {
// It's not considered inactive
return inactiveDuration, false
}
// The tunnel is inactive
return inactiveDuration, true
}
func (cm *connectionManager) shouldSwapPrimary(current *HostInfo) bool {
// The primary tunnel is the most recent handshake to complete locally and should work entirely fine. // The primary tunnel is the most recent handshake to complete locally and should work entirely fine.
// If we are here then we have multiple tunnels for a host pair and neither side believes the same tunnel is primary. // If we are here then we have multiple tunnels for a host pair and neither side believes the same tunnel is primary.
// Let's sort this out. // Let's sort this out.
@@ -429,83 +454,127 @@ func (n *connectionManager) shouldSwapPrimary(current, primary *HostInfo) bool {
// Only one side should swap because if both swap then we may never resolve to a single tunnel. // Only one side should swap because if both swap then we may never resolve to a single tunnel.
// vpn addr is static across all tunnels for this host pair so lets // vpn addr is static across all tunnels for this host pair so lets
// use that to determine if we should consider swapping. // use that to determine if we should consider swapping.
if current.vpnAddrs[0].Compare(n.intf.myVpnAddrs[0]) < 0 { if current.vpnAddrs[0].Compare(cm.intf.myVpnAddrs[0]) < 0 {
// Their primary vpn addr is less than mine. Do not swap. // Their primary vpn addr is less than mine. Do not swap.
return false return false
} }
crt := n.intf.pki.getCertState().getCertificate(current.ConnectionState.myCert.Version()) crt := cm.intf.pki.getCertState().getCertificate(current.ConnectionState.myCert.Version())
if crt == nil {
//my cert was reloaded away. We should definitely swap from this tunnel
return true
}
// If this tunnel is using the latest certificate then we should swap it to primary for a bit and see if things // If this tunnel is using the latest certificate then we should swap it to primary for a bit and see if things
// settle down. // settle down.
return bytes.Equal(current.ConnectionState.myCert.Signature(), crt.Signature()) return bytes.Equal(current.ConnectionState.myCert.Signature(), crt.Signature())
} }
func (n *connectionManager) swapPrimary(current, primary *HostInfo) { func (cm *connectionManager) swapPrimary(current, primary *HostInfo) {
n.hostMap.Lock() cm.hostMap.Lock()
// Make sure the primary is still the same after the write lock. This avoids a race with a rehandshake. // Make sure the primary is still the same after the write lock. This avoids a race with a rehandshake.
if n.hostMap.Hosts[current.vpnAddrs[0]] == primary { if cm.hostMap.Hosts[current.vpnAddrs[0]] == primary {
n.hostMap.unlockedMakePrimary(current) cm.hostMap.unlockedMakePrimary(current)
} }
n.hostMap.Unlock() cm.hostMap.Unlock()
} }
// isInvalidCertificate will check if we should destroy a tunnel if pki.disconnect_invalid is true and // isInvalidCertificate decides if we should destroy a tunnel.
// the certificate is no longer valid. Block listed certificates will skip the pki.disconnect_invalid // returns true if pki.disconnect_invalid is true and the certificate is no longer valid.
// check and return true. // Blocklisted certificates will skip the pki.disconnect_invalid check and return true.
func (n *connectionManager) isInvalidCertificate(now time.Time, hostinfo *HostInfo) bool { func (cm *connectionManager) isInvalidCertificate(now time.Time, hostinfo *HostInfo) bool {
remoteCert := hostinfo.GetCert() remoteCert := hostinfo.GetCert()
if remoteCert == nil { if remoteCert == nil {
return false return false //don't tear down tunnels for handshakes in progress
} }
caPool := n.intf.pki.GetCAPool() caPool := cm.intf.pki.GetCAPool()
err := caPool.VerifyCachedCertificate(now, remoteCert) err := caPool.VerifyCachedCertificate(now, remoteCert)
if err == nil { if err == nil {
return false return false //cert is still valid! yay!
} } else if err == cert.ErrBlockListed { //avoiding errors.Is for speed
if !n.intf.disconnectInvalid.Load() && err != cert.ErrBlockListed {
// Block listed certificates should always be disconnected // Block listed certificates should always be disconnected
return false hostinfo.logger(cm.l).WithError(err).
} WithField("fingerprint", remoteCert.Fingerprint).
Info("Remote certificate is blocked, tearing down the tunnel")
hostinfo.logger(n.l).WithError(err). return true
} else if cm.intf.disconnectInvalid.Load() {
hostinfo.logger(cm.l).WithError(err).
WithField("fingerprint", remoteCert.Fingerprint). WithField("fingerprint", remoteCert.Fingerprint).
Info("Remote certificate is no longer valid, tearing down the tunnel") Info("Remote certificate is no longer valid, tearing down the tunnel")
return true return true
} else {
//if we reach here, the cert is no longer valid, but we're configured to keep tunnels from now-invalid certs open
return false
}
} }
func (n *connectionManager) sendPunch(hostinfo *HostInfo) { func (cm *connectionManager) sendPunch(hostinfo *HostInfo) {
if !n.punchy.GetPunch() { if !cm.punchy.GetPunch() {
// Punching is disabled // Punching is disabled
return return
} }
if n.punchy.GetTargetEverything() { if cm.intf.lightHouse.IsAnyLighthouseAddr(hostinfo.vpnAddrs) {
hostinfo.remotes.ForEach(n.hostMap.GetPreferredRanges(), func(addr netip.AddrPort, preferred bool) { // Do not punch to lighthouses, we assume our lighthouse update interval is good enough.
n.metricsTxPunchy.Inc(1) // In the event the update interval is not sufficient to maintain NAT state then a publicly available lighthouse
n.intf.outside.WriteTo([]byte{1}, addr) // would lose the ability to notify us and punchy.respond would become unreliable.
})
} else if hostinfo.remote.IsValid() {
n.metricsTxPunchy.Inc(1)
n.intf.outside.WriteTo([]byte{1}, hostinfo.remote)
}
}
func (n *connectionManager) tryRehandshake(hostinfo *HostInfo) {
cs := n.intf.pki.getCertState()
curCrt := hostinfo.ConnectionState.myCert
myCrt := cs.getCertificate(curCrt.Version())
if curCrt.Version() >= cs.initiatingVersion && bytes.Equal(curCrt.Signature(), myCrt.Signature()) == true {
// The current tunnel is using the latest certificate and version, no need to rehandshake.
return return
} }
n.l.WithField("vpnAddrs", hostinfo.vpnAddrs). if cm.punchy.GetTargetEverything() {
hostinfo.remotes.ForEach(cm.hostMap.GetPreferredRanges(), func(addr netip.AddrPort, preferred bool) {
cm.metricsTxPunchy.Inc(1)
cm.intf.outside.WriteTo([]byte{1}, addr)
})
} else if hostinfo.remote.IsValid() {
cm.metricsTxPunchy.Inc(1)
cm.intf.outside.WriteTo([]byte{1}, hostinfo.remote)
}
}
func (cm *connectionManager) tryRehandshake(hostinfo *HostInfo) {
cs := cm.intf.pki.getCertState()
curCrt := hostinfo.ConnectionState.myCert
curCrtVersion := curCrt.Version()
myCrt := cs.getCertificate(curCrtVersion)
if myCrt == nil {
cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
WithField("version", curCrtVersion).
WithField("reason", "local certificate removed").
Info("Re-handshaking with remote")
cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil)
return
}
peerCrt := hostinfo.ConnectionState.peerCert
if peerCrt != nil && curCrtVersion < peerCrt.Certificate.Version() {
// if our certificate version is less than theirs, and we have a matching version available, rehandshake?
if cs.getCertificate(peerCrt.Certificate.Version()) != nil {
cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
WithField("version", curCrtVersion).
WithField("peerVersion", peerCrt.Certificate.Version()).
WithField("reason", "local certificate version lower than peer, attempting to correct").
Info("Re-handshaking with remote")
cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], func(hh *HandshakeHostInfo) {
hh.initiatingVersionOverride = peerCrt.Certificate.Version()
})
return
}
}
if !bytes.Equal(curCrt.Signature(), myCrt.Signature()) {
cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
WithField("reason", "local certificate is not current"). WithField("reason", "local certificate is not current").
Info("Re-handshaking with remote") Info("Re-handshaking with remote")
n.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil) cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil)
return
}
if curCrtVersion < cs.initiatingVersion {
cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
WithField("reason", "current cert version < pki.initiatingVersion").
Info("Re-handshaking with remote")
cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil)
return
}
} }

View File

@@ -1,7 +1,6 @@
package nebula package nebula
import ( import (
"context"
"crypto/ed25519" "crypto/ed25519"
"crypto/rand" "crypto/rand"
"net/netip" "net/netip"
@@ -23,7 +22,7 @@ func newTestLighthouse() *LightHouse {
addrMap: map[netip.Addr]*RemoteList{}, addrMap: map[netip.Addr]*RemoteList{},
queryChan: make(chan netip.Addr, 10), queryChan: make(chan netip.Addr, 10),
} }
lighthouses := map[netip.Addr]struct{}{} lighthouses := []netip.Addr{}
staticList := map[netip.Addr]struct{}{} staticList := map[netip.Addr]struct{}{}
lh.lighthouses.Store(&lighthouses) lh.lighthouses.Store(&lighthouses)
@@ -64,10 +63,10 @@ func Test_NewConnectionManagerTest(t *testing.T) {
ifce.pki.cs.Store(cs) ifce.pki.cs.Store(cs)
// Create manager // Create manager
ctx, cancel := context.WithCancel(context.Background()) conf := config.NewC(l)
defer cancel() punchy := NewPunchyFromConfig(l, conf)
punchy := NewPunchyFromConfig(l, config.NewC(l)) nc := newConnectionManagerFromConfig(l, conf, hostMap, punchy)
nc := newConnectionManager(ctx, l, ifce, 5, 10, punchy) nc.intf = ifce
p := []byte("") p := []byte("")
nb := make([]byte, 12, 12) nb := make([]byte, 12, 12)
out := make([]byte, mtu) out := make([]byte, mtu)
@@ -85,32 +84,33 @@ func Test_NewConnectionManagerTest(t *testing.T) {
nc.hostMap.unlockedAddHostInfo(hostinfo, ifce) nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)
// We saw traffic out to vpnIp // We saw traffic out to vpnIp
nc.Out(hostinfo.localIndexId) nc.Out(hostinfo)
nc.In(hostinfo.localIndexId) nc.In(hostinfo)
assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId) assert.False(t, hostinfo.pendingDeletion.Load())
assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0]) assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId) assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
assert.Contains(t, nc.out, hostinfo.localIndexId) assert.True(t, hostinfo.out.Load())
assert.True(t, hostinfo.in.Load())
// Do a traffic check tick, should not be pending deletion but should not have any in/out packets recorded // Do a traffic check tick, should not be pending deletion but should not have any in/out packets recorded
nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now()) nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId) assert.False(t, hostinfo.pendingDeletion.Load())
assert.NotContains(t, nc.out, hostinfo.localIndexId) assert.False(t, hostinfo.out.Load())
assert.NotContains(t, nc.in, hostinfo.localIndexId) assert.False(t, hostinfo.in.Load())
// Do another traffic check tick, this host should be pending deletion now // Do another traffic check tick, this host should be pending deletion now
nc.Out(hostinfo.localIndexId) nc.Out(hostinfo)
assert.True(t, hostinfo.out.Load())
nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now()) nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
assert.Contains(t, nc.pendingDeletion, hostinfo.localIndexId) assert.True(t, hostinfo.pendingDeletion.Load())
assert.NotContains(t, nc.out, hostinfo.localIndexId) assert.False(t, hostinfo.out.Load())
assert.NotContains(t, nc.in, hostinfo.localIndexId) assert.False(t, hostinfo.in.Load())
assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId) assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0]) assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
// Do a final traffic check tick, the host should now be removed // Do a final traffic check tick, the host should now be removed
nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now()) nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId) assert.NotContains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs)
assert.NotContains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
assert.NotContains(t, nc.hostMap.Indexes, hostinfo.localIndexId) assert.NotContains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
} }
@@ -146,10 +146,10 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
ifce.pki.cs.Store(cs) ifce.pki.cs.Store(cs)
// Create manager // Create manager
ctx, cancel := context.WithCancel(context.Background()) conf := config.NewC(l)
defer cancel() punchy := NewPunchyFromConfig(l, conf)
punchy := NewPunchyFromConfig(l, config.NewC(l)) nc := newConnectionManagerFromConfig(l, conf, hostMap, punchy)
nc := newConnectionManager(ctx, l, ifce, 5, 10, punchy) nc.intf = ifce
p := []byte("") p := []byte("")
nb := make([]byte, 12, 12) nb := make([]byte, 12, 12)
out := make([]byte, mtu) out := make([]byte, mtu)
@@ -167,33 +167,129 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
nc.hostMap.unlockedAddHostInfo(hostinfo, ifce) nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)
// We saw traffic out to vpnIp // We saw traffic out to vpnIp
nc.Out(hostinfo.localIndexId) nc.Out(hostinfo)
nc.In(hostinfo.localIndexId) nc.In(hostinfo)
assert.NotContains(t, nc.pendingDeletion, hostinfo.vpnAddrs[0]) assert.True(t, hostinfo.in.Load())
assert.True(t, hostinfo.out.Load())
assert.False(t, hostinfo.pendingDeletion.Load())
assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0]) assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId) assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
// Do a traffic check tick, should not be pending deletion but should not have any in/out packets recorded // Do a traffic check tick, should not be pending deletion but should not have any in/out packets recorded
nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now()) nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId) assert.False(t, hostinfo.pendingDeletion.Load())
assert.NotContains(t, nc.out, hostinfo.localIndexId) assert.False(t, hostinfo.out.Load())
assert.NotContains(t, nc.in, hostinfo.localIndexId) assert.False(t, hostinfo.in.Load())
// Do another traffic check tick, this host should be pending deletion now // Do another traffic check tick, this host should be pending deletion now
nc.Out(hostinfo.localIndexId) nc.Out(hostinfo)
nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now()) nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
assert.Contains(t, nc.pendingDeletion, hostinfo.localIndexId) assert.True(t, hostinfo.pendingDeletion.Load())
assert.NotContains(t, nc.out, hostinfo.localIndexId) assert.False(t, hostinfo.out.Load())
assert.NotContains(t, nc.in, hostinfo.localIndexId) assert.False(t, hostinfo.in.Load())
assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId) assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0]) assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
// We saw traffic, should no longer be pending deletion // We saw traffic, should no longer be pending deletion
nc.In(hostinfo.localIndexId) nc.In(hostinfo)
nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now()) nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId) assert.False(t, hostinfo.pendingDeletion.Load())
assert.NotContains(t, nc.out, hostinfo.localIndexId) assert.False(t, hostinfo.out.Load())
assert.NotContains(t, nc.in, hostinfo.localIndexId) assert.False(t, hostinfo.in.Load())
assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
}
func Test_NewConnectionManager_DisconnectInactive(t *testing.T) {
l := test.NewLogger()
localrange := netip.MustParsePrefix("10.1.1.1/24")
vpnAddrs := []netip.Addr{netip.MustParseAddr("172.1.1.2")}
preferredRanges := []netip.Prefix{localrange}
// Very incomplete mock objects
hostMap := newHostMap(l)
hostMap.preferredRanges.Store(&preferredRanges)
cs := &CertState{
initiatingVersion: cert.Version1,
privateKey: []byte{},
v1Cert: &dummyCert{version: cert.Version1},
v1HandshakeBytes: []byte{},
}
lh := newTestLighthouse()
ifce := &Interface{
hostMap: hostMap,
inside: &test.NoopTun{},
outside: &udp.NoopConn{},
firewall: &Firewall{},
lightHouse: lh,
pki: &PKI{},
handshakeManager: NewHandshakeManager(l, hostMap, lh, &udp.NoopConn{}, defaultHandshakeConfig),
l: l,
}
ifce.pki.cs.Store(cs)
// Create manager
conf := config.NewC(l)
conf.Settings["tunnels"] = map[string]any{
"drop_inactive": true,
}
punchy := NewPunchyFromConfig(l, conf)
nc := newConnectionManagerFromConfig(l, conf, hostMap, punchy)
assert.True(t, nc.dropInactive.Load())
nc.intf = ifce
// Add an ip we have established a connection w/ to hostmap
hostinfo := &HostInfo{
vpnAddrs: vpnAddrs,
localIndexId: 1099,
remoteIndexId: 9901,
}
hostinfo.ConnectionState = &ConnectionState{
myCert: &dummyCert{version: cert.Version1},
H: &noise.HandshakeState{},
}
nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)
// Do a traffic check tick, in and out should be cleared but should not be pending deletion
nc.Out(hostinfo)
nc.In(hostinfo)
assert.True(t, hostinfo.out.Load())
assert.True(t, hostinfo.in.Load())
now := time.Now()
decision, _, _ := nc.makeTrafficDecision(hostinfo.localIndexId, now)
assert.Equal(t, tryRehandshake, decision)
assert.Equal(t, now, hostinfo.lastUsed)
assert.False(t, hostinfo.pendingDeletion.Load())
assert.False(t, hostinfo.out.Load())
assert.False(t, hostinfo.in.Load())
decision, _, _ = nc.makeTrafficDecision(hostinfo.localIndexId, now.Add(time.Second*5))
assert.Equal(t, doNothing, decision)
assert.Equal(t, now, hostinfo.lastUsed)
assert.False(t, hostinfo.pendingDeletion.Load())
assert.False(t, hostinfo.out.Load())
assert.False(t, hostinfo.in.Load())
// Do another traffic check tick, should still not be pending deletion
decision, _, _ = nc.makeTrafficDecision(hostinfo.localIndexId, now.Add(time.Second*10))
assert.Equal(t, doNothing, decision)
assert.Equal(t, now, hostinfo.lastUsed)
assert.False(t, hostinfo.pendingDeletion.Load())
assert.False(t, hostinfo.out.Load())
assert.False(t, hostinfo.in.Load())
assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
// Finally advance beyond the inactivity timeout
decision, _, _ = nc.makeTrafficDecision(hostinfo.localIndexId, now.Add(time.Minute*10))
assert.Equal(t, closeTunnel, decision)
assert.Equal(t, now, hostinfo.lastUsed)
assert.False(t, hostinfo.pendingDeletion.Load())
assert.False(t, hostinfo.out.Load())
assert.False(t, hostinfo.in.Load())
assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId) assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0]) assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
} }
@@ -264,10 +360,10 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
ifce.disconnectInvalid.Store(true) ifce.disconnectInvalid.Store(true)
// Create manager // Create manager
ctx, cancel := context.WithCancel(context.Background()) conf := config.NewC(l)
defer cancel() punchy := NewPunchyFromConfig(l, conf)
punchy := NewPunchyFromConfig(l, config.NewC(l)) nc := newConnectionManagerFromConfig(l, conf, hostMap, punchy)
nc := newConnectionManager(ctx, l, ifce, 5, 10, punchy) nc.intf = ifce
ifce.connectionManager = nc ifce.connectionManager = nc
hostinfo := &HostInfo{ hostinfo := &HostInfo{
@@ -350,6 +446,10 @@ func (d *dummyCert) PublicKey() []byte {
return d.publicKey return d.publicKey
} }
func (d *dummyCert) MarshalPublicKeyPEM() []byte {
return cert.MarshalPublicKeyToPEM(d.curve, d.publicKey)
}
func (d *dummyCert) Signature() []byte { func (d *dummyCert) Signature() []byte {
return d.signature return d.signature
} }

View File

@@ -4,16 +4,16 @@ import (
"crypto/rand" "crypto/rand"
"encoding/json" "encoding/json"
"fmt" "fmt"
"sync"
"sync/atomic" "sync/atomic"
"github.com/flynn/noise" "github.com/flynn/noise"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/noiseutil" "github.com/slackhq/nebula/noiseutil"
"github.com/wadey/synctrace"
) )
const ReplayWindow = 1024 const ReplayWindow = 4096
type ConnectionState struct { type ConnectionState struct {
eKey *NebulaCipherState eKey *NebulaCipherState
@@ -24,7 +24,7 @@ type ConnectionState struct {
initiator bool initiator bool
messageCounter atomic.Uint64 messageCounter atomic.Uint64
window *Bits window *Bits
writeLock synctrace.Mutex writeLock sync.Mutex
} }
func NewConnectionState(l *logrus.Logger, cs *CertState, crt cert.Certificate, initiator bool, pattern noise.HandshakePattern) (*ConnectionState, error) { func NewConnectionState(l *logrus.Logger, cs *CertState, crt cert.Certificate, initiator bool, pattern noise.HandshakePattern) (*ConnectionState, error) {
@@ -76,7 +76,6 @@ func NewConnectionState(l *logrus.Logger, cs *CertState, crt cert.Certificate, i
initiator: initiator, initiator: initiator,
window: b, window: b,
myCert: crt, myCert: crt,
writeLock: synctrace.NewMutex("connection-state"),
} }
// always start the counter from 2, as packet 1 and packet 2 are handshake packets. // always start the counter from 2, as packet 1 and packet 2 are handshake packets.
ci.messageCounter.Add(2) ci.messageCounter.Add(2)

View File

@@ -34,6 +34,7 @@ type Control struct {
statsStart func() statsStart func()
dnsStart func() dnsStart func()
lighthouseStart func() lighthouseStart func()
connectionManagerStart func(context.Context)
} }
type ControlHostInfo struct { type ControlHostInfo struct {
@@ -63,6 +64,9 @@ func (c *Control) Start() {
if c.dnsStart != nil { if c.dnsStart != nil {
go c.dnsStart() go c.dnsStart()
} }
if c.connectionManagerStart != nil {
go c.connectionManagerStart(c.ctx)
}
if c.lighthouseStart != nil { if c.lighthouseStart != nil {
c.lighthouseStart() c.lighthouseStart()
} }

View File

@@ -53,7 +53,7 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
localIndexId: 201, localIndexId: 201,
vpnAddrs: []netip.Addr{vpnIp}, vpnAddrs: []netip.Addr{vpnIp},
relayState: RelayState{ relayState: RelayState{
relays: map[netip.Addr]struct{}{}, relays: nil,
relayForByAddr: map[netip.Addr]*Relay{}, relayForByAddr: map[netip.Addr]*Relay{},
relayForByIdx: map[uint32]*Relay{}, relayForByIdx: map[uint32]*Relay{},
}, },
@@ -72,7 +72,7 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
localIndexId: 201, localIndexId: 201,
vpnAddrs: []netip.Addr{vpnIp2}, vpnAddrs: []netip.Addr{vpnIp2},
relayState: RelayState{ relayState: RelayState{
relays: map[netip.Addr]struct{}{}, relays: nil,
relayForByAddr: map[netip.Addr]*Relay{}, relayForByAddr: map[netip.Addr]*Relay{},
relayForByIdx: map[uint32]*Relay{}, relayForByIdx: map[uint32]*Relay{},
}, },

View File

@@ -174,6 +174,10 @@ func (c *Control) GetHostmap() *HostMap {
return c.f.hostMap return c.f.hostMap
} }
func (c *Control) GetF() *Interface {
return c.f
}
func (c *Control) GetCertState() *CertState { func (c *Control) GetCertState() *CertState {
return c.f.pki.getCertState() return c.f.pki.getCertState()
} }

View File

@@ -6,12 +6,12 @@ import (
"net/netip" "net/netip"
"strconv" "strconv"
"strings" "strings"
"sync"
"github.com/gaissmai/bart" "github.com/gaissmai/bart"
"github.com/miekg/dns" "github.com/miekg/dns"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/wadey/synctrace"
) )
// This whole thing should be rewritten to use context // This whole thing should be rewritten to use context
@@ -21,7 +21,7 @@ var dnsServer *dns.Server
var dnsAddr string var dnsAddr string
type dnsRecords struct { type dnsRecords struct {
synctrace.RWMutex sync.RWMutex
l *logrus.Logger l *logrus.Logger
dnsMap4 map[string]netip.Addr dnsMap4 map[string]netip.Addr
dnsMap6 map[string]netip.Addr dnsMap6 map[string]netip.Addr
@@ -31,7 +31,6 @@ type dnsRecords struct {
func newDnsRecords(l *logrus.Logger, cs *CertState, hostMap *HostMap) *dnsRecords { func newDnsRecords(l *logrus.Logger, cs *CertState, hostMap *HostMap) *dnsRecords {
return &dnsRecords{ return &dnsRecords{
RWMutex: synctrace.NewRWMutex("dns-records"),
l: l, l: l,
dnsMap4: make(map[string]netip.Addr), dnsMap4: make(map[string]netip.Addr),
dnsMap6: make(map[string]netip.Addr), dnsMap6: make(map[string]netip.Addr),

View File

@@ -20,7 +20,7 @@ import (
"github.com/slackhq/nebula/udp" "github.com/slackhq/nebula/udp"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"gopkg.in/yaml.v3" "go.yaml.in/yaml/v3"
) )
func BenchmarkHotPath(b *testing.B) { func BenchmarkHotPath(b *testing.B) {
@@ -97,6 +97,41 @@ func TestGoodHandshake(t *testing.T) {
theirControl.Stop() theirControl.Stop()
} }
func TestGoodHandshakeNoOverlap(t *testing.T) {
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "me", "10.128.0.1/24", nil)
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "them", "2001::69/24", nil) //look ma, cross-stack!
// Put their info in our lighthouse
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
// Start the servers
myControl.Start()
theirControl.Start()
empty := []byte{}
t.Log("do something to cause a handshake")
myControl.GetF().SendMessageToVpnAddr(header.Test, header.MessageNone, theirVpnIpNet[0].Addr(), empty, empty, empty)
t.Log("Have them consume my stage 0 packet. They have a tunnel now")
theirControl.InjectUDPPacket(myControl.GetFromUDP(true))
t.Log("Get their stage 1 packet")
stage1Packet := theirControl.GetFromUDP(true)
t.Log("Have me consume their stage 1 packet. I have a tunnel now")
myControl.InjectUDPPacket(stage1Packet)
t.Log("Wait until we see a test packet come through to make sure we give the tunnel time to complete")
myControl.WaitForType(header.Test, 0, theirControl)
t.Log("Make sure our host infos are correct")
assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet, theirVpnIpNet, myControl, theirControl)
myControl.Stop()
theirControl.Stop()
}
func TestWrongResponderHandshake(t *testing.T) { func TestWrongResponderHandshake(t *testing.T) {
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
@@ -464,6 +499,35 @@ func TestRelays(t *testing.T) {
r.RenderHostmaps("Final hostmaps", myControl, relayControl, theirControl) r.RenderHostmaps("Final hostmaps", myControl, relayControl, theirControl)
} }
func TestRelaysDontCareAboutIps(t *testing.T) {
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version2, ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}})
relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "relay ", "2001::9999/24", m{"relay": m{"am_relay": true}})
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}})
// Teach my how to get to the relay and that their can be reached via the relay
myControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr)
myControl.InjectRelays(theirVpnIpNet[0].Addr(), []netip.Addr{relayVpnIpNet[0].Addr()})
relayControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
// Build a router so we don't have to reason who gets which packet
r := router.NewR(t, myControl, relayControl, theirControl)
defer r.RenderFlow()
// Start the servers
myControl.Start()
relayControl.Start()
theirControl.Start()
t.Log("Trigger a handshake from me to them via the relay")
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
p := r.RouteForAllUntilTxTun(theirControl)
r.Log("Assert the tunnel works")
assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80)
r.RenderHostmaps("Final hostmaps", myControl, relayControl, theirControl)
}
func TestReestablishRelays(t *testing.T) { func TestReestablishRelays(t *testing.T) {
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}}) myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}})
@@ -506,7 +570,7 @@ func TestReestablishRelays(t *testing.T) {
curIndexes := len(myControl.GetHostmap().Indexes) curIndexes := len(myControl.GetHostmap().Indexes)
for curIndexes >= start { for curIndexes >= start {
curIndexes = len(myControl.GetHostmap().Indexes) curIndexes = len(myControl.GetHostmap().Indexes)
r.Logf("Wait for the dead index to go away:start=%v indexes, currnet=%v indexes", start, curIndexes) r.Logf("Wait for the dead index to go away:start=%v indexes, current=%v indexes", start, curIndexes)
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me should fail")) myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me should fail"))
r.RouteForAllExitFunc(func(p *udp.Packet, c *nebula.Control) router.ExitType { r.RouteForAllExitFunc(func(p *udp.Packet, c *nebula.Control) router.ExitType {
@@ -1052,6 +1116,9 @@ func TestRehandshakingLoser(t *testing.T) {
t.Log("Stand up a tunnel between me and them") t.Log("Stand up a tunnel between me and them")
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
myControl.GetHostInfoByVpnAddr(theirVpnIpNet[0].Addr(), false)
theirControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false)
r.RenderHostmaps("Starting hostmaps", myControl, theirControl) r.RenderHostmaps("Starting hostmaps", myControl, theirControl)
r.Log("Renew their certificate and spin until mine sees it") r.Log("Renew their certificate and spin until mine sees it")
@@ -1224,3 +1291,109 @@ func TestV2NonPrimaryWithLighthouse(t *testing.T) {
myControl.Stop() myControl.Stop()
theirControl.Stop() theirControl.Stop()
} }
func TestV2NonPrimaryWithOffNetLighthouse(t *testing.T) {
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
lhControl, lhVpnIpNet, lhUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "lh ", "2001::1/64", m{"lighthouse": m{"am_lighthouse": true}})
o := m{
"static_host_map": m{
lhVpnIpNet[0].Addr().String(): []string{lhUdpAddr.String()},
},
"lighthouse": m{
"hosts": []string{lhVpnIpNet[0].Addr().String()},
"local_allow_list": m{
// Try and block our lighthouse updates from using the actual addresses assigned to this computer
// If we start discovering addresses the test router doesn't know about then test traffic cant flow
"10.0.0.0/24": true,
"::/0": false,
},
},
}
myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version2, ca, caKey, "me ", "10.128.0.2/24, ff::2/64", o)
theirControl, theirVpnIpNet, _, _ := newSimpleServer(cert.Version2, ca, caKey, "them", "10.128.0.3/24, ff::3/64", o)
// Build a router so we don't have to reason who gets which packet
r := router.NewR(t, lhControl, myControl, theirControl)
defer r.RenderFlow()
// Start the servers
lhControl.Start()
myControl.Start()
theirControl.Start()
t.Log("Stand up an ipv6 tunnel between me and them")
assert.True(t, myVpnIpNet[1].Addr().Is6())
assert.True(t, theirVpnIpNet[1].Addr().Is6())
assertTunnel(t, myVpnIpNet[1].Addr(), theirVpnIpNet[1].Addr(), myControl, theirControl, r)
lhControl.Stop()
myControl.Stop()
theirControl.Stop()
}
func TestGoodHandshakeUnsafeDest(t *testing.T) {
unsafePrefix := "192.168.6.0/24"
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServerWithUdpAndUnsafeNetworks(cert.Version2, ca, caKey, "spooky", "10.128.0.2/24", netip.MustParseAddrPort("10.64.0.2:4242"), unsafePrefix, nil)
route := m{"route": unsafePrefix, "via": theirVpnIpNet[0].Addr().String()}
myCfg := m{
"tun": m{
"unsafe_routes": []m{route},
},
}
myControl, myVpnIpNet, myUdpAddr, myConfig := newSimpleServer(cert.Version2, ca, caKey, "me", "10.128.0.1/24", myCfg)
t.Logf("my config %v", myConfig)
// Put their info in our lighthouse
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
spookyDest := netip.MustParseAddr("192.168.6.4")
// Start the servers
myControl.Start()
theirControl.Start()
t.Log("Send a udp packet through to begin standing up the tunnel, this should come out the other side")
myControl.InjectTunUDPPacket(spookyDest, 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
t.Log("Have them consume my stage 0 packet. They have a tunnel now")
theirControl.InjectUDPPacket(myControl.GetFromUDP(true))
t.Log("Get their stage 1 packet so that we can play with it")
stage1Packet := theirControl.GetFromUDP(true)
t.Log("I consume a garbage packet with a proper nebula header for our tunnel")
// this should log a statement and get ignored, allowing the real handshake packet to complete the tunnel
badPacket := stage1Packet.Copy()
badPacket.Data = badPacket.Data[:len(badPacket.Data)-header.Len]
myControl.InjectUDPPacket(badPacket)
t.Log("Have me consume their real stage 1 packet. I have a tunnel now")
myControl.InjectUDPPacket(stage1Packet)
t.Log("Wait until we see my cached packet come through")
myControl.WaitForType(1, 0, theirControl)
t.Log("Make sure our host infos are correct")
assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet, theirVpnIpNet, myControl, theirControl)
t.Log("Get that cached packet and make sure it looks right")
myCachedPacket := theirControl.GetFromTun(true)
assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet[0].Addr(), spookyDest, 80, 80)
//reply
theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, spookyDest, 80, []byte("Hi from the spookyman"))
//wait for reply
theirControl.WaitForType(1, 0, myControl)
theirCachedPacket := myControl.GetFromTun(true)
assertUdpPacket(t, []byte("Hi from the spookyman"), theirCachedPacket, spookyDest, myVpnIpNet[0].Addr(), 80, 80)
t.Log("Do a bidirectional tunnel test")
r := router.NewR(t, myControl, theirControl)
defer r.RenderFlow()
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
r.RenderHostmaps("Final hostmaps", myControl, theirControl)
myControl.Stop()
theirControl.Stop()
}

View File

@@ -22,15 +22,14 @@ import (
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/e2e/router" "github.com/slackhq/nebula/e2e/router"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"gopkg.in/yaml.v3" "github.com/stretchr/testify/require"
"go.yaml.in/yaml/v3"
) )
type m = map[string]any type m = map[string]any
// newSimpleServer creates a nebula instance with many assumptions // newSimpleServer creates a nebula instance with many assumptions
func newSimpleServer(v cert.Version, caCrt cert.Certificate, caKey []byte, name string, sVpnNetworks string, overrides m) (*nebula.Control, []netip.Prefix, netip.AddrPort, *config.C) { func newSimpleServer(v cert.Version, caCrt cert.Certificate, caKey []byte, name string, sVpnNetworks string, overrides m) (*nebula.Control, []netip.Prefix, netip.AddrPort, *config.C) {
l := NewTestLogger()
var vpnNetworks []netip.Prefix var vpnNetworks []netip.Prefix
for _, sn := range strings.Split(sVpnNetworks, ",") { for _, sn := range strings.Split(sVpnNetworks, ",") {
vpnIpNet, err := netip.ParsePrefix(strings.TrimSpace(sn)) vpnIpNet, err := netip.ParsePrefix(strings.TrimSpace(sn))
@@ -56,7 +55,54 @@ func newSimpleServer(v cert.Version, caCrt cert.Certificate, caKey []byte, name
budpIp[3] = 239 budpIp[3] = 239
udpAddr = netip.AddrPortFrom(netip.AddrFrom16(budpIp), 4242) udpAddr = netip.AddrPortFrom(netip.AddrFrom16(budpIp), 4242)
} }
_, _, myPrivKey, myPEM := cert_test.NewTestCert(v, cert.Curve_CURVE25519, caCrt, caKey, name, time.Now(), time.Now().Add(5*time.Minute), vpnNetworks, nil, []string{}) return newSimpleServerWithUdp(v, caCrt, caKey, name, sVpnNetworks, udpAddr, overrides)
}
func newSimpleServerWithUdp(v cert.Version, caCrt cert.Certificate, caKey []byte, name string, sVpnNetworks string, udpAddr netip.AddrPort, overrides m) (*nebula.Control, []netip.Prefix, netip.AddrPort, *config.C) {
return newSimpleServerWithUdpAndUnsafeNetworks(v, caCrt, caKey, name, sVpnNetworks, udpAddr, "", overrides)
}
func newSimpleServerWithUdpAndUnsafeNetworks(v cert.Version, caCrt cert.Certificate, caKey []byte, name string, sVpnNetworks string, udpAddr netip.AddrPort, sUnsafeNetworks string, overrides m) (*nebula.Control, []netip.Prefix, netip.AddrPort, *config.C) {
l := NewTestLogger()
var vpnNetworks []netip.Prefix
for _, sn := range strings.Split(sVpnNetworks, ",") {
vpnIpNet, err := netip.ParsePrefix(strings.TrimSpace(sn))
if err != nil {
panic(err)
}
vpnNetworks = append(vpnNetworks, vpnIpNet)
}
if len(vpnNetworks) == 0 {
panic("no vpn networks")
}
firewallInbound := []m{{
"proto": "any",
"port": "any",
"host": "any",
}}
var unsafeNetworks []netip.Prefix
if sUnsafeNetworks != "" {
firewallInbound = []m{{
"proto": "any",
"port": "any",
"host": "any",
"local_cidr": "0.0.0.0/0",
}}
for _, sn := range strings.Split(sUnsafeNetworks, ",") {
x, err := netip.ParsePrefix(strings.TrimSpace(sn))
if err != nil {
panic(err)
}
unsafeNetworks = append(unsafeNetworks, x)
}
}
_, _, myPrivKey, myPEM := cert_test.NewTestCert(v, cert.Curve_CURVE25519, caCrt, caKey, name, time.Now(), time.Now().Add(5*time.Minute), vpnNetworks, unsafeNetworks, []string{})
caB, err := caCrt.MarshalPEM() caB, err := caCrt.MarshalPEM()
if err != nil { if err != nil {
@@ -76,11 +122,7 @@ func newSimpleServer(v cert.Version, caCrt cert.Certificate, caKey []byte, name
"port": "any", "port": "any",
"host": "any", "host": "any",
}}, }},
"inbound": []m{{ "inbound": firewallInbound,
"proto": "any",
"port": "any",
"host": "any",
}},
}, },
//"handshakes": m{ //"handshakes": m{
// "try_interval": "1s", // "try_interval": "1s",
@@ -129,6 +171,109 @@ func newSimpleServer(v cert.Version, caCrt cert.Certificate, caKey []byte, name
return control, vpnNetworks, udpAddr, c return control, vpnNetworks, udpAddr, c
} }
// newServer creates a nebula instance with fewer assumptions
func newServer(caCrt []cert.Certificate, certs []cert.Certificate, key []byte, overrides m) (*nebula.Control, []netip.Prefix, netip.AddrPort, *config.C) {
l := NewTestLogger()
vpnNetworks := certs[len(certs)-1].Networks()
var udpAddr netip.AddrPort
if vpnNetworks[0].Addr().Is4() {
budpIp := vpnNetworks[0].Addr().As4()
budpIp[1] -= 128
udpAddr = netip.AddrPortFrom(netip.AddrFrom4(budpIp), 4242)
} else {
budpIp := vpnNetworks[0].Addr().As16()
// beef for funsies
budpIp[2] = 190
budpIp[3] = 239
udpAddr = netip.AddrPortFrom(netip.AddrFrom16(budpIp), 4242)
}
caStr := ""
for _, ca := range caCrt {
x, err := ca.MarshalPEM()
if err != nil {
panic(err)
}
caStr += string(x)
}
certStr := ""
for _, c := range certs {
x, err := c.MarshalPEM()
if err != nil {
panic(err)
}
certStr += string(x)
}
mc := m{
"pki": m{
"ca": caStr,
"cert": certStr,
"key": string(key),
},
//"tun": m{"disabled": true},
"firewall": m{
"outbound": []m{{
"proto": "any",
"port": "any",
"host": "any",
}},
"inbound": []m{{
"proto": "any",
"port": "any",
"host": "any",
}},
},
//"handshakes": m{
// "try_interval": "1s",
//},
"listen": m{
"host": udpAddr.Addr().String(),
"port": udpAddr.Port(),
},
"logging": m{
"timestamp_format": fmt.Sprintf("%v 15:04:05.000000", certs[0].Name()),
"level": l.Level.String(),
},
"timers": m{
"pending_deletion_interval": 2,
"connection_alive_interval": 2,
},
}
if overrides != nil {
final := m{}
err := mergo.Merge(&final, overrides, mergo.WithAppendSlice)
if err != nil {
panic(err)
}
err = mergo.Merge(&final, mc, mergo.WithAppendSlice)
if err != nil {
panic(err)
}
mc = final
}
cb, err := yaml.Marshal(mc)
if err != nil {
panic(err)
}
c := config.NewC(l)
cStr := string(cb)
c.LoadString(cStr)
control, err := nebula.Main(c, false, "e2e-test", l, nil)
if err != nil {
panic(err)
}
return control, vpnNetworks, udpAddr, c
}
type doneCb func() type doneCb func()
func deadline(t *testing.T, seconds time.Duration) doneCb { func deadline(t *testing.T, seconds time.Duration) doneCb {
@@ -163,10 +308,10 @@ func assertHostInfoPair(t *testing.T, addrA, addrB netip.AddrPort, vpnNetsA, vpn
// Get both host infos // Get both host infos
//TODO: CERT-V2 we may want to loop over each vpnAddr and assert all the things //TODO: CERT-V2 we may want to loop over each vpnAddr and assert all the things
hBinA := controlA.GetHostInfoByVpnAddr(vpnNetsB[0].Addr(), false) hBinA := controlA.GetHostInfoByVpnAddr(vpnNetsB[0].Addr(), false)
assert.NotNil(t, hBinA, "Host B was not found by vpnAddr in controlA") require.NotNil(t, hBinA, "Host B was not found by vpnAddr in controlA")
hAinB := controlB.GetHostInfoByVpnAddr(vpnNetsA[0].Addr(), false) hAinB := controlB.GetHostInfoByVpnAddr(vpnNetsA[0].Addr(), false)
assert.NotNil(t, hAinB, "Host A was not found by vpnAddr in controlB") require.NotNil(t, hAinB, "Host A was not found by vpnAddr in controlB")
// Check that both vpn and real addr are correct // Check that both vpn and real addr are correct
assert.EqualValues(t, getAddrs(vpnNetsB), hBinA.VpnAddrs, "Host B VpnIp is wrong in control A") assert.EqualValues(t, getAddrs(vpnNetsB), hBinA.VpnAddrs, "Host B VpnIp is wrong in control A")

View File

@@ -700,6 +700,7 @@ func (r *R) FlushAll() {
r.Unlock() r.Unlock()
panic("Can't FlushAll for host: " + p.To.String()) panic("Can't FlushAll for host: " + p.To.String())
} }
receiver.InjectUDPPacket(p)
r.Unlock() r.Unlock()
} }
} }

367
e2e/tunnels_test.go Normal file
View File

@@ -0,0 +1,367 @@
//go:build e2e_testing
// +build e2e_testing
package e2e
import (
"fmt"
"net/netip"
"testing"
"time"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/cert_test"
"github.com/slackhq/nebula/e2e/router"
"github.com/stretchr/testify/assert"
"gopkg.in/yaml.v3"
)
func TestDropInactiveTunnels(t *testing.T) {
// The goal of this test is to ensure the shortest inactivity timeout will close the tunnel on both sides
// under ideal conditions
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", m{"tunnels": m{"drop_inactive": true, "inactivity_timeout": "5s"}})
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", m{"tunnels": m{"drop_inactive": true, "inactivity_timeout": "10m"}})
// Share our underlay information
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
// Start the servers
myControl.Start()
theirControl.Start()
r := router.NewR(t, myControl, theirControl)
r.Log("Assert the tunnel between me and them works")
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
r.Log("Go inactive and wait for the tunnels to get dropped")
waitStart := time.Now()
for {
myIndexes := len(myControl.GetHostmap().Indexes)
theirIndexes := len(theirControl.GetHostmap().Indexes)
if myIndexes == 0 && theirIndexes == 0 {
break
}
since := time.Since(waitStart)
r.Logf("my tunnels: %v; their tunnels: %v; duration: %v", myIndexes, theirIndexes, since)
if since > time.Second*30 {
t.Fatal("Tunnel should have been declared inactive after 5 seconds and before 30 seconds")
}
time.Sleep(1 * time.Second)
r.FlushAll()
}
r.Logf("Inactive tunnels were dropped within %v", time.Since(waitStart))
myControl.Stop()
theirControl.Stop()
}
func TestCertUpgrade(t *testing.T) {
// The goal of this test is to ensure the shortest inactivity timeout will close the tunnel on both sides
// under ideal conditions
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
caB, err := ca.MarshalPEM()
if err != nil {
panic(err)
}
ca2, _, caKey2, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
ca2B, err := ca2.MarshalPEM()
if err != nil {
panic(err)
}
caStr := fmt.Sprintf("%s\n%s", caB, ca2B)
myCert, _, myPrivKey, _ := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "me", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.1/24")}, nil, []string{})
_, myCert2Pem := cert_test.NewTestCertDifferentVersion(myCert, cert.Version2, ca2, caKey2)
theirCert, _, theirPrivKey, _ := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "them", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.2/24")}, nil, []string{})
theirCert2, _ := cert_test.NewTestCertDifferentVersion(theirCert, cert.Version2, ca2, caKey2)
myControl, myVpnIpNet, myUdpAddr, myC := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{myCert}, myPrivKey, m{})
theirControl, theirVpnIpNet, theirUdpAddr, _ := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{theirCert, theirCert2}, theirPrivKey, m{})
// Share our underlay information
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
// Start the servers
myControl.Start()
theirControl.Start()
r := router.NewR(t, myControl, theirControl)
defer r.RenderFlow()
r.Log("Assert the tunnel between me and them works")
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
r.Log("yay")
//todo ???
time.Sleep(1 * time.Second)
r.FlushAll()
mc := m{
"pki": m{
"ca": caStr,
"cert": string(myCert2Pem),
"key": string(myPrivKey),
},
//"tun": m{"disabled": true},
"firewall": myC.Settings["firewall"],
//"handshakes": m{
// "try_interval": "1s",
//},
"listen": myC.Settings["listen"],
"logging": myC.Settings["logging"],
"timers": myC.Settings["timers"],
}
cb, err := yaml.Marshal(mc)
if err != nil {
panic(err)
}
r.Logf("reload new v2-only config")
err = myC.ReloadConfigString(string(cb))
assert.NoError(t, err)
r.Log("yay, spin until their sees it")
waitStart := time.Now()
for {
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
c := theirControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false)
if c == nil {
r.Log("nil")
} else {
version := c.Cert.Version()
r.Logf("version %d", version)
if version == cert.Version2 {
break
}
}
since := time.Since(waitStart)
if since > time.Second*10 {
t.Fatal("Cert should be new by now")
}
time.Sleep(time.Second)
}
r.RenderHostmaps("Final hostmaps", myControl, theirControl)
myControl.Stop()
theirControl.Stop()
}
func TestCertDowngrade(t *testing.T) {
// The goal of this test is to ensure the shortest inactivity timeout will close the tunnel on both sides
// under ideal conditions
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
caB, err := ca.MarshalPEM()
if err != nil {
panic(err)
}
ca2, _, caKey2, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
ca2B, err := ca2.MarshalPEM()
if err != nil {
panic(err)
}
caStr := fmt.Sprintf("%s\n%s", caB, ca2B)
myCert, _, myPrivKey, myCertPem := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "me", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.1/24")}, nil, []string{})
myCert2, _ := cert_test.NewTestCertDifferentVersion(myCert, cert.Version2, ca2, caKey2)
theirCert, _, theirPrivKey, _ := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "them", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.2/24")}, nil, []string{})
theirCert2, _ := cert_test.NewTestCertDifferentVersion(theirCert, cert.Version2, ca2, caKey2)
myControl, myVpnIpNet, myUdpAddr, myC := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{myCert2}, myPrivKey, m{})
theirControl, theirVpnIpNet, theirUdpAddr, _ := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{theirCert, theirCert2}, theirPrivKey, m{})
// Share our underlay information
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
// Start the servers
myControl.Start()
theirControl.Start()
r := router.NewR(t, myControl, theirControl)
defer r.RenderFlow()
r.Log("Assert the tunnel between me and them works")
//assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r)
//r.Log("yay")
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
r.Log("yay")
//todo ???
time.Sleep(1 * time.Second)
r.FlushAll()
mc := m{
"pki": m{
"ca": caStr,
"cert": string(myCertPem),
"key": string(myPrivKey),
},
"firewall": myC.Settings["firewall"],
"listen": myC.Settings["listen"],
"logging": myC.Settings["logging"],
"timers": myC.Settings["timers"],
}
cb, err := yaml.Marshal(mc)
if err != nil {
panic(err)
}
r.Logf("reload new v1-only config")
err = myC.ReloadConfigString(string(cb))
assert.NoError(t, err)
r.Log("yay, spin until their sees it")
waitStart := time.Now()
for {
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
c := theirControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false)
c2 := myControl.GetHostInfoByVpnAddr(theirVpnIpNet[0].Addr(), false)
if c == nil || c2 == nil {
r.Log("nil")
} else {
version := c.Cert.Version()
theirVersion := c2.Cert.Version()
r.Logf("version %d,%d", version, theirVersion)
if version == cert.Version1 {
break
}
}
since := time.Since(waitStart)
if since > time.Second*5 {
r.Log("it is unusual that the cert is not new yet, but not a failure yet")
}
if since > time.Second*10 {
r.Log("wtf")
t.Fatal("Cert should be new by now")
}
time.Sleep(time.Second)
}
r.RenderHostmaps("Final hostmaps", myControl, theirControl)
myControl.Stop()
theirControl.Stop()
}
func TestCertMismatchCorrection(t *testing.T) {
// The goal of this test is to ensure the shortest inactivity timeout will close the tunnel on both sides
// under ideal conditions
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
ca2, _, caKey2, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
myCert, _, myPrivKey, _ := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "me", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.1/24")}, nil, []string{})
myCert2, _ := cert_test.NewTestCertDifferentVersion(myCert, cert.Version2, ca2, caKey2)
theirCert, _, theirPrivKey, _ := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "them", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.2/24")}, nil, []string{})
theirCert2, _ := cert_test.NewTestCertDifferentVersion(theirCert, cert.Version2, ca2, caKey2)
myControl, myVpnIpNet, myUdpAddr, _ := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{myCert2}, myPrivKey, m{})
theirControl, theirVpnIpNet, theirUdpAddr, _ := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{theirCert, theirCert2}, theirPrivKey, m{})
// Share our underlay information
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
// Start the servers
myControl.Start()
theirControl.Start()
r := router.NewR(t, myControl, theirControl)
defer r.RenderFlow()
r.Log("Assert the tunnel between me and them works")
//assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r)
//r.Log("yay")
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
r.Log("yay")
//todo ???
time.Sleep(1 * time.Second)
r.FlushAll()
waitStart := time.Now()
for {
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
c := theirControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false)
c2 := myControl.GetHostInfoByVpnAddr(theirVpnIpNet[0].Addr(), false)
if c == nil || c2 == nil {
r.Log("nil")
} else {
version := c.Cert.Version()
theirVersion := c2.Cert.Version()
r.Logf("version %d,%d", version, theirVersion)
if version == theirVersion {
break
}
}
since := time.Since(waitStart)
if since > time.Second*5 {
r.Log("wtf")
}
if since > time.Second*10 {
r.Log("wtf")
t.Fatal("Cert should be new by now")
}
time.Sleep(time.Second)
}
r.RenderHostmaps("Final hostmaps", myControl, theirControl)
myControl.Stop()
theirControl.Stop()
}
func TestCrossStackRelaysWork(t *testing.T) {
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version2, ca, caKey, "me ", "10.128.0.1/24,fc00::1/64", m{"relay": m{"use_relays": true}})
relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "relay ", "10.128.0.128/24,fc00::128/64", m{"relay": m{"am_relay": true}})
theirUdp := netip.MustParseAddrPort("10.0.0.2:4242")
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServerWithUdp(cert.Version2, ca, caKey, "them ", "fc00::2/64", theirUdp, m{"relay": m{"use_relays": true}})
//myVpnV4 := myVpnIpNet[0]
myVpnV6 := myVpnIpNet[1]
relayVpnV4 := relayVpnIpNet[0]
relayVpnV6 := relayVpnIpNet[1]
theirVpnV6 := theirVpnIpNet[0]
// Teach my how to get to the relay and that their can be reached via the relay
myControl.InjectLightHouseAddr(relayVpnV4.Addr(), relayUdpAddr)
myControl.InjectLightHouseAddr(relayVpnV6.Addr(), relayUdpAddr)
myControl.InjectRelays(theirVpnV6.Addr(), []netip.Addr{relayVpnV6.Addr()})
relayControl.InjectLightHouseAddr(theirVpnV6.Addr(), theirUdpAddr)
// Build a router so we don't have to reason who gets which packet
r := router.NewR(t, myControl, relayControl, theirControl)
defer r.RenderFlow()
// Start the servers
myControl.Start()
relayControl.Start()
theirControl.Start()
t.Log("Trigger a handshake from me to them via the relay")
myControl.InjectTunUDPPacket(theirVpnV6.Addr(), 80, myVpnV6.Addr(), 80, []byte("Hi from me"))
p := r.RouteForAllUntilTxTun(theirControl)
r.Log("Assert the tunnel works")
assertUdpPacket(t, []byte("Hi from me"), p, myVpnV6.Addr(), theirVpnV6.Addr(), 80, 80)
t.Log("reply?")
theirControl.InjectTunUDPPacket(myVpnV6.Addr(), 80, theirVpnV6.Addr(), 80, []byte("Hi from them"))
p = r.RouteForAllUntilTxTun(myControl)
assertUdpPacket(t, []byte("Hi from them"), p, theirVpnV6.Addr(), myVpnV6.Addr(), 80, 80)
r.RenderHostmaps("Final hostmaps", myControl, relayControl, theirControl)
//t.Log("finish up")
//myControl.Stop()
//theirControl.Stop()
//relayControl.Stop()
}

View File

@@ -338,6 +338,18 @@ logging:
# after receiving the response for lighthouse queries # after receiving the response for lighthouse queries
#trigger_buffer: 64 #trigger_buffer: 64
# Tunnel manager settings
#tunnels:
# drop_inactive controls whether inactive tunnels are maintained or dropped after the inactive_timeout period has
# elapsed.
# In general, it is a good idea to enable this setting. It will be enabled by default in a future release.
# This setting is reloadable
#drop_inactive: false
# inactivity_timeout controls how long a tunnel MUST NOT see any inbound or outbound traffic before being considered
# inactive and eligible to be dropped.
# This setting is reloadable
#inactivity_timeout: 10m
# Nebula security group configuration # Nebula security group configuration
firewall: firewall:

View File

@@ -10,6 +10,7 @@ import (
"reflect" "reflect"
"strconv" "strconv"
"strings" "strings"
"sync"
"time" "time"
"github.com/gaissmai/bart" "github.com/gaissmai/bart"
@@ -18,7 +19,6 @@ import (
"github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/firewall" "github.com/slackhq/nebula/firewall"
"github.com/wadey/synctrace"
) )
type FirewallInterface interface { type FirewallInterface interface {
@@ -76,7 +76,7 @@ type firewallMetrics struct {
} }
type FirewallConntrack struct { type FirewallConntrack struct {
synctrace.Mutex sync.Mutex
Conns map[firewall.Packet]*conn Conns map[firewall.Packet]*conn
TimerWheel *TimerWheel[firewall.Packet] TimerWheel *TimerWheel[firewall.Packet]
@@ -164,7 +164,6 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D
return &Firewall{ return &Firewall{
Conntrack: &FirewallConntrack{ Conntrack: &FirewallConntrack{
Mutex: synctrace.NewMutex("firewall-conntrack"),
Conns: make(map[firewall.Packet]*conn), Conns: make(map[firewall.Packet]*conn),
TimerWheel: NewTimerWheel[firewall.Packet](tmin, tmax), TimerWheel: NewTimerWheel[firewall.Packet](tmin, tmax),
}, },
@@ -418,30 +417,45 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw
return nil return nil
} }
var ErrInvalidRemoteIP = errors.New("remote IP is not in remote certificate subnets") var ErrUnknownNetworkType = errors.New("unknown network type")
var ErrInvalidLocalIP = errors.New("local IP is not in list of handled local IPs") var ErrPeerRejected = errors.New("remote address is not within a network that we handle")
var ErrInvalidRemoteIP = errors.New("remote address is not in remote certificate networks")
var ErrInvalidLocalIP = errors.New("local address is not in list of handled local addresses")
var ErrNoMatchingRule = errors.New("no matching rule in firewall table") var ErrNoMatchingRule = errors.New("no matching rule in firewall table")
// Drop returns an error if the packet should be dropped, explaining why. It // Drop returns an error if the packet should be dropped, explaining why. It
// returns nil if the packet should not be dropped. // returns nil if the packet should not be dropped.
func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *cert.CAPool, localCache firewall.ConntrackCache) error { func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *cert.CAPool, localCache firewall.ConntrackCache, now time.Time) error {
// Check if we spoke to this tuple, if we did then allow this packet // Check if we spoke to this tuple, if we did then allow this packet
if f.inConns(fp, h, caPool, localCache) { if f.inConns(fp, h, caPool, localCache, now) {
return nil return nil
} }
// Make sure remote address matches nebula certificate // Make sure remote address matches nebula certificate, and determine how to treat it
if h.networks != nil { if h.networks == nil {
if !h.networks.Contains(fp.RemoteAddr) {
f.metrics(incoming).droppedRemoteAddr.Inc(1)
return ErrInvalidRemoteIP
}
} else {
// Simple case: Certificate has one address and no unsafe networks // Simple case: Certificate has one address and no unsafe networks
if h.vpnAddrs[0] != fp.RemoteAddr { if h.vpnAddrs[0] != fp.RemoteAddr {
f.metrics(incoming).droppedRemoteAddr.Inc(1) f.metrics(incoming).droppedRemoteAddr.Inc(1)
return ErrInvalidRemoteIP return ErrInvalidRemoteIP
} }
} else {
nwType, ok := h.networks.Lookup(fp.RemoteAddr)
if !ok {
f.metrics(incoming).droppedRemoteAddr.Inc(1)
return ErrInvalidRemoteIP
}
switch nwType {
case NetworkTypeVPN:
break // nothing special
case NetworkTypeVPNPeer:
f.metrics(incoming).droppedRemoteAddr.Inc(1)
return ErrPeerRejected // reject for now, one day this may have different FW rules
case NetworkTypeUnsafe:
break // nothing special, one day this may have different FW rules
default:
f.metrics(incoming).droppedRemoteAddr.Inc(1)
return ErrUnknownNetworkType //should never happen
}
} }
// Make sure we are supposed to be handling this local ip address // Make sure we are supposed to be handling this local ip address
@@ -462,7 +476,7 @@ func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *
} }
// We always want to conntrack since it is a faster operation // We always want to conntrack since it is a faster operation
f.addConn(fp, incoming) f.addConn(fp, incoming, now)
return nil return nil
} }
@@ -491,7 +505,7 @@ func (f *Firewall) EmitStats() {
metrics.GetOrRegisterGauge("firewall.rules.hash", nil).Update(int64(f.GetRuleHashFNV())) metrics.GetOrRegisterGauge("firewall.rules.hash", nil).Update(int64(f.GetRuleHashFNV()))
} }
func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool, localCache firewall.ConntrackCache) bool { func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool, localCache firewall.ConntrackCache, now time.Time) bool {
if localCache != nil { if localCache != nil {
if _, ok := localCache[fp]; ok { if _, ok := localCache[fp]; ok {
return true return true
@@ -503,7 +517,7 @@ func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool,
// Purge every time we test // Purge every time we test
ep, has := conntrack.TimerWheel.Purge() ep, has := conntrack.TimerWheel.Purge()
if has { if has {
f.evict(ep) f.evict(ep, now)
} }
c, ok := conntrack.Conns[fp] c, ok := conntrack.Conns[fp]
@@ -550,11 +564,11 @@ func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool,
switch fp.Protocol { switch fp.Protocol {
case firewall.ProtoTCP: case firewall.ProtoTCP:
c.Expires = time.Now().Add(f.TCPTimeout) c.Expires = now.Add(f.TCPTimeout)
case firewall.ProtoUDP: case firewall.ProtoUDP:
c.Expires = time.Now().Add(f.UDPTimeout) c.Expires = now.Add(f.UDPTimeout)
default: default:
c.Expires = time.Now().Add(f.DefaultTimeout) c.Expires = now.Add(f.DefaultTimeout)
} }
conntrack.Unlock() conntrack.Unlock()
@@ -566,7 +580,7 @@ func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool,
return true return true
} }
func (f *Firewall) addConn(fp firewall.Packet, incoming bool) { func (f *Firewall) addConn(fp firewall.Packet, incoming bool, now time.Time) {
var timeout time.Duration var timeout time.Duration
c := &conn{} c := &conn{}
@@ -582,7 +596,7 @@ func (f *Firewall) addConn(fp firewall.Packet, incoming bool) {
conntrack := f.Conntrack conntrack := f.Conntrack
conntrack.Lock() conntrack.Lock()
if _, ok := conntrack.Conns[fp]; !ok { if _, ok := conntrack.Conns[fp]; !ok {
conntrack.TimerWheel.Advance(time.Now()) conntrack.TimerWheel.Advance(now)
conntrack.TimerWheel.Add(fp, timeout) conntrack.TimerWheel.Add(fp, timeout)
} }
@@ -590,14 +604,14 @@ func (f *Firewall) addConn(fp firewall.Packet, incoming bool) {
// firewall reload // firewall reload
c.incoming = incoming c.incoming = incoming
c.rulesVersion = f.rulesVersion c.rulesVersion = f.rulesVersion
c.Expires = time.Now().Add(timeout) c.Expires = now.Add(timeout)
conntrack.Conns[fp] = c conntrack.Conns[fp] = c
conntrack.Unlock() conntrack.Unlock()
} }
// Evict checks if a conntrack entry has expired, if so it is removed, if not it is re-added to the wheel // Evict checks if a conntrack entry has expired, if so it is removed, if not it is re-added to the wheel
// Caller must own the connMutex lock! // Caller must own the connMutex lock!
func (f *Firewall) evict(p firewall.Packet) { func (f *Firewall) evict(p firewall.Packet, now time.Time) {
// Are we still tracking this conn? // Are we still tracking this conn?
conntrack := f.Conntrack conntrack := f.Conntrack
t, ok := conntrack.Conns[p] t, ok := conntrack.Conns[p]
@@ -605,11 +619,11 @@ func (f *Firewall) evict(p firewall.Packet) {
return return
} }
newT := t.Expires.Sub(time.Now()) newT := t.Expires.Sub(now)
// Timeout is in the future, re-add the timer // Timeout is in the future, re-add the timer
if newT > 0 { if newT > 0 {
conntrack.TimerWheel.Advance(time.Now()) conntrack.TimerWheel.Advance(now)
conntrack.TimerWheel.Add(p, newT) conntrack.TimerWheel.Add(p, newT)
return return
} }

View File

@@ -8,6 +8,8 @@ import (
"testing" "testing"
"time" "time"
"github.com/gaissmai/bart"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/firewall" "github.com/slackhq/nebula/firewall"
@@ -68,6 +70,9 @@ func TestFirewall_AddRule(t *testing.T) {
ti, err := netip.ParsePrefix("1.2.3.4/32") ti, err := netip.ParsePrefix("1.2.3.4/32")
require.NoError(t, err) require.NoError(t, err)
ti6, err := netip.ParsePrefix("fd12::34/128")
require.NoError(t, err)
require.NoError(t, fw.AddRule(true, firewall.ProtoTCP, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoTCP, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
// An empty rule is any // An empty rule is any
assert.True(t, fw.InRules.TCP[1].Any.Any.Any) assert.True(t, fw.InRules.TCP[1].Any.Any.Any)
@@ -92,12 +97,24 @@ func TestFirewall_AddRule(t *testing.T) {
_, ok := fw.OutRules.AnyProto[1].Any.CIDR.Get(ti) _, ok := fw.OutRules.AnyProto[1].Any.CIDR.Get(ti)
assert.True(t, ok) assert.True(t, ok)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti6, netip.Prefix{}, "", ""))
assert.Nil(t, fw.OutRules.AnyProto[1].Any.Any)
_, ok = fw.OutRules.AnyProto[1].Any.CIDR.Get(ti6)
assert.True(t, ok)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, ti, "", "")) require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, ti, "", ""))
assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any) assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any)
_, ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti) _, ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti)
assert.True(t, ok) assert.True(t, ok)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, ti6, "", ""))
assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any)
_, ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti6)
assert.True(t, ok)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "ca-name", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "ca-name", ""))
assert.Contains(t, fw.InRules.UDP[1].CANames, "ca-name") assert.Contains(t, fw.InRules.UDP[1].CANames, "ca-name")
@@ -117,6 +134,13 @@ func TestFirewall_AddRule(t *testing.T) {
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp, netip.Prefix{}, "", "")) require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp, netip.Prefix{}, "", ""))
assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any) assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
anyIp6, err := netip.ParsePrefix("::/0")
require.NoError(t, err)
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp6, netip.Prefix{}, "", ""))
assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
// Test error conditions // Test error conditions
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.Error(t, fw.AddRule(true, math.MaxUint8, 0, 0, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", "")) require.Error(t, fw.AddRule(true, math.MaxUint8, 0, 0, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
@@ -127,7 +151,8 @@ func TestFirewall_Drop(t *testing.T) {
l := test.NewLogger() l := test.NewLogger()
ob := &bytes.Buffer{} ob := &bytes.Buffer{}
l.SetOutput(ob) l.SetOutput(ob)
myVpnNetworksTable := new(bart.Lite)
myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
p := firewall.Packet{ p := firewall.Packet{
LocalAddr: netip.MustParseAddr("1.2.3.4"), LocalAddr: netip.MustParseAddr("1.2.3.4"),
RemoteAddr: netip.MustParseAddr("1.2.3.4"), RemoteAddr: netip.MustParseAddr("1.2.3.4"),
@@ -152,7 +177,7 @@ func TestFirewall_Drop(t *testing.T) {
}, },
vpnAddrs: []netip.Addr{netip.MustParseAddr("1.2.3.4")}, vpnAddrs: []netip.Addr{netip.MustParseAddr("1.2.3.4")},
} }
h.buildNetworks(c.networks, c.unsafeNetworks) h.buildNetworks(myVpnNetworksTable, &c)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
@@ -199,6 +224,85 @@ func TestFirewall_Drop(t *testing.T) {
require.NoError(t, fw.Drop(p, true, &h, cp, nil)) require.NoError(t, fw.Drop(p, true, &h, cp, nil))
} }
func TestFirewall_DropV6(t *testing.T) {
l := test.NewLogger()
ob := &bytes.Buffer{}
l.SetOutput(ob)
myVpnNetworksTable := new(bart.Lite)
myVpnNetworksTable.Insert(netip.MustParsePrefix("fd00::/7"))
p := firewall.Packet{
LocalAddr: netip.MustParseAddr("fd12::34"),
RemoteAddr: netip.MustParseAddr("fd12::34"),
LocalPort: 10,
RemotePort: 90,
Protocol: firewall.ProtoUDP,
Fragment: false,
}
c := dummyCert{
name: "host1",
networks: []netip.Prefix{netip.MustParsePrefix("fd12::34/120")},
groups: []string{"default-group"},
issuer: "signer-shasum",
}
h := HostInfo{
ConnectionState: &ConnectionState{
peerCert: &cert.CachedCertificate{
Certificate: &c,
InvertedGroups: map[string]struct{}{"default-group": {}},
},
},
vpnAddrs: []netip.Addr{netip.MustParseAddr("fd12::34")},
}
h.buildNetworks(myVpnNetworksTable, &c)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
cp := cert.NewCAPool()
// Drop outbound
assert.Equal(t, ErrNoMatchingRule, fw.Drop(p, false, &h, cp, nil))
// Allow inbound
resetConntrack(fw)
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
// Allow outbound because conntrack
require.NoError(t, fw.Drop(p, false, &h, cp, nil))
// test remote mismatch
oldRemote := p.RemoteAddr
p.RemoteAddr = netip.MustParseAddr("fd12::56")
assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrInvalidRemoteIP)
p.RemoteAddr = oldRemote
// ensure signer doesn't get in the way of group checks
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum"))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad"))
assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
// test caSha doesn't drop on match
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad"))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum"))
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
// ensure ca name doesn't get in the way of group checks
cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", ""))
assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
// test caName doesn't drop on match
cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", ""))
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
}
func BenchmarkFirewallTable_match(b *testing.B) { func BenchmarkFirewallTable_match(b *testing.B) {
f := &Firewall{} f := &Firewall{}
ft := FirewallTable{ ft := FirewallTable{
@@ -208,6 +312,10 @@ func BenchmarkFirewallTable_match(b *testing.B) {
pfix := netip.MustParsePrefix("172.1.1.1/32") pfix := netip.MustParsePrefix("172.1.1.1/32")
_ = ft.TCP.addRule(f, 10, 10, []string{"good-group"}, "good-host", pfix, netip.Prefix{}, "", "") _ = ft.TCP.addRule(f, 10, 10, []string{"good-group"}, "good-host", pfix, netip.Prefix{}, "", "")
_ = ft.TCP.addRule(f, 100, 100, []string{"good-group"}, "good-host", netip.Prefix{}, pfix, "", "") _ = ft.TCP.addRule(f, 100, 100, []string{"good-group"}, "good-host", netip.Prefix{}, pfix, "", "")
pfix6 := netip.MustParsePrefix("fd11::11/128")
_ = ft.TCP.addRule(f, 10, 10, []string{"good-group"}, "good-host", pfix6, netip.Prefix{}, "", "")
_ = ft.TCP.addRule(f, 100, 100, []string{"good-group"}, "good-host", netip.Prefix{}, pfix6, "", "")
cp := cert.NewCAPool() cp := cert.NewCAPool()
b.Run("fail on proto", func(b *testing.B) { b.Run("fail on proto", func(b *testing.B) {
@@ -239,6 +347,15 @@ func BenchmarkFirewallTable_match(b *testing.B) {
assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: ip.Addr()}, true, c, cp)) assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: ip.Addr()}, true, c, cp))
} }
}) })
b.Run("pass proto, port, fail on local CIDRv6", func(b *testing.B) {
c := &cert.CachedCertificate{
Certificate: &dummyCert{},
}
ip := netip.MustParsePrefix("fd99::99/128")
for n := 0; n < b.N; n++ {
assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: ip.Addr()}, true, c, cp))
}
})
b.Run("pass proto, port, any local CIDR, fail all group, name, and cidr", func(b *testing.B) { b.Run("pass proto, port, any local CIDR, fail all group, name, and cidr", func(b *testing.B) {
c := &cert.CachedCertificate{ c := &cert.CachedCertificate{
@@ -252,6 +369,18 @@ func BenchmarkFirewallTable_match(b *testing.B) {
assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp)) assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp))
} }
}) })
b.Run("pass proto, port, any local CIDRv6, fail all group, name, and cidr", func(b *testing.B) {
c := &cert.CachedCertificate{
Certificate: &dummyCert{
name: "nope",
networks: []netip.Prefix{netip.MustParsePrefix("fd99::99/128")},
},
InvertedGroups: map[string]struct{}{"nope": {}},
}
for n := 0; n < b.N; n++ {
assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp))
}
})
b.Run("pass proto, port, specific local CIDR, fail all group, name, and cidr", func(b *testing.B) { b.Run("pass proto, port, specific local CIDR, fail all group, name, and cidr", func(b *testing.B) {
c := &cert.CachedCertificate{ c := &cert.CachedCertificate{
@@ -265,6 +394,18 @@ func BenchmarkFirewallTable_match(b *testing.B) {
assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: pfix.Addr()}, true, c, cp)) assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: pfix.Addr()}, true, c, cp))
} }
}) })
b.Run("pass proto, port, specific local CIDRv6, fail all group, name, and cidr", func(b *testing.B) {
c := &cert.CachedCertificate{
Certificate: &dummyCert{
name: "nope",
networks: []netip.Prefix{netip.MustParsePrefix("fd99::99/128")},
},
InvertedGroups: map[string]struct{}{"nope": {}},
}
for n := 0; n < b.N; n++ {
assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: pfix6.Addr()}, true, c, cp))
}
})
b.Run("pass on group on any local cidr", func(b *testing.B) { b.Run("pass on group on any local cidr", func(b *testing.B) {
c := &cert.CachedCertificate{ c := &cert.CachedCertificate{
@@ -289,6 +430,17 @@ func BenchmarkFirewallTable_match(b *testing.B) {
assert.True(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: pfix.Addr()}, true, c, cp)) assert.True(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: pfix.Addr()}, true, c, cp))
} }
}) })
b.Run("pass on group on specific local cidr6", func(b *testing.B) {
c := &cert.CachedCertificate{
Certificate: &dummyCert{
name: "nope",
},
InvertedGroups: map[string]struct{}{"good-group": {}},
}
for n := 0; n < b.N; n++ {
assert.True(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: pfix6.Addr()}, true, c, cp))
}
})
b.Run("pass on name", func(b *testing.B) { b.Run("pass on name", func(b *testing.B) {
c := &cert.CachedCertificate{ c := &cert.CachedCertificate{
@@ -307,6 +459,8 @@ func TestFirewall_Drop2(t *testing.T) {
l := test.NewLogger() l := test.NewLogger()
ob := &bytes.Buffer{} ob := &bytes.Buffer{}
l.SetOutput(ob) l.SetOutput(ob)
myVpnNetworksTable := new(bart.Lite)
myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
p := firewall.Packet{ p := firewall.Packet{
LocalAddr: netip.MustParseAddr("1.2.3.4"), LocalAddr: netip.MustParseAddr("1.2.3.4"),
@@ -332,7 +486,7 @@ func TestFirewall_Drop2(t *testing.T) {
}, },
vpnAddrs: []netip.Addr{network.Addr()}, vpnAddrs: []netip.Addr{network.Addr()},
} }
h.buildNetworks(c.Certificate.Networks(), c.Certificate.UnsafeNetworks()) h.buildNetworks(myVpnNetworksTable, c.Certificate)
c1 := cert.CachedCertificate{ c1 := cert.CachedCertificate{
Certificate: &dummyCert{ Certificate: &dummyCert{
@@ -347,7 +501,7 @@ func TestFirewall_Drop2(t *testing.T) {
peerCert: &c1, peerCert: &c1,
}, },
} }
h1.buildNetworks(c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks()) h1.buildNetworks(myVpnNetworksTable, c1.Certificate)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
@@ -364,6 +518,8 @@ func TestFirewall_Drop3(t *testing.T) {
l := test.NewLogger() l := test.NewLogger()
ob := &bytes.Buffer{} ob := &bytes.Buffer{}
l.SetOutput(ob) l.SetOutput(ob)
myVpnNetworksTable := new(bart.Lite)
myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
p := firewall.Packet{ p := firewall.Packet{
LocalAddr: netip.MustParseAddr("1.2.3.4"), LocalAddr: netip.MustParseAddr("1.2.3.4"),
@@ -395,7 +551,7 @@ func TestFirewall_Drop3(t *testing.T) {
}, },
vpnAddrs: []netip.Addr{network.Addr()}, vpnAddrs: []netip.Addr{network.Addr()},
} }
h1.buildNetworks(c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks()) h1.buildNetworks(myVpnNetworksTable, c1.Certificate)
c2 := cert.CachedCertificate{ c2 := cert.CachedCertificate{
Certificate: &dummyCert{ Certificate: &dummyCert{
@@ -410,7 +566,7 @@ func TestFirewall_Drop3(t *testing.T) {
}, },
vpnAddrs: []netip.Addr{network.Addr()}, vpnAddrs: []netip.Addr{network.Addr()},
} }
h2.buildNetworks(c2.Certificate.Networks(), c2.Certificate.UnsafeNetworks()) h2.buildNetworks(myVpnNetworksTable, c2.Certificate)
c3 := cert.CachedCertificate{ c3 := cert.CachedCertificate{
Certificate: &dummyCert{ Certificate: &dummyCert{
@@ -425,7 +581,7 @@ func TestFirewall_Drop3(t *testing.T) {
}, },
vpnAddrs: []netip.Addr{network.Addr()}, vpnAddrs: []netip.Addr{network.Addr()},
} }
h3.buildNetworks(c3.Certificate.Networks(), c3.Certificate.UnsafeNetworks()) h3.buildNetworks(myVpnNetworksTable, c3.Certificate)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", netip.Prefix{}, netip.Prefix{}, "", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", netip.Prefix{}, netip.Prefix{}, "", ""))
@@ -447,10 +603,50 @@ func TestFirewall_Drop3(t *testing.T) {
require.NoError(t, fw.Drop(p, true, &h1, cp, nil)) require.NoError(t, fw.Drop(p, true, &h1, cp, nil))
} }
func TestFirewall_Drop3V6(t *testing.T) {
l := test.NewLogger()
ob := &bytes.Buffer{}
l.SetOutput(ob)
myVpnNetworksTable := new(bart.Lite)
myVpnNetworksTable.Insert(netip.MustParsePrefix("fd00::/7"))
p := firewall.Packet{
LocalAddr: netip.MustParseAddr("fd12::34"),
RemoteAddr: netip.MustParseAddr("fd12::34"),
LocalPort: 1,
RemotePort: 1,
Protocol: firewall.ProtoUDP,
Fragment: false,
}
network := netip.MustParsePrefix("fd12::34/120")
c := cert.CachedCertificate{
Certificate: &dummyCert{
name: "host-owner",
networks: []netip.Prefix{network},
},
}
h := HostInfo{
ConnectionState: &ConnectionState{
peerCert: &c,
},
vpnAddrs: []netip.Addr{network.Addr()},
}
h.buildNetworks(myVpnNetworksTable, c.Certificate)
// Test a remote address match
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
cp := cert.NewCAPool()
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.MustParsePrefix("fd12::34/120"), netip.Prefix{}, "", ""))
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
}
func TestFirewall_DropConntrackReload(t *testing.T) { func TestFirewall_DropConntrackReload(t *testing.T) {
l := test.NewLogger() l := test.NewLogger()
ob := &bytes.Buffer{} ob := &bytes.Buffer{}
l.SetOutput(ob) l.SetOutput(ob)
myVpnNetworksTable := new(bart.Lite)
myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
p := firewall.Packet{ p := firewall.Packet{
LocalAddr: netip.MustParseAddr("1.2.3.4"), LocalAddr: netip.MustParseAddr("1.2.3.4"),
@@ -477,7 +673,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
}, },
vpnAddrs: []netip.Addr{network.Addr()}, vpnAddrs: []netip.Addr{network.Addr()},
} }
h.buildNetworks(c.Certificate.Networks(), c.Certificate.UnsafeNetworks()) h.buildNetworks(myVpnNetworksTable, c.Certificate)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
@@ -510,6 +706,52 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrNoMatchingRule) assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrNoMatchingRule)
} }
func TestFirewall_DropIPSpoofing(t *testing.T) {
l := test.NewLogger()
ob := &bytes.Buffer{}
l.SetOutput(ob)
myVpnNetworksTable := new(bart.Lite)
myVpnNetworksTable.Insert(netip.MustParsePrefix("192.0.2.1/24"))
c := cert.CachedCertificate{
Certificate: &dummyCert{
name: "host-owner",
networks: []netip.Prefix{netip.MustParsePrefix("192.0.2.1/24")},
},
}
c1 := cert.CachedCertificate{
Certificate: &dummyCert{
name: "host",
networks: []netip.Prefix{netip.MustParsePrefix("192.0.2.2/24")},
unsafeNetworks: []netip.Prefix{netip.MustParsePrefix("198.51.100.0/24")},
},
}
h1 := HostInfo{
ConnectionState: &ConnectionState{
peerCert: &c1,
},
vpnAddrs: []netip.Addr{c1.Certificate.Networks()[0].Addr()},
}
h1.buildNetworks(myVpnNetworksTable, c1.Certificate)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
cp := cert.NewCAPool()
// Packet spoofed by `c1`. Note that the remote addr is not a valid one.
p := firewall.Packet{
LocalAddr: netip.MustParseAddr("192.0.2.1"),
RemoteAddr: netip.MustParseAddr("192.0.2.3"),
LocalPort: 1,
RemotePort: 1,
Protocol: firewall.ProtoUDP,
Fragment: false,
}
assert.Equal(t, fw.Drop(p, true, &h1, cp, nil), ErrInvalidRemoteIP)
}
func BenchmarkLookup(b *testing.B) { func BenchmarkLookup(b *testing.B) {
ml := func(m map[string]struct{}, a [][]string) { ml := func(m map[string]struct{}, a [][]string) {
for n := 0; n < b.N; n++ { for n := 0; n < b.N; n++ {
@@ -727,6 +969,21 @@ func TestAddFirewallRulesFromConfig(t *testing.T) {
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: cidr}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: cidr}, mf.lastCall)
// Test adding rule with cidr ipv6
cidr6 := netip.MustParsePrefix("fd00::/8")
conf = config.NewC(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": cidr6.String()}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr6, localIp: netip.Prefix{}}, mf.lastCall)
// Test adding rule with local_cidr ipv6
conf = config.NewC(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": cidr6.String()}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: cidr6}, mf.lastCall)
// Test adding rule with ca_sha // Test adding rule with ca_sha
conf = config.NewC(l) conf = config.NewC(l)
mf = &mockFirewall{} mf = &mockFirewall{}
@@ -806,6 +1063,171 @@ func TestFirewall_convertRule(t *testing.T) {
assert.Equal(t, "group1", r.Group) assert.Equal(t, "group1", r.Group)
} }
type testcase struct {
h *HostInfo
p firewall.Packet
c cert.Certificate
err error
}
func (c *testcase) Test(t *testing.T, fw *Firewall) {
t.Helper()
cp := cert.NewCAPool()
resetConntrack(fw)
err := fw.Drop(c.p, true, c.h, cp, nil)
if c.err == nil {
require.NoError(t, err, "failed to not drop remote address %s", c.p.RemoteAddr)
} else {
require.ErrorIs(t, c.err, err, "failed to drop remote address %s", c.p.RemoteAddr)
}
}
func buildTestCase(setup testsetup, err error, theirPrefixes ...netip.Prefix) testcase {
c1 := dummyCert{
name: "host1",
networks: theirPrefixes,
groups: []string{"default-group"},
issuer: "signer-shasum",
}
h := HostInfo{
ConnectionState: &ConnectionState{
peerCert: &cert.CachedCertificate{
Certificate: &c1,
InvertedGroups: map[string]struct{}{"default-group": {}},
},
},
vpnAddrs: make([]netip.Addr, len(theirPrefixes)),
}
for i := range theirPrefixes {
h.vpnAddrs[i] = theirPrefixes[i].Addr()
}
h.buildNetworks(setup.myVpnNetworksTable, &c1)
p := firewall.Packet{
LocalAddr: setup.c.Networks()[0].Addr(), //todo?
RemoteAddr: theirPrefixes[0].Addr(),
LocalPort: 10,
RemotePort: 90,
Protocol: firewall.ProtoUDP,
Fragment: false,
}
return testcase{
h: &h,
p: p,
c: &c1,
err: err,
}
}
type testsetup struct {
c dummyCert
myVpnNetworksTable *bart.Lite
fw *Firewall
}
func newSetup(t *testing.T, l *logrus.Logger, myPrefixes ...netip.Prefix) testsetup {
c := dummyCert{
name: "me",
networks: myPrefixes,
groups: []string{"default-group"},
issuer: "signer-shasum",
}
return newSetupFromCert(t, l, c)
}
func newSetupFromCert(t *testing.T, l *logrus.Logger, c dummyCert) testsetup {
myVpnNetworksTable := new(bart.Lite)
for _, prefix := range c.Networks() {
myVpnNetworksTable.Insert(prefix)
}
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
return testsetup{
c: c,
fw: fw,
myVpnNetworksTable: myVpnNetworksTable,
}
}
func TestFirewall_Drop_EnforceIPMatch(t *testing.T) {
t.Parallel()
l := test.NewLogger()
ob := &bytes.Buffer{}
l.SetOutput(ob)
myPrefix := netip.MustParsePrefix("1.1.1.1/8")
// for now, it's okay that these are all "incoming", the logic this test tries to check doesn't care about in/out
t.Run("allow inbound all matching", func(t *testing.T) {
t.Parallel()
setup := newSetup(t, l, myPrefix)
tc := buildTestCase(setup, nil, netip.MustParsePrefix("1.2.3.4/24"))
tc.Test(t, setup.fw)
})
t.Run("allow inbound local matching", func(t *testing.T) {
t.Parallel()
setup := newSetup(t, l, myPrefix)
tc := buildTestCase(setup, ErrInvalidLocalIP, netip.MustParsePrefix("1.2.3.4/24"))
tc.p.LocalAddr = netip.MustParseAddr("1.2.3.8")
tc.Test(t, setup.fw)
})
t.Run("block inbound remote mismatched", func(t *testing.T) {
t.Parallel()
setup := newSetup(t, l, myPrefix)
tc := buildTestCase(setup, ErrInvalidRemoteIP, netip.MustParsePrefix("1.2.3.4/24"))
tc.p.RemoteAddr = netip.MustParseAddr("9.9.9.9")
tc.Test(t, setup.fw)
})
t.Run("Block a vpn peer packet", func(t *testing.T) {
t.Parallel()
setup := newSetup(t, l, myPrefix)
tc := buildTestCase(setup, ErrPeerRejected, netip.MustParsePrefix("2.2.2.2/24"))
tc.Test(t, setup.fw)
})
twoPrefixes := []netip.Prefix{
netip.MustParsePrefix("1.2.3.4/24"), netip.MustParsePrefix("2.2.2.2/24"),
}
t.Run("allow inbound one matching", func(t *testing.T) {
t.Parallel()
setup := newSetup(t, l, myPrefix)
tc := buildTestCase(setup, nil, twoPrefixes...)
tc.Test(t, setup.fw)
})
t.Run("block inbound multimismatch", func(t *testing.T) {
t.Parallel()
setup := newSetup(t, l, myPrefix)
tc := buildTestCase(setup, ErrInvalidRemoteIP, twoPrefixes...)
tc.p.RemoteAddr = netip.MustParseAddr("9.9.9.9")
tc.Test(t, setup.fw)
})
t.Run("allow inbound 2nd one matching", func(t *testing.T) {
t.Parallel()
setup2 := newSetup(t, l, netip.MustParsePrefix("2.2.2.1/24"))
tc := buildTestCase(setup2, nil, twoPrefixes...)
tc.p.RemoteAddr = twoPrefixes[1].Addr()
tc.Test(t, setup2.fw)
})
t.Run("allow inbound unsafe route", func(t *testing.T) {
t.Parallel()
unsafePrefix := netip.MustParsePrefix("192.168.0.0/24")
c := dummyCert{
name: "me",
networks: []netip.Prefix{myPrefix},
unsafeNetworks: []netip.Prefix{unsafePrefix},
groups: []string{"default-group"},
issuer: "signer-shasum",
}
unsafeSetup := newSetupFromCert(t, l, c)
tc := buildTestCase(unsafeSetup, nil, twoPrefixes...)
tc.p.LocalAddr = netip.MustParseAddr("192.168.0.3")
tc.err = ErrNoMatchingRule
tc.Test(t, unsafeSetup.fw) //should hit firewall and bounce off
require.NoError(t, unsafeSetup.fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, unsafePrefix, "", ""))
tc.err = nil
tc.Test(t, unsafeSetup.fw) //should pass
})
}
type addRuleCall struct { type addRuleCall struct {
incoming bool incoming bool
proto uint8 proto uint8

51
go.mod
View File

@@ -1,40 +1,38 @@
module github.com/slackhq/nebula module github.com/slackhq/nebula
go 1.23.0 go 1.25
toolchain go1.24.1
require ( require (
dario.cat/mergo v1.0.1 dario.cat/mergo v1.0.2
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be
github.com/armon/go-radix v1.0.0 github.com/armon/go-radix v1.0.0
github.com/cyberdelia/go-metrics-graphite v0.0.0-20161219230853-39f87cc3b432 github.com/cyberdelia/go-metrics-graphite v0.0.0-20161219230853-39f87cc3b432
github.com/flynn/noise v1.1.0 github.com/flynn/noise v1.1.0
github.com/gaissmai/bart v0.20.4 github.com/gaissmai/bart v0.25.0
github.com/gogo/protobuf v1.3.2 github.com/gogo/protobuf v1.3.2
github.com/google/gopacket v1.1.19 github.com/google/gopacket v1.1.19
github.com/kardianos/service v1.2.2 github.com/kardianos/service v1.2.4
github.com/miekg/dns v1.1.65 github.com/miekg/dns v1.1.68
github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b
github.com/nbrownus/go-metrics-prometheus v0.0.0-20210712211119-974a6260965f github.com/nbrownus/go-metrics-prometheus v0.0.0-20210712211119-974a6260965f
github.com/prometheus/client_golang v1.22.0 github.com/prometheus/client_golang v1.23.2
github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475 github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475
github.com/sirupsen/logrus v1.9.3 github.com/sirupsen/logrus v1.9.3
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e
github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6 github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6
github.com/stretchr/testify v1.10.0 github.com/stretchr/testify v1.11.1
github.com/vishvananda/netlink v1.3.0 github.com/vishvananda/netlink v1.3.1
github.com/wadey/synctrace v0.0.0-20250612192159-94547ef50dfe go.yaml.in/yaml/v3 v3.0.4
golang.org/x/crypto v0.37.0 golang.org/x/crypto v0.44.0
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 golang.org/x/exp v0.0.0-20230725093048-515e97ebf090
golang.org/x/net v0.39.0 golang.org/x/net v0.46.0
golang.org/x/sync v0.13.0 golang.org/x/sync v0.18.0
golang.org/x/sys v0.32.0 golang.org/x/sys v0.38.0
golang.org/x/term v0.31.0 golang.org/x/term v0.37.0
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2
golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b
golang.zx2c4.com/wireguard/windows v0.5.3 golang.zx2c4.com/wireguard/windows v0.5.3
google.golang.org/protobuf v1.36.6 google.golang.org/protobuf v1.36.10
gopkg.in/yaml.v3 v3.0.1 gopkg.in/yaml.v3 v3.0.1
gvisor.dev/gvisor v0.0.0-20240423190808-9d7a357edefe gvisor.dev/gvisor v0.0.0-20240423190808-9d7a357edefe
) )
@@ -43,18 +41,15 @@ require (
github.com/beorn7/perks v1.0.1 // indirect github.com/beorn7/perks v1.0.1 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect github.com/davecgh/go-spew v1.1.1 // indirect
github.com/emirpasic/gods v1.18.1 // indirect
github.com/google/btree v1.1.2 // indirect github.com/google/btree v1.1.2 // indirect
github.com/google/uuid v1.3.0 // indirect
github.com/heimdalr/dag v1.4.0 // indirect
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/prometheus/client_model v0.6.1 // indirect github.com/prometheus/client_model v0.6.2 // indirect
github.com/prometheus/common v0.62.0 // indirect github.com/prometheus/common v0.66.1 // indirect
github.com/prometheus/procfs v0.15.1 // indirect github.com/prometheus/procfs v0.16.1 // indirect
github.com/timandy/routine v1.1.5 // indirect github.com/vishvananda/netns v0.0.5 // indirect
github.com/vishvananda/netns v0.0.4 // indirect go.yaml.in/yaml/v2 v2.4.2 // indirect
golang.org/x/mod v0.23.0 // indirect golang.org/x/mod v0.24.0 // indirect
golang.org/x/time v0.5.0 // indirect golang.org/x/time v0.7.0 // indirect
golang.org/x/tools v0.30.0 // indirect golang.org/x/tools v0.33.0 // indirect
) )

99
go.sum
View File

@@ -1,6 +1,6 @@
cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
dario.cat/mergo v1.0.1 h1:Ra4+bf83h2ztPIQYNP99R6m+Y7KfnARDfID+a+vLl4s= dario.cat/mergo v1.0.2 h1:85+piFYR1tMbRrLcDwR18y4UKJ3aH1Tbzi24VRW1TK8=
dario.cat/mergo v1.0.1/go.mod h1:uNxQE+84aUszobStD9th8a29P2fMDhsBdgRYvZOxGmk= dario.cat/mergo v1.0.2/go.mod h1:E/hbnu0NxMFBjpMIE34DRGLWqDy0g5FuKDhCb31ngxA=
github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc=
github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc=
github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0=
@@ -22,12 +22,10 @@ github.com/cyberdelia/go-metrics-graphite v0.0.0-20161219230853-39f87cc3b432/go.
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/emirpasic/gods v1.18.1 h1:FXtiHYKDGKCW2KzwZKx0iC0PQmdlorYgdFG9jPXJ1Bc=
github.com/emirpasic/gods v1.18.1/go.mod h1:8tpGGwCnJ5H4r6BWwaV6OrWmMoPhUl5jm/FMNAnJvWQ=
github.com/flynn/noise v1.1.0 h1:KjPQoQCEFdZDiP03phOvGi11+SVVhBG2wOWAorLsstg= github.com/flynn/noise v1.1.0 h1:KjPQoQCEFdZDiP03phOvGi11+SVVhBG2wOWAorLsstg=
github.com/flynn/noise v1.1.0/go.mod h1:xbMo+0i6+IGbYdJhF31t2eR1BIU0CYc12+BNAKwUTag= github.com/flynn/noise v1.1.0/go.mod h1:xbMo+0i6+IGbYdJhF31t2eR1BIU0CYc12+BNAKwUTag=
github.com/gaissmai/bart v0.20.4 h1:Ik47r1fy3jRVU+1eYzKSW3ho2UgBVTVnUS8O993584U= github.com/gaissmai/bart v0.25.0 h1:eqiokVPqM3F94vJ0bTHXHtH91S8zkKL+bKh+BsGOsJM=
github.com/gaissmai/bart v0.20.4/go.mod h1:cEed+ge8dalcbpi8wtS9x9m2hn/fNJH5suhdGQOHnYk= github.com/gaissmai/bart v0.25.0/go.mod h1:GREWQfTLRWz/c5FTOsIw+KkscuFkIV5t8Rp7Nd1Td5c=
github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY= github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY=
@@ -35,8 +33,6 @@ github.com/go-logfmt/logfmt v0.3.0/go.mod h1:Qt1PoO58o5twSAckw1HlFXLmHsOX5/0LbT9
github.com/go-logfmt/logfmt v0.4.0/go.mod h1:3RMwSq7FuexP4Kalkev3ejPJsZTpXXBr9+V4qmtdjCk= github.com/go-logfmt/logfmt v0.4.0/go.mod h1:3RMwSq7FuexP4Kalkev3ejPJsZTpXXBr9+V4qmtdjCk=
github.com/go-logfmt/logfmt v0.5.0/go.mod h1:wCYkCAKZfumFQihp8CzCvQ3paCTfi41vtzG1KdI/P7A= github.com/go-logfmt/logfmt v0.5.0/go.mod h1:wCYkCAKZfumFQihp8CzCvQ3paCTfi41vtzG1KdI/P7A=
github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY=
github.com/go-test/deep v1.1.0 h1:WOcxcdHcvdgThNXjw0t76K42FXTU7HpNQWHpA2HHNlg=
github.com/go-test/deep v1.1.0/go.mod h1:5C2ZWiW0ErCdrYzpqxLbTX7MG14M9iiw8DgHncVwcsE=
github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ=
github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q=
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
@@ -62,18 +58,14 @@ github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/gopacket v1.1.19 h1:ves8RnFZPGiFnTS0uPQStjwru6uO6h+nlr9j6fL7kF8= github.com/google/gopacket v1.1.19 h1:ves8RnFZPGiFnTS0uPQStjwru6uO6h+nlr9j6fL7kF8=
github.com/google/gopacket v1.1.19/go.mod h1:iJ8V8n6KS+z2U1A8pUwu8bW5SyEMkXJB8Yo/Vo+TKTo= github.com/google/gopacket v1.1.19/go.mod h1:iJ8V8n6KS+z2U1A8pUwu8bW5SyEMkXJB8Yo/Vo+TKTo=
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/heimdalr/dag v1.4.0 h1:zG3JA4RDVLc55k3AXAgfwa+EgBNZ0TkfOO3C29Ucpmg=
github.com/heimdalr/dag v1.4.0/go.mod h1:OCh6ghKmU0hPjtwMqWBoNxPmtRioKd1xSu7Zs4sbIqM=
github.com/jpillora/backoff v1.0.0/go.mod h1:J/6gKK9jxlEcS3zixgDgUAsiuZ7yrSoa/FX5e0EB2j4= github.com/jpillora/backoff v1.0.0/go.mod h1:J/6gKK9jxlEcS3zixgDgUAsiuZ7yrSoa/FX5e0EB2j4=
github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU= github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU=
github.com/json-iterator/go v1.1.10/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= github.com/json-iterator/go v1.1.10/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4=
github.com/json-iterator/go v1.1.11/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= github.com/json-iterator/go v1.1.11/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4=
github.com/julienschmidt/httprouter v1.2.0/go.mod h1:SYymIcj16QtmaHHD7aYtjjsJG7VTCxuUUipMqKk8s4w= github.com/julienschmidt/httprouter v1.2.0/go.mod h1:SYymIcj16QtmaHHD7aYtjjsJG7VTCxuUUipMqKk8s4w=
github.com/julienschmidt/httprouter v1.3.0/go.mod h1:JR6WtHb+2LUe8TCKY3cZOxFyyO8IZAc4RVcycCCAKdM= github.com/julienschmidt/httprouter v1.3.0/go.mod h1:JR6WtHb+2LUe8TCKY3cZOxFyyO8IZAc4RVcycCCAKdM=
github.com/kardianos/service v1.2.2 h1:ZvePhAHfvo0A7Mftk/tEzqEZ7Q4lgnR8sGz4xu1YX60= github.com/kardianos/service v1.2.4 h1:XNlGtZOYNx2u91urOdg/Kfmc+gfmuIo1Dd3rEi2OgBk=
github.com/kardianos/service v1.2.2/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= github.com/kardianos/service v1.2.4/go.mod h1:E4V9ufUuY82F7Ztlu1eN9VXWIQxg8NoLQlmFe0MtrXc=
github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8= github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8=
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo= github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo=
@@ -91,8 +83,8 @@ github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc=
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0=
github.com/miekg/dns v1.1.65 h1:0+tIPHzUW0GCge7IiK3guGP57VAw7hoPDfApjkMD1Fc= github.com/miekg/dns v1.1.68 h1:jsSRkNozw7G/mnmXULynzMNIsgY2dHC8LO6U6Ij2JEA=
github.com/miekg/dns v1.1.65/go.mod h1:Dzw9769uoKVaLuODMDZz9M6ynFU6Em65csPuoi8G0ck= github.com/miekg/dns v1.1.68/go.mod h1:fujopn7TB3Pu3JM69XaawiU0wqjpL9/8xGop5UrTPps=
github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b h1:J/AzCvg5z0Hn1rqZUJjpbzALUmkKX0Zwbc/i4fw7Sfk= github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b h1:J/AzCvg5z0Hn1rqZUJjpbzALUmkKX0Zwbc/i4fw7Sfk=
github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b/go.mod h1:XsNlhZGX73bx86s2hdc/FuaLm2CPZJemRLMA+WTFxgs= github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b/go.mod h1:XsNlhZGX73bx86s2hdc/FuaLm2CPZJemRLMA+WTFxgs=
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
@@ -114,24 +106,24 @@ github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXP
github.com/prometheus/client_golang v1.0.0/go.mod h1:db9x61etRT2tGnBNRi70OPL5FsnadC4Ky3P0J6CfImo= github.com/prometheus/client_golang v1.0.0/go.mod h1:db9x61etRT2tGnBNRi70OPL5FsnadC4Ky3P0J6CfImo=
github.com/prometheus/client_golang v1.7.1/go.mod h1:PY5Wy2awLA44sXw4AOSfFBetzPP4j5+D6mVACh+pe2M= github.com/prometheus/client_golang v1.7.1/go.mod h1:PY5Wy2awLA44sXw4AOSfFBetzPP4j5+D6mVACh+pe2M=
github.com/prometheus/client_golang v1.11.0/go.mod h1:Z6t4BnS23TR94PD6BsDNk8yVqroYurpAkEiz0P2BEV0= github.com/prometheus/client_golang v1.11.0/go.mod h1:Z6t4BnS23TR94PD6BsDNk8yVqroYurpAkEiz0P2BEV0=
github.com/prometheus/client_golang v1.22.0 h1:rb93p9lokFEsctTys46VnV1kLCDpVZ0a/Y92Vm0Zc6Q= github.com/prometheus/client_golang v1.23.2 h1:Je96obch5RDVy3FDMndoUsjAhG5Edi49h0RJWRi/o0o=
github.com/prometheus/client_golang v1.22.0/go.mod h1:R7ljNsLXhuQXYZYtw6GAE9AZg8Y7vEW5scdCXrWRXC0= github.com/prometheus/client_golang v1.23.2/go.mod h1:Tb1a6LWHB3/SPIzCoaDXI4I8UHKeFTEQ1YCr+0Gyqmg=
github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo= github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo=
github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
github.com/prometheus/client_model v0.2.0/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= github.com/prometheus/client_model v0.2.0/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
github.com/prometheus/client_model v0.6.1 h1:ZKSh/rekM+n3CeS952MLRAdFwIKqeY8b62p8ais2e9E= github.com/prometheus/client_model v0.6.2 h1:oBsgwpGs7iVziMvrGhE53c/GrLUsZdHnqNwqPLxwZyk=
github.com/prometheus/client_model v0.6.1/go.mod h1:OrxVMOVHjw3lKMa8+x6HeMGkHMQyHDk9E3jmP2AmGiY= github.com/prometheus/client_model v0.6.2/go.mod h1:y3m2F6Gdpfy6Ut/GBsUqTWZqCUvMVzSfMLjcu6wAwpE=
github.com/prometheus/common v0.4.1/go.mod h1:TNfzLD0ON7rHzMJeJkieUDPYmFC7Snx/y86RQel1bk4= github.com/prometheus/common v0.4.1/go.mod h1:TNfzLD0ON7rHzMJeJkieUDPYmFC7Snx/y86RQel1bk4=
github.com/prometheus/common v0.10.0/go.mod h1:Tlit/dnDKsSWFlCLTWaA1cyBgKHSMdTB80sz/V91rCo= github.com/prometheus/common v0.10.0/go.mod h1:Tlit/dnDKsSWFlCLTWaA1cyBgKHSMdTB80sz/V91rCo=
github.com/prometheus/common v0.26.0/go.mod h1:M7rCNAaPfAosfx8veZJCuw84e35h3Cfd9VFqTh1DIvc= github.com/prometheus/common v0.26.0/go.mod h1:M7rCNAaPfAosfx8veZJCuw84e35h3Cfd9VFqTh1DIvc=
github.com/prometheus/common v0.62.0 h1:xasJaQlnWAeyHdUBeGjXmutelfJHWMRr+Fg4QszZ2Io= github.com/prometheus/common v0.66.1 h1:h5E0h5/Y8niHc5DlaLlWLArTQI7tMrsfQjHV+d9ZoGs=
github.com/prometheus/common v0.62.0/go.mod h1:vyBcEuLSvWos9B1+CyL7JZ2up+uFzXhkqml0W5zIY1I= github.com/prometheus/common v0.66.1/go.mod h1:gcaUsgf3KfRSwHY4dIMXLPV0K/Wg1oZ8+SbZk/HH/dA=
github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk= github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk=
github.com/prometheus/procfs v0.0.2/go.mod h1:TjEm7ze935MbeOT/UhFTIMYKhuLP4wbCsTZCD3I8kEA= github.com/prometheus/procfs v0.0.2/go.mod h1:TjEm7ze935MbeOT/UhFTIMYKhuLP4wbCsTZCD3I8kEA=
github.com/prometheus/procfs v0.1.3/go.mod h1:lV6e/gmhEcM9IjHGsFOCxxuZ+z1YqCvr4OA4YeYWdaU= github.com/prometheus/procfs v0.1.3/go.mod h1:lV6e/gmhEcM9IjHGsFOCxxuZ+z1YqCvr4OA4YeYWdaU=
github.com/prometheus/procfs v0.6.0/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1xBZuNvfVA= github.com/prometheus/procfs v0.6.0/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1xBZuNvfVA=
github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc= github.com/prometheus/procfs v0.16.1 h1:hZ15bTNuirocR6u0JZ6BAHHmwS1p8B4P6MRqxtzMyRg=
github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk= github.com/prometheus/procfs v0.16.1/go.mod h1:teAbpZRB1iIAJYREa1LsoWUXykVXA1KlTmWl8x/U+Is=
github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475 h1:N/ElC8H3+5XpJzTSTfLsJV/mx9Q9g7kxmchpfZyxgzM= github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475 h1:N/ElC8H3+5XpJzTSTfLsJV/mx9Q9g7kxmchpfZyxgzM=
github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4= github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4=
github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ= github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ=
@@ -151,33 +143,35 @@ github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXf
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
github.com/timandy/routine v1.1.5 h1:LSpm7Iijwb9imIPlucl4krpr2EeCeAUvifiQ9Uf5X+M= github.com/vishvananda/netlink v1.3.1 h1:3AEMt62VKqz90r0tmNhog0r/PpWKmrEShJU0wJW6bV0=
github.com/timandy/routine v1.1.5/go.mod h1:kXslgIosdY8LW0byTyPnenDgn4/azt2euufAq9rK51w= github.com/vishvananda/netlink v1.3.1/go.mod h1:ARtKouGSTGchR8aMwmkzC0qiNPrrWO5JS/XMVl45+b4=
github.com/vishvananda/netlink v1.3.0 h1:X7l42GfcV4S6E4vHTsw48qbrV+9PVojNfIhZcwQdrZk= github.com/vishvananda/netns v0.0.5 h1:DfiHV+j8bA32MFM7bfEunvT8IAqQ/NzSJHtcmW5zdEY=
github.com/vishvananda/netlink v1.3.0/go.mod h1:i6NetklAujEcC6fK0JPjT8qSwWyO0HLn4UKG+hGqeJs= github.com/vishvananda/netns v0.0.5/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM=
github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8=
github.com/vishvananda/netns v0.0.4/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM=
github.com/wadey/synctrace v0.0.0-20250612192159-94547ef50dfe h1:dc8Q42VsX+ABr0drJw27f3smvGfcz7eB8rJx+IkVMAo=
github.com/wadey/synctrace v0.0.0-20250612192159-94547ef50dfe/go.mod h1:F2VCml4UxGPgAAqqm9T0ZfnVRWITrQS1EMZM+KCAm/Q=
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
go.yaml.in/yaml/v2 v2.4.2 h1:DzmwEr2rDGHl7lsFgAHxmNz/1NlQ7xLIrlN2h5d1eGI=
go.yaml.in/yaml/v2 v2.4.2/go.mod h1:081UH+NErpNdqlCXm3TtEran0rJZGxAYx9hb/ELlsPU=
go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc=
go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg=
golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
golang.org/x/crypto v0.37.0 h1:kJNSjF/Xp7kU0iB2Z+9viTPMW4EqqsrywMXLJOOsXSE= golang.org/x/crypto v0.44.0 h1:A97SsFvM3AIwEEmTBiaxPPTYpDC47w720rdiiUvgoAU=
golang.org/x/crypto v0.37.0/go.mod h1:vg+k43peMZ0pUMhYmVAWysMK35e6ioLh3wB8ZCAfbVc= golang.org/x/crypto v0.44.0/go.mod h1:013i+Nw79BMiQiMsOPcVCB5ZIJbYkerPrGnOa00tvmc=
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 h1:Di6/M8l0O2lCLc6VVRWhgCiApHV8MnQurBnFSHsQtNY= golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 h1:Di6/M8l0O2lCLc6VVRWhgCiApHV8MnQurBnFSHsQtNY=
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc= golang.org/x/exp v0.0.0-20230725093048-515e97ebf090/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc=
golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY= golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY=
golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg=
golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.23.0 h1:Zb7khfcRGKk+kqfxFaP5tZqCnDZMjC5VtUBs87Hr6QM= golang.org/x/mod v0.24.0 h1:ZfthKaKaT4NrhGVZHO1/WDTwGES4De8KtWO0SIbNJMU=
golang.org/x/mod v0.23.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY= golang.org/x/mod v0.24.0/go.mod h1:IXM97Txy2VM4PJ3gI61r1YEk/gAj6zAHN3AdZt6S9Ww=
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
@@ -188,8 +182,8 @@ golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLL
golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.39.0 h1:ZCu7HMWDxpXpaiKdhzIfaltL9Lp31x/3fCP11bc6/fY= golang.org/x/net v0.46.0 h1:giFlY12I07fugqwPuWJi68oOnpfqFnJIJzaIIm2JVV4=
golang.org/x/net v0.39.0/go.mod h1:X7NRbYVEA+ewNkCNyJ513WmMdQ3BineSwVtN2zD/d+E= golang.org/x/net v0.46.0/go.mod h1:Q9BGdFy1y4nkUwiLvT5qtyhAnEHgnQ/zd8PfU6nc210=
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
@@ -197,8 +191,8 @@ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJ
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.13.0 h1:AauUjRAJ9OSnvULf/ARrrVywoJDy0YS2AwQ98I37610= golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I=
golang.org/x/sync v0.13.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
@@ -209,30 +203,29 @@ golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20200615200032-f1bc736245b1/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200615200032-f1bc736245b1/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200625212154-ddb9806d33ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200625212154-ddb9806d33ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201015000850-e3ed0017c211/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210603081109-ebe580a85c40/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210603081109-ebe580a85c40/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.32.0 h1:s77OFDvIQeibCmezSnk/q6iAfkdiQaJi4VzroCFrN20= golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc=
golang.org/x/sys v0.32.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.31.0 h1:erwDkOK1Msy6offm1mOgvspSkslFnIGsFnxOKoufg3o= golang.org/x/term v0.37.0 h1:8EGAD0qCmHYZg6J17DvsMy9/wJ7/D/4pV/wfnld5lTU=
golang.org/x/term v0.31.0/go.mod h1:R4BeIy7D95HzImkxGkTW1UQTtP54tio2RyHz7PwK0aw= golang.org/x/term v0.37.0/go.mod h1:5pB4lxRNYYVZuTLmy8oR2BH8dflOR+IbTYFD8fi3254=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk= golang.org/x/time v0.7.0 h1:ntUhktv3OPE6TgYxXWv9vKvUSJyIFJlyohwbkEwPrKQ=
golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= golang.org/x/time v0.7.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
golang.org/x/tools v0.30.0 h1:BgcpHewrV5AUp2G9MebG4XPFI1E2W41zU1SaqVA9vJY= golang.org/x/tools v0.33.0 h1:4qz2S3zmRxbGIhDIAgjxvFutSvH5EfnsYrRBj0UI0bc=
golang.org/x/tools v0.30.0/go.mod h1:c347cR/OJfw5TI+GfX7RUPNMdDRRbjvYTS0jPyvsVtY= golang.org/x/tools v0.33.0/go.mod h1:CIJMaWEY88juyUfo7UbgPqbC8rU2OqfAV1h2Qp0oMYI=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
@@ -251,8 +244,8 @@ google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miE
google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo= google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo=
google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU=
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
google.golang.org/protobuf v1.36.6 h1:z1NpPI8ku2WgiWnf+t9wTPsn6eP1L7ksHUlkfLvd9xY= google.golang.org/protobuf v1.36.10 h1:AYd7cD/uASjIL6Q9LiTjz8JLcrh/88q5UObnmY3aOOE=
google.golang.org/protobuf v1.36.6/go.mod h1:jduwjTPXsFjZGTmRluh+L6NjiWu7pchiJ2/5YcXBHnY= google.golang.org/protobuf v1.36.10/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco=
gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw= gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=

View File

@@ -2,14 +2,12 @@ package nebula
import ( import (
"net/netip" "net/netip"
"slices"
"time" "time"
"github.com/flynn/noise" "github.com/flynn/noise"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/header" "github.com/slackhq/nebula/header"
"github.com/wadey/synctrace"
) )
// NOISE IX Handshakes // NOISE IX Handshakes
@@ -24,15 +22,19 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
return false return false
} }
// If we're connecting to a v6 address we must use a v2 cert
cs := f.pki.getCertState() cs := f.pki.getCertState()
v := cs.initiatingVersion v := cs.initiatingVersion
if hh.initiatingVersionOverride != cert.VersionPre1 {
v = hh.initiatingVersionOverride
} else if v < cert.Version2 {
// If we're connecting to a v6 address we should encourage use of a V2 cert
for _, a := range hh.hostinfo.vpnAddrs { for _, a := range hh.hostinfo.vpnAddrs {
if a.Is6() { if a.Is6() {
v = cert.Version2 v = cert.Version2
break break
} }
} }
}
crt := cs.getCertificate(v) crt := cs.getCertificate(v)
if crt == nil { if crt == nil {
@@ -49,6 +51,7 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}). WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
WithField("certVersion", v). WithField("certVersion", v).
Error("Unable to handshake with host because no certificate handshake bytes is available") Error("Unable to handshake with host because no certificate handshake bytes is available")
return false
} }
ci, err := NewConnectionState(f.l, cs, crt, true, noise.HandshakeIX) ci, err := NewConnectionState(f.l, cs, crt, true, noise.HandshakeIX)
@@ -104,6 +107,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}). WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
WithField("certVersion", cs.initiatingVersion). WithField("certVersion", cs.initiatingVersion).
Error("Unable to handshake with host because no certificate is available") Error("Unable to handshake with host because no certificate is available")
return
} }
ci, err := NewConnectionState(f.l, cs, crt, false, noise.HandshakeIX) ci, err := NewConnectionState(f.l, cs, crt, false, noise.HandshakeIX)
@@ -144,8 +148,8 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
remoteCert, err := f.pki.GetCAPool().VerifyCertificate(time.Now(), rc) remoteCert, err := f.pki.GetCAPool().VerifyCertificate(time.Now(), rc)
if err != nil { if err != nil {
fp, err := rc.Fingerprint() fp, fperr := rc.Fingerprint()
if err != nil { if fperr != nil {
fp = "<error generating certificate fingerprint>" fp = "<error generating certificate fingerprint>"
} }
@@ -164,16 +168,19 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
if remoteCert.Certificate.Version() != ci.myCert.Version() { if remoteCert.Certificate.Version() != ci.myCert.Version() {
// We started off using the wrong certificate version, lets see if we can match the version that was sent to us // We started off using the wrong certificate version, lets see if we can match the version that was sent to us
rc := cs.getCertificate(remoteCert.Certificate.Version()) myCertOtherVersion := cs.getCertificate(remoteCert.Certificate.Version())
if rc == nil { if myCertOtherVersion == nil {
f.l.WithError(err).WithField("udpAddr", addr). if f.l.Level >= logrus.DebugLevel {
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).WithField("cert", remoteCert). f.l.WithError(err).WithFields(m{
Info("Unable to handshake with host due to missing certificate version") "udpAddr": addr,
return "handshake": m{"stage": 1, "style": "ix_psk0"},
"cert": remoteCert,
}).Debug("Might be unable to handshake with host due to missing certificate version")
} }
} else {
// Record the certificate we are actually using // Record the certificate we are actually using
ci.myCert = rc ci.myCert = myCertOtherVersion
}
} }
if len(remoteCert.Certificate.Networks()) == 0 { if len(remoteCert.Certificate.Networks()) == 0 {
@@ -184,17 +191,17 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
return return
} }
var vpnAddrs []netip.Addr
var filteredNetworks []netip.Prefix
certName := remoteCert.Certificate.Name() certName := remoteCert.Certificate.Name()
certVersion := remoteCert.Certificate.Version() certVersion := remoteCert.Certificate.Version()
fingerprint := remoteCert.Fingerprint fingerprint := remoteCert.Fingerprint
issuer := remoteCert.Certificate.Issuer() issuer := remoteCert.Certificate.Issuer()
vpnNetworks := remoteCert.Certificate.Networks()
for _, network := range remoteCert.Certificate.Networks() { anyVpnAddrsInCommon := false
vpnAddr := network.Addr() vpnAddrs := make([]netip.Addr, len(vpnNetworks))
if f.myVpnAddrsTable.Contains(vpnAddr) { for i, network := range vpnNetworks {
f.l.WithField("vpnAddr", vpnAddr).WithField("udpAddr", addr). if f.myVpnAddrsTable.Contains(network.Addr()) {
f.l.WithField("vpnNetworks", vpnNetworks).WithField("udpAddr", addr).
WithField("certName", certName). WithField("certName", certName).
WithField("certVersion", certVersion). WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint). WithField("fingerprint", fingerprint).
@@ -202,24 +209,10 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Refusing to handshake with myself") WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Refusing to handshake with myself")
return return
} }
vpnAddrs[i] = network.Addr()
// vpnAddrs outside our vpn networks are of no use to us, filter them out if f.myVpnNetworksTable.Contains(network.Addr()) {
if !f.myVpnNetworksTable.Contains(vpnAddr) { anyVpnAddrsInCommon = true
continue
} }
filteredNetworks = append(filteredNetworks, network)
vpnAddrs = append(vpnAddrs, vpnAddr)
}
if len(vpnAddrs) == 0 {
f.l.WithError(err).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint).
WithField("issuer", issuer).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("No usable vpn addresses from host, refusing handshake")
return
} }
if addr.IsValid() { if addr.IsValid() {
@@ -250,33 +243,36 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
HandshakePacket: make(map[uint8][]byte, 0), HandshakePacket: make(map[uint8][]byte, 0),
lastHandshakeTime: hs.Details.Time, lastHandshakeTime: hs.Details.Time,
relayState: RelayState{ relayState: RelayState{
RWMutex: synctrace.NewRWMutex("relay-state"), relays: nil,
relays: map[netip.Addr]struct{}{},
relayForByAddr: map[netip.Addr]*Relay{}, relayForByAddr: map[netip.Addr]*Relay{},
relayForByIdx: map[uint32]*Relay{}, relayForByIdx: map[uint32]*Relay{},
}, },
} }
f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr). msgRxL := f.l.WithFields(m{
WithField("certName", certName). "vpnAddrs": vpnAddrs,
WithField("certVersion", certVersion). "udpAddr": addr,
WithField("fingerprint", fingerprint). "certName": certName,
WithField("issuer", issuer). "certVersion": certVersion,
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). "fingerprint": fingerprint,
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). "issuer": issuer,
Info("Handshake message received") "initiatorIndex": hs.Details.InitiatorIndex,
"responderIndex": hs.Details.ResponderIndex,
"remoteIndex": h.RemoteIndex,
"handshake": m{"stage": 1, "style": "ix_psk0"},
})
if anyVpnAddrsInCommon {
msgRxL.Info("Handshake message received")
} else {
//todo warn if not lighthouse or relay?
msgRxL.Info("Handshake message received, but no vpnNetworks in common.")
}
hs.Details.ResponderIndex = myIndex hs.Details.ResponderIndex = myIndex
hs.Details.Cert = cs.getHandshakeBytes(ci.myCert.Version()) hs.Details.Cert = cs.getHandshakeBytes(ci.myCert.Version())
if hs.Details.Cert == nil { if hs.Details.Cert == nil {
f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr). msgRxL.WithField("myCertVersion", ci.myCert.Version()).
WithField("certName", certName).
WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint).
WithField("issuer", issuer).
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
WithField("certVersion", ci.myCert.Version()).
Error("Unable to handshake with host because no certificate handshake bytes is available") Error("Unable to handshake with host because no certificate handshake bytes is available")
return return
} }
@@ -334,7 +330,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
hostinfo.remotes = f.lightHouse.QueryCache(vpnAddrs) hostinfo.remotes = f.lightHouse.QueryCache(vpnAddrs)
hostinfo.SetRemote(addr) hostinfo.SetRemote(addr)
hostinfo.buildNetworks(filteredNetworks, remoteCert.Certificate.UnsafeNetworks()) hostinfo.buildNetworks(f.myVpnNetworksTable, remoteCert.Certificate)
existing, err := f.handshakeManager.CheckAndComplete(hostinfo, 0, f) existing, err := f.handshakeManager.CheckAndComplete(hostinfo, 0, f)
if err != nil { if err != nil {
@@ -459,9 +455,9 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
Info("Handshake message sent") Info("Handshake message sent")
} }
f.connectionManager.AddTrafficWatch(hostinfo.localIndexId) f.connectionManager.AddTrafficWatch(hostinfo)
hostinfo.remotes.ResetBlockedRemotes() hostinfo.remotes.RefreshFromHandshake(vpnAddrs)
return return
} }
@@ -575,31 +571,22 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0]) hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0])
} }
var vpnAddrs []netip.Addr correctHostResponded := false
var filteredNetworks []netip.Prefix anyVpnAddrsInCommon := false
for _, network := range vpnNetworks { vpnAddrs := make([]netip.Addr, len(vpnNetworks))
// vpnAddrs outside our vpn networks are of no use to us, filter them out for i, network := range vpnNetworks {
vpnAddr := network.Addr() vpnAddrs[i] = network.Addr()
if !f.myVpnNetworksTable.Contains(vpnAddr) { if f.myVpnNetworksTable.Contains(network.Addr()) {
continue anyVpnAddrsInCommon = true
} }
if hostinfo.vpnAddrs[0] == network.Addr() {
filteredNetworks = append(filteredNetworks, network) // todo is it more correct to see if any of hostinfo.vpnAddrs are in the cert? it should have len==1, but one day it might not?
vpnAddrs = append(vpnAddrs, vpnAddr) correctHostResponded = true
} }
if len(vpnAddrs) == 0 {
f.l.WithError(err).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint).
WithField("issuer", issuer).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).Error("No usable vpn addresses from host, refusing handshake")
return true
} }
// Ensure the right host responded // Ensure the right host responded
if !slices.Contains(vpnAddrs, hostinfo.vpnAddrs[0]) { if !correctHostResponded {
f.l.WithField("intendedVpnAddrs", hostinfo.vpnAddrs).WithField("haveVpnNetworks", vpnNetworks). f.l.WithField("intendedVpnAddrs", hostinfo.vpnAddrs).WithField("haveVpnNetworks", vpnNetworks).
WithField("udpAddr", addr). WithField("udpAddr", addr).
WithField("certName", certName). WithField("certName", certName).
@@ -611,6 +598,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
f.handshakeManager.DeleteHostInfo(hostinfo) f.handshakeManager.DeleteHostInfo(hostinfo)
// Create a new hostinfo/handshake for the intended vpn ip // Create a new hostinfo/handshake for the intended vpn ip
//TODO is hostinfo.vpnAddrs[0] always the address to use?
f.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], func(newHH *HandshakeHostInfo) { f.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], func(newHH *HandshakeHostInfo) {
// Block the current used address // Block the current used address
newHH.hostinfo.remotes = hostinfo.remotes newHH.hostinfo.remotes = hostinfo.remotes
@@ -637,7 +625,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
ci.window.Update(f.l, 2) ci.window.Update(f.l, 2)
duration := time.Since(hh.startTime).Nanoseconds() duration := time.Since(hh.startTime).Nanoseconds()
f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr). msgRxL := f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
WithField("certName", certName). WithField("certName", certName).
WithField("certVersion", certVersion). WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint). WithField("fingerprint", fingerprint).
@@ -645,16 +633,21 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
WithField("durationNs", duration). WithField("durationNs", duration).
WithField("sentCachedPackets", len(hh.packetStore)). WithField("sentCachedPackets", len(hh.packetStore))
Info("Handshake message received") if anyVpnAddrsInCommon {
msgRxL.Info("Handshake message received")
} else {
//todo warn if not lighthouse or relay?
msgRxL.Info("Handshake message received, but no vpnNetworks in common.")
}
// Build up the radix for the firewall if we have subnets in the cert // Build up the radix for the firewall if we have subnets in the cert
hostinfo.vpnAddrs = vpnAddrs hostinfo.vpnAddrs = vpnAddrs
hostinfo.buildNetworks(filteredNetworks, remoteCert.Certificate.UnsafeNetworks()) hostinfo.buildNetworks(f.myVpnNetworksTable, remoteCert.Certificate)
// Complete our handshake and update metrics, this will replace any existing tunnels for the vpnAddrs here // Complete our handshake and update metrics, this will replace any existing tunnels for the vpnAddrs here
f.handshakeManager.Complete(hostinfo, f) f.handshakeManager.Complete(hostinfo, f)
f.connectionManager.AddTrafficWatch(hostinfo.localIndexId) f.connectionManager.AddTrafficWatch(hostinfo)
if f.l.Level >= logrus.DebugLevel { if f.l.Level >= logrus.DebugLevel {
hostinfo.logger(f.l).Debugf("Sending %d stored packets", len(hh.packetStore)) hostinfo.logger(f.l).Debugf("Sending %d stored packets", len(hh.packetStore))
@@ -669,7 +662,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
f.cachedPacketMetrics.sent.Inc(int64(len(hh.packetStore))) f.cachedPacketMetrics.sent.Inc(int64(len(hh.packetStore)))
} }
hostinfo.remotes.ResetBlockedRemotes() hostinfo.remotes.RefreshFromHandshake(vpnAddrs)
f.metricHandshakes.Update(duration) f.metricHandshakes.Update(duration)
return false return false

View File

@@ -8,6 +8,7 @@ import (
"errors" "errors"
"net/netip" "net/netip"
"slices" "slices"
"sync"
"time" "time"
"github.com/rcrowley/go-metrics" "github.com/rcrowley/go-metrics"
@@ -15,7 +16,6 @@ import (
"github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/header" "github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/udp" "github.com/slackhq/nebula/udp"
"github.com/wadey/synctrace"
) )
const ( const (
@@ -45,7 +45,7 @@ type HandshakeConfig struct {
type HandshakeManager struct { type HandshakeManager struct {
// Mutex for interacting with the vpnIps and indexes maps // Mutex for interacting with the vpnIps and indexes maps
synctrace.RWMutex sync.RWMutex
vpnIps map[netip.Addr]*HandshakeHostInfo vpnIps map[netip.Addr]*HandshakeHostInfo
indexes map[uint32]*HandshakeHostInfo indexes map[uint32]*HandshakeHostInfo
@@ -66,10 +66,11 @@ type HandshakeManager struct {
} }
type HandshakeHostInfo struct { type HandshakeHostInfo struct {
synctrace.Mutex sync.Mutex
startTime time.Time // Time that we first started trying with this handshake startTime time.Time // Time that we first started trying with this handshake
ready bool // Is the handshake ready ready bool // Is the handshake ready
initiatingVersionOverride cert.Version // Should we use a non-default cert version for this handshake?
counter int64 // How many attempts have we made so far counter int64 // How many attempts have we made so far
lastRemotes []netip.AddrPort // Remotes that we sent to during the previous attempt lastRemotes []netip.AddrPort // Remotes that we sent to during the previous attempt
packetStore []*cachedPacket // A set of packets to be transmitted once the handshake completes packetStore []*cachedPacket // A set of packets to be transmitted once the handshake completes
@@ -104,7 +105,6 @@ func (hh *HandshakeHostInfo) cachePacket(l *logrus.Logger, t header.MessageType,
func NewHandshakeManager(l *logrus.Logger, mainHostMap *HostMap, lightHouse *LightHouse, outside udp.Conn, config HandshakeConfig) *HandshakeManager { func NewHandshakeManager(l *logrus.Logger, mainHostMap *HostMap, lightHouse *LightHouse, outside udp.Conn, config HandshakeConfig) *HandshakeManager {
return &HandshakeManager{ return &HandshakeManager{
RWMutex: synctrace.NewRWMutex("handshake-manager"),
vpnIps: map[netip.Addr]*HandshakeHostInfo{}, vpnIps: map[netip.Addr]*HandshakeHostInfo{},
indexes: map[uint32]*HandshakeHostInfo{}, indexes: map[uint32]*HandshakeHostInfo{},
mainHostMap: mainHostMap, mainHostMap: mainHostMap,
@@ -112,7 +112,7 @@ func NewHandshakeManager(l *logrus.Logger, mainHostMap *HostMap, lightHouse *Lig
outside: outside, outside: outside,
config: config, config: config,
trigger: make(chan netip.Addr, config.triggerBuffer), trigger: make(chan netip.Addr, config.triggerBuffer),
OutboundHandshakeTimer: NewLockingTimerWheel[netip.Addr]("outbound-handshake-timer", config.tryInterval, hsTimeout(config.retries, config.tryInterval)), OutboundHandshakeTimer: NewLockingTimerWheel[netip.Addr](config.tryInterval, hsTimeout(config.retries, config.tryInterval)),
messageMetrics: config.messageMetrics, messageMetrics: config.messageMetrics,
metricInitiated: metrics.GetOrRegisterCounter("handshake_manager.initiated", nil), metricInitiated: metrics.GetOrRegisterCounter("handshake_manager.initiated", nil),
metricTimedOut: metrics.GetOrRegisterCounter("handshake_manager.timed_out", nil), metricTimedOut: metrics.GetOrRegisterCounter("handshake_manager.timed_out", nil),
@@ -269,12 +269,12 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
hostinfo.logger(hm.l).WithField("relays", hostinfo.remotes.relays).Info("Attempt to relay through hosts") hostinfo.logger(hm.l).WithField("relays", hostinfo.remotes.relays).Info("Attempt to relay through hosts")
// Send a RelayRequest to all known Relay IP's // Send a RelayRequest to all known Relay IP's
for _, relay := range hostinfo.remotes.relays { for _, relay := range hostinfo.remotes.relays {
// Don't relay to myself // Don't relay through the host I'm trying to connect to
if relay == vpnIp { if relay == vpnIp {
continue continue
} }
// Don't relay through the host I'm trying to connect to // Don't relay to myself
if hm.f.myVpnAddrsTable.Contains(relay) { if hm.f.myVpnAddrsTable.Contains(relay) {
continue continue
} }
@@ -451,15 +451,13 @@ func (hm *HandshakeManager) StartHandshake(vpnAddr netip.Addr, cacheCb func(*Han
vpnAddrs: []netip.Addr{vpnAddr}, vpnAddrs: []netip.Addr{vpnAddr},
HandshakePacket: make(map[uint8][]byte, 0), HandshakePacket: make(map[uint8][]byte, 0),
relayState: RelayState{ relayState: RelayState{
RWMutex: synctrace.NewRWMutex("relay-state"), relays: nil,
relays: map[netip.Addr]struct{}{},
relayForByAddr: map[netip.Addr]*Relay{}, relayForByAddr: map[netip.Addr]*Relay{},
relayForByIdx: map[uint32]*Relay{}, relayForByIdx: map[uint32]*Relay{},
}, },
} }
hh := &HandshakeHostInfo{ hh := &HandshakeHostInfo{
Mutex: synctrace.NewMutex("handshake-hostinfo"),
hostinfo: hostinfo, hostinfo: hostinfo,
startTime: time.Now(), startTime: time.Now(),
} }

View File

@@ -4,6 +4,8 @@ import (
"errors" "errors"
"net" "net"
"net/netip" "net/netip"
"slices"
"sync"
"sync/atomic" "sync/atomic"
"time" "time"
@@ -13,15 +15,12 @@ import (
"github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/header" "github.com/slackhq/nebula/header"
"github.com/wadey/synctrace"
) )
// const ProbeLen = 100
const defaultPromoteEvery = 1000 // Count of packets sent before we try moving a tunnel to a preferred underlay ip address const defaultPromoteEvery = 1000 // Count of packets sent before we try moving a tunnel to a preferred underlay ip address
const defaultReQueryEvery = 5000 // Count of packets sent before re-querying a hostinfo to the lighthouse const defaultReQueryEvery = 5000 // Count of packets sent before re-querying a hostinfo to the lighthouse
const defaultReQueryWait = time.Minute // Minimum amount of seconds to wait before re-querying a hostinfo the lighthouse. Evaluated every ReQueryEvery const defaultReQueryWait = time.Minute // Minimum amount of seconds to wait before re-querying a hostinfo the lighthouse. Evaluated every ReQueryEvery
const MaxRemotes = 10 const MaxRemotes = 10
const maxRecvError = 4
// MaxHostInfosPerVpnIp is the max number of hostinfos we will track for a given vpn ip // MaxHostInfosPerVpnIp is the max number of hostinfos we will track for a given vpn ip
// 5 allows for an initial handshake and each host pair re-handshaking twice // 5 allows for an initial handshake and each host pair re-handshaking twice
@@ -53,7 +52,7 @@ type Relay struct {
} }
type HostMap struct { type HostMap struct {
synctrace.RWMutex //Because we concurrently read and write to our maps sync.RWMutex //Because we concurrently read and write to our maps
Indexes map[uint32]*HostInfo Indexes map[uint32]*HostInfo
Relays map[uint32]*HostInfo // Maps a Relay IDX to a Relay HostInfo object Relays map[uint32]*HostInfo // Maps a Relay IDX to a Relay HostInfo object
RemoteIndexes map[uint32]*HostInfo RemoteIndexes map[uint32]*HostInfo
@@ -66,9 +65,9 @@ type HostMap struct {
// struct, make a copy of an existing value, edit the fileds in the copy, and // struct, make a copy of an existing value, edit the fileds in the copy, and
// then store a pointer to the new copy in both realyForBy* maps. // then store a pointer to the new copy in both realyForBy* maps.
type RelayState struct { type RelayState struct {
synctrace.RWMutex sync.RWMutex
relays map[netip.Addr]struct{} // Set of vpnAddr's of Hosts to use as relays to access this peer relays []netip.Addr // Ordered set of VpnAddrs of Hosts to use as relays to access this peer
// For data race avoidance, the contents of a *Relay are treated immutably. To update a *Relay, copy the existing data, // For data race avoidance, the contents of a *Relay are treated immutably. To update a *Relay, copy the existing data,
// modify what needs to be updated, and store the new modified copy in the relayForByIp and relayForByIdx maps (with // modify what needs to be updated, and store the new modified copy in the relayForByIp and relayForByIdx maps (with
// the RelayState Lock held) // the RelayState Lock held)
@@ -79,7 +78,12 @@ type RelayState struct {
func (rs *RelayState) DeleteRelay(ip netip.Addr) { func (rs *RelayState) DeleteRelay(ip netip.Addr) {
rs.Lock() rs.Lock()
defer rs.Unlock() defer rs.Unlock()
delete(rs.relays, ip) for idx, val := range rs.relays {
if val == ip {
rs.relays = append(rs.relays[:idx], rs.relays[idx+1:]...)
return
}
}
} }
func (rs *RelayState) UpdateRelayForByIpState(vpnIp netip.Addr, state int) { func (rs *RelayState) UpdateRelayForByIpState(vpnIp netip.Addr, state int) {
@@ -124,16 +128,16 @@ func (rs *RelayState) GetRelayForByAddr(addr netip.Addr) (*Relay, bool) {
func (rs *RelayState) InsertRelayTo(ip netip.Addr) { func (rs *RelayState) InsertRelayTo(ip netip.Addr) {
rs.Lock() rs.Lock()
defer rs.Unlock() defer rs.Unlock()
rs.relays[ip] = struct{}{} if !slices.Contains(rs.relays, ip) {
rs.relays = append(rs.relays, ip)
}
} }
func (rs *RelayState) CopyRelayIps() []netip.Addr { func (rs *RelayState) CopyRelayIps() []netip.Addr {
ret := make([]netip.Addr, len(rs.relays))
rs.RLock() rs.RLock()
defer rs.RUnlock() defer rs.RUnlock()
ret := make([]netip.Addr, 0, len(rs.relays)) copy(ret, rs.relays)
for ip := range rs.relays {
ret = append(ret, ip)
}
return ret return ret
} }
@@ -208,6 +212,18 @@ func (rs *RelayState) InsertRelay(ip netip.Addr, idx uint32, r *Relay) {
rs.relayForByIdx[idx] = r rs.relayForByIdx[idx] = r
} }
type NetworkType uint8
const (
NetworkTypeUnknown NetworkType = iota
// NetworkTypeVPN is a network that overlaps one or more of the vpnNetworks in our certificate
NetworkTypeVPN
// NetworkTypeVPNPeer is a network that does not overlap one of our networks
NetworkTypeVPNPeer
// NetworkTypeUnsafe is a network from Certificate.UnsafeNetworks()
NetworkTypeUnsafe
)
type HostInfo struct { type HostInfo struct {
remote netip.AddrPort remote netip.AddrPort
remotes *RemoteList remotes *RemoteList
@@ -220,10 +236,9 @@ type HostInfo struct {
// The host may have other vpn addresses that are outside our // The host may have other vpn addresses that are outside our
// vpn networks but were removed because they are not usable // vpn networks but were removed because they are not usable
vpnAddrs []netip.Addr vpnAddrs []netip.Addr
recvError atomic.Uint32
// networks are both all vpn and unsafe networks assigned to this host // networks is a combination of specific vpn addresses (not prefixes!) and full unsafe networks assigned to this host.
networks *bart.Lite networks *bart.Table[NetworkType]
relayState RelayState relayState RelayState
// HandshakePacket records the packets used to create this hostinfo // HandshakePacket records the packets used to create this hostinfo
@@ -250,6 +265,14 @@ type HostInfo struct {
// Used to track other hostinfos for this vpn ip since only 1 can be primary // Used to track other hostinfos for this vpn ip since only 1 can be primary
// Synchronised via hostmap lock and not the hostinfo lock. // Synchronised via hostmap lock and not the hostinfo lock.
next, prev *HostInfo next, prev *HostInfo
//TODO: in, out, and others might benefit from being an atomic.Int32. We could collapse connectionManager pendingDeletion, relayUsed, and in/out into this 1 thing
in, out, pendingDeletion atomic.Bool
// lastUsed tracks the last time ConnectionManager checked the tunnel and it was in use.
// This value will be behind against actual tunnel utilization in the hot path.
// This should only be used by the ConnectionManagers ticker routine.
lastUsed time.Time
} }
type ViaSender struct { type ViaSender struct {
@@ -288,7 +311,6 @@ func NewHostMapFromConfig(l *logrus.Logger, c *config.C) *HostMap {
func newHostMap(l *logrus.Logger) *HostMap { func newHostMap(l *logrus.Logger) *HostMap {
return &HostMap{ return &HostMap{
RWMutex: synctrace.NewRWMutex("hostmap"),
Indexes: map[uint32]*HostInfo{}, Indexes: map[uint32]*HostInfo{},
Relays: map[uint32]*HostInfo{}, Relays: map[uint32]*HostInfo{},
RemoteIndexes: map[uint32]*HostInfo{}, RemoteIndexes: map[uint32]*HostInfo{},
@@ -720,26 +742,26 @@ func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, newRemote netip.AddrPort) b
return false return false
} }
func (i *HostInfo) RecvErrorExceeded() bool { // buildNetworks fills in the networks field of HostInfo. It accepts a cert.Certificate so you never ever mix the network types up.
if i.recvError.Add(1) >= maxRecvError { func (i *HostInfo) buildNetworks(myVpnNetworksTable *bart.Lite, c cert.Certificate) {
return true if len(c.Networks()) == 1 && len(c.UnsafeNetworks()) == 0 {
if myVpnNetworksTable.Contains(c.Networks()[0].Addr()) {
return // Simple case, no BART needed
} }
return true
} }
func (i *HostInfo) buildNetworks(networks, unsafeNetworks []netip.Prefix) { i.networks = new(bart.Table[NetworkType])
if len(networks) == 1 && len(unsafeNetworks) == 0 { for _, network := range c.Networks() {
// Simple case, no CIDRTree needed nprefix := netip.PrefixFrom(network.Addr(), network.Addr().BitLen())
return if myVpnNetworksTable.Contains(network.Addr()) {
i.networks.Insert(nprefix, NetworkTypeVPN)
} else {
i.networks.Insert(nprefix, NetworkTypeVPNPeer)
}
} }
i.networks = new(bart.Lite) for _, network := range c.UnsafeNetworks() {
for _, network := range networks { i.networks.Insert(network, NetworkTypeUnsafe)
i.networks.Insert(network)
}
for _, network := range unsafeNetworks {
i.networks.Insert(network)
} }
} }

View File

@@ -7,6 +7,7 @@ import (
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/test" "github.com/slackhq/nebula/test"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
func TestHostMap_MakePrimary(t *testing.T) { func TestHostMap_MakePrimary(t *testing.T) {
@@ -215,3 +216,31 @@ func TestHostMap_reload(t *testing.T) {
c.ReloadConfigString("preferred_ranges: [1.1.1.1/32]") c.ReloadConfigString("preferred_ranges: [1.1.1.1/32]")
assert.Equal(t, []string{"1.1.1.1/32"}, toS(hm.GetPreferredRanges())) assert.Equal(t, []string{"1.1.1.1/32"}, toS(hm.GetPreferredRanges()))
} }
func TestHostMap_RelayState(t *testing.T) {
h1 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 1}
a1 := netip.MustParseAddr("::1")
a2 := netip.MustParseAddr("2001::1")
h1.relayState.InsertRelayTo(a1)
assert.Equal(t, []netip.Addr{a1}, h1.relayState.relays)
h1.relayState.InsertRelayTo(a2)
assert.Equal(t, []netip.Addr{a1, a2}, h1.relayState.relays)
// Ensure that the first relay added is the first one returned in the copy
currentRelays := h1.relayState.CopyRelayIps()
require.Len(t, currentRelays, 2)
assert.Equal(t, a1, currentRelays[0])
// Deleting the last one in the list works ok
h1.relayState.DeleteRelay(a2)
assert.Equal(t, []netip.Addr{a1}, h1.relayState.relays)
// Deleting an element not in the list works ok
h1.relayState.DeleteRelay(a2)
assert.Equal(t, []netip.Addr{a1}, h1.relayState.relays)
// Deleting the only element in the list works ok
h1.relayState.DeleteRelay(a1)
assert.Equal(t, []netip.Addr{}, h1.relayState.relays)
}

108
inside.go
View File

@@ -2,16 +2,18 @@ package nebula
import ( import (
"net/netip" "net/netip"
"time"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula/firewall" "github.com/slackhq/nebula/firewall"
"github.com/slackhq/nebula/header" "github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/noiseutil" "github.com/slackhq/nebula/noiseutil"
"github.com/slackhq/nebula/packet"
"github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/routing"
) )
func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb, out []byte, q int, localCache firewall.ConntrackCache) { func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb []byte, out *packet.Packet, q int, localCache firewall.ConntrackCache, now time.Time) {
err := newPacket(packet, false, fwPacket) err := newPacket(packet, false, fwPacket)
if err != nil { if err != nil {
if f.l.Level >= logrus.DebugLevel { if f.l.Level >= logrus.DebugLevel {
@@ -53,7 +55,7 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
}) })
if hostinfo == nil { if hostinfo == nil {
f.rejectInside(packet, out, q) f.rejectInside(packet, out.Payload, q) //todo vector?
if f.l.Level >= logrus.DebugLevel { if f.l.Level >= logrus.DebugLevel {
f.l.WithField("vpnAddr", fwPacket.RemoteAddr). f.l.WithField("vpnAddr", fwPacket.RemoteAddr).
WithField("fwPacket", fwPacket). WithField("fwPacket", fwPacket).
@@ -66,12 +68,11 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
return return
} }
dropReason := f.firewall.Drop(*fwPacket, false, hostinfo, f.pki.GetCAPool(), localCache) dropReason := f.firewall.Drop(*fwPacket, false, hostinfo, f.pki.GetCAPool(), localCache, now)
if dropReason == nil { if dropReason == nil {
f.sendNoMetrics(header.Message, 0, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, packet, nb, out, q) f.sendNoMetricsDelayed(header.Message, 0, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, packet, nb, out, q)
} else { } else {
f.rejectInside(packet, out, q) f.rejectInside(packet, out.Payload, q) //todo vector?
if f.l.Level >= logrus.DebugLevel { if f.l.Level >= logrus.DebugLevel {
hostinfo.logger(f.l). hostinfo.logger(f.l).
WithField("fwPacket", fwPacket). WithField("fwPacket", fwPacket).
@@ -120,9 +121,10 @@ func (f *Interface) rejectOutside(packet []byte, ci *ConnectionState, hostinfo *
f.sendNoMetrics(header.Message, 0, ci, hostinfo, netip.AddrPort{}, out, nb, packet, q) f.sendNoMetrics(header.Message, 0, ci, hostinfo, netip.AddrPort{}, out, nb, packet, q)
} }
// Handshake will attempt to initiate a tunnel with the provided vpn address if it is within our vpn networks. This is a no-op if the tunnel is already established or being established // Handshake will attempt to initiate a tunnel with the provided vpn address. This is a no-op if the tunnel is already established or being established
// it does not check if it is within our vpn networks!
func (f *Interface) Handshake(vpnAddr netip.Addr) { func (f *Interface) Handshake(vpnAddr netip.Addr) {
f.getOrHandshakeNoRouting(vpnAddr, nil) f.handshakeManager.GetOrHandshake(vpnAddr, nil)
} }
// getOrHandshakeNoRouting returns nil if the vpnAddr is not routable. // getOrHandshakeNoRouting returns nil if the vpnAddr is not routable.
@@ -138,7 +140,6 @@ func (f *Interface) getOrHandshakeNoRouting(vpnAddr netip.Addr, cacheCallback fu
// getOrHandshakeConsiderRouting will try to find the HostInfo to handle this packet, starting a handshake if necessary. // getOrHandshakeConsiderRouting will try to find the HostInfo to handle this packet, starting a handshake if necessary.
// If the 2nd return var is false then the hostinfo is not ready to be used in a tunnel. // If the 2nd return var is false then the hostinfo is not ready to be used in a tunnel.
func (f *Interface) getOrHandshakeConsiderRouting(fwPacket *firewall.Packet, cacheCallback func(*HandshakeHostInfo)) (*HostInfo, bool) { func (f *Interface) getOrHandshakeConsiderRouting(fwPacket *firewall.Packet, cacheCallback func(*HandshakeHostInfo)) (*HostInfo, bool) {
destinationAddr := fwPacket.RemoteAddr destinationAddr := fwPacket.RemoteAddr
hostinfo, ready := f.getOrHandshakeNoRouting(destinationAddr, cacheCallback) hostinfo, ready := f.getOrHandshakeNoRouting(destinationAddr, cacheCallback)
@@ -218,7 +219,7 @@ func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubTyp
} }
// check if packet is in outbound fw rules // check if packet is in outbound fw rules
dropReason := f.firewall.Drop(*fp, false, hostinfo, f.pki.GetCAPool(), nil) dropReason := f.firewall.Drop(*fp, false, hostinfo, f.pki.GetCAPool(), nil, time.Now())
if dropReason != nil { if dropReason != nil {
if f.l.Level >= logrus.DebugLevel { if f.l.Level >= logrus.DebugLevel {
f.l.WithField("fwPacket", fp). f.l.WithField("fwPacket", fp).
@@ -231,9 +232,10 @@ func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubTyp
f.sendNoMetrics(header.Message, st, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, p, nb, out, 0) f.sendNoMetrics(header.Message, st, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, p, nb, out, 0)
} }
// SendMessageToVpnAddr handles real addr:port lookup and sends to the current best known address for vpnAddr // SendMessageToVpnAddr handles real addr:port lookup and sends to the current best known address for vpnAddr.
// This function ignores myVpnNetworksTable, and will always attempt to treat the address as a vpnAddr
func (f *Interface) SendMessageToVpnAddr(t header.MessageType, st header.MessageSubType, vpnAddr netip.Addr, p, nb, out []byte) { func (f *Interface) SendMessageToVpnAddr(t header.MessageType, st header.MessageSubType, vpnAddr netip.Addr, p, nb, out []byte) {
hostInfo, ready := f.getOrHandshakeNoRouting(vpnAddr, func(hh *HandshakeHostInfo) { hostInfo, ready := f.handshakeManager.GetOrHandshake(vpnAddr, func(hh *HandshakeHostInfo) {
hh.cachePacket(f.l, t, st, p, f.SendMessageToHostInfo, f.cachedPacketMetrics) hh.cachePacket(f.l, t, st, p, f.SendMessageToHostInfo, f.cachedPacketMetrics)
}) })
@@ -288,7 +290,7 @@ func (f *Interface) SendVia(via *HostInfo,
c := via.ConnectionState.messageCounter.Add(1) c := via.ConnectionState.messageCounter.Add(1)
out = header.Encode(out, header.Version, header.Message, header.MessageRelay, relay.RemoteIndex, c) out = header.Encode(out, header.Version, header.Message, header.MessageRelay, relay.RemoteIndex, c)
f.connectionManager.Out(via.localIndexId) f.connectionManager.Out(via)
// Authenticate the header and payload, but do not encrypt for this message type. // Authenticate the header and payload, but do not encrypt for this message type.
// The payload consists of the inner, unencrypted Nebula header, as well as the end-to-end encrypted payload. // The payload consists of the inner, unencrypted Nebula header, as well as the end-to-end encrypted payload.
@@ -356,7 +358,7 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType
//l.WithField("trace", string(debug.Stack())).Error("out Header ", &Header{Version, t, st, 0, hostinfo.remoteIndexId, c}, p) //l.WithField("trace", string(debug.Stack())).Error("out Header ", &Header{Version, t, st, 0, hostinfo.remoteIndexId, c}, p)
out = header.Encode(out, header.Version, t, st, hostinfo.remoteIndexId, c) out = header.Encode(out, header.Version, t, st, hostinfo.remoteIndexId, c)
f.connectionManager.Out(hostinfo.localIndexId) f.connectionManager.Out(hostinfo)
// Query our LH if we haven't since the last time we've been rebound, this will cause the remote to punch against // Query our LH if we haven't since the last time we've been rebound, this will cause the remote to punch against
// all our addrs and enable a faster roaming. // all our addrs and enable a faster roaming.
@@ -409,3 +411,81 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType
} }
} }
} }
func (f *Interface) sendNoMetricsDelayed(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote netip.AddrPort, p, nb []byte, out *packet.Packet, q int) {
if ci.eKey == nil {
return
}
useRelay := !remote.IsValid() && !hostinfo.remote.IsValid()
fullOut := out.Payload
if useRelay {
if len(out.Payload) < header.Len {
// out always has a capacity of mtu, but not always a length greater than the header.Len.
// Grow it to make sure the next operation works.
out.Payload = out.Payload[:header.Len]
}
// Save a header's worth of data at the front of the 'out' buffer.
out.Payload = out.Payload[header.Len:]
}
if noiseutil.EncryptLockNeeded {
// NOTE: for goboring AESGCMTLS we need to lock because of the nonce check
ci.writeLock.Lock()
}
c := ci.messageCounter.Add(1)
//l.WithField("trace", string(debug.Stack())).Error("out Header ", &Header{Version, t, st, 0, hostinfo.remoteIndexId, c}, p)
out.Payload = header.Encode(out.Payload, header.Version, t, st, hostinfo.remoteIndexId, c)
f.connectionManager.Out(hostinfo)
// Query our LH if we haven't since the last time we've been rebound, this will cause the remote to punch against
// all our addrs and enable a faster roaming.
if t != header.CloseTunnel && hostinfo.lastRebindCount != f.rebindCount {
//NOTE: there is an update hole if a tunnel isn't used and exactly 256 rebinds occur before the tunnel is
// finally used again. This tunnel would eventually be torn down and recreated if this action didn't help.
f.lightHouse.QueryServer(hostinfo.vpnAddrs[0])
hostinfo.lastRebindCount = f.rebindCount
if f.l.Level >= logrus.DebugLevel {
f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).Debug("Lighthouse update triggered for punch due to rebind counter")
}
}
var err error
out.Payload, err = ci.eKey.EncryptDanger(out.Payload, out.Payload, p, c, nb)
if noiseutil.EncryptLockNeeded {
ci.writeLock.Unlock()
}
if err != nil {
hostinfo.logger(f.l).WithError(err).
WithField("udpAddr", remote).WithField("counter", c).
WithField("attemptedCounter", c).
Error("Failed to encrypt outgoing packet")
return
}
if remote.IsValid() {
err = f.writers[q].Prep(out, remote)
if err != nil {
hostinfo.logger(f.l).WithError(err).WithField("udpAddr", remote).Error("Failed to write outgoing packet")
}
} else if hostinfo.remote.IsValid() {
err = f.writers[q].Prep(out, hostinfo.remote)
if err != nil {
hostinfo.logger(f.l).WithError(err).WithField("udpAddr", remote).Error("Failed to write outgoing packet")
}
} else {
// Try to send via a relay
for _, relayIP := range hostinfo.relayState.CopyRelayIps() {
relayHostInfo, relay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relayIP)
if err != nil {
hostinfo.relayState.DeleteRelay(relayIP)
hostinfo.logger(f.l).WithField("relay", relayIP).WithError(err).Info("sendNoMetrics failed to find HostInfo")
continue
}
//todo vector!!
f.SendVia(relayHostInfo, relay, out.Payload, nb, fullOut[:header.Len+len(out.Payload)], true)
break
}
}
}

View File

@@ -4,7 +4,6 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"io"
"net/netip" "net/netip"
"os" "os"
"runtime" "runtime"
@@ -18,22 +17,24 @@ import (
"github.com/slackhq/nebula/firewall" "github.com/slackhq/nebula/firewall"
"github.com/slackhq/nebula/header" "github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/overlay" "github.com/slackhq/nebula/overlay"
"github.com/slackhq/nebula/packet"
"github.com/slackhq/nebula/udp" "github.com/slackhq/nebula/udp"
) )
const mtu = 9001 const mtu = 9001
const batch = 1024 //todo config!
type InterfaceConfig struct { type InterfaceConfig struct {
HostMap *HostMap HostMap *HostMap
Outside udp.Conn Outside udp.Conn
Inside overlay.Device Inside overlay.Device
pki *PKI pki *PKI
Cipher string
Firewall *Firewall Firewall *Firewall
ServeDns bool ServeDns bool
HandshakeManager *HandshakeManager HandshakeManager *HandshakeManager
lightHouse *LightHouse lightHouse *LightHouse
checkInterval time.Duration connectionManager *connectionManager
pendingDeletionInterval time.Duration
DropLocalBroadcast bool DropLocalBroadcast bool
DropMulticast bool DropMulticast bool
routines int routines int
@@ -86,12 +87,18 @@ type Interface struct {
conntrackCacheTimeout time.Duration conntrackCacheTimeout time.Duration
writers []udp.Conn writers []udp.Conn
readers []io.ReadWriteCloser readers []overlay.TunDev
metricHandshakes metrics.Histogram metricHandshakes metrics.Histogram
messageMetrics *MessageMetrics messageMetrics *MessageMetrics
cachedPacketMetrics *cachedPacketMetrics cachedPacketMetrics *cachedPacketMetrics
listenInN int
listenOutN int
listenInMetric metrics.Histogram
listenOutMetric metrics.Histogram
l *logrus.Logger l *logrus.Logger
} }
@@ -157,6 +164,9 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
if c.Firewall == nil { if c.Firewall == nil {
return nil, errors.New("no firewall rules") return nil, errors.New("no firewall rules")
} }
if c.connectionManager == nil {
return nil, errors.New("no connection manager")
}
cs := c.pki.getCertState() cs := c.pki.getCertState()
ifce := &Interface{ ifce := &Interface{
@@ -174,14 +184,14 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
routines: c.routines, routines: c.routines,
version: c.version, version: c.version,
writers: make([]udp.Conn, c.routines), writers: make([]udp.Conn, c.routines),
readers: make([]io.ReadWriteCloser, c.routines), readers: make([]overlay.TunDev, c.routines),
myVpnNetworks: cs.myVpnNetworks, myVpnNetworks: cs.myVpnNetworks,
myVpnNetworksTable: cs.myVpnNetworksTable, myVpnNetworksTable: cs.myVpnNetworksTable,
myVpnAddrs: cs.myVpnAddrs, myVpnAddrs: cs.myVpnAddrs,
myVpnAddrsTable: cs.myVpnAddrsTable, myVpnAddrsTable: cs.myVpnAddrsTable,
myBroadcastAddrsTable: cs.myVpnBroadcastAddrsTable, myBroadcastAddrsTable: cs.myVpnBroadcastAddrsTable,
relayManager: c.relayManager, relayManager: c.relayManager,
connectionManager: c.connectionManager,
conntrackCacheTimeout: c.ConntrackCacheTimeout, conntrackCacheTimeout: c.ConntrackCacheTimeout,
metricHandshakes: metrics.GetOrRegisterHistogram("handshakes", nil, metrics.NewExpDecaySample(1028, 0.015)), metricHandshakes: metrics.GetOrRegisterHistogram("handshakes", nil, metrics.NewExpDecaySample(1028, 0.015)),
@@ -193,12 +203,14 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
l: c.l, l: c.l,
} }
ifce.listenInMetric = metrics.GetOrRegisterHistogram("vhost.listenIn.n", nil, metrics.NewExpDecaySample(1028, 0.015))
ifce.listenOutMetric = metrics.GetOrRegisterHistogram("vhost.listenOut.n", nil, metrics.NewExpDecaySample(1028, 0.015))
ifce.tryPromoteEvery.Store(c.tryPromoteEvery) ifce.tryPromoteEvery.Store(c.tryPromoteEvery)
ifce.reQueryEvery.Store(c.reQueryEvery) ifce.reQueryEvery.Store(c.reQueryEvery)
ifce.reQueryWait.Store(int64(c.reQueryWait)) ifce.reQueryWait.Store(int64(c.reQueryWait))
ifce.connectionManager = newConnectionManager(ctx, c.l, ifce, c.checkInterval, c.pendingDeletionInterval, c.punchy) ifce.connectionManager.intf = ifce
return ifce, nil return ifce, nil
} }
@@ -222,7 +234,7 @@ func (f *Interface) activate() {
metrics.GetOrRegisterGauge("routines", nil).Update(int64(f.routines)) metrics.GetOrRegisterGauge("routines", nil).Update(int64(f.routines))
// Prepare n tun queues // Prepare n tun queues
var reader io.ReadWriteCloser = f.inside var reader overlay.TunDev = f.inside
for i := 0; i < f.routines; i++ { for i := 0; i < f.routines; i++ {
if i > 0 { if i > 0 {
reader, err = f.inside.NewMultiQueueReader() reader, err = f.inside.NewMultiQueueReader()
@@ -251,40 +263,71 @@ func (f *Interface) run() {
} }
} }
func (f *Interface) listenOut(i int) { func (f *Interface) listenOut(q int) {
runtime.LockOSThread() runtime.LockOSThread()
var li udp.Conn var li udp.Conn
if i > 0 { if q > 0 {
li = f.writers[i] li = f.writers[q]
} else { } else {
li = f.outside li = f.outside
} }
ctCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout) ctCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
lhh := f.lightHouse.NewRequestHandler() lhh := f.lightHouse.NewRequestHandler()
plaintext := make([]byte, udp.MTU)
outPackets := make([]*packet.OutPacket, batch)
for i := 0; i < batch; i++ {
outPackets[i] = packet.NewOut()
}
h := &header.H{} h := &header.H{}
fwPacket := &firewall.Packet{} fwPacket := &firewall.Packet{}
nb := make([]byte, 12, 12) nb := make([]byte, 12, 12)
li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) { toSend := make([][]byte, batch)
f.readOutsidePackets(fromUdpAddr, nil, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get(f.l))
li.ListenOut(func(pkts []*packet.Packet) {
toSend = toSend[:0]
for i := range outPackets {
outPackets[i].Valid = false
outPackets[i].SegCounter = 0
}
f.readOutsidePacketsMany(pkts, outPackets, h, fwPacket, lhh, nb, q, ctCache.Get(f.l), time.Now())
//we opportunistically tx, but try to also send stragglers
if _, err := f.readers[q].WriteMany(outPackets, q); err != nil {
f.l.WithError(err).Error("Failed to send packets")
}
//todo I broke this
//n := len(toSend)
//if f.l.Level == logrus.DebugLevel {
// f.listenOutMetric.Update(int64(n))
//}
//f.listenOutN = n
}) })
} }
func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) { func (f *Interface) listenIn(reader overlay.TunDev, queueNum int) {
runtime.LockOSThread() runtime.LockOSThread()
packet := make([]byte, mtu)
out := make([]byte, mtu)
fwPacket := &firewall.Packet{} fwPacket := &firewall.Packet{}
nb := make([]byte, 12, 12) nb := make([]byte, 12, 12)
conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout) conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
packets := make([]*packet.VirtIOPacket, batch)
outPackets := make([]*packet.Packet, batch)
for i := 0; i < batch; i++ {
packets[i] = packet.NewVIO()
outPackets[i] = packet.New(false) //todo?
}
for { for {
n, err := reader.Read(packet) n, err := reader.ReadMany(packets, queueNum)
//todo!!
if err != nil { if err != nil {
if errors.Is(err, os.ErrClosed) && f.closed.Load() { if errors.Is(err, os.ErrClosed) && f.closed.Load() {
return return
@@ -295,7 +338,22 @@ func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
os.Exit(2) os.Exit(2)
} }
f.consumeInsidePacket(packet[:n], fwPacket, nb, out, i, conntrackCache.Get(f.l)) if f.l.Level == logrus.DebugLevel {
f.listenInMetric.Update(int64(n))
}
f.listenInN = n
now := time.Now()
for i, pkt := range packets[:n] {
outPackets[i].OutLen = -1
f.consumeInsidePacket(pkt.Payload, fwPacket, nb, outPackets[i], queueNum, conntrackCache.Get(f.l), now)
reader.RecycleRxSeg(pkt, i == (n-1), queueNum) //todo handle err?
pkt.Reset()
}
_, err = f.writers[queueNum].WriteBatch(outPackets[:n])
if err != nil {
f.l.WithError(err).Error("Error while writing outbound packets")
}
} }
} }
@@ -433,6 +491,11 @@ func (f *Interface) emitStats(ctx context.Context, i time.Duration) {
} else { } else {
certMaxVersion.Update(int64(certState.v1Cert.Version())) certMaxVersion.Update(int64(certState.v1Cert.Version()))
} }
if f.l.Level != logrus.DebugLevel {
f.listenInMetric.Update(int64(f.listenInN))
f.listenOutMetric.Update(int64(f.listenOutN))
}
} }
} }
} }

View File

@@ -9,6 +9,7 @@ import (
"net/netip" "net/netip"
"slices" "slices"
"strconv" "strconv"
"sync"
"sync/atomic" "sync/atomic"
"time" "time"
@@ -20,14 +21,14 @@ import (
"github.com/slackhq/nebula/header" "github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/udp" "github.com/slackhq/nebula/udp"
"github.com/slackhq/nebula/util" "github.com/slackhq/nebula/util"
"github.com/wadey/synctrace"
) )
var ErrHostNotKnown = errors.New("host not known") var ErrHostNotKnown = errors.New("host not known")
var ErrBadDetailsVpnAddr = errors.New("invalid packet, malformed detailsVpnAddr")
type LightHouse struct { type LightHouse struct {
//TODO: We need a timer wheel to kick out vpnAddrs that haven't reported in a long time //TODO: We need a timer wheel to kick out vpnAddrs that haven't reported in a long time
synctrace.RWMutex //Because we concurrently read and write to our maps sync.RWMutex //Because we concurrently read and write to our maps
ctx context.Context ctx context.Context
amLighthouse bool amLighthouse bool
@@ -56,7 +57,7 @@ type LightHouse struct {
// staticList exists to avoid having a bool in each addrMap entry // staticList exists to avoid having a bool in each addrMap entry
// since static should be rare // since static should be rare
staticList atomic.Pointer[map[netip.Addr]struct{}] staticList atomic.Pointer[map[netip.Addr]struct{}]
lighthouses atomic.Pointer[map[netip.Addr]struct{}] lighthouses atomic.Pointer[[]netip.Addr]
interval atomic.Int64 interval atomic.Int64
updateCancel context.CancelFunc updateCancel context.CancelFunc
@@ -96,7 +97,6 @@ func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C,
} }
h := LightHouse{ h := LightHouse{
RWMutex: synctrace.NewRWMutex("lighthouse"),
ctx: ctx, ctx: ctx,
amLighthouse: amLighthouse, amLighthouse: amLighthouse,
myVpnNetworks: cs.myVpnNetworks, myVpnNetworks: cs.myVpnNetworks,
@@ -108,7 +108,7 @@ func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C,
queryChan: make(chan netip.Addr, c.GetUint32("handshakes.query_buffer", 64)), queryChan: make(chan netip.Addr, c.GetUint32("handshakes.query_buffer", 64)),
l: l, l: l,
} }
lighthouses := make(map[netip.Addr]struct{}) lighthouses := make([]netip.Addr, 0)
h.lighthouses.Store(&lighthouses) h.lighthouses.Store(&lighthouses)
staticList := make(map[netip.Addr]struct{}) staticList := make(map[netip.Addr]struct{})
h.staticList.Store(&staticList) h.staticList.Store(&staticList)
@@ -144,7 +144,7 @@ func (lh *LightHouse) GetStaticHostList() map[netip.Addr]struct{} {
return *lh.staticList.Load() return *lh.staticList.Load()
} }
func (lh *LightHouse) GetLighthouses() map[netip.Addr]struct{} { func (lh *LightHouse) GetLighthouses() []netip.Addr {
return *lh.lighthouses.Load() return *lh.lighthouses.Load()
} }
@@ -307,13 +307,12 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error {
} }
if initial || c.HasChanged("lighthouse.hosts") { if initial || c.HasChanged("lighthouse.hosts") {
lhMap := make(map[netip.Addr]struct{}) lhList, err := lh.parseLighthouses(c)
err := lh.parseLighthouses(c, lhMap)
if err != nil { if err != nil {
return err return err
} }
lh.lighthouses.Store(&lhMap) lh.lighthouses.Store(&lhList)
if !initial { if !initial {
//NOTE: we are not tearing down existing lighthouse connections because they might be used for non lighthouse traffic //NOTE: we are not tearing down existing lighthouse connections because they might be used for non lighthouse traffic
lh.l.Info("lighthouse.hosts has changed") lh.l.Info("lighthouse.hosts has changed")
@@ -347,36 +346,38 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error {
return nil return nil
} }
func (lh *LightHouse) parseLighthouses(c *config.C, lhMap map[netip.Addr]struct{}) error { func (lh *LightHouse) parseLighthouses(c *config.C) ([]netip.Addr, error) {
lhs := c.GetStringSlice("lighthouse.hosts", []string{}) lhs := c.GetStringSlice("lighthouse.hosts", []string{})
if lh.amLighthouse && len(lhs) != 0 { if lh.amLighthouse && len(lhs) != 0 {
lh.l.Warn("lighthouse.am_lighthouse enabled on node but upstream lighthouses exist in config") lh.l.Warn("lighthouse.am_lighthouse enabled on node but upstream lighthouses exist in config")
} }
out := make([]netip.Addr, len(lhs))
for i, host := range lhs { for i, host := range lhs {
addr, err := netip.ParseAddr(host) addr, err := netip.ParseAddr(host)
if err != nil { if err != nil {
return util.NewContextualError("Unable to parse lighthouse host entry", m{"host": host, "entry": i + 1}, err) return nil, util.NewContextualError("Unable to parse lighthouse host entry", m{"host": host, "entry": i + 1}, err)
} }
if !lh.myVpnNetworksTable.Contains(addr) { if !lh.myVpnNetworksTable.Contains(addr) {
return util.NewContextualError("lighthouse host is not in our networks, invalid", m{"vpnAddr": addr, "networks": lh.myVpnNetworks}, nil) lh.l.WithFields(m{"vpnAddr": addr, "networks": lh.myVpnNetworks}).
Warn("lighthouse host is not within our networks, lighthouse functionality will work but layer 3 network traffic to the lighthouse will not")
} }
lhMap[addr] = struct{}{} out[i] = addr
} }
if !lh.amLighthouse && len(lhMap) == 0 { if !lh.amLighthouse && len(out) == 0 {
lh.l.Warn("No lighthouse.hosts configured, this host will only be able to initiate tunnels with static_host_map entries") lh.l.Warn("No lighthouse.hosts configured, this host will only be able to initiate tunnels with static_host_map entries")
} }
staticList := lh.GetStaticHostList() staticList := lh.GetStaticHostList()
for lhAddr, _ := range lhMap { for i := range out {
if _, ok := staticList[lhAddr]; !ok { if _, ok := staticList[out[i]]; !ok {
return fmt.Errorf("lighthouse %s does not have a static_host_map entry", lhAddr) return nil, fmt.Errorf("lighthouse %s does not have a static_host_map entry", out[i])
} }
} }
return nil return out, nil
} }
func getStaticMapCadence(c *config.C) (time.Duration, error) { func getStaticMapCadence(c *config.C) (time.Duration, error) {
@@ -431,7 +432,8 @@ func (lh *LightHouse) loadStaticMap(c *config.C, staticList map[netip.Addr]struc
} }
if !lh.myVpnNetworksTable.Contains(vpnAddr) { if !lh.myVpnNetworksTable.Contains(vpnAddr) {
return util.NewContextualError("static_host_map key is not in our network, invalid", m{"vpnAddr": vpnAddr, "networks": lh.myVpnNetworks, "entry": i + 1}, nil) lh.l.WithFields(m{"vpnAddr": vpnAddr, "networks": lh.myVpnNetworks, "entry": i + 1}).
Warn("static_host_map key is not within our networks, layer 3 network traffic to this host will not work")
} }
vals, ok := v.([]any) vals, ok := v.([]any)
@@ -473,7 +475,6 @@ func (lh *LightHouse) QueryServer(vpnAddr netip.Addr) {
return return
} }
synctrace.ChanDebugSend("lighthouse-querychan")
lh.queryChan <- vpnAddr lh.queryChan <- vpnAddr
} }
@@ -488,7 +489,7 @@ func (lh *LightHouse) QueryCache(vpnAddrs []netip.Addr) *RemoteList {
lh.Lock() lh.Lock()
defer lh.Unlock() defer lh.Unlock()
// Add an entry if we don't already have one // Add an entry if we don't already have one
return lh.unlockedGetRemoteList(vpnAddrs) return lh.unlockedGetRemoteList(vpnAddrs) //todo CERT-V2 this contains addrmap lookups we could potentially skip
} }
// queryAndPrepMessage is a lock helper on RemoteList, assisting the caller to build a lighthouse message containing // queryAndPrepMessage is a lock helper on RemoteList, assisting the caller to build a lighthouse message containing
@@ -521,11 +522,15 @@ func (lh *LightHouse) queryAndPrepMessage(vpnAddr netip.Addr, f func(*cache) (in
} }
func (lh *LightHouse) DeleteVpnAddrs(allVpnAddrs []netip.Addr) { func (lh *LightHouse) DeleteVpnAddrs(allVpnAddrs []netip.Addr) {
// First we check the static mapping // First we check the static host map. If any of the VpnAddrs to be deleted are present, do nothing.
// and do nothing if it is there staticList := lh.GetStaticHostList()
if _, ok := lh.GetStaticHostList()[allVpnAddrs[0]]; ok { for _, addr := range allVpnAddrs {
if _, ok := staticList[addr]; ok {
return return
} }
}
// None of the VpnAddrs were present. Now we can do the deletes.
lh.Lock() lh.Lock()
rm, ok := lh.addrMap[allVpnAddrs[0]] rm, ok := lh.addrMap[allVpnAddrs[0]]
if ok { if ok {
@@ -567,7 +572,7 @@ func (lh *LightHouse) addStaticRemotes(i int, d time.Duration, network string, t
am.unlockedSetHostnamesResults(hr) am.unlockedSetHostnamesResults(hr)
for _, addrPort := range hr.GetAddrs() { for _, addrPort := range hr.GetAddrs() {
if !lh.shouldAdd(vpnAddr, addrPort.Addr()) { if !lh.shouldAdd([]netip.Addr{vpnAddr}, addrPort.Addr()) {
continue continue
} }
switch { switch {
@@ -629,23 +634,30 @@ func (lh *LightHouse) addCalculatedRemotes(vpnAddr netip.Addr) bool {
return len(calculatedV4) > 0 || len(calculatedV6) > 0 return len(calculatedV4) > 0 || len(calculatedV6) > 0
} }
// unlockedGetRemoteList // unlockedGetRemoteList assumes you have the lh lock
// assumes you have the lh lock
func (lh *LightHouse) unlockedGetRemoteList(allAddrs []netip.Addr) *RemoteList { func (lh *LightHouse) unlockedGetRemoteList(allAddrs []netip.Addr) *RemoteList {
am, ok := lh.addrMap[allAddrs[0]] // before we go and make a new remotelist, we need to make sure we don't have one for any of this set of vpnaddrs yet
if !ok { for i, addr := range allAddrs {
am = NewRemoteList(allAddrs, func(a netip.Addr) bool { return lh.shouldAdd(allAddrs[0], a) }) am, ok := lh.addrMap[addr]
if ok {
if i != 0 {
lh.addrMap[allAddrs[0]] = am
}
return am
}
}
am := NewRemoteList(allAddrs, lh.shouldAdd)
for _, addr := range allAddrs { for _, addr := range allAddrs {
lh.addrMap[addr] = am lh.addrMap[addr] = am
} }
}
return am return am
} }
func (lh *LightHouse) shouldAdd(vpnAddr netip.Addr, to netip.Addr) bool { func (lh *LightHouse) shouldAdd(vpnAddrs []netip.Addr, to netip.Addr) bool {
allow := lh.GetRemoteAllowList().Allow(vpnAddr, to) allow := lh.GetRemoteAllowList().AllowAll(vpnAddrs, to)
if lh.l.Level >= logrus.TraceLevel { if lh.l.Level >= logrus.TraceLevel {
lh.l.WithField("vpnAddr", vpnAddr).WithField("udpAddr", to).WithField("allow", allow). lh.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", to).WithField("allow", allow).
Trace("remoteAllowList.Allow") Trace("remoteAllowList.Allow")
} }
if !allow { if !allow {
@@ -700,21 +712,24 @@ func (lh *LightHouse) unlockedShouldAddV6(vpnAddr netip.Addr, to *V6AddrPort) bo
} }
func (lh *LightHouse) IsLighthouseAddr(vpnAddr netip.Addr) bool { func (lh *LightHouse) IsLighthouseAddr(vpnAddr netip.Addr) bool {
if _, ok := lh.GetLighthouses()[vpnAddr]; ok { l := lh.GetLighthouses()
for i := range l {
if l[i] == vpnAddr {
return true return true
} }
}
return false return false
} }
// TODO: CERT-V2 IsLighthouseAddr should be sufficient, we just need to update the vpnAddrs for lighthouses after a handshake func (lh *LightHouse) IsAnyLighthouseAddr(vpnAddrs []netip.Addr) bool {
// so that we know all the lighthouse vpnAddrs, not just the ones we were configured to talk to initially
func (lh *LightHouse) IsAnyLighthouseAddr(vpnAddr []netip.Addr) bool {
l := lh.GetLighthouses() l := lh.GetLighthouses()
for _, a := range vpnAddr { for i := range vpnAddrs {
if _, ok := l[a]; ok { for j := range l {
if l[j] == vpnAddrs[i] {
return true return true
} }
} }
}
return false return false
} }
@@ -727,11 +742,9 @@ func (lh *LightHouse) startQueryWorker() {
nb := make([]byte, 12, 12) nb := make([]byte, 12, 12)
out := make([]byte, mtu) out := make([]byte, mtu)
synctrace.ChanDebugRecvLock("lighthouse-querychan")
for { for {
select { select {
case <-lh.ctx.Done(): case <-lh.ctx.Done():
synctrace.ChanDebugRecvUnlock("lighthouse-querychan")
return return
case addr := <-lh.queryChan: case addr := <-lh.queryChan:
lh.innerQueryServer(addr, nb, out) lh.innerQueryServer(addr, nb, out)
@@ -756,7 +769,7 @@ func (lh *LightHouse) innerQueryServer(addr netip.Addr, nb, out []byte) {
queried := 0 queried := 0
lighthouses := lh.GetLighthouses() lighthouses := lh.GetLighthouses()
for lhVpnAddr := range lighthouses { for _, lhVpnAddr := range lighthouses {
hi := lh.ifce.GetHostInfo(lhVpnAddr) hi := lh.ifce.GetHostInfo(lhVpnAddr)
if hi != nil { if hi != nil {
v = hi.ConnectionState.myCert.Version() v = hi.ConnectionState.myCert.Version()
@@ -874,7 +887,7 @@ func (lh *LightHouse) SendUpdate() {
updated := 0 updated := 0
lighthouses := lh.GetLighthouses() lighthouses := lh.GetLighthouses()
for lhVpnAddr := range lighthouses { for _, lhVpnAddr := range lighthouses {
var v cert.Version var v cert.Version
hi := lh.ifce.GetHostInfo(lhVpnAddr) hi := lh.ifce.GetHostInfo(lhVpnAddr)
if hi != nil { if hi != nil {
@@ -932,7 +945,6 @@ func (lh *LightHouse) SendUpdate() {
V4AddrPorts: v4, V4AddrPorts: v4,
V6AddrPorts: v6, V6AddrPorts: v6,
RelayVpnAddrs: relays, RelayVpnAddrs: relays,
VpnAddr: netAddrToProtoAddr(lh.myVpnNetworks[0].Addr()),
}, },
} }
@@ -1052,19 +1064,19 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, fromVpnAddrs []neti
return return
} }
useVersion := cert.Version1 queryVpnAddr, useVersion, err := n.Details.GetVpnAddrAndVersion()
var queryVpnAddr netip.Addr if err != nil {
if n.Details.OldVpnAddr != 0 {
b := [4]byte{}
binary.BigEndian.PutUint32(b[:], n.Details.OldVpnAddr)
queryVpnAddr = netip.AddrFrom4(b)
useVersion = 1
} else if n.Details.VpnAddr != nil {
queryVpnAddr = protoAddrToNetAddr(n.Details.VpnAddr)
useVersion = 2
} else {
if lhh.l.Level >= logrus.DebugLevel { if lhh.l.Level >= logrus.DebugLevel {
lhh.l.WithField("from", fromVpnAddrs).WithField("details", n.Details).Debugln("Dropping malformed HostQuery") lhh.l.WithField("from", fromVpnAddrs).WithField("details", n.Details).
Debugln("Dropping malformed HostQuery")
}
return
}
if useVersion == cert.Version1 && queryVpnAddr.Is6() {
// this case really shouldn't be possible to represent, but reject it anyway.
if lhh.l.Level >= logrus.DebugLevel {
lhh.l.WithField("vpnAddrs", fromVpnAddrs).WithField("queryVpnAddr", queryVpnAddr).
Debugln("invalid vpn addr for v1 handleHostQuery")
} }
return return
} }
@@ -1073,9 +1085,6 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, fromVpnAddrs []neti
n = lhh.resetMeta() n = lhh.resetMeta()
n.Type = NebulaMeta_HostQueryReply n.Type = NebulaMeta_HostQueryReply
if useVersion == cert.Version1 { if useVersion == cert.Version1 {
if !queryVpnAddr.Is4() {
return 0, fmt.Errorf("invalid vpn addr for v1 handleHostQuery")
}
b := queryVpnAddr.As4() b := queryVpnAddr.As4()
n.Details.OldVpnAddr = binary.BigEndian.Uint32(b[:]) n.Details.OldVpnAddr = binary.BigEndian.Uint32(b[:])
} else { } else {
@@ -1120,8 +1129,9 @@ func (lhh *LightHouseHandler) sendHostPunchNotification(n *NebulaMeta, fromVpnAd
if ok { if ok {
whereToPunch = newDest whereToPunch = newDest
} else { } else {
//TODO: CERT-V2 this means the destination will have no addresses in common with the punch-ee if lhh.l.Level >= logrus.DebugLevel {
//choosing to do nothing for now, but maybe we return an error? lhh.l.WithField("to", crt.Networks()).Debugln("unable to punch to host, no addresses in common")
}
} }
} }
@@ -1180,19 +1190,17 @@ func (lhh *LightHouseHandler) coalesceAnswers(v cert.Version, c *cache, n *Nebul
if !r.Is4() { if !r.Is4() {
continue continue
} }
b = r.As4() b = r.As4()
n.Details.OldRelayVpnAddrs = append(n.Details.OldRelayVpnAddrs, binary.BigEndian.Uint32(b[:])) n.Details.OldRelayVpnAddrs = append(n.Details.OldRelayVpnAddrs, binary.BigEndian.Uint32(b[:]))
} }
} else if v == cert.Version2 { } else if v == cert.Version2 {
for _, r := range c.relay.relay { for _, r := range c.relay.relay {
n.Details.RelayVpnAddrs = append(n.Details.RelayVpnAddrs, netAddrToProtoAddr(r)) n.Details.RelayVpnAddrs = append(n.Details.RelayVpnAddrs, netAddrToProtoAddr(r))
} }
} else { } else {
//TODO: CERT-V2 don't panic if lhh.l.Level >= logrus.DebugLevel {
panic("unsupported version") lhh.l.WithField("version", v).Debug("unsupported protocol version")
}
} }
} }
} }
@@ -1202,18 +1210,16 @@ func (lhh *LightHouseHandler) handleHostQueryReply(n *NebulaMeta, fromVpnAddrs [
return return
} }
lhh.lh.Lock() certVpnAddr, _, err := n.Details.GetVpnAddrAndVersion()
if err != nil {
var certVpnAddr netip.Addr if lhh.l.Level >= logrus.DebugLevel {
if n.Details.OldVpnAddr != 0 { lhh.l.WithError(err).WithField("vpnAddrs", fromVpnAddrs).Error("dropping malformed HostQueryReply")
b := [4]byte{} }
binary.BigEndian.PutUint32(b[:], n.Details.OldVpnAddr) return
certVpnAddr = netip.AddrFrom4(b)
} else if n.Details.VpnAddr != nil {
certVpnAddr = protoAddrToNetAddr(n.Details.VpnAddr)
} }
relays := n.Details.GetRelays() relays := n.Details.GetRelays()
lhh.lh.Lock()
am := lhh.lh.unlockedGetRemoteList([]netip.Addr{certVpnAddr}) am := lhh.lh.unlockedGetRemoteList([]netip.Addr{certVpnAddr})
am.Lock() am.Lock()
lhh.lh.Unlock() lhh.lh.Unlock()
@@ -1238,27 +1244,24 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, fromVp
return return
} }
// not using GetVpnAddrAndVersion because we don't want to error on a blank detailsVpnAddr
var detailsVpnAddr netip.Addr var detailsVpnAddr netip.Addr
useVersion := cert.Version1 var useVersion cert.Version
if n.Details.OldVpnAddr != 0 { if n.Details.OldVpnAddr != 0 { //v1 always sets this field
b := [4]byte{} b := [4]byte{}
binary.BigEndian.PutUint32(b[:], n.Details.OldVpnAddr) binary.BigEndian.PutUint32(b[:], n.Details.OldVpnAddr)
detailsVpnAddr = netip.AddrFrom4(b) detailsVpnAddr = netip.AddrFrom4(b)
useVersion = cert.Version1 useVersion = cert.Version1
} else if n.Details.VpnAddr != nil { } else if n.Details.VpnAddr != nil { //this field is "optional" in v2, but if it's set, we should enforce it
detailsVpnAddr = protoAddrToNetAddr(n.Details.VpnAddr) detailsVpnAddr = protoAddrToNetAddr(n.Details.VpnAddr)
useVersion = cert.Version2 useVersion = cert.Version2
} else { } else {
if lhh.l.Level >= logrus.DebugLevel { detailsVpnAddr = netip.Addr{}
lhh.l.WithField("details", n.Details).Debugf("dropping invalid HostUpdateNotification") useVersion = cert.Version2
}
return
} }
//TODO: CERT-V2 hosts with only v2 certs cannot provide their ipv6 addr when contacting the lighthouse via v4? //Simple check that the host sent this not someone else, if detailsVpnAddr is filled
//TODO: CERT-V2 why do we care about the vpnAddr in the packet? We know where it came from, right? if detailsVpnAddr.IsValid() && !slices.Contains(fromVpnAddrs, detailsVpnAddr) {
//Simple check that the host sent this not someone else
if !slices.Contains(fromVpnAddrs, detailsVpnAddr) {
if lhh.l.Level >= logrus.DebugLevel { if lhh.l.Level >= logrus.DebugLevel {
lhh.l.WithField("vpnAddrs", fromVpnAddrs).WithField("answer", detailsVpnAddr).Debugln("Host sent invalid update") lhh.l.WithField("vpnAddrs", fromVpnAddrs).WithField("answer", detailsVpnAddr).Debugln("Host sent invalid update")
} }
@@ -1272,24 +1275,24 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, fromVp
am.Lock() am.Lock()
lhh.lh.Unlock() lhh.lh.Unlock()
am.unlockedSetV4(fromVpnAddrs[0], detailsVpnAddr, n.Details.V4AddrPorts, lhh.lh.unlockedShouldAddV4) am.unlockedSetV4(fromVpnAddrs[0], fromVpnAddrs[0], n.Details.V4AddrPorts, lhh.lh.unlockedShouldAddV4)
am.unlockedSetV6(fromVpnAddrs[0], detailsVpnAddr, n.Details.V6AddrPorts, lhh.lh.unlockedShouldAddV6) am.unlockedSetV6(fromVpnAddrs[0], fromVpnAddrs[0], n.Details.V6AddrPorts, lhh.lh.unlockedShouldAddV6)
am.unlockedSetRelay(fromVpnAddrs[0], relays) am.unlockedSetRelay(fromVpnAddrs[0], relays)
am.Unlock() am.Unlock()
n = lhh.resetMeta() n = lhh.resetMeta()
n.Type = NebulaMeta_HostUpdateNotificationAck n.Type = NebulaMeta_HostUpdateNotificationAck
switch useVersion {
if useVersion == cert.Version1 { case cert.Version1:
if !fromVpnAddrs[0].Is4() { if !fromVpnAddrs[0].Is4() {
lhh.l.WithField("vpnAddrs", fromVpnAddrs).Error("Can not send HostUpdateNotificationAck for a ipv6 vpn ip in a v1 message") lhh.l.WithField("vpnAddrs", fromVpnAddrs).Error("Can not send HostUpdateNotificationAck for a ipv6 vpn ip in a v1 message")
return return
} }
vpnAddrB := fromVpnAddrs[0].As4() vpnAddrB := fromVpnAddrs[0].As4()
n.Details.OldVpnAddr = binary.BigEndian.Uint32(vpnAddrB[:]) n.Details.OldVpnAddr = binary.BigEndian.Uint32(vpnAddrB[:])
} else if useVersion == cert.Version2 { case cert.Version2:
n.Details.VpnAddr = netAddrToProtoAddr(fromVpnAddrs[0]) // do nothing, we want to send a blank message
} else { default:
lhh.l.WithField("useVersion", useVersion).Error("invalid protocol version") lhh.l.WithField("useVersion", useVersion).Error("invalid protocol version")
return return
} }
@@ -1307,13 +1310,20 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, fromVp
func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, fromVpnAddrs []netip.Addr, w EncWriter) { func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, fromVpnAddrs []netip.Addr, w EncWriter) {
//It's possible the lighthouse is communicating with us using a non primary vpn addr, //It's possible the lighthouse is communicating with us using a non primary vpn addr,
//which means we need to compare all fromVpnAddrs against all configured lighthouse vpn addrs. //which means we need to compare all fromVpnAddrs against all configured lighthouse vpn addrs.
//maybe one day we'll have a better idea, if it matters.
if !lhh.lh.IsAnyLighthouseAddr(fromVpnAddrs) { if !lhh.lh.IsAnyLighthouseAddr(fromVpnAddrs) {
return return
} }
detailsVpnAddr, _, err := n.Details.GetVpnAddrAndVersion()
if err != nil {
if lhh.l.Level >= logrus.DebugLevel {
lhh.l.WithField("details", n.Details).WithError(err).Debugln("dropping invalid HostPunchNotification")
}
return
}
empty := []byte{0} empty := []byte{0}
punch := func(vpnPeer netip.AddrPort) { punch := func(vpnPeer netip.AddrPort, logVpnAddr netip.Addr) {
if !vpnPeer.IsValid() { if !vpnPeer.IsValid() {
return return
} }
@@ -1325,48 +1335,38 @@ func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, fromVpn
}() }()
if lhh.l.Level >= logrus.DebugLevel { if lhh.l.Level >= logrus.DebugLevel {
var logVpnAddr netip.Addr
if n.Details.OldVpnAddr != 0 {
b := [4]byte{}
binary.BigEndian.PutUint32(b[:], n.Details.OldVpnAddr)
logVpnAddr = netip.AddrFrom4(b)
} else if n.Details.VpnAddr != nil {
logVpnAddr = protoAddrToNetAddr(n.Details.VpnAddr)
}
lhh.l.Debugf("Punching on %v for %v", vpnPeer, logVpnAddr) lhh.l.Debugf("Punching on %v for %v", vpnPeer, logVpnAddr)
} }
} }
remoteAllowList := lhh.lh.GetRemoteAllowList()
for _, a := range n.Details.V4AddrPorts { for _, a := range n.Details.V4AddrPorts {
punch(protoV4AddrPortToNetAddrPort(a)) b := protoV4AddrPortToNetAddrPort(a)
if remoteAllowList.Allow(detailsVpnAddr, b.Addr()) {
punch(b, detailsVpnAddr)
}
} }
for _, a := range n.Details.V6AddrPorts { for _, a := range n.Details.V6AddrPorts {
punch(protoV6AddrPortToNetAddrPort(a)) b := protoV6AddrPortToNetAddrPort(a)
if remoteAllowList.Allow(detailsVpnAddr, b.Addr()) {
punch(b, detailsVpnAddr)
}
} }
// This sends a nebula test packet to the host trying to contact us. In the case // This sends a nebula test packet to the host trying to contact us. In the case
// of a double nat or other difficult scenario, this may help establish // of a double nat or other difficult scenario, this may help establish
// a tunnel. // a tunnel.
if lhh.lh.punchy.GetRespond() { if lhh.lh.punchy.GetRespond() {
var queryVpnAddr netip.Addr
if n.Details.OldVpnAddr != 0 {
b := [4]byte{}
binary.BigEndian.PutUint32(b[:], n.Details.OldVpnAddr)
queryVpnAddr = netip.AddrFrom4(b)
} else if n.Details.VpnAddr != nil {
queryVpnAddr = protoAddrToNetAddr(n.Details.VpnAddr)
}
go func() { go func() {
time.Sleep(lhh.lh.punchy.GetRespondDelay()) time.Sleep(lhh.lh.punchy.GetRespondDelay())
if lhh.l.Level >= logrus.DebugLevel { if lhh.l.Level >= logrus.DebugLevel {
lhh.l.Debugf("Sending a nebula test packet to vpn addr %s", queryVpnAddr) lhh.l.Debugf("Sending a nebula test packet to vpn addr %s", detailsVpnAddr)
} }
//NOTE: we have to allocate a new output buffer here since we are spawning a new goroutine //NOTE: we have to allocate a new output buffer here since we are spawning a new goroutine
// for each punchBack packet. We should move this into a timerwheel or a single goroutine // for each punchBack packet. We should move this into a timerwheel or a single goroutine
// managed by a channel. // managed by a channel.
w.SendMessageToVpnAddr(header.Test, header.TestRequest, queryVpnAddr, []byte(""), make([]byte, 12, 12), make([]byte, mtu)) w.SendMessageToVpnAddr(header.Test, header.TestRequest, detailsVpnAddr, []byte(""), make([]byte, 12, 12), make([]byte, mtu))
}() }()
} }
} }
@@ -1445,3 +1445,17 @@ func findNetworkUnion(prefixes []netip.Prefix, addrs []netip.Addr) (netip.Addr,
} }
return netip.Addr{}, false return netip.Addr{}, false
} }
func (d *NebulaMetaDetails) GetVpnAddrAndVersion() (netip.Addr, cert.Version, error) {
if d.OldVpnAddr != 0 {
b := [4]byte{}
binary.BigEndian.PutUint32(b[:], d.OldVpnAddr)
detailsVpnAddr := netip.AddrFrom4(b)
return detailsVpnAddr, cert.Version1, nil
} else if d.VpnAddr != nil {
detailsVpnAddr := protoAddrToNetAddr(d.VpnAddr)
return detailsVpnAddr, cert.Version2, nil
} else {
return netip.Addr{}, cert.Version1, ErrBadDetailsVpnAddr
}
}

View File

@@ -14,7 +14,7 @@ import (
"github.com/slackhq/nebula/test" "github.com/slackhq/nebula/test"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"gopkg.in/yaml.v3" "go.yaml.in/yaml/v3"
) )
func TestOldIPv4Only(t *testing.T) { func TestOldIPv4Only(t *testing.T) {
@@ -493,3 +493,123 @@ func Test_findNetworkUnion(t *testing.T) {
out, ok = findNetworkUnion([]netip.Prefix{fc00}, []netip.Addr{a1, afe81}) out, ok = findNetworkUnion([]netip.Prefix{fc00}, []netip.Addr{a1, afe81})
assert.False(t, ok) assert.False(t, ok)
} }
func TestLighthouse_Dont_Delete_Static_Hosts(t *testing.T) {
l := test.NewLogger()
myUdpAddr2 := netip.MustParseAddrPort("1.2.3.4:4242")
testSameHostNotStatic := netip.MustParseAddr("10.128.0.41")
testStaticHost := netip.MustParseAddr("10.128.0.42")
//myVpnIp := netip.MustParseAddr("10.128.0.2")
c := config.NewC(l)
lh1 := "10.128.0.2"
c.Settings["lighthouse"] = map[string]any{
"hosts": []any{lh1},
"interval": "1s",
}
c.Settings["listen"] = map[string]any{"port": 4242}
c.Settings["static_host_map"] = map[string]any{
lh1: []any{"1.1.1.1:4242"},
"10.128.0.42": []any{"1.2.3.4:4242"},
}
myVpnNet := netip.MustParsePrefix("10.128.0.1/24")
nt := new(bart.Lite)
nt.Insert(myVpnNet)
cs := &CertState{
myVpnNetworks: []netip.Prefix{myVpnNet},
myVpnNetworksTable: nt,
}
lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
require.NoError(t, err)
lh.ifce = &mockEncWriter{}
//test that we actually have the static entry:
out := lh.Query(testStaticHost)
assert.NotNil(t, out)
assert.Equal(t, out.vpnAddrs[0], testStaticHost)
out.Rebuild([]netip.Prefix{}) //why tho
assert.Equal(t, out.addrs[0], myUdpAddr2)
//bolt on a lower numbered primary IP
am := lh.unlockedGetRemoteList([]netip.Addr{testStaticHost})
am.vpnAddrs = []netip.Addr{testSameHostNotStatic, testStaticHost}
lh.addrMap[testSameHostNotStatic] = am
out.Rebuild([]netip.Prefix{}) //???
//test that we actually have the static entry:
out = lh.Query(testStaticHost)
assert.NotNil(t, out)
assert.Equal(t, out.vpnAddrs[0], testSameHostNotStatic)
assert.Equal(t, out.vpnAddrs[1], testStaticHost)
assert.Equal(t, out.addrs[0], myUdpAddr2)
//test that we actually have the static entry for BOTH:
out2 := lh.Query(testSameHostNotStatic)
assert.Same(t, out2, out)
//now do the delete
lh.DeleteVpnAddrs([]netip.Addr{testSameHostNotStatic, testStaticHost})
//verify
out = lh.Query(testSameHostNotStatic)
assert.NotNil(t, out)
if out == nil {
t.Fatal("expected non-nil query for the static host")
}
assert.Equal(t, out.vpnAddrs[0], testSameHostNotStatic)
assert.Equal(t, out.vpnAddrs[1], testStaticHost)
assert.Equal(t, out.addrs[0], myUdpAddr2)
}
func TestLighthouse_DeletesWork(t *testing.T) {
l := test.NewLogger()
myUdpAddr2 := netip.MustParseAddrPort("1.2.3.4:4242")
testHost := netip.MustParseAddr("10.128.0.42")
c := config.NewC(l)
lh1 := "10.128.0.2"
c.Settings["lighthouse"] = map[string]any{
"hosts": []any{lh1},
"interval": "1s",
}
c.Settings["listen"] = map[string]any{"port": 4242}
c.Settings["static_host_map"] = map[string]any{
lh1: []any{"1.1.1.1:4242"},
}
myVpnNet := netip.MustParsePrefix("10.128.0.1/24")
nt := new(bart.Lite)
nt.Insert(myVpnNet)
cs := &CertState{
myVpnNetworks: []netip.Prefix{myVpnNet},
myVpnNetworksTable: nt,
}
lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
require.NoError(t, err)
lh.ifce = &mockEncWriter{}
//insert the host
am := lh.unlockedGetRemoteList([]netip.Addr{testHost})
am.vpnAddrs = []netip.Addr{testHost}
am.addrs = []netip.AddrPort{myUdpAddr2}
lh.addrMap[testHost] = am
am.Rebuild([]netip.Prefix{}) //???
//test that we actually have the entry:
out := lh.Query(testHost)
assert.NotNil(t, out)
assert.Equal(t, out.vpnAddrs[0], testHost)
out.Rebuild([]netip.Prefix{}) //why tho
assert.Equal(t, out.addrs[0], myUdpAddr2)
//now do the delete
lh.DeleteVpnAddrs([]netip.Addr{testHost})
//verify
out = lh.Query(testHost)
assert.Nil(t, out)
}

14
main.go
View File

@@ -13,7 +13,7 @@ import (
"github.com/slackhq/nebula/sshd" "github.com/slackhq/nebula/sshd"
"github.com/slackhq/nebula/udp" "github.com/slackhq/nebula/udp"
"github.com/slackhq/nebula/util" "github.com/slackhq/nebula/util"
"gopkg.in/yaml.v3" "go.yaml.in/yaml/v3"
) )
type m = map[string]any type m = map[string]any
@@ -75,7 +75,8 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
if c.GetBool("sshd.enabled", false) { if c.GetBool("sshd.enabled", false) {
sshStart, err = configSSH(l, ssh, c) sshStart, err = configSSH(l, ssh, c)
if err != nil { if err != nil {
return nil, util.ContextualizeIfNeeded("Error while configuring the sshd", err) l.WithError(err).Warn("Failed to configure sshd, ssh debugging will not be available")
sshStart = nil
} }
} }
@@ -185,6 +186,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
hostMap := NewHostMapFromConfig(l, c) hostMap := NewHostMapFromConfig(l, c)
punchy := NewPunchyFromConfig(l, c) punchy := NewPunchyFromConfig(l, c)
connManager := newConnectionManagerFromConfig(l, c, hostMap, punchy)
lightHouse, err := NewLightHouseFromConfig(ctx, l, c, pki.getCertState(), udpConns[0], punchy) lightHouse, err := NewLightHouseFromConfig(ctx, l, c, pki.getCertState(), udpConns[0], punchy)
if err != nil { if err != nil {
return nil, util.ContextualizeIfNeeded("Failed to initialize lighthouse handler", err) return nil, util.ContextualizeIfNeeded("Failed to initialize lighthouse handler", err)
@@ -220,9 +222,6 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
} }
} }
checkInterval := c.GetInt("timers.connection_alive_interval", 5)
pendingDeletionInterval := c.GetInt("timers.pending_deletion_interval", 10)
ifConfig := &InterfaceConfig{ ifConfig := &InterfaceConfig{
HostMap: hostMap, HostMap: hostMap,
Inside: tun, Inside: tun,
@@ -231,9 +230,8 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
Firewall: fw, Firewall: fw,
ServeDns: serveDns, ServeDns: serveDns,
HandshakeManager: handshakeManager, HandshakeManager: handshakeManager,
connectionManager: connManager,
lightHouse: lightHouse, lightHouse: lightHouse,
checkInterval: time.Second * time.Duration(checkInterval),
pendingDeletionInterval: time.Second * time.Duration(pendingDeletionInterval),
tryPromoteEvery: c.GetUint32("counters.try_promote", defaultPromoteEvery), tryPromoteEvery: c.GetUint32("counters.try_promote", defaultPromoteEvery),
reQueryEvery: c.GetUint32("counters.requery_every_packets", defaultReQueryEvery), reQueryEvery: c.GetUint32("counters.requery_every_packets", defaultReQueryEvery),
reQueryWait: c.GetDuration("timers.requery_wait_duration", defaultReQueryWait), reQueryWait: c.GetDuration("timers.requery_wait_duration", defaultReQueryWait),
@@ -244,7 +242,6 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
version: buildVersion, version: buildVersion,
relayManager: NewRelayManager(ctx, l, hostMap, c), relayManager: NewRelayManager(ctx, l, hostMap, c),
punchy: punchy, punchy: punchy,
ConntrackCacheTimeout: conntrackCacheTimeout, ConntrackCacheTimeout: conntrackCacheTimeout,
l: l, l: l,
} }
@@ -296,5 +293,6 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
statsStart, statsStart,
dnsStart, dnsStart,
lightHouse.StartUpdateWorker, lightHouse.StartUpdateWorker,
connManager.Start,
}, nil }, nil
} }

View File

@@ -7,6 +7,7 @@ import (
"time" "time"
"github.com/google/gopacket/layers" "github.com/google/gopacket/layers"
"github.com/slackhq/nebula/packet"
"golang.org/x/net/ipv6" "golang.org/x/net/ipv6"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
@@ -19,7 +20,7 @@ const (
minFwPacketLen = 4 minFwPacketLen = 4
) )
func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache) { func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache, now time.Time) {
err := h.Parse(packet) err := h.Parse(packet)
if err != nil { if err != nil {
// Hole punch packets are 0 or 1 byte big, so lets ignore printing those errors // Hole punch packets are 0 or 1 byte big, so lets ignore printing those errors
@@ -61,7 +62,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
switch h.Subtype { switch h.Subtype {
case header.MessageNone: case header.MessageNone:
if !f.decryptToTun(hostinfo, h.MessageCounter, out, packet, fwPacket, nb, q, localCache) { if !f.decryptToTun(hostinfo, h.MessageCounter, out, packet, fwPacket, nb, q, localCache, now) {
return return
} }
case header.MessageRelay: case header.MessageRelay:
@@ -81,7 +82,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
// Pull the Roaming parts up here, and return in all call paths. // Pull the Roaming parts up here, and return in all call paths.
f.handleHostRoaming(hostinfo, ip) f.handleHostRoaming(hostinfo, ip)
// Track usage of both the HostInfo and the Relay for the received & authenticated packet // Track usage of both the HostInfo and the Relay for the received & authenticated packet
f.connectionManager.In(hostinfo.localIndexId) f.connectionManager.In(hostinfo)
f.connectionManager.RelayUsed(h.RemoteIndex) f.connectionManager.RelayUsed(h.RemoteIndex)
relay, ok := hostinfo.relayState.QueryRelayForByIdx(h.RemoteIndex) relay, ok := hostinfo.relayState.QueryRelayForByIdx(h.RemoteIndex)
@@ -96,7 +97,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
case TerminalType: case TerminalType:
// If I am the target of this relay, process the unwrapped packet // If I am the target of this relay, process the unwrapped packet
// From this recursive point, all these variables are 'burned'. We shouldn't rely on them again. // From this recursive point, all these variables are 'burned'. We shouldn't rely on them again.
f.readOutsidePackets(netip.AddrPort{}, &ViaSender{relayHI: hostinfo, remoteIdx: relay.RemoteIndex, relay: relay}, out[:0], signedPayload, h, fwPacket, lhf, nb, q, localCache) f.readOutsidePackets(netip.AddrPort{}, &ViaSender{relayHI: hostinfo, remoteIdx: relay.RemoteIndex, relay: relay}, out[:0], signedPayload, h, fwPacket, lhf, nb, q, localCache, now)
return return
case ForwardingType: case ForwardingType:
// Find the target HostInfo relay object // Find the target HostInfo relay object
@@ -213,7 +214,218 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
f.handleHostRoaming(hostinfo, ip) f.handleHostRoaming(hostinfo, ip)
f.connectionManager.In(hostinfo.localIndexId) f.connectionManager.In(hostinfo)
}
func (f *Interface) readOutsidePacketsMany(packets []*packet.Packet, out []*packet.OutPacket, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache, now time.Time) {
for i, pkt := range packets {
out[i].Scratch = out[i].Scratch[:0]
ip := pkt.AddrPort()
//l.Error("in packet ", header, packet[HeaderLen:])
if ip.IsValid() {
if f.myVpnNetworksTable.Contains(ip.Addr()) {
if f.l.Level >= logrus.DebugLevel {
f.l.WithField("udpAddr", ip).Debug("Refusing to process double encrypted packet")
}
return
}
}
//todo per-segment!
for segment := range pkt.Segments() {
err := h.Parse(segment)
if err != nil {
// Hole punch packets are 0 or 1 byte big, so lets ignore printing those errors
if len(segment) > 1 {
f.l.WithField("packet", pkt).Infof("Error while parsing inbound packet from %s: %s", ip, err)
}
return
}
var hostinfo *HostInfo
// verify if we've seen this index before, otherwise respond to the handshake initiation
if h.Type == header.Message && h.Subtype == header.MessageRelay {
hostinfo = f.hostMap.QueryRelayIndex(h.RemoteIndex)
} else {
hostinfo = f.hostMap.QueryIndex(h.RemoteIndex)
}
var ci *ConnectionState
if hostinfo != nil {
ci = hostinfo.ConnectionState
}
switch h.Type {
case header.Message:
// TODO handleEncrypted sends directly to addr on error. Handle this in the tunneling case.
if !f.handleEncrypted(ci, ip, h) {
return
}
switch h.Subtype {
case header.MessageNone:
if !f.decryptToTunDelayWrite(hostinfo, h.MessageCounter, out[i], pkt, segment, fwPacket, nb, q, localCache, now) {
return
}
case header.MessageRelay:
// The entire body is sent as AD, not encrypted.
// The packet consists of a 16-byte parsed Nebula header, Associated Data-protected payload, and a trailing 16-byte AEAD signature value.
// The packet is guaranteed to be at least 16 bytes at this point, b/c it got past the h.Parse() call above. If it's
// otherwise malformed (meaning, there is no trailing 16 byte AEAD value), then this will result in at worst a 0-length slice
// which will gracefully fail in the DecryptDanger call.
signedPayload := segment[:len(segment)-hostinfo.ConnectionState.dKey.Overhead()]
signatureValue := segment[len(segment)-hostinfo.ConnectionState.dKey.Overhead():]
out[i].Scratch, err = hostinfo.ConnectionState.dKey.DecryptDanger(out[i].Scratch, signedPayload, signatureValue, h.MessageCounter, nb)
if err != nil {
return
}
// Successfully validated the thing. Get rid of the Relay header.
signedPayload = signedPayload[header.Len:]
// Pull the Roaming parts up here, and return in all call paths.
f.handleHostRoaming(hostinfo, ip)
// Track usage of both the HostInfo and the Relay for the received & authenticated packet
f.connectionManager.In(hostinfo)
f.connectionManager.RelayUsed(h.RemoteIndex)
relay, ok := hostinfo.relayState.QueryRelayForByIdx(h.RemoteIndex)
if !ok {
// The only way this happens is if hostmap has an index to the correct HostInfo, but the HostInfo is missing
// its internal mapping. This should never happen.
hostinfo.logger(f.l).WithFields(logrus.Fields{"vpnAddrs": hostinfo.vpnAddrs, "remoteIndex": h.RemoteIndex}).Error("HostInfo missing remote relay index")
return
}
switch relay.Type {
case TerminalType:
// If I am the target of this relay, process the unwrapped packet
// From this recursive point, all these variables are 'burned'. We shouldn't rely on them again.
f.readOutsidePackets(netip.AddrPort{}, &ViaSender{relayHI: hostinfo, remoteIdx: relay.RemoteIndex, relay: relay}, out[i].Scratch[:0], signedPayload, h, fwPacket, lhf, nb, q, localCache, now)
return
case ForwardingType:
// Find the target HostInfo relay object
targetHI, targetRelay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relay.PeerAddr)
if err != nil {
hostinfo.logger(f.l).WithField("relayTo", relay.PeerAddr).WithError(err).WithField("hostinfo.vpnAddrs", hostinfo.vpnAddrs).Info("Failed to find target host info by ip")
return
}
// If that relay is Established, forward the payload through it
if targetRelay.State == Established {
switch targetRelay.Type {
case ForwardingType:
// Forward this packet through the relay tunnel
// Find the target HostInfo
f.SendVia(targetHI, targetRelay, signedPayload, nb, out[i].Scratch, false)
return
case TerminalType:
hostinfo.logger(f.l).Error("Unexpected Relay Type of Terminal")
}
} else {
hostinfo.logger(f.l).WithFields(logrus.Fields{"relayTo": relay.PeerAddr, "relayFrom": hostinfo.vpnAddrs[0], "targetRelayState": targetRelay.State}).Info("Unexpected target relay state")
return
}
}
}
case header.LightHouse:
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
if !f.handleEncrypted(ci, ip, h) {
return
}
d, err := f.decrypt(hostinfo, h.MessageCounter, out[i].Scratch, segment, h, nb)
if err != nil {
hostinfo.logger(f.l).WithError(err).WithField("udpAddr", ip).
WithField("packet", segment).
Error("Failed to decrypt lighthouse packet")
return
}
lhf.HandleRequest(ip, hostinfo.vpnAddrs, d, f)
// Fallthrough to the bottom to record incoming traffic
case header.Test:
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
if !f.handleEncrypted(ci, ip, h) {
return
}
d, err := f.decrypt(hostinfo, h.MessageCounter, out[i].Scratch, segment, h, nb)
if err != nil {
hostinfo.logger(f.l).WithError(err).WithField("udpAddr", ip).
WithField("packet", segment).
Error("Failed to decrypt test packet")
return
}
if h.Subtype == header.TestRequest {
// This testRequest might be from TryPromoteBest, so we should roam
// to the new IP address before responding
f.handleHostRoaming(hostinfo, ip)
f.send(header.Test, header.TestReply, ci, hostinfo, d, nb, out[i].Scratch)
}
// Fallthrough to the bottom to record incoming traffic
// Non encrypted messages below here, they should not fall through to avoid tracking incoming traffic since they
// are unauthenticated
case header.Handshake:
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
f.handshakeManager.HandleIncoming(ip, nil, segment, h)
return
case header.RecvError:
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
f.handleRecvError(ip, h)
return
case header.CloseTunnel:
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
if !f.handleEncrypted(ci, ip, h) {
return
}
hostinfo.logger(f.l).WithField("udpAddr", ip).
Info("Close tunnel received, tearing down.")
f.closeTunnel(hostinfo)
return
case header.Control:
if !f.handleEncrypted(ci, ip, h) {
return
}
d, err := f.decrypt(hostinfo, h.MessageCounter, out[i].Scratch, segment, h, nb)
if err != nil {
hostinfo.logger(f.l).WithError(err).WithField("udpAddr", ip).
WithField("packet", segment).
Error("Failed to decrypt Control packet")
return
}
f.relayManager.HandleControlMsg(hostinfo, d, f)
default:
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
hostinfo.logger(f.l).Debugf("Unexpected packet received from %s", ip)
return
}
f.handleHostRoaming(hostinfo, ip)
f.connectionManager.In(hostinfo)
}
_, err := f.readers[q].WriteOne(out[i], false, q)
if err != nil {
f.l.WithError(err).Error("Failed to write packet")
}
}
} }
// closeTunnel closes a tunnel locally, it does not send a closeTunnel packet to the remote // closeTunnel closes a tunnel locally, it does not send a closeTunnel packet to the remote
@@ -254,16 +466,18 @@ func (f *Interface) handleHostRoaming(hostinfo *HostInfo, udpAddr netip.AddrPort
} }
// handleEncrypted returns true if a packet should be processed, false otherwise
func (f *Interface) handleEncrypted(ci *ConnectionState, addr netip.AddrPort, h *header.H) bool { func (f *Interface) handleEncrypted(ci *ConnectionState, addr netip.AddrPort, h *header.H) bool {
// If connectionstate exists and the replay protector allows, process packet // If connectionstate does not exist, send a recv error, if possible, to encourage a fast reconnect
// Else, send recv errors for 300 seconds after a restart to allow fast reconnection. if ci == nil {
if ci == nil || !ci.window.Check(f.l, h.MessageCounter) {
if addr.IsValid() { if addr.IsValid() {
f.maybeSendRecvError(addr, h.RemoteIndex) f.maybeSendRecvError(addr, h.RemoteIndex)
return false }
} else {
return false return false
} }
// If the window check fails, refuse to process the packet, but don't send a recv error
if !ci.window.Check(f.l, h.MessageCounter) {
return false
} }
return true return true
@@ -463,7 +677,55 @@ func (f *Interface) decrypt(hostinfo *HostInfo, mc uint64, out []byte, packet []
return out, nil return out, nil
} }
func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out []byte, packet []byte, fwPacket *firewall.Packet, nb []byte, q int, localCache firewall.ConntrackCache) bool { func (f *Interface) decryptToTunDelayWrite(hostinfo *HostInfo, messageCounter uint64, out *packet.OutPacket, pkt *packet.Packet, inSegment []byte, fwPacket *firewall.Packet, nb []byte, q int, localCache firewall.ConntrackCache, now time.Time) bool {
var err error
seg, err := f.readers[q].AllocSeg(out, q)
if err != nil {
f.l.WithError(err).Errorln("decryptToTunDelayWrite: failed to allocate segment")
return false
}
out.SegmentPayloads[seg] = out.SegmentPayloads[seg][:0]
out.SegmentPayloads[seg], err = hostinfo.ConnectionState.dKey.DecryptDanger(out.SegmentPayloads[seg], inSegment[:header.Len], inSegment[header.Len:], messageCounter, nb)
if err != nil {
hostinfo.logger(f.l).WithError(err).Error("Failed to decrypt packet")
return false
}
err = newPacket(out.SegmentPayloads[seg], true, fwPacket)
if err != nil {
hostinfo.logger(f.l).WithError(err).WithField("packet", out).
Warnf("Error while validating inbound packet")
return false
}
if !hostinfo.ConnectionState.window.Update(f.l, messageCounter) {
hostinfo.logger(f.l).WithField("fwPacket", fwPacket).
Debugln("dropping out of window packet")
return false
}
dropReason := f.firewall.Drop(*fwPacket, true, hostinfo, f.pki.GetCAPool(), localCache, now)
if dropReason != nil {
// NOTE: We give `packet` as the `out` here since we already decrypted from it and we don't need it anymore
// This gives us a buffer to build the reject packet in
f.rejectOutside(out.SegmentPayloads[seg], hostinfo.ConnectionState, hostinfo, nb, inSegment, q)
if f.l.Level >= logrus.DebugLevel {
hostinfo.logger(f.l).WithField("fwPacket", fwPacket).
WithField("reason", dropReason).
Debugln("dropping inbound packet")
}
return false
}
f.connectionManager.In(hostinfo)
pkt.OutLen += len(inSegment)
out.Segments[seg] = out.Segments[seg][:len(out.SegmentHeaders[seg])+len(out.SegmentPayloads[seg])]
return true
}
func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out []byte, packet []byte, fwPacket *firewall.Packet, nb []byte, q int, localCache firewall.ConntrackCache, now time.Time) bool {
var err error var err error
out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:header.Len], packet[header.Len:], messageCounter, nb) out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:header.Len], packet[header.Len:], messageCounter, nb)
@@ -485,7 +747,7 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
return false return false
} }
dropReason := f.firewall.Drop(*fwPacket, true, hostinfo, f.pki.GetCAPool(), localCache) dropReason := f.firewall.Drop(*fwPacket, true, hostinfo, f.pki.GetCAPool(), localCache, now)
if dropReason != nil { if dropReason != nil {
// NOTE: We give `packet` as the `out` here since we already decrypted from it and we don't need it anymore // NOTE: We give `packet` as the `out` here since we already decrypted from it and we don't need it anymore
// This gives us a buffer to build the reject packet in // This gives us a buffer to build the reject packet in
@@ -498,7 +760,7 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
return false return false
} }
f.connectionManager.In(hostinfo.localIndexId) f.connectionManager.In(hostinfo)
_, err = f.readers[q].Write(out) _, err = f.readers[q].Write(out)
if err != nil { if err != nil {
f.l.WithError(err).Error("Failed to write to tun") f.l.WithError(err).Error("Failed to write to tun")
@@ -537,10 +799,6 @@ func (f *Interface) handleRecvError(addr netip.AddrPort, h *header.H) {
return return
} }
if !hostinfo.RecvErrorExceeded() {
return
}
if hostinfo.remote.IsValid() && hostinfo.remote != addr { if hostinfo.remote.IsValid() && hostinfo.remote != addr {
f.l.Infoln("Someone spoofing recv_errors? ", addr, hostinfo.remote) f.l.Infoln("Someone spoofing recv_errors? ", addr, hostinfo.remote)
return return

View File

@@ -1,17 +1,16 @@
package overlay package overlay
import ( import (
"io"
"net/netip" "net/netip"
"github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/routing"
) )
type Device interface { type Device interface {
io.ReadWriteCloser TunDev
Activate() error Activate() error
Networks() []netip.Prefix Networks() []netip.Prefix
Name() string Name() string
RoutesFor(netip.Addr) routing.Gateways RoutesFor(netip.Addr) routing.Gateways
NewMultiQueueReader() (io.ReadWriteCloser, error) NewMultiQueueReader() (TunDev, error)
} }

View File

@@ -0,0 +1,91 @@
package eventfd
import (
"encoding/binary"
"syscall"
"golang.org/x/sys/unix"
)
type EventFD struct {
fd int
buf [8]byte
}
func New() (EventFD, error) {
fd, err := unix.Eventfd(0, unix.EFD_NONBLOCK)
if err != nil {
return EventFD{}, err
}
return EventFD{
fd: fd,
buf: [8]byte{},
}, nil
}
func (e *EventFD) Kick() error {
binary.LittleEndian.PutUint64(e.buf[:], 1) //is this right???
_, err := syscall.Write(int(e.fd), e.buf[:])
return err
}
func (e *EventFD) Close() error {
if e.fd != 0 {
return unix.Close(e.fd)
}
return nil
}
func (e *EventFD) FD() int {
return e.fd
}
type Epoll struct {
fd int
buf [8]byte
events []syscall.EpollEvent
}
func NewEpoll() (Epoll, error) {
fd, err := unix.EpollCreate1(0)
if err != nil {
return Epoll{}, err
}
return Epoll{
fd: fd,
buf: [8]byte{},
events: make([]syscall.EpollEvent, 1),
}, nil
}
func (ep *Epoll) AddEvent(fdToAdd int) error {
event := syscall.EpollEvent{
Events: syscall.EPOLLIN,
Fd: int32(fdToAdd),
}
return syscall.EpollCtl(ep.fd, syscall.EPOLL_CTL_ADD, fdToAdd, &event)
}
func (ep *Epoll) Block() (int, error) {
n, err := syscall.EpollWait(ep.fd, ep.events, -1)
if err != nil {
//goland:noinspection GoDirectComparisonOfErrors
if err == syscall.EINTR {
return 0, nil //??
}
return -1, err
}
return n, nil
}
func (ep *Epoll) Clear() error {
_, err := syscall.Read(int(ep.events[0].Fd), ep.buf[:])
return err
}
func (ep *Epoll) Close() error {
if ep.fd != 0 {
return unix.Close(ep.fd)
}
return nil
}

View File

@@ -1,15 +1,30 @@
package overlay package overlay
import ( import (
"fmt"
"io"
"net"
"net/netip" "net/netip"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/packet"
"github.com/slackhq/nebula/util" "github.com/slackhq/nebula/util"
) )
const DefaultMTU = 1300 const DefaultMTU = 1300
type TunDev interface {
io.WriteCloser
ReadMany(x []*packet.VirtIOPacket, q int) (int, error)
//todo this interface sux
AllocSeg(pkt *packet.OutPacket, q int) (int, error)
WriteOne(x *packet.OutPacket, kick bool, q int) (int, error)
WriteMany(x []*packet.OutPacket, q int) (int, error)
RecycleRxSeg(pkt *packet.VirtIOPacket, kick bool, q int) error
}
// TODO: We may be able to remove routines // TODO: We may be able to remove routines
type DeviceFactory func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) type DeviceFactory func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error)
@@ -24,11 +39,11 @@ func NewDeviceFromConfig(c *config.C, l *logrus.Logger, vpnNetworks []netip.Pref
} }
} }
func NewFdDeviceFromConfig(fd *int) DeviceFactory { //func NewFdDeviceFromConfig(fd *int) DeviceFactory {
return func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) { // return func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) {
return newTunFromFd(c, l, *fd, vpnNetworks) // return newTunFromFd(c, l, *fd, vpnNetworks)
} // }
} //}
func getAllRoutesFromConfig(c *config.C, vpnNetworks []netip.Prefix, initial bool) (bool, []Route, error) { func getAllRoutesFromConfig(c *config.C, vpnNetworks []netip.Prefix, initial bool) (bool, []Route, error) {
if !initial && !c.HasChanged("tun.routes") && !c.HasChanged("tun.unsafe_routes") { if !initial && !c.HasChanged("tun.routes") && !c.HasChanged("tun.unsafe_routes") {
@@ -70,3 +85,51 @@ func findRemovedRoutes(newRoutes, oldRoutes []Route) []Route {
return removed return removed
} }
func prefixToMask(prefix netip.Prefix) netip.Addr {
pLen := 128
if prefix.Addr().Is4() {
pLen = 32
}
addr, _ := netip.AddrFromSlice(net.CIDRMask(prefix.Bits(), pLen))
return addr
}
func flipBytes(b []byte) []byte {
for i := 0; i < len(b); i++ {
b[i] ^= 0xFF
}
return b
}
func orBytes(a []byte, b []byte) []byte {
ret := make([]byte, len(a))
for i := 0; i < len(a); i++ {
ret[i] = a[i] | b[i]
}
return ret
}
func getBroadcast(cidr netip.Prefix) netip.Addr {
broadcast, _ := netip.AddrFromSlice(
orBytes(
cidr.Addr().AsSlice(),
flipBytes(prefixToMask(cidr).AsSlice()),
),
)
return broadcast
}
func selectGateway(dest netip.Prefix, gateways []netip.Prefix) (netip.Prefix, error) {
for _, gateway := range gateways {
if dest.Addr().Is4() && gateway.Addr().Is4() {
return gateway, nil
}
if dest.Addr().Is6() && gateway.Addr().Is6() {
return gateway, nil
}
}
return netip.Prefix{}, fmt.Errorf("no gateway found for %v in the list of vpn networks", dest)
}

View File

@@ -7,7 +7,6 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"net"
"net/netip" "net/netip"
"os" "os"
"sync/atomic" "sync/atomic"
@@ -295,7 +294,6 @@ func (t *tun) activate6(network netip.Prefix) error {
Vltime: 0xffffffff, Vltime: 0xffffffff,
Pltime: 0xffffffff, Pltime: 0xffffffff,
}, },
//TODO: CERT-V2 should we disable DAD (duplicate address detection) and mark this as a secured address?
Flags: _IN6_IFF_NODAD, Flags: _IN6_IFF_NODAD,
} }
@@ -554,13 +552,3 @@ func (t *tun) Name() string {
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for darwin") return nil, fmt.Errorf("TODO: multiqueue not implemented for darwin")
} }
func prefixToMask(prefix netip.Prefix) netip.Addr {
pLen := 128
if prefix.Addr().Is4() {
pLen = 32
}
addr, _ := netip.AddrFromSlice(net.CIDRMask(prefix.Bits(), pLen))
return addr
}

View File

@@ -9,6 +9,8 @@ import (
"github.com/rcrowley/go-metrics" "github.com/rcrowley/go-metrics"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/overlay/virtqueue"
"github.com/slackhq/nebula/packet"
"github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/routing"
) )
@@ -22,6 +24,10 @@ type disabledTun struct {
l *logrus.Logger l *logrus.Logger
} }
func (*disabledTun) RecycleRxSeg(pkt *packet.VirtIOPacket, kick bool, q int) error {
return nil
}
func newDisabledTun(vpnNetworks []netip.Prefix, queueLen int, metricsEnabled bool, l *logrus.Logger) *disabledTun { func newDisabledTun(vpnNetworks []netip.Prefix, queueLen int, metricsEnabled bool, l *logrus.Logger) *disabledTun {
tun := &disabledTun{ tun := &disabledTun{
vpnNetworks: vpnNetworks, vpnNetworks: vpnNetworks,
@@ -40,6 +46,10 @@ func newDisabledTun(vpnNetworks []netip.Prefix, queueLen int, metricsEnabled boo
return tun return tun
} }
func (*disabledTun) GetQueues() []*virtqueue.SplitQueue {
return nil
}
func (*disabledTun) Activate() error { func (*disabledTun) Activate() error {
return nil return nil
} }
@@ -105,7 +115,23 @@ func (t *disabledTun) Write(b []byte) (int, error) {
return len(b), nil return len(b), nil
} }
func (t *disabledTun) NewMultiQueueReader() (io.ReadWriteCloser, error) { func (t *disabledTun) AllocSeg(pkt *packet.OutPacket, q int) (int, error) {
return 0, fmt.Errorf("tun_disabled: AllocSeg not implemented")
}
func (t *disabledTun) WriteOne(x *packet.OutPacket, kick bool, q int) (int, error) {
return 0, fmt.Errorf("tun_disabled: WriteOne not implemented")
}
func (t *disabledTun) WriteMany(x []*packet.OutPacket, q int) (int, error) {
return 0, fmt.Errorf("tun_disabled: WriteMany not implemented")
}
func (t *disabledTun) ReadMany(b []*packet.VirtIOPacket, _ int) (int, error) {
return t.Read(b[0].Payload)
}
func (t *disabledTun) NewMultiQueueReader() (TunDev, error) {
return t, nil return t, nil
} }

View File

@@ -10,11 +10,9 @@ import (
"io" "io"
"io/fs" "io/fs"
"net/netip" "net/netip"
"os"
"os/exec"
"strconv"
"sync/atomic" "sync/atomic"
"syscall" "syscall"
"time"
"unsafe" "unsafe"
"github.com/gaissmai/bart" "github.com/gaissmai/bart"
@@ -22,12 +20,18 @@ import (
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/routing"
"github.com/slackhq/nebula/util" "github.com/slackhq/nebula/util"
netroute "golang.org/x/net/route"
"golang.org/x/sys/unix"
) )
const ( const (
// FIODGNAME is defined in sys/sys/filio.h on FreeBSD // FIODGNAME is defined in sys/sys/filio.h on FreeBSD
// For 32-bit systems, use FIODGNAME_32 (not defined in this file: 0x80086678) // For 32-bit systems, use FIODGNAME_32 (not defined in this file: 0x80086678)
FIODGNAME = 0x80106678 FIODGNAME = 0x80106678
TUNSIFMODE = 0x8004745e
TUNSIFHEAD = 0x80047460
OSIOCAIFADDR_IN6 = 0x8088691b
IN6_IFF_NODAD = 0x0020
) )
type fiodgnameArg struct { type fiodgnameArg struct {
@@ -37,43 +41,159 @@ type fiodgnameArg struct {
} }
type ifreqRename struct { type ifreqRename struct {
Name [16]byte Name [unix.IFNAMSIZ]byte
Data uintptr Data uintptr
} }
type ifreqDestroy struct { type ifreqDestroy struct {
Name [16]byte Name [unix.IFNAMSIZ]byte
pad [16]byte pad [16]byte
} }
type ifReq struct {
Name [unix.IFNAMSIZ]byte
Flags uint16
}
type ifreqMTU struct {
Name [unix.IFNAMSIZ]byte
MTU int32
}
type addrLifetime struct {
Expire uint64
Preferred uint64
Vltime uint32
Pltime uint32
}
type ifreqAlias4 struct {
Name [unix.IFNAMSIZ]byte
Addr unix.RawSockaddrInet4
DstAddr unix.RawSockaddrInet4
MaskAddr unix.RawSockaddrInet4
VHid uint32
}
type ifreqAlias6 struct {
Name [unix.IFNAMSIZ]byte
Addr unix.RawSockaddrInet6
DstAddr unix.RawSockaddrInet6
PrefixMask unix.RawSockaddrInet6
Flags uint32
Lifetime addrLifetime
VHid uint32
}
type tun struct { type tun struct {
Device string Device string
vpnNetworks []netip.Prefix vpnNetworks []netip.Prefix
MTU int MTU int
Routes atomic.Pointer[[]Route] Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[routing.Gateways]] routeTree atomic.Pointer[bart.Table[routing.Gateways]]
linkAddr *netroute.LinkAddr
l *logrus.Logger l *logrus.Logger
devFd int
}
io.ReadWriteCloser func (t *tun) Read(to []byte) (int, error) {
// use readv() to read from the tunnel device, to eliminate the need for copying the buffer
if t.devFd < 0 {
return -1, syscall.EINVAL
}
// first 4 bytes is protocol family, in network byte order
head := make([]byte, 4)
iovecs := []syscall.Iovec{
{&head[0], 4},
{&to[0], uint64(len(to))},
}
n, _, errno := syscall.Syscall(syscall.SYS_READV, uintptr(t.devFd), uintptr(unsafe.Pointer(&iovecs[0])), uintptr(2))
var err error
if errno != 0 {
err = syscall.Errno(errno)
} else {
err = nil
}
// fix bytes read number to exclude header
bytesRead := int(n)
if bytesRead < 0 {
return bytesRead, err
} else if bytesRead < 4 {
return 0, err
} else {
return bytesRead - 4, err
}
}
// Write is only valid for single threaded use
func (t *tun) Write(from []byte) (int, error) {
// use writev() to write to the tunnel device, to eliminate the need for copying the buffer
if t.devFd < 0 {
return -1, syscall.EINVAL
}
if len(from) <= 1 {
return 0, syscall.EIO
}
ipVer := from[0] >> 4
var head []byte
// first 4 bytes is protocol family, in network byte order
if ipVer == 4 {
head = []byte{0, 0, 0, syscall.AF_INET}
} else if ipVer == 6 {
head = []byte{0, 0, 0, syscall.AF_INET6}
} else {
return 0, fmt.Errorf("unable to determine IP version from packet")
}
iovecs := []syscall.Iovec{
{&head[0], 4},
{&from[0], uint64(len(from))},
}
n, _, errno := syscall.Syscall(syscall.SYS_WRITEV, uintptr(t.devFd), uintptr(unsafe.Pointer(&iovecs[0])), uintptr(2))
var err error
if errno != 0 {
err = syscall.Errno(errno)
} else {
err = nil
}
return int(n) - 4, err
} }
func (t *tun) Close() error { func (t *tun) Close() error {
if t.ReadWriteCloser != nil { if t.devFd >= 0 {
if err := t.ReadWriteCloser.Close(); err != nil { err := syscall.Close(t.devFd)
return err
}
s, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_DGRAM, syscall.IPPROTO_IP)
if err != nil { if err != nil {
return err t.l.WithError(err).Error("Error closing device")
} }
t.devFd = -1
c := make(chan struct{})
go func() {
// destroying the interface can block if a read() is still pending. Do this asynchronously.
defer close(c)
s, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_DGRAM, syscall.IPPROTO_IP)
if err == nil {
defer syscall.Close(s) defer syscall.Close(s)
ifreq := ifreqDestroy{Name: t.deviceBytes()} ifreq := ifreqDestroy{Name: t.deviceBytes()}
// Destroy the interface
err = ioctl(uintptr(s), syscall.SIOCIFDESTROY, uintptr(unsafe.Pointer(&ifreq))) err = ioctl(uintptr(s), syscall.SIOCIFDESTROY, uintptr(unsafe.Pointer(&ifreq)))
return err }
if err != nil {
t.l.WithError(err).Error("Error destroying tunnel")
}
}()
// wait up to 1 second so we start blocking at the ioctl
select {
case <-c:
case <-time.After(1 * time.Second):
}
} }
return nil return nil
@@ -85,32 +205,37 @@ func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun,
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) { func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
// Try to open existing tun device // Try to open existing tun device
var file *os.File var fd int
var err error var err error
deviceName := c.GetString("tun.dev", "") deviceName := c.GetString("tun.dev", "")
if deviceName != "" { if deviceName != "" {
file, err = os.OpenFile("/dev/"+deviceName, os.O_RDWR, 0) fd, err = syscall.Open("/dev/"+deviceName, syscall.O_RDWR, 0)
} }
if errors.Is(err, fs.ErrNotExist) || deviceName == "" { if errors.Is(err, fs.ErrNotExist) || deviceName == "" {
// If the device doesn't already exist, request a new one and rename it // If the device doesn't already exist, request a new one and rename it
file, err = os.OpenFile("/dev/tun", os.O_RDWR, 0) fd, err = syscall.Open("/dev/tun", syscall.O_RDWR, 0)
} }
if err != nil { if err != nil {
return nil, err return nil, err
} }
rawConn, err := file.SyscallConn() // Read the name of the interface
if err != nil { var name [16]byte
return nil, fmt.Errorf("SyscallConn: %v", err) arg := fiodgnameArg{length: 16, buf: unsafe.Pointer(&name)}
ctrlErr := ioctl(uintptr(fd), FIODGNAME, uintptr(unsafe.Pointer(&arg)))
if ctrlErr == nil {
// set broadcast mode and multicast
ifmode := uint32(unix.IFF_BROADCAST | unix.IFF_MULTICAST)
ctrlErr = ioctl(uintptr(fd), TUNSIFMODE, uintptr(unsafe.Pointer(&ifmode)))
}
if ctrlErr == nil {
// turn on link-layer mode, to support ipv6
ifhead := uint32(1)
ctrlErr = ioctl(uintptr(fd), TUNSIFHEAD, uintptr(unsafe.Pointer(&ifhead)))
} }
var name [16]byte
var ctrlErr error
rawConn.Control(func(fd uintptr) {
// Read the name of the interface
arg := fiodgnameArg{length: 16, buf: unsafe.Pointer(&name)}
ctrlErr = ioctl(fd, FIODGNAME, uintptr(unsafe.Pointer(&arg)))
})
if ctrlErr != nil { if ctrlErr != nil {
return nil, err return nil, err
} }
@@ -122,11 +247,7 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
// If the name doesn't match the desired interface name, rename it now // If the name doesn't match the desired interface name, rename it now
if ifName != deviceName { if ifName != deviceName {
s, err := syscall.Socket( s, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP)
syscall.AF_INET,
syscall.SOCK_DGRAM,
syscall.IPPROTO_IP,
)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -149,11 +270,11 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
} }
t := &tun{ t := &tun{
ReadWriteCloser: file,
Device: deviceName, Device: deviceName,
vpnNetworks: vpnNetworks, vpnNetworks: vpnNetworks,
MTU: c.GetInt("tun.mtu", DefaultMTU), MTU: c.GetInt("tun.mtu", DefaultMTU),
l: l, l: l,
devFd: fd,
} }
err = t.reload(c, true) err = t.reload(c, true)
@@ -172,38 +293,111 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
} }
func (t *tun) addIp(cidr netip.Prefix) error { func (t *tun) addIp(cidr netip.Prefix) error {
var err error if cidr.Addr().Is4() {
// TODO use syscalls instead of exec.Command ifr := ifreqAlias4{
cmd := exec.Command("/sbin/ifconfig", t.Device, cidr.String(), cidr.Addr().String()) Name: t.deviceBytes(),
t.l.Debug("command: ", cmd.String()) Addr: unix.RawSockaddrInet4{
if err = cmd.Run(); err != nil { Len: unix.SizeofSockaddrInet4,
return fmt.Errorf("failed to run 'ifconfig': %s", err) Family: unix.AF_INET,
Addr: cidr.Addr().As4(),
},
DstAddr: unix.RawSockaddrInet4{
Len: unix.SizeofSockaddrInet4,
Family: unix.AF_INET,
Addr: getBroadcast(cidr).As4(),
},
MaskAddr: unix.RawSockaddrInet4{
Len: unix.SizeofSockaddrInet4,
Family: unix.AF_INET,
Addr: prefixToMask(cidr).As4(),
},
VHid: 0,
}
s, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP)
if err != nil {
return err
}
defer syscall.Close(s)
// Note: unix.SIOCAIFADDR corresponds to FreeBSD's OSIOCAIFADDR
if err := ioctl(uintptr(s), unix.SIOCAIFADDR, uintptr(unsafe.Pointer(&ifr))); err != nil {
return fmt.Errorf("failed to set tun address %s: %s", cidr.Addr().String(), err)
}
return nil
} }
cmd = exec.Command("/sbin/route", "-n", "add", "-net", cidr.String(), "-interface", t.Device) if cidr.Addr().Is6() {
t.l.Debug("command: ", cmd.String()) ifr := ifreqAlias6{
if err = cmd.Run(); err != nil { Name: t.deviceBytes(),
return fmt.Errorf("failed to run 'route add': %s", err) Addr: unix.RawSockaddrInet6{
Len: unix.SizeofSockaddrInet6,
Family: unix.AF_INET6,
Addr: cidr.Addr().As16(),
},
PrefixMask: unix.RawSockaddrInet6{
Len: unix.SizeofSockaddrInet6,
Family: unix.AF_INET6,
Addr: prefixToMask(cidr).As16(),
},
Lifetime: addrLifetime{
Expire: 0,
Preferred: 0,
Vltime: 0xffffffff,
Pltime: 0xffffffff,
},
Flags: IN6_IFF_NODAD,
}
s, err := syscall.Socket(syscall.AF_INET6, syscall.SOCK_DGRAM, syscall.IPPROTO_IP)
if err != nil {
return err
}
defer syscall.Close(s)
if err := ioctl(uintptr(s), OSIOCAIFADDR_IN6, uintptr(unsafe.Pointer(&ifr))); err != nil {
return fmt.Errorf("failed to set tun address %s: %s", cidr.Addr().String(), err)
}
return nil
} }
cmd = exec.Command("/sbin/ifconfig", t.Device, "mtu", strconv.Itoa(t.MTU)) return fmt.Errorf("unknown address type %v", cidr)
t.l.Debug("command: ", cmd.String())
if err = cmd.Run(); err != nil {
return fmt.Errorf("failed to run 'ifconfig': %s", err)
}
// Unsafe path routes
return t.addRoutes(false)
} }
func (t *tun) Activate() error { func (t *tun) Activate() error {
// Setup our default MTU
err := t.setMTU()
if err != nil {
return err
}
linkAddr, err := getLinkAddr(t.Device)
if err != nil {
return err
}
if linkAddr == nil {
return fmt.Errorf("unable to discover link_addr for tun interface")
}
t.linkAddr = linkAddr
for i := range t.vpnNetworks { for i := range t.vpnNetworks {
err := t.addIp(t.vpnNetworks[i]) err := t.addIp(t.vpnNetworks[i])
if err != nil { if err != nil {
return err return err
} }
} }
return nil
return t.addRoutes(false)
}
func (t *tun) setMTU() error {
// Set the MTU on the device
s, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP)
if err != nil {
return err
}
defer syscall.Close(s)
ifm := ifreqMTU{Name: t.deviceBytes(), MTU: int32(t.MTU)}
err = ioctl(uintptr(s), unix.SIOCSIFMTU, uintptr(unsafe.Pointer(&ifm)))
return err
} }
func (t *tun) reload(c *config.C, initial bool) error { func (t *tun) reload(c *config.C, initial bool) error {
@@ -268,15 +462,16 @@ func (t *tun) addRoutes(logErrors bool) error {
continue continue
} }
cmd := exec.Command("/sbin/route", "-n", "add", "-net", r.Cidr.String(), "-interface", t.Device) err := addRoute(r.Cidr, t.linkAddr)
t.l.Debug("command: ", cmd.String()) if err != nil {
if err := cmd.Run(); err != nil { retErr := util.NewContextualError("Failed to add route", map[string]any{"route": r}, err)
retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]any{"route": r}, err)
if logErrors { if logErrors {
retErr.Log(t.l) retErr.Log(t.l)
} else { } else {
return retErr return retErr
} }
} else {
t.l.WithField("route", r).Info("Added route")
} }
} }
@@ -289,9 +484,8 @@ func (t *tun) removeRoutes(routes []Route) error {
continue continue
} }
cmd := exec.Command("/sbin/route", "-n", "delete", "-net", r.Cidr.String(), "-interface", t.Device) err := delRoute(r.Cidr, t.linkAddr)
t.l.Debug("command: ", cmd.String()) if err != nil {
if err := cmd.Run(); err != nil {
t.l.WithError(err).WithField("route", r).Error("Failed to remove route") t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
} else { } else {
t.l.WithField("route", r).Info("Removed route") t.l.WithField("route", r).Info("Removed route")
@@ -306,3 +500,120 @@ func (t *tun) deviceBytes() (o [16]byte) {
} }
return return
} }
func addRoute(prefix netip.Prefix, gateway netroute.Addr) error {
sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
if err != nil {
return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
}
defer unix.Close(sock)
route := &netroute.RouteMessage{
Version: unix.RTM_VERSION,
Type: unix.RTM_ADD,
Flags: unix.RTF_UP,
Seq: 1,
}
if prefix.Addr().Is4() {
route.Addrs = []netroute.Addr{
unix.RTAX_DST: &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()},
unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()},
unix.RTAX_GATEWAY: gateway,
}
} else {
route.Addrs = []netroute.Addr{
unix.RTAX_DST: &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()},
unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()},
unix.RTAX_GATEWAY: gateway,
}
}
data, err := route.Marshal()
if err != nil {
return fmt.Errorf("failed to create route.RouteMessage: %w", err)
}
_, err = unix.Write(sock, data[:])
if err != nil {
if errors.Is(err, unix.EEXIST) {
// Try to do a change
route.Type = unix.RTM_CHANGE
data, err = route.Marshal()
if err != nil {
return fmt.Errorf("failed to create route.RouteMessage for change: %w", err)
}
_, err = unix.Write(sock, data[:])
fmt.Println("DOING CHANGE")
return err
}
return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err)
}
return nil
}
func delRoute(prefix netip.Prefix, gateway netroute.Addr) error {
sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
if err != nil {
return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
}
defer unix.Close(sock)
route := netroute.RouteMessage{
Version: unix.RTM_VERSION,
Type: unix.RTM_DELETE,
Seq: 1,
}
if prefix.Addr().Is4() {
route.Addrs = []netroute.Addr{
unix.RTAX_DST: &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()},
unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()},
unix.RTAX_GATEWAY: gateway,
}
} else {
route.Addrs = []netroute.Addr{
unix.RTAX_DST: &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()},
unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()},
unix.RTAX_GATEWAY: gateway,
}
}
data, err := route.Marshal()
if err != nil {
return fmt.Errorf("failed to create route.RouteMessage: %w", err)
}
_, err = unix.Write(sock, data[:])
if err != nil {
return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err)
}
return nil
}
// getLinkAddr Gets the link address for the interface of the given name
func getLinkAddr(name string) (*netroute.LinkAddr, error) {
rib, err := netroute.FetchRIB(unix.AF_UNSPEC, unix.NET_RT_IFLIST, 0)
if err != nil {
return nil, err
}
msgs, err := netroute.ParseRIB(unix.NET_RT_IFLIST, rib)
if err != nil {
return nil, err
}
for _, m := range msgs {
switch m := m.(type) {
case *netroute.InterfaceMessage:
if m.Name == name {
sa, ok := m.Addrs[unix.RTAX_IFP].(*netroute.LinkAddr)
if ok {
return sa, nil
}
}
}
}
return nil, nil
}

View File

@@ -5,7 +5,6 @@ package overlay
import ( import (
"fmt" "fmt"
"io"
"net" "net"
"net/netip" "net/netip"
"os" "os"
@@ -17,15 +16,19 @@ import (
"github.com/gaissmai/bart" "github.com/gaissmai/bart"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/overlay/vhostnet"
"github.com/slackhq/nebula/packet"
"github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/routing"
"github.com/slackhq/nebula/util" "github.com/slackhq/nebula/util"
"github.com/slackhq/nebula/util/virtio"
"github.com/vishvananda/netlink" "github.com/vishvananda/netlink"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
) )
type tun struct { type tun struct {
io.ReadWriteCloser file *os.File
fd int fd int
vdev []*vhostnet.Device
Device string Device string
vpnNetworks []netip.Prefix vpnNetworks []netip.Prefix
MaxMTU int MaxMTU int
@@ -40,6 +43,7 @@ type tun struct {
useSystemRoutes bool useSystemRoutes bool
useSystemRoutesBufferSize int useSystemRoutesBufferSize int
isV6 bool
l *logrus.Logger l *logrus.Logger
} }
@@ -102,7 +106,7 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueu
} }
var req ifReq var req ifReq
req.Flags = uint16(unix.IFF_TUN | unix.IFF_NO_PI) req.Flags = uint16(unix.IFF_TUN | unix.IFF_NO_PI | unix.IFF_TUN_EXCL | unix.IFF_VNET_HDR | unix.IFF_NAPI)
if multiqueue { if multiqueue {
req.Flags |= unix.IFF_MULTI_QUEUE req.Flags |= unix.IFF_MULTI_QUEUE
} }
@@ -112,20 +116,47 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueu
} }
name := strings.Trim(string(req.Name[:]), "\x00") name := strings.Trim(string(req.Name[:]), "\x00")
if err = unix.SetNonblock(fd, true); err != nil {
_ = unix.Close(fd)
return nil, fmt.Errorf("make file descriptor non-blocking: %w", err)
}
file := os.NewFile(uintptr(fd), "/dev/net/tun") file := os.NewFile(uintptr(fd), "/dev/net/tun")
err = unix.IoctlSetPointerInt(fd, unix.TUNSETVNETHDRSZ, virtio.NetHdrSize)
if err != nil {
return nil, fmt.Errorf("set vnethdr size: %w", err)
}
flags := 0
//flags = //unix.TUN_F_CSUM //| unix.TUN_F_TSO4 | unix.TUN_F_USO4 | unix.TUN_F_TSO6 | unix.TUN_F_USO6
err = unix.IoctlSetInt(fd, unix.TUNSETOFFLOAD, flags)
if err != nil {
return nil, fmt.Errorf("set offloads: %w", err)
}
t, err := newTunGeneric(c, l, file, vpnNetworks) t, err := newTunGeneric(c, l, file, vpnNetworks)
if err != nil { if err != nil {
return nil, err return nil, err
} }
t.fd = fd
t.Device = name t.Device = name
vdev, err := vhostnet.NewDevice(
vhostnet.WithBackendFD(fd),
vhostnet.WithQueueSize(8192), //todo config
)
if err != nil {
return nil, err
}
t.vdev = []*vhostnet.Device{vdev}
return t, nil return t, nil
} }
func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, vpnNetworks []netip.Prefix) (*tun, error) { func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, vpnNetworks []netip.Prefix) (*tun, error) {
t := &tun{ t := &tun{
ReadWriteCloser: file, file: file,
fd: int(file.Fd()), fd: int(file.Fd()),
vpnNetworks: vpnNetworks, vpnNetworks: vpnNetworks,
TXQueueLen: c.GetInt("tun.tx_queue", 500), TXQueueLen: c.GetInt("tun.tx_queue", 500),
@@ -133,6 +164,9 @@ func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, vpnNetworks []n
useSystemRoutesBufferSize: c.GetInt("tun.use_system_route_table_buffer_size", 0), useSystemRoutesBufferSize: c.GetInt("tun.use_system_route_table_buffer_size", 0),
l: l, l: l,
} }
if len(vpnNetworks) != 0 {
t.isV6 = vpnNetworks[0].Addr().Is6() //todo what about multi-IP?
}
err := t.reload(c, true) err := t.reload(c, true)
if err != nil { if err != nil {
@@ -216,7 +250,7 @@ func (t *tun) reload(c *config.C, initial bool) error {
return nil return nil
} }
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { func (t *tun) NewMultiQueueReader() (TunDev, error) {
fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0) fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -229,9 +263,17 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return nil, err return nil, err
} }
file := os.NewFile(uintptr(fd), "/dev/net/tun") vdev, err := vhostnet.NewDevice(
vhostnet.WithBackendFD(fd),
vhostnet.WithQueueSize(8192), //todo config
)
if err != nil {
return nil, err
}
return file, nil t.vdev = append(t.vdev, vdev)
return t, nil
} }
func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways { func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
@@ -239,29 +281,6 @@ func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
return r return r
} }
func (t *tun) Write(b []byte) (int, error) {
var nn int
maximum := len(b)
for {
n, err := unix.Write(t.fd, b[nn:maximum])
if n > 0 {
nn += n
}
if nn == len(b) {
return nn, err
}
if err != nil {
return nn, err
}
if n == 0 {
return nn, io.ErrUnexpectedEOF
}
}
}
func (t *tun) deviceBytes() (o [16]byte) { func (t *tun) deviceBytes() (o [16]byte) {
for i, c := range t.Device { for i, c := range t.Device {
o[i] = byte(c) o[i] = byte(c)
@@ -293,7 +312,6 @@ func (t *tun) addIPs(link netlink.Link) error {
//add all new addresses //add all new addresses
for i := range newAddrs { for i := range newAddrs {
//TODO: CERT-V2 do we want to stack errors and try as many ops as possible?
//AddrReplace still adds new IPs, but if their properties change it will change them as well //AddrReplace still adds new IPs, but if their properties change it will change them as well
if err := netlink.AddrReplace(link, newAddrs[i]); err != nil { if err := netlink.AddrReplace(link, newAddrs[i]); err != nil {
return err return err
@@ -361,6 +379,11 @@ func (t *tun) Activate() error {
t.l.WithError(err).Error("Failed to set tun tx queue length") t.l.WithError(err).Error("Failed to set tun tx queue length")
} }
const modeNone = 1
if err = netlink.LinkSetIP6AddrGenMode(link, modeNone); err != nil {
t.l.WithError(err).Warn("Failed to disable link local address generation")
}
if err = t.addIPs(link); err != nil { if err = t.addIPs(link); err != nil {
return err return err
} }
@@ -638,6 +661,11 @@ func (t *tun) updateRoutes(r netlink.RouteUpdate) {
return return
} }
if r.Dst == nil {
t.l.WithField("route", r).Debug("Ignoring route update, no destination address")
return
}
dstAddr, ok := netip.AddrFromSlice(r.Dst.IP) dstAddr, ok := netip.AddrFromSlice(r.Dst.IP)
if !ok { if !ok {
t.l.WithField("route", r).Debug("Ignoring route update, invalid destination address") t.l.WithField("route", r).Debug("Ignoring route update, invalid destination address")
@@ -665,8 +693,14 @@ func (t *tun) Close() error {
close(t.routeChan) close(t.routeChan)
} }
if t.ReadWriteCloser != nil { for _, v := range t.vdev {
_ = t.ReadWriteCloser.Close() if v != nil {
_ = v.Close()
}
}
if t.file != nil {
_ = t.file.Close()
} }
if t.ioctlFd > 0 { if t.ioctlFd > 0 {
@@ -675,3 +709,65 @@ func (t *tun) Close() error {
return nil return nil
} }
func (t *tun) ReadMany(p []*packet.VirtIOPacket, q int) (int, error) {
n, err := t.vdev[q].ReceivePackets(p) //we are TXing
if err != nil {
return 0, err
}
return n, nil
}
func (t *tun) Write(b []byte) (int, error) {
maximum := len(b) //we are RXing
//todo garbagey
out := packet.NewOut()
x, err := t.AllocSeg(out, 0)
if err != nil {
return 0, err
}
copy(out.SegmentPayloads[x], b)
err = t.vdev[0].TransmitPacket(out, true)
if err != nil {
t.l.WithError(err).Error("Transmitting packet")
return 0, err
}
return maximum, nil
}
func (t *tun) AllocSeg(pkt *packet.OutPacket, q int) (int, error) {
idx, buf, err := t.vdev[q].GetPacketForTx()
if err != nil {
return 0, err
}
x := pkt.UseSegment(idx, buf, t.isV6)
return x, nil
}
func (t *tun) WriteOne(x *packet.OutPacket, kick bool, q int) (int, error) {
if err := t.vdev[q].TransmitPacket(x, kick); err != nil {
t.l.WithError(err).Error("Transmitting packet")
return 0, err
}
return 1, nil
}
func (t *tun) WriteMany(x []*packet.OutPacket, q int) (int, error) {
maximum := len(x) //we are RXing
if maximum == 0 {
return 0, nil
}
err := t.vdev[q].TransmitPackets(x)
if err != nil {
t.l.WithError(err).Error("Transmitting packet")
return 0, err
}
return maximum, nil
}
func (t *tun) RecycleRxSeg(pkt *packet.VirtIOPacket, kick bool, q int) error {
return t.vdev[q].ReceiveQueue.OfferDescriptorChains(pkt.Chains, kick)
}

View File

@@ -4,13 +4,12 @@
package overlay package overlay
import ( import (
"errors"
"fmt" "fmt"
"io" "io"
"net/netip" "net/netip"
"os" "os"
"os/exec"
"regexp" "regexp"
"strconv"
"sync/atomic" "sync/atomic"
"syscall" "syscall"
"unsafe" "unsafe"
@@ -20,11 +19,42 @@ import (
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/routing"
"github.com/slackhq/nebula/util" "github.com/slackhq/nebula/util"
netroute "golang.org/x/net/route"
"golang.org/x/sys/unix"
) )
type ifreqDestroy struct { const (
Name [16]byte SIOCAIFADDR_IN6 = 0x8080696b
pad [16]byte TUNSIFHEAD = 0x80047442
TUNSIFMODE = 0x80047458
)
type ifreqAlias4 struct {
Name [unix.IFNAMSIZ]byte
Addr unix.RawSockaddrInet4
DstAddr unix.RawSockaddrInet4
MaskAddr unix.RawSockaddrInet4
}
type ifreqAlias6 struct {
Name [unix.IFNAMSIZ]byte
Addr unix.RawSockaddrInet6
DstAddr unix.RawSockaddrInet6
PrefixMask unix.RawSockaddrInet6
Flags uint32
Lifetime addrLifetime
}
type ifreq struct {
Name [unix.IFNAMSIZ]byte
data int
}
type addrLifetime struct {
Expire uint64
Preferred uint64
Vltime uint32
Pltime uint32
} }
type tun struct { type tun struct {
@@ -34,40 +64,18 @@ type tun struct {
Routes atomic.Pointer[[]Route] Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[routing.Gateways]] routeTree atomic.Pointer[bart.Table[routing.Gateways]]
l *logrus.Logger l *logrus.Logger
f *os.File
io.ReadWriteCloser fd int
} }
func (t *tun) Close() error { var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
if t.ReadWriteCloser != nil {
if err := t.ReadWriteCloser.Close(); err != nil {
return err
}
s, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_DGRAM, syscall.IPPROTO_IP)
if err != nil {
return err
}
defer syscall.Close(s)
ifreq := ifreqDestroy{Name: t.deviceBytes()}
err = ioctl(uintptr(s), syscall.SIOCIFDESTROY, uintptr(unsafe.Pointer(&ifreq)))
return err
}
return nil
}
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) { func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) {
return nil, fmt.Errorf("newTunFromFd not supported in NetBSD") return nil, fmt.Errorf("newTunFromFd not supported in NetBSD")
} }
var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) { func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
// Try to open tun device // Try to open tun device
var file *os.File
var err error var err error
deviceName := c.GetString("tun.dev", "") deviceName := c.GetString("tun.dev", "")
if deviceName == "" { if deviceName == "" {
@@ -77,13 +85,19 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
return nil, fmt.Errorf("a device name in the format of /dev/tunN must be specified") return nil, fmt.Errorf("a device name in the format of /dev/tunN must be specified")
} }
file, err = os.OpenFile("/dev/"+deviceName, os.O_RDWR, 0) fd, err := unix.Open("/dev/"+deviceName, os.O_RDWR, 0)
if err != nil { if err != nil {
return nil, err return nil, err
} }
err = unix.SetNonblock(fd, true)
if err != nil {
l.WithError(err).Warn("Failed to set the tun device as nonblocking")
}
t := &tun{ t := &tun{
ReadWriteCloser: file, f: os.NewFile(uintptr(fd), ""),
fd: fd,
Device: deviceName, Device: deviceName,
vpnNetworks: vpnNetworks, vpnNetworks: vpnNetworks,
MTU: c.GetInt("tun.mtu", DefaultMTU), MTU: c.GetInt("tun.mtu", DefaultMTU),
@@ -105,40 +119,225 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
return t, nil return t, nil
} }
func (t *tun) Close() error {
if t.f != nil {
if err := t.f.Close(); err != nil {
return fmt.Errorf("error closing tun file: %w", err)
}
// t.f.Close should have handled it for us but let's be extra sure
_ = unix.Close(t.fd)
s, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_DGRAM, syscall.IPPROTO_IP)
if err != nil {
return err
}
defer syscall.Close(s)
ifr := ifreq{Name: t.deviceBytes()}
err = ioctl(uintptr(s), syscall.SIOCIFDESTROY, uintptr(unsafe.Pointer(&ifr)))
return err
}
return nil
}
func (t *tun) Read(to []byte) (int, error) {
rc, err := t.f.SyscallConn()
if err != nil {
return 0, fmt.Errorf("failed to get syscall conn for tun: %w", err)
}
var errno syscall.Errno
var n uintptr
err = rc.Read(func(fd uintptr) bool {
// first 4 bytes is protocol family, in network byte order
head := [4]byte{}
iovecs := []syscall.Iovec{
{&head[0], 4},
{&to[0], uint64(len(to))},
}
n, _, errno = syscall.Syscall(syscall.SYS_READV, fd, uintptr(unsafe.Pointer(&iovecs[0])), uintptr(2))
if errno.Temporary() {
// We got an EAGAIN, EINTR, or EWOULDBLOCK, go again
return false
}
return true
})
if err != nil {
if err == syscall.EBADF || err.Error() == "use of closed file" {
// Go doesn't export poll.ErrFileClosing but happily reports it to us so here we are
// https://github.com/golang/go/blob/master/src/internal/poll/fd_poll_runtime.go#L121
return 0, os.ErrClosed
}
return 0, fmt.Errorf("failed to make read call for tun: %w", err)
}
if errno != 0 {
return 0, fmt.Errorf("failed to make inner read call for tun: %w", errno)
}
// fix bytes read number to exclude header
bytesRead := int(n)
if bytesRead < 0 {
return bytesRead, nil
} else if bytesRead < 4 {
return 0, nil
} else {
return bytesRead - 4, nil
}
}
// Write is only valid for single threaded use
func (t *tun) Write(from []byte) (int, error) {
if len(from) <= 1 {
return 0, syscall.EIO
}
ipVer := from[0] >> 4
var head [4]byte
// first 4 bytes is protocol family, in network byte order
if ipVer == 4 {
head[3] = syscall.AF_INET
} else if ipVer == 6 {
head[3] = syscall.AF_INET6
} else {
return 0, fmt.Errorf("unable to determine IP version from packet")
}
rc, err := t.f.SyscallConn()
if err != nil {
return 0, err
}
var errno syscall.Errno
var n uintptr
err = rc.Write(func(fd uintptr) bool {
iovecs := []syscall.Iovec{
{&head[0], 4},
{&from[0], uint64(len(from))},
}
n, _, errno = syscall.Syscall(syscall.SYS_WRITEV, fd, uintptr(unsafe.Pointer(&iovecs[0])), uintptr(2))
// According to NetBSD documentation for TUN, writes will only return errors in which
// this packet will never be delivered so just go on living life.
return true
})
if err != nil {
return 0, err
}
if errno != 0 {
return 0, errno
}
return int(n) - 4, err
}
func (t *tun) addIp(cidr netip.Prefix) error { func (t *tun) addIp(cidr netip.Prefix) error {
var err error if cidr.Addr().Is4() {
var req ifreqAlias4
// TODO use syscalls instead of exec.Command req.Name = t.deviceBytes()
cmd := exec.Command("/sbin/ifconfig", t.Device, cidr.String(), cidr.Addr().String()) req.Addr = unix.RawSockaddrInet4{
t.l.Debug("command: ", cmd.String()) Len: unix.SizeofSockaddrInet4,
if err = cmd.Run(); err != nil { Family: unix.AF_INET,
return fmt.Errorf("failed to run 'ifconfig': %s", err) Addr: cidr.Addr().As4(),
}
req.DstAddr = unix.RawSockaddrInet4{
Len: unix.SizeofSockaddrInet4,
Family: unix.AF_INET,
Addr: cidr.Addr().As4(),
}
req.MaskAddr = unix.RawSockaddrInet4{
Len: unix.SizeofSockaddrInet4,
Family: unix.AF_INET,
Addr: prefixToMask(cidr).As4(),
} }
cmd = exec.Command("/sbin/route", "-n", "add", "-net", cidr.String(), cidr.Addr().String()) s, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP)
t.l.Debug("command: ", cmd.String()) if err != nil {
if err = cmd.Run(); err != nil { return err
return fmt.Errorf("failed to run 'route add': %s", err) }
defer syscall.Close(s)
if err := ioctl(uintptr(s), unix.SIOCAIFADDR, uintptr(unsafe.Pointer(&req))); err != nil {
return fmt.Errorf("failed to set tun address %s: %s", cidr.Addr(), err)
} }
cmd = exec.Command("/sbin/ifconfig", t.Device, "mtu", strconv.Itoa(t.MTU)) return nil
t.l.Debug("command: ", cmd.String())
if err = cmd.Run(); err != nil {
return fmt.Errorf("failed to run 'ifconfig': %s", err)
} }
// Unsafe path routes if cidr.Addr().Is6() {
return t.addRoutes(false) var req ifreqAlias6
req.Name = t.deviceBytes()
req.Addr = unix.RawSockaddrInet6{
Len: unix.SizeofSockaddrInet6,
Family: unix.AF_INET6,
Addr: cidr.Addr().As16(),
}
req.PrefixMask = unix.RawSockaddrInet6{
Len: unix.SizeofSockaddrInet6,
Family: unix.AF_INET6,
Addr: prefixToMask(cidr).As16(),
}
req.Lifetime = addrLifetime{
Vltime: 0xffffffff,
Pltime: 0xffffffff,
}
s, err := unix.Socket(unix.AF_INET6, unix.SOCK_DGRAM, unix.IPPROTO_IP)
if err != nil {
return err
}
defer syscall.Close(s)
if err := ioctl(uintptr(s), SIOCAIFADDR_IN6, uintptr(unsafe.Pointer(&req))); err != nil {
return fmt.Errorf("failed to set tun address %s: %s", cidr.Addr().String(), err)
}
return nil
}
return fmt.Errorf("unknown address type %v", cidr)
} }
func (t *tun) Activate() error { func (t *tun) Activate() error {
mode := int32(unix.IFF_BROADCAST)
err := ioctl(uintptr(t.fd), TUNSIFMODE, uintptr(unsafe.Pointer(&mode)))
if err != nil {
return fmt.Errorf("failed to set tun device mode: %w", err)
}
v := 1
err = ioctl(uintptr(t.fd), TUNSIFHEAD, uintptr(unsafe.Pointer(&v)))
if err != nil {
return fmt.Errorf("failed to set tun device head: %w", err)
}
err = t.doIoctlByName(unix.SIOCSIFMTU, uint32(t.MTU))
if err != nil {
return fmt.Errorf("failed to set tun mtu: %w", err)
}
for i := range t.vpnNetworks { for i := range t.vpnNetworks {
err := t.addIp(t.vpnNetworks[i]) err = t.addIp(t.vpnNetworks[i])
if err != nil { if err != nil {
return err return err
} }
} }
return nil
return t.addRoutes(false)
}
func (t *tun) doIoctlByName(ctl uintptr, value uint32) error {
s, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP)
if err != nil {
return err
}
defer syscall.Close(s)
ir := ifreq{Name: t.deviceBytes(), data: int(value)}
err = ioctl(uintptr(s), ctl, uintptr(unsafe.Pointer(&ir)))
return err
} }
func (t *tun) reload(c *config.C, initial bool) error { func (t *tun) reload(c *config.C, initial bool) error {
@@ -197,21 +396,23 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
func (t *tun) addRoutes(logErrors bool) error { func (t *tun) addRoutes(logErrors bool) error {
routes := *t.Routes.Load() routes := *t.Routes.Load()
for _, r := range routes { for _, r := range routes {
if len(r.Via) == 0 || !r.Install { if len(r.Via) == 0 || !r.Install {
// We don't allow route MTUs so only install routes with a via // We don't allow route MTUs so only install routes with a via
continue continue
} }
cmd := exec.Command("/sbin/route", "-n", "add", "-net", r.Cidr.String(), t.vpnNetworks[0].Addr().String()) err := addRoute(r.Cidr, t.vpnNetworks)
t.l.Debug("command: ", cmd.String()) if err != nil {
if err := cmd.Run(); err != nil { retErr := util.NewContextualError("Failed to add route", map[string]any{"route": r}, err)
retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]any{"route": r}, err)
if logErrors { if logErrors {
retErr.Log(t.l) retErr.Log(t.l)
} else { } else {
return retErr return retErr
} }
} else {
t.l.WithField("route", r).Info("Added route")
} }
} }
@@ -224,10 +425,8 @@ func (t *tun) removeRoutes(routes []Route) error {
continue continue
} }
//TODO: CERT-V2 is this right? err := delRoute(r.Cidr, t.vpnNetworks)
cmd := exec.Command("/sbin/route", "-n", "delete", "-net", r.Cidr.String(), t.vpnNetworks[0].Addr().String()) if err != nil {
t.l.Debug("command: ", cmd.String())
if err := cmd.Run(); err != nil {
t.l.WithError(err).WithField("route", r).Error("Failed to remove route") t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
} else { } else {
t.l.WithField("route", r).Info("Removed route") t.l.WithField("route", r).Info("Removed route")
@@ -242,3 +441,109 @@ func (t *tun) deviceBytes() (o [16]byte) {
} }
return return
} }
func addRoute(prefix netip.Prefix, gateways []netip.Prefix) error {
sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
if err != nil {
return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
}
defer unix.Close(sock)
route := &netroute.RouteMessage{
Version: unix.RTM_VERSION,
Type: unix.RTM_ADD,
Flags: unix.RTF_UP | unix.RTF_GATEWAY,
Seq: 1,
}
if prefix.Addr().Is4() {
gw, err := selectGateway(prefix, gateways)
if err != nil {
return err
}
route.Addrs = []netroute.Addr{
unix.RTAX_DST: &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()},
unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()},
unix.RTAX_GATEWAY: &netroute.Inet4Addr{IP: gw.Addr().As4()},
}
} else {
gw, err := selectGateway(prefix, gateways)
if err != nil {
return err
}
route.Addrs = []netroute.Addr{
unix.RTAX_DST: &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()},
unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()},
unix.RTAX_GATEWAY: &netroute.Inet6Addr{IP: gw.Addr().As16()},
}
}
data, err := route.Marshal()
if err != nil {
return fmt.Errorf("failed to create route.RouteMessage: %w", err)
}
_, err = unix.Write(sock, data[:])
if err != nil {
if errors.Is(err, unix.EEXIST) {
// Try to do a change
route.Type = unix.RTM_CHANGE
data, err = route.Marshal()
if err != nil {
return fmt.Errorf("failed to create route.RouteMessage for change: %w", err)
}
_, err = unix.Write(sock, data[:])
return err
}
return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err)
}
return nil
}
func delRoute(prefix netip.Prefix, gateways []netip.Prefix) error {
sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
if err != nil {
return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
}
defer unix.Close(sock)
route := netroute.RouteMessage{
Version: unix.RTM_VERSION,
Type: unix.RTM_DELETE,
Seq: 1,
}
if prefix.Addr().Is4() {
gw, err := selectGateway(prefix, gateways)
if err != nil {
return err
}
route.Addrs = []netroute.Addr{
unix.RTAX_DST: &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()},
unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()},
unix.RTAX_GATEWAY: &netroute.Inet4Addr{IP: gw.Addr().As4()},
}
} else {
gw, err := selectGateway(prefix, gateways)
if err != nil {
return err
}
route.Addrs = []netroute.Addr{
unix.RTAX_DST: &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()},
unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()},
unix.RTAX_GATEWAY: &netroute.Inet6Addr{IP: gw.Addr().As16()},
}
}
data, err := route.Marshal()
if err != nil {
return fmt.Errorf("failed to create route.RouteMessage: %w", err)
}
_, err = unix.Write(sock, data[:])
if err != nil {
return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err)
}
return nil
}

View File

@@ -4,23 +4,50 @@
package overlay package overlay
import ( import (
"errors"
"fmt" "fmt"
"io" "io"
"net/netip" "net/netip"
"os" "os"
"os/exec"
"regexp" "regexp"
"strconv"
"sync/atomic" "sync/atomic"
"syscall" "syscall"
"unsafe"
"github.com/gaissmai/bart" "github.com/gaissmai/bart"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/routing"
"github.com/slackhq/nebula/util" "github.com/slackhq/nebula/util"
netroute "golang.org/x/net/route"
"golang.org/x/sys/unix"
) )
const (
SIOCAIFADDR_IN6 = 0x8080691a
)
type ifreqAlias4 struct {
Name [unix.IFNAMSIZ]byte
Addr unix.RawSockaddrInet4
DstAddr unix.RawSockaddrInet4
MaskAddr unix.RawSockaddrInet4
}
type ifreqAlias6 struct {
Name [unix.IFNAMSIZ]byte
Addr unix.RawSockaddrInet6
DstAddr unix.RawSockaddrInet6
PrefixMask unix.RawSockaddrInet6
Flags uint32
Lifetime [2]uint32
}
type ifreq struct {
Name [unix.IFNAMSIZ]byte
data int
}
type tun struct { type tun struct {
Device string Device string
vpnNetworks []netip.Prefix vpnNetworks []netip.Prefix
@@ -28,44 +55,42 @@ type tun struct {
Routes atomic.Pointer[[]Route] Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[routing.Gateways]] routeTree atomic.Pointer[bart.Table[routing.Gateways]]
l *logrus.Logger l *logrus.Logger
f *os.File
io.ReadWriteCloser fd int
// cache out buffer since we need to prepend 4 bytes for tun metadata // cache out buffer since we need to prepend 4 bytes for tun metadata
out []byte out []byte
} }
func (t *tun) Close() error {
if t.ReadWriteCloser != nil {
return t.ReadWriteCloser.Close()
}
return nil
}
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) {
return nil, fmt.Errorf("newTunFromFd not supported in OpenBSD")
}
var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`) var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) {
return nil, fmt.Errorf("newTunFromFd not supported in openbsd")
}
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) { func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
// Try to open tun device
var err error
deviceName := c.GetString("tun.dev", "") deviceName := c.GetString("tun.dev", "")
if deviceName == "" { if deviceName == "" {
return nil, fmt.Errorf("a device name in the format of tunN must be specified") return nil, fmt.Errorf("a device name in the format of /dev/tunN must be specified")
} }
if !deviceNameRE.MatchString(deviceName) { if !deviceNameRE.MatchString(deviceName) {
return nil, fmt.Errorf("a device name in the format of tunN must be specified") return nil, fmt.Errorf("a device name in the format of /dev/tunN must be specified")
} }
file, err := os.OpenFile("/dev/"+deviceName, os.O_RDWR, 0) fd, err := unix.Open("/dev/"+deviceName, os.O_RDWR, 0)
if err != nil { if err != nil {
return nil, err return nil, err
} }
err = unix.SetNonblock(fd, true)
if err != nil {
l.WithError(err).Warn("Failed to set the tun device as nonblocking")
}
t := &tun{ t := &tun{
ReadWriteCloser: file, f: os.NewFile(uintptr(fd), ""),
fd: fd,
Device: deviceName, Device: deviceName,
vpnNetworks: vpnNetworks, vpnNetworks: vpnNetworks,
MTU: c.GetInt("tun.mtu", DefaultMTU), MTU: c.GetInt("tun.mtu", DefaultMTU),
@@ -87,6 +112,154 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
return t, nil return t, nil
} }
func (t *tun) Close() error {
if t.f != nil {
if err := t.f.Close(); err != nil {
return fmt.Errorf("error closing tun file: %w", err)
}
// t.f.Close should have handled it for us but let's be extra sure
_ = unix.Close(t.fd)
}
return nil
}
func (t *tun) Read(to []byte) (int, error) {
buf := make([]byte, len(to)+4)
n, err := t.f.Read(buf)
copy(to, buf[4:])
return n - 4, err
}
// Write is only valid for single threaded use
func (t *tun) Write(from []byte) (int, error) {
buf := t.out
if cap(buf) < len(from)+4 {
buf = make([]byte, len(from)+4)
t.out = buf
}
buf = buf[:len(from)+4]
if len(from) == 0 {
return 0, syscall.EIO
}
// Determine the IP Family for the NULL L2 Header
ipVer := from[0] >> 4
if ipVer == 4 {
buf[3] = syscall.AF_INET
} else if ipVer == 6 {
buf[3] = syscall.AF_INET6
} else {
return 0, fmt.Errorf("unable to determine IP version from packet")
}
copy(buf[4:], from)
n, err := t.f.Write(buf)
return n - 4, err
}
func (t *tun) addIp(cidr netip.Prefix) error {
if cidr.Addr().Is4() {
var req ifreqAlias4
req.Name = t.deviceBytes()
req.Addr = unix.RawSockaddrInet4{
Len: unix.SizeofSockaddrInet4,
Family: unix.AF_INET,
Addr: cidr.Addr().As4(),
}
req.DstAddr = unix.RawSockaddrInet4{
Len: unix.SizeofSockaddrInet4,
Family: unix.AF_INET,
Addr: cidr.Addr().As4(),
}
req.MaskAddr = unix.RawSockaddrInet4{
Len: unix.SizeofSockaddrInet4,
Family: unix.AF_INET,
Addr: prefixToMask(cidr).As4(),
}
s, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP)
if err != nil {
return err
}
defer syscall.Close(s)
if err := ioctl(uintptr(s), unix.SIOCAIFADDR, uintptr(unsafe.Pointer(&req))); err != nil {
return fmt.Errorf("failed to set tun address %s: %s", cidr.Addr(), err)
}
err = addRoute(cidr, t.vpnNetworks)
if err != nil {
return fmt.Errorf("failed to set route for vpn network %v: %w", cidr, err)
}
return nil
}
if cidr.Addr().Is6() {
var req ifreqAlias6
req.Name = t.deviceBytes()
req.Addr = unix.RawSockaddrInet6{
Len: unix.SizeofSockaddrInet6,
Family: unix.AF_INET6,
Addr: cidr.Addr().As16(),
}
req.PrefixMask = unix.RawSockaddrInet6{
Len: unix.SizeofSockaddrInet6,
Family: unix.AF_INET6,
Addr: prefixToMask(cidr).As16(),
}
req.Lifetime[0] = 0xffffffff
req.Lifetime[1] = 0xffffffff
s, err := unix.Socket(unix.AF_INET6, unix.SOCK_DGRAM, unix.IPPROTO_IP)
if err != nil {
return err
}
defer syscall.Close(s)
if err := ioctl(uintptr(s), SIOCAIFADDR_IN6, uintptr(unsafe.Pointer(&req))); err != nil {
return fmt.Errorf("failed to set tun address %s: %s", cidr.Addr().String(), err)
}
return nil
}
return fmt.Errorf("unknown address type %v", cidr)
}
func (t *tun) Activate() error {
err := t.doIoctlByName(unix.SIOCSIFMTU, uint32(t.MTU))
if err != nil {
return fmt.Errorf("failed to set tun mtu: %w", err)
}
for i := range t.vpnNetworks {
err = t.addIp(t.vpnNetworks[i])
if err != nil {
return err
}
}
return t.addRoutes(false)
}
func (t *tun) doIoctlByName(ctl uintptr, value uint32) error {
s, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP)
if err != nil {
return err
}
defer syscall.Close(s)
ir := ifreq{Name: t.deviceBytes(), data: int(value)}
err = ioctl(uintptr(s), ctl, uintptr(unsafe.Pointer(&ir)))
return err
}
func (t *tun) reload(c *config.C, initial bool) error { func (t *tun) reload(c *config.C, initial bool) error {
change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial) change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial)
if err != nil { if err != nil {
@@ -124,86 +297,11 @@ func (t *tun) reload(c *config.C, initial bool) error {
return nil return nil
} }
func (t *tun) addIp(cidr netip.Prefix) error {
var err error
// TODO use syscalls instead of exec.Command
cmd := exec.Command("/sbin/ifconfig", t.Device, cidr.String(), cidr.Addr().String())
t.l.Debug("command: ", cmd.String())
if err = cmd.Run(); err != nil {
return fmt.Errorf("failed to run 'ifconfig': %s", err)
}
cmd = exec.Command("/sbin/ifconfig", t.Device, "mtu", strconv.Itoa(t.MTU))
t.l.Debug("command: ", cmd.String())
if err = cmd.Run(); err != nil {
return fmt.Errorf("failed to run 'ifconfig': %s", err)
}
cmd = exec.Command("/sbin/route", "-n", "add", "-inet", cidr.String(), cidr.Addr().String())
t.l.Debug("command: ", cmd.String())
if err = cmd.Run(); err != nil {
return fmt.Errorf("failed to run 'route add': %s", err)
}
// Unsafe path routes
return t.addRoutes(false)
}
func (t *tun) Activate() error {
for i := range t.vpnNetworks {
err := t.addIp(t.vpnNetworks[i])
if err != nil {
return err
}
}
return nil
}
func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways { func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
r, _ := t.routeTree.Load().Lookup(ip) r, _ := t.routeTree.Load().Lookup(ip)
return r return r
} }
func (t *tun) addRoutes(logErrors bool) error {
routes := *t.Routes.Load()
for _, r := range routes {
if len(r.Via) == 0 || !r.Install {
// We don't allow route MTUs so only install routes with a via
continue
}
//TODO: CERT-V2 is this right?
cmd := exec.Command("/sbin/route", "-n", "add", "-inet", r.Cidr.String(), t.vpnNetworks[0].Addr().String())
t.l.Debug("command: ", cmd.String())
if err := cmd.Run(); err != nil {
retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]any{"route": r}, err)
if logErrors {
retErr.Log(t.l)
} else {
return retErr
}
}
}
return nil
}
func (t *tun) removeRoutes(routes []Route) error {
for _, r := range routes {
if !r.Install {
continue
}
//TODO: CERT-V2 is this right?
cmd := exec.Command("/sbin/route", "-n", "delete", "-inet", r.Cidr.String(), t.vpnNetworks[0].Addr().String())
t.l.Debug("command: ", cmd.String())
if err := cmd.Run(); err != nil {
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
} else {
t.l.WithField("route", r).Info("Removed route")
}
}
return nil
}
func (t *tun) Networks() []netip.Prefix { func (t *tun) Networks() []netip.Prefix {
return t.vpnNetworks return t.vpnNetworks
} }
@@ -213,43 +311,159 @@ func (t *tun) Name() string {
} }
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for freebsd") return nil, fmt.Errorf("TODO: multiqueue not implemented for openbsd")
} }
func (t *tun) Read(to []byte) (int, error) { func (t *tun) addRoutes(logErrors bool) error {
buf := make([]byte, len(to)+4) routes := *t.Routes.Load()
n, err := t.ReadWriteCloser.Read(buf) for _, r := range routes {
if len(r.Via) == 0 || !r.Install {
copy(to, buf[4:]) // We don't allow route MTUs so only install routes with a via
return n - 4, err continue
} }
// Write is only valid for single threaded use err := addRoute(r.Cidr, t.vpnNetworks)
func (t *tun) Write(from []byte) (int, error) { if err != nil {
buf := t.out retErr := util.NewContextualError("Failed to add route", map[string]any{"route": r}, err)
if cap(buf) < len(from)+4 { if logErrors {
buf = make([]byte, len(from)+4) retErr.Log(t.l)
t.out = buf
}
buf = buf[:len(from)+4]
if len(from) == 0 {
return 0, syscall.EIO
}
// Determine the IP Family for the NULL L2 Header
ipVer := from[0] >> 4
if ipVer == 4 {
buf[3] = syscall.AF_INET
} else if ipVer == 6 {
buf[3] = syscall.AF_INET6
} else { } else {
return 0, fmt.Errorf("unable to determine IP version from packet") return retErr
}
} else {
t.l.WithField("route", r).Info("Added route")
}
} }
copy(buf[4:], from) return nil
}
n, err := t.ReadWriteCloser.Write(buf)
return n - 4, err func (t *tun) removeRoutes(routes []Route) error {
for _, r := range routes {
if !r.Install {
continue
}
err := delRoute(r.Cidr, t.vpnNetworks)
if err != nil {
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
} else {
t.l.WithField("route", r).Info("Removed route")
}
}
return nil
}
func (t *tun) deviceBytes() (o [16]byte) {
for i, c := range t.Device {
o[i] = byte(c)
}
return
}
func addRoute(prefix netip.Prefix, gateways []netip.Prefix) error {
sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
if err != nil {
return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
}
defer unix.Close(sock)
route := &netroute.RouteMessage{
Version: unix.RTM_VERSION,
Type: unix.RTM_ADD,
Flags: unix.RTF_UP | unix.RTF_GATEWAY,
Seq: 1,
}
if prefix.Addr().Is4() {
gw, err := selectGateway(prefix, gateways)
if err != nil {
return err
}
route.Addrs = []netroute.Addr{
unix.RTAX_DST: &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()},
unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()},
unix.RTAX_GATEWAY: &netroute.Inet4Addr{IP: gw.Addr().As4()},
}
} else {
gw, err := selectGateway(prefix, gateways)
if err != nil {
return err
}
route.Addrs = []netroute.Addr{
unix.RTAX_DST: &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()},
unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()},
unix.RTAX_GATEWAY: &netroute.Inet6Addr{IP: gw.Addr().As16()},
}
}
data, err := route.Marshal()
if err != nil {
return fmt.Errorf("failed to create route.RouteMessage: %w", err)
}
_, err = unix.Write(sock, data[:])
if err != nil {
if errors.Is(err, unix.EEXIST) {
// Try to do a change
route.Type = unix.RTM_CHANGE
data, err = route.Marshal()
if err != nil {
return fmt.Errorf("failed to create route.RouteMessage for change: %w", err)
}
_, err = unix.Write(sock, data[:])
return err
}
return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err)
}
return nil
}
func delRoute(prefix netip.Prefix, gateways []netip.Prefix) error {
sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
if err != nil {
return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
}
defer unix.Close(sock)
route := netroute.RouteMessage{
Version: unix.RTM_VERSION,
Type: unix.RTM_DELETE,
Seq: 1,
}
if prefix.Addr().Is4() {
gw, err := selectGateway(prefix, gateways)
if err != nil {
return err
}
route.Addrs = []netroute.Addr{
unix.RTAX_DST: &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()},
unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()},
unix.RTAX_GATEWAY: &netroute.Inet4Addr{IP: gw.Addr().As4()},
}
} else {
gw, err := selectGateway(prefix, gateways)
if err != nil {
return err
}
route.Addrs = []netroute.Addr{
unix.RTAX_DST: &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()},
unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()},
unix.RTAX_GATEWAY: &netroute.Inet6Addr{IP: gw.Addr().As16()},
}
}
data, err := route.Marshal()
if err != nil {
return fmt.Errorf("failed to create route.RouteMessage: %w", err)
}
_, err = unix.Write(sock, data[:])
if err != nil {
return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err)
}
return nil
} }

View File

@@ -1,11 +1,13 @@
package overlay package overlay
import ( import (
"fmt"
"io" "io"
"net/netip" "net/netip"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/packet"
"github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/routing"
) )
@@ -36,6 +38,10 @@ type UserDevice struct {
inboundWriter *io.PipeWriter inboundWriter *io.PipeWriter
} }
func (d *UserDevice) RecycleRxSeg(pkt *packet.VirtIOPacket, kick bool, q int) error {
return nil
}
func (d *UserDevice) Activate() error { func (d *UserDevice) Activate() error {
return nil return nil
} }
@@ -46,7 +52,7 @@ func (d *UserDevice) RoutesFor(ip netip.Addr) routing.Gateways {
return routing.Gateways{routing.NewGateway(ip, 1)} return routing.Gateways{routing.NewGateway(ip, 1)}
} }
func (d *UserDevice) NewMultiQueueReader() (io.ReadWriteCloser, error) { func (d *UserDevice) NewMultiQueueReader() (TunDev, error) {
return d, nil return d, nil
} }
@@ -65,3 +71,19 @@ func (d *UserDevice) Close() error {
d.outboundWriter.Close() d.outboundWriter.Close()
return nil return nil
} }
func (d *UserDevice) ReadMany(b []*packet.VirtIOPacket, _ int) (int, error) {
return d.Read(b[0].Payload)
}
func (d *UserDevice) AllocSeg(pkt *packet.OutPacket, q int) (int, error) {
return 0, fmt.Errorf("user: AllocSeg not implemented")
}
func (d *UserDevice) WriteOne(x *packet.OutPacket, kick bool, q int) (int, error) {
return 0, fmt.Errorf("user: WriteOne not implemented")
}
func (d *UserDevice) WriteMany(x []*packet.OutPacket, q int) (int, error) {
return 0, fmt.Errorf("user: WriteMany not implemented")
}

23
overlay/vhost/README.md Normal file
View File

@@ -0,0 +1,23 @@
Significant portions of this code are derived from https://pkg.go.dev/github.com/hetznercloud/virtio-go
MIT License
Copyright (c) 2025 Hetzner Cloud GmbH
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

4
overlay/vhost/doc.go Normal file
View File

@@ -0,0 +1,4 @@
// Package vhost implements the basic ioctl requests needed to interact with the
// kernel-level virtio server that provides accelerated virtio devices for
// networking and more.
package vhost

218
overlay/vhost/ioctl.go Normal file
View File

@@ -0,0 +1,218 @@
package vhost
import (
"fmt"
"unsafe"
"github.com/slackhq/nebula/overlay/virtqueue"
"github.com/slackhq/nebula/util/virtio"
"golang.org/x/sys/unix"
)
const (
// vhostIoctlGetFeatures can be used to retrieve the features supported by
// the vhost implementation in the kernel.
//
// Response payload: [virtio.Feature]
// Kernel name: VHOST_GET_FEATURES
vhostIoctlGetFeatures = 0x8008af00
// vhostIoctlSetFeatures can be used to communicate the features supported
// by this virtio implementation to the kernel.
//
// Request payload: [virtio.Feature]
// Kernel name: VHOST_SET_FEATURES
vhostIoctlSetFeatures = 0x4008af00
// vhostIoctlSetOwner can be used to set the current process as the
// exclusive owner of a control file descriptor.
//
// Request payload: none
// Kernel name: VHOST_SET_OWNER
vhostIoctlSetOwner = 0x0000af01
// vhostIoctlSetMemoryLayout can be used to set up or modify the memory
// layout which describes the IOTLB mappings in the kernel.
//
// Request payload: [MemoryLayout] with custom serialization
// Kernel name: VHOST_SET_MEM_TABLE
vhostIoctlSetMemoryLayout = 0x4008af03
// vhostIoctlSetQueueSize can be used to set the size of the virtqueue.
//
// Request payload: [QueueState]
// Kernel name: VHOST_SET_VRING_NUM
vhostIoctlSetQueueSize = 0x4008af10
// vhostIoctlSetQueueAddress can be used to set the addresses of the
// different parts of the virtqueue.
//
// Request payload: [QueueAddresses]
// Kernel name: VHOST_SET_VRING_ADDR
vhostIoctlSetQueueAddress = 0x4028af11
// vhostIoctlSetAvailableRingBase can be used to set the index of the next
// available ring entry the device will process.
//
// Request payload: [QueueState]
// Kernel name: VHOST_SET_VRING_BASE
vhostIoctlSetAvailableRingBase = 0x4008af12
// vhostIoctlSetQueueKickEventFD can be used to set the event file
// descriptor to signal the device when descriptor chains were added to the
// available ring.
//
// Request payload: [QueueFile]
// Kernel name: VHOST_SET_VRING_KICK
vhostIoctlSetQueueKickEventFD = 0x4008af20
// vhostIoctlSetQueueCallEventFD can be used to set the event file
// descriptor that gets signaled by the device when descriptor chains have
// been used by it.
//
// Request payload: [QueueFile]
// Kernel name: VHOST_SET_VRING_CALL
vhostIoctlSetQueueCallEventFD = 0x4008af21
)
// QueueState is an ioctl request payload that can hold a queue index and any
// 32-bit number.
//
// Kernel name: vhost_vring_state
type QueueState struct {
// QueueIndex is the index of the virtqueue.
QueueIndex uint32
// Num is any 32-bit number, depending on the request.
Num uint32
}
// QueueAddresses is an ioctl request payload that can hold the addresses of the
// different parts of a virtqueue.
//
// Kernel name: vhost_vring_addr
type QueueAddresses struct {
// QueueIndex is the index of the virtqueue.
QueueIndex uint32
// Flags that are not used in this implementation.
Flags uint32
// DescriptorTableAddress is the address of the descriptor table in user
// space memory. It must be 16-byte aligned.
DescriptorTableAddress uintptr
// UsedRingAddress is the address of the used ring in user space memory. It
// must be 4-byte aligned.
UsedRingAddress uintptr
// AvailableRingAddress is the address of the available ring in user space
// memory. It must be 2-byte aligned.
AvailableRingAddress uintptr
// LogAddress is used for an optional logging support, not supported by this
// implementation.
LogAddress uintptr
}
// QueueFile is an ioctl request payload that can hold a queue index and a file
// descriptor.
//
// Kernel name: vhost_vring_file
type QueueFile struct {
// QueueIndex is the index of the virtqueue.
QueueIndex uint32
// FD is the file descriptor of the file. Pass -1 to unbind from a file.
FD int32
}
// IoctlPtr is a copy of the similarly named unexported function from the Go
// unix package. This is needed to do custom ioctl requests not supported by the
// standard library.
func IoctlPtr(fd int, req uint, arg unsafe.Pointer) error {
_, _, err := unix.Syscall(unix.SYS_IOCTL, uintptr(fd), uintptr(req), uintptr(arg))
if err != 0 {
return fmt.Errorf("ioctl request %d: %w", req, err)
}
return nil
}
// GetFeatures requests the supported feature bits from the virtio device
// associated with the given control file descriptor.
func GetFeatures(controlFD int) (virtio.Feature, error) {
var features virtio.Feature
if err := IoctlPtr(controlFD, vhostIoctlGetFeatures, unsafe.Pointer(&features)); err != nil {
return 0, fmt.Errorf("get features: %w", err)
}
return features, nil
}
// SetFeatures communicates the feature bits supported by this implementation
// to the virtio device associated with the given control file descriptor.
func SetFeatures(controlFD int, features virtio.Feature) error {
if err := IoctlPtr(controlFD, vhostIoctlSetFeatures, unsafe.Pointer(&features)); err != nil {
return fmt.Errorf("set features: %w", err)
}
return nil
}
// OwnControlFD sets the current process as the exclusive owner for the
// given control file descriptor. This must be called before interacting with
// the control file descriptor in any other way.
func OwnControlFD(controlFD int) error {
if err := IoctlPtr(controlFD, vhostIoctlSetOwner, unsafe.Pointer(nil)); err != nil {
return fmt.Errorf("set control file descriptor owner: %w", err)
}
return nil
}
// SetMemoryLayout sets up or modifies the memory layout for the kernel-level
// virtio device associated with the given control file descriptor.
func SetMemoryLayout(controlFD int, layout MemoryLayout) error {
payload := layout.serializePayload()
if err := IoctlPtr(controlFD, vhostIoctlSetMemoryLayout, unsafe.Pointer(&payload[0])); err != nil {
return fmt.Errorf("set memory layout: %w", err)
}
return nil
}
// RegisterQueue registers a virtio queue with the kernel-level virtio server.
// The virtqueue will be linked to the given control file descriptor and will
// have the given index. The kernel will use this queue until the control file
// descriptor is closed.
func RegisterQueue(controlFD int, queueIndex uint32, queue *virtqueue.SplitQueue) error {
if err := IoctlPtr(controlFD, vhostIoctlSetQueueSize, unsafe.Pointer(&QueueState{
QueueIndex: queueIndex,
Num: uint32(queue.Size()),
})); err != nil {
return fmt.Errorf("set queue size: %w", err)
}
if err := IoctlPtr(controlFD, vhostIoctlSetQueueAddress, unsafe.Pointer(&QueueAddresses{
QueueIndex: queueIndex,
Flags: 0,
DescriptorTableAddress: queue.DescriptorTable().Address(),
UsedRingAddress: queue.UsedRing().Address(),
AvailableRingAddress: queue.AvailableRing().Address(),
LogAddress: 0,
})); err != nil {
return fmt.Errorf("set queue addresses: %w", err)
}
if err := IoctlPtr(controlFD, vhostIoctlSetAvailableRingBase, unsafe.Pointer(&QueueState{
QueueIndex: queueIndex,
Num: 0,
})); err != nil {
return fmt.Errorf("set available ring base: %w", err)
}
if err := IoctlPtr(controlFD, vhostIoctlSetQueueKickEventFD, unsafe.Pointer(&QueueFile{
QueueIndex: queueIndex,
FD: int32(queue.KickEventFD()),
})); err != nil {
return fmt.Errorf("set kick event file descriptor: %w", err)
}
if err := IoctlPtr(controlFD, vhostIoctlSetQueueCallEventFD, unsafe.Pointer(&QueueFile{
QueueIndex: queueIndex,
FD: int32(queue.CallEventFD()),
})); err != nil {
return fmt.Errorf("set call event file descriptor: %w", err)
}
return nil
}

View File

@@ -0,0 +1,21 @@
package vhost_test
import (
"testing"
"unsafe"
"github.com/slackhq/nebula/overlay/vhost"
"github.com/stretchr/testify/assert"
)
func TestQueueState_Size(t *testing.T) {
assert.EqualValues(t, 8, unsafe.Sizeof(vhost.QueueState{}))
}
func TestQueueAddresses_Size(t *testing.T) {
assert.EqualValues(t, 40, unsafe.Sizeof(vhost.QueueAddresses{}))
}
func TestQueueFile_Size(t *testing.T) {
assert.EqualValues(t, 8, unsafe.Sizeof(vhost.QueueFile{}))
}

73
overlay/vhost/memory.go Normal file
View File

@@ -0,0 +1,73 @@
package vhost
import (
"encoding/binary"
"fmt"
"unsafe"
"github.com/slackhq/nebula/overlay/virtqueue"
)
// MemoryRegion describes a region of userspace memory which is being made
// accessible to a vhost device.
//
// Kernel name: vhost_memory_region
type MemoryRegion struct {
// GuestPhysicalAddress is the physical address of the memory region within
// the guest, when virtualization is used. When no virtualization is used,
// this should be the same as UserspaceAddress.
GuestPhysicalAddress uintptr
// Size is the size of the memory region.
Size uint64
// UserspaceAddress is the virtual address in the userspace of the host
// where the memory region can be found.
UserspaceAddress uintptr
// Padding and room for flags. Currently unused.
_ uint64
}
// MemoryLayout is a list of [MemoryRegion]s.
type MemoryLayout []MemoryRegion
// NewMemoryLayoutForQueues returns a new [MemoryLayout] that describes the
// memory pages used by the descriptor tables of the given queues.
func NewMemoryLayoutForQueues(queues []*virtqueue.SplitQueue) MemoryLayout {
regions := make([]MemoryRegion, 0)
for _, queue := range queues {
for address, size := range queue.DescriptorTable().BufferAddresses() {
regions = append(regions, MemoryRegion{
// There is no virtualization in play here, so the guest address
// is the same as in the host's userspace.
GuestPhysicalAddress: address,
Size: uint64(size),
UserspaceAddress: address,
})
}
}
return regions
}
// serializePayload serializes the list of memory regions into a format that is
// compatible to the vhost_memory kernel struct. The returned byte slice can be
// used as a payload for the vhostIoctlSetMemoryLayout ioctl.
func (regions MemoryLayout) serializePayload() []byte {
regionCount := len(regions)
regionSize := int(unsafe.Sizeof(MemoryRegion{}))
payload := make([]byte, 8+regionCount*regionSize)
// The first 32 bits contain the number of memory regions. The following 32
// bits are padding.
binary.LittleEndian.PutUint32(payload[0:4], uint32(regionCount))
if regionCount > 0 {
// The underlying byte array of the slice should already have the correct
// format, so just copy that.
copied := copy(payload[8:], unsafe.Slice((*byte)(unsafe.Pointer(&regions[0])), regionCount*regionSize))
if copied != regionCount*regionSize {
panic(fmt.Sprintf("copied only %d bytes of the memory regions, but expected %d",
copied, regionCount*regionSize))
}
}
return payload
}

View File

@@ -0,0 +1,42 @@
package vhost
import (
"testing"
"unsafe"
"github.com/stretchr/testify/assert"
)
func TestMemoryRegion_Size(t *testing.T) {
assert.EqualValues(t, 32, unsafe.Sizeof(MemoryRegion{}))
}
func TestMemoryLayout_SerializePayload(t *testing.T) {
layout := MemoryLayout([]MemoryRegion{
{
GuestPhysicalAddress: 42,
Size: 100,
UserspaceAddress: 142,
}, {
GuestPhysicalAddress: 99,
Size: 100,
UserspaceAddress: 99,
},
})
payload := layout.serializePayload()
assert.Equal(t, []byte{
0x02, 0x00, 0x00, 0x00, // nregions
0x00, 0x00, 0x00, 0x00, // padding
// region 0
0x2a, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // guest_phys_addr
0x64, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // memory_size
0x8e, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // userspace_addr
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // flags_padding
// region 1
0x63, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // guest_phys_addr
0x64, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // memory_size
0x63, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // userspace_addr
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // flags_padding
}, payload)
}

View File

@@ -0,0 +1,23 @@
Significant portions of this code are derived from https://pkg.go.dev/github.com/hetznercloud/virtio-go
MIT License
Copyright (c) 2025 Hetzner Cloud GmbH
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

427
overlay/vhostnet/device.go Normal file
View File

@@ -0,0 +1,427 @@
package vhostnet
import (
"context"
"errors"
"fmt"
"os"
"runtime"
"github.com/slackhq/nebula/overlay/vhost"
"github.com/slackhq/nebula/overlay/virtqueue"
"github.com/slackhq/nebula/packet"
"github.com/slackhq/nebula/util/virtio"
"golang.org/x/sys/unix"
)
// ErrDeviceClosed is returned when the [Device] is closed while operations are
// still running.
var ErrDeviceClosed = errors.New("device was closed")
// The indexes for the receive and transmit queues.
const (
receiveQueueIndex = 0
transmitQueueIndex = 1
)
// Device represents a vhost networking device within the kernel-level virtio
// implementation and provides methods to interact with it.
type Device struct {
initialized bool
controlFD int
fullTable bool
ReceiveQueue *virtqueue.SplitQueue
TransmitQueue *virtqueue.SplitQueue
}
// NewDevice initializes a new vhost networking device within the
// kernel-level virtio implementation, sets up the virtqueues and returns a
// [Device] instance that can be used to communicate with that vhost device.
//
// There are multiple options that can be passed to this constructor to
// influence device creation:
// - [WithQueueSize]
// - [WithBackendFD]
// - [WithBackendDevice]
//
// Remember to call [Device.Close] after use to free up resources.
func NewDevice(options ...Option) (*Device, error) {
var err error
opts := optionDefaults
opts.apply(options)
if err = opts.validate(); err != nil {
return nil, fmt.Errorf("invalid options: %w", err)
}
dev := Device{
controlFD: -1,
}
// Clean up a partially initialized device when something fails.
defer func() {
if err != nil {
_ = dev.Close()
}
}()
// Retrieve a new control file descriptor. This will be used to configure
// the vhost networking device in the kernel.
dev.controlFD, err = unix.Open("/dev/vhost-net", os.O_RDWR, 0666)
if err != nil {
return nil, fmt.Errorf("get control file descriptor: %w", err)
}
if err = vhost.OwnControlFD(dev.controlFD); err != nil {
return nil, fmt.Errorf("own control file descriptor: %w", err)
}
// Advertise the supported features. This isn't much for now.
// TODO: Add feature options and implement proper feature negotiation.
getFeatures, err := vhost.GetFeatures(dev.controlFD) //0x1033D008000 but why
if err != nil {
return nil, fmt.Errorf("get features: %w", err)
}
if getFeatures == 0 {
}
//const funky = virtio.Feature(1 << 27)
//features := virtio.FeatureVersion1 | funky // | todo virtio.FeatureNetMergeRXBuffers
features := virtio.FeatureVersion1 | virtio.FeatureNetMergeRXBuffers
if err = vhost.SetFeatures(dev.controlFD, features); err != nil {
return nil, fmt.Errorf("set features: %w", err)
}
itemSize := os.Getpagesize() * 4 //todo config
// Initialize and register the queues needed for the networking device.
if dev.ReceiveQueue, err = createQueue(dev.controlFD, receiveQueueIndex, opts.queueSize, itemSize); err != nil {
return nil, fmt.Errorf("create receive queue: %w", err)
}
if dev.TransmitQueue, err = createQueue(dev.controlFD, transmitQueueIndex, opts.queueSize, itemSize); err != nil {
return nil, fmt.Errorf("create transmit queue: %w", err)
}
// Set up memory mappings for all buffers used by the queues. This has to
// happen before a backend for the queues can be registered.
memoryLayout := vhost.NewMemoryLayoutForQueues(
[]*virtqueue.SplitQueue{dev.ReceiveQueue, dev.TransmitQueue},
)
if err = vhost.SetMemoryLayout(dev.controlFD, memoryLayout); err != nil {
return nil, fmt.Errorf("setup memory layout: %w", err)
}
// Set the queue backends. This activates the queues within the kernel.
if err = SetQueueBackend(dev.controlFD, receiveQueueIndex, opts.backendFD); err != nil {
return nil, fmt.Errorf("set receive queue backend: %w", err)
}
if err = SetQueueBackend(dev.controlFD, transmitQueueIndex, opts.backendFD); err != nil {
return nil, fmt.Errorf("set transmit queue backend: %w", err)
}
// Fully populate the receive queue with available buffers which the device
// can write new packets into.
if err = dev.refillReceiveQueue(); err != nil {
return nil, fmt.Errorf("refill receive queue: %w", err)
}
if err = dev.refillTransmitQueue(); err != nil {
return nil, fmt.Errorf("refill receive queue: %w", err)
}
dev.initialized = true
// Make sure to clean up even when the device gets garbage collected without
// Close being called first.
devPtr := &dev
runtime.SetFinalizer(devPtr, (*Device).Close)
return devPtr, nil
}
// refillReceiveQueue offers as many new device-writable buffers to the device
// as the queue can fit. The device will then use these to write received
// packets.
func (dev *Device) refillReceiveQueue() error {
for {
_, err := dev.ReceiveQueue.OfferInDescriptorChains()
if err != nil {
if errors.Is(err, virtqueue.ErrNotEnoughFreeDescriptors) {
// Queue is full, job is done.
return nil
}
return fmt.Errorf("offer descriptor chain: %w", err)
}
}
}
func (dev *Device) refillTransmitQueue() error {
//for {
// desc, err := dev.TransmitQueue.DescriptorTable().CreateDescriptorForOutputs()
// if err != nil {
// if errors.Is(err, virtqueue.ErrNotEnoughFreeDescriptors) {
// // Queue is full, job is done.
// return nil
// }
// return fmt.Errorf("offer descriptor chain: %w", err)
// } else {
// dev.TransmitQueue.UsedRing().InitOfferSingle(desc, 0)
// }
//}
return nil
}
// Close cleans up the vhost networking device within the kernel and releases
// all resources used for it.
// The implementation will try to release as many resources as possible and
// collect potential errors before returning them.
func (dev *Device) Close() error {
dev.initialized = false
// Closing the control file descriptor will unregister all queues from the
// kernel.
if dev.controlFD >= 0 {
if err := unix.Close(dev.controlFD); err != nil {
// Return an error and do not continue, because the memory used for
// the queues should not be released before they were unregistered
// from the kernel.
return fmt.Errorf("close control file descriptor: %w", err)
}
dev.controlFD = -1
}
var errs []error
if dev.ReceiveQueue != nil {
if err := dev.ReceiveQueue.Close(); err == nil {
dev.ReceiveQueue = nil
} else {
errs = append(errs, fmt.Errorf("close receive queue: %w", err))
}
}
if dev.TransmitQueue != nil {
if err := dev.TransmitQueue.Close(); err == nil {
dev.TransmitQueue = nil
} else {
errs = append(errs, fmt.Errorf("close transmit queue: %w", err))
}
}
if len(errs) == 0 {
// Everything was cleaned up. No need to run the finalizer anymore.
runtime.SetFinalizer(dev, nil)
}
return errors.Join(errs...)
}
// ensureInitialized is used as a guard to prevent methods to be called on an
// uninitialized instance.
func (dev *Device) ensureInitialized() {
if !dev.initialized {
panic("device is not initialized")
}
}
// createQueue creates a new virtqueue and registers it with the vhost device
// using the given index.
func createQueue(controlFD int, queueIndex int, queueSize int, itemSize int) (*virtqueue.SplitQueue, error) {
var (
queue *virtqueue.SplitQueue
err error
)
if queue, err = virtqueue.NewSplitQueue(queueSize, itemSize); err != nil {
return nil, fmt.Errorf("create virtqueue: %w", err)
}
if err = vhost.RegisterQueue(controlFD, uint32(queueIndex), queue); err != nil {
return nil, fmt.Errorf("register virtqueue with index %d: %w", queueIndex, err)
}
return queue, nil
}
// truncateBuffers returns a new list of buffers whose combined length matches
// exactly the specified length. When the specified length exceeds the length of
// the buffers, this is an error. When it is smaller, the buffer list will be
// truncated accordingly.
func truncateBuffers(buffers [][]byte, length int) (out [][]byte) {
for _, buffer := range buffers {
if length < len(buffer) {
out = append(out, buffer[:length])
return
}
out = append(out, buffer)
length -= len(buffer)
}
if length > 0 {
panic("length exceeds the combined length of all buffers")
}
return
}
func (dev *Device) GetPacketForTx() (uint16, []byte, error) {
var err error
var idx uint16
if !dev.fullTable {
idx, err = dev.TransmitQueue.DescriptorTable().CreateDescriptorForOutputs()
if err == virtqueue.ErrNotEnoughFreeDescriptors {
dev.fullTable = true
idx, err = dev.TransmitQueue.TakeSingle(context.TODO())
}
} else {
idx, err = dev.TransmitQueue.TakeSingle(context.TODO())
}
if err != nil {
return 0, nil, fmt.Errorf("transmit queue: %w", err)
}
buf, err := dev.TransmitQueue.GetDescriptorItem(idx)
if err != nil {
return 0, nil, fmt.Errorf("get descriptor chain: %w", err)
}
return idx, buf, nil
}
func (dev *Device) TransmitPacket(pkt *packet.OutPacket, kick bool) error {
if len(pkt.SegmentIDs) == 0 {
return nil
}
for idx := range pkt.SegmentIDs {
segmentID := pkt.SegmentIDs[idx]
dev.TransmitQueue.SetDescSize(segmentID, len(pkt.Segments[idx]))
}
err := dev.TransmitQueue.OfferDescriptorChains(pkt.SegmentIDs, false)
if err != nil {
return fmt.Errorf("offer descriptor chains: %w", err)
}
pkt.Reset()
if kick {
if err := dev.TransmitQueue.Kick(); err != nil {
return err
}
}
return nil
}
func (dev *Device) TransmitPackets(pkts []*packet.OutPacket) error {
if len(pkts) == 0 {
return nil
}
for i := range pkts {
if err := dev.TransmitPacket(pkts[i], false); err != nil {
return err
}
}
if err := dev.TransmitQueue.Kick(); err != nil {
return err
}
return nil
}
// TODO: Make above methods cancelable by taking a context.Context argument?
// TODO: Implement zero-copy variants to transmit and receive packets?
// processChains processes as many chains as needed to create one packet. The number of processed chains is returned.
func (dev *Device) processChains(pkt *packet.VirtIOPacket, chains []virtqueue.UsedElement) (int, error) {
//read first element to see how many descriptors we need:
pkt.Reset()
err := dev.ReceiveQueue.GetDescriptorInbuffers(uint16(chains[0].DescriptorIndex), &pkt.ChainRefs)
if err != nil {
return 0, fmt.Errorf("get descriptor chain: %w", err)
}
if len(pkt.ChainRefs) == 0 {
return 1, nil
}
// The specification requires that the first descriptor chain starts
// with a virtio-net header. It is not clear, whether it is also
// required to be fully contained in the first buffer of that
// descriptor chain, but it is reasonable to assume that this is
// always the case.
// The decode method already does the buffer length check.
if err = pkt.Header.Decode(pkt.ChainRefs[0][0:]); err != nil {
// The device misbehaved. There is no way we can gracefully
// recover from this, because we don't know how many of the
// following descriptor chains belong to this packet.
return 0, fmt.Errorf("decode vnethdr: %w", err)
}
//we have the header now: what do we need to do?
if int(pkt.Header.NumBuffers) > len(chains) {
return 0, fmt.Errorf("number of buffers is greater than number of chains %d", len(chains))
}
if int(pkt.Header.NumBuffers) != 1 {
return 0, fmt.Errorf("too smol-brain to handle more than one chain right now: %d chains", len(chains))
}
if chains[0].Length > 16000 {
//todo!
return 1, fmt.Errorf("too big packet length: %d", chains[0].Length)
}
//shift the buffer out of out:
pkt.Payload = pkt.ChainRefs[0][virtio.NetHdrSize:chains[0].Length]
pkt.Chains = append(pkt.Chains, uint16(chains[0].DescriptorIndex))
return 1, nil
//cursor := n - virtio.NetHdrSize
//
//if uint32(n) >= chains[0].Length && pkt.Header.NumBuffers == 1 {
// pkt.Payload = pkt.Payload[:chains[0].Length-virtio.NetHdrSize]
// return 1, nil
//}
//
//i := 1
//// we used chain 0 already
//for i = 1; i < len(chains); i++ {
// n, err = dev.ReceiveQueue.GetDescriptorChainContents(uint16(chains[i].DescriptorIndex), pkt.Payload[cursor:], int(chains[i].Length))
// if err != nil {
// // When this fails we may miss to free some descriptor chains. We
// // could try to mitigate this by deferring the freeing somehow, but
// // it's not worth the hassle. When this method fails, the queue will
// // be in a broken state anyway.
// return i, fmt.Errorf("get descriptor chain: %w", err)
// }
// cursor += n
//}
////todo this has to be wrong
//pkt.Payload = pkt.Payload[:cursor]
//return i, nil
}
func (dev *Device) ReceivePackets(out []*packet.VirtIOPacket) (int, error) {
//todo optimize?
var chains []virtqueue.UsedElement
var err error
//if len(dev.extraRx) == 0 {
chains, err = dev.ReceiveQueue.BlockAndGetHeadsCapped(context.TODO(), len(out))
if err != nil {
return 0, err
}
if len(chains) == 0 {
return 0, nil
}
//} else {
// chains = dev.extraRx
//}
numPackets := 0
chainsIdx := 0
for numPackets = 0; chainsIdx < len(chains); numPackets++ {
if numPackets >= len(out) {
return numPackets, fmt.Errorf("dropping %d packets, no room", len(chains)-numPackets)
}
numChains, err := dev.processChains(out[numPackets], chains[chainsIdx:])
if err != nil {
return 0, err
}
chainsIdx += numChains
}
// Now that we have copied all buffers, we can recycle the used descriptor chains
//if err = dev.ReceiveQueue.OfferDescriptorChains(chains); err != nil {
// return 0, err
//}
return numPackets, nil
}

3
overlay/vhostnet/doc.go Normal file
View File

@@ -0,0 +1,3 @@
// Package vhostnet implements methods to initialize vhost networking devices
// within the kernel-level virtio implementation and communicate with them.
package vhostnet

31
overlay/vhostnet/ioctl.go Normal file
View File

@@ -0,0 +1,31 @@
package vhostnet
import (
"fmt"
"unsafe"
"github.com/slackhq/nebula/overlay/vhost"
)
const (
// vhostNetIoctlSetBackend can be used to attach a virtqueue to a RAW socket
// or TAP device.
//
// Request payload: [vhost.QueueFile]
// Kernel name: VHOST_NET_SET_BACKEND
vhostNetIoctlSetBackend = 0x4008af30
)
// SetQueueBackend attaches a virtqueue of the vhost networking device
// described by controlFD to the given backend file descriptor.
// The backend file descriptor can either be a RAW socket or a TAP device. When
// it is -1, the queue will be detached.
func SetQueueBackend(controlFD int, queueIndex uint32, backendFD int) error {
if err := vhost.IoctlPtr(controlFD, vhostNetIoctlSetBackend, unsafe.Pointer(&vhost.QueueFile{
QueueIndex: queueIndex,
FD: int32(backendFD),
})); err != nil {
return fmt.Errorf("set queue backend file descriptor: %w", err)
}
return nil
}

View File

@@ -0,0 +1,69 @@
package vhostnet
import (
"errors"
"github.com/slackhq/nebula/overlay/virtqueue"
)
type optionValues struct {
queueSize int
backendFD int
}
func (o *optionValues) apply(options []Option) {
for _, option := range options {
option(o)
}
}
func (o *optionValues) validate() error {
if o.queueSize == -1 {
return errors.New("queue size is required")
}
if err := virtqueue.CheckQueueSize(o.queueSize); err != nil {
return err
}
if o.backendFD == -1 {
return errors.New("backend file descriptor is required")
}
return nil
}
var optionDefaults = optionValues{
// Required.
queueSize: -1,
// Required.
backendFD: -1,
}
// Option can be passed to [NewDevice] to influence device creation.
type Option func(*optionValues)
// WithQueueSize returns an [Option] that sets the size of the TX and RX queues
// that are to be created for the device. It specifies the number of
// entries/buffers each queue can hold. This also affects the memory
// consumption.
// This is required and must be an integer from 1 to 32768 that is also a power
// of 2.
func WithQueueSize(queueSize int) Option {
return func(o *optionValues) { o.queueSize = queueSize }
}
// WithBackendFD returns an [Option] that sets the file descriptor of the
// backend that will be used for the queues of the device. The device will write
// and read packets to/from that backend. The file descriptor can either be of a
// RAW socket or TUN/TAP device.
// Either this or [WithBackendDevice] is required.
func WithBackendFD(backendFD int) Option {
return func(o *optionValues) { o.backendFD = backendFD }
}
//// WithBackendDevice returns an [Option] that sets the given TAP device as the
//// backend that will be used for the queues of the device. The device will
//// write and read packets to/from that backend. The TAP device should have been
//// created with the [tuntap.WithVirtioNetHdr] option enabled.
//// Either this or [WithBackendFD] is required.
//func WithBackendDevice(dev *tuntap.Device) Option {
// return func(o *optionValues) { o.backendFD = int(dev.File().Fd()) }
//}

View File

@@ -0,0 +1,23 @@
Significant portions of this code are derived from https://pkg.go.dev/github.com/hetznercloud/virtio-go
MIT License
Copyright (c) 2025 Hetzner Cloud GmbH
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@@ -0,0 +1,140 @@
package virtqueue
import (
"fmt"
"unsafe"
)
// availableRingFlag is a flag that describes an [AvailableRing].
type availableRingFlag uint16
const (
// availableRingFlagNoInterrupt is used by the guest to advise the host to
// not interrupt it when consuming a buffer. It's unreliable, so it's simply
// an optimization.
availableRingFlagNoInterrupt availableRingFlag = 1 << iota
)
// availableRingSize is the number of bytes needed to store an [AvailableRing]
// with the given queue size in memory.
func availableRingSize(queueSize int) int {
return 6 + 2*queueSize
}
// availableRingAlignment is the minimum alignment of an [AvailableRing]
// in memory, as required by the virtio spec.
const availableRingAlignment = 2
// AvailableRing is used by the driver to offer descriptor chains to the device.
// Each ring entry refers to the head of a descriptor chain. It is only written
// to by the driver and read by the device.
//
// Because the size of the ring depends on the queue size, we cannot define a
// Go struct with a static size that maps to the memory of the ring. Instead,
// this struct only contains pointers to the corresponding memory areas.
type AvailableRing struct {
initialized bool
// flags that describe this ring.
flags *availableRingFlag
// ringIndex indicates where the driver would put the next entry into the
// ring (modulo the queue size).
ringIndex *uint16
// ring references buffers using the index of the head of the descriptor
// chain in the [DescriptorTable]. It wraps around at queue size.
ring []uint16
// usedEvent is not used by this implementation, but we reserve it anyway to
// avoid issues in case a device may try to access it, contrary to the
// virtio specification.
usedEvent *uint16
}
// newAvailableRing creates an available ring that uses the given underlying
// memory. The length of the memory slice must match the size needed for the
// ring (see [availableRingSize]) for the given queue size.
func newAvailableRing(queueSize int, mem []byte) *AvailableRing {
ringSize := availableRingSize(queueSize)
if len(mem) != ringSize {
panic(fmt.Sprintf("memory size (%v) does not match required size "+
"for available ring: %v", len(mem), ringSize))
}
return &AvailableRing{
initialized: true,
flags: (*availableRingFlag)(unsafe.Pointer(&mem[0])),
ringIndex: (*uint16)(unsafe.Pointer(&mem[2])),
ring: unsafe.Slice((*uint16)(unsafe.Pointer(&mem[4])), queueSize),
usedEvent: (*uint16)(unsafe.Pointer(&mem[ringSize-2])),
}
}
// Address returns the pointer to the beginning of the ring in memory.
// Do not modify the memory directly to not interfere with this implementation.
func (r *AvailableRing) Address() uintptr {
if !r.initialized {
panic("available ring is not initialized")
}
return uintptr(unsafe.Pointer(r.flags))
}
// offer adds the given descriptor chain heads to the available ring and
// advances the ring index accordingly to make the device process the new
// descriptor chains.
func (r *AvailableRing) offerElements(chains []UsedElement) {
//always called under lock
//r.mu.Lock()
//defer r.mu.Unlock()
// Add descriptor chain heads to the ring.
for offset, x := range chains {
// The 16-bit ring index may overflow. This is expected and is not an
// issue because the size of the ring array (which equals the queue
// size) is always a power of 2 and smaller than the highest possible
// 16-bit value.
insertIndex := int(*r.ringIndex+uint16(offset)) % len(r.ring)
r.ring[insertIndex] = x.GetHead()
}
// Increase the ring index by the number of descriptor chains added to the
// ring.
*r.ringIndex += uint16(len(chains))
}
func (r *AvailableRing) offer(chains []uint16) {
//always called under lock
//r.mu.Lock()
//defer r.mu.Unlock()
// Add descriptor chain heads to the ring.
for offset, x := range chains {
// The 16-bit ring index may overflow. This is expected and is not an
// issue because the size of the ring array (which equals the queue
// size) is always a power of 2 and smaller than the highest possible
// 16-bit value.
insertIndex := int(*r.ringIndex+uint16(offset)) % len(r.ring)
r.ring[insertIndex] = x
}
// Increase the ring index by the number of descriptor chains added to the
// ring.
*r.ringIndex += uint16(len(chains))
}
func (r *AvailableRing) offerSingle(x uint16) {
//always called under lock
//r.mu.Lock()
//defer r.mu.Unlock()
offset := 0
// Add descriptor chain heads to the ring.
// The 16-bit ring index may overflow. This is expected and is not an
// issue because the size of the ring array (which equals the queue
// size) is always a power of 2 and smaller than the highest possible
// 16-bit value.
insertIndex := int(*r.ringIndex+uint16(offset)) % len(r.ring)
r.ring[insertIndex] = x
// Increase the ring index by the number of descriptor chains added to the ring.
*r.ringIndex += 1
}

View File

@@ -0,0 +1,71 @@
package virtqueue
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestAvailableRing_MemoryLayout(t *testing.T) {
const queueSize = 2
memory := make([]byte, availableRingSize(queueSize))
r := newAvailableRing(queueSize, memory)
*r.flags = 0x01ff
*r.ringIndex = 1
r.ring[0] = 0x1234
r.ring[1] = 0x5678
assert.Equal(t, []byte{
0xff, 0x01,
0x01, 0x00,
0x34, 0x12,
0x78, 0x56,
0x00, 0x00,
}, memory)
}
func TestAvailableRing_Offer(t *testing.T) {
const queueSize = 8
chainHeads := []uint16{42, 33, 69}
tests := []struct {
name string
startRingIndex uint16
expectedRingIndex uint16
expectedRing []uint16
}{
{
name: "no overflow",
startRingIndex: 0,
expectedRingIndex: 3,
expectedRing: []uint16{42, 33, 69, 0, 0, 0, 0, 0},
},
{
name: "ring overflow",
startRingIndex: 6,
expectedRingIndex: 9,
expectedRing: []uint16{69, 0, 0, 0, 0, 0, 42, 33},
},
{
name: "index overflow",
startRingIndex: 65535,
expectedRingIndex: 2,
expectedRing: []uint16{33, 69, 0, 0, 0, 0, 0, 42},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
memory := make([]byte, availableRingSize(queueSize))
r := newAvailableRing(queueSize, memory)
*r.ringIndex = tt.startRingIndex
r.offer(chainHeads)
assert.Equal(t, tt.expectedRingIndex, *r.ringIndex)
assert.Equal(t, tt.expectedRing, r.ring)
})
}
}

View File

@@ -0,0 +1,43 @@
package virtqueue
// descriptorFlag is a flag that describes a [Descriptor].
type descriptorFlag uint16
const (
// descriptorFlagHasNext marks a descriptor chain as continuing via the next
// field.
descriptorFlagHasNext descriptorFlag = 1 << iota
// descriptorFlagWritable marks a buffer as device write-only (otherwise
// device read-only).
descriptorFlagWritable
// descriptorFlagIndirect means the buffer contains a list of buffer
// descriptors to provide an additional layer of indirection.
// Only allowed when the [virtio.FeatureIndirectDescriptors] feature was
// negotiated.
descriptorFlagIndirect
)
// descriptorSize is the number of bytes needed to store a [Descriptor] in
// memory.
const descriptorSize = 16
// Descriptor describes (a part of) a buffer which is either read-only for the
// device or write-only for the device (depending on [descriptorFlagWritable]).
// Multiple descriptors can be chained to produce a "descriptor chain" that can
// contain both device-readable and device-writable buffers. Device-readable
// descriptors always come first in a chain. A single, large buffer may be
// split up by chaining multiple similar descriptors that reference different
// memory pages. This is required, because buffers may exceed a single page size
// and the memory accessed by the device is expected to be continuous.
type Descriptor struct {
// address is the address to the continuous memory holding the data for this
// descriptor.
address uintptr
// length is the amount of bytes stored at address.
length uint32
// flags that describe this descriptor.
flags descriptorFlag
// next contains the index of the next descriptor continuing this descriptor
// chain when the [descriptorFlagHasNext] flag is set.
next uint16
}

View File

@@ -0,0 +1,12 @@
package virtqueue
import (
"testing"
"unsafe"
"github.com/stretchr/testify/assert"
)
func TestDescriptor_Size(t *testing.T) {
assert.EqualValues(t, descriptorSize, unsafe.Sizeof(Descriptor{}))
}

View File

@@ -0,0 +1,641 @@
package virtqueue
import (
"errors"
"fmt"
"math"
"unsafe"
"golang.org/x/sys/unix"
)
var (
// ErrDescriptorChainEmpty is returned when a descriptor chain would contain
// no buffers, which is not allowed.
ErrDescriptorChainEmpty = errors.New("empty descriptor chains are not allowed")
// ErrNotEnoughFreeDescriptors is returned when the free descriptors are
// exhausted, meaning that the queue is full.
ErrNotEnoughFreeDescriptors = errors.New("not enough free descriptors, queue is full")
// ErrInvalidDescriptorChain is returned when a descriptor chain is not
// valid for a given operation.
ErrInvalidDescriptorChain = errors.New("invalid descriptor chain")
)
// noFreeHead is used to mark when all descriptors are in use and we have no
// free chain. This value is impossible to occur as an index naturally, because
// it exceeds the maximum queue size.
const noFreeHead = uint16(math.MaxUint16)
// descriptorTableSize is the number of bytes needed to store a
// [DescriptorTable] with the given queue size in memory.
func descriptorTableSize(queueSize int) int {
return descriptorSize * queueSize
}
// descriptorTableAlignment is the minimum alignment of a [DescriptorTable]
// in memory, as required by the virtio spec.
const descriptorTableAlignment = 16
// DescriptorTable is a table that holds [Descriptor]s, addressed via their
// index in the slice.
type DescriptorTable struct {
descriptors []Descriptor
// freeHeadIndex is the index of the head of the descriptor chain which
// contains all currently unused descriptors. When all descriptors are in
// use, this has the special value of noFreeHead.
freeHeadIndex uint16
// freeNum tracks the number of descriptors which are currently not in use.
freeNum uint16
bufferBase uintptr
bufferSize int
itemSize int
}
// newDescriptorTable creates a descriptor table that uses the given underlying
// memory. The Length of the memory slice must match the size needed for the
// descriptor table (see [descriptorTableSize]) for the given queue size.
//
// Before this descriptor table can be used, [initialize] must be called.
func newDescriptorTable(queueSize int, mem []byte, itemSize int) *DescriptorTable {
dtSize := descriptorTableSize(queueSize)
if len(mem) != dtSize {
panic(fmt.Sprintf("memory size (%v) does not match required size "+
"for descriptor table: %v", len(mem), dtSize))
}
return &DescriptorTable{
descriptors: unsafe.Slice((*Descriptor)(unsafe.Pointer(&mem[0])), queueSize),
// We have no free descriptors until they were initialized.
freeHeadIndex: noFreeHead,
freeNum: 0,
itemSize: itemSize, //todo configurable? needs to be page-aligned
}
}
// Address returns the pointer to the beginning of the descriptor table in
// memory. Do not modify the memory directly to not interfere with this
// implementation.
func (dt *DescriptorTable) Address() uintptr {
if dt.descriptors == nil {
panic("descriptor table is not initialized")
}
//should be same as dt.bufferBase
return uintptr(unsafe.Pointer(&dt.descriptors[0]))
}
func (dt *DescriptorTable) Size() uintptr {
if dt.descriptors == nil {
panic("descriptor table is not initialized")
}
return uintptr(dt.bufferSize)
}
// BufferAddresses returns a map of pointer->size for all allocations used by the table
func (dt *DescriptorTable) BufferAddresses() map[uintptr]int {
if dt.descriptors == nil {
panic("descriptor table is not initialized")
}
return map[uintptr]int{dt.bufferBase: dt.bufferSize}
}
// initializeDescriptors allocates buffers with the size of a full memory page
// for each descriptor in the table. While this may be a bit wasteful, it makes
// dealing with descriptors way easier. Without this preallocation, we would
// have to allocate and free memory on demand, increasing complexity.
//
// All descriptors will be marked as free and will form a free chain. The
// addresses of all descriptors will be populated while their length remains
// zero.
func (dt *DescriptorTable) initializeDescriptors() error {
numDescriptors := len(dt.descriptors)
// Allocate ONE large region for all buffers
totalSize := dt.itemSize * numDescriptors
basePtr, err := unix.MmapPtr(-1, 0, nil, uintptr(totalSize),
unix.PROT_READ|unix.PROT_WRITE,
unix.MAP_PRIVATE|unix.MAP_ANONYMOUS)
if err != nil {
return fmt.Errorf("allocate buffer memory for descriptors: %w", err)
}
// Store the base for cleanup later
dt.bufferBase = uintptr(basePtr)
dt.bufferSize = totalSize
for i := range dt.descriptors {
dt.descriptors[i] = Descriptor{
address: dt.bufferBase + uintptr(i*dt.itemSize),
length: 0,
// All descriptors should form a free chain that loops around.
flags: descriptorFlagHasNext,
next: uint16((i + 1) % len(dt.descriptors)),
}
}
// All descriptors are free to use now.
dt.freeHeadIndex = 0
dt.freeNum = uint16(len(dt.descriptors))
return nil
}
// releaseBuffers releases all allocated buffers for this descriptor table.
// The implementation will try to release as many buffers as possible and
// collect potential errors before returning them.
// The descriptor table should no longer be used after calling this.
func (dt *DescriptorTable) releaseBuffers() error {
for i := range dt.descriptors {
descriptor := &dt.descriptors[i]
descriptor.address = 0
}
// As a safety measure, make sure no descriptors can be used anymore.
dt.freeHeadIndex = noFreeHead
dt.freeNum = 0
if dt.bufferBase != 0 {
// The pointer points to memory not managed by Go, so this conversion
// is safe. See https://github.com/golang/go/issues/58625
dt.bufferBase = 0
//goland:noinspection GoVetUnsafePointer
err := unix.MunmapPtr(unsafe.Pointer(dt.bufferBase), uintptr(dt.bufferSize))
if err != nil {
return fmt.Errorf("release buffer memory: %w", err)
}
}
return nil
}
// createDescriptorChain creates a new descriptor chain within the descriptor
// table which contains a number of device-readable buffers (out buffers) and
// device-writable buffers (in buffers).
//
// All buffers in the outBuffers slice will be concatenated by chaining
// descriptors, one for each buffer in the slice. The size of the single buffers
// must not exceed the size of a memory page (see [os.Getpagesize]).
// When numInBuffers is greater than zero, the given number of device-writable
// descriptors will be appended to the end of the chain, each referencing a
// whole memory page.
//
// The index of the head of the new descriptor chain will be returned. Callers
// should make sure to free the descriptor chain using [freeDescriptorChain]
// after it was used by the device.
//
// When there are not enough free descriptors to hold the given number of
// buffers, an [ErrNotEnoughFreeDescriptors] will be returned. In this case, the
// caller should try again after some descriptor chains were used by the device
// and returned back into the free chain.
func (dt *DescriptorTable) createDescriptorChain(outBuffers [][]byte, numInBuffers int) (uint16, error) {
// Calculate the number of descriptors needed to build the chain.
numDesc := uint16(len(outBuffers) + numInBuffers)
// Descriptor chains must always contain at least one descriptor.
if numDesc < 1 {
return 0, ErrDescriptorChainEmpty
}
// Do we still have enough free descriptors?
if numDesc > dt.freeNum {
return 0, ErrNotEnoughFreeDescriptors
}
// Above validation ensured that there is at least one free descriptor, so
// the free descriptor chain head should be valid.
if dt.freeHeadIndex == noFreeHead {
panic("free descriptor chain head is unset but there should be free descriptors")
}
// To avoid having to iterate over the whole table to find the descriptor
// pointing to the head just to replace the free head, we instead always
// create descriptor chains from the descriptors coming after the head.
// This way we only have to touch the head as a last resort, when all other
// descriptors are already used.
head := dt.descriptors[dt.freeHeadIndex].next
next := head
tail := head
for i, buffer := range outBuffers {
desc := &dt.descriptors[next]
checkUnusedDescriptorLength(next, desc)
if len(buffer) > dt.itemSize {
// The caller should already prevent that from happening.
panic(fmt.Sprintf("out buffer %d has size %d which exceeds desc length %d", i, len(buffer), dt.itemSize))
}
// Copy the buffer to the memory referenced by the descriptor.
// The descriptor address points to memory not managed by Go, so this
// conversion is safe. See https://github.com/golang/go/issues/58625
//goland:noinspection GoVetUnsafePointer
copy(unsafe.Slice((*byte)(unsafe.Pointer(desc.address)), dt.itemSize), buffer)
desc.length = uint32(len(buffer))
// Clear the flags in case there were any others set.
desc.flags = descriptorFlagHasNext
tail = next
next = desc.next
}
for range numInBuffers {
desc := &dt.descriptors[next]
checkUnusedDescriptorLength(next, desc)
// Give the device the maximum available number of bytes to write into.
desc.length = uint32(dt.itemSize)
// Mark the descriptor as device-writable.
desc.flags = descriptorFlagHasNext | descriptorFlagWritable
tail = next
next = desc.next
}
// The last descriptor should end the chain.
tailDesc := &dt.descriptors[tail]
tailDesc.flags &= ^descriptorFlagHasNext
tailDesc.next = 0 // Not necessary to clear this, it's just for looks.
dt.freeNum -= numDesc
if dt.freeNum == 0 {
// The last descriptor in the chain should be the free chain head
// itself.
if tail != dt.freeHeadIndex {
panic("descriptor chain takes up all free descriptors but does not end with the free chain head")
}
// When this new chain takes up all remaining descriptors, we no longer
// have a free chain.
dt.freeHeadIndex = noFreeHead
} else {
// We took some descriptors out of the free chain, so make sure to close
// the circle again.
dt.descriptors[dt.freeHeadIndex].next = next
}
return head, nil
}
func (dt *DescriptorTable) CreateDescriptorForOutputs() (uint16, error) {
//todo just fill the damn table
// Do we still have enough free descriptors?
if 1 > dt.freeNum {
return 0, ErrNotEnoughFreeDescriptors
}
// Above validation ensured that there is at least one free descriptor, so
// the free descriptor chain head should be valid.
if dt.freeHeadIndex == noFreeHead {
panic("free descriptor chain head is unset but there should be free descriptors")
}
// To avoid having to iterate over the whole table to find the descriptor
// pointing to the head just to replace the free head, we instead always
// create descriptor chains from the descriptors coming after the head.
// This way we only have to touch the head as a last resort, when all other
// descriptors are already used.
head := dt.descriptors[dt.freeHeadIndex].next
desc := &dt.descriptors[head]
next := desc.next
checkUnusedDescriptorLength(head, desc)
// Give the device the maximum available number of bytes to write into.
desc.length = uint32(dt.itemSize)
desc.flags = 0 // descriptorFlagWritable
desc.next = 0 // Not necessary to clear this, it's just for looks.
dt.freeNum -= 1
if dt.freeNum == 0 {
// The last descriptor in the chain should be the free chain head
// itself.
if next != dt.freeHeadIndex {
panic("descriptor chain takes up all free descriptors but does not end with the free chain head")
}
// When this new chain takes up all remaining descriptors, we no longer
// have a free chain.
dt.freeHeadIndex = noFreeHead
} else {
// We took some descriptors out of the free chain, so make sure to close
// the circle again.
dt.descriptors[dt.freeHeadIndex].next = next
}
return head, nil
}
func (dt *DescriptorTable) createDescriptorForInputs() (uint16, error) {
// Do we still have enough free descriptors?
if 1 > dt.freeNum {
return 0, ErrNotEnoughFreeDescriptors
}
// Above validation ensured that there is at least one free descriptor, so
// the free descriptor chain head should be valid.
if dt.freeHeadIndex == noFreeHead {
panic("free descriptor chain head is unset but there should be free descriptors")
}
// To avoid having to iterate over the whole table to find the descriptor
// pointing to the head just to replace the free head, we instead always
// create descriptor chains from the descriptors coming after the head.
// This way we only have to touch the head as a last resort, when all other
// descriptors are already used.
head := dt.descriptors[dt.freeHeadIndex].next
desc := &dt.descriptors[head]
next := desc.next
checkUnusedDescriptorLength(head, desc)
// Give the device the maximum available number of bytes to write into.
desc.length = uint32(dt.itemSize)
desc.flags = descriptorFlagWritable
desc.next = 0 // Not necessary to clear this, it's just for looks.
dt.freeNum -= 1
if dt.freeNum == 0 {
// The last descriptor in the chain should be the free chain head
// itself.
if next != dt.freeHeadIndex {
panic("descriptor chain takes up all free descriptors but does not end with the free chain head")
}
// When this new chain takes up all remaining descriptors, we no longer
// have a free chain.
dt.freeHeadIndex = noFreeHead
} else {
// We took some descriptors out of the free chain, so make sure to close
// the circle again.
dt.descriptors[dt.freeHeadIndex].next = next
}
return head, nil
}
// TODO: Implement a zero-copy variant of createDescriptorChain?
// getDescriptorChain returns the device-readable buffers (out buffers) and
// device-writable buffers (in buffers) of the descriptor chain that starts with
// the given head index. The descriptor chain must have been created using
// [createDescriptorChain] and must not have been freed yet (meaning that the
// head index must not be contained in the free chain).
//
// Be careful to only access the returned buffer slices when the device has not
// yet or is no longer using them. They must not be accessed after
// [freeDescriptorChain] has been called.
func (dt *DescriptorTable) getDescriptorChain(head uint16) (outBuffers, inBuffers [][]byte, err error) {
if int(head) > len(dt.descriptors) {
return nil, nil, fmt.Errorf("%w: index out of range", ErrInvalidDescriptorChain)
}
// Iterate over the chain. The iteration is limited to the queue size to
// avoid ending up in an endless loop when things go very wrong.
next := head
for range len(dt.descriptors) {
if next == dt.freeHeadIndex {
return nil, nil, fmt.Errorf("%w: must not be part of the free chain", ErrInvalidDescriptorChain)
}
desc := &dt.descriptors[next]
// The descriptor address points to memory not managed by Go, so this
// conversion is safe. See https://github.com/golang/go/issues/58625
//goland:noinspection GoVetUnsafePointer
bs := unsafe.Slice((*byte)(unsafe.Pointer(desc.address)), desc.length)
if desc.flags&descriptorFlagWritable == 0 {
outBuffers = append(outBuffers, bs)
} else {
inBuffers = append(inBuffers, bs)
}
// Is this the tail of the chain?
if desc.flags&descriptorFlagHasNext == 0 {
break
}
// Detect loops.
if desc.next == head {
return nil, nil, fmt.Errorf("%w: contains a loop", ErrInvalidDescriptorChain)
}
next = desc.next
}
return
}
func (dt *DescriptorTable) getDescriptorItem(head uint16) ([]byte, error) {
if int(head) > len(dt.descriptors) {
return nil, fmt.Errorf("%w: index out of range", ErrInvalidDescriptorChain)
}
desc := &dt.descriptors[head] //todo this is a pretty nasty hack with no checks
// The descriptor address points to memory not managed by Go, so this
// conversion is safe. See https://github.com/golang/go/issues/58625
//goland:noinspection GoVetUnsafePointer
bs := unsafe.Slice((*byte)(unsafe.Pointer(desc.address)), desc.length)
return bs, nil
}
func (dt *DescriptorTable) getDescriptorInbuffers(head uint16, inBuffers *[][]byte) error {
if int(head) > len(dt.descriptors) {
return fmt.Errorf("%w: index out of range", ErrInvalidDescriptorChain)
}
// Iterate over the chain. The iteration is limited to the queue size to
// avoid ending up in an endless loop when things go very wrong.
next := head
for range len(dt.descriptors) {
if next == dt.freeHeadIndex {
return fmt.Errorf("%w: must not be part of the free chain", ErrInvalidDescriptorChain)
}
desc := &dt.descriptors[next]
// The descriptor address points to memory not managed by Go, so this
// conversion is safe. See https://github.com/golang/go/issues/58625
//goland:noinspection GoVetUnsafePointer
bs := unsafe.Slice((*byte)(unsafe.Pointer(desc.address)), desc.length)
if desc.flags&descriptorFlagWritable == 0 {
return fmt.Errorf("there should not be an outbuffer in %d", head)
} else {
*inBuffers = append(*inBuffers, bs)
}
// Is this the tail of the chain?
if desc.flags&descriptorFlagHasNext == 0 {
break
}
// Detect loops.
if desc.next == head {
return fmt.Errorf("%w: contains a loop", ErrInvalidDescriptorChain)
}
next = desc.next
}
return nil
}
func (dt *DescriptorTable) getDescriptorChainContents(head uint16, out []byte, maxLen int) (int, error) {
if int(head) > len(dt.descriptors) {
return 0, fmt.Errorf("%w: index out of range", ErrInvalidDescriptorChain)
}
// Iterate over the chain. The iteration is limited to the queue size to
// avoid ending up in an endless loop when things go very wrong.
length := 0
//find length
next := head
for range len(dt.descriptors) {
if next == dt.freeHeadIndex {
return 0, fmt.Errorf("%w: must not be part of the free chain", ErrInvalidDescriptorChain)
}
desc := &dt.descriptors[next]
if desc.flags&descriptorFlagWritable == 0 {
return 0, fmt.Errorf("receive queue contains device-readable buffer")
}
length += int(desc.length)
// Is this the tail of the chain?
if desc.flags&descriptorFlagHasNext == 0 {
break
}
// Detect loops.
if desc.next == head {
return 0, fmt.Errorf("%w: contains a loop", ErrInvalidDescriptorChain)
}
next = desc.next
}
if maxLen > 0 {
//todo length = min(maxLen, length)
}
//set out to length:
out = out[:length]
//now do the copying
copied := 0
for range len(dt.descriptors) {
desc := &dt.descriptors[next]
// The descriptor address points to memory not managed by Go, so this
// conversion is safe. See https://github.com/golang/go/issues/58625
//goland:noinspection GoVetUnsafePointer
bs := unsafe.Slice((*byte)(unsafe.Pointer(desc.address)), min(uint32(length-copied), desc.length))
copied += copy(out[copied:], bs)
// Is this the tail of the chain?
if desc.flags&descriptorFlagHasNext == 0 {
break
}
// we did this already, no need to detect loops.
next = desc.next
}
if copied != length {
panic(fmt.Sprintf("expected to copy %d bytes but only copied %d bytes", length, copied))
}
return length, nil
}
// freeDescriptorChain can be used to free a descriptor chain when it is no
// longer in use. The descriptor chain that starts with the given index will be
// put back into the free chain, so the descriptors can be used for later calls
// of [createDescriptorChain].
// The descriptor chain must have been created using [createDescriptorChain] and
// must not have been freed yet (meaning that the head index must not be
// contained in the free chain).
func (dt *DescriptorTable) freeDescriptorChain(head uint16) error {
if int(head) > len(dt.descriptors) {
return fmt.Errorf("%w: index out of range", ErrInvalidDescriptorChain)
}
// Iterate over the chain. The iteration is limited to the queue size to
// avoid ending up in an endless loop when things go very wrong.
next := head
var tailDesc *Descriptor
var chainLen uint16
for range len(dt.descriptors) {
if next == dt.freeHeadIndex {
return fmt.Errorf("%w: must not be part of the free chain", ErrInvalidDescriptorChain)
}
desc := &dt.descriptors[next]
chainLen++
// Set the length of all unused descriptors back to zero.
desc.length = 0
// Unset all flags except the next flag.
desc.flags &= descriptorFlagHasNext
// Is this the tail of the chain?
if desc.flags&descriptorFlagHasNext == 0 {
tailDesc = desc
break
}
// Detect loops.
if desc.next == head {
return fmt.Errorf("%w: contains a loop", ErrInvalidDescriptorChain)
}
next = desc.next
}
if tailDesc == nil {
// A descriptor chain longer than the queue size but without loops
// should be impossible.
panic(fmt.Sprintf("could not find a tail for descriptor chain starting at %d", head))
}
// The tail descriptor does not have the next flag set, but when it comes
// back into the free chain, it should have.
tailDesc.flags = descriptorFlagHasNext
if dt.freeHeadIndex == noFreeHead {
// The whole free chain was used up, so we turn this returned descriptor
// chain into the new free chain by completing the circle and using its
// head.
tailDesc.next = head
dt.freeHeadIndex = head
} else {
// Attach the returned chain at the beginning of the free chain but
// right after the free chain head.
freeHeadDesc := &dt.descriptors[dt.freeHeadIndex]
tailDesc.next = freeHeadDesc.next
freeHeadDesc.next = head
}
dt.freeNum += chainLen
return nil
}
// checkUnusedDescriptorLength asserts that the length of an unused descriptor
// is zero, as it should be.
// This is not a requirement by the virtio spec but rather a thing we do to
// notice when our algorithm goes sideways.
func checkUnusedDescriptorLength(index uint16, desc *Descriptor) {
if desc.length != 0 {
panic(fmt.Sprintf("descriptor %d should be unused but has a non-zero length", index))
}
}

View File

@@ -0,0 +1,407 @@
package virtqueue
import (
"os"
"testing"
"unsafe"
"github.com/stretchr/testify/assert"
)
func TestDescriptorTable_InitializeDescriptors(t *testing.T) {
const queueSize = 32
dt := DescriptorTable{
descriptors: make([]Descriptor, queueSize),
}
assert.NoError(t, dt.initializeDescriptors())
t.Cleanup(func() {
assert.NoError(t, dt.releaseBuffers())
})
for i, descriptor := range dt.descriptors {
assert.NotZero(t, descriptor.address)
assert.Zero(t, descriptor.length)
assert.EqualValues(t, descriptorFlagHasNext, descriptor.flags)
assert.EqualValues(t, (i+1)%queueSize, descriptor.next)
}
}
func TestDescriptorTable_DescriptorChains(t *testing.T) {
// Use a very short queue size to not make this test overly verbose.
const queueSize = 8
pageSize := os.Getpagesize() * 2
// Initialize descriptor table.
dt := DescriptorTable{
descriptors: make([]Descriptor, queueSize),
}
assert.NoError(t, dt.initializeDescriptors())
t.Cleanup(func() {
assert.NoError(t, dt.releaseBuffers())
})
// Some utilities for easier checking if the descriptor table looks as
// expected.
type desc struct {
buffer []byte
flags descriptorFlag
next uint16
}
assertDescriptorTable := func(expected [queueSize]desc) {
for i := 0; i < queueSize; i++ {
actualDesc := &dt.descriptors[i]
expectedDesc := &expected[i]
assert.Equal(t, uint32(len(expectedDesc.buffer)), actualDesc.length)
if len(expectedDesc.buffer) > 0 {
//goland:noinspection GoVetUnsafePointer
assert.EqualValues(t,
unsafe.Slice((*byte)(unsafe.Pointer(actualDesc.address)), actualDesc.length),
expectedDesc.buffer)
}
assert.Equal(t, expectedDesc.flags, actualDesc.flags)
if expectedDesc.flags&descriptorFlagHasNext != 0 {
assert.Equal(t, expectedDesc.next, actualDesc.next)
}
}
}
// Initial state: All descriptors are in the free chain.
assert.Equal(t, uint16(0), dt.freeHeadIndex)
assert.Equal(t, uint16(8), dt.freeNum)
assertDescriptorTable([queueSize]desc{
{
// Free head.
flags: descriptorFlagHasNext,
next: 1,
},
{
flags: descriptorFlagHasNext,
next: 2,
},
{
flags: descriptorFlagHasNext,
next: 3,
},
{
flags: descriptorFlagHasNext,
next: 4,
},
{
flags: descriptorFlagHasNext,
next: 5,
},
{
flags: descriptorFlagHasNext,
next: 6,
},
{
flags: descriptorFlagHasNext,
next: 7,
},
{
flags: descriptorFlagHasNext,
next: 0,
},
})
// Create the first chain.
firstChain, err := dt.createDescriptorChain([][]byte{
makeTestBuffer(t, 26),
makeTestBuffer(t, 256),
}, 1)
assert.NoError(t, err)
assert.Equal(t, uint16(1), firstChain)
// Now there should be a new chain next to the free chain.
assert.Equal(t, uint16(0), dt.freeHeadIndex)
assert.Equal(t, uint16(5), dt.freeNum)
assertDescriptorTable([queueSize]desc{
{
// Free head.
flags: descriptorFlagHasNext,
next: 4,
},
{
// Head of first chain.
buffer: makeTestBuffer(t, 26),
flags: descriptorFlagHasNext,
next: 2,
},
{
buffer: makeTestBuffer(t, 256),
flags: descriptorFlagHasNext,
next: 3,
},
{
// Tail of first chain.
buffer: make([]byte, pageSize),
flags: descriptorFlagWritable,
},
{
flags: descriptorFlagHasNext,
next: 5,
},
{
flags: descriptorFlagHasNext,
next: 6,
},
{
flags: descriptorFlagHasNext,
next: 7,
},
{
flags: descriptorFlagHasNext,
next: 0,
},
})
// Create a second chain with only a single in buffer.
secondChain, err := dt.createDescriptorChain(nil, 1)
assert.NoError(t, err)
assert.Equal(t, uint16(4), secondChain)
// Now there should be two chains next to the free chain.
assert.Equal(t, uint16(0), dt.freeHeadIndex)
assert.Equal(t, uint16(4), dt.freeNum)
assertDescriptorTable([queueSize]desc{
{
// Free head.
flags: descriptorFlagHasNext,
next: 5,
},
{
// Head of the first chain.
buffer: makeTestBuffer(t, 26),
flags: descriptorFlagHasNext,
next: 2,
},
{
buffer: makeTestBuffer(t, 256),
flags: descriptorFlagHasNext,
next: 3,
},
{
// Tail of the first chain.
buffer: make([]byte, pageSize),
flags: descriptorFlagWritable,
},
{
// Head and tail of the second chain.
buffer: make([]byte, pageSize),
flags: descriptorFlagWritable,
},
{
flags: descriptorFlagHasNext,
next: 6,
},
{
flags: descriptorFlagHasNext,
next: 7,
},
{
flags: descriptorFlagHasNext,
next: 0,
},
})
// Create a third chain taking up all remaining descriptors.
thirdChain, err := dt.createDescriptorChain([][]byte{
makeTestBuffer(t, 42),
makeTestBuffer(t, 96),
makeTestBuffer(t, 33),
makeTestBuffer(t, 222),
}, 0)
assert.NoError(t, err)
assert.Equal(t, uint16(5), thirdChain)
// Now there should be three chains and no free chain.
assert.Equal(t, noFreeHead, dt.freeHeadIndex)
assert.Equal(t, uint16(0), dt.freeNum)
assertDescriptorTable([queueSize]desc{
{
// Tail of the third chain.
buffer: makeTestBuffer(t, 222),
},
{
// Head of the first chain.
buffer: makeTestBuffer(t, 26),
flags: descriptorFlagHasNext,
next: 2,
},
{
buffer: makeTestBuffer(t, 256),
flags: descriptorFlagHasNext,
next: 3,
},
{
// Tail of the first chain.
buffer: make([]byte, pageSize),
flags: descriptorFlagWritable,
},
{
// Head and tail of the second chain.
buffer: make([]byte, pageSize),
flags: descriptorFlagWritable,
},
{
// Head of the third chain.
buffer: makeTestBuffer(t, 42),
flags: descriptorFlagHasNext,
next: 6,
},
{
buffer: makeTestBuffer(t, 96),
flags: descriptorFlagHasNext,
next: 7,
},
{
buffer: makeTestBuffer(t, 33),
flags: descriptorFlagHasNext,
next: 0,
},
})
// Free the third chain.
assert.NoError(t, dt.freeDescriptorChain(thirdChain))
// Now there should be two chains and a free chain again.
assert.Equal(t, uint16(5), dt.freeHeadIndex)
assert.Equal(t, uint16(4), dt.freeNum)
assertDescriptorTable([queueSize]desc{
{
flags: descriptorFlagHasNext,
next: 5,
},
{
// Head of the first chain.
buffer: makeTestBuffer(t, 26),
flags: descriptorFlagHasNext,
next: 2,
},
{
buffer: makeTestBuffer(t, 256),
flags: descriptorFlagHasNext,
next: 3,
},
{
// Tail of the first chain.
buffer: make([]byte, pageSize),
flags: descriptorFlagWritable,
},
{
// Head and tail of the second chain.
buffer: make([]byte, pageSize),
flags: descriptorFlagWritable,
},
{
// Free head.
flags: descriptorFlagHasNext,
next: 6,
},
{
flags: descriptorFlagHasNext,
next: 7,
},
{
flags: descriptorFlagHasNext,
next: 0,
},
})
// Free the first chain.
assert.NoError(t, dt.freeDescriptorChain(firstChain))
// Now there should be only a single chain next to the free chain.
assert.Equal(t, uint16(5), dt.freeHeadIndex)
assert.Equal(t, uint16(7), dt.freeNum)
assertDescriptorTable([queueSize]desc{
{
flags: descriptorFlagHasNext,
next: 5,
},
{
flags: descriptorFlagHasNext,
next: 2,
},
{
flags: descriptorFlagHasNext,
next: 3,
},
{
flags: descriptorFlagHasNext,
next: 6,
},
{
// Head and tail of the second chain.
buffer: make([]byte, pageSize),
flags: descriptorFlagWritable,
},
{
// Free head.
flags: descriptorFlagHasNext,
next: 1,
},
{
flags: descriptorFlagHasNext,
next: 7,
},
{
flags: descriptorFlagHasNext,
next: 0,
},
})
// Free the second chain.
assert.NoError(t, dt.freeDescriptorChain(secondChain))
// Now all descriptors should be in the free chain again.
assert.Equal(t, uint16(5), dt.freeHeadIndex)
assert.Equal(t, uint16(8), dt.freeNum)
assertDescriptorTable([queueSize]desc{
{
flags: descriptorFlagHasNext,
next: 5,
},
{
flags: descriptorFlagHasNext,
next: 2,
},
{
flags: descriptorFlagHasNext,
next: 3,
},
{
flags: descriptorFlagHasNext,
next: 6,
},
{
flags: descriptorFlagHasNext,
next: 1,
},
{
// Free head.
flags: descriptorFlagHasNext,
next: 4,
},
{
flags: descriptorFlagHasNext,
next: 7,
},
{
flags: descriptorFlagHasNext,
next: 0,
},
})
}
func makeTestBuffer(t *testing.T, length int) []byte {
t.Helper()
buf := make([]byte, length)
for i := 0; i < length; i++ {
buf[i] = byte(length - i)
}
return buf
}

7
overlay/virtqueue/doc.go Normal file
View File

@@ -0,0 +1,7 @@
// Package virtqueue implements the driver-side for a virtio queue as described
// in the specification:
// https://docs.oasis-open.org/virtio/virtio/v1.2/csd01/virtio-v1.2-csd01.html#x1-270006
// This package does not make assumptions about the device that consumes the
// queue. It rather just allocates the queue structures in memory and provides
// methods to interact with it.
package virtqueue

View File

@@ -0,0 +1,45 @@
package virtqueue
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gvisor.dev/gvisor/pkg/eventfd"
)
// Tests how an eventfd and a waiting goroutine can be gracefully closed.
// Extends the eventfd test suite:
// https://github.com/google/gvisor/blob/0799336d64be65eb97d330606c30162dc3440cab/pkg/eventfd/eventfd_test.go
func TestEventFD_CancelWait(t *testing.T) {
efd, err := eventfd.Create()
require.NoError(t, err)
t.Cleanup(func() {
assert.NoError(t, efd.Close())
})
var stop bool
done := make(chan struct{})
go func() {
for !stop {
_ = efd.Wait()
}
close(done)
}()
select {
case <-done:
t.Fatalf("goroutine ended early")
case <-time.After(500 * time.Millisecond):
}
stop = true
assert.NoError(t, efd.Notify())
select {
case <-done:
break
case <-time.After(5 * time.Second):
t.Error("goroutine did not end")
}
}

33
overlay/virtqueue/size.go Normal file
View File

@@ -0,0 +1,33 @@
package virtqueue
import (
"errors"
"fmt"
)
// ErrQueueSizeInvalid is returned when a queue size is invalid.
var ErrQueueSizeInvalid = errors.New("queue size is invalid")
// CheckQueueSize checks if the given value would be a valid size for a
// virtqueue and returns an [ErrQueueSizeInvalid], if not.
func CheckQueueSize(queueSize int) error {
if queueSize <= 0 {
return fmt.Errorf("%w: %d is too small", ErrQueueSizeInvalid, queueSize)
}
// The queue size must always be a power of 2.
// This ensures that ring indexes wrap correctly when the 16-bit integers
// overflow.
if queueSize&(queueSize-1) != 0 {
return fmt.Errorf("%w: %d is not a power of 2", ErrQueueSizeInvalid, queueSize)
}
// The largest power of 2 that fits into a 16-bit integer is 32768.
// 2 * 32768 would be 65536 which no longer fits.
if queueSize > 32768 {
return fmt.Errorf("%w: %d is larger than the maximum possible queue size 32768",
ErrQueueSizeInvalid, queueSize)
}
return nil
}

View File

@@ -0,0 +1,59 @@
package virtqueue
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestCheckQueueSize(t *testing.T) {
tests := []struct {
name string
queueSize int
containsErr string
}{
{
name: "negative",
queueSize: -1,
containsErr: "too small",
},
{
name: "zero",
queueSize: 0,
containsErr: "too small",
},
{
name: "not a power of 2",
queueSize: 24,
containsErr: "not a power of 2",
},
{
name: "too large",
queueSize: 65536,
containsErr: "larger than the maximum",
},
{
name: "valid 1",
queueSize: 1,
},
{
name: "valid 256",
queueSize: 256,
},
{
name: "valid 32768",
queueSize: 32768,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := CheckQueueSize(tt.queueSize)
if tt.containsErr != "" {
assert.ErrorContains(t, err, tt.containsErr)
} else {
assert.NoError(t, err)
}
})
}
}

View File

@@ -0,0 +1,530 @@
package virtqueue
import (
"context"
"errors"
"fmt"
"os"
"syscall"
"github.com/slackhq/nebula/overlay/eventfd"
"golang.org/x/sys/unix"
)
// SplitQueue is a virtqueue that consists of several parts, where each part is
// writeable by either the driver or the device, but not both.
type SplitQueue struct {
// size is the size of the queue.
size int
// buf is the underlying memory used for the queue.
buf []byte
descriptorTable *DescriptorTable
availableRing *AvailableRing
usedRing *UsedRing
// kickEventFD is used to signal the device when descriptor chains were
// added to the available ring.
kickEventFD eventfd.EventFD
// callEventFD is used by the device to signal when it has used descriptor
// chains and put them in the used ring.
callEventFD eventfd.EventFD
// stop is used by [SplitQueue.Close] to cancel the goroutine that handles
// used buffer notifications. It blocks until the goroutine ended.
stop func() error
itemSize int
epoll eventfd.Epoll
more int
}
// NewSplitQueue allocates a new [SplitQueue] in memory. The given queue size
// specifies the number of entries/buffers the queue can hold. This also affects
// the memory consumption.
func NewSplitQueue(queueSize int, itemSize int) (_ *SplitQueue, err error) {
if err = CheckQueueSize(queueSize); err != nil {
return nil, err
}
if itemSize%os.Getpagesize() != 0 {
return nil, errors.New("split queue size must be multiple of os.Getpagesize()")
}
sq := SplitQueue{
size: queueSize,
itemSize: itemSize,
}
// Clean up a partially initialized queue when something fails.
defer func() {
if err != nil {
_ = sq.Close()
}
}()
// There are multiple ways for how the memory for the virtqueue could be
// allocated. We could use Go native structs with arrays inside them, but
// this wouldn't allow us to make the queue size configurable. And including
// a slice in the Go structs wouldn't work, because this would just put the
// Go slice descriptor into the memory region which the virtio device will
// not understand.
// Additionally, Go does not allow us to ensure a correct alignment of the
// parts of the virtqueue, as it is required by the virtio specification.
//
// To resolve this, let's just allocate the memory manually by allocating
// one or more memory pages, depending on the queue size. Making the
// virtqueue start at the beginning of a page is not strictly necessary, as
// the virtio specification does not require it to be continuous in the
// physical memory of the host (e.g. the vhost implementation in the kernel
// always uses copy_from_user to access it), but this makes it very easy to
// guarantee the alignment. Also, it is not required for the virtqueue parts
// to be in the same memory region, as we pass separate pointers to them to
// the device, but this design just makes things easier to implement.
//
// One added benefit of allocating the memory manually is, that we have full
// control over its lifetime and don't risk the garbage collector to collect
// our valuable structures while the device still works with them.
// The descriptor table is at the start of the page, so alignment is not an
// issue here.
descriptorTableStart := 0
descriptorTableEnd := descriptorTableStart + descriptorTableSize(queueSize)
availableRingStart := align(descriptorTableEnd, availableRingAlignment)
availableRingEnd := availableRingStart + availableRingSize(queueSize)
usedRingStart := align(availableRingEnd, usedRingAlignment)
usedRingEnd := usedRingStart + usedRingSize(queueSize)
sq.buf, err = unix.Mmap(-1, 0, usedRingEnd,
unix.PROT_READ|unix.PROT_WRITE,
unix.MAP_PRIVATE|unix.MAP_ANONYMOUS)
if err != nil {
return nil, fmt.Errorf("allocate virtqueue buffer: %w", err)
}
sq.descriptorTable = newDescriptorTable(queueSize, sq.buf[descriptorTableStart:descriptorTableEnd], sq.itemSize)
sq.availableRing = newAvailableRing(queueSize, sq.buf[availableRingStart:availableRingEnd])
sq.usedRing = newUsedRing(queueSize, sq.buf[usedRingStart:usedRingEnd])
sq.kickEventFD, err = eventfd.New()
if err != nil {
return nil, fmt.Errorf("create kick event file descriptor: %w", err)
}
sq.callEventFD, err = eventfd.New()
if err != nil {
return nil, fmt.Errorf("create call event file descriptor: %w", err)
}
if err = sq.descriptorTable.initializeDescriptors(); err != nil {
return nil, fmt.Errorf("initialize descriptors: %w", err)
}
sq.epoll, err = eventfd.NewEpoll()
if err != nil {
return nil, err
}
err = sq.epoll.AddEvent(sq.callEventFD.FD())
if err != nil {
return nil, err
}
// Consume used buffer notifications in the background.
sq.stop = sq.startConsumeUsedRing()
return &sq, nil
}
// Size returns the size of this queue, which is the number of entries/buffers
// this queue can hold.
func (sq *SplitQueue) Size() int {
return sq.size
}
// DescriptorTable returns the [DescriptorTable] behind this queue.
func (sq *SplitQueue) DescriptorTable() *DescriptorTable {
return sq.descriptorTable
}
// AvailableRing returns the [AvailableRing] behind this queue.
func (sq *SplitQueue) AvailableRing() *AvailableRing {
return sq.availableRing
}
// UsedRing returns the [UsedRing] behind this queue.
func (sq *SplitQueue) UsedRing() *UsedRing {
return sq.usedRing
}
// KickEventFD returns the kick event file descriptor behind this queue.
// The returned file descriptor should be used with great care to not interfere
// with this implementation.
func (sq *SplitQueue) KickEventFD() int {
return sq.kickEventFD.FD()
}
// CallEventFD returns the call event file descriptor behind this queue.
// The returned file descriptor should be used with great care to not interfere
// with this implementation.
func (sq *SplitQueue) CallEventFD() int {
return sq.callEventFD.FD()
}
// startConsumeUsedRing starts a goroutine that runs [consumeUsedRing].
// A function is returned that can be used to gracefully cancel it. todo rename
func (sq *SplitQueue) startConsumeUsedRing() func() error {
return func() error {
// The goroutine blocks until it receives a signal on the event file
// descriptor, so it will never notice the context being canceled.
// To resolve this, we can just produce a fake-signal ourselves to wake
// it up.
if err := sq.callEventFD.Kick(); err != nil {
return fmt.Errorf("wake up goroutine: %w", err)
}
return nil
}
}
// BlockAndGetHeads waits for the device to signal that it has used descriptor chains and returns all [UsedElement]s
func (sq *SplitQueue) BlockAndGetHeads(ctx context.Context) ([]UsedElement, error) {
var n int
var err error
for ctx.Err() == nil {
// Wait for a signal from the device.
if n, err = sq.epoll.Block(); err != nil {
return nil, fmt.Errorf("wait: %w", err)
}
if n > 0 {
stillNeedToTake, out := sq.usedRing.take(-1)
sq.more = stillNeedToTake
if stillNeedToTake == 0 {
_ = sq.epoll.Clear() //???
}
return out, nil
}
}
return nil, ctx.Err()
}
func (sq *SplitQueue) TakeSingle(ctx context.Context) (uint16, error) {
var n int
var err error
for ctx.Err() == nil {
out, ok := sq.usedRing.takeOne()
if ok {
return out, nil
}
// Wait for a signal from the device.
if n, err = sq.epoll.Block(); err != nil {
return 0, fmt.Errorf("wait: %w", err)
}
if n > 0 {
out, ok = sq.usedRing.takeOne()
if ok {
_ = sq.epoll.Clear() //???
return out, nil
} else {
continue //???
}
}
}
return 0, ctx.Err()
}
func (sq *SplitQueue) BlockAndGetHeadsCapped(ctx context.Context, maxToTake int) ([]UsedElement, error) {
var n int
var err error
for ctx.Err() == nil {
//we have leftovers in the fridge
if sq.more > 0 {
stillNeedToTake, out := sq.usedRing.take(maxToTake)
sq.more = stillNeedToTake
return out, nil
}
//look inside the fridge
stillNeedToTake, out := sq.usedRing.take(maxToTake)
if len(out) > 0 {
sq.more = stillNeedToTake
return out, nil
}
//fridge is empty I guess
// Wait for a signal from the device.
if n, err = sq.epoll.Block(); err != nil {
return nil, fmt.Errorf("wait: %w", err)
}
if n > 0 {
_ = sq.epoll.Clear() //???
stillNeedToTake, out = sq.usedRing.take(maxToTake)
sq.more = stillNeedToTake
return out, nil
}
}
return nil, ctx.Err()
}
// OfferDescriptorChain offers a descriptor chain to the device which contains a
// number of device-readable buffers (out buffers) and device-writable buffers
// (in buffers).
//
// All buffers in the outBuffers slice will be concatenated by chaining
// descriptors, one for each buffer in the slice. When a buffer is too large to
// fit into a single descriptor (limited by the system's page size), it will be
// split up into multiple descriptors within the chain.
// When numInBuffers is greater than zero, the given number of device-writable
// descriptors will be appended to the end of the chain, each referencing a
// whole memory page (see [os.Getpagesize]).
//
// When the queue is full and no more descriptor chains can be added, a wrapped
// [ErrNotEnoughFreeDescriptors] will be returned. If you set waitFree to true,
// this method will handle this error and will block instead until there are
// enough free descriptors again.
//
// After defining the descriptor chain in the [DescriptorTable], the index of
// the head of the chain will be made available to the device using the
// [AvailableRing] and will be returned by this method.
// Callers should read from the [SplitQueue.UsedDescriptorChains] channel to be
// notified when the descriptor chain was used by the device and should free the
// used descriptor chains again using [SplitQueue.FreeDescriptorChain] when
// they're done with them. When this does not happen, the queue will run full
// and any further calls to [SplitQueue.OfferDescriptorChain] will stall.
func (sq *SplitQueue) OfferInDescriptorChains() (uint16, error) {
// Create a descriptor chain for the given buffers.
var (
head uint16
err error
)
for {
head, err = sq.descriptorTable.createDescriptorForInputs()
if err == nil {
break
}
// I don't wanna use errors.Is, it's slow
//goland:noinspection GoDirectComparisonOfErrors
if err == ErrNotEnoughFreeDescriptors {
return 0, err
} else {
return 0, fmt.Errorf("create descriptor chain: %w", err)
}
}
// Make the descriptor chain available to the device.
sq.availableRing.offerSingle(head)
// Notify the device to make it process the updated available ring.
if err := sq.kickEventFD.Kick(); err != nil {
return head, fmt.Errorf("notify device: %w", err)
}
return head, nil
}
func (sq *SplitQueue) OfferOutDescriptorChains(prepend []byte, outBuffers [][]byte) ([]uint16, error) {
// TODO change this
// Each descriptor can only hold a whole memory page, so split large out
// buffers into multiple smaller ones.
outBuffers = splitBuffers(outBuffers, sq.itemSize)
chains := make([]uint16, len(outBuffers))
// Create a descriptor chain for the given buffers.
var (
head uint16
err error
)
for i := range outBuffers {
for {
bufs := [][]byte{prepend, outBuffers[i]}
head, err = sq.descriptorTable.createDescriptorChain(bufs, 0)
if err == nil {
break
}
// I don't wanna use errors.Is, it's slow
//goland:noinspection GoDirectComparisonOfErrors
if err == ErrNotEnoughFreeDescriptors {
// Wait for more free descriptors to be put back into the queue.
// If the number of free descriptors is still not sufficient, we'll
// land here again.
//todo should never happen
syscall.Syscall(syscall.SYS_SCHED_YIELD, 0, 0, 0) // Cheap barrier
continue
}
return nil, fmt.Errorf("create descriptor chain: %w", err)
}
chains[i] = head
}
// Make the descriptor chain available to the device.
sq.availableRing.offer(chains)
// Notify the device to make it process the updated available ring.
if err := sq.kickEventFD.Kick(); err != nil {
return chains, fmt.Errorf("notify device: %w", err)
}
return chains, nil
}
// GetDescriptorChain returns the device-readable buffers (out buffers) and
// device-writable buffers (in buffers) of the descriptor chain with the given
// head index.
// The head index must be one that was returned by a previous call to
// [SplitQueue.OfferDescriptorChain] and the descriptor chain must not have been
// freed yet.
//
// Be careful to only access the returned buffer slices when the device is no
// longer using them. They must not be accessed after
// [SplitQueue.FreeDescriptorChain] has been called.
func (sq *SplitQueue) GetDescriptorChain(head uint16) (outBuffers, inBuffers [][]byte, err error) {
return sq.descriptorTable.getDescriptorChain(head)
}
func (sq *SplitQueue) GetDescriptorItem(head uint16) ([]byte, error) {
sq.descriptorTable.descriptors[head].length = uint32(sq.descriptorTable.itemSize)
return sq.descriptorTable.getDescriptorItem(head)
}
func (sq *SplitQueue) GetDescriptorChainContents(head uint16, out []byte, maxLen int) (int, error) {
return sq.descriptorTable.getDescriptorChainContents(head, out, maxLen)
}
func (sq *SplitQueue) GetDescriptorInbuffers(head uint16, inBuffers *[][]byte) error {
return sq.descriptorTable.getDescriptorInbuffers(head, inBuffers)
}
// FreeDescriptorChain frees the descriptor chain with the given head index.
// The head index must be one that was returned by a previous call to
// [SplitQueue.OfferDescriptorChain] and the descriptor chain must not have been
// freed yet.
//
// This creates new room in the queue which can be used by following
// [SplitQueue.OfferDescriptorChain] calls.
// When there are outstanding calls for [SplitQueue.OfferDescriptorChain] that
// are waiting for free room in the queue, they may become unblocked by this.
func (sq *SplitQueue) FreeDescriptorChain(head uint16) error {
//not called under lock
if err := sq.descriptorTable.freeDescriptorChain(head); err != nil {
return fmt.Errorf("free: %w", err)
}
return nil
}
func (sq *SplitQueue) SetDescSize(head uint16, sz int) {
//not called under lock
sq.descriptorTable.descriptors[int(head)].length = uint32(sz)
}
func (sq *SplitQueue) OfferDescriptorChains(chains []uint16, kick bool) error {
//todo not doing this may break eventually?
//not called under lock
//if err := sq.descriptorTable.freeDescriptorChain(head); err != nil {
// return fmt.Errorf("free: %w", err)
//}
// Make the descriptor chain available to the device.
sq.availableRing.offer(chains)
// Notify the device to make it process the updated available ring.
if kick {
return sq.Kick()
}
return nil
}
func (sq *SplitQueue) Kick() error {
if err := sq.kickEventFD.Kick(); err != nil {
return fmt.Errorf("notify device: %w", err)
}
return nil
}
// Close releases all resources used for this queue.
// The implementation will try to release as many resources as possible and
// collect potential errors before returning them.
func (sq *SplitQueue) Close() error {
var errs []error
if sq.stop != nil {
// This has to happen before the event file descriptors may be closed.
if err := sq.stop(); err != nil {
errs = append(errs, fmt.Errorf("stop consume used ring: %w", err))
}
// Make sure that this code block is executed only once.
sq.stop = nil
}
if err := sq.kickEventFD.Close(); err != nil {
errs = append(errs, fmt.Errorf("close kick event file descriptor: %w", err))
}
if err := sq.callEventFD.Close(); err != nil {
errs = append(errs, fmt.Errorf("close call event file descriptor: %w", err))
}
if err := sq.descriptorTable.releaseBuffers(); err != nil {
errs = append(errs, fmt.Errorf("release descriptor buffers: %w", err))
}
if sq.buf != nil {
if err := unix.Munmap(sq.buf); err == nil {
sq.buf = nil
} else {
errs = append(errs, fmt.Errorf("unmap virtqueue buffer: %w", err))
}
}
return errors.Join(errs...)
}
// ensureInitialized is used as a guard to prevent methods to be called on an
// uninitialized instance.
func (sq *SplitQueue) ensureInitialized() {
if sq.buf == nil {
panic("used ring is not initialized")
}
}
func align(index, alignment int) int {
remainder := index % alignment
if remainder == 0 {
return index
}
return index + alignment - remainder
}
// splitBuffers processes a list of buffers and splits each buffer that is
// larger than the size limit into multiple smaller buffers.
// If none of the buffers are too big though, do nothing, to avoid allocation for now
func splitBuffers(buffers [][]byte, sizeLimit int) [][]byte {
for i := range buffers {
if len(buffers[i]) > sizeLimit {
return reallySplitBuffers(buffers, sizeLimit)
}
}
return buffers
}
func reallySplitBuffers(buffers [][]byte, sizeLimit int) [][]byte {
result := make([][]byte, 0, len(buffers))
for _, buffer := range buffers {
for added := 0; added < len(buffer); added += sizeLimit {
if len(buffer)-added <= sizeLimit {
result = append(result, buffer[added:])
break
}
result = append(result, buffer[added:added+sizeLimit])
}
}
return result
}

View File

@@ -0,0 +1,105 @@
package virtqueue
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestSplitQueue_MemoryAlignment(t *testing.T) {
tests := []struct {
name string
queueSize int
}{
{
name: "minimal queue size",
queueSize: 1,
},
{
name: "small queue size",
queueSize: 8,
},
{
name: "large queue size",
queueSize: 256,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
sq, err := NewSplitQueue(tt.queueSize)
require.NoError(t, err)
assert.Zero(t, sq.descriptorTable.Address()%descriptorTableAlignment)
assert.Zero(t, sq.availableRing.Address()%availableRingAlignment)
assert.Zero(t, sq.usedRing.Address()%usedRingAlignment)
})
}
}
func TestSplitBuffers(t *testing.T) {
const sizeLimit = 16
tests := []struct {
name string
buffers [][]byte
expected [][]byte
}{
{
name: "no buffers",
buffers: make([][]byte, 0),
expected: make([][]byte, 0),
},
{
name: "small",
buffers: [][]byte{
make([]byte, 11),
},
expected: [][]byte{
make([]byte, 11),
},
},
{
name: "exact size",
buffers: [][]byte{
make([]byte, sizeLimit),
},
expected: [][]byte{
make([]byte, sizeLimit),
},
},
{
name: "large",
buffers: [][]byte{
make([]byte, 42),
},
expected: [][]byte{
make([]byte, 16),
make([]byte, 16),
make([]byte, 10),
},
},
{
name: "mixed",
buffers: [][]byte{
make([]byte, 7),
make([]byte, 30),
make([]byte, 15),
make([]byte, 32),
},
expected: [][]byte{
make([]byte, 7),
make([]byte, 16),
make([]byte, 14),
make([]byte, 15),
make([]byte, 16),
make([]byte, 16),
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
actual := splitBuffers(tt.buffers, sizeLimit)
assert.Equal(t, tt.expected, actual)
})
}
}

View File

@@ -0,0 +1,21 @@
package virtqueue
// usedElementSize is the number of bytes needed to store a [UsedElement] in
// memory.
const usedElementSize = 8
// UsedElement is an element of the [UsedRing] and describes a descriptor chain
// that was used by the device.
type UsedElement struct {
// DescriptorIndex is the index of the head of the used descriptor chain in
// the [DescriptorTable].
// The index is 32-bit here for padding reasons.
DescriptorIndex uint32
// Length is the number of bytes written into the device writable portion of
// the buffer described by the descriptor chain.
Length uint32
}
func (u *UsedElement) GetHead() uint16 {
return uint16(u.DescriptorIndex)
}

View File

@@ -0,0 +1,12 @@
package virtqueue
import (
"testing"
"unsafe"
"github.com/stretchr/testify/assert"
)
func TestUsedElement_Size(t *testing.T) {
assert.EqualValues(t, usedElementSize, unsafe.Sizeof(UsedElement{}))
}

View File

@@ -0,0 +1,184 @@
package virtqueue
import (
"fmt"
"unsafe"
)
// usedRingFlag is a flag that describes a [UsedRing].
type usedRingFlag uint16
const (
// usedRingFlagNoNotify is used by the host to advise the guest to not
// kick it when adding a buffer. It's unreliable, so it's simply an
// optimization. Guest will still kick when it's out of buffers.
usedRingFlagNoNotify usedRingFlag = 1 << iota
)
// usedRingSize is the number of bytes needed to store a [UsedRing] with the
// given queue size in memory.
func usedRingSize(queueSize int) int {
return 6 + usedElementSize*queueSize
}
// usedRingAlignment is the minimum alignment of a [UsedRing] in memory, as
// required by the virtio spec.
const usedRingAlignment = 4
// UsedRing is where the device returns descriptor chains once it is done with
// them. Each ring entry is a [UsedElement]. It is only written to by the device
// and read by the driver.
//
// Because the size of the ring depends on the queue size, we cannot define a
// Go struct with a static size that maps to the memory of the ring. Instead,
// this struct only contains pointers to the corresponding memory areas.
type UsedRing struct {
initialized bool
// flags that describe this ring.
flags *usedRingFlag
// ringIndex indicates where the device would put the next entry into the
// ring (modulo the queue size).
ringIndex *uint16
// ring contains the [UsedElement]s. It wraps around at queue size.
ring []UsedElement
// availableEvent is not used by this implementation, but we reserve it
// anyway to avoid issues in case a device may try to write to it, contrary
// to the virtio specification.
availableEvent *uint16
// lastIndex is the internal ringIndex up to which all [UsedElement]s were
// processed.
lastIndex uint16
//mu sync.Mutex
}
// newUsedRing creates a used ring that uses the given underlying memory. The
// length of the memory slice must match the size needed for the ring (see
// [usedRingSize]) for the given queue size.
func newUsedRing(queueSize int, mem []byte) *UsedRing {
ringSize := usedRingSize(queueSize)
if len(mem) != ringSize {
panic(fmt.Sprintf("memory size (%v) does not match required size "+
"for used ring: %v", len(mem), ringSize))
}
r := UsedRing{
initialized: true,
flags: (*usedRingFlag)(unsafe.Pointer(&mem[0])),
ringIndex: (*uint16)(unsafe.Pointer(&mem[2])),
ring: unsafe.Slice((*UsedElement)(unsafe.Pointer(&mem[4])), queueSize),
availableEvent: (*uint16)(unsafe.Pointer(&mem[ringSize-2])),
}
r.lastIndex = *r.ringIndex
return &r
}
// Address returns the pointer to the beginning of the ring in memory.
// Do not modify the memory directly to not interfere with this implementation.
func (r *UsedRing) Address() uintptr {
if !r.initialized {
panic("used ring is not initialized")
}
return uintptr(unsafe.Pointer(r.flags))
}
// take returns all new [UsedElement]s that the device put into the ring and
// that weren't already returned by a previous call to this method.
// had a lock, I removed it
func (r *UsedRing) take(maxToTake int) (int, []UsedElement) {
//r.mu.Lock()
//defer r.mu.Unlock()
ringIndex := *r.ringIndex
if ringIndex == r.lastIndex {
// Nothing new.
return 0, nil
}
// Calculate the number new used elements that we can read from the ring.
// The ring index may wrap, so special handling for that case is needed.
count := int(ringIndex - r.lastIndex)
if count < 0 {
count += 0xffff
}
stillNeedToTake := 0
if maxToTake > 0 {
stillNeedToTake = count - maxToTake
if stillNeedToTake < 0 {
stillNeedToTake = 0
}
count = min(count, maxToTake)
}
// The number of new elements can never exceed the queue size.
if count > len(r.ring) {
panic("used ring contains more new elements than the ring is long")
}
elems := make([]UsedElement, count)
for i := range count {
elems[i] = r.ring[r.lastIndex%uint16(len(r.ring))]
r.lastIndex++
}
return stillNeedToTake, elems
}
func (r *UsedRing) takeOne() (uint16, bool) {
//r.mu.Lock()
//defer r.mu.Unlock()
ringIndex := *r.ringIndex
if ringIndex == r.lastIndex {
// Nothing new.
return 0xffff, false
}
// Calculate the number new used elements that we can read from the ring.
// The ring index may wrap, so special handling for that case is needed.
count := int(ringIndex - r.lastIndex)
if count < 0 {
count += 0xffff
}
// The number of new elements can never exceed the queue size.
if count > len(r.ring) {
panic("used ring contains more new elements than the ring is long")
}
if count == 0 {
return 0xffff, false
}
out := r.ring[r.lastIndex%uint16(len(r.ring))].GetHead()
r.lastIndex++
return out, true
}
// InitOfferSingle is only used to pre-fill the used queue at startup, and should not be used if the device is running!
func (r *UsedRing) InitOfferSingle(x uint16, size int) {
//always called under lock
//r.mu.Lock()
//defer r.mu.Unlock()
offset := 0
// Add descriptor chain heads to the ring.
// The 16-bit ring index may overflow. This is expected and is not an
// issue because the size of the ring array (which equals the queue
// size) is always a power of 2 and smaller than the highest possible
// 16-bit value.
insertIndex := int(*r.ringIndex+uint16(offset)) % len(r.ring)
r.ring[insertIndex] = UsedElement{
DescriptorIndex: uint32(x),
Length: uint32(size),
}
// Increase the ring index by the number of descriptor chains added to the ring.
*r.ringIndex += 1
}

View File

@@ -0,0 +1,136 @@
package virtqueue
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestUsedRing_MemoryLayout(t *testing.T) {
const queueSize = 2
memory := make([]byte, usedRingSize(queueSize))
r := newUsedRing(queueSize, memory)
*r.flags = 0x01ff
*r.ringIndex = 1
r.ring[0] = UsedElement{
DescriptorIndex: 0x0123,
Length: 0x4567,
}
r.ring[1] = UsedElement{
DescriptorIndex: 0x89ab,
Length: 0xcdef,
}
assert.Equal(t, []byte{
0xff, 0x01,
0x01, 0x00,
0x23, 0x01, 0x00, 0x00,
0x67, 0x45, 0x00, 0x00,
0xab, 0x89, 0x00, 0x00,
0xef, 0xcd, 0x00, 0x00,
0x00, 0x00,
}, memory)
}
//func TestUsedRing_Take(t *testing.T) {
// const queueSize = 8
//
// tests := []struct {
// name string
// ring []UsedElement
// ringIndex uint16
// lastIndex uint16
// expected []UsedElement
// }{
// {
// name: "nothing new",
// ring: []UsedElement{
// {DescriptorIndex: 1},
// {DescriptorIndex: 2},
// {DescriptorIndex: 3},
// {DescriptorIndex: 4},
// {},
// {},
// {},
// {},
// },
// ringIndex: 4,
// lastIndex: 4,
// expected: nil,
// },
// {
// name: "no overflow",
// ring: []UsedElement{
// {DescriptorIndex: 1},
// {DescriptorIndex: 2},
// {DescriptorIndex: 3},
// {DescriptorIndex: 4},
// {},
// {},
// {},
// {},
// },
// ringIndex: 4,
// lastIndex: 1,
// expected: []UsedElement{
// {DescriptorIndex: 2},
// {DescriptorIndex: 3},
// {DescriptorIndex: 4},
// },
// },
// {
// name: "ring overflow",
// ring: []UsedElement{
// {DescriptorIndex: 9},
// {DescriptorIndex: 10},
// {DescriptorIndex: 3},
// {DescriptorIndex: 4},
// {DescriptorIndex: 5},
// {DescriptorIndex: 6},
// {DescriptorIndex: 7},
// {DescriptorIndex: 8},
// },
// ringIndex: 10,
// lastIndex: 7,
// expected: []UsedElement{
// {DescriptorIndex: 8},
// {DescriptorIndex: 9},
// {DescriptorIndex: 10},
// },
// },
// {
// name: "index overflow",
// ring: []UsedElement{
// {DescriptorIndex: 9},
// {DescriptorIndex: 10},
// {DescriptorIndex: 3},
// {DescriptorIndex: 4},
// {DescriptorIndex: 5},
// {DescriptorIndex: 6},
// {DescriptorIndex: 7},
// {DescriptorIndex: 8},
// },
// ringIndex: 2,
// lastIndex: 65535,
// expected: []UsedElement{
// {DescriptorIndex: 8},
// {DescriptorIndex: 9},
// {DescriptorIndex: 10},
// },
// },
// }
// for _, tt := range tests {
// t.Run(tt.name, func(t *testing.T) {
// memory := make([]byte, usedRingSize(queueSize))
// r := newUsedRing(queueSize, memory)
//
// copy(r.ring, tt.ring)
// *r.ringIndex = tt.ringIndex
// r.lastIndex = tt.lastIndex
//
// assert.Equal(t, tt.expected, r.take())
// })
// }
//}

70
packet/outpacket.go Normal file
View File

@@ -0,0 +1,70 @@
package packet
import (
"github.com/slackhq/nebula/util/virtio"
"golang.org/x/sys/unix"
)
type OutPacket struct {
Segments [][]byte
SegmentPayloads [][]byte
SegmentHeaders [][]byte
SegmentIDs []uint16
//todo virtio header?
SegSize int
SegCounter int
Valid bool
wasSegmented bool
Scratch []byte
}
func NewOut() *OutPacket {
out := new(OutPacket)
out.Segments = make([][]byte, 0, 64)
out.SegmentHeaders = make([][]byte, 0, 64)
out.SegmentPayloads = make([][]byte, 0, 64)
out.SegmentIDs = make([]uint16, 0, 64)
out.Scratch = make([]byte, Size)
return out
}
func (pkt *OutPacket) Reset() {
pkt.Segments = pkt.Segments[:0]
pkt.SegmentPayloads = pkt.SegmentPayloads[:0]
pkt.SegmentHeaders = pkt.SegmentHeaders[:0]
pkt.SegmentIDs = pkt.SegmentIDs[:0]
pkt.SegSize = 0
pkt.Valid = false
pkt.wasSegmented = false
}
func (pkt *OutPacket) UseSegment(segID uint16, seg []byte, isV6 bool) int {
pkt.Valid = true
pkt.SegmentIDs = append(pkt.SegmentIDs, segID)
pkt.Segments = append(pkt.Segments, seg) //todo do we need this?
vhdr := virtio.NetHdr{ //todo
Flags: unix.VIRTIO_NET_HDR_F_DATA_VALID,
GSOType: unix.VIRTIO_NET_HDR_GSO_NONE,
HdrLen: 0,
GSOSize: 0,
CsumStart: 0,
CsumOffset: 0,
NumBuffers: 0,
}
hdr := seg[0 : virtio.NetHdrSize+14]
_ = vhdr.Encode(hdr)
if isV6 {
hdr[virtio.NetHdrSize+14-2] = 0x86
hdr[virtio.NetHdrSize+14-1] = 0xdd
} else {
hdr[virtio.NetHdrSize+14-2] = 0x08
hdr[virtio.NetHdrSize+14-1] = 0x00
}
pkt.SegmentHeaders = append(pkt.SegmentHeaders, hdr)
pkt.SegmentPayloads = append(pkt.SegmentPayloads, seg[virtio.NetHdrSize+14:])
return len(pkt.SegmentIDs) - 1
}

119
packet/packet.go Normal file
View File

@@ -0,0 +1,119 @@
package packet
import (
"encoding/binary"
"iter"
"net/netip"
"slices"
"syscall"
"unsafe"
"golang.org/x/sys/unix"
)
const Size = 0xffff
type Packet struct {
Payload []byte
Control []byte
Name []byte
SegSize int
//todo should this hold out as well?
OutLen int
wasSegmented bool
isV4 bool
}
func New(isV4 bool) *Packet {
return &Packet{
Payload: make([]byte, Size),
Control: make([]byte, unix.CmsgSpace(2)),
Name: make([]byte, unix.SizeofSockaddrInet6),
isV4: isV4,
}
}
func (p *Packet) AddrPort() netip.AddrPort {
var ip netip.Addr
// Its ok to skip the ok check here, the slicing is the only error that can occur and it will panic
if p.isV4 {
ip, _ = netip.AddrFromSlice(p.Name[4:8])
} else {
ip, _ = netip.AddrFromSlice(p.Name[8:24])
}
return netip.AddrPortFrom(ip.Unmap(), binary.BigEndian.Uint16(p.Name[2:4]))
}
func (p *Packet) updateCtrl(ctrlLen int) {
p.SegSize = len(p.Payload)
p.wasSegmented = false
if ctrlLen == 0 {
return
}
if len(p.Control) == 0 {
return
}
cmsgs, err := unix.ParseSocketControlMessage(p.Control)
if err != nil {
return // oh well
}
for _, c := range cmsgs {
if c.Header.Level == unix.SOL_UDP && c.Header.Type == unix.UDP_GRO && len(c.Data) >= 2 {
p.wasSegmented = true
p.SegSize = int(binary.LittleEndian.Uint16(c.Data[:2]))
return
}
}
}
// Update sets a Packet into "just received, not processed" state
func (p *Packet) Update(ctrlLen int) {
p.OutLen = -1
p.updateCtrl(ctrlLen)
}
func (p *Packet) SetSegSizeForTX() {
p.SegSize = len(p.Payload)
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&p.Control[0]))
hdr.Level = unix.SOL_UDP
hdr.Type = unix.UDP_SEGMENT
hdr.SetLen(syscall.CmsgLen(2))
binary.NativeEndian.PutUint16(p.Control[unix.CmsgLen(0):unix.CmsgLen(0)+2], uint16(p.SegSize))
}
func (p *Packet) CompatibleForSegmentationWith(otherP *Packet, currentTotalSize int) bool {
//same dest
if !slices.Equal(p.Name, otherP.Name) {
return false
}
//don't get too big
if len(p.Payload)+currentTotalSize >= 0xffff {
return false
}
//same body len
//todo allow single different size at end
if len(p.Payload) != len(otherP.Payload) {
return false //todo technically you can cram one extra in
}
return true
}
func (p *Packet) Segments() iter.Seq[[]byte] {
return func(yield func([]byte) bool) {
//cursor := 0
for offset := 0; offset < len(p.Payload); offset += p.SegSize {
end := offset + p.SegSize
if end > len(p.Payload) {
end = len(p.Payload)
}
if !yield(p.Payload[offset:end]) {
return
}
}
}
}

37
packet/virtio.go Normal file
View File

@@ -0,0 +1,37 @@
package packet
import (
"github.com/slackhq/nebula/util/virtio"
)
type VirtIOPacket struct {
Payload []byte
Header virtio.NetHdr
Chains []uint16
ChainRefs [][]byte
// OfferDescriptorChains(chains []uint16, kick bool) error
}
func NewVIO() *VirtIOPacket {
out := new(VirtIOPacket)
out.Payload = nil
out.ChainRefs = make([][]byte, 0, 4)
out.Chains = make([]uint16, 0, 8)
return out
}
func (v *VirtIOPacket) Reset() {
v.Payload = nil
v.ChainRefs = v.ChainRefs[:0]
v.Chains = v.Chains[:0]
}
type VirtIOTXPacket struct {
VirtIOPacket
}
func NewVIOTX(isV4 bool) *VirtIOTXPacket {
out := new(VirtIOTXPacket)
out.VirtIOPacket = *NewVIO()
return out
}

View File

@@ -180,6 +180,7 @@ func (c *PKClient) DeriveNoise(peerPubKey []byte) ([]byte, error) {
pkcs11.NewAttribute(pkcs11.CKA_DECRYPT, true), pkcs11.NewAttribute(pkcs11.CKA_DECRYPT, true),
pkcs11.NewAttribute(pkcs11.CKA_WRAP, true), pkcs11.NewAttribute(pkcs11.CKA_WRAP, true),
pkcs11.NewAttribute(pkcs11.CKA_UNWRAP, true), pkcs11.NewAttribute(pkcs11.CKA_UNWRAP, true),
pkcs11.NewAttribute(pkcs11.CKA_VALUE_LEN, NoiseKeySize),
} }
// Set up the parameters which include the peer's public key // Set up the parameters which include the peer's public key

52
pki.go
View File

@@ -100,41 +100,36 @@ func (p *PKI) reloadCerts(c *config.C, initial bool) *util.ContextualError {
currentState := p.cs.Load() currentState := p.cs.Load()
if newState.v1Cert != nil { if newState.v1Cert != nil {
if currentState.v1Cert == nil { if currentState.v1Cert == nil {
return util.NewContextualError("v1 certificate was added, restart required", nil, err) //adding certs is fine, actually. Networks-in-common confirmed in newCertState().
} } else {
// did IP in cert change? if so, don't set // did IP in cert change? if so, don't set
if !slices.Equal(currentState.v1Cert.Networks(), newState.v1Cert.Networks()) { if !slices.Equal(currentState.v1Cert.Networks(), newState.v1Cert.Networks()) {
return util.NewContextualError( return util.NewContextualError(
"Networks in new cert was different from old", "Networks in new cert was different from old",
m{"new_networks": newState.v1Cert.Networks(), "old_networks": currentState.v1Cert.Networks()}, m{"new_networks": newState.v1Cert.Networks(), "old_networks": currentState.v1Cert.Networks(), "cert_version": cert.Version1},
nil, nil,
) )
} }
if currentState.v1Cert.Curve() != newState.v1Cert.Curve() { if currentState.v1Cert.Curve() != newState.v1Cert.Curve() {
return util.NewContextualError( return util.NewContextualError(
"Curve in new cert was different from old", "Curve in new v1 cert was different from old",
m{"new_curve": newState.v1Cert.Curve(), "old_curve": currentState.v1Cert.Curve()}, m{"new_curve": newState.v1Cert.Curve(), "old_curve": currentState.v1Cert.Curve(), "cert_version": cert.Version1},
nil, nil,
) )
} }
}
} else if currentState.v1Cert != nil {
//TODO: CERT-V2 we should be able to tear this down
return util.NewContextualError("v1 certificate was removed, restart required", nil, err)
} }
if newState.v2Cert != nil { if newState.v2Cert != nil {
if currentState.v2Cert == nil { if currentState.v2Cert == nil {
return util.NewContextualError("v2 certificate was added, restart required", nil, err) //adding certs is fine, actually
} } else {
// did IP in cert change? if so, don't set // did IP in cert change? if so, don't set
if !slices.Equal(currentState.v2Cert.Networks(), newState.v2Cert.Networks()) { if !slices.Equal(currentState.v2Cert.Networks(), newState.v2Cert.Networks()) {
return util.NewContextualError( return util.NewContextualError(
"Networks in new cert was different from old", "Networks in new cert was different from old",
m{"new_networks": newState.v2Cert.Networks(), "old_networks": currentState.v2Cert.Networks()}, m{"new_networks": newState.v2Cert.Networks(), "old_networks": currentState.v2Cert.Networks(), "cert_version": cert.Version2},
nil, nil,
) )
} }
@@ -142,13 +137,25 @@ func (p *PKI) reloadCerts(c *config.C, initial bool) *util.ContextualError {
if currentState.v2Cert.Curve() != newState.v2Cert.Curve() { if currentState.v2Cert.Curve() != newState.v2Cert.Curve() {
return util.NewContextualError( return util.NewContextualError(
"Curve in new cert was different from old", "Curve in new cert was different from old",
m{"new_curve": newState.v2Cert.Curve(), "old_curve": currentState.v2Cert.Curve()}, m{"new_curve": newState.v2Cert.Curve(), "old_curve": currentState.v2Cert.Curve(), "cert_version": cert.Version2},
nil, nil,
) )
} }
}
} else if currentState.v2Cert != nil { } else if currentState.v2Cert != nil {
return util.NewContextualError("v2 certificate was removed, restart required", nil, err) //newState.v1Cert is non-nil bc empty certstates aren't permitted
if newState.v1Cert == nil {
return util.NewContextualError("v1 and v2 certs are nil, this should be impossible", nil, err)
}
//if we're going to v1-only, we need to make sure we didn't orphan any v2-cert vpnaddrs
if !slices.Equal(currentState.v2Cert.Networks(), newState.v1Cert.Networks()) {
return util.NewContextualError(
"Removing a V2 cert is not permitted unless it has identical networks to the new V1 cert",
m{"new_v1_networks": newState.v1Cert.Networks(), "old_v2_networks": currentState.v2Cert.Networks()},
nil,
)
}
} }
// Cipher cant be hot swapped so just leave it at what it was before // Cipher cant be hot swapped so just leave it at what it was before
@@ -173,7 +180,6 @@ func (p *PKI) reloadCerts(c *config.C, initial bool) *util.ContextualError {
p.cs.Store(newState) p.cs.Store(newState)
//TODO: CERT-V2 newState needs a stringer that does json
if initial { if initial {
p.l.WithField("cert", newState).Debug("Client nebula certificate(s)") p.l.WithField("cert", newState).Debug("Client nebula certificate(s)")
} else { } else {
@@ -359,7 +365,9 @@ func newCertState(dv cert.Version, v1, v2 cert.Certificate, pkcs11backed bool, p
return nil, util.NewContextualError("v1 and v2 curve are not the same, ignoring", nil, nil) return nil, util.NewContextualError("v1 and v2 curve are not the same, ignoring", nil, nil)
} }
//TODO: CERT-V2 make sure v2 has v1s address if v1.Networks()[0] != v2.Networks()[0] {
return nil, util.NewContextualError("v1 and v2 networks are not the same", nil, nil)
}
cs.initiatingVersion = dv cs.initiatingVersion = dv
} }
@@ -515,10 +523,14 @@ func loadCAPoolFromConfig(l *logrus.Logger, c *config.C) (*cert.CAPool, error) {
return nil, fmt.Errorf("error while adding CA certificate to CA trust store: %s", err) return nil, fmt.Errorf("error while adding CA certificate to CA trust store: %s", err)
} }
for _, fp := range c.GetStringSlice("pki.blocklist", []string{}) { bl := c.GetStringSlice("pki.blocklist", []string{})
l.WithField("fingerprint", fp).Info("Blocklisting cert") if len(bl) > 0 {
for _, fp := range bl {
caPool.BlocklistFingerprint(fp) caPool.BlocklistFingerprint(fp)
} }
l.WithField("fingerprintCount", len(bl)).Info("Blocklisted certificates")
}
return caPool, nil return caPool, nil
} }

View File

@@ -7,11 +7,11 @@ import (
"slices" "slices"
"sort" "sort"
"strconv" "strconv"
"sync"
"sync/atomic" "sync/atomic"
"time" "time"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/wadey/synctrace"
) )
// forEachFunc is used to benefit folks that want to do work inside the lock // forEachFunc is used to benefit folks that want to do work inside the lock
@@ -185,12 +185,12 @@ func (hr *hostnamesResults) GetAddrs() []netip.AddrPort {
// It serves as a local cache of query replies, host update notifications, and locally learned addresses // It serves as a local cache of query replies, host update notifications, and locally learned addresses
type RemoteList struct { type RemoteList struct {
// Every interaction with internals requires a lock! // Every interaction with internals requires a lock!
synctrace.RWMutex sync.RWMutex
// The full list of vpn addresses assigned to this host // The full list of vpn addresses assigned to this host
vpnAddrs []netip.Addr vpnAddrs []netip.Addr
// A deduplicated set of addresses. Any accessor should lock beforehand. // A deduplicated set of underlay addresses. Any accessor should lock beforehand.
addrs []netip.AddrPort addrs []netip.AddrPort
// A set of relay addresses. VpnIp addresses that the remote identified as relays. // A set of relay addresses. VpnIp addresses that the remote identified as relays.
@@ -202,7 +202,9 @@ type RemoteList struct {
cache map[netip.Addr]*cache cache map[netip.Addr]*cache
hr *hostnamesResults hr *hostnamesResults
shouldAdd func(netip.Addr) bool
// shouldAdd is a nillable function that decides if x should be added to addrs.
shouldAdd func(vpnAddrs []netip.Addr, x netip.Addr) bool
// This is a list of remotes that we have tried to handshake with and have returned from the wrong vpn ip. // This is a list of remotes that we have tried to handshake with and have returned from the wrong vpn ip.
// They should not be tried again during a handshake // They should not be tried again during a handshake
@@ -213,9 +215,8 @@ type RemoteList struct {
} }
// NewRemoteList creates a new empty RemoteList // NewRemoteList creates a new empty RemoteList
func NewRemoteList(vpnAddrs []netip.Addr, shouldAdd func(netip.Addr) bool) *RemoteList { func NewRemoteList(vpnAddrs []netip.Addr, shouldAdd func([]netip.Addr, netip.Addr) bool) *RemoteList {
r := &RemoteList{ r := &RemoteList{
RWMutex: synctrace.NewRWMutex("remote-list"),
vpnAddrs: make([]netip.Addr, len(vpnAddrs)), vpnAddrs: make([]netip.Addr, len(vpnAddrs)),
addrs: make([]netip.AddrPort, 0), addrs: make([]netip.AddrPort, 0),
relays: make([]netip.Addr, 0), relays: make([]netip.Addr, 0),
@@ -369,6 +370,15 @@ func (r *RemoteList) CopyBlockedRemotes() []netip.AddrPort {
return c return c
} }
// RefreshFromHandshake locks and updates the RemoteList to account for data learned upon a completed handshake
func (r *RemoteList) RefreshFromHandshake(vpnAddrs []netip.Addr) {
r.Lock()
r.badRemotes = nil
r.vpnAddrs = make([]netip.Addr, len(vpnAddrs))
copy(r.vpnAddrs, vpnAddrs)
r.Unlock()
}
// ResetBlockedRemotes locks and clears the blocked remotes list // ResetBlockedRemotes locks and clears the blocked remotes list
func (r *RemoteList) ResetBlockedRemotes() { func (r *RemoteList) ResetBlockedRemotes() {
r.Lock() r.Lock()
@@ -578,7 +588,7 @@ func (r *RemoteList) unlockedCollect() {
dnsAddrs := r.hr.GetAddrs() dnsAddrs := r.hr.GetAddrs()
for _, addr := range dnsAddrs { for _, addr := range dnsAddrs {
if r.shouldAdd == nil || r.shouldAdd(addr.Addr()) { if r.shouldAdd == nil || r.shouldAdd(r.vpnAddrs, addr.Addr()) {
if !r.unlockedIsBad(addr) { if !r.unlockedIsBad(addr) {
addrs = append(addrs, addr) addrs = append(addrs, addr)
} }

View File

@@ -16,8 +16,8 @@ import (
"github.com/slackhq/nebula/cert_test" "github.com/slackhq/nebula/cert_test"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/overlay" "github.com/slackhq/nebula/overlay"
"go.yaml.in/yaml/v3"
"golang.org/x/sync/errgroup" "golang.org/x/sync/errgroup"
"gopkg.in/yaml.v3"
) )
type m = map[string]any type m = map[string]any

View File

@@ -5,10 +5,10 @@ import (
"errors" "errors"
"fmt" "fmt"
"net" "net"
"sync"
"github.com/armon/go-radix" "github.com/armon/go-radix"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/wadey/synctrace"
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
) )
@@ -28,7 +28,7 @@ type SSHServer struct {
listener net.Listener listener net.Listener
// Locks the conns/counter to avoid concurrent map access // Locks the conns/counter to avoid concurrent map access
connsLock synctrace.Mutex connsLock sync.Mutex
conns map[int]*session conns map[int]*session
counter int counter int
} }
@@ -41,7 +41,6 @@ func NewSSHServer(l *logrus.Entry) (*SSHServer, error) {
l: l, l: l,
commands: radix.New(), commands: radix.New(),
conns: make(map[int]*session), conns: make(map[int]*session),
connsLock: synctrace.NewMutex("ssh-server-conns"),
} }
cc := ssh.CertChecker{ cc := ssh.CertChecker{

View File

@@ -1,9 +1,8 @@
package nebula package nebula
import ( import (
"sync"
"time" "time"
"github.com/wadey/synctrace"
) )
// How many timer objects should be cached // How many timer objects should be cached
@@ -35,7 +34,7 @@ type TimerWheel[T any] struct {
} }
type LockingTimerWheel[T any] struct { type LockingTimerWheel[T any] struct {
m synctrace.Mutex m sync.Mutex
t *TimerWheel[T] t *TimerWheel[T]
} }
@@ -82,9 +81,8 @@ func NewTimerWheel[T any](min, max time.Duration) *TimerWheel[T] {
} }
// NewLockingTimerWheel is version of TimerWheel that is safe for concurrent use with a small performance penalty // NewLockingTimerWheel is version of TimerWheel that is safe for concurrent use with a small performance penalty
func NewLockingTimerWheel[T any](name string, min, max time.Duration) *LockingTimerWheel[T] { func NewLockingTimerWheel[T any](min, max time.Duration) *LockingTimerWheel[T] {
return &LockingTimerWheel[T]{ return &LockingTimerWheel[T]{
m: synctrace.NewMutex(name),
t: NewTimerWheel[T](min, max), t: NewTimerWheel[T](min, max),
} }
} }

View File

@@ -4,13 +4,13 @@ import (
"net/netip" "net/netip"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/packet"
) )
const MTU = 9001 const MTU = 9001
type EncReader func( type EncReader func(
addr netip.AddrPort, []*packet.Packet,
payload []byte,
) )
type Conn interface { type Conn interface {
@@ -19,6 +19,8 @@ type Conn interface {
ListenOut(r EncReader) ListenOut(r EncReader)
WriteTo(b []byte, addr netip.AddrPort) error WriteTo(b []byte, addr netip.AddrPort) error
ReloadConfig(c *config.C) ReloadConfig(c *config.C)
Prep(pkt *packet.Packet, addr netip.AddrPort) error
WriteBatch(pkt []*packet.Packet) (int, error)
Close() error Close() error
} }

5
udp/errors.go Normal file
View File

@@ -0,0 +1,5 @@
package udp
import "errors"
var ErrInvalidIPv6RemoteForSocket = errors.New("listener is IPv4, but writing to IPv6 remote")

View File

@@ -3,20 +3,62 @@
package udp package udp
// Darwin support is primarily implemented in udp_generic, besides NewListenConfig
import ( import (
"context"
"encoding/binary"
"errors"
"fmt" "fmt"
"net" "net"
"net/netip" "net/netip"
"syscall" "syscall"
"unsafe"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
) )
type StdConn struct {
*net.UDPConn
isV4 bool
sysFd uintptr
l *logrus.Logger
}
var _ Conn = &StdConn{}
func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) { func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
return NewGenericListener(l, ip, port, multi, batch) lc := NewListenConfig(multi)
pc, err := lc.ListenPacket(context.TODO(), "udp", net.JoinHostPort(ip.String(), fmt.Sprintf("%v", port)))
if err != nil {
return nil, err
}
if uc, ok := pc.(*net.UDPConn); ok {
c := &StdConn{UDPConn: uc, l: l}
rc, err := uc.SyscallConn()
if err != nil {
return nil, fmt.Errorf("failed to open udp socket: %w", err)
}
err = rc.Control(func(fd uintptr) {
c.sysFd = fd
})
if err != nil {
return nil, fmt.Errorf("failed to get udp fd: %w", err)
}
la, err := c.LocalAddr()
if err != nil {
return nil, err
}
c.isV4 = la.Addr().Is4()
return c, nil
}
return nil, fmt.Errorf("unexpected PacketConn: %T %#v", pc, pc)
} }
func NewListenConfig(multi bool) net.ListenConfig { func NewListenConfig(multi bool) net.ListenConfig {
@@ -43,16 +85,116 @@ func NewListenConfig(multi bool) net.ListenConfig {
} }
} }
func (u *GenericConn) Rebind() error { //go:linkname sendto golang.org/x/sys/unix.sendto
rc, err := u.UDPConn.SyscallConn() //go:noescape
if err != nil { func sendto(s int, buf []byte, flags int, to unsafe.Pointer, addrlen int32) (err error)
return err
func (u *StdConn) WriteTo(b []byte, ap netip.AddrPort) error {
var sa unsafe.Pointer
var addrLen int32
if u.isV4 {
if ap.Addr().Is6() {
return ErrInvalidIPv6RemoteForSocket
}
var rsa unix.RawSockaddrInet6
rsa.Family = unix.AF_INET6
rsa.Addr = ap.Addr().As16()
binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&rsa.Port))[:], ap.Port())
sa = unsafe.Pointer(&rsa)
addrLen = syscall.SizeofSockaddrInet4
} else {
var rsa unix.RawSockaddrInet6
rsa.Family = unix.AF_INET6
rsa.Addr = ap.Addr().As16()
binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&rsa.Port))[:], ap.Port())
sa = unsafe.Pointer(&rsa)
addrLen = syscall.SizeofSockaddrInet6
}
// Golang stdlib doesn't handle EAGAIN correctly in some situations so we do writes ourselves
// See https://github.com/golang/go/issues/73919
for {
//_, _, err := unix.Syscall6(unix.SYS_SENDTO, u.sysFd, uintptr(unsafe.Pointer(&b[0])), uintptr(len(b)), 0, sa, addrLen)
err := sendto(int(u.sysFd), b, 0, sa, addrLen)
if err == nil {
// Written, get out before the error handling
return nil
}
if errors.Is(err, syscall.EINTR) {
// Write was interrupted, retry
continue
}
if errors.Is(err, syscall.EAGAIN) {
return &net.OpError{Op: "sendto", Err: unix.EWOULDBLOCK}
}
if errors.Is(err, syscall.EBADF) {
return net.ErrClosed
}
return &net.OpError{Op: "sendto", Err: err}
}
}
func (u *StdConn) LocalAddr() (netip.AddrPort, error) {
a := u.UDPConn.LocalAddr()
switch v := a.(type) {
case *net.UDPAddr:
addr, ok := netip.AddrFromSlice(v.IP)
if !ok {
return netip.AddrPort{}, fmt.Errorf("LocalAddr returned invalid IP address: %s", v.IP)
}
return netip.AddrPortFrom(addr, uint16(v.Port)), nil
default:
return netip.AddrPort{}, fmt.Errorf("LocalAddr returned: %#v", a)
}
}
func (u *StdConn) ReloadConfig(c *config.C) {
// TODO
}
func NewUDPStatsEmitter(udpConns []Conn) func() {
// No UDP stats for non-linux
return func() {}
}
func (u *StdConn) ListenOut(r EncReader) {
buffer := make([]byte, MTU)
for {
// Just read one packet at a time
n, rua, err := u.ReadFromUDPAddrPort(buffer)
if err != nil {
if errors.Is(err, net.ErrClosed) {
u.l.WithError(err).Debug("udp socket is closed, exiting read loop")
return
}
u.l.WithError(err).Error("unexpected udp socket receive error")
}
r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n])
}
}
func (u *StdConn) Rebind() error {
var err error
if u.isV4 {
err = syscall.SetsockoptInt(int(u.sysFd), syscall.IPPROTO_IP, syscall.IP_BOUND_IF, 0)
} else {
err = syscall.SetsockoptInt(int(u.sysFd), syscall.IPPROTO_IPV6, syscall.IPV6_BOUND_IF, 0)
} }
return rc.Control(func(fd uintptr) {
err := syscall.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_BOUND_IF, 0)
if err != nil { if err != nil {
u.l.WithError(err).Error("Failed to rebind udp socket") u.l.WithError(err).Error("Failed to rebind udp socket")
} }
})
return nil
} }

View File

@@ -1,6 +1,7 @@
//go:build (!linux || android) && !e2e_testing //go:build (!linux || android) && !e2e_testing && !darwin
// +build !linux android // +build !linux android
// +build !e2e_testing // +build !e2e_testing
// +build !darwin
// udp_generic implements the nebula UDP interface in pure Go stdlib. This // udp_generic implements the nebula UDP interface in pure Go stdlib. This
// means it can be used on platforms like Darwin and Windows. // means it can be used on platforms like Darwin and Windows.

Some files were not shown because too many files have changed in this diff Show More