Compare commits

..

38 Commits

Author SHA1 Message Date
Jay Wren
2400e2392b lint
* reduce staticcheck warnings
2025-04-14 14:33:46 -04:00
dependabot[bot]
18279ed17b Bump github.com/miekg/dns from 1.1.64 to 1.1.65 (#1384)
Some checks failed
gofmt / Run gofmt (push) Successful in 25s
smoke-extra / Run extra smoke tests (push) Failing after 19s
smoke / Run multi node smoke test (push) Failing after 1m27s
Build and test / Build all and test on ubuntu-linux (push) Failing after 26m1s
Build and test / Build and test on linux with boringcrypto (push) Failing after 3m4s
Build and test / Build and test on linux with pkcs11 (push) Failing after 2m46s
Build and test / Build and test on macos-latest (push) Has been cancelled
Build and test / Build and test on windows-latest (push) Has been cancelled
Bumps [github.com/miekg/dns](https://github.com/miekg/dns) from 1.1.64 to 1.1.65.
- [Changelog](https://github.com/miekg/dns/blob/master/Makefile.release)
- [Commits](https://github.com/miekg/dns/compare/v1.1.64...v1.1.65)

---
updated-dependencies:
- dependency-name: github.com/miekg/dns
  dependency-version: 1.1.65
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-04-08 11:40:34 -04:00
dependabot[bot]
c7fb3ad9cf Bump the golang-x-dependencies group with 4 updates (#1382)
Bumps the golang-x-dependencies group with 4 updates: [golang.org/x/crypto](https://github.com/golang/crypto), [golang.org/x/sync](https://github.com/golang/sync), [golang.org/x/sys](https://github.com/golang/sys) and [golang.org/x/term](https://github.com/golang/term).


Updates `golang.org/x/crypto` from 0.36.0 to 0.37.0
- [Commits](https://github.com/golang/crypto/compare/v0.36.0...v0.37.0)

Updates `golang.org/x/sync` from 0.12.0 to 0.13.0
- [Commits](https://github.com/golang/sync/compare/v0.12.0...v0.13.0)

Updates `golang.org/x/sys` from 0.31.0 to 0.32.0
- [Commits](https://github.com/golang/sys/compare/v0.31.0...v0.32.0)

Updates `golang.org/x/term` from 0.30.0 to 0.31.0
- [Commits](https://github.com/golang/term/compare/v0.30.0...v0.31.0)

---
updated-dependencies:
- dependency-name: golang.org/x/crypto
  dependency-version: 0.37.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
  dependency-group: golang-x-dependencies
- dependency-name: golang.org/x/sync
  dependency-version: 0.13.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
  dependency-group: golang-x-dependencies
- dependency-name: golang.org/x/sys
  dependency-version: 0.32.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
  dependency-group: golang-x-dependencies
- dependency-name: golang.org/x/term
  dependency-version: 0.31.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
  dependency-group: golang-x-dependencies
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-04-08 11:39:31 -04:00
John Maguire
d4a7df3083 Rename pki.default_version to pki.initiating_version (#1381)
Some checks failed
gofmt / Run gofmt (push) Successful in 9s
smoke-extra / Run extra smoke tests (push) Failing after 20s
smoke / Run multi node smoke test (push) Failing after 1m26s
Build and test / Build all and test on ubuntu-linux (push) Failing after 21m13s
Build and test / Build and test on linux with boringcrypto (push) Failing after 3m19s
Build and test / Build and test on linux with pkcs11 (push) Failing after 2m47s
Build and test / Build and test on macos-latest (push) Has been cancelled
Build and test / Build and test on windows-latest (push) Has been cancelled
2025-04-07 18:08:29 -04:00
Zeroday BYTE
e83a1c6c84 Update config.go (#1353)
Some checks failed
gofmt / Run gofmt (push) Successful in 10s
smoke-extra / Run extra smoke tests (push) Failing after 19s
smoke / Run multi node smoke test (push) Failing after 1m27s
Build and test / Build all and test on ubuntu-linux (push) Failing after 21m57s
Build and test / Build and test on linux with boringcrypto (push) Failing after 3m23s
Build and test / Build and test on linux with pkcs11 (push) Failing after 3m2s
Build and test / Build and test on macos-latest (push) Has been cancelled
Build and test / Build and test on windows-latest (push) Has been cancelled
2025-04-03 13:11:20 -05:00
Wade Simmons
f5d096dd2b move to golang.org/x/term (#1372)
Some checks failed
gofmt / Run gofmt (push) Successful in 35s
smoke-extra / Run extra smoke tests (push) Failing after 20s
smoke / Run multi node smoke test (push) Failing after 1m28s
Build and test / Build all and test on ubuntu-linux (push) Failing after 19m52s
Build and test / Build and test on linux with boringcrypto (push) Failing after 2m34s
Build and test / Build and test on linux with pkcs11 (push) Failing after 2m44s
Build and test / Build and test on macos-latest (push) Has been cancelled
Build and test / Build and test on windows-latest (push) Has been cancelled
The `golang.org/x/crypto/ssh/terminal` was deprecated and moved to
`golang.org/x/term`. We already use the new package in
`cmd/nebula-cert`, so fix our remaining reference here.

See:

- https://github.com/golang/go/issues/31044
2025-04-02 09:11:34 -04:00
dependabot[bot]
e2d6f4e444 Bump github.com/miekg/dns from 1.1.63 to 1.1.64 (#1363)
Some checks failed
gofmt / Run gofmt (push) Successful in 25s
smoke-extra / Run extra smoke tests (push) Failing after 19s
smoke / Run multi node smoke test (push) Failing after 1m25s
Build and test / Build all and test on ubuntu-linux (push) Failing after 19m34s
Build and test / Build and test on linux with boringcrypto (push) Failing after 2m35s
Build and test / Build and test on linux with pkcs11 (push) Failing after 2m24s
Build and test / Build and test on macos-latest (push) Has been cancelled
Build and test / Build and test on windows-latest (push) Has been cancelled
2025-04-01 16:28:27 -05:00
dependabot[bot]
d99fd60e06 Bump Apple-Actions/import-codesign-certs from 3 to 5 (#1364) 2025-04-01 16:26:23 -05:00
dependabot[bot]
e4bae15825 Bump google.golang.org/protobuf in the protobuf-dependencies group (#1365) 2025-04-01 16:23:35 -05:00
dependabot[bot]
58ead4116f Bump github.com/gaissmai/bart from 0.18.1 to 0.20.1 (#1369) 2025-04-01 16:10:20 -05:00
John Maguire
e136d1d47a Update example config with default_local_cidr_any changes (#1373) 2025-04-01 16:08:03 -05:00
dependabot[bot]
d2adebf26d Bump golangci/golangci-lint-action from 6 to 7 (#1361)
Some checks failed
gofmt / Run gofmt (push) Successful in 10s
smoke-extra / Run extra smoke tests (push) Failing after 19s
smoke / Run multi node smoke test (push) Failing after 1m26s
Build and test / Build all and test on ubuntu-linux (push) Failing after 19m6s
Build and test / Build and test on linux with boringcrypto (push) Failing after 2m32s
Build and test / Build and test on linux with pkcs11 (push) Failing after 2m36s
Build and test / Build and test on macos-latest (push) Has been cancelled
Build and test / Build and test on windows-latest (push) Has been cancelled
* Bump golangci/golangci-lint-action from 6 to 7

Bumps [golangci/golangci-lint-action](https://github.com/golangci/golangci-lint-action) from 6 to 7.
- [Release notes](https://github.com/golangci/golangci-lint-action/releases)
- [Commits](https://github.com/golangci/golangci-lint-action/compare/v6...v7)

---
updated-dependencies:
- dependency-name: golangci/golangci-lint-action
  dependency-type: direct:production
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>

* use latest golangci-lint

* pin to v2.0

* golangci-lint migrate

* make the tests happy

---------

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Wade Simmons <wsimmons@slack-corp.com>
2025-04-01 13:24:19 -04:00
Wade Simmons
36bc9dd261 fix parseUnsafeRoutes for yaml.v3 (#1371)
We switched to yaml.v3 with #1148, but missed this spot that was still
casting into `map[any]any` when yaml.v3 makes it `map[string]any`. Also
clean up a few more `interface{}` that were added as we changed them all
to `any` with #1148.
2025-04-01 09:49:26 -04:00
Wade Simmons
879852c32a upgrade to yaml.v3 (#1148)
Some checks failed
gofmt / Run gofmt (push) Successful in 37s
smoke-extra / Run extra smoke tests (push) Failing after 20s
smoke / Run multi node smoke test (push) Failing after 1m25s
Build and test / Build all and test on ubuntu-linux (push) Failing after 18m51s
Build and test / Build and test on linux with boringcrypto (push) Failing after 2m44s
Build and test / Build and test on linux with pkcs11 (push) Failing after 2m27s
Build and test / Build and test on macos-latest (push) Has been cancelled
Build and test / Build and test on windows-latest (push) Has been cancelled
* upgrade to yaml.v3

The main nice fix here is that maps unmarshal into `map[string]any`
instead of `map[any]any`, so it cleans things up a bit.

* add config.AsBool

Since yaml.v3 doesn't automatically convert yes to bool now, for
backwards compat

* use type aliases for m

* more cleanup

* more cleanup

* more cleanup

* go mod cleanup
2025-03-31 16:08:34 -04:00
dependabot[bot]
75faa5f2e5 Bump golang.org/x/net in the golang-x-dependencies group (#1370)
Bumps the golang-x-dependencies group with 1 update: [golang.org/x/net](https://github.com/golang/net).


Updates `golang.org/x/net` from 0.37.0 to 0.38.0
- [Commits](https://github.com/golang/net/compare/v0.37.0...v0.38.0)

---
updated-dependencies:
- dependency-name: golang.org/x/net
  dependency-type: direct:production
  update-type: version-update:semver-minor
  dependency-group: golang-x-dependencies
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-03-31 16:05:07 -04:00
Caleb Jasik
4444ed166a Add certVersion field to logs when logging the cert name in handshakes (#1359)
Some checks failed
gofmt / Run gofmt (push) Successful in 10s
smoke-extra / Run extra smoke tests (push) Failing after 19s
smoke / Run multi node smoke test (push) Failing after 1m24s
Build and test / Build all and test on ubuntu-linux (push) Failing after 19m24s
Build and test / Build and test on linux with boringcrypto (push) Failing after 4m13s
Build and test / Build and test on linux with pkcs11 (push) Failing after 3m27s
Build and test / Build and test on macos-latest (push) Has been cancelled
Build and test / Build and test on windows-latest (push) Has been cancelled
2025-03-25 16:08:36 -05:00
dioss-Machiel
f86953ca56 Implement ECMP for unsafe_routes (#1332)
Some checks failed
gofmt / Run gofmt (push) Successful in 27s
smoke-extra / Run extra smoke tests (push) Failing after 18s
smoke / Run multi node smoke test (push) Failing after 1m26s
Build and test / Build all and test on ubuntu-linux (push) Failing after 21m43s
Build and test / Build and test on linux with boringcrypto (push) Failing after 3m45s
Build and test / Build and test on linux with pkcs11 (push) Failing after 2m59s
Build and test / Build and test on macos-latest (push) Has been cancelled
Build and test / Build and test on windows-latest (push) Has been cancelled
2025-03-24 17:15:59 -05:00
Wade Simmons
3de36c99b6 build with go1.24 (#1338)
Some checks failed
gofmt / Run gofmt (push) Successful in 40s
smoke-extra / Run extra smoke tests (push) Failing after 20s
smoke / Run multi node smoke test (push) Failing after 1m32s
Build and test / Build all and test on ubuntu-linux (push) Failing after 20m31s
Build and test / Build and test on linux with boringcrypto (push) Failing after 2m48s
Build and test / Build and test on linux with pkcs11 (push) Failing after 2m57s
Build and test / Build and test on macos-latest (push) Has been cancelled
Build and test / Build and test on windows-latest (push) Has been cancelled
This doesn't change our go.mod, which still only requires go1.22 as a minimum. It only changes our builds to use go1.24 so we have the latest improvements.
2025-03-14 13:49:27 -04:00
Caleb Jasik
50473bd2a8 Update example config to listen on :: by default (#1351)
Some checks failed
gofmt / Run gofmt (push) Successful in 10s
smoke-extra / Run extra smoke tests (push) Failing after 19s
smoke / Run multi node smoke test (push) Failing after 1m27s
Build and test / Build all and test on ubuntu-linux (push) Failing after 19m16s
Build and test / Build and test on linux with boringcrypto (push) Failing after 2m41s
Build and test / Build and test on linux with pkcs11 (push) Failing after 2m56s
Build and test / Build and test on macos-latest (push) Has been cancelled
Build and test / Build and test on windows-latest (push) Has been cancelled
2025-03-12 22:53:16 -05:00
jampe
1d3c85338c add so_mark sockopt support (#1331)
Some checks failed
gofmt / Run gofmt (push) Successful in 10s
smoke-extra / Run extra smoke tests (push) Failing after 20s
smoke / Run multi node smoke test (push) Failing after 1m29s
Build and test / Build all and test on ubuntu-linux (push) Failing after 19m23s
Build and test / Build and test on linux with boringcrypto (push) Failing after 2m45s
Build and test / Build and test on linux with pkcs11 (push) Failing after 3m39s
Build and test / Build and test on macos-latest (push) Has been cancelled
Build and test / Build and test on windows-latest (push) Has been cancelled
2025-03-12 09:35:33 -05:00
Aleksandr Zykov
2fb018ced8 Fixed homebrew formula path (#1219)
Some checks failed
gofmt / Run gofmt (push) Successful in 12s
smoke-extra / Run extra smoke tests (push) Failing after 19s
smoke / Run multi node smoke test (push) Failing after 1m28s
Build and test / Build all and test on ubuntu-linux (push) Failing after 19m40s
Build and test / Build and test on linux with boringcrypto (push) Failing after 2m56s
Build and test / Build and test on linux with pkcs11 (push) Failing after 2m47s
Build and test / Build and test on macos-latest (push) Has been cancelled
Build and test / Build and test on windows-latest (push) Has been cancelled
2025-03-11 22:58:52 -05:00
Caleb Jasik
088af8edb2 Enable running testifylint in CI (#1350)
Some checks failed
gofmt / Run gofmt (push) Successful in 10s
smoke-extra / Run extra smoke tests (push) Failing after 18s
smoke / Run multi node smoke test (push) Failing after 1m28s
Build and test / Build all and test on ubuntu-linux (push) Failing after 19m44s
Build and test / Build and test on linux with boringcrypto (push) Failing after 3m1s
Build and test / Build and test on linux with pkcs11 (push) Failing after 3m6s
Build and test / Build and test on macos-latest (push) Has been cancelled
Build and test / Build and test on windows-latest (push) Has been cancelled
2025-03-10 17:38:14 -05:00
Caleb Jasik
612637f529 Fix testifylint lint errors (#1321)
Some checks failed
gofmt / Run gofmt (push) Successful in 11s
smoke-extra / Run extra smoke tests (push) Failing after 19s
smoke / Run multi node smoke test (push) Failing after 1m28s
Build and test / Build all and test on ubuntu-linux (push) Failing after 19m3s
Build and test / Build and test on linux with boringcrypto (push) Failing after 2m44s
Build and test / Build and test on linux with pkcs11 (push) Failing after 2m54s
Build and test / Build and test on macos-latest (push) Has been cancelled
Build and test / Build and test on windows-latest (push) Has been cancelled
* Fix bool-compare

* Fix empty

* Fix encoded-compare

* Fix error-is-as

* Fix error-nil

* Fix expected-actual

* Fix len
2025-03-10 10:18:34 -04:00
Wade Simmons
94e89a1045 smoke-tests: guess the lighthouse container IP better (#1347)
Currently we just assume you are using the default Docker bridge network
config of `172.17.0.0/24`. This change works to try to detect if you are
using a different config, but still only works if you are using a `/24`
and aren't running any other containers. A future PR could make this
better by launching the lighthouse container first and then fetching
what the IP address is before continuing with the configuration.
2025-03-10 10:17:54 -04:00
Caleb Jasik
f7540ad355 Remove commented out metadata.go (#1320)
Some checks failed
gofmt / Run gofmt (push) Successful in 27s
smoke-extra / Run extra smoke tests (push) Failing after 20s
smoke / Run multi node smoke test (push) Failing after 1m26s
Build and test / Build all and test on ubuntu-linux (push) Failing after 18m30s
Build and test / Build and test on linux with boringcrypto (push) Failing after 2m52s
Build and test / Build and test on linux with pkcs11 (push) Failing after 2m35s
Build and test / Build and test on macos-latest (push) Has been cancelled
Build and test / Build and test on windows-latest (push) Has been cancelled
2025-03-07 14:37:07 -06:00
dependabot[bot]
096179a8c9 Bump github.com/miekg/dns from 1.1.62 to 1.1.63 (#1346)
Bumps [github.com/miekg/dns](https://github.com/miekg/dns) from 1.1.62 to 1.1.63.
- [Changelog](https://github.com/miekg/dns/blob/master/Makefile.release)
- [Commits](https://github.com/miekg/dns/compare/v1.1.62...v1.1.63)

---
updated-dependencies:
- dependency-name: github.com/miekg/dns
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-03-07 12:05:36 -05:00
Nate Brown
f8734ffa43 Improve logging when handshaking with an invalid cert (#1345) 2025-03-07 10:45:31 -06:00
dependabot[bot]
c58e223b3d Bump github.com/prometheus/client_golang from 1.20.4 to 1.21.1 (#1340)
Some checks failed
gofmt / Run gofmt (push) Successful in 26s
smoke-extra / Run extra smoke tests (push) Failing after 19s
smoke / Run multi node smoke test (push) Failing after 1m28s
Build and test / Build all and test on ubuntu-linux (push) Failing after 18m50s
Build and test / Build and test on linux with boringcrypto (push) Failing after 3m13s
Build and test / Build and test on linux with pkcs11 (push) Failing after 3m10s
Build and test / Build and test on macos-latest (push) Has been cancelled
Build and test / Build and test on windows-latest (push) Has been cancelled
Bumps [github.com/prometheus/client_golang](https://github.com/prometheus/client_golang) from 1.20.4 to 1.21.1.
- [Release notes](https://github.com/prometheus/client_golang/releases)
- [Changelog](https://github.com/prometheus/client_golang/blob/main/CHANGELOG.md)
- [Commits](https://github.com/prometheus/client_golang/compare/v1.20.4...v1.21.1)

---
updated-dependencies:
- dependency-name: github.com/prometheus/client_golang
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-03-07 09:44:30 -05:00
Wade Simmons
c46ef43590 smoke-test-extra: cleanup ncat references (#1343)
Some checks failed
gofmt / Run gofmt (push) Successful in 39s
smoke-extra / Run extra smoke tests (push) Failing after 30s
smoke / Run multi node smoke test (push) Failing after 1m29s
Build and test / Build all and test on ubuntu-linux (push) Failing after 18m40s
Build and test / Build and test on linux with boringcrypto (push) Failing after 2m36s
Build and test / Build and test on linux with pkcs11 (push) Failing after 2m50s
Build and test / Build and test on macos-latest (push) Has been cancelled
Build and test / Build and test on windows-latest (push) Has been cancelled
* smoke-extra: cleanup ncat references

We can't run the `ncat` tests unless we can make sure to install it to
all of the vagrant boxes.

* more ncat

* more cleanup
2025-03-06 15:44:41 -05:00
dependabot[bot]
775c6bc83d Bump google.golang.org/protobuf (#1344)
Bumps the protobuf-dependencies group with 1 update in the / directory: google.golang.org/protobuf.


Updates `google.golang.org/protobuf` from 1.35.1 to 1.36.5

---
updated-dependencies:
- dependency-name: google.golang.org/protobuf
  dependency-type: direct:production
  update-type: version-update:semver-minor
  dependency-group: protobuf-dependencies
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-03-06 15:43:55 -05:00
dependabot[bot]
13799f425d Bump the golang-x-dependencies group with 5 updates (#1339)
Bumps the golang-x-dependencies group with 5 updates:

| Package | From | To |
| --- | --- | --- |
| [golang.org/x/crypto](https://github.com/golang/crypto) | `0.32.0` | `0.36.0` |
| [golang.org/x/net](https://github.com/golang/net) | `0.34.0` | `0.37.0` |
| [golang.org/x/sync](https://github.com/golang/sync) | `0.10.0` | `0.12.0` |
| [golang.org/x/sys](https://github.com/golang/sys) | `0.29.0` | `0.31.0` |
| [golang.org/x/term](https://github.com/golang/term) | `0.28.0` | `0.30.0` |


Updates `golang.org/x/crypto` from 0.32.0 to 0.36.0
- [Commits](https://github.com/golang/crypto/compare/v0.32.0...v0.36.0)

Updates `golang.org/x/net` from 0.34.0 to 0.37.0
- [Commits](https://github.com/golang/net/compare/v0.34.0...v0.37.0)

Updates `golang.org/x/sync` from 0.10.0 to 0.12.0
- [Commits](https://github.com/golang/sync/compare/v0.10.0...v0.12.0)

Updates `golang.org/x/sys` from 0.29.0 to 0.31.0
- [Commits](https://github.com/golang/sys/compare/v0.29.0...v0.31.0)

Updates `golang.org/x/term` from 0.28.0 to 0.30.0
- [Commits](https://github.com/golang/term/compare/v0.28.0...v0.30.0)

---
updated-dependencies:
- dependency-name: golang.org/x/crypto
  dependency-type: direct:production
  update-type: version-update:semver-minor
  dependency-group: golang-x-dependencies
- dependency-name: golang.org/x/net
  dependency-type: direct:production
  update-type: version-update:semver-minor
  dependency-group: golang-x-dependencies
- dependency-name: golang.org/x/sync
  dependency-type: direct:production
  update-type: version-update:semver-minor
  dependency-group: golang-x-dependencies
- dependency-name: golang.org/x/sys
  dependency-type: direct:production
  update-type: version-update:semver-minor
  dependency-group: golang-x-dependencies
- dependency-name: golang.org/x/term
  dependency-type: direct:production
  update-type: version-update:semver-minor
  dependency-group: golang-x-dependencies
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-03-06 14:30:20 -05:00
dependabot[bot]
8a090e59d7 Bump github.com/gaissmai/bart from 0.13.0 to 0.18.1 (#1341)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Nate Brown <nbrown.us@gmail.com>
2025-03-06 13:26:29 -06:00
Wade Simmons
9feda811a6 bump go.mod to go1.23 (#1342)
* bump go.mod to go1.23

* 1.23.6
2025-03-06 13:21:49 -05:00
dependabot[bot]
750e4a81bf Bump the golang-x-dependencies group across 1 directory with 5 updates (#1303)
Bumps the golang-x-dependencies group with 3 updates in the / directory: [golang.org/x/crypto](https://github.com/golang/crypto), [golang.org/x/net](https://github.com/golang/net) and [golang.org/x/sync](https://github.com/golang/sync).


Updates `golang.org/x/crypto` from 0.28.0 to 0.32.0
- [Commits](https://github.com/golang/crypto/compare/v0.28.0...v0.32.0)

Updates `golang.org/x/net` from 0.30.0 to 0.34.0
- [Commits](https://github.com/golang/net/compare/v0.30.0...v0.34.0)

Updates `golang.org/x/sync` from 0.8.0 to 0.10.0
- [Commits](https://github.com/golang/sync/compare/v0.8.0...v0.10.0)

Updates `golang.org/x/sys` from 0.26.0 to 0.29.0
- [Commits](https://github.com/golang/sys/compare/v0.26.0...v0.29.0)

Updates `golang.org/x/term` from 0.25.0 to 0.28.0
- [Commits](https://github.com/golang/term/compare/v0.25.0...v0.28.0)

---
updated-dependencies:
- dependency-name: golang.org/x/crypto
  dependency-type: direct:production
  update-type: version-update:semver-minor
  dependency-group: golang-x-dependencies
- dependency-name: golang.org/x/net
  dependency-type: direct:production
  update-type: version-update:semver-minor
  dependency-group: golang-x-dependencies
- dependency-name: golang.org/x/sync
  dependency-type: direct:production
  update-type: version-update:semver-minor
  dependency-group: golang-x-dependencies
- dependency-name: golang.org/x/sys
  dependency-type: direct:production
  update-type: version-update:semver-minor
  dependency-group: golang-x-dependencies
- dependency-name: golang.org/x/term
  dependency-type: direct:production
  update-type: version-update:semver-minor
  dependency-group: golang-x-dependencies
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-03-06 12:57:05 -05:00
Wade Simmons
32d3a6e091 build with go1.23 (#1198)
* make boringcrypto: add checklinkname flag for go1.23

Starting with go1.23, we need to set -checklinkname=0 when building for
boringcrypto because we need to use go:linkname to access `newGCMTLS`.

Note that this does break builds when using a go version less than
go1.23.0. We can probably assume that someone using this Makefile and
manually building is using the latest release of Go though.

See:

- https://go.dev/doc/go1.23#linker

* build with go1.23

This doesn't change our go.mod, which still only requires go1.22 as
a minimum, only changes our builds to use go1.23 so we have the latest
improvements.

* fix `make test-boringcrypto` as well

* also fix boringcrypto e2e test
2025-03-06 12:54:20 -05:00
Wade Simmons
351dbd6059 smoke-extra: support Ubuntu 24.04 (#1311)
Ubuntu 24.04 doesn't include vagrant anymore, so add the hashicorp
source
2025-03-06 12:29:38 -05:00
Nate Brown
d97ed57a19 V2 certificate format (#1216)
Co-authored-by: Nate Brown <nbrown.us@gmail.com>
Co-authored-by: Jack Doan <jackdoan@rivian.com>
Co-authored-by: brad-defined <77982333+brad-defined@users.noreply.github.com>
Co-authored-by: Jack Doan <me@jackdoan.com>
2025-03-06 11:28:26 -06:00
Ian VanSchooten
2b427a7e89 Update slack invitation link (#1308)
Some checks failed
gofmt / Run gofmt (push) Successful in 46s
smoke-extra / Run extra smoke tests (push) Failing after 31s
smoke / Run multi node smoke test (push) Failing after 1m12s
Build and test / Build all and test on ubuntu-linux (push) Failing after 15m11s
Build and test / Build and test on linux with boringcrypto (push) Failing after 2m10s
Build and test / Build and test on linux with pkcs11 (push) Failing after 1m58s
Build and test / Build and test on ${{ matrix.os }} (macos-latest) (push) Has been cancelled
Build and test / Build and test on ${{ matrix.os }} (windows-latest) (push) Has been cancelled
2025-01-13 13:35:53 -05:00
106 changed files with 1976 additions and 1331 deletions

View File

@@ -18,7 +18,7 @@ jobs:
- uses: actions/setup-go@v5 - uses: actions/setup-go@v5
with: with:
go-version: '1.22' go-version: '1.24'
check-latest: true check-latest: true
- name: Install goimports - name: Install goimports

View File

@@ -14,7 +14,7 @@ jobs:
- uses: actions/setup-go@v5 - uses: actions/setup-go@v5
with: with:
go-version: '1.22' go-version: '1.24'
check-latest: true check-latest: true
- name: Build - name: Build
@@ -37,7 +37,7 @@ jobs:
- uses: actions/setup-go@v5 - uses: actions/setup-go@v5
with: with:
go-version: '1.22' go-version: '1.24'
check-latest: true check-latest: true
- name: Build - name: Build
@@ -70,12 +70,12 @@ jobs:
- uses: actions/setup-go@v5 - uses: actions/setup-go@v5
with: with:
go-version: '1.22' go-version: '1.24'
check-latest: true check-latest: true
- name: Import certificates - name: Import certificates
if: env.HAS_SIGNING_CREDS == 'true' if: env.HAS_SIGNING_CREDS == 'true'
uses: Apple-Actions/import-codesign-certs@v3 uses: Apple-Actions/import-codesign-certs@v5
with: with:
p12-file-base64: ${{ secrets.APPLE_DEVELOPER_CERTIFICATE_P12_BASE64 }} p12-file-base64: ${{ secrets.APPLE_DEVELOPER_CERTIFICATE_P12_BASE64 }}
p12-password: ${{ secrets.APPLE_DEVELOPER_CERTIFICATE_PASSWORD }} p12-password: ${{ secrets.APPLE_DEVELOPER_CERTIFICATE_PASSWORD }}

View File

@@ -27,6 +27,9 @@ jobs:
go-version-file: 'go.mod' go-version-file: 'go.mod'
check-latest: true check-latest: true
- name: add hashicorp source
run: wget -O- https://apt.releases.hashicorp.com/gpg | gpg --dearmor | sudo tee /usr/share/keyrings/hashicorp-archive-keyring.gpg && echo "deb [signed-by=/usr/share/keyrings/hashicorp-archive-keyring.gpg] https://apt.releases.hashicorp.com $(lsb_release -cs) main" | sudo tee /etc/apt/sources.list.d/hashicorp.list
- name: install vagrant - name: install vagrant
run: sudo apt-get update && sudo apt-get install -y vagrant virtualbox run: sudo apt-get update && sudo apt-get install -y vagrant virtualbox

View File

@@ -22,7 +22,7 @@ jobs:
- uses: actions/setup-go@v5 - uses: actions/setup-go@v5
with: with:
go-version: '1.22' go-version: '1.24'
check-latest: true check-latest: true
- name: build - name: build

View File

@@ -5,6 +5,10 @@ set -e -x
rm -rf ./build rm -rf ./build
mkdir ./build mkdir ./build
# TODO: Assumes your docker bridge network is a /24, and the first container that launches will be .1
# - We could make this better by launching the lighthouse first and then fetching what IP it is.
NET="$(docker network inspect bridge -f '{{ range .IPAM.Config }}{{ .Subnet }}{{ end }}' | cut -d. -f1-3)"
( (
cd build cd build
@@ -21,16 +25,16 @@ mkdir ./build
../genconfig.sh >lighthouse1.yml ../genconfig.sh >lighthouse1.yml
HOST="host2" \ HOST="host2" \
LIGHTHOUSES="192.168.100.1 172.17.0.2:4242" \ LIGHTHOUSES="192.168.100.1 $NET.2:4242" \
../genconfig.sh >host2.yml ../genconfig.sh >host2.yml
HOST="host3" \ HOST="host3" \
LIGHTHOUSES="192.168.100.1 172.17.0.2:4242" \ LIGHTHOUSES="192.168.100.1 $NET.2:4242" \
INBOUND='[{"port": "any", "proto": "icmp", "group": "lighthouse"}]' \ INBOUND='[{"port": "any", "proto": "icmp", "group": "lighthouse"}]' \
../genconfig.sh >host3.yml ../genconfig.sh >host3.yml
HOST="host4" \ HOST="host4" \
LIGHTHOUSES="192.168.100.1 172.17.0.2:4242" \ LIGHTHOUSES="192.168.100.1 $NET.2:4242" \
OUTBOUND='[{"port": "any", "proto": "icmp", "group": "lighthouse"}]' \ OUTBOUND='[{"port": "any", "proto": "icmp", "group": "lighthouse"}]' \
../genconfig.sh >host4.yml ../genconfig.sh >host4.yml

View File

@@ -29,13 +29,13 @@ docker run --name lighthouse1 --rm "$CONTAINER" -config lighthouse1.yml -test
docker run --name host2 --rm "$CONTAINER" -config host2.yml -test docker run --name host2 --rm "$CONTAINER" -config host2.yml -test
vagrant up vagrant up
vagrant ssh -c "cd /nebula && /nebula/$1-nebula -config host3.yml -test" vagrant ssh -c "cd /nebula && /nebula/$1-nebula -config host3.yml -test" -- -T
docker run --name lighthouse1 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config lighthouse1.yml 2>&1 | tee logs/lighthouse1 | sed -u 's/^/ [lighthouse1] /' & docker run --name lighthouse1 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config lighthouse1.yml 2>&1 | tee logs/lighthouse1 | sed -u 's/^/ [lighthouse1] /' &
sleep 1 sleep 1
docker run --name host2 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config host2.yml 2>&1 | tee logs/host2 | sed -u 's/^/ [host2] /' & docker run --name host2 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config host2.yml 2>&1 | tee logs/host2 | sed -u 's/^/ [host2] /' &
sleep 1 sleep 1
vagrant ssh -c "cd /nebula && sudo sh -c 'echo \$\$ >/nebula/pid && exec /nebula/$1-nebula -config host3.yml'" & vagrant ssh -c "cd /nebula && sudo sh -c 'echo \$\$ >/nebula/pid && exec /nebula/$1-nebula -config host3.yml'" 2>&1 -- -T | tee logs/host3 | sed -u 's/^/ [host3] /' &
sleep 15 sleep 15
# grab tcpdump pcaps for debugging # grab tcpdump pcaps for debugging
@@ -46,8 +46,8 @@ docker exec host2 tcpdump -i eth0 -q -w - -U 2>logs/host2.outside.log >logs/host
# vagrant ssh -c "tcpdump -i nebula1 -q -w - -U" 2>logs/host3.inside.log >logs/host3.inside.pcap & # vagrant ssh -c "tcpdump -i nebula1 -q -w - -U" 2>logs/host3.inside.log >logs/host3.inside.pcap &
# vagrant ssh -c "tcpdump -i eth0 -q -w - -U" 2>logs/host3.outside.log >logs/host3.outside.pcap & # vagrant ssh -c "tcpdump -i eth0 -q -w - -U" 2>logs/host3.outside.log >logs/host3.outside.pcap &
docker exec host2 ncat -nklv 0.0.0.0 2000 & #docker exec host2 ncat -nklv 0.0.0.0 2000 &
vagrant ssh -c "ncat -nklv 0.0.0.0 2000" & #vagrant ssh -c "ncat -nklv 0.0.0.0 2000" &
#docker exec host2 ncat -e '/usr/bin/echo host2' -nkluv 0.0.0.0 3000 & #docker exec host2 ncat -e '/usr/bin/echo host2' -nkluv 0.0.0.0 3000 &
#vagrant ssh -c "ncat -e '/usr/bin/echo host3' -nkluv 0.0.0.0 3000" & #vagrant ssh -c "ncat -e '/usr/bin/echo host3' -nkluv 0.0.0.0 3000" &
@@ -68,11 +68,11 @@ docker exec host2 ping -c1 192.168.100.1
# Should fail because not allowed by host3 inbound firewall # Should fail because not allowed by host3 inbound firewall
! docker exec host2 ping -c1 192.168.100.3 -w5 || exit 1 ! docker exec host2 ping -c1 192.168.100.3 -w5 || exit 1
set +x #set +x
echo #echo
echo " *** Testing ncat from host2" #echo " *** Testing ncat from host2"
echo #echo
set -x #set -x
# Should fail because not allowed by host3 inbound firewall # Should fail because not allowed by host3 inbound firewall
#! docker exec host2 ncat -nzv -w5 192.168.100.3 2000 || exit 1 #! docker exec host2 ncat -nzv -w5 192.168.100.3 2000 || exit 1
#! docker exec host2 ncat -nzuv -w5 192.168.100.3 3000 | grep -q host3 || exit 1 #! docker exec host2 ncat -nzuv -w5 192.168.100.3 3000 | grep -q host3 || exit 1
@@ -82,18 +82,18 @@ echo
echo " *** Testing ping from host3" echo " *** Testing ping from host3"
echo echo
set -x set -x
vagrant ssh -c "ping -c1 192.168.100.1" vagrant ssh -c "ping -c1 192.168.100.1" -- -T
vagrant ssh -c "ping -c1 192.168.100.2" vagrant ssh -c "ping -c1 192.168.100.2" -- -T
set +x #set +x
echo #echo
echo " *** Testing ncat from host3" #echo " *** Testing ncat from host3"
echo #echo
set -x #set -x
#vagrant ssh -c "ncat -nzv -w5 192.168.100.2 2000" #vagrant ssh -c "ncat -nzv -w5 192.168.100.2 2000"
#vagrant ssh -c "ncat -nzuv -w5 192.168.100.2 3000" | grep -q host2 #vagrant ssh -c "ncat -nzuv -w5 192.168.100.2 3000" | grep -q host2
vagrant ssh -c "sudo xargs kill </nebula/pid" vagrant ssh -c "sudo xargs kill </nebula/pid" -- -T
docker exec host2 sh -c 'kill 1' docker exec host2 sh -c 'kill 1'
docker exec lighthouse1 sh -c 'kill 1' docker exec lighthouse1 sh -c 'kill 1'
sleep 1 sleep 1

View File

@@ -22,7 +22,7 @@ jobs:
- uses: actions/setup-go@v5 - uses: actions/setup-go@v5
with: with:
go-version: '1.22' go-version: '1.24'
check-latest: true check-latest: true
- name: Build - name: Build
@@ -31,6 +31,11 @@ jobs:
- name: Vet - name: Vet
run: make vet run: make vet
- name: golangci-lint
uses: golangci/golangci-lint-action@v7
with:
version: v2.0
- name: Test - name: Test
run: make test run: make test
@@ -55,7 +60,7 @@ jobs:
- uses: actions/setup-go@v5 - uses: actions/setup-go@v5
with: with:
go-version: '1.22' go-version: '1.24'
check-latest: true check-latest: true
- name: Build - name: Build
@@ -65,7 +70,7 @@ jobs:
run: make test-boringcrypto run: make test-boringcrypto
- name: End 2 end - name: End 2 end
run: make e2evv GOEXPERIMENT=boringcrypto CGO_ENABLED=1 run: make e2e GOEXPERIMENT=boringcrypto CGO_ENABLED=1 TEST_ENV="TEST_LOGS=1" TEST_FLAGS="-v -ldflags -checklinkname=0"
test-linux-pkcs11: test-linux-pkcs11:
name: Build and test on linux with pkcs11 name: Build and test on linux with pkcs11
@@ -97,7 +102,7 @@ jobs:
- uses: actions/setup-go@v5 - uses: actions/setup-go@v5
with: with:
go-version: '1.22' go-version: '1.24'
check-latest: true check-latest: true
- name: Build nebula - name: Build nebula
@@ -109,6 +114,11 @@ jobs:
- name: Vet - name: Vet
run: make vet run: make vet
- name: golangci-lint
uses: golangci/golangci-lint-action@v7
with:
version: v2.0
- name: Test - name: Test
run: make test run: make test

23
.golangci.yaml Normal file
View File

@@ -0,0 +1,23 @@
version: "2"
linters:
default: none
enable:
- testifylint
exclusions:
generated: lax
presets:
- comments
- common-false-positives
- legacy
- std-error-handling
paths:
- third_party$
- builtin$
- examples$
formatters:
exclusions:
generated: lax
paths:
- third_party$
- builtin$
- examples$

View File

@@ -7,6 +7,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [Unreleased] ## [Unreleased]
### Changed
- `default_local_cidr_any` now defaults to false, meaning that any firewall rule
intended to target an `unsafe_routes` entry must explicitly declare it via the
`local_cidr` field. This is almost always the intended behavior. This flag is
deprecated and will be removed in a future release.
## [1.9.4] - 2024-09-09 ## [1.9.4] - 2024-09-09
### Added ### Added

View File

@@ -137,6 +137,8 @@ build/linux-mips-softfloat/%: LDFLAGS += -s -w
# boringcrypto # boringcrypto
build/linux-amd64-boringcrypto/%: GOENV += GOEXPERIMENT=boringcrypto CGO_ENABLED=1 build/linux-amd64-boringcrypto/%: GOENV += GOEXPERIMENT=boringcrypto CGO_ENABLED=1
build/linux-arm64-boringcrypto/%: GOENV += GOEXPERIMENT=boringcrypto CGO_ENABLED=1 build/linux-arm64-boringcrypto/%: GOENV += GOEXPERIMENT=boringcrypto CGO_ENABLED=1
build/linux-amd64-boringcrypto/%: LDFLAGS += -checklinkname=0
build/linux-arm64-boringcrypto/%: LDFLAGS += -checklinkname=0
build/%/nebula: .FORCE build/%/nebula: .FORCE
GOOS=$(firstword $(subst -, , $*)) \ GOOS=$(firstword $(subst -, , $*)) \
@@ -170,7 +172,7 @@ test:
go test -v ./... go test -v ./...
test-boringcrypto: test-boringcrypto:
GOEXPERIMENT=boringcrypto CGO_ENABLED=1 go test -v ./... GOEXPERIMENT=boringcrypto CGO_ENABLED=1 go test -ldflags "-checklinkname=0" -v ./...
test-pkcs11: test-pkcs11:
CGO_ENABLED=1 go test -v -tags pkcs11 ./... CGO_ENABLED=1 go test -v -tags pkcs11 ./...

View File

@@ -12,7 +12,7 @@ Further documentation can be found [here](https://nebula.defined.net/docs/).
You can read more about Nebula [here](https://medium.com/p/884110a5579). You can read more about Nebula [here](https://medium.com/p/884110a5579).
You can also join the NebulaOSS Slack group [here](https://join.slack.com/t/nebulaoss/shared_invite/enQtOTA5MDI4NDg3MTg4LTkwY2EwNTI4NzQyMzc0M2ZlODBjNWI3NTY1MzhiOThiMmZlZjVkMTI0NGY4YTMyNjUwMWEyNzNkZTJmYzQxOGU). You can also join the NebulaOSS Slack group [here](https://join.slack.com/t/nebulaoss/shared_invite/zt-2xqe6e7vn-k_KGi8s13nsr7cvHVvHvuQ).
## Supported Platforms ## Supported Platforms
@@ -47,7 +47,7 @@ Check the [releases](https://github.com/slackhq/nebula/releases/latest) page for
$ sudo apk add nebula $ sudo apk add nebula
``` ```
- [macOS Homebrew](https://github.com/Homebrew/homebrew-core/blob/HEAD/Formula/nebula.rb) - [macOS Homebrew](https://github.com/Homebrew/homebrew-core/blob/HEAD/Formula/n/nebula.rb)
``` ```
$ brew install nebula $ brew install nebula
``` ```

View File

@@ -36,7 +36,7 @@ type AllowListNameRule struct {
func NewLocalAllowListFromConfig(c *config.C, k string) (*LocalAllowList, error) { func NewLocalAllowListFromConfig(c *config.C, k string) (*LocalAllowList, error) {
var nameRules []AllowListNameRule var nameRules []AllowListNameRule
handleKey := func(key string, value interface{}) (bool, error) { handleKey := func(key string, value any) (bool, error) {
if key == "interfaces" { if key == "interfaces" {
var err error var err error
nameRules, err = getAllowListInterfaces(k, value) nameRules, err = getAllowListInterfaces(k, value)
@@ -70,7 +70,7 @@ func NewRemoteAllowListFromConfig(c *config.C, k, rangesKey string) (*RemoteAllo
// If the handleKey func returns true, the rest of the parsing is skipped // If the handleKey func returns true, the rest of the parsing is skipped
// for this key. This allows parsing of special values like `interfaces`. // for this key. This allows parsing of special values like `interfaces`.
func newAllowListFromConfig(c *config.C, k string, handleKey func(key string, value interface{}) (bool, error)) (*AllowList, error) { func newAllowListFromConfig(c *config.C, k string, handleKey func(key string, value any) (bool, error)) (*AllowList, error) {
r := c.Get(k) r := c.Get(k)
if r == nil { if r == nil {
return nil, nil return nil, nil
@@ -81,8 +81,8 @@ func newAllowListFromConfig(c *config.C, k string, handleKey func(key string, va
// If the handleKey func returns true, the rest of the parsing is skipped // If the handleKey func returns true, the rest of the parsing is skipped
// for this key. This allows parsing of special values like `interfaces`. // for this key. This allows parsing of special values like `interfaces`.
func newAllowList(k string, raw interface{}, handleKey func(key string, value interface{}) (bool, error)) (*AllowList, error) { func newAllowList(k string, raw any, handleKey func(key string, value any) (bool, error)) (*AllowList, error) {
rawMap, ok := raw.(map[interface{}]interface{}) rawMap, ok := raw.(map[string]any)
if !ok { if !ok {
return nil, fmt.Errorf("config `%s` has invalid type: %T", k, raw) return nil, fmt.Errorf("config `%s` has invalid type: %T", k, raw)
} }
@@ -100,12 +100,7 @@ func newAllowList(k string, raw interface{}, handleKey func(key string, value in
rules4 := allowListRules{firstValue: true, allValuesMatch: true, defaultSet: false} rules4 := allowListRules{firstValue: true, allValuesMatch: true, defaultSet: false}
rules6 := allowListRules{firstValue: true, allValuesMatch: true, defaultSet: false} rules6 := allowListRules{firstValue: true, allValuesMatch: true, defaultSet: false}
for rawKey, rawValue := range rawMap { for rawCIDR, rawValue := range rawMap {
rawCIDR, ok := rawKey.(string)
if !ok {
return nil, fmt.Errorf("config `%s` has invalid key (type %T): %v", k, rawKey, rawKey)
}
if handleKey != nil { if handleKey != nil {
handled, err := handleKey(rawCIDR, rawValue) handled, err := handleKey(rawCIDR, rawValue)
if err != nil { if err != nil {
@@ -116,7 +111,7 @@ func newAllowList(k string, raw interface{}, handleKey func(key string, value in
} }
} }
value, ok := rawValue.(bool) value, ok := config.AsBool(rawValue)
if !ok { if !ok {
return nil, fmt.Errorf("config `%s` has invalid value (type %T): %v", k, rawValue, rawValue) return nil, fmt.Errorf("config `%s` has invalid value (type %T): %v", k, rawValue, rawValue)
} }
@@ -173,22 +168,18 @@ func newAllowList(k string, raw interface{}, handleKey func(key string, value in
return &AllowList{cidrTree: tree}, nil return &AllowList{cidrTree: tree}, nil
} }
func getAllowListInterfaces(k string, v interface{}) ([]AllowListNameRule, error) { func getAllowListInterfaces(k string, v any) ([]AllowListNameRule, error) {
var nameRules []AllowListNameRule var nameRules []AllowListNameRule
rawRules, ok := v.(map[interface{}]interface{}) rawRules, ok := v.(map[string]any)
if !ok { if !ok {
return nil, fmt.Errorf("config `%s.interfaces` is invalid (type %T): %v", k, v, v) return nil, fmt.Errorf("config `%s.interfaces` is invalid (type %T): %v", k, v, v)
} }
firstEntry := true firstEntry := true
var allValues bool var allValues bool
for rawName, rawAllow := range rawRules { for name, rawAllow := range rawRules {
name, ok := rawName.(string) allow, ok := config.AsBool(rawAllow)
if !ok {
return nil, fmt.Errorf("config `%s.interfaces` has invalid key (type %T): %v", k, rawName, rawName)
}
allow, ok := rawAllow.(bool)
if !ok { if !ok {
return nil, fmt.Errorf("config `%s.interfaces` has invalid value (type %T): %v", k, rawAllow, rawAllow) return nil, fmt.Errorf("config `%s.interfaces` has invalid value (type %T): %v", k, rawAllow, rawAllow)
} }
@@ -224,16 +215,11 @@ func getRemoteAllowRanges(c *config.C, k string) (*bart.Table[*AllowList], error
remoteAllowRanges := new(bart.Table[*AllowList]) remoteAllowRanges := new(bart.Table[*AllowList])
rawMap, ok := value.(map[interface{}]interface{}) rawMap, ok := value.(map[string]any)
if !ok { if !ok {
return nil, fmt.Errorf("config `%s` has invalid type: %T", k, value) return nil, fmt.Errorf("config `%s` has invalid type: %T", k, value)
} }
for rawKey, rawValue := range rawMap { for rawCIDR, rawValue := range rawMap {
rawCIDR, ok := rawKey.(string)
if !ok {
return nil, fmt.Errorf("config `%s` has invalid key (type %T): %v", k, rawKey, rawKey)
}
allowList, err := newAllowList(fmt.Sprintf("%s.%s", k, rawCIDR), rawValue, nil) allowList, err := newAllowList(fmt.Sprintf("%s.%s", k, rawCIDR), rawValue, nil)
if err != nil { if err != nil {
return nil, err return nil, err

View File

@@ -9,42 +9,43 @@ import (
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/test" "github.com/slackhq/nebula/test"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
func TestNewAllowListFromConfig(t *testing.T) { func TestNewAllowListFromConfig(t *testing.T) {
l := test.NewLogger() l := test.NewLogger()
c := config.NewC(l) c := config.NewC(l)
c.Settings["allowlist"] = map[interface{}]interface{}{ c.Settings["allowlist"] = map[string]any{
"192.168.0.0": true, "192.168.0.0": true,
} }
r, err := newAllowListFromConfig(c, "allowlist", nil) r, err := newAllowListFromConfig(c, "allowlist", nil)
assert.EqualError(t, err, "config `allowlist` has invalid CIDR: 192.168.0.0. netip.ParsePrefix(\"192.168.0.0\"): no '/'") require.EqualError(t, err, "config `allowlist` has invalid CIDR: 192.168.0.0. netip.ParsePrefix(\"192.168.0.0\"): no '/'")
assert.Nil(t, r) assert.Nil(t, r)
c.Settings["allowlist"] = map[interface{}]interface{}{ c.Settings["allowlist"] = map[string]any{
"192.168.0.0/16": "abc", "192.168.0.0/16": "abc",
} }
r, err = newAllowListFromConfig(c, "allowlist", nil) _, err = newAllowListFromConfig(c, "allowlist", nil)
assert.EqualError(t, err, "config `allowlist` has invalid value (type string): abc") require.EqualError(t, err, "config `allowlist` has invalid value (type string): abc")
c.Settings["allowlist"] = map[interface{}]interface{}{ c.Settings["allowlist"] = map[string]any{
"192.168.0.0/16": true, "192.168.0.0/16": true,
"10.0.0.0/8": false, "10.0.0.0/8": false,
} }
r, err = newAllowListFromConfig(c, "allowlist", nil) _, err = newAllowListFromConfig(c, "allowlist", nil)
assert.EqualError(t, err, "config `allowlist` contains both true and false rules, but no default set for 0.0.0.0/0") require.EqualError(t, err, "config `allowlist` contains both true and false rules, but no default set for 0.0.0.0/0")
c.Settings["allowlist"] = map[interface{}]interface{}{ c.Settings["allowlist"] = map[string]any{
"0.0.0.0/0": true, "0.0.0.0/0": true,
"10.0.0.0/8": false, "10.0.0.0/8": false,
"10.42.42.0/24": true, "10.42.42.0/24": true,
"fd00::/8": true, "fd00::/8": true,
"fd00:fd00::/16": false, "fd00:fd00::/16": false,
} }
r, err = newAllowListFromConfig(c, "allowlist", nil) _, err = newAllowListFromConfig(c, "allowlist", nil)
assert.EqualError(t, err, "config `allowlist` contains both true and false rules, but no default set for ::/0") require.EqualError(t, err, "config `allowlist` contains both true and false rules, but no default set for ::/0")
c.Settings["allowlist"] = map[interface{}]interface{}{ c.Settings["allowlist"] = map[string]any{
"0.0.0.0/0": true, "0.0.0.0/0": true,
"10.0.0.0/8": false, "10.0.0.0/8": false,
"10.42.42.0/24": true, "10.42.42.0/24": true,
@@ -54,7 +55,7 @@ func TestNewAllowListFromConfig(t *testing.T) {
assert.NotNil(t, r) assert.NotNil(t, r)
} }
c.Settings["allowlist"] = map[interface{}]interface{}{ c.Settings["allowlist"] = map[string]any{
"0.0.0.0/0": true, "0.0.0.0/0": true,
"10.0.0.0/8": false, "10.0.0.0/8": false,
"10.42.42.0/24": true, "10.42.42.0/24": true,
@@ -69,36 +70,36 @@ func TestNewAllowListFromConfig(t *testing.T) {
// Test interface names // Test interface names
c.Settings["allowlist"] = map[interface{}]interface{}{ c.Settings["allowlist"] = map[string]any{
"interfaces": map[interface{}]interface{}{ "interfaces": map[string]any{
`docker.*`: "foo", `docker.*`: "foo",
}, },
} }
lr, err := NewLocalAllowListFromConfig(c, "allowlist") _, err = NewLocalAllowListFromConfig(c, "allowlist")
assert.EqualError(t, err, "config `allowlist.interfaces` has invalid value (type string): foo") require.EqualError(t, err, "config `allowlist.interfaces` has invalid value (type string): foo")
c.Settings["allowlist"] = map[interface{}]interface{}{ c.Settings["allowlist"] = map[string]any{
"interfaces": map[interface{}]interface{}{ "interfaces": map[string]any{
`docker.*`: false, `docker.*`: false,
`eth.*`: true, `eth.*`: true,
}, },
} }
lr, err = NewLocalAllowListFromConfig(c, "allowlist") _, err = NewLocalAllowListFromConfig(c, "allowlist")
assert.EqualError(t, err, "config `allowlist.interfaces` values must all be the same true/false value") require.EqualError(t, err, "config `allowlist.interfaces` values must all be the same true/false value")
c.Settings["allowlist"] = map[interface{}]interface{}{ c.Settings["allowlist"] = map[string]any{
"interfaces": map[interface{}]interface{}{ "interfaces": map[string]any{
`docker.*`: false, `docker.*`: false,
}, },
} }
lr, err = NewLocalAllowListFromConfig(c, "allowlist") lr, err := NewLocalAllowListFromConfig(c, "allowlist")
if assert.NoError(t, err) { if assert.NoError(t, err) {
assert.NotNil(t, lr) assert.NotNil(t, lr)
} }
} }
func TestAllowList_Allow(t *testing.T) { func TestAllowList_Allow(t *testing.T) {
assert.Equal(t, true, ((*AllowList)(nil)).Allow(netip.MustParseAddr("1.1.1.1"))) assert.True(t, ((*AllowList)(nil)).Allow(netip.MustParseAddr("1.1.1.1")))
tree := new(bart.Table[bool]) tree := new(bart.Table[bool])
tree.Insert(netip.MustParsePrefix("0.0.0.0/0"), true) tree.Insert(netip.MustParsePrefix("0.0.0.0/0"), true)
@@ -111,17 +112,17 @@ func TestAllowList_Allow(t *testing.T) {
tree.Insert(netip.MustParsePrefix("::2/128"), false) tree.Insert(netip.MustParsePrefix("::2/128"), false)
al := &AllowList{cidrTree: tree} al := &AllowList{cidrTree: tree}
assert.Equal(t, true, al.Allow(netip.MustParseAddr("1.1.1.1"))) assert.True(t, al.Allow(netip.MustParseAddr("1.1.1.1")))
assert.Equal(t, false, al.Allow(netip.MustParseAddr("10.0.0.4"))) assert.False(t, al.Allow(netip.MustParseAddr("10.0.0.4")))
assert.Equal(t, true, al.Allow(netip.MustParseAddr("10.42.42.42"))) assert.True(t, al.Allow(netip.MustParseAddr("10.42.42.42")))
assert.Equal(t, false, al.Allow(netip.MustParseAddr("10.42.42.41"))) assert.False(t, al.Allow(netip.MustParseAddr("10.42.42.41")))
assert.Equal(t, true, al.Allow(netip.MustParseAddr("10.42.0.1"))) assert.True(t, al.Allow(netip.MustParseAddr("10.42.0.1")))
assert.Equal(t, true, al.Allow(netip.MustParseAddr("::1"))) assert.True(t, al.Allow(netip.MustParseAddr("::1")))
assert.Equal(t, false, al.Allow(netip.MustParseAddr("::2"))) assert.False(t, al.Allow(netip.MustParseAddr("::2")))
} }
func TestLocalAllowList_AllowName(t *testing.T) { func TestLocalAllowList_AllowName(t *testing.T) {
assert.Equal(t, true, ((*LocalAllowList)(nil)).AllowName("docker0")) assert.True(t, ((*LocalAllowList)(nil)).AllowName("docker0"))
rules := []AllowListNameRule{ rules := []AllowListNameRule{
{Name: regexp.MustCompile("^docker.*$"), Allow: false}, {Name: regexp.MustCompile("^docker.*$"), Allow: false},
@@ -129,9 +130,9 @@ func TestLocalAllowList_AllowName(t *testing.T) {
} }
al := &LocalAllowList{nameRules: rules} al := &LocalAllowList{nameRules: rules}
assert.Equal(t, false, al.AllowName("docker0")) assert.False(t, al.AllowName("docker0"))
assert.Equal(t, false, al.AllowName("tun0")) assert.False(t, al.AllowName("tun0"))
assert.Equal(t, true, al.AllowName("eth0")) assert.True(t, al.AllowName("eth0"))
rules = []AllowListNameRule{ rules = []AllowListNameRule{
{Name: regexp.MustCompile("^eth.*$"), Allow: true}, {Name: regexp.MustCompile("^eth.*$"), Allow: true},
@@ -139,7 +140,7 @@ func TestLocalAllowList_AllowName(t *testing.T) {
} }
al = &LocalAllowList{nameRules: rules} al = &LocalAllowList{nameRules: rules}
assert.Equal(t, false, al.AllowName("docker0")) assert.False(t, al.AllowName("docker0"))
assert.Equal(t, true, al.AllowName("eth0")) assert.True(t, al.AllowName("eth0"))
assert.Equal(t, true, al.AllowName("ens5")) assert.True(t, al.AllowName("ens5"))
} }

10
bits.go
View File

@@ -18,7 +18,7 @@ type Bits struct {
func NewBits(bits uint64) *Bits { func NewBits(bits uint64) *Bits {
return &Bits{ return &Bits{
length: bits, length: bits,
bits: make([]bool, bits, bits), bits: make([]bool, bits),
current: 0, current: 0,
lostCounter: metrics.GetOrRegisterCounter("network.packets.lost", nil), lostCounter: metrics.GetOrRegisterCounter("network.packets.lost", nil),
dupeCounter: metrics.GetOrRegisterCounter("network.packets.duplicate", nil), dupeCounter: metrics.GetOrRegisterCounter("network.packets.duplicate", nil),
@@ -28,7 +28,7 @@ func NewBits(bits uint64) *Bits {
func (b *Bits) Check(l logrus.FieldLogger, i uint64) bool { func (b *Bits) Check(l logrus.FieldLogger, i uint64) bool {
// If i is the next number, return true. // If i is the next number, return true.
if i > b.current || (i == 0 && b.firstSeen == false && b.current < b.length) { if i > b.current || (i == 0 && !b.firstSeen && b.current < b.length) {
return true return true
} }
@@ -51,7 +51,7 @@ func (b *Bits) Update(l *logrus.Logger, i uint64) bool {
// If i is the next number, return true and update current. // If i is the next number, return true and update current.
if i == b.current+1 { if i == b.current+1 {
// Report missed packets, we can only understand what was missed after the first window has been gone through // Report missed packets, we can only understand what was missed after the first window has been gone through
if i > b.length && b.bits[i%b.length] == false { if i > b.length && !b.bits[i%b.length] {
b.lostCounter.Inc(1) b.lostCounter.Inc(1)
} }
b.bits[i%b.length] = true b.bits[i%b.length] = true
@@ -104,7 +104,7 @@ func (b *Bits) Update(l *logrus.Logger, i uint64) bool {
} }
// Allow for the 0 packet to come in within the first window // Allow for the 0 packet to come in within the first window
if i == 0 && b.firstSeen == false && b.current < b.length { if i == 0 && !b.firstSeen && b.current < b.length {
b.firstSeen = true b.firstSeen = true
b.bits[i%b.length] = true b.bits[i%b.length] = true
return true return true
@@ -122,7 +122,7 @@ func (b *Bits) Update(l *logrus.Logger, i uint64) bool {
return false return false
} }
if b.bits[i%b.length] == true { if b.bits[i%b.length] {
if l.Level >= logrus.DebugLevel { if l.Level >= logrus.DebugLevel {
l.WithField("receiveWindow", m{"accepted": false, "currentCounter": b.current, "incomingCounter": i, "reason": "old duplicate"}). l.WithField("receiveWindow", m{"accepted": false, "currentCounter": b.current, "incomingCounter": i, "reason": "old duplicate"}).
Debug("Receive window") Debug("Receive window")

View File

@@ -15,10 +15,10 @@ func TestCalculatedRemoteApply(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
input, err := netip.ParseAddr("10.0.10.182") input, err := netip.ParseAddr("10.0.10.182")
assert.NoError(t, err) require.NoError(t, err)
expected, err := netip.ParseAddr("192.168.1.182") expected, err := netip.ParseAddr("192.168.1.182")
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, netAddrToProtoV4AddrPort(expected, 4242), c.ApplyV4(input)) assert.Equal(t, netAddrToProtoV4AddrPort(expected, 4242), c.ApplyV4(input))
@@ -28,10 +28,10 @@ func TestCalculatedRemoteApply(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
input, err = netip.ParseAddr("beef:beef:beef:beef:beef:beef:beef:beef") input, err = netip.ParseAddr("beef:beef:beef:beef:beef:beef:beef:beef")
assert.NoError(t, err) require.NoError(t, err)
expected, err = netip.ParseAddr("ffff:ffff:ffff:ffff:beef:beef:beef:beef") expected, err = netip.ParseAddr("ffff:ffff:ffff:ffff:beef:beef:beef:beef")
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, netAddrToProtoV6AddrPort(expected, 4242), c.ApplyV6(input)) assert.Equal(t, netAddrToProtoV6AddrPort(expected, 4242), c.ApplyV6(input))
@@ -41,10 +41,10 @@ func TestCalculatedRemoteApply(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
input, err = netip.ParseAddr("beef:beef:beef:beef:beef:beef:beef:beef") input, err = netip.ParseAddr("beef:beef:beef:beef:beef:beef:beef:beef")
assert.NoError(t, err) require.NoError(t, err)
expected, err = netip.ParseAddr("ffff:ffff:ffff:ffff:ffff:beef:beef:beef") expected, err = netip.ParseAddr("ffff:ffff:ffff:ffff:ffff:beef:beef:beef")
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, netAddrToProtoV6AddrPort(expected, 4242), c.ApplyV6(input)) assert.Equal(t, netAddrToProtoV6AddrPort(expected, 4242), c.ApplyV6(input))
@@ -54,10 +54,10 @@ func TestCalculatedRemoteApply(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
input, err = netip.ParseAddr("beef:beef:beef:beef:beef:beef:beef:beef") input, err = netip.ParseAddr("beef:beef:beef:beef:beef:beef:beef:beef")
assert.NoError(t, err) require.NoError(t, err)
expected, err = netip.ParseAddr("ffff:ffff:ffff:beef:beef:beef:beef:beef") expected, err = netip.ParseAddr("ffff:ffff:ffff:beef:beef:beef:beef:beef")
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, netAddrToProtoV6AddrPort(expected, 4242), c.ApplyV6(input)) assert.Equal(t, netAddrToProtoV6AddrPort(expected, 4242), c.ApplyV6(input))
} }

View File

@@ -6,6 +6,7 @@ import (
"time" "time"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
func TestNewCAPoolFromBytes(t *testing.T) { func TestNewCAPoolFromBytes(t *testing.T) {
@@ -82,32 +83,32 @@ k+coOv04r+zh33ISyhbsafnYduN17p2eD7CmHvHuerguXD9f32gcxo/KsFCKEjMe
} }
p, err := NewCAPoolFromPEM([]byte(noNewLines)) p, err := NewCAPoolFromPEM([]byte(noNewLines))
assert.Nil(t, err) require.NoError(t, err)
assert.Equal(t, p.CAs["ce4e6c7a596996eb0d82a8875f0f0137a4b53ce22d2421c9fd7150e7a26f6300"].Certificate.Name(), rootCA.details.name) assert.Equal(t, p.CAs["ce4e6c7a596996eb0d82a8875f0f0137a4b53ce22d2421c9fd7150e7a26f6300"].Certificate.Name(), rootCA.details.name)
assert.Equal(t, p.CAs["04c585fcd9a49b276df956a22b7ebea3bf23f1fca5a17c0b56ce2e626631969e"].Certificate.Name(), rootCA01.details.name) assert.Equal(t, p.CAs["04c585fcd9a49b276df956a22b7ebea3bf23f1fca5a17c0b56ce2e626631969e"].Certificate.Name(), rootCA01.details.name)
pp, err := NewCAPoolFromPEM([]byte(withNewLines)) pp, err := NewCAPoolFromPEM([]byte(withNewLines))
assert.Nil(t, err) require.NoError(t, err)
assert.Equal(t, pp.CAs["ce4e6c7a596996eb0d82a8875f0f0137a4b53ce22d2421c9fd7150e7a26f6300"].Certificate.Name(), rootCA.details.name) assert.Equal(t, pp.CAs["ce4e6c7a596996eb0d82a8875f0f0137a4b53ce22d2421c9fd7150e7a26f6300"].Certificate.Name(), rootCA.details.name)
assert.Equal(t, pp.CAs["04c585fcd9a49b276df956a22b7ebea3bf23f1fca5a17c0b56ce2e626631969e"].Certificate.Name(), rootCA01.details.name) assert.Equal(t, pp.CAs["04c585fcd9a49b276df956a22b7ebea3bf23f1fca5a17c0b56ce2e626631969e"].Certificate.Name(), rootCA01.details.name)
// expired cert, no valid certs // expired cert, no valid certs
ppp, err := NewCAPoolFromPEM([]byte(expired)) ppp, err := NewCAPoolFromPEM([]byte(expired))
assert.Equal(t, ErrExpired, err) assert.Equal(t, ErrExpired, err)
assert.Equal(t, ppp.CAs["c39b35a0e8f246203fe4f32b9aa8bfd155f1ae6a6be9d78370641e43397f48f5"].Certificate.Name(), "expired") assert.Equal(t, "expired", ppp.CAs["c39b35a0e8f246203fe4f32b9aa8bfd155f1ae6a6be9d78370641e43397f48f5"].Certificate.Name())
// expired cert, with valid certs // expired cert, with valid certs
pppp, err := NewCAPoolFromPEM(append([]byte(expired), noNewLines...)) pppp, err := NewCAPoolFromPEM(append([]byte(expired), noNewLines...))
assert.Equal(t, ErrExpired, err) assert.Equal(t, ErrExpired, err)
assert.Equal(t, pppp.CAs["ce4e6c7a596996eb0d82a8875f0f0137a4b53ce22d2421c9fd7150e7a26f6300"].Certificate.Name(), rootCA.details.name) assert.Equal(t, pppp.CAs["ce4e6c7a596996eb0d82a8875f0f0137a4b53ce22d2421c9fd7150e7a26f6300"].Certificate.Name(), rootCA.details.name)
assert.Equal(t, pppp.CAs["04c585fcd9a49b276df956a22b7ebea3bf23f1fca5a17c0b56ce2e626631969e"].Certificate.Name(), rootCA01.details.name) assert.Equal(t, pppp.CAs["04c585fcd9a49b276df956a22b7ebea3bf23f1fca5a17c0b56ce2e626631969e"].Certificate.Name(), rootCA01.details.name)
assert.Equal(t, pppp.CAs["c39b35a0e8f246203fe4f32b9aa8bfd155f1ae6a6be9d78370641e43397f48f5"].Certificate.Name(), "expired") assert.Equal(t, "expired", pppp.CAs["c39b35a0e8f246203fe4f32b9aa8bfd155f1ae6a6be9d78370641e43397f48f5"].Certificate.Name())
assert.Equal(t, len(pppp.CAs), 3) assert.Len(t, pppp.CAs, 3)
ppppp, err := NewCAPoolFromPEM([]byte(p256)) ppppp, err := NewCAPoolFromPEM([]byte(p256))
assert.Nil(t, err) require.NoError(t, err)
assert.Equal(t, ppppp.CAs["552bf7d99bec1fc775a0e4c324bf6d8f789b3078f1919c7960d2e5e0c351ee97"].Certificate.Name(), rootCAP256.details.name) assert.Equal(t, ppppp.CAs["552bf7d99bec1fc775a0e4c324bf6d8f789b3078f1919c7960d2e5e0c351ee97"].Certificate.Name(), rootCAP256.details.name)
assert.Equal(t, len(ppppp.CAs), 1) assert.Len(t, ppppp.CAs, 1)
} }
func TestCertificateV1_Verify(t *testing.T) { func TestCertificateV1_Verify(t *testing.T) {
@@ -115,21 +116,21 @@ func TestCertificateV1_Verify(t *testing.T) {
c, _, _, _ := NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test cert", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil) c, _, _, _ := NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test cert", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil)
caPool := NewCAPool() caPool := NewCAPool()
assert.NoError(t, caPool.AddCA(ca)) require.NoError(t, caPool.AddCA(ca))
f, err := c.Fingerprint() f, err := c.Fingerprint()
assert.Nil(t, err) require.NoError(t, err)
caPool.BlocklistFingerprint(f) caPool.BlocklistFingerprint(f)
_, err = caPool.VerifyCertificate(time.Now(), c) _, err = caPool.VerifyCertificate(time.Now(), c)
assert.EqualError(t, err, "certificate is in the block list") require.EqualError(t, err, "certificate is in the block list")
caPool.ResetCertBlocklist() caPool.ResetCertBlocklist()
_, err = caPool.VerifyCertificate(time.Now(), c) _, err = caPool.VerifyCertificate(time.Now(), c)
assert.Nil(t, err) require.NoError(t, err)
_, err = caPool.VerifyCertificate(time.Now().Add(time.Hour*1000), c) _, err = caPool.VerifyCertificate(time.Now().Add(time.Hour*1000), c)
assert.EqualError(t, err, "root certificate is expired") require.EqualError(t, err, "root certificate is expired")
assert.PanicsWithError(t, "certificate is valid before the signing certificate", func() { assert.PanicsWithError(t, "certificate is valid before the signing certificate", func() {
NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test cert2", time.Time{}, time.Time{}, nil, nil, nil) NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test cert2", time.Time{}, time.Time{}, nil, nil, nil)
@@ -138,11 +139,11 @@ func TestCertificateV1_Verify(t *testing.T) {
// Test group assertion // Test group assertion
ca, _, caKey, _ = NewTestCaCert(Version1, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{"test1", "test2"}) ca, _, caKey, _ = NewTestCaCert(Version1, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{"test1", "test2"})
caPem, err := ca.MarshalPEM() caPem, err := ca.MarshalPEM()
assert.Nil(t, err) require.NoError(t, err)
caPool = NewCAPool() caPool = NewCAPool()
b, err := caPool.AddCAFromPEM(caPem) b, err := caPool.AddCAFromPEM(caPem)
assert.NoError(t, err) require.NoError(t, err)
assert.Empty(t, b) assert.Empty(t, b)
assert.PanicsWithError(t, "certificate contained a group not present on the signing ca: bad", func() { assert.PanicsWithError(t, "certificate contained a group not present on the signing ca: bad", func() {
@@ -150,9 +151,9 @@ func TestCertificateV1_Verify(t *testing.T) {
}) })
c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test2", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"}) c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test2", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"})
assert.Nil(t, err) require.NoError(t, err)
_, err = caPool.VerifyCertificate(time.Now(), c) _, err = caPool.VerifyCertificate(time.Now(), c)
assert.Nil(t, err) require.NoError(t, err)
} }
func TestCertificateV1_VerifyP256(t *testing.T) { func TestCertificateV1_VerifyP256(t *testing.T) {
@@ -160,21 +161,21 @@ func TestCertificateV1_VerifyP256(t *testing.T) {
c, _, _, _ := NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil) c, _, _, _ := NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil)
caPool := NewCAPool() caPool := NewCAPool()
assert.NoError(t, caPool.AddCA(ca)) require.NoError(t, caPool.AddCA(ca))
f, err := c.Fingerprint() f, err := c.Fingerprint()
assert.Nil(t, err) require.NoError(t, err)
caPool.BlocklistFingerprint(f) caPool.BlocklistFingerprint(f)
_, err = caPool.VerifyCertificate(time.Now(), c) _, err = caPool.VerifyCertificate(time.Now(), c)
assert.EqualError(t, err, "certificate is in the block list") require.EqualError(t, err, "certificate is in the block list")
caPool.ResetCertBlocklist() caPool.ResetCertBlocklist()
_, err = caPool.VerifyCertificate(time.Now(), c) _, err = caPool.VerifyCertificate(time.Now(), c)
assert.Nil(t, err) require.NoError(t, err)
_, err = caPool.VerifyCertificate(time.Now().Add(time.Hour*1000), c) _, err = caPool.VerifyCertificate(time.Now().Add(time.Hour*1000), c)
assert.EqualError(t, err, "root certificate is expired") require.EqualError(t, err, "root certificate is expired")
assert.PanicsWithError(t, "certificate is valid before the signing certificate", func() { assert.PanicsWithError(t, "certificate is valid before the signing certificate", func() {
NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil) NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil)
@@ -183,11 +184,11 @@ func TestCertificateV1_VerifyP256(t *testing.T) {
// Test group assertion // Test group assertion
ca, _, caKey, _ = NewTestCaCert(Version1, Curve_P256, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{"test1", "test2"}) ca, _, caKey, _ = NewTestCaCert(Version1, Curve_P256, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{"test1", "test2"})
caPem, err := ca.MarshalPEM() caPem, err := ca.MarshalPEM()
assert.Nil(t, err) require.NoError(t, err)
caPool = NewCAPool() caPool = NewCAPool()
b, err := caPool.AddCAFromPEM(caPem) b, err := caPool.AddCAFromPEM(caPem)
assert.NoError(t, err) require.NoError(t, err)
assert.Empty(t, b) assert.Empty(t, b)
assert.PanicsWithError(t, "certificate contained a group not present on the signing ca: bad", func() { assert.PanicsWithError(t, "certificate contained a group not present on the signing ca: bad", func() {
@@ -196,7 +197,7 @@ func TestCertificateV1_VerifyP256(t *testing.T) {
c, _, _, _ = NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"}) c, _, _, _ = NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"})
_, err = caPool.VerifyCertificate(time.Now(), c) _, err = caPool.VerifyCertificate(time.Now(), c)
assert.Nil(t, err) require.NoError(t, err)
} }
func TestCertificateV1_Verify_IPs(t *testing.T) { func TestCertificateV1_Verify_IPs(t *testing.T) {
@@ -205,11 +206,11 @@ func TestCertificateV1_Verify_IPs(t *testing.T) {
ca, _, caKey, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), []netip.Prefix{caIp1, caIp2}, nil, []string{"test"}) ca, _, caKey, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), []netip.Prefix{caIp1, caIp2}, nil, []string{"test"})
caPem, err := ca.MarshalPEM() caPem, err := ca.MarshalPEM()
assert.Nil(t, err) require.NoError(t, err)
caPool := NewCAPool() caPool := NewCAPool()
b, err := caPool.AddCAFromPEM(caPem) b, err := caPool.AddCAFromPEM(caPem)
assert.NoError(t, err) require.NoError(t, err)
assert.Empty(t, b) assert.Empty(t, b)
// ip is outside the network // ip is outside the network
@@ -245,25 +246,25 @@ func TestCertificateV1_Verify_IPs(t *testing.T) {
cIp2 = mustParsePrefixUnmapped("192.168.0.1/25") cIp2 = mustParsePrefixUnmapped("192.168.0.1/25")
c, _, _, _ := NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"}) c, _, _, _ := NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"})
_, err = caPool.VerifyCertificate(time.Now(), c) _, err = caPool.VerifyCertificate(time.Now(), c)
assert.Nil(t, err) require.NoError(t, err)
// Exact matches // Exact matches
c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp1, caIp2}, nil, []string{"test"}) c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp1, caIp2}, nil, []string{"test"})
assert.Nil(t, err) require.NoError(t, err)
_, err = caPool.VerifyCertificate(time.Now(), c) _, err = caPool.VerifyCertificate(time.Now(), c)
assert.Nil(t, err) require.NoError(t, err)
// Exact matches reversed // Exact matches reversed
c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp2, caIp1}, nil, []string{"test"}) c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp2, caIp1}, nil, []string{"test"})
assert.Nil(t, err) require.NoError(t, err)
_, err = caPool.VerifyCertificate(time.Now(), c) _, err = caPool.VerifyCertificate(time.Now(), c)
assert.Nil(t, err) require.NoError(t, err)
// Exact matches reversed with just 1 // Exact matches reversed with just 1
c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp1}, nil, []string{"test"}) c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp1}, nil, []string{"test"})
assert.Nil(t, err) require.NoError(t, err)
_, err = caPool.VerifyCertificate(time.Now(), c) _, err = caPool.VerifyCertificate(time.Now(), c)
assert.Nil(t, err) require.NoError(t, err)
} }
func TestCertificateV1_Verify_Subnets(t *testing.T) { func TestCertificateV1_Verify_Subnets(t *testing.T) {
@@ -272,11 +273,11 @@ func TestCertificateV1_Verify_Subnets(t *testing.T) {
ca, _, caKey, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, []netip.Prefix{caIp1, caIp2}, []string{"test"}) ca, _, caKey, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, []netip.Prefix{caIp1, caIp2}, []string{"test"})
caPem, err := ca.MarshalPEM() caPem, err := ca.MarshalPEM()
assert.Nil(t, err) require.NoError(t, err)
caPool := NewCAPool() caPool := NewCAPool()
b, err := caPool.AddCAFromPEM(caPem) b, err := caPool.AddCAFromPEM(caPem)
assert.NoError(t, err) require.NoError(t, err)
assert.Empty(t, b) assert.Empty(t, b)
// ip is outside the network // ip is outside the network
@@ -311,27 +312,27 @@ func TestCertificateV1_Verify_Subnets(t *testing.T) {
cIp1 = mustParsePrefixUnmapped("10.0.1.0/16") cIp1 = mustParsePrefixUnmapped("10.0.1.0/16")
cIp2 = mustParsePrefixUnmapped("192.168.0.1/25") cIp2 = mustParsePrefixUnmapped("192.168.0.1/25")
c, _, _, _ := NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"}) c, _, _, _ := NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"})
assert.Nil(t, err) require.NoError(t, err)
_, err = caPool.VerifyCertificate(time.Now(), c) _, err = caPool.VerifyCertificate(time.Now(), c)
assert.Nil(t, err) require.NoError(t, err)
// Exact matches // Exact matches
c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp1, caIp2}, []string{"test"}) c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp1, caIp2}, []string{"test"})
assert.Nil(t, err) require.NoError(t, err)
_, err = caPool.VerifyCertificate(time.Now(), c) _, err = caPool.VerifyCertificate(time.Now(), c)
assert.Nil(t, err) require.NoError(t, err)
// Exact matches reversed // Exact matches reversed
c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp2, caIp1}, []string{"test"}) c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp2, caIp1}, []string{"test"})
assert.Nil(t, err) require.NoError(t, err)
_, err = caPool.VerifyCertificate(time.Now(), c) _, err = caPool.VerifyCertificate(time.Now(), c)
assert.Nil(t, err) require.NoError(t, err)
// Exact matches reversed with just 1 // Exact matches reversed with just 1
c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp1}, []string{"test"}) c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp1}, []string{"test"})
assert.Nil(t, err) require.NoError(t, err)
_, err = caPool.VerifyCertificate(time.Now(), c) _, err = caPool.VerifyCertificate(time.Now(), c)
assert.Nil(t, err) require.NoError(t, err)
} }
func TestCertificateV2_Verify(t *testing.T) { func TestCertificateV2_Verify(t *testing.T) {
@@ -339,21 +340,21 @@ func TestCertificateV2_Verify(t *testing.T) {
c, _, _, _ := NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test cert", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil) c, _, _, _ := NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test cert", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil)
caPool := NewCAPool() caPool := NewCAPool()
assert.NoError(t, caPool.AddCA(ca)) require.NoError(t, caPool.AddCA(ca))
f, err := c.Fingerprint() f, err := c.Fingerprint()
assert.Nil(t, err) require.NoError(t, err)
caPool.BlocklistFingerprint(f) caPool.BlocklistFingerprint(f)
_, err = caPool.VerifyCertificate(time.Now(), c) _, err = caPool.VerifyCertificate(time.Now(), c)
assert.EqualError(t, err, "certificate is in the block list") require.EqualError(t, err, "certificate is in the block list")
caPool.ResetCertBlocklist() caPool.ResetCertBlocklist()
_, err = caPool.VerifyCertificate(time.Now(), c) _, err = caPool.VerifyCertificate(time.Now(), c)
assert.Nil(t, err) require.NoError(t, err)
_, err = caPool.VerifyCertificate(time.Now().Add(time.Hour*1000), c) _, err = caPool.VerifyCertificate(time.Now().Add(time.Hour*1000), c)
assert.EqualError(t, err, "root certificate is expired") require.EqualError(t, err, "root certificate is expired")
assert.PanicsWithError(t, "certificate is valid before the signing certificate", func() { assert.PanicsWithError(t, "certificate is valid before the signing certificate", func() {
NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test cert2", time.Time{}, time.Time{}, nil, nil, nil) NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test cert2", time.Time{}, time.Time{}, nil, nil, nil)
@@ -362,11 +363,11 @@ func TestCertificateV2_Verify(t *testing.T) {
// Test group assertion // Test group assertion
ca, _, caKey, _ = NewTestCaCert(Version2, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{"test1", "test2"}) ca, _, caKey, _ = NewTestCaCert(Version2, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{"test1", "test2"})
caPem, err := ca.MarshalPEM() caPem, err := ca.MarshalPEM()
assert.Nil(t, err) require.NoError(t, err)
caPool = NewCAPool() caPool = NewCAPool()
b, err := caPool.AddCAFromPEM(caPem) b, err := caPool.AddCAFromPEM(caPem)
assert.NoError(t, err) require.NoError(t, err)
assert.Empty(t, b) assert.Empty(t, b)
assert.PanicsWithError(t, "certificate contained a group not present on the signing ca: bad", func() { assert.PanicsWithError(t, "certificate contained a group not present on the signing ca: bad", func() {
@@ -374,9 +375,9 @@ func TestCertificateV2_Verify(t *testing.T) {
}) })
c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test2", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"}) c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test2", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"})
assert.Nil(t, err) require.NoError(t, err)
_, err = caPool.VerifyCertificate(time.Now(), c) _, err = caPool.VerifyCertificate(time.Now(), c)
assert.Nil(t, err) require.NoError(t, err)
} }
func TestCertificateV2_VerifyP256(t *testing.T) { func TestCertificateV2_VerifyP256(t *testing.T) {
@@ -384,21 +385,21 @@ func TestCertificateV2_VerifyP256(t *testing.T) {
c, _, _, _ := NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil) c, _, _, _ := NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil)
caPool := NewCAPool() caPool := NewCAPool()
assert.NoError(t, caPool.AddCA(ca)) require.NoError(t, caPool.AddCA(ca))
f, err := c.Fingerprint() f, err := c.Fingerprint()
assert.Nil(t, err) require.NoError(t, err)
caPool.BlocklistFingerprint(f) caPool.BlocklistFingerprint(f)
_, err = caPool.VerifyCertificate(time.Now(), c) _, err = caPool.VerifyCertificate(time.Now(), c)
assert.EqualError(t, err, "certificate is in the block list") require.EqualError(t, err, "certificate is in the block list")
caPool.ResetCertBlocklist() caPool.ResetCertBlocklist()
_, err = caPool.VerifyCertificate(time.Now(), c) _, err = caPool.VerifyCertificate(time.Now(), c)
assert.Nil(t, err) require.NoError(t, err)
_, err = caPool.VerifyCertificate(time.Now().Add(time.Hour*1000), c) _, err = caPool.VerifyCertificate(time.Now().Add(time.Hour*1000), c)
assert.EqualError(t, err, "root certificate is expired") require.EqualError(t, err, "root certificate is expired")
assert.PanicsWithError(t, "certificate is valid before the signing certificate", func() { assert.PanicsWithError(t, "certificate is valid before the signing certificate", func() {
NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil) NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil)
@@ -407,11 +408,11 @@ func TestCertificateV2_VerifyP256(t *testing.T) {
// Test group assertion // Test group assertion
ca, _, caKey, _ = NewTestCaCert(Version2, Curve_P256, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{"test1", "test2"}) ca, _, caKey, _ = NewTestCaCert(Version2, Curve_P256, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{"test1", "test2"})
caPem, err := ca.MarshalPEM() caPem, err := ca.MarshalPEM()
assert.Nil(t, err) require.NoError(t, err)
caPool = NewCAPool() caPool = NewCAPool()
b, err := caPool.AddCAFromPEM(caPem) b, err := caPool.AddCAFromPEM(caPem)
assert.NoError(t, err) require.NoError(t, err)
assert.Empty(t, b) assert.Empty(t, b)
assert.PanicsWithError(t, "certificate contained a group not present on the signing ca: bad", func() { assert.PanicsWithError(t, "certificate contained a group not present on the signing ca: bad", func() {
@@ -420,7 +421,7 @@ func TestCertificateV2_VerifyP256(t *testing.T) {
c, _, _, _ = NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"}) c, _, _, _ = NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"})
_, err = caPool.VerifyCertificate(time.Now(), c) _, err = caPool.VerifyCertificate(time.Now(), c)
assert.Nil(t, err) require.NoError(t, err)
} }
func TestCertificateV2_Verify_IPs(t *testing.T) { func TestCertificateV2_Verify_IPs(t *testing.T) {
@@ -429,11 +430,11 @@ func TestCertificateV2_Verify_IPs(t *testing.T) {
ca, _, caKey, _ := NewTestCaCert(Version2, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), []netip.Prefix{caIp1, caIp2}, nil, []string{"test"}) ca, _, caKey, _ := NewTestCaCert(Version2, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), []netip.Prefix{caIp1, caIp2}, nil, []string{"test"})
caPem, err := ca.MarshalPEM() caPem, err := ca.MarshalPEM()
assert.Nil(t, err) require.NoError(t, err)
caPool := NewCAPool() caPool := NewCAPool()
b, err := caPool.AddCAFromPEM(caPem) b, err := caPool.AddCAFromPEM(caPem)
assert.NoError(t, err) require.NoError(t, err)
assert.Empty(t, b) assert.Empty(t, b)
// ip is outside the network // ip is outside the network
@@ -469,25 +470,25 @@ func TestCertificateV2_Verify_IPs(t *testing.T) {
cIp2 = mustParsePrefixUnmapped("192.168.0.1/25") cIp2 = mustParsePrefixUnmapped("192.168.0.1/25")
c, _, _, _ := NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"}) c, _, _, _ := NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"})
_, err = caPool.VerifyCertificate(time.Now(), c) _, err = caPool.VerifyCertificate(time.Now(), c)
assert.Nil(t, err) require.NoError(t, err)
// Exact matches // Exact matches
c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp1, caIp2}, nil, []string{"test"}) c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp1, caIp2}, nil, []string{"test"})
assert.Nil(t, err) require.NoError(t, err)
_, err = caPool.VerifyCertificate(time.Now(), c) _, err = caPool.VerifyCertificate(time.Now(), c)
assert.Nil(t, err) require.NoError(t, err)
// Exact matches reversed // Exact matches reversed
c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp2, caIp1}, nil, []string{"test"}) c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp2, caIp1}, nil, []string{"test"})
assert.Nil(t, err) require.NoError(t, err)
_, err = caPool.VerifyCertificate(time.Now(), c) _, err = caPool.VerifyCertificate(time.Now(), c)
assert.Nil(t, err) require.NoError(t, err)
// Exact matches reversed with just 1 // Exact matches reversed with just 1
c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp1}, nil, []string{"test"}) c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp1}, nil, []string{"test"})
assert.Nil(t, err) require.NoError(t, err)
_, err = caPool.VerifyCertificate(time.Now(), c) _, err = caPool.VerifyCertificate(time.Now(), c)
assert.Nil(t, err) require.NoError(t, err)
} }
func TestCertificateV2_Verify_Subnets(t *testing.T) { func TestCertificateV2_Verify_Subnets(t *testing.T) {
@@ -496,11 +497,11 @@ func TestCertificateV2_Verify_Subnets(t *testing.T) {
ca, _, caKey, _ := NewTestCaCert(Version2, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, []netip.Prefix{caIp1, caIp2}, []string{"test"}) ca, _, caKey, _ := NewTestCaCert(Version2, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, []netip.Prefix{caIp1, caIp2}, []string{"test"})
caPem, err := ca.MarshalPEM() caPem, err := ca.MarshalPEM()
assert.Nil(t, err) require.NoError(t, err)
caPool := NewCAPool() caPool := NewCAPool()
b, err := caPool.AddCAFromPEM(caPem) b, err := caPool.AddCAFromPEM(caPem)
assert.NoError(t, err) require.NoError(t, err)
assert.Empty(t, b) assert.Empty(t, b)
// ip is outside the network // ip is outside the network
@@ -535,25 +536,25 @@ func TestCertificateV2_Verify_Subnets(t *testing.T) {
cIp1 = mustParsePrefixUnmapped("10.0.1.0/16") cIp1 = mustParsePrefixUnmapped("10.0.1.0/16")
cIp2 = mustParsePrefixUnmapped("192.168.0.1/25") cIp2 = mustParsePrefixUnmapped("192.168.0.1/25")
c, _, _, _ := NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"}) c, _, _, _ := NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"})
assert.Nil(t, err) require.NoError(t, err)
_, err = caPool.VerifyCertificate(time.Now(), c) _, err = caPool.VerifyCertificate(time.Now(), c)
assert.Nil(t, err) require.NoError(t, err)
// Exact matches // Exact matches
c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp1, caIp2}, []string{"test"}) c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp1, caIp2}, []string{"test"})
assert.Nil(t, err) require.NoError(t, err)
_, err = caPool.VerifyCertificate(time.Now(), c) _, err = caPool.VerifyCertificate(time.Now(), c)
assert.Nil(t, err) require.NoError(t, err)
// Exact matches reversed // Exact matches reversed
c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp2, caIp1}, []string{"test"}) c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp2, caIp1}, []string{"test"})
assert.Nil(t, err) require.NoError(t, err)
_, err = caPool.VerifyCertificate(time.Now(), c) _, err = caPool.VerifyCertificate(time.Now(), c)
assert.Nil(t, err) require.NoError(t, err)
// Exact matches reversed with just 1 // Exact matches reversed with just 1
c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp1}, []string{"test"}) c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp1}, []string{"test"})
assert.Nil(t, err) require.NoError(t, err)
_, err = caPool.VerifyCertificate(time.Now(), c) _, err = caPool.VerifyCertificate(time.Now(), c)
assert.Nil(t, err) require.NoError(t, err)
} }

View File

@@ -113,10 +113,10 @@ func (cc *CachedCertificate) String() string {
return cc.Certificate.String() return cc.Certificate.String()
} }
// RecombineAndValidate will attempt to unmarshal a certificate received in a handshake. // Recombine will attempt to unmarshal a certificate received in a handshake.
// Handshakes save space by placing the peers public key in a different part of the packet, we have to // Handshakes save space by placing the peers public key in a different part of the packet, we have to
// reassemble the actual certificate structure with that in mind. // reassemble the actual certificate structure with that in mind.
func RecombineAndValidate(v Version, rawCertBytes, publicKey []byte, curve Curve, caPool *CAPool) (*CachedCertificate, error) { func Recombine(v Version, rawCertBytes, publicKey []byte, curve Curve) (Certificate, error) {
if publicKey == nil { if publicKey == nil {
return nil, ErrNoPeerStaticKey return nil, ErrNoPeerStaticKey
} }
@@ -125,29 +125,15 @@ func RecombineAndValidate(v Version, rawCertBytes, publicKey []byte, curve Curve
return nil, ErrNoPayload return nil, ErrNoPayload
} }
c, err := unmarshalCertificateFromHandshake(v, rawCertBytes, publicKey, curve)
if err != nil {
return nil, fmt.Errorf("error unmarshaling cert: %w", err)
}
cc, err := caPool.VerifyCertificate(time.Now(), c)
if err != nil {
return nil, fmt.Errorf("certificate validation failed: %w", err)
}
return cc, nil
}
func unmarshalCertificateFromHandshake(v Version, b []byte, publicKey []byte, curve Curve) (Certificate, error) {
var c Certificate var c Certificate
var err error var err error
switch v { switch v {
// Implementations must ensure the result is a valid cert! // Implementations must ensure the result is a valid cert!
case VersionPre1, Version1: case VersionPre1, Version1:
c, err = unmarshalCertificateV1(b, publicKey) c, err = unmarshalCertificateV1(rawCertBytes, publicKey)
case Version2: case Version2:
c, err = unmarshalCertificateV2(b, publicKey, curve) c, err = unmarshalCertificateV2(rawCertBytes, publicKey, curve)
default: default:
//TODO: CERT-V2 make a static var //TODO: CERT-V2 make a static var
return nil, fmt.Errorf("unknown certificate version %d", v) return nil, fmt.Errorf("unknown certificate version %d", v)

View File

@@ -20,8 +20,6 @@ import (
"google.golang.org/protobuf/proto" "google.golang.org/protobuf/proto"
) )
const publicKeyLen = 32
type certificateV1 struct { type certificateV1 struct {
details detailsV1 details detailsV1
signature []byte signature []byte
@@ -41,7 +39,7 @@ type detailsV1 struct {
curve Curve curve Curve
} }
type m map[string]interface{} type m = map[string]any
func (c *certificateV1) Version() Version { func (c *certificateV1) Version() Version {
return Version1 return Version1

View File

@@ -39,14 +39,14 @@ func TestCertificateV1_Marshal(t *testing.T) {
} }
b, err := nc.Marshal() b, err := nc.Marshal()
assert.Nil(t, err) require.NoError(t, err)
//t.Log("Cert size:", len(b)) //t.Log("Cert size:", len(b))
nc2, err := unmarshalCertificateV1(b, nil) nc2, err := unmarshalCertificateV1(b, nil)
assert.Nil(t, err) require.NoError(t, err)
assert.Equal(t, nc.Version(), Version1) assert.Equal(t, Version1, nc.Version())
assert.Equal(t, nc.Curve(), Curve_CURVE25519) assert.Equal(t, Curve_CURVE25519, nc.Curve())
assert.Equal(t, nc.Signature(), nc2.Signature()) assert.Equal(t, nc.Signature(), nc2.Signature())
assert.Equal(t, nc.Name(), nc2.Name()) assert.Equal(t, nc.Name(), nc2.Name())
assert.Equal(t, nc.NotBefore(), nc2.NotBefore()) assert.Equal(t, nc.NotBefore(), nc2.NotBefore())
@@ -99,8 +99,8 @@ func TestCertificateV1_MarshalJSON(t *testing.T) {
} }
b, err := nc.MarshalJSON() b, err := nc.MarshalJSON()
assert.Nil(t, err) require.NoError(t, err)
assert.Equal( assert.JSONEq(
t, t,
"{\"details\":{\"curve\":\"CURVE25519\",\"groups\":[\"test-group1\",\"test-group2\",\"test-group3\"],\"isCa\":false,\"issuer\":\"1234567890abcedfghij1234567890ab\",\"name\":\"testing\",\"networks\":[\"10.1.1.1/24\",\"10.1.1.2/16\"],\"notAfter\":\"0000-11-30T02:00:00Z\",\"notBefore\":\"0000-11-30T01:00:00Z\",\"publicKey\":\"313233343536373839306162636564666768696a313233343536373839306162\",\"unsafeNetworks\":[\"9.1.1.2/24\",\"9.1.1.3/16\"]},\"fingerprint\":\"3944c53d4267a229295b56cb2d27d459164c010ac97d655063ba421e0670f4ba\",\"signature\":\"313233343536373839306162636564666768696a313233343536373839306162\",\"version\":1}", "{\"details\":{\"curve\":\"CURVE25519\",\"groups\":[\"test-group1\",\"test-group2\",\"test-group3\"],\"isCa\":false,\"issuer\":\"1234567890abcedfghij1234567890ab\",\"name\":\"testing\",\"networks\":[\"10.1.1.1/24\",\"10.1.1.2/16\"],\"notAfter\":\"0000-11-30T02:00:00Z\",\"notBefore\":\"0000-11-30T01:00:00Z\",\"publicKey\":\"313233343536373839306162636564666768696a313233343536373839306162\",\"unsafeNetworks\":[\"9.1.1.2/24\",\"9.1.1.3/16\"]},\"fingerprint\":\"3944c53d4267a229295b56cb2d27d459164c010ac97d655063ba421e0670f4ba\",\"signature\":\"313233343536373839306162636564666768696a313233343536373839306162\",\"version\":1}",
string(b), string(b),
@@ -110,47 +110,47 @@ func TestCertificateV1_MarshalJSON(t *testing.T) {
func TestCertificateV1_VerifyPrivateKey(t *testing.T) { func TestCertificateV1_VerifyPrivateKey(t *testing.T) {
ca, _, caKey, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil) ca, _, caKey, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil)
err := ca.VerifyPrivateKey(Curve_CURVE25519, caKey) err := ca.VerifyPrivateKey(Curve_CURVE25519, caKey)
assert.Nil(t, err) require.NoError(t, err)
_, _, caKey2, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil) _, _, caKey2, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil)
assert.Nil(t, err) require.NoError(t, err)
err = ca.VerifyPrivateKey(Curve_CURVE25519, caKey2) err = ca.VerifyPrivateKey(Curve_CURVE25519, caKey2)
assert.NotNil(t, err) require.Error(t, err)
c, _, priv, _ := NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil) c, _, priv, _ := NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil)
rawPriv, b, curve, err := UnmarshalPrivateKeyFromPEM(priv) rawPriv, b, curve, err := UnmarshalPrivateKeyFromPEM(priv)
assert.NoError(t, err) require.NoError(t, err)
assert.Empty(t, b) assert.Empty(t, b)
assert.Equal(t, Curve_CURVE25519, curve) assert.Equal(t, Curve_CURVE25519, curve)
err = c.VerifyPrivateKey(Curve_CURVE25519, rawPriv) err = c.VerifyPrivateKey(Curve_CURVE25519, rawPriv)
assert.Nil(t, err) require.NoError(t, err)
_, priv2 := X25519Keypair() _, priv2 := X25519Keypair()
err = c.VerifyPrivateKey(Curve_CURVE25519, priv2) err = c.VerifyPrivateKey(Curve_CURVE25519, priv2)
assert.NotNil(t, err) require.Error(t, err)
} }
func TestCertificateV1_VerifyPrivateKeyP256(t *testing.T) { func TestCertificateV1_VerifyPrivateKeyP256(t *testing.T) {
ca, _, caKey, _ := NewTestCaCert(Version1, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil) ca, _, caKey, _ := NewTestCaCert(Version1, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil)
err := ca.VerifyPrivateKey(Curve_P256, caKey) err := ca.VerifyPrivateKey(Curve_P256, caKey)
assert.Nil(t, err) require.NoError(t, err)
_, _, caKey2, _ := NewTestCaCert(Version1, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil) _, _, caKey2, _ := NewTestCaCert(Version1, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil)
assert.Nil(t, err) require.NoError(t, err)
err = ca.VerifyPrivateKey(Curve_P256, caKey2) err = ca.VerifyPrivateKey(Curve_P256, caKey2)
assert.NotNil(t, err) require.Error(t, err)
c, _, priv, _ := NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil) c, _, priv, _ := NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil)
rawPriv, b, curve, err := UnmarshalPrivateKeyFromPEM(priv) rawPriv, b, curve, err := UnmarshalPrivateKeyFromPEM(priv)
assert.NoError(t, err) require.NoError(t, err)
assert.Empty(t, b) assert.Empty(t, b)
assert.Equal(t, Curve_P256, curve) assert.Equal(t, Curve_P256, curve)
err = c.VerifyPrivateKey(Curve_P256, rawPriv) err = c.VerifyPrivateKey(Curve_P256, rawPriv)
assert.Nil(t, err) require.NoError(t, err)
_, priv2 := P256Keypair() _, priv2 := P256Keypair()
err = c.VerifyPrivateKey(Curve_P256, priv2) err = c.VerifyPrivateKey(Curve_P256, priv2)
assert.NotNil(t, err) require.Error(t, err)
} }
// Ensure that upgrading the protobuf library does not change how certificates // Ensure that upgrading the protobuf library does not change how certificates
@@ -182,11 +182,11 @@ func TestMarshalingCertificateV1Consistency(t *testing.T) {
} }
b, err := nc.Marshal() b, err := nc.Marshal()
require.Nil(t, err) require.NoError(t, err)
assert.Equal(t, "0a8e010a0774657374696e671212828284508080fcff0f8182845080feffff0f1a12838284488080fcff0f8282844880feffff0f220b746573742d67726f757031220b746573742d67726f757032220b746573742d67726f75703328cd1c30cdb8ccf0af073a20313233343536373839306162636564666768696a3132333435363738393061624a081234567890abcedf1220313233343536373839306162636564666768696a313233343536373839306162", fmt.Sprintf("%x", b)) assert.Equal(t, "0a8e010a0774657374696e671212828284508080fcff0f8182845080feffff0f1a12838284488080fcff0f8282844880feffff0f220b746573742d67726f757031220b746573742d67726f757032220b746573742d67726f75703328cd1c30cdb8ccf0af073a20313233343536373839306162636564666768696a3132333435363738393061624a081234567890abcedf1220313233343536373839306162636564666768696a313233343536373839306162", fmt.Sprintf("%x", b))
b, err = proto.Marshal(nc.getRawDetails()) b, err = proto.Marshal(nc.getRawDetails())
assert.Nil(t, err) require.NoError(t, err)
assert.Equal(t, "0a0774657374696e671212828284508080fcff0f8182845080feffff0f1a12838284488080fcff0f8282844880feffff0f220b746573742d67726f757031220b746573742d67726f757032220b746573742d67726f75703328cd1c30cdb8ccf0af073a20313233343536373839306162636564666768696a3132333435363738393061624a081234567890abcedf", fmt.Sprintf("%x", b)) assert.Equal(t, "0a0774657374696e671212828284508080fcff0f8182845080feffff0f1a12838284488080fcff0f8282844880feffff0f220b746573742d67726f757031220b746573742d67726f757032220b746573742d67726f75703328cd1c30cdb8ccf0af073a20313233343536373839306162636564666768696a3132333435363738393061624a081234567890abcedf", fmt.Sprintf("%x", b))
} }
@@ -201,7 +201,7 @@ func TestUnmarshalCertificateV1(t *testing.T) {
// Test that we don't panic with an invalid certificate (#332) // Test that we don't panic with an invalid certificate (#332)
data := []byte("\x98\x00\x00") data := []byte("\x98\x00\x00")
_, err := unmarshalCertificateV1(data, nil) _, err := unmarshalCertificateV1(data, nil)
assert.EqualError(t, err, "encoded Details was nil") require.EqualError(t, err, "encoded Details was nil")
} }
func appendByteSlices(b ...[]byte) []byte { func appendByteSlices(b ...[]byte) []byte {

View File

@@ -45,14 +45,14 @@ func TestCertificateV2_Marshal(t *testing.T) {
nc.rawDetails = db nc.rawDetails = db
b, err := nc.Marshal() b, err := nc.Marshal()
require.Nil(t, err) require.NoError(t, err)
//t.Log("Cert size:", len(b)) //t.Log("Cert size:", len(b))
nc2, err := unmarshalCertificateV2(b, nil, Curve_CURVE25519) nc2, err := unmarshalCertificateV2(b, nil, Curve_CURVE25519)
assert.Nil(t, err) require.NoError(t, err)
assert.Equal(t, nc.Version(), Version2) assert.Equal(t, Version2, nc.Version())
assert.Equal(t, nc.Curve(), Curve_CURVE25519) assert.Equal(t, Curve_CURVE25519, nc.Curve())
assert.Equal(t, nc.Signature(), nc2.Signature()) assert.Equal(t, nc.Signature(), nc2.Signature())
assert.Equal(t, nc.Name(), nc2.Name()) assert.Equal(t, nc.Name(), nc2.Name())
assert.Equal(t, nc.NotBefore(), nc2.NotBefore()) assert.Equal(t, nc.NotBefore(), nc2.NotBefore())
@@ -113,16 +113,16 @@ func TestCertificateV2_MarshalJSON(t *testing.T) {
signature: []byte("1234567890abcedf1234567890abcedf1234567890abcedf1234567890abcedf"), signature: []byte("1234567890abcedf1234567890abcedf1234567890abcedf1234567890abcedf"),
} }
b, err := nc.MarshalJSON() _, err := nc.MarshalJSON()
assert.ErrorIs(t, err, ErrMissingDetails) require.ErrorIs(t, err, ErrMissingDetails)
rd, err := nc.details.Marshal() rd, err := nc.details.Marshal()
assert.NoError(t, err) require.NoError(t, err)
nc.rawDetails = rd nc.rawDetails = rd
b, err = nc.MarshalJSON() b, err := nc.MarshalJSON()
assert.Nil(t, err) require.NoError(t, err)
assert.Equal( assert.JSONEq(
t, t,
"{\"curve\":\"CURVE25519\",\"details\":{\"groups\":[\"test-group1\",\"test-group2\",\"test-group3\"],\"isCa\":false,\"issuer\":\"1234567890abcedf1234567890abcedf\",\"name\":\"testing\",\"networks\":[\"10.1.1.1/24\",\"10.1.1.2/16\"],\"notAfter\":\"0000-11-30T02:00:00Z\",\"notBefore\":\"0000-11-30T01:00:00Z\",\"unsafeNetworks\":[\"9.1.1.2/24\",\"9.1.1.3/16\"]},\"fingerprint\":\"152d9a7400c1e001cb76cffd035215ebb351f69eeb797f7f847dd086e15e56dd\",\"publicKey\":\"3132333435363738393061626365646631323334353637383930616263656466\",\"signature\":\"31323334353637383930616263656466313233343536373839306162636564663132333435363738393061626365646631323334353637383930616263656466\",\"version\":2}", "{\"curve\":\"CURVE25519\",\"details\":{\"groups\":[\"test-group1\",\"test-group2\",\"test-group3\"],\"isCa\":false,\"issuer\":\"1234567890abcedf1234567890abcedf\",\"name\":\"testing\",\"networks\":[\"10.1.1.1/24\",\"10.1.1.2/16\"],\"notAfter\":\"0000-11-30T02:00:00Z\",\"notBefore\":\"0000-11-30T01:00:00Z\",\"unsafeNetworks\":[\"9.1.1.2/24\",\"9.1.1.3/16\"]},\"fingerprint\":\"152d9a7400c1e001cb76cffd035215ebb351f69eeb797f7f847dd086e15e56dd\",\"publicKey\":\"3132333435363738393061626365646631323334353637383930616263656466\",\"signature\":\"31323334353637383930616263656466313233343536373839306162636564663132333435363738393061626365646631323334353637383930616263656466\",\"version\":2}",
string(b), string(b),
@@ -132,85 +132,86 @@ func TestCertificateV2_MarshalJSON(t *testing.T) {
func TestCertificateV2_VerifyPrivateKey(t *testing.T) { func TestCertificateV2_VerifyPrivateKey(t *testing.T) {
ca, _, caKey, _ := NewTestCaCert(Version2, Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil) ca, _, caKey, _ := NewTestCaCert(Version2, Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil)
err := ca.VerifyPrivateKey(Curve_CURVE25519, caKey) err := ca.VerifyPrivateKey(Curve_CURVE25519, caKey)
assert.Nil(t, err) require.NoError(t, err)
err = ca.VerifyPrivateKey(Curve_CURVE25519, caKey[:16]) err = ca.VerifyPrivateKey(Curve_CURVE25519, caKey[:16])
assert.ErrorIs(t, err, ErrInvalidPrivateKey) require.ErrorIs(t, err, ErrInvalidPrivateKey)
_, caKey2, err := ed25519.GenerateKey(rand.Reader) _, caKey2, err := ed25519.GenerateKey(rand.Reader)
require.Nil(t, err) require.NoError(t, err)
err = ca.VerifyPrivateKey(Curve_CURVE25519, caKey2) err = ca.VerifyPrivateKey(Curve_CURVE25519, caKey2)
assert.ErrorIs(t, err, ErrPublicPrivateKeyMismatch) require.ErrorIs(t, err, ErrPublicPrivateKeyMismatch)
c, _, priv, _ := NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil) c, _, priv, _ := NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil)
rawPriv, b, curve, err := UnmarshalPrivateKeyFromPEM(priv) rawPriv, b, curve, err := UnmarshalPrivateKeyFromPEM(priv)
assert.NoError(t, err) require.NoError(t, err)
assert.Empty(t, b) assert.Empty(t, b)
assert.Equal(t, Curve_CURVE25519, curve) assert.Equal(t, Curve_CURVE25519, curve)
err = c.VerifyPrivateKey(Curve_CURVE25519, rawPriv) err = c.VerifyPrivateKey(Curve_CURVE25519, rawPriv)
assert.Nil(t, err) require.NoError(t, err)
_, priv2 := X25519Keypair() _, priv2 := X25519Keypair()
err = c.VerifyPrivateKey(Curve_P256, priv2) err = c.VerifyPrivateKey(Curve_P256, priv2)
assert.ErrorIs(t, err, ErrPublicPrivateCurveMismatch) require.ErrorIs(t, err, ErrPublicPrivateCurveMismatch)
err = c.VerifyPrivateKey(Curve_CURVE25519, priv2) err = c.VerifyPrivateKey(Curve_CURVE25519, priv2)
assert.ErrorIs(t, err, ErrPublicPrivateKeyMismatch) require.ErrorIs(t, err, ErrPublicPrivateKeyMismatch)
err = c.VerifyPrivateKey(Curve_CURVE25519, priv2[:16]) err = c.VerifyPrivateKey(Curve_CURVE25519, priv2[:16])
assert.ErrorIs(t, err, ErrInvalidPrivateKey) require.ErrorIs(t, err, ErrInvalidPrivateKey)
ac, ok := c.(*certificateV2) ac, ok := c.(*certificateV2)
require.True(t, ok) require.True(t, ok)
ac.curve = Curve(99) ac.curve = Curve(99)
err = c.VerifyPrivateKey(Curve(99), priv2) err = c.VerifyPrivateKey(Curve(99), priv2)
assert.EqualError(t, err, "invalid curve: 99") require.EqualError(t, err, "invalid curve: 99")
ca2, _, caKey2, _ := NewTestCaCert(Version2, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil) ca2, _, caKey2, _ := NewTestCaCert(Version2, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil)
err = ca.VerifyPrivateKey(Curve_CURVE25519, caKey) err = ca.VerifyPrivateKey(Curve_CURVE25519, caKey)
assert.Nil(t, err) require.NoError(t, err)
err = ca2.VerifyPrivateKey(Curve_P256, caKey2[:16]) err = ca2.VerifyPrivateKey(Curve_P256, caKey2[:16])
assert.ErrorIs(t, err, ErrInvalidPrivateKey) require.ErrorIs(t, err, ErrInvalidPrivateKey)
c, _, priv, _ = NewTestCert(Version2, Curve_P256, ca2, caKey2, "test", time.Time{}, time.Time{}, nil, nil, nil) c, _, priv, _ = NewTestCert(Version2, Curve_P256, ca2, caKey2, "test", time.Time{}, time.Time{}, nil, nil, nil)
rawPriv, b, curve, err = UnmarshalPrivateKeyFromPEM(priv) _, _, curve, err = UnmarshalPrivateKeyFromPEM(priv)
assert.Equal(t, err, nil)
assert.Equal(t, curve, Curve_P256)
err = c.VerifyPrivateKey(Curve_P256, priv[:16]) err = c.VerifyPrivateKey(Curve_P256, priv[:16])
assert.ErrorIs(t, err, ErrInvalidPrivateKey) require.ErrorIs(t, err, ErrInvalidPrivateKey)
err = c.VerifyPrivateKey(Curve_P256, priv) err = c.VerifyPrivateKey(Curve_P256, priv)
assert.ErrorIs(t, err, ErrInvalidPrivateKey) require.ErrorIs(t, err, ErrInvalidPrivateKey)
aCa, ok := ca2.(*certificateV2) aCa, ok := ca2.(*certificateV2)
require.True(t, ok) require.True(t, ok)
aCa.curve = Curve(99) aCa.curve = Curve(99)
err = aCa.VerifyPrivateKey(Curve(99), priv2) err = aCa.VerifyPrivateKey(Curve(99), priv2)
assert.EqualError(t, err, "invalid curve: 99") require.EqualError(t, err, "invalid curve: 99")
} }
func TestCertificateV2_VerifyPrivateKeyP256(t *testing.T) { func TestCertificateV2_VerifyPrivateKeyP256(t *testing.T) {
ca, _, caKey, _ := NewTestCaCert(Version2, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil) ca, _, caKey, _ := NewTestCaCert(Version2, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil)
err := ca.VerifyPrivateKey(Curve_P256, caKey) err := ca.VerifyPrivateKey(Curve_P256, caKey)
assert.Nil(t, err) require.NoError(t, err)
_, _, caKey2, _ := NewTestCaCert(Version2, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil) _, _, caKey2, _ := NewTestCaCert(Version2, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil)
assert.Nil(t, err) require.NoError(t, err)
err = ca.VerifyPrivateKey(Curve_P256, caKey2) err = ca.VerifyPrivateKey(Curve_P256, caKey2)
assert.NotNil(t, err) require.Error(t, err)
c, _, priv, _ := NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil) c, _, priv, _ := NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil)
rawPriv, b, curve, err := UnmarshalPrivateKeyFromPEM(priv) rawPriv, b, curve, err := UnmarshalPrivateKeyFromPEM(priv)
assert.NoError(t, err) require.NoError(t, err)
assert.Empty(t, b) assert.Empty(t, b)
assert.Equal(t, Curve_P256, curve) assert.Equal(t, Curve_P256, curve)
err = c.VerifyPrivateKey(Curve_P256, rawPriv) err = c.VerifyPrivateKey(Curve_P256, rawPriv)
assert.Nil(t, err) require.NoError(t, err)
_, priv2 := P256Keypair() _, priv2 := P256Keypair()
err = c.VerifyPrivateKey(Curve_P256, priv2) err = c.VerifyPrivateKey(Curve_P256, priv2)
assert.NotNil(t, err) require.Error(t, err)
} }
func TestCertificateV2_Copy(t *testing.T) { func TestCertificateV2_Copy(t *testing.T) {
@@ -223,7 +224,7 @@ func TestCertificateV2_Copy(t *testing.T) {
func TestUnmarshalCertificateV2(t *testing.T) { func TestUnmarshalCertificateV2(t *testing.T) {
data := []byte("\x98\x00\x00") data := []byte("\x98\x00\x00")
_, err := unmarshalCertificateV2(data, nil, Curve_CURVE25519) _, err := unmarshalCertificateV2(data, nil, Curve_CURVE25519)
assert.EqualError(t, err, "bad wire format") require.EqualError(t, err, "bad wire format")
} }
func TestCertificateV2_marshalForSigningStability(t *testing.T) { func TestCertificateV2_marshalForSigningStability(t *testing.T) {
@@ -261,6 +262,7 @@ func TestCertificateV2_marshalForSigningStability(t *testing.T) {
assert.Equal(t, expectedRawDetails, db) assert.Equal(t, expectedRawDetails, db)
expectedForSigning, err := hex.DecodeString(expectedRawDetailsStr + "00313233343536373839306162636564666768696a313233343536373839306162") expectedForSigning, err := hex.DecodeString(expectedRawDetailsStr + "00313233343536373839306162636564666768696a313233343536373839306162")
require.NoError(t, err)
b, err := nc.marshalForSigning() b, err := nc.marshalForSigning()
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, expectedForSigning, b) assert.Equal(t, expectedForSigning, b)

View File

@@ -227,6 +227,9 @@ func UnmarshalNebulaEncryptedData(b []byte) (*NebulaEncryptedData, error) {
} }
func unmarshalArgon2Parameters(params *RawNebulaArgon2Parameters) (*Argon2Parameters, error) { func unmarshalArgon2Parameters(params *RawNebulaArgon2Parameters) (*Argon2Parameters, error) {
// Are we testing the compilers types here?
// No value of int32 is lewss than math.MinInt32.
// By definition these checks can never be true.
if params.Version < math.MinInt32 || params.Version > math.MaxInt32 { if params.Version < math.MinInt32 || params.Version > math.MaxInt32 {
return nil, fmt.Errorf("Argon2Parameters Version must be at least %d and no more than %d", math.MinInt32, math.MaxInt32) return nil, fmt.Errorf("Argon2Parameters Version must be at least %d and no more than %d", math.MinInt32, math.MaxInt32)
} }

View File

@@ -4,19 +4,20 @@ import (
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/crypto/argon2" "golang.org/x/crypto/argon2"
) )
func TestNewArgon2Parameters(t *testing.T) { func TestNewArgon2Parameters(t *testing.T) {
p := NewArgon2Parameters(64*1024, 4, 3) p := NewArgon2Parameters(64*1024, 4, 3)
assert.EqualValues(t, &Argon2Parameters{ assert.Equal(t, &Argon2Parameters{
version: argon2.Version, version: argon2.Version,
Memory: 64 * 1024, Memory: 64 * 1024,
Parallelism: 4, Parallelism: 4,
Iterations: 3, Iterations: 3,
}, p) }, p)
p = NewArgon2Parameters(2*1024*1024, 2, 1) p = NewArgon2Parameters(2*1024*1024, 2, 1)
assert.EqualValues(t, &Argon2Parameters{ assert.Equal(t, &Argon2Parameters{
version: argon2.Version, version: argon2.Version,
Memory: 2 * 1024 * 1024, Memory: 2 * 1024 * 1024,
Parallelism: 2, Parallelism: 2,
@@ -61,35 +62,39 @@ qrlJ69wer3ZUHFXA
// Success test case // Success test case
curve, k, rest, err := DecryptAndUnmarshalSigningPrivateKey(passphrase, keyBundle) curve, k, rest, err := DecryptAndUnmarshalSigningPrivateKey(passphrase, keyBundle)
assert.Nil(t, err) require.NoError(t, err)
assert.Equal(t, Curve_CURVE25519, curve) assert.Equal(t, Curve_CURVE25519, curve)
assert.Len(t, k, 64) assert.Len(t, k, 64)
assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem)) assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem))
// Fail due to short key // Fail due to short key
curve, k, rest, err = DecryptAndUnmarshalSigningPrivateKey(passphrase, rest) curve, k, rest, err = DecryptAndUnmarshalSigningPrivateKey(passphrase, rest)
assert.EqualError(t, err, "key was not 64 bytes, is invalid ed25519 private key") require.EqualError(t, err, "key was not 64 bytes, is invalid ed25519 private key")
assert.Nil(t, k) assert.Nil(t, k)
assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem)) assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem))
assert.Equal(t, curve, Curve_CURVE25519)
// Fail due to invalid banner // Fail due to invalid banner
curve, k, rest, err = DecryptAndUnmarshalSigningPrivateKey(passphrase, rest) curve, k, rest, err = DecryptAndUnmarshalSigningPrivateKey(passphrase, rest)
assert.EqualError(t, err, "bytes did not contain a proper nebula encrypted Ed25519/ECDSA private key banner") require.EqualError(t, err, "bytes did not contain a proper nebula encrypted Ed25519/ECDSA private key banner")
assert.Nil(t, k) assert.Nil(t, k)
assert.Equal(t, rest, invalidPem) assert.Equal(t, rest, invalidPem)
assert.Equal(t, curve, Curve_CURVE25519)
// Fail due to ivalid PEM format, because // Fail due to ivalid PEM format, because
// it's missing the requisite pre-encapsulation boundary. // it's missing the requisite pre-encapsulation boundary.
curve, k, rest, err = DecryptAndUnmarshalSigningPrivateKey(passphrase, rest) curve, k, rest, err = DecryptAndUnmarshalSigningPrivateKey(passphrase, rest)
assert.EqualError(t, err, "input did not contain a valid PEM encoded block") require.EqualError(t, err, "input did not contain a valid PEM encoded block")
assert.Nil(t, k) assert.Nil(t, k)
assert.Equal(t, rest, invalidPem) assert.Equal(t, rest, invalidPem)
assert.Equal(t, curve, Curve_CURVE25519)
// Fail due to invalid passphrase // Fail due to invalid passphrase
curve, k, rest, err = DecryptAndUnmarshalSigningPrivateKey([]byte("invalid passphrase"), privKey) curve, k, rest, err = DecryptAndUnmarshalSigningPrivateKey([]byte("invalid passphrase"), privKey)
assert.EqualError(t, err, "invalid passphrase or corrupt private key") require.EqualError(t, err, "invalid passphrase or corrupt private key")
assert.Nil(t, k) assert.Nil(t, k)
assert.Equal(t, rest, []byte{}) assert.Equal(t, []byte{}, rest)
assert.Equal(t, curve, Curve_CURVE25519)
} }
func TestEncryptAndMarshalSigningPrivateKey(t *testing.T) { func TestEncryptAndMarshalSigningPrivateKey(t *testing.T) {
@@ -99,14 +104,14 @@ func TestEncryptAndMarshalSigningPrivateKey(t *testing.T) {
bytes := []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA") bytes := []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA")
kdfParams := NewArgon2Parameters(64*1024, 4, 3) kdfParams := NewArgon2Parameters(64*1024, 4, 3)
key, err := EncryptAndMarshalSigningPrivateKey(Curve_CURVE25519, bytes, passphrase, kdfParams) key, err := EncryptAndMarshalSigningPrivateKey(Curve_CURVE25519, bytes, passphrase, kdfParams)
assert.Nil(t, err) require.NoError(t, err)
// Verify the "key" can be decrypted successfully // Verify the "key" can be decrypted successfully
curve, k, rest, err := DecryptAndUnmarshalSigningPrivateKey(passphrase, key) curve, k, rest, err := DecryptAndUnmarshalSigningPrivateKey(passphrase, key)
assert.Len(t, k, 64) assert.Len(t, k, 64)
assert.Equal(t, Curve_CURVE25519, curve) assert.Equal(t, Curve_CURVE25519, curve)
assert.Equal(t, rest, []byte{}) assert.Equal(t, []byte{}, rest)
assert.Nil(t, err) require.NoError(t, err)
// EncryptAndMarshalEd25519PrivateKey does not create any errors itself // EncryptAndMarshalEd25519PrivateKey does not create any errors itself
} }

View File

@@ -21,6 +21,9 @@ func NewTestCaCert(version Version, curve Curve, before, after time.Time, networ
switch curve { switch curve {
case Curve_CURVE25519: case Curve_CURVE25519:
pub, priv, err = ed25519.GenerateKey(rand.Reader) pub, priv, err = ed25519.GenerateKey(rand.Reader)
if err != nil {
panic(err)
}
case Curve_P256: case Curve_P256:
privk, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) privk, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil { if err != nil {

View File

@@ -4,6 +4,7 @@ import (
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
func TestUnmarshalCertificateFromPEM(t *testing.T) { func TestUnmarshalCertificateFromPEM(t *testing.T) {
@@ -35,20 +36,20 @@ bzBEr00kERQxxTzTsH8cpYEgRoipvmExvg8WP8NdAJEYJosB
cert, rest, err := UnmarshalCertificateFromPEM(certBundle) cert, rest, err := UnmarshalCertificateFromPEM(certBundle)
assert.NotNil(t, cert) assert.NotNil(t, cert)
assert.Equal(t, rest, append(badBanner, invalidPem...)) assert.Equal(t, rest, append(badBanner, invalidPem...))
assert.Nil(t, err) require.NoError(t, err)
// Fail due to invalid banner. // Fail due to invalid banner.
cert, rest, err = UnmarshalCertificateFromPEM(rest) cert, rest, err = UnmarshalCertificateFromPEM(rest)
assert.Nil(t, cert) assert.Nil(t, cert)
assert.Equal(t, rest, invalidPem) assert.Equal(t, rest, invalidPem)
assert.EqualError(t, err, "bytes did not contain a proper certificate banner") require.EqualError(t, err, "bytes did not contain a proper certificate banner")
// Fail due to ivalid PEM format, because // Fail due to ivalid PEM format, because
// it's missing the requisite pre-encapsulation boundary. // it's missing the requisite pre-encapsulation boundary.
cert, rest, err = UnmarshalCertificateFromPEM(rest) cert, rest, err = UnmarshalCertificateFromPEM(rest)
assert.Nil(t, cert) assert.Nil(t, cert)
assert.Equal(t, rest, invalidPem) assert.Equal(t, rest, invalidPem)
assert.EqualError(t, err, "input did not contain a valid PEM encoded block") require.EqualError(t, err, "input did not contain a valid PEM encoded block")
} }
func TestUnmarshalSigningPrivateKeyFromPEM(t *testing.T) { func TestUnmarshalSigningPrivateKeyFromPEM(t *testing.T) {
@@ -84,33 +85,36 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
assert.Len(t, k, 64) assert.Len(t, k, 64)
assert.Equal(t, rest, appendByteSlices(privP256Key, shortKey, invalidBanner, invalidPem)) assert.Equal(t, rest, appendByteSlices(privP256Key, shortKey, invalidBanner, invalidPem))
assert.Equal(t, Curve_CURVE25519, curve) assert.Equal(t, Curve_CURVE25519, curve)
assert.Nil(t, err) require.NoError(t, err)
// Success test case // Success test case
k, rest, curve, err = UnmarshalSigningPrivateKeyFromPEM(rest) k, rest, curve, err = UnmarshalSigningPrivateKeyFromPEM(rest)
assert.Len(t, k, 32) assert.Len(t, k, 32)
assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem)) assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem))
assert.Equal(t, Curve_P256, curve) assert.Equal(t, Curve_P256, curve)
assert.Nil(t, err) require.NoError(t, err)
// Fail due to short key // Fail due to short key
k, rest, curve, err = UnmarshalSigningPrivateKeyFromPEM(rest) k, rest, curve, err = UnmarshalSigningPrivateKeyFromPEM(rest)
assert.Nil(t, k) assert.Nil(t, k)
assert.Equal(t, Curve_CURVE25519, curve)
assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem)) assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem))
assert.EqualError(t, err, "key was not 64 bytes, is invalid Ed25519 private key") require.EqualError(t, err, "key was not 64 bytes, is invalid Ed25519 private key")
// Fail due to invalid banner // Fail due to invalid banner
k, rest, curve, err = UnmarshalSigningPrivateKeyFromPEM(rest) k, rest, curve, err = UnmarshalSigningPrivateKeyFromPEM(rest)
assert.Nil(t, k) assert.Nil(t, k)
assert.Equal(t, Curve_CURVE25519, curve)
assert.Equal(t, rest, invalidPem) assert.Equal(t, rest, invalidPem)
assert.EqualError(t, err, "bytes did not contain a proper Ed25519/ECDSA private key banner") require.EqualError(t, err, "bytes did not contain a proper Ed25519/ECDSA private key banner")
// Fail due to ivalid PEM format, because // Fail due to ivalid PEM format, because
// it's missing the requisite pre-encapsulation boundary. // it's missing the requisite pre-encapsulation boundary.
k, rest, curve, err = UnmarshalSigningPrivateKeyFromPEM(rest) k, rest, curve, err = UnmarshalSigningPrivateKeyFromPEM(rest)
assert.Nil(t, k) assert.Nil(t, k)
assert.Equal(t, Curve_CURVE25519, curve)
assert.Equal(t, rest, invalidPem) assert.Equal(t, rest, invalidPem)
assert.EqualError(t, err, "input did not contain a valid PEM encoded block") require.EqualError(t, err, "input did not contain a valid PEM encoded block")
} }
func TestUnmarshalPrivateKeyFromPEM(t *testing.T) { func TestUnmarshalPrivateKeyFromPEM(t *testing.T) {
@@ -146,33 +150,36 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
assert.Len(t, k, 32) assert.Len(t, k, 32)
assert.Equal(t, rest, appendByteSlices(privP256Key, shortKey, invalidBanner, invalidPem)) assert.Equal(t, rest, appendByteSlices(privP256Key, shortKey, invalidBanner, invalidPem))
assert.Equal(t, Curve_CURVE25519, curve) assert.Equal(t, Curve_CURVE25519, curve)
assert.Nil(t, err) require.NoError(t, err)
// Success test case // Success test case
k, rest, curve, err = UnmarshalPrivateKeyFromPEM(rest) k, rest, curve, err = UnmarshalPrivateKeyFromPEM(rest)
assert.Len(t, k, 32) assert.Len(t, k, 32)
assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem)) assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem))
assert.Equal(t, Curve_P256, curve) assert.Equal(t, Curve_P256, curve)
assert.Nil(t, err) require.NoError(t, err)
// Fail due to short key // Fail due to short key
k, rest, curve, err = UnmarshalPrivateKeyFromPEM(rest) k, rest, curve, err = UnmarshalPrivateKeyFromPEM(rest)
assert.Nil(t, k) assert.Nil(t, k)
assert.Equal(t, Curve_CURVE25519, curve)
assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem)) assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem))
assert.EqualError(t, err, "key was not 32 bytes, is invalid CURVE25519 private key") require.EqualError(t, err, "key was not 32 bytes, is invalid CURVE25519 private key")
// Fail due to invalid banner // Fail due to invalid banner
k, rest, curve, err = UnmarshalPrivateKeyFromPEM(rest) k, rest, curve, err = UnmarshalPrivateKeyFromPEM(rest)
assert.Nil(t, k) assert.Nil(t, k)
assert.Equal(t, Curve_CURVE25519, curve)
assert.Equal(t, rest, invalidPem) assert.Equal(t, rest, invalidPem)
assert.EqualError(t, err, "bytes did not contain a proper private key banner") require.EqualError(t, err, "bytes did not contain a proper private key banner")
// Fail due to ivalid PEM format, because // Fail due to ivalid PEM format, because
// it's missing the requisite pre-encapsulation boundary. // it's missing the requisite pre-encapsulation boundary.
k, rest, curve, err = UnmarshalPrivateKeyFromPEM(rest) k, rest, curve, err = UnmarshalPrivateKeyFromPEM(rest)
assert.Nil(t, k) assert.Nil(t, k)
assert.Equal(t, Curve_CURVE25519, curve)
assert.Equal(t, rest, invalidPem) assert.Equal(t, rest, invalidPem)
assert.EqualError(t, err, "input did not contain a valid PEM encoded block") require.EqualError(t, err, "input did not contain a valid PEM encoded block")
} }
func TestUnmarshalPublicKeyFromPEM(t *testing.T) { func TestUnmarshalPublicKeyFromPEM(t *testing.T) {
@@ -200,9 +207,9 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
// Success test case // Success test case
k, rest, curve, err := UnmarshalPublicKeyFromPEM(keyBundle) k, rest, curve, err := UnmarshalPublicKeyFromPEM(keyBundle)
assert.Equal(t, 32, len(k)) assert.Len(t, k, 32)
assert.Equal(t, Curve_CURVE25519, curve) assert.Equal(t, Curve_CURVE25519, curve)
assert.Nil(t, err) require.NoError(t, err)
assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem)) assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem))
// Fail due to short key // Fail due to short key
@@ -210,13 +217,13 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
assert.Nil(t, k) assert.Nil(t, k)
assert.Equal(t, Curve_CURVE25519, curve) assert.Equal(t, Curve_CURVE25519, curve)
assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem)) assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem))
assert.EqualError(t, err, "key was not 32 bytes, is invalid CURVE25519 public key") require.EqualError(t, err, "key was not 32 bytes, is invalid CURVE25519 public key")
// Fail due to invalid banner // Fail due to invalid banner
k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest) k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest)
assert.Nil(t, k) assert.Nil(t, k)
assert.Equal(t, Curve_CURVE25519, curve) assert.Equal(t, Curve_CURVE25519, curve)
assert.EqualError(t, err, "bytes did not contain a proper public key banner") require.EqualError(t, err, "bytes did not contain a proper public key banner")
assert.Equal(t, rest, invalidPem) assert.Equal(t, rest, invalidPem)
// Fail due to ivalid PEM format, because // Fail due to ivalid PEM format, because
@@ -225,7 +232,7 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
assert.Nil(t, k) assert.Nil(t, k)
assert.Equal(t, Curve_CURVE25519, curve) assert.Equal(t, Curve_CURVE25519, curve)
assert.Equal(t, rest, invalidPem) assert.Equal(t, rest, invalidPem)
assert.EqualError(t, err, "input did not contain a valid PEM encoded block") require.EqualError(t, err, "input did not contain a valid PEM encoded block")
} }
func TestUnmarshalX25519PublicKey(t *testing.T) { func TestUnmarshalX25519PublicKey(t *testing.T) {
@@ -259,34 +266,37 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
// Success test case // Success test case
k, rest, curve, err := UnmarshalPublicKeyFromPEM(keyBundle) k, rest, curve, err := UnmarshalPublicKeyFromPEM(keyBundle)
assert.Equal(t, 32, len(k)) assert.Len(t, k, 32)
assert.Nil(t, err) require.NoError(t, err)
assert.Equal(t, rest, appendByteSlices(pubP256Key, shortKey, invalidBanner, invalidPem)) assert.Equal(t, rest, appendByteSlices(pubP256Key, shortKey, invalidBanner, invalidPem))
assert.Equal(t, Curve_CURVE25519, curve) assert.Equal(t, Curve_CURVE25519, curve)
// Success test case // Success test case
k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest) k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest)
assert.Equal(t, 65, len(k)) assert.Len(t, k, 65)
assert.Nil(t, err) require.NoError(t, err)
assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem)) assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem))
assert.Equal(t, Curve_P256, curve) assert.Equal(t, Curve_P256, curve)
// Fail due to short key // Fail due to short key
k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest) k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest)
assert.Nil(t, k) assert.Nil(t, k)
assert.Equal(t, Curve_CURVE25519, curve)
assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem)) assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem))
assert.EqualError(t, err, "key was not 32 bytes, is invalid CURVE25519 public key") require.EqualError(t, err, "key was not 32 bytes, is invalid CURVE25519 public key")
// Fail due to invalid banner // Fail due to invalid banner
k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest) k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest)
assert.Nil(t, k) assert.Nil(t, k)
assert.EqualError(t, err, "bytes did not contain a proper public key banner") assert.Equal(t, Curve_CURVE25519, curve)
require.EqualError(t, err, "bytes did not contain a proper public key banner")
assert.Equal(t, rest, invalidPem) assert.Equal(t, rest, invalidPem)
// Fail due to ivalid PEM format, because // Fail due to ivalid PEM format, because
// it's missing the requisite pre-encapsulation boundary. // it's missing the requisite pre-encapsulation boundary.
k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest) k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest)
assert.Nil(t, k) assert.Nil(t, k)
assert.Equal(t, Curve_CURVE25519, curve)
assert.Equal(t, rest, invalidPem) assert.Equal(t, rest, invalidPem)
assert.EqualError(t, err, "input did not contain a valid PEM encoded block") require.EqualError(t, err, "input did not contain a valid PEM encoded block")
} }

View File

@@ -10,6 +10,7 @@ import (
"time" "time"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
func TestCertificateV1_Sign(t *testing.T) { func TestCertificateV1_Sign(t *testing.T) {
@@ -36,15 +37,16 @@ func TestCertificateV1_Sign(t *testing.T) {
} }
pub, priv, err := ed25519.GenerateKey(rand.Reader) pub, priv, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
c, err := tbs.Sign(&certificateV1{details: detailsV1{notBefore: before, notAfter: after}}, Curve_CURVE25519, priv) c, err := tbs.Sign(&certificateV1{details: detailsV1{notBefore: before, notAfter: after}}, Curve_CURVE25519, priv)
assert.Nil(t, err) require.NoError(t, err)
assert.NotNil(t, c) assert.NotNil(t, c)
assert.True(t, c.CheckSignature(pub)) assert.True(t, c.CheckSignature(pub))
b, err := c.Marshal() b, err := c.Marshal()
assert.Nil(t, err) require.NoError(t, err)
uc, err := unmarshalCertificateV1(b, nil) uc, err := unmarshalCertificateV1(b, nil)
assert.Nil(t, err) require.NoError(t, err)
assert.NotNil(t, uc) assert.NotNil(t, uc)
} }
@@ -73,18 +75,18 @@ func TestCertificateV1_SignP256(t *testing.T) {
} }
priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
assert.NoError(t, err) require.NoError(t, err)
pub := elliptic.Marshal(elliptic.P256(), priv.PublicKey.X, priv.PublicKey.Y) pub := elliptic.Marshal(elliptic.P256(), priv.PublicKey.X, priv.PublicKey.Y)
rawPriv := priv.D.FillBytes(make([]byte, 32)) rawPriv := priv.D.FillBytes(make([]byte, 32))
c, err := tbs.Sign(&certificateV1{details: detailsV1{notBefore: before, notAfter: after}}, Curve_P256, rawPriv) c, err := tbs.Sign(&certificateV1{details: detailsV1{notBefore: before, notAfter: after}}, Curve_P256, rawPriv)
assert.Nil(t, err) require.NoError(t, err)
assert.NotNil(t, c) assert.NotNil(t, c)
assert.True(t, c.CheckSignature(pub)) assert.True(t, c.CheckSignature(pub))
b, err := c.Marshal() b, err := c.Marshal()
assert.Nil(t, err) require.NoError(t, err)
uc, err := unmarshalCertificateV1(b, nil) uc, err := unmarshalCertificateV1(b, nil)
assert.Nil(t, err) require.NoError(t, err)
assert.NotNil(t, uc) assert.NotNil(t, uc)
} }

View File

@@ -22,6 +22,9 @@ func NewTestCaCert(version cert.Version, curve cert.Curve, before, after time.Ti
switch curve { switch curve {
case cert.Curve_CURVE25519: case cert.Curve_CURVE25519:
pub, priv, err = ed25519.GenerateKey(rand.Reader) pub, priv, err = ed25519.GenerateKey(rand.Reader)
if err != nil {
panic(err)
}
case cert.Curve_P256: case cert.Curve_P256:
privk, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) privk, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil { if err != nil {

View File

@@ -81,7 +81,7 @@ func parseArgonParameters(memory uint, parallelism uint, iterations uint) (*cert
return cert.NewArgon2Parameters(uint32(memory), uint8(parallelism), uint32(iterations)), nil return cert.NewArgon2Parameters(uint32(memory), uint8(parallelism), uint32(iterations)), nil
} }
func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error { func ca(args []string, out io.Writer, _ io.Writer, pr PasswordReader) error {
cf := newCaFlags() cf := newCaFlags()
err := cf.set.Parse(args) err := cf.set.Parse(args)
if err != nil { if err != nil {

View File

@@ -14,6 +14,7 @@ import (
"github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/cert"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
func Test_caSummary(t *testing.T) { func Test_caSummary(t *testing.T) {
@@ -89,75 +90,75 @@ func Test_ca(t *testing.T) {
assertHelpError(t, ca( assertHelpError(t, ca(
[]string{"-version", "1", "-out-key", "nope", "-out-crt", "nope", "duration", "100m"}, ob, eb, nopw, []string{"-version", "1", "-out-key", "nope", "-out-crt", "nope", "duration", "100m"}, ob, eb, nopw,
), "-name is required") ), "-name is required")
assert.Equal(t, "", ob.String()) assert.Empty(t, ob.String())
assert.Equal(t, "", eb.String()) assert.Empty(t, eb.String())
// ipv4 only ips // ipv4 only ips
assertHelpError(t, ca([]string{"-version", "1", "-name", "ipv6", "-ips", "100::100/100"}, ob, eb, nopw), "invalid -networks definition: v1 certificates can only be ipv4, have 100::100/100") assertHelpError(t, ca([]string{"-version", "1", "-name", "ipv6", "-ips", "100::100/100"}, ob, eb, nopw), "invalid -networks definition: v1 certificates can only be ipv4, have 100::100/100")
assert.Equal(t, "", ob.String()) assert.Empty(t, ob.String())
assert.Equal(t, "", eb.String()) assert.Empty(t, eb.String())
// ipv4 only subnets // ipv4 only subnets
assertHelpError(t, ca([]string{"-version", "1", "-name", "ipv6", "-subnets", "100::100/100"}, ob, eb, nopw), "invalid -unsafe-networks definition: v1 certificates can only be ipv4, have 100::100/100") assertHelpError(t, ca([]string{"-version", "1", "-name", "ipv6", "-subnets", "100::100/100"}, ob, eb, nopw), "invalid -unsafe-networks definition: v1 certificates can only be ipv4, have 100::100/100")
assert.Equal(t, "", ob.String()) assert.Empty(t, ob.String())
assert.Equal(t, "", eb.String()) assert.Empty(t, eb.String())
// failed key write // failed key write
ob.Reset() ob.Reset()
eb.Reset() eb.Reset()
args := []string{"-version", "1", "-name", "test", "-duration", "100m", "-out-crt", "/do/not/write/pleasecrt", "-out-key", "/do/not/write/pleasekey"} args := []string{"-version", "1", "-name", "test", "-duration", "100m", "-out-crt", "/do/not/write/pleasecrt", "-out-key", "/do/not/write/pleasekey"}
assert.EqualError(t, ca(args, ob, eb, nopw), "error while writing out-key: open /do/not/write/pleasekey: "+NoSuchDirError) require.EqualError(t, ca(args, ob, eb, nopw), "error while writing out-key: open /do/not/write/pleasekey: "+NoSuchDirError)
assert.Equal(t, "", ob.String()) assert.Empty(t, ob.String())
assert.Equal(t, "", eb.String()) assert.Empty(t, eb.String())
// create temp key file // create temp key file
keyF, err := os.CreateTemp("", "test.key") keyF, err := os.CreateTemp("", "test.key")
assert.Nil(t, err) require.NoError(t, err)
assert.Nil(t, os.Remove(keyF.Name())) require.NoError(t, os.Remove(keyF.Name()))
// failed cert write // failed cert write
ob.Reset() ob.Reset()
eb.Reset() eb.Reset()
args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-out-crt", "/do/not/write/pleasecrt", "-out-key", keyF.Name()} args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-out-crt", "/do/not/write/pleasecrt", "-out-key", keyF.Name()}
assert.EqualError(t, ca(args, ob, eb, nopw), "error while writing out-crt: open /do/not/write/pleasecrt: "+NoSuchDirError) require.EqualError(t, ca(args, ob, eb, nopw), "error while writing out-crt: open /do/not/write/pleasecrt: "+NoSuchDirError)
assert.Equal(t, "", ob.String()) assert.Empty(t, ob.String())
assert.Equal(t, "", eb.String()) assert.Empty(t, eb.String())
// create temp cert file // create temp cert file
crtF, err := os.CreateTemp("", "test.crt") crtF, err := os.CreateTemp("", "test.crt")
assert.Nil(t, err) require.NoError(t, err)
assert.Nil(t, os.Remove(crtF.Name())) require.NoError(t, os.Remove(crtF.Name()))
assert.Nil(t, os.Remove(keyF.Name())) require.NoError(t, os.Remove(keyF.Name()))
// test proper cert with removed empty groups and subnets // test proper cert with removed empty groups and subnets
ob.Reset() ob.Reset()
eb.Reset() eb.Reset()
args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
assert.Nil(t, ca(args, ob, eb, nopw)) require.NoError(t, ca(args, ob, eb, nopw))
assert.Equal(t, "", ob.String()) assert.Empty(t, ob.String())
assert.Equal(t, "", eb.String()) assert.Empty(t, eb.String())
// read cert and key files // read cert and key files
rb, _ := os.ReadFile(keyF.Name()) rb, _ := os.ReadFile(keyF.Name())
lKey, b, c, err := cert.UnmarshalSigningPrivateKeyFromPEM(rb) lKey, b, c, err := cert.UnmarshalSigningPrivateKeyFromPEM(rb)
assert.Equal(t, cert.Curve_CURVE25519, c) assert.Equal(t, cert.Curve_CURVE25519, c)
assert.Len(t, b, 0) assert.Empty(t, b)
assert.Nil(t, err) require.NoError(t, err)
assert.Len(t, lKey, 64) assert.Len(t, lKey, 64)
rb, _ = os.ReadFile(crtF.Name()) rb, _ = os.ReadFile(crtF.Name())
lCrt, b, err := cert.UnmarshalCertificateFromPEM(rb) lCrt, b, err := cert.UnmarshalCertificateFromPEM(rb)
assert.Len(t, b, 0) assert.Empty(t, b)
assert.Nil(t, err) require.NoError(t, err)
assert.Equal(t, "test", lCrt.Name()) assert.Equal(t, "test", lCrt.Name())
assert.Len(t, lCrt.Networks(), 0) assert.Empty(t, lCrt.Networks())
assert.True(t, lCrt.IsCA()) assert.True(t, lCrt.IsCA())
assert.Equal(t, []string{"1", "2", "3", "4", "5"}, lCrt.Groups()) assert.Equal(t, []string{"1", "2", "3", "4", "5"}, lCrt.Groups())
assert.Len(t, lCrt.UnsafeNetworks(), 0) assert.Empty(t, lCrt.UnsafeNetworks())
assert.Len(t, lCrt.PublicKey(), 32) assert.Len(t, lCrt.PublicKey(), 32)
assert.Equal(t, time.Duration(time.Minute*100), lCrt.NotAfter().Sub(lCrt.NotBefore())) assert.Equal(t, time.Duration(time.Minute*100), lCrt.NotAfter().Sub(lCrt.NotBefore()))
assert.Equal(t, "", lCrt.Issuer()) assert.Empty(t, lCrt.Issuer())
assert.True(t, lCrt.CheckSignature(lCrt.PublicKey())) assert.True(t, lCrt.CheckSignature(lCrt.PublicKey()))
// test encrypted key // test encrypted key
@@ -166,15 +167,15 @@ func Test_ca(t *testing.T) {
ob.Reset() ob.Reset()
eb.Reset() eb.Reset()
args = []string{"-version", "1", "-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} args = []string{"-version", "1", "-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
assert.Nil(t, ca(args, ob, eb, testpw)) require.NoError(t, ca(args, ob, eb, testpw))
assert.Equal(t, pwPromptOb, ob.String()) assert.Equal(t, pwPromptOb, ob.String())
assert.Equal(t, "", eb.String()) assert.Empty(t, eb.String())
// read encrypted key file and verify default params // read encrypted key file and verify default params
rb, _ = os.ReadFile(keyF.Name()) rb, _ = os.ReadFile(keyF.Name())
k, _ := pem.Decode(rb) k, _ := pem.Decode(rb)
ned, err := cert.UnmarshalNebulaEncryptedData(k.Bytes) ned, err := cert.UnmarshalNebulaEncryptedData(k.Bytes)
assert.Nil(t, err) require.NoError(t, err)
// we won't know salt in advance, so just check start of string // we won't know salt in advance, so just check start of string
assert.Equal(t, uint32(2*1024*1024), ned.EncryptionMetadata.Argon2Parameters.Memory) assert.Equal(t, uint32(2*1024*1024), ned.EncryptionMetadata.Argon2Parameters.Memory)
assert.Equal(t, uint8(4), ned.EncryptionMetadata.Argon2Parameters.Parallelism) assert.Equal(t, uint8(4), ned.EncryptionMetadata.Argon2Parameters.Parallelism)
@@ -184,8 +185,8 @@ func Test_ca(t *testing.T) {
var curve cert.Curve var curve cert.Curve
curve, lKey, b, err = cert.DecryptAndUnmarshalSigningPrivateKey(passphrase, rb) curve, lKey, b, err = cert.DecryptAndUnmarshalSigningPrivateKey(passphrase, rb)
assert.Equal(t, cert.Curve_CURVE25519, curve) assert.Equal(t, cert.Curve_CURVE25519, curve)
assert.Nil(t, err) require.NoError(t, err)
assert.Len(t, b, 0) assert.Empty(t, b)
assert.Len(t, lKey, 64) assert.Len(t, lKey, 64)
// test when reading passsword results in an error // test when reading passsword results in an error
@@ -194,9 +195,9 @@ func Test_ca(t *testing.T) {
ob.Reset() ob.Reset()
eb.Reset() eb.Reset()
args = []string{"-version", "1", "-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} args = []string{"-version", "1", "-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
assert.Error(t, ca(args, ob, eb, errpw)) require.Error(t, ca(args, ob, eb, errpw))
assert.Equal(t, pwPromptOb, ob.String()) assert.Equal(t, pwPromptOb, ob.String())
assert.Equal(t, "", eb.String()) assert.Empty(t, eb.String())
// test when user fails to enter a password // test when user fails to enter a password
os.Remove(keyF.Name()) os.Remove(keyF.Name())
@@ -204,9 +205,9 @@ func Test_ca(t *testing.T) {
ob.Reset() ob.Reset()
eb.Reset() eb.Reset()
args = []string{"-version", "1", "-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} args = []string{"-version", "1", "-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
assert.EqualError(t, ca(args, ob, eb, nopw), "no passphrase specified, remove -encrypt flag to write out-key in plaintext") require.EqualError(t, ca(args, ob, eb, nopw), "no passphrase specified, remove -encrypt flag to write out-key in plaintext")
assert.Equal(t, strings.Repeat(pwPromptOb, 5), ob.String()) // prompts 5 times before giving up assert.Equal(t, strings.Repeat(pwPromptOb, 5), ob.String()) // prompts 5 times before giving up
assert.Equal(t, "", eb.String()) assert.Empty(t, eb.String())
// create valid cert/key for overwrite tests // create valid cert/key for overwrite tests
os.Remove(keyF.Name()) os.Remove(keyF.Name())
@@ -214,24 +215,24 @@ func Test_ca(t *testing.T) {
ob.Reset() ob.Reset()
eb.Reset() eb.Reset()
args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
assert.Nil(t, ca(args, ob, eb, nopw)) require.NoError(t, ca(args, ob, eb, nopw))
// test that we won't overwrite existing certificate file // test that we won't overwrite existing certificate file
ob.Reset() ob.Reset()
eb.Reset() eb.Reset()
args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
assert.EqualError(t, ca(args, ob, eb, nopw), "refusing to overwrite existing CA key: "+keyF.Name()) require.EqualError(t, ca(args, ob, eb, nopw), "refusing to overwrite existing CA key: "+keyF.Name())
assert.Equal(t, "", ob.String()) assert.Empty(t, ob.String())
assert.Equal(t, "", eb.String()) assert.Empty(t, eb.String())
// test that we won't overwrite existing key file // test that we won't overwrite existing key file
os.Remove(keyF.Name()) os.Remove(keyF.Name())
ob.Reset() ob.Reset()
eb.Reset() eb.Reset()
args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
assert.EqualError(t, ca(args, ob, eb, nopw), "refusing to overwrite existing CA cert: "+crtF.Name()) require.EqualError(t, ca(args, ob, eb, nopw), "refusing to overwrite existing CA cert: "+crtF.Name())
assert.Equal(t, "", ob.String()) assert.Empty(t, ob.String())
assert.Equal(t, "", eb.String()) assert.Empty(t, eb.String())
os.Remove(keyF.Name()) os.Remove(keyF.Name())
} }

View File

@@ -29,7 +29,7 @@ func newKeygenFlags() *keygenFlags {
return &cf return &cf
} }
func keygen(args []string, out io.Writer, errOut io.Writer) error { func keygen(args []string, _ io.Writer, _ io.Writer) error {
cf := newKeygenFlags() cf := newKeygenFlags()
err := cf.set.Parse(args) err := cf.set.Parse(args)
if err != nil { if err != nil {

View File

@@ -7,6 +7,7 @@ import (
"github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/cert"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
func Test_keygenSummary(t *testing.T) { func Test_keygenSummary(t *testing.T) {
@@ -36,59 +37,59 @@ func Test_keygen(t *testing.T) {
// required args // required args
assertHelpError(t, keygen([]string{"-out-pub", "nope"}, ob, eb), "-out-key is required") assertHelpError(t, keygen([]string{"-out-pub", "nope"}, ob, eb), "-out-key is required")
assert.Equal(t, "", ob.String()) assert.Empty(t, ob.String())
assert.Equal(t, "", eb.String()) assert.Empty(t, eb.String())
assertHelpError(t, keygen([]string{"-out-key", "nope"}, ob, eb), "-out-pub is required") assertHelpError(t, keygen([]string{"-out-key", "nope"}, ob, eb), "-out-pub is required")
assert.Equal(t, "", ob.String()) assert.Empty(t, ob.String())
assert.Equal(t, "", eb.String()) assert.Empty(t, eb.String())
// failed key write // failed key write
ob.Reset() ob.Reset()
eb.Reset() eb.Reset()
args := []string{"-out-pub", "/do/not/write/pleasepub", "-out-key", "/do/not/write/pleasekey"} args := []string{"-out-pub", "/do/not/write/pleasepub", "-out-key", "/do/not/write/pleasekey"}
assert.EqualError(t, keygen(args, ob, eb), "error while writing out-key: open /do/not/write/pleasekey: "+NoSuchDirError) require.EqualError(t, keygen(args, ob, eb), "error while writing out-key: open /do/not/write/pleasekey: "+NoSuchDirError)
assert.Equal(t, "", ob.String()) assert.Empty(t, ob.String())
assert.Equal(t, "", eb.String()) assert.Empty(t, eb.String())
// create temp key file // create temp key file
keyF, err := os.CreateTemp("", "test.key") keyF, err := os.CreateTemp("", "test.key")
assert.Nil(t, err) require.NoError(t, err)
defer os.Remove(keyF.Name()) defer os.Remove(keyF.Name())
// failed pub write // failed pub write
ob.Reset() ob.Reset()
eb.Reset() eb.Reset()
args = []string{"-out-pub", "/do/not/write/pleasepub", "-out-key", keyF.Name()} args = []string{"-out-pub", "/do/not/write/pleasepub", "-out-key", keyF.Name()}
assert.EqualError(t, keygen(args, ob, eb), "error while writing out-pub: open /do/not/write/pleasepub: "+NoSuchDirError) require.EqualError(t, keygen(args, ob, eb), "error while writing out-pub: open /do/not/write/pleasepub: "+NoSuchDirError)
assert.Equal(t, "", ob.String()) assert.Empty(t, ob.String())
assert.Equal(t, "", eb.String()) assert.Empty(t, eb.String())
// create temp pub file // create temp pub file
pubF, err := os.CreateTemp("", "test.pub") pubF, err := os.CreateTemp("", "test.pub")
assert.Nil(t, err) require.NoError(t, err)
defer os.Remove(pubF.Name()) defer os.Remove(pubF.Name())
// test proper keygen // test proper keygen
ob.Reset() ob.Reset()
eb.Reset() eb.Reset()
args = []string{"-out-pub", pubF.Name(), "-out-key", keyF.Name()} args = []string{"-out-pub", pubF.Name(), "-out-key", keyF.Name()}
assert.Nil(t, keygen(args, ob, eb)) require.NoError(t, keygen(args, ob, eb))
assert.Equal(t, "", ob.String()) assert.Empty(t, ob.String())
assert.Equal(t, "", eb.String()) assert.Empty(t, eb.String())
// read cert and key files // read cert and key files
rb, _ := os.ReadFile(keyF.Name()) rb, _ := os.ReadFile(keyF.Name())
lKey, b, curve, err := cert.UnmarshalPrivateKeyFromPEM(rb) lKey, b, curve, err := cert.UnmarshalPrivateKeyFromPEM(rb)
assert.Equal(t, cert.Curve_CURVE25519, curve) assert.Equal(t, cert.Curve_CURVE25519, curve)
assert.Len(t, b, 0) assert.Empty(t, b)
assert.Nil(t, err) require.NoError(t, err)
assert.Len(t, lKey, 32) assert.Len(t, lKey, 32)
rb, _ = os.ReadFile(pubF.Name()) rb, _ = os.ReadFile(pubF.Name())
lPub, b, curve, err := cert.UnmarshalPublicKeyFromPEM(rb) lPub, b, curve, err := cert.UnmarshalPublicKeyFromPEM(rb)
assert.Equal(t, cert.Curve_CURVE25519, curve) assert.Equal(t, cert.Curve_CURVE25519, curve)
assert.Len(t, b, 0) assert.Empty(t, b)
assert.Nil(t, err) require.NoError(t, err)
assert.Len(t, lPub, 32) assert.Len(t, lPub, 32)
} }

View File

@@ -17,7 +17,7 @@ func (he *helpError) Error() string {
return he.s return he.s
} }
func newHelpErrorf(s string, v ...interface{}) error { func newHelpErrorf(s string, v ...any) error {
return &helpError{s: fmt.Sprintf(s, v...)} return &helpError{s: fmt.Sprintf(s, v...)}
} }

View File

@@ -3,12 +3,12 @@ package main
import ( import (
"bytes" "bytes"
"errors" "errors"
"fmt"
"io" "io"
"os" "os"
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
func Test_help(t *testing.T) { func Test_help(t *testing.T) {
@@ -76,10 +76,10 @@ func assertHelpError(t *testing.T, err error, msg string) {
case *helpError: case *helpError:
// good // good
default: default:
t.Fatal(fmt.Sprintf("err was not a helpError: %q, expected %q", err, msg)) t.Fatalf("err was not a helpError: %q, expected %q", err, msg)
} }
assert.EqualError(t, err, msg) require.EqualError(t, err, msg)
} }
func optionalPkcs11String(msg string) string { func optionalPkcs11String(msg string) string {

View File

@@ -10,7 +10,7 @@ func p11Supported() bool {
return false return false
} }
func p11Flag(set *flag.FlagSet) *string { func p11Flag(_ *flag.FlagSet) *string {
var ret = "" var ret = ""
return &ret return &ret
} }

View File

@@ -1,12 +1,12 @@
package main package main
import ( import (
"bytes"
"encoding/json" "encoding/json"
"flag" "flag"
"fmt" "fmt"
"io" "io"
"os" "os"
"strings"
"github.com/skip2/go-qrcode" "github.com/skip2/go-qrcode"
"github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/cert"
@@ -29,7 +29,7 @@ func newPrintFlags() *printFlags {
return &pf return &pf
} }
func printCert(args []string, out io.Writer, errOut io.Writer) error { func printCert(args []string, out io.Writer, _ io.Writer) error {
pf := newPrintFlags() pf := newPrintFlags()
err := pf.set.Parse(args) err := pf.set.Parse(args)
if err != nil { if err != nil {
@@ -72,7 +72,7 @@ func printCert(args []string, out io.Writer, errOut io.Writer) error {
qrBytes = append(qrBytes, b...) qrBytes = append(qrBytes, b...)
} }
if rawCert == nil || len(rawCert) == 0 || strings.TrimSpace(string(rawCert)) == "" { if len(rawCert) == 0 || len(bytes.TrimSpace(rawCert)) == 0 {
break break
} }

View File

@@ -12,6 +12,7 @@ import (
"github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/cert"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
func Test_printSummary(t *testing.T) { func Test_printSummary(t *testing.T) {
@@ -42,30 +43,30 @@ func Test_printCert(t *testing.T) {
// no path // no path
err := printCert([]string{}, ob, eb) err := printCert([]string{}, ob, eb)
assert.Equal(t, "", ob.String()) assert.Empty(t, ob.String())
assert.Equal(t, "", eb.String()) assert.Empty(t, eb.String())
assertHelpError(t, err, "-path is required") assertHelpError(t, err, "-path is required")
// no cert at path // no cert at path
ob.Reset() ob.Reset()
eb.Reset() eb.Reset()
err = printCert([]string{"-path", "does_not_exist"}, ob, eb) err = printCert([]string{"-path", "does_not_exist"}, ob, eb)
assert.Equal(t, "", ob.String()) assert.Empty(t, ob.String())
assert.Equal(t, "", eb.String()) assert.Empty(t, eb.String())
assert.EqualError(t, err, "unable to read cert; open does_not_exist: "+NoSuchFileError) require.EqualError(t, err, "unable to read cert; open does_not_exist: "+NoSuchFileError)
// invalid cert at path // invalid cert at path
ob.Reset() ob.Reset()
eb.Reset() eb.Reset()
tf, err := os.CreateTemp("", "print-cert") tf, err := os.CreateTemp("", "print-cert")
assert.Nil(t, err) require.NoError(t, err)
defer os.Remove(tf.Name()) defer os.Remove(tf.Name())
tf.WriteString("-----BEGIN NOPE-----") tf.WriteString("-----BEGIN NOPE-----")
err = printCert([]string{"-path", tf.Name()}, ob, eb) err = printCert([]string{"-path", tf.Name()}, ob, eb)
assert.Equal(t, "", ob.String()) assert.Empty(t, ob.String())
assert.Equal(t, "", eb.String()) assert.Empty(t, eb.String())
assert.EqualError(t, err, "error while unmarshaling cert: input did not contain a valid PEM encoded block") require.EqualError(t, err, "error while unmarshaling cert: input did not contain a valid PEM encoded block")
// test multiple certs // test multiple certs
ob.Reset() ob.Reset()
@@ -84,7 +85,7 @@ func Test_printCert(t *testing.T) {
fp, _ := c.Fingerprint() fp, _ := c.Fingerprint()
pk := hex.EncodeToString(c.PublicKey()) pk := hex.EncodeToString(c.PublicKey())
sig := hex.EncodeToString(c.Signature()) sig := hex.EncodeToString(c.Signature())
assert.Nil(t, err) require.NoError(t, err)
assert.Equal( assert.Equal(
t, t,
//"NebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: "+c.Issuer()+"\n\t\tPublic key: "+pk+"\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: "+fp+"\n\tSignature: "+sig+"\n}\nNebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: "+c.Issuer()+"\n\t\tPublic key: "+pk+"\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: "+fp+"\n\tSignature: "+sig+"\n}\nNebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: "+c.Issuer()+"\n\t\tPublic key: "+pk+"\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: "+fp+"\n\tSignature: "+sig+"\n}\n", //"NebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: "+c.Issuer()+"\n\t\tPublic key: "+pk+"\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: "+fp+"\n\tSignature: "+sig+"\n}\nNebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: "+c.Issuer()+"\n\t\tPublic key: "+pk+"\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: "+fp+"\n\tSignature: "+sig+"\n}\nNebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: "+c.Issuer()+"\n\t\tPublic key: "+pk+"\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: "+fp+"\n\tSignature: "+sig+"\n}\n",
@@ -154,7 +155,7 @@ func Test_printCert(t *testing.T) {
`, `,
ob.String(), ob.String(),
) )
assert.Equal(t, "", eb.String()) assert.Empty(t, eb.String())
// test json // test json
ob.Reset() ob.Reset()
@@ -169,14 +170,14 @@ func Test_printCert(t *testing.T) {
fp, _ = c.Fingerprint() fp, _ = c.Fingerprint()
pk = hex.EncodeToString(c.PublicKey()) pk = hex.EncodeToString(c.PublicKey())
sig = hex.EncodeToString(c.Signature()) sig = hex.EncodeToString(c.Signature())
assert.Nil(t, err) require.NoError(t, err)
assert.Equal( assert.Equal(
t, t,
`[{"details":{"curve":"CURVE25519","groups":["hi"],"isCa":false,"issuer":"`+c.Issuer()+`","name":"test","networks":["10.0.0.123/8"],"notAfter":"0001-01-01T00:00:00Z","notBefore":"0001-01-01T00:00:00Z","publicKey":"`+pk+`","unsafeNetworks":[]},"fingerprint":"`+fp+`","signature":"`+sig+`","version":1},{"details":{"curve":"CURVE25519","groups":["hi"],"isCa":false,"issuer":"`+c.Issuer()+`","name":"test","networks":["10.0.0.123/8"],"notAfter":"0001-01-01T00:00:00Z","notBefore":"0001-01-01T00:00:00Z","publicKey":"`+pk+`","unsafeNetworks":[]},"fingerprint":"`+fp+`","signature":"`+sig+`","version":1},{"details":{"curve":"CURVE25519","groups":["hi"],"isCa":false,"issuer":"`+c.Issuer()+`","name":"test","networks":["10.0.0.123/8"],"notAfter":"0001-01-01T00:00:00Z","notBefore":"0001-01-01T00:00:00Z","publicKey":"`+pk+`","unsafeNetworks":[]},"fingerprint":"`+fp+`","signature":"`+sig+`","version":1}] `[{"details":{"curve":"CURVE25519","groups":["hi"],"isCa":false,"issuer":"`+c.Issuer()+`","name":"test","networks":["10.0.0.123/8"],"notAfter":"0001-01-01T00:00:00Z","notBefore":"0001-01-01T00:00:00Z","publicKey":"`+pk+`","unsafeNetworks":[]},"fingerprint":"`+fp+`","signature":"`+sig+`","version":1},{"details":{"curve":"CURVE25519","groups":["hi"],"isCa":false,"issuer":"`+c.Issuer()+`","name":"test","networks":["10.0.0.123/8"],"notAfter":"0001-01-01T00:00:00Z","notBefore":"0001-01-01T00:00:00Z","publicKey":"`+pk+`","unsafeNetworks":[]},"fingerprint":"`+fp+`","signature":"`+sig+`","version":1},{"details":{"curve":"CURVE25519","groups":["hi"],"isCa":false,"issuer":"`+c.Issuer()+`","name":"test","networks":["10.0.0.123/8"],"notAfter":"0001-01-01T00:00:00Z","notBefore":"0001-01-01T00:00:00Z","publicKey":"`+pk+`","unsafeNetworks":[]},"fingerprint":"`+fp+`","signature":"`+sig+`","version":1}]
`, `,
ob.String(), ob.String(),
) )
assert.Equal(t, "", eb.String()) assert.Empty(t, eb.String())
} }
// NewTestCaCert will generate a CA cert // NewTestCaCert will generate a CA cert

View File

@@ -13,6 +13,7 @@ import (
"github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/cert"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/crypto/ed25519" "golang.org/x/crypto/ed25519"
) )
@@ -103,17 +104,17 @@ func Test_signCert(t *testing.T) {
ob.Reset() ob.Reset()
eb.Reset() eb.Reset()
args := []string{"-version", "1", "-ca-crt", "./nope", "-ca-key", "./nope", "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"} args := []string{"-version", "1", "-ca-crt", "./nope", "-ca-key", "./nope", "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
assert.EqualError(t, signCert(args, ob, eb, nopw), "error while reading ca-key: open ./nope: "+NoSuchFileError) require.EqualError(t, signCert(args, ob, eb, nopw), "error while reading ca-key: open ./nope: "+NoSuchFileError)
// failed to unmarshal key // failed to unmarshal key
ob.Reset() ob.Reset()
eb.Reset() eb.Reset()
caKeyF, err := os.CreateTemp("", "sign-cert.key") caKeyF, err := os.CreateTemp("", "sign-cert.key")
assert.Nil(t, err) require.NoError(t, err)
defer os.Remove(caKeyF.Name()) defer os.Remove(caKeyF.Name())
args = []string{"-version", "1", "-ca-crt", "./nope", "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"} args = []string{"-version", "1", "-ca-crt", "./nope", "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
assert.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing ca-key: input did not contain a valid PEM encoded block") require.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing ca-key: input did not contain a valid PEM encoded block")
assert.Empty(t, ob.String()) assert.Empty(t, ob.String())
assert.Empty(t, eb.String()) assert.Empty(t, eb.String())
@@ -125,7 +126,7 @@ func Test_signCert(t *testing.T) {
// failed to read cert // failed to read cert
args = []string{"-version", "1", "-ca-crt", "./nope", "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"} args = []string{"-version", "1", "-ca-crt", "./nope", "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
assert.EqualError(t, signCert(args, ob, eb, nopw), "error while reading ca-crt: open ./nope: "+NoSuchFileError) require.EqualError(t, signCert(args, ob, eb, nopw), "error while reading ca-crt: open ./nope: "+NoSuchFileError)
assert.Empty(t, ob.String()) assert.Empty(t, ob.String())
assert.Empty(t, eb.String()) assert.Empty(t, eb.String())
@@ -133,11 +134,11 @@ func Test_signCert(t *testing.T) {
ob.Reset() ob.Reset()
eb.Reset() eb.Reset()
caCrtF, err := os.CreateTemp("", "sign-cert.crt") caCrtF, err := os.CreateTemp("", "sign-cert.crt")
assert.Nil(t, err) require.NoError(t, err)
defer os.Remove(caCrtF.Name()) defer os.Remove(caCrtF.Name())
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"} args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
assert.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing ca-crt: input did not contain a valid PEM encoded block") require.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing ca-crt: input did not contain a valid PEM encoded block")
assert.Empty(t, ob.String()) assert.Empty(t, ob.String())
assert.Empty(t, eb.String()) assert.Empty(t, eb.String())
@@ -148,7 +149,7 @@ func Test_signCert(t *testing.T) {
// failed to read pub // failed to read pub
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-in-pub", "./nope", "-duration", "100m"} args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-in-pub", "./nope", "-duration", "100m"}
assert.EqualError(t, signCert(args, ob, eb, nopw), "error while reading in-pub: open ./nope: "+NoSuchFileError) require.EqualError(t, signCert(args, ob, eb, nopw), "error while reading in-pub: open ./nope: "+NoSuchFileError)
assert.Empty(t, ob.String()) assert.Empty(t, ob.String())
assert.Empty(t, eb.String()) assert.Empty(t, eb.String())
@@ -156,11 +157,11 @@ func Test_signCert(t *testing.T) {
ob.Reset() ob.Reset()
eb.Reset() eb.Reset()
inPubF, err := os.CreateTemp("", "in.pub") inPubF, err := os.CreateTemp("", "in.pub")
assert.Nil(t, err) require.NoError(t, err)
defer os.Remove(inPubF.Name()) defer os.Remove(inPubF.Name())
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-in-pub", inPubF.Name(), "-duration", "100m"} args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-in-pub", inPubF.Name(), "-duration", "100m"}
assert.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing in-pub: input did not contain a valid PEM encoded block") require.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing in-pub: input did not contain a valid PEM encoded block")
assert.Empty(t, ob.String()) assert.Empty(t, ob.String())
assert.Empty(t, eb.String()) assert.Empty(t, eb.String())
@@ -210,14 +211,14 @@ func Test_signCert(t *testing.T) {
// mismatched ca key // mismatched ca key
_, caPriv2, _ := ed25519.GenerateKey(rand.Reader) _, caPriv2, _ := ed25519.GenerateKey(rand.Reader)
caKeyF2, err := os.CreateTemp("", "sign-cert-2.key") caKeyF2, err := os.CreateTemp("", "sign-cert-2.key")
assert.Nil(t, err) require.NoError(t, err)
defer os.Remove(caKeyF2.Name()) defer os.Remove(caKeyF2.Name())
caKeyF2.Write(cert.MarshalSigningPrivateKeyToPEM(cert.Curve_CURVE25519, caPriv2)) caKeyF2.Write(cert.MarshalSigningPrivateKeyToPEM(cert.Curve_CURVE25519, caPriv2))
ob.Reset() ob.Reset()
eb.Reset() eb.Reset()
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF2.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "a"} args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF2.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "a"}
assert.EqualError(t, signCert(args, ob, eb, nopw), "refusing to sign, root certificate does not match private key") require.EqualError(t, signCert(args, ob, eb, nopw), "refusing to sign, root certificate does not match private key")
assert.Empty(t, ob.String()) assert.Empty(t, ob.String())
assert.Empty(t, eb.String()) assert.Empty(t, eb.String())
@@ -225,34 +226,34 @@ func Test_signCert(t *testing.T) {
ob.Reset() ob.Reset()
eb.Reset() eb.Reset()
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "/do/not/write/pleasecrt", "-out-key", "/do/not/write/pleasekey", "-duration", "100m", "-subnets", "10.1.1.1/32"} args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "/do/not/write/pleasecrt", "-out-key", "/do/not/write/pleasekey", "-duration", "100m", "-subnets", "10.1.1.1/32"}
assert.EqualError(t, signCert(args, ob, eb, nopw), "error while writing out-key: open /do/not/write/pleasekey: "+NoSuchDirError) require.EqualError(t, signCert(args, ob, eb, nopw), "error while writing out-key: open /do/not/write/pleasekey: "+NoSuchDirError)
assert.Empty(t, ob.String()) assert.Empty(t, ob.String())
assert.Empty(t, eb.String()) assert.Empty(t, eb.String())
// create temp key file // create temp key file
keyF, err := os.CreateTemp("", "test.key") keyF, err := os.CreateTemp("", "test.key")
assert.Nil(t, err) require.NoError(t, err)
os.Remove(keyF.Name()) os.Remove(keyF.Name())
// failed cert write // failed cert write
ob.Reset() ob.Reset()
eb.Reset() eb.Reset()
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "/do/not/write/pleasecrt", "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32"} args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "/do/not/write/pleasecrt", "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32"}
assert.EqualError(t, signCert(args, ob, eb, nopw), "error while writing out-crt: open /do/not/write/pleasecrt: "+NoSuchDirError) require.EqualError(t, signCert(args, ob, eb, nopw), "error while writing out-crt: open /do/not/write/pleasecrt: "+NoSuchDirError)
assert.Empty(t, ob.String()) assert.Empty(t, ob.String())
assert.Empty(t, eb.String()) assert.Empty(t, eb.String())
os.Remove(keyF.Name()) os.Remove(keyF.Name())
// create temp cert file // create temp cert file
crtF, err := os.CreateTemp("", "test.crt") crtF, err := os.CreateTemp("", "test.crt")
assert.Nil(t, err) require.NoError(t, err)
os.Remove(crtF.Name()) os.Remove(crtF.Name())
// test proper cert with removed empty groups and subnets // test proper cert with removed empty groups and subnets
ob.Reset() ob.Reset()
eb.Reset() eb.Reset()
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
assert.Nil(t, signCert(args, ob, eb, nopw)) require.NoError(t, signCert(args, ob, eb, nopw))
assert.Empty(t, ob.String()) assert.Empty(t, ob.String())
assert.Empty(t, eb.String()) assert.Empty(t, eb.String())
@@ -260,14 +261,14 @@ func Test_signCert(t *testing.T) {
rb, _ := os.ReadFile(keyF.Name()) rb, _ := os.ReadFile(keyF.Name())
lKey, b, curve, err := cert.UnmarshalPrivateKeyFromPEM(rb) lKey, b, curve, err := cert.UnmarshalPrivateKeyFromPEM(rb)
assert.Equal(t, cert.Curve_CURVE25519, curve) assert.Equal(t, cert.Curve_CURVE25519, curve)
assert.Len(t, b, 0) assert.Empty(t, b)
assert.Nil(t, err) require.NoError(t, err)
assert.Len(t, lKey, 32) assert.Len(t, lKey, 32)
rb, _ = os.ReadFile(crtF.Name()) rb, _ = os.ReadFile(crtF.Name())
lCrt, b, err := cert.UnmarshalCertificateFromPEM(rb) lCrt, b, err := cert.UnmarshalCertificateFromPEM(rb)
assert.Len(t, b, 0) assert.Empty(t, b)
assert.Nil(t, err) require.NoError(t, err)
assert.Equal(t, "test", lCrt.Name()) assert.Equal(t, "test", lCrt.Name())
assert.Equal(t, "1.1.1.1/24", lCrt.Networks()[0].String()) assert.Equal(t, "1.1.1.1/24", lCrt.Networks()[0].String())
@@ -295,15 +296,15 @@ func Test_signCert(t *testing.T) {
ob.Reset() ob.Reset()
eb.Reset() eb.Reset()
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-in-pub", inPubF.Name(), "-duration", "100m", "-groups", "1"} args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-in-pub", inPubF.Name(), "-duration", "100m", "-groups", "1"}
assert.Nil(t, signCert(args, ob, eb, nopw)) require.NoError(t, signCert(args, ob, eb, nopw))
assert.Empty(t, ob.String()) assert.Empty(t, ob.String())
assert.Empty(t, eb.String()) assert.Empty(t, eb.String())
// read cert file and check pub key matches in-pub // read cert file and check pub key matches in-pub
rb, _ = os.ReadFile(crtF.Name()) rb, _ = os.ReadFile(crtF.Name())
lCrt, b, err = cert.UnmarshalCertificateFromPEM(rb) lCrt, b, err = cert.UnmarshalCertificateFromPEM(rb)
assert.Len(t, b, 0) assert.Empty(t, b)
assert.Nil(t, err) require.NoError(t, err)
assert.Equal(t, lCrt.PublicKey(), inPub) assert.Equal(t, lCrt.PublicKey(), inPub)
// test refuse to sign cert with duration beyond root // test refuse to sign cert with duration beyond root
@@ -312,7 +313,7 @@ func Test_signCert(t *testing.T) {
os.Remove(keyF.Name()) os.Remove(keyF.Name())
os.Remove(crtF.Name()) os.Remove(crtF.Name())
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "1000m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "1000m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
assert.EqualError(t, signCert(args, ob, eb, nopw), "error while signing: certificate expires after signing certificate") require.EqualError(t, signCert(args, ob, eb, nopw), "error while signing: certificate expires after signing certificate")
assert.Empty(t, ob.String()) assert.Empty(t, ob.String())
assert.Empty(t, eb.String()) assert.Empty(t, eb.String())
@@ -320,14 +321,14 @@ func Test_signCert(t *testing.T) {
os.Remove(keyF.Name()) os.Remove(keyF.Name())
os.Remove(crtF.Name()) os.Remove(crtF.Name())
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
assert.Nil(t, signCert(args, ob, eb, nopw)) require.NoError(t, signCert(args, ob, eb, nopw))
// test that we won't overwrite existing key file // test that we won't overwrite existing key file
os.Remove(crtF.Name()) os.Remove(crtF.Name())
ob.Reset() ob.Reset()
eb.Reset() eb.Reset()
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
assert.EqualError(t, signCert(args, ob, eb, nopw), "refusing to overwrite existing key: "+keyF.Name()) require.EqualError(t, signCert(args, ob, eb, nopw), "refusing to overwrite existing key: "+keyF.Name())
assert.Empty(t, ob.String()) assert.Empty(t, ob.String())
assert.Empty(t, eb.String()) assert.Empty(t, eb.String())
@@ -335,14 +336,14 @@ func Test_signCert(t *testing.T) {
os.Remove(keyF.Name()) os.Remove(keyF.Name())
os.Remove(crtF.Name()) os.Remove(crtF.Name())
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
assert.Nil(t, signCert(args, ob, eb, nopw)) require.NoError(t, signCert(args, ob, eb, nopw))
// test that we won't overwrite existing certificate file // test that we won't overwrite existing certificate file
os.Remove(keyF.Name()) os.Remove(keyF.Name())
ob.Reset() ob.Reset()
eb.Reset() eb.Reset()
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
assert.EqualError(t, signCert(args, ob, eb, nopw), "refusing to overwrite existing cert: "+crtF.Name()) require.EqualError(t, signCert(args, ob, eb, nopw), "refusing to overwrite existing cert: "+crtF.Name())
assert.Empty(t, ob.String()) assert.Empty(t, ob.String())
assert.Empty(t, eb.String()) assert.Empty(t, eb.String())
@@ -355,11 +356,11 @@ func Test_signCert(t *testing.T) {
eb.Reset() eb.Reset()
caKeyF, err = os.CreateTemp("", "sign-cert.key") caKeyF, err = os.CreateTemp("", "sign-cert.key")
assert.Nil(t, err) require.NoError(t, err)
defer os.Remove(caKeyF.Name()) defer os.Remove(caKeyF.Name())
caCrtF, err = os.CreateTemp("", "sign-cert.crt") caCrtF, err = os.CreateTemp("", "sign-cert.crt")
assert.Nil(t, err) require.NoError(t, err)
defer os.Remove(caCrtF.Name()) defer os.Remove(caCrtF.Name())
// generate the encrypted key // generate the encrypted key
@@ -374,7 +375,7 @@ func Test_signCert(t *testing.T) {
// test with the proper password // test with the proper password
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
assert.Nil(t, signCert(args, ob, eb, testpw)) require.NoError(t, signCert(args, ob, eb, testpw))
assert.Equal(t, "Enter passphrase: ", ob.String()) assert.Equal(t, "Enter passphrase: ", ob.String())
assert.Empty(t, eb.String()) assert.Empty(t, eb.String())
@@ -384,7 +385,7 @@ func Test_signCert(t *testing.T) {
testpw.password = []byte("invalid password") testpw.password = []byte("invalid password")
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
assert.Error(t, signCert(args, ob, eb, testpw)) require.Error(t, signCert(args, ob, eb, testpw))
assert.Equal(t, "Enter passphrase: ", ob.String()) assert.Equal(t, "Enter passphrase: ", ob.String())
assert.Empty(t, eb.String()) assert.Empty(t, eb.String())
@@ -393,7 +394,7 @@ func Test_signCert(t *testing.T) {
eb.Reset() eb.Reset()
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
assert.Error(t, signCert(args, ob, eb, nopw)) require.Error(t, signCert(args, ob, eb, nopw))
// normally the user hitting enter on the prompt would add newlines between these // normally the user hitting enter on the prompt would add newlines between these
assert.Equal(t, "Enter passphrase: Enter passphrase: Enter passphrase: Enter passphrase: Enter passphrase: ", ob.String()) assert.Equal(t, "Enter passphrase: Enter passphrase: Enter passphrase: Enter passphrase: Enter passphrase: ", ob.String())
assert.Empty(t, eb.String()) assert.Empty(t, eb.String())
@@ -403,7 +404,7 @@ func Test_signCert(t *testing.T) {
eb.Reset() eb.Reset()
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
assert.Error(t, signCert(args, ob, eb, errpw)) require.Error(t, signCert(args, ob, eb, errpw))
assert.Equal(t, "Enter passphrase: ", ob.String()) assert.Equal(t, "Enter passphrase: ", ob.String())
assert.Empty(t, eb.String()) assert.Empty(t, eb.String())
} }

View File

@@ -1,12 +1,12 @@
package main package main
import ( import (
"bytes"
"errors" "errors"
"flag" "flag"
"fmt" "fmt"
"io" "io"
"os" "os"
"strings"
"time" "time"
"github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/cert"
@@ -52,7 +52,7 @@ func verify(args []string, out io.Writer, errOut io.Writer) error {
return fmt.Errorf("error while adding ca cert to pool: %w", err) return fmt.Errorf("error while adding ca cert to pool: %w", err)
} }
if rawCACert == nil || len(rawCACert) == 0 || strings.TrimSpace(string(rawCACert)) == "" { if len(rawCACert) == 0 || len(bytes.TrimSpace(rawCACert)) == 0 {
break break
} }
} }

View File

@@ -3,13 +3,13 @@ package main
import ( import (
"bytes" "bytes"
"crypto/rand" "crypto/rand"
"errors"
"os" "os"
"testing" "testing"
"time" "time"
"github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/cert"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/crypto/ed25519" "golang.org/x/crypto/ed25519"
) )
@@ -38,33 +38,33 @@ func Test_verify(t *testing.T) {
// required args // required args
assertHelpError(t, verify([]string{"-ca", "derp"}, ob, eb), "-crt is required") assertHelpError(t, verify([]string{"-ca", "derp"}, ob, eb), "-crt is required")
assert.Equal(t, "", ob.String()) assert.Empty(t, ob.String())
assert.Equal(t, "", eb.String()) assert.Empty(t, eb.String())
assertHelpError(t, verify([]string{"-crt", "derp"}, ob, eb), "-ca is required") assertHelpError(t, verify([]string{"-crt", "derp"}, ob, eb), "-ca is required")
assert.Equal(t, "", ob.String()) assert.Empty(t, ob.String())
assert.Equal(t, "", eb.String()) assert.Empty(t, eb.String())
// no ca at path // no ca at path
ob.Reset() ob.Reset()
eb.Reset() eb.Reset()
err := verify([]string{"-ca", "does_not_exist", "-crt", "does_not_exist"}, ob, eb) err := verify([]string{"-ca", "does_not_exist", "-crt", "does_not_exist"}, ob, eb)
assert.Equal(t, "", ob.String()) assert.Empty(t, ob.String())
assert.Equal(t, "", eb.String()) assert.Empty(t, eb.String())
assert.EqualError(t, err, "error while reading ca: open does_not_exist: "+NoSuchFileError) require.EqualError(t, err, "error while reading ca: open does_not_exist: "+NoSuchFileError)
// invalid ca at path // invalid ca at path
ob.Reset() ob.Reset()
eb.Reset() eb.Reset()
caFile, err := os.CreateTemp("", "verify-ca") caFile, err := os.CreateTemp("", "verify-ca")
assert.Nil(t, err) require.NoError(t, err)
defer os.Remove(caFile.Name()) defer os.Remove(caFile.Name())
caFile.WriteString("-----BEGIN NOPE-----") caFile.WriteString("-----BEGIN NOPE-----")
err = verify([]string{"-ca", caFile.Name(), "-crt", "does_not_exist"}, ob, eb) err = verify([]string{"-ca", caFile.Name(), "-crt", "does_not_exist"}, ob, eb)
assert.Equal(t, "", ob.String()) assert.Empty(t, ob.String())
assert.Equal(t, "", eb.String()) assert.Empty(t, eb.String())
assert.EqualError(t, err, "error while adding ca cert to pool: input did not contain a valid PEM encoded block") require.EqualError(t, err, "error while adding ca cert to pool: input did not contain a valid PEM encoded block")
// make a ca for later // make a ca for later
caPub, caPriv, _ := ed25519.GenerateKey(rand.Reader) caPub, caPriv, _ := ed25519.GenerateKey(rand.Reader)
@@ -76,28 +76,28 @@ func Test_verify(t *testing.T) {
// no crt at path // no crt at path
err = verify([]string{"-ca", caFile.Name(), "-crt", "does_not_exist"}, ob, eb) err = verify([]string{"-ca", caFile.Name(), "-crt", "does_not_exist"}, ob, eb)
assert.Equal(t, "", ob.String()) assert.Empty(t, ob.String())
assert.Equal(t, "", eb.String()) assert.Empty(t, eb.String())
assert.EqualError(t, err, "unable to read crt: open does_not_exist: "+NoSuchFileError) require.EqualError(t, err, "unable to read crt: open does_not_exist: "+NoSuchFileError)
// invalid crt at path // invalid crt at path
ob.Reset() ob.Reset()
eb.Reset() eb.Reset()
certFile, err := os.CreateTemp("", "verify-cert") certFile, err := os.CreateTemp("", "verify-cert")
assert.Nil(t, err) require.NoError(t, err)
defer os.Remove(certFile.Name()) defer os.Remove(certFile.Name())
certFile.WriteString("-----BEGIN NOPE-----") certFile.WriteString("-----BEGIN NOPE-----")
err = verify([]string{"-ca", caFile.Name(), "-crt", certFile.Name()}, ob, eb) err = verify([]string{"-ca", caFile.Name(), "-crt", certFile.Name()}, ob, eb)
assert.Equal(t, "", ob.String()) assert.Empty(t, ob.String())
assert.Equal(t, "", eb.String()) assert.Empty(t, eb.String())
assert.EqualError(t, err, "error while parsing crt: input did not contain a valid PEM encoded block") require.EqualError(t, err, "error while parsing crt: input did not contain a valid PEM encoded block")
// unverifiable cert at path // unverifiable cert at path
crt, _ := NewTestCert(ca, caPriv, "test-cert", time.Now().Add(time.Hour*-1), time.Now().Add(time.Hour), nil, nil, nil) crt, _ := NewTestCert(ca, caPriv, "test-cert", time.Now().Add(time.Hour*-1), time.Now().Add(time.Hour), nil, nil, nil)
// Slightly evil hack to modify the certificate after it was sealed to generate an invalid signature // Slightly evil hack to modify the certificate after it was sealed to generate an invalid signature
pub := crt.PublicKey() pub := crt.PublicKey()
for i, _ := range pub { for i := range pub {
pub[i] = 0 pub[i] = 0
} }
b, _ = crt.MarshalPEM() b, _ = crt.MarshalPEM()
@@ -106,9 +106,9 @@ func Test_verify(t *testing.T) {
certFile.Write(b) certFile.Write(b)
err = verify([]string{"-ca", caFile.Name(), "-crt", certFile.Name()}, ob, eb) err = verify([]string{"-ca", caFile.Name(), "-crt", certFile.Name()}, ob, eb)
assert.Equal(t, "", ob.String()) assert.Empty(t, ob.String())
assert.Equal(t, "", eb.String()) assert.Empty(t, eb.String())
assert.True(t, errors.Is(err, cert.ErrSignatureMismatch)) require.ErrorIs(t, err, cert.ErrSignatureMismatch)
// verified cert at path // verified cert at path
crt, _ = NewTestCert(ca, caPriv, "test-cert", time.Now().Add(time.Hour*-1), time.Now().Add(time.Hour), nil, nil, nil) crt, _ = NewTestCert(ca, caPriv, "test-cert", time.Now().Add(time.Hour*-1), time.Now().Add(time.Hour), nil, nil, nil)
@@ -118,7 +118,7 @@ func Test_verify(t *testing.T) {
certFile.Write(b) certFile.Write(b)
err = verify([]string{"-ca", caFile.Name(), "-crt", certFile.Name()}, ob, eb) err = verify([]string{"-ca", caFile.Name(), "-crt", certFile.Name()}, ob, eb)
assert.Equal(t, "", ob.String()) assert.Empty(t, ob.String())
assert.Equal(t, "", eb.String()) assert.Empty(t, eb.String())
assert.Nil(t, err) require.NoError(t, err)
} }

View File

@@ -51,10 +51,7 @@ func (p *program) Stop(s service.Service) error {
func fileExists(filename string) bool { func fileExists(filename string) bool {
_, err := os.Stat(filename) _, err := os.Stat(filename)
if os.IsNotExist(err) { return !os.IsNotExist(err)
return false
}
return true
} }
func doService(configPath *string, configTest *bool, build string, serviceFlag *string) { func doService(configPath *string, configTest *bool, build string, serviceFlag *string) {

View File

@@ -17,14 +17,14 @@ import (
"dario.cat/mergo" "dario.cat/mergo"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"gopkg.in/yaml.v2" "gopkg.in/yaml.v3"
) )
type C struct { type C struct {
path string path string
files []string files []string
Settings map[interface{}]interface{} Settings map[string]any
oldSettings map[interface{}]interface{} oldSettings map[string]any
callbacks []func(*C) callbacks []func(*C)
l *logrus.Logger l *logrus.Logger
reloadLock sync.Mutex reloadLock sync.Mutex
@@ -32,7 +32,7 @@ type C struct {
func NewC(l *logrus.Logger) *C { func NewC(l *logrus.Logger) *C {
return &C{ return &C{
Settings: make(map[interface{}]interface{}), Settings: make(map[string]any),
l: l, l: l,
} }
} }
@@ -63,7 +63,7 @@ func (c *C) Load(path string) error {
func (c *C) LoadString(raw string) error { func (c *C) LoadString(raw string) error {
if raw == "" { if raw == "" {
return errors.New("Empty configuration") return errors.New("empty configuration")
} }
return c.parseRaw([]byte(raw)) return c.parseRaw([]byte(raw))
} }
@@ -92,8 +92,8 @@ func (c *C) HasChanged(k string) bool {
} }
var ( var (
nv interface{} nv any
ov interface{} ov any
) )
if k == "" { if k == "" {
@@ -147,7 +147,7 @@ func (c *C) ReloadConfig() {
c.reloadLock.Lock() c.reloadLock.Lock()
defer c.reloadLock.Unlock() defer c.reloadLock.Unlock()
c.oldSettings = make(map[interface{}]interface{}) c.oldSettings = make(map[string]any)
for k, v := range c.Settings { for k, v := range c.Settings {
c.oldSettings[k] = v c.oldSettings[k] = v
} }
@@ -167,7 +167,7 @@ func (c *C) ReloadConfigString(raw string) error {
c.reloadLock.Lock() c.reloadLock.Lock()
defer c.reloadLock.Unlock() defer c.reloadLock.Unlock()
c.oldSettings = make(map[interface{}]interface{}) c.oldSettings = make(map[string]any)
for k, v := range c.Settings { for k, v := range c.Settings {
c.oldSettings[k] = v c.oldSettings[k] = v
} }
@@ -201,7 +201,7 @@ func (c *C) GetStringSlice(k string, d []string) []string {
return d return d
} }
rv, ok := r.([]interface{}) rv, ok := r.([]any)
if !ok { if !ok {
return d return d
} }
@@ -215,13 +215,13 @@ func (c *C) GetStringSlice(k string, d []string) []string {
} }
// GetMap will get the map for k or return the default d if not found or invalid // GetMap will get the map for k or return the default d if not found or invalid
func (c *C) GetMap(k string, d map[interface{}]interface{}) map[interface{}]interface{} { func (c *C) GetMap(k string, d map[string]any) map[string]any {
r := c.Get(k) r := c.Get(k)
if r == nil { if r == nil {
return d return d
} }
v, ok := r.(map[interface{}]interface{}) v, ok := r.(map[string]any)
if !ok { if !ok {
return d return d
} }
@@ -243,7 +243,7 @@ func (c *C) GetInt(k string, d int) int {
// GetUint32 will get the uint32 for k or return the default d if not found or invalid // GetUint32 will get the uint32 for k or return the default d if not found or invalid
func (c *C) GetUint32(k string, d uint32) uint32 { func (c *C) GetUint32(k string, d uint32) uint32 {
r := c.GetInt(k, int(d)) r := c.GetInt(k, int(d))
if uint64(r) > uint64(math.MaxUint32) { if r < 0 || uint64(r) > uint64(math.MaxUint32) {
return d return d
} }
return uint32(r) return uint32(r)
@@ -266,6 +266,22 @@ func (c *C) GetBool(k string, d bool) bool {
return v return v
} }
func AsBool(v any) (value bool, ok bool) {
switch x := v.(type) {
case bool:
return x, true
case string:
switch x {
case "y", "yes":
return true, true
case "n", "no":
return false, true
}
}
return false, false
}
// GetDuration will get the duration for k or return the default d if not found or invalid // GetDuration will get the duration for k or return the default d if not found or invalid
func (c *C) GetDuration(k string, d time.Duration) time.Duration { func (c *C) GetDuration(k string, d time.Duration) time.Duration {
r := c.GetString(k, "") r := c.GetString(k, "")
@@ -276,7 +292,7 @@ func (c *C) GetDuration(k string, d time.Duration) time.Duration {
return v return v
} }
func (c *C) Get(k string) interface{} { func (c *C) Get(k string) any {
return c.get(k, c.Settings) return c.get(k, c.Settings)
} }
@@ -284,10 +300,10 @@ func (c *C) IsSet(k string) bool {
return c.get(k, c.Settings) != nil return c.get(k, c.Settings) != nil
} }
func (c *C) get(k string, v interface{}) interface{} { func (c *C) get(k string, v any) any {
parts := strings.Split(k, ".") parts := strings.Split(k, ".")
for _, p := range parts { for _, p := range parts {
m, ok := v.(map[interface{}]interface{}) m, ok := v.(map[string]any)
if !ok { if !ok {
return nil return nil
} }
@@ -346,7 +362,7 @@ func (c *C) addFile(path string, direct bool) error {
} }
func (c *C) parseRaw(b []byte) error { func (c *C) parseRaw(b []byte) error {
var m map[interface{}]interface{} var m map[string]any
err := yaml.Unmarshal(b, &m) err := yaml.Unmarshal(b, &m)
if err != nil { if err != nil {
@@ -358,7 +374,7 @@ func (c *C) parseRaw(b []byte) error {
} }
func (c *C) parse() error { func (c *C) parse() error {
var m map[interface{}]interface{} var m map[string]any
for _, path := range c.files { for _, path := range c.files {
b, err := os.ReadFile(path) b, err := os.ReadFile(path)
@@ -366,7 +382,7 @@ func (c *C) parse() error {
return err return err
} }
var nm map[interface{}]interface{} var nm map[string]any
err = yaml.Unmarshal(b, &nm) err = yaml.Unmarshal(b, &nm)
if err != nil { if err != nil {
return err return err

View File

@@ -10,7 +10,7 @@ import (
"github.com/slackhq/nebula/test" "github.com/slackhq/nebula/test"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"gopkg.in/yaml.v2" "gopkg.in/yaml.v3"
) )
func TestConfig_Load(t *testing.T) { func TestConfig_Load(t *testing.T) {
@@ -19,20 +19,20 @@ func TestConfig_Load(t *testing.T) {
// invalid yaml // invalid yaml
c := NewC(l) c := NewC(l)
os.WriteFile(filepath.Join(dir, "01.yaml"), []byte(" invalid yaml"), 0644) os.WriteFile(filepath.Join(dir, "01.yaml"), []byte(" invalid yaml"), 0644)
assert.EqualError(t, c.Load(dir), "yaml: unmarshal errors:\n line 1: cannot unmarshal !!str `invalid...` into map[interface {}]interface {}") require.EqualError(t, c.Load(dir), "yaml: unmarshal errors:\n line 1: cannot unmarshal !!str `invalid...` into map[string]interface {}")
// simple multi config merge // simple multi config merge
c = NewC(l) c = NewC(l)
os.RemoveAll(dir) os.RemoveAll(dir)
os.Mkdir(dir, 0755) os.Mkdir(dir, 0755)
assert.Nil(t, err) require.NoError(t, err)
os.WriteFile(filepath.Join(dir, "01.yaml"), []byte("outer:\n inner: hi"), 0644) os.WriteFile(filepath.Join(dir, "01.yaml"), []byte("outer:\n inner: hi"), 0644)
os.WriteFile(filepath.Join(dir, "02.yml"), []byte("outer:\n inner: override\nnew: hi"), 0644) os.WriteFile(filepath.Join(dir, "02.yml"), []byte("outer:\n inner: override\nnew: hi"), 0644)
assert.Nil(t, c.Load(dir)) require.NoError(t, c.Load(dir))
expected := map[interface{}]interface{}{ expected := map[string]any{
"outer": map[interface{}]interface{}{ "outer": map[string]any{
"inner": "override", "inner": "override",
}, },
"new": "hi", "new": "hi",
@@ -44,12 +44,12 @@ func TestConfig_Get(t *testing.T) {
l := test.NewLogger() l := test.NewLogger()
// test simple type // test simple type
c := NewC(l) c := NewC(l)
c.Settings["firewall"] = map[interface{}]interface{}{"outbound": "hi"} c.Settings["firewall"] = map[string]any{"outbound": "hi"}
assert.Equal(t, "hi", c.Get("firewall.outbound")) assert.Equal(t, "hi", c.Get("firewall.outbound"))
// test complex type // test complex type
inner := []map[interface{}]interface{}{{"port": "1", "code": "2"}} inner := []map[string]any{{"port": "1", "code": "2"}}
c.Settings["firewall"] = map[interface{}]interface{}{"outbound": inner} c.Settings["firewall"] = map[string]any{"outbound": inner}
assert.EqualValues(t, inner, c.Get("firewall.outbound")) assert.EqualValues(t, inner, c.Get("firewall.outbound"))
// test missing // test missing
@@ -59,7 +59,7 @@ func TestConfig_Get(t *testing.T) {
func TestConfig_GetStringSlice(t *testing.T) { func TestConfig_GetStringSlice(t *testing.T) {
l := test.NewLogger() l := test.NewLogger()
c := NewC(l) c := NewC(l)
c.Settings["slice"] = []interface{}{"one", "two"} c.Settings["slice"] = []any{"one", "two"}
assert.Equal(t, []string{"one", "two"}, c.GetStringSlice("slice", []string{})) assert.Equal(t, []string{"one", "two"}, c.GetStringSlice("slice", []string{}))
} }
@@ -67,28 +67,28 @@ func TestConfig_GetBool(t *testing.T) {
l := test.NewLogger() l := test.NewLogger()
c := NewC(l) c := NewC(l)
c.Settings["bool"] = true c.Settings["bool"] = true
assert.Equal(t, true, c.GetBool("bool", false)) assert.True(t, c.GetBool("bool", false))
c.Settings["bool"] = "true" c.Settings["bool"] = "true"
assert.Equal(t, true, c.GetBool("bool", false)) assert.True(t, c.GetBool("bool", false))
c.Settings["bool"] = false c.Settings["bool"] = false
assert.Equal(t, false, c.GetBool("bool", true)) assert.False(t, c.GetBool("bool", true))
c.Settings["bool"] = "false" c.Settings["bool"] = "false"
assert.Equal(t, false, c.GetBool("bool", true)) assert.False(t, c.GetBool("bool", true))
c.Settings["bool"] = "Y" c.Settings["bool"] = "Y"
assert.Equal(t, true, c.GetBool("bool", false)) assert.True(t, c.GetBool("bool", false))
c.Settings["bool"] = "yEs" c.Settings["bool"] = "yEs"
assert.Equal(t, true, c.GetBool("bool", false)) assert.True(t, c.GetBool("bool", false))
c.Settings["bool"] = "N" c.Settings["bool"] = "N"
assert.Equal(t, false, c.GetBool("bool", true)) assert.False(t, c.GetBool("bool", true))
c.Settings["bool"] = "nO" c.Settings["bool"] = "nO"
assert.Equal(t, false, c.GetBool("bool", true)) assert.False(t, c.GetBool("bool", true))
} }
func TestConfig_HasChanged(t *testing.T) { func TestConfig_HasChanged(t *testing.T) {
@@ -101,14 +101,14 @@ func TestConfig_HasChanged(t *testing.T) {
// Test key change // Test key change
c = NewC(l) c = NewC(l)
c.Settings["test"] = "hi" c.Settings["test"] = "hi"
c.oldSettings = map[interface{}]interface{}{"test": "no"} c.oldSettings = map[string]any{"test": "no"}
assert.True(t, c.HasChanged("test")) assert.True(t, c.HasChanged("test"))
assert.True(t, c.HasChanged("")) assert.True(t, c.HasChanged(""))
// No key change // No key change
c = NewC(l) c = NewC(l)
c.Settings["test"] = "hi" c.Settings["test"] = "hi"
c.oldSettings = map[interface{}]interface{}{"test": "hi"} c.oldSettings = map[string]any{"test": "hi"}
assert.False(t, c.HasChanged("test")) assert.False(t, c.HasChanged("test"))
assert.False(t, c.HasChanged("")) assert.False(t, c.HasChanged(""))
} }
@@ -117,11 +117,11 @@ func TestConfig_ReloadConfig(t *testing.T) {
l := test.NewLogger() l := test.NewLogger()
done := make(chan bool, 1) done := make(chan bool, 1)
dir, err := os.MkdirTemp("", "config-test") dir, err := os.MkdirTemp("", "config-test")
assert.Nil(t, err) require.NoError(t, err)
os.WriteFile(filepath.Join(dir, "01.yaml"), []byte("outer:\n inner: hi"), 0644) os.WriteFile(filepath.Join(dir, "01.yaml"), []byte("outer:\n inner: hi"), 0644)
c := NewC(l) c := NewC(l)
assert.Nil(t, c.Load(dir)) require.NoError(t, c.Load(dir))
assert.False(t, c.HasChanged("outer.inner")) assert.False(t, c.HasChanged("outer.inner"))
assert.False(t, c.HasChanged("outer")) assert.False(t, c.HasChanged("outer"))
@@ -184,11 +184,11 @@ firewall:
`), `),
} }
var m map[any]any var m map[string]any
// merge the same way config.parse() merges // merge the same way config.parse() merges
for _, b := range configs { for _, b := range configs {
var nm map[any]any var nm map[string]any
err := yaml.Unmarshal(b, &nm) err := yaml.Unmarshal(b, &nm)
require.NoError(t, err) require.NoError(t, err)
@@ -205,15 +205,15 @@ firewall:
t.Logf("Merged Config as YAML:\n%s", mYaml) t.Logf("Merged Config as YAML:\n%s", mYaml)
// If a bug is present, some items might be replaced instead of merged like we expect // If a bug is present, some items might be replaced instead of merged like we expect
expected := map[any]any{ expected := map[string]any{
"firewall": map[any]any{ "firewall": map[string]any{
"inbound": []any{ "inbound": []any{
map[any]any{"host": "any", "port": "any", "proto": "icmp"}, map[string]any{"host": "any", "port": "any", "proto": "icmp"},
map[any]any{"groups": []any{"server"}, "port": 443, "proto": "tcp"}, map[string]any{"groups": []any{"server"}, "port": 443, "proto": "tcp"},
map[any]any{"groups": []any{"webapp"}, "port": 443, "proto": "tcp"}}, map[string]any{"groups": []any{"webapp"}, "port": 443, "proto": "tcp"}},
"outbound": []any{ "outbound": []any{
map[any]any{"host": "any", "port": "any", "proto": "any"}}}, map[string]any{"host": "any", "port": "any", "proto": "any"}}},
"listen": map[any]any{ "listen": map[string]any{
"host": "0.0.0.0", "host": "0.0.0.0",
"port": 4242, "port": 4242,
}, },

View File

@@ -154,7 +154,7 @@ func (n *connectionManager) Run(ctx context.Context) {
defer clockSource.Stop() defer clockSource.Stop()
p := []byte("") p := []byte("")
nb := make([]byte, 12, 12) nb := make([]byte, 12)
out := make([]byte, mtu) out := make([]byte, mtu)
for { for {
@@ -355,7 +355,7 @@ func (n *connectionManager) makeTrafficDecision(localIndex uint32, now time.Time
decision = tryRehandshake decision = tryRehandshake
} else { } else {
if n.shouldSwapPrimary(hostinfo, primary) { if n.shouldSwapPrimary(hostinfo) {
decision = swapPrimary decision = swapPrimary
} else { } else {
// migrate the relays to the primary, if in use. // migrate the relays to the primary, if in use.
@@ -384,7 +384,7 @@ func (n *connectionManager) makeTrafficDecision(localIndex uint32, now time.Time
} }
decision := doNothing decision := doNothing
if hostinfo != nil && hostinfo.ConnectionState != nil && mainHostInfo { if hostinfo.ConnectionState != nil && mainHostInfo {
if !outTraffic { if !outTraffic {
// If we aren't sending or receiving traffic then its an unused tunnel and we don't to test the tunnel. // If we aren't sending or receiving traffic then its an unused tunnel and we don't to test the tunnel.
// Just maintain NAT state if configured to do so. // Just maintain NAT state if configured to do so.
@@ -421,7 +421,7 @@ func (n *connectionManager) makeTrafficDecision(localIndex uint32, now time.Time
return decision, hostinfo, nil return decision, hostinfo, nil
} }
func (n *connectionManager) shouldSwapPrimary(current, primary *HostInfo) bool { func (n *connectionManager) shouldSwapPrimary(current *HostInfo) bool {
// The primary tunnel is the most recent handshake to complete locally and should work entirely fine. // The primary tunnel is the most recent handshake to complete locally and should work entirely fine.
// If we are here then we have multiple tunnels for a host pair and neither side believes the same tunnel is primary. // If we are here then we have multiple tunnels for a host pair and neither side believes the same tunnel is primary.
// Let's sort this out. // Let's sort this out.
@@ -498,7 +498,7 @@ func (n *connectionManager) tryRehandshake(hostinfo *HostInfo) {
cs := n.intf.pki.getCertState() cs := n.intf.pki.getCertState()
curCrt := hostinfo.ConnectionState.myCert curCrt := hostinfo.ConnectionState.myCert
myCrt := cs.getCertificate(curCrt.Version()) myCrt := cs.getCertificate(curCrt.Version())
if curCrt.Version() >= cs.defaultVersion && bytes.Equal(curCrt.Signature(), myCrt.Signature()) == true { if curCrt.Version() >= cs.initiatingVersion && bytes.Equal(curCrt.Signature(), myCrt.Signature()) {
// The current tunnel is using the latest certificate and version, no need to rehandshake. // The current tunnel is using the latest certificate and version, no need to rehandshake.
return return
} }

View File

@@ -14,6 +14,7 @@ import (
"github.com/slackhq/nebula/test" "github.com/slackhq/nebula/test"
"github.com/slackhq/nebula/udp" "github.com/slackhq/nebula/udp"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
func newTestLighthouse() *LightHouse { func newTestLighthouse() *LightHouse {
@@ -43,10 +44,10 @@ func Test_NewConnectionManagerTest(t *testing.T) {
hostMap.preferredRanges.Store(&preferredRanges) hostMap.preferredRanges.Store(&preferredRanges)
cs := &CertState{ cs := &CertState{
defaultVersion: cert.Version1, initiatingVersion: cert.Version1,
privateKey: []byte{}, privateKey: []byte{},
v1Cert: &dummyCert{version: cert.Version1}, v1Cert: &dummyCert{version: cert.Version1},
v1HandshakeBytes: []byte{}, v1HandshakeBytes: []byte{},
} }
lh := newTestLighthouse() lh := newTestLighthouse()
@@ -68,7 +69,7 @@ func Test_NewConnectionManagerTest(t *testing.T) {
punchy := NewPunchyFromConfig(l, config.NewC(l)) punchy := NewPunchyFromConfig(l, config.NewC(l))
nc := newConnectionManager(ctx, l, ifce, 5, 10, punchy) nc := newConnectionManager(ctx, l, ifce, 5, 10, punchy)
p := []byte("") p := []byte("")
nb := make([]byte, 12, 12) nb := make([]byte, 12)
out := make([]byte, mtu) out := make([]byte, mtu)
// Add an ip we have established a connection w/ to hostmap // Add an ip we have established a connection w/ to hostmap
@@ -125,10 +126,10 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
hostMap.preferredRanges.Store(&preferredRanges) hostMap.preferredRanges.Store(&preferredRanges)
cs := &CertState{ cs := &CertState{
defaultVersion: cert.Version1, initiatingVersion: cert.Version1,
privateKey: []byte{}, privateKey: []byte{},
v1Cert: &dummyCert{version: cert.Version1}, v1Cert: &dummyCert{version: cert.Version1},
v1HandshakeBytes: []byte{}, v1HandshakeBytes: []byte{},
} }
lh := newTestLighthouse() lh := newTestLighthouse()
@@ -150,7 +151,7 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
punchy := NewPunchyFromConfig(l, config.NewC(l)) punchy := NewPunchyFromConfig(l, config.NewC(l))
nc := newConnectionManager(ctx, l, ifce, 5, 10, punchy) nc := newConnectionManager(ctx, l, ifce, 5, 10, punchy)
p := []byte("") p := []byte("")
nb := make([]byte, 12, 12) nb := make([]byte, 12)
out := make([]byte, mtu) out := make([]byte, mtu)
// Add an ip we have established a connection w/ to hostmap // Add an ip we have established a connection w/ to hostmap
@@ -223,9 +224,9 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
} }
caCert, err := tbs.Sign(nil, cert.Curve_CURVE25519, privCA) caCert, err := tbs.Sign(nil, cert.Curve_CURVE25519, privCA)
assert.NoError(t, err) require.NoError(t, err)
ncp := cert.NewCAPool() ncp := cert.NewCAPool()
assert.NoError(t, ncp.AddCA(caCert)) require.NoError(t, ncp.AddCA(caCert))
pubCrt, _, _ := ed25519.GenerateKey(rand.Reader) pubCrt, _, _ := ed25519.GenerateKey(rand.Reader)
tbs = &cert.TBSCertificate{ tbs = &cert.TBSCertificate{
@@ -237,10 +238,10 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
PublicKey: pubCrt, PublicKey: pubCrt,
} }
peerCert, err := tbs.Sign(caCert, cert.Curve_CURVE25519, privCA) peerCert, err := tbs.Sign(caCert, cert.Curve_CURVE25519, privCA)
assert.NoError(t, err) require.NoError(t, err)
cachedPeerCert, err := ncp.VerifyCertificate(now.Add(time.Second), peerCert) cachedPeerCert, err := ncp.VerifyCertificate(now.Add(time.Second), peerCert)
require.NoError(t, err)
cs := &CertState{ cs := &CertState{
privateKey: []byte{}, privateKey: []byte{},
v1Cert: &dummyCert{}, v1Cert: &dummyCert{},

View File

@@ -215,7 +215,7 @@ func (c *Control) CloseTunnel(vpnIp netip.Addr, localOnly bool) bool {
hostInfo.ConnectionState, hostInfo.ConnectionState,
hostInfo, hostInfo,
[]byte{}, []byte{},
make([]byte, 12, 12), make([]byte, 12),
make([]byte, mtu), make([]byte, mtu),
) )
} }
@@ -231,7 +231,7 @@ func (c *Control) CloseAllTunnels(excludeLighthouses bool) (closed int) {
if excludeLighthouses && c.f.lightHouse.IsAnyLighthouseAddr(h.vpnAddrs) { if excludeLighthouses && c.f.lightHouse.IsAnyLighthouseAddr(h.vpnAddrs) {
return return
} }
c.f.send(header.CloseTunnel, 0, h.ConnectionState, h, []byte{}, make([]byte, 12, 12), make([]byte, mtu)) c.f.send(header.CloseTunnel, 0, h.ConnectionState, h, []byte{}, make([]byte, 12), make([]byte, mtu))
c.f.closeTunnel(h) c.f.closeTunnel(h)
c.l.WithField("vpnAddrs", h.vpnAddrs).WithField("udpAddr", h.remote). c.l.WithField("vpnAddrs", h.vpnAddrs).WithField("udpAddr", h.remote).
@@ -282,9 +282,7 @@ func copyHostInfo(h *HostInfo, preferredRanges []netip.Prefix) ControlHostInfo {
CurrentRemote: h.remote, CurrentRemote: h.remote,
} }
for i, a := range h.vpnAddrs { copy(chi.VpnAddrs, h.vpnAddrs)
chi.VpnAddrs[i] = a
}
if h.ConnectionState != nil { if h.ConnectionState != nil {
chi.MessageCounter = h.ConnectionState.messageCounter.Load() chi.MessageCounter = h.ConnectionState.messageCounter.Load()

View File

@@ -26,13 +26,11 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
remote2 := netip.MustParseAddrPort("[1:2:3:4:5:6:7:8]:4444") remote2 := netip.MustParseAddrPort("[1:2:3:4:5:6:7:8]:4444")
ipNet := net.IPNet{ ipNet := net.IPNet{
IP: remote1.Addr().AsSlice(), IP: remote1.Addr().AsSlice(),
Mask: net.IPMask{255, 255, 255, 0},
} }
ipNet2 := net.IPNet{ ipNet2 := net.IPNet{
IP: remote2.Addr().AsSlice(), IP: remote2.Addr().AsSlice(),
Mask: net.IPMask{255, 255, 255, 0},
} }
remotes := NewRemoteList([]netip.Addr{netip.IPv4Unspecified()}, nil) remotes := NewRemoteList([]netip.Addr{netip.IPv4Unspecified()}, nil)
@@ -101,7 +99,7 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
// Make sure we don't have any unexpected fields // Make sure we don't have any unexpected fields
assertFields(t, []string{"VpnAddrs", "LocalIndex", "RemoteIndex", "RemoteAddrs", "Cert", "MessageCounter", "CurrentRemote", "CurrentRelaysToMe", "CurrentRelaysThroughMe"}, thi) assertFields(t, []string{"VpnAddrs", "LocalIndex", "RemoteIndex", "RemoteAddrs", "Cert", "MessageCounter", "CurrentRemote", "CurrentRelaysToMe", "CurrentRelaysThroughMe"}, thi)
assert.EqualValues(t, &expectedInfo, thi) assert.Equal(t, &expectedInfo, thi)
test.AssertDeepCopyEqual(t, &expectedInfo, thi) test.AssertDeepCopyEqual(t, &expectedInfo, thi)
// Make sure we don't panic if the host info doesn't have a cert yet // Make sure we don't panic if the host info doesn't have a cert yet
@@ -110,7 +108,7 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
}) })
} }
func assertFields(t *testing.T, expected []string, actualStruct interface{}) { func assertFields(t *testing.T, expected []string, actualStruct any) {
val := reflect.ValueOf(actualStruct).Elem() val := reflect.ValueOf(actualStruct).Elem()
fields := make([]string, val.NumField()) fields := make([]string, val.NumField())
for i := 0; i < val.NumField(); i++ { for i := 0; i < val.NumField(); i++ {

View File

@@ -38,24 +38,24 @@ func TestParsequery(t *testing.T) {
func Test_getDnsServerAddr(t *testing.T) { func Test_getDnsServerAddr(t *testing.T) {
c := config.NewC(nil) c := config.NewC(nil)
c.Settings["lighthouse"] = map[interface{}]interface{}{ c.Settings["lighthouse"] = map[string]any{
"dns": map[interface{}]interface{}{ "dns": map[string]any{
"host": "0.0.0.0", "host": "0.0.0.0",
"port": "1", "port": "1",
}, },
} }
assert.Equal(t, "0.0.0.0:1", getDnsServerAddr(c)) assert.Equal(t, "0.0.0.0:1", getDnsServerAddr(c))
c.Settings["lighthouse"] = map[interface{}]interface{}{ c.Settings["lighthouse"] = map[string]any{
"dns": map[interface{}]interface{}{ "dns": map[string]any{
"host": "::", "host": "::",
"port": "1", "port": "1",
}, },
} }
assert.Equal(t, "[::]:1", getDnsServerAddr(c)) assert.Equal(t, "[::]:1", getDnsServerAddr(c))
c.Settings["lighthouse"] = map[interface{}]interface{}{ c.Settings["lighthouse"] = map[string]any{
"dns": map[interface{}]interface{}{ "dns": map[string]any{
"host": "[::]", "host": "[::]",
"port": "1", "port": "1",
}, },
@@ -63,8 +63,8 @@ func Test_getDnsServerAddr(t *testing.T) {
assert.Equal(t, "[::]:1", getDnsServerAddr(c)) assert.Equal(t, "[::]:1", getDnsServerAddr(c))
// Make sure whitespace doesn't mess us up // Make sure whitespace doesn't mess us up
c.Settings["lighthouse"] = map[interface{}]interface{}{ c.Settings["lighthouse"] = map[string]any{
"dns": map[interface{}]interface{}{ "dns": map[string]any{
"host": "[::] ", "host": "[::] ",
"port": "1", "port": "1",
}, },

View File

@@ -19,7 +19,8 @@ import (
"github.com/slackhq/nebula/header" "github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/udp" "github.com/slackhq/nebula/udp"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"gopkg.in/yaml.v2" "github.com/stretchr/testify/require"
"gopkg.in/yaml.v3"
) )
func BenchmarkHotPath(b *testing.B) { func BenchmarkHotPath(b *testing.B) {
@@ -771,7 +772,7 @@ func TestRehandshakingRelays(t *testing.T) {
"key": string(myNextPrivKey), "key": string(myNextPrivKey),
} }
rc, err := yaml.Marshal(relayConfig.Settings) rc, err := yaml.Marshal(relayConfig.Settings)
assert.NoError(t, err) require.NoError(t, err)
relayConfig.ReloadConfigString(string(rc)) relayConfig.ReloadConfigString(string(rc))
for { for {
@@ -875,7 +876,7 @@ func TestRehandshakingRelaysPrimary(t *testing.T) {
"key": string(myNextPrivKey), "key": string(myNextPrivKey),
} }
rc, err := yaml.Marshal(relayConfig.Settings) rc, err := yaml.Marshal(relayConfig.Settings)
assert.NoError(t, err) require.NoError(t, err)
relayConfig.ReloadConfigString(string(rc)) relayConfig.ReloadConfigString(string(rc))
for { for {
@@ -970,7 +971,7 @@ func TestRehandshaking(t *testing.T) {
"key": string(myNextPrivKey), "key": string(myNextPrivKey),
} }
rc, err := yaml.Marshal(myConfig.Settings) rc, err := yaml.Marshal(myConfig.Settings)
assert.NoError(t, err) require.NoError(t, err)
myConfig.ReloadConfigString(string(rc)) myConfig.ReloadConfigString(string(rc))
for { for {
@@ -987,17 +988,17 @@ func TestRehandshaking(t *testing.T) {
r.Log("Got the new cert") r.Log("Got the new cert")
// Flip their firewall to only allowing the new group to catch the tunnels reverting incorrectly // Flip their firewall to only allowing the new group to catch the tunnels reverting incorrectly
rc, err = yaml.Marshal(theirConfig.Settings) rc, err = yaml.Marshal(theirConfig.Settings)
assert.NoError(t, err) require.NoError(t, err)
var theirNewConfig m var theirNewConfig m
assert.NoError(t, yaml.Unmarshal(rc, &theirNewConfig)) require.NoError(t, yaml.Unmarshal(rc, &theirNewConfig))
theirFirewall := theirNewConfig["firewall"].(map[interface{}]interface{}) theirFirewall := theirNewConfig["firewall"].(map[string]any)
theirFirewall["inbound"] = []m{{ theirFirewall["inbound"] = []m{{
"proto": "any", "proto": "any",
"port": "any", "port": "any",
"group": "new group", "group": "new group",
}} }}
rc, err = yaml.Marshal(theirNewConfig) rc, err = yaml.Marshal(theirNewConfig)
assert.NoError(t, err) require.NoError(t, err)
theirConfig.ReloadConfigString(string(rc)) theirConfig.ReloadConfigString(string(rc))
r.Log("Spin until there is only 1 tunnel") r.Log("Spin until there is only 1 tunnel")
@@ -1067,7 +1068,7 @@ func TestRehandshakingLoser(t *testing.T) {
"key": string(theirNextPrivKey), "key": string(theirNextPrivKey),
} }
rc, err := yaml.Marshal(theirConfig.Settings) rc, err := yaml.Marshal(theirConfig.Settings)
assert.NoError(t, err) require.NoError(t, err)
theirConfig.ReloadConfigString(string(rc)) theirConfig.ReloadConfigString(string(rc))
for { for {
@@ -1083,17 +1084,17 @@ func TestRehandshakingLoser(t *testing.T) {
// Flip my firewall to only allowing the new group to catch the tunnels reverting incorrectly // Flip my firewall to only allowing the new group to catch the tunnels reverting incorrectly
rc, err = yaml.Marshal(myConfig.Settings) rc, err = yaml.Marshal(myConfig.Settings)
assert.NoError(t, err) require.NoError(t, err)
var myNewConfig m var myNewConfig m
assert.NoError(t, yaml.Unmarshal(rc, &myNewConfig)) require.NoError(t, yaml.Unmarshal(rc, &myNewConfig))
theirFirewall := myNewConfig["firewall"].(map[interface{}]interface{}) theirFirewall := myNewConfig["firewall"].(map[string]any)
theirFirewall["inbound"] = []m{{ theirFirewall["inbound"] = []m{{
"proto": "any", "proto": "any",
"port": "any", "port": "any",
"group": "their new group", "group": "their new group",
}} }}
rc, err = yaml.Marshal(myNewConfig) rc, err = yaml.Marshal(myNewConfig)
assert.NoError(t, err) require.NoError(t, err)
myConfig.ReloadConfigString(string(rc)) myConfig.ReloadConfigString(string(rc))
r.Log("Spin until there is only 1 tunnel") r.Log("Spin until there is only 1 tunnel")

View File

@@ -22,10 +22,10 @@ import (
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/e2e/router" "github.com/slackhq/nebula/e2e/router"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"gopkg.in/yaml.v2" "gopkg.in/yaml.v3"
) )
type m map[string]interface{} type m = map[string]any
// newSimpleServer creates a nebula instance with many assumptions // newSimpleServer creates a nebula instance with many assumptions
func newSimpleServer(v cert.Version, caCrt cert.Certificate, caKey []byte, name string, sVpnNetworks string, overrides m) (*nebula.Control, []netip.Prefix, netip.AddrPort, *config.C) { func newSimpleServer(v cert.Version, caCrt cert.Certificate, caKey []byte, name string, sVpnNetworks string, overrides m) (*nebula.Control, []netip.Prefix, netip.AddrPort, *config.C) {

View File

@@ -13,11 +13,11 @@ pki:
# disconnect_invalid is a toggle to force a client to be disconnected if the certificate is expired or invalid. # disconnect_invalid is a toggle to force a client to be disconnected if the certificate is expired or invalid.
#disconnect_invalid: true #disconnect_invalid: true
# default_version controls which certificate version is used in handshakes. # initiating_version controls which certificate version is used when initiating handshakes.
# This setting only applies if both a v1 and a v2 certificate are configured, in which case it will default to `1`. # This setting only applies if both a v1 and a v2 certificate are configured, in which case it will default to `1`.
# Once all hosts in the mesh are configured with both a v1 and v2 certificate then this should be changed to `2`. # Once all hosts in the mesh are configured with both a v1 and v2 certificate then this should be changed to `2`.
# After all hosts in the mesh are using a v2 certificate then v1 certificates are no longer needed. # After all hosts in the mesh are using a v2 certificate then v1 certificates are no longer needed.
# default_version: 1 # initiating_version: 1
# The static host map defines a set of hosts with fixed IP addresses on the internet (or any network). # The static host map defines a set of hosts with fixed IP addresses on the internet (or any network).
# A host can have multiple fixed IP addresses defined here, and nebula will try each when establishing a tunnel. # A host can have multiple fixed IP addresses defined here, and nebula will try each when establishing a tunnel.
@@ -126,8 +126,8 @@ lighthouse:
# Port Nebula will be listening on. The default here is 4242. For a lighthouse node, the port should be defined, # Port Nebula will be listening on. The default here is 4242. For a lighthouse node, the port should be defined,
# however using port 0 will dynamically assign a port and is recommended for roaming nodes. # however using port 0 will dynamically assign a port and is recommended for roaming nodes.
listen: listen:
# To listen on both any ipv4 and ipv6 use "::" # To listen on only ipv4, use "0.0.0.0"
host: 0.0.0.0 host: "::"
port: 4242 port: 4242
# Sets the max number of packets to pull from the kernel for each syscall (under systems that support recvmmsg) # Sets the max number of packets to pull from the kernel for each syscall (under systems that support recvmmsg)
# default is 64, does not support reload # default is 64, does not support reload
@@ -144,6 +144,11 @@ listen:
# valid values: always, never, private # valid values: always, never, private
# This setting is reloadable. # This setting is reloadable.
#send_recv_error: always #send_recv_error: always
# The so_sock option is a Linux-specific feature that allows all outgoing Nebula packets to be tagged with a specific identifier.
# This tagging enables IP rule-based filtering. For example, it supports 0.0.0.0/0 unsafe_routes,
# allowing for more precise routing decisions based on the packet tags. Default is 0 meaning no mark is set.
# This setting is reloadable.
#so_mark: 0
# Routines is the number of thread pairs to run that consume from the tun and UDP queues. # Routines is the number of thread pairs to run that consume from the tun and UDP queues.
# Currently, this defaults to 1 which means we have 1 tun queue reader and 1 # Currently, this defaults to 1 which means we have 1 tun queue reader and 1
@@ -234,7 +239,28 @@ tun:
# Unsafe routes allows you to route traffic over nebula to non-nebula nodes # Unsafe routes allows you to route traffic over nebula to non-nebula nodes
# Unsafe routes should be avoided unless you have hosts/services that cannot run nebula # Unsafe routes should be avoided unless you have hosts/services that cannot run nebula
# NOTE: The nebula certificate of the "via" node *MUST* have the "route" defined as a subnet in its certificate # Supports weighted ECMP if you define a list of gateways, this can be used for load balancing or redundancy to hosts outside of nebula
# NOTES:
# * You will only see a single gateway in the routing table if you are not on linux
# * If a gateway is not reachable through the overlay another gateway will be selected to send the traffic through, ignoring weights
#
# unsafe_routes:
# # Multiple gateways without defining a weight defaults to a weight of 1, this will balance traffic equally between the three gateways
# - route: 192.168.87.0/24
# via:
# - gateway: 10.0.0.1
# - gateway: 10.0.0.2
# - gateway: 10.0.0.3
# # Multiple gateways with a weight, this will balance traffic accordingly
# - route: 192.168.87.0/24
# via:
# - gateway: 10.0.0.1
# weight: 10
# - gateway: 10.0.0.2
# weight: 5
#
# NOTE: The nebula certificate of the "via" node(s) *MUST* have the "route" defined as a subnet in its certificate
# `via`: single node or list of gateways to use for this route
# `mtu`: will default to tun mtu if this option is not specified # `mtu`: will default to tun mtu if this option is not specified
# `metric`: will default to 0 if this option is not specified # `metric`: will default to 0 if this option is not specified
# `install`: will default to true, controls whether this route is installed in the systems routing table. # `install`: will default to true, controls whether this route is installed in the systems routing table.
@@ -320,11 +346,11 @@ firewall:
outbound_action: drop outbound_action: drop
inbound_action: drop inbound_action: drop
# Controls the default value for local_cidr. Default is true, will be deprecated after v1.9 and defaulted to false. # THIS FLAG IS DEPRECATED AND WILL BE REMOVED IN A FUTURE RELEASE. (Defaults to false.)
# This setting only affects nebula hosts with subnets encoded in their certificate. A nebula host acting as an # This setting only affects nebula hosts exposing unsafe_routes. When set to false, each inbound rule must contain a
# unsafe router with `default_local_cidr_any: true` will expose their unsafe routes to every inbound rule regardless # `local_cidr` if the intention is to allow traffic to flow to an unsafe route. When set to true, every firewall rule
# of the actual destination for the packet. Setting this to false requires each inbound rule to contain a `local_cidr` # will apply to all configured unsafe_routes regardless of the actual destination of the packet, unless `local_cidr`
# if the intention is to allow traffic to flow to an unsafe route. # is explicitly defined. This is usually not the desired behavior and should be avoided!
#default_local_cidr_any: false #default_local_cidr_any: false
conntrack: conntrack:
@@ -342,11 +368,9 @@ firewall:
# group: `any` or a literal group name, ie `default-group` # group: `any` or a literal group name, ie `default-group`
# groups: Same as group but accepts a list of values. Multiple values are AND'd together and a certificate would have to contain all groups to pass # groups: Same as group but accepts a list of values. Multiple values are AND'd together and a certificate would have to contain all groups to pass
# cidr: a remote CIDR, `0.0.0.0/0` is any ipv4 and `::/0` is any ipv6. # cidr: a remote CIDR, `0.0.0.0/0` is any ipv4 and `::/0` is any ipv6.
# local_cidr: a local CIDR, `0.0.0.0/0` is any ipv4 and `::/0` is any ipv6. This could be used to filter destinations when using unsafe_routes. # local_cidr: a local CIDR, `0.0.0.0/0` is any ipv4 and `::/0` is any ipv6. This can be used to filter destinations when using unsafe_routes.
# If no unsafe networks are present in the certificate(s) or `default_local_cidr_any` is true then the default is any ipv4 or ipv6 network. # By default, this is set to only the VPN (overlay) networks assigned via the certificate networks field unless `default_local_cidr_any` is set to true.
# Otherwise the default is any vpn network assigned to via the certificate. # If there are unsafe_routes present in this config file, `local_cidr` should be set appropriately for the intended us case.
# `default_local_cidr_any` defaults to false and is deprecated, it will be removed in a future release.
# If there are unsafe routes present its best to set `local_cidr` to whatever best fits the situation.
# ca_name: An issuing CA name # ca_name: An issuing CA name
# ca_sha: An issuing CA shasum # ca_sha: An issuing CA shasum

View File

@@ -331,7 +331,7 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw
return nil return nil
} }
rs, ok := r.([]interface{}) rs, ok := r.([]any)
if !ok { if !ok {
return fmt.Errorf("%s failed to parse, should be an array of rules", table) return fmt.Errorf("%s failed to parse, should be an array of rules", table)
} }
@@ -606,7 +606,7 @@ func (f *Firewall) evict(p firewall.Packet) {
return return
} }
newT := t.Expires.Sub(time.Now()) newT := time.Until(t.Expires)
// Timeout is in the future, re-add the timer // Timeout is in the future, re-add the timer
if newT > 0 { if newT > 0 {
@@ -832,7 +832,7 @@ func (fr *FirewallRule) match(p firewall.Packet, c *cert.CachedCertificate) bool
} }
// Shortcut path for if groups, hosts, or cidr contained an `any` // Shortcut path for if groups, hosts, or cidr contained an `any`
if fr.Any.match(p, c) { if fr.Any.match(p) {
return true return true
} }
@@ -849,29 +849,26 @@ func (fr *FirewallRule) match(p firewall.Packet, c *cert.CachedCertificate) bool
found = true found = true
} }
if found && sg.LocalCIDR.match(p, c) { if found && sg.LocalCIDR.match(p) {
return true return true
} }
} }
if fr.Hosts != nil { if fr.Hosts != nil {
if flc, ok := fr.Hosts[c.Certificate.Name()]; ok { if flc, ok := fr.Hosts[c.Certificate.Name()]; ok {
if flc.match(p, c) { if flc.match(p) {
return true return true
} }
} }
} }
matched := false for _, v := range fr.CIDR.Supernets(netip.PrefixFrom(p.RemoteAddr, p.RemoteAddr.BitLen())) {
prefix := netip.PrefixFrom(p.RemoteAddr, p.RemoteAddr.BitLen()) if v.match(p) {
fr.CIDR.EachLookupPrefix(prefix, func(prefix netip.Prefix, val *firewallLocalCIDR) bool { return true
if prefix.Contains(p.RemoteAddr) && val.match(p, c) {
matched = true
return false
} }
return true }
})
return matched return false
} }
func (flc *firewallLocalCIDR) addRule(f *Firewall, localIp netip.Prefix) error { func (flc *firewallLocalCIDR) addRule(f *Firewall, localIp netip.Prefix) error {
@@ -895,7 +892,7 @@ func (flc *firewallLocalCIDR) addRule(f *Firewall, localIp netip.Prefix) error {
return nil return nil
} }
func (flc *firewallLocalCIDR) match(p firewall.Packet, c *cert.CachedCertificate) bool { func (flc *firewallLocalCIDR) match(p firewall.Packet) bool {
if flc == nil { if flc == nil {
return false return false
} }
@@ -921,15 +918,15 @@ type rule struct {
CASha string CASha string
} }
func convertRule(l *logrus.Logger, p interface{}, table string, i int) (rule, error) { func convertRule(l *logrus.Logger, p any, table string, i int) (rule, error) {
r := rule{} r := rule{}
m, ok := p.(map[interface{}]interface{}) m, ok := p.(map[string]any)
if !ok { if !ok {
return r, errors.New("could not parse rule") return r, errors.New("could not parse rule")
} }
toString := func(k string, m map[interface{}]interface{}) string { toString := func(k string, m map[string]any) string {
v, ok := m[k] v, ok := m[k]
if !ok { if !ok {
return "" return ""
@@ -947,7 +944,7 @@ func convertRule(l *logrus.Logger, p interface{}, table string, i int) (rule, er
r.CASha = toString("ca_sha", m) r.CASha = toString("ca_sha", m)
// Make sure group isn't an array // Make sure group isn't an array
if v, ok := m["group"].([]interface{}); ok { if v, ok := m["group"].([]any); ok {
if len(v) > 1 { if len(v) > 1 {
return r, errors.New("group should contain a single value, an array with more than one entry was provided") return r, errors.New("group should contain a single value, an array with more than one entry was provided")
} }

View File

@@ -6,7 +6,7 @@ import (
"net/netip" "net/netip"
) )
type m map[string]interface{} type m = map[string]any
const ( const (
ProtoAny = 0 // When we want to handle HOPOPT (0) we can change this, if ever ProtoAny = 0 // When we want to handle HOPOPT (0) we can change this, if ever

View File

@@ -35,22 +35,27 @@ func TestNewFirewall(t *testing.T) {
assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen) assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen)
fw = NewFirewall(l, time.Second, time.Hour, time.Minute, c) fw = NewFirewall(l, time.Second, time.Hour, time.Minute, c)
conntrack = fw.Conntrack
assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration) assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen) assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen)
fw = NewFirewall(l, time.Hour, time.Second, time.Minute, c) fw = NewFirewall(l, time.Hour, time.Second, time.Minute, c)
conntrack = fw.Conntrack
assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration) assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen) assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen)
fw = NewFirewall(l, time.Hour, time.Minute, time.Second, c) fw = NewFirewall(l, time.Hour, time.Minute, time.Second, c)
conntrack = fw.Conntrack
assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration) assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen) assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen)
fw = NewFirewall(l, time.Minute, time.Hour, time.Second, c) fw = NewFirewall(l, time.Minute, time.Hour, time.Second, c)
conntrack = fw.Conntrack
assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration) assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen) assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen)
fw = NewFirewall(l, time.Minute, time.Second, time.Hour, c) fw = NewFirewall(l, time.Minute, time.Second, time.Hour, c)
conntrack = fw.Conntrack
assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration) assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen) assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen)
} }
@@ -66,61 +71,61 @@ func TestFirewall_AddRule(t *testing.T) {
assert.NotNil(t, fw.OutRules) assert.NotNil(t, fw.OutRules)
ti, err := netip.ParsePrefix("1.2.3.4/32") ti, err := netip.ParsePrefix("1.2.3.4/32")
assert.NoError(t, err) require.NoError(t, err)
assert.Nil(t, fw.AddRule(true, firewall.ProtoTCP, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoTCP, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
// An empty rule is any // An empty rule is any
assert.True(t, fw.InRules.TCP[1].Any.Any.Any) assert.True(t, fw.InRules.TCP[1].Any.Any.Any)
assert.Empty(t, fw.InRules.TCP[1].Any.Groups) assert.Empty(t, fw.InRules.TCP[1].Any.Groups)
assert.Empty(t, fw.InRules.TCP[1].Any.Hosts) assert.Empty(t, fw.InRules.TCP[1].Any.Hosts)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
assert.Nil(t, fw.InRules.UDP[1].Any.Any) assert.Nil(t, fw.InRules.UDP[1].Any.Any)
assert.Contains(t, fw.InRules.UDP[1].Any.Groups[0].Groups, "g1") assert.Contains(t, fw.InRules.UDP[1].Any.Groups[0].Groups, "g1")
assert.Empty(t, fw.InRules.UDP[1].Any.Hosts) assert.Empty(t, fw.InRules.UDP[1].Any.Hosts)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
assert.Nil(t, fw.AddRule(true, firewall.ProtoICMP, 1, 1, []string{}, "h1", netip.Prefix{}, netip.Prefix{}, "", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoICMP, 1, 1, []string{}, "h1", netip.Prefix{}, netip.Prefix{}, "", ""))
assert.Nil(t, fw.InRules.ICMP[1].Any.Any) assert.Nil(t, fw.InRules.ICMP[1].Any.Any)
assert.Empty(t, fw.InRules.ICMP[1].Any.Groups) assert.Empty(t, fw.InRules.ICMP[1].Any.Groups)
assert.Contains(t, fw.InRules.ICMP[1].Any.Hosts, "h1") assert.Contains(t, fw.InRules.ICMP[1].Any.Hosts, "h1")
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti, netip.Prefix{}, "", "")) require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti, netip.Prefix{}, "", ""))
assert.Nil(t, fw.OutRules.AnyProto[1].Any.Any) assert.Nil(t, fw.OutRules.AnyProto[1].Any.Any)
_, ok := fw.OutRules.AnyProto[1].Any.CIDR.Get(ti) _, ok := fw.OutRules.AnyProto[1].Any.CIDR.Get(ti)
assert.True(t, ok) assert.True(t, ok)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, ti, "", "")) require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, ti, "", ""))
assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any) assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any)
_, ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti) _, ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti)
assert.True(t, ok) assert.True(t, ok)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "ca-name", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "ca-name", ""))
assert.Contains(t, fw.InRules.UDP[1].CANames, "ca-name") assert.Contains(t, fw.InRules.UDP[1].CANames, "ca-name")
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "", "ca-sha")) require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "", "ca-sha"))
assert.Contains(t, fw.InRules.UDP[1].CAShas, "ca-sha") assert.Contains(t, fw.InRules.UDP[1].CAShas, "ca-sha")
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", netip.Prefix{}, netip.Prefix{}, "", "")) require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", netip.Prefix{}, netip.Prefix{}, "", ""))
assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any) assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
anyIp, err := netip.ParsePrefix("0.0.0.0/0") anyIp, err := netip.ParsePrefix("0.0.0.0/0")
assert.NoError(t, err) require.NoError(t, err)
assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp, netip.Prefix{}, "", "")) require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp, netip.Prefix{}, "", ""))
assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any) assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
// Test error conditions // Test error conditions
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
assert.Error(t, fw.AddRule(true, math.MaxUint8, 0, 0, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", "")) require.Error(t, fw.AddRule(true, math.MaxUint8, 0, 0, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
assert.Error(t, fw.AddRule(true, firewall.ProtoAny, 10, 0, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", "")) require.Error(t, fw.AddRule(true, firewall.ProtoAny, 10, 0, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
} }
func TestFirewall_Drop(t *testing.T) { func TestFirewall_Drop(t *testing.T) {
@@ -155,16 +160,16 @@ func TestFirewall_Drop(t *testing.T) {
h.buildNetworks(c.networks, c.unsafeNetworks) h.buildNetworks(c.networks, c.unsafeNetworks)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
cp := cert.NewCAPool() cp := cert.NewCAPool()
// Drop outbound // Drop outbound
assert.Equal(t, ErrNoMatchingRule, fw.Drop(p, false, &h, cp, nil)) assert.Equal(t, ErrNoMatchingRule, fw.Drop(p, false, &h, cp, nil))
// Allow inbound // Allow inbound
resetConntrack(fw) resetConntrack(fw)
assert.NoError(t, fw.Drop(p, true, &h, cp, nil)) require.NoError(t, fw.Drop(p, true, &h, cp, nil))
// Allow outbound because conntrack // Allow outbound because conntrack
assert.NoError(t, fw.Drop(p, false, &h, cp, nil)) require.NoError(t, fw.Drop(p, false, &h, cp, nil))
// test remote mismatch // test remote mismatch
oldRemote := p.RemoteAddr oldRemote := p.RemoteAddr
@@ -174,29 +179,29 @@ func TestFirewall_Drop(t *testing.T) {
// ensure signer doesn't get in the way of group checks // ensure signer doesn't get in the way of group checks
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum"))
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad"))
assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule) assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
// test caSha doesn't drop on match // test caSha doesn't drop on match
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad"))
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum"))
assert.NoError(t, fw.Drop(p, true, &h, cp, nil)) require.NoError(t, fw.Drop(p, true, &h, cp, nil))
// ensure ca name doesn't get in the way of group checks // ensure ca name doesn't get in the way of group checks
cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}} cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", ""))
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", ""))
assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule) assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
// test caName doesn't drop on match // test caName doesn't drop on match
cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}} cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", ""))
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", ""))
assert.NoError(t, fw.Drop(p, true, &h, cp, nil)) require.NoError(t, fw.Drop(p, true, &h, cp, nil))
} }
func BenchmarkFirewallTable_match(b *testing.B) { func BenchmarkFirewallTable_match(b *testing.B) {
@@ -350,14 +355,14 @@ func TestFirewall_Drop2(t *testing.T) {
h1.buildNetworks(c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks()) h1.buildNetworks(c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks())
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
cp := cert.NewCAPool() cp := cert.NewCAPool()
// h1/c1 lacks the proper groups // h1/c1 lacks the proper groups
assert.Error(t, fw.Drop(p, true, &h1, cp, nil), ErrNoMatchingRule) require.ErrorIs(t, fw.Drop(p, true, &h1, cp, nil), ErrNoMatchingRule)
// c has the proper groups // c has the proper groups
resetConntrack(fw) resetConntrack(fw)
assert.NoError(t, fw.Drop(p, true, &h, cp, nil)) require.NoError(t, fw.Drop(p, true, &h, cp, nil))
} }
func TestFirewall_Drop3(t *testing.T) { func TestFirewall_Drop3(t *testing.T) {
@@ -428,18 +433,23 @@ func TestFirewall_Drop3(t *testing.T) {
h3.buildNetworks(c3.Certificate.Networks(), c3.Certificate.UnsafeNetworks()) h3.buildNetworks(c3.Certificate.Networks(), c3.Certificate.UnsafeNetworks())
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", netip.Prefix{}, netip.Prefix{}, "", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", netip.Prefix{}, netip.Prefix{}, "", ""))
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-sha")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-sha"))
cp := cert.NewCAPool() cp := cert.NewCAPool()
// c1 should pass because host match // c1 should pass because host match
assert.NoError(t, fw.Drop(p, true, &h1, cp, nil)) require.NoError(t, fw.Drop(p, true, &h1, cp, nil))
// c2 should pass because ca sha match // c2 should pass because ca sha match
resetConntrack(fw) resetConntrack(fw)
assert.NoError(t, fw.Drop(p, true, &h2, cp, nil)) require.NoError(t, fw.Drop(p, true, &h2, cp, nil))
// c3 should fail because no match // c3 should fail because no match
resetConntrack(fw) resetConntrack(fw)
assert.Equal(t, fw.Drop(p, true, &h3, cp, nil), ErrNoMatchingRule) assert.Equal(t, fw.Drop(p, true, &h3, cp, nil), ErrNoMatchingRule)
// Test a remote address match
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.MustParsePrefix("1.2.3.4/24"), netip.Prefix{}, "", ""))
require.NoError(t, fw.Drop(p, true, &h1, cp, nil))
} }
func TestFirewall_DropConntrackReload(t *testing.T) { func TestFirewall_DropConntrackReload(t *testing.T) {
@@ -475,29 +485,29 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
h.buildNetworks(c.Certificate.Networks(), c.Certificate.UnsafeNetworks()) h.buildNetworks(c.Certificate.Networks(), c.Certificate.UnsafeNetworks())
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
cp := cert.NewCAPool() cp := cert.NewCAPool()
// Drop outbound // Drop outbound
assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrNoMatchingRule) assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrNoMatchingRule)
// Allow inbound // Allow inbound
resetConntrack(fw) resetConntrack(fw)
assert.NoError(t, fw.Drop(p, true, &h, cp, nil)) require.NoError(t, fw.Drop(p, true, &h, cp, nil))
// Allow outbound because conntrack // Allow outbound because conntrack
assert.NoError(t, fw.Drop(p, false, &h, cp, nil)) require.NoError(t, fw.Drop(p, false, &h, cp, nil))
oldFw := fw oldFw := fw
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 10, 10, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 10, 10, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
fw.Conntrack = oldFw.Conntrack fw.Conntrack = oldFw.Conntrack
fw.rulesVersion = oldFw.rulesVersion + 1 fw.rulesVersion = oldFw.rulesVersion + 1
// Allow outbound because conntrack and new rules allow port 10 // Allow outbound because conntrack and new rules allow port 10
assert.NoError(t, fw.Drop(p, false, &h, cp, nil)) require.NoError(t, fw.Drop(p, false, &h, cp, nil))
oldFw = fw oldFw = fw
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 11, 11, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 11, 11, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
fw.Conntrack = oldFw.Conntrack fw.Conntrack = oldFw.Conntrack
fw.rulesVersion = oldFw.rulesVersion + 1 fw.rulesVersion = oldFw.rulesVersion + 1
@@ -580,42 +590,42 @@ func BenchmarkLookup(b *testing.B) {
func Test_parsePort(t *testing.T) { func Test_parsePort(t *testing.T) {
_, _, err := parsePort("") _, _, err := parsePort("")
assert.EqualError(t, err, "was not a number; ``") require.EqualError(t, err, "was not a number; ``")
_, _, err = parsePort(" ") _, _, err = parsePort(" ")
assert.EqualError(t, err, "was not a number; ` `") require.EqualError(t, err, "was not a number; ` `")
_, _, err = parsePort("-") _, _, err = parsePort("-")
assert.EqualError(t, err, "appears to be a range but could not be parsed; `-`") require.EqualError(t, err, "appears to be a range but could not be parsed; `-`")
_, _, err = parsePort(" - ") _, _, err = parsePort(" - ")
assert.EqualError(t, err, "appears to be a range but could not be parsed; ` - `") require.EqualError(t, err, "appears to be a range but could not be parsed; ` - `")
_, _, err = parsePort("a-b") _, _, err = parsePort("a-b")
assert.EqualError(t, err, "beginning range was not a number; `a`") require.EqualError(t, err, "beginning range was not a number; `a`")
_, _, err = parsePort("1-b") _, _, err = parsePort("1-b")
assert.EqualError(t, err, "ending range was not a number; `b`") require.EqualError(t, err, "ending range was not a number; `b`")
s, e, err := parsePort(" 1 - 2 ") s, e, err := parsePort(" 1 - 2 ")
assert.Equal(t, int32(1), s) assert.Equal(t, int32(1), s)
assert.Equal(t, int32(2), e) assert.Equal(t, int32(2), e)
assert.Nil(t, err) require.NoError(t, err)
s, e, err = parsePort("0-1") s, e, err = parsePort("0-1")
assert.Equal(t, int32(0), s) assert.Equal(t, int32(0), s)
assert.Equal(t, int32(0), e) assert.Equal(t, int32(0), e)
assert.Nil(t, err) require.NoError(t, err)
s, e, err = parsePort("9919") s, e, err = parsePort("9919")
assert.Equal(t, int32(9919), s) assert.Equal(t, int32(9919), s)
assert.Equal(t, int32(9919), e) assert.Equal(t, int32(9919), e)
assert.Nil(t, err) require.NoError(t, err)
s, e, err = parsePort("any") s, e, err = parsePort("any")
assert.Equal(t, int32(0), s) assert.Equal(t, int32(0), s)
assert.Equal(t, int32(0), e) assert.Equal(t, int32(0), e)
assert.Nil(t, err) require.NoError(t, err)
} }
func TestNewFirewallFromConfig(t *testing.T) { func TestNewFirewallFromConfig(t *testing.T) {
@@ -626,55 +636,55 @@ func TestNewFirewallFromConfig(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
conf := config.NewC(l) conf := config.NewC(l)
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": "asdf"} conf.Settings["firewall"] = map[string]any{"outbound": "asdf"}
_, err = NewFirewallFromConfig(l, cs, conf) _, err = NewFirewallFromConfig(l, cs, conf)
assert.EqualError(t, err, "firewall.outbound failed to parse, should be an array of rules") require.EqualError(t, err, "firewall.outbound failed to parse, should be an array of rules")
// Test both port and code // Test both port and code
conf = config.NewC(l) conf = config.NewC(l)
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "code": "2"}}} conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "code": "2"}}}
_, err = NewFirewallFromConfig(l, cs, conf) _, err = NewFirewallFromConfig(l, cs, conf)
assert.EqualError(t, err, "firewall.outbound rule #0; only one of port or code should be provided") require.EqualError(t, err, "firewall.outbound rule #0; only one of port or code should be provided")
// Test missing host, group, cidr, ca_name and ca_sha // Test missing host, group, cidr, ca_name and ca_sha
conf = config.NewC(l) conf = config.NewC(l)
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{}}} conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{}}}
_, err = NewFirewallFromConfig(l, cs, conf) _, err = NewFirewallFromConfig(l, cs, conf)
assert.EqualError(t, err, "firewall.outbound rule #0; at least one of host, group, cidr, local_cidr, ca_name, or ca_sha must be provided") require.EqualError(t, err, "firewall.outbound rule #0; at least one of host, group, cidr, local_cidr, ca_name, or ca_sha must be provided")
// Test code/port error // Test code/port error
conf = config.NewC(l) conf = config.NewC(l)
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "a", "host": "testh"}}} conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "a", "host": "testh"}}}
_, err = NewFirewallFromConfig(l, cs, conf) _, err = NewFirewallFromConfig(l, cs, conf)
assert.EqualError(t, err, "firewall.outbound rule #0; code was not a number; `a`") require.EqualError(t, err, "firewall.outbound rule #0; code was not a number; `a`")
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "a", "host": "testh"}}} conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "a", "host": "testh"}}}
_, err = NewFirewallFromConfig(l, cs, conf) _, err = NewFirewallFromConfig(l, cs, conf)
assert.EqualError(t, err, "firewall.outbound rule #0; port was not a number; `a`") require.EqualError(t, err, "firewall.outbound rule #0; port was not a number; `a`")
// Test proto error // Test proto error
conf = config.NewC(l) conf = config.NewC(l)
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "host": "testh"}}} conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "1", "host": "testh"}}}
_, err = NewFirewallFromConfig(l, cs, conf) _, err = NewFirewallFromConfig(l, cs, conf)
assert.EqualError(t, err, "firewall.outbound rule #0; proto was not understood; ``") require.EqualError(t, err, "firewall.outbound rule #0; proto was not understood; ``")
// Test cidr parse error // Test cidr parse error
conf = config.NewC(l) conf = config.NewC(l)
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "cidr": "testh", "proto": "any"}}} conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "1", "cidr": "testh", "proto": "any"}}}
_, err = NewFirewallFromConfig(l, cs, conf) _, err = NewFirewallFromConfig(l, cs, conf)
assert.EqualError(t, err, "firewall.outbound rule #0; cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'") require.EqualError(t, err, "firewall.outbound rule #0; cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'")
// Test local_cidr parse error // Test local_cidr parse error
conf = config.NewC(l) conf = config.NewC(l)
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "local_cidr": "testh", "proto": "any"}}} conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "1", "local_cidr": "testh", "proto": "any"}}}
_, err = NewFirewallFromConfig(l, cs, conf) _, err = NewFirewallFromConfig(l, cs, conf)
assert.EqualError(t, err, "firewall.outbound rule #0; local_cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'") require.EqualError(t, err, "firewall.outbound rule #0; local_cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'")
// Test both group and groups // Test both group and groups
conf = config.NewC(l) conf = config.NewC(l)
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a", "groups": []string{"b", "c"}}}} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "group": "a", "groups": []string{"b", "c"}}}}
_, err = NewFirewallFromConfig(l, cs, conf) _, err = NewFirewallFromConfig(l, cs, conf)
assert.EqualError(t, err, "firewall.inbound rule #0; only one of group or groups should be defined, both provided") require.EqualError(t, err, "firewall.inbound rule #0; only one of group or groups should be defined, both provided")
} }
func TestAddFirewallRulesFromConfig(t *testing.T) { func TestAddFirewallRulesFromConfig(t *testing.T) {
@@ -682,87 +692,87 @@ func TestAddFirewallRulesFromConfig(t *testing.T) {
// Test adding tcp rule // Test adding tcp rule
conf := config.NewC(l) conf := config.NewC(l)
mf := &mockFirewall{} mf := &mockFirewall{}
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "tcp", "host": "a"}}} conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "tcp", "host": "a"}}}
assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf)) require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf))
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoTCP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoTCP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
// Test adding udp rule // Test adding udp rule
conf = config.NewC(l) conf = config.NewC(l)
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "udp", "host": "a"}}} conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "udp", "host": "a"}}}
assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf)) require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf))
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoUDP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoUDP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
// Test adding icmp rule // Test adding icmp rule
conf = config.NewC(l) conf = config.NewC(l)
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "icmp", "host": "a"}}} conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "icmp", "host": "a"}}}
assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf)) require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf))
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
// Test adding any rule // Test adding any rule
conf = config.NewC(l) conf = config.NewC(l)
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "host": "a"}}} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "host": "a"}}}
assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf)) require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
// Test adding rule with cidr // Test adding rule with cidr
cidr := netip.MustParsePrefix("10.0.0.0/8") cidr := netip.MustParsePrefix("10.0.0.0/8")
conf = config.NewC(l) conf = config.NewC(l)
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "cidr": cidr.String()}}} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": cidr.String()}}}
assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf)) require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr, localIp: netip.Prefix{}}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr, localIp: netip.Prefix{}}, mf.lastCall)
// Test adding rule with local_cidr // Test adding rule with local_cidr
conf = config.NewC(l) conf = config.NewC(l)
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "local_cidr": cidr.String()}}} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": cidr.String()}}}
assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf)) require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: cidr}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: cidr}, mf.lastCall)
// Test adding rule with ca_sha // Test adding rule with ca_sha
conf = config.NewC(l) conf = config.NewC(l)
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "ca_sha": "12312313123"}}} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "ca_sha": "12312313123"}}}
assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf)) require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: netip.Prefix{}, caSha: "12312313123"}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: netip.Prefix{}, caSha: "12312313123"}, mf.lastCall)
// Test adding rule with ca_name // Test adding rule with ca_name
conf = config.NewC(l) conf = config.NewC(l)
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "ca_name": "root01"}}} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "ca_name": "root01"}}}
assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf)) require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: netip.Prefix{}, caName: "root01"}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: netip.Prefix{}, caName: "root01"}, mf.lastCall)
// Test single group // Test single group
conf = config.NewC(l) conf = config.NewC(l)
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a"}}} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "group": "a"}}}
assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf)) require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
// Test single groups // Test single groups
conf = config.NewC(l) conf = config.NewC(l)
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "groups": "a"}}} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "groups": "a"}}}
assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf)) require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
// Test multiple AND groups // Test multiple AND groups
conf = config.NewC(l) conf = config.NewC(l)
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "groups": []string{"a", "b"}}}} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "groups": []string{"a", "b"}}}}
assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf)) require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a", "b"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a", "b"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
// Test Add error // Test Add error
conf = config.NewC(l) conf = config.NewC(l)
mf = &mockFirewall{} mf = &mockFirewall{}
mf.nextCallReturn = errors.New("test error") mf.nextCallReturn = errors.New("test error")
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "host": "a"}}} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "host": "a"}}}
assert.EqualError(t, AddFirewallRulesFromConfig(l, true, conf, mf), "firewall.inbound rule #0; `test error`") require.EqualError(t, AddFirewallRulesFromConfig(l, true, conf, mf), "firewall.inbound rule #0; `test error`")
} }
func TestFirewall_convertRule(t *testing.T) { func TestFirewall_convertRule(t *testing.T) {
@@ -771,33 +781,33 @@ func TestFirewall_convertRule(t *testing.T) {
l.SetOutput(ob) l.SetOutput(ob)
// Ensure group array of 1 is converted and a warning is printed // Ensure group array of 1 is converted and a warning is printed
c := map[interface{}]interface{}{ c := map[string]any{
"group": []interface{}{"group1"}, "group": []any{"group1"},
} }
r, err := convertRule(l, c, "test", 1) r, err := convertRule(l, c, "test", 1)
assert.Contains(t, ob.String(), "test rule #1; group was an array with a single value, converting to simple value") assert.Contains(t, ob.String(), "test rule #1; group was an array with a single value, converting to simple value")
assert.Nil(t, err) require.NoError(t, err)
assert.Equal(t, "group1", r.Group) assert.Equal(t, "group1", r.Group)
// Ensure group array of > 1 is errord // Ensure group array of > 1 is errord
ob.Reset() ob.Reset()
c = map[interface{}]interface{}{ c = map[string]any{
"group": []interface{}{"group1", "group2"}, "group": []any{"group1", "group2"},
} }
r, err = convertRule(l, c, "test", 1) r, err = convertRule(l, c, "test", 1)
assert.Equal(t, "", ob.String()) assert.Empty(t, ob.String())
assert.Error(t, err, "group should contain a single value, an array with more than one entry was provided") require.Error(t, err, "group should contain a single value, an array with more than one entry was provided")
// Make sure a well formed group is alright // Make sure a well formed group is alright
ob.Reset() ob.Reset()
c = map[interface{}]interface{}{ c = map[string]any{
"group": "group1", "group": "group1",
} }
r, err = convertRule(l, c, "test", 1) r, err = convertRule(l, c, "test", 1)
assert.Nil(t, err) require.NoError(t, err)
assert.Equal(t, "group1", r.Group) assert.Equal(t, "group1", r.Group)
} }

36
go.mod
View File

@@ -1,8 +1,8 @@
module github.com/slackhq/nebula module github.com/slackhq/nebula
go 1.22.0 go 1.23.6
toolchain go1.22.2 toolchain go1.24.1
require ( require (
dario.cat/mergo v1.0.1 dario.cat/mergo v1.0.1
@@ -10,49 +10,47 @@ require (
github.com/armon/go-radix v1.0.0 github.com/armon/go-radix v1.0.0
github.com/cyberdelia/go-metrics-graphite v0.0.0-20161219230853-39f87cc3b432 github.com/cyberdelia/go-metrics-graphite v0.0.0-20161219230853-39f87cc3b432
github.com/flynn/noise v1.1.0 github.com/flynn/noise v1.1.0
github.com/gaissmai/bart v0.13.0 github.com/gaissmai/bart v0.20.1
github.com/gogo/protobuf v1.3.2 github.com/gogo/protobuf v1.3.2
github.com/google/gopacket v1.1.19 github.com/google/gopacket v1.1.19
github.com/kardianos/service v1.2.2 github.com/kardianos/service v1.2.2
github.com/miekg/dns v1.1.62 github.com/miekg/dns v1.1.65
github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b
github.com/nbrownus/go-metrics-prometheus v0.0.0-20210712211119-974a6260965f github.com/nbrownus/go-metrics-prometheus v0.0.0-20210712211119-974a6260965f
github.com/prometheus/client_golang v1.20.4 github.com/prometheus/client_golang v1.21.1
github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475 github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475
github.com/sirupsen/logrus v1.9.3 github.com/sirupsen/logrus v1.9.3
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e
github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6 github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6
github.com/stretchr/testify v1.9.0 github.com/stretchr/testify v1.10.0
github.com/vishvananda/netlink v1.3.0 github.com/vishvananda/netlink v1.3.0
golang.org/x/crypto v0.28.0 golang.org/x/crypto v0.37.0
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 golang.org/x/exp v0.0.0-20230725093048-515e97ebf090
golang.org/x/net v0.30.0 golang.org/x/net v0.38.0
golang.org/x/sync v0.8.0 golang.org/x/sync v0.13.0
golang.org/x/sys v0.26.0 golang.org/x/sys v0.32.0
golang.org/x/term v0.25.0 golang.org/x/term v0.31.0
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2
golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b
golang.zx2c4.com/wireguard/windows v0.5.3 golang.zx2c4.com/wireguard/windows v0.5.3
google.golang.org/protobuf v1.35.1 google.golang.org/protobuf v1.36.6
gopkg.in/yaml.v2 v2.4.0 gopkg.in/yaml.v3 v3.0.1
gvisor.dev/gvisor v0.0.0-20240423190808-9d7a357edefe gvisor.dev/gvisor v0.0.0-20240423190808-9d7a357edefe
) )
require ( require (
github.com/beorn7/perks v1.0.1 // indirect github.com/beorn7/perks v1.0.1 // indirect
github.com/bits-and-blooms/bitset v1.14.3 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect github.com/davecgh/go-spew v1.1.1 // indirect
github.com/google/btree v1.1.2 // indirect github.com/google/btree v1.1.2 // indirect
github.com/klauspost/compress v1.17.9 // indirect github.com/klauspost/compress v1.17.11 // indirect
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/prometheus/client_model v0.6.1 // indirect github.com/prometheus/client_model v0.6.1 // indirect
github.com/prometheus/common v0.55.0 // indirect github.com/prometheus/common v0.62.0 // indirect
github.com/prometheus/procfs v0.15.1 // indirect github.com/prometheus/procfs v0.15.1 // indirect
github.com/vishvananda/netns v0.0.4 // indirect github.com/vishvananda/netns v0.0.4 // indirect
golang.org/x/mod v0.18.0 // indirect golang.org/x/mod v0.23.0 // indirect
golang.org/x/time v0.5.0 // indirect golang.org/x/time v0.5.0 // indirect
golang.org/x/tools v0.22.0 // indirect golang.org/x/tools v0.30.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
) )

60
go.sum
View File

@@ -14,8 +14,6 @@ github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24
github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8= github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8=
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
github.com/bits-and-blooms/bitset v1.14.3 h1:Gd2c8lSNf9pKXom5JtD7AaKO8o7fGQ2LtFj1436qilA=
github.com/bits-and-blooms/bitset v1.14.3/go.mod h1:7hO7Gc7Pp1vODcmWvKMRA9BNmbv6a/7QIWpPxHddWR8=
github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
@@ -26,8 +24,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/flynn/noise v1.1.0 h1:KjPQoQCEFdZDiP03phOvGi11+SVVhBG2wOWAorLsstg= github.com/flynn/noise v1.1.0 h1:KjPQoQCEFdZDiP03phOvGi11+SVVhBG2wOWAorLsstg=
github.com/flynn/noise v1.1.0/go.mod h1:xbMo+0i6+IGbYdJhF31t2eR1BIU0CYc12+BNAKwUTag= github.com/flynn/noise v1.1.0/go.mod h1:xbMo+0i6+IGbYdJhF31t2eR1BIU0CYc12+BNAKwUTag=
github.com/gaissmai/bart v0.13.0 h1:pItEhXDVVebUa+i978FfQ7ye8xZc1FrMgs8nJPPWAgA= github.com/gaissmai/bart v0.20.1 h1:igNss0zDsSY8e+ophKgD9KJVPKBOo7uSVjyKCL7nIzo=
github.com/gaissmai/bart v0.13.0/go.mod h1:qSes2fnJ8hB410BW0ymHUN/eQkuGpTYyJcN8sKMYpJU= github.com/gaissmai/bart v0.20.1/go.mod h1:JJzMAhNF5Rjo4SF4jWBrANuJfqY+FvsFhW7t1UZJ+XY=
github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY= github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY=
@@ -70,8 +68,8 @@ github.com/kardianos/service v1.2.2 h1:ZvePhAHfvo0A7Mftk/tEzqEZ7Q4lgnR8sGz4xu1YX
github.com/kardianos/service v1.2.2/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= github.com/kardianos/service v1.2.2/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8= github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8=
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
github.com/klauspost/compress v1.17.9 h1:6KIumPrER1LHsvBVuDa0r5xaG0Es51mhhB9BQB2qeMA= github.com/klauspost/compress v1.17.11 h1:In6xLpyWOi1+C7tXUUWv2ot1QvBjxevKAaI6IXrJmUc=
github.com/klauspost/compress v1.17.9/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw= github.com/klauspost/compress v1.17.11/go.mod h1:pMDklpSncoRMuLFrf1W9Ss9KT+0rH90U12bZKk7uwG0=
github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
github.com/konsorten/go-windows-terminal-sequences v1.0.3/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/konsorten/go-windows-terminal-sequences v1.0.3/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc= github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc=
@@ -85,8 +83,8 @@ github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc=
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0=
github.com/miekg/dns v1.1.62 h1:cN8OuEF1/x5Rq6Np+h1epln8OiyPWV+lROx9LxcGgIQ= github.com/miekg/dns v1.1.65 h1:0+tIPHzUW0GCge7IiK3guGP57VAw7hoPDfApjkMD1Fc=
github.com/miekg/dns v1.1.62/go.mod h1:mvDlcItzm+br7MToIKqkglaGhlFMHJ9DTNNWONWXbNQ= github.com/miekg/dns v1.1.65/go.mod h1:Dzw9769uoKVaLuODMDZz9M6ynFU6Em65csPuoi8G0ck=
github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b h1:J/AzCvg5z0Hn1rqZUJjpbzALUmkKX0Zwbc/i4fw7Sfk= github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b h1:J/AzCvg5z0Hn1rqZUJjpbzALUmkKX0Zwbc/i4fw7Sfk=
github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b/go.mod h1:XsNlhZGX73bx86s2hdc/FuaLm2CPZJemRLMA+WTFxgs= github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b/go.mod h1:XsNlhZGX73bx86s2hdc/FuaLm2CPZJemRLMA+WTFxgs=
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
@@ -108,8 +106,8 @@ github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXP
github.com/prometheus/client_golang v1.0.0/go.mod h1:db9x61etRT2tGnBNRi70OPL5FsnadC4Ky3P0J6CfImo= github.com/prometheus/client_golang v1.0.0/go.mod h1:db9x61etRT2tGnBNRi70OPL5FsnadC4Ky3P0J6CfImo=
github.com/prometheus/client_golang v1.7.1/go.mod h1:PY5Wy2awLA44sXw4AOSfFBetzPP4j5+D6mVACh+pe2M= github.com/prometheus/client_golang v1.7.1/go.mod h1:PY5Wy2awLA44sXw4AOSfFBetzPP4j5+D6mVACh+pe2M=
github.com/prometheus/client_golang v1.11.0/go.mod h1:Z6t4BnS23TR94PD6BsDNk8yVqroYurpAkEiz0P2BEV0= github.com/prometheus/client_golang v1.11.0/go.mod h1:Z6t4BnS23TR94PD6BsDNk8yVqroYurpAkEiz0P2BEV0=
github.com/prometheus/client_golang v1.20.4 h1:Tgh3Yr67PaOv/uTqloMsCEdeuFTatm5zIq5+qNN23vI= github.com/prometheus/client_golang v1.21.1 h1:DOvXXTqVzvkIewV/CDPFdejpMCGeMcbGCQ8YOmu+Ibk=
github.com/prometheus/client_golang v1.20.4/go.mod h1:PIEt8X02hGcP8JWbeHyeZ53Y/jReSnHgO035n//V5WE= github.com/prometheus/client_golang v1.21.1/go.mod h1:U9NM32ykUErtVBxdvD3zfi+EuFkkaBvMb09mIfe0Zgg=
github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo= github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo=
github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
github.com/prometheus/client_model v0.2.0/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= github.com/prometheus/client_model v0.2.0/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
@@ -118,8 +116,8 @@ github.com/prometheus/client_model v0.6.1/go.mod h1:OrxVMOVHjw3lKMa8+x6HeMGkHMQy
github.com/prometheus/common v0.4.1/go.mod h1:TNfzLD0ON7rHzMJeJkieUDPYmFC7Snx/y86RQel1bk4= github.com/prometheus/common v0.4.1/go.mod h1:TNfzLD0ON7rHzMJeJkieUDPYmFC7Snx/y86RQel1bk4=
github.com/prometheus/common v0.10.0/go.mod h1:Tlit/dnDKsSWFlCLTWaA1cyBgKHSMdTB80sz/V91rCo= github.com/prometheus/common v0.10.0/go.mod h1:Tlit/dnDKsSWFlCLTWaA1cyBgKHSMdTB80sz/V91rCo=
github.com/prometheus/common v0.26.0/go.mod h1:M7rCNAaPfAosfx8veZJCuw84e35h3Cfd9VFqTh1DIvc= github.com/prometheus/common v0.26.0/go.mod h1:M7rCNAaPfAosfx8veZJCuw84e35h3Cfd9VFqTh1DIvc=
github.com/prometheus/common v0.55.0 h1:KEi6DK7lXW/m7Ig5i47x0vRzuBsHuvJdi5ee6Y3G1dc= github.com/prometheus/common v0.62.0 h1:xasJaQlnWAeyHdUBeGjXmutelfJHWMRr+Fg4QszZ2Io=
github.com/prometheus/common v0.55.0/go.mod h1:2SECS4xJG1kd8XF9IcM1gMX6510RAEL65zxzNImwdc8= github.com/prometheus/common v0.62.0/go.mod h1:vyBcEuLSvWos9B1+CyL7JZ2up+uFzXhkqml0W5zIY1I=
github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk= github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk=
github.com/prometheus/procfs v0.0.2/go.mod h1:TjEm7ze935MbeOT/UhFTIMYKhuLP4wbCsTZCD3I8kEA= github.com/prometheus/procfs v0.0.2/go.mod h1:TjEm7ze935MbeOT/UhFTIMYKhuLP4wbCsTZCD3I8kEA=
github.com/prometheus/procfs v0.1.3/go.mod h1:lV6e/gmhEcM9IjHGsFOCxxuZ+z1YqCvr4OA4YeYWdaU= github.com/prometheus/procfs v0.1.3/go.mod h1:lV6e/gmhEcM9IjHGsFOCxxuZ+z1YqCvr4OA4YeYWdaU=
@@ -145,8 +143,8 @@ github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXf
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/vishvananda/netlink v1.3.0 h1:X7l42GfcV4S6E4vHTsw48qbrV+9PVojNfIhZcwQdrZk= github.com/vishvananda/netlink v1.3.0 h1:X7l42GfcV4S6E4vHTsw48qbrV+9PVojNfIhZcwQdrZk=
github.com/vishvananda/netlink v1.3.0/go.mod h1:i6NetklAujEcC6fK0JPjT8qSwWyO0HLn4UKG+hGqeJs= github.com/vishvananda/netlink v1.3.0/go.mod h1:i6NetklAujEcC6fK0JPjT8qSwWyO0HLn4UKG+hGqeJs=
github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8= github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8=
@@ -158,16 +156,16 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
golang.org/x/crypto v0.28.0 h1:GBDwsMXVQi34v5CCYUm2jkJvu4cbtru2U4TN2PSyQnw= golang.org/x/crypto v0.37.0 h1:kJNSjF/Xp7kU0iB2Z+9viTPMW4EqqsrywMXLJOOsXSE=
golang.org/x/crypto v0.28.0/go.mod h1:rmgy+3RHxRZMyY0jjAJShp2zgEdOqj2AO7U0pYmeQ7U= golang.org/x/crypto v0.37.0/go.mod h1:vg+k43peMZ0pUMhYmVAWysMK35e6ioLh3wB8ZCAfbVc=
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 h1:Di6/M8l0O2lCLc6VVRWhgCiApHV8MnQurBnFSHsQtNY= golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 h1:Di6/M8l0O2lCLc6VVRWhgCiApHV8MnQurBnFSHsQtNY=
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc= golang.org/x/exp v0.0.0-20230725093048-515e97ebf090/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc=
golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY= golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY=
golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg=
golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.18.0 h1:5+9lSbEzPSdWkH32vYPBwEpX8KwDbM52Ud9xBUvNlb0= golang.org/x/mod v0.23.0 h1:Zb7khfcRGKk+kqfxFaP5tZqCnDZMjC5VtUBs87Hr6QM=
golang.org/x/mod v0.18.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= golang.org/x/mod v0.23.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY=
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
@@ -178,8 +176,8 @@ golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLL
golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.30.0 h1:AcW1SDZMkb8IpzCdQUaIq2sP4sZ4zw+55h6ynffypl4= golang.org/x/net v0.38.0 h1:vRMAPTMaeGqVhG5QyLJHqNDwecKTomGeqbnfZyKlBI8=
golang.org/x/net v0.30.0/go.mod h1:2wGyMJ5iFasEhkwi13ChkO/t1ECNC4X4eBKkVFyYFlU= golang.org/x/net v0.38.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8=
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
@@ -187,8 +185,8 @@ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJ
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ= golang.org/x/sync v0.13.0 h1:AauUjRAJ9OSnvULf/ARrrVywoJDy0YS2AwQ98I37610=
golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sync v0.13.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
@@ -206,11 +204,11 @@ golang.org/x/sys v0.0.0-20210603081109-ebe580a85c40/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.26.0 h1:KHjCJyddX0LoSTb3J+vWpupP9p0oznkqVk/IfjymZbo= golang.org/x/sys v0.32.0 h1:s77OFDvIQeibCmezSnk/q6iAfkdiQaJi4VzroCFrN20=
golang.org/x/sys v0.26.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.32.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.25.0 h1:WtHI/ltw4NvSUig5KARz9h521QvRC8RmF/cuYqifU24= golang.org/x/term v0.31.0 h1:erwDkOK1Msy6offm1mOgvspSkslFnIGsFnxOKoufg3o=
golang.org/x/term v0.25.0/go.mod h1:RPyXicDX+6vLxogjjRxjgD2TKtmAO6NZBsBRfrOLu7M= golang.org/x/term v0.31.0/go.mod h1:R4BeIy7D95HzImkxGkTW1UQTtP54tio2RyHz7PwK0aw=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
@@ -221,8 +219,8 @@ golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtn
golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
golang.org/x/tools v0.22.0 h1:gqSGLZqv+AI9lIQzniJ0nZDRG5GBPsSi+DRNHWNz6yA= golang.org/x/tools v0.30.0 h1:BgcpHewrV5AUp2G9MebG4XPFI1E2W41zU1SaqVA9vJY=
golang.org/x/tools v0.22.0/go.mod h1:aCwcsjqvq7Yqt6TNyX7QMU2enbQ/Gt0bo6krSeEri+c= golang.org/x/tools v0.30.0/go.mod h1:c347cR/OJfw5TI+GfX7RUPNMdDRRbjvYTS0jPyvsVtY=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
@@ -241,8 +239,8 @@ google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miE
google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo= google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo=
google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU=
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
google.golang.org/protobuf v1.35.1 h1:m3LfL6/Ca+fqnjnlqQXNpFPABW1UD7mjh8KO2mKFytA= google.golang.org/protobuf v1.36.6 h1:z1NpPI8ku2WgiWnf+t9wTPsn6eP1L7ksHUlkfLvd9xY=
google.golang.org/protobuf v1.35.1/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE= google.golang.org/protobuf v1.36.6/go.mod h1:jduwjTPXsFjZGTmRluh+L6NjiWu7pchiJ2/5YcXBHnY=
gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw= gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
@@ -253,8 +251,6 @@ gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.2.5/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.5/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

View File

@@ -25,7 +25,7 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
// If we're connecting to a v6 address we must use a v2 cert // If we're connecting to a v6 address we must use a v2 cert
cs := f.pki.getCertState() cs := f.pki.getCertState()
v := cs.defaultVersion v := cs.initiatingVersion
for _, a := range hh.hostinfo.vpnAddrs { for _, a := range hh.hostinfo.vpnAddrs {
if a.Is6() { if a.Is6() {
v = cert.Version2 v = cert.Version2
@@ -71,7 +71,8 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
hsBytes, err := hs.Marshal() hsBytes, err := hs.Marshal()
if err != nil { if err != nil {
f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs).WithField("certVersion", v). f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs).
WithField("certVersion", v).
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to marshal handshake message") WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to marshal handshake message")
return false return false
} }
@@ -100,7 +101,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
if crt == nil { if crt == nil {
f.l.WithField("udpAddr", addr). f.l.WithField("udpAddr", addr).
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}). WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
WithField("certVersion", cs.defaultVersion). WithField("certVersion", cs.initiatingVersion).
Error("Unable to handshake with host because no certificate is available") Error("Unable to handshake with host because no certificate is available")
} }
@@ -132,13 +133,28 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
return return
} }
remoteCert, err := cert.RecombineAndValidate(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), ci.Curve(), f.pki.GetCAPool()) rc, err := cert.Recombine(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), ci.Curve())
if err != nil { if err != nil {
e := f.l.WithError(err).WithField("udpAddr", addr). f.l.WithError(err).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}) WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
Info("Handshake did not contain a certificate")
return
}
if f.l.Level > logrus.DebugLevel { remoteCert, err := f.pki.GetCAPool().VerifyCertificate(time.Now(), rc)
e = e.WithField("cert", remoteCert) if err != nil {
fp, err := rc.Fingerprint()
if err != nil {
fp = "<error generating certificate fingerprint>"
}
e := f.l.WithError(err).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
WithField("certVpnNetworks", rc.Networks()).
WithField("certFingerprint", fp)
if f.l.Level >= logrus.DebugLevel {
e = e.WithField("cert", rc)
} }
e.Info("Invalid certificate from host") e.Info("Invalid certificate from host")
@@ -160,20 +176,17 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
} }
if len(remoteCert.Certificate.Networks()) == 0 { if len(remoteCert.Certificate.Networks()) == 0 {
e := f.l.WithError(err).WithField("udpAddr", addr). f.l.WithError(err).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}) WithField("cert", remoteCert).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
if f.l.Level > logrus.DebugLevel { Info("No networks in certificate")
e = e.WithField("cert", remoteCert)
}
e.Info("Invalid vpn ip from host")
return return
} }
var vpnAddrs []netip.Addr var vpnAddrs []netip.Addr
var filteredNetworks []netip.Prefix var filteredNetworks []netip.Prefix
certName := remoteCert.Certificate.Name() certName := remoteCert.Certificate.Name()
certVersion := remoteCert.Certificate.Version()
fingerprint := remoteCert.Fingerprint fingerprint := remoteCert.Fingerprint
issuer := remoteCert.Certificate.Issuer() issuer := remoteCert.Certificate.Issuer()
@@ -183,6 +196,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
if found { if found {
f.l.WithField("vpnAddr", vpnAddr).WithField("udpAddr", addr). f.l.WithField("vpnAddr", vpnAddr).WithField("udpAddr", addr).
WithField("certName", certName). WithField("certName", certName).
WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint). WithField("fingerprint", fingerprint).
WithField("issuer", issuer). WithField("issuer", issuer).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Refusing to handshake with myself") WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Refusing to handshake with myself")
@@ -201,6 +215,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
if len(vpnAddrs) == 0 { if len(vpnAddrs) == 0 {
f.l.WithError(err).WithField("udpAddr", addr). f.l.WithError(err).WithField("udpAddr", addr).
WithField("certName", certName). WithField("certName", certName).
WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint). WithField("fingerprint", fingerprint).
WithField("issuer", issuer). WithField("issuer", issuer).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("No usable vpn addresses from host, refusing handshake") WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("No usable vpn addresses from host, refusing handshake")
@@ -220,6 +235,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
if err != nil { if err != nil {
f.l.WithError(err).WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr). f.l.WithError(err).WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
WithField("certName", certName). WithField("certName", certName).
WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint). WithField("fingerprint", fingerprint).
WithField("issuer", issuer). WithField("issuer", issuer).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to generate index") WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to generate index")
@@ -242,6 +258,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr). f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
WithField("certName", certName). WithField("certName", certName).
WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint). WithField("fingerprint", fingerprint).
WithField("issuer", issuer). WithField("issuer", issuer).
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
@@ -253,6 +270,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
if hs.Details.Cert == nil { if hs.Details.Cert == nil {
f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr). f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
WithField("certName", certName). WithField("certName", certName).
WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint). WithField("fingerprint", fingerprint).
WithField("issuer", issuer). WithField("issuer", issuer).
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
@@ -270,6 +288,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
if err != nil { if err != nil {
f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr). f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).
WithField("certName", certName). WithField("certName", certName).
WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint). WithField("fingerprint", fingerprint).
WithField("issuer", issuer). WithField("issuer", issuer).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to marshal handshake message") WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to marshal handshake message")
@@ -281,6 +300,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
if err != nil { if err != nil {
f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr). f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).
WithField("certName", certName). WithField("certName", certName).
WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint). WithField("fingerprint", fingerprint).
WithField("issuer", issuer). WithField("issuer", issuer).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage") WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage")
@@ -288,6 +308,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
} else if dKey == nil || eKey == nil { } else if dKey == nil || eKey == nil {
f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr). f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).
WithField("certName", certName). WithField("certName", certName).
WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint). WithField("fingerprint", fingerprint).
WithField("issuer", issuer). WithField("issuer", issuer).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Noise did not arrive at a key") WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Noise did not arrive at a key")
@@ -322,7 +343,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
if existing.SetRemoteIfPreferred(f.hostMap, addr) { if existing.SetRemoteIfPreferred(f.hostMap, addr) {
// Send a test packet to ensure the other side has also switched to // Send a test packet to ensure the other side has also switched to
// the preferred remote // the preferred remote
f.SendMessageToVpnAddr(header.Test, header.TestRequest, vpnAddrs[0], []byte(""), make([]byte, 12, 12), make([]byte, mtu)) f.SendMessageToVpnAddr(header.Test, header.TestRequest, vpnAddrs[0], []byte(""), make([]byte, 12), make([]byte, mtu))
} }
msg = existing.HandshakePacket[2] msg = existing.HandshakePacket[2]
@@ -355,6 +376,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
// This means there was an existing tunnel and this handshake was older than the one we are currently based on // This means there was an existing tunnel and this handshake was older than the one we are currently based on
f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr). f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
WithField("certName", certName). WithField("certName", certName).
WithField("certVersion", certVersion).
WithField("oldHandshakeTime", existing.lastHandshakeTime). WithField("oldHandshakeTime", existing.lastHandshakeTime).
WithField("newHandshakeTime", hostinfo.lastHandshakeTime). WithField("newHandshakeTime", hostinfo.lastHandshakeTime).
WithField("fingerprint", fingerprint). WithField("fingerprint", fingerprint).
@@ -364,12 +386,13 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
Info("Handshake too old") Info("Handshake too old")
// Send a test packet to trigger an authenticated tunnel test, this should suss out any lingering tunnel issues // Send a test packet to trigger an authenticated tunnel test, this should suss out any lingering tunnel issues
f.SendMessageToVpnAddr(header.Test, header.TestRequest, vpnAddrs[0], []byte(""), make([]byte, 12, 12), make([]byte, mtu)) f.SendMessageToVpnAddr(header.Test, header.TestRequest, vpnAddrs[0], []byte(""), make([]byte, 12), make([]byte, mtu))
return return
case ErrLocalIndexCollision: case ErrLocalIndexCollision:
// This means we failed to insert because of collision on localIndexId. Just let the next handshake packet retry // This means we failed to insert because of collision on localIndexId. Just let the next handshake packet retry
f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr). f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
WithField("certName", certName). WithField("certName", certName).
WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint). WithField("fingerprint", fingerprint).
WithField("issuer", issuer). WithField("issuer", issuer).
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
@@ -382,6 +405,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
// And we forget to update it here // And we forget to update it here
f.l.WithError(err).WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr). f.l.WithError(err).WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
WithField("certName", certName). WithField("certName", certName).
WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint). WithField("fingerprint", fingerprint).
WithField("issuer", issuer). WithField("issuer", issuer).
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
@@ -398,6 +422,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
if err != nil { if err != nil {
f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr). f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
WithField("certName", certName). WithField("certName", certName).
WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint). WithField("fingerprint", fingerprint).
WithField("issuer", issuer). WithField("issuer", issuer).
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
@@ -406,6 +431,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
} else { } else {
f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr). f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
WithField("certName", certName). WithField("certName", certName).
WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint). WithField("fingerprint", fingerprint).
WithField("issuer", issuer). WithField("issuer", issuer).
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
@@ -424,6 +450,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
f.SendVia(via.relayHI, via.relay, msg, make([]byte, 12), make([]byte, mtu), false) f.SendVia(via.relayHI, via.relay, msg, make([]byte, 12), make([]byte, mtu), false)
f.l.WithField("vpnAddrs", vpnAddrs).WithField("relay", via.relayHI.vpnAddrs[0]). f.l.WithField("vpnAddrs", vpnAddrs).WithField("relay", via.relayHI.vpnAddrs[0]).
WithField("certName", certName). WithField("certName", certName).
WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint). WithField("fingerprint", fingerprint).
WithField("issuer", issuer). WithField("issuer", issuer).
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
@@ -434,8 +461,6 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
f.connectionManager.AddTrafficWatch(hostinfo.localIndexId) f.connectionManager.AddTrafficWatch(hostinfo.localIndexId)
hostinfo.remotes.ResetBlockedRemotes() hostinfo.remotes.ResetBlockedRemotes()
return
} }
func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *HandshakeHostInfo, packet []byte, h *header.H) bool { func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *HandshakeHostInfo, packet []byte, h *header.H) bool {
@@ -487,35 +512,48 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
return true return true
} }
remoteCert, err := cert.RecombineAndValidate(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), ci.Curve(), f.pki.GetCAPool()) rc, err := cert.Recombine(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), ci.Curve())
if err != nil { if err != nil {
e := f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr). f.l.WithError(err).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}) WithField("vpnAddrs", hostinfo.vpnAddrs).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
Info("Handshake did not contain a certificate")
return true
}
if f.l.Level > logrus.DebugLevel { remoteCert, err := f.pki.GetCAPool().VerifyCertificate(time.Now(), rc)
e = e.WithField("cert", remoteCert) if err != nil {
fp, err := rc.Fingerprint()
if err != nil {
fp = "<error generating certificate fingerprint>"
} }
e.Error("Invalid certificate from host") e := f.l.WithError(err).WithField("udpAddr", addr).
WithField("vpnAddrs", hostinfo.vpnAddrs).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
WithField("certFingerprint", fp).
WithField("certVpnNetworks", rc.Networks())
// The handshake state machine is complete, if things break now there is no chance to recover. Tear down and start again if f.l.Level >= logrus.DebugLevel {
e = e.WithField("cert", rc)
}
e.Info("Invalid certificate from host")
return true return true
} }
if len(remoteCert.Certificate.Networks()) == 0 { if len(remoteCert.Certificate.Networks()) == 0 {
e := f.l.WithError(err).WithField("udpAddr", addr). f.l.WithError(err).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}) WithField("vpnAddrs", hostinfo.vpnAddrs).
WithField("cert", remoteCert).
if f.l.Level > logrus.DebugLevel { WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
e = e.WithField("cert", remoteCert) Info("No networks in certificate")
}
e.Info("Empty networks from host")
return true return true
} }
vpnNetworks := remoteCert.Certificate.Networks() vpnNetworks := remoteCert.Certificate.Networks()
certName := remoteCert.Certificate.Name() certName := remoteCert.Certificate.Name()
certVersion := remoteCert.Certificate.Version()
fingerprint := remoteCert.Fingerprint fingerprint := remoteCert.Fingerprint
issuer := remoteCert.Certificate.Issuer() issuer := remoteCert.Certificate.Issuer()
@@ -550,6 +588,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
if len(vpnAddrs) == 0 { if len(vpnAddrs) == 0 {
f.l.WithError(err).WithField("udpAddr", addr). f.l.WithError(err).WithField("udpAddr", addr).
WithField("certName", certName). WithField("certName", certName).
WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint). WithField("fingerprint", fingerprint).
WithField("issuer", issuer). WithField("issuer", issuer).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).Error("No usable vpn addresses from host, refusing handshake") WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).Error("No usable vpn addresses from host, refusing handshake")
@@ -559,7 +598,9 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
// Ensure the right host responded // Ensure the right host responded
if !slices.Contains(vpnAddrs, hostinfo.vpnAddrs[0]) { if !slices.Contains(vpnAddrs, hostinfo.vpnAddrs[0]) {
f.l.WithField("intendedVpnAddrs", hostinfo.vpnAddrs).WithField("haveVpnNetworks", vpnNetworks). f.l.WithField("intendedVpnAddrs", hostinfo.vpnAddrs).WithField("haveVpnNetworks", vpnNetworks).
WithField("udpAddr", addr).WithField("certName", certName). WithField("udpAddr", addr).
WithField("certName", certName).
WithField("certVersion", certVersion).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
Info("Incorrect host responded to handshake") Info("Incorrect host responded to handshake")
@@ -595,6 +636,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
duration := time.Since(hh.startTime).Nanoseconds() duration := time.Since(hh.startTime).Nanoseconds()
f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr). f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
WithField("certName", certName). WithField("certName", certName).
WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint). WithField("fingerprint", fingerprint).
WithField("issuer", issuer). WithField("issuer", issuer).
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
@@ -616,7 +658,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
} }
if len(hh.packetStore) > 0 { if len(hh.packetStore) > 0 {
nb := make([]byte, 12, 12) nb := make([]byte, 12)
out := make([]byte, mtu) out := make([]byte, mtu)
for _, cp := range hh.packetStore { for _, cp := range hh.packetStore {
cp.callback(cp.messageType, cp.messageSubType, hostinfo, cp.packet, nb, out) cp.callback(cp.messageType, cp.messageSubType, hostinfo, cp.packet, nb, out)

View File

@@ -257,7 +257,7 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
WithField("initiatorIndex", hostinfo.localIndexId). WithField("initiatorIndex", hostinfo.localIndexId).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
Info("Handshake message sent") Info("Handshake message sent")
} else if hm.l.IsLevelEnabled(logrus.DebugLevel) { } else if hm.l.Level >= logrus.DebugLevel {
hostinfo.logger(hm.l).WithField("udpAddrs", sentTo). hostinfo.logger(hm.l).WithField("udpAddrs", sentTo).
WithField("initiatorIndex", hostinfo.localIndexId). WithField("initiatorIndex", hostinfo.localIndexId).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).

View File

@@ -24,10 +24,10 @@ func Test_NewHandshakeManagerVpnIp(t *testing.T) {
lh := newTestLighthouse() lh := newTestLighthouse()
cs := &CertState{ cs := &CertState{
defaultVersion: cert.Version1, initiatingVersion: cert.Version1,
privateKey: []byte{}, privateKey: []byte{},
v1Cert: &dummyCert{version: cert.Version1}, v1Cert: &dummyCert{version: cert.Version1},
v1HandshakeBytes: []byte{}, v1HandshakeBytes: []byte{},
} }
blah := NewHandshakeManager(l, mainHM, lh, &udp.NoopConn{}, defaultHandshakeConfig) blah := NewHandshakeManager(l, mainHM, lh, &udp.NoopConn{}, defaultHandshakeConfig)
@@ -44,7 +44,7 @@ func Test_NewHandshakeManagerVpnIp(t *testing.T) {
i.remotes = NewRemoteList([]netip.Addr{}, nil) i.remotes = NewRemoteList([]netip.Addr{}, nil)
// Adding something to pending should not affect the main hostmap // Adding something to pending should not affect the main hostmap
assert.Len(t, mainHM.Hosts, 0) assert.Empty(t, mainHM.Hosts)
// Confirm they are in the pending index list // Confirm they are in the pending index list
assert.Contains(t, blah.vpnIps, ip) assert.Contains(t, blah.vpnIps, ip)
@@ -65,30 +65,16 @@ func Test_NewHandshakeManagerVpnIp(t *testing.T) {
assert.NotContains(t, blah.vpnIps, ip) assert.NotContains(t, blah.vpnIps, ip)
} }
func testCountTimerWheelEntries(tw *LockingTimerWheel[netip.Addr]) (c int) {
for _, i := range tw.t.wheel {
n := i.Head
for n != nil {
c++
n = n.Next
}
}
return c
}
type mockEncWriter struct { type mockEncWriter struct {
} }
func (mw *mockEncWriter) SendMessageToVpnAddr(_ header.MessageType, _ header.MessageSubType, _ netip.Addr, _, _, _ []byte) { func (mw *mockEncWriter) SendMessageToVpnAddr(_ header.MessageType, _ header.MessageSubType, _ netip.Addr, _, _, _ []byte) {
return
} }
func (mw *mockEncWriter) SendVia(_ *HostInfo, _ *Relay, _, _, _ []byte, _ bool) { func (mw *mockEncWriter) SendVia(_ *HostInfo, _ *Relay, _, _, _ []byte, _ bool) {
return
} }
func (mw *mockEncWriter) SendMessageToHostInfo(_ header.MessageType, _ header.MessageSubType, _ *HostInfo, _, _, _ []byte) { func (mw *mockEncWriter) SendMessageToHostInfo(_ header.MessageType, _ header.MessageSubType, _ *HostInfo, _, _, _ []byte) {
return
} }
func (mw *mockEncWriter) Handshake(_ netip.Addr) {} func (mw *mockEncWriter) Handshake(_ netip.Addr) {}
@@ -98,5 +84,5 @@ func (mw *mockEncWriter) GetHostInfo(_ netip.Addr) *HostInfo {
} }
func (mw *mockEncWriter) GetCertState() *CertState { func (mw *mockEncWriter) GetCertState() *CertState {
return &CertState{defaultVersion: cert.Version2} return &CertState{initiatingVersion: cert.Version2}
} }

View File

@@ -19,11 +19,11 @@ import (
// |-----------------------------------------------------------------------| // |-----------------------------------------------------------------------|
// | payload... | // | payload... |
type m map[string]interface{} type m = map[string]any
const ( const (
Version uint8 = 1 Version uint8 = 1
Len = 16 Len int = 16
) )
type MessageType uint8 type MessageType uint8

View File

@@ -5,6 +5,7 @@ import (
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
type headerTest struct { type headerTest struct {
@@ -111,7 +112,7 @@ func TestHeader_String(t *testing.T) {
func TestHeader_MarshalJSON(t *testing.T) { func TestHeader_MarshalJSON(t *testing.T) {
b, err := (&H{100, Test, TestRequest, 99, 98, 97}).MarshalJSON() b, err := (&H{100, Test, TestRequest, 99, 98, 97}).MarshalJSON()
assert.Nil(t, err) require.NoError(t, err)
assert.Equal( assert.Equal(
t, t,
"{\"messageCounter\":97,\"remoteIndex\":98,\"reserved\":99,\"subType\":\"testRequest\",\"type\":\"test\",\"version\":100}", "{\"messageCounter\":97,\"remoteIndex\":98,\"reserved\":99,\"subType\":\"testRequest\",\"type\":\"test\",\"version\":100}",

View File

@@ -568,7 +568,7 @@ func (hm *HostMap) unlockedAddHostInfo(hostinfo *HostInfo, f *Interface) {
dnsR.Add(remoteCert.Certificate.Name()+".", hostinfo.vpnAddrs) dnsR.Add(remoteCert.Certificate.Name()+".", hostinfo.vpnAddrs)
} }
for _, addr := range hostinfo.vpnAddrs { for _, addr := range hostinfo.vpnAddrs {
hm.unlockedInnerAddHostInfo(addr, hostinfo, f) hm.unlockedInnerAddHostInfo(addr, hostinfo)
} }
hm.Indexes[hostinfo.localIndexId] = hostinfo hm.Indexes[hostinfo.localIndexId] = hostinfo
@@ -581,7 +581,7 @@ func (hm *HostMap) unlockedAddHostInfo(hostinfo *HostInfo, f *Interface) {
} }
} }
func (hm *HostMap) unlockedInnerAddHostInfo(vpnAddr netip.Addr, hostinfo *HostInfo, f *Interface) { func (hm *HostMap) unlockedInnerAddHostInfo(vpnAddr netip.Addr, hostinfo *HostInfo) {
existing := hm.Hosts[vpnAddr] existing := hm.Hosts[vpnAddr]
hm.Hosts[vpnAddr] = hostinfo hm.Hosts[vpnAddr] = hostinfo
@@ -648,7 +648,7 @@ func (i *HostInfo) TryPromoteBest(preferredRanges []netip.Prefix, ifce *Interfac
// Try to send a test packet to that host, this should // Try to send a test packet to that host, this should
// cause it to detect a roaming event and switch remotes // cause it to detect a roaming event and switch remotes
ifce.sendTo(header.Test, header.TestRequest, i.ConnectionState, i, addr, []byte(""), make([]byte, 12, 12), make([]byte, mtu)) ifce.sendTo(header.Test, header.TestRequest, i.ConnectionState, i, addr, []byte(""), make([]byte, 12), make([]byte, mtu))
}) })
} }
@@ -794,7 +794,7 @@ func localAddrs(l *logrus.Logger, allowList *LocalAllowList) []netip.Addr {
} }
addr = addr.Unmap() addr = addr.Unmap()
if addr.IsLoopback() == false && addr.IsLinkLocalUnicast() == false { if !addr.IsLoopback() && !addr.IsLinkLocalUnicast() {
isAllowed := allowList.Allow(addr) isAllowed := allowList.Allow(addr)
if l.Level >= logrus.TraceLevel { if l.Level >= logrus.TraceLevel {
l.WithField("localAddr", addr).WithField("allowed", isAllowed).Trace("localAllowList.Allow") l.WithField("localAddr", addr).WithField("allowed", isAllowed).Trace("localAllowList.Allow")

View File

@@ -210,8 +210,8 @@ func TestHostMap_reload(t *testing.T) {
assert.Empty(t, hm.GetPreferredRanges()) assert.Empty(t, hm.GetPreferredRanges())
c.ReloadConfigString("preferred_ranges: [1.1.1.0/24, 10.1.1.0/24]") c.ReloadConfigString("preferred_ranges: [1.1.1.0/24, 10.1.1.0/24]")
assert.EqualValues(t, []string{"1.1.1.0/24", "10.1.1.0/24"}, toS(hm.GetPreferredRanges())) assert.Equal(t, []string{"1.1.1.0/24", "10.1.1.0/24"}, toS(hm.GetPreferredRanges()))
c.ReloadConfigString("preferred_ranges: [1.1.1.1/32]") c.ReloadConfigString("preferred_ranges: [1.1.1.1/32]")
assert.EqualValues(t, []string{"1.1.1.1/32"}, toS(hm.GetPreferredRanges())) assert.Equal(t, []string{"1.1.1.1/32"}, toS(hm.GetPreferredRanges()))
} }

View File

@@ -8,6 +8,7 @@ import (
"github.com/slackhq/nebula/header" "github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/noiseutil" "github.com/slackhq/nebula/noiseutil"
"github.com/slackhq/nebula/routing"
) )
func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb, out []byte, q int, localCache firewall.ConntrackCache) { func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb, out []byte, q int, localCache firewall.ConntrackCache) {
@@ -49,7 +50,7 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
return return
} }
hostinfo, ready := f.getOrHandshake(fwPacket.RemoteAddr, func(hh *HandshakeHostInfo) { hostinfo, ready := f.getOrHandshakeConsiderRouting(fwPacket, func(hh *HandshakeHostInfo) {
hh.cachePacket(f.l, header.Message, 0, packet, f.sendMessageNow, f.cachedPacketMetrics) hh.cachePacket(f.l, header.Message, 0, packet, f.sendMessageNow, f.cachedPacketMetrics)
}) })
@@ -121,22 +122,94 @@ func (f *Interface) rejectOutside(packet []byte, ci *ConnectionState, hostinfo *
f.sendNoMetrics(header.Message, 0, ci, hostinfo, netip.AddrPort{}, out, nb, packet, q) f.sendNoMetrics(header.Message, 0, ci, hostinfo, netip.AddrPort{}, out, nb, packet, q)
} }
// Handshake will attempt to initiate a tunnel with the provided vpn address if it is within our vpn networks. This is a no-op if the tunnel is already established or being established
func (f *Interface) Handshake(vpnAddr netip.Addr) { func (f *Interface) Handshake(vpnAddr netip.Addr) {
f.getOrHandshake(vpnAddr, nil) f.getOrHandshakeNoRouting(vpnAddr, nil)
} }
// getOrHandshake returns nil if the vpnAddr is not routable. // getOrHandshakeNoRouting returns nil if the vpnAddr is not routable.
// If the 2nd return var is false then the hostinfo is not ready to be used in a tunnel // If the 2nd return var is false then the hostinfo is not ready to be used in a tunnel
func (f *Interface) getOrHandshake(vpnAddr netip.Addr, cacheCallback func(*HandshakeHostInfo)) (*HostInfo, bool) { func (f *Interface) getOrHandshakeNoRouting(vpnAddr netip.Addr, cacheCallback func(*HandshakeHostInfo)) (*HostInfo, bool) {
_, found := f.myVpnNetworksTable.Lookup(vpnAddr) _, found := f.myVpnNetworksTable.Lookup(vpnAddr)
if !found { if found {
vpnAddr = f.inside.RouteFor(vpnAddr) return f.handshakeManager.GetOrHandshake(vpnAddr, cacheCallback)
if !vpnAddr.IsValid() { }
return nil, false
} return nil, false
}
// getOrHandshakeConsiderRouting will try to find the HostInfo to handle this packet, starting a handshake if necessary.
// If the 2nd return var is false then the hostinfo is not ready to be used in a tunnel.
func (f *Interface) getOrHandshakeConsiderRouting(fwPacket *firewall.Packet, cacheCallback func(*HandshakeHostInfo)) (*HostInfo, bool) {
destinationAddr := fwPacket.RemoteAddr
hostinfo, ready := f.getOrHandshakeNoRouting(destinationAddr, cacheCallback)
// Host is inside the mesh, no routing required
if hostinfo != nil {
return hostinfo, ready
}
gateways := f.inside.RoutesFor(destinationAddr)
switch len(gateways) {
case 0:
return nil, false
case 1:
// Single gateway route
return f.handshakeManager.GetOrHandshake(gateways[0].Addr(), cacheCallback)
default:
// Multi gateway route, perform ECMP categorization
gatewayAddr, balancingOk := routing.BalancePacket(fwPacket, gateways)
if !balancingOk {
// This happens if the gateway buckets were not calculated, this _should_ never happen
f.l.Error("Gateway buckets not calculated, fallback from ECMP to random routing. Please report this bug.")
}
var handshakeInfoForChosenGateway *HandshakeHostInfo
var hhReceiver = func(hh *HandshakeHostInfo) {
handshakeInfoForChosenGateway = hh
}
// Store the handshakeHostInfo for later.
// If this node is not reachable we will attempt other nodes, if none are reachable we will
// cache the packet for this gateway.
if hostinfo, ready = f.handshakeManager.GetOrHandshake(gatewayAddr, hhReceiver); ready {
return hostinfo, true
}
// It appears the selected gateway cannot be reached, find another gateway to fallback on.
// The current implementation breaks ECMP but that seems better than no connectivity.
// If ECMP is also required when a gateway is down then connectivity status
// for each gateway needs to be kept and the weights recalculated when they go up or down.
// This would also need to interact with unsafe_route updates through reloading the config or
// use of the use_system_route_table option
if f.l.Level >= logrus.DebugLevel {
f.l.WithField("destination", destinationAddr).
WithField("originalGateway", gatewayAddr).
Debugln("Calculated gateway for ECMP not available, attempting other gateways")
}
for i := range gateways {
// Skip the gateway that failed previously
if gateways[i].Addr() == gatewayAddr {
continue
}
// We do not need the HandshakeHostInfo since we cache the packet in the originally chosen gateway
if hostinfo, ready = f.handshakeManager.GetOrHandshake(gateways[i].Addr(), nil); ready {
return hostinfo, true
}
}
// No gateways reachable, cache the packet in the originally chosen gateway
cacheCallback(handshakeInfoForChosenGateway)
return hostinfo, false
} }
return f.handshakeManager.GetOrHandshake(vpnAddr, cacheCallback)
} }
func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, nb, out []byte) { func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, nb, out []byte) {
@@ -163,7 +236,7 @@ func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubTyp
// SendMessageToVpnAddr handles real addr:port lookup and sends to the current best known address for vpnAddr // SendMessageToVpnAddr handles real addr:port lookup and sends to the current best known address for vpnAddr
func (f *Interface) SendMessageToVpnAddr(t header.MessageType, st header.MessageSubType, vpnAddr netip.Addr, p, nb, out []byte) { func (f *Interface) SendMessageToVpnAddr(t header.MessageType, st header.MessageSubType, vpnAddr netip.Addr, p, nb, out []byte) {
hostInfo, ready := f.getOrHandshake(vpnAddr, func(hh *HandshakeHostInfo) { hostInfo, ready := f.getOrHandshakeNoRouting(vpnAddr, func(hh *HandshakeHostInfo) {
hh.cachePacket(f.l, t, st, p, f.SendMessageToHostInfo, f.cachedPacketMetrics) hh.cachePacket(f.l, t, st, p, f.SendMessageToHostInfo, f.cachedPacketMetrics)
}) })

View File

@@ -266,7 +266,7 @@ func (f *Interface) listenOut(i int) {
plaintext := make([]byte, udp.MTU) plaintext := make([]byte, udp.MTU)
h := &header.H{} h := &header.H{}
fwPacket := &firewall.Packet{} fwPacket := &firewall.Packet{}
nb := make([]byte, 12, 12) nb := make([]byte, 12)
li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) { li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) {
f.readOutsidePackets(fromUdpAddr, nil, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get(f.l)) f.readOutsidePackets(fromUdpAddr, nil, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get(f.l))
@@ -279,7 +279,7 @@ func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
packet := make([]byte, mtu) packet := make([]byte, mtu)
out := make([]byte, mtu) out := make([]byte, mtu)
fwPacket := &firewall.Packet{} fwPacket := &firewall.Packet{}
nb := make([]byte, 12, 12) nb := make([]byte, 12)
conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout) conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
@@ -322,7 +322,7 @@ func (f *Interface) reloadDisconnectInvalid(c *config.C) {
func (f *Interface) reloadFirewall(c *config.C) { func (f *Interface) reloadFirewall(c *config.C) {
//TODO: need to trigger/detect if the certificate changed too //TODO: need to trigger/detect if the certificate changed too
if c.HasChanged("firewall") == false { if !c.HasChanged("firewall") {
f.l.Debug("No firewall config change detected") f.l.Debug("No firewall config change detected")
return return
} }
@@ -410,7 +410,7 @@ func (f *Interface) emitStats(ctx context.Context, i time.Duration) {
udpStats := udp.NewUDPStatsEmitter(f.writers) udpStats := udp.NewUDPStatsEmitter(f.writers)
certExpirationGauge := metrics.GetOrRegisterGauge("certificate.ttl_seconds", nil) certExpirationGauge := metrics.GetOrRegisterGauge("certificate.ttl_seconds", nil)
certDefaultVersion := metrics.GetOrRegisterGauge("certificate.default_version", nil) certInitiatingVersion := metrics.GetOrRegisterGauge("certificate.initiating_version", nil)
certMaxVersion := metrics.GetOrRegisterGauge("certificate.max_version", nil) certMaxVersion := metrics.GetOrRegisterGauge("certificate.max_version", nil)
for { for {
@@ -424,8 +424,8 @@ func (f *Interface) emitStats(ctx context.Context, i time.Duration) {
certState := f.pki.getCertState() certState := f.pki.getCertState()
defaultCrt := certState.GetDefaultCertificate() defaultCrt := certState.GetDefaultCertificate()
certExpirationGauge.Update(int64(defaultCrt.NotAfter().Sub(time.Now()) / time.Second)) certExpirationGauge.Update(int64(time.Until(defaultCrt.NotAfter()) / time.Second))
certDefaultVersion.Update(int64(defaultCrt.Version())) certInitiatingVersion.Update(int64(defaultCrt.Version()))
// Report the max certificate version we are capable of using // Report the max certificate version we are capable of using
if certState.v2Cert != nil { if certState.v2Cert != nil {

View File

@@ -371,7 +371,7 @@ func (lh *LightHouse) parseLighthouses(c *config.C, lhMap map[netip.Addr]struct{
} }
staticList := lh.GetStaticHostList() staticList := lh.GetStaticHostList()
for lhAddr, _ := range lhMap { for lhAddr := range lhMap {
if _, ok := staticList[lhAddr]; !ok { if _, ok := staticList[lhAddr]; !ok {
return fmt.Errorf("lighthouse %s does not have a static_host_map entry", lhAddr) return fmt.Errorf("lighthouse %s does not have a static_host_map entry", lhAddr)
} }
@@ -422,7 +422,7 @@ func (lh *LightHouse) loadStaticMap(c *config.C, staticList map[netip.Addr]struc
return err return err
} }
shm := c.GetMap("static_host_map", map[interface{}]interface{}{}) shm := c.GetMap("static_host_map", map[string]any{})
i := 0 i := 0
for k, v := range shm { for k, v := range shm {
@@ -436,9 +436,9 @@ func (lh *LightHouse) loadStaticMap(c *config.C, staticList map[netip.Addr]struc
return util.NewContextualError("static_host_map key is not in our network, invalid", m{"vpnAddr": vpnAddr, "networks": lh.myVpnNetworks, "entry": i + 1}, nil) return util.NewContextualError("static_host_map key is not in our network, invalid", m{"vpnAddr": vpnAddr, "networks": lh.myVpnNetworks, "entry": i + 1}, nil)
} }
vals, ok := v.([]interface{}) vals, ok := v.([]any)
if !ok { if !ok {
vals = []interface{}{v} vals = []any{v}
} }
remoteAddrs := []string{} remoteAddrs := []string{}
for _, v := range vals { for _, v := range vals {
@@ -654,11 +654,8 @@ func (lh *LightHouse) shouldAdd(vpnAddr netip.Addr, to netip.Addr) bool {
} }
_, found := lh.myVpnNetworksTable.Lookup(to) _, found := lh.myVpnNetworksTable.Lookup(to)
if found {
return false
}
return true return !found
} }
// unlockedShouldAddV4 checks if to is allowed by our allow list // unlockedShouldAddV4 checks if to is allowed by our allow list
@@ -675,11 +672,7 @@ func (lh *LightHouse) unlockedShouldAddV4(vpnAddr netip.Addr, to *V4AddrPort) bo
} }
_, found := lh.myVpnNetworksTable.Lookup(udpAddr.Addr()) _, found := lh.myVpnNetworksTable.Lookup(udpAddr.Addr())
if found { return !found
return false
}
return true
} }
// unlockedShouldAddV6 checks if to is allowed by our allow list // unlockedShouldAddV6 checks if to is allowed by our allow list
@@ -696,11 +689,8 @@ func (lh *LightHouse) unlockedShouldAddV6(vpnAddr netip.Addr, to *V6AddrPort) bo
} }
_, found := lh.myVpnNetworksTable.Lookup(udpAddr.Addr()) _, found := lh.myVpnNetworksTable.Lookup(udpAddr.Addr())
if found {
return false
}
return true return !found
} }
func (lh *LightHouse) IsLighthouseAddr(vpnAddr netip.Addr) bool { func (lh *LightHouse) IsLighthouseAddr(vpnAddr netip.Addr) bool {
@@ -728,7 +718,7 @@ func (lh *LightHouse) startQueryWorker() {
} }
go func() { go func() {
nb := make([]byte, 12, 12) nb := make([]byte, 12)
out := make([]byte, mtu) out := make([]byte, mtu)
for { for {
@@ -763,7 +753,7 @@ func (lh *LightHouse) innerQueryServer(addr netip.Addr, nb, out []byte) {
if hi != nil { if hi != nil {
v = hi.ConnectionState.myCert.Version() v = hi.ConnectionState.myCert.Version()
} else { } else {
v = lh.ifce.GetCertState().defaultVersion v = lh.ifce.GetCertState().initiatingVersion
} }
if v == cert.Version1 { if v == cert.Version1 {
@@ -869,7 +859,7 @@ func (lh *LightHouse) SendUpdate() {
} }
} }
nb := make([]byte, 12, 12) nb := make([]byte, 12)
out := make([]byte, mtu) out := make([]byte, mtu)
var v1Update, v2Update []byte var v1Update, v2Update []byte
@@ -883,7 +873,7 @@ func (lh *LightHouse) SendUpdate() {
if hi != nil { if hi != nil {
v = hi.ConnectionState.myCert.Version() v = hi.ConnectionState.myCert.Version()
} else { } else {
v = lh.ifce.GetCertState().defaultVersion v = lh.ifce.GetCertState().initiatingVersion
} }
if v == cert.Version1 { if v == cert.Version1 {
if v1Update == nil { if v1Update == nil {
@@ -971,7 +961,7 @@ type LightHouseHandler struct {
func (lh *LightHouse) NewRequestHandler() *LightHouseHandler { func (lh *LightHouse) NewRequestHandler() *LightHouseHandler {
lhh := &LightHouseHandler{ lhh := &LightHouseHandler{
lh: lh, lh: lh,
nb: make([]byte, 12, 12), nb: make([]byte, 12),
out: make([]byte, mtu), out: make([]byte, mtu),
l: lh.l, l: lh.l,
pb: make([]byte, mtu), pb: make([]byte, mtu),
@@ -1114,7 +1104,7 @@ func (lhh *LightHouseHandler) sendHostPunchNotification(n *NebulaMeta, fromVpnAd
targetHI := lhh.lh.ifce.GetHostInfo(punchNotifDest) targetHI := lhh.lh.ifce.GetHostInfo(punchNotifDest)
var useVersion cert.Version var useVersion cert.Version
if targetHI == nil { if targetHI == nil {
useVersion = lhh.lh.ifce.GetCertState().defaultVersion useVersion = lhh.lh.ifce.GetCertState().initiatingVersion
} else { } else {
crt := targetHI.GetCert().Certificate crt := targetHI.GetCert().Certificate
useVersion = crt.Version() useVersion = crt.Version()
@@ -1162,7 +1152,7 @@ func (lhh *LightHouseHandler) coalesceAnswers(v cert.Version, c *cache, n *Nebul
if c.v4.learned != nil { if c.v4.learned != nil {
n.Details.V4AddrPorts = append(n.Details.V4AddrPorts, c.v4.learned) n.Details.V4AddrPorts = append(n.Details.V4AddrPorts, c.v4.learned)
} }
if c.v4.reported != nil && len(c.v4.reported) > 0 { if len(c.v4.reported) > 0 {
n.Details.V4AddrPorts = append(n.Details.V4AddrPorts, c.v4.reported...) n.Details.V4AddrPorts = append(n.Details.V4AddrPorts, c.v4.reported...)
} }
} }
@@ -1171,7 +1161,7 @@ func (lhh *LightHouseHandler) coalesceAnswers(v cert.Version, c *cache, n *Nebul
if c.v6.learned != nil { if c.v6.learned != nil {
n.Details.V6AddrPorts = append(n.Details.V6AddrPorts, c.v6.learned) n.Details.V6AddrPorts = append(n.Details.V6AddrPorts, c.v6.learned)
} }
if c.v6.reported != nil && len(c.v6.reported) > 0 { if len(c.v6.reported) > 0 {
n.Details.V6AddrPorts = append(n.Details.V6AddrPorts, c.v6.reported...) n.Details.V6AddrPorts = append(n.Details.V6AddrPorts, c.v6.reported...)
} }
} }
@@ -1369,7 +1359,7 @@ func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, fromVpn
//NOTE: we have to allocate a new output buffer here since we are spawning a new goroutine //NOTE: we have to allocate a new output buffer here since we are spawning a new goroutine
// for each punchBack packet. We should move this into a timerwheel or a single goroutine // for each punchBack packet. We should move this into a timerwheel or a single goroutine
// managed by a channel. // managed by a channel.
w.SendMessageToVpnAddr(header.Test, header.TestRequest, queryVpnAddr, []byte(""), make([]byte, 12, 12), make([]byte, mtu)) w.SendMessageToVpnAddr(header.Test, header.TestRequest, queryVpnAddr, []byte(""), make([]byte, 12), make([]byte, mtu))
}() }()
} }
} }

View File

@@ -13,7 +13,8 @@ import (
"github.com/slackhq/nebula/header" "github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/test" "github.com/slackhq/nebula/test"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"gopkg.in/yaml.v2" "github.com/stretchr/testify/require"
"gopkg.in/yaml.v3"
) )
func TestOldIPv4Only(t *testing.T) { func TestOldIPv4Only(t *testing.T) {
@@ -21,7 +22,7 @@ func TestOldIPv4Only(t *testing.T) {
b := []byte{8, 129, 130, 132, 80, 16, 10} b := []byte{8, 129, 130, 132, 80, 16, 10}
var m V4AddrPort var m V4AddrPort
err := m.Unmarshal(b) err := m.Unmarshal(b)
assert.NoError(t, err) require.NoError(t, err)
ip := netip.MustParseAddr("10.1.1.1") ip := netip.MustParseAddr("10.1.1.1")
bp := ip.As4() bp := ip.As4()
assert.Equal(t, binary.BigEndian.Uint32(bp[:]), m.GetAddr()) assert.Equal(t, binary.BigEndian.Uint32(bp[:]), m.GetAddr())
@@ -39,17 +40,17 @@ func Test_lhStaticMapping(t *testing.T) {
lh1 := "10.128.0.2" lh1 := "10.128.0.2"
c := config.NewC(l) c := config.NewC(l)
c.Settings["lighthouse"] = map[interface{}]interface{}{"hosts": []interface{}{lh1}} c.Settings["lighthouse"] = map[string]any{"hosts": []any{lh1}}
c.Settings["static_host_map"] = map[interface{}]interface{}{lh1: []interface{}{"1.1.1.1:4242"}} c.Settings["static_host_map"] = map[string]any{lh1: []any{"1.1.1.1:4242"}}
_, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil) _, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
assert.Nil(t, err) require.NoError(t, err)
lh2 := "10.128.0.3" lh2 := "10.128.0.3"
c = config.NewC(l) c = config.NewC(l)
c.Settings["lighthouse"] = map[interface{}]interface{}{"hosts": []interface{}{lh1, lh2}} c.Settings["lighthouse"] = map[string]any{"hosts": []any{lh1, lh2}}
c.Settings["static_host_map"] = map[interface{}]interface{}{lh1: []interface{}{"100.1.1.1:4242"}} c.Settings["static_host_map"] = map[string]any{lh1: []any{"100.1.1.1:4242"}}
_, err = NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil) _, err = NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
assert.EqualError(t, err, "lighthouse 10.128.0.3 does not have a static_host_map entry") require.EqualError(t, err, "lighthouse 10.128.0.3 does not have a static_host_map entry")
} }
func TestReloadLighthouseInterval(t *testing.T) { func TestReloadLighthouseInterval(t *testing.T) {
@@ -64,26 +65,26 @@ func TestReloadLighthouseInterval(t *testing.T) {
lh1 := "10.128.0.2" lh1 := "10.128.0.2"
c := config.NewC(l) c := config.NewC(l)
c.Settings["lighthouse"] = map[interface{}]interface{}{ c.Settings["lighthouse"] = map[string]any{
"hosts": []interface{}{lh1}, "hosts": []any{lh1},
"interval": "1s", "interval": "1s",
} }
c.Settings["static_host_map"] = map[interface{}]interface{}{lh1: []interface{}{"1.1.1.1:4242"}} c.Settings["static_host_map"] = map[string]any{lh1: []any{"1.1.1.1:4242"}}
lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil) lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
assert.NoError(t, err) require.NoError(t, err)
lh.ifce = &mockEncWriter{} lh.ifce = &mockEncWriter{}
// The first one routine is kicked off by main.go currently, lets make sure that one dies // The first one routine is kicked off by main.go currently, lets make sure that one dies
assert.NoError(t, c.ReloadConfigString("lighthouse:\n interval: 5")) require.NoError(t, c.ReloadConfigString("lighthouse:\n interval: 5"))
assert.Equal(t, int64(5), lh.interval.Load()) assert.Equal(t, int64(5), lh.interval.Load())
// Subsequent calls are killed off by the LightHouse.Reload function // Subsequent calls are killed off by the LightHouse.Reload function
assert.NoError(t, c.ReloadConfigString("lighthouse:\n interval: 10")) require.NoError(t, c.ReloadConfigString("lighthouse:\n interval: 10"))
assert.Equal(t, int64(10), lh.interval.Load()) assert.Equal(t, int64(10), lh.interval.Load())
// If this completes then nothing is stealing our reload routine // If this completes then nothing is stealing our reload routine
assert.NoError(t, c.ReloadConfigString("lighthouse:\n interval: 11")) require.NoError(t, c.ReloadConfigString("lighthouse:\n interval: 11"))
assert.Equal(t, int64(11), lh.interval.Load()) assert.Equal(t, int64(11), lh.interval.Load())
} }
@@ -99,9 +100,7 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) {
c := config.NewC(l) c := config.NewC(l)
lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil) lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
if !assert.NoError(b, err) { require.NoError(b, err)
b.Fatal()
}
hAddr := netip.MustParseAddrPort("4.5.6.7:12345") hAddr := netip.MustParseAddrPort("4.5.6.7:12345")
hAddr2 := netip.MustParseAddrPort("4.5.6.7:12346") hAddr2 := netip.MustParseAddrPort("4.5.6.7:12346")
@@ -145,7 +144,7 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) {
}, },
} }
p, err := req.Marshal() p, err := req.Marshal()
assert.NoError(b, err) require.NoError(b, err)
for n := 0; n < b.N; n++ { for n := 0; n < b.N; n++ {
lhh.HandleRequest(rAddr, hi, p, mw) lhh.HandleRequest(rAddr, hi, p, mw)
} }
@@ -160,7 +159,7 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) {
}, },
} }
p, err := req.Marshal() p, err := req.Marshal()
assert.NoError(b, err) require.NoError(b, err)
for n := 0; n < b.N; n++ { for n := 0; n < b.N; n++ {
lhh.HandleRequest(rAddr, hi, p, mw) lhh.HandleRequest(rAddr, hi, p, mw)
@@ -193,8 +192,8 @@ func TestLighthouse_Memory(t *testing.T) {
theirVpnIp := netip.MustParseAddr("10.128.0.3") theirVpnIp := netip.MustParseAddr("10.128.0.3")
c := config.NewC(l) c := config.NewC(l)
c.Settings["lighthouse"] = map[interface{}]interface{}{"am_lighthouse": true} c.Settings["lighthouse"] = map[string]any{"am_lighthouse": true}
c.Settings["listen"] = map[interface{}]interface{}{"port": 4242} c.Settings["listen"] = map[string]any{"port": 4242}
myVpnNet := netip.MustParsePrefix("10.128.0.1/24") myVpnNet := netip.MustParsePrefix("10.128.0.1/24")
nt := new(bart.Table[struct{}]) nt := new(bart.Table[struct{}])
@@ -205,7 +204,7 @@ func TestLighthouse_Memory(t *testing.T) {
} }
lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil) lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
lh.ifce = &mockEncWriter{} lh.ifce = &mockEncWriter{}
assert.NoError(t, err) require.NoError(t, err)
lhh := lh.NewRequestHandler() lhh := lh.NewRequestHandler()
// Test that my first update responds with just that // Test that my first update responds with just that
@@ -278,8 +277,8 @@ func TestLighthouse_Memory(t *testing.T) {
func TestLighthouse_reload(t *testing.T) { func TestLighthouse_reload(t *testing.T) {
l := test.NewLogger() l := test.NewLogger()
c := config.NewC(l) c := config.NewC(l)
c.Settings["lighthouse"] = map[interface{}]interface{}{"am_lighthouse": true} c.Settings["lighthouse"] = map[string]any{"am_lighthouse": true}
c.Settings["listen"] = map[interface{}]interface{}{"port": 4242} c.Settings["listen"] = map[string]any{"port": 4242}
myVpnNet := netip.MustParsePrefix("10.128.0.1/24") myVpnNet := netip.MustParsePrefix("10.128.0.1/24")
nt := new(bart.Table[struct{}]) nt := new(bart.Table[struct{}])
@@ -290,19 +289,19 @@ func TestLighthouse_reload(t *testing.T) {
} }
lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil) lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
assert.NoError(t, err) require.NoError(t, err)
nc := map[interface{}]interface{}{ nc := map[string]any{
"static_host_map": map[interface{}]interface{}{ "static_host_map": map[string]any{
"10.128.0.2": []interface{}{"1.1.1.1:4242"}, "10.128.0.2": []any{"1.1.1.1:4242"},
}, },
} }
rc, err := yaml.Marshal(nc) rc, err := yaml.Marshal(nc)
assert.NoError(t, err) require.NoError(t, err)
c.ReloadConfigString(string(rc)) c.ReloadConfigString(string(rc))
err = lh.reload(c, false) err = lh.reload(c, false)
assert.NoError(t, err) require.NoError(t, err)
} }
func newLHHostRequest(fromAddr netip.AddrPort, myVpnIp, queryVpnIp netip.Addr, lhh *LightHouseHandler) testLhReply { func newLHHostRequest(fromAddr netip.AddrPort, myVpnIp, queryVpnIp netip.Addr, lhh *LightHouseHandler) testLhReply {
@@ -418,7 +417,7 @@ func (tw *testEncWriter) GetHostInfo(vpnIp netip.Addr) *HostInfo {
} }
func (tw *testEncWriter) GetCertState() *CertState { func (tw *testEncWriter) GetCertState() *CertState {
return &CertState{defaultVersion: tw.protocolVersion} return &CertState{initiatingVersion: tw.protocolVersion}
} }
// assertIp4InArray asserts every address in want is at the same position in have and that the lengths match // assertIp4InArray asserts every address in want is at the same position in have and that the lengths match
@@ -485,12 +484,12 @@ func Test_findNetworkUnion(t *testing.T) {
assert.Equal(t, out, afe81) assert.Equal(t, out, afe81)
//falsey cases //falsey cases
out, ok = findNetworkUnion([]netip.Prefix{oneSevenTwo, fe80}, []netip.Addr{a1}) _, ok = findNetworkUnion([]netip.Prefix{oneSevenTwo, fe80}, []netip.Addr{a1})
assert.False(t, ok) assert.False(t, ok)
out, ok = findNetworkUnion([]netip.Prefix{fc00, fe80}, []netip.Addr{a1}) _, ok = findNetworkUnion([]netip.Prefix{fc00, fe80}, []netip.Addr{a1})
assert.False(t, ok) assert.False(t, ok)
out, ok = findNetworkUnion([]netip.Prefix{oneSevenTwo, fc00}, []netip.Addr{a1, afe81}) _, ok = findNetworkUnion([]netip.Prefix{oneSevenTwo, fc00}, []netip.Addr{a1, afe81})
assert.False(t, ok) assert.False(t, ok)
out, ok = findNetworkUnion([]netip.Prefix{fc00}, []netip.Addr{a1, afe81}) _, ok = findNetworkUnion([]netip.Prefix{fc00}, []netip.Addr{a1, afe81})
assert.False(t, ok) assert.False(t, ok)
} }

View File

@@ -13,10 +13,10 @@ import (
"github.com/slackhq/nebula/sshd" "github.com/slackhq/nebula/sshd"
"github.com/slackhq/nebula/udp" "github.com/slackhq/nebula/udp"
"github.com/slackhq/nebula/util" "github.com/slackhq/nebula/util"
"gopkg.in/yaml.v2" "gopkg.in/yaml.v3"
) )
type m map[string]interface{} type m = map[string]any
func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logger, deviceFactory overlay.DeviceFactory) (retcon *Control, reterr error) { func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logger, deviceFactory overlay.DeviceFactory) (retcon *Control, reterr error) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())

View File

@@ -17,7 +17,7 @@ type MessageMetrics struct {
func (m *MessageMetrics) Rx(t header.MessageType, s header.MessageSubType, i int64) { func (m *MessageMetrics) Rx(t header.MessageType, s header.MessageSubType, i int64) {
if m != nil { if m != nil {
if t >= 0 && int(t) < len(m.rx) && s >= 0 && int(s) < len(m.rx[t]) { if int(t) < len(m.rx) && int(s) < len(m.rx[t]) {
m.rx[t][s].Inc(i) m.rx[t][s].Inc(i)
} else if m.rxUnknown != nil { } else if m.rxUnknown != nil {
m.rxUnknown.Inc(i) m.rxUnknown.Inc(i)
@@ -26,7 +26,7 @@ func (m *MessageMetrics) Rx(t header.MessageType, s header.MessageSubType, i int
} }
func (m *MessageMetrics) Tx(t header.MessageType, s header.MessageSubType, i int64) { func (m *MessageMetrics) Tx(t header.MessageType, s header.MessageSubType, i int64) {
if m != nil { if m != nil {
if t >= 0 && int(t) < len(m.tx) && s >= 0 && int(s) < len(m.tx[t]) { if int(t) < len(m.tx) && int(s) < len(m.tx[t]) {
m.tx[t][s].Inc(i) m.tx[t][s].Inc(i)
} else if m.txUnknown != nil { } else if m.txUnknown != nil {
m.txUnknown.Inc(i) m.txUnknown.Inc(i)

View File

@@ -1,18 +0,0 @@
package nebula
/*
import (
proto "google.golang.org/protobuf/proto"
)
func HandleMetaProto(p []byte) {
m := &NebulaMeta{}
err := proto.Unmarshal(p, m)
if err != nil {
l.Debugf("problem unmarshaling meta message: %s", err)
}
//fmt.Println(m)
}
*/

View File

@@ -228,7 +228,7 @@ func (f *Interface) closeTunnel(hostInfo *HostInfo) {
// sendCloseTunnel is a helper function to send a proper close tunnel packet to a remote // sendCloseTunnel is a helper function to send a proper close tunnel packet to a remote
func (f *Interface) sendCloseTunnel(h *HostInfo) { func (f *Interface) sendCloseTunnel(h *HostInfo) {
f.send(header.CloseTunnel, 0, h.ConnectionState, h, []byte{}, make([]byte, 12, 12), make([]byte, mtu)) f.send(header.CloseTunnel, 0, h.ConnectionState, h, []byte{}, make([]byte, 12), make([]byte, mtu))
} }
func (f *Interface) handleHostRoaming(hostinfo *HostInfo, udpAddr netip.AddrPort) { func (f *Interface) handleHostRoaming(hostinfo *HostInfo, udpAddr netip.AddrPort) {

View File

@@ -12,6 +12,7 @@ import (
"github.com/slackhq/nebula/firewall" "github.com/slackhq/nebula/firewall"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/net/ipv4" "golang.org/x/net/ipv4"
) )
@@ -20,13 +21,13 @@ func Test_newPacket(t *testing.T) {
// length fails // length fails
err := newPacket([]byte{}, true, p) err := newPacket([]byte{}, true, p)
assert.ErrorIs(t, err, ErrPacketTooShort) require.ErrorIs(t, err, ErrPacketTooShort)
err = newPacket([]byte{0x40}, true, p) err = newPacket([]byte{0x40}, true, p)
assert.ErrorIs(t, err, ErrIPv4PacketTooShort) require.ErrorIs(t, err, ErrIPv4PacketTooShort)
err = newPacket([]byte{0x60}, true, p) err = newPacket([]byte{0x60}, true, p)
assert.ErrorIs(t, err, ErrIPv6PacketTooShort) require.ErrorIs(t, err, ErrIPv6PacketTooShort)
// length fail with ip options // length fail with ip options
h := ipv4.Header{ h := ipv4.Header{
@@ -39,15 +40,15 @@ func Test_newPacket(t *testing.T) {
b, _ := h.Marshal() b, _ := h.Marshal()
err = newPacket(b, true, p) err = newPacket(b, true, p)
assert.ErrorIs(t, err, ErrIPv4InvalidHeaderLength) require.ErrorIs(t, err, ErrIPv4InvalidHeaderLength)
// not an ipv4 packet // not an ipv4 packet
err = newPacket([]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, true, p) err = newPacket([]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, true, p)
assert.ErrorIs(t, err, ErrUnknownIPVersion) require.ErrorIs(t, err, ErrUnknownIPVersion)
// invalid ihl // invalid ihl
err = newPacket([]byte{4<<4 | (8 >> 2 & 0x0f), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, true, p) err = newPacket([]byte{4<<4 | (8 >> 2 & 0x0f), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, true, p)
assert.ErrorIs(t, err, ErrIPv4InvalidHeaderLength) require.ErrorIs(t, err, ErrIPv4InvalidHeaderLength)
// account for variable ip header length - incoming // account for variable ip header length - incoming
h = ipv4.Header{ h = ipv4.Header{
@@ -63,7 +64,7 @@ func Test_newPacket(t *testing.T) {
b = append(b, []byte{0, 3, 0, 4}...) b = append(b, []byte{0, 3, 0, 4}...)
err = newPacket(b, true, p) err = newPacket(b, true, p)
assert.Nil(t, err) require.NoError(t, err)
assert.Equal(t, uint8(firewall.ProtoTCP), p.Protocol) assert.Equal(t, uint8(firewall.ProtoTCP), p.Protocol)
assert.Equal(t, netip.MustParseAddr("10.0.0.2"), p.LocalAddr) assert.Equal(t, netip.MustParseAddr("10.0.0.2"), p.LocalAddr)
assert.Equal(t, netip.MustParseAddr("10.0.0.1"), p.RemoteAddr) assert.Equal(t, netip.MustParseAddr("10.0.0.1"), p.RemoteAddr)
@@ -85,7 +86,7 @@ func Test_newPacket(t *testing.T) {
b = append(b, []byte{0, 5, 0, 6}...) b = append(b, []byte{0, 5, 0, 6}...)
err = newPacket(b, false, p) err = newPacket(b, false, p)
assert.Nil(t, err) require.NoError(t, err)
assert.Equal(t, uint8(2), p.Protocol) assert.Equal(t, uint8(2), p.Protocol)
assert.Equal(t, netip.MustParseAddr("10.0.0.1"), p.LocalAddr) assert.Equal(t, netip.MustParseAddr("10.0.0.1"), p.LocalAddr)
assert.Equal(t, netip.MustParseAddr("10.0.0.2"), p.RemoteAddr) assert.Equal(t, netip.MustParseAddr("10.0.0.2"), p.RemoteAddr)
@@ -111,10 +112,10 @@ func Test_newPacket_v6(t *testing.T) {
FixLengths: false, FixLengths: false,
} }
err := gopacket.SerializeLayers(buffer, opt, &ip) err := gopacket.SerializeLayers(buffer, opt, &ip)
assert.NoError(t, err) require.NoError(t, err)
err = newPacket(buffer.Bytes(), true, p) err = newPacket(buffer.Bytes(), true, p)
assert.ErrorIs(t, err, ErrIPv6CouldNotFindPayload) require.ErrorIs(t, err, ErrIPv6CouldNotFindPayload)
// A good ICMP packet // A good ICMP packet
ip = layers.IPv6{ ip = layers.IPv6{
@@ -134,7 +135,7 @@ func Test_newPacket_v6(t *testing.T) {
} }
err = newPacket(buffer.Bytes(), true, p) err = newPacket(buffer.Bytes(), true, p)
assert.Nil(t, err) require.NoError(t, err)
assert.Equal(t, uint8(layers.IPProtocolICMPv6), p.Protocol) assert.Equal(t, uint8(layers.IPProtocolICMPv6), p.Protocol)
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr) assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr) assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr)
@@ -146,7 +147,7 @@ func Test_newPacket_v6(t *testing.T) {
b := buffer.Bytes() b := buffer.Bytes()
b[6] = byte(layers.IPProtocolESP) b[6] = byte(layers.IPProtocolESP)
err = newPacket(b, true, p) err = newPacket(b, true, p)
assert.Nil(t, err) require.NoError(t, err)
assert.Equal(t, uint8(layers.IPProtocolESP), p.Protocol) assert.Equal(t, uint8(layers.IPProtocolESP), p.Protocol)
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr) assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr) assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr)
@@ -158,7 +159,7 @@ func Test_newPacket_v6(t *testing.T) {
b = buffer.Bytes() b = buffer.Bytes()
b[6] = byte(layers.IPProtocolNoNextHeader) b[6] = byte(layers.IPProtocolNoNextHeader)
err = newPacket(b, true, p) err = newPacket(b, true, p)
assert.Nil(t, err) require.NoError(t, err)
assert.Equal(t, uint8(layers.IPProtocolNoNextHeader), p.Protocol) assert.Equal(t, uint8(layers.IPProtocolNoNextHeader), p.Protocol)
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr) assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr) assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr)
@@ -170,7 +171,7 @@ func Test_newPacket_v6(t *testing.T) {
b = buffer.Bytes() b = buffer.Bytes()
b[6] = 255 // 255 is a reserved protocol number b[6] = 255 // 255 is a reserved protocol number
err = newPacket(b, true, p) err = newPacket(b, true, p)
assert.ErrorIs(t, err, ErrIPv6CouldNotFindPayload) require.ErrorIs(t, err, ErrIPv6CouldNotFindPayload)
// A good UDP packet // A good UDP packet
ip = layers.IPv6{ ip = layers.IPv6{
@@ -186,7 +187,7 @@ func Test_newPacket_v6(t *testing.T) {
DstPort: layers.UDPPort(22), DstPort: layers.UDPPort(22),
} }
err = udp.SetNetworkLayerForChecksum(&ip) err = udp.SetNetworkLayerForChecksum(&ip)
assert.NoError(t, err) require.NoError(t, err)
buffer.Clear() buffer.Clear()
err = gopacket.SerializeLayers(buffer, opt, &ip, &udp, gopacket.Payload([]byte{0xde, 0xad, 0xbe, 0xef})) err = gopacket.SerializeLayers(buffer, opt, &ip, &udp, gopacket.Payload([]byte{0xde, 0xad, 0xbe, 0xef}))
@@ -197,7 +198,7 @@ func Test_newPacket_v6(t *testing.T) {
// incoming // incoming
err = newPacket(b, true, p) err = newPacket(b, true, p)
assert.Nil(t, err) require.NoError(t, err)
assert.Equal(t, uint8(firewall.ProtoUDP), p.Protocol) assert.Equal(t, uint8(firewall.ProtoUDP), p.Protocol)
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr) assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr) assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr)
@@ -207,7 +208,7 @@ func Test_newPacket_v6(t *testing.T) {
// outgoing // outgoing
err = newPacket(b, false, p) err = newPacket(b, false, p)
assert.Nil(t, err) require.NoError(t, err)
assert.Equal(t, uint8(firewall.ProtoUDP), p.Protocol) assert.Equal(t, uint8(firewall.ProtoUDP), p.Protocol)
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.LocalAddr) assert.Equal(t, netip.MustParseAddr("ff02::2"), p.LocalAddr)
assert.Equal(t, netip.MustParseAddr("ff02::1"), p.RemoteAddr) assert.Equal(t, netip.MustParseAddr("ff02::1"), p.RemoteAddr)
@@ -217,14 +218,14 @@ func Test_newPacket_v6(t *testing.T) {
// Too short UDP packet // Too short UDP packet
err = newPacket(b[:len(b)-10], false, p) // pull off the last 10 bytes err = newPacket(b[:len(b)-10], false, p) // pull off the last 10 bytes
assert.ErrorIs(t, err, ErrIPv6PacketTooShort) require.ErrorIs(t, err, ErrIPv6PacketTooShort)
// A good TCP packet // A good TCP packet
b[6] = byte(layers.IPProtocolTCP) b[6] = byte(layers.IPProtocolTCP)
// incoming // incoming
err = newPacket(b, true, p) err = newPacket(b, true, p)
assert.Nil(t, err) require.NoError(t, err)
assert.Equal(t, uint8(firewall.ProtoTCP), p.Protocol) assert.Equal(t, uint8(firewall.ProtoTCP), p.Protocol)
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr) assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr) assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr)
@@ -234,7 +235,7 @@ func Test_newPacket_v6(t *testing.T) {
// outgoing // outgoing
err = newPacket(b, false, p) err = newPacket(b, false, p)
assert.Nil(t, err) require.NoError(t, err)
assert.Equal(t, uint8(firewall.ProtoTCP), p.Protocol) assert.Equal(t, uint8(firewall.ProtoTCP), p.Protocol)
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.LocalAddr) assert.Equal(t, netip.MustParseAddr("ff02::2"), p.LocalAddr)
assert.Equal(t, netip.MustParseAddr("ff02::1"), p.RemoteAddr) assert.Equal(t, netip.MustParseAddr("ff02::1"), p.RemoteAddr)
@@ -244,7 +245,7 @@ func Test_newPacket_v6(t *testing.T) {
// Too short TCP packet // Too short TCP packet
err = newPacket(b[:len(b)-10], false, p) // pull off the last 10 bytes err = newPacket(b[:len(b)-10], false, p) // pull off the last 10 bytes
assert.ErrorIs(t, err, ErrIPv6PacketTooShort) require.ErrorIs(t, err, ErrIPv6PacketTooShort)
// A good UDP packet with an AH header // A good UDP packet with an AH header
ip = layers.IPv6{ ip = layers.IPv6{
@@ -279,7 +280,7 @@ func Test_newPacket_v6(t *testing.T) {
b = append(b, udpHeader...) b = append(b, udpHeader...)
err = newPacket(b, true, p) err = newPacket(b, true, p)
assert.Nil(t, err) require.NoError(t, err)
assert.Equal(t, uint8(firewall.ProtoUDP), p.Protocol) assert.Equal(t, uint8(firewall.ProtoUDP), p.Protocol)
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr) assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr) assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr)
@@ -290,7 +291,7 @@ func Test_newPacket_v6(t *testing.T) {
// Invalid AH header // Invalid AH header
b = buffer.Bytes() b = buffer.Bytes()
err = newPacket(b, true, p) err = newPacket(b, true, p)
assert.ErrorIs(t, err, ErrIPv6CouldNotFindPayload) require.ErrorIs(t, err, ErrIPv6CouldNotFindPayload)
} }
func Test_newPacket_ipv6Fragment(t *testing.T) { func Test_newPacket_ipv6Fragment(t *testing.T) {
@@ -338,7 +339,7 @@ func Test_newPacket_ipv6Fragment(t *testing.T) {
// Test first fragment incoming // Test first fragment incoming
err = newPacket(firstFrag, true, p) err = newPacket(firstFrag, true, p)
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr) assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr) assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr)
assert.Equal(t, uint8(layers.IPProtocolUDP), p.Protocol) assert.Equal(t, uint8(layers.IPProtocolUDP), p.Protocol)
@@ -348,7 +349,7 @@ func Test_newPacket_ipv6Fragment(t *testing.T) {
// Test first fragment outgoing // Test first fragment outgoing
err = newPacket(firstFrag, false, p) err = newPacket(firstFrag, false, p)
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.LocalAddr) assert.Equal(t, netip.MustParseAddr("ff02::2"), p.LocalAddr)
assert.Equal(t, netip.MustParseAddr("ff02::1"), p.RemoteAddr) assert.Equal(t, netip.MustParseAddr("ff02::1"), p.RemoteAddr)
assert.Equal(t, uint8(layers.IPProtocolUDP), p.Protocol) assert.Equal(t, uint8(layers.IPProtocolUDP), p.Protocol)
@@ -377,7 +378,7 @@ func Test_newPacket_ipv6Fragment(t *testing.T) {
// Test second fragment incoming // Test second fragment incoming
err = newPacket(secondFrag, true, p) err = newPacket(secondFrag, true, p)
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr) assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr) assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr)
assert.Equal(t, uint8(layers.IPProtocolUDP), p.Protocol) assert.Equal(t, uint8(layers.IPProtocolUDP), p.Protocol)
@@ -387,7 +388,7 @@ func Test_newPacket_ipv6Fragment(t *testing.T) {
// Test second fragment outgoing // Test second fragment outgoing
err = newPacket(secondFrag, false, p) err = newPacket(secondFrag, false, p)
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.LocalAddr) assert.Equal(t, netip.MustParseAddr("ff02::2"), p.LocalAddr)
assert.Equal(t, netip.MustParseAddr("ff02::1"), p.RemoteAddr) assert.Equal(t, netip.MustParseAddr("ff02::1"), p.RemoteAddr)
assert.Equal(t, uint8(layers.IPProtocolUDP), p.Protocol) assert.Equal(t, uint8(layers.IPProtocolUDP), p.Protocol)
@@ -397,7 +398,7 @@ func Test_newPacket_ipv6Fragment(t *testing.T) {
// Too short of a fragment packet // Too short of a fragment packet
err = newPacket(secondFrag[:len(secondFrag)-10], false, p) err = newPacket(secondFrag[:len(secondFrag)-10], false, p)
assert.ErrorIs(t, err, ErrIPv6PacketTooShort) require.ErrorIs(t, err, ErrIPv6PacketTooShort)
} }
func BenchmarkParseV6(b *testing.B) { func BenchmarkParseV6(b *testing.B) {

View File

@@ -3,6 +3,8 @@ package overlay
import ( import (
"io" "io"
"net/netip" "net/netip"
"github.com/slackhq/nebula/routing"
) )
type Device interface { type Device interface {
@@ -10,6 +12,6 @@ type Device interface {
Activate() error Activate() error
Networks() []netip.Prefix Networks() []netip.Prefix
Name() string Name() string
RouteFor(netip.Addr) netip.Addr RoutesFor(netip.Addr) routing.Gateways
NewMultiQueueReader() (io.ReadWriteCloser, error) NewMultiQueueReader() (io.ReadWriteCloser, error)
} }

View File

@@ -3,7 +3,6 @@ package overlay
import ( import (
"fmt" "fmt"
"math" "math"
"net"
"net/netip" "net/netip"
"runtime" "runtime"
"strconv" "strconv"
@@ -11,13 +10,14 @@ import (
"github.com/gaissmai/bart" "github.com/gaissmai/bart"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/routing"
) )
type Route struct { type Route struct {
MTU int MTU int
Metric int Metric int
Cidr netip.Prefix Cidr netip.Prefix
Via netip.Addr Via routing.Gateways
Install bool Install bool
} }
@@ -47,15 +47,17 @@ func (r Route) String() string {
return s return s
} }
func makeRouteTree(l *logrus.Logger, routes []Route, allowMTU bool) (*bart.Table[netip.Addr], error) { func makeRouteTree(l *logrus.Logger, routes []Route, allowMTU bool) (*bart.Table[routing.Gateways], error) {
routeTree := new(bart.Table[netip.Addr]) routeTree := new(bart.Table[routing.Gateways])
for _, r := range routes { for _, r := range routes {
if !allowMTU && r.MTU > 0 { if !allowMTU && r.MTU > 0 {
l.WithField("route", r).Warnf("route MTU is not supported in %s", runtime.GOOS) l.WithField("route", r).Warnf("route MTU is not supported in %s", runtime.GOOS)
} }
if r.Via.IsValid() { gateways := r.Via
routeTree.Insert(r.Cidr, r.Via) if len(gateways) > 0 {
routing.CalculateBucketsForGateways(gateways)
routeTree.Insert(r.Cidr, gateways)
} }
} }
return routeTree, nil return routeTree, nil
@@ -69,7 +71,7 @@ func parseRoutes(c *config.C, networks []netip.Prefix) ([]Route, error) {
return []Route{}, nil return []Route{}, nil
} }
rawRoutes, ok := r.([]interface{}) rawRoutes, ok := r.([]any)
if !ok { if !ok {
return nil, fmt.Errorf("tun.routes is not an array") return nil, fmt.Errorf("tun.routes is not an array")
} }
@@ -80,7 +82,7 @@ func parseRoutes(c *config.C, networks []netip.Prefix) ([]Route, error) {
routes := make([]Route, len(rawRoutes)) routes := make([]Route, len(rawRoutes))
for i, r := range rawRoutes { for i, r := range rawRoutes {
m, ok := r.(map[interface{}]interface{}) m, ok := r.(map[string]any)
if !ok { if !ok {
return nil, fmt.Errorf("entry %v in tun.routes is invalid", i+1) return nil, fmt.Errorf("entry %v in tun.routes is invalid", i+1)
} }
@@ -148,7 +150,7 @@ func parseUnsafeRoutes(c *config.C, networks []netip.Prefix) ([]Route, error) {
return []Route{}, nil return []Route{}, nil
} }
rawRoutes, ok := r.([]interface{}) rawRoutes, ok := r.([]any)
if !ok { if !ok {
return nil, fmt.Errorf("tun.unsafe_routes is not an array") return nil, fmt.Errorf("tun.unsafe_routes is not an array")
} }
@@ -159,7 +161,7 @@ func parseUnsafeRoutes(c *config.C, networks []netip.Prefix) ([]Route, error) {
routes := make([]Route, len(rawRoutes)) routes := make([]Route, len(rawRoutes))
for i, r := range rawRoutes { for i, r := range rawRoutes {
m, ok := r.(map[interface{}]interface{}) m, ok := r.(map[string]any)
if !ok { if !ok {
return nil, fmt.Errorf("entry %v in tun.unsafe_routes is invalid", i+1) return nil, fmt.Errorf("entry %v in tun.unsafe_routes is invalid", i+1)
} }
@@ -201,14 +203,63 @@ func parseUnsafeRoutes(c *config.C, networks []netip.Prefix) ([]Route, error) {
return nil, fmt.Errorf("entry %v.via in tun.unsafe_routes is not present", i+1) return nil, fmt.Errorf("entry %v.via in tun.unsafe_routes is not present", i+1)
} }
via, ok := rVia.(string) var gateways routing.Gateways
if !ok {
return nil, fmt.Errorf("entry %v.via in tun.unsafe_routes is not a string: found %T", i+1, rVia)
}
viaVpnIp, err := netip.ParseAddr(via) switch via := rVia.(type) {
if err != nil { case string:
return nil, fmt.Errorf("entry %v.via in tun.unsafe_routes failed to parse address: %v", i+1, err) viaIp, err := netip.ParseAddr(via)
if err != nil {
return nil, fmt.Errorf("entry %v.via in tun.unsafe_routes failed to parse address: %v", i+1, err)
}
gateways = routing.Gateways{routing.NewGateway(viaIp, 1)}
case []any:
gateways = make(routing.Gateways, len(via))
for ig, v := range via {
gatewayMap, ok := v.(map[string]any)
if !ok {
return nil, fmt.Errorf("entry %v in tun.unsafe_routes[%v].via is invalid", i+1, ig+1)
}
rGateway, ok := gatewayMap["gateway"]
if !ok {
return nil, fmt.Errorf("entry .gateway in tun.unsafe_routes[%v].via[%v] is not present", i+1, ig+1)
}
parsedGateway, ok := rGateway.(string)
if !ok {
return nil, fmt.Errorf("entry .gateway in tun.unsafe_routes[%v].via[%v] is not a string", i+1, ig+1)
}
gatewayIp, err := netip.ParseAddr(parsedGateway)
if err != nil {
return nil, fmt.Errorf("entry .gateway in tun.unsafe_routes[%v].via[%v] failed to parse address: %v", i+1, ig+1, err)
}
rGatewayWeight, ok := gatewayMap["weight"]
if !ok {
rGatewayWeight = 1
}
gatewayWeight, ok := rGatewayWeight.(int)
if !ok {
_, err = strconv.ParseInt(rGatewayWeight.(string), 10, 32)
if err != nil {
return nil, fmt.Errorf("entry .weight in tun.unsafe_routes[%v].via[%v] is not an integer", i+1, ig+1)
}
}
if gatewayWeight < 1 || gatewayWeight > math.MaxInt32 {
return nil, fmt.Errorf("entry .weight in tun.unsafe_routes[%v].via[%v] is not in range (1-%d) : %v", i+1, ig+1, math.MaxInt32, gatewayWeight)
}
gateways[ig] = routing.NewGateway(gatewayIp, gatewayWeight)
}
default:
return nil, fmt.Errorf("entry %v.via in tun.unsafe_routes is not a string or list of gateways: found %T", i+1, rVia)
} }
rRoute, ok := m["route"] rRoute, ok := m["route"]
@@ -226,7 +277,7 @@ func parseUnsafeRoutes(c *config.C, networks []netip.Prefix) ([]Route, error) {
} }
r := Route{ r := Route{
Via: viaVpnIp, Via: gateways,
MTU: mtu, MTU: mtu,
Metric: metric, Metric: metric,
Install: install, Install: install,
@@ -253,29 +304,3 @@ func parseUnsafeRoutes(c *config.C, networks []netip.Prefix) ([]Route, error) {
return routes, nil return routes, nil
} }
func ipWithin(o *net.IPNet, i *net.IPNet) bool {
// Make sure o contains the lowest form of i
if !o.Contains(i.IP.Mask(i.Mask)) {
return false
}
// Find the max ip in i
ip4 := i.IP.To4()
if ip4 == nil {
return false
}
last := make(net.IP, len(ip4))
copy(last, ip4)
for x := range ip4 {
last[x] |= ^i.Mask[x]
}
// Make sure o contains the max
if !o.Contains(last) {
return false
}
return true
}

View File

@@ -6,94 +6,96 @@ import (
"testing" "testing"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/routing"
"github.com/slackhq/nebula/test" "github.com/slackhq/nebula/test"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
func Test_parseRoutes(t *testing.T) { func Test_parseRoutes(t *testing.T) {
l := test.NewLogger() l := test.NewLogger()
c := config.NewC(l) c := config.NewC(l)
n, err := netip.ParsePrefix("10.0.0.0/24") n, err := netip.ParsePrefix("10.0.0.0/24")
assert.NoError(t, err) require.NoError(t, err)
// test no routes config // test no routes config
routes, err := parseRoutes(c, []netip.Prefix{n}) routes, err := parseRoutes(c, []netip.Prefix{n})
assert.Nil(t, err) require.NoError(t, err)
assert.Len(t, routes, 0) assert.Empty(t, routes)
// not an array // not an array
c.Settings["tun"] = map[interface{}]interface{}{"routes": "hi"} c.Settings["tun"] = map[string]any{"routes": "hi"}
routes, err = parseRoutes(c, []netip.Prefix{n}) routes, err = parseRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes) assert.Nil(t, routes)
assert.EqualError(t, err, "tun.routes is not an array") require.EqualError(t, err, "tun.routes is not an array")
// no routes // no routes
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{}} c.Settings["tun"] = map[string]any{"routes": []any{}}
routes, err = parseRoutes(c, []netip.Prefix{n}) routes, err = parseRoutes(c, []netip.Prefix{n})
assert.Nil(t, err) require.NoError(t, err)
assert.Len(t, routes, 0) assert.Empty(t, routes)
// weird route // weird route
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{"asdf"}} c.Settings["tun"] = map[string]any{"routes": []any{"asdf"}}
routes, err = parseRoutes(c, []netip.Prefix{n}) routes, err = parseRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes) assert.Nil(t, routes)
assert.EqualError(t, err, "entry 1 in tun.routes is invalid") require.EqualError(t, err, "entry 1 in tun.routes is invalid")
// no mtu // no mtu
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{}}} c.Settings["tun"] = map[string]any{"routes": []any{map[string]any{}}}
routes, err = parseRoutes(c, []netip.Prefix{n}) routes, err = parseRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes) assert.Nil(t, routes)
assert.EqualError(t, err, "entry 1.mtu in tun.routes is not present") require.EqualError(t, err, "entry 1.mtu in tun.routes is not present")
// bad mtu // bad mtu
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "nope"}}} c.Settings["tun"] = map[string]any{"routes": []any{map[string]any{"mtu": "nope"}}}
routes, err = parseRoutes(c, []netip.Prefix{n}) routes, err = parseRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes) assert.Nil(t, routes)
assert.EqualError(t, err, "entry 1.mtu in tun.routes is not an integer: strconv.Atoi: parsing \"nope\": invalid syntax") require.EqualError(t, err, "entry 1.mtu in tun.routes is not an integer: strconv.Atoi: parsing \"nope\": invalid syntax")
// low mtu // low mtu
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "499"}}} c.Settings["tun"] = map[string]any{"routes": []any{map[string]any{"mtu": "499"}}}
routes, err = parseRoutes(c, []netip.Prefix{n}) routes, err = parseRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes) assert.Nil(t, routes)
assert.EqualError(t, err, "entry 1.mtu in tun.routes is below 500: 499") require.EqualError(t, err, "entry 1.mtu in tun.routes is below 500: 499")
// missing route // missing route
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500"}}} c.Settings["tun"] = map[string]any{"routes": []any{map[string]any{"mtu": "500"}}}
routes, err = parseRoutes(c, []netip.Prefix{n}) routes, err = parseRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes) assert.Nil(t, routes)
assert.EqualError(t, err, "entry 1.route in tun.routes is not present") require.EqualError(t, err, "entry 1.route in tun.routes is not present")
// unparsable route // unparsable route
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "nope"}}} c.Settings["tun"] = map[string]any{"routes": []any{map[string]any{"mtu": "500", "route": "nope"}}}
routes, err = parseRoutes(c, []netip.Prefix{n}) routes, err = parseRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes) assert.Nil(t, routes)
assert.EqualError(t, err, "entry 1.route in tun.routes failed to parse: netip.ParsePrefix(\"nope\"): no '/'") require.EqualError(t, err, "entry 1.route in tun.routes failed to parse: netip.ParsePrefix(\"nope\"): no '/'")
// below network range // below network range
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "1.0.0.0/8"}}} c.Settings["tun"] = map[string]any{"routes": []any{map[string]any{"mtu": "500", "route": "1.0.0.0/8"}}}
routes, err = parseRoutes(c, []netip.Prefix{n}) routes, err = parseRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes) assert.Nil(t, routes)
assert.EqualError(t, err, "entry 1.route in tun.routes is not contained within the configured vpn networks; route: 1.0.0.0/8, networks: [10.0.0.0/24]") require.EqualError(t, err, "entry 1.route in tun.routes is not contained within the configured vpn networks; route: 1.0.0.0/8, networks: [10.0.0.0/24]")
// above network range // above network range
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "10.0.1.0/24"}}} c.Settings["tun"] = map[string]any{"routes": []any{map[string]any{"mtu": "500", "route": "10.0.1.0/24"}}}
routes, err = parseRoutes(c, []netip.Prefix{n}) routes, err = parseRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes) assert.Nil(t, routes)
assert.EqualError(t, err, "entry 1.route in tun.routes is not contained within the configured vpn networks; route: 10.0.1.0/24, networks: [10.0.0.0/24]") require.EqualError(t, err, "entry 1.route in tun.routes is not contained within the configured vpn networks; route: 10.0.1.0/24, networks: [10.0.0.0/24]")
// Not in multiple ranges // Not in multiple ranges
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "192.0.0.0/24"}}} c.Settings["tun"] = map[string]any{"routes": []any{map[string]any{"mtu": "500", "route": "192.0.0.0/24"}}}
routes, err = parseRoutes(c, []netip.Prefix{n, netip.MustParsePrefix("192.1.0.0/24")}) routes, err = parseRoutes(c, []netip.Prefix{n, netip.MustParsePrefix("192.1.0.0/24")})
assert.Nil(t, routes) assert.Nil(t, routes)
assert.EqualError(t, err, "entry 1.route in tun.routes is not contained within the configured vpn networks; route: 192.0.0.0/24, networks: [10.0.0.0/24 192.1.0.0/24]") require.EqualError(t, err, "entry 1.route in tun.routes is not contained within the configured vpn networks; route: 192.0.0.0/24, networks: [10.0.0.0/24 192.1.0.0/24]")
// happy case // happy case
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{ c.Settings["tun"] = map[string]any{"routes": []any{
map[interface{}]interface{}{"mtu": "9000", "route": "10.0.0.0/29"}, map[string]any{"mtu": "9000", "route": "10.0.0.0/29"},
map[interface{}]interface{}{"mtu": "8000", "route": "10.0.0.1/32"}, map[string]any{"mtu": "8000", "route": "10.0.0.1/32"},
}} }}
routes, err = parseRoutes(c, []netip.Prefix{n}) routes, err = parseRoutes(c, []netip.Prefix{n})
assert.Nil(t, err) require.NoError(t, err)
assert.Len(t, routes, 2) assert.Len(t, routes, 2)
tested := 0 tested := 0
@@ -119,116 +121,141 @@ func Test_parseUnsafeRoutes(t *testing.T) {
l := test.NewLogger() l := test.NewLogger()
c := config.NewC(l) c := config.NewC(l)
n, err := netip.ParsePrefix("10.0.0.0/24") n, err := netip.ParsePrefix("10.0.0.0/24")
assert.NoError(t, err) require.NoError(t, err)
// test no routes config // test no routes config
routes, err := parseUnsafeRoutes(c, []netip.Prefix{n}) routes, err := parseUnsafeRoutes(c, []netip.Prefix{n})
assert.Nil(t, err) require.NoError(t, err)
assert.Len(t, routes, 0) assert.Empty(t, routes)
// not an array // not an array
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": "hi"} c.Settings["tun"] = map[string]any{"unsafe_routes": "hi"}
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes) assert.Nil(t, routes)
assert.EqualError(t, err, "tun.unsafe_routes is not an array") require.EqualError(t, err, "tun.unsafe_routes is not an array")
// no routes // no routes
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{}} c.Settings["tun"] = map[string]any{"unsafe_routes": []any{}}
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
assert.Nil(t, err) require.NoError(t, err)
assert.Len(t, routes, 0) assert.Empty(t, routes)
// weird route // weird route
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{"asdf"}} c.Settings["tun"] = map[string]any{"unsafe_routes": []any{"asdf"}}
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes) assert.Nil(t, routes)
assert.EqualError(t, err, "entry 1 in tun.unsafe_routes is invalid") require.EqualError(t, err, "entry 1 in tun.unsafe_routes is invalid")
// no via // no via
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{}}} c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{}}}
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes) assert.Nil(t, routes)
assert.EqualError(t, err, "entry 1.via in tun.unsafe_routes is not present") require.EqualError(t, err, "entry 1.via in tun.unsafe_routes is not present")
// invalid via // invalid via
for _, invalidValue := range []interface{}{ for _, invalidValue := range []any{
127, false, nil, 1.0, []string{"1", "2"}, 127, false, nil, 1.0, []string{"1", "2"},
} { } {
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": invalidValue}}} c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"via": invalidValue}}}
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes) assert.Nil(t, routes)
assert.EqualError(t, err, fmt.Sprintf("entry 1.via in tun.unsafe_routes is not a string: found %T", invalidValue)) require.EqualError(t, err, fmt.Sprintf("entry 1.via in tun.unsafe_routes is not a string or list of gateways: found %T", invalidValue))
} }
// unparsable via // Unparsable list of via
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "via": "nope"}}} c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"via": []string{"1", "2"}}}}
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes) assert.Nil(t, routes)
assert.EqualError(t, err, "entry 1.via in tun.unsafe_routes failed to parse address: ParseAddr(\"nope\"): unable to parse IP") require.EqualError(t, err, "entry 1.via in tun.unsafe_routes is not a string or list of gateways: found []string")
// unparsable via
c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"mtu": "500", "via": "nope"}}}
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes)
require.EqualError(t, err, "entry 1.via in tun.unsafe_routes failed to parse address: ParseAddr(\"nope\"): unable to parse IP")
// unparsable gateway
c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"mtu": "500", "via": []any{map[string]any{"gateway": "1"}}}}}
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes)
require.EqualError(t, err, "entry .gateway in tun.unsafe_routes[1].via[1] failed to parse address: ParseAddr(\"1\"): unable to parse IP")
// missing gateway element
c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"mtu": "500", "via": []any{map[string]any{"weight": "1"}}}}}
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes)
require.EqualError(t, err, "entry .gateway in tun.unsafe_routes[1].via[1] is not present")
// unparsable weight element
c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"mtu": "500", "via": []any{map[string]any{"gateway": "10.0.0.1", "weight": "a"}}}}}
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes)
require.EqualError(t, err, "entry .weight in tun.unsafe_routes[1].via[1] is not an integer")
// missing route // missing route
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "500"}}} c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"via": "127.0.0.1", "mtu": "500"}}}
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes) assert.Nil(t, routes)
assert.EqualError(t, err, "entry 1.route in tun.unsafe_routes is not present") require.EqualError(t, err, "entry 1.route in tun.unsafe_routes is not present")
// unparsable route // unparsable route
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "500", "route": "nope"}}} c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"via": "127.0.0.1", "mtu": "500", "route": "nope"}}}
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes) assert.Nil(t, routes)
assert.EqualError(t, err, "entry 1.route in tun.unsafe_routes failed to parse: netip.ParsePrefix(\"nope\"): no '/'") require.EqualError(t, err, "entry 1.route in tun.unsafe_routes failed to parse: netip.ParsePrefix(\"nope\"): no '/'")
// within network range // within network range
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "10.0.0.0/24"}}} c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"via": "127.0.0.1", "route": "10.0.0.0/24"}}}
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes) assert.Nil(t, routes)
assert.EqualError(t, err, "entry 1.route in tun.unsafe_routes is contained within the configured vpn networks; route: 10.0.0.0/24, network: 10.0.0.0/24") require.EqualError(t, err, "entry 1.route in tun.unsafe_routes is contained within the configured vpn networks; route: 10.0.0.0/24, network: 10.0.0.0/24")
// below network range // below network range
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "1.0.0.0/8"}}} c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"via": "127.0.0.1", "route": "1.0.0.0/8"}}}
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
assert.Len(t, routes, 1) assert.Len(t, routes, 1)
assert.Nil(t, err) require.NoError(t, err)
// above network range // above network range
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "10.0.1.0/24"}}} c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"via": "127.0.0.1", "route": "10.0.1.0/24"}}}
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
assert.Len(t, routes, 1) assert.Len(t, routes, 1)
assert.Nil(t, err) require.NoError(t, err)
// no mtu // no mtu
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "1.0.0.0/8"}}} c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"via": "127.0.0.1", "route": "1.0.0.0/8"}}}
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
require.NoError(t, err)
assert.Len(t, routes, 1) assert.Len(t, routes, 1)
assert.Equal(t, 0, routes[0].MTU) assert.Equal(t, 0, routes[0].MTU)
// bad mtu // bad mtu
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "nope"}}} c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"via": "127.0.0.1", "mtu": "nope"}}}
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes) assert.Nil(t, routes)
assert.EqualError(t, err, "entry 1.mtu in tun.unsafe_routes is not an integer: strconv.Atoi: parsing \"nope\": invalid syntax") require.EqualError(t, err, "entry 1.mtu in tun.unsafe_routes is not an integer: strconv.Atoi: parsing \"nope\": invalid syntax")
// low mtu // low mtu
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "499"}}} c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"via": "127.0.0.1", "mtu": "499"}}}
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes) assert.Nil(t, routes)
assert.EqualError(t, err, "entry 1.mtu in tun.unsafe_routes is below 500: 499") require.EqualError(t, err, "entry 1.mtu in tun.unsafe_routes is below 500: 499")
// bad install // bad install
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "9000", "route": "1.0.0.0/29", "install": "nope"}}} c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"via": "127.0.0.1", "mtu": "9000", "route": "1.0.0.0/29", "install": "nope"}}}
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes) assert.Nil(t, routes)
assert.EqualError(t, err, "entry 1.install in tun.unsafe_routes is not a boolean: strconv.ParseBool: parsing \"nope\": invalid syntax") require.EqualError(t, err, "entry 1.install in tun.unsafe_routes is not a boolean: strconv.ParseBool: parsing \"nope\": invalid syntax")
// happy case // happy case
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{ c.Settings["tun"] = map[string]any{"unsafe_routes": []any{
map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "9000", "route": "1.0.0.0/29", "install": "t"}, map[string]any{"via": "127.0.0.1", "mtu": "9000", "route": "1.0.0.0/29", "install": "t"},
map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "8000", "route": "1.0.0.1/32", "install": 0}, map[string]any{"via": "127.0.0.1", "mtu": "8000", "route": "1.0.0.1/32", "install": 0},
map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "1500", "metric": 1234, "route": "1.0.0.2/32", "install": 1}, map[string]any{"via": "127.0.0.1", "mtu": "1500", "metric": 1234, "route": "1.0.0.2/32", "install": 1},
map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "1500", "metric": 1234, "route": "1.0.0.2/32"}, map[string]any{"via": "127.0.0.1", "mtu": "1500", "metric": 1234, "route": "1.0.0.2/32"},
}} }}
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
assert.Nil(t, err) require.NoError(t, err)
assert.Len(t, routes, 4) assert.Len(t, routes, 4)
tested := 0 tested := 0
@@ -260,38 +287,119 @@ func Test_makeRouteTree(t *testing.T) {
l := test.NewLogger() l := test.NewLogger()
c := config.NewC(l) c := config.NewC(l)
n, err := netip.ParsePrefix("10.0.0.0/24") n, err := netip.ParsePrefix("10.0.0.0/24")
assert.NoError(t, err) require.NoError(t, err)
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{ c.Settings["tun"] = map[string]any{"unsafe_routes": []any{
map[interface{}]interface{}{"via": "192.168.0.1", "route": "1.0.0.0/28"}, map[string]any{"via": "192.168.0.1", "route": "1.0.0.0/28"},
map[interface{}]interface{}{"via": "192.168.0.2", "route": "1.0.0.1/32"}, map[string]any{"via": "192.168.0.2", "route": "1.0.0.1/32"},
}} }}
routes, err := parseUnsafeRoutes(c, []netip.Prefix{n}) routes, err := parseUnsafeRoutes(c, []netip.Prefix{n})
assert.NoError(t, err) require.NoError(t, err)
assert.Len(t, routes, 2) assert.Len(t, routes, 2)
routeTree, err := makeRouteTree(l, routes, true) routeTree, err := makeRouteTree(l, routes, true)
assert.NoError(t, err) require.NoError(t, err)
ip, err := netip.ParseAddr("1.0.0.2") ip, err := netip.ParseAddr("1.0.0.2")
assert.NoError(t, err) require.NoError(t, err)
r, ok := routeTree.Lookup(ip) r, ok := routeTree.Lookup(ip)
assert.True(t, ok) assert.True(t, ok)
nip, err := netip.ParseAddr("192.168.0.1") nip, err := netip.ParseAddr("192.168.0.1")
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, nip, r) assert.Equal(t, nip, r[0].Addr())
ip, err = netip.ParseAddr("1.0.0.1") ip, err = netip.ParseAddr("1.0.0.1")
assert.NoError(t, err) require.NoError(t, err)
r, ok = routeTree.Lookup(ip) r, ok = routeTree.Lookup(ip)
assert.True(t, ok) assert.True(t, ok)
nip, err = netip.ParseAddr("192.168.0.2") nip, err = netip.ParseAddr("192.168.0.2")
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, nip, r) assert.Equal(t, nip, r[0].Addr())
ip, err = netip.ParseAddr("1.1.0.1") ip, err = netip.ParseAddr("1.1.0.1")
assert.NoError(t, err) require.NoError(t, err)
r, ok = routeTree.Lookup(ip) _, ok = routeTree.Lookup(ip)
assert.False(t, ok) assert.False(t, ok)
} }
func Test_makeMultipathUnsafeRouteTree(t *testing.T) {
l := test.NewLogger()
c := config.NewC(l)
n, err := netip.ParsePrefix("10.0.0.0/24")
require.NoError(t, err)
c.Settings["tun"] = map[string]any{
"unsafe_routes": []any{
map[string]any{
"route": "192.168.86.0/24",
"via": "192.168.100.10",
},
map[string]any{
"route": "192.168.87.0/24",
"via": []any{
map[string]any{
"gateway": "10.0.0.1",
},
map[string]any{
"gateway": "10.0.0.2",
},
map[string]any{
"gateway": "10.0.0.3",
},
},
},
map[string]any{
"route": "192.168.89.0/24",
"via": []any{
map[string]any{
"gateway": "10.0.0.1",
"weight": 10,
},
map[string]any{
"gateway": "10.0.0.2",
"weight": 5,
},
},
},
},
}
routes, err := parseUnsafeRoutes(c, []netip.Prefix{n})
require.NoError(t, err)
assert.Len(t, routes, 3)
routeTree, err := makeRouteTree(l, routes, true)
require.NoError(t, err)
ip, err := netip.ParseAddr("192.168.86.1")
require.NoError(t, err)
r, ok := routeTree.Lookup(ip)
assert.True(t, ok)
nip, err := netip.ParseAddr("192.168.100.10")
require.NoError(t, err)
assert.Equal(t, nip, r[0].Addr())
ip, err = netip.ParseAddr("192.168.87.1")
require.NoError(t, err)
r, ok = routeTree.Lookup(ip)
assert.True(t, ok)
expectedGateways := routing.Gateways{routing.NewGateway(netip.MustParseAddr("10.0.0.1"), 1),
routing.NewGateway(netip.MustParseAddr("10.0.0.2"), 1),
routing.NewGateway(netip.MustParseAddr("10.0.0.3"), 1)}
routing.CalculateBucketsForGateways(expectedGateways)
assert.ElementsMatch(t, expectedGateways, r)
ip, err = netip.ParseAddr("192.168.89.1")
require.NoError(t, err)
r, ok = routeTree.Lookup(ip)
assert.True(t, ok)
expectedGateways = routing.Gateways{routing.NewGateway(netip.MustParseAddr("10.0.0.1"), 10),
routing.NewGateway(netip.MustParseAddr("10.0.0.2"), 5)}
routing.CalculateBucketsForGateways(expectedGateways)
assert.ElementsMatch(t, expectedGateways, r)
}

View File

@@ -13,6 +13,7 @@ import (
"github.com/gaissmai/bart" "github.com/gaissmai/bart"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/routing"
"github.com/slackhq/nebula/util" "github.com/slackhq/nebula/util"
) )
@@ -21,7 +22,7 @@ type tun struct {
fd int fd int
vpnNetworks []netip.Prefix vpnNetworks []netip.Prefix
Routes atomic.Pointer[[]Route] Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[netip.Addr]] routeTree atomic.Pointer[bart.Table[routing.Gateways]]
l *logrus.Logger l *logrus.Logger
} }
@@ -56,7 +57,7 @@ func newTun(_ *config.C, _ *logrus.Logger, _ []netip.Prefix, _ bool) (*tun, erro
return nil, fmt.Errorf("newTun not supported in Android") return nil, fmt.Errorf("newTun not supported in Android")
} }
func (t *tun) RouteFor(ip netip.Addr) netip.Addr { func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
r, _ := t.routeTree.Load().Lookup(ip) r, _ := t.routeTree.Load().Lookup(ip)
return r return r
} }

View File

@@ -17,6 +17,7 @@ import (
"github.com/gaissmai/bart" "github.com/gaissmai/bart"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/routing"
"github.com/slackhq/nebula/util" "github.com/slackhq/nebula/util"
netroute "golang.org/x/net/route" netroute "golang.org/x/net/route"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
@@ -28,7 +29,7 @@ type tun struct {
vpnNetworks []netip.Prefix vpnNetworks []netip.Prefix
DefaultMTU int DefaultMTU int
Routes atomic.Pointer[[]Route] Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[netip.Addr]] routeTree atomic.Pointer[bart.Table[routing.Gateways]]
linkAddr *netroute.LinkAddr linkAddr *netroute.LinkAddr
l *logrus.Logger l *logrus.Logger
@@ -294,6 +295,7 @@ func (t *tun) activate6(network netip.Prefix) error {
Vltime: 0xffffffff, Vltime: 0xffffffff,
Pltime: 0xffffffff, Pltime: 0xffffffff,
}, },
//TODO: CERT-V2 should we disable DAD (duplicate address detection) and mark this as a secured address?
Flags: _IN6_IFF_NODAD, Flags: _IN6_IFF_NODAD,
} }
@@ -341,12 +343,12 @@ func (t *tun) reload(c *config.C, initial bool) error {
return nil return nil
} }
func (t *tun) RouteFor(ip netip.Addr) netip.Addr { func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
r, ok := t.routeTree.Load().Lookup(ip) r, ok := t.routeTree.Load().Lookup(ip)
if ok { if ok {
return r return r
} }
return netip.Addr{} return routing.Gateways{}
} }
// Get the LinkAddr for the interface of the given name // Get the LinkAddr for the interface of the given name
@@ -381,7 +383,7 @@ func (t *tun) addRoutes(logErrors bool) error {
routes := *t.Routes.Load() routes := *t.Routes.Load()
for _, r := range routes { for _, r := range routes {
if !r.Via.IsValid() || !r.Install { if len(r.Via) == 0 || !r.Install {
// We don't allow route MTUs so only install routes with a via // We don't allow route MTUs so only install routes with a via
continue continue
} }
@@ -392,7 +394,7 @@ func (t *tun) addRoutes(logErrors bool) error {
t.l.WithField("route", r.Cidr). t.l.WithField("route", r.Cidr).
Warnf("unable to add unsafe_route, identical route already exists") Warnf("unable to add unsafe_route, identical route already exists")
} else { } else {
retErr := util.NewContextualError("Failed to add route", map[string]interface{}{"route": r}, err) retErr := util.NewContextualError("Failed to add route", map[string]any{"route": r}, err)
if logErrors { if logErrors {
retErr.Log(t.l) retErr.Log(t.l)
} else { } else {

View File

@@ -9,6 +9,7 @@ import (
"github.com/rcrowley/go-metrics" "github.com/rcrowley/go-metrics"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/routing"
) )
type disabledTun struct { type disabledTun struct {
@@ -43,8 +44,8 @@ func (*disabledTun) Activate() error {
return nil return nil
} }
func (*disabledTun) RouteFor(addr netip.Addr) netip.Addr { func (*disabledTun) RoutesFor(addr netip.Addr) routing.Gateways {
return netip.Addr{} return routing.Gateways{}
} }
func (t *disabledTun) Networks() []netip.Prefix { func (t *disabledTun) Networks() []netip.Prefix {

View File

@@ -20,6 +20,7 @@ import (
"github.com/gaissmai/bart" "github.com/gaissmai/bart"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/routing"
"github.com/slackhq/nebula/util" "github.com/slackhq/nebula/util"
) )
@@ -50,7 +51,7 @@ type tun struct {
vpnNetworks []netip.Prefix vpnNetworks []netip.Prefix
MTU int MTU int
Routes atomic.Pointer[[]Route] Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[netip.Addr]] routeTree atomic.Pointer[bart.Table[routing.Gateways]]
l *logrus.Logger l *logrus.Logger
io.ReadWriteCloser io.ReadWriteCloser
@@ -242,7 +243,7 @@ func (t *tun) reload(c *config.C, initial bool) error {
return nil return nil
} }
func (t *tun) RouteFor(ip netip.Addr) netip.Addr { func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
r, _ := t.routeTree.Load().Lookup(ip) r, _ := t.routeTree.Load().Lookup(ip)
return r return r
} }
@@ -262,7 +263,7 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
func (t *tun) addRoutes(logErrors bool) error { func (t *tun) addRoutes(logErrors bool) error {
routes := *t.Routes.Load() routes := *t.Routes.Load()
for _, r := range routes { for _, r := range routes {
if !r.Via.IsValid() || !r.Install { if len(r.Via) == 0 || !r.Install {
// We don't allow route MTUs so only install routes with a via // We don't allow route MTUs so only install routes with a via
continue continue
} }
@@ -270,7 +271,7 @@ func (t *tun) addRoutes(logErrors bool) error {
cmd := exec.Command("/sbin/route", "-n", "add", "-net", r.Cidr.String(), "-interface", t.Device) cmd := exec.Command("/sbin/route", "-n", "add", "-net", r.Cidr.String(), "-interface", t.Device)
t.l.Debug("command: ", cmd.String()) t.l.Debug("command: ", cmd.String())
if err := cmd.Run(); err != nil { if err := cmd.Run(); err != nil {
retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]interface{}{"route": r}, err) retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]any{"route": r}, err)
if logErrors { if logErrors {
retErr.Log(t.l) retErr.Log(t.l)
} else { } else {

View File

@@ -16,6 +16,7 @@ import (
"github.com/gaissmai/bart" "github.com/gaissmai/bart"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/routing"
"github.com/slackhq/nebula/util" "github.com/slackhq/nebula/util"
) )
@@ -23,7 +24,7 @@ type tun struct {
io.ReadWriteCloser io.ReadWriteCloser
vpnNetworks []netip.Prefix vpnNetworks []netip.Prefix
Routes atomic.Pointer[[]Route] Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[netip.Addr]] routeTree atomic.Pointer[bart.Table[routing.Gateways]]
l *logrus.Logger l *logrus.Logger
} }
@@ -79,7 +80,7 @@ func (t *tun) reload(c *config.C, initial bool) error {
return nil return nil
} }
func (t *tun) RouteFor(ip netip.Addr) netip.Addr { func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
r, _ := t.routeTree.Load().Lookup(ip) r, _ := t.routeTree.Load().Lookup(ip)
return r return r
} }

View File

@@ -17,6 +17,7 @@ import (
"github.com/gaissmai/bart" "github.com/gaissmai/bart"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/routing"
"github.com/slackhq/nebula/util" "github.com/slackhq/nebula/util"
"github.com/vishvananda/netlink" "github.com/vishvananda/netlink"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
@@ -34,7 +35,7 @@ type tun struct {
ioctlFd uintptr ioctlFd uintptr
Routes atomic.Pointer[[]Route] Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[netip.Addr]] routeTree atomic.Pointer[bart.Table[routing.Gateways]]
routeChan chan struct{} routeChan chan struct{}
useSystemRoutes bool useSystemRoutes bool
@@ -231,7 +232,7 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return file, nil return file, nil
} }
func (t *tun) RouteFor(ip netip.Addr) netip.Addr { func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
r, _ := t.routeTree.Load().Lookup(ip) r, _ := t.routeTree.Load().Lookup(ip)
return r return r
} }
@@ -463,7 +464,7 @@ func (t *tun) addRoutes(logErrors bool) error {
err := netlink.RouteReplace(&nr) err := netlink.RouteReplace(&nr)
if err != nil { if err != nil {
retErr := util.NewContextualError("Failed to add route", map[string]interface{}{"route": r}, err) retErr := util.NewContextualError("Failed to add route", map[string]any{"route": r}, err)
if logErrors { if logErrors {
retErr.Log(t.l) retErr.Log(t.l)
} else { } else {
@@ -550,20 +551,7 @@ func (t *tun) watchRoutes() {
}() }()
} }
func (t *tun) updateRoutes(r netlink.RouteUpdate) { func (t *tun) isGatewayInVpnNetworks(gwAddr netip.Addr) bool {
if r.Gw == nil {
// Not a gateway route, ignore
t.l.WithField("route", r).Debug("Ignoring route update, not a gateway route")
return
}
gwAddr, ok := netip.AddrFromSlice(r.Gw)
if !ok {
t.l.WithField("route", r).Debug("Ignoring route update, invalid gateway address")
return
}
gwAddr = gwAddr.Unmap()
withinNetworks := false withinNetworks := false
for i := range t.vpnNetworks { for i := range t.vpnNetworks {
if t.vpnNetworks[i].Contains(gwAddr) { if t.vpnNetworks[i].Contains(gwAddr) {
@@ -571,9 +559,68 @@ func (t *tun) updateRoutes(r netlink.RouteUpdate) {
break break
} }
} }
if !withinNetworks {
// Gateway isn't in our overlay network, ignore return withinNetworks
t.l.WithField("route", r).Debug("Ignoring route update, not in our networks") }
func (t *tun) getGatewaysFromRoute(r *netlink.Route) routing.Gateways {
var gateways routing.Gateways
link, err := netlink.LinkByName(t.Device)
if err != nil {
t.l.WithField("Devicename", t.Device).Error("Ignoring route update: failed to get link by name")
return gateways
}
// If this route is relevant to our interface and there is a gateway then add it
if r.LinkIndex == link.Attrs().Index && len(r.Gw) > 0 {
gwAddr, ok := netip.AddrFromSlice(r.Gw)
if !ok {
t.l.WithField("route", r).Debug("Ignoring route update, invalid gateway address")
} else {
gwAddr = gwAddr.Unmap()
if !t.isGatewayInVpnNetworks(gwAddr) {
// Gateway isn't in our overlay network, ignore
t.l.WithField("route", r).Debug("Ignoring route update, not in our network")
} else {
gateways = append(gateways, routing.NewGateway(gwAddr, 1))
}
}
}
for _, p := range r.MultiPath {
// If this route is relevant to our interface and there is a gateway then add it
if p.LinkIndex == link.Attrs().Index && len(p.Gw) > 0 {
gwAddr, ok := netip.AddrFromSlice(p.Gw)
if !ok {
t.l.WithField("route", r).Debug("Ignoring multipath route update, invalid gateway address")
} else {
gwAddr = gwAddr.Unmap()
if !t.isGatewayInVpnNetworks(gwAddr) {
// Gateway isn't in our overlay network, ignore
t.l.WithField("route", r).Debug("Ignoring route update, not in our network")
} else {
// p.Hops+1 = weight of the route
gateways = append(gateways, routing.NewGateway(gwAddr, p.Hops+1))
}
}
}
}
routing.CalculateBucketsForGateways(gateways)
return gateways
}
func (t *tun) updateRoutes(r netlink.RouteUpdate) {
gateways := t.getGatewaysFromRoute(&r.Route)
if len(gateways) == 0 {
// No gateways relevant to our network, no routing changes required.
t.l.WithField("route", r).Debug("Ignoring route update, no gateways")
return return
} }
@@ -589,12 +636,12 @@ func (t *tun) updateRoutes(r netlink.RouteUpdate) {
newTree := t.routeTree.Load().Clone() newTree := t.routeTree.Load().Clone()
if r.Type == unix.RTM_NEWROUTE { if r.Type == unix.RTM_NEWROUTE {
t.l.WithField("destination", r.Dst).WithField("via", r.Gw).Info("Adding route") t.l.WithField("destination", dst).WithField("via", gateways).Info("Adding route")
newTree.Insert(dst, gwAddr) newTree.Insert(dst, gateways)
} else { } else {
t.l.WithField("destination", dst).WithField("via", gateways).Info("Removing route")
newTree.Delete(dst) newTree.Delete(dst)
t.l.WithField("destination", r.Dst).WithField("via", r.Gw).Info("Removing route")
} }
t.routeTree.Store(newTree) t.routeTree.Store(newTree)
} }

View File

@@ -18,6 +18,7 @@ import (
"github.com/gaissmai/bart" "github.com/gaissmai/bart"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/routing"
"github.com/slackhq/nebula/util" "github.com/slackhq/nebula/util"
) )
@@ -31,7 +32,7 @@ type tun struct {
vpnNetworks []netip.Prefix vpnNetworks []netip.Prefix
MTU int MTU int
Routes atomic.Pointer[[]Route] Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[netip.Addr]] routeTree atomic.Pointer[bart.Table[routing.Gateways]]
l *logrus.Logger l *logrus.Logger
io.ReadWriteCloser io.ReadWriteCloser
@@ -108,34 +109,26 @@ func (t *tun) addIp(cidr netip.Prefix) error {
var err error var err error
// TODO use syscalls instead of exec.Command // TODO use syscalls instead of exec.Command
if cidr.Addr().Is6() { cmd := exec.Command("/sbin/ifconfig", t.Device, cidr.String(), cidr.Addr().String())
cmd := exec.Command("/sbin/ifconfig", t.Device, "inet6", cidr.Addr().String(), "prefixlen", strconv.Itoa(cidr.Bits()), "alias") t.l.Debug("command: ", cmd.String())
t.l.Debug("command: ", cmd.String()) if err = cmd.Run(); err != nil {
if err = cmd.Run(); err != nil { return fmt.Errorf("failed to run 'ifconfig': %s", err)
return fmt.Errorf("failed to run 'ifconfig': %s", err)
}
cmd = exec.Command("/sbin/route", "-n", "add", "-net", cidr.String(), cidr.Addr().String())
t.l.Debug("command: ", cmd.String())
if err = cmd.Run(); err != nil {
return fmt.Errorf("failed to run 'route add': %s", err)
}
} else {
cmd := exec.Command("/sbin/ifconfig", t.Device, cidr.String(), cidr.Addr().String())
t.l.Debug("command: ", cmd.String())
if err = cmd.Run(); err != nil {
return fmt.Errorf("failed to run 'ifconfig': %s", err)
}
cmd = exec.Command("/sbin/route", "-n", "add", "-net", cidr.String(), cidr.Addr().String())
t.l.Debug("command: ", cmd.String())
if err = cmd.Run(); err != nil {
return fmt.Errorf("failed to run 'route add': %s", err)
}
} }
return nil cmd = exec.Command("/sbin/route", "-n", "add", "-net", cidr.String(), cidr.Addr().String())
t.l.Debug("command: ", cmd.String())
if err = cmd.Run(); err != nil {
return fmt.Errorf("failed to run 'route add': %s", err)
}
cmd = exec.Command("/sbin/ifconfig", t.Device, "mtu", strconv.Itoa(t.MTU))
t.l.Debug("command: ", cmd.String())
if err = cmd.Run(); err != nil {
return fmt.Errorf("failed to run 'ifconfig': %s", err)
}
// Unsafe path routes
return t.addRoutes(false)
} }
func (t *tun) Activate() error { func (t *tun) Activate() error {
@@ -145,15 +138,7 @@ func (t *tun) Activate() error {
return err return err
} }
} }
return nil
cmd := exec.Command("/sbin/ifconfig", t.Device, "mtu", strconv.Itoa(t.MTU))
t.l.Debug("command: ", cmd.String())
if err := cmd.Run(); err != nil {
return fmt.Errorf("failed to run '%s': %s", cmd, err)
}
// Unsafe path routes
return t.addRoutes(false)
} }
func (t *tun) reload(c *config.C, initial bool) error { func (t *tun) reload(c *config.C, initial bool) error {
@@ -193,7 +178,7 @@ func (t *tun) reload(c *config.C, initial bool) error {
return nil return nil
} }
func (t *tun) RouteFor(ip netip.Addr) netip.Addr { func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
r, _ := t.routeTree.Load().Lookup(ip) r, _ := t.routeTree.Load().Lookup(ip)
return r return r
} }
@@ -213,7 +198,7 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
func (t *tun) addRoutes(logErrors bool) error { func (t *tun) addRoutes(logErrors bool) error {
routes := *t.Routes.Load() routes := *t.Routes.Load()
for _, r := range routes { for _, r := range routes {
if !r.Via.IsValid() || !r.Install { if len(r.Via) == 0 || !r.Install {
// We don't allow route MTUs so only install routes with a via // We don't allow route MTUs so only install routes with a via
continue continue
} }
@@ -221,7 +206,7 @@ func (t *tun) addRoutes(logErrors bool) error {
cmd := exec.Command("/sbin/route", "-n", "add", "-net", r.Cidr.String(), t.vpnNetworks[0].Addr().String()) cmd := exec.Command("/sbin/route", "-n", "add", "-net", r.Cidr.String(), t.vpnNetworks[0].Addr().String())
t.l.Debug("command: ", cmd.String()) t.l.Debug("command: ", cmd.String())
if err := cmd.Run(); err != nil { if err := cmd.Run(); err != nil {
retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]interface{}{"route": r}, err) retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]any{"route": r}, err)
if logErrors { if logErrors {
retErr.Log(t.l) retErr.Log(t.l)
} else { } else {

View File

@@ -17,6 +17,7 @@ import (
"github.com/gaissmai/bart" "github.com/gaissmai/bart"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/routing"
"github.com/slackhq/nebula/util" "github.com/slackhq/nebula/util"
) )
@@ -25,7 +26,7 @@ type tun struct {
vpnNetworks []netip.Prefix vpnNetworks []netip.Prefix
MTU int MTU int
Routes atomic.Pointer[[]Route] Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[netip.Addr]] routeTree atomic.Pointer[bart.Table[routing.Gateways]]
l *logrus.Logger l *logrus.Logger
io.ReadWriteCloser io.ReadWriteCloser
@@ -158,7 +159,7 @@ func (t *tun) Activate() error {
return nil return nil
} }
func (t *tun) RouteFor(ip netip.Addr) netip.Addr { func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
r, _ := t.routeTree.Load().Lookup(ip) r, _ := t.routeTree.Load().Lookup(ip)
return r return r
} }
@@ -166,7 +167,7 @@ func (t *tun) RouteFor(ip netip.Addr) netip.Addr {
func (t *tun) addRoutes(logErrors bool) error { func (t *tun) addRoutes(logErrors bool) error {
routes := *t.Routes.Load() routes := *t.Routes.Load()
for _, r := range routes { for _, r := range routes {
if !r.Via.IsValid() || !r.Install { if len(r.Via) == 0 || !r.Install {
// We don't allow route MTUs so only install routes with a via // We don't allow route MTUs so only install routes with a via
continue continue
} }
@@ -174,7 +175,7 @@ func (t *tun) addRoutes(logErrors bool) error {
cmd := exec.Command("/sbin/route", "-n", "add", "-inet", r.Cidr.String(), t.vpnNetworks[0].Addr().String()) cmd := exec.Command("/sbin/route", "-n", "add", "-inet", r.Cidr.String(), t.vpnNetworks[0].Addr().String())
t.l.Debug("command: ", cmd.String()) t.l.Debug("command: ", cmd.String())
if err := cmd.Run(); err != nil { if err := cmd.Run(); err != nil {
retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]interface{}{"route": r}, err) retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]any{"route": r}, err)
if logErrors { if logErrors {
retErr.Log(t.l) retErr.Log(t.l)
} else { } else {

View File

@@ -13,13 +13,14 @@ import (
"github.com/gaissmai/bart" "github.com/gaissmai/bart"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/routing"
) )
type TestTun struct { type TestTun struct {
Device string Device string
vpnNetworks []netip.Prefix vpnNetworks []netip.Prefix
Routes []Route Routes []Route
routeTree *bart.Table[netip.Addr] routeTree *bart.Table[routing.Gateways]
l *logrus.Logger l *logrus.Logger
closed atomic.Bool closed atomic.Bool
@@ -86,7 +87,7 @@ func (t *TestTun) Get(block bool) []byte {
// Below this is boilerplate implementation to make nebula actually work // Below this is boilerplate implementation to make nebula actually work
//********************************************************************************************************************// //********************************************************************************************************************//
func (t *TestTun) RouteFor(ip netip.Addr) netip.Addr { func (t *TestTun) RoutesFor(ip netip.Addr) routing.Gateways {
r, _ := t.routeTree.Lookup(ip) r, _ := t.routeTree.Lookup(ip)
return r return r
} }

View File

@@ -18,6 +18,7 @@ import (
"github.com/gaissmai/bart" "github.com/gaissmai/bart"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/routing"
"github.com/slackhq/nebula/util" "github.com/slackhq/nebula/util"
"github.com/slackhq/nebula/wintun" "github.com/slackhq/nebula/wintun"
"golang.org/x/sys/windows" "golang.org/x/sys/windows"
@@ -31,7 +32,7 @@ type winTun struct {
vpnNetworks []netip.Prefix vpnNetworks []netip.Prefix
MTU int MTU int
Routes atomic.Pointer[[]Route] Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[netip.Addr]] routeTree atomic.Pointer[bart.Table[routing.Gateways]]
l *logrus.Logger l *logrus.Logger
tun *wintun.NativeTun tun *wintun.NativeTun
@@ -147,15 +148,18 @@ func (t *winTun) addRoutes(logErrors bool) error {
foundDefault4 := false foundDefault4 := false
for _, r := range routes { for _, r := range routes {
if !r.Via.IsValid() || !r.Install { if len(r.Via) == 0 || !r.Install {
// We don't allow route MTUs so only install routes with a via // We don't allow route MTUs so only install routes with a via
continue continue
} }
// Add our unsafe route // Add our unsafe route
err := luid.AddRoute(r.Cidr, r.Via, uint32(r.Metric)) // Windows does not support multipath routes natively, so we install only a single route.
// This is not a problem as traffic will always be sent to Nebula which handles the multipath routing internally.
// In effect this provides multipath routing support to windows supporting loadbalancing and redundancy.
err := luid.AddRoute(r.Cidr, r.Via[0].Addr(), uint32(r.Metric))
if err != nil { if err != nil {
retErr := util.NewContextualError("Failed to add route", map[string]interface{}{"route": r}, err) retErr := util.NewContextualError("Failed to add route", map[string]any{"route": r}, err)
if logErrors { if logErrors {
retErr.Log(t.l) retErr.Log(t.l)
continue continue
@@ -198,7 +202,8 @@ func (t *winTun) removeRoutes(routes []Route) error {
continue continue
} }
err := luid.DeleteRoute(r.Cidr, r.Via) // See comment on luid.AddRoute
err := luid.DeleteRoute(r.Cidr, r.Via[0].Addr())
if err != nil { if err != nil {
t.l.WithError(err).WithField("route", r).Error("Failed to remove route") t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
} else { } else {
@@ -208,7 +213,7 @@ func (t *winTun) removeRoutes(routes []Route) error {
return nil return nil
} }
func (t *winTun) RouteFor(ip netip.Addr) netip.Addr { func (t *winTun) RoutesFor(ip netip.Addr) routing.Gateways {
r, _ := t.routeTree.Load().Lookup(ip) r, _ := t.routeTree.Load().Lookup(ip)
return r return r
} }

View File

@@ -6,6 +6,7 @@ import (
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/routing"
) )
func NewUserDeviceFromConfig(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) { func NewUserDeviceFromConfig(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) {
@@ -38,9 +39,13 @@ type UserDevice struct {
func (d *UserDevice) Activate() error { func (d *UserDevice) Activate() error {
return nil return nil
} }
func (d *UserDevice) Networks() []netip.Prefix { return d.vpnNetworks }
func (d *UserDevice) Name() string { return "faketun0" } func (d *UserDevice) Networks() []netip.Prefix { return d.vpnNetworks }
func (d *UserDevice) RouteFor(ip netip.Addr) netip.Addr { return ip } func (d *UserDevice) Name() string { return "faketun0" }
func (d *UserDevice) RoutesFor(ip netip.Addr) routing.Gateways {
return routing.Gateways{routing.NewGateway(ip, 1)}
}
func (d *UserDevice) NewMultiQueueReader() (io.ReadWriteCloser, error) { func (d *UserDevice) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return d, nil return d, nil
} }

View File

@@ -1,8 +1,6 @@
package pkclient package pkclient
import ( import (
"crypto/ecdsa"
"crypto/x509"
"fmt" "fmt"
"io" "io"
"strconv" "strconv"
@@ -50,27 +48,6 @@ func FromUrl(pkurl string) (*PKClient, error) {
return New(module, uint(slotid), pin, id, label) return New(module, uint(slotid), pin, id, label)
} }
func ecKeyToArray(key *ecdsa.PublicKey) []byte {
x := make([]byte, 32)
y := make([]byte, 32)
key.X.FillBytes(x)
key.Y.FillBytes(y)
return append([]byte{0x04}, append(x, y...)...)
}
func formatPubkeyFromPublicKeyInfoAttr(d []byte) ([]byte, error) {
e, err := x509.ParsePKIXPublicKey(d)
if err != nil {
return nil, err
}
switch t := e.(type) {
case *ecdsa.PublicKey:
return ecKeyToArray(e.(*ecdsa.PublicKey)), nil
default:
return nil, fmt.Errorf("unknown public key type: %T", t)
}
}
func (c *PKClient) Test() error { func (c *PKClient) Test() error {
pub, err := c.GetPubKey() pub, err := c.GetPubKey()
if err != nil { if err != nil {

View File

@@ -3,6 +3,8 @@
package pkclient package pkclient
import ( import (
"crypto/ecdsa"
"crypto/x509"
"encoding/asn1" "encoding/asn1"
"errors" "errors"
"fmt" "fmt"
@@ -227,3 +229,24 @@ func (c *PKClient) GetPubKey() ([]byte, error) {
return nil, fmt.Errorf("unknown public key length: %d", len(d)) return nil, fmt.Errorf("unknown public key length: %d", len(d))
} }
} }
func ecKeyToArray(key *ecdsa.PublicKey) []byte {
x := make([]byte, 32)
y := make([]byte, 32)
key.X.FillBytes(x)
key.Y.FillBytes(y)
return append([]byte{0x04}, append(x, y...)...)
}
func formatPubkeyFromPublicKeyInfoAttr(d []byte) ([]byte, error) {
e, err := x509.ParsePKIXPublicKey(d)
if err != nil {
return nil, err
}
switch t := e.(type) {
case *ecdsa.PublicKey:
return ecKeyToArray(e.(*ecdsa.PublicKey)), nil
default:
return nil, fmt.Errorf("unknown public key type: %T", t)
}
}

View File

@@ -7,10 +7,10 @@ import "errors"
type PKClient struct { type PKClient struct {
} }
var notImplemented = errors.New("not implemented") var errNotImplemented = errors.New("not implemented")
func New(hsmPath string, slotId uint, pin string, id string, label string) (*PKClient, error) { func New(hsmPath string, slotId uint, pin string, id string, label string) (*PKClient, error) {
return nil, notImplemented return nil, errNotImplemented
} }
func (c *PKClient) Close() error { func (c *PKClient) Close() error {
@@ -18,13 +18,13 @@ func (c *PKClient) Close() error {
} }
func (c *PKClient) SignASN1(data []byte) ([]byte, error) { func (c *PKClient) SignASN1(data []byte) ([]byte, error) {
return nil, notImplemented return nil, errNotImplemented
} }
func (c *PKClient) DeriveNoise(_ []byte) ([]byte, error) { func (c *PKClient) DeriveNoise(_ []byte) ([]byte, error) {
return nil, notImplemented return nil, errNotImplemented
} }
func (c *PKClient) GetPubKey() ([]byte, error) { func (c *PKClient) GetPubKey() ([]byte, error) {
return nil, notImplemented return nil, errNotImplemented
} }

41
pki.go
View File

@@ -33,10 +33,10 @@ type CertState struct {
v2Cert cert.Certificate v2Cert cert.Certificate
v2HandshakeBytes []byte v2HandshakeBytes []byte
defaultVersion cert.Version initiatingVersion cert.Version
privateKey []byte privateKey []byte
pkcs11Backed bool pkcs11Backed bool
cipher string cipher string
myVpnNetworks []netip.Prefix myVpnNetworks []netip.Prefix
myVpnNetworksTable *bart.Table[struct{}] myVpnNetworksTable *bart.Table[struct{}]
@@ -173,6 +173,7 @@ func (p *PKI) reloadCerts(c *config.C, initial bool) *util.ContextualError {
p.cs.Store(newState) p.cs.Store(newState)
//TODO: CERT-V2 newState needs a stringer that does json
if initial { if initial {
p.l.WithField("cert", newState).Debug("Client nebula certificate(s)") p.l.WithField("cert", newState).Debug("Client nebula certificate(s)")
} else { } else {
@@ -193,7 +194,7 @@ func (p *PKI) reloadCAPool(c *config.C) *util.ContextualError {
} }
func (cs *CertState) GetDefaultCertificate() cert.Certificate { func (cs *CertState) GetDefaultCertificate() cert.Certificate {
c := cs.getCertificate(cs.defaultVersion) c := cs.getCertificate(cs.initiatingVersion)
if c == nil { if c == nil {
panic("No default certificate found") panic("No default certificate found")
} }
@@ -316,28 +317,28 @@ func newCertStateFromConfig(c *config.C) (*CertState, error) {
return nil, errors.New("no certificates found in pki.cert") return nil, errors.New("no certificates found in pki.cert")
} }
useDefaultVersion := uint32(1) useInitiatingVersion := uint32(1)
if v1 == nil { if v1 == nil {
// The only condition that requires v2 as the default is if only a v2 certificate is present // The only condition that requires v2 as the default is if only a v2 certificate is present
// We do this to avoid having to configure it specifically in the config file // We do this to avoid having to configure it specifically in the config file
useDefaultVersion = 2 useInitiatingVersion = 2
} }
rawDefaultVersion := c.GetUint32("pki.default_version", useDefaultVersion) rawInitiatingVersion := c.GetUint32("pki.initiating_version", useInitiatingVersion)
var defaultVersion cert.Version var initiatingVersion cert.Version
switch rawDefaultVersion { switch rawInitiatingVersion {
case 1: case 1:
if v1 == nil { if v1 == nil {
return nil, fmt.Errorf("can not use pki.default_version 1 without a v1 certificate in pki.cert") return nil, fmt.Errorf("can not use pki.initiating_version 1 without a v1 certificate in pki.cert")
} }
defaultVersion = cert.Version1 initiatingVersion = cert.Version1
case 2: case 2:
defaultVersion = cert.Version2 initiatingVersion = cert.Version2
default: default:
return nil, fmt.Errorf("unknown pki.default_version: %v", rawDefaultVersion) return nil, fmt.Errorf("unknown pki.initiating_version: %v", rawInitiatingVersion)
} }
return newCertState(defaultVersion, v1, v2, isPkcs11, curve, rawKey) return newCertState(initiatingVersion, v1, v2, isPkcs11, curve, rawKey)
} }
func newCertState(dv cert.Version, v1, v2 cert.Certificate, pkcs11backed bool, privateKeyCurve cert.Curve, privateKey []byte) (*CertState, error) { func newCertState(dv cert.Version, v1, v2 cert.Certificate, pkcs11backed bool, privateKeyCurve cert.Curve, privateKey []byte) (*CertState, error) {
@@ -360,7 +361,7 @@ func newCertState(dv cert.Version, v1, v2 cert.Certificate, pkcs11backed bool, p
//TODO: CERT-V2 make sure v2 has v1s address //TODO: CERT-V2 make sure v2 has v1s address
cs.defaultVersion = dv cs.initiatingVersion = dv
} }
if v1 != nil { if v1 != nil {
@@ -379,8 +380,8 @@ func newCertState(dv cert.Version, v1, v2 cert.Certificate, pkcs11backed bool, p
cs.v1Cert = v1 cs.v1Cert = v1
cs.v1HandshakeBytes = v1hs cs.v1HandshakeBytes = v1hs
if cs.defaultVersion == 0 { if cs.initiatingVersion == 0 {
cs.defaultVersion = cert.Version1 cs.initiatingVersion = cert.Version1
} }
} }
@@ -400,8 +401,8 @@ func newCertState(dv cert.Version, v1, v2 cert.Certificate, pkcs11backed bool, p
cs.v2Cert = v2 cs.v2Cert = v2
cs.v2HandshakeBytes = v2hs cs.v2HandshakeBytes = v2hs
if cs.defaultVersion == 0 { if cs.initiatingVersion == 0 {
cs.defaultVersion = cert.Version2 cs.initiatingVersion = cert.Version2
} }
} }

View File

@@ -7,6 +7,7 @@ import (
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/test" "github.com/slackhq/nebula/test"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
func TestNewPunchyFromConfig(t *testing.T) { func TestNewPunchyFromConfig(t *testing.T) {
@@ -15,39 +16,39 @@ func TestNewPunchyFromConfig(t *testing.T) {
// Test defaults // Test defaults
p := NewPunchyFromConfig(l, c) p := NewPunchyFromConfig(l, c)
assert.Equal(t, false, p.GetPunch()) assert.False(t, p.GetPunch())
assert.Equal(t, false, p.GetRespond()) assert.False(t, p.GetRespond())
assert.Equal(t, time.Second, p.GetDelay()) assert.Equal(t, time.Second, p.GetDelay())
assert.Equal(t, 5*time.Second, p.GetRespondDelay()) assert.Equal(t, 5*time.Second, p.GetRespondDelay())
// punchy deprecation // punchy deprecation
c.Settings["punchy"] = true c.Settings["punchy"] = true
p = NewPunchyFromConfig(l, c) p = NewPunchyFromConfig(l, c)
assert.Equal(t, true, p.GetPunch()) assert.True(t, p.GetPunch())
// punchy.punch // punchy.punch
c.Settings["punchy"] = map[interface{}]interface{}{"punch": true} c.Settings["punchy"] = map[string]any{"punch": true}
p = NewPunchyFromConfig(l, c) p = NewPunchyFromConfig(l, c)
assert.Equal(t, true, p.GetPunch()) assert.True(t, p.GetPunch())
// punch_back deprecation // punch_back deprecation
c.Settings["punch_back"] = true c.Settings["punch_back"] = true
p = NewPunchyFromConfig(l, c) p = NewPunchyFromConfig(l, c)
assert.Equal(t, true, p.GetRespond()) assert.True(t, p.GetRespond())
// punchy.respond // punchy.respond
c.Settings["punchy"] = map[interface{}]interface{}{"respond": true} c.Settings["punchy"] = map[string]any{"respond": true}
c.Settings["punch_back"] = false c.Settings["punch_back"] = false
p = NewPunchyFromConfig(l, c) p = NewPunchyFromConfig(l, c)
assert.Equal(t, true, p.GetRespond()) assert.True(t, p.GetRespond())
// punchy.delay // punchy.delay
c.Settings["punchy"] = map[interface{}]interface{}{"delay": "1m"} c.Settings["punchy"] = map[string]any{"delay": "1m"}
p = NewPunchyFromConfig(l, c) p = NewPunchyFromConfig(l, c)
assert.Equal(t, time.Minute, p.GetDelay()) assert.Equal(t, time.Minute, p.GetDelay())
// punchy.respond_delay // punchy.respond_delay
c.Settings["punchy"] = map[interface{}]interface{}{"respond_delay": "1m"} c.Settings["punchy"] = map[string]any{"respond_delay": "1m"}
p = NewPunchyFromConfig(l, c) p = NewPunchyFromConfig(l, c)
assert.Equal(t, time.Minute, p.GetRespondDelay()) assert.Equal(t, time.Minute, p.GetRespondDelay())
} }
@@ -56,22 +57,22 @@ func TestPunchy_reload(t *testing.T) {
l := test.NewLogger() l := test.NewLogger()
c := config.NewC(l) c := config.NewC(l)
delay, _ := time.ParseDuration("1m") delay, _ := time.ParseDuration("1m")
assert.NoError(t, c.LoadString(` require.NoError(t, c.LoadString(`
punchy: punchy:
delay: 1m delay: 1m
respond: false respond: false
`)) `))
p := NewPunchyFromConfig(l, c) p := NewPunchyFromConfig(l, c)
assert.Equal(t, delay, p.GetDelay()) assert.Equal(t, delay, p.GetDelay())
assert.Equal(t, false, p.GetRespond()) assert.False(t, p.GetRespond())
newDelay, _ := time.ParseDuration("10m") newDelay, _ := time.ParseDuration("10m")
assert.NoError(t, c.ReloadConfigString(` require.NoError(t, c.ReloadConfigString(`
punchy: punchy:
delay: 10m delay: 10m
respond: true respond: true
`)) `))
p.reload(c, false) p.reload(c, false)
assert.Equal(t, newDelay, p.GetDelay()) assert.Equal(t, newDelay, p.GetDelay())
assert.Equal(t, true, p.GetRespond()) assert.True(t, p.GetRespond())
} }

View File

@@ -263,9 +263,7 @@ func (r *RemoteList) CopyAddrs(preferredRanges []netip.Prefix) []netip.AddrPort
r.RLock() r.RLock()
defer r.RUnlock() defer r.RUnlock()
c := make([]netip.AddrPort, len(r.addrs)) c := make([]netip.AddrPort, len(r.addrs))
for i, v := range r.addrs { copy(c, r.addrs)
c[i] = v
}
return c return c
} }
@@ -326,9 +324,7 @@ func (r *RemoteList) CopyCache() *CacheMap {
} }
if mc.relay != nil { if mc.relay != nil {
for _, a := range mc.relay.relay { c.Relay = append(c.Relay, mc.relay.relay...)
c.Relay = append(c.Relay, a)
}
} }
} }
@@ -362,9 +358,7 @@ func (r *RemoteList) CopyBlockedRemotes() []netip.AddrPort {
defer r.RUnlock() defer r.RUnlock()
c := make([]netip.AddrPort, len(r.badRemotes)) c := make([]netip.AddrPort, len(r.badRemotes))
for i, v := range r.badRemotes { copy(c, r.badRemotes)
c[i] = v
}
return c return c
} }
@@ -569,9 +563,7 @@ func (r *RemoteList) unlockedCollect() {
} }
if c.relay != nil { if c.relay != nil {
for _, v := range c.relay.relay { relays = append(relays, c.relay.relay...)
relays = append(relays, v)
}
} }
} }
@@ -635,15 +627,15 @@ func (r *RemoteList) unlockedSort(preferredRanges []netip.Prefix) {
a4 := a.Addr().Is4() a4 := a.Addr().Is4()
b4 := b.Addr().Is4() b4 := b.Addr().Is4()
switch { switch {
case a4 == false && b4 == true: case !a4 && b4:
// If i is v6 and j is v4, i is less than j // If i is v6 and j is v4, i is less than j
return true return true
case a4 == true && b4 == false: case a4 && !b4:
// If j is v6 and i is v4, i is not less than j // If j is v6 and i is v4, i is not less than j
return false return false
case a4 == true && b4 == true: case a4 && b4:
// i and j are both ipv4 // i and j are both ipv4
aPrivate := a.Addr().IsPrivate() aPrivate := a.Addr().IsPrivate()
bPrivate := b.Addr().IsPrivate() bPrivate := b.Addr().IsPrivate()
@@ -691,7 +683,6 @@ func (r *RemoteList) unlockedSort(preferredRanges []netip.Prefix) {
} }
r.addrs = r.addrs[:a+1] r.addrs = r.addrs[:a+1]
return
} }
// minInt returns the minimum integer of a or b // minInt returns the minimum integer of a or b

39
routing/balance.go Normal file
View File

@@ -0,0 +1,39 @@
package routing
import (
"net/netip"
"github.com/slackhq/nebula/firewall"
)
// Hashes the packet source and destination port and always returns a positive integer
// Based on 'Prospecting for Hash Functions'
// - https://nullprogram.com/blog/2018/07/31/
// - https://github.com/skeeto/hash-prospector
// [16 21f0aaad 15 d35a2d97 15] = 0.10760229515479501
func hashPacket(p *firewall.Packet) int {
x := (uint32(p.LocalPort) << 16) | uint32(p.RemotePort)
x ^= x >> 16
x *= 0x21f0aaad
x ^= x >> 15
x *= 0xd35a2d97
x ^= x >> 15
return int(x) & 0x7FFFFFFF
}
// For this function to work correctly it requires that the buckets for the gateways have been calculated
// If the contract is violated balancing will not work properly and the second return value will return false
func BalancePacket(fwPacket *firewall.Packet, gateways []Gateway) (netip.Addr, bool) {
hash := hashPacket(fwPacket)
for i := range gateways {
if hash <= gateways[i].BucketUpperBound() {
return gateways[i].Addr(), true
}
}
// If you land here then the buckets for the gateways are not properly calculated
// Fallback to random routing and let the caller know
return gateways[hash%len(gateways)].Addr(), false
}

144
routing/balance_test.go Normal file
View File

@@ -0,0 +1,144 @@
package routing
import (
"net/netip"
"testing"
"github.com/slackhq/nebula/firewall"
"github.com/stretchr/testify/assert"
)
func TestPacketsAreBalancedEqually(t *testing.T) {
gateways := []Gateway{}
gw1Addr := netip.MustParseAddr("1.0.0.1")
gw2Addr := netip.MustParseAddr("1.0.0.2")
gw3Addr := netip.MustParseAddr("1.0.0.3")
gateways = append(gateways, NewGateway(gw1Addr, 1))
gateways = append(gateways, NewGateway(gw2Addr, 1))
gateways = append(gateways, NewGateway(gw3Addr, 1))
CalculateBucketsForGateways(gateways)
gw1count := 0
gw2count := 0
gw3count := 0
iterationCount := uint16(65535)
for i := uint16(0); i < iterationCount; i++ {
packet := firewall.Packet{
LocalAddr: netip.MustParseAddr("192.168.1.1"),
RemoteAddr: netip.MustParseAddr("10.0.0.1"),
LocalPort: i,
RemotePort: 65535 - i,
Protocol: 6, // TCP
Fragment: false,
}
selectedGw, ok := BalancePacket(&packet, gateways)
assert.True(t, ok)
switch selectedGw {
case gw1Addr:
gw1count += 1
case gw2Addr:
gw2count += 1
case gw3Addr:
gw3count += 1
}
}
// Assert packets are balanced, allow variation of up to 100 packets per gateway
assert.InDeltaf(t, iterationCount/3, gw1count, 100, "Expected %d +/- 100, but got %d", iterationCount/3, gw1count)
assert.InDeltaf(t, iterationCount/3, gw2count, 100, "Expected %d +/- 100, but got %d", iterationCount/3, gw1count)
assert.InDeltaf(t, iterationCount/3, gw3count, 100, "Expected %d +/- 100, but got %d", iterationCount/3, gw1count)
}
func TestPacketsAreBalancedByPriority(t *testing.T) {
gateways := []Gateway{}
gw1Addr := netip.MustParseAddr("1.0.0.1")
gw2Addr := netip.MustParseAddr("1.0.0.2")
gateways = append(gateways, NewGateway(gw1Addr, 10))
gateways = append(gateways, NewGateway(gw2Addr, 5))
CalculateBucketsForGateways(gateways)
gw1count := 0
gw2count := 0
iterationCount := uint16(65535)
for i := uint16(0); i < iterationCount; i++ {
packet := firewall.Packet{
LocalAddr: netip.MustParseAddr("192.168.1.1"),
RemoteAddr: netip.MustParseAddr("10.0.0.1"),
LocalPort: i,
RemotePort: 65535 - i,
Protocol: 6, // TCP
Fragment: false,
}
selectedGw, ok := BalancePacket(&packet, gateways)
assert.True(t, ok)
switch selectedGw {
case gw1Addr:
gw1count += 1
case gw2Addr:
gw2count += 1
}
}
iterationCountAsFloat := float32(iterationCount)
assert.InDeltaf(t, iterationCountAsFloat*(2.0/3.0), gw1count, 100, "Expected %d +/- 100, but got %d", iterationCountAsFloat*(2.0/3.0), gw1count)
assert.InDeltaf(t, iterationCountAsFloat*(1.0/3.0), gw2count, 100, "Expected %d +/- 100, but got %d", iterationCountAsFloat*(1.0/3.0), gw2count)
}
func TestBalancePacketDistributsRandomlyAndReturnsFalseIfBucketsNotCalculated(t *testing.T) {
gateways := []Gateway{}
gw1Addr := netip.MustParseAddr("1.0.0.1")
gw2Addr := netip.MustParseAddr("1.0.0.2")
gateways = append(gateways, NewGateway(gw1Addr, 10))
gateways = append(gateways, NewGateway(gw2Addr, 5))
iterationCount := uint16(65535)
gw1count := 0
gw2count := 0
for i := uint16(0); i < iterationCount; i++ {
packet := firewall.Packet{
LocalAddr: netip.MustParseAddr("192.168.1.1"),
RemoteAddr: netip.MustParseAddr("10.0.0.1"),
LocalPort: i,
RemotePort: 65535 - i,
Protocol: 6, // TCP
Fragment: false,
}
selectedGw, ok := BalancePacket(&packet, gateways)
assert.False(t, ok)
switch selectedGw {
case gw1Addr:
gw1count += 1
case gw2Addr:
gw2count += 1
}
}
assert.Equal(t, int(iterationCount), (gw1count + gw2count))
assert.NotEqual(t, 0, gw1count)
assert.NotEqual(t, 0, gw2count)
}

70
routing/gateway.go Normal file
View File

@@ -0,0 +1,70 @@
package routing
import (
"fmt"
"net/netip"
)
const (
// Sentinal value
BucketNotCalculated = -1
)
type Gateways []Gateway
func (g Gateways) String() string {
str := ""
for i, gw := range g {
str += gw.String()
if i < len(g)-1 {
str += ", "
}
}
return str
}
type Gateway struct {
addr netip.Addr
weight int
bucketUpperBound int
}
func NewGateway(addr netip.Addr, weight int) Gateway {
return Gateway{addr: addr, weight: weight, bucketUpperBound: BucketNotCalculated}
}
func (g *Gateway) BucketUpperBound() int {
return g.bucketUpperBound
}
func (g *Gateway) Addr() netip.Addr {
return g.addr
}
func (g *Gateway) String() string {
return fmt.Sprintf("{addr: %s, weight: %d}", g.addr, g.weight)
}
// Divide and round to nearest integer
func divideAndRound(v uint64, d uint64) uint64 {
var tmp uint64 = v + d/2
return tmp / d
}
// Implements Hash-Threshold mapping, equivalent to the implementation in the linux kernel.
// After this function returns each gateway will have a
// positive bucketUpperBound with a maximum value of 2147483647 (INT_MAX)
func CalculateBucketsForGateways(gateways []Gateway) {
var totalWeight int = 0
for i := range gateways {
totalWeight += gateways[i].weight
}
var loopWeight int = 0
for i := range gateways {
loopWeight += gateways[i].weight
gateways[i].bucketUpperBound = int(divideAndRound(uint64(loopWeight)<<31, uint64(totalWeight))) - 1
}
}

34
routing/gateway_test.go Normal file
View File

@@ -0,0 +1,34 @@
package routing
import (
"net/netip"
"testing"
"github.com/stretchr/testify/assert"
)
func TestRebalance3_2Split(t *testing.T) {
gateways := []Gateway{}
gateways = append(gateways, Gateway{addr: netip.Addr{}, weight: 10})
gateways = append(gateways, Gateway{addr: netip.Addr{}, weight: 5})
CalculateBucketsForGateways(gateways)
assert.Equal(t, 1431655764, gateways[0].bucketUpperBound) // INT_MAX/3*2
assert.Equal(t, 2147483647, gateways[1].bucketUpperBound) // INT_MAX
}
func TestRebalanceEqualSplit(t *testing.T) {
gateways := []Gateway{}
gateways = append(gateways, Gateway{addr: netip.Addr{}, weight: 1})
gateways = append(gateways, Gateway{addr: netip.Addr{}, weight: 1})
gateways = append(gateways, Gateway{addr: netip.Addr{}, weight: 1})
CalculateBucketsForGateways(gateways)
assert.Equal(t, 715827882, gateways[0].bucketUpperBound) // INT_MAX/3
assert.Equal(t, 1431655764, gateways[1].bucketUpperBound) // INT_MAX/3*2
assert.Equal(t, 2147483647, gateways[2].bucketUpperBound) // INT_MAX
}

View File

@@ -13,10 +13,10 @@ import (
"github.com/slackhq/nebula/cert_test" "github.com/slackhq/nebula/cert_test"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"golang.org/x/sync/errgroup" "golang.org/x/sync/errgroup"
"gopkg.in/yaml.v2" "gopkg.in/yaml.v3"
) )
type m map[string]interface{} type m = map[string]any
func newSimpleService(caCrt cert.Certificate, caKey []byte, name string, udpIp netip.Addr, overrides m) *Service { func newSimpleService(caCrt cert.Certificate, caKey []byte, name string, udpIp netip.Addr, overrides m) *Service {
_, _, myPrivKey, myPEM := cert_test.NewTestCert(cert.Version2, cert.Curve_CURVE25519, caCrt, caKey, "a", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.PrefixFrom(udpIp, 24)}, nil, []string{}) _, _, myPrivKey, myPEM := cert_test.NewTestCert(cert.Version2, cert.Curve_CURVE25519, caCrt, caKey, "a", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.PrefixFrom(udpIp, 24)}, nil, []string{})

102
ssh.go
View File

@@ -124,10 +124,10 @@ func configSSH(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) (func(), erro
} }
rawKeys := c.Get("sshd.authorized_users") rawKeys := c.Get("sshd.authorized_users")
keys, ok := rawKeys.([]interface{}) keys, ok := rawKeys.([]any)
if ok { if ok {
for _, rk := range keys { for _, rk := range keys {
kDef, ok := rk.(map[interface{}]interface{}) kDef, ok := rk.(map[string]any)
if !ok { if !ok {
l.WithField("sshKeyConfig", rk).Warn("Authorized user had an error, ignoring") l.WithField("sshKeyConfig", rk).Warn("Authorized user had an error, ignoring")
continue continue
@@ -148,7 +148,7 @@ func configSSH(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) (func(), erro
continue continue
} }
case []interface{}: case []any:
for _, subK := range v { for _, subK := range v {
sk, ok := subK.(string) sk, ok := subK.(string)
if !ok { if !ok {
@@ -190,7 +190,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter
ssh.RegisterCommand(&sshd.Command{ ssh.RegisterCommand(&sshd.Command{
Name: "list-hostmap", Name: "list-hostmap",
ShortDescription: "List all known previously connected hosts", ShortDescription: "List all known previously connected hosts",
Flags: func() (*flag.FlagSet, interface{}) { Flags: func() (*flag.FlagSet, any) {
fl := flag.NewFlagSet("", flag.ContinueOnError) fl := flag.NewFlagSet("", flag.ContinueOnError)
s := sshListHostMapFlags{} s := sshListHostMapFlags{}
fl.BoolVar(&s.Json, "json", false, "outputs as json with more information") fl.BoolVar(&s.Json, "json", false, "outputs as json with more information")
@@ -198,7 +198,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter
fl.BoolVar(&s.ByIndex, "by-index", false, "gets all hosts in the hostmap from the index table") fl.BoolVar(&s.ByIndex, "by-index", false, "gets all hosts in the hostmap from the index table")
return fl, &s return fl, &s
}, },
Callback: func(fs interface{}, a []string, w sshd.StringWriter) error { Callback: func(fs any, a []string, w sshd.StringWriter) error {
return sshListHostMap(f.hostMap, fs, w) return sshListHostMap(f.hostMap, fs, w)
}, },
}) })
@@ -206,7 +206,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter
ssh.RegisterCommand(&sshd.Command{ ssh.RegisterCommand(&sshd.Command{
Name: "list-pending-hostmap", Name: "list-pending-hostmap",
ShortDescription: "List all handshaking hosts", ShortDescription: "List all handshaking hosts",
Flags: func() (*flag.FlagSet, interface{}) { Flags: func() (*flag.FlagSet, any) {
fl := flag.NewFlagSet("", flag.ContinueOnError) fl := flag.NewFlagSet("", flag.ContinueOnError)
s := sshListHostMapFlags{} s := sshListHostMapFlags{}
fl.BoolVar(&s.Json, "json", false, "outputs as json with more information") fl.BoolVar(&s.Json, "json", false, "outputs as json with more information")
@@ -214,7 +214,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter
fl.BoolVar(&s.ByIndex, "by-index", false, "gets all hosts in the hostmap from the index table") fl.BoolVar(&s.ByIndex, "by-index", false, "gets all hosts in the hostmap from the index table")
return fl, &s return fl, &s
}, },
Callback: func(fs interface{}, a []string, w sshd.StringWriter) error { Callback: func(fs any, a []string, w sshd.StringWriter) error {
return sshListHostMap(f.handshakeManager, fs, w) return sshListHostMap(f.handshakeManager, fs, w)
}, },
}) })
@@ -222,14 +222,14 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter
ssh.RegisterCommand(&sshd.Command{ ssh.RegisterCommand(&sshd.Command{
Name: "list-lighthouse-addrmap", Name: "list-lighthouse-addrmap",
ShortDescription: "List all lighthouse map entries", ShortDescription: "List all lighthouse map entries",
Flags: func() (*flag.FlagSet, interface{}) { Flags: func() (*flag.FlagSet, any) {
fl := flag.NewFlagSet("", flag.ContinueOnError) fl := flag.NewFlagSet("", flag.ContinueOnError)
s := sshListHostMapFlags{} s := sshListHostMapFlags{}
fl.BoolVar(&s.Json, "json", false, "outputs as json with more information") fl.BoolVar(&s.Json, "json", false, "outputs as json with more information")
fl.BoolVar(&s.Pretty, "pretty", false, "pretty prints json, assumes -json") fl.BoolVar(&s.Pretty, "pretty", false, "pretty prints json, assumes -json")
return fl, &s return fl, &s
}, },
Callback: func(fs interface{}, a []string, w sshd.StringWriter) error { Callback: func(fs any, a []string, w sshd.StringWriter) error {
return sshListLighthouseMap(f.lightHouse, fs, w) return sshListLighthouseMap(f.lightHouse, fs, w)
}, },
}) })
@@ -237,7 +237,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter
ssh.RegisterCommand(&sshd.Command{ ssh.RegisterCommand(&sshd.Command{
Name: "reload", Name: "reload",
ShortDescription: "Reloads configuration from disk, same as sending HUP to the process", ShortDescription: "Reloads configuration from disk, same as sending HUP to the process",
Callback: func(fs interface{}, a []string, w sshd.StringWriter) error { Callback: func(fs any, a []string, w sshd.StringWriter) error {
return sshReload(c, w) return sshReload(c, w)
}, },
}) })
@@ -251,7 +251,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter
ssh.RegisterCommand(&sshd.Command{ ssh.RegisterCommand(&sshd.Command{
Name: "stop-cpu-profile", Name: "stop-cpu-profile",
ShortDescription: "Stops a cpu profile and writes output to the previously provided file", ShortDescription: "Stops a cpu profile and writes output to the previously provided file",
Callback: func(fs interface{}, a []string, w sshd.StringWriter) error { Callback: func(fs any, a []string, w sshd.StringWriter) error {
pprof.StopCPUProfile() pprof.StopCPUProfile()
return w.WriteLine("If a CPU profile was running it is now stopped") return w.WriteLine("If a CPU profile was running it is now stopped")
}, },
@@ -278,7 +278,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter
ssh.RegisterCommand(&sshd.Command{ ssh.RegisterCommand(&sshd.Command{
Name: "log-level", Name: "log-level",
ShortDescription: "Gets or sets the current log level", ShortDescription: "Gets or sets the current log level",
Callback: func(fs interface{}, a []string, w sshd.StringWriter) error { Callback: func(fs any, a []string, w sshd.StringWriter) error {
return sshLogLevel(l, fs, a, w) return sshLogLevel(l, fs, a, w)
}, },
}) })
@@ -286,7 +286,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter
ssh.RegisterCommand(&sshd.Command{ ssh.RegisterCommand(&sshd.Command{
Name: "log-format", Name: "log-format",
ShortDescription: "Gets or sets the current log format", ShortDescription: "Gets or sets the current log format",
Callback: func(fs interface{}, a []string, w sshd.StringWriter) error { Callback: func(fs any, a []string, w sshd.StringWriter) error {
return sshLogFormat(l, fs, a, w) return sshLogFormat(l, fs, a, w)
}, },
}) })
@@ -294,7 +294,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter
ssh.RegisterCommand(&sshd.Command{ ssh.RegisterCommand(&sshd.Command{
Name: "version", Name: "version",
ShortDescription: "Prints the currently running version of nebula", ShortDescription: "Prints the currently running version of nebula",
Callback: func(fs interface{}, a []string, w sshd.StringWriter) error { Callback: func(fs any, a []string, w sshd.StringWriter) error {
return sshVersion(f, fs, a, w) return sshVersion(f, fs, a, w)
}, },
}) })
@@ -302,14 +302,14 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter
ssh.RegisterCommand(&sshd.Command{ ssh.RegisterCommand(&sshd.Command{
Name: "device-info", Name: "device-info",
ShortDescription: "Prints information about the network device.", ShortDescription: "Prints information about the network device.",
Flags: func() (*flag.FlagSet, interface{}) { Flags: func() (*flag.FlagSet, any) {
fl := flag.NewFlagSet("", flag.ContinueOnError) fl := flag.NewFlagSet("", flag.ContinueOnError)
s := sshDeviceInfoFlags{} s := sshDeviceInfoFlags{}
fl.BoolVar(&s.Json, "json", false, "outputs as json with more information") fl.BoolVar(&s.Json, "json", false, "outputs as json with more information")
fl.BoolVar(&s.Pretty, "pretty", false, "pretty prints json, assumes -json") fl.BoolVar(&s.Pretty, "pretty", false, "pretty prints json, assumes -json")
return fl, &s return fl, &s
}, },
Callback: func(fs interface{}, a []string, w sshd.StringWriter) error { Callback: func(fs any, a []string, w sshd.StringWriter) error {
return sshDeviceInfo(f, fs, w) return sshDeviceInfo(f, fs, w)
}, },
}) })
@@ -317,7 +317,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter
ssh.RegisterCommand(&sshd.Command{ ssh.RegisterCommand(&sshd.Command{
Name: "print-cert", Name: "print-cert",
ShortDescription: "Prints the current certificate being used or the certificate for the provided vpn addr", ShortDescription: "Prints the current certificate being used or the certificate for the provided vpn addr",
Flags: func() (*flag.FlagSet, interface{}) { Flags: func() (*flag.FlagSet, any) {
fl := flag.NewFlagSet("", flag.ContinueOnError) fl := flag.NewFlagSet("", flag.ContinueOnError)
s := sshPrintCertFlags{} s := sshPrintCertFlags{}
fl.BoolVar(&s.Json, "json", false, "outputs as json") fl.BoolVar(&s.Json, "json", false, "outputs as json")
@@ -325,7 +325,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter
fl.BoolVar(&s.Raw, "raw", false, "raw prints the PEM encoded certificate, not compatible with -json or -pretty") fl.BoolVar(&s.Raw, "raw", false, "raw prints the PEM encoded certificate, not compatible with -json or -pretty")
return fl, &s return fl, &s
}, },
Callback: func(fs interface{}, a []string, w sshd.StringWriter) error { Callback: func(fs any, a []string, w sshd.StringWriter) error {
return sshPrintCert(f, fs, a, w) return sshPrintCert(f, fs, a, w)
}, },
}) })
@@ -333,13 +333,13 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter
ssh.RegisterCommand(&sshd.Command{ ssh.RegisterCommand(&sshd.Command{
Name: "print-tunnel", Name: "print-tunnel",
ShortDescription: "Prints json details about a tunnel for the provided vpn addr", ShortDescription: "Prints json details about a tunnel for the provided vpn addr",
Flags: func() (*flag.FlagSet, interface{}) { Flags: func() (*flag.FlagSet, any) {
fl := flag.NewFlagSet("", flag.ContinueOnError) fl := flag.NewFlagSet("", flag.ContinueOnError)
s := sshPrintTunnelFlags{} s := sshPrintTunnelFlags{}
fl.BoolVar(&s.Pretty, "pretty", false, "pretty prints json") fl.BoolVar(&s.Pretty, "pretty", false, "pretty prints json")
return fl, &s return fl, &s
}, },
Callback: func(fs interface{}, a []string, w sshd.StringWriter) error { Callback: func(fs any, a []string, w sshd.StringWriter) error {
return sshPrintTunnel(f, fs, a, w) return sshPrintTunnel(f, fs, a, w)
}, },
}) })
@@ -347,13 +347,13 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter
ssh.RegisterCommand(&sshd.Command{ ssh.RegisterCommand(&sshd.Command{
Name: "print-relays", Name: "print-relays",
ShortDescription: "Prints json details about all relay info", ShortDescription: "Prints json details about all relay info",
Flags: func() (*flag.FlagSet, interface{}) { Flags: func() (*flag.FlagSet, any) {
fl := flag.NewFlagSet("", flag.ContinueOnError) fl := flag.NewFlagSet("", flag.ContinueOnError)
s := sshPrintTunnelFlags{} s := sshPrintTunnelFlags{}
fl.BoolVar(&s.Pretty, "pretty", false, "pretty prints json") fl.BoolVar(&s.Pretty, "pretty", false, "pretty prints json")
return fl, &s return fl, &s
}, },
Callback: func(fs interface{}, a []string, w sshd.StringWriter) error { Callback: func(fs any, a []string, w sshd.StringWriter) error {
return sshPrintRelays(f, fs, a, w) return sshPrintRelays(f, fs, a, w)
}, },
}) })
@@ -361,13 +361,13 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter
ssh.RegisterCommand(&sshd.Command{ ssh.RegisterCommand(&sshd.Command{
Name: "change-remote", Name: "change-remote",
ShortDescription: "Changes the remote address used in the tunnel for the provided vpn addr", ShortDescription: "Changes the remote address used in the tunnel for the provided vpn addr",
Flags: func() (*flag.FlagSet, interface{}) { Flags: func() (*flag.FlagSet, any) {
fl := flag.NewFlagSet("", flag.ContinueOnError) fl := flag.NewFlagSet("", flag.ContinueOnError)
s := sshChangeRemoteFlags{} s := sshChangeRemoteFlags{}
fl.StringVar(&s.Address, "address", "", "The new remote address, ip:port") fl.StringVar(&s.Address, "address", "", "The new remote address, ip:port")
return fl, &s return fl, &s
}, },
Callback: func(fs interface{}, a []string, w sshd.StringWriter) error { Callback: func(fs any, a []string, w sshd.StringWriter) error {
return sshChangeRemote(f, fs, a, w) return sshChangeRemote(f, fs, a, w)
}, },
}) })
@@ -375,13 +375,13 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter
ssh.RegisterCommand(&sshd.Command{ ssh.RegisterCommand(&sshd.Command{
Name: "close-tunnel", Name: "close-tunnel",
ShortDescription: "Closes a tunnel for the provided vpn addr", ShortDescription: "Closes a tunnel for the provided vpn addr",
Flags: func() (*flag.FlagSet, interface{}) { Flags: func() (*flag.FlagSet, any) {
fl := flag.NewFlagSet("", flag.ContinueOnError) fl := flag.NewFlagSet("", flag.ContinueOnError)
s := sshCloseTunnelFlags{} s := sshCloseTunnelFlags{}
fl.BoolVar(&s.LocalOnly, "local-only", false, "Disables notifying the remote that the tunnel is shutting down") fl.BoolVar(&s.LocalOnly, "local-only", false, "Disables notifying the remote that the tunnel is shutting down")
return fl, &s return fl, &s
}, },
Callback: func(fs interface{}, a []string, w sshd.StringWriter) error { Callback: func(fs any, a []string, w sshd.StringWriter) error {
return sshCloseTunnel(f, fs, a, w) return sshCloseTunnel(f, fs, a, w)
}, },
}) })
@@ -390,13 +390,13 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter
Name: "create-tunnel", Name: "create-tunnel",
ShortDescription: "Creates a tunnel for the provided vpn address", ShortDescription: "Creates a tunnel for the provided vpn address",
Help: "The lighthouses will be queried for real addresses but you can provide one as well.", Help: "The lighthouses will be queried for real addresses but you can provide one as well.",
Flags: func() (*flag.FlagSet, interface{}) { Flags: func() (*flag.FlagSet, any) {
fl := flag.NewFlagSet("", flag.ContinueOnError) fl := flag.NewFlagSet("", flag.ContinueOnError)
s := sshCreateTunnelFlags{} s := sshCreateTunnelFlags{}
fl.StringVar(&s.Address, "address", "", "Optionally provide a real remote address, ip:port ") fl.StringVar(&s.Address, "address", "", "Optionally provide a real remote address, ip:port ")
return fl, &s return fl, &s
}, },
Callback: func(fs interface{}, a []string, w sshd.StringWriter) error { Callback: func(fs any, a []string, w sshd.StringWriter) error {
return sshCreateTunnel(f, fs, a, w) return sshCreateTunnel(f, fs, a, w)
}, },
}) })
@@ -405,13 +405,13 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter
Name: "query-lighthouse", Name: "query-lighthouse",
ShortDescription: "Query the lighthouses for the provided vpn address", ShortDescription: "Query the lighthouses for the provided vpn address",
Help: "This command is asynchronous. Only currently known udp addresses will be printed.", Help: "This command is asynchronous. Only currently known udp addresses will be printed.",
Callback: func(fs interface{}, a []string, w sshd.StringWriter) error { Callback: func(fs any, a []string, w sshd.StringWriter) error {
return sshQueryLighthouse(f, fs, a, w) return sshQueryLighthouse(f, fs, a, w)
}, },
}) })
} }
func sshListHostMap(hl controlHostLister, a interface{}, w sshd.StringWriter) error { func sshListHostMap(hl controlHostLister, a any, w sshd.StringWriter) error {
fs, ok := a.(*sshListHostMapFlags) fs, ok := a.(*sshListHostMapFlags)
if !ok { if !ok {
return nil return nil
@@ -451,7 +451,7 @@ func sshListHostMap(hl controlHostLister, a interface{}, w sshd.StringWriter) er
return nil return nil
} }
func sshListLighthouseMap(lightHouse *LightHouse, a interface{}, w sshd.StringWriter) error { func sshListLighthouseMap(lightHouse *LightHouse, a any, w sshd.StringWriter) error {
fs, ok := a.(*sshListHostMapFlags) fs, ok := a.(*sshListHostMapFlags)
if !ok { if !ok {
return nil return nil
@@ -505,7 +505,7 @@ func sshListLighthouseMap(lightHouse *LightHouse, a interface{}, w sshd.StringWr
return nil return nil
} }
func sshStartCpuProfile(fs interface{}, a []string, w sshd.StringWriter) error { func sshStartCpuProfile(fs any, a []string, w sshd.StringWriter) error {
if len(a) == 0 { if len(a) == 0 {
err := w.WriteLine("No path to write profile provided") err := w.WriteLine("No path to write profile provided")
return err return err
@@ -527,11 +527,11 @@ func sshStartCpuProfile(fs interface{}, a []string, w sshd.StringWriter) error {
return err return err
} }
func sshVersion(ifce *Interface, fs interface{}, a []string, w sshd.StringWriter) error { func sshVersion(ifce *Interface, _ any, _ []string, w sshd.StringWriter) error {
return w.WriteLine(fmt.Sprintf("%s", ifce.version)) return w.WriteLine(ifce.version)
} }
func sshQueryLighthouse(ifce *Interface, fs interface{}, a []string, w sshd.StringWriter) error { func sshQueryLighthouse(ifce *Interface, _ any, a []string, w sshd.StringWriter) error {
if len(a) == 0 { if len(a) == 0 {
return w.WriteLine("No vpn address was provided") return w.WriteLine("No vpn address was provided")
} }
@@ -553,7 +553,7 @@ func sshQueryLighthouse(ifce *Interface, fs interface{}, a []string, w sshd.Stri
return json.NewEncoder(w.GetWriter()).Encode(cm) return json.NewEncoder(w.GetWriter()).Encode(cm)
} }
func sshCloseTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWriter) error { func sshCloseTunnel(ifce *Interface, fs any, a []string, w sshd.StringWriter) error {
flags, ok := fs.(*sshCloseTunnelFlags) flags, ok := fs.(*sshCloseTunnelFlags)
if !ok { if !ok {
return nil return nil
@@ -584,7 +584,7 @@ func sshCloseTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWr
hostInfo.ConnectionState, hostInfo.ConnectionState,
hostInfo, hostInfo,
[]byte{}, []byte{},
make([]byte, 12, 12), make([]byte, 12),
make([]byte, mtu), make([]byte, mtu),
) )
} }
@@ -593,7 +593,7 @@ func sshCloseTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWr
return w.WriteLine("Closed") return w.WriteLine("Closed")
} }
func sshCreateTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWriter) error { func sshCreateTunnel(ifce *Interface, fs any, a []string, w sshd.StringWriter) error {
flags, ok := fs.(*sshCreateTunnelFlags) flags, ok := fs.(*sshCreateTunnelFlags)
if !ok { if !ok {
return nil return nil
@@ -614,12 +614,12 @@ func sshCreateTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringW
hostInfo := ifce.hostMap.QueryVpnAddr(vpnAddr) hostInfo := ifce.hostMap.QueryVpnAddr(vpnAddr)
if hostInfo != nil { if hostInfo != nil {
return w.WriteLine(fmt.Sprintf("Tunnel already exists")) return w.WriteLine("Tunnel already exists")
} }
hostInfo = ifce.handshakeManager.QueryVpnAddr(vpnAddr) hostInfo = ifce.handshakeManager.QueryVpnAddr(vpnAddr)
if hostInfo != nil { if hostInfo != nil {
return w.WriteLine(fmt.Sprintf("Tunnel already handshaking")) return w.WriteLine("Tunnel already handshaking")
} }
var addr netip.AddrPort var addr netip.AddrPort
@@ -638,7 +638,7 @@ func sshCreateTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringW
return w.WriteLine("Created") return w.WriteLine("Created")
} }
func sshChangeRemote(ifce *Interface, fs interface{}, a []string, w sshd.StringWriter) error { func sshChangeRemote(ifce *Interface, fs any, a []string, w sshd.StringWriter) error {
flags, ok := fs.(*sshChangeRemoteFlags) flags, ok := fs.(*sshChangeRemoteFlags)
if !ok { if !ok {
return nil return nil
@@ -675,7 +675,7 @@ func sshChangeRemote(ifce *Interface, fs interface{}, a []string, w sshd.StringW
return w.WriteLine("Changed") return w.WriteLine("Changed")
} }
func sshGetHeapProfile(fs interface{}, a []string, w sshd.StringWriter) error { func sshGetHeapProfile(fs any, a []string, w sshd.StringWriter) error {
if len(a) == 0 { if len(a) == 0 {
return w.WriteLine("No path to write profile provided") return w.WriteLine("No path to write profile provided")
} }
@@ -696,7 +696,7 @@ func sshGetHeapProfile(fs interface{}, a []string, w sshd.StringWriter) error {
return err return err
} }
func sshMutexProfileFraction(fs interface{}, a []string, w sshd.StringWriter) error { func sshMutexProfileFraction(fs any, a []string, w sshd.StringWriter) error {
if len(a) == 0 { if len(a) == 0 {
rate := runtime.SetMutexProfileFraction(-1) rate := runtime.SetMutexProfileFraction(-1)
return w.WriteLine(fmt.Sprintf("Current value: %d", rate)) return w.WriteLine(fmt.Sprintf("Current value: %d", rate))
@@ -711,7 +711,7 @@ func sshMutexProfileFraction(fs interface{}, a []string, w sshd.StringWriter) er
return w.WriteLine(fmt.Sprintf("New value: %d. Old value: %d", newRate, oldRate)) return w.WriteLine(fmt.Sprintf("New value: %d. Old value: %d", newRate, oldRate))
} }
func sshGetMutexProfile(fs interface{}, a []string, w sshd.StringWriter) error { func sshGetMutexProfile(fs any, a []string, w sshd.StringWriter) error {
if len(a) == 0 { if len(a) == 0 {
return w.WriteLine("No path to write profile provided") return w.WriteLine("No path to write profile provided")
} }
@@ -735,7 +735,7 @@ func sshGetMutexProfile(fs interface{}, a []string, w sshd.StringWriter) error {
return w.WriteLine(fmt.Sprintf("Mutex profile created at %s", a)) return w.WriteLine(fmt.Sprintf("Mutex profile created at %s", a))
} }
func sshLogLevel(l *logrus.Logger, fs interface{}, a []string, w sshd.StringWriter) error { func sshLogLevel(l *logrus.Logger, _ any, a []string, w sshd.StringWriter) error {
if len(a) == 0 { if len(a) == 0 {
return w.WriteLine(fmt.Sprintf("Log level is: %s", l.Level)) return w.WriteLine(fmt.Sprintf("Log level is: %s", l.Level))
} }
@@ -749,7 +749,7 @@ func sshLogLevel(l *logrus.Logger, fs interface{}, a []string, w sshd.StringWrit
return w.WriteLine(fmt.Sprintf("Log level is: %s", l.Level)) return w.WriteLine(fmt.Sprintf("Log level is: %s", l.Level))
} }
func sshLogFormat(l *logrus.Logger, fs interface{}, a []string, w sshd.StringWriter) error { func sshLogFormat(l *logrus.Logger, _ any, a []string, w sshd.StringWriter) error {
if len(a) == 0 { if len(a) == 0 {
return w.WriteLine(fmt.Sprintf("Log format is: %s", reflect.TypeOf(l.Formatter))) return w.WriteLine(fmt.Sprintf("Log format is: %s", reflect.TypeOf(l.Formatter)))
} }
@@ -767,7 +767,7 @@ func sshLogFormat(l *logrus.Logger, fs interface{}, a []string, w sshd.StringWri
return w.WriteLine(fmt.Sprintf("Log format is: %s", reflect.TypeOf(l.Formatter))) return w.WriteLine(fmt.Sprintf("Log format is: %s", reflect.TypeOf(l.Formatter)))
} }
func sshPrintCert(ifce *Interface, fs interface{}, a []string, w sshd.StringWriter) error { func sshPrintCert(ifce *Interface, fs any, a []string, w sshd.StringWriter) error {
args, ok := fs.(*sshPrintCertFlags) args, ok := fs.(*sshPrintCertFlags)
if !ok { if !ok {
return nil return nil
@@ -822,10 +822,10 @@ func sshPrintCert(ifce *Interface, fs interface{}, a []string, w sshd.StringWrit
return w.WriteLine(cert.String()) return w.WriteLine(cert.String())
} }
func sshPrintRelays(ifce *Interface, fs interface{}, a []string, w sshd.StringWriter) error { func sshPrintRelays(ifce *Interface, fs any, _ []string, w sshd.StringWriter) error {
args, ok := fs.(*sshPrintTunnelFlags) args, ok := fs.(*sshPrintTunnelFlags)
if !ok { if !ok {
w.WriteLine(fmt.Sprintf("sshPrintRelays failed to convert args type")) w.WriteLine("sshPrintRelays failed to convert args type")
return nil return nil
} }
@@ -919,7 +919,7 @@ func sshPrintRelays(ifce *Interface, fs interface{}, a []string, w sshd.StringWr
return nil return nil
} }
func sshPrintTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWriter) error { func sshPrintTunnel(ifce *Interface, fs any, a []string, w sshd.StringWriter) error {
args, ok := fs.(*sshPrintTunnelFlags) args, ok := fs.(*sshPrintTunnelFlags)
if !ok { if !ok {
return nil return nil
@@ -951,7 +951,7 @@ func sshPrintTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWr
return enc.Encode(copyHostInfo(hostInfo, ifce.hostMap.GetPreferredRanges())) return enc.Encode(copyHostInfo(hostInfo, ifce.hostMap.GetPreferredRanges()))
} }
func sshDeviceInfo(ifce *Interface, fs interface{}, w sshd.StringWriter) error { func sshDeviceInfo(ifce *Interface, fs any, w sshd.StringWriter) error {
data := struct { data := struct {
Name string `json:"name"` Name string `json:"name"`

View File

@@ -12,7 +12,7 @@ import (
// CommandFlags is a function called before help or command execution to parse command line flags // CommandFlags is a function called before help or command execution to parse command line flags
// It should return a flag.FlagSet instance and a pointer to the struct that will contain parsed flags // It should return a flag.FlagSet instance and a pointer to the struct that will contain parsed flags
type CommandFlags func() (*flag.FlagSet, interface{}) type CommandFlags func() (*flag.FlagSet, any)
// CommandCallback is the function called when your command should execute. // CommandCallback is the function called when your command should execute.
// fs will be a a pointer to the struct provided by Command.Flags callback, if there was one. -h and -help are reserved // fs will be a a pointer to the struct provided by Command.Flags callback, if there was one. -h and -help are reserved
@@ -21,7 +21,7 @@ type CommandFlags func() (*flag.FlagSet, interface{})
// w is the writer to use when sending messages back to the client. // w is the writer to use when sending messages back to the client.
// If an error is returned by the callback it is logged locally, the callback should handle messaging errors to the user // If an error is returned by the callback it is logged locally, the callback should handle messaging errors to the user
// where appropriate // where appropriate
type CommandCallback func(fs interface{}, a []string, w StringWriter) error type CommandCallback func(fs any, a []string, w StringWriter) error
type Command struct { type Command struct {
Name string Name string
@@ -34,7 +34,7 @@ type Command struct {
func execCommand(c *Command, args []string, w StringWriter) error { func execCommand(c *Command, args []string, w StringWriter) error {
var ( var (
fl *flag.FlagSet fl *flag.FlagSet
fs interface{} fs any
) )
if c.Flags != nil { if c.Flags != nil {
@@ -85,7 +85,7 @@ func lookupCommand(c *radix.Tree, sCmd string) (*Command, error) {
func matchCommand(c *radix.Tree, cmd string) []string { func matchCommand(c *radix.Tree, cmd string) []string {
cmds := make([]string, 0) cmds := make([]string, 0)
c.WalkPrefix(cmd, func(found string, v interface{}) bool { c.WalkPrefix(cmd, func(found string, v any) bool {
cmds = append(cmds, found) cmds = append(cmds, found)
return false return false
}) })
@@ -95,7 +95,7 @@ func matchCommand(c *radix.Tree, cmd string) []string {
func allCommands(c *radix.Tree) []*Command { func allCommands(c *radix.Tree) []*Command {
cmds := make([]*Command, 0) cmds := make([]*Command, 0)
c.WalkPrefix("", func(found string, v interface{}) bool { c.WalkPrefix("", func(found string, v any) bool {
cmd, ok := v.(*Command) cmd, ok := v.(*Command)
if ok { if ok {
cmds = append(cmds, cmd) cmds = append(cmds, cmd)

View File

@@ -23,7 +23,6 @@ type SSHServer struct {
trustedCAs []ssh.PublicKey trustedCAs []ssh.PublicKey
// List of available commands // List of available commands
helpCommand *Command
commands *radix.Tree commands *radix.Tree
listener net.Listener listener net.Listener
@@ -43,7 +42,7 @@ func NewSSHServer(l *logrus.Entry) (*SSHServer, error) {
conns: make(map[int]*session), conns: make(map[int]*session),
} }
cc := ssh.CertChecker{ cc := &ssh.CertChecker{
IsUserAuthority: func(auth ssh.PublicKey) bool { IsUserAuthority: func(auth ssh.PublicKey) bool {
for _, ca := range s.trustedCAs { for _, ca := range s.trustedCAs {
if bytes.Equal(ca.Marshal(), auth.Marshal()) { if bytes.Equal(ca.Marshal(), auth.Marshal()) {
@@ -77,16 +76,17 @@ func NewSSHServer(l *logrus.Entry) (*SSHServer, error) {
}, },
} }
s.certChecker = cc
s.config = &ssh.ServerConfig{ s.config = &ssh.ServerConfig{
PublicKeyCallback: cc.Authenticate, PublicKeyCallback: cc.Authenticate,
ServerVersion: fmt.Sprintf("SSH-2.0-Nebula???"), ServerVersion: "SSH-2.0-Nebula???",
} }
s.RegisterCommand(&Command{ s.RegisterCommand(&Command{
Name: "help", Name: "help",
ShortDescription: "prints available commands or help <command> for specific usage info", ShortDescription: "prints available commands or help <command> for specific usage info",
Callback: func(a interface{}, args []string, w StringWriter) error { Callback: func(a any, args []string, w StringWriter) error {
return helpCallback(s.commands, args, w) return helpCallback(s.commands, args, w)
}, },
}) })

View File

@@ -9,13 +9,13 @@ import (
"github.com/armon/go-radix" "github.com/armon/go-radix"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/terminal" "golang.org/x/term"
) )
type session struct { type session struct {
l *logrus.Entry l *logrus.Entry
c *ssh.ServerConn c *ssh.ServerConn
term *terminal.Terminal term *term.Terminal
commands *radix.Tree commands *radix.Tree
exitChan chan bool exitChan chan bool
} }
@@ -31,7 +31,7 @@ func NewSession(commands *radix.Tree, conn *ssh.ServerConn, chans <-chan ssh.New
s.commands.Insert("logout", &Command{ s.commands.Insert("logout", &Command{
Name: "logout", Name: "logout",
ShortDescription: "Ends the current session", ShortDescription: "Ends the current session",
Callback: func(a interface{}, args []string, w StringWriter) error { Callback: func(a any, args []string, w StringWriter) error {
s.Close() s.Close()
return nil return nil
}, },
@@ -106,8 +106,8 @@ func (s *session) handleRequests(in <-chan *ssh.Request, channel ssh.Channel) {
} }
} }
func (s *session) createTerm(channel ssh.Channel) *terminal.Terminal { func (s *session) createTerm(channel ssh.Channel) *term.Terminal {
term := terminal.NewTerminal(channel, s.c.User()+"@nebula > ") term := term.NewTerminal(channel, s.c.User()+"@nebula > ")
term.AutoCompleteCallback = func(line string, pos int, key rune) (newLine string, newPos int, ok bool) { term.AutoCompleteCallback = func(line string, pos int, key rune) (newLine string, newPos int, ok bool) {
// key 9 is tab // key 9 is tab
if key == 9 { if key == 9 {
@@ -170,7 +170,6 @@ func (s *session) dispatchCommand(line string, w StringWriter) {
} }
_ = execCommand(c, args[1:], w) _ = execCommand(c, args[1:], w)
return
} }
func (s *session) Close() { func (s *session) Close() {

View File

@@ -13,7 +13,7 @@ import (
// AssertDeepCopyEqual checks to see if two variables have the same values but DO NOT share any memory // AssertDeepCopyEqual checks to see if two variables have the same values but DO NOT share any memory
// There is currently a special case for `time.loc` (as this code traverses into unexported fields) // There is currently a special case for `time.loc` (as this code traverses into unexported fields)
func AssertDeepCopyEqual(t *testing.T, a interface{}, b interface{}) { func AssertDeepCopyEqual(t *testing.T, a any, b any) {
v1 := reflect.ValueOf(a) v1 := reflect.ValueOf(a)
v2 := reflect.ValueOf(b) v2 := reflect.ValueOf(b)

Some files were not shown because too many files have changed in this diff Show More