From 4eb1da09586191a9e6ca0d320471c94cb30896d5 Mon Sep 17 00:00:00 2001 From: Wade Simmons Date: Wed, 29 May 2024 12:52:52 -0400 Subject: [PATCH 01/67] remove deadlock in GetOrHandshake (#1151) We had a rare deadlock in GetOrHandshake because we kept the hostmap lock when we do the call to StartHandshake. StartHandshake can block while sending to the lighthouse query worker channel, and that worker needs to be able to grab the hostmap lock to do its work. Other calls for StartHandshake don't hold the hostmap lock so we should be able to drop it here. This lock was originally added with: https://github.com/slackhq/nebula/pull/954 --- handshake_manager.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/handshake_manager.go b/handshake_manager.go index 640227a..2372ced 100644 --- a/handshake_manager.go +++ b/handshake_manager.go @@ -356,10 +356,11 @@ func (hm *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTrigger // GetOrHandshake will try to find a hostinfo with a fully formed tunnel or start a new handshake if one is not present // The 2nd argument will be true if the hostinfo is ready to transmit traffic func (hm *HandshakeManager) GetOrHandshake(vpnIp iputil.VpnIp, cacheCb func(*HandshakeHostInfo)) (*HostInfo, bool) { - // Check the main hostmap and maintain a read lock if our host is not there hm.mainHostMap.RLock() - if h, ok := hm.mainHostMap.Hosts[vpnIp]; ok { - hm.mainHostMap.RUnlock() + h, ok := hm.mainHostMap.Hosts[vpnIp] + hm.mainHostMap.RUnlock() + + if ok { // Do not attempt promotion if you are a lighthouse if !hm.lightHouse.amLighthouse { h.TryPromoteBest(hm.mainHostMap.GetPreferredRanges(), hm.f) @@ -367,7 +368,6 @@ func (hm *HandshakeManager) GetOrHandshake(vpnIp iputil.VpnIp, cacheCb func(*Han return h, true } - defer hm.mainHostMap.RUnlock() return hm.StartHandshake(vpnIp, cacheCb), false } From a92056a7db2fcae11078d677a88a471cd6be707e Mon Sep 17 00:00:00 2001 From: Wade Simmons Date: Wed, 29 May 2024 14:06:46 -0400 Subject: [PATCH 02/67] v1.9.1 (#1152) Update CHANGELOG for Nebula v1.9.1 --- CHANGELOG.md | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b7b3e01..184eea8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +## [1.9.1] - 2024-05-29 + +### Fixed + +- Fixed a potential deadlock in GetOrHandshake. (#1151) + ## [1.9.0] - 2024-05-07 ### Deprecated @@ -626,7 +632,8 @@ created.) - Initial public release. -[Unreleased]: https://github.com/slackhq/nebula/compare/v1.9.0...HEAD +[Unreleased]: https://github.com/slackhq/nebula/compare/v1.9.1...HEAD +[1.9.1]: https://github.com/slackhq/nebula/releases/tag/v1.9.1 [1.9.0]: https://github.com/slackhq/nebula/releases/tag/v1.9.0 [1.8.2]: https://github.com/slackhq/nebula/releases/tag/v1.8.2 [1.8.1]: https://github.com/slackhq/nebula/releases/tag/v1.8.1 From d9cae9e0627954e71d3b5a2e85daf19000167d95 Mon Sep 17 00:00:00 2001 From: Wade Simmons Date: Mon, 3 Jun 2024 15:40:51 -0400 Subject: [PATCH 03/67] ensure messageCounter is set before handshake is complete (#1154) Ensure we set messageCounter to 2 before the handshake is marked as complete. --- handshake_ix.go | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/handshake_ix.go b/handshake_ix.go index 8727b16..b86ecab 100644 --- a/handshake_ix.go +++ b/handshake_ix.go @@ -1,6 +1,7 @@ package nebula import ( + "fmt" "time" "github.com/flynn/noise" @@ -321,7 +322,11 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []by } f.connectionManager.AddTrafficWatch(hostinfo.localIndexId) - hostinfo.ConnectionState.messageCounter.Store(2) + prev := hostinfo.ConnectionState.messageCounter.Swap(2) + if prev > 2 { + panic(fmt.Errorf("invalid state: messageCounter > 2 before handshake complete: %v", prev)) + } + hostinfo.remotes.ResetBlockedRemotes() return @@ -463,12 +468,15 @@ func ixHandshakeStage2(f *Interface, addr *udp.Addr, via *ViaSender, hh *Handsha // Build up the radix for the firewall if we have subnets in the cert hostinfo.CreateRemoteCIDR(remoteCert) + prev := hostinfo.ConnectionState.messageCounter.Swap(2) + if prev > 2 { + panic(fmt.Errorf("invalid state: messageCounter > 2 before handshake complete: %v", prev)) + } + // Complete our handshake and update metrics, this will replace any existing tunnels for this vpnIp f.handshakeManager.Complete(hostinfo, f) f.connectionManager.AddTrafficWatch(hostinfo.localIndexId) - hostinfo.ConnectionState.messageCounter.Store(2) - if f.l.Level >= logrus.DebugLevel { hostinfo.logger(f.l).Debugf("Sending %d stored packets", len(hh.packetStore)) } From 249ae41fec4b9f587c09aabcc712b3fa5febb9da Mon Sep 17 00:00:00 2001 From: Wade Simmons Date: Mon, 3 Jun 2024 15:50:02 -0400 Subject: [PATCH 04/67] v1.9.2 (#1155) Update CHANGELOG for Nebula v1.9.2 --- CHANGELOG.md | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 184eea8..555d82a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +## [1.9.2] - 2024-06-03 + +### Fixed + +- Ensure messageCounter is set before handshake is complete. (#1154) + ## [1.9.1] - 2024-05-29 ### Fixed @@ -632,7 +638,8 @@ created.) - Initial public release. -[Unreleased]: https://github.com/slackhq/nebula/compare/v1.9.1...HEAD +[Unreleased]: https://github.com/slackhq/nebula/compare/v1.9.2...HEAD +[1.9.2]: https://github.com/slackhq/nebula/releases/tag/v1.9.2 [1.9.1]: https://github.com/slackhq/nebula/releases/tag/v1.9.1 [1.9.0]: https://github.com/slackhq/nebula/releases/tag/v1.9.0 [1.8.2]: https://github.com/slackhq/nebula/releases/tag/v1.8.2 From 4c066d8c3257cb800f0aad09a1f53a37ebfa1686 Mon Sep 17 00:00:00 2001 From: Wade Simmons Date: Thu, 6 Jun 2024 13:03:07 -0400 Subject: [PATCH 05/67] initialize messageCounter to 2 instead of verifying later (#1156) Clean up the messageCounter checks added in #1154. Instead of checking that messageCounter is still at 2, just initialize it to 2 and only increment for non-handshake messages. Handshake packets will always be packets 1 and 2. --- connection_state.go | 2 ++ handshake_ix.go | 11 ----------- 2 files changed, 2 insertions(+), 11 deletions(-) diff --git a/connection_state.go b/connection_state.go index 8ef8b3a..1dd3c8c 100644 --- a/connection_state.go +++ b/connection_state.go @@ -72,6 +72,8 @@ func NewConnectionState(l *logrus.Logger, cipher string, certState *CertState, i window: b, myCert: certState.Certificate, } + // always start the counter from 2, as packet 1 and packet 2 are handshake packets. + ci.messageCounter.Add(2) return ci } diff --git a/handshake_ix.go b/handshake_ix.go index b86ecab..d0bee86 100644 --- a/handshake_ix.go +++ b/handshake_ix.go @@ -1,7 +1,6 @@ package nebula import ( - "fmt" "time" "github.com/flynn/noise" @@ -47,7 +46,6 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool { } h := header.Encode(make([]byte, header.Len), header.Version, header.Handshake, header.HandshakeIXPSK0, 0, 1) - ci.messageCounter.Add(1) msg, _, _, err := ci.H.WriteMessage(h, hsBytes) if err != nil { @@ -322,10 +320,6 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []by } f.connectionManager.AddTrafficWatch(hostinfo.localIndexId) - prev := hostinfo.ConnectionState.messageCounter.Swap(2) - if prev > 2 { - panic(fmt.Errorf("invalid state: messageCounter > 2 before handshake complete: %v", prev)) - } hostinfo.remotes.ResetBlockedRemotes() @@ -468,11 +462,6 @@ func ixHandshakeStage2(f *Interface, addr *udp.Addr, via *ViaSender, hh *Handsha // Build up the radix for the firewall if we have subnets in the cert hostinfo.CreateRemoteCIDR(remoteCert) - prev := hostinfo.ConnectionState.messageCounter.Swap(2) - if prev > 2 { - panic(fmt.Errorf("invalid state: messageCounter > 2 before handshake complete: %v", prev)) - } - // Complete our handshake and update metrics, this will replace any existing tunnels for this vpnIp f.handshakeManager.Complete(hostinfo, f) f.connectionManager.AddTrafficWatch(hostinfo.localIndexId) From b14bad586ac4eb922fe11c1a4f360e223bd8dc8b Mon Sep 17 00:00:00 2001 From: Wade Simmons Date: Thu, 6 Jun 2024 13:17:07 -0400 Subject: [PATCH 06/67] v1.9.3 (#1160) Update CHANGELOG for Nebula v1.9.3 --- CHANGELOG.md | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 555d82a..f763b69 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +## [1.9.3] - 2024-06-06 + +### Fixed + +- Initialize messageCounter to 2 instead of verifying later. (#1156) + ## [1.9.2] - 2024-06-03 ### Fixed @@ -638,7 +644,8 @@ created.) - Initial public release. -[Unreleased]: https://github.com/slackhq/nebula/compare/v1.9.2...HEAD +[Unreleased]: https://github.com/slackhq/nebula/compare/v1.9.3...HEAD +[1.9.3]: https://github.com/slackhq/nebula/releases/tag/v1.9.3 [1.9.2]: https://github.com/slackhq/nebula/releases/tag/v1.9.2 [1.9.1]: https://github.com/slackhq/nebula/releases/tag/v1.9.1 [1.9.0]: https://github.com/slackhq/nebula/releases/tag/v1.9.0 From 40cfd00e8770ff212e1d6766edb41ffe75f6fea3 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 10 Jun 2024 16:08:43 -0400 Subject: [PATCH 07/67] Bump the golang-x-dependencies group with 4 updates (#1161) Bumps the golang-x-dependencies group with 4 updates: [golang.org/x/crypto](https://github.com/golang/crypto), [golang.org/x/net](https://github.com/golang/net), [golang.org/x/sys](https://github.com/golang/sys) and [golang.org/x/term](https://github.com/golang/term). Updates `golang.org/x/crypto` from 0.23.0 to 0.24.0 - [Commits](https://github.com/golang/crypto/compare/v0.23.0...v0.24.0) Updates `golang.org/x/net` from 0.25.0 to 0.26.0 - [Commits](https://github.com/golang/net/compare/v0.25.0...v0.26.0) Updates `golang.org/x/sys` from 0.20.0 to 0.21.0 - [Commits](https://github.com/golang/sys/compare/v0.20.0...v0.21.0) Updates `golang.org/x/term` from 0.20.0 to 0.21.0 - [Commits](https://github.com/golang/term/compare/v0.20.0...v0.21.0) --- updated-dependencies: - dependency-name: golang.org/x/crypto dependency-type: direct:production update-type: version-update:semver-minor dependency-group: golang-x-dependencies - dependency-name: golang.org/x/net dependency-type: direct:production update-type: version-update:semver-minor dependency-group: golang-x-dependencies - dependency-name: golang.org/x/sys dependency-type: direct:production update-type: version-update:semver-minor dependency-group: golang-x-dependencies - dependency-name: golang.org/x/term dependency-type: direct:production update-type: version-update:semver-minor dependency-group: golang-x-dependencies ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- go.mod | 8 ++++---- go.sum | 16 ++++++++-------- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/go.mod b/go.mod index b1f7215..6705e52 100644 --- a/go.mod +++ b/go.mod @@ -22,12 +22,12 @@ require ( github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8 github.com/stretchr/testify v1.9.0 github.com/vishvananda/netlink v1.2.1-beta.2 - golang.org/x/crypto v0.23.0 + golang.org/x/crypto v0.24.0 golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 - golang.org/x/net v0.25.0 + golang.org/x/net v0.26.0 golang.org/x/sync v0.7.0 - golang.org/x/sys v0.20.0 - golang.org/x/term v0.20.0 + golang.org/x/sys v0.21.0 + golang.org/x/term v0.21.0 golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b golang.zx2c4.com/wireguard/windows v0.5.3 diff --git a/go.sum b/go.sum index 0e67186..24d69dc 100644 --- a/go.sum +++ b/go.sum @@ -147,8 +147,8 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= -golang.org/x/crypto v0.23.0 h1:dIJU/v2J8Mdglj/8rJ6UUOM3Zc9zLZxVZwwxMooUSAI= -golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8= +golang.org/x/crypto v0.24.0 h1:mnl8DM0o513X8fdIkmyFE/5hTYxbwYOjDS/+rK6qpRI= +golang.org/x/crypto v0.24.0/go.mod h1:Z1PMYSOR5nyMcyAVAIQSKCDwalqy85Aqn1x3Ws4L5DM= golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 h1:Di6/M8l0O2lCLc6VVRWhgCiApHV8MnQurBnFSHsQtNY= golang.org/x/exp v0.0.0-20230725093048-515e97ebf090/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc= golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY= @@ -167,8 +167,8 @@ golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLL golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= -golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac= -golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM= +golang.org/x/net v0.26.0 h1:soB7SVo0PWrY4vPW/+ay0jKDNScG2X9wFeYlXIvJsOQ= +golang.org/x/net v0.26.0/go.mod h1:5YKkiSynbBIh3p6iOc/vibscux0x38BZDkn8sCUPxHE= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -195,11 +195,11 @@ golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210603081109-ebe580a85c40/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y= -golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.21.0 h1:rF+pYz3DAGSQAxAu1CbC7catZg4ebC4UIeIhKxBZvws= +golang.org/x/sys v0.21.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= -golang.org/x/term v0.20.0 h1:VnkxpohqXaOBYJtBmEppKUG6mXpi+4O6purfc2+sMhw= -golang.org/x/term v0.20.0/go.mod h1:8UkIAJTvZgivsXaD6/pH6U9ecQzZ45awqEOzuCvwpFY= +golang.org/x/term v0.21.0 h1:WVXCp+/EBEHOj53Rvu+7KiT/iElMrO8ACK16SMZ3jaA= +golang.org/x/term v0.21.0/go.mod h1:ooXLefLobQVslOqselCNF4SxFAaoS6KujMbsGzSDmX0= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= From d372df56ab0c087e65c33f7621b7a32f174e493f Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 24 Jun 2024 14:45:52 -0400 Subject: [PATCH 08/67] Bump google.golang.org/protobuf in the protobuf-dependencies group (#1167) Bumps the protobuf-dependencies group with 1 update: google.golang.org/protobuf. Updates `google.golang.org/protobuf` from 1.34.1 to 1.34.2 --- updated-dependencies: - dependency-name: google.golang.org/protobuf dependency-type: direct:production update-type: version-update:semver-patch dependency-group: protobuf-dependencies ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- go.mod | 2 +- go.sum | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/go.mod b/go.mod index 6705e52..7ee3d68 100644 --- a/go.mod +++ b/go.mod @@ -31,7 +31,7 @@ require ( golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b golang.zx2c4.com/wireguard/windows v0.5.3 - google.golang.org/protobuf v1.34.1 + google.golang.org/protobuf v1.34.2 gopkg.in/yaml.v2 v2.4.0 gvisor.dev/gvisor v0.0.0-20240423190808-9d7a357edefe ) diff --git a/go.sum b/go.sum index 24d69dc..9d9d8ce 100644 --- a/go.sum +++ b/go.sum @@ -230,8 +230,8 @@ google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miE google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo= google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= -google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFWg= -google.golang.org/protobuf v1.34.1/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= +google.golang.org/protobuf v1.34.2 h1:6xV6lTsCfpGD21XK49h7MhtcApnLqkfYgPcdHftf6hg= +google.golang.org/protobuf v1.34.2/go.mod h1:qYOHts0dSfpeUzUFpOMr/WGzszTmLH+DiWniOlNbLDw= gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= From 506ba5ab5b62bc14397738854d103c60296769d8 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 24 Jun 2024 14:46:27 -0400 Subject: [PATCH 09/67] Bump github.com/miekg/dns from 1.1.59 to 1.1.61 (#1168) Bumps [github.com/miekg/dns](https://github.com/miekg/dns) from 1.1.59 to 1.1.61. - [Changelog](https://github.com/miekg/dns/blob/master/Makefile.release) - [Commits](https://github.com/miekg/dns/compare/v1.1.59...v1.1.61) --- updated-dependencies: - dependency-name: github.com/miekg/dns dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- go.mod | 6 +++--- go.sum | 12 ++++++------ 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/go.mod b/go.mod index 7ee3d68..e69072e 100644 --- a/go.mod +++ b/go.mod @@ -13,7 +13,7 @@ require ( github.com/gogo/protobuf v1.3.2 github.com/google/gopacket v1.1.19 github.com/kardianos/service v1.2.2 - github.com/miekg/dns v1.1.59 + github.com/miekg/dns v1.1.61 github.com/nbrownus/go-metrics-prometheus v0.0.0-20210712211119-974a6260965f github.com/prometheus/client_golang v1.19.0 github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475 @@ -46,8 +46,8 @@ require ( github.com/prometheus/common v0.48.0 // indirect github.com/prometheus/procfs v0.12.0 // indirect github.com/vishvananda/netns v0.0.4 // indirect - golang.org/x/mod v0.16.0 // indirect + golang.org/x/mod v0.18.0 // indirect golang.org/x/time v0.5.0 // indirect - golang.org/x/tools v0.19.0 // indirect + golang.org/x/tools v0.22.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 9d9d8ce..131cf56 100644 --- a/go.sum +++ b/go.sum @@ -77,8 +77,8 @@ github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= -github.com/miekg/dns v1.1.59 h1:C9EXc/UToRwKLhK5wKU/I4QVsBUc8kE6MkHBkeypWZs= -github.com/miekg/dns v1.1.59/go.mod h1:nZpewl5p6IvctfgrckopVx2OlSEHPRO/U4SYkRklrEk= +github.com/miekg/dns v1.1.61 h1:nLxbwF3XxhwVSm8g9Dghm9MHPaUZuqhPiGL+675ZmEs= +github.com/miekg/dns v1.1.61/go.mod h1:mnAarhS3nWaW+NVP2wTkYVIZyHNJ098SJZUki3eykwQ= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= @@ -155,8 +155,8 @@ golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPI golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= -golang.org/x/mod v0.16.0 h1:QX4fJ0Rr5cPQCF7O9lh9Se4pmwfwskqZfq5moyldzic= -golang.org/x/mod v0.16.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= +golang.org/x/mod v0.18.0 h1:5+9lSbEzPSdWkH32vYPBwEpX8KwDbM52Ud9xBUvNlb0= +golang.org/x/mod v0.18.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= @@ -210,8 +210,8 @@ golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtn golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= -golang.org/x/tools v0.19.0 h1:tfGCXNR1OsFG+sVdLAitlpjAvD/I6dHDKnYrpEZUHkw= -golang.org/x/tools v0.19.0/go.mod h1:qoJWxmGSIBmAeriMx19ogtrEPrGtDbPK634QFIcLAhc= +golang.org/x/tools v0.22.0 h1:gqSGLZqv+AI9lIQzniJ0nZDRG5GBPsSi+DRNHWNz6yA= +golang.org/x/tools v0.22.0/go.mod h1:aCwcsjqvq7Yqt6TNyX7QMU2enbQ/Gt0bo6krSeEri+c= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= From 97e9834f82678c9324bdddf571265be77b63b1df Mon Sep 17 00:00:00 2001 From: Wade Simmons Date: Mon, 24 Jun 2024 14:47:14 -0400 Subject: [PATCH 10/67] cleanup SK_MEMINFO vars (#1162) We had to manually define these types before, but the latest release of `golang.org/x/sys` adds these definitions: - https://github.com/golang/sys/commit/6dfb94eaa3bd0fcaa615f58e915f7214ce078beb Since we just updated with this PR, we can clean this up now: - https://github.com/slackhq/nebula/pull/1161 --- udp/udp_linux.go | 33 +++++++-------------------------- 1 file changed, 7 insertions(+), 26 deletions(-) diff --git a/udp/udp_linux.go b/udp/udp_linux.go index 1151c89..02c8ce0 100644 --- a/udp/udp_linux.go +++ b/udp/udp_linux.go @@ -27,25 +27,6 @@ type StdConn struct { batch int } -var x int - -// From linux/sock_diag.h -const ( - _SK_MEMINFO_RMEM_ALLOC = iota - _SK_MEMINFO_RCVBUF - _SK_MEMINFO_WMEM_ALLOC - _SK_MEMINFO_SNDBUF - _SK_MEMINFO_FWD_ALLOC - _SK_MEMINFO_WMEM_QUEUED - _SK_MEMINFO_OPTMEM - _SK_MEMINFO_BACKLOG - _SK_MEMINFO_DROPS - - _SK_MEMINFO_VARS -) - -type _SK_MEMINFO [_SK_MEMINFO_VARS]uint32 - func maybeIPV4(ip net.IP) (net.IP, bool) { ip4 := ip.To4() if ip4 != nil { @@ -316,8 +297,8 @@ func (u *StdConn) ReloadConfig(c *config.C) { } } -func (u *StdConn) getMemInfo(meminfo *_SK_MEMINFO) error { - var vallen uint32 = 4 * _SK_MEMINFO_VARS +func (u *StdConn) getMemInfo(meminfo *[unix.SK_MEMINFO_VARS]uint32) error { + var vallen uint32 = 4 * unix.SK_MEMINFO_VARS _, _, err := unix.Syscall6(unix.SYS_GETSOCKOPT, uintptr(u.sysFd), uintptr(unix.SOL_SOCKET), uintptr(unix.SO_MEMINFO), uintptr(unsafe.Pointer(meminfo)), uintptr(unsafe.Pointer(&vallen)), 0) if err != 0 { return err @@ -332,12 +313,12 @@ func (u *StdConn) Close() error { func NewUDPStatsEmitter(udpConns []Conn) func() { // Check if our kernel supports SO_MEMINFO before registering the gauges - var udpGauges [][_SK_MEMINFO_VARS]metrics.Gauge - var meminfo _SK_MEMINFO + var udpGauges [][unix.SK_MEMINFO_VARS]metrics.Gauge + var meminfo [unix.SK_MEMINFO_VARS]uint32 if err := udpConns[0].(*StdConn).getMemInfo(&meminfo); err == nil { - udpGauges = make([][_SK_MEMINFO_VARS]metrics.Gauge, len(udpConns)) + udpGauges = make([][unix.SK_MEMINFO_VARS]metrics.Gauge, len(udpConns)) for i := range udpConns { - udpGauges[i] = [_SK_MEMINFO_VARS]metrics.Gauge{ + udpGauges[i] = [unix.SK_MEMINFO_VARS]metrics.Gauge{ metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.rmem_alloc", i), nil), metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.rcvbuf", i), nil), metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.wmem_alloc", i), nil), @@ -354,7 +335,7 @@ func NewUDPStatsEmitter(udpConns []Conn) func() { return func() { for i, gauges := range udpGauges { if err := udpConns[i].(*StdConn).getMemInfo(&meminfo); err == nil { - for j := 0; j < _SK_MEMINFO_VARS; j++ { + for j := 0; j < unix.SK_MEMINFO_VARS; j++ { gauges[j].Update(int64(meminfo[j])) } } From 8109cf2170375776f6e430d53d338bc5b2f91ffb Mon Sep 17 00:00:00 2001 From: Caleb Jasik Date: Mon, 24 Jun 2024 13:50:17 -0500 Subject: [PATCH 11/67] Add puncuation to doc comment (#1164) * Add puncuation to doc comment * Fix list formatting inside `EncryptDanger` doc comment --- lighthouse.go | 2 +- noise.go | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/lighthouse.go b/lighthouse.go index aa54c4b..df68e1e 100644 --- a/lighthouse.go +++ b/lighthouse.go @@ -1151,7 +1151,7 @@ func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, vpnIp i } } -// ipMaskContains checks if testIp is contained by ip after applying a cidr +// ipMaskContains checks if testIp is contained by ip after applying a cidr. // zeros is 32 - bits from net.IPMask.Size() func ipMaskContains(ip iputil.VpnIp, zeros iputil.VpnIp, testIp iputil.VpnIp) bool { return (testIp^ip)>>zeros == 0 diff --git a/noise.go b/noise.go index 91ad2c0..57990a7 100644 --- a/noise.go +++ b/noise.go @@ -28,11 +28,11 @@ func NewNebulaCipherState(s *noise.CipherState) *NebulaCipherState { // EncryptDanger encrypts and authenticates a given payload. // // out is a destination slice to hold the output of the EncryptDanger operation. -// - ad is additional data, which will be authenticated and appended to out, but not encrypted. -// - plaintext is encrypted, authenticated and appended to out. -// - n is a nonce value which must never be re-used with this key. -// - nb is a buffer used for temporary storage in the implementation of this call, which should -// be re-used by callers to minimize garbage collection. +// - ad is additional data, which will be authenticated and appended to out, but not encrypted. +// - plaintext is encrypted, authenticated and appended to out. +// - n is a nonce value which must never be re-used with this key. +// - nb is a buffer used for temporary storage in the implementation of this call, which should +// be re-used by callers to minimize garbage collection. func (s *NebulaCipherState) EncryptDanger(out, ad, plaintext []byte, n uint64, nb []byte) ([]byte, error) { if s != nil { // TODO: Is this okay now that we have made messageCounter atomic? From a76723eaf5e089196970ad4e5c60f07ea102b869 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 24 Jun 2024 14:54:05 -0400 Subject: [PATCH 12/67] Bump Apple-Actions/import-codesign-certs from 2 to 3 (#1146) Bumps [Apple-Actions/import-codesign-certs](https://github.com/apple-actions/import-codesign-certs) from 2 to 3. - [Release notes](https://github.com/apple-actions/import-codesign-certs/releases) - [Commits](https://github.com/apple-actions/import-codesign-certs/compare/v2...v3) --- updated-dependencies: - dependency-name: Apple-Actions/import-codesign-certs dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/release.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index c8cf3f8..8f53207 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -75,7 +75,7 @@ jobs: - name: Import certificates if: env.HAS_SIGNING_CREDS == 'true' - uses: Apple-Actions/import-codesign-certs@v2 + uses: Apple-Actions/import-codesign-certs@v3 with: p12-file-base64: ${{ secrets.APPLE_DEVELOPER_CERTIFICATE_P12_BASE64 }} p12-password: ${{ secrets.APPLE_DEVELOPER_CERTIFICATE_PASSWORD }} From b9aace1e58a8cd45716ceaad2b237c5a01ed0298 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 24 Jun 2024 14:54:51 -0400 Subject: [PATCH 13/67] Bump github.com/prometheus/client_golang from 1.19.0 to 1.19.1 (#1147) Bumps [github.com/prometheus/client_golang](https://github.com/prometheus/client_golang) from 1.19.0 to 1.19.1. - [Release notes](https://github.com/prometheus/client_golang/releases) - [Changelog](https://github.com/prometheus/client_golang/blob/main/CHANGELOG.md) - [Commits](https://github.com/prometheus/client_golang/compare/v1.19.0...v1.19.1) --- updated-dependencies: - dependency-name: github.com/prometheus/client_golang dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- go.mod | 2 +- go.sum | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/go.mod b/go.mod index e69072e..dc9e01e 100644 --- a/go.mod +++ b/go.mod @@ -15,7 +15,7 @@ require ( github.com/kardianos/service v1.2.2 github.com/miekg/dns v1.1.61 github.com/nbrownus/go-metrics-prometheus v0.0.0-20210712211119-974a6260965f - github.com/prometheus/client_golang v1.19.0 + github.com/prometheus/client_golang v1.19.1 github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475 github.com/sirupsen/logrus v1.9.3 github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e diff --git a/go.sum b/go.sum index 131cf56..32099f2 100644 --- a/go.sum +++ b/go.sum @@ -96,8 +96,8 @@ github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXP github.com/prometheus/client_golang v1.0.0/go.mod h1:db9x61etRT2tGnBNRi70OPL5FsnadC4Ky3P0J6CfImo= github.com/prometheus/client_golang v1.7.1/go.mod h1:PY5Wy2awLA44sXw4AOSfFBetzPP4j5+D6mVACh+pe2M= github.com/prometheus/client_golang v1.11.0/go.mod h1:Z6t4BnS23TR94PD6BsDNk8yVqroYurpAkEiz0P2BEV0= -github.com/prometheus/client_golang v1.19.0 h1:ygXvpU1AoN1MhdzckN+PyD9QJOSD4x7kmXYlnfbA6JU= -github.com/prometheus/client_golang v1.19.0/go.mod h1:ZRM9uEAypZakd+q/x7+gmsvXdURP+DABIEIjnmDdp+k= +github.com/prometheus/client_golang v1.19.1 h1:wZWJDwK+NameRJuPGDhlnFgx8e8HN3XHQeLaYJFJBOE= +github.com/prometheus/client_golang v1.19.1/go.mod h1:mP78NwGzrVks5S2H6ab8+ZZGJLZUq1hoULYBAYBw1Ho= github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo= github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= github.com/prometheus/client_model v0.2.0/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= From e6009b849145c039a088cb84dc9c6f349bb42f78 Mon Sep 17 00:00:00 2001 From: Wade Simmons Date: Tue, 2 Jul 2024 11:50:51 -0400 Subject: [PATCH 14/67] github actions: use macos-latest (#1171) macos-11 was deprecated and removed: > The macos-11 label has been deprecated and will no longer be available after 28 June 2024. We can just use macos-latest instead. --- .github/workflows/release.yml | 2 +- .github/workflows/test.yml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 8f53207..31987db 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -64,7 +64,7 @@ jobs: name: Build Universal Darwin env: HAS_SIGNING_CREDS: ${{ secrets.AC_USERNAME != '' }} - runs-on: macos-11 + runs-on: macos-latest steps: - uses: actions/checkout@v4 diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 844eaf2..65a6e3e 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -72,7 +72,7 @@ jobs: runs-on: ${{ matrix.os }} strategy: matrix: - os: [windows-latest, macos-11] + os: [windows-latest, macos-latest] steps: - uses: actions/checkout@v4 From 00458302caf8132923d53df42b1fc4143e8a6d14 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 29 Jul 2024 11:42:33 -0400 Subject: [PATCH 15/67] Bump the golang-x-dependencies group with 4 updates (#1174) Bumps the golang-x-dependencies group with 4 updates: [golang.org/x/crypto](https://github.com/golang/crypto), [golang.org/x/net](https://github.com/golang/net), [golang.org/x/sys](https://github.com/golang/sys) and [golang.org/x/term](https://github.com/golang/term). Updates `golang.org/x/crypto` from 0.24.0 to 0.25.0 - [Commits](https://github.com/golang/crypto/compare/v0.24.0...v0.25.0) Updates `golang.org/x/net` from 0.26.0 to 0.27.0 - [Commits](https://github.com/golang/net/compare/v0.26.0...v0.27.0) Updates `golang.org/x/sys` from 0.21.0 to 0.22.0 - [Commits](https://github.com/golang/sys/compare/v0.21.0...v0.22.0) Updates `golang.org/x/term` from 0.21.0 to 0.22.0 - [Commits](https://github.com/golang/term/compare/v0.21.0...v0.22.0) --- updated-dependencies: - dependency-name: golang.org/x/crypto dependency-type: direct:production update-type: version-update:semver-minor dependency-group: golang-x-dependencies - dependency-name: golang.org/x/net dependency-type: direct:production update-type: version-update:semver-minor dependency-group: golang-x-dependencies - dependency-name: golang.org/x/sys dependency-type: direct:production update-type: version-update:semver-minor dependency-group: golang-x-dependencies - dependency-name: golang.org/x/term dependency-type: direct:production update-type: version-update:semver-minor dependency-group: golang-x-dependencies ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- go.mod | 8 ++++---- go.sum | 16 ++++++++-------- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/go.mod b/go.mod index dc9e01e..bec08c4 100644 --- a/go.mod +++ b/go.mod @@ -22,12 +22,12 @@ require ( github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8 github.com/stretchr/testify v1.9.0 github.com/vishvananda/netlink v1.2.1-beta.2 - golang.org/x/crypto v0.24.0 + golang.org/x/crypto v0.25.0 golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 - golang.org/x/net v0.26.0 + golang.org/x/net v0.27.0 golang.org/x/sync v0.7.0 - golang.org/x/sys v0.21.0 - golang.org/x/term v0.21.0 + golang.org/x/sys v0.22.0 + golang.org/x/term v0.22.0 golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b golang.zx2c4.com/wireguard/windows v0.5.3 diff --git a/go.sum b/go.sum index 32099f2..ddd5402 100644 --- a/go.sum +++ b/go.sum @@ -147,8 +147,8 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= -golang.org/x/crypto v0.24.0 h1:mnl8DM0o513X8fdIkmyFE/5hTYxbwYOjDS/+rK6qpRI= -golang.org/x/crypto v0.24.0/go.mod h1:Z1PMYSOR5nyMcyAVAIQSKCDwalqy85Aqn1x3Ws4L5DM= +golang.org/x/crypto v0.25.0 h1:ypSNr+bnYL2YhwoMt2zPxHFmbAN1KZs/njMG3hxUp30= +golang.org/x/crypto v0.25.0/go.mod h1:T+wALwcMOSE0kXgUAnPAHqTLW+XHgcELELW8VaDgm/M= golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 h1:Di6/M8l0O2lCLc6VVRWhgCiApHV8MnQurBnFSHsQtNY= golang.org/x/exp v0.0.0-20230725093048-515e97ebf090/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc= golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY= @@ -167,8 +167,8 @@ golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLL golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= -golang.org/x/net v0.26.0 h1:soB7SVo0PWrY4vPW/+ay0jKDNScG2X9wFeYlXIvJsOQ= -golang.org/x/net v0.26.0/go.mod h1:5YKkiSynbBIh3p6iOc/vibscux0x38BZDkn8sCUPxHE= +golang.org/x/net v0.27.0 h1:5K3Njcw06/l2y9vpGCSdcxWOYHOUk3dVNGDXN+FvAys= +golang.org/x/net v0.27.0/go.mod h1:dDi0PyhWNoiUOrAS8uXv/vnScO4wnHQO4mj9fn/RytE= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -195,11 +195,11 @@ golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210603081109-ebe580a85c40/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.21.0 h1:rF+pYz3DAGSQAxAu1CbC7catZg4ebC4UIeIhKxBZvws= -golang.org/x/sys v0.21.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.22.0 h1:RI27ohtqKCnwULzJLqkv897zojh5/DwS/ENaMzUOaWI= +golang.org/x/sys v0.22.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= -golang.org/x/term v0.21.0 h1:WVXCp+/EBEHOj53Rvu+7KiT/iElMrO8ACK16SMZ3jaA= -golang.org/x/term v0.21.0/go.mod h1:ooXLefLobQVslOqselCNF4SxFAaoS6KujMbsGzSDmX0= +golang.org/x/term v0.22.0 h1:BbsgPEJULsl2fV/AT3v15Mjva5yXKQDyKf+TbDz7QJk= +golang.org/x/term v0.22.0/go.mod h1:F3qCibpT5AMpCRfhfT53vVJwhLtIVHhB9XDjfFvnMI4= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= From e264a0ff888c7bf0568579306755a60fc42f6ecc Mon Sep 17 00:00:00 2001 From: Nate Brown Date: Wed, 31 Jul 2024 10:18:56 -0500 Subject: [PATCH 16/67] Switch most everything to netip in prep for ipv6 in the overlay (#1173) --- allow_list.go | 93 +++----- allow_list_test.go | 40 ++-- calculated_remote.go | 68 +++--- calculated_remote_test.go | 16 +- cidr/parse.go | 10 - cidr/tree4.go | 203 ----------------- cidr/tree4_test.go | 170 --------------- cidr/tree6.go | 189 ---------------- cidr/tree6_test.go | 98 --------- connection_manager.go | 30 +-- connection_manager_test.go | 34 +-- control.go | 40 ++-- control_test.go | 57 +++-- control_tester.go | 47 ++-- dns_server.go | 18 +- e2e/handshakes_test.go | 338 ++++++++++++++-------------- e2e/helpers.go | 23 +- e2e/helpers_test.go | 52 +++-- e2e/router/hostmap.go | 8 +- e2e/router/router.go | 99 +++------ firewall.go | 100 +++++---- firewall/packet.go | 7 +- firewall_test.go | 147 ++++++------- go.mod | 2 + go.sum | 6 + handshake_ix.go | 58 +++-- handshake_manager.go | 91 ++++---- handshake_manager_test.go | 18 +- hostmap.go | 146 +++++++------ hostmap_test.go | 59 +++-- hostmap_tester.go | 6 +- inside.go | 44 ++-- interface.go | 47 +++- iputil/packet.go | 2 + iputil/util.go | 93 -------- iputil/util_test.go | 17 -- lighthouse.go | 400 ++++++++++++++++++---------------- lighthouse_test.go | 185 +++++++--------- main.go | 30 ++- outside.go | 95 ++++---- outside_test.go | 10 +- overlay/device.go | 8 +- overlay/route.go | 44 ++-- overlay/route_test.go | 43 ++-- overlay/tun.go | 10 +- overlay/tun_android.go | 19 +- overlay/tun_darwin.go | 59 +++-- overlay/tun_disabled.go | 12 +- overlay/tun_freebsd.go | 23 +- overlay/tun_ios.go | 19 +- overlay/tun_linux.go | 97 +++++---- overlay/tun_netbsd.go | 29 ++- overlay/tun_openbsd.go | 29 ++- overlay/tun_tester.go | 19 +- overlay/tun_water_windows.go | 22 +- overlay/tun_windows.go | 6 +- overlay/tun_wintun_windows.go | 50 +---- overlay/user.go | 15 +- pki.go | 2 + relay_manager.go | 83 ++++--- remote_list.go | 166 +++++++------- remote_list_test.go | 187 ++++++++-------- service/service.go | 2 +- service/service_test.go | 16 +- ssh.go | 65 +++--- test/tun.go | 12 +- timeout_test.go | 9 +- udp/conn.go | 14 +- udp/temp.go | 5 +- udp/udp_all.go | 100 --------- udp/udp_android.go | 3 +- udp/udp_bsd.go | 3 +- udp/udp_darwin.go | 3 +- udp/udp_generic.go | 37 ++-- udp/udp_linux.go | 77 ++++--- udp/udp_netbsd.go | 3 +- udp/udp_rio_windows.go | 43 ++-- udp/udp_tester.go | 49 ++--- udp/udp_windows.go | 3 +- 79 files changed, 1900 insertions(+), 2682 deletions(-) delete mode 100644 cidr/parse.go delete mode 100644 cidr/tree4.go delete mode 100644 cidr/tree4_test.go delete mode 100644 cidr/tree6.go delete mode 100644 cidr/tree6_test.go delete mode 100644 iputil/util.go delete mode 100644 iputil/util_test.go delete mode 100644 udp/udp_all.go diff --git a/allow_list.go b/allow_list.go index 9186b2f..90e0de2 100644 --- a/allow_list.go +++ b/allow_list.go @@ -2,17 +2,16 @@ package nebula import ( "fmt" - "net" + "net/netip" "regexp" - "github.com/slackhq/nebula/cidr" + "github.com/gaissmai/bart" "github.com/slackhq/nebula/config" - "github.com/slackhq/nebula/iputil" ) type AllowList struct { // The values of this cidrTree are `bool`, signifying allow/deny - cidrTree *cidr.Tree6[bool] + cidrTree *bart.Table[bool] } type RemoteAllowList struct { @@ -20,7 +19,7 @@ type RemoteAllowList struct { // Inside Range Specific, keys of this tree are inside CIDRs and values // are *AllowList - insideAllowLists *cidr.Tree6[*AllowList] + insideAllowLists *bart.Table[*AllowList] } type LocalAllowList struct { @@ -88,7 +87,7 @@ func newAllowList(k string, raw interface{}, handleKey func(key string, value in return nil, fmt.Errorf("config `%s` has invalid type: %T", k, raw) } - tree := cidr.NewTree6[bool]() + tree := new(bart.Table[bool]) // Keep track of the rules we have added for both ipv4 and ipv6 type allowListRules struct { @@ -122,18 +121,20 @@ func newAllowList(k string, raw interface{}, handleKey func(key string, value in return nil, fmt.Errorf("config `%s` has invalid value (type %T): %v", k, rawValue, rawValue) } - _, ipNet, err := net.ParseCIDR(rawCIDR) + ipNet, err := netip.ParsePrefix(rawCIDR) if err != nil { - return nil, fmt.Errorf("config `%s` has invalid CIDR: %s", k, rawCIDR) + return nil, fmt.Errorf("config `%s` has invalid CIDR: %s. %w", k, rawCIDR, err) } - // TODO: should we error on duplicate CIDRs in the config? - tree.AddCIDR(ipNet, value) + ipNet = netip.PrefixFrom(ipNet.Addr().Unmap(), ipNet.Bits()) - maskBits, maskSize := ipNet.Mask.Size() + // TODO: should we error on duplicate CIDRs in the config? + tree.Insert(ipNet, value) + + maskBits := ipNet.Bits() var rules *allowListRules - if maskSize == 32 { + if ipNet.Addr().Is4() { rules = &rules4 } else { rules = &rules6 @@ -156,8 +157,7 @@ func newAllowList(k string, raw interface{}, handleKey func(key string, value in if !rules4.defaultSet { if rules4.allValuesMatch { - _, zeroCIDR, _ := net.ParseCIDR("0.0.0.0/0") - tree.AddCIDR(zeroCIDR, !rules4.allValues) + tree.Insert(netip.PrefixFrom(netip.IPv4Unspecified(), 0), !rules4.allValues) } else { return nil, fmt.Errorf("config `%s` contains both true and false rules, but no default set for 0.0.0.0/0", k) } @@ -165,8 +165,7 @@ func newAllowList(k string, raw interface{}, handleKey func(key string, value in if !rules6.defaultSet { if rules6.allValuesMatch { - _, zeroCIDR, _ := net.ParseCIDR("::/0") - tree.AddCIDR(zeroCIDR, !rules6.allValues) + tree.Insert(netip.PrefixFrom(netip.IPv6Unspecified(), 0), !rules6.allValues) } else { return nil, fmt.Errorf("config `%s` contains both true and false rules, but no default set for ::/0", k) } @@ -218,13 +217,13 @@ func getAllowListInterfaces(k string, v interface{}) ([]AllowListNameRule, error return nameRules, nil } -func getRemoteAllowRanges(c *config.C, k string) (*cidr.Tree6[*AllowList], error) { +func getRemoteAllowRanges(c *config.C, k string) (*bart.Table[*AllowList], error) { value := c.Get(k) if value == nil { return nil, nil } - remoteAllowRanges := cidr.NewTree6[*AllowList]() + remoteAllowRanges := new(bart.Table[*AllowList]) rawMap, ok := value.(map[interface{}]interface{}) if !ok { @@ -241,45 +240,27 @@ func getRemoteAllowRanges(c *config.C, k string) (*cidr.Tree6[*AllowList], error return nil, err } - _, ipNet, err := net.ParseCIDR(rawCIDR) + ipNet, err := netip.ParsePrefix(rawCIDR) if err != nil { - return nil, fmt.Errorf("config `%s` has invalid CIDR: %s", k, rawCIDR) + return nil, fmt.Errorf("config `%s` has invalid CIDR: %s. %w", k, rawCIDR, err) } - remoteAllowRanges.AddCIDR(ipNet, allowList) + remoteAllowRanges.Insert(netip.PrefixFrom(ipNet.Addr().Unmap(), ipNet.Bits()), allowList) } return remoteAllowRanges, nil } -func (al *AllowList) Allow(ip net.IP) bool { +func (al *AllowList) Allow(ip netip.Addr) bool { if al == nil { return true } - _, result := al.cidrTree.MostSpecificContains(ip) + result, _ := al.cidrTree.Lookup(ip) return result } -func (al *AllowList) AllowIpV4(ip iputil.VpnIp) bool { - if al == nil { - return true - } - - _, result := al.cidrTree.MostSpecificContainsIpV4(ip) - return result -} - -func (al *AllowList) AllowIpV6(hi, lo uint64) bool { - if al == nil { - return true - } - - _, result := al.cidrTree.MostSpecificContainsIpV6(hi, lo) - return result -} - -func (al *LocalAllowList) Allow(ip net.IP) bool { +func (al *LocalAllowList) Allow(ip netip.Addr) bool { if al == nil { return true } @@ -301,43 +282,23 @@ func (al *LocalAllowList) AllowName(name string) bool { return !al.nameRules[0].Allow } -func (al *RemoteAllowList) AllowUnknownVpnIp(ip net.IP) bool { +func (al *RemoteAllowList) AllowUnknownVpnIp(ip netip.Addr) bool { if al == nil { return true } return al.AllowList.Allow(ip) } -func (al *RemoteAllowList) Allow(vpnIp iputil.VpnIp, ip net.IP) bool { +func (al *RemoteAllowList) Allow(vpnIp netip.Addr, ip netip.Addr) bool { if !al.getInsideAllowList(vpnIp).Allow(ip) { return false } return al.AllowList.Allow(ip) } -func (al *RemoteAllowList) AllowIpV4(vpnIp iputil.VpnIp, ip iputil.VpnIp) bool { - if al == nil { - return true - } - if !al.getInsideAllowList(vpnIp).AllowIpV4(ip) { - return false - } - return al.AllowList.AllowIpV4(ip) -} - -func (al *RemoteAllowList) AllowIpV6(vpnIp iputil.VpnIp, hi, lo uint64) bool { - if al == nil { - return true - } - if !al.getInsideAllowList(vpnIp).AllowIpV6(hi, lo) { - return false - } - return al.AllowList.AllowIpV6(hi, lo) -} - -func (al *RemoteAllowList) getInsideAllowList(vpnIp iputil.VpnIp) *AllowList { +func (al *RemoteAllowList) getInsideAllowList(vpnIp netip.Addr) *AllowList { if al.insideAllowLists != nil { - ok, inside := al.insideAllowLists.MostSpecificContainsIpV4(vpnIp) + inside, ok := al.insideAllowLists.Lookup(vpnIp) if ok { return inside } diff --git a/allow_list_test.go b/allow_list_test.go index 334cb60..c8b3d08 100644 --- a/allow_list_test.go +++ b/allow_list_test.go @@ -1,11 +1,11 @@ package nebula import ( - "net" + "net/netip" "regexp" "testing" - "github.com/slackhq/nebula/cidr" + "github.com/gaissmai/bart" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/test" "github.com/stretchr/testify/assert" @@ -18,7 +18,7 @@ func TestNewAllowListFromConfig(t *testing.T) { "192.168.0.0": true, } r, err := newAllowListFromConfig(c, "allowlist", nil) - assert.EqualError(t, err, "config `allowlist` has invalid CIDR: 192.168.0.0") + assert.EqualError(t, err, "config `allowlist` has invalid CIDR: 192.168.0.0. netip.ParsePrefix(\"192.168.0.0\"): no '/'") assert.Nil(t, r) c.Settings["allowlist"] = map[interface{}]interface{}{ @@ -98,26 +98,26 @@ func TestNewAllowListFromConfig(t *testing.T) { } func TestAllowList_Allow(t *testing.T) { - assert.Equal(t, true, ((*AllowList)(nil)).Allow(net.ParseIP("1.1.1.1"))) + assert.Equal(t, true, ((*AllowList)(nil)).Allow(netip.MustParseAddr("1.1.1.1"))) - tree := cidr.NewTree6[bool]() - tree.AddCIDR(cidr.Parse("0.0.0.0/0"), true) - tree.AddCIDR(cidr.Parse("10.0.0.0/8"), false) - tree.AddCIDR(cidr.Parse("10.42.42.42/32"), true) - tree.AddCIDR(cidr.Parse("10.42.0.0/16"), true) - tree.AddCIDR(cidr.Parse("10.42.42.0/24"), true) - tree.AddCIDR(cidr.Parse("10.42.42.0/24"), false) - tree.AddCIDR(cidr.Parse("::1/128"), true) - tree.AddCIDR(cidr.Parse("::2/128"), false) + tree := new(bart.Table[bool]) + tree.Insert(netip.MustParsePrefix("0.0.0.0/0"), true) + tree.Insert(netip.MustParsePrefix("10.0.0.0/8"), false) + tree.Insert(netip.MustParsePrefix("10.42.42.42/32"), true) + tree.Insert(netip.MustParsePrefix("10.42.0.0/16"), true) + tree.Insert(netip.MustParsePrefix("10.42.42.0/24"), true) + tree.Insert(netip.MustParsePrefix("10.42.42.0/24"), false) + tree.Insert(netip.MustParsePrefix("::1/128"), true) + tree.Insert(netip.MustParsePrefix("::2/128"), false) al := &AllowList{cidrTree: tree} - assert.Equal(t, true, al.Allow(net.ParseIP("1.1.1.1"))) - assert.Equal(t, false, al.Allow(net.ParseIP("10.0.0.4"))) - assert.Equal(t, true, al.Allow(net.ParseIP("10.42.42.42"))) - assert.Equal(t, false, al.Allow(net.ParseIP("10.42.42.41"))) - assert.Equal(t, true, al.Allow(net.ParseIP("10.42.0.1"))) - assert.Equal(t, true, al.Allow(net.ParseIP("::1"))) - assert.Equal(t, false, al.Allow(net.ParseIP("::2"))) + assert.Equal(t, true, al.Allow(netip.MustParseAddr("1.1.1.1"))) + assert.Equal(t, false, al.Allow(netip.MustParseAddr("10.0.0.4"))) + assert.Equal(t, true, al.Allow(netip.MustParseAddr("10.42.42.42"))) + assert.Equal(t, false, al.Allow(netip.MustParseAddr("10.42.42.41"))) + assert.Equal(t, true, al.Allow(netip.MustParseAddr("10.42.0.1"))) + assert.Equal(t, true, al.Allow(netip.MustParseAddr("::1"))) + assert.Equal(t, false, al.Allow(netip.MustParseAddr("::2"))) } func TestLocalAllowList_AllowName(t *testing.T) { diff --git a/calculated_remote.go b/calculated_remote.go index 38f5bea..ae2ed50 100644 --- a/calculated_remote.go +++ b/calculated_remote.go @@ -1,41 +1,36 @@ package nebula import ( + "encoding/binary" "fmt" "math" "net" + "net/netip" "strconv" - "github.com/slackhq/nebula/cidr" + "github.com/gaissmai/bart" "github.com/slackhq/nebula/config" - "github.com/slackhq/nebula/iputil" ) // This allows us to "guess" what the remote might be for a host while we wait // for the lighthouse response. See "lighthouse.calculated_remotes" in the // example config file. type calculatedRemote struct { - ipNet net.IPNet - maskIP iputil.VpnIp - mask iputil.VpnIp - port uint32 + ipNet netip.Prefix + mask netip.Prefix + port uint32 } -func newCalculatedRemote(ipNet *net.IPNet, port int) (*calculatedRemote, error) { - // Ensure this is an IPv4 mask that we expect - ones, bits := ipNet.Mask.Size() - if ones == 0 || bits != 32 { - return nil, fmt.Errorf("invalid mask: %v", ipNet) - } +func newCalculatedRemote(maskCidr netip.Prefix, port int) (*calculatedRemote, error) { + masked := maskCidr.Masked() if port < 0 || port > math.MaxUint16 { return nil, fmt.Errorf("invalid port: %d", port) } return &calculatedRemote{ - ipNet: *ipNet, - maskIP: iputil.Ip2VpnIp(ipNet.IP), - mask: iputil.Ip2VpnIp(ipNet.Mask), - port: uint32(port), + ipNet: maskCidr, + mask: masked, + port: uint32(port), }, nil } @@ -43,21 +38,41 @@ func (c *calculatedRemote) String() string { return fmt.Sprintf("CalculatedRemote(mask=%v port=%d)", c.ipNet, c.port) } -func (c *calculatedRemote) Apply(ip iputil.VpnIp) *Ip4AndPort { +func (c *calculatedRemote) Apply(ip netip.Addr) *Ip4AndPort { // Combine the masked bytes of the "mask" IP with the unmasked bytes // of the overlay IP - masked := (c.maskIP & c.mask) | (ip & ^c.mask) - - return &Ip4AndPort{Ip: uint32(masked), Port: c.port} + if c.ipNet.Addr().Is4() { + return c.apply4(ip) + } + return c.apply6(ip) } -func NewCalculatedRemotesFromConfig(c *config.C, k string) (*cidr.Tree4[[]*calculatedRemote], error) { +func (c *calculatedRemote) apply4(ip netip.Addr) *Ip4AndPort { + //TODO: IPV6-WORK this can be less crappy + maskb := net.CIDRMask(c.mask.Bits(), c.mask.Addr().BitLen()) + mask := binary.BigEndian.Uint32(maskb[:]) + + b := c.mask.Addr().As4() + maskIp := binary.BigEndian.Uint32(b[:]) + + b = ip.As4() + intIp := binary.BigEndian.Uint32(b[:]) + + return &Ip4AndPort{(maskIp & mask) | (intIp & ^mask), c.port} +} + +func (c *calculatedRemote) apply6(ip netip.Addr) *Ip4AndPort { + //TODO: IPV6-WORK + panic("Can not calculate ipv6 remote addresses") +} + +func NewCalculatedRemotesFromConfig(c *config.C, k string) (*bart.Table[[]*calculatedRemote], error) { value := c.Get(k) if value == nil { return nil, nil } - calculatedRemotes := cidr.NewTree4[[]*calculatedRemote]() + calculatedRemotes := new(bart.Table[[]*calculatedRemote]) rawMap, ok := value.(map[any]any) if !ok { @@ -69,17 +84,18 @@ func NewCalculatedRemotesFromConfig(c *config.C, k string) (*cidr.Tree4[[]*calcu return nil, fmt.Errorf("config `%s` has invalid key (type %T): %v", k, rawKey, rawKey) } - _, ipNet, err := net.ParseCIDR(rawCIDR) + cidr, err := netip.ParsePrefix(rawCIDR) if err != nil { return nil, fmt.Errorf("config `%s` has invalid CIDR: %s", k, rawCIDR) } + //TODO: IPV6-WORK this does not verify that rawValue contains the same bits as cidr here entry, err := newCalculatedRemotesListFromConfig(rawValue) if err != nil { return nil, fmt.Errorf("config '%s.%s': %w", k, rawCIDR, err) } - calculatedRemotes.AddCIDR(ipNet, entry) + calculatedRemotes.Insert(cidr, entry) } return calculatedRemotes, nil @@ -117,7 +133,7 @@ func newCalculatedRemotesEntryFromConfig(raw any) (*calculatedRemote, error) { if !ok { return nil, fmt.Errorf("invalid mask (type %T): %v", rawValue, rawValue) } - _, ipNet, err := net.ParseCIDR(rawMask) + maskCidr, err := netip.ParsePrefix(rawMask) if err != nil { return nil, fmt.Errorf("invalid mask: %s", rawMask) } @@ -139,5 +155,5 @@ func newCalculatedRemotesEntryFromConfig(raw any) (*calculatedRemote, error) { return nil, fmt.Errorf("invalid port (type %T): %v", rawValue, rawValue) } - return newCalculatedRemote(ipNet, port) + return newCalculatedRemote(maskCidr, port) } diff --git a/calculated_remote_test.go b/calculated_remote_test.go index 2ddebca..6ff1cb0 100644 --- a/calculated_remote_test.go +++ b/calculated_remote_test.go @@ -1,27 +1,25 @@ package nebula import ( - "net" + "net/netip" "testing" - "github.com/slackhq/nebula/iputil" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestCalculatedRemoteApply(t *testing.T) { - _, ipNet, err := net.ParseCIDR("192.168.1.0/24") + ipNet, err := netip.ParsePrefix("192.168.1.0/24") require.NoError(t, err) c, err := newCalculatedRemote(ipNet, 4242) require.NoError(t, err) - input := iputil.Ip2VpnIp([]byte{10, 0, 10, 182}) + input, err := netip.ParseAddr("10.0.10.182") + assert.NoError(t, err) - expected := &Ip4AndPort{ - Ip: uint32(iputil.Ip2VpnIp([]byte{192, 168, 1, 182})), - Port: 4242, - } + expected, err := netip.ParseAddr("192.168.1.182") + assert.NoError(t, err) - assert.Equal(t, expected, c.Apply(input)) + assert.Equal(t, NewIp4AndPortFromNetIP(expected, 4242), c.Apply(input)) } diff --git a/cidr/parse.go b/cidr/parse.go deleted file mode 100644 index 74367f6..0000000 --- a/cidr/parse.go +++ /dev/null @@ -1,10 +0,0 @@ -package cidr - -import "net" - -// Parse is a convenience function that returns only the IPNet -// This function ignores errors since it is primarily a test helper, the result could be nil -func Parse(s string) *net.IPNet { - _, c, _ := net.ParseCIDR(s) - return c -} diff --git a/cidr/tree4.go b/cidr/tree4.go deleted file mode 100644 index c5ebe54..0000000 --- a/cidr/tree4.go +++ /dev/null @@ -1,203 +0,0 @@ -package cidr - -import ( - "net" - - "github.com/slackhq/nebula/iputil" -) - -type Node[T any] struct { - left *Node[T] - right *Node[T] - parent *Node[T] - hasValue bool - value T -} - -type entry[T any] struct { - CIDR *net.IPNet - Value T -} - -type Tree4[T any] struct { - root *Node[T] - list []entry[T] -} - -const ( - startbit = iputil.VpnIp(0x80000000) -) - -func NewTree4[T any]() *Tree4[T] { - tree := new(Tree4[T]) - tree.root = &Node[T]{} - tree.list = []entry[T]{} - return tree -} - -func (tree *Tree4[T]) AddCIDR(cidr *net.IPNet, val T) { - bit := startbit - node := tree.root - next := tree.root - - ip := iputil.Ip2VpnIp(cidr.IP) - mask := iputil.Ip2VpnIp(cidr.Mask) - - // Find our last ancestor in the tree - for bit&mask != 0 { - if ip&bit != 0 { - next = node.right - } else { - next = node.left - } - - if next == nil { - break - } - - bit = bit >> 1 - node = next - } - - // We already have this range so update the value - if next != nil { - addCIDR := cidr.String() - for i, v := range tree.list { - if addCIDR == v.CIDR.String() { - tree.list = append(tree.list[:i], tree.list[i+1:]...) - break - } - } - - tree.list = append(tree.list, entry[T]{CIDR: cidr, Value: val}) - node.value = val - node.hasValue = true - return - } - - // Build up the rest of the tree we don't already have - for bit&mask != 0 { - next = &Node[T]{} - next.parent = node - - if ip&bit != 0 { - node.right = next - } else { - node.left = next - } - - bit >>= 1 - node = next - } - - // Final node marks our cidr, set the value - node.value = val - node.hasValue = true - tree.list = append(tree.list, entry[T]{CIDR: cidr, Value: val}) -} - -// Contains finds the first match, which may be the least specific -func (tree *Tree4[T]) Contains(ip iputil.VpnIp) (ok bool, value T) { - bit := startbit - node := tree.root - - for node != nil { - if node.hasValue { - return true, node.value - } - - if ip&bit != 0 { - node = node.right - } else { - node = node.left - } - - bit >>= 1 - - } - - return false, value -} - -// MostSpecificContains finds the most specific match -func (tree *Tree4[T]) MostSpecificContains(ip iputil.VpnIp) (ok bool, value T) { - bit := startbit - node := tree.root - - for node != nil { - if node.hasValue { - value = node.value - ok = true - } - - if ip&bit != 0 { - node = node.right - } else { - node = node.left - } - - bit >>= 1 - } - - return ok, value -} - -type eachFunc[T any] func(T) bool - -// EachContains will call a function, passing the value, for each entry until the function returns true or the search is complete -// The final return value will be true if the provided function returned true -func (tree *Tree4[T]) EachContains(ip iputil.VpnIp, each eachFunc[T]) bool { - bit := startbit - node := tree.root - - for node != nil { - if node.hasValue { - // If the each func returns true then we can exit the loop - if each(node.value) { - return true - } - } - - if ip&bit != 0 { - node = node.right - } else { - node = node.left - } - - bit >>= 1 - } - - return false -} - -// GetCIDR returns the entry added by the most recent matching AddCIDR call -func (tree *Tree4[T]) GetCIDR(cidr *net.IPNet) (ok bool, value T) { - bit := startbit - node := tree.root - - ip := iputil.Ip2VpnIp(cidr.IP) - mask := iputil.Ip2VpnIp(cidr.Mask) - - // Find our last ancestor in the tree - for node != nil && bit&mask != 0 { - if ip&bit != 0 { - node = node.right - } else { - node = node.left - } - - bit = bit >> 1 - } - - if bit&mask == 0 && node != nil { - value = node.value - ok = node.hasValue - } - - return ok, value -} - -// List will return all CIDRs and their current values. Do not modify the contents! -func (tree *Tree4[T]) List() []entry[T] { - return tree.list -} diff --git a/cidr/tree4_test.go b/cidr/tree4_test.go deleted file mode 100644 index cd17be4..0000000 --- a/cidr/tree4_test.go +++ /dev/null @@ -1,170 +0,0 @@ -package cidr - -import ( - "net" - "testing" - - "github.com/slackhq/nebula/iputil" - "github.com/stretchr/testify/assert" -) - -func TestCIDRTree_List(t *testing.T) { - tree := NewTree4[string]() - tree.AddCIDR(Parse("1.0.0.0/16"), "1") - tree.AddCIDR(Parse("1.0.0.0/8"), "2") - tree.AddCIDR(Parse("1.0.0.0/16"), "3") - tree.AddCIDR(Parse("1.0.0.0/16"), "4") - list := tree.List() - assert.Len(t, list, 2) - assert.Equal(t, "1.0.0.0/8", list[0].CIDR.String()) - assert.Equal(t, "2", list[0].Value) - assert.Equal(t, "1.0.0.0/16", list[1].CIDR.String()) - assert.Equal(t, "4", list[1].Value) -} - -func TestCIDRTree_Contains(t *testing.T) { - tree := NewTree4[string]() - tree.AddCIDR(Parse("1.0.0.0/8"), "1") - tree.AddCIDR(Parse("2.1.0.0/16"), "2") - tree.AddCIDR(Parse("3.1.1.0/24"), "3") - tree.AddCIDR(Parse("4.1.1.0/24"), "4a") - tree.AddCIDR(Parse("4.1.1.1/32"), "4b") - tree.AddCIDR(Parse("4.1.2.1/32"), "4c") - tree.AddCIDR(Parse("254.0.0.0/4"), "5") - - tests := []struct { - Found bool - Result interface{} - IP string - }{ - {true, "1", "1.0.0.0"}, - {true, "1", "1.255.255.255"}, - {true, "2", "2.1.0.0"}, - {true, "2", "2.1.255.255"}, - {true, "3", "3.1.1.0"}, - {true, "3", "3.1.1.255"}, - {true, "4a", "4.1.1.255"}, - {true, "4a", "4.1.1.1"}, - {true, "5", "240.0.0.0"}, - {true, "5", "255.255.255.255"}, - {false, "", "239.0.0.0"}, - {false, "", "4.1.2.2"}, - } - - for _, tt := range tests { - ok, r := tree.Contains(iputil.Ip2VpnIp(net.ParseIP(tt.IP))) - assert.Equal(t, tt.Found, ok) - assert.Equal(t, tt.Result, r) - } - - tree = NewTree4[string]() - tree.AddCIDR(Parse("1.1.1.1/0"), "cool") - ok, r := tree.Contains(iputil.Ip2VpnIp(net.ParseIP("0.0.0.0"))) - assert.True(t, ok) - assert.Equal(t, "cool", r) - - ok, r = tree.Contains(iputil.Ip2VpnIp(net.ParseIP("255.255.255.255"))) - assert.True(t, ok) - assert.Equal(t, "cool", r) -} - -func TestCIDRTree_MostSpecificContains(t *testing.T) { - tree := NewTree4[string]() - tree.AddCIDR(Parse("1.0.0.0/8"), "1") - tree.AddCIDR(Parse("2.1.0.0/16"), "2") - tree.AddCIDR(Parse("3.1.1.0/24"), "3") - tree.AddCIDR(Parse("4.1.1.0/24"), "4a") - tree.AddCIDR(Parse("4.1.1.0/30"), "4b") - tree.AddCIDR(Parse("4.1.1.1/32"), "4c") - tree.AddCIDR(Parse("254.0.0.0/4"), "5") - - tests := []struct { - Found bool - Result interface{} - IP string - }{ - {true, "1", "1.0.0.0"}, - {true, "1", "1.255.255.255"}, - {true, "2", "2.1.0.0"}, - {true, "2", "2.1.255.255"}, - {true, "3", "3.1.1.0"}, - {true, "3", "3.1.1.255"}, - {true, "4a", "4.1.1.255"}, - {true, "4b", "4.1.1.2"}, - {true, "4c", "4.1.1.1"}, - {true, "5", "240.0.0.0"}, - {true, "5", "255.255.255.255"}, - {false, "", "239.0.0.0"}, - {false, "", "4.1.2.2"}, - } - - for _, tt := range tests { - ok, r := tree.MostSpecificContains(iputil.Ip2VpnIp(net.ParseIP(tt.IP))) - assert.Equal(t, tt.Found, ok) - assert.Equal(t, tt.Result, r) - } - - tree = NewTree4[string]() - tree.AddCIDR(Parse("1.1.1.1/0"), "cool") - ok, r := tree.MostSpecificContains(iputil.Ip2VpnIp(net.ParseIP("0.0.0.0"))) - assert.True(t, ok) - assert.Equal(t, "cool", r) - - ok, r = tree.MostSpecificContains(iputil.Ip2VpnIp(net.ParseIP("255.255.255.255"))) - assert.True(t, ok) - assert.Equal(t, "cool", r) -} - -func TestTree4_GetCIDR(t *testing.T) { - tree := NewTree4[string]() - tree.AddCIDR(Parse("1.0.0.0/8"), "1") - tree.AddCIDR(Parse("2.1.0.0/16"), "2") - tree.AddCIDR(Parse("3.1.1.0/24"), "3") - tree.AddCIDR(Parse("4.1.1.0/24"), "4a") - tree.AddCIDR(Parse("4.1.1.1/32"), "4b") - tree.AddCIDR(Parse("4.1.2.1/32"), "4c") - tree.AddCIDR(Parse("254.0.0.0/4"), "5") - - tests := []struct { - Found bool - Result interface{} - IPNet *net.IPNet - }{ - {true, "1", Parse("1.0.0.0/8")}, - {true, "2", Parse("2.1.0.0/16")}, - {true, "3", Parse("3.1.1.0/24")}, - {true, "4a", Parse("4.1.1.0/24")}, - {true, "4b", Parse("4.1.1.1/32")}, - {true, "4c", Parse("4.1.2.1/32")}, - {true, "5", Parse("254.0.0.0/4")}, - {false, "", Parse("2.0.0.0/8")}, - } - - for _, tt := range tests { - ok, r := tree.GetCIDR(tt.IPNet) - assert.Equal(t, tt.Found, ok) - assert.Equal(t, tt.Result, r) - } -} - -func BenchmarkCIDRTree_Contains(b *testing.B) { - tree := NewTree4[string]() - tree.AddCIDR(Parse("1.1.0.0/16"), "1") - tree.AddCIDR(Parse("1.2.1.1/32"), "1") - tree.AddCIDR(Parse("192.2.1.1/32"), "1") - tree.AddCIDR(Parse("172.2.1.1/32"), "1") - - ip := iputil.Ip2VpnIp(net.ParseIP("1.2.1.1")) - b.Run("found", func(b *testing.B) { - for i := 0; i < b.N; i++ { - tree.Contains(ip) - } - }) - - ip = iputil.Ip2VpnIp(net.ParseIP("1.2.1.255")) - b.Run("not found", func(b *testing.B) { - for i := 0; i < b.N; i++ { - tree.Contains(ip) - } - }) -} diff --git a/cidr/tree6.go b/cidr/tree6.go deleted file mode 100644 index 3f2cd2a..0000000 --- a/cidr/tree6.go +++ /dev/null @@ -1,189 +0,0 @@ -package cidr - -import ( - "net" - - "github.com/slackhq/nebula/iputil" -) - -const startbit6 = uint64(1 << 63) - -type Tree6[T any] struct { - root4 *Node[T] - root6 *Node[T] -} - -func NewTree6[T any]() *Tree6[T] { - tree := new(Tree6[T]) - tree.root4 = &Node[T]{} - tree.root6 = &Node[T]{} - return tree -} - -func (tree *Tree6[T]) AddCIDR(cidr *net.IPNet, val T) { - var node, next *Node[T] - - cidrIP, ipv4 := isIPV4(cidr.IP) - if ipv4 { - node = tree.root4 - next = tree.root4 - - } else { - node = tree.root6 - next = tree.root6 - } - - for i := 0; i < len(cidrIP); i += 4 { - ip := iputil.Ip2VpnIp(cidrIP[i : i+4]) - mask := iputil.Ip2VpnIp(cidr.Mask[i : i+4]) - bit := startbit - - // Find our last ancestor in the tree - for bit&mask != 0 { - if ip&bit != 0 { - next = node.right - } else { - next = node.left - } - - if next == nil { - break - } - - bit = bit >> 1 - node = next - } - - // Build up the rest of the tree we don't already have - for bit&mask != 0 { - next = &Node[T]{} - next.parent = node - - if ip&bit != 0 { - node.right = next - } else { - node.left = next - } - - bit >>= 1 - node = next - } - } - - // Final node marks our cidr, set the value - node.value = val - node.hasValue = true -} - -// Finds the most specific match -func (tree *Tree6[T]) MostSpecificContains(ip net.IP) (ok bool, value T) { - var node *Node[T] - - wholeIP, ipv4 := isIPV4(ip) - if ipv4 { - node = tree.root4 - } else { - node = tree.root6 - } - - for i := 0; i < len(wholeIP); i += 4 { - ip := iputil.Ip2VpnIp(wholeIP[i : i+4]) - bit := startbit - - for node != nil { - if node.hasValue { - value = node.value - ok = true - } - - if bit == 0 { - break - } - - if ip&bit != 0 { - node = node.right - } else { - node = node.left - } - - bit >>= 1 - } - } - - return ok, value -} - -func (tree *Tree6[T]) MostSpecificContainsIpV4(ip iputil.VpnIp) (ok bool, value T) { - bit := startbit - node := tree.root4 - - for node != nil { - if node.hasValue { - value = node.value - ok = true - } - - if ip&bit != 0 { - node = node.right - } else { - node = node.left - } - - bit >>= 1 - } - - return ok, value -} - -func (tree *Tree6[T]) MostSpecificContainsIpV6(hi, lo uint64) (ok bool, value T) { - ip := hi - node := tree.root6 - - for i := 0; i < 2; i++ { - bit := startbit6 - - for node != nil { - if node.hasValue { - value = node.value - ok = true - } - - if bit == 0 { - break - } - - if ip&bit != 0 { - node = node.right - } else { - node = node.left - } - - bit >>= 1 - } - - ip = lo - } - - return ok, value -} - -func isIPV4(ip net.IP) (net.IP, bool) { - if len(ip) == net.IPv4len { - return ip, true - } - - if len(ip) == net.IPv6len && isZeros(ip[0:10]) && ip[10] == 0xff && ip[11] == 0xff { - return ip[12:16], true - } - - return ip, false -} - -func isZeros(p net.IP) bool { - for i := 0; i < len(p); i++ { - if p[i] != 0 { - return false - } - } - return true -} diff --git a/cidr/tree6_test.go b/cidr/tree6_test.go deleted file mode 100644 index eb159ec..0000000 --- a/cidr/tree6_test.go +++ /dev/null @@ -1,98 +0,0 @@ -package cidr - -import ( - "encoding/binary" - "net" - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestCIDR6Tree_MostSpecificContains(t *testing.T) { - tree := NewTree6[string]() - tree.AddCIDR(Parse("1.0.0.0/8"), "1") - tree.AddCIDR(Parse("2.1.0.0/16"), "2") - tree.AddCIDR(Parse("3.1.1.0/24"), "3") - tree.AddCIDR(Parse("4.1.1.1/24"), "4a") - tree.AddCIDR(Parse("4.1.1.1/30"), "4b") - tree.AddCIDR(Parse("4.1.1.1/32"), "4c") - tree.AddCIDR(Parse("254.0.0.0/4"), "5") - tree.AddCIDR(Parse("1:2:0:4:5:0:0:0/64"), "6a") - tree.AddCIDR(Parse("1:2:0:4:5:0:0:0/80"), "6b") - tree.AddCIDR(Parse("1:2:0:4:5:0:0:0/96"), "6c") - - tests := []struct { - Found bool - Result interface{} - IP string - }{ - {true, "1", "1.0.0.0"}, - {true, "1", "1.255.255.255"}, - {true, "2", "2.1.0.0"}, - {true, "2", "2.1.255.255"}, - {true, "3", "3.1.1.0"}, - {true, "3", "3.1.1.255"}, - {true, "4a", "4.1.1.255"}, - {true, "4b", "4.1.1.2"}, - {true, "4c", "4.1.1.1"}, - {true, "5", "240.0.0.0"}, - {true, "5", "255.255.255.255"}, - {true, "6a", "1:2:0:4:1:1:1:1"}, - {true, "6b", "1:2:0:4:5:1:1:1"}, - {true, "6c", "1:2:0:4:5:0:0:0"}, - {false, "", "239.0.0.0"}, - {false, "", "4.1.2.2"}, - } - - for _, tt := range tests { - ok, r := tree.MostSpecificContains(net.ParseIP(tt.IP)) - assert.Equal(t, tt.Found, ok) - assert.Equal(t, tt.Result, r) - } - - tree = NewTree6[string]() - tree.AddCIDR(Parse("1.1.1.1/0"), "cool") - tree.AddCIDR(Parse("::/0"), "cool6") - ok, r := tree.MostSpecificContains(net.ParseIP("0.0.0.0")) - assert.True(t, ok) - assert.Equal(t, "cool", r) - - ok, r = tree.MostSpecificContains(net.ParseIP("255.255.255.255")) - assert.True(t, ok) - assert.Equal(t, "cool", r) - - ok, r = tree.MostSpecificContains(net.ParseIP("::")) - assert.True(t, ok) - assert.Equal(t, "cool6", r) - - ok, r = tree.MostSpecificContains(net.ParseIP("1:2:3:4:5:6:7:8")) - assert.True(t, ok) - assert.Equal(t, "cool6", r) -} - -func TestCIDR6Tree_MostSpecificContainsIpV6(t *testing.T) { - tree := NewTree6[string]() - tree.AddCIDR(Parse("1:2:0:4:5:0:0:0/64"), "6a") - tree.AddCIDR(Parse("1:2:0:4:5:0:0:0/80"), "6b") - tree.AddCIDR(Parse("1:2:0:4:5:0:0:0/96"), "6c") - - tests := []struct { - Found bool - Result interface{} - IP string - }{ - {true, "6a", "1:2:0:4:1:1:1:1"}, - {true, "6b", "1:2:0:4:5:1:1:1"}, - {true, "6c", "1:2:0:4:5:0:0:0"}, - } - - for _, tt := range tests { - ip := net.ParseIP(tt.IP) - hi := binary.BigEndian.Uint64(ip[:8]) - lo := binary.BigEndian.Uint64(ip[8:]) - - ok, r := tree.MostSpecificContainsIpV6(hi, lo) - assert.Equal(t, tt.Found, ok) - assert.Equal(t, tt.Result, r) - } -} diff --git a/connection_manager.go b/connection_manager.go index 0b277b5..d2e8616 100644 --- a/connection_manager.go +++ b/connection_manager.go @@ -3,6 +3,8 @@ package nebula import ( "bytes" "context" + "encoding/binary" + "net/netip" "sync" "time" @@ -10,8 +12,6 @@ import ( "github.com/sirupsen/logrus" "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/header" - "github.com/slackhq/nebula/iputil" - "github.com/slackhq/nebula/udp" ) type trafficDecision int @@ -224,8 +224,8 @@ func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo) existing, ok := newhostinfo.relayState.QueryRelayForByIp(r.PeerIp) var index uint32 - var relayFrom iputil.VpnIp - var relayTo iputil.VpnIp + var relayFrom netip.Addr + var relayTo netip.Addr switch { case ok && existing.State == Established: // This relay already exists in newhostinfo, then do nothing. @@ -235,7 +235,7 @@ func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo) index = existing.LocalIndex switch r.Type { case TerminalType: - relayFrom = n.intf.myVpnIp + relayFrom = n.intf.myVpnNet.Addr() relayTo = existing.PeerIp case ForwardingType: relayFrom = existing.PeerIp @@ -260,7 +260,7 @@ func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo) } switch r.Type { case TerminalType: - relayFrom = n.intf.myVpnIp + relayFrom = n.intf.myVpnNet.Addr() relayTo = r.PeerIp case ForwardingType: relayFrom = r.PeerIp @@ -270,12 +270,16 @@ func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo) } } + //TODO: IPV6-WORK + relayFromB := relayFrom.As4() + relayToB := relayTo.As4() + // Send a CreateRelayRequest to the peer. req := NebulaControl{ Type: NebulaControl_CreateRelayRequest, InitiatorRelayIndex: index, - RelayFromIp: uint32(relayFrom), - RelayToIp: uint32(relayTo), + RelayFromIp: binary.BigEndian.Uint32(relayFromB[:]), + RelayToIp: binary.BigEndian.Uint32(relayToB[:]), } msg, err := req.Marshal() if err != nil { @@ -283,8 +287,8 @@ func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo) } else { n.intf.SendMessageToHostInfo(header.Control, 0, newhostinfo, msg, make([]byte, 12), make([]byte, mtu)) n.l.WithFields(logrus.Fields{ - "relayFrom": iputil.VpnIp(req.RelayFromIp), - "relayTo": iputil.VpnIp(req.RelayToIp), + "relayFrom": req.RelayFromIp, + "relayTo": req.RelayToIp, "initiatorRelayIndex": req.InitiatorRelayIndex, "responderRelayIndex": req.ResponderRelayIndex, "vpnIp": newhostinfo.vpnIp}). @@ -403,7 +407,7 @@ func (n *connectionManager) shouldSwapPrimary(current, primary *HostInfo) bool { // If we are here then we have multiple tunnels for a host pair and neither side believes the same tunnel is primary. // Let's sort this out. - if current.vpnIp < n.intf.myVpnIp { + if current.vpnIp.Compare(n.intf.myVpnNet.Addr()) < 0 { // Only one side should flip primary because if both flip then we may never resolve to a single tunnel. // vpn ip is static across all tunnels for this host pair so lets use that to determine who is flipping. // The remotes vpn ip is lower than mine. I will not flip. @@ -457,12 +461,12 @@ func (n *connectionManager) sendPunch(hostinfo *HostInfo) { } if n.punchy.GetTargetEverything() { - hostinfo.remotes.ForEach(n.hostMap.GetPreferredRanges(), func(addr *udp.Addr, preferred bool) { + hostinfo.remotes.ForEach(n.hostMap.GetPreferredRanges(), func(addr netip.AddrPort, preferred bool) { n.metricsTxPunchy.Inc(1) n.intf.outside.WriteTo([]byte{1}, addr) }) - } else if hostinfo.remote != nil { + } else if hostinfo.remote.IsValid() { n.metricsTxPunchy.Inc(1) n.intf.outside.WriteTo([]byte{1}, hostinfo.remote) } diff --git a/connection_manager_test.go b/connection_manager_test.go index f50bcf8..5f97cad 100644 --- a/connection_manager_test.go +++ b/connection_manager_test.go @@ -5,28 +5,26 @@ import ( "crypto/ed25519" "crypto/rand" "net" + "net/netip" "testing" "time" "github.com/flynn/noise" "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/config" - "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/test" "github.com/slackhq/nebula/udp" "github.com/stretchr/testify/assert" ) -var vpnIp iputil.VpnIp - func newTestLighthouse() *LightHouse { lh := &LightHouse{ l: test.NewLogger(), - addrMap: map[iputil.VpnIp]*RemoteList{}, - queryChan: make(chan iputil.VpnIp, 10), + addrMap: map[netip.Addr]*RemoteList{}, + queryChan: make(chan netip.Addr, 10), } - lighthouses := map[iputil.VpnIp]struct{}{} - staticList := map[iputil.VpnIp]struct{}{} + lighthouses := map[netip.Addr]struct{}{} + staticList := map[netip.Addr]struct{}{} lh.lighthouses.Store(&lighthouses) lh.staticList.Store(&staticList) @@ -37,10 +35,10 @@ func newTestLighthouse() *LightHouse { func Test_NewConnectionManagerTest(t *testing.T) { l := test.NewLogger() //_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24") - _, vpncidr, _ := net.ParseCIDR("172.1.1.1/24") - _, localrange, _ := net.ParseCIDR("10.1.1.1/24") - vpnIp = iputil.Ip2VpnIp(net.ParseIP("172.1.1.2")) - preferredRanges := []*net.IPNet{localrange} + vpncidr := netip.MustParsePrefix("172.1.1.1/24") + localrange := netip.MustParsePrefix("10.1.1.1/24") + vpnIp := netip.MustParseAddr("172.1.1.2") + preferredRanges := []netip.Prefix{localrange} // Very incomplete mock objects hostMap := newHostMap(l, vpncidr) @@ -120,9 +118,10 @@ func Test_NewConnectionManagerTest(t *testing.T) { func Test_NewConnectionManagerTest2(t *testing.T) { l := test.NewLogger() //_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24") - _, vpncidr, _ := net.ParseCIDR("172.1.1.1/24") - _, localrange, _ := net.ParseCIDR("10.1.1.1/24") - preferredRanges := []*net.IPNet{localrange} + vpncidr := netip.MustParsePrefix("172.1.1.1/24") + localrange := netip.MustParsePrefix("10.1.1.1/24") + vpnIp := netip.MustParseAddr("172.1.1.2") + preferredRanges := []netip.Prefix{localrange} // Very incomplete mock objects hostMap := newHostMap(l, vpncidr) @@ -211,9 +210,10 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) { IP: net.IPv4(172, 1, 1, 2), Mask: net.IPMask{255, 255, 255, 0}, } - _, vpncidr, _ := net.ParseCIDR("172.1.1.1/24") - _, localrange, _ := net.ParseCIDR("10.1.1.1/24") - preferredRanges := []*net.IPNet{localrange} + vpncidr := netip.MustParsePrefix("172.1.1.1/24") + localrange := netip.MustParsePrefix("10.1.1.1/24") + vpnIp := netip.MustParseAddr("172.1.1.2") + preferredRanges := []netip.Prefix{localrange} hostMap := newHostMap(l, vpncidr) hostMap.preferredRanges.Store(&preferredRanges) diff --git a/control.go b/control.go index c227b20..7782b23 100644 --- a/control.go +++ b/control.go @@ -2,7 +2,7 @@ package nebula import ( "context" - "net" + "net/netip" "os" "os/signal" "syscall" @@ -10,9 +10,7 @@ import ( "github.com/sirupsen/logrus" "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/header" - "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/overlay" - "github.com/slackhq/nebula/udp" ) // Every interaction here needs to take extra care to copy memory and not return or use arguments "as is" when touching @@ -21,10 +19,10 @@ import ( type controlEach func(h *HostInfo) type controlHostLister interface { - QueryVpnIp(vpnIp iputil.VpnIp) *HostInfo + QueryVpnIp(vpnIp netip.Addr) *HostInfo ForEachIndex(each controlEach) ForEachVpnIp(each controlEach) - GetPreferredRanges() []*net.IPNet + GetPreferredRanges() []netip.Prefix } type Control struct { @@ -39,15 +37,15 @@ type Control struct { } type ControlHostInfo struct { - VpnIp net.IP `json:"vpnIp"` + VpnIp netip.Addr `json:"vpnIp"` LocalIndex uint32 `json:"localIndex"` RemoteIndex uint32 `json:"remoteIndex"` - RemoteAddrs []*udp.Addr `json:"remoteAddrs"` + RemoteAddrs []netip.AddrPort `json:"remoteAddrs"` Cert *cert.NebulaCertificate `json:"cert"` MessageCounter uint64 `json:"messageCounter"` - CurrentRemote *udp.Addr `json:"currentRemote"` - CurrentRelaysToMe []iputil.VpnIp `json:"currentRelaysToMe"` - CurrentRelaysThroughMe []iputil.VpnIp `json:"currentRelaysThroughMe"` + CurrentRemote netip.AddrPort `json:"currentRemote"` + CurrentRelaysToMe []netip.Addr `json:"currentRelaysToMe"` + CurrentRelaysThroughMe []netip.Addr `json:"currentRelaysThroughMe"` } // Start actually runs nebula, this is a nonblocking call. To block use Control.ShutdownBlock() @@ -132,7 +130,8 @@ func (c *Control) ListHostmapIndexes(pendingMap bool) []ControlHostInfo { } // GetHostInfoByVpnIp returns a single tunnels hostInfo, or nil if not found -func (c *Control) GetHostInfoByVpnIp(vpnIp iputil.VpnIp, pending bool) *ControlHostInfo { +// Caller should take care to Unmap() any 4in6 addresses prior to calling. +func (c *Control) GetHostInfoByVpnIp(vpnIp netip.Addr, pending bool) *ControlHostInfo { var hl controlHostLister if pending { hl = c.f.handshakeManager @@ -150,19 +149,21 @@ func (c *Control) GetHostInfoByVpnIp(vpnIp iputil.VpnIp, pending bool) *ControlH } // SetRemoteForTunnel forces a tunnel to use a specific remote -func (c *Control) SetRemoteForTunnel(vpnIp iputil.VpnIp, addr udp.Addr) *ControlHostInfo { +// Caller should take care to Unmap() any 4in6 addresses prior to calling. +func (c *Control) SetRemoteForTunnel(vpnIp netip.Addr, addr netip.AddrPort) *ControlHostInfo { hostInfo := c.f.hostMap.QueryVpnIp(vpnIp) if hostInfo == nil { return nil } - hostInfo.SetRemote(addr.Copy()) + hostInfo.SetRemote(addr) ch := copyHostInfo(hostInfo, c.f.hostMap.GetPreferredRanges()) return &ch } // CloseTunnel closes a fully established tunnel. If localOnly is false it will notify the remote end as well. -func (c *Control) CloseTunnel(vpnIp iputil.VpnIp, localOnly bool) bool { +// Caller should take care to Unmap() any 4in6 addresses prior to calling. +func (c *Control) CloseTunnel(vpnIp netip.Addr, localOnly bool) bool { hostInfo := c.f.hostMap.QueryVpnIp(vpnIp) if hostInfo == nil { return false @@ -205,7 +206,7 @@ func (c *Control) CloseAllTunnels(excludeLighthouses bool) (closed int) { } // Learn which hosts are being used as relays, so we can shut them down last. - relayingHosts := map[iputil.VpnIp]*HostInfo{} + relayingHosts := map[netip.Addr]*HostInfo{} // Grab the hostMap lock to access the Relays map c.f.hostMap.Lock() for _, relayingHost := range c.f.hostMap.Relays { @@ -236,15 +237,16 @@ func (c *Control) Device() overlay.Device { return c.f.inside } -func copyHostInfo(h *HostInfo, preferredRanges []*net.IPNet) ControlHostInfo { +func copyHostInfo(h *HostInfo, preferredRanges []netip.Prefix) ControlHostInfo { chi := ControlHostInfo{ - VpnIp: h.vpnIp.ToIP(), + VpnIp: h.vpnIp, LocalIndex: h.localIndexId, RemoteIndex: h.remoteIndexId, RemoteAddrs: h.remotes.CopyAddrs(preferredRanges), CurrentRelaysToMe: h.relayState.CopyRelayIps(), CurrentRelaysThroughMe: h.relayState.CopyRelayForIps(), + CurrentRemote: h.remote, } if h.ConnectionState != nil { @@ -255,10 +257,6 @@ func copyHostInfo(h *HostInfo, preferredRanges []*net.IPNet) ControlHostInfo { chi.Cert = c.Copy() } - if h.remote != nil { - chi.CurrentRemote = h.remote.Copy() - } - return chi } diff --git a/control_test.go b/control_test.go index c64a3a4..fbf29c0 100644 --- a/control_test.go +++ b/control_test.go @@ -2,15 +2,14 @@ package nebula import ( "net" + "net/netip" "reflect" "testing" "time" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/cert" - "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/test" - "github.com/slackhq/nebula/udp" "github.com/stretchr/testify/assert" ) @@ -18,18 +17,19 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) { l := test.NewLogger() // Special care must be taken to re-use all objects provided to the hostmap and certificate in the expectedInfo object // To properly ensure we are not exposing core memory to the caller - hm := newHostMap(l, &net.IPNet{}) - hm.preferredRanges.Store(&[]*net.IPNet{}) + hm := newHostMap(l, netip.Prefix{}) + hm.preferredRanges.Store(&[]netip.Prefix{}) + + remote1 := netip.MustParseAddrPort("0.0.0.100:4444") + remote2 := netip.MustParseAddrPort("[1:2:3:4:5:6:7:8]:4444") - remote1 := udp.NewAddr(net.ParseIP("0.0.0.100"), 4444) - remote2 := udp.NewAddr(net.ParseIP("1:2:3:4:5:6:7:8"), 4444) ipNet := net.IPNet{ - IP: net.IPv4(1, 2, 3, 4), + IP: remote1.Addr().AsSlice(), Mask: net.IPMask{255, 255, 255, 0}, } ipNet2 := net.IPNet{ - IP: net.ParseIP("1:2:3:4:5:6:7:8"), + IP: remote2.Addr().AsSlice(), Mask: net.IPMask{255, 255, 255, 0}, } @@ -50,8 +50,12 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) { } remotes := NewRemoteList(nil) - remotes.unlockedPrependV4(0, NewIp4AndPort(remote1.IP, uint32(remote1.Port))) - remotes.unlockedPrependV6(0, NewIp6AndPort(remote2.IP, uint32(remote2.Port))) + remotes.unlockedPrependV4(netip.IPv4Unspecified(), NewIp4AndPortFromNetIP(remote1.Addr(), remote1.Port())) + remotes.unlockedPrependV6(netip.IPv4Unspecified(), NewIp6AndPortFromNetIP(remote2.Addr(), remote2.Port())) + + vpnIp, ok := netip.AddrFromSlice(ipNet.IP) + assert.True(t, ok) + hm.unlockedAddHostInfo(&HostInfo{ remote: remote1, remotes: remotes, @@ -60,14 +64,17 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) { }, remoteIndexId: 200, localIndexId: 201, - vpnIp: iputil.Ip2VpnIp(ipNet.IP), + vpnIp: vpnIp, relayState: RelayState{ - relays: map[iputil.VpnIp]struct{}{}, - relayForByIp: map[iputil.VpnIp]*Relay{}, + relays: map[netip.Addr]struct{}{}, + relayForByIp: map[netip.Addr]*Relay{}, relayForByIdx: map[uint32]*Relay{}, }, }, &Interface{}) + vpnIp2, ok := netip.AddrFromSlice(ipNet2.IP) + assert.True(t, ok) + hm.unlockedAddHostInfo(&HostInfo{ remote: remote1, remotes: remotes, @@ -76,10 +83,10 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) { }, remoteIndexId: 200, localIndexId: 201, - vpnIp: iputil.Ip2VpnIp(ipNet2.IP), + vpnIp: vpnIp2, relayState: RelayState{ - relays: map[iputil.VpnIp]struct{}{}, - relayForByIp: map[iputil.VpnIp]*Relay{}, + relays: map[netip.Addr]struct{}{}, + relayForByIp: map[netip.Addr]*Relay{}, relayForByIdx: map[uint32]*Relay{}, }, }, &Interface{}) @@ -91,27 +98,29 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) { l: logrus.New(), } - thi := c.GetHostInfoByVpnIp(iputil.Ip2VpnIp(ipNet.IP), false) + thi := c.GetHostInfoByVpnIp(vpnIp, false) expectedInfo := ControlHostInfo{ - VpnIp: net.IPv4(1, 2, 3, 4).To4(), + VpnIp: vpnIp, LocalIndex: 201, RemoteIndex: 200, - RemoteAddrs: []*udp.Addr{remote2, remote1}, + RemoteAddrs: []netip.AddrPort{remote2, remote1}, Cert: crt.Copy(), MessageCounter: 0, - CurrentRemote: udp.NewAddr(net.ParseIP("0.0.0.100"), 4444), - CurrentRelaysToMe: []iputil.VpnIp{}, - CurrentRelaysThroughMe: []iputil.VpnIp{}, + CurrentRemote: remote1, + CurrentRelaysToMe: []netip.Addr{}, + CurrentRelaysThroughMe: []netip.Addr{}, } // Make sure we don't have any unexpected fields assertFields(t, []string{"VpnIp", "LocalIndex", "RemoteIndex", "RemoteAddrs", "Cert", "MessageCounter", "CurrentRemote", "CurrentRelaysToMe", "CurrentRelaysThroughMe"}, thi) - test.AssertDeepCopyEqual(t, &expectedInfo, thi) + assert.EqualValues(t, &expectedInfo, thi) + //TODO: netip.Addr reuses global memory for zone identifiers which breaks our "no reused memory check" here + //test.AssertDeepCopyEqual(t, &expectedInfo, thi) // Make sure we don't panic if the host info doesn't have a cert yet assert.NotPanics(t, func() { - thi = c.GetHostInfoByVpnIp(iputil.Ip2VpnIp(ipNet2.IP), false) + thi = c.GetHostInfoByVpnIp(vpnIp2, false) }) } diff --git a/control_tester.go b/control_tester.go index b786ba3..d46540f 100644 --- a/control_tester.go +++ b/control_tester.go @@ -4,14 +4,13 @@ package nebula import ( - "net" + "net/netip" "github.com/slackhq/nebula/cert" "github.com/google/gopacket" "github.com/google/gopacket/layers" "github.com/slackhq/nebula/header" - "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/overlay" "github.com/slackhq/nebula/udp" ) @@ -50,37 +49,30 @@ func (c *Control) WaitForTypeByIndex(toIndex uint32, msgType header.MessageType, // InjectLightHouseAddr will push toAddr into the local lighthouse cache for the vpnIp // This is necessary if you did not configure static hosts or are not running a lighthouse -func (c *Control) InjectLightHouseAddr(vpnIp net.IP, toAddr *net.UDPAddr) { +func (c *Control) InjectLightHouseAddr(vpnIp netip.Addr, toAddr netip.AddrPort) { c.f.lightHouse.Lock() - remoteList := c.f.lightHouse.unlockedGetRemoteList(iputil.Ip2VpnIp(vpnIp)) + remoteList := c.f.lightHouse.unlockedGetRemoteList(vpnIp) remoteList.Lock() defer remoteList.Unlock() c.f.lightHouse.Unlock() - iVpnIp := iputil.Ip2VpnIp(vpnIp) - if v4 := toAddr.IP.To4(); v4 != nil { - remoteList.unlockedPrependV4(iVpnIp, NewIp4AndPort(v4, uint32(toAddr.Port))) + if toAddr.Addr().Is4() { + remoteList.unlockedPrependV4(vpnIp, NewIp4AndPortFromNetIP(toAddr.Addr(), toAddr.Port())) } else { - remoteList.unlockedPrependV6(iVpnIp, NewIp6AndPort(toAddr.IP, uint32(toAddr.Port))) + remoteList.unlockedPrependV6(vpnIp, NewIp6AndPortFromNetIP(toAddr.Addr(), toAddr.Port())) } } // InjectRelays will push relayVpnIps into the local lighthouse cache for the vpnIp // This is necessary to inform an initiator of possible relays for communicating with a responder -func (c *Control) InjectRelays(vpnIp net.IP, relayVpnIps []net.IP) { +func (c *Control) InjectRelays(vpnIp netip.Addr, relayVpnIps []netip.Addr) { c.f.lightHouse.Lock() - remoteList := c.f.lightHouse.unlockedGetRemoteList(iputil.Ip2VpnIp(vpnIp)) + remoteList := c.f.lightHouse.unlockedGetRemoteList(vpnIp) remoteList.Lock() defer remoteList.Unlock() c.f.lightHouse.Unlock() - iVpnIp := iputil.Ip2VpnIp(vpnIp) - uVpnIp := []uint32{} - for _, rVPnIp := range relayVpnIps { - uVpnIp = append(uVpnIp, uint32(iputil.Ip2VpnIp(rVPnIp))) - } - - remoteList.unlockedSetRelay(iVpnIp, iVpnIp, uVpnIp) + remoteList.unlockedSetRelay(vpnIp, vpnIp, relayVpnIps) } // GetFromTun will pull a packet off the tun side of nebula @@ -107,13 +99,14 @@ func (c *Control) InjectUDPPacket(p *udp.Packet) { } // InjectTunUDPPacket puts a udp packet on the tun interface. Using UDP here because it's a simpler protocol -func (c *Control) InjectTunUDPPacket(toIp net.IP, toPort uint16, fromPort uint16, data []byte) { +func (c *Control) InjectTunUDPPacket(toIp netip.Addr, toPort uint16, fromPort uint16, data []byte) { + //TODO: IPV6-WORK ip := layers.IPv4{ Version: 4, TTL: 64, Protocol: layers.IPProtocolUDP, - SrcIP: c.f.inside.Cidr().IP, - DstIP: toIp, + SrcIP: c.f.inside.Cidr().Addr().Unmap().AsSlice(), + DstIP: toIp.Unmap().AsSlice(), } udp := layers.UDP{ @@ -138,16 +131,16 @@ func (c *Control) InjectTunUDPPacket(toIp net.IP, toPort uint16, fromPort uint16 c.f.inside.(*overlay.TestTun).Send(buffer.Bytes()) } -func (c *Control) GetVpnIp() iputil.VpnIp { - return c.f.myVpnIp +func (c *Control) GetVpnIp() netip.Addr { + return c.f.myVpnNet.Addr() } -func (c *Control) GetUDPAddr() string { - return c.f.outside.(*udp.TesterConn).Addr.String() +func (c *Control) GetUDPAddr() netip.AddrPort { + return c.f.outside.(*udp.TesterConn).Addr } -func (c *Control) KillPendingTunnel(vpnIp net.IP) bool { - hostinfo := c.f.handshakeManager.QueryVpnIp(iputil.Ip2VpnIp(vpnIp)) +func (c *Control) KillPendingTunnel(vpnIp netip.Addr) bool { + hostinfo := c.f.handshakeManager.QueryVpnIp(vpnIp) if hostinfo == nil { return false } @@ -164,6 +157,6 @@ func (c *Control) GetCert() *cert.NebulaCertificate { return c.f.pki.GetCertState().Certificate } -func (c *Control) ReHandshake(vpnIp iputil.VpnIp) { +func (c *Control) ReHandshake(vpnIp netip.Addr) { c.f.handshakeManager.StartHandshake(vpnIp, nil) } diff --git a/dns_server.go b/dns_server.go index 4e7bb83..5fea65c 100644 --- a/dns_server.go +++ b/dns_server.go @@ -3,6 +3,7 @@ package nebula import ( "fmt" "net" + "net/netip" "strconv" "strings" "sync" @@ -10,7 +11,6 @@ import ( "github.com/miekg/dns" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" - "github.com/slackhq/nebula/iputil" ) // This whole thing should be rewritten to use context @@ -42,19 +42,21 @@ func (d *dnsRecords) Query(data string) string { } func (d *dnsRecords) QueryCert(data string) string { - ip := net.ParseIP(data[:len(data)-1]) - if ip == nil { + ip, err := netip.ParseAddr(data[:len(data)-1]) + if err != nil { return "" } - iip := iputil.Ip2VpnIp(ip) - hostinfo := d.hostMap.QueryVpnIp(iip) + + hostinfo := d.hostMap.QueryVpnIp(ip) if hostinfo == nil { return "" } + q := hostinfo.GetCert() if q == nil { return "" } + cert := q.Details c := fmt.Sprintf("\"Name: %s\" \"Ips: %s\" \"Subnets %s\" \"Groups %s\" \"NotBefore %s\" \"NotAfter %s\" \"PublicKey %x\" \"IsCA %t\" \"Issuer %s\"", cert.Name, cert.Ips, cert.Subnets, cert.Groups, cert.NotBefore, cert.NotAfter, cert.PublicKey, cert.IsCA, cert.Issuer) return c @@ -80,7 +82,11 @@ func parseQuery(l *logrus.Logger, m *dns.Msg, w dns.ResponseWriter) { } case dns.TypeTXT: a, _, _ := net.SplitHostPort(w.RemoteAddr().String()) - b := net.ParseIP(a) + b, err := netip.ParseAddr(a) + if err != nil { + return + } + // We don't answer these queries from non nebula nodes or localhost //l.Debugf("Does %s contain %s", b, dnsR.hostMap.vpnCIDR) if !dnsR.hostMap.vpnCIDR.Contains(b) && a != "127.0.0.1" { diff --git a/e2e/handshakes_test.go b/e2e/handshakes_test.go index 59f1d0e..3d42a56 100644 --- a/e2e/handshakes_test.go +++ b/e2e/handshakes_test.go @@ -5,7 +5,7 @@ package e2e import ( "fmt" - "net" + "net/netip" "testing" "time" @@ -13,19 +13,18 @@ import ( "github.com/slackhq/nebula" "github.com/slackhq/nebula/e2e/router" "github.com/slackhq/nebula/header" - "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/udp" "github.com/stretchr/testify/assert" "gopkg.in/yaml.v2" ) func BenchmarkHotPath(b *testing.B) { - ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) - myControl, _, _, _ := newSimpleServer(ca, caKey, "me", net.IP{10, 0, 0, 1}, nil) - theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil) + ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, _, _, _ := newSimpleServer(ca, caKey, "me", "10.128.0.1/24", nil) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", "10.128.0.2/24", nil) // Put their info in our lighthouse - myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr) + myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr) // Start the servers myControl.Start() @@ -35,7 +34,7 @@ func BenchmarkHotPath(b *testing.B) { r.CancelFlowLogs() for n := 0; n < b.N; n++ { - myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me")) + myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me")) _ = r.RouteForAllUntilTxTun(theirControl) } @@ -44,19 +43,19 @@ func BenchmarkHotPath(b *testing.B) { } func TestGoodHandshake(t *testing.T) { - ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) - myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me", net.IP{10, 0, 0, 1}, nil) - theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil) + ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me", "10.128.0.1/24", nil) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", "10.128.0.2/24", nil) // Put their info in our lighthouse - myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr) + myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr) // Start the servers myControl.Start() theirControl.Start() t.Log("Send a udp packet through to begin standing up the tunnel, this should come out the other side") - myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me")) + myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me")) t.Log("Have them consume my stage 0 packet. They have a tunnel now") theirControl.InjectUDPPacket(myControl.GetFromUDP(true)) @@ -77,16 +76,16 @@ func TestGoodHandshake(t *testing.T) { myControl.WaitForType(1, 0, theirControl) t.Log("Make sure our host infos are correct") - assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl) + assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl) t.Log("Get that cached packet and make sure it looks right") myCachedPacket := theirControl.GetFromTun(true) - assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet.IP, theirVpnIpNet.IP, 80, 80) + assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80) t.Log("Do a bidirectional tunnel test") r := router.NewR(t, myControl, theirControl) defer r.RenderFlow() - assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r) + assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) r.RenderHostmaps("Final hostmaps", myControl, theirControl) myControl.Stop() @@ -95,20 +94,20 @@ func TestGoodHandshake(t *testing.T) { } func TestWrongResponderHandshake(t *testing.T) { - ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) + ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) // The IPs here are chosen on purpose: // The current remote handling will sort by preference, public, and then lexically. // So we need them to have a higher address than evil (we could apply a preference though) - myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me", net.IP{10, 0, 0, 100}, nil) - theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 99}, nil) - evilControl, evilVpnIp, evilUdpAddr, _ := newSimpleServer(ca, caKey, "evil", net.IP{10, 0, 0, 2}, nil) + myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me", "10.128.0.100/24", nil) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", "10.128.0.99/24", nil) + evilControl, evilVpnIp, evilUdpAddr, _ := newSimpleServer(ca, caKey, "evil", "10.128.0.2/24", nil) // Add their real udp addr, which should be tried after evil. - myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr) + myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr) // Put the evil udp addr in for their vpn Ip, this is a case of being lied to by the lighthouse. - myControl.InjectLightHouseAddr(theirVpnIpNet.IP, evilUdpAddr) + myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), evilUdpAddr) // Build a router so we don't have to reason who gets which packet r := router.NewR(t, myControl, theirControl, evilControl) @@ -120,7 +119,7 @@ func TestWrongResponderHandshake(t *testing.T) { evilControl.Start() t.Log("Start the handshake process, we will route until we see our cached packet get sent to them") - myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me")) + myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me")) r.RouteForAllExitFunc(func(p *udp.Packet, c *nebula.Control) router.ExitType { h := &header.H{} err := h.Parse(p.Data) @@ -128,7 +127,7 @@ func TestWrongResponderHandshake(t *testing.T) { panic(err) } - if p.ToIp.Equal(theirUdpAddr.IP) && p.ToPort == uint16(theirUdpAddr.Port) && h.Type == 1 { + if p.To == theirUdpAddr && h.Type == 1 { return router.RouteAndExit } @@ -139,18 +138,18 @@ func TestWrongResponderHandshake(t *testing.T) { t.Log("My cached packet should be received by them") myCachedPacket := theirControl.GetFromTun(true) - assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet.IP, theirVpnIpNet.IP, 80, 80) + assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80) t.Log("Test the tunnel with them") - assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl) - assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r) + assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl) + assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) t.Log("Flush all packets from all controllers") r.FlushAll() t.Log("Ensure ensure I don't have any hostinfo artifacts from evil") - assert.Nil(t, myControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(evilVpnIp.IP), true), "My pending hostmap should not contain evil") - assert.Nil(t, myControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(evilVpnIp.IP), false), "My main hostmap should not contain evil") + assert.Nil(t, myControl.GetHostInfoByVpnIp(evilVpnIp.Addr(), true), "My pending hostmap should not contain evil") + assert.Nil(t, myControl.GetHostInfoByVpnIp(evilVpnIp.Addr(), false), "My main hostmap should not contain evil") //NOTE: if evil lost the handshake race it may still have a tunnel since me would reject the handshake since the tunnel is complete //TODO: assert hostmaps for everyone @@ -164,13 +163,13 @@ func TestStage1Race(t *testing.T) { // This tests ensures that two hosts handshaking with each other at the same time will allow traffic to flow // But will eventually collapse down to a single tunnel - ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) - myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 1}, nil) - theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil) + ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me ", "10.128.0.1/24", nil) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", "10.128.0.2/24", nil) // Put their info in our lighthouse and vice versa - myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr) - theirControl.InjectLightHouseAddr(myVpnIpNet.IP, myUdpAddr) + myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr) + theirControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr) // Build a router so we don't have to reason who gets which packet r := router.NewR(t, myControl, theirControl) @@ -181,8 +180,8 @@ func TestStage1Race(t *testing.T) { theirControl.Start() t.Log("Trigger a handshake to start on both me and them") - myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me")) - theirControl.InjectTunUDPPacket(myVpnIpNet.IP, 80, 80, []byte("Hi from them")) + myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me")) + theirControl.InjectTunUDPPacket(myVpnIpNet.Addr(), 80, 80, []byte("Hi from them")) t.Log("Get both stage 1 handshake packets") myHsForThem := myControl.GetFromUDP(true) @@ -194,14 +193,14 @@ func TestStage1Race(t *testing.T) { r.Log("Route until they receive a message packet") myCachedPacket := r.RouteForAllUntilTxTun(theirControl) - assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet.IP, theirVpnIpNet.IP, 80, 80) + assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80) r.Log("Their cached packet should be received by me") theirCachedPacket := r.RouteForAllUntilTxTun(myControl) - assertUdpPacket(t, []byte("Hi from them"), theirCachedPacket, theirVpnIpNet.IP, myVpnIpNet.IP, 80, 80) + assertUdpPacket(t, []byte("Hi from them"), theirCachedPacket, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), 80, 80) r.Log("Do a bidirectional tunnel test") - assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r) + assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) myHostmapHosts := myControl.ListHostmapHosts(false) myHostmapIndexes := myControl.ListHostmapIndexes(false) @@ -219,7 +218,7 @@ func TestStage1Race(t *testing.T) { r.Log("Spin until connection manager tears down a tunnel") for len(myControl.GetHostmap().Indexes)+len(theirControl.GetHostmap().Indexes) > 2 { - assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r) + assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) t.Log("Connection manager hasn't ticked yet") time.Sleep(time.Second) } @@ -241,13 +240,13 @@ func TestStage1Race(t *testing.T) { } func TestUncleanShutdownRaceLoser(t *testing.T) { - ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) - myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 1}, nil) - theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil) + ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me ", "10.128.0.1/24", nil) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", "10.128.0.2/24", nil) // Teach my how to get to the relay and that their can be reached via the relay - myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr) - theirControl.InjectLightHouseAddr(myVpnIpNet.IP, myUdpAddr) + myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr) + theirControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr) // Build a router so we don't have to reason who gets which packet r := router.NewR(t, myControl, theirControl) @@ -258,28 +257,28 @@ func TestUncleanShutdownRaceLoser(t *testing.T) { theirControl.Start() r.Log("Trigger a handshake from me to them") - myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me")) + myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me")) p := r.RouteForAllUntilTxTun(theirControl) - assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.IP, theirVpnIpNet.IP, 80, 80) + assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80) r.Log("Nuke my hostmap") myHostmap := myControl.GetHostmap() - myHostmap.Hosts = map[iputil.VpnIp]*nebula.HostInfo{} + myHostmap.Hosts = map[netip.Addr]*nebula.HostInfo{} myHostmap.Indexes = map[uint32]*nebula.HostInfo{} myHostmap.RemoteIndexes = map[uint32]*nebula.HostInfo{} - myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me again")) + myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me again")) p = r.RouteForAllUntilTxTun(theirControl) - assertUdpPacket(t, []byte("Hi from me again"), p, myVpnIpNet.IP, theirVpnIpNet.IP, 80, 80) + assertUdpPacket(t, []byte("Hi from me again"), p, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80) r.Log("Assert the tunnel works") - assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r) + assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) r.Log("Wait for the dead index to go away") start := len(theirControl.GetHostmap().Indexes) for { - assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r) + assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) if len(theirControl.GetHostmap().Indexes) < start { break } @@ -290,13 +289,13 @@ func TestUncleanShutdownRaceLoser(t *testing.T) { } func TestUncleanShutdownRaceWinner(t *testing.T) { - ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) - myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 1}, nil) - theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil) + ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me ", "10.128.0.1/24", nil) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", "10.128.0.2/24", nil) // Teach my how to get to the relay and that their can be reached via the relay - myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr) - theirControl.InjectLightHouseAddr(myVpnIpNet.IP, myUdpAddr) + myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr) + theirControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr) // Build a router so we don't have to reason who gets which packet r := router.NewR(t, myControl, theirControl) @@ -307,30 +306,30 @@ func TestUncleanShutdownRaceWinner(t *testing.T) { theirControl.Start() r.Log("Trigger a handshake from me to them") - myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me")) + myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me")) p := r.RouteForAllUntilTxTun(theirControl) - assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.IP, theirVpnIpNet.IP, 80, 80) + assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80) r.RenderHostmaps("Final hostmaps", myControl, theirControl) r.Log("Nuke my hostmap") theirHostmap := theirControl.GetHostmap() - theirHostmap.Hosts = map[iputil.VpnIp]*nebula.HostInfo{} + theirHostmap.Hosts = map[netip.Addr]*nebula.HostInfo{} theirHostmap.Indexes = map[uint32]*nebula.HostInfo{} theirHostmap.RemoteIndexes = map[uint32]*nebula.HostInfo{} - theirControl.InjectTunUDPPacket(myVpnIpNet.IP, 80, 80, []byte("Hi from them again")) + theirControl.InjectTunUDPPacket(myVpnIpNet.Addr(), 80, 80, []byte("Hi from them again")) p = r.RouteForAllUntilTxTun(myControl) - assertUdpPacket(t, []byte("Hi from them again"), p, theirVpnIpNet.IP, myVpnIpNet.IP, 80, 80) + assertUdpPacket(t, []byte("Hi from them again"), p, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), 80, 80) r.RenderHostmaps("Derp hostmaps", myControl, theirControl) r.Log("Assert the tunnel works") - assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r) + assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) r.Log("Wait for the dead index to go away") start := len(myControl.GetHostmap().Indexes) for { - assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r) + assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) if len(myControl.GetHostmap().Indexes) < start { break } @@ -341,15 +340,15 @@ func TestUncleanShutdownRaceWinner(t *testing.T) { } func TestRelays(t *testing.T) { - ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) - myControl, myVpnIpNet, _, _ := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 1}, m{"relay": m{"use_relays": true}}) - relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(ca, caKey, "relay ", net.IP{10, 0, 0, 128}, m{"relay": m{"am_relay": true}}) - theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", net.IP{10, 0, 0, 2}, m{"relay": m{"use_relays": true}}) + ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, _, _ := newSimpleServer(ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}}) + relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}}) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}}) // Teach my how to get to the relay and that their can be reached via the relay - myControl.InjectLightHouseAddr(relayVpnIpNet.IP, relayUdpAddr) - myControl.InjectRelays(theirVpnIpNet.IP, []net.IP{relayVpnIpNet.IP}) - relayControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr) + myControl.InjectLightHouseAddr(relayVpnIpNet.Addr(), relayUdpAddr) + myControl.InjectRelays(theirVpnIpNet.Addr(), []netip.Addr{relayVpnIpNet.Addr()}) + relayControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr) // Build a router so we don't have to reason who gets which packet r := router.NewR(t, myControl, relayControl, theirControl) @@ -361,31 +360,31 @@ func TestRelays(t *testing.T) { theirControl.Start() t.Log("Trigger a handshake from me to them via the relay") - myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me")) + myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me")) p := r.RouteForAllUntilTxTun(theirControl) r.Log("Assert the tunnel works") - assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.IP, theirVpnIpNet.IP, 80, 80) + assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80) r.RenderHostmaps("Final hostmaps", myControl, relayControl, theirControl) //TODO: assert we actually used the relay even though it should be impossible for a tunnel to have occurred without it } func TestStage1RaceRelays(t *testing.T) { //NOTE: this is a race between me and relay resulting in a full tunnel from me to them via relay - ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) - myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 1}, m{"relay": m{"use_relays": true}}) - relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(ca, caKey, "relay ", net.IP{10, 0, 0, 128}, m{"relay": m{"am_relay": true}}) - theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", net.IP{10, 0, 0, 2}, m{"relay": m{"use_relays": true}}) + ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}}) + relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}}) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}}) // Teach my how to get to the relay and that their can be reached via the relay - myControl.InjectLightHouseAddr(relayVpnIpNet.IP, relayUdpAddr) - theirControl.InjectLightHouseAddr(relayVpnIpNet.IP, relayUdpAddr) + myControl.InjectLightHouseAddr(relayVpnIpNet.Addr(), relayUdpAddr) + theirControl.InjectLightHouseAddr(relayVpnIpNet.Addr(), relayUdpAddr) - myControl.InjectRelays(theirVpnIpNet.IP, []net.IP{relayVpnIpNet.IP}) - theirControl.InjectRelays(myVpnIpNet.IP, []net.IP{relayVpnIpNet.IP}) + myControl.InjectRelays(theirVpnIpNet.Addr(), []netip.Addr{relayVpnIpNet.Addr()}) + theirControl.InjectRelays(myVpnIpNet.Addr(), []netip.Addr{relayVpnIpNet.Addr()}) - relayControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr) - relayControl.InjectLightHouseAddr(myVpnIpNet.IP, myUdpAddr) + relayControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr) + relayControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr) // Build a router so we don't have to reason who gets which packet r := router.NewR(t, myControl, relayControl, theirControl) @@ -397,14 +396,14 @@ func TestStage1RaceRelays(t *testing.T) { theirControl.Start() r.Log("Get a tunnel between me and relay") - assertTunnel(t, myVpnIpNet.IP, relayVpnIpNet.IP, myControl, relayControl, r) + assertTunnel(t, myVpnIpNet.Addr(), relayVpnIpNet.Addr(), myControl, relayControl, r) r.Log("Get a tunnel between them and relay") - assertTunnel(t, theirVpnIpNet.IP, relayVpnIpNet.IP, theirControl, relayControl, r) + assertTunnel(t, theirVpnIpNet.Addr(), relayVpnIpNet.Addr(), theirControl, relayControl, r) r.Log("Trigger a handshake from both them and me via relay to them and me") - myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me")) - theirControl.InjectTunUDPPacket(myVpnIpNet.IP, 80, 80, []byte("Hi from them")) + myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me")) + theirControl.InjectTunUDPPacket(myVpnIpNet.Addr(), 80, 80, []byte("Hi from them")) r.Log("Wait for a packet from them to me") p := r.RouteForAllUntilTxTun(myControl) @@ -421,21 +420,21 @@ func TestStage1RaceRelays(t *testing.T) { func TestStage1RaceRelays2(t *testing.T) { //NOTE: this is a race between me and relay resulting in a full tunnel from me to them via relay - ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) - myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 1}, m{"relay": m{"use_relays": true}}) - relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(ca, caKey, "relay ", net.IP{10, 0, 0, 128}, m{"relay": m{"am_relay": true}}) - theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", net.IP{10, 0, 0, 2}, m{"relay": m{"use_relays": true}}) + ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}}) + relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}}) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}}) l := NewTestLogger() // Teach my how to get to the relay and that their can be reached via the relay - myControl.InjectLightHouseAddr(relayVpnIpNet.IP, relayUdpAddr) - theirControl.InjectLightHouseAddr(relayVpnIpNet.IP, relayUdpAddr) + myControl.InjectLightHouseAddr(relayVpnIpNet.Addr(), relayUdpAddr) + theirControl.InjectLightHouseAddr(relayVpnIpNet.Addr(), relayUdpAddr) - myControl.InjectRelays(theirVpnIpNet.IP, []net.IP{relayVpnIpNet.IP}) - theirControl.InjectRelays(myVpnIpNet.IP, []net.IP{relayVpnIpNet.IP}) + myControl.InjectRelays(theirVpnIpNet.Addr(), []netip.Addr{relayVpnIpNet.Addr()}) + theirControl.InjectRelays(myVpnIpNet.Addr(), []netip.Addr{relayVpnIpNet.Addr()}) - relayControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr) - relayControl.InjectLightHouseAddr(myVpnIpNet.IP, myUdpAddr) + relayControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr) + relayControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr) // Build a router so we don't have to reason who gets which packet r := router.NewR(t, myControl, relayControl, theirControl) @@ -448,16 +447,16 @@ func TestStage1RaceRelays2(t *testing.T) { r.Log("Get a tunnel between me and relay") l.Info("Get a tunnel between me and relay") - assertTunnel(t, myVpnIpNet.IP, relayVpnIpNet.IP, myControl, relayControl, r) + assertTunnel(t, myVpnIpNet.Addr(), relayVpnIpNet.Addr(), myControl, relayControl, r) r.Log("Get a tunnel between them and relay") l.Info("Get a tunnel between them and relay") - assertTunnel(t, theirVpnIpNet.IP, relayVpnIpNet.IP, theirControl, relayControl, r) + assertTunnel(t, theirVpnIpNet.Addr(), relayVpnIpNet.Addr(), theirControl, relayControl, r) r.Log("Trigger a handshake from both them and me via relay to them and me") l.Info("Trigger a handshake from both them and me via relay to them and me") - myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me")) - theirControl.InjectTunUDPPacket(myVpnIpNet.IP, 80, 80, []byte("Hi from them")) + myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me")) + theirControl.InjectTunUDPPacket(myVpnIpNet.Addr(), 80, 80, []byte("Hi from them")) //r.RouteUntilAfterMsgType(myControl, header.Control, header.MessageNone) //r.RouteUntilAfterMsgType(theirControl, header.Control, header.MessageNone) @@ -470,7 +469,7 @@ func TestStage1RaceRelays2(t *testing.T) { r.Log("Assert the tunnel works") l.Info("Assert the tunnel works") - assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r) + assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r) t.Log("Wait until we remove extra tunnels") l.Info("Wait until we remove extra tunnels") @@ -490,7 +489,7 @@ func TestStage1RaceRelays2(t *testing.T) { "theirControl": len(theirControl.GetHostmap().Indexes), "relayControl": len(relayControl.GetHostmap().Indexes), }).Info("Waiting for hostinfos to be removed...") - assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r) + assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) t.Log("Connection manager hasn't ticked yet") time.Sleep(time.Second) retries-- @@ -498,7 +497,7 @@ func TestStage1RaceRelays2(t *testing.T) { r.Log("Assert the tunnel works") l.Info("Assert the tunnel works") - assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r) + assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r) myControl.Stop() theirControl.Stop() @@ -507,16 +506,17 @@ func TestStage1RaceRelays2(t *testing.T) { // ////TODO: assert hostmaps } + func TestRehandshakingRelays(t *testing.T) { - ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) - myControl, myVpnIpNet, _, _ := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 1}, m{"relay": m{"use_relays": true}}) - relayControl, relayVpnIpNet, relayUdpAddr, relayConfig := newSimpleServer(ca, caKey, "relay ", net.IP{10, 0, 0, 128}, m{"relay": m{"am_relay": true}}) - theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", net.IP{10, 0, 0, 2}, m{"relay": m{"use_relays": true}}) + ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, _, _ := newSimpleServer(ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}}) + relayControl, relayVpnIpNet, relayUdpAddr, relayConfig := newSimpleServer(ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}}) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}}) // Teach my how to get to the relay and that their can be reached via the relay - myControl.InjectLightHouseAddr(relayVpnIpNet.IP, relayUdpAddr) - myControl.InjectRelays(theirVpnIpNet.IP, []net.IP{relayVpnIpNet.IP}) - relayControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr) + myControl.InjectLightHouseAddr(relayVpnIpNet.Addr(), relayUdpAddr) + myControl.InjectRelays(theirVpnIpNet.Addr(), []netip.Addr{relayVpnIpNet.Addr()}) + relayControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr) // Build a router so we don't have to reason who gets which packet r := router.NewR(t, myControl, relayControl, theirControl) @@ -528,11 +528,11 @@ func TestRehandshakingRelays(t *testing.T) { theirControl.Start() t.Log("Trigger a handshake from me to them via the relay") - myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me")) + myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me")) p := r.RouteForAllUntilTxTun(theirControl) r.Log("Assert the tunnel works") - assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.IP, theirVpnIpNet.IP, 80, 80) + assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80) r.RenderHostmaps("working hostmaps", myControl, relayControl, theirControl) // When I update the certificate for the relay, both me and them will have 2 host infos for the relay, @@ -556,8 +556,8 @@ func TestRehandshakingRelays(t *testing.T) { for { r.Log("Assert the tunnel works between myVpnIpNet and relayVpnIpNet") - assertTunnel(t, myVpnIpNet.IP, relayVpnIpNet.IP, myControl, relayControl, r) - c := myControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(relayVpnIpNet.IP), false) + assertTunnel(t, myVpnIpNet.Addr(), relayVpnIpNet.Addr(), myControl, relayControl, r) + c := myControl.GetHostInfoByVpnIp(relayVpnIpNet.Addr(), false) if len(c.Cert.Details.Groups) != 0 { // We have a new certificate now r.Log("Certificate between my and relay is updated!") @@ -569,8 +569,8 @@ func TestRehandshakingRelays(t *testing.T) { for { r.Log("Assert the tunnel works between theirVpnIpNet and relayVpnIpNet") - assertTunnel(t, theirVpnIpNet.IP, relayVpnIpNet.IP, theirControl, relayControl, r) - c := theirControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(relayVpnIpNet.IP), false) + assertTunnel(t, theirVpnIpNet.Addr(), relayVpnIpNet.Addr(), theirControl, relayControl, r) + c := theirControl.GetHostInfoByVpnIp(relayVpnIpNet.Addr(), false) if len(c.Cert.Details.Groups) != 0 { // We have a new certificate now r.Log("Certificate between their and relay is updated!") @@ -581,13 +581,13 @@ func TestRehandshakingRelays(t *testing.T) { } r.Log("Assert the relay tunnel still works") - assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r) + assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r) r.RenderHostmaps("working hostmaps", myControl, relayControl, theirControl) // We should have two hostinfos on all sides for len(myControl.GetHostmap().Indexes) != 2 { t.Logf("Waiting for myControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(myControl.GetHostmap().Indexes)) r.Log("Assert the relay tunnel still works") - assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r) + assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r) r.Log("yupitdoes") time.Sleep(time.Second) } @@ -595,7 +595,7 @@ func TestRehandshakingRelays(t *testing.T) { for len(theirControl.GetHostmap().Indexes) != 2 { t.Logf("Waiting for theirControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(theirControl.GetHostmap().Indexes)) r.Log("Assert the relay tunnel still works") - assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r) + assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r) r.Log("yupitdoes") time.Sleep(time.Second) } @@ -603,7 +603,7 @@ func TestRehandshakingRelays(t *testing.T) { for len(relayControl.GetHostmap().Indexes) != 2 { t.Logf("Waiting for relayControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(relayControl.GetHostmap().Indexes)) r.Log("Assert the relay tunnel still works") - assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r) + assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r) r.Log("yupitdoes") time.Sleep(time.Second) } @@ -612,15 +612,15 @@ func TestRehandshakingRelays(t *testing.T) { func TestRehandshakingRelaysPrimary(t *testing.T) { // This test is the same as TestRehandshakingRelays but one of the terminal types is a primary swap winner - ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) - myControl, myVpnIpNet, _, _ := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 128}, m{"relay": m{"use_relays": true}}) - relayControl, relayVpnIpNet, relayUdpAddr, relayConfig := newSimpleServer(ca, caKey, "relay ", net.IP{10, 0, 0, 1}, m{"relay": m{"am_relay": true}}) - theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", net.IP{10, 0, 0, 2}, m{"relay": m{"use_relays": true}}) + ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, _, _ := newSimpleServer(ca, caKey, "me ", "10.128.0.128/24", m{"relay": m{"use_relays": true}}) + relayControl, relayVpnIpNet, relayUdpAddr, relayConfig := newSimpleServer(ca, caKey, "relay ", "10.128.0.1/24", m{"relay": m{"am_relay": true}}) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}}) // Teach my how to get to the relay and that their can be reached via the relay - myControl.InjectLightHouseAddr(relayVpnIpNet.IP, relayUdpAddr) - myControl.InjectRelays(theirVpnIpNet.IP, []net.IP{relayVpnIpNet.IP}) - relayControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr) + myControl.InjectLightHouseAddr(relayVpnIpNet.Addr(), relayUdpAddr) + myControl.InjectRelays(theirVpnIpNet.Addr(), []netip.Addr{relayVpnIpNet.Addr()}) + relayControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr) // Build a router so we don't have to reason who gets which packet r := router.NewR(t, myControl, relayControl, theirControl) @@ -632,11 +632,11 @@ func TestRehandshakingRelaysPrimary(t *testing.T) { theirControl.Start() t.Log("Trigger a handshake from me to them via the relay") - myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me")) + myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me")) p := r.RouteForAllUntilTxTun(theirControl) r.Log("Assert the tunnel works") - assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.IP, theirVpnIpNet.IP, 80, 80) + assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80) r.RenderHostmaps("working hostmaps", myControl, relayControl, theirControl) // When I update the certificate for the relay, both me and them will have 2 host infos for the relay, @@ -660,8 +660,8 @@ func TestRehandshakingRelaysPrimary(t *testing.T) { for { r.Log("Assert the tunnel works between myVpnIpNet and relayVpnIpNet") - assertTunnel(t, myVpnIpNet.IP, relayVpnIpNet.IP, myControl, relayControl, r) - c := myControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(relayVpnIpNet.IP), false) + assertTunnel(t, myVpnIpNet.Addr(), relayVpnIpNet.Addr(), myControl, relayControl, r) + c := myControl.GetHostInfoByVpnIp(relayVpnIpNet.Addr(), false) if len(c.Cert.Details.Groups) != 0 { // We have a new certificate now r.Log("Certificate between my and relay is updated!") @@ -673,8 +673,8 @@ func TestRehandshakingRelaysPrimary(t *testing.T) { for { r.Log("Assert the tunnel works between theirVpnIpNet and relayVpnIpNet") - assertTunnel(t, theirVpnIpNet.IP, relayVpnIpNet.IP, theirControl, relayControl, r) - c := theirControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(relayVpnIpNet.IP), false) + assertTunnel(t, theirVpnIpNet.Addr(), relayVpnIpNet.Addr(), theirControl, relayControl, r) + c := theirControl.GetHostInfoByVpnIp(relayVpnIpNet.Addr(), false) if len(c.Cert.Details.Groups) != 0 { // We have a new certificate now r.Log("Certificate between their and relay is updated!") @@ -685,13 +685,13 @@ func TestRehandshakingRelaysPrimary(t *testing.T) { } r.Log("Assert the relay tunnel still works") - assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r) + assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r) r.RenderHostmaps("working hostmaps", myControl, relayControl, theirControl) // We should have two hostinfos on all sides for len(myControl.GetHostmap().Indexes) != 2 { t.Logf("Waiting for myControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(myControl.GetHostmap().Indexes)) r.Log("Assert the relay tunnel still works") - assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r) + assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r) r.Log("yupitdoes") time.Sleep(time.Second) } @@ -699,7 +699,7 @@ func TestRehandshakingRelaysPrimary(t *testing.T) { for len(theirControl.GetHostmap().Indexes) != 2 { t.Logf("Waiting for theirControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(theirControl.GetHostmap().Indexes)) r.Log("Assert the relay tunnel still works") - assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r) + assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r) r.Log("yupitdoes") time.Sleep(time.Second) } @@ -707,7 +707,7 @@ func TestRehandshakingRelaysPrimary(t *testing.T) { for len(relayControl.GetHostmap().Indexes) != 2 { t.Logf("Waiting for relayControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(relayControl.GetHostmap().Indexes)) r.Log("Assert the relay tunnel still works") - assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r) + assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r) r.Log("yupitdoes") time.Sleep(time.Second) } @@ -715,13 +715,13 @@ func TestRehandshakingRelaysPrimary(t *testing.T) { } func TestRehandshaking(t *testing.T) { - ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) - myControl, myVpnIpNet, myUdpAddr, myConfig := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 2}, nil) - theirControl, theirVpnIpNet, theirUdpAddr, theirConfig := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 1}, nil) + ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, myUdpAddr, myConfig := newSimpleServer(ca, caKey, "me ", "10.128.0.2/24", nil) + theirControl, theirVpnIpNet, theirUdpAddr, theirConfig := newSimpleServer(ca, caKey, "them", "10.128.0.1/24", nil) // Put their info in our lighthouse and vice versa - myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr) - theirControl.InjectLightHouseAddr(myVpnIpNet.IP, myUdpAddr) + myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr) + theirControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr) // Build a router so we don't have to reason who gets which packet r := router.NewR(t, myControl, theirControl) @@ -732,7 +732,7 @@ func TestRehandshaking(t *testing.T) { theirControl.Start() t.Log("Stand up a tunnel between me and them") - assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r) + assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) r.RenderHostmaps("Starting hostmaps", myControl, theirControl) @@ -754,8 +754,8 @@ func TestRehandshaking(t *testing.T) { myConfig.ReloadConfigString(string(rc)) for { - assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r) - c := theirControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(myVpnIpNet.IP), false) + assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) + c := theirControl.GetHostInfoByVpnIp(myVpnIpNet.Addr(), false) if len(c.Cert.Details.Groups) != 0 { // We have a new certificate now break @@ -781,19 +781,19 @@ func TestRehandshaking(t *testing.T) { r.Log("Spin until there is only 1 tunnel") for len(myControl.GetHostmap().Indexes)+len(theirControl.GetHostmap().Indexes) > 2 { - assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r) + assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) t.Log("Connection manager hasn't ticked yet") time.Sleep(time.Second) } - assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r) + assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) myFinalHostmapHosts := myControl.ListHostmapHosts(false) myFinalHostmapIndexes := myControl.ListHostmapIndexes(false) theirFinalHostmapHosts := theirControl.ListHostmapHosts(false) theirFinalHostmapIndexes := theirControl.ListHostmapIndexes(false) // Make sure the correct tunnel won - c := theirControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(myVpnIpNet.IP), false) + c := theirControl.GetHostInfoByVpnIp(myVpnIpNet.Addr(), false) assert.Contains(t, c.Cert.Details.Groups, "new group") // We should only have a single tunnel now on both sides @@ -811,13 +811,13 @@ func TestRehandshaking(t *testing.T) { func TestRehandshakingLoser(t *testing.T) { // The purpose of this test is that the race loser renews their certificate and rehandshakes. The final tunnel // Should be the one with the new certificate - ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) - myControl, myVpnIpNet, myUdpAddr, myConfig := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 2}, nil) - theirControl, theirVpnIpNet, theirUdpAddr, theirConfig := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 1}, nil) + ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, myUdpAddr, myConfig := newSimpleServer(ca, caKey, "me ", "10.128.0.2/24", nil) + theirControl, theirVpnIpNet, theirUdpAddr, theirConfig := newSimpleServer(ca, caKey, "them", "10.128.0.1/24", nil) // Put their info in our lighthouse and vice versa - myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr) - theirControl.InjectLightHouseAddr(myVpnIpNet.IP, myUdpAddr) + myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr) + theirControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr) // Build a router so we don't have to reason who gets which packet r := router.NewR(t, myControl, theirControl) @@ -828,10 +828,10 @@ func TestRehandshakingLoser(t *testing.T) { theirControl.Start() t.Log("Stand up a tunnel between me and them") - assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r) + assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) - tt1 := myControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(theirVpnIpNet.IP), false) - tt2 := theirControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(myVpnIpNet.IP), false) + tt1 := myControl.GetHostInfoByVpnIp(theirVpnIpNet.Addr(), false) + tt2 := theirControl.GetHostInfoByVpnIp(myVpnIpNet.Addr(), false) fmt.Println(tt1.LocalIndex, tt2.LocalIndex) r.RenderHostmaps("Starting hostmaps", myControl, theirControl) @@ -854,8 +854,8 @@ func TestRehandshakingLoser(t *testing.T) { theirConfig.ReloadConfigString(string(rc)) for { - assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r) - theirCertInMe := myControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(theirVpnIpNet.IP), false) + assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) + theirCertInMe := myControl.GetHostInfoByVpnIp(theirVpnIpNet.Addr(), false) _, theirNewGroup := theirCertInMe.Cert.Details.InvertedGroups["their new group"] if theirNewGroup { @@ -882,19 +882,19 @@ func TestRehandshakingLoser(t *testing.T) { r.Log("Spin until there is only 1 tunnel") for len(myControl.GetHostmap().Indexes)+len(theirControl.GetHostmap().Indexes) > 2 { - assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r) + assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) t.Log("Connection manager hasn't ticked yet") time.Sleep(time.Second) } - assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r) + assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) myFinalHostmapHosts := myControl.ListHostmapHosts(false) myFinalHostmapIndexes := myControl.ListHostmapIndexes(false) theirFinalHostmapHosts := theirControl.ListHostmapHosts(false) theirFinalHostmapIndexes := theirControl.ListHostmapIndexes(false) // Make sure the correct tunnel won - theirCertInMe := myControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(theirVpnIpNet.IP), false) + theirCertInMe := myControl.GetHostInfoByVpnIp(theirVpnIpNet.Addr(), false) assert.Contains(t, theirCertInMe.Cert.Details.Groups, "their new group") // We should only have a single tunnel now on both sides @@ -912,13 +912,13 @@ func TestRaceRegression(t *testing.T) { // This test forces stage 1, stage 2, stage 1 to be received by me from them // We had a bug where we were not finding the duplicate handshake and responding to the final stage 1 which // caused a cross-linked hostinfo - ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) - myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me", net.IP{10, 0, 0, 1}, nil) - theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil) + ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me", "10.128.0.1/24", nil) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", "10.128.0.2/24", nil) // Put their info in our lighthouse - myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr) - theirControl.InjectLightHouseAddr(myVpnIpNet.IP, myUdpAddr) + myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr) + theirControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr) // Start the servers myControl.Start() @@ -932,8 +932,8 @@ func TestRaceRegression(t *testing.T) { //them rx stage:2 initiatorIndex=120607833 responderIndex=4209862089 t.Log("Start both handshakes") - myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me")) - theirControl.InjectTunUDPPacket(myVpnIpNet.IP, 80, 80, []byte("Hi from them")) + myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me")) + theirControl.InjectTunUDPPacket(myVpnIpNet.Addr(), 80, 80, []byte("Hi from them")) t.Log("Get both stage 1") myStage1ForThem := myControl.GetFromUDP(true) @@ -963,7 +963,7 @@ func TestRaceRegression(t *testing.T) { r.RenderHostmaps("Starting hostmaps", myControl, theirControl) t.Log("Make sure the tunnel still works") - assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r) + assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) myControl.Stop() theirControl.Stop() diff --git a/e2e/helpers.go b/e2e/helpers.go index 13146ab..71df805 100644 --- a/e2e/helpers.go +++ b/e2e/helpers.go @@ -4,6 +4,7 @@ import ( "crypto/rand" "io" "net" + "net/netip" "time" "github.com/slackhq/nebula/cert" @@ -12,7 +13,7 @@ import ( ) // NewTestCaCert will generate a CA cert -func NewTestCaCert(before, after time.Time, ips, subnets []*net.IPNet, groups []string) (*cert.NebulaCertificate, []byte, []byte, []byte) { +func NewTestCaCert(before, after time.Time, ips, subnets []netip.Prefix, groups []string) (*cert.NebulaCertificate, []byte, []byte, []byte) { pub, priv, err := ed25519.GenerateKey(rand.Reader) if before.IsZero() { before = time.Now().Add(time.Second * -60).Round(time.Second) @@ -33,11 +34,17 @@ func NewTestCaCert(before, after time.Time, ips, subnets []*net.IPNet, groups [] } if len(ips) > 0 { - nc.Details.Ips = ips + nc.Details.Ips = make([]*net.IPNet, len(ips)) + for i, ip := range ips { + nc.Details.Ips[i] = &net.IPNet{IP: ip.Addr().AsSlice(), Mask: net.CIDRMask(ip.Bits(), ip.Addr().BitLen())} + } } if len(subnets) > 0 { - nc.Details.Subnets = subnets + nc.Details.Subnets = make([]*net.IPNet, len(subnets)) + for i, ip := range subnets { + nc.Details.Ips[i] = &net.IPNet{IP: ip.Addr().AsSlice(), Mask: net.CIDRMask(ip.Bits(), ip.Addr().BitLen())} + } } if len(groups) > 0 { @@ -59,7 +66,7 @@ func NewTestCaCert(before, after time.Time, ips, subnets []*net.IPNet, groups [] // NewTestCert will generate a signed certificate with the provided details. // Expiry times are defaulted if you do not pass them in -func NewTestCert(ca *cert.NebulaCertificate, key []byte, name string, before, after time.Time, ip *net.IPNet, subnets []*net.IPNet, groups []string) (*cert.NebulaCertificate, []byte, []byte, []byte) { +func NewTestCert(ca *cert.NebulaCertificate, key []byte, name string, before, after time.Time, ip netip.Prefix, subnets []netip.Prefix, groups []string) (*cert.NebulaCertificate, []byte, []byte, []byte) { issuer, err := ca.Sha256Sum() if err != nil { panic(err) @@ -74,12 +81,12 @@ func NewTestCert(ca *cert.NebulaCertificate, key []byte, name string, before, af } pub, rawPriv := x25519Keypair() - + ipb := ip.Addr().AsSlice() nc := &cert.NebulaCertificate{ Details: cert.NebulaCertificateDetails{ - Name: name, - Ips: []*net.IPNet{ip}, - Subnets: subnets, + Name: name, + Ips: []*net.IPNet{{IP: ipb[:], Mask: net.CIDRMask(ip.Bits(), ip.Addr().BitLen())}}, + //Subnets: subnets, Groups: groups, NotBefore: time.Unix(before.Unix(), 0), NotAfter: time.Unix(after.Unix(), 0), diff --git a/e2e/helpers_test.go b/e2e/helpers_test.go index b05c84a..527f55b 100644 --- a/e2e/helpers_test.go +++ b/e2e/helpers_test.go @@ -6,7 +6,7 @@ package e2e import ( "fmt" "io" - "net" + "net/netip" "os" "testing" "time" @@ -19,7 +19,6 @@ import ( "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/e2e/router" - "github.com/slackhq/nebula/iputil" "github.com/stretchr/testify/assert" "gopkg.in/yaml.v2" ) @@ -27,15 +26,23 @@ import ( type m map[string]interface{} // newSimpleServer creates a nebula instance with many assumptions -func newSimpleServer(caCrt *cert.NebulaCertificate, caKey []byte, name string, udpIp net.IP, overrides m) (*nebula.Control, *net.IPNet, *net.UDPAddr, *config.C) { +func newSimpleServer(caCrt *cert.NebulaCertificate, caKey []byte, name string, sVpnIpNet string, overrides m) (*nebula.Control, netip.Prefix, netip.AddrPort, *config.C) { l := NewTestLogger() - vpnIpNet := &net.IPNet{IP: make([]byte, len(udpIp)), Mask: net.IPMask{255, 255, 255, 0}} - copy(vpnIpNet.IP, udpIp) - vpnIpNet.IP[1] += 128 - udpAddr := net.UDPAddr{ - IP: udpIp, - Port: 4242, + vpnIpNet, err := netip.ParsePrefix(sVpnIpNet) + if err != nil { + panic(err) + } + + var udpAddr netip.AddrPort + if vpnIpNet.Addr().Is4() { + budpIp := vpnIpNet.Addr().As4() + budpIp[1] -= 128 + udpAddr = netip.AddrPortFrom(netip.AddrFrom4(budpIp), 4242) + } else { + budpIp := vpnIpNet.Addr().As16() + budpIp[13] -= 128 + udpAddr = netip.AddrPortFrom(netip.AddrFrom16(budpIp), 4242) } _, _, myPrivKey, myPEM := NewTestCert(caCrt, caKey, name, time.Now(), time.Now().Add(5*time.Minute), vpnIpNet, nil, []string{}) @@ -67,8 +74,8 @@ func newSimpleServer(caCrt *cert.NebulaCertificate, caKey []byte, name string, u // "try_interval": "1s", //}, "listen": m{ - "host": udpAddr.IP.String(), - "port": udpAddr.Port, + "host": udpAddr.Addr().String(), + "port": udpAddr.Port(), }, "logging": m{ "timestamp_format": fmt.Sprintf("%v 15:04:05.000000", name), @@ -102,7 +109,7 @@ func newSimpleServer(caCrt *cert.NebulaCertificate, caKey []byte, name string, u panic(err) } - return control, vpnIpNet, &udpAddr, c + return control, vpnIpNet, udpAddr, c } type doneCb func() @@ -123,7 +130,7 @@ func deadline(t *testing.T, seconds time.Duration) doneCb { } } -func assertTunnel(t *testing.T, vpnIpA, vpnIpB net.IP, controlA, controlB *nebula.Control, r *router.R) { +func assertTunnel(t *testing.T, vpnIpA, vpnIpB netip.Addr, controlA, controlB *nebula.Control, r *router.R) { // Send a packet from them to me controlB.InjectTunUDPPacket(vpnIpA, 80, 90, []byte("Hi from B")) bPacket := r.RouteForAllUntilTxTun(controlA) @@ -135,23 +142,20 @@ func assertTunnel(t *testing.T, vpnIpA, vpnIpB net.IP, controlA, controlB *nebul assertUdpPacket(t, []byte("Hello from A"), aPacket, vpnIpA, vpnIpB, 90, 80) } -func assertHostInfoPair(t *testing.T, addrA, addrB *net.UDPAddr, vpnIpA, vpnIpB net.IP, controlA, controlB *nebula.Control) { +func assertHostInfoPair(t *testing.T, addrA, addrB netip.AddrPort, vpnIpA, vpnIpB netip.Addr, controlA, controlB *nebula.Control) { // Get both host infos - hBinA := controlA.GetHostInfoByVpnIp(iputil.Ip2VpnIp(vpnIpB), false) + hBinA := controlA.GetHostInfoByVpnIp(vpnIpB, false) assert.NotNil(t, hBinA, "Host B was not found by vpnIp in controlA") - hAinB := controlB.GetHostInfoByVpnIp(iputil.Ip2VpnIp(vpnIpA), false) + hAinB := controlB.GetHostInfoByVpnIp(vpnIpA, false) assert.NotNil(t, hAinB, "Host A was not found by vpnIp in controlB") // Check that both vpn and real addr are correct assert.Equal(t, vpnIpB, hBinA.VpnIp, "Host B VpnIp is wrong in control A") assert.Equal(t, vpnIpA, hAinB.VpnIp, "Host A VpnIp is wrong in control B") - assert.Equal(t, addrB.IP.To16(), hBinA.CurrentRemote.IP.To16(), "Host B remote ip is wrong in control A") - assert.Equal(t, addrA.IP.To16(), hAinB.CurrentRemote.IP.To16(), "Host A remote ip is wrong in control B") - - assert.Equal(t, addrB.Port, int(hBinA.CurrentRemote.Port), "Host B remote port is wrong in control A") - assert.Equal(t, addrA.Port, int(hAinB.CurrentRemote.Port), "Host A remote port is wrong in control B") + assert.Equal(t, addrB, hBinA.CurrentRemote, "Host B remote is wrong in control A") + assert.Equal(t, addrA, hAinB.CurrentRemote, "Host A remote is wrong in control B") // Check that our indexes match assert.Equal(t, hBinA.LocalIndex, hAinB.RemoteIndex, "Host B local index does not match host A remote index") @@ -174,13 +178,13 @@ func assertHostInfoPair(t *testing.T, addrA, addrB *net.UDPAddr, vpnIpA, vpnIpB //checkIndexes("hmB", hmB, hAinB) } -func assertUdpPacket(t *testing.T, expected, b []byte, fromIp, toIp net.IP, fromPort, toPort uint16) { +func assertUdpPacket(t *testing.T, expected, b []byte, fromIp, toIp netip.Addr, fromPort, toPort uint16) { packet := gopacket.NewPacket(b, layers.LayerTypeIPv4, gopacket.Lazy) v4 := packet.Layer(layers.LayerTypeIPv4).(*layers.IPv4) assert.NotNil(t, v4, "No ipv4 data found") - assert.Equal(t, fromIp, v4.SrcIP, "Source ip was incorrect") - assert.Equal(t, toIp, v4.DstIP, "Dest ip was incorrect") + assert.Equal(t, fromIp.AsSlice(), []byte(v4.SrcIP), "Source ip was incorrect") + assert.Equal(t, toIp.AsSlice(), []byte(v4.DstIP), "Dest ip was incorrect") udp := packet.Layer(layers.LayerTypeUDP).(*layers.UDP) assert.NotNil(t, udp, "No udp data found") diff --git a/e2e/router/hostmap.go b/e2e/router/hostmap.go index 120be69..c14ab2e 100644 --- a/e2e/router/hostmap.go +++ b/e2e/router/hostmap.go @@ -5,11 +5,11 @@ package router import ( "fmt" + "net/netip" "sort" "strings" "github.com/slackhq/nebula" - "github.com/slackhq/nebula/iputil" ) type edge struct { @@ -118,14 +118,14 @@ func renderHostmap(c *nebula.Control) (string, []*edge) { return r, globalLines } -func sortedHosts(hosts map[iputil.VpnIp]*nebula.HostInfo) []iputil.VpnIp { - keys := make([]iputil.VpnIp, 0, len(hosts)) +func sortedHosts(hosts map[netip.Addr]*nebula.HostInfo) []netip.Addr { + keys := make([]netip.Addr, 0, len(hosts)) for key := range hosts { keys = append(keys, key) } sort.SliceStable(keys, func(i, j int) bool { - return keys[i] > keys[j] + return keys[i].Compare(keys[j]) > 0 }) return keys diff --git a/e2e/router/router.go b/e2e/router/router.go index 730853a..0890570 100644 --- a/e2e/router/router.go +++ b/e2e/router/router.go @@ -6,12 +6,11 @@ package router import ( "context" "fmt" - "net" + "net/netip" "os" "path/filepath" "reflect" "sort" - "strconv" "strings" "sync" "testing" @@ -21,7 +20,6 @@ import ( "github.com/google/gopacket/layers" "github.com/slackhq/nebula" "github.com/slackhq/nebula/header" - "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/udp" "golang.org/x/exp/maps" ) @@ -29,18 +27,18 @@ import ( type R struct { // Simple map of the ip:port registered on a control to the control // Basically a router, right? - controls map[string]*nebula.Control + controls map[netip.AddrPort]*nebula.Control // A map for inbound packets for a control that doesn't know about this address - inNat map[string]*nebula.Control + inNat map[netip.AddrPort]*nebula.Control // A last used map, if an inbound packet hit the inNat map then // all return packets should use the same last used inbound address for the outbound sender // map[from address + ":" + to address] => ip:port to rewrite in the udp packet to receiver - outNat map[string]net.UDPAddr + outNat map[string]netip.AddrPort // A map of vpn ip to the nebula control it belongs to - vpnControls map[iputil.VpnIp]*nebula.Control + vpnControls map[netip.Addr]*nebula.Control ignoreFlows []ignoreFlow flow []flowEntry @@ -118,10 +116,10 @@ func NewR(t testing.TB, controls ...*nebula.Control) *R { } r := &R{ - controls: make(map[string]*nebula.Control), - vpnControls: make(map[iputil.VpnIp]*nebula.Control), - inNat: make(map[string]*nebula.Control), - outNat: make(map[string]net.UDPAddr), + controls: make(map[netip.AddrPort]*nebula.Control), + vpnControls: make(map[netip.Addr]*nebula.Control), + inNat: make(map[netip.AddrPort]*nebula.Control), + outNat: make(map[string]netip.AddrPort), flow: []flowEntry{}, ignoreFlows: []ignoreFlow{}, fn: filepath.Join("mermaid", fmt.Sprintf("%s.md", t.Name())), @@ -135,7 +133,7 @@ func NewR(t testing.TB, controls ...*nebula.Control) *R { for _, c := range controls { addr := c.GetUDPAddr() if _, ok := r.controls[addr]; ok { - panic("Duplicate listen address: " + addr) + panic("Duplicate listen address: " + addr.String()) } r.vpnControls[c.GetVpnIp()] = c @@ -165,13 +163,13 @@ func NewR(t testing.TB, controls ...*nebula.Control) *R { // It does not look at the addr attached to the instance. // If a route is used, this will behave like a NAT for the return path. // Rewriting the source ip:port to what was last sent to from the origin -func (r *R) AddRoute(ip net.IP, port uint16, c *nebula.Control) { +func (r *R) AddRoute(ip netip.Addr, port uint16, c *nebula.Control) { r.Lock() defer r.Unlock() - inAddr := net.JoinHostPort(ip.String(), fmt.Sprintf("%v", port)) + inAddr := netip.AddrPortFrom(ip, port) if _, ok := r.inNat[inAddr]; ok { - panic("Duplicate listen address inNat: " + inAddr) + panic("Duplicate listen address inNat: " + inAddr.String()) } r.inNat[inAddr] = c } @@ -198,7 +196,7 @@ func (r *R) renderFlow() { panic(err) } - var participants = map[string]struct{}{} + var participants = map[netip.AddrPort]struct{}{} var participantsVals []string fmt.Fprintln(f, "```mermaid") @@ -215,7 +213,7 @@ func (r *R) renderFlow() { continue } participants[addr] = struct{}{} - sanAddr := strings.Replace(addr, ":", "-", 1) + sanAddr := strings.Replace(addr.String(), ":", "-", 1) participantsVals = append(participantsVals, sanAddr) fmt.Fprintf( f, " participant %s as Nebula: %s
UDP: %s\n", @@ -252,9 +250,9 @@ func (r *R) renderFlow() { fmt.Fprintf(f, " %s%s%s: %s(%s), index %v, counter: %v\n", - strings.Replace(p.from.GetUDPAddr(), ":", "-", 1), + strings.Replace(p.from.GetUDPAddr().String(), ":", "-", 1), line, - strings.Replace(p.to.GetUDPAddr(), ":", "-", 1), + strings.Replace(p.to.GetUDPAddr().String(), ":", "-", 1), h.TypeName(), h.SubTypeName(), h.RemoteIndex, h.MessageCounter, ) } @@ -305,7 +303,7 @@ func (r *R) RenderHostmaps(title string, controls ...*nebula.Control) { func (r *R) renderHostmaps(title string) { c := maps.Values(r.controls) sort.SliceStable(c, func(i, j int) bool { - return c[i].GetVpnIp() > c[j].GetVpnIp() + return c[i].GetVpnIp().Compare(c[j].GetVpnIp()) > 0 }) s := renderHostmaps(c...) @@ -420,10 +418,8 @@ func (r *R) RouteUntilTxTun(sender *nebula.Control, receiver *nebula.Control) [] // Nope, lets push the sender along case p := <-udpTx: - outAddr := sender.GetUDPAddr() r.Lock() - inAddr := net.JoinHostPort(p.ToIp.String(), fmt.Sprintf("%v", p.ToPort)) - c := r.getControl(outAddr, inAddr, p) + c := r.getControl(sender.GetUDPAddr(), p.To, p) if c == nil { r.Unlock() panic("No control for udp tx") @@ -479,10 +475,7 @@ func (r *R) RouteForAllUntilTxTun(receiver *nebula.Control) []byte { } else { // we are a udp tx, route and continue p := rx.Interface().(*udp.Packet) - outAddr := cm[x].GetUDPAddr() - - inAddr := net.JoinHostPort(p.ToIp.String(), fmt.Sprintf("%v", p.ToPort)) - c := r.getControl(outAddr, inAddr, p) + c := r.getControl(cm[x].GetUDPAddr(), p.To, p) if c == nil { r.Unlock() panic("No control for udp tx") @@ -509,12 +502,10 @@ func (r *R) RouteExitFunc(sender *nebula.Control, whatDo ExitFunc) { panic(err) } - outAddr := sender.GetUDPAddr() - inAddr := net.JoinHostPort(p.ToIp.String(), fmt.Sprintf("%v", p.ToPort)) - receiver := r.getControl(outAddr, inAddr, p) + receiver := r.getControl(sender.GetUDPAddr(), p.To, p) if receiver == nil { r.Unlock() - panic("Can't route for host: " + inAddr) + panic("Can't RouteExitFunc for host: " + p.To.String()) } e := whatDo(p, receiver) @@ -590,13 +581,13 @@ func (r *R) InjectUDPPacket(sender, receiver *nebula.Control, packet *udp.Packet // RouteForUntilAfterToAddr will route for sender and return only after it sees and sends a packet destined for toAddr // finish can be any of the exitType values except `keepRouting`, the default value is `routeAndExit` // If the router doesn't have the nebula controller for that address, we panic -func (r *R) RouteForUntilAfterToAddr(sender *nebula.Control, toAddr *net.UDPAddr, finish ExitType) { +func (r *R) RouteForUntilAfterToAddr(sender *nebula.Control, toAddr netip.AddrPort, finish ExitType) { if finish == KeepRouting { finish = RouteAndExit } r.RouteExitFunc(sender, func(p *udp.Packet, r *nebula.Control) ExitType { - if p.ToIp.Equal(toAddr.IP) && p.ToPort == uint16(toAddr.Port) { + if p.To == toAddr { return finish } @@ -630,13 +621,10 @@ func (r *R) RouteForAllExitFunc(whatDo ExitFunc) { r.Lock() p := rx.Interface().(*udp.Packet) - - outAddr := cm[x].GetUDPAddr() - inAddr := net.JoinHostPort(p.ToIp.String(), fmt.Sprintf("%v", p.ToPort)) - receiver := r.getControl(outAddr, inAddr, p) + receiver := r.getControl(cm[x].GetUDPAddr(), p.To, p) if receiver == nil { r.Unlock() - panic("Can't route for host: " + inAddr) + panic("Can't RouteForAllExitFunc for host: " + p.To.String()) } e := whatDo(p, receiver) @@ -697,12 +685,10 @@ func (r *R) FlushAll() { p := rx.Interface().(*udp.Packet) - outAddr := cm[x].GetUDPAddr() - inAddr := net.JoinHostPort(p.ToIp.String(), fmt.Sprintf("%v", p.ToPort)) - receiver := r.getControl(outAddr, inAddr, p) + receiver := r.getControl(cm[x].GetUDPAddr(), p.To, p) if receiver == nil { r.Unlock() - panic("Can't route for host: " + inAddr) + panic("Can't FlushAll for host: " + p.To.String()) } r.Unlock() } @@ -710,28 +696,14 @@ func (r *R) FlushAll() { // getControl performs or seeds NAT translation and returns the control for toAddr, p from fields may change // This is an internal router function, the caller must hold the lock -func (r *R) getControl(fromAddr, toAddr string, p *udp.Packet) *nebula.Control { - if newAddr, ok := r.outNat[fromAddr+":"+toAddr]; ok { - p.FromIp = newAddr.IP - p.FromPort = uint16(newAddr.Port) +func (r *R) getControl(fromAddr, toAddr netip.AddrPort, p *udp.Packet) *nebula.Control { + if newAddr, ok := r.outNat[fromAddr.String()+":"+toAddr.String()]; ok { + p.From = newAddr } c, ok := r.inNat[toAddr] if ok { - sHost, sPort, err := net.SplitHostPort(toAddr) - if err != nil { - panic(err) - } - - port, err := strconv.Atoi(sPort) - if err != nil { - panic(err) - } - - r.outNat[c.GetUDPAddr()+":"+fromAddr] = net.UDPAddr{ - IP: net.ParseIP(sHost), - Port: port, - } + r.outNat[c.GetUDPAddr().String()+":"+fromAddr.String()] = toAddr return c } @@ -746,8 +718,9 @@ func (r *R) formatUdpPacket(p *packet) string { } from := "unknown" - if c, ok := r.vpnControls[iputil.Ip2VpnIp(v4.SrcIP)]; ok { - from = c.GetUDPAddr() + srcAddr, _ := netip.AddrFromSlice(v4.SrcIP) + if c, ok := r.vpnControls[srcAddr]; ok { + from = c.GetUDPAddr().String() } udp := packet.Layer(layers.LayerTypeUDP).(*layers.UDP) @@ -759,7 +732,7 @@ func (r *R) formatUdpPacket(p *packet) string { return fmt.Sprintf( " %s-->>%s: src port: %v
dest port: %v
data: \"%v\"\n", strings.Replace(from, ":", "-", 1), - strings.Replace(p.to.GetUDPAddr(), ":", "-", 1), + strings.Replace(p.to.GetUDPAddr().String(), ":", "-", 1), udp.SrcPort, udp.DstPort, string(data.Payload()), diff --git a/firewall.go b/firewall.go index 3e760fe..8a409d2 100644 --- a/firewall.go +++ b/firewall.go @@ -6,23 +6,23 @@ import ( "errors" "fmt" "hash/fnv" - "net" + "net/netip" "reflect" "strconv" "strings" "sync" "time" + "github.com/gaissmai/bart" "github.com/rcrowley/go-metrics" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/cert" - "github.com/slackhq/nebula/cidr" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/firewall" ) type FirewallInterface interface { - AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, localIp *net.IPNet, caName string, caSha string) error + AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip, localIp netip.Prefix, caName string, caSha string) error } type conn struct { @@ -52,8 +52,8 @@ type Firewall struct { DefaultTimeout time.Duration //linux: 600s // Used to ensure we don't emit local packets for ips we don't own - localIps *cidr.Tree4[struct{}] - assignedCIDR *net.IPNet + localIps *bart.Table[struct{}] + assignedCIDR netip.Prefix hasSubnets bool rules string @@ -108,7 +108,7 @@ type FirewallRule struct { Any *firewallLocalCIDR Hosts map[string]*firewallLocalCIDR Groups []*firewallGroups - CIDR *cidr.Tree4[*firewallLocalCIDR] + CIDR *bart.Table[*firewallLocalCIDR] } type firewallGroups struct { @@ -122,7 +122,7 @@ type firewallPort map[int32]*FirewallCA type firewallLocalCIDR struct { Any bool - LocalCIDR *cidr.Tree4[struct{}] + LocalCIDR *bart.Table[struct{}] } // NewFirewall creates a new Firewall object. A TimerWheel is created for you from the provided timeouts. @@ -144,20 +144,28 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D max = defaultTimeout } - localIps := cidr.NewTree4[struct{}]() - var assignedCIDR *net.IPNet + localIps := new(bart.Table[struct{}]) + var assignedCIDR netip.Prefix + var assignedSet bool for _, ip := range c.Details.Ips { - ipNet := &net.IPNet{IP: ip.IP, Mask: net.IPMask{255, 255, 255, 255}} - localIps.AddCIDR(ipNet, struct{}{}) + //TODO: IPV6-WORK the unmap is a bit unfortunate + nip, _ := netip.AddrFromSlice(ip.IP) + nip = nip.Unmap() + nprefix := netip.PrefixFrom(nip, nip.BitLen()) + localIps.Insert(nprefix, struct{}{}) - if assignedCIDR == nil { + if !assignedSet { // Only grabbing the first one in the cert since any more than that currently has undefined behavior - assignedCIDR = ipNet + assignedCIDR = nprefix + assignedSet = true } } for _, n := range c.Details.Subnets { - localIps.AddCIDR(n, struct{}{}) + nip, _ := netip.AddrFromSlice(n.IP) + ones, _ := n.Mask.Size() + nip = nip.Unmap() + localIps.Insert(netip.PrefixFrom(nip, ones), struct{}{}) } return &Firewall{ @@ -237,15 +245,15 @@ func NewFirewallFromConfig(l *logrus.Logger, nc *cert.NebulaCertificate, c *conf } // AddRule properly creates the in memory rule structure for a firewall table. -func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, localIp *net.IPNet, caName string, caSha string) error { +func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip, localIp netip.Prefix, caName string, caSha string) error { // Under gomobile, stringing a nil pointer with fmt causes an abort in debug mode for iOS // https://github.com/golang/go/issues/14131 sIp := "" - if ip != nil { + if ip.IsValid() { sIp = ip.String() } lIp := "" - if localIp != nil { + if localIp.IsValid() { lIp = localIp.String() } @@ -382,17 +390,17 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw return fmt.Errorf("%s rule #%v; proto was not understood; `%s`", table, i, r.Proto) } - var cidr *net.IPNet + var cidr netip.Prefix if r.Cidr != "" { - _, cidr, err = net.ParseCIDR(r.Cidr) + cidr, err = netip.ParsePrefix(r.Cidr) if err != nil { return fmt.Errorf("%s rule #%v; cidr did not parse; %s", table, i, err) } } - var localCidr *net.IPNet + var localCidr netip.Prefix if r.LocalCidr != "" { - _, localCidr, err = net.ParseCIDR(r.LocalCidr) + localCidr, err = netip.ParsePrefix(r.LocalCidr) if err != nil { return fmt.Errorf("%s rule #%v; local_cidr did not parse; %s", table, i, err) } @@ -421,7 +429,8 @@ func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool * // Make sure remote address matches nebula certificate if remoteCidr := h.remoteCidr; remoteCidr != nil { - ok, _ := remoteCidr.Contains(fp.RemoteIP) + //TODO: this would be better if we had a least specific match lookup, could waste time here, need to benchmark since the algo is different + _, ok := remoteCidr.Lookup(fp.RemoteIP) if !ok { f.metrics(incoming).droppedRemoteIP.Inc(1) return ErrInvalidRemoteIP @@ -435,7 +444,8 @@ func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool * } // Make sure we are supposed to be handling this local ip address - ok, _ := f.localIps.Contains(fp.LocalIP) + //TODO: this would be better if we had a least specific match lookup, could waste time here, need to benchmark since the algo is different + _, ok := f.localIps.Lookup(fp.LocalIP) if !ok { f.metrics(incoming).droppedLocalIP.Inc(1) return ErrInvalidLocalIP @@ -589,7 +599,6 @@ func (f *Firewall) addConn(fp firewall.Packet, incoming bool) { // Evict checks if a conntrack entry has expired, if so it is removed, if not it is re-added to the wheel // Caller must own the connMutex lock! func (f *Firewall) evict(p firewall.Packet) { - //TODO: report a stat if the tcp rtt tracking was never resolved? // Are we still tracking this conn? conntrack := f.Conntrack t, ok := conntrack.Conns[p] @@ -633,7 +642,7 @@ func (ft *FirewallTable) match(p firewall.Packet, incoming bool, c *cert.NebulaC return false } -func (fp firewallPort) addRule(f *Firewall, startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, localIp *net.IPNet, caName string, caSha string) error { +func (fp firewallPort) addRule(f *Firewall, startPort int32, endPort int32, groups []string, host string, ip, localIp netip.Prefix, caName string, caSha string) error { if startPort > endPort { return fmt.Errorf("start port was lower than end port") } @@ -677,12 +686,12 @@ func (fp firewallPort) match(p firewall.Packet, incoming bool, c *cert.NebulaCer return fp[firewall.PortAny].match(p, c, caPool) } -func (fc *FirewallCA) addRule(f *Firewall, groups []string, host string, ip, localIp *net.IPNet, caName, caSha string) error { +func (fc *FirewallCA) addRule(f *Firewall, groups []string, host string, ip, localIp netip.Prefix, caName, caSha string) error { fr := func() *FirewallRule { return &FirewallRule{ Hosts: make(map[string]*firewallLocalCIDR), Groups: make([]*firewallGroups, 0), - CIDR: cidr.NewTree4[*firewallLocalCIDR](), + CIDR: new(bart.Table[*firewallLocalCIDR]), } } @@ -740,10 +749,10 @@ func (fc *FirewallCA) match(p firewall.Packet, c *cert.NebulaCertificate, caPool return fc.CANames[s.Details.Name].match(p, c) } -func (fr *FirewallRule) addRule(f *Firewall, groups []string, host string, ip *net.IPNet, localCIDR *net.IPNet) error { +func (fr *FirewallRule) addRule(f *Firewall, groups []string, host string, ip, localCIDR netip.Prefix) error { flc := func() *firewallLocalCIDR { return &firewallLocalCIDR{ - LocalCIDR: cidr.NewTree4[struct{}](), + LocalCIDR: new(bart.Table[struct{}]), } } @@ -780,8 +789,8 @@ func (fr *FirewallRule) addRule(f *Firewall, groups []string, host string, ip *n fr.Hosts[host] = nlc } - if ip != nil { - _, nlc := fr.CIDR.GetCIDR(ip) + if ip.IsValid() { + nlc, _ := fr.CIDR.Get(ip) if nlc == nil { nlc = flc() } @@ -789,14 +798,14 @@ func (fr *FirewallRule) addRule(f *Firewall, groups []string, host string, ip *n if err != nil { return err } - fr.CIDR.AddCIDR(ip, nlc) + fr.CIDR.Insert(ip, nlc) } return nil } -func (fr *FirewallRule) isAny(groups []string, host string, ip *net.IPNet) bool { - if len(groups) == 0 && host == "" && ip == nil { +func (fr *FirewallRule) isAny(groups []string, host string, ip netip.Prefix) bool { + if len(groups) == 0 && host == "" && !ip.IsValid() { return true } @@ -810,7 +819,7 @@ func (fr *FirewallRule) isAny(groups []string, host string, ip *net.IPNet) bool return true } - if ip != nil && ip.Contains(net.IPv4(0, 0, 0, 0)) { + if ip.IsValid() && ip.Bits() == 0 { return true } @@ -853,24 +862,31 @@ func (fr *FirewallRule) match(p firewall.Packet, c *cert.NebulaCertificate) bool } } - return fr.CIDR.EachContains(p.RemoteIP, func(flc *firewallLocalCIDR) bool { - return flc.match(p, c) + matched := false + prefix := netip.PrefixFrom(p.RemoteIP, p.RemoteIP.BitLen()) + fr.CIDR.EachLookupPrefix(prefix, func(prefix netip.Prefix, val *firewallLocalCIDR) bool { + if prefix.Contains(p.RemoteIP) && val.match(p, c) { + matched = true + return false + } + return true }) + return matched } -func (flc *firewallLocalCIDR) addRule(f *Firewall, localIp *net.IPNet) error { - if localIp == nil { +func (flc *firewallLocalCIDR) addRule(f *Firewall, localIp netip.Prefix) error { + if !localIp.IsValid() { if !f.hasSubnets || f.defaultLocalCIDRAny { flc.Any = true return nil } localIp = f.assignedCIDR - } else if localIp.Contains(net.IPv4(0, 0, 0, 0)) { + } else if localIp.Bits() == 0 { flc.Any = true } - flc.LocalCIDR.AddCIDR(localIp, struct{}{}) + flc.LocalCIDR.Insert(localIp, struct{}{}) return nil } @@ -883,7 +899,7 @@ func (flc *firewallLocalCIDR) match(p firewall.Packet, c *cert.NebulaCertificate return true } - ok, _ := flc.LocalCIDR.Contains(p.LocalIP) + _, ok := flc.LocalCIDR.Lookup(p.LocalIP) return ok } diff --git a/firewall/packet.go b/firewall/packet.go index 1c4affd..8954f4c 100644 --- a/firewall/packet.go +++ b/firewall/packet.go @@ -3,8 +3,7 @@ package firewall import ( "encoding/json" "fmt" - - "github.com/slackhq/nebula/iputil" + "net/netip" ) type m map[string]interface{} @@ -20,8 +19,8 @@ const ( ) type Packet struct { - LocalIP iputil.VpnIp - RemoteIP iputil.VpnIp + LocalIP netip.Addr + RemoteIP netip.Addr LocalPort uint16 RemotePort uint16 Protocol uint8 diff --git a/firewall_test.go b/firewall_test.go index b5beff6..4d47e78 100644 --- a/firewall_test.go +++ b/firewall_test.go @@ -5,13 +5,13 @@ import ( "errors" "math" "net" + "net/netip" "testing" "time" "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/firewall" - "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/test" "github.com/stretchr/testify/assert" ) @@ -65,59 +65,62 @@ func TestFirewall_AddRule(t *testing.T) { assert.NotNil(t, fw.InRules) assert.NotNil(t, fw.OutRules) - _, ti, _ := net.ParseCIDR("1.2.3.4/32") + ti, err := netip.ParsePrefix("1.2.3.4/32") + assert.NoError(t, err) - assert.Nil(t, fw.AddRule(true, firewall.ProtoTCP, 1, 1, []string{}, "", nil, nil, "", "")) + assert.Nil(t, fw.AddRule(true, firewall.ProtoTCP, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", "")) // An empty rule is any assert.True(t, fw.InRules.TCP[1].Any.Any.Any) assert.Empty(t, fw.InRules.TCP[1].Any.Groups) assert.Empty(t, fw.InRules.TCP[1].Any.Hosts) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) - assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", nil, nil, "", "")) + assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) assert.Nil(t, fw.InRules.UDP[1].Any.Any) assert.Contains(t, fw.InRules.UDP[1].Any.Groups[0].Groups, "g1") assert.Empty(t, fw.InRules.UDP[1].Any.Hosts) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) - assert.Nil(t, fw.AddRule(true, firewall.ProtoICMP, 1, 1, []string{}, "h1", nil, nil, "", "")) + assert.Nil(t, fw.AddRule(true, firewall.ProtoICMP, 1, 1, []string{}, "h1", netip.Prefix{}, netip.Prefix{}, "", "")) assert.Nil(t, fw.InRules.ICMP[1].Any.Any) assert.Empty(t, fw.InRules.ICMP[1].Any.Groups) assert.Contains(t, fw.InRules.ICMP[1].Any.Hosts, "h1") fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) - assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti, nil, "", "")) + assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti, netip.Prefix{}, "", "")) assert.Nil(t, fw.OutRules.AnyProto[1].Any.Any) - ok, _ := fw.OutRules.AnyProto[1].Any.CIDR.GetCIDR(ti) + _, ok := fw.OutRules.AnyProto[1].Any.CIDR.Get(ti) assert.True(t, ok) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) - assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", nil, ti, "", "")) + assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, ti, "", "")) assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any) - ok, _ = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.GetCIDR(ti) + _, ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti) assert.True(t, ok) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) - assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", nil, nil, "ca-name", "")) + assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "ca-name", "")) assert.Contains(t, fw.InRules.UDP[1].CANames, "ca-name") fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) - assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", nil, nil, "", "ca-sha")) + assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "", "ca-sha")) assert.Contains(t, fw.InRules.UDP[1].CAShas, "ca-sha") fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) - assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", nil, nil, "", "")) + assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", netip.Prefix{}, netip.Prefix{}, "", "")) assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) - _, anyIp, _ := net.ParseCIDR("0.0.0.0/0") - assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp, nil, "", "")) + anyIp, err := netip.ParsePrefix("0.0.0.0/0") + assert.NoError(t, err) + + assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp, netip.Prefix{}, "", "")) assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any) // Test error conditions fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) - assert.Error(t, fw.AddRule(true, math.MaxUint8, 0, 0, []string{}, "", nil, nil, "", "")) - assert.Error(t, fw.AddRule(true, firewall.ProtoAny, 10, 0, []string{}, "", nil, nil, "", "")) + assert.Error(t, fw.AddRule(true, math.MaxUint8, 0, 0, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", "")) + assert.Error(t, fw.AddRule(true, firewall.ProtoAny, 10, 0, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", "")) } func TestFirewall_Drop(t *testing.T) { @@ -126,8 +129,8 @@ func TestFirewall_Drop(t *testing.T) { l.SetOutput(ob) p := firewall.Packet{ - LocalIP: iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)), - RemoteIP: iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)), + LocalIP: netip.MustParseAddr("1.2.3.4"), + RemoteIP: netip.MustParseAddr("1.2.3.4"), LocalPort: 10, RemotePort: 90, Protocol: firewall.ProtoUDP, @@ -152,16 +155,16 @@ func TestFirewall_Drop(t *testing.T) { ConnectionState: &ConnectionState{ peerCert: &c, }, - vpnIp: iputil.Ip2VpnIp(ipNet.IP), + vpnIp: netip.MustParseAddr("1.2.3.4"), } h.CreateRemoteCIDR(&c) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c) - assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", nil, nil, "", "")) + assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) cp := cert.NewCAPool() // Drop outbound - assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrNoMatchingRule) + assert.Equal(t, ErrNoMatchingRule, fw.Drop(p, false, &h, cp, nil)) // Allow inbound resetConntrack(fw) assert.NoError(t, fw.Drop(p, true, &h, cp, nil)) @@ -170,34 +173,34 @@ func TestFirewall_Drop(t *testing.T) { // test remote mismatch oldRemote := p.RemoteIP - p.RemoteIP = iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 10)) + p.RemoteIP = netip.MustParseAddr("1.2.3.10") assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrInvalidRemoteIP) p.RemoteIP = oldRemote // ensure signer doesn't get in the way of group checks fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) - assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, nil, "", "signer-shasum")) - assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, nil, "", "signer-shasum-bad")) + assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum")) + assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad")) assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule) // test caSha doesn't drop on match fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) - assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, nil, "", "signer-shasum-bad")) - assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, nil, "", "signer-shasum")) + assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad")) + assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum")) assert.NoError(t, fw.Drop(p, true, &h, cp, nil)) // ensure ca name doesn't get in the way of group checks cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}} fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) - assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, nil, "ca-good", "")) - assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, nil, "ca-good-bad", "")) + assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", "")) + assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", "")) assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule) // test caName doesn't drop on match cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}} fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) - assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, nil, "ca-good-bad", "")) - assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, nil, "ca-good", "")) + assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", "")) + assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", "")) assert.NoError(t, fw.Drop(p, true, &h, cp, nil)) } @@ -207,10 +210,9 @@ func BenchmarkFirewallTable_match(b *testing.B) { TCP: firewallPort{}, } - _, n, _ := net.ParseCIDR("172.1.1.1/32") - goodLocalCIDRIP := iputil.Ip2VpnIp(n.IP) - _ = ft.TCP.addRule(f, 10, 10, []string{"good-group"}, "good-host", n, nil, "", "") - _ = ft.TCP.addRule(f, 100, 100, []string{"good-group"}, "good-host", nil, n, "", "") + pfix := netip.MustParsePrefix("172.1.1.1/32") + _ = ft.TCP.addRule(f, 10, 10, []string{"good-group"}, "good-host", pfix, netip.Prefix{}, "", "") + _ = ft.TCP.addRule(f, 100, 100, []string{"good-group"}, "good-host", netip.Prefix{}, pfix, "", "") cp := cert.NewCAPool() b.Run("fail on proto", func(b *testing.B) { @@ -231,10 +233,9 @@ func BenchmarkFirewallTable_match(b *testing.B) { b.Run("pass proto, port, fail on local CIDR", func(b *testing.B) { c := &cert.NebulaCertificate{} - ip, _, _ := net.ParseCIDR("9.254.254.254/32") - lip := iputil.Ip2VpnIp(ip) + ip := netip.MustParsePrefix("9.254.254.254/32") for n := 0; n < b.N; n++ { - assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: lip}, true, c, cp)) + assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: ip.Addr()}, true, c, cp)) } }) @@ -262,7 +263,7 @@ func BenchmarkFirewallTable_match(b *testing.B) { }, } for n := 0; n < b.N; n++ { - assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: goodLocalCIDRIP}, true, c, cp)) + assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: pfix.Addr()}, true, c, cp)) } }) @@ -286,7 +287,7 @@ func BenchmarkFirewallTable_match(b *testing.B) { }, } for n := 0; n < b.N; n++ { - assert.True(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: goodLocalCIDRIP}, true, c, cp)) + assert.True(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: pfix.Addr()}, true, c, cp)) } }) @@ -363,8 +364,8 @@ func TestFirewall_Drop2(t *testing.T) { l.SetOutput(ob) p := firewall.Packet{ - LocalIP: iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)), - RemoteIP: iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)), + LocalIP: netip.MustParseAddr("1.2.3.4"), + RemoteIP: netip.MustParseAddr("1.2.3.4"), LocalPort: 10, RemotePort: 90, Protocol: firewall.ProtoUDP, @@ -387,7 +388,7 @@ func TestFirewall_Drop2(t *testing.T) { ConnectionState: &ConnectionState{ peerCert: &c, }, - vpnIp: iputil.Ip2VpnIp(ipNet.IP), + vpnIp: netip.MustParseAddr(ipNet.IP.String()), } h.CreateRemoteCIDR(&c) @@ -406,7 +407,7 @@ func TestFirewall_Drop2(t *testing.T) { h1.CreateRemoteCIDR(&c1) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c) - assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", nil, nil, "", "")) + assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) cp := cert.NewCAPool() // h1/c1 lacks the proper groups @@ -422,8 +423,8 @@ func TestFirewall_Drop3(t *testing.T) { l.SetOutput(ob) p := firewall.Packet{ - LocalIP: iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)), - RemoteIP: iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)), + LocalIP: netip.MustParseAddr("1.2.3.4"), + RemoteIP: netip.MustParseAddr("1.2.3.4"), LocalPort: 1, RemotePort: 1, Protocol: firewall.ProtoUDP, @@ -453,7 +454,7 @@ func TestFirewall_Drop3(t *testing.T) { ConnectionState: &ConnectionState{ peerCert: &c1, }, - vpnIp: iputil.Ip2VpnIp(ipNet.IP), + vpnIp: netip.MustParseAddr(ipNet.IP.String()), } h1.CreateRemoteCIDR(&c1) @@ -468,7 +469,7 @@ func TestFirewall_Drop3(t *testing.T) { ConnectionState: &ConnectionState{ peerCert: &c2, }, - vpnIp: iputil.Ip2VpnIp(ipNet.IP), + vpnIp: netip.MustParseAddr(ipNet.IP.String()), } h2.CreateRemoteCIDR(&c2) @@ -483,13 +484,13 @@ func TestFirewall_Drop3(t *testing.T) { ConnectionState: &ConnectionState{ peerCert: &c3, }, - vpnIp: iputil.Ip2VpnIp(ipNet.IP), + vpnIp: netip.MustParseAddr(ipNet.IP.String()), } h3.CreateRemoteCIDR(&c3) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c) - assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", nil, nil, "", "")) - assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", nil, nil, "", "signer-sha")) + assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", netip.Prefix{}, netip.Prefix{}, "", "")) + assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-sha")) cp := cert.NewCAPool() // c1 should pass because host match @@ -508,8 +509,8 @@ func TestFirewall_DropConntrackReload(t *testing.T) { l.SetOutput(ob) p := firewall.Packet{ - LocalIP: iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)), - RemoteIP: iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)), + LocalIP: netip.MustParseAddr("1.2.3.4"), + RemoteIP: netip.MustParseAddr("1.2.3.4"), LocalPort: 10, RemotePort: 90, Protocol: firewall.ProtoUDP, @@ -534,12 +535,12 @@ func TestFirewall_DropConntrackReload(t *testing.T) { ConnectionState: &ConnectionState{ peerCert: &c, }, - vpnIp: iputil.Ip2VpnIp(ipNet.IP), + vpnIp: netip.MustParseAddr(ipNet.IP.String()), } h.CreateRemoteCIDR(&c) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c) - assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", nil, nil, "", "")) + assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) cp := cert.NewCAPool() // Drop outbound @@ -552,7 +553,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) { oldFw := fw fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) - assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 10, 10, []string{"any"}, "", nil, nil, "", "")) + assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 10, 10, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) fw.Conntrack = oldFw.Conntrack fw.rulesVersion = oldFw.rulesVersion + 1 @@ -561,7 +562,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) { oldFw = fw fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) - assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 11, 11, []string{"any"}, "", nil, nil, "", "")) + assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 11, 11, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) fw.Conntrack = oldFw.Conntrack fw.rulesVersion = oldFw.rulesVersion + 1 @@ -725,13 +726,13 @@ func TestNewFirewallFromConfig(t *testing.T) { conf = config.NewC(l) conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "cidr": "testh", "proto": "any"}}} _, err = NewFirewallFromConfig(l, c, conf) - assert.EqualError(t, err, "firewall.outbound rule #0; cidr did not parse; invalid CIDR address: testh") + assert.EqualError(t, err, "firewall.outbound rule #0; cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'") // Test local_cidr parse error conf = config.NewC(l) conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "local_cidr": "testh", "proto": "any"}}} _, err = NewFirewallFromConfig(l, c, conf) - assert.EqualError(t, err, "firewall.outbound rule #0; local_cidr did not parse; invalid CIDR address: testh") + assert.EqualError(t, err, "firewall.outbound rule #0; local_cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'") // Test both group and groups conf = config.NewC(l) @@ -747,78 +748,78 @@ func TestAddFirewallRulesFromConfig(t *testing.T) { mf := &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "tcp", "host": "a"}}} assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf)) - assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoTCP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil, localIp: nil}, mf.lastCall) + assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoTCP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall) // Test adding udp rule conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "udp", "host": "a"}}} assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf)) - assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoUDP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil, localIp: nil}, mf.lastCall) + assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoUDP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall) // Test adding icmp rule conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "icmp", "host": "a"}}} assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf)) - assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil, localIp: nil}, mf.lastCall) + assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall) // Test adding any rule conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "host": "a"}}} assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf)) - assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil, localIp: nil}, mf.lastCall) + assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall) // Test adding rule with cidr - cidr := &net.IPNet{IP: net.ParseIP("10.0.0.0").To4(), Mask: net.IPv4Mask(255, 0, 0, 0)} + cidr := netip.MustParsePrefix("10.0.0.0/8") conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "cidr": cidr.String()}}} assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf)) - assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr, localIp: nil}, mf.lastCall) + assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr, localIp: netip.Prefix{}}, mf.lastCall) // Test adding rule with local_cidr conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "local_cidr": cidr.String()}}} assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf)) - assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: nil, localIp: cidr}, mf.lastCall) + assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: cidr}, mf.lastCall) // Test adding rule with ca_sha conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "ca_sha": "12312313123"}}} assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf)) - assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: nil, localIp: nil, caSha: "12312313123"}, mf.lastCall) + assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: netip.Prefix{}, caSha: "12312313123"}, mf.lastCall) // Test adding rule with ca_name conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "ca_name": "root01"}}} assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf)) - assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: nil, localIp: nil, caName: "root01"}, mf.lastCall) + assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: netip.Prefix{}, caName: "root01"}, mf.lastCall) // Test single group conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a"}}} assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf)) - assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: nil, localIp: nil}, mf.lastCall) + assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall) // Test single groups conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "groups": "a"}}} assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf)) - assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: nil, localIp: nil}, mf.lastCall) + assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall) // Test multiple AND groups conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "groups": []string{"a", "b"}}}} assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf)) - assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a", "b"}, ip: nil, localIp: nil}, mf.lastCall) + assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a", "b"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall) // Test Add error conf = config.NewC(l) @@ -871,8 +872,8 @@ type addRuleCall struct { endPort int32 groups []string host string - ip *net.IPNet - localIp *net.IPNet + ip netip.Prefix + localIp netip.Prefix caName string caSha string } @@ -882,7 +883,7 @@ type mockFirewall struct { nextCallReturn error } -func (mf *mockFirewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, localIp *net.IPNet, caName string, caSha string) error { +func (mf *mockFirewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip netip.Prefix, localIp netip.Prefix, caName string, caSha string) error { mf.lastCall = addRuleCall{ incoming: incoming, proto: proto, diff --git a/go.mod b/go.mod index bec08c4..7680d09 100644 --- a/go.mod +++ b/go.mod @@ -38,8 +38,10 @@ require ( require ( github.com/beorn7/perks v1.0.1 // indirect + github.com/bits-and-blooms/bitset v1.13.0 // indirect github.com/cespare/xxhash/v2 v2.2.0 // indirect github.com/davecgh/go-spew v1.1.1 // indirect + github.com/gaissmai/bart v0.11.1 // indirect github.com/google/btree v1.1.2 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/prometheus/client_model v0.5.0 // indirect diff --git a/go.sum b/go.sum index ddd5402..7ce7e0e 100644 --- a/go.sum +++ b/go.sum @@ -14,6 +14,8 @@ github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24 github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= +github.com/bits-and-blooms/bitset v1.13.0 h1:bAQ9OPNFYbGHV6Nez0tmNI0RiEu7/hxlYJRUA0wFAVE= +github.com/bits-and-blooms/bitset v1.13.0/go.mod h1:7hO7Gc7Pp1vODcmWvKMRA9BNmbv6a/7QIWpPxHddWR8= github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44= github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= @@ -24,6 +26,10 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/flynn/noise v1.1.0 h1:KjPQoQCEFdZDiP03phOvGi11+SVVhBG2wOWAorLsstg= github.com/flynn/noise v1.1.0/go.mod h1:xbMo+0i6+IGbYdJhF31t2eR1BIU0CYc12+BNAKwUTag= +github.com/gaissmai/bart v0.10.0 h1:yCZCYF8xzcRnqDe4jMk14NlJjL1WmMsE7ilBzvuHtiI= +github.com/gaissmai/bart v0.10.0/go.mod h1:KHeYECXQiBjTzQz/om2tqn3sZF1J7hw9m6z41ftj3fg= +github.com/gaissmai/bart v0.11.1 h1:5Uv5XwsaFBRo4E5VBcb9TzY8B7zxFf+U7isDxqOrRfc= +github.com/gaissmai/bart v0.11.1/go.mod h1:KHeYECXQiBjTzQz/om2tqn3sZF1J7hw9m6z41ftj3fg= github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY= diff --git a/handshake_ix.go b/handshake_ix.go index d0bee86..8cf5341 100644 --- a/handshake_ix.go +++ b/handshake_ix.go @@ -1,13 +1,12 @@ package nebula import ( + "net/netip" "time" "github.com/flynn/noise" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/header" - "github.com/slackhq/nebula/iputil" - "github.com/slackhq/nebula/udp" ) // NOISE IX Handshakes @@ -63,7 +62,7 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool { return true } -func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []byte, h *header.H) { +func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet []byte, h *header.H) { certState := f.pki.GetCertState() ci := NewConnectionState(f.l, f.cipher, certState, false, noise.HandshakeIX, []byte{}, 0) // Mark packet 1 as seen so it doesn't show up as missed @@ -99,12 +98,26 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []by e.Info("Invalid certificate from host") return } - vpnIp := iputil.Ip2VpnIp(remoteCert.Details.Ips[0].IP) + + vpnIp, ok := netip.AddrFromSlice(remoteCert.Details.Ips[0].IP) + if !ok { + e := f.l.WithError(err).WithField("udpAddr", addr). + WithField("handshake", m{"stage": 1, "style": "ix_psk0"}) + + if f.l.Level > logrus.DebugLevel { + e = e.WithField("cert", remoteCert) + } + + e.Info("Invalid vpn ip from host") + return + } + + vpnIp = vpnIp.Unmap() certName := remoteCert.Details.Name fingerprint, _ := remoteCert.Sha256Sum() issuer := remoteCert.Details.Issuer - if vpnIp == f.myVpnIp { + if vpnIp == f.myVpnNet.Addr() { f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr). WithField("certName", certName). WithField("fingerprint", fingerprint). @@ -113,8 +126,8 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []by return } - if addr != nil { - if !f.lightHouse.GetRemoteAllowList().Allow(vpnIp, addr.IP) { + if addr.IsValid() { + if !f.lightHouse.GetRemoteAllowList().Allow(vpnIp, addr.Addr()) { f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake") return } @@ -138,8 +151,8 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []by HandshakePacket: make(map[uint8][]byte, 0), lastHandshakeTime: hs.Details.Time, relayState: RelayState{ - relays: map[iputil.VpnIp]struct{}{}, - relayForByIp: map[iputil.VpnIp]*Relay{}, + relays: map[netip.Addr]struct{}{}, + relayForByIp: map[netip.Addr]*Relay{}, relayForByIdx: map[uint32]*Relay{}, }, } @@ -218,7 +231,7 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []by msg = existing.HandshakePacket[2] f.messageMetrics.Tx(header.Handshake, header.MessageSubType(msg[1]), 1) - if addr != nil { + if addr.IsValid() { err := f.outside.WriteTo(msg, addr) if err != nil { f.l.WithField("vpnIp", existing.vpnIp).WithField("udpAddr", addr). @@ -284,7 +297,7 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []by // Do the send f.messageMetrics.Tx(header.Handshake, header.MessageSubType(msg[1]), 1) - if addr != nil { + if addr.IsValid() { err = f.outside.WriteTo(msg, addr) if err != nil { f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr). @@ -326,7 +339,7 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []by return } -func ixHandshakeStage2(f *Interface, addr *udp.Addr, via *ViaSender, hh *HandshakeHostInfo, packet []byte, h *header.H) bool { +func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *HandshakeHostInfo, packet []byte, h *header.H) bool { if hh == nil { // Nothing here to tear down, got a bogus stage 2 packet return true @@ -336,8 +349,8 @@ func ixHandshakeStage2(f *Interface, addr *udp.Addr, via *ViaSender, hh *Handsha defer hh.Unlock() hostinfo := hh.hostinfo - if addr != nil { - if !f.lightHouse.GetRemoteAllowList().Allow(hostinfo.vpnIp, addr.IP) { + if addr.IsValid() { + if !f.lightHouse.GetRemoteAllowList().Allow(hostinfo.vpnIp, addr.Addr()) { f.l.WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake") return false } @@ -389,7 +402,20 @@ func ixHandshakeStage2(f *Interface, addr *udp.Addr, via *ViaSender, hh *Handsha return true } - vpnIp := iputil.Ip2VpnIp(remoteCert.Details.Ips[0].IP) + vpnIp, ok := netip.AddrFromSlice(remoteCert.Details.Ips[0].IP) + if !ok { + e := f.l.WithError(err).WithField("udpAddr", addr). + WithField("handshake", m{"stage": 2, "style": "ix_psk0"}) + + if f.l.Level > logrus.DebugLevel { + e = e.WithField("cert", remoteCert) + } + + e.Info("Invalid vpn ip from host") + return true + } + + vpnIp = vpnIp.Unmap() certName := remoteCert.Details.Name fingerprint, _ := remoteCert.Sha256Sum() issuer := remoteCert.Details.Issuer @@ -453,7 +479,7 @@ func ixHandshakeStage2(f *Interface, addr *udp.Addr, via *ViaSender, hh *Handsha ci.eKey = NewNebulaCipherState(eKey) // Make sure the current udpAddr being used is set for responding - if addr != nil { + if addr.IsValid() { hostinfo.SetRemote(addr) } else { hostinfo.relayState.InsertRelayTo(via.relayHI.vpnIp) diff --git a/handshake_manager.go b/handshake_manager.go index 2372ced..7960435 100644 --- a/handshake_manager.go +++ b/handshake_manager.go @@ -6,15 +6,15 @@ import ( "crypto/rand" "encoding/binary" "errors" - "net" + "net/netip" "sync" "time" "github.com/rcrowley/go-metrics" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/header" - "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/udp" + "golang.org/x/exp/slices" ) const ( @@ -46,14 +46,14 @@ type HandshakeManager struct { // Mutex for interacting with the vpnIps and indexes maps sync.RWMutex - vpnIps map[iputil.VpnIp]*HandshakeHostInfo + vpnIps map[netip.Addr]*HandshakeHostInfo indexes map[uint32]*HandshakeHostInfo mainHostMap *HostMap lightHouse *LightHouse outside udp.Conn config HandshakeConfig - OutboundHandshakeTimer *LockingTimerWheel[iputil.VpnIp] + OutboundHandshakeTimer *LockingTimerWheel[netip.Addr] messageMetrics *MessageMetrics metricInitiated metrics.Counter metricTimedOut metrics.Counter @@ -61,17 +61,17 @@ type HandshakeManager struct { l *logrus.Logger // can be used to trigger outbound handshake for the given vpnIp - trigger chan iputil.VpnIp + trigger chan netip.Addr } type HandshakeHostInfo struct { sync.Mutex - startTime time.Time // Time that we first started trying with this handshake - ready bool // Is the handshake ready - counter int // How many attempts have we made so far - lastRemotes []*udp.Addr // Remotes that we sent to during the previous attempt - packetStore []*cachedPacket // A set of packets to be transmitted once the handshake completes + startTime time.Time // Time that we first started trying with this handshake + ready bool // Is the handshake ready + counter int // How many attempts have we made so far + lastRemotes []netip.AddrPort // Remotes that we sent to during the previous attempt + packetStore []*cachedPacket // A set of packets to be transmitted once the handshake completes hostinfo *HostInfo } @@ -103,14 +103,14 @@ func (hh *HandshakeHostInfo) cachePacket(l *logrus.Logger, t header.MessageType, func NewHandshakeManager(l *logrus.Logger, mainHostMap *HostMap, lightHouse *LightHouse, outside udp.Conn, config HandshakeConfig) *HandshakeManager { return &HandshakeManager{ - vpnIps: map[iputil.VpnIp]*HandshakeHostInfo{}, + vpnIps: map[netip.Addr]*HandshakeHostInfo{}, indexes: map[uint32]*HandshakeHostInfo{}, mainHostMap: mainHostMap, lightHouse: lightHouse, outside: outside, config: config, - trigger: make(chan iputil.VpnIp, config.triggerBuffer), - OutboundHandshakeTimer: NewLockingTimerWheel[iputil.VpnIp](config.tryInterval, hsTimeout(config.retries, config.tryInterval)), + trigger: make(chan netip.Addr, config.triggerBuffer), + OutboundHandshakeTimer: NewLockingTimerWheel[netip.Addr](config.tryInterval, hsTimeout(config.retries, config.tryInterval)), messageMetrics: config.messageMetrics, metricInitiated: metrics.GetOrRegisterCounter("handshake_manager.initiated", nil), metricTimedOut: metrics.GetOrRegisterCounter("handshake_manager.timed_out", nil), @@ -134,10 +134,10 @@ func (c *HandshakeManager) Run(ctx context.Context) { } } -func (hm *HandshakeManager) HandleIncoming(addr *udp.Addr, via *ViaSender, packet []byte, h *header.H) { +func (hm *HandshakeManager) HandleIncoming(addr netip.AddrPort, via *ViaSender, packet []byte, h *header.H) { // First remote allow list check before we know the vpnIp - if addr != nil { - if !hm.lightHouse.GetRemoteAllowList().AllowUnknownVpnIp(addr.IP) { + if addr.IsValid() { + if !hm.lightHouse.GetRemoteAllowList().AllowUnknownVpnIp(addr.Addr()) { hm.l.WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake") return } @@ -170,7 +170,7 @@ func (c *HandshakeManager) NextOutboundHandshakeTimerTick(now time.Time) { } } -func (hm *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTriggered bool) { +func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered bool) { hh := hm.queryVpnIp(vpnIp) if hh == nil { return @@ -212,7 +212,7 @@ func (hm *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTrigger } remotes := hostinfo.remotes.CopyAddrs(hm.mainHostMap.GetPreferredRanges()) - remotesHaveChanged := !udp.AddrSlice(remotes).Equal(hh.lastRemotes) + remotesHaveChanged := !slices.Equal(remotes, hh.lastRemotes) // We only care about a lighthouse trigger if we have new remotes to send to. // This is a very specific optimization for a fast lighthouse reply. @@ -234,8 +234,8 @@ func (hm *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTrigger } // Send the handshake to all known ips, stage 2 takes care of assigning the hostinfo.remote based on the first to reply - var sentTo []*udp.Addr - hostinfo.remotes.ForEach(hm.mainHostMap.GetPreferredRanges(), func(addr *udp.Addr, _ bool) { + var sentTo []netip.AddrPort + hostinfo.remotes.ForEach(hm.mainHostMap.GetPreferredRanges(), func(addr netip.AddrPort, _ bool) { hm.messageMetrics.Tx(header.Handshake, header.MessageSubType(hostinfo.HandshakePacket[0][1]), 1) err := hm.outside.WriteTo(hostinfo.HandshakePacket[0], addr) if err != nil { @@ -268,13 +268,13 @@ func (hm *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTrigger // Send a RelayRequest to all known Relay IP's for _, relay := range hostinfo.remotes.relays { // Don't relay to myself, and don't relay through the host I'm trying to connect to - if *relay == vpnIp || *relay == hm.lightHouse.myVpnIp { + if relay == vpnIp || relay == hm.lightHouse.myVpnNet.Addr() { continue } - relayHostInfo := hm.mainHostMap.QueryVpnIp(*relay) - if relayHostInfo == nil || relayHostInfo.remote == nil { + relayHostInfo := hm.mainHostMap.QueryVpnIp(relay) + if relayHostInfo == nil || !relayHostInfo.remote.IsValid() { hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Establish tunnel to relay target") - hm.f.Handshake(*relay) + hm.f.Handshake(relay) continue } // Check the relay HostInfo to see if we already established a relay through it @@ -285,12 +285,17 @@ func (hm *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTrigger hm.f.SendVia(relayHostInfo, existingRelay, hostinfo.HandshakePacket[0], make([]byte, 12), make([]byte, mtu), false) case Requested: hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Re-send CreateRelay request") + + //TODO: IPV6-WORK + myVpnIpB := hm.f.myVpnNet.Addr().As4() + theirVpnIpB := vpnIp.As4() + // Re-send the CreateRelay request, in case the previous one was lost. m := NebulaControl{ Type: NebulaControl_CreateRelayRequest, InitiatorRelayIndex: existingRelay.LocalIndex, - RelayFromIp: uint32(hm.lightHouse.myVpnIp), - RelayToIp: uint32(vpnIp), + RelayFromIp: binary.BigEndian.Uint32(myVpnIpB[:]), + RelayToIp: binary.BigEndian.Uint32(theirVpnIpB[:]), } msg, err := m.Marshal() if err != nil { @@ -301,10 +306,10 @@ func (hm *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTrigger // This must send over the hostinfo, not over hm.Hosts[ip] hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu)) hm.l.WithFields(logrus.Fields{ - "relayFrom": hm.lightHouse.myVpnIp, + "relayFrom": hm.f.myVpnNet.Addr(), "relayTo": vpnIp, "initiatorRelayIndex": existingRelay.LocalIndex, - "relay": *relay}). + "relay": relay}). Info("send CreateRelayRequest") } default: @@ -316,17 +321,21 @@ func (hm *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTrigger } } else { // No relays exist or requested yet. - if relayHostInfo.remote != nil { + if relayHostInfo.remote.IsValid() { idx, err := AddRelay(hm.l, relayHostInfo, hm.mainHostMap, vpnIp, nil, TerminalType, Requested) if err != nil { hostinfo.logger(hm.l).WithField("relay", relay.String()).WithError(err).Info("Failed to add relay to hostmap") } + //TODO: IPV6-WORK + myVpnIpB := hm.f.myVpnNet.Addr().As4() + theirVpnIpB := vpnIp.As4() + m := NebulaControl{ Type: NebulaControl_CreateRelayRequest, InitiatorRelayIndex: idx, - RelayFromIp: uint32(hm.lightHouse.myVpnIp), - RelayToIp: uint32(vpnIp), + RelayFromIp: binary.BigEndian.Uint32(myVpnIpB[:]), + RelayToIp: binary.BigEndian.Uint32(theirVpnIpB[:]), } msg, err := m.Marshal() if err != nil { @@ -336,10 +345,10 @@ func (hm *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTrigger } else { hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu)) hm.l.WithFields(logrus.Fields{ - "relayFrom": hm.lightHouse.myVpnIp, + "relayFrom": hm.f.myVpnNet.Addr(), "relayTo": vpnIp, "initiatorRelayIndex": idx, - "relay": *relay}). + "relay": relay}). Info("send CreateRelayRequest") } } @@ -355,7 +364,7 @@ func (hm *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTrigger // GetOrHandshake will try to find a hostinfo with a fully formed tunnel or start a new handshake if one is not present // The 2nd argument will be true if the hostinfo is ready to transmit traffic -func (hm *HandshakeManager) GetOrHandshake(vpnIp iputil.VpnIp, cacheCb func(*HandshakeHostInfo)) (*HostInfo, bool) { +func (hm *HandshakeManager) GetOrHandshake(vpnIp netip.Addr, cacheCb func(*HandshakeHostInfo)) (*HostInfo, bool) { hm.mainHostMap.RLock() h, ok := hm.mainHostMap.Hosts[vpnIp] hm.mainHostMap.RUnlock() @@ -372,7 +381,7 @@ func (hm *HandshakeManager) GetOrHandshake(vpnIp iputil.VpnIp, cacheCb func(*Han } // StartHandshake will ensure a handshake is currently being attempted for the provided vpn ip -func (hm *HandshakeManager) StartHandshake(vpnIp iputil.VpnIp, cacheCb func(*HandshakeHostInfo)) *HostInfo { +func (hm *HandshakeManager) StartHandshake(vpnIp netip.Addr, cacheCb func(*HandshakeHostInfo)) *HostInfo { hm.Lock() if hh, ok := hm.vpnIps[vpnIp]; ok { @@ -388,8 +397,8 @@ func (hm *HandshakeManager) StartHandshake(vpnIp iputil.VpnIp, cacheCb func(*Han vpnIp: vpnIp, HandshakePacket: make(map[uint8][]byte, 0), relayState: RelayState{ - relays: map[iputil.VpnIp]struct{}{}, - relayForByIp: map[iputil.VpnIp]*Relay{}, + relays: map[netip.Addr]struct{}{}, + relayForByIp: map[netip.Addr]*Relay{}, relayForByIdx: map[uint32]*Relay{}, }, } @@ -555,7 +564,7 @@ func (c *HandshakeManager) DeleteHostInfo(hostinfo *HostInfo) { func (c *HandshakeManager) unlockedDeleteHostInfo(hostinfo *HostInfo) { delete(c.vpnIps, hostinfo.vpnIp) if len(c.vpnIps) == 0 { - c.vpnIps = map[iputil.VpnIp]*HandshakeHostInfo{} + c.vpnIps = map[netip.Addr]*HandshakeHostInfo{} } delete(c.indexes, hostinfo.localIndexId) @@ -570,7 +579,7 @@ func (c *HandshakeManager) unlockedDeleteHostInfo(hostinfo *HostInfo) { } } -func (hm *HandshakeManager) QueryVpnIp(vpnIp iputil.VpnIp) *HostInfo { +func (hm *HandshakeManager) QueryVpnIp(vpnIp netip.Addr) *HostInfo { hh := hm.queryVpnIp(vpnIp) if hh != nil { return hh.hostinfo @@ -579,7 +588,7 @@ func (hm *HandshakeManager) QueryVpnIp(vpnIp iputil.VpnIp) *HostInfo { } -func (hm *HandshakeManager) queryVpnIp(vpnIp iputil.VpnIp) *HandshakeHostInfo { +func (hm *HandshakeManager) queryVpnIp(vpnIp netip.Addr) *HandshakeHostInfo { hm.RLock() defer hm.RUnlock() return hm.vpnIps[vpnIp] @@ -599,7 +608,7 @@ func (hm *HandshakeManager) queryIndex(index uint32) *HandshakeHostInfo { return hm.indexes[index] } -func (c *HandshakeManager) GetPreferredRanges() []*net.IPNet { +func (c *HandshakeManager) GetPreferredRanges() []netip.Prefix { return c.mainHostMap.GetPreferredRanges() } diff --git a/handshake_manager_test.go b/handshake_manager_test.go index 9a63357..a78b45f 100644 --- a/handshake_manager_test.go +++ b/handshake_manager_test.go @@ -1,13 +1,12 @@ package nebula import ( - "net" + "net/netip" "testing" "time" "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/header" - "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/test" "github.com/slackhq/nebula/udp" "github.com/stretchr/testify/assert" @@ -15,10 +14,11 @@ import ( func Test_NewHandshakeManagerVpnIp(t *testing.T) { l := test.NewLogger() - _, vpncidr, _ := net.ParseCIDR("172.1.1.1/24") - _, localrange, _ := net.ParseCIDR("10.1.1.1/24") - ip := iputil.Ip2VpnIp(net.ParseIP("172.1.1.2")) - preferredRanges := []*net.IPNet{localrange} + vpncidr := netip.MustParsePrefix("172.1.1.1/24") + localrange := netip.MustParsePrefix("10.1.1.1/24") + ip := netip.MustParseAddr("172.1.1.2") + + preferredRanges := []netip.Prefix{localrange} mainHM := newHostMap(l, vpncidr) mainHM.preferredRanges.Store(&preferredRanges) @@ -66,7 +66,7 @@ func Test_NewHandshakeManagerVpnIp(t *testing.T) { assert.NotContains(t, blah.vpnIps, ip) } -func testCountTimerWheelEntries(tw *LockingTimerWheel[iputil.VpnIp]) (c int) { +func testCountTimerWheelEntries(tw *LockingTimerWheel[netip.Addr]) (c int) { for _, i := range tw.t.wheel { n := i.Head for n != nil { @@ -80,7 +80,7 @@ func testCountTimerWheelEntries(tw *LockingTimerWheel[iputil.VpnIp]) (c int) { type mockEncWriter struct { } -func (mw *mockEncWriter) SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp iputil.VpnIp, p, nb, out []byte) { +func (mw *mockEncWriter) SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp netip.Addr, p, nb, out []byte) { return } @@ -92,4 +92,4 @@ func (mw *mockEncWriter) SendMessageToHostInfo(t header.MessageType, st header.M return } -func (mw *mockEncWriter) Handshake(vpnIP iputil.VpnIp) {} +func (mw *mockEncWriter) Handshake(vpnIP netip.Addr) {} diff --git a/hostmap.go b/hostmap.go index 589a124..fb97b76 100644 --- a/hostmap.go +++ b/hostmap.go @@ -3,18 +3,17 @@ package nebula import ( "errors" "net" + "net/netip" "sync" "sync/atomic" "time" + "github.com/gaissmai/bart" "github.com/rcrowley/go-metrics" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/cert" - "github.com/slackhq/nebula/cidr" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/header" - "github.com/slackhq/nebula/iputil" - "github.com/slackhq/nebula/udp" ) // const ProbeLen = 100 @@ -49,7 +48,7 @@ type Relay struct { State int LocalIndex uint32 RemoteIndex uint32 - PeerIp iputil.VpnIp + PeerIp netip.Addr } type HostMap struct { @@ -57,9 +56,9 @@ type HostMap struct { Indexes map[uint32]*HostInfo Relays map[uint32]*HostInfo // Maps a Relay IDX to a Relay HostInfo object RemoteIndexes map[uint32]*HostInfo - Hosts map[iputil.VpnIp]*HostInfo - preferredRanges atomic.Pointer[[]*net.IPNet] - vpnCIDR *net.IPNet + Hosts map[netip.Addr]*HostInfo + preferredRanges atomic.Pointer[[]netip.Prefix] + vpnCIDR netip.Prefix l *logrus.Logger } @@ -69,12 +68,12 @@ type HostMap struct { type RelayState struct { sync.RWMutex - relays map[iputil.VpnIp]struct{} // Set of VpnIp's of Hosts to use as relays to access this peer - relayForByIp map[iputil.VpnIp]*Relay // Maps VpnIps of peers for which this HostInfo is a relay to some Relay info - relayForByIdx map[uint32]*Relay // Maps a local index to some Relay info + relays map[netip.Addr]struct{} // Set of VpnIp's of Hosts to use as relays to access this peer + relayForByIp map[netip.Addr]*Relay // Maps VpnIps of peers for which this HostInfo is a relay to some Relay info + relayForByIdx map[uint32]*Relay // Maps a local index to some Relay info } -func (rs *RelayState) DeleteRelay(ip iputil.VpnIp) { +func (rs *RelayState) DeleteRelay(ip netip.Addr) { rs.Lock() defer rs.Unlock() delete(rs.relays, ip) @@ -90,33 +89,33 @@ func (rs *RelayState) CopyAllRelayFor() []*Relay { return ret } -func (rs *RelayState) GetRelayForByIp(ip iputil.VpnIp) (*Relay, bool) { +func (rs *RelayState) GetRelayForByIp(ip netip.Addr) (*Relay, bool) { rs.RLock() defer rs.RUnlock() r, ok := rs.relayForByIp[ip] return r, ok } -func (rs *RelayState) InsertRelayTo(ip iputil.VpnIp) { +func (rs *RelayState) InsertRelayTo(ip netip.Addr) { rs.Lock() defer rs.Unlock() rs.relays[ip] = struct{}{} } -func (rs *RelayState) CopyRelayIps() []iputil.VpnIp { +func (rs *RelayState) CopyRelayIps() []netip.Addr { rs.RLock() defer rs.RUnlock() - ret := make([]iputil.VpnIp, 0, len(rs.relays)) + ret := make([]netip.Addr, 0, len(rs.relays)) for ip := range rs.relays { ret = append(ret, ip) } return ret } -func (rs *RelayState) CopyRelayForIps() []iputil.VpnIp { +func (rs *RelayState) CopyRelayForIps() []netip.Addr { rs.RLock() defer rs.RUnlock() - currentRelays := make([]iputil.VpnIp, 0, len(rs.relayForByIp)) + currentRelays := make([]netip.Addr, 0, len(rs.relayForByIp)) for relayIp := range rs.relayForByIp { currentRelays = append(currentRelays, relayIp) } @@ -133,19 +132,7 @@ func (rs *RelayState) CopyRelayForIdxs() []uint32 { return ret } -func (rs *RelayState) RemoveRelay(localIdx uint32) (iputil.VpnIp, bool) { - rs.Lock() - defer rs.Unlock() - r, ok := rs.relayForByIdx[localIdx] - if !ok { - return iputil.VpnIp(0), false - } - delete(rs.relayForByIdx, localIdx) - delete(rs.relayForByIp, r.PeerIp) - return r.PeerIp, true -} - -func (rs *RelayState) CompleteRelayByIP(vpnIp iputil.VpnIp, remoteIdx uint32) bool { +func (rs *RelayState) CompleteRelayByIP(vpnIp netip.Addr, remoteIdx uint32) bool { rs.Lock() defer rs.Unlock() r, ok := rs.relayForByIp[vpnIp] @@ -175,7 +162,7 @@ func (rs *RelayState) CompleteRelayByIdx(localIdx uint32, remoteIdx uint32) (*Re return &newRelay, true } -func (rs *RelayState) QueryRelayForByIp(vpnIp iputil.VpnIp) (*Relay, bool) { +func (rs *RelayState) QueryRelayForByIp(vpnIp netip.Addr) (*Relay, bool) { rs.RLock() defer rs.RUnlock() r, ok := rs.relayForByIp[vpnIp] @@ -189,7 +176,7 @@ func (rs *RelayState) QueryRelayForByIdx(idx uint32) (*Relay, bool) { return r, ok } -func (rs *RelayState) InsertRelay(ip iputil.VpnIp, idx uint32, r *Relay) { +func (rs *RelayState) InsertRelay(ip netip.Addr, idx uint32, r *Relay) { rs.Lock() defer rs.Unlock() rs.relayForByIp[ip] = r @@ -197,15 +184,15 @@ func (rs *RelayState) InsertRelay(ip iputil.VpnIp, idx uint32, r *Relay) { } type HostInfo struct { - remote *udp.Addr + remote netip.AddrPort remotes *RemoteList promoteCounter atomic.Uint32 ConnectionState *ConnectionState remoteIndexId uint32 localIndexId uint32 - vpnIp iputil.VpnIp + vpnIp netip.Addr recvError atomic.Uint32 - remoteCidr *cidr.Tree4[struct{}] + remoteCidr *bart.Table[struct{}] relayState RelayState // HandshakePacket records the packets used to create this hostinfo @@ -227,7 +214,7 @@ type HostInfo struct { lastHandshakeTime uint64 lastRoam time.Time - lastRoamRemote *udp.Addr + lastRoamRemote netip.AddrPort // Used to track other hostinfos for this vpn ip since only 1 can be primary // Synchronised via hostmap lock and not the hostinfo lock. @@ -254,7 +241,7 @@ type cachedPacketMetrics struct { dropped metrics.Counter } -func NewHostMapFromConfig(l *logrus.Logger, vpnCIDR *net.IPNet, c *config.C) *HostMap { +func NewHostMapFromConfig(l *logrus.Logger, vpnCIDR netip.Prefix, c *config.C) *HostMap { hm := newHostMap(l, vpnCIDR) hm.reload(c, true) @@ -269,12 +256,12 @@ func NewHostMapFromConfig(l *logrus.Logger, vpnCIDR *net.IPNet, c *config.C) *Ho return hm } -func newHostMap(l *logrus.Logger, vpnCIDR *net.IPNet) *HostMap { +func newHostMap(l *logrus.Logger, vpnCIDR netip.Prefix) *HostMap { return &HostMap{ Indexes: map[uint32]*HostInfo{}, Relays: map[uint32]*HostInfo{}, RemoteIndexes: map[uint32]*HostInfo{}, - Hosts: map[iputil.VpnIp]*HostInfo{}, + Hosts: map[netip.Addr]*HostInfo{}, vpnCIDR: vpnCIDR, l: l, } @@ -282,11 +269,11 @@ func newHostMap(l *logrus.Logger, vpnCIDR *net.IPNet) *HostMap { func (hm *HostMap) reload(c *config.C, initial bool) { if initial || c.HasChanged("preferred_ranges") { - var preferredRanges []*net.IPNet + var preferredRanges []netip.Prefix rawPreferredRanges := c.GetStringSlice("preferred_ranges", []string{}) for _, rawPreferredRange := range rawPreferredRanges { - _, preferredRange, err := net.ParseCIDR(rawPreferredRange) + preferredRange, err := netip.ParsePrefix(rawPreferredRange) if err != nil { hm.l.WithError(err).WithField("range", rawPreferredRanges).Warn("Failed to parse preferred ranges, ignoring") @@ -378,7 +365,7 @@ func (hm *HostMap) unlockedDeleteHostInfo(hostinfo *HostInfo) { // The vpnIp pointer points to the same hostinfo as the local index id, we can remove it delete(hm.Hosts, hostinfo.vpnIp) if len(hm.Hosts) == 0 { - hm.Hosts = map[iputil.VpnIp]*HostInfo{} + hm.Hosts = map[netip.Addr]*HostInfo{} } if hostinfo.next != nil { @@ -461,11 +448,11 @@ func (hm *HostMap) QueryReverseIndex(index uint32) *HostInfo { } } -func (hm *HostMap) QueryVpnIp(vpnIp iputil.VpnIp) *HostInfo { +func (hm *HostMap) QueryVpnIp(vpnIp netip.Addr) *HostInfo { return hm.queryVpnIp(vpnIp, nil) } -func (hm *HostMap) QueryVpnIpRelayFor(targetIp, relayHostIp iputil.VpnIp) (*HostInfo, *Relay, error) { +func (hm *HostMap) QueryVpnIpRelayFor(targetIp, relayHostIp netip.Addr) (*HostInfo, *Relay, error) { hm.RLock() defer hm.RUnlock() @@ -483,7 +470,7 @@ func (hm *HostMap) QueryVpnIpRelayFor(targetIp, relayHostIp iputil.VpnIp) (*Host return nil, nil, errors.New("unable to find host with relay") } -func (hm *HostMap) queryVpnIp(vpnIp iputil.VpnIp, promoteIfce *Interface) *HostInfo { +func (hm *HostMap) queryVpnIp(vpnIp netip.Addr, promoteIfce *Interface) *HostInfo { hm.RLock() if h, ok := hm.Hosts[vpnIp]; ok { hm.RUnlock() @@ -535,7 +522,7 @@ func (hm *HostMap) unlockedAddHostInfo(hostinfo *HostInfo, f *Interface) { } } -func (hm *HostMap) GetPreferredRanges() []*net.IPNet { +func (hm *HostMap) GetPreferredRanges() []netip.Prefix { //NOTE: if preferredRanges is ever not stored before a load this will fail to dereference a nil pointer return *hm.preferredRanges.Load() } @@ -560,14 +547,14 @@ func (hm *HostMap) ForEachIndex(f controlEach) { // TryPromoteBest handles re-querying lighthouses and probing for better paths // NOTE: It is an error to call this if you are a lighthouse since they should not roam clients! -func (i *HostInfo) TryPromoteBest(preferredRanges []*net.IPNet, ifce *Interface) { +func (i *HostInfo) TryPromoteBest(preferredRanges []netip.Prefix, ifce *Interface) { c := i.promoteCounter.Add(1) if c%ifce.tryPromoteEvery.Load() == 0 { remote := i.remote // return early if we are already on a preferred remote - if remote != nil { - rIP := remote.IP + if remote.IsValid() { + rIP := remote.Addr() for _, l := range preferredRanges { if l.Contains(rIP) { return @@ -575,8 +562,8 @@ func (i *HostInfo) TryPromoteBest(preferredRanges []*net.IPNet, ifce *Interface) } } - i.remotes.ForEach(preferredRanges, func(addr *udp.Addr, preferred bool) { - if remote != nil && (addr == nil || !preferred) { + i.remotes.ForEach(preferredRanges, func(addr netip.AddrPort, preferred bool) { + if remote.IsValid() && (!addr.IsValid() || !preferred) { return } @@ -605,23 +592,23 @@ func (i *HostInfo) GetCert() *cert.NebulaCertificate { return nil } -func (i *HostInfo) SetRemote(remote *udp.Addr) { +func (i *HostInfo) SetRemote(remote netip.AddrPort) { // We copy here because we likely got this remote from a source that reuses the object - if !i.remote.Equals(remote) { - i.remote = remote.Copy() - i.remotes.LearnRemote(i.vpnIp, remote.Copy()) + if i.remote != remote { + i.remote = remote + i.remotes.LearnRemote(i.vpnIp, remote) } } // SetRemoteIfPreferred returns true if the remote was changed. The lastRoam // time on the HostInfo will also be updated. -func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, newRemote *udp.Addr) bool { - if newRemote == nil { +func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, newRemote netip.AddrPort) bool { + if !newRemote.IsValid() { // relays have nil udp Addrs return false } currentRemote := i.remote - if currentRemote == nil { + if !currentRemote.IsValid() { i.SetRemote(newRemote) return true } @@ -631,11 +618,11 @@ func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, newRemote *udp.Addr) bool { newIsPreferred := false for _, l := range hm.GetPreferredRanges() { // return early if we are already on a preferred remote - if l.Contains(currentRemote.IP) { + if l.Contains(currentRemote.Addr()) { return false } - if l.Contains(newRemote.IP) { + if l.Contains(newRemote.Addr()) { newIsPreferred = true } } @@ -643,7 +630,7 @@ func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, newRemote *udp.Addr) bool { if newIsPreferred { // Consider this a roaming event i.lastRoam = time.Now() - i.lastRoamRemote = currentRemote.Copy() + i.lastRoamRemote = currentRemote i.SetRemote(newRemote) @@ -666,13 +653,21 @@ func (i *HostInfo) CreateRemoteCIDR(c *cert.NebulaCertificate) { return } - remoteCidr := cidr.NewTree4[struct{}]() + remoteCidr := new(bart.Table[struct{}]) for _, ip := range c.Details.Ips { - remoteCidr.AddCIDR(&net.IPNet{IP: ip.IP, Mask: net.IPMask{255, 255, 255, 255}}, struct{}{}) + //TODO: IPV6-WORK what to do when ip is invalid? + nip, _ := netip.AddrFromSlice(ip.IP) + nip = nip.Unmap() + bits, _ := ip.Mask.Size() + remoteCidr.Insert(netip.PrefixFrom(nip, bits), struct{}{}) } for _, n := range c.Details.Subnets { - remoteCidr.AddCIDR(n, struct{}{}) + //TODO: IPV6-WORK what to do when ip is invalid? + nip, _ := netip.AddrFromSlice(n.IP) + nip = nip.Unmap() + bits, _ := n.Mask.Size() + remoteCidr.Insert(netip.PrefixFrom(nip, bits), struct{}{}) } i.remoteCidr = remoteCidr } @@ -697,9 +692,9 @@ func (i *HostInfo) logger(l *logrus.Logger) *logrus.Entry { // Utility functions -func localIps(l *logrus.Logger, allowList *LocalAllowList) *[]net.IP { +func localIps(l *logrus.Logger, allowList *LocalAllowList) []netip.Addr { //FIXME: This function is pretty garbage - var ips []net.IP + var ips []netip.Addr ifaces, _ := net.Interfaces() for _, i := range ifaces { allow := allowList.AllowName(i.Name) @@ -721,20 +716,29 @@ func localIps(l *logrus.Logger, allowList *LocalAllowList) *[]net.IP { ip = v.IP } + nip, ok := netip.AddrFromSlice(ip) + if !ok { + if l.Level >= logrus.DebugLevel { + l.WithField("localIp", ip).Debug("ip was invalid for netip") + } + continue + } + nip = nip.Unmap() + //TODO: Filtering out link local for now, this is probably the most correct thing //TODO: Would be nice to filter out SLAAC MAC based ips as well - if ip.IsLoopback() == false && !ip.IsLinkLocalUnicast() { - allow := allowList.Allow(ip) + if nip.IsLoopback() == false && nip.IsLinkLocalUnicast() == false { + allow := allowList.Allow(nip) if l.Level >= logrus.TraceLevel { - l.WithField("localIp", ip).WithField("allow", allow).Trace("localAllowList.Allow") + l.WithField("localIp", nip).WithField("allow", allow).Trace("localAllowList.Allow") } if !allow { continue } - ips = append(ips, ip) + ips = append(ips, nip) } } } - return &ips + return ips } diff --git a/hostmap_test.go b/hostmap_test.go index 8311cef..7e2feb8 100644 --- a/hostmap_test.go +++ b/hostmap_test.go @@ -1,7 +1,7 @@ package nebula import ( - "net" + "net/netip" "testing" "github.com/slackhq/nebula/config" @@ -13,18 +13,15 @@ func TestHostMap_MakePrimary(t *testing.T) { l := test.NewLogger() hm := newHostMap( l, - &net.IPNet{ - IP: net.IP{10, 0, 0, 1}, - Mask: net.IPMask{255, 255, 255, 0}, - }, + netip.MustParsePrefix("10.0.0.1/24"), ) f := &Interface{} - h1 := &HostInfo{vpnIp: 1, localIndexId: 1} - h2 := &HostInfo{vpnIp: 1, localIndexId: 2} - h3 := &HostInfo{vpnIp: 1, localIndexId: 3} - h4 := &HostInfo{vpnIp: 1, localIndexId: 4} + h1 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 1} + h2 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 2} + h3 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 3} + h4 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 4} hm.unlockedAddHostInfo(h4, f) hm.unlockedAddHostInfo(h3, f) @@ -32,7 +29,7 @@ func TestHostMap_MakePrimary(t *testing.T) { hm.unlockedAddHostInfo(h1, f) // Make sure we go h1 -> h2 -> h3 -> h4 - prim := hm.QueryVpnIp(1) + prim := hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1")) assert.Equal(t, h1.localIndexId, prim.localIndexId) assert.Equal(t, h2.localIndexId, prim.next.localIndexId) assert.Nil(t, prim.prev) @@ -47,7 +44,7 @@ func TestHostMap_MakePrimary(t *testing.T) { hm.MakePrimary(h3) // Make sure we go h3 -> h1 -> h2 -> h4 - prim = hm.QueryVpnIp(1) + prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1")) assert.Equal(t, h3.localIndexId, prim.localIndexId) assert.Equal(t, h1.localIndexId, prim.next.localIndexId) assert.Nil(t, prim.prev) @@ -62,7 +59,7 @@ func TestHostMap_MakePrimary(t *testing.T) { hm.MakePrimary(h4) // Make sure we go h4 -> h3 -> h1 -> h2 - prim = hm.QueryVpnIp(1) + prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1")) assert.Equal(t, h4.localIndexId, prim.localIndexId) assert.Equal(t, h3.localIndexId, prim.next.localIndexId) assert.Nil(t, prim.prev) @@ -77,7 +74,7 @@ func TestHostMap_MakePrimary(t *testing.T) { hm.MakePrimary(h4) // Make sure we go h4 -> h3 -> h1 -> h2 - prim = hm.QueryVpnIp(1) + prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1")) assert.Equal(t, h4.localIndexId, prim.localIndexId) assert.Equal(t, h3.localIndexId, prim.next.localIndexId) assert.Nil(t, prim.prev) @@ -93,20 +90,17 @@ func TestHostMap_DeleteHostInfo(t *testing.T) { l := test.NewLogger() hm := newHostMap( l, - &net.IPNet{ - IP: net.IP{10, 0, 0, 1}, - Mask: net.IPMask{255, 255, 255, 0}, - }, + netip.MustParsePrefix("10.0.0.1/24"), ) f := &Interface{} - h1 := &HostInfo{vpnIp: 1, localIndexId: 1} - h2 := &HostInfo{vpnIp: 1, localIndexId: 2} - h3 := &HostInfo{vpnIp: 1, localIndexId: 3} - h4 := &HostInfo{vpnIp: 1, localIndexId: 4} - h5 := &HostInfo{vpnIp: 1, localIndexId: 5} - h6 := &HostInfo{vpnIp: 1, localIndexId: 6} + h1 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 1} + h2 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 2} + h3 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 3} + h4 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 4} + h5 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 5} + h6 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 6} hm.unlockedAddHostInfo(h6, f) hm.unlockedAddHostInfo(h5, f) @@ -122,7 +116,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) { assert.Nil(t, h) // Make sure we go h1 -> h2 -> h3 -> h4 -> h5 - prim := hm.QueryVpnIp(1) + prim := hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1")) assert.Equal(t, h1.localIndexId, prim.localIndexId) assert.Equal(t, h2.localIndexId, prim.next.localIndexId) assert.Nil(t, prim.prev) @@ -141,7 +135,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) { assert.Nil(t, h1.next) // Make sure we go h2 -> h3 -> h4 -> h5 - prim = hm.QueryVpnIp(1) + prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1")) assert.Equal(t, h2.localIndexId, prim.localIndexId) assert.Equal(t, h3.localIndexId, prim.next.localIndexId) assert.Nil(t, prim.prev) @@ -159,7 +153,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) { assert.Nil(t, h3.next) // Make sure we go h2 -> h4 -> h5 - prim = hm.QueryVpnIp(1) + prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1")) assert.Equal(t, h2.localIndexId, prim.localIndexId) assert.Equal(t, h4.localIndexId, prim.next.localIndexId) assert.Nil(t, prim.prev) @@ -175,7 +169,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) { assert.Nil(t, h5.next) // Make sure we go h2 -> h4 - prim = hm.QueryVpnIp(1) + prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1")) assert.Equal(t, h2.localIndexId, prim.localIndexId) assert.Equal(t, h4.localIndexId, prim.next.localIndexId) assert.Nil(t, prim.prev) @@ -189,7 +183,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) { assert.Nil(t, h2.next) // Make sure we only have h4 - prim = hm.QueryVpnIp(1) + prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1")) assert.Equal(t, h4.localIndexId, prim.localIndexId) assert.Nil(t, prim.prev) assert.Nil(t, prim.next) @@ -201,7 +195,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) { assert.Nil(t, h4.next) // Make sure we have nil - prim = hm.QueryVpnIp(1) + prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1")) assert.Nil(t, prim) } @@ -211,14 +205,11 @@ func TestHostMap_reload(t *testing.T) { hm := NewHostMapFromConfig( l, - &net.IPNet{ - IP: net.IP{10, 0, 0, 1}, - Mask: net.IPMask{255, 255, 255, 0}, - }, + netip.MustParsePrefix("10.0.0.1/24"), c, ) - toS := func(ipn []*net.IPNet) []string { + toS := func(ipn []netip.Prefix) []string { var s []string for _, n := range ipn { s = append(s, n.String()) diff --git a/hostmap_tester.go b/hostmap_tester.go index 0d5d41b..b2d1d1b 100644 --- a/hostmap_tester.go +++ b/hostmap_tester.go @@ -5,9 +5,11 @@ package nebula // This file contains functions used to export information to the e2e testing framework -import "github.com/slackhq/nebula/iputil" +import ( + "net/netip" +) -func (i *HostInfo) GetVpnIp() iputil.VpnIp { +func (i *HostInfo) GetVpnIp() netip.Addr { return i.vpnIp } diff --git a/inside.go b/inside.go index 079e4dd..0ccd179 100644 --- a/inside.go +++ b/inside.go @@ -1,12 +1,13 @@ package nebula import ( + "net/netip" + "github.com/sirupsen/logrus" "github.com/slackhq/nebula/firewall" "github.com/slackhq/nebula/header" "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/noiseutil" - "github.com/slackhq/nebula/udp" ) func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb, out []byte, q int, localCache firewall.ConntrackCache) { @@ -19,11 +20,11 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet } // Ignore local broadcast packets - if f.dropLocalBroadcast && fwPacket.RemoteIP == f.localBroadcast { + if f.dropLocalBroadcast && fwPacket.RemoteIP == f.myBroadcastAddr { return } - if fwPacket.RemoteIP == f.myVpnIp { + if fwPacket.RemoteIP == f.myVpnNet.Addr() { // Immediately forward packets from self to self. // This should only happen on Darwin-based and FreeBSD hosts, which // routes packets from the Nebula IP to the Nebula IP through the Nebula @@ -39,8 +40,8 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet return } - // Ignore broadcast packets - if f.dropMulticast && isMulticast(fwPacket.RemoteIP) { + // Ignore multicast packets + if f.dropMulticast && fwPacket.RemoteIP.IsMulticast() { return } @@ -64,7 +65,7 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet dropReason := f.firewall.Drop(*fwPacket, false, hostinfo, f.pki.GetCAPool(), localCache) if dropReason == nil { - f.sendNoMetrics(header.Message, 0, hostinfo.ConnectionState, hostinfo, nil, packet, nb, out, q) + f.sendNoMetrics(header.Message, 0, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, packet, nb, out, q) } else { f.rejectInside(packet, out, q) @@ -113,19 +114,19 @@ func (f *Interface) rejectOutside(packet []byte, ci *ConnectionState, hostinfo * return } - f.sendNoMetrics(header.Message, 0, ci, hostinfo, nil, out, nb, packet, q) + f.sendNoMetrics(header.Message, 0, ci, hostinfo, netip.AddrPort{}, out, nb, packet, q) } -func (f *Interface) Handshake(vpnIp iputil.VpnIp) { +func (f *Interface) Handshake(vpnIp netip.Addr) { f.getOrHandshake(vpnIp, nil) } // getOrHandshake returns nil if the vpnIp is not routable. // If the 2nd return var is false then the hostinfo is not ready to be used in a tunnel -func (f *Interface) getOrHandshake(vpnIp iputil.VpnIp, cacheCallback func(*HandshakeHostInfo)) (*HostInfo, bool) { - if !ipMaskContains(f.lightHouse.myVpnIp, f.lightHouse.myVpnZeros, vpnIp) { +func (f *Interface) getOrHandshake(vpnIp netip.Addr, cacheCallback func(*HandshakeHostInfo)) (*HostInfo, bool) { + if !f.myVpnNet.Contains(vpnIp) { vpnIp = f.inside.RouteFor(vpnIp) - if vpnIp == 0 { + if !vpnIp.IsValid() { return nil, false } } @@ -152,11 +153,11 @@ func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubTyp return } - f.sendNoMetrics(header.Message, st, hostinfo.ConnectionState, hostinfo, nil, p, nb, out, 0) + f.sendNoMetrics(header.Message, st, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, p, nb, out, 0) } // SendMessageToVpnIp handles real ip:port lookup and sends to the current best known address for vpnIp -func (f *Interface) SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp iputil.VpnIp, p, nb, out []byte) { +func (f *Interface) SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp netip.Addr, p, nb, out []byte) { hostInfo, ready := f.getOrHandshake(vpnIp, func(hh *HandshakeHostInfo) { hh.cachePacket(f.l, t, st, p, f.SendMessageToHostInfo, f.cachedPacketMetrics) }) @@ -182,10 +183,10 @@ func (f *Interface) SendMessageToHostInfo(t header.MessageType, st header.Messag func (f *Interface) send(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, p, nb, out []byte) { f.messageMetrics.Tx(t, st, 1) - f.sendNoMetrics(t, st, ci, hostinfo, nil, p, nb, out, 0) + f.sendNoMetrics(t, st, ci, hostinfo, netip.AddrPort{}, p, nb, out, 0) } -func (f *Interface) sendTo(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote *udp.Addr, p, nb, out []byte) { +func (f *Interface) sendTo(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote netip.AddrPort, p, nb, out []byte) { f.messageMetrics.Tx(t, st, 1) f.sendNoMetrics(t, st, ci, hostinfo, remote, p, nb, out, 0) } @@ -255,12 +256,12 @@ func (f *Interface) SendVia(via *HostInfo, f.connectionManager.RelayUsed(relay.LocalIndex) } -func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote *udp.Addr, p, nb, out []byte, q int) { +func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote netip.AddrPort, p, nb, out []byte, q int) { if ci.eKey == nil { //TODO: log warning return } - useRelay := remote == nil && hostinfo.remote == nil + useRelay := !remote.IsValid() && !hostinfo.remote.IsValid() fullOut := out if useRelay { @@ -308,13 +309,13 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType return } - if remote != nil { + if remote.IsValid() { err = f.writers[q].WriteTo(out, remote) if err != nil { hostinfo.logger(f.l).WithError(err). WithField("udpAddr", remote).Error("Failed to write outgoing packet") } - } else if hostinfo.remote != nil { + } else if hostinfo.remote.IsValid() { err = f.writers[q].WriteTo(out, hostinfo.remote) if err != nil { hostinfo.logger(f.l).WithError(err). @@ -334,8 +335,3 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType } } } - -func isMulticast(ip iputil.VpnIp) bool { - // Class D multicast - return (((ip >> 24) & 0xff) & 0xf0) == 0xe0 -} diff --git a/interface.go b/interface.go index d16348a..f251907 100644 --- a/interface.go +++ b/interface.go @@ -2,10 +2,11 @@ package nebula import ( "context" + "encoding/binary" "errors" "fmt" "io" - "net" + "net/netip" "os" "runtime" "sync/atomic" @@ -16,7 +17,6 @@ import ( "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/firewall" "github.com/slackhq/nebula/header" - "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/overlay" "github.com/slackhq/nebula/udp" ) @@ -63,8 +63,8 @@ type Interface struct { serveDns bool createTime time.Time lightHouse *LightHouse - localBroadcast iputil.VpnIp - myVpnIp iputil.VpnIp + myBroadcastAddr netip.Addr + myVpnNet netip.Prefix dropLocalBroadcast bool dropMulticast bool routines int @@ -102,9 +102,9 @@ type EncWriter interface { out []byte, nocopy bool, ) - SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp iputil.VpnIp, p, nb, out []byte) + SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp netip.Addr, p, nb, out []byte) SendMessageToHostInfo(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, nb, out []byte) - Handshake(vpnIp iputil.VpnIp) + Handshake(vpnIp netip.Addr) } type sendRecvErrorConfig uint8 @@ -115,10 +115,10 @@ const ( sendRecvErrorPrivate ) -func (s sendRecvErrorConfig) ShouldSendRecvError(ip net.IP) bool { +func (s sendRecvErrorConfig) ShouldSendRecvError(ip netip.AddrPort) bool { switch s { case sendRecvErrorPrivate: - return ip.IsPrivate() + return ip.Addr().IsPrivate() case sendRecvErrorAlways: return true case sendRecvErrorNever: @@ -156,7 +156,27 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) { } certificate := c.pki.GetCertState().Certificate - myVpnIp := iputil.Ip2VpnIp(certificate.Details.Ips[0].IP) + + myVpnAddr, ok := netip.AddrFromSlice(certificate.Details.Ips[0].IP) + if !ok { + return nil, fmt.Errorf("invalid ip address in certificate: %s", certificate.Details.Ips[0].IP) + } + + myVpnMask, ok := netip.AddrFromSlice(certificate.Details.Ips[0].Mask) + if !ok { + return nil, fmt.Errorf("invalid ip mask in certificate: %s", certificate.Details.Ips[0].Mask) + } + + myVpnAddr = myVpnAddr.Unmap() + myVpnMask = myVpnMask.Unmap() + + if myVpnAddr.BitLen() != myVpnMask.BitLen() { + return nil, fmt.Errorf("ip address and mask are different lengths in certificate") + } + + ones, _ := certificate.Details.Ips[0].Mask.Size() + myVpnNet := netip.PrefixFrom(myVpnAddr, ones) + ifce := &Interface{ pki: c.pki, hostMap: c.HostMap, @@ -168,14 +188,13 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) { handshakeManager: c.HandshakeManager, createTime: time.Now(), lightHouse: c.lightHouse, - localBroadcast: myVpnIp | ^iputil.Ip2VpnIp(certificate.Details.Ips[0].Mask), dropLocalBroadcast: c.DropLocalBroadcast, dropMulticast: c.DropMulticast, routines: c.routines, version: c.version, writers: make([]udp.Conn, c.routines), readers: make([]io.ReadWriteCloser, c.routines), - myVpnIp: myVpnIp, + myVpnNet: myVpnNet, relayManager: c.relayManager, conntrackCacheTimeout: c.ConntrackCacheTimeout, @@ -190,6 +209,12 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) { l: c.l, } + if myVpnAddr.Is4() { + addr := myVpnNet.Masked().Addr().As4() + binary.BigEndian.PutUint32(addr[:], binary.BigEndian.Uint32(addr[:])|^binary.BigEndian.Uint32(certificate.Details.Ips[0].Mask)) + ifce.myBroadcastAddr = netip.AddrFrom4(addr) + } + ifce.tryPromoteEvery.Store(c.tryPromoteEvery) ifce.reQueryEvery.Store(c.reQueryEvery) ifce.reQueryWait.Store(int64(c.reQueryWait)) diff --git a/iputil/packet.go b/iputil/packet.go index b18e524..719e034 100644 --- a/iputil/packet.go +++ b/iputil/packet.go @@ -6,6 +6,8 @@ import ( "golang.org/x/net/ipv4" ) +//TODO: IPV6-WORK can probably delete this + const ( // Need 96 bytes for the largest reject packet: // - 20 byte ipv4 header diff --git a/iputil/util.go b/iputil/util.go deleted file mode 100644 index 65f7677..0000000 --- a/iputil/util.go +++ /dev/null @@ -1,93 +0,0 @@ -package iputil - -import ( - "encoding/binary" - "fmt" - "net" - "net/netip" -) - -type VpnIp uint32 - -const maxIPv4StringLen = len("255.255.255.255") - -func (ip VpnIp) String() string { - b := make([]byte, maxIPv4StringLen) - - n := ubtoa(b, 0, byte(ip>>24)) - b[n] = '.' - n++ - - n += ubtoa(b, n, byte(ip>>16&255)) - b[n] = '.' - n++ - - n += ubtoa(b, n, byte(ip>>8&255)) - b[n] = '.' - n++ - - n += ubtoa(b, n, byte(ip&255)) - return string(b[:n]) -} - -func (ip VpnIp) MarshalJSON() ([]byte, error) { - return []byte(fmt.Sprintf("\"%s\"", ip.String())), nil -} - -func (ip VpnIp) ToIP() net.IP { - nip := make(net.IP, 4) - binary.BigEndian.PutUint32(nip, uint32(ip)) - return nip -} - -func (ip VpnIp) ToNetIpAddr() netip.Addr { - var nip [4]byte - binary.BigEndian.PutUint32(nip[:], uint32(ip)) - return netip.AddrFrom4(nip) -} - -func Ip2VpnIp(ip []byte) VpnIp { - if len(ip) == 16 { - return VpnIp(binary.BigEndian.Uint32(ip[12:16])) - } - return VpnIp(binary.BigEndian.Uint32(ip)) -} - -func ToNetIpAddr(ip net.IP) (netip.Addr, error) { - addr, ok := netip.AddrFromSlice(ip) - if !ok { - return netip.Addr{}, fmt.Errorf("invalid net.IP: %v", ip) - } - return addr, nil -} - -func ToNetIpPrefix(ipNet net.IPNet) (netip.Prefix, error) { - addr, err := ToNetIpAddr(ipNet.IP) - if err != nil { - return netip.Prefix{}, err - } - ones, bits := ipNet.Mask.Size() - if ones == 0 && bits == 0 { - return netip.Prefix{}, fmt.Errorf("invalid net.IP: %v", ipNet) - } - return netip.PrefixFrom(addr, ones), nil -} - -// ubtoa encodes the string form of the integer v to dst[start:] and -// returns the number of bytes written to dst. The caller must ensure -// that dst has sufficient length. -func ubtoa(dst []byte, start int, v byte) int { - if v < 10 { - dst[start] = v + '0' - return 1 - } else if v < 100 { - dst[start+1] = v%10 + '0' - dst[start] = v/10 + '0' - return 2 - } - - dst[start+2] = v%10 + '0' - dst[start+1] = (v/10)%10 + '0' - dst[start] = v/100 + '0' - return 3 -} diff --git a/iputil/util_test.go b/iputil/util_test.go deleted file mode 100644 index 712d426..0000000 --- a/iputil/util_test.go +++ /dev/null @@ -1,17 +0,0 @@ -package iputil - -import ( - "net" - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestVpnIp_String(t *testing.T) { - assert.Equal(t, "255.255.255.255", Ip2VpnIp(net.ParseIP("255.255.255.255")).String()) - assert.Equal(t, "1.255.255.255", Ip2VpnIp(net.ParseIP("1.255.255.255")).String()) - assert.Equal(t, "1.1.255.255", Ip2VpnIp(net.ParseIP("1.1.255.255")).String()) - assert.Equal(t, "1.1.1.255", Ip2VpnIp(net.ParseIP("1.1.1.255")).String()) - assert.Equal(t, "1.1.1.1", Ip2VpnIp(net.ParseIP("1.1.1.1")).String()) - assert.Equal(t, "0.0.0.0", Ip2VpnIp(net.ParseIP("0.0.0.0")).String()) -} diff --git a/lighthouse.go b/lighthouse.go index df68e1e..62f4065 100644 --- a/lighthouse.go +++ b/lighthouse.go @@ -7,16 +7,16 @@ import ( "fmt" "net" "net/netip" + "strconv" "sync" "sync/atomic" "time" + "github.com/gaissmai/bart" "github.com/rcrowley/go-metrics" "github.com/sirupsen/logrus" - "github.com/slackhq/nebula/cidr" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/header" - "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/udp" "github.com/slackhq/nebula/util" ) @@ -26,25 +26,18 @@ import ( var ErrHostNotKnown = errors.New("host not known") -type netIpAndPort struct { - ip net.IP - port uint16 -} - type LightHouse struct { //TODO: We need a timer wheel to kick out vpnIps that haven't reported in a long time sync.RWMutex //Because we concurrently read and write to our maps ctx context.Context amLighthouse bool - myVpnIp iputil.VpnIp - myVpnZeros iputil.VpnIp - myVpnNet *net.IPNet + myVpnNet netip.Prefix punchConn udp.Conn punchy *Punchy // Local cache of answers from light houses // map of vpn Ip to answers - addrMap map[iputil.VpnIp]*RemoteList + addrMap map[netip.Addr]*RemoteList // filters remote addresses allowed for each host // - When we are a lighthouse, this filters what addresses we store and @@ -57,26 +50,26 @@ type LightHouse struct { localAllowList atomic.Pointer[LocalAllowList] // used to trigger the HandshakeManager when we receive HostQueryReply - handshakeTrigger chan<- iputil.VpnIp + handshakeTrigger chan<- netip.Addr // staticList exists to avoid having a bool in each addrMap entry // since static should be rare - staticList atomic.Pointer[map[iputil.VpnIp]struct{}] - lighthouses atomic.Pointer[map[iputil.VpnIp]struct{}] + staticList atomic.Pointer[map[netip.Addr]struct{}] + lighthouses atomic.Pointer[map[netip.Addr]struct{}] interval atomic.Int64 updateCancel context.CancelFunc ifce EncWriter nebulaPort uint32 // 32 bits because protobuf does not have a uint16 - advertiseAddrs atomic.Pointer[[]netIpAndPort] + advertiseAddrs atomic.Pointer[[]netip.AddrPort] // IP's of relays that can be used by peers to access me - relaysForMe atomic.Pointer[[]iputil.VpnIp] + relaysForMe atomic.Pointer[[]netip.Addr] - queryChan chan iputil.VpnIp + queryChan chan netip.Addr - calculatedRemotes atomic.Pointer[cidr.Tree4[[]*calculatedRemote]] // Maps VpnIp to []*calculatedRemote + calculatedRemotes atomic.Pointer[bart.Table[[]*calculatedRemote]] // Maps VpnIp to []*calculatedRemote metrics *MessageMetrics metricHolepunchTx metrics.Counter @@ -85,7 +78,7 @@ type LightHouse struct { // NewLightHouseFromConfig will build a Lighthouse struct from the values provided in the config object // addrMap should be nil unless this is during a config reload -func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C, myVpnNet *net.IPNet, pc udp.Conn, p *Punchy) (*LightHouse, error) { +func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C, myVpnNet netip.Prefix, pc udp.Conn, p *Punchy) (*LightHouse, error) { amLighthouse := c.GetBool("lighthouse.am_lighthouse", false) nebulaPort := uint32(c.GetInt("listen.port", 0)) if amLighthouse && nebulaPort == 0 { @@ -98,26 +91,23 @@ func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C, if err != nil { return nil, util.NewContextualError("Failed to get listening port", nil, err) } - nebulaPort = uint32(uPort.Port) + nebulaPort = uint32(uPort.Port()) } - ones, _ := myVpnNet.Mask.Size() h := LightHouse{ ctx: ctx, amLighthouse: amLighthouse, - myVpnIp: iputil.Ip2VpnIp(myVpnNet.IP), - myVpnZeros: iputil.VpnIp(32 - ones), myVpnNet: myVpnNet, - addrMap: make(map[iputil.VpnIp]*RemoteList), + addrMap: make(map[netip.Addr]*RemoteList), nebulaPort: nebulaPort, punchConn: pc, punchy: p, - queryChan: make(chan iputil.VpnIp, c.GetUint32("handshakes.query_buffer", 64)), + queryChan: make(chan netip.Addr, c.GetUint32("handshakes.query_buffer", 64)), l: l, } - lighthouses := make(map[iputil.VpnIp]struct{}) + lighthouses := make(map[netip.Addr]struct{}) h.lighthouses.Store(&lighthouses) - staticList := make(map[iputil.VpnIp]struct{}) + staticList := make(map[netip.Addr]struct{}) h.staticList.Store(&staticList) if c.GetBool("stats.lighthouse_metrics", false) { @@ -147,11 +137,11 @@ func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C, return &h, nil } -func (lh *LightHouse) GetStaticHostList() map[iputil.VpnIp]struct{} { +func (lh *LightHouse) GetStaticHostList() map[netip.Addr]struct{} { return *lh.staticList.Load() } -func (lh *LightHouse) GetLighthouses() map[iputil.VpnIp]struct{} { +func (lh *LightHouse) GetLighthouses() map[netip.Addr]struct{} { return *lh.lighthouses.Load() } @@ -163,15 +153,15 @@ func (lh *LightHouse) GetLocalAllowList() *LocalAllowList { return lh.localAllowList.Load() } -func (lh *LightHouse) GetAdvertiseAddrs() []netIpAndPort { +func (lh *LightHouse) GetAdvertiseAddrs() []netip.AddrPort { return *lh.advertiseAddrs.Load() } -func (lh *LightHouse) GetRelaysForMe() []iputil.VpnIp { +func (lh *LightHouse) GetRelaysForMe() []netip.Addr { return *lh.relaysForMe.Load() } -func (lh *LightHouse) getCalculatedRemotes() *cidr.Tree4[[]*calculatedRemote] { +func (lh *LightHouse) getCalculatedRemotes() *bart.Table[[]*calculatedRemote] { return lh.calculatedRemotes.Load() } @@ -182,25 +172,40 @@ func (lh *LightHouse) GetUpdateInterval() int64 { func (lh *LightHouse) reload(c *config.C, initial bool) error { if initial || c.HasChanged("lighthouse.advertise_addrs") { rawAdvAddrs := c.GetStringSlice("lighthouse.advertise_addrs", []string{}) - advAddrs := make([]netIpAndPort, 0) + advAddrs := make([]netip.AddrPort, 0) for i, rawAddr := range rawAdvAddrs { - fIp, fPort, err := udp.ParseIPAndPort(rawAddr) + host, sport, err := net.SplitHostPort(rawAddr) if err != nil { return util.NewContextualError("Unable to parse lighthouse.advertise_addrs entry", m{"addr": rawAddr, "entry": i + 1}, err) } - if fPort == 0 { - fPort = uint16(lh.nebulaPort) + ips, err := net.DefaultResolver.LookupNetIP(context.Background(), "ip", host) + if err != nil { + return util.NewContextualError("Unable to lookup lighthouse.advertise_addrs entry", m{"addr": rawAddr, "entry": i + 1}, err) + } + if len(ips) == 0 { + return util.NewContextualError("Unable to lookup lighthouse.advertise_addrs entry", m{"addr": rawAddr, "entry": i + 1}, nil) } - if ip4 := fIp.To4(); ip4 != nil && lh.myVpnNet.Contains(fIp) { + port, err := strconv.Atoi(sport) + if err != nil { + return util.NewContextualError("Unable to parse port in lighthouse.advertise_addrs entry", m{"addr": rawAddr, "entry": i + 1}, err) + } + + if port == 0 { + port = int(lh.nebulaPort) + } + + //TODO: we could technically insert all returned ips instead of just the first one if a dns lookup was used + ip := ips[0].Unmap() + if lh.myVpnNet.Contains(ip) { lh.l.WithField("addr", rawAddr).WithField("entry", i+1). Warn("Ignoring lighthouse.advertise_addrs report because it is within the nebula network range") continue } - advAddrs = append(advAddrs, netIpAndPort{ip: fIp, port: fPort}) + advAddrs = append(advAddrs, netip.AddrPortFrom(ip, uint16(port))) } lh.advertiseAddrs.Store(&advAddrs) @@ -278,8 +283,8 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error { lh.RUnlock() } // Build a new list based on current config. - staticList := make(map[iputil.VpnIp]struct{}) - err := lh.loadStaticMap(c, lh.myVpnNet, staticList) + staticList := make(map[netip.Addr]struct{}) + err := lh.loadStaticMap(c, staticList) if err != nil { return err } @@ -303,8 +308,8 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error { } if initial || c.HasChanged("lighthouse.hosts") { - lhMap := make(map[iputil.VpnIp]struct{}) - err := lh.parseLighthouses(c, lh.myVpnNet, lhMap) + lhMap := make(map[netip.Addr]struct{}) + err := lh.parseLighthouses(c, lhMap) if err != nil { return err } @@ -323,16 +328,17 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error { if len(c.GetStringSlice("relay.relays", nil)) > 0 { lh.l.Info("Ignoring relays from config because am_relay is true") } - relaysForMe := []iputil.VpnIp{} + relaysForMe := []netip.Addr{} lh.relaysForMe.Store(&relaysForMe) case false: - relaysForMe := []iputil.VpnIp{} + relaysForMe := []netip.Addr{} for _, v := range c.GetStringSlice("relay.relays", nil) { lh.l.WithField("relay", v).Info("Read relay from config") - configRIP := net.ParseIP(v) - if configRIP != nil { - relaysForMe = append(relaysForMe, iputil.Ip2VpnIp(configRIP)) + configRIP, err := netip.ParseAddr(v) + //TODO: We could print the error here + if err == nil { + relaysForMe = append(relaysForMe, configRIP) } } lh.relaysForMe.Store(&relaysForMe) @@ -342,21 +348,21 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error { return nil } -func (lh *LightHouse) parseLighthouses(c *config.C, tunCidr *net.IPNet, lhMap map[iputil.VpnIp]struct{}) error { +func (lh *LightHouse) parseLighthouses(c *config.C, lhMap map[netip.Addr]struct{}) error { lhs := c.GetStringSlice("lighthouse.hosts", []string{}) if lh.amLighthouse && len(lhs) != 0 { lh.l.Warn("lighthouse.am_lighthouse enabled on node but upstream lighthouses exist in config") } for i, host := range lhs { - ip := net.ParseIP(host) - if ip == nil { - return util.NewContextualError("Unable to parse lighthouse host entry", m{"host": host, "entry": i + 1}, nil) + ip, err := netip.ParseAddr(host) + if err != nil { + return util.NewContextualError("Unable to parse lighthouse host entry", m{"host": host, "entry": i + 1}, err) } - if !tunCidr.Contains(ip) { - return util.NewContextualError("lighthouse host is not in our subnet, invalid", m{"vpnIp": ip, "network": tunCidr.String()}, nil) + if !lh.myVpnNet.Contains(ip) { + return util.NewContextualError("lighthouse host is not in our subnet, invalid", m{"vpnIp": ip, "network": lh.myVpnNet}, nil) } - lhMap[iputil.Ip2VpnIp(ip)] = struct{}{} + lhMap[ip] = struct{}{} } if !lh.amLighthouse && len(lhMap) == 0 { @@ -399,7 +405,7 @@ func getStaticMapNetwork(c *config.C) (string, error) { return network, nil } -func (lh *LightHouse) loadStaticMap(c *config.C, tunCidr *net.IPNet, staticList map[iputil.VpnIp]struct{}) error { +func (lh *LightHouse) loadStaticMap(c *config.C, staticList map[netip.Addr]struct{}) error { d, err := getStaticMapCadence(c) if err != nil { return err @@ -410,7 +416,7 @@ func (lh *LightHouse) loadStaticMap(c *config.C, tunCidr *net.IPNet, staticList return err } - lookup_timeout, err := getStaticMapLookupTimeout(c) + lookupTimeout, err := getStaticMapLookupTimeout(c) if err != nil { return err } @@ -419,16 +425,15 @@ func (lh *LightHouse) loadStaticMap(c *config.C, tunCidr *net.IPNet, staticList i := 0 for k, v := range shm { - rip := net.ParseIP(fmt.Sprintf("%v", k)) - if rip == nil { - return util.NewContextualError("Unable to parse static_host_map entry", m{"host": k, "entry": i + 1}, nil) + vpnIp, err := netip.ParseAddr(fmt.Sprintf("%v", k)) + if err != nil { + return util.NewContextualError("Unable to parse static_host_map entry", m{"host": k, "entry": i + 1}, err) } - if !tunCidr.Contains(rip) { - return util.NewContextualError("static_host_map key is not in our subnet, invalid", m{"vpnIp": rip, "network": tunCidr.String(), "entry": i + 1}, nil) + if !lh.myVpnNet.Contains(vpnIp) { + return util.NewContextualError("static_host_map key is not in our subnet, invalid", m{"vpnIp": vpnIp, "network": lh.myVpnNet, "entry": i + 1}, nil) } - vpnIp := iputil.Ip2VpnIp(rip) vals, ok := v.([]interface{}) if !ok { vals = []interface{}{v} @@ -438,7 +443,7 @@ func (lh *LightHouse) loadStaticMap(c *config.C, tunCidr *net.IPNet, staticList remoteAddrs = append(remoteAddrs, fmt.Sprintf("%v", v)) } - err := lh.addStaticRemotes(i, d, network, lookup_timeout, vpnIp, remoteAddrs, staticList) + err = lh.addStaticRemotes(i, d, network, lookupTimeout, vpnIp, remoteAddrs, staticList) if err != nil { return err } @@ -448,7 +453,7 @@ func (lh *LightHouse) loadStaticMap(c *config.C, tunCidr *net.IPNet, staticList return nil } -func (lh *LightHouse) Query(ip iputil.VpnIp) *RemoteList { +func (lh *LightHouse) Query(ip netip.Addr) *RemoteList { if !lh.IsLighthouseIP(ip) { lh.QueryServer(ip) } @@ -462,7 +467,7 @@ func (lh *LightHouse) Query(ip iputil.VpnIp) *RemoteList { } // QueryServer is asynchronous so no reply should be expected -func (lh *LightHouse) QueryServer(ip iputil.VpnIp) { +func (lh *LightHouse) QueryServer(ip netip.Addr) { // Don't put lighthouse ips in the query channel because we can't query lighthouses about lighthouses if lh.amLighthouse || lh.IsLighthouseIP(ip) { return @@ -471,7 +476,7 @@ func (lh *LightHouse) QueryServer(ip iputil.VpnIp) { lh.queryChan <- ip } -func (lh *LightHouse) QueryCache(ip iputil.VpnIp) *RemoteList { +func (lh *LightHouse) QueryCache(ip netip.Addr) *RemoteList { lh.RLock() if v, ok := lh.addrMap[ip]; ok { lh.RUnlock() @@ -488,7 +493,7 @@ func (lh *LightHouse) QueryCache(ip iputil.VpnIp) *RemoteList { // queryAndPrepMessage is a lock helper on RemoteList, assisting the caller to build a lighthouse message containing // details from the remote list. It looks for a hit in the addrMap and a hit in the RemoteList under the owner vpnIp // If one is found then f() is called with proper locking, f() must return result of n.MarshalTo() -func (lh *LightHouse) queryAndPrepMessage(vpnIp iputil.VpnIp, f func(*cache) (int, error)) (bool, int, error) { +func (lh *LightHouse) queryAndPrepMessage(vpnIp netip.Addr, f func(*cache) (int, error)) (bool, int, error) { lh.RLock() // Do we have an entry in the main cache? if v, ok := lh.addrMap[vpnIp]; ok { @@ -511,7 +516,7 @@ func (lh *LightHouse) queryAndPrepMessage(vpnIp iputil.VpnIp, f func(*cache) (in return false, 0, nil } -func (lh *LightHouse) DeleteVpnIp(vpnIp iputil.VpnIp) { +func (lh *LightHouse) DeleteVpnIp(vpnIp netip.Addr) { // First we check the static mapping // and do nothing if it is there if _, ok := lh.GetStaticHostList()[vpnIp]; ok { @@ -532,7 +537,7 @@ func (lh *LightHouse) DeleteVpnIp(vpnIp iputil.VpnIp) { // We are the owner because we don't want a lighthouse server to advertise for static hosts it was configured with // And we don't want a lighthouse query reply to interfere with our learned cache if we are a client // NOTE: this function should not interact with any hot path objects, like lh.staticList, the caller should handle it -func (lh *LightHouse) addStaticRemotes(i int, d time.Duration, network string, timeout time.Duration, vpnIp iputil.VpnIp, toAddrs []string, staticList map[iputil.VpnIp]struct{}) error { +func (lh *LightHouse) addStaticRemotes(i int, d time.Duration, network string, timeout time.Duration, vpnIp netip.Addr, toAddrs []string, staticList map[netip.Addr]struct{}) error { lh.Lock() am := lh.unlockedGetRemoteList(vpnIp) am.Lock() @@ -553,20 +558,14 @@ func (lh *LightHouse) addStaticRemotes(i int, d time.Duration, network string, t am.unlockedSetHostnamesResults(hr) for _, addrPort := range hr.GetIPs() { - + if !lh.shouldAdd(vpnIp, addrPort.Addr()) { + continue + } switch { case addrPort.Addr().Is4(): - to := NewIp4AndPortFromNetIP(addrPort.Addr(), addrPort.Port()) - if !lh.unlockedShouldAddV4(vpnIp, to) { - continue - } - am.unlockedPrependV4(lh.myVpnIp, to) + am.unlockedPrependV4(lh.myVpnNet.Addr(), NewIp4AndPortFromNetIP(addrPort.Addr(), addrPort.Port())) case addrPort.Addr().Is6(): - to := NewIp6AndPortFromNetIP(addrPort.Addr(), addrPort.Port()) - if !lh.unlockedShouldAddV6(vpnIp, to) { - continue - } - am.unlockedPrependV6(lh.myVpnIp, to) + am.unlockedPrependV6(lh.myVpnNet.Addr(), NewIp6AndPortFromNetIP(addrPort.Addr(), addrPort.Port())) } } @@ -578,12 +577,12 @@ func (lh *LightHouse) addStaticRemotes(i int, d time.Duration, network string, t // addCalculatedRemotes adds any calculated remotes based on the // lighthouse.calculated_remotes configuration. It returns true if any // calculated remotes were added -func (lh *LightHouse) addCalculatedRemotes(vpnIp iputil.VpnIp) bool { +func (lh *LightHouse) addCalculatedRemotes(vpnIp netip.Addr) bool { tree := lh.getCalculatedRemotes() if tree == nil { return false } - ok, calculatedRemotes := tree.MostSpecificContains(vpnIp) + calculatedRemotes, ok := tree.Lookup(vpnIp) if !ok { return false } @@ -602,13 +601,13 @@ func (lh *LightHouse) addCalculatedRemotes(vpnIp iputil.VpnIp) bool { defer am.Unlock() lh.Unlock() - am.unlockedSetV4(lh.myVpnIp, vpnIp, calculated, lh.unlockedShouldAddV4) + am.unlockedSetV4(lh.myVpnNet.Addr(), vpnIp, calculated, lh.unlockedShouldAddV4) return len(calculated) > 0 } // unlockedGetRemoteList assumes you have the lh lock -func (lh *LightHouse) unlockedGetRemoteList(vpnIp iputil.VpnIp) *RemoteList { +func (lh *LightHouse) unlockedGetRemoteList(vpnIp netip.Addr) *RemoteList { am, ok := lh.addrMap[vpnIp] if !ok { am = NewRemoteList(func(a netip.Addr) bool { return lh.shouldAdd(vpnIp, a) }) @@ -617,44 +616,27 @@ func (lh *LightHouse) unlockedGetRemoteList(vpnIp iputil.VpnIp) *RemoteList { return am } -func (lh *LightHouse) shouldAdd(vpnIp iputil.VpnIp, to netip.Addr) bool { - switch { - case to.Is4(): - ipBytes := to.As4() - ip := iputil.Ip2VpnIp(ipBytes[:]) - allow := lh.GetRemoteAllowList().AllowIpV4(vpnIp, ip) - if lh.l.Level >= logrus.TraceLevel { - lh.l.WithField("remoteIp", vpnIp).WithField("allow", allow).Trace("remoteAllowList.Allow") - } - if !allow || ipMaskContains(lh.myVpnIp, lh.myVpnZeros, ip) { - return false - } - case to.Is6(): - ipBytes := to.As16() - - hi := binary.BigEndian.Uint64(ipBytes[:8]) - lo := binary.BigEndian.Uint64(ipBytes[8:]) - allow := lh.GetRemoteAllowList().AllowIpV6(vpnIp, hi, lo) - if lh.l.Level >= logrus.TraceLevel { - lh.l.WithField("remoteIp", to).WithField("allow", allow).Trace("remoteAllowList.Allow") - } - - // We don't check our vpn network here because nebula does not support ipv6 on the inside - if !allow { - return false - } +func (lh *LightHouse) shouldAdd(vpnIp netip.Addr, to netip.Addr) bool { + allow := lh.GetRemoteAllowList().Allow(vpnIp, to) + if lh.l.Level >= logrus.TraceLevel { + lh.l.WithField("remoteIp", vpnIp).WithField("allow", allow).Trace("remoteAllowList.Allow") } + if !allow || lh.myVpnNet.Contains(to) { + return false + } + return true } // unlockedShouldAddV4 checks if to is allowed by our allow list -func (lh *LightHouse) unlockedShouldAddV4(vpnIp iputil.VpnIp, to *Ip4AndPort) bool { - allow := lh.GetRemoteAllowList().AllowIpV4(vpnIp, iputil.VpnIp(to.Ip)) +func (lh *LightHouse) unlockedShouldAddV4(vpnIp netip.Addr, to *Ip4AndPort) bool { + ip := AddrPortFromIp4AndPort(to) + allow := lh.GetRemoteAllowList().Allow(vpnIp, ip.Addr()) if lh.l.Level >= logrus.TraceLevel { lh.l.WithField("remoteIp", vpnIp).WithField("allow", allow).Trace("remoteAllowList.Allow") } - if !allow || ipMaskContains(lh.myVpnIp, lh.myVpnZeros, iputil.VpnIp(to.Ip)) { + if !allow || lh.myVpnNet.Contains(ip.Addr()) { return false } @@ -662,14 +644,14 @@ func (lh *LightHouse) unlockedShouldAddV4(vpnIp iputil.VpnIp, to *Ip4AndPort) bo } // unlockedShouldAddV6 checks if to is allowed by our allow list -func (lh *LightHouse) unlockedShouldAddV6(vpnIp iputil.VpnIp, to *Ip6AndPort) bool { - allow := lh.GetRemoteAllowList().AllowIpV6(vpnIp, to.Hi, to.Lo) +func (lh *LightHouse) unlockedShouldAddV6(vpnIp netip.Addr, to *Ip6AndPort) bool { + ip := AddrPortFromIp6AndPort(to) + allow := lh.GetRemoteAllowList().Allow(vpnIp, ip.Addr()) if lh.l.Level >= logrus.TraceLevel { lh.l.WithField("remoteIp", lhIp6ToIp(to)).WithField("allow", allow).Trace("remoteAllowList.Allow") } - // We don't check our vpn network here because nebula does not support ipv6 on the inside - if !allow { + if !allow || lh.myVpnNet.Contains(ip.Addr()) { return false } @@ -683,26 +665,39 @@ func lhIp6ToIp(v *Ip6AndPort) net.IP { return ip } -func (lh *LightHouse) IsLighthouseIP(vpnIp iputil.VpnIp) bool { +func (lh *LightHouse) IsLighthouseIP(vpnIp netip.Addr) bool { if _, ok := lh.GetLighthouses()[vpnIp]; ok { return true } return false } -func NewLhQueryByInt(VpnIp iputil.VpnIp) *NebulaMeta { +func NewLhQueryByInt(vpnIp netip.Addr) *NebulaMeta { + if vpnIp.Is6() { + //TODO: need to support ipv6 + panic("ipv6 is not yet supported") + } + + b := vpnIp.As4() return &NebulaMeta{ Type: NebulaMeta_HostQuery, Details: &NebulaMetaDetails{ - VpnIp: uint32(VpnIp), + VpnIp: binary.BigEndian.Uint32(b[:]), }, } } -func NewIp4AndPort(ip net.IP, port uint32) *Ip4AndPort { - ipp := Ip4AndPort{Port: port} - ipp.Ip = uint32(iputil.Ip2VpnIp(ip)) - return &ipp +func AddrPortFromIp4AndPort(ip *Ip4AndPort) netip.AddrPort { + b := [4]byte{} + binary.BigEndian.PutUint32(b[:], ip.Ip) + return netip.AddrPortFrom(netip.AddrFrom4(b), uint16(ip.Port)) +} + +func AddrPortFromIp6AndPort(ip *Ip6AndPort) netip.AddrPort { + b := [16]byte{} + binary.BigEndian.PutUint64(b[:8], ip.Hi) + binary.BigEndian.PutUint64(b[8:], ip.Lo) + return netip.AddrPortFrom(netip.AddrFrom16(b), uint16(ip.Port)) } func NewIp4AndPortFromNetIP(ip netip.Addr, port uint16) *Ip4AndPort { @@ -713,14 +708,7 @@ func NewIp4AndPortFromNetIP(ip netip.Addr, port uint16) *Ip4AndPort { } } -func NewIp6AndPort(ip net.IP, port uint32) *Ip6AndPort { - return &Ip6AndPort{ - Hi: binary.BigEndian.Uint64(ip[:8]), - Lo: binary.BigEndian.Uint64(ip[8:]), - Port: port, - } -} - +// TODO: IPV6-WORK we can delete some more of these func NewIp6AndPortFromNetIP(ip netip.Addr, port uint16) *Ip6AndPort { ip6Addr := ip.As16() return &Ip6AndPort{ @@ -729,17 +717,6 @@ func NewIp6AndPortFromNetIP(ip netip.Addr, port uint16) *Ip6AndPort { Port: uint32(port), } } -func NewUDPAddrFromLH4(ipp *Ip4AndPort) *udp.Addr { - ip := ipp.Ip - return udp.NewAddr( - net.IPv4(byte(ip&0xff000000>>24), byte(ip&0x00ff0000>>16), byte(ip&0x0000ff00>>8), byte(ip&0x000000ff)), - uint16(ipp.Port), - ) -} - -func NewUDPAddrFromLH6(ipp *Ip6AndPort) *udp.Addr { - return udp.NewAddr(lhIp6ToIp(ipp), uint16(ipp.Port)) -} func (lh *LightHouse) startQueryWorker() { if lh.amLighthouse { @@ -761,7 +738,7 @@ func (lh *LightHouse) startQueryWorker() { }() } -func (lh *LightHouse) innerQueryServer(ip iputil.VpnIp, nb, out []byte) { +func (lh *LightHouse) innerQueryServer(ip netip.Addr, nb, out []byte) { if lh.IsLighthouseIP(ip) { return } @@ -812,36 +789,41 @@ func (lh *LightHouse) SendUpdate() { var v6 []*Ip6AndPort for _, e := range lh.GetAdvertiseAddrs() { - if ip := e.ip.To4(); ip != nil { - v4 = append(v4, NewIp4AndPort(e.ip, uint32(e.port))) + if e.Addr().Is4() { + v4 = append(v4, NewIp4AndPortFromNetIP(e.Addr(), e.Port())) } else { - v6 = append(v6, NewIp6AndPort(e.ip, uint32(e.port))) + v6 = append(v6, NewIp6AndPortFromNetIP(e.Addr(), e.Port())) } } lal := lh.GetLocalAllowList() - for _, e := range *localIps(lh.l, lal) { - if ip4 := e.To4(); ip4 != nil && ipMaskContains(lh.myVpnIp, lh.myVpnZeros, iputil.Ip2VpnIp(ip4)) { + for _, e := range localIps(lh.l, lal) { + if lh.myVpnNet.Contains(e) { continue } // Only add IPs that aren't my VPN/tun IP - if ip := e.To4(); ip != nil { - v4 = append(v4, NewIp4AndPort(e, lh.nebulaPort)) + if e.Is4() { + v4 = append(v4, NewIp4AndPortFromNetIP(e, uint16(lh.nebulaPort))) } else { - v6 = append(v6, NewIp6AndPort(e, lh.nebulaPort)) + v6 = append(v6, NewIp6AndPortFromNetIP(e, uint16(lh.nebulaPort))) } } var relays []uint32 for _, r := range lh.GetRelaysForMe() { - relays = append(relays, (uint32)(r)) + //TODO: IPV6-WORK both relays and vpnip need ipv6 support + b := r.As4() + relays = append(relays, binary.BigEndian.Uint32(b[:])) } + //TODO: IPV6-WORK both relays and vpnip need ipv6 support + b := lh.myVpnNet.Addr().As4() + m := &NebulaMeta{ Type: NebulaMeta_HostUpdateNotification, Details: &NebulaMetaDetails{ - VpnIp: uint32(lh.myVpnIp), + VpnIp: binary.BigEndian.Uint32(b[:]), Ip4AndPorts: v4, Ip6AndPorts: v6, RelayVpnIp: relays, @@ -913,12 +895,12 @@ func (lhh *LightHouseHandler) resetMeta() *NebulaMeta { } func lhHandleRequest(lhh *LightHouseHandler, f *Interface) udp.LightHouseHandlerFunc { - return func(rAddr *udp.Addr, vpnIp iputil.VpnIp, p []byte) { + return func(rAddr netip.AddrPort, vpnIp netip.Addr, p []byte) { lhh.HandleRequest(rAddr, vpnIp, p, f) } } -func (lhh *LightHouseHandler) HandleRequest(rAddr *udp.Addr, vpnIp iputil.VpnIp, p []byte, w EncWriter) { +func (lhh *LightHouseHandler) HandleRequest(rAddr netip.AddrPort, vpnIp netip.Addr, p []byte, w EncWriter) { n := lhh.resetMeta() err := n.Unmarshal(p) if err != nil { @@ -956,7 +938,7 @@ func (lhh *LightHouseHandler) HandleRequest(rAddr *udp.Addr, vpnIp iputil.VpnIp, } } -func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, vpnIp iputil.VpnIp, addr *udp.Addr, w EncWriter) { +func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, vpnIp netip.Addr, addr netip.AddrPort, w EncWriter) { // Exit if we don't answer queries if !lhh.lh.amLighthouse { if lhh.l.Level >= logrus.DebugLevel { @@ -967,8 +949,14 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, vpnIp iputil.VpnIp, //TODO: we can DRY this further reqVpnIp := n.Details.VpnIp + + //TODO: IPV6-WORK + b := [4]byte{} + binary.BigEndian.PutUint32(b[:], n.Details.VpnIp) + queryVpnIp := netip.AddrFrom4(b) + //TODO: Maybe instead of marshalling into n we marshal into a new `r` to not nuke our current request data - found, ln, err := lhh.lh.queryAndPrepMessage(iputil.VpnIp(n.Details.VpnIp), func(c *cache) (int, error) { + found, ln, err := lhh.lh.queryAndPrepMessage(queryVpnIp, func(c *cache) (int, error) { n = lhh.resetMeta() n.Type = NebulaMeta_HostQueryReply n.Details.VpnIp = reqVpnIp @@ -994,8 +982,9 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, vpnIp iputil.VpnIp, found, ln, err = lhh.lh.queryAndPrepMessage(vpnIp, func(c *cache) (int, error) { n = lhh.resetMeta() n.Type = NebulaMeta_HostPunchNotification - n.Details.VpnIp = uint32(vpnIp) - + //TODO: IPV6-WORK + b = vpnIp.As4() + n.Details.VpnIp = binary.BigEndian.Uint32(b[:]) lhh.coalesceAnswers(c, n) return n.MarshalTo(lhh.pb) @@ -1011,7 +1000,11 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, vpnIp iputil.VpnIp, } lhh.lh.metricTx(NebulaMeta_HostPunchNotification, 1) - w.SendMessageToVpnIp(header.LightHouse, 0, iputil.VpnIp(reqVpnIp), lhh.pb[:ln], lhh.nb, lhh.out[:0]) + + //TODO: IPV6-WORK + binary.BigEndian.PutUint32(b[:], reqVpnIp) + sendTo := netip.AddrFrom4(b) + w.SendMessageToVpnIp(header.LightHouse, 0, sendTo, lhh.pb[:ln], lhh.nb, lhh.out[:0]) } func (lhh *LightHouseHandler) coalesceAnswers(c *cache, n *NebulaMeta) { @@ -1034,34 +1027,52 @@ func (lhh *LightHouseHandler) coalesceAnswers(c *cache, n *NebulaMeta) { } if c.relay != nil { - n.Details.RelayVpnIp = append(n.Details.RelayVpnIp, c.relay.relay...) + //TODO: IPV6-WORK + relays := make([]uint32, len(c.relay.relay)) + b := [4]byte{} + for i, _ := range relays { + b = c.relay.relay[i].As4() + relays[i] = binary.BigEndian.Uint32(b[:]) + } + n.Details.RelayVpnIp = append(n.Details.RelayVpnIp, relays...) } } -func (lhh *LightHouseHandler) handleHostQueryReply(n *NebulaMeta, vpnIp iputil.VpnIp) { +func (lhh *LightHouseHandler) handleHostQueryReply(n *NebulaMeta, vpnIp netip.Addr) { if !lhh.lh.IsLighthouseIP(vpnIp) { return } lhh.lh.Lock() - am := lhh.lh.unlockedGetRemoteList(iputil.VpnIp(n.Details.VpnIp)) + //TODO: IPV6-WORK + b := [4]byte{} + binary.BigEndian.PutUint32(b[:], n.Details.VpnIp) + certVpnIp := netip.AddrFrom4(b) + am := lhh.lh.unlockedGetRemoteList(certVpnIp) am.Lock() lhh.lh.Unlock() - certVpnIp := iputil.VpnIp(n.Details.VpnIp) + //TODO: IPV6-WORK am.unlockedSetV4(vpnIp, certVpnIp, n.Details.Ip4AndPorts, lhh.lh.unlockedShouldAddV4) am.unlockedSetV6(vpnIp, certVpnIp, n.Details.Ip6AndPorts, lhh.lh.unlockedShouldAddV6) - am.unlockedSetRelay(vpnIp, certVpnIp, n.Details.RelayVpnIp) + + //TODO: IPV6-WORK + relays := make([]netip.Addr, len(n.Details.RelayVpnIp)) + for i, _ := range n.Details.RelayVpnIp { + binary.BigEndian.PutUint32(b[:], n.Details.RelayVpnIp[i]) + relays[i] = netip.AddrFrom4(b) + } + am.unlockedSetRelay(vpnIp, certVpnIp, relays) am.Unlock() // Non-blocking attempt to trigger, skip if it would block select { - case lhh.lh.handshakeTrigger <- iputil.VpnIp(n.Details.VpnIp): + case lhh.lh.handshakeTrigger <- certVpnIp: default: } } -func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, vpnIp iputil.VpnIp, w EncWriter) { +func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, vpnIp netip.Addr, w EncWriter) { if !lhh.lh.amLighthouse { if lhh.l.Level >= logrus.DebugLevel { lhh.l.Debugln("I am not a lighthouse, do not take host updates: ", vpnIp) @@ -1070,9 +1081,13 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, vpnIp } //Simple check that the host sent this not someone else - if n.Details.VpnIp != uint32(vpnIp) { + //TODO: IPV6-WORK + b := [4]byte{} + binary.BigEndian.PutUint32(b[:], n.Details.VpnIp) + detailsVpnIp := netip.AddrFrom4(b) + if detailsVpnIp != vpnIp { if lhh.l.Level >= logrus.DebugLevel { - lhh.l.WithField("vpnIp", vpnIp).WithField("answer", iputil.VpnIp(n.Details.VpnIp)).Debugln("Host sent invalid update") + lhh.l.WithField("vpnIp", vpnIp).WithField("answer", detailsVpnIp).Debugln("Host sent invalid update") } return } @@ -1082,15 +1097,24 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, vpnIp am.Lock() lhh.lh.Unlock() - certVpnIp := iputil.VpnIp(n.Details.VpnIp) - am.unlockedSetV4(vpnIp, certVpnIp, n.Details.Ip4AndPorts, lhh.lh.unlockedShouldAddV4) - am.unlockedSetV6(vpnIp, certVpnIp, n.Details.Ip6AndPorts, lhh.lh.unlockedShouldAddV6) - am.unlockedSetRelay(vpnIp, certVpnIp, n.Details.RelayVpnIp) + am.unlockedSetV4(vpnIp, detailsVpnIp, n.Details.Ip4AndPorts, lhh.lh.unlockedShouldAddV4) + am.unlockedSetV6(vpnIp, detailsVpnIp, n.Details.Ip6AndPorts, lhh.lh.unlockedShouldAddV6) + + //TODO: IPV6-WORK + relays := make([]netip.Addr, len(n.Details.RelayVpnIp)) + for i, _ := range n.Details.RelayVpnIp { + binary.BigEndian.PutUint32(b[:], n.Details.RelayVpnIp[i]) + relays[i] = netip.AddrFrom4(b) + } + am.unlockedSetRelay(vpnIp, detailsVpnIp, relays) am.Unlock() n = lhh.resetMeta() n.Type = NebulaMeta_HostUpdateNotificationAck - n.Details.VpnIp = uint32(vpnIp) + + //TODO: IPV6-WORK + vpnIpB := vpnIp.As4() + n.Details.VpnIp = binary.BigEndian.Uint32(vpnIpB[:]) ln, err := n.MarshalTo(lhh.pb) if err != nil { @@ -1102,14 +1126,14 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, vpnIp w.SendMessageToVpnIp(header.LightHouse, 0, vpnIp, lhh.pb[:ln], lhh.nb, lhh.out[:0]) } -func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, vpnIp iputil.VpnIp, w EncWriter) { +func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, vpnIp netip.Addr, w EncWriter) { if !lhh.lh.IsLighthouseIP(vpnIp) { return } empty := []byte{0} - punch := func(vpnPeer *udp.Addr) { - if vpnPeer == nil { + punch := func(vpnPeer netip.AddrPort) { + if !vpnPeer.IsValid() { return } @@ -1121,23 +1145,29 @@ func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, vpnIp i if lhh.l.Level >= logrus.DebugLevel { //TODO: lacking the ip we are actually punching on, old: l.Debugf("Punching %s on %d for %s", IntIp(a.Ip), a.Port, IntIp(n.Details.VpnIp)) - lhh.l.Debugf("Punching on %d for %s", vpnPeer.Port, iputil.VpnIp(n.Details.VpnIp)) + //TODO: IPV6-WORK, make this debug line not suck + b := [4]byte{} + binary.BigEndian.PutUint32(b[:], n.Details.VpnIp) + lhh.l.Debugf("Punching on %d for %v", vpnPeer.Port(), netip.AddrFrom4(b)) } } for _, a := range n.Details.Ip4AndPorts { - punch(NewUDPAddrFromLH4(a)) + punch(AddrPortFromIp4AndPort(a)) } for _, a := range n.Details.Ip6AndPorts { - punch(NewUDPAddrFromLH6(a)) + punch(AddrPortFromIp6AndPort(a)) } // This sends a nebula test packet to the host trying to contact us. In the case // of a double nat or other difficult scenario, this may help establish // a tunnel. if lhh.lh.punchy.GetRespond() { - queryVpnIp := iputil.VpnIp(n.Details.VpnIp) + //TODO: IPV6-WORK + b := [4]byte{} + binary.BigEndian.PutUint32(b[:], n.Details.VpnIp) + queryVpnIp := netip.AddrFrom4(b) go func() { time.Sleep(lhh.lh.punchy.GetRespondDelay()) if lhh.l.Level >= logrus.DebugLevel { @@ -1150,9 +1180,3 @@ func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, vpnIp i }() } } - -// ipMaskContains checks if testIp is contained by ip after applying a cidr. -// zeros is 32 - bits from net.IPMask.Size() -func ipMaskContains(ip iputil.VpnIp, zeros iputil.VpnIp, testIp iputil.VpnIp) bool { - return (testIp^ip)>>zeros == 0 -} diff --git a/lighthouse_test.go b/lighthouse_test.go index 66427e3..2599f5f 100644 --- a/lighthouse_test.go +++ b/lighthouse_test.go @@ -2,15 +2,14 @@ package nebula import ( "context" + "encoding/binary" "fmt" - "net" + "net/netip" "testing" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/header" - "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/test" - "github.com/slackhq/nebula/udp" "github.com/stretchr/testify/assert" "gopkg.in/yaml.v2" ) @@ -23,15 +22,17 @@ func TestOldIPv4Only(t *testing.T) { var m Ip4AndPort err := m.Unmarshal(b) assert.NoError(t, err) - assert.Equal(t, "10.1.1.1", iputil.VpnIp(m.GetIp()).String()) + ip := netip.MustParseAddr("10.1.1.1") + bp := ip.As4() + assert.Equal(t, binary.BigEndian.Uint32(bp[:]), m.GetIp()) } func TestNewLhQuery(t *testing.T) { - myIp := net.ParseIP("192.1.1.1") - myIpint := iputil.Ip2VpnIp(myIp) + myIp, err := netip.ParseAddr("192.1.1.1") + assert.NoError(t, err) // Generating a new lh query should work - a := NewLhQueryByInt(myIpint) + a := NewLhQueryByInt(myIp) // The result should be a nebulameta protobuf assert.IsType(t, &NebulaMeta{}, a) @@ -49,7 +50,7 @@ func TestNewLhQuery(t *testing.T) { func Test_lhStaticMapping(t *testing.T) { l := test.NewLogger() - _, myVpnNet, _ := net.ParseCIDR("10.128.0.1/16") + myVpnNet := netip.MustParsePrefix("10.128.0.1/16") lh1 := "10.128.0.2" c := config.NewC(l) @@ -68,7 +69,7 @@ func Test_lhStaticMapping(t *testing.T) { func TestReloadLighthouseInterval(t *testing.T) { l := test.NewLogger() - _, myVpnNet, _ := net.ParseCIDR("10.128.0.1/16") + myVpnNet := netip.MustParsePrefix("10.128.0.1/16") lh1 := "10.128.0.2" c := config.NewC(l) @@ -83,21 +84,21 @@ func TestReloadLighthouseInterval(t *testing.T) { lh.ifce = &mockEncWriter{} // The first one routine is kicked off by main.go currently, lets make sure that one dies - c.ReloadConfigString("lighthouse:\n interval: 5") + assert.NoError(t, c.ReloadConfigString("lighthouse:\n interval: 5")) assert.Equal(t, int64(5), lh.interval.Load()) // Subsequent calls are killed off by the LightHouse.Reload function - c.ReloadConfigString("lighthouse:\n interval: 10") + assert.NoError(t, c.ReloadConfigString("lighthouse:\n interval: 10")) assert.Equal(t, int64(10), lh.interval.Load()) // If this completes then nothing is stealing our reload routine - c.ReloadConfigString("lighthouse:\n interval: 11") + assert.NoError(t, c.ReloadConfigString("lighthouse:\n interval: 11")) assert.Equal(t, int64(11), lh.interval.Load()) } func BenchmarkLighthouseHandleRequest(b *testing.B) { l := test.NewLogger() - _, myVpnNet, _ := net.ParseCIDR("10.128.0.1/0") + myVpnNet := netip.MustParsePrefix("10.128.0.1/0") c := config.NewC(l) lh, err := NewLightHouseFromConfig(context.Background(), l, c, myVpnNet, nil, nil) @@ -105,30 +106,33 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) { b.Fatal() } - hAddr := udp.NewAddrFromString("4.5.6.7:12345") - hAddr2 := udp.NewAddrFromString("4.5.6.7:12346") - lh.addrMap[3] = NewRemoteList(nil) - lh.addrMap[3].unlockedSetV4( - 3, - 3, + hAddr := netip.MustParseAddrPort("4.5.6.7:12345") + hAddr2 := netip.MustParseAddrPort("4.5.6.7:12346") + + vpnIp3 := netip.MustParseAddr("0.0.0.3") + lh.addrMap[vpnIp3] = NewRemoteList(nil) + lh.addrMap[vpnIp3].unlockedSetV4( + vpnIp3, + vpnIp3, []*Ip4AndPort{ - NewIp4AndPort(hAddr.IP, uint32(hAddr.Port)), - NewIp4AndPort(hAddr2.IP, uint32(hAddr2.Port)), + NewIp4AndPortFromNetIP(hAddr.Addr(), hAddr.Port()), + NewIp4AndPortFromNetIP(hAddr2.Addr(), hAddr2.Port()), }, - func(iputil.VpnIp, *Ip4AndPort) bool { return true }, + func(netip.Addr, *Ip4AndPort) bool { return true }, ) - rAddr := udp.NewAddrFromString("1.2.2.3:12345") - rAddr2 := udp.NewAddrFromString("1.2.2.3:12346") - lh.addrMap[2] = NewRemoteList(nil) - lh.addrMap[2].unlockedSetV4( - 3, - 3, + rAddr := netip.MustParseAddrPort("1.2.2.3:12345") + rAddr2 := netip.MustParseAddrPort("1.2.2.3:12346") + vpnIp2 := netip.MustParseAddr("0.0.0.3") + lh.addrMap[vpnIp2] = NewRemoteList(nil) + lh.addrMap[vpnIp2].unlockedSetV4( + vpnIp3, + vpnIp3, []*Ip4AndPort{ - NewIp4AndPort(rAddr.IP, uint32(rAddr.Port)), - NewIp4AndPort(rAddr2.IP, uint32(rAddr2.Port)), + NewIp4AndPortFromNetIP(rAddr.Addr(), rAddr.Port()), + NewIp4AndPortFromNetIP(rAddr2.Addr(), rAddr2.Port()), }, - func(iputil.VpnIp, *Ip4AndPort) bool { return true }, + func(netip.Addr, *Ip4AndPort) bool { return true }, ) mw := &mockEncWriter{} @@ -145,7 +149,7 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) { p, err := req.Marshal() assert.NoError(b, err) for n := 0; n < b.N; n++ { - lhh.HandleRequest(rAddr, 2, p, mw) + lhh.HandleRequest(rAddr, vpnIp2, p, mw) } }) b.Run("found", func(b *testing.B) { @@ -161,7 +165,7 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) { assert.NoError(b, err) for n := 0; n < b.N; n++ { - lhh.HandleRequest(rAddr, 2, p, mw) + lhh.HandleRequest(rAddr, vpnIp2, p, mw) } }) } @@ -169,51 +173,51 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) { func TestLighthouse_Memory(t *testing.T) { l := test.NewLogger() - myUdpAddr0 := &udp.Addr{IP: net.ParseIP("10.0.0.2"), Port: 4242} - myUdpAddr1 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4242} - myUdpAddr2 := &udp.Addr{IP: net.ParseIP("172.16.0.2"), Port: 4242} - myUdpAddr3 := &udp.Addr{IP: net.ParseIP("100.152.0.2"), Port: 4242} - myUdpAddr4 := &udp.Addr{IP: net.ParseIP("24.15.0.2"), Port: 4242} - myUdpAddr5 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4243} - myUdpAddr6 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4244} - myUdpAddr7 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4245} - myUdpAddr8 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4246} - myUdpAddr9 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4247} - myUdpAddr10 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4248} - myUdpAddr11 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4249} - myVpnIp := iputil.Ip2VpnIp(net.ParseIP("10.128.0.2")) + myUdpAddr0 := netip.MustParseAddrPort("10.0.0.2:4242") + myUdpAddr1 := netip.MustParseAddrPort("192.168.0.2:4242") + myUdpAddr2 := netip.MustParseAddrPort("172.16.0.2:4242") + myUdpAddr3 := netip.MustParseAddrPort("100.152.0.2:4242") + myUdpAddr4 := netip.MustParseAddrPort("24.15.0.2:4242") + myUdpAddr5 := netip.MustParseAddrPort("192.168.0.2:4243") + myUdpAddr6 := netip.MustParseAddrPort("192.168.0.2:4244") + myUdpAddr7 := netip.MustParseAddrPort("192.168.0.2:4245") + myUdpAddr8 := netip.MustParseAddrPort("192.168.0.2:4246") + myUdpAddr9 := netip.MustParseAddrPort("192.168.0.2:4247") + myUdpAddr10 := netip.MustParseAddrPort("192.168.0.2:4248") + myUdpAddr11 := netip.MustParseAddrPort("192.168.0.2:4249") + myVpnIp := netip.MustParseAddr("10.128.0.2") - theirUdpAddr0 := &udp.Addr{IP: net.ParseIP("10.0.0.3"), Port: 4242} - theirUdpAddr1 := &udp.Addr{IP: net.ParseIP("192.168.0.3"), Port: 4242} - theirUdpAddr2 := &udp.Addr{IP: net.ParseIP("172.16.0.3"), Port: 4242} - theirUdpAddr3 := &udp.Addr{IP: net.ParseIP("100.152.0.3"), Port: 4242} - theirUdpAddr4 := &udp.Addr{IP: net.ParseIP("24.15.0.3"), Port: 4242} - theirVpnIp := iputil.Ip2VpnIp(net.ParseIP("10.128.0.3")) + theirUdpAddr0 := netip.MustParseAddrPort("10.0.0.3:4242") + theirUdpAddr1 := netip.MustParseAddrPort("192.168.0.3:4242") + theirUdpAddr2 := netip.MustParseAddrPort("172.16.0.3:4242") + theirUdpAddr3 := netip.MustParseAddrPort("100.152.0.3:4242") + theirUdpAddr4 := netip.MustParseAddrPort("24.15.0.3:4242") + theirVpnIp := netip.MustParseAddr("10.128.0.3") c := config.NewC(l) c.Settings["lighthouse"] = map[interface{}]interface{}{"am_lighthouse": true} c.Settings["listen"] = map[interface{}]interface{}{"port": 4242} - lh, err := NewLightHouseFromConfig(context.Background(), l, c, &net.IPNet{IP: net.IP{10, 128, 0, 1}, Mask: net.IPMask{255, 255, 255, 0}}, nil, nil) + lh, err := NewLightHouseFromConfig(context.Background(), l, c, netip.MustParsePrefix("10.128.0.1/24"), nil, nil) assert.NoError(t, err) lhh := lh.NewRequestHandler() // Test that my first update responds with just that - newLHHostUpdate(myUdpAddr0, myVpnIp, []*udp.Addr{myUdpAddr1, myUdpAddr2}, lhh) + newLHHostUpdate(myUdpAddr0, myVpnIp, []netip.AddrPort{myUdpAddr1, myUdpAddr2}, lhh) r := newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh) assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr1, myUdpAddr2) // Ensure we don't accumulate addresses - newLHHostUpdate(myUdpAddr0, myVpnIp, []*udp.Addr{myUdpAddr3}, lhh) + newLHHostUpdate(myUdpAddr0, myVpnIp, []netip.AddrPort{myUdpAddr3}, lhh) r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh) assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr3) // Grow it back to 2 - newLHHostUpdate(myUdpAddr0, myVpnIp, []*udp.Addr{myUdpAddr1, myUdpAddr4}, lhh) + newLHHostUpdate(myUdpAddr0, myVpnIp, []netip.AddrPort{myUdpAddr1, myUdpAddr4}, lhh) r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh) assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr1, myUdpAddr4) // Update a different host and ask about it - newLHHostUpdate(theirUdpAddr0, theirVpnIp, []*udp.Addr{theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4}, lhh) + newLHHostUpdate(theirUdpAddr0, theirVpnIp, []netip.AddrPort{theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4}, lhh) r = newLHHostRequest(theirUdpAddr0, theirVpnIp, theirVpnIp, lhh) assertIp4InArray(t, r.msg.Details.Ip4AndPorts, theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4) @@ -233,7 +237,7 @@ func TestLighthouse_Memory(t *testing.T) { newLHHostUpdate( myUdpAddr0, myVpnIp, - []*udp.Addr{ + []netip.AddrPort{ myUdpAddr1, myUdpAddr2, myUdpAddr3, @@ -256,10 +260,10 @@ func TestLighthouse_Memory(t *testing.T) { ) // Make sure we won't add ips in our vpn network - bad1 := &udp.Addr{IP: net.ParseIP("10.128.0.99"), Port: 4242} - bad2 := &udp.Addr{IP: net.ParseIP("10.128.0.100"), Port: 4242} - good := &udp.Addr{IP: net.ParseIP("1.128.0.99"), Port: 4242} - newLHHostUpdate(myUdpAddr0, myVpnIp, []*udp.Addr{bad1, bad2, good}, lhh) + bad1 := netip.MustParseAddrPort("10.128.0.99:4242") + bad2 := netip.MustParseAddrPort("10.128.0.100:4242") + good := netip.MustParseAddrPort("1.128.0.99:4242") + newLHHostUpdate(myUdpAddr0, myVpnIp, []netip.AddrPort{bad1, bad2, good}, lhh) r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh) assertIp4InArray(t, r.msg.Details.Ip4AndPorts, good) } @@ -269,7 +273,7 @@ func TestLighthouse_reload(t *testing.T) { c := config.NewC(l) c.Settings["lighthouse"] = map[interface{}]interface{}{"am_lighthouse": true} c.Settings["listen"] = map[interface{}]interface{}{"port": 4242} - lh, err := NewLightHouseFromConfig(context.Background(), l, c, &net.IPNet{IP: net.IP{10, 128, 0, 1}, Mask: net.IPMask{255, 255, 255, 0}}, nil, nil) + lh, err := NewLightHouseFromConfig(context.Background(), l, c, netip.MustParsePrefix("10.128.0.1/24"), nil, nil) assert.NoError(t, err) nc := map[interface{}]interface{}{ @@ -285,11 +289,13 @@ func TestLighthouse_reload(t *testing.T) { assert.NoError(t, err) } -func newLHHostRequest(fromAddr *udp.Addr, myVpnIp, queryVpnIp iputil.VpnIp, lhh *LightHouseHandler) testLhReply { +func newLHHostRequest(fromAddr netip.AddrPort, myVpnIp, queryVpnIp netip.Addr, lhh *LightHouseHandler) testLhReply { + //TODO: IPV6-WORK + bip := queryVpnIp.As4() req := &NebulaMeta{ Type: NebulaMeta_HostQuery, Details: &NebulaMetaDetails{ - VpnIp: uint32(queryVpnIp), + VpnIp: binary.BigEndian.Uint32(bip[:]), }, } @@ -306,17 +312,19 @@ func newLHHostRequest(fromAddr *udp.Addr, myVpnIp, queryVpnIp iputil.VpnIp, lhh return w.lastReply } -func newLHHostUpdate(fromAddr *udp.Addr, vpnIp iputil.VpnIp, addrs []*udp.Addr, lhh *LightHouseHandler) { +func newLHHostUpdate(fromAddr netip.AddrPort, vpnIp netip.Addr, addrs []netip.AddrPort, lhh *LightHouseHandler) { + //TODO: IPV6-WORK + bip := vpnIp.As4() req := &NebulaMeta{ Type: NebulaMeta_HostUpdateNotification, Details: &NebulaMetaDetails{ - VpnIp: uint32(vpnIp), + VpnIp: binary.BigEndian.Uint32(bip[:]), Ip4AndPorts: make([]*Ip4AndPort, len(addrs)), }, } for k, v := range addrs { - req.Details.Ip4AndPorts[k] = &Ip4AndPort{Ip: uint32(iputil.Ip2VpnIp(v.IP)), Port: uint32(v.Port)} + req.Details.Ip4AndPorts[k] = NewIp4AndPortFromNetIP(v.Addr(), v.Port()) } b, err := req.Marshal() @@ -394,16 +402,10 @@ func newLHHostUpdate(fromAddr *udp.Addr, vpnIp iputil.VpnIp, addrs []*udp.Addr, // ) //} -func Test_ipMaskContains(t *testing.T) { - assert.True(t, ipMaskContains(iputil.Ip2VpnIp(net.ParseIP("10.0.0.1")), 32-24, iputil.Ip2VpnIp(net.ParseIP("10.0.0.255")))) - assert.False(t, ipMaskContains(iputil.Ip2VpnIp(net.ParseIP("10.0.0.1")), 32-24, iputil.Ip2VpnIp(net.ParseIP("10.0.1.1")))) - assert.True(t, ipMaskContains(iputil.Ip2VpnIp(net.ParseIP("10.0.0.1")), 32, iputil.Ip2VpnIp(net.ParseIP("10.0.1.1")))) -} - type testLhReply struct { nebType header.MessageType nebSubType header.MessageSubType - vpnIp iputil.VpnIp + vpnIp netip.Addr msg *NebulaMeta } @@ -414,7 +416,7 @@ type testEncWriter struct { func (tw *testEncWriter) SendVia(via *HostInfo, relay *Relay, ad, nb, out []byte, nocopy bool) { } -func (tw *testEncWriter) Handshake(vpnIp iputil.VpnIp) { +func (tw *testEncWriter) Handshake(vpnIp netip.Addr) { } func (tw *testEncWriter) SendMessageToHostInfo(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, _, _ []byte) { @@ -434,7 +436,7 @@ func (tw *testEncWriter) SendMessageToHostInfo(t header.MessageType, st header.M } } -func (tw *testEncWriter) SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp iputil.VpnIp, p, _, _ []byte) { +func (tw *testEncWriter) SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp netip.Addr, p, _, _ []byte) { msg := &NebulaMeta{} err := msg.Unmarshal(p) if tw.metaFilter == nil || msg.Type == *tw.metaFilter { @@ -452,35 +454,16 @@ func (tw *testEncWriter) SendMessageToVpnIp(t header.MessageType, st header.Mess } // assertIp4InArray asserts every address in want is at the same position in have and that the lengths match -func assertIp4InArray(t *testing.T, have []*Ip4AndPort, want ...*udp.Addr) { +func assertIp4InArray(t *testing.T, have []*Ip4AndPort, want ...netip.AddrPort) { if !assert.Len(t, have, len(want)) { return } for k, w := range want { - if !(have[k].Ip == uint32(iputil.Ip2VpnIp(w.IP)) && have[k].Port == uint32(w.Port)) { - assert.Fail(t, fmt.Sprintf("Response did not contain: %v:%v at %v; %v", w.IP, w.Port, k, translateV4toUdpAddr(have))) + //TODO: IPV6-WORK + h := AddrPortFromIp4AndPort(have[k]) + if !(h == w) { + assert.Fail(t, fmt.Sprintf("Response did not contain: %v at %v, found %v", w, k, h)) } } } - -// assertUdpAddrInArray asserts every address in want is at the same position in have and that the lengths match -func assertUdpAddrInArray(t *testing.T, have []*udp.Addr, want ...*udp.Addr) { - if !assert.Len(t, have, len(want)) { - return - } - - for k, w := range want { - if !(have[k].IP.Equal(w.IP) && have[k].Port == w.Port) { - assert.Fail(t, fmt.Sprintf("Response did not contain: %v at %v; %v", w, k, have)) - } - } -} - -func translateV4toUdpAddr(ips []*Ip4AndPort) []*udp.Addr { - addrs := make([]*udp.Addr, len(ips)) - for k, v := range ips { - addrs[k] = NewUDPAddrFromLH4(v) - } - return addrs -} diff --git a/main.go b/main.go index 7a0a0cf..248f329 100644 --- a/main.go +++ b/main.go @@ -5,6 +5,7 @@ import ( "encoding/binary" "fmt" "net" + "net/netip" "time" "github.com/sirupsen/logrus" @@ -67,8 +68,17 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg } l.WithField("firewallHashes", fw.GetRuleHashes()).Info("Firewall started") - // TODO: make sure mask is 4 bytes - tunCidr := certificate.Details.Ips[0] + ones, _ := certificate.Details.Ips[0].Mask.Size() + addr, ok := netip.AddrFromSlice(certificate.Details.Ips[0].IP) + if !ok { + err = util.NewContextualError( + "Invalid ip address in certificate", + m{"vpnIp": certificate.Details.Ips[0].IP}, + nil, + ) + return nil, err + } + tunCidr := netip.PrefixFrom(addr, ones) ssh, err := sshd.NewSSHServer(l.WithField("subsystem", "sshd")) if err != nil { @@ -150,21 +160,25 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg if !configTest { rawListenHost := c.GetString("listen.host", "0.0.0.0") - var listenHost *net.IPAddr + var listenHost netip.Addr if rawListenHost == "[::]" { // Old guidance was to provide the literal `[::]` in `listen.host` but that won't resolve. - listenHost = &net.IPAddr{IP: net.IPv6zero} + listenHost = netip.IPv6Unspecified() } else { - listenHost, err = net.ResolveIPAddr("ip", rawListenHost) + ips, err := net.DefaultResolver.LookupNetIP(context.Background(), "ip", rawListenHost) if err != nil { return nil, util.ContextualizeIfNeeded("Failed to resolve listen.host", err) } + if len(ips) == 0 { + return nil, util.ContextualizeIfNeeded("Failed to resolve listen.host", err) + } + listenHost = ips[0].Unmap() } for i := 0; i < routines; i++ { - l.Infof("listening %q %d", listenHost.IP, port) - udpServer, err := udp.NewListener(l, listenHost.IP, port, routines > 1, c.GetInt("listen.batch", 64)) + l.Infof("listening on %v", netip.AddrPortFrom(listenHost, uint16(port))) + udpServer, err := udp.NewListener(l, listenHost, port, routines > 1, c.GetInt("listen.batch", 64)) if err != nil { return nil, util.NewContextualError("Failed to open udp listener", m{"queue": i}, err) } @@ -178,7 +192,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg if err != nil { return nil, util.NewContextualError("Failed to get listening port", nil, err) } - port = int(uPort.Port) + port = int(uPort.Port()) } } } diff --git a/outside.go b/outside.go index 818e2ae..be60294 100644 --- a/outside.go +++ b/outside.go @@ -4,6 +4,7 @@ import ( "encoding/binary" "errors" "fmt" + "net/netip" "time" "github.com/flynn/noise" @@ -11,7 +12,6 @@ import ( "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/firewall" "github.com/slackhq/nebula/header" - "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/udp" "golang.org/x/net/ipv4" "google.golang.org/protobuf/proto" @@ -21,9 +21,10 @@ const ( minFwPacketLen = 4 ) +// TODO: IPV6-WORK this can likely be removed now func readOutsidePackets(f *Interface) udp.EncReader { return func( - addr *udp.Addr, + addr netip.AddrPort, out []byte, packet []byte, header *header.H, @@ -37,27 +38,25 @@ func readOutsidePackets(f *Interface) udp.EncReader { } } -func (f *Interface) readOutsidePackets(addr *udp.Addr, via *ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf udp.LightHouseHandlerFunc, nb []byte, q int, localCache firewall.ConntrackCache) { +func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf udp.LightHouseHandlerFunc, nb []byte, q int, localCache firewall.ConntrackCache) { err := h.Parse(packet) if err != nil { // TODO: best if we return this and let caller log // TODO: Might be better to send the literal []byte("holepunch") packet and ignore that? // Hole punch packets are 0 or 1 byte big, so lets ignore printing those errors if len(packet) > 1 { - f.l.WithField("packet", packet).Infof("Error while parsing inbound packet from %s: %s", addr, err) + f.l.WithField("packet", packet).Infof("Error while parsing inbound packet from %s: %s", ip, err) } return } //l.Error("in packet ", header, packet[HeaderLen:]) - if addr != nil { - if ip4 := addr.IP.To4(); ip4 != nil { - if ipMaskContains(f.lightHouse.myVpnIp, f.lightHouse.myVpnZeros, iputil.VpnIp(binary.BigEndian.Uint32(ip4))) { - if f.l.Level >= logrus.DebugLevel { - f.l.WithField("udpAddr", addr).Debug("Refusing to process double encrypted packet") - } - return + if ip.IsValid() { + if f.myVpnNet.Contains(ip.Addr()) { + if f.l.Level >= logrus.DebugLevel { + f.l.WithField("udpAddr", ip).Debug("Refusing to process double encrypted packet") } + return } } @@ -77,7 +76,7 @@ func (f *Interface) readOutsidePackets(addr *udp.Addr, via *ViaSender, out []byt switch h.Type { case header.Message: // TODO handleEncrypted sends directly to addr on error. Handle this in the tunneling case. - if !f.handleEncrypted(ci, addr, h) { + if !f.handleEncrypted(ci, ip, h) { return } @@ -101,7 +100,7 @@ func (f *Interface) readOutsidePackets(addr *udp.Addr, via *ViaSender, out []byt // Successfully validated the thing. Get rid of the Relay header. signedPayload = signedPayload[header.Len:] // Pull the Roaming parts up here, and return in all call paths. - f.handleHostRoaming(hostinfo, addr) + f.handleHostRoaming(hostinfo, ip) // Track usage of both the HostInfo and the Relay for the received & authenticated packet f.connectionManager.In(hostinfo.localIndexId) f.connectionManager.RelayUsed(h.RemoteIndex) @@ -118,7 +117,7 @@ func (f *Interface) readOutsidePackets(addr *udp.Addr, via *ViaSender, out []byt case TerminalType: // If I am the target of this relay, process the unwrapped packet // From this recursive point, all these variables are 'burned'. We shouldn't rely on them again. - f.readOutsidePackets(nil, &ViaSender{relayHI: hostinfo, remoteIdx: relay.RemoteIndex, relay: relay}, out[:0], signedPayload, h, fwPacket, lhf, nb, q, localCache) + f.readOutsidePackets(netip.AddrPort{}, &ViaSender{relayHI: hostinfo, remoteIdx: relay.RemoteIndex, relay: relay}, out[:0], signedPayload, h, fwPacket, lhf, nb, q, localCache) return case ForwardingType: // Find the target HostInfo relay object @@ -148,13 +147,13 @@ func (f *Interface) readOutsidePackets(addr *udp.Addr, via *ViaSender, out []byt case header.LightHouse: f.messageMetrics.Rx(h.Type, h.Subtype, 1) - if !f.handleEncrypted(ci, addr, h) { + if !f.handleEncrypted(ci, ip, h) { return } d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb) if err != nil { - hostinfo.logger(f.l).WithError(err).WithField("udpAddr", addr). + hostinfo.logger(f.l).WithError(err).WithField("udpAddr", ip). WithField("packet", packet). Error("Failed to decrypt lighthouse packet") @@ -163,19 +162,19 @@ func (f *Interface) readOutsidePackets(addr *udp.Addr, via *ViaSender, out []byt return } - lhf(addr, hostinfo.vpnIp, d) + lhf(ip, hostinfo.vpnIp, d) // Fallthrough to the bottom to record incoming traffic case header.Test: f.messageMetrics.Rx(h.Type, h.Subtype, 1) - if !f.handleEncrypted(ci, addr, h) { + if !f.handleEncrypted(ci, ip, h) { return } d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb) if err != nil { - hostinfo.logger(f.l).WithError(err).WithField("udpAddr", addr). + hostinfo.logger(f.l).WithError(err).WithField("udpAddr", ip). WithField("packet", packet). Error("Failed to decrypt test packet") @@ -187,7 +186,7 @@ func (f *Interface) readOutsidePackets(addr *udp.Addr, via *ViaSender, out []byt if h.Subtype == header.TestRequest { // This testRequest might be from TryPromoteBest, so we should roam // to the new IP address before responding - f.handleHostRoaming(hostinfo, addr) + f.handleHostRoaming(hostinfo, ip) f.send(header.Test, header.TestReply, ci, hostinfo, d, nb, out) } @@ -198,34 +197,34 @@ func (f *Interface) readOutsidePackets(addr *udp.Addr, via *ViaSender, out []byt case header.Handshake: f.messageMetrics.Rx(h.Type, h.Subtype, 1) - f.handshakeManager.HandleIncoming(addr, via, packet, h) + f.handshakeManager.HandleIncoming(ip, via, packet, h) return case header.RecvError: f.messageMetrics.Rx(h.Type, h.Subtype, 1) - f.handleRecvError(addr, h) + f.handleRecvError(ip, h) return case header.CloseTunnel: f.messageMetrics.Rx(h.Type, h.Subtype, 1) - if !f.handleEncrypted(ci, addr, h) { + if !f.handleEncrypted(ci, ip, h) { return } - hostinfo.logger(f.l).WithField("udpAddr", addr). + hostinfo.logger(f.l).WithField("udpAddr", ip). Info("Close tunnel received, tearing down.") f.closeTunnel(hostinfo) return case header.Control: - if !f.handleEncrypted(ci, addr, h) { + if !f.handleEncrypted(ci, ip, h) { return } d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb) if err != nil { - hostinfo.logger(f.l).WithError(err).WithField("udpAddr", addr). + hostinfo.logger(f.l).WithError(err).WithField("udpAddr", ip). WithField("packet", packet). Error("Failed to decrypt Control packet") return @@ -241,11 +240,11 @@ func (f *Interface) readOutsidePackets(addr *udp.Addr, via *ViaSender, out []byt default: f.messageMetrics.Rx(h.Type, h.Subtype, 1) - hostinfo.logger(f.l).Debugf("Unexpected packet received from %s", addr) + hostinfo.logger(f.l).Debugf("Unexpected packet received from %s", ip) return } - f.handleHostRoaming(hostinfo, addr) + f.handleHostRoaming(hostinfo, ip) f.connectionManager.In(hostinfo.localIndexId) } @@ -264,34 +263,34 @@ func (f *Interface) sendCloseTunnel(h *HostInfo) { f.send(header.CloseTunnel, 0, h.ConnectionState, h, []byte{}, make([]byte, 12, 12), make([]byte, mtu)) } -func (f *Interface) handleHostRoaming(hostinfo *HostInfo, addr *udp.Addr) { - if addr != nil && !hostinfo.remote.Equals(addr) { - if !f.lightHouse.GetRemoteAllowList().Allow(hostinfo.vpnIp, addr.IP) { - hostinfo.logger(f.l).WithField("newAddr", addr).Debug("lighthouse.remote_allow_list denied roaming") +func (f *Interface) handleHostRoaming(hostinfo *HostInfo, ip netip.AddrPort) { + if ip.IsValid() && hostinfo.remote != ip { + if !f.lightHouse.GetRemoteAllowList().Allow(hostinfo.vpnIp, ip.Addr()) { + hostinfo.logger(f.l).WithField("newAddr", ip).Debug("lighthouse.remote_allow_list denied roaming") return } - if !hostinfo.lastRoam.IsZero() && addr.Equals(hostinfo.lastRoamRemote) && time.Since(hostinfo.lastRoam) < RoamingSuppressSeconds*time.Second { + if !hostinfo.lastRoam.IsZero() && ip == hostinfo.lastRoamRemote && time.Since(hostinfo.lastRoam) < RoamingSuppressSeconds*time.Second { if f.l.Level >= logrus.DebugLevel { - hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", addr). + hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", ip). Debugf("Suppressing roam back to previous remote for %d seconds", RoamingSuppressSeconds) } return } - hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", addr). + hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", ip). Info("Host roamed to new udp ip/port.") hostinfo.lastRoam = time.Now() hostinfo.lastRoamRemote = hostinfo.remote - hostinfo.SetRemote(addr) + hostinfo.SetRemote(ip) } } -func (f *Interface) handleEncrypted(ci *ConnectionState, addr *udp.Addr, h *header.H) bool { +func (f *Interface) handleEncrypted(ci *ConnectionState, addr netip.AddrPort, h *header.H) bool { // If connectionstate exists and the replay protector allows, process packet // Else, send recv errors for 300 seconds after a restart to allow fast reconnection. if ci == nil || !ci.window.Check(f.l, h.MessageCounter) { - if addr != nil { + if addr.IsValid() { f.maybeSendRecvError(addr, h.RemoteIndex) return false } else { @@ -340,8 +339,9 @@ func newPacket(data []byte, incoming bool, fp *firewall.Packet) error { // Firewall packets are locally oriented if incoming { - fp.RemoteIP = iputil.Ip2VpnIp(data[12:16]) - fp.LocalIP = iputil.Ip2VpnIp(data[16:20]) + //TODO: IPV6-WORK + fp.RemoteIP, _ = netip.AddrFromSlice(data[12:16]) + fp.LocalIP, _ = netip.AddrFromSlice(data[16:20]) if fp.Fragment || fp.Protocol == firewall.ProtoICMP { fp.RemotePort = 0 fp.LocalPort = 0 @@ -350,8 +350,9 @@ func newPacket(data []byte, incoming bool, fp *firewall.Packet) error { fp.LocalPort = binary.BigEndian.Uint16(data[ihl+2 : ihl+4]) } } else { - fp.LocalIP = iputil.Ip2VpnIp(data[12:16]) - fp.RemoteIP = iputil.Ip2VpnIp(data[16:20]) + //TODO: IPV6-WORK + fp.LocalIP, _ = netip.AddrFromSlice(data[12:16]) + fp.RemoteIP, _ = netip.AddrFromSlice(data[16:20]) if fp.Fragment || fp.Protocol == firewall.ProtoICMP { fp.RemotePort = 0 fp.LocalPort = 0 @@ -425,13 +426,13 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out return true } -func (f *Interface) maybeSendRecvError(endpoint *udp.Addr, index uint32) { - if f.sendRecvErrorConfig.ShouldSendRecvError(endpoint.IP) { +func (f *Interface) maybeSendRecvError(endpoint netip.AddrPort, index uint32) { + if f.sendRecvErrorConfig.ShouldSendRecvError(endpoint) { f.sendRecvError(endpoint, index) } } -func (f *Interface) sendRecvError(endpoint *udp.Addr, index uint32) { +func (f *Interface) sendRecvError(endpoint netip.AddrPort, index uint32) { f.messageMetrics.Tx(header.RecvError, 0, 1) //TODO: this should be a signed message so we can trust that we should drop the index @@ -444,7 +445,7 @@ func (f *Interface) sendRecvError(endpoint *udp.Addr, index uint32) { } } -func (f *Interface) handleRecvError(addr *udp.Addr, h *header.H) { +func (f *Interface) handleRecvError(addr netip.AddrPort, h *header.H) { if f.l.Level >= logrus.DebugLevel { f.l.WithField("index", h.RemoteIndex). WithField("udpAddr", addr). @@ -461,7 +462,7 @@ func (f *Interface) handleRecvError(addr *udp.Addr, h *header.H) { return } - if hostinfo.remote != nil && !hostinfo.remote.Equals(addr) { + if hostinfo.remote.IsValid() && hostinfo.remote != addr { f.l.Infoln("Someone spoofing recv_errors? ", addr, hostinfo.remote) return } diff --git a/outside_test.go b/outside_test.go index 682107b..f9d4bfa 100644 --- a/outside_test.go +++ b/outside_test.go @@ -2,10 +2,10 @@ package nebula import ( "net" + "net/netip" "testing" "github.com/slackhq/nebula/firewall" - "github.com/slackhq/nebula/iputil" "github.com/stretchr/testify/assert" "golang.org/x/net/ipv4" ) @@ -55,8 +55,8 @@ func Test_newPacket(t *testing.T) { assert.Nil(t, err) assert.Equal(t, p.Protocol, uint8(firewall.ProtoTCP)) - assert.Equal(t, p.LocalIP, iputil.Ip2VpnIp(net.IPv4(10, 0, 0, 2))) - assert.Equal(t, p.RemoteIP, iputil.Ip2VpnIp(net.IPv4(10, 0, 0, 1))) + assert.Equal(t, p.LocalIP, netip.MustParseAddr("10.0.0.2")) + assert.Equal(t, p.RemoteIP, netip.MustParseAddr("10.0.0.1")) assert.Equal(t, p.RemotePort, uint16(3)) assert.Equal(t, p.LocalPort, uint16(4)) @@ -76,8 +76,8 @@ func Test_newPacket(t *testing.T) { assert.Nil(t, err) assert.Equal(t, p.Protocol, uint8(2)) - assert.Equal(t, p.LocalIP, iputil.Ip2VpnIp(net.IPv4(10, 0, 0, 1))) - assert.Equal(t, p.RemoteIP, iputil.Ip2VpnIp(net.IPv4(10, 0, 0, 2))) + assert.Equal(t, p.LocalIP, netip.MustParseAddr("10.0.0.1")) + assert.Equal(t, p.RemoteIP, netip.MustParseAddr("10.0.0.2")) assert.Equal(t, p.RemotePort, uint16(6)) assert.Equal(t, p.LocalPort, uint16(5)) } diff --git a/overlay/device.go b/overlay/device.go index 3f3f2eb..50ad6ad 100644 --- a/overlay/device.go +++ b/overlay/device.go @@ -2,16 +2,14 @@ package overlay import ( "io" - "net" - - "github.com/slackhq/nebula/iputil" + "net/netip" ) type Device interface { io.ReadWriteCloser Activate() error - Cidr() *net.IPNet + Cidr() netip.Prefix Name() string - RouteFor(iputil.VpnIp) iputil.VpnIp + RouteFor(netip.Addr) netip.Addr NewMultiQueueReader() (io.ReadWriteCloser, error) } diff --git a/overlay/route.go b/overlay/route.go index 64c624c..8ccc994 100644 --- a/overlay/route.go +++ b/overlay/route.go @@ -1,34 +1,30 @@ package overlay import ( - "bytes" "fmt" "math" "net" + "net/netip" "runtime" "strconv" + "github.com/gaissmai/bart" "github.com/sirupsen/logrus" - "github.com/slackhq/nebula/cidr" "github.com/slackhq/nebula/config" - "github.com/slackhq/nebula/iputil" ) type Route struct { MTU int Metric int - Cidr *net.IPNet - Via *iputil.VpnIp + Cidr netip.Prefix + Via netip.Addr Install bool } // Equal determines if a route that could be installed in the system route table is equal to another // Via is ignored since that is only consumed within nebula itself func (r Route) Equal(t Route) bool { - if !r.Cidr.IP.Equal(t.Cidr.IP) { - return false - } - if !bytes.Equal(r.Cidr.Mask, t.Cidr.Mask) { + if r.Cidr != t.Cidr { return false } if r.Metric != t.Metric { @@ -51,21 +47,21 @@ func (r Route) String() string { return s } -func makeRouteTree(l *logrus.Logger, routes []Route, allowMTU bool) (*cidr.Tree4[iputil.VpnIp], error) { - routeTree := cidr.NewTree4[iputil.VpnIp]() +func makeRouteTree(l *logrus.Logger, routes []Route, allowMTU bool) (*bart.Table[netip.Addr], error) { + routeTree := new(bart.Table[netip.Addr]) for _, r := range routes { if !allowMTU && r.MTU > 0 { l.WithField("route", r).Warnf("route MTU is not supported in %s", runtime.GOOS) } - if r.Via != nil { - routeTree.AddCIDR(r.Cidr, *r.Via) + if r.Via.IsValid() { + routeTree.Insert(r.Cidr, r.Via) } } return routeTree, nil } -func parseRoutes(c *config.C, network *net.IPNet) ([]Route, error) { +func parseRoutes(c *config.C, network netip.Prefix) ([]Route, error) { var err error r := c.Get("tun.routes") @@ -116,12 +112,12 @@ func parseRoutes(c *config.C, network *net.IPNet) ([]Route, error) { MTU: mtu, } - _, r.Cidr, err = net.ParseCIDR(fmt.Sprintf("%v", rRoute)) + r.Cidr, err = netip.ParsePrefix(fmt.Sprintf("%v", rRoute)) if err != nil { return nil, fmt.Errorf("entry %v.route in tun.routes failed to parse: %v", i+1, err) } - if !ipWithin(network, r.Cidr) { + if !network.Contains(r.Cidr.Addr()) || r.Cidr.Bits() < network.Bits() { return nil, fmt.Errorf( "entry %v.route in tun.routes is not contained within the network attached to the certificate; route: %v, network: %v", i+1, @@ -136,7 +132,7 @@ func parseRoutes(c *config.C, network *net.IPNet) ([]Route, error) { return routes, nil } -func parseUnsafeRoutes(c *config.C, network *net.IPNet) ([]Route, error) { +func parseUnsafeRoutes(c *config.C, network netip.Prefix) ([]Route, error) { var err error r := c.Get("tun.unsafe_routes") @@ -202,9 +198,9 @@ func parseUnsafeRoutes(c *config.C, network *net.IPNet) ([]Route, error) { return nil, fmt.Errorf("entry %v.via in tun.unsafe_routes is not a string: found %T", i+1, rVia) } - nVia := net.ParseIP(via) - if nVia == nil { - return nil, fmt.Errorf("entry %v.via in tun.unsafe_routes failed to parse address: %v", i+1, via) + viaVpnIp, err := netip.ParseAddr(via) + if err != nil { + return nil, fmt.Errorf("entry %v.via in tun.unsafe_routes failed to parse address: %v", i+1, err) } rRoute, ok := m["route"] @@ -212,8 +208,6 @@ func parseUnsafeRoutes(c *config.C, network *net.IPNet) ([]Route, error) { return nil, fmt.Errorf("entry %v.route in tun.unsafe_routes is not present", i+1) } - viaVpnIp := iputil.Ip2VpnIp(nVia) - install := true rInstall, ok := m["install"] if ok { @@ -224,18 +218,18 @@ func parseUnsafeRoutes(c *config.C, network *net.IPNet) ([]Route, error) { } r := Route{ - Via: &viaVpnIp, + Via: viaVpnIp, MTU: mtu, Metric: metric, Install: install, } - _, r.Cidr, err = net.ParseCIDR(fmt.Sprintf("%v", rRoute)) + r.Cidr, err = netip.ParsePrefix(fmt.Sprintf("%v", rRoute)) if err != nil { return nil, fmt.Errorf("entry %v.route in tun.unsafe_routes failed to parse: %v", i+1, err) } - if ipWithin(network, r.Cidr) { + if network.Contains(r.Cidr.Addr()) { return nil, fmt.Errorf( "entry %v.route in tun.unsafe_routes is contained within the network attached to the certificate; route: %v, network: %v", i+1, diff --git a/overlay/route_test.go b/overlay/route_test.go index 46fb87c..d791389 100644 --- a/overlay/route_test.go +++ b/overlay/route_test.go @@ -2,11 +2,10 @@ package overlay import ( "fmt" - "net" + "net/netip" "testing" "github.com/slackhq/nebula/config" - "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/test" "github.com/stretchr/testify/assert" ) @@ -14,7 +13,8 @@ import ( func Test_parseRoutes(t *testing.T) { l := test.NewLogger() c := config.NewC(l) - _, n, _ := net.ParseCIDR("10.0.0.0/24") + n, err := netip.ParsePrefix("10.0.0.0/24") + assert.NoError(t, err) // test no routes config routes, err := parseRoutes(c, n) @@ -67,7 +67,7 @@ func Test_parseRoutes(t *testing.T) { c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "nope"}}} routes, err = parseRoutes(c, n) assert.Nil(t, routes) - assert.EqualError(t, err, "entry 1.route in tun.routes failed to parse: invalid CIDR address: nope") + assert.EqualError(t, err, "entry 1.route in tun.routes failed to parse: netip.ParsePrefix(\"nope\"): no '/'") // below network range c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "1.0.0.0/8"}}} @@ -112,7 +112,8 @@ func Test_parseRoutes(t *testing.T) { func Test_parseUnsafeRoutes(t *testing.T) { l := test.NewLogger() c := config.NewC(l) - _, n, _ := net.ParseCIDR("10.0.0.0/24") + n, err := netip.ParsePrefix("10.0.0.0/24") + assert.NoError(t, err) // test no routes config routes, err := parseUnsafeRoutes(c, n) @@ -157,7 +158,7 @@ func Test_parseUnsafeRoutes(t *testing.T) { c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "via": "nope"}}} routes, err = parseUnsafeRoutes(c, n) assert.Nil(t, routes) - assert.EqualError(t, err, "entry 1.via in tun.unsafe_routes failed to parse address: nope") + assert.EqualError(t, err, "entry 1.via in tun.unsafe_routes failed to parse address: ParseAddr(\"nope\"): unable to parse IP") // missing route c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "500"}}} @@ -169,7 +170,7 @@ func Test_parseUnsafeRoutes(t *testing.T) { c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "500", "route": "nope"}}} routes, err = parseUnsafeRoutes(c, n) assert.Nil(t, routes) - assert.EqualError(t, err, "entry 1.route in tun.unsafe_routes failed to parse: invalid CIDR address: nope") + assert.EqualError(t, err, "entry 1.route in tun.unsafe_routes failed to parse: netip.ParsePrefix(\"nope\"): no '/'") // within network range c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "10.0.0.0/24"}}} @@ -252,7 +253,8 @@ func Test_parseUnsafeRoutes(t *testing.T) { func Test_makeRouteTree(t *testing.T) { l := test.NewLogger() c := config.NewC(l) - _, n, _ := net.ParseCIDR("10.0.0.0/24") + n, err := netip.ParsePrefix("10.0.0.0/24") + assert.NoError(t, err) c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{ map[interface{}]interface{}{"via": "192.168.0.1", "route": "1.0.0.0/28"}, @@ -264,17 +266,26 @@ func Test_makeRouteTree(t *testing.T) { routeTree, err := makeRouteTree(l, routes, true) assert.NoError(t, err) - ip := iputil.Ip2VpnIp(net.ParseIP("1.0.0.2")) - ok, r := routeTree.MostSpecificContains(ip) + ip, err := netip.ParseAddr("1.0.0.2") + assert.NoError(t, err) + r, ok := routeTree.Lookup(ip) assert.True(t, ok) - assert.Equal(t, iputil.Ip2VpnIp(net.ParseIP("192.168.0.1")), r) - ip = iputil.Ip2VpnIp(net.ParseIP("1.0.0.1")) - ok, r = routeTree.MostSpecificContains(ip) + nip, err := netip.ParseAddr("192.168.0.1") + assert.NoError(t, err) + assert.Equal(t, nip, r) + + ip, err = netip.ParseAddr("1.0.0.1") + assert.NoError(t, err) + r, ok = routeTree.Lookup(ip) assert.True(t, ok) - assert.Equal(t, iputil.Ip2VpnIp(net.ParseIP("192.168.0.2")), r) - ip = iputil.Ip2VpnIp(net.ParseIP("1.1.0.1")) - ok, r = routeTree.MostSpecificContains(ip) + nip, err = netip.ParseAddr("192.168.0.2") + assert.NoError(t, err) + assert.Equal(t, nip, r) + + ip, err = netip.ParseAddr("1.1.0.1") + assert.NoError(t, err) + r, ok = routeTree.Lookup(ip) assert.False(t, ok) } diff --git a/overlay/tun.go b/overlay/tun.go index cedd7fe..12460da 100644 --- a/overlay/tun.go +++ b/overlay/tun.go @@ -1,7 +1,7 @@ package overlay import ( - "net" + "net/netip" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" @@ -11,9 +11,9 @@ import ( const DefaultMTU = 1300 // TODO: We may be able to remove routines -type DeviceFactory func(c *config.C, l *logrus.Logger, tunCidr *net.IPNet, routines int) (Device, error) +type DeviceFactory func(c *config.C, l *logrus.Logger, tunCidr netip.Prefix, routines int) (Device, error) -func NewDeviceFromConfig(c *config.C, l *logrus.Logger, tunCidr *net.IPNet, routines int) (Device, error) { +func NewDeviceFromConfig(c *config.C, l *logrus.Logger, tunCidr netip.Prefix, routines int) (Device, error) { switch { case c.GetBool("tun.disabled", false): tun := newDisabledTun(tunCidr, c.GetInt("tun.tx_queue", 500), c.GetBool("stats.message_metrics", false), l) @@ -25,12 +25,12 @@ func NewDeviceFromConfig(c *config.C, l *logrus.Logger, tunCidr *net.IPNet, rout } func NewFdDeviceFromConfig(fd *int) DeviceFactory { - return func(c *config.C, l *logrus.Logger, tunCidr *net.IPNet, routines int) (Device, error) { + return func(c *config.C, l *logrus.Logger, tunCidr netip.Prefix, routines int) (Device, error) { return newTunFromFd(c, l, *fd, tunCidr) } } -func getAllRoutesFromConfig(c *config.C, cidr *net.IPNet, initial bool) (bool, []Route, error) { +func getAllRoutesFromConfig(c *config.C, cidr netip.Prefix, initial bool) (bool, []Route, error) { if !initial && !c.HasChanged("tun.routes") && !c.HasChanged("tun.unsafe_routes") { return false, nil, nil } diff --git a/overlay/tun_android.go b/overlay/tun_android.go index c15827f..98ad9b4 100644 --- a/overlay/tun_android.go +++ b/overlay/tun_android.go @@ -6,27 +6,26 @@ package overlay import ( "fmt" "io" - "net" + "net/netip" "os" "sync/atomic" + "github.com/gaissmai/bart" "github.com/sirupsen/logrus" - "github.com/slackhq/nebula/cidr" "github.com/slackhq/nebula/config" - "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/util" ) type tun struct { io.ReadWriteCloser fd int - cidr *net.IPNet + cidr netip.Prefix Routes atomic.Pointer[[]Route] - routeTree atomic.Pointer[cidr.Tree4[iputil.VpnIp]] + routeTree atomic.Pointer[bart.Table[netip.Addr]] l *logrus.Logger } -func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr *net.IPNet) (*tun, error) { +func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr netip.Prefix) (*tun, error) { // XXX Android returns an fd in non-blocking mode which is necessary for shutdown to work properly. // Be sure not to call file.Fd() as it will set the fd to blocking mode. file := os.NewFile(uintptr(deviceFd), "/dev/net/tun") @@ -53,12 +52,12 @@ func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr *net.IPNet) return t, nil } -func newTun(_ *config.C, _ *logrus.Logger, _ *net.IPNet, _ bool) (*tun, error) { +func newTun(_ *config.C, _ *logrus.Logger, _ netip.Prefix, _ bool) (*tun, error) { return nil, fmt.Errorf("newTun not supported in Android") } -func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp { - _, r := t.routeTree.Load().MostSpecificContains(ip) +func (t *tun) RouteFor(ip netip.Addr) netip.Addr { + r, _ := t.routeTree.Load().Lookup(ip) return r } @@ -87,7 +86,7 @@ func (t *tun) reload(c *config.C, initial bool) error { return nil } -func (t *tun) Cidr() *net.IPNet { +func (t *tun) Cidr() netip.Prefix { return t.cidr } diff --git a/overlay/tun_darwin.go b/overlay/tun_darwin.go index 1c63828..0b573e6 100644 --- a/overlay/tun_darwin.go +++ b/overlay/tun_darwin.go @@ -8,15 +8,15 @@ import ( "fmt" "io" "net" + "net/netip" "os" "sync/atomic" "syscall" "unsafe" + "github.com/gaissmai/bart" "github.com/sirupsen/logrus" - "github.com/slackhq/nebula/cidr" "github.com/slackhq/nebula/config" - "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/util" netroute "golang.org/x/net/route" "golang.org/x/sys/unix" @@ -25,10 +25,10 @@ import ( type tun struct { io.ReadWriteCloser Device string - cidr *net.IPNet + cidr netip.Prefix DefaultMTU int Routes atomic.Pointer[[]Route] - routeTree atomic.Pointer[cidr.Tree4[iputil.VpnIp]] + routeTree atomic.Pointer[bart.Table[netip.Addr]] linkAddr *netroute.LinkAddr l *logrus.Logger @@ -73,7 +73,7 @@ type ifreqMTU struct { pad [8]byte } -func newTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, _ bool) (*tun, error) { +func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, error) { name := c.GetString("tun.dev", "") ifIndex := -1 if name != "" && name != "utun" { @@ -172,7 +172,7 @@ func (t *tun) deviceBytes() (o [16]byte) { return } -func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ *net.IPNet) (*tun, error) { +func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ netip.Prefix) (*tun, error) { return nil, fmt.Errorf("newTunFromFd not supported in Darwin") } @@ -188,8 +188,13 @@ func (t *tun) Activate() error { var addr, mask [4]byte - copy(addr[:], t.cidr.IP.To4()) - copy(mask[:], t.cidr.Mask) + if !t.cidr.Addr().Is4() { + //TODO: IPV6-WORK + panic("need ipv6") + } + + addr = t.cidr.Addr().As4() + copy(mask[:], prefixToMask(t.cidr)) s, err := unix.Socket( unix.AF_INET, @@ -329,13 +334,12 @@ func (t *tun) reload(c *config.C, initial bool) error { return nil } -func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp { - ok, r := t.routeTree.Load().MostSpecificContains(ip) +func (t *tun) RouteFor(ip netip.Addr) netip.Addr { + r, ok := t.routeTree.Load().Lookup(ip) if ok { return r } - - return 0 + return netip.Addr{} } // Get the LinkAddr for the interface of the given name @@ -384,13 +388,19 @@ func (t *tun) addRoutes(logErrors bool) error { maskAddr := &netroute.Inet4Addr{} routes := *t.Routes.Load() for _, r := range routes { - if r.Via == nil || !r.Install { + if !r.Via.IsValid() || !r.Install { // We don't allow route MTUs so only install routes with a via continue } - copy(routeAddr.IP[:], r.Cidr.IP.To4()) - copy(maskAddr.IP[:], net.IP(r.Cidr.Mask).To4()) + if !r.Cidr.Addr().Is4() { + //TODO: implement ipv6 + panic("Cant handle ipv6 routes yet") + } + + routeAddr.IP = r.Cidr.Addr().As4() + //TODO: we could avoid the copy + copy(maskAddr.IP[:], prefixToMask(r.Cidr)) err := addRoute(routeSock, routeAddr, maskAddr, t.linkAddr) if err != nil { @@ -435,8 +445,13 @@ func (t *tun) removeRoutes(routes []Route) error { continue } - copy(routeAddr.IP[:], r.Cidr.IP.To4()) - copy(maskAddr.IP[:], net.IP(r.Cidr.Mask).To4()) + if r.Cidr.Addr().Is6() { + //TODO: implement ipv6 + panic("Cant handle ipv6 routes yet") + } + + routeAddr.IP = r.Cidr.Addr().As4() + copy(maskAddr.IP[:], prefixToMask(r.Cidr)) err := delRoute(routeSock, routeAddr, maskAddr, t.linkAddr) if err != nil { @@ -536,7 +551,7 @@ func (t *tun) Write(from []byte) (int, error) { return n - 4, err } -func (t *tun) Cidr() *net.IPNet { +func (t *tun) Cidr() netip.Prefix { return t.cidr } @@ -547,3 +562,11 @@ func (t *tun) Name() string { func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { return nil, fmt.Errorf("TODO: multiqueue not implemented for darwin") } + +func prefixToMask(prefix netip.Prefix) []byte { + pLen := 128 + if prefix.Addr().Is4() { + pLen = 32 + } + return net.CIDRMask(prefix.Bits(), pLen) +} diff --git a/overlay/tun_disabled.go b/overlay/tun_disabled.go index e1e4ede..130f8f9 100644 --- a/overlay/tun_disabled.go +++ b/overlay/tun_disabled.go @@ -3,7 +3,7 @@ package overlay import ( "fmt" "io" - "net" + "net/netip" "strings" "github.com/rcrowley/go-metrics" @@ -13,7 +13,7 @@ import ( type disabledTun struct { read chan []byte - cidr *net.IPNet + cidr netip.Prefix // Track these metrics since we don't have the tun device to do it for us tx metrics.Counter @@ -21,7 +21,7 @@ type disabledTun struct { l *logrus.Logger } -func newDisabledTun(cidr *net.IPNet, queueLen int, metricsEnabled bool, l *logrus.Logger) *disabledTun { +func newDisabledTun(cidr netip.Prefix, queueLen int, metricsEnabled bool, l *logrus.Logger) *disabledTun { tun := &disabledTun{ cidr: cidr, read: make(chan []byte, queueLen), @@ -43,11 +43,11 @@ func (*disabledTun) Activate() error { return nil } -func (*disabledTun) RouteFor(iputil.VpnIp) iputil.VpnIp { - return 0 +func (*disabledTun) RouteFor(addr netip.Addr) netip.Addr { + return netip.Addr{} } -func (t *disabledTun) Cidr() *net.IPNet { +func (t *disabledTun) Cidr() netip.Prefix { return t.cidr } diff --git a/overlay/tun_freebsd.go b/overlay/tun_freebsd.go index 3b1b80f..bdfeb58 100644 --- a/overlay/tun_freebsd.go +++ b/overlay/tun_freebsd.go @@ -9,7 +9,7 @@ import ( "fmt" "io" "io/fs" - "net" + "net/netip" "os" "os/exec" "strconv" @@ -17,10 +17,9 @@ import ( "syscall" "unsafe" + "github.com/gaissmai/bart" "github.com/sirupsen/logrus" - "github.com/slackhq/nebula/cidr" "github.com/slackhq/nebula/config" - "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/util" ) @@ -48,10 +47,10 @@ type ifreqDestroy struct { type tun struct { Device string - cidr *net.IPNet + cidr netip.Prefix MTU int Routes atomic.Pointer[[]Route] - routeTree atomic.Pointer[cidr.Tree4[iputil.VpnIp]] + routeTree atomic.Pointer[bart.Table[netip.Addr]] l *logrus.Logger io.ReadWriteCloser @@ -79,11 +78,11 @@ func (t *tun) Close() error { return nil } -func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ *net.IPNet) (*tun, error) { +func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ netip.Prefix) (*tun, error) { return nil, fmt.Errorf("newTunFromFd not supported in FreeBSD") } -func newTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, _ bool) (*tun, error) { +func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, error) { // Try to open existing tun device var file *os.File var err error @@ -174,7 +173,7 @@ func newTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, _ bool) (*tun, error func (t *tun) Activate() error { var err error // TODO use syscalls instead of exec.Command - cmd := exec.Command("/sbin/ifconfig", t.Device, t.cidr.String(), t.cidr.IP.String()) + cmd := exec.Command("/sbin/ifconfig", t.Device, t.cidr.String(), t.cidr.Addr().String()) t.l.Debug("command: ", cmd.String()) if err = cmd.Run(); err != nil { return fmt.Errorf("failed to run 'ifconfig': %s", err) @@ -233,12 +232,12 @@ func (t *tun) reload(c *config.C, initial bool) error { return nil } -func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp { - _, r := t.routeTree.Load().MostSpecificContains(ip) +func (t *tun) RouteFor(ip netip.Addr) netip.Addr { + r, _ := t.routeTree.Load().Lookup(ip) return r } -func (t *tun) Cidr() *net.IPNet { +func (t *tun) Cidr() netip.Prefix { return t.cidr } @@ -253,7 +252,7 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { func (t *tun) addRoutes(logErrors bool) error { routes := *t.Routes.Load() for _, r := range routes { - if r.Via == nil || !r.Install { + if !r.Via.IsValid() || !r.Install { // We don't allow route MTUs so only install routes with a via continue } diff --git a/overlay/tun_ios.go b/overlay/tun_ios.go index ba15d66..20981f0 100644 --- a/overlay/tun_ios.go +++ b/overlay/tun_ios.go @@ -7,32 +7,31 @@ import ( "errors" "fmt" "io" - "net" + "net/netip" "os" "sync" "sync/atomic" "syscall" + "github.com/gaissmai/bart" "github.com/sirupsen/logrus" - "github.com/slackhq/nebula/cidr" "github.com/slackhq/nebula/config" - "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/util" ) type tun struct { io.ReadWriteCloser - cidr *net.IPNet + cidr netip.Prefix Routes atomic.Pointer[[]Route] - routeTree atomic.Pointer[cidr.Tree4[iputil.VpnIp]] + routeTree atomic.Pointer[bart.Table[netip.Addr]] l *logrus.Logger } -func newTun(_ *config.C, _ *logrus.Logger, _ *net.IPNet, _ bool) (*tun, error) { +func newTun(_ *config.C, _ *logrus.Logger, _ netip.Prefix, _ bool) (*tun, error) { return nil, fmt.Errorf("newTun not supported in iOS") } -func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr *net.IPNet) (*tun, error) { +func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr netip.Prefix) (*tun, error) { file := os.NewFile(uintptr(deviceFd), "/dev/tun") t := &tun{ cidr: cidr, @@ -80,8 +79,8 @@ func (t *tun) reload(c *config.C, initial bool) error { return nil } -func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp { - _, r := t.routeTree.Load().MostSpecificContains(ip) +func (t *tun) RouteFor(ip netip.Addr) netip.Addr { + r, _ := t.routeTree.Load().Lookup(ip) return r } @@ -143,7 +142,7 @@ func (tr *tunReadCloser) Close() error { return tr.f.Close() } -func (t *tun) Cidr() *net.IPNet { +func (t *tun) Cidr() netip.Prefix { return t.cidr } diff --git a/overlay/tun_linux.go b/overlay/tun_linux.go index 2f06951..0e7e20d 100644 --- a/overlay/tun_linux.go +++ b/overlay/tun_linux.go @@ -4,19 +4,18 @@ package overlay import ( - "bytes" "fmt" "io" "net" + "net/netip" "os" "strings" "sync/atomic" "unsafe" + "github.com/gaissmai/bart" "github.com/sirupsen/logrus" - "github.com/slackhq/nebula/cidr" "github.com/slackhq/nebula/config" - "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/util" "github.com/vishvananda/netlink" "golang.org/x/sys/unix" @@ -26,7 +25,7 @@ type tun struct { io.ReadWriteCloser fd int Device string - cidr *net.IPNet + cidr netip.Prefix MaxMTU int DefaultMTU int TXQueueLen int @@ -34,7 +33,7 @@ type tun struct { ioctlFd uintptr Routes atomic.Pointer[[]Route] - routeTree atomic.Pointer[cidr.Tree4[iputil.VpnIp]] + routeTree atomic.Pointer[bart.Table[netip.Addr]] routeChan chan struct{} useSystemRoutes bool @@ -65,7 +64,7 @@ type ifreqQLEN struct { pad [8]byte } -func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr *net.IPNet) (*tun, error) { +func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr netip.Prefix) (*tun, error) { file := os.NewFile(uintptr(deviceFd), "/dev/net/tun") t, err := newTunGeneric(c, l, file, cidr) @@ -78,7 +77,7 @@ func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr *net.IPNet) return t, nil } -func newTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, multiqueue bool) (*tun, error) { +func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, multiqueue bool) (*tun, error) { fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0) if err != nil { // If /dev/net/tun doesn't exist, try to create it (will happen in docker) @@ -123,7 +122,7 @@ func newTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, multiqueue bool) (*t return t, nil } -func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, cidr *net.IPNet) (*tun, error) { +func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, cidr netip.Prefix) (*tun, error) { t := &tun{ ReadWriteCloser: file, fd: int(file.Fd()), @@ -231,8 +230,8 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { return file, nil } -func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp { - _, r := t.routeTree.Load().MostSpecificContains(ip) +func (t *tun) RouteFor(ip netip.Addr) netip.Addr { + r, _ := t.routeTree.Load().Lookup(ip) return r } @@ -275,8 +274,10 @@ func (t *tun) Activate() error { var addr, mask [4]byte - copy(addr[:], t.cidr.IP.To4()) - copy(mask[:], t.cidr.Mask) + //TODO: IPV6-WORK + addr = t.cidr.Addr().As4() + tmask := net.CIDRMask(t.cidr.Bits(), 32) + copy(mask[:], tmask) s, err := unix.Socket( unix.AF_INET, @@ -364,14 +365,19 @@ func (t *tun) setMTU() { func (t *tun) setDefaultRoute() error { // Default route - dr := &net.IPNet{IP: t.cidr.IP.Mask(t.cidr.Mask), Mask: t.cidr.Mask} + + dr := &net.IPNet{ + IP: t.cidr.Masked().Addr().AsSlice(), + Mask: net.CIDRMask(t.cidr.Bits(), t.cidr.Addr().BitLen()), + } + nr := netlink.Route{ LinkIndex: t.deviceIndex, Dst: dr, MTU: t.DefaultMTU, AdvMSS: t.advMSS(Route{}), Scope: unix.RT_SCOPE_LINK, - Src: t.cidr.IP, + Src: net.IP(t.cidr.Addr().AsSlice()), Protocol: unix.RTPROT_KERNEL, Table: unix.RT_TABLE_MAIN, Type: unix.RTN_UNICAST, @@ -392,9 +398,14 @@ func (t *tun) addRoutes(logErrors bool) error { continue } + dr := &net.IPNet{ + IP: r.Cidr.Masked().Addr().AsSlice(), + Mask: net.CIDRMask(r.Cidr.Bits(), r.Cidr.Addr().BitLen()), + } + nr := netlink.Route{ LinkIndex: t.deviceIndex, - Dst: r.Cidr, + Dst: dr, MTU: r.MTU, AdvMSS: t.advMSS(r), Scope: unix.RT_SCOPE_LINK, @@ -426,9 +437,14 @@ func (t *tun) removeRoutes(routes []Route) { continue } + dr := &net.IPNet{ + IP: r.Cidr.Masked().Addr().AsSlice(), + Mask: net.CIDRMask(r.Cidr.Bits(), r.Cidr.Addr().BitLen()), + } + nr := netlink.Route{ LinkIndex: t.deviceIndex, - Dst: r.Cidr, + Dst: dr, MTU: r.MTU, AdvMSS: t.advMSS(r), Scope: unix.RT_SCOPE_LINK, @@ -447,7 +463,7 @@ func (t *tun) removeRoutes(routes []Route) { } } -func (t *tun) Cidr() *net.IPNet { +func (t *tun) Cidr() netip.Prefix { return t.cidr } @@ -499,7 +515,15 @@ func (t *tun) updateRoutes(r netlink.RouteUpdate) { return } - if !t.cidr.Contains(r.Gw) { + //TODO: IPV6-WORK what if not ok? + gwAddr, ok := netip.AddrFromSlice(r.Gw) + if !ok { + t.l.WithField("route", r).Debug("Ignoring route update, invalid gateway address") + return + } + + gwAddr = gwAddr.Unmap() + if !t.cidr.Contains(gwAddr) { // Gateway isn't in our overlay network, ignore t.l.WithField("route", r).Debug("Ignoring route update, not in our network") return @@ -511,28 +535,25 @@ func (t *tun) updateRoutes(r netlink.RouteUpdate) { return } - newTree := cidr.NewTree4[iputil.VpnIp]() - if r.Type == unix.RTM_NEWROUTE { - for _, oldR := range t.routeTree.Load().List() { - newTree.AddCIDR(oldR.CIDR, oldR.Value) - } - - t.l.WithField("destination", r.Dst).WithField("via", r.Gw).Info("Adding route") - newTree.AddCIDR(r.Dst, iputil.Ip2VpnIp(r.Gw)) - - } else { - gw := iputil.Ip2VpnIp(r.Gw) - for _, oldR := range t.routeTree.Load().List() { - if bytes.Equal(oldR.CIDR.IP, r.Dst.IP) && bytes.Equal(oldR.CIDR.Mask, r.Dst.Mask) && oldR.Value == gw { - // This is the record to delete - t.l.WithField("destination", r.Dst).WithField("via", r.Gw).Info("Removing route") - continue - } - - newTree.AddCIDR(oldR.CIDR, oldR.Value) - } + dstAddr, ok := netip.AddrFromSlice(r.Dst.IP) + if !ok { + t.l.WithField("route", r).Debug("Ignoring route update, invalid destination address") + return } + ones, _ := r.Dst.Mask.Size() + dst := netip.PrefixFrom(dstAddr, ones) + + newTree := t.routeTree.Load().Clone() + + if r.Type == unix.RTM_NEWROUTE { + t.l.WithField("destination", r.Dst).WithField("via", r.Gw).Info("Adding route") + newTree.Insert(dst, gwAddr) + + } else { + newTree.Delete(dst) + t.l.WithField("destination", r.Dst).WithField("via", r.Gw).Info("Removing route") + } t.routeTree.Store(newTree) } diff --git a/overlay/tun_netbsd.go b/overlay/tun_netbsd.go index cc0216f..24ab24f 100644 --- a/overlay/tun_netbsd.go +++ b/overlay/tun_netbsd.go @@ -6,7 +6,7 @@ package overlay import ( "fmt" "io" - "net" + "net/netip" "os" "os/exec" "regexp" @@ -15,10 +15,9 @@ import ( "syscall" "unsafe" + "github.com/gaissmai/bart" "github.com/sirupsen/logrus" - "github.com/slackhq/nebula/cidr" "github.com/slackhq/nebula/config" - "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/util" ) @@ -29,10 +28,10 @@ type ifreqDestroy struct { type tun struct { Device string - cidr *net.IPNet + cidr netip.Prefix MTU int Routes atomic.Pointer[[]Route] - routeTree atomic.Pointer[cidr.Tree4[iputil.VpnIp]] + routeTree atomic.Pointer[bart.Table[netip.Addr]] l *logrus.Logger io.ReadWriteCloser @@ -59,13 +58,13 @@ func (t *tun) Close() error { return nil } -func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ *net.IPNet) (*tun, error) { +func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ netip.Prefix) (*tun, error) { return nil, fmt.Errorf("newTunFromFd not supported in NetBSD") } var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`) -func newTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, _ bool) (*tun, error) { +func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, error) { // Try to open tun device var file *os.File var err error @@ -109,13 +108,13 @@ func (t *tun) Activate() error { var err error // TODO use syscalls instead of exec.Command - cmd := exec.Command("/sbin/ifconfig", t.Device, t.cidr.String(), t.cidr.IP.String()) + cmd := exec.Command("/sbin/ifconfig", t.Device, t.cidr.String(), t.cidr.Addr().String()) t.l.Debug("command: ", cmd.String()) if err = cmd.Run(); err != nil { return fmt.Errorf("failed to run 'ifconfig': %s", err) } - cmd = exec.Command("/sbin/route", "-n", "add", "-net", t.cidr.String(), t.cidr.IP.String()) + cmd = exec.Command("/sbin/route", "-n", "add", "-net", t.cidr.String(), t.cidr.Addr().String()) t.l.Debug("command: ", cmd.String()) if err = cmd.Run(); err != nil { return fmt.Errorf("failed to run 'route add': %s", err) @@ -168,12 +167,12 @@ func (t *tun) reload(c *config.C, initial bool) error { return nil } -func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp { - _, r := t.routeTree.Load().MostSpecificContains(ip) +func (t *tun) RouteFor(ip netip.Addr) netip.Addr { + r, _ := t.routeTree.Load().Lookup(ip) return r } -func (t *tun) Cidr() *net.IPNet { +func (t *tun) Cidr() netip.Prefix { return t.cidr } @@ -188,12 +187,12 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { func (t *tun) addRoutes(logErrors bool) error { routes := *t.Routes.Load() for _, r := range routes { - if r.Via == nil || !r.Install { + if !r.Via.IsValid() || !r.Install { // We don't allow route MTUs so only install routes with a via continue } - cmd := exec.Command("/sbin/route", "-n", "add", "-net", r.Cidr.String(), t.cidr.IP.String()) + cmd := exec.Command("/sbin/route", "-n", "add", "-net", r.Cidr.String(), t.cidr.Addr().String()) t.l.Debug("command: ", cmd.String()) if err := cmd.Run(); err != nil { retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]interface{}{"route": r}, err) @@ -214,7 +213,7 @@ func (t *tun) removeRoutes(routes []Route) error { continue } - cmd := exec.Command("/sbin/route", "-n", "delete", "-net", r.Cidr.String(), t.cidr.IP.String()) + cmd := exec.Command("/sbin/route", "-n", "delete", "-net", r.Cidr.String(), t.cidr.Addr().String()) t.l.Debug("command: ", cmd.String()) if err := cmd.Run(); err != nil { t.l.WithError(err).WithField("route", r).Error("Failed to remove route") diff --git a/overlay/tun_openbsd.go b/overlay/tun_openbsd.go index 53f57b1..6463ccb 100644 --- a/overlay/tun_openbsd.go +++ b/overlay/tun_openbsd.go @@ -6,7 +6,7 @@ package overlay import ( "fmt" "io" - "net" + "net/netip" "os" "os/exec" "regexp" @@ -14,19 +14,18 @@ import ( "sync/atomic" "syscall" + "github.com/gaissmai/bart" "github.com/sirupsen/logrus" - "github.com/slackhq/nebula/cidr" "github.com/slackhq/nebula/config" - "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/util" ) type tun struct { Device string - cidr *net.IPNet + cidr netip.Prefix MTU int Routes atomic.Pointer[[]Route] - routeTree atomic.Pointer[cidr.Tree4[iputil.VpnIp]] + routeTree atomic.Pointer[bart.Table[netip.Addr]] l *logrus.Logger io.ReadWriteCloser @@ -43,13 +42,13 @@ func (t *tun) Close() error { return nil } -func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ *net.IPNet) (*tun, error) { +func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ netip.Prefix) (*tun, error) { return nil, fmt.Errorf("newTunFromFd not supported in OpenBSD") } var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`) -func newTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, _ bool) (*tun, error) { +func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, error) { deviceName := c.GetString("tun.dev", "") if deviceName == "" { return nil, fmt.Errorf("a device name in the format of tunN must be specified") @@ -127,7 +126,7 @@ func (t *tun) reload(c *config.C, initial bool) error { func (t *tun) Activate() error { var err error // TODO use syscalls instead of exec.Command - cmd := exec.Command("/sbin/ifconfig", t.Device, t.cidr.String(), t.cidr.IP.String()) + cmd := exec.Command("/sbin/ifconfig", t.Device, t.cidr.String(), t.cidr.Addr().String()) t.l.Debug("command: ", cmd.String()) if err = cmd.Run(); err != nil { return fmt.Errorf("failed to run 'ifconfig': %s", err) @@ -139,7 +138,7 @@ func (t *tun) Activate() error { return fmt.Errorf("failed to run 'ifconfig': %s", err) } - cmd = exec.Command("/sbin/route", "-n", "add", "-inet", t.cidr.String(), t.cidr.IP.String()) + cmd = exec.Command("/sbin/route", "-n", "add", "-inet", t.cidr.String(), t.cidr.Addr().String()) t.l.Debug("command: ", cmd.String()) if err = cmd.Run(); err != nil { return fmt.Errorf("failed to run 'route add': %s", err) @@ -149,20 +148,20 @@ func (t *tun) Activate() error { return t.addRoutes(false) } -func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp { - _, r := t.routeTree.Load().MostSpecificContains(ip) +func (t *tun) RouteFor(ip netip.Addr) netip.Addr { + r, _ := t.routeTree.Load().Lookup(ip) return r } func (t *tun) addRoutes(logErrors bool) error { routes := *t.Routes.Load() for _, r := range routes { - if r.Via == nil || !r.Install { + if !r.Via.IsValid() || !r.Install { // We don't allow route MTUs so only install routes with a via continue } - cmd := exec.Command("/sbin/route", "-n", "add", "-inet", r.Cidr.String(), t.cidr.IP.String()) + cmd := exec.Command("/sbin/route", "-n", "add", "-inet", r.Cidr.String(), t.cidr.Addr().String()) t.l.Debug("command: ", cmd.String()) if err := cmd.Run(); err != nil { retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]interface{}{"route": r}, err) @@ -183,7 +182,7 @@ func (t *tun) removeRoutes(routes []Route) error { continue } - cmd := exec.Command("/sbin/route", "-n", "delete", "-inet", r.Cidr.String(), t.cidr.IP.String()) + cmd := exec.Command("/sbin/route", "-n", "delete", "-inet", r.Cidr.String(), t.cidr.Addr().String()) t.l.Debug("command: ", cmd.String()) if err := cmd.Run(); err != nil { t.l.WithError(err).WithField("route", r).Error("Failed to remove route") @@ -194,7 +193,7 @@ func (t *tun) removeRoutes(routes []Route) error { return nil } -func (t *tun) Cidr() *net.IPNet { +func (t *tun) Cidr() netip.Prefix { return t.cidr } diff --git a/overlay/tun_tester.go b/overlay/tun_tester.go index 3833983..ba15723 100644 --- a/overlay/tun_tester.go +++ b/overlay/tun_tester.go @@ -6,21 +6,20 @@ package overlay import ( "fmt" "io" - "net" + "net/netip" "os" "sync/atomic" + "github.com/gaissmai/bart" "github.com/sirupsen/logrus" - "github.com/slackhq/nebula/cidr" "github.com/slackhq/nebula/config" - "github.com/slackhq/nebula/iputil" ) type TestTun struct { Device string - cidr *net.IPNet + cidr netip.Prefix Routes []Route - routeTree *cidr.Tree4[iputil.VpnIp] + routeTree *bart.Table[netip.Addr] l *logrus.Logger closed atomic.Bool @@ -28,7 +27,7 @@ type TestTun struct { TxPackets chan []byte // Packets transmitted outside by nebula } -func newTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, _ bool) (*TestTun, error) { +func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*TestTun, error) { _, routes, err := getAllRoutesFromConfig(c, cidr, true) if err != nil { return nil, err @@ -49,7 +48,7 @@ func newTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, _ bool) (*TestTun, e }, nil } -func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ *net.IPNet) (*TestTun, error) { +func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ netip.Prefix) (*TestTun, error) { return nil, fmt.Errorf("newTunFromFd not supported") } @@ -87,8 +86,8 @@ func (t *TestTun) Get(block bool) []byte { // Below this is boilerplate implementation to make nebula actually work //********************************************************************************************************************// -func (t *TestTun) RouteFor(ip iputil.VpnIp) iputil.VpnIp { - _, r := t.routeTree.MostSpecificContains(ip) +func (t *TestTun) RouteFor(ip netip.Addr) netip.Addr { + r, _ := t.routeTree.Lookup(ip) return r } @@ -96,7 +95,7 @@ func (t *TestTun) Activate() error { return nil } -func (t *TestTun) Cidr() *net.IPNet { +func (t *TestTun) Cidr() netip.Prefix { return t.cidr } diff --git a/overlay/tun_water_windows.go b/overlay/tun_water_windows.go index a1acd2b..d78f564 100644 --- a/overlay/tun_water_windows.go +++ b/overlay/tun_water_windows.go @@ -4,30 +4,30 @@ import ( "fmt" "io" "net" + "net/netip" "os/exec" "strconv" "sync/atomic" + "github.com/gaissmai/bart" "github.com/sirupsen/logrus" - "github.com/slackhq/nebula/cidr" "github.com/slackhq/nebula/config" - "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/util" "github.com/songgao/water" ) type waterTun struct { Device string - cidr *net.IPNet + cidr netip.Prefix MTU int Routes atomic.Pointer[[]Route] - routeTree atomic.Pointer[cidr.Tree4[iputil.VpnIp]] + routeTree atomic.Pointer[bart.Table[netip.Addr]] l *logrus.Logger f *net.Interface *water.Interface } -func newWaterTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, _ bool) (*waterTun, error) { +func newWaterTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*waterTun, error) { // NOTE: You cannot set the deviceName under Windows, so you must check tun.Device after calling .Activate() t := &waterTun{ cidr: cidr, @@ -70,8 +70,8 @@ func (t *waterTun) Activate() error { `C:\Windows\System32\netsh.exe`, "interface", "ipv4", "set", "address", fmt.Sprintf("name=%s", t.Device), "source=static", - fmt.Sprintf("addr=%s", t.cidr.IP), - fmt.Sprintf("mask=%s", net.IP(t.cidr.Mask)), + fmt.Sprintf("addr=%s", t.cidr.Addr()), + fmt.Sprintf("mask=%s", net.CIDRMask(t.cidr.Bits(), t.cidr.Addr().BitLen())), "gateway=none", ).Run() if err != nil { @@ -141,7 +141,7 @@ func (t *waterTun) addRoutes(logErrors bool) error { // Path routes routes := *t.Routes.Load() for _, r := range routes { - if r.Via == nil || !r.Install { + if !r.Via.IsValid() || !r.Install { // We don't allow route MTUs so only install routes with a via continue } @@ -182,12 +182,12 @@ func (t *waterTun) removeRoutes(routes []Route) { } } -func (t *waterTun) RouteFor(ip iputil.VpnIp) iputil.VpnIp { - _, r := t.routeTree.Load().MostSpecificContains(ip) +func (t *waterTun) RouteFor(ip netip.Addr) netip.Addr { + r, _ := t.routeTree.Load().Lookup(ip) return r } -func (t *waterTun) Cidr() *net.IPNet { +func (t *waterTun) Cidr() netip.Prefix { return t.cidr } diff --git a/overlay/tun_windows.go b/overlay/tun_windows.go index f85ee9c..3d88309 100644 --- a/overlay/tun_windows.go +++ b/overlay/tun_windows.go @@ -5,7 +5,7 @@ package overlay import ( "fmt" - "net" + "net/netip" "os" "path/filepath" "runtime" @@ -15,11 +15,11 @@ import ( "github.com/slackhq/nebula/config" ) -func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ *net.IPNet) (Device, error) { +func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ netip.Prefix) (Device, error) { return nil, fmt.Errorf("newTunFromFd not supported in Windows") } -func newTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, multiqueue bool) (Device, error) { +func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, multiqueue bool) (Device, error) { useWintun := true if err := checkWinTunExists(); err != nil { l.WithError(err).Warn("Check Wintun driver failed, fallback to wintap driver") diff --git a/overlay/tun_wintun_windows.go b/overlay/tun_wintun_windows.go index 197e3a7..d010387 100644 --- a/overlay/tun_wintun_windows.go +++ b/overlay/tun_wintun_windows.go @@ -4,15 +4,13 @@ import ( "crypto" "fmt" "io" - "net" "net/netip" "sync/atomic" "unsafe" + "github.com/gaissmai/bart" "github.com/sirupsen/logrus" - "github.com/slackhq/nebula/cidr" "github.com/slackhq/nebula/config" - "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/util" "github.com/slackhq/nebula/wintun" "golang.org/x/sys/windows" @@ -23,11 +21,10 @@ const tunGUIDLabel = "Fixed Nebula Windows GUID v1" type winTun struct { Device string - cidr *net.IPNet - prefix netip.Prefix + cidr netip.Prefix MTU int Routes atomic.Pointer[[]Route] - routeTree atomic.Pointer[cidr.Tree4[iputil.VpnIp]] + routeTree atomic.Pointer[bart.Table[netip.Addr]] l *logrus.Logger tun *wintun.NativeTun @@ -52,22 +49,16 @@ func generateGUIDByDeviceName(name string) (*windows.GUID, error) { return (*windows.GUID)(unsafe.Pointer(&sum[0])), nil } -func newWinTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, _ bool) (*winTun, error) { +func newWinTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*winTun, error) { deviceName := c.GetString("tun.dev", "") guid, err := generateGUIDByDeviceName(deviceName) if err != nil { return nil, fmt.Errorf("generate GUID failed: %w", err) } - prefix, err := iputil.ToNetIpPrefix(*cidr) - if err != nil { - return nil, err - } - t := &winTun{ Device: deviceName, cidr: cidr, - prefix: prefix, MTU: c.GetInt("tun.mtu", DefaultMTU), l: l, } @@ -140,7 +131,7 @@ func (t *winTun) reload(c *config.C, initial bool) error { func (t *winTun) Activate() error { luid := winipcfg.LUID(t.tun.LUID()) - err := luid.SetIPAddresses([]netip.Prefix{t.prefix}) + err := luid.SetIPAddresses([]netip.Prefix{t.cidr}) if err != nil { return fmt.Errorf("failed to set address: %w", err) } @@ -159,24 +150,13 @@ func (t *winTun) addRoutes(logErrors bool) error { foundDefault4 := false for _, r := range routes { - if r.Via == nil || !r.Install { + if !r.Via.IsValid() || !r.Install { // We don't allow route MTUs so only install routes with a via continue } - prefix, err := iputil.ToNetIpPrefix(*r.Cidr) - if err != nil { - retErr := util.NewContextualError("Failed to parse cidr to netip prefix, ignoring route", map[string]interface{}{"route": r}, err) - if logErrors { - retErr.Log(t.l) - continue - } else { - return retErr - } - } - // Add our unsafe route - err = luid.AddRoute(prefix, r.Via.ToNetIpAddr(), uint32(r.Metric)) + err := luid.AddRoute(r.Cidr, r.Via, uint32(r.Metric)) if err != nil { retErr := util.NewContextualError("Failed to add route", map[string]interface{}{"route": r}, err) if logErrors { @@ -190,7 +170,7 @@ func (t *winTun) addRoutes(logErrors bool) error { } if !foundDefault4 { - if ones, bits := r.Cidr.Mask.Size(); ones == 0 && bits != 0 { + if r.Cidr.Bits() == 0 && r.Cidr.Addr().BitLen() == 32 { foundDefault4 = true } } @@ -221,13 +201,7 @@ func (t *winTun) removeRoutes(routes []Route) error { continue } - prefix, err := iputil.ToNetIpPrefix(*r.Cidr) - if err != nil { - t.l.WithError(err).WithField("route", r).Info("Failed to convert cidr to netip prefix") - continue - } - - err = luid.DeleteRoute(prefix, r.Via.ToNetIpAddr()) + err := luid.DeleteRoute(r.Cidr, r.Via) if err != nil { t.l.WithError(err).WithField("route", r).Error("Failed to remove route") } else { @@ -237,12 +211,12 @@ func (t *winTun) removeRoutes(routes []Route) error { return nil } -func (t *winTun) RouteFor(ip iputil.VpnIp) iputil.VpnIp { - _, r := t.routeTree.Load().MostSpecificContains(ip) +func (t *winTun) RouteFor(ip netip.Addr) netip.Addr { + r, _ := t.routeTree.Load().Lookup(ip) return r } -func (t *winTun) Cidr() *net.IPNet { +func (t *winTun) Cidr() netip.Prefix { return t.cidr } diff --git a/overlay/user.go b/overlay/user.go index 9d819ae..1bb4ef5 100644 --- a/overlay/user.go +++ b/overlay/user.go @@ -2,18 +2,17 @@ package overlay import ( "io" - "net" + "net/netip" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" - "github.com/slackhq/nebula/iputil" ) -func NewUserDeviceFromConfig(c *config.C, l *logrus.Logger, tunCidr *net.IPNet, routines int) (Device, error) { +func NewUserDeviceFromConfig(c *config.C, l *logrus.Logger, tunCidr netip.Prefix, routines int) (Device, error) { return NewUserDevice(tunCidr) } -func NewUserDevice(tunCidr *net.IPNet) (Device, error) { +func NewUserDevice(tunCidr netip.Prefix) (Device, error) { // these pipes guarantee each write/read will match 1:1 or, ow := io.Pipe() ir, iw := io.Pipe() @@ -27,7 +26,7 @@ func NewUserDevice(tunCidr *net.IPNet) (Device, error) { } type UserDevice struct { - tunCidr *net.IPNet + tunCidr netip.Prefix outboundReader *io.PipeReader outboundWriter *io.PipeWriter @@ -39,9 +38,9 @@ type UserDevice struct { func (d *UserDevice) Activate() error { return nil } -func (d *UserDevice) Cidr() *net.IPNet { return d.tunCidr } -func (d *UserDevice) Name() string { return "faketun0" } -func (d *UserDevice) RouteFor(ip iputil.VpnIp) iputil.VpnIp { return ip } +func (d *UserDevice) Cidr() netip.Prefix { return d.tunCidr } +func (d *UserDevice) Name() string { return "faketun0" } +func (d *UserDevice) RouteFor(ip netip.Addr) netip.Addr { return ip } func (d *UserDevice) NewMultiQueueReader() (io.ReadWriteCloser, error) { return d, nil } diff --git a/pki.go b/pki.go index 91478ce..ab95a04 100644 --- a/pki.go +++ b/pki.go @@ -80,6 +80,8 @@ func (p *PKI) reloadCert(c *config.C, initial bool) *util.ContextualError { } if !initial { + //TODO: include check for mask equality as well + // did IP in cert change? if so, don't set currentCert := p.cs.Load().Certificate oldIPs := currentCert.Details.Ips diff --git a/relay_manager.go b/relay_manager.go index 7aa06cc..1a3a4d4 100644 --- a/relay_manager.go +++ b/relay_manager.go @@ -2,14 +2,15 @@ package nebula import ( "context" + "encoding/binary" "errors" "fmt" + "net/netip" "sync/atomic" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/header" - "github.com/slackhq/nebula/iputil" ) type relayManager struct { @@ -50,7 +51,7 @@ func (rm *relayManager) setAmRelay(v bool) { // AddRelay finds an available relay index on the hostmap, and associates the relay info with it. // relayHostInfo is the Nebula peer which can be used as a relay to access the target vpnIp. -func AddRelay(l *logrus.Logger, relayHostInfo *HostInfo, hm *HostMap, vpnIp iputil.VpnIp, remoteIdx *uint32, relayType int, state int) (uint32, error) { +func AddRelay(l *logrus.Logger, relayHostInfo *HostInfo, hm *HostMap, vpnIp netip.Addr, remoteIdx *uint32, relayType int, state int) (uint32, error) { hm.Lock() defer hm.Unlock() for i := 0; i < 32; i++ { @@ -113,13 +114,17 @@ func (rm *relayManager) HandleControlMsg(h *HostInfo, m *NebulaControl, f *Inter func (rm *relayManager) handleCreateRelayResponse(h *HostInfo, f *Interface, m *NebulaControl) { rm.l.WithFields(logrus.Fields{ - "relayFrom": iputil.VpnIp(m.RelayFromIp), - "relayTo": iputil.VpnIp(m.RelayToIp), + "relayFrom": m.RelayFromIp, + "relayTo": m.RelayToIp, "initiatorRelayIndex": m.InitiatorRelayIndex, "responderRelayIndex": m.ResponderRelayIndex, "vpnIp": h.vpnIp}). Info("handleCreateRelayResponse") - target := iputil.VpnIp(m.RelayToIp) + target := m.RelayToIp + //TODO: IPV6-WORK + b := [4]byte{} + binary.BigEndian.PutUint32(b[:], m.RelayToIp) + targetAddr := netip.AddrFrom4(b) relay, err := rm.EstablishRelay(h, m) if err != nil { @@ -136,18 +141,20 @@ func (rm *relayManager) handleCreateRelayResponse(h *HostInfo, f *Interface, m * rm.l.WithField("relayTo", relay.PeerIp).Error("Can't find a HostInfo for peer") return } - peerRelay, ok := peerHostInfo.relayState.QueryRelayForByIp(target) + peerRelay, ok := peerHostInfo.relayState.QueryRelayForByIp(targetAddr) if !ok { rm.l.WithField("relayTo", peerHostInfo.vpnIp).Error("peerRelay does not have Relay state for relayTo") return } if peerRelay.State == PeerRequested { + //TODO: IPV6-WORK + b = peerHostInfo.vpnIp.As4() peerRelay.State = Established resp := NebulaControl{ Type: NebulaControl_CreateRelayResponse, ResponderRelayIndex: peerRelay.LocalIndex, InitiatorRelayIndex: peerRelay.RemoteIndex, - RelayFromIp: uint32(peerHostInfo.vpnIp), + RelayFromIp: binary.BigEndian.Uint32(b[:]), RelayToIp: uint32(target), } msg, err := resp.Marshal() @@ -157,8 +164,8 @@ func (rm *relayManager) handleCreateRelayResponse(h *HostInfo, f *Interface, m * } else { f.SendMessageToHostInfo(header.Control, 0, peerHostInfo, msg, make([]byte, 12), make([]byte, mtu)) rm.l.WithFields(logrus.Fields{ - "relayFrom": iputil.VpnIp(resp.RelayFromIp), - "relayTo": iputil.VpnIp(resp.RelayToIp), + "relayFrom": resp.RelayFromIp, + "relayTo": resp.RelayToIp, "initiatorRelayIndex": resp.InitiatorRelayIndex, "responderRelayIndex": resp.ResponderRelayIndex, "vpnIp": peerHostInfo.vpnIp}). @@ -168,9 +175,13 @@ func (rm *relayManager) handleCreateRelayResponse(h *HostInfo, f *Interface, m * } func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *NebulaControl) { + //TODO: IPV6-WORK + b := [4]byte{} + binary.BigEndian.PutUint32(b[:], m.RelayFromIp) + from := netip.AddrFrom4(b) - from := iputil.VpnIp(m.RelayFromIp) - target := iputil.VpnIp(m.RelayToIp) + binary.BigEndian.PutUint32(b[:], m.RelayToIp) + target := netip.AddrFrom4(b) logMsg := rm.l.WithFields(logrus.Fields{ "relayFrom": from, @@ -181,12 +192,12 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N logMsg.Info("handleCreateRelayRequest") // Is the source of the relay me? This should never happen, but did happen due to // an issue migrating relays over to newly re-handshaked host info objects. - if from == f.myVpnIp { - logMsg.WithField("myIP", f.myVpnIp).Error("Discarding relay request from myself") + if from == f.myVpnNet.Addr() { + logMsg.WithField("myIP", from).Error("Discarding relay request from myself") return } // Is the target of the relay me? - if target == f.myVpnIp { + if target == f.myVpnNet.Addr() { existingRelay, ok := h.relayState.QueryRelayForByIp(from) if ok { switch existingRelay.State { @@ -219,12 +230,16 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N return } + //TODO: IPV6-WORK + fromB := from.As4() + targetB := target.As4() + resp := NebulaControl{ Type: NebulaControl_CreateRelayResponse, ResponderRelayIndex: relay.LocalIndex, InitiatorRelayIndex: relay.RemoteIndex, - RelayFromIp: uint32(from), - RelayToIp: uint32(target), + RelayFromIp: binary.BigEndian.Uint32(fromB[:]), + RelayToIp: binary.BigEndian.Uint32(targetB[:]), } msg, err := resp.Marshal() if err != nil { @@ -233,8 +248,9 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N } else { f.SendMessageToHostInfo(header.Control, 0, h, msg, make([]byte, 12), make([]byte, mtu)) rm.l.WithFields(logrus.Fields{ - "relayFrom": iputil.VpnIp(resp.RelayFromIp), - "relayTo": iputil.VpnIp(resp.RelayToIp), + //TODO: IPV6-WORK, this used to use the resp object but I am getting lazy now + "relayFrom": from, + "relayTo": target, "initiatorRelayIndex": resp.InitiatorRelayIndex, "responderRelayIndex": resp.ResponderRelayIndex, "vpnIp": h.vpnIp}). @@ -253,7 +269,7 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N f.Handshake(target) return } - if peer.remote == nil { + if !peer.remote.IsValid() { // Only create relays to peers for whom I have a direct connection return } @@ -275,12 +291,16 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N sendCreateRequest = true } if sendCreateRequest { + //TODO: IPV6-WORK + fromB := h.vpnIp.As4() + targetB := target.As4() + // Send a CreateRelayRequest to the peer. req := NebulaControl{ Type: NebulaControl_CreateRelayRequest, InitiatorRelayIndex: index, - RelayFromIp: uint32(h.vpnIp), - RelayToIp: uint32(target), + RelayFromIp: binary.BigEndian.Uint32(fromB[:]), + RelayToIp: binary.BigEndian.Uint32(targetB[:]), } msg, err := req.Marshal() if err != nil { @@ -289,8 +309,9 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N } else { f.SendMessageToHostInfo(header.Control, 0, peer, msg, make([]byte, 12), make([]byte, mtu)) rm.l.WithFields(logrus.Fields{ - "relayFrom": iputil.VpnIp(req.RelayFromIp), - "relayTo": iputil.VpnIp(req.RelayToIp), + //TODO: IPV6-WORK another lazy used to use the req object + "relayFrom": h.vpnIp, + "relayTo": target, "initiatorRelayIndex": req.InitiatorRelayIndex, "responderRelayIndex": req.ResponderRelayIndex, "vpnIp": target}). @@ -321,12 +342,15 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N "existingRemoteIndex": relay.RemoteIndex}).Error("Existing relay mismatch with CreateRelayRequest") return } + //TODO: IPV6-WORK + fromB := h.vpnIp.As4() + targetB := target.As4() resp := NebulaControl{ Type: NebulaControl_CreateRelayResponse, ResponderRelayIndex: relay.LocalIndex, InitiatorRelayIndex: relay.RemoteIndex, - RelayFromIp: uint32(h.vpnIp), - RelayToIp: uint32(target), + RelayFromIp: binary.BigEndian.Uint32(fromB[:]), + RelayToIp: binary.BigEndian.Uint32(targetB[:]), } msg, err := resp.Marshal() if err != nil { @@ -335,8 +359,9 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N } else { f.SendMessageToHostInfo(header.Control, 0, h, msg, make([]byte, 12), make([]byte, mtu)) rm.l.WithFields(logrus.Fields{ - "relayFrom": iputil.VpnIp(resp.RelayFromIp), - "relayTo": iputil.VpnIp(resp.RelayToIp), + //TODO: IPV6-WORK more lazy, used to use resp object + "relayFrom": h.vpnIp, + "relayTo": target, "initiatorRelayIndex": resp.InitiatorRelayIndex, "responderRelayIndex": resp.ResponderRelayIndex, "vpnIp": h.vpnIp}). @@ -349,7 +374,3 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N } } } - -func (rm *relayManager) RemoveRelay(localIdx uint32) { - rm.hostmap.RemoveRelay(localIdx) -} diff --git a/remote_list.go b/remote_list.go index 60a1afd..fa14f42 100644 --- a/remote_list.go +++ b/remote_list.go @@ -1,7 +1,6 @@ package nebula import ( - "bytes" "context" "net" "net/netip" @@ -12,16 +11,14 @@ import ( "time" "github.com/sirupsen/logrus" - "github.com/slackhq/nebula/iputil" - "github.com/slackhq/nebula/udp" ) // forEachFunc is used to benefit folks that want to do work inside the lock -type forEachFunc func(addr *udp.Addr, preferred bool) +type forEachFunc func(addr netip.AddrPort, preferred bool) // The checkFuncs here are to simplify bulk importing LH query response logic into a single function (reset slice and iterate) -type checkFuncV4 func(vpnIp iputil.VpnIp, to *Ip4AndPort) bool -type checkFuncV6 func(vpnIp iputil.VpnIp, to *Ip6AndPort) bool +type checkFuncV4 func(vpnIp netip.Addr, to *Ip4AndPort) bool +type checkFuncV6 func(vpnIp netip.Addr, to *Ip6AndPort) bool // CacheMap is a struct that better represents the lighthouse cache for humans // The string key is the owners vpnIp @@ -30,9 +27,9 @@ type CacheMap map[string]*Cache // Cache is the other part of CacheMap to better represent the lighthouse cache for humans // We don't reason about ipv4 vs ipv6 here type Cache struct { - Learned []*udp.Addr `json:"learned,omitempty"` - Reported []*udp.Addr `json:"reported,omitempty"` - Relay []*net.IP `json:"relay"` + Learned []netip.AddrPort `json:"learned,omitempty"` + Reported []netip.AddrPort `json:"reported,omitempty"` + Relay []netip.Addr `json:"relay"` } //TODO: Seems like we should plop static host entries in here too since the are protected by the lighthouse from deletion @@ -46,7 +43,7 @@ type cache struct { } type cacheRelay struct { - relay []uint32 + relay []netip.Addr } // cacheV4 stores learned and reported ipv4 records under cache @@ -130,7 +127,7 @@ func NewHostnameResults(ctx context.Context, l *logrus.Logger, d time.Duration, continue } for _, a := range addrs { - netipAddrs[netip.AddrPortFrom(a, hostPort.port)] = struct{}{} + netipAddrs[netip.AddrPortFrom(a.Unmap(), hostPort.port)] = struct{}{} } } origSet := r.ips.Load() @@ -193,22 +190,22 @@ type RemoteList struct { sync.RWMutex // A deduplicated set of addresses. Any accessor should lock beforehand. - addrs []*udp.Addr + addrs []netip.AddrPort // A set of relay addresses. VpnIp addresses that the remote identified as relays. - relays []*iputil.VpnIp + relays []netip.Addr // These are maps to store v4 and v6 addresses per lighthouse // Map key is the vpnIp of the person that told us about this the cached entries underneath. // For learned addresses, this is the vpnIp that sent the packet - cache map[iputil.VpnIp]*cache + cache map[netip.Addr]*cache hr *hostnamesResults shouldAdd func(netip.Addr) bool // This is a list of remotes that we have tried to handshake with and have returned from the wrong vpn ip. // They should not be tried again during a handshake - badRemotes []*udp.Addr + badRemotes []netip.AddrPort // A flag that the cache may have changed and addrs needs to be rebuilt shouldRebuild bool @@ -217,9 +214,9 @@ type RemoteList struct { // NewRemoteList creates a new empty RemoteList func NewRemoteList(shouldAdd func(netip.Addr) bool) *RemoteList { return &RemoteList{ - addrs: make([]*udp.Addr, 0), - relays: make([]*iputil.VpnIp, 0), - cache: make(map[iputil.VpnIp]*cache), + addrs: make([]netip.AddrPort, 0), + relays: make([]netip.Addr, 0), + cache: make(map[netip.Addr]*cache), shouldAdd: shouldAdd, } } @@ -232,7 +229,7 @@ func (r *RemoteList) unlockedSetHostnamesResults(hr *hostnamesResults) { // Len locks and reports the size of the deduplicated address list // The deduplication work may need to occur here, so you must pass preferredRanges -func (r *RemoteList) Len(preferredRanges []*net.IPNet) int { +func (r *RemoteList) Len(preferredRanges []netip.Prefix) int { r.Rebuild(preferredRanges) r.RLock() defer r.RUnlock() @@ -241,18 +238,18 @@ func (r *RemoteList) Len(preferredRanges []*net.IPNet) int { // ForEach locks and will call the forEachFunc for every deduplicated address in the list // The deduplication work may need to occur here, so you must pass preferredRanges -func (r *RemoteList) ForEach(preferredRanges []*net.IPNet, forEach forEachFunc) { +func (r *RemoteList) ForEach(preferredRanges []netip.Prefix, forEach forEachFunc) { r.Rebuild(preferredRanges) r.RLock() for _, v := range r.addrs { - forEach(v, isPreferred(v.IP, preferredRanges)) + forEach(v, isPreferred(v.Addr(), preferredRanges)) } r.RUnlock() } // CopyAddrs locks and makes a deep copy of the deduplicated address list // The deduplication work may need to occur here, so you must pass preferredRanges -func (r *RemoteList) CopyAddrs(preferredRanges []*net.IPNet) []*udp.Addr { +func (r *RemoteList) CopyAddrs(preferredRanges []netip.Prefix) []netip.AddrPort { if r == nil { return nil } @@ -261,9 +258,9 @@ func (r *RemoteList) CopyAddrs(preferredRanges []*net.IPNet) []*udp.Addr { r.RLock() defer r.RUnlock() - c := make([]*udp.Addr, len(r.addrs)) + c := make([]netip.AddrPort, len(r.addrs)) for i, v := range r.addrs { - c[i] = v.Copy() + c[i] = v } return c } @@ -272,13 +269,13 @@ func (r *RemoteList) CopyAddrs(preferredRanges []*net.IPNet) []*udp.Addr { // Currently this is only needed when HostInfo.SetRemote is called as that should cover both handshaking and roaming. // It will mark the deduplicated address list as dirty, so do not call it unless new information is available // TODO: this needs to support the allow list list -func (r *RemoteList) LearnRemote(ownerVpnIp iputil.VpnIp, addr *udp.Addr) { +func (r *RemoteList) LearnRemote(ownerVpnIp netip.Addr, remote netip.AddrPort) { r.Lock() defer r.Unlock() - if v4 := addr.IP.To4(); v4 != nil { - r.unlockedSetLearnedV4(ownerVpnIp, NewIp4AndPort(v4, uint32(addr.Port))) + if remote.Addr().Is4() { + r.unlockedSetLearnedV4(ownerVpnIp, NewIp4AndPortFromNetIP(remote.Addr(), remote.Port())) } else { - r.unlockedSetLearnedV6(ownerVpnIp, NewIp6AndPort(addr.IP, uint32(addr.Port))) + r.unlockedSetLearnedV6(ownerVpnIp, NewIp6AndPortFromNetIP(remote.Addr(), remote.Port())) } } @@ -293,9 +290,9 @@ func (r *RemoteList) CopyCache() *CacheMap { c := cm[vpnIp] if c == nil { c = &Cache{ - Learned: make([]*udp.Addr, 0), - Reported: make([]*udp.Addr, 0), - Relay: make([]*net.IP, 0), + Learned: make([]netip.AddrPort, 0), + Reported: make([]netip.AddrPort, 0), + Relay: make([]netip.Addr, 0), } cm[vpnIp] = c } @@ -307,28 +304,27 @@ func (r *RemoteList) CopyCache() *CacheMap { if mc.v4 != nil { if mc.v4.learned != nil { - c.Learned = append(c.Learned, NewUDPAddrFromLH4(mc.v4.learned)) + c.Learned = append(c.Learned, AddrPortFromIp4AndPort(mc.v4.learned)) } for _, a := range mc.v4.reported { - c.Reported = append(c.Reported, NewUDPAddrFromLH4(a)) + c.Reported = append(c.Reported, AddrPortFromIp4AndPort(a)) } } if mc.v6 != nil { if mc.v6.learned != nil { - c.Learned = append(c.Learned, NewUDPAddrFromLH6(mc.v6.learned)) + c.Learned = append(c.Learned, AddrPortFromIp6AndPort(mc.v6.learned)) } for _, a := range mc.v6.reported { - c.Reported = append(c.Reported, NewUDPAddrFromLH6(a)) + c.Reported = append(c.Reported, AddrPortFromIp6AndPort(a)) } } if mc.relay != nil { for _, a := range mc.relay.relay { - nip := iputil.VpnIp(a).ToIP() - c.Relay = append(c.Relay, &nip) + c.Relay = append(c.Relay, a) } } } @@ -337,8 +333,8 @@ func (r *RemoteList) CopyCache() *CacheMap { } // BlockRemote locks and records the address as bad, it will be excluded from the deduplicated address list -func (r *RemoteList) BlockRemote(bad *udp.Addr) { - if bad == nil { +func (r *RemoteList) BlockRemote(bad netip.AddrPort) { + if !bad.IsValid() { // relays can have nil udp Addrs return } @@ -351,20 +347,20 @@ func (r *RemoteList) BlockRemote(bad *udp.Addr) { } // We copy here because we are taking something else's memory and we can't trust everything - r.badRemotes = append(r.badRemotes, bad.Copy()) + r.badRemotes = append(r.badRemotes, bad) // Mark the next interaction must recollect/dedupe r.shouldRebuild = true } // CopyBlockedRemotes locks and makes a deep copy of the blocked remotes list -func (r *RemoteList) CopyBlockedRemotes() []*udp.Addr { +func (r *RemoteList) CopyBlockedRemotes() []netip.AddrPort { r.RLock() defer r.RUnlock() - c := make([]*udp.Addr, len(r.badRemotes)) + c := make([]netip.AddrPort, len(r.badRemotes)) for i, v := range r.badRemotes { - c[i] = v.Copy() + c[i] = v } return c } @@ -378,7 +374,7 @@ func (r *RemoteList) ResetBlockedRemotes() { // Rebuild locks and generates the deduplicated address list only if there is work to be done // There is generally no reason to call this directly but it is safe to do so -func (r *RemoteList) Rebuild(preferredRanges []*net.IPNet) { +func (r *RemoteList) Rebuild(preferredRanges []netip.Prefix) { r.Lock() defer r.Unlock() @@ -394,9 +390,9 @@ func (r *RemoteList) Rebuild(preferredRanges []*net.IPNet) { } // unlockedIsBad assumes you have the write lock and checks if the remote matches any entry in the blocked address list -func (r *RemoteList) unlockedIsBad(remote *udp.Addr) bool { +func (r *RemoteList) unlockedIsBad(remote netip.AddrPort) bool { for _, v := range r.badRemotes { - if v.Equals(remote) { + if v == remote { return true } } @@ -405,14 +401,14 @@ func (r *RemoteList) unlockedIsBad(remote *udp.Addr) bool { // unlockedSetLearnedV4 assumes you have the write lock and sets the current learned address for this owner and marks the // deduplicated address list as dirty -func (r *RemoteList) unlockedSetLearnedV4(ownerVpnIp iputil.VpnIp, to *Ip4AndPort) { +func (r *RemoteList) unlockedSetLearnedV4(ownerVpnIp netip.Addr, to *Ip4AndPort) { r.shouldRebuild = true r.unlockedGetOrMakeV4(ownerVpnIp).learned = to } // unlockedSetV4 assumes you have the write lock and resets the reported list of ips for this owner to the list provided // and marks the deduplicated address list as dirty -func (r *RemoteList) unlockedSetV4(ownerVpnIp iputil.VpnIp, vpnIp iputil.VpnIp, to []*Ip4AndPort, check checkFuncV4) { +func (r *RemoteList) unlockedSetV4(ownerVpnIp, vpnIp netip.Addr, to []*Ip4AndPort, check checkFuncV4) { r.shouldRebuild = true c := r.unlockedGetOrMakeV4(ownerVpnIp) @@ -427,7 +423,7 @@ func (r *RemoteList) unlockedSetV4(ownerVpnIp iputil.VpnIp, vpnIp iputil.VpnIp, } } -func (r *RemoteList) unlockedSetRelay(ownerVpnIp iputil.VpnIp, vpnIp iputil.VpnIp, to []uint32) { +func (r *RemoteList) unlockedSetRelay(ownerVpnIp, vpnIp netip.Addr, to []netip.Addr) { r.shouldRebuild = true c := r.unlockedGetOrMakeRelay(ownerVpnIp) @@ -440,7 +436,7 @@ func (r *RemoteList) unlockedSetRelay(ownerVpnIp iputil.VpnIp, vpnIp iputil.VpnI // unlockedPrependV4 assumes you have the write lock and prepends the address in the reported list for this owner // This is only useful for establishing static hosts -func (r *RemoteList) unlockedPrependV4(ownerVpnIp iputil.VpnIp, to *Ip4AndPort) { +func (r *RemoteList) unlockedPrependV4(ownerVpnIp netip.Addr, to *Ip4AndPort) { r.shouldRebuild = true c := r.unlockedGetOrMakeV4(ownerVpnIp) @@ -453,14 +449,14 @@ func (r *RemoteList) unlockedPrependV4(ownerVpnIp iputil.VpnIp, to *Ip4AndPort) // unlockedSetLearnedV6 assumes you have the write lock and sets the current learned address for this owner and marks the // deduplicated address list as dirty -func (r *RemoteList) unlockedSetLearnedV6(ownerVpnIp iputil.VpnIp, to *Ip6AndPort) { +func (r *RemoteList) unlockedSetLearnedV6(ownerVpnIp netip.Addr, to *Ip6AndPort) { r.shouldRebuild = true r.unlockedGetOrMakeV6(ownerVpnIp).learned = to } // unlockedSetV6 assumes you have the write lock and resets the reported list of ips for this owner to the list provided // and marks the deduplicated address list as dirty -func (r *RemoteList) unlockedSetV6(ownerVpnIp iputil.VpnIp, vpnIp iputil.VpnIp, to []*Ip6AndPort, check checkFuncV6) { +func (r *RemoteList) unlockedSetV6(ownerVpnIp, vpnIp netip.Addr, to []*Ip6AndPort, check checkFuncV6) { r.shouldRebuild = true c := r.unlockedGetOrMakeV6(ownerVpnIp) @@ -477,7 +473,7 @@ func (r *RemoteList) unlockedSetV6(ownerVpnIp iputil.VpnIp, vpnIp iputil.VpnIp, // unlockedPrependV6 assumes you have the write lock and prepends the address in the reported list for this owner // This is only useful for establishing static hosts -func (r *RemoteList) unlockedPrependV6(ownerVpnIp iputil.VpnIp, to *Ip6AndPort) { +func (r *RemoteList) unlockedPrependV6(ownerVpnIp netip.Addr, to *Ip6AndPort) { r.shouldRebuild = true c := r.unlockedGetOrMakeV6(ownerVpnIp) @@ -488,7 +484,7 @@ func (r *RemoteList) unlockedPrependV6(ownerVpnIp iputil.VpnIp, to *Ip6AndPort) } } -func (r *RemoteList) unlockedGetOrMakeRelay(ownerVpnIp iputil.VpnIp) *cacheRelay { +func (r *RemoteList) unlockedGetOrMakeRelay(ownerVpnIp netip.Addr) *cacheRelay { am := r.cache[ownerVpnIp] if am == nil { am = &cache{} @@ -503,7 +499,7 @@ func (r *RemoteList) unlockedGetOrMakeRelay(ownerVpnIp iputil.VpnIp) *cacheRelay // unlockedGetOrMakeV4 assumes you have the write lock and builds the cache and owner entry. Only the v4 pointer is established. // The caller must dirty the learned address cache if required -func (r *RemoteList) unlockedGetOrMakeV4(ownerVpnIp iputil.VpnIp) *cacheV4 { +func (r *RemoteList) unlockedGetOrMakeV4(ownerVpnIp netip.Addr) *cacheV4 { am := r.cache[ownerVpnIp] if am == nil { am = &cache{} @@ -518,7 +514,7 @@ func (r *RemoteList) unlockedGetOrMakeV4(ownerVpnIp iputil.VpnIp) *cacheV4 { // unlockedGetOrMakeV6 assumes you have the write lock and builds the cache and owner entry. Only the v6 pointer is established. // The caller must dirty the learned address cache if required -func (r *RemoteList) unlockedGetOrMakeV6(ownerVpnIp iputil.VpnIp) *cacheV6 { +func (r *RemoteList) unlockedGetOrMakeV6(ownerVpnIp netip.Addr) *cacheV6 { am := r.cache[ownerVpnIp] if am == nil { am = &cache{} @@ -540,14 +536,14 @@ func (r *RemoteList) unlockedCollect() { for _, c := range r.cache { if c.v4 != nil { if c.v4.learned != nil { - u := NewUDPAddrFromLH4(c.v4.learned) + u := AddrPortFromIp4AndPort(c.v4.learned) if !r.unlockedIsBad(u) { addrs = append(addrs, u) } } for _, v := range c.v4.reported { - u := NewUDPAddrFromLH4(v) + u := AddrPortFromIp4AndPort(v) if !r.unlockedIsBad(u) { addrs = append(addrs, u) } @@ -556,14 +552,14 @@ func (r *RemoteList) unlockedCollect() { if c.v6 != nil { if c.v6.learned != nil { - u := NewUDPAddrFromLH6(c.v6.learned) + u := AddrPortFromIp6AndPort(c.v6.learned) if !r.unlockedIsBad(u) { addrs = append(addrs, u) } } for _, v := range c.v6.reported { - u := NewUDPAddrFromLH6(v) + u := AddrPortFromIp6AndPort(v) if !r.unlockedIsBad(u) { addrs = append(addrs, u) } @@ -572,8 +568,7 @@ func (r *RemoteList) unlockedCollect() { if c.relay != nil { for _, v := range c.relay.relay { - ip := iputil.VpnIp(v) - relays = append(relays, &ip) + relays = append(relays, v) } } } @@ -581,11 +576,7 @@ func (r *RemoteList) unlockedCollect() { dnsAddrs := r.hr.GetIPs() for _, addr := range dnsAddrs { if r.shouldAdd == nil || r.shouldAdd(addr.Addr()) { - v6 := addr.Addr().As16() - addrs = append(addrs, &udp.Addr{ - IP: v6[:], - Port: addr.Port(), - }) + addrs = append(addrs, addr) } } @@ -595,7 +586,7 @@ func (r *RemoteList) unlockedCollect() { } // unlockedSort assumes you have the write lock and performs the deduping and sorting of the address list -func (r *RemoteList) unlockedSort(preferredRanges []*net.IPNet) { +func (r *RemoteList) unlockedSort(preferredRanges []netip.Prefix) { n := len(r.addrs) if n < 2 { return @@ -606,8 +597,8 @@ func (r *RemoteList) unlockedSort(preferredRanges []*net.IPNet) { b := r.addrs[j] // Preferred addresses first - aPref := isPreferred(a.IP, preferredRanges) - bPref := isPreferred(b.IP, preferredRanges) + aPref := isPreferred(a.Addr(), preferredRanges) + bPref := isPreferred(b.Addr(), preferredRanges) switch { case aPref && !bPref: // If i is preferred and j is not, i is less than j @@ -622,21 +613,21 @@ func (r *RemoteList) unlockedSort(preferredRanges []*net.IPNet) { } // ipv6 addresses 2nd - a4 := a.IP.To4() - b4 := b.IP.To4() + a4 := a.Addr().Is4() + b4 := b.Addr().Is4() switch { - case a4 == nil && b4 != nil: + case a4 == false && b4 == true: // If i is v6 and j is v4, i is less than j return true - case a4 != nil && b4 == nil: + case a4 == true && b4 == false: // If j is v6 and i is v4, i is not less than j return false - case a4 != nil && b4 != nil: - // Special case for ipv4, a4 and b4 are not nil - aPrivate := isPrivateIP(a4) - bPrivate := isPrivateIP(b4) + case a4 == true && b4 == true: + // i and j are both ipv4 + aPrivate := a.Addr().IsPrivate() + bPrivate := b.Addr().IsPrivate() switch { case !aPrivate && bPrivate: // If i is a public ip (not private) and j is a private ip, i is less then j @@ -655,10 +646,10 @@ func (r *RemoteList) unlockedSort(preferredRanges []*net.IPNet) { } // lexical order of ips 3rd - c := bytes.Compare(a.IP, b.IP) + c := a.Addr().Compare(b.Addr()) if c == 0 { // Ips are the same, Lexical order of ports 4th - return a.Port < b.Port + return a.Port() < b.Port() } // Ip wasn't the same @@ -671,7 +662,7 @@ func (r *RemoteList) unlockedSort(preferredRanges []*net.IPNet) { // Deduplicate a, b := 0, 1 for b < n { - if !r.addrs[a].Equals(r.addrs[b]) { + if r.addrs[a] != r.addrs[b] { a++ if a != b { r.addrs[a], r.addrs[b] = r.addrs[b], r.addrs[a] @@ -693,7 +684,7 @@ func minInt(a, b int) int { } // isPreferred returns true of the ip is contained in the preferredRanges list -func isPreferred(ip net.IP, preferredRanges []*net.IPNet) bool { +func isPreferred(ip netip.Addr, preferredRanges []netip.Prefix) bool { //TODO: this would be better in a CIDR6Tree for _, p := range preferredRanges { if p.Contains(ip) { @@ -702,14 +693,3 @@ func isPreferred(ip net.IP, preferredRanges []*net.IPNet) bool { } return false } - -var _, private24BitBlock, _ = net.ParseCIDR("10.0.0.0/8") -var _, private20BitBlock, _ = net.ParseCIDR("172.16.0.0/12") -var _, private16BitBlock, _ = net.ParseCIDR("192.168.0.0/16") - -// isPrivateIP returns true if the ip is contained by a rfc 1918 private range -func isPrivateIP(ip net.IP) bool { - //TODO: another great cidrtree option - //TODO: Private for ipv6 or just let it ride? - return private24BitBlock.Contains(ip) || private20BitBlock.Contains(ip) || private16BitBlock.Contains(ip) -} diff --git a/remote_list_test.go b/remote_list_test.go index 49aa171..62a892b 100644 --- a/remote_list_test.go +++ b/remote_list_test.go @@ -1,47 +1,47 @@ package nebula import ( - "net" + "encoding/binary" + "net/netip" "testing" - "github.com/slackhq/nebula/iputil" "github.com/stretchr/testify/assert" ) func TestRemoteList_Rebuild(t *testing.T) { rl := NewRemoteList(nil) rl.unlockedSetV4( - 0, - 0, + netip.MustParseAddr("0.0.0.0"), + netip.MustParseAddr("0.0.0.0"), []*Ip4AndPort{ - {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("70.199.182.92"))), Port: 1475}, // this is duped - {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.0.182"))), Port: 10101}, - {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.1.1"))), Port: 10101}, // this is duped - {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.18.0.1"))), Port: 10101}, // this is duped - {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.18.0.1"))), Port: 10101}, // this is a dupe - {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.19.0.1"))), Port: 10101}, - {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.31.0.1"))), Port: 10101}, - {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.1.1"))), Port: 10101}, // this is a dupe - {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("70.199.182.92"))), Port: 1476}, // almost dupe of 0 with a diff port - {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("70.199.182.92"))), Port: 1475}, // this is a dupe + newIp4AndPortFromString("70.199.182.92:1475"), // this is duped + newIp4AndPortFromString("172.17.0.182:10101"), + newIp4AndPortFromString("172.17.1.1:10101"), // this is duped + newIp4AndPortFromString("172.18.0.1:10101"), // this is duped + newIp4AndPortFromString("172.18.0.1:10101"), // this is a dupe + newIp4AndPortFromString("172.19.0.1:10101"), + newIp4AndPortFromString("172.31.0.1:10101"), + newIp4AndPortFromString("172.17.1.1:10101"), // this is a dupe + newIp4AndPortFromString("70.199.182.92:1476"), // almost dupe of 0 with a diff port + newIp4AndPortFromString("70.199.182.92:1475"), // this is a dupe }, - func(iputil.VpnIp, *Ip4AndPort) bool { return true }, + func(netip.Addr, *Ip4AndPort) bool { return true }, ) rl.unlockedSetV6( - 1, - 1, + netip.MustParseAddr("0.0.0.1"), + netip.MustParseAddr("0.0.0.1"), []*Ip6AndPort{ - NewIp6AndPort(net.ParseIP("1::1"), 1), // this is duped - NewIp6AndPort(net.ParseIP("1::1"), 2), // almost dupe of 0 with a diff port, also gets duped - NewIp6AndPort(net.ParseIP("1:100::1"), 1), - NewIp6AndPort(net.ParseIP("1::1"), 1), // this is a dupe - NewIp6AndPort(net.ParseIP("1::1"), 2), // this is a dupe + newIp6AndPortFromString("[1::1]:1"), // this is duped + newIp6AndPortFromString("[1::1]:2"), // almost dupe of 0 with a diff port, also gets duped + newIp6AndPortFromString("[1:100::1]:1"), + newIp6AndPortFromString("[1::1]:1"), // this is a dupe + newIp6AndPortFromString("[1::1]:2"), // this is a dupe }, - func(iputil.VpnIp, *Ip6AndPort) bool { return true }, + func(netip.Addr, *Ip6AndPort) bool { return true }, ) - rl.Rebuild([]*net.IPNet{}) + rl.Rebuild([]netip.Prefix{}) assert.Len(t, rl.addrs, 10, "addrs contains too many entries") // ipv6 first, sorted lexically within @@ -59,9 +59,7 @@ func TestRemoteList_Rebuild(t *testing.T) { assert.Equal(t, "172.31.0.1:10101", rl.addrs[9].String()) // Now ensure we can hoist ipv4 up - _, ipNet, err := net.ParseCIDR("0.0.0.0/0") - assert.NoError(t, err) - rl.Rebuild([]*net.IPNet{ipNet}) + rl.Rebuild([]netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")}) assert.Len(t, rl.addrs, 10, "addrs contains too many entries") // ipv4 first, public then private, lexically within them @@ -79,9 +77,7 @@ func TestRemoteList_Rebuild(t *testing.T) { assert.Equal(t, "[1:100::1]:1", rl.addrs[9].String()) // Ensure we can hoist a specific ipv4 range over anything else - _, ipNet, err = net.ParseCIDR("172.17.0.0/16") - assert.NoError(t, err) - rl.Rebuild([]*net.IPNet{ipNet}) + rl.Rebuild([]netip.Prefix{netip.MustParsePrefix("172.17.0.0/16")}) assert.Len(t, rl.addrs, 10, "addrs contains too many entries") // Preferred ipv4 first @@ -104,64 +100,61 @@ func TestRemoteList_Rebuild(t *testing.T) { func BenchmarkFullRebuild(b *testing.B) { rl := NewRemoteList(nil) rl.unlockedSetV4( - 0, - 0, + netip.MustParseAddr("0.0.0.0"), + netip.MustParseAddr("0.0.0.0"), []*Ip4AndPort{ - {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("70.199.182.92"))), Port: 1475}, - {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.0.182"))), Port: 10101}, - {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.1.1"))), Port: 10101}, - {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.18.0.1"))), Port: 10101}, - {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.19.0.1"))), Port: 10101}, - {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.31.0.1"))), Port: 10101}, - {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.1.1"))), Port: 10101}, // this is a dupe - {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("70.199.182.92"))), Port: 1476}, // dupe of 0 with a diff port + newIp4AndPortFromString("70.199.182.92:1475"), + newIp4AndPortFromString("172.17.0.182:10101"), + newIp4AndPortFromString("172.17.1.1:10101"), + newIp4AndPortFromString("172.18.0.1:10101"), + newIp4AndPortFromString("172.19.0.1:10101"), + newIp4AndPortFromString("172.31.0.1:10101"), + newIp4AndPortFromString("172.17.1.1:10101"), // this is a dupe + newIp4AndPortFromString("70.199.182.92:1476"), // dupe of 0 with a diff port }, - func(iputil.VpnIp, *Ip4AndPort) bool { return true }, + func(netip.Addr, *Ip4AndPort) bool { return true }, ) rl.unlockedSetV6( - 0, - 0, + netip.MustParseAddr("0.0.0.0"), + netip.MustParseAddr("0.0.0.0"), []*Ip6AndPort{ - NewIp6AndPort(net.ParseIP("1::1"), 1), - NewIp6AndPort(net.ParseIP("1::1"), 2), // dupe of 0 with a diff port - NewIp6AndPort(net.ParseIP("1:100::1"), 1), - NewIp6AndPort(net.ParseIP("1::1"), 1), // this is a dupe + newIp6AndPortFromString("[1::1]:1"), + newIp6AndPortFromString("[1::1]:2"), // dupe of 0 with a diff port + newIp6AndPortFromString("[1:100::1]:1"), + newIp6AndPortFromString("[1::1]:1"), // this is a dupe }, - func(iputil.VpnIp, *Ip6AndPort) bool { return true }, + func(netip.Addr, *Ip6AndPort) bool { return true }, ) b.Run("no preferred", func(b *testing.B) { for i := 0; i < b.N; i++ { rl.shouldRebuild = true - rl.Rebuild([]*net.IPNet{}) + rl.Rebuild([]netip.Prefix{}) } }) - _, ipNet, err := net.ParseCIDR("172.17.0.0/16") - assert.NoError(b, err) + ipNet1 := netip.MustParsePrefix("172.17.0.0/16") b.Run("1 preferred", func(b *testing.B) { for i := 0; i < b.N; i++ { rl.shouldRebuild = true - rl.Rebuild([]*net.IPNet{ipNet}) + rl.Rebuild([]netip.Prefix{ipNet1}) } }) - _, ipNet2, err := net.ParseCIDR("70.0.0.0/8") - assert.NoError(b, err) + ipNet2 := netip.MustParsePrefix("70.0.0.0/8") b.Run("2 preferred", func(b *testing.B) { for i := 0; i < b.N; i++ { rl.shouldRebuild = true - rl.Rebuild([]*net.IPNet{ipNet, ipNet2}) + rl.Rebuild([]netip.Prefix{ipNet2}) } }) - _, ipNet3, err := net.ParseCIDR("0.0.0.0/0") - assert.NoError(b, err) + ipNet3 := netip.MustParsePrefix("0.0.0.0/0") b.Run("3 preferred", func(b *testing.B) { for i := 0; i < b.N; i++ { rl.shouldRebuild = true - rl.Rebuild([]*net.IPNet{ipNet, ipNet2, ipNet3}) + rl.Rebuild([]netip.Prefix{ipNet1, ipNet2, ipNet3}) } }) } @@ -169,67 +162,83 @@ func BenchmarkFullRebuild(b *testing.B) { func BenchmarkSortRebuild(b *testing.B) { rl := NewRemoteList(nil) rl.unlockedSetV4( - 0, - 0, + netip.MustParseAddr("0.0.0.0"), + netip.MustParseAddr("0.0.0.0"), []*Ip4AndPort{ - {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("70.199.182.92"))), Port: 1475}, - {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.0.182"))), Port: 10101}, - {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.1.1"))), Port: 10101}, - {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.18.0.1"))), Port: 10101}, - {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.19.0.1"))), Port: 10101}, - {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.31.0.1"))), Port: 10101}, - {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.1.1"))), Port: 10101}, // this is a dupe - {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("70.199.182.92"))), Port: 1476}, // dupe of 0 with a diff port + newIp4AndPortFromString("70.199.182.92:1475"), + newIp4AndPortFromString("172.17.0.182:10101"), + newIp4AndPortFromString("172.17.1.1:10101"), + newIp4AndPortFromString("172.18.0.1:10101"), + newIp4AndPortFromString("172.19.0.1:10101"), + newIp4AndPortFromString("172.31.0.1:10101"), + newIp4AndPortFromString("172.17.1.1:10101"), // this is a dupe + newIp4AndPortFromString("70.199.182.92:1476"), // dupe of 0 with a diff port }, - func(iputil.VpnIp, *Ip4AndPort) bool { return true }, + func(netip.Addr, *Ip4AndPort) bool { return true }, ) rl.unlockedSetV6( - 0, - 0, + netip.MustParseAddr("0.0.0.0"), + netip.MustParseAddr("0.0.0.0"), []*Ip6AndPort{ - NewIp6AndPort(net.ParseIP("1::1"), 1), - NewIp6AndPort(net.ParseIP("1::1"), 2), // dupe of 0 with a diff port - NewIp6AndPort(net.ParseIP("1:100::1"), 1), - NewIp6AndPort(net.ParseIP("1::1"), 1), // this is a dupe + newIp6AndPortFromString("[1::1]:1"), + newIp6AndPortFromString("[1::1]:2"), // dupe of 0 with a diff port + newIp6AndPortFromString("[1:100::1]:1"), + newIp6AndPortFromString("[1::1]:1"), // this is a dupe }, - func(iputil.VpnIp, *Ip6AndPort) bool { return true }, + func(netip.Addr, *Ip6AndPort) bool { return true }, ) b.Run("no preferred", func(b *testing.B) { for i := 0; i < b.N; i++ { rl.shouldRebuild = true - rl.Rebuild([]*net.IPNet{}) + rl.Rebuild([]netip.Prefix{}) } }) - _, ipNet, err := net.ParseCIDR("172.17.0.0/16") - rl.Rebuild([]*net.IPNet{ipNet}) + ipNet1 := netip.MustParsePrefix("172.17.0.0/16") + rl.Rebuild([]netip.Prefix{ipNet1}) - assert.NoError(b, err) b.Run("1 preferred", func(b *testing.B) { for i := 0; i < b.N; i++ { - rl.Rebuild([]*net.IPNet{ipNet}) + rl.Rebuild([]netip.Prefix{ipNet1}) } }) - _, ipNet2, err := net.ParseCIDR("70.0.0.0/8") - rl.Rebuild([]*net.IPNet{ipNet, ipNet2}) + ipNet2 := netip.MustParsePrefix("70.0.0.0/8") + rl.Rebuild([]netip.Prefix{ipNet1, ipNet2}) - assert.NoError(b, err) b.Run("2 preferred", func(b *testing.B) { for i := 0; i < b.N; i++ { - rl.Rebuild([]*net.IPNet{ipNet, ipNet2}) + rl.Rebuild([]netip.Prefix{ipNet1, ipNet2}) } }) - _, ipNet3, err := net.ParseCIDR("0.0.0.0/0") - rl.Rebuild([]*net.IPNet{ipNet, ipNet2, ipNet3}) + ipNet3 := netip.MustParsePrefix("0.0.0.0/0") + rl.Rebuild([]netip.Prefix{ipNet1, ipNet2, ipNet3}) - assert.NoError(b, err) b.Run("3 preferred", func(b *testing.B) { for i := 0; i < b.N; i++ { - rl.Rebuild([]*net.IPNet{ipNet, ipNet2, ipNet3}) + rl.Rebuild([]netip.Prefix{ipNet1, ipNet2, ipNet3}) } }) } + +func newIp4AndPortFromString(s string) *Ip4AndPort { + a := netip.MustParseAddrPort(s) + v4Addr := a.Addr().As4() + return &Ip4AndPort{ + Ip: binary.BigEndian.Uint32(v4Addr[:]), + Port: uint32(a.Port()), + } +} + +func newIp6AndPortFromString(s string) *Ip6AndPort { + a := netip.MustParseAddrPort(s) + v6Addr := a.Addr().As16() + return &Ip6AndPort{ + Hi: binary.BigEndian.Uint64(v6Addr[:8]), + Lo: binary.BigEndian.Uint64(v6Addr[8:]), + Port: uint32(a.Port()), + } +} diff --git a/service/service.go b/service/service.go index 6816be6..50c1d4a 100644 --- a/service/service.go +++ b/service/service.go @@ -91,7 +91,7 @@ func New(config *config.C) (*Service, error) { ipNet := device.Cidr() pa := tcpip.ProtocolAddress{ - AddressWithPrefix: tcpip.AddrFromSlice(ipNet.IP).WithPrefix(), + AddressWithPrefix: tcpip.AddrFromSlice(ipNet.Addr().AsSlice()).WithPrefix(), Protocol: ipv4.ProtocolNumber, } if err := s.ipstack.AddProtocolAddress(nicID, pa, stack.AddressProperties{ diff --git a/service/service_test.go b/service/service_test.go index d1909cd..3176209 100644 --- a/service/service_test.go +++ b/service/service_test.go @@ -4,7 +4,7 @@ import ( "bytes" "context" "errors" - "net" + "net/netip" "testing" "time" @@ -18,12 +18,8 @@ import ( type m map[string]interface{} -func newSimpleService(caCrt *cert.NebulaCertificate, caKey []byte, name string, udpIp net.IP, overrides m) *Service { - - vpnIpNet := &net.IPNet{IP: make([]byte, len(udpIp)), Mask: net.IPMask{255, 255, 255, 0}} - copy(vpnIpNet.IP, udpIp) - - _, _, myPrivKey, myPEM := e2e.NewTestCert(caCrt, caKey, "a", time.Now(), time.Now().Add(5*time.Minute), vpnIpNet, nil, []string{}) +func newSimpleService(caCrt *cert.NebulaCertificate, caKey []byte, name string, udpIp netip.Addr, overrides m) *Service { + _, _, myPrivKey, myPEM := e2e.NewTestCert(caCrt, caKey, "a", time.Now(), time.Now().Add(5*time.Minute), netip.PrefixFrom(udpIp, 24), nil, []string{}) caB, err := caCrt.MarshalToPEM() if err != nil { panic(err) @@ -83,8 +79,8 @@ func newSimpleService(caCrt *cert.NebulaCertificate, caKey []byte, name string, } func TestService(t *testing.T) { - ca, _, caKey, _ := e2e.NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) - a := newSimpleService(ca, caKey, "a", net.IP{10, 0, 0, 1}, m{ + ca, _, caKey, _ := e2e.NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + a := newSimpleService(ca, caKey, "a", netip.MustParseAddr("10.0.0.1"), m{ "static_host_map": m{}, "lighthouse": m{ "am_lighthouse": true, @@ -94,7 +90,7 @@ func TestService(t *testing.T) { "port": 4243, }, }) - b := newSimpleService(ca, caKey, "b", net.IP{10, 0, 0, 2}, m{ + b := newSimpleService(ca, caKey, "b", netip.MustParseAddr("10.0.0.2"), m{ "static_host_map": m{ "10.0.0.1": []string{"localhost:4243"}, }, diff --git a/ssh.go b/ssh.go index f096121..2ff0954 100644 --- a/ssh.go +++ b/ssh.go @@ -7,6 +7,7 @@ import ( "flag" "fmt" "net" + "net/netip" "os" "reflect" "runtime" @@ -18,9 +19,7 @@ import ( "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/header" - "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/sshd" - "github.com/slackhq/nebula/udp" ) type sshListHostMapFlags struct { @@ -431,7 +430,7 @@ func sshListHostMap(hl controlHostLister, a interface{}, w sshd.StringWriter) er } sort.Slice(hm, func(i, j int) bool { - return bytes.Compare(hm[i].VpnIp, hm[j].VpnIp) < 0 + return hm[i].VpnIp.Compare(hm[j].VpnIp) < 0 }) if fs.Json || fs.Pretty { @@ -545,13 +544,12 @@ func sshQueryLighthouse(ifce *Interface, fs interface{}, a []string, w sshd.Stri return w.WriteLine("No vpn ip was provided") } - parsedIp := net.ParseIP(a[0]) - if parsedIp == nil { + vpnIp, err := netip.ParseAddr(a[0]) + if err != nil { return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) } - vpnIp := iputil.Ip2VpnIp(parsedIp) - if vpnIp == 0 { + if !vpnIp.IsValid() { return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) } @@ -574,13 +572,12 @@ func sshCloseTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWr return w.WriteLine("No vpn ip was provided") } - parsedIp := net.ParseIP(a[0]) - if parsedIp == nil { + vpnIp, err := netip.ParseAddr(a[0]) + if err != nil { return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) } - vpnIp := iputil.Ip2VpnIp(parsedIp) - if vpnIp == 0 { + if !vpnIp.IsValid() { return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) } @@ -616,13 +613,12 @@ func sshCreateTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringW return w.WriteLine("No vpn ip was provided") } - parsedIp := net.ParseIP(a[0]) - if parsedIp == nil { + vpnIp, err := netip.ParseAddr(a[0]) + if err != nil { return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) } - vpnIp := iputil.Ip2VpnIp(parsedIp) - if vpnIp == 0 { + if !vpnIp.IsValid() { return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) } @@ -636,16 +632,16 @@ func sshCreateTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringW return w.WriteLine(fmt.Sprintf("Tunnel already handshaking")) } - var addr *udp.Addr + var addr netip.AddrPort if flags.Address != "" { - addr = udp.NewAddrFromString(flags.Address) - if addr == nil { + addr, err = netip.ParseAddrPort(flags.Address) + if err != nil { return w.WriteLine("Address could not be parsed") } } hostInfo = ifce.handshakeManager.StartHandshake(vpnIp, nil) - if addr != nil { + if addr.IsValid() { hostInfo.SetRemote(addr) } @@ -667,18 +663,17 @@ func sshChangeRemote(ifce *Interface, fs interface{}, a []string, w sshd.StringW return w.WriteLine("No address was provided") } - addr := udp.NewAddrFromString(flags.Address) - if addr == nil { + addr, err := netip.ParseAddrPort(flags.Address) + if err != nil { return w.WriteLine("Address could not be parsed") } - parsedIp := net.ParseIP(a[0]) - if parsedIp == nil { + vpnIp, err := netip.ParseAddr(a[0]) + if err != nil { return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) } - vpnIp := iputil.Ip2VpnIp(parsedIp) - if vpnIp == 0 { + if !vpnIp.IsValid() { return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) } @@ -792,13 +787,12 @@ func sshPrintCert(ifce *Interface, fs interface{}, a []string, w sshd.StringWrit cert := ifce.pki.GetCertState().Certificate if len(a) > 0 { - parsedIp := net.ParseIP(a[0]) - if parsedIp == nil { + vpnIp, err := netip.ParseAddr(a[0]) + if err != nil { return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) } - vpnIp := iputil.Ip2VpnIp(parsedIp) - if vpnIp == 0 { + if !vpnIp.IsValid() { return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) } @@ -862,14 +856,14 @@ func sshPrintRelays(ifce *Interface, fs interface{}, a []string, w sshd.StringWr Error error Type string State string - PeerIp iputil.VpnIp + PeerIp netip.Addr LocalIndex uint32 RemoteIndex uint32 - RelayedThrough []iputil.VpnIp + RelayedThrough []netip.Addr } type RelayOutput struct { - NebulaIp iputil.VpnIp + NebulaIp netip.Addr RelayForIps []RelayFor } @@ -952,13 +946,12 @@ func sshPrintTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWr return w.WriteLine("No vpn ip was provided") } - parsedIp := net.ParseIP(a[0]) - if parsedIp == nil { + vpnIp, err := netip.ParseAddr(a[0]) + if err != nil { return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) } - vpnIp := iputil.Ip2VpnIp(parsedIp) - if vpnIp == 0 { + if !vpnIp.IsValid() { return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) } diff --git a/test/tun.go b/test/tun.go index 86656c9..fbf5829 100644 --- a/test/tun.go +++ b/test/tun.go @@ -3,23 +3,21 @@ package test import ( "errors" "io" - "net" - - "github.com/slackhq/nebula/iputil" + "net/netip" ) type NoopTun struct{} -func (NoopTun) RouteFor(iputil.VpnIp) iputil.VpnIp { - return 0 +func (NoopTun) RouteFor(addr netip.Addr) netip.Addr { + return netip.Addr{} } func (NoopTun) Activate() error { return nil } -func (NoopTun) Cidr() *net.IPNet { - return nil +func (NoopTun) Cidr() netip.Prefix { + return netip.Prefix{} } func (NoopTun) Name() string { diff --git a/timeout_test.go b/timeout_test.go index 3f81ff4..4c6364e 100644 --- a/timeout_test.go +++ b/timeout_test.go @@ -1,6 +1,7 @@ package nebula import ( + "net/netip" "testing" "time" @@ -115,10 +116,10 @@ func TestTimerWheel_Purge(t *testing.T) { assert.Equal(t, 0, tw.current) fps := []firewall.Packet{ - {LocalIP: 1}, - {LocalIP: 2}, - {LocalIP: 3}, - {LocalIP: 4}, + {LocalIP: netip.MustParseAddr("0.0.0.1")}, + {LocalIP: netip.MustParseAddr("0.0.0.2")}, + {LocalIP: netip.MustParseAddr("0.0.0.3")}, + {LocalIP: netip.MustParseAddr("0.0.0.4")}, } tw.Add(fps[0], time.Second*1) diff --git a/udp/conn.go b/udp/conn.go index a2c24a1..fa4e443 100644 --- a/udp/conn.go +++ b/udp/conn.go @@ -1,6 +1,8 @@ package udp import ( + "net/netip" + "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/firewall" "github.com/slackhq/nebula/header" @@ -9,7 +11,7 @@ import ( const MTU = 9001 type EncReader func( - addr *Addr, + addr netip.AddrPort, out []byte, packet []byte, header *header.H, @@ -22,9 +24,9 @@ type EncReader func( type Conn interface { Rebind() error - LocalAddr() (*Addr, error) + LocalAddr() (netip.AddrPort, error) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int) - WriteTo(b []byte, addr *Addr) error + WriteTo(b []byte, addr netip.AddrPort) error ReloadConfig(c *config.C) Close() error } @@ -34,13 +36,13 @@ type NoopConn struct{} func (NoopConn) Rebind() error { return nil } -func (NoopConn) LocalAddr() (*Addr, error) { - return nil, nil +func (NoopConn) LocalAddr() (netip.AddrPort, error) { + return netip.AddrPort{}, nil } func (NoopConn) ListenOut(_ EncReader, _ LightHouseHandlerFunc, _ *firewall.ConntrackCacheTicker, _ int) { return } -func (NoopConn) WriteTo(_ []byte, _ *Addr) error { +func (NoopConn) WriteTo(_ []byte, _ netip.AddrPort) error { return nil } func (NoopConn) ReloadConfig(_ *config.C) { diff --git a/udp/temp.go b/udp/temp.go index 2efe31d..b281906 100644 --- a/udp/temp.go +++ b/udp/temp.go @@ -1,9 +1,10 @@ package udp import ( - "github.com/slackhq/nebula/iputil" + "net/netip" ) //TODO: The items in this file belong in their own packages but doing that in a single PR is a nightmare -type LightHouseHandlerFunc func(rAddr *Addr, vpnIp iputil.VpnIp, p []byte) +// TODO: IPV6-WORK this can likely be removed now +type LightHouseHandlerFunc func(rAddr netip.AddrPort, vpnIp netip.Addr, p []byte) diff --git a/udp/udp_all.go b/udp/udp_all.go deleted file mode 100644 index 093bf69..0000000 --- a/udp/udp_all.go +++ /dev/null @@ -1,100 +0,0 @@ -package udp - -import ( - "encoding/json" - "fmt" - "net" - "strconv" -) - -type m map[string]interface{} - -type Addr struct { - IP net.IP - Port uint16 -} - -func NewAddr(ip net.IP, port uint16) *Addr { - addr := Addr{IP: make([]byte, net.IPv6len), Port: port} - copy(addr.IP, ip.To16()) - return &addr -} - -func NewAddrFromString(s string) *Addr { - ip, port, err := ParseIPAndPort(s) - //TODO: handle err - _ = err - return &Addr{IP: ip.To16(), Port: port} -} - -func (ua *Addr) Equals(t *Addr) bool { - if t == nil || ua == nil { - return t == nil && ua == nil - } - return ua.IP.Equal(t.IP) && ua.Port == t.Port -} - -func (ua *Addr) String() string { - if ua == nil { - return "" - } - - return net.JoinHostPort(ua.IP.String(), fmt.Sprintf("%v", ua.Port)) -} - -func (ua *Addr) MarshalJSON() ([]byte, error) { - if ua == nil { - return nil, nil - } - - return json.Marshal(m{"ip": ua.IP, "port": ua.Port}) -} - -func (ua *Addr) Copy() *Addr { - if ua == nil { - return nil - } - - nu := Addr{ - Port: ua.Port, - IP: make(net.IP, len(ua.IP)), - } - - copy(nu.IP, ua.IP) - return &nu -} - -type AddrSlice []*Addr - -func (a AddrSlice) Equal(b AddrSlice) bool { - if len(a) != len(b) { - return false - } - - for i := range a { - if !a[i].Equals(b[i]) { - return false - } - } - - return true -} - -func ParseIPAndPort(s string) (net.IP, uint16, error) { - rIp, sPort, err := net.SplitHostPort(s) - if err != nil { - return nil, 0, err - } - - addr, err := net.ResolveIPAddr("ip", rIp) - if err != nil { - return nil, 0, err - } - - iPort, err := strconv.Atoi(sPort) - if err != nil { - return nil, 0, err - } - - return addr.IP, uint16(iPort), nil -} diff --git a/udp/udp_android.go b/udp/udp_android.go index 8d69074..bb19195 100644 --- a/udp/udp_android.go +++ b/udp/udp_android.go @@ -6,13 +6,14 @@ package udp import ( "fmt" "net" + "net/netip" "syscall" "github.com/sirupsen/logrus" "golang.org/x/sys/unix" ) -func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (Conn, error) { +func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) { return NewGenericListener(l, ip, port, multi, batch) } diff --git a/udp/udp_bsd.go b/udp/udp_bsd.go index 785aa6a..65ef31a 100644 --- a/udp/udp_bsd.go +++ b/udp/udp_bsd.go @@ -9,13 +9,14 @@ package udp import ( "fmt" "net" + "net/netip" "syscall" "github.com/sirupsen/logrus" "golang.org/x/sys/unix" ) -func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (Conn, error) { +func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) { return NewGenericListener(l, ip, port, multi, batch) } diff --git a/udp/udp_darwin.go b/udp/udp_darwin.go index 08e1b6a..183ac7a 100644 --- a/udp/udp_darwin.go +++ b/udp/udp_darwin.go @@ -8,13 +8,14 @@ package udp import ( "fmt" "net" + "net/netip" "syscall" "github.com/sirupsen/logrus" "golang.org/x/sys/unix" ) -func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (Conn, error) { +func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) { return NewGenericListener(l, ip, port, multi, batch) } diff --git a/udp/udp_generic.go b/udp/udp_generic.go index 1dd6d1d..2d84536 100644 --- a/udp/udp_generic.go +++ b/udp/udp_generic.go @@ -11,6 +11,7 @@ import ( "context" "fmt" "net" + "net/netip" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" @@ -25,7 +26,7 @@ type GenericConn struct { var _ Conn = &GenericConn{} -func NewGenericListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (Conn, error) { +func NewGenericListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) { lc := NewListenConfig(multi) pc, err := lc.ListenPacket(context.TODO(), "udp", net.JoinHostPort(ip.String(), fmt.Sprintf("%v", port))) if err != nil { @@ -37,23 +38,24 @@ func NewGenericListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch return nil, fmt.Errorf("Unexpected PacketConn: %T %#v", pc, pc) } -func (u *GenericConn) WriteTo(b []byte, addr *Addr) error { - _, err := u.UDPConn.WriteToUDP(b, &net.UDPAddr{IP: addr.IP, Port: int(addr.Port)}) +func (u *GenericConn) WriteTo(b []byte, addr netip.AddrPort) error { + _, err := u.UDPConn.WriteToUDPAddrPort(b, addr) return err } -func (u *GenericConn) LocalAddr() (*Addr, error) { +func (u *GenericConn) LocalAddr() (netip.AddrPort, error) { a := u.UDPConn.LocalAddr() switch v := a.(type) { case *net.UDPAddr: - addr := &Addr{IP: make([]byte, len(v.IP))} - copy(addr.IP, v.IP) - addr.Port = uint16(v.Port) - return addr, nil + addr, ok := netip.AddrFromSlice(v.IP) + if !ok { + return netip.AddrPort{}, fmt.Errorf("LocalAddr returned invalid IP address: %s", v.IP) + } + return netip.AddrPortFrom(addr, uint16(v.Port)), nil default: - return nil, fmt.Errorf("LocalAddr returned: %#v", a) + return netip.AddrPort{}, fmt.Errorf("LocalAddr returned: %#v", a) } } @@ -75,19 +77,26 @@ func (u *GenericConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *f buffer := make([]byte, MTU) h := &header.H{} fwPacket := &firewall.Packet{} - udpAddr := &Addr{IP: make([]byte, 16)} nb := make([]byte, 12, 12) for { // Just read one packet at a time - n, rua, err := u.ReadFromUDP(buffer) + n, rua, err := u.ReadFromUDPAddrPort(buffer) if err != nil { u.l.WithError(err).Debug("udp socket is closed, exiting read loop") return } - udpAddr.IP = rua.IP - udpAddr.Port = uint16(rua.Port) - r(udpAddr, plaintext[:0], buffer[:n], h, fwPacket, lhf, nb, q, cache.Get(u.l)) + r( + netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), + plaintext[:0], + buffer[:n], + h, + fwPacket, + lhf, + nb, + q, + cache.Get(u.l), + ) } } diff --git a/udp/udp_linux.go b/udp/udp_linux.go index 02c8ce0..ef07243 100644 --- a/udp/udp_linux.go +++ b/udp/udp_linux.go @@ -7,6 +7,7 @@ import ( "encoding/binary" "fmt" "net" + "net/netip" "syscall" "unsafe" @@ -35,10 +36,9 @@ func maybeIPV4(ip net.IP) (net.IP, bool) { return ip, false } -func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (Conn, error) { - ipV4, isV4 := maybeIPV4(ip) +func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) { af := unix.AF_INET6 - if isV4 { + if ip.Is4() { af = unix.AF_INET } syscall.ForkLock.RLock() @@ -61,13 +61,13 @@ func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) ( //TODO: support multiple listening IPs (for limiting ipv6) var sa unix.Sockaddr - if isV4 { + if ip.Is4() { sa4 := &unix.SockaddrInet4{Port: port} - copy(sa4.Addr[:], ipV4) + sa4.Addr = ip.As4() sa = sa4 } else { sa6 := &unix.SockaddrInet6{Port: port} - copy(sa6.Addr[:], ip.To16()) + sa6.Addr = ip.As16() sa = sa6 } if err = unix.Bind(fd, sa); err != nil { @@ -79,7 +79,7 @@ func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) ( //v, err := unix.GetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_INCOMING_CPU) //l.Println(v, err) - return &StdConn{sysFd: fd, isV4: isV4, l: l, batch: batch}, err + return &StdConn{sysFd: fd, isV4: ip.Is4(), l: l, batch: batch}, err } func (u *StdConn) Rebind() error { @@ -102,30 +102,29 @@ func (u *StdConn) GetSendBuffer() (int, error) { return unix.GetsockoptInt(int(u.sysFd), unix.SOL_SOCKET, unix.SO_SNDBUF) } -func (u *StdConn) LocalAddr() (*Addr, error) { +func (u *StdConn) LocalAddr() (netip.AddrPort, error) { sa, err := unix.Getsockname(u.sysFd) if err != nil { - return nil, err + return netip.AddrPort{}, err } - addr := &Addr{} switch sa := sa.(type) { case *unix.SockaddrInet4: - addr.IP = net.IP{sa.Addr[0], sa.Addr[1], sa.Addr[2], sa.Addr[3]}.To16() - addr.Port = uint16(sa.Port) - case *unix.SockaddrInet6: - addr.IP = sa.Addr[0:] - addr.Port = uint16(sa.Port) - } + return netip.AddrPortFrom(netip.AddrFrom4(sa.Addr), uint16(sa.Port)), nil - return addr, nil + case *unix.SockaddrInet6: + return netip.AddrPortFrom(netip.AddrFrom16(sa.Addr), uint16(sa.Port)), nil + + default: + return netip.AddrPort{}, fmt.Errorf("unsupported sock type: %T", sa) + } } func (u *StdConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int) { plaintext := make([]byte, MTU) h := &header.H{} fwPacket := &firewall.Packet{} - udpAddr := &Addr{} + var ip netip.Addr nb := make([]byte, 12, 12) //TODO: should we track this? @@ -146,12 +145,23 @@ func (u *StdConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firew //metric.Update(int64(n)) for i := 0; i < n; i++ { if u.isV4 { - udpAddr.IP = names[i][4:8] + ip, _ = netip.AddrFromSlice(names[i][4:8]) + //TODO: IPV6-WORK what is not ok? } else { - udpAddr.IP = names[i][8:24] + ip, _ = netip.AddrFromSlice(names[i][8:24]) + //TODO: IPV6-WORK what is not ok? } - udpAddr.Port = binary.BigEndian.Uint16(names[i][2:4]) - r(udpAddr, plaintext[:0], buffers[i][:msgs[i].Len], h, fwPacket, lhf, nb, q, cache.Get(u.l)) + r( + netip.AddrPortFrom(ip.Unmap(), binary.BigEndian.Uint16(names[i][2:4])), + plaintext[:0], + buffers[i][:msgs[i].Len], + h, + fwPacket, + lhf, + nb, + q, + cache.Get(u.l), + ) } } } @@ -197,19 +207,20 @@ func (u *StdConn) ReadMulti(msgs []rawMessage) (int, error) { } } -func (u *StdConn) WriteTo(b []byte, addr *Addr) error { +func (u *StdConn) WriteTo(b []byte, ip netip.AddrPort) error { if u.isV4 { - return u.writeTo4(b, addr) + return u.writeTo4(b, ip) } - return u.writeTo6(b, addr) + return u.writeTo6(b, ip) } -func (u *StdConn) writeTo6(b []byte, addr *Addr) error { +func (u *StdConn) writeTo6(b []byte, ip netip.AddrPort) error { var rsa unix.RawSockaddrInet6 rsa.Family = unix.AF_INET6 + rsa.Addr = ip.Addr().As16() + port := ip.Port() // Little Endian -> Network Endian - rsa.Port = (addr.Port >> 8) | ((addr.Port & 0xff) << 8) - copy(rsa.Addr[:], addr.IP.To16()) + rsa.Port = (port >> 8) | ((port & 0xff) << 8) for { _, _, err := unix.Syscall6( @@ -232,17 +243,17 @@ func (u *StdConn) writeTo6(b []byte, addr *Addr) error { } } -func (u *StdConn) writeTo4(b []byte, addr *Addr) error { - addrV4, isAddrV4 := maybeIPV4(addr.IP) - if !isAddrV4 { +func (u *StdConn) writeTo4(b []byte, ip netip.AddrPort) error { + if !ip.Addr().Is4() { return fmt.Errorf("Listener is IPv4, but writing to IPv6 remote") } var rsa unix.RawSockaddrInet4 rsa.Family = unix.AF_INET + rsa.Addr = ip.Addr().As4() + port := ip.Port() // Little Endian -> Network Endian - rsa.Port = (addr.Port >> 8) | ((addr.Port & 0xff) << 8) - copy(rsa.Addr[:], addrV4) + rsa.Port = (port >> 8) | ((port & 0xff) << 8) for { _, _, err := unix.Syscall6( diff --git a/udp/udp_netbsd.go b/udp/udp_netbsd.go index 3c14fac..3b69159 100644 --- a/udp/udp_netbsd.go +++ b/udp/udp_netbsd.go @@ -8,13 +8,14 @@ package udp import ( "fmt" "net" + "net/netip" "syscall" "github.com/sirupsen/logrus" "golang.org/x/sys/unix" ) -func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (Conn, error) { +func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) { return NewGenericListener(l, ip, port, multi, batch) } diff --git a/udp/udp_rio_windows.go b/udp/udp_rio_windows.go index 31c1a55..ee7e1e0 100644 --- a/udp/udp_rio_windows.go +++ b/udp/udp_rio_windows.go @@ -10,6 +10,7 @@ import ( "fmt" "io" "net" + "net/netip" "sync" "sync/atomic" "syscall" @@ -61,16 +62,14 @@ type RIOConn struct { results [packetsPerRing]winrio.Result } -func NewRIOListener(l *logrus.Logger, ip net.IP, port int) (*RIOConn, error) { +func NewRIOListener(l *logrus.Logger, addr netip.Addr, port int) (*RIOConn, error) { if !winrio.Initialize() { return nil, errors.New("could not initialize winrio") } u := &RIOConn{l: l} - addr := [16]byte{} - copy(addr[:], ip.To16()) - err := u.bind(&windows.SockaddrInet6{Addr: addr, Port: port}) + err := u.bind(&windows.SockaddrInet6{Addr: addr.As16(), Port: port}) if err != nil { return nil, fmt.Errorf("bind: %w", err) } @@ -124,7 +123,6 @@ func (u *RIOConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firew buffer := make([]byte, MTU) h := &header.H{} fwPacket := &firewall.Packet{} - udpAddr := &Addr{IP: make([]byte, 16)} nb := make([]byte, 12, 12) for { @@ -135,11 +133,17 @@ func (u *RIOConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firew return } - udpAddr.IP = rua.Addr[:] - p := (*[2]byte)(unsafe.Pointer(&udpAddr.Port)) - p[0] = byte(rua.Port >> 8) - p[1] = byte(rua.Port) - r(udpAddr, plaintext[:0], buffer[:n], h, fwPacket, lhf, nb, q, cache.Get(u.l)) + r( + netip.AddrPortFrom(netip.AddrFrom16(rua.Addr).Unmap(), (rua.Port>>8)|((rua.Port&0xff)<<8)), + plaintext[:0], + buffer[:n], + h, + fwPacket, + lhf, + nb, + q, + cache.Get(u.l), + ) } } @@ -231,7 +235,7 @@ retry: return n, ep, nil } -func (u *RIOConn) WriteTo(buf []byte, addr *Addr) error { +func (u *RIOConn) WriteTo(buf []byte, ip netip.AddrPort) error { if !u.isOpen.Load() { return net.ErrClosed } @@ -274,10 +278,9 @@ func (u *RIOConn) WriteTo(buf []byte, addr *Addr) error { packet := u.tx.Push() packet.addr.Family = windows.AF_INET6 - p := (*[2]byte)(unsafe.Pointer(&packet.addr.Port)) - p[0] = byte(addr.Port >> 8) - p[1] = byte(addr.Port) - copy(packet.addr.Addr[:], addr.IP.To16()) + packet.addr.Addr = ip.Addr().As16() + port := ip.Port() + packet.addr.Port = (port >> 8) | ((port & 0xff) << 8) copy(packet.data[:], buf) dataBuffer := &winrio.Buffer{ @@ -295,17 +298,15 @@ func (u *RIOConn) WriteTo(buf []byte, addr *Addr) error { return winrio.SendEx(u.rq, dataBuffer, 1, nil, addressBuffer, nil, nil, 0, 0) } -func (u *RIOConn) LocalAddr() (*Addr, error) { +func (u *RIOConn) LocalAddr() (netip.AddrPort, error) { sa, err := windows.Getsockname(u.sock) if err != nil { - return nil, err + return netip.AddrPort{}, err } v6 := sa.(*windows.SockaddrInet6) - return &Addr{ - IP: v6.Addr[:], - Port: uint16(v6.Port), - }, nil + return netip.AddrPortFrom(netip.AddrFrom16(v6.Addr).Unmap(), uint16(v6.Port)), nil + } func (u *RIOConn) Rebind() error { diff --git a/udp/udp_tester.go b/udp/udp_tester.go index 55985f4..f03a353 100644 --- a/udp/udp_tester.go +++ b/udp/udp_tester.go @@ -4,9 +4,8 @@ package udp import ( - "fmt" "io" - "net" + "net/netip" "sync/atomic" "github.com/sirupsen/logrus" @@ -16,30 +15,24 @@ import ( ) type Packet struct { - ToIp net.IP - ToPort uint16 - FromIp net.IP - FromPort uint16 - Data []byte + To netip.AddrPort + From netip.AddrPort + Data []byte } func (u *Packet) Copy() *Packet { n := &Packet{ - ToIp: make(net.IP, len(u.ToIp)), - ToPort: u.ToPort, - FromIp: make(net.IP, len(u.FromIp)), - FromPort: u.FromPort, - Data: make([]byte, len(u.Data)), + To: u.To, + From: u.From, + Data: make([]byte, len(u.Data)), } - copy(n.ToIp, u.ToIp) - copy(n.FromIp, u.FromIp) copy(n.Data, u.Data) return n } type TesterConn struct { - Addr *Addr + Addr netip.AddrPort RxPackets chan *Packet // Packets to receive into nebula TxPackets chan *Packet // Packets transmitted outside by nebula @@ -48,9 +41,9 @@ type TesterConn struct { l *logrus.Logger } -func NewListener(l *logrus.Logger, ip net.IP, port int, _ bool, _ int) (Conn, error) { +func NewListener(l *logrus.Logger, ip netip.Addr, port int, _ bool, _ int) (Conn, error) { return &TesterConn{ - Addr: &Addr{ip, uint16(port)}, + Addr: netip.AddrPortFrom(ip, uint16(port)), RxPackets: make(chan *Packet, 10), TxPackets: make(chan *Packet, 10), l: l, @@ -71,7 +64,7 @@ func (u *TesterConn) Send(packet *Packet) { } if u.l.Level >= logrus.DebugLevel { u.l.WithField("header", h). - WithField("udpAddr", fmt.Sprintf("%v:%v", packet.FromIp, packet.FromPort)). + WithField("udpAddr", packet.From). WithField("dataLen", len(packet.Data)). Debug("UDP receiving injected packet") } @@ -98,23 +91,18 @@ func (u *TesterConn) Get(block bool) *Packet { // Below this is boilerplate implementation to make nebula actually work //********************************************************************************************************************// -func (u *TesterConn) WriteTo(b []byte, addr *Addr) error { +func (u *TesterConn) WriteTo(b []byte, addr netip.AddrPort) error { if u.closed.Load() { return io.ErrClosedPipe } p := &Packet{ - Data: make([]byte, len(b), len(b)), - FromIp: make([]byte, 16), - FromPort: u.Addr.Port, - ToIp: make([]byte, 16), - ToPort: addr.Port, + Data: make([]byte, len(b), len(b)), + From: u.Addr, + To: addr, } copy(p.Data, b) - copy(p.ToIp, addr.IP.To16()) - copy(p.FromIp, u.Addr.IP.To16()) - u.TxPackets <- p return nil } @@ -123,7 +111,6 @@ func (u *TesterConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *fi plaintext := make([]byte, MTU) h := &header.H{} fwPacket := &firewall.Packet{} - ua := &Addr{IP: make([]byte, 16)} nb := make([]byte, 12, 12) for { @@ -131,9 +118,7 @@ func (u *TesterConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *fi if !ok { return } - ua.Port = p.FromPort - copy(ua.IP, p.FromIp.To16()) - r(ua, plaintext[:0], p.Data, h, fwPacket, lhf, nb, q, cache.Get(u.l)) + r(p.From, plaintext[:0], p.Data, h, fwPacket, lhf, nb, q, cache.Get(u.l)) } } @@ -144,7 +129,7 @@ func NewUDPStatsEmitter(_ []Conn) func() { return func() {} } -func (u *TesterConn) LocalAddr() (*Addr, error) { +func (u *TesterConn) LocalAddr() (netip.AddrPort, error) { return u.Addr, nil } diff --git a/udp/udp_windows.go b/udp/udp_windows.go index ebcace6..1b777c3 100644 --- a/udp/udp_windows.go +++ b/udp/udp_windows.go @@ -6,12 +6,13 @@ package udp import ( "fmt" "net" + "net/netip" "syscall" "github.com/sirupsen/logrus" ) -func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (Conn, error) { +func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) { if multi { //NOTE: Technically we can support it with RIO but it wouldn't be at the socket level // The udp stack would need to be reworked to hide away the implementation differences between From 9a63fa0a07330d9423247f65084838a0e1821332 Mon Sep 17 00:00:00 2001 From: brad-defined <77982333+brad-defined@users.noreply.github.com> Date: Thu, 1 Aug 2024 13:40:05 -0400 Subject: [PATCH 17/67] Make some Nebula state programmatically available via control object (#1188) --- control.go | 36 ++++++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/control.go b/control.go index 7782b23..3468b35 100644 --- a/control.go +++ b/control.go @@ -129,6 +129,42 @@ func (c *Control) ListHostmapIndexes(pendingMap bool) []ControlHostInfo { } } +// GetCertByVpnIp returns the authenticated certificate of the given vpn IP, or nil if not found +func (c *Control) GetCertByVpnIp(vpnIp netip.Addr) *cert.NebulaCertificate { + if c.f.myVpnNet.Addr() == vpnIp { + return c.f.pki.GetCertState().Certificate + } + hi := c.f.hostMap.QueryVpnIp(vpnIp) + if hi == nil { + return nil + } + return hi.GetCert() +} + +// CreateTunnel creates a new tunnel to the given vpn ip. +func (c *Control) CreateTunnel(vpnIp netip.Addr) { + c.f.handshakeManager.StartHandshake(vpnIp, nil) +} + +// PrintTunnel creates a new tunnel to the given vpn ip. +func (c *Control) PrintTunnel(vpnIp netip.Addr) *ControlHostInfo { + hi := c.f.hostMap.QueryVpnIp(vpnIp) + if hi == nil { + return nil + } + chi := copyHostInfo(hi, c.f.hostMap.GetPreferredRanges()) + return &chi +} + +// QueryLighthouse queries the lighthouse. +func (c *Control) QueryLighthouse(vpnIp netip.Addr) *CacheMap { + hi := c.f.lightHouse.Query(vpnIp) + if hi == nil { + return nil + } + return hi.CopyCache() +} + // GetHostInfoByVpnIp returns a single tunnels hostInfo, or nil if not found // Caller should take care to Unmap() any 4in6 addresses prior to calling. func (c *Control) GetHostInfoByVpnIp(vpnIp netip.Addr, pending bool) *ControlHostInfo { From f5f6c269ac70d54476f6ec8caf8ff10a7d4957a6 Mon Sep 17 00:00:00 2001 From: Wade Simmons Date: Wed, 7 Aug 2024 11:53:32 -0400 Subject: [PATCH 18/67] fix rare panic when local index collision happens (#1191) A local index collision happens when two tunnels attempt to use the same random int32 index ID. This is a rare chance, and we have code to deal with it, but we have a panic because we return the wrong thing in this case. This change should fix the panic. --- handshake_manager.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/handshake_manager.go b/handshake_manager.go index 7960435..217f11b 100644 --- a/handshake_manager.go +++ b/handshake_manager.go @@ -488,7 +488,7 @@ func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket existingPendingIndex, found := c.indexes[hostinfo.localIndexId] if found && existingPendingIndex.hostinfo != hostinfo { // We have a collision, but for a different hostinfo - return existingIndex, ErrLocalIndexCollision + return existingPendingIndex.hostinfo, ErrLocalIndexCollision } existingRemoteIndex, found := c.mainHostMap.RemoteIndexes[hostinfo.remoteIndexId] From 8a6a0f0636cf9eade3b51270095dcb7a6bbf879c Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 7 Aug 2024 11:58:46 -0400 Subject: [PATCH 19/67] Bump the golang-x-dependencies group with 2 updates (#1190) Bumps the golang-x-dependencies group with 2 updates: [golang.org/x/sync](https://github.com/golang/sync) and [golang.org/x/sys](https://github.com/golang/sys). Updates `golang.org/x/sync` from 0.7.0 to 0.8.0 - [Commits](https://github.com/golang/sync/compare/v0.7.0...v0.8.0) Updates `golang.org/x/sys` from 0.22.0 to 0.23.0 - [Commits](https://github.com/golang/sys/compare/v0.22.0...v0.23.0) --- updated-dependencies: - dependency-name: golang.org/x/sync dependency-type: direct:production update-type: version-update:semver-minor dependency-group: golang-x-dependencies - dependency-name: golang.org/x/sys dependency-type: direct:production update-type: version-update:semver-minor dependency-group: golang-x-dependencies ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- go.mod | 6 +++--- go.sum | 10 ++++------ 2 files changed, 7 insertions(+), 9 deletions(-) diff --git a/go.mod b/go.mod index 7680d09..56871f1 100644 --- a/go.mod +++ b/go.mod @@ -10,6 +10,7 @@ require ( github.com/armon/go-radix v1.0.0 github.com/cyberdelia/go-metrics-graphite v0.0.0-20161219230853-39f87cc3b432 github.com/flynn/noise v1.1.0 + github.com/gaissmai/bart v0.11.1 github.com/gogo/protobuf v1.3.2 github.com/google/gopacket v1.1.19 github.com/kardianos/service v1.2.2 @@ -25,8 +26,8 @@ require ( golang.org/x/crypto v0.25.0 golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 golang.org/x/net v0.27.0 - golang.org/x/sync v0.7.0 - golang.org/x/sys v0.22.0 + golang.org/x/sync v0.8.0 + golang.org/x/sys v0.23.0 golang.org/x/term v0.22.0 golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b @@ -41,7 +42,6 @@ require ( github.com/bits-and-blooms/bitset v1.13.0 // indirect github.com/cespare/xxhash/v2 v2.2.0 // indirect github.com/davecgh/go-spew v1.1.1 // indirect - github.com/gaissmai/bart v0.11.1 // indirect github.com/google/btree v1.1.2 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/prometheus/client_model v0.5.0 // indirect diff --git a/go.sum b/go.sum index 7ce7e0e..2688b7e 100644 --- a/go.sum +++ b/go.sum @@ -26,8 +26,6 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/flynn/noise v1.1.0 h1:KjPQoQCEFdZDiP03phOvGi11+SVVhBG2wOWAorLsstg= github.com/flynn/noise v1.1.0/go.mod h1:xbMo+0i6+IGbYdJhF31t2eR1BIU0CYc12+BNAKwUTag= -github.com/gaissmai/bart v0.10.0 h1:yCZCYF8xzcRnqDe4jMk14NlJjL1WmMsE7ilBzvuHtiI= -github.com/gaissmai/bart v0.10.0/go.mod h1:KHeYECXQiBjTzQz/om2tqn3sZF1J7hw9m6z41ftj3fg= github.com/gaissmai/bart v0.11.1 h1:5Uv5XwsaFBRo4E5VBcb9TzY8B7zxFf+U7isDxqOrRfc= github.com/gaissmai/bart v0.11.1/go.mod h1:KHeYECXQiBjTzQz/om2tqn3sZF1J7hw9m6z41ftj3fg= github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= @@ -182,8 +180,8 @@ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= -golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ= +golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -201,8 +199,8 @@ golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210603081109-ebe580a85c40/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.22.0 h1:RI27ohtqKCnwULzJLqkv897zojh5/DwS/ENaMzUOaWI= -golang.org/x/sys v0.22.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.23.0 h1:YfKFowiIMvtgl1UERQoTPPToxltDeZfbj4H7dVUCwmM= +golang.org/x/sys v0.23.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.22.0 h1:BbsgPEJULsl2fV/AT3v15Mjva5yXKQDyKf+TbDz7QJk= golang.org/x/term v0.22.0/go.mod h1:F3qCibpT5AMpCRfhfT53vVJwhLtIVHhB9XDjfFvnMI4= From 248cf194cdb8dfd4eb23753337a7f2fab14cf9a4 Mon Sep 17 00:00:00 2001 From: Jack Doan Date: Tue, 13 Aug 2024 06:25:18 -0700 Subject: [PATCH 20/67] fix integer wraparound in the calculation of handshake timeouts on 32-bit targets (#1185) Fixes: #1169 --- handshake_manager.go | 8 ++++---- main.go | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/handshake_manager.go b/handshake_manager.go index 217f11b..1df37bd 100644 --- a/handshake_manager.go +++ b/handshake_manager.go @@ -35,7 +35,7 @@ var ( type HandshakeConfig struct { tryInterval time.Duration - retries int + retries int64 triggerBuffer int useRelays bool @@ -69,7 +69,7 @@ type HandshakeHostInfo struct { startTime time.Time // Time that we first started trying with this handshake ready bool // Is the handshake ready - counter int // How many attempts have we made so far + counter int64 // How many attempts have we made so far lastRemotes []netip.AddrPort // Remotes that we sent to during the previous attempt packetStore []*cachedPacket // A set of packets to be transmitted once the handshake completes @@ -665,6 +665,6 @@ func generateIndex(l *logrus.Logger) (uint32, error) { return index, nil } -func hsTimeout(tries int, interval time.Duration) time.Duration { - return time.Duration(tries / 2 * ((2 * int(interval)) + (tries-1)*int(interval))) +func hsTimeout(tries int64, interval time.Duration) time.Duration { + return time.Duration(tries / 2 * ((2 * int64(interval)) + (tries-1)*int64(interval))) } diff --git a/main.go b/main.go index 248f329..c6edc91 100644 --- a/main.go +++ b/main.go @@ -215,7 +215,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg handshakeConfig := HandshakeConfig{ tryInterval: c.GetDuration("handshakes.try_interval", DefaultHandshakeTryInterval), - retries: c.GetInt("handshakes.retries", DefaultHandshakeRetries), + retries: int64(c.GetInt("handshakes.retries", DefaultHandshakeRetries)), triggerBuffer: c.GetInt("handshakes.trigger_buffer", DefaultHandshakeTriggerBuffer), useRelays: useRelays, From 0736cfa5627f969a6506a02f3e59412b9377446a Mon Sep 17 00:00:00 2001 From: Wade Simmons Date: Wed, 14 Aug 2024 12:53:00 -0400 Subject: [PATCH 21/67] udp: fix endianness for port (#1194) If the host OS is already big endian, we were swapping bytes when we shouldn't have. Use the Go helper to make sure we do the endianness correctly Fixes: #1189 --- udp/udp_linux.go | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/udp/udp_linux.go b/udp/udp_linux.go index ef07243..2eee76e 100644 --- a/udp/udp_linux.go +++ b/udp/udp_linux.go @@ -218,9 +218,7 @@ func (u *StdConn) writeTo6(b []byte, ip netip.AddrPort) error { var rsa unix.RawSockaddrInet6 rsa.Family = unix.AF_INET6 rsa.Addr = ip.Addr().As16() - port := ip.Port() - // Little Endian -> Network Endian - rsa.Port = (port >> 8) | ((port & 0xff) << 8) + binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&rsa.Port))[:], ip.Port()) for { _, _, err := unix.Syscall6( @@ -251,9 +249,7 @@ func (u *StdConn) writeTo4(b []byte, ip netip.AddrPort) error { var rsa unix.RawSockaddrInet4 rsa.Family = unix.AF_INET rsa.Addr = ip.Addr().As4() - port := ip.Port() - // Little Endian -> Network Endian - rsa.Port = (port >> 8) | ((port & 0xff) << 8) + binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&rsa.Port))[:], ip.Port()) for { _, _, err := unix.Syscall6( From 3dc56e1184ccdd3fb6c7466f416c9bfcbefe73dc Mon Sep 17 00:00:00 2001 From: Jack Doan Date: Mon, 26 Aug 2024 10:38:32 -0700 Subject: [PATCH 22/67] Support UDP dialling with gvisor (#1181) --- examples/go_service/main.go | 27 ++++++++++++------ service/service.go | 55 +++++++++++++++++++++++++++---------- 2 files changed, 58 insertions(+), 24 deletions(-) diff --git a/examples/go_service/main.go b/examples/go_service/main.go index f46273a..30178c0 100644 --- a/examples/go_service/main.go +++ b/examples/go_service/main.go @@ -4,6 +4,7 @@ import ( "bufio" "fmt" "log" + "net" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/service" @@ -54,16 +55,16 @@ pki: cert: /home/rice/Developer/nebula-config/app.crt key: /home/rice/Developer/nebula-config/app.key ` - var config config.C - if err := config.LoadString(configStr); err != nil { + var cfg config.C + if err := cfg.LoadString(configStr); err != nil { return err } - service, err := service.New(&config) + svc, err := service.New(&cfg) if err != nil { return err } - ln, err := service.Listen("tcp", ":1234") + ln, err := svc.Listen("tcp", ":1234") if err != nil { return err } @@ -73,16 +74,24 @@ pki: log.Printf("accept error: %s", err) break } - defer conn.Close() + defer func(conn net.Conn) { + _ = conn.Close() + }(conn) log.Printf("got connection") - conn.Write([]byte("hello world\n")) + _, err = conn.Write([]byte("hello world\n")) + if err != nil { + log.Printf("write error: %s", err) + } scanner := bufio.NewScanner(conn) for scanner.Scan() { message := scanner.Text() - fmt.Fprintf(conn, "echo: %q\n", message) + _, err = fmt.Fprintf(conn, "echo: %q\n", message) + if err != nil { + log.Printf("write error: %s", err) + } log.Printf("got message %q", message) } @@ -92,8 +101,8 @@ pki: } } - service.Close() - if err := service.Wait(); err != nil { + _ = svc.Close() + if err := svc.Wait(); err != nil { return err } return nil diff --git a/service/service.go b/service/service.go index 50c1d4a..4ddd301 100644 --- a/service/service.go +++ b/service/service.go @@ -8,6 +8,7 @@ import ( "log" "math" "net" + "net/netip" "os" "strings" "sync" @@ -153,24 +154,48 @@ func New(config *config.C) (*Service, error) { return &s, nil } -// DialContext dials the provided address. Currently only TCP is supported. +func getProtocolNumber(addr netip.Addr) tcpip.NetworkProtocolNumber { + if addr.Is6() { + return ipv6.ProtocolNumber + } + return ipv4.ProtocolNumber +} + +// DialContext dials the provided address. func (s *Service) DialContext(ctx context.Context, network, address string) (net.Conn, error) { - if network != "tcp" && network != "tcp4" { - return nil, errors.New("only tcp is supported") + switch network { + case "udp", "udp4", "udp6": + addr, err := net.ResolveUDPAddr(network, address) + if err != nil { + return nil, err + } + fullAddr := tcpip.FullAddress{ + NIC: nicID, + Addr: tcpip.AddrFromSlice(addr.IP), + Port: uint16(addr.Port), + } + num := getProtocolNumber(addr.AddrPort().Addr()) + return gonet.DialUDP(s.ipstack, nil, &fullAddr, num) + case "tcp", "tcp4", "tcp6": + addr, err := net.ResolveTCPAddr(network, address) + if err != nil { + return nil, err + } + fullAddr := tcpip.FullAddress{ + NIC: nicID, + Addr: tcpip.AddrFromSlice(addr.IP), + Port: uint16(addr.Port), + } + num := getProtocolNumber(addr.AddrPort().Addr()) + return gonet.DialContextTCP(ctx, s.ipstack, fullAddr, num) + default: + return nil, fmt.Errorf("unknown network type: %s", network) } +} - addr, err := net.ResolveTCPAddr(network, address) - if err != nil { - return nil, err - } - - fullAddr := tcpip.FullAddress{ - NIC: nicID, - Addr: tcpip.AddrFromSlice(addr.IP), - Port: uint16(addr.Port), - } - - return gonet.DialContextTCP(ctx, s.ipstack, fullAddr, ipv4.ProtocolNumber) +// Dial dials the provided address +func (s *Service) Dial(network, address string) (net.Conn, error) { + return s.DialContext(context.Background(), network, address) } // Listen listens on the provided address. Currently only TCP with wildcard From 45bbad2f216a05ff6503f3c4b0951bf70982de8b Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 3 Sep 2024 16:47:36 -0400 Subject: [PATCH 23/67] Bump the golang-x-dependencies group with 4 updates (#1195) Bumps the golang-x-dependencies group with 4 updates: [golang.org/x/crypto](https://github.com/golang/crypto), [golang.org/x/net](https://github.com/golang/net), [golang.org/x/sys](https://github.com/golang/sys) and [golang.org/x/term](https://github.com/golang/term). Updates `golang.org/x/crypto` from 0.25.0 to 0.26.0 - [Commits](https://github.com/golang/crypto/compare/v0.25.0...v0.26.0) Updates `golang.org/x/net` from 0.27.0 to 0.28.0 - [Commits](https://github.com/golang/net/compare/v0.27.0...v0.28.0) Updates `golang.org/x/sys` from 0.23.0 to 0.24.0 - [Commits](https://github.com/golang/sys/compare/v0.23.0...v0.24.0) Updates `golang.org/x/term` from 0.22.0 to 0.23.0 - [Commits](https://github.com/golang/term/compare/v0.22.0...v0.23.0) --- updated-dependencies: - dependency-name: golang.org/x/crypto dependency-type: direct:production update-type: version-update:semver-minor dependency-group: golang-x-dependencies - dependency-name: golang.org/x/net dependency-type: direct:production update-type: version-update:semver-minor dependency-group: golang-x-dependencies - dependency-name: golang.org/x/sys dependency-type: direct:production update-type: version-update:semver-minor dependency-group: golang-x-dependencies - dependency-name: golang.org/x/term dependency-type: direct:production update-type: version-update:semver-minor dependency-group: golang-x-dependencies ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- go.mod | 8 ++++---- go.sum | 16 ++++++++-------- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/go.mod b/go.mod index 56871f1..adb2e84 100644 --- a/go.mod +++ b/go.mod @@ -23,12 +23,12 @@ require ( github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8 github.com/stretchr/testify v1.9.0 github.com/vishvananda/netlink v1.2.1-beta.2 - golang.org/x/crypto v0.25.0 + golang.org/x/crypto v0.26.0 golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 - golang.org/x/net v0.27.0 + golang.org/x/net v0.28.0 golang.org/x/sync v0.8.0 - golang.org/x/sys v0.23.0 - golang.org/x/term v0.22.0 + golang.org/x/sys v0.24.0 + golang.org/x/term v0.23.0 golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b golang.zx2c4.com/wireguard/windows v0.5.3 diff --git a/go.sum b/go.sum index 2688b7e..3afd6cb 100644 --- a/go.sum +++ b/go.sum @@ -151,8 +151,8 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= -golang.org/x/crypto v0.25.0 h1:ypSNr+bnYL2YhwoMt2zPxHFmbAN1KZs/njMG3hxUp30= -golang.org/x/crypto v0.25.0/go.mod h1:T+wALwcMOSE0kXgUAnPAHqTLW+XHgcELELW8VaDgm/M= +golang.org/x/crypto v0.26.0 h1:RrRspgV4mU+YwB4FYnuBoKsUapNIL5cohGAmSH3azsw= +golang.org/x/crypto v0.26.0/go.mod h1:GY7jblb9wI+FOo5y8/S2oY4zWP07AkOJ4+jxCqdqn54= golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 h1:Di6/M8l0O2lCLc6VVRWhgCiApHV8MnQurBnFSHsQtNY= golang.org/x/exp v0.0.0-20230725093048-515e97ebf090/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc= golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY= @@ -171,8 +171,8 @@ golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLL golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= -golang.org/x/net v0.27.0 h1:5K3Njcw06/l2y9vpGCSdcxWOYHOUk3dVNGDXN+FvAys= -golang.org/x/net v0.27.0/go.mod h1:dDi0PyhWNoiUOrAS8uXv/vnScO4wnHQO4mj9fn/RytE= +golang.org/x/net v0.28.0 h1:a9JDOJc5GMUJ0+UDqmLT86WiEy7iWyIhz8gz8E4e5hE= +golang.org/x/net v0.28.0/go.mod h1:yqtgsTWOOnlGLG9GFRrK3++bGOUEkNBoHZc8MEDWPNg= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -199,11 +199,11 @@ golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210603081109-ebe580a85c40/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.23.0 h1:YfKFowiIMvtgl1UERQoTPPToxltDeZfbj4H7dVUCwmM= -golang.org/x/sys v0.23.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.24.0 h1:Twjiwq9dn6R1fQcyiK+wQyHWfaz/BJB+YIpzU/Cv3Xg= +golang.org/x/sys v0.24.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= -golang.org/x/term v0.22.0 h1:BbsgPEJULsl2fV/AT3v15Mjva5yXKQDyKf+TbDz7QJk= -golang.org/x/term v0.22.0/go.mod h1:F3qCibpT5AMpCRfhfT53vVJwhLtIVHhB9XDjfFvnMI4= +golang.org/x/term v0.23.0 h1:F6D4vR+EHoL9/sWAWgAR1H2DcHr4PareCbAaCo1RpuU= +golang.org/x/term v0.23.0/go.mod h1:DgV24QBUrK6jhZXl+20l6UWznPlwAHm1Q1mGHtydmSk= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= From ab81b62ea05814f3eebc0ad545aabcc8b704052b Mon Sep 17 00:00:00 2001 From: Wade Simmons Date: Mon, 9 Sep 2024 14:11:44 -0400 Subject: [PATCH 24/67] v1.9.4 (#1210) Update CHANGELOG for Nebula v1.9.4 --- CHANGELOG.md | 23 ++++++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index f763b69..ad17147 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,26 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +## [1.9.4] - 2024-09-09 + +### Added + +- Support UDP dialing with gVisor. (#1181) + +### Changed + +- Make some Nebula state programmatically available via control object. (#1188) +- Switch internal representation of IPs to netip, to prepare for IPv6 support + in the overlay. (#1173) +- Minor build and cleanup changes. (#1171, #1164, #1162) +- Various dependency updates. (#1195, #1190, #1174, #1168, #1167, #1161, #1147, #1146) + +### Fixed + +- Fix a bug on big endian hosts, like mips. (#1194) +- Fix a rare panic if a local index collision happens. (#1191) +- Fix integer wraparound in the calculation of handshake timeouts on 32-bit targets. (#1185) + ## [1.9.3] - 2024-06-06 ### Fixed @@ -644,7 +664,8 @@ created.) - Initial public release. -[Unreleased]: https://github.com/slackhq/nebula/compare/v1.9.3...HEAD +[Unreleased]: https://github.com/slackhq/nebula/compare/v1.9.4...HEAD +[1.9.4]: https://github.com/slackhq/nebula/releases/tag/v1.9.4 [1.9.3]: https://github.com/slackhq/nebula/releases/tag/v1.9.3 [1.9.2]: https://github.com/slackhq/nebula/releases/tag/v1.9.2 [1.9.1]: https://github.com/slackhq/nebula/releases/tag/v1.9.1 From 35603d1c39fa8bfb0d35ef7ee29716023d0c65c0 Mon Sep 17 00:00:00 2001 From: Jack Doan Date: Mon, 9 Sep 2024 17:51:58 -0400 Subject: [PATCH 25/67] add PKCS11 support (#1153) * add PKCS11 support * add pkcs11 build option to the makefile, add a stub pkclient to avoid forcing CGO onto people * don't print the pkcs11 option on nebula-cert keygen if not compiled in * remove linux-arm64-pkcs11 from the all target to fix CI * correctly serialize ec keys * nebula-cert: support PKCS#11 for sign and ca * fix gofmt lint * clean up some logic with regard to closing sessions * pkclient: handle empty correctly for TPM2 * Update Makefile and Actions --------- Co-authored-by: Morgan Jones Co-authored-by: John Maguire --- .github/workflows/test.yml | 18 +++ .gitignore | 1 + Makefile | 13 +- cert/cert.go | 37 +++++- cmd/nebula-cert/ca.go | 125 ++++++++++++------ cmd/nebula-cert/ca_test.go | 1 + cmd/nebula-cert/keygen.go | 67 +++++++--- cmd/nebula-cert/keygen_test.go | 3 +- cmd/nebula-cert/main_test.go | 11 +- cmd/nebula-cert/p11_cgo.go | 15 +++ cmd/nebula-cert/p11_stub.go | 16 +++ cmd/nebula-cert/sign.go | 128 +++++++++++------- cmd/nebula-cert/sign_test.go | 1 + connection_state.go | 6 +- go.mod | 2 + go.sum | 4 + noiseutil/pkcs11.go | 50 +++++++ pkclient/pkclient.go | 87 +++++++++++++ pkclient/pkclient_cgo.go | 229 +++++++++++++++++++++++++++++++++ pkclient/pkclient_stub.go | 30 +++++ pki.go | 44 ++++--- 21 files changed, 761 insertions(+), 127 deletions(-) create mode 100644 cmd/nebula-cert/p11_cgo.go create mode 100644 cmd/nebula-cert/p11_stub.go create mode 100644 noiseutil/pkcs11.go create mode 100644 pkclient/pkclient.go create mode 100644 pkclient/pkclient_cgo.go create mode 100644 pkclient/pkclient_stub.go diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 65a6e3e..2b27f52 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -67,6 +67,24 @@ jobs: - name: End 2 end run: make e2evv GOEXPERIMENT=boringcrypto CGO_ENABLED=1 + test-linux-pkcs11: + name: Build and test on linux with pkcs11 + runs-on: ubuntu-latest + steps: + + - uses: actions/checkout@v4 + + - uses: actions/setup-go@v5 + with: + go-version: '1.22' + check-latest: true + + - name: Build + run: make bin-pkcs11 + + - name: Test + run: make test-pkcs11 + test: name: Build and test on ${{ matrix.os }} runs-on: ${{ matrix.os }} diff --git a/.gitignore b/.gitignore index 0efb967..0bffc85 100644 --- a/.gitignore +++ b/.gitignore @@ -13,5 +13,6 @@ **.crt **.key **.pem +**.pub !/examples/quickstart-vagrant/ansible/roles/nebula/files/vagrant-test-ca.key !/examples/quickstart-vagrant/ansible/roles/nebula/files/vagrant-test-ca.crt diff --git a/Makefile b/Makefile index 0d0943f..6922cc3 100644 --- a/Makefile +++ b/Makefile @@ -40,7 +40,7 @@ ALL_LINUX = linux-amd64 \ linux-mips64le \ linux-mips-softfloat \ linux-riscv64 \ - linux-loong64 + linux-loong64 ALL_FREEBSD = freebsd-amd64 \ freebsd-arm64 @@ -63,7 +63,7 @@ ALL = $(ALL_LINUX) \ e2e: $(TEST_ENV) go test -tags=e2e_testing -count=1 $(TEST_FLAGS) ./e2e -e2ev: TEST_FLAGS = -v +e2ev: TEST_FLAGS += -v e2ev: e2e e2evv: TEST_ENV += TEST_LOGS=1 @@ -96,7 +96,7 @@ release-netbsd: $(ALL_NETBSD:%=build/nebula-%.tar.gz) release-boringcrypto: build/nebula-linux-$(shell go env GOARCH)-boringcrypto.tar.gz -BUILD_ARGS = -trimpath +BUILD_ARGS += -trimpath bin-windows: build/windows-amd64/nebula.exe build/windows-amd64/nebula-cert.exe mv $? . @@ -116,6 +116,10 @@ bin-freebsd-arm64: build/freebsd-arm64/nebula build/freebsd-arm64/nebula-cert bin-boringcrypto: build/linux-$(shell go env GOARCH)-boringcrypto/nebula build/linux-$(shell go env GOARCH)-boringcrypto/nebula-cert mv $? . +bin-pkcs11: BUILD_ARGS += -tags pkcs11 +bin-pkcs11: CGO_ENABLED = 1 +bin-pkcs11: bin + bin: go build $(BUILD_ARGS) -ldflags "$(LDFLAGS)" -o ./nebula${NEBULA_CMD_SUFFIX} ${NEBULA_CMD_PATH} go build $(BUILD_ARGS) -ldflags "$(LDFLAGS)" -o ./nebula-cert${NEBULA_CMD_SUFFIX} ./cmd/nebula-cert @@ -168,6 +172,9 @@ test: test-boringcrypto: GOEXPERIMENT=boringcrypto CGO_ENABLED=1 go test -v ./... +test-pkcs11: + CGO_ENABLED=1 go test -v -tags pkcs11 ./... + test-cov-html: go test -coverprofile=coverage.out go tool cover -html=coverage.out diff --git a/cert/cert.go b/cert/cert.go index a0164f7..dd08923 100644 --- a/cert/cert.go +++ b/cert/cert.go @@ -20,6 +20,7 @@ import ( "sync/atomic" "time" + "github.com/slackhq/nebula/pkclient" "golang.org/x/crypto/curve25519" "google.golang.org/protobuf/proto" ) @@ -41,8 +42,9 @@ const ( ) type NebulaCertificate struct { - Details NebulaCertificateDetails - Signature []byte + Details NebulaCertificateDetails + Pkcs11Backed bool + Signature []byte // the cached hex string of the calculated sha256sum // for VerifyWithCache @@ -555,6 +557,34 @@ func (nc *NebulaCertificate) Sign(curve Curve, key []byte) error { return nil } +// SignPkcs11 signs a nebula cert with the provided private key +func (nc *NebulaCertificate) SignPkcs11(curve Curve, client *pkclient.PKClient) error { + if !nc.Pkcs11Backed { + return fmt.Errorf("certificate is not PKCS#11 backed") + } + + if curve != nc.Details.Curve { + return fmt.Errorf("curve in cert and private key supplied don't match") + } + + if curve != Curve_P256 { + return fmt.Errorf("only P256 is supported by PKCS#11") + } + + b, err := proto.Marshal(nc.getRawDetails()) + if err != nil { + return err + } + + sig, err := client.SignASN1(b) + if err != nil { + return err + } + + nc.Signature = sig + return nil +} + // CheckSignature verifies the signature against the provided public key func (nc *NebulaCertificate) CheckSignature(key []byte) bool { b, err := proto.Marshal(nc.getRawDetails()) @@ -693,6 +723,9 @@ func (nc *NebulaCertificate) CheckRootConstrains(signer *NebulaCertificate) erro // VerifyPrivateKey checks that the public key in the Nebula certificate and a supplied private key match func (nc *NebulaCertificate) VerifyPrivateKey(curve Curve, key []byte) error { + if nc.Pkcs11Backed { + return nil //todo! + } if curve != nc.Details.Curve { return fmt.Errorf("curve in cert and private key supplied don't match") } diff --git a/cmd/nebula-cert/ca.go b/cmd/nebula-cert/ca.go index 4e5d51d..757f883 100644 --- a/cmd/nebula-cert/ca.go +++ b/cmd/nebula-cert/ca.go @@ -4,6 +4,7 @@ import ( "crypto/ecdsa" "crypto/elliptic" "crypto/rand" + "flag" "fmt" "io" @@ -15,6 +16,7 @@ import ( "github.com/skip2/go-qrcode" "github.com/slackhq/nebula/cert" + "github.com/slackhq/nebula/pkclient" "golang.org/x/crypto/ed25519" ) @@ -33,7 +35,8 @@ type caFlags struct { argonParallelism *uint encryption *bool - curve *string + curve *string + p11url *string } func newCaFlags() *caFlags { @@ -52,6 +55,7 @@ func newCaFlags() *caFlags { cf.argonIterations = cf.set.Uint("argon-iterations", 1, "Optional: Argon2 iterations parameter used for encrypted private key passphrase") cf.encryption = cf.set.Bool("encrypt", false, "Optional: prompt for passphrase and write out-key in an encrypted format") cf.curve = cf.set.String("curve", "25519", "EdDSA/ECDSA Curve (25519, P256)") + cf.p11url = p11Flag(cf.set) return &cf } @@ -76,17 +80,21 @@ func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error return err } + isP11 := len(*cf.p11url) > 0 + if err := mustFlagString("name", cf.name); err != nil { return err } - if err := mustFlagString("out-key", cf.outKeyPath); err != nil { - return err + if !isP11 { + if err = mustFlagString("out-key", cf.outKeyPath); err != nil { + return err + } } if err := mustFlagString("out-crt", cf.outCertPath); err != nil { return err } var kdfParams *cert.Argon2Parameters - if *cf.encryption { + if !isP11 && *cf.encryption { if kdfParams, err = parseArgonParameters(*cf.argonMemory, *cf.argonParallelism, *cf.argonIterations); err != nil { return err } @@ -143,7 +151,7 @@ func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error } var passphrase []byte - if *cf.encryption { + if !isP11 && *cf.encryption { for i := 0; i < 5; i++ { out.Write([]byte("Enter passphrase: ")) passphrase, err = pr.ReadPassword() @@ -166,29 +174,54 @@ func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error var curve cert.Curve var pub, rawPriv []byte - switch *cf.curve { - case "25519", "X25519", "Curve25519", "CURVE25519": - curve = cert.Curve_CURVE25519 - pub, rawPriv, err = ed25519.GenerateKey(rand.Reader) - if err != nil { - return fmt.Errorf("error while generating ed25519 keys: %s", err) - } - case "P256": - var key *ecdsa.PrivateKey - curve = cert.Curve_P256 - key, err = ecdsa.GenerateKey(elliptic.P256(), rand.Reader) - if err != nil { - return fmt.Errorf("error while generating ecdsa keys: %s", err) + var p11Client *pkclient.PKClient + + if isP11 { + switch *cf.curve { + case "P256": + curve = cert.Curve_P256 + default: + return fmt.Errorf("invalid curve for PKCS#11: %s", *cf.curve) } - // ecdh.PrivateKey lets us get at the encoded bytes, even though - // we aren't using ECDH here. - eKey, err := key.ECDH() + p11Client, err = pkclient.FromUrl(*cf.p11url) if err != nil { - return fmt.Errorf("error while converting ecdsa key: %s", err) + return fmt.Errorf("error while creating PKCS#11 client: %w", err) + } + defer func(client *pkclient.PKClient) { + _ = client.Close() + }(p11Client) + pub, err = p11Client.GetPubKey() + if err != nil { + return fmt.Errorf("error while getting public key with PKCS#11: %w", err) + } + } else { + switch *cf.curve { + case "25519", "X25519", "Curve25519", "CURVE25519": + curve = cert.Curve_CURVE25519 + pub, rawPriv, err = ed25519.GenerateKey(rand.Reader) + if err != nil { + return fmt.Errorf("error while generating ed25519 keys: %s", err) + } + case "P256": + var key *ecdsa.PrivateKey + curve = cert.Curve_P256 + key, err = ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + return fmt.Errorf("error while generating ecdsa keys: %s", err) + } + + // ecdh.PrivateKey lets us get at the encoded bytes, even though + // we aren't using ECDH here. + eKey, err := key.ECDH() + if err != nil { + return fmt.Errorf("error while converting ecdsa key: %s", err) + } + rawPriv = eKey.Bytes() + pub = eKey.PublicKey().Bytes() + default: + return fmt.Errorf("invalid curve: %s", *cf.curve) } - rawPriv = eKey.Bytes() - pub = eKey.PublicKey().Bytes() } nc := cert.NebulaCertificate{ @@ -203,34 +236,48 @@ func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error IsCA: true, Curve: curve, }, + Pkcs11Backed: isP11, } - if _, err := os.Stat(*cf.outKeyPath); err == nil { - return fmt.Errorf("refusing to overwrite existing CA key: %s", *cf.outKeyPath) + if !isP11 { + if _, err := os.Stat(*cf.outKeyPath); err == nil { + return fmt.Errorf("refusing to overwrite existing CA key: %s", *cf.outKeyPath) + } } if _, err := os.Stat(*cf.outCertPath); err == nil { return fmt.Errorf("refusing to overwrite existing CA cert: %s", *cf.outCertPath) } - err = nc.Sign(curve, rawPriv) - if err != nil { - return fmt.Errorf("error while signing: %s", err) - } - var b []byte - if *cf.encryption { - b, err = cert.EncryptAndMarshalSigningPrivateKey(curve, rawPriv, passphrase, kdfParams) + + if isP11 { + err = nc.SignPkcs11(curve, p11Client) if err != nil { - return fmt.Errorf("error while encrypting out-key: %s", err) + return fmt.Errorf("error while signing with PKCS#11: %w", err) } } else { - b = cert.MarshalSigningPrivateKey(curve, rawPriv) - } + err = nc.Sign(curve, rawPriv) + if err != nil { + return fmt.Errorf("error while signing: %s", err) + } - err = os.WriteFile(*cf.outKeyPath, b, 0600) - if err != nil { - return fmt.Errorf("error while writing out-key: %s", err) + if *cf.encryption { + b, err = cert.EncryptAndMarshalSigningPrivateKey(curve, rawPriv, passphrase, kdfParams) + if err != nil { + return fmt.Errorf("error while encrypting out-key: %s", err) + } + } else { + b = cert.MarshalSigningPrivateKey(curve, rawPriv) + } + + err = os.WriteFile(*cf.outKeyPath, b, 0600) + if err != nil { + return fmt.Errorf("error while writing out-key: %s", err) + } + if _, err := os.Stat(*cf.outCertPath); err == nil { + return fmt.Errorf("refusing to overwrite existing CA cert: %s", *cf.outCertPath) + } } b, err = nc.MarshalToPEM() diff --git a/cmd/nebula-cert/ca_test.go b/cmd/nebula-cert/ca_test.go index 3a53405..cb8b57a 100644 --- a/cmd/nebula-cert/ca_test.go +++ b/cmd/nebula-cert/ca_test.go @@ -52,6 +52,7 @@ func Test_caHelp(t *testing.T) { " \tOptional: path to write the private key to (default \"ca.key\")\n"+ " -out-qr string\n"+ " \tOptional: output a qr code image (png) of the certificate\n"+ + optionalPkcs11String(" -pkcs11 string\n \tOptional: PKCS#11 URI to an existing private key\n")+ " -subnets string\n"+ " \tOptional: comma separated list of ipv4 address and network in CIDR notation. This will limit which ipv4 addresses and networks subordinate certs can use in subnets\n", ob.String(), diff --git a/cmd/nebula-cert/keygen.go b/cmd/nebula-cert/keygen.go index d94cbf1..2355c4f 100644 --- a/cmd/nebula-cert/keygen.go +++ b/cmd/nebula-cert/keygen.go @@ -6,6 +6,8 @@ import ( "io" "os" + "github.com/slackhq/nebula/pkclient" + "github.com/slackhq/nebula/cert" ) @@ -13,8 +15,8 @@ type keygenFlags struct { set *flag.FlagSet outKeyPath *string outPubPath *string - - curve *string + curve *string + p11url *string } func newKeygenFlags() *keygenFlags { @@ -23,6 +25,7 @@ func newKeygenFlags() *keygenFlags { cf.outPubPath = cf.set.String("out-pub", "", "Required: path to write the public key to") cf.outKeyPath = cf.set.String("out-key", "", "Required: path to write the private key to") cf.curve = cf.set.String("curve", "25519", "ECDH Curve (25519, P256)") + cf.p11url = p11Flag(cf.set) return &cf } @@ -33,31 +36,57 @@ func keygen(args []string, out io.Writer, errOut io.Writer) error { return err } - if err := mustFlagString("out-key", cf.outKeyPath); err != nil { - return err + isP11 := len(*cf.p11url) > 0 + + if !isP11 { + if err = mustFlagString("out-key", cf.outKeyPath); err != nil { + return err + } } - if err := mustFlagString("out-pub", cf.outPubPath); err != nil { + if err = mustFlagString("out-pub", cf.outPubPath); err != nil { return err } var pub, rawPriv []byte var curve cert.Curve - switch *cf.curve { - case "25519", "X25519", "Curve25519", "CURVE25519": - pub, rawPriv = x25519Keypair() - curve = cert.Curve_CURVE25519 - case "P256": - pub, rawPriv = p256Keypair() - curve = cert.Curve_P256 - default: - return fmt.Errorf("invalid curve: %s", *cf.curve) + if isP11 { + switch *cf.curve { + case "P256": + curve = cert.Curve_P256 + default: + return fmt.Errorf("invalid curve for PKCS#11: %s", *cf.curve) + } + } else { + switch *cf.curve { + case "25519", "X25519", "Curve25519", "CURVE25519": + pub, rawPriv = x25519Keypair() + curve = cert.Curve_CURVE25519 + case "P256": + pub, rawPriv = p256Keypair() + curve = cert.Curve_P256 + default: + return fmt.Errorf("invalid curve: %s", *cf.curve) + } } - err = os.WriteFile(*cf.outKeyPath, cert.MarshalPrivateKey(curve, rawPriv), 0600) - if err != nil { - return fmt.Errorf("error while writing out-key: %s", err) + if isP11 { + p11Client, err := pkclient.FromUrl(*cf.p11url) + if err != nil { + return fmt.Errorf("error while creating PKCS#11 client: %w", err) + } + defer func(client *pkclient.PKClient) { + _ = client.Close() + }(p11Client) + pub, err = p11Client.GetPubKey() + if err != nil { + return fmt.Errorf("error while getting public key: %w", err) + } + } else { + err = os.WriteFile(*cf.outKeyPath, cert.MarshalPrivateKey(curve, rawPriv), 0600) + if err != nil { + return fmt.Errorf("error while writing out-key: %s", err) + } } - err = os.WriteFile(*cf.outPubPath, cert.MarshalPublicKey(curve, pub), 0600) if err != nil { return fmt.Errorf("error while writing out-pub: %s", err) @@ -72,7 +101,7 @@ func keygenSummary() string { func keygenHelp(out io.Writer) { cf := newKeygenFlags() - out.Write([]byte("Usage of " + os.Args[0] + " " + keygenSummary() + "\n")) + _, _ = out.Write([]byte("Usage of " + os.Args[0] + " " + keygenSummary() + "\n")) cf.set.SetOutput(out) cf.set.PrintDefaults() } diff --git a/cmd/nebula-cert/keygen_test.go b/cmd/nebula-cert/keygen_test.go index 9a3b3f3..925b266 100644 --- a/cmd/nebula-cert/keygen_test.go +++ b/cmd/nebula-cert/keygen_test.go @@ -26,7 +26,8 @@ func Test_keygenHelp(t *testing.T) { " -out-key string\n"+ " \tRequired: path to write the private key to\n"+ " -out-pub string\n"+ - " \tRequired: path to write the public key to\n", + " \tRequired: path to write the public key to\n"+ + optionalPkcs11String(" -pkcs11 string\n \tOptional: PKCS#11 URI to an existing private key\n"), ob.String(), ) } diff --git a/cmd/nebula-cert/main_test.go b/cmd/nebula-cert/main_test.go index 3d0fa1b..2502824 100644 --- a/cmd/nebula-cert/main_test.go +++ b/cmd/nebula-cert/main_test.go @@ -3,6 +3,7 @@ package main import ( "bytes" "errors" + "fmt" "io" "os" "testing" @@ -77,8 +78,16 @@ func assertHelpError(t *testing.T, err error, msg string) { case *helpError: // good default: - t.Fatal("err was not a helpError") + t.Fatal(fmt.Sprintf("err was not a helpError: %q, expected %q", err, msg)) } assert.EqualError(t, err, msg) } + +func optionalPkcs11String(msg string) string { + if p11Supported() { + return msg + } else { + return "" + } +} diff --git a/cmd/nebula-cert/p11_cgo.go b/cmd/nebula-cert/p11_cgo.go new file mode 100644 index 0000000..f1f1ec6 --- /dev/null +++ b/cmd/nebula-cert/p11_cgo.go @@ -0,0 +1,15 @@ +//go:build cgo && pkcs11 + +package main + +import ( + "flag" +) + +func p11Supported() bool { + return true +} + +func p11Flag(set *flag.FlagSet) *string { + return set.String("pkcs11", "", "Optional: PKCS#11 URI to an existing private key") +} diff --git a/cmd/nebula-cert/p11_stub.go b/cmd/nebula-cert/p11_stub.go new file mode 100644 index 0000000..5afeaea --- /dev/null +++ b/cmd/nebula-cert/p11_stub.go @@ -0,0 +1,16 @@ +//go:build !cgo || !pkcs11 + +package main + +import ( + "flag" +) + +func p11Supported() bool { + return false +} + +func p11Flag(set *flag.FlagSet) *string { + var ret = "" + return &ret +} diff --git a/cmd/nebula-cert/sign.go b/cmd/nebula-cert/sign.go index 35d6446..8e86fe5 100644 --- a/cmd/nebula-cert/sign.go +++ b/cmd/nebula-cert/sign.go @@ -13,6 +13,7 @@ import ( "github.com/skip2/go-qrcode" "github.com/slackhq/nebula/cert" + "github.com/slackhq/nebula/pkclient" "golang.org/x/crypto/curve25519" ) @@ -29,6 +30,7 @@ type signFlags struct { outQRPath *string groups *string subnets *string + p11url *string } func newSignFlags() *signFlags { @@ -45,8 +47,8 @@ func newSignFlags() *signFlags { sf.outQRPath = sf.set.String("out-qr", "", "Optional: output a qr code image (png) of the certificate") sf.groups = sf.set.String("groups", "", "Optional: comma separated list of groups") sf.subnets = sf.set.String("subnets", "", "Optional: comma separated list of ipv4 address and network in CIDR notation. Subnets this cert can serve for") + sf.p11url = p11Flag(sf.set) return &sf - } func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error { @@ -56,8 +58,12 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) return err } - if err := mustFlagString("ca-key", sf.caKeyPath); err != nil { - return err + isP11 := len(*sf.p11url) > 0 + + if !isP11 { + if err := mustFlagString("ca-key", sf.caKeyPath); err != nil { + return err + } } if err := mustFlagString("ca-crt", sf.caCertPath); err != nil { return err @@ -68,47 +74,49 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) if err := mustFlagString("ip", sf.ip); err != nil { return err } - if *sf.inPubPath != "" && *sf.outKeyPath != "" { + if !isP11 && *sf.inPubPath != "" && *sf.outKeyPath != "" { return newHelpErrorf("cannot set both -in-pub and -out-key") } - rawCAKey, err := os.ReadFile(*sf.caKeyPath) - if err != nil { - return fmt.Errorf("error while reading ca-key: %s", err) - } - var curve cert.Curve var caKey []byte - - // naively attempt to decode the private key as though it is not encrypted - caKey, _, curve, err = cert.UnmarshalSigningPrivateKey(rawCAKey) - if err == cert.ErrPrivateKeyEncrypted { - // ask for a passphrase until we get one - var passphrase []byte - for i := 0; i < 5; i++ { - out.Write([]byte("Enter passphrase: ")) - passphrase, err = pr.ReadPassword() - - if err == ErrNoTerminal { - return fmt.Errorf("ca-key is encrypted and must be decrypted interactively") - } else if err != nil { - return fmt.Errorf("error reading password: %s", err) - } - - if len(passphrase) > 0 { - break - } - } - if len(passphrase) == 0 { - return fmt.Errorf("cannot open encrypted ca-key without passphrase") - } - - curve, caKey, _, err = cert.DecryptAndUnmarshalSigningPrivateKey(passphrase, rawCAKey) + if !isP11 { + var rawCAKey []byte + rawCAKey, err := os.ReadFile(*sf.caKeyPath) if err != nil { - return fmt.Errorf("error while parsing encrypted ca-key: %s", err) + return fmt.Errorf("error while reading ca-key: %s", err) + } + + // naively attempt to decode the private key as though it is not encrypted + caKey, _, curve, err = cert.UnmarshalSigningPrivateKey(rawCAKey) + if err == cert.ErrPrivateKeyEncrypted { + // ask for a passphrase until we get one + var passphrase []byte + for i := 0; i < 5; i++ { + out.Write([]byte("Enter passphrase: ")) + passphrase, err = pr.ReadPassword() + + if err == ErrNoTerminal { + return fmt.Errorf("ca-key is encrypted and must be decrypted interactively") + } else if err != nil { + return fmt.Errorf("error reading password: %s", err) + } + + if len(passphrase) > 0 { + break + } + } + if len(passphrase) == 0 { + return fmt.Errorf("cannot open encrypted ca-key without passphrase") + } + + curve, caKey, _, err = cert.DecryptAndUnmarshalSigningPrivateKey(passphrase, rawCAKey) + if err != nil { + return fmt.Errorf("error while parsing encrypted ca-key: %s", err) + } + } else if err != nil { + return fmt.Errorf("error while parsing ca-key: %s", err) } - } else if err != nil { - return fmt.Errorf("error while parsing ca-key: %s", err) } rawCACert, err := os.ReadFile(*sf.caCertPath) @@ -121,8 +129,10 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) return fmt.Errorf("error while parsing ca-crt: %s", err) } - if err := caCert.VerifyPrivateKey(curve, caKey); err != nil { - return fmt.Errorf("refusing to sign, root certificate does not match private key") + if !isP11 { + if err := caCert.VerifyPrivateKey(curve, caKey); err != nil { + return fmt.Errorf("refusing to sign, root certificate does not match private key") + } } issuer, err := caCert.Sha256Sum() @@ -176,12 +186,25 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) } var pub, rawPriv []byte + var p11Client *pkclient.PKClient + + if isP11 { + curve = cert.Curve_P256 + p11Client, err = pkclient.FromUrl(*sf.p11url) + if err != nil { + return fmt.Errorf("error while creating PKCS#11 client: %w", err) + } + defer func(client *pkclient.PKClient) { + _ = client.Close() + }(p11Client) + } + if *sf.inPubPath != "" { + var pubCurve cert.Curve rawPub, err := os.ReadFile(*sf.inPubPath) if err != nil { return fmt.Errorf("error while reading in-pub: %s", err) } - var pubCurve cert.Curve pub, _, pubCurve, err = cert.UnmarshalPublicKey(rawPub) if err != nil { return fmt.Errorf("error while parsing in-pub: %s", err) @@ -189,6 +212,11 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) if pubCurve != curve { return fmt.Errorf("curve of in-pub does not match ca") } + } else if isP11 { + pub, err = p11Client.GetPubKey() + if err != nil { + return fmt.Errorf("error while getting public key with PKCS#11: %w", err) + } } else { pub, rawPriv = newKeypair(curve) } @@ -206,6 +234,19 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) Issuer: issuer, Curve: curve, }, + Pkcs11Backed: isP11, + } + + if p11Client == nil { + err = nc.Sign(curve, caKey) + if err != nil { + return fmt.Errorf("error while signing: %w", err) + } + } else { + err = nc.SignPkcs11(curve, p11Client) + if err != nil { + return fmt.Errorf("error while signing with PKCS#11: %w", err) + } } if err := nc.CheckRootConstrains(caCert); err != nil { @@ -224,12 +265,7 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) return fmt.Errorf("refusing to overwrite existing cert: %s", *sf.outCertPath) } - err = nc.Sign(curve, caKey) - if err != nil { - return fmt.Errorf("error while signing: %s", err) - } - - if *sf.inPubPath == "" { + if !isP11 && *sf.inPubPath == "" { if _, err := os.Stat(*sf.outKeyPath); err == nil { return fmt.Errorf("refusing to overwrite existing key: %s", *sf.outKeyPath) } diff --git a/cmd/nebula-cert/sign_test.go b/cmd/nebula-cert/sign_test.go index adf83a2..d6e2a39 100644 --- a/cmd/nebula-cert/sign_test.go +++ b/cmd/nebula-cert/sign_test.go @@ -48,6 +48,7 @@ func Test_signHelp(t *testing.T) { " \tOptional (if in-pub not set): path to write the private key to\n"+ " -out-qr string\n"+ " \tOptional: output a qr code image (png) of the certificate\n"+ + optionalPkcs11String(" -pkcs11 string\n \tOptional: PKCS#11 URI to an existing private key\n")+ " -subnets string\n"+ " \tOptional: comma separated list of ipv4 address and network in CIDR notation. Subnets this cert can serve for\n", ob.String(), diff --git a/connection_state.go b/connection_state.go index 1dd3c8c..aa17a13 100644 --- a/connection_state.go +++ b/connection_state.go @@ -32,7 +32,11 @@ func NewConnectionState(l *logrus.Logger, cipher string, certState *CertState, i case cert.Curve_CURVE25519: dhFunc = noise.DH25519 case cert.Curve_P256: - dhFunc = noiseutil.DHP256 + if certState.Certificate.Pkcs11Backed { + dhFunc = noiseutil.DHP256PKCS11 + } else { + dhFunc = noiseutil.DHP256 + } default: l.Errorf("invalid curve: %s", certState.Certificate.Details.Curve) return nil diff --git a/go.mod b/go.mod index adb2e84..4d36f34 100644 --- a/go.mod +++ b/go.mod @@ -15,12 +15,14 @@ require ( github.com/google/gopacket v1.1.19 github.com/kardianos/service v1.2.2 github.com/miekg/dns v1.1.61 + github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b github.com/nbrownus/go-metrics-prometheus v0.0.0-20210712211119-974a6260965f github.com/prometheus/client_golang v1.19.1 github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475 github.com/sirupsen/logrus v1.9.3 github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8 + github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6 github.com/stretchr/testify v1.9.0 github.com/vishvananda/netlink v1.2.1-beta.2 golang.org/x/crypto v0.26.0 diff --git a/go.sum b/go.sum index 3afd6cb..e90ae8e 100644 --- a/go.sum +++ b/go.sum @@ -83,6 +83,8 @@ github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= github.com/miekg/dns v1.1.61 h1:nLxbwF3XxhwVSm8g9Dghm9MHPaUZuqhPiGL+675ZmEs= github.com/miekg/dns v1.1.61/go.mod h1:mnAarhS3nWaW+NVP2wTkYVIZyHNJ098SJZUki3eykwQ= +github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b h1:J/AzCvg5z0Hn1rqZUJjpbzALUmkKX0Zwbc/i4fw7Sfk= +github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b/go.mod h1:XsNlhZGX73bx86s2hdc/FuaLm2CPZJemRLMA+WTFxgs= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= @@ -131,6 +133,8 @@ github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e h1:MRM5ITcdelLK2j1 github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e/go.mod h1:XV66xRDqSt+GTGFMVlhk3ULuV0y9ZmzeVGR4mloJI3M= github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8 h1:TG/diQgUe0pntT/2D9tmUCz4VNwm9MfrtPr0SU2qSX8= github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8/go.mod h1:P5HUIBuIWKbyjl083/loAegFkfbFNx5i2qEP4CNbm7E= +github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6 h1:pnnLyeX7o/5aX8qUQ69P/mLojDqwda8hFOCBTmP/6hw= +github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6/go.mod h1:39R/xuhNgVhi+K0/zst4TLrJrVmbm6LVgl4A0+ZFS5M= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= diff --git a/noiseutil/pkcs11.go b/noiseutil/pkcs11.go new file mode 100644 index 0000000..d1c7ba9 --- /dev/null +++ b/noiseutil/pkcs11.go @@ -0,0 +1,50 @@ +package noiseutil + +import ( + "crypto/ecdh" + "fmt" + "strings" + + "github.com/slackhq/nebula/pkclient" + + "github.com/flynn/noise" +) + +// DHP256PKCS11 is the NIST P-256 ECDH function +var DHP256PKCS11 noise.DHFunc = newNISTP11Curve("P256", ecdh.P256(), 32) + +type nistP11Curve struct { + nistCurve +} + +func newNISTP11Curve(name string, curve ecdh.Curve, byteLen int) nistP11Curve { + return nistP11Curve{ + newNISTCurve(name, curve, byteLen), + } +} + +func (c nistP11Curve) DH(privkey, pubkey []byte) ([]byte, error) { + //for this function "privkey" is actually a pkcs11 URI + pkStr := string(privkey) + + //to set up a handshake, we need to also do non-pkcs11-DH. Handle that here. + if !strings.HasPrefix(pkStr, "pkcs11:") { + return DHP256.DH(privkey, pubkey) + } + ecdhPubKey, err := c.curve.NewPublicKey(pubkey) + if err != nil { + return nil, fmt.Errorf("unable to unmarshal pubkey: %w", err) + } + + //this is not the most performant way to do this (a long-lived client would be better) + //but, it works, and helps avoid problems with stale sessions and HSMs used by multiple users. + client, err := pkclient.FromUrl(pkStr) + if err != nil { + return nil, err + } + defer func(client *pkclient.PKClient) { + _ = client.Close() + }(client) + + return client.DeriveNoise(ecdhPubKey.Bytes()) +} diff --git a/pkclient/pkclient.go b/pkclient/pkclient.go new file mode 100644 index 0000000..7061de6 --- /dev/null +++ b/pkclient/pkclient.go @@ -0,0 +1,87 @@ +package pkclient + +import ( + "crypto/ecdsa" + "crypto/x509" + "fmt" + "io" + "strconv" + + "github.com/stefanberger/go-pkcs11uri" +) + +type Client interface { + io.Closer + GetPubKey() ([]byte, error) + DeriveNoise(peerPubKey []byte) ([]byte, error) + Test() error +} + +const NoiseKeySize = 32 + +func FromUrl(pkurl string) (*PKClient, error) { + uri := pkcs11uri.New() + uri.SetAllowAnyModule(true) //todo + err := uri.Parse(pkurl) + if err != nil { + return nil, err + } + + module, err := uri.GetModule() + if err != nil { + return nil, err + } + + slotid := 0 + slot, ok := uri.GetPathAttribute("slot-id", false) + if !ok { + slotid = 0 + } else { + slotid, err = strconv.Atoi(slot) + if err != nil { + return nil, err + } + } + + pin, _ := uri.GetPIN() + id, _ := uri.GetPathAttribute("id", false) + label, _ := uri.GetPathAttribute("object", false) + + return New(module, uint(slotid), pin, id, label) +} + +func ecKeyToArray(key *ecdsa.PublicKey) []byte { + x := make([]byte, 32) + y := make([]byte, 32) + key.X.FillBytes(x) + key.Y.FillBytes(y) + return append([]byte{0x04}, append(x, y...)...) +} + +func formatPubkeyFromPublicKeyInfoAttr(d []byte) ([]byte, error) { + e, err := x509.ParsePKIXPublicKey(d) + if err != nil { + return nil, err + } + switch t := e.(type) { + case *ecdsa.PublicKey: + return ecKeyToArray(e.(*ecdsa.PublicKey)), nil + default: + return nil, fmt.Errorf("unknown public key type: %T", t) + } +} + +func (c *PKClient) Test() error { + pub, err := c.GetPubKey() + if err != nil { + return fmt.Errorf("failed to get public key: %w", err) + } + out, err := c.DeriveNoise(pub) //do an ECDH with ourselves as a quick test + if err != nil { + return err + } + if len(out) != NoiseKeySize { + return fmt.Errorf("got a key of %d bytes, expected %d", len(out), NoiseKeySize) + } + return nil +} diff --git a/pkclient/pkclient_cgo.go b/pkclient/pkclient_cgo.go new file mode 100644 index 0000000..a2ead55 --- /dev/null +++ b/pkclient/pkclient_cgo.go @@ -0,0 +1,229 @@ +//go:build cgo && pkcs11 + +package pkclient + +import ( + "encoding/asn1" + "errors" + "fmt" + "log" + "math/big" + + "github.com/miekg/pkcs11" + "github.com/miekg/pkcs11/p11" +) + +type PKClient struct { + module p11.Module + session p11.Session + id []byte + label []byte + privKeyObj p11.Object + pubKeyObj p11.Object +} + +type ecdsaSignature struct { + R, S *big.Int +} + +// New tries to open a session with the HSM, select the slot and login to it +func New(hsmPath string, slotId uint, pin string, id string, label string) (*PKClient, error) { + module, err := p11.OpenModule(hsmPath) + if err != nil { + return nil, fmt.Errorf("failed to load module library: %s", hsmPath) + } + + slots, err := module.Slots() + if err != nil { + module.Destroy() + return nil, err + } + + // Try to open a session on the slot + slotIdx := 0 + for i, slot := range slots { + if slot.ID() == slotId { + slotIdx = i + break + } + } + + client := &PKClient{ + module: module, + id: []byte(id), + label: []byte(label), + } + + client.session, err = slots[slotIdx].OpenWriteSession() + if err != nil { + module.Destroy() + return nil, fmt.Errorf("failed to open session on slot %d", slotId) + } + + if len(pin) != 0 { + err = client.session.Login(pin) + if err != nil { + // ignore "already logged in" + if !errors.Is(err, pkcs11.Error(256)) { + _ = client.session.Close() + return nil, fmt.Errorf("unable to login. error: %w", err) + } + } + } + + // Make sure the hsm has a private key for deriving + client.privKeyObj, err = client.findDeriveKey(client.id, client.label, true) + if err != nil { + _ = client.Close() //log out, close session, destroy module + return nil, fmt.Errorf("failed to find private key for deriving: %w", err) + } + + return client, nil +} + +// Close cleans up properly and logs out +func (c *PKClient) Close() error { + var err error = nil + if c.session != nil { + _ = c.session.Logout() //if logout fails, we still want to close + err = c.session.Close() + } + + c.module.Destroy() + return err +} + +// Try to find a suitable key on the hsm for key derivation +// parameter GET_PUB_KEY sets the search pattern for a public or private key +func (c *PKClient) findDeriveKey(id []byte, label []byte, private bool) (key p11.Object, err error) { + keyClass := pkcs11.CKO_PRIVATE_KEY + if !private { + keyClass = pkcs11.CKO_PUBLIC_KEY + } + keyAttrs := []*pkcs11.Attribute{ + //todo, not all HSMs seem to report this, even if its true: pkcs11.NewAttribute(pkcs11.CKA_DERIVE, true), + pkcs11.NewAttribute(pkcs11.CKA_CLASS, keyClass), + } + + if id != nil && len(id) != 0 { + keyAttrs = append(keyAttrs, pkcs11.NewAttribute(pkcs11.CKA_ID, id)) + } + if label != nil && len(label) != 0 { + keyAttrs = append(keyAttrs, pkcs11.NewAttribute(pkcs11.CKA_LABEL, label)) + } + + return c.session.FindObject(keyAttrs) +} + +func (c *PKClient) listDeriveKeys(id []byte, label []byte, private bool) { + keyClass := pkcs11.CKO_PRIVATE_KEY + if !private { + keyClass = pkcs11.CKO_PUBLIC_KEY + } + keyAttrs := []*pkcs11.Attribute{ + pkcs11.NewAttribute(pkcs11.CKA_CLASS, keyClass), + } + + if id != nil && len(id) != 0 { + keyAttrs = append(keyAttrs, pkcs11.NewAttribute(pkcs11.CKA_ID, id)) + } + if label != nil && len(label) != 0 { + keyAttrs = append(keyAttrs, pkcs11.NewAttribute(pkcs11.CKA_LABEL, label)) + } + + objects, err := c.session.FindObjects(keyAttrs) + if err != nil { + return + } + + for _, obj := range objects { + l, err := obj.Label() + log.Printf("%s, %v", l, err) + a, err := obj.Attribute(pkcs11.CKA_DERIVE) + log.Printf("DERIVE: %s %v, %v", l, a, err) + } +} + +// SignASN1 signs some data. Returns the ASN.1 encoded signature. +func (c *PKClient) SignASN1(data []byte) ([]byte, error) { + mech := pkcs11.NewMechanism(pkcs11.CKM_ECDSA_SHA256, nil) + sk := p11.PrivateKey(c.privKeyObj) + rawSig, err := sk.Sign(*mech, data) + if err != nil { + return nil, err + } + + // PKCS #11 Mechanisms v2.30: + // "The signature octets correspond to the concatenation of the ECDSA values r and s, + // both represented as an octet string of equal length of at most nLen with the most + // significant byte first. If r and s have different octet length, the shorter of both + // must be padded with leading zero octets such that both have the same octet length. + // Loosely spoken, the first half of the signature is r and the second half is s." + r := new(big.Int).SetBytes(rawSig[:len(rawSig)/2]) + s := new(big.Int).SetBytes(rawSig[len(rawSig)/2:]) + return asn1.Marshal(ecdsaSignature{r, s}) +} + +// DeriveNoise derives a shared secret using the input public key against the private key that was found during setup. +// Returns a fixed 32 byte array. +func (c *PKClient) DeriveNoise(peerPubKey []byte) ([]byte, error) { + // Before we call derive, we need to have an array of attributes which specify the type of + // key to be returned, in our case, it's the shared secret key, produced via deriving + // This template pulled from OpenSC pkclient-tool.c line 4038 + attrTemplate := []*pkcs11.Attribute{ + pkcs11.NewAttribute(pkcs11.CKA_TOKEN, false), + pkcs11.NewAttribute(pkcs11.CKA_CLASS, pkcs11.CKO_SECRET_KEY), + pkcs11.NewAttribute(pkcs11.CKA_KEY_TYPE, pkcs11.CKK_GENERIC_SECRET), + pkcs11.NewAttribute(pkcs11.CKA_SENSITIVE, false), + pkcs11.NewAttribute(pkcs11.CKA_EXTRACTABLE, true), + pkcs11.NewAttribute(pkcs11.CKA_ENCRYPT, true), + pkcs11.NewAttribute(pkcs11.CKA_DECRYPT, true), + pkcs11.NewAttribute(pkcs11.CKA_WRAP, true), + pkcs11.NewAttribute(pkcs11.CKA_UNWRAP, true), + } + + // Set up the parameters which include the peer's public key + ecdhParams := pkcs11.NewECDH1DeriveParams(pkcs11.CKD_NULL, nil, peerPubKey) + mech := pkcs11.NewMechanism(pkcs11.CKM_ECDH1_DERIVE, ecdhParams) + sk := p11.PrivateKey(c.privKeyObj) + + tmpKey, err := sk.Derive(*mech, attrTemplate) + if err != nil { + return nil, err + } + if tmpKey == nil || len(tmpKey) == 0 { + return nil, fmt.Errorf("got an empty secret key") + } + secret := make([]byte, NoiseKeySize) + copy(secret[:], tmpKey[:NoiseKeySize]) + return secret, nil +} + +func (c *PKClient) GetPubKey() ([]byte, error) { + d, err := c.privKeyObj.Attribute(pkcs11.CKA_PUBLIC_KEY_INFO) + if err != nil { + return nil, err + } + if d != nil && len(d) > 0 { + return formatPubkeyFromPublicKeyInfoAttr(d) + } + c.pubKeyObj, err = c.findDeriveKey(c.id, c.label, false) + if err != nil { + return nil, fmt.Errorf("pkcs11 module gave us a nil CKA_PUBLIC_KEY_INFO, and looking up the public key also failed: %w", err) + } + d, err = c.pubKeyObj.Attribute(pkcs11.CKA_EC_POINT) + if err != nil { + return nil, fmt.Errorf("pkcs11 module gave us a nil CKA_PUBLIC_KEY_INFO, and reading CKA_EC_POINT also failed: %w", err) + } + if d == nil || len(d) < 1 { + return nil, fmt.Errorf("pkcs11 module gave us a nil or empty CKA_EC_POINT") + } + switch len(d) { + case 65: //length of 0x04 + len(X) + len(Y) + return d, nil + case 67: //as above, DER-encoded IIRC? + return d[2:], nil + default: + return nil, fmt.Errorf("unknown public key length: %d", len(d)) + } +} diff --git a/pkclient/pkclient_stub.go b/pkclient/pkclient_stub.go new file mode 100644 index 0000000..36b0fc9 --- /dev/null +++ b/pkclient/pkclient_stub.go @@ -0,0 +1,30 @@ +//go:build !cgo || !pkcs11 + +package pkclient + +import "errors" + +type PKClient struct { +} + +var notImplemented = errors.New("not implemented") + +func New(hsmPath string, slotId uint, pin string, id string, label string) (*PKClient, error) { + return nil, notImplemented +} + +func (c *PKClient) Close() error { + return nil +} + +func (c *PKClient) SignASN1(data []byte) ([]byte, error) { + return nil, notImplemented +} + +func (c *PKClient) DeriveNoise(_ []byte) ([]byte, error) { + return nil, notImplemented +} + +func (c *PKClient) GetPubKey() ([]byte, error) { + return nil, notImplemented +} diff --git a/pki.go b/pki.go index ab95a04..511d305 100644 --- a/pki.go +++ b/pki.go @@ -141,8 +141,33 @@ func newCertState(certificate *cert.NebulaCertificate, privateKey []byte) (*Cert return cs, nil } -func newCertStateFromConfig(c *config.C) (*CertState, error) { +func loadPrivateKey(privPathOrPEM string) (rawKey []byte, curve cert.Curve, isPkcs11 bool, err error) { var pemPrivateKey []byte + if strings.Contains(privPathOrPEM, "-----BEGIN") { + pemPrivateKey = []byte(privPathOrPEM) + privPathOrPEM = "" + rawKey, _, curve, err = cert.UnmarshalPrivateKey(pemPrivateKey) + if err != nil { + return nil, curve, false, fmt.Errorf("error while unmarshaling pki.key %s: %s", privPathOrPEM, err) + } + } else if strings.HasPrefix(privPathOrPEM, "pkcs11:") { + rawKey = []byte(privPathOrPEM) + return rawKey, cert.Curve_P256, true, nil + } else { + pemPrivateKey, err = os.ReadFile(privPathOrPEM) + if err != nil { + return nil, curve, false, fmt.Errorf("unable to read pki.key file %s: %s", privPathOrPEM, err) + } + rawKey, _, curve, err = cert.UnmarshalPrivateKey(pemPrivateKey) + if err != nil { + return nil, curve, false, fmt.Errorf("error while unmarshaling pki.key %s: %s", privPathOrPEM, err) + } + } + + return +} + +func newCertStateFromConfig(c *config.C) (*CertState, error) { var err error privPathOrPEM := c.GetString("pki.key", "") @@ -150,20 +175,9 @@ func newCertStateFromConfig(c *config.C) (*CertState, error) { return nil, errors.New("no pki.key path or PEM data provided") } - if strings.Contains(privPathOrPEM, "-----BEGIN") { - pemPrivateKey = []byte(privPathOrPEM) - privPathOrPEM = "" - - } else { - pemPrivateKey, err = os.ReadFile(privPathOrPEM) - if err != nil { - return nil, fmt.Errorf("unable to read pki.key file %s: %s", privPathOrPEM, err) - } - } - - rawKey, _, curve, err := cert.UnmarshalPrivateKey(pemPrivateKey) + rawKey, curve, isPkcs11, err := loadPrivateKey(privPathOrPEM) if err != nil { - return nil, fmt.Errorf("error while unmarshaling pki.key %s: %s", privPathOrPEM, err) + return nil, err } var rawCert []byte @@ -188,7 +202,7 @@ func newCertStateFromConfig(c *config.C) (*CertState, error) { if err != nil { return nil, fmt.Errorf("error while unmarshaling pki.cert %s: %s", pubPathOrPEM, err) } - + nebulaCert.Pkcs11Backed = isPkcs11 if nebulaCert.Expired(time.Now()) { return nil, fmt.Errorf("nebula certificate for this host is expired") } From 16eaae306afe95ac9876e333fc1636c98cd999a4 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 9 Sep 2024 17:53:26 -0400 Subject: [PATCH 26/67] Bump dario.cat/mergo from 1.0.0 to 1.0.1 (#1200) Bumps [dario.cat/mergo](https://github.com/imdario/mergo) from 1.0.0 to 1.0.1. - [Release notes](https://github.com/imdario/mergo/releases) - [Commits](https://github.com/imdario/mergo/compare/v1.0.0...v1.0.1) --- updated-dependencies: - dependency-name: dario.cat/mergo dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- go.mod | 2 +- go.sum | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/go.mod b/go.mod index 4d36f34..be7f0b6 100644 --- a/go.mod +++ b/go.mod @@ -5,7 +5,7 @@ go 1.22.0 toolchain go1.22.2 require ( - dario.cat/mergo v1.0.0 + dario.cat/mergo v1.0.1 github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be github.com/armon/go-radix v1.0.0 github.com/cyberdelia/go-metrics-graphite v0.0.0-20161219230853-39f87cc3b432 diff --git a/go.sum b/go.sum index e90ae8e..10d12a1 100644 --- a/go.sum +++ b/go.sum @@ -1,6 +1,6 @@ cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= -dario.cat/mergo v1.0.0 h1:AGCNq9Evsj31mOgNPcLyXc+4PNABt905YmuqPYYpBWk= -dario.cat/mergo v1.0.0/go.mod h1:uNxQE+84aUszobStD9th8a29P2fMDhsBdgRYvZOxGmk= +dario.cat/mergo v1.0.1 h1:Ra4+bf83h2ztPIQYNP99R6m+Y7KfnARDfID+a+vLl4s= +dario.cat/mergo v1.0.1/go.mod h1:uNxQE+84aUszobStD9th8a29P2fMDhsBdgRYvZOxGmk= github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= From 08ac65362e344a1ac1905802b36fd1b5cb5b034b Mon Sep 17 00:00:00 2001 From: Nate Brown Date: Thu, 10 Oct 2024 18:00:22 -0500 Subject: [PATCH 27/67] Cert interface (#1212) --- cert/Makefile | 2 +- cert/ca.go | 140 ---- cert/ca_pool.go | 296 +++++++ cert/ca_pool_test.go | 109 +++ cert/cert.go | 1141 +++------------------------ cert/cert_test.go | 1165 ++++++++-------------------- cert/cert_v1.go | 496 ++++++++++++ cert/{cert.pb.go => cert_v1.pb.go} | 220 +++--- cert/{cert.proto => cert_v1.proto} | 0 cert/crypto.go | 161 +++- cert/crypto_test.go | 87 +++ cert/errors.go | 25 +- cert/pem.go | 155 ++++ cert/pem_test.go | 292 +++++++ cert/sign.go | 76 ++ cmd/nebula-cert/ca.go | 57 +- cmd/nebula-cert/ca_test.go | 29 +- cmd/nebula-cert/keygen.go | 4 +- cmd/nebula-cert/keygen_test.go | 6 +- cmd/nebula-cert/print.go | 6 +- cmd/nebula-cert/print_test.go | 96 ++- cmd/nebula-cert/sign.go | 94 +-- cmd/nebula-cert/sign_test.go | 69 +- cmd/nebula-cert/verify.go | 8 +- cmd/nebula-cert/verify_test.go | 39 +- connection_manager.go | 12 +- connection_manager_test.go | 178 ++++- connection_state.go | 10 +- control.go | 26 +- control_test.go | 26 +- control_tester.go | 2 +- dns_server.go | 8 +- e2e/handshakes_test.go | 35 +- e2e/helpers.go | 83 +- e2e/helpers_test.go | 6 +- e2e/router/hostmap.go | 6 +- firewall.go | 70 +- firewall_test.go | 267 +++---- handshake_ix.go | 26 +- handshake_manager.go | 2 +- handshake_manager_test.go | 3 +- hostmap.go | 26 +- interface.go | 33 +- main.go | 12 +- outside.go | 29 +- pki.go | 46 +- service/service_test.go | 6 +- ssh.go | 4 +- test/assert.go | 6 + 49 files changed, 2862 insertions(+), 2833 deletions(-) delete mode 100644 cert/ca.go create mode 100644 cert/ca_pool.go create mode 100644 cert/ca_pool_test.go create mode 100644 cert/cert_v1.go rename cert/{cert.pb.go => cert_v1.pb.go} (62%) rename cert/{cert.proto => cert_v1.proto} (100%) create mode 100644 cert/pem.go create mode 100644 cert/pem_test.go create mode 100644 cert/sign.go diff --git a/cert/Makefile b/cert/Makefile index 28170b6..311afc2 100644 --- a/cert/Makefile +++ b/cert/Makefile @@ -1,7 +1,7 @@ GO111MODULE = on export GO111MODULE -cert.pb.go: cert.proto .FORCE +cert_v1.pb.go: cert_v1.proto .FORCE go build google.golang.org/protobuf/cmd/protoc-gen-go PATH="$(CURDIR):$(PATH)" protoc --go_out=. --go_opt=paths=source_relative $< rm protoc-gen-go diff --git a/cert/ca.go b/cert/ca.go deleted file mode 100644 index 0ffbd87..0000000 --- a/cert/ca.go +++ /dev/null @@ -1,140 +0,0 @@ -package cert - -import ( - "errors" - "fmt" - "strings" - "time" -) - -type NebulaCAPool struct { - CAs map[string]*NebulaCertificate - certBlocklist map[string]struct{} -} - -// NewCAPool creates a CAPool -func NewCAPool() *NebulaCAPool { - ca := NebulaCAPool{ - CAs: make(map[string]*NebulaCertificate), - certBlocklist: make(map[string]struct{}), - } - - return &ca -} - -// NewCAPoolFromBytes will create a new CA pool from the provided -// input bytes, which must be a PEM-encoded set of nebula certificates. -// If the pool contains any expired certificates, an ErrExpired will be -// returned along with the pool. The caller must handle any such errors. -func NewCAPoolFromBytes(caPEMs []byte) (*NebulaCAPool, error) { - pool := NewCAPool() - var err error - var expired bool - for { - caPEMs, err = pool.AddCACertificate(caPEMs) - if errors.Is(err, ErrExpired) { - expired = true - err = nil - } - if err != nil { - return nil, err - } - if len(caPEMs) == 0 || strings.TrimSpace(string(caPEMs)) == "" { - break - } - } - - if expired { - return pool, ErrExpired - } - - return pool, nil -} - -// AddCACertificate verifies a Nebula CA certificate and adds it to the pool -// Only the first pem encoded object will be consumed, any remaining bytes are returned. -// Parsed certificates will be verified and must be a CA -func (ncp *NebulaCAPool) AddCACertificate(pemBytes []byte) ([]byte, error) { - c, pemBytes, err := UnmarshalNebulaCertificateFromPEM(pemBytes) - if err != nil { - return pemBytes, err - } - - if !c.Details.IsCA { - return pemBytes, fmt.Errorf("%s: %w", c.Details.Name, ErrNotCA) - } - - if !c.CheckSignature(c.Details.PublicKey) { - return pemBytes, fmt.Errorf("%s: %w", c.Details.Name, ErrNotSelfSigned) - } - - sum, err := c.Sha256Sum() - if err != nil { - return pemBytes, fmt.Errorf("could not calculate shasum for provided CA; error: %s; %s", err, c.Details.Name) - } - - ncp.CAs[sum] = c - if c.Expired(time.Now()) { - return pemBytes, fmt.Errorf("%s: %w", c.Details.Name, ErrExpired) - } - - return pemBytes, nil -} - -// BlocklistFingerprint adds a cert fingerprint to the blocklist -func (ncp *NebulaCAPool) BlocklistFingerprint(f string) { - ncp.certBlocklist[f] = struct{}{} -} - -// ResetCertBlocklist removes all previously blocklisted cert fingerprints -func (ncp *NebulaCAPool) ResetCertBlocklist() { - ncp.certBlocklist = make(map[string]struct{}) -} - -// NOTE: This uses an internal cache for Sha256Sum() that will not be invalidated -// automatically if you manually change any fields in the NebulaCertificate. -func (ncp *NebulaCAPool) IsBlocklisted(c *NebulaCertificate) bool { - return ncp.isBlocklistedWithCache(c, false) -} - -// IsBlocklisted returns true if the fingerprint fails to generate or has been explicitly blocklisted -func (ncp *NebulaCAPool) isBlocklistedWithCache(c *NebulaCertificate, useCache bool) bool { - h, err := c.sha256SumWithCache(useCache) - if err != nil { - return true - } - - if _, ok := ncp.certBlocklist[h]; ok { - return true - } - - return false -} - -// GetCAForCert attempts to return the signing certificate for the provided certificate. -// No signature validation is performed -func (ncp *NebulaCAPool) GetCAForCert(c *NebulaCertificate) (*NebulaCertificate, error) { - if c.Details.Issuer == "" { - return nil, fmt.Errorf("no issuer in certificate") - } - - signer, ok := ncp.CAs[c.Details.Issuer] - if ok { - return signer, nil - } - - return nil, fmt.Errorf("could not find ca for the certificate") -} - -// GetFingerprints returns an array of trusted CA fingerprints -func (ncp *NebulaCAPool) GetFingerprints() []string { - fp := make([]string, len(ncp.CAs)) - - i := 0 - for k := range ncp.CAs { - fp[i] = k - i++ - } - - return fp -} diff --git a/cert/ca_pool.go b/cert/ca_pool.go new file mode 100644 index 0000000..d525830 --- /dev/null +++ b/cert/ca_pool.go @@ -0,0 +1,296 @@ +package cert + +import ( + "errors" + "fmt" + "net/netip" + "slices" + "strings" + "time" +) + +type CAPool struct { + CAs map[string]*CachedCertificate + certBlocklist map[string]struct{} +} + +// NewCAPool creates an empty CAPool +func NewCAPool() *CAPool { + ca := CAPool{ + CAs: make(map[string]*CachedCertificate), + certBlocklist: make(map[string]struct{}), + } + + return &ca +} + +// NewCAPoolFromPEM will create a new CA pool from the provided +// input bytes, which must be a PEM-encoded set of nebula certificates. +// If the pool contains any expired certificates, an ErrExpired will be +// returned along with the pool. The caller must handle any such errors. +func NewCAPoolFromPEM(caPEMs []byte) (*CAPool, error) { + pool := NewCAPool() + var err error + var expired bool + for { + caPEMs, err = pool.AddCAFromPEM(caPEMs) + if errors.Is(err, ErrExpired) { + expired = true + err = nil + } + if err != nil { + return nil, err + } + if len(caPEMs) == 0 || strings.TrimSpace(string(caPEMs)) == "" { + break + } + } + + if expired { + return pool, ErrExpired + } + + return pool, nil +} + +// AddCAFromPEM verifies a Nebula CA certificate and adds it to the pool. +// Only the first pem encoded object will be consumed, any remaining bytes are returned. +// Parsed certificates will be verified and must be a CA +func (ncp *CAPool) AddCAFromPEM(pemBytes []byte) ([]byte, error) { + c, pemBytes, err := UnmarshalCertificateFromPEM(pemBytes) + if err != nil { + return pemBytes, err + } + + err = ncp.AddCA(c) + if err != nil { + return pemBytes, err + } + + return pemBytes, nil +} + +// AddCA verifies a Nebula CA certificate and adds it to the pool. +func (ncp *CAPool) AddCA(c Certificate) error { + if !c.IsCA() { + return fmt.Errorf("%s: %w", c.Name(), ErrNotCA) + } + + if !c.CheckSignature(c.PublicKey()) { + return fmt.Errorf("%s: %w", c.Name(), ErrNotSelfSigned) + } + + sum, err := c.Fingerprint() + if err != nil { + return fmt.Errorf("could not calculate fingerprint for provided CA; error: %w; %s", err, c.Name()) + } + + cc := &CachedCertificate{ + Certificate: c, + Fingerprint: sum, + InvertedGroups: make(map[string]struct{}), + } + + for _, g := range c.Groups() { + cc.InvertedGroups[g] = struct{}{} + } + + ncp.CAs[sum] = cc + + if c.Expired(time.Now()) { + return fmt.Errorf("%s: %w", c.Name(), ErrExpired) + } + + return nil +} + +// BlocklistFingerprint adds a cert fingerprint to the blocklist +func (ncp *CAPool) BlocklistFingerprint(f string) { + ncp.certBlocklist[f] = struct{}{} +} + +// ResetCertBlocklist removes all previously blocklisted cert fingerprints +func (ncp *CAPool) ResetCertBlocklist() { + ncp.certBlocklist = make(map[string]struct{}) +} + +// IsBlocklisted tests the provided fingerprint against the pools blocklist. +// Returns true if the fingerprint is blocked. +func (ncp *CAPool) IsBlocklisted(fingerprint string) bool { + if _, ok := ncp.certBlocklist[fingerprint]; ok { + return true + } + + return false +} + +// VerifyCertificate verifies the certificate is valid and is signed by a trusted CA in the pool. +// If the certificate is valid then the returned CachedCertificate can be used in subsequent verification attempts +// to increase performance. +func (ncp *CAPool) VerifyCertificate(now time.Time, c Certificate) (*CachedCertificate, error) { + if c == nil { + return nil, fmt.Errorf("no certificate") + } + fp, err := c.Fingerprint() + if err != nil { + return nil, fmt.Errorf("could not calculate fingerprint to verify: %w", err) + } + + signer, err := ncp.verify(c, now, fp, "") + if err != nil { + return nil, err + } + + cc := CachedCertificate{ + Certificate: c, + InvertedGroups: make(map[string]struct{}), + Fingerprint: fp, + signerFingerprint: signer.Fingerprint, + } + + for _, g := range c.Groups() { + cc.InvertedGroups[g] = struct{}{} + } + + return &cc, nil +} + +// VerifyCachedCertificate is the same as VerifyCertificate other than it operates on a pre-verified structure and +// is a cheaper operation to perform as a result. +func (ncp *CAPool) VerifyCachedCertificate(now time.Time, c *CachedCertificate) error { + _, err := ncp.verify(c.Certificate, now, c.Fingerprint, c.signerFingerprint) + return err +} + +func (ncp *CAPool) verify(c Certificate, now time.Time, certFp string, signerFp string) (*CachedCertificate, error) { + if ncp.IsBlocklisted(certFp) { + return nil, ErrBlockListed + } + + signer, err := ncp.GetCAForCert(c) + if err != nil { + return nil, err + } + + if signer.Certificate.Expired(now) { + return nil, ErrRootExpired + } + + if c.Expired(now) { + return nil, ErrExpired + } + + // If we are checking a cached certificate then we can bail early here + // Either the root is no longer trusted or everything is fine + if len(signerFp) > 0 { + if signerFp != signer.Fingerprint { + return nil, ErrFingerprintMismatch + } + return signer, nil + } + if !c.CheckSignature(signer.Certificate.PublicKey()) { + return nil, ErrSignatureMismatch + } + + err = CheckCAConstraints(signer.Certificate, c) + if err != nil { + return nil, err + } + + return signer, nil +} + +// GetCAForCert attempts to return the signing certificate for the provided certificate. +// No signature validation is performed +func (ncp *CAPool) GetCAForCert(c Certificate) (*CachedCertificate, error) { + issuer := c.Issuer() + if issuer == "" { + return nil, fmt.Errorf("no issuer in certificate") + } + + signer, ok := ncp.CAs[issuer] + if ok { + return signer, nil + } + + return nil, fmt.Errorf("could not find ca for the certificate") +} + +// GetFingerprints returns an array of trusted CA fingerprints +func (ncp *CAPool) GetFingerprints() []string { + fp := make([]string, len(ncp.CAs)) + + i := 0 + for k := range ncp.CAs { + fp[i] = k + i++ + } + + return fp +} + +// CheckCAConstraints returns an error if the sub certificate violates constraints present in the signer certificate. +func CheckCAConstraints(signer Certificate, sub Certificate) error { + return checkCAConstraints(signer, sub.NotBefore(), sub.NotAfter(), sub.Groups(), sub.Networks(), sub.UnsafeNetworks()) +} + +// checkCAConstraints is a very generic function allowing both Certificates and TBSCertificates to be tested. +func checkCAConstraints(signer Certificate, notBefore, notAfter time.Time, groups []string, networks, unsafeNetworks []netip.Prefix) error { + // Make sure this cert isn't valid after the root + if notAfter.After(signer.NotAfter()) { + return fmt.Errorf("certificate expires after signing certificate") + } + + // Make sure this cert wasn't valid before the root + if notBefore.Before(signer.NotBefore()) { + return fmt.Errorf("certificate is valid before the signing certificate") + } + + // If the signer has a limited set of groups make sure the cert only contains a subset + signerGroups := signer.Groups() + if len(signerGroups) > 0 { + for _, g := range groups { + if !slices.Contains(signerGroups, g) { + return fmt.Errorf("certificate contained a group not present on the signing ca: %s", g) + } + } + } + + // If the signer has a limited set of ip ranges to issue from make sure the cert only contains a subset + signingNetworks := signer.Networks() + if len(signingNetworks) > 0 { + for _, certNetwork := range networks { + found := false + for _, signingNetwork := range signingNetworks { + if signingNetwork.Contains(certNetwork.Addr()) && signingNetwork.Bits() <= certNetwork.Bits() { + found = true + break + } + } + + if !found { + return fmt.Errorf("certificate contained a network assignment outside the limitations of the signing ca: %s", certNetwork.String()) + } + } + } + + // If the signer has a limited set of subnet ranges to issue from make sure the cert only contains a subset + signingUnsafeNetworks := signer.UnsafeNetworks() + if len(signingUnsafeNetworks) > 0 { + for _, certUnsafeNetwork := range unsafeNetworks { + found := false + for _, caNetwork := range signingUnsafeNetworks { + if caNetwork.Contains(certUnsafeNetwork.Addr()) && caNetwork.Bits() <= certUnsafeNetwork.Bits() { + found = true + break + } + } + + if !found { + return fmt.Errorf("certificate contained an unsafe network assignment outside the limitations of the signing ca: %s", certUnsafeNetwork.String()) + } + } + } + + return nil +} diff --git a/cert/ca_pool_test.go b/cert/ca_pool_test.go new file mode 100644 index 0000000..053640d --- /dev/null +++ b/cert/ca_pool_test.go @@ -0,0 +1,109 @@ +package cert + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestNewCAPoolFromBytes(t *testing.T) { + noNewLines := ` +# Current provisional, Remove once everything moves over to the real root. +-----BEGIN NEBULA CERTIFICATE----- +CkAKDm5lYnVsYSByb290IGNhKJfap9AFMJfg1+YGOiCUQGByMuNRhIlQBOyzXWbL +vcKBwDhov900phEfJ5DN3kABEkDCq5R8qBiu8sl54yVfgRcQXEDt3cHr8UTSLszv +bzBEr00kERQxxTzTsH8cpYEgRoipvmExvg8WP8NdAJEYJosB +-----END NEBULA CERTIFICATE----- +# root-ca01 +-----BEGIN NEBULA CERTIFICATE----- +CkMKEW5lYnVsYSByb290IGNhIDAxKJL2u9EFMJL86+cGOiDPXMH4oU6HZTk/CqTG +BVG+oJpAoqokUBbI4U0N8CSfpUABEkB/Pm5A2xyH/nc8mg/wvGUWG3pZ7nHzaDMf +8/phAUt+FLzqTECzQKisYswKvE3pl9mbEYKbOdIHrxdIp95mo4sF +-----END NEBULA CERTIFICATE----- +` + + withNewLines := ` +# Current provisional, Remove once everything moves over to the real root. + +-----BEGIN NEBULA CERTIFICATE----- +CkAKDm5lYnVsYSByb290IGNhKJfap9AFMJfg1+YGOiCUQGByMuNRhIlQBOyzXWbL +vcKBwDhov900phEfJ5DN3kABEkDCq5R8qBiu8sl54yVfgRcQXEDt3cHr8UTSLszv +bzBEr00kERQxxTzTsH8cpYEgRoipvmExvg8WP8NdAJEYJosB +-----END NEBULA CERTIFICATE----- + +# root-ca01 + + +-----BEGIN NEBULA CERTIFICATE----- +CkMKEW5lYnVsYSByb290IGNhIDAxKJL2u9EFMJL86+cGOiDPXMH4oU6HZTk/CqTG +BVG+oJpAoqokUBbI4U0N8CSfpUABEkB/Pm5A2xyH/nc8mg/wvGUWG3pZ7nHzaDMf +8/phAUt+FLzqTECzQKisYswKvE3pl9mbEYKbOdIHrxdIp95mo4sF +-----END NEBULA CERTIFICATE----- + +` + + expired := ` +# expired certificate +-----BEGIN NEBULA CERTIFICATE----- +CjkKB2V4cGlyZWQouPmWjQYwufmWjQY6ILCRaoCkJlqHgv5jfDN4lzLHBvDzaQm4 +vZxfu144hmgjQAESQG4qlnZi8DncvD/LDZnLgJHOaX1DWCHHEh59epVsC+BNgTie +WH1M9n4O7cFtGlM6sJJOS+rCVVEJ3ABS7+MPdQs= +-----END NEBULA CERTIFICATE----- +` + + p256 := ` +# p256 certificate +-----BEGIN NEBULA CERTIFICATE----- +CmYKEG5lYnVsYSBQMjU2IHRlc3Qo4s+7mgYw4tXrsAc6QQRkaW2jFmllYvN4+/k2 +6tctO9sPT3jOx8ES6M1nIqOhpTmZeabF/4rELDqPV4aH5jfJut798DUXql0FlF8H +76gvQAGgBgESRzBFAiEAib0/te6eMiZOKD8gdDeloMTS0wGuX2t0C7TFdUhAQzgC +IBNWYMep3ysx9zCgknfG5dKtwGTaqF++BWKDYdyl34KX +-----END NEBULA CERTIFICATE----- +` + + rootCA := certificateV1{ + details: detailsV1{ + Name: "nebula root ca", + }, + } + + rootCA01 := certificateV1{ + details: detailsV1{ + Name: "nebula root ca 01", + }, + } + + rootCAP256 := certificateV1{ + details: detailsV1{ + Name: "nebula P256 test", + }, + } + + p, err := NewCAPoolFromPEM([]byte(noNewLines)) + assert.Nil(t, err) + assert.Equal(t, p.CAs[string("c9bfaf7ce8e84b2eeda2e27b469f4b9617bde192efd214b68891ecda6ed49522")].Certificate.Name(), rootCA.details.Name) + assert.Equal(t, p.CAs[string("5c9c3f23e7ee7fe97637cbd3a0a5b854154d1d9aaaf7b566a51f4a88f76b64cd")].Certificate.Name(), rootCA01.details.Name) + + pp, err := NewCAPoolFromPEM([]byte(withNewLines)) + assert.Nil(t, err) + assert.Equal(t, pp.CAs[string("c9bfaf7ce8e84b2eeda2e27b469f4b9617bde192efd214b68891ecda6ed49522")].Certificate.Name(), rootCA.details.Name) + assert.Equal(t, pp.CAs[string("5c9c3f23e7ee7fe97637cbd3a0a5b854154d1d9aaaf7b566a51f4a88f76b64cd")].Certificate.Name(), rootCA01.details.Name) + + // expired cert, no valid certs + ppp, err := NewCAPoolFromPEM([]byte(expired)) + assert.Equal(t, ErrExpired, err) + assert.Equal(t, ppp.CAs[string("152070be6bb19bc9e3bde4c2f0e7d8f4ff5448b4c9856b8eccb314fade0229b0")].Certificate.Name(), "expired") + + // expired cert, with valid certs + pppp, err := NewCAPoolFromPEM(append([]byte(expired), noNewLines...)) + assert.Equal(t, ErrExpired, err) + assert.Equal(t, pppp.CAs[string("c9bfaf7ce8e84b2eeda2e27b469f4b9617bde192efd214b68891ecda6ed49522")].Certificate.Name(), rootCA.details.Name) + assert.Equal(t, pppp.CAs[string("5c9c3f23e7ee7fe97637cbd3a0a5b854154d1d9aaaf7b566a51f4a88f76b64cd")].Certificate.Name(), rootCA01.details.Name) + assert.Equal(t, pppp.CAs[string("152070be6bb19bc9e3bde4c2f0e7d8f4ff5448b4c9856b8eccb314fade0229b0")].Certificate.Name(), "expired") + assert.Equal(t, len(pppp.CAs), 3) + + ppppp, err := NewCAPoolFromPEM([]byte(p256)) + assert.Nil(t, err) + assert.Equal(t, ppppp.CAs[string("a7938893ec8c4ef769b06d7f425e5e46f7a7f5ffa49c3bcf4a86b608caba9159")].Certificate.Name(), rootCAP256.details.Name) + assert.Equal(t, len(ppppp.CAs), 1) +} diff --git a/cert/cert.go b/cert/cert.go index dd08923..02c8877 100644 --- a/cert/cert.go +++ b/cert/cert.go @@ -1,1062 +1,129 @@ package cert import ( - "bytes" - "crypto/ecdh" - "crypto/ecdsa" - "crypto/ed25519" - "crypto/elliptic" - "crypto/rand" - "crypto/sha256" - "encoding/binary" - "encoding/hex" - "encoding/json" - "encoding/pem" - "errors" - "fmt" - "math" - "math/big" - "net" - "sync/atomic" + "net/netip" "time" - - "github.com/slackhq/nebula/pkclient" - "golang.org/x/crypto/curve25519" - "google.golang.org/protobuf/proto" ) -const publicKeyLen = 32 +type Version int const ( - CertBanner = "NEBULA CERTIFICATE" - X25519PrivateKeyBanner = "NEBULA X25519 PRIVATE KEY" - X25519PublicKeyBanner = "NEBULA X25519 PUBLIC KEY" - EncryptedEd25519PrivateKeyBanner = "NEBULA ED25519 ENCRYPTED PRIVATE KEY" - Ed25519PrivateKeyBanner = "NEBULA ED25519 PRIVATE KEY" - Ed25519PublicKeyBanner = "NEBULA ED25519 PUBLIC KEY" - - P256PrivateKeyBanner = "NEBULA P256 PRIVATE KEY" - P256PublicKeyBanner = "NEBULA P256 PUBLIC KEY" - EncryptedECDSAP256PrivateKeyBanner = "NEBULA ECDSA P256 ENCRYPTED PRIVATE KEY" - ECDSAP256PrivateKeyBanner = "NEBULA ECDSA P256 PRIVATE KEY" + Version1 Version = 1 + Version2 Version = 2 ) -type NebulaCertificate struct { - Details NebulaCertificateDetails - Pkcs11Backed bool - Signature []byte +type Certificate interface { + // Version defines the underlying certificate structure and wire protocol version + // Version1 certificates are ipv4 only and uses protobuf serialization + // Version2 certificates are ipv4 or ipv6 and uses asn.1 serialization + Version() Version - // the cached hex string of the calculated sha256sum - // for VerifyWithCache - sha256sum atomic.Pointer[string] + // Name is the human-readable name that identifies this certificate. + Name() string - // the cached public key bytes if they were verified as the signer - // for VerifyWithCache - signatureVerified atomic.Pointer[[]byte] + // Networks is a list of ip addresses and network sizes assigned to this certificate. + // If IsCA is true then certificates signed by this CA can only have ip addresses and + // networks that are contained by an entry in this list. + Networks() []netip.Prefix + + // UnsafeNetworks is a list of networks that this host can act as an unsafe router for. + // If IsCA is true then certificates signed by this CA can only have networks that are + // contained by an entry in this list. + UnsafeNetworks() []netip.Prefix + + // Groups is a list of identities that can be used to write more general firewall rule + // definitions. + // If IsCA is true then certificates signed by this CA can only use groups that are + // in this list. + Groups() []string + + // IsCA signifies if this is a certificate authority (true) or a host certificate (false). + // It is invalid to use a CA certificate as a host certificate. + IsCA() bool + + // NotBefore is the time at which this certificate becomes valid. + // If IsCA is true then certificate signed by this CA can not have a time before this. + NotBefore() time.Time + + // NotAfter is the time at which this certificate becomes invalid. + // If IsCA is true then certificate signed by this CA can not have a time after this. + NotAfter() time.Time + + // Issuer is the fingerprint of the CA that signed this certificate. + // If IsCA is true then this will be empty. + Issuer() string + + // PublicKey is the raw bytes to be used in asymmetric cryptographic operations. + PublicKey() []byte + + // Curve identifies which curve was used for the PublicKey and Signature. + Curve() Curve + + // Signature is the cryptographic seal for all the details of this certificate. + // CheckSignature can be used to verify that the details of this certificate are valid. + Signature() []byte + + // CheckSignature will check that the certificate Signature() matches the + // computed signature. A true result means this certificate has not been tampered with. + CheckSignature(signingPublicKey []byte) bool + + // Fingerprint returns the hex encoded sha256 sum of the certificate. + // This acts as a unique fingerprint and can be used to blocklist certificates. + Fingerprint() (string, error) + + // Expired tests if the certificate is valid for the provided time. + Expired(t time.Time) bool + + // VerifyPrivateKey returns an error if the private key is not a pair with the certificates public key. + VerifyPrivateKey(curve Curve, privateKey []byte) error + + // Marshal will return the byte representation of this certificate + // This is primarily the format transmitted on the wire. + Marshal() ([]byte, error) + + // MarshalForHandshakes prepares the bytes needed to use directly in a handshake + MarshalForHandshakes() ([]byte, error) + + // MarshalPEM will return a PEM encoded representation of this certificate + // This is primarily the format stored on disk + MarshalPEM() ([]byte, error) + + // MarshalJSON will return the json representation of this certificate + MarshalJSON() ([]byte, error) + + // String will return a human-readable representation of this certificate + String() string + + // Copy creates a copy of the certificate + Copy() Certificate } -type NebulaCertificateDetails struct { - Name string - Ips []*net.IPNet - Subnets []*net.IPNet - Groups []string - NotBefore time.Time - NotAfter time.Time - PublicKey []byte - IsCA bool - Issuer string - - // Map of groups for faster lookup - InvertedGroups map[string]struct{} - - Curve Curve +// CachedCertificate represents a verified certificate with some cached fields to improve +// performance. +type CachedCertificate struct { + Certificate Certificate + InvertedGroups map[string]struct{} + Fingerprint string + signerFingerprint string } -type NebulaEncryptedData struct { - EncryptionMetadata NebulaEncryptionMetadata - Ciphertext []byte -} - -type NebulaEncryptionMetadata struct { - EncryptionAlgorithm string - Argon2Parameters Argon2Parameters -} - -type m map[string]interface{} - -// Returned if we try to unmarshal an encrypted private key without a passphrase -var ErrPrivateKeyEncrypted = errors.New("private key must be decrypted") - -// UnmarshalNebulaCertificate will unmarshal a protobuf byte representation of a nebula cert -func UnmarshalNebulaCertificate(b []byte) (*NebulaCertificate, error) { - if len(b) == 0 { - return nil, fmt.Errorf("nil byte array") - } - var rc RawNebulaCertificate - err := proto.Unmarshal(b, &rc) +// UnmarshalCertificate will attempt to unmarshal a wire protocol level certificate. +func UnmarshalCertificate(b []byte) (Certificate, error) { + c, err := unmarshalCertificateV1(b, true) if err != nil { return nil, err } - - if rc.Details == nil { - return nil, fmt.Errorf("encoded Details was nil") - } - - if len(rc.Details.Ips)%2 != 0 { - return nil, fmt.Errorf("encoded IPs should be in pairs, an odd number was found") - } - - if len(rc.Details.Subnets)%2 != 0 { - return nil, fmt.Errorf("encoded Subnets should be in pairs, an odd number was found") - } - - nc := NebulaCertificate{ - Details: NebulaCertificateDetails{ - Name: rc.Details.Name, - Groups: make([]string, len(rc.Details.Groups)), - Ips: make([]*net.IPNet, len(rc.Details.Ips)/2), - Subnets: make([]*net.IPNet, len(rc.Details.Subnets)/2), - NotBefore: time.Unix(rc.Details.NotBefore, 0), - NotAfter: time.Unix(rc.Details.NotAfter, 0), - PublicKey: make([]byte, len(rc.Details.PublicKey)), - IsCA: rc.Details.IsCA, - InvertedGroups: make(map[string]struct{}), - Curve: rc.Details.Curve, - }, - Signature: make([]byte, len(rc.Signature)), - } - - copy(nc.Signature, rc.Signature) - copy(nc.Details.Groups, rc.Details.Groups) - nc.Details.Issuer = hex.EncodeToString(rc.Details.Issuer) - - if len(rc.Details.PublicKey) < publicKeyLen { - return nil, fmt.Errorf("Public key was fewer than 32 bytes; %v", len(rc.Details.PublicKey)) - } - copy(nc.Details.PublicKey, rc.Details.PublicKey) - - for i, rawIp := range rc.Details.Ips { - if i%2 == 0 { - nc.Details.Ips[i/2] = &net.IPNet{IP: int2ip(rawIp)} - } else { - nc.Details.Ips[i/2].Mask = net.IPMask(int2ip(rawIp)) - } - } - - for i, rawIp := range rc.Details.Subnets { - if i%2 == 0 { - nc.Details.Subnets[i/2] = &net.IPNet{IP: int2ip(rawIp)} - } else { - nc.Details.Subnets[i/2].Mask = net.IPMask(int2ip(rawIp)) - } - } - - for _, g := range rc.Details.Groups { - nc.Details.InvertedGroups[g] = struct{}{} - } - - return &nc, nil + return c, nil } -// UnmarshalNebulaCertificateFromPEM will unmarshal the first pem block in a byte array, returning any non consumed data -// or an error on failure -func UnmarshalNebulaCertificateFromPEM(b []byte) (*NebulaCertificate, []byte, error) { - p, r := pem.Decode(b) - if p == nil { - return nil, r, fmt.Errorf("input did not contain a valid PEM encoded block") - } - if p.Type != CertBanner { - return nil, r, fmt.Errorf("bytes did not contain a proper nebula certificate banner") - } - nc, err := UnmarshalNebulaCertificate(p.Bytes) - return nc, r, err -} - -func MarshalPrivateKey(curve Curve, b []byte) []byte { - switch curve { - case Curve_CURVE25519: - return pem.EncodeToMemory(&pem.Block{Type: X25519PrivateKeyBanner, Bytes: b}) - case Curve_P256: - return pem.EncodeToMemory(&pem.Block{Type: P256PrivateKeyBanner, Bytes: b}) - default: - return nil - } -} - -func MarshalSigningPrivateKey(curve Curve, b []byte) []byte { - switch curve { - case Curve_CURVE25519: - return pem.EncodeToMemory(&pem.Block{Type: Ed25519PrivateKeyBanner, Bytes: b}) - case Curve_P256: - return pem.EncodeToMemory(&pem.Block{Type: ECDSAP256PrivateKeyBanner, Bytes: b}) - default: - return nil - } -} - -// MarshalX25519PrivateKey is a simple helper to PEM encode an X25519 private key -func MarshalX25519PrivateKey(b []byte) []byte { - return pem.EncodeToMemory(&pem.Block{Type: X25519PrivateKeyBanner, Bytes: b}) -} - -// MarshalEd25519PrivateKey is a simple helper to PEM encode an Ed25519 private key -func MarshalEd25519PrivateKey(key ed25519.PrivateKey) []byte { - return pem.EncodeToMemory(&pem.Block{Type: Ed25519PrivateKeyBanner, Bytes: key}) -} - -func UnmarshalPrivateKey(b []byte) ([]byte, []byte, Curve, error) { - k, r := pem.Decode(b) - if k == nil { - return nil, r, 0, fmt.Errorf("input did not contain a valid PEM encoded block") - } - var expectedLen int - var curve Curve - switch k.Type { - case X25519PrivateKeyBanner: - expectedLen = 32 - curve = Curve_CURVE25519 - case P256PrivateKeyBanner: - expectedLen = 32 - curve = Curve_P256 - default: - return nil, r, 0, fmt.Errorf("bytes did not contain a proper nebula private key banner") - } - if len(k.Bytes) != expectedLen { - return nil, r, 0, fmt.Errorf("key was not %d bytes, is invalid %s private key", expectedLen, curve) - } - return k.Bytes, r, curve, nil -} - -func UnmarshalSigningPrivateKey(b []byte) ([]byte, []byte, Curve, error) { - k, r := pem.Decode(b) - if k == nil { - return nil, r, 0, fmt.Errorf("input did not contain a valid PEM encoded block") - } - var curve Curve - switch k.Type { - case EncryptedEd25519PrivateKeyBanner: - return nil, nil, Curve_CURVE25519, ErrPrivateKeyEncrypted - case EncryptedECDSAP256PrivateKeyBanner: - return nil, nil, Curve_P256, ErrPrivateKeyEncrypted - case Ed25519PrivateKeyBanner: - curve = Curve_CURVE25519 - if len(k.Bytes) != ed25519.PrivateKeySize { - return nil, r, 0, fmt.Errorf("key was not %d bytes, is invalid Ed25519 private key", ed25519.PrivateKeySize) - } - case ECDSAP256PrivateKeyBanner: - curve = Curve_P256 - if len(k.Bytes) != 32 { - return nil, r, 0, fmt.Errorf("key was not 32 bytes, is invalid ECDSA P256 private key") - } - default: - return nil, r, 0, fmt.Errorf("bytes did not contain a proper nebula Ed25519/ECDSA private key banner") - } - return k.Bytes, r, curve, nil -} - -// EncryptAndMarshalSigningPrivateKey is a simple helper to encrypt and PEM encode a private key -func EncryptAndMarshalSigningPrivateKey(curve Curve, b []byte, passphrase []byte, kdfParams *Argon2Parameters) ([]byte, error) { - ciphertext, err := aes256Encrypt(passphrase, kdfParams, b) +// UnmarshalCertificateFromHandshake will attempt to unmarshal a certificate received in a handshake. +// Handshakes save space by placing the peers public key in a different part of the packet, we have to +// reassemble the actual certificate structure with that in mind. +func UnmarshalCertificateFromHandshake(b []byte, publicKey []byte) (Certificate, error) { + c, err := unmarshalCertificateV1(b, false) if err != nil { return nil, err } - - b, err = proto.Marshal(&RawNebulaEncryptedData{ - EncryptionMetadata: &RawNebulaEncryptionMetadata{ - EncryptionAlgorithm: "AES-256-GCM", - Argon2Parameters: &RawNebulaArgon2Parameters{ - Version: kdfParams.version, - Memory: kdfParams.Memory, - Parallelism: uint32(kdfParams.Parallelism), - Iterations: kdfParams.Iterations, - Salt: kdfParams.salt, - }, - }, - Ciphertext: ciphertext, - }) - if err != nil { - return nil, err - } - - switch curve { - case Curve_CURVE25519: - return pem.EncodeToMemory(&pem.Block{Type: EncryptedEd25519PrivateKeyBanner, Bytes: b}), nil - case Curve_P256: - return pem.EncodeToMemory(&pem.Block{Type: EncryptedECDSAP256PrivateKeyBanner, Bytes: b}), nil - default: - return nil, fmt.Errorf("invalid curve: %v", curve) - } -} - -// UnmarshalX25519PrivateKey will try to pem decode an X25519 private key, returning any other bytes b -// or an error on failure -func UnmarshalX25519PrivateKey(b []byte) ([]byte, []byte, error) { - k, r := pem.Decode(b) - if k == nil { - return nil, r, fmt.Errorf("input did not contain a valid PEM encoded block") - } - if k.Type != X25519PrivateKeyBanner { - return nil, r, fmt.Errorf("bytes did not contain a proper nebula X25519 private key banner") - } - if len(k.Bytes) != publicKeyLen { - return nil, r, fmt.Errorf("key was not 32 bytes, is invalid X25519 private key") - } - - return k.Bytes, r, nil -} - -// UnmarshalEd25519PrivateKey will try to pem decode an Ed25519 private key, returning any other bytes b -// or an error on failure -func UnmarshalEd25519PrivateKey(b []byte) (ed25519.PrivateKey, []byte, error) { - k, r := pem.Decode(b) - if k == nil { - return nil, r, fmt.Errorf("input did not contain a valid PEM encoded block") - } - - if k.Type == EncryptedEd25519PrivateKeyBanner { - return nil, r, ErrPrivateKeyEncrypted - } else if k.Type != Ed25519PrivateKeyBanner { - return nil, r, fmt.Errorf("bytes did not contain a proper nebula Ed25519 private key banner") - } - - if len(k.Bytes) != ed25519.PrivateKeySize { - return nil, r, fmt.Errorf("key was not 64 bytes, is invalid ed25519 private key") - } - - return k.Bytes, r, nil -} - -// UnmarshalNebulaEncryptedData will unmarshal a protobuf byte representation of a nebula cert into its -// protobuf-generated struct. -func UnmarshalNebulaEncryptedData(b []byte) (*NebulaEncryptedData, error) { - if len(b) == 0 { - return nil, fmt.Errorf("nil byte array") - } - var rned RawNebulaEncryptedData - err := proto.Unmarshal(b, &rned) - if err != nil { - return nil, err - } - - if rned.EncryptionMetadata == nil { - return nil, fmt.Errorf("encoded EncryptionMetadata was nil") - } - - if rned.EncryptionMetadata.Argon2Parameters == nil { - return nil, fmt.Errorf("encoded Argon2Parameters was nil") - } - - params, err := unmarshalArgon2Parameters(rned.EncryptionMetadata.Argon2Parameters) - if err != nil { - return nil, err - } - - ned := NebulaEncryptedData{ - EncryptionMetadata: NebulaEncryptionMetadata{ - EncryptionAlgorithm: rned.EncryptionMetadata.EncryptionAlgorithm, - Argon2Parameters: *params, - }, - Ciphertext: rned.Ciphertext, - } - - return &ned, nil -} - -func unmarshalArgon2Parameters(params *RawNebulaArgon2Parameters) (*Argon2Parameters, error) { - if params.Version < math.MinInt32 || params.Version > math.MaxInt32 { - return nil, fmt.Errorf("Argon2Parameters Version must be at least %d and no more than %d", math.MinInt32, math.MaxInt32) - } - if params.Memory <= 0 || params.Memory > math.MaxUint32 { - return nil, fmt.Errorf("Argon2Parameters Memory must be be greater than 0 and no more than %d KiB", uint32(math.MaxUint32)) - } - if params.Parallelism <= 0 || params.Parallelism > math.MaxUint8 { - return nil, fmt.Errorf("Argon2Parameters Parallelism must be be greater than 0 and no more than %d", math.MaxUint8) - } - if params.Iterations <= 0 || params.Iterations > math.MaxUint32 { - return nil, fmt.Errorf("-argon-iterations must be be greater than 0 and no more than %d", uint32(math.MaxUint32)) - } - - return &Argon2Parameters{ - version: rune(params.Version), - Memory: uint32(params.Memory), - Parallelism: uint8(params.Parallelism), - Iterations: uint32(params.Iterations), - salt: params.Salt, - }, nil - -} - -// DecryptAndUnmarshalSigningPrivateKey will try to pem decode and decrypt an Ed25519/ECDSA private key with -// the given passphrase, returning any other bytes b or an error on failure -func DecryptAndUnmarshalSigningPrivateKey(passphrase, b []byte) (Curve, []byte, []byte, error) { - var curve Curve - - k, r := pem.Decode(b) - if k == nil { - return curve, nil, r, fmt.Errorf("input did not contain a valid PEM encoded block") - } - - switch k.Type { - case EncryptedEd25519PrivateKeyBanner: - curve = Curve_CURVE25519 - case EncryptedECDSAP256PrivateKeyBanner: - curve = Curve_P256 - default: - return curve, nil, r, fmt.Errorf("bytes did not contain a proper nebula encrypted Ed25519/ECDSA private key banner") - } - - ned, err := UnmarshalNebulaEncryptedData(k.Bytes) - if err != nil { - return curve, nil, r, err - } - - var bytes []byte - switch ned.EncryptionMetadata.EncryptionAlgorithm { - case "AES-256-GCM": - bytes, err = aes256Decrypt(passphrase, &ned.EncryptionMetadata.Argon2Parameters, ned.Ciphertext) - if err != nil { - return curve, nil, r, err - } - default: - return curve, nil, r, fmt.Errorf("unsupported encryption algorithm: %s", ned.EncryptionMetadata.EncryptionAlgorithm) - } - - switch curve { - case Curve_CURVE25519: - if len(bytes) != ed25519.PrivateKeySize { - return curve, nil, r, fmt.Errorf("key was not %d bytes, is invalid ed25519 private key", ed25519.PrivateKeySize) - } - case Curve_P256: - if len(bytes) != 32 { - return curve, nil, r, fmt.Errorf("key was not 32 bytes, is invalid ECDSA P256 private key") - } - } - - return curve, bytes, r, nil -} - -func MarshalPublicKey(curve Curve, b []byte) []byte { - switch curve { - case Curve_CURVE25519: - return pem.EncodeToMemory(&pem.Block{Type: X25519PublicKeyBanner, Bytes: b}) - case Curve_P256: - return pem.EncodeToMemory(&pem.Block{Type: P256PublicKeyBanner, Bytes: b}) - default: - return nil - } -} - -// MarshalX25519PublicKey is a simple helper to PEM encode an X25519 public key -func MarshalX25519PublicKey(b []byte) []byte { - return pem.EncodeToMemory(&pem.Block{Type: X25519PublicKeyBanner, Bytes: b}) -} - -// MarshalEd25519PublicKey is a simple helper to PEM encode an Ed25519 public key -func MarshalEd25519PublicKey(key ed25519.PublicKey) []byte { - return pem.EncodeToMemory(&pem.Block{Type: Ed25519PublicKeyBanner, Bytes: key}) -} - -func UnmarshalPublicKey(b []byte) ([]byte, []byte, Curve, error) { - k, r := pem.Decode(b) - if k == nil { - return nil, r, 0, fmt.Errorf("input did not contain a valid PEM encoded block") - } - var expectedLen int - var curve Curve - switch k.Type { - case X25519PublicKeyBanner: - expectedLen = 32 - curve = Curve_CURVE25519 - case P256PublicKeyBanner: - // Uncompressed - expectedLen = 65 - curve = Curve_P256 - default: - return nil, r, 0, fmt.Errorf("bytes did not contain a proper nebula public key banner") - } - if len(k.Bytes) != expectedLen { - return nil, r, 0, fmt.Errorf("key was not %d bytes, is invalid %s public key", expectedLen, curve) - } - return k.Bytes, r, curve, nil -} - -// UnmarshalX25519PublicKey will try to pem decode an X25519 public key, returning any other bytes b -// or an error on failure -func UnmarshalX25519PublicKey(b []byte) ([]byte, []byte, error) { - k, r := pem.Decode(b) - if k == nil { - return nil, r, fmt.Errorf("input did not contain a valid PEM encoded block") - } - if k.Type != X25519PublicKeyBanner { - return nil, r, fmt.Errorf("bytes did not contain a proper nebula X25519 public key banner") - } - if len(k.Bytes) != publicKeyLen { - return nil, r, fmt.Errorf("key was not 32 bytes, is invalid X25519 public key") - } - - return k.Bytes, r, nil -} - -// UnmarshalEd25519PublicKey will try to pem decode an Ed25519 public key, returning any other bytes b -// or an error on failure -func UnmarshalEd25519PublicKey(b []byte) (ed25519.PublicKey, []byte, error) { - k, r := pem.Decode(b) - if k == nil { - return nil, r, fmt.Errorf("input did not contain a valid PEM encoded block") - } - if k.Type != Ed25519PublicKeyBanner { - return nil, r, fmt.Errorf("bytes did not contain a proper nebula Ed25519 public key banner") - } - if len(k.Bytes) != ed25519.PublicKeySize { - return nil, r, fmt.Errorf("key was not 32 bytes, is invalid ed25519 public key") - } - - return k.Bytes, r, nil -} - -// Sign signs a nebula cert with the provided private key -func (nc *NebulaCertificate) Sign(curve Curve, key []byte) error { - if curve != nc.Details.Curve { - return fmt.Errorf("curve in cert and private key supplied don't match") - } - - b, err := proto.Marshal(nc.getRawDetails()) - if err != nil { - return err - } - - var sig []byte - - switch curve { - case Curve_CURVE25519: - signer := ed25519.PrivateKey(key) - sig = ed25519.Sign(signer, b) - case Curve_P256: - signer := &ecdsa.PrivateKey{ - PublicKey: ecdsa.PublicKey{ - Curve: elliptic.P256(), - }, - // ref: https://github.com/golang/go/blob/go1.19/src/crypto/x509/sec1.go#L95 - D: new(big.Int).SetBytes(key), - } - // ref: https://github.com/golang/go/blob/go1.19/src/crypto/x509/sec1.go#L119 - signer.X, signer.Y = signer.Curve.ScalarBaseMult(key) - - // We need to hash first for ECDSA - // - https://pkg.go.dev/crypto/ecdsa#SignASN1 - hashed := sha256.Sum256(b) - sig, err = ecdsa.SignASN1(rand.Reader, signer, hashed[:]) - if err != nil { - return err - } - default: - return fmt.Errorf("invalid curve: %s", nc.Details.Curve) - } - - nc.Signature = sig - return nil -} - -// SignPkcs11 signs a nebula cert with the provided private key -func (nc *NebulaCertificate) SignPkcs11(curve Curve, client *pkclient.PKClient) error { - if !nc.Pkcs11Backed { - return fmt.Errorf("certificate is not PKCS#11 backed") - } - - if curve != nc.Details.Curve { - return fmt.Errorf("curve in cert and private key supplied don't match") - } - - if curve != Curve_P256 { - return fmt.Errorf("only P256 is supported by PKCS#11") - } - - b, err := proto.Marshal(nc.getRawDetails()) - if err != nil { - return err - } - - sig, err := client.SignASN1(b) - if err != nil { - return err - } - - nc.Signature = sig - return nil -} - -// CheckSignature verifies the signature against the provided public key -func (nc *NebulaCertificate) CheckSignature(key []byte) bool { - b, err := proto.Marshal(nc.getRawDetails()) - if err != nil { - return false - } - switch nc.Details.Curve { - case Curve_CURVE25519: - return ed25519.Verify(ed25519.PublicKey(key), b, nc.Signature) - case Curve_P256: - x, y := elliptic.Unmarshal(elliptic.P256(), key) - pubKey := &ecdsa.PublicKey{Curve: elliptic.P256(), X: x, Y: y} - hashed := sha256.Sum256(b) - return ecdsa.VerifyASN1(pubKey, hashed[:], nc.Signature) - default: - return false - } -} - -// NOTE: This uses an internal cache that will not be invalidated automatically -// if you manually change any fields in the NebulaCertificate. -func (nc *NebulaCertificate) checkSignatureWithCache(key []byte, useCache bool) bool { - if !useCache { - return nc.CheckSignature(key) - } - - if v := nc.signatureVerified.Load(); v != nil { - return bytes.Equal(*v, key) - } - - verified := nc.CheckSignature(key) - if verified { - keyCopy := make([]byte, len(key)) - copy(keyCopy, key) - nc.signatureVerified.Store(&keyCopy) - } - - return verified -} - -// Expired will return true if the nebula cert is too young or too old compared to the provided time, otherwise false -func (nc *NebulaCertificate) Expired(t time.Time) bool { - return nc.Details.NotBefore.After(t) || nc.Details.NotAfter.Before(t) -} - -// Verify will ensure a certificate is good in all respects (expiry, group membership, signature, cert blocklist, etc) -func (nc *NebulaCertificate) Verify(t time.Time, ncp *NebulaCAPool) (bool, error) { - return nc.verify(t, ncp, false) -} - -// VerifyWithCache will ensure a certificate is good in all respects (expiry, group membership, signature, cert blocklist, etc) -// -// NOTE: This uses an internal cache that will not be invalidated automatically -// if you manually change any fields in the NebulaCertificate. -func (nc *NebulaCertificate) VerifyWithCache(t time.Time, ncp *NebulaCAPool) (bool, error) { - return nc.verify(t, ncp, true) -} - -// ResetCache resets the cache used by VerifyWithCache. -func (nc *NebulaCertificate) ResetCache() { - nc.sha256sum.Store(nil) - nc.signatureVerified.Store(nil) -} - -// Verify will ensure a certificate is good in all respects (expiry, group membership, signature, cert blocklist, etc) -func (nc *NebulaCertificate) verify(t time.Time, ncp *NebulaCAPool, useCache bool) (bool, error) { - if ncp.isBlocklistedWithCache(nc, useCache) { - return false, ErrBlockListed - } - - signer, err := ncp.GetCAForCert(nc) - if err != nil { - return false, err - } - - if signer.Expired(t) { - return false, ErrRootExpired - } - - if nc.Expired(t) { - return false, ErrExpired - } - - if !nc.checkSignatureWithCache(signer.Details.PublicKey, useCache) { - return false, ErrSignatureMismatch - } - - if err := nc.CheckRootConstrains(signer); err != nil { - return false, err - } - - return true, nil -} - -// CheckRootConstrains returns an error if the certificate violates constraints set on the root (groups, ips, subnets) -func (nc *NebulaCertificate) CheckRootConstrains(signer *NebulaCertificate) error { - // Make sure this cert wasn't valid before the root - if signer.Details.NotAfter.Before(nc.Details.NotAfter) { - return fmt.Errorf("certificate expires after signing certificate") - } - - // Make sure this cert isn't valid after the root - if signer.Details.NotBefore.After(nc.Details.NotBefore) { - return fmt.Errorf("certificate is valid before the signing certificate") - } - - // If the signer has a limited set of groups make sure the cert only contains a subset - if len(signer.Details.InvertedGroups) > 0 { - for _, g := range nc.Details.Groups { - if _, ok := signer.Details.InvertedGroups[g]; !ok { - return fmt.Errorf("certificate contained a group not present on the signing ca: %s", g) - } - } - } - - // If the signer has a limited set of ip ranges to issue from make sure the cert only contains a subset - if len(signer.Details.Ips) > 0 { - for _, ip := range nc.Details.Ips { - if !netMatch(ip, signer.Details.Ips) { - return fmt.Errorf("certificate contained an ip assignment outside the limitations of the signing ca: %s", ip.String()) - } - } - } - - // If the signer has a limited set of subnet ranges to issue from make sure the cert only contains a subset - if len(signer.Details.Subnets) > 0 { - for _, subnet := range nc.Details.Subnets { - if !netMatch(subnet, signer.Details.Subnets) { - return fmt.Errorf("certificate contained a subnet assignment outside the limitations of the signing ca: %s", subnet) - } - } - } - - return nil -} - -// VerifyPrivateKey checks that the public key in the Nebula certificate and a supplied private key match -func (nc *NebulaCertificate) VerifyPrivateKey(curve Curve, key []byte) error { - if nc.Pkcs11Backed { - return nil //todo! - } - if curve != nc.Details.Curve { - return fmt.Errorf("curve in cert and private key supplied don't match") - } - if nc.Details.IsCA { - switch curve { - case Curve_CURVE25519: - // the call to PublicKey below will panic slice bounds out of range otherwise - if len(key) != ed25519.PrivateKeySize { - return fmt.Errorf("key was not 64 bytes, is invalid ed25519 private key") - } - - if !ed25519.PublicKey(nc.Details.PublicKey).Equal(ed25519.PrivateKey(key).Public()) { - return fmt.Errorf("public key in cert and private key supplied don't match") - } - case Curve_P256: - privkey, err := ecdh.P256().NewPrivateKey(key) - if err != nil { - return fmt.Errorf("cannot parse private key as P256") - } - pub := privkey.PublicKey().Bytes() - if !bytes.Equal(pub, nc.Details.PublicKey) { - return fmt.Errorf("public key in cert and private key supplied don't match") - } - default: - return fmt.Errorf("invalid curve: %s", curve) - } - return nil - } - - var pub []byte - switch curve { - case Curve_CURVE25519: - var err error - pub, err = curve25519.X25519(key, curve25519.Basepoint) - if err != nil { - return err - } - case Curve_P256: - privkey, err := ecdh.P256().NewPrivateKey(key) - if err != nil { - return err - } - pub = privkey.PublicKey().Bytes() - default: - return fmt.Errorf("invalid curve: %s", curve) - } - if !bytes.Equal(pub, nc.Details.PublicKey) { - return fmt.Errorf("public key in cert and private key supplied don't match") - } - - return nil -} - -// String will return a pretty printed representation of a nebula cert -func (nc *NebulaCertificate) String() string { - if nc == nil { - return "NebulaCertificate {}\n" - } - - s := "NebulaCertificate {\n" - s += "\tDetails {\n" - s += fmt.Sprintf("\t\tName: %v\n", nc.Details.Name) - - if len(nc.Details.Ips) > 0 { - s += "\t\tIps: [\n" - for _, ip := range nc.Details.Ips { - s += fmt.Sprintf("\t\t\t%v\n", ip.String()) - } - s += "\t\t]\n" - } else { - s += "\t\tIps: []\n" - } - - if len(nc.Details.Subnets) > 0 { - s += "\t\tSubnets: [\n" - for _, ip := range nc.Details.Subnets { - s += fmt.Sprintf("\t\t\t%v\n", ip.String()) - } - s += "\t\t]\n" - } else { - s += "\t\tSubnets: []\n" - } - - if len(nc.Details.Groups) > 0 { - s += "\t\tGroups: [\n" - for _, g := range nc.Details.Groups { - s += fmt.Sprintf("\t\t\t\"%v\"\n", g) - } - s += "\t\t]\n" - } else { - s += "\t\tGroups: []\n" - } - - s += fmt.Sprintf("\t\tNot before: %v\n", nc.Details.NotBefore) - s += fmt.Sprintf("\t\tNot After: %v\n", nc.Details.NotAfter) - s += fmt.Sprintf("\t\tIs CA: %v\n", nc.Details.IsCA) - s += fmt.Sprintf("\t\tIssuer: %s\n", nc.Details.Issuer) - s += fmt.Sprintf("\t\tPublic key: %x\n", nc.Details.PublicKey) - s += fmt.Sprintf("\t\tCurve: %s\n", nc.Details.Curve) - s += "\t}\n" - fp, err := nc.Sha256Sum() - if err == nil { - s += fmt.Sprintf("\tFingerprint: %s\n", fp) - } - s += fmt.Sprintf("\tSignature: %x\n", nc.Signature) - s += "}" - - return s -} - -// getRawDetails marshals the raw details into protobuf ready struct -func (nc *NebulaCertificate) getRawDetails() *RawNebulaCertificateDetails { - rd := &RawNebulaCertificateDetails{ - Name: nc.Details.Name, - Groups: nc.Details.Groups, - NotBefore: nc.Details.NotBefore.Unix(), - NotAfter: nc.Details.NotAfter.Unix(), - PublicKey: make([]byte, len(nc.Details.PublicKey)), - IsCA: nc.Details.IsCA, - Curve: nc.Details.Curve, - } - - for _, ipNet := range nc.Details.Ips { - rd.Ips = append(rd.Ips, ip2int(ipNet.IP), ip2int(ipNet.Mask)) - } - - for _, ipNet := range nc.Details.Subnets { - rd.Subnets = append(rd.Subnets, ip2int(ipNet.IP), ip2int(ipNet.Mask)) - } - - copy(rd.PublicKey, nc.Details.PublicKey[:]) - - // I know, this is terrible - rd.Issuer, _ = hex.DecodeString(nc.Details.Issuer) - - return rd -} - -// Marshal will marshal a nebula cert into a protobuf byte array -func (nc *NebulaCertificate) Marshal() ([]byte, error) { - rc := RawNebulaCertificate{ - Details: nc.getRawDetails(), - Signature: nc.Signature, - } - - return proto.Marshal(&rc) -} - -// MarshalToPEM will marshal a nebula cert into a protobuf byte array and pem encode the result -func (nc *NebulaCertificate) MarshalToPEM() ([]byte, error) { - b, err := nc.Marshal() - if err != nil { - return nil, err - } - return pem.EncodeToMemory(&pem.Block{Type: CertBanner, Bytes: b}), nil -} - -// Sha256Sum calculates a sha-256 sum of the marshaled certificate -func (nc *NebulaCertificate) Sha256Sum() (string, error) { - b, err := nc.Marshal() - if err != nil { - return "", err - } - - sum := sha256.Sum256(b) - return hex.EncodeToString(sum[:]), nil -} - -// NOTE: This uses an internal cache that will not be invalidated automatically -// if you manually change any fields in the NebulaCertificate. -func (nc *NebulaCertificate) sha256SumWithCache(useCache bool) (string, error) { - if !useCache { - return nc.Sha256Sum() - } - - if s := nc.sha256sum.Load(); s != nil { - return *s, nil - } - s, err := nc.Sha256Sum() - if err != nil { - return s, err - } - - nc.sha256sum.Store(&s) - return s, nil -} - -func (nc *NebulaCertificate) MarshalJSON() ([]byte, error) { - toString := func(ips []*net.IPNet) []string { - s := []string{} - for _, ip := range ips { - s = append(s, ip.String()) - } - return s - } - - fp, _ := nc.Sha256Sum() - jc := m{ - "details": m{ - "name": nc.Details.Name, - "ips": toString(nc.Details.Ips), - "subnets": toString(nc.Details.Subnets), - "groups": nc.Details.Groups, - "notBefore": nc.Details.NotBefore, - "notAfter": nc.Details.NotAfter, - "publicKey": fmt.Sprintf("%x", nc.Details.PublicKey), - "isCa": nc.Details.IsCA, - "issuer": nc.Details.Issuer, - "curve": nc.Details.Curve.String(), - }, - "fingerprint": fp, - "signature": fmt.Sprintf("%x", nc.Signature), - } - return json.Marshal(jc) -} - -//func (nc *NebulaCertificate) Copy() *NebulaCertificate { -// r, err := nc.Marshal() -// if err != nil { -// //TODO -// return nil -// } -// -// c, err := UnmarshalNebulaCertificate(r) -// return c -//} - -func (nc *NebulaCertificate) Copy() *NebulaCertificate { - c := &NebulaCertificate{ - Details: NebulaCertificateDetails{ - Name: nc.Details.Name, - Groups: make([]string, len(nc.Details.Groups)), - Ips: make([]*net.IPNet, len(nc.Details.Ips)), - Subnets: make([]*net.IPNet, len(nc.Details.Subnets)), - NotBefore: nc.Details.NotBefore, - NotAfter: nc.Details.NotAfter, - PublicKey: make([]byte, len(nc.Details.PublicKey)), - IsCA: nc.Details.IsCA, - Issuer: nc.Details.Issuer, - InvertedGroups: make(map[string]struct{}, len(nc.Details.InvertedGroups)), - }, - Signature: make([]byte, len(nc.Signature)), - } - - copy(c.Signature, nc.Signature) - copy(c.Details.Groups, nc.Details.Groups) - copy(c.Details.PublicKey, nc.Details.PublicKey) - - for i, p := range nc.Details.Ips { - c.Details.Ips[i] = &net.IPNet{ - IP: make(net.IP, len(p.IP)), - Mask: make(net.IPMask, len(p.Mask)), - } - copy(c.Details.Ips[i].IP, p.IP) - copy(c.Details.Ips[i].Mask, p.Mask) - } - - for i, p := range nc.Details.Subnets { - c.Details.Subnets[i] = &net.IPNet{ - IP: make(net.IP, len(p.IP)), - Mask: make(net.IPMask, len(p.Mask)), - } - copy(c.Details.Subnets[i].IP, p.IP) - copy(c.Details.Subnets[i].Mask, p.Mask) - } - - for g := range nc.Details.InvertedGroups { - c.Details.InvertedGroups[g] = struct{}{} - } - - return c -} - -func netMatch(certIp *net.IPNet, rootIps []*net.IPNet) bool { - for _, net := range rootIps { - if net.Contains(certIp.IP) && maskContains(net.Mask, certIp.Mask) { - return true - } - } - - return false -} - -func maskContains(caMask, certMask net.IPMask) bool { - caM := maskTo4(caMask) - cM := maskTo4(certMask) - // Make sure forcing to ipv4 didn't nuke us - if caM == nil || cM == nil { - return false - } - - // Make sure the cert mask is not greater than the ca mask - for i := 0; i < len(caMask); i++ { - if caM[i] > cM[i] { - return false - } - } - - return true -} - -func maskTo4(ip net.IPMask) net.IPMask { - if len(ip) == net.IPv4len { - return ip - } - - if len(ip) == net.IPv6len && isZeros(ip[0:10]) && ip[10] == 0xff && ip[11] == 0xff { - return ip[12:16] - } - - return nil -} - -func isZeros(b []byte) bool { - for i := 0; i < len(b); i++ { - if b[i] != 0 { - return false - } - } - return true -} - -func ip2int(ip []byte) uint32 { - if len(ip) == 16 { - return binary.BigEndian.Uint32(ip[12:16]) - } - return binary.BigEndian.Uint32(ip) -} - -func int2ip(nn uint32) net.IP { - ip := make(net.IP, net.IPv4len) - binary.BigEndian.PutUint32(ip, nn) - return ip + c.details.PublicKey = publicKey + return c, nil } diff --git a/cert/cert_test.go b/cert/cert_test.go index 30e99ec..12bbd97 100644 --- a/cert/cert_test.go +++ b/cert/cert_test.go @@ -7,7 +7,7 @@ import ( "crypto/rand" "fmt" "io" - "net" + "net/netip" "testing" "time" @@ -15,7 +15,6 @@ import ( "github.com/stretchr/testify/assert" "golang.org/x/crypto/curve25519" "golang.org/x/crypto/ed25519" - "google.golang.org/protobuf/proto" ) func TestMarshalingNebulaCertificate(t *testing.T) { @@ -23,18 +22,16 @@ func TestMarshalingNebulaCertificate(t *testing.T) { after := time.Now().Add(time.Second * 60).Round(time.Second) pubKey := []byte("1234567890abcedfghij1234567890ab") - nc := NebulaCertificate{ - Details: NebulaCertificateDetails{ + nc := certificateV1{ + details: detailsV1{ Name: "testing", - Ips: []*net.IPNet{ - {IP: net.ParseIP("10.1.1.1"), Mask: net.IPMask(net.ParseIP("255.255.255.0"))}, - {IP: net.ParseIP("10.1.1.2"), Mask: net.IPMask(net.ParseIP("255.255.0.0"))}, - {IP: net.ParseIP("10.1.1.3"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))}, + Ips: []netip.Prefix{ + mustParsePrefixUnmapped("10.1.1.1/24"), + mustParsePrefixUnmapped("10.1.1.2/16"), }, - Subnets: []*net.IPNet{ - {IP: net.ParseIP("9.1.1.1"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))}, - {IP: net.ParseIP("9.1.1.2"), Mask: net.IPMask(net.ParseIP("255.255.255.0"))}, - {IP: net.ParseIP("9.1.1.3"), Mask: net.IPMask(net.ParseIP("255.255.0.0"))}, + Subnets: []netip.Prefix{ + mustParsePrefixUnmapped("9.1.1.2/24"), + mustParsePrefixUnmapped("9.1.1.3/16"), }, Groups: []string{"test-group1", "test-group2", "test-group3"}, NotBefore: before, @@ -43,120 +40,116 @@ func TestMarshalingNebulaCertificate(t *testing.T) { IsCA: false, Issuer: "1234567890abcedfghij1234567890ab", }, - Signature: []byte("1234567890abcedfghij1234567890ab"), + signature: []byte("1234567890abcedfghij1234567890ab"), } b, err := nc.Marshal() assert.Nil(t, err) //t.Log("Cert size:", len(b)) - nc2, err := UnmarshalNebulaCertificate(b) + nc2, err := unmarshalCertificateV1(b, true) assert.Nil(t, err) - assert.Equal(t, nc.Signature, nc2.Signature) - assert.Equal(t, nc.Details.Name, nc2.Details.Name) - assert.Equal(t, nc.Details.NotBefore, nc2.Details.NotBefore) - assert.Equal(t, nc.Details.NotAfter, nc2.Details.NotAfter) - assert.Equal(t, nc.Details.PublicKey, nc2.Details.PublicKey) - assert.Equal(t, nc.Details.IsCA, nc2.Details.IsCA) + assert.Equal(t, nc.signature, nc2.Signature()) + assert.Equal(t, nc.details.Name, nc2.Name()) + assert.Equal(t, nc.details.NotBefore, nc2.NotBefore()) + assert.Equal(t, nc.details.NotAfter, nc2.NotAfter()) + assert.Equal(t, nc.details.PublicKey, nc2.PublicKey()) + assert.Equal(t, nc.details.IsCA, nc2.IsCA()) - // IP byte arrays can be 4 or 16 in length so we have to go this route - assert.Equal(t, len(nc.Details.Ips), len(nc2.Details.Ips)) - for i, wIp := range nc.Details.Ips { - assert.Equal(t, wIp.String(), nc2.Details.Ips[i].String()) - } + assert.Equal(t, nc.details.Ips, nc2.Networks()) + assert.Equal(t, nc.details.Subnets, nc2.UnsafeNetworks()) - assert.Equal(t, len(nc.Details.Subnets), len(nc2.Details.Subnets)) - for i, wIp := range nc.Details.Subnets { - assert.Equal(t, wIp.String(), nc2.Details.Subnets[i].String()) - } - - assert.EqualValues(t, nc.Details.Groups, nc2.Details.Groups) + assert.Equal(t, nc.details.Groups, nc2.Groups()) } -func TestNebulaCertificate_Sign(t *testing.T) { - before := time.Now().Add(time.Second * -60).Round(time.Second) - after := time.Now().Add(time.Second * 60).Round(time.Second) - pubKey := []byte("1234567890abcedfghij1234567890ab") +//func TestNebulaCertificate_Sign(t *testing.T) { +// before := time.Now().Add(time.Second * -60).Round(time.Second) +// after := time.Now().Add(time.Second * 60).Round(time.Second) +// pubKey := []byte("1234567890abcedfghij1234567890ab") +// +// nc := certificateV1{ +// details: detailsV1{ +// Name: "testing", +// Ips: []netip.Prefix{ +// mustParsePrefixUnmapped("10.1.1.1/24"), +// mustParsePrefixUnmapped("10.1.1.2/16"), +// //TODO: netip cant do it +// //{IP: net.ParseIP("10.1.1.3"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))}, +// }, +// Subnets: []netip.Prefix{ +// //TODO: netip cant do it +// //{IP: net.ParseIP("9.1.1.1"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))}, +// mustParsePrefixUnmapped("9.1.1.2/24"), +// mustParsePrefixUnmapped("9.1.1.3/24"), +// }, +// Groups: []string{"test-group1", "test-group2", "test-group3"}, +// NotBefore: before, +// NotAfter: after, +// PublicKey: pubKey, +// IsCA: false, +// Issuer: "1234567890abcedfghij1234567890ab", +// }, +// } +// +// pub, priv, err := ed25519.GenerateKey(rand.Reader) +// assert.Nil(t, err) +// assert.False(t, nc.CheckSignature(pub)) +// assert.Nil(t, nc.Sign(Curve_CURVE25519, priv)) +// assert.True(t, nc.CheckSignature(pub)) +// +// _, err = nc.Marshal() +// assert.Nil(t, err) +// //t.Log("Cert size:", len(b)) +//} - nc := NebulaCertificate{ - Details: NebulaCertificateDetails{ - Name: "testing", - Ips: []*net.IPNet{ - {IP: net.ParseIP("10.1.1.1"), Mask: net.IPMask(net.ParseIP("255.255.255.0"))}, - {IP: net.ParseIP("10.1.1.2"), Mask: net.IPMask(net.ParseIP("255.255.0.0"))}, - {IP: net.ParseIP("10.1.1.3"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))}, - }, - Subnets: []*net.IPNet{ - {IP: net.ParseIP("9.1.1.1"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))}, - {IP: net.ParseIP("9.1.1.2"), Mask: net.IPMask(net.ParseIP("255.255.255.0"))}, - {IP: net.ParseIP("9.1.1.3"), Mask: net.IPMask(net.ParseIP("255.255.0.0"))}, - }, - Groups: []string{"test-group1", "test-group2", "test-group3"}, - NotBefore: before, - NotAfter: after, - PublicKey: pubKey, - IsCA: false, - Issuer: "1234567890abcedfghij1234567890ab", - }, - } - - pub, priv, err := ed25519.GenerateKey(rand.Reader) - assert.Nil(t, err) - assert.False(t, nc.CheckSignature(pub)) - assert.Nil(t, nc.Sign(Curve_CURVE25519, priv)) - assert.True(t, nc.CheckSignature(pub)) - - _, err = nc.Marshal() - assert.Nil(t, err) - //t.Log("Cert size:", len(b)) -} - -func TestNebulaCertificate_SignP256(t *testing.T) { - before := time.Now().Add(time.Second * -60).Round(time.Second) - after := time.Now().Add(time.Second * 60).Round(time.Second) - pubKey := []byte("01234567890abcedfghij1234567890ab1234567890abcedfghij1234567890ab") - - nc := NebulaCertificate{ - Details: NebulaCertificateDetails{ - Name: "testing", - Ips: []*net.IPNet{ - {IP: net.ParseIP("10.1.1.1"), Mask: net.IPMask(net.ParseIP("255.255.255.0"))}, - {IP: net.ParseIP("10.1.1.2"), Mask: net.IPMask(net.ParseIP("255.255.0.0"))}, - {IP: net.ParseIP("10.1.1.3"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))}, - }, - Subnets: []*net.IPNet{ - {IP: net.ParseIP("9.1.1.1"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))}, - {IP: net.ParseIP("9.1.1.2"), Mask: net.IPMask(net.ParseIP("255.255.255.0"))}, - {IP: net.ParseIP("9.1.1.3"), Mask: net.IPMask(net.ParseIP("255.255.0.0"))}, - }, - Groups: []string{"test-group1", "test-group2", "test-group3"}, - NotBefore: before, - NotAfter: after, - PublicKey: pubKey, - IsCA: false, - Curve: Curve_P256, - Issuer: "1234567890abcedfghij1234567890ab", - }, - } - - priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) - pub := elliptic.Marshal(elliptic.P256(), priv.PublicKey.X, priv.PublicKey.Y) - rawPriv := priv.D.FillBytes(make([]byte, 32)) - - assert.Nil(t, err) - assert.False(t, nc.CheckSignature(pub)) - assert.Nil(t, nc.Sign(Curve_P256, rawPriv)) - assert.True(t, nc.CheckSignature(pub)) - - _, err = nc.Marshal() - assert.Nil(t, err) - //t.Log("Cert size:", len(b)) -} +//func TestNebulaCertificate_SignP256(t *testing.T) { +// before := time.Now().Add(time.Second * -60).Round(time.Second) +// after := time.Now().Add(time.Second * 60).Round(time.Second) +// pubKey := []byte("01234567890abcedfghij1234567890ab1234567890abcedfghij1234567890ab") +// +// nc := certificateV1{ +// details: detailsV1{ +// Name: "testing", +// Ips: []netip.Prefix{ +// mustParsePrefixUnmapped("10.1.1.1/24"), +// mustParsePrefixUnmapped("10.1.1.2/16"), +// //TODO: netip no can do +// //{IP: net.ParseIP("10.1.1.3"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))}, +// }, +// Subnets: []netip.Prefix{ +// //TODO: netip bad +// //{IP: net.ParseIP("9.1.1.1"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))}, +// mustParsePrefixUnmapped("9.1.1.2/24"), +// mustParsePrefixUnmapped("9.1.1.3/16"), +// }, +// Groups: []string{"test-group1", "test-group2", "test-group3"}, +// NotBefore: before, +// NotAfter: after, +// PublicKey: pubKey, +// IsCA: false, +// Curve: Curve_P256, +// Issuer: "1234567890abcedfghij1234567890ab", +// }, +// } +// +// priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) +// pub := elliptic.Marshal(elliptic.P256(), priv.PublicKey.X, priv.PublicKey.Y) +// rawPriv := priv.D.FillBytes(make([]byte, 32)) +// +// assert.Nil(t, err) +// assert.False(t, nc.CheckSignature(pub)) +// assert.Nil(t, nc.Sign(Curve_P256, rawPriv)) +// assert.True(t, nc.CheckSignature(pub)) +// +// _, err = nc.Marshal() +// assert.Nil(t, err) +// //t.Log("Cert size:", len(b)) +//} func TestNebulaCertificate_Expired(t *testing.T) { - nc := NebulaCertificate{ - Details: NebulaCertificateDetails{ + nc := certificateV1{ + details: detailsV1{ NotBefore: time.Now().Add(time.Second * -60).Round(time.Second), NotAfter: time.Now().Add(time.Second * 60).Round(time.Second), }, @@ -171,18 +164,16 @@ func TestNebulaCertificate_MarshalJSON(t *testing.T) { time.Local = time.UTC pubKey := []byte("1234567890abcedfghij1234567890ab") - nc := NebulaCertificate{ - Details: NebulaCertificateDetails{ + nc := certificateV1{ + details: detailsV1{ Name: "testing", - Ips: []*net.IPNet{ - {IP: net.ParseIP("10.1.1.1"), Mask: net.IPMask(net.ParseIP("255.255.255.0"))}, - {IP: net.ParseIP("10.1.1.2"), Mask: net.IPMask(net.ParseIP("255.255.0.0"))}, - {IP: net.ParseIP("10.1.1.3"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))}, + Ips: []netip.Prefix{ + mustParsePrefixUnmapped("10.1.1.1/24"), + mustParsePrefixUnmapped("10.1.1.2/16"), }, - Subnets: []*net.IPNet{ - {IP: net.ParseIP("9.1.1.1"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))}, - {IP: net.ParseIP("9.1.1.2"), Mask: net.IPMask(net.ParseIP("255.255.255.0"))}, - {IP: net.ParseIP("9.1.1.3"), Mask: net.IPMask(net.ParseIP("255.255.0.0"))}, + Subnets: []netip.Prefix{ + mustParsePrefixUnmapped("9.1.1.2/24"), + mustParsePrefixUnmapped("9.1.1.3/16"), }, Groups: []string{"test-group1", "test-group2", "test-group3"}, NotBefore: time.Date(1, 0, 0, 1, 0, 0, 0, time.UTC), @@ -191,306 +182,256 @@ func TestNebulaCertificate_MarshalJSON(t *testing.T) { IsCA: false, Issuer: "1234567890abcedfghij1234567890ab", }, - Signature: []byte("1234567890abcedfghij1234567890ab"), + signature: []byte("1234567890abcedfghij1234567890ab"), } b, err := nc.MarshalJSON() assert.Nil(t, err) assert.Equal( t, - "{\"details\":{\"curve\":\"CURVE25519\",\"groups\":[\"test-group1\",\"test-group2\",\"test-group3\"],\"ips\":[\"10.1.1.1/24\",\"10.1.1.2/16\",\"10.1.1.3/ff00ff00\"],\"isCa\":false,\"issuer\":\"1234567890abcedfghij1234567890ab\",\"name\":\"testing\",\"notAfter\":\"0000-11-30T02:00:00Z\",\"notBefore\":\"0000-11-30T01:00:00Z\",\"publicKey\":\"313233343536373839306162636564666768696a313233343536373839306162\",\"subnets\":[\"9.1.1.1/ff00ff00\",\"9.1.1.2/24\",\"9.1.1.3/16\"]},\"fingerprint\":\"26cb1c30ad7872c804c166b5150fa372f437aa3856b04edb4334b4470ec728e4\",\"signature\":\"313233343536373839306162636564666768696a313233343536373839306162\"}", + "{\"details\":{\"curve\":\"CURVE25519\",\"groups\":[\"test-group1\",\"test-group2\",\"test-group3\"],\"ips\":[\"10.1.1.1/24\",\"10.1.1.2/16\"],\"isCa\":false,\"issuer\":\"1234567890abcedfghij1234567890ab\",\"name\":\"testing\",\"notAfter\":\"0000-11-30T02:00:00Z\",\"notBefore\":\"0000-11-30T01:00:00Z\",\"publicKey\":\"313233343536373839306162636564666768696a313233343536373839306162\",\"subnets\":[\"9.1.1.2/24\",\"9.1.1.3/16\"]},\"fingerprint\":\"3944c53d4267a229295b56cb2d27d459164c010ac97d655063ba421e0670f4ba\",\"signature\":\"313233343536373839306162636564666768696a313233343536373839306162\"}", string(b), ) } func TestNebulaCertificate_Verify(t *testing.T) { - ca, _, caKey, err := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) + ca, _, caKey, err := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, nil) assert.Nil(t, err) - c, _, _, err := newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) - assert.Nil(t, err) - - h, err := ca.Sha256Sum() + c, _, _, err := newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil) assert.Nil(t, err) caPool := NewCAPool() - caPool.CAs[h] = ca + assert.NoError(t, caPool.AddCA(ca)) - f, err := c.Sha256Sum() + f, err := c.Fingerprint() assert.Nil(t, err) caPool.BlocklistFingerprint(f) - v, err := c.Verify(time.Now(), caPool) - assert.False(t, v) + _, err = caPool.VerifyCertificate(time.Now(), c) assert.EqualError(t, err, "certificate is in the block list") caPool.ResetCertBlocklist() - v, err = c.Verify(time.Now(), caPool) - assert.True(t, v) + _, err = caPool.VerifyCertificate(time.Now(), c) assert.Nil(t, err) - v, err = c.Verify(time.Now().Add(time.Hour*1000), caPool) - assert.False(t, v) + _, err = caPool.VerifyCertificate(time.Now().Add(time.Hour*1000), c) assert.EqualError(t, err, "root certificate is expired") - c, _, _, err = newTestCert(ca, caKey, time.Time{}, time.Time{}, []*net.IPNet{}, []*net.IPNet{}, []string{}) - assert.Nil(t, err) - v, err = c.Verify(time.Now().Add(time.Minute*6), caPool) - assert.False(t, v) - assert.EqualError(t, err, "certificate is expired") + c, _, _, err = newTestCert(ca, caKey, time.Time{}, time.Time{}, nil, nil, nil) + assert.EqualError(t, err, "certificate is valid before the signing certificate") // Test group assertion - ca, _, caKey, err = newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{"test1", "test2"}) + ca, _, caKey, err = newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{"test1", "test2"}) assert.Nil(t, err) - caPem, err := ca.MarshalToPEM() + caPem, err := ca.MarshalPEM() assert.Nil(t, err) caPool = NewCAPool() - caPool.AddCACertificate(caPem) + b, err := caPool.AddCAFromPEM(caPem) + assert.NoError(t, err) + assert.Empty(t, b) - c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{"test1", "bad"}) - assert.Nil(t, err) - v, err = c.Verify(time.Now(), caPool) - assert.False(t, v) + c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1", "bad"}) assert.EqualError(t, err, "certificate contained a group not present on the signing ca: bad") - c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{"test1"}) + c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"}) assert.Nil(t, err) - v, err = c.Verify(time.Now(), caPool) - assert.True(t, v) + _, err = caPool.VerifyCertificate(time.Now(), c) assert.Nil(t, err) } func TestNebulaCertificate_VerifyP256(t *testing.T) { - ca, _, caKey, err := newTestCaCertP256(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) + ca, _, caKey, err := newTestCaCertP256(time.Now(), time.Now().Add(10*time.Minute), nil, nil, nil) assert.Nil(t, err) - c, _, _, err := newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) - assert.Nil(t, err) - - h, err := ca.Sha256Sum() + c, _, _, err := newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil) assert.Nil(t, err) caPool := NewCAPool() - caPool.CAs[h] = ca + assert.NoError(t, caPool.AddCA(ca)) - f, err := c.Sha256Sum() + f, err := c.Fingerprint() assert.Nil(t, err) caPool.BlocklistFingerprint(f) - v, err := c.Verify(time.Now(), caPool) - assert.False(t, v) + _, err = caPool.VerifyCertificate(time.Now(), c) assert.EqualError(t, err, "certificate is in the block list") caPool.ResetCertBlocklist() - v, err = c.Verify(time.Now(), caPool) - assert.True(t, v) + _, err = caPool.VerifyCertificate(time.Now(), c) assert.Nil(t, err) - v, err = c.Verify(time.Now().Add(time.Hour*1000), caPool) - assert.False(t, v) + _, err = caPool.VerifyCertificate(time.Now().Add(time.Hour*1000), c) assert.EqualError(t, err, "root certificate is expired") - c, _, _, err = newTestCert(ca, caKey, time.Time{}, time.Time{}, []*net.IPNet{}, []*net.IPNet{}, []string{}) - assert.Nil(t, err) - v, err = c.Verify(time.Now().Add(time.Minute*6), caPool) - assert.False(t, v) - assert.EqualError(t, err, "certificate is expired") + c, _, _, err = newTestCert(ca, caKey, time.Time{}, time.Time{}, nil, nil, nil) + assert.EqualError(t, err, "certificate is valid before the signing certificate") // Test group assertion - ca, _, caKey, err = newTestCaCertP256(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{"test1", "test2"}) + ca, _, caKey, err = newTestCaCertP256(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{"test1", "test2"}) assert.Nil(t, err) - caPem, err := ca.MarshalToPEM() + caPem, err := ca.MarshalPEM() assert.Nil(t, err) caPool = NewCAPool() - caPool.AddCACertificate(caPem) + b, err := caPool.AddCAFromPEM(caPem) + assert.NoError(t, err) + assert.Empty(t, b) - c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{"test1", "bad"}) - assert.Nil(t, err) - v, err = c.Verify(time.Now(), caPool) - assert.False(t, v) + c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1", "bad"}) assert.EqualError(t, err, "certificate contained a group not present on the signing ca: bad") - c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{"test1"}) + c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"}) assert.Nil(t, err) - v, err = c.Verify(time.Now(), caPool) - assert.True(t, v) + _, err = caPool.VerifyCertificate(time.Now(), c) assert.Nil(t, err) } func TestNebulaCertificate_Verify_IPs(t *testing.T) { - _, caIp1, _ := net.ParseCIDR("10.0.0.0/16") - _, caIp2, _ := net.ParseCIDR("192.168.0.0/24") - ca, _, caKey, err := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{caIp1, caIp2}, []*net.IPNet{}, []string{"test"}) + caIp1 := mustParsePrefixUnmapped("10.0.0.0/16") + caIp2 := mustParsePrefixUnmapped("192.168.0.0/24") + ca, _, caKey, err := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []netip.Prefix{caIp1, caIp2}, nil, []string{"test"}) assert.Nil(t, err) - caPem, err := ca.MarshalToPEM() + caPem, err := ca.MarshalPEM() assert.Nil(t, err) caPool := NewCAPool() - caPool.AddCACertificate(caPem) + b, err := caPool.AddCAFromPEM(caPem) + assert.NoError(t, err) + assert.Empty(t, b) // ip is outside the network - cIp1 := &net.IPNet{IP: net.ParseIP("10.1.0.0"), Mask: []byte{255, 255, 255, 0}} - cIp2 := &net.IPNet{IP: net.ParseIP("192.168.0.1"), Mask: []byte{255, 255, 0, 0}} - c, _, _, err := newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{cIp1, cIp2}, []*net.IPNet{}, []string{"test"}) - assert.Nil(t, err) - v, err := c.Verify(time.Now(), caPool) - assert.False(t, v) - assert.EqualError(t, err, "certificate contained an ip assignment outside the limitations of the signing ca: 10.1.0.0/24") + cIp1 := mustParsePrefixUnmapped("10.1.0.0/24") + cIp2 := mustParsePrefixUnmapped("192.168.0.1/16") + c, _, _, err := newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"}) + assert.EqualError(t, err, "certificate contained a network assignment outside the limitations of the signing ca: 10.1.0.0/24") // ip is outside the network reversed order of above - cIp1 = &net.IPNet{IP: net.ParseIP("192.168.0.1"), Mask: []byte{255, 255, 255, 0}} - cIp2 = &net.IPNet{IP: net.ParseIP("10.1.0.0"), Mask: []byte{255, 255, 255, 0}} - c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{cIp1, cIp2}, []*net.IPNet{}, []string{"test"}) - assert.Nil(t, err) - v, err = c.Verify(time.Now(), caPool) - assert.False(t, v) - assert.EqualError(t, err, "certificate contained an ip assignment outside the limitations of the signing ca: 10.1.0.0/24") + cIp1 = mustParsePrefixUnmapped("192.168.0.1/24") + cIp2 = mustParsePrefixUnmapped("10.1.0.0/24") + c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"}) + assert.EqualError(t, err, "certificate contained a network assignment outside the limitations of the signing ca: 10.1.0.0/24") // ip is within the network but mask is outside - cIp1 = &net.IPNet{IP: net.ParseIP("10.0.1.0"), Mask: []byte{255, 254, 0, 0}} - cIp2 = &net.IPNet{IP: net.ParseIP("192.168.0.1"), Mask: []byte{255, 255, 255, 0}} - c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{cIp1, cIp2}, []*net.IPNet{}, []string{"test"}) - assert.Nil(t, err) - v, err = c.Verify(time.Now(), caPool) - assert.False(t, v) - assert.EqualError(t, err, "certificate contained an ip assignment outside the limitations of the signing ca: 10.0.1.0/15") + cIp1 = mustParsePrefixUnmapped("10.0.1.0/15") + cIp2 = mustParsePrefixUnmapped("192.168.0.1/24") + c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"}) + assert.EqualError(t, err, "certificate contained a network assignment outside the limitations of the signing ca: 10.0.1.0/15") // ip is within the network but mask is outside reversed order of above - cIp1 = &net.IPNet{IP: net.ParseIP("192.168.0.1"), Mask: []byte{255, 255, 255, 0}} - cIp2 = &net.IPNet{IP: net.ParseIP("10.0.1.0"), Mask: []byte{255, 254, 0, 0}} - c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{cIp1, cIp2}, []*net.IPNet{}, []string{"test"}) - assert.Nil(t, err) - v, err = c.Verify(time.Now(), caPool) - assert.False(t, v) - assert.EqualError(t, err, "certificate contained an ip assignment outside the limitations of the signing ca: 10.0.1.0/15") + cIp1 = mustParsePrefixUnmapped("192.168.0.1/24") + cIp2 = mustParsePrefixUnmapped("10.0.1.0/15") + c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"}) + assert.EqualError(t, err, "certificate contained a network assignment outside the limitations of the signing ca: 10.0.1.0/15") // ip and mask are within the network - cIp1 = &net.IPNet{IP: net.ParseIP("10.0.1.0"), Mask: []byte{255, 255, 0, 0}} - cIp2 = &net.IPNet{IP: net.ParseIP("192.168.0.1"), Mask: []byte{255, 255, 255, 128}} - c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{cIp1, cIp2}, []*net.IPNet{}, []string{"test"}) + cIp1 = mustParsePrefixUnmapped("10.0.1.0/16") + cIp2 = mustParsePrefixUnmapped("192.168.0.1/25") + c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"}) assert.Nil(t, err) - v, err = c.Verify(time.Now(), caPool) - assert.True(t, v) + _, err = caPool.VerifyCertificate(time.Now(), c) assert.Nil(t, err) // Exact matches - c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{caIp1, caIp2}, []*net.IPNet{}, []string{"test"}) + c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp1, caIp2}, nil, []string{"test"}) assert.Nil(t, err) - v, err = c.Verify(time.Now(), caPool) - assert.True(t, v) + _, err = caPool.VerifyCertificate(time.Now(), c) assert.Nil(t, err) // Exact matches reversed - c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{caIp2, caIp1}, []*net.IPNet{}, []string{"test"}) + c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp2, caIp1}, nil, []string{"test"}) assert.Nil(t, err) - v, err = c.Verify(time.Now(), caPool) - assert.True(t, v) + _, err = caPool.VerifyCertificate(time.Now(), c) assert.Nil(t, err) // Exact matches reversed with just 1 - c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{caIp1}, []*net.IPNet{}, []string{"test"}) + c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp1}, nil, []string{"test"}) assert.Nil(t, err) - v, err = c.Verify(time.Now(), caPool) - assert.True(t, v) + _, err = caPool.VerifyCertificate(time.Now(), c) assert.Nil(t, err) } func TestNebulaCertificate_Verify_Subnets(t *testing.T) { - _, caIp1, _ := net.ParseCIDR("10.0.0.0/16") - _, caIp2, _ := net.ParseCIDR("192.168.0.0/24") - ca, _, caKey, err := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{caIp1, caIp2}, []string{"test"}) + caIp1 := mustParsePrefixUnmapped("10.0.0.0/16") + caIp2 := mustParsePrefixUnmapped("192.168.0.0/24") + ca, _, caKey, err := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, []netip.Prefix{caIp1, caIp2}, []string{"test"}) assert.Nil(t, err) - caPem, err := ca.MarshalToPEM() + caPem, err := ca.MarshalPEM() assert.Nil(t, err) caPool := NewCAPool() - caPool.AddCACertificate(caPem) + b, err := caPool.AddCAFromPEM(caPem) + assert.NoError(t, err) + assert.Empty(t, b) // ip is outside the network - cIp1 := &net.IPNet{IP: net.ParseIP("10.1.0.0"), Mask: []byte{255, 255, 255, 0}} - cIp2 := &net.IPNet{IP: net.ParseIP("192.168.0.1"), Mask: []byte{255, 255, 0, 0}} - c, _, _, err := newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{}, []*net.IPNet{cIp1, cIp2}, []string{"test"}) - assert.Nil(t, err) - v, err := c.Verify(time.Now(), caPool) - assert.False(t, v) - assert.EqualError(t, err, "certificate contained a subnet assignment outside the limitations of the signing ca: 10.1.0.0/24") + cIp1 := mustParsePrefixUnmapped("10.1.0.0/24") + cIp2 := mustParsePrefixUnmapped("192.168.0.1/16") + c, _, _, err := newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"}) + assert.EqualError(t, err, "certificate contained an unsafe network assignment outside the limitations of the signing ca: 10.1.0.0/24") // ip is outside the network reversed order of above - cIp1 = &net.IPNet{IP: net.ParseIP("192.168.0.1"), Mask: []byte{255, 255, 255, 0}} - cIp2 = &net.IPNet{IP: net.ParseIP("10.1.0.0"), Mask: []byte{255, 255, 255, 0}} - c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{}, []*net.IPNet{cIp1, cIp2}, []string{"test"}) - assert.Nil(t, err) - v, err = c.Verify(time.Now(), caPool) - assert.False(t, v) - assert.EqualError(t, err, "certificate contained a subnet assignment outside the limitations of the signing ca: 10.1.0.0/24") + cIp1 = mustParsePrefixUnmapped("192.168.0.1/24") + cIp2 = mustParsePrefixUnmapped("10.1.0.0/24") + c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"}) + assert.EqualError(t, err, "certificate contained an unsafe network assignment outside the limitations of the signing ca: 10.1.0.0/24") // ip is within the network but mask is outside - cIp1 = &net.IPNet{IP: net.ParseIP("10.0.1.0"), Mask: []byte{255, 254, 0, 0}} - cIp2 = &net.IPNet{IP: net.ParseIP("192.168.0.1"), Mask: []byte{255, 255, 255, 0}} - c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{}, []*net.IPNet{cIp1, cIp2}, []string{"test"}) - assert.Nil(t, err) - v, err = c.Verify(time.Now(), caPool) - assert.False(t, v) - assert.EqualError(t, err, "certificate contained a subnet assignment outside the limitations of the signing ca: 10.0.1.0/15") + cIp1 = mustParsePrefixUnmapped("10.0.1.0/15") + cIp2 = mustParsePrefixUnmapped("192.168.0.1/24") + c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"}) + assert.EqualError(t, err, "certificate contained an unsafe network assignment outside the limitations of the signing ca: 10.0.1.0/15") // ip is within the network but mask is outside reversed order of above - cIp1 = &net.IPNet{IP: net.ParseIP("192.168.0.1"), Mask: []byte{255, 255, 255, 0}} - cIp2 = &net.IPNet{IP: net.ParseIP("10.0.1.0"), Mask: []byte{255, 254, 0, 0}} - c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{}, []*net.IPNet{cIp1, cIp2}, []string{"test"}) - assert.Nil(t, err) - v, err = c.Verify(time.Now(), caPool) - assert.False(t, v) - assert.EqualError(t, err, "certificate contained a subnet assignment outside the limitations of the signing ca: 10.0.1.0/15") + cIp1 = mustParsePrefixUnmapped("192.168.0.1/24") + cIp2 = mustParsePrefixUnmapped("10.0.1.0/15") + c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"}) + assert.EqualError(t, err, "certificate contained an unsafe network assignment outside the limitations of the signing ca: 10.0.1.0/15") // ip and mask are within the network - cIp1 = &net.IPNet{IP: net.ParseIP("10.0.1.0"), Mask: []byte{255, 255, 0, 0}} - cIp2 = &net.IPNet{IP: net.ParseIP("192.168.0.1"), Mask: []byte{255, 255, 255, 128}} - c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{}, []*net.IPNet{cIp1, cIp2}, []string{"test"}) + cIp1 = mustParsePrefixUnmapped("10.0.1.0/16") + cIp2 = mustParsePrefixUnmapped("192.168.0.1/25") + c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"}) assert.Nil(t, err) - v, err = c.Verify(time.Now(), caPool) - assert.True(t, v) + _, err = caPool.VerifyCertificate(time.Now(), c) assert.Nil(t, err) // Exact matches - c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{}, []*net.IPNet{caIp1, caIp2}, []string{"test"}) + c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp1, caIp2}, []string{"test"}) assert.Nil(t, err) - v, err = c.Verify(time.Now(), caPool) - assert.True(t, v) + _, err = caPool.VerifyCertificate(time.Now(), c) assert.Nil(t, err) // Exact matches reversed - c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{}, []*net.IPNet{caIp2, caIp1}, []string{"test"}) + c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp2, caIp1}, []string{"test"}) assert.Nil(t, err) - v, err = c.Verify(time.Now(), caPool) - assert.True(t, v) + _, err = caPool.VerifyCertificate(time.Now(), c) assert.Nil(t, err) // Exact matches reversed with just 1 - c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{}, []*net.IPNet{caIp1}, []string{"test"}) + c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp1}, []string{"test"}) assert.Nil(t, err) - v, err = c.Verify(time.Now(), caPool) - assert.True(t, v) + _, err = caPool.VerifyCertificate(time.Now(), c) assert.Nil(t, err) } func TestNebulaCertificate_VerifyPrivateKey(t *testing.T) { - ca, _, caKey, err := newTestCaCert(time.Time{}, time.Time{}, []*net.IPNet{}, []*net.IPNet{}, []string{}) + ca, _, caKey, err := newTestCaCert(time.Time{}, time.Time{}, nil, nil, nil) assert.Nil(t, err) err = ca.VerifyPrivateKey(Curve_CURVE25519, caKey) assert.Nil(t, err) - _, _, caKey2, err := newTestCaCert(time.Time{}, time.Time{}, []*net.IPNet{}, []*net.IPNet{}, []string{}) + _, _, caKey2, err := newTestCaCert(time.Time{}, time.Time{}, nil, nil, nil) assert.Nil(t, err) err = ca.VerifyPrivateKey(Curve_CURVE25519, caKey2) assert.NotNil(t, err) - c, _, priv, err := newTestCert(ca, caKey, time.Time{}, time.Time{}, []*net.IPNet{}, []*net.IPNet{}, []string{}) + c, _, priv, err := newTestCert(ca, caKey, time.Time{}, time.Time{}, nil, nil, nil) err = c.VerifyPrivateKey(Curve_CURVE25519, priv) assert.Nil(t, err) @@ -500,17 +441,17 @@ func TestNebulaCertificate_VerifyPrivateKey(t *testing.T) { } func TestNebulaCertificate_VerifyPrivateKeyP256(t *testing.T) { - ca, _, caKey, err := newTestCaCertP256(time.Time{}, time.Time{}, []*net.IPNet{}, []*net.IPNet{}, []string{}) + ca, _, caKey, err := newTestCaCertP256(time.Time{}, time.Time{}, nil, nil, nil) assert.Nil(t, err) err = ca.VerifyPrivateKey(Curve_P256, caKey) assert.Nil(t, err) - _, _, caKey2, err := newTestCaCertP256(time.Time{}, time.Time{}, []*net.IPNet{}, []*net.IPNet{}, []string{}) + _, _, caKey2, err := newTestCaCertP256(time.Time{}, time.Time{}, nil, nil, nil) assert.Nil(t, err) err = ca.VerifyPrivateKey(Curve_P256, caKey2) assert.NotNil(t, err) - c, _, priv, err := newTestCert(ca, caKey, time.Time{}, time.Time{}, []*net.IPNet{}, []*net.IPNet{}, []string{}) + c, _, priv, err := newTestCert(ca, caKey, time.Time{}, time.Time{}, nil, nil, nil) err = c.VerifyPrivateKey(Curve_P256, priv) assert.Nil(t, err) @@ -519,108 +460,6 @@ func TestNebulaCertificate_VerifyPrivateKeyP256(t *testing.T) { assert.NotNil(t, err) } -func TestNewCAPoolFromBytes(t *testing.T) { - noNewLines := ` -# Current provisional, Remove once everything moves over to the real root. ------BEGIN NEBULA CERTIFICATE----- -CkAKDm5lYnVsYSByb290IGNhKJfap9AFMJfg1+YGOiCUQGByMuNRhIlQBOyzXWbL -vcKBwDhov900phEfJ5DN3kABEkDCq5R8qBiu8sl54yVfgRcQXEDt3cHr8UTSLszv -bzBEr00kERQxxTzTsH8cpYEgRoipvmExvg8WP8NdAJEYJosB ------END NEBULA CERTIFICATE----- -# root-ca01 ------BEGIN NEBULA CERTIFICATE----- -CkMKEW5lYnVsYSByb290IGNhIDAxKJL2u9EFMJL86+cGOiDPXMH4oU6HZTk/CqTG -BVG+oJpAoqokUBbI4U0N8CSfpUABEkB/Pm5A2xyH/nc8mg/wvGUWG3pZ7nHzaDMf -8/phAUt+FLzqTECzQKisYswKvE3pl9mbEYKbOdIHrxdIp95mo4sF ------END NEBULA CERTIFICATE----- -` - - withNewLines := ` -# Current provisional, Remove once everything moves over to the real root. - ------BEGIN NEBULA CERTIFICATE----- -CkAKDm5lYnVsYSByb290IGNhKJfap9AFMJfg1+YGOiCUQGByMuNRhIlQBOyzXWbL -vcKBwDhov900phEfJ5DN3kABEkDCq5R8qBiu8sl54yVfgRcQXEDt3cHr8UTSLszv -bzBEr00kERQxxTzTsH8cpYEgRoipvmExvg8WP8NdAJEYJosB ------END NEBULA CERTIFICATE----- - -# root-ca01 - - ------BEGIN NEBULA CERTIFICATE----- -CkMKEW5lYnVsYSByb290IGNhIDAxKJL2u9EFMJL86+cGOiDPXMH4oU6HZTk/CqTG -BVG+oJpAoqokUBbI4U0N8CSfpUABEkB/Pm5A2xyH/nc8mg/wvGUWG3pZ7nHzaDMf -8/phAUt+FLzqTECzQKisYswKvE3pl9mbEYKbOdIHrxdIp95mo4sF ------END NEBULA CERTIFICATE----- - -` - - expired := ` -# expired certificate ------BEGIN NEBULA CERTIFICATE----- -CjkKB2V4cGlyZWQouPmWjQYwufmWjQY6ILCRaoCkJlqHgv5jfDN4lzLHBvDzaQm4 -vZxfu144hmgjQAESQG4qlnZi8DncvD/LDZnLgJHOaX1DWCHHEh59epVsC+BNgTie -WH1M9n4O7cFtGlM6sJJOS+rCVVEJ3ABS7+MPdQs= ------END NEBULA CERTIFICATE----- -` - - p256 := ` -# p256 certificate ------BEGIN NEBULA CERTIFICATE----- -CmYKEG5lYnVsYSBQMjU2IHRlc3Qo4s+7mgYw4tXrsAc6QQRkaW2jFmllYvN4+/k2 -6tctO9sPT3jOx8ES6M1nIqOhpTmZeabF/4rELDqPV4aH5jfJut798DUXql0FlF8H -76gvQAGgBgESRzBFAiEAib0/te6eMiZOKD8gdDeloMTS0wGuX2t0C7TFdUhAQzgC -IBNWYMep3ysx9zCgknfG5dKtwGTaqF++BWKDYdyl34KX ------END NEBULA CERTIFICATE----- -` - - rootCA := NebulaCertificate{ - Details: NebulaCertificateDetails{ - Name: "nebula root ca", - }, - } - - rootCA01 := NebulaCertificate{ - Details: NebulaCertificateDetails{ - Name: "nebula root ca 01", - }, - } - - rootCAP256 := NebulaCertificate{ - Details: NebulaCertificateDetails{ - Name: "nebula P256 test", - }, - } - - p, err := NewCAPoolFromBytes([]byte(noNewLines)) - assert.Nil(t, err) - assert.Equal(t, p.CAs[string("c9bfaf7ce8e84b2eeda2e27b469f4b9617bde192efd214b68891ecda6ed49522")].Details.Name, rootCA.Details.Name) - assert.Equal(t, p.CAs[string("5c9c3f23e7ee7fe97637cbd3a0a5b854154d1d9aaaf7b566a51f4a88f76b64cd")].Details.Name, rootCA01.Details.Name) - - pp, err := NewCAPoolFromBytes([]byte(withNewLines)) - assert.Nil(t, err) - assert.Equal(t, pp.CAs[string("c9bfaf7ce8e84b2eeda2e27b469f4b9617bde192efd214b68891ecda6ed49522")].Details.Name, rootCA.Details.Name) - assert.Equal(t, pp.CAs[string("5c9c3f23e7ee7fe97637cbd3a0a5b854154d1d9aaaf7b566a51f4a88f76b64cd")].Details.Name, rootCA01.Details.Name) - - // expired cert, no valid certs - ppp, err := NewCAPoolFromBytes([]byte(expired)) - assert.Equal(t, ErrExpired, err) - assert.Equal(t, ppp.CAs[string("152070be6bb19bc9e3bde4c2f0e7d8f4ff5448b4c9856b8eccb314fade0229b0")].Details.Name, "expired") - - // expired cert, with valid certs - pppp, err := NewCAPoolFromBytes(append([]byte(expired), noNewLines...)) - assert.Equal(t, ErrExpired, err) - assert.Equal(t, pppp.CAs[string("c9bfaf7ce8e84b2eeda2e27b469f4b9617bde192efd214b68891ecda6ed49522")].Details.Name, rootCA.Details.Name) - assert.Equal(t, pppp.CAs[string("5c9c3f23e7ee7fe97637cbd3a0a5b854154d1d9aaaf7b566a51f4a88f76b64cd")].Details.Name, rootCA01.Details.Name) - assert.Equal(t, pppp.CAs[string("152070be6bb19bc9e3bde4c2f0e7d8f4ff5448b4c9856b8eccb314fade0229b0")].Details.Name, "expired") - assert.Equal(t, len(pppp.CAs), 3) - - ppppp, err := NewCAPoolFromBytes([]byte(p256)) - assert.Nil(t, err) - assert.Equal(t, ppppp.CAs[string("a7938893ec8c4ef769b06d7f425e5e46f7a7f5ffa49c3bcf4a86b608caba9159")].Details.Name, rootCAP256.Details.Name) - assert.Equal(t, len(ppppp.CAs), 1) -} - func appendByteSlices(b ...[]byte) []byte { retSlice := []byte{} for _, v := range b { @@ -629,420 +468,55 @@ func appendByteSlices(b ...[]byte) []byte { return retSlice } -func TestUnmrshalCertPEM(t *testing.T) { - goodCert := []byte(` -# A good cert ------BEGIN NEBULA CERTIFICATE----- -CkAKDm5lYnVsYSByb290IGNhKJfap9AFMJfg1+YGOiCUQGByMuNRhIlQBOyzXWbL -vcKBwDhov900phEfJ5DN3kABEkDCq5R8qBiu8sl54yVfgRcQXEDt3cHr8UTSLszv -bzBEr00kERQxxTzTsH8cpYEgRoipvmExvg8WP8NdAJEYJosB ------END NEBULA CERTIFICATE----- -`) - badBanner := []byte(`# A bad banner ------BEGIN NOT A NEBULA CERTIFICATE----- -CkAKDm5lYnVsYSByb290IGNhKJfap9AFMJfg1+YGOiCUQGByMuNRhIlQBOyzXWbL -vcKBwDhov900phEfJ5DN3kABEkDCq5R8qBiu8sl54yVfgRcQXEDt3cHr8UTSLszv -bzBEr00kERQxxTzTsH8cpYEgRoipvmExvg8WP8NdAJEYJosB ------END NOT A NEBULA CERTIFICATE----- -`) - invalidPem := []byte(`# Not a valid PEM format --BEGIN NEBULA CERTIFICATE----- -CkAKDm5lYnVsYSByb290IGNhKJfap9AFMJfg1+YGOiCUQGByMuNRhIlQBOyzXWbL -vcKBwDhov900phEfJ5DN3kABEkDCq5R8qBiu8sl54yVfgRcQXEDt3cHr8UTSLszv -bzBEr00kERQxxTzTsH8cpYEgRoipvmExvg8WP8NdAJEYJosB --END NEBULA CERTIFICATE----`) - - certBundle := appendByteSlices(goodCert, badBanner, invalidPem) - - // Success test case - cert, rest, err := UnmarshalNebulaCertificateFromPEM(certBundle) - assert.NotNil(t, cert) - assert.Equal(t, rest, append(badBanner, invalidPem...)) - assert.Nil(t, err) - - // Fail due to invalid banner. - cert, rest, err = UnmarshalNebulaCertificateFromPEM(rest) - assert.Nil(t, cert) - assert.Equal(t, rest, invalidPem) - assert.EqualError(t, err, "bytes did not contain a proper nebula certificate banner") - - // Fail due to ivalid PEM format, because - // it's missing the requisite pre-encapsulation boundary. - cert, rest, err = UnmarshalNebulaCertificateFromPEM(rest) - assert.Nil(t, cert) - assert.Equal(t, rest, invalidPem) - assert.EqualError(t, err, "input did not contain a valid PEM encoded block") -} - -func TestUnmarshalSigningPrivateKey(t *testing.T) { - privKey := []byte(`# A good key ------BEGIN NEBULA ED25519 PRIVATE KEY----- -AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA== ------END NEBULA ED25519 PRIVATE KEY----- -`) - privP256Key := []byte(`# A good key ------BEGIN NEBULA ECDSA P256 PRIVATE KEY----- -AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA= ------END NEBULA ECDSA P256 PRIVATE KEY----- -`) - shortKey := []byte(`# A short key ------BEGIN NEBULA ED25519 PRIVATE KEY----- -AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA ------END NEBULA ED25519 PRIVATE KEY----- -`) - invalidBanner := []byte(`# Invalid banner ------BEGIN NOT A NEBULA PRIVATE KEY----- -AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA== ------END NOT A NEBULA PRIVATE KEY----- -`) - invalidPem := []byte(`# Not a valid PEM format --BEGIN NEBULA ED25519 PRIVATE KEY----- -AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA== --END NEBULA ED25519 PRIVATE KEY-----`) - - keyBundle := appendByteSlices(privKey, privP256Key, shortKey, invalidBanner, invalidPem) - - // Success test case - k, rest, curve, err := UnmarshalSigningPrivateKey(keyBundle) - assert.Len(t, k, 64) - assert.Equal(t, rest, appendByteSlices(privP256Key, shortKey, invalidBanner, invalidPem)) - assert.Equal(t, Curve_CURVE25519, curve) - assert.Nil(t, err) - - // Success test case - k, rest, curve, err = UnmarshalSigningPrivateKey(rest) - assert.Len(t, k, 32) - assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem)) - assert.Equal(t, Curve_P256, curve) - assert.Nil(t, err) - - // Fail due to short key - k, rest, curve, err = UnmarshalSigningPrivateKey(rest) - assert.Nil(t, k) - assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem)) - assert.EqualError(t, err, "key was not 64 bytes, is invalid Ed25519 private key") - - // Fail due to invalid banner - k, rest, curve, err = UnmarshalSigningPrivateKey(rest) - assert.Nil(t, k) - assert.Equal(t, rest, invalidPem) - assert.EqualError(t, err, "bytes did not contain a proper nebula Ed25519/ECDSA private key banner") - - // Fail due to ivalid PEM format, because - // it's missing the requisite pre-encapsulation boundary. - k, rest, curve, err = UnmarshalSigningPrivateKey(rest) - assert.Nil(t, k) - assert.Equal(t, rest, invalidPem) - assert.EqualError(t, err, "input did not contain a valid PEM encoded block") -} - -func TestDecryptAndUnmarshalSigningPrivateKey(t *testing.T) { - passphrase := []byte("DO NOT USE THIS KEY") - privKey := []byte(`# A good key ------BEGIN NEBULA ED25519 ENCRYPTED PRIVATE KEY----- -CjwKC0FFUy0yNTYtR0NNEi0IExCAgIABGAEgBCognnjujd67Vsv99p22wfAjQaDT -oCMW1mdjkU3gACKNW4MSXOWR9Sts4C81yk1RUku2gvGKs3TB9LYoklLsIizSYOLl -+Vs//O1T0I1Xbml2XBAROsb/VSoDln/6LMqR4B6fn6B3GOsLBBqRI8daDl9lRMPB -qrlJ69wer3ZUHFXA ------END NEBULA ED25519 ENCRYPTED PRIVATE KEY----- -`) - shortKey := []byte(`# A key which, once decrypted, is too short ------BEGIN NEBULA ED25519 ENCRYPTED PRIVATE KEY----- -CjwKC0FFUy0yNTYtR0NNEi0IExCAgIABGAEgBCoga5h8owMEBWRSMMJKzuUvWce7 -k0qlBkQmCxiuLh80MuASW70YcKt8jeEIS2axo2V6zAKA9TSMcCsJW1kDDXEtL/xe -GLF5T7sDl5COp4LU3pGxpV+KoeQ/S3gQCAAcnaOtnJQX+aSDnbO3jCHyP7U9CHbs -rQr3bdH3Oy/WiYU= ------END NEBULA ED25519 ENCRYPTED PRIVATE KEY----- -`) - invalidBanner := []byte(`# Invalid banner (not encrypted) ------BEGIN NEBULA ED25519 PRIVATE KEY----- -bWRp2CTVFhW9HD/qCd28ltDgK3w8VXSeaEYczDWos8sMUBqDb9jP3+NYwcS4lURG -XgLvodMXZJuaFPssp+WwtA== ------END NEBULA ED25519 PRIVATE KEY----- -`) - invalidPem := []byte(`# Not a valid PEM format --BEGIN NEBULA ED25519 ENCRYPTED PRIVATE KEY----- -CjwKC0FFUy0yNTYtR0NNEi0IExCAgIABGAEgBCognnjujd67Vsv99p22wfAjQaDT -oCMW1mdjkU3gACKNW4MSXOWR9Sts4C81yk1RUku2gvGKs3TB9LYoklLsIizSYOLl -+Vs//O1T0I1Xbml2XBAROsb/VSoDln/6LMqR4B6fn6B3GOsLBBqRI8daDl9lRMPB -qrlJ69wer3ZUHFXA --END NEBULA ED25519 ENCRYPTED PRIVATE KEY----- -`) - - keyBundle := appendByteSlices(privKey, shortKey, invalidBanner, invalidPem) - - // Success test case - curve, k, rest, err := DecryptAndUnmarshalSigningPrivateKey(passphrase, keyBundle) - assert.Nil(t, err) - assert.Equal(t, Curve_CURVE25519, curve) - assert.Len(t, k, 64) - assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem)) - - // Fail due to short key - curve, k, rest, err = DecryptAndUnmarshalSigningPrivateKey(passphrase, rest) - assert.EqualError(t, err, "key was not 64 bytes, is invalid ed25519 private key") - assert.Nil(t, k) - assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem)) - - // Fail due to invalid banner - curve, k, rest, err = DecryptAndUnmarshalSigningPrivateKey(passphrase, rest) - assert.EqualError(t, err, "bytes did not contain a proper nebula encrypted Ed25519/ECDSA private key banner") - assert.Nil(t, k) - assert.Equal(t, rest, invalidPem) - - // Fail due to ivalid PEM format, because - // it's missing the requisite pre-encapsulation boundary. - curve, k, rest, err = DecryptAndUnmarshalSigningPrivateKey(passphrase, rest) - assert.EqualError(t, err, "input did not contain a valid PEM encoded block") - assert.Nil(t, k) - assert.Equal(t, rest, invalidPem) - - // Fail due to invalid passphrase - curve, k, rest, err = DecryptAndUnmarshalSigningPrivateKey([]byte("invalid passphrase"), privKey) - assert.EqualError(t, err, "invalid passphrase or corrupt private key") - assert.Nil(t, k) - assert.Equal(t, rest, []byte{}) -} - -func TestEncryptAndMarshalSigningPrivateKey(t *testing.T) { - // Having proved that decryption works correctly above, we can test the - // encryption function produces a value which can be decrypted - passphrase := []byte("passphrase") - bytes := []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA") - kdfParams := NewArgon2Parameters(64*1024, 4, 3) - key, err := EncryptAndMarshalSigningPrivateKey(Curve_CURVE25519, bytes, passphrase, kdfParams) - assert.Nil(t, err) - - // Verify the "key" can be decrypted successfully - curve, k, rest, err := DecryptAndUnmarshalSigningPrivateKey(passphrase, key) - assert.Len(t, k, 64) - assert.Equal(t, Curve_CURVE25519, curve) - assert.Equal(t, rest, []byte{}) - assert.Nil(t, err) - - // EncryptAndMarshalEd25519PrivateKey does not create any errors itself -} - -func TestUnmarshalPrivateKey(t *testing.T) { - privKey := []byte(`# A good key ------BEGIN NEBULA X25519 PRIVATE KEY----- -AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA= ------END NEBULA X25519 PRIVATE KEY----- -`) - privP256Key := []byte(`# A good key ------BEGIN NEBULA P256 PRIVATE KEY----- -AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA= ------END NEBULA P256 PRIVATE KEY----- -`) - shortKey := []byte(`# A short key ------BEGIN NEBULA X25519 PRIVATE KEY----- -AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA== ------END NEBULA X25519 PRIVATE KEY----- -`) - invalidBanner := []byte(`# Invalid banner ------BEGIN NOT A NEBULA PRIVATE KEY----- -AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA= ------END NOT A NEBULA PRIVATE KEY----- -`) - invalidPem := []byte(`# Not a valid PEM format --BEGIN NEBULA X25519 PRIVATE KEY----- -AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA= --END NEBULA X25519 PRIVATE KEY-----`) - - keyBundle := appendByteSlices(privKey, privP256Key, shortKey, invalidBanner, invalidPem) - - // Success test case - k, rest, curve, err := UnmarshalPrivateKey(keyBundle) - assert.Len(t, k, 32) - assert.Equal(t, rest, appendByteSlices(privP256Key, shortKey, invalidBanner, invalidPem)) - assert.Equal(t, Curve_CURVE25519, curve) - assert.Nil(t, err) - - // Success test case - k, rest, curve, err = UnmarshalPrivateKey(rest) - assert.Len(t, k, 32) - assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem)) - assert.Equal(t, Curve_P256, curve) - assert.Nil(t, err) - - // Fail due to short key - k, rest, curve, err = UnmarshalPrivateKey(rest) - assert.Nil(t, k) - assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem)) - assert.EqualError(t, err, "key was not 32 bytes, is invalid CURVE25519 private key") - - // Fail due to invalid banner - k, rest, curve, err = UnmarshalPrivateKey(rest) - assert.Nil(t, k) - assert.Equal(t, rest, invalidPem) - assert.EqualError(t, err, "bytes did not contain a proper nebula private key banner") - - // Fail due to ivalid PEM format, because - // it's missing the requisite pre-encapsulation boundary. - k, rest, curve, err = UnmarshalPrivateKey(rest) - assert.Nil(t, k) - assert.Equal(t, rest, invalidPem) - assert.EqualError(t, err, "input did not contain a valid PEM encoded block") -} - -func TestUnmarshalEd25519PublicKey(t *testing.T) { - pubKey := []byte(`# A good key ------BEGIN NEBULA ED25519 PUBLIC KEY----- -AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA= ------END NEBULA ED25519 PUBLIC KEY----- -`) - shortKey := []byte(`# A short key ------BEGIN NEBULA ED25519 PUBLIC KEY----- -AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA== ------END NEBULA ED25519 PUBLIC KEY----- -`) - invalidBanner := []byte(`# Invalid banner ------BEGIN NOT A NEBULA PUBLIC KEY----- -AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA= ------END NOT A NEBULA PUBLIC KEY----- -`) - invalidPem := []byte(`# Not a valid PEM format --BEGIN NEBULA ED25519 PUBLIC KEY----- -AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA= --END NEBULA ED25519 PUBLIC KEY-----`) - - keyBundle := appendByteSlices(pubKey, shortKey, invalidBanner, invalidPem) - - // Success test case - k, rest, err := UnmarshalEd25519PublicKey(keyBundle) - assert.Equal(t, len(k), 32) - assert.Nil(t, err) - assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem)) - - // Fail due to short key - k, rest, err = UnmarshalEd25519PublicKey(rest) - assert.Nil(t, k) - assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem)) - assert.EqualError(t, err, "key was not 32 bytes, is invalid ed25519 public key") - - // Fail due to invalid banner - k, rest, err = UnmarshalEd25519PublicKey(rest) - assert.Nil(t, k) - assert.EqualError(t, err, "bytes did not contain a proper nebula Ed25519 public key banner") - assert.Equal(t, rest, invalidPem) - - // Fail due to ivalid PEM format, because - // it's missing the requisite pre-encapsulation boundary. - k, rest, err = UnmarshalEd25519PublicKey(rest) - assert.Nil(t, k) - assert.Equal(t, rest, invalidPem) - assert.EqualError(t, err, "input did not contain a valid PEM encoded block") -} - -func TestUnmarshalX25519PublicKey(t *testing.T) { - pubKey := []byte(`# A good key ------BEGIN NEBULA X25519 PUBLIC KEY----- -AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA= ------END NEBULA X25519 PUBLIC KEY----- -`) - pubP256Key := []byte(`# A good key ------BEGIN NEBULA P256 PUBLIC KEY----- -AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA -AAAAAAAAAAAAAAAAAAAAAAA= ------END NEBULA P256 PUBLIC KEY----- -`) - shortKey := []byte(`# A short key ------BEGIN NEBULA X25519 PUBLIC KEY----- -AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA== ------END NEBULA X25519 PUBLIC KEY----- -`) - invalidBanner := []byte(`# Invalid banner ------BEGIN NOT A NEBULA PUBLIC KEY----- -AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA= ------END NOT A NEBULA PUBLIC KEY----- -`) - invalidPem := []byte(`# Not a valid PEM format --BEGIN NEBULA X25519 PUBLIC KEY----- -AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA= --END NEBULA X25519 PUBLIC KEY-----`) - - keyBundle := appendByteSlices(pubKey, pubP256Key, shortKey, invalidBanner, invalidPem) - - // Success test case - k, rest, curve, err := UnmarshalPublicKey(keyBundle) - assert.Equal(t, len(k), 32) - assert.Nil(t, err) - assert.Equal(t, rest, appendByteSlices(pubP256Key, shortKey, invalidBanner, invalidPem)) - assert.Equal(t, Curve_CURVE25519, curve) - - // Success test case - k, rest, curve, err = UnmarshalPublicKey(rest) - assert.Equal(t, len(k), 65) - assert.Nil(t, err) - assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem)) - assert.Equal(t, Curve_P256, curve) - - // Fail due to short key - k, rest, curve, err = UnmarshalPublicKey(rest) - assert.Nil(t, k) - assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem)) - assert.EqualError(t, err, "key was not 32 bytes, is invalid CURVE25519 public key") - - // Fail due to invalid banner - k, rest, curve, err = UnmarshalPublicKey(rest) - assert.Nil(t, k) - assert.EqualError(t, err, "bytes did not contain a proper nebula public key banner") - assert.Equal(t, rest, invalidPem) - - // Fail due to ivalid PEM format, because - // it's missing the requisite pre-encapsulation boundary. - k, rest, curve, err = UnmarshalPublicKey(rest) - assert.Nil(t, k) - assert.Equal(t, rest, invalidPem) - assert.EqualError(t, err, "input did not contain a valid PEM encoded block") -} - // Ensure that upgrading the protobuf library does not change how certificates // are marshalled, since this would break signature verification -func TestMarshalingNebulaCertificateConsistency(t *testing.T) { - before := time.Date(2009, time.November, 10, 23, 0, 0, 0, time.UTC) - after := time.Date(2017, time.January, 18, 28, 40, 0, 0, time.UTC) - pubKey := []byte("1234567890abcedfghij1234567890ab") - - nc := NebulaCertificate{ - Details: NebulaCertificateDetails{ - Name: "testing", - Ips: []*net.IPNet{ - {IP: net.ParseIP("10.1.1.1"), Mask: net.IPMask(net.ParseIP("255.255.255.0"))}, - {IP: net.ParseIP("10.1.1.2"), Mask: net.IPMask(net.ParseIP("255.255.0.0"))}, - {IP: net.ParseIP("10.1.1.3"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))}, - }, - Subnets: []*net.IPNet{ - {IP: net.ParseIP("9.1.1.1"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))}, - {IP: net.ParseIP("9.1.1.2"), Mask: net.IPMask(net.ParseIP("255.255.255.0"))}, - {IP: net.ParseIP("9.1.1.3"), Mask: net.IPMask(net.ParseIP("255.255.0.0"))}, - }, - Groups: []string{"test-group1", "test-group2", "test-group3"}, - NotBefore: before, - NotAfter: after, - PublicKey: pubKey, - IsCA: false, - Issuer: "1234567890abcedfghij1234567890ab", - }, - Signature: []byte("1234567890abcedfghij1234567890ab"), - } - - b, err := nc.Marshal() - assert.Nil(t, err) - //t.Log("Cert size:", len(b)) - assert.Equal(t, "0aa2010a0774657374696e67121b8182845080feffff0f828284508080fcff0f8382845080fe83f80f1a1b8182844880fe83f80f8282844880feffff0f838284488080fcff0f220b746573742d67726f757031220b746573742d67726f757032220b746573742d67726f75703328f0e0e7d70430a08681c4053a20313233343536373839306162636564666768696a3132333435363738393061624a081234567890abcedf1220313233343536373839306162636564666768696a313233343536373839306162", fmt.Sprintf("%x", b)) - - b, err = proto.Marshal(nc.getRawDetails()) - assert.Nil(t, err) - //t.Log("Raw cert size:", len(b)) - assert.Equal(t, "0a0774657374696e67121b8182845080feffff0f828284508080fcff0f8382845080fe83f80f1a1b8182844880fe83f80f8282844880feffff0f838284488080fcff0f220b746573742d67726f757031220b746573742d67726f757032220b746573742d67726f75703328f0e0e7d70430a08681c4053a20313233343536373839306162636564666768696a3132333435363738393061624a081234567890abcedf", fmt.Sprintf("%x", b)) -} +//TODO: since netip cant represent 255.0.255.0 netmask we can't verify the old certs are ok +//func TestMarshalingNebulaCertificateConsistency(t *testing.T) { +// before := time.Date(2009, time.November, 10, 23, 0, 0, 0, time.UTC) +// after := time.Date(2017, time.January, 18, 28, 40, 0, 0, time.UTC) +// pubKey := []byte("1234567890abcedfghij1234567890ab") +// +// nc := certificateV1{ +// details: detailsV1{ +// Name: "testing", +// Ips: []netip.Prefix{ +// mustParsePrefixUnmapped("10.1.1.1/24"), +// mustParsePrefixUnmapped("10.1.1.2/16"), +// //TODO: netip bad +// //{IP: net.ParseIP("10.1.1.3"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))}, +// }, +// Subnets: []netip.Prefix{ +// //TODO: netip bad +// //{IP: net.ParseIP("9.1.1.1"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))}, +// mustParsePrefixUnmapped("9.1.1.2/24"), +// mustParsePrefixUnmapped("9.1.1.3/16"), +// }, +// Groups: []string{"test-group1", "test-group2", "test-group3"}, +// NotBefore: before, +// NotAfter: after, +// PublicKey: pubKey, +// IsCA: false, +// Issuer: "1234567890abcedfghij1234567890ab", +// }, +// signature: []byte("1234567890abcedfghij1234567890ab"), +// } +// +// b, err := nc.Marshal() +// assert.Nil(t, err) +// //t.Log("Cert size:", len(b)) +// assert.Equal(t, "0aa2010a0774657374696e67121b8182845080feffff0f828284508080fcff0f8382845080fe83f80f1a1b8182844880fe83f80f8282844880feffff0f838284488080fcff0f220b746573742d67726f757031220b746573742d67726f757032220b746573742d67726f75703328f0e0e7d70430a08681c4053a20313233343536373839306162636564666768696a3132333435363738393061624a081234567890abcedf1220313233343536373839306162636564666768696a313233343536373839306162", fmt.Sprintf("%x", b)) +// +// b, err = proto.Marshal(nc.getRawDetails()) +// assert.Nil(t, err) +// //t.Log("Raw cert size:", len(b)) +// assert.Equal(t, "0a0774657374696e67121b8182845080feffff0f828284508080fcff0f8382845080fe83f80f1a1b8182844880fe83f80f8282844880feffff0f838284488080fcff0f220b746573742d67726f757031220b746573742d67726f757032220b746573742d67726f75703328f0e0e7d70430a08681c4053a20313233343536373839306162636564666768696a3132333435363738393061624a081234567890abcedf", fmt.Sprintf("%x", b)) +//} func TestNebulaCertificate_Copy(t *testing.T) { - ca, _, caKey, err := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) + ca, _, caKey, err := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, nil) assert.Nil(t, err) - c, _, _, err := newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) + c, _, _, err := newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil) assert.Nil(t, err) cc := c.Copy() @@ -1052,11 +526,11 @@ func TestNebulaCertificate_Copy(t *testing.T) { func TestUnmarshalNebulaCertificate(t *testing.T) { // Test that we don't panic with an invalid certificate (#332) data := []byte("\x98\x00\x00") - _, err := UnmarshalNebulaCertificate(data) + _, err := unmarshalCertificateV1(data, true) assert.EqualError(t, err, "encoded Details was nil") } -func newTestCaCert(before, after time.Time, ips, subnets []*net.IPNet, groups []string) (*NebulaCertificate, []byte, []byte, error) { +func newTestCaCert(before, after time.Time, ips, subnets []netip.Prefix, groups []string) (Certificate, []byte, []byte, error) { pub, priv, err := ed25519.GenerateKey(rand.Reader) if before.IsZero() { before = time.Now().Add(time.Second * -60).Round(time.Second) @@ -1065,37 +539,35 @@ func newTestCaCert(before, after time.Time, ips, subnets []*net.IPNet, groups [] after = time.Now().Add(time.Second * 60).Round(time.Second) } - nc := &NebulaCertificate{ - Details: NebulaCertificateDetails{ - Name: "test ca", - NotBefore: time.Unix(before.Unix(), 0), - NotAfter: time.Unix(after.Unix(), 0), - PublicKey: pub, - IsCA: true, - InvertedGroups: make(map[string]struct{}), - }, + tbs := &TBSCertificate{ + Version: Version1, + Name: "test ca", + IsCA: true, + NotBefore: time.Unix(before.Unix(), 0), + NotAfter: time.Unix(after.Unix(), 0), + PublicKey: pub, } if len(ips) > 0 { - nc.Details.Ips = ips + tbs.Networks = ips } if len(subnets) > 0 { - nc.Details.Subnets = subnets + tbs.UnsafeNetworks = subnets } if len(groups) > 0 { - nc.Details.Groups = groups + tbs.Groups = groups } - err = nc.Sign(Curve_CURVE25519, priv) + nc, err := tbs.Sign(nil, Curve_CURVE25519, priv) if err != nil { return nil, nil, nil, err } return nc, pub, priv, nil } -func newTestCaCertP256(before, after time.Time, ips, subnets []*net.IPNet, groups []string) (*NebulaCertificate, []byte, []byte, error) { +func newTestCaCertP256(before, after time.Time, ips, subnets []netip.Prefix, groups []string) (Certificate, []byte, []byte, error) { priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) pub := elliptic.Marshal(elliptic.P256(), priv.PublicKey.X, priv.PublicKey.Y) rawPriv := priv.D.FillBytes(make([]byte, 32)) @@ -1107,43 +579,36 @@ func newTestCaCertP256(before, after time.Time, ips, subnets []*net.IPNet, group after = time.Now().Add(time.Second * 60).Round(time.Second) } - nc := &NebulaCertificate{ - Details: NebulaCertificateDetails{ - Name: "test ca", - NotBefore: time.Unix(before.Unix(), 0), - NotAfter: time.Unix(after.Unix(), 0), - PublicKey: pub, - IsCA: true, - Curve: Curve_P256, - InvertedGroups: make(map[string]struct{}), - }, + tbs := &TBSCertificate{ + Version: Version1, + Name: "test ca", + IsCA: true, + NotBefore: time.Unix(before.Unix(), 0), + NotAfter: time.Unix(after.Unix(), 0), + PublicKey: pub, + Curve: Curve_P256, } if len(ips) > 0 { - nc.Details.Ips = ips + tbs.Networks = ips } if len(subnets) > 0 { - nc.Details.Subnets = subnets + tbs.UnsafeNetworks = subnets } if len(groups) > 0 { - nc.Details.Groups = groups + tbs.Groups = groups } - err = nc.Sign(Curve_P256, rawPriv) + nc, err := tbs.Sign(nil, Curve_P256, rawPriv) if err != nil { return nil, nil, nil, err } return nc, pub, rawPriv, nil } -func newTestCert(ca *NebulaCertificate, key []byte, before, after time.Time, ips, subnets []*net.IPNet, groups []string) (*NebulaCertificate, []byte, []byte, error) { - issuer, err := ca.Sha256Sum() - if err != nil { - return nil, nil, nil, err - } - +func newTestCert(ca Certificate, key []byte, before, after time.Time, ips, subnets []netip.Prefix, groups []string) (Certificate, []byte, []byte, error) { if before.IsZero() { before = time.Now().Add(time.Second * -60).Round(time.Second) } @@ -1156,49 +621,44 @@ func newTestCert(ca *NebulaCertificate, key []byte, before, after time.Time, ips } if len(ips) == 0 { - ips = []*net.IPNet{ - {IP: net.ParseIP("10.1.1.1").To4(), Mask: net.IPMask(net.ParseIP("255.255.255.0").To4())}, - {IP: net.ParseIP("10.1.1.2").To4(), Mask: net.IPMask(net.ParseIP("255.255.0.0").To4())}, - {IP: net.ParseIP("10.1.1.3").To4(), Mask: net.IPMask(net.ParseIP("255.0.255.0").To4())}, + ips = []netip.Prefix{ + mustParsePrefixUnmapped("10.1.1.1/24"), + mustParsePrefixUnmapped("10.1.1.2/16"), } } if len(subnets) == 0 { - subnets = []*net.IPNet{ - {IP: net.ParseIP("9.1.1.1").To4(), Mask: net.IPMask(net.ParseIP("255.0.255.0").To4())}, - {IP: net.ParseIP("9.1.1.2").To4(), Mask: net.IPMask(net.ParseIP("255.255.255.0").To4())}, - {IP: net.ParseIP("9.1.1.3").To4(), Mask: net.IPMask(net.ParseIP("255.255.0.0").To4())}, + subnets = []netip.Prefix{ + mustParsePrefixUnmapped("9.1.1.2/24"), + mustParsePrefixUnmapped("9.1.1.3/16"), } } var pub, rawPriv []byte - switch ca.Details.Curve { + switch ca.Curve() { case Curve_CURVE25519: pub, rawPriv = x25519Keypair() case Curve_P256: pub, rawPriv = p256Keypair() default: - return nil, nil, nil, fmt.Errorf("unknown curve: %v", ca.Details.Curve) + return nil, nil, nil, fmt.Errorf("unknown curve: %v", ca.Curve()) } - nc := &NebulaCertificate{ - Details: NebulaCertificateDetails{ - Name: "testing", - Ips: ips, - Subnets: subnets, - Groups: groups, - NotBefore: time.Unix(before.Unix(), 0), - NotAfter: time.Unix(after.Unix(), 0), - PublicKey: pub, - IsCA: false, - Curve: ca.Details.Curve, - Issuer: issuer, - InvertedGroups: make(map[string]struct{}), - }, + tbs := &TBSCertificate{ + Version: Version1, + Name: "testing", + Networks: ips, + UnsafeNetworks: subnets, + Groups: groups, + IsCA: false, + NotBefore: time.Unix(before.Unix(), 0), + NotAfter: time.Unix(after.Unix(), 0), + PublicKey: pub, + Curve: ca.Curve(), } - err = nc.Sign(ca.Details.Curve, key) + nc, err := tbs.Sign(ca, ca.Curve(), key) if err != nil { return nil, nil, nil, err } @@ -1228,3 +688,8 @@ func p256Keypair() ([]byte, []byte) { pubkey := privkey.PublicKey() return pubkey.Bytes(), privkey.Bytes() } + +func mustParsePrefixUnmapped(s string) netip.Prefix { + prefix := netip.MustParsePrefix(s) + return netip.PrefixFrom(prefix.Addr().Unmap(), prefix.Bits()) +} diff --git a/cert/cert_v1.go b/cert/cert_v1.go new file mode 100644 index 0000000..165e409 --- /dev/null +++ b/cert/cert_v1.go @@ -0,0 +1,496 @@ +package cert + +import ( + "bytes" + "crypto/ecdh" + "crypto/ecdsa" + "crypto/ed25519" + "crypto/elliptic" + "crypto/rand" + "crypto/sha256" + "encoding/binary" + "encoding/hex" + "encoding/json" + "encoding/pem" + "fmt" + "math/big" + "net" + "net/netip" + "time" + + "github.com/slackhq/nebula/pkclient" + "golang.org/x/crypto/curve25519" + "google.golang.org/protobuf/proto" +) + +const publicKeyLen = 32 + +type certificateV1 struct { + details detailsV1 + signature []byte +} + +type detailsV1 struct { + Name string + Ips []netip.Prefix + Subnets []netip.Prefix + Groups []string + NotBefore time.Time + NotAfter time.Time + PublicKey []byte + IsCA bool + Issuer string + + Curve Curve +} + +type m map[string]interface{} + +func (nc *certificateV1) Version() Version { + return Version1 +} + +func (nc *certificateV1) Curve() Curve { + return nc.details.Curve +} + +func (nc *certificateV1) Groups() []string { + return nc.details.Groups +} + +func (nc *certificateV1) IsCA() bool { + return nc.details.IsCA +} + +func (nc *certificateV1) Issuer() string { + return nc.details.Issuer +} + +func (nc *certificateV1) Name() string { + return nc.details.Name +} + +func (nc *certificateV1) Networks() []netip.Prefix { + return nc.details.Ips +} + +func (nc *certificateV1) NotAfter() time.Time { + return nc.details.NotAfter +} + +func (nc *certificateV1) NotBefore() time.Time { + return nc.details.NotBefore +} + +func (nc *certificateV1) PublicKey() []byte { + return nc.details.PublicKey +} + +func (nc *certificateV1) Signature() []byte { + return nc.signature +} + +func (nc *certificateV1) UnsafeNetworks() []netip.Prefix { + return nc.details.Subnets +} + +func (nc *certificateV1) Fingerprint() (string, error) { + b, err := nc.Marshal() + if err != nil { + return "", err + } + + sum := sha256.Sum256(b) + return hex.EncodeToString(sum[:]), nil +} + +func (nc *certificateV1) CheckSignature(key []byte) bool { + b, err := proto.Marshal(nc.getRawDetails()) + if err != nil { + return false + } + switch nc.details.Curve { + case Curve_CURVE25519: + return ed25519.Verify(key, b, nc.signature) + case Curve_P256: + x, y := elliptic.Unmarshal(elliptic.P256(), key) + pubKey := &ecdsa.PublicKey{Curve: elliptic.P256(), X: x, Y: y} + hashed := sha256.Sum256(b) + return ecdsa.VerifyASN1(pubKey, hashed[:], nc.signature) + default: + return false + } +} + +func (nc *certificateV1) Expired(t time.Time) bool { + return nc.details.NotBefore.After(t) || nc.details.NotAfter.Before(t) +} + +func (nc *certificateV1) VerifyPrivateKey(curve Curve, key []byte) error { + if curve != nc.details.Curve { + return fmt.Errorf("curve in cert and private key supplied don't match") + } + if nc.details.IsCA { + switch curve { + case Curve_CURVE25519: + // the call to PublicKey below will panic slice bounds out of range otherwise + if len(key) != ed25519.PrivateKeySize { + return fmt.Errorf("key was not 64 bytes, is invalid ed25519 private key") + } + + if !ed25519.PublicKey(nc.details.PublicKey).Equal(ed25519.PrivateKey(key).Public()) { + return fmt.Errorf("public key in cert and private key supplied don't match") + } + case Curve_P256: + privkey, err := ecdh.P256().NewPrivateKey(key) + if err != nil { + return fmt.Errorf("cannot parse private key as P256: %w", err) + } + pub := privkey.PublicKey().Bytes() + if !bytes.Equal(pub, nc.details.PublicKey) { + return fmt.Errorf("public key in cert and private key supplied don't match") + } + default: + return fmt.Errorf("invalid curve: %s", curve) + } + return nil + } + + var pub []byte + switch curve { + case Curve_CURVE25519: + var err error + pub, err = curve25519.X25519(key, curve25519.Basepoint) + if err != nil { + return err + } + case Curve_P256: + privkey, err := ecdh.P256().NewPrivateKey(key) + if err != nil { + return err + } + pub = privkey.PublicKey().Bytes() + default: + return fmt.Errorf("invalid curve: %s", curve) + } + if !bytes.Equal(pub, nc.details.PublicKey) { + return fmt.Errorf("public key in cert and private key supplied don't match") + } + + return nil +} + +// getRawDetails marshals the raw details into protobuf ready struct +func (nc *certificateV1) getRawDetails() *RawNebulaCertificateDetails { + rd := &RawNebulaCertificateDetails{ + Name: nc.details.Name, + Groups: nc.details.Groups, + NotBefore: nc.details.NotBefore.Unix(), + NotAfter: nc.details.NotAfter.Unix(), + PublicKey: make([]byte, len(nc.details.PublicKey)), + IsCA: nc.details.IsCA, + Curve: nc.details.Curve, + } + + for _, ipNet := range nc.details.Ips { + mask := net.CIDRMask(ipNet.Bits(), ipNet.Addr().BitLen()) + rd.Ips = append(rd.Ips, addr2int(ipNet.Addr()), ip2int(mask)) + } + + for _, ipNet := range nc.details.Subnets { + mask := net.CIDRMask(ipNet.Bits(), ipNet.Addr().BitLen()) + rd.Subnets = append(rd.Subnets, addr2int(ipNet.Addr()), ip2int(mask)) + } + + copy(rd.PublicKey, nc.details.PublicKey[:]) + + // I know, this is terrible + rd.Issuer, _ = hex.DecodeString(nc.details.Issuer) + + return rd +} + +func (nc *certificateV1) String() string { + if nc == nil { + return "Certificate {}\n" + } + + s := "NebulaCertificate {\n" + s += "\tDetails {\n" + s += fmt.Sprintf("\t\tName: %v\n", nc.details.Name) + + if len(nc.details.Ips) > 0 { + s += "\t\tIps: [\n" + for _, ip := range nc.details.Ips { + s += fmt.Sprintf("\t\t\t%v\n", ip.String()) + } + s += "\t\t]\n" + } else { + s += "\t\tIps: []\n" + } + + if len(nc.details.Subnets) > 0 { + s += "\t\tSubnets: [\n" + for _, ip := range nc.details.Subnets { + s += fmt.Sprintf("\t\t\t%v\n", ip.String()) + } + s += "\t\t]\n" + } else { + s += "\t\tSubnets: []\n" + } + + if len(nc.details.Groups) > 0 { + s += "\t\tGroups: [\n" + for _, g := range nc.details.Groups { + s += fmt.Sprintf("\t\t\t\"%v\"\n", g) + } + s += "\t\t]\n" + } else { + s += "\t\tGroups: []\n" + } + + s += fmt.Sprintf("\t\tNot before: %v\n", nc.details.NotBefore) + s += fmt.Sprintf("\t\tNot After: %v\n", nc.details.NotAfter) + s += fmt.Sprintf("\t\tIs CA: %v\n", nc.details.IsCA) + s += fmt.Sprintf("\t\tIssuer: %s\n", nc.details.Issuer) + s += fmt.Sprintf("\t\tPublic key: %x\n", nc.details.PublicKey) + s += fmt.Sprintf("\t\tCurve: %s\n", nc.details.Curve) + s += "\t}\n" + fp, err := nc.Fingerprint() + if err == nil { + s += fmt.Sprintf("\tFingerprint: %s\n", fp) + } + s += fmt.Sprintf("\tSignature: %x\n", nc.Signature()) + s += "}" + + return s +} + +func (nc *certificateV1) MarshalForHandshakes() ([]byte, error) { + pubKey := nc.details.PublicKey + nc.details.PublicKey = nil + rawCertNoKey, err := nc.Marshal() + if err != nil { + return nil, err + } + nc.details.PublicKey = pubKey + return rawCertNoKey, nil +} + +func (nc *certificateV1) Marshal() ([]byte, error) { + rc := RawNebulaCertificate{ + Details: nc.getRawDetails(), + Signature: nc.signature, + } + + return proto.Marshal(&rc) +} + +func (nc *certificateV1) MarshalPEM() ([]byte, error) { + b, err := nc.Marshal() + if err != nil { + return nil, err + } + return pem.EncodeToMemory(&pem.Block{Type: CertificateBanner, Bytes: b}), nil +} + +func (nc *certificateV1) MarshalJSON() ([]byte, error) { + fp, _ := nc.Fingerprint() + jc := m{ + "details": m{ + "name": nc.details.Name, + "ips": nc.details.Ips, + "subnets": nc.details.Subnets, + "groups": nc.details.Groups, + "notBefore": nc.details.NotBefore, + "notAfter": nc.details.NotAfter, + "publicKey": fmt.Sprintf("%x", nc.details.PublicKey), + "isCa": nc.details.IsCA, + "issuer": nc.details.Issuer, + "curve": nc.details.Curve.String(), + }, + "fingerprint": fp, + "signature": fmt.Sprintf("%x", nc.Signature()), + } + return json.Marshal(jc) +} + +func (nc *certificateV1) Copy() Certificate { + c := &certificateV1{ + details: detailsV1{ + Name: nc.details.Name, + Groups: make([]string, len(nc.details.Groups)), + Ips: make([]netip.Prefix, len(nc.details.Ips)), + Subnets: make([]netip.Prefix, len(nc.details.Subnets)), + NotBefore: nc.details.NotBefore, + NotAfter: nc.details.NotAfter, + PublicKey: make([]byte, len(nc.details.PublicKey)), + IsCA: nc.details.IsCA, + Issuer: nc.details.Issuer, + }, + signature: make([]byte, len(nc.signature)), + } + + copy(c.signature, nc.signature) + copy(c.details.Groups, nc.details.Groups) + copy(c.details.PublicKey, nc.details.PublicKey) + + for i, p := range nc.details.Ips { + c.details.Ips[i] = p + } + + for i, p := range nc.details.Subnets { + c.details.Subnets[i] = p + } + + return c +} + +// unmarshalCertificateV1 will unmarshal a protobuf byte representation of a nebula cert +func unmarshalCertificateV1(b []byte, assertPublicKey bool) (*certificateV1, error) { + if len(b) == 0 { + return nil, fmt.Errorf("nil byte array") + } + var rc RawNebulaCertificate + err := proto.Unmarshal(b, &rc) + if err != nil { + return nil, err + } + + if rc.Details == nil { + return nil, fmt.Errorf("encoded Details was nil") + } + + if len(rc.Details.Ips)%2 != 0 { + return nil, fmt.Errorf("encoded IPs should be in pairs, an odd number was found") + } + + if len(rc.Details.Subnets)%2 != 0 { + return nil, fmt.Errorf("encoded Subnets should be in pairs, an odd number was found") + } + + nc := certificateV1{ + details: detailsV1{ + Name: rc.Details.Name, + Groups: make([]string, len(rc.Details.Groups)), + Ips: make([]netip.Prefix, len(rc.Details.Ips)/2), + Subnets: make([]netip.Prefix, len(rc.Details.Subnets)/2), + NotBefore: time.Unix(rc.Details.NotBefore, 0), + NotAfter: time.Unix(rc.Details.NotAfter, 0), + PublicKey: make([]byte, len(rc.Details.PublicKey)), + IsCA: rc.Details.IsCA, + Curve: rc.Details.Curve, + }, + signature: make([]byte, len(rc.Signature)), + } + + copy(nc.signature, rc.Signature) + copy(nc.details.Groups, rc.Details.Groups) + nc.details.Issuer = hex.EncodeToString(rc.Details.Issuer) + + if len(rc.Details.PublicKey) < publicKeyLen && assertPublicKey { + return nil, fmt.Errorf("public key was fewer than 32 bytes; %v", len(rc.Details.PublicKey)) + } + copy(nc.details.PublicKey, rc.Details.PublicKey) + + var ip netip.Addr + for i, rawIp := range rc.Details.Ips { + if i%2 == 0 { + ip = int2addr(rawIp) + } else { + ones, _ := net.IPMask(int2ip(rawIp)).Size() + nc.details.Ips[i/2] = netip.PrefixFrom(ip, ones) + } + } + + for i, rawIp := range rc.Details.Subnets { + if i%2 == 0 { + ip = int2addr(rawIp) + } else { + ones, _ := net.IPMask(int2ip(rawIp)).Size() + nc.details.Subnets[i/2] = netip.PrefixFrom(ip, ones) + } + } + + return &nc, nil +} + +func signV1(t *TBSCertificate, curve Curve, key []byte, client *pkclient.PKClient) (*certificateV1, error) { + c := &certificateV1{ + details: detailsV1{ + Name: t.Name, + Ips: t.Networks, + Subnets: t.UnsafeNetworks, + Groups: t.Groups, + NotBefore: t.NotBefore, + NotAfter: t.NotAfter, + PublicKey: t.PublicKey, + IsCA: t.IsCA, + Curve: t.Curve, + Issuer: t.issuer, + }, + } + b, err := proto.Marshal(c.getRawDetails()) + if err != nil { + return nil, err + } + + var sig []byte + + switch curve { + case Curve_CURVE25519: + signer := ed25519.PrivateKey(key) + sig = ed25519.Sign(signer, b) + case Curve_P256: + if client != nil { + sig, err = client.SignASN1(b) + } else { + signer := &ecdsa.PrivateKey{ + PublicKey: ecdsa.PublicKey{ + Curve: elliptic.P256(), + }, + // ref: https://github.com/golang/go/blob/go1.19/src/crypto/x509/sec1.go#L95 + D: new(big.Int).SetBytes(key), + } + // ref: https://github.com/golang/go/blob/go1.19/src/crypto/x509/sec1.go#L119 + signer.X, signer.Y = signer.Curve.ScalarBaseMult(key) + + // We need to hash first for ECDSA + // - https://pkg.go.dev/crypto/ecdsa#SignASN1 + hashed := sha256.Sum256(b) + sig, err = ecdsa.SignASN1(rand.Reader, signer, hashed[:]) + if err != nil { + return nil, err + } + } + default: + return nil, fmt.Errorf("invalid curve: %s", c.details.Curve) + } + + c.signature = sig + return c, nil +} + +func ip2int(ip []byte) uint32 { + if len(ip) == 16 { + return binary.BigEndian.Uint32(ip[12:16]) + } + return binary.BigEndian.Uint32(ip) +} + +func int2ip(nn uint32) net.IP { + ip := make(net.IP, net.IPv4len) + binary.BigEndian.PutUint32(ip, nn) + return ip +} + +func addr2int(addr netip.Addr) uint32 { + b := addr.Unmap().As4() + return binary.BigEndian.Uint32(b[:]) +} + +func int2addr(nn uint32) netip.Addr { + ip := [4]byte{} + binary.BigEndian.PutUint32(ip[:], nn) + return netip.AddrFrom4(ip).Unmap() +} diff --git a/cert/cert.pb.go b/cert/cert_v1.pb.go similarity index 62% rename from cert/cert.pb.go rename to cert/cert_v1.pb.go index 3570e07..32de1a0 100644 --- a/cert/cert.pb.go +++ b/cert/cert_v1.pb.go @@ -1,8 +1,8 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: -// protoc-gen-go v1.30.0 +// protoc-gen-go v1.34.2 // protoc v3.21.5 -// source: cert.proto +// source: cert_v1.proto package cert @@ -50,11 +50,11 @@ func (x Curve) String() string { } func (Curve) Descriptor() protoreflect.EnumDescriptor { - return file_cert_proto_enumTypes[0].Descriptor() + return file_cert_v1_proto_enumTypes[0].Descriptor() } func (Curve) Type() protoreflect.EnumType { - return &file_cert_proto_enumTypes[0] + return &file_cert_v1_proto_enumTypes[0] } func (x Curve) Number() protoreflect.EnumNumber { @@ -63,7 +63,7 @@ func (x Curve) Number() protoreflect.EnumNumber { // Deprecated: Use Curve.Descriptor instead. func (Curve) EnumDescriptor() ([]byte, []int) { - return file_cert_proto_rawDescGZIP(), []int{0} + return file_cert_v1_proto_rawDescGZIP(), []int{0} } type RawNebulaCertificate struct { @@ -78,7 +78,7 @@ type RawNebulaCertificate struct { func (x *RawNebulaCertificate) Reset() { *x = RawNebulaCertificate{} if protoimpl.UnsafeEnabled { - mi := &file_cert_proto_msgTypes[0] + mi := &file_cert_v1_proto_msgTypes[0] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -91,7 +91,7 @@ func (x *RawNebulaCertificate) String() string { func (*RawNebulaCertificate) ProtoMessage() {} func (x *RawNebulaCertificate) ProtoReflect() protoreflect.Message { - mi := &file_cert_proto_msgTypes[0] + mi := &file_cert_v1_proto_msgTypes[0] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -104,7 +104,7 @@ func (x *RawNebulaCertificate) ProtoReflect() protoreflect.Message { // Deprecated: Use RawNebulaCertificate.ProtoReflect.Descriptor instead. func (*RawNebulaCertificate) Descriptor() ([]byte, []int) { - return file_cert_proto_rawDescGZIP(), []int{0} + return file_cert_v1_proto_rawDescGZIP(), []int{0} } func (x *RawNebulaCertificate) GetDetails() *RawNebulaCertificateDetails { @@ -143,7 +143,7 @@ type RawNebulaCertificateDetails struct { func (x *RawNebulaCertificateDetails) Reset() { *x = RawNebulaCertificateDetails{} if protoimpl.UnsafeEnabled { - mi := &file_cert_proto_msgTypes[1] + mi := &file_cert_v1_proto_msgTypes[1] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -156,7 +156,7 @@ func (x *RawNebulaCertificateDetails) String() string { func (*RawNebulaCertificateDetails) ProtoMessage() {} func (x *RawNebulaCertificateDetails) ProtoReflect() protoreflect.Message { - mi := &file_cert_proto_msgTypes[1] + mi := &file_cert_v1_proto_msgTypes[1] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -169,7 +169,7 @@ func (x *RawNebulaCertificateDetails) ProtoReflect() protoreflect.Message { // Deprecated: Use RawNebulaCertificateDetails.ProtoReflect.Descriptor instead. func (*RawNebulaCertificateDetails) Descriptor() ([]byte, []int) { - return file_cert_proto_rawDescGZIP(), []int{1} + return file_cert_v1_proto_rawDescGZIP(), []int{1} } func (x *RawNebulaCertificateDetails) GetName() string { @@ -254,7 +254,7 @@ type RawNebulaEncryptedData struct { func (x *RawNebulaEncryptedData) Reset() { *x = RawNebulaEncryptedData{} if protoimpl.UnsafeEnabled { - mi := &file_cert_proto_msgTypes[2] + mi := &file_cert_v1_proto_msgTypes[2] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -267,7 +267,7 @@ func (x *RawNebulaEncryptedData) String() string { func (*RawNebulaEncryptedData) ProtoMessage() {} func (x *RawNebulaEncryptedData) ProtoReflect() protoreflect.Message { - mi := &file_cert_proto_msgTypes[2] + mi := &file_cert_v1_proto_msgTypes[2] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -280,7 +280,7 @@ func (x *RawNebulaEncryptedData) ProtoReflect() protoreflect.Message { // Deprecated: Use RawNebulaEncryptedData.ProtoReflect.Descriptor instead. func (*RawNebulaEncryptedData) Descriptor() ([]byte, []int) { - return file_cert_proto_rawDescGZIP(), []int{2} + return file_cert_v1_proto_rawDescGZIP(), []int{2} } func (x *RawNebulaEncryptedData) GetEncryptionMetadata() *RawNebulaEncryptionMetadata { @@ -309,7 +309,7 @@ type RawNebulaEncryptionMetadata struct { func (x *RawNebulaEncryptionMetadata) Reset() { *x = RawNebulaEncryptionMetadata{} if protoimpl.UnsafeEnabled { - mi := &file_cert_proto_msgTypes[3] + mi := &file_cert_v1_proto_msgTypes[3] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -322,7 +322,7 @@ func (x *RawNebulaEncryptionMetadata) String() string { func (*RawNebulaEncryptionMetadata) ProtoMessage() {} func (x *RawNebulaEncryptionMetadata) ProtoReflect() protoreflect.Message { - mi := &file_cert_proto_msgTypes[3] + mi := &file_cert_v1_proto_msgTypes[3] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -335,7 +335,7 @@ func (x *RawNebulaEncryptionMetadata) ProtoReflect() protoreflect.Message { // Deprecated: Use RawNebulaEncryptionMetadata.ProtoReflect.Descriptor instead. func (*RawNebulaEncryptionMetadata) Descriptor() ([]byte, []int) { - return file_cert_proto_rawDescGZIP(), []int{3} + return file_cert_v1_proto_rawDescGZIP(), []int{3} } func (x *RawNebulaEncryptionMetadata) GetEncryptionAlgorithm() string { @@ -367,7 +367,7 @@ type RawNebulaArgon2Parameters struct { func (x *RawNebulaArgon2Parameters) Reset() { *x = RawNebulaArgon2Parameters{} if protoimpl.UnsafeEnabled { - mi := &file_cert_proto_msgTypes[4] + mi := &file_cert_v1_proto_msgTypes[4] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -380,7 +380,7 @@ func (x *RawNebulaArgon2Parameters) String() string { func (*RawNebulaArgon2Parameters) ProtoMessage() {} func (x *RawNebulaArgon2Parameters) ProtoReflect() protoreflect.Message { - mi := &file_cert_proto_msgTypes[4] + mi := &file_cert_v1_proto_msgTypes[4] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -393,7 +393,7 @@ func (x *RawNebulaArgon2Parameters) ProtoReflect() protoreflect.Message { // Deprecated: Use RawNebulaArgon2Parameters.ProtoReflect.Descriptor instead. func (*RawNebulaArgon2Parameters) Descriptor() ([]byte, []int) { - return file_cert_proto_rawDescGZIP(), []int{4} + return file_cert_v1_proto_rawDescGZIP(), []int{4} } func (x *RawNebulaArgon2Parameters) GetVersion() int32 { @@ -431,87 +431,87 @@ func (x *RawNebulaArgon2Parameters) GetSalt() []byte { return nil } -var File_cert_proto protoreflect.FileDescriptor +var File_cert_v1_proto protoreflect.FileDescriptor -var file_cert_proto_rawDesc = []byte{ - 0x0a, 0x0a, 0x63, 0x65, 0x72, 0x74, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x04, 0x63, 0x65, - 0x72, 0x74, 0x22, 0x71, 0x0a, 0x14, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62, 0x75, 0x6c, 0x61, 0x43, - 0x65, 0x72, 0x74, 0x69, 0x66, 0x69, 0x63, 0x61, 0x74, 0x65, 0x12, 0x3b, 0x0a, 0x07, 0x44, 0x65, - 0x74, 0x61, 0x69, 0x6c, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x21, 0x2e, 0x63, 0x65, - 0x72, 0x74, 0x2e, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62, 0x75, 0x6c, 0x61, 0x43, 0x65, 0x72, 0x74, - 0x69, 0x66, 0x69, 0x63, 0x61, 0x74, 0x65, 0x44, 0x65, 0x74, 0x61, 0x69, 0x6c, 0x73, 0x52, 0x07, - 0x44, 0x65, 0x74, 0x61, 0x69, 0x6c, 0x73, 0x12, 0x1c, 0x0a, 0x09, 0x53, 0x69, 0x67, 0x6e, 0x61, - 0x74, 0x75, 0x72, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x09, 0x53, 0x69, 0x67, 0x6e, - 0x61, 0x74, 0x75, 0x72, 0x65, 0x22, 0x9c, 0x02, 0x0a, 0x1b, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62, - 0x75, 0x6c, 0x61, 0x43, 0x65, 0x72, 0x74, 0x69, 0x66, 0x69, 0x63, 0x61, 0x74, 0x65, 0x44, 0x65, - 0x74, 0x61, 0x69, 0x6c, 0x73, 0x12, 0x12, 0x0a, 0x04, 0x4e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, - 0x01, 0x28, 0x09, 0x52, 0x04, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x10, 0x0a, 0x03, 0x49, 0x70, 0x73, - 0x18, 0x02, 0x20, 0x03, 0x28, 0x0d, 0x52, 0x03, 0x49, 0x70, 0x73, 0x12, 0x18, 0x0a, 0x07, 0x53, - 0x75, 0x62, 0x6e, 0x65, 0x74, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x0d, 0x52, 0x07, 0x53, 0x75, - 0x62, 0x6e, 0x65, 0x74, 0x73, 0x12, 0x16, 0x0a, 0x06, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x73, 0x18, - 0x04, 0x20, 0x03, 0x28, 0x09, 0x52, 0x06, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x73, 0x12, 0x1c, 0x0a, - 0x09, 0x4e, 0x6f, 0x74, 0x42, 0x65, 0x66, 0x6f, 0x72, 0x65, 0x18, 0x05, 0x20, 0x01, 0x28, 0x03, - 0x52, 0x09, 0x4e, 0x6f, 0x74, 0x42, 0x65, 0x66, 0x6f, 0x72, 0x65, 0x12, 0x1a, 0x0a, 0x08, 0x4e, - 0x6f, 0x74, 0x41, 0x66, 0x74, 0x65, 0x72, 0x18, 0x06, 0x20, 0x01, 0x28, 0x03, 0x52, 0x08, 0x4e, - 0x6f, 0x74, 0x41, 0x66, 0x74, 0x65, 0x72, 0x12, 0x1c, 0x0a, 0x09, 0x50, 0x75, 0x62, 0x6c, 0x69, - 0x63, 0x4b, 0x65, 0x79, 0x18, 0x07, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x09, 0x50, 0x75, 0x62, 0x6c, - 0x69, 0x63, 0x4b, 0x65, 0x79, 0x12, 0x12, 0x0a, 0x04, 0x49, 0x73, 0x43, 0x41, 0x18, 0x08, 0x20, - 0x01, 0x28, 0x08, 0x52, 0x04, 0x49, 0x73, 0x43, 0x41, 0x12, 0x16, 0x0a, 0x06, 0x49, 0x73, 0x73, - 0x75, 0x65, 0x72, 0x18, 0x09, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x06, 0x49, 0x73, 0x73, 0x75, 0x65, - 0x72, 0x12, 0x21, 0x0a, 0x05, 0x63, 0x75, 0x72, 0x76, 0x65, 0x18, 0x64, 0x20, 0x01, 0x28, 0x0e, - 0x32, 0x0b, 0x2e, 0x63, 0x65, 0x72, 0x74, 0x2e, 0x43, 0x75, 0x72, 0x76, 0x65, 0x52, 0x05, 0x63, - 0x75, 0x72, 0x76, 0x65, 0x22, 0x8b, 0x01, 0x0a, 0x16, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62, 0x75, - 0x6c, 0x61, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x44, 0x61, 0x74, 0x61, 0x12, - 0x51, 0x0a, 0x12, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x4d, 0x65, 0x74, - 0x61, 0x64, 0x61, 0x74, 0x61, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x21, 0x2e, 0x63, 0x65, - 0x72, 0x74, 0x2e, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62, 0x75, 0x6c, 0x61, 0x45, 0x6e, 0x63, 0x72, - 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x52, 0x12, - 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, - 0x74, 0x61, 0x12, 0x1e, 0x0a, 0x0a, 0x43, 0x69, 0x70, 0x68, 0x65, 0x72, 0x74, 0x65, 0x78, 0x74, - 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x0a, 0x43, 0x69, 0x70, 0x68, 0x65, 0x72, 0x74, 0x65, - 0x78, 0x74, 0x22, 0x9c, 0x01, 0x0a, 0x1b, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62, 0x75, 0x6c, 0x61, - 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, - 0x74, 0x61, 0x12, 0x30, 0x0a, 0x13, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, - 0x41, 0x6c, 0x67, 0x6f, 0x72, 0x69, 0x74, 0x68, 0x6d, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, - 0x13, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x41, 0x6c, 0x67, 0x6f, 0x72, - 0x69, 0x74, 0x68, 0x6d, 0x12, 0x4b, 0x0a, 0x10, 0x41, 0x72, 0x67, 0x6f, 0x6e, 0x32, 0x50, 0x61, - 0x72, 0x61, 0x6d, 0x65, 0x74, 0x65, 0x72, 0x73, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1f, - 0x2e, 0x63, 0x65, 0x72, 0x74, 0x2e, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62, 0x75, 0x6c, 0x61, 0x41, - 0x72, 0x67, 0x6f, 0x6e, 0x32, 0x50, 0x61, 0x72, 0x61, 0x6d, 0x65, 0x74, 0x65, 0x72, 0x73, 0x52, - 0x10, 0x41, 0x72, 0x67, 0x6f, 0x6e, 0x32, 0x50, 0x61, 0x72, 0x61, 0x6d, 0x65, 0x74, 0x65, 0x72, - 0x73, 0x22, 0xa3, 0x01, 0x0a, 0x19, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62, 0x75, 0x6c, 0x61, 0x41, - 0x72, 0x67, 0x6f, 0x6e, 0x32, 0x50, 0x61, 0x72, 0x61, 0x6d, 0x65, 0x74, 0x65, 0x72, 0x73, 0x12, - 0x18, 0x0a, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, 0x05, - 0x52, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x12, 0x16, 0x0a, 0x06, 0x6d, 0x65, 0x6d, - 0x6f, 0x72, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x06, 0x6d, 0x65, 0x6d, 0x6f, 0x72, - 0x79, 0x12, 0x20, 0x0a, 0x0b, 0x70, 0x61, 0x72, 0x61, 0x6c, 0x6c, 0x65, 0x6c, 0x69, 0x73, 0x6d, - 0x18, 0x04, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x0b, 0x70, 0x61, 0x72, 0x61, 0x6c, 0x6c, 0x65, 0x6c, - 0x69, 0x73, 0x6d, 0x12, 0x1e, 0x0a, 0x0a, 0x69, 0x74, 0x65, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, - 0x73, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x0a, 0x69, 0x74, 0x65, 0x72, 0x61, 0x74, 0x69, - 0x6f, 0x6e, 0x73, 0x12, 0x12, 0x0a, 0x04, 0x73, 0x61, 0x6c, 0x74, 0x18, 0x05, 0x20, 0x01, 0x28, - 0x0c, 0x52, 0x04, 0x73, 0x61, 0x6c, 0x74, 0x2a, 0x21, 0x0a, 0x05, 0x43, 0x75, 0x72, 0x76, 0x65, - 0x12, 0x0e, 0x0a, 0x0a, 0x43, 0x55, 0x52, 0x56, 0x45, 0x32, 0x35, 0x35, 0x31, 0x39, 0x10, 0x00, - 0x12, 0x08, 0x0a, 0x04, 0x50, 0x32, 0x35, 0x36, 0x10, 0x01, 0x42, 0x20, 0x5a, 0x1e, 0x67, 0x69, - 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x73, 0x6c, 0x61, 0x63, 0x6b, 0x68, 0x71, - 0x2f, 0x6e, 0x65, 0x62, 0x75, 0x6c, 0x61, 0x2f, 0x63, 0x65, 0x72, 0x74, 0x62, 0x06, 0x70, 0x72, - 0x6f, 0x74, 0x6f, 0x33, +var file_cert_v1_proto_rawDesc = []byte{ + 0x0a, 0x0d, 0x63, 0x65, 0x72, 0x74, 0x5f, 0x76, 0x31, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, + 0x04, 0x63, 0x65, 0x72, 0x74, 0x22, 0x71, 0x0a, 0x14, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62, 0x75, + 0x6c, 0x61, 0x43, 0x65, 0x72, 0x74, 0x69, 0x66, 0x69, 0x63, 0x61, 0x74, 0x65, 0x12, 0x3b, 0x0a, + 0x07, 0x44, 0x65, 0x74, 0x61, 0x69, 0x6c, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x21, + 0x2e, 0x63, 0x65, 0x72, 0x74, 0x2e, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62, 0x75, 0x6c, 0x61, 0x43, + 0x65, 0x72, 0x74, 0x69, 0x66, 0x69, 0x63, 0x61, 0x74, 0x65, 0x44, 0x65, 0x74, 0x61, 0x69, 0x6c, + 0x73, 0x52, 0x07, 0x44, 0x65, 0x74, 0x61, 0x69, 0x6c, 0x73, 0x12, 0x1c, 0x0a, 0x09, 0x53, 0x69, + 0x67, 0x6e, 0x61, 0x74, 0x75, 0x72, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x09, 0x53, + 0x69, 0x67, 0x6e, 0x61, 0x74, 0x75, 0x72, 0x65, 0x22, 0x9c, 0x02, 0x0a, 0x1b, 0x52, 0x61, 0x77, + 0x4e, 0x65, 0x62, 0x75, 0x6c, 0x61, 0x43, 0x65, 0x72, 0x74, 0x69, 0x66, 0x69, 0x63, 0x61, 0x74, + 0x65, 0x44, 0x65, 0x74, 0x61, 0x69, 0x6c, 0x73, 0x12, 0x12, 0x0a, 0x04, 0x4e, 0x61, 0x6d, 0x65, + 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x10, 0x0a, 0x03, + 0x49, 0x70, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x0d, 0x52, 0x03, 0x49, 0x70, 0x73, 0x12, 0x18, + 0x0a, 0x07, 0x53, 0x75, 0x62, 0x6e, 0x65, 0x74, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x0d, 0x52, + 0x07, 0x53, 0x75, 0x62, 0x6e, 0x65, 0x74, 0x73, 0x12, 0x16, 0x0a, 0x06, 0x47, 0x72, 0x6f, 0x75, + 0x70, 0x73, 0x18, 0x04, 0x20, 0x03, 0x28, 0x09, 0x52, 0x06, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x73, + 0x12, 0x1c, 0x0a, 0x09, 0x4e, 0x6f, 0x74, 0x42, 0x65, 0x66, 0x6f, 0x72, 0x65, 0x18, 0x05, 0x20, + 0x01, 0x28, 0x03, 0x52, 0x09, 0x4e, 0x6f, 0x74, 0x42, 0x65, 0x66, 0x6f, 0x72, 0x65, 0x12, 0x1a, + 0x0a, 0x08, 0x4e, 0x6f, 0x74, 0x41, 0x66, 0x74, 0x65, 0x72, 0x18, 0x06, 0x20, 0x01, 0x28, 0x03, + 0x52, 0x08, 0x4e, 0x6f, 0x74, 0x41, 0x66, 0x74, 0x65, 0x72, 0x12, 0x1c, 0x0a, 0x09, 0x50, 0x75, + 0x62, 0x6c, 0x69, 0x63, 0x4b, 0x65, 0x79, 0x18, 0x07, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x09, 0x50, + 0x75, 0x62, 0x6c, 0x69, 0x63, 0x4b, 0x65, 0x79, 0x12, 0x12, 0x0a, 0x04, 0x49, 0x73, 0x43, 0x41, + 0x18, 0x08, 0x20, 0x01, 0x28, 0x08, 0x52, 0x04, 0x49, 0x73, 0x43, 0x41, 0x12, 0x16, 0x0a, 0x06, + 0x49, 0x73, 0x73, 0x75, 0x65, 0x72, 0x18, 0x09, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x06, 0x49, 0x73, + 0x73, 0x75, 0x65, 0x72, 0x12, 0x21, 0x0a, 0x05, 0x63, 0x75, 0x72, 0x76, 0x65, 0x18, 0x64, 0x20, + 0x01, 0x28, 0x0e, 0x32, 0x0b, 0x2e, 0x63, 0x65, 0x72, 0x74, 0x2e, 0x43, 0x75, 0x72, 0x76, 0x65, + 0x52, 0x05, 0x63, 0x75, 0x72, 0x76, 0x65, 0x22, 0x8b, 0x01, 0x0a, 0x16, 0x52, 0x61, 0x77, 0x4e, + 0x65, 0x62, 0x75, 0x6c, 0x61, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x44, 0x61, + 0x74, 0x61, 0x12, 0x51, 0x0a, 0x12, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, + 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x21, + 0x2e, 0x63, 0x65, 0x72, 0x74, 0x2e, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62, 0x75, 0x6c, 0x61, 0x45, + 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, + 0x61, 0x52, 0x12, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x4d, 0x65, 0x74, + 0x61, 0x64, 0x61, 0x74, 0x61, 0x12, 0x1e, 0x0a, 0x0a, 0x43, 0x69, 0x70, 0x68, 0x65, 0x72, 0x74, + 0x65, 0x78, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x0a, 0x43, 0x69, 0x70, 0x68, 0x65, + 0x72, 0x74, 0x65, 0x78, 0x74, 0x22, 0x9c, 0x01, 0x0a, 0x1b, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62, + 0x75, 0x6c, 0x61, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x4d, 0x65, 0x74, + 0x61, 0x64, 0x61, 0x74, 0x61, 0x12, 0x30, 0x0a, 0x13, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, + 0x69, 0x6f, 0x6e, 0x41, 0x6c, 0x67, 0x6f, 0x72, 0x69, 0x74, 0x68, 0x6d, 0x18, 0x01, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x13, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x41, 0x6c, + 0x67, 0x6f, 0x72, 0x69, 0x74, 0x68, 0x6d, 0x12, 0x4b, 0x0a, 0x10, 0x41, 0x72, 0x67, 0x6f, 0x6e, + 0x32, 0x50, 0x61, 0x72, 0x61, 0x6d, 0x65, 0x74, 0x65, 0x72, 0x73, 0x18, 0x02, 0x20, 0x01, 0x28, + 0x0b, 0x32, 0x1f, 0x2e, 0x63, 0x65, 0x72, 0x74, 0x2e, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62, 0x75, + 0x6c, 0x61, 0x41, 0x72, 0x67, 0x6f, 0x6e, 0x32, 0x50, 0x61, 0x72, 0x61, 0x6d, 0x65, 0x74, 0x65, + 0x72, 0x73, 0x52, 0x10, 0x41, 0x72, 0x67, 0x6f, 0x6e, 0x32, 0x50, 0x61, 0x72, 0x61, 0x6d, 0x65, + 0x74, 0x65, 0x72, 0x73, 0x22, 0xa3, 0x01, 0x0a, 0x19, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62, 0x75, + 0x6c, 0x61, 0x41, 0x72, 0x67, 0x6f, 0x6e, 0x32, 0x50, 0x61, 0x72, 0x61, 0x6d, 0x65, 0x74, 0x65, + 0x72, 0x73, 0x12, 0x18, 0x0a, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x01, 0x20, + 0x01, 0x28, 0x05, 0x52, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x12, 0x16, 0x0a, 0x06, + 0x6d, 0x65, 0x6d, 0x6f, 0x72, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x06, 0x6d, 0x65, + 0x6d, 0x6f, 0x72, 0x79, 0x12, 0x20, 0x0a, 0x0b, 0x70, 0x61, 0x72, 0x61, 0x6c, 0x6c, 0x65, 0x6c, + 0x69, 0x73, 0x6d, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x0b, 0x70, 0x61, 0x72, 0x61, 0x6c, + 0x6c, 0x65, 0x6c, 0x69, 0x73, 0x6d, 0x12, 0x1e, 0x0a, 0x0a, 0x69, 0x74, 0x65, 0x72, 0x61, 0x74, + 0x69, 0x6f, 0x6e, 0x73, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x0a, 0x69, 0x74, 0x65, 0x72, + 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x12, 0x12, 0x0a, 0x04, 0x73, 0x61, 0x6c, 0x74, 0x18, 0x05, + 0x20, 0x01, 0x28, 0x0c, 0x52, 0x04, 0x73, 0x61, 0x6c, 0x74, 0x2a, 0x21, 0x0a, 0x05, 0x43, 0x75, + 0x72, 0x76, 0x65, 0x12, 0x0e, 0x0a, 0x0a, 0x43, 0x55, 0x52, 0x56, 0x45, 0x32, 0x35, 0x35, 0x31, + 0x39, 0x10, 0x00, 0x12, 0x08, 0x0a, 0x04, 0x50, 0x32, 0x35, 0x36, 0x10, 0x01, 0x42, 0x20, 0x5a, + 0x1e, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x73, 0x6c, 0x61, 0x63, + 0x6b, 0x68, 0x71, 0x2f, 0x6e, 0x65, 0x62, 0x75, 0x6c, 0x61, 0x2f, 0x63, 0x65, 0x72, 0x74, 0x62, + 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, } var ( - file_cert_proto_rawDescOnce sync.Once - file_cert_proto_rawDescData = file_cert_proto_rawDesc + file_cert_v1_proto_rawDescOnce sync.Once + file_cert_v1_proto_rawDescData = file_cert_v1_proto_rawDesc ) -func file_cert_proto_rawDescGZIP() []byte { - file_cert_proto_rawDescOnce.Do(func() { - file_cert_proto_rawDescData = protoimpl.X.CompressGZIP(file_cert_proto_rawDescData) +func file_cert_v1_proto_rawDescGZIP() []byte { + file_cert_v1_proto_rawDescOnce.Do(func() { + file_cert_v1_proto_rawDescData = protoimpl.X.CompressGZIP(file_cert_v1_proto_rawDescData) }) - return file_cert_proto_rawDescData + return file_cert_v1_proto_rawDescData } -var file_cert_proto_enumTypes = make([]protoimpl.EnumInfo, 1) -var file_cert_proto_msgTypes = make([]protoimpl.MessageInfo, 5) -var file_cert_proto_goTypes = []interface{}{ +var file_cert_v1_proto_enumTypes = make([]protoimpl.EnumInfo, 1) +var file_cert_v1_proto_msgTypes = make([]protoimpl.MessageInfo, 5) +var file_cert_v1_proto_goTypes = []any{ (Curve)(0), // 0: cert.Curve (*RawNebulaCertificate)(nil), // 1: cert.RawNebulaCertificate (*RawNebulaCertificateDetails)(nil), // 2: cert.RawNebulaCertificateDetails @@ -519,7 +519,7 @@ var file_cert_proto_goTypes = []interface{}{ (*RawNebulaEncryptionMetadata)(nil), // 4: cert.RawNebulaEncryptionMetadata (*RawNebulaArgon2Parameters)(nil), // 5: cert.RawNebulaArgon2Parameters } -var file_cert_proto_depIdxs = []int32{ +var file_cert_v1_proto_depIdxs = []int32{ 2, // 0: cert.RawNebulaCertificate.Details:type_name -> cert.RawNebulaCertificateDetails 0, // 1: cert.RawNebulaCertificateDetails.curve:type_name -> cert.Curve 4, // 2: cert.RawNebulaEncryptedData.EncryptionMetadata:type_name -> cert.RawNebulaEncryptionMetadata @@ -531,13 +531,13 @@ var file_cert_proto_depIdxs = []int32{ 0, // [0:4] is the sub-list for field type_name } -func init() { file_cert_proto_init() } -func file_cert_proto_init() { - if File_cert_proto != nil { +func init() { file_cert_v1_proto_init() } +func file_cert_v1_proto_init() { + if File_cert_v1_proto != nil { return } if !protoimpl.UnsafeEnabled { - file_cert_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} { + file_cert_v1_proto_msgTypes[0].Exporter = func(v any, i int) any { switch v := v.(*RawNebulaCertificate); i { case 0: return &v.state @@ -549,7 +549,7 @@ func file_cert_proto_init() { return nil } } - file_cert_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} { + file_cert_v1_proto_msgTypes[1].Exporter = func(v any, i int) any { switch v := v.(*RawNebulaCertificateDetails); i { case 0: return &v.state @@ -561,7 +561,7 @@ func file_cert_proto_init() { return nil } } - file_cert_proto_msgTypes[2].Exporter = func(v interface{}, i int) interface{} { + file_cert_v1_proto_msgTypes[2].Exporter = func(v any, i int) any { switch v := v.(*RawNebulaEncryptedData); i { case 0: return &v.state @@ -573,7 +573,7 @@ func file_cert_proto_init() { return nil } } - file_cert_proto_msgTypes[3].Exporter = func(v interface{}, i int) interface{} { + file_cert_v1_proto_msgTypes[3].Exporter = func(v any, i int) any { switch v := v.(*RawNebulaEncryptionMetadata); i { case 0: return &v.state @@ -585,7 +585,7 @@ func file_cert_proto_init() { return nil } } - file_cert_proto_msgTypes[4].Exporter = func(v interface{}, i int) interface{} { + file_cert_v1_proto_msgTypes[4].Exporter = func(v any, i int) any { switch v := v.(*RawNebulaArgon2Parameters); i { case 0: return &v.state @@ -602,19 +602,19 @@ func file_cert_proto_init() { out := protoimpl.TypeBuilder{ File: protoimpl.DescBuilder{ GoPackagePath: reflect.TypeOf(x{}).PkgPath(), - RawDescriptor: file_cert_proto_rawDesc, + RawDescriptor: file_cert_v1_proto_rawDesc, NumEnums: 1, NumMessages: 5, NumExtensions: 0, NumServices: 0, }, - GoTypes: file_cert_proto_goTypes, - DependencyIndexes: file_cert_proto_depIdxs, - EnumInfos: file_cert_proto_enumTypes, - MessageInfos: file_cert_proto_msgTypes, + GoTypes: file_cert_v1_proto_goTypes, + DependencyIndexes: file_cert_v1_proto_depIdxs, + EnumInfos: file_cert_v1_proto_enumTypes, + MessageInfos: file_cert_v1_proto_msgTypes, }.Build() - File_cert_proto = out.File - file_cert_proto_rawDesc = nil - file_cert_proto_goTypes = nil - file_cert_proto_depIdxs = nil + File_cert_v1_proto = out.File + file_cert_v1_proto_rawDesc = nil + file_cert_v1_proto_goTypes = nil + file_cert_v1_proto_depIdxs = nil } diff --git a/cert/cert.proto b/cert/cert_v1.proto similarity index 100% rename from cert/cert.proto rename to cert/cert_v1.proto diff --git a/cert/crypto.go b/cert/crypto.go index 3558e1a..4c236ae 100644 --- a/cert/crypto.go +++ b/cert/crypto.go @@ -3,14 +3,28 @@ package cert import ( "crypto/aes" "crypto/cipher" + "crypto/ed25519" "crypto/rand" + "encoding/pem" "fmt" "io" + "math" "golang.org/x/crypto/argon2" + "google.golang.org/protobuf/proto" ) -// KDF factors +type NebulaEncryptedData struct { + EncryptionMetadata NebulaEncryptionMetadata + Ciphertext []byte +} + +type NebulaEncryptionMetadata struct { + EncryptionAlgorithm string + Argon2Parameters Argon2Parameters +} + +// Argon2Parameters KDF factors type Argon2Parameters struct { version rune Memory uint32 // KiB @@ -19,7 +33,7 @@ type Argon2Parameters struct { salt []byte } -// Returns a new Argon2Parameters object with current version set +// NewArgon2Parameters Returns a new Argon2Parameters object with current version set func NewArgon2Parameters(memory uint32, parallelism uint8, iterations uint32) *Argon2Parameters { return &Argon2Parameters{ version: argon2.Version, @@ -141,3 +155,146 @@ func splitNonceCiphertext(blob []byte, nonceSize int) ([]byte, []byte, error) { return blob[:nonceSize], blob[nonceSize:], nil } + +// EncryptAndMarshalSigningPrivateKey is a simple helper to encrypt and PEM encode a private key +func EncryptAndMarshalSigningPrivateKey(curve Curve, b []byte, passphrase []byte, kdfParams *Argon2Parameters) ([]byte, error) { + ciphertext, err := aes256Encrypt(passphrase, kdfParams, b) + if err != nil { + return nil, err + } + + b, err = proto.Marshal(&RawNebulaEncryptedData{ + EncryptionMetadata: &RawNebulaEncryptionMetadata{ + EncryptionAlgorithm: "AES-256-GCM", + Argon2Parameters: &RawNebulaArgon2Parameters{ + Version: kdfParams.version, + Memory: kdfParams.Memory, + Parallelism: uint32(kdfParams.Parallelism), + Iterations: kdfParams.Iterations, + Salt: kdfParams.salt, + }, + }, + Ciphertext: ciphertext, + }) + if err != nil { + return nil, err + } + + switch curve { + case Curve_CURVE25519: + return pem.EncodeToMemory(&pem.Block{Type: EncryptedEd25519PrivateKeyBanner, Bytes: b}), nil + case Curve_P256: + return pem.EncodeToMemory(&pem.Block{Type: EncryptedECDSAP256PrivateKeyBanner, Bytes: b}), nil + default: + return nil, fmt.Errorf("invalid curve: %v", curve) + } +} + +// UnmarshalNebulaEncryptedData will unmarshal a protobuf byte representation of a nebula cert into its +// protobuf-generated struct. +func UnmarshalNebulaEncryptedData(b []byte) (*NebulaEncryptedData, error) { + if len(b) == 0 { + return nil, fmt.Errorf("nil byte array") + } + var rned RawNebulaEncryptedData + err := proto.Unmarshal(b, &rned) + if err != nil { + return nil, err + } + + if rned.EncryptionMetadata == nil { + return nil, fmt.Errorf("encoded EncryptionMetadata was nil") + } + + if rned.EncryptionMetadata.Argon2Parameters == nil { + return nil, fmt.Errorf("encoded Argon2Parameters was nil") + } + + params, err := unmarshalArgon2Parameters(rned.EncryptionMetadata.Argon2Parameters) + if err != nil { + return nil, err + } + + ned := NebulaEncryptedData{ + EncryptionMetadata: NebulaEncryptionMetadata{ + EncryptionAlgorithm: rned.EncryptionMetadata.EncryptionAlgorithm, + Argon2Parameters: *params, + }, + Ciphertext: rned.Ciphertext, + } + + return &ned, nil +} + +func unmarshalArgon2Parameters(params *RawNebulaArgon2Parameters) (*Argon2Parameters, error) { + if params.Version < math.MinInt32 || params.Version > math.MaxInt32 { + return nil, fmt.Errorf("Argon2Parameters Version must be at least %d and no more than %d", math.MinInt32, math.MaxInt32) + } + if params.Memory <= 0 || params.Memory > math.MaxUint32 { + return nil, fmt.Errorf("Argon2Parameters Memory must be be greater than 0 and no more than %d KiB", uint32(math.MaxUint32)) + } + if params.Parallelism <= 0 || params.Parallelism > math.MaxUint8 { + return nil, fmt.Errorf("Argon2Parameters Parallelism must be be greater than 0 and no more than %d", math.MaxUint8) + } + if params.Iterations <= 0 || params.Iterations > math.MaxUint32 { + return nil, fmt.Errorf("-argon-iterations must be be greater than 0 and no more than %d", uint32(math.MaxUint32)) + } + + return &Argon2Parameters{ + version: params.Version, + Memory: params.Memory, + Parallelism: uint8(params.Parallelism), + Iterations: params.Iterations, + salt: params.Salt, + }, nil + +} + +// DecryptAndUnmarshalSigningPrivateKey will try to pem decode and decrypt an Ed25519/ECDSA private key with +// the given passphrase, returning any other bytes b or an error on failure +func DecryptAndUnmarshalSigningPrivateKey(passphrase, b []byte) (Curve, []byte, []byte, error) { + var curve Curve + + k, r := pem.Decode(b) + if k == nil { + return curve, nil, r, fmt.Errorf("input did not contain a valid PEM encoded block") + } + + switch k.Type { + case EncryptedEd25519PrivateKeyBanner: + curve = Curve_CURVE25519 + case EncryptedECDSAP256PrivateKeyBanner: + curve = Curve_P256 + default: + return curve, nil, r, fmt.Errorf("bytes did not contain a proper nebula encrypted Ed25519/ECDSA private key banner") + } + + ned, err := UnmarshalNebulaEncryptedData(k.Bytes) + if err != nil { + return curve, nil, r, err + } + + var bytes []byte + switch ned.EncryptionMetadata.EncryptionAlgorithm { + case "AES-256-GCM": + bytes, err = aes256Decrypt(passphrase, &ned.EncryptionMetadata.Argon2Parameters, ned.Ciphertext) + if err != nil { + return curve, nil, r, err + } + default: + return curve, nil, r, fmt.Errorf("unsupported encryption algorithm: %s", ned.EncryptionMetadata.EncryptionAlgorithm) + } + + switch curve { + case Curve_CURVE25519: + if len(bytes) != ed25519.PrivateKeySize { + return curve, nil, r, fmt.Errorf("key was not %d bytes, is invalid ed25519 private key", ed25519.PrivateKeySize) + } + case Curve_P256: + if len(bytes) != 32 { + return curve, nil, r, fmt.Errorf("key was not 32 bytes, is invalid ECDSA P256 private key") + } + } + + return curve, bytes, r, nil +} diff --git a/cert/crypto_test.go b/cert/crypto_test.go index c2e61df..c9aba3e 100644 --- a/cert/crypto_test.go +++ b/cert/crypto_test.go @@ -23,3 +23,90 @@ func TestNewArgon2Parameters(t *testing.T) { Iterations: 1, }, p) } + +func TestDecryptAndUnmarshalSigningPrivateKey(t *testing.T) { + passphrase := []byte("DO NOT USE THIS KEY") + privKey := []byte(`# A good key +-----BEGIN NEBULA ED25519 ENCRYPTED PRIVATE KEY----- +CjwKC0FFUy0yNTYtR0NNEi0IExCAgIABGAEgBCognnjujd67Vsv99p22wfAjQaDT +oCMW1mdjkU3gACKNW4MSXOWR9Sts4C81yk1RUku2gvGKs3TB9LYoklLsIizSYOLl ++Vs//O1T0I1Xbml2XBAROsb/VSoDln/6LMqR4B6fn6B3GOsLBBqRI8daDl9lRMPB +qrlJ69wer3ZUHFXA +-----END NEBULA ED25519 ENCRYPTED PRIVATE KEY----- +`) + shortKey := []byte(`# A key which, once decrypted, is too short +-----BEGIN NEBULA ED25519 ENCRYPTED PRIVATE KEY----- +CjwKC0FFUy0yNTYtR0NNEi0IExCAgIABGAEgBCoga5h8owMEBWRSMMJKzuUvWce7 +k0qlBkQmCxiuLh80MuASW70YcKt8jeEIS2axo2V6zAKA9TSMcCsJW1kDDXEtL/xe +GLF5T7sDl5COp4LU3pGxpV+KoeQ/S3gQCAAcnaOtnJQX+aSDnbO3jCHyP7U9CHbs +rQr3bdH3Oy/WiYU= +-----END NEBULA ED25519 ENCRYPTED PRIVATE KEY----- +`) + invalidBanner := []byte(`# Invalid banner (not encrypted) +-----BEGIN NEBULA ED25519 PRIVATE KEY----- +bWRp2CTVFhW9HD/qCd28ltDgK3w8VXSeaEYczDWos8sMUBqDb9jP3+NYwcS4lURG +XgLvodMXZJuaFPssp+WwtA== +-----END NEBULA ED25519 PRIVATE KEY----- +`) + invalidPem := []byte(`# Not a valid PEM format +-BEGIN NEBULA ED25519 ENCRYPTED PRIVATE KEY----- +CjwKC0FFUy0yNTYtR0NNEi0IExCAgIABGAEgBCognnjujd67Vsv99p22wfAjQaDT +oCMW1mdjkU3gACKNW4MSXOWR9Sts4C81yk1RUku2gvGKs3TB9LYoklLsIizSYOLl ++Vs//O1T0I1Xbml2XBAROsb/VSoDln/6LMqR4B6fn6B3GOsLBBqRI8daDl9lRMPB +qrlJ69wer3ZUHFXA +-END NEBULA ED25519 ENCRYPTED PRIVATE KEY----- +`) + + keyBundle := appendByteSlices(privKey, shortKey, invalidBanner, invalidPem) + + // Success test case + curve, k, rest, err := DecryptAndUnmarshalSigningPrivateKey(passphrase, keyBundle) + assert.Nil(t, err) + assert.Equal(t, Curve_CURVE25519, curve) + assert.Len(t, k, 64) + assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem)) + + // Fail due to short key + curve, k, rest, err = DecryptAndUnmarshalSigningPrivateKey(passphrase, rest) + assert.EqualError(t, err, "key was not 64 bytes, is invalid ed25519 private key") + assert.Nil(t, k) + assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem)) + + // Fail due to invalid banner + curve, k, rest, err = DecryptAndUnmarshalSigningPrivateKey(passphrase, rest) + assert.EqualError(t, err, "bytes did not contain a proper nebula encrypted Ed25519/ECDSA private key banner") + assert.Nil(t, k) + assert.Equal(t, rest, invalidPem) + + // Fail due to ivalid PEM format, because + // it's missing the requisite pre-encapsulation boundary. + curve, k, rest, err = DecryptAndUnmarshalSigningPrivateKey(passphrase, rest) + assert.EqualError(t, err, "input did not contain a valid PEM encoded block") + assert.Nil(t, k) + assert.Equal(t, rest, invalidPem) + + // Fail due to invalid passphrase + curve, k, rest, err = DecryptAndUnmarshalSigningPrivateKey([]byte("invalid passphrase"), privKey) + assert.EqualError(t, err, "invalid passphrase or corrupt private key") + assert.Nil(t, k) + assert.Equal(t, rest, []byte{}) +} + +func TestEncryptAndMarshalSigningPrivateKey(t *testing.T) { + // Having proved that decryption works correctly above, we can test the + // encryption function produces a value which can be decrypted + passphrase := []byte("passphrase") + bytes := []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA") + kdfParams := NewArgon2Parameters(64*1024, 4, 3) + key, err := EncryptAndMarshalSigningPrivateKey(Curve_CURVE25519, bytes, passphrase, kdfParams) + assert.Nil(t, err) + + // Verify the "key" can be decrypted successfully + curve, k, rest, err := DecryptAndUnmarshalSigningPrivateKey(passphrase, key) + assert.Len(t, k, 64) + assert.Equal(t, Curve_CURVE25519, curve) + assert.Equal(t, rest, []byte{}) + assert.Nil(t, err) + + // EncryptAndMarshalEd25519PrivateKey does not create any errors itself +} diff --git a/cert/errors.go b/cert/errors.go index 05b42d1..da0d1be 100644 --- a/cert/errors.go +++ b/cert/errors.go @@ -5,10 +5,23 @@ import ( ) var ( - ErrRootExpired = errors.New("root certificate is expired") - ErrExpired = errors.New("certificate is expired") - ErrNotCA = errors.New("certificate is not a CA") - ErrNotSelfSigned = errors.New("certificate is not self-signed") - ErrBlockListed = errors.New("certificate is in the block list") - ErrSignatureMismatch = errors.New("certificate signature did not match") + ErrBadFormat = errors.New("bad wire format") + ErrRootExpired = errors.New("root certificate is expired") + ErrExpired = errors.New("certificate is expired") + ErrNotCA = errors.New("certificate is not a CA") + ErrNotSelfSigned = errors.New("certificate is not self-signed") + ErrBlockListed = errors.New("certificate is in the block list") + ErrFingerprintMismatch = errors.New("certificate fingerprint did not match") + ErrSignatureMismatch = errors.New("certificate signature did not match") + ErrInvalidPublicKeyLength = errors.New("invalid public key length") + ErrInvalidPrivateKeyLength = errors.New("invalid private key length") + + ErrPrivateKeyEncrypted = errors.New("private key must be decrypted") + + ErrInvalidPEMBlock = errors.New("input did not contain a valid PEM encoded block") + ErrInvalidPEMCertificateBanner = errors.New("bytes did not contain a proper certificate banner") + ErrInvalidPEMX25519PublicKeyBanner = errors.New("bytes did not contain a proper X25519 public key banner") + ErrInvalidPEMX25519PrivateKeyBanner = errors.New("bytes did not contain a proper X25519 private key banner") + ErrInvalidPEMEd25519PublicKeyBanner = errors.New("bytes did not contain a proper Ed25519 public key banner") + ErrInvalidPEMEd25519PrivateKeyBanner = errors.New("bytes did not contain a proper Ed25519 private key banner") ) diff --git a/cert/pem.go b/cert/pem.go new file mode 100644 index 0000000..744ae2e --- /dev/null +++ b/cert/pem.go @@ -0,0 +1,155 @@ +package cert + +import ( + "encoding/pem" + "fmt" + + "golang.org/x/crypto/ed25519" +) + +const ( + CertificateBanner = "NEBULA CERTIFICATE" + CertificateV2Banner = "NEBULA CERTIFICATE V2" + X25519PrivateKeyBanner = "NEBULA X25519 PRIVATE KEY" + X25519PublicKeyBanner = "NEBULA X25519 PUBLIC KEY" + EncryptedEd25519PrivateKeyBanner = "NEBULA ED25519 ENCRYPTED PRIVATE KEY" + Ed25519PrivateKeyBanner = "NEBULA ED25519 PRIVATE KEY" + Ed25519PublicKeyBanner = "NEBULA ED25519 PUBLIC KEY" + + P256PrivateKeyBanner = "NEBULA P256 PRIVATE KEY" + P256PublicKeyBanner = "NEBULA P256 PUBLIC KEY" + EncryptedECDSAP256PrivateKeyBanner = "NEBULA ECDSA P256 ENCRYPTED PRIVATE KEY" + ECDSAP256PrivateKeyBanner = "NEBULA ECDSA P256 PRIVATE KEY" +) + +// UnmarshalCertificateFromPEM will try to unmarshal the first pem block in a byte array, returning any non consumed +// data or an error on failure +func UnmarshalCertificateFromPEM(b []byte) (Certificate, []byte, error) { + p, r := pem.Decode(b) + if p == nil { + return nil, r, ErrInvalidPEMBlock + } + + switch p.Type { + case CertificateBanner: + c, err := unmarshalCertificateV1(p.Bytes, true) + if err != nil { + return nil, nil, err + } + return c, r, nil + case CertificateV2Banner: + //TODO + panic("TODO") + default: + return nil, r, ErrInvalidPEMCertificateBanner + } +} + +func MarshalPublicKeyToPEM(curve Curve, b []byte) []byte { + switch curve { + case Curve_CURVE25519: + return pem.EncodeToMemory(&pem.Block{Type: X25519PublicKeyBanner, Bytes: b}) + case Curve_P256: + return pem.EncodeToMemory(&pem.Block{Type: P256PublicKeyBanner, Bytes: b}) + default: + return nil + } +} + +func UnmarshalPublicKeyFromPEM(b []byte) ([]byte, []byte, Curve, error) { + k, r := pem.Decode(b) + if k == nil { + return nil, r, 0, fmt.Errorf("input did not contain a valid PEM encoded block") + } + var expectedLen int + var curve Curve + switch k.Type { + case X25519PublicKeyBanner, Ed25519PublicKeyBanner: + expectedLen = 32 + curve = Curve_CURVE25519 + case P256PublicKeyBanner: + // Uncompressed + expectedLen = 65 + curve = Curve_P256 + default: + return nil, r, 0, fmt.Errorf("bytes did not contain a proper public key banner") + } + if len(k.Bytes) != expectedLen { + return nil, r, 0, fmt.Errorf("key was not %d bytes, is invalid %s public key", expectedLen, curve) + } + return k.Bytes, r, curve, nil +} + +func MarshalPrivateKeyToPEM(curve Curve, b []byte) []byte { + switch curve { + case Curve_CURVE25519: + return pem.EncodeToMemory(&pem.Block{Type: X25519PrivateKeyBanner, Bytes: b}) + case Curve_P256: + return pem.EncodeToMemory(&pem.Block{Type: P256PrivateKeyBanner, Bytes: b}) + default: + return nil + } +} + +func MarshalSigningPrivateKeyToPEM(curve Curve, b []byte) []byte { + switch curve { + case Curve_CURVE25519: + return pem.EncodeToMemory(&pem.Block{Type: Ed25519PrivateKeyBanner, Bytes: b}) + case Curve_P256: + return pem.EncodeToMemory(&pem.Block{Type: ECDSAP256PrivateKeyBanner, Bytes: b}) + default: + return nil + } +} + +// UnmarshalPrivateKeyFromPEM will try to unmarshal the first pem block in a byte array, returning any non +// consumed data or an error on failure +func UnmarshalPrivateKeyFromPEM(b []byte) ([]byte, []byte, Curve, error) { + k, r := pem.Decode(b) + if k == nil { + return nil, r, 0, fmt.Errorf("input did not contain a valid PEM encoded block") + } + var expectedLen int + var curve Curve + switch k.Type { + case X25519PrivateKeyBanner: + expectedLen = 32 + curve = Curve_CURVE25519 + case P256PrivateKeyBanner: + expectedLen = 32 + curve = Curve_P256 + default: + return nil, r, 0, fmt.Errorf("bytes did not contain a proper private key banner") + } + if len(k.Bytes) != expectedLen { + return nil, r, 0, fmt.Errorf("key was not %d bytes, is invalid %s private key", expectedLen, curve) + } + return k.Bytes, r, curve, nil +} + +func UnmarshalSigningPrivateKeyFromPEM(b []byte) ([]byte, []byte, Curve, error) { + k, r := pem.Decode(b) + if k == nil { + return nil, r, 0, fmt.Errorf("input did not contain a valid PEM encoded block") + } + var curve Curve + switch k.Type { + case EncryptedEd25519PrivateKeyBanner: + return nil, nil, Curve_CURVE25519, ErrPrivateKeyEncrypted + case EncryptedECDSAP256PrivateKeyBanner: + return nil, nil, Curve_P256, ErrPrivateKeyEncrypted + case Ed25519PrivateKeyBanner: + curve = Curve_CURVE25519 + if len(k.Bytes) != ed25519.PrivateKeySize { + return nil, r, 0, fmt.Errorf("key was not %d bytes, is invalid Ed25519 private key", ed25519.PrivateKeySize) + } + case ECDSAP256PrivateKeyBanner: + curve = Curve_P256 + if len(k.Bytes) != 32 { + return nil, r, 0, fmt.Errorf("key was not 32 bytes, is invalid ECDSA P256 private key") + } + default: + return nil, r, 0, fmt.Errorf("bytes did not contain a proper Ed25519/ECDSA private key banner") + } + return k.Bytes, r, curve, nil +} diff --git a/cert/pem_test.go b/cert/pem_test.go new file mode 100644 index 0000000..a0c6e74 --- /dev/null +++ b/cert/pem_test.go @@ -0,0 +1,292 @@ +package cert + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestUnmarshalCertificateFromPEM(t *testing.T) { + goodCert := []byte(` +# A good cert +-----BEGIN NEBULA CERTIFICATE----- +CkAKDm5lYnVsYSByb290IGNhKJfap9AFMJfg1+YGOiCUQGByMuNRhIlQBOyzXWbL +vcKBwDhov900phEfJ5DN3kABEkDCq5R8qBiu8sl54yVfgRcQXEDt3cHr8UTSLszv +bzBEr00kERQxxTzTsH8cpYEgRoipvmExvg8WP8NdAJEYJosB +-----END NEBULA CERTIFICATE----- +`) + badBanner := []byte(`# A bad banner +-----BEGIN NOT A NEBULA CERTIFICATE----- +CkAKDm5lYnVsYSByb290IGNhKJfap9AFMJfg1+YGOiCUQGByMuNRhIlQBOyzXWbL +vcKBwDhov900phEfJ5DN3kABEkDCq5R8qBiu8sl54yVfgRcQXEDt3cHr8UTSLszv +bzBEr00kERQxxTzTsH8cpYEgRoipvmExvg8WP8NdAJEYJosB +-----END NOT A NEBULA CERTIFICATE----- +`) + invalidPem := []byte(`# Not a valid PEM format +-BEGIN NEBULA CERTIFICATE----- +CkAKDm5lYnVsYSByb290IGNhKJfap9AFMJfg1+YGOiCUQGByMuNRhIlQBOyzXWbL +vcKBwDhov900phEfJ5DN3kABEkDCq5R8qBiu8sl54yVfgRcQXEDt3cHr8UTSLszv +bzBEr00kERQxxTzTsH8cpYEgRoipvmExvg8WP8NdAJEYJosB +-END NEBULA CERTIFICATE----`) + + certBundle := appendByteSlices(goodCert, badBanner, invalidPem) + + // Success test case + cert, rest, err := UnmarshalCertificateFromPEM(certBundle) + assert.NotNil(t, cert) + assert.Equal(t, rest, append(badBanner, invalidPem...)) + assert.Nil(t, err) + + // Fail due to invalid banner. + cert, rest, err = UnmarshalCertificateFromPEM(rest) + assert.Nil(t, cert) + assert.Equal(t, rest, invalidPem) + assert.EqualError(t, err, "bytes did not contain a proper certificate banner") + + // Fail due to ivalid PEM format, because + // it's missing the requisite pre-encapsulation boundary. + cert, rest, err = UnmarshalCertificateFromPEM(rest) + assert.Nil(t, cert) + assert.Equal(t, rest, invalidPem) + assert.EqualError(t, err, "input did not contain a valid PEM encoded block") +} + +func TestUnmarshalSigningPrivateKeyFromPEM(t *testing.T) { + privKey := []byte(`# A good key +-----BEGIN NEBULA ED25519 PRIVATE KEY----- +AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA== +-----END NEBULA ED25519 PRIVATE KEY----- +`) + privP256Key := []byte(`# A good key +-----BEGIN NEBULA ECDSA P256 PRIVATE KEY----- +AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA= +-----END NEBULA ECDSA P256 PRIVATE KEY----- +`) + shortKey := []byte(`# A short key +-----BEGIN NEBULA ED25519 PRIVATE KEY----- +AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA +-----END NEBULA ED25519 PRIVATE KEY----- +`) + invalidBanner := []byte(`# Invalid banner +-----BEGIN NOT A NEBULA PRIVATE KEY----- +AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA== +-----END NOT A NEBULA PRIVATE KEY----- +`) + invalidPem := []byte(`# Not a valid PEM format +-BEGIN NEBULA ED25519 PRIVATE KEY----- +AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA== +-END NEBULA ED25519 PRIVATE KEY-----`) + + keyBundle := appendByteSlices(privKey, privP256Key, shortKey, invalidBanner, invalidPem) + + // Success test case + k, rest, curve, err := UnmarshalSigningPrivateKeyFromPEM(keyBundle) + assert.Len(t, k, 64) + assert.Equal(t, rest, appendByteSlices(privP256Key, shortKey, invalidBanner, invalidPem)) + assert.Equal(t, Curve_CURVE25519, curve) + assert.Nil(t, err) + + // Success test case + k, rest, curve, err = UnmarshalSigningPrivateKeyFromPEM(rest) + assert.Len(t, k, 32) + assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem)) + assert.Equal(t, Curve_P256, curve) + assert.Nil(t, err) + + // Fail due to short key + k, rest, curve, err = UnmarshalSigningPrivateKeyFromPEM(rest) + assert.Nil(t, k) + assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem)) + assert.EqualError(t, err, "key was not 64 bytes, is invalid Ed25519 private key") + + // Fail due to invalid banner + k, rest, curve, err = UnmarshalSigningPrivateKeyFromPEM(rest) + assert.Nil(t, k) + assert.Equal(t, rest, invalidPem) + assert.EqualError(t, err, "bytes did not contain a proper Ed25519/ECDSA private key banner") + + // Fail due to ivalid PEM format, because + // it's missing the requisite pre-encapsulation boundary. + k, rest, curve, err = UnmarshalSigningPrivateKeyFromPEM(rest) + assert.Nil(t, k) + assert.Equal(t, rest, invalidPem) + assert.EqualError(t, err, "input did not contain a valid PEM encoded block") +} + +func TestUnmarshalPrivateKeyFromPEM(t *testing.T) { + privKey := []byte(`# A good key +-----BEGIN NEBULA X25519 PRIVATE KEY----- +AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA= +-----END NEBULA X25519 PRIVATE KEY----- +`) + privP256Key := []byte(`# A good key +-----BEGIN NEBULA P256 PRIVATE KEY----- +AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA= +-----END NEBULA P256 PRIVATE KEY----- +`) + shortKey := []byte(`# A short key +-----BEGIN NEBULA X25519 PRIVATE KEY----- +AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA== +-----END NEBULA X25519 PRIVATE KEY----- +`) + invalidBanner := []byte(`# Invalid banner +-----BEGIN NOT A NEBULA PRIVATE KEY----- +AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA= +-----END NOT A NEBULA PRIVATE KEY----- +`) + invalidPem := []byte(`# Not a valid PEM format +-BEGIN NEBULA X25519 PRIVATE KEY----- +AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA= +-END NEBULA X25519 PRIVATE KEY-----`) + + keyBundle := appendByteSlices(privKey, privP256Key, shortKey, invalidBanner, invalidPem) + + // Success test case + k, rest, curve, err := UnmarshalPrivateKeyFromPEM(keyBundle) + assert.Len(t, k, 32) + assert.Equal(t, rest, appendByteSlices(privP256Key, shortKey, invalidBanner, invalidPem)) + assert.Equal(t, Curve_CURVE25519, curve) + assert.Nil(t, err) + + // Success test case + k, rest, curve, err = UnmarshalPrivateKeyFromPEM(rest) + assert.Len(t, k, 32) + assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem)) + assert.Equal(t, Curve_P256, curve) + assert.Nil(t, err) + + // Fail due to short key + k, rest, curve, err = UnmarshalPrivateKeyFromPEM(rest) + assert.Nil(t, k) + assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem)) + assert.EqualError(t, err, "key was not 32 bytes, is invalid CURVE25519 private key") + + // Fail due to invalid banner + k, rest, curve, err = UnmarshalPrivateKeyFromPEM(rest) + assert.Nil(t, k) + assert.Equal(t, rest, invalidPem) + assert.EqualError(t, err, "bytes did not contain a proper private key banner") + + // Fail due to ivalid PEM format, because + // it's missing the requisite pre-encapsulation boundary. + k, rest, curve, err = UnmarshalPrivateKeyFromPEM(rest) + assert.Nil(t, k) + assert.Equal(t, rest, invalidPem) + assert.EqualError(t, err, "input did not contain a valid PEM encoded block") +} + +func TestUnmarshalPublicKeyFromPEM(t *testing.T) { + pubKey := []byte(`# A good key +-----BEGIN NEBULA ED25519 PUBLIC KEY----- +AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA= +-----END NEBULA ED25519 PUBLIC KEY----- +`) + shortKey := []byte(`# A short key +-----BEGIN NEBULA ED25519 PUBLIC KEY----- +AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA== +-----END NEBULA ED25519 PUBLIC KEY----- +`) + invalidBanner := []byte(`# Invalid banner +-----BEGIN NOT A NEBULA PUBLIC KEY----- +AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA= +-----END NOT A NEBULA PUBLIC KEY----- +`) + invalidPem := []byte(`# Not a valid PEM format +-BEGIN NEBULA ED25519 PUBLIC KEY----- +AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA= +-END NEBULA ED25519 PUBLIC KEY-----`) + + keyBundle := appendByteSlices(pubKey, shortKey, invalidBanner, invalidPem) + + // Success test case + k, rest, curve, err := UnmarshalPublicKeyFromPEM(keyBundle) + assert.Equal(t, 32, len(k)) + assert.Equal(t, Curve_CURVE25519, curve) + assert.Nil(t, err) + assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem)) + + // Fail due to short key + k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest) + assert.Nil(t, k) + assert.Equal(t, Curve_CURVE25519, curve) + assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem)) + assert.EqualError(t, err, "key was not 32 bytes, is invalid CURVE25519 public key") + + // Fail due to invalid banner + k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest) + assert.Nil(t, k) + assert.Equal(t, Curve_CURVE25519, curve) + assert.EqualError(t, err, "bytes did not contain a proper public key banner") + assert.Equal(t, rest, invalidPem) + + // Fail due to ivalid PEM format, because + // it's missing the requisite pre-encapsulation boundary. + k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest) + assert.Nil(t, k) + assert.Equal(t, Curve_CURVE25519, curve) + assert.Equal(t, rest, invalidPem) + assert.EqualError(t, err, "input did not contain a valid PEM encoded block") +} + +func TestUnmarshalX25519PublicKey(t *testing.T) { + pubKey := []byte(`# A good key +-----BEGIN NEBULA X25519 PUBLIC KEY----- +AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA= +-----END NEBULA X25519 PUBLIC KEY----- +`) + pubP256Key := []byte(`# A good key +-----BEGIN NEBULA P256 PUBLIC KEY----- +AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA +AAAAAAAAAAAAAAAAAAAAAAA= +-----END NEBULA P256 PUBLIC KEY----- +`) + shortKey := []byte(`# A short key +-----BEGIN NEBULA X25519 PUBLIC KEY----- +AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA== +-----END NEBULA X25519 PUBLIC KEY----- +`) + invalidBanner := []byte(`# Invalid banner +-----BEGIN NOT A NEBULA PUBLIC KEY----- +AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA= +-----END NOT A NEBULA PUBLIC KEY----- +`) + invalidPem := []byte(`# Not a valid PEM format +-BEGIN NEBULA X25519 PUBLIC KEY----- +AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA= +-END NEBULA X25519 PUBLIC KEY-----`) + + keyBundle := appendByteSlices(pubKey, pubP256Key, shortKey, invalidBanner, invalidPem) + + // Success test case + k, rest, curve, err := UnmarshalPublicKeyFromPEM(keyBundle) + assert.Equal(t, 32, len(k)) + assert.Nil(t, err) + assert.Equal(t, rest, appendByteSlices(pubP256Key, shortKey, invalidBanner, invalidPem)) + assert.Equal(t, Curve_CURVE25519, curve) + + // Success test case + k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest) + assert.Equal(t, 65, len(k)) + assert.Nil(t, err) + assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem)) + assert.Equal(t, Curve_P256, curve) + + // Fail due to short key + k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest) + assert.Nil(t, k) + assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem)) + assert.EqualError(t, err, "key was not 32 bytes, is invalid CURVE25519 public key") + + // Fail due to invalid banner + k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest) + assert.Nil(t, k) + assert.EqualError(t, err, "bytes did not contain a proper public key banner") + assert.Equal(t, rest, invalidPem) + + // Fail due to ivalid PEM format, because + // it's missing the requisite pre-encapsulation boundary. + k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest) + assert.Nil(t, k) + assert.Equal(t, rest, invalidPem) + assert.EqualError(t, err, "input did not contain a valid PEM encoded block") +} diff --git a/cert/sign.go b/cert/sign.go new file mode 100644 index 0000000..e446aa1 --- /dev/null +++ b/cert/sign.go @@ -0,0 +1,76 @@ +package cert + +import ( + "fmt" + "net/netip" + "time" + + "github.com/slackhq/nebula/pkclient" +) + +// TBSCertificate represents a certificate intended to be signed. +// It is invalid to use this structure as a Certificate. +type TBSCertificate struct { + Version Version + Name string + Networks []netip.Prefix + UnsafeNetworks []netip.Prefix + Groups []string + IsCA bool + NotBefore time.Time + NotAfter time.Time + PublicKey []byte + Curve Curve + issuer string +} + +// Sign will create a sealed certificate using details provided by the TBSCertificate as long as those +// details do not violate constraints of the signing certificate. +// If the TBSCertificate is a CA then signer must be nil. +func (t *TBSCertificate) Sign(signer Certificate, curve Curve, key []byte) (Certificate, error) { + return t.sign(signer, curve, key, nil) +} + +func (t *TBSCertificate) SignPkcs11(signer Certificate, curve Curve, client *pkclient.PKClient) (Certificate, error) { + if curve != Curve_P256 { + return nil, fmt.Errorf("only P256 is supported by PKCS#11") + } + + return t.sign(signer, curve, nil, client) +} + +func (t *TBSCertificate) sign(signer Certificate, curve Curve, key []byte, client *pkclient.PKClient) (Certificate, error) { + if curve != t.Curve { + return nil, fmt.Errorf("curve in cert and private key supplied don't match") + } + + //TODO: make sure we have all minimum properties to sign, like a public key + + if signer != nil { + if t.IsCA { + return nil, fmt.Errorf("can not sign a CA certificate with another") + } + + err := checkCAConstraints(signer, t.NotBefore, t.NotAfter, t.Groups, t.Networks, t.UnsafeNetworks) + if err != nil { + return nil, err + } + + issuer, err := signer.Fingerprint() + if err != nil { + return nil, fmt.Errorf("error computing issuer: %v", err) + } + t.issuer = issuer + } else { + if !t.IsCA { + return nil, fmt.Errorf("self signed certificates must have IsCA set to true") + } + } + + switch t.Version { + case Version1: + return signV1(t, curve, key, client) + default: + return nil, fmt.Errorf("unknown cert version %d", t.Version) + } +} diff --git a/cmd/nebula-cert/ca.go b/cmd/nebula-cert/ca.go index 757f883..90ea8ff 100644 --- a/cmd/nebula-cert/ca.go +++ b/cmd/nebula-cert/ca.go @@ -4,12 +4,11 @@ import ( "crypto/ecdsa" "crypto/elliptic" "crypto/rand" - "flag" "fmt" "io" "math" - "net" + "net/netip" "os" "strings" "time" @@ -114,38 +113,36 @@ func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error } } - var ips []*net.IPNet + var ips []netip.Prefix if *cf.ips != "" { for _, rs := range strings.Split(*cf.ips, ",") { rs := strings.Trim(rs, " ") if rs != "" { - ip, ipNet, err := net.ParseCIDR(rs) + n, err := netip.ParsePrefix(rs) if err != nil { return newHelpErrorf("invalid ip definition: %s", err) } - if ip.To4() == nil { + if !n.Addr().Is4() { return newHelpErrorf("invalid ip definition: can only be ipv4, have %s", rs) } - - ipNet.IP = ip - ips = append(ips, ipNet) + ips = append(ips, n) } } } - var subnets []*net.IPNet + var subnets []netip.Prefix if *cf.subnets != "" { for _, rs := range strings.Split(*cf.subnets, ",") { rs := strings.Trim(rs, " ") if rs != "" { - _, s, err := net.ParseCIDR(rs) + n, err := netip.ParsePrefix(rs) if err != nil { return newHelpErrorf("invalid subnet definition: %s", err) } - if s.IP.To4() == nil { + if !n.Addr().Is4() { return newHelpErrorf("invalid subnet definition: can only be ipv4, have %s", rs) } - subnets = append(subnets, s) + subnets = append(subnets, n) } } } @@ -224,19 +221,17 @@ func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error } } - nc := cert.NebulaCertificate{ - Details: cert.NebulaCertificateDetails{ - Name: *cf.name, - Groups: groups, - Ips: ips, - Subnets: subnets, - NotBefore: time.Now(), - NotAfter: time.Now().Add(*cf.duration), - PublicKey: pub, - IsCA: true, - Curve: curve, - }, - Pkcs11Backed: isP11, + t := &cert.TBSCertificate{ + Version: cert.Version1, + Name: *cf.name, + Groups: groups, + Networks: ips, + UnsafeNetworks: subnets, + NotBefore: time.Now(), + NotAfter: time.Now().Add(*cf.duration), + PublicKey: pub, + IsCA: true, + Curve: curve, } if !isP11 { @@ -249,15 +244,16 @@ func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error return fmt.Errorf("refusing to overwrite existing CA cert: %s", *cf.outCertPath) } + var c cert.Certificate var b []byte if isP11 { - err = nc.SignPkcs11(curve, p11Client) + c, err = t.SignPkcs11(nil, curve, p11Client) if err != nil { return fmt.Errorf("error while signing with PKCS#11: %w", err) } } else { - err = nc.Sign(curve, rawPriv) + c, err = t.Sign(nil, curve, rawPriv) if err != nil { return fmt.Errorf("error while signing: %s", err) } @@ -268,19 +264,16 @@ func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error return fmt.Errorf("error while encrypting out-key: %s", err) } } else { - b = cert.MarshalSigningPrivateKey(curve, rawPriv) + b = cert.MarshalSigningPrivateKeyToPEM(curve, rawPriv) } err = os.WriteFile(*cf.outKeyPath, b, 0600) if err != nil { return fmt.Errorf("error while writing out-key: %s", err) } - if _, err := os.Stat(*cf.outCertPath); err == nil { - return fmt.Errorf("refusing to overwrite existing CA cert: %s", *cf.outCertPath) - } } - b, err = nc.MarshalToPEM() + b, err = c.MarshalPEM() if err != nil { return fmt.Errorf("error while marshalling certificate: %s", err) } diff --git a/cmd/nebula-cert/ca_test.go b/cmd/nebula-cert/ca_test.go index cb8b57a..06a24ed 100644 --- a/cmd/nebula-cert/ca_test.go +++ b/cmd/nebula-cert/ca_test.go @@ -109,7 +109,7 @@ func Test_ca(t *testing.T) { // create temp key file keyF, err := os.CreateTemp("", "test.key") assert.Nil(t, err) - os.Remove(keyF.Name()) + assert.Nil(t, os.Remove(keyF.Name())) // failed cert write ob.Reset() @@ -122,8 +122,8 @@ func Test_ca(t *testing.T) { // create temp cert file crtF, err := os.CreateTemp("", "test.crt") assert.Nil(t, err) - os.Remove(crtF.Name()) - os.Remove(keyF.Name()) + assert.Nil(t, os.Remove(crtF.Name())) + assert.Nil(t, os.Remove(keyF.Name())) // test proper cert with removed empty groups and subnets ob.Reset() @@ -135,25 +135,26 @@ func Test_ca(t *testing.T) { // read cert and key files rb, _ := os.ReadFile(keyF.Name()) - lKey, b, err := cert.UnmarshalEd25519PrivateKey(rb) + lKey, b, c, err := cert.UnmarshalSigningPrivateKeyFromPEM(rb) + assert.Equal(t, cert.Curve_CURVE25519, c) assert.Len(t, b, 0) assert.Nil(t, err) assert.Len(t, lKey, 64) rb, _ = os.ReadFile(crtF.Name()) - lCrt, b, err := cert.UnmarshalNebulaCertificateFromPEM(rb) + lCrt, b, err := cert.UnmarshalCertificateFromPEM(rb) assert.Len(t, b, 0) assert.Nil(t, err) - assert.Equal(t, "test", lCrt.Details.Name) - assert.Len(t, lCrt.Details.Ips, 0) - assert.True(t, lCrt.Details.IsCA) - assert.Equal(t, []string{"1", "2", "3", "4", "5"}, lCrt.Details.Groups) - assert.Len(t, lCrt.Details.Subnets, 0) - assert.Len(t, lCrt.Details.PublicKey, 32) - assert.Equal(t, time.Duration(time.Minute*100), lCrt.Details.NotAfter.Sub(lCrt.Details.NotBefore)) - assert.Equal(t, "", lCrt.Details.Issuer) - assert.True(t, lCrt.CheckSignature(lCrt.Details.PublicKey)) + assert.Equal(t, "test", lCrt.Name()) + assert.Len(t, lCrt.Networks(), 0) + assert.True(t, lCrt.IsCA()) + assert.Equal(t, []string{"1", "2", "3", "4", "5"}, lCrt.Groups()) + assert.Len(t, lCrt.UnsafeNetworks(), 0) + assert.Len(t, lCrt.PublicKey(), 32) + assert.Equal(t, time.Duration(time.Minute*100), lCrt.NotAfter().Sub(lCrt.NotBefore())) + assert.Equal(t, "", lCrt.Issuer()) + assert.True(t, lCrt.CheckSignature(lCrt.PublicKey())) // test encrypted key os.Remove(keyF.Name()) diff --git a/cmd/nebula-cert/keygen.go b/cmd/nebula-cert/keygen.go index 2355c4f..496f84c 100644 --- a/cmd/nebula-cert/keygen.go +++ b/cmd/nebula-cert/keygen.go @@ -82,12 +82,12 @@ func keygen(args []string, out io.Writer, errOut io.Writer) error { return fmt.Errorf("error while getting public key: %w", err) } } else { - err = os.WriteFile(*cf.outKeyPath, cert.MarshalPrivateKey(curve, rawPriv), 0600) + err = os.WriteFile(*cf.outKeyPath, cert.MarshalPrivateKeyToPEM(curve, rawPriv), 0600) if err != nil { return fmt.Errorf("error while writing out-key: %s", err) } } - err = os.WriteFile(*cf.outPubPath, cert.MarshalPublicKey(curve, pub), 0600) + err = os.WriteFile(*cf.outPubPath, cert.MarshalPublicKeyToPEM(curve, pub), 0600) if err != nil { return fmt.Errorf("error while writing out-pub: %s", err) } diff --git a/cmd/nebula-cert/keygen_test.go b/cmd/nebula-cert/keygen_test.go index 925b266..18ceb4b 100644 --- a/cmd/nebula-cert/keygen_test.go +++ b/cmd/nebula-cert/keygen_test.go @@ -81,13 +81,15 @@ func Test_keygen(t *testing.T) { // read cert and key files rb, _ := os.ReadFile(keyF.Name()) - lKey, b, err := cert.UnmarshalX25519PrivateKey(rb) + lKey, b, curve, err := cert.UnmarshalPrivateKeyFromPEM(rb) + assert.Equal(t, cert.Curve_CURVE25519, curve) assert.Len(t, b, 0) assert.Nil(t, err) assert.Len(t, lKey, 32) rb, _ = os.ReadFile(pubF.Name()) - lPub, b, err := cert.UnmarshalX25519PublicKey(rb) + lPub, b, curve, err := cert.UnmarshalPublicKeyFromPEM(rb) + assert.Equal(t, cert.Curve_CURVE25519, curve) assert.Len(t, b, 0) assert.Nil(t, err) assert.Len(t, lPub, 32) diff --git a/cmd/nebula-cert/print.go b/cmd/nebula-cert/print.go index 746d6a3..a62c223 100644 --- a/cmd/nebula-cert/print.go +++ b/cmd/nebula-cert/print.go @@ -45,12 +45,12 @@ func printCert(args []string, out io.Writer, errOut io.Writer) error { return fmt.Errorf("unable to read cert; %s", err) } - var c *cert.NebulaCertificate + var c cert.Certificate var qrBytes []byte part := 0 for { - c, rawCert, err = cert.UnmarshalNebulaCertificateFromPEM(rawCert) + c, rawCert, err = cert.UnmarshalCertificateFromPEM(rawCert) if err != nil { return fmt.Errorf("error while unmarshaling cert: %s", err) } @@ -66,7 +66,7 @@ func printCert(args []string, out io.Writer, errOut io.Writer) error { } if *pf.outQRPath != "" { - b, err := c.MarshalToPEM() + b, err := c.MarshalPEM() if err != nil { return fmt.Errorf("error while marshalling cert to PEM: %s", err) } diff --git a/cmd/nebula-cert/print_test.go b/cmd/nebula-cert/print_test.go index 9fa8a54..4c9a72d 100644 --- a/cmd/nebula-cert/print_test.go +++ b/cmd/nebula-cert/print_test.go @@ -2,6 +2,10 @@ package main import ( "bytes" + "crypto/ed25519" + "crypto/rand" + "encoding/hex" + "net/netip" "os" "testing" "time" @@ -68,25 +72,22 @@ func Test_printCert(t *testing.T) { eb.Reset() tf.Truncate(0) tf.Seek(0, 0) - c := cert.NebulaCertificate{ - Details: cert.NebulaCertificateDetails{ - Name: "test", - Groups: []string{"hi"}, - PublicKey: []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2}, - }, - Signature: []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2}, - } + ca, caKey := NewTestCaCert("test ca", nil, nil, time.Time{}, time.Time{}, nil, nil, nil) + c, _ := NewTestCert(ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, []string{"hi"}) - p, _ := c.MarshalToPEM() + p, _ := c.MarshalPEM() tf.Write(p) tf.Write(p) tf.Write(p) err = printCert([]string{"-path", tf.Name()}, ob, eb) + fp, _ := c.Fingerprint() + pk := hex.EncodeToString(c.PublicKey()) + sig := hex.EncodeToString(c.Signature()) assert.Nil(t, err) assert.Equal( t, - "NebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: \n\t\tPublic key: 0102030405060708090001020304050607080900010203040506070809000102\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: cc3492c0e9c48f17547f5987ea807462ebb3451e622590a10bb3763c344c82bd\n\tSignature: 0102030405060708090001020304050607080900010203040506070809000102\n}\nNebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: \n\t\tPublic key: 0102030405060708090001020304050607080900010203040506070809000102\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: cc3492c0e9c48f17547f5987ea807462ebb3451e622590a10bb3763c344c82bd\n\tSignature: 0102030405060708090001020304050607080900010203040506070809000102\n}\nNebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: \n\t\tPublic key: 0102030405060708090001020304050607080900010203040506070809000102\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: cc3492c0e9c48f17547f5987ea807462ebb3451e622590a10bb3763c344c82bd\n\tSignature: 0102030405060708090001020304050607080900010203040506070809000102\n}\n", + "NebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: "+c.Issuer()+"\n\t\tPublic key: "+pk+"\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: "+fp+"\n\tSignature: "+sig+"\n}\nNebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: "+c.Issuer()+"\n\t\tPublic key: "+pk+"\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: "+fp+"\n\tSignature: "+sig+"\n}\nNebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: "+c.Issuer()+"\n\t\tPublic key: "+pk+"\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: "+fp+"\n\tSignature: "+sig+"\n}\n", ob.String(), ) assert.Equal(t, "", eb.String()) @@ -96,26 +97,79 @@ func Test_printCert(t *testing.T) { eb.Reset() tf.Truncate(0) tf.Seek(0, 0) - c = cert.NebulaCertificate{ - Details: cert.NebulaCertificateDetails{ - Name: "test", - Groups: []string{"hi"}, - PublicKey: []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2}, - }, - Signature: []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2}, - } - - p, _ = c.MarshalToPEM() tf.Write(p) tf.Write(p) tf.Write(p) err = printCert([]string{"-json", "-path", tf.Name()}, ob, eb) + fp, _ = c.Fingerprint() + pk = hex.EncodeToString(c.PublicKey()) + sig = hex.EncodeToString(c.Signature()) assert.Nil(t, err) assert.Equal( t, - "{\"details\":{\"curve\":\"CURVE25519\",\"groups\":[\"hi\"],\"ips\":[],\"isCa\":false,\"issuer\":\"\",\"name\":\"test\",\"notAfter\":\"0001-01-01T00:00:00Z\",\"notBefore\":\"0001-01-01T00:00:00Z\",\"publicKey\":\"0102030405060708090001020304050607080900010203040506070809000102\",\"subnets\":[]},\"fingerprint\":\"cc3492c0e9c48f17547f5987ea807462ebb3451e622590a10bb3763c344c82bd\",\"signature\":\"0102030405060708090001020304050607080900010203040506070809000102\"}\n{\"details\":{\"curve\":\"CURVE25519\",\"groups\":[\"hi\"],\"ips\":[],\"isCa\":false,\"issuer\":\"\",\"name\":\"test\",\"notAfter\":\"0001-01-01T00:00:00Z\",\"notBefore\":\"0001-01-01T00:00:00Z\",\"publicKey\":\"0102030405060708090001020304050607080900010203040506070809000102\",\"subnets\":[]},\"fingerprint\":\"cc3492c0e9c48f17547f5987ea807462ebb3451e622590a10bb3763c344c82bd\",\"signature\":\"0102030405060708090001020304050607080900010203040506070809000102\"}\n{\"details\":{\"curve\":\"CURVE25519\",\"groups\":[\"hi\"],\"ips\":[],\"isCa\":false,\"issuer\":\"\",\"name\":\"test\",\"notAfter\":\"0001-01-01T00:00:00Z\",\"notBefore\":\"0001-01-01T00:00:00Z\",\"publicKey\":\"0102030405060708090001020304050607080900010203040506070809000102\",\"subnets\":[]},\"fingerprint\":\"cc3492c0e9c48f17547f5987ea807462ebb3451e622590a10bb3763c344c82bd\",\"signature\":\"0102030405060708090001020304050607080900010203040506070809000102\"}\n", + "{\"details\":{\"curve\":\"CURVE25519\",\"groups\":[\"hi\"],\"ips\":[],\"isCa\":false,\"issuer\":\""+c.Issuer()+"\",\"name\":\"test\",\"notAfter\":\"0001-01-01T00:00:00Z\",\"notBefore\":\"0001-01-01T00:00:00Z\",\"publicKey\":\""+pk+"\",\"subnets\":[]},\"fingerprint\":\""+fp+"\",\"signature\":\""+sig+"\"}\n{\"details\":{\"curve\":\"CURVE25519\",\"groups\":[\"hi\"],\"ips\":[],\"isCa\":false,\"issuer\":\""+c.Issuer()+"\",\"name\":\"test\",\"notAfter\":\"0001-01-01T00:00:00Z\",\"notBefore\":\"0001-01-01T00:00:00Z\",\"publicKey\":\""+pk+"\",\"subnets\":[]},\"fingerprint\":\""+fp+"\",\"signature\":\""+sig+"\"}\n{\"details\":{\"curve\":\"CURVE25519\",\"groups\":[\"hi\"],\"ips\":[],\"isCa\":false,\"issuer\":\""+c.Issuer()+"\",\"name\":\"test\",\"notAfter\":\"0001-01-01T00:00:00Z\",\"notBefore\":\"0001-01-01T00:00:00Z\",\"publicKey\":\""+pk+"\",\"subnets\":[]},\"fingerprint\":\""+fp+"\",\"signature\":\""+sig+"\"}\n", ob.String(), ) assert.Equal(t, "", eb.String()) } + +// NewTestCaCert will generate a CA cert +func NewTestCaCert(name string, pubKey, privKey []byte, before, after time.Time, networks, unsafeNetworks []netip.Prefix, groups []string) (cert.Certificate, []byte) { + var err error + if pubKey == nil || privKey == nil { + pubKey, privKey, err = ed25519.GenerateKey(rand.Reader) + if err != nil { + panic(err) + } + } + + t := &cert.TBSCertificate{ + Version: cert.Version1, + Name: name, + NotBefore: time.Unix(before.Unix(), 0), + NotAfter: time.Unix(after.Unix(), 0), + PublicKey: pubKey, + Networks: networks, + UnsafeNetworks: unsafeNetworks, + Groups: groups, + IsCA: true, + } + + c, err := t.Sign(nil, cert.Curve_CURVE25519, privKey) + if err != nil { + panic(err) + } + + return c, privKey +} + +func NewTestCert(ca cert.Certificate, signerKey []byte, name string, before, after time.Time, networks, unsafeNetworks []netip.Prefix, groups []string) (cert.Certificate, []byte) { + if before.IsZero() { + before = ca.NotBefore() + } + + if after.IsZero() { + after = ca.NotAfter() + } + + pub, rawPriv := x25519Keypair() + nc := &cert.TBSCertificate{ + Version: cert.Version1, + Name: name, + Networks: networks, + UnsafeNetworks: unsafeNetworks, + Groups: groups, + NotBefore: time.Unix(before.Unix(), 0), + NotAfter: time.Unix(after.Unix(), 0), + PublicKey: pub, + IsCA: false, + } + + c, err := nc.Sign(ca, ca.Curve(), signerKey) + if err != nil { + panic(err) + } + + return c, rawPriv +} diff --git a/cmd/nebula-cert/sign.go b/cmd/nebula-cert/sign.go index 8e86fe5..13e807f 100644 --- a/cmd/nebula-cert/sign.go +++ b/cmd/nebula-cert/sign.go @@ -6,7 +6,7 @@ import ( "flag" "fmt" "io" - "net" + "net/netip" "os" "strings" "time" @@ -80,15 +80,17 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) var curve cert.Curve var caKey []byte + if !isP11 { var rawCAKey []byte rawCAKey, err := os.ReadFile(*sf.caKeyPath) + if err != nil { return fmt.Errorf("error while reading ca-key: %s", err) } // naively attempt to decode the private key as though it is not encrypted - caKey, _, curve, err = cert.UnmarshalSigningPrivateKey(rawCAKey) + caKey, _, curve, err = cert.UnmarshalSigningPrivateKeyFromPEM(rawCAKey) if err == cert.ErrPrivateKeyEncrypted { // ask for a passphrase until we get one var passphrase []byte @@ -124,7 +126,7 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) return fmt.Errorf("error while reading ca-crt: %s", err) } - caCert, _, err := cert.UnmarshalNebulaCertificateFromPEM(rawCACert) + caCert, _, err := cert.UnmarshalCertificateFromPEM(rawCACert) if err != nil { return fmt.Errorf("error while parsing ca-crt: %s", err) } @@ -135,30 +137,24 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) } } - issuer, err := caCert.Sha256Sum() - if err != nil { - return fmt.Errorf("error while getting -ca-crt fingerprint: %s", err) - } - if caCert.Expired(time.Now()) { return fmt.Errorf("ca certificate is expired") } // if no duration is given, expire one second before the root expires if *sf.duration <= 0 { - *sf.duration = time.Until(caCert.Details.NotAfter) - time.Second*1 + *sf.duration = time.Until(caCert.NotAfter()) - time.Second*1 } - ip, ipNet, err := net.ParseCIDR(*sf.ip) + network, err := netip.ParsePrefix(*sf.ip) if err != nil { - return newHelpErrorf("invalid ip definition: %s", err) + return newHelpErrorf("invalid ip definition: %s", *sf.ip) } - if ip.To4() == nil { + if !network.Addr().Is4() { return newHelpErrorf("invalid ip definition: can only be ipv4, have %s", *sf.ip) } - ipNet.IP = ip - groups := []string{} + var groups []string if *sf.groups != "" { for _, rg := range strings.Split(*sf.groups, ",") { g := strings.TrimSpace(rg) @@ -168,16 +164,16 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) } } - subnets := []*net.IPNet{} + var subnets []netip.Prefix if *sf.subnets != "" { for _, rs := range strings.Split(*sf.subnets, ",") { rs := strings.Trim(rs, " ") if rs != "" { - _, s, err := net.ParseCIDR(rs) + s, err := netip.ParsePrefix(rs) if err != nil { - return newHelpErrorf("invalid subnet definition: %s", err) + return newHelpErrorf("invalid subnet definition: %s", rs) } - if s.IP.To4() == nil { + if !s.Addr().Is4() { return newHelpErrorf("invalid subnet definition: can only be ipv4, have %s", rs) } subnets = append(subnets, s) @@ -205,7 +201,8 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) if err != nil { return fmt.Errorf("error while reading in-pub: %s", err) } - pub, _, pubCurve, err = cert.UnmarshalPublicKey(rawPub) + + pub, _, pubCurve, err = cert.UnmarshalPublicKeyFromPEM(rawPub) if err != nil { return fmt.Errorf("error while parsing in-pub: %s", err) } @@ -221,36 +218,17 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) pub, rawPriv = newKeypair(curve) } - nc := cert.NebulaCertificate{ - Details: cert.NebulaCertificateDetails{ - Name: *sf.name, - Ips: []*net.IPNet{ipNet}, - Groups: groups, - Subnets: subnets, - NotBefore: time.Now(), - NotAfter: time.Now().Add(*sf.duration), - PublicKey: pub, - IsCA: false, - Issuer: issuer, - Curve: curve, - }, - Pkcs11Backed: isP11, - } - - if p11Client == nil { - err = nc.Sign(curve, caKey) - if err != nil { - return fmt.Errorf("error while signing: %w", err) - } - } else { - err = nc.SignPkcs11(curve, p11Client) - if err != nil { - return fmt.Errorf("error while signing with PKCS#11: %w", err) - } - } - - if err := nc.CheckRootConstrains(caCert); err != nil { - return fmt.Errorf("refusing to sign, root certificate constraints violated: %s", err) + t := &cert.TBSCertificate{ + Version: cert.Version1, + Name: *sf.name, + Networks: []netip.Prefix{network}, + Groups: groups, + UnsafeNetworks: subnets, + NotBefore: time.Now(), + NotAfter: time.Now().Add(*sf.duration), + PublicKey: pub, + IsCA: false, + Curve: curve, } if *sf.outKeyPath == "" { @@ -265,18 +243,32 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) return fmt.Errorf("refusing to overwrite existing cert: %s", *sf.outCertPath) } + var c cert.Certificate + + if p11Client == nil { + c, err = t.Sign(caCert, curve, caKey) + if err != nil { + return fmt.Errorf("error while signing: %w", err) + } + } else { + c, err = t.SignPkcs11(caCert, curve, p11Client) + if err != nil { + return fmt.Errorf("error while signing with PKCS#11: %w", err) + } + } + if !isP11 && *sf.inPubPath == "" { if _, err := os.Stat(*sf.outKeyPath); err == nil { return fmt.Errorf("refusing to overwrite existing key: %s", *sf.outKeyPath) } - err = os.WriteFile(*sf.outKeyPath, cert.MarshalPrivateKey(curve, rawPriv), 0600) + err = os.WriteFile(*sf.outKeyPath, cert.MarshalPrivateKeyToPEM(curve, rawPriv), 0600) if err != nil { return fmt.Errorf("error while writing out-key: %s", err) } } - b, err := nc.MarshalToPEM() + b, err := c.MarshalPEM() if err != nil { return fmt.Errorf("error while marshalling certificate: %s", err) } diff --git a/cmd/nebula-cert/sign_test.go b/cmd/nebula-cert/sign_test.go index d6e2a39..b68434d 100644 --- a/cmd/nebula-cert/sign_test.go +++ b/cmd/nebula-cert/sign_test.go @@ -117,7 +117,7 @@ func Test_signCert(t *testing.T) { ob.Reset() eb.Reset() caPub, caPriv, _ := ed25519.GenerateKey(rand.Reader) - caKeyF.Write(cert.MarshalEd25519PrivateKey(caPriv)) + caKeyF.Write(cert.MarshalSigningPrivateKeyToPEM(cert.Curve_CURVE25519, caPriv)) // failed to read cert args = []string{"-ca-crt", "./nope", "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"} @@ -138,16 +138,8 @@ func Test_signCert(t *testing.T) { assert.Empty(t, eb.String()) // write a proper ca cert for later - ca := cert.NebulaCertificate{ - Details: cert.NebulaCertificateDetails{ - Name: "ca", - NotBefore: time.Now(), - NotAfter: time.Now().Add(time.Minute * 200), - PublicKey: caPub, - IsCA: true, - }, - } - b, _ := ca.MarshalToPEM() + ca, _ := NewTestCaCert("ca", caPub, caPriv, time.Now(), time.Now().Add(time.Minute*200), nil, nil, nil) + b, _ := ca.MarshalPEM() caCrtF.Write(b) // failed to read pub @@ -172,13 +164,13 @@ func Test_signCert(t *testing.T) { ob.Reset() eb.Reset() inPub, _ := x25519Keypair() - inPubF.Write(cert.MarshalX25519PublicKey(inPub)) + inPubF.Write(cert.MarshalPublicKeyToPEM(cert.Curve_CURVE25519, inPub)) // bad ip cidr ob.Reset() eb.Reset() args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "a1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"} - assertHelpError(t, signCert(args, ob, eb, nopw), "invalid ip definition: invalid CIDR address: a1.1.1.1/24") + assertHelpError(t, signCert(args, ob, eb, nopw), "invalid ip definition: a1.1.1.1/24") assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) @@ -193,7 +185,7 @@ func Test_signCert(t *testing.T) { ob.Reset() eb.Reset() args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "a"} - assertHelpError(t, signCert(args, ob, eb, nopw), "invalid subnet definition: invalid CIDR address: a") + assertHelpError(t, signCert(args, ob, eb, nopw), "invalid subnet definition: a") assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) @@ -209,7 +201,7 @@ func Test_signCert(t *testing.T) { caKeyF2, err := os.CreateTemp("", "sign-cert-2.key") assert.Nil(t, err) defer os.Remove(caKeyF2.Name()) - caKeyF2.Write(cert.MarshalEd25519PrivateKey(caPriv2)) + caKeyF2.Write(cert.MarshalSigningPrivateKeyToPEM(cert.Curve_CURVE25519, caPriv2)) ob.Reset() eb.Reset() @@ -255,33 +247,34 @@ func Test_signCert(t *testing.T) { // read cert and key files rb, _ := os.ReadFile(keyF.Name()) - lKey, b, err := cert.UnmarshalX25519PrivateKey(rb) + lKey, b, curve, err := cert.UnmarshalPrivateKeyFromPEM(rb) + assert.Equal(t, cert.Curve_CURVE25519, curve) assert.Len(t, b, 0) assert.Nil(t, err) assert.Len(t, lKey, 32) rb, _ = os.ReadFile(crtF.Name()) - lCrt, b, err := cert.UnmarshalNebulaCertificateFromPEM(rb) + lCrt, b, err := cert.UnmarshalCertificateFromPEM(rb) assert.Len(t, b, 0) assert.Nil(t, err) - assert.Equal(t, "test", lCrt.Details.Name) - assert.Equal(t, "1.1.1.1/24", lCrt.Details.Ips[0].String()) - assert.Len(t, lCrt.Details.Ips, 1) - assert.False(t, lCrt.Details.IsCA) - assert.Equal(t, []string{"1", "2", "3", "4", "5"}, lCrt.Details.Groups) - assert.Len(t, lCrt.Details.Subnets, 3) - assert.Len(t, lCrt.Details.PublicKey, 32) - assert.Equal(t, time.Duration(time.Minute*100), lCrt.Details.NotAfter.Sub(lCrt.Details.NotBefore)) + assert.Equal(t, "test", lCrt.Name()) + assert.Equal(t, "1.1.1.1/24", lCrt.Networks()[0].String()) + assert.Len(t, lCrt.Networks(), 1) + assert.False(t, lCrt.IsCA()) + assert.Equal(t, []string{"1", "2", "3", "4", "5"}, lCrt.Groups()) + assert.Len(t, lCrt.UnsafeNetworks(), 3) + assert.Len(t, lCrt.PublicKey(), 32) + assert.Equal(t, time.Duration(time.Minute*100), lCrt.NotAfter().Sub(lCrt.NotBefore())) sns := []string{} - for _, sn := range lCrt.Details.Subnets { + for _, sn := range lCrt.UnsafeNetworks() { sns = append(sns, sn.String()) } assert.Equal(t, []string{"10.1.1.1/32", "10.2.2.2/32", "10.5.5.5/32"}, sns) - issuer, _ := ca.Sha256Sum() - assert.Equal(t, issuer, lCrt.Details.Issuer) + issuer, _ := ca.Fingerprint() + assert.Equal(t, issuer, lCrt.Issuer()) assert.True(t, lCrt.CheckSignature(caPub)) @@ -297,16 +290,18 @@ func Test_signCert(t *testing.T) { // read cert file and check pub key matches in-pub rb, _ = os.ReadFile(crtF.Name()) - lCrt, b, err = cert.UnmarshalNebulaCertificateFromPEM(rb) + lCrt, b, err = cert.UnmarshalCertificateFromPEM(rb) assert.Len(t, b, 0) assert.Nil(t, err) - assert.Equal(t, lCrt.Details.PublicKey, inPub) + assert.Equal(t, lCrt.PublicKey(), inPub) // test refuse to sign cert with duration beyond root ob.Reset() eb.Reset() + os.Remove(keyF.Name()) + os.Remove(crtF.Name()) args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "1000m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} - assert.EqualError(t, signCert(args, ob, eb, nopw), "refusing to sign, root certificate constraints violated: certificate expires after signing certificate") + assert.EqualError(t, signCert(args, ob, eb, nopw), "error while signing: certificate expires after signing certificate") assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) @@ -362,16 +357,8 @@ func Test_signCert(t *testing.T) { b, _ = cert.EncryptAndMarshalSigningPrivateKey(cert.Curve_CURVE25519, caPriv, passphrase, kdfParams) caKeyF.Write(b) - ca = cert.NebulaCertificate{ - Details: cert.NebulaCertificateDetails{ - Name: "ca", - NotBefore: time.Now(), - NotAfter: time.Now().Add(time.Minute * 200), - PublicKey: caPub, - IsCA: true, - }, - } - b, _ = ca.MarshalToPEM() + ca, _ = NewTestCaCert("ca", caPub, caPriv, time.Now(), time.Now().Add(time.Minute*200), nil, nil, nil) + b, _ = ca.MarshalPEM() caCrtF.Write(b) // test with the proper password diff --git a/cmd/nebula-cert/verify.go b/cmd/nebula-cert/verify.go index c955913..80cfef3 100644 --- a/cmd/nebula-cert/verify.go +++ b/cmd/nebula-cert/verify.go @@ -46,7 +46,7 @@ func verify(args []string, out io.Writer, errOut io.Writer) error { caPool := cert.NewCAPool() for { - rawCACert, err = caPool.AddCACertificate(rawCACert) + rawCACert, err = caPool.AddCAFromPEM(rawCACert) if err != nil { return fmt.Errorf("error while adding ca cert to pool: %s", err) } @@ -61,13 +61,13 @@ func verify(args []string, out io.Writer, errOut io.Writer) error { return fmt.Errorf("unable to read crt; %s", err) } - c, _, err := cert.UnmarshalNebulaCertificateFromPEM(rawCert) + c, _, err := cert.UnmarshalCertificateFromPEM(rawCert) if err != nil { return fmt.Errorf("error while parsing crt: %s", err) } - good, err := c.Verify(time.Now(), caPool) - if !good { + _, err = caPool.VerifyCertificate(time.Now(), c) + if err != nil { return err } diff --git a/cmd/nebula-cert/verify_test.go b/cmd/nebula-cert/verify_test.go index f0f4c78..204ff09 100644 --- a/cmd/nebula-cert/verify_test.go +++ b/cmd/nebula-cert/verify_test.go @@ -7,7 +7,6 @@ import ( "testing" "time" - "github.com/slackhq/nebula/cert" "github.com/stretchr/testify/assert" "golang.org/x/crypto/ed25519" ) @@ -67,17 +66,8 @@ func Test_verify(t *testing.T) { // make a ca for later caPub, caPriv, _ := ed25519.GenerateKey(rand.Reader) - ca := cert.NebulaCertificate{ - Details: cert.NebulaCertificateDetails{ - Name: "test-ca", - NotBefore: time.Now().Add(time.Hour * -1), - NotAfter: time.Now().Add(time.Hour * 2), - PublicKey: caPub, - IsCA: true, - }, - } - ca.Sign(cert.Curve_CURVE25519, caPriv) - b, _ := ca.MarshalToPEM() + ca, _ := NewTestCaCert("test-ca", caPub, caPriv, time.Now().Add(time.Hour*-1), time.Now().Add(time.Hour*2), nil, nil, nil) + b, _ := ca.MarshalPEM() caFile.Truncate(0) caFile.Seek(0, 0) caFile.Write(b) @@ -102,22 +92,13 @@ func Test_verify(t *testing.T) { assert.EqualError(t, err, "error while parsing crt: input did not contain a valid PEM encoded block") // unverifiable cert at path - _, badPriv, _ := ed25519.GenerateKey(rand.Reader) - certPub, _ := x25519Keypair() - signer, _ := ca.Sha256Sum() - crt := cert.NebulaCertificate{ - Details: cert.NebulaCertificateDetails{ - Name: "test-cert", - NotBefore: time.Now().Add(time.Hour * -1), - NotAfter: time.Now().Add(time.Hour), - PublicKey: certPub, - IsCA: false, - Issuer: signer, - }, + crt, _ := NewTestCert(ca, caPriv, "test-cert", time.Now().Add(time.Hour*-1), time.Now().Add(time.Hour), nil, nil, nil) + // Slightly evil hack to modify the certificate after it was sealed to generate an invalid signature + pub := crt.PublicKey() + for i, _ := range pub { + pub[i] = 0 } - - crt.Sign(cert.Curve_CURVE25519, badPriv) - b, _ = crt.MarshalToPEM() + b, _ = crt.MarshalPEM() certFile.Truncate(0) certFile.Seek(0, 0) certFile.Write(b) @@ -128,8 +109,8 @@ func Test_verify(t *testing.T) { assert.EqualError(t, err, "certificate signature did not match") // verified cert at path - crt.Sign(cert.Curve_CURVE25519, caPriv) - b, _ = crt.MarshalToPEM() + crt, _ = NewTestCert(ca, caPriv, "test-cert", time.Now().Add(time.Hour*-1), time.Now().Add(time.Hour), nil, nil, nil) + b, _ = crt.MarshalPEM() certFile.Truncate(0) certFile.Seek(0, 0) certFile.Write(b) diff --git a/connection_manager.go b/connection_manager.go index d2e8616..7718252 100644 --- a/connection_manager.go +++ b/connection_manager.go @@ -415,7 +415,7 @@ func (n *connectionManager) shouldSwapPrimary(current, primary *HostInfo) bool { } certState := n.intf.pki.GetCertState() - return bytes.Equal(current.ConnectionState.myCert.Signature, certState.Certificate.Signature) + return bytes.Equal(current.ConnectionState.myCert.Signature(), certState.Certificate.Signature()) } func (n *connectionManager) swapPrimary(current, primary *HostInfo) { @@ -436,8 +436,9 @@ func (n *connectionManager) isInvalidCertificate(now time.Time, hostinfo *HostIn return false } - valid, err := remoteCert.VerifyWithCache(now, n.intf.pki.GetCAPool()) - if valid { + caPool := n.intf.pki.GetCAPool() + err := caPool.VerifyCachedCertificate(now, remoteCert) + if err == nil { return false } @@ -446,9 +447,8 @@ func (n *connectionManager) isInvalidCertificate(now time.Time, hostinfo *HostIn return false } - fingerprint, _ := remoteCert.Sha256Sum() hostinfo.logger(n.l).WithError(err). - WithField("fingerprint", fingerprint). + WithField("fingerprint", remoteCert.Fingerprint). Info("Remote certificate is no longer valid, tearing down the tunnel") return true @@ -474,7 +474,7 @@ func (n *connectionManager) sendPunch(hostinfo *HostInfo) { func (n *connectionManager) tryRehandshake(hostinfo *HostInfo) { certState := n.intf.pki.GetCertState() - if bytes.Equal(hostinfo.ConnectionState.myCert.Signature, certState.Certificate.Signature) { + if bytes.Equal(hostinfo.ConnectionState.myCert.Signature(), certState.Certificate.Signature()) { return } diff --git a/connection_manager_test.go b/connection_manager_test.go index 5f97cad..9f222c8 100644 --- a/connection_manager_test.go +++ b/connection_manager_test.go @@ -4,7 +4,6 @@ import ( "context" "crypto/ed25519" "crypto/rand" - "net" "net/netip" "testing" "time" @@ -47,7 +46,7 @@ func Test_NewConnectionManagerTest(t *testing.T) { cs := &CertState{ RawCertificate: []byte{}, PrivateKey: []byte{}, - Certificate: &cert.NebulaCertificate{}, + Certificate: &dummyCert{}, RawCertificateNoKey: []byte{}, } @@ -80,7 +79,7 @@ func Test_NewConnectionManagerTest(t *testing.T) { remoteIndexId: 9901, } hostinfo.ConnectionState = &ConnectionState{ - myCert: &cert.NebulaCertificate{}, + myCert: &dummyCert{}, H: &noise.HandshakeState{}, } nc.hostMap.unlockedAddHostInfo(hostinfo, ifce) @@ -130,7 +129,7 @@ func Test_NewConnectionManagerTest2(t *testing.T) { cs := &CertState{ RawCertificate: []byte{}, PrivateKey: []byte{}, - Certificate: &cert.NebulaCertificate{}, + Certificate: &dummyCert{}, RawCertificateNoKey: []byte{}, } @@ -163,7 +162,7 @@ func Test_NewConnectionManagerTest2(t *testing.T) { remoteIndexId: 9901, } hostinfo.ConnectionState = &ConnectionState{ - myCert: &cert.NebulaCertificate{}, + myCert: &dummyCert{}, H: &noise.HandshakeState{}, } nc.hostMap.unlockedAddHostInfo(hostinfo, ifce) @@ -206,10 +205,7 @@ func Test_NewConnectionManagerTest2(t *testing.T) { func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) { now := time.Now() l := test.NewLogger() - ipNet := net.IPNet{ - IP: net.IPv4(172, 1, 1, 2), - Mask: net.IPMask{255, 255, 255, 0}, - } + vpncidr := netip.MustParsePrefix("172.1.1.1/24") localrange := netip.MustParsePrefix("10.1.1.1/24") vpnIp := netip.MustParseAddr("172.1.1.2") @@ -219,41 +215,38 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) { // Generate keys for CA and peer's cert. pubCA, privCA, _ := ed25519.GenerateKey(rand.Reader) - caCert := cert.NebulaCertificate{ - Details: cert.NebulaCertificateDetails{ - Name: "ca", - NotBefore: now, - NotAfter: now.Add(1 * time.Hour), - IsCA: true, - PublicKey: pubCA, - }, + tbs := &cert.TBSCertificate{ + Version: 1, + Name: "ca", + IsCA: true, + NotBefore: now, + NotAfter: now.Add(1 * time.Hour), + PublicKey: pubCA, } - assert.NoError(t, caCert.Sign(cert.Curve_CURVE25519, privCA)) - ncp := &cert.NebulaCAPool{ - CAs: cert.NewCAPool().CAs, - } - ncp.CAs["ca"] = &caCert + caCert, err := tbs.Sign(nil, cert.Curve_CURVE25519, privCA) + assert.NoError(t, err) + ncp := cert.NewCAPool() + assert.NoError(t, ncp.AddCA(caCert)) pubCrt, _, _ := ed25519.GenerateKey(rand.Reader) - peerCert := cert.NebulaCertificate{ - Details: cert.NebulaCertificateDetails{ - Name: "host", - Ips: []*net.IPNet{&ipNet}, - Subnets: []*net.IPNet{}, - NotBefore: now, - NotAfter: now.Add(60 * time.Second), - PublicKey: pubCrt, - IsCA: false, - Issuer: "ca", - }, + tbs = &cert.TBSCertificate{ + Version: 1, + Name: "host", + Networks: []netip.Prefix{vpncidr}, + NotBefore: now, + NotAfter: now.Add(60 * time.Second), + PublicKey: pubCrt, } - assert.NoError(t, peerCert.Sign(cert.Curve_CURVE25519, privCA)) + peerCert, err := tbs.Sign(caCert, cert.Curve_CURVE25519, privCA) + assert.NoError(t, err) + + cachedPeerCert, err := ncp.VerifyCertificate(now.Add(time.Second), peerCert) cs := &CertState{ RawCertificate: []byte{}, PrivateKey: []byte{}, - Certificate: &cert.NebulaCertificate{}, + Certificate: &dummyCert{}, RawCertificateNoKey: []byte{}, } @@ -282,8 +275,8 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) { hostinfo := &HostInfo{ vpnIp: vpnIp, ConnectionState: &ConnectionState{ - myCert: &cert.NebulaCertificate{}, - peerCert: &peerCert, + myCert: &dummyCert{}, + peerCert: cachedPeerCert, H: &noise.HandshakeState{}, }, } @@ -303,3 +296,114 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) { invalid = nc.isInvalidCertificate(nextTick, hostinfo) assert.True(t, invalid) } + +type dummyCert struct { + version cert.Version + curve cert.Curve + groups []string + isCa bool + issuer string + name string + networks []netip.Prefix + notAfter time.Time + notBefore time.Time + publicKey []byte + signature []byte + unsafeNetworks []netip.Prefix +} + +func (d *dummyCert) Version() cert.Version { + return d.version +} + +func (d *dummyCert) Curve() cert.Curve { + return d.curve +} + +func (d *dummyCert) Groups() []string { + return d.groups +} + +func (d *dummyCert) IsCA() bool { + return d.isCa +} + +func (d *dummyCert) Issuer() string { + return d.issuer +} + +func (d *dummyCert) Name() string { + return d.name +} + +func (d *dummyCert) Networks() []netip.Prefix { + return d.networks +} + +func (d *dummyCert) NotAfter() time.Time { + return d.notAfter +} + +func (d *dummyCert) NotBefore() time.Time { + return d.notBefore +} + +func (d *dummyCert) PublicKey() []byte { + return d.publicKey +} + +func (d *dummyCert) Signature() []byte { + return d.signature +} + +func (d *dummyCert) UnsafeNetworks() []netip.Prefix { + return d.unsafeNetworks +} + +func (d *dummyCert) MarshalForHandshakes() ([]byte, error) { + return nil, nil +} + +func (d *dummyCert) Sign(curve cert.Curve, key []byte) error { + return nil +} + +func (d *dummyCert) CheckSignature(key []byte) bool { + return true +} + +func (d *dummyCert) Expired(t time.Time) bool { + return false +} + +func (d *dummyCert) CheckRootConstraints(signer cert.Certificate) error { + return nil +} + +func (d *dummyCert) VerifyPrivateKey(curve cert.Curve, key []byte) error { + return nil +} + +func (d *dummyCert) String() string { + return "" +} + +func (d *dummyCert) Marshal() ([]byte, error) { + return nil, nil +} + +func (d *dummyCert) MarshalPEM() ([]byte, error) { + return nil, nil +} + +func (d *dummyCert) Fingerprint() (string, error) { + return "", nil +} + +func (d *dummyCert) MarshalJSON() ([]byte, error) { + return nil, nil +} + +func (d *dummyCert) Copy() cert.Certificate { + return d +} diff --git a/connection_state.go b/connection_state.go index aa17a13..bcc9e5d 100644 --- a/connection_state.go +++ b/connection_state.go @@ -18,8 +18,8 @@ type ConnectionState struct { eKey *NebulaCipherState dKey *NebulaCipherState H *noise.HandshakeState - myCert *cert.NebulaCertificate - peerCert *cert.NebulaCertificate + myCert cert.Certificate + peerCert *cert.CachedCertificate initiator bool messageCounter atomic.Uint64 window *Bits @@ -28,17 +28,17 @@ type ConnectionState struct { func NewConnectionState(l *logrus.Logger, cipher string, certState *CertState, initiator bool, pattern noise.HandshakePattern, psk []byte, pskStage int) *ConnectionState { var dhFunc noise.DHFunc - switch certState.Certificate.Details.Curve { + switch certState.Certificate.Curve() { case cert.Curve_CURVE25519: dhFunc = noise.DH25519 case cert.Curve_P256: - if certState.Certificate.Pkcs11Backed { + if certState.pkcs11Backed { dhFunc = noiseutil.DHP256PKCS11 } else { dhFunc = noiseutil.DHP256 } default: - l.Errorf("invalid curve: %s", certState.Certificate.Details.Curve) + l.Errorf("invalid curve: %s", certState.Certificate.Curve()) return nil } diff --git a/control.go b/control.go index 3468b35..2615984 100644 --- a/control.go +++ b/control.go @@ -37,15 +37,15 @@ type Control struct { } type ControlHostInfo struct { - VpnIp netip.Addr `json:"vpnIp"` - LocalIndex uint32 `json:"localIndex"` - RemoteIndex uint32 `json:"remoteIndex"` - RemoteAddrs []netip.AddrPort `json:"remoteAddrs"` - Cert *cert.NebulaCertificate `json:"cert"` - MessageCounter uint64 `json:"messageCounter"` - CurrentRemote netip.AddrPort `json:"currentRemote"` - CurrentRelaysToMe []netip.Addr `json:"currentRelaysToMe"` - CurrentRelaysThroughMe []netip.Addr `json:"currentRelaysThroughMe"` + VpnIp netip.Addr `json:"vpnIp"` + LocalIndex uint32 `json:"localIndex"` + RemoteIndex uint32 `json:"remoteIndex"` + RemoteAddrs []netip.AddrPort `json:"remoteAddrs"` + Cert cert.Certificate `json:"cert"` + MessageCounter uint64 `json:"messageCounter"` + CurrentRemote netip.AddrPort `json:"currentRemote"` + CurrentRelaysToMe []netip.Addr `json:"currentRelaysToMe"` + CurrentRelaysThroughMe []netip.Addr `json:"currentRelaysThroughMe"` } // Start actually runs nebula, this is a nonblocking call. To block use Control.ShutdownBlock() @@ -130,15 +130,15 @@ func (c *Control) ListHostmapIndexes(pendingMap bool) []ControlHostInfo { } // GetCertByVpnIp returns the authenticated certificate of the given vpn IP, or nil if not found -func (c *Control) GetCertByVpnIp(vpnIp netip.Addr) *cert.NebulaCertificate { +func (c *Control) GetCertByVpnIp(vpnIp netip.Addr) cert.Certificate { if c.f.myVpnNet.Addr() == vpnIp { - return c.f.pki.GetCertState().Certificate + return c.f.pki.GetCertState().Certificate.Copy() } hi := c.f.hostMap.QueryVpnIp(vpnIp) if hi == nil { return nil } - return hi.GetCert() + return hi.GetCert().Certificate.Copy() } // CreateTunnel creates a new tunnel to the given vpn ip. @@ -290,7 +290,7 @@ func copyHostInfo(h *HostInfo, preferredRanges []netip.Prefix) ControlHostInfo { } if c := h.GetCert(); c != nil { - chi.Cert = c.Copy() + chi.Cert = c.Certificate.Copy() } return chi diff --git a/control_test.go b/control_test.go index fbf29c0..fdfc0a5 100644 --- a/control_test.go +++ b/control_test.go @@ -5,7 +5,6 @@ import ( "net/netip" "reflect" "testing" - "time" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/cert" @@ -14,6 +13,9 @@ import ( ) func TestControl_GetHostInfoByVpnIp(t *testing.T) { + //TODO: with multiple certificate versions we have a problem with this test + // Some certs versions have different characteristics and each version implements their own Copy() func + // which means this is not a good place to test for exposing memory l := test.NewLogger() // Special care must be taken to re-use all objects provided to the hostmap and certificate in the expectedInfo object // To properly ensure we are not exposing core memory to the caller @@ -33,22 +35,6 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) { Mask: net.IPMask{255, 255, 255, 0}, } - crt := &cert.NebulaCertificate{ - Details: cert.NebulaCertificateDetails{ - Name: "test", - Ips: []*net.IPNet{&ipNet}, - Subnets: []*net.IPNet{}, - Groups: []string{"default-group"}, - NotBefore: time.Unix(1, 0), - NotAfter: time.Unix(2, 0), - PublicKey: []byte{5, 6, 7, 8}, - IsCA: false, - Issuer: "the-issuer", - InvertedGroups: map[string]struct{}{"default-group": {}}, - }, - Signature: []byte{1, 2, 1, 2, 1, 3}, - } - remotes := NewRemoteList(nil) remotes.unlockedPrependV4(netip.IPv4Unspecified(), NewIp4AndPortFromNetIP(remote1.Addr(), remote1.Port())) remotes.unlockedPrependV6(netip.IPv4Unspecified(), NewIp6AndPortFromNetIP(remote2.Addr(), remote2.Port())) @@ -56,11 +42,12 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) { vpnIp, ok := netip.AddrFromSlice(ipNet.IP) assert.True(t, ok) + crt := &dummyCert{} hm.unlockedAddHostInfo(&HostInfo{ remote: remote1, remotes: remotes, ConnectionState: &ConnectionState{ - peerCert: crt, + peerCert: &cert.CachedCertificate{Certificate: crt}, }, remoteIndexId: 200, localIndexId: 201, @@ -115,8 +102,7 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) { // Make sure we don't have any unexpected fields assertFields(t, []string{"VpnIp", "LocalIndex", "RemoteIndex", "RemoteAddrs", "Cert", "MessageCounter", "CurrentRemote", "CurrentRelaysToMe", "CurrentRelaysThroughMe"}, thi) assert.EqualValues(t, &expectedInfo, thi) - //TODO: netip.Addr reuses global memory for zone identifiers which breaks our "no reused memory check" here - //test.AssertDeepCopyEqual(t, &expectedInfo, thi) + test.AssertDeepCopyEqual(t, &expectedInfo, thi) // Make sure we don't panic if the host info doesn't have a cert yet assert.NotPanics(t, func() { diff --git a/control_tester.go b/control_tester.go index d46540f..fa87e53 100644 --- a/control_tester.go +++ b/control_tester.go @@ -153,7 +153,7 @@ func (c *Control) GetHostmap() *HostMap { return c.f.hostMap } -func (c *Control) GetCert() *cert.NebulaCertificate { +func (c *Control) GetCert() cert.Certificate { return c.f.pki.GetCertState().Certificate } diff --git a/dns_server.go b/dns_server.go index 5fea65c..7501231 100644 --- a/dns_server.go +++ b/dns_server.go @@ -57,9 +57,11 @@ func (d *dnsRecords) QueryCert(data string) string { return "" } - cert := q.Details - c := fmt.Sprintf("\"Name: %s\" \"Ips: %s\" \"Subnets %s\" \"Groups %s\" \"NotBefore %s\" \"NotAfter %s\" \"PublicKey %x\" \"IsCA %t\" \"Issuer %s\"", cert.Name, cert.Ips, cert.Subnets, cert.Groups, cert.NotBefore, cert.NotAfter, cert.PublicKey, cert.IsCA, cert.Issuer) - return c + b, err := q.Certificate.MarshalJSON() + if err != nil { + return "" + } + return string(b) } func (d *dnsRecords) Add(host, data string) { diff --git a/e2e/handshakes_test.go b/e2e/handshakes_test.go index 3d42a56..6be94ad 100644 --- a/e2e/handshakes_test.go +++ b/e2e/handshakes_test.go @@ -6,6 +6,7 @@ package e2e import ( "fmt" "net/netip" + "slices" "testing" "time" @@ -538,9 +539,9 @@ func TestRehandshakingRelays(t *testing.T) { // When I update the certificate for the relay, both me and them will have 2 host infos for the relay, // and the main host infos will not have any relay state to handle the me<->relay<->them tunnel. r.Log("Renew relay certificate and spin until me and them sees it") - _, _, myNextPrivKey, myNextPEM := NewTestCert(ca, caKey, "relay", time.Now(), time.Now().Add(5*time.Minute), relayVpnIpNet, nil, []string{"new group"}) + _, _, myNextPrivKey, myNextPEM := NewTestCert(ca, caKey, "relay", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{relayVpnIpNet}, nil, []string{"new group"}) - caB, err := ca.MarshalToPEM() + caB, err := ca.MarshalPEM() if err != nil { panic(err) } @@ -558,7 +559,7 @@ func TestRehandshakingRelays(t *testing.T) { r.Log("Assert the tunnel works between myVpnIpNet and relayVpnIpNet") assertTunnel(t, myVpnIpNet.Addr(), relayVpnIpNet.Addr(), myControl, relayControl, r) c := myControl.GetHostInfoByVpnIp(relayVpnIpNet.Addr(), false) - if len(c.Cert.Details.Groups) != 0 { + if len(c.Cert.Groups()) != 0 { // We have a new certificate now r.Log("Certificate between my and relay is updated!") break @@ -571,7 +572,7 @@ func TestRehandshakingRelays(t *testing.T) { r.Log("Assert the tunnel works between theirVpnIpNet and relayVpnIpNet") assertTunnel(t, theirVpnIpNet.Addr(), relayVpnIpNet.Addr(), theirControl, relayControl, r) c := theirControl.GetHostInfoByVpnIp(relayVpnIpNet.Addr(), false) - if len(c.Cert.Details.Groups) != 0 { + if len(c.Cert.Groups()) != 0 { // We have a new certificate now r.Log("Certificate between their and relay is updated!") break @@ -642,9 +643,9 @@ func TestRehandshakingRelaysPrimary(t *testing.T) { // When I update the certificate for the relay, both me and them will have 2 host infos for the relay, // and the main host infos will not have any relay state to handle the me<->relay<->them tunnel. r.Log("Renew relay certificate and spin until me and them sees it") - _, _, myNextPrivKey, myNextPEM := NewTestCert(ca, caKey, "relay", time.Now(), time.Now().Add(5*time.Minute), relayVpnIpNet, nil, []string{"new group"}) + _, _, myNextPrivKey, myNextPEM := NewTestCert(ca, caKey, "relay", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{relayVpnIpNet}, nil, []string{"new group"}) - caB, err := ca.MarshalToPEM() + caB, err := ca.MarshalPEM() if err != nil { panic(err) } @@ -662,7 +663,7 @@ func TestRehandshakingRelaysPrimary(t *testing.T) { r.Log("Assert the tunnel works between myVpnIpNet and relayVpnIpNet") assertTunnel(t, myVpnIpNet.Addr(), relayVpnIpNet.Addr(), myControl, relayControl, r) c := myControl.GetHostInfoByVpnIp(relayVpnIpNet.Addr(), false) - if len(c.Cert.Details.Groups) != 0 { + if len(c.Cert.Groups()) != 0 { // We have a new certificate now r.Log("Certificate between my and relay is updated!") break @@ -675,7 +676,7 @@ func TestRehandshakingRelaysPrimary(t *testing.T) { r.Log("Assert the tunnel works between theirVpnIpNet and relayVpnIpNet") assertTunnel(t, theirVpnIpNet.Addr(), relayVpnIpNet.Addr(), theirControl, relayControl, r) c := theirControl.GetHostInfoByVpnIp(relayVpnIpNet.Addr(), false) - if len(c.Cert.Details.Groups) != 0 { + if len(c.Cert.Groups()) != 0 { // We have a new certificate now r.Log("Certificate between their and relay is updated!") break @@ -737,9 +738,9 @@ func TestRehandshaking(t *testing.T) { r.RenderHostmaps("Starting hostmaps", myControl, theirControl) r.Log("Renew my certificate and spin until their sees it") - _, _, myNextPrivKey, myNextPEM := NewTestCert(ca, caKey, "me", time.Now(), time.Now().Add(5*time.Minute), myVpnIpNet, nil, []string{"new group"}) + _, _, myNextPrivKey, myNextPEM := NewTestCert(ca, caKey, "me", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{myVpnIpNet}, nil, []string{"new group"}) - caB, err := ca.MarshalToPEM() + caB, err := ca.MarshalPEM() if err != nil { panic(err) } @@ -756,7 +757,7 @@ func TestRehandshaking(t *testing.T) { for { assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) c := theirControl.GetHostInfoByVpnIp(myVpnIpNet.Addr(), false) - if len(c.Cert.Details.Groups) != 0 { + if len(c.Cert.Groups()) != 0 { // We have a new certificate now break } @@ -764,6 +765,7 @@ func TestRehandshaking(t *testing.T) { time.Sleep(time.Second) } + r.Log("Got the new cert") // Flip their firewall to only allowing the new group to catch the tunnels reverting incorrectly rc, err = yaml.Marshal(theirConfig.Settings) assert.NoError(t, err) @@ -794,7 +796,7 @@ func TestRehandshaking(t *testing.T) { // Make sure the correct tunnel won c := theirControl.GetHostInfoByVpnIp(myVpnIpNet.Addr(), false) - assert.Contains(t, c.Cert.Details.Groups, "new group") + assert.Contains(t, c.Cert.Groups(), "new group") // We should only have a single tunnel now on both sides assert.Len(t, myFinalHostmapHosts, 1) @@ -837,9 +839,9 @@ func TestRehandshakingLoser(t *testing.T) { r.RenderHostmaps("Starting hostmaps", myControl, theirControl) r.Log("Renew their certificate and spin until mine sees it") - _, _, theirNextPrivKey, theirNextPEM := NewTestCert(ca, caKey, "them", time.Now(), time.Now().Add(5*time.Minute), theirVpnIpNet, nil, []string{"their new group"}) + _, _, theirNextPrivKey, theirNextPEM := NewTestCert(ca, caKey, "them", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{theirVpnIpNet}, nil, []string{"their new group"}) - caB, err := ca.MarshalToPEM() + caB, err := ca.MarshalPEM() if err != nil { panic(err) } @@ -857,8 +859,7 @@ func TestRehandshakingLoser(t *testing.T) { assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) theirCertInMe := myControl.GetHostInfoByVpnIp(theirVpnIpNet.Addr(), false) - _, theirNewGroup := theirCertInMe.Cert.Details.InvertedGroups["their new group"] - if theirNewGroup { + if slices.Contains(theirCertInMe.Cert.Groups(), "their new group") { break } @@ -895,7 +896,7 @@ func TestRehandshakingLoser(t *testing.T) { // Make sure the correct tunnel won theirCertInMe := myControl.GetHostInfoByVpnIp(theirVpnIpNet.Addr(), false) - assert.Contains(t, theirCertInMe.Cert.Details.Groups, "their new group") + assert.Contains(t, theirCertInMe.Cert.Groups(), "their new group") // We should only have a single tunnel now on both sides assert.Len(t, myFinalHostmapHosts, 1) diff --git a/e2e/helpers.go b/e2e/helpers.go index 71df805..c0893ac 100644 --- a/e2e/helpers.go +++ b/e2e/helpers.go @@ -3,7 +3,6 @@ package e2e import ( "crypto/rand" "io" - "net" "net/netip" "time" @@ -13,7 +12,7 @@ import ( ) // NewTestCaCert will generate a CA cert -func NewTestCaCert(before, after time.Time, ips, subnets []netip.Prefix, groups []string) (*cert.NebulaCertificate, []byte, []byte, []byte) { +func NewTestCaCert(before, after time.Time, networks, unsafeNetworks []netip.Prefix, groups []string) (cert.Certificate, []byte, []byte, []byte) { pub, priv, err := ed25519.GenerateKey(rand.Reader) if before.IsZero() { before = time.Now().Add(time.Second * -60).Round(time.Second) @@ -22,56 +21,34 @@ func NewTestCaCert(before, after time.Time, ips, subnets []netip.Prefix, groups after = time.Now().Add(time.Second * 60).Round(time.Second) } - nc := &cert.NebulaCertificate{ - Details: cert.NebulaCertificateDetails{ - Name: "test ca", - NotBefore: time.Unix(before.Unix(), 0), - NotAfter: time.Unix(after.Unix(), 0), - PublicKey: pub, - IsCA: true, - InvertedGroups: make(map[string]struct{}), - }, + t := &cert.TBSCertificate{ + Version: cert.Version1, + Name: "test ca", + NotBefore: time.Unix(before.Unix(), 0), + NotAfter: time.Unix(after.Unix(), 0), + PublicKey: pub, + Networks: networks, + UnsafeNetworks: unsafeNetworks, + Groups: groups, + IsCA: true, } - if len(ips) > 0 { - nc.Details.Ips = make([]*net.IPNet, len(ips)) - for i, ip := range ips { - nc.Details.Ips[i] = &net.IPNet{IP: ip.Addr().AsSlice(), Mask: net.CIDRMask(ip.Bits(), ip.Addr().BitLen())} - } - } - - if len(subnets) > 0 { - nc.Details.Subnets = make([]*net.IPNet, len(subnets)) - for i, ip := range subnets { - nc.Details.Ips[i] = &net.IPNet{IP: ip.Addr().AsSlice(), Mask: net.CIDRMask(ip.Bits(), ip.Addr().BitLen())} - } - } - - if len(groups) > 0 { - nc.Details.Groups = groups - } - - err = nc.Sign(cert.Curve_CURVE25519, priv) + c, err := t.Sign(nil, cert.Curve_CURVE25519, priv) if err != nil { panic(err) } - pem, err := nc.MarshalToPEM() + pem, err := c.MarshalPEM() if err != nil { panic(err) } - return nc, pub, priv, pem + return c, pub, priv, pem } // NewTestCert will generate a signed certificate with the provided details. // Expiry times are defaulted if you do not pass them in -func NewTestCert(ca *cert.NebulaCertificate, key []byte, name string, before, after time.Time, ip netip.Prefix, subnets []netip.Prefix, groups []string) (*cert.NebulaCertificate, []byte, []byte, []byte) { - issuer, err := ca.Sha256Sum() - if err != nil { - panic(err) - } - +func NewTestCert(ca cert.Certificate, key []byte, name string, before, after time.Time, networks, unsafeNetworks []netip.Prefix, groups []string) (cert.Certificate, []byte, []byte, []byte) { if before.IsZero() { before = time.Now().Add(time.Second * -60).Round(time.Second) } @@ -81,33 +58,29 @@ func NewTestCert(ca *cert.NebulaCertificate, key []byte, name string, before, af } pub, rawPriv := x25519Keypair() - ipb := ip.Addr().AsSlice() - nc := &cert.NebulaCertificate{ - Details: cert.NebulaCertificateDetails{ - Name: name, - Ips: []*net.IPNet{{IP: ipb[:], Mask: net.CIDRMask(ip.Bits(), ip.Addr().BitLen())}}, - //Subnets: subnets, - Groups: groups, - NotBefore: time.Unix(before.Unix(), 0), - NotAfter: time.Unix(after.Unix(), 0), - PublicKey: pub, - IsCA: false, - Issuer: issuer, - InvertedGroups: make(map[string]struct{}), - }, + nc := &cert.TBSCertificate{ + Version: cert.Version1, + Name: name, + Networks: networks, + UnsafeNetworks: unsafeNetworks, + Groups: groups, + NotBefore: time.Unix(before.Unix(), 0), + NotAfter: time.Unix(after.Unix(), 0), + PublicKey: pub, + IsCA: false, } - err = nc.Sign(ca.Details.Curve, key) + c, err := nc.Sign(ca, ca.Curve(), key) if err != nil { panic(err) } - pem, err := nc.MarshalToPEM() + pem, err := c.MarshalPEM() if err != nil { panic(err) } - return nc, pub, cert.MarshalX25519PrivateKey(rawPriv), pem + return c, pub, cert.MarshalPrivateKeyToPEM(cert.Curve_CURVE25519, rawPriv), pem } func x25519Keypair() ([]byte, []byte) { diff --git a/e2e/helpers_test.go b/e2e/helpers_test.go index 527f55b..77996f3 100644 --- a/e2e/helpers_test.go +++ b/e2e/helpers_test.go @@ -26,7 +26,7 @@ import ( type m map[string]interface{} // newSimpleServer creates a nebula instance with many assumptions -func newSimpleServer(caCrt *cert.NebulaCertificate, caKey []byte, name string, sVpnIpNet string, overrides m) (*nebula.Control, netip.Prefix, netip.AddrPort, *config.C) { +func newSimpleServer(caCrt cert.Certificate, caKey []byte, name string, sVpnIpNet string, overrides m) (*nebula.Control, netip.Prefix, netip.AddrPort, *config.C) { l := NewTestLogger() vpnIpNet, err := netip.ParsePrefix(sVpnIpNet) @@ -44,9 +44,9 @@ func newSimpleServer(caCrt *cert.NebulaCertificate, caKey []byte, name string, s budpIp[13] -= 128 udpAddr = netip.AddrPortFrom(netip.AddrFrom16(budpIp), 4242) } - _, _, myPrivKey, myPEM := NewTestCert(caCrt, caKey, name, time.Now(), time.Now().Add(5*time.Minute), vpnIpNet, nil, []string{}) + _, _, myPrivKey, myPEM := NewTestCert(caCrt, caKey, name, time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{vpnIpNet}, nil, []string{}) - caB, err := caCrt.MarshalToPEM() + caB, err := caCrt.MarshalPEM() if err != nil { panic(err) } diff --git a/e2e/router/hostmap.go b/e2e/router/hostmap.go index c14ab2e..29fa959 100644 --- a/e2e/router/hostmap.go +++ b/e2e/router/hostmap.go @@ -58,8 +58,8 @@ func renderHostmap(c *nebula.Control) (string, []*edge) { var lines []string var globalLines []*edge - clusterName := strings.Trim(c.GetCert().Details.Name, " ") - clusterVpnIp := c.GetCert().Details.Ips[0].IP + clusterName := strings.Trim(c.GetCert().Name(), " ") + clusterVpnIp := c.GetCert().Networks()[0].Addr() r := fmt.Sprintf("\tsubgraph %s[\"%s (%s)\"]\n", clusterName, clusterName, clusterVpnIp) hm := c.GetHostmap() @@ -102,7 +102,7 @@ func renderHostmap(c *nebula.Control) (string, []*edge) { hi, ok := hm.Indexes[idx] if ok { r += fmt.Sprintf("\t\t\t%v.%v[\"%v (%v)\"]\n", clusterName, idx, idx, hi.GetVpnIp()) - remoteClusterName := strings.Trim(hi.GetCert().Details.Name, " ") + remoteClusterName := strings.Trim(hi.GetCert().Certificate.Name(), " ") globalLines = append(globalLines, &edge{from: fmt.Sprintf("%v.%v", clusterName, idx), to: fmt.Sprintf("%v.%v", remoteClusterName, hi.GetRemoteIndex())}) _ = hi } diff --git a/firewall.go b/firewall.go index 8a409d2..80a8280 100644 --- a/firewall.go +++ b/firewall.go @@ -52,9 +52,9 @@ type Firewall struct { DefaultTimeout time.Duration //linux: 600s // Used to ensure we don't emit local packets for ips we don't own - localIps *bart.Table[struct{}] - assignedCIDR netip.Prefix - hasSubnets bool + localIps *bart.Table[struct{}] + assignedCIDR netip.Prefix + hasUnsafeNetworks bool rules string rulesVersion uint16 @@ -126,7 +126,7 @@ type firewallLocalCIDR struct { } // NewFirewall creates a new Firewall object. A TimerWheel is created for you from the provided timeouts. -func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.Duration, c *cert.NebulaCertificate) *Firewall { +func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.Duration, c cert.Certificate) *Firewall { //TODO: error on 0 duration var min, max time.Duration @@ -147,11 +147,8 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D localIps := new(bart.Table[struct{}]) var assignedCIDR netip.Prefix var assignedSet bool - for _, ip := range c.Details.Ips { - //TODO: IPV6-WORK the unmap is a bit unfortunate - nip, _ := netip.AddrFromSlice(ip.IP) - nip = nip.Unmap() - nprefix := netip.PrefixFrom(nip, nip.BitLen()) + for _, network := range c.Networks() { + nprefix := netip.PrefixFrom(network.Addr(), network.Addr().BitLen()) localIps.Insert(nprefix, struct{}{}) if !assignedSet { @@ -161,11 +158,10 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D } } - for _, n := range c.Details.Subnets { - nip, _ := netip.AddrFromSlice(n.IP) - ones, _ := n.Mask.Size() - nip = nip.Unmap() - localIps.Insert(netip.PrefixFrom(nip, ones), struct{}{}) + hasUnsafeNetworks := false + for _, n := range c.UnsafeNetworks() { + localIps.Insert(n, struct{}{}) + hasUnsafeNetworks = true } return &Firewall{ @@ -173,15 +169,15 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D Conns: make(map[firewall.Packet]*conn), TimerWheel: NewTimerWheel[firewall.Packet](min, max), }, - InRules: newFirewallTable(), - OutRules: newFirewallTable(), - TCPTimeout: tcpTimeout, - UDPTimeout: UDPTimeout, - DefaultTimeout: defaultTimeout, - localIps: localIps, - assignedCIDR: assignedCIDR, - hasSubnets: len(c.Details.Subnets) > 0, - l: l, + InRules: newFirewallTable(), + OutRules: newFirewallTable(), + TCPTimeout: tcpTimeout, + UDPTimeout: UDPTimeout, + DefaultTimeout: defaultTimeout, + localIps: localIps, + assignedCIDR: assignedCIDR, + hasUnsafeNetworks: hasUnsafeNetworks, + l: l, incomingMetrics: firewallMetrics{ droppedLocalIP: metrics.GetOrRegisterCounter("firewall.incoming.dropped.local_ip", nil), @@ -196,7 +192,7 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D } } -func NewFirewallFromConfig(l *logrus.Logger, nc *cert.NebulaCertificate, c *config.C) (*Firewall, error) { +func NewFirewallFromConfig(l *logrus.Logger, nc cert.Certificate, c *config.C) (*Firewall, error) { fw := NewFirewall( l, c.GetDuration("firewall.conntrack.tcp_timeout", time.Minute*12), @@ -421,7 +417,7 @@ var ErrNoMatchingRule = errors.New("no matching rule in firewall table") // Drop returns an error if the packet should be dropped, explaining why. It // returns nil if the packet should not be dropped. -func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *cert.NebulaCAPool, localCache firewall.ConntrackCache) error { +func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *cert.CAPool, localCache firewall.ConntrackCache) error { // Check if we spoke to this tuple, if we did then allow this packet if f.inConns(fp, h, caPool, localCache) { return nil @@ -492,7 +488,7 @@ func (f *Firewall) EmitStats() { metrics.GetOrRegisterGauge("firewall.rules.hash", nil).Update(int64(f.GetRuleHashFNV())) } -func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.NebulaCAPool, localCache firewall.ConntrackCache) bool { +func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool, localCache firewall.ConntrackCache) bool { if localCache != nil { if _, ok := localCache[fp]; ok { return true @@ -619,7 +615,7 @@ func (f *Firewall) evict(p firewall.Packet) { delete(conntrack.Conns, p) } -func (ft *FirewallTable) match(p firewall.Packet, incoming bool, c *cert.NebulaCertificate, caPool *cert.NebulaCAPool) bool { +func (ft *FirewallTable) match(p firewall.Packet, incoming bool, c *cert.CachedCertificate, caPool *cert.CAPool) bool { if ft.AnyProto.match(p, incoming, c, caPool) { return true } @@ -663,7 +659,7 @@ func (fp firewallPort) addRule(f *Firewall, startPort int32, endPort int32, grou return nil } -func (fp firewallPort) match(p firewall.Packet, incoming bool, c *cert.NebulaCertificate, caPool *cert.NebulaCAPool) bool { +func (fp firewallPort) match(p firewall.Packet, incoming bool, c *cert.CachedCertificate, caPool *cert.CAPool) bool { // We don't have any allowed ports, bail if fp == nil { return false @@ -726,7 +722,7 @@ func (fc *FirewallCA) addRule(f *Firewall, groups []string, host string, ip, loc return nil } -func (fc *FirewallCA) match(p firewall.Packet, c *cert.NebulaCertificate, caPool *cert.NebulaCAPool) bool { +func (fc *FirewallCA) match(p firewall.Packet, c *cert.CachedCertificate, caPool *cert.CAPool) bool { if fc == nil { return false } @@ -735,18 +731,18 @@ func (fc *FirewallCA) match(p firewall.Packet, c *cert.NebulaCertificate, caPool return true } - if t, ok := fc.CAShas[c.Details.Issuer]; ok { + if t, ok := fc.CAShas[c.Certificate.Issuer()]; ok { if t.match(p, c) { return true } } - s, err := caPool.GetCAForCert(c) + s, err := caPool.GetCAForCert(c.Certificate) if err != nil { return false } - return fc.CANames[s.Details.Name].match(p, c) + return fc.CANames[s.Certificate.Name()].match(p, c) } func (fr *FirewallRule) addRule(f *Firewall, groups []string, host string, ip, localCIDR netip.Prefix) error { @@ -826,7 +822,7 @@ func (fr *FirewallRule) isAny(groups []string, host string, ip netip.Prefix) boo return false } -func (fr *FirewallRule) match(p firewall.Packet, c *cert.NebulaCertificate) bool { +func (fr *FirewallRule) match(p firewall.Packet, c *cert.CachedCertificate) bool { if fr == nil { return false } @@ -841,7 +837,7 @@ func (fr *FirewallRule) match(p firewall.Packet, c *cert.NebulaCertificate) bool found := false for _, g := range sg.Groups { - if _, ok := c.Details.InvertedGroups[g]; !ok { + if _, ok := c.InvertedGroups[g]; !ok { found = false break } @@ -855,7 +851,7 @@ func (fr *FirewallRule) match(p firewall.Packet, c *cert.NebulaCertificate) bool } if fr.Hosts != nil { - if flc, ok := fr.Hosts[c.Details.Name]; ok { + if flc, ok := fr.Hosts[c.Certificate.Name()]; ok { if flc.match(p, c) { return true } @@ -876,7 +872,7 @@ func (fr *FirewallRule) match(p firewall.Packet, c *cert.NebulaCertificate) bool func (flc *firewallLocalCIDR) addRule(f *Firewall, localIp netip.Prefix) error { if !localIp.IsValid() { - if !f.hasSubnets || f.defaultLocalCIDRAny { + if !f.hasUnsafeNetworks || f.defaultLocalCIDRAny { flc.Any = true return nil } @@ -890,7 +886,7 @@ func (flc *firewallLocalCIDR) addRule(f *Firewall, localIp netip.Prefix) error { return nil } -func (flc *firewallLocalCIDR) match(p firewall.Packet, c *cert.NebulaCertificate) bool { +func (flc *firewallLocalCIDR) match(p firewall.Packet, c *cert.CachedCertificate) bool { if flc == nil { return false } diff --git a/firewall_test.go b/firewall_test.go index 4d47e78..57cd32a 100644 --- a/firewall_test.go +++ b/firewall_test.go @@ -4,7 +4,6 @@ import ( "bytes" "errors" "math" - "net" "net/netip" "testing" "time" @@ -18,7 +17,7 @@ import ( func TestNewFirewall(t *testing.T) { l := test.NewLogger() - c := &cert.NebulaCertificate{} + c := &dummyCert{} fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c) conntrack := fw.Conntrack assert.NotNil(t, conntrack) @@ -60,7 +59,7 @@ func TestFirewall_AddRule(t *testing.T) { ob := &bytes.Buffer{} l.SetOutput(ob) - c := &cert.NebulaCertificate{} + c := &dummyCert{} fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c) assert.NotNil(t, fw.InRules) assert.NotNil(t, fw.OutRules) @@ -137,23 +136,18 @@ func TestFirewall_Drop(t *testing.T) { Fragment: false, } - ipNet := net.IPNet{ - IP: net.IPv4(1, 2, 3, 4), - Mask: net.IPMask{255, 255, 255, 0}, - } - - c := cert.NebulaCertificate{ - Details: cert.NebulaCertificateDetails{ - Name: "host1", - Ips: []*net.IPNet{&ipNet}, - Groups: []string{"default-group"}, - InvertedGroups: map[string]struct{}{"default-group": {}}, - Issuer: "signer-shasum", - }, + c := dummyCert{ + name: "host1", + networks: []netip.Prefix{netip.MustParsePrefix("1.2.3.4/24")}, + groups: []string{"default-group"}, + issuer: "signer-shasum", } h := HostInfo{ ConnectionState: &ConnectionState{ - peerCert: &c, + peerCert: &cert.CachedCertificate{ + Certificate: &c, + InvertedGroups: map[string]struct{}{"default-group": {}}, + }, }, vpnIp: netip.MustParseAddr("1.2.3.4"), } @@ -190,14 +184,14 @@ func TestFirewall_Drop(t *testing.T) { assert.NoError(t, fw.Drop(p, true, &h, cp, nil)) // ensure ca name doesn't get in the way of group checks - cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}} + cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}} fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", "")) assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", "")) assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule) // test caName doesn't drop on match - cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}} + cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}} fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", "")) assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", "")) @@ -217,7 +211,9 @@ func BenchmarkFirewallTable_match(b *testing.B) { b.Run("fail on proto", func(b *testing.B) { // This benchmark is showing us the cost of failing to match the protocol - c := &cert.NebulaCertificate{} + c := &cert.CachedCertificate{ + Certificate: &dummyCert{}, + } for n := 0; n < b.N; n++ { assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoUDP}, true, c, cp)) } @@ -225,14 +221,18 @@ func BenchmarkFirewallTable_match(b *testing.B) { b.Run("pass proto, fail on port", func(b *testing.B) { // This benchmark is showing us the cost of matching a specific protocol but failing to match the port - c := &cert.NebulaCertificate{} + c := &cert.CachedCertificate{ + Certificate: &dummyCert{}, + } for n := 0; n < b.N; n++ { assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 1}, true, c, cp)) } }) b.Run("pass proto, port, fail on local CIDR", func(b *testing.B) { - c := &cert.NebulaCertificate{} + c := &cert.CachedCertificate{ + Certificate: &dummyCert{}, + } ip := netip.MustParsePrefix("9.254.254.254/32") for n := 0; n < b.N; n++ { assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: ip.Addr()}, true, c, cp)) @@ -240,13 +240,12 @@ func BenchmarkFirewallTable_match(b *testing.B) { }) b.Run("pass proto, port, any local CIDR, fail all group, name, and cidr", func(b *testing.B) { - _, ip, _ := net.ParseCIDR("9.254.254.254/32") - c := &cert.NebulaCertificate{ - Details: cert.NebulaCertificateDetails{ - InvertedGroups: map[string]struct{}{"nope": {}}, - Name: "nope", - Ips: []*net.IPNet{ip}, + c := &cert.CachedCertificate{ + Certificate: &dummyCert{ + name: "nope", + networks: []netip.Prefix{netip.MustParsePrefix("9.254.254.245/32")}, }, + InvertedGroups: map[string]struct{}{"nope": {}}, } for n := 0; n < b.N; n++ { assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp)) @@ -254,13 +253,12 @@ func BenchmarkFirewallTable_match(b *testing.B) { }) b.Run("pass proto, port, specific local CIDR, fail all group, name, and cidr", func(b *testing.B) { - _, ip, _ := net.ParseCIDR("9.254.254.254/32") - c := &cert.NebulaCertificate{ - Details: cert.NebulaCertificateDetails{ - InvertedGroups: map[string]struct{}{"nope": {}}, - Name: "nope", - Ips: []*net.IPNet{ip}, + c := &cert.CachedCertificate{ + Certificate: &dummyCert{ + name: "nope", + networks: []netip.Prefix{netip.MustParsePrefix("9.254.254.245/32")}, }, + InvertedGroups: map[string]struct{}{"nope": {}}, } for n := 0; n < b.N; n++ { assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: pfix.Addr()}, true, c, cp)) @@ -268,11 +266,11 @@ func BenchmarkFirewallTable_match(b *testing.B) { }) b.Run("pass on group on any local cidr", func(b *testing.B) { - c := &cert.NebulaCertificate{ - Details: cert.NebulaCertificateDetails{ - InvertedGroups: map[string]struct{}{"good-group": {}}, - Name: "nope", + c := &cert.CachedCertificate{ + Certificate: &dummyCert{ + name: "nope", }, + InvertedGroups: map[string]struct{}{"good-group": {}}, } for n := 0; n < b.N; n++ { assert.True(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp)) @@ -280,11 +278,11 @@ func BenchmarkFirewallTable_match(b *testing.B) { }) b.Run("pass on group on specific local cidr", func(b *testing.B) { - c := &cert.NebulaCertificate{ - Details: cert.NebulaCertificateDetails{ - InvertedGroups: map[string]struct{}{"good-group": {}}, - Name: "nope", + c := &cert.CachedCertificate{ + Certificate: &dummyCert{ + name: "nope", }, + InvertedGroups: map[string]struct{}{"good-group": {}}, } for n := 0; n < b.N; n++ { assert.True(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: pfix.Addr()}, true, c, cp)) @@ -292,70 +290,16 @@ func BenchmarkFirewallTable_match(b *testing.B) { }) b.Run("pass on name", func(b *testing.B) { - c := &cert.NebulaCertificate{ - Details: cert.NebulaCertificateDetails{ - InvertedGroups: map[string]struct{}{"nope": {}}, - Name: "good-host", + c := &cert.CachedCertificate{ + Certificate: &dummyCert{ + name: "good-host", }, + InvertedGroups: map[string]struct{}{"nope": {}}, } for n := 0; n < b.N; n++ { ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp) } }) - // - //b.Run("pass on ip", func(b *testing.B) { - // ip := iputil.Ip2VpnIp(net.IPv4(172, 1, 1, 1)) - // c := &cert.NebulaCertificate{ - // Details: cert.NebulaCertificateDetails{ - // InvertedGroups: map[string]struct{}{"nope": {}}, - // Name: "good-host", - // }, - // } - // for n := 0; n < b.N; n++ { - // ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10, RemoteIP: ip}, true, c, cp) - // } - //}) - // - //b.Run("pass on local ip", func(b *testing.B) { - // ip := iputil.Ip2VpnIp(net.IPv4(172, 1, 1, 1)) - // c := &cert.NebulaCertificate{ - // Details: cert.NebulaCertificateDetails{ - // InvertedGroups: map[string]struct{}{"nope": {}}, - // Name: "good-host", - // }, - // } - // for n := 0; n < b.N; n++ { - // ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10, LocalIP: ip}, true, c, cp) - // } - //}) - // - //_ = ft.TCP.addRule(0, 0, []string{"good-group"}, "good-host", n, n, "", "") - // - //b.Run("pass on ip with any port", func(b *testing.B) { - // ip := iputil.Ip2VpnIp(net.IPv4(172, 1, 1, 1)) - // c := &cert.NebulaCertificate{ - // Details: cert.NebulaCertificateDetails{ - // InvertedGroups: map[string]struct{}{"nope": {}}, - // Name: "good-host", - // }, - // } - // for n := 0; n < b.N; n++ { - // ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, RemoteIP: ip}, true, c, cp) - // } - //}) - // - //b.Run("pass on local ip with any port", func(b *testing.B) { - // ip := iputil.Ip2VpnIp(net.IPv4(172, 1, 1, 1)) - // c := &cert.NebulaCertificate{ - // Details: cert.NebulaCertificateDetails{ - // InvertedGroups: map[string]struct{}{"nope": {}}, - // Name: "good-host", - // }, - // } - // for n := 0; n < b.N; n++ { - // ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: ip}, true, c, cp) - // } - //}) } func TestFirewall_Drop2(t *testing.T) { @@ -372,41 +316,38 @@ func TestFirewall_Drop2(t *testing.T) { Fragment: false, } - ipNet := net.IPNet{ - IP: net.IPv4(1, 2, 3, 4), - Mask: net.IPMask{255, 255, 255, 0}, - } + network := netip.MustParsePrefix("1.2.3.4/24") - c := cert.NebulaCertificate{ - Details: cert.NebulaCertificateDetails{ - Name: "host1", - Ips: []*net.IPNet{&ipNet}, - InvertedGroups: map[string]struct{}{"default-group": {}, "test-group": {}}, + c := cert.CachedCertificate{ + Certificate: &dummyCert{ + name: "host1", + networks: []netip.Prefix{network}, }, + InvertedGroups: map[string]struct{}{"default-group": {}, "test-group": {}}, } h := HostInfo{ ConnectionState: &ConnectionState{ peerCert: &c, }, - vpnIp: netip.MustParseAddr(ipNet.IP.String()), + vpnIp: network.Addr(), } - h.CreateRemoteCIDR(&c) + h.CreateRemoteCIDR(c.Certificate) - c1 := cert.NebulaCertificate{ - Details: cert.NebulaCertificateDetails{ - Name: "host1", - Ips: []*net.IPNet{&ipNet}, - InvertedGroups: map[string]struct{}{"default-group": {}, "test-group-not": {}}, + c1 := cert.CachedCertificate{ + Certificate: &dummyCert{ + name: "host1", + networks: []netip.Prefix{network}, }, + InvertedGroups: map[string]struct{}{"default-group": {}, "test-group-not": {}}, } h1 := HostInfo{ ConnectionState: &ConnectionState{ peerCert: &c1, }, } - h1.CreateRemoteCIDR(&c1) + h1.CreateRemoteCIDR(c1.Certificate) - fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c) + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) cp := cert.NewCAPool() @@ -431,64 +372,60 @@ func TestFirewall_Drop3(t *testing.T) { Fragment: false, } - ipNet := net.IPNet{ - IP: net.IPv4(1, 2, 3, 4), - Mask: net.IPMask{255, 255, 255, 0}, - } - - c := cert.NebulaCertificate{ - Details: cert.NebulaCertificateDetails{ - Name: "host-owner", - Ips: []*net.IPNet{&ipNet}, + network := netip.MustParsePrefix("1.2.3.4/24") + c := cert.CachedCertificate{ + Certificate: &dummyCert{ + name: "host-owner", + networks: []netip.Prefix{network}, }, } - c1 := cert.NebulaCertificate{ - Details: cert.NebulaCertificateDetails{ - Name: "host1", - Ips: []*net.IPNet{&ipNet}, - Issuer: "signer-sha-bad", + c1 := cert.CachedCertificate{ + Certificate: &dummyCert{ + name: "host1", + networks: []netip.Prefix{network}, + issuer: "signer-sha-bad", }, } h1 := HostInfo{ ConnectionState: &ConnectionState{ peerCert: &c1, }, - vpnIp: netip.MustParseAddr(ipNet.IP.String()), + vpnIp: network.Addr(), } - h1.CreateRemoteCIDR(&c1) + h1.CreateRemoteCIDR(c1.Certificate) - c2 := cert.NebulaCertificate{ - Details: cert.NebulaCertificateDetails{ - Name: "host2", - Ips: []*net.IPNet{&ipNet}, - Issuer: "signer-sha", + c2 := cert.CachedCertificate{ + Certificate: &dummyCert{ + name: "host2", + networks: []netip.Prefix{network}, + issuer: "signer-sha", }, } h2 := HostInfo{ ConnectionState: &ConnectionState{ peerCert: &c2, }, - vpnIp: netip.MustParseAddr(ipNet.IP.String()), + vpnIp: network.Addr(), } - h2.CreateRemoteCIDR(&c2) + h2.CreateRemoteCIDR(c2.Certificate) - c3 := cert.NebulaCertificate{ - Details: cert.NebulaCertificateDetails{ - Name: "host3", - Ips: []*net.IPNet{&ipNet}, - Issuer: "signer-sha-bad", + c3 := cert.CachedCertificate{ + Certificate: &dummyCert{ + name: "host3", + networks: []netip.Prefix{network}, + issuer: "signer-sha-bad", }, } h3 := HostInfo{ ConnectionState: &ConnectionState{ peerCert: &c3, }, - vpnIp: netip.MustParseAddr(ipNet.IP.String()), + vpnIp: network.Addr(), } - h3.CreateRemoteCIDR(&c3) + h3.CreateRemoteCIDR(c3.Certificate) - fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c) + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", netip.Prefix{}, netip.Prefix{}, "", "")) assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-sha")) cp := cert.NewCAPool() @@ -516,30 +453,26 @@ func TestFirewall_DropConntrackReload(t *testing.T) { Protocol: firewall.ProtoUDP, Fragment: false, } + network := netip.MustParsePrefix("1.2.3.4/24") - ipNet := net.IPNet{ - IP: net.IPv4(1, 2, 3, 4), - Mask: net.IPMask{255, 255, 255, 0}, - } - - c := cert.NebulaCertificate{ - Details: cert.NebulaCertificateDetails{ - Name: "host1", - Ips: []*net.IPNet{&ipNet}, - Groups: []string{"default-group"}, - InvertedGroups: map[string]struct{}{"default-group": {}}, - Issuer: "signer-shasum", + c := cert.CachedCertificate{ + Certificate: &dummyCert{ + name: "host1", + networks: []netip.Prefix{network}, + groups: []string{"default-group"}, + issuer: "signer-shasum", }, + InvertedGroups: map[string]struct{}{"default-group": {}}, } h := HostInfo{ ConnectionState: &ConnectionState{ peerCert: &c, }, - vpnIp: netip.MustParseAddr(ipNet.IP.String()), + vpnIp: network.Addr(), } - h.CreateRemoteCIDR(&c) + h.CreateRemoteCIDR(c.Certificate) - fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c) + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) cp := cert.NewCAPool() @@ -552,7 +485,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) { assert.NoError(t, fw.Drop(p, false, &h, cp, nil)) oldFw := fw - fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 10, 10, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) fw.Conntrack = oldFw.Conntrack fw.rulesVersion = oldFw.rulesVersion + 1 @@ -561,7 +494,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) { assert.NoError(t, fw.Drop(p, false, &h, cp, nil)) oldFw = fw - fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 11, 11, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) fw.Conntrack = oldFw.Conntrack fw.rulesVersion = oldFw.rulesVersion + 1 @@ -688,7 +621,7 @@ func Test_parsePort(t *testing.T) { func TestNewFirewallFromConfig(t *testing.T) { l := test.NewLogger() // Test a bad rule definition - c := &cert.NebulaCertificate{} + c := &dummyCert{} conf := config.NewC(l) conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": "asdf"} _, err := NewFirewallFromConfig(l, c, conf) diff --git a/handshake_ix.go b/handshake_ix.go index 8cf5341..73e8541 100644 --- a/handshake_ix.go +++ b/handshake_ix.go @@ -99,8 +99,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet return } - vpnIp, ok := netip.AddrFromSlice(remoteCert.Details.Ips[0].IP) - if !ok { + if len(remoteCert.Certificate.Networks()) == 0 { e := f.l.WithError(err).WithField("udpAddr", addr). WithField("handshake", m{"stage": 1, "style": "ix_psk0"}) @@ -112,10 +111,10 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet return } - vpnIp = vpnIp.Unmap() - certName := remoteCert.Details.Name - fingerprint, _ := remoteCert.Sha256Sum() - issuer := remoteCert.Details.Issuer + vpnIp := remoteCert.Certificate.Networks()[0].Addr().Unmap() + certName := remoteCert.Certificate.Name() + fingerprint := remoteCert.Fingerprint + issuer := remoteCert.Certificate.Issuer() if vpnIp == f.myVpnNet.Addr() { f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr). @@ -216,7 +215,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet hostinfo.remotes = f.lightHouse.QueryCache(vpnIp) hostinfo.SetRemote(addr) - hostinfo.CreateRemoteCIDR(remoteCert) + hostinfo.CreateRemoteCIDR(remoteCert.Certificate) existing, err := f.handshakeManager.CheckAndComplete(hostinfo, 0, f) if err != nil { @@ -402,8 +401,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha return true } - vpnIp, ok := netip.AddrFromSlice(remoteCert.Details.Ips[0].IP) - if !ok { + if len(remoteCert.Certificate.Networks()) == 0 { e := f.l.WithError(err).WithField("udpAddr", addr). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}) @@ -415,10 +413,10 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha return true } - vpnIp = vpnIp.Unmap() - certName := remoteCert.Details.Name - fingerprint, _ := remoteCert.Sha256Sum() - issuer := remoteCert.Details.Issuer + vpnIp := remoteCert.Certificate.Networks()[0].Addr().Unmap() + certName := remoteCert.Certificate.Name() + fingerprint := remoteCert.Fingerprint + issuer := remoteCert.Certificate.Issuer() // Ensure the right host responded if vpnIp != hostinfo.vpnIp { @@ -486,7 +484,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha } // Build up the radix for the firewall if we have subnets in the cert - hostinfo.CreateRemoteCIDR(remoteCert) + hostinfo.CreateRemoteCIDR(remoteCert.Certificate) // Complete our handshake and update metrics, this will replace any existing tunnels for this vpnIp f.handshakeManager.Complete(hostinfo, f) diff --git a/handshake_manager.go b/handshake_manager.go index 1df37bd..4834893 100644 --- a/handshake_manager.go +++ b/handshake_manager.go @@ -7,6 +7,7 @@ import ( "encoding/binary" "errors" "net/netip" + "slices" "sync" "time" @@ -14,7 +15,6 @@ import ( "github.com/sirupsen/logrus" "github.com/slackhq/nebula/header" "github.com/slackhq/nebula/udp" - "golang.org/x/exp/slices" ) const ( diff --git a/handshake_manager_test.go b/handshake_manager_test.go index a78b45f..daa8675 100644 --- a/handshake_manager_test.go +++ b/handshake_manager_test.go @@ -5,7 +5,6 @@ import ( "testing" "time" - "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/header" "github.com/slackhq/nebula/test" "github.com/slackhq/nebula/udp" @@ -27,7 +26,7 @@ func Test_NewHandshakeManagerVpnIp(t *testing.T) { cs := &CertState{ RawCertificate: []byte{}, PrivateKey: []byte{}, - Certificate: &cert.NebulaCertificate{}, + Certificate: &dummyCert{}, RawCertificateNoKey: []byte{}, } diff --git a/hostmap.go b/hostmap.go index fb97b76..d83151e 100644 --- a/hostmap.go +++ b/hostmap.go @@ -491,7 +491,7 @@ func (hm *HostMap) queryVpnIp(vpnIp netip.Addr, promoteIfce *Interface) *HostInf func (hm *HostMap) unlockedAddHostInfo(hostinfo *HostInfo, f *Interface) { if f.serveDns { remoteCert := hostinfo.ConnectionState.peerCert - dnsR.Add(remoteCert.Details.Name+".", remoteCert.Details.Ips[0].IP.String()) + dnsR.Add(remoteCert.Certificate.Name()+".", remoteCert.Certificate.Networks()[0].Addr().String()) } existing := hm.Hosts[hostinfo.vpnIp] @@ -585,7 +585,7 @@ func (i *HostInfo) TryPromoteBest(preferredRanges []netip.Prefix, ifce *Interfac } } -func (i *HostInfo) GetCert() *cert.NebulaCertificate { +func (i *HostInfo) GetCert() *cert.CachedCertificate { if i.ConnectionState != nil { return i.ConnectionState.peerCert } @@ -647,27 +647,19 @@ func (i *HostInfo) RecvErrorExceeded() bool { return true } -func (i *HostInfo) CreateRemoteCIDR(c *cert.NebulaCertificate) { - if len(c.Details.Ips) == 1 && len(c.Details.Subnets) == 0 { +func (i *HostInfo) CreateRemoteCIDR(c cert.Certificate) { + if len(c.Networks()) == 1 && len(c.UnsafeNetworks()) == 0 { // Simple case, no CIDRTree needed return } remoteCidr := new(bart.Table[struct{}]) - for _, ip := range c.Details.Ips { - //TODO: IPV6-WORK what to do when ip is invalid? - nip, _ := netip.AddrFromSlice(ip.IP) - nip = nip.Unmap() - bits, _ := ip.Mask.Size() - remoteCidr.Insert(netip.PrefixFrom(nip, bits), struct{}{}) + for _, network := range c.Networks() { + remoteCidr.Insert(network, struct{}{}) } - for _, n := range c.Details.Subnets { - //TODO: IPV6-WORK what to do when ip is invalid? - nip, _ := netip.AddrFromSlice(n.IP) - nip = nip.Unmap() - bits, _ := n.Mask.Size() - remoteCidr.Insert(netip.PrefixFrom(nip, bits), struct{}{}) + for _, network := range c.UnsafeNetworks() { + remoteCidr.Insert(network, struct{}{}) } i.remoteCidr = remoteCidr } @@ -683,7 +675,7 @@ func (i *HostInfo) logger(l *logrus.Logger) *logrus.Entry { if connState := i.ConnectionState; connState != nil { if peerCert := connState.peerCert; peerCert != nil { - li = li.WithField("certName", peerCert.Details.Name) + li = li.WithField("certName", peerCert.Certificate.Name()) } } diff --git a/interface.go b/interface.go index f251907..5d41a87 100644 --- a/interface.go +++ b/interface.go @@ -6,6 +6,7 @@ import ( "errors" "fmt" "io" + "net" "net/netip" "os" "runtime" @@ -157,26 +158,6 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) { certificate := c.pki.GetCertState().Certificate - myVpnAddr, ok := netip.AddrFromSlice(certificate.Details.Ips[0].IP) - if !ok { - return nil, fmt.Errorf("invalid ip address in certificate: %s", certificate.Details.Ips[0].IP) - } - - myVpnMask, ok := netip.AddrFromSlice(certificate.Details.Ips[0].Mask) - if !ok { - return nil, fmt.Errorf("invalid ip mask in certificate: %s", certificate.Details.Ips[0].Mask) - } - - myVpnAddr = myVpnAddr.Unmap() - myVpnMask = myVpnMask.Unmap() - - if myVpnAddr.BitLen() != myVpnMask.BitLen() { - return nil, fmt.Errorf("ip address and mask are different lengths in certificate") - } - - ones, _ := certificate.Details.Ips[0].Mask.Size() - myVpnNet := netip.PrefixFrom(myVpnAddr, ones) - ifce := &Interface{ pki: c.pki, hostMap: c.HostMap, @@ -194,7 +175,7 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) { version: c.version, writers: make([]udp.Conn, c.routines), readers: make([]io.ReadWriteCloser, c.routines), - myVpnNet: myVpnNet, + myVpnNet: certificate.Networks()[0], relayManager: c.relayManager, conntrackCacheTimeout: c.ConntrackCacheTimeout, @@ -209,9 +190,11 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) { l: c.l, } - if myVpnAddr.Is4() { - addr := myVpnNet.Masked().Addr().As4() - binary.BigEndian.PutUint32(addr[:], binary.BigEndian.Uint32(addr[:])|^binary.BigEndian.Uint32(certificate.Details.Ips[0].Mask)) + if ifce.myVpnNet.Addr().Is4() { + maskedAddr := certificate.Networks()[0].Masked() + addr := maskedAddr.Addr().As4() + mask := net.CIDRMask(maskedAddr.Bits(), maskedAddr.Addr().BitLen()) + binary.BigEndian.PutUint32(addr[:], binary.BigEndian.Uint32(addr[:])|^binary.BigEndian.Uint32(mask)) ifce.myBroadcastAddr = netip.AddrFrom4(addr) } @@ -434,7 +417,7 @@ func (f *Interface) emitStats(ctx context.Context, i time.Duration) { f.firewall.EmitStats() f.handshakeManager.EmitStats() udpStats() - certExpirationGauge.Update(int64(f.pki.GetCertState().Certificate.Details.NotAfter.Sub(time.Now()) / time.Second)) + certExpirationGauge.Update(int64(f.pki.GetCertState().Certificate.NotAfter().Sub(time.Now()) / time.Second)) } } } diff --git a/main.go b/main.go index c6edc91..8f45359 100644 --- a/main.go +++ b/main.go @@ -68,17 +68,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg } l.WithField("firewallHashes", fw.GetRuleHashes()).Info("Firewall started") - ones, _ := certificate.Details.Ips[0].Mask.Size() - addr, ok := netip.AddrFromSlice(certificate.Details.Ips[0].IP) - if !ok { - err = util.NewContextualError( - "Invalid ip address in certificate", - m{"vpnIp": certificate.Details.Ips[0].IP}, - nil, - ) - return nil, err - } - tunCidr := netip.PrefixFrom(addr, ones) + tunCidr := certificate.Networks()[0] ssh, err := sshd.NewSSHServer(l.WithField("subsystem", "sshd")) if err != nil { diff --git a/outside.go b/outside.go index be60294..6a71fe7 100644 --- a/outside.go +++ b/outside.go @@ -14,7 +14,6 @@ import ( "github.com/slackhq/nebula/header" "github.com/slackhq/nebula/udp" "golang.org/x/net/ipv4" - "google.golang.org/protobuf/proto" ) const ( @@ -494,7 +493,7 @@ func (f *Interface) sendMeta(ci *ConnectionState, endpoint *net.UDPAddr, meta *N } */ -func RecombineCertAndValidate(h *noise.HandshakeState, rawCertBytes []byte, caPool *cert.NebulaCAPool) (*cert.NebulaCertificate, error) { +func RecombineCertAndValidate(h *noise.HandshakeState, rawCertBytes []byte, caPool *cert.CAPool) (*cert.CachedCertificate, error) { pk := h.PeerStatic() if pk == nil { @@ -505,31 +504,15 @@ func RecombineCertAndValidate(h *noise.HandshakeState, rawCertBytes []byte, caPo return nil, errors.New("provided payload was empty") } - r := &cert.RawNebulaCertificate{} - err := proto.Unmarshal(rawCertBytes, r) + c, err := cert.UnmarshalCertificateFromHandshake(rawCertBytes, pk) if err != nil { - return nil, fmt.Errorf("error unmarshaling cert: %s", err) + return nil, fmt.Errorf("error unmarshaling cert: %w", err) } - // If the Details are nil, just exit to avoid crashing - if r.Details == nil { - return nil, fmt.Errorf("certificate did not contain any details") - } - - r.Details.PublicKey = pk - recombined, err := proto.Marshal(r) + cc, err := caPool.VerifyCertificate(time.Now(), c) if err != nil { - return nil, fmt.Errorf("error while recombining certificate: %s", err) + return nil, fmt.Errorf("certificate validation failed: %w", err) } - c, _ := cert.UnmarshalNebulaCertificate(recombined) - isValid, err := c.Verify(time.Now(), caPool) - if err != nil { - return c, fmt.Errorf("certificate validation failed: %s", err) - } else if !isValid { - // This case should never happen but here's to defensive programming! - return c, errors.New("certificate validation failed but did not return an error") - } - - return c, nil + return cc, nil } diff --git a/pki.go b/pki.go index 511d305..fe64ea5 100644 --- a/pki.go +++ b/pki.go @@ -16,16 +16,17 @@ import ( type PKI struct { cs atomic.Pointer[CertState] - caPool atomic.Pointer[cert.NebulaCAPool] + caPool atomic.Pointer[cert.CAPool] l *logrus.Logger } type CertState struct { - Certificate *cert.NebulaCertificate + Certificate cert.Certificate RawCertificate []byte RawCertificateNoKey []byte PublicKey []byte PrivateKey []byte + pkcs11Backed bool } func NewPKIFromConfig(l *logrus.Logger, c *config.C) (*PKI, error) { @@ -49,7 +50,7 @@ func (p *PKI) GetCertState() *CertState { return p.cs.Load() } -func (p *PKI) GetCAPool() *cert.NebulaCAPool { +func (p *PKI) GetCAPool() *cert.CAPool { return p.caPool.Load() } @@ -84,12 +85,12 @@ func (p *PKI) reloadCert(c *config.C, initial bool) *util.ContextualError { // did IP in cert change? if so, don't set currentCert := p.cs.Load().Certificate - oldIPs := currentCert.Details.Ips - newIPs := cs.Certificate.Details.Ips + oldIPs := currentCert.Networks() + newIPs := cs.Certificate.Networks() if len(oldIPs) > 0 && len(newIPs) > 0 && oldIPs[0].String() != newIPs[0].String() { return util.NewContextualError( - "IP in new cert was different from old", - m{"new_ip": newIPs[0], "old_ip": oldIPs[0]}, + "Networks in new cert was different from old", + m{"new_network": newIPs[0], "old_network": oldIPs[0]}, nil, ) } @@ -115,29 +116,28 @@ func (p *PKI) reloadCAPool(c *config.C) *util.ContextualError { return nil } -func newCertState(certificate *cert.NebulaCertificate, privateKey []byte) (*CertState, error) { +func newCertState(certificate cert.Certificate, pkcs11backed bool, privateKey []byte) (*CertState, error) { // Marshal the certificate to ensure it is valid rawCertificate, err := certificate.Marshal() if err != nil { return nil, fmt.Errorf("invalid nebula certificate on interface: %s", err) } - publicKey := certificate.Details.PublicKey + publicKey := certificate.PublicKey() cs := &CertState{ RawCertificate: rawCertificate, Certificate: certificate, PrivateKey: privateKey, PublicKey: publicKey, + pkcs11Backed: pkcs11backed, } - cs.Certificate.Details.PublicKey = nil - rawCertNoKey, err := cs.Certificate.Marshal() + rawCertNoKey, err := cs.Certificate.MarshalForHandshakes() if err != nil { return nil, fmt.Errorf("error marshalling certificate no key: %s", err) } cs.RawCertificateNoKey = rawCertNoKey - // put public key back - cs.Certificate.Details.PublicKey = cs.PublicKey + return cs, nil } @@ -146,7 +146,7 @@ func loadPrivateKey(privPathOrPEM string) (rawKey []byte, curve cert.Curve, isPk if strings.Contains(privPathOrPEM, "-----BEGIN") { pemPrivateKey = []byte(privPathOrPEM) privPathOrPEM = "" - rawKey, _, curve, err = cert.UnmarshalPrivateKey(pemPrivateKey) + rawKey, _, curve, err = cert.UnmarshalPrivateKeyFromPEM(pemPrivateKey) if err != nil { return nil, curve, false, fmt.Errorf("error while unmarshaling pki.key %s: %s", privPathOrPEM, err) } @@ -158,7 +158,7 @@ func loadPrivateKey(privPathOrPEM string) (rawKey []byte, curve cert.Curve, isPk if err != nil { return nil, curve, false, fmt.Errorf("unable to read pki.key file %s: %s", privPathOrPEM, err) } - rawKey, _, curve, err = cert.UnmarshalPrivateKey(pemPrivateKey) + rawKey, _, curve, err = cert.UnmarshalPrivateKeyFromPEM(pemPrivateKey) if err != nil { return nil, curve, false, fmt.Errorf("error while unmarshaling pki.key %s: %s", privPathOrPEM, err) } @@ -198,27 +198,27 @@ func newCertStateFromConfig(c *config.C) (*CertState, error) { } } - nebulaCert, _, err := cert.UnmarshalNebulaCertificateFromPEM(rawCert) + nebulaCert, _, err := cert.UnmarshalCertificateFromPEM(rawCert) if err != nil { return nil, fmt.Errorf("error while unmarshaling pki.cert %s: %s", pubPathOrPEM, err) } - nebulaCert.Pkcs11Backed = isPkcs11 + if nebulaCert.Expired(time.Now()) { return nil, fmt.Errorf("nebula certificate for this host is expired") } - if len(nebulaCert.Details.Ips) == 0 { - return nil, fmt.Errorf("no IPs encoded in certificate") + if len(nebulaCert.Networks()) == 0 { + return nil, fmt.Errorf("no networks encoded in certificate") } if err = nebulaCert.VerifyPrivateKey(curve, rawKey); err != nil { return nil, fmt.Errorf("private key is not a pair with public key in nebula cert") } - return newCertState(nebulaCert, rawKey) + return newCertState(nebulaCert, isPkcs11, rawKey) } -func loadCAPoolFromConfig(l *logrus.Logger, c *config.C) (*cert.NebulaCAPool, error) { +func loadCAPoolFromConfig(l *logrus.Logger, c *config.C) (*cert.CAPool, error) { var rawCA []byte var err error @@ -237,11 +237,11 @@ func loadCAPoolFromConfig(l *logrus.Logger, c *config.C) (*cert.NebulaCAPool, er } } - caPool, err := cert.NewCAPoolFromBytes(rawCA) + caPool, err := cert.NewCAPoolFromPEM(rawCA) if errors.Is(err, cert.ErrExpired) { var expired int for _, crt := range caPool.CAs { - if crt.Expired(time.Now()) { + if crt.Certificate.Expired(time.Now()) { expired++ l.WithField("cert", crt).Warn("expired certificate present in CA pool") } diff --git a/service/service_test.go b/service/service_test.go index 3176209..e9fceef 100644 --- a/service/service_test.go +++ b/service/service_test.go @@ -18,9 +18,9 @@ import ( type m map[string]interface{} -func newSimpleService(caCrt *cert.NebulaCertificate, caKey []byte, name string, udpIp netip.Addr, overrides m) *Service { - _, _, myPrivKey, myPEM := e2e.NewTestCert(caCrt, caKey, "a", time.Now(), time.Now().Add(5*time.Minute), netip.PrefixFrom(udpIp, 24), nil, []string{}) - caB, err := caCrt.MarshalToPEM() +func newSimpleService(caCrt cert.Certificate, caKey []byte, name string, udpIp netip.Addr, overrides m) *Service { + _, _, myPrivKey, myPEM := e2e.NewTestCert(caCrt, caKey, "a", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.PrefixFrom(udpIp, 24)}, nil, []string{}) + caB, err := caCrt.MarshalPEM() if err != nil { panic(err) } diff --git a/ssh.go b/ssh.go index 2ff0954..881ee46 100644 --- a/ssh.go +++ b/ssh.go @@ -801,7 +801,7 @@ func sshPrintCert(ifce *Interface, fs interface{}, a []string, w sshd.StringWrit return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0])) } - cert = hostInfo.GetCert() + cert = hostInfo.GetCert().Certificate } if args.Json || args.Pretty { @@ -825,7 +825,7 @@ func sshPrintCert(ifce *Interface, fs interface{}, a []string, w sshd.StringWrit } if args.Raw { - b, err := cert.MarshalToPEM() + b, err := cert.MarshalPEM() if err != nil { //TODO: handle it return nil diff --git a/test/assert.go b/test/assert.go index 6c6c795..d34252e 100644 --- a/test/assert.go +++ b/test/assert.go @@ -2,6 +2,7 @@ package test import ( "fmt" + "net/netip" "reflect" "testing" "time" @@ -24,6 +25,11 @@ func AssertDeepCopyEqual(t *testing.T, a interface{}, b interface{}) { } func traverseDeepCopy(t *testing.T, v1 reflect.Value, v2 reflect.Value, name string) bool { + if v1.Type() == v2.Type() && v1.Type() == reflect.TypeOf(netip.Addr{}) { + // Ignore netip.Addr types since they reuse an interned global value + return false + } + switch v1.Kind() { case reflect.Array: for i := 0; i < v1.Len(); i++ { From 9c175b4faff7086c18d7ba543f74f3adbfbb9e33 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Fri, 11 Oct 2024 09:01:42 -0400 Subject: [PATCH 28/67] Bump the golang-x-dependencies group across 1 directory with 4 updates (#1237) Bumps the golang-x-dependencies group with 2 updates in the / directory: [golang.org/x/crypto](https://github.com/golang/crypto) and [golang.org/x/net](https://github.com/golang/net). Updates `golang.org/x/crypto` from 0.26.0 to 0.28.0 - [Commits](https://github.com/golang/crypto/compare/v0.26.0...v0.28.0) Updates `golang.org/x/net` from 0.28.0 to 0.30.0 - [Commits](https://github.com/golang/net/compare/v0.28.0...v0.30.0) Updates `golang.org/x/sys` from 0.24.0 to 0.26.0 - [Commits](https://github.com/golang/sys/compare/v0.24.0...v0.26.0) Updates `golang.org/x/term` from 0.23.0 to 0.25.0 - [Commits](https://github.com/golang/term/compare/v0.23.0...v0.25.0) --- updated-dependencies: - dependency-name: golang.org/x/crypto dependency-type: direct:production update-type: version-update:semver-minor dependency-group: golang-x-dependencies - dependency-name: golang.org/x/net dependency-type: direct:production update-type: version-update:semver-minor dependency-group: golang-x-dependencies - dependency-name: golang.org/x/sys dependency-type: direct:production update-type: version-update:semver-minor dependency-group: golang-x-dependencies - dependency-name: golang.org/x/term dependency-type: direct:production update-type: version-update:semver-minor dependency-group: golang-x-dependencies ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- go.mod | 8 ++++---- go.sum | 16 ++++++++-------- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/go.mod b/go.mod index be7f0b6..bce8694 100644 --- a/go.mod +++ b/go.mod @@ -25,12 +25,12 @@ require ( github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6 github.com/stretchr/testify v1.9.0 github.com/vishvananda/netlink v1.2.1-beta.2 - golang.org/x/crypto v0.26.0 + golang.org/x/crypto v0.28.0 golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 - golang.org/x/net v0.28.0 + golang.org/x/net v0.30.0 golang.org/x/sync v0.8.0 - golang.org/x/sys v0.24.0 - golang.org/x/term v0.23.0 + golang.org/x/sys v0.26.0 + golang.org/x/term v0.25.0 golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b golang.zx2c4.com/wireguard/windows v0.5.3 diff --git a/go.sum b/go.sum index 10d12a1..e801759 100644 --- a/go.sum +++ b/go.sum @@ -155,8 +155,8 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= -golang.org/x/crypto v0.26.0 h1:RrRspgV4mU+YwB4FYnuBoKsUapNIL5cohGAmSH3azsw= -golang.org/x/crypto v0.26.0/go.mod h1:GY7jblb9wI+FOo5y8/S2oY4zWP07AkOJ4+jxCqdqn54= +golang.org/x/crypto v0.28.0 h1:GBDwsMXVQi34v5CCYUm2jkJvu4cbtru2U4TN2PSyQnw= +golang.org/x/crypto v0.28.0/go.mod h1:rmgy+3RHxRZMyY0jjAJShp2zgEdOqj2AO7U0pYmeQ7U= golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 h1:Di6/M8l0O2lCLc6VVRWhgCiApHV8MnQurBnFSHsQtNY= golang.org/x/exp v0.0.0-20230725093048-515e97ebf090/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc= golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY= @@ -175,8 +175,8 @@ golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLL golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= -golang.org/x/net v0.28.0 h1:a9JDOJc5GMUJ0+UDqmLT86WiEy7iWyIhz8gz8E4e5hE= -golang.org/x/net v0.28.0/go.mod h1:yqtgsTWOOnlGLG9GFRrK3++bGOUEkNBoHZc8MEDWPNg= +golang.org/x/net v0.30.0 h1:AcW1SDZMkb8IpzCdQUaIq2sP4sZ4zw+55h6ynffypl4= +golang.org/x/net v0.30.0/go.mod h1:2wGyMJ5iFasEhkwi13ChkO/t1ECNC4X4eBKkVFyYFlU= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -203,11 +203,11 @@ golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210603081109-ebe580a85c40/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.24.0 h1:Twjiwq9dn6R1fQcyiK+wQyHWfaz/BJB+YIpzU/Cv3Xg= -golang.org/x/sys v0.24.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.26.0 h1:KHjCJyddX0LoSTb3J+vWpupP9p0oznkqVk/IfjymZbo= +golang.org/x/sys v0.26.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= -golang.org/x/term v0.23.0 h1:F6D4vR+EHoL9/sWAWgAR1H2DcHr4PareCbAaCo1RpuU= -golang.org/x/term v0.23.0/go.mod h1:DgV24QBUrK6jhZXl+20l6UWznPlwAHm1Q1mGHtydmSk= +golang.org/x/term v0.25.0 h1:WtHI/ltw4NvSUig5KARz9h521QvRC8RmF/cuYqifU24= +golang.org/x/term v0.25.0/go.mod h1:RPyXicDX+6vLxogjjRxjgD2TKtmAO6NZBsBRfrOLu7M= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= From 97dd8c53c944859d3e5dcc05dcbfdf87bc69c8f3 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Fri, 11 Oct 2024 11:35:43 -0400 Subject: [PATCH 29/67] Bump github.com/vishvananda/netlink from 1.2.1-beta.2 to 1.3.0 (#1211) Bumps [github.com/vishvananda/netlink](https://github.com/vishvananda/netlink) from 1.2.1-beta.2 to 1.3.0. - [Release notes](https://github.com/vishvananda/netlink/releases) - [Commits](https://github.com/vishvananda/netlink/compare/v1.2.1-beta.2...v1.3.0) --- updated-dependencies: - dependency-name: github.com/vishvananda/netlink dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- go.mod | 2 +- go.sum | 9 ++++----- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/go.mod b/go.mod index bce8694..6720061 100644 --- a/go.mod +++ b/go.mod @@ -24,7 +24,7 @@ require ( github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8 github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6 github.com/stretchr/testify v1.9.0 - github.com/vishvananda/netlink v1.2.1-beta.2 + github.com/vishvananda/netlink v1.3.0 golang.org/x/crypto v0.28.0 golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 golang.org/x/net v0.30.0 diff --git a/go.sum b/go.sum index e801759..9260cbb 100644 --- a/go.sum +++ b/go.sum @@ -143,9 +143,8 @@ github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81P github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= -github.com/vishvananda/netlink v1.2.1-beta.2 h1:Llsql0lnQEbHj0I1OuKyp8otXp0r3q0mPkuhwHfStVs= -github.com/vishvananda/netlink v1.2.1-beta.2/go.mod h1:twkDnbuQxJYemMlGd4JFIcuhgX83tXhKS2B/PRMpOho= -github.com/vishvananda/netns v0.0.0-20200728191858-db3c7e526aae/go.mod h1:DD4vA1DwXk04H54A1oHXtwZmA0grkVMdPxx/VGLCah0= +github.com/vishvananda/netlink v1.3.0 h1:X7l42GfcV4S6E4vHTsw48qbrV+9PVojNfIhZcwQdrZk= +github.com/vishvananda/netlink v1.3.0/go.mod h1:i6NetklAujEcC6fK0JPjT8qSwWyO0HLn4UKG+hGqeJs= github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8= github.com/vishvananda/netns v0.0.4/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= @@ -192,17 +191,17 @@ golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5h golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200106162015-b016eb3dc98e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200217220822-9197077df867/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200615200032-f1bc736245b1/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200625212154-ddb9806d33ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200728102440-3e129f6d46b1/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201015000850-e3ed0017c211/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210603081109-ebe580a85c40/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.26.0 h1:KHjCJyddX0LoSTb3J+vWpupP9p0oznkqVk/IfjymZbo= golang.org/x/sys v0.26.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= From 999418d0e9c2e37acbd1103217e5996025a04698 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Fri, 11 Oct 2024 11:36:43 -0400 Subject: [PATCH 30/67] Bump github.com/miekg/dns from 1.1.61 to 1.1.62 (#1201) Bumps [github.com/miekg/dns](https://github.com/miekg/dns) from 1.1.61 to 1.1.62. - [Changelog](https://github.com/miekg/dns/blob/master/Makefile.release) - [Commits](https://github.com/miekg/dns/compare/v1.1.61...v1.1.62) --- updated-dependencies: - dependency-name: github.com/miekg/dns dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- go.mod | 2 +- go.sum | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/go.mod b/go.mod index 6720061..e844c6e 100644 --- a/go.mod +++ b/go.mod @@ -14,7 +14,7 @@ require ( github.com/gogo/protobuf v1.3.2 github.com/google/gopacket v1.1.19 github.com/kardianos/service v1.2.2 - github.com/miekg/dns v1.1.61 + github.com/miekg/dns v1.1.62 github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b github.com/nbrownus/go-metrics-prometheus v0.0.0-20210712211119-974a6260965f github.com/prometheus/client_golang v1.19.1 diff --git a/go.sum b/go.sum index 9260cbb..f357e57 100644 --- a/go.sum +++ b/go.sum @@ -81,8 +81,8 @@ github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= -github.com/miekg/dns v1.1.61 h1:nLxbwF3XxhwVSm8g9Dghm9MHPaUZuqhPiGL+675ZmEs= -github.com/miekg/dns v1.1.61/go.mod h1:mnAarhS3nWaW+NVP2wTkYVIZyHNJ098SJZUki3eykwQ= +github.com/miekg/dns v1.1.62 h1:cN8OuEF1/x5Rq6Np+h1epln8OiyPWV+lROx9LxcGgIQ= +github.com/miekg/dns v1.1.62/go.mod h1:mvDlcItzm+br7MToIKqkglaGhlFMHJ9DTNNWONWXbNQ= github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b h1:J/AzCvg5z0Hn1rqZUJjpbzALUmkKX0Zwbc/i4fw7Sfk= github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b/go.mod h1:XsNlhZGX73bx86s2hdc/FuaLm2CPZJemRLMA+WTFxgs= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= From d3bf09ef8ebc71749340e59cb8caf096c0767915 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Fri, 11 Oct 2024 11:38:40 -0400 Subject: [PATCH 31/67] Bump github.com/gaissmai/bart from 0.11.1 to 0.13.0 (#1231) Bumps [github.com/gaissmai/bart](https://github.com/gaissmai/bart) from 0.11.1 to 0.13.0. - [Release notes](https://github.com/gaissmai/bart/releases) - [Commits](https://github.com/gaissmai/bart/compare/v0.11.1...v0.13.0) --- updated-dependencies: - dependency-name: github.com/gaissmai/bart dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- go.mod | 4 ++-- go.sum | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/go.mod b/go.mod index e844c6e..ec5baab 100644 --- a/go.mod +++ b/go.mod @@ -10,7 +10,7 @@ require ( github.com/armon/go-radix v1.0.0 github.com/cyberdelia/go-metrics-graphite v0.0.0-20161219230853-39f87cc3b432 github.com/flynn/noise v1.1.0 - github.com/gaissmai/bart v0.11.1 + github.com/gaissmai/bart v0.13.0 github.com/gogo/protobuf v1.3.2 github.com/google/gopacket v1.1.19 github.com/kardianos/service v1.2.2 @@ -41,7 +41,7 @@ require ( require ( github.com/beorn7/perks v1.0.1 // indirect - github.com/bits-and-blooms/bitset v1.13.0 // indirect + github.com/bits-and-blooms/bitset v1.14.3 // indirect github.com/cespare/xxhash/v2 v2.2.0 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/google/btree v1.1.2 // indirect diff --git a/go.sum b/go.sum index f357e57..4b1f29b 100644 --- a/go.sum +++ b/go.sum @@ -14,8 +14,8 @@ github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24 github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= -github.com/bits-and-blooms/bitset v1.13.0 h1:bAQ9OPNFYbGHV6Nez0tmNI0RiEu7/hxlYJRUA0wFAVE= -github.com/bits-and-blooms/bitset v1.13.0/go.mod h1:7hO7Gc7Pp1vODcmWvKMRA9BNmbv6a/7QIWpPxHddWR8= +github.com/bits-and-blooms/bitset v1.14.3 h1:Gd2c8lSNf9pKXom5JtD7AaKO8o7fGQ2LtFj1436qilA= +github.com/bits-and-blooms/bitset v1.14.3/go.mod h1:7hO7Gc7Pp1vODcmWvKMRA9BNmbv6a/7QIWpPxHddWR8= github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44= github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= @@ -26,8 +26,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/flynn/noise v1.1.0 h1:KjPQoQCEFdZDiP03phOvGi11+SVVhBG2wOWAorLsstg= github.com/flynn/noise v1.1.0/go.mod h1:xbMo+0i6+IGbYdJhF31t2eR1BIU0CYc12+BNAKwUTag= -github.com/gaissmai/bart v0.11.1 h1:5Uv5XwsaFBRo4E5VBcb9TzY8B7zxFf+U7isDxqOrRfc= -github.com/gaissmai/bart v0.11.1/go.mod h1:KHeYECXQiBjTzQz/om2tqn3sZF1J7hw9m6z41ftj3fg= +github.com/gaissmai/bart v0.13.0 h1:pItEhXDVVebUa+i978FfQ7ye8xZc1FrMgs8nJPPWAgA= +github.com/gaissmai/bart v0.13.0/go.mod h1:qSes2fnJ8hB410BW0ymHUN/eQkuGpTYyJcN8sKMYpJU= github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY= From 37415d57d0e0ea5c25007c63e3222deabb367eab Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Fri, 11 Oct 2024 12:50:31 -0400 Subject: [PATCH 32/67] Bump github.com/prometheus/client_golang from 1.19.1 to 1.20.4 (#1228) Bumps [github.com/prometheus/client_golang](https://github.com/prometheus/client_golang) from 1.19.1 to 1.20.4. - [Release notes](https://github.com/prometheus/client_golang/releases) - [Changelog](https://github.com/prometheus/client_golang/blob/main/CHANGELOG.md) - [Commits](https://github.com/prometheus/client_golang/compare/v1.19.1...v1.20.4) --- updated-dependencies: - dependency-name: github.com/prometheus/client_golang dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- go.mod | 12 +++++++----- go.sum | 26 ++++++++++++++++---------- 2 files changed, 23 insertions(+), 15 deletions(-) diff --git a/go.mod b/go.mod index ec5baab..882c386 100644 --- a/go.mod +++ b/go.mod @@ -17,7 +17,7 @@ require ( github.com/miekg/dns v1.1.62 github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b github.com/nbrownus/go-metrics-prometheus v0.0.0-20210712211119-974a6260965f - github.com/prometheus/client_golang v1.19.1 + github.com/prometheus/client_golang v1.20.4 github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475 github.com/sirupsen/logrus v1.9.3 github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e @@ -42,13 +42,15 @@ require ( require ( github.com/beorn7/perks v1.0.1 // indirect github.com/bits-and-blooms/bitset v1.14.3 // indirect - github.com/cespare/xxhash/v2 v2.2.0 // indirect + github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/google/btree v1.1.2 // indirect + github.com/klauspost/compress v1.17.9 // indirect + github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect - github.com/prometheus/client_model v0.5.0 // indirect - github.com/prometheus/common v0.48.0 // indirect - github.com/prometheus/procfs v0.12.0 // indirect + github.com/prometheus/client_model v0.6.1 // indirect + github.com/prometheus/common v0.55.0 // indirect + github.com/prometheus/procfs v0.15.1 // indirect github.com/vishvananda/netns v0.0.4 // indirect golang.org/x/mod v0.18.0 // indirect golang.org/x/time v0.5.0 // indirect diff --git a/go.sum b/go.sum index 4b1f29b..fe010cb 100644 --- a/go.sum +++ b/go.sum @@ -17,8 +17,8 @@ github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6r github.com/bits-and-blooms/bitset v1.14.3 h1:Gd2c8lSNf9pKXom5JtD7AaKO8o7fGQ2LtFj1436qilA= github.com/bits-and-blooms/bitset v1.14.3/go.mod h1:7hO7Gc7Pp1vODcmWvKMRA9BNmbv6a/7QIWpPxHddWR8= github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= -github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44= -github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= +github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/cyberdelia/go-metrics-graphite v0.0.0-20161219230853-39f87cc3b432 h1:M5QgkYacWj0Xs8MhpIK/5uwU02icXpEoSo9sM2aRCps= github.com/cyberdelia/go-metrics-graphite v0.0.0-20161219230853-39f87cc3b432/go.mod h1:xwIwAxMvYnVrGJPe2FKx5prTrnAjGOD8zvDOnxnrrkM= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -70,6 +70,8 @@ github.com/kardianos/service v1.2.2 h1:ZvePhAHfvo0A7Mftk/tEzqEZ7Q4lgnR8sGz4xu1YX github.com/kardianos/service v1.2.2/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= +github.com/klauspost/compress v1.17.9 h1:6KIumPrER1LHsvBVuDa0r5xaG0Es51mhhB9BQB2qeMA= +github.com/klauspost/compress v1.17.9/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw= github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/konsorten/go-windows-terminal-sequences v1.0.3/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc= @@ -80,6 +82,8 @@ github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3x github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= +github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= github.com/miekg/dns v1.1.62 h1:cN8OuEF1/x5Rq6Np+h1epln8OiyPWV+lROx9LxcGgIQ= github.com/miekg/dns v1.1.62/go.mod h1:mvDlcItzm+br7MToIKqkglaGhlFMHJ9DTNNWONWXbNQ= @@ -89,6 +93,8 @@ github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJ github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= +github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA= +github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= github.com/nbrownus/go-metrics-prometheus v0.0.0-20210712211119-974a6260965f h1:8dM0ilqKL0Uzl42GABzzC4Oqlc3kGRILz0vgoff7nwg= @@ -102,24 +108,24 @@ github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXP github.com/prometheus/client_golang v1.0.0/go.mod h1:db9x61etRT2tGnBNRi70OPL5FsnadC4Ky3P0J6CfImo= github.com/prometheus/client_golang v1.7.1/go.mod h1:PY5Wy2awLA44sXw4AOSfFBetzPP4j5+D6mVACh+pe2M= github.com/prometheus/client_golang v1.11.0/go.mod h1:Z6t4BnS23TR94PD6BsDNk8yVqroYurpAkEiz0P2BEV0= -github.com/prometheus/client_golang v1.19.1 h1:wZWJDwK+NameRJuPGDhlnFgx8e8HN3XHQeLaYJFJBOE= -github.com/prometheus/client_golang v1.19.1/go.mod h1:mP78NwGzrVks5S2H6ab8+ZZGJLZUq1hoULYBAYBw1Ho= +github.com/prometheus/client_golang v1.20.4 h1:Tgh3Yr67PaOv/uTqloMsCEdeuFTatm5zIq5+qNN23vI= +github.com/prometheus/client_golang v1.20.4/go.mod h1:PIEt8X02hGcP8JWbeHyeZ53Y/jReSnHgO035n//V5WE= github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo= github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= github.com/prometheus/client_model v0.2.0/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= -github.com/prometheus/client_model v0.5.0 h1:VQw1hfvPvk3Uv6Qf29VrPF32JB6rtbgI6cYPYQjL0Qw= -github.com/prometheus/client_model v0.5.0/go.mod h1:dTiFglRmd66nLR9Pv9f0mZi7B7fk5Pm3gvsjB5tr+kI= +github.com/prometheus/client_model v0.6.1 h1:ZKSh/rekM+n3CeS952MLRAdFwIKqeY8b62p8ais2e9E= +github.com/prometheus/client_model v0.6.1/go.mod h1:OrxVMOVHjw3lKMa8+x6HeMGkHMQyHDk9E3jmP2AmGiY= github.com/prometheus/common v0.4.1/go.mod h1:TNfzLD0ON7rHzMJeJkieUDPYmFC7Snx/y86RQel1bk4= github.com/prometheus/common v0.10.0/go.mod h1:Tlit/dnDKsSWFlCLTWaA1cyBgKHSMdTB80sz/V91rCo= github.com/prometheus/common v0.26.0/go.mod h1:M7rCNAaPfAosfx8veZJCuw84e35h3Cfd9VFqTh1DIvc= -github.com/prometheus/common v0.48.0 h1:QO8U2CdOzSn1BBsmXJXduaaW+dY/5QLjfB8svtSzKKE= -github.com/prometheus/common v0.48.0/go.mod h1:0/KsvlIEfPQCQ5I2iNSAWKPZziNCvRs5EC6ILDTlAPc= +github.com/prometheus/common v0.55.0 h1:KEi6DK7lXW/m7Ig5i47x0vRzuBsHuvJdi5ee6Y3G1dc= +github.com/prometheus/common v0.55.0/go.mod h1:2SECS4xJG1kd8XF9IcM1gMX6510RAEL65zxzNImwdc8= github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk= github.com/prometheus/procfs v0.0.2/go.mod h1:TjEm7ze935MbeOT/UhFTIMYKhuLP4wbCsTZCD3I8kEA= github.com/prometheus/procfs v0.1.3/go.mod h1:lV6e/gmhEcM9IjHGsFOCxxuZ+z1YqCvr4OA4YeYWdaU= github.com/prometheus/procfs v0.6.0/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1xBZuNvfVA= -github.com/prometheus/procfs v0.12.0 h1:jluTpSng7V9hY0O2R9DzzJHYb2xULk9VTR1V1R/k6Bo= -github.com/prometheus/procfs v0.12.0/go.mod h1:pcuDEFsWDnvcgNzo4EEweacyhjeA9Zk3cnaOZAZEfOo= +github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc= +github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk= github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475 h1:N/ElC8H3+5XpJzTSTfLsJV/mx9Q9g7kxmchpfZyxgzM= github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4= github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ= From 3f6a7cb250a756314da65e8a440628dd94e7976f Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Fri, 11 Oct 2024 13:05:58 -0400 Subject: [PATCH 33/67] Bump google.golang.org/protobuf in the protobuf-dependencies group (#1250) Bumps the protobuf-dependencies group with 1 update: google.golang.org/protobuf. Updates `google.golang.org/protobuf` from 1.34.2 to 1.35.1 --- updated-dependencies: - dependency-name: google.golang.org/protobuf dependency-type: direct:production update-type: version-update:semver-minor dependency-group: protobuf-dependencies ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- go.mod | 2 +- go.sum | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/go.mod b/go.mod index 882c386..f464990 100644 --- a/go.mod +++ b/go.mod @@ -34,7 +34,7 @@ require ( golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b golang.zx2c4.com/wireguard/windows v0.5.3 - google.golang.org/protobuf v1.34.2 + google.golang.org/protobuf v1.35.1 gopkg.in/yaml.v2 v2.4.0 gvisor.dev/gvisor v0.0.0-20240423190808-9d7a357edefe ) diff --git a/go.sum b/go.sum index fe010cb..dacc3d3 100644 --- a/go.sum +++ b/go.sum @@ -243,8 +243,8 @@ google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miE google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo= google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= -google.golang.org/protobuf v1.34.2 h1:6xV6lTsCfpGD21XK49h7MhtcApnLqkfYgPcdHftf6hg= -google.golang.org/protobuf v1.34.2/go.mod h1:qYOHts0dSfpeUzUFpOMr/WGzszTmLH+DiWniOlNbLDw= +google.golang.org/protobuf v1.35.1 h1:m3LfL6/Ca+fqnjnlqQXNpFPABW1UD7mjh8KO2mKFytA= +google.golang.org/protobuf v1.35.1/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE= gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= From 3e6c75573fd55d4bcdb6ad19f8c362d59bb1d9f0 Mon Sep 17 00:00:00 2001 From: Nate Brown Date: Wed, 23 Oct 2024 14:28:02 -0500 Subject: [PATCH 34/67] Fix static host map wrong responder situations, correct logging (#1259) --- e2e/handshakes_test.go | 115 +++++++++++++++++++++++++++++++++++++---- handshake_ix.go | 39 +++++++------- remote_list.go | 4 +- 3 files changed, 129 insertions(+), 29 deletions(-) diff --git a/e2e/handshakes_test.go b/e2e/handshakes_test.go index 6be94ad..f6069bf 100644 --- a/e2e/handshakes_test.go +++ b/e2e/handshakes_test.go @@ -97,16 +97,10 @@ func TestGoodHandshake(t *testing.T) { func TestWrongResponderHandshake(t *testing.T) { ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) - // The IPs here are chosen on purpose: - // The current remote handling will sort by preference, public, and then lexically. - // So we need them to have a higher address than evil (we could apply a preference though) myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me", "10.128.0.100/24", nil) theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", "10.128.0.99/24", nil) evilControl, evilVpnIp, evilUdpAddr, _ := newSimpleServer(ca, caKey, "evil", "10.128.0.2/24", nil) - // Add their real udp addr, which should be tried after evil. - myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr) - // Put the evil udp addr in for their vpn Ip, this is a case of being lied to by the lighthouse. myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), evilUdpAddr) @@ -119,10 +113,114 @@ func TestWrongResponderHandshake(t *testing.T) { theirControl.Start() evilControl.Start() - t.Log("Start the handshake process, we will route until we see our cached packet get sent to them") + t.Log("Start the handshake process, we will route until we see the evil tunnel closed") myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me")) + + h := &header.H{} + r.RouteForAllExitFunc(func(p *udp.Packet, c *nebula.Control) router.ExitType { + err := h.Parse(p.Data) + if err != nil { + panic(err) + } + + if h.Type == header.CloseTunnel && p.To == evilUdpAddr { + return router.RouteAndExit + } + + return router.KeepRouting + }) + + t.Log("Evil tunnel is closed, inject the correct udp addr for them") + myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr) + pendingHi := myControl.GetHostInfoByVpnIp(theirVpnIpNet.Addr(), true) + assert.NotContains(t, pendingHi.RemoteAddrs, evilUdpAddr) + + t.Log("Route until we see the cached packet") + r.RouteForAllExitFunc(func(p *udp.Packet, c *nebula.Control) router.ExitType { + err := h.Parse(p.Data) + if err != nil { + panic(err) + } + + if p.To == theirUdpAddr && h.Type == 1 { + return router.RouteAndExit + } + + return router.KeepRouting + }) + + //TODO: Assert pending hostmap - I should have a correct hostinfo for them now + + t.Log("My cached packet should be received by them") + myCachedPacket := theirControl.GetFromTun(true) + assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80) + + t.Log("Test the tunnel with them") + assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl) + assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) + + t.Log("Flush all packets from all controllers") + r.FlushAll() + + t.Log("Ensure ensure I don't have any hostinfo artifacts from evil") + assert.Nil(t, myControl.GetHostInfoByVpnIp(evilVpnIp.Addr(), true), "My pending hostmap should not contain evil") + assert.Nil(t, myControl.GetHostInfoByVpnIp(evilVpnIp.Addr(), false), "My main hostmap should not contain evil") + + //TODO: assert hostmaps for everyone + r.RenderHostmaps("Final hostmaps", myControl, theirControl, evilControl) + t.Log("Success!") + myControl.Stop() + theirControl.Stop() +} + +func TestWrongResponderHandshakeStaticHostMap(t *testing.T) { + ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", "10.128.0.99/24", nil) + evilControl, evilVpnIp, evilUdpAddr, _ := newSimpleServer(ca, caKey, "evil", "10.128.0.2/24", nil) + o := m{ + "static_host_map": m{ + theirVpnIpNet.Addr().String(): []string{evilUdpAddr.String()}, + }, + } + myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me", "10.128.0.100/24", o) + + // Put the evil udp addr in for their vpn addr, this is a case of a remote at a static entry changing its vpn addr. + myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), evilUdpAddr) + + // Build a router so we don't have to reason who gets which packet + r := router.NewR(t, myControl, theirControl, evilControl) + defer r.RenderFlow() + + // Start the servers + myControl.Start() + theirControl.Start() + evilControl.Start() + + t.Log("Start the handshake process, we will route until we see the evil tunnel closed") + myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me")) + + h := &header.H{} + r.RouteForAllExitFunc(func(p *udp.Packet, c *nebula.Control) router.ExitType { + err := h.Parse(p.Data) + if err != nil { + panic(err) + } + + if h.Type == header.CloseTunnel && p.To == evilUdpAddr { + return router.RouteAndExit + } + + return router.KeepRouting + }) + + t.Log("Evil tunnel is closed, inject the correct udp addr for them") + myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr) + pendingHi := myControl.GetHostInfoByVpnIp(theirVpnIpNet.Addr(), true) + assert.NotContains(t, pendingHi.RemoteAddrs, evilUdpAddr) + + t.Log("Route until we see the cached packet") r.RouteForAllExitFunc(func(p *udp.Packet, c *nebula.Control) router.ExitType { - h := &header.H{} err := h.Parse(p.Data) if err != nil { panic(err) @@ -151,7 +249,6 @@ func TestWrongResponderHandshake(t *testing.T) { t.Log("Ensure ensure I don't have any hostinfo artifacts from evil") assert.Nil(t, myControl.GetHostInfoByVpnIp(evilVpnIp.Addr(), true), "My pending hostmap should not contain evil") assert.Nil(t, myControl.GetHostInfoByVpnIp(evilVpnIp.Addr(), false), "My main hostmap should not contain evil") - //NOTE: if evil lost the handshake race it may still have a tunnel since me would reject the handshake since the tunnel is complete //TODO: assert hostmaps for everyone r.RenderHostmaps("Final hostmaps", myControl, theirControl, evilControl) diff --git a/handshake_ix.go b/handshake_ix.go index 73e8541..3add83d 100644 --- a/handshake_ix.go +++ b/handshake_ix.go @@ -418,6 +418,21 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha fingerprint := remoteCert.Fingerprint issuer := remoteCert.Certificate.Issuer() + hostinfo.remoteIndexId = hs.Details.ResponderIndex + hostinfo.lastHandshakeTime = hs.Details.Time + + // Store their cert and our symmetric keys + ci.peerCert = remoteCert + ci.dKey = NewNebulaCipherState(dKey) + ci.eKey = NewNebulaCipherState(eKey) + + // Make sure the current udpAddr being used is set for responding + if addr.IsValid() { + hostinfo.SetRemote(addr) + } else { + hostinfo.relayState.InsertRelayTo(via.relayHI.vpnIp) + } + // Ensure the right host responded if vpnIp != hostinfo.vpnIp { f.l.WithField("intendedVpnIp", hostinfo.vpnIp).WithField("haveVpnIp", vpnIp). @@ -435,10 +450,8 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha newHH.hostinfo.remotes = hostinfo.remotes newHH.hostinfo.remotes.BlockRemote(addr) - // Get the correct remote list for the host we did handshake with - hostinfo.remotes = f.lightHouse.QueryCache(vpnIp) - - f.l.WithField("blockedUdpAddrs", newHH.hostinfo.remotes.CopyBlockedRemotes()).WithField("vpnIp", vpnIp). + f.l.WithField("blockedUdpAddrs", newHH.hostinfo.remotes.CopyBlockedRemotes()). + WithField("vpnIp", newHH.hostinfo.vpnIp). WithField("remotes", newHH.hostinfo.remotes.CopyAddrs(f.hostMap.GetPreferredRanges())). Info("Blocked addresses for handshakes") @@ -446,6 +459,9 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha newHH.packetStore = hh.packetStore hh.packetStore = []*cachedPacket{} + // Get the correct remote list for the host we did handshake with + hostinfo.SetRemote(addr) + hostinfo.remotes = f.lightHouse.QueryCache(vpnIp) // Finally, put the correct vpn ip in the host info, tell them to close the tunnel, and return true to tear down hostinfo.vpnIp = vpnIp f.sendCloseTunnel(hostinfo) @@ -468,21 +484,6 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha WithField("sentCachedPackets", len(hh.packetStore)). Info("Handshake message received") - hostinfo.remoteIndexId = hs.Details.ResponderIndex - hostinfo.lastHandshakeTime = hs.Details.Time - - // Store their cert and our symmetric keys - ci.peerCert = remoteCert - ci.dKey = NewNebulaCipherState(dKey) - ci.eKey = NewNebulaCipherState(eKey) - - // Make sure the current udpAddr being used is set for responding - if addr.IsValid() { - hostinfo.SetRemote(addr) - } else { - hostinfo.relayState.InsertRelayTo(via.relayHI.vpnIp) - } - // Build up the radix for the firewall if we have subnets in the cert hostinfo.CreateRemoteCIDR(remoteCert.Certificate) diff --git a/remote_list.go b/remote_list.go index fa14f42..94db8f2 100644 --- a/remote_list.go +++ b/remote_list.go @@ -576,7 +576,9 @@ func (r *RemoteList) unlockedCollect() { dnsAddrs := r.hr.GetIPs() for _, addr := range dnsAddrs { if r.shouldAdd == nil || r.shouldAdd(addr.Addr()) { - addrs = append(addrs, addr) + if !r.unlockedIsBad(addr) { + addrs = append(addrs, addr) + } } } From 2b427a7e8934f0a436fea25eb40a6b979b34ee7a Mon Sep 17 00:00:00 2001 From: Ian VanSchooten Date: Mon, 13 Jan 2025 13:35:53 -0500 Subject: [PATCH 35/67] Update slack invitation link (#1308) --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 65ea91f..56e4c9d 100644 --- a/README.md +++ b/README.md @@ -12,7 +12,7 @@ Further documentation can be found [here](https://nebula.defined.net/docs/). You can read more about Nebula [here](https://medium.com/p/884110a5579). -You can also join the NebulaOSS Slack group [here](https://join.slack.com/t/nebulaoss/shared_invite/enQtOTA5MDI4NDg3MTg4LTkwY2EwNTI4NzQyMzc0M2ZlODBjNWI3NTY1MzhiOThiMmZlZjVkMTI0NGY4YTMyNjUwMWEyNzNkZTJmYzQxOGU). +You can also join the NebulaOSS Slack group [here](https://join.slack.com/t/nebulaoss/shared_invite/zt-2xqe6e7vn-k_KGi8s13nsr7cvHVvHvuQ). ## Supported Platforms From d97ed57a19396dbc8faa299d02a7f45ec760fd3b Mon Sep 17 00:00:00 2001 From: Nate Brown Date: Thu, 6 Mar 2025 11:28:26 -0600 Subject: [PATCH 36/67] V2 certificate format (#1216) Co-authored-by: Nate Brown Co-authored-by: Jack Doan Co-authored-by: brad-defined <77982333+brad-defined@users.noreply.github.com> Co-authored-by: Jack Doan --- .gitignore | 3 +- Makefile | 2 +- allow_list.go | 37 +- calculated_remote.go | 59 +- calculated_remote_test.go | 66 +- cert/README.md | 19 +- cert/asn1.go | 52 ++ cert/ca_pool.go | 2 +- cert/ca_pool_test.go | 512 ++++++++++++++- cert/cert.go | 64 +- cert/cert_test.go | 695 --------------------- cert/cert_v1.go | 467 +++++++------- cert/cert_v1_test.go | 218 +++++++ cert/cert_v2.asn1 | 37 ++ cert/cert_v2.go | 730 ++++++++++++++++++++++ cert/cert_v2_test.go | 267 ++++++++ cert/errors.go | 46 +- cert/helper_test.go | 141 +++++ cert/pem.go | 20 +- cert/sign.go | 119 +++- cert/sign_test.go | 90 +++ e2e/helpers.go => cert_test/cert.go | 62 +- cmd/nebula-cert/ca.go | 68 +- cmd/nebula-cert/ca_test.go | 36 +- cmd/nebula-cert/keygen_test.go | 2 - cmd/nebula-cert/main_test.go | 2 - cmd/nebula-cert/print.go | 17 +- cmd/nebula-cert/print_test.go | 75 ++- cmd/nebula-cert/sign.go | 234 +++++-- cmd/nebula-cert/sign_test.go | 83 +-- cmd/nebula-cert/verify.go | 41 +- cmd/nebula-cert/verify_test.go | 6 +- config/config_test.go | 3 - connection_manager.go | 89 ++- connection_manager_test.go | 55 +- connection_state.go | 47 +- control.go | 51 +- control_test.go | 34 +- control_tester.go | 65 +- dns_server.go | 110 ++-- dns_server_test.go | 25 +- e2e/handshakes_test.go | 561 +++++++++++------ e2e/helpers_test.go | 111 ++-- e2e/router/hostmap.go | 7 +- e2e/router/router.go | 66 +- examples/config.yml | 17 +- firewall.go | 121 ++-- firewall/packet.go | 21 +- firewall_test.go | 79 +-- go.mod | 1 - go.sum | 2 - handshake_ix.go | 276 ++++++--- handshake_manager.go | 303 +++++---- handshake_manager_test.go | 30 +- hostmap.go | 263 +++++--- hostmap_test.go | 56 +- hostmap_tester.go | 4 +- inside.go | 58 +- interface.go | 143 +++-- iputil/packet.go | 2 - lighthouse.go | 926 ++++++++++++++++++---------- lighthouse_test.go | 319 +++++----- main.go | 29 +- message_metrics.go | 2 - nebula.pb.go | 921 ++++++++++++++++++++------- nebula.proto | 32 +- outside.go | 272 ++++---- outside_test.go | 542 +++++++++++++++- overlay/device.go | 2 +- overlay/route.go | 34 +- overlay/route_test.go | 72 ++- overlay/tun.go | 18 +- overlay/tun_android.go | 22 +- overlay/tun_darwin.go | 423 +++++++------ overlay/tun_disabled.go | 16 +- overlay/tun_freebsd.go | 40 +- overlay/tun_ios.go | 20 +- overlay/tun_linux.go | 198 +++--- overlay/tun_netbsd.go | 45 +- overlay/tun_openbsd.go | 48 +- overlay/tun_tester.go | 34 +- overlay/tun_water_windows.go | 208 ------- overlay/tun_windows.go | 265 +++++++- overlay/tun_wintun_windows.go | 252 -------- overlay/user.go | 12 +- pki.go | 414 ++++++++++--- relay_manager.go | 298 +++++---- remote_list.go | 86 +-- remote_list_test.go | 55 +- service/service.go | 4 +- service/service_test.go | 6 +- ssh.go | 161 +++-- sshd/command.go | 8 +- sshd/server.go | 4 +- sshd/session.go | 13 +- test/tun.go | 4 +- timeout_test.go | 8 +- udp/conn.go | 15 +- udp/temp.go | 10 - udp/udp_generic.go | 22 +- udp/udp_linux.go | 39 +- udp/udp_linux_32.go | 1 - udp/udp_linux_64.go | 1 - udp/udp_rio_windows.go | 21 +- udp/udp_tester.go | 10 +- 105 files changed, 8276 insertions(+), 4528 deletions(-) create mode 100644 cert/asn1.go delete mode 100644 cert/cert_test.go create mode 100644 cert/cert_v1_test.go create mode 100644 cert/cert_v2.asn1 create mode 100644 cert/cert_v2.go create mode 100644 cert/cert_v2_test.go create mode 100644 cert/helper_test.go create mode 100644 cert/sign_test.go rename e2e/helpers.go => cert_test/cert.go (51%) delete mode 100644 overlay/tun_water_windows.go delete mode 100644 overlay/tun_wintun_windows.go delete mode 100644 udp/temp.go diff --git a/.gitignore b/.gitignore index 0bffc85..55068f3 100644 --- a/.gitignore +++ b/.gitignore @@ -5,7 +5,8 @@ /nebula-darwin /nebula.exe /nebula-cert.exe -/coverage.out +**/coverage.out +**/cover.out /cpu.pprof /build /*.tar.gz diff --git a/Makefile b/Makefile index 6922cc3..d3fbcaa 100644 --- a/Makefile +++ b/Makefile @@ -196,7 +196,7 @@ bench-cpu-long: go test -bench=. -benchtime=60s -cpuprofile=cpu.pprof go tool pprof go-audit.test cpu.pprof -proto: nebula.pb.go cert/cert.pb.go +proto: nebula.pb.go cert/cert_v1.pb.go nebula.pb.go: nebula.proto .FORCE go build github.com/gogo/protobuf/protoc-gen-gogofaster diff --git a/allow_list.go b/allow_list.go index 90e0de2..cfdd983 100644 --- a/allow_list.go +++ b/allow_list.go @@ -128,7 +128,6 @@ func newAllowList(k string, raw interface{}, handleKey func(key string, value in ipNet = netip.PrefixFrom(ipNet.Addr().Unmap(), ipNet.Bits()) - // TODO: should we error on duplicate CIDRs in the config? tree.Insert(ipNet, value) maskBits := ipNet.Bits() @@ -251,20 +250,20 @@ func getRemoteAllowRanges(c *config.C, k string) (*bart.Table[*AllowList], error return remoteAllowRanges, nil } -func (al *AllowList) Allow(ip netip.Addr) bool { +func (al *AllowList) Allow(addr netip.Addr) bool { if al == nil { return true } - result, _ := al.cidrTree.Lookup(ip) + result, _ := al.cidrTree.Lookup(addr) return result } -func (al *LocalAllowList) Allow(ip netip.Addr) bool { +func (al *LocalAllowList) Allow(udpAddr netip.Addr) bool { if al == nil { return true } - return al.AllowList.Allow(ip) + return al.AllowList.Allow(udpAddr) } func (al *LocalAllowList) AllowName(name string) bool { @@ -282,23 +281,37 @@ func (al *LocalAllowList) AllowName(name string) bool { return !al.nameRules[0].Allow } -func (al *RemoteAllowList) AllowUnknownVpnIp(ip netip.Addr) bool { +func (al *RemoteAllowList) AllowUnknownVpnAddr(vpnAddr netip.Addr) bool { if al == nil { return true } - return al.AllowList.Allow(ip) + return al.AllowList.Allow(vpnAddr) } -func (al *RemoteAllowList) Allow(vpnIp netip.Addr, ip netip.Addr) bool { - if !al.getInsideAllowList(vpnIp).Allow(ip) { +func (al *RemoteAllowList) Allow(vpnAddr netip.Addr, udpAddr netip.Addr) bool { + if !al.getInsideAllowList(vpnAddr).Allow(udpAddr) { return false } - return al.AllowList.Allow(ip) + return al.AllowList.Allow(udpAddr) } -func (al *RemoteAllowList) getInsideAllowList(vpnIp netip.Addr) *AllowList { +func (al *RemoteAllowList) AllowAll(vpnAddrs []netip.Addr, udpAddr netip.Addr) bool { + if !al.AllowList.Allow(udpAddr) { + return false + } + + for _, vpnAddr := range vpnAddrs { + if !al.getInsideAllowList(vpnAddr).Allow(udpAddr) { + return false + } + } + + return true +} + +func (al *RemoteAllowList) getInsideAllowList(vpnAddr netip.Addr) *AllowList { if al.insideAllowLists != nil { - inside, ok := al.insideAllowLists.Lookup(vpnIp) + inside, ok := al.insideAllowLists.Lookup(vpnAddr) if ok { return inside } diff --git a/calculated_remote.go b/calculated_remote.go index ae2ed50..32d062a 100644 --- a/calculated_remote.go +++ b/calculated_remote.go @@ -21,7 +21,11 @@ type calculatedRemote struct { port uint32 } -func newCalculatedRemote(maskCidr netip.Prefix, port int) (*calculatedRemote, error) { +func newCalculatedRemote(cidr, maskCidr netip.Prefix, port int) (*calculatedRemote, error) { + if maskCidr.Addr().BitLen() != cidr.Addr().BitLen() { + return nil, fmt.Errorf("invalid mask: %s for cidr: %s", maskCidr, cidr) + } + masked := maskCidr.Masked() if port < 0 || port > math.MaxUint16 { return nil, fmt.Errorf("invalid port: %d", port) @@ -38,32 +42,38 @@ func (c *calculatedRemote) String() string { return fmt.Sprintf("CalculatedRemote(mask=%v port=%d)", c.ipNet, c.port) } -func (c *calculatedRemote) Apply(ip netip.Addr) *Ip4AndPort { - // Combine the masked bytes of the "mask" IP with the unmasked bytes - // of the overlay IP - if c.ipNet.Addr().Is4() { - return c.apply4(ip) - } - return c.apply6(ip) -} - -func (c *calculatedRemote) apply4(ip netip.Addr) *Ip4AndPort { - //TODO: IPV6-WORK this can be less crappy +func (c *calculatedRemote) ApplyV4(addr netip.Addr) *V4AddrPort { + // Combine the masked bytes of the "mask" IP with the unmasked bytes of the overlay IP maskb := net.CIDRMask(c.mask.Bits(), c.mask.Addr().BitLen()) mask := binary.BigEndian.Uint32(maskb[:]) b := c.mask.Addr().As4() - maskIp := binary.BigEndian.Uint32(b[:]) + maskAddr := binary.BigEndian.Uint32(b[:]) - b = ip.As4() - intIp := binary.BigEndian.Uint32(b[:]) + b = addr.As4() + intAddr := binary.BigEndian.Uint32(b[:]) - return &Ip4AndPort{(maskIp & mask) | (intIp & ^mask), c.port} + return &V4AddrPort{(maskAddr & mask) | (intAddr & ^mask), c.port} } -func (c *calculatedRemote) apply6(ip netip.Addr) *Ip4AndPort { - //TODO: IPV6-WORK - panic("Can not calculate ipv6 remote addresses") +func (c *calculatedRemote) ApplyV6(addr netip.Addr) *V6AddrPort { + mask := net.CIDRMask(c.mask.Bits(), c.mask.Addr().BitLen()) + maskAddr := c.mask.Addr().As16() + calcAddr := addr.As16() + + ap := V6AddrPort{Port: c.port} + + maskb := binary.BigEndian.Uint64(mask[:8]) + maskAddrb := binary.BigEndian.Uint64(maskAddr[:8]) + calcAddrb := binary.BigEndian.Uint64(calcAddr[:8]) + ap.Hi = (maskAddrb & maskb) | (calcAddrb & ^maskb) + + maskb = binary.BigEndian.Uint64(mask[8:]) + maskAddrb = binary.BigEndian.Uint64(maskAddr[8:]) + calcAddrb = binary.BigEndian.Uint64(calcAddr[8:]) + ap.Lo = (maskAddrb & maskb) | (calcAddrb & ^maskb) + + return &ap } func NewCalculatedRemotesFromConfig(c *config.C, k string) (*bart.Table[[]*calculatedRemote], error) { @@ -89,8 +99,7 @@ func NewCalculatedRemotesFromConfig(c *config.C, k string) (*bart.Table[[]*calcu return nil, fmt.Errorf("config `%s` has invalid CIDR: %s", k, rawCIDR) } - //TODO: IPV6-WORK this does not verify that rawValue contains the same bits as cidr here - entry, err := newCalculatedRemotesListFromConfig(rawValue) + entry, err := newCalculatedRemotesListFromConfig(cidr, rawValue) if err != nil { return nil, fmt.Errorf("config '%s.%s': %w", k, rawCIDR, err) } @@ -101,7 +110,7 @@ func NewCalculatedRemotesFromConfig(c *config.C, k string) (*bart.Table[[]*calcu return calculatedRemotes, nil } -func newCalculatedRemotesListFromConfig(raw any) ([]*calculatedRemote, error) { +func newCalculatedRemotesListFromConfig(cidr netip.Prefix, raw any) ([]*calculatedRemote, error) { rawList, ok := raw.([]any) if !ok { return nil, fmt.Errorf("calculated_remotes entry has invalid type: %T", raw) @@ -109,7 +118,7 @@ func newCalculatedRemotesListFromConfig(raw any) ([]*calculatedRemote, error) { var l []*calculatedRemote for _, e := range rawList { - c, err := newCalculatedRemotesEntryFromConfig(e) + c, err := newCalculatedRemotesEntryFromConfig(cidr, e) if err != nil { return nil, fmt.Errorf("calculated_remotes entry: %w", err) } @@ -119,7 +128,7 @@ func newCalculatedRemotesListFromConfig(raw any) ([]*calculatedRemote, error) { return l, nil } -func newCalculatedRemotesEntryFromConfig(raw any) (*calculatedRemote, error) { +func newCalculatedRemotesEntryFromConfig(cidr netip.Prefix, raw any) (*calculatedRemote, error) { rawMap, ok := raw.(map[any]any) if !ok { return nil, fmt.Errorf("invalid type: %T", raw) @@ -155,5 +164,5 @@ func newCalculatedRemotesEntryFromConfig(raw any) (*calculatedRemote, error) { return nil, fmt.Errorf("invalid port (type %T): %v", rawValue, rawValue) } - return newCalculatedRemote(maskCidr, port) + return newCalculatedRemote(cidr, maskCidr, port) } diff --git a/calculated_remote_test.go b/calculated_remote_test.go index 6ff1cb0..066213e 100644 --- a/calculated_remote_test.go +++ b/calculated_remote_test.go @@ -9,10 +9,9 @@ import ( ) func TestCalculatedRemoteApply(t *testing.T) { - ipNet, err := netip.ParsePrefix("192.168.1.0/24") - require.NoError(t, err) - - c, err := newCalculatedRemote(ipNet, 4242) + // Test v4 addresses + ipNet := netip.MustParsePrefix("192.168.1.0/24") + c, err := newCalculatedRemote(ipNet, ipNet, 4242) require.NoError(t, err) input, err := netip.ParseAddr("10.0.10.182") @@ -21,5 +20,62 @@ func TestCalculatedRemoteApply(t *testing.T) { expected, err := netip.ParseAddr("192.168.1.182") assert.NoError(t, err) - assert.Equal(t, NewIp4AndPortFromNetIP(expected, 4242), c.Apply(input)) + assert.Equal(t, netAddrToProtoV4AddrPort(expected, 4242), c.ApplyV4(input)) + + // Test v6 addresses + ipNet = netip.MustParsePrefix("ffff:ffff:ffff:ffff::0/64") + c, err = newCalculatedRemote(ipNet, ipNet, 4242) + require.NoError(t, err) + + input, err = netip.ParseAddr("beef:beef:beef:beef:beef:beef:beef:beef") + assert.NoError(t, err) + + expected, err = netip.ParseAddr("ffff:ffff:ffff:ffff:beef:beef:beef:beef") + assert.NoError(t, err) + + assert.Equal(t, netAddrToProtoV6AddrPort(expected, 4242), c.ApplyV6(input)) + + // Test v6 addresses part 2 + ipNet = netip.MustParsePrefix("ffff:ffff:ffff:ffff:ffff::0/80") + c, err = newCalculatedRemote(ipNet, ipNet, 4242) + require.NoError(t, err) + + input, err = netip.ParseAddr("beef:beef:beef:beef:beef:beef:beef:beef") + assert.NoError(t, err) + + expected, err = netip.ParseAddr("ffff:ffff:ffff:ffff:ffff:beef:beef:beef") + assert.NoError(t, err) + + assert.Equal(t, netAddrToProtoV6AddrPort(expected, 4242), c.ApplyV6(input)) + + // Test v6 addresses part 2 + ipNet = netip.MustParsePrefix("ffff:ffff:ffff::0/48") + c, err = newCalculatedRemote(ipNet, ipNet, 4242) + require.NoError(t, err) + + input, err = netip.ParseAddr("beef:beef:beef:beef:beef:beef:beef:beef") + assert.NoError(t, err) + + expected, err = netip.ParseAddr("ffff:ffff:ffff:beef:beef:beef:beef:beef") + assert.NoError(t, err) + + assert.Equal(t, netAddrToProtoV6AddrPort(expected, 4242), c.ApplyV6(input)) +} + +func Test_newCalculatedRemote(t *testing.T) { + c, err := newCalculatedRemote(netip.MustParsePrefix("1::1/128"), netip.MustParsePrefix("1.0.0.0/32"), 4242) + require.EqualError(t, err, "invalid mask: 1.0.0.0/32 for cidr: 1::1/128") + require.Nil(t, c) + + c, err = newCalculatedRemote(netip.MustParsePrefix("1.0.0.0/32"), netip.MustParsePrefix("1::1/128"), 4242) + require.EqualError(t, err, "invalid mask: 1::1/128 for cidr: 1.0.0.0/32") + require.Nil(t, c) + + c, err = newCalculatedRemote(netip.MustParsePrefix("1.0.0.0/32"), netip.MustParsePrefix("1.0.0.0/32"), 4242) + require.NoError(t, err) + require.NotNil(t, c) + + c, err = newCalculatedRemote(netip.MustParsePrefix("1::1/128"), netip.MustParsePrefix("1::1/128"), 4242) + require.NoError(t, err) + require.NotNil(t, c) } diff --git a/cert/README.md b/cert/README.md index ae19a28..1e27a6b 100644 --- a/cert/README.md +++ b/cert/README.md @@ -2,14 +2,25 @@ This is a library for interacting with `nebula` style certificates and authorities. -A `protobuf` definition of the certificate format is also included +There are now 2 versions of `nebula` certificates: -### Compiling the protobuf definition +## v1 -Make sure you have `protoc` installed. +This version is deprecated. + +A `protobuf` definition of the certificate format is included at `cert_v1.proto` + +To compile the definition you will need `protoc` installed. To compile for `go` with the same version of protobuf specified in go.mod: ```bash -make +make proto ``` + +## v2 + +This is the latest version which uses asn.1 DER encoding. It can support ipv4 and ipv6 and tolerate +future certificate changes better than v1. + +`cert_v2.asn1` defines the wire format and can be used to compile marshalers. \ No newline at end of file diff --git a/cert/asn1.go b/cert/asn1.go new file mode 100644 index 0000000..6bf6a8d --- /dev/null +++ b/cert/asn1.go @@ -0,0 +1,52 @@ +package cert + +import ( + "golang.org/x/crypto/cryptobyte" + "golang.org/x/crypto/cryptobyte/asn1" +) + +// readOptionalASN1Boolean reads an asn.1 boolean with a specific tag instead of a asn.1 tag wrapping a boolean with a value +// https://github.com/golang/go/issues/64811#issuecomment-1944446920 +func readOptionalASN1Boolean(b *cryptobyte.String, out *bool, tag asn1.Tag, defaultValue bool) bool { + var present bool + var child cryptobyte.String + if !b.ReadOptionalASN1(&child, &present, tag) { + return false + } + + if !present { + *out = defaultValue + return true + } + + // Ensure we have 1 byte + if len(child) == 1 { + *out = child[0] > 0 + return true + } + + return false +} + +// readOptionalASN1Byte reads an asn.1 uint8 with a specific tag instead of a asn.1 tag wrapping a uint8 with a value +// Similar issue as with readOptionalASN1Boolean +func readOptionalASN1Byte(b *cryptobyte.String, out *byte, tag asn1.Tag, defaultValue byte) bool { + var present bool + var child cryptobyte.String + if !b.ReadOptionalASN1(&child, &present, tag) { + return false + } + + if !present { + *out = defaultValue + return true + } + + // Ensure we have 1 byte + if len(child) == 1 { + *out = child[0] + return true + } + + return false +} diff --git a/cert/ca_pool.go b/cert/ca_pool.go index d525830..2bf480f 100644 --- a/cert/ca_pool.go +++ b/cert/ca_pool.go @@ -213,7 +213,7 @@ func (ncp *CAPool) GetCAForCert(c Certificate) (*CachedCertificate, error) { return signer, nil } - return nil, fmt.Errorf("could not find ca for the certificate") + return nil, ErrCaNotFound } // GetFingerprints returns an array of trusted CA fingerprints diff --git a/cert/ca_pool_test.go b/cert/ca_pool_test.go index 053640d..f03b2ba 100644 --- a/cert/ca_pool_test.go +++ b/cert/ca_pool_test.go @@ -1,7 +1,9 @@ package cert import ( + "net/netip" "testing" + "time" "github.com/stretchr/testify/assert" ) @@ -10,15 +12,15 @@ func TestNewCAPoolFromBytes(t *testing.T) { noNewLines := ` # Current provisional, Remove once everything moves over to the real root. -----BEGIN NEBULA CERTIFICATE----- -CkAKDm5lYnVsYSByb290IGNhKJfap9AFMJfg1+YGOiCUQGByMuNRhIlQBOyzXWbL -vcKBwDhov900phEfJ5DN3kABEkDCq5R8qBiu8sl54yVfgRcQXEDt3cHr8UTSLszv -bzBEr00kERQxxTzTsH8cpYEgRoipvmExvg8WP8NdAJEYJosB +Cj4KDm5lYnVsYSByb290IGNhKM0cMM24zPCvBzogV24YEw5YiqeI/oYo8XXFsoo+ +PBmiOafNJhLacf9rsspAARJAz9OAnh8TKAUKix1kKVMyQU4iM3LsFfZRf6ODWXIf +2qWMpB6fpd3PSoVYziPoOt2bIHIFLlgRLPJz3I3xBEdBCQ== -----END NEBULA CERTIFICATE----- # root-ca01 -----BEGIN NEBULA CERTIFICATE----- -CkMKEW5lYnVsYSByb290IGNhIDAxKJL2u9EFMJL86+cGOiDPXMH4oU6HZTk/CqTG -BVG+oJpAoqokUBbI4U0N8CSfpUABEkB/Pm5A2xyH/nc8mg/wvGUWG3pZ7nHzaDMf -8/phAUt+FLzqTECzQKisYswKvE3pl9mbEYKbOdIHrxdIp95mo4sF +CkEKEW5lYnVsYSByb290IGNhIDAxKM0cMM24zPCvBzogPzbWTxt8ZgXPQEwup7Br +BrtIt1O0q5AuTRT3+t2x1VJAARJAZ+2ib23qBXjdy49oU1YysrwuKkWWKrtJ7Jye +rFBQpDXikOukhQD/mfkloFwJ+Yjsfru7IpTN4ZfjXL+kN/2sCA== -----END NEBULA CERTIFICATE----- ` @@ -26,18 +28,18 @@ BVG+oJpAoqokUBbI4U0N8CSfpUABEkB/Pm5A2xyH/nc8mg/wvGUWG3pZ7nHzaDMf # Current provisional, Remove once everything moves over to the real root. -----BEGIN NEBULA CERTIFICATE----- -CkAKDm5lYnVsYSByb290IGNhKJfap9AFMJfg1+YGOiCUQGByMuNRhIlQBOyzXWbL -vcKBwDhov900phEfJ5DN3kABEkDCq5R8qBiu8sl54yVfgRcQXEDt3cHr8UTSLszv -bzBEr00kERQxxTzTsH8cpYEgRoipvmExvg8WP8NdAJEYJosB +Cj4KDm5lYnVsYSByb290IGNhKM0cMM24zPCvBzogV24YEw5YiqeI/oYo8XXFsoo+ +PBmiOafNJhLacf9rsspAARJAz9OAnh8TKAUKix1kKVMyQU4iM3LsFfZRf6ODWXIf +2qWMpB6fpd3PSoVYziPoOt2bIHIFLlgRLPJz3I3xBEdBCQ== -----END NEBULA CERTIFICATE----- # root-ca01 -----BEGIN NEBULA CERTIFICATE----- -CkMKEW5lYnVsYSByb290IGNhIDAxKJL2u9EFMJL86+cGOiDPXMH4oU6HZTk/CqTG -BVG+oJpAoqokUBbI4U0N8CSfpUABEkB/Pm5A2xyH/nc8mg/wvGUWG3pZ7nHzaDMf -8/phAUt+FLzqTECzQKisYswKvE3pl9mbEYKbOdIHrxdIp95mo4sF +CkEKEW5lYnVsYSByb290IGNhIDAxKM0cMM24zPCvBzogPzbWTxt8ZgXPQEwup7Br +BrtIt1O0q5AuTRT3+t2x1VJAARJAZ+2ib23qBXjdy49oU1YysrwuKkWWKrtJ7Jye +rFBQpDXikOukhQD/mfkloFwJ+Yjsfru7IpTN4ZfjXL+kN/2sCA== -----END NEBULA CERTIFICATE----- ` @@ -45,65 +47,513 @@ BVG+oJpAoqokUBbI4U0N8CSfpUABEkB/Pm5A2xyH/nc8mg/wvGUWG3pZ7nHzaDMf expired := ` # expired certificate -----BEGIN NEBULA CERTIFICATE----- -CjkKB2V4cGlyZWQouPmWjQYwufmWjQY6ILCRaoCkJlqHgv5jfDN4lzLHBvDzaQm4 -vZxfu144hmgjQAESQG4qlnZi8DncvD/LDZnLgJHOaX1DWCHHEh59epVsC+BNgTie -WH1M9n4O7cFtGlM6sJJOS+rCVVEJ3ABS7+MPdQs= +CjMKB2V4cGlyZWQozRwwzRw6ICJSG94CqX8wn5I65Pwn25V6HftVfWeIySVtp2DA +7TY/QAESQMaAk5iJT5EnQwK524ZaaHGEJLUqqbh5yyOHhboIGiVTWkFeH3HccTW8 +Tq5a8AyWDQdfXbtEZ1FwabeHfH5Asw0= -----END NEBULA CERTIFICATE----- ` p256 := ` # p256 certificate -----BEGIN NEBULA CERTIFICATE----- -CmYKEG5lYnVsYSBQMjU2IHRlc3Qo4s+7mgYw4tXrsAc6QQRkaW2jFmllYvN4+/k2 -6tctO9sPT3jOx8ES6M1nIqOhpTmZeabF/4rELDqPV4aH5jfJut798DUXql0FlF8H -76gvQAGgBgESRzBFAiEAib0/te6eMiZOKD8gdDeloMTS0wGuX2t0C7TFdUhAQzgC -IBNWYMep3ysx9zCgknfG5dKtwGTaqF++BWKDYdyl34KX +CmQKEG5lYnVsYSBQMjU2IHRlc3QozRwwzbjM8K8HOkEEdrmmg40zQp44AkMq6DZp +k+coOv04r+zh33ISyhbsafnYduN17p2eD7CmHvHuerguXD9f32gcxo/KsFCKEjMe ++0ABoAYBEkcwRQIgVoTg38L7uWku9xQgsr06kxZ/viQLOO/w1Qj1vFUEnhcCIQCq +75SjTiV92kv/1GcbT3wWpAZQQDBiUHVMVmh1822szA== -----END NEBULA CERTIFICATE----- ` rootCA := certificateV1{ details: detailsV1{ - Name: "nebula root ca", + name: "nebula root ca", }, } rootCA01 := certificateV1{ details: detailsV1{ - Name: "nebula root ca 01", + name: "nebula root ca 01", }, } rootCAP256 := certificateV1{ details: detailsV1{ - Name: "nebula P256 test", + name: "nebula P256 test", }, } p, err := NewCAPoolFromPEM([]byte(noNewLines)) assert.Nil(t, err) - assert.Equal(t, p.CAs[string("c9bfaf7ce8e84b2eeda2e27b469f4b9617bde192efd214b68891ecda6ed49522")].Certificate.Name(), rootCA.details.Name) - assert.Equal(t, p.CAs[string("5c9c3f23e7ee7fe97637cbd3a0a5b854154d1d9aaaf7b566a51f4a88f76b64cd")].Certificate.Name(), rootCA01.details.Name) + assert.Equal(t, p.CAs["ce4e6c7a596996eb0d82a8875f0f0137a4b53ce22d2421c9fd7150e7a26f6300"].Certificate.Name(), rootCA.details.name) + assert.Equal(t, p.CAs["04c585fcd9a49b276df956a22b7ebea3bf23f1fca5a17c0b56ce2e626631969e"].Certificate.Name(), rootCA01.details.name) pp, err := NewCAPoolFromPEM([]byte(withNewLines)) assert.Nil(t, err) - assert.Equal(t, pp.CAs[string("c9bfaf7ce8e84b2eeda2e27b469f4b9617bde192efd214b68891ecda6ed49522")].Certificate.Name(), rootCA.details.Name) - assert.Equal(t, pp.CAs[string("5c9c3f23e7ee7fe97637cbd3a0a5b854154d1d9aaaf7b566a51f4a88f76b64cd")].Certificate.Name(), rootCA01.details.Name) + assert.Equal(t, pp.CAs["ce4e6c7a596996eb0d82a8875f0f0137a4b53ce22d2421c9fd7150e7a26f6300"].Certificate.Name(), rootCA.details.name) + assert.Equal(t, pp.CAs["04c585fcd9a49b276df956a22b7ebea3bf23f1fca5a17c0b56ce2e626631969e"].Certificate.Name(), rootCA01.details.name) // expired cert, no valid certs ppp, err := NewCAPoolFromPEM([]byte(expired)) assert.Equal(t, ErrExpired, err) - assert.Equal(t, ppp.CAs[string("152070be6bb19bc9e3bde4c2f0e7d8f4ff5448b4c9856b8eccb314fade0229b0")].Certificate.Name(), "expired") + assert.Equal(t, ppp.CAs["c39b35a0e8f246203fe4f32b9aa8bfd155f1ae6a6be9d78370641e43397f48f5"].Certificate.Name(), "expired") // expired cert, with valid certs pppp, err := NewCAPoolFromPEM(append([]byte(expired), noNewLines...)) assert.Equal(t, ErrExpired, err) - assert.Equal(t, pppp.CAs[string("c9bfaf7ce8e84b2eeda2e27b469f4b9617bde192efd214b68891ecda6ed49522")].Certificate.Name(), rootCA.details.Name) - assert.Equal(t, pppp.CAs[string("5c9c3f23e7ee7fe97637cbd3a0a5b854154d1d9aaaf7b566a51f4a88f76b64cd")].Certificate.Name(), rootCA01.details.Name) - assert.Equal(t, pppp.CAs[string("152070be6bb19bc9e3bde4c2f0e7d8f4ff5448b4c9856b8eccb314fade0229b0")].Certificate.Name(), "expired") + assert.Equal(t, pppp.CAs["ce4e6c7a596996eb0d82a8875f0f0137a4b53ce22d2421c9fd7150e7a26f6300"].Certificate.Name(), rootCA.details.name) + assert.Equal(t, pppp.CAs["04c585fcd9a49b276df956a22b7ebea3bf23f1fca5a17c0b56ce2e626631969e"].Certificate.Name(), rootCA01.details.name) + assert.Equal(t, pppp.CAs["c39b35a0e8f246203fe4f32b9aa8bfd155f1ae6a6be9d78370641e43397f48f5"].Certificate.Name(), "expired") assert.Equal(t, len(pppp.CAs), 3) ppppp, err := NewCAPoolFromPEM([]byte(p256)) assert.Nil(t, err) - assert.Equal(t, ppppp.CAs[string("a7938893ec8c4ef769b06d7f425e5e46f7a7f5ffa49c3bcf4a86b608caba9159")].Certificate.Name(), rootCAP256.details.Name) + assert.Equal(t, ppppp.CAs["552bf7d99bec1fc775a0e4c324bf6d8f789b3078f1919c7960d2e5e0c351ee97"].Certificate.Name(), rootCAP256.details.name) assert.Equal(t, len(ppppp.CAs), 1) } + +func TestCertificateV1_Verify(t *testing.T) { + ca, _, caKey, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, nil) + c, _, _, _ := NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test cert", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil) + + caPool := NewCAPool() + assert.NoError(t, caPool.AddCA(ca)) + + f, err := c.Fingerprint() + assert.Nil(t, err) + caPool.BlocklistFingerprint(f) + + _, err = caPool.VerifyCertificate(time.Now(), c) + assert.EqualError(t, err, "certificate is in the block list") + + caPool.ResetCertBlocklist() + _, err = caPool.VerifyCertificate(time.Now(), c) + assert.Nil(t, err) + + _, err = caPool.VerifyCertificate(time.Now().Add(time.Hour*1000), c) + assert.EqualError(t, err, "root certificate is expired") + + assert.PanicsWithError(t, "certificate is valid before the signing certificate", func() { + NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test cert2", time.Time{}, time.Time{}, nil, nil, nil) + }) + + // Test group assertion + ca, _, caKey, _ = NewTestCaCert(Version1, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{"test1", "test2"}) + caPem, err := ca.MarshalPEM() + assert.Nil(t, err) + + caPool = NewCAPool() + b, err := caPool.AddCAFromPEM(caPem) + assert.NoError(t, err) + assert.Empty(t, b) + + assert.PanicsWithError(t, "certificate contained a group not present on the signing ca: bad", func() { + NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1", "bad"}) + }) + + c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test2", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"}) + assert.Nil(t, err) + _, err = caPool.VerifyCertificate(time.Now(), c) + assert.Nil(t, err) +} + +func TestCertificateV1_VerifyP256(t *testing.T) { + ca, _, caKey, _ := NewTestCaCert(Version1, Curve_P256, time.Now(), time.Now().Add(10*time.Minute), nil, nil, nil) + c, _, _, _ := NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil) + + caPool := NewCAPool() + assert.NoError(t, caPool.AddCA(ca)) + + f, err := c.Fingerprint() + assert.Nil(t, err) + caPool.BlocklistFingerprint(f) + + _, err = caPool.VerifyCertificate(time.Now(), c) + assert.EqualError(t, err, "certificate is in the block list") + + caPool.ResetCertBlocklist() + _, err = caPool.VerifyCertificate(time.Now(), c) + assert.Nil(t, err) + + _, err = caPool.VerifyCertificate(time.Now().Add(time.Hour*1000), c) + assert.EqualError(t, err, "root certificate is expired") + + assert.PanicsWithError(t, "certificate is valid before the signing certificate", func() { + NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil) + }) + + // Test group assertion + ca, _, caKey, _ = NewTestCaCert(Version1, Curve_P256, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{"test1", "test2"}) + caPem, err := ca.MarshalPEM() + assert.Nil(t, err) + + caPool = NewCAPool() + b, err := caPool.AddCAFromPEM(caPem) + assert.NoError(t, err) + assert.Empty(t, b) + + assert.PanicsWithError(t, "certificate contained a group not present on the signing ca: bad", func() { + NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1", "bad"}) + }) + + c, _, _, _ = NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"}) + _, err = caPool.VerifyCertificate(time.Now(), c) + assert.Nil(t, err) +} + +func TestCertificateV1_Verify_IPs(t *testing.T) { + caIp1 := mustParsePrefixUnmapped("10.0.0.0/16") + caIp2 := mustParsePrefixUnmapped("192.168.0.0/24") + ca, _, caKey, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), []netip.Prefix{caIp1, caIp2}, nil, []string{"test"}) + + caPem, err := ca.MarshalPEM() + assert.Nil(t, err) + + caPool := NewCAPool() + b, err := caPool.AddCAFromPEM(caPem) + assert.NoError(t, err) + assert.Empty(t, b) + + // ip is outside the network + cIp1 := mustParsePrefixUnmapped("10.1.0.0/24") + cIp2 := mustParsePrefixUnmapped("192.168.0.1/16") + assert.PanicsWithError(t, "certificate contained a network assignment outside the limitations of the signing ca: 10.1.0.0/24", func() { + NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"}) + }) + + // ip is outside the network reversed order of above + cIp1 = mustParsePrefixUnmapped("192.168.0.1/24") + cIp2 = mustParsePrefixUnmapped("10.1.0.0/24") + assert.PanicsWithError(t, "certificate contained a network assignment outside the limitations of the signing ca: 10.1.0.0/24", func() { + NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"}) + }) + + // ip is within the network but mask is outside + cIp1 = mustParsePrefixUnmapped("10.0.1.0/15") + cIp2 = mustParsePrefixUnmapped("192.168.0.1/24") + assert.PanicsWithError(t, "certificate contained a network assignment outside the limitations of the signing ca: 10.0.1.0/15", func() { + NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"}) + }) + + // ip is within the network but mask is outside reversed order of above + cIp1 = mustParsePrefixUnmapped("192.168.0.1/24") + cIp2 = mustParsePrefixUnmapped("10.0.1.0/15") + assert.PanicsWithError(t, "certificate contained a network assignment outside the limitations of the signing ca: 10.0.1.0/15", func() { + NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"}) + }) + + // ip and mask are within the network + cIp1 = mustParsePrefixUnmapped("10.0.1.0/16") + cIp2 = mustParsePrefixUnmapped("192.168.0.1/25") + c, _, _, _ := NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"}) + _, err = caPool.VerifyCertificate(time.Now(), c) + assert.Nil(t, err) + + // Exact matches + c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp1, caIp2}, nil, []string{"test"}) + assert.Nil(t, err) + _, err = caPool.VerifyCertificate(time.Now(), c) + assert.Nil(t, err) + + // Exact matches reversed + c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp2, caIp1}, nil, []string{"test"}) + assert.Nil(t, err) + _, err = caPool.VerifyCertificate(time.Now(), c) + assert.Nil(t, err) + + // Exact matches reversed with just 1 + c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp1}, nil, []string{"test"}) + assert.Nil(t, err) + _, err = caPool.VerifyCertificate(time.Now(), c) + assert.Nil(t, err) +} + +func TestCertificateV1_Verify_Subnets(t *testing.T) { + caIp1 := mustParsePrefixUnmapped("10.0.0.0/16") + caIp2 := mustParsePrefixUnmapped("192.168.0.0/24") + ca, _, caKey, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, []netip.Prefix{caIp1, caIp2}, []string{"test"}) + + caPem, err := ca.MarshalPEM() + assert.Nil(t, err) + + caPool := NewCAPool() + b, err := caPool.AddCAFromPEM(caPem) + assert.NoError(t, err) + assert.Empty(t, b) + + // ip is outside the network + cIp1 := mustParsePrefixUnmapped("10.1.0.0/24") + cIp2 := mustParsePrefixUnmapped("192.168.0.1/16") + assert.PanicsWithError(t, "certificate contained an unsafe network assignment outside the limitations of the signing ca: 10.1.0.0/24", func() { + NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"}) + }) + + // ip is outside the network reversed order of above + cIp1 = mustParsePrefixUnmapped("192.168.0.1/24") + cIp2 = mustParsePrefixUnmapped("10.1.0.0/24") + assert.PanicsWithError(t, "certificate contained an unsafe network assignment outside the limitations of the signing ca: 10.1.0.0/24", func() { + NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"}) + }) + + // ip is within the network but mask is outside + cIp1 = mustParsePrefixUnmapped("10.0.1.0/15") + cIp2 = mustParsePrefixUnmapped("192.168.0.1/24") + assert.PanicsWithError(t, "certificate contained an unsafe network assignment outside the limitations of the signing ca: 10.0.1.0/15", func() { + NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"}) + }) + + // ip is within the network but mask is outside reversed order of above + cIp1 = mustParsePrefixUnmapped("192.168.0.1/24") + cIp2 = mustParsePrefixUnmapped("10.0.1.0/15") + assert.PanicsWithError(t, "certificate contained an unsafe network assignment outside the limitations of the signing ca: 10.0.1.0/15", func() { + NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"}) + }) + + // ip and mask are within the network + cIp1 = mustParsePrefixUnmapped("10.0.1.0/16") + cIp2 = mustParsePrefixUnmapped("192.168.0.1/25") + c, _, _, _ := NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"}) + assert.Nil(t, err) + _, err = caPool.VerifyCertificate(time.Now(), c) + assert.Nil(t, err) + + // Exact matches + c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp1, caIp2}, []string{"test"}) + assert.Nil(t, err) + _, err = caPool.VerifyCertificate(time.Now(), c) + assert.Nil(t, err) + + // Exact matches reversed + c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp2, caIp1}, []string{"test"}) + assert.Nil(t, err) + _, err = caPool.VerifyCertificate(time.Now(), c) + assert.Nil(t, err) + + // Exact matches reversed with just 1 + c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp1}, []string{"test"}) + assert.Nil(t, err) + _, err = caPool.VerifyCertificate(time.Now(), c) + assert.Nil(t, err) +} + +func TestCertificateV2_Verify(t *testing.T) { + ca, _, caKey, _ := NewTestCaCert(Version2, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, nil) + c, _, _, _ := NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test cert", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil) + + caPool := NewCAPool() + assert.NoError(t, caPool.AddCA(ca)) + + f, err := c.Fingerprint() + assert.Nil(t, err) + caPool.BlocklistFingerprint(f) + + _, err = caPool.VerifyCertificate(time.Now(), c) + assert.EqualError(t, err, "certificate is in the block list") + + caPool.ResetCertBlocklist() + _, err = caPool.VerifyCertificate(time.Now(), c) + assert.Nil(t, err) + + _, err = caPool.VerifyCertificate(time.Now().Add(time.Hour*1000), c) + assert.EqualError(t, err, "root certificate is expired") + + assert.PanicsWithError(t, "certificate is valid before the signing certificate", func() { + NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test cert2", time.Time{}, time.Time{}, nil, nil, nil) + }) + + // Test group assertion + ca, _, caKey, _ = NewTestCaCert(Version2, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{"test1", "test2"}) + caPem, err := ca.MarshalPEM() + assert.Nil(t, err) + + caPool = NewCAPool() + b, err := caPool.AddCAFromPEM(caPem) + assert.NoError(t, err) + assert.Empty(t, b) + + assert.PanicsWithError(t, "certificate contained a group not present on the signing ca: bad", func() { + NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1", "bad"}) + }) + + c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test2", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"}) + assert.Nil(t, err) + _, err = caPool.VerifyCertificate(time.Now(), c) + assert.Nil(t, err) +} + +func TestCertificateV2_VerifyP256(t *testing.T) { + ca, _, caKey, _ := NewTestCaCert(Version2, Curve_P256, time.Now(), time.Now().Add(10*time.Minute), nil, nil, nil) + c, _, _, _ := NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil) + + caPool := NewCAPool() + assert.NoError(t, caPool.AddCA(ca)) + + f, err := c.Fingerprint() + assert.Nil(t, err) + caPool.BlocklistFingerprint(f) + + _, err = caPool.VerifyCertificate(time.Now(), c) + assert.EqualError(t, err, "certificate is in the block list") + + caPool.ResetCertBlocklist() + _, err = caPool.VerifyCertificate(time.Now(), c) + assert.Nil(t, err) + + _, err = caPool.VerifyCertificate(time.Now().Add(time.Hour*1000), c) + assert.EqualError(t, err, "root certificate is expired") + + assert.PanicsWithError(t, "certificate is valid before the signing certificate", func() { + NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil) + }) + + // Test group assertion + ca, _, caKey, _ = NewTestCaCert(Version2, Curve_P256, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{"test1", "test2"}) + caPem, err := ca.MarshalPEM() + assert.Nil(t, err) + + caPool = NewCAPool() + b, err := caPool.AddCAFromPEM(caPem) + assert.NoError(t, err) + assert.Empty(t, b) + + assert.PanicsWithError(t, "certificate contained a group not present on the signing ca: bad", func() { + NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1", "bad"}) + }) + + c, _, _, _ = NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"}) + _, err = caPool.VerifyCertificate(time.Now(), c) + assert.Nil(t, err) +} + +func TestCertificateV2_Verify_IPs(t *testing.T) { + caIp1 := mustParsePrefixUnmapped("10.0.0.0/16") + caIp2 := mustParsePrefixUnmapped("192.168.0.0/24") + ca, _, caKey, _ := NewTestCaCert(Version2, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), []netip.Prefix{caIp1, caIp2}, nil, []string{"test"}) + + caPem, err := ca.MarshalPEM() + assert.Nil(t, err) + + caPool := NewCAPool() + b, err := caPool.AddCAFromPEM(caPem) + assert.NoError(t, err) + assert.Empty(t, b) + + // ip is outside the network + cIp1 := mustParsePrefixUnmapped("10.1.0.0/24") + cIp2 := mustParsePrefixUnmapped("192.168.0.1/16") + assert.PanicsWithError(t, "certificate contained a network assignment outside the limitations of the signing ca: 10.1.0.0/24", func() { + NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"}) + }) + + // ip is outside the network reversed order of above + cIp1 = mustParsePrefixUnmapped("192.168.0.1/24") + cIp2 = mustParsePrefixUnmapped("10.1.0.0/24") + assert.PanicsWithError(t, "certificate contained a network assignment outside the limitations of the signing ca: 10.1.0.0/24", func() { + NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"}) + }) + + // ip is within the network but mask is outside + cIp1 = mustParsePrefixUnmapped("10.0.1.0/15") + cIp2 = mustParsePrefixUnmapped("192.168.0.1/24") + assert.PanicsWithError(t, "certificate contained a network assignment outside the limitations of the signing ca: 10.0.1.0/15", func() { + NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"}) + }) + + // ip is within the network but mask is outside reversed order of above + cIp1 = mustParsePrefixUnmapped("192.168.0.1/24") + cIp2 = mustParsePrefixUnmapped("10.0.1.0/15") + assert.PanicsWithError(t, "certificate contained a network assignment outside the limitations of the signing ca: 10.0.1.0/15", func() { + NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"}) + }) + + // ip and mask are within the network + cIp1 = mustParsePrefixUnmapped("10.0.1.0/16") + cIp2 = mustParsePrefixUnmapped("192.168.0.1/25") + c, _, _, _ := NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"}) + _, err = caPool.VerifyCertificate(time.Now(), c) + assert.Nil(t, err) + + // Exact matches + c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp1, caIp2}, nil, []string{"test"}) + assert.Nil(t, err) + _, err = caPool.VerifyCertificate(time.Now(), c) + assert.Nil(t, err) + + // Exact matches reversed + c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp2, caIp1}, nil, []string{"test"}) + assert.Nil(t, err) + _, err = caPool.VerifyCertificate(time.Now(), c) + assert.Nil(t, err) + + // Exact matches reversed with just 1 + c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp1}, nil, []string{"test"}) + assert.Nil(t, err) + _, err = caPool.VerifyCertificate(time.Now(), c) + assert.Nil(t, err) +} + +func TestCertificateV2_Verify_Subnets(t *testing.T) { + caIp1 := mustParsePrefixUnmapped("10.0.0.0/16") + caIp2 := mustParsePrefixUnmapped("192.168.0.0/24") + ca, _, caKey, _ := NewTestCaCert(Version2, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, []netip.Prefix{caIp1, caIp2}, []string{"test"}) + + caPem, err := ca.MarshalPEM() + assert.Nil(t, err) + + caPool := NewCAPool() + b, err := caPool.AddCAFromPEM(caPem) + assert.NoError(t, err) + assert.Empty(t, b) + + // ip is outside the network + cIp1 := mustParsePrefixUnmapped("10.1.0.0/24") + cIp2 := mustParsePrefixUnmapped("192.168.0.1/16") + assert.PanicsWithError(t, "certificate contained an unsafe network assignment outside the limitations of the signing ca: 10.1.0.0/24", func() { + NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"}) + }) + + // ip is outside the network reversed order of above + cIp1 = mustParsePrefixUnmapped("192.168.0.1/24") + cIp2 = mustParsePrefixUnmapped("10.1.0.0/24") + assert.PanicsWithError(t, "certificate contained an unsafe network assignment outside the limitations of the signing ca: 10.1.0.0/24", func() { + NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"}) + }) + + // ip is within the network but mask is outside + cIp1 = mustParsePrefixUnmapped("10.0.1.0/15") + cIp2 = mustParsePrefixUnmapped("192.168.0.1/24") + assert.PanicsWithError(t, "certificate contained an unsafe network assignment outside the limitations of the signing ca: 10.0.1.0/15", func() { + NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"}) + }) + + // ip is within the network but mask is outside reversed order of above + cIp1 = mustParsePrefixUnmapped("192.168.0.1/24") + cIp2 = mustParsePrefixUnmapped("10.0.1.0/15") + assert.PanicsWithError(t, "certificate contained an unsafe network assignment outside the limitations of the signing ca: 10.0.1.0/15", func() { + NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"}) + }) + + // ip and mask are within the network + cIp1 = mustParsePrefixUnmapped("10.0.1.0/16") + cIp2 = mustParsePrefixUnmapped("192.168.0.1/25") + c, _, _, _ := NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"}) + assert.Nil(t, err) + _, err = caPool.VerifyCertificate(time.Now(), c) + assert.Nil(t, err) + + // Exact matches + c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp1, caIp2}, []string{"test"}) + assert.Nil(t, err) + _, err = caPool.VerifyCertificate(time.Now(), c) + assert.Nil(t, err) + + // Exact matches reversed + c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp2, caIp1}, []string{"test"}) + assert.Nil(t, err) + _, err = caPool.VerifyCertificate(time.Now(), c) + assert.Nil(t, err) + + // Exact matches reversed with just 1 + c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp1}, []string{"test"}) + assert.Nil(t, err) + _, err = caPool.VerifyCertificate(time.Now(), c) + assert.Nil(t, err) +} diff --git a/cert/cert.go b/cert/cert.go index 02c8877..4246571 100644 --- a/cert/cert.go +++ b/cert/cert.go @@ -1,15 +1,17 @@ package cert import ( + "fmt" "net/netip" "time" ) -type Version int +type Version uint8 const ( - Version1 Version = 1 - Version2 Version = 2 + VersionPre1 Version = 0 + Version1 Version = 1 + Version2 Version = 2 ) type Certificate interface { @@ -107,23 +109,57 @@ type CachedCertificate struct { signerFingerprint string } -// UnmarshalCertificate will attempt to unmarshal a wire protocol level certificate. -func UnmarshalCertificate(b []byte) (Certificate, error) { - c, err := unmarshalCertificateV1(b, true) - if err != nil { - return nil, err - } - return c, nil +func (cc *CachedCertificate) String() string { + return cc.Certificate.String() } -// UnmarshalCertificateFromHandshake will attempt to unmarshal a certificate received in a handshake. +// RecombineAndValidate will attempt to unmarshal a certificate received in a handshake. // Handshakes save space by placing the peers public key in a different part of the packet, we have to // reassemble the actual certificate structure with that in mind. -func UnmarshalCertificateFromHandshake(b []byte, publicKey []byte) (Certificate, error) { - c, err := unmarshalCertificateV1(b, false) +func RecombineAndValidate(v Version, rawCertBytes, publicKey []byte, curve Curve, caPool *CAPool) (*CachedCertificate, error) { + if publicKey == nil { + return nil, ErrNoPeerStaticKey + } + + if rawCertBytes == nil { + return nil, ErrNoPayload + } + + c, err := unmarshalCertificateFromHandshake(v, rawCertBytes, publicKey, curve) + if err != nil { + return nil, fmt.Errorf("error unmarshaling cert: %w", err) + } + + cc, err := caPool.VerifyCertificate(time.Now(), c) + if err != nil { + return nil, fmt.Errorf("certificate validation failed: %w", err) + } + + return cc, nil +} + +func unmarshalCertificateFromHandshake(v Version, b []byte, publicKey []byte, curve Curve) (Certificate, error) { + var c Certificate + var err error + + switch v { + // Implementations must ensure the result is a valid cert! + case VersionPre1, Version1: + c, err = unmarshalCertificateV1(b, publicKey) + case Version2: + c, err = unmarshalCertificateV2(b, publicKey, curve) + default: + //TODO: CERT-V2 make a static var + return nil, fmt.Errorf("unknown certificate version %d", v) + } + if err != nil { return nil, err } - c.details.PublicKey = publicKey + + if c.Curve() != curve { + return nil, fmt.Errorf("certificate curve %s does not match expected %s", c.Curve().String(), curve.String()) + } + return c, nil } diff --git a/cert/cert_test.go b/cert/cert_test.go deleted file mode 100644 index 12bbd97..0000000 --- a/cert/cert_test.go +++ /dev/null @@ -1,695 +0,0 @@ -package cert - -import ( - "crypto/ecdh" - "crypto/ecdsa" - "crypto/elliptic" - "crypto/rand" - "fmt" - "io" - "net/netip" - "testing" - "time" - - "github.com/slackhq/nebula/test" - "github.com/stretchr/testify/assert" - "golang.org/x/crypto/curve25519" - "golang.org/x/crypto/ed25519" -) - -func TestMarshalingNebulaCertificate(t *testing.T) { - before := time.Now().Add(time.Second * -60).Round(time.Second) - after := time.Now().Add(time.Second * 60).Round(time.Second) - pubKey := []byte("1234567890abcedfghij1234567890ab") - - nc := certificateV1{ - details: detailsV1{ - Name: "testing", - Ips: []netip.Prefix{ - mustParsePrefixUnmapped("10.1.1.1/24"), - mustParsePrefixUnmapped("10.1.1.2/16"), - }, - Subnets: []netip.Prefix{ - mustParsePrefixUnmapped("9.1.1.2/24"), - mustParsePrefixUnmapped("9.1.1.3/16"), - }, - Groups: []string{"test-group1", "test-group2", "test-group3"}, - NotBefore: before, - NotAfter: after, - PublicKey: pubKey, - IsCA: false, - Issuer: "1234567890abcedfghij1234567890ab", - }, - signature: []byte("1234567890abcedfghij1234567890ab"), - } - - b, err := nc.Marshal() - assert.Nil(t, err) - //t.Log("Cert size:", len(b)) - - nc2, err := unmarshalCertificateV1(b, true) - assert.Nil(t, err) - - assert.Equal(t, nc.signature, nc2.Signature()) - assert.Equal(t, nc.details.Name, nc2.Name()) - assert.Equal(t, nc.details.NotBefore, nc2.NotBefore()) - assert.Equal(t, nc.details.NotAfter, nc2.NotAfter()) - assert.Equal(t, nc.details.PublicKey, nc2.PublicKey()) - assert.Equal(t, nc.details.IsCA, nc2.IsCA()) - - assert.Equal(t, nc.details.Ips, nc2.Networks()) - assert.Equal(t, nc.details.Subnets, nc2.UnsafeNetworks()) - - assert.Equal(t, nc.details.Groups, nc2.Groups()) -} - -//func TestNebulaCertificate_Sign(t *testing.T) { -// before := time.Now().Add(time.Second * -60).Round(time.Second) -// after := time.Now().Add(time.Second * 60).Round(time.Second) -// pubKey := []byte("1234567890abcedfghij1234567890ab") -// -// nc := certificateV1{ -// details: detailsV1{ -// Name: "testing", -// Ips: []netip.Prefix{ -// mustParsePrefixUnmapped("10.1.1.1/24"), -// mustParsePrefixUnmapped("10.1.1.2/16"), -// //TODO: netip cant do it -// //{IP: net.ParseIP("10.1.1.3"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))}, -// }, -// Subnets: []netip.Prefix{ -// //TODO: netip cant do it -// //{IP: net.ParseIP("9.1.1.1"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))}, -// mustParsePrefixUnmapped("9.1.1.2/24"), -// mustParsePrefixUnmapped("9.1.1.3/24"), -// }, -// Groups: []string{"test-group1", "test-group2", "test-group3"}, -// NotBefore: before, -// NotAfter: after, -// PublicKey: pubKey, -// IsCA: false, -// Issuer: "1234567890abcedfghij1234567890ab", -// }, -// } -// -// pub, priv, err := ed25519.GenerateKey(rand.Reader) -// assert.Nil(t, err) -// assert.False(t, nc.CheckSignature(pub)) -// assert.Nil(t, nc.Sign(Curve_CURVE25519, priv)) -// assert.True(t, nc.CheckSignature(pub)) -// -// _, err = nc.Marshal() -// assert.Nil(t, err) -// //t.Log("Cert size:", len(b)) -//} - -//func TestNebulaCertificate_SignP256(t *testing.T) { -// before := time.Now().Add(time.Second * -60).Round(time.Second) -// after := time.Now().Add(time.Second * 60).Round(time.Second) -// pubKey := []byte("01234567890abcedfghij1234567890ab1234567890abcedfghij1234567890ab") -// -// nc := certificateV1{ -// details: detailsV1{ -// Name: "testing", -// Ips: []netip.Prefix{ -// mustParsePrefixUnmapped("10.1.1.1/24"), -// mustParsePrefixUnmapped("10.1.1.2/16"), -// //TODO: netip no can do -// //{IP: net.ParseIP("10.1.1.3"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))}, -// }, -// Subnets: []netip.Prefix{ -// //TODO: netip bad -// //{IP: net.ParseIP("9.1.1.1"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))}, -// mustParsePrefixUnmapped("9.1.1.2/24"), -// mustParsePrefixUnmapped("9.1.1.3/16"), -// }, -// Groups: []string{"test-group1", "test-group2", "test-group3"}, -// NotBefore: before, -// NotAfter: after, -// PublicKey: pubKey, -// IsCA: false, -// Curve: Curve_P256, -// Issuer: "1234567890abcedfghij1234567890ab", -// }, -// } -// -// priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) -// pub := elliptic.Marshal(elliptic.P256(), priv.PublicKey.X, priv.PublicKey.Y) -// rawPriv := priv.D.FillBytes(make([]byte, 32)) -// -// assert.Nil(t, err) -// assert.False(t, nc.CheckSignature(pub)) -// assert.Nil(t, nc.Sign(Curve_P256, rawPriv)) -// assert.True(t, nc.CheckSignature(pub)) -// -// _, err = nc.Marshal() -// assert.Nil(t, err) -// //t.Log("Cert size:", len(b)) -//} - -func TestNebulaCertificate_Expired(t *testing.T) { - nc := certificateV1{ - details: detailsV1{ - NotBefore: time.Now().Add(time.Second * -60).Round(time.Second), - NotAfter: time.Now().Add(time.Second * 60).Round(time.Second), - }, - } - - assert.True(t, nc.Expired(time.Now().Add(time.Hour))) - assert.True(t, nc.Expired(time.Now().Add(-time.Hour))) - assert.False(t, nc.Expired(time.Now())) -} - -func TestNebulaCertificate_MarshalJSON(t *testing.T) { - time.Local = time.UTC - pubKey := []byte("1234567890abcedfghij1234567890ab") - - nc := certificateV1{ - details: detailsV1{ - Name: "testing", - Ips: []netip.Prefix{ - mustParsePrefixUnmapped("10.1.1.1/24"), - mustParsePrefixUnmapped("10.1.1.2/16"), - }, - Subnets: []netip.Prefix{ - mustParsePrefixUnmapped("9.1.1.2/24"), - mustParsePrefixUnmapped("9.1.1.3/16"), - }, - Groups: []string{"test-group1", "test-group2", "test-group3"}, - NotBefore: time.Date(1, 0, 0, 1, 0, 0, 0, time.UTC), - NotAfter: time.Date(1, 0, 0, 2, 0, 0, 0, time.UTC), - PublicKey: pubKey, - IsCA: false, - Issuer: "1234567890abcedfghij1234567890ab", - }, - signature: []byte("1234567890abcedfghij1234567890ab"), - } - - b, err := nc.MarshalJSON() - assert.Nil(t, err) - assert.Equal( - t, - "{\"details\":{\"curve\":\"CURVE25519\",\"groups\":[\"test-group1\",\"test-group2\",\"test-group3\"],\"ips\":[\"10.1.1.1/24\",\"10.1.1.2/16\"],\"isCa\":false,\"issuer\":\"1234567890abcedfghij1234567890ab\",\"name\":\"testing\",\"notAfter\":\"0000-11-30T02:00:00Z\",\"notBefore\":\"0000-11-30T01:00:00Z\",\"publicKey\":\"313233343536373839306162636564666768696a313233343536373839306162\",\"subnets\":[\"9.1.1.2/24\",\"9.1.1.3/16\"]},\"fingerprint\":\"3944c53d4267a229295b56cb2d27d459164c010ac97d655063ba421e0670f4ba\",\"signature\":\"313233343536373839306162636564666768696a313233343536373839306162\"}", - string(b), - ) -} - -func TestNebulaCertificate_Verify(t *testing.T) { - ca, _, caKey, err := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, nil) - assert.Nil(t, err) - - c, _, _, err := newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil) - assert.Nil(t, err) - - caPool := NewCAPool() - assert.NoError(t, caPool.AddCA(ca)) - - f, err := c.Fingerprint() - assert.Nil(t, err) - caPool.BlocklistFingerprint(f) - - _, err = caPool.VerifyCertificate(time.Now(), c) - assert.EqualError(t, err, "certificate is in the block list") - - caPool.ResetCertBlocklist() - _, err = caPool.VerifyCertificate(time.Now(), c) - assert.Nil(t, err) - - _, err = caPool.VerifyCertificate(time.Now().Add(time.Hour*1000), c) - assert.EqualError(t, err, "root certificate is expired") - - c, _, _, err = newTestCert(ca, caKey, time.Time{}, time.Time{}, nil, nil, nil) - assert.EqualError(t, err, "certificate is valid before the signing certificate") - - // Test group assertion - ca, _, caKey, err = newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{"test1", "test2"}) - assert.Nil(t, err) - - caPem, err := ca.MarshalPEM() - assert.Nil(t, err) - - caPool = NewCAPool() - b, err := caPool.AddCAFromPEM(caPem) - assert.NoError(t, err) - assert.Empty(t, b) - - c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1", "bad"}) - assert.EqualError(t, err, "certificate contained a group not present on the signing ca: bad") - - c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"}) - assert.Nil(t, err) - _, err = caPool.VerifyCertificate(time.Now(), c) - assert.Nil(t, err) -} - -func TestNebulaCertificate_VerifyP256(t *testing.T) { - ca, _, caKey, err := newTestCaCertP256(time.Now(), time.Now().Add(10*time.Minute), nil, nil, nil) - assert.Nil(t, err) - - c, _, _, err := newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil) - assert.Nil(t, err) - - caPool := NewCAPool() - assert.NoError(t, caPool.AddCA(ca)) - - f, err := c.Fingerprint() - assert.Nil(t, err) - caPool.BlocklistFingerprint(f) - - _, err = caPool.VerifyCertificate(time.Now(), c) - assert.EqualError(t, err, "certificate is in the block list") - - caPool.ResetCertBlocklist() - _, err = caPool.VerifyCertificate(time.Now(), c) - assert.Nil(t, err) - - _, err = caPool.VerifyCertificate(time.Now().Add(time.Hour*1000), c) - assert.EqualError(t, err, "root certificate is expired") - - c, _, _, err = newTestCert(ca, caKey, time.Time{}, time.Time{}, nil, nil, nil) - assert.EqualError(t, err, "certificate is valid before the signing certificate") - - // Test group assertion - ca, _, caKey, err = newTestCaCertP256(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{"test1", "test2"}) - assert.Nil(t, err) - - caPem, err := ca.MarshalPEM() - assert.Nil(t, err) - - caPool = NewCAPool() - b, err := caPool.AddCAFromPEM(caPem) - assert.NoError(t, err) - assert.Empty(t, b) - - c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1", "bad"}) - assert.EqualError(t, err, "certificate contained a group not present on the signing ca: bad") - - c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"}) - assert.Nil(t, err) - _, err = caPool.VerifyCertificate(time.Now(), c) - assert.Nil(t, err) -} - -func TestNebulaCertificate_Verify_IPs(t *testing.T) { - caIp1 := mustParsePrefixUnmapped("10.0.0.0/16") - caIp2 := mustParsePrefixUnmapped("192.168.0.0/24") - ca, _, caKey, err := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []netip.Prefix{caIp1, caIp2}, nil, []string{"test"}) - assert.Nil(t, err) - - caPem, err := ca.MarshalPEM() - assert.Nil(t, err) - - caPool := NewCAPool() - b, err := caPool.AddCAFromPEM(caPem) - assert.NoError(t, err) - assert.Empty(t, b) - - // ip is outside the network - cIp1 := mustParsePrefixUnmapped("10.1.0.0/24") - cIp2 := mustParsePrefixUnmapped("192.168.0.1/16") - c, _, _, err := newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"}) - assert.EqualError(t, err, "certificate contained a network assignment outside the limitations of the signing ca: 10.1.0.0/24") - - // ip is outside the network reversed order of above - cIp1 = mustParsePrefixUnmapped("192.168.0.1/24") - cIp2 = mustParsePrefixUnmapped("10.1.0.0/24") - c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"}) - assert.EqualError(t, err, "certificate contained a network assignment outside the limitations of the signing ca: 10.1.0.0/24") - - // ip is within the network but mask is outside - cIp1 = mustParsePrefixUnmapped("10.0.1.0/15") - cIp2 = mustParsePrefixUnmapped("192.168.0.1/24") - c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"}) - assert.EqualError(t, err, "certificate contained a network assignment outside the limitations of the signing ca: 10.0.1.0/15") - - // ip is within the network but mask is outside reversed order of above - cIp1 = mustParsePrefixUnmapped("192.168.0.1/24") - cIp2 = mustParsePrefixUnmapped("10.0.1.0/15") - c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"}) - assert.EqualError(t, err, "certificate contained a network assignment outside the limitations of the signing ca: 10.0.1.0/15") - - // ip and mask are within the network - cIp1 = mustParsePrefixUnmapped("10.0.1.0/16") - cIp2 = mustParsePrefixUnmapped("192.168.0.1/25") - c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"}) - assert.Nil(t, err) - _, err = caPool.VerifyCertificate(time.Now(), c) - assert.Nil(t, err) - - // Exact matches - c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp1, caIp2}, nil, []string{"test"}) - assert.Nil(t, err) - _, err = caPool.VerifyCertificate(time.Now(), c) - assert.Nil(t, err) - - // Exact matches reversed - c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp2, caIp1}, nil, []string{"test"}) - assert.Nil(t, err) - _, err = caPool.VerifyCertificate(time.Now(), c) - assert.Nil(t, err) - - // Exact matches reversed with just 1 - c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp1}, nil, []string{"test"}) - assert.Nil(t, err) - _, err = caPool.VerifyCertificate(time.Now(), c) - assert.Nil(t, err) -} - -func TestNebulaCertificate_Verify_Subnets(t *testing.T) { - caIp1 := mustParsePrefixUnmapped("10.0.0.0/16") - caIp2 := mustParsePrefixUnmapped("192.168.0.0/24") - ca, _, caKey, err := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, []netip.Prefix{caIp1, caIp2}, []string{"test"}) - assert.Nil(t, err) - - caPem, err := ca.MarshalPEM() - assert.Nil(t, err) - - caPool := NewCAPool() - b, err := caPool.AddCAFromPEM(caPem) - assert.NoError(t, err) - assert.Empty(t, b) - - // ip is outside the network - cIp1 := mustParsePrefixUnmapped("10.1.0.0/24") - cIp2 := mustParsePrefixUnmapped("192.168.0.1/16") - c, _, _, err := newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"}) - assert.EqualError(t, err, "certificate contained an unsafe network assignment outside the limitations of the signing ca: 10.1.0.0/24") - - // ip is outside the network reversed order of above - cIp1 = mustParsePrefixUnmapped("192.168.0.1/24") - cIp2 = mustParsePrefixUnmapped("10.1.0.0/24") - c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"}) - assert.EqualError(t, err, "certificate contained an unsafe network assignment outside the limitations of the signing ca: 10.1.0.0/24") - - // ip is within the network but mask is outside - cIp1 = mustParsePrefixUnmapped("10.0.1.0/15") - cIp2 = mustParsePrefixUnmapped("192.168.0.1/24") - c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"}) - assert.EqualError(t, err, "certificate contained an unsafe network assignment outside the limitations of the signing ca: 10.0.1.0/15") - - // ip is within the network but mask is outside reversed order of above - cIp1 = mustParsePrefixUnmapped("192.168.0.1/24") - cIp2 = mustParsePrefixUnmapped("10.0.1.0/15") - c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"}) - assert.EqualError(t, err, "certificate contained an unsafe network assignment outside the limitations of the signing ca: 10.0.1.0/15") - - // ip and mask are within the network - cIp1 = mustParsePrefixUnmapped("10.0.1.0/16") - cIp2 = mustParsePrefixUnmapped("192.168.0.1/25") - c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"}) - assert.Nil(t, err) - _, err = caPool.VerifyCertificate(time.Now(), c) - assert.Nil(t, err) - - // Exact matches - c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp1, caIp2}, []string{"test"}) - assert.Nil(t, err) - _, err = caPool.VerifyCertificate(time.Now(), c) - assert.Nil(t, err) - - // Exact matches reversed - c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp2, caIp1}, []string{"test"}) - assert.Nil(t, err) - _, err = caPool.VerifyCertificate(time.Now(), c) - assert.Nil(t, err) - - // Exact matches reversed with just 1 - c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp1}, []string{"test"}) - assert.Nil(t, err) - _, err = caPool.VerifyCertificate(time.Now(), c) - assert.Nil(t, err) -} - -func TestNebulaCertificate_VerifyPrivateKey(t *testing.T) { - ca, _, caKey, err := newTestCaCert(time.Time{}, time.Time{}, nil, nil, nil) - assert.Nil(t, err) - err = ca.VerifyPrivateKey(Curve_CURVE25519, caKey) - assert.Nil(t, err) - - _, _, caKey2, err := newTestCaCert(time.Time{}, time.Time{}, nil, nil, nil) - assert.Nil(t, err) - err = ca.VerifyPrivateKey(Curve_CURVE25519, caKey2) - assert.NotNil(t, err) - - c, _, priv, err := newTestCert(ca, caKey, time.Time{}, time.Time{}, nil, nil, nil) - err = c.VerifyPrivateKey(Curve_CURVE25519, priv) - assert.Nil(t, err) - - _, priv2 := x25519Keypair() - err = c.VerifyPrivateKey(Curve_CURVE25519, priv2) - assert.NotNil(t, err) -} - -func TestNebulaCertificate_VerifyPrivateKeyP256(t *testing.T) { - ca, _, caKey, err := newTestCaCertP256(time.Time{}, time.Time{}, nil, nil, nil) - assert.Nil(t, err) - err = ca.VerifyPrivateKey(Curve_P256, caKey) - assert.Nil(t, err) - - _, _, caKey2, err := newTestCaCertP256(time.Time{}, time.Time{}, nil, nil, nil) - assert.Nil(t, err) - err = ca.VerifyPrivateKey(Curve_P256, caKey2) - assert.NotNil(t, err) - - c, _, priv, err := newTestCert(ca, caKey, time.Time{}, time.Time{}, nil, nil, nil) - err = c.VerifyPrivateKey(Curve_P256, priv) - assert.Nil(t, err) - - _, priv2 := p256Keypair() - err = c.VerifyPrivateKey(Curve_P256, priv2) - assert.NotNil(t, err) -} - -func appendByteSlices(b ...[]byte) []byte { - retSlice := []byte{} - for _, v := range b { - retSlice = append(retSlice, v...) - } - return retSlice -} - -// Ensure that upgrading the protobuf library does not change how certificates -// are marshalled, since this would break signature verification -//TODO: since netip cant represent 255.0.255.0 netmask we can't verify the old certs are ok -//func TestMarshalingNebulaCertificateConsistency(t *testing.T) { -// before := time.Date(2009, time.November, 10, 23, 0, 0, 0, time.UTC) -// after := time.Date(2017, time.January, 18, 28, 40, 0, 0, time.UTC) -// pubKey := []byte("1234567890abcedfghij1234567890ab") -// -// nc := certificateV1{ -// details: detailsV1{ -// Name: "testing", -// Ips: []netip.Prefix{ -// mustParsePrefixUnmapped("10.1.1.1/24"), -// mustParsePrefixUnmapped("10.1.1.2/16"), -// //TODO: netip bad -// //{IP: net.ParseIP("10.1.1.3"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))}, -// }, -// Subnets: []netip.Prefix{ -// //TODO: netip bad -// //{IP: net.ParseIP("9.1.1.1"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))}, -// mustParsePrefixUnmapped("9.1.1.2/24"), -// mustParsePrefixUnmapped("9.1.1.3/16"), -// }, -// Groups: []string{"test-group1", "test-group2", "test-group3"}, -// NotBefore: before, -// NotAfter: after, -// PublicKey: pubKey, -// IsCA: false, -// Issuer: "1234567890abcedfghij1234567890ab", -// }, -// signature: []byte("1234567890abcedfghij1234567890ab"), -// } -// -// b, err := nc.Marshal() -// assert.Nil(t, err) -// //t.Log("Cert size:", len(b)) -// assert.Equal(t, "0aa2010a0774657374696e67121b8182845080feffff0f828284508080fcff0f8382845080fe83f80f1a1b8182844880fe83f80f8282844880feffff0f838284488080fcff0f220b746573742d67726f757031220b746573742d67726f757032220b746573742d67726f75703328f0e0e7d70430a08681c4053a20313233343536373839306162636564666768696a3132333435363738393061624a081234567890abcedf1220313233343536373839306162636564666768696a313233343536373839306162", fmt.Sprintf("%x", b)) -// -// b, err = proto.Marshal(nc.getRawDetails()) -// assert.Nil(t, err) -// //t.Log("Raw cert size:", len(b)) -// assert.Equal(t, "0a0774657374696e67121b8182845080feffff0f828284508080fcff0f8382845080fe83f80f1a1b8182844880fe83f80f8282844880feffff0f838284488080fcff0f220b746573742d67726f757031220b746573742d67726f757032220b746573742d67726f75703328f0e0e7d70430a08681c4053a20313233343536373839306162636564666768696a3132333435363738393061624a081234567890abcedf", fmt.Sprintf("%x", b)) -//} - -func TestNebulaCertificate_Copy(t *testing.T) { - ca, _, caKey, err := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, nil) - assert.Nil(t, err) - - c, _, _, err := newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil) - assert.Nil(t, err) - cc := c.Copy() - - test.AssertDeepCopyEqual(t, c, cc) -} - -func TestUnmarshalNebulaCertificate(t *testing.T) { - // Test that we don't panic with an invalid certificate (#332) - data := []byte("\x98\x00\x00") - _, err := unmarshalCertificateV1(data, true) - assert.EqualError(t, err, "encoded Details was nil") -} - -func newTestCaCert(before, after time.Time, ips, subnets []netip.Prefix, groups []string) (Certificate, []byte, []byte, error) { - pub, priv, err := ed25519.GenerateKey(rand.Reader) - if before.IsZero() { - before = time.Now().Add(time.Second * -60).Round(time.Second) - } - if after.IsZero() { - after = time.Now().Add(time.Second * 60).Round(time.Second) - } - - tbs := &TBSCertificate{ - Version: Version1, - Name: "test ca", - IsCA: true, - NotBefore: time.Unix(before.Unix(), 0), - NotAfter: time.Unix(after.Unix(), 0), - PublicKey: pub, - } - - if len(ips) > 0 { - tbs.Networks = ips - } - - if len(subnets) > 0 { - tbs.UnsafeNetworks = subnets - } - - if len(groups) > 0 { - tbs.Groups = groups - } - - nc, err := tbs.Sign(nil, Curve_CURVE25519, priv) - if err != nil { - return nil, nil, nil, err - } - return nc, pub, priv, nil -} - -func newTestCaCertP256(before, after time.Time, ips, subnets []netip.Prefix, groups []string) (Certificate, []byte, []byte, error) { - priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) - pub := elliptic.Marshal(elliptic.P256(), priv.PublicKey.X, priv.PublicKey.Y) - rawPriv := priv.D.FillBytes(make([]byte, 32)) - - if before.IsZero() { - before = time.Now().Add(time.Second * -60).Round(time.Second) - } - if after.IsZero() { - after = time.Now().Add(time.Second * 60).Round(time.Second) - } - - tbs := &TBSCertificate{ - Version: Version1, - Name: "test ca", - IsCA: true, - NotBefore: time.Unix(before.Unix(), 0), - NotAfter: time.Unix(after.Unix(), 0), - PublicKey: pub, - Curve: Curve_P256, - } - - if len(ips) > 0 { - tbs.Networks = ips - } - - if len(subnets) > 0 { - tbs.UnsafeNetworks = subnets - } - - if len(groups) > 0 { - tbs.Groups = groups - } - - nc, err := tbs.Sign(nil, Curve_P256, rawPriv) - if err != nil { - return nil, nil, nil, err - } - return nc, pub, rawPriv, nil -} - -func newTestCert(ca Certificate, key []byte, before, after time.Time, ips, subnets []netip.Prefix, groups []string) (Certificate, []byte, []byte, error) { - if before.IsZero() { - before = time.Now().Add(time.Second * -60).Round(time.Second) - } - if after.IsZero() { - after = time.Now().Add(time.Second * 60).Round(time.Second) - } - - if len(groups) == 0 { - groups = []string{"test-group1", "test-group2", "test-group3"} - } - - if len(ips) == 0 { - ips = []netip.Prefix{ - mustParsePrefixUnmapped("10.1.1.1/24"), - mustParsePrefixUnmapped("10.1.1.2/16"), - } - } - - if len(subnets) == 0 { - subnets = []netip.Prefix{ - mustParsePrefixUnmapped("9.1.1.2/24"), - mustParsePrefixUnmapped("9.1.1.3/16"), - } - } - - var pub, rawPriv []byte - - switch ca.Curve() { - case Curve_CURVE25519: - pub, rawPriv = x25519Keypair() - case Curve_P256: - pub, rawPriv = p256Keypair() - default: - return nil, nil, nil, fmt.Errorf("unknown curve: %v", ca.Curve()) - } - - tbs := &TBSCertificate{ - Version: Version1, - Name: "testing", - Networks: ips, - UnsafeNetworks: subnets, - Groups: groups, - IsCA: false, - NotBefore: time.Unix(before.Unix(), 0), - NotAfter: time.Unix(after.Unix(), 0), - PublicKey: pub, - Curve: ca.Curve(), - } - - nc, err := tbs.Sign(ca, ca.Curve(), key) - if err != nil { - return nil, nil, nil, err - } - - return nc, pub, rawPriv, nil -} - -func x25519Keypair() ([]byte, []byte) { - privkey := make([]byte, 32) - if _, err := io.ReadFull(rand.Reader, privkey); err != nil { - panic(err) - } - - pubkey, err := curve25519.X25519(privkey, curve25519.Basepoint) - if err != nil { - panic(err) - } - - return pubkey, privkey -} - -func p256Keypair() ([]byte, []byte) { - privkey, err := ecdh.P256().GenerateKey(rand.Reader) - if err != nil { - panic(err) - } - pubkey := privkey.PublicKey() - return pubkey.Bytes(), privkey.Bytes() -} - -func mustParsePrefixUnmapped(s string) netip.Prefix { - prefix := netip.MustParsePrefix(s) - return netip.PrefixFrom(prefix.Addr().Unmap(), prefix.Bits()) -} diff --git a/cert/cert_v1.go b/cert/cert_v1.go index 165e409..6bb146f 100644 --- a/cert/cert_v1.go +++ b/cert/cert_v1.go @@ -6,19 +6,16 @@ import ( "crypto/ecdsa" "crypto/ed25519" "crypto/elliptic" - "crypto/rand" "crypto/sha256" "encoding/binary" "encoding/hex" "encoding/json" "encoding/pem" "fmt" - "math/big" "net" "net/netip" "time" - "github.com/slackhq/nebula/pkclient" "golang.org/x/crypto/curve25519" "google.golang.org/protobuf/proto" ) @@ -31,71 +28,71 @@ type certificateV1 struct { } type detailsV1 struct { - Name string - Ips []netip.Prefix - Subnets []netip.Prefix - Groups []string - NotBefore time.Time - NotAfter time.Time - PublicKey []byte - IsCA bool - Issuer string + name string + networks []netip.Prefix + unsafeNetworks []netip.Prefix + groups []string + notBefore time.Time + notAfter time.Time + publicKey []byte + isCA bool + issuer string - Curve Curve + curve Curve } type m map[string]interface{} -func (nc *certificateV1) Version() Version { +func (c *certificateV1) Version() Version { return Version1 } -func (nc *certificateV1) Curve() Curve { - return nc.details.Curve +func (c *certificateV1) Curve() Curve { + return c.details.curve } -func (nc *certificateV1) Groups() []string { - return nc.details.Groups +func (c *certificateV1) Groups() []string { + return c.details.groups } -func (nc *certificateV1) IsCA() bool { - return nc.details.IsCA +func (c *certificateV1) IsCA() bool { + return c.details.isCA } -func (nc *certificateV1) Issuer() string { - return nc.details.Issuer +func (c *certificateV1) Issuer() string { + return c.details.issuer } -func (nc *certificateV1) Name() string { - return nc.details.Name +func (c *certificateV1) Name() string { + return c.details.name } -func (nc *certificateV1) Networks() []netip.Prefix { - return nc.details.Ips +func (c *certificateV1) Networks() []netip.Prefix { + return c.details.networks } -func (nc *certificateV1) NotAfter() time.Time { - return nc.details.NotAfter +func (c *certificateV1) NotAfter() time.Time { + return c.details.notAfter } -func (nc *certificateV1) NotBefore() time.Time { - return nc.details.NotBefore +func (c *certificateV1) NotBefore() time.Time { + return c.details.notBefore } -func (nc *certificateV1) PublicKey() []byte { - return nc.details.PublicKey +func (c *certificateV1) PublicKey() []byte { + return c.details.publicKey } -func (nc *certificateV1) Signature() []byte { - return nc.signature +func (c *certificateV1) Signature() []byte { + return c.signature } -func (nc *certificateV1) UnsafeNetworks() []netip.Prefix { - return nc.details.Subnets +func (c *certificateV1) UnsafeNetworks() []netip.Prefix { + return c.details.unsafeNetworks } -func (nc *certificateV1) Fingerprint() (string, error) { - b, err := nc.Marshal() +func (c *certificateV1) Fingerprint() (string, error) { + b, err := c.Marshal() if err != nil { return "", err } @@ -104,33 +101,33 @@ func (nc *certificateV1) Fingerprint() (string, error) { return hex.EncodeToString(sum[:]), nil } -func (nc *certificateV1) CheckSignature(key []byte) bool { - b, err := proto.Marshal(nc.getRawDetails()) +func (c *certificateV1) CheckSignature(key []byte) bool { + b, err := proto.Marshal(c.getRawDetails()) if err != nil { return false } - switch nc.details.Curve { + switch c.details.curve { case Curve_CURVE25519: - return ed25519.Verify(key, b, nc.signature) + return ed25519.Verify(key, b, c.signature) case Curve_P256: x, y := elliptic.Unmarshal(elliptic.P256(), key) pubKey := &ecdsa.PublicKey{Curve: elliptic.P256(), X: x, Y: y} hashed := sha256.Sum256(b) - return ecdsa.VerifyASN1(pubKey, hashed[:], nc.signature) + return ecdsa.VerifyASN1(pubKey, hashed[:], c.signature) default: return false } } -func (nc *certificateV1) Expired(t time.Time) bool { - return nc.details.NotBefore.After(t) || nc.details.NotAfter.Before(t) +func (c *certificateV1) Expired(t time.Time) bool { + return c.details.notBefore.After(t) || c.details.notAfter.Before(t) } -func (nc *certificateV1) VerifyPrivateKey(curve Curve, key []byte) error { - if curve != nc.details.Curve { +func (c *certificateV1) VerifyPrivateKey(curve Curve, key []byte) error { + if curve != c.details.curve { return fmt.Errorf("curve in cert and private key supplied don't match") } - if nc.details.IsCA { + if c.details.isCA { switch curve { case Curve_CURVE25519: // the call to PublicKey below will panic slice bounds out of range otherwise @@ -138,7 +135,7 @@ func (nc *certificateV1) VerifyPrivateKey(curve Curve, key []byte) error { return fmt.Errorf("key was not 64 bytes, is invalid ed25519 private key") } - if !ed25519.PublicKey(nc.details.PublicKey).Equal(ed25519.PrivateKey(key).Public()) { + if !ed25519.PublicKey(c.details.publicKey).Equal(ed25519.PrivateKey(key).Public()) { return fmt.Errorf("public key in cert and private key supplied don't match") } case Curve_P256: @@ -147,7 +144,7 @@ func (nc *certificateV1) VerifyPrivateKey(curve Curve, key []byte) error { return fmt.Errorf("cannot parse private key as P256: %w", err) } pub := privkey.PublicKey().Bytes() - if !bytes.Equal(pub, nc.details.PublicKey) { + if !bytes.Equal(pub, c.details.publicKey) { return fmt.Errorf("public key in cert and private key supplied don't match") } default: @@ -173,7 +170,7 @@ func (nc *certificateV1) VerifyPrivateKey(curve Curve, key []byte) error { default: return fmt.Errorf("invalid curve: %s", curve) } - if !bytes.Equal(pub, nc.details.PublicKey) { + if !bytes.Equal(pub, c.details.publicKey) { return fmt.Errorf("public key in cert and private key supplied don't match") } @@ -181,173 +178,219 @@ func (nc *certificateV1) VerifyPrivateKey(curve Curve, key []byte) error { } // getRawDetails marshals the raw details into protobuf ready struct -func (nc *certificateV1) getRawDetails() *RawNebulaCertificateDetails { +func (c *certificateV1) getRawDetails() *RawNebulaCertificateDetails { rd := &RawNebulaCertificateDetails{ - Name: nc.details.Name, - Groups: nc.details.Groups, - NotBefore: nc.details.NotBefore.Unix(), - NotAfter: nc.details.NotAfter.Unix(), - PublicKey: make([]byte, len(nc.details.PublicKey)), - IsCA: nc.details.IsCA, - Curve: nc.details.Curve, + Name: c.details.name, + Groups: c.details.groups, + NotBefore: c.details.notBefore.Unix(), + NotAfter: c.details.notAfter.Unix(), + PublicKey: make([]byte, len(c.details.publicKey)), + IsCA: c.details.isCA, + Curve: c.details.curve, } - for _, ipNet := range nc.details.Ips { + for _, ipNet := range c.details.networks { mask := net.CIDRMask(ipNet.Bits(), ipNet.Addr().BitLen()) rd.Ips = append(rd.Ips, addr2int(ipNet.Addr()), ip2int(mask)) } - for _, ipNet := range nc.details.Subnets { + for _, ipNet := range c.details.unsafeNetworks { mask := net.CIDRMask(ipNet.Bits(), ipNet.Addr().BitLen()) rd.Subnets = append(rd.Subnets, addr2int(ipNet.Addr()), ip2int(mask)) } - copy(rd.PublicKey, nc.details.PublicKey[:]) + copy(rd.PublicKey, c.details.publicKey[:]) // I know, this is terrible - rd.Issuer, _ = hex.DecodeString(nc.details.Issuer) + rd.Issuer, _ = hex.DecodeString(c.details.issuer) return rd } -func (nc *certificateV1) String() string { - if nc == nil { - return "Certificate {}\n" +func (c *certificateV1) String() string { + b, err := json.MarshalIndent(c.marshalJSON(), "", "\t") + if err != nil { + return fmt.Sprintf("", err) } - - s := "NebulaCertificate {\n" - s += "\tDetails {\n" - s += fmt.Sprintf("\t\tName: %v\n", nc.details.Name) - - if len(nc.details.Ips) > 0 { - s += "\t\tIps: [\n" - for _, ip := range nc.details.Ips { - s += fmt.Sprintf("\t\t\t%v\n", ip.String()) - } - s += "\t\t]\n" - } else { - s += "\t\tIps: []\n" - } - - if len(nc.details.Subnets) > 0 { - s += "\t\tSubnets: [\n" - for _, ip := range nc.details.Subnets { - s += fmt.Sprintf("\t\t\t%v\n", ip.String()) - } - s += "\t\t]\n" - } else { - s += "\t\tSubnets: []\n" - } - - if len(nc.details.Groups) > 0 { - s += "\t\tGroups: [\n" - for _, g := range nc.details.Groups { - s += fmt.Sprintf("\t\t\t\"%v\"\n", g) - } - s += "\t\t]\n" - } else { - s += "\t\tGroups: []\n" - } - - s += fmt.Sprintf("\t\tNot before: %v\n", nc.details.NotBefore) - s += fmt.Sprintf("\t\tNot After: %v\n", nc.details.NotAfter) - s += fmt.Sprintf("\t\tIs CA: %v\n", nc.details.IsCA) - s += fmt.Sprintf("\t\tIssuer: %s\n", nc.details.Issuer) - s += fmt.Sprintf("\t\tPublic key: %x\n", nc.details.PublicKey) - s += fmt.Sprintf("\t\tCurve: %s\n", nc.details.Curve) - s += "\t}\n" - fp, err := nc.Fingerprint() - if err == nil { - s += fmt.Sprintf("\tFingerprint: %s\n", fp) - } - s += fmt.Sprintf("\tSignature: %x\n", nc.Signature()) - s += "}" - - return s + return string(b) } -func (nc *certificateV1) MarshalForHandshakes() ([]byte, error) { - pubKey := nc.details.PublicKey - nc.details.PublicKey = nil - rawCertNoKey, err := nc.Marshal() +func (c *certificateV1) MarshalForHandshakes() ([]byte, error) { + pubKey := c.details.publicKey + c.details.publicKey = nil + rawCertNoKey, err := c.Marshal() if err != nil { return nil, err } - nc.details.PublicKey = pubKey + c.details.publicKey = pubKey return rawCertNoKey, nil } -func (nc *certificateV1) Marshal() ([]byte, error) { +func (c *certificateV1) Marshal() ([]byte, error) { rc := RawNebulaCertificate{ - Details: nc.getRawDetails(), - Signature: nc.signature, + Details: c.getRawDetails(), + Signature: c.signature, } return proto.Marshal(&rc) } -func (nc *certificateV1) MarshalPEM() ([]byte, error) { - b, err := nc.Marshal() +func (c *certificateV1) MarshalPEM() ([]byte, error) { + b, err := c.Marshal() if err != nil { return nil, err } return pem.EncodeToMemory(&pem.Block{Type: CertificateBanner, Bytes: b}), nil } -func (nc *certificateV1) MarshalJSON() ([]byte, error) { - fp, _ := nc.Fingerprint() - jc := m{ - "details": m{ - "name": nc.details.Name, - "ips": nc.details.Ips, - "subnets": nc.details.Subnets, - "groups": nc.details.Groups, - "notBefore": nc.details.NotBefore, - "notAfter": nc.details.NotAfter, - "publicKey": fmt.Sprintf("%x", nc.details.PublicKey), - "isCa": nc.details.IsCA, - "issuer": nc.details.Issuer, - "curve": nc.details.Curve.String(), - }, - "fingerprint": fp, - "signature": fmt.Sprintf("%x", nc.Signature()), - } - return json.Marshal(jc) +func (c *certificateV1) MarshalJSON() ([]byte, error) { + return json.Marshal(c.marshalJSON()) } -func (nc *certificateV1) Copy() Certificate { - c := &certificateV1{ - details: detailsV1{ - Name: nc.details.Name, - Groups: make([]string, len(nc.details.Groups)), - Ips: make([]netip.Prefix, len(nc.details.Ips)), - Subnets: make([]netip.Prefix, len(nc.details.Subnets)), - NotBefore: nc.details.NotBefore, - NotAfter: nc.details.NotAfter, - PublicKey: make([]byte, len(nc.details.PublicKey)), - IsCA: nc.details.IsCA, - Issuer: nc.details.Issuer, +func (c *certificateV1) marshalJSON() m { + fp, _ := c.Fingerprint() + return m{ + "version": Version1, + "details": m{ + "name": c.details.name, + "networks": c.details.networks, + "unsafeNetworks": c.details.unsafeNetworks, + "groups": c.details.groups, + "notBefore": c.details.notBefore, + "notAfter": c.details.notAfter, + "publicKey": fmt.Sprintf("%x", c.details.publicKey), + "isCa": c.details.isCA, + "issuer": c.details.issuer, + "curve": c.details.curve.String(), }, - signature: make([]byte, len(nc.signature)), + "fingerprint": fp, + "signature": fmt.Sprintf("%x", c.Signature()), + } +} + +func (c *certificateV1) Copy() Certificate { + nc := &certificateV1{ + details: detailsV1{ + name: c.details.name, + notBefore: c.details.notBefore, + notAfter: c.details.notAfter, + publicKey: make([]byte, len(c.details.publicKey)), + isCA: c.details.isCA, + issuer: c.details.issuer, + curve: c.details.curve, + }, + signature: make([]byte, len(c.signature)), } - copy(c.signature, nc.signature) - copy(c.details.Groups, nc.details.Groups) - copy(c.details.PublicKey, nc.details.PublicKey) - - for i, p := range nc.details.Ips { - c.details.Ips[i] = p + if c.details.groups != nil { + nc.details.groups = make([]string, len(c.details.groups)) + copy(nc.details.groups, c.details.groups) } - for i, p := range nc.details.Subnets { - c.details.Subnets[i] = p + if c.details.networks != nil { + nc.details.networks = make([]netip.Prefix, len(c.details.networks)) + copy(nc.details.networks, c.details.networks) } - return c + if c.details.unsafeNetworks != nil { + nc.details.unsafeNetworks = make([]netip.Prefix, len(c.details.unsafeNetworks)) + copy(nc.details.unsafeNetworks, c.details.unsafeNetworks) + } + + copy(nc.signature, c.signature) + copy(nc.details.publicKey, c.details.publicKey) + + return nc +} + +func (c *certificateV1) fromTBSCertificate(t *TBSCertificate) error { + c.details = detailsV1{ + name: t.Name, + networks: t.Networks, + unsafeNetworks: t.UnsafeNetworks, + groups: t.Groups, + notBefore: t.NotBefore, + notAfter: t.NotAfter, + publicKey: t.PublicKey, + isCA: t.IsCA, + curve: t.Curve, + issuer: t.issuer, + } + + return c.validate() +} + +func (c *certificateV1) validate() error { + // Empty names are allowed + + if len(c.details.publicKey) == 0 { + return ErrInvalidPublicKey + } + + // Original v1 rules allowed multiple networks to be present but ignored all but the first one. + // Continue to allow this behavior + if !c.details.isCA && len(c.details.networks) == 0 { + return NewErrInvalidCertificateProperties("non-CA certificates must contain exactly one network") + } + + for _, network := range c.details.networks { + if !network.IsValid() || !network.Addr().IsValid() { + return NewErrInvalidCertificateProperties("invalid network: %s", network) + } + + if network.Addr().Is6() { + return NewErrInvalidCertificateProperties("certificate may not contain IPv6 networks: %v", network) + } + + if network.Addr().IsUnspecified() { + return NewErrInvalidCertificateProperties("non-CA certificates must not use the zero address as a network: %s", network) + } + + if network.Addr().Zone() != "" { + return NewErrInvalidCertificateProperties("networks may not contain zones: %s", network) + } + } + + for _, network := range c.details.unsafeNetworks { + if !network.IsValid() || !network.Addr().IsValid() { + return NewErrInvalidCertificateProperties("invalid unsafe network: %s", network) + } + + if network.Addr().Is6() { + return NewErrInvalidCertificateProperties("certificate may not contain IPv6 unsafe networks: %v", network) + } + + if network.Addr().Zone() != "" { + return NewErrInvalidCertificateProperties("unsafe networks may not contain zones: %s", network) + } + } + + // v1 doesn't bother with sort order or uniqueness of networks or unsafe networks. + // We can't modify the unmarshalled data because verification requires re-marshalling and a re-ordered + // unsafe networks would result in a different signature. + + return nil +} + +func (c *certificateV1) marshalForSigning() ([]byte, error) { + b, err := proto.Marshal(c.getRawDetails()) + if err != nil { + return nil, err + } + return b, nil +} + +func (c *certificateV1) setSignature(b []byte) error { + if len(b) == 0 { + return ErrEmptySignature + } + c.signature = b + return nil } // unmarshalCertificateV1 will unmarshal a protobuf byte representation of a nebula cert -func unmarshalCertificateV1(b []byte, assertPublicKey bool) (*certificateV1, error) { +// if the publicKey is provided here then it is not required to be present in `b` +func unmarshalCertificateV1(b []byte, publicKey []byte) (*certificateV1, error) { if len(b) == 0 { return nil, fmt.Errorf("nil byte array") } @@ -371,27 +414,28 @@ func unmarshalCertificateV1(b []byte, assertPublicKey bool) (*certificateV1, err nc := certificateV1{ details: detailsV1{ - Name: rc.Details.Name, - Groups: make([]string, len(rc.Details.Groups)), - Ips: make([]netip.Prefix, len(rc.Details.Ips)/2), - Subnets: make([]netip.Prefix, len(rc.Details.Subnets)/2), - NotBefore: time.Unix(rc.Details.NotBefore, 0), - NotAfter: time.Unix(rc.Details.NotAfter, 0), - PublicKey: make([]byte, len(rc.Details.PublicKey)), - IsCA: rc.Details.IsCA, - Curve: rc.Details.Curve, + name: rc.Details.Name, + groups: make([]string, len(rc.Details.Groups)), + networks: make([]netip.Prefix, len(rc.Details.Ips)/2), + unsafeNetworks: make([]netip.Prefix, len(rc.Details.Subnets)/2), + notBefore: time.Unix(rc.Details.NotBefore, 0), + notAfter: time.Unix(rc.Details.NotAfter, 0), + publicKey: make([]byte, len(rc.Details.PublicKey)), + isCA: rc.Details.IsCA, + curve: rc.Details.Curve, }, signature: make([]byte, len(rc.Signature)), } copy(nc.signature, rc.Signature) - copy(nc.details.Groups, rc.Details.Groups) - nc.details.Issuer = hex.EncodeToString(rc.Details.Issuer) + copy(nc.details.groups, rc.Details.Groups) + nc.details.issuer = hex.EncodeToString(rc.Details.Issuer) - if len(rc.Details.PublicKey) < publicKeyLen && assertPublicKey { - return nil, fmt.Errorf("public key was fewer than 32 bytes; %v", len(rc.Details.PublicKey)) + if len(publicKey) > 0 { + nc.details.publicKey = publicKey } - copy(nc.details.PublicKey, rc.Details.PublicKey) + + copy(nc.details.publicKey, rc.Details.PublicKey) var ip netip.Addr for i, rawIp := range rc.Details.Ips { @@ -399,7 +443,7 @@ func unmarshalCertificateV1(b []byte, assertPublicKey bool) (*certificateV1, err ip = int2addr(rawIp) } else { ones, _ := net.IPMask(int2ip(rawIp)).Size() - nc.details.Ips[i/2] = netip.PrefixFrom(ip, ones) + nc.details.networks[i/2] = netip.PrefixFrom(ip, ones) } } @@ -408,67 +452,16 @@ func unmarshalCertificateV1(b []byte, assertPublicKey bool) (*certificateV1, err ip = int2addr(rawIp) } else { ones, _ := net.IPMask(int2ip(rawIp)).Size() - nc.details.Subnets[i/2] = netip.PrefixFrom(ip, ones) + nc.details.unsafeNetworks[i/2] = netip.PrefixFrom(ip, ones) } } - return &nc, nil -} - -func signV1(t *TBSCertificate, curve Curve, key []byte, client *pkclient.PKClient) (*certificateV1, error) { - c := &certificateV1{ - details: detailsV1{ - Name: t.Name, - Ips: t.Networks, - Subnets: t.UnsafeNetworks, - Groups: t.Groups, - NotBefore: t.NotBefore, - NotAfter: t.NotAfter, - PublicKey: t.PublicKey, - IsCA: t.IsCA, - Curve: t.Curve, - Issuer: t.issuer, - }, - } - b, err := proto.Marshal(c.getRawDetails()) + err = nc.validate() if err != nil { return nil, err } - var sig []byte - - switch curve { - case Curve_CURVE25519: - signer := ed25519.PrivateKey(key) - sig = ed25519.Sign(signer, b) - case Curve_P256: - if client != nil { - sig, err = client.SignASN1(b) - } else { - signer := &ecdsa.PrivateKey{ - PublicKey: ecdsa.PublicKey{ - Curve: elliptic.P256(), - }, - // ref: https://github.com/golang/go/blob/go1.19/src/crypto/x509/sec1.go#L95 - D: new(big.Int).SetBytes(key), - } - // ref: https://github.com/golang/go/blob/go1.19/src/crypto/x509/sec1.go#L119 - signer.X, signer.Y = signer.Curve.ScalarBaseMult(key) - - // We need to hash first for ECDSA - // - https://pkg.go.dev/crypto/ecdsa#SignASN1 - hashed := sha256.Sum256(b) - sig, err = ecdsa.SignASN1(rand.Reader, signer, hashed[:]) - if err != nil { - return nil, err - } - } - default: - return nil, fmt.Errorf("invalid curve: %s", c.details.Curve) - } - - c.signature = sig - return c, nil + return &nc, nil } func ip2int(ip []byte) uint32 { diff --git a/cert/cert_v1_test.go b/cert/cert_v1_test.go new file mode 100644 index 0000000..8c3fe93 --- /dev/null +++ b/cert/cert_v1_test.go @@ -0,0 +1,218 @@ +package cert + +import ( + "fmt" + "net/netip" + "testing" + "time" + + "github.com/slackhq/nebula/test" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/proto" +) + +func TestCertificateV1_Marshal(t *testing.T) { + before := time.Now().Add(time.Second * -60).Round(time.Second) + after := time.Now().Add(time.Second * 60).Round(time.Second) + pubKey := []byte("1234567890abcedfghij1234567890ab") + + nc := certificateV1{ + details: detailsV1{ + name: "testing", + networks: []netip.Prefix{ + mustParsePrefixUnmapped("10.1.1.1/24"), + mustParsePrefixUnmapped("10.1.1.2/16"), + }, + unsafeNetworks: []netip.Prefix{ + mustParsePrefixUnmapped("9.1.1.2/24"), + mustParsePrefixUnmapped("9.1.1.3/16"), + }, + groups: []string{"test-group1", "test-group2", "test-group3"}, + notBefore: before, + notAfter: after, + publicKey: pubKey, + isCA: false, + issuer: "1234567890abcedfghij1234567890ab", + }, + signature: []byte("1234567890abcedfghij1234567890ab"), + } + + b, err := nc.Marshal() + assert.Nil(t, err) + //t.Log("Cert size:", len(b)) + + nc2, err := unmarshalCertificateV1(b, nil) + assert.Nil(t, err) + + assert.Equal(t, nc.Version(), Version1) + assert.Equal(t, nc.Curve(), Curve_CURVE25519) + assert.Equal(t, nc.Signature(), nc2.Signature()) + assert.Equal(t, nc.Name(), nc2.Name()) + assert.Equal(t, nc.NotBefore(), nc2.NotBefore()) + assert.Equal(t, nc.NotAfter(), nc2.NotAfter()) + assert.Equal(t, nc.PublicKey(), nc2.PublicKey()) + assert.Equal(t, nc.IsCA(), nc2.IsCA()) + + assert.Equal(t, nc.Networks(), nc2.Networks()) + assert.Equal(t, nc.UnsafeNetworks(), nc2.UnsafeNetworks()) + + assert.Equal(t, nc.Groups(), nc2.Groups()) +} + +func TestCertificateV1_Expired(t *testing.T) { + nc := certificateV1{ + details: detailsV1{ + notBefore: time.Now().Add(time.Second * -60).Round(time.Second), + notAfter: time.Now().Add(time.Second * 60).Round(time.Second), + }, + } + + assert.True(t, nc.Expired(time.Now().Add(time.Hour))) + assert.True(t, nc.Expired(time.Now().Add(-time.Hour))) + assert.False(t, nc.Expired(time.Now())) +} + +func TestCertificateV1_MarshalJSON(t *testing.T) { + time.Local = time.UTC + pubKey := []byte("1234567890abcedfghij1234567890ab") + + nc := certificateV1{ + details: detailsV1{ + name: "testing", + networks: []netip.Prefix{ + mustParsePrefixUnmapped("10.1.1.1/24"), + mustParsePrefixUnmapped("10.1.1.2/16"), + }, + unsafeNetworks: []netip.Prefix{ + mustParsePrefixUnmapped("9.1.1.2/24"), + mustParsePrefixUnmapped("9.1.1.3/16"), + }, + groups: []string{"test-group1", "test-group2", "test-group3"}, + notBefore: time.Date(1, 0, 0, 1, 0, 0, 0, time.UTC), + notAfter: time.Date(1, 0, 0, 2, 0, 0, 0, time.UTC), + publicKey: pubKey, + isCA: false, + issuer: "1234567890abcedfghij1234567890ab", + }, + signature: []byte("1234567890abcedfghij1234567890ab"), + } + + b, err := nc.MarshalJSON() + assert.Nil(t, err) + assert.Equal( + t, + "{\"details\":{\"curve\":\"CURVE25519\",\"groups\":[\"test-group1\",\"test-group2\",\"test-group3\"],\"isCa\":false,\"issuer\":\"1234567890abcedfghij1234567890ab\",\"name\":\"testing\",\"networks\":[\"10.1.1.1/24\",\"10.1.1.2/16\"],\"notAfter\":\"0000-11-30T02:00:00Z\",\"notBefore\":\"0000-11-30T01:00:00Z\",\"publicKey\":\"313233343536373839306162636564666768696a313233343536373839306162\",\"unsafeNetworks\":[\"9.1.1.2/24\",\"9.1.1.3/16\"]},\"fingerprint\":\"3944c53d4267a229295b56cb2d27d459164c010ac97d655063ba421e0670f4ba\",\"signature\":\"313233343536373839306162636564666768696a313233343536373839306162\",\"version\":1}", + string(b), + ) +} + +func TestCertificateV1_VerifyPrivateKey(t *testing.T) { + ca, _, caKey, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil) + err := ca.VerifyPrivateKey(Curve_CURVE25519, caKey) + assert.Nil(t, err) + + _, _, caKey2, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil) + assert.Nil(t, err) + err = ca.VerifyPrivateKey(Curve_CURVE25519, caKey2) + assert.NotNil(t, err) + + c, _, priv, _ := NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil) + rawPriv, b, curve, err := UnmarshalPrivateKeyFromPEM(priv) + assert.NoError(t, err) + assert.Empty(t, b) + assert.Equal(t, Curve_CURVE25519, curve) + err = c.VerifyPrivateKey(Curve_CURVE25519, rawPriv) + assert.Nil(t, err) + + _, priv2 := X25519Keypair() + err = c.VerifyPrivateKey(Curve_CURVE25519, priv2) + assert.NotNil(t, err) +} + +func TestCertificateV1_VerifyPrivateKeyP256(t *testing.T) { + ca, _, caKey, _ := NewTestCaCert(Version1, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil) + err := ca.VerifyPrivateKey(Curve_P256, caKey) + assert.Nil(t, err) + + _, _, caKey2, _ := NewTestCaCert(Version1, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil) + assert.Nil(t, err) + err = ca.VerifyPrivateKey(Curve_P256, caKey2) + assert.NotNil(t, err) + + c, _, priv, _ := NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil) + rawPriv, b, curve, err := UnmarshalPrivateKeyFromPEM(priv) + assert.NoError(t, err) + assert.Empty(t, b) + assert.Equal(t, Curve_P256, curve) + err = c.VerifyPrivateKey(Curve_P256, rawPriv) + assert.Nil(t, err) + + _, priv2 := P256Keypair() + err = c.VerifyPrivateKey(Curve_P256, priv2) + assert.NotNil(t, err) +} + +// Ensure that upgrading the protobuf library does not change how certificates +// are marshalled, since this would break signature verification +func TestMarshalingCertificateV1Consistency(t *testing.T) { + before := time.Date(1970, time.January, 1, 1, 1, 1, 1, time.UTC) + after := time.Date(9999, time.January, 1, 1, 1, 1, 1, time.UTC) + pubKey := []byte("1234567890abcedfghij1234567890ab") + + nc := certificateV1{ + details: detailsV1{ + name: "testing", + networks: []netip.Prefix{ + mustParsePrefixUnmapped("10.1.1.2/16"), + mustParsePrefixUnmapped("10.1.1.1/24"), + }, + unsafeNetworks: []netip.Prefix{ + mustParsePrefixUnmapped("9.1.1.3/16"), + mustParsePrefixUnmapped("9.1.1.2/24"), + }, + groups: []string{"test-group1", "test-group2", "test-group3"}, + notBefore: before, + notAfter: after, + publicKey: pubKey, + isCA: false, + issuer: "1234567890abcedfghij1234567890ab", + }, + signature: []byte("1234567890abcedfghij1234567890ab"), + } + + b, err := nc.Marshal() + require.Nil(t, err) + assert.Equal(t, "0a8e010a0774657374696e671212828284508080fcff0f8182845080feffff0f1a12838284488080fcff0f8282844880feffff0f220b746573742d67726f757031220b746573742d67726f757032220b746573742d67726f75703328cd1c30cdb8ccf0af073a20313233343536373839306162636564666768696a3132333435363738393061624a081234567890abcedf1220313233343536373839306162636564666768696a313233343536373839306162", fmt.Sprintf("%x", b)) + + b, err = proto.Marshal(nc.getRawDetails()) + assert.Nil(t, err) + assert.Equal(t, "0a0774657374696e671212828284508080fcff0f8182845080feffff0f1a12838284488080fcff0f8282844880feffff0f220b746573742d67726f757031220b746573742d67726f757032220b746573742d67726f75703328cd1c30cdb8ccf0af073a20313233343536373839306162636564666768696a3132333435363738393061624a081234567890abcedf", fmt.Sprintf("%x", b)) +} + +func TestCertificateV1_Copy(t *testing.T) { + ca, _, caKey, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, nil) + c, _, _, _ := NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil) + cc := c.Copy() + test.AssertDeepCopyEqual(t, c, cc) +} + +func TestUnmarshalCertificateV1(t *testing.T) { + // Test that we don't panic with an invalid certificate (#332) + data := []byte("\x98\x00\x00") + _, err := unmarshalCertificateV1(data, nil) + assert.EqualError(t, err, "encoded Details was nil") +} + +func appendByteSlices(b ...[]byte) []byte { + retSlice := []byte{} + for _, v := range b { + retSlice = append(retSlice, v...) + } + return retSlice +} + +func mustParsePrefixUnmapped(s string) netip.Prefix { + prefix := netip.MustParsePrefix(s) + return netip.PrefixFrom(prefix.Addr().Unmap(), prefix.Bits()) +} diff --git a/cert/cert_v2.asn1 b/cert/cert_v2.asn1 new file mode 100644 index 0000000..f863133 --- /dev/null +++ b/cert/cert_v2.asn1 @@ -0,0 +1,37 @@ +Nebula DEFINITIONS AUTOMATIC TAGS ::= BEGIN + +Name ::= UTF8String (SIZE (1..253)) +Time ::= INTEGER (0..18446744073709551615) -- Seconds since unix epoch, uint64 maximum +Network ::= OCTET STRING (SIZE (5,17)) -- IP addresses are 4 or 16 bytes + 1 byte for the prefix length +Curve ::= ENUMERATED { + curve25519 (0), + p256 (1) +} + +-- The maximum size of a certificate must not exceed 65536 bytes +Certificate ::= SEQUENCE { + details OCTET STRING, + curve Curve DEFAULT curve25519, + publicKey OCTET STRING, + -- signature(details + curve + publicKey) using the appropriate method for curve + signature OCTET STRING +} + +Details ::= SEQUENCE { + name Name, + + -- At least 1 ipv4 or ipv6 address must be present if isCA is false + networks SEQUENCE OF Network OPTIONAL, + unsafeNetworks SEQUENCE OF Network OPTIONAL, + groups SEQUENCE OF Name OPTIONAL, + isCA BOOLEAN DEFAULT false, + notBefore Time, + notAfter Time, + + -- issuer is only required if isCA is false, if isCA is true then it must not be present + issuer OCTET STRING OPTIONAL, + ... + -- New fields can be added below here +} + +END \ No newline at end of file diff --git a/cert/cert_v2.go b/cert/cert_v2.go new file mode 100644 index 0000000..322463e --- /dev/null +++ b/cert/cert_v2.go @@ -0,0 +1,730 @@ +package cert + +import ( + "bytes" + "crypto/ecdh" + "crypto/ecdsa" + "crypto/ed25519" + "crypto/elliptic" + "crypto/sha256" + "encoding/hex" + "encoding/json" + "encoding/pem" + "fmt" + "net/netip" + "slices" + "time" + + "golang.org/x/crypto/cryptobyte" + "golang.org/x/crypto/cryptobyte/asn1" + "golang.org/x/crypto/curve25519" +) + +const ( + classConstructed = 0x20 + classContextSpecific = 0x80 + + TagCertDetails = 0 | classConstructed | classContextSpecific + TagCertCurve = 1 | classContextSpecific + TagCertPublicKey = 2 | classContextSpecific + TagCertSignature = 3 | classContextSpecific + + TagDetailsName = 0 | classContextSpecific + TagDetailsNetworks = 1 | classConstructed | classContextSpecific + TagDetailsUnsafeNetworks = 2 | classConstructed | classContextSpecific + TagDetailsGroups = 3 | classConstructed | classContextSpecific + TagDetailsIsCA = 4 | classContextSpecific + TagDetailsNotBefore = 5 | classContextSpecific + TagDetailsNotAfter = 6 | classContextSpecific + TagDetailsIssuer = 7 | classContextSpecific +) + +const ( + // MaxCertificateSize is the maximum length a valid certificate can be + MaxCertificateSize = 65536 + + // MaxNameLength is limited to a maximum realistic DNS domain name to help facilitate DNS systems + MaxNameLength = 253 + + // MaxNetworkLength is the maximum length a network value can be. + // 16 bytes for an ipv6 address + 1 byte for the prefix length + MaxNetworkLength = 17 +) + +type certificateV2 struct { + details detailsV2 + + // RawDetails contains the entire asn.1 DER encoded Details struct + // This is to benefit forwards compatibility in signature checking. + // signature(RawDetails + Curve + PublicKey) == Signature + rawDetails []byte + curve Curve + publicKey []byte + signature []byte +} + +type detailsV2 struct { + name string + networks []netip.Prefix // MUST BE SORTED + unsafeNetworks []netip.Prefix // MUST BE SORTED + groups []string + isCA bool + notBefore time.Time + notAfter time.Time + issuer string +} + +func (c *certificateV2) Version() Version { + return Version2 +} + +func (c *certificateV2) Curve() Curve { + return c.curve +} + +func (c *certificateV2) Groups() []string { + return c.details.groups +} + +func (c *certificateV2) IsCA() bool { + return c.details.isCA +} + +func (c *certificateV2) Issuer() string { + return c.details.issuer +} + +func (c *certificateV2) Name() string { + return c.details.name +} + +func (c *certificateV2) Networks() []netip.Prefix { + return c.details.networks +} + +func (c *certificateV2) NotAfter() time.Time { + return c.details.notAfter +} + +func (c *certificateV2) NotBefore() time.Time { + return c.details.notBefore +} + +func (c *certificateV2) PublicKey() []byte { + return c.publicKey +} + +func (c *certificateV2) Signature() []byte { + return c.signature +} + +func (c *certificateV2) UnsafeNetworks() []netip.Prefix { + return c.details.unsafeNetworks +} + +func (c *certificateV2) Fingerprint() (string, error) { + if len(c.rawDetails) == 0 { + return "", ErrMissingDetails + } + + b := make([]byte, len(c.rawDetails)+1+len(c.publicKey)+len(c.signature)) + copy(b, c.rawDetails) + b[len(c.rawDetails)] = byte(c.curve) + copy(b[len(c.rawDetails)+1:], c.publicKey) + copy(b[len(c.rawDetails)+1+len(c.publicKey):], c.signature) + sum := sha256.Sum256(b) + return hex.EncodeToString(sum[:]), nil +} + +func (c *certificateV2) CheckSignature(key []byte) bool { + if len(c.rawDetails) == 0 { + return false + } + b := make([]byte, len(c.rawDetails)+1+len(c.publicKey)) + copy(b, c.rawDetails) + b[len(c.rawDetails)] = byte(c.curve) + copy(b[len(c.rawDetails)+1:], c.publicKey) + + switch c.curve { + case Curve_CURVE25519: + return ed25519.Verify(key, b, c.signature) + case Curve_P256: + x, y := elliptic.Unmarshal(elliptic.P256(), key) + pubKey := &ecdsa.PublicKey{Curve: elliptic.P256(), X: x, Y: y} + hashed := sha256.Sum256(b) + return ecdsa.VerifyASN1(pubKey, hashed[:], c.signature) + default: + return false + } +} + +func (c *certificateV2) Expired(t time.Time) bool { + return c.details.notBefore.After(t) || c.details.notAfter.Before(t) +} + +func (c *certificateV2) VerifyPrivateKey(curve Curve, key []byte) error { + if curve != c.curve { + return ErrPublicPrivateCurveMismatch + } + if c.details.isCA { + switch curve { + case Curve_CURVE25519: + // the call to PublicKey below will panic slice bounds out of range otherwise + if len(key) != ed25519.PrivateKeySize { + return ErrInvalidPrivateKey + } + + if !ed25519.PublicKey(c.publicKey).Equal(ed25519.PrivateKey(key).Public()) { + return ErrPublicPrivateKeyMismatch + } + case Curve_P256: + privkey, err := ecdh.P256().NewPrivateKey(key) + if err != nil { + return ErrInvalidPrivateKey + } + pub := privkey.PublicKey().Bytes() + if !bytes.Equal(pub, c.publicKey) { + return ErrPublicPrivateKeyMismatch + } + default: + return fmt.Errorf("invalid curve: %s", curve) + } + return nil + } + + var pub []byte + switch curve { + case Curve_CURVE25519: + var err error + pub, err = curve25519.X25519(key, curve25519.Basepoint) + if err != nil { + return ErrInvalidPrivateKey + } + case Curve_P256: + privkey, err := ecdh.P256().NewPrivateKey(key) + if err != nil { + return ErrInvalidPrivateKey + } + pub = privkey.PublicKey().Bytes() + default: + return fmt.Errorf("invalid curve: %s", curve) + } + if !bytes.Equal(pub, c.publicKey) { + return ErrPublicPrivateKeyMismatch + } + + return nil +} + +func (c *certificateV2) String() string { + mb, err := c.marshalJSON() + if err != nil { + return fmt.Sprintf("", err) + } + + b, err := json.MarshalIndent(mb, "", "\t") + if err != nil { + return fmt.Sprintf("", err) + } + return string(b) +} + +func (c *certificateV2) MarshalForHandshakes() ([]byte, error) { + if c.rawDetails == nil { + return nil, ErrEmptyRawDetails + } + var b cryptobyte.Builder + // Outermost certificate + b.AddASN1(asn1.SEQUENCE, func(b *cryptobyte.Builder) { + + // Add the cert details which is already marshalled + b.AddBytes(c.rawDetails) + + // Skipping the curve and public key since those come across in a different part of the handshake + + // Add the signature + b.AddASN1(TagCertSignature, func(b *cryptobyte.Builder) { + b.AddBytes(c.signature) + }) + }) + + return b.Bytes() +} + +func (c *certificateV2) Marshal() ([]byte, error) { + if c.rawDetails == nil { + return nil, ErrEmptyRawDetails + } + var b cryptobyte.Builder + // Outermost certificate + b.AddASN1(asn1.SEQUENCE, func(b *cryptobyte.Builder) { + + // Add the cert details which is already marshalled + b.AddBytes(c.rawDetails) + + // Add the curve only if its not the default value + if c.curve != Curve_CURVE25519 { + b.AddASN1(TagCertCurve, func(b *cryptobyte.Builder) { + b.AddBytes([]byte{byte(c.curve)}) + }) + } + + // Add the public key if it is not empty + if c.publicKey != nil { + b.AddASN1(TagCertPublicKey, func(b *cryptobyte.Builder) { + b.AddBytes(c.publicKey) + }) + } + + // Add the signature + b.AddASN1(TagCertSignature, func(b *cryptobyte.Builder) { + b.AddBytes(c.signature) + }) + }) + + return b.Bytes() +} + +func (c *certificateV2) MarshalPEM() ([]byte, error) { + b, err := c.Marshal() + if err != nil { + return nil, err + } + return pem.EncodeToMemory(&pem.Block{Type: CertificateV2Banner, Bytes: b}), nil +} + +func (c *certificateV2) MarshalJSON() ([]byte, error) { + b, err := c.marshalJSON() + if err != nil { + return nil, err + } + return json.Marshal(b) +} + +func (c *certificateV2) marshalJSON() (m, error) { + fp, err := c.Fingerprint() + if err != nil { + return nil, err + } + + return m{ + "details": m{ + "name": c.details.name, + "networks": c.details.networks, + "unsafeNetworks": c.details.unsafeNetworks, + "groups": c.details.groups, + "notBefore": c.details.notBefore, + "notAfter": c.details.notAfter, + "isCa": c.details.isCA, + "issuer": c.details.issuer, + }, + "version": Version2, + "publicKey": fmt.Sprintf("%x", c.publicKey), + "curve": c.curve.String(), + "fingerprint": fp, + "signature": fmt.Sprintf("%x", c.Signature()), + }, nil +} + +func (c *certificateV2) Copy() Certificate { + nc := &certificateV2{ + details: detailsV2{ + name: c.details.name, + notBefore: c.details.notBefore, + notAfter: c.details.notAfter, + isCA: c.details.isCA, + issuer: c.details.issuer, + }, + curve: c.curve, + publicKey: make([]byte, len(c.publicKey)), + signature: make([]byte, len(c.signature)), + rawDetails: make([]byte, len(c.rawDetails)), + } + + if c.details.groups != nil { + nc.details.groups = make([]string, len(c.details.groups)) + copy(nc.details.groups, c.details.groups) + } + + if c.details.networks != nil { + nc.details.networks = make([]netip.Prefix, len(c.details.networks)) + copy(nc.details.networks, c.details.networks) + } + + if c.details.unsafeNetworks != nil { + nc.details.unsafeNetworks = make([]netip.Prefix, len(c.details.unsafeNetworks)) + copy(nc.details.unsafeNetworks, c.details.unsafeNetworks) + } + + copy(nc.rawDetails, c.rawDetails) + copy(nc.signature, c.signature) + copy(nc.publicKey, c.publicKey) + + return nc +} + +func (c *certificateV2) fromTBSCertificate(t *TBSCertificate) error { + c.details = detailsV2{ + name: t.Name, + networks: t.Networks, + unsafeNetworks: t.UnsafeNetworks, + groups: t.Groups, + isCA: t.IsCA, + notBefore: t.NotBefore, + notAfter: t.NotAfter, + issuer: t.issuer, + } + c.curve = t.Curve + c.publicKey = t.PublicKey + return c.validate() +} + +func (c *certificateV2) validate() error { + // Empty names are allowed + + if len(c.publicKey) == 0 { + return ErrInvalidPublicKey + } + + if !c.details.isCA && len(c.details.networks) == 0 { + return NewErrInvalidCertificateProperties("non-CA certificate must contain at least 1 network") + } + + hasV4Networks := false + hasV6Networks := false + for _, network := range c.details.networks { + if !network.IsValid() || !network.Addr().IsValid() { + return NewErrInvalidCertificateProperties("invalid network: %s", network) + } + + if network.Addr().IsUnspecified() { + return NewErrInvalidCertificateProperties("non-CA certificates must not use the zero address as a network: %s", network) + } + + if network.Addr().Zone() != "" { + return NewErrInvalidCertificateProperties("networks may not contain zones: %s", network) + } + + if network.Addr().Is4In6() { + return NewErrInvalidCertificateProperties("4in6 networks are not allowed: %s", network) + } + + hasV4Networks = hasV4Networks || network.Addr().Is4() + hasV6Networks = hasV6Networks || network.Addr().Is6() + } + + slices.SortFunc(c.details.networks, comparePrefix) + err := findDuplicatePrefix(c.details.networks) + if err != nil { + return err + } + + for _, network := range c.details.unsafeNetworks { + if !network.IsValid() || !network.Addr().IsValid() { + return NewErrInvalidCertificateProperties("invalid unsafe network: %s", network) + } + + if network.Addr().Zone() != "" { + return NewErrInvalidCertificateProperties("unsafe networks may not contain zones: %s", network) + } + + if !c.details.isCA { + if network.Addr().Is6() { + if !hasV6Networks { + return NewErrInvalidCertificateProperties("IPv6 unsafe networks require an IPv6 address assignment: %s", network) + } + } else if network.Addr().Is4() { + if !hasV4Networks { + return NewErrInvalidCertificateProperties("IPv4 unsafe networks require an IPv4 address assignment: %s", network) + } + } + } + } + + slices.SortFunc(c.details.unsafeNetworks, comparePrefix) + err = findDuplicatePrefix(c.details.unsafeNetworks) + if err != nil { + return err + } + + return nil +} + +func (c *certificateV2) marshalForSigning() ([]byte, error) { + d, err := c.details.Marshal() + if err != nil { + return nil, fmt.Errorf("marshalling certificate details failed: %w", err) + } + c.rawDetails = d + + b := make([]byte, len(c.rawDetails)+1+len(c.publicKey)) + copy(b, c.rawDetails) + b[len(c.rawDetails)] = byte(c.curve) + copy(b[len(c.rawDetails)+1:], c.publicKey) + return b, nil +} + +func (c *certificateV2) setSignature(b []byte) error { + if len(b) == 0 { + return ErrEmptySignature + } + c.signature = b + return nil +} + +func (d *detailsV2) Marshal() ([]byte, error) { + var b cryptobyte.Builder + var err error + + // Details are a structure + b.AddASN1(TagCertDetails, func(b *cryptobyte.Builder) { + + // Add the name + b.AddASN1(TagDetailsName, func(b *cryptobyte.Builder) { + b.AddBytes([]byte(d.name)) + }) + + // Add the networks if any exist + if len(d.networks) > 0 { + b.AddASN1(TagDetailsNetworks, func(b *cryptobyte.Builder) { + for _, n := range d.networks { + sb, innerErr := n.MarshalBinary() + if innerErr != nil { + // MarshalBinary never returns an error + err = fmt.Errorf("unable to marshal network: %w", innerErr) + return + } + b.AddASN1OctetString(sb) + } + }) + } + + // Add the unsafe networks if any exist + if len(d.unsafeNetworks) > 0 { + b.AddASN1(TagDetailsUnsafeNetworks, func(b *cryptobyte.Builder) { + for _, n := range d.unsafeNetworks { + sb, innerErr := n.MarshalBinary() + if innerErr != nil { + // MarshalBinary never returns an error + err = fmt.Errorf("unable to marshal unsafe network: %w", innerErr) + return + } + b.AddASN1OctetString(sb) + } + }) + } + + // Add groups if any exist + if len(d.groups) > 0 { + b.AddASN1(TagDetailsGroups, func(b *cryptobyte.Builder) { + for _, group := range d.groups { + b.AddASN1(asn1.UTF8String, func(b *cryptobyte.Builder) { + b.AddBytes([]byte(group)) + }) + } + }) + } + + // Add IsCA only if true + if d.isCA { + b.AddASN1(TagDetailsIsCA, func(b *cryptobyte.Builder) { + b.AddUint8(0xff) + }) + } + + // Add not before + b.AddASN1Int64WithTag(d.notBefore.Unix(), TagDetailsNotBefore) + + // Add not after + b.AddASN1Int64WithTag(d.notAfter.Unix(), TagDetailsNotAfter) + + // Add the issuer if present + if d.issuer != "" { + issuerBytes, innerErr := hex.DecodeString(d.issuer) + if innerErr != nil { + err = fmt.Errorf("failed to decode issuer: %w", innerErr) + return + } + b.AddASN1(TagDetailsIssuer, func(b *cryptobyte.Builder) { + b.AddBytes(issuerBytes) + }) + } + }) + + if err != nil { + return nil, err + } + + return b.Bytes() +} + +func unmarshalCertificateV2(b []byte, publicKey []byte, curve Curve) (*certificateV2, error) { + l := len(b) + if l == 0 || l > MaxCertificateSize { + return nil, ErrBadFormat + } + + input := cryptobyte.String(b) + // Open the envelope + if !input.ReadASN1(&input, asn1.SEQUENCE) || input.Empty() { + return nil, ErrBadFormat + } + + // Grab the cert details, we need to preserve the tag and length + var rawDetails cryptobyte.String + if !input.ReadASN1Element(&rawDetails, TagCertDetails) || rawDetails.Empty() { + return nil, ErrBadFormat + } + + //Maybe grab the curve + var rawCurve byte + if !readOptionalASN1Byte(&input, &rawCurve, TagCertCurve, byte(curve)) { + return nil, ErrBadFormat + } + curve = Curve(rawCurve) + + // Maybe grab the public key + var rawPublicKey cryptobyte.String + if len(publicKey) > 0 { + rawPublicKey = publicKey + } else if !input.ReadOptionalASN1(&rawPublicKey, nil, TagCertPublicKey) { + return nil, ErrBadFormat + } + + if len(rawPublicKey) == 0 { + return nil, ErrBadFormat + } + + // Grab the signature + var rawSignature cryptobyte.String + if !input.ReadASN1(&rawSignature, TagCertSignature) || rawSignature.Empty() { + return nil, ErrBadFormat + } + + // Finally unmarshal the details + details, err := unmarshalDetails(rawDetails) + if err != nil { + return nil, err + } + + c := &certificateV2{ + details: details, + rawDetails: rawDetails, + curve: curve, + publicKey: rawPublicKey, + signature: rawSignature, + } + + err = c.validate() + if err != nil { + return nil, err + } + + return c, nil +} + +func unmarshalDetails(b cryptobyte.String) (detailsV2, error) { + // Open the envelope + if !b.ReadASN1(&b, TagCertDetails) || b.Empty() { + return detailsV2{}, ErrBadFormat + } + + // Read the name + var name cryptobyte.String + if !b.ReadASN1(&name, TagDetailsName) || name.Empty() || len(name) > MaxNameLength { + return detailsV2{}, ErrBadFormat + } + + // Read the network addresses + var subString cryptobyte.String + var found bool + + if !b.ReadOptionalASN1(&subString, &found, TagDetailsNetworks) { + return detailsV2{}, ErrBadFormat + } + + var networks []netip.Prefix + var val cryptobyte.String + if found { + for !subString.Empty() { + if !subString.ReadASN1(&val, asn1.OCTET_STRING) || val.Empty() || len(val) > MaxNetworkLength { + return detailsV2{}, ErrBadFormat + } + + var n netip.Prefix + if err := n.UnmarshalBinary(val); err != nil { + return detailsV2{}, ErrBadFormat + } + networks = append(networks, n) + } + } + + // Read out any unsafe networks + if !b.ReadOptionalASN1(&subString, &found, TagDetailsUnsafeNetworks) { + return detailsV2{}, ErrBadFormat + } + + var unsafeNetworks []netip.Prefix + if found { + for !subString.Empty() { + if !subString.ReadASN1(&val, asn1.OCTET_STRING) || val.Empty() || len(val) > MaxNetworkLength { + return detailsV2{}, ErrBadFormat + } + + var n netip.Prefix + if err := n.UnmarshalBinary(val); err != nil { + return detailsV2{}, ErrBadFormat + } + unsafeNetworks = append(unsafeNetworks, n) + } + } + + // Read out any groups + if !b.ReadOptionalASN1(&subString, &found, TagDetailsGroups) { + return detailsV2{}, ErrBadFormat + } + + var groups []string + if found { + for !subString.Empty() { + if !subString.ReadASN1(&val, asn1.UTF8String) || val.Empty() { + return detailsV2{}, ErrBadFormat + } + groups = append(groups, string(val)) + } + } + + // Read out IsCA + var isCa bool + if !readOptionalASN1Boolean(&b, &isCa, TagDetailsIsCA, false) { + return detailsV2{}, ErrBadFormat + } + + // Read not before and not after + var notBefore int64 + if !b.ReadASN1Int64WithTag(¬Before, TagDetailsNotBefore) { + return detailsV2{}, ErrBadFormat + } + + var notAfter int64 + if !b.ReadASN1Int64WithTag(¬After, TagDetailsNotAfter) { + return detailsV2{}, ErrBadFormat + } + + // Read issuer + var issuer cryptobyte.String + if !b.ReadOptionalASN1(&issuer, nil, TagDetailsIssuer) { + return detailsV2{}, ErrBadFormat + } + + return detailsV2{ + name: string(name), + networks: networks, + unsafeNetworks: unsafeNetworks, + groups: groups, + isCA: isCa, + notBefore: time.Unix(notBefore, 0), + notAfter: time.Unix(notAfter, 0), + issuer: hex.EncodeToString(issuer), + }, nil +} diff --git a/cert/cert_v2_test.go b/cert/cert_v2_test.go new file mode 100644 index 0000000..3afbcab --- /dev/null +++ b/cert/cert_v2_test.go @@ -0,0 +1,267 @@ +package cert + +import ( + "crypto/ed25519" + "crypto/rand" + "encoding/hex" + "net/netip" + "slices" + "testing" + "time" + + "github.com/slackhq/nebula/test" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestCertificateV2_Marshal(t *testing.T) { + before := time.Now().Add(time.Second * -60).Round(time.Second) + after := time.Now().Add(time.Second * 60).Round(time.Second) + pubKey := []byte("1234567890abcedfghij1234567890ab") + + nc := certificateV2{ + details: detailsV2{ + name: "testing", + networks: []netip.Prefix{ + mustParsePrefixUnmapped("10.1.1.2/16"), + mustParsePrefixUnmapped("10.1.1.1/24"), + }, + unsafeNetworks: []netip.Prefix{ + mustParsePrefixUnmapped("9.1.1.3/16"), + mustParsePrefixUnmapped("9.1.1.2/24"), + }, + groups: []string{"test-group1", "test-group2", "test-group3"}, + notBefore: before, + notAfter: after, + isCA: false, + issuer: "1234567890abcdef1234567890abcdef", + }, + signature: []byte("1234567890abcdef1234567890abcdef"), + publicKey: pubKey, + } + + db, err := nc.details.Marshal() + require.NoError(t, err) + nc.rawDetails = db + + b, err := nc.Marshal() + require.Nil(t, err) + //t.Log("Cert size:", len(b)) + + nc2, err := unmarshalCertificateV2(b, nil, Curve_CURVE25519) + assert.Nil(t, err) + + assert.Equal(t, nc.Version(), Version2) + assert.Equal(t, nc.Curve(), Curve_CURVE25519) + assert.Equal(t, nc.Signature(), nc2.Signature()) + assert.Equal(t, nc.Name(), nc2.Name()) + assert.Equal(t, nc.NotBefore(), nc2.NotBefore()) + assert.Equal(t, nc.NotAfter(), nc2.NotAfter()) + assert.Equal(t, nc.PublicKey(), nc2.PublicKey()) + assert.Equal(t, nc.IsCA(), nc2.IsCA()) + assert.Equal(t, nc.Issuer(), nc2.Issuer()) + + // unmarshalling will sort networks and unsafeNetworks, we need to do the same + // but first make sure it fails + assert.NotEqual(t, nc.Networks(), nc2.Networks()) + assert.NotEqual(t, nc.UnsafeNetworks(), nc2.UnsafeNetworks()) + + slices.SortFunc(nc.details.networks, comparePrefix) + slices.SortFunc(nc.details.unsafeNetworks, comparePrefix) + + assert.Equal(t, nc.Networks(), nc2.Networks()) + assert.Equal(t, nc.UnsafeNetworks(), nc2.UnsafeNetworks()) + + assert.Equal(t, nc.Groups(), nc2.Groups()) +} + +func TestCertificateV2_Expired(t *testing.T) { + nc := certificateV2{ + details: detailsV2{ + notBefore: time.Now().Add(time.Second * -60).Round(time.Second), + notAfter: time.Now().Add(time.Second * 60).Round(time.Second), + }, + } + + assert.True(t, nc.Expired(time.Now().Add(time.Hour))) + assert.True(t, nc.Expired(time.Now().Add(-time.Hour))) + assert.False(t, nc.Expired(time.Now())) +} + +func TestCertificateV2_MarshalJSON(t *testing.T) { + time.Local = time.UTC + pubKey := []byte("1234567890abcedf1234567890abcedf") + + nc := certificateV2{ + details: detailsV2{ + name: "testing", + networks: []netip.Prefix{ + mustParsePrefixUnmapped("10.1.1.1/24"), + mustParsePrefixUnmapped("10.1.1.2/16"), + }, + unsafeNetworks: []netip.Prefix{ + mustParsePrefixUnmapped("9.1.1.2/24"), + mustParsePrefixUnmapped("9.1.1.3/16"), + }, + groups: []string{"test-group1", "test-group2", "test-group3"}, + notBefore: time.Date(1, 0, 0, 1, 0, 0, 0, time.UTC), + notAfter: time.Date(1, 0, 0, 2, 0, 0, 0, time.UTC), + isCA: false, + issuer: "1234567890abcedf1234567890abcedf", + }, + publicKey: pubKey, + signature: []byte("1234567890abcedf1234567890abcedf1234567890abcedf1234567890abcedf"), + } + + b, err := nc.MarshalJSON() + assert.ErrorIs(t, err, ErrMissingDetails) + + rd, err := nc.details.Marshal() + assert.NoError(t, err) + + nc.rawDetails = rd + b, err = nc.MarshalJSON() + assert.Nil(t, err) + assert.Equal( + t, + "{\"curve\":\"CURVE25519\",\"details\":{\"groups\":[\"test-group1\",\"test-group2\",\"test-group3\"],\"isCa\":false,\"issuer\":\"1234567890abcedf1234567890abcedf\",\"name\":\"testing\",\"networks\":[\"10.1.1.1/24\",\"10.1.1.2/16\"],\"notAfter\":\"0000-11-30T02:00:00Z\",\"notBefore\":\"0000-11-30T01:00:00Z\",\"unsafeNetworks\":[\"9.1.1.2/24\",\"9.1.1.3/16\"]},\"fingerprint\":\"152d9a7400c1e001cb76cffd035215ebb351f69eeb797f7f847dd086e15e56dd\",\"publicKey\":\"3132333435363738393061626365646631323334353637383930616263656466\",\"signature\":\"31323334353637383930616263656466313233343536373839306162636564663132333435363738393061626365646631323334353637383930616263656466\",\"version\":2}", + string(b), + ) +} + +func TestCertificateV2_VerifyPrivateKey(t *testing.T) { + ca, _, caKey, _ := NewTestCaCert(Version2, Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil) + err := ca.VerifyPrivateKey(Curve_CURVE25519, caKey) + assert.Nil(t, err) + + err = ca.VerifyPrivateKey(Curve_CURVE25519, caKey[:16]) + assert.ErrorIs(t, err, ErrInvalidPrivateKey) + + _, caKey2, err := ed25519.GenerateKey(rand.Reader) + require.Nil(t, err) + err = ca.VerifyPrivateKey(Curve_CURVE25519, caKey2) + assert.ErrorIs(t, err, ErrPublicPrivateKeyMismatch) + + c, _, priv, _ := NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil) + rawPriv, b, curve, err := UnmarshalPrivateKeyFromPEM(priv) + assert.NoError(t, err) + assert.Empty(t, b) + assert.Equal(t, Curve_CURVE25519, curve) + err = c.VerifyPrivateKey(Curve_CURVE25519, rawPriv) + assert.Nil(t, err) + + _, priv2 := X25519Keypair() + err = c.VerifyPrivateKey(Curve_P256, priv2) + assert.ErrorIs(t, err, ErrPublicPrivateCurveMismatch) + + err = c.VerifyPrivateKey(Curve_CURVE25519, priv2) + assert.ErrorIs(t, err, ErrPublicPrivateKeyMismatch) + + err = c.VerifyPrivateKey(Curve_CURVE25519, priv2[:16]) + assert.ErrorIs(t, err, ErrInvalidPrivateKey) + + ac, ok := c.(*certificateV2) + require.True(t, ok) + ac.curve = Curve(99) + err = c.VerifyPrivateKey(Curve(99), priv2) + assert.EqualError(t, err, "invalid curve: 99") + + ca2, _, caKey2, _ := NewTestCaCert(Version2, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil) + err = ca.VerifyPrivateKey(Curve_CURVE25519, caKey) + assert.Nil(t, err) + + err = ca2.VerifyPrivateKey(Curve_P256, caKey2[:16]) + assert.ErrorIs(t, err, ErrInvalidPrivateKey) + + c, _, priv, _ = NewTestCert(Version2, Curve_P256, ca2, caKey2, "test", time.Time{}, time.Time{}, nil, nil, nil) + rawPriv, b, curve, err = UnmarshalPrivateKeyFromPEM(priv) + + err = c.VerifyPrivateKey(Curve_P256, priv[:16]) + assert.ErrorIs(t, err, ErrInvalidPrivateKey) + + err = c.VerifyPrivateKey(Curve_P256, priv) + assert.ErrorIs(t, err, ErrInvalidPrivateKey) + + aCa, ok := ca2.(*certificateV2) + require.True(t, ok) + aCa.curve = Curve(99) + err = aCa.VerifyPrivateKey(Curve(99), priv2) + assert.EqualError(t, err, "invalid curve: 99") + +} + +func TestCertificateV2_VerifyPrivateKeyP256(t *testing.T) { + ca, _, caKey, _ := NewTestCaCert(Version2, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil) + err := ca.VerifyPrivateKey(Curve_P256, caKey) + assert.Nil(t, err) + + _, _, caKey2, _ := NewTestCaCert(Version2, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil) + assert.Nil(t, err) + err = ca.VerifyPrivateKey(Curve_P256, caKey2) + assert.NotNil(t, err) + + c, _, priv, _ := NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil) + rawPriv, b, curve, err := UnmarshalPrivateKeyFromPEM(priv) + assert.NoError(t, err) + assert.Empty(t, b) + assert.Equal(t, Curve_P256, curve) + err = c.VerifyPrivateKey(Curve_P256, rawPriv) + assert.Nil(t, err) + + _, priv2 := P256Keypair() + err = c.VerifyPrivateKey(Curve_P256, priv2) + assert.NotNil(t, err) +} + +func TestCertificateV2_Copy(t *testing.T) { + ca, _, caKey, _ := NewTestCaCert(Version2, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, nil) + c, _, _, _ := NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil) + cc := c.Copy() + test.AssertDeepCopyEqual(t, c, cc) +} + +func TestUnmarshalCertificateV2(t *testing.T) { + data := []byte("\x98\x00\x00") + _, err := unmarshalCertificateV2(data, nil, Curve_CURVE25519) + assert.EqualError(t, err, "bad wire format") +} + +func TestCertificateV2_marshalForSigningStability(t *testing.T) { + before := time.Date(1996, time.May, 5, 0, 0, 0, 0, time.UTC) + after := before.Add(time.Second * 60).Round(time.Second) + pubKey := []byte("1234567890abcedfghij1234567890ab") + + nc := certificateV2{ + details: detailsV2{ + name: "testing", + networks: []netip.Prefix{ + mustParsePrefixUnmapped("10.1.1.2/16"), + mustParsePrefixUnmapped("10.1.1.1/24"), + }, + unsafeNetworks: []netip.Prefix{ + mustParsePrefixUnmapped("9.1.1.3/16"), + mustParsePrefixUnmapped("9.1.1.2/24"), + }, + groups: []string{"test-group1", "test-group2", "test-group3"}, + notBefore: before, + notAfter: after, + isCA: false, + issuer: "1234567890abcdef1234567890abcdef", + }, + signature: []byte("1234567890abcdef1234567890abcdef"), + publicKey: pubKey, + } + + const expectedRawDetailsStr = "a070800774657374696e67a10e04050a0101021004050a01010118a20e0405090101031004050901010218a3270c0b746573742d67726f7570310c0b746573742d67726f7570320c0b746573742d67726f7570338504318bef808604318befbc87101234567890abcdef1234567890abcdef" + expectedRawDetails, err := hex.DecodeString(expectedRawDetailsStr) + require.NoError(t, err) + + db, err := nc.details.Marshal() + require.NoError(t, err) + assert.Equal(t, expectedRawDetails, db) + + expectedForSigning, err := hex.DecodeString(expectedRawDetailsStr + "00313233343536373839306162636564666768696a313233343536373839306162") + b, err := nc.marshalForSigning() + require.NoError(t, err) + assert.Equal(t, expectedForSigning, b) +} diff --git a/cert/errors.go b/cert/errors.go index da0d1be..4bbc023 100644 --- a/cert/errors.go +++ b/cert/errors.go @@ -2,21 +2,24 @@ package cert import ( "errors" + "fmt" ) var ( - ErrBadFormat = errors.New("bad wire format") - ErrRootExpired = errors.New("root certificate is expired") - ErrExpired = errors.New("certificate is expired") - ErrNotCA = errors.New("certificate is not a CA") - ErrNotSelfSigned = errors.New("certificate is not self-signed") - ErrBlockListed = errors.New("certificate is in the block list") - ErrFingerprintMismatch = errors.New("certificate fingerprint did not match") - ErrSignatureMismatch = errors.New("certificate signature did not match") - ErrInvalidPublicKeyLength = errors.New("invalid public key length") - ErrInvalidPrivateKeyLength = errors.New("invalid private key length") - - ErrPrivateKeyEncrypted = errors.New("private key must be decrypted") + ErrBadFormat = errors.New("bad wire format") + ErrRootExpired = errors.New("root certificate is expired") + ErrExpired = errors.New("certificate is expired") + ErrNotCA = errors.New("certificate is not a CA") + ErrNotSelfSigned = errors.New("certificate is not self-signed") + ErrBlockListed = errors.New("certificate is in the block list") + ErrFingerprintMismatch = errors.New("certificate fingerprint did not match") + ErrSignatureMismatch = errors.New("certificate signature did not match") + ErrInvalidPublicKey = errors.New("invalid public key") + ErrInvalidPrivateKey = errors.New("invalid private key") + ErrPublicPrivateCurveMismatch = errors.New("public key does not match private key curve") + ErrPublicPrivateKeyMismatch = errors.New("public key and private key are not a pair") + ErrPrivateKeyEncrypted = errors.New("private key must be decrypted") + ErrCaNotFound = errors.New("could not find ca for the certificate") ErrInvalidPEMBlock = errors.New("input did not contain a valid PEM encoded block") ErrInvalidPEMCertificateBanner = errors.New("bytes did not contain a proper certificate banner") @@ -24,4 +27,23 @@ var ( ErrInvalidPEMX25519PrivateKeyBanner = errors.New("bytes did not contain a proper X25519 private key banner") ErrInvalidPEMEd25519PublicKeyBanner = errors.New("bytes did not contain a proper Ed25519 public key banner") ErrInvalidPEMEd25519PrivateKeyBanner = errors.New("bytes did not contain a proper Ed25519 private key banner") + + ErrNoPeerStaticKey = errors.New("no peer static key was present") + ErrNoPayload = errors.New("provided payload was empty") + + ErrMissingDetails = errors.New("certificate did not contain details") + ErrEmptySignature = errors.New("empty signature") + ErrEmptyRawDetails = errors.New("empty rawDetails not allowed") ) + +type ErrInvalidCertificateProperties struct { + str string +} + +func NewErrInvalidCertificateProperties(format string, a ...any) error { + return &ErrInvalidCertificateProperties{fmt.Sprintf(format, a...)} +} + +func (e *ErrInvalidCertificateProperties) Error() string { + return e.str +} diff --git a/cert/helper_test.go b/cert/helper_test.go new file mode 100644 index 0000000..1b72a0f --- /dev/null +++ b/cert/helper_test.go @@ -0,0 +1,141 @@ +package cert + +import ( + "crypto/ecdh" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "io" + "net/netip" + "time" + + "golang.org/x/crypto/curve25519" + "golang.org/x/crypto/ed25519" +) + +// NewTestCaCert will create a new ca certificate +func NewTestCaCert(version Version, curve Curve, before, after time.Time, networks, unsafeNetworks []netip.Prefix, groups []string) (Certificate, []byte, []byte, []byte) { + var err error + var pub, priv []byte + + switch curve { + case Curve_CURVE25519: + pub, priv, err = ed25519.GenerateKey(rand.Reader) + case Curve_P256: + privk, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + panic(err) + } + + pub = elliptic.Marshal(elliptic.P256(), privk.PublicKey.X, privk.PublicKey.Y) + priv = privk.D.FillBytes(make([]byte, 32)) + default: + // There is no default to allow the underlying lib to respond with an error + } + + if before.IsZero() { + before = time.Now().Add(time.Second * -60).Round(time.Second) + } + if after.IsZero() { + after = time.Now().Add(time.Second * 60).Round(time.Second) + } + + t := &TBSCertificate{ + Curve: curve, + Version: version, + Name: "test ca", + NotBefore: time.Unix(before.Unix(), 0), + NotAfter: time.Unix(after.Unix(), 0), + PublicKey: pub, + Networks: networks, + UnsafeNetworks: unsafeNetworks, + Groups: groups, + IsCA: true, + } + + c, err := t.Sign(nil, curve, priv) + if err != nil { + panic(err) + } + + pem, err := c.MarshalPEM() + if err != nil { + panic(err) + } + + return c, pub, priv, pem +} + +// NewTestCert will generate a signed certificate with the provided details. +// Expiry times are defaulted if you do not pass them in +func NewTestCert(v Version, curve Curve, ca Certificate, key []byte, name string, before, after time.Time, networks, unsafeNetworks []netip.Prefix, groups []string) (Certificate, []byte, []byte, []byte) { + if before.IsZero() { + before = time.Now().Add(time.Second * -60).Round(time.Second) + } + + if after.IsZero() { + after = time.Now().Add(time.Second * 60).Round(time.Second) + } + + if len(networks) == 0 { + networks = []netip.Prefix{netip.MustParsePrefix("10.0.0.123/8")} + } + + var pub, priv []byte + switch curve { + case Curve_CURVE25519: + pub, priv = X25519Keypair() + case Curve_P256: + pub, priv = P256Keypair() + default: + panic("unknown curve") + } + + nc := &TBSCertificate{ + Version: v, + Curve: curve, + Name: name, + Networks: networks, + UnsafeNetworks: unsafeNetworks, + Groups: groups, + NotBefore: time.Unix(before.Unix(), 0), + NotAfter: time.Unix(after.Unix(), 0), + PublicKey: pub, + IsCA: false, + } + + c, err := nc.Sign(ca, ca.Curve(), key) + if err != nil { + panic(err) + } + + pem, err := c.MarshalPEM() + if err != nil { + panic(err) + } + + return c, pub, MarshalPrivateKeyToPEM(curve, priv), pem +} + +func X25519Keypair() ([]byte, []byte) { + privkey := make([]byte, 32) + if _, err := io.ReadFull(rand.Reader, privkey); err != nil { + panic(err) + } + + pubkey, err := curve25519.X25519(privkey, curve25519.Basepoint) + if err != nil { + panic(err) + } + + return pubkey, privkey +} + +func P256Keypair() ([]byte, []byte) { + privkey, err := ecdh.P256().GenerateKey(rand.Reader) + if err != nil { + panic(err) + } + pubkey := privkey.PublicKey() + return pubkey.Bytes(), privkey.Bytes() +} diff --git a/cert/pem.go b/cert/pem.go index 744ae2e..7ad28d1 100644 --- a/cert/pem.go +++ b/cert/pem.go @@ -30,19 +30,25 @@ func UnmarshalCertificateFromPEM(b []byte) (Certificate, []byte, error) { return nil, r, ErrInvalidPEMBlock } + var c Certificate + var err error + switch p.Type { + // Implementations must validate the resulting certificate contains valid information case CertificateBanner: - c, err := unmarshalCertificateV1(p.Bytes, true) - if err != nil { - return nil, nil, err - } - return c, r, nil + c, err = unmarshalCertificateV1(p.Bytes, nil) case CertificateV2Banner: - //TODO - panic("TODO") + c, err = unmarshalCertificateV2(p.Bytes, nil, Curve_CURVE25519) default: return nil, r, ErrInvalidPEMCertificateBanner } + + if err != nil { + return nil, r, err + } + + return c, r, nil + } func MarshalPublicKeyToPEM(curve Curve, b []byte) []byte { diff --git a/cert/sign.go b/cert/sign.go index e446aa1..12d4ee4 100644 --- a/cert/sign.go +++ b/cert/sign.go @@ -1,11 +1,15 @@ package cert import ( + "crypto/ecdsa" + "crypto/ed25519" + "crypto/elliptic" + "crypto/rand" + "crypto/sha256" "fmt" + "math/big" "net/netip" "time" - - "github.com/slackhq/nebula/pkclient" ) // TBSCertificate represents a certificate intended to be signed. @@ -24,28 +28,61 @@ type TBSCertificate struct { issuer string } +type beingSignedCertificate interface { + // fromTBSCertificate copies the values from the TBSCertificate to this versions internal representation + // Implementations must validate the resulting certificate contains valid information + fromTBSCertificate(*TBSCertificate) error + + // marshalForSigning returns the bytes that should be signed + marshalForSigning() ([]byte, error) + + // setSignature sets the signature for the certificate that has just been signed. The signature must not be blank. + setSignature([]byte) error +} + +type SignerLambda func(certBytes []byte) ([]byte, error) + // Sign will create a sealed certificate using details provided by the TBSCertificate as long as those // details do not violate constraints of the signing certificate. // If the TBSCertificate is a CA then signer must be nil. func (t *TBSCertificate) Sign(signer Certificate, curve Curve, key []byte) (Certificate, error) { - return t.sign(signer, curve, key, nil) -} - -func (t *TBSCertificate) SignPkcs11(signer Certificate, curve Curve, client *pkclient.PKClient) (Certificate, error) { - if curve != Curve_P256 { - return nil, fmt.Errorf("only P256 is supported by PKCS#11") + switch t.Curve { + case Curve_CURVE25519: + pk := ed25519.PrivateKey(key) + sp := func(certBytes []byte) ([]byte, error) { + sig := ed25519.Sign(pk, certBytes) + return sig, nil + } + return t.SignWith(signer, curve, sp) + case Curve_P256: + pk := &ecdsa.PrivateKey{ + PublicKey: ecdsa.PublicKey{ + Curve: elliptic.P256(), + }, + // ref: https://github.com/golang/go/blob/go1.19/src/crypto/x509/sec1.go#L95 + D: new(big.Int).SetBytes(key), + } + // ref: https://github.com/golang/go/blob/go1.19/src/crypto/x509/sec1.go#L119 + pk.X, pk.Y = pk.Curve.ScalarBaseMult(key) + sp := func(certBytes []byte) ([]byte, error) { + // We need to hash first for ECDSA + // - https://pkg.go.dev/crypto/ecdsa#SignASN1 + hashed := sha256.Sum256(certBytes) + return ecdsa.SignASN1(rand.Reader, pk, hashed[:]) + } + return t.SignWith(signer, curve, sp) + default: + return nil, fmt.Errorf("invalid curve: %s", t.Curve) } - - return t.sign(signer, curve, nil, client) } -func (t *TBSCertificate) sign(signer Certificate, curve Curve, key []byte, client *pkclient.PKClient) (Certificate, error) { +// SignWith does the same thing as sign, but uses the function in `sp` to calculate the signature. +// You should only use SignWith if you do not have direct access to your private key. +func (t *TBSCertificate) SignWith(signer Certificate, curve Curve, sp SignerLambda) (Certificate, error) { if curve != t.Curve { return nil, fmt.Errorf("curve in cert and private key supplied don't match") } - //TODO: make sure we have all minimum properties to sign, like a public key - if signer != nil { if t.IsCA { return nil, fmt.Errorf("can not sign a CA certificate with another") @@ -67,10 +104,64 @@ func (t *TBSCertificate) sign(signer Certificate, curve Curve, key []byte, clien } } + var c beingSignedCertificate switch t.Version { case Version1: - return signV1(t, curve, key, client) + c = &certificateV1{} + err := c.fromTBSCertificate(t) + if err != nil { + return nil, err + } + case Version2: + c = &certificateV2{} + err := c.fromTBSCertificate(t) + if err != nil { + return nil, err + } default: return nil, fmt.Errorf("unknown cert version %d", t.Version) } + + certBytes, err := c.marshalForSigning() + if err != nil { + return nil, err + } + + sig, err := sp(certBytes) + if err != nil { + return nil, err + } + + err = c.setSignature(sig) + if err != nil { + return nil, err + } + + sc, ok := c.(Certificate) + if !ok { + return nil, fmt.Errorf("invalid certificate") + } + + return sc, nil +} + +func comparePrefix(a, b netip.Prefix) int { + addr := a.Addr().Compare(b.Addr()) + if addr == 0 { + return a.Bits() - b.Bits() + } + return addr +} + +// findDuplicatePrefix returns an error if there is a duplicate prefix in the pre-sorted input slice sortedPrefixes +func findDuplicatePrefix(sortedPrefixes []netip.Prefix) error { + if len(sortedPrefixes) < 2 { + return nil + } + for i := 1; i < len(sortedPrefixes); i++ { + if comparePrefix(sortedPrefixes[i], sortedPrefixes[i-1]) == 0 { + return NewErrInvalidCertificateProperties("duplicate network detected: %v", sortedPrefixes[i]) + } + } + return nil } diff --git a/cert/sign_test.go b/cert/sign_test.go new file mode 100644 index 0000000..2b8dbe8 --- /dev/null +++ b/cert/sign_test.go @@ -0,0 +1,90 @@ +package cert + +import ( + "crypto/ecdsa" + "crypto/ed25519" + "crypto/elliptic" + "crypto/rand" + "net/netip" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestCertificateV1_Sign(t *testing.T) { + before := time.Now().Add(time.Second * -60).Round(time.Second) + after := time.Now().Add(time.Second * 60).Round(time.Second) + pubKey := []byte("1234567890abcedfghij1234567890ab") + + tbs := TBSCertificate{ + Version: Version1, + Name: "testing", + Networks: []netip.Prefix{ + mustParsePrefixUnmapped("10.1.1.1/24"), + mustParsePrefixUnmapped("10.1.1.2/16"), + }, + UnsafeNetworks: []netip.Prefix{ + mustParsePrefixUnmapped("9.1.1.2/24"), + mustParsePrefixUnmapped("9.1.1.3/24"), + }, + Groups: []string{"test-group1", "test-group2", "test-group3"}, + NotBefore: before, + NotAfter: after, + PublicKey: pubKey, + IsCA: false, + } + + pub, priv, err := ed25519.GenerateKey(rand.Reader) + c, err := tbs.Sign(&certificateV1{details: detailsV1{notBefore: before, notAfter: after}}, Curve_CURVE25519, priv) + assert.Nil(t, err) + assert.NotNil(t, c) + assert.True(t, c.CheckSignature(pub)) + + b, err := c.Marshal() + assert.Nil(t, err) + uc, err := unmarshalCertificateV1(b, nil) + assert.Nil(t, err) + assert.NotNil(t, uc) +} + +func TestCertificateV1_SignP256(t *testing.T) { + before := time.Now().Add(time.Second * -60).Round(time.Second) + after := time.Now().Add(time.Second * 60).Round(time.Second) + pubKey := []byte("01234567890abcedfghij1234567890ab1234567890abcedfghij1234567890ab") + + tbs := TBSCertificate{ + Version: Version1, + Name: "testing", + Networks: []netip.Prefix{ + mustParsePrefixUnmapped("10.1.1.1/24"), + mustParsePrefixUnmapped("10.1.1.2/16"), + }, + UnsafeNetworks: []netip.Prefix{ + mustParsePrefixUnmapped("9.1.1.2/24"), + mustParsePrefixUnmapped("9.1.1.3/16"), + }, + Groups: []string{"test-group1", "test-group2", "test-group3"}, + NotBefore: before, + NotAfter: after, + PublicKey: pubKey, + IsCA: false, + Curve: Curve_P256, + } + + priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + assert.NoError(t, err) + pub := elliptic.Marshal(elliptic.P256(), priv.PublicKey.X, priv.PublicKey.Y) + rawPriv := priv.D.FillBytes(make([]byte, 32)) + + c, err := tbs.Sign(&certificateV1{details: detailsV1{notBefore: before, notAfter: after}}, Curve_P256, rawPriv) + assert.Nil(t, err) + assert.NotNil(t, c) + assert.True(t, c.CheckSignature(pub)) + + b, err := c.Marshal() + assert.Nil(t, err) + uc, err := unmarshalCertificateV1(b, nil) + assert.Nil(t, err) + assert.NotNil(t, uc) +} diff --git a/e2e/helpers.go b/cert_test/cert.go similarity index 51% rename from e2e/helpers.go rename to cert_test/cert.go index c0893ac..ebc6f52 100644 --- a/e2e/helpers.go +++ b/cert_test/cert.go @@ -1,6 +1,9 @@ -package e2e +package cert_test import ( + "crypto/ecdh" + "crypto/ecdsa" + "crypto/elliptic" "crypto/rand" "io" "net/netip" @@ -11,9 +14,26 @@ import ( "golang.org/x/crypto/ed25519" ) -// NewTestCaCert will generate a CA cert -func NewTestCaCert(before, after time.Time, networks, unsafeNetworks []netip.Prefix, groups []string) (cert.Certificate, []byte, []byte, []byte) { - pub, priv, err := ed25519.GenerateKey(rand.Reader) +// NewTestCaCert will create a new ca certificate +func NewTestCaCert(version cert.Version, curve cert.Curve, before, after time.Time, networks, unsafeNetworks []netip.Prefix, groups []string) (cert.Certificate, []byte, []byte, []byte) { + var err error + var pub, priv []byte + + switch curve { + case cert.Curve_CURVE25519: + pub, priv, err = ed25519.GenerateKey(rand.Reader) + case cert.Curve_P256: + privk, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + panic(err) + } + + pub = elliptic.Marshal(elliptic.P256(), privk.PublicKey.X, privk.PublicKey.Y) + priv = privk.D.FillBytes(make([]byte, 32)) + default: + // There is no default to allow the underlying lib to respond with an error + } + if before.IsZero() { before = time.Now().Add(time.Second * -60).Round(time.Second) } @@ -22,7 +42,8 @@ func NewTestCaCert(before, after time.Time, networks, unsafeNetworks []netip.Pre } t := &cert.TBSCertificate{ - Version: cert.Version1, + Curve: curve, + Version: version, Name: "test ca", NotBefore: time.Unix(before.Unix(), 0), NotAfter: time.Unix(after.Unix(), 0), @@ -33,7 +54,7 @@ func NewTestCaCert(before, after time.Time, networks, unsafeNetworks []netip.Pre IsCA: true, } - c, err := t.Sign(nil, cert.Curve_CURVE25519, priv) + c, err := t.Sign(nil, curve, priv) if err != nil { panic(err) } @@ -48,7 +69,7 @@ func NewTestCaCert(before, after time.Time, networks, unsafeNetworks []netip.Pre // NewTestCert will generate a signed certificate with the provided details. // Expiry times are defaulted if you do not pass them in -func NewTestCert(ca cert.Certificate, key []byte, name string, before, after time.Time, networks, unsafeNetworks []netip.Prefix, groups []string) (cert.Certificate, []byte, []byte, []byte) { +func NewTestCert(v cert.Version, curve cert.Curve, ca cert.Certificate, key []byte, name string, before, after time.Time, networks, unsafeNetworks []netip.Prefix, groups []string) (cert.Certificate, []byte, []byte, []byte) { if before.IsZero() { before = time.Now().Add(time.Second * -60).Round(time.Second) } @@ -57,9 +78,19 @@ func NewTestCert(ca cert.Certificate, key []byte, name string, before, after tim after = time.Now().Add(time.Second * 60).Round(time.Second) } - pub, rawPriv := x25519Keypair() + var pub, priv []byte + switch curve { + case cert.Curve_CURVE25519: + pub, priv = X25519Keypair() + case cert.Curve_P256: + pub, priv = P256Keypair() + default: + panic("unknown curve") + } + nc := &cert.TBSCertificate{ - Version: cert.Version1, + Version: v, + Curve: curve, Name: name, Networks: networks, UnsafeNetworks: unsafeNetworks, @@ -80,10 +111,10 @@ func NewTestCert(ca cert.Certificate, key []byte, name string, before, after tim panic(err) } - return c, pub, cert.MarshalPrivateKeyToPEM(cert.Curve_CURVE25519, rawPriv), pem + return c, pub, cert.MarshalPrivateKeyToPEM(curve, priv), pem } -func x25519Keypair() ([]byte, []byte) { +func X25519Keypair() ([]byte, []byte) { privkey := make([]byte, 32) if _, err := io.ReadFull(rand.Reader, privkey); err != nil { panic(err) @@ -96,3 +127,12 @@ func x25519Keypair() ([]byte, []byte) { return pubkey, privkey } + +func P256Keypair() ([]byte, []byte) { + privkey, err := ecdh.P256().GenerateKey(rand.Reader) + if err != nil { + panic(err) + } + pubkey := privkey.PublicKey() + return pubkey.Bytes(), privkey.Bytes() +} diff --git a/cmd/nebula-cert/ca.go b/cmd/nebula-cert/ca.go index 90ea8ff..f83c94f 100644 --- a/cmd/nebula-cert/ca.go +++ b/cmd/nebula-cert/ca.go @@ -27,34 +27,43 @@ type caFlags struct { outCertPath *string outQRPath *string groups *string - ips *string - subnets *string + networks *string + unsafeNetworks *string argonMemory *uint argonIterations *uint argonParallelism *uint encryption *bool + version *uint curve *string p11url *string + + // Deprecated options + ips *string + subnets *string } func newCaFlags() *caFlags { cf := caFlags{set: flag.NewFlagSet("ca", flag.ContinueOnError)} cf.set.Usage = func() {} cf.name = cf.set.String("name", "", "Required: name of the certificate authority") + cf.version = cf.set.Uint("version", uint(cert.Version2), "Optional: version of the certificate format to use") cf.duration = cf.set.Duration("duration", time.Duration(time.Hour*8760), "Optional: amount of time the certificate should be valid for. Valid time units are seconds: \"s\", minutes: \"m\", hours: \"h\"") cf.outKeyPath = cf.set.String("out-key", "ca.key", "Optional: path to write the private key to") cf.outCertPath = cf.set.String("out-crt", "ca.crt", "Optional: path to write the certificate to") cf.outQRPath = cf.set.String("out-qr", "", "Optional: output a qr code image (png) of the certificate") cf.groups = cf.set.String("groups", "", "Optional: comma separated list of groups. This will limit which groups subordinate certs can use") - cf.ips = cf.set.String("ips", "", "Optional: comma separated list of ipv4 address and network in CIDR notation. This will limit which ipv4 addresses and networks subordinate certs can use for ip addresses") - cf.subnets = cf.set.String("subnets", "", "Optional: comma separated list of ipv4 address and network in CIDR notation. This will limit which ipv4 addresses and networks subordinate certs can use in subnets") + cf.networks = cf.set.String("networks", "", "Optional: comma separated list of ip address and network in CIDR notation. This will limit which ip addresses and networks subordinate certs can use in networks") + cf.unsafeNetworks = cf.set.String("unsafe-networks", "", "Optional: comma separated list of ip address and network in CIDR notation. This will limit which ip addresses and networks subordinate certs can use in unsafe networks") cf.argonMemory = cf.set.Uint("argon-memory", 2*1024*1024, "Optional: Argon2 memory parameter (in KiB) used for encrypted private key passphrase") cf.argonParallelism = cf.set.Uint("argon-parallelism", 4, "Optional: Argon2 parallelism parameter used for encrypted private key passphrase") cf.argonIterations = cf.set.Uint("argon-iterations", 1, "Optional: Argon2 iterations parameter used for encrypted private key passphrase") cf.encryption = cf.set.Bool("encrypt", false, "Optional: prompt for passphrase and write out-key in an encrypted format") cf.curve = cf.set.String("curve", "25519", "EdDSA/ECDSA Curve (25519, P256)") cf.p11url = p11Flag(cf.set) + + cf.ips = cf.set.String("ips", "", "Deprecated, see -networks") + cf.subnets = cf.set.String("subnets", "", "Deprecated, see -unsafe-networks") return &cf } @@ -113,36 +122,51 @@ func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error } } - var ips []netip.Prefix - if *cf.ips != "" { - for _, rs := range strings.Split(*cf.ips, ",") { + version := cert.Version(*cf.version) + if version != cert.Version1 && version != cert.Version2 { + return newHelpErrorf("-version must be either %v or %v", cert.Version1, cert.Version2) + } + + var networks []netip.Prefix + if *cf.networks == "" && *cf.ips != "" { + // Pull up deprecated -ips flag if needed + *cf.networks = *cf.ips + } + + if *cf.networks != "" { + for _, rs := range strings.Split(*cf.networks, ",") { rs := strings.Trim(rs, " ") if rs != "" { n, err := netip.ParsePrefix(rs) if err != nil { - return newHelpErrorf("invalid ip definition: %s", err) + return newHelpErrorf("invalid -networks definition: %s", rs) } - if !n.Addr().Is4() { - return newHelpErrorf("invalid ip definition: can only be ipv4, have %s", rs) + if version == cert.Version1 && !n.Addr().Is4() { + return newHelpErrorf("invalid -networks definition: v1 certificates can only be ipv4, have %s", rs) } - ips = append(ips, n) + networks = append(networks, n) } } } - var subnets []netip.Prefix - if *cf.subnets != "" { - for _, rs := range strings.Split(*cf.subnets, ",") { + var unsafeNetworks []netip.Prefix + if *cf.unsafeNetworks == "" && *cf.subnets != "" { + // Pull up deprecated -subnets flag if needed + *cf.unsafeNetworks = *cf.subnets + } + + if *cf.unsafeNetworks != "" { + for _, rs := range strings.Split(*cf.unsafeNetworks, ",") { rs := strings.Trim(rs, " ") if rs != "" { n, err := netip.ParsePrefix(rs) if err != nil { - return newHelpErrorf("invalid subnet definition: %s", err) + return newHelpErrorf("invalid -unsafe-networks definition: %s", rs) } - if !n.Addr().Is4() { - return newHelpErrorf("invalid subnet definition: can only be ipv4, have %s", rs) + if version == cert.Version1 && !n.Addr().Is4() { + return newHelpErrorf("invalid -unsafe-networks definition: v1 certificates can only be ipv4, have %s", rs) } - subnets = append(subnets, n) + unsafeNetworks = append(unsafeNetworks, n) } } } @@ -222,11 +246,11 @@ func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error } t := &cert.TBSCertificate{ - Version: cert.Version1, + Version: version, Name: *cf.name, Groups: groups, - Networks: ips, - UnsafeNetworks: subnets, + Networks: networks, + UnsafeNetworks: unsafeNetworks, NotBefore: time.Now(), NotAfter: time.Now().Add(*cf.duration), PublicKey: pub, @@ -248,7 +272,7 @@ func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error var b []byte if isP11 { - c, err = t.SignPkcs11(nil, curve, p11Client) + c, err = t.SignWith(nil, curve, p11Client.SignASN1) if err != nil { return fmt.Errorf("error while signing with PKCS#11: %w", err) } diff --git a/cmd/nebula-cert/ca_test.go b/cmd/nebula-cert/ca_test.go index 06a24ed..9da0ad4 100644 --- a/cmd/nebula-cert/ca_test.go +++ b/cmd/nebula-cert/ca_test.go @@ -16,8 +16,6 @@ import ( "github.com/stretchr/testify/assert" ) -//TODO: test file permissions - func Test_caSummary(t *testing.T) { assert.Equal(t, "ca : create a self signed certificate authority", caSummary()) } @@ -43,9 +41,11 @@ func Test_caHelp(t *testing.T) { " -groups string\n"+ " \tOptional: comma separated list of groups. This will limit which groups subordinate certs can use\n"+ " -ips string\n"+ - " \tOptional: comma separated list of ipv4 address and network in CIDR notation. This will limit which ipv4 addresses and networks subordinate certs can use for ip addresses\n"+ + " Deprecated, see -networks\n"+ " -name string\n"+ " \tRequired: name of the certificate authority\n"+ + " -networks string\n"+ + " \tOptional: comma separated list of ip address and network in CIDR notation. This will limit which ip addresses and networks subordinate certs can use in networks\n"+ " -out-crt string\n"+ " \tOptional: path to write the certificate to (default \"ca.crt\")\n"+ " -out-key string\n"+ @@ -54,7 +54,11 @@ func Test_caHelp(t *testing.T) { " \tOptional: output a qr code image (png) of the certificate\n"+ optionalPkcs11String(" -pkcs11 string\n \tOptional: PKCS#11 URI to an existing private key\n")+ " -subnets string\n"+ - " \tOptional: comma separated list of ipv4 address and network in CIDR notation. This will limit which ipv4 addresses and networks subordinate certs can use in subnets\n", + " \tDeprecated, see -unsafe-networks\n"+ + " -unsafe-networks string\n"+ + " \tOptional: comma separated list of ip address and network in CIDR notation. This will limit which ip addresses and networks subordinate certs can use in unsafe networks\n"+ + " -version uint\n"+ + " \tOptional: version of the certificate format to use (default 2)\n", ob.String(), ) } @@ -83,25 +87,25 @@ func Test_ca(t *testing.T) { // required args assertHelpError(t, ca( - []string{"-out-key", "nope", "-out-crt", "nope", "duration", "100m"}, ob, eb, nopw, + []string{"-version", "1", "-out-key", "nope", "-out-crt", "nope", "duration", "100m"}, ob, eb, nopw, ), "-name is required") assert.Equal(t, "", ob.String()) assert.Equal(t, "", eb.String()) // ipv4 only ips - assertHelpError(t, ca([]string{"-name", "ipv6", "-ips", "100::100/100"}, ob, eb, nopw), "invalid ip definition: can only be ipv4, have 100::100/100") + assertHelpError(t, ca([]string{"-version", "1", "-name", "ipv6", "-ips", "100::100/100"}, ob, eb, nopw), "invalid -networks definition: v1 certificates can only be ipv4, have 100::100/100") assert.Equal(t, "", ob.String()) assert.Equal(t, "", eb.String()) // ipv4 only subnets - assertHelpError(t, ca([]string{"-name", "ipv6", "-subnets", "100::100/100"}, ob, eb, nopw), "invalid subnet definition: can only be ipv4, have 100::100/100") + assertHelpError(t, ca([]string{"-version", "1", "-name", "ipv6", "-subnets", "100::100/100"}, ob, eb, nopw), "invalid -unsafe-networks definition: v1 certificates can only be ipv4, have 100::100/100") assert.Equal(t, "", ob.String()) assert.Equal(t, "", eb.String()) // failed key write ob.Reset() eb.Reset() - args := []string{"-name", "test", "-duration", "100m", "-out-crt", "/do/not/write/pleasecrt", "-out-key", "/do/not/write/pleasekey"} + args := []string{"-version", "1", "-name", "test", "-duration", "100m", "-out-crt", "/do/not/write/pleasecrt", "-out-key", "/do/not/write/pleasekey"} assert.EqualError(t, ca(args, ob, eb, nopw), "error while writing out-key: open /do/not/write/pleasekey: "+NoSuchDirError) assert.Equal(t, "", ob.String()) assert.Equal(t, "", eb.String()) @@ -114,7 +118,7 @@ func Test_ca(t *testing.T) { // failed cert write ob.Reset() eb.Reset() - args = []string{"-name", "test", "-duration", "100m", "-out-crt", "/do/not/write/pleasecrt", "-out-key", keyF.Name()} + args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-out-crt", "/do/not/write/pleasecrt", "-out-key", keyF.Name()} assert.EqualError(t, ca(args, ob, eb, nopw), "error while writing out-crt: open /do/not/write/pleasecrt: "+NoSuchDirError) assert.Equal(t, "", ob.String()) assert.Equal(t, "", eb.String()) @@ -128,7 +132,7 @@ func Test_ca(t *testing.T) { // test proper cert with removed empty groups and subnets ob.Reset() eb.Reset() - args = []string{"-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} + args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} assert.Nil(t, ca(args, ob, eb, nopw)) assert.Equal(t, "", ob.String()) assert.Equal(t, "", eb.String()) @@ -161,7 +165,7 @@ func Test_ca(t *testing.T) { os.Remove(crtF.Name()) ob.Reset() eb.Reset() - args = []string{"-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} + args = []string{"-version", "1", "-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} assert.Nil(t, ca(args, ob, eb, testpw)) assert.Equal(t, pwPromptOb, ob.String()) assert.Equal(t, "", eb.String()) @@ -189,7 +193,7 @@ func Test_ca(t *testing.T) { os.Remove(crtF.Name()) ob.Reset() eb.Reset() - args = []string{"-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} + args = []string{"-version", "1", "-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} assert.Error(t, ca(args, ob, eb, errpw)) assert.Equal(t, pwPromptOb, ob.String()) assert.Equal(t, "", eb.String()) @@ -199,7 +203,7 @@ func Test_ca(t *testing.T) { os.Remove(crtF.Name()) ob.Reset() eb.Reset() - args = []string{"-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} + args = []string{"-version", "1", "-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} assert.EqualError(t, ca(args, ob, eb, nopw), "no passphrase specified, remove -encrypt flag to write out-key in plaintext") assert.Equal(t, strings.Repeat(pwPromptOb, 5), ob.String()) // prompts 5 times before giving up assert.Equal(t, "", eb.String()) @@ -209,13 +213,13 @@ func Test_ca(t *testing.T) { os.Remove(crtF.Name()) ob.Reset() eb.Reset() - args = []string{"-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} + args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} assert.Nil(t, ca(args, ob, eb, nopw)) // test that we won't overwrite existing certificate file ob.Reset() eb.Reset() - args = []string{"-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} + args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} assert.EqualError(t, ca(args, ob, eb, nopw), "refusing to overwrite existing CA key: "+keyF.Name()) assert.Equal(t, "", ob.String()) assert.Equal(t, "", eb.String()) @@ -224,7 +228,7 @@ func Test_ca(t *testing.T) { os.Remove(keyF.Name()) ob.Reset() eb.Reset() - args = []string{"-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} + args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} assert.EqualError(t, ca(args, ob, eb, nopw), "refusing to overwrite existing CA cert: "+crtF.Name()) assert.Equal(t, "", ob.String()) assert.Equal(t, "", eb.String()) diff --git a/cmd/nebula-cert/keygen_test.go b/cmd/nebula-cert/keygen_test.go index 18ceb4b..fcfd77b 100644 --- a/cmd/nebula-cert/keygen_test.go +++ b/cmd/nebula-cert/keygen_test.go @@ -9,8 +9,6 @@ import ( "github.com/stretchr/testify/assert" ) -//TODO: test file permissions - func Test_keygenSummary(t *testing.T) { assert.Equal(t, "keygen : create a public/private key pair. the public key can be passed to `nebula-cert sign`", keygenSummary()) } diff --git a/cmd/nebula-cert/main_test.go b/cmd/nebula-cert/main_test.go index 2502824..f332895 100644 --- a/cmd/nebula-cert/main_test.go +++ b/cmd/nebula-cert/main_test.go @@ -11,8 +11,6 @@ import ( "github.com/stretchr/testify/assert" ) -//TODO: all flag parsing continueOnError will print to stderr on its own currently - func Test_help(t *testing.T) { expected := "Usage of " + os.Args[0] + " :\n" + " Global flags:\n" + diff --git a/cmd/nebula-cert/print.go b/cmd/nebula-cert/print.go index a62c223..30e0965 100644 --- a/cmd/nebula-cert/print.go +++ b/cmd/nebula-cert/print.go @@ -49,6 +49,8 @@ func printCert(args []string, out io.Writer, errOut io.Writer) error { var qrBytes []byte part := 0 + var jsonCerts []cert.Certificate + for { c, rawCert, err = cert.UnmarshalCertificateFromPEM(rawCert) if err != nil { @@ -56,13 +58,10 @@ func printCert(args []string, out io.Writer, errOut io.Writer) error { } if *pf.json { - b, _ := json.Marshal(c) - out.Write(b) - out.Write([]byte("\n")) - + jsonCerts = append(jsonCerts, c) } else { - out.Write([]byte(c.String())) - out.Write([]byte("\n")) + _, _ = out.Write([]byte(c.String())) + _, _ = out.Write([]byte("\n")) } if *pf.outQRPath != "" { @@ -80,6 +79,12 @@ func printCert(args []string, out io.Writer, errOut io.Writer) error { part++ } + if *pf.json { + b, _ := json.Marshal(jsonCerts) + _, _ = out.Write(b) + _, _ = out.Write([]byte("\n")) + } + if *pf.outQRPath != "" { b, err := qrcode.Encode(string(qrBytes), qrcode.Medium, -5) if err != nil { diff --git a/cmd/nebula-cert/print_test.go b/cmd/nebula-cert/print_test.go index 4c9a72d..86795e4 100644 --- a/cmd/nebula-cert/print_test.go +++ b/cmd/nebula-cert/print_test.go @@ -73,7 +73,7 @@ func Test_printCert(t *testing.T) { tf.Truncate(0) tf.Seek(0, 0) ca, caKey := NewTestCaCert("test ca", nil, nil, time.Time{}, time.Time{}, nil, nil, nil) - c, _ := NewTestCert(ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, []string{"hi"}) + c, _ := NewTestCert(ca, caKey, "test", time.Time{}, time.Time{}, []netip.Prefix{netip.MustParsePrefix("10.0.0.123/8")}, nil, []string{"hi"}) p, _ := c.MarshalPEM() tf.Write(p) @@ -87,7 +87,71 @@ func Test_printCert(t *testing.T) { assert.Nil(t, err) assert.Equal( t, - "NebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: "+c.Issuer()+"\n\t\tPublic key: "+pk+"\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: "+fp+"\n\tSignature: "+sig+"\n}\nNebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: "+c.Issuer()+"\n\t\tPublic key: "+pk+"\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: "+fp+"\n\tSignature: "+sig+"\n}\nNebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: "+c.Issuer()+"\n\t\tPublic key: "+pk+"\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: "+fp+"\n\tSignature: "+sig+"\n}\n", + //"NebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: "+c.Issuer()+"\n\t\tPublic key: "+pk+"\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: "+fp+"\n\tSignature: "+sig+"\n}\nNebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: "+c.Issuer()+"\n\t\tPublic key: "+pk+"\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: "+fp+"\n\tSignature: "+sig+"\n}\nNebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: "+c.Issuer()+"\n\t\tPublic key: "+pk+"\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: "+fp+"\n\tSignature: "+sig+"\n}\n", + `{ + "details": { + "curve": "CURVE25519", + "groups": [ + "hi" + ], + "isCa": false, + "issuer": "`+c.Issuer()+`", + "name": "test", + "networks": [ + "10.0.0.123/8" + ], + "notAfter": "0001-01-01T00:00:00Z", + "notBefore": "0001-01-01T00:00:00Z", + "publicKey": "`+pk+`", + "unsafeNetworks": [] + }, + "fingerprint": "`+fp+`", + "signature": "`+sig+`", + "version": 1 +} +{ + "details": { + "curve": "CURVE25519", + "groups": [ + "hi" + ], + "isCa": false, + "issuer": "`+c.Issuer()+`", + "name": "test", + "networks": [ + "10.0.0.123/8" + ], + "notAfter": "0001-01-01T00:00:00Z", + "notBefore": "0001-01-01T00:00:00Z", + "publicKey": "`+pk+`", + "unsafeNetworks": [] + }, + "fingerprint": "`+fp+`", + "signature": "`+sig+`", + "version": 1 +} +{ + "details": { + "curve": "CURVE25519", + "groups": [ + "hi" + ], + "isCa": false, + "issuer": "`+c.Issuer()+`", + "name": "test", + "networks": [ + "10.0.0.123/8" + ], + "notAfter": "0001-01-01T00:00:00Z", + "notBefore": "0001-01-01T00:00:00Z", + "publicKey": "`+pk+`", + "unsafeNetworks": [] + }, + "fingerprint": "`+fp+`", + "signature": "`+sig+`", + "version": 1 +} +`, ob.String(), ) assert.Equal(t, "", eb.String()) @@ -108,7 +172,8 @@ func Test_printCert(t *testing.T) { assert.Nil(t, err) assert.Equal( t, - "{\"details\":{\"curve\":\"CURVE25519\",\"groups\":[\"hi\"],\"ips\":[],\"isCa\":false,\"issuer\":\""+c.Issuer()+"\",\"name\":\"test\",\"notAfter\":\"0001-01-01T00:00:00Z\",\"notBefore\":\"0001-01-01T00:00:00Z\",\"publicKey\":\""+pk+"\",\"subnets\":[]},\"fingerprint\":\""+fp+"\",\"signature\":\""+sig+"\"}\n{\"details\":{\"curve\":\"CURVE25519\",\"groups\":[\"hi\"],\"ips\":[],\"isCa\":false,\"issuer\":\""+c.Issuer()+"\",\"name\":\"test\",\"notAfter\":\"0001-01-01T00:00:00Z\",\"notBefore\":\"0001-01-01T00:00:00Z\",\"publicKey\":\""+pk+"\",\"subnets\":[]},\"fingerprint\":\""+fp+"\",\"signature\":\""+sig+"\"}\n{\"details\":{\"curve\":\"CURVE25519\",\"groups\":[\"hi\"],\"ips\":[],\"isCa\":false,\"issuer\":\""+c.Issuer()+"\",\"name\":\"test\",\"notAfter\":\"0001-01-01T00:00:00Z\",\"notBefore\":\"0001-01-01T00:00:00Z\",\"publicKey\":\""+pk+"\",\"subnets\":[]},\"fingerprint\":\""+fp+"\",\"signature\":\""+sig+"\"}\n", + `[{"details":{"curve":"CURVE25519","groups":["hi"],"isCa":false,"issuer":"`+c.Issuer()+`","name":"test","networks":["10.0.0.123/8"],"notAfter":"0001-01-01T00:00:00Z","notBefore":"0001-01-01T00:00:00Z","publicKey":"`+pk+`","unsafeNetworks":[]},"fingerprint":"`+fp+`","signature":"`+sig+`","version":1},{"details":{"curve":"CURVE25519","groups":["hi"],"isCa":false,"issuer":"`+c.Issuer()+`","name":"test","networks":["10.0.0.123/8"],"notAfter":"0001-01-01T00:00:00Z","notBefore":"0001-01-01T00:00:00Z","publicKey":"`+pk+`","unsafeNetworks":[]},"fingerprint":"`+fp+`","signature":"`+sig+`","version":1},{"details":{"curve":"CURVE25519","groups":["hi"],"isCa":false,"issuer":"`+c.Issuer()+`","name":"test","networks":["10.0.0.123/8"],"notAfter":"0001-01-01T00:00:00Z","notBefore":"0001-01-01T00:00:00Z","publicKey":"`+pk+`","unsafeNetworks":[]},"fingerprint":"`+fp+`","signature":"`+sig+`","version":1}] +`, ob.String(), ) assert.Equal(t, "", eb.String()) @@ -153,6 +218,10 @@ func NewTestCert(ca cert.Certificate, signerKey []byte, name string, before, aft after = ca.NotAfter() } + if len(networks) == 0 { + networks = []netip.Prefix{netip.MustParsePrefix("10.0.0.123/8")} + } + pub, rawPriv := x25519Keypair() nc := &cert.TBSCertificate{ Version: cert.Version1, diff --git a/cmd/nebula-cert/sign.go b/cmd/nebula-cert/sign.go index 13e807f..ebcb592 100644 --- a/cmd/nebula-cert/sign.go +++ b/cmd/nebula-cert/sign.go @@ -3,6 +3,7 @@ package main import ( "crypto/ecdh" "crypto/rand" + "errors" "flag" "fmt" "io" @@ -18,36 +19,46 @@ import ( ) type signFlags struct { - set *flag.FlagSet - caKeyPath *string - caCertPath *string - name *string - ip *string - duration *time.Duration - inPubPath *string - outKeyPath *string - outCertPath *string - outQRPath *string - groups *string - subnets *string - p11url *string + set *flag.FlagSet + version *uint + caKeyPath *string + caCertPath *string + name *string + networks *string + unsafeNetworks *string + duration *time.Duration + inPubPath *string + outKeyPath *string + outCertPath *string + outQRPath *string + groups *string + + p11url *string + + // Deprecated options + ip *string + subnets *string } func newSignFlags() *signFlags { sf := signFlags{set: flag.NewFlagSet("sign", flag.ContinueOnError)} sf.set.Usage = func() {} + sf.version = sf.set.Uint("version", 0, "Optional: version of the certificate format to use, the default is to create both v1 and v2 certificates.") sf.caKeyPath = sf.set.String("ca-key", "ca.key", "Optional: path to the signing CA key") sf.caCertPath = sf.set.String("ca-crt", "ca.crt", "Optional: path to the signing CA cert") sf.name = sf.set.String("name", "", "Required: name of the cert, usually a hostname") - sf.ip = sf.set.String("ip", "", "Required: ipv4 address and network in CIDR notation to assign the cert") + sf.networks = sf.set.String("networks", "", "Required: comma separated list of ip address and network in CIDR notation to assign to this cert") + sf.unsafeNetworks = sf.set.String("unsafe-networks", "", "Optional: comma separated list of ip address and network in CIDR notation. Unsafe networks this cert can route for") sf.duration = sf.set.Duration("duration", 0, "Optional: how long the cert should be valid for. The default is 1 second before the signing cert expires. Valid time units are seconds: \"s\", minutes: \"m\", hours: \"h\"") sf.inPubPath = sf.set.String("in-pub", "", "Optional (if out-key not set): path to read a previously generated public key") sf.outKeyPath = sf.set.String("out-key", "", "Optional (if in-pub not set): path to write the private key to") sf.outCertPath = sf.set.String("out-crt", "", "Optional: path to write the certificate to") sf.outQRPath = sf.set.String("out-qr", "", "Optional: output a qr code image (png) of the certificate") sf.groups = sf.set.String("groups", "", "Optional: comma separated list of groups") - sf.subnets = sf.set.String("subnets", "", "Optional: comma separated list of ipv4 address and network in CIDR notation. Subnets this cert can serve for") sf.p11url = p11Flag(sf.set) + + sf.ip = sf.set.String("ip", "", "Deprecated, see -networks") + sf.subnets = sf.set.String("subnets", "", "Deprecated, see -unsafe-networks") return &sf } @@ -71,13 +82,26 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) if err := mustFlagString("name", sf.name); err != nil { return err } - if err := mustFlagString("ip", sf.ip); err != nil { - return err - } if !isP11 && *sf.inPubPath != "" && *sf.outKeyPath != "" { return newHelpErrorf("cannot set both -in-pub and -out-key") } + var v4Networks []netip.Prefix + var v6Networks []netip.Prefix + if *sf.networks == "" && *sf.ip != "" { + // Pull up deprecated -ip flag if needed + *sf.networks = *sf.ip + } + + if len(*sf.networks) == 0 { + return newHelpErrorf("-networks is required") + } + + version := cert.Version(*sf.version) + if version != 0 && version != cert.Version1 && version != cert.Version2 { + return newHelpErrorf("-version must be either %v or %v", cert.Version1, cert.Version2) + } + var curve cert.Curve var caKey []byte @@ -91,14 +115,14 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) // naively attempt to decode the private key as though it is not encrypted caKey, _, curve, err = cert.UnmarshalSigningPrivateKeyFromPEM(rawCAKey) - if err == cert.ErrPrivateKeyEncrypted { + if errors.Is(err, cert.ErrPrivateKeyEncrypted) { // ask for a passphrase until we get one var passphrase []byte for i := 0; i < 5; i++ { out.Write([]byte("Enter passphrase: ")) passphrase, err = pr.ReadPassword() - if err == ErrNoTerminal { + if errors.Is(err, ErrNoTerminal) { return fmt.Errorf("ca-key is encrypted and must be decrypted interactively") } else if err != nil { return fmt.Errorf("error reading password: %s", err) @@ -146,12 +170,47 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) *sf.duration = time.Until(caCert.NotAfter()) - time.Second*1 } - network, err := netip.ParsePrefix(*sf.ip) - if err != nil { - return newHelpErrorf("invalid ip definition: %s", *sf.ip) + if *sf.networks != "" { + for _, rs := range strings.Split(*sf.networks, ",") { + rs := strings.Trim(rs, " ") + if rs != "" { + n, err := netip.ParsePrefix(rs) + if err != nil { + return newHelpErrorf("invalid -networks definition: %s", rs) + } + + if n.Addr().Is4() { + v4Networks = append(v4Networks, n) + } else { + v6Networks = append(v6Networks, n) + } + } + } } - if !network.Addr().Is4() { - return newHelpErrorf("invalid ip definition: can only be ipv4, have %s", *sf.ip) + + var v4UnsafeNetworks []netip.Prefix + var v6UnsafeNetworks []netip.Prefix + if *sf.unsafeNetworks == "" && *sf.subnets != "" { + // Pull up deprecated -subnets flag if needed + *sf.unsafeNetworks = *sf.subnets + } + + if *sf.unsafeNetworks != "" { + for _, rs := range strings.Split(*sf.unsafeNetworks, ",") { + rs := strings.Trim(rs, " ") + if rs != "" { + n, err := netip.ParsePrefix(rs) + if err != nil { + return newHelpErrorf("invalid -unsafe-networks definition: %s", rs) + } + + if n.Addr().Is4() { + v4UnsafeNetworks = append(v4UnsafeNetworks, n) + } else { + v6UnsafeNetworks = append(v6UnsafeNetworks, n) + } + } + } } var groups []string @@ -164,23 +223,6 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) } } - var subnets []netip.Prefix - if *sf.subnets != "" { - for _, rs := range strings.Split(*sf.subnets, ",") { - rs := strings.Trim(rs, " ") - if rs != "" { - s, err := netip.ParsePrefix(rs) - if err != nil { - return newHelpErrorf("invalid subnet definition: %s", rs) - } - if !s.Addr().Is4() { - return newHelpErrorf("invalid subnet definition: can only be ipv4, have %s", rs) - } - subnets = append(subnets, s) - } - } - } - var pub, rawPriv []byte var p11Client *pkclient.PKClient @@ -218,19 +260,6 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) pub, rawPriv = newKeypair(curve) } - t := &cert.TBSCertificate{ - Version: cert.Version1, - Name: *sf.name, - Networks: []netip.Prefix{network}, - Groups: groups, - UnsafeNetworks: subnets, - NotBefore: time.Now(), - NotAfter: time.Now().Add(*sf.duration), - PublicKey: pub, - IsCA: false, - Curve: curve, - } - if *sf.outKeyPath == "" { *sf.outKeyPath = *sf.name + ".key" } @@ -243,18 +272,85 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) return fmt.Errorf("refusing to overwrite existing cert: %s", *sf.outCertPath) } - var c cert.Certificate + var crts []cert.Certificate - if p11Client == nil { - c, err = t.Sign(caCert, curve, caKey) - if err != nil { - return fmt.Errorf("error while signing: %w", err) + notBefore := time.Now() + notAfter := notBefore.Add(*sf.duration) + + if version == 0 || version == cert.Version1 { + // Make sure we at least have an ip + if len(v4Networks) != 1 { + return newHelpErrorf("invalid -networks definition: v1 certificates can only have a single ipv4 address") } - } else { - c, err = t.SignPkcs11(caCert, curve, p11Client) - if err != nil { - return fmt.Errorf("error while signing with PKCS#11: %w", err) + + if version == cert.Version1 { + // If we are asked to mint a v1 certificate only then we cant just ignore any v6 addresses + if len(v6Networks) > 0 { + return newHelpErrorf("invalid -networks definition: v1 certificates can only be ipv4") + } + + if len(v6UnsafeNetworks) > 0 { + return newHelpErrorf("invalid -unsafe-networks definition: v1 certificates can only be ipv4") + } } + + t := &cert.TBSCertificate{ + Version: cert.Version1, + Name: *sf.name, + Networks: []netip.Prefix{v4Networks[0]}, + Groups: groups, + UnsafeNetworks: v4UnsafeNetworks, + NotBefore: notBefore, + NotAfter: notAfter, + PublicKey: pub, + IsCA: false, + Curve: curve, + } + + var nc cert.Certificate + if p11Client == nil { + nc, err = t.Sign(caCert, curve, caKey) + if err != nil { + return fmt.Errorf("error while signing: %w", err) + } + } else { + nc, err = t.SignWith(caCert, curve, p11Client.SignASN1) + if err != nil { + return fmt.Errorf("error while signing with PKCS#11: %w", err) + } + } + + crts = append(crts, nc) + } + + if version == 0 || version == cert.Version2 { + t := &cert.TBSCertificate{ + Version: cert.Version2, + Name: *sf.name, + Networks: append(v4Networks, v6Networks...), + Groups: groups, + UnsafeNetworks: append(v4UnsafeNetworks, v6UnsafeNetworks...), + NotBefore: notBefore, + NotAfter: notAfter, + PublicKey: pub, + IsCA: false, + Curve: curve, + } + + var nc cert.Certificate + if p11Client == nil { + nc, err = t.Sign(caCert, curve, caKey) + if err != nil { + return fmt.Errorf("error while signing: %w", err) + } + } else { + nc, err = t.SignWith(caCert, curve, p11Client.SignASN1) + if err != nil { + return fmt.Errorf("error while signing with PKCS#11: %w", err) + } + } + + crts = append(crts, nc) } if !isP11 && *sf.inPubPath == "" { @@ -268,9 +364,13 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) } } - b, err := c.MarshalPEM() - if err != nil { - return fmt.Errorf("error while marshalling certificate: %s", err) + var b []byte + for _, c := range crts { + sb, err := c.MarshalPEM() + if err != nil { + return fmt.Errorf("error while marshalling certificate: %s", err) + } + b = append(b, sb...) } err = os.WriteFile(*sf.outCertPath, b, 0600) diff --git a/cmd/nebula-cert/sign_test.go b/cmd/nebula-cert/sign_test.go index b68434d..466cb8c 100644 --- a/cmd/nebula-cert/sign_test.go +++ b/cmd/nebula-cert/sign_test.go @@ -16,8 +16,6 @@ import ( "golang.org/x/crypto/ed25519" ) -//TODO: test file permissions - func Test_signSummary(t *testing.T) { assert.Equal(t, "sign : create and sign a certificate", signSummary()) } @@ -39,9 +37,11 @@ func Test_signHelp(t *testing.T) { " -in-pub string\n"+ " \tOptional (if out-key not set): path to read a previously generated public key\n"+ " -ip string\n"+ - " \tRequired: ipv4 address and network in CIDR notation to assign the cert\n"+ + " \tDeprecated, see -networks\n"+ " -name string\n"+ " \tRequired: name of the cert, usually a hostname\n"+ + " -networks string\n"+ + " \tRequired: comma separated list of ip address and network in CIDR notation to assign to this cert\n"+ " -out-crt string\n"+ " \tOptional: path to write the certificate to\n"+ " -out-key string\n"+ @@ -50,7 +50,11 @@ func Test_signHelp(t *testing.T) { " \tOptional: output a qr code image (png) of the certificate\n"+ optionalPkcs11String(" -pkcs11 string\n \tOptional: PKCS#11 URI to an existing private key\n")+ " -subnets string\n"+ - " \tOptional: comma separated list of ipv4 address and network in CIDR notation. Subnets this cert can serve for\n", + " \tDeprecated, see -unsafe-networks\n"+ + " -unsafe-networks string\n"+ + " \tOptional: comma separated list of ip address and network in CIDR notation. Unsafe networks this cert can route for\n"+ + " -version uint\n"+ + " \tOptional: version of the certificate format to use, the default is to create both v1 and v2 certificates.\n", ob.String(), ) } @@ -77,20 +81,20 @@ func Test_signCert(t *testing.T) { // required args assertHelpError(t, signCert( - []string{"-ca-crt", "./nope", "-ca-key", "./nope", "-ip", "1.1.1.1/24", "-out-key", "nope", "-out-crt", "nope"}, ob, eb, nopw, + []string{"-version", "1", "-ca-crt", "./nope", "-ca-key", "./nope", "-ip", "1.1.1.1/24", "-out-key", "nope", "-out-crt", "nope"}, ob, eb, nopw, ), "-name is required") assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) assertHelpError(t, signCert( - []string{"-ca-crt", "./nope", "-ca-key", "./nope", "-name", "test", "-out-key", "nope", "-out-crt", "nope"}, ob, eb, nopw, - ), "-ip is required") + []string{"-version", "1", "-ca-crt", "./nope", "-ca-key", "./nope", "-name", "test", "-out-key", "nope", "-out-crt", "nope"}, ob, eb, nopw, + ), "-networks is required") assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) // cannot set -in-pub and -out-key assertHelpError(t, signCert( - []string{"-ca-crt", "./nope", "-ca-key", "./nope", "-name", "test", "-in-pub", "nope", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope"}, ob, eb, nopw, + []string{"-version", "1", "-ca-crt", "./nope", "-ca-key", "./nope", "-name", "test", "-in-pub", "nope", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope"}, ob, eb, nopw, ), "cannot set both -in-pub and -out-key") assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) @@ -98,7 +102,7 @@ func Test_signCert(t *testing.T) { // failed to read key ob.Reset() eb.Reset() - args := []string{"-ca-crt", "./nope", "-ca-key", "./nope", "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"} + args := []string{"-version", "1", "-ca-crt", "./nope", "-ca-key", "./nope", "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"} assert.EqualError(t, signCert(args, ob, eb, nopw), "error while reading ca-key: open ./nope: "+NoSuchFileError) // failed to unmarshal key @@ -108,7 +112,7 @@ func Test_signCert(t *testing.T) { assert.Nil(t, err) defer os.Remove(caKeyF.Name()) - args = []string{"-ca-crt", "./nope", "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"} + args = []string{"-version", "1", "-ca-crt", "./nope", "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"} assert.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing ca-key: input did not contain a valid PEM encoded block") assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) @@ -120,7 +124,7 @@ func Test_signCert(t *testing.T) { caKeyF.Write(cert.MarshalSigningPrivateKeyToPEM(cert.Curve_CURVE25519, caPriv)) // failed to read cert - args = []string{"-ca-crt", "./nope", "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"} + args = []string{"-version", "1", "-ca-crt", "./nope", "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"} assert.EqualError(t, signCert(args, ob, eb, nopw), "error while reading ca-crt: open ./nope: "+NoSuchFileError) assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) @@ -132,7 +136,7 @@ func Test_signCert(t *testing.T) { assert.Nil(t, err) defer os.Remove(caCrtF.Name()) - args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"} + args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"} assert.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing ca-crt: input did not contain a valid PEM encoded block") assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) @@ -143,7 +147,7 @@ func Test_signCert(t *testing.T) { caCrtF.Write(b) // failed to read pub - args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-in-pub", "./nope", "-duration", "100m"} + args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-in-pub", "./nope", "-duration", "100m"} assert.EqualError(t, signCert(args, ob, eb, nopw), "error while reading in-pub: open ./nope: "+NoSuchFileError) assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) @@ -155,7 +159,7 @@ func Test_signCert(t *testing.T) { assert.Nil(t, err) defer os.Remove(inPubF.Name()) - args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-in-pub", inPubF.Name(), "-duration", "100m"} + args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-in-pub", inPubF.Name(), "-duration", "100m"} assert.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing in-pub: input did not contain a valid PEM encoded block") assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) @@ -169,30 +173,37 @@ func Test_signCert(t *testing.T) { // bad ip cidr ob.Reset() eb.Reset() - args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "a1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"} - assertHelpError(t, signCert(args, ob, eb, nopw), "invalid ip definition: a1.1.1.1/24") + args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "a1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"} + assertHelpError(t, signCert(args, ob, eb, nopw), "invalid -networks definition: a1.1.1.1/24") assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) ob.Reset() eb.Reset() - args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "100::100/100", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"} - assertHelpError(t, signCert(args, ob, eb, nopw), "invalid ip definition: can only be ipv4, have 100::100/100") + args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "100::100/100", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"} + assertHelpError(t, signCert(args, ob, eb, nopw), "invalid -networks definition: v1 certificates can only have a single ipv4 address") + assert.Empty(t, ob.String()) + assert.Empty(t, eb.String()) + + ob.Reset() + eb.Reset() + args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24,1.1.1.2/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"} + assertHelpError(t, signCert(args, ob, eb, nopw), "invalid -networks definition: v1 certificates can only have a single ipv4 address") assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) // bad subnet cidr ob.Reset() eb.Reset() - args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "a"} - assertHelpError(t, signCert(args, ob, eb, nopw), "invalid subnet definition: a") + args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "a"} + assertHelpError(t, signCert(args, ob, eb, nopw), "invalid -unsafe-networks definition: a") assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) ob.Reset() eb.Reset() - args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "100::100/100"} - assertHelpError(t, signCert(args, ob, eb, nopw), "invalid subnet definition: can only be ipv4, have 100::100/100") + args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "100::100/100"} + assertHelpError(t, signCert(args, ob, eb, nopw), "invalid -unsafe-networks definition: v1 certificates can only be ipv4") assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) @@ -205,7 +216,7 @@ func Test_signCert(t *testing.T) { ob.Reset() eb.Reset() - args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF2.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "a"} + args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF2.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "a"} assert.EqualError(t, signCert(args, ob, eb, nopw), "refusing to sign, root certificate does not match private key") assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) @@ -213,7 +224,7 @@ func Test_signCert(t *testing.T) { // failed key write ob.Reset() eb.Reset() - args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "/do/not/write/pleasecrt", "-out-key", "/do/not/write/pleasekey", "-duration", "100m", "-subnets", "10.1.1.1/32"} + args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "/do/not/write/pleasecrt", "-out-key", "/do/not/write/pleasekey", "-duration", "100m", "-subnets", "10.1.1.1/32"} assert.EqualError(t, signCert(args, ob, eb, nopw), "error while writing out-key: open /do/not/write/pleasekey: "+NoSuchDirError) assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) @@ -226,7 +237,7 @@ func Test_signCert(t *testing.T) { // failed cert write ob.Reset() eb.Reset() - args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "/do/not/write/pleasecrt", "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32"} + args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "/do/not/write/pleasecrt", "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32"} assert.EqualError(t, signCert(args, ob, eb, nopw), "error while writing out-crt: open /do/not/write/pleasecrt: "+NoSuchDirError) assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) @@ -240,7 +251,7 @@ func Test_signCert(t *testing.T) { // test proper cert with removed empty groups and subnets ob.Reset() eb.Reset() - args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} + args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} assert.Nil(t, signCert(args, ob, eb, nopw)) assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) @@ -283,7 +294,7 @@ func Test_signCert(t *testing.T) { os.Remove(crtF.Name()) ob.Reset() eb.Reset() - args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-in-pub", inPubF.Name(), "-duration", "100m", "-groups", "1"} + args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-in-pub", inPubF.Name(), "-duration", "100m", "-groups", "1"} assert.Nil(t, signCert(args, ob, eb, nopw)) assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) @@ -300,7 +311,7 @@ func Test_signCert(t *testing.T) { eb.Reset() os.Remove(keyF.Name()) os.Remove(crtF.Name()) - args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "1000m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} + args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "1000m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} assert.EqualError(t, signCert(args, ob, eb, nopw), "error while signing: certificate expires after signing certificate") assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) @@ -308,14 +319,14 @@ func Test_signCert(t *testing.T) { // create valid cert/key for overwrite tests os.Remove(keyF.Name()) os.Remove(crtF.Name()) - args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} + args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} assert.Nil(t, signCert(args, ob, eb, nopw)) // test that we won't overwrite existing key file os.Remove(crtF.Name()) ob.Reset() eb.Reset() - args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} + args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} assert.EqualError(t, signCert(args, ob, eb, nopw), "refusing to overwrite existing key: "+keyF.Name()) assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) @@ -323,14 +334,14 @@ func Test_signCert(t *testing.T) { // create valid cert/key for overwrite tests os.Remove(keyF.Name()) os.Remove(crtF.Name()) - args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} + args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} assert.Nil(t, signCert(args, ob, eb, nopw)) // test that we won't overwrite existing certificate file os.Remove(keyF.Name()) ob.Reset() eb.Reset() - args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} + args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} assert.EqualError(t, signCert(args, ob, eb, nopw), "refusing to overwrite existing cert: "+crtF.Name()) assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) @@ -362,7 +373,7 @@ func Test_signCert(t *testing.T) { caCrtF.Write(b) // test with the proper password - args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} + args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} assert.Nil(t, signCert(args, ob, eb, testpw)) assert.Equal(t, "Enter passphrase: ", ob.String()) assert.Empty(t, eb.String()) @@ -372,7 +383,7 @@ func Test_signCert(t *testing.T) { eb.Reset() testpw.password = []byte("invalid password") - args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} + args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} assert.Error(t, signCert(args, ob, eb, testpw)) assert.Equal(t, "Enter passphrase: ", ob.String()) assert.Empty(t, eb.String()) @@ -381,7 +392,7 @@ func Test_signCert(t *testing.T) { ob.Reset() eb.Reset() - args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} + args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} assert.Error(t, signCert(args, ob, eb, nopw)) // normally the user hitting enter on the prompt would add newlines between these assert.Equal(t, "Enter passphrase: Enter passphrase: Enter passphrase: Enter passphrase: Enter passphrase: ", ob.String()) @@ -391,7 +402,7 @@ func Test_signCert(t *testing.T) { ob.Reset() eb.Reset() - args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} + args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} assert.Error(t, signCert(args, ob, eb, errpw)) assert.Equal(t, "Enter passphrase: ", ob.String()) assert.Empty(t, eb.String()) diff --git a/cmd/nebula-cert/verify.go b/cmd/nebula-cert/verify.go index 80cfef3..bea4d1d 100644 --- a/cmd/nebula-cert/verify.go +++ b/cmd/nebula-cert/verify.go @@ -1,6 +1,7 @@ package main import ( + "errors" "flag" "fmt" "io" @@ -41,14 +42,14 @@ func verify(args []string, out io.Writer, errOut io.Writer) error { rawCACert, err := os.ReadFile(*vf.caPath) if err != nil { - return fmt.Errorf("error while reading ca: %s", err) + return fmt.Errorf("error while reading ca: %w", err) } caPool := cert.NewCAPool() for { rawCACert, err = caPool.AddCAFromPEM(rawCACert) if err != nil { - return fmt.Errorf("error while adding ca cert to pool: %s", err) + return fmt.Errorf("error while adding ca cert to pool: %w", err) } if rawCACert == nil || len(rawCACert) == 0 || strings.TrimSpace(string(rawCACert)) == "" { @@ -58,20 +59,30 @@ func verify(args []string, out io.Writer, errOut io.Writer) error { rawCert, err := os.ReadFile(*vf.certPath) if err != nil { - return fmt.Errorf("unable to read crt; %s", err) + return fmt.Errorf("unable to read crt: %w", err) + } + var errs []error + for { + if len(rawCert) == 0 { + break + } + c, extra, err := cert.UnmarshalCertificateFromPEM(rawCert) + if err != nil { + return fmt.Errorf("error while parsing crt: %w", err) + } + rawCert = extra + _, err = caPool.VerifyCertificate(time.Now(), c) + if err != nil { + switch { + case errors.Is(err, cert.ErrCaNotFound): + errs = append(errs, fmt.Errorf("error while verifying certificate v%d %s with issuer %s: %w", c.Version(), c.Name(), c.Issuer(), err)) + default: + errs = append(errs, fmt.Errorf("error while verifying certificate %+v: %w", c, err)) + } + } } - c, _, err := cert.UnmarshalCertificateFromPEM(rawCert) - if err != nil { - return fmt.Errorf("error while parsing crt: %s", err) - } - - _, err = caPool.VerifyCertificate(time.Now(), c) - if err != nil { - return err - } - - return nil + return errors.Join(errs...) } func verifySummary() string { @@ -80,7 +91,7 @@ func verifySummary() string { func verifyHelp(out io.Writer) { vf := newVerifyFlags() - out.Write([]byte("Usage of " + os.Args[0] + " " + verifySummary() + "\n")) + _, _ = out.Write([]byte("Usage of " + os.Args[0] + " " + verifySummary() + "\n")) vf.set.SetOutput(out) vf.set.PrintDefaults() } diff --git a/cmd/nebula-cert/verify_test.go b/cmd/nebula-cert/verify_test.go index 204ff09..d94bd1f 100644 --- a/cmd/nebula-cert/verify_test.go +++ b/cmd/nebula-cert/verify_test.go @@ -3,10 +3,12 @@ package main import ( "bytes" "crypto/rand" + "errors" "os" "testing" "time" + "github.com/slackhq/nebula/cert" "github.com/stretchr/testify/assert" "golang.org/x/crypto/ed25519" ) @@ -76,7 +78,7 @@ func Test_verify(t *testing.T) { err = verify([]string{"-ca", caFile.Name(), "-crt", "does_not_exist"}, ob, eb) assert.Equal(t, "", ob.String()) assert.Equal(t, "", eb.String()) - assert.EqualError(t, err, "unable to read crt; open does_not_exist: "+NoSuchFileError) + assert.EqualError(t, err, "unable to read crt: open does_not_exist: "+NoSuchFileError) // invalid crt at path ob.Reset() @@ -106,7 +108,7 @@ func Test_verify(t *testing.T) { err = verify([]string{"-ca", caFile.Name(), "-crt", certFile.Name()}, ob, eb) assert.Equal(t, "", ob.String()) assert.Equal(t, "", eb.String()) - assert.EqualError(t, err, "certificate signature did not match") + assert.True(t, errors.Is(err, cert.ErrSignatureMismatch)) // verified cert at path crt, _ = NewTestCert(ca, caPriv, "test-cert", time.Now().Add(time.Hour*-1), time.Now().Add(time.Hour), nil, nil, nil) diff --git a/config/config_test.go b/config/config_test.go index fa94393..c3a1a73 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -38,9 +38,6 @@ func TestConfig_Load(t *testing.T) { "new": "hi", } assert.Equal(t, expected, c.Settings) - - //TODO: test symlinked file - //TODO: test symlinked directory } func TestConfig_Get(t *testing.T) { diff --git a/connection_manager.go b/connection_manager.go index 7718252..9d8d071 100644 --- a/connection_manager.go +++ b/connection_manager.go @@ -183,7 +183,7 @@ func (n *connectionManager) doTrafficCheck(localIndex uint32, p, nb, out []byte, case deleteTunnel: if n.hostMap.DeleteHostInfo(hostinfo) { // Only clearing the lighthouse cache if this is the last hostinfo for this vpn ip in the hostmap - n.intf.lightHouse.DeleteVpnIp(hostinfo.vpnIp) + n.intf.lightHouse.DeleteVpnAddrs(hostinfo.vpnAddrs) } case closeTunnel: @@ -221,7 +221,7 @@ func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo) relayFor := oldhostinfo.relayState.CopyAllRelayFor() for _, r := range relayFor { - existing, ok := newhostinfo.relayState.QueryRelayForByIp(r.PeerIp) + existing, ok := newhostinfo.relayState.QueryRelayForByIp(r.PeerAddr) var index uint32 var relayFrom netip.Addr @@ -235,11 +235,11 @@ func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo) index = existing.LocalIndex switch r.Type { case TerminalType: - relayFrom = n.intf.myVpnNet.Addr() - relayTo = existing.PeerIp + relayFrom = n.intf.myVpnAddrs[0] + relayTo = existing.PeerAddr case ForwardingType: - relayFrom = existing.PeerIp - relayTo = newhostinfo.vpnIp + relayFrom = existing.PeerAddr + relayTo = newhostinfo.vpnAddrs[0] default: // should never happen } @@ -253,45 +253,64 @@ func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo) n.relayUsedLock.RUnlock() // The relay doesn't exist at all; create some relay state and send the request. var err error - index, err = AddRelay(n.l, newhostinfo, n.hostMap, r.PeerIp, nil, r.Type, Requested) + index, err = AddRelay(n.l, newhostinfo, n.hostMap, r.PeerAddr, nil, r.Type, Requested) if err != nil { n.l.WithError(err).Error("failed to migrate relay to new hostinfo") continue } switch r.Type { case TerminalType: - relayFrom = n.intf.myVpnNet.Addr() - relayTo = r.PeerIp + relayFrom = n.intf.myVpnAddrs[0] + relayTo = r.PeerAddr case ForwardingType: - relayFrom = r.PeerIp - relayTo = newhostinfo.vpnIp + relayFrom = r.PeerAddr + relayTo = newhostinfo.vpnAddrs[0] default: // should never happen } } - //TODO: IPV6-WORK - relayFromB := relayFrom.As4() - relayToB := relayTo.As4() - // Send a CreateRelayRequest to the peer. req := NebulaControl{ Type: NebulaControl_CreateRelayRequest, InitiatorRelayIndex: index, - RelayFromIp: binary.BigEndian.Uint32(relayFromB[:]), - RelayToIp: binary.BigEndian.Uint32(relayToB[:]), } + + switch newhostinfo.GetCert().Certificate.Version() { + case cert.Version1: + if !relayFrom.Is4() { + n.l.Error("can not migrate v1 relay with a v6 network because the relay is not running a current nebula version") + continue + } + + if !relayTo.Is4() { + n.l.Error("can not migrate v1 relay with a v6 remote network because the relay is not running a current nebula version") + continue + } + + b := relayFrom.As4() + req.OldRelayFromAddr = binary.BigEndian.Uint32(b[:]) + b = relayTo.As4() + req.OldRelayToAddr = binary.BigEndian.Uint32(b[:]) + case cert.Version2: + req.RelayFromAddr = netAddrToProtoAddr(relayFrom) + req.RelayToAddr = netAddrToProtoAddr(relayTo) + default: + newhostinfo.logger(n.l).Error("Unknown certificate version found while attempting to migrate relay") + continue + } + msg, err := req.Marshal() if err != nil { n.l.WithError(err).Error("failed to marshal Control message to migrate relay") } else { n.intf.SendMessageToHostInfo(header.Control, 0, newhostinfo, msg, make([]byte, 12), make([]byte, mtu)) n.l.WithFields(logrus.Fields{ - "relayFrom": req.RelayFromIp, - "relayTo": req.RelayToIp, + "relayFrom": req.RelayFromAddr, + "relayTo": req.RelayToAddr, "initiatorRelayIndex": req.InitiatorRelayIndex, "responderRelayIndex": req.ResponderRelayIndex, - "vpnIp": newhostinfo.vpnIp}). + "vpnAddrs": newhostinfo.vpnAddrs}). Info("send CreateRelayRequest") } } @@ -313,7 +332,7 @@ func (n *connectionManager) makeTrafficDecision(localIndex uint32, now time.Time return closeTunnel, hostinfo, nil } - primary := n.hostMap.Hosts[hostinfo.vpnIp] + primary := n.hostMap.Hosts[hostinfo.vpnAddrs[0]] mainHostInfo := true if primary != nil && primary != hostinfo { mainHostInfo = false @@ -407,21 +426,24 @@ func (n *connectionManager) shouldSwapPrimary(current, primary *HostInfo) bool { // If we are here then we have multiple tunnels for a host pair and neither side believes the same tunnel is primary. // Let's sort this out. - if current.vpnIp.Compare(n.intf.myVpnNet.Addr()) < 0 { - // Only one side should flip primary because if both flip then we may never resolve to a single tunnel. - // vpn ip is static across all tunnels for this host pair so lets use that to determine who is flipping. - // The remotes vpn ip is lower than mine. I will not flip. + // Only one side should swap because if both swap then we may never resolve to a single tunnel. + // vpn addr is static across all tunnels for this host pair so lets + // use that to determine if we should consider swapping. + if current.vpnAddrs[0].Compare(n.intf.myVpnAddrs[0]) < 0 { + // Their primary vpn addr is less than mine. Do not swap. return false } - certState := n.intf.pki.GetCertState() - return bytes.Equal(current.ConnectionState.myCert.Signature(), certState.Certificate.Signature()) + crt := n.intf.pki.getCertState().getCertificate(current.ConnectionState.myCert.Version()) + // If this tunnel is using the latest certificate then we should swap it to primary for a bit and see if things + // settle down. + return bytes.Equal(current.ConnectionState.myCert.Signature(), crt.Signature()) } func (n *connectionManager) swapPrimary(current, primary *HostInfo) { n.hostMap.Lock() // Make sure the primary is still the same after the write lock. This avoids a race with a rehandshake. - if n.hostMap.Hosts[current.vpnIp] == primary { + if n.hostMap.Hosts[current.vpnAddrs[0]] == primary { n.hostMap.unlockedMakePrimary(current) } n.hostMap.Unlock() @@ -473,14 +495,17 @@ func (n *connectionManager) sendPunch(hostinfo *HostInfo) { } func (n *connectionManager) tryRehandshake(hostinfo *HostInfo) { - certState := n.intf.pki.GetCertState() - if bytes.Equal(hostinfo.ConnectionState.myCert.Signature(), certState.Certificate.Signature()) { + cs := n.intf.pki.getCertState() + curCrt := hostinfo.ConnectionState.myCert + myCrt := cs.getCertificate(curCrt.Version()) + if curCrt.Version() >= cs.defaultVersion && bytes.Equal(curCrt.Signature(), myCrt.Signature()) == true { + // The current tunnel is using the latest certificate and version, no need to rehandshake. return } - n.l.WithField("vpnIp", hostinfo.vpnIp). + n.l.WithField("vpnAddrs", hostinfo.vpnAddrs). WithField("reason", "local certificate is not current"). Info("Re-handshaking with remote") - n.intf.handshakeManager.StartHandshake(hostinfo.vpnIp, nil) + n.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil) } diff --git a/connection_manager_test.go b/connection_manager_test.go index 9f222c8..8e2ef15 100644 --- a/connection_manager_test.go +++ b/connection_manager_test.go @@ -34,20 +34,19 @@ func newTestLighthouse() *LightHouse { func Test_NewConnectionManagerTest(t *testing.T) { l := test.NewLogger() //_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24") - vpncidr := netip.MustParsePrefix("172.1.1.1/24") localrange := netip.MustParsePrefix("10.1.1.1/24") vpnIp := netip.MustParseAddr("172.1.1.2") preferredRanges := []netip.Prefix{localrange} // Very incomplete mock objects - hostMap := newHostMap(l, vpncidr) + hostMap := newHostMap(l) hostMap.preferredRanges.Store(&preferredRanges) cs := &CertState{ - RawCertificate: []byte{}, - PrivateKey: []byte{}, - Certificate: &dummyCert{}, - RawCertificateNoKey: []byte{}, + defaultVersion: cert.Version1, + privateKey: []byte{}, + v1Cert: &dummyCert{version: cert.Version1}, + v1HandshakeBytes: []byte{}, } lh := newTestLighthouse() @@ -74,12 +73,12 @@ func Test_NewConnectionManagerTest(t *testing.T) { // Add an ip we have established a connection w/ to hostmap hostinfo := &HostInfo{ - vpnIp: vpnIp, + vpnAddrs: []netip.Addr{vpnIp}, localIndexId: 1099, remoteIndexId: 9901, } hostinfo.ConnectionState = &ConnectionState{ - myCert: &dummyCert{}, + myCert: &dummyCert{version: cert.Version1}, H: &noise.HandshakeState{}, } nc.hostMap.unlockedAddHostInfo(hostinfo, ifce) @@ -88,7 +87,7 @@ func Test_NewConnectionManagerTest(t *testing.T) { nc.Out(hostinfo.localIndexId) nc.In(hostinfo.localIndexId) assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId) - assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnIp) + assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0]) assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId) assert.Contains(t, nc.out, hostinfo.localIndexId) @@ -105,32 +104,31 @@ func Test_NewConnectionManagerTest(t *testing.T) { assert.NotContains(t, nc.out, hostinfo.localIndexId) assert.NotContains(t, nc.in, hostinfo.localIndexId) assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId) - assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnIp) + assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0]) // Do a final traffic check tick, the host should now be removed nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now()) assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId) - assert.NotContains(t, nc.hostMap.Hosts, hostinfo.vpnIp) + assert.NotContains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0]) assert.NotContains(t, nc.hostMap.Indexes, hostinfo.localIndexId) } func Test_NewConnectionManagerTest2(t *testing.T) { l := test.NewLogger() //_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24") - vpncidr := netip.MustParsePrefix("172.1.1.1/24") localrange := netip.MustParsePrefix("10.1.1.1/24") vpnIp := netip.MustParseAddr("172.1.1.2") preferredRanges := []netip.Prefix{localrange} // Very incomplete mock objects - hostMap := newHostMap(l, vpncidr) + hostMap := newHostMap(l) hostMap.preferredRanges.Store(&preferredRanges) cs := &CertState{ - RawCertificate: []byte{}, - PrivateKey: []byte{}, - Certificate: &dummyCert{}, - RawCertificateNoKey: []byte{}, + defaultVersion: cert.Version1, + privateKey: []byte{}, + v1Cert: &dummyCert{version: cert.Version1}, + v1HandshakeBytes: []byte{}, } lh := newTestLighthouse() @@ -157,12 +155,12 @@ func Test_NewConnectionManagerTest2(t *testing.T) { // Add an ip we have established a connection w/ to hostmap hostinfo := &HostInfo{ - vpnIp: vpnIp, + vpnAddrs: []netip.Addr{vpnIp}, localIndexId: 1099, remoteIndexId: 9901, } hostinfo.ConnectionState = &ConnectionState{ - myCert: &dummyCert{}, + myCert: &dummyCert{version: cert.Version1}, H: &noise.HandshakeState{}, } nc.hostMap.unlockedAddHostInfo(hostinfo, ifce) @@ -170,8 +168,8 @@ func Test_NewConnectionManagerTest2(t *testing.T) { // We saw traffic out to vpnIp nc.Out(hostinfo.localIndexId) nc.In(hostinfo.localIndexId) - assert.NotContains(t, nc.pendingDeletion, hostinfo.vpnIp) - assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnIp) + assert.NotContains(t, nc.pendingDeletion, hostinfo.vpnAddrs[0]) + assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0]) assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId) // Do a traffic check tick, should not be pending deletion but should not have any in/out packets recorded @@ -187,7 +185,7 @@ func Test_NewConnectionManagerTest2(t *testing.T) { assert.NotContains(t, nc.out, hostinfo.localIndexId) assert.NotContains(t, nc.in, hostinfo.localIndexId) assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId) - assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnIp) + assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0]) // We saw traffic, should no longer be pending deletion nc.In(hostinfo.localIndexId) @@ -196,7 +194,7 @@ func Test_NewConnectionManagerTest2(t *testing.T) { assert.NotContains(t, nc.out, hostinfo.localIndexId) assert.NotContains(t, nc.in, hostinfo.localIndexId) assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId) - assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnIp) + assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0]) } // Check if we can disconnect the peer. @@ -210,7 +208,7 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) { localrange := netip.MustParsePrefix("10.1.1.1/24") vpnIp := netip.MustParseAddr("172.1.1.2") preferredRanges := []netip.Prefix{localrange} - hostMap := newHostMap(l, vpncidr) + hostMap := newHostMap(l) hostMap.preferredRanges.Store(&preferredRanges) // Generate keys for CA and peer's cert. @@ -244,10 +242,9 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) { cachedPeerCert, err := ncp.VerifyCertificate(now.Add(time.Second), peerCert) cs := &CertState{ - RawCertificate: []byte{}, - PrivateKey: []byte{}, - Certificate: &dummyCert{}, - RawCertificateNoKey: []byte{}, + privateKey: []byte{}, + v1Cert: &dummyCert{}, + v1HandshakeBytes: []byte{}, } lh := newTestLighthouse() @@ -273,7 +270,7 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) { ifce.connectionManager = nc hostinfo := &HostInfo{ - vpnIp: vpnIp, + vpnAddrs: []netip.Addr{vpnIp}, ConnectionState: &ConnectionState{ myCert: &dummyCert{}, peerCert: cachedPeerCert, diff --git a/connection_state.go b/connection_state.go index bcc9e5d..faee443 100644 --- a/connection_state.go +++ b/connection_state.go @@ -3,6 +3,7 @@ package nebula import ( "crypto/rand" "encoding/json" + "fmt" "sync" "sync/atomic" @@ -26,46 +27,46 @@ type ConnectionState struct { writeLock sync.Mutex } -func NewConnectionState(l *logrus.Logger, cipher string, certState *CertState, initiator bool, pattern noise.HandshakePattern, psk []byte, pskStage int) *ConnectionState { +func NewConnectionState(l *logrus.Logger, cs *CertState, crt cert.Certificate, initiator bool, pattern noise.HandshakePattern) (*ConnectionState, error) { var dhFunc noise.DHFunc - switch certState.Certificate.Curve() { + switch crt.Curve() { case cert.Curve_CURVE25519: dhFunc = noise.DH25519 case cert.Curve_P256: - if certState.pkcs11Backed { + if cs.pkcs11Backed { dhFunc = noiseutil.DHP256PKCS11 } else { dhFunc = noiseutil.DHP256 } default: - l.Errorf("invalid curve: %s", certState.Certificate.Curve()) - return nil + return nil, fmt.Errorf("invalid curve: %s", crt.Curve()) } - var cs noise.CipherSuite - if cipher == "chachapoly" { - cs = noise.NewCipherSuite(dhFunc, noise.CipherChaChaPoly, noise.HashSHA256) + var ncs noise.CipherSuite + if cs.cipher == "chachapoly" { + ncs = noise.NewCipherSuite(dhFunc, noise.CipherChaChaPoly, noise.HashSHA256) } else { - cs = noise.NewCipherSuite(dhFunc, noiseutil.CipherAESGCM, noise.HashSHA256) + ncs = noise.NewCipherSuite(dhFunc, noiseutil.CipherAESGCM, noise.HashSHA256) } - static := noise.DHKey{Private: certState.PrivateKey, Public: certState.PublicKey} + static := noise.DHKey{Private: cs.privateKey, Public: crt.PublicKey()} b := NewBits(ReplayWindow) - // Clear out bit 0, we never transmit it and we don't want it showing as packet loss + // Clear out bit 0, we never transmit it, and we don't want it showing as packet loss b.Update(l, 0) hs, err := noise.NewHandshakeState(noise.Config{ - CipherSuite: cs, - Random: rand.Reader, - Pattern: pattern, - Initiator: initiator, - StaticKeypair: static, - PresharedKey: psk, - PresharedKeyPlacement: pskStage, + CipherSuite: ncs, + Random: rand.Reader, + Pattern: pattern, + Initiator: initiator, + StaticKeypair: static, + //NOTE: These should come from CertState (pki.go) when we finally implement it + PresharedKey: []byte{}, + PresharedKeyPlacement: 0, }) if err != nil { - return nil + return nil, fmt.Errorf("NewConnectionState: %s", err) } // The queue and ready params prevent a counter race that would happen when @@ -74,12 +75,12 @@ func NewConnectionState(l *logrus.Logger, cipher string, certState *CertState, i H: hs, initiator: initiator, window: b, - myCert: certState.Certificate, + myCert: crt, } // always start the counter from 2, as packet 1 and packet 2 are handshake packets. ci.messageCounter.Add(2) - return ci + return ci, nil } func (cs *ConnectionState) MarshalJSON() ([]byte, error) { @@ -89,3 +90,7 @@ func (cs *ConnectionState) MarshalJSON() ([]byte, error) { "message_counter": cs.messageCounter.Load(), }) } + +func (cs *ConnectionState) Curve() cert.Curve { + return cs.myCert.Curve() +} diff --git a/control.go b/control.go index 2615984..20dd7fe 100644 --- a/control.go +++ b/control.go @@ -19,9 +19,9 @@ import ( type controlEach func(h *HostInfo) type controlHostLister interface { - QueryVpnIp(vpnIp netip.Addr) *HostInfo + QueryVpnAddr(vpnAddr netip.Addr) *HostInfo ForEachIndex(each controlEach) - ForEachVpnIp(each controlEach) + ForEachVpnAddr(each controlEach) GetPreferredRanges() []netip.Prefix } @@ -37,7 +37,7 @@ type Control struct { } type ControlHostInfo struct { - VpnIp netip.Addr `json:"vpnIp"` + VpnAddrs []netip.Addr `json:"vpnAddrs"` LocalIndex uint32 `json:"localIndex"` RemoteIndex uint32 `json:"remoteIndex"` RemoteAddrs []netip.AddrPort `json:"remoteAddrs"` @@ -131,10 +131,13 @@ func (c *Control) ListHostmapIndexes(pendingMap bool) []ControlHostInfo { // GetCertByVpnIp returns the authenticated certificate of the given vpn IP, or nil if not found func (c *Control) GetCertByVpnIp(vpnIp netip.Addr) cert.Certificate { - if c.f.myVpnNet.Addr() == vpnIp { - return c.f.pki.GetCertState().Certificate.Copy() + _, found := c.f.myVpnAddrsTable.Lookup(vpnIp) + if found { + // Only returning the default certificate since its impossible + // for any other host but ourselves to have more than 1 + return c.f.pki.getCertState().GetDefaultCertificate().Copy() } - hi := c.f.hostMap.QueryVpnIp(vpnIp) + hi := c.f.hostMap.QueryVpnAddr(vpnIp) if hi == nil { return nil } @@ -148,7 +151,7 @@ func (c *Control) CreateTunnel(vpnIp netip.Addr) { // PrintTunnel creates a new tunnel to the given vpn ip. func (c *Control) PrintTunnel(vpnIp netip.Addr) *ControlHostInfo { - hi := c.f.hostMap.QueryVpnIp(vpnIp) + hi := c.f.hostMap.QueryVpnAddr(vpnIp) if hi == nil { return nil } @@ -165,9 +168,9 @@ func (c *Control) QueryLighthouse(vpnIp netip.Addr) *CacheMap { return hi.CopyCache() } -// GetHostInfoByVpnIp returns a single tunnels hostInfo, or nil if not found +// GetHostInfoByVpnAddr returns a single tunnels hostInfo, or nil if not found // Caller should take care to Unmap() any 4in6 addresses prior to calling. -func (c *Control) GetHostInfoByVpnIp(vpnIp netip.Addr, pending bool) *ControlHostInfo { +func (c *Control) GetHostInfoByVpnAddr(vpnAddr netip.Addr, pending bool) *ControlHostInfo { var hl controlHostLister if pending { hl = c.f.handshakeManager @@ -175,7 +178,7 @@ func (c *Control) GetHostInfoByVpnIp(vpnIp netip.Addr, pending bool) *ControlHos hl = c.f.hostMap } - h := hl.QueryVpnIp(vpnIp) + h := hl.QueryVpnAddr(vpnAddr) if h == nil { return nil } @@ -187,7 +190,7 @@ func (c *Control) GetHostInfoByVpnIp(vpnIp netip.Addr, pending bool) *ControlHos // SetRemoteForTunnel forces a tunnel to use a specific remote // Caller should take care to Unmap() any 4in6 addresses prior to calling. func (c *Control) SetRemoteForTunnel(vpnIp netip.Addr, addr netip.AddrPort) *ControlHostInfo { - hostInfo := c.f.hostMap.QueryVpnIp(vpnIp) + hostInfo := c.f.hostMap.QueryVpnAddr(vpnIp) if hostInfo == nil { return nil } @@ -200,7 +203,7 @@ func (c *Control) SetRemoteForTunnel(vpnIp netip.Addr, addr netip.AddrPort) *Con // CloseTunnel closes a fully established tunnel. If localOnly is false it will notify the remote end as well. // Caller should take care to Unmap() any 4in6 addresses prior to calling. func (c *Control) CloseTunnel(vpnIp netip.Addr, localOnly bool) bool { - hostInfo := c.f.hostMap.QueryVpnIp(vpnIp) + hostInfo := c.f.hostMap.QueryVpnAddr(vpnIp) if hostInfo == nil { return false } @@ -224,19 +227,14 @@ func (c *Control) CloseTunnel(vpnIp netip.Addr, localOnly bool) bool { // CloseAllTunnels is just like CloseTunnel except it goes through and shuts them all down, optionally you can avoid shutting down lighthouse tunnels // the int returned is a count of tunnels closed func (c *Control) CloseAllTunnels(excludeLighthouses bool) (closed int) { - //TODO: this is probably better as a function in ConnectionManager or HostMap directly - lighthouses := c.f.lightHouse.GetLighthouses() - shutdown := func(h *HostInfo) { - if excludeLighthouses { - if _, ok := lighthouses[h.vpnIp]; ok { - return - } + if excludeLighthouses && c.f.lightHouse.IsAnyLighthouseAddr(h.vpnAddrs) { + return } c.f.send(header.CloseTunnel, 0, h.ConnectionState, h, []byte{}, make([]byte, 12, 12), make([]byte, mtu)) c.f.closeTunnel(h) - c.l.WithField("vpnIp", h.vpnIp).WithField("udpAddr", h.remote). + c.l.WithField("vpnAddrs", h.vpnAddrs).WithField("udpAddr", h.remote). Debug("Sending close tunnel message") closed++ } @@ -246,7 +244,7 @@ func (c *Control) CloseAllTunnels(excludeLighthouses bool) (closed int) { // Grab the hostMap lock to access the Relays map c.f.hostMap.Lock() for _, relayingHost := range c.f.hostMap.Relays { - relayingHosts[relayingHost.vpnIp] = relayingHost + relayingHosts[relayingHost.vpnAddrs[0]] = relayingHost } c.f.hostMap.Unlock() @@ -254,7 +252,7 @@ func (c *Control) CloseAllTunnels(excludeLighthouses bool) (closed int) { // Grab the hostMap lock to access the Hosts map c.f.hostMap.Lock() for _, relayHost := range c.f.hostMap.Indexes { - if _, ok := relayingHosts[relayHost.vpnIp]; !ok { + if _, ok := relayingHosts[relayHost.vpnAddrs[0]]; !ok { hostInfos = append(hostInfos, relayHost) } } @@ -274,9 +272,8 @@ func (c *Control) Device() overlay.Device { } func copyHostInfo(h *HostInfo, preferredRanges []netip.Prefix) ControlHostInfo { - chi := ControlHostInfo{ - VpnIp: h.vpnIp, + VpnAddrs: make([]netip.Addr, len(h.vpnAddrs)), LocalIndex: h.localIndexId, RemoteIndex: h.remoteIndexId, RemoteAddrs: h.remotes.CopyAddrs(preferredRanges), @@ -285,6 +282,10 @@ func copyHostInfo(h *HostInfo, preferredRanges []netip.Prefix) ControlHostInfo { CurrentRemote: h.remote, } + for i, a := range h.vpnAddrs { + chi.VpnAddrs[i] = a + } + if h.ConnectionState != nil { chi.MessageCounter = h.ConnectionState.messageCounter.Load() } @@ -299,7 +300,7 @@ func copyHostInfo(h *HostInfo, preferredRanges []netip.Prefix) ControlHostInfo { func listHostMapHosts(hl controlHostLister) []ControlHostInfo { hosts := make([]ControlHostInfo, 0) pr := hl.GetPreferredRanges() - hl.ForEachVpnIp(func(hostinfo *HostInfo) { + hl.ForEachVpnAddr(func(hostinfo *HostInfo) { hosts = append(hosts, copyHostInfo(hostinfo, pr)) }) return hosts diff --git a/control_test.go b/control_test.go index fdfc0a5..6ce7083 100644 --- a/control_test.go +++ b/control_test.go @@ -13,13 +13,13 @@ import ( ) func TestControl_GetHostInfoByVpnIp(t *testing.T) { - //TODO: with multiple certificate versions we have a problem with this test + //TODO: CERT-V2 with multiple certificate versions we have a problem with this test // Some certs versions have different characteristics and each version implements their own Copy() func // which means this is not a good place to test for exposing memory l := test.NewLogger() // Special care must be taken to re-use all objects provided to the hostmap and certificate in the expectedInfo object // To properly ensure we are not exposing core memory to the caller - hm := newHostMap(l, netip.Prefix{}) + hm := newHostMap(l) hm.preferredRanges.Store(&[]netip.Prefix{}) remote1 := netip.MustParseAddrPort("0.0.0.100:4444") @@ -35,9 +35,9 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) { Mask: net.IPMask{255, 255, 255, 0}, } - remotes := NewRemoteList(nil) - remotes.unlockedPrependV4(netip.IPv4Unspecified(), NewIp4AndPortFromNetIP(remote1.Addr(), remote1.Port())) - remotes.unlockedPrependV6(netip.IPv4Unspecified(), NewIp6AndPortFromNetIP(remote2.Addr(), remote2.Port())) + remotes := NewRemoteList([]netip.Addr{netip.IPv4Unspecified()}, nil) + remotes.unlockedPrependV4(netip.IPv4Unspecified(), netAddrToProtoV4AddrPort(remote1.Addr(), remote1.Port())) + remotes.unlockedPrependV6(netip.IPv4Unspecified(), netAddrToProtoV6AddrPort(remote2.Addr(), remote2.Port())) vpnIp, ok := netip.AddrFromSlice(ipNet.IP) assert.True(t, ok) @@ -51,11 +51,11 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) { }, remoteIndexId: 200, localIndexId: 201, - vpnIp: vpnIp, + vpnAddrs: []netip.Addr{vpnIp}, relayState: RelayState{ - relays: map[netip.Addr]struct{}{}, - relayForByIp: map[netip.Addr]*Relay{}, - relayForByIdx: map[uint32]*Relay{}, + relays: map[netip.Addr]struct{}{}, + relayForByAddr: map[netip.Addr]*Relay{}, + relayForByIdx: map[uint32]*Relay{}, }, }, &Interface{}) @@ -70,11 +70,11 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) { }, remoteIndexId: 200, localIndexId: 201, - vpnIp: vpnIp2, + vpnAddrs: []netip.Addr{vpnIp2}, relayState: RelayState{ - relays: map[netip.Addr]struct{}{}, - relayForByIp: map[netip.Addr]*Relay{}, - relayForByIdx: map[uint32]*Relay{}, + relays: map[netip.Addr]struct{}{}, + relayForByAddr: map[netip.Addr]*Relay{}, + relayForByIdx: map[uint32]*Relay{}, }, }, &Interface{}) @@ -85,10 +85,10 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) { l: logrus.New(), } - thi := c.GetHostInfoByVpnIp(vpnIp, false) + thi := c.GetHostInfoByVpnAddr(vpnIp, false) expectedInfo := ControlHostInfo{ - VpnIp: vpnIp, + VpnAddrs: []netip.Addr{vpnIp}, LocalIndex: 201, RemoteIndex: 200, RemoteAddrs: []netip.AddrPort{remote2, remote1}, @@ -100,13 +100,13 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) { } // Make sure we don't have any unexpected fields - assertFields(t, []string{"VpnIp", "LocalIndex", "RemoteIndex", "RemoteAddrs", "Cert", "MessageCounter", "CurrentRemote", "CurrentRelaysToMe", "CurrentRelaysThroughMe"}, thi) + assertFields(t, []string{"VpnAddrs", "LocalIndex", "RemoteIndex", "RemoteAddrs", "Cert", "MessageCounter", "CurrentRemote", "CurrentRelaysToMe", "CurrentRelaysThroughMe"}, thi) assert.EqualValues(t, &expectedInfo, thi) test.AssertDeepCopyEqual(t, &expectedInfo, thi) // Make sure we don't panic if the host info doesn't have a cert yet assert.NotPanics(t, func() { - thi = c.GetHostInfoByVpnIp(vpnIp2, false) + thi = c.GetHostInfoByVpnAddr(vpnIp2, false) }) } diff --git a/control_tester.go b/control_tester.go index fa87e53..451dac5 100644 --- a/control_tester.go +++ b/control_tester.go @@ -6,8 +6,6 @@ package nebula import ( "net/netip" - "github.com/slackhq/nebula/cert" - "github.com/google/gopacket" "github.com/google/gopacket/layers" "github.com/slackhq/nebula/header" @@ -51,15 +49,15 @@ func (c *Control) WaitForTypeByIndex(toIndex uint32, msgType header.MessageType, // This is necessary if you did not configure static hosts or are not running a lighthouse func (c *Control) InjectLightHouseAddr(vpnIp netip.Addr, toAddr netip.AddrPort) { c.f.lightHouse.Lock() - remoteList := c.f.lightHouse.unlockedGetRemoteList(vpnIp) + remoteList := c.f.lightHouse.unlockedGetRemoteList([]netip.Addr{vpnIp}) remoteList.Lock() defer remoteList.Unlock() c.f.lightHouse.Unlock() if toAddr.Addr().Is4() { - remoteList.unlockedPrependV4(vpnIp, NewIp4AndPortFromNetIP(toAddr.Addr(), toAddr.Port())) + remoteList.unlockedPrependV4(vpnIp, netAddrToProtoV4AddrPort(toAddr.Addr(), toAddr.Port())) } else { - remoteList.unlockedPrependV6(vpnIp, NewIp6AndPortFromNetIP(toAddr.Addr(), toAddr.Port())) + remoteList.unlockedPrependV6(vpnIp, netAddrToProtoV6AddrPort(toAddr.Addr(), toAddr.Port())) } } @@ -67,12 +65,12 @@ func (c *Control) InjectLightHouseAddr(vpnIp netip.Addr, toAddr netip.AddrPort) // This is necessary to inform an initiator of possible relays for communicating with a responder func (c *Control) InjectRelays(vpnIp netip.Addr, relayVpnIps []netip.Addr) { c.f.lightHouse.Lock() - remoteList := c.f.lightHouse.unlockedGetRemoteList(vpnIp) + remoteList := c.f.lightHouse.unlockedGetRemoteList([]netip.Addr{vpnIp}) remoteList.Lock() defer remoteList.Unlock() c.f.lightHouse.Unlock() - remoteList.unlockedSetRelay(vpnIp, vpnIp, relayVpnIps) + remoteList.unlockedSetRelay(vpnIp, relayVpnIps) } // GetFromTun will pull a packet off the tun side of nebula @@ -99,21 +97,42 @@ func (c *Control) InjectUDPPacket(p *udp.Packet) { } // InjectTunUDPPacket puts a udp packet on the tun interface. Using UDP here because it's a simpler protocol -func (c *Control) InjectTunUDPPacket(toIp netip.Addr, toPort uint16, fromPort uint16, data []byte) { - //TODO: IPV6-WORK - ip := layers.IPv4{ - Version: 4, - TTL: 64, - Protocol: layers.IPProtocolUDP, - SrcIP: c.f.inside.Cidr().Addr().Unmap().AsSlice(), - DstIP: toIp.Unmap().AsSlice(), +func (c *Control) InjectTunUDPPacket(toAddr netip.Addr, toPort uint16, fromAddr netip.Addr, fromPort uint16, data []byte) { + serialize := make([]gopacket.SerializableLayer, 0) + var netLayer gopacket.NetworkLayer + if toAddr.Is6() { + if !fromAddr.Is6() { + panic("Cant send ipv6 to ipv4") + } + ip := &layers.IPv6{ + Version: 6, + NextHeader: layers.IPProtocolUDP, + SrcIP: fromAddr.Unmap().AsSlice(), + DstIP: toAddr.Unmap().AsSlice(), + } + serialize = append(serialize, ip) + netLayer = ip + } else { + if !fromAddr.Is4() { + panic("Cant send ipv4 to ipv6") + } + + ip := &layers.IPv4{ + Version: 4, + TTL: 64, + Protocol: layers.IPProtocolUDP, + SrcIP: fromAddr.Unmap().AsSlice(), + DstIP: toAddr.Unmap().AsSlice(), + } + serialize = append(serialize, ip) + netLayer = ip } udp := layers.UDP{ SrcPort: layers.UDPPort(fromPort), DstPort: layers.UDPPort(toPort), } - err := udp.SetNetworkLayerForChecksum(&ip) + err := udp.SetNetworkLayerForChecksum(netLayer) if err != nil { panic(err) } @@ -123,7 +142,9 @@ func (c *Control) InjectTunUDPPacket(toIp netip.Addr, toPort uint16, fromPort ui ComputeChecksums: true, FixLengths: true, } - err = gopacket.SerializeLayers(buffer, opt, &ip, &udp, gopacket.Payload(data)) + + serialize = append(serialize, &udp, gopacket.Payload(data)) + err = gopacket.SerializeLayers(buffer, opt, serialize...) if err != nil { panic(err) } @@ -131,8 +152,8 @@ func (c *Control) InjectTunUDPPacket(toIp netip.Addr, toPort uint16, fromPort ui c.f.inside.(*overlay.TestTun).Send(buffer.Bytes()) } -func (c *Control) GetVpnIp() netip.Addr { - return c.f.myVpnNet.Addr() +func (c *Control) GetVpnAddrs() []netip.Addr { + return c.f.myVpnAddrs } func (c *Control) GetUDPAddr() netip.AddrPort { @@ -140,7 +161,7 @@ func (c *Control) GetUDPAddr() netip.AddrPort { } func (c *Control) KillPendingTunnel(vpnIp netip.Addr) bool { - hostinfo := c.f.handshakeManager.QueryVpnIp(vpnIp) + hostinfo := c.f.handshakeManager.QueryVpnAddr(vpnIp) if hostinfo == nil { return false } @@ -153,8 +174,8 @@ func (c *Control) GetHostmap() *HostMap { return c.f.hostMap } -func (c *Control) GetCert() cert.Certificate { - return c.f.pki.GetCertState().Certificate +func (c *Control) GetCertState() *CertState { + return c.f.pki.getCertState() } func (c *Control) ReHandshake(vpnIp netip.Addr) { diff --git a/dns_server.go b/dns_server.go index 7501231..710f6ed 100644 --- a/dns_server.go +++ b/dns_server.go @@ -8,6 +8,7 @@ import ( "strings" "sync" + "github.com/gaissmai/bart" "github.com/miekg/dns" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" @@ -21,24 +22,39 @@ var dnsAddr string type dnsRecords struct { sync.RWMutex - dnsMap map[string]string - hostMap *HostMap + l *logrus.Logger + dnsMap4 map[string]netip.Addr + dnsMap6 map[string]netip.Addr + hostMap *HostMap + myVpnAddrsTable *bart.Table[struct{}] } -func newDnsRecords(hostMap *HostMap) *dnsRecords { +func newDnsRecords(l *logrus.Logger, cs *CertState, hostMap *HostMap) *dnsRecords { return &dnsRecords{ - dnsMap: make(map[string]string), - hostMap: hostMap, + l: l, + dnsMap4: make(map[string]netip.Addr), + dnsMap6: make(map[string]netip.Addr), + hostMap: hostMap, + myVpnAddrsTable: cs.myVpnAddrsTable, } } -func (d *dnsRecords) Query(data string) string { +func (d *dnsRecords) Query(q uint16, data string) netip.Addr { + data = strings.ToLower(data) d.RLock() defer d.RUnlock() - if r, ok := d.dnsMap[strings.ToLower(data)]; ok { - return r + switch q { + case dns.TypeA: + if r, ok := d.dnsMap4[data]; ok { + return r + } + case dns.TypeAAAA: + if r, ok := d.dnsMap6[data]; ok { + return r + } } - return "" + + return netip.Addr{} } func (d *dnsRecords) QueryCert(data string) string { @@ -47,7 +63,7 @@ func (d *dnsRecords) QueryCert(data string) string { return "" } - hostinfo := d.hostMap.QueryVpnIp(ip) + hostinfo := d.hostMap.QueryVpnAddr(ip) if hostinfo == nil { return "" } @@ -64,38 +80,62 @@ func (d *dnsRecords) QueryCert(data string) string { return string(b) } -func (d *dnsRecords) Add(host, data string) { +// Add adds the first IPv4 and IPv6 address that appears in `addresses` as the record for `host` +func (d *dnsRecords) Add(host string, addresses []netip.Addr) { + host = strings.ToLower(host) d.Lock() defer d.Unlock() - d.dnsMap[strings.ToLower(host)] = data + haveV4 := false + haveV6 := false + for _, addr := range addresses { + if addr.Is4() && !haveV4 { + d.dnsMap4[host] = addr + haveV4 = true + } else if addr.Is6() && !haveV6 { + d.dnsMap6[host] = addr + haveV6 = true + } + if haveV4 && haveV6 { + break + } + } } -func parseQuery(l *logrus.Logger, m *dns.Msg, w dns.ResponseWriter) { +func (d *dnsRecords) isSelfNebulaOrLocalhost(addr string) bool { + a, _, _ := net.SplitHostPort(addr) + b, err := netip.ParseAddr(a) + if err != nil { + return false + } + + if b.IsLoopback() { + return true + } + + _, found := d.myVpnAddrsTable.Lookup(b) + return found //if we found it in this table, it's good +} + +func (d *dnsRecords) parseQuery(m *dns.Msg, w dns.ResponseWriter) { for _, q := range m.Question { switch q.Qtype { - case dns.TypeA: - l.Debugf("Query for A %s", q.Name) - ip := dnsR.Query(q.Name) - if ip != "" { - rr, err := dns.NewRR(fmt.Sprintf("%s A %s", q.Name, ip)) + case dns.TypeA, dns.TypeAAAA: + qType := dns.TypeToString[q.Qtype] + d.l.Debugf("Query for %s %s", qType, q.Name) + ip := d.Query(q.Qtype, q.Name) + if ip.IsValid() { + rr, err := dns.NewRR(fmt.Sprintf("%s %s %s", q.Name, qType, ip)) if err == nil { m.Answer = append(m.Answer, rr) } } case dns.TypeTXT: - a, _, _ := net.SplitHostPort(w.RemoteAddr().String()) - b, err := netip.ParseAddr(a) - if err != nil { + // We only answer these queries from nebula nodes or localhost + if !d.isSelfNebulaOrLocalhost(w.RemoteAddr().String()) { return } - - // We don't answer these queries from non nebula nodes or localhost - //l.Debugf("Does %s contain %s", b, dnsR.hostMap.vpnCIDR) - if !dnsR.hostMap.vpnCIDR.Contains(b) && a != "127.0.0.1" { - return - } - l.Debugf("Query for TXT %s", q.Name) - ip := dnsR.QueryCert(q.Name) + d.l.Debugf("Query for TXT %s", q.Name) + ip := d.QueryCert(q.Name) if ip != "" { rr, err := dns.NewRR(fmt.Sprintf("%s TXT %s", q.Name, ip)) if err == nil { @@ -110,26 +150,24 @@ func parseQuery(l *logrus.Logger, m *dns.Msg, w dns.ResponseWriter) { } } -func handleDnsRequest(l *logrus.Logger, w dns.ResponseWriter, r *dns.Msg) { +func (d *dnsRecords) handleDnsRequest(w dns.ResponseWriter, r *dns.Msg) { m := new(dns.Msg) m.SetReply(r) m.Compress = false switch r.Opcode { case dns.OpcodeQuery: - parseQuery(l, m, w) + d.parseQuery(m, w) } w.WriteMsg(m) } -func dnsMain(l *logrus.Logger, hostMap *HostMap, c *config.C) func() { - dnsR = newDnsRecords(hostMap) +func dnsMain(l *logrus.Logger, cs *CertState, hostMap *HostMap, c *config.C) func() { + dnsR = newDnsRecords(l, cs, hostMap) // attach request handler func - dns.HandleFunc(".", func(w dns.ResponseWriter, r *dns.Msg) { - handleDnsRequest(l, w, r) - }) + dns.HandleFunc(".", dnsR.handleDnsRequest) c.RegisterReloadCallback(func(c *config.C) { reloadDns(l, c) diff --git a/dns_server_test.go b/dns_server_test.go index 69f6ae8..f4643a3 100644 --- a/dns_server_test.go +++ b/dns_server_test.go @@ -1,23 +1,38 @@ package nebula import ( + "net/netip" "testing" "github.com/miekg/dns" + "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" "github.com/stretchr/testify/assert" ) func TestParsequery(t *testing.T) { - //TODO: This test is basically pointless + l := logrus.New() hostMap := &HostMap{} - ds := newDnsRecords(hostMap) - ds.Add("test.com.com", "1.2.3.4") + ds := newDnsRecords(l, &CertState{}, hostMap) + addrs := []netip.Addr{ + netip.MustParseAddr("1.2.3.4"), + netip.MustParseAddr("1.2.3.5"), + netip.MustParseAddr("fd01::24"), + netip.MustParseAddr("fd01::25"), + } + ds.Add("test.com.com", addrs) - m := new(dns.Msg) + m := &dns.Msg{} m.SetQuestion("test.com.com", dns.TypeA) + ds.parseQuery(m, nil) + assert.NotNil(t, m.Answer) + assert.Equal(t, "1.2.3.4", m.Answer[0].(*dns.A).A.String()) - //parseQuery(m) + m = &dns.Msg{} + m.SetQuestion("test.com.com", dns.TypeAAAA) + ds.parseQuery(m, nil) + assert.NotNil(t, m.Answer) + assert.Equal(t, "fd01::24", m.Answer[0].(*dns.AAAA).AAAA.String()) } func Test_getDnsServerAddr(t *testing.T) { diff --git a/e2e/handshakes_test.go b/e2e/handshakes_test.go index f6069bf..2e7e6e4 100644 --- a/e2e/handshakes_test.go +++ b/e2e/handshakes_test.go @@ -4,14 +4,17 @@ package e2e import ( - "fmt" "net/netip" "slices" "testing" "time" + "github.com/google/gopacket" + "github.com/google/gopacket/layers" "github.com/sirupsen/logrus" "github.com/slackhq/nebula" + "github.com/slackhq/nebula/cert" + "github.com/slackhq/nebula/cert_test" "github.com/slackhq/nebula/e2e/router" "github.com/slackhq/nebula/header" "github.com/slackhq/nebula/udp" @@ -20,12 +23,12 @@ import ( ) func BenchmarkHotPath(b *testing.B) { - ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) - myControl, _, _, _ := newSimpleServer(ca, caKey, "me", "10.128.0.1/24", nil) - theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", "10.128.0.2/24", nil) + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil) // Put their info in our lighthouse - myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr) + myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) // Start the servers myControl.Start() @@ -35,7 +38,7 @@ func BenchmarkHotPath(b *testing.B) { r.CancelFlowLogs() for n := 0; n < b.N; n++ { - myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me")) + myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) _ = r.RouteForAllUntilTxTun(theirControl) } @@ -44,19 +47,19 @@ func BenchmarkHotPath(b *testing.B) { } func TestGoodHandshake(t *testing.T) { - ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) - myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me", "10.128.0.1/24", nil) - theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", "10.128.0.2/24", nil) + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil) // Put their info in our lighthouse - myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr) + myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) // Start the servers myControl.Start() theirControl.Start() t.Log("Send a udp packet through to begin standing up the tunnel, this should come out the other side") - myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me")) + myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) t.Log("Have them consume my stage 0 packet. They have a tunnel now") theirControl.InjectUDPPacket(myControl.GetFromUDP(true)) @@ -77,32 +80,31 @@ func TestGoodHandshake(t *testing.T) { myControl.WaitForType(1, 0, theirControl) t.Log("Make sure our host infos are correct") - assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl) + assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet, theirVpnIpNet, myControl, theirControl) t.Log("Get that cached packet and make sure it looks right") myCachedPacket := theirControl.GetFromTun(true) - assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80) + assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80) t.Log("Do a bidirectional tunnel test") r := router.NewR(t, myControl, theirControl) defer r.RenderFlow() - assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) r.RenderHostmaps("Final hostmaps", myControl, theirControl) myControl.Stop() theirControl.Stop() - //TODO: assert hostmaps } func TestWrongResponderHandshake(t *testing.T) { - ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) - myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me", "10.128.0.100/24", nil) - theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", "10.128.0.99/24", nil) - evilControl, evilVpnIp, evilUdpAddr, _ := newSimpleServer(ca, caKey, "evil", "10.128.0.2/24", nil) + myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.100/24", nil) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.99/24", nil) + evilControl, evilVpnIp, evilUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "evil", "10.128.0.2/24", nil) // Put the evil udp addr in for their vpn Ip, this is a case of being lied to by the lighthouse. - myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), evilUdpAddr) + myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), evilUdpAddr) // Build a router so we don't have to reason who gets which packet r := router.NewR(t, myControl, theirControl, evilControl) @@ -114,7 +116,7 @@ func TestWrongResponderHandshake(t *testing.T) { evilControl.Start() t.Log("Start the handshake process, we will route until we see the evil tunnel closed") - myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me")) + myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) h := &header.H{} r.RouteForAllExitFunc(func(p *udp.Packet, c *nebula.Control) router.ExitType { @@ -131,8 +133,8 @@ func TestWrongResponderHandshake(t *testing.T) { }) t.Log("Evil tunnel is closed, inject the correct udp addr for them") - myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr) - pendingHi := myControl.GetHostInfoByVpnIp(theirVpnIpNet.Addr(), true) + myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) + pendingHi := myControl.GetHostInfoByVpnAddr(theirVpnIpNet[0].Addr(), true) assert.NotContains(t, pendingHi.RemoteAddrs, evilUdpAddr) t.Log("Route until we see the cached packet") @@ -149,24 +151,21 @@ func TestWrongResponderHandshake(t *testing.T) { return router.KeepRouting }) - //TODO: Assert pending hostmap - I should have a correct hostinfo for them now - t.Log("My cached packet should be received by them") myCachedPacket := theirControl.GetFromTun(true) - assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80) + assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80) t.Log("Test the tunnel with them") - assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl) - assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) + assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet, theirVpnIpNet, myControl, theirControl) + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) t.Log("Flush all packets from all controllers") r.FlushAll() t.Log("Ensure ensure I don't have any hostinfo artifacts from evil") - assert.Nil(t, myControl.GetHostInfoByVpnIp(evilVpnIp.Addr(), true), "My pending hostmap should not contain evil") - assert.Nil(t, myControl.GetHostInfoByVpnIp(evilVpnIp.Addr(), false), "My main hostmap should not contain evil") + assert.Nil(t, myControl.GetHostInfoByVpnAddr(evilVpnIp[0].Addr(), true), "My pending hostmap should not contain evil") + assert.Nil(t, myControl.GetHostInfoByVpnAddr(evilVpnIp[0].Addr(), false), "My main hostmap should not contain evil") - //TODO: assert hostmaps for everyone r.RenderHostmaps("Final hostmaps", myControl, theirControl, evilControl) t.Log("Success!") myControl.Stop() @@ -174,19 +173,19 @@ func TestWrongResponderHandshake(t *testing.T) { } func TestWrongResponderHandshakeStaticHostMap(t *testing.T) { - ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) - theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", "10.128.0.99/24", nil) - evilControl, evilVpnIp, evilUdpAddr, _ := newSimpleServer(ca, caKey, "evil", "10.128.0.2/24", nil) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.99/24", nil) + evilControl, evilVpnIp, evilUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "evil", "10.128.0.2/24", nil) o := m{ "static_host_map": m{ - theirVpnIpNet.Addr().String(): []string{evilUdpAddr.String()}, + theirVpnIpNet[0].Addr().String(): []string{evilUdpAddr.String()}, }, } - myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me", "10.128.0.100/24", o) + myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.100/24", o) // Put the evil udp addr in for their vpn addr, this is a case of a remote at a static entry changing its vpn addr. - myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), evilUdpAddr) + myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), evilUdpAddr) // Build a router so we don't have to reason who gets which packet r := router.NewR(t, myControl, theirControl, evilControl) @@ -198,7 +197,7 @@ func TestWrongResponderHandshakeStaticHostMap(t *testing.T) { evilControl.Start() t.Log("Start the handshake process, we will route until we see the evil tunnel closed") - myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me")) + myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) h := &header.H{} r.RouteForAllExitFunc(func(p *udp.Packet, c *nebula.Control) router.ExitType { @@ -215,8 +214,8 @@ func TestWrongResponderHandshakeStaticHostMap(t *testing.T) { }) t.Log("Evil tunnel is closed, inject the correct udp addr for them") - myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr) - pendingHi := myControl.GetHostInfoByVpnIp(theirVpnIpNet.Addr(), true) + myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) + pendingHi := myControl.GetHostInfoByVpnAddr(theirVpnIpNet[0].Addr(), true) assert.NotContains(t, pendingHi.RemoteAddrs, evilUdpAddr) t.Log("Route until we see the cached packet") @@ -233,24 +232,22 @@ func TestWrongResponderHandshakeStaticHostMap(t *testing.T) { return router.KeepRouting }) - //TODO: Assert pending hostmap - I should have a correct hostinfo for them now - t.Log("My cached packet should be received by them") myCachedPacket := theirControl.GetFromTun(true) - assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80) + assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80) t.Log("Test the tunnel with them") - assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl) - assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) + assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet, theirVpnIpNet, myControl, theirControl) + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) t.Log("Flush all packets from all controllers") r.FlushAll() t.Log("Ensure ensure I don't have any hostinfo artifacts from evil") - assert.Nil(t, myControl.GetHostInfoByVpnIp(evilVpnIp.Addr(), true), "My pending hostmap should not contain evil") - assert.Nil(t, myControl.GetHostInfoByVpnIp(evilVpnIp.Addr(), false), "My main hostmap should not contain evil") + assert.Nil(t, myControl.GetHostInfoByVpnAddr(evilVpnIp[0].Addr(), true), "My pending hostmap should not contain evil") + assert.Nil(t, myControl.GetHostInfoByVpnAddr(evilVpnIp[0].Addr(), false), "My main hostmap should not contain evil") + //NOTE: if evil lost the handshake race it may still have a tunnel since me would reject the handshake since the tunnel is complete - //TODO: assert hostmaps for everyone r.RenderHostmaps("Final hostmaps", myControl, theirControl, evilControl) t.Log("Success!") myControl.Stop() @@ -261,13 +258,13 @@ func TestStage1Race(t *testing.T) { // This tests ensures that two hosts handshaking with each other at the same time will allow traffic to flow // But will eventually collapse down to a single tunnel - ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) - myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me ", "10.128.0.1/24", nil) - theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", "10.128.0.2/24", nil) + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", nil) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil) // Put their info in our lighthouse and vice versa - myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr) - theirControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr) + myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) + theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr) // Build a router so we don't have to reason who gets which packet r := router.NewR(t, myControl, theirControl) @@ -278,8 +275,8 @@ func TestStage1Race(t *testing.T) { theirControl.Start() t.Log("Trigger a handshake to start on both me and them") - myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me")) - theirControl.InjectTunUDPPacket(myVpnIpNet.Addr(), 80, 80, []byte("Hi from them")) + myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) + theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them")) t.Log("Get both stage 1 handshake packets") myHsForThem := myControl.GetFromUDP(true) @@ -291,14 +288,14 @@ func TestStage1Race(t *testing.T) { r.Log("Route until they receive a message packet") myCachedPacket := r.RouteForAllUntilTxTun(theirControl) - assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80) + assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80) r.Log("Their cached packet should be received by me") theirCachedPacket := r.RouteForAllUntilTxTun(myControl) - assertUdpPacket(t, []byte("Hi from them"), theirCachedPacket, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), 80, 80) + assertUdpPacket(t, []byte("Hi from them"), theirCachedPacket, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), 80, 80) r.Log("Do a bidirectional tunnel test") - assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) myHostmapHosts := myControl.ListHostmapHosts(false) myHostmapIndexes := myControl.ListHostmapIndexes(false) @@ -316,7 +313,7 @@ func TestStage1Race(t *testing.T) { r.Log("Spin until connection manager tears down a tunnel") for len(myControl.GetHostmap().Indexes)+len(theirControl.GetHostmap().Indexes) > 2 { - assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) t.Log("Connection manager hasn't ticked yet") time.Sleep(time.Second) } @@ -338,13 +335,13 @@ func TestStage1Race(t *testing.T) { } func TestUncleanShutdownRaceLoser(t *testing.T) { - ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) - myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me ", "10.128.0.1/24", nil) - theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", "10.128.0.2/24", nil) + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", nil) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil) // Teach my how to get to the relay and that their can be reached via the relay - myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr) - theirControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr) + myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) + theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr) // Build a router so we don't have to reason who gets which packet r := router.NewR(t, myControl, theirControl) @@ -355,10 +352,10 @@ func TestUncleanShutdownRaceLoser(t *testing.T) { theirControl.Start() r.Log("Trigger a handshake from me to them") - myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me")) + myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) p := r.RouteForAllUntilTxTun(theirControl) - assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80) + assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80) r.Log("Nuke my hostmap") myHostmap := myControl.GetHostmap() @@ -366,17 +363,17 @@ func TestUncleanShutdownRaceLoser(t *testing.T) { myHostmap.Indexes = map[uint32]*nebula.HostInfo{} myHostmap.RemoteIndexes = map[uint32]*nebula.HostInfo{} - myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me again")) + myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me again")) p = r.RouteForAllUntilTxTun(theirControl) - assertUdpPacket(t, []byte("Hi from me again"), p, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80) + assertUdpPacket(t, []byte("Hi from me again"), p, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80) r.Log("Assert the tunnel works") - assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) r.Log("Wait for the dead index to go away") start := len(theirControl.GetHostmap().Indexes) for { - assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) if len(theirControl.GetHostmap().Indexes) < start { break } @@ -387,13 +384,13 @@ func TestUncleanShutdownRaceLoser(t *testing.T) { } func TestUncleanShutdownRaceWinner(t *testing.T) { - ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) - myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me ", "10.128.0.1/24", nil) - theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", "10.128.0.2/24", nil) + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", nil) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil) // Teach my how to get to the relay and that their can be reached via the relay - myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr) - theirControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr) + myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) + theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr) // Build a router so we don't have to reason who gets which packet r := router.NewR(t, myControl, theirControl) @@ -404,10 +401,10 @@ func TestUncleanShutdownRaceWinner(t *testing.T) { theirControl.Start() r.Log("Trigger a handshake from me to them") - myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me")) + myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) p := r.RouteForAllUntilTxTun(theirControl) - assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80) + assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80) r.RenderHostmaps("Final hostmaps", myControl, theirControl) r.Log("Nuke my hostmap") @@ -416,18 +413,18 @@ func TestUncleanShutdownRaceWinner(t *testing.T) { theirHostmap.Indexes = map[uint32]*nebula.HostInfo{} theirHostmap.RemoteIndexes = map[uint32]*nebula.HostInfo{} - theirControl.InjectTunUDPPacket(myVpnIpNet.Addr(), 80, 80, []byte("Hi from them again")) + theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them again")) p = r.RouteForAllUntilTxTun(myControl) - assertUdpPacket(t, []byte("Hi from them again"), p, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), 80, 80) + assertUdpPacket(t, []byte("Hi from them again"), p, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), 80, 80) r.RenderHostmaps("Derp hostmaps", myControl, theirControl) r.Log("Assert the tunnel works") - assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) r.Log("Wait for the dead index to go away") start := len(myControl.GetHostmap().Indexes) for { - assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) if len(myControl.GetHostmap().Indexes) < start { break } @@ -438,15 +435,15 @@ func TestUncleanShutdownRaceWinner(t *testing.T) { } func TestRelays(t *testing.T) { - ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) - myControl, myVpnIpNet, _, _ := newSimpleServer(ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}}) - relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}}) - theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}}) + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}}) + relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}}) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}}) // Teach my how to get to the relay and that their can be reached via the relay - myControl.InjectLightHouseAddr(relayVpnIpNet.Addr(), relayUdpAddr) - myControl.InjectRelays(theirVpnIpNet.Addr(), []netip.Addr{relayVpnIpNet.Addr()}) - relayControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr) + myControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr) + myControl.InjectRelays(theirVpnIpNet[0].Addr(), []netip.Addr{relayVpnIpNet[0].Addr()}) + relayControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) // Build a router so we don't have to reason who gets which packet r := router.NewR(t, myControl, relayControl, theirControl) @@ -458,31 +455,161 @@ func TestRelays(t *testing.T) { theirControl.Start() t.Log("Trigger a handshake from me to them via the relay") - myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me")) + myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) p := r.RouteForAllUntilTxTun(theirControl) r.Log("Assert the tunnel works") - assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80) + assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80) r.RenderHostmaps("Final hostmaps", myControl, relayControl, theirControl) - //TODO: assert we actually used the relay even though it should be impossible for a tunnel to have occurred without it +} + +func TestReestablishRelays(t *testing.T) { + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}}) + relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}}) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}}) + + // Teach my how to get to the relay and that their can be reached via the relay + myControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr) + myControl.InjectRelays(theirVpnIpNet[0].Addr(), []netip.Addr{relayVpnIpNet[0].Addr()}) + relayControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) + + // Build a router so we don't have to reason who gets which packet + r := router.NewR(t, myControl, relayControl, theirControl) + defer r.RenderFlow() + + // Start the servers + myControl.Start() + relayControl.Start() + theirControl.Start() + + t.Log("Trigger a handshake from me to them via the relay") + myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) + + p := r.RouteForAllUntilTxTun(theirControl) + r.Log("Assert the tunnel works") + assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80) + + t.Log("Ensure packet traversal from them to me via the relay") + theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them")) + + p = r.RouteForAllUntilTxTun(myControl) + r.Log("Assert the tunnel works") + assertUdpPacket(t, []byte("Hi from them"), p, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), 80, 80) + + // If we break the relay's connection to 'them', 'me' needs to detect and recover the connection + r.Log("Close the tunnel") + relayControl.CloseTunnel(theirVpnIpNet[0].Addr(), true) + + start := len(myControl.GetHostmap().Indexes) + curIndexes := len(myControl.GetHostmap().Indexes) + for curIndexes >= start { + curIndexes = len(myControl.GetHostmap().Indexes) + r.Logf("Wait for the dead index to go away:start=%v indexes, currnet=%v indexes", start, curIndexes) + myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me should fail")) + + r.RouteForAllExitFunc(func(p *udp.Packet, c *nebula.Control) router.ExitType { + return router.RouteAndExit + }) + time.Sleep(2 * time.Second) + } + r.Log("Dead index went away. Woot!") + r.RenderHostmaps("Me removed hostinfo", myControl, relayControl, theirControl) + // Next packet should re-establish a relayed connection and work just great. + + t.Logf("Assert the tunnel...") + for { + t.Log("RouteForAllUntilTxTun") + myControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr) + myControl.InjectRelays(theirVpnIpNet[0].Addr(), []netip.Addr{relayVpnIpNet[0].Addr()}) + relayControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) + myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) + + p = r.RouteForAllUntilTxTun(theirControl) + r.Log("Assert the tunnel works") + packet := gopacket.NewPacket(p, layers.LayerTypeIPv4, gopacket.Lazy) + v4 := packet.Layer(layers.LayerTypeIPv4).(*layers.IPv4) + if slices.Compare(v4.SrcIP, myVpnIpNet[0].Addr().AsSlice()) != 0 { + t.Logf("SrcIP is unexpected...this is not the packet I'm looking for. Keep looking") + continue + } + if slices.Compare(v4.DstIP, theirVpnIpNet[0].Addr().AsSlice()) != 0 { + t.Logf("DstIP is unexpected...this is not the packet I'm looking for. Keep looking") + continue + } + + udp := packet.Layer(layers.LayerTypeUDP).(*layers.UDP) + if udp == nil { + t.Log("Not a UDP packet. This is not the packet I'm looking for. Keep looking") + continue + } + data := packet.ApplicationLayer() + if data == nil { + t.Log("No data found in packet. This is not the packet I'm looking for. Keep looking.") + continue + } + if string(data.Payload()) != "Hi from me" { + t.Logf("Unexpected payload: '%v', keep looking", string(data.Payload())) + continue + } + t.Log("I found my lost packet. I am so happy.") + break + } + t.Log("Assert the tunnel works the other way, too") + for { + t.Log("RouteForAllUntilTxTun") + theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them")) + + p = r.RouteForAllUntilTxTun(myControl) + r.Log("Assert the tunnel works") + packet := gopacket.NewPacket(p, layers.LayerTypeIPv4, gopacket.Lazy) + v4 := packet.Layer(layers.LayerTypeIPv4).(*layers.IPv4) + if slices.Compare(v4.DstIP, myVpnIpNet[0].Addr().AsSlice()) != 0 { + t.Logf("Dst is unexpected...this is not the packet I'm looking for. Keep looking") + continue + } + if slices.Compare(v4.SrcIP, theirVpnIpNet[0].Addr().AsSlice()) != 0 { + t.Logf("SrcIP is unexpected...this is not the packet I'm looking for. Keep looking") + continue + } + + udp := packet.Layer(layers.LayerTypeUDP).(*layers.UDP) + if udp == nil { + t.Log("Not a UDP packet. This is not the packet I'm looking for. Keep looking") + continue + } + data := packet.ApplicationLayer() + if data == nil { + t.Log("No data found in packet. This is not the packet I'm looking for. Keep looking.") + continue + } + if string(data.Payload()) != "Hi from them" { + t.Logf("Unexpected payload: '%v', keep looking", string(data.Payload())) + continue + } + t.Log("I found my lost packet. I am so happy.") + break + } + r.RenderHostmaps("Final hostmaps", myControl, relayControl, theirControl) + } func TestStage1RaceRelays(t *testing.T) { //NOTE: this is a race between me and relay resulting in a full tunnel from me to them via relay - ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) - myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}}) - relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}}) - theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}}) + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}}) + relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}}) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}}) // Teach my how to get to the relay and that their can be reached via the relay - myControl.InjectLightHouseAddr(relayVpnIpNet.Addr(), relayUdpAddr) - theirControl.InjectLightHouseAddr(relayVpnIpNet.Addr(), relayUdpAddr) + myControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr) + theirControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr) - myControl.InjectRelays(theirVpnIpNet.Addr(), []netip.Addr{relayVpnIpNet.Addr()}) - theirControl.InjectRelays(myVpnIpNet.Addr(), []netip.Addr{relayVpnIpNet.Addr()}) + myControl.InjectRelays(theirVpnIpNet[0].Addr(), []netip.Addr{relayVpnIpNet[0].Addr()}) + theirControl.InjectRelays(myVpnIpNet[0].Addr(), []netip.Addr{relayVpnIpNet[0].Addr()}) - relayControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr) - relayControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr) + relayControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) + relayControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr) // Build a router so we don't have to reason who gets which packet r := router.NewR(t, myControl, relayControl, theirControl) @@ -494,14 +621,14 @@ func TestStage1RaceRelays(t *testing.T) { theirControl.Start() r.Log("Get a tunnel between me and relay") - assertTunnel(t, myVpnIpNet.Addr(), relayVpnIpNet.Addr(), myControl, relayControl, r) + assertTunnel(t, myVpnIpNet[0].Addr(), relayVpnIpNet[0].Addr(), myControl, relayControl, r) r.Log("Get a tunnel between them and relay") - assertTunnel(t, theirVpnIpNet.Addr(), relayVpnIpNet.Addr(), theirControl, relayControl, r) + assertTunnel(t, theirVpnIpNet[0].Addr(), relayVpnIpNet[0].Addr(), theirControl, relayControl, r) r.Log("Trigger a handshake from both them and me via relay to them and me") - myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me")) - theirControl.InjectTunUDPPacket(myVpnIpNet.Addr(), 80, 80, []byte("Hi from them")) + myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) + theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them")) r.Log("Wait for a packet from them to me") p := r.RouteForAllUntilTxTun(myControl) @@ -512,27 +639,25 @@ func TestStage1RaceRelays(t *testing.T) { myControl.Stop() theirControl.Stop() relayControl.Stop() - // - ////TODO: assert hostmaps } func TestStage1RaceRelays2(t *testing.T) { //NOTE: this is a race between me and relay resulting in a full tunnel from me to them via relay - ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) - myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}}) - relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}}) - theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}}) + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}}) + relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}}) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}}) l := NewTestLogger() // Teach my how to get to the relay and that their can be reached via the relay - myControl.InjectLightHouseAddr(relayVpnIpNet.Addr(), relayUdpAddr) - theirControl.InjectLightHouseAddr(relayVpnIpNet.Addr(), relayUdpAddr) + myControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr) + theirControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr) - myControl.InjectRelays(theirVpnIpNet.Addr(), []netip.Addr{relayVpnIpNet.Addr()}) - theirControl.InjectRelays(myVpnIpNet.Addr(), []netip.Addr{relayVpnIpNet.Addr()}) + myControl.InjectRelays(theirVpnIpNet[0].Addr(), []netip.Addr{relayVpnIpNet[0].Addr()}) + theirControl.InjectRelays(myVpnIpNet[0].Addr(), []netip.Addr{relayVpnIpNet[0].Addr()}) - relayControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr) - relayControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr) + relayControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) + relayControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr) // Build a router so we don't have to reason who gets which packet r := router.NewR(t, myControl, relayControl, theirControl) @@ -545,16 +670,16 @@ func TestStage1RaceRelays2(t *testing.T) { r.Log("Get a tunnel between me and relay") l.Info("Get a tunnel between me and relay") - assertTunnel(t, myVpnIpNet.Addr(), relayVpnIpNet.Addr(), myControl, relayControl, r) + assertTunnel(t, myVpnIpNet[0].Addr(), relayVpnIpNet[0].Addr(), myControl, relayControl, r) r.Log("Get a tunnel between them and relay") l.Info("Get a tunnel between them and relay") - assertTunnel(t, theirVpnIpNet.Addr(), relayVpnIpNet.Addr(), theirControl, relayControl, r) + assertTunnel(t, theirVpnIpNet[0].Addr(), relayVpnIpNet[0].Addr(), theirControl, relayControl, r) r.Log("Trigger a handshake from both them and me via relay to them and me") l.Info("Trigger a handshake from both them and me via relay to them and me") - myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me")) - theirControl.InjectTunUDPPacket(myVpnIpNet.Addr(), 80, 80, []byte("Hi from them")) + myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) + theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them")) //r.RouteUntilAfterMsgType(myControl, header.Control, header.MessageNone) //r.RouteUntilAfterMsgType(theirControl, header.Control, header.MessageNone) @@ -567,7 +692,7 @@ func TestStage1RaceRelays2(t *testing.T) { r.Log("Assert the tunnel works") l.Info("Assert the tunnel works") - assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r) + assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r) t.Log("Wait until we remove extra tunnels") l.Info("Wait until we remove extra tunnels") @@ -587,7 +712,7 @@ func TestStage1RaceRelays2(t *testing.T) { "theirControl": len(theirControl.GetHostmap().Indexes), "relayControl": len(relayControl.GetHostmap().Indexes), }).Info("Waiting for hostinfos to be removed...") - assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) t.Log("Connection manager hasn't ticked yet") time.Sleep(time.Second) retries-- @@ -595,26 +720,23 @@ func TestStage1RaceRelays2(t *testing.T) { r.Log("Assert the tunnel works") l.Info("Assert the tunnel works") - assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r) + assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r) myControl.Stop() theirControl.Stop() relayControl.Stop() - - // - ////TODO: assert hostmaps } func TestRehandshakingRelays(t *testing.T) { - ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) - myControl, myVpnIpNet, _, _ := newSimpleServer(ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}}) - relayControl, relayVpnIpNet, relayUdpAddr, relayConfig := newSimpleServer(ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}}) - theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}}) + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}}) + relayControl, relayVpnIpNet, relayUdpAddr, relayConfig := newSimpleServer(cert.Version1, ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}}) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}}) // Teach my how to get to the relay and that their can be reached via the relay - myControl.InjectLightHouseAddr(relayVpnIpNet.Addr(), relayUdpAddr) - myControl.InjectRelays(theirVpnIpNet.Addr(), []netip.Addr{relayVpnIpNet.Addr()}) - relayControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr) + myControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr) + myControl.InjectRelays(theirVpnIpNet[0].Addr(), []netip.Addr{relayVpnIpNet[0].Addr()}) + relayControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) // Build a router so we don't have to reason who gets which packet r := router.NewR(t, myControl, relayControl, theirControl) @@ -626,17 +748,17 @@ func TestRehandshakingRelays(t *testing.T) { theirControl.Start() t.Log("Trigger a handshake from me to them via the relay") - myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me")) + myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) p := r.RouteForAllUntilTxTun(theirControl) r.Log("Assert the tunnel works") - assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80) + assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80) r.RenderHostmaps("working hostmaps", myControl, relayControl, theirControl) // When I update the certificate for the relay, both me and them will have 2 host infos for the relay, // and the main host infos will not have any relay state to handle the me<->relay<->them tunnel. r.Log("Renew relay certificate and spin until me and them sees it") - _, _, myNextPrivKey, myNextPEM := NewTestCert(ca, caKey, "relay", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{relayVpnIpNet}, nil, []string{"new group"}) + _, _, myNextPrivKey, myNextPEM := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "relay", time.Now(), time.Now().Add(5*time.Minute), relayVpnIpNet, nil, []string{"new group"}) caB, err := ca.MarshalPEM() if err != nil { @@ -654,8 +776,8 @@ func TestRehandshakingRelays(t *testing.T) { for { r.Log("Assert the tunnel works between myVpnIpNet and relayVpnIpNet") - assertTunnel(t, myVpnIpNet.Addr(), relayVpnIpNet.Addr(), myControl, relayControl, r) - c := myControl.GetHostInfoByVpnIp(relayVpnIpNet.Addr(), false) + assertTunnel(t, myVpnIpNet[0].Addr(), relayVpnIpNet[0].Addr(), myControl, relayControl, r) + c := myControl.GetHostInfoByVpnAddr(relayVpnIpNet[0].Addr(), false) if len(c.Cert.Groups()) != 0 { // We have a new certificate now r.Log("Certificate between my and relay is updated!") @@ -667,8 +789,8 @@ func TestRehandshakingRelays(t *testing.T) { for { r.Log("Assert the tunnel works between theirVpnIpNet and relayVpnIpNet") - assertTunnel(t, theirVpnIpNet.Addr(), relayVpnIpNet.Addr(), theirControl, relayControl, r) - c := theirControl.GetHostInfoByVpnIp(relayVpnIpNet.Addr(), false) + assertTunnel(t, theirVpnIpNet[0].Addr(), relayVpnIpNet[0].Addr(), theirControl, relayControl, r) + c := theirControl.GetHostInfoByVpnAddr(relayVpnIpNet[0].Addr(), false) if len(c.Cert.Groups()) != 0 { // We have a new certificate now r.Log("Certificate between their and relay is updated!") @@ -679,13 +801,13 @@ func TestRehandshakingRelays(t *testing.T) { } r.Log("Assert the relay tunnel still works") - assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r) + assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r) r.RenderHostmaps("working hostmaps", myControl, relayControl, theirControl) // We should have two hostinfos on all sides for len(myControl.GetHostmap().Indexes) != 2 { t.Logf("Waiting for myControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(myControl.GetHostmap().Indexes)) r.Log("Assert the relay tunnel still works") - assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r) + assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r) r.Log("yupitdoes") time.Sleep(time.Second) } @@ -693,7 +815,7 @@ func TestRehandshakingRelays(t *testing.T) { for len(theirControl.GetHostmap().Indexes) != 2 { t.Logf("Waiting for theirControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(theirControl.GetHostmap().Indexes)) r.Log("Assert the relay tunnel still works") - assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r) + assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r) r.Log("yupitdoes") time.Sleep(time.Second) } @@ -701,7 +823,7 @@ func TestRehandshakingRelays(t *testing.T) { for len(relayControl.GetHostmap().Indexes) != 2 { t.Logf("Waiting for relayControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(relayControl.GetHostmap().Indexes)) r.Log("Assert the relay tunnel still works") - assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r) + assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r) r.Log("yupitdoes") time.Sleep(time.Second) } @@ -710,15 +832,15 @@ func TestRehandshakingRelays(t *testing.T) { func TestRehandshakingRelaysPrimary(t *testing.T) { // This test is the same as TestRehandshakingRelays but one of the terminal types is a primary swap winner - ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) - myControl, myVpnIpNet, _, _ := newSimpleServer(ca, caKey, "me ", "10.128.0.128/24", m{"relay": m{"use_relays": true}}) - relayControl, relayVpnIpNet, relayUdpAddr, relayConfig := newSimpleServer(ca, caKey, "relay ", "10.128.0.1/24", m{"relay": m{"am_relay": true}}) - theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}}) + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.128/24", m{"relay": m{"use_relays": true}}) + relayControl, relayVpnIpNet, relayUdpAddr, relayConfig := newSimpleServer(cert.Version1, ca, caKey, "relay ", "10.128.0.1/24", m{"relay": m{"am_relay": true}}) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}}) // Teach my how to get to the relay and that their can be reached via the relay - myControl.InjectLightHouseAddr(relayVpnIpNet.Addr(), relayUdpAddr) - myControl.InjectRelays(theirVpnIpNet.Addr(), []netip.Addr{relayVpnIpNet.Addr()}) - relayControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr) + myControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr) + myControl.InjectRelays(theirVpnIpNet[0].Addr(), []netip.Addr{relayVpnIpNet[0].Addr()}) + relayControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) // Build a router so we don't have to reason who gets which packet r := router.NewR(t, myControl, relayControl, theirControl) @@ -730,17 +852,17 @@ func TestRehandshakingRelaysPrimary(t *testing.T) { theirControl.Start() t.Log("Trigger a handshake from me to them via the relay") - myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me")) + myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) p := r.RouteForAllUntilTxTun(theirControl) r.Log("Assert the tunnel works") - assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80) + assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80) r.RenderHostmaps("working hostmaps", myControl, relayControl, theirControl) // When I update the certificate for the relay, both me and them will have 2 host infos for the relay, // and the main host infos will not have any relay state to handle the me<->relay<->them tunnel. r.Log("Renew relay certificate and spin until me and them sees it") - _, _, myNextPrivKey, myNextPEM := NewTestCert(ca, caKey, "relay", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{relayVpnIpNet}, nil, []string{"new group"}) + _, _, myNextPrivKey, myNextPEM := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "relay", time.Now(), time.Now().Add(5*time.Minute), relayVpnIpNet, nil, []string{"new group"}) caB, err := ca.MarshalPEM() if err != nil { @@ -758,8 +880,8 @@ func TestRehandshakingRelaysPrimary(t *testing.T) { for { r.Log("Assert the tunnel works between myVpnIpNet and relayVpnIpNet") - assertTunnel(t, myVpnIpNet.Addr(), relayVpnIpNet.Addr(), myControl, relayControl, r) - c := myControl.GetHostInfoByVpnIp(relayVpnIpNet.Addr(), false) + assertTunnel(t, myVpnIpNet[0].Addr(), relayVpnIpNet[0].Addr(), myControl, relayControl, r) + c := myControl.GetHostInfoByVpnAddr(relayVpnIpNet[0].Addr(), false) if len(c.Cert.Groups()) != 0 { // We have a new certificate now r.Log("Certificate between my and relay is updated!") @@ -771,8 +893,8 @@ func TestRehandshakingRelaysPrimary(t *testing.T) { for { r.Log("Assert the tunnel works between theirVpnIpNet and relayVpnIpNet") - assertTunnel(t, theirVpnIpNet.Addr(), relayVpnIpNet.Addr(), theirControl, relayControl, r) - c := theirControl.GetHostInfoByVpnIp(relayVpnIpNet.Addr(), false) + assertTunnel(t, theirVpnIpNet[0].Addr(), relayVpnIpNet[0].Addr(), theirControl, relayControl, r) + c := theirControl.GetHostInfoByVpnAddr(relayVpnIpNet[0].Addr(), false) if len(c.Cert.Groups()) != 0 { // We have a new certificate now r.Log("Certificate between their and relay is updated!") @@ -783,13 +905,13 @@ func TestRehandshakingRelaysPrimary(t *testing.T) { } r.Log("Assert the relay tunnel still works") - assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r) + assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r) r.RenderHostmaps("working hostmaps", myControl, relayControl, theirControl) // We should have two hostinfos on all sides for len(myControl.GetHostmap().Indexes) != 2 { t.Logf("Waiting for myControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(myControl.GetHostmap().Indexes)) r.Log("Assert the relay tunnel still works") - assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r) + assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r) r.Log("yupitdoes") time.Sleep(time.Second) } @@ -797,7 +919,7 @@ func TestRehandshakingRelaysPrimary(t *testing.T) { for len(theirControl.GetHostmap().Indexes) != 2 { t.Logf("Waiting for theirControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(theirControl.GetHostmap().Indexes)) r.Log("Assert the relay tunnel still works") - assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r) + assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r) r.Log("yupitdoes") time.Sleep(time.Second) } @@ -805,7 +927,7 @@ func TestRehandshakingRelaysPrimary(t *testing.T) { for len(relayControl.GetHostmap().Indexes) != 2 { t.Logf("Waiting for relayControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(relayControl.GetHostmap().Indexes)) r.Log("Assert the relay tunnel still works") - assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r) + assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r) r.Log("yupitdoes") time.Sleep(time.Second) } @@ -813,13 +935,13 @@ func TestRehandshakingRelaysPrimary(t *testing.T) { } func TestRehandshaking(t *testing.T) { - ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) - myControl, myVpnIpNet, myUdpAddr, myConfig := newSimpleServer(ca, caKey, "me ", "10.128.0.2/24", nil) - theirControl, theirVpnIpNet, theirUdpAddr, theirConfig := newSimpleServer(ca, caKey, "them", "10.128.0.1/24", nil) + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, myUdpAddr, myConfig := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.2/24", nil) + theirControl, theirVpnIpNet, theirUdpAddr, theirConfig := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.1/24", nil) // Put their info in our lighthouse and vice versa - myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr) - theirControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr) + myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) + theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr) // Build a router so we don't have to reason who gets which packet r := router.NewR(t, myControl, theirControl) @@ -830,12 +952,12 @@ func TestRehandshaking(t *testing.T) { theirControl.Start() t.Log("Stand up a tunnel between me and them") - assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) r.RenderHostmaps("Starting hostmaps", myControl, theirControl) r.Log("Renew my certificate and spin until their sees it") - _, _, myNextPrivKey, myNextPEM := NewTestCert(ca, caKey, "me", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{myVpnIpNet}, nil, []string{"new group"}) + _, _, myNextPrivKey, myNextPEM := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "me", time.Now(), time.Now().Add(5*time.Minute), myVpnIpNet, nil, []string{"new group"}) caB, err := ca.MarshalPEM() if err != nil { @@ -852,8 +974,8 @@ func TestRehandshaking(t *testing.T) { myConfig.ReloadConfigString(string(rc)) for { - assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) - c := theirControl.GetHostInfoByVpnIp(myVpnIpNet.Addr(), false) + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) + c := theirControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false) if len(c.Cert.Groups()) != 0 { // We have a new certificate now break @@ -880,19 +1002,19 @@ func TestRehandshaking(t *testing.T) { r.Log("Spin until there is only 1 tunnel") for len(myControl.GetHostmap().Indexes)+len(theirControl.GetHostmap().Indexes) > 2 { - assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) t.Log("Connection manager hasn't ticked yet") time.Sleep(time.Second) } - assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) myFinalHostmapHosts := myControl.ListHostmapHosts(false) myFinalHostmapIndexes := myControl.ListHostmapIndexes(false) theirFinalHostmapHosts := theirControl.ListHostmapHosts(false) theirFinalHostmapIndexes := theirControl.ListHostmapIndexes(false) // Make sure the correct tunnel won - c := theirControl.GetHostInfoByVpnIp(myVpnIpNet.Addr(), false) + c := theirControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false) assert.Contains(t, c.Cert.Groups(), "new group") // We should only have a single tunnel now on both sides @@ -910,13 +1032,13 @@ func TestRehandshaking(t *testing.T) { func TestRehandshakingLoser(t *testing.T) { // The purpose of this test is that the race loser renews their certificate and rehandshakes. The final tunnel // Should be the one with the new certificate - ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) - myControl, myVpnIpNet, myUdpAddr, myConfig := newSimpleServer(ca, caKey, "me ", "10.128.0.2/24", nil) - theirControl, theirVpnIpNet, theirUdpAddr, theirConfig := newSimpleServer(ca, caKey, "them", "10.128.0.1/24", nil) + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, myUdpAddr, myConfig := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.2/24", nil) + theirControl, theirVpnIpNet, theirUdpAddr, theirConfig := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.1/24", nil) // Put their info in our lighthouse and vice versa - myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr) - theirControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr) + myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) + theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr) // Build a router so we don't have to reason who gets which packet r := router.NewR(t, myControl, theirControl) @@ -927,16 +1049,12 @@ func TestRehandshakingLoser(t *testing.T) { theirControl.Start() t.Log("Stand up a tunnel between me and them") - assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) - - tt1 := myControl.GetHostInfoByVpnIp(theirVpnIpNet.Addr(), false) - tt2 := theirControl.GetHostInfoByVpnIp(myVpnIpNet.Addr(), false) - fmt.Println(tt1.LocalIndex, tt2.LocalIndex) + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) r.RenderHostmaps("Starting hostmaps", myControl, theirControl) r.Log("Renew their certificate and spin until mine sees it") - _, _, theirNextPrivKey, theirNextPEM := NewTestCert(ca, caKey, "them", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{theirVpnIpNet}, nil, []string{"their new group"}) + _, _, theirNextPrivKey, theirNextPEM := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "them", time.Now(), time.Now().Add(5*time.Minute), theirVpnIpNet, nil, []string{"their new group"}) caB, err := ca.MarshalPEM() if err != nil { @@ -953,8 +1071,8 @@ func TestRehandshakingLoser(t *testing.T) { theirConfig.ReloadConfigString(string(rc)) for { - assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) - theirCertInMe := myControl.GetHostInfoByVpnIp(theirVpnIpNet.Addr(), false) + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) + theirCertInMe := myControl.GetHostInfoByVpnAddr(theirVpnIpNet[0].Addr(), false) if slices.Contains(theirCertInMe.Cert.Groups(), "their new group") { break @@ -980,19 +1098,19 @@ func TestRehandshakingLoser(t *testing.T) { r.Log("Spin until there is only 1 tunnel") for len(myControl.GetHostmap().Indexes)+len(theirControl.GetHostmap().Indexes) > 2 { - assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) t.Log("Connection manager hasn't ticked yet") time.Sleep(time.Second) } - assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) myFinalHostmapHosts := myControl.ListHostmapHosts(false) myFinalHostmapIndexes := myControl.ListHostmapIndexes(false) theirFinalHostmapHosts := theirControl.ListHostmapHosts(false) theirFinalHostmapIndexes := theirControl.ListHostmapIndexes(false) // Make sure the correct tunnel won - theirCertInMe := myControl.GetHostInfoByVpnIp(theirVpnIpNet.Addr(), false) + theirCertInMe := myControl.GetHostInfoByVpnAddr(theirVpnIpNet[0].Addr(), false) assert.Contains(t, theirCertInMe.Cert.Groups(), "their new group") // We should only have a single tunnel now on both sides @@ -1010,13 +1128,13 @@ func TestRaceRegression(t *testing.T) { // This test forces stage 1, stage 2, stage 1 to be received by me from them // We had a bug where we were not finding the duplicate handshake and responding to the final stage 1 which // caused a cross-linked hostinfo - ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) - myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me", "10.128.0.1/24", nil) - theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", "10.128.0.2/24", nil) + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil) // Put their info in our lighthouse - myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr) - theirControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr) + myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) + theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr) // Start the servers myControl.Start() @@ -1030,8 +1148,8 @@ func TestRaceRegression(t *testing.T) { //them rx stage:2 initiatorIndex=120607833 responderIndex=4209862089 t.Log("Start both handshakes") - myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me")) - theirControl.InjectTunUDPPacket(myVpnIpNet.Addr(), 80, 80, []byte("Hi from them")) + myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) + theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them")) t.Log("Get both stage 1") myStage1ForThem := myControl.GetFromUDP(true) @@ -1042,7 +1160,6 @@ func TestRaceRegression(t *testing.T) { myControl.InjectUDPPacket(theirStage1ForMe) theirControl.InjectUDPPacket(myStage1ForThem) - //TODO: ensure stage 2 t.Log("Get both stage 2") myStage2ForThem := myControl.GetFromUDP(true) theirStage2ForMe := theirControl.GetFromUDP(true) @@ -1061,14 +1178,48 @@ func TestRaceRegression(t *testing.T) { r.RenderHostmaps("Starting hostmaps", myControl, theirControl) t.Log("Make sure the tunnel still works") - assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) myControl.Stop() theirControl.Stop() } -//TODO: test -// Race winner renews and handshakes -// Race loser renews and handshakes -// Does race winner repin the cert to old? -//TODO: add a test with many lies +func TestV2NonPrimaryWithLighthouse(t *testing.T) { + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + lhControl, lhVpnIpNet, lhUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "lh ", "10.128.0.1/24, ff::1/64", m{"lighthouse": m{"am_lighthouse": true}}) + + o := m{ + "static_host_map": m{ + lhVpnIpNet[1].Addr().String(): []string{lhUdpAddr.String()}, + }, + "lighthouse": m{ + "hosts": []string{lhVpnIpNet[1].Addr().String()}, + "local_allow_list": m{ + // Try and block our lighthouse updates from using the actual addresses assigned to this computer + // If we start discovering addresses the test router doesn't know about then test traffic cant flow + "10.0.0.0/24": true, + "::/0": false, + }, + }, + } + myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version2, ca, caKey, "me ", "10.128.0.2/24, ff::2/64", o) + theirControl, theirVpnIpNet, _, _ := newSimpleServer(cert.Version2, ca, caKey, "them", "10.128.0.3/24, ff::3/64", o) + + // Build a router so we don't have to reason who gets which packet + r := router.NewR(t, lhControl, myControl, theirControl) + defer r.RenderFlow() + + // Start the servers + lhControl.Start() + myControl.Start() + theirControl.Start() + + t.Log("Stand up an ipv6 tunnel between me and them") + assert.True(t, myVpnIpNet[1].Addr().Is6()) + assert.True(t, theirVpnIpNet[1].Addr().Is6()) + assertTunnel(t, myVpnIpNet[1].Addr(), theirVpnIpNet[1].Addr(), myControl, theirControl, r) + + lhControl.Stop() + myControl.Stop() + theirControl.Stop() +} diff --git a/e2e/helpers_test.go b/e2e/helpers_test.go index 77996f3..e1b7ac2 100644 --- a/e2e/helpers_test.go +++ b/e2e/helpers_test.go @@ -8,6 +8,7 @@ import ( "io" "net/netip" "os" + "strings" "testing" "time" @@ -17,6 +18,7 @@ import ( "github.com/sirupsen/logrus" "github.com/slackhq/nebula" "github.com/slackhq/nebula/cert" + "github.com/slackhq/nebula/cert_test" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/e2e/router" "github.com/stretchr/testify/assert" @@ -26,25 +28,35 @@ import ( type m map[string]interface{} // newSimpleServer creates a nebula instance with many assumptions -func newSimpleServer(caCrt cert.Certificate, caKey []byte, name string, sVpnIpNet string, overrides m) (*nebula.Control, netip.Prefix, netip.AddrPort, *config.C) { +func newSimpleServer(v cert.Version, caCrt cert.Certificate, caKey []byte, name string, sVpnNetworks string, overrides m) (*nebula.Control, []netip.Prefix, netip.AddrPort, *config.C) { l := NewTestLogger() - vpnIpNet, err := netip.ParsePrefix(sVpnIpNet) - if err != nil { - panic(err) + var vpnNetworks []netip.Prefix + for _, sn := range strings.Split(sVpnNetworks, ",") { + vpnIpNet, err := netip.ParsePrefix(strings.TrimSpace(sn)) + if err != nil { + panic(err) + } + vpnNetworks = append(vpnNetworks, vpnIpNet) + } + + if len(vpnNetworks) == 0 { + panic("no vpn networks") } var udpAddr netip.AddrPort - if vpnIpNet.Addr().Is4() { - budpIp := vpnIpNet.Addr().As4() + if vpnNetworks[0].Addr().Is4() { + budpIp := vpnNetworks[0].Addr().As4() budpIp[1] -= 128 udpAddr = netip.AddrPortFrom(netip.AddrFrom4(budpIp), 4242) } else { - budpIp := vpnIpNet.Addr().As16() - budpIp[13] -= 128 + budpIp := vpnNetworks[0].Addr().As16() + // beef for funsies + budpIp[2] = 190 + budpIp[3] = 239 udpAddr = netip.AddrPortFrom(netip.AddrFrom16(budpIp), 4242) } - _, _, myPrivKey, myPEM := NewTestCert(caCrt, caKey, name, time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{vpnIpNet}, nil, []string{}) + _, _, myPrivKey, myPEM := cert_test.NewTestCert(v, cert.Curve_CURVE25519, caCrt, caKey, name, time.Now(), time.Now().Add(5*time.Minute), vpnNetworks, nil, []string{}) caB, err := caCrt.MarshalPEM() if err != nil { @@ -88,11 +100,16 @@ func newSimpleServer(caCrt cert.Certificate, caKey []byte, name string, sVpnIpNe } if overrides != nil { - err = mergo.Merge(&overrides, mc, mergo.WithAppendSlice) + final := m{} + err = mergo.Merge(&final, overrides, mergo.WithAppendSlice) if err != nil { panic(err) } - mc = overrides + err = mergo.Merge(&final, mc, mergo.WithAppendSlice) + if err != nil { + panic(err) + } + mc = final } cb, err := yaml.Marshal(mc) @@ -109,7 +126,7 @@ func newSimpleServer(caCrt cert.Certificate, caKey []byte, name string, sVpnIpNe panic(err) } - return control, vpnIpNet, udpAddr, c + return control, vpnNetworks, udpAddr, c } type doneCb func() @@ -132,27 +149,28 @@ func deadline(t *testing.T, seconds time.Duration) doneCb { func assertTunnel(t *testing.T, vpnIpA, vpnIpB netip.Addr, controlA, controlB *nebula.Control, r *router.R) { // Send a packet from them to me - controlB.InjectTunUDPPacket(vpnIpA, 80, 90, []byte("Hi from B")) + controlB.InjectTunUDPPacket(vpnIpA, 80, vpnIpB, 90, []byte("Hi from B")) bPacket := r.RouteForAllUntilTxTun(controlA) assertUdpPacket(t, []byte("Hi from B"), bPacket, vpnIpB, vpnIpA, 90, 80) // And once more from me to them - controlA.InjectTunUDPPacket(vpnIpB, 80, 90, []byte("Hello from A")) + controlA.InjectTunUDPPacket(vpnIpB, 80, vpnIpA, 90, []byte("Hello from A")) aPacket := r.RouteForAllUntilTxTun(controlB) assertUdpPacket(t, []byte("Hello from A"), aPacket, vpnIpA, vpnIpB, 90, 80) } -func assertHostInfoPair(t *testing.T, addrA, addrB netip.AddrPort, vpnIpA, vpnIpB netip.Addr, controlA, controlB *nebula.Control) { +func assertHostInfoPair(t *testing.T, addrA, addrB netip.AddrPort, vpnNetsA, vpnNetsB []netip.Prefix, controlA, controlB *nebula.Control) { // Get both host infos - hBinA := controlA.GetHostInfoByVpnIp(vpnIpB, false) - assert.NotNil(t, hBinA, "Host B was not found by vpnIp in controlA") + //TODO: CERT-V2 we may want to loop over each vpnAddr and assert all the things + hBinA := controlA.GetHostInfoByVpnAddr(vpnNetsB[0].Addr(), false) + assert.NotNil(t, hBinA, "Host B was not found by vpnAddr in controlA") - hAinB := controlB.GetHostInfoByVpnIp(vpnIpA, false) - assert.NotNil(t, hAinB, "Host A was not found by vpnIp in controlB") + hAinB := controlB.GetHostInfoByVpnAddr(vpnNetsA[0].Addr(), false) + assert.NotNil(t, hAinB, "Host A was not found by vpnAddr in controlB") // Check that both vpn and real addr are correct - assert.Equal(t, vpnIpB, hBinA.VpnIp, "Host B VpnIp is wrong in control A") - assert.Equal(t, vpnIpA, hAinB.VpnIp, "Host A VpnIp is wrong in control B") + assert.EqualValues(t, getAddrs(vpnNetsB), hBinA.VpnAddrs, "Host B VpnIp is wrong in control A") + assert.EqualValues(t, getAddrs(vpnNetsA), hAinB.VpnAddrs, "Host A VpnIp is wrong in control B") assert.Equal(t, addrB, hBinA.CurrentRemote, "Host B remote is wrong in control A") assert.Equal(t, addrA, hAinB.CurrentRemote, "Host A remote is wrong in control B") @@ -160,25 +178,36 @@ func assertHostInfoPair(t *testing.T, addrA, addrB netip.AddrPort, vpnIpA, vpnIp // Check that our indexes match assert.Equal(t, hBinA.LocalIndex, hAinB.RemoteIndex, "Host B local index does not match host A remote index") assert.Equal(t, hBinA.RemoteIndex, hAinB.LocalIndex, "Host B remote index does not match host A local index") - - //TODO: Would be nice to assert this memory - //checkIndexes := func(name string, hm *HostMap, hi *HostInfo) { - // hBbyIndex := hmA.Indexes[hBinA.localIndexId] - // assert.NotNil(t, hBbyIndex, "Could not host info by local index in %s", name) - // assert.Equal(t, &hBbyIndex, &hBinA, "%s Indexes map did not point to the right host info", name) - // - // //TODO: remote indexes are susceptible to collision - // hBbyRemoteIndex := hmA.RemoteIndexes[hBinA.remoteIndexId] - // assert.NotNil(t, hBbyIndex, "Could not host info by remote index in %s", name) - // assert.Equal(t, &hBbyRemoteIndex, &hBinA, "%s RemoteIndexes did not point to the right host info", name) - //} - // - //// Check hostmap indexes too - //checkIndexes("hmA", hmA, hBinA) - //checkIndexes("hmB", hmB, hAinB) } func assertUdpPacket(t *testing.T, expected, b []byte, fromIp, toIp netip.Addr, fromPort, toPort uint16) { + if toIp.Is6() { + assertUdpPacket6(t, expected, b, fromIp, toIp, fromPort, toPort) + } else { + assertUdpPacket4(t, expected, b, fromIp, toIp, fromPort, toPort) + } +} + +func assertUdpPacket6(t *testing.T, expected, b []byte, fromIp, toIp netip.Addr, fromPort, toPort uint16) { + packet := gopacket.NewPacket(b, layers.LayerTypeIPv6, gopacket.Lazy) + v6 := packet.Layer(layers.LayerTypeIPv6).(*layers.IPv6) + assert.NotNil(t, v6, "No ipv6 data found") + + assert.Equal(t, fromIp.AsSlice(), []byte(v6.SrcIP), "Source ip was incorrect") + assert.Equal(t, toIp.AsSlice(), []byte(v6.DstIP), "Dest ip was incorrect") + + udp := packet.Layer(layers.LayerTypeUDP).(*layers.UDP) + assert.NotNil(t, udp, "No udp data found") + + assert.Equal(t, fromPort, uint16(udp.SrcPort), "Source port was incorrect") + assert.Equal(t, toPort, uint16(udp.DstPort), "Dest port was incorrect") + + data := packet.ApplicationLayer() + assert.NotNil(t, data) + assert.Equal(t, expected, data.Payload(), "Data was incorrect") +} + +func assertUdpPacket4(t *testing.T, expected, b []byte, fromIp, toIp netip.Addr, fromPort, toPort uint16) { packet := gopacket.NewPacket(b, layers.LayerTypeIPv4, gopacket.Lazy) v4 := packet.Layer(layers.LayerTypeIPv4).(*layers.IPv4) assert.NotNil(t, v4, "No ipv4 data found") @@ -197,6 +226,14 @@ func assertUdpPacket(t *testing.T, expected, b []byte, fromIp, toIp netip.Addr, assert.Equal(t, expected, data.Payload(), "Data was incorrect") } +func getAddrs(ns []netip.Prefix) []netip.Addr { + var a []netip.Addr + for _, n := range ns { + a = append(a, n.Addr()) + } + return a +} + func NewTestLogger() *logrus.Logger { l := logrus.New() diff --git a/e2e/router/hostmap.go b/e2e/router/hostmap.go index 29fa959..f2805d0 100644 --- a/e2e/router/hostmap.go +++ b/e2e/router/hostmap.go @@ -58,8 +58,9 @@ func renderHostmap(c *nebula.Control) (string, []*edge) { var lines []string var globalLines []*edge - clusterName := strings.Trim(c.GetCert().Name(), " ") - clusterVpnIp := c.GetCert().Networks()[0].Addr() + crt := c.GetCertState().GetDefaultCertificate() + clusterName := strings.Trim(crt.Name(), " ") + clusterVpnIp := crt.Networks()[0].Addr() r := fmt.Sprintf("\tsubgraph %s[\"%s (%s)\"]\n", clusterName, clusterName, clusterVpnIp) hm := c.GetHostmap() @@ -101,7 +102,7 @@ func renderHostmap(c *nebula.Control) (string, []*edge) { for _, idx := range indexes { hi, ok := hm.Indexes[idx] if ok { - r += fmt.Sprintf("\t\t\t%v.%v[\"%v (%v)\"]\n", clusterName, idx, idx, hi.GetVpnIp()) + r += fmt.Sprintf("\t\t\t%v.%v[\"%v (%v)\"]\n", clusterName, idx, idx, hi.GetVpnAddrs()) remoteClusterName := strings.Trim(hi.GetCert().Certificate.Name(), " ") globalLines = append(globalLines, &edge{from: fmt.Sprintf("%v.%v", clusterName, idx), to: fmt.Sprintf("%v.%v", remoteClusterName, hi.GetRemoteIndex())}) _ = hi diff --git a/e2e/router/router.go b/e2e/router/router.go index 0890570..5e52ed7 100644 --- a/e2e/router/router.go +++ b/e2e/router/router.go @@ -10,8 +10,8 @@ import ( "os" "path/filepath" "reflect" + "regexp" "sort" - "strings" "sync" "testing" "time" @@ -136,7 +136,10 @@ func NewR(t testing.TB, controls ...*nebula.Control) *R { panic("Duplicate listen address: " + addr.String()) } - r.vpnControls[c.GetVpnIp()] = c + for _, vpnAddr := range c.GetVpnAddrs() { + r.vpnControls[vpnAddr] = c + } + r.controls[addr] = c } @@ -213,11 +216,11 @@ func (r *R) renderFlow() { continue } participants[addr] = struct{}{} - sanAddr := strings.Replace(addr.String(), ":", "-", 1) + sanAddr := normalizeName(addr.String()) participantsVals = append(participantsVals, sanAddr) fmt.Fprintf( f, " participant %s as Nebula: %s
UDP: %s\n", - sanAddr, e.packet.from.GetVpnIp(), sanAddr, + sanAddr, e.packet.from.GetVpnAddrs(), sanAddr, ) } @@ -250,9 +253,9 @@ func (r *R) renderFlow() { fmt.Fprintf(f, " %s%s%s: %s(%s), index %v, counter: %v\n", - strings.Replace(p.from.GetUDPAddr().String(), ":", "-", 1), + normalizeName(p.from.GetUDPAddr().String()), line, - strings.Replace(p.to.GetUDPAddr().String(), ":", "-", 1), + normalizeName(p.to.GetUDPAddr().String()), h.TypeName(), h.SubTypeName(), h.RemoteIndex, h.MessageCounter, ) } @@ -267,6 +270,11 @@ func (r *R) renderFlow() { } } +func normalizeName(s string) string { + rx := regexp.MustCompile("[\\[\\]\\:]") + return rx.ReplaceAllLiteralString(s, "_") +} + // IgnoreFlow tells the router to stop recording future flows that matches the provided criteria. // messageType and subType will target nebula underlay packets while tun will target nebula overlay packets // NOTE: This is a very broad system, if you set tun to true then no more tun traffic will be rendered @@ -303,7 +311,7 @@ func (r *R) RenderHostmaps(title string, controls ...*nebula.Control) { func (r *R) renderHostmaps(title string) { c := maps.Values(r.controls) sort.SliceStable(c, func(i, j int) bool { - return c[i].GetVpnIp().Compare(c[j].GetVpnIp()) > 0 + return c[i].GetVpnAddrs()[0].Compare(c[j].GetVpnAddrs()[0]) > 0 }) s := renderHostmaps(c...) @@ -419,10 +427,11 @@ func (r *R) RouteUntilTxTun(sender *nebula.Control, receiver *nebula.Control) [] // Nope, lets push the sender along case p := <-udpTx: r.Lock() - c := r.getControl(sender.GetUDPAddr(), p.To, p) + a := sender.GetUDPAddr() + c := r.getControl(a, p.To, p) if c == nil { r.Unlock() - panic("No control for udp tx") + panic("No control for udp tx " + a.String()) } fp := r.unlockedInjectFlow(sender, c, p, false) c.InjectUDPPacket(p) @@ -475,10 +484,11 @@ func (r *R) RouteForAllUntilTxTun(receiver *nebula.Control) []byte { } else { // we are a udp tx, route and continue p := rx.Interface().(*udp.Packet) - c := r.getControl(cm[x].GetUDPAddr(), p.To, p) + a := cm[x].GetUDPAddr() + c := r.getControl(a, p.To, p) if c == nil { r.Unlock() - panic("No control for udp tx") + panic(fmt.Sprintf("No control for udp tx %s", p.To)) } fp := r.unlockedInjectFlow(cm[x], c, p, false) c.InjectUDPPacket(p) @@ -711,30 +721,42 @@ func (r *R) getControl(fromAddr, toAddr netip.AddrPort, p *udp.Packet) *nebula.C } func (r *R) formatUdpPacket(p *packet) string { - packet := gopacket.NewPacket(p.packet.Data, layers.LayerTypeIPv4, gopacket.Lazy) - v4 := packet.Layer(layers.LayerTypeIPv4).(*layers.IPv4) - if v4 == nil { - panic("not an ipv4 packet") + var packet gopacket.Packet + var srcAddr netip.Addr + + packet = gopacket.NewPacket(p.packet.Data, layers.LayerTypeIPv6, gopacket.Lazy) + if packet.ErrorLayer() == nil { + v6 := packet.Layer(layers.LayerTypeIPv6).(*layers.IPv6) + if v6 == nil { + panic("not an ipv6 packet") + } + srcAddr, _ = netip.AddrFromSlice(v6.SrcIP) + } else { + packet = gopacket.NewPacket(p.packet.Data, layers.LayerTypeIPv4, gopacket.Lazy) + v6 := packet.Layer(layers.LayerTypeIPv4).(*layers.IPv4) + if v6 == nil { + panic("not an ipv6 packet") + } + srcAddr, _ = netip.AddrFromSlice(v6.SrcIP) } from := "unknown" - srcAddr, _ := netip.AddrFromSlice(v4.SrcIP) if c, ok := r.vpnControls[srcAddr]; ok { from = c.GetUDPAddr().String() } - udp := packet.Layer(layers.LayerTypeUDP).(*layers.UDP) - if udp == nil { + udpLayer := packet.Layer(layers.LayerTypeUDP).(*layers.UDP) + if udpLayer == nil { panic("not a udp packet") } data := packet.ApplicationLayer() return fmt.Sprintf( " %s-->>%s: src port: %v
dest port: %v
data: \"%v\"\n", - strings.Replace(from, ":", "-", 1), - strings.Replace(p.to.GetUDPAddr().String(), ":", "-", 1), - udp.SrcPort, - udp.DstPort, + normalizeName(from), + normalizeName(p.to.GetUDPAddr().String()), + udpLayer.SrcPort, + udpLayer.DstPort, string(data.Payload()), ) } diff --git a/examples/config.yml b/examples/config.yml index c74ffc6..1c3584e 100644 --- a/examples/config.yml +++ b/examples/config.yml @@ -13,6 +13,12 @@ pki: # disconnect_invalid is a toggle to force a client to be disconnected if the certificate is expired or invalid. #disconnect_invalid: true + # default_version controls which certificate version is used in handshakes. + # This setting only applies if both a v1 and a v2 certificate are configured, in which case it will default to `1`. + # Once all hosts in the mesh are configured with both a v1 and v2 certificate then this should be changed to `2`. + # After all hosts in the mesh are using a v2 certificate then v1 certificates are no longer needed. + # default_version: 1 + # The static host map defines a set of hosts with fixed IP addresses on the internet (or any network). # A host can have multiple fixed IP addresses defined here, and nebula will try each when establishing a tunnel. # The syntax is: @@ -244,7 +250,6 @@ tun: # in nebula configuration files. Default false, not reloadable. #use_system_route_table: false -# TODO # Configure logging level logging: # panic, fatal, error, warning, info, or debug. Default is info and is reloadable. @@ -336,10 +341,12 @@ firewall: # host: `any` or a literal hostname, ie `test-host` # group: `any` or a literal group name, ie `default-group` # groups: Same as group but accepts a list of values. Multiple values are AND'd together and a certificate would have to contain all groups to pass - # cidr: a remote CIDR, `0.0.0.0/0` is any. - # local_cidr: a local CIDR, `0.0.0.0/0` is any. This could be used to filter destinations when using unsafe_routes. - # Default is `any` unless the certificate contains subnets and then the default is the ip issued in the certificate - # if `default_local_cidr_any` is false, otherwise its `any`. + # cidr: a remote CIDR, `0.0.0.0/0` is any ipv4 and `::/0` is any ipv6. + # local_cidr: a local CIDR, `0.0.0.0/0` is any ipv4 and `::/0` is any ipv6. This could be used to filter destinations when using unsafe_routes. + # If no unsafe networks are present in the certificate(s) or `default_local_cidr_any` is true then the default is any ipv4 or ipv6 network. + # Otherwise the default is any vpn network assigned to via the certificate. + # `default_local_cidr_any` defaults to false and is deprecated, it will be removed in a future release. + # If there are unsafe routes present its best to set `local_cidr` to whatever best fits the situation. # ca_name: An issuing CA name # ca_sha: An issuing CA shasum diff --git a/firewall.go b/firewall.go index 80a8280..d3b9eb6 100644 --- a/firewall.go +++ b/firewall.go @@ -22,7 +22,7 @@ import ( ) type FirewallInterface interface { - AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip, localIp netip.Prefix, caName string, caSha string) error + AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, addr, localAddr netip.Prefix, caName string, caSha string) error } type conn struct { @@ -51,9 +51,12 @@ type Firewall struct { UDPTimeout time.Duration //linux: 180s max DefaultTimeout time.Duration //linux: 600s - // Used to ensure we don't emit local packets for ips we don't own - localIps *bart.Table[struct{}] - assignedCIDR netip.Prefix + // routableNetworks describes the vpn addresses as well as any unsafe networks issued to us in the certificate. + // The vpn addresses are a full bit match while the unsafe networks only match the prefix + routableNetworks *bart.Table[struct{}] + + // assignedNetworks is a list of vpn networks assigned to us in the certificate. + assignedNetworks []netip.Prefix hasUnsafeNetworks bool rules string @@ -67,9 +70,9 @@ type Firewall struct { } type firewallMetrics struct { - droppedLocalIP metrics.Counter - droppedRemoteIP metrics.Counter - droppedNoRule metrics.Counter + droppedLocalAddr metrics.Counter + droppedRemoteAddr metrics.Counter + droppedNoRule metrics.Counter } type FirewallConntrack struct { @@ -126,84 +129,87 @@ type firewallLocalCIDR struct { } // NewFirewall creates a new Firewall object. A TimerWheel is created for you from the provided timeouts. +// The certificate provided should be the highest version loaded in memory. func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.Duration, c cert.Certificate) *Firewall { //TODO: error on 0 duration - var min, max time.Duration + var tmin, tmax time.Duration if tcpTimeout < UDPTimeout { - min = tcpTimeout - max = UDPTimeout + tmin = tcpTimeout + tmax = UDPTimeout } else { - min = UDPTimeout - max = tcpTimeout + tmin = UDPTimeout + tmax = tcpTimeout } - if defaultTimeout < min { - min = defaultTimeout - } else if defaultTimeout > max { - max = defaultTimeout + if defaultTimeout < tmin { + tmin = defaultTimeout + } else if defaultTimeout > tmax { + tmax = defaultTimeout } - localIps := new(bart.Table[struct{}]) - var assignedCIDR netip.Prefix - var assignedSet bool + routableNetworks := new(bart.Table[struct{}]) + var assignedNetworks []netip.Prefix for _, network := range c.Networks() { nprefix := netip.PrefixFrom(network.Addr(), network.Addr().BitLen()) - localIps.Insert(nprefix, struct{}{}) - - if !assignedSet { - // Only grabbing the first one in the cert since any more than that currently has undefined behavior - assignedCIDR = nprefix - assignedSet = true - } + routableNetworks.Insert(nprefix, struct{}{}) + assignedNetworks = append(assignedNetworks, network) } hasUnsafeNetworks := false for _, n := range c.UnsafeNetworks() { - localIps.Insert(n, struct{}{}) + routableNetworks.Insert(n, struct{}{}) hasUnsafeNetworks = true } return &Firewall{ Conntrack: &FirewallConntrack{ Conns: make(map[firewall.Packet]*conn), - TimerWheel: NewTimerWheel[firewall.Packet](min, max), + TimerWheel: NewTimerWheel[firewall.Packet](tmin, tmax), }, InRules: newFirewallTable(), OutRules: newFirewallTable(), TCPTimeout: tcpTimeout, UDPTimeout: UDPTimeout, DefaultTimeout: defaultTimeout, - localIps: localIps, - assignedCIDR: assignedCIDR, + routableNetworks: routableNetworks, + assignedNetworks: assignedNetworks, hasUnsafeNetworks: hasUnsafeNetworks, l: l, incomingMetrics: firewallMetrics{ - droppedLocalIP: metrics.GetOrRegisterCounter("firewall.incoming.dropped.local_ip", nil), - droppedRemoteIP: metrics.GetOrRegisterCounter("firewall.incoming.dropped.remote_ip", nil), - droppedNoRule: metrics.GetOrRegisterCounter("firewall.incoming.dropped.no_rule", nil), + droppedLocalAddr: metrics.GetOrRegisterCounter("firewall.incoming.dropped.local_addr", nil), + droppedRemoteAddr: metrics.GetOrRegisterCounter("firewall.incoming.dropped.remote_addr", nil), + droppedNoRule: metrics.GetOrRegisterCounter("firewall.incoming.dropped.no_rule", nil), }, outgoingMetrics: firewallMetrics{ - droppedLocalIP: metrics.GetOrRegisterCounter("firewall.outgoing.dropped.local_ip", nil), - droppedRemoteIP: metrics.GetOrRegisterCounter("firewall.outgoing.dropped.remote_ip", nil), - droppedNoRule: metrics.GetOrRegisterCounter("firewall.outgoing.dropped.no_rule", nil), + droppedLocalAddr: metrics.GetOrRegisterCounter("firewall.outgoing.dropped.local_addr", nil), + droppedRemoteAddr: metrics.GetOrRegisterCounter("firewall.outgoing.dropped.remote_addr", nil), + droppedNoRule: metrics.GetOrRegisterCounter("firewall.outgoing.dropped.no_rule", nil), }, } } -func NewFirewallFromConfig(l *logrus.Logger, nc cert.Certificate, c *config.C) (*Firewall, error) { +func NewFirewallFromConfig(l *logrus.Logger, cs *CertState, c *config.C) (*Firewall, error) { + certificate := cs.getCertificate(cert.Version2) + if certificate == nil { + certificate = cs.getCertificate(cert.Version1) + } + + if certificate == nil { + panic("No certificate available to reconfigure the firewall") + } + fw := NewFirewall( l, c.GetDuration("firewall.conntrack.tcp_timeout", time.Minute*12), c.GetDuration("firewall.conntrack.udp_timeout", time.Minute*3), c.GetDuration("firewall.conntrack.default_timeout", time.Minute*10), - nc, + certificate, //TODO: max_connections ) - //TODO: Flip to false after v1.9 release - fw.defaultLocalCIDRAny = c.GetBool("firewall.default_local_cidr_any", true) + fw.defaultLocalCIDRAny = c.GetBool("firewall.default_local_cidr_any", false) inboundAction := c.GetString("firewall.inbound_action", "drop") switch inboundAction { @@ -283,7 +289,7 @@ func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort fp = ft.TCP case firewall.ProtoUDP: fp = ft.UDP - case firewall.ProtoICMP: + case firewall.ProtoICMP, firewall.ProtoICMPv6: fp = ft.ICMP case firewall.ProtoAny: fp = ft.AnyProto @@ -424,26 +430,24 @@ func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool * } // Make sure remote address matches nebula certificate - if remoteCidr := h.remoteCidr; remoteCidr != nil { - //TODO: this would be better if we had a least specific match lookup, could waste time here, need to benchmark since the algo is different - _, ok := remoteCidr.Lookup(fp.RemoteIP) + if h.networks != nil { + _, ok := h.networks.Lookup(fp.RemoteAddr) if !ok { - f.metrics(incoming).droppedRemoteIP.Inc(1) + f.metrics(incoming).droppedRemoteAddr.Inc(1) return ErrInvalidRemoteIP } } else { - // Simple case: Certificate has one IP and no subnets - if fp.RemoteIP != h.vpnIp { - f.metrics(incoming).droppedRemoteIP.Inc(1) + // Simple case: Certificate has one address and no unsafe networks + if h.vpnAddrs[0] != fp.RemoteAddr { + f.metrics(incoming).droppedRemoteAddr.Inc(1) return ErrInvalidRemoteIP } } // Make sure we are supposed to be handling this local ip address - //TODO: this would be better if we had a least specific match lookup, could waste time here, need to benchmark since the algo is different - _, ok := f.localIps.Lookup(fp.LocalIP) + _, ok := f.routableNetworks.Lookup(fp.LocalAddr) if !ok { - f.metrics(incoming).droppedLocalIP.Inc(1) + f.metrics(incoming).droppedLocalAddr.Inc(1) return ErrInvalidLocalIP } @@ -629,7 +633,7 @@ func (ft *FirewallTable) match(p firewall.Packet, incoming bool, c *cert.CachedC if ft.UDP.match(p, incoming, c, caPool) { return true } - case firewall.ProtoICMP: + case firewall.ProtoICMP, firewall.ProtoICMPv6: if ft.ICMP.match(p, incoming, c, caPool) { return true } @@ -859,9 +863,9 @@ func (fr *FirewallRule) match(p firewall.Packet, c *cert.CachedCertificate) bool } matched := false - prefix := netip.PrefixFrom(p.RemoteIP, p.RemoteIP.BitLen()) + prefix := netip.PrefixFrom(p.RemoteAddr, p.RemoteAddr.BitLen()) fr.CIDR.EachLookupPrefix(prefix, func(prefix netip.Prefix, val *firewallLocalCIDR) bool { - if prefix.Contains(p.RemoteIP) && val.match(p, c) { + if prefix.Contains(p.RemoteAddr) && val.match(p, c) { matched = true return false } @@ -877,9 +881,14 @@ func (flc *firewallLocalCIDR) addRule(f *Firewall, localIp netip.Prefix) error { return nil } - localIp = f.assignedCIDR + for _, network := range f.assignedNetworks { + flc.LocalCIDR.Insert(network, struct{}{}) + } + return nil + } else if localIp.Bits() == 0 { flc.Any = true + return nil } flc.LocalCIDR.Insert(localIp, struct{}{}) @@ -895,7 +904,7 @@ func (flc *firewallLocalCIDR) match(p firewall.Packet, c *cert.CachedCertificate return true } - _, ok := flc.LocalCIDR.Lookup(p.LocalIP) + _, ok := flc.LocalCIDR.Lookup(p.LocalAddr) return ok } diff --git a/firewall/packet.go b/firewall/packet.go index 8954f4c..1d8f12a 100644 --- a/firewall/packet.go +++ b/firewall/packet.go @@ -9,18 +9,19 @@ import ( type m map[string]interface{} const ( - ProtoAny = 0 // When we want to handle HOPOPT (0) we can change this, if ever - ProtoTCP = 6 - ProtoUDP = 17 - ProtoICMP = 1 + ProtoAny = 0 // When we want to handle HOPOPT (0) we can change this, if ever + ProtoTCP = 6 + ProtoUDP = 17 + ProtoICMP = 1 + ProtoICMPv6 = 58 PortAny = 0 // Special value for matching `port: any` PortFragment = -1 // Special value for matching `port: fragment` ) type Packet struct { - LocalIP netip.Addr - RemoteIP netip.Addr + LocalAddr netip.Addr + RemoteAddr netip.Addr LocalPort uint16 RemotePort uint16 Protocol uint8 @@ -29,8 +30,8 @@ type Packet struct { func (fp *Packet) Copy() *Packet { return &Packet{ - LocalIP: fp.LocalIP, - RemoteIP: fp.RemoteIP, + LocalAddr: fp.LocalAddr, + RemoteAddr: fp.RemoteAddr, LocalPort: fp.LocalPort, RemotePort: fp.RemotePort, Protocol: fp.Protocol, @@ -51,8 +52,8 @@ func (fp Packet) MarshalJSON() ([]byte, error) { proto = fmt.Sprintf("unknown %v", fp.Protocol) } return json.Marshal(m{ - "LocalIP": fp.LocalIP.String(), - "RemoteIP": fp.RemoteIP.String(), + "LocalAddr": fp.LocalAddr.String(), + "RemoteAddr": fp.RemoteAddr.String(), "LocalPort": fp.LocalPort, "RemotePort": fp.RemotePort, "Protocol": proto, diff --git a/firewall_test.go b/firewall_test.go index 57cd32a..4dd2c9a 100644 --- a/firewall_test.go +++ b/firewall_test.go @@ -13,6 +13,7 @@ import ( "github.com/slackhq/nebula/firewall" "github.com/slackhq/nebula/test" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestNewFirewall(t *testing.T) { @@ -128,8 +129,8 @@ func TestFirewall_Drop(t *testing.T) { l.SetOutput(ob) p := firewall.Packet{ - LocalIP: netip.MustParseAddr("1.2.3.4"), - RemoteIP: netip.MustParseAddr("1.2.3.4"), + LocalAddr: netip.MustParseAddr("1.2.3.4"), + RemoteAddr: netip.MustParseAddr("1.2.3.4"), LocalPort: 10, RemotePort: 90, Protocol: firewall.ProtoUDP, @@ -149,9 +150,9 @@ func TestFirewall_Drop(t *testing.T) { InvertedGroups: map[string]struct{}{"default-group": {}}, }, }, - vpnIp: netip.MustParseAddr("1.2.3.4"), + vpnAddrs: []netip.Addr{netip.MustParseAddr("1.2.3.4")}, } - h.CreateRemoteCIDR(&c) + h.buildNetworks(c.networks, c.unsafeNetworks) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c) assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) @@ -166,10 +167,10 @@ func TestFirewall_Drop(t *testing.T) { assert.NoError(t, fw.Drop(p, false, &h, cp, nil)) // test remote mismatch - oldRemote := p.RemoteIP - p.RemoteIP = netip.MustParseAddr("1.2.3.10") + oldRemote := p.RemoteAddr + p.RemoteAddr = netip.MustParseAddr("1.2.3.10") assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrInvalidRemoteIP) - p.RemoteIP = oldRemote + p.RemoteAddr = oldRemote // ensure signer doesn't get in the way of group checks fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) @@ -235,7 +236,7 @@ func BenchmarkFirewallTable_match(b *testing.B) { } ip := netip.MustParsePrefix("9.254.254.254/32") for n := 0; n < b.N; n++ { - assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: ip.Addr()}, true, c, cp)) + assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: ip.Addr()}, true, c, cp)) } }) @@ -261,7 +262,7 @@ func BenchmarkFirewallTable_match(b *testing.B) { InvertedGroups: map[string]struct{}{"nope": {}}, } for n := 0; n < b.N; n++ { - assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: pfix.Addr()}, true, c, cp)) + assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: pfix.Addr()}, true, c, cp)) } }) @@ -285,7 +286,7 @@ func BenchmarkFirewallTable_match(b *testing.B) { InvertedGroups: map[string]struct{}{"good-group": {}}, } for n := 0; n < b.N; n++ { - assert.True(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: pfix.Addr()}, true, c, cp)) + assert.True(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: pfix.Addr()}, true, c, cp)) } }) @@ -308,8 +309,8 @@ func TestFirewall_Drop2(t *testing.T) { l.SetOutput(ob) p := firewall.Packet{ - LocalIP: netip.MustParseAddr("1.2.3.4"), - RemoteIP: netip.MustParseAddr("1.2.3.4"), + LocalAddr: netip.MustParseAddr("1.2.3.4"), + RemoteAddr: netip.MustParseAddr("1.2.3.4"), LocalPort: 10, RemotePort: 90, Protocol: firewall.ProtoUDP, @@ -329,9 +330,9 @@ func TestFirewall_Drop2(t *testing.T) { ConnectionState: &ConnectionState{ peerCert: &c, }, - vpnIp: network.Addr(), + vpnAddrs: []netip.Addr{network.Addr()}, } - h.CreateRemoteCIDR(c.Certificate) + h.buildNetworks(c.Certificate.Networks(), c.Certificate.UnsafeNetworks()) c1 := cert.CachedCertificate{ Certificate: &dummyCert{ @@ -341,11 +342,12 @@ func TestFirewall_Drop2(t *testing.T) { InvertedGroups: map[string]struct{}{"default-group": {}, "test-group-not": {}}, } h1 := HostInfo{ + vpnAddrs: []netip.Addr{network.Addr()}, ConnectionState: &ConnectionState{ peerCert: &c1, }, } - h1.CreateRemoteCIDR(c1.Certificate) + h1.buildNetworks(c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks()) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) @@ -364,8 +366,8 @@ func TestFirewall_Drop3(t *testing.T) { l.SetOutput(ob) p := firewall.Packet{ - LocalIP: netip.MustParseAddr("1.2.3.4"), - RemoteIP: netip.MustParseAddr("1.2.3.4"), + LocalAddr: netip.MustParseAddr("1.2.3.4"), + RemoteAddr: netip.MustParseAddr("1.2.3.4"), LocalPort: 1, RemotePort: 1, Protocol: firewall.ProtoUDP, @@ -391,9 +393,9 @@ func TestFirewall_Drop3(t *testing.T) { ConnectionState: &ConnectionState{ peerCert: &c1, }, - vpnIp: network.Addr(), + vpnAddrs: []netip.Addr{network.Addr()}, } - h1.CreateRemoteCIDR(c1.Certificate) + h1.buildNetworks(c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks()) c2 := cert.CachedCertificate{ Certificate: &dummyCert{ @@ -406,9 +408,9 @@ func TestFirewall_Drop3(t *testing.T) { ConnectionState: &ConnectionState{ peerCert: &c2, }, - vpnIp: network.Addr(), + vpnAddrs: []netip.Addr{network.Addr()}, } - h2.CreateRemoteCIDR(c2.Certificate) + h2.buildNetworks(c2.Certificate.Networks(), c2.Certificate.UnsafeNetworks()) c3 := cert.CachedCertificate{ Certificate: &dummyCert{ @@ -421,9 +423,9 @@ func TestFirewall_Drop3(t *testing.T) { ConnectionState: &ConnectionState{ peerCert: &c3, }, - vpnIp: network.Addr(), + vpnAddrs: []netip.Addr{network.Addr()}, } - h3.CreateRemoteCIDR(c3.Certificate) + h3.buildNetworks(c3.Certificate.Networks(), c3.Certificate.UnsafeNetworks()) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", netip.Prefix{}, netip.Prefix{}, "", "")) @@ -446,8 +448,8 @@ func TestFirewall_DropConntrackReload(t *testing.T) { l.SetOutput(ob) p := firewall.Packet{ - LocalIP: netip.MustParseAddr("1.2.3.4"), - RemoteIP: netip.MustParseAddr("1.2.3.4"), + LocalAddr: netip.MustParseAddr("1.2.3.4"), + RemoteAddr: netip.MustParseAddr("1.2.3.4"), LocalPort: 10, RemotePort: 90, Protocol: firewall.ProtoUDP, @@ -468,9 +470,9 @@ func TestFirewall_DropConntrackReload(t *testing.T) { ConnectionState: &ConnectionState{ peerCert: &c, }, - vpnIp: network.Addr(), + vpnAddrs: []netip.Addr{network.Addr()}, } - h.CreateRemoteCIDR(c.Certificate) + h.buildNetworks(c.Certificate.Networks(), c.Certificate.UnsafeNetworks()) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) @@ -574,8 +576,6 @@ func BenchmarkLookup(b *testing.B) { ml(m, a) } }) - - //TODO: only way array lookup in array will help is if both are sorted, then maybe it's faster } func Test_parsePort(t *testing.T) { @@ -622,55 +622,58 @@ func TestNewFirewallFromConfig(t *testing.T) { l := test.NewLogger() // Test a bad rule definition c := &dummyCert{} + cs, err := newCertState(cert.Version2, nil, c, false, cert.Curve_CURVE25519, nil) + require.NoError(t, err) + conf := config.NewC(l) conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": "asdf"} - _, err := NewFirewallFromConfig(l, c, conf) + _, err = NewFirewallFromConfig(l, cs, conf) assert.EqualError(t, err, "firewall.outbound failed to parse, should be an array of rules") // Test both port and code conf = config.NewC(l) conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "code": "2"}}} - _, err = NewFirewallFromConfig(l, c, conf) + _, err = NewFirewallFromConfig(l, cs, conf) assert.EqualError(t, err, "firewall.outbound rule #0; only one of port or code should be provided") // Test missing host, group, cidr, ca_name and ca_sha conf = config.NewC(l) conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{}}} - _, err = NewFirewallFromConfig(l, c, conf) + _, err = NewFirewallFromConfig(l, cs, conf) assert.EqualError(t, err, "firewall.outbound rule #0; at least one of host, group, cidr, local_cidr, ca_name, or ca_sha must be provided") // Test code/port error conf = config.NewC(l) conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "a", "host": "testh"}}} - _, err = NewFirewallFromConfig(l, c, conf) + _, err = NewFirewallFromConfig(l, cs, conf) assert.EqualError(t, err, "firewall.outbound rule #0; code was not a number; `a`") conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "a", "host": "testh"}}} - _, err = NewFirewallFromConfig(l, c, conf) + _, err = NewFirewallFromConfig(l, cs, conf) assert.EqualError(t, err, "firewall.outbound rule #0; port was not a number; `a`") // Test proto error conf = config.NewC(l) conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "host": "testh"}}} - _, err = NewFirewallFromConfig(l, c, conf) + _, err = NewFirewallFromConfig(l, cs, conf) assert.EqualError(t, err, "firewall.outbound rule #0; proto was not understood; ``") // Test cidr parse error conf = config.NewC(l) conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "cidr": "testh", "proto": "any"}}} - _, err = NewFirewallFromConfig(l, c, conf) + _, err = NewFirewallFromConfig(l, cs, conf) assert.EqualError(t, err, "firewall.outbound rule #0; cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'") // Test local_cidr parse error conf = config.NewC(l) conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "local_cidr": "testh", "proto": "any"}}} - _, err = NewFirewallFromConfig(l, c, conf) + _, err = NewFirewallFromConfig(l, cs, conf) assert.EqualError(t, err, "firewall.outbound rule #0; local_cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'") // Test both group and groups conf = config.NewC(l) conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a", "groups": []string{"b", "c"}}}} - _, err = NewFirewallFromConfig(l, c, conf) + _, err = NewFirewallFromConfig(l, cs, conf) assert.EqualError(t, err, "firewall.inbound rule #0; only one of group or groups should be defined, both provided") } diff --git a/go.mod b/go.mod index f464990..2ff9976 100644 --- a/go.mod +++ b/go.mod @@ -21,7 +21,6 @@ require ( github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475 github.com/sirupsen/logrus v1.9.3 github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e - github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8 github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6 github.com/stretchr/testify v1.9.0 github.com/vishvananda/netlink v1.3.0 diff --git a/go.sum b/go.sum index dacc3d3..d0e9c55 100644 --- a/go.sum +++ b/go.sum @@ -137,8 +137,6 @@ github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e h1:MRM5ITcdelLK2j1vwZ3Je0FKVCfqOLp5zO6trqMLYs0= github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e/go.mod h1:XV66xRDqSt+GTGFMVlhk3ULuV0y9ZmzeVGR4mloJI3M= -github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8 h1:TG/diQgUe0pntT/2D9tmUCz4VNwm9MfrtPr0SU2qSX8= -github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8/go.mod h1:P5HUIBuIWKbyjl083/loAegFkfbFNx5i2qEP4CNbm7E= github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6 h1:pnnLyeX7o/5aX8qUQ69P/mLojDqwda8hFOCBTmP/6hw= github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6/go.mod h1:39R/xuhNgVhi+K0/zst4TLrJrVmbm6LVgl4A0+ZFS5M= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= diff --git a/handshake_ix.go b/handshake_ix.go index 3add83d..9b8b3e9 100644 --- a/handshake_ix.go +++ b/handshake_ix.go @@ -2,10 +2,12 @@ package nebula import ( "net/netip" + "slices" "time" "github.com/flynn/noise" "github.com/sirupsen/logrus" + "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/header" ) @@ -16,30 +18,60 @@ import ( func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool { err := f.handshakeManager.allocateIndex(hh) if err != nil { - f.l.WithError(err).WithField("vpnIp", hh.hostinfo.vpnIp). + f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs). WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to generate index") return false } - certState := f.pki.GetCertState() - ci := NewConnectionState(f.l, f.cipher, certState, true, noise.HandshakeIX, []byte{}, 0) + // If we're connecting to a v6 address we must use a v2 cert + cs := f.pki.getCertState() + v := cs.defaultVersion + for _, a := range hh.hostinfo.vpnAddrs { + if a.Is6() { + v = cert.Version2 + break + } + } + + crt := cs.getCertificate(v) + if crt == nil { + f.l.WithField("vpnAddrs", hh.hostinfo.vpnAddrs). + WithField("handshake", m{"stage": 0, "style": "ix_psk0"}). + WithField("certVersion", v). + Error("Unable to handshake with host because no certificate is available") + return false + } + + crtHs := cs.getHandshakeBytes(v) + if crtHs == nil { + f.l.WithField("vpnAddrs", hh.hostinfo.vpnAddrs). + WithField("handshake", m{"stage": 0, "style": "ix_psk0"}). + WithField("certVersion", v). + Error("Unable to handshake with host because no certificate handshake bytes is available") + } + + ci, err := NewConnectionState(f.l, cs, crt, true, noise.HandshakeIX) + if err != nil { + f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs). + WithField("handshake", m{"stage": 0, "style": "ix_psk0"}). + WithField("certVersion", v). + Error("Failed to create connection state") + return false + } hh.hostinfo.ConnectionState = ci - hsProto := &NebulaHandshakeDetails{ - InitiatorIndex: hh.hostinfo.localIndexId, - Time: uint64(time.Now().UnixNano()), - Cert: certState.RawCertificateNoKey, - } - - hsBytes := []byte{} - hs := &NebulaHandshake{ - Details: hsProto, + Details: &NebulaHandshakeDetails{ + InitiatorIndex: hh.hostinfo.localIndexId, + Time: uint64(time.Now().UnixNano()), + Cert: crtHs, + CertVersion: uint32(v), + }, } - hsBytes, err = hs.Marshal() + hsBytes, err := hs.Marshal() if err != nil { - f.l.WithError(err).WithField("vpnIp", hh.hostinfo.vpnIp). + f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs).WithField("certVersion", v). WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to marshal handshake message") return false } @@ -48,7 +80,7 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool { msg, _, _, err := ci.H.WriteMessage(h, hsBytes) if err != nil { - f.l.WithError(err).WithField("vpnIp", hh.hostinfo.vpnIp). + f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs). WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage") return false } @@ -63,30 +95,44 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool { } func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet []byte, h *header.H) { - certState := f.pki.GetCertState() - ci := NewConnectionState(f.l, f.cipher, certState, false, noise.HandshakeIX, []byte{}, 0) + cs := f.pki.getCertState() + crt := cs.GetDefaultCertificate() + if crt == nil { + f.l.WithField("udpAddr", addr). + WithField("handshake", m{"stage": 0, "style": "ix_psk0"}). + WithField("certVersion", cs.defaultVersion). + Error("Unable to handshake with host because no certificate is available") + } + + ci, err := NewConnectionState(f.l, cs, crt, false, noise.HandshakeIX) + if err != nil { + f.l.WithError(err).WithField("udpAddr", addr). + WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). + Error("Failed to create connection state") + return + } + // Mark packet 1 as seen so it doesn't show up as missed ci.window.Update(f.l, 1) msg, _, _, err := ci.H.ReadMessage(nil, packet[header.Len:]) if err != nil { f.l.WithError(err).WithField("udpAddr", addr). - WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.ReadMessage") + WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). + Error("Failed to call noise.ReadMessage") return } hs := &NebulaHandshake{} err = hs.Unmarshal(msg) - /* - l.Debugln("GOT INDEX: ", hs.Details.InitiatorIndex) - */ if err != nil || hs.Details == nil { f.l.WithError(err).WithField("udpAddr", addr). - WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed unmarshal handshake message") + WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). + Error("Failed unmarshal handshake message") return } - remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert, f.pki.GetCAPool()) + remoteCert, err := cert.RecombineAndValidate(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), ci.Curve(), f.pki.GetCAPool()) if err != nil { e := f.l.WithError(err).WithField("udpAddr", addr). WithField("handshake", m{"stage": 1, "style": "ix_psk0"}) @@ -99,6 +145,20 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet return } + if remoteCert.Certificate.Version() != ci.myCert.Version() { + // We started off using the wrong certificate version, lets see if we can match the version that was sent to us + rc := cs.getCertificate(remoteCert.Certificate.Version()) + if rc == nil { + f.l.WithError(err).WithField("udpAddr", addr). + WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).WithField("cert", remoteCert). + Info("Unable to handshake with host due to missing certificate version") + return + } + + // Record the certificate we are actually using + ci.myCert = rc + } + if len(remoteCert.Certificate.Networks()) == 0 { e := f.l.WithError(err).WithField("udpAddr", addr). WithField("handshake", m{"stage": 1, "style": "ix_psk0"}) @@ -111,30 +171,54 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet return } - vpnIp := remoteCert.Certificate.Networks()[0].Addr().Unmap() + var vpnAddrs []netip.Addr + var filteredNetworks []netip.Prefix certName := remoteCert.Certificate.Name() fingerprint := remoteCert.Fingerprint issuer := remoteCert.Certificate.Issuer() - if vpnIp == f.myVpnNet.Addr() { - f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr). + for _, network := range remoteCert.Certificate.Networks() { + vpnAddr := network.Addr() + _, found := f.myVpnAddrsTable.Lookup(vpnAddr) + if found { + f.l.WithField("vpnAddr", vpnAddr).WithField("udpAddr", addr). + WithField("certName", certName). + WithField("fingerprint", fingerprint). + WithField("issuer", issuer). + WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Refusing to handshake with myself") + return + } + + // vpnAddrs outside our vpn networks are of no use to us, filter them out + if _, ok := f.myVpnNetworksTable.Lookup(vpnAddr); !ok { + continue + } + + filteredNetworks = append(filteredNetworks, network) + vpnAddrs = append(vpnAddrs, vpnAddr) + } + + if len(vpnAddrs) == 0 { + f.l.WithError(err).WithField("udpAddr", addr). WithField("certName", certName). WithField("fingerprint", fingerprint). WithField("issuer", issuer). - WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Refusing to handshake with myself") + WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("No usable vpn addresses from host, refusing handshake") return } if addr.IsValid() { - if !f.lightHouse.GetRemoteAllowList().Allow(vpnIp, addr.Addr()) { - f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake") + // addr can be invalid when the tunnel is being relayed. + // We only want to apply the remote allow list for direct tunnels here + if !f.lightHouse.GetRemoteAllowList().AllowAll(vpnAddrs, addr.Addr()) { + f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake") return } } myIndex, err := generateIndex(f.l) if err != nil { - f.l.WithError(err).WithField("vpnIp", vpnIp).WithField("udpAddr", addr). + f.l.WithError(err).WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr). WithField("certName", certName). WithField("fingerprint", fingerprint). WithField("issuer", issuer). @@ -146,17 +230,17 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet ConnectionState: ci, localIndexId: myIndex, remoteIndexId: hs.Details.InitiatorIndex, - vpnIp: vpnIp, + vpnAddrs: vpnAddrs, HandshakePacket: make(map[uint8][]byte, 0), lastHandshakeTime: hs.Details.Time, relayState: RelayState{ - relays: map[netip.Addr]struct{}{}, - relayForByIp: map[netip.Addr]*Relay{}, - relayForByIdx: map[uint32]*Relay{}, + relays: map[netip.Addr]struct{}{}, + relayForByAddr: map[netip.Addr]*Relay{}, + relayForByIdx: map[uint32]*Relay{}, }, } - f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr). + f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr). WithField("certName", certName). WithField("fingerprint", fingerprint). WithField("issuer", issuer). @@ -165,13 +249,26 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet Info("Handshake message received") hs.Details.ResponderIndex = myIndex - hs.Details.Cert = certState.RawCertificateNoKey + hs.Details.Cert = cs.getHandshakeBytes(ci.myCert.Version()) + if hs.Details.Cert == nil { + f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr). + WithField("certName", certName). + WithField("fingerprint", fingerprint). + WithField("issuer", issuer). + WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). + WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). + WithField("certVersion", ci.myCert.Version()). + Error("Unable to handshake with host because no certificate handshake bytes is available") + return + } + + hs.Details.CertVersion = uint32(ci.myCert.Version()) // Update the time in case their clock is way off from ours hs.Details.Time = uint64(time.Now().UnixNano()) hsBytes, err := hs.Marshal() if err != nil { - f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr). + f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr). WithField("certName", certName). WithField("fingerprint", fingerprint). WithField("issuer", issuer). @@ -182,14 +279,14 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet nh := header.Encode(make([]byte, header.Len), header.Version, header.Handshake, header.HandshakeIXPSK0, hs.Details.InitiatorIndex, 2) msg, dKey, eKey, err := ci.H.WriteMessage(nh, hsBytes) if err != nil { - f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr). + f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr). WithField("certName", certName). WithField("fingerprint", fingerprint). WithField("issuer", issuer). WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage") return } else if dKey == nil || eKey == nil { - f.l.WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr). + f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr). WithField("certName", certName). WithField("fingerprint", fingerprint). WithField("issuer", issuer). @@ -213,9 +310,9 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet ci.dKey = NewNebulaCipherState(dKey) ci.eKey = NewNebulaCipherState(eKey) - hostinfo.remotes = f.lightHouse.QueryCache(vpnIp) + hostinfo.remotes = f.lightHouse.QueryCache(vpnAddrs) hostinfo.SetRemote(addr) - hostinfo.CreateRemoteCIDR(remoteCert.Certificate) + hostinfo.buildNetworks(filteredNetworks, remoteCert.Certificate.UnsafeNetworks()) existing, err := f.handshakeManager.CheckAndComplete(hostinfo, 0, f) if err != nil { @@ -225,7 +322,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet if existing.SetRemoteIfPreferred(f.hostMap, addr) { // Send a test packet to ensure the other side has also switched to // the preferred remote - f.SendMessageToVpnIp(header.Test, header.TestRequest, vpnIp, []byte(""), make([]byte, 12, 12), make([]byte, mtu)) + f.SendMessageToVpnAddr(header.Test, header.TestRequest, vpnAddrs[0], []byte(""), make([]byte, 12, 12), make([]byte, mtu)) } msg = existing.HandshakePacket[2] @@ -233,11 +330,11 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet if addr.IsValid() { err := f.outside.WriteTo(msg, addr) if err != nil { - f.l.WithField("vpnIp", existing.vpnIp).WithField("udpAddr", addr). + f.l.WithField("vpnAddrs", existing.vpnAddrs).WithField("udpAddr", addr). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true). WithError(err).Error("Failed to send handshake message") } else { - f.l.WithField("vpnIp", existing.vpnIp).WithField("udpAddr", addr). + f.l.WithField("vpnAddrs", existing.vpnAddrs).WithField("udpAddr", addr). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true). Info("Handshake message sent") } @@ -247,16 +344,16 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet f.l.Error("Handshake send failed: both addr and via are nil.") return } - hostinfo.relayState.InsertRelayTo(via.relayHI.vpnIp) + hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0]) f.SendVia(via.relayHI, via.relay, msg, make([]byte, 12), make([]byte, mtu), false) - f.l.WithField("vpnIp", existing.vpnIp).WithField("relay", via.relayHI.vpnIp). + f.l.WithField("vpnAddrs", existing.vpnAddrs).WithField("relay", via.relayHI.vpnAddrs[0]). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true). Info("Handshake message sent") return } case ErrExistingHostInfo: // This means there was an existing tunnel and this handshake was older than the one we are currently based on - f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr). + f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr). WithField("certName", certName). WithField("oldHandshakeTime", existing.lastHandshakeTime). WithField("newHandshakeTime", hostinfo.lastHandshakeTime). @@ -267,23 +364,23 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet Info("Handshake too old") // Send a test packet to trigger an authenticated tunnel test, this should suss out any lingering tunnel issues - f.SendMessageToVpnIp(header.Test, header.TestRequest, vpnIp, []byte(""), make([]byte, 12, 12), make([]byte, mtu)) + f.SendMessageToVpnAddr(header.Test, header.TestRequest, vpnAddrs[0], []byte(""), make([]byte, 12, 12), make([]byte, mtu)) return case ErrLocalIndexCollision: // This means we failed to insert because of collision on localIndexId. Just let the next handshake packet retry - f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr). + f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr). WithField("certName", certName). WithField("fingerprint", fingerprint). WithField("issuer", issuer). WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). - WithField("localIndex", hostinfo.localIndexId).WithField("collision", existing.vpnIp). + WithField("localIndex", hostinfo.localIndexId).WithField("collision", existing.vpnAddrs). Error("Failed to add HostInfo due to localIndex collision") return default: // Shouldn't happen, but just in case someone adds a new error type to CheckAndComplete // And we forget to update it here - f.l.WithError(err).WithField("vpnIp", vpnIp).WithField("udpAddr", addr). + f.l.WithError(err).WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr). WithField("certName", certName). WithField("fingerprint", fingerprint). WithField("issuer", issuer). @@ -299,7 +396,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet if addr.IsValid() { err = f.outside.WriteTo(msg, addr) if err != nil { - f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr). + f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr). WithField("certName", certName). WithField("fingerprint", fingerprint). WithField("issuer", issuer). @@ -307,7 +404,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). WithError(err).Error("Failed to send handshake") } else { - f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr). + f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr). WithField("certName", certName). WithField("fingerprint", fingerprint). WithField("issuer", issuer). @@ -320,9 +417,12 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet f.l.Error("Handshake send failed: both addr and via are nil.") return } - hostinfo.relayState.InsertRelayTo(via.relayHI.vpnIp) + hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0]) + // I successfully received a handshake. Just in case I marked this tunnel as 'Disestablished', ensure + // it's correctly marked as working. + via.relayHI.relayState.UpdateRelayForByIdxState(via.remoteIdx, Established) f.SendVia(via.relayHI, via.relay, msg, make([]byte, 12), make([]byte, mtu), false) - f.l.WithField("vpnIp", vpnIp).WithField("relay", via.relayHI.vpnIp). + f.l.WithField("vpnAddrs", vpnAddrs).WithField("relay", via.relayHI.vpnAddrs[0]). WithField("certName", certName). WithField("fingerprint", fingerprint). WithField("issuer", issuer). @@ -349,8 +449,9 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha hostinfo := hh.hostinfo if addr.IsValid() { - if !f.lightHouse.GetRemoteAllowList().Allow(hostinfo.vpnIp, addr.Addr()) { - f.l.WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake") + // The vpnAddr we know about is the one we tried to handshake with, use it to apply the remote allow list. + if !f.lightHouse.GetRemoteAllowList().AllowAll(hostinfo.vpnAddrs, addr.Addr()) { + f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake") return false } } @@ -358,7 +459,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha ci := hostinfo.ConnectionState msg, eKey, dKey, err := ci.H.ReadMessage(nil, packet[header.Len:]) if err != nil { - f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr). + f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("header", h). Error("Failed to call noise.ReadMessage") @@ -367,7 +468,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha // near future return false } else if dKey == nil || eKey == nil { - f.l.WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr). + f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). Error("Noise did not arrive at a key") @@ -379,16 +480,16 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha hs := &NebulaHandshake{} err = hs.Unmarshal(msg) if err != nil || hs.Details == nil { - f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr). + f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).Error("Failed unmarshal handshake message") // The handshake state machine is complete, if things break now there is no chance to recover. Tear down and start again return true } - remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert, f.pki.GetCAPool()) + remoteCert, err := cert.RecombineAndValidate(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), ci.Curve(), f.pki.GetCAPool()) if err != nil { - e := f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr). + e := f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}) if f.l.Level > logrus.DebugLevel { @@ -409,11 +510,11 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha e = e.WithField("cert", remoteCert) } - e.Info("Invalid vpn ip from host") + e.Info("Empty networks from host") return true } - vpnIp := remoteCert.Certificate.Networks()[0].Addr().Unmap() + vpnNetworks := remoteCert.Certificate.Networks() certName := remoteCert.Certificate.Name() fingerprint := remoteCert.Fingerprint issuer := remoteCert.Certificate.Issuer() @@ -430,12 +531,34 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha if addr.IsValid() { hostinfo.SetRemote(addr) } else { - hostinfo.relayState.InsertRelayTo(via.relayHI.vpnIp) + hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0]) + } + + var vpnAddrs []netip.Addr + var filteredNetworks []netip.Prefix + for _, network := range vpnNetworks { + // vpnAddrs outside our vpn networks are of no use to us, filter them out + vpnAddr := network.Addr() + if _, ok := f.myVpnNetworksTable.Lookup(vpnAddr); !ok { + continue + } + + filteredNetworks = append(filteredNetworks, network) + vpnAddrs = append(vpnAddrs, vpnAddr) + } + + if len(vpnAddrs) == 0 { + f.l.WithError(err).WithField("udpAddr", addr). + WithField("certName", certName). + WithField("fingerprint", fingerprint). + WithField("issuer", issuer). + WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).Error("No usable vpn addresses from host, refusing handshake") + return true } // Ensure the right host responded - if vpnIp != hostinfo.vpnIp { - f.l.WithField("intendedVpnIp", hostinfo.vpnIp).WithField("haveVpnIp", vpnIp). + if !slices.Contains(vpnAddrs, hostinfo.vpnAddrs[0]) { + f.l.WithField("intendedVpnAddrs", hostinfo.vpnAddrs).WithField("haveVpnNetworks", vpnNetworks). WithField("udpAddr", addr).WithField("certName", certName). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). Info("Incorrect host responded to handshake") @@ -444,14 +567,13 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha f.handshakeManager.DeleteHostInfo(hostinfo) // Create a new hostinfo/handshake for the intended vpn ip - f.handshakeManager.StartHandshake(hostinfo.vpnIp, func(newHH *HandshakeHostInfo) { - //TODO: this doesnt know if its being added or is being used for caching a packet + f.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], func(newHH *HandshakeHostInfo) { // Block the current used address newHH.hostinfo.remotes = hostinfo.remotes newHH.hostinfo.remotes.BlockRemote(addr) f.l.WithField("blockedUdpAddrs", newHH.hostinfo.remotes.CopyBlockedRemotes()). - WithField("vpnIp", newHH.hostinfo.vpnIp). + WithField("vpnNetworks", vpnNetworks). WithField("remotes", newHH.hostinfo.remotes.CopyAddrs(f.hostMap.GetPreferredRanges())). Info("Blocked addresses for handshakes") @@ -459,11 +581,8 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha newHH.packetStore = hh.packetStore hh.packetStore = []*cachedPacket{} - // Get the correct remote list for the host we did handshake with - hostinfo.SetRemote(addr) - hostinfo.remotes = f.lightHouse.QueryCache(vpnIp) - // Finally, put the correct vpn ip in the host info, tell them to close the tunnel, and return true to tear down - hostinfo.vpnIp = vpnIp + // Finally, put the correct vpn addrs in the host info, tell them to close the tunnel, and return true to tear down + hostinfo.vpnAddrs = vpnAddrs f.sendCloseTunnel(hostinfo) }) @@ -474,7 +593,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha ci.window.Update(f.l, 2) duration := time.Since(hh.startTime).Nanoseconds() - f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr). + f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr). WithField("certName", certName). WithField("fingerprint", fingerprint). WithField("issuer", issuer). @@ -485,9 +604,10 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha Info("Handshake message received") // Build up the radix for the firewall if we have subnets in the cert - hostinfo.CreateRemoteCIDR(remoteCert.Certificate) + hostinfo.vpnAddrs = vpnAddrs + hostinfo.buildNetworks(filteredNetworks, remoteCert.Certificate.UnsafeNetworks()) - // Complete our handshake and update metrics, this will replace any existing tunnels for this vpnIp + // Complete our handshake and update metrics, this will replace any existing tunnels for the vpnAddrs here f.handshakeManager.Complete(hostinfo, f) f.connectionManager.AddTrafficWatch(hostinfo.localIndexId) diff --git a/handshake_manager.go b/handshake_manager.go index 4834893..6d3ed12 100644 --- a/handshake_manager.go +++ b/handshake_manager.go @@ -13,6 +13,7 @@ import ( "github.com/rcrowley/go-metrics" "github.com/sirupsen/logrus" + "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/header" "github.com/slackhq/nebula/udp" ) @@ -118,18 +119,18 @@ func NewHandshakeManager(l *logrus.Logger, mainHostMap *HostMap, lightHouse *Lig } } -func (c *HandshakeManager) Run(ctx context.Context) { - clockSource := time.NewTicker(c.config.tryInterval) +func (hm *HandshakeManager) Run(ctx context.Context) { + clockSource := time.NewTicker(hm.config.tryInterval) defer clockSource.Stop() for { select { case <-ctx.Done(): return - case vpnIP := <-c.trigger: - c.handleOutbound(vpnIP, true) + case vpnIP := <-hm.trigger: + hm.handleOutbound(vpnIP, true) case now := <-clockSource.C: - c.NextOutboundHandshakeTimerTick(now) + hm.NextOutboundHandshakeTimerTick(now) } } } @@ -137,7 +138,7 @@ func (c *HandshakeManager) Run(ctx context.Context) { func (hm *HandshakeManager) HandleIncoming(addr netip.AddrPort, via *ViaSender, packet []byte, h *header.H) { // First remote allow list check before we know the vpnIp if addr.IsValid() { - if !hm.lightHouse.GetRemoteAllowList().AllowUnknownVpnIp(addr.Addr()) { + if !hm.lightHouse.GetRemoteAllowList().AllowUnknownVpnAddr(addr.Addr()) { hm.l.WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake") return } @@ -159,14 +160,14 @@ func (hm *HandshakeManager) HandleIncoming(addr netip.AddrPort, via *ViaSender, } } -func (c *HandshakeManager) NextOutboundHandshakeTimerTick(now time.Time) { - c.OutboundHandshakeTimer.Advance(now) +func (hm *HandshakeManager) NextOutboundHandshakeTimerTick(now time.Time) { + hm.OutboundHandshakeTimer.Advance(now) for { - vpnIp, has := c.OutboundHandshakeTimer.Purge() + vpnIp, has := hm.OutboundHandshakeTimer.Purge() if !has { break } - c.handleOutbound(vpnIp, false) + hm.handleOutbound(vpnIp, false) } } @@ -208,7 +209,7 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered // NB ^ This comment doesn't jive. It's how the thing gets initialized. // It's the common path. Should it update every time, in case a future LH query/queries give us more info? if hostinfo.remotes == nil { - hostinfo.remotes = hm.lightHouse.QueryCache(vpnIp) + hostinfo.remotes = hm.lightHouse.QueryCache([]netip.Addr{vpnIp}) } remotes := hostinfo.remotes.CopyAddrs(hm.mainHostMap.GetPreferredRanges()) @@ -223,7 +224,7 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered hh.lastRemotes = remotes - // TODO: this will generate a load of queries for hosts with only 1 ip + // This will generate a load of queries for hosts with only 1 ip // (such as ones registered to the lighthouse with only a private IP) // So we only do it one time after attempting 5 handshakes already. if len(remotes) <= 1 && hh.counter == 5 { @@ -267,59 +268,26 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered hostinfo.logger(hm.l).WithField("relays", hostinfo.remotes.relays).Info("Attempt to relay through hosts") // Send a RelayRequest to all known Relay IP's for _, relay := range hostinfo.remotes.relays { - // Don't relay to myself, and don't relay through the host I'm trying to connect to - if relay == vpnIp || relay == hm.lightHouse.myVpnNet.Addr() { + // Don't relay to myself + if relay == vpnIp { continue } - relayHostInfo := hm.mainHostMap.QueryVpnIp(relay) + + // Don't relay through the host I'm trying to connect to + _, found := hm.f.myVpnAddrsTable.Lookup(relay) + if found { + continue + } + + relayHostInfo := hm.mainHostMap.QueryVpnAddr(relay) if relayHostInfo == nil || !relayHostInfo.remote.IsValid() { hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Establish tunnel to relay target") hm.f.Handshake(relay) continue } - // Check the relay HostInfo to see if we already established a relay through it - if existingRelay, ok := relayHostInfo.relayState.QueryRelayForByIp(vpnIp); ok { - switch existingRelay.State { - case Established: - hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Send handshake via relay") - hm.f.SendVia(relayHostInfo, existingRelay, hostinfo.HandshakePacket[0], make([]byte, 12), make([]byte, mtu), false) - case Requested: - hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Re-send CreateRelay request") - - //TODO: IPV6-WORK - myVpnIpB := hm.f.myVpnNet.Addr().As4() - theirVpnIpB := vpnIp.As4() - - // Re-send the CreateRelay request, in case the previous one was lost. - m := NebulaControl{ - Type: NebulaControl_CreateRelayRequest, - InitiatorRelayIndex: existingRelay.LocalIndex, - RelayFromIp: binary.BigEndian.Uint32(myVpnIpB[:]), - RelayToIp: binary.BigEndian.Uint32(theirVpnIpB[:]), - } - msg, err := m.Marshal() - if err != nil { - hostinfo.logger(hm.l). - WithError(err). - Error("Failed to marshal Control message to create relay") - } else { - // This must send over the hostinfo, not over hm.Hosts[ip] - hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu)) - hm.l.WithFields(logrus.Fields{ - "relayFrom": hm.f.myVpnNet.Addr(), - "relayTo": vpnIp, - "initiatorRelayIndex": existingRelay.LocalIndex, - "relay": relay}). - Info("send CreateRelayRequest") - } - default: - hostinfo.logger(hm.l). - WithField("vpnIp", vpnIp). - WithField("state", existingRelay.State). - WithField("relay", relayHostInfo.vpnIp). - Errorf("Relay unexpected state") - } - } else { + // Check the relay HostInfo to see if we already established a relay through + existingRelay, ok := relayHostInfo.relayState.QueryRelayForByIp(vpnIp) + if !ok { // No relays exist or requested yet. if relayHostInfo.remote.IsValid() { idx, err := AddRelay(hm.l, relayHostInfo, hm.mainHostMap, vpnIp, nil, TerminalType, Requested) @@ -327,16 +295,35 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered hostinfo.logger(hm.l).WithField("relay", relay.String()).WithError(err).Info("Failed to add relay to hostmap") } - //TODO: IPV6-WORK - myVpnIpB := hm.f.myVpnNet.Addr().As4() - theirVpnIpB := vpnIp.As4() - m := NebulaControl{ Type: NebulaControl_CreateRelayRequest, InitiatorRelayIndex: idx, - RelayFromIp: binary.BigEndian.Uint32(myVpnIpB[:]), - RelayToIp: binary.BigEndian.Uint32(theirVpnIpB[:]), } + + switch relayHostInfo.GetCert().Certificate.Version() { + case cert.Version1: + if !hm.f.myVpnAddrs[0].Is4() { + hostinfo.logger(hm.l).Error("can not establish v1 relay with a v6 network because the relay is not running a current nebula version") + continue + } + + if !vpnIp.Is4() { + hostinfo.logger(hm.l).Error("can not establish v1 relay with a v6 remote network because the relay is not running a current nebula version") + continue + } + + b := hm.f.myVpnAddrs[0].As4() + m.OldRelayFromAddr = binary.BigEndian.Uint32(b[:]) + b = vpnIp.As4() + m.OldRelayToAddr = binary.BigEndian.Uint32(b[:]) + case cert.Version2: + m.RelayFromAddr = netAddrToProtoAddr(hm.f.myVpnAddrs[0]) + m.RelayToAddr = netAddrToProtoAddr(vpnIp) + default: + hostinfo.logger(hm.l).Error("Unknown certificate version found while creating relay") + continue + } + msg, err := m.Marshal() if err != nil { hostinfo.logger(hm.l). @@ -345,13 +332,80 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered } else { hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu)) hm.l.WithFields(logrus.Fields{ - "relayFrom": hm.f.myVpnNet.Addr(), + "relayFrom": hm.f.myVpnAddrs[0], "relayTo": vpnIp, "initiatorRelayIndex": idx, "relay": relay}). Info("send CreateRelayRequest") } } + continue + } + + switch existingRelay.State { + case Established: + hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Send handshake via relay") + hm.f.SendVia(relayHostInfo, existingRelay, hostinfo.HandshakePacket[0], make([]byte, 12), make([]byte, mtu), false) + case Disestablished: + // Mark this relay as 'requested' + relayHostInfo.relayState.UpdateRelayForByIpState(vpnIp, Requested) + fallthrough + case Requested: + hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Re-send CreateRelay request") + // Re-send the CreateRelay request, in case the previous one was lost. + m := NebulaControl{ + Type: NebulaControl_CreateRelayRequest, + InitiatorRelayIndex: existingRelay.LocalIndex, + } + + switch relayHostInfo.GetCert().Certificate.Version() { + case cert.Version1: + if !hm.f.myVpnAddrs[0].Is4() { + hostinfo.logger(hm.l).Error("can not establish v1 relay with a v6 network because the relay is not running a current nebula version") + continue + } + + if !vpnIp.Is4() { + hostinfo.logger(hm.l).Error("can not establish v1 relay with a v6 remote network because the relay is not running a current nebula version") + continue + } + + b := hm.f.myVpnAddrs[0].As4() + m.OldRelayFromAddr = binary.BigEndian.Uint32(b[:]) + b = vpnIp.As4() + m.OldRelayToAddr = binary.BigEndian.Uint32(b[:]) + case cert.Version2: + m.RelayFromAddr = netAddrToProtoAddr(hm.f.myVpnAddrs[0]) + m.RelayToAddr = netAddrToProtoAddr(vpnIp) + default: + hostinfo.logger(hm.l).Error("Unknown certificate version found while creating relay") + continue + } + msg, err := m.Marshal() + if err != nil { + hostinfo.logger(hm.l). + WithError(err). + Error("Failed to marshal Control message to create relay") + } else { + // This must send over the hostinfo, not over hm.Hosts[ip] + hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu)) + hm.l.WithFields(logrus.Fields{ + "relayFrom": hm.f.myVpnAddrs[0], + "relayTo": vpnIp, + "initiatorRelayIndex": existingRelay.LocalIndex, + "relay": relay}). + Info("send CreateRelayRequest") + } + case PeerRequested: + // PeerRequested only occurs in Forwarding relays, not Terminal relays, and this is a Terminal relay case. + fallthrough + default: + hostinfo.logger(hm.l). + WithField("vpnIp", vpnIp). + WithField("state", existingRelay.State). + WithField("relay", relay). + Errorf("Relay unexpected state") + } } } @@ -381,10 +435,10 @@ func (hm *HandshakeManager) GetOrHandshake(vpnIp netip.Addr, cacheCb func(*Hands } // StartHandshake will ensure a handshake is currently being attempted for the provided vpn ip -func (hm *HandshakeManager) StartHandshake(vpnIp netip.Addr, cacheCb func(*HandshakeHostInfo)) *HostInfo { +func (hm *HandshakeManager) StartHandshake(vpnAddr netip.Addr, cacheCb func(*HandshakeHostInfo)) *HostInfo { hm.Lock() - if hh, ok := hm.vpnIps[vpnIp]; ok { + if hh, ok := hm.vpnIps[vpnAddr]; ok { // We are already trying to handshake with this vpn ip if cacheCb != nil { cacheCb(hh) @@ -394,12 +448,12 @@ func (hm *HandshakeManager) StartHandshake(vpnIp netip.Addr, cacheCb func(*Hands } hostinfo := &HostInfo{ - vpnIp: vpnIp, + vpnAddrs: []netip.Addr{vpnAddr}, HandshakePacket: make(map[uint8][]byte, 0), relayState: RelayState{ - relays: map[netip.Addr]struct{}{}, - relayForByIp: map[netip.Addr]*Relay{}, - relayForByIdx: map[uint32]*Relay{}, + relays: map[netip.Addr]struct{}{}, + relayForByAddr: map[netip.Addr]*Relay{}, + relayForByIdx: map[uint32]*Relay{}, }, } @@ -407,9 +461,9 @@ func (hm *HandshakeManager) StartHandshake(vpnIp netip.Addr, cacheCb func(*Hands hostinfo: hostinfo, startTime: time.Now(), } - hm.vpnIps[vpnIp] = hh + hm.vpnIps[vpnAddr] = hh hm.metricInitiated.Inc(1) - hm.OutboundHandshakeTimer.Add(vpnIp, hm.config.tryInterval) + hm.OutboundHandshakeTimer.Add(vpnAddr, hm.config.tryInterval) if cacheCb != nil { cacheCb(hh) @@ -417,21 +471,21 @@ func (hm *HandshakeManager) StartHandshake(vpnIp netip.Addr, cacheCb func(*Hands // If this is a static host, we don't need to wait for the HostQueryReply // We can trigger the handshake right now - _, doTrigger := hm.lightHouse.GetStaticHostList()[vpnIp] + _, doTrigger := hm.lightHouse.GetStaticHostList()[vpnAddr] if !doTrigger { // Add any calculated remotes, and trigger early handshake if one found - doTrigger = hm.lightHouse.addCalculatedRemotes(vpnIp) + doTrigger = hm.lightHouse.addCalculatedRemotes(vpnAddr) } if doTrigger { select { - case hm.trigger <- vpnIp: + case hm.trigger <- vpnAddr: default: } } hm.Unlock() - hm.lightHouse.QueryServer(vpnIp) + hm.lightHouse.QueryServer(vpnAddr) return hostinfo } @@ -452,14 +506,14 @@ var ( // // ErrLocalIndexCollision if we already have an entry in the main or pending // hostmap for the hostinfo.localIndexId. -func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket uint8, f *Interface) (*HostInfo, error) { - c.mainHostMap.Lock() - defer c.mainHostMap.Unlock() - c.Lock() - defer c.Unlock() +func (hm *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket uint8, f *Interface) (*HostInfo, error) { + hm.mainHostMap.Lock() + defer hm.mainHostMap.Unlock() + hm.Lock() + defer hm.Unlock() // Check if we already have a tunnel with this vpn ip - existingHostInfo, found := c.mainHostMap.Hosts[hostinfo.vpnIp] + existingHostInfo, found := hm.mainHostMap.Hosts[hostinfo.vpnAddrs[0]] if found && existingHostInfo != nil { testHostInfo := existingHostInfo for testHostInfo != nil { @@ -476,31 +530,31 @@ func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket return existingHostInfo, ErrExistingHostInfo } - existingHostInfo.logger(c.l).Info("Taking new handshake") + existingHostInfo.logger(hm.l).Info("Taking new handshake") } - existingIndex, found := c.mainHostMap.Indexes[hostinfo.localIndexId] + existingIndex, found := hm.mainHostMap.Indexes[hostinfo.localIndexId] if found { // We have a collision, but for a different hostinfo return existingIndex, ErrLocalIndexCollision } - existingPendingIndex, found := c.indexes[hostinfo.localIndexId] + existingPendingIndex, found := hm.indexes[hostinfo.localIndexId] if found && existingPendingIndex.hostinfo != hostinfo { // We have a collision, but for a different hostinfo return existingPendingIndex.hostinfo, ErrLocalIndexCollision } - existingRemoteIndex, found := c.mainHostMap.RemoteIndexes[hostinfo.remoteIndexId] - if found && existingRemoteIndex != nil && existingRemoteIndex.vpnIp != hostinfo.vpnIp { + existingRemoteIndex, found := hm.mainHostMap.RemoteIndexes[hostinfo.remoteIndexId] + if found && existingRemoteIndex != nil && existingRemoteIndex.vpnAddrs[0] != hostinfo.vpnAddrs[0] { // We have a collision, but this can happen since we can't control // the remote ID. Just log about the situation as a note. - hostinfo.logger(c.l). - WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", existingRemoteIndex.vpnIp). + hostinfo.logger(hm.l). + WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", existingRemoteIndex.vpnAddrs). Info("New host shadows existing host remoteIndex") } - c.mainHostMap.unlockedAddHostInfo(hostinfo, f) + hm.mainHostMap.unlockedAddHostInfo(hostinfo, f) return existingHostInfo, nil } @@ -518,7 +572,7 @@ func (hm *HandshakeManager) Complete(hostinfo *HostInfo, f *Interface) { // We have a collision, but this can happen since we can't control // the remote ID. Just log about the situation as a note. hostinfo.logger(hm.l). - WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", existingRemoteIndex.vpnIp). + WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", existingRemoteIndex.vpnAddrs). Info("New host shadows existing host remoteIndex") } @@ -555,31 +609,34 @@ func (hm *HandshakeManager) allocateIndex(hh *HandshakeHostInfo) error { return errors.New("failed to generate unique localIndexId") } -func (c *HandshakeManager) DeleteHostInfo(hostinfo *HostInfo) { - c.Lock() - defer c.Unlock() - c.unlockedDeleteHostInfo(hostinfo) +func (hm *HandshakeManager) DeleteHostInfo(hostinfo *HostInfo) { + hm.Lock() + defer hm.Unlock() + hm.unlockedDeleteHostInfo(hostinfo) } -func (c *HandshakeManager) unlockedDeleteHostInfo(hostinfo *HostInfo) { - delete(c.vpnIps, hostinfo.vpnIp) - if len(c.vpnIps) == 0 { - c.vpnIps = map[netip.Addr]*HandshakeHostInfo{} +func (hm *HandshakeManager) unlockedDeleteHostInfo(hostinfo *HostInfo) { + for _, addr := range hostinfo.vpnAddrs { + delete(hm.vpnIps, addr) } - delete(c.indexes, hostinfo.localIndexId) - if len(c.vpnIps) == 0 { - c.indexes = map[uint32]*HandshakeHostInfo{} + if len(hm.vpnIps) == 0 { + hm.vpnIps = map[netip.Addr]*HandshakeHostInfo{} } - if c.l.Level >= logrus.DebugLevel { - c.l.WithField("hostMap", m{"mapTotalSize": len(c.vpnIps), - "vpnIp": hostinfo.vpnIp, "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId}). + delete(hm.indexes, hostinfo.localIndexId) + if len(hm.indexes) == 0 { + hm.indexes = map[uint32]*HandshakeHostInfo{} + } + + if hm.l.Level >= logrus.DebugLevel { + hm.l.WithField("hostMap", m{"mapTotalSize": len(hm.vpnIps), + "vpnAddrs": hostinfo.vpnAddrs, "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId}). Debug("Pending hostmap hostInfo deleted") } } -func (hm *HandshakeManager) QueryVpnIp(vpnIp netip.Addr) *HostInfo { +func (hm *HandshakeManager) QueryVpnAddr(vpnIp netip.Addr) *HostInfo { hh := hm.queryVpnIp(vpnIp) if hh != nil { return hh.hostinfo @@ -608,37 +665,37 @@ func (hm *HandshakeManager) queryIndex(index uint32) *HandshakeHostInfo { return hm.indexes[index] } -func (c *HandshakeManager) GetPreferredRanges() []netip.Prefix { - return c.mainHostMap.GetPreferredRanges() +func (hm *HandshakeManager) GetPreferredRanges() []netip.Prefix { + return hm.mainHostMap.GetPreferredRanges() } -func (c *HandshakeManager) ForEachVpnIp(f controlEach) { - c.RLock() - defer c.RUnlock() +func (hm *HandshakeManager) ForEachVpnAddr(f controlEach) { + hm.RLock() + defer hm.RUnlock() - for _, v := range c.vpnIps { + for _, v := range hm.vpnIps { f(v.hostinfo) } } -func (c *HandshakeManager) ForEachIndex(f controlEach) { - c.RLock() - defer c.RUnlock() +func (hm *HandshakeManager) ForEachIndex(f controlEach) { + hm.RLock() + defer hm.RUnlock() - for _, v := range c.indexes { + for _, v := range hm.indexes { f(v.hostinfo) } } -func (c *HandshakeManager) EmitStats() { - c.RLock() - hostLen := len(c.vpnIps) - indexLen := len(c.indexes) - c.RUnlock() +func (hm *HandshakeManager) EmitStats() { + hm.RLock() + hostLen := len(hm.vpnIps) + indexLen := len(hm.indexes) + hm.RUnlock() metrics.GetOrRegisterGauge("hostmap.pending.hosts", nil).Update(int64(hostLen)) metrics.GetOrRegisterGauge("hostmap.pending.indexes", nil).Update(int64(indexLen)) - c.mainHostMap.EmitStats() + hm.mainHostMap.EmitStats() } // Utility functions below diff --git a/handshake_manager_test.go b/handshake_manager_test.go index daa8675..7edc55b 100644 --- a/handshake_manager_test.go +++ b/handshake_manager_test.go @@ -5,6 +5,7 @@ import ( "testing" "time" + "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/header" "github.com/slackhq/nebula/test" "github.com/slackhq/nebula/udp" @@ -13,21 +14,20 @@ import ( func Test_NewHandshakeManagerVpnIp(t *testing.T) { l := test.NewLogger() - vpncidr := netip.MustParsePrefix("172.1.1.1/24") localrange := netip.MustParsePrefix("10.1.1.1/24") ip := netip.MustParseAddr("172.1.1.2") preferredRanges := []netip.Prefix{localrange} - mainHM := newHostMap(l, vpncidr) + mainHM := newHostMap(l) mainHM.preferredRanges.Store(&preferredRanges) lh := newTestLighthouse() cs := &CertState{ - RawCertificate: []byte{}, - PrivateKey: []byte{}, - Certificate: &dummyCert{}, - RawCertificateNoKey: []byte{}, + defaultVersion: cert.Version1, + privateKey: []byte{}, + v1Cert: &dummyCert{version: cert.Version1}, + v1HandshakeBytes: []byte{}, } blah := NewHandshakeManager(l, mainHM, lh, &udp.NoopConn{}, defaultHandshakeConfig) @@ -41,7 +41,7 @@ func Test_NewHandshakeManagerVpnIp(t *testing.T) { i2 := blah.StartHandshake(ip, nil) assert.Same(t, i, i2) - i.remotes = NewRemoteList(nil) + i.remotes = NewRemoteList([]netip.Addr{}, nil) // Adding something to pending should not affect the main hostmap assert.Len(t, mainHM.Hosts, 0) @@ -79,16 +79,24 @@ func testCountTimerWheelEntries(tw *LockingTimerWheel[netip.Addr]) (c int) { type mockEncWriter struct { } -func (mw *mockEncWriter) SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp netip.Addr, p, nb, out []byte) { +func (mw *mockEncWriter) SendMessageToVpnAddr(_ header.MessageType, _ header.MessageSubType, _ netip.Addr, _, _, _ []byte) { return } -func (mw *mockEncWriter) SendVia(via *HostInfo, relay *Relay, ad, nb, out []byte, nocopy bool) { +func (mw *mockEncWriter) SendVia(_ *HostInfo, _ *Relay, _, _, _ []byte, _ bool) { return } -func (mw *mockEncWriter) SendMessageToHostInfo(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, nb, out []byte) { +func (mw *mockEncWriter) SendMessageToHostInfo(_ header.MessageType, _ header.MessageSubType, _ *HostInfo, _, _, _ []byte) { return } -func (mw *mockEncWriter) Handshake(vpnIP netip.Addr) {} +func (mw *mockEncWriter) Handshake(_ netip.Addr) {} + +func (mw *mockEncWriter) GetHostInfo(_ netip.Addr) *HostInfo { + return nil +} + +func (mw *mockEncWriter) GetCertState() *CertState { + return &CertState{defaultVersion: cert.Version2} +} diff --git a/hostmap.go b/hostmap.go index d83151e..f9e3c4e 100644 --- a/hostmap.go +++ b/hostmap.go @@ -35,6 +35,7 @@ const ( Requested = iota PeerRequested Established + Disestablished ) const ( @@ -48,7 +49,7 @@ type Relay struct { State int LocalIndex uint32 RemoteIndex uint32 - PeerIp netip.Addr + PeerAddr netip.Addr } type HostMap struct { @@ -58,7 +59,6 @@ type HostMap struct { RemoteIndexes map[uint32]*HostInfo Hosts map[netip.Addr]*HostInfo preferredRanges atomic.Pointer[[]netip.Prefix] - vpnCIDR netip.Prefix l *logrus.Logger } @@ -68,9 +68,12 @@ type HostMap struct { type RelayState struct { sync.RWMutex - relays map[netip.Addr]struct{} // Set of VpnIp's of Hosts to use as relays to access this peer - relayForByIp map[netip.Addr]*Relay // Maps VpnIps of peers for which this HostInfo is a relay to some Relay info - relayForByIdx map[uint32]*Relay // Maps a local index to some Relay info + relays map[netip.Addr]struct{} // Set of vpnAddr's of Hosts to use as relays to access this peer + // For data race avoidance, the contents of a *Relay are treated immutably. To update a *Relay, copy the existing data, + // modify what needs to be updated, and store the new modified copy in the relayForByIp and relayForByIdx maps (with + // the RelayState Lock held) + relayForByAddr map[netip.Addr]*Relay // Maps vpnAddr of peers for which this HostInfo is a relay to some Relay info + relayForByIdx map[uint32]*Relay // Maps a local index to some Relay info } func (rs *RelayState) DeleteRelay(ip netip.Addr) { @@ -79,6 +82,28 @@ func (rs *RelayState) DeleteRelay(ip netip.Addr) { delete(rs.relays, ip) } +func (rs *RelayState) UpdateRelayForByIpState(vpnIp netip.Addr, state int) { + rs.Lock() + defer rs.Unlock() + if r, ok := rs.relayForByAddr[vpnIp]; ok { + newRelay := *r + newRelay.State = state + rs.relayForByAddr[newRelay.PeerAddr] = &newRelay + rs.relayForByIdx[newRelay.LocalIndex] = &newRelay + } +} + +func (rs *RelayState) UpdateRelayForByIdxState(idx uint32, state int) { + rs.Lock() + defer rs.Unlock() + if r, ok := rs.relayForByIdx[idx]; ok { + newRelay := *r + newRelay.State = state + rs.relayForByAddr[newRelay.PeerAddr] = &newRelay + rs.relayForByIdx[newRelay.LocalIndex] = &newRelay + } +} + func (rs *RelayState) CopyAllRelayFor() []*Relay { rs.RLock() defer rs.RUnlock() @@ -89,10 +114,10 @@ func (rs *RelayState) CopyAllRelayFor() []*Relay { return ret } -func (rs *RelayState) GetRelayForByIp(ip netip.Addr) (*Relay, bool) { +func (rs *RelayState) GetRelayForByAddr(addr netip.Addr) (*Relay, bool) { rs.RLock() defer rs.RUnlock() - r, ok := rs.relayForByIp[ip] + r, ok := rs.relayForByAddr[addr] return r, ok } @@ -115,8 +140,8 @@ func (rs *RelayState) CopyRelayIps() []netip.Addr { func (rs *RelayState) CopyRelayForIps() []netip.Addr { rs.RLock() defer rs.RUnlock() - currentRelays := make([]netip.Addr, 0, len(rs.relayForByIp)) - for relayIp := range rs.relayForByIp { + currentRelays := make([]netip.Addr, 0, len(rs.relayForByAddr)) + for relayIp := range rs.relayForByAddr { currentRelays = append(currentRelays, relayIp) } return currentRelays @@ -135,7 +160,7 @@ func (rs *RelayState) CopyRelayForIdxs() []uint32 { func (rs *RelayState) CompleteRelayByIP(vpnIp netip.Addr, remoteIdx uint32) bool { rs.Lock() defer rs.Unlock() - r, ok := rs.relayForByIp[vpnIp] + r, ok := rs.relayForByAddr[vpnIp] if !ok { return false } @@ -143,7 +168,7 @@ func (rs *RelayState) CompleteRelayByIP(vpnIp netip.Addr, remoteIdx uint32) bool newRelay.State = Established newRelay.RemoteIndex = remoteIdx rs.relayForByIdx[r.LocalIndex] = &newRelay - rs.relayForByIp[r.PeerIp] = &newRelay + rs.relayForByAddr[r.PeerAddr] = &newRelay return true } @@ -158,14 +183,14 @@ func (rs *RelayState) CompleteRelayByIdx(localIdx uint32, remoteIdx uint32) (*Re newRelay.State = Established newRelay.RemoteIndex = remoteIdx rs.relayForByIdx[r.LocalIndex] = &newRelay - rs.relayForByIp[r.PeerIp] = &newRelay + rs.relayForByAddr[r.PeerAddr] = &newRelay return &newRelay, true } func (rs *RelayState) QueryRelayForByIp(vpnIp netip.Addr) (*Relay, bool) { rs.RLock() defer rs.RUnlock() - r, ok := rs.relayForByIp[vpnIp] + r, ok := rs.relayForByAddr[vpnIp] return r, ok } @@ -179,7 +204,7 @@ func (rs *RelayState) QueryRelayForByIdx(idx uint32) (*Relay, bool) { func (rs *RelayState) InsertRelay(ip netip.Addr, idx uint32, r *Relay) { rs.Lock() defer rs.Unlock() - rs.relayForByIp[ip] = r + rs.relayForByAddr[ip] = r rs.relayForByIdx[idx] = r } @@ -190,10 +215,16 @@ type HostInfo struct { ConnectionState *ConnectionState remoteIndexId uint32 localIndexId uint32 - vpnIp netip.Addr - recvError atomic.Uint32 - remoteCidr *bart.Table[struct{}] - relayState RelayState + + // vpnAddrs is a list of vpn addresses assigned to this host that are within our own vpn networks + // The host may have other vpn addresses that are outside our + // vpn networks but were removed because they are not usable + vpnAddrs []netip.Addr + recvError atomic.Uint32 + + // networks are both all vpn and unsafe networks assigned to this host + networks *bart.Table[struct{}] + relayState RelayState // HandshakePacket records the packets used to create this hostinfo // We need these to avoid replayed handshake packets creating new hostinfos which causes churn @@ -241,28 +272,26 @@ type cachedPacketMetrics struct { dropped metrics.Counter } -func NewHostMapFromConfig(l *logrus.Logger, vpnCIDR netip.Prefix, c *config.C) *HostMap { - hm := newHostMap(l, vpnCIDR) +func NewHostMapFromConfig(l *logrus.Logger, c *config.C) *HostMap { + hm := newHostMap(l) hm.reload(c, true) c.RegisterReloadCallback(func(c *config.C) { hm.reload(c, false) }) - l.WithField("network", hm.vpnCIDR.String()). - WithField("preferredRanges", hm.GetPreferredRanges()). + l.WithField("preferredRanges", hm.GetPreferredRanges()). Info("Main HostMap created") return hm } -func newHostMap(l *logrus.Logger, vpnCIDR netip.Prefix) *HostMap { +func newHostMap(l *logrus.Logger) *HostMap { return &HostMap{ Indexes: map[uint32]*HostInfo{}, Relays: map[uint32]*HostInfo{}, RemoteIndexes: map[uint32]*HostInfo{}, Hosts: map[netip.Addr]*HostInfo{}, - vpnCIDR: vpnCIDR, l: l, } } @@ -305,17 +334,6 @@ func (hm *HostMap) EmitStats() { metrics.GetOrRegisterGauge("hostmap.main.relayIndexes", nil).Update(int64(relaysLen)) } -func (hm *HostMap) RemoveRelay(localIdx uint32) { - hm.Lock() - _, ok := hm.Relays[localIdx] - if !ok { - hm.Unlock() - return - } - delete(hm.Relays, localIdx) - hm.Unlock() -} - // DeleteHostInfo will fully unlink the hostinfo and return true if it was the final hostinfo for this vpn ip func (hm *HostMap) DeleteHostInfo(hostinfo *HostInfo) bool { // Delete the host itself, ensuring it's not modified anymore @@ -335,48 +353,73 @@ func (hm *HostMap) MakePrimary(hostinfo *HostInfo) { } func (hm *HostMap) unlockedMakePrimary(hostinfo *HostInfo) { - oldHostinfo := hm.Hosts[hostinfo.vpnIp] + // Get the current primary, if it exists + oldHostinfo := hm.Hosts[hostinfo.vpnAddrs[0]] + + // Every address in the hostinfo gets elevated to primary + for _, vpnAddr := range hostinfo.vpnAddrs { + //NOTE: It is possible that we leave a dangling hostinfo here but connection manager works on + // indexes so it should be fine. + hm.Hosts[vpnAddr] = hostinfo + } + + // If we are already primary then we won't bother re-linking if oldHostinfo == hostinfo { return } + // Unlink this hostinfo if hostinfo.prev != nil { hostinfo.prev.next = hostinfo.next } - if hostinfo.next != nil { hostinfo.next.prev = hostinfo.prev } - hm.Hosts[hostinfo.vpnIp] = hostinfo - + // If there wasn't a previous primary then clear out any links if oldHostinfo == nil { + hostinfo.next = nil + hostinfo.prev = nil return } + // Relink the hostinfo as primary hostinfo.next = oldHostinfo oldHostinfo.prev = hostinfo hostinfo.prev = nil } func (hm *HostMap) unlockedDeleteHostInfo(hostinfo *HostInfo) { - primary, ok := hm.Hosts[hostinfo.vpnIp] + for _, addr := range hostinfo.vpnAddrs { + h := hm.Hosts[addr] + for h != nil { + if h == hostinfo { + hm.unlockedInnerDeleteHostInfo(h, addr) + } + h = h.next + } + } +} + +func (hm *HostMap) unlockedInnerDeleteHostInfo(hostinfo *HostInfo, addr netip.Addr) { + primary, ok := hm.Hosts[addr] + isLastHostinfo := hostinfo.next == nil && hostinfo.prev == nil if ok && primary == hostinfo { - // The vpnIp pointer points to the same hostinfo as the local index id, we can remove it - delete(hm.Hosts, hostinfo.vpnIp) + // The vpn addr pointer points to the same hostinfo as the local index id, we can remove it + delete(hm.Hosts, addr) if len(hm.Hosts) == 0 { hm.Hosts = map[netip.Addr]*HostInfo{} } if hostinfo.next != nil { - // We had more than 1 hostinfo at this vpnip, promote the next in the list to primary - hm.Hosts[hostinfo.vpnIp] = hostinfo.next + // We had more than 1 hostinfo at this vpn addr, promote the next in the list to primary + hm.Hosts[addr] = hostinfo.next // It is primary, there is no previous hostinfo now hostinfo.next.prev = nil } } else { - // Relink if we were in the middle of multiple hostinfos for this vpn ip + // Relink if we were in the middle of multiple hostinfos for this vpn addr if hostinfo.prev != nil { hostinfo.prev.next = hostinfo.next } @@ -406,10 +449,16 @@ func (hm *HostMap) unlockedDeleteHostInfo(hostinfo *HostInfo) { if hm.l.Level >= logrus.DebugLevel { hm.l.WithField("hostMap", m{"mapTotalSize": len(hm.Hosts), - "vpnIp": hostinfo.vpnIp, "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId}). + "vpnAddrs": hostinfo.vpnAddrs, "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId}). Debug("Hostmap hostInfo deleted") } + if isLastHostinfo { + // I have lost connectivity to my peers. My relay tunnel is likely broken. Mark the next + // hops as 'Requested' so that new relay tunnels are created in the future. + hm.unlockedDisestablishVpnAddrRelayFor(hostinfo) + } + // Clean up any local relay indexes for which I am acting as a relay hop for _, localRelayIdx := range hostinfo.relayState.CopyRelayForIdxs() { delete(hm.Relays, localRelayIdx) } @@ -448,11 +497,11 @@ func (hm *HostMap) QueryReverseIndex(index uint32) *HostInfo { } } -func (hm *HostMap) QueryVpnIp(vpnIp netip.Addr) *HostInfo { - return hm.queryVpnIp(vpnIp, nil) +func (hm *HostMap) QueryVpnAddr(vpnIp netip.Addr) *HostInfo { + return hm.queryVpnAddr(vpnIp, nil) } -func (hm *HostMap) QueryVpnIpRelayFor(targetIp, relayHostIp netip.Addr) (*HostInfo, *Relay, error) { +func (hm *HostMap) QueryVpnAddrsRelayFor(targetIps []netip.Addr, relayHostIp netip.Addr) (*HostInfo, *Relay, error) { hm.RLock() defer hm.RUnlock() @@ -460,17 +509,42 @@ func (hm *HostMap) QueryVpnIpRelayFor(targetIp, relayHostIp netip.Addr) (*HostIn if !ok { return nil, nil, errors.New("unable to find host") } + for h != nil { - r, ok := h.relayState.QueryRelayForByIp(targetIp) - if ok && r.State == Established { - return h, r, nil + for _, targetIp := range targetIps { + r, ok := h.relayState.QueryRelayForByIp(targetIp) + if ok && r.State == Established { + return h, r, nil + } } h = h.next } + return nil, nil, errors.New("unable to find host with relay") } -func (hm *HostMap) queryVpnIp(vpnIp netip.Addr, promoteIfce *Interface) *HostInfo { +func (hm *HostMap) unlockedDisestablishVpnAddrRelayFor(hi *HostInfo) { + for _, relayHostIp := range hi.relayState.CopyRelayIps() { + if h, ok := hm.Hosts[relayHostIp]; ok { + for h != nil { + h.relayState.UpdateRelayForByIpState(hi.vpnAddrs[0], Disestablished) + h = h.next + } + } + } + for _, rs := range hi.relayState.CopyAllRelayFor() { + if rs.Type == ForwardingType { + if h, ok := hm.Hosts[rs.PeerAddr]; ok { + for h != nil { + h.relayState.UpdateRelayForByIpState(hi.vpnAddrs[0], Disestablished) + h = h.next + } + } + } + } +} + +func (hm *HostMap) queryVpnAddr(vpnIp netip.Addr, promoteIfce *Interface) *HostInfo { hm.RLock() if h, ok := hm.Hosts[vpnIp]; ok { hm.RUnlock() @@ -491,25 +565,30 @@ func (hm *HostMap) queryVpnIp(vpnIp netip.Addr, promoteIfce *Interface) *HostInf func (hm *HostMap) unlockedAddHostInfo(hostinfo *HostInfo, f *Interface) { if f.serveDns { remoteCert := hostinfo.ConnectionState.peerCert - dnsR.Add(remoteCert.Certificate.Name()+".", remoteCert.Certificate.Networks()[0].Addr().String()) + dnsR.Add(remoteCert.Certificate.Name()+".", hostinfo.vpnAddrs) } - - existing := hm.Hosts[hostinfo.vpnIp] - hm.Hosts[hostinfo.vpnIp] = hostinfo - - if existing != nil { - hostinfo.next = existing - existing.prev = hostinfo + for _, addr := range hostinfo.vpnAddrs { + hm.unlockedInnerAddHostInfo(addr, hostinfo, f) } hm.Indexes[hostinfo.localIndexId] = hostinfo hm.RemoteIndexes[hostinfo.remoteIndexId] = hostinfo if hm.l.Level >= logrus.DebugLevel { - hm.l.WithField("hostMap", m{"vpnIp": hostinfo.vpnIp, "mapTotalSize": len(hm.Hosts), - "hostinfo": m{"existing": true, "localIndexId": hostinfo.localIndexId, "hostId": hostinfo.vpnIp}}). + hm.l.WithField("hostMap", m{"vpnAddrs": hostinfo.vpnAddrs, "mapTotalSize": len(hm.Hosts), + "hostinfo": m{"existing": true, "localIndexId": hostinfo.localIndexId, "vpnAddrs": hostinfo.vpnAddrs}}). Debug("Hostmap vpnIp added") } +} + +func (hm *HostMap) unlockedInnerAddHostInfo(vpnAddr netip.Addr, hostinfo *HostInfo, f *Interface) { + existing := hm.Hosts[vpnAddr] + hm.Hosts[vpnAddr] = hostinfo + + if existing != nil && existing != hostinfo { + hostinfo.next = existing + existing.prev = hostinfo + } i := 1 check := hostinfo @@ -527,7 +606,7 @@ func (hm *HostMap) GetPreferredRanges() []netip.Prefix { return *hm.preferredRanges.Load() } -func (hm *HostMap) ForEachVpnIp(f controlEach) { +func (hm *HostMap) ForEachVpnAddr(f controlEach) { hm.RLock() defer hm.RUnlock() @@ -581,7 +660,7 @@ func (i *HostInfo) TryPromoteBest(preferredRanges []netip.Prefix, ifce *Interfac } i.nextLHQuery.Store(now + ifce.reQueryWait.Load()) - ifce.lightHouse.QueryServer(i.vpnIp) + ifce.lightHouse.QueryServer(i.vpnAddrs[0]) } } @@ -596,7 +675,7 @@ func (i *HostInfo) SetRemote(remote netip.AddrPort) { // We copy here because we likely got this remote from a source that reuses the object if i.remote != remote { i.remote = remote - i.remotes.LearnRemote(i.vpnIp, remote) + i.remotes.LearnRemote(i.vpnAddrs[0], remote) } } @@ -647,21 +726,20 @@ func (i *HostInfo) RecvErrorExceeded() bool { return true } -func (i *HostInfo) CreateRemoteCIDR(c cert.Certificate) { - if len(c.Networks()) == 1 && len(c.UnsafeNetworks()) == 0 { +func (i *HostInfo) buildNetworks(networks, unsafeNetworks []netip.Prefix) { + if len(networks) == 1 && len(unsafeNetworks) == 0 { // Simple case, no CIDRTree needed return } - remoteCidr := new(bart.Table[struct{}]) - for _, network := range c.Networks() { - remoteCidr.Insert(network, struct{}{}) + i.networks = new(bart.Table[struct{}]) + for _, network := range networks { + i.networks.Insert(network, struct{}{}) } - for _, network := range c.UnsafeNetworks() { - remoteCidr.Insert(network, struct{}{}) + for _, network := range unsafeNetworks { + i.networks.Insert(network, struct{}{}) } - i.remoteCidr = remoteCidr } func (i *HostInfo) logger(l *logrus.Logger) *logrus.Entry { @@ -669,7 +747,7 @@ func (i *HostInfo) logger(l *logrus.Logger) *logrus.Entry { return logrus.NewEntry(l) } - li := l.WithField("vpnIp", i.vpnIp). + li := l.WithField("vpnAddrs", i.vpnAddrs). WithField("localIndex", i.localIndexId). WithField("remoteIndex", i.remoteIndexId) @@ -684,9 +762,9 @@ func (i *HostInfo) logger(l *logrus.Logger) *logrus.Entry { // Utility functions -func localIps(l *logrus.Logger, allowList *LocalAllowList) []netip.Addr { +func localAddrs(l *logrus.Logger, allowList *LocalAllowList) []netip.Addr { //FIXME: This function is pretty garbage - var ips []netip.Addr + var finalAddrs []netip.Addr ifaces, _ := net.Interfaces() for _, i := range ifaces { allow := allowList.AllowName(i.Name) @@ -698,39 +776,36 @@ func localIps(l *logrus.Logger, allowList *LocalAllowList) []netip.Addr { continue } addrs, _ := i.Addrs() - for _, addr := range addrs { - var ip net.IP - switch v := addr.(type) { + for _, rawAddr := range addrs { + var addr netip.Addr + switch v := rawAddr.(type) { case *net.IPNet: //continue - ip = v.IP + addr, _ = netip.AddrFromSlice(v.IP) case *net.IPAddr: - ip = v.IP + addr, _ = netip.AddrFromSlice(v.IP) } - nip, ok := netip.AddrFromSlice(ip) - if !ok { + if !addr.IsValid() { if l.Level >= logrus.DebugLevel { - l.WithField("localIp", ip).Debug("ip was invalid for netip") + l.WithField("localAddr", rawAddr).Debug("addr was invalid") } continue } - nip = nip.Unmap() + addr = addr.Unmap() - //TODO: Filtering out link local for now, this is probably the most correct thing - //TODO: Would be nice to filter out SLAAC MAC based ips as well - if nip.IsLoopback() == false && nip.IsLinkLocalUnicast() == false { - allow := allowList.Allow(nip) + if addr.IsLoopback() == false && addr.IsLinkLocalUnicast() == false { + isAllowed := allowList.Allow(addr) if l.Level >= logrus.TraceLevel { - l.WithField("localIp", nip).WithField("allow", allow).Trace("localAllowList.Allow") + l.WithField("localAddr", addr).WithField("allowed", isAllowed).Trace("localAllowList.Allow") } - if !allow { + if !isAllowed { continue } - ips = append(ips, nip) + finalAddrs = append(finalAddrs, addr) } } } - return ips + return finalAddrs } diff --git a/hostmap_test.go b/hostmap_test.go index 7e2feb8..e974340 100644 --- a/hostmap_test.go +++ b/hostmap_test.go @@ -11,17 +11,14 @@ import ( func TestHostMap_MakePrimary(t *testing.T) { l := test.NewLogger() - hm := newHostMap( - l, - netip.MustParsePrefix("10.0.0.1/24"), - ) + hm := newHostMap(l) f := &Interface{} - h1 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 1} - h2 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 2} - h3 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 3} - h4 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 4} + h1 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 1} + h2 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 2} + h3 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 3} + h4 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 4} hm.unlockedAddHostInfo(h4, f) hm.unlockedAddHostInfo(h3, f) @@ -29,7 +26,7 @@ func TestHostMap_MakePrimary(t *testing.T) { hm.unlockedAddHostInfo(h1, f) // Make sure we go h1 -> h2 -> h3 -> h4 - prim := hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1")) + prim := hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1")) assert.Equal(t, h1.localIndexId, prim.localIndexId) assert.Equal(t, h2.localIndexId, prim.next.localIndexId) assert.Nil(t, prim.prev) @@ -44,7 +41,7 @@ func TestHostMap_MakePrimary(t *testing.T) { hm.MakePrimary(h3) // Make sure we go h3 -> h1 -> h2 -> h4 - prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1")) + prim = hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1")) assert.Equal(t, h3.localIndexId, prim.localIndexId) assert.Equal(t, h1.localIndexId, prim.next.localIndexId) assert.Nil(t, prim.prev) @@ -59,7 +56,7 @@ func TestHostMap_MakePrimary(t *testing.T) { hm.MakePrimary(h4) // Make sure we go h4 -> h3 -> h1 -> h2 - prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1")) + prim = hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1")) assert.Equal(t, h4.localIndexId, prim.localIndexId) assert.Equal(t, h3.localIndexId, prim.next.localIndexId) assert.Nil(t, prim.prev) @@ -74,7 +71,7 @@ func TestHostMap_MakePrimary(t *testing.T) { hm.MakePrimary(h4) // Make sure we go h4 -> h3 -> h1 -> h2 - prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1")) + prim = hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1")) assert.Equal(t, h4.localIndexId, prim.localIndexId) assert.Equal(t, h3.localIndexId, prim.next.localIndexId) assert.Nil(t, prim.prev) @@ -88,19 +85,16 @@ func TestHostMap_MakePrimary(t *testing.T) { func TestHostMap_DeleteHostInfo(t *testing.T) { l := test.NewLogger() - hm := newHostMap( - l, - netip.MustParsePrefix("10.0.0.1/24"), - ) + hm := newHostMap(l) f := &Interface{} - h1 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 1} - h2 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 2} - h3 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 3} - h4 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 4} - h5 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 5} - h6 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 6} + h1 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 1} + h2 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 2} + h3 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 3} + h4 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 4} + h5 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 5} + h6 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 6} hm.unlockedAddHostInfo(h6, f) hm.unlockedAddHostInfo(h5, f) @@ -116,7 +110,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) { assert.Nil(t, h) // Make sure we go h1 -> h2 -> h3 -> h4 -> h5 - prim := hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1")) + prim := hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1")) assert.Equal(t, h1.localIndexId, prim.localIndexId) assert.Equal(t, h2.localIndexId, prim.next.localIndexId) assert.Nil(t, prim.prev) @@ -135,7 +129,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) { assert.Nil(t, h1.next) // Make sure we go h2 -> h3 -> h4 -> h5 - prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1")) + prim = hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1")) assert.Equal(t, h2.localIndexId, prim.localIndexId) assert.Equal(t, h3.localIndexId, prim.next.localIndexId) assert.Nil(t, prim.prev) @@ -153,7 +147,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) { assert.Nil(t, h3.next) // Make sure we go h2 -> h4 -> h5 - prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1")) + prim = hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1")) assert.Equal(t, h2.localIndexId, prim.localIndexId) assert.Equal(t, h4.localIndexId, prim.next.localIndexId) assert.Nil(t, prim.prev) @@ -169,7 +163,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) { assert.Nil(t, h5.next) // Make sure we go h2 -> h4 - prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1")) + prim = hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1")) assert.Equal(t, h2.localIndexId, prim.localIndexId) assert.Equal(t, h4.localIndexId, prim.next.localIndexId) assert.Nil(t, prim.prev) @@ -183,7 +177,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) { assert.Nil(t, h2.next) // Make sure we only have h4 - prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1")) + prim = hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1")) assert.Equal(t, h4.localIndexId, prim.localIndexId) assert.Nil(t, prim.prev) assert.Nil(t, prim.next) @@ -195,7 +189,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) { assert.Nil(t, h4.next) // Make sure we have nil - prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1")) + prim = hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1")) assert.Nil(t, prim) } @@ -203,11 +197,7 @@ func TestHostMap_reload(t *testing.T) { l := test.NewLogger() c := config.NewC(l) - hm := NewHostMapFromConfig( - l, - netip.MustParsePrefix("10.0.0.1/24"), - c, - ) + hm := NewHostMapFromConfig(l, c) toS := func(ipn []netip.Prefix) []string { var s []string diff --git a/hostmap_tester.go b/hostmap_tester.go index b2d1d1b..fe40c53 100644 --- a/hostmap_tester.go +++ b/hostmap_tester.go @@ -9,8 +9,8 @@ import ( "net/netip" ) -func (i *HostInfo) GetVpnIp() netip.Addr { - return i.vpnIp +func (i *HostInfo) GetVpnAddrs() []netip.Addr { + return i.vpnAddrs } func (i *HostInfo) GetLocalIndex() uint32 { diff --git a/inside.go b/inside.go index 0ccd179..9629947 100644 --- a/inside.go +++ b/inside.go @@ -20,14 +20,18 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet } // Ignore local broadcast packets - if f.dropLocalBroadcast && fwPacket.RemoteIP == f.myBroadcastAddr { - return + if f.dropLocalBroadcast { + _, found := f.myBroadcastAddrsTable.Lookup(fwPacket.RemoteAddr) + if found { + return + } } - if fwPacket.RemoteIP == f.myVpnNet.Addr() { + _, found := f.myVpnAddrsTable.Lookup(fwPacket.RemoteAddr) + if found { // Immediately forward packets from self to self. // This should only happen on Darwin-based and FreeBSD hosts, which - // routes packets from the Nebula IP to the Nebula IP through the Nebula + // routes packets from the Nebula addr to the Nebula addr through the Nebula // TUN device. if immediatelyForwardToSelf { _, err := f.readers[q].Write(packet) @@ -36,25 +40,25 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet } } // Otherwise, drop. On linux, we should never see these packets - Linux - // routes packets from the nebula IP to the nebula IP through the loopback device. + // routes packets from the nebula addr to the nebula addr through the loopback device. return } // Ignore multicast packets - if f.dropMulticast && fwPacket.RemoteIP.IsMulticast() { + if f.dropMulticast && fwPacket.RemoteAddr.IsMulticast() { return } - hostinfo, ready := f.getOrHandshake(fwPacket.RemoteIP, func(hh *HandshakeHostInfo) { + hostinfo, ready := f.getOrHandshake(fwPacket.RemoteAddr, func(hh *HandshakeHostInfo) { hh.cachePacket(f.l, header.Message, 0, packet, f.sendMessageNow, f.cachedPacketMetrics) }) if hostinfo == nil { f.rejectInside(packet, out, q) if f.l.Level >= logrus.DebugLevel { - f.l.WithField("vpnIp", fwPacket.RemoteIP). + f.l.WithField("vpnAddr", fwPacket.RemoteAddr). WithField("fwPacket", fwPacket). - Debugln("dropping outbound packet, vpnIp not in our CIDR or in unsafe routes") + Debugln("dropping outbound packet, vpnAddr not in our vpn networks or in unsafe networks") } return } @@ -117,21 +121,22 @@ func (f *Interface) rejectOutside(packet []byte, ci *ConnectionState, hostinfo * f.sendNoMetrics(header.Message, 0, ci, hostinfo, netip.AddrPort{}, out, nb, packet, q) } -func (f *Interface) Handshake(vpnIp netip.Addr) { - f.getOrHandshake(vpnIp, nil) +func (f *Interface) Handshake(vpnAddr netip.Addr) { + f.getOrHandshake(vpnAddr, nil) } -// getOrHandshake returns nil if the vpnIp is not routable. +// getOrHandshake returns nil if the vpnAddr is not routable. // If the 2nd return var is false then the hostinfo is not ready to be used in a tunnel -func (f *Interface) getOrHandshake(vpnIp netip.Addr, cacheCallback func(*HandshakeHostInfo)) (*HostInfo, bool) { - if !f.myVpnNet.Contains(vpnIp) { - vpnIp = f.inside.RouteFor(vpnIp) - if !vpnIp.IsValid() { +func (f *Interface) getOrHandshake(vpnAddr netip.Addr, cacheCallback func(*HandshakeHostInfo)) (*HostInfo, bool) { + _, found := f.myVpnNetworksTable.Lookup(vpnAddr) + if !found { + vpnAddr = f.inside.RouteFor(vpnAddr) + if !vpnAddr.IsValid() { return nil, false } } - return f.handshakeManager.GetOrHandshake(vpnIp, cacheCallback) + return f.handshakeManager.GetOrHandshake(vpnAddr, cacheCallback) } func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, nb, out []byte) { @@ -156,16 +161,16 @@ func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubTyp f.sendNoMetrics(header.Message, st, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, p, nb, out, 0) } -// SendMessageToVpnIp handles real ip:port lookup and sends to the current best known address for vpnIp -func (f *Interface) SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp netip.Addr, p, nb, out []byte) { - hostInfo, ready := f.getOrHandshake(vpnIp, func(hh *HandshakeHostInfo) { +// SendMessageToVpnAddr handles real addr:port lookup and sends to the current best known address for vpnAddr +func (f *Interface) SendMessageToVpnAddr(t header.MessageType, st header.MessageSubType, vpnAddr netip.Addr, p, nb, out []byte) { + hostInfo, ready := f.getOrHandshake(vpnAddr, func(hh *HandshakeHostInfo) { hh.cachePacket(f.l, t, st, p, f.SendMessageToHostInfo, f.cachedPacketMetrics) }) if hostInfo == nil { if f.l.Level >= logrus.DebugLevel { - f.l.WithField("vpnIp", vpnIp). - Debugln("dropping SendMessageToVpnIp, vpnIp not in our CIDR or in unsafe routes") + f.l.WithField("vpnAddr", vpnAddr). + Debugln("dropping SendMessageToVpnAddr, vpnAddr not in our vpn networks or in unsafe routes") } return } @@ -258,7 +263,6 @@ func (f *Interface) SendVia(via *HostInfo, func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote netip.AddrPort, p, nb, out []byte, q int) { if ci.eKey == nil { - //TODO: log warning return } useRelay := !remote.IsValid() && !hostinfo.remote.IsValid() @@ -285,14 +289,14 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType f.connectionManager.Out(hostinfo.localIndexId) // Query our LH if we haven't since the last time we've been rebound, this will cause the remote to punch against - // all our IPs and enable a faster roaming. + // all our addrs and enable a faster roaming. if t != header.CloseTunnel && hostinfo.lastRebindCount != f.rebindCount { //NOTE: there is an update hole if a tunnel isn't used and exactly 256 rebinds occur before the tunnel is // finally used again. This tunnel would eventually be torn down and recreated if this action didn't help. - f.lightHouse.QueryServer(hostinfo.vpnIp) + f.lightHouse.QueryServer(hostinfo.vpnAddrs[0]) hostinfo.lastRebindCount = f.rebindCount if f.l.Level >= logrus.DebugLevel { - f.l.WithField("vpnIp", hostinfo.vpnIp).Debug("Lighthouse update triggered for punch due to rebind counter") + f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).Debug("Lighthouse update triggered for punch due to rebind counter") } } @@ -324,7 +328,7 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType } else { // Try to send via a relay for _, relayIP := range hostinfo.relayState.CopyRelayIps() { - relayHostInfo, relay, err := f.hostMap.QueryVpnIpRelayFor(hostinfo.vpnIp, relayIP) + relayHostInfo, relay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relayIP) if err != nil { hostinfo.relayState.DeleteRelay(relayIP) hostinfo.logger(f.l).WithField("relay", relayIP).WithError(err).Info("sendNoMetrics failed to find HostInfo") diff --git a/interface.go b/interface.go index 5d41a87..21e198c 100644 --- a/interface.go +++ b/interface.go @@ -2,17 +2,16 @@ package nebula import ( "context" - "encoding/binary" "errors" "fmt" "io" - "net" "net/netip" "os" "runtime" "sync/atomic" "time" + "github.com/gaissmai/bart" "github.com/rcrowley/go-metrics" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" @@ -29,7 +28,6 @@ type InterfaceConfig struct { Outside udp.Conn Inside overlay.Device pki *PKI - Cipher string Firewall *Firewall ServeDns bool HandshakeManager *HandshakeManager @@ -53,25 +51,27 @@ type InterfaceConfig struct { } type Interface struct { - hostMap *HostMap - outside udp.Conn - inside overlay.Device - pki *PKI - cipher string - firewall *Firewall - connectionManager *connectionManager - handshakeManager *HandshakeManager - serveDns bool - createTime time.Time - lightHouse *LightHouse - myBroadcastAddr netip.Addr - myVpnNet netip.Prefix - dropLocalBroadcast bool - dropMulticast bool - routines int - disconnectInvalid atomic.Bool - closed atomic.Bool - relayManager *relayManager + hostMap *HostMap + outside udp.Conn + inside overlay.Device + pki *PKI + firewall *Firewall + connectionManager *connectionManager + handshakeManager *HandshakeManager + serveDns bool + createTime time.Time + lightHouse *LightHouse + myBroadcastAddrsTable *bart.Table[struct{}] + myVpnAddrs []netip.Addr // A list of addresses assigned to us via our certificate + myVpnAddrsTable *bart.Table[struct{}] // A table of addresses assigned to us via our certificate + myVpnNetworks []netip.Prefix // A list of networks assigned to us via our certificate + myVpnNetworksTable *bart.Table[struct{}] // A table of networks assigned to us via our certificate + dropLocalBroadcast bool + dropMulticast bool + routines int + disconnectInvalid atomic.Bool + closed atomic.Bool + relayManager *relayManager tryPromoteEvery atomic.Uint32 reQueryEvery atomic.Uint32 @@ -103,9 +103,11 @@ type EncWriter interface { out []byte, nocopy bool, ) - SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp netip.Addr, p, nb, out []byte) + SendMessageToVpnAddr(t header.MessageType, st header.MessageSubType, vpnAddr netip.Addr, p, nb, out []byte) SendMessageToHostInfo(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, nb, out []byte) - Handshake(vpnIp netip.Addr) + Handshake(vpnAddr netip.Addr) + GetHostInfo(vpnAddr netip.Addr) *HostInfo + GetCertState() *CertState } type sendRecvErrorConfig uint8 @@ -116,10 +118,10 @@ const ( sendRecvErrorPrivate ) -func (s sendRecvErrorConfig) ShouldSendRecvError(ip netip.AddrPort) bool { +func (s sendRecvErrorConfig) ShouldSendRecvError(endpoint netip.AddrPort) bool { switch s { case sendRecvErrorPrivate: - return ip.Addr().IsPrivate() + return endpoint.Addr().IsPrivate() case sendRecvErrorAlways: return true case sendRecvErrorNever: @@ -156,27 +158,29 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) { return nil, errors.New("no firewall rules") } - certificate := c.pki.GetCertState().Certificate - + cs := c.pki.getCertState() ifce := &Interface{ - pki: c.pki, - hostMap: c.HostMap, - outside: c.Outside, - inside: c.Inside, - cipher: c.Cipher, - firewall: c.Firewall, - serveDns: c.ServeDns, - handshakeManager: c.HandshakeManager, - createTime: time.Now(), - lightHouse: c.lightHouse, - dropLocalBroadcast: c.DropLocalBroadcast, - dropMulticast: c.DropMulticast, - routines: c.routines, - version: c.version, - writers: make([]udp.Conn, c.routines), - readers: make([]io.ReadWriteCloser, c.routines), - myVpnNet: certificate.Networks()[0], - relayManager: c.relayManager, + pki: c.pki, + hostMap: c.HostMap, + outside: c.Outside, + inside: c.Inside, + firewall: c.Firewall, + serveDns: c.ServeDns, + handshakeManager: c.HandshakeManager, + createTime: time.Now(), + lightHouse: c.lightHouse, + dropLocalBroadcast: c.DropLocalBroadcast, + dropMulticast: c.DropMulticast, + routines: c.routines, + version: c.version, + writers: make([]udp.Conn, c.routines), + readers: make([]io.ReadWriteCloser, c.routines), + myVpnNetworks: cs.myVpnNetworks, + myVpnNetworksTable: cs.myVpnNetworksTable, + myVpnAddrs: cs.myVpnAddrs, + myVpnAddrsTable: cs.myVpnAddrsTable, + myBroadcastAddrsTable: cs.myVpnBroadcastAddrsTable, + relayManager: c.relayManager, conntrackCacheTimeout: c.ConntrackCacheTimeout, @@ -190,14 +194,6 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) { l: c.l, } - if ifce.myVpnNet.Addr().Is4() { - maskedAddr := certificate.Networks()[0].Masked() - addr := maskedAddr.Addr().As4() - mask := net.CIDRMask(maskedAddr.Bits(), maskedAddr.Addr().BitLen()) - binary.BigEndian.PutUint32(addr[:], binary.BigEndian.Uint32(addr[:])|^binary.BigEndian.Uint32(mask)) - ifce.myBroadcastAddr = netip.AddrFrom4(addr) - } - ifce.tryPromoteEvery.Store(c.tryPromoteEvery) ifce.reQueryEvery.Store(c.reQueryEvery) ifce.reQueryWait.Store(int64(c.reQueryWait)) @@ -218,7 +214,7 @@ func (f *Interface) activate() { f.l.WithError(err).Error("Failed to get udp listen address") } - f.l.WithField("interface", f.inside.Name()).WithField("network", f.inside.Cidr().String()). + f.l.WithField("interface", f.inside.Name()).WithField("networks", f.myVpnNetworks). WithField("build", f.version).WithField("udpAddr", addr). WithField("boringcrypto", boringEnabled()). Info("Nebula interface is active") @@ -259,16 +255,22 @@ func (f *Interface) listenOut(i int) { runtime.LockOSThread() var li udp.Conn - // TODO clean this up with a coherent interface for each outside connection if i > 0 { li = f.writers[i] } else { li = f.outside } + ctCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout) lhh := f.lightHouse.NewRequestHandler() - conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout) - li.ListenOut(readOutsidePackets(f), lhHandleRequest(lhh, f), conntrackCache, i) + plaintext := make([]byte, udp.MTU) + h := &header.H{} + fwPacket := &firewall.Packet{} + nb := make([]byte, 12, 12) + + li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) { + f.readOutsidePackets(fromUdpAddr, nil, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get(f.l)) + }) } func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) { @@ -325,7 +327,7 @@ func (f *Interface) reloadFirewall(c *config.C) { return } - fw, err := NewFirewallFromConfig(f.l, f.pki.GetCertState().Certificate, c) + fw, err := NewFirewallFromConfig(f.l, f.pki.getCertState(), c) if err != nil { f.l.WithError(err).Error("Error while creating firewall during reload") return @@ -408,6 +410,8 @@ func (f *Interface) emitStats(ctx context.Context, i time.Duration) { udpStats := udp.NewUDPStatsEmitter(f.writers) certExpirationGauge := metrics.GetOrRegisterGauge("certificate.ttl_seconds", nil) + certDefaultVersion := metrics.GetOrRegisterGauge("certificate.default_version", nil) + certMaxVersion := metrics.GetOrRegisterGauge("certificate.max_version", nil) for { select { @@ -417,11 +421,30 @@ func (f *Interface) emitStats(ctx context.Context, i time.Duration) { f.firewall.EmitStats() f.handshakeManager.EmitStats() udpStats() - certExpirationGauge.Update(int64(f.pki.GetCertState().Certificate.NotAfter().Sub(time.Now()) / time.Second)) + + certState := f.pki.getCertState() + defaultCrt := certState.GetDefaultCertificate() + certExpirationGauge.Update(int64(defaultCrt.NotAfter().Sub(time.Now()) / time.Second)) + certDefaultVersion.Update(int64(defaultCrt.Version())) + + // Report the max certificate version we are capable of using + if certState.v2Cert != nil { + certMaxVersion.Update(int64(certState.v2Cert.Version())) + } else { + certMaxVersion.Update(int64(certState.v1Cert.Version())) + } } } } +func (f *Interface) GetHostInfo(vpnIp netip.Addr) *HostInfo { + return f.hostMap.QueryVpnAddr(vpnIp) +} + +func (f *Interface) GetCertState() *CertState { + return f.pki.getCertState() +} + func (f *Interface) Close() error { f.closed.Store(true) diff --git a/iputil/packet.go b/iputil/packet.go index 719e034..b18e524 100644 --- a/iputil/packet.go +++ b/iputil/packet.go @@ -6,8 +6,6 @@ import ( "golang.org/x/net/ipv4" ) -//TODO: IPV6-WORK can probably delete this - const ( // Need 96 bytes for the largest reject packet: // - 20 byte ipv4 header diff --git a/lighthouse.go b/lighthouse.go index 62f4065..ce37023 100644 --- a/lighthouse.go +++ b/lighthouse.go @@ -7,6 +7,7 @@ import ( "fmt" "net" "net/netip" + "slices" "strconv" "sync" "sync/atomic" @@ -15,28 +16,28 @@ import ( "github.com/gaissmai/bart" "github.com/rcrowley/go-metrics" "github.com/sirupsen/logrus" + "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/header" "github.com/slackhq/nebula/udp" "github.com/slackhq/nebula/util" ) -//TODO: if a lighthouse doesn't have an answer, clients AGGRESSIVELY REQUERY.. why? handshake manager and/or getOrHandshake? -//TODO: nodes are roaming lighthouses, this is bad. How are they learning? - var ErrHostNotKnown = errors.New("host not known") type LightHouse struct { - //TODO: We need a timer wheel to kick out vpnIps that haven't reported in a long time + //TODO: We need a timer wheel to kick out vpnAddrs that haven't reported in a long time sync.RWMutex //Because we concurrently read and write to our maps ctx context.Context amLighthouse bool - myVpnNet netip.Prefix - punchConn udp.Conn - punchy *Punchy + + myVpnNetworks []netip.Prefix + myVpnNetworksTable *bart.Table[struct{}] + punchConn udp.Conn + punchy *Punchy // Local cache of answers from light houses - // map of vpn Ip to answers + // map of vpn addr to answers addrMap map[netip.Addr]*RemoteList // filters remote addresses allowed for each host @@ -64,12 +65,12 @@ type LightHouse struct { advertiseAddrs atomic.Pointer[[]netip.AddrPort] - // IP's of relays that can be used by peers to access me + // Addr's of relays that can be used by peers to access me relaysForMe atomic.Pointer[[]netip.Addr] queryChan chan netip.Addr - calculatedRemotes atomic.Pointer[bart.Table[[]*calculatedRemote]] // Maps VpnIp to []*calculatedRemote + calculatedRemotes atomic.Pointer[bart.Table[[]*calculatedRemote]] // Maps VpnAddr to []*calculatedRemote metrics *MessageMetrics metricHolepunchTx metrics.Counter @@ -78,7 +79,7 @@ type LightHouse struct { // NewLightHouseFromConfig will build a Lighthouse struct from the values provided in the config object // addrMap should be nil unless this is during a config reload -func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C, myVpnNet netip.Prefix, pc udp.Conn, p *Punchy) (*LightHouse, error) { +func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C, cs *CertState, pc udp.Conn, p *Punchy) (*LightHouse, error) { amLighthouse := c.GetBool("lighthouse.am_lighthouse", false) nebulaPort := uint32(c.GetInt("listen.port", 0)) if amLighthouse && nebulaPort == 0 { @@ -95,15 +96,16 @@ func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C, } h := LightHouse{ - ctx: ctx, - amLighthouse: amLighthouse, - myVpnNet: myVpnNet, - addrMap: make(map[netip.Addr]*RemoteList), - nebulaPort: nebulaPort, - punchConn: pc, - punchy: p, - queryChan: make(chan netip.Addr, c.GetUint32("handshakes.query_buffer", 64)), - l: l, + ctx: ctx, + amLighthouse: amLighthouse, + myVpnNetworks: cs.myVpnNetworks, + myVpnNetworksTable: cs.myVpnNetworksTable, + addrMap: make(map[netip.Addr]*RemoteList), + nebulaPort: nebulaPort, + punchConn: pc, + punchy: p, + queryChan: make(chan netip.Addr, c.GetUint32("handshakes.query_buffer", 64)), + l: l, } lighthouses := make(map[netip.Addr]struct{}) h.lighthouses.Store(&lighthouses) @@ -180,11 +182,11 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error { return util.NewContextualError("Unable to parse lighthouse.advertise_addrs entry", m{"addr": rawAddr, "entry": i + 1}, err) } - ips, err := net.DefaultResolver.LookupNetIP(context.Background(), "ip", host) + addrs, err := net.DefaultResolver.LookupNetIP(context.Background(), "ip", host) if err != nil { return util.NewContextualError("Unable to lookup lighthouse.advertise_addrs entry", m{"addr": rawAddr, "entry": i + 1}, err) } - if len(ips) == 0 { + if len(addrs) == 0 { return util.NewContextualError("Unable to lookup lighthouse.advertise_addrs entry", m{"addr": rawAddr, "entry": i + 1}, nil) } @@ -197,15 +199,16 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error { port = int(lh.nebulaPort) } - //TODO: we could technically insert all returned ips instead of just the first one if a dns lookup was used - ip := ips[0].Unmap() - if lh.myVpnNet.Contains(ip) { + //TODO: we could technically insert all returned addrs instead of just the first one if a dns lookup was used + addr := addrs[0].Unmap() + _, found := lh.myVpnNetworksTable.Lookup(addr) + if found { lh.l.WithField("addr", rawAddr).WithField("entry", i+1). Warn("Ignoring lighthouse.advertise_addrs report because it is within the nebula network range") continue } - advAddrs = append(advAddrs, netip.AddrPortFrom(ip, uint16(port))) + advAddrs = append(advAddrs, netip.AddrPortFrom(addr, uint16(port))) } lh.advertiseAddrs.Store(&advAddrs) @@ -238,7 +241,6 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error { lh.remoteAllowList.Store(ral) if !initial { - //TODO: a diff will be annoyingly difficult lh.l.Info("lighthouse.remote_allow_list and/or lighthouse.remote_allow_ranges has changed") } } @@ -251,7 +253,6 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error { lh.localAllowList.Store(lal) if !initial { - //TODO: a diff will be annoyingly difficult lh.l.Info("lighthouse.local_allow_list has changed") } } @@ -264,7 +265,6 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error { lh.calculatedRemotes.Store(cr) if !initial { - //TODO: a diff will be annoyingly difficult lh.l.Info("lighthouse.calculated_remotes has changed") } } @@ -275,8 +275,8 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error { // Entries no longer present must have their (possible) background DNS goroutines stopped. if existingStaticList := lh.staticList.Load(); existingStaticList != nil { lh.RLock() - for staticVpnIp := range *existingStaticList { - if am, ok := lh.addrMap[staticVpnIp]; ok && am != nil { + for staticVpnAddr := range *existingStaticList { + if am, ok := lh.addrMap[staticVpnAddr]; ok && am != nil { am.hr.Cancel() } } @@ -291,7 +291,6 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error { lh.staticList.Store(&staticList) if !initial { - //TODO: we should remove any remote list entries for static hosts that were removed/modified? if c.HasChanged("static_host_map") { lh.l.Info("static_host_map has changed") } @@ -333,11 +332,11 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error { case false: relaysForMe := []netip.Addr{} for _, v := range c.GetStringSlice("relay.relays", nil) { - lh.l.WithField("relay", v).Info("Read relay from config") - configRIP, err := netip.ParseAddr(v) - //TODO: We could print the error here - if err == nil { + if err != nil { + lh.l.WithField("relay", v).WithError(err).Warn("Parse relay from config failed") + } else { + lh.l.WithField("relay", v).Info("Read relay from config") relaysForMe = append(relaysForMe, configRIP) } } @@ -355,14 +354,16 @@ func (lh *LightHouse) parseLighthouses(c *config.C, lhMap map[netip.Addr]struct{ } for i, host := range lhs { - ip, err := netip.ParseAddr(host) + addr, err := netip.ParseAddr(host) if err != nil { return util.NewContextualError("Unable to parse lighthouse host entry", m{"host": host, "entry": i + 1}, err) } - if !lh.myVpnNet.Contains(ip) { - return util.NewContextualError("lighthouse host is not in our subnet, invalid", m{"vpnIp": ip, "network": lh.myVpnNet}, nil) + + _, found := lh.myVpnNetworksTable.Lookup(addr) + if !found { + return util.NewContextualError("lighthouse host is not in our networks, invalid", m{"vpnAddr": addr, "networks": lh.myVpnNetworks}, nil) } - lhMap[ip] = struct{}{} + lhMap[addr] = struct{}{} } if !lh.amLighthouse && len(lhMap) == 0 { @@ -370,9 +371,9 @@ func (lh *LightHouse) parseLighthouses(c *config.C, lhMap map[netip.Addr]struct{ } staticList := lh.GetStaticHostList() - for lhIP, _ := range lhMap { - if _, ok := staticList[lhIP]; !ok { - return fmt.Errorf("lighthouse %s does not have a static_host_map entry", lhIP) + for lhAddr, _ := range lhMap { + if _, ok := staticList[lhAddr]; !ok { + return fmt.Errorf("lighthouse %s does not have a static_host_map entry", lhAddr) } } @@ -425,13 +426,14 @@ func (lh *LightHouse) loadStaticMap(c *config.C, staticList map[netip.Addr]struc i := 0 for k, v := range shm { - vpnIp, err := netip.ParseAddr(fmt.Sprintf("%v", k)) + vpnAddr, err := netip.ParseAddr(fmt.Sprintf("%v", k)) if err != nil { return util.NewContextualError("Unable to parse static_host_map entry", m{"host": k, "entry": i + 1}, err) } - if !lh.myVpnNet.Contains(vpnIp) { - return util.NewContextualError("static_host_map key is not in our subnet, invalid", m{"vpnIp": vpnIp, "network": lh.myVpnNet, "entry": i + 1}, nil) + _, found := lh.myVpnNetworksTable.Lookup(vpnAddr) + if !found { + return util.NewContextualError("static_host_map key is not in our network, invalid", m{"vpnAddr": vpnAddr, "networks": lh.myVpnNetworks, "entry": i + 1}, nil) } vals, ok := v.([]interface{}) @@ -443,7 +445,7 @@ func (lh *LightHouse) loadStaticMap(c *config.C, staticList map[netip.Addr]struc remoteAddrs = append(remoteAddrs, fmt.Sprintf("%v", v)) } - err = lh.addStaticRemotes(i, d, network, lookupTimeout, vpnIp, remoteAddrs, staticList) + err = lh.addStaticRemotes(i, d, network, lookupTimeout, vpnAddr, remoteAddrs, staticList) if err != nil { return err } @@ -453,12 +455,12 @@ func (lh *LightHouse) loadStaticMap(c *config.C, staticList map[netip.Addr]struc return nil } -func (lh *LightHouse) Query(ip netip.Addr) *RemoteList { - if !lh.IsLighthouseIP(ip) { - lh.QueryServer(ip) +func (lh *LightHouse) Query(vpnAddr netip.Addr) *RemoteList { + if !lh.IsLighthouseAddr(vpnAddr) { + lh.QueryServer(vpnAddr) } lh.RLock() - if v, ok := lh.addrMap[ip]; ok { + if v, ok := lh.addrMap[vpnAddr]; ok { lh.RUnlock() return v } @@ -467,18 +469,18 @@ func (lh *LightHouse) Query(ip netip.Addr) *RemoteList { } // QueryServer is asynchronous so no reply should be expected -func (lh *LightHouse) QueryServer(ip netip.Addr) { - // Don't put lighthouse ips in the query channel because we can't query lighthouses about lighthouses - if lh.amLighthouse || lh.IsLighthouseIP(ip) { +func (lh *LightHouse) QueryServer(vpnAddr netip.Addr) { + // Don't put lighthouse addrs in the query channel because we can't query lighthouses about lighthouses + if lh.amLighthouse || lh.IsLighthouseAddr(vpnAddr) { return } - lh.queryChan <- ip + lh.queryChan <- vpnAddr } -func (lh *LightHouse) QueryCache(ip netip.Addr) *RemoteList { +func (lh *LightHouse) QueryCache(vpnAddrs []netip.Addr) *RemoteList { lh.RLock() - if v, ok := lh.addrMap[ip]; ok { + if v, ok := lh.addrMap[vpnAddrs[0]]; ok { lh.RUnlock() return v } @@ -487,24 +489,27 @@ func (lh *LightHouse) QueryCache(ip netip.Addr) *RemoteList { lh.Lock() defer lh.Unlock() // Add an entry if we don't already have one - return lh.unlockedGetRemoteList(ip) + return lh.unlockedGetRemoteList(vpnAddrs) } // queryAndPrepMessage is a lock helper on RemoteList, assisting the caller to build a lighthouse message containing -// details from the remote list. It looks for a hit in the addrMap and a hit in the RemoteList under the owner vpnIp +// details from the remote list. It looks for a hit in the addrMap and a hit in the RemoteList under the owner vpnAddr // If one is found then f() is called with proper locking, f() must return result of n.MarshalTo() -func (lh *LightHouse) queryAndPrepMessage(vpnIp netip.Addr, f func(*cache) (int, error)) (bool, int, error) { +func (lh *LightHouse) queryAndPrepMessage(vpnAddr netip.Addr, f func(*cache) (int, error)) (bool, int, error) { lh.RLock() // Do we have an entry in the main cache? - if v, ok := lh.addrMap[vpnIp]; ok { + if v, ok := lh.addrMap[vpnAddr]; ok { // Swap lh lock for remote list lock v.RLock() defer v.RUnlock() lh.RUnlock() - // vpnIp should also be the owner here since we are a lighthouse. - c := v.cache[vpnIp] + // We may be asking about a non primary address so lets get the primary address + if slices.Contains(v.vpnAddrs, vpnAddr) { + vpnAddr = v.vpnAddrs[0] + } + c := v.cache[vpnAddr] // Make sure we have if c != nil { n, err := f(c) @@ -516,112 +521,140 @@ func (lh *LightHouse) queryAndPrepMessage(vpnIp netip.Addr, f func(*cache) (int, return false, 0, nil } -func (lh *LightHouse) DeleteVpnIp(vpnIp netip.Addr) { +func (lh *LightHouse) DeleteVpnAddrs(allVpnAddrs []netip.Addr) { // First we check the static mapping // and do nothing if it is there - if _, ok := lh.GetStaticHostList()[vpnIp]; ok { + if _, ok := lh.GetStaticHostList()[allVpnAddrs[0]]; ok { return } lh.Lock() - //l.Debugln(lh.addrMap) - delete(lh.addrMap, vpnIp) - - if lh.l.Level >= logrus.DebugLevel { - lh.l.Debugf("deleting %s from lighthouse.", vpnIp) + rm, ok := lh.addrMap[allVpnAddrs[0]] + if ok { + for _, addr := range allVpnAddrs { + srm := lh.addrMap[addr] + if srm == rm { + delete(lh.addrMap, addr) + if lh.l.Level >= logrus.DebugLevel { + lh.l.Debugf("deleting %s from lighthouse.", addr) + } + } + } } - lh.Unlock() } -// AddStaticRemote adds a static host entry for vpnIp as ourselves as the owner +// AddStaticRemote adds a static host entry for vpnAddr as ourselves as the owner // We are the owner because we don't want a lighthouse server to advertise for static hosts it was configured with // And we don't want a lighthouse query reply to interfere with our learned cache if we are a client // NOTE: this function should not interact with any hot path objects, like lh.staticList, the caller should handle it -func (lh *LightHouse) addStaticRemotes(i int, d time.Duration, network string, timeout time.Duration, vpnIp netip.Addr, toAddrs []string, staticList map[netip.Addr]struct{}) error { +func (lh *LightHouse) addStaticRemotes(i int, d time.Duration, network string, timeout time.Duration, vpnAddr netip.Addr, toAddrs []string, staticList map[netip.Addr]struct{}) error { lh.Lock() - am := lh.unlockedGetRemoteList(vpnIp) + am := lh.unlockedGetRemoteList([]netip.Addr{vpnAddr}) am.Lock() defer am.Unlock() ctx := lh.ctx lh.Unlock() hr, err := NewHostnameResults(ctx, lh.l, d, network, timeout, toAddrs, func() { - // This callback runs whenever the DNS hostname resolver finds a different set of IP's + // This callback runs whenever the DNS hostname resolver finds a different set of addr's // in its resolution for hostnames. am.Lock() defer am.Unlock() am.shouldRebuild = true }) if err != nil { - return util.NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp, "entry": i + 1}, err) + return util.NewContextualError("Static host address could not be parsed", m{"vpnAddr": vpnAddr, "entry": i + 1}, err) } am.unlockedSetHostnamesResults(hr) - for _, addrPort := range hr.GetIPs() { - if !lh.shouldAdd(vpnIp, addrPort.Addr()) { + for _, addrPort := range hr.GetAddrs() { + if !lh.shouldAdd(vpnAddr, addrPort.Addr()) { continue } switch { case addrPort.Addr().Is4(): - am.unlockedPrependV4(lh.myVpnNet.Addr(), NewIp4AndPortFromNetIP(addrPort.Addr(), addrPort.Port())) + am.unlockedPrependV4(lh.myVpnNetworks[0].Addr(), netAddrToProtoV4AddrPort(addrPort.Addr(), addrPort.Port())) case addrPort.Addr().Is6(): - am.unlockedPrependV6(lh.myVpnNet.Addr(), NewIp6AndPortFromNetIP(addrPort.Addr(), addrPort.Port())) + am.unlockedPrependV6(lh.myVpnNetworks[0].Addr(), netAddrToProtoV6AddrPort(addrPort.Addr(), addrPort.Port())) } } // Mark it as static in the caller provided map - staticList[vpnIp] = struct{}{} + staticList[vpnAddr] = struct{}{} return nil } // addCalculatedRemotes adds any calculated remotes based on the // lighthouse.calculated_remotes configuration. It returns true if any // calculated remotes were added -func (lh *LightHouse) addCalculatedRemotes(vpnIp netip.Addr) bool { +func (lh *LightHouse) addCalculatedRemotes(vpnAddr netip.Addr) bool { tree := lh.getCalculatedRemotes() if tree == nil { return false } - calculatedRemotes, ok := tree.Lookup(vpnIp) + calculatedRemotes, ok := tree.Lookup(vpnAddr) if !ok { return false } - var calculated []*Ip4AndPort + var calculatedV4 []*V4AddrPort + var calculatedV6 []*V6AddrPort for _, cr := range calculatedRemotes { - c := cr.Apply(vpnIp) - if c != nil { - calculated = append(calculated, c) + if vpnAddr.Is4() { + c := cr.ApplyV4(vpnAddr) + if c != nil { + calculatedV4 = append(calculatedV4, c) + } + } else if vpnAddr.Is6() { + c := cr.ApplyV6(vpnAddr) + if c != nil { + calculatedV6 = append(calculatedV6, c) + } } } lh.Lock() - am := lh.unlockedGetRemoteList(vpnIp) + am := lh.unlockedGetRemoteList([]netip.Addr{vpnAddr}) am.Lock() defer am.Unlock() lh.Unlock() - am.unlockedSetV4(lh.myVpnNet.Addr(), vpnIp, calculated, lh.unlockedShouldAddV4) + if len(calculatedV4) > 0 { + am.unlockedSetV4(lh.myVpnNetworks[0].Addr(), vpnAddr, calculatedV4, lh.unlockedShouldAddV4) + } - return len(calculated) > 0 + if len(calculatedV6) > 0 { + am.unlockedSetV6(lh.myVpnNetworks[0].Addr(), vpnAddr, calculatedV6, lh.unlockedShouldAddV6) + } + + return len(calculatedV4) > 0 || len(calculatedV6) > 0 } -// unlockedGetRemoteList assumes you have the lh lock -func (lh *LightHouse) unlockedGetRemoteList(vpnIp netip.Addr) *RemoteList { - am, ok := lh.addrMap[vpnIp] +// unlockedGetRemoteList +// assumes you have the lh lock +func (lh *LightHouse) unlockedGetRemoteList(allAddrs []netip.Addr) *RemoteList { + am, ok := lh.addrMap[allAddrs[0]] if !ok { - am = NewRemoteList(func(a netip.Addr) bool { return lh.shouldAdd(vpnIp, a) }) - lh.addrMap[vpnIp] = am + am = NewRemoteList(allAddrs, func(a netip.Addr) bool { return lh.shouldAdd(allAddrs[0], a) }) + for _, addr := range allAddrs { + lh.addrMap[addr] = am + } } return am } -func (lh *LightHouse) shouldAdd(vpnIp netip.Addr, to netip.Addr) bool { - allow := lh.GetRemoteAllowList().Allow(vpnIp, to) +func (lh *LightHouse) shouldAdd(vpnAddr netip.Addr, to netip.Addr) bool { + allow := lh.GetRemoteAllowList().Allow(vpnAddr, to) if lh.l.Level >= logrus.TraceLevel { - lh.l.WithField("remoteIp", vpnIp).WithField("allow", allow).Trace("remoteAllowList.Allow") + lh.l.WithField("vpnAddr", vpnAddr).WithField("udpAddr", to).WithField("allow", allow). + Trace("remoteAllowList.Allow") } - if !allow || lh.myVpnNet.Contains(to) { + if !allow { + return false + } + + _, found := lh.myVpnNetworksTable.Lookup(to) + if found { return false } @@ -629,14 +662,20 @@ func (lh *LightHouse) shouldAdd(vpnIp netip.Addr, to netip.Addr) bool { } // unlockedShouldAddV4 checks if to is allowed by our allow list -func (lh *LightHouse) unlockedShouldAddV4(vpnIp netip.Addr, to *Ip4AndPort) bool { - ip := AddrPortFromIp4AndPort(to) - allow := lh.GetRemoteAllowList().Allow(vpnIp, ip.Addr()) +func (lh *LightHouse) unlockedShouldAddV4(vpnAddr netip.Addr, to *V4AddrPort) bool { + udpAddr := protoV4AddrPortToNetAddrPort(to) + allow := lh.GetRemoteAllowList().Allow(vpnAddr, udpAddr.Addr()) if lh.l.Level >= logrus.TraceLevel { - lh.l.WithField("remoteIp", vpnIp).WithField("allow", allow).Trace("remoteAllowList.Allow") + lh.l.WithField("vpnAddr", vpnAddr).WithField("udpAddr", udpAddr).WithField("allow", allow). + Trace("remoteAllowList.Allow") } - if !allow || lh.myVpnNet.Contains(ip.Addr()) { + if !allow { + return false + } + + _, found := lh.myVpnNetworksTable.Lookup(udpAddr.Addr()) + if found { return false } @@ -644,78 +683,43 @@ func (lh *LightHouse) unlockedShouldAddV4(vpnIp netip.Addr, to *Ip4AndPort) bool } // unlockedShouldAddV6 checks if to is allowed by our allow list -func (lh *LightHouse) unlockedShouldAddV6(vpnIp netip.Addr, to *Ip6AndPort) bool { - ip := AddrPortFromIp6AndPort(to) - allow := lh.GetRemoteAllowList().Allow(vpnIp, ip.Addr()) +func (lh *LightHouse) unlockedShouldAddV6(vpnAddr netip.Addr, to *V6AddrPort) bool { + udpAddr := protoV6AddrPortToNetAddrPort(to) + allow := lh.GetRemoteAllowList().Allow(vpnAddr, udpAddr.Addr()) if lh.l.Level >= logrus.TraceLevel { - lh.l.WithField("remoteIp", lhIp6ToIp(to)).WithField("allow", allow).Trace("remoteAllowList.Allow") + lh.l.WithField("vpnAddr", vpnAddr).WithField("udpAddr", udpAddr).WithField("allow", allow). + Trace("remoteAllowList.Allow") } - if !allow || lh.myVpnNet.Contains(ip.Addr()) { + if !allow { + return false + } + + _, found := lh.myVpnNetworksTable.Lookup(udpAddr.Addr()) + if found { return false } return true } -func lhIp6ToIp(v *Ip6AndPort) net.IP { - ip := make(net.IP, 16) - binary.BigEndian.PutUint64(ip[:8], v.Hi) - binary.BigEndian.PutUint64(ip[8:], v.Lo) - return ip -} - -func (lh *LightHouse) IsLighthouseIP(vpnIp netip.Addr) bool { - if _, ok := lh.GetLighthouses()[vpnIp]; ok { +func (lh *LightHouse) IsLighthouseAddr(vpnAddr netip.Addr) bool { + if _, ok := lh.GetLighthouses()[vpnAddr]; ok { return true } return false } -func NewLhQueryByInt(vpnIp netip.Addr) *NebulaMeta { - if vpnIp.Is6() { - //TODO: need to support ipv6 - panic("ipv6 is not yet supported") - } - - b := vpnIp.As4() - return &NebulaMeta{ - Type: NebulaMeta_HostQuery, - Details: &NebulaMetaDetails{ - VpnIp: binary.BigEndian.Uint32(b[:]), - }, - } -} - -func AddrPortFromIp4AndPort(ip *Ip4AndPort) netip.AddrPort { - b := [4]byte{} - binary.BigEndian.PutUint32(b[:], ip.Ip) - return netip.AddrPortFrom(netip.AddrFrom4(b), uint16(ip.Port)) -} - -func AddrPortFromIp6AndPort(ip *Ip6AndPort) netip.AddrPort { - b := [16]byte{} - binary.BigEndian.PutUint64(b[:8], ip.Hi) - binary.BigEndian.PutUint64(b[8:], ip.Lo) - return netip.AddrPortFrom(netip.AddrFrom16(b), uint16(ip.Port)) -} - -func NewIp4AndPortFromNetIP(ip netip.Addr, port uint16) *Ip4AndPort { - v4Addr := ip.As4() - return &Ip4AndPort{ - Ip: binary.BigEndian.Uint32(v4Addr[:]), - Port: uint32(port), - } -} - -// TODO: IPV6-WORK we can delete some more of these -func NewIp6AndPortFromNetIP(ip netip.Addr, port uint16) *Ip6AndPort { - ip6Addr := ip.As16() - return &Ip6AndPort{ - Hi: binary.BigEndian.Uint64(ip6Addr[:8]), - Lo: binary.BigEndian.Uint64(ip6Addr[8:]), - Port: uint32(port), +// TODO: CERT-V2 IsLighthouseAddr should be sufficient, we just need to update the vpnAddrs for lighthouses after a handshake +// so that we know all the lighthouse vpnAddrs, not just the ones we were configured to talk to initially +func (lh *LightHouse) IsAnyLighthouseAddr(vpnAddr []netip.Addr) bool { + l := lh.GetLighthouses() + for _, a := range vpnAddr { + if _, ok := l[a]; ok { + return true + } } + return false } func (lh *LightHouse) startQueryWorker() { @@ -731,31 +735,85 @@ func (lh *LightHouse) startQueryWorker() { select { case <-lh.ctx.Done(): return - case ip := <-lh.queryChan: - lh.innerQueryServer(ip, nb, out) + case addr := <-lh.queryChan: + lh.innerQueryServer(addr, nb, out) } } }() } -func (lh *LightHouse) innerQueryServer(ip netip.Addr, nb, out []byte) { - if lh.IsLighthouseIP(ip) { +func (lh *LightHouse) innerQueryServer(addr netip.Addr, nb, out []byte) { + if lh.IsLighthouseAddr(addr) { return } - // Send a query to the lighthouses and hope for the best next time - query, err := NewLhQueryByInt(ip).Marshal() - if err != nil { - lh.l.WithError(err).WithField("vpnIp", ip).Error("Failed to marshal lighthouse query payload") - return + msg := &NebulaMeta{ + Type: NebulaMeta_HostQuery, + Details: &NebulaMetaDetails{}, } + var v1Query, v2Query []byte + var err error + var v cert.Version + queried := 0 lighthouses := lh.GetLighthouses() - lh.metricTx(NebulaMeta_HostQuery, int64(len(lighthouses))) - for n := range lighthouses { - lh.ifce.SendMessageToVpnIp(header.LightHouse, 0, n, query, nb, out) + for lhVpnAddr := range lighthouses { + hi := lh.ifce.GetHostInfo(lhVpnAddr) + if hi != nil { + v = hi.ConnectionState.myCert.Version() + } else { + v = lh.ifce.GetCertState().defaultVersion + } + + if v == cert.Version1 { + if !addr.Is4() { + lh.l.WithField("queryVpnAddr", addr).WithField("lighthouseAddr", lhVpnAddr). + Error("Can't query lighthouse for v6 address using a v1 protocol") + continue + } + + if v1Query == nil { + b := addr.As4() + msg.Details.VpnAddr = nil + msg.Details.OldVpnAddr = binary.BigEndian.Uint32(b[:]) + + v1Query, err = msg.Marshal() + if err != nil { + lh.l.WithError(err).WithField("queryVpnAddr", addr). + WithField("lighthouseAddr", lhVpnAddr). + Error("Failed to marshal lighthouse v1 query payload") + continue + } + } + + lh.ifce.SendMessageToVpnAddr(header.LightHouse, 0, lhVpnAddr, v1Query, nb, out) + queried++ + + } else if v == cert.Version2 { + if v2Query == nil { + msg.Details.OldVpnAddr = 0 + msg.Details.VpnAddr = netAddrToProtoAddr(addr) + + v2Query, err = msg.Marshal() + if err != nil { + lh.l.WithError(err).WithField("queryVpnAddr", addr). + WithField("lighthouseAddr", lhVpnAddr). + Error("Failed to marshal lighthouse v2 query payload") + continue + } + } + + lh.ifce.SendMessageToVpnAddr(header.LightHouse, 0, lhVpnAddr, v2Query, nb, out) + queried++ + + } else { + lh.l.Debugf("Can not query lighthouse for %v using unknown protocol version: %v", addr, v) + continue + } } + + lh.metricTx(NebulaMeta_HostQuery, int64(queried)) } func (lh *LightHouse) StartUpdateWorker() { @@ -785,65 +843,120 @@ func (lh *LightHouse) StartUpdateWorker() { } func (lh *LightHouse) SendUpdate() { - var v4 []*Ip4AndPort - var v6 []*Ip6AndPort + var v4 []*V4AddrPort + var v6 []*V6AddrPort for _, e := range lh.GetAdvertiseAddrs() { if e.Addr().Is4() { - v4 = append(v4, NewIp4AndPortFromNetIP(e.Addr(), e.Port())) + v4 = append(v4, netAddrToProtoV4AddrPort(e.Addr(), e.Port())) } else { - v6 = append(v6, NewIp6AndPortFromNetIP(e.Addr(), e.Port())) + v6 = append(v6, netAddrToProtoV6AddrPort(e.Addr(), e.Port())) } } lal := lh.GetLocalAllowList() - for _, e := range localIps(lh.l, lal) { - if lh.myVpnNet.Contains(e) { + for _, e := range localAddrs(lh.l, lal) { + _, found := lh.myVpnNetworksTable.Lookup(e) + if found { continue } - // Only add IPs that aren't my VPN/tun IP + // Only add addrs that aren't my VPN/tun networks if e.Is4() { - v4 = append(v4, NewIp4AndPortFromNetIP(e, uint16(lh.nebulaPort))) + v4 = append(v4, netAddrToProtoV4AddrPort(e, uint16(lh.nebulaPort))) } else { - v6 = append(v6, NewIp6AndPortFromNetIP(e, uint16(lh.nebulaPort))) + v6 = append(v6, netAddrToProtoV6AddrPort(e, uint16(lh.nebulaPort))) } } - var relays []uint32 - for _, r := range lh.GetRelaysForMe() { - //TODO: IPV6-WORK both relays and vpnip need ipv6 support - b := r.As4() - relays = append(relays, binary.BigEndian.Uint32(b[:])) - } - - //TODO: IPV6-WORK both relays and vpnip need ipv6 support - b := lh.myVpnNet.Addr().As4() - - m := &NebulaMeta{ - Type: NebulaMeta_HostUpdateNotification, - Details: &NebulaMetaDetails{ - VpnIp: binary.BigEndian.Uint32(b[:]), - Ip4AndPorts: v4, - Ip6AndPorts: v6, - RelayVpnIp: relays, - }, - } - - lighthouses := lh.GetLighthouses() - lh.metricTx(NebulaMeta_HostUpdateNotification, int64(len(lighthouses))) nb := make([]byte, 12, 12) out := make([]byte, mtu) - mm, err := m.Marshal() - if err != nil { - lh.l.WithError(err).Error("Error while marshaling for lighthouse update") - return + var v1Update, v2Update []byte + var err error + updated := 0 + lighthouses := lh.GetLighthouses() + + for lhVpnAddr := range lighthouses { + var v cert.Version + hi := lh.ifce.GetHostInfo(lhVpnAddr) + if hi != nil { + v = hi.ConnectionState.myCert.Version() + } else { + v = lh.ifce.GetCertState().defaultVersion + } + if v == cert.Version1 { + if v1Update == nil { + if !lh.myVpnNetworks[0].Addr().Is4() { + lh.l.WithField("lighthouseAddr", lhVpnAddr). + Warn("cannot update lighthouse using v1 protocol without an IPv4 address") + continue + } + var relays []uint32 + for _, r := range lh.GetRelaysForMe() { + if !r.Is4() { + continue + } + b := r.As4() + relays = append(relays, binary.BigEndian.Uint32(b[:])) + } + b := lh.myVpnNetworks[0].Addr().As4() + msg := NebulaMeta{ + Type: NebulaMeta_HostUpdateNotification, + Details: &NebulaMetaDetails{ + V4AddrPorts: v4, + V6AddrPorts: v6, + OldRelayVpnAddrs: relays, + OldVpnAddr: binary.BigEndian.Uint32(b[:]), + }, + } + + v1Update, err = msg.Marshal() + if err != nil { + lh.l.WithError(err).WithField("lighthouseAddr", lhVpnAddr). + Error("Error while marshaling for lighthouse v1 update") + continue + } + } + + lh.ifce.SendMessageToVpnAddr(header.LightHouse, 0, lhVpnAddr, v1Update, nb, out) + updated++ + + } else if v == cert.Version2 { + if v2Update == nil { + var relays []*Addr + for _, r := range lh.GetRelaysForMe() { + relays = append(relays, netAddrToProtoAddr(r)) + } + + msg := NebulaMeta{ + Type: NebulaMeta_HostUpdateNotification, + Details: &NebulaMetaDetails{ + V4AddrPorts: v4, + V6AddrPorts: v6, + RelayVpnAddrs: relays, + VpnAddr: netAddrToProtoAddr(lh.myVpnNetworks[0].Addr()), + }, + } + + v2Update, err = msg.Marshal() + if err != nil { + lh.l.WithError(err).WithField("lighthouseAddr", lhVpnAddr). + Error("Error while marshaling for lighthouse v2 update") + continue + } + } + + lh.ifce.SendMessageToVpnAddr(header.LightHouse, 0, lhVpnAddr, v2Update, nb, out) + updated++ + + } else { + lh.l.Debugf("Can not update lighthouse using unknown protocol version: %v", v) + continue + } } - for vpnIp := range lighthouses { - lh.ifce.SendMessageToVpnIp(header.LightHouse, 0, vpnIp, mm, nb, out) - } + lh.metricTx(NebulaMeta_HostUpdateNotification, int64(updated)) } type LightHouseHandler struct { @@ -886,34 +999,29 @@ func (lhh *LightHouseHandler) resetMeta() *NebulaMeta { lhh.meta.Reset() // Keep the array memory around - details.Ip4AndPorts = details.Ip4AndPorts[:0] - details.Ip6AndPorts = details.Ip6AndPorts[:0] - details.RelayVpnIp = details.RelayVpnIp[:0] + details.V4AddrPorts = details.V4AddrPorts[:0] + details.V6AddrPorts = details.V6AddrPorts[:0] + details.RelayVpnAddrs = details.RelayVpnAddrs[:0] + details.OldRelayVpnAddrs = details.OldRelayVpnAddrs[:0] + details.OldVpnAddr = 0 + details.VpnAddr = nil lhh.meta.Details = details return lhh.meta } -func lhHandleRequest(lhh *LightHouseHandler, f *Interface) udp.LightHouseHandlerFunc { - return func(rAddr netip.AddrPort, vpnIp netip.Addr, p []byte) { - lhh.HandleRequest(rAddr, vpnIp, p, f) - } -} - -func (lhh *LightHouseHandler) HandleRequest(rAddr netip.AddrPort, vpnIp netip.Addr, p []byte, w EncWriter) { +func (lhh *LightHouseHandler) HandleRequest(rAddr netip.AddrPort, fromVpnAddrs []netip.Addr, p []byte, w EncWriter) { n := lhh.resetMeta() err := n.Unmarshal(p) if err != nil { - lhh.l.WithError(err).WithField("vpnIp", vpnIp).WithField("udpAddr", rAddr). + lhh.l.WithError(err).WithField("vpnAddrs", fromVpnAddrs).WithField("udpAddr", rAddr). Error("Failed to unmarshal lighthouse packet") - //TODO: send recv_error? return } if n.Details == nil { - lhh.l.WithField("vpnIp", vpnIp).WithField("udpAddr", rAddr). + lhh.l.WithField("vpnAddrs", fromVpnAddrs).WithField("udpAddr", rAddr). Error("Invalid lighthouse update") - //TODO: send recv_error? return } @@ -921,24 +1029,24 @@ func (lhh *LightHouseHandler) HandleRequest(rAddr netip.AddrPort, vpnIp netip.Ad switch n.Type { case NebulaMeta_HostQuery: - lhh.handleHostQuery(n, vpnIp, rAddr, w) + lhh.handleHostQuery(n, fromVpnAddrs, rAddr, w) case NebulaMeta_HostQueryReply: - lhh.handleHostQueryReply(n, vpnIp) + lhh.handleHostQueryReply(n, fromVpnAddrs) case NebulaMeta_HostUpdateNotification: - lhh.handleHostUpdateNotification(n, vpnIp, w) + lhh.handleHostUpdateNotification(n, fromVpnAddrs, w) case NebulaMeta_HostMovedNotification: case NebulaMeta_HostPunchNotification: - lhh.handleHostPunchNotification(n, vpnIp, w) + lhh.handleHostPunchNotification(n, fromVpnAddrs, w) case NebulaMeta_HostUpdateNotificationAck: // noop } } -func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, vpnIp netip.Addr, addr netip.AddrPort, w EncWriter) { +func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, fromVpnAddrs []netip.Addr, addr netip.AddrPort, w EncWriter) { // Exit if we don't answer queries if !lhh.lh.amLighthouse { if lhh.l.Level >= logrus.DebugLevel { @@ -947,21 +1055,37 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, vpnIp netip.Addr, a return } - //TODO: we can DRY this further - reqVpnIp := n.Details.VpnIp + useVersion := cert.Version1 + var queryVpnAddr netip.Addr + if n.Details.OldVpnAddr != 0 { + b := [4]byte{} + binary.BigEndian.PutUint32(b[:], n.Details.OldVpnAddr) + queryVpnAddr = netip.AddrFrom4(b) + useVersion = 1 + } else if n.Details.VpnAddr != nil { + queryVpnAddr = protoAddrToNetAddr(n.Details.VpnAddr) + useVersion = 2 + } else { + if lhh.l.Level >= logrus.DebugLevel { + lhh.l.WithField("from", fromVpnAddrs).WithField("details", n.Details).Debugln("Dropping malformed HostQuery") + } + return + } - //TODO: IPV6-WORK - b := [4]byte{} - binary.BigEndian.PutUint32(b[:], n.Details.VpnIp) - queryVpnIp := netip.AddrFrom4(b) - - //TODO: Maybe instead of marshalling into n we marshal into a new `r` to not nuke our current request data - found, ln, err := lhh.lh.queryAndPrepMessage(queryVpnIp, func(c *cache) (int, error) { + found, ln, err := lhh.lh.queryAndPrepMessage(queryVpnAddr, func(c *cache) (int, error) { n = lhh.resetMeta() n.Type = NebulaMeta_HostQueryReply - n.Details.VpnIp = reqVpnIp + if useVersion == cert.Version1 { + if !queryVpnAddr.Is4() { + return 0, fmt.Errorf("invalid vpn addr for v1 handleHostQuery") + } + b := queryVpnAddr.As4() + n.Details.OldVpnAddr = binary.BigEndian.Uint32(b[:]) + } else { + n.Details.VpnAddr = netAddrToProtoAddr(queryVpnAddr) + } - lhh.coalesceAnswers(c, n) + lhh.coalesceAnswers(useVersion, c, n) return n.MarshalTo(lhh.pb) }) @@ -971,21 +1095,51 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, vpnIp netip.Addr, a } if err != nil { - lhh.l.WithError(err).WithField("vpnIp", vpnIp).Error("Failed to marshal lighthouse host query reply") + lhh.l.WithError(err).WithField("vpnAddrs", fromVpnAddrs).Error("Failed to marshal lighthouse host query reply") return } lhh.lh.metricTx(NebulaMeta_HostQueryReply, 1) - w.SendMessageToVpnIp(header.LightHouse, 0, vpnIp, lhh.pb[:ln], lhh.nb, lhh.out[:0]) + w.SendMessageToVpnAddr(header.LightHouse, 0, fromVpnAddrs[0], lhh.pb[:ln], lhh.nb, lhh.out[:0]) - // This signals the other side to punch some zero byte udp packets - found, ln, err = lhh.lh.queryAndPrepMessage(vpnIp, func(c *cache) (int, error) { + lhh.sendHostPunchNotification(n, fromVpnAddrs, queryVpnAddr, w) +} + +// sendHostPunchNotification signals the other side to punch some zero byte udp packets +func (lhh *LightHouseHandler) sendHostPunchNotification(n *NebulaMeta, fromVpnAddrs []netip.Addr, punchNotifDest netip.Addr, w EncWriter) { + whereToPunch := fromVpnAddrs[0] + found, ln, err := lhh.lh.queryAndPrepMessage(whereToPunch, func(c *cache) (int, error) { n = lhh.resetMeta() n.Type = NebulaMeta_HostPunchNotification - //TODO: IPV6-WORK - b = vpnIp.As4() - n.Details.VpnIp = binary.BigEndian.Uint32(b[:]) - lhh.coalesceAnswers(c, n) + targetHI := lhh.lh.ifce.GetHostInfo(punchNotifDest) + var useVersion cert.Version + if targetHI == nil { + useVersion = lhh.lh.ifce.GetCertState().defaultVersion + } else { + crt := targetHI.GetCert().Certificate + useVersion = crt.Version() + // we can only retarget if we have a hostinfo + newDest, ok := findNetworkUnion(crt.Networks(), fromVpnAddrs) + if ok { + whereToPunch = newDest + } else { + //TODO: CERT-V2 this means the destination will have no addresses in common with the punch-ee + //choosing to do nothing for now, but maybe we return an error? + } + } + + if useVersion == cert.Version1 { + if !whereToPunch.Is4() { + return 0, fmt.Errorf("invalid vpn addr for v1 handleHostQuery") + } + b := whereToPunch.As4() + n.Details.OldVpnAddr = binary.BigEndian.Uint32(b[:]) + } else if useVersion == cert.Version2 { + n.Details.VpnAddr = netAddrToProtoAddr(whereToPunch) + } else { + return 0, errors.New("unsupported version") + } + lhh.coalesceAnswers(useVersion, c, n) return n.MarshalTo(lhh.pb) }) @@ -995,139 +1149,169 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, vpnIp netip.Addr, a } if err != nil { - lhh.l.WithError(err).WithField("vpnIp", vpnIp).Error("Failed to marshal lighthouse host was queried for") + lhh.l.WithError(err).WithField("vpnAddrs", fromVpnAddrs).Error("Failed to marshal lighthouse host was queried for") return } lhh.lh.metricTx(NebulaMeta_HostPunchNotification, 1) - - //TODO: IPV6-WORK - binary.BigEndian.PutUint32(b[:], reqVpnIp) - sendTo := netip.AddrFrom4(b) - w.SendMessageToVpnIp(header.LightHouse, 0, sendTo, lhh.pb[:ln], lhh.nb, lhh.out[:0]) + w.SendMessageToVpnAddr(header.LightHouse, 0, punchNotifDest, lhh.pb[:ln], lhh.nb, lhh.out[:0]) } -func (lhh *LightHouseHandler) coalesceAnswers(c *cache, n *NebulaMeta) { +func (lhh *LightHouseHandler) coalesceAnswers(v cert.Version, c *cache, n *NebulaMeta) { if c.v4 != nil { if c.v4.learned != nil { - n.Details.Ip4AndPorts = append(n.Details.Ip4AndPorts, c.v4.learned) + n.Details.V4AddrPorts = append(n.Details.V4AddrPorts, c.v4.learned) } if c.v4.reported != nil && len(c.v4.reported) > 0 { - n.Details.Ip4AndPorts = append(n.Details.Ip4AndPorts, c.v4.reported...) + n.Details.V4AddrPorts = append(n.Details.V4AddrPorts, c.v4.reported...) } } if c.v6 != nil { if c.v6.learned != nil { - n.Details.Ip6AndPorts = append(n.Details.Ip6AndPorts, c.v6.learned) + n.Details.V6AddrPorts = append(n.Details.V6AddrPorts, c.v6.learned) } if c.v6.reported != nil && len(c.v6.reported) > 0 { - n.Details.Ip6AndPorts = append(n.Details.Ip6AndPorts, c.v6.reported...) + n.Details.V6AddrPorts = append(n.Details.V6AddrPorts, c.v6.reported...) } } if c.relay != nil { - //TODO: IPV6-WORK - relays := make([]uint32, len(c.relay.relay)) - b := [4]byte{} - for i, _ := range relays { - b = c.relay.relay[i].As4() - relays[i] = binary.BigEndian.Uint32(b[:]) + if v == cert.Version1 { + b := [4]byte{} + for _, r := range c.relay.relay { + if !r.Is4() { + continue + } + + b = r.As4() + n.Details.OldRelayVpnAddrs = append(n.Details.OldRelayVpnAddrs, binary.BigEndian.Uint32(b[:])) + } + + } else if v == cert.Version2 { + for _, r := range c.relay.relay { + n.Details.RelayVpnAddrs = append(n.Details.RelayVpnAddrs, netAddrToProtoAddr(r)) + } + + } else { + //TODO: CERT-V2 don't panic + panic("unsupported version") } - n.Details.RelayVpnIp = append(n.Details.RelayVpnIp, relays...) } } -func (lhh *LightHouseHandler) handleHostQueryReply(n *NebulaMeta, vpnIp netip.Addr) { - if !lhh.lh.IsLighthouseIP(vpnIp) { +func (lhh *LightHouseHandler) handleHostQueryReply(n *NebulaMeta, fromVpnAddrs []netip.Addr) { + if !lhh.lh.IsAnyLighthouseAddr(fromVpnAddrs) { return } lhh.lh.Lock() - //TODO: IPV6-WORK - b := [4]byte{} - binary.BigEndian.PutUint32(b[:], n.Details.VpnIp) - certVpnIp := netip.AddrFrom4(b) - am := lhh.lh.unlockedGetRemoteList(certVpnIp) + + var certVpnAddr netip.Addr + if n.Details.OldVpnAddr != 0 { + b := [4]byte{} + binary.BigEndian.PutUint32(b[:], n.Details.OldVpnAddr) + certVpnAddr = netip.AddrFrom4(b) + } else if n.Details.VpnAddr != nil { + certVpnAddr = protoAddrToNetAddr(n.Details.VpnAddr) + } + relays := n.Details.GetRelays() + + am := lhh.lh.unlockedGetRemoteList([]netip.Addr{certVpnAddr}) am.Lock() lhh.lh.Unlock() - //TODO: IPV6-WORK - am.unlockedSetV4(vpnIp, certVpnIp, n.Details.Ip4AndPorts, lhh.lh.unlockedShouldAddV4) - am.unlockedSetV6(vpnIp, certVpnIp, n.Details.Ip6AndPorts, lhh.lh.unlockedShouldAddV6) - - //TODO: IPV6-WORK - relays := make([]netip.Addr, len(n.Details.RelayVpnIp)) - for i, _ := range n.Details.RelayVpnIp { - binary.BigEndian.PutUint32(b[:], n.Details.RelayVpnIp[i]) - relays[i] = netip.AddrFrom4(b) - } - am.unlockedSetRelay(vpnIp, certVpnIp, relays) + am.unlockedSetV4(fromVpnAddrs[0], certVpnAddr, n.Details.V4AddrPorts, lhh.lh.unlockedShouldAddV4) + am.unlockedSetV6(fromVpnAddrs[0], certVpnAddr, n.Details.V6AddrPorts, lhh.lh.unlockedShouldAddV6) + am.unlockedSetRelay(fromVpnAddrs[0], relays) am.Unlock() // Non-blocking attempt to trigger, skip if it would block select { - case lhh.lh.handshakeTrigger <- certVpnIp: + case lhh.lh.handshakeTrigger <- certVpnAddr: default: } } -func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, vpnIp netip.Addr, w EncWriter) { +func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, fromVpnAddrs []netip.Addr, w EncWriter) { if !lhh.lh.amLighthouse { if lhh.l.Level >= logrus.DebugLevel { - lhh.l.Debugln("I am not a lighthouse, do not take host updates: ", vpnIp) + lhh.l.Debugln("I am not a lighthouse, do not take host updates: ", fromVpnAddrs) } return } - //Simple check that the host sent this not someone else - //TODO: IPV6-WORK - b := [4]byte{} - binary.BigEndian.PutUint32(b[:], n.Details.VpnIp) - detailsVpnIp := netip.AddrFrom4(b) - if detailsVpnIp != vpnIp { + var detailsVpnAddr netip.Addr + useVersion := cert.Version1 + if n.Details.OldVpnAddr != 0 { + b := [4]byte{} + binary.BigEndian.PutUint32(b[:], n.Details.OldVpnAddr) + detailsVpnAddr = netip.AddrFrom4(b) + useVersion = cert.Version1 + } else if n.Details.VpnAddr != nil { + detailsVpnAddr = protoAddrToNetAddr(n.Details.VpnAddr) + useVersion = cert.Version2 + } else { if lhh.l.Level >= logrus.DebugLevel { - lhh.l.WithField("vpnIp", vpnIp).WithField("answer", detailsVpnIp).Debugln("Host sent invalid update") + lhh.l.WithField("details", n.Details).Debugf("dropping invalid HostUpdateNotification") } return } + //TODO: CERT-V2 hosts with only v2 certs cannot provide their ipv6 addr when contacting the lighthouse via v4? + //TODO: CERT-V2 why do we care about the vpnAddr in the packet? We know where it came from, right? + //Simple check that the host sent this not someone else + if !slices.Contains(fromVpnAddrs, detailsVpnAddr) { + if lhh.l.Level >= logrus.DebugLevel { + lhh.l.WithField("vpnAddrs", fromVpnAddrs).WithField("answer", detailsVpnAddr).Debugln("Host sent invalid update") + } + return + } + + relays := n.Details.GetRelays() + lhh.lh.Lock() - am := lhh.lh.unlockedGetRemoteList(vpnIp) + am := lhh.lh.unlockedGetRemoteList(fromVpnAddrs) am.Lock() lhh.lh.Unlock() - am.unlockedSetV4(vpnIp, detailsVpnIp, n.Details.Ip4AndPorts, lhh.lh.unlockedShouldAddV4) - am.unlockedSetV6(vpnIp, detailsVpnIp, n.Details.Ip6AndPorts, lhh.lh.unlockedShouldAddV6) - - //TODO: IPV6-WORK - relays := make([]netip.Addr, len(n.Details.RelayVpnIp)) - for i, _ := range n.Details.RelayVpnIp { - binary.BigEndian.PutUint32(b[:], n.Details.RelayVpnIp[i]) - relays[i] = netip.AddrFrom4(b) - } - am.unlockedSetRelay(vpnIp, detailsVpnIp, relays) + am.unlockedSetV4(fromVpnAddrs[0], detailsVpnAddr, n.Details.V4AddrPorts, lhh.lh.unlockedShouldAddV4) + am.unlockedSetV6(fromVpnAddrs[0], detailsVpnAddr, n.Details.V6AddrPorts, lhh.lh.unlockedShouldAddV6) + am.unlockedSetRelay(fromVpnAddrs[0], relays) am.Unlock() n = lhh.resetMeta() n.Type = NebulaMeta_HostUpdateNotificationAck - //TODO: IPV6-WORK - vpnIpB := vpnIp.As4() - n.Details.VpnIp = binary.BigEndian.Uint32(vpnIpB[:]) - ln, err := n.MarshalTo(lhh.pb) + if useVersion == cert.Version1 { + if !fromVpnAddrs[0].Is4() { + lhh.l.WithField("vpnAddrs", fromVpnAddrs).Error("Can not send HostUpdateNotificationAck for a ipv6 vpn ip in a v1 message") + return + } + vpnAddrB := fromVpnAddrs[0].As4() + n.Details.OldVpnAddr = binary.BigEndian.Uint32(vpnAddrB[:]) + } else if useVersion == cert.Version2 { + n.Details.VpnAddr = netAddrToProtoAddr(fromVpnAddrs[0]) + } else { + lhh.l.WithField("useVersion", useVersion).Error("invalid protocol version") + return + } + ln, err := n.MarshalTo(lhh.pb) if err != nil { - lhh.l.WithError(err).WithField("vpnIp", vpnIp).Error("Failed to marshal lighthouse host update ack") + lhh.l.WithError(err).WithField("vpnAddrs", fromVpnAddrs).Error("Failed to marshal lighthouse host update ack") return } lhh.lh.metricTx(NebulaMeta_HostUpdateNotificationAck, 1) - w.SendMessageToVpnIp(header.LightHouse, 0, vpnIp, lhh.pb[:ln], lhh.nb, lhh.out[:0]) + w.SendMessageToVpnAddr(header.LightHouse, 0, fromVpnAddrs[0], lhh.pb[:ln], lhh.nb, lhh.out[:0]) } -func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, vpnIp netip.Addr, w EncWriter) { - if !lhh.lh.IsLighthouseIP(vpnIp) { +func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, fromVpnAddrs []netip.Addr, w EncWriter) { + //It's possible the lighthouse is communicating with us using a non primary vpn addr, + //which means we need to compare all fromVpnAddrs against all configured lighthouse vpn addrs. + //maybe one day we'll have a better idea, if it matters. + if !lhh.lh.IsAnyLighthouseAddr(fromVpnAddrs) { return } @@ -1144,39 +1328,123 @@ func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, vpnIp n }() if lhh.l.Level >= logrus.DebugLevel { - //TODO: lacking the ip we are actually punching on, old: l.Debugf("Punching %s on %d for %s", IntIp(a.Ip), a.Port, IntIp(n.Details.VpnIp)) - //TODO: IPV6-WORK, make this debug line not suck - b := [4]byte{} - binary.BigEndian.PutUint32(b[:], n.Details.VpnIp) - lhh.l.Debugf("Punching on %d for %v", vpnPeer.Port(), netip.AddrFrom4(b)) + var logVpnAddr netip.Addr + if n.Details.OldVpnAddr != 0 { + b := [4]byte{} + binary.BigEndian.PutUint32(b[:], n.Details.OldVpnAddr) + logVpnAddr = netip.AddrFrom4(b) + } else if n.Details.VpnAddr != nil { + logVpnAddr = protoAddrToNetAddr(n.Details.VpnAddr) + } + lhh.l.Debugf("Punching on %v for %v", vpnPeer, logVpnAddr) } } - for _, a := range n.Details.Ip4AndPorts { - punch(AddrPortFromIp4AndPort(a)) + for _, a := range n.Details.V4AddrPorts { + punch(protoV4AddrPortToNetAddrPort(a)) } - for _, a := range n.Details.Ip6AndPorts { - punch(AddrPortFromIp6AndPort(a)) + for _, a := range n.Details.V6AddrPorts { + punch(protoV6AddrPortToNetAddrPort(a)) } // This sends a nebula test packet to the host trying to contact us. In the case // of a double nat or other difficult scenario, this may help establish // a tunnel. if lhh.lh.punchy.GetRespond() { - //TODO: IPV6-WORK - b := [4]byte{} - binary.BigEndian.PutUint32(b[:], n.Details.VpnIp) - queryVpnIp := netip.AddrFrom4(b) + var queryVpnAddr netip.Addr + if n.Details.OldVpnAddr != 0 { + b := [4]byte{} + binary.BigEndian.PutUint32(b[:], n.Details.OldVpnAddr) + queryVpnAddr = netip.AddrFrom4(b) + } else if n.Details.VpnAddr != nil { + queryVpnAddr = protoAddrToNetAddr(n.Details.VpnAddr) + } + go func() { time.Sleep(lhh.lh.punchy.GetRespondDelay()) if lhh.l.Level >= logrus.DebugLevel { - lhh.l.Debugf("Sending a nebula test packet to vpn ip %s", queryVpnIp) + lhh.l.Debugf("Sending a nebula test packet to vpn addr %s", queryVpnAddr) } //NOTE: we have to allocate a new output buffer here since we are spawning a new goroutine // for each punchBack packet. We should move this into a timerwheel or a single goroutine // managed by a channel. - w.SendMessageToVpnIp(header.Test, header.TestRequest, queryVpnIp, []byte(""), make([]byte, 12, 12), make([]byte, mtu)) + w.SendMessageToVpnAddr(header.Test, header.TestRequest, queryVpnAddr, []byte(""), make([]byte, 12, 12), make([]byte, mtu)) }() } } + +func protoAddrToNetAddr(addr *Addr) netip.Addr { + b := [16]byte{} + binary.BigEndian.PutUint64(b[:8], addr.Hi) + binary.BigEndian.PutUint64(b[8:], addr.Lo) + return netip.AddrFrom16(b).Unmap() +} + +func protoV4AddrPortToNetAddrPort(ap *V4AddrPort) netip.AddrPort { + b := [4]byte{} + binary.BigEndian.PutUint32(b[:], ap.Addr) + return netip.AddrPortFrom(netip.AddrFrom4(b), uint16(ap.Port)) +} + +func protoV6AddrPortToNetAddrPort(ap *V6AddrPort) netip.AddrPort { + b := [16]byte{} + binary.BigEndian.PutUint64(b[:8], ap.Hi) + binary.BigEndian.PutUint64(b[8:], ap.Lo) + return netip.AddrPortFrom(netip.AddrFrom16(b), uint16(ap.Port)) +} + +func netAddrToProtoAddr(addr netip.Addr) *Addr { + b := addr.As16() + return &Addr{ + Hi: binary.BigEndian.Uint64(b[:8]), + Lo: binary.BigEndian.Uint64(b[8:]), + } +} + +func netAddrToProtoV4AddrPort(addr netip.Addr, port uint16) *V4AddrPort { + v4Addr := addr.As4() + return &V4AddrPort{ + Addr: binary.BigEndian.Uint32(v4Addr[:]), + Port: uint32(port), + } +} + +func netAddrToProtoV6AddrPort(addr netip.Addr, port uint16) *V6AddrPort { + v6Addr := addr.As16() + return &V6AddrPort{ + Hi: binary.BigEndian.Uint64(v6Addr[:8]), + Lo: binary.BigEndian.Uint64(v6Addr[8:]), + Port: uint32(port), + } +} + +func (d *NebulaMetaDetails) GetRelays() []netip.Addr { + var relays []netip.Addr + if len(d.OldRelayVpnAddrs) > 0 { + b := [4]byte{} + for _, r := range d.OldRelayVpnAddrs { + binary.BigEndian.PutUint32(b[:], r) + relays = append(relays, netip.AddrFrom4(b)) + } + } + + if len(d.RelayVpnAddrs) > 0 { + for _, r := range d.RelayVpnAddrs { + relays = append(relays, protoAddrToNetAddr(r)) + } + } + return relays +} + +// FindNetworkUnion returns the first netip.Addr contained in the list of provided netip.Prefix, if able +func findNetworkUnion(prefixes []netip.Prefix, addrs []netip.Addr) (netip.Addr, bool) { + for i := range prefixes { + for j := range addrs { + if prefixes[i].Contains(addrs[j]) { + return addrs[j], true + } + } + } + return netip.Addr{}, false +} diff --git a/lighthouse_test.go b/lighthouse_test.go index 2599f5f..d5947aa 100644 --- a/lighthouse_test.go +++ b/lighthouse_test.go @@ -7,6 +7,8 @@ import ( "net/netip" "testing" + "github.com/gaissmai/bart" + "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/header" "github.com/slackhq/nebula/test" @@ -14,62 +16,51 @@ import ( "gopkg.in/yaml.v2" ) -//TODO: Add a test to ensure udpAddr is copied and not reused - func TestOldIPv4Only(t *testing.T) { // This test ensures our new ipv6 enabled LH protobuf IpAndPorts works with the old style to enable backwards compatibility b := []byte{8, 129, 130, 132, 80, 16, 10} - var m Ip4AndPort + var m V4AddrPort err := m.Unmarshal(b) assert.NoError(t, err) ip := netip.MustParseAddr("10.1.1.1") bp := ip.As4() - assert.Equal(t, binary.BigEndian.Uint32(bp[:]), m.GetIp()) -} - -func TestNewLhQuery(t *testing.T) { - myIp, err := netip.ParseAddr("192.1.1.1") - assert.NoError(t, err) - - // Generating a new lh query should work - a := NewLhQueryByInt(myIp) - - // The result should be a nebulameta protobuf - assert.IsType(t, &NebulaMeta{}, a) - - // It should also Marshal fine - b, err := a.Marshal() - assert.Nil(t, err) - - // and then Unmarshal fine - n := &NebulaMeta{} - err = n.Unmarshal(b) - assert.Nil(t, err) - + assert.Equal(t, binary.BigEndian.Uint32(bp[:]), m.GetAddr()) } func Test_lhStaticMapping(t *testing.T) { l := test.NewLogger() myVpnNet := netip.MustParsePrefix("10.128.0.1/16") + nt := new(bart.Table[struct{}]) + nt.Insert(myVpnNet, struct{}{}) + cs := &CertState{ + myVpnNetworks: []netip.Prefix{myVpnNet}, + myVpnNetworksTable: nt, + } lh1 := "10.128.0.2" c := config.NewC(l) c.Settings["lighthouse"] = map[interface{}]interface{}{"hosts": []interface{}{lh1}} c.Settings["static_host_map"] = map[interface{}]interface{}{lh1: []interface{}{"1.1.1.1:4242"}} - _, err := NewLightHouseFromConfig(context.Background(), l, c, myVpnNet, nil, nil) + _, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil) assert.Nil(t, err) lh2 := "10.128.0.3" c = config.NewC(l) c.Settings["lighthouse"] = map[interface{}]interface{}{"hosts": []interface{}{lh1, lh2}} c.Settings["static_host_map"] = map[interface{}]interface{}{lh1: []interface{}{"100.1.1.1:4242"}} - _, err = NewLightHouseFromConfig(context.Background(), l, c, myVpnNet, nil, nil) + _, err = NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil) assert.EqualError(t, err, "lighthouse 10.128.0.3 does not have a static_host_map entry") } func TestReloadLighthouseInterval(t *testing.T) { l := test.NewLogger() myVpnNet := netip.MustParsePrefix("10.128.0.1/16") + nt := new(bart.Table[struct{}]) + nt.Insert(myVpnNet, struct{}{}) + cs := &CertState{ + myVpnNetworks: []netip.Prefix{myVpnNet}, + myVpnNetworksTable: nt, + } lh1 := "10.128.0.2" c := config.NewC(l) @@ -79,7 +70,7 @@ func TestReloadLighthouseInterval(t *testing.T) { } c.Settings["static_host_map"] = map[interface{}]interface{}{lh1: []interface{}{"1.1.1.1:4242"}} - lh, err := NewLightHouseFromConfig(context.Background(), l, c, myVpnNet, nil, nil) + lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil) assert.NoError(t, err) lh.ifce = &mockEncWriter{} @@ -99,9 +90,15 @@ func TestReloadLighthouseInterval(t *testing.T) { func BenchmarkLighthouseHandleRequest(b *testing.B) { l := test.NewLogger() myVpnNet := netip.MustParsePrefix("10.128.0.1/0") + nt := new(bart.Table[struct{}]) + nt.Insert(myVpnNet, struct{}{}) + cs := &CertState{ + myVpnNetworks: []netip.Prefix{myVpnNet}, + myVpnNetworksTable: nt, + } c := config.NewC(l) - lh, err := NewLightHouseFromConfig(context.Background(), l, c, myVpnNet, nil, nil) + lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil) if !assert.NoError(b, err) { b.Fatal() } @@ -110,46 +107,47 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) { hAddr2 := netip.MustParseAddrPort("4.5.6.7:12346") vpnIp3 := netip.MustParseAddr("0.0.0.3") - lh.addrMap[vpnIp3] = NewRemoteList(nil) + lh.addrMap[vpnIp3] = NewRemoteList([]netip.Addr{vpnIp3}, nil) lh.addrMap[vpnIp3].unlockedSetV4( vpnIp3, vpnIp3, - []*Ip4AndPort{ - NewIp4AndPortFromNetIP(hAddr.Addr(), hAddr.Port()), - NewIp4AndPortFromNetIP(hAddr2.Addr(), hAddr2.Port()), + []*V4AddrPort{ + netAddrToProtoV4AddrPort(hAddr.Addr(), hAddr.Port()), + netAddrToProtoV4AddrPort(hAddr2.Addr(), hAddr2.Port()), }, - func(netip.Addr, *Ip4AndPort) bool { return true }, + func(netip.Addr, *V4AddrPort) bool { return true }, ) rAddr := netip.MustParseAddrPort("1.2.2.3:12345") rAddr2 := netip.MustParseAddrPort("1.2.2.3:12346") vpnIp2 := netip.MustParseAddr("0.0.0.3") - lh.addrMap[vpnIp2] = NewRemoteList(nil) + lh.addrMap[vpnIp2] = NewRemoteList([]netip.Addr{vpnIp2}, nil) lh.addrMap[vpnIp2].unlockedSetV4( vpnIp3, vpnIp3, - []*Ip4AndPort{ - NewIp4AndPortFromNetIP(rAddr.Addr(), rAddr.Port()), - NewIp4AndPortFromNetIP(rAddr2.Addr(), rAddr2.Port()), + []*V4AddrPort{ + netAddrToProtoV4AddrPort(rAddr.Addr(), rAddr.Port()), + netAddrToProtoV4AddrPort(rAddr2.Addr(), rAddr2.Port()), }, - func(netip.Addr, *Ip4AndPort) bool { return true }, + func(netip.Addr, *V4AddrPort) bool { return true }, ) mw := &mockEncWriter{} + hi := []netip.Addr{vpnIp2} b.Run("notfound", func(b *testing.B) { lhh := lh.NewRequestHandler() req := &NebulaMeta{ Type: NebulaMeta_HostQuery, Details: &NebulaMetaDetails{ - VpnIp: 4, - Ip4AndPorts: nil, + OldVpnAddr: 4, + V4AddrPorts: nil, }, } p, err := req.Marshal() assert.NoError(b, err) for n := 0; n < b.N; n++ { - lhh.HandleRequest(rAddr, vpnIp2, p, mw) + lhh.HandleRequest(rAddr, hi, p, mw) } }) b.Run("found", func(b *testing.B) { @@ -157,15 +155,15 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) { req := &NebulaMeta{ Type: NebulaMeta_HostQuery, Details: &NebulaMetaDetails{ - VpnIp: 3, - Ip4AndPorts: nil, + OldVpnAddr: 3, + V4AddrPorts: nil, }, } p, err := req.Marshal() assert.NoError(b, err) for n := 0; n < b.N; n++ { - lhh.HandleRequest(rAddr, vpnIp2, p, mw) + lhh.HandleRequest(rAddr, hi, p, mw) } }) } @@ -197,40 +195,49 @@ func TestLighthouse_Memory(t *testing.T) { c := config.NewC(l) c.Settings["lighthouse"] = map[interface{}]interface{}{"am_lighthouse": true} c.Settings["listen"] = map[interface{}]interface{}{"port": 4242} - lh, err := NewLightHouseFromConfig(context.Background(), l, c, netip.MustParsePrefix("10.128.0.1/24"), nil, nil) + + myVpnNet := netip.MustParsePrefix("10.128.0.1/24") + nt := new(bart.Table[struct{}]) + nt.Insert(myVpnNet, struct{}{}) + cs := &CertState{ + myVpnNetworks: []netip.Prefix{myVpnNet}, + myVpnNetworksTable: nt, + } + lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil) + lh.ifce = &mockEncWriter{} assert.NoError(t, err) lhh := lh.NewRequestHandler() // Test that my first update responds with just that newLHHostUpdate(myUdpAddr0, myVpnIp, []netip.AddrPort{myUdpAddr1, myUdpAddr2}, lhh) r := newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh) - assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr1, myUdpAddr2) + assertIp4InArray(t, r.msg.Details.V4AddrPorts, myUdpAddr1, myUdpAddr2) // Ensure we don't accumulate addresses newLHHostUpdate(myUdpAddr0, myVpnIp, []netip.AddrPort{myUdpAddr3}, lhh) r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh) - assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr3) + assertIp4InArray(t, r.msg.Details.V4AddrPorts, myUdpAddr3) // Grow it back to 2 newLHHostUpdate(myUdpAddr0, myVpnIp, []netip.AddrPort{myUdpAddr1, myUdpAddr4}, lhh) r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh) - assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr1, myUdpAddr4) + assertIp4InArray(t, r.msg.Details.V4AddrPorts, myUdpAddr1, myUdpAddr4) // Update a different host and ask about it newLHHostUpdate(theirUdpAddr0, theirVpnIp, []netip.AddrPort{theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4}, lhh) r = newLHHostRequest(theirUdpAddr0, theirVpnIp, theirVpnIp, lhh) - assertIp4InArray(t, r.msg.Details.Ip4AndPorts, theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4) + assertIp4InArray(t, r.msg.Details.V4AddrPorts, theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4) // Have both hosts ask about the other r = newLHHostRequest(theirUdpAddr0, theirVpnIp, myVpnIp, lhh) - assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr1, myUdpAddr4) + assertIp4InArray(t, r.msg.Details.V4AddrPorts, myUdpAddr1, myUdpAddr4) r = newLHHostRequest(myUdpAddr0, myVpnIp, theirVpnIp, lhh) - assertIp4InArray(t, r.msg.Details.Ip4AndPorts, theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4) + assertIp4InArray(t, r.msg.Details.V4AddrPorts, theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4) // Make sure we didn't get changed r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh) - assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr1, myUdpAddr4) + assertIp4InArray(t, r.msg.Details.V4AddrPorts, myUdpAddr1, myUdpAddr4) // Ensure proper ordering and limiting // Send 12 addrs, get 10 back, the last 2 removed, allowing the duplicate to remain (clients dedupe) @@ -255,7 +262,7 @@ func TestLighthouse_Memory(t *testing.T) { r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh) assertIp4InArray( t, - r.msg.Details.Ip4AndPorts, + r.msg.Details.V4AddrPorts, myUdpAddr1, myUdpAddr2, myUdpAddr3, myUdpAddr4, myUdpAddr5, myUdpAddr5, myUdpAddr6, myUdpAddr7, myUdpAddr8, myUdpAddr9, ) @@ -265,7 +272,7 @@ func TestLighthouse_Memory(t *testing.T) { good := netip.MustParseAddrPort("1.128.0.99:4242") newLHHostUpdate(myUdpAddr0, myVpnIp, []netip.AddrPort{bad1, bad2, good}, lhh) r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh) - assertIp4InArray(t, r.msg.Details.Ip4AndPorts, good) + assertIp4InArray(t, r.msg.Details.V4AddrPorts, good) } func TestLighthouse_reload(t *testing.T) { @@ -273,7 +280,16 @@ func TestLighthouse_reload(t *testing.T) { c := config.NewC(l) c.Settings["lighthouse"] = map[interface{}]interface{}{"am_lighthouse": true} c.Settings["listen"] = map[interface{}]interface{}{"port": 4242} - lh, err := NewLightHouseFromConfig(context.Background(), l, c, netip.MustParsePrefix("10.128.0.1/24"), nil, nil) + + myVpnNet := netip.MustParsePrefix("10.128.0.1/24") + nt := new(bart.Table[struct{}]) + nt.Insert(myVpnNet, struct{}{}) + cs := &CertState{ + myVpnNetworks: []netip.Prefix{myVpnNet}, + myVpnNetworksTable: nt, + } + + lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil) assert.NoError(t, err) nc := map[interface{}]interface{}{ @@ -290,13 +306,16 @@ func TestLighthouse_reload(t *testing.T) { } func newLHHostRequest(fromAddr netip.AddrPort, myVpnIp, queryVpnIp netip.Addr, lhh *LightHouseHandler) testLhReply { - //TODO: IPV6-WORK - bip := queryVpnIp.As4() req := &NebulaMeta{ - Type: NebulaMeta_HostQuery, - Details: &NebulaMetaDetails{ - VpnIp: binary.BigEndian.Uint32(bip[:]), - }, + Type: NebulaMeta_HostQuery, + Details: &NebulaMetaDetails{}, + } + + if queryVpnIp.Is4() { + bip := queryVpnIp.As4() + req.Details.OldVpnAddr = binary.BigEndian.Uint32(bip[:]) + } else { + req.Details.VpnAddr = netAddrToProtoAddr(queryVpnIp) } b, err := req.Marshal() @@ -308,23 +327,29 @@ func newLHHostRequest(fromAddr netip.AddrPort, myVpnIp, queryVpnIp netip.Addr, l w := &testEncWriter{ metaFilter: &filter, } - lhh.HandleRequest(fromAddr, myVpnIp, b, w) + lhh.HandleRequest(fromAddr, []netip.Addr{myVpnIp}, b, w) return w.lastReply } func newLHHostUpdate(fromAddr netip.AddrPort, vpnIp netip.Addr, addrs []netip.AddrPort, lhh *LightHouseHandler) { - //TODO: IPV6-WORK - bip := vpnIp.As4() req := &NebulaMeta{ - Type: NebulaMeta_HostUpdateNotification, - Details: &NebulaMetaDetails{ - VpnIp: binary.BigEndian.Uint32(bip[:]), - Ip4AndPorts: make([]*Ip4AndPort, len(addrs)), - }, + Type: NebulaMeta_HostUpdateNotification, + Details: &NebulaMetaDetails{}, } - for k, v := range addrs { - req.Details.Ip4AndPorts[k] = NewIp4AndPortFromNetIP(v.Addr(), v.Port()) + if vpnIp.Is4() { + bip := vpnIp.As4() + req.Details.OldVpnAddr = binary.BigEndian.Uint32(bip[:]) + } else { + req.Details.VpnAddr = netAddrToProtoAddr(vpnIp) + } + + for _, v := range addrs { + if v.Addr().Is4() { + req.Details.V4AddrPorts = append(req.Details.V4AddrPorts, netAddrToProtoV4AddrPort(v.Addr(), v.Port())) + } else { + req.Details.V6AddrPorts = append(req.Details.V6AddrPorts, netAddrToProtoV6AddrPort(v.Addr(), v.Port())) + } } b, err := req.Marshal() @@ -333,75 +358,9 @@ func newLHHostUpdate(fromAddr netip.AddrPort, vpnIp netip.Addr, addrs []netip.Ad } w := &testEncWriter{} - lhh.HandleRequest(fromAddr, vpnIp, b, w) + lhh.HandleRequest(fromAddr, []netip.Addr{vpnIp}, b, w) } -//TODO: this is a RemoteList test -//func Test_lhRemoteAllowList(t *testing.T) { -// l := NewLogger() -// c := NewConfig(l) -// c.Settings["remoteallowlist"] = map[interface{}]interface{}{ -// "10.20.0.0/12": false, -// } -// allowList, err := c.GetAllowList("remoteallowlist", false) -// assert.Nil(t, err) -// -// lh1 := "10.128.0.2" -// lh1IP := net.ParseIP(lh1) -// -// udpServer, _ := NewListener(l, "0.0.0.0", 0, true) -// -// lh := NewLightHouse(l, true, &net.IPNet{IP: net.IP{0, 0, 0, 1}, Mask: net.IPMask{255, 255, 255, 0}}, []uint32{ip2int(lh1IP)}, 10, 10003, udpServer, false, 1, false) -// lh.SetRemoteAllowList(allowList) -// -// // A disallowed ip should not enter the cache but we should end up with an empty entry in the addrMap -// remote1IP := net.ParseIP("10.20.0.3") -// remotes := lh.unlockedGetRemoteList(ip2int(remote1IP)) -// remotes.unlockedPrependV4(ip2int(remote1IP), NewIp4AndPort(remote1IP, 4242)) -// assert.NotNil(t, lh.addrMap[ip2int(remote1IP)]) -// assert.Empty(t, lh.addrMap[ip2int(remote1IP)].CopyAddrs([]*net.IPNet{})) -// -// // Make sure a good ip enters the cache and addrMap -// remote2IP := net.ParseIP("10.128.0.3") -// remote2UDPAddr := NewUDPAddr(remote2IP, uint16(4242)) -// lh.addRemoteV4(ip2int(remote2IP), ip2int(remote2IP), NewIp4AndPort(remote2UDPAddr.IP, uint32(remote2UDPAddr.Port)), false, false) -// assertUdpAddrInArray(t, lh.addrMap[ip2int(remote2IP)].CopyAddrs([]*net.IPNet{}), remote2UDPAddr) -// -// // Another good ip gets into the cache, ordering is inverted -// remote3IP := net.ParseIP("10.128.0.4") -// remote3UDPAddr := NewUDPAddr(remote3IP, uint16(4243)) -// lh.addRemoteV4(ip2int(remote2IP), ip2int(remote2IP), NewIp4AndPort(remote3UDPAddr.IP, uint32(remote3UDPAddr.Port)), false, false) -// assertUdpAddrInArray(t, lh.addrMap[ip2int(remote2IP)].CopyAddrs([]*net.IPNet{}), remote2UDPAddr, remote3UDPAddr) -// -// // If we exceed the length limit we should only have the most recent addresses -// addedAddrs := []*udpAddr{} -// for i := 0; i < 11; i++ { -// remoteUDPAddr := NewUDPAddr(net.IP{10, 128, 0, 4}, uint16(4243+i)) -// lh.addRemoteV4(ip2int(remote2IP), ip2int(remote2IP), NewIp4AndPort(remoteUDPAddr.IP, uint32(remoteUDPAddr.Port)), false, false) -// // The first entry here is a duplicate, don't add it to the assert list -// if i != 0 { -// addedAddrs = append(addedAddrs, remoteUDPAddr) -// } -// } -// -// // We should only have the last 10 of what we tried to add -// assert.True(t, len(addedAddrs) >= 10, "We should have tried to add at least 10 addresses") -// assertUdpAddrInArray( -// t, -// lh.addrMap[ip2int(remote2IP)].CopyAddrs([]*net.IPNet{}), -// addedAddrs[0], -// addedAddrs[1], -// addedAddrs[2], -// addedAddrs[3], -// addedAddrs[4], -// addedAddrs[5], -// addedAddrs[6], -// addedAddrs[7], -// addedAddrs[8], -// addedAddrs[9], -// ) -//} - type testLhReply struct { nebType header.MessageType nebSubType header.MessageSubType @@ -410,8 +369,9 @@ type testLhReply struct { } type testEncWriter struct { - lastReply testLhReply - metaFilter *NebulaMeta_MessageType + lastReply testLhReply + metaFilter *NebulaMeta_MessageType + protocolVersion cert.Version } func (tw *testEncWriter) SendVia(via *HostInfo, relay *Relay, ad, nb, out []byte, nocopy bool) { @@ -426,7 +386,7 @@ func (tw *testEncWriter) SendMessageToHostInfo(t header.MessageType, st header.M tw.lastReply = testLhReply{ nebType: t, nebSubType: st, - vpnIp: hostinfo.vpnIp, + vpnIp: hostinfo.vpnAddrs[0], msg: msg, } } @@ -436,7 +396,7 @@ func (tw *testEncWriter) SendMessageToHostInfo(t header.MessageType, st header.M } } -func (tw *testEncWriter) SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp netip.Addr, p, _, _ []byte) { +func (tw *testEncWriter) SendMessageToVpnAddr(t header.MessageType, st header.MessageSubType, vpnIp netip.Addr, p, _, _ []byte) { msg := &NebulaMeta{} err := msg.Unmarshal(p) if tw.metaFilter == nil || msg.Type == *tw.metaFilter { @@ -453,17 +413,84 @@ func (tw *testEncWriter) SendMessageToVpnIp(t header.MessageType, st header.Mess } } +func (tw *testEncWriter) GetHostInfo(vpnIp netip.Addr) *HostInfo { + return nil +} + +func (tw *testEncWriter) GetCertState() *CertState { + return &CertState{defaultVersion: tw.protocolVersion} +} + // assertIp4InArray asserts every address in want is at the same position in have and that the lengths match -func assertIp4InArray(t *testing.T, have []*Ip4AndPort, want ...netip.AddrPort) { +func assertIp4InArray(t *testing.T, have []*V4AddrPort, want ...netip.AddrPort) { if !assert.Len(t, have, len(want)) { return } for k, w := range want { - //TODO: IPV6-WORK - h := AddrPortFromIp4AndPort(have[k]) + h := protoV4AddrPortToNetAddrPort(have[k]) if !(h == w) { assert.Fail(t, fmt.Sprintf("Response did not contain: %v at %v, found %v", w, k, h)) } } } + +func Test_findNetworkUnion(t *testing.T) { + var out netip.Addr + var ok bool + + tenDot := netip.MustParsePrefix("10.0.0.0/8") + oneSevenTwo := netip.MustParsePrefix("172.16.0.0/16") + fe80 := netip.MustParsePrefix("fe80::/8") + fc00 := netip.MustParsePrefix("fc00::/7") + + a1 := netip.MustParseAddr("10.0.0.1") + afe81 := netip.MustParseAddr("fe80::1") + + //simple + out, ok = findNetworkUnion([]netip.Prefix{tenDot}, []netip.Addr{a1}) + assert.True(t, ok) + assert.Equal(t, out, a1) + + //mixed lengths + out, ok = findNetworkUnion([]netip.Prefix{tenDot}, []netip.Addr{a1, afe81}) + assert.True(t, ok) + assert.Equal(t, out, a1) + out, ok = findNetworkUnion([]netip.Prefix{tenDot, oneSevenTwo}, []netip.Addr{a1}) + assert.True(t, ok) + assert.Equal(t, out, a1) + + //mixed family + out, ok = findNetworkUnion([]netip.Prefix{tenDot, oneSevenTwo, fe80}, []netip.Addr{a1}) + assert.True(t, ok) + assert.Equal(t, out, a1) + out, ok = findNetworkUnion([]netip.Prefix{tenDot, oneSevenTwo, fe80}, []netip.Addr{a1, afe81}) + assert.True(t, ok) + assert.Equal(t, out, a1) + + //ordering + out, ok = findNetworkUnion([]netip.Prefix{tenDot, oneSevenTwo, fe80}, []netip.Addr{afe81, a1}) + assert.True(t, ok) + assert.Equal(t, out, a1) + out, ok = findNetworkUnion([]netip.Prefix{fe80, tenDot, oneSevenTwo}, []netip.Addr{afe81, a1}) + assert.True(t, ok) + assert.Equal(t, out, afe81) + + //some mismatches + out, ok = findNetworkUnion([]netip.Prefix{tenDot, oneSevenTwo, fe80}, []netip.Addr{afe81}) + assert.True(t, ok) + assert.Equal(t, out, afe81) + out, ok = findNetworkUnion([]netip.Prefix{oneSevenTwo, fe80}, []netip.Addr{a1, afe81}) + assert.True(t, ok) + assert.Equal(t, out, afe81) + + //falsey cases + out, ok = findNetworkUnion([]netip.Prefix{oneSevenTwo, fe80}, []netip.Addr{a1}) + assert.False(t, ok) + out, ok = findNetworkUnion([]netip.Prefix{fc00, fe80}, []netip.Addr{a1}) + assert.False(t, ok) + out, ok = findNetworkUnion([]netip.Prefix{oneSevenTwo, fc00}, []netip.Addr{a1, afe81}) + assert.False(t, ok) + out, ok = findNetworkUnion([]netip.Prefix{fc00}, []netip.Addr{a1, afe81}) + assert.False(t, ok) +} diff --git a/main.go b/main.go index 8f45359..7e94c32 100644 --- a/main.go +++ b/main.go @@ -2,7 +2,6 @@ package nebula import ( "context" - "encoding/binary" "fmt" "net" "net/netip" @@ -61,15 +60,12 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg return nil, util.ContextualizeIfNeeded("Failed to load PKI from config", err) } - certificate := pki.GetCertState().Certificate - fw, err := NewFirewallFromConfig(l, certificate, c) + fw, err := NewFirewallFromConfig(l, pki.getCertState(), c) if err != nil { return nil, util.ContextualizeIfNeeded("Error while loading firewall rules", err) } l.WithField("firewallHashes", fw.GetRuleHashes()).Info("Firewall started") - tunCidr := certificate.Networks()[0] - ssh, err := sshd.NewSSHServer(l.WithField("subsystem", "sshd")) if err != nil { return nil, util.ContextualizeIfNeeded("Error while creating SSH server", err) @@ -132,7 +128,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg deviceFactory = overlay.NewDeviceFromConfig } - tun, err = deviceFactory(c, l, tunCidr, routines) + tun, err = deviceFactory(c, l, pki.getCertState().myVpnNetworks, routines) if err != nil { return nil, util.ContextualizeIfNeeded("Failed to get a tun/tap device", err) } @@ -187,9 +183,9 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg } } - hostMap := NewHostMapFromConfig(l, tunCidr, c) + hostMap := NewHostMapFromConfig(l, c) punchy := NewPunchyFromConfig(l, c) - lightHouse, err := NewLightHouseFromConfig(ctx, l, c, tunCidr, udpConns[0], punchy) + lightHouse, err := NewLightHouseFromConfig(ctx, l, c, pki.getCertState(), udpConns[0], punchy) if err != nil { return nil, util.ContextualizeIfNeeded("Failed to initialize lighthouse handler", err) } @@ -232,7 +228,6 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg Inside: tun, Outside: udpConns[0], pki: pki, - Cipher: c.GetString("cipher", "aes"), Firewall: fw, ServeDns: serveDns, HandshakeManager: handshakeManager, @@ -254,15 +249,6 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg l: l, } - switch ifConfig.Cipher { - case "aes": - noiseEndianness = binary.BigEndian - case "chachapoly": - noiseEndianness = binary.LittleEndian - default: - return nil, fmt.Errorf("unknown cipher: %v", ifConfig.Cipher) - } - var ifce *Interface if !configTest { ifce, err = NewInterface(ctx, ifConfig) @@ -270,8 +256,6 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg return nil, fmt.Errorf("failed to initialize interface: %s", err) } - // TODO: Better way to attach these, probably want a new interface in InterfaceConfig - // I don't want to make this initial commit too far-reaching though ifce.writers = udpConns lightHouse.ifce = ifce @@ -283,8 +267,6 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg go handshakeManager.Run(ctx) } - // TODO - stats third-party modules start uncancellable goroutines. Update those libs to accept - // a context so that they can exit when the context is Done. statsStart, err := startStats(l, c, buildVersion, configTest) if err != nil { return nil, util.ContextualizeIfNeeded("Failed to start stats emitter", err) @@ -294,7 +276,6 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg return nil, nil } - //TODO: check if we _should_ be emitting stats go ifce.emitStats(ctx, c.GetDuration("stats.interval", time.Second*10)) attachCommands(l, c, ssh, ifce) @@ -303,7 +284,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg var dnsStart func() if lightHouse.amLighthouse && serveDns { l.Debugln("Starting dns server") - dnsStart = dnsMain(l, hostMap, c) + dnsStart = dnsMain(l, pki.getCertState(), hostMap, c) } return &Control{ diff --git a/message_metrics.go b/message_metrics.go index 94bb02f..10e8472 100644 --- a/message_metrics.go +++ b/message_metrics.go @@ -7,8 +7,6 @@ import ( "github.com/slackhq/nebula/header" ) -//TODO: this can probably move into the header package - type MessageMetrics struct { rx [][]metrics.Counter tx [][]metrics.Counter diff --git a/nebula.pb.go b/nebula.pb.go index b3c723a..2fd2ff6 100644 --- a/nebula.pb.go +++ b/nebula.pb.go @@ -96,7 +96,7 @@ func (x NebulaPing_MessageType) String() string { } func (NebulaPing_MessageType) EnumDescriptor() ([]byte, []int) { - return fileDescriptor_2d65afa7693df5ef, []int{4, 0} + return fileDescriptor_2d65afa7693df5ef, []int{5, 0} } type NebulaControl_MessageType int32 @@ -124,7 +124,7 @@ func (x NebulaControl_MessageType) String() string { } func (NebulaControl_MessageType) EnumDescriptor() ([]byte, []int) { - return fileDescriptor_2d65afa7693df5ef, []int{7, 0} + return fileDescriptor_2d65afa7693df5ef, []int{8, 0} } type NebulaMeta struct { @@ -180,11 +180,13 @@ func (m *NebulaMeta) GetDetails() *NebulaMetaDetails { } type NebulaMetaDetails struct { - VpnIp uint32 `protobuf:"varint,1,opt,name=VpnIp,proto3" json:"VpnIp,omitempty"` - Ip4AndPorts []*Ip4AndPort `protobuf:"bytes,2,rep,name=Ip4AndPorts,proto3" json:"Ip4AndPorts,omitempty"` - Ip6AndPorts []*Ip6AndPort `protobuf:"bytes,4,rep,name=Ip6AndPorts,proto3" json:"Ip6AndPorts,omitempty"` - RelayVpnIp []uint32 `protobuf:"varint,5,rep,packed,name=RelayVpnIp,proto3" json:"RelayVpnIp,omitempty"` - Counter uint32 `protobuf:"varint,3,opt,name=counter,proto3" json:"counter,omitempty"` + OldVpnAddr uint32 `protobuf:"varint,1,opt,name=OldVpnAddr,proto3" json:"OldVpnAddr,omitempty"` // Deprecated: Do not use. + VpnAddr *Addr `protobuf:"bytes,6,opt,name=VpnAddr,proto3" json:"VpnAddr,omitempty"` + OldRelayVpnAddrs []uint32 `protobuf:"varint,5,rep,packed,name=OldRelayVpnAddrs,proto3" json:"OldRelayVpnAddrs,omitempty"` // Deprecated: Do not use. + RelayVpnAddrs []*Addr `protobuf:"bytes,7,rep,name=RelayVpnAddrs,proto3" json:"RelayVpnAddrs,omitempty"` + V4AddrPorts []*V4AddrPort `protobuf:"bytes,2,rep,name=V4AddrPorts,proto3" json:"V4AddrPorts,omitempty"` + V6AddrPorts []*V6AddrPort `protobuf:"bytes,4,rep,name=V6AddrPorts,proto3" json:"V6AddrPorts,omitempty"` + Counter uint32 `protobuf:"varint,3,opt,name=counter,proto3" json:"counter,omitempty"` } func (m *NebulaMetaDetails) Reset() { *m = NebulaMetaDetails{} } @@ -220,30 +222,46 @@ func (m *NebulaMetaDetails) XXX_DiscardUnknown() { var xxx_messageInfo_NebulaMetaDetails proto.InternalMessageInfo -func (m *NebulaMetaDetails) GetVpnIp() uint32 { +// Deprecated: Do not use. +func (m *NebulaMetaDetails) GetOldVpnAddr() uint32 { if m != nil { - return m.VpnIp + return m.OldVpnAddr } return 0 } -func (m *NebulaMetaDetails) GetIp4AndPorts() []*Ip4AndPort { +func (m *NebulaMetaDetails) GetVpnAddr() *Addr { if m != nil { - return m.Ip4AndPorts + return m.VpnAddr } return nil } -func (m *NebulaMetaDetails) GetIp6AndPorts() []*Ip6AndPort { +// Deprecated: Do not use. +func (m *NebulaMetaDetails) GetOldRelayVpnAddrs() []uint32 { if m != nil { - return m.Ip6AndPorts + return m.OldRelayVpnAddrs } return nil } -func (m *NebulaMetaDetails) GetRelayVpnIp() []uint32 { +func (m *NebulaMetaDetails) GetRelayVpnAddrs() []*Addr { if m != nil { - return m.RelayVpnIp + return m.RelayVpnAddrs + } + return nil +} + +func (m *NebulaMetaDetails) GetV4AddrPorts() []*V4AddrPort { + if m != nil { + return m.V4AddrPorts + } + return nil +} + +func (m *NebulaMetaDetails) GetV6AddrPorts() []*V6AddrPort { + if m != nil { + return m.V6AddrPorts } return nil } @@ -255,23 +273,23 @@ func (m *NebulaMetaDetails) GetCounter() uint32 { return 0 } -type Ip4AndPort struct { - Ip uint32 `protobuf:"varint,1,opt,name=Ip,proto3" json:"Ip,omitempty"` - Port uint32 `protobuf:"varint,2,opt,name=Port,proto3" json:"Port,omitempty"` +type Addr struct { + Hi uint64 `protobuf:"varint,1,opt,name=Hi,proto3" json:"Hi,omitempty"` + Lo uint64 `protobuf:"varint,2,opt,name=Lo,proto3" json:"Lo,omitempty"` } -func (m *Ip4AndPort) Reset() { *m = Ip4AndPort{} } -func (m *Ip4AndPort) String() string { return proto.CompactTextString(m) } -func (*Ip4AndPort) ProtoMessage() {} -func (*Ip4AndPort) Descriptor() ([]byte, []int) { +func (m *Addr) Reset() { *m = Addr{} } +func (m *Addr) String() string { return proto.CompactTextString(m) } +func (*Addr) ProtoMessage() {} +func (*Addr) Descriptor() ([]byte, []int) { return fileDescriptor_2d65afa7693df5ef, []int{2} } -func (m *Ip4AndPort) XXX_Unmarshal(b []byte) error { +func (m *Addr) XXX_Unmarshal(b []byte) error { return m.Unmarshal(b) } -func (m *Ip4AndPort) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { +func (m *Addr) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { if deterministic { - return xxx_messageInfo_Ip4AndPort.Marshal(b, m, deterministic) + return xxx_messageInfo_Addr.Marshal(b, m, deterministic) } else { b = b[:cap(b)] n, err := m.MarshalToSizedBuffer(b) @@ -281,86 +299,138 @@ func (m *Ip4AndPort) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { return b[:n], nil } } -func (m *Ip4AndPort) XXX_Merge(src proto.Message) { - xxx_messageInfo_Ip4AndPort.Merge(m, src) +func (m *Addr) XXX_Merge(src proto.Message) { + xxx_messageInfo_Addr.Merge(m, src) } -func (m *Ip4AndPort) XXX_Size() int { +func (m *Addr) XXX_Size() int { return m.Size() } -func (m *Ip4AndPort) XXX_DiscardUnknown() { - xxx_messageInfo_Ip4AndPort.DiscardUnknown(m) +func (m *Addr) XXX_DiscardUnknown() { + xxx_messageInfo_Addr.DiscardUnknown(m) } -var xxx_messageInfo_Ip4AndPort proto.InternalMessageInfo +var xxx_messageInfo_Addr proto.InternalMessageInfo -func (m *Ip4AndPort) GetIp() uint32 { - if m != nil { - return m.Ip - } - return 0 -} - -func (m *Ip4AndPort) GetPort() uint32 { - if m != nil { - return m.Port - } - return 0 -} - -type Ip6AndPort struct { - Hi uint64 `protobuf:"varint,1,opt,name=Hi,proto3" json:"Hi,omitempty"` - Lo uint64 `protobuf:"varint,2,opt,name=Lo,proto3" json:"Lo,omitempty"` - Port uint32 `protobuf:"varint,3,opt,name=Port,proto3" json:"Port,omitempty"` -} - -func (m *Ip6AndPort) Reset() { *m = Ip6AndPort{} } -func (m *Ip6AndPort) String() string { return proto.CompactTextString(m) } -func (*Ip6AndPort) ProtoMessage() {} -func (*Ip6AndPort) Descriptor() ([]byte, []int) { - return fileDescriptor_2d65afa7693df5ef, []int{3} -} -func (m *Ip6AndPort) XXX_Unmarshal(b []byte) error { - return m.Unmarshal(b) -} -func (m *Ip6AndPort) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { - if deterministic { - return xxx_messageInfo_Ip6AndPort.Marshal(b, m, deterministic) - } else { - b = b[:cap(b)] - n, err := m.MarshalToSizedBuffer(b) - if err != nil { - return nil, err - } - return b[:n], nil - } -} -func (m *Ip6AndPort) XXX_Merge(src proto.Message) { - xxx_messageInfo_Ip6AndPort.Merge(m, src) -} -func (m *Ip6AndPort) XXX_Size() int { - return m.Size() -} -func (m *Ip6AndPort) XXX_DiscardUnknown() { - xxx_messageInfo_Ip6AndPort.DiscardUnknown(m) -} - -var xxx_messageInfo_Ip6AndPort proto.InternalMessageInfo - -func (m *Ip6AndPort) GetHi() uint64 { +func (m *Addr) GetHi() uint64 { if m != nil { return m.Hi } return 0 } -func (m *Ip6AndPort) GetLo() uint64 { +func (m *Addr) GetLo() uint64 { if m != nil { return m.Lo } return 0 } -func (m *Ip6AndPort) GetPort() uint32 { +type V4AddrPort struct { + Addr uint32 `protobuf:"varint,1,opt,name=Addr,proto3" json:"Addr,omitempty"` + Port uint32 `protobuf:"varint,2,opt,name=Port,proto3" json:"Port,omitempty"` +} + +func (m *V4AddrPort) Reset() { *m = V4AddrPort{} } +func (m *V4AddrPort) String() string { return proto.CompactTextString(m) } +func (*V4AddrPort) ProtoMessage() {} +func (*V4AddrPort) Descriptor() ([]byte, []int) { + return fileDescriptor_2d65afa7693df5ef, []int{3} +} +func (m *V4AddrPort) XXX_Unmarshal(b []byte) error { + return m.Unmarshal(b) +} +func (m *V4AddrPort) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + if deterministic { + return xxx_messageInfo_V4AddrPort.Marshal(b, m, deterministic) + } else { + b = b[:cap(b)] + n, err := m.MarshalToSizedBuffer(b) + if err != nil { + return nil, err + } + return b[:n], nil + } +} +func (m *V4AddrPort) XXX_Merge(src proto.Message) { + xxx_messageInfo_V4AddrPort.Merge(m, src) +} +func (m *V4AddrPort) XXX_Size() int { + return m.Size() +} +func (m *V4AddrPort) XXX_DiscardUnknown() { + xxx_messageInfo_V4AddrPort.DiscardUnknown(m) +} + +var xxx_messageInfo_V4AddrPort proto.InternalMessageInfo + +func (m *V4AddrPort) GetAddr() uint32 { + if m != nil { + return m.Addr + } + return 0 +} + +func (m *V4AddrPort) GetPort() uint32 { + if m != nil { + return m.Port + } + return 0 +} + +type V6AddrPort struct { + Hi uint64 `protobuf:"varint,1,opt,name=Hi,proto3" json:"Hi,omitempty"` + Lo uint64 `protobuf:"varint,2,opt,name=Lo,proto3" json:"Lo,omitempty"` + Port uint32 `protobuf:"varint,3,opt,name=Port,proto3" json:"Port,omitempty"` +} + +func (m *V6AddrPort) Reset() { *m = V6AddrPort{} } +func (m *V6AddrPort) String() string { return proto.CompactTextString(m) } +func (*V6AddrPort) ProtoMessage() {} +func (*V6AddrPort) Descriptor() ([]byte, []int) { + return fileDescriptor_2d65afa7693df5ef, []int{4} +} +func (m *V6AddrPort) XXX_Unmarshal(b []byte) error { + return m.Unmarshal(b) +} +func (m *V6AddrPort) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + if deterministic { + return xxx_messageInfo_V6AddrPort.Marshal(b, m, deterministic) + } else { + b = b[:cap(b)] + n, err := m.MarshalToSizedBuffer(b) + if err != nil { + return nil, err + } + return b[:n], nil + } +} +func (m *V6AddrPort) XXX_Merge(src proto.Message) { + xxx_messageInfo_V6AddrPort.Merge(m, src) +} +func (m *V6AddrPort) XXX_Size() int { + return m.Size() +} +func (m *V6AddrPort) XXX_DiscardUnknown() { + xxx_messageInfo_V6AddrPort.DiscardUnknown(m) +} + +var xxx_messageInfo_V6AddrPort proto.InternalMessageInfo + +func (m *V6AddrPort) GetHi() uint64 { + if m != nil { + return m.Hi + } + return 0 +} + +func (m *V6AddrPort) GetLo() uint64 { + if m != nil { + return m.Lo + } + return 0 +} + +func (m *V6AddrPort) GetPort() uint32 { if m != nil { return m.Port } @@ -376,7 +446,7 @@ func (m *NebulaPing) Reset() { *m = NebulaPing{} } func (m *NebulaPing) String() string { return proto.CompactTextString(m) } func (*NebulaPing) ProtoMessage() {} func (*NebulaPing) Descriptor() ([]byte, []int) { - return fileDescriptor_2d65afa7693df5ef, []int{4} + return fileDescriptor_2d65afa7693df5ef, []int{5} } func (m *NebulaPing) XXX_Unmarshal(b []byte) error { return m.Unmarshal(b) @@ -428,7 +498,7 @@ func (m *NebulaHandshake) Reset() { *m = NebulaHandshake{} } func (m *NebulaHandshake) String() string { return proto.CompactTextString(m) } func (*NebulaHandshake) ProtoMessage() {} func (*NebulaHandshake) Descriptor() ([]byte, []int) { - return fileDescriptor_2d65afa7693df5ef, []int{5} + return fileDescriptor_2d65afa7693df5ef, []int{6} } func (m *NebulaHandshake) XXX_Unmarshal(b []byte) error { return m.Unmarshal(b) @@ -477,13 +547,14 @@ type NebulaHandshakeDetails struct { ResponderIndex uint32 `protobuf:"varint,3,opt,name=ResponderIndex,proto3" json:"ResponderIndex,omitempty"` Cookie uint64 `protobuf:"varint,4,opt,name=Cookie,proto3" json:"Cookie,omitempty"` Time uint64 `protobuf:"varint,5,opt,name=Time,proto3" json:"Time,omitempty"` + CertVersion uint32 `protobuf:"varint,8,opt,name=CertVersion,proto3" json:"CertVersion,omitempty"` } func (m *NebulaHandshakeDetails) Reset() { *m = NebulaHandshakeDetails{} } func (m *NebulaHandshakeDetails) String() string { return proto.CompactTextString(m) } func (*NebulaHandshakeDetails) ProtoMessage() {} func (*NebulaHandshakeDetails) Descriptor() ([]byte, []int) { - return fileDescriptor_2d65afa7693df5ef, []int{6} + return fileDescriptor_2d65afa7693df5ef, []int{7} } func (m *NebulaHandshakeDetails) XXX_Unmarshal(b []byte) error { return m.Unmarshal(b) @@ -547,19 +618,28 @@ func (m *NebulaHandshakeDetails) GetTime() uint64 { return 0 } +func (m *NebulaHandshakeDetails) GetCertVersion() uint32 { + if m != nil { + return m.CertVersion + } + return 0 +} + type NebulaControl struct { Type NebulaControl_MessageType `protobuf:"varint,1,opt,name=Type,proto3,enum=nebula.NebulaControl_MessageType" json:"Type,omitempty"` InitiatorRelayIndex uint32 `protobuf:"varint,2,opt,name=InitiatorRelayIndex,proto3" json:"InitiatorRelayIndex,omitempty"` ResponderRelayIndex uint32 `protobuf:"varint,3,opt,name=ResponderRelayIndex,proto3" json:"ResponderRelayIndex,omitempty"` - RelayToIp uint32 `protobuf:"varint,4,opt,name=RelayToIp,proto3" json:"RelayToIp,omitempty"` - RelayFromIp uint32 `protobuf:"varint,5,opt,name=RelayFromIp,proto3" json:"RelayFromIp,omitempty"` + OldRelayToAddr uint32 `protobuf:"varint,4,opt,name=OldRelayToAddr,proto3" json:"OldRelayToAddr,omitempty"` // Deprecated: Do not use. + OldRelayFromAddr uint32 `protobuf:"varint,5,opt,name=OldRelayFromAddr,proto3" json:"OldRelayFromAddr,omitempty"` // Deprecated: Do not use. + RelayToAddr *Addr `protobuf:"bytes,6,opt,name=RelayToAddr,proto3" json:"RelayToAddr,omitempty"` + RelayFromAddr *Addr `protobuf:"bytes,7,opt,name=RelayFromAddr,proto3" json:"RelayFromAddr,omitempty"` } func (m *NebulaControl) Reset() { *m = NebulaControl{} } func (m *NebulaControl) String() string { return proto.CompactTextString(m) } func (*NebulaControl) ProtoMessage() {} func (*NebulaControl) Descriptor() ([]byte, []int) { - return fileDescriptor_2d65afa7693df5ef, []int{7} + return fileDescriptor_2d65afa7693df5ef, []int{8} } func (m *NebulaControl) XXX_Unmarshal(b []byte) error { return m.Unmarshal(b) @@ -609,28 +689,45 @@ func (m *NebulaControl) GetResponderRelayIndex() uint32 { return 0 } -func (m *NebulaControl) GetRelayToIp() uint32 { +// Deprecated: Do not use. +func (m *NebulaControl) GetOldRelayToAddr() uint32 { if m != nil { - return m.RelayToIp + return m.OldRelayToAddr } return 0 } -func (m *NebulaControl) GetRelayFromIp() uint32 { +// Deprecated: Do not use. +func (m *NebulaControl) GetOldRelayFromAddr() uint32 { if m != nil { - return m.RelayFromIp + return m.OldRelayFromAddr } return 0 } +func (m *NebulaControl) GetRelayToAddr() *Addr { + if m != nil { + return m.RelayToAddr + } + return nil +} + +func (m *NebulaControl) GetRelayFromAddr() *Addr { + if m != nil { + return m.RelayFromAddr + } + return nil +} + func init() { proto.RegisterEnum("nebula.NebulaMeta_MessageType", NebulaMeta_MessageType_name, NebulaMeta_MessageType_value) proto.RegisterEnum("nebula.NebulaPing_MessageType", NebulaPing_MessageType_name, NebulaPing_MessageType_value) proto.RegisterEnum("nebula.NebulaControl_MessageType", NebulaControl_MessageType_name, NebulaControl_MessageType_value) proto.RegisterType((*NebulaMeta)(nil), "nebula.NebulaMeta") proto.RegisterType((*NebulaMetaDetails)(nil), "nebula.NebulaMetaDetails") - proto.RegisterType((*Ip4AndPort)(nil), "nebula.Ip4AndPort") - proto.RegisterType((*Ip6AndPort)(nil), "nebula.Ip6AndPort") + proto.RegisterType((*Addr)(nil), "nebula.Addr") + proto.RegisterType((*V4AddrPort)(nil), "nebula.V4AddrPort") + proto.RegisterType((*V6AddrPort)(nil), "nebula.V6AddrPort") proto.RegisterType((*NebulaPing)(nil), "nebula.NebulaPing") proto.RegisterType((*NebulaHandshake)(nil), "nebula.NebulaHandshake") proto.RegisterType((*NebulaHandshakeDetails)(nil), "nebula.NebulaHandshakeDetails") @@ -640,52 +737,57 @@ func init() { func init() { proto.RegisterFile("nebula.proto", fileDescriptor_2d65afa7693df5ef) } var fileDescriptor_2d65afa7693df5ef = []byte{ - // 707 bytes of a gzipped FileDescriptorProto - 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x7c, 0x54, 0x4d, 0x6f, 0xda, 0x4a, - 0x14, 0xc5, 0xc6, 0x7c, 0x5d, 0x02, 0xf1, 0xbb, 0x79, 0x8f, 0x07, 0x4f, 0xaf, 0x16, 0xf5, 0xa2, - 0x62, 0x45, 0x22, 0x92, 0x46, 0x5d, 0x36, 0xa5, 0xaa, 0x20, 0x4a, 0x22, 0x3a, 0x4a, 0x5b, 0xa9, - 0x9b, 0x6a, 0x62, 0xa6, 0xc1, 0x02, 0x3c, 0x8e, 0x3d, 0x54, 0xe1, 0x5f, 0xf4, 0xc7, 0xe4, 0x47, - 0x74, 0xd7, 0x2c, 0xbb, 0xac, 0x92, 0x65, 0x97, 0xfd, 0x03, 0xd5, 0x8c, 0xc1, 0x36, 0x84, 0x76, - 0x37, 0xe7, 0xde, 0x73, 0x66, 0xce, 0x9c, 0xb9, 0x36, 0x6c, 0x79, 0xec, 0x62, 0x36, 0xa1, 0x6d, - 0x3f, 0xe0, 0x82, 0x63, 0x3e, 0x42, 0xf6, 0x0f, 0x1d, 0xe0, 0x4c, 0x2d, 0x4f, 0x99, 0xa0, 0xd8, - 0x01, 0xe3, 0x7c, 0xee, 0xb3, 0xba, 0xd6, 0xd4, 0x5a, 0xd5, 0x8e, 0xd5, 0x5e, 0x68, 0x12, 0x46, - 0xfb, 0x94, 0x85, 0x21, 0xbd, 0x64, 0x92, 0x45, 0x14, 0x17, 0xf7, 0xa1, 0xf0, 0x92, 0x09, 0xea, - 0x4e, 0xc2, 0xba, 0xde, 0xd4, 0x5a, 0xe5, 0x4e, 0xe3, 0xa1, 0x6c, 0x41, 0x20, 0x4b, 0xa6, 0xfd, - 0x53, 0x83, 0x72, 0x6a, 0x2b, 0x2c, 0x82, 0x71, 0xc6, 0x3d, 0x66, 0x66, 0xb0, 0x02, 0xa5, 0x1e, - 0x0f, 0xc5, 0xeb, 0x19, 0x0b, 0xe6, 0xa6, 0x86, 0x08, 0xd5, 0x18, 0x12, 0xe6, 0x4f, 0xe6, 0xa6, - 0x8e, 0xff, 0x41, 0x4d, 0xd6, 0xde, 0xf8, 0x43, 0x2a, 0xd8, 0x19, 0x17, 0xee, 0x47, 0xd7, 0xa1, - 0xc2, 0xe5, 0x9e, 0x99, 0xc5, 0x06, 0xfc, 0x23, 0x7b, 0xa7, 0xfc, 0x13, 0x1b, 0xae, 0xb4, 0x8c, - 0x65, 0x6b, 0x30, 0xf3, 0x9c, 0xd1, 0x4a, 0x2b, 0x87, 0x55, 0x00, 0xd9, 0x7a, 0x37, 0xe2, 0x74, - 0xea, 0x9a, 0x79, 0xdc, 0x81, 0xed, 0x04, 0x47, 0xc7, 0x16, 0xa4, 0xb3, 0x01, 0x15, 0xa3, 0xee, - 0x88, 0x39, 0x63, 0xb3, 0x28, 0x9d, 0xc5, 0x30, 0xa2, 0x94, 0xf0, 0x11, 0x34, 0x36, 0x3b, 0x3b, - 0x72, 0xc6, 0x26, 0xd8, 0x5f, 0x35, 0xf8, 0xeb, 0x41, 0x28, 0xf8, 0x37, 0xe4, 0xde, 0xfa, 0x5e, - 0xdf, 0x57, 0xa9, 0x57, 0x48, 0x04, 0xf0, 0x00, 0xca, 0x7d, 0xff, 0xe0, 0xc8, 0x1b, 0x0e, 0x78, - 0x20, 0x64, 0xb4, 0xd9, 0x56, 0xb9, 0x83, 0xcb, 0x68, 0x93, 0x16, 0x49, 0xd3, 0x22, 0xd5, 0x61, - 0xac, 0x32, 0xd6, 0x55, 0x87, 0x29, 0x55, 0x4c, 0x43, 0x0b, 0x80, 0xb0, 0x09, 0x9d, 0x47, 0x36, - 0x72, 0xcd, 0x6c, 0xab, 0x42, 0x52, 0x15, 0xac, 0x43, 0xc1, 0xe1, 0x33, 0x4f, 0xb0, 0xa0, 0x9e, - 0x55, 0x1e, 0x97, 0xd0, 0xde, 0x03, 0x48, 0x8e, 0xc7, 0x2a, 0xe8, 0xf1, 0x35, 0xf4, 0xbe, 0x8f, - 0x08, 0x86, 0xac, 0xab, 0xb9, 0xa8, 0x10, 0xb5, 0xb6, 0x9f, 0x4b, 0xc5, 0x61, 0x4a, 0xd1, 0x73, - 0x95, 0xc2, 0x20, 0x7a, 0xcf, 0x95, 0xf8, 0x84, 0x2b, 0xbe, 0x41, 0xf4, 0x13, 0x1e, 0xef, 0x90, - 0x4d, 0xed, 0x70, 0xbd, 0x1c, 0xd9, 0x81, 0xeb, 0x5d, 0xfe, 0x79, 0x64, 0x25, 0x63, 0xc3, 0xc8, - 0x22, 0x18, 0xe7, 0xee, 0x94, 0x2d, 0xce, 0x51, 0x6b, 0xdb, 0x7e, 0x30, 0x90, 0x52, 0x6c, 0x66, - 0xb0, 0x04, 0xb9, 0xe8, 0x79, 0x35, 0xfb, 0x03, 0x6c, 0x47, 0xfb, 0xf6, 0xa8, 0x37, 0x0c, 0x47, - 0x74, 0xcc, 0xf0, 0x59, 0x32, 0xfd, 0x9a, 0x9a, 0xfe, 0x35, 0x07, 0x31, 0x73, 0xfd, 0x13, 0x90, - 0x26, 0x7a, 0x53, 0xea, 0x28, 0x13, 0x5b, 0x44, 0xad, 0xed, 0x1b, 0x0d, 0x6a, 0x9b, 0x75, 0x92, - 0xde, 0x65, 0x81, 0x50, 0xa7, 0x6c, 0x11, 0xb5, 0xc6, 0x27, 0x50, 0xed, 0x7b, 0xae, 0x70, 0xa9, - 0xe0, 0x41, 0xdf, 0x1b, 0xb2, 0xeb, 0x45, 0xd2, 0x6b, 0x55, 0xc9, 0x23, 0x2c, 0xf4, 0xb9, 0x37, - 0x64, 0x0b, 0x5e, 0x94, 0xe7, 0x5a, 0x15, 0x6b, 0x90, 0xef, 0x72, 0x3e, 0x76, 0x59, 0xdd, 0x50, - 0xc9, 0x2c, 0x50, 0x9c, 0x57, 0x2e, 0xc9, 0xeb, 0xd8, 0x28, 0xe6, 0xcd, 0xc2, 0xb1, 0x51, 0x2c, - 0x98, 0x45, 0xfb, 0x46, 0x87, 0x4a, 0x64, 0xbb, 0xcb, 0x3d, 0x11, 0xf0, 0x09, 0x3e, 0x5d, 0x79, - 0x95, 0xc7, 0xab, 0x99, 0x2c, 0x48, 0x1b, 0x1e, 0x66, 0x0f, 0x76, 0x62, 0xeb, 0x6a, 0xfe, 0xd2, - 0xb7, 0xda, 0xd4, 0x92, 0x8a, 0xf8, 0x12, 0x29, 0x45, 0x74, 0xbf, 0x4d, 0x2d, 0xfc, 0x1f, 0x4a, - 0x0a, 0x9d, 0xf3, 0xbe, 0xaf, 0xee, 0x59, 0x21, 0x49, 0x01, 0x9b, 0x50, 0x56, 0xe0, 0x55, 0xc0, - 0xa7, 0xea, 0x5b, 0x90, 0xfd, 0x74, 0xc9, 0xee, 0xfd, 0xee, 0xcf, 0x55, 0x03, 0xec, 0x06, 0x8c, - 0x0a, 0xa6, 0xd8, 0x84, 0x5d, 0xcd, 0x58, 0x28, 0x4c, 0x0d, 0xff, 0x85, 0x9d, 0x95, 0xba, 0xb4, - 0x14, 0x32, 0x53, 0x7f, 0xb1, 0xff, 0xe5, 0xce, 0xd2, 0x6e, 0xef, 0x2c, 0xed, 0xfb, 0x9d, 0xa5, - 0x7d, 0xbe, 0xb7, 0x32, 0xb7, 0xf7, 0x56, 0xe6, 0xdb, 0xbd, 0x95, 0x79, 0xdf, 0xb8, 0x74, 0xc5, - 0x68, 0x76, 0xd1, 0x76, 0xf8, 0x74, 0x37, 0x9c, 0x50, 0x67, 0x3c, 0xba, 0xda, 0x8d, 0x22, 0xbc, - 0xc8, 0xab, 0x1f, 0xf8, 0xfe, 0xaf, 0x00, 0x00, 0x00, 0xff, 0xff, 0x17, 0x56, 0x28, 0x74, 0xd0, - 0x05, 0x00, 0x00, + // 785 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x84, 0x55, 0xcd, 0x6e, 0xeb, 0x44, + 0x14, 0x8e, 0x1d, 0x27, 0x4e, 0x4f, 0x7e, 0xae, 0x39, 0x15, 0xc1, 0x41, 0x22, 0x0a, 0x5e, 0x54, + 0x57, 0x2c, 0x72, 0x51, 0x5a, 0xae, 0x58, 0x72, 0x1b, 0x84, 0xd2, 0xaa, 0x3f, 0x61, 0x54, 0x8a, + 0xc4, 0x06, 0xb9, 0xf6, 0xd0, 0x58, 0x71, 0x3c, 0xa9, 0x3d, 0x41, 0xcd, 0x5b, 0xf0, 0x30, 0x3c, + 0x04, 0xec, 0xba, 0x42, 0x2c, 0x51, 0xbb, 0x64, 0xc9, 0x0b, 0xa0, 0x19, 0xff, 0x27, 0x86, 0xbb, + 0x9b, 0x73, 0xbe, 0xef, 0x3b, 0x73, 0xe6, 0xf3, 0x9c, 0x31, 0x74, 0x02, 0x7a, 0xb7, 0xf1, 0xed, + 0xf1, 0x3a, 0x64, 0x9c, 0x61, 0x33, 0x8e, 0xac, 0xbf, 0x55, 0x80, 0x2b, 0xb9, 0xbc, 0xa4, 0xdc, + 0xc6, 0x09, 0x68, 0x37, 0xdb, 0x35, 0x35, 0x95, 0x91, 0xf2, 0xba, 0x37, 0x19, 0x8e, 0x13, 0x4d, + 0xce, 0x18, 0x5f, 0xd2, 0x28, 0xb2, 0xef, 0xa9, 0x60, 0x11, 0xc9, 0xc5, 0x63, 0xd0, 0xbf, 0xa6, + 0xdc, 0xf6, 0xfc, 0xc8, 0x54, 0x47, 0xca, 0xeb, 0xf6, 0x64, 0xb0, 0x2f, 0x4b, 0x08, 0x24, 0x65, + 0x5a, 0xff, 0x28, 0xd0, 0x2e, 0x94, 0xc2, 0x16, 0x68, 0x57, 0x2c, 0xa0, 0x46, 0x0d, 0xbb, 0x70, + 0x30, 0x63, 0x11, 0xff, 0x76, 0x43, 0xc3, 0xad, 0xa1, 0x20, 0x42, 0x2f, 0x0b, 0x09, 0x5d, 0xfb, + 0x5b, 0x43, 0xc5, 0x8f, 0xa1, 0x2f, 0x72, 0xdf, 0xad, 0x5d, 0x9b, 0xd3, 0x2b, 0xc6, 0xbd, 0x9f, + 0x3c, 0xc7, 0xe6, 0x1e, 0x0b, 0x8c, 0x3a, 0x0e, 0xe0, 0x43, 0x81, 0x5d, 0xb2, 0x9f, 0xa9, 0x5b, + 0x82, 0xb4, 0x14, 0x9a, 0x6f, 0x02, 0x67, 0x51, 0x82, 0x1a, 0xd8, 0x03, 0x10, 0xd0, 0xf7, 0x0b, + 0x66, 0xaf, 0x3c, 0xa3, 0x89, 0x87, 0xf0, 0x2a, 0x8f, 0xe3, 0x6d, 0x75, 0xd1, 0xd9, 0xdc, 0xe6, + 0x8b, 0xe9, 0x82, 0x3a, 0x4b, 0xa3, 0x25, 0x3a, 0xcb, 0xc2, 0x98, 0x72, 0x80, 0x9f, 0xc0, 0xa0, + 0xba, 0xb3, 0x77, 0xce, 0xd2, 0x00, 0xeb, 0x77, 0x15, 0x3e, 0xd8, 0x33, 0x05, 0x2d, 0x80, 0x6b, + 0xdf, 0xbd, 0x5d, 0x07, 0xef, 0x5c, 0x37, 0x94, 0xd6, 0x77, 0x4f, 0x55, 0x53, 0x21, 0x85, 0x2c, + 0x1e, 0x81, 0x9e, 0x12, 0x9a, 0xd2, 0xe4, 0x4e, 0x6a, 0xb2, 0xc8, 0x91, 0x14, 0xc4, 0x31, 0x18, + 0xd7, 0xbe, 0x4b, 0xa8, 0x6f, 0x6f, 0x93, 0x54, 0x64, 0x36, 0x46, 0xf5, 0xa4, 0xe2, 0x1e, 0x86, + 0x13, 0xe8, 0x96, 0xc9, 0xfa, 0xa8, 0xbe, 0x57, 0xbd, 0x4c, 0xc1, 0x13, 0x68, 0xdf, 0x9e, 0x88, + 0xe5, 0x9c, 0x85, 0x5c, 0x7c, 0x74, 0xa1, 0xc0, 0x54, 0x91, 0x43, 0xa4, 0x48, 0x93, 0xaa, 0xb7, + 0xb9, 0x4a, 0xdb, 0x51, 0xbd, 0x2d, 0xa8, 0x72, 0x1a, 0x9a, 0xa0, 0x3b, 0x6c, 0x13, 0x70, 0x1a, + 0x9a, 0x75, 0x61, 0x0c, 0x49, 0x43, 0xeb, 0x08, 0x34, 0x79, 0xe2, 0x1e, 0xa8, 0x33, 0x4f, 0xba, + 0xa6, 0x11, 0x75, 0xe6, 0x89, 0xf8, 0x82, 0xc9, 0x9b, 0xa8, 0x11, 0xf5, 0x82, 0x59, 0x27, 0x00, + 0x79, 0x1b, 0x88, 0xb1, 0x2a, 0x76, 0x99, 0xc4, 0x15, 0x10, 0x34, 0x81, 0x49, 0x4d, 0x97, 0xc8, + 0xb5, 0xf5, 0x15, 0x40, 0xde, 0xc6, 0xfb, 0xf6, 0xc8, 0x2a, 0xd4, 0x0b, 0x15, 0x1e, 0xd3, 0xc1, + 0x9a, 0x7b, 0xc1, 0xfd, 0xff, 0x0f, 0x96, 0x60, 0x54, 0x0c, 0x16, 0x82, 0x76, 0xe3, 0xad, 0x68, + 0xb2, 0x8f, 0x5c, 0x5b, 0xd6, 0xde, 0xd8, 0x08, 0xb1, 0x51, 0xc3, 0x03, 0x68, 0xc4, 0x97, 0x50, + 0xb1, 0x7e, 0x84, 0x57, 0x71, 0xdd, 0x99, 0x1d, 0xb8, 0xd1, 0xc2, 0x5e, 0x52, 0xfc, 0x32, 0x9f, + 0x51, 0x45, 0x5e, 0x9f, 0x9d, 0x0e, 0x32, 0xe6, 0xee, 0xa0, 0x8a, 0x26, 0x66, 0x2b, 0xdb, 0x91, + 0x4d, 0x74, 0x88, 0x5c, 0x5b, 0x7f, 0x28, 0xd0, 0xaf, 0xd6, 0x09, 0xfa, 0x94, 0x86, 0x5c, 0xee, + 0xd2, 0x21, 0x72, 0x8d, 0x47, 0xd0, 0x3b, 0x0b, 0x3c, 0xee, 0xd9, 0x9c, 0x85, 0x67, 0x81, 0x4b, + 0x1f, 0x13, 0xa7, 0x77, 0xb2, 0x82, 0x47, 0x68, 0xb4, 0x66, 0x81, 0x4b, 0x13, 0x5e, 0xec, 0xe7, + 0x4e, 0x16, 0xfb, 0xd0, 0x9c, 0x32, 0xb6, 0xf4, 0xa8, 0xa9, 0x49, 0x67, 0x92, 0x28, 0xf3, 0xab, + 0x91, 0xfb, 0x85, 0x23, 0x68, 0x8b, 0x1e, 0x6e, 0x69, 0x18, 0x79, 0x2c, 0x30, 0x5b, 0xb2, 0x60, + 0x31, 0x75, 0xae, 0xb5, 0x9a, 0x86, 0x7e, 0xae, 0xb5, 0x74, 0xa3, 0x65, 0xfd, 0x5a, 0x87, 0x6e, + 0x7c, 0xb0, 0x29, 0x0b, 0x78, 0xc8, 0x7c, 0xfc, 0xa2, 0xf4, 0xdd, 0x3e, 0x2d, 0xbb, 0x96, 0x90, + 0x2a, 0x3e, 0xdd, 0xe7, 0x70, 0x98, 0x1d, 0x4e, 0x0e, 0x4f, 0xf1, 0xdc, 0x55, 0x90, 0x50, 0x64, + 0xc7, 0x2c, 0x28, 0x62, 0x07, 0xaa, 0x20, 0xfc, 0x0c, 0x7a, 0xe9, 0x38, 0xdf, 0x30, 0x79, 0xa9, + 0xb5, 0xec, 0xe9, 0xd8, 0x41, 0x8a, 0xcf, 0xc2, 0x37, 0x21, 0x5b, 0x49, 0x76, 0x23, 0x63, 0xef, + 0x61, 0x38, 0x86, 0x76, 0xb1, 0x70, 0xd5, 0x93, 0x53, 0x24, 0x64, 0xcf, 0x48, 0x56, 0x5c, 0xaf, + 0x50, 0x94, 0x29, 0xd6, 0xec, 0xbf, 0xfe, 0x00, 0x7d, 0xc0, 0x69, 0x48, 0x6d, 0x4e, 0x25, 0x9f, + 0xd0, 0x87, 0x0d, 0x8d, 0xb8, 0xa1, 0xe0, 0x47, 0x70, 0x58, 0xca, 0x0b, 0x4b, 0x22, 0x6a, 0xa8, + 0xa7, 0xc7, 0xbf, 0x3d, 0x0f, 0x95, 0xa7, 0xe7, 0xa1, 0xf2, 0xd7, 0xf3, 0x50, 0xf9, 0xe5, 0x65, + 0x58, 0x7b, 0x7a, 0x19, 0xd6, 0xfe, 0x7c, 0x19, 0xd6, 0x7e, 0x18, 0xdc, 0x7b, 0x7c, 0xb1, 0xb9, + 0x1b, 0x3b, 0x6c, 0xf5, 0x26, 0xf2, 0x6d, 0x67, 0xb9, 0x78, 0x78, 0x13, 0xb7, 0x74, 0xd7, 0x94, + 0x3f, 0xc2, 0xe3, 0x7f, 0x03, 0x00, 0x00, 0xff, 0xff, 0xea, 0x6f, 0xbc, 0x50, 0x18, 0x07, 0x00, + 0x00, } func (m *NebulaMeta) Marshal() (dAtA []byte, err error) { @@ -748,28 +850,54 @@ func (m *NebulaMetaDetails) MarshalToSizedBuffer(dAtA []byte) (int, error) { _ = i var l int _ = l - if len(m.RelayVpnIp) > 0 { - dAtA3 := make([]byte, len(m.RelayVpnIp)*10) - var j2 int - for _, num := range m.RelayVpnIp { - for num >= 1<<7 { - dAtA3[j2] = uint8(uint64(num)&0x7f | 0x80) - num >>= 7 - j2++ + if len(m.RelayVpnAddrs) > 0 { + for iNdEx := len(m.RelayVpnAddrs) - 1; iNdEx >= 0; iNdEx-- { + { + size, err := m.RelayVpnAddrs[iNdEx].MarshalToSizedBuffer(dAtA[:i]) + if err != nil { + return 0, err + } + i -= size + i = encodeVarintNebula(dAtA, i, uint64(size)) } - dAtA3[j2] = uint8(num) - j2++ + i-- + dAtA[i] = 0x3a } - i -= j2 - copy(dAtA[i:], dAtA3[:j2]) - i = encodeVarintNebula(dAtA, i, uint64(j2)) + } + if m.VpnAddr != nil { + { + size, err := m.VpnAddr.MarshalToSizedBuffer(dAtA[:i]) + if err != nil { + return 0, err + } + i -= size + i = encodeVarintNebula(dAtA, i, uint64(size)) + } + i-- + dAtA[i] = 0x32 + } + if len(m.OldRelayVpnAddrs) > 0 { + dAtA4 := make([]byte, len(m.OldRelayVpnAddrs)*10) + var j3 int + for _, num := range m.OldRelayVpnAddrs { + for num >= 1<<7 { + dAtA4[j3] = uint8(uint64(num)&0x7f | 0x80) + num >>= 7 + j3++ + } + dAtA4[j3] = uint8(num) + j3++ + } + i -= j3 + copy(dAtA[i:], dAtA4[:j3]) + i = encodeVarintNebula(dAtA, i, uint64(j3)) i-- dAtA[i] = 0x2a } - if len(m.Ip6AndPorts) > 0 { - for iNdEx := len(m.Ip6AndPorts) - 1; iNdEx >= 0; iNdEx-- { + if len(m.V6AddrPorts) > 0 { + for iNdEx := len(m.V6AddrPorts) - 1; iNdEx >= 0; iNdEx-- { { - size, err := m.Ip6AndPorts[iNdEx].MarshalToSizedBuffer(dAtA[:i]) + size, err := m.V6AddrPorts[iNdEx].MarshalToSizedBuffer(dAtA[:i]) if err != nil { return 0, err } @@ -785,10 +913,10 @@ func (m *NebulaMetaDetails) MarshalToSizedBuffer(dAtA []byte) (int, error) { i-- dAtA[i] = 0x18 } - if len(m.Ip4AndPorts) > 0 { - for iNdEx := len(m.Ip4AndPorts) - 1; iNdEx >= 0; iNdEx-- { + if len(m.V4AddrPorts) > 0 { + for iNdEx := len(m.V4AddrPorts) - 1; iNdEx >= 0; iNdEx-- { { - size, err := m.Ip4AndPorts[iNdEx].MarshalToSizedBuffer(dAtA[:i]) + size, err := m.V4AddrPorts[iNdEx].MarshalToSizedBuffer(dAtA[:i]) if err != nil { return 0, err } @@ -799,15 +927,15 @@ func (m *NebulaMetaDetails) MarshalToSizedBuffer(dAtA []byte) (int, error) { dAtA[i] = 0x12 } } - if m.VpnIp != 0 { - i = encodeVarintNebula(dAtA, i, uint64(m.VpnIp)) + if m.OldVpnAddr != 0 { + i = encodeVarintNebula(dAtA, i, uint64(m.OldVpnAddr)) i-- dAtA[i] = 0x8 } return len(dAtA) - i, nil } -func (m *Ip4AndPort) Marshal() (dAtA []byte, err error) { +func (m *Addr) Marshal() (dAtA []byte, err error) { size := m.Size() dAtA = make([]byte, size) n, err := m.MarshalToSizedBuffer(dAtA[:size]) @@ -817,12 +945,45 @@ func (m *Ip4AndPort) Marshal() (dAtA []byte, err error) { return dAtA[:n], nil } -func (m *Ip4AndPort) MarshalTo(dAtA []byte) (int, error) { +func (m *Addr) MarshalTo(dAtA []byte) (int, error) { size := m.Size() return m.MarshalToSizedBuffer(dAtA[:size]) } -func (m *Ip4AndPort) MarshalToSizedBuffer(dAtA []byte) (int, error) { +func (m *Addr) MarshalToSizedBuffer(dAtA []byte) (int, error) { + i := len(dAtA) + _ = i + var l int + _ = l + if m.Lo != 0 { + i = encodeVarintNebula(dAtA, i, uint64(m.Lo)) + i-- + dAtA[i] = 0x10 + } + if m.Hi != 0 { + i = encodeVarintNebula(dAtA, i, uint64(m.Hi)) + i-- + dAtA[i] = 0x8 + } + return len(dAtA) - i, nil +} + +func (m *V4AddrPort) Marshal() (dAtA []byte, err error) { + size := m.Size() + dAtA = make([]byte, size) + n, err := m.MarshalToSizedBuffer(dAtA[:size]) + if err != nil { + return nil, err + } + return dAtA[:n], nil +} + +func (m *V4AddrPort) MarshalTo(dAtA []byte) (int, error) { + size := m.Size() + return m.MarshalToSizedBuffer(dAtA[:size]) +} + +func (m *V4AddrPort) MarshalToSizedBuffer(dAtA []byte) (int, error) { i := len(dAtA) _ = i var l int @@ -832,15 +993,15 @@ func (m *Ip4AndPort) MarshalToSizedBuffer(dAtA []byte) (int, error) { i-- dAtA[i] = 0x10 } - if m.Ip != 0 { - i = encodeVarintNebula(dAtA, i, uint64(m.Ip)) + if m.Addr != 0 { + i = encodeVarintNebula(dAtA, i, uint64(m.Addr)) i-- dAtA[i] = 0x8 } return len(dAtA) - i, nil } -func (m *Ip6AndPort) Marshal() (dAtA []byte, err error) { +func (m *V6AddrPort) Marshal() (dAtA []byte, err error) { size := m.Size() dAtA = make([]byte, size) n, err := m.MarshalToSizedBuffer(dAtA[:size]) @@ -850,12 +1011,12 @@ func (m *Ip6AndPort) Marshal() (dAtA []byte, err error) { return dAtA[:n], nil } -func (m *Ip6AndPort) MarshalTo(dAtA []byte) (int, error) { +func (m *V6AddrPort) MarshalTo(dAtA []byte) (int, error) { size := m.Size() return m.MarshalToSizedBuffer(dAtA[:size]) } -func (m *Ip6AndPort) MarshalToSizedBuffer(dAtA []byte) (int, error) { +func (m *V6AddrPort) MarshalToSizedBuffer(dAtA []byte) (int, error) { i := len(dAtA) _ = i var l int @@ -973,6 +1134,11 @@ func (m *NebulaHandshakeDetails) MarshalToSizedBuffer(dAtA []byte) (int, error) _ = i var l int _ = l + if m.CertVersion != 0 { + i = encodeVarintNebula(dAtA, i, uint64(m.CertVersion)) + i-- + dAtA[i] = 0x40 + } if m.Time != 0 { i = encodeVarintNebula(dAtA, i, uint64(m.Time)) i-- @@ -1023,13 +1189,37 @@ func (m *NebulaControl) MarshalToSizedBuffer(dAtA []byte) (int, error) { _ = i var l int _ = l - if m.RelayFromIp != 0 { - i = encodeVarintNebula(dAtA, i, uint64(m.RelayFromIp)) + if m.RelayFromAddr != nil { + { + size, err := m.RelayFromAddr.MarshalToSizedBuffer(dAtA[:i]) + if err != nil { + return 0, err + } + i -= size + i = encodeVarintNebula(dAtA, i, uint64(size)) + } + i-- + dAtA[i] = 0x3a + } + if m.RelayToAddr != nil { + { + size, err := m.RelayToAddr.MarshalToSizedBuffer(dAtA[:i]) + if err != nil { + return 0, err + } + i -= size + i = encodeVarintNebula(dAtA, i, uint64(size)) + } + i-- + dAtA[i] = 0x32 + } + if m.OldRelayFromAddr != 0 { + i = encodeVarintNebula(dAtA, i, uint64(m.OldRelayFromAddr)) i-- dAtA[i] = 0x28 } - if m.RelayToIp != 0 { - i = encodeVarintNebula(dAtA, i, uint64(m.RelayToIp)) + if m.OldRelayToAddr != 0 { + i = encodeVarintNebula(dAtA, i, uint64(m.OldRelayToAddr)) i-- dAtA[i] = 0x20 } @@ -1084,11 +1274,11 @@ func (m *NebulaMetaDetails) Size() (n int) { } var l int _ = l - if m.VpnIp != 0 { - n += 1 + sovNebula(uint64(m.VpnIp)) + if m.OldVpnAddr != 0 { + n += 1 + sovNebula(uint64(m.OldVpnAddr)) } - if len(m.Ip4AndPorts) > 0 { - for _, e := range m.Ip4AndPorts { + if len(m.V4AddrPorts) > 0 { + for _, e := range m.V4AddrPorts { l = e.Size() n += 1 + l + sovNebula(uint64(l)) } @@ -1096,30 +1286,55 @@ func (m *NebulaMetaDetails) Size() (n int) { if m.Counter != 0 { n += 1 + sovNebula(uint64(m.Counter)) } - if len(m.Ip6AndPorts) > 0 { - for _, e := range m.Ip6AndPorts { + if len(m.V6AddrPorts) > 0 { + for _, e := range m.V6AddrPorts { l = e.Size() n += 1 + l + sovNebula(uint64(l)) } } - if len(m.RelayVpnIp) > 0 { + if len(m.OldRelayVpnAddrs) > 0 { l = 0 - for _, e := range m.RelayVpnIp { + for _, e := range m.OldRelayVpnAddrs { l += sovNebula(uint64(e)) } n += 1 + sovNebula(uint64(l)) + l } + if m.VpnAddr != nil { + l = m.VpnAddr.Size() + n += 1 + l + sovNebula(uint64(l)) + } + if len(m.RelayVpnAddrs) > 0 { + for _, e := range m.RelayVpnAddrs { + l = e.Size() + n += 1 + l + sovNebula(uint64(l)) + } + } return n } -func (m *Ip4AndPort) Size() (n int) { +func (m *Addr) Size() (n int) { if m == nil { return 0 } var l int _ = l - if m.Ip != 0 { - n += 1 + sovNebula(uint64(m.Ip)) + if m.Hi != 0 { + n += 1 + sovNebula(uint64(m.Hi)) + } + if m.Lo != 0 { + n += 1 + sovNebula(uint64(m.Lo)) + } + return n +} + +func (m *V4AddrPort) Size() (n int) { + if m == nil { + return 0 + } + var l int + _ = l + if m.Addr != 0 { + n += 1 + sovNebula(uint64(m.Addr)) } if m.Port != 0 { n += 1 + sovNebula(uint64(m.Port)) @@ -1127,7 +1342,7 @@ func (m *Ip4AndPort) Size() (n int) { return n } -func (m *Ip6AndPort) Size() (n int) { +func (m *V6AddrPort) Size() (n int) { if m == nil { return 0 } @@ -1199,6 +1414,9 @@ func (m *NebulaHandshakeDetails) Size() (n int) { if m.Time != 0 { n += 1 + sovNebula(uint64(m.Time)) } + if m.CertVersion != 0 { + n += 1 + sovNebula(uint64(m.CertVersion)) + } return n } @@ -1217,11 +1435,19 @@ func (m *NebulaControl) Size() (n int) { if m.ResponderRelayIndex != 0 { n += 1 + sovNebula(uint64(m.ResponderRelayIndex)) } - if m.RelayToIp != 0 { - n += 1 + sovNebula(uint64(m.RelayToIp)) + if m.OldRelayToAddr != 0 { + n += 1 + sovNebula(uint64(m.OldRelayToAddr)) } - if m.RelayFromIp != 0 { - n += 1 + sovNebula(uint64(m.RelayFromIp)) + if m.OldRelayFromAddr != 0 { + n += 1 + sovNebula(uint64(m.OldRelayFromAddr)) + } + if m.RelayToAddr != nil { + l = m.RelayToAddr.Size() + n += 1 + l + sovNebula(uint64(l)) + } + if m.RelayFromAddr != nil { + l = m.RelayFromAddr.Size() + n += 1 + l + sovNebula(uint64(l)) } return n } @@ -1368,9 +1594,9 @@ func (m *NebulaMetaDetails) Unmarshal(dAtA []byte) error { switch fieldNum { case 1: if wireType != 0 { - return fmt.Errorf("proto: wrong wireType = %d for field VpnIp", wireType) + return fmt.Errorf("proto: wrong wireType = %d for field OldVpnAddr", wireType) } - m.VpnIp = 0 + m.OldVpnAddr = 0 for shift := uint(0); ; shift += 7 { if shift >= 64 { return ErrIntOverflowNebula @@ -1380,14 +1606,14 @@ func (m *NebulaMetaDetails) Unmarshal(dAtA []byte) error { } b := dAtA[iNdEx] iNdEx++ - m.VpnIp |= uint32(b&0x7F) << shift + m.OldVpnAddr |= uint32(b&0x7F) << shift if b < 0x80 { break } } case 2: if wireType != 2 { - return fmt.Errorf("proto: wrong wireType = %d for field Ip4AndPorts", wireType) + return fmt.Errorf("proto: wrong wireType = %d for field V4AddrPorts", wireType) } var msglen int for shift := uint(0); ; shift += 7 { @@ -1414,8 +1640,8 @@ func (m *NebulaMetaDetails) Unmarshal(dAtA []byte) error { if postIndex > l { return io.ErrUnexpectedEOF } - m.Ip4AndPorts = append(m.Ip4AndPorts, &Ip4AndPort{}) - if err := m.Ip4AndPorts[len(m.Ip4AndPorts)-1].Unmarshal(dAtA[iNdEx:postIndex]); err != nil { + m.V4AddrPorts = append(m.V4AddrPorts, &V4AddrPort{}) + if err := m.V4AddrPorts[len(m.V4AddrPorts)-1].Unmarshal(dAtA[iNdEx:postIndex]); err != nil { return err } iNdEx = postIndex @@ -1440,7 +1666,7 @@ func (m *NebulaMetaDetails) Unmarshal(dAtA []byte) error { } case 4: if wireType != 2 { - return fmt.Errorf("proto: wrong wireType = %d for field Ip6AndPorts", wireType) + return fmt.Errorf("proto: wrong wireType = %d for field V6AddrPorts", wireType) } var msglen int for shift := uint(0); ; shift += 7 { @@ -1467,8 +1693,8 @@ func (m *NebulaMetaDetails) Unmarshal(dAtA []byte) error { if postIndex > l { return io.ErrUnexpectedEOF } - m.Ip6AndPorts = append(m.Ip6AndPorts, &Ip6AndPort{}) - if err := m.Ip6AndPorts[len(m.Ip6AndPorts)-1].Unmarshal(dAtA[iNdEx:postIndex]); err != nil { + m.V6AddrPorts = append(m.V6AddrPorts, &V6AddrPort{}) + if err := m.V6AddrPorts[len(m.V6AddrPorts)-1].Unmarshal(dAtA[iNdEx:postIndex]); err != nil { return err } iNdEx = postIndex @@ -1489,7 +1715,7 @@ func (m *NebulaMetaDetails) Unmarshal(dAtA []byte) error { break } } - m.RelayVpnIp = append(m.RelayVpnIp, v) + m.OldRelayVpnAddrs = append(m.OldRelayVpnAddrs, v) } else if wireType == 2 { var packedLen int for shift := uint(0); ; shift += 7 { @@ -1524,8 +1750,8 @@ func (m *NebulaMetaDetails) Unmarshal(dAtA []byte) error { } } elementCount = count - if elementCount != 0 && len(m.RelayVpnIp) == 0 { - m.RelayVpnIp = make([]uint32, 0, elementCount) + if elementCount != 0 && len(m.OldRelayVpnAddrs) == 0 { + m.OldRelayVpnAddrs = make([]uint32, 0, elementCount) } for iNdEx < postIndex { var v uint32 @@ -1543,11 +1769,81 @@ func (m *NebulaMetaDetails) Unmarshal(dAtA []byte) error { break } } - m.RelayVpnIp = append(m.RelayVpnIp, v) + m.OldRelayVpnAddrs = append(m.OldRelayVpnAddrs, v) } } else { - return fmt.Errorf("proto: wrong wireType = %d for field RelayVpnIp", wireType) + return fmt.Errorf("proto: wrong wireType = %d for field OldRelayVpnAddrs", wireType) } + case 6: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field VpnAddr", wireType) + } + var msglen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowNebula + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + msglen |= int(b&0x7F) << shift + if b < 0x80 { + break + } + } + if msglen < 0 { + return ErrInvalidLengthNebula + } + postIndex := iNdEx + msglen + if postIndex < 0 { + return ErrInvalidLengthNebula + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + if m.VpnAddr == nil { + m.VpnAddr = &Addr{} + } + if err := m.VpnAddr.Unmarshal(dAtA[iNdEx:postIndex]); err != nil { + return err + } + iNdEx = postIndex + case 7: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field RelayVpnAddrs", wireType) + } + var msglen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowNebula + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + msglen |= int(b&0x7F) << shift + if b < 0x80 { + break + } + } + if msglen < 0 { + return ErrInvalidLengthNebula + } + postIndex := iNdEx + msglen + if postIndex < 0 { + return ErrInvalidLengthNebula + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.RelayVpnAddrs = append(m.RelayVpnAddrs, &Addr{}) + if err := m.RelayVpnAddrs[len(m.RelayVpnAddrs)-1].Unmarshal(dAtA[iNdEx:postIndex]); err != nil { + return err + } + iNdEx = postIndex default: iNdEx = preIndex skippy, err := skipNebula(dAtA[iNdEx:]) @@ -1569,7 +1865,7 @@ func (m *NebulaMetaDetails) Unmarshal(dAtA []byte) error { } return nil } -func (m *Ip4AndPort) Unmarshal(dAtA []byte) error { +func (m *Addr) Unmarshal(dAtA []byte) error { l := len(dAtA) iNdEx := 0 for iNdEx < l { @@ -1592,17 +1888,17 @@ func (m *Ip4AndPort) Unmarshal(dAtA []byte) error { fieldNum := int32(wire >> 3) wireType := int(wire & 0x7) if wireType == 4 { - return fmt.Errorf("proto: Ip4AndPort: wiretype end group for non-group") + return fmt.Errorf("proto: Addr: wiretype end group for non-group") } if fieldNum <= 0 { - return fmt.Errorf("proto: Ip4AndPort: illegal tag %d (wire type %d)", fieldNum, wire) + return fmt.Errorf("proto: Addr: illegal tag %d (wire type %d)", fieldNum, wire) } switch fieldNum { case 1: if wireType != 0 { - return fmt.Errorf("proto: wrong wireType = %d for field Ip", wireType) + return fmt.Errorf("proto: wrong wireType = %d for field Hi", wireType) } - m.Ip = 0 + m.Hi = 0 for shift := uint(0); ; shift += 7 { if shift >= 64 { return ErrIntOverflowNebula @@ -1612,7 +1908,95 @@ func (m *Ip4AndPort) Unmarshal(dAtA []byte) error { } b := dAtA[iNdEx] iNdEx++ - m.Ip |= uint32(b&0x7F) << shift + m.Hi |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + case 2: + if wireType != 0 { + return fmt.Errorf("proto: wrong wireType = %d for field Lo", wireType) + } + m.Lo = 0 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowNebula + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + m.Lo |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + default: + iNdEx = preIndex + skippy, err := skipNebula(dAtA[iNdEx:]) + if err != nil { + return err + } + if (skippy < 0) || (iNdEx+skippy) < 0 { + return ErrInvalidLengthNebula + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + iNdEx += skippy + } + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} +func (m *V4AddrPort) Unmarshal(dAtA []byte) error { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowNebula + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: V4AddrPort: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: V4AddrPort: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType != 0 { + return fmt.Errorf("proto: wrong wireType = %d for field Addr", wireType) + } + m.Addr = 0 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowNebula + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + m.Addr |= uint32(b&0x7F) << shift if b < 0x80 { break } @@ -1657,7 +2041,7 @@ func (m *Ip4AndPort) Unmarshal(dAtA []byte) error { } return nil } -func (m *Ip6AndPort) Unmarshal(dAtA []byte) error { +func (m *V6AddrPort) Unmarshal(dAtA []byte) error { l := len(dAtA) iNdEx := 0 for iNdEx < l { @@ -1680,10 +2064,10 @@ func (m *Ip6AndPort) Unmarshal(dAtA []byte) error { fieldNum := int32(wire >> 3) wireType := int(wire & 0x7) if wireType == 4 { - return fmt.Errorf("proto: Ip6AndPort: wiretype end group for non-group") + return fmt.Errorf("proto: V6AddrPort: wiretype end group for non-group") } if fieldNum <= 0 { - return fmt.Errorf("proto: Ip6AndPort: illegal tag %d (wire type %d)", fieldNum, wire) + return fmt.Errorf("proto: V6AddrPort: illegal tag %d (wire type %d)", fieldNum, wire) } switch fieldNum { case 1: @@ -2111,6 +2495,25 @@ func (m *NebulaHandshakeDetails) Unmarshal(dAtA []byte) error { break } } + case 8: + if wireType != 0 { + return fmt.Errorf("proto: wrong wireType = %d for field CertVersion", wireType) + } + m.CertVersion = 0 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowNebula + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + m.CertVersion |= uint32(b&0x7F) << shift + if b < 0x80 { + break + } + } default: iNdEx = preIndex skippy, err := skipNebula(dAtA[iNdEx:]) @@ -2220,9 +2623,9 @@ func (m *NebulaControl) Unmarshal(dAtA []byte) error { } case 4: if wireType != 0 { - return fmt.Errorf("proto: wrong wireType = %d for field RelayToIp", wireType) + return fmt.Errorf("proto: wrong wireType = %d for field OldRelayToAddr", wireType) } - m.RelayToIp = 0 + m.OldRelayToAddr = 0 for shift := uint(0); ; shift += 7 { if shift >= 64 { return ErrIntOverflowNebula @@ -2232,16 +2635,16 @@ func (m *NebulaControl) Unmarshal(dAtA []byte) error { } b := dAtA[iNdEx] iNdEx++ - m.RelayToIp |= uint32(b&0x7F) << shift + m.OldRelayToAddr |= uint32(b&0x7F) << shift if b < 0x80 { break } } case 5: if wireType != 0 { - return fmt.Errorf("proto: wrong wireType = %d for field RelayFromIp", wireType) + return fmt.Errorf("proto: wrong wireType = %d for field OldRelayFromAddr", wireType) } - m.RelayFromIp = 0 + m.OldRelayFromAddr = 0 for shift := uint(0); ; shift += 7 { if shift >= 64 { return ErrIntOverflowNebula @@ -2251,11 +2654,83 @@ func (m *NebulaControl) Unmarshal(dAtA []byte) error { } b := dAtA[iNdEx] iNdEx++ - m.RelayFromIp |= uint32(b&0x7F) << shift + m.OldRelayFromAddr |= uint32(b&0x7F) << shift if b < 0x80 { break } } + case 6: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field RelayToAddr", wireType) + } + var msglen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowNebula + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + msglen |= int(b&0x7F) << shift + if b < 0x80 { + break + } + } + if msglen < 0 { + return ErrInvalidLengthNebula + } + postIndex := iNdEx + msglen + if postIndex < 0 { + return ErrInvalidLengthNebula + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + if m.RelayToAddr == nil { + m.RelayToAddr = &Addr{} + } + if err := m.RelayToAddr.Unmarshal(dAtA[iNdEx:postIndex]); err != nil { + return err + } + iNdEx = postIndex + case 7: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field RelayFromAddr", wireType) + } + var msglen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowNebula + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + msglen |= int(b&0x7F) << shift + if b < 0x80 { + break + } + } + if msglen < 0 { + return ErrInvalidLengthNebula + } + postIndex := iNdEx + msglen + if postIndex < 0 { + return ErrInvalidLengthNebula + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + if m.RelayFromAddr == nil { + m.RelayFromAddr = &Addr{} + } + if err := m.RelayFromAddr.Unmarshal(dAtA[iNdEx:postIndex]); err != nil { + return err + } + iNdEx = postIndex default: iNdEx = preIndex skippy, err := skipNebula(dAtA[iNdEx:]) diff --git a/nebula.proto b/nebula.proto index 88e33b7..ea10233 100644 --- a/nebula.proto +++ b/nebula.proto @@ -23,19 +23,28 @@ message NebulaMeta { } message NebulaMetaDetails { - uint32 VpnIp = 1; - repeated Ip4AndPort Ip4AndPorts = 2; - repeated Ip6AndPort Ip6AndPorts = 4; - repeated uint32 RelayVpnIp = 5; + uint32 OldVpnAddr = 1 [deprecated = true]; + Addr VpnAddr = 6; + + repeated uint32 OldRelayVpnAddrs = 5 [deprecated = true]; + repeated Addr RelayVpnAddrs = 7; + + repeated V4AddrPort V4AddrPorts = 2; + repeated V6AddrPort V6AddrPorts = 4; uint32 counter = 3; } -message Ip4AndPort { - uint32 Ip = 1; +message Addr { + uint64 Hi = 1; + uint64 Lo = 2; +} + +message V4AddrPort { + uint32 Addr = 1; uint32 Port = 2; } -message Ip6AndPort { +message V6AddrPort { uint64 Hi = 1; uint64 Lo = 2; uint32 Port = 3; @@ -62,6 +71,7 @@ message NebulaHandshakeDetails { uint32 ResponderIndex = 3; uint64 Cookie = 4; uint64 Time = 5; + uint32 CertVersion = 8; // reserved for WIP multiport reserved 6, 7; } @@ -76,6 +86,10 @@ message NebulaControl { uint32 InitiatorRelayIndex = 2; uint32 ResponderRelayIndex = 3; - uint32 RelayToIp = 4; - uint32 RelayFromIp = 5; + + uint32 OldRelayToAddr = 4 [deprecated = true]; + uint32 OldRelayFromAddr = 5 [deprecated = true]; + + Addr RelayToAddr = 6; + Addr RelayFromAddr = 7; } diff --git a/outside.go b/outside.go index 6a71fe7..1e9cde1 100644 --- a/outside.go +++ b/outside.go @@ -3,16 +3,15 @@ package nebula import ( "encoding/binary" "errors" - "fmt" "net/netip" "time" - "github.com/flynn/noise" + "github.com/google/gopacket/layers" + "golang.org/x/net/ipv6" + "github.com/sirupsen/logrus" - "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/firewall" "github.com/slackhq/nebula/header" - "github.com/slackhq/nebula/udp" "golang.org/x/net/ipv4" ) @@ -20,28 +19,9 @@ const ( minFwPacketLen = 4 ) -// TODO: IPV6-WORK this can likely be removed now -func readOutsidePackets(f *Interface) udp.EncReader { - return func( - addr netip.AddrPort, - out []byte, - packet []byte, - header *header.H, - fwPacket *firewall.Packet, - lhh udp.LightHouseHandlerFunc, - nb []byte, - q int, - localCache firewall.ConntrackCache, - ) { - f.readOutsidePackets(addr, nil, out, packet, header, fwPacket, lhh, nb, q, localCache) - } -} - -func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf udp.LightHouseHandlerFunc, nb []byte, q int, localCache firewall.ConntrackCache) { +func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache) { err := h.Parse(packet) if err != nil { - // TODO: best if we return this and let caller log - // TODO: Might be better to send the literal []byte("holepunch") packet and ignore that? // Hole punch packets are 0 or 1 byte big, so lets ignore printing those errors if len(packet) > 1 { f.l.WithField("packet", packet).Infof("Error while parsing inbound packet from %s: %s", ip, err) @@ -51,7 +31,8 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out [] //l.Error("in packet ", header, packet[HeaderLen:]) if ip.IsValid() { - if f.myVpnNet.Contains(ip.Addr()) { + _, found := f.myVpnNetworksTable.Lookup(ip.Addr()) + if found { if f.l.Level >= logrus.DebugLevel { f.l.WithField("udpAddr", ip).Debug("Refusing to process double encrypted packet") } @@ -108,7 +89,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out [] if !ok { // The only way this happens is if hostmap has an index to the correct HostInfo, but the HostInfo is missing // its internal mapping. This should never happen. - hostinfo.logger(f.l).WithFields(logrus.Fields{"vpnIp": hostinfo.vpnIp, "remoteIndex": h.RemoteIndex}).Error("HostInfo missing remote relay index") + hostinfo.logger(f.l).WithFields(logrus.Fields{"vpnAddrs": hostinfo.vpnAddrs, "remoteIndex": h.RemoteIndex}).Error("HostInfo missing remote relay index") return } @@ -120,9 +101,9 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out [] return case ForwardingType: // Find the target HostInfo relay object - targetHI, targetRelay, err := f.hostMap.QueryVpnIpRelayFor(hostinfo.vpnIp, relay.PeerIp) + targetHI, targetRelay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relay.PeerAddr) if err != nil { - hostinfo.logger(f.l).WithField("relayTo", relay.PeerIp).WithError(err).Info("Failed to find target host info by ip") + hostinfo.logger(f.l).WithField("relayTo", relay.PeerAddr).WithError(err).WithField("hostinfo.vpnAddrs", hostinfo.vpnAddrs).Info("Failed to find target host info by ip") return } @@ -138,7 +119,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out [] hostinfo.logger(f.l).Error("Unexpected Relay Type of Terminal") } } else { - hostinfo.logger(f.l).WithFields(logrus.Fields{"relayTo": relay.PeerIp, "relayFrom": hostinfo.vpnIp, "targetRelayState": targetRelay.State}).Info("Unexpected target relay state") + hostinfo.logger(f.l).WithFields(logrus.Fields{"relayTo": relay.PeerAddr, "relayFrom": hostinfo.vpnAddrs[0], "targetRelayState": targetRelay.State}).Info("Unexpected target relay state") return } } @@ -155,13 +136,10 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out [] hostinfo.logger(f.l).WithError(err).WithField("udpAddr", ip). WithField("packet", packet). Error("Failed to decrypt lighthouse packet") - - //TODO: maybe after build 64 is out? 06/14/2018 - NB - //f.sendRecvError(net.Addr(addr), header.RemoteIndex) return } - lhf(ip, hostinfo.vpnIp, d) + lhf.HandleRequest(ip, hostinfo.vpnAddrs, d, f) // Fallthrough to the bottom to record incoming traffic @@ -176,9 +154,6 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out [] hostinfo.logger(f.l).WithError(err).WithField("udpAddr", ip). WithField("packet", packet). Error("Failed to decrypt test packet") - - //TODO: maybe after build 64 is out? 06/14/2018 - NB - //f.sendRecvError(net.Addr(addr), header.RemoteIndex) return } @@ -228,14 +203,8 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out [] Error("Failed to decrypt Control packet") return } - m := &NebulaControl{} - err = m.Unmarshal(d) - if err != nil { - hostinfo.logger(f.l).WithError(err).Error("Failed to unmarshal control message") - break - } - f.relayManager.HandleControlMsg(hostinfo, m, f) + f.relayManager.HandleControlMsg(hostinfo, d, f) default: f.messageMetrics.Rx(h.Type, h.Subtype, 1) @@ -252,8 +221,8 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out [] func (f *Interface) closeTunnel(hostInfo *HostInfo) { final := f.hostMap.DeleteHostInfo(hostInfo) if final { - // We no longer have any tunnels with this vpn ip, clear learned lighthouse state to lower memory usage - f.lightHouse.DeleteVpnIp(hostInfo.vpnIp) + // We no longer have any tunnels with this vpn addr, clear learned lighthouse state to lower memory usage + f.lightHouse.DeleteVpnAddrs(hostInfo.vpnAddrs) } } @@ -262,25 +231,26 @@ func (f *Interface) sendCloseTunnel(h *HostInfo) { f.send(header.CloseTunnel, 0, h.ConnectionState, h, []byte{}, make([]byte, 12, 12), make([]byte, mtu)) } -func (f *Interface) handleHostRoaming(hostinfo *HostInfo, ip netip.AddrPort) { - if ip.IsValid() && hostinfo.remote != ip { - if !f.lightHouse.GetRemoteAllowList().Allow(hostinfo.vpnIp, ip.Addr()) { - hostinfo.logger(f.l).WithField("newAddr", ip).Debug("lighthouse.remote_allow_list denied roaming") +func (f *Interface) handleHostRoaming(hostinfo *HostInfo, udpAddr netip.AddrPort) { + if udpAddr.IsValid() && hostinfo.remote != udpAddr { + if !f.lightHouse.GetRemoteAllowList().AllowAll(hostinfo.vpnAddrs, udpAddr.Addr()) { + hostinfo.logger(f.l).WithField("newAddr", udpAddr).Debug("lighthouse.remote_allow_list denied roaming") return } - if !hostinfo.lastRoam.IsZero() && ip == hostinfo.lastRoamRemote && time.Since(hostinfo.lastRoam) < RoamingSuppressSeconds*time.Second { + + if !hostinfo.lastRoam.IsZero() && udpAddr == hostinfo.lastRoamRemote && time.Since(hostinfo.lastRoam) < RoamingSuppressSeconds*time.Second { if f.l.Level >= logrus.DebugLevel { - hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", ip). + hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", udpAddr). Debugf("Suppressing roam back to previous remote for %d seconds", RoamingSuppressSeconds) } return } - hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", ip). + hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", udpAddr). Info("Host roamed to new udp ip/port.") hostinfo.lastRoam = time.Now() hostinfo.lastRoamRemote = hostinfo.remote - hostinfo.SetRemote(ip) + hostinfo.SetRemote(udpAddr) } } @@ -300,24 +270,141 @@ func (f *Interface) handleEncrypted(ci *ConnectionState, addr netip.AddrPort, h return true } +var ( + ErrPacketTooShort = errors.New("packet is too short") + ErrUnknownIPVersion = errors.New("packet is an unknown ip version") + ErrIPv4InvalidHeaderLength = errors.New("invalid ipv4 header length") + ErrIPv4PacketTooShort = errors.New("ipv4 packet is too short") + ErrIPv6PacketTooShort = errors.New("ipv6 packet is too short") + ErrIPv6CouldNotFindPayload = errors.New("could not find payload in ipv6 packet") +) + // newPacket validates and parses the interesting bits for the firewall out of the ip and sub protocol headers func newPacket(data []byte, incoming bool, fp *firewall.Packet) error { - // Do we at least have an ipv4 header worth of data? - if len(data) < ipv4.HeaderLen { - return fmt.Errorf("packet is less than %v bytes", ipv4.HeaderLen) + if len(data) < 1 { + return ErrPacketTooShort } - // Is it an ipv4 packet? - if int((data[0]>>4)&0x0f) != 4 { - return fmt.Errorf("packet is not ipv4, type: %v", int((data[0]>>4)&0x0f)) + version := int((data[0] >> 4) & 0x0f) + switch version { + case ipv4.Version: + return parseV4(data, incoming, fp) + case ipv6.Version: + return parseV6(data, incoming, fp) + } + return ErrUnknownIPVersion +} + +func parseV6(data []byte, incoming bool, fp *firewall.Packet) error { + dataLen := len(data) + if dataLen < ipv6.HeaderLen { + return ErrIPv6PacketTooShort + } + + if incoming { + fp.RemoteAddr, _ = netip.AddrFromSlice(data[8:24]) + fp.LocalAddr, _ = netip.AddrFromSlice(data[24:40]) + } else { + fp.LocalAddr, _ = netip.AddrFromSlice(data[8:24]) + fp.RemoteAddr, _ = netip.AddrFromSlice(data[24:40]) + } + + protoAt := 6 // NextHeader is at 6 bytes into the ipv6 header + offset := ipv6.HeaderLen // Start at the end of the ipv6 header + next := 0 + for { + if dataLen < offset { + break + } + + proto := layers.IPProtocol(data[protoAt]) + //fmt.Println(proto, protoAt) + switch proto { + case layers.IPProtocolICMPv6, layers.IPProtocolESP, layers.IPProtocolNoNextHeader: + fp.Protocol = uint8(proto) + fp.RemotePort = 0 + fp.LocalPort = 0 + fp.Fragment = false + return nil + + case layers.IPProtocolTCP, layers.IPProtocolUDP: + if dataLen < offset+4 { + return ErrIPv6PacketTooShort + } + + fp.Protocol = uint8(proto) + if incoming { + fp.RemotePort = binary.BigEndian.Uint16(data[offset : offset+2]) + fp.LocalPort = binary.BigEndian.Uint16(data[offset+2 : offset+4]) + } else { + fp.LocalPort = binary.BigEndian.Uint16(data[offset : offset+2]) + fp.RemotePort = binary.BigEndian.Uint16(data[offset+2 : offset+4]) + } + + fp.Fragment = false + return nil + + case layers.IPProtocolIPv6Fragment: + // Fragment header is 8 bytes, need at least offset+4 to read the offset field + if dataLen < offset+8 { + return ErrIPv6PacketTooShort + } + + // Check if this is the first fragment + fragmentOffset := binary.BigEndian.Uint16(data[offset+2:offset+4]) &^ uint16(0x7) // Remove the reserved and M flag bits + if fragmentOffset != 0 { + // Non-first fragment, use what we have now and stop processing + fp.Protocol = data[offset] + fp.Fragment = true + fp.RemotePort = 0 + fp.LocalPort = 0 + return nil + } + + // The next loop should be the transport layer since we are the first fragment + next = 8 // Fragment headers are always 8 bytes + + case layers.IPProtocolAH: + // Auth headers, used by IPSec, have a different meaning for header length + if dataLen < offset+1 { + break + } + + next = int(data[offset+1]+2) << 2 + + default: + // Normal ipv6 header length processing + if dataLen < offset+1 { + break + } + + next = int(data[offset+1]+1) << 3 + } + + if next <= 0 { + // Safety check, each ipv6 header has to be at least 8 bytes + next = 8 + } + + protoAt = offset + offset = offset + next + } + + return ErrIPv6CouldNotFindPayload +} + +func parseV4(data []byte, incoming bool, fp *firewall.Packet) error { + // Do we at least have an ipv4 header worth of data? + if len(data) < ipv4.HeaderLen { + return ErrIPv4PacketTooShort } // Adjust our start position based on the advertised ip header length ihl := int(data[0]&0x0f) << 2 - // Well formed ip header length? + // Well-formed ip header length? if ihl < ipv4.HeaderLen { - return fmt.Errorf("packet had an invalid header length: %v", ihl) + return ErrIPv4InvalidHeaderLength } // Check if this is the second or further fragment of a fragmented packet. @@ -333,14 +420,13 @@ func newPacket(data []byte, incoming bool, fp *firewall.Packet) error { minLen += minFwPacketLen } if len(data) < minLen { - return fmt.Errorf("packet is less than %v bytes, ip header len: %v", minLen, ihl) + return ErrIPv4InvalidHeaderLength } // Firewall packets are locally oriented if incoming { - //TODO: IPV6-WORK - fp.RemoteIP, _ = netip.AddrFromSlice(data[12:16]) - fp.LocalIP, _ = netip.AddrFromSlice(data[16:20]) + fp.RemoteAddr, _ = netip.AddrFromSlice(data[12:16]) + fp.LocalAddr, _ = netip.AddrFromSlice(data[16:20]) if fp.Fragment || fp.Protocol == firewall.ProtoICMP { fp.RemotePort = 0 fp.LocalPort = 0 @@ -349,9 +435,8 @@ func newPacket(data []byte, incoming bool, fp *firewall.Packet) error { fp.LocalPort = binary.BigEndian.Uint16(data[ihl+2 : ihl+4]) } } else { - //TODO: IPV6-WORK - fp.LocalIP, _ = netip.AddrFromSlice(data[12:16]) - fp.RemoteIP, _ = netip.AddrFromSlice(data[16:20]) + fp.LocalAddr, _ = netip.AddrFromSlice(data[12:16]) + fp.RemoteAddr, _ = netip.AddrFromSlice(data[16:20]) if fp.Fragment || fp.Protocol == firewall.ProtoICMP { fp.RemotePort = 0 fp.LocalPort = 0 @@ -386,8 +471,6 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:header.Len], packet[header.Len:], messageCounter, nb) if err != nil { hostinfo.logger(f.l).WithError(err).Error("Failed to decrypt packet") - //TODO: maybe after build 64 is out? 06/14/2018 - NB - //f.sendRecvError(hostinfo.remote, header.RemoteIndex) return false } @@ -434,9 +517,8 @@ func (f *Interface) maybeSendRecvError(endpoint netip.AddrPort, index uint32) { func (f *Interface) sendRecvError(endpoint netip.AddrPort, index uint32) { f.messageMetrics.Tx(header.RecvError, 0, 1) - //TODO: this should be a signed message so we can trust that we should drop the index b := header.Encode(make([]byte, header.Len), header.Version, header.RecvError, 0, index, 0) - f.outside.WriteTo(b, endpoint) + _ = f.outside.WriteTo(b, endpoint) if f.l.Level >= logrus.DebugLevel { f.l.WithField("index", index). WithField("udpAddr", endpoint). @@ -470,49 +552,3 @@ func (f *Interface) handleRecvError(addr netip.AddrPort, h *header.H) { // We also delete it from pending hostmap to allow for fast reconnect. f.handshakeManager.DeleteHostInfo(hostinfo) } - -/* -func (f *Interface) sendMeta(ci *ConnectionState, endpoint *net.UDPAddr, meta *NebulaMeta) { - if ci.eKey != nil { - //TODO: log error? - return - } - - msg, err := proto.Marshal(meta) - if err != nil { - l.Debugln("failed to encode header") - } - - c := ci.messageCounter - b := HeaderEncode(nil, Version, uint8(metadata), 0, hostinfo.remoteIndexId, c) - ci.messageCounter++ - - msg := ci.eKey.EncryptDanger(b, nil, msg, c) - //msg := ci.eKey.EncryptDanger(b, nil, []byte(fmt.Sprintf("%d", counter)), c) - f.outside.WriteTo(msg, endpoint) -} -*/ - -func RecombineCertAndValidate(h *noise.HandshakeState, rawCertBytes []byte, caPool *cert.CAPool) (*cert.CachedCertificate, error) { - pk := h.PeerStatic() - - if pk == nil { - return nil, errors.New("no peer static key was present") - } - - if rawCertBytes == nil { - return nil, errors.New("provided payload was empty") - } - - c, err := cert.UnmarshalCertificateFromHandshake(rawCertBytes, pk) - if err != nil { - return nil, fmt.Errorf("error unmarshaling cert: %w", err) - } - - cc, err := caPool.VerifyCertificate(time.Now(), c) - if err != nil { - return nil, fmt.Errorf("certificate validation failed: %w", err) - } - - return cc, nil -} diff --git a/outside_test.go b/outside_test.go index f9d4bfa..f197594 100644 --- a/outside_test.go +++ b/outside_test.go @@ -1,10 +1,15 @@ package nebula import ( + "bytes" + "encoding/binary" "net" "net/netip" "testing" + "github.com/google/gopacket" + "github.com/google/gopacket/layers" + "github.com/slackhq/nebula/firewall" "github.com/stretchr/testify/assert" "golang.org/x/net/ipv4" @@ -13,9 +18,15 @@ import ( func Test_newPacket(t *testing.T) { p := &firewall.Packet{} - // length fail - err := newPacket([]byte{0, 1}, true, p) - assert.EqualError(t, err, "packet is less than 20 bytes") + // length fails + err := newPacket([]byte{}, true, p) + assert.ErrorIs(t, err, ErrPacketTooShort) + + err = newPacket([]byte{0x40}, true, p) + assert.ErrorIs(t, err, ErrIPv4PacketTooShort) + + err = newPacket([]byte{0x60}, true, p) + assert.ErrorIs(t, err, ErrIPv6PacketTooShort) // length fail with ip options h := ipv4.Header{ @@ -28,16 +39,15 @@ func Test_newPacket(t *testing.T) { b, _ := h.Marshal() err = newPacket(b, true, p) - - assert.EqualError(t, err, "packet is less than 28 bytes, ip header len: 24") + assert.ErrorIs(t, err, ErrIPv4InvalidHeaderLength) // not an ipv4 packet err = newPacket([]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, true, p) - assert.EqualError(t, err, "packet is not ipv4, type: 0") + assert.ErrorIs(t, err, ErrUnknownIPVersion) // invalid ihl err = newPacket([]byte{4<<4 | (8 >> 2 & 0x0f), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, true, p) - assert.EqualError(t, err, "packet had an invalid header length: 8") + assert.ErrorIs(t, err, ErrIPv4InvalidHeaderLength) // account for variable ip header length - incoming h = ipv4.Header{ @@ -54,11 +64,12 @@ func Test_newPacket(t *testing.T) { err = newPacket(b, true, p) assert.Nil(t, err) - assert.Equal(t, p.Protocol, uint8(firewall.ProtoTCP)) - assert.Equal(t, p.LocalIP, netip.MustParseAddr("10.0.0.2")) - assert.Equal(t, p.RemoteIP, netip.MustParseAddr("10.0.0.1")) - assert.Equal(t, p.RemotePort, uint16(3)) - assert.Equal(t, p.LocalPort, uint16(4)) + assert.Equal(t, uint8(firewall.ProtoTCP), p.Protocol) + assert.Equal(t, netip.MustParseAddr("10.0.0.2"), p.LocalAddr) + assert.Equal(t, netip.MustParseAddr("10.0.0.1"), p.RemoteAddr) + assert.Equal(t, uint16(3), p.RemotePort) + assert.Equal(t, uint16(4), p.LocalPort) + assert.False(t, p.Fragment) // account for variable ip header length - outgoing h = ipv4.Header{ @@ -75,9 +86,506 @@ func Test_newPacket(t *testing.T) { err = newPacket(b, false, p) assert.Nil(t, err) - assert.Equal(t, p.Protocol, uint8(2)) - assert.Equal(t, p.LocalIP, netip.MustParseAddr("10.0.0.1")) - assert.Equal(t, p.RemoteIP, netip.MustParseAddr("10.0.0.2")) - assert.Equal(t, p.RemotePort, uint16(6)) - assert.Equal(t, p.LocalPort, uint16(5)) + assert.Equal(t, uint8(2), p.Protocol) + assert.Equal(t, netip.MustParseAddr("10.0.0.1"), p.LocalAddr) + assert.Equal(t, netip.MustParseAddr("10.0.0.2"), p.RemoteAddr) + assert.Equal(t, uint16(6), p.RemotePort) + assert.Equal(t, uint16(5), p.LocalPort) + assert.False(t, p.Fragment) +} + +func Test_newPacket_v6(t *testing.T) { + p := &firewall.Packet{} + + // invalid ipv6 + ip := layers.IPv6{ + Version: 6, + HopLimit: 128, + SrcIP: net.IPv6linklocalallrouters, + DstIP: net.IPv6linklocalallnodes, + } + + buffer := gopacket.NewSerializeBuffer() + opt := gopacket.SerializeOptions{ + ComputeChecksums: false, + FixLengths: false, + } + err := gopacket.SerializeLayers(buffer, opt, &ip) + assert.NoError(t, err) + + err = newPacket(buffer.Bytes(), true, p) + assert.ErrorIs(t, err, ErrIPv6CouldNotFindPayload) + + // A good ICMP packet + ip = layers.IPv6{ + Version: 6, + NextHeader: layers.IPProtocolICMPv6, + HopLimit: 128, + SrcIP: net.IPv6linklocalallrouters, + DstIP: net.IPv6linklocalallnodes, + } + + icmp := layers.ICMPv6{} + + buffer.Clear() + err = gopacket.SerializeLayers(buffer, opt, &ip, &icmp) + if err != nil { + panic(err) + } + + err = newPacket(buffer.Bytes(), true, p) + assert.Nil(t, err) + assert.Equal(t, uint8(layers.IPProtocolICMPv6), p.Protocol) + assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr) + assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr) + assert.Equal(t, uint16(0), p.RemotePort) + assert.Equal(t, uint16(0), p.LocalPort) + assert.False(t, p.Fragment) + + // A good ESP packet + b := buffer.Bytes() + b[6] = byte(layers.IPProtocolESP) + err = newPacket(b, true, p) + assert.Nil(t, err) + assert.Equal(t, uint8(layers.IPProtocolESP), p.Protocol) + assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr) + assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr) + assert.Equal(t, uint16(0), p.RemotePort) + assert.Equal(t, uint16(0), p.LocalPort) + assert.False(t, p.Fragment) + + // A good None packet + b = buffer.Bytes() + b[6] = byte(layers.IPProtocolNoNextHeader) + err = newPacket(b, true, p) + assert.Nil(t, err) + assert.Equal(t, uint8(layers.IPProtocolNoNextHeader), p.Protocol) + assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr) + assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr) + assert.Equal(t, uint16(0), p.RemotePort) + assert.Equal(t, uint16(0), p.LocalPort) + assert.False(t, p.Fragment) + + // An unknown protocol packet + b = buffer.Bytes() + b[6] = 255 // 255 is a reserved protocol number + err = newPacket(b, true, p) + assert.ErrorIs(t, err, ErrIPv6CouldNotFindPayload) + + // A good UDP packet + ip = layers.IPv6{ + Version: 6, + NextHeader: firewall.ProtoUDP, + HopLimit: 128, + SrcIP: net.IPv6linklocalallrouters, + DstIP: net.IPv6linklocalallnodes, + } + + udp := layers.UDP{ + SrcPort: layers.UDPPort(36123), + DstPort: layers.UDPPort(22), + } + err = udp.SetNetworkLayerForChecksum(&ip) + assert.NoError(t, err) + + buffer.Clear() + err = gopacket.SerializeLayers(buffer, opt, &ip, &udp, gopacket.Payload([]byte{0xde, 0xad, 0xbe, 0xef})) + if err != nil { + panic(err) + } + b = buffer.Bytes() + + // incoming + err = newPacket(b, true, p) + assert.Nil(t, err) + assert.Equal(t, uint8(firewall.ProtoUDP), p.Protocol) + assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr) + assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr) + assert.Equal(t, uint16(36123), p.RemotePort) + assert.Equal(t, uint16(22), p.LocalPort) + assert.False(t, p.Fragment) + + // outgoing + err = newPacket(b, false, p) + assert.Nil(t, err) + assert.Equal(t, uint8(firewall.ProtoUDP), p.Protocol) + assert.Equal(t, netip.MustParseAddr("ff02::2"), p.LocalAddr) + assert.Equal(t, netip.MustParseAddr("ff02::1"), p.RemoteAddr) + assert.Equal(t, uint16(36123), p.LocalPort) + assert.Equal(t, uint16(22), p.RemotePort) + assert.False(t, p.Fragment) + + // Too short UDP packet + err = newPacket(b[:len(b)-10], false, p) // pull off the last 10 bytes + assert.ErrorIs(t, err, ErrIPv6PacketTooShort) + + // A good TCP packet + b[6] = byte(layers.IPProtocolTCP) + + // incoming + err = newPacket(b, true, p) + assert.Nil(t, err) + assert.Equal(t, uint8(firewall.ProtoTCP), p.Protocol) + assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr) + assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr) + assert.Equal(t, uint16(36123), p.RemotePort) + assert.Equal(t, uint16(22), p.LocalPort) + assert.False(t, p.Fragment) + + // outgoing + err = newPacket(b, false, p) + assert.Nil(t, err) + assert.Equal(t, uint8(firewall.ProtoTCP), p.Protocol) + assert.Equal(t, netip.MustParseAddr("ff02::2"), p.LocalAddr) + assert.Equal(t, netip.MustParseAddr("ff02::1"), p.RemoteAddr) + assert.Equal(t, uint16(36123), p.LocalPort) + assert.Equal(t, uint16(22), p.RemotePort) + assert.False(t, p.Fragment) + + // Too short TCP packet + err = newPacket(b[:len(b)-10], false, p) // pull off the last 10 bytes + assert.ErrorIs(t, err, ErrIPv6PacketTooShort) + + // A good UDP packet with an AH header + ip = layers.IPv6{ + Version: 6, + NextHeader: layers.IPProtocolAH, + HopLimit: 128, + SrcIP: net.IPv6linklocalallrouters, + DstIP: net.IPv6linklocalallnodes, + } + + ah := layers.IPSecAH{ + AuthenticationData: []byte{0xde, 0xad, 0xbe, 0xef, 0xde, 0xad, 0xbe, 0xef}, + } + ah.NextHeader = layers.IPProtocolUDP + + udpHeader := []byte{ + 0x8d, 0x1b, // Source port 36123 + 0x00, 0x16, // Destination port 22 + 0x00, 0x00, // Length + 0x00, 0x00, // Checksum + } + + buffer.Clear() + err = ip.SerializeTo(buffer, opt) + if err != nil { + panic(err) + } + + b = buffer.Bytes() + ahb := serializeAH(&ah) + b = append(b, ahb...) + b = append(b, udpHeader...) + + err = newPacket(b, true, p) + assert.Nil(t, err) + assert.Equal(t, uint8(firewall.ProtoUDP), p.Protocol) + assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr) + assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr) + assert.Equal(t, uint16(36123), p.RemotePort) + assert.Equal(t, uint16(22), p.LocalPort) + assert.False(t, p.Fragment) + + // Invalid AH header + b = buffer.Bytes() + err = newPacket(b, true, p) + assert.ErrorIs(t, err, ErrIPv6CouldNotFindPayload) +} + +func Test_newPacket_ipv6Fragment(t *testing.T) { + p := &firewall.Packet{} + + ip := &layers.IPv6{ + Version: 6, + NextHeader: layers.IPProtocolIPv6Fragment, + HopLimit: 64, + SrcIP: net.IPv6linklocalallrouters, + DstIP: net.IPv6linklocalallnodes, + } + + // First fragment + fragHeader1 := []byte{ + uint8(layers.IPProtocolUDP), // Next Header (UDP) + 0x00, // Reserved + 0x00, // Fragment Offset high byte (0) + 0x01, // Fragment Offset low byte & flags (M=1) + 0x00, 0x00, 0x00, 0x01, // Identification + } + + udpHeader := []byte{ + 0x8d, 0x1b, // Source port 36123 + 0x00, 0x16, // Destination port 22 + 0x00, 0x00, // Length + 0x00, 0x00, // Checksum + } + + buffer := gopacket.NewSerializeBuffer() + opts := gopacket.SerializeOptions{ + ComputeChecksums: true, + FixLengths: true, + } + + err := ip.SerializeTo(buffer, opts) + if err != nil { + t.Fatal(err) + } + + firstFrag := buffer.Bytes() + firstFrag = append(firstFrag, fragHeader1...) + firstFrag = append(firstFrag, udpHeader...) + firstFrag = append(firstFrag, []byte{0xde, 0xad, 0xbe, 0xef}...) + + // Test first fragment incoming + err = newPacket(firstFrag, true, p) + assert.NoError(t, err) + assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr) + assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr) + assert.Equal(t, uint8(layers.IPProtocolUDP), p.Protocol) + assert.Equal(t, uint16(36123), p.RemotePort) + assert.Equal(t, uint16(22), p.LocalPort) + assert.False(t, p.Fragment) + + // Test first fragment outgoing + err = newPacket(firstFrag, false, p) + assert.NoError(t, err) + assert.Equal(t, netip.MustParseAddr("ff02::2"), p.LocalAddr) + assert.Equal(t, netip.MustParseAddr("ff02::1"), p.RemoteAddr) + assert.Equal(t, uint8(layers.IPProtocolUDP), p.Protocol) + assert.Equal(t, uint16(36123), p.LocalPort) + assert.Equal(t, uint16(22), p.RemotePort) + assert.False(t, p.Fragment) + + // Second fragment + fragHeader2 := []byte{ + uint8(layers.IPProtocolUDP), // Next Header (UDP) + 0x00, // Reserved + 0xb9, // Fragment Offset high byte (185) + 0x01, // Fragment Offset low byte & flags (M=1) + 0x00, 0x00, 0x00, 0x01, // Identification + } + + buffer.Clear() + err = ip.SerializeTo(buffer, opts) + if err != nil { + t.Fatal(err) + } + + secondFrag := buffer.Bytes() + secondFrag = append(secondFrag, fragHeader2...) + secondFrag = append(secondFrag, []byte{0xde, 0xad, 0xbe, 0xef}...) + + // Test second fragment incoming + err = newPacket(secondFrag, true, p) + assert.NoError(t, err) + assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr) + assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr) + assert.Equal(t, uint8(layers.IPProtocolUDP), p.Protocol) + assert.Equal(t, uint16(0), p.RemotePort) + assert.Equal(t, uint16(0), p.LocalPort) + assert.True(t, p.Fragment) + + // Test second fragment outgoing + err = newPacket(secondFrag, false, p) + assert.NoError(t, err) + assert.Equal(t, netip.MustParseAddr("ff02::2"), p.LocalAddr) + assert.Equal(t, netip.MustParseAddr("ff02::1"), p.RemoteAddr) + assert.Equal(t, uint8(layers.IPProtocolUDP), p.Protocol) + assert.Equal(t, uint16(0), p.LocalPort) + assert.Equal(t, uint16(0), p.RemotePort) + assert.True(t, p.Fragment) + + // Too short of a fragment packet + err = newPacket(secondFrag[:len(secondFrag)-10], false, p) + assert.ErrorIs(t, err, ErrIPv6PacketTooShort) +} + +func BenchmarkParseV6(b *testing.B) { + // Regular UDP packet + ip := &layers.IPv6{ + Version: 6, + NextHeader: layers.IPProtocolUDP, + HopLimit: 64, + SrcIP: net.IPv6linklocalallrouters, + DstIP: net.IPv6linklocalallnodes, + } + + udp := &layers.UDP{ + SrcPort: layers.UDPPort(36123), + DstPort: layers.UDPPort(22), + } + + buffer := gopacket.NewSerializeBuffer() + opts := gopacket.SerializeOptions{ + ComputeChecksums: false, + FixLengths: true, + } + + err := gopacket.SerializeLayers(buffer, opts, ip, udp) + if err != nil { + b.Fatal(err) + } + normalPacket := buffer.Bytes() + + // First Fragment packet + ipFrag := &layers.IPv6{ + Version: 6, + NextHeader: layers.IPProtocolIPv6Fragment, + HopLimit: 64, + SrcIP: net.IPv6linklocalallrouters, + DstIP: net.IPv6linklocalallnodes, + } + + fragHeader := []byte{ + uint8(layers.IPProtocolUDP), // Next Header (UDP) + 0x00, // Reserved + 0x00, // Fragment Offset high byte (0) + 0x01, // Fragment Offset low byte & flags (M=1) + 0x00, 0x00, 0x00, 0x01, // Identification + } + + udpHeader := []byte{ + 0x8d, 0x7b, // Source port 36123 + 0x00, 0x16, // Destination port 22 + 0x00, 0x00, // Length + 0x00, 0x00, // Checksum + } + + buffer.Clear() + err = ipFrag.SerializeTo(buffer, opts) + if err != nil { + b.Fatal(err) + } + + firstFrag := buffer.Bytes() + firstFrag = append(firstFrag, fragHeader...) + firstFrag = append(firstFrag, udpHeader...) + firstFrag = append(firstFrag, []byte{0xde, 0xad, 0xbe, 0xef}...) + + // Second Fragment packet + fragHeader[2] = 0xb9 // offset 185 + buffer.Clear() + err = ipFrag.SerializeTo(buffer, opts) + if err != nil { + b.Fatal(err) + } + + secondFrag := buffer.Bytes() + secondFrag = append(secondFrag, fragHeader...) + secondFrag = append(secondFrag, []byte{0xde, 0xad, 0xbe, 0xef}...) + + fp := &firewall.Packet{} + + b.Run("Normal", func(b *testing.B) { + for i := 0; i < b.N; i++ { + if err = parseV6(normalPacket, true, fp); err != nil { + b.Fatal(err) + } + } + }) + + b.Run("FirstFragment", func(b *testing.B) { + for i := 0; i < b.N; i++ { + if err = parseV6(firstFrag, true, fp); err != nil { + b.Fatal(err) + } + } + }) + + b.Run("SecondFragment", func(b *testing.B) { + for i := 0; i < b.N; i++ { + if err = parseV6(secondFrag, true, fp); err != nil { + b.Fatal(err) + } + } + }) + + // Evil packet + evilPacket := &layers.IPv6{ + Version: 6, + NextHeader: layers.IPProtocolIPv6HopByHop, + HopLimit: 64, + SrcIP: net.IPv6linklocalallrouters, + DstIP: net.IPv6linklocalallnodes, + } + + hopHeader := []byte{ + uint8(layers.IPProtocolIPv6HopByHop), // Next Header (HopByHop) + 0x00, // Length + 0x00, 0x00, // Options and padding + 0x00, 0x00, 0x00, 0x00, // More options and padding + } + + lastHopHeader := []byte{ + uint8(layers.IPProtocolUDP), // Next Header (UDP) + 0x00, // Length + 0x00, 0x00, // Options and padding + 0x00, 0x00, 0x00, 0x00, // More options and padding + } + + buffer.Clear() + err = evilPacket.SerializeTo(buffer, opts) + if err != nil { + b.Fatal(err) + } + + evilBytes := buffer.Bytes() + for i := 0; i < 200; i++ { + evilBytes = append(evilBytes, hopHeader...) + } + evilBytes = append(evilBytes, lastHopHeader...) + evilBytes = append(evilBytes, udpHeader...) + evilBytes = append(evilBytes, []byte{0xde, 0xad, 0xbe, 0xef}...) + + b.Run("200 HopByHop headers", func(b *testing.B) { + for i := 0; i < b.N; i++ { + if err = parseV6(evilBytes, false, fp); err != nil { + b.Fatal(err) + } + } + }) +} + +// Ensure authentication data is a multiple of 8 bytes by padding if necessary +func padAuthData(authData []byte) []byte { + // Length of Authentication Data must be a multiple of 8 bytes + paddingLength := (8 - (len(authData) % 8)) % 8 // Only pad if necessary + if paddingLength > 0 { + authData = append(authData, make([]byte, paddingLength)...) + } + return authData +} + +// Custom function to manually serialize IPSecAH for both IPv4 and IPv6 +func serializeAH(ah *layers.IPSecAH) []byte { + buf := new(bytes.Buffer) + + // Ensure Authentication Data is a multiple of 8 bytes + ah.AuthenticationData = padAuthData(ah.AuthenticationData) + // Calculate Payload Length (in 32-bit words, minus 2) + payloadLen := uint8((12+len(ah.AuthenticationData))/4) - 2 + + // Serialize fields + if err := binary.Write(buf, binary.BigEndian, ah.NextHeader); err != nil { + panic(err) + } + if err := binary.Write(buf, binary.BigEndian, payloadLen); err != nil { + panic(err) + } + if err := binary.Write(buf, binary.BigEndian, ah.Reserved); err != nil { + panic(err) + } + if err := binary.Write(buf, binary.BigEndian, ah.SPI); err != nil { + panic(err) + } + if err := binary.Write(buf, binary.BigEndian, ah.Seq); err != nil { + panic(err) + } + if len(ah.AuthenticationData) > 0 { + if err := binary.Write(buf, binary.BigEndian, ah.AuthenticationData); err != nil { + panic(err) + } + } + + return buf.Bytes() } diff --git a/overlay/device.go b/overlay/device.go index 50ad6ad..da8cbe9 100644 --- a/overlay/device.go +++ b/overlay/device.go @@ -8,7 +8,7 @@ import ( type Device interface { io.ReadWriteCloser Activate() error - Cidr() netip.Prefix + Networks() []netip.Prefix Name() string RouteFor(netip.Addr) netip.Addr NewMultiQueueReader() (io.ReadWriteCloser, error) diff --git a/overlay/route.go b/overlay/route.go index 8ccc994..687cc11 100644 --- a/overlay/route.go +++ b/overlay/route.go @@ -61,7 +61,7 @@ func makeRouteTree(l *logrus.Logger, routes []Route, allowMTU bool) (*bart.Table return routeTree, nil } -func parseRoutes(c *config.C, network netip.Prefix) ([]Route, error) { +func parseRoutes(c *config.C, networks []netip.Prefix) ([]Route, error) { var err error r := c.Get("tun.routes") @@ -117,12 +117,20 @@ func parseRoutes(c *config.C, network netip.Prefix) ([]Route, error) { return nil, fmt.Errorf("entry %v.route in tun.routes failed to parse: %v", i+1, err) } - if !network.Contains(r.Cidr.Addr()) || r.Cidr.Bits() < network.Bits() { + found := false + for _, network := range networks { + if network.Contains(r.Cidr.Addr()) && r.Cidr.Bits() >= network.Bits() { + found = true + break + } + } + + if !found { return nil, fmt.Errorf( - "entry %v.route in tun.routes is not contained within the network attached to the certificate; route: %v, network: %v", + "entry %v.route in tun.routes is not contained within the configured vpn networks; route: %v, networks: %v", i+1, r.Cidr.String(), - network.String(), + networks, ) } @@ -132,7 +140,7 @@ func parseRoutes(c *config.C, network netip.Prefix) ([]Route, error) { return routes, nil } -func parseUnsafeRoutes(c *config.C, network netip.Prefix) ([]Route, error) { +func parseUnsafeRoutes(c *config.C, networks []netip.Prefix) ([]Route, error) { var err error r := c.Get("tun.unsafe_routes") @@ -229,13 +237,15 @@ func parseUnsafeRoutes(c *config.C, network netip.Prefix) ([]Route, error) { return nil, fmt.Errorf("entry %v.route in tun.unsafe_routes failed to parse: %v", i+1, err) } - if network.Contains(r.Cidr.Addr()) { - return nil, fmt.Errorf( - "entry %v.route in tun.unsafe_routes is contained within the network attached to the certificate; route: %v, network: %v", - i+1, - r.Cidr.String(), - network.String(), - ) + for _, network := range networks { + if network.Contains(r.Cidr.Addr()) { + return nil, fmt.Errorf( + "entry %v.route in tun.unsafe_routes is contained within the configured vpn networks; route: %v, network: %v", + i+1, + r.Cidr.String(), + network.String(), + ) + } } routes[i] = r diff --git a/overlay/route_test.go b/overlay/route_test.go index d791389..c60e4c2 100644 --- a/overlay/route_test.go +++ b/overlay/route_test.go @@ -17,76 +17,82 @@ func Test_parseRoutes(t *testing.T) { assert.NoError(t, err) // test no routes config - routes, err := parseRoutes(c, n) + routes, err := parseRoutes(c, []netip.Prefix{n}) assert.Nil(t, err) assert.Len(t, routes, 0) // not an array c.Settings["tun"] = map[interface{}]interface{}{"routes": "hi"} - routes, err = parseRoutes(c, n) + routes, err = parseRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) assert.EqualError(t, err, "tun.routes is not an array") // no routes c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{}} - routes, err = parseRoutes(c, n) + routes, err = parseRoutes(c, []netip.Prefix{n}) assert.Nil(t, err) assert.Len(t, routes, 0) // weird route c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{"asdf"}} - routes, err = parseRoutes(c, n) + routes, err = parseRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) assert.EqualError(t, err, "entry 1 in tun.routes is invalid") // no mtu c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{}}} - routes, err = parseRoutes(c, n) + routes, err = parseRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) assert.EqualError(t, err, "entry 1.mtu in tun.routes is not present") // bad mtu c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "nope"}}} - routes, err = parseRoutes(c, n) + routes, err = parseRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) assert.EqualError(t, err, "entry 1.mtu in tun.routes is not an integer: strconv.Atoi: parsing \"nope\": invalid syntax") // low mtu c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "499"}}} - routes, err = parseRoutes(c, n) + routes, err = parseRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) assert.EqualError(t, err, "entry 1.mtu in tun.routes is below 500: 499") // missing route c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500"}}} - routes, err = parseRoutes(c, n) + routes, err = parseRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) assert.EqualError(t, err, "entry 1.route in tun.routes is not present") // unparsable route c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "nope"}}} - routes, err = parseRoutes(c, n) + routes, err = parseRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) assert.EqualError(t, err, "entry 1.route in tun.routes failed to parse: netip.ParsePrefix(\"nope\"): no '/'") // below network range c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "1.0.0.0/8"}}} - routes, err = parseRoutes(c, n) + routes, err = parseRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) - assert.EqualError(t, err, "entry 1.route in tun.routes is not contained within the network attached to the certificate; route: 1.0.0.0/8, network: 10.0.0.0/24") + assert.EqualError(t, err, "entry 1.route in tun.routes is not contained within the configured vpn networks; route: 1.0.0.0/8, networks: [10.0.0.0/24]") // above network range c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "10.0.1.0/24"}}} - routes, err = parseRoutes(c, n) + routes, err = parseRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) - assert.EqualError(t, err, "entry 1.route in tun.routes is not contained within the network attached to the certificate; route: 10.0.1.0/24, network: 10.0.0.0/24") + assert.EqualError(t, err, "entry 1.route in tun.routes is not contained within the configured vpn networks; route: 10.0.1.0/24, networks: [10.0.0.0/24]") + + // Not in multiple ranges + c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "192.0.0.0/24"}}} + routes, err = parseRoutes(c, []netip.Prefix{n, netip.MustParsePrefix("192.1.0.0/24")}) + assert.Nil(t, routes) + assert.EqualError(t, err, "entry 1.route in tun.routes is not contained within the configured vpn networks; route: 192.0.0.0/24, networks: [10.0.0.0/24 192.1.0.0/24]") // happy case c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{ map[interface{}]interface{}{"mtu": "9000", "route": "10.0.0.0/29"}, map[interface{}]interface{}{"mtu": "8000", "route": "10.0.0.1/32"}, }} - routes, err = parseRoutes(c, n) + routes, err = parseRoutes(c, []netip.Prefix{n}) assert.Nil(t, err) assert.Len(t, routes, 2) @@ -116,31 +122,31 @@ func Test_parseUnsafeRoutes(t *testing.T) { assert.NoError(t, err) // test no routes config - routes, err := parseUnsafeRoutes(c, n) + routes, err := parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, err) assert.Len(t, routes, 0) // not an array c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": "hi"} - routes, err = parseUnsafeRoutes(c, n) + routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) assert.EqualError(t, err, "tun.unsafe_routes is not an array") // no routes c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{}} - routes, err = parseUnsafeRoutes(c, n) + routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, err) assert.Len(t, routes, 0) // weird route c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{"asdf"}} - routes, err = parseUnsafeRoutes(c, n) + routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) assert.EqualError(t, err, "entry 1 in tun.unsafe_routes is invalid") // no via c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{}}} - routes, err = parseUnsafeRoutes(c, n) + routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) assert.EqualError(t, err, "entry 1.via in tun.unsafe_routes is not present") @@ -149,68 +155,68 @@ func Test_parseUnsafeRoutes(t *testing.T) { 127, false, nil, 1.0, []string{"1", "2"}, } { c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": invalidValue}}} - routes, err = parseUnsafeRoutes(c, n) + routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) assert.EqualError(t, err, fmt.Sprintf("entry 1.via in tun.unsafe_routes is not a string: found %T", invalidValue)) } // unparsable via c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "via": "nope"}}} - routes, err = parseUnsafeRoutes(c, n) + routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) assert.EqualError(t, err, "entry 1.via in tun.unsafe_routes failed to parse address: ParseAddr(\"nope\"): unable to parse IP") // missing route c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "500"}}} - routes, err = parseUnsafeRoutes(c, n) + routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) assert.EqualError(t, err, "entry 1.route in tun.unsafe_routes is not present") // unparsable route c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "500", "route": "nope"}}} - routes, err = parseUnsafeRoutes(c, n) + routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) assert.EqualError(t, err, "entry 1.route in tun.unsafe_routes failed to parse: netip.ParsePrefix(\"nope\"): no '/'") // within network range c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "10.0.0.0/24"}}} - routes, err = parseUnsafeRoutes(c, n) + routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) - assert.EqualError(t, err, "entry 1.route in tun.unsafe_routes is contained within the network attached to the certificate; route: 10.0.0.0/24, network: 10.0.0.0/24") + assert.EqualError(t, err, "entry 1.route in tun.unsafe_routes is contained within the configured vpn networks; route: 10.0.0.0/24, network: 10.0.0.0/24") // below network range c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "1.0.0.0/8"}}} - routes, err = parseUnsafeRoutes(c, n) + routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Len(t, routes, 1) assert.Nil(t, err) // above network range c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "10.0.1.0/24"}}} - routes, err = parseUnsafeRoutes(c, n) + routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Len(t, routes, 1) assert.Nil(t, err) // no mtu c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "1.0.0.0/8"}}} - routes, err = parseUnsafeRoutes(c, n) + routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Len(t, routes, 1) assert.Equal(t, 0, routes[0].MTU) // bad mtu c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "nope"}}} - routes, err = parseUnsafeRoutes(c, n) + routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) assert.EqualError(t, err, "entry 1.mtu in tun.unsafe_routes is not an integer: strconv.Atoi: parsing \"nope\": invalid syntax") // low mtu c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "499"}}} - routes, err = parseUnsafeRoutes(c, n) + routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) assert.EqualError(t, err, "entry 1.mtu in tun.unsafe_routes is below 500: 499") // bad install c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "9000", "route": "1.0.0.0/29", "install": "nope"}}} - routes, err = parseUnsafeRoutes(c, n) + routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) assert.EqualError(t, err, "entry 1.install in tun.unsafe_routes is not a boolean: strconv.ParseBool: parsing \"nope\": invalid syntax") @@ -221,7 +227,7 @@ func Test_parseUnsafeRoutes(t *testing.T) { map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "1500", "metric": 1234, "route": "1.0.0.2/32", "install": 1}, map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "1500", "metric": 1234, "route": "1.0.0.2/32"}, }} - routes, err = parseUnsafeRoutes(c, n) + routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, err) assert.Len(t, routes, 4) @@ -260,7 +266,7 @@ func Test_makeRouteTree(t *testing.T) { map[interface{}]interface{}{"via": "192.168.0.1", "route": "1.0.0.0/28"}, map[interface{}]interface{}{"via": "192.168.0.2", "route": "1.0.0.1/32"}, }} - routes, err := parseUnsafeRoutes(c, n) + routes, err := parseUnsafeRoutes(c, []netip.Prefix{n}) assert.NoError(t, err) assert.Len(t, routes, 2) routeTree, err := makeRouteTree(l, routes, true) diff --git a/overlay/tun.go b/overlay/tun.go index 12460da..4a6377d 100644 --- a/overlay/tun.go +++ b/overlay/tun.go @@ -11,36 +11,36 @@ import ( const DefaultMTU = 1300 // TODO: We may be able to remove routines -type DeviceFactory func(c *config.C, l *logrus.Logger, tunCidr netip.Prefix, routines int) (Device, error) +type DeviceFactory func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) -func NewDeviceFromConfig(c *config.C, l *logrus.Logger, tunCidr netip.Prefix, routines int) (Device, error) { +func NewDeviceFromConfig(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) { switch { case c.GetBool("tun.disabled", false): - tun := newDisabledTun(tunCidr, c.GetInt("tun.tx_queue", 500), c.GetBool("stats.message_metrics", false), l) + tun := newDisabledTun(vpnNetworks, c.GetInt("tun.tx_queue", 500), c.GetBool("stats.message_metrics", false), l) return tun, nil default: - return newTun(c, l, tunCidr, routines > 1) + return newTun(c, l, vpnNetworks, routines > 1) } } func NewFdDeviceFromConfig(fd *int) DeviceFactory { - return func(c *config.C, l *logrus.Logger, tunCidr netip.Prefix, routines int) (Device, error) { - return newTunFromFd(c, l, *fd, tunCidr) + return func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) { + return newTunFromFd(c, l, *fd, vpnNetworks) } } -func getAllRoutesFromConfig(c *config.C, cidr netip.Prefix, initial bool) (bool, []Route, error) { +func getAllRoutesFromConfig(c *config.C, vpnNetworks []netip.Prefix, initial bool) (bool, []Route, error) { if !initial && !c.HasChanged("tun.routes") && !c.HasChanged("tun.unsafe_routes") { return false, nil, nil } - routes, err := parseRoutes(c, cidr) + routes, err := parseRoutes(c, vpnNetworks) if err != nil { return true, nil, util.NewContextualError("Could not parse tun.routes", nil, err) } - unsafeRoutes, err := parseUnsafeRoutes(c, cidr) + unsafeRoutes, err := parseUnsafeRoutes(c, vpnNetworks) if err != nil { return true, nil, util.NewContextualError("Could not parse tun.unsafe_routes", nil, err) } diff --git a/overlay/tun_android.go b/overlay/tun_android.go index 98ad9b4..72a6565 100644 --- a/overlay/tun_android.go +++ b/overlay/tun_android.go @@ -18,14 +18,14 @@ import ( type tun struct { io.ReadWriteCloser - fd int - cidr netip.Prefix - Routes atomic.Pointer[[]Route] - routeTree atomic.Pointer[bart.Table[netip.Addr]] - l *logrus.Logger + fd int + vpnNetworks []netip.Prefix + Routes atomic.Pointer[[]Route] + routeTree atomic.Pointer[bart.Table[netip.Addr]] + l *logrus.Logger } -func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr netip.Prefix) (*tun, error) { +func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) { // XXX Android returns an fd in non-blocking mode which is necessary for shutdown to work properly. // Be sure not to call file.Fd() as it will set the fd to blocking mode. file := os.NewFile(uintptr(deviceFd), "/dev/net/tun") @@ -33,7 +33,7 @@ func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr netip.Prefix t := &tun{ ReadWriteCloser: file, fd: deviceFd, - cidr: cidr, + vpnNetworks: vpnNetworks, l: l, } @@ -52,7 +52,7 @@ func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr netip.Prefix return t, nil } -func newTun(_ *config.C, _ *logrus.Logger, _ netip.Prefix, _ bool) (*tun, error) { +func newTun(_ *config.C, _ *logrus.Logger, _ []netip.Prefix, _ bool) (*tun, error) { return nil, fmt.Errorf("newTun not supported in Android") } @@ -66,7 +66,7 @@ func (t tun) Activate() error { } func (t *tun) reload(c *config.C, initial bool) error { - change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial) + change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial) if err != nil { return err } @@ -86,8 +86,8 @@ func (t *tun) reload(c *config.C, initial bool) error { return nil } -func (t *tun) Cidr() netip.Prefix { - return t.cidr +func (t *tun) Networks() []netip.Prefix { + return t.vpnNetworks } func (t *tun) Name() string { diff --git a/overlay/tun_darwin.go b/overlay/tun_darwin.go index 0b573e6..1a02b49 100644 --- a/overlay/tun_darwin.go +++ b/overlay/tun_darwin.go @@ -24,56 +24,62 @@ import ( type tun struct { io.ReadWriteCloser - Device string - cidr netip.Prefix - DefaultMTU int - Routes atomic.Pointer[[]Route] - routeTree atomic.Pointer[bart.Table[netip.Addr]] - linkAddr *netroute.LinkAddr - l *logrus.Logger + Device string + vpnNetworks []netip.Prefix + DefaultMTU int + Routes atomic.Pointer[[]Route] + routeTree atomic.Pointer[bart.Table[netip.Addr]] + linkAddr *netroute.LinkAddr + l *logrus.Logger // cache out buffer since we need to prepend 4 bytes for tun metadata out []byte } -type sockaddrCtl struct { - scLen uint8 - scFamily uint8 - ssSysaddr uint16 - scID uint32 - scUnit uint32 - scReserved [5]uint32 -} - type ifReq struct { - Name [16]byte + Name [unix.IFNAMSIZ]byte Flags uint16 pad [8]byte } -var sockaddrCtlSize uintptr = 32 - const ( - _SYSPROTO_CONTROL = 2 //define SYSPROTO_CONTROL 2 /* kernel control protocol */ - _AF_SYS_CONTROL = 2 //#define AF_SYS_CONTROL 2 /* corresponding sub address type */ - _PF_SYSTEM = unix.AF_SYSTEM //#define PF_SYSTEM AF_SYSTEM - _CTLIOCGINFO = 3227799043 //#define CTLIOCGINFO _IOWR('N', 3, struct ctl_info) - utunControlName = "com.apple.net.utun_control" + _SIOCAIFADDR_IN6 = 2155899162 + _UTUN_OPT_IFNAME = 2 + _IN6_IFF_NODAD = 0x0020 + _IN6_IFF_SECURED = 0x0400 + utunControlName = "com.apple.net.utun_control" ) -type ifreqAddr struct { - Name [16]byte - Addr unix.RawSockaddrInet4 - pad [8]byte -} - type ifreqMTU struct { Name [16]byte MTU int32 pad [8]byte } -func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, error) { +type addrLifetime struct { + Expire float64 + Preferred float64 + Vltime uint32 + Pltime uint32 +} + +type ifreqAlias4 struct { + Name [unix.IFNAMSIZ]byte + Addr unix.RawSockaddrInet4 + DstAddr unix.RawSockaddrInet4 + MaskAddr unix.RawSockaddrInet4 +} + +type ifreqAlias6 struct { + Name [unix.IFNAMSIZ]byte + Addr unix.RawSockaddrInet6 + DstAddr unix.RawSockaddrInet6 + PrefixMask unix.RawSockaddrInet6 + Flags uint32 + Lifetime addrLifetime +} + +func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) { name := c.GetString("tun.dev", "") ifIndex := -1 if name != "" && name != "utun" { @@ -86,66 +92,41 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, err } } - fd, err := unix.Socket(_PF_SYSTEM, unix.SOCK_DGRAM, _SYSPROTO_CONTROL) + fd, err := unix.Socket(unix.AF_SYSTEM, unix.SOCK_DGRAM, unix.AF_SYS_CONTROL) if err != nil { return nil, fmt.Errorf("system socket: %v", err) } - var ctlInfo = &struct { - ctlID uint32 - ctlName [96]byte - }{} + var ctlInfo = &unix.CtlInfo{} + copy(ctlInfo.Name[:], utunControlName) - copy(ctlInfo.ctlName[:], utunControlName) - - err = ioctl(uintptr(fd), uintptr(_CTLIOCGINFO), uintptr(unsafe.Pointer(ctlInfo))) + err = unix.IoctlCtlInfo(fd, ctlInfo) if err != nil { return nil, fmt.Errorf("CTLIOCGINFO: %v", err) } - sc := sockaddrCtl{ - scLen: uint8(sockaddrCtlSize), - scFamily: unix.AF_SYSTEM, - ssSysaddr: _AF_SYS_CONTROL, - scID: ctlInfo.ctlID, - scUnit: uint32(ifIndex) + 1, + err = unix.Connect(fd, &unix.SockaddrCtl{ + ID: ctlInfo.Id, + Unit: uint32(ifIndex) + 1, + }) + if err != nil { + return nil, fmt.Errorf("SYS_CONNECT: %v", err) } - _, _, errno := unix.RawSyscall( - unix.SYS_CONNECT, - uintptr(fd), - uintptr(unsafe.Pointer(&sc)), - sockaddrCtlSize, - ) - if errno != 0 { - return nil, fmt.Errorf("SYS_CONNECT: %v", errno) + name, err = unix.GetsockoptString(fd, unix.AF_SYS_CONTROL, _UTUN_OPT_IFNAME) + if err != nil { + return nil, fmt.Errorf("failed to retrieve tun name: %w", err) } - var ifName struct { - name [16]byte - } - ifNameSize := uintptr(len(ifName.name)) - _, _, errno = syscall.Syscall6(syscall.SYS_GETSOCKOPT, uintptr(fd), - 2, // SYSPROTO_CONTROL - 2, // UTUN_OPT_IFNAME - uintptr(unsafe.Pointer(&ifName)), - uintptr(unsafe.Pointer(&ifNameSize)), 0) - if errno != 0 { - return nil, fmt.Errorf("SYS_GETSOCKOPT: %v", errno) - } - name = string(ifName.name[:ifNameSize-1]) - - err = syscall.SetNonblock(fd, true) + err = unix.SetNonblock(fd, true) if err != nil { return nil, fmt.Errorf("SetNonblock: %v", err) } - file := os.NewFile(uintptr(fd), "") - t := &tun{ - ReadWriteCloser: file, + ReadWriteCloser: os.NewFile(uintptr(fd), ""), Device: name, - cidr: cidr, + vpnNetworks: vpnNetworks, DefaultMTU: c.GetInt("tun.mtu", DefaultMTU), l: l, } @@ -172,7 +153,7 @@ func (t *tun) deviceBytes() (o [16]byte) { return } -func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ netip.Prefix) (*tun, error) { +func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) { return nil, fmt.Errorf("newTunFromFd not supported in Darwin") } @@ -186,16 +167,6 @@ func (t *tun) Close() error { func (t *tun) Activate() error { devName := t.deviceBytes() - var addr, mask [4]byte - - if !t.cidr.Addr().Is4() { - //TODO: IPV6-WORK - panic("need ipv6") - } - - addr = t.cidr.Addr().As4() - copy(mask[:], prefixToMask(t.cidr)) - s, err := unix.Socket( unix.AF_INET, unix.SOCK_DGRAM, @@ -208,66 +179,18 @@ func (t *tun) Activate() error { fd := uintptr(s) - ifra := ifreqAddr{ - Name: devName, - Addr: unix.RawSockaddrInet4{ - Family: unix.AF_INET, - Addr: addr, - }, - } - - // Set the device ip address - if err = ioctl(fd, unix.SIOCSIFADDR, uintptr(unsafe.Pointer(&ifra))); err != nil { - return fmt.Errorf("failed to set tun address: %s", err) - } - - // Set the device network - ifra.Addr.Addr = mask - if err = ioctl(fd, unix.SIOCSIFNETMASK, uintptr(unsafe.Pointer(&ifra))); err != nil { - return fmt.Errorf("failed to set tun netmask: %s", err) - } - - // Set the device name - ifrf := ifReq{Name: devName} - if err = ioctl(fd, unix.SIOCGIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil { - return fmt.Errorf("failed to set tun device name: %s", err) - } - // Set the MTU on the device ifm := ifreqMTU{Name: devName, MTU: int32(t.DefaultMTU)} if err = ioctl(fd, unix.SIOCSIFMTU, uintptr(unsafe.Pointer(&ifm))); err != nil { return fmt.Errorf("failed to set tun mtu: %v", err) } - /* - // Set the transmit queue length - ifrq := ifreqQLEN{Name: devName, Value: int32(t.TXQueueLen)} - if err = ioctl(fd, unix.SIOCSIFTXQLEN, uintptr(unsafe.Pointer(&ifrq))); err != nil { - // If we can't set the queue length nebula will still work but it may lead to packet loss - l.WithError(err).Error("Failed to set tun tx queue length") - } - */ - - // Bring up the interface - ifrf.Flags = ifrf.Flags | unix.IFF_UP - if err = ioctl(fd, unix.SIOCSIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil { - return fmt.Errorf("failed to bring the tun device up: %s", err) + // Get the device flags + ifrf := ifReq{Name: devName} + if err = ioctl(fd, unix.SIOCGIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil { + return fmt.Errorf("failed to get tun flags: %s", err) } - routeSock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC) - if err != nil { - return fmt.Errorf("unable to create AF_ROUTE socket: %v", err) - } - defer func() { - unix.Shutdown(routeSock, unix.SHUT_RDWR) - err := unix.Close(routeSock) - if err != nil { - t.l.WithError(err).Error("failed to close AF_ROUTE socket") - } - }() - - routeAddr := &netroute.Inet4Addr{} - maskAddr := &netroute.Inet4Addr{} linkAddr, err := getLinkAddr(t.Device) if err != nil { return err @@ -277,14 +200,18 @@ func (t *tun) Activate() error { } t.linkAddr = linkAddr - copy(routeAddr.IP[:], addr[:]) - copy(maskAddr.IP[:], mask[:]) - err = addRoute(routeSock, routeAddr, maskAddr, linkAddr) - if err != nil { - if errors.Is(err, unix.EEXIST) { - err = fmt.Errorf("unable to add tun route, identical route already exists: %s", t.cidr) + for _, network := range t.vpnNetworks { + if network.Addr().Is4() { + err = t.activate4(network) + if err != nil { + return err + } + } else { + err = t.activate6(network) + if err != nil { + return err + } } - return err } // Run the interface @@ -297,8 +224,89 @@ func (t *tun) Activate() error { return t.addRoutes(false) } +func (t *tun) activate4(network netip.Prefix) error { + s, err := unix.Socket( + unix.AF_INET, + unix.SOCK_DGRAM, + unix.IPPROTO_IP, + ) + if err != nil { + return err + } + defer unix.Close(s) + + ifr := ifreqAlias4{ + Name: t.deviceBytes(), + Addr: unix.RawSockaddrInet4{ + Len: unix.SizeofSockaddrInet4, + Family: unix.AF_INET, + Addr: network.Addr().As4(), + }, + DstAddr: unix.RawSockaddrInet4{ + Len: unix.SizeofSockaddrInet4, + Family: unix.AF_INET, + Addr: network.Addr().As4(), + }, + MaskAddr: unix.RawSockaddrInet4{ + Len: unix.SizeofSockaddrInet4, + Family: unix.AF_INET, + Addr: prefixToMask(network).As4(), + }, + } + + if err := ioctl(uintptr(s), unix.SIOCAIFADDR, uintptr(unsafe.Pointer(&ifr))); err != nil { + return fmt.Errorf("failed to set tun v4 address: %s", err) + } + + err = addRoute(network, t.linkAddr) + if err != nil { + return err + } + + return nil +} + +func (t *tun) activate6(network netip.Prefix) error { + s, err := unix.Socket( + unix.AF_INET6, + unix.SOCK_DGRAM, + unix.IPPROTO_IP, + ) + if err != nil { + return err + } + defer unix.Close(s) + + ifr := ifreqAlias6{ + Name: t.deviceBytes(), + Addr: unix.RawSockaddrInet6{ + Len: unix.SizeofSockaddrInet6, + Family: unix.AF_INET6, + Addr: network.Addr().As16(), + }, + PrefixMask: unix.RawSockaddrInet6{ + Len: unix.SizeofSockaddrInet6, + Family: unix.AF_INET6, + Addr: prefixToMask(network).As16(), + }, + Lifetime: addrLifetime{ + // never expires + Vltime: 0xffffffff, + Pltime: 0xffffffff, + }, + //TODO: CERT-V2 should we disable DAD (duplicate address detection) and mark this as a secured address? + Flags: _IN6_IFF_NODAD, + } + + if err := ioctl(uintptr(s), _SIOCAIFADDR_IN6, uintptr(unsafe.Pointer(&ifr))); err != nil { + return fmt.Errorf("failed to set tun address: %s", err) + } + + return nil +} + func (t *tun) reload(c *config.C, initial bool) error { - change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial) + change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial) if err != nil { return err } @@ -343,7 +351,7 @@ func (t *tun) RouteFor(ip netip.Addr) netip.Addr { } // Get the LinkAddr for the interface of the given name -// TODO: Is there an easier way to fetch this when we create the interface? +// Is there an easier way to fetch this when we create the interface? // Maybe SIOCGIFINDEX? but this doesn't appear to exist in the darwin headers. func getLinkAddr(name string) (*netroute.LinkAddr, error) { rib, err := netroute.FetchRIB(unix.AF_UNSPEC, unix.NET_RT_IFLIST, 0) @@ -371,38 +379,15 @@ func getLinkAddr(name string) (*netroute.LinkAddr, error) { } func (t *tun) addRoutes(logErrors bool) error { - routeSock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC) - if err != nil { - return fmt.Errorf("unable to create AF_ROUTE socket: %v", err) - } - - defer func() { - unix.Shutdown(routeSock, unix.SHUT_RDWR) - err := unix.Close(routeSock) - if err != nil { - t.l.WithError(err).Error("failed to close AF_ROUTE socket") - } - }() - - routeAddr := &netroute.Inet4Addr{} - maskAddr := &netroute.Inet4Addr{} routes := *t.Routes.Load() + for _, r := range routes { if !r.Via.IsValid() || !r.Install { // We don't allow route MTUs so only install routes with a via continue } - if !r.Cidr.Addr().Is4() { - //TODO: implement ipv6 - panic("Cant handle ipv6 routes yet") - } - - routeAddr.IP = r.Cidr.Addr().As4() - //TODO: we could avoid the copy - copy(maskAddr.IP[:], prefixToMask(r.Cidr)) - - err := addRoute(routeSock, routeAddr, maskAddr, t.linkAddr) + err := addRoute(r.Cidr, t.linkAddr) if err != nil { if errors.Is(err, unix.EEXIST) { t.l.WithField("route", r.Cidr). @@ -424,36 +409,12 @@ func (t *tun) addRoutes(logErrors bool) error { } func (t *tun) removeRoutes(routes []Route) error { - routeSock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC) - if err != nil { - return fmt.Errorf("unable to create AF_ROUTE socket: %v", err) - } - - defer func() { - unix.Shutdown(routeSock, unix.SHUT_RDWR) - err := unix.Close(routeSock) - if err != nil { - t.l.WithError(err).Error("failed to close AF_ROUTE socket") - } - }() - - routeAddr := &netroute.Inet4Addr{} - maskAddr := &netroute.Inet4Addr{} - for _, r := range routes { if !r.Install { continue } - if r.Cidr.Addr().Is6() { - //TODO: implement ipv6 - panic("Cant handle ipv6 routes yet") - } - - routeAddr.IP = r.Cidr.Addr().As4() - copy(maskAddr.IP[:], prefixToMask(r.Cidr)) - - err := delRoute(routeSock, routeAddr, maskAddr, t.linkAddr) + err := delRoute(r.Cidr, t.linkAddr) if err != nil { t.l.WithError(err).WithField("route", r).Error("Failed to remove route") } else { @@ -463,23 +424,39 @@ func (t *tun) removeRoutes(routes []Route) error { return nil } -func addRoute(sock int, addr, mask *netroute.Inet4Addr, link *netroute.LinkAddr) error { - r := netroute.RouteMessage{ +func addRoute(prefix netip.Prefix, gateway netroute.Addr) error { + sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC) + if err != nil { + return fmt.Errorf("unable to create AF_ROUTE socket: %v", err) + } + defer unix.Close(sock) + + route := &netroute.RouteMessage{ Version: unix.RTM_VERSION, Type: unix.RTM_ADD, Flags: unix.RTF_UP, Seq: 1, - Addrs: []netroute.Addr{ - unix.RTAX_DST: addr, - unix.RTAX_GATEWAY: link, - unix.RTAX_NETMASK: mask, - }, } - data, err := r.Marshal() + if prefix.Addr().Is4() { + route.Addrs = []netroute.Addr{ + unix.RTAX_DST: &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()}, + unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()}, + unix.RTAX_GATEWAY: gateway, + } + } else { + route.Addrs = []netroute.Addr{ + unix.RTAX_DST: &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()}, + unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()}, + unix.RTAX_GATEWAY: gateway, + } + } + + data, err := route.Marshal() if err != nil { return fmt.Errorf("failed to create route.RouteMessage: %w", err) } + _, err = unix.Write(sock, data[:]) if err != nil { return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err) @@ -488,19 +465,34 @@ func addRoute(sock int, addr, mask *netroute.Inet4Addr, link *netroute.LinkAddr) return nil } -func delRoute(sock int, addr, mask *netroute.Inet4Addr, link *netroute.LinkAddr) error { - r := netroute.RouteMessage{ +func delRoute(prefix netip.Prefix, gateway netroute.Addr) error { + sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC) + if err != nil { + return fmt.Errorf("unable to create AF_ROUTE socket: %v", err) + } + defer unix.Close(sock) + + route := netroute.RouteMessage{ Version: unix.RTM_VERSION, Type: unix.RTM_DELETE, Seq: 1, - Addrs: []netroute.Addr{ - unix.RTAX_DST: addr, - unix.RTAX_GATEWAY: link, - unix.RTAX_NETMASK: mask, - }, } - data, err := r.Marshal() + if prefix.Addr().Is4() { + route.Addrs = []netroute.Addr{ + unix.RTAX_DST: &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()}, + unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()}, + unix.RTAX_GATEWAY: gateway, + } + } else { + route.Addrs = []netroute.Addr{ + unix.RTAX_DST: &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()}, + unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()}, + unix.RTAX_GATEWAY: gateway, + } + } + + data, err := route.Marshal() if err != nil { return fmt.Errorf("failed to create route.RouteMessage: %w", err) } @@ -513,7 +505,6 @@ func delRoute(sock int, addr, mask *netroute.Inet4Addr, link *netroute.LinkAddr) } func (t *tun) Read(to []byte) (int, error) { - buf := make([]byte, len(to)+4) n, err := t.ReadWriteCloser.Read(buf) @@ -551,8 +542,8 @@ func (t *tun) Write(from []byte) (int, error) { return n - 4, err } -func (t *tun) Cidr() netip.Prefix { - return t.cidr +func (t *tun) Networks() []netip.Prefix { + return t.vpnNetworks } func (t *tun) Name() string { @@ -563,10 +554,12 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { return nil, fmt.Errorf("TODO: multiqueue not implemented for darwin") } -func prefixToMask(prefix netip.Prefix) []byte { +func prefixToMask(prefix netip.Prefix) netip.Addr { pLen := 128 if prefix.Addr().Is4() { pLen = 32 } - return net.CIDRMask(prefix.Bits(), pLen) + + addr, _ := netip.AddrFromSlice(net.CIDRMask(prefix.Bits(), pLen)) + return addr } diff --git a/overlay/tun_disabled.go b/overlay/tun_disabled.go index 130f8f9..cfbf17d 100644 --- a/overlay/tun_disabled.go +++ b/overlay/tun_disabled.go @@ -12,8 +12,8 @@ import ( ) type disabledTun struct { - read chan []byte - cidr netip.Prefix + read chan []byte + vpnNetworks []netip.Prefix // Track these metrics since we don't have the tun device to do it for us tx metrics.Counter @@ -21,11 +21,11 @@ type disabledTun struct { l *logrus.Logger } -func newDisabledTun(cidr netip.Prefix, queueLen int, metricsEnabled bool, l *logrus.Logger) *disabledTun { +func newDisabledTun(vpnNetworks []netip.Prefix, queueLen int, metricsEnabled bool, l *logrus.Logger) *disabledTun { tun := &disabledTun{ - cidr: cidr, - read: make(chan []byte, queueLen), - l: l, + vpnNetworks: vpnNetworks, + read: make(chan []byte, queueLen), + l: l, } if metricsEnabled { @@ -47,8 +47,8 @@ func (*disabledTun) RouteFor(addr netip.Addr) netip.Addr { return netip.Addr{} } -func (t *disabledTun) Cidr() netip.Prefix { - return t.cidr +func (t *disabledTun) Networks() []netip.Prefix { + return t.vpnNetworks } func (*disabledTun) Name() string { diff --git a/overlay/tun_freebsd.go b/overlay/tun_freebsd.go index bdfeb58..69690e9 100644 --- a/overlay/tun_freebsd.go +++ b/overlay/tun_freebsd.go @@ -46,12 +46,12 @@ type ifreqDestroy struct { } type tun struct { - Device string - cidr netip.Prefix - MTU int - Routes atomic.Pointer[[]Route] - routeTree atomic.Pointer[bart.Table[netip.Addr]] - l *logrus.Logger + Device string + vpnNetworks []netip.Prefix + MTU int + Routes atomic.Pointer[[]Route] + routeTree atomic.Pointer[bart.Table[netip.Addr]] + l *logrus.Logger io.ReadWriteCloser } @@ -78,11 +78,11 @@ func (t *tun) Close() error { return nil } -func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ netip.Prefix) (*tun, error) { +func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) { return nil, fmt.Errorf("newTunFromFd not supported in FreeBSD") } -func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, error) { +func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) { // Try to open existing tun device var file *os.File var err error @@ -150,7 +150,7 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, err t := &tun{ ReadWriteCloser: file, Device: deviceName, - cidr: cidr, + vpnNetworks: vpnNetworks, MTU: c.GetInt("tun.mtu", DefaultMTU), l: l, } @@ -170,16 +170,16 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, err return t, nil } -func (t *tun) Activate() error { +func (t *tun) addIp(cidr netip.Prefix) error { var err error // TODO use syscalls instead of exec.Command - cmd := exec.Command("/sbin/ifconfig", t.Device, t.cidr.String(), t.cidr.Addr().String()) + cmd := exec.Command("/sbin/ifconfig", t.Device, cidr.String(), cidr.Addr().String()) t.l.Debug("command: ", cmd.String()) if err = cmd.Run(); err != nil { return fmt.Errorf("failed to run 'ifconfig': %s", err) } - cmd = exec.Command("/sbin/route", "-n", "add", "-net", t.cidr.String(), "-interface", t.Device) + cmd = exec.Command("/sbin/route", "-n", "add", "-net", cidr.String(), "-interface", t.Device) t.l.Debug("command: ", cmd.String()) if err = cmd.Run(); err != nil { return fmt.Errorf("failed to run 'route add': %s", err) @@ -195,8 +195,18 @@ func (t *tun) Activate() error { return t.addRoutes(false) } +func (t *tun) Activate() error { + for i := range t.vpnNetworks { + err := t.addIp(t.vpnNetworks[i]) + if err != nil { + return err + } + } + return nil +} + func (t *tun) reload(c *config.C, initial bool) error { - change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial) + change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial) if err != nil { return err } @@ -237,8 +247,8 @@ func (t *tun) RouteFor(ip netip.Addr) netip.Addr { return r } -func (t *tun) Cidr() netip.Prefix { - return t.cidr +func (t *tun) Networks() []netip.Prefix { + return t.vpnNetworks } func (t *tun) Name() string { diff --git a/overlay/tun_ios.go b/overlay/tun_ios.go index 20981f0..e99d447 100644 --- a/overlay/tun_ios.go +++ b/overlay/tun_ios.go @@ -21,20 +21,20 @@ import ( type tun struct { io.ReadWriteCloser - cidr netip.Prefix - Routes atomic.Pointer[[]Route] - routeTree atomic.Pointer[bart.Table[netip.Addr]] - l *logrus.Logger + vpnNetworks []netip.Prefix + Routes atomic.Pointer[[]Route] + routeTree atomic.Pointer[bart.Table[netip.Addr]] + l *logrus.Logger } -func newTun(_ *config.C, _ *logrus.Logger, _ netip.Prefix, _ bool) (*tun, error) { +func newTun(_ *config.C, _ *logrus.Logger, _ []netip.Prefix, _ bool) (*tun, error) { return nil, fmt.Errorf("newTun not supported in iOS") } -func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr netip.Prefix) (*tun, error) { +func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) { file := os.NewFile(uintptr(deviceFd), "/dev/tun") t := &tun{ - cidr: cidr, + vpnNetworks: vpnNetworks, ReadWriteCloser: &tunReadCloser{f: file}, l: l, } @@ -59,7 +59,7 @@ func (t *tun) Activate() error { } func (t *tun) reload(c *config.C, initial bool) error { - change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial) + change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial) if err != nil { return err } @@ -142,8 +142,8 @@ func (tr *tunReadCloser) Close() error { return tr.f.Close() } -func (t *tun) Cidr() netip.Prefix { - return t.cidr +func (t *tun) Networks() []netip.Prefix { + return t.vpnNetworks } func (t *tun) Name() string { diff --git a/overlay/tun_linux.go b/overlay/tun_linux.go index 0e7e20d..993bd4a 100644 --- a/overlay/tun_linux.go +++ b/overlay/tun_linux.go @@ -11,6 +11,7 @@ import ( "os" "strings" "sync/atomic" + "time" "unsafe" "github.com/gaissmai/bart" @@ -25,7 +26,7 @@ type tun struct { io.ReadWriteCloser fd int Device string - cidr netip.Prefix + vpnNetworks []netip.Prefix MaxMTU int DefaultMTU int TXQueueLen int @@ -40,18 +41,16 @@ type tun struct { l *logrus.Logger } +func (t *tun) Networks() []netip.Prefix { + return t.vpnNetworks +} + type ifReq struct { Name [16]byte Flags uint16 pad [8]byte } -type ifreqAddr struct { - Name [16]byte - Addr unix.RawSockaddrInet4 - pad [8]byte -} - type ifreqMTU struct { Name [16]byte MTU int32 @@ -64,10 +63,10 @@ type ifreqQLEN struct { pad [8]byte } -func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr netip.Prefix) (*tun, error) { +func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) { file := os.NewFile(uintptr(deviceFd), "/dev/net/tun") - t, err := newTunGeneric(c, l, file, cidr) + t, err := newTunGeneric(c, l, file, vpnNetworks) if err != nil { return nil, err } @@ -77,7 +76,7 @@ func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr netip.Prefix return t, nil } -func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, multiqueue bool) (*tun, error) { +func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueue bool) (*tun, error) { fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0) if err != nil { // If /dev/net/tun doesn't exist, try to create it (will happen in docker) @@ -112,7 +111,7 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, multiqueue bool) ( name := strings.Trim(string(req.Name[:]), "\x00") file := os.NewFile(uintptr(fd), "/dev/net/tun") - t, err := newTunGeneric(c, l, file, cidr) + t, err := newTunGeneric(c, l, file, vpnNetworks) if err != nil { return nil, err } @@ -122,11 +121,11 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, multiqueue bool) ( return t, nil } -func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, cidr netip.Prefix) (*tun, error) { +func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, vpnNetworks []netip.Prefix) (*tun, error) { t := &tun{ ReadWriteCloser: file, fd: int(file.Fd()), - cidr: cidr, + vpnNetworks: vpnNetworks, TXQueueLen: c.GetInt("tun.tx_queue", 500), useSystemRoutes: c.GetBool("tun.use_system_route_table", false), l: l, @@ -148,7 +147,7 @@ func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, cidr netip.Pref } func (t *tun) reload(c *config.C, initial bool) error { - routeChange, routes, err := getAllRoutesFromConfig(c, t.cidr, initial) + routeChange, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial) if err != nil { return err } @@ -190,11 +189,13 @@ func (t *tun) reload(c *config.C, initial bool) error { } if oldDefaultMTU != newDefaultMTU { - err := t.setDefaultRoute() - if err != nil { - t.l.Warn(err) - } else { - t.l.Infof("Set default MTU to %v was %v", t.DefaultMTU, oldDefaultMTU) + for i := range t.vpnNetworks { + err := t.setDefaultRoute(t.vpnNetworks[i]) + if err != nil { + t.l.Warn(err) + } else { + t.l.Infof("Set default MTU to %v was %v", t.DefaultMTU, oldDefaultMTU) + } } } @@ -237,10 +238,10 @@ func (t *tun) RouteFor(ip netip.Addr) netip.Addr { func (t *tun) Write(b []byte) (int, error) { var nn int - max := len(b) + maximum := len(b) for { - n, err := unix.Write(t.fd, b[nn:max]) + n, err := unix.Write(t.fd, b[nn:maximum]) if n > 0 { nn += n } @@ -265,6 +266,58 @@ func (t *tun) deviceBytes() (o [16]byte) { return } +func hasNetlinkAddr(al []*netlink.Addr, x netlink.Addr) bool { + for i := range al { + if al[i].Equal(x) { + return true + } + } + return false +} + +// addIPs uses netlink to add all addresses that don't exist, then it removes ones that should not be there +func (t *tun) addIPs(link netlink.Link) error { + newAddrs := make([]*netlink.Addr, len(t.vpnNetworks)) + for i := range t.vpnNetworks { + newAddrs[i] = &netlink.Addr{ + IPNet: &net.IPNet{ + IP: t.vpnNetworks[i].Addr().AsSlice(), + Mask: net.CIDRMask(t.vpnNetworks[i].Bits(), t.vpnNetworks[i].Addr().BitLen()), + }, + Label: t.vpnNetworks[i].Addr().Zone(), + } + } + + //add all new addresses + for i := range newAddrs { + //TODO: CERT-V2 do we want to stack errors and try as many ops as possible? + //AddrReplace still adds new IPs, but if their properties change it will change them as well + if err := netlink.AddrReplace(link, newAddrs[i]); err != nil { + return err + } + } + + //iterate over remainder, remove whoever shouldn't be there + al, err := netlink.AddrList(link, netlink.FAMILY_ALL) + if err != nil { + return fmt.Errorf("failed to get tun address list: %s", err) + } + + for i := range al { + if hasNetlinkAddr(newAddrs, al[i]) { + continue + } + err = netlink.AddrDel(link, &al[i]) + if err != nil { + t.l.WithError(err).Error("failed to remove address from tun address list") + } else { + t.l.WithField("removed", al[i].String()).Info("removed address not listed in cert(s)") + } + } + + return nil +} + func (t *tun) Activate() error { devName := t.deviceBytes() @@ -272,15 +325,8 @@ func (t *tun) Activate() error { t.watchRoutes() } - var addr, mask [4]byte - - //TODO: IPV6-WORK - addr = t.cidr.Addr().As4() - tmask := net.CIDRMask(t.cidr.Bits(), 32) - copy(mask[:], tmask) - s, err := unix.Socket( - unix.AF_INET, + unix.AF_INET, //because everything we use t.ioctlFd for is address family independent, this is fine unix.SOCK_DGRAM, unix.IPPROTO_IP, ) @@ -289,31 +335,19 @@ func (t *tun) Activate() error { } t.ioctlFd = uintptr(s) - ifra := ifreqAddr{ - Name: devName, - Addr: unix.RawSockaddrInet4{ - Family: unix.AF_INET, - Addr: addr, - }, - } - - // Set the device ip address - if err = ioctl(t.ioctlFd, unix.SIOCSIFADDR, uintptr(unsafe.Pointer(&ifra))); err != nil { - return fmt.Errorf("failed to set tun address: %s", err) - } - - // Set the device network - ifra.Addr.Addr = mask - if err = ioctl(t.ioctlFd, unix.SIOCSIFNETMASK, uintptr(unsafe.Pointer(&ifra))); err != nil { - return fmt.Errorf("failed to set tun netmask: %s", err) - } - // Set the device name ifrf := ifReq{Name: devName} if err = ioctl(t.ioctlFd, unix.SIOCGIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil { return fmt.Errorf("failed to set tun device name: %s", err) } + link, err := netlink.LinkByName(t.Device) + if err != nil { + return fmt.Errorf("failed to get tun device link: %s", err) + } + + t.deviceIndex = link.Attrs().Index + // Setup our default MTU t.setMTU() @@ -324,20 +358,21 @@ func (t *tun) Activate() error { t.l.WithError(err).Error("Failed to set tun tx queue length") } + if err = t.addIPs(link); err != nil { + return err + } + // Bring up the interface ifrf.Flags = ifrf.Flags | unix.IFF_UP if err = ioctl(t.ioctlFd, unix.SIOCSIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil { return fmt.Errorf("failed to bring the tun device up: %s", err) } - link, err := netlink.LinkByName(t.Device) - if err != nil { - return fmt.Errorf("failed to get tun device link: %s", err) - } - t.deviceIndex = link.Attrs().Index - - if err = t.setDefaultRoute(); err != nil { - return err + //set route MTU + for i := range t.vpnNetworks { + if err = t.setDefaultRoute(t.vpnNetworks[i]); err != nil { + return fmt.Errorf("failed to set default route MTU: %w", err) + } } // Set the routes @@ -363,12 +398,10 @@ func (t *tun) setMTU() { } } -func (t *tun) setDefaultRoute() error { - // Default route - +func (t *tun) setDefaultRoute(cidr netip.Prefix) error { dr := &net.IPNet{ - IP: t.cidr.Masked().Addr().AsSlice(), - Mask: net.CIDRMask(t.cidr.Bits(), t.cidr.Addr().BitLen()), + IP: cidr.Masked().Addr().AsSlice(), + Mask: net.CIDRMask(cidr.Bits(), cidr.Addr().BitLen()), } nr := netlink.Route{ @@ -377,14 +410,27 @@ func (t *tun) setDefaultRoute() error { MTU: t.DefaultMTU, AdvMSS: t.advMSS(Route{}), Scope: unix.RT_SCOPE_LINK, - Src: net.IP(t.cidr.Addr().AsSlice()), + Src: net.IP(cidr.Addr().AsSlice()), Protocol: unix.RTPROT_KERNEL, Table: unix.RT_TABLE_MAIN, Type: unix.RTN_UNICAST, } err := netlink.RouteReplace(&nr) if err != nil { - return fmt.Errorf("failed to set mtu %v on the default route %v; %v", t.DefaultMTU, dr, err) + t.l.WithError(err).WithField("cidr", cidr).Warn("Failed to set default route MTU, retrying") + //retry twice more -- on some systems there appears to be a race condition where if we set routes too soon, netlink says `invalid argument` + for i := 0; i < 2; i++ { + time.Sleep(100 * time.Millisecond) + err = netlink.RouteReplace(&nr) + if err == nil { + break + } else { + t.l.WithError(err).WithField("cidr", cidr).WithField("mtu", t.DefaultMTU).Warn("Failed to set default route MTU, retrying") + } + } + if err != nil { + return fmt.Errorf("failed to set mtu %v on the default route %v; %v", t.DefaultMTU, dr, err) + } } return nil @@ -463,10 +509,6 @@ func (t *tun) removeRoutes(routes []Route) { } } -func (t *tun) Cidr() netip.Prefix { - return t.cidr -} - func (t *tun) Name() string { return t.Device } @@ -515,7 +557,6 @@ func (t *tun) updateRoutes(r netlink.RouteUpdate) { return } - //TODO: IPV6-WORK what if not ok? gwAddr, ok := netip.AddrFromSlice(r.Gw) if !ok { t.l.WithField("route", r).Debug("Ignoring route update, invalid gateway address") @@ -523,15 +564,16 @@ func (t *tun) updateRoutes(r netlink.RouteUpdate) { } gwAddr = gwAddr.Unmap() - if !t.cidr.Contains(gwAddr) { - // Gateway isn't in our overlay network, ignore - t.l.WithField("route", r).Debug("Ignoring route update, not in our network") - return + withinNetworks := false + for i := range t.vpnNetworks { + if t.vpnNetworks[i].Contains(gwAddr) { + withinNetworks = true + break + } } - - if x := r.Dst.IP.To4(); x == nil { - // Nebula only handles ipv4 on the overlay currently - t.l.WithField("route", r).Debug("Ignoring route update, destination is not ipv4") + if !withinNetworks { + // Gateway isn't in our overlay network, ignore + t.l.WithField("route", r).Debug("Ignoring route update, not in our networks") return } @@ -563,11 +605,11 @@ func (t *tun) Close() error { } if t.ReadWriteCloser != nil { - t.ReadWriteCloser.Close() + _ = t.ReadWriteCloser.Close() } if t.ioctlFd > 0 { - os.NewFile(t.ioctlFd, "ioctlFd").Close() + _ = os.NewFile(t.ioctlFd, "ioctlFd").Close() } return nil diff --git a/overlay/tun_netbsd.go b/overlay/tun_netbsd.go index 24ab24f..f7586cb 100644 --- a/overlay/tun_netbsd.go +++ b/overlay/tun_netbsd.go @@ -27,12 +27,12 @@ type ifreqDestroy struct { } type tun struct { - Device string - cidr netip.Prefix - MTU int - Routes atomic.Pointer[[]Route] - routeTree atomic.Pointer[bart.Table[netip.Addr]] - l *logrus.Logger + Device string + vpnNetworks []netip.Prefix + MTU int + Routes atomic.Pointer[[]Route] + routeTree atomic.Pointer[bart.Table[netip.Addr]] + l *logrus.Logger io.ReadWriteCloser } @@ -58,13 +58,13 @@ func (t *tun) Close() error { return nil } -func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ netip.Prefix) (*tun, error) { +func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) { return nil, fmt.Errorf("newTunFromFd not supported in NetBSD") } var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`) -func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, error) { +func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) { // Try to open tun device var file *os.File var err error @@ -84,7 +84,7 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, err t := &tun{ ReadWriteCloser: file, Device: deviceName, - cidr: cidr, + vpnNetworks: vpnNetworks, MTU: c.GetInt("tun.mtu", DefaultMTU), l: l, } @@ -104,17 +104,17 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, err return t, nil } -func (t *tun) Activate() error { +func (t *tun) addIp(cidr netip.Prefix) error { var err error // TODO use syscalls instead of exec.Command - cmd := exec.Command("/sbin/ifconfig", t.Device, t.cidr.String(), t.cidr.Addr().String()) + cmd := exec.Command("/sbin/ifconfig", t.Device, cidr.String(), cidr.Addr().String()) t.l.Debug("command: ", cmd.String()) if err = cmd.Run(); err != nil { return fmt.Errorf("failed to run 'ifconfig': %s", err) } - cmd = exec.Command("/sbin/route", "-n", "add", "-net", t.cidr.String(), t.cidr.Addr().String()) + cmd = exec.Command("/sbin/route", "-n", "add", "-net", cidr.String(), cidr.Addr().String()) t.l.Debug("command: ", cmd.String()) if err = cmd.Run(); err != nil { return fmt.Errorf("failed to run 'route add': %s", err) @@ -130,8 +130,18 @@ func (t *tun) Activate() error { return t.addRoutes(false) } +func (t *tun) Activate() error { + for i := range t.vpnNetworks { + err := t.addIp(t.vpnNetworks[i]) + if err != nil { + return err + } + } + return nil +} + func (t *tun) reload(c *config.C, initial bool) error { - change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial) + change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial) if err != nil { return err } @@ -172,8 +182,8 @@ func (t *tun) RouteFor(ip netip.Addr) netip.Addr { return r } -func (t *tun) Cidr() netip.Prefix { - return t.cidr +func (t *tun) Networks() []netip.Prefix { + return t.vpnNetworks } func (t *tun) Name() string { @@ -192,7 +202,7 @@ func (t *tun) addRoutes(logErrors bool) error { continue } - cmd := exec.Command("/sbin/route", "-n", "add", "-net", r.Cidr.String(), t.cidr.Addr().String()) + cmd := exec.Command("/sbin/route", "-n", "add", "-net", r.Cidr.String(), t.vpnNetworks[0].Addr().String()) t.l.Debug("command: ", cmd.String()) if err := cmd.Run(); err != nil { retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]interface{}{"route": r}, err) @@ -213,7 +223,8 @@ func (t *tun) removeRoutes(routes []Route) error { continue } - cmd := exec.Command("/sbin/route", "-n", "delete", "-net", r.Cidr.String(), t.cidr.Addr().String()) + //TODO: CERT-V2 is this right? + cmd := exec.Command("/sbin/route", "-n", "delete", "-net", r.Cidr.String(), t.vpnNetworks[0].Addr().String()) t.l.Debug("command: ", cmd.String()) if err := cmd.Run(); err != nil { t.l.WithError(err).WithField("route", r).Error("Failed to remove route") diff --git a/overlay/tun_openbsd.go b/overlay/tun_openbsd.go index 6463ccb..a2fd184 100644 --- a/overlay/tun_openbsd.go +++ b/overlay/tun_openbsd.go @@ -21,12 +21,12 @@ import ( ) type tun struct { - Device string - cidr netip.Prefix - MTU int - Routes atomic.Pointer[[]Route] - routeTree atomic.Pointer[bart.Table[netip.Addr]] - l *logrus.Logger + Device string + vpnNetworks []netip.Prefix + MTU int + Routes atomic.Pointer[[]Route] + routeTree atomic.Pointer[bart.Table[netip.Addr]] + l *logrus.Logger io.ReadWriteCloser @@ -42,13 +42,13 @@ func (t *tun) Close() error { return nil } -func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ netip.Prefix) (*tun, error) { +func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) { return nil, fmt.Errorf("newTunFromFd not supported in OpenBSD") } var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`) -func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, error) { +func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) { deviceName := c.GetString("tun.dev", "") if deviceName == "" { return nil, fmt.Errorf("a device name in the format of tunN must be specified") @@ -66,7 +66,7 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, err t := &tun{ ReadWriteCloser: file, Device: deviceName, - cidr: cidr, + vpnNetworks: vpnNetworks, MTU: c.GetInt("tun.mtu", DefaultMTU), l: l, } @@ -87,7 +87,7 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, err } func (t *tun) reload(c *config.C, initial bool) error { - change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial) + change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial) if err != nil { return err } @@ -123,10 +123,10 @@ func (t *tun) reload(c *config.C, initial bool) error { return nil } -func (t *tun) Activate() error { +func (t *tun) addIp(cidr netip.Prefix) error { var err error // TODO use syscalls instead of exec.Command - cmd := exec.Command("/sbin/ifconfig", t.Device, t.cidr.String(), t.cidr.Addr().String()) + cmd := exec.Command("/sbin/ifconfig", t.Device, cidr.String(), cidr.Addr().String()) t.l.Debug("command: ", cmd.String()) if err = cmd.Run(); err != nil { return fmt.Errorf("failed to run 'ifconfig': %s", err) @@ -138,7 +138,7 @@ func (t *tun) Activate() error { return fmt.Errorf("failed to run 'ifconfig': %s", err) } - cmd = exec.Command("/sbin/route", "-n", "add", "-inet", t.cidr.String(), t.cidr.Addr().String()) + cmd = exec.Command("/sbin/route", "-n", "add", "-inet", cidr.String(), cidr.Addr().String()) t.l.Debug("command: ", cmd.String()) if err = cmd.Run(); err != nil { return fmt.Errorf("failed to run 'route add': %s", err) @@ -148,6 +148,16 @@ func (t *tun) Activate() error { return t.addRoutes(false) } +func (t *tun) Activate() error { + for i := range t.vpnNetworks { + err := t.addIp(t.vpnNetworks[i]) + if err != nil { + return err + } + } + return nil +} + func (t *tun) RouteFor(ip netip.Addr) netip.Addr { r, _ := t.routeTree.Load().Lookup(ip) return r @@ -160,8 +170,8 @@ func (t *tun) addRoutes(logErrors bool) error { // We don't allow route MTUs so only install routes with a via continue } - - cmd := exec.Command("/sbin/route", "-n", "add", "-inet", r.Cidr.String(), t.cidr.Addr().String()) + //TODO: CERT-V2 is this right? + cmd := exec.Command("/sbin/route", "-n", "add", "-inet", r.Cidr.String(), t.vpnNetworks[0].Addr().String()) t.l.Debug("command: ", cmd.String()) if err := cmd.Run(); err != nil { retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]interface{}{"route": r}, err) @@ -181,8 +191,8 @@ func (t *tun) removeRoutes(routes []Route) error { if !r.Install { continue } - - cmd := exec.Command("/sbin/route", "-n", "delete", "-inet", r.Cidr.String(), t.cidr.Addr().String()) + //TODO: CERT-V2 is this right? + cmd := exec.Command("/sbin/route", "-n", "delete", "-inet", r.Cidr.String(), t.vpnNetworks[0].Addr().String()) t.l.Debug("command: ", cmd.String()) if err := cmd.Run(); err != nil { t.l.WithError(err).WithField("route", r).Error("Failed to remove route") @@ -193,8 +203,8 @@ func (t *tun) removeRoutes(routes []Route) error { return nil } -func (t *tun) Cidr() netip.Prefix { - return t.cidr +func (t *tun) Networks() []netip.Prefix { + return t.vpnNetworks } func (t *tun) Name() string { diff --git a/overlay/tun_tester.go b/overlay/tun_tester.go index ba15723..cc3942f 100644 --- a/overlay/tun_tester.go +++ b/overlay/tun_tester.go @@ -16,19 +16,19 @@ import ( ) type TestTun struct { - Device string - cidr netip.Prefix - Routes []Route - routeTree *bart.Table[netip.Addr] - l *logrus.Logger + Device string + vpnNetworks []netip.Prefix + Routes []Route + routeTree *bart.Table[netip.Addr] + l *logrus.Logger closed atomic.Bool rxPackets chan []byte // Packets to receive into nebula TxPackets chan []byte // Packets transmitted outside by nebula } -func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*TestTun, error) { - _, routes, err := getAllRoutesFromConfig(c, cidr, true) +func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*TestTun, error) { + _, routes, err := getAllRoutesFromConfig(c, vpnNetworks, true) if err != nil { return nil, err } @@ -38,17 +38,17 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*TestTun, } return &TestTun{ - Device: c.GetString("tun.dev", ""), - cidr: cidr, - Routes: routes, - routeTree: routeTree, - l: l, - rxPackets: make(chan []byte, 10), - TxPackets: make(chan []byte, 10), + Device: c.GetString("tun.dev", ""), + vpnNetworks: vpnNetworks, + Routes: routes, + routeTree: routeTree, + l: l, + rxPackets: make(chan []byte, 10), + TxPackets: make(chan []byte, 10), }, nil } -func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ netip.Prefix) (*TestTun, error) { +func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*TestTun, error) { return nil, fmt.Errorf("newTunFromFd not supported") } @@ -95,8 +95,8 @@ func (t *TestTun) Activate() error { return nil } -func (t *TestTun) Cidr() netip.Prefix { - return t.cidr +func (t *TestTun) Networks() []netip.Prefix { + return t.vpnNetworks } func (t *TestTun) Name() string { diff --git a/overlay/tun_water_windows.go b/overlay/tun_water_windows.go deleted file mode 100644 index d78f564..0000000 --- a/overlay/tun_water_windows.go +++ /dev/null @@ -1,208 +0,0 @@ -package overlay - -import ( - "fmt" - "io" - "net" - "net/netip" - "os/exec" - "strconv" - "sync/atomic" - - "github.com/gaissmai/bart" - "github.com/sirupsen/logrus" - "github.com/slackhq/nebula/config" - "github.com/slackhq/nebula/util" - "github.com/songgao/water" -) - -type waterTun struct { - Device string - cidr netip.Prefix - MTU int - Routes atomic.Pointer[[]Route] - routeTree atomic.Pointer[bart.Table[netip.Addr]] - l *logrus.Logger - f *net.Interface - *water.Interface -} - -func newWaterTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*waterTun, error) { - // NOTE: You cannot set the deviceName under Windows, so you must check tun.Device after calling .Activate() - t := &waterTun{ - cidr: cidr, - MTU: c.GetInt("tun.mtu", DefaultMTU), - l: l, - } - - err := t.reload(c, true) - if err != nil { - return nil, err - } - - c.RegisterReloadCallback(func(c *config.C) { - err := t.reload(c, false) - if err != nil { - util.LogWithContextIfNeeded("failed to reload tun device", err, t.l) - } - }) - - return t, nil -} - -func (t *waterTun) Activate() error { - var err error - t.Interface, err = water.New(water.Config{ - DeviceType: water.TUN, - PlatformSpecificParams: water.PlatformSpecificParams{ - ComponentID: "tap0901", - Network: t.cidr.String(), - }, - }) - if err != nil { - return fmt.Errorf("activate failed: %v", err) - } - - t.Device = t.Interface.Name() - - // TODO use syscalls instead of exec.Command - err = exec.Command( - `C:\Windows\System32\netsh.exe`, "interface", "ipv4", "set", "address", - fmt.Sprintf("name=%s", t.Device), - "source=static", - fmt.Sprintf("addr=%s", t.cidr.Addr()), - fmt.Sprintf("mask=%s", net.CIDRMask(t.cidr.Bits(), t.cidr.Addr().BitLen())), - "gateway=none", - ).Run() - if err != nil { - return fmt.Errorf("failed to run 'netsh' to set address: %s", err) - } - err = exec.Command( - `C:\Windows\System32\netsh.exe`, "interface", "ipv4", "set", "interface", - t.Device, - fmt.Sprintf("mtu=%d", t.MTU), - ).Run() - if err != nil { - return fmt.Errorf("failed to run 'netsh' to set MTU: %s", err) - } - - t.f, err = net.InterfaceByName(t.Device) - if err != nil { - return fmt.Errorf("failed to find interface named %s: %v", t.Device, err) - } - - err = t.addRoutes(false) - if err != nil { - return err - } - - return nil -} - -func (t *waterTun) reload(c *config.C, initial bool) error { - change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial) - if err != nil { - return err - } - - if !initial && !change { - return nil - } - - routeTree, err := makeRouteTree(t.l, routes, false) - if err != nil { - return err - } - - // Teach nebula how to handle the routes before establishing them in the system table - oldRoutes := t.Routes.Swap(&routes) - t.routeTree.Store(routeTree) - - if !initial { - // Remove first, if the system removes a wanted route hopefully it will be re-added next - t.removeRoutes(findRemovedRoutes(routes, *oldRoutes)) - - // Ensure any routes we actually want are installed - err = t.addRoutes(true) - if err != nil { - // Catch any stray logs - util.LogWithContextIfNeeded("Failed to set routes", err, t.l) - } else { - for _, r := range findRemovedRoutes(routes, *oldRoutes) { - t.l.WithField("route", r).Info("Removed route") - } - } - } - - return nil -} - -func (t *waterTun) addRoutes(logErrors bool) error { - // Path routes - routes := *t.Routes.Load() - for _, r := range routes { - if !r.Via.IsValid() || !r.Install { - // We don't allow route MTUs so only install routes with a via - continue - } - - err := exec.Command( - "C:\\Windows\\System32\\route.exe", "add", r.Cidr.String(), r.Via.String(), "IF", strconv.Itoa(t.f.Index), "METRIC", strconv.Itoa(r.Metric), - ).Run() - - if err != nil { - retErr := util.NewContextualError("Failed to add route", map[string]interface{}{"route": r}, err) - if logErrors { - retErr.Log(t.l) - } else { - return retErr - } - } else { - t.l.WithField("route", r).Info("Added route") - } - } - - return nil -} - -func (t *waterTun) removeRoutes(routes []Route) { - for _, r := range routes { - if !r.Install { - continue - } - - err := exec.Command( - "C:\\Windows\\System32\\route.exe", "delete", r.Cidr.String(), r.Via.String(), "IF", strconv.Itoa(t.f.Index), "METRIC", strconv.Itoa(r.Metric), - ).Run() - if err != nil { - t.l.WithError(err).WithField("route", r).Error("Failed to remove route") - } else { - t.l.WithField("route", r).Info("Removed route") - } - } -} - -func (t *waterTun) RouteFor(ip netip.Addr) netip.Addr { - r, _ := t.routeTree.Load().Lookup(ip) - return r -} - -func (t *waterTun) Cidr() netip.Prefix { - return t.cidr -} - -func (t *waterTun) Name() string { - return t.Device -} - -func (t *waterTun) Close() error { - if t.Interface == nil { - return nil - } - - return t.Interface.Close() -} - -func (t *waterTun) NewMultiQueueReader() (io.ReadWriteCloser, error) { - return nil, fmt.Errorf("TODO: multiqueue not implemented for windows") -} diff --git a/overlay/tun_windows.go b/overlay/tun_windows.go index 3d88309..289999d 100644 --- a/overlay/tun_windows.go +++ b/overlay/tun_windows.go @@ -4,41 +4,268 @@ package overlay import ( + "crypto" "fmt" + "io" "net/netip" "os" "path/filepath" "runtime" + "sync/atomic" "syscall" + "unsafe" + "github.com/gaissmai/bart" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/util" + "github.com/slackhq/nebula/wintun" + "golang.org/x/sys/windows" + "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg" ) -func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ netip.Prefix) (Device, error) { +const tunGUIDLabel = "Fixed Nebula Windows GUID v1" + +type winTun struct { + Device string + vpnNetworks []netip.Prefix + MTU int + Routes atomic.Pointer[[]Route] + routeTree atomic.Pointer[bart.Table[netip.Addr]] + l *logrus.Logger + + tun *wintun.NativeTun +} + +func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (Device, error) { return nil, fmt.Errorf("newTunFromFd not supported in Windows") } -func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, multiqueue bool) (Device, error) { - useWintun := true - if err := checkWinTunExists(); err != nil { - l.WithError(err).Warn("Check Wintun driver failed, fallback to wintap driver") - useWintun = false - } - - if useWintun { - device, err := newWinTun(c, l, cidr, multiqueue) - if err != nil { - return nil, fmt.Errorf("create Wintun interface failed, %w", err) - } - return device, nil - } - - device, err := newWaterTun(c, l, cidr, multiqueue) +func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*winTun, error) { + err := checkWinTunExists() if err != nil { - return nil, fmt.Errorf("create wintap driver failed, %w", err) + return nil, fmt.Errorf("can not load the wintun driver: %w", err) } - return device, nil + + deviceName := c.GetString("tun.dev", "") + guid, err := generateGUIDByDeviceName(deviceName) + if err != nil { + return nil, fmt.Errorf("generate GUID failed: %w", err) + } + + t := &winTun{ + Device: deviceName, + vpnNetworks: vpnNetworks, + MTU: c.GetInt("tun.mtu", DefaultMTU), + l: l, + } + + err = t.reload(c, true) + if err != nil { + return nil, err + } + + var tunDevice wintun.Device + tunDevice, err = wintun.CreateTUNWithRequestedGUID(deviceName, guid, t.MTU) + if err != nil { + // Windows 10 has an issue with unclean shutdowns not fully cleaning up the wintun device. + // Trying a second time resolves the issue. + l.WithError(err).Debug("Failed to create wintun device, retrying") + tunDevice, err = wintun.CreateTUNWithRequestedGUID(deviceName, guid, t.MTU) + if err != nil { + return nil, fmt.Errorf("create TUN device failed: %w", err) + } + } + t.tun = tunDevice.(*wintun.NativeTun) + + c.RegisterReloadCallback(func(c *config.C) { + err := t.reload(c, false) + if err != nil { + util.LogWithContextIfNeeded("failed to reload tun device", err, t.l) + } + }) + + return t, nil +} + +func (t *winTun) reload(c *config.C, initial bool) error { + change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial) + if err != nil { + return err + } + + if !initial && !change { + return nil + } + + routeTree, err := makeRouteTree(t.l, routes, false) + if err != nil { + return err + } + + // Teach nebula how to handle the routes before establishing them in the system table + oldRoutes := t.Routes.Swap(&routes) + t.routeTree.Store(routeTree) + + if !initial { + // Remove first, if the system removes a wanted route hopefully it will be re-added next + err := t.removeRoutes(findRemovedRoutes(routes, *oldRoutes)) + if err != nil { + util.LogWithContextIfNeeded("Failed to remove routes", err, t.l) + } + + // Ensure any routes we actually want are installed + err = t.addRoutes(true) + if err != nil { + // Catch any stray logs + util.LogWithContextIfNeeded("Failed to add routes", err, t.l) + } + } + + return nil +} + +func (t *winTun) Activate() error { + luid := winipcfg.LUID(t.tun.LUID()) + + err := luid.SetIPAddresses(t.vpnNetworks) + if err != nil { + return fmt.Errorf("failed to set address: %w", err) + } + + err = t.addRoutes(false) + if err != nil { + return err + } + + return nil +} + +func (t *winTun) addRoutes(logErrors bool) error { + luid := winipcfg.LUID(t.tun.LUID()) + routes := *t.Routes.Load() + foundDefault4 := false + + for _, r := range routes { + if !r.Via.IsValid() || !r.Install { + // We don't allow route MTUs so only install routes with a via + continue + } + + // Add our unsafe route + err := luid.AddRoute(r.Cidr, r.Via, uint32(r.Metric)) + if err != nil { + retErr := util.NewContextualError("Failed to add route", map[string]interface{}{"route": r}, err) + if logErrors { + retErr.Log(t.l) + continue + } else { + return retErr + } + } else { + t.l.WithField("route", r).Info("Added route") + } + + if !foundDefault4 { + if r.Cidr.Bits() == 0 && r.Cidr.Addr().BitLen() == 32 { + foundDefault4 = true + } + } + } + + ipif, err := luid.IPInterface(windows.AF_INET) + if err != nil { + return fmt.Errorf("failed to get ip interface: %w", err) + } + + ipif.NLMTU = uint32(t.MTU) + if foundDefault4 { + ipif.UseAutomaticMetric = false + ipif.Metric = 0 + } + + if err := ipif.Set(); err != nil { + return fmt.Errorf("failed to set ip interface: %w", err) + } + return nil +} + +func (t *winTun) removeRoutes(routes []Route) error { + luid := winipcfg.LUID(t.tun.LUID()) + + for _, r := range routes { + if !r.Install { + continue + } + + err := luid.DeleteRoute(r.Cidr, r.Via) + if err != nil { + t.l.WithError(err).WithField("route", r).Error("Failed to remove route") + } else { + t.l.WithField("route", r).Info("Removed route") + } + } + return nil +} + +func (t *winTun) RouteFor(ip netip.Addr) netip.Addr { + r, _ := t.routeTree.Load().Lookup(ip) + return r +} + +func (t *winTun) Networks() []netip.Prefix { + return t.vpnNetworks +} + +func (t *winTun) Name() string { + return t.Device +} + +func (t *winTun) Read(b []byte) (int, error) { + return t.tun.Read(b, 0) +} + +func (t *winTun) Write(b []byte) (int, error) { + return t.tun.Write(b, 0) +} + +func (t *winTun) NewMultiQueueReader() (io.ReadWriteCloser, error) { + return nil, fmt.Errorf("TODO: multiqueue not implemented for windows") +} + +func (t *winTun) Close() error { + // It seems that the Windows networking stack doesn't like it when we destroy interfaces that have active routes, + // so to be certain, just remove everything before destroying. + luid := winipcfg.LUID(t.tun.LUID()) + _ = luid.FlushRoutes(windows.AF_INET) + _ = luid.FlushIPAddresses(windows.AF_INET) + + _ = luid.FlushRoutes(windows.AF_INET6) + _ = luid.FlushIPAddresses(windows.AF_INET6) + + _ = luid.FlushDNS(windows.AF_INET) + _ = luid.FlushDNS(windows.AF_INET6) + + return t.tun.Close() +} + +func generateGUIDByDeviceName(name string) (*windows.GUID, error) { + // GUID is 128 bit + hash := crypto.MD5.New() + + _, err := hash.Write([]byte(tunGUIDLabel)) + if err != nil { + return nil, err + } + + _, err = hash.Write([]byte(name)) + if err != nil { + return nil, err + } + + sum := hash.Sum(nil) + + return (*windows.GUID)(unsafe.Pointer(&sum[0])), nil } func checkWinTunExists() error { diff --git a/overlay/tun_wintun_windows.go b/overlay/tun_wintun_windows.go deleted file mode 100644 index d010387..0000000 --- a/overlay/tun_wintun_windows.go +++ /dev/null @@ -1,252 +0,0 @@ -package overlay - -import ( - "crypto" - "fmt" - "io" - "net/netip" - "sync/atomic" - "unsafe" - - "github.com/gaissmai/bart" - "github.com/sirupsen/logrus" - "github.com/slackhq/nebula/config" - "github.com/slackhq/nebula/util" - "github.com/slackhq/nebula/wintun" - "golang.org/x/sys/windows" - "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg" -) - -const tunGUIDLabel = "Fixed Nebula Windows GUID v1" - -type winTun struct { - Device string - cidr netip.Prefix - MTU int - Routes atomic.Pointer[[]Route] - routeTree atomic.Pointer[bart.Table[netip.Addr]] - l *logrus.Logger - - tun *wintun.NativeTun -} - -func generateGUIDByDeviceName(name string) (*windows.GUID, error) { - // GUID is 128 bit - hash := crypto.MD5.New() - - _, err := hash.Write([]byte(tunGUIDLabel)) - if err != nil { - return nil, err - } - - _, err = hash.Write([]byte(name)) - if err != nil { - return nil, err - } - - sum := hash.Sum(nil) - - return (*windows.GUID)(unsafe.Pointer(&sum[0])), nil -} - -func newWinTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*winTun, error) { - deviceName := c.GetString("tun.dev", "") - guid, err := generateGUIDByDeviceName(deviceName) - if err != nil { - return nil, fmt.Errorf("generate GUID failed: %w", err) - } - - t := &winTun{ - Device: deviceName, - cidr: cidr, - MTU: c.GetInt("tun.mtu", DefaultMTU), - l: l, - } - - err = t.reload(c, true) - if err != nil { - return nil, err - } - - var tunDevice wintun.Device - tunDevice, err = wintun.CreateTUNWithRequestedGUID(deviceName, guid, t.MTU) - if err != nil { - // Windows 10 has an issue with unclean shutdowns not fully cleaning up the wintun device. - // Trying a second time resolves the issue. - l.WithError(err).Debug("Failed to create wintun device, retrying") - tunDevice, err = wintun.CreateTUNWithRequestedGUID(deviceName, guid, t.MTU) - if err != nil { - return nil, fmt.Errorf("create TUN device failed: %w", err) - } - } - t.tun = tunDevice.(*wintun.NativeTun) - - c.RegisterReloadCallback(func(c *config.C) { - err := t.reload(c, false) - if err != nil { - util.LogWithContextIfNeeded("failed to reload tun device", err, t.l) - } - }) - - return t, nil -} - -func (t *winTun) reload(c *config.C, initial bool) error { - change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial) - if err != nil { - return err - } - - if !initial && !change { - return nil - } - - routeTree, err := makeRouteTree(t.l, routes, false) - if err != nil { - return err - } - - // Teach nebula how to handle the routes before establishing them in the system table - oldRoutes := t.Routes.Swap(&routes) - t.routeTree.Store(routeTree) - - if !initial { - // Remove first, if the system removes a wanted route hopefully it will be re-added next - err := t.removeRoutes(findRemovedRoutes(routes, *oldRoutes)) - if err != nil { - util.LogWithContextIfNeeded("Failed to remove routes", err, t.l) - } - - // Ensure any routes we actually want are installed - err = t.addRoutes(true) - if err != nil { - // Catch any stray logs - util.LogWithContextIfNeeded("Failed to add routes", err, t.l) - } - } - - return nil -} - -func (t *winTun) Activate() error { - luid := winipcfg.LUID(t.tun.LUID()) - - err := luid.SetIPAddresses([]netip.Prefix{t.cidr}) - if err != nil { - return fmt.Errorf("failed to set address: %w", err) - } - - err = t.addRoutes(false) - if err != nil { - return err - } - - return nil -} - -func (t *winTun) addRoutes(logErrors bool) error { - luid := winipcfg.LUID(t.tun.LUID()) - routes := *t.Routes.Load() - foundDefault4 := false - - for _, r := range routes { - if !r.Via.IsValid() || !r.Install { - // We don't allow route MTUs so only install routes with a via - continue - } - - // Add our unsafe route - err := luid.AddRoute(r.Cidr, r.Via, uint32(r.Metric)) - if err != nil { - retErr := util.NewContextualError("Failed to add route", map[string]interface{}{"route": r}, err) - if logErrors { - retErr.Log(t.l) - continue - } else { - return retErr - } - } else { - t.l.WithField("route", r).Info("Added route") - } - - if !foundDefault4 { - if r.Cidr.Bits() == 0 && r.Cidr.Addr().BitLen() == 32 { - foundDefault4 = true - } - } - } - - ipif, err := luid.IPInterface(windows.AF_INET) - if err != nil { - return fmt.Errorf("failed to get ip interface: %w", err) - } - - ipif.NLMTU = uint32(t.MTU) - if foundDefault4 { - ipif.UseAutomaticMetric = false - ipif.Metric = 0 - } - - if err := ipif.Set(); err != nil { - return fmt.Errorf("failed to set ip interface: %w", err) - } - return nil -} - -func (t *winTun) removeRoutes(routes []Route) error { - luid := winipcfg.LUID(t.tun.LUID()) - - for _, r := range routes { - if !r.Install { - continue - } - - err := luid.DeleteRoute(r.Cidr, r.Via) - if err != nil { - t.l.WithError(err).WithField("route", r).Error("Failed to remove route") - } else { - t.l.WithField("route", r).Info("Removed route") - } - } - return nil -} - -func (t *winTun) RouteFor(ip netip.Addr) netip.Addr { - r, _ := t.routeTree.Load().Lookup(ip) - return r -} - -func (t *winTun) Cidr() netip.Prefix { - return t.cidr -} - -func (t *winTun) Name() string { - return t.Device -} - -func (t *winTun) Read(b []byte) (int, error) { - return t.tun.Read(b, 0) -} - -func (t *winTun) Write(b []byte) (int, error) { - return t.tun.Write(b, 0) -} - -func (t *winTun) NewMultiQueueReader() (io.ReadWriteCloser, error) { - return nil, fmt.Errorf("TODO: multiqueue not implemented for windows") -} - -func (t *winTun) Close() error { - // It seems that the Windows networking stack doesn't like it when we destroy interfaces that have active routes, - // so to be certain, just remove everything before destroying. - luid := winipcfg.LUID(t.tun.LUID()) - _ = luid.FlushRoutes(windows.AF_INET) - _ = luid.FlushIPAddresses(windows.AF_INET) - /* We don't support IPV6 yet - _ = luid.FlushRoutes(windows.AF_INET6) - _ = luid.FlushIPAddresses(windows.AF_INET6) - */ - _ = luid.FlushDNS(windows.AF_INET) - - return t.tun.Close() -} diff --git a/overlay/user.go b/overlay/user.go index 1bb4ef5..ae665f3 100644 --- a/overlay/user.go +++ b/overlay/user.go @@ -8,16 +8,16 @@ import ( "github.com/slackhq/nebula/config" ) -func NewUserDeviceFromConfig(c *config.C, l *logrus.Logger, tunCidr netip.Prefix, routines int) (Device, error) { - return NewUserDevice(tunCidr) +func NewUserDeviceFromConfig(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) { + return NewUserDevice(vpnNetworks) } -func NewUserDevice(tunCidr netip.Prefix) (Device, error) { +func NewUserDevice(vpnNetworks []netip.Prefix) (Device, error) { // these pipes guarantee each write/read will match 1:1 or, ow := io.Pipe() ir, iw := io.Pipe() return &UserDevice{ - tunCidr: tunCidr, + vpnNetworks: vpnNetworks, outboundReader: or, outboundWriter: ow, inboundReader: ir, @@ -26,7 +26,7 @@ func NewUserDevice(tunCidr netip.Prefix) (Device, error) { } type UserDevice struct { - tunCidr netip.Prefix + vpnNetworks []netip.Prefix outboundReader *io.PipeReader outboundWriter *io.PipeWriter @@ -38,7 +38,7 @@ type UserDevice struct { func (d *UserDevice) Activate() error { return nil } -func (d *UserDevice) Cidr() netip.Prefix { return d.tunCidr } +func (d *UserDevice) Networks() []netip.Prefix { return d.vpnNetworks } func (d *UserDevice) Name() string { return "faketun0" } func (d *UserDevice) RouteFor(ip netip.Addr) netip.Addr { return ip } func (d *UserDevice) NewMultiQueueReader() (io.ReadWriteCloser, error) { diff --git a/pki.go b/pki.go index fe64ea5..888da7c 100644 --- a/pki.go +++ b/pki.go @@ -1,13 +1,19 @@ package nebula import ( + "encoding/binary" + "encoding/json" "errors" "fmt" + "net" + "net/netip" "os" + "slices" "strings" "sync/atomic" "time" + "github.com/gaissmai/bart" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/config" @@ -21,12 +27,22 @@ type PKI struct { } type CertState struct { - Certificate cert.Certificate - RawCertificate []byte - RawCertificateNoKey []byte - PublicKey []byte - PrivateKey []byte - pkcs11Backed bool + v1Cert cert.Certificate + v1HandshakeBytes []byte + + v2Cert cert.Certificate + v2HandshakeBytes []byte + + defaultVersion cert.Version + privateKey []byte + pkcs11Backed bool + cipher string + + myVpnNetworks []netip.Prefix + myVpnNetworksTable *bart.Table[struct{}] + myVpnAddrs []netip.Addr + myVpnAddrsTable *bart.Table[struct{}] + myVpnBroadcastAddrsTable *bart.Table[struct{}] } func NewPKIFromConfig(l *logrus.Logger, c *config.C) (*PKI, error) { @@ -46,16 +62,16 @@ func NewPKIFromConfig(l *logrus.Logger, c *config.C) (*PKI, error) { return pki, nil } -func (p *PKI) GetCertState() *CertState { - return p.cs.Load() -} - func (p *PKI) GetCAPool() *cert.CAPool { return p.caPool.Load() } +func (p *PKI) getCertState() *CertState { + return p.cs.Load() +} + func (p *PKI) reload(c *config.C, initial bool) error { - err := p.reloadCert(c, initial) + err := p.reloadCerts(c, initial) if err != nil { if initial { return err @@ -74,33 +90,94 @@ func (p *PKI) reload(c *config.C, initial bool) error { return nil } -func (p *PKI) reloadCert(c *config.C, initial bool) *util.ContextualError { - cs, err := newCertStateFromConfig(c) +func (p *PKI) reloadCerts(c *config.C, initial bool) *util.ContextualError { + newState, err := newCertStateFromConfig(c) if err != nil { return util.NewContextualError("Could not load client cert", nil, err) } if !initial { - //TODO: include check for mask equality as well + currentState := p.cs.Load() + if newState.v1Cert != nil { + if currentState.v1Cert == nil { + return util.NewContextualError("v1 certificate was added, restart required", nil, err) + } - // did IP in cert change? if so, don't set - currentCert := p.cs.Load().Certificate - oldIPs := currentCert.Networks() - newIPs := cs.Certificate.Networks() - if len(oldIPs) > 0 && len(newIPs) > 0 && oldIPs[0].String() != newIPs[0].String() { + // did IP in cert change? if so, don't set + if !slices.Equal(currentState.v1Cert.Networks(), newState.v1Cert.Networks()) { + return util.NewContextualError( + "Networks in new cert was different from old", + m{"new_networks": newState.v1Cert.Networks(), "old_networks": currentState.v1Cert.Networks()}, + nil, + ) + } + + if currentState.v1Cert.Curve() != newState.v1Cert.Curve() { + return util.NewContextualError( + "Curve in new cert was different from old", + m{"new_curve": newState.v1Cert.Curve(), "old_curve": currentState.v1Cert.Curve()}, + nil, + ) + } + + } else if currentState.v1Cert != nil { + //TODO: CERT-V2 we should be able to tear this down + return util.NewContextualError("v1 certificate was removed, restart required", nil, err) + } + + if newState.v2Cert != nil { + if currentState.v2Cert == nil { + return util.NewContextualError("v2 certificate was added, restart required", nil, err) + } + + // did IP in cert change? if so, don't set + if !slices.Equal(currentState.v2Cert.Networks(), newState.v2Cert.Networks()) { + return util.NewContextualError( + "Networks in new cert was different from old", + m{"new_networks": newState.v2Cert.Networks(), "old_networks": currentState.v2Cert.Networks()}, + nil, + ) + } + + if currentState.v2Cert.Curve() != newState.v2Cert.Curve() { + return util.NewContextualError( + "Curve in new cert was different from old", + m{"new_curve": newState.v2Cert.Curve(), "old_curve": currentState.v2Cert.Curve()}, + nil, + ) + } + + } else if currentState.v2Cert != nil { + return util.NewContextualError("v2 certificate was removed, restart required", nil, err) + } + + // Cipher cant be hot swapped so just leave it at what it was before + newState.cipher = currentState.cipher + + } else { + newState.cipher = c.GetString("cipher", "aes") + //TODO: this sucks and we should make it not a global + switch newState.cipher { + case "aes": + noiseEndianness = binary.BigEndian + case "chachapoly": + noiseEndianness = binary.LittleEndian + default: return util.NewContextualError( - "Networks in new cert was different from old", - m{"new_network": newIPs[0], "old_network": oldIPs[0]}, + "unknown cipher", + m{"cipher": newState.cipher}, nil, ) } } - p.cs.Store(cs) + p.cs.Store(newState) + + //TODO: CERT-V2 newState needs a stringer that does json if initial { - p.l.WithField("cert", cs.Certificate).Debug("Client nebula certificate") + p.l.WithField("cert", newState).Debug("Client nebula certificate(s)") } else { - p.l.WithField("cert", cs.Certificate).Info("Client cert refreshed from disk") + p.l.WithField("cert", newState).Info("Client certificate(s) refreshed from disk") } return nil } @@ -116,55 +193,65 @@ func (p *PKI) reloadCAPool(c *config.C) *util.ContextualError { return nil } -func newCertState(certificate cert.Certificate, pkcs11backed bool, privateKey []byte) (*CertState, error) { - // Marshal the certificate to ensure it is valid - rawCertificate, err := certificate.Marshal() - if err != nil { - return nil, fmt.Errorf("invalid nebula certificate on interface: %s", err) +func (cs *CertState) GetDefaultCertificate() cert.Certificate { + c := cs.getCertificate(cs.defaultVersion) + if c == nil { + panic("No default certificate found") } - - publicKey := certificate.PublicKey() - cs := &CertState{ - RawCertificate: rawCertificate, - Certificate: certificate, - PrivateKey: privateKey, - PublicKey: publicKey, - pkcs11Backed: pkcs11backed, - } - - rawCertNoKey, err := cs.Certificate.MarshalForHandshakes() - if err != nil { - return nil, fmt.Errorf("error marshalling certificate no key: %s", err) - } - cs.RawCertificateNoKey = rawCertNoKey - - return cs, nil + return c } -func loadPrivateKey(privPathOrPEM string) (rawKey []byte, curve cert.Curve, isPkcs11 bool, err error) { - var pemPrivateKey []byte - if strings.Contains(privPathOrPEM, "-----BEGIN") { - pemPrivateKey = []byte(privPathOrPEM) - privPathOrPEM = "" - rawKey, _, curve, err = cert.UnmarshalPrivateKeyFromPEM(pemPrivateKey) - if err != nil { - return nil, curve, false, fmt.Errorf("error while unmarshaling pki.key %s: %s", privPathOrPEM, err) - } - } else if strings.HasPrefix(privPathOrPEM, "pkcs11:") { - rawKey = []byte(privPathOrPEM) - return rawKey, cert.Curve_P256, true, nil - } else { - pemPrivateKey, err = os.ReadFile(privPathOrPEM) - if err != nil { - return nil, curve, false, fmt.Errorf("unable to read pki.key file %s: %s", privPathOrPEM, err) - } - rawKey, _, curve, err = cert.UnmarshalPrivateKeyFromPEM(pemPrivateKey) - if err != nil { - return nil, curve, false, fmt.Errorf("error while unmarshaling pki.key %s: %s", privPathOrPEM, err) - } +func (cs *CertState) getCertificate(v cert.Version) cert.Certificate { + switch v { + case cert.Version1: + return cs.v1Cert + case cert.Version2: + return cs.v2Cert } - return + return nil +} + +// getHandshakeBytes returns the cached bytes to be used in a handshake message for the requested version. +// Callers must check if the return []byte is nil. +func (cs *CertState) getHandshakeBytes(v cert.Version) []byte { + switch v { + case cert.Version1: + return cs.v1HandshakeBytes + case cert.Version2: + return cs.v2HandshakeBytes + default: + return nil + } +} + +func (cs *CertState) String() string { + b, err := cs.MarshalJSON() + if err != nil { + return fmt.Sprintf("error marshaling certificate state: %v", err) + } + return string(b) +} + +func (cs *CertState) MarshalJSON() ([]byte, error) { + msg := []json.RawMessage{} + if cs.v1Cert != nil { + b, err := cs.v1Cert.MarshalJSON() + if err != nil { + return nil, err + } + msg = append(msg, b) + } + + if cs.v2Cert != nil { + b, err := cs.v2Cert.MarshalJSON() + if err != nil { + return nil, err + } + msg = append(msg, b) + } + + return json.Marshal(msg) } func newCertStateFromConfig(c *config.C) (*CertState, error) { @@ -198,24 +285,197 @@ func newCertStateFromConfig(c *config.C) (*CertState, error) { } } - nebulaCert, _, err := cert.UnmarshalCertificateFromPEM(rawCert) + var crt, v1, v2 cert.Certificate + for { + // Load the certificate + crt, rawCert, err = loadCertificate(rawCert) + if err != nil { + return nil, err + } + + switch crt.Version() { + case cert.Version1: + if v1 != nil { + return nil, fmt.Errorf("v1 certificate already found in pki.cert") + } + v1 = crt + case cert.Version2: + if v2 != nil { + return nil, fmt.Errorf("v2 certificate already found in pki.cert") + } + v2 = crt + default: + return nil, fmt.Errorf("unknown certificate version %v", crt.Version()) + } + + if len(rawCert) == 0 || strings.TrimSpace(string(rawCert)) == "" { + break + } + } + + if v1 == nil && v2 == nil { + return nil, errors.New("no certificates found in pki.cert") + } + + useDefaultVersion := uint32(1) + if v1 == nil { + // The only condition that requires v2 as the default is if only a v2 certificate is present + // We do this to avoid having to configure it specifically in the config file + useDefaultVersion = 2 + } + + rawDefaultVersion := c.GetUint32("pki.default_version", useDefaultVersion) + var defaultVersion cert.Version + switch rawDefaultVersion { + case 1: + if v1 == nil { + return nil, fmt.Errorf("can not use pki.default_version 1 without a v1 certificate in pki.cert") + } + defaultVersion = cert.Version1 + case 2: + defaultVersion = cert.Version2 + default: + return nil, fmt.Errorf("unknown pki.default_version: %v", rawDefaultVersion) + } + + return newCertState(defaultVersion, v1, v2, isPkcs11, curve, rawKey) +} + +func newCertState(dv cert.Version, v1, v2 cert.Certificate, pkcs11backed bool, privateKeyCurve cert.Curve, privateKey []byte) (*CertState, error) { + cs := CertState{ + privateKey: privateKey, + pkcs11Backed: pkcs11backed, + myVpnNetworksTable: new(bart.Table[struct{}]), + myVpnAddrsTable: new(bart.Table[struct{}]), + myVpnBroadcastAddrsTable: new(bart.Table[struct{}]), + } + + if v1 != nil && v2 != nil { + if !slices.Equal(v1.PublicKey(), v2.PublicKey()) { + return nil, util.NewContextualError("v1 and v2 public keys are not the same, ignoring", nil, nil) + } + + if v1.Curve() != v2.Curve() { + return nil, util.NewContextualError("v1 and v2 curve are not the same, ignoring", nil, nil) + } + + //TODO: CERT-V2 make sure v2 has v1s address + + cs.defaultVersion = dv + } + + if v1 != nil { + if pkcs11backed { + //NOTE: We do not currently have a method to verify a public private key pair when the private key is in an hsm + } else { + if err := v1.VerifyPrivateKey(privateKeyCurve, privateKey); err != nil { + return nil, fmt.Errorf("private key is not a pair with public key in nebula cert") + } + } + + v1hs, err := v1.MarshalForHandshakes() + if err != nil { + return nil, fmt.Errorf("error marshalling certificate for handshake: %w", err) + } + cs.v1Cert = v1 + cs.v1HandshakeBytes = v1hs + + if cs.defaultVersion == 0 { + cs.defaultVersion = cert.Version1 + } + } + + if v2 != nil { + if pkcs11backed { + //NOTE: We do not currently have a method to verify a public private key pair when the private key is in an hsm + } else { + if err := v2.VerifyPrivateKey(privateKeyCurve, privateKey); err != nil { + return nil, fmt.Errorf("private key is not a pair with public key in nebula cert") + } + } + + v2hs, err := v2.MarshalForHandshakes() + if err != nil { + return nil, fmt.Errorf("error marshalling certificate for handshake: %w", err) + } + cs.v2Cert = v2 + cs.v2HandshakeBytes = v2hs + + if cs.defaultVersion == 0 { + cs.defaultVersion = cert.Version2 + } + } + + var crt cert.Certificate + crt = cs.getCertificate(cert.Version2) + if crt == nil { + // v2 certificates are a superset, only look at v1 if its all we have + crt = cs.getCertificate(cert.Version1) + } + + for _, network := range crt.Networks() { + cs.myVpnNetworks = append(cs.myVpnNetworks, network) + cs.myVpnNetworksTable.Insert(network, struct{}{}) + + cs.myVpnAddrs = append(cs.myVpnAddrs, network.Addr()) + cs.myVpnAddrsTable.Insert(netip.PrefixFrom(network.Addr(), network.Addr().BitLen()), struct{}{}) + + if network.Addr().Is4() { + addr := network.Masked().Addr().As4() + mask := net.CIDRMask(network.Bits(), network.Addr().BitLen()) + binary.BigEndian.PutUint32(addr[:], binary.BigEndian.Uint32(addr[:])|^binary.BigEndian.Uint32(mask)) + cs.myVpnBroadcastAddrsTable.Insert(netip.PrefixFrom(netip.AddrFrom4(addr), network.Addr().BitLen()), struct{}{}) + } + } + + return &cs, nil +} + +func loadPrivateKey(privPathOrPEM string) (rawKey []byte, curve cert.Curve, isPkcs11 bool, err error) { + var pemPrivateKey []byte + if strings.Contains(privPathOrPEM, "-----BEGIN") { + pemPrivateKey = []byte(privPathOrPEM) + privPathOrPEM = "" + rawKey, _, curve, err = cert.UnmarshalPrivateKeyFromPEM(pemPrivateKey) + if err != nil { + return nil, curve, false, fmt.Errorf("error while unmarshaling pki.key %s: %s", privPathOrPEM, err) + } + } else if strings.HasPrefix(privPathOrPEM, "pkcs11:") { + rawKey = []byte(privPathOrPEM) + return rawKey, cert.Curve_P256, true, nil + } else { + pemPrivateKey, err = os.ReadFile(privPathOrPEM) + if err != nil { + return nil, curve, false, fmt.Errorf("unable to read pki.key file %s: %s", privPathOrPEM, err) + } + rawKey, _, curve, err = cert.UnmarshalPrivateKeyFromPEM(pemPrivateKey) + if err != nil { + return nil, curve, false, fmt.Errorf("error while unmarshaling pki.key %s: %s", privPathOrPEM, err) + } + } + + return +} + +func loadCertificate(b []byte) (cert.Certificate, []byte, error) { + c, b, err := cert.UnmarshalCertificateFromPEM(b) if err != nil { - return nil, fmt.Errorf("error while unmarshaling pki.cert %s: %s", pubPathOrPEM, err) + return nil, b, fmt.Errorf("error while unmarshaling pki.cert: %w", err) } - if nebulaCert.Expired(time.Now()) { - return nil, fmt.Errorf("nebula certificate for this host is expired") + if c.Expired(time.Now()) { + return nil, b, fmt.Errorf("nebula certificate for this host is expired") } - if len(nebulaCert.Networks()) == 0 { - return nil, fmt.Errorf("no networks encoded in certificate") + if len(c.Networks()) == 0 { + return nil, b, fmt.Errorf("no networks encoded in certificate") } - if err = nebulaCert.VerifyPrivateKey(curve, rawKey); err != nil { - return nil, fmt.Errorf("private key is not a pair with public key in nebula cert") + if c.IsCA() { + return nil, b, fmt.Errorf("host certificate is a CA certificate") } - return newCertState(nebulaCert, isPkcs11, rawKey) + return c, b, nil } func loadCAPoolFromConfig(l *logrus.Logger, c *config.C) (*cert.CAPool, error) { diff --git a/relay_manager.go b/relay_manager.go index 1a3a4d4..7565350 100644 --- a/relay_manager.go +++ b/relay_manager.go @@ -9,6 +9,7 @@ import ( "sync/atomic" "github.com/sirupsen/logrus" + "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/header" ) @@ -72,7 +73,7 @@ func AddRelay(l *logrus.Logger, relayHostInfo *HostInfo, hm *HostMap, vpnIp neti Type: relayType, State: state, LocalIndex: index, - PeerIp: vpnIp, + PeerAddr: vpnIp, } if remoteIdx != nil { @@ -91,40 +92,71 @@ func AddRelay(l *logrus.Logger, relayHostInfo *HostInfo, hm *HostMap, vpnIp neti func (rm *relayManager) EstablishRelay(relayHostInfo *HostInfo, m *NebulaControl) (*Relay, error) { relay, ok := relayHostInfo.relayState.CompleteRelayByIdx(m.InitiatorRelayIndex, m.ResponderRelayIndex) if !ok { - rm.l.WithFields(logrus.Fields{"relay": relayHostInfo.vpnIp, + fields := logrus.Fields{ + "relay": relayHostInfo.vpnAddrs[0], "initiatorRelayIndex": m.InitiatorRelayIndex, - "relayFrom": m.RelayFromIp, - "relayTo": m.RelayToIp}).Info("relayManager failed to update relay") + } + + if m.RelayFromAddr == nil { + fields["relayFrom"] = m.OldRelayFromAddr + } else { + fields["relayFrom"] = m.RelayFromAddr + } + + if m.RelayToAddr == nil { + fields["relayTo"] = m.OldRelayToAddr + } else { + fields["relayTo"] = m.RelayToAddr + } + + rm.l.WithFields(fields).Info("relayManager failed to update relay") return nil, fmt.Errorf("unknown relay") } return relay, nil } -func (rm *relayManager) HandleControlMsg(h *HostInfo, m *NebulaControl, f *Interface) { - - switch m.Type { - case NebulaControl_CreateRelayRequest: - rm.handleCreateRelayRequest(h, f, m) - case NebulaControl_CreateRelayResponse: - rm.handleCreateRelayResponse(h, f, m) +func (rm *relayManager) HandleControlMsg(h *HostInfo, d []byte, f *Interface) { + msg := &NebulaControl{} + err := msg.Unmarshal(d) + if err != nil { + h.logger(f.l).WithError(err).Error("Failed to unmarshal control message") + return } + var v cert.Version + if msg.OldRelayFromAddr > 0 || msg.OldRelayToAddr > 0 { + v = cert.Version1 + + b := [4]byte{} + binary.BigEndian.PutUint32(b[:], msg.OldRelayFromAddr) + msg.RelayFromAddr = netAddrToProtoAddr(netip.AddrFrom4(b)) + + binary.BigEndian.PutUint32(b[:], msg.OldRelayToAddr) + msg.RelayToAddr = netAddrToProtoAddr(netip.AddrFrom4(b)) + } else { + v = cert.Version2 + } + + switch msg.Type { + case NebulaControl_CreateRelayRequest: + rm.handleCreateRelayRequest(v, h, f, msg) + case NebulaControl_CreateRelayResponse: + rm.handleCreateRelayResponse(v, h, f, msg) + } } -func (rm *relayManager) handleCreateRelayResponse(h *HostInfo, f *Interface, m *NebulaControl) { +func (rm *relayManager) handleCreateRelayResponse(v cert.Version, h *HostInfo, f *Interface, m *NebulaControl) { rm.l.WithFields(logrus.Fields{ - "relayFrom": m.RelayFromIp, - "relayTo": m.RelayToIp, + "relayFrom": protoAddrToNetAddr(m.RelayFromAddr), + "relayTo": protoAddrToNetAddr(m.RelayToAddr), "initiatorRelayIndex": m.InitiatorRelayIndex, "responderRelayIndex": m.ResponderRelayIndex, - "vpnIp": h.vpnIp}). + "vpnAddrs": h.vpnAddrs}). Info("handleCreateRelayResponse") - target := m.RelayToIp - //TODO: IPV6-WORK - b := [4]byte{} - binary.BigEndian.PutUint32(b[:], m.RelayToIp) - targetAddr := netip.AddrFrom4(b) + + target := m.RelayToAddr + targetAddr := protoAddrToNetAddr(target) relay, err := rm.EstablishRelay(h, m) if err != nil { @@ -136,68 +168,88 @@ func (rm *relayManager) handleCreateRelayResponse(h *HostInfo, f *Interface, m * return } // I'm the middle man. Let the initiator know that the I've established the relay they requested. - peerHostInfo := rm.hostmap.QueryVpnIp(relay.PeerIp) + peerHostInfo := rm.hostmap.QueryVpnAddr(relay.PeerAddr) if peerHostInfo == nil { - rm.l.WithField("relayTo", relay.PeerIp).Error("Can't find a HostInfo for peer") + rm.l.WithField("relayTo", relay.PeerAddr).Error("Can't find a HostInfo for peer") return } peerRelay, ok := peerHostInfo.relayState.QueryRelayForByIp(targetAddr) if !ok { - rm.l.WithField("relayTo", peerHostInfo.vpnIp).Error("peerRelay does not have Relay state for relayTo") + rm.l.WithField("relayTo", peerHostInfo.vpnAddrs[0]).Error("peerRelay does not have Relay state for relayTo") return } - if peerRelay.State == PeerRequested { - //TODO: IPV6-WORK - b = peerHostInfo.vpnIp.As4() - peerRelay.State = Established + switch peerRelay.State { + case Requested: + // I initiated the request to this peer, but haven't heard back from the peer yet. I must wait for this peer + // to respond to complete the connection. + case PeerRequested, Disestablished, Established: + peerHostInfo.relayState.UpdateRelayForByIpState(targetAddr, Established) resp := NebulaControl{ Type: NebulaControl_CreateRelayResponse, ResponderRelayIndex: peerRelay.LocalIndex, InitiatorRelayIndex: peerRelay.RemoteIndex, - RelayFromIp: binary.BigEndian.Uint32(b[:]), - RelayToIp: uint32(target), } + + if v == cert.Version1 { + peer := peerHostInfo.vpnAddrs[0] + if !peer.Is4() { + rm.l.WithField("relayFrom", peer). + WithField("relayTo", target). + WithField("initiatorRelayIndex", resp.InitiatorRelayIndex). + WithField("responderRelayIndex", resp.ResponderRelayIndex). + WithField("vpnAddrs", peerHostInfo.vpnAddrs). + Error("Refusing to CreateRelayResponse for a v1 relay with an ipv6 address") + return + } + + b := peer.As4() + resp.OldRelayFromAddr = binary.BigEndian.Uint32(b[:]) + b = targetAddr.As4() + resp.OldRelayToAddr = binary.BigEndian.Uint32(b[:]) + } else { + resp.RelayFromAddr = netAddrToProtoAddr(peerHostInfo.vpnAddrs[0]) + resp.RelayToAddr = target + } + msg, err := resp.Marshal() if err != nil { - rm.l. - WithError(err).Error("relayManager Failed to marshal Control CreateRelayResponse message to create relay") + rm.l.WithError(err). + Error("relayManager Failed to marshal Control CreateRelayResponse message to create relay") } else { f.SendMessageToHostInfo(header.Control, 0, peerHostInfo, msg, make([]byte, 12), make([]byte, mtu)) rm.l.WithFields(logrus.Fields{ - "relayFrom": resp.RelayFromIp, - "relayTo": resp.RelayToIp, + "relayFrom": resp.RelayFromAddr, + "relayTo": resp.RelayToAddr, "initiatorRelayIndex": resp.InitiatorRelayIndex, "responderRelayIndex": resp.ResponderRelayIndex, - "vpnIp": peerHostInfo.vpnIp}). + "vpnAddrs": peerHostInfo.vpnAddrs}). Info("send CreateRelayResponse") } } } -func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *NebulaControl) { - //TODO: IPV6-WORK - b := [4]byte{} - binary.BigEndian.PutUint32(b[:], m.RelayFromIp) - from := netip.AddrFrom4(b) - - binary.BigEndian.PutUint32(b[:], m.RelayToIp) - target := netip.AddrFrom4(b) +func (rm *relayManager) handleCreateRelayRequest(v cert.Version, h *HostInfo, f *Interface, m *NebulaControl) { + from := protoAddrToNetAddr(m.RelayFromAddr) + target := protoAddrToNetAddr(m.RelayToAddr) logMsg := rm.l.WithFields(logrus.Fields{ "relayFrom": from, "relayTo": target, "initiatorRelayIndex": m.InitiatorRelayIndex, - "vpnIp": h.vpnIp}) + "vpnAddrs": h.vpnAddrs}) logMsg.Info("handleCreateRelayRequest") // Is the source of the relay me? This should never happen, but did happen due to // an issue migrating relays over to newly re-handshaked host info objects. - if from == f.myVpnNet.Addr() { + _, found := f.myVpnAddrsTable.Lookup(from) + if found { logMsg.WithField("myIP", from).Error("Discarding relay request from myself") return } + // Is the target of the relay me? - if target == f.myVpnNet.Addr() { + _, found = f.myVpnAddrsTable.Lookup(target) + if found { existingRelay, ok := h.relayState.QueryRelayForByIp(from) if ok { switch existingRelay.State { @@ -215,6 +267,21 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N "existingRemoteIndex": existingRelay.RemoteIndex}).Error("Existing relay mismatch with CreateRelayRequest") return } + case Disestablished: + if existingRelay.RemoteIndex != m.InitiatorRelayIndex { + // We got a brand new Relay request, because its index is different than what we saw before. + // This should never happen. The peer should never change an index, once created. + logMsg.WithFields(logrus.Fields{ + "existingRemoteIndex": existingRelay.RemoteIndex}).Error("Existing relay mismatch with CreateRelayRequest") + return + } + // Mark the relay as 'Established' because it's safe to use again + h.relayState.UpdateRelayForByIpState(from, Established) + case PeerRequested: + // I should never be in this state, because I am terminal, not forwarding. + logMsg.WithFields(logrus.Fields{ + "existingRemoteIndex": existingRelay.RemoteIndex, + "state": existingRelay.State}).Error("Unexpected Relay State found") } } else { _, err := AddRelay(rm.l, h, f.hostMap, from, &m.InitiatorRelayIndex, TerminalType, Established) @@ -226,21 +293,26 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N relay, ok := h.relayState.QueryRelayForByIp(from) if !ok { - logMsg.Error("Relay State not found") + logMsg.WithField("from", from).Error("Relay State not found") return } - //TODO: IPV6-WORK - fromB := from.As4() - targetB := target.As4() - resp := NebulaControl{ Type: NebulaControl_CreateRelayResponse, ResponderRelayIndex: relay.LocalIndex, InitiatorRelayIndex: relay.RemoteIndex, - RelayFromIp: binary.BigEndian.Uint32(fromB[:]), - RelayToIp: binary.BigEndian.Uint32(targetB[:]), } + + if v == cert.Version1 { + b := from.As4() + resp.OldRelayFromAddr = binary.BigEndian.Uint32(b[:]) + b = target.As4() + resp.OldRelayToAddr = binary.BigEndian.Uint32(b[:]) + } else { + resp.RelayFromAddr = netAddrToProtoAddr(from) + resp.RelayToAddr = netAddrToProtoAddr(target) + } + msg, err := resp.Marshal() if err != nil { logMsg. @@ -248,12 +320,11 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N } else { f.SendMessageToHostInfo(header.Control, 0, h, msg, make([]byte, 12), make([]byte, mtu)) rm.l.WithFields(logrus.Fields{ - //TODO: IPV6-WORK, this used to use the resp object but I am getting lazy now "relayFrom": from, "relayTo": target, "initiatorRelayIndex": resp.InitiatorRelayIndex, "responderRelayIndex": resp.ResponderRelayIndex, - "vpnIp": h.vpnIp}). + "vpnAddrs": h.vpnAddrs}). Info("send CreateRelayResponse") } return @@ -262,7 +333,7 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N if !rm.GetAmRelay() { return } - peer := rm.hostmap.QueryVpnIp(target) + peer := rm.hostmap.QueryVpnAddr(target) if peer == nil { // Try to establish a connection to this host. If we get a future relay request, // we'll be ready! @@ -273,104 +344,69 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N // Only create relays to peers for whom I have a direct connection return } - sendCreateRequest := false var index uint32 var err error targetRelay, ok := peer.relayState.QueryRelayForByIp(from) if ok { index = targetRelay.LocalIndex - if targetRelay.State == Requested { - sendCreateRequest = true - } } else { // Allocate an index in the hostMap for this relay peer index, err = AddRelay(rm.l, peer, f.hostMap, from, nil, ForwardingType, Requested) if err != nil { return } - sendCreateRequest = true } - if sendCreateRequest { - //TODO: IPV6-WORK - fromB := h.vpnIp.As4() - targetB := target.As4() + peer.relayState.UpdateRelayForByIpState(from, Requested) + // Send a CreateRelayRequest to the peer. + req := NebulaControl{ + Type: NebulaControl_CreateRelayRequest, + InitiatorRelayIndex: index, + } - // Send a CreateRelayRequest to the peer. - req := NebulaControl{ - Type: NebulaControl_CreateRelayRequest, - InitiatorRelayIndex: index, - RelayFromIp: binary.BigEndian.Uint32(fromB[:]), - RelayToIp: binary.BigEndian.Uint32(targetB[:]), - } - msg, err := req.Marshal() - if err != nil { - logMsg. - WithError(err).Error("relayManager Failed to marshal Control message to create relay") - } else { - f.SendMessageToHostInfo(header.Control, 0, peer, msg, make([]byte, 12), make([]byte, mtu)) - rm.l.WithFields(logrus.Fields{ - //TODO: IPV6-WORK another lazy used to use the req object - "relayFrom": h.vpnIp, - "relayTo": target, - "initiatorRelayIndex": req.InitiatorRelayIndex, - "responderRelayIndex": req.ResponderRelayIndex, - "vpnIp": target}). - Info("send CreateRelayRequest") + if v == cert.Version1 { + if !h.vpnAddrs[0].Is4() { + rm.l.WithField("relayFrom", h.vpnAddrs[0]). + WithField("relayTo", target). + WithField("initiatorRelayIndex", req.InitiatorRelayIndex). + WithField("responderRelayIndex", req.ResponderRelayIndex). + WithField("vpnAddr", target). + Error("Refusing to CreateRelayRequest for a v1 relay with an ipv6 address") + return } + + b := h.vpnAddrs[0].As4() + req.OldRelayFromAddr = binary.BigEndian.Uint32(b[:]) + b = target.As4() + req.OldRelayToAddr = binary.BigEndian.Uint32(b[:]) + } else { + req.RelayFromAddr = netAddrToProtoAddr(h.vpnAddrs[0]) + req.RelayToAddr = netAddrToProtoAddr(target) } + + msg, err := req.Marshal() + if err != nil { + logMsg. + WithError(err).Error("relayManager Failed to marshal Control message to create relay") + } else { + f.SendMessageToHostInfo(header.Control, 0, peer, msg, make([]byte, 12), make([]byte, mtu)) + rm.l.WithFields(logrus.Fields{ + "relayFrom": h.vpnAddrs[0], + "relayTo": target, + "initiatorRelayIndex": req.InitiatorRelayIndex, + "responderRelayIndex": req.ResponderRelayIndex, + "vpnAddr": target}). + Info("send CreateRelayRequest") + } + // Also track the half-created Relay state just received - relay, ok := h.relayState.QueryRelayForByIp(target) + _, ok = h.relayState.QueryRelayForByIp(target) if !ok { - // Add the relay - state := PeerRequested - if targetRelay != nil && targetRelay.State == Established { - state = Established - } - _, err := AddRelay(rm.l, h, f.hostMap, target, &m.InitiatorRelayIndex, ForwardingType, state) + _, err := AddRelay(rm.l, h, f.hostMap, target, &m.InitiatorRelayIndex, ForwardingType, PeerRequested) if err != nil { logMsg. WithError(err).Error("relayManager Failed to allocate a local index for relay") return } - } else { - switch relay.State { - case Established: - if relay.RemoteIndex != m.InitiatorRelayIndex { - // We got a brand new Relay request, because its index is different than what we saw before. - // This should never happen. The peer should never change an index, once created. - logMsg.WithFields(logrus.Fields{ - "existingRemoteIndex": relay.RemoteIndex}).Error("Existing relay mismatch with CreateRelayRequest") - return - } - //TODO: IPV6-WORK - fromB := h.vpnIp.As4() - targetB := target.As4() - resp := NebulaControl{ - Type: NebulaControl_CreateRelayResponse, - ResponderRelayIndex: relay.LocalIndex, - InitiatorRelayIndex: relay.RemoteIndex, - RelayFromIp: binary.BigEndian.Uint32(fromB[:]), - RelayToIp: binary.BigEndian.Uint32(targetB[:]), - } - msg, err := resp.Marshal() - if err != nil { - rm.l. - WithError(err).Error("relayManager Failed to marshal Control CreateRelayResponse message to create relay") - } else { - f.SendMessageToHostInfo(header.Control, 0, h, msg, make([]byte, 12), make([]byte, mtu)) - rm.l.WithFields(logrus.Fields{ - //TODO: IPV6-WORK more lazy, used to use resp object - "relayFrom": h.vpnIp, - "relayTo": target, - "initiatorRelayIndex": resp.InitiatorRelayIndex, - "responderRelayIndex": resp.ResponderRelayIndex, - "vpnIp": h.vpnIp}). - Info("send CreateRelayResponse") - } - - case Requested: - // Keep waiting for the other relay to complete - } } } } diff --git a/remote_list.go b/remote_list.go index 94db8f2..6baed29 100644 --- a/remote_list.go +++ b/remote_list.go @@ -4,6 +4,7 @@ import ( "context" "net" "net/netip" + "slices" "sort" "strconv" "sync" @@ -17,8 +18,8 @@ import ( type forEachFunc func(addr netip.AddrPort, preferred bool) // The checkFuncs here are to simplify bulk importing LH query response logic into a single function (reset slice and iterate) -type checkFuncV4 func(vpnIp netip.Addr, to *Ip4AndPort) bool -type checkFuncV6 func(vpnIp netip.Addr, to *Ip6AndPort) bool +type checkFuncV4 func(vpnIp netip.Addr, to *V4AddrPort) bool +type checkFuncV6 func(vpnIp netip.Addr, to *V6AddrPort) bool // CacheMap is a struct that better represents the lighthouse cache for humans // The string key is the owners vpnIp @@ -32,9 +33,6 @@ type Cache struct { Relay []netip.Addr `json:"relay"` } -//TODO: Seems like we should plop static host entries in here too since the are protected by the lighthouse from deletion -// We will never clean learned/reported information for them as it stands today - // cache is an internal struct that splits v4 and v6 addresses inside the cache map type cache struct { v4 *cacheV4 @@ -48,14 +46,14 @@ type cacheRelay struct { // cacheV4 stores learned and reported ipv4 records under cache type cacheV4 struct { - learned *Ip4AndPort - reported []*Ip4AndPort + learned *V4AddrPort + reported []*V4AddrPort } // cacheV4 stores learned and reported ipv6 records under cache type cacheV6 struct { - learned *Ip6AndPort - reported []*Ip6AndPort + learned *V6AddrPort + reported []*V6AddrPort } type hostnamePort struct { @@ -170,7 +168,7 @@ func (hr *hostnamesResults) Cancel() { } } -func (hr *hostnamesResults) GetIPs() []netip.AddrPort { +func (hr *hostnamesResults) GetAddrs() []netip.AddrPort { var retSlice []netip.AddrPort if hr != nil { p := hr.ips.Load() @@ -189,6 +187,9 @@ type RemoteList struct { // Every interaction with internals requires a lock! sync.RWMutex + // The full list of vpn addresses assigned to this host + vpnAddrs []netip.Addr + // A deduplicated set of addresses. Any accessor should lock beforehand. addrs []netip.AddrPort @@ -212,13 +213,16 @@ type RemoteList struct { } // NewRemoteList creates a new empty RemoteList -func NewRemoteList(shouldAdd func(netip.Addr) bool) *RemoteList { - return &RemoteList{ +func NewRemoteList(vpnAddrs []netip.Addr, shouldAdd func(netip.Addr) bool) *RemoteList { + r := &RemoteList{ + vpnAddrs: make([]netip.Addr, len(vpnAddrs)), addrs: make([]netip.AddrPort, 0), relays: make([]netip.Addr, 0), cache: make(map[netip.Addr]*cache), shouldAdd: shouldAdd, } + copy(r.vpnAddrs, vpnAddrs) + return r } func (r *RemoteList) unlockedSetHostnamesResults(hr *hostnamesResults) { @@ -268,14 +272,13 @@ func (r *RemoteList) CopyAddrs(preferredRanges []netip.Prefix) []netip.AddrPort // LearnRemote locks and sets the learned slot for the owner vpn ip to the provided addr // Currently this is only needed when HostInfo.SetRemote is called as that should cover both handshaking and roaming. // It will mark the deduplicated address list as dirty, so do not call it unless new information is available -// TODO: this needs to support the allow list list func (r *RemoteList) LearnRemote(ownerVpnIp netip.Addr, remote netip.AddrPort) { r.Lock() defer r.Unlock() if remote.Addr().Is4() { - r.unlockedSetLearnedV4(ownerVpnIp, NewIp4AndPortFromNetIP(remote.Addr(), remote.Port())) + r.unlockedSetLearnedV4(ownerVpnIp, netAddrToProtoV4AddrPort(remote.Addr(), remote.Port())) } else { - r.unlockedSetLearnedV6(ownerVpnIp, NewIp6AndPortFromNetIP(remote.Addr(), remote.Port())) + r.unlockedSetLearnedV6(ownerVpnIp, netAddrToProtoV6AddrPort(remote.Addr(), remote.Port())) } } @@ -304,21 +307,21 @@ func (r *RemoteList) CopyCache() *CacheMap { if mc.v4 != nil { if mc.v4.learned != nil { - c.Learned = append(c.Learned, AddrPortFromIp4AndPort(mc.v4.learned)) + c.Learned = append(c.Learned, protoV4AddrPortToNetAddrPort(mc.v4.learned)) } for _, a := range mc.v4.reported { - c.Reported = append(c.Reported, AddrPortFromIp4AndPort(a)) + c.Reported = append(c.Reported, protoV4AddrPortToNetAddrPort(a)) } } if mc.v6 != nil { if mc.v6.learned != nil { - c.Learned = append(c.Learned, AddrPortFromIp6AndPort(mc.v6.learned)) + c.Learned = append(c.Learned, protoV6AddrPortToNetAddrPort(mc.v6.learned)) } for _, a := range mc.v6.reported { - c.Reported = append(c.Reported, AddrPortFromIp6AndPort(a)) + c.Reported = append(c.Reported, protoV6AddrPortToNetAddrPort(a)) } } @@ -379,7 +382,6 @@ func (r *RemoteList) Rebuild(preferredRanges []netip.Prefix) { defer r.Unlock() // Only rebuild if the cache changed - //TODO: shouldRebuild is probably pointless as we don't check for actual change when lighthouse updates come in if r.shouldRebuild { r.unlockedCollect() r.shouldRebuild = false @@ -401,14 +403,14 @@ func (r *RemoteList) unlockedIsBad(remote netip.AddrPort) bool { // unlockedSetLearnedV4 assumes you have the write lock and sets the current learned address for this owner and marks the // deduplicated address list as dirty -func (r *RemoteList) unlockedSetLearnedV4(ownerVpnIp netip.Addr, to *Ip4AndPort) { +func (r *RemoteList) unlockedSetLearnedV4(ownerVpnIp netip.Addr, to *V4AddrPort) { r.shouldRebuild = true r.unlockedGetOrMakeV4(ownerVpnIp).learned = to } // unlockedSetV4 assumes you have the write lock and resets the reported list of ips for this owner to the list provided // and marks the deduplicated address list as dirty -func (r *RemoteList) unlockedSetV4(ownerVpnIp, vpnIp netip.Addr, to []*Ip4AndPort, check checkFuncV4) { +func (r *RemoteList) unlockedSetV4(ownerVpnIp, vpnIp netip.Addr, to []*V4AddrPort, check checkFuncV4) { r.shouldRebuild = true c := r.unlockedGetOrMakeV4(ownerVpnIp) @@ -423,7 +425,7 @@ func (r *RemoteList) unlockedSetV4(ownerVpnIp, vpnIp netip.Addr, to []*Ip4AndPor } } -func (r *RemoteList) unlockedSetRelay(ownerVpnIp, vpnIp netip.Addr, to []netip.Addr) { +func (r *RemoteList) unlockedSetRelay(ownerVpnIp netip.Addr, to []netip.Addr) { r.shouldRebuild = true c := r.unlockedGetOrMakeRelay(ownerVpnIp) @@ -436,12 +438,12 @@ func (r *RemoteList) unlockedSetRelay(ownerVpnIp, vpnIp netip.Addr, to []netip.A // unlockedPrependV4 assumes you have the write lock and prepends the address in the reported list for this owner // This is only useful for establishing static hosts -func (r *RemoteList) unlockedPrependV4(ownerVpnIp netip.Addr, to *Ip4AndPort) { +func (r *RemoteList) unlockedPrependV4(ownerVpnIp netip.Addr, to *V4AddrPort) { r.shouldRebuild = true c := r.unlockedGetOrMakeV4(ownerVpnIp) // We are doing the easy append because this is rarely called - c.reported = append([]*Ip4AndPort{to}, c.reported...) + c.reported = append([]*V4AddrPort{to}, c.reported...) if len(c.reported) > MaxRemotes { c.reported = c.reported[:MaxRemotes] } @@ -449,14 +451,14 @@ func (r *RemoteList) unlockedPrependV4(ownerVpnIp netip.Addr, to *Ip4AndPort) { // unlockedSetLearnedV6 assumes you have the write lock and sets the current learned address for this owner and marks the // deduplicated address list as dirty -func (r *RemoteList) unlockedSetLearnedV6(ownerVpnIp netip.Addr, to *Ip6AndPort) { +func (r *RemoteList) unlockedSetLearnedV6(ownerVpnIp netip.Addr, to *V6AddrPort) { r.shouldRebuild = true r.unlockedGetOrMakeV6(ownerVpnIp).learned = to } // unlockedSetV6 assumes you have the write lock and resets the reported list of ips for this owner to the list provided // and marks the deduplicated address list as dirty -func (r *RemoteList) unlockedSetV6(ownerVpnIp, vpnIp netip.Addr, to []*Ip6AndPort, check checkFuncV6) { +func (r *RemoteList) unlockedSetV6(ownerVpnIp, vpnIp netip.Addr, to []*V6AddrPort, check checkFuncV6) { r.shouldRebuild = true c := r.unlockedGetOrMakeV6(ownerVpnIp) @@ -473,12 +475,12 @@ func (r *RemoteList) unlockedSetV6(ownerVpnIp, vpnIp netip.Addr, to []*Ip6AndPor // unlockedPrependV6 assumes you have the write lock and prepends the address in the reported list for this owner // This is only useful for establishing static hosts -func (r *RemoteList) unlockedPrependV6(ownerVpnIp netip.Addr, to *Ip6AndPort) { +func (r *RemoteList) unlockedPrependV6(ownerVpnIp netip.Addr, to *V6AddrPort) { r.shouldRebuild = true c := r.unlockedGetOrMakeV6(ownerVpnIp) // We are doing the easy append because this is rarely called - c.reported = append([]*Ip6AndPort{to}, c.reported...) + c.reported = append([]*V6AddrPort{to}, c.reported...) if len(c.reported) > MaxRemotes { c.reported = c.reported[:MaxRemotes] } @@ -536,14 +538,14 @@ func (r *RemoteList) unlockedCollect() { for _, c := range r.cache { if c.v4 != nil { if c.v4.learned != nil { - u := AddrPortFromIp4AndPort(c.v4.learned) + u := protoV4AddrPortToNetAddrPort(c.v4.learned) if !r.unlockedIsBad(u) { addrs = append(addrs, u) } } for _, v := range c.v4.reported { - u := AddrPortFromIp4AndPort(v) + u := protoV4AddrPortToNetAddrPort(v) if !r.unlockedIsBad(u) { addrs = append(addrs, u) } @@ -552,14 +554,14 @@ func (r *RemoteList) unlockedCollect() { if c.v6 != nil { if c.v6.learned != nil { - u := AddrPortFromIp6AndPort(c.v6.learned) + u := protoV6AddrPortToNetAddrPort(c.v6.learned) if !r.unlockedIsBad(u) { addrs = append(addrs, u) } } for _, v := range c.v6.reported { - u := AddrPortFromIp6AndPort(v) + u := protoV6AddrPortToNetAddrPort(v) if !r.unlockedIsBad(u) { addrs = append(addrs, u) } @@ -573,7 +575,7 @@ func (r *RemoteList) unlockedCollect() { } } - dnsAddrs := r.hr.GetIPs() + dnsAddrs := r.hr.GetAddrs() for _, addr := range dnsAddrs { if r.shouldAdd == nil || r.shouldAdd(addr.Addr()) { if !r.unlockedIsBad(addr) { @@ -589,6 +591,21 @@ func (r *RemoteList) unlockedCollect() { // unlockedSort assumes you have the write lock and performs the deduping and sorting of the address list func (r *RemoteList) unlockedSort(preferredRanges []netip.Prefix) { + // Use a map to deduplicate any relay addresses + dedupedRelays := map[netip.Addr]struct{}{} + for _, relay := range r.relays { + dedupedRelays[relay] = struct{}{} + } + r.relays = r.relays[:0] + for relay := range dedupedRelays { + r.relays = append(r.relays, relay) + } + // Put them in a somewhat consistent order after de-duplication + slices.SortFunc(r.relays, func(a, b netip.Addr) int { + return a.Compare(b) + }) + + // Now the addrs n := len(r.addrs) if n < 2 { return @@ -687,7 +704,6 @@ func minInt(a, b int) int { // isPreferred returns true of the ip is contained in the preferredRanges list func isPreferred(ip netip.Addr, preferredRanges []netip.Prefix) bool { - //TODO: this would be better in a CIDR6Tree for _, p := range preferredRanges { if p.Contains(ip) { return true diff --git a/remote_list_test.go b/remote_list_test.go index 62a892b..0caf86a 100644 --- a/remote_list_test.go +++ b/remote_list_test.go @@ -9,11 +9,11 @@ import ( ) func TestRemoteList_Rebuild(t *testing.T) { - rl := NewRemoteList(nil) + rl := NewRemoteList([]netip.Addr{netip.MustParseAddr("0.0.0.0")}, nil) rl.unlockedSetV4( netip.MustParseAddr("0.0.0.0"), netip.MustParseAddr("0.0.0.0"), - []*Ip4AndPort{ + []*V4AddrPort{ newIp4AndPortFromString("70.199.182.92:1475"), // this is duped newIp4AndPortFromString("172.17.0.182:10101"), newIp4AndPortFromString("172.17.1.1:10101"), // this is duped @@ -25,20 +25,30 @@ func TestRemoteList_Rebuild(t *testing.T) { newIp4AndPortFromString("70.199.182.92:1476"), // almost dupe of 0 with a diff port newIp4AndPortFromString("70.199.182.92:1475"), // this is a dupe }, - func(netip.Addr, *Ip4AndPort) bool { return true }, + func(netip.Addr, *V4AddrPort) bool { return true }, ) rl.unlockedSetV6( netip.MustParseAddr("0.0.0.1"), netip.MustParseAddr("0.0.0.1"), - []*Ip6AndPort{ + []*V6AddrPort{ newIp6AndPortFromString("[1::1]:1"), // this is duped newIp6AndPortFromString("[1::1]:2"), // almost dupe of 0 with a diff port, also gets duped newIp6AndPortFromString("[1:100::1]:1"), newIp6AndPortFromString("[1::1]:1"), // this is a dupe newIp6AndPortFromString("[1::1]:2"), // this is a dupe }, - func(netip.Addr, *Ip6AndPort) bool { return true }, + func(netip.Addr, *V6AddrPort) bool { return true }, + ) + + rl.unlockedSetRelay( + netip.MustParseAddr("0.0.0.1"), + []netip.Addr{ + netip.MustParseAddr("1::1"), + netip.MustParseAddr("1.2.3.4"), + netip.MustParseAddr("1.2.3.4"), + netip.MustParseAddr("1::1"), + }, ) rl.Rebuild([]netip.Prefix{}) @@ -76,6 +86,11 @@ func TestRemoteList_Rebuild(t *testing.T) { assert.Equal(t, "[1::1]:2", rl.addrs[8].String()) assert.Equal(t, "[1:100::1]:1", rl.addrs[9].String()) + // assert relay deduplicated + assert.Len(t, rl.relays, 2) + assert.Equal(t, "1.2.3.4", rl.relays[0].String()) + assert.Equal(t, "1::1", rl.relays[1].String()) + // Ensure we can hoist a specific ipv4 range over anything else rl.Rebuild([]netip.Prefix{netip.MustParsePrefix("172.17.0.0/16")}) assert.Len(t, rl.addrs, 10, "addrs contains too many entries") @@ -98,11 +113,11 @@ func TestRemoteList_Rebuild(t *testing.T) { } func BenchmarkFullRebuild(b *testing.B) { - rl := NewRemoteList(nil) + rl := NewRemoteList([]netip.Addr{netip.MustParseAddr("0.0.0.0")}, nil) rl.unlockedSetV4( netip.MustParseAddr("0.0.0.0"), netip.MustParseAddr("0.0.0.0"), - []*Ip4AndPort{ + []*V4AddrPort{ newIp4AndPortFromString("70.199.182.92:1475"), newIp4AndPortFromString("172.17.0.182:10101"), newIp4AndPortFromString("172.17.1.1:10101"), @@ -112,19 +127,19 @@ func BenchmarkFullRebuild(b *testing.B) { newIp4AndPortFromString("172.17.1.1:10101"), // this is a dupe newIp4AndPortFromString("70.199.182.92:1476"), // dupe of 0 with a diff port }, - func(netip.Addr, *Ip4AndPort) bool { return true }, + func(netip.Addr, *V4AddrPort) bool { return true }, ) rl.unlockedSetV6( netip.MustParseAddr("0.0.0.0"), netip.MustParseAddr("0.0.0.0"), - []*Ip6AndPort{ + []*V6AddrPort{ newIp6AndPortFromString("[1::1]:1"), newIp6AndPortFromString("[1::1]:2"), // dupe of 0 with a diff port newIp6AndPortFromString("[1:100::1]:1"), newIp6AndPortFromString("[1::1]:1"), // this is a dupe }, - func(netip.Addr, *Ip6AndPort) bool { return true }, + func(netip.Addr, *V6AddrPort) bool { return true }, ) b.Run("no preferred", func(b *testing.B) { @@ -160,11 +175,11 @@ func BenchmarkFullRebuild(b *testing.B) { } func BenchmarkSortRebuild(b *testing.B) { - rl := NewRemoteList(nil) + rl := NewRemoteList([]netip.Addr{netip.MustParseAddr("0.0.0.0")}, nil) rl.unlockedSetV4( netip.MustParseAddr("0.0.0.0"), netip.MustParseAddr("0.0.0.0"), - []*Ip4AndPort{ + []*V4AddrPort{ newIp4AndPortFromString("70.199.182.92:1475"), newIp4AndPortFromString("172.17.0.182:10101"), newIp4AndPortFromString("172.17.1.1:10101"), @@ -174,19 +189,19 @@ func BenchmarkSortRebuild(b *testing.B) { newIp4AndPortFromString("172.17.1.1:10101"), // this is a dupe newIp4AndPortFromString("70.199.182.92:1476"), // dupe of 0 with a diff port }, - func(netip.Addr, *Ip4AndPort) bool { return true }, + func(netip.Addr, *V4AddrPort) bool { return true }, ) rl.unlockedSetV6( netip.MustParseAddr("0.0.0.0"), netip.MustParseAddr("0.0.0.0"), - []*Ip6AndPort{ + []*V6AddrPort{ newIp6AndPortFromString("[1::1]:1"), newIp6AndPortFromString("[1::1]:2"), // dupe of 0 with a diff port newIp6AndPortFromString("[1:100::1]:1"), newIp6AndPortFromString("[1::1]:1"), // this is a dupe }, - func(netip.Addr, *Ip6AndPort) bool { return true }, + func(netip.Addr, *V6AddrPort) bool { return true }, ) b.Run("no preferred", func(b *testing.B) { @@ -224,19 +239,19 @@ func BenchmarkSortRebuild(b *testing.B) { }) } -func newIp4AndPortFromString(s string) *Ip4AndPort { +func newIp4AndPortFromString(s string) *V4AddrPort { a := netip.MustParseAddrPort(s) v4Addr := a.Addr().As4() - return &Ip4AndPort{ - Ip: binary.BigEndian.Uint32(v4Addr[:]), + return &V4AddrPort{ + Addr: binary.BigEndian.Uint32(v4Addr[:]), Port: uint32(a.Port()), } } -func newIp6AndPortFromString(s string) *Ip6AndPort { +func newIp6AndPortFromString(s string) *V6AddrPort { a := netip.MustParseAddrPort(s) v6Addr := a.Addr().As16() - return &Ip6AndPort{ + return &V6AddrPort{ Hi: binary.BigEndian.Uint64(v6Addr[:8]), Lo: binary.BigEndian.Uint64(v6Addr[8:]), Port: uint32(a.Port()), diff --git a/service/service.go b/service/service.go index 4ddd301..4339677 100644 --- a/service/service.go +++ b/service/service.go @@ -90,9 +90,9 @@ func New(config *config.C) (*Service, error) { }, }) - ipNet := device.Cidr() + ipNet := device.Networks() pa := tcpip.ProtocolAddress{ - AddressWithPrefix: tcpip.AddrFromSlice(ipNet.Addr().AsSlice()).WithPrefix(), + AddressWithPrefix: tcpip.AddrFromSlice(ipNet[0].Addr().AsSlice()).WithPrefix(), Protocol: ipv4.ProtocolNumber, } if err := s.ipstack.AddProtocolAddress(nicID, pa, stack.AddressProperties{ diff --git a/service/service_test.go b/service/service_test.go index e9fceef..613758e 100644 --- a/service/service_test.go +++ b/service/service_test.go @@ -10,8 +10,8 @@ import ( "dario.cat/mergo" "github.com/slackhq/nebula/cert" + "github.com/slackhq/nebula/cert_test" "github.com/slackhq/nebula/config" - "github.com/slackhq/nebula/e2e" "golang.org/x/sync/errgroup" "gopkg.in/yaml.v2" ) @@ -19,7 +19,7 @@ import ( type m map[string]interface{} func newSimpleService(caCrt cert.Certificate, caKey []byte, name string, udpIp netip.Addr, overrides m) *Service { - _, _, myPrivKey, myPEM := e2e.NewTestCert(caCrt, caKey, "a", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.PrefixFrom(udpIp, 24)}, nil, []string{}) + _, _, myPrivKey, myPEM := cert_test.NewTestCert(cert.Version2, cert.Curve_CURVE25519, caCrt, caKey, "a", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.PrefixFrom(udpIp, 24)}, nil, []string{}) caB, err := caCrt.MarshalPEM() if err != nil { panic(err) @@ -79,7 +79,7 @@ func newSimpleService(caCrt cert.Certificate, caKey []byte, name string, udpIp n } func TestService(t *testing.T) { - ca, _, caKey, _ := e2e.NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) a := newSimpleService(ca, caKey, "a", netip.MustParseAddr("10.0.0.1"), m{ "static_host_map": m{}, "lighthouse": m{ diff --git a/ssh.go b/ssh.go index 881ee46..203166c 100644 --- a/ssh.go +++ b/ssh.go @@ -77,9 +77,6 @@ func wireSSHReload(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) { // that callers may invoke to run the configured ssh server. On // failure, it returns nil, error. func configSSH(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) (func(), error) { - //TODO conntrack list - //TODO print firewall rules or hash? - listen := c.GetString("sshd.listen", "") if listen == "" { return nil, fmt.Errorf("sshd.listen must be provided") @@ -93,7 +90,6 @@ func configSSH(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) (func(), erro return nil, fmt.Errorf("sshd.listen can not use port 22") } - //TODO: no good way to reload this right now hostKeyPathOrKey := c.GetString("sshd.host_key", "") if hostKeyPathOrKey == "" { return nil, fmt.Errorf("sshd.host_key must be provided") @@ -320,7 +316,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter ssh.RegisterCommand(&sshd.Command{ Name: "print-cert", - ShortDescription: "Prints the current certificate being used or the certificate for the provided vpn ip", + ShortDescription: "Prints the current certificate being used or the certificate for the provided vpn addr", Flags: func() (*flag.FlagSet, interface{}) { fl := flag.NewFlagSet("", flag.ContinueOnError) s := sshPrintCertFlags{} @@ -336,7 +332,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter ssh.RegisterCommand(&sshd.Command{ Name: "print-tunnel", - ShortDescription: "Prints json details about a tunnel for the provided vpn ip", + ShortDescription: "Prints json details about a tunnel for the provided vpn addr", Flags: func() (*flag.FlagSet, interface{}) { fl := flag.NewFlagSet("", flag.ContinueOnError) s := sshPrintTunnelFlags{} @@ -364,7 +360,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter ssh.RegisterCommand(&sshd.Command{ Name: "change-remote", - ShortDescription: "Changes the remote address used in the tunnel for the provided vpn ip", + ShortDescription: "Changes the remote address used in the tunnel for the provided vpn addr", Flags: func() (*flag.FlagSet, interface{}) { fl := flag.NewFlagSet("", flag.ContinueOnError) s := sshChangeRemoteFlags{} @@ -378,7 +374,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter ssh.RegisterCommand(&sshd.Command{ Name: "close-tunnel", - ShortDescription: "Closes a tunnel for the provided vpn ip", + ShortDescription: "Closes a tunnel for the provided vpn addr", Flags: func() (*flag.FlagSet, interface{}) { fl := flag.NewFlagSet("", flag.ContinueOnError) s := sshCloseTunnelFlags{} @@ -392,7 +388,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter ssh.RegisterCommand(&sshd.Command{ Name: "create-tunnel", - ShortDescription: "Creates a tunnel for the provided vpn ip and address", + ShortDescription: "Creates a tunnel for the provided vpn address", Help: "The lighthouses will be queried for real addresses but you can provide one as well.", Flags: func() (*flag.FlagSet, interface{}) { fl := flag.NewFlagSet("", flag.ContinueOnError) @@ -407,8 +403,8 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter ssh.RegisterCommand(&sshd.Command{ Name: "query-lighthouse", - ShortDescription: "Query the lighthouses for the provided vpn ip", - Help: "This command is asynchronous. Only currently known udp ips will be printed.", + ShortDescription: "Query the lighthouses for the provided vpn address", + Help: "This command is asynchronous. Only currently known udp addresses will be printed.", Callback: func(fs interface{}, a []string, w sshd.StringWriter) error { return sshQueryLighthouse(f, fs, a, w) }, @@ -418,7 +414,6 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter func sshListHostMap(hl controlHostLister, a interface{}, w sshd.StringWriter) error { fs, ok := a.(*sshListHostMapFlags) if !ok { - //TODO: error return nil } @@ -430,7 +425,7 @@ func sshListHostMap(hl controlHostLister, a interface{}, w sshd.StringWriter) er } sort.Slice(hm, func(i, j int) bool { - return hm[i].VpnIp.Compare(hm[j].VpnIp) < 0 + return hm[i].VpnAddrs[0].Compare(hm[j].VpnAddrs[0]) < 0 }) if fs.Json || fs.Pretty { @@ -441,13 +436,12 @@ func sshListHostMap(hl controlHostLister, a interface{}, w sshd.StringWriter) er err := js.Encode(hm) if err != nil { - //TODO return nil } } else { for _, v := range hm { - err := w.WriteLine(fmt.Sprintf("%s: %s", v.VpnIp, v.RemoteAddrs)) + err := w.WriteLine(fmt.Sprintf("%s: %s", v.VpnAddrs, v.RemoteAddrs)) if err != nil { return err } @@ -460,13 +454,12 @@ func sshListHostMap(hl controlHostLister, a interface{}, w sshd.StringWriter) er func sshListLighthouseMap(lightHouse *LightHouse, a interface{}, w sshd.StringWriter) error { fs, ok := a.(*sshListHostMapFlags) if !ok { - //TODO: error return nil } type lighthouseInfo struct { - VpnIp string `json:"vpnIp"` - Addrs *CacheMap `json:"addrs"` + VpnAddr string `json:"vpnAddr"` + Addrs *CacheMap `json:"addrs"` } lightHouse.RLock() @@ -474,15 +467,15 @@ func sshListLighthouseMap(lightHouse *LightHouse, a interface{}, w sshd.StringWr x := 0 for k, v := range lightHouse.addrMap { addrMap[x] = lighthouseInfo{ - VpnIp: k.String(), - Addrs: v.CopyCache(), + VpnAddr: k.String(), + Addrs: v.CopyCache(), } x++ } lightHouse.RUnlock() sort.Slice(addrMap, func(i, j int) bool { - return strings.Compare(addrMap[i].VpnIp, addrMap[j].VpnIp) < 0 + return strings.Compare(addrMap[i].VpnAddr, addrMap[j].VpnAddr) < 0 }) if fs.Json || fs.Pretty { @@ -493,7 +486,6 @@ func sshListLighthouseMap(lightHouse *LightHouse, a interface{}, w sshd.StringWr err := js.Encode(addrMap) if err != nil { - //TODO return nil } @@ -503,7 +495,7 @@ func sshListLighthouseMap(lightHouse *LightHouse, a interface{}, w sshd.StringWr if err != nil { return err } - err = w.WriteLine(fmt.Sprintf("%s: %s", v.VpnIp, string(b))) + err = w.WriteLine(fmt.Sprintf("%s: %s", v.VpnAddr, string(b))) if err != nil { return err } @@ -541,20 +533,20 @@ func sshVersion(ifce *Interface, fs interface{}, a []string, w sshd.StringWriter func sshQueryLighthouse(ifce *Interface, fs interface{}, a []string, w sshd.StringWriter) error { if len(a) == 0 { - return w.WriteLine("No vpn ip was provided") + return w.WriteLine("No vpn address was provided") } - vpnIp, err := netip.ParseAddr(a[0]) + vpnAddr, err := netip.ParseAddr(a[0]) if err != nil { - return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) + return w.WriteLine(fmt.Sprintf("The provided vpn address could not be parsed: %s", a[0])) } - if !vpnIp.IsValid() { - return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) + if !vpnAddr.IsValid() { + return w.WriteLine(fmt.Sprintf("The provided vpn address could not be parsed: %s", a[0])) } var cm *CacheMap - rl := ifce.lightHouse.Query(vpnIp) + rl := ifce.lightHouse.Query(vpnAddr) if rl != nil { cm = rl.CopyCache() } @@ -564,26 +556,25 @@ func sshQueryLighthouse(ifce *Interface, fs interface{}, a []string, w sshd.Stri func sshCloseTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWriter) error { flags, ok := fs.(*sshCloseTunnelFlags) if !ok { - //TODO: error return nil } if len(a) == 0 { - return w.WriteLine("No vpn ip was provided") + return w.WriteLine("No vpn address was provided") } - vpnIp, err := netip.ParseAddr(a[0]) + vpnAddr, err := netip.ParseAddr(a[0]) if err != nil { - return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) + return w.WriteLine(fmt.Sprintf("The provided vpn address could not be parsed: %s", a[0])) } - if !vpnIp.IsValid() { - return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) + if !vpnAddr.IsValid() { + return w.WriteLine(fmt.Sprintf("The provided vpn address could not be parsed: %s", a[0])) } - hostInfo := ifce.hostMap.QueryVpnIp(vpnIp) + hostInfo := ifce.hostMap.QueryVpnAddr(vpnAddr) if hostInfo == nil { - return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0])) + return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn address: %v", a[0])) } if !flags.LocalOnly { @@ -605,29 +596,28 @@ func sshCloseTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWr func sshCreateTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWriter) error { flags, ok := fs.(*sshCreateTunnelFlags) if !ok { - //TODO: error return nil } if len(a) == 0 { - return w.WriteLine("No vpn ip was provided") + return w.WriteLine("No vpn address was provided") } - vpnIp, err := netip.ParseAddr(a[0]) + vpnAddr, err := netip.ParseAddr(a[0]) if err != nil { - return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) + return w.WriteLine(fmt.Sprintf("The provided vpn address could not be parsed: %s", a[0])) } - if !vpnIp.IsValid() { - return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) + if !vpnAddr.IsValid() { + return w.WriteLine(fmt.Sprintf("The provided vpn address could not be parsed: %s", a[0])) } - hostInfo := ifce.hostMap.QueryVpnIp(vpnIp) + hostInfo := ifce.hostMap.QueryVpnAddr(vpnAddr) if hostInfo != nil { return w.WriteLine(fmt.Sprintf("Tunnel already exists")) } - hostInfo = ifce.handshakeManager.QueryVpnIp(vpnIp) + hostInfo = ifce.handshakeManager.QueryVpnAddr(vpnAddr) if hostInfo != nil { return w.WriteLine(fmt.Sprintf("Tunnel already handshaking")) } @@ -640,7 +630,7 @@ func sshCreateTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringW } } - hostInfo = ifce.handshakeManager.StartHandshake(vpnIp, nil) + hostInfo = ifce.handshakeManager.StartHandshake(vpnAddr, nil) if addr.IsValid() { hostInfo.SetRemote(addr) } @@ -651,12 +641,11 @@ func sshCreateTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringW func sshChangeRemote(ifce *Interface, fs interface{}, a []string, w sshd.StringWriter) error { flags, ok := fs.(*sshChangeRemoteFlags) if !ok { - //TODO: error return nil } if len(a) == 0 { - return w.WriteLine("No vpn ip was provided") + return w.WriteLine("No vpn address was provided") } if flags.Address == "" { @@ -668,18 +657,18 @@ func sshChangeRemote(ifce *Interface, fs interface{}, a []string, w sshd.StringW return w.WriteLine("Address could not be parsed") } - vpnIp, err := netip.ParseAddr(a[0]) + vpnAddr, err := netip.ParseAddr(a[0]) if err != nil { - return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) + return w.WriteLine(fmt.Sprintf("The provided vpn address could not be parsed: %s", a[0])) } - if !vpnIp.IsValid() { - return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) + if !vpnAddr.IsValid() { + return w.WriteLine(fmt.Sprintf("The provided vpn address could not be parsed: %s", a[0])) } - hostInfo := ifce.hostMap.QueryVpnIp(vpnIp) + hostInfo := ifce.hostMap.QueryVpnAddr(vpnAddr) if hostInfo == nil { - return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0])) + return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn address: %v", a[0])) } hostInfo.SetRemote(addr) @@ -781,24 +770,23 @@ func sshLogFormat(l *logrus.Logger, fs interface{}, a []string, w sshd.StringWri func sshPrintCert(ifce *Interface, fs interface{}, a []string, w sshd.StringWriter) error { args, ok := fs.(*sshPrintCertFlags) if !ok { - //TODO: error return nil } - cert := ifce.pki.GetCertState().Certificate + cert := ifce.pki.getCertState().GetDefaultCertificate() if len(a) > 0 { - vpnIp, err := netip.ParseAddr(a[0]) + vpnAddr, err := netip.ParseAddr(a[0]) if err != nil { - return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) + return w.WriteLine(fmt.Sprintf("The provided vpn addr could not be parsed: %s", a[0])) } - if !vpnIp.IsValid() { - return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) + if !vpnAddr.IsValid() { + return w.WriteLine(fmt.Sprintf("The provided vpn addr could not be parsed: %s", a[0])) } - hostInfo := ifce.hostMap.QueryVpnIp(vpnIp) + hostInfo := ifce.hostMap.QueryVpnAddr(vpnAddr) if hostInfo == nil { - return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0])) + return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn addr: %v", a[0])) } cert = hostInfo.GetCert().Certificate @@ -807,7 +795,6 @@ func sshPrintCert(ifce *Interface, fs interface{}, a []string, w sshd.StringWrit if args.Json || args.Pretty { b, err := cert.MarshalJSON() if err != nil { - //TODO: handle it return nil } @@ -816,7 +803,6 @@ func sshPrintCert(ifce *Interface, fs interface{}, a []string, w sshd.StringWrit err := json.Indent(buf, b, "", " ") b = buf.Bytes() if err != nil { - //TODO: handle it return nil } } @@ -827,7 +813,6 @@ func sshPrintCert(ifce *Interface, fs interface{}, a []string, w sshd.StringWrit if args.Raw { b, err := cert.MarshalPEM() if err != nil { - //TODO: handle it return nil } @@ -840,7 +825,6 @@ func sshPrintCert(ifce *Interface, fs interface{}, a []string, w sshd.StringWrit func sshPrintRelays(ifce *Interface, fs interface{}, a []string, w sshd.StringWriter) error { args, ok := fs.(*sshPrintTunnelFlags) if !ok { - //TODO: error w.WriteLine(fmt.Sprintf("sshPrintRelays failed to convert args type")) return nil } @@ -856,15 +840,15 @@ func sshPrintRelays(ifce *Interface, fs interface{}, a []string, w sshd.StringWr Error error Type string State string - PeerIp netip.Addr + PeerAddr netip.Addr LocalIndex uint32 RemoteIndex uint32 RelayedThrough []netip.Addr } type RelayOutput struct { - NebulaIp netip.Addr - RelayForIps []RelayFor + NebulaAddr netip.Addr + RelayForAddrs []RelayFor } type CmdOutput struct { @@ -880,16 +864,16 @@ func sshPrintRelays(ifce *Interface, fs interface{}, a []string, w sshd.StringWr } for k, v := range relays { - ro := RelayOutput{NebulaIp: v.vpnIp} + ro := RelayOutput{NebulaAddr: v.vpnAddrs[0]} co.Relays = append(co.Relays, &ro) - relayHI := ifce.hostMap.QueryVpnIp(v.vpnIp) + relayHI := ifce.hostMap.QueryVpnAddr(v.vpnAddrs[0]) if relayHI == nil { - ro.RelayForIps = append(ro.RelayForIps, RelayFor{Error: errors.New("could not find hostinfo")}) + ro.RelayForAddrs = append(ro.RelayForAddrs, RelayFor{Error: errors.New("could not find hostinfo")}) continue } - for _, vpnIp := range relayHI.relayState.CopyRelayForIps() { + for _, vpnAddr := range relayHI.relayState.CopyRelayForIps() { rf := RelayFor{Error: nil} - r, ok := relayHI.relayState.GetRelayForByIp(vpnIp) + r, ok := relayHI.relayState.GetRelayForByAddr(vpnAddr) if ok { t := "" switch r.Type { @@ -913,19 +897,19 @@ func sshPrintRelays(ifce *Interface, fs interface{}, a []string, w sshd.StringWr rf.LocalIndex = r.LocalIndex rf.RemoteIndex = r.RemoteIndex - rf.PeerIp = r.PeerIp + rf.PeerAddr = r.PeerAddr rf.Type = t rf.State = s if rf.LocalIndex != k { rf.Error = fmt.Errorf("hostmap LocalIndex '%v' does not match RelayState LocalIndex", k) } } - relayedHI := ifce.hostMap.QueryVpnIp(vpnIp) + relayedHI := ifce.hostMap.QueryVpnAddr(vpnAddr) if relayedHI != nil { rf.RelayedThrough = append(rf.RelayedThrough, relayedHI.relayState.CopyRelayIps()...) } - ro.RelayForIps = append(ro.RelayForIps, rf) + ro.RelayForAddrs = append(ro.RelayForAddrs, rf) } } err := enc.Encode(co) @@ -938,26 +922,25 @@ func sshPrintRelays(ifce *Interface, fs interface{}, a []string, w sshd.StringWr func sshPrintTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWriter) error { args, ok := fs.(*sshPrintTunnelFlags) if !ok { - //TODO: error return nil } if len(a) == 0 { - return w.WriteLine("No vpn ip was provided") + return w.WriteLine("No vpn address was provided") } - vpnIp, err := netip.ParseAddr(a[0]) + vpnAddr, err := netip.ParseAddr(a[0]) if err != nil { - return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) + return w.WriteLine(fmt.Sprintf("The provided vpn addr could not be parsed: %s", a[0])) } - if !vpnIp.IsValid() { - return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) + if !vpnAddr.IsValid() { + return w.WriteLine(fmt.Sprintf("The provided vpn addr could not be parsed: %s", a[0])) } - hostInfo := ifce.hostMap.QueryVpnIp(vpnIp) + hostInfo := ifce.hostMap.QueryVpnAddr(vpnAddr) if hostInfo == nil { - return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0])) + return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn addr: %v", a[0])) } enc := json.NewEncoder(w.GetWriter()) @@ -971,13 +954,15 @@ func sshPrintTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWr func sshDeviceInfo(ifce *Interface, fs interface{}, w sshd.StringWriter) error { data := struct { - Name string `json:"name"` - Cidr string `json:"cidr"` + Name string `json:"name"` + Cidr []netip.Prefix `json:"cidr"` }{ Name: ifce.inside.Name(), - Cidr: ifce.inside.Cidr().String(), + Cidr: make([]netip.Prefix, len(ifce.inside.Networks())), } + copy(data.Cidr, ifce.inside.Networks()) + flags, ok := fs.(*sshDeviceInfoFlags) if !ok { return fmt.Errorf("internal error: expected flags to be sshDeviceInfoFlags but was %+v", fs) diff --git a/sshd/command.go b/sshd/command.go index 900b01e..66646a6 100644 --- a/sshd/command.go +++ b/sshd/command.go @@ -57,7 +57,6 @@ func execCommand(c *Command, args []string, w StringWriter) error { func dumpCommands(c *radix.Tree, w StringWriter) { err := w.WriteLine("Available commands:") if err != nil { - //TODO: log return } @@ -67,10 +66,7 @@ func dumpCommands(c *radix.Tree, w StringWriter) { } sort.Strings(cmds) - err = w.Write(strings.Join(cmds, "\n") + "\n\n") - if err != nil { - //TODO: log - } + _ = w.Write(strings.Join(cmds, "\n") + "\n\n") } func lookupCommand(c *radix.Tree, sCmd string) (*Command, error) { @@ -119,8 +115,6 @@ func helpCallback(commands *radix.Tree, a []string, w StringWriter) (err error) // We are printing a specific commands help text cmd, err := lookupCommand(commands, a[0]) if err != nil { - //TODO: handle error - //TODO: message the user return } diff --git a/sshd/server.go b/sshd/server.go index 9e8c721..c151f91 100644 --- a/sshd/server.go +++ b/sshd/server.go @@ -80,9 +80,7 @@ func NewSSHServer(l *logrus.Entry) (*SSHServer, error) { s.config = &ssh.ServerConfig{ PublicKeyCallback: cc.Authenticate, - //TODO: AuthLogCallback: s.authAttempt, - //TODO: version string - ServerVersion: fmt.Sprintf("SSH-2.0-Nebula???"), + ServerVersion: fmt.Sprintf("SSH-2.0-Nebula???"), } s.RegisterCommand(&Command{ diff --git a/sshd/session.go b/sshd/session.go index bba2a55..7c5869e 100644 --- a/sshd/session.go +++ b/sshd/session.go @@ -62,7 +62,6 @@ func (s *session) handleChannels(chans <-chan ssh.NewChannel) { func (s *session) handleRequests(in <-chan *ssh.Request, channel ssh.Channel) { for req := range in { var err error - //TODO: maybe support window sizing? switch req.Type { case "shell": if s.term == nil { @@ -89,9 +88,7 @@ func (s *session) handleRequests(in <-chan *ssh.Request, channel ssh.Channel) { req.Reply(true, nil) s.dispatchCommand(payload.Value, &stringWriter{channel}) - //TODO: Fix error handling and report the proper status back status := struct{ Status uint32 }{uint32(0)} - //TODO: I think this is how we shut down a shell as well? channel.SendRequest("exit-status", false, ssh.Marshal(status)) channel.Close() return @@ -110,7 +107,6 @@ func (s *session) handleRequests(in <-chan *ssh.Request, channel ssh.Channel) { } func (s *session) createTerm(channel ssh.Channel) *terminal.Terminal { - //TODO: PS1 with nebula cert name term := terminal.NewTerminal(channel, s.c.User()+"@nebula > ") term.AutoCompleteCallback = func(line string, pos int, key rune) (newLine string, newPos int, ok bool) { // key 9 is tab @@ -137,7 +133,6 @@ func (s *session) handleInput(channel ssh.Channel) { for { line, err := s.term.ReadLine() if err != nil { - //TODO: log break } @@ -148,7 +143,6 @@ func (s *session) handleInput(channel ssh.Channel) { func (s *session) dispatchCommand(line string, w StringWriter) { args, err := shlex.Split(line, true) if err != nil { - //todo: LOG IT return } @@ -159,13 +153,11 @@ func (s *session) dispatchCommand(line string, w StringWriter) { c, err := lookupCommand(s.commands, args[0]) if err != nil { - //TODO: handle the error return } if c == nil { err := w.WriteLine(fmt.Sprintf("did not understand: %s", line)) - //TODO: log error _ = err dumpCommands(s.commands, w) @@ -177,10 +169,7 @@ func (s *session) dispatchCommand(line string, w StringWriter) { return } - err = execCommand(c, args[1:], w) - if err != nil { - //TODO: log the error - } + _ = execCommand(c, args[1:], w) return } diff --git a/test/tun.go b/test/tun.go index fbf5829..b29d61c 100644 --- a/test/tun.go +++ b/test/tun.go @@ -16,8 +16,8 @@ func (NoopTun) Activate() error { return nil } -func (NoopTun) Cidr() netip.Prefix { - return netip.Prefix{} +func (NoopTun) Networks() []netip.Prefix { + return []netip.Prefix{} } func (NoopTun) Name() string { diff --git a/timeout_test.go b/timeout_test.go index 4c6364e..db36fec 100644 --- a/timeout_test.go +++ b/timeout_test.go @@ -116,10 +116,10 @@ func TestTimerWheel_Purge(t *testing.T) { assert.Equal(t, 0, tw.current) fps := []firewall.Packet{ - {LocalIP: netip.MustParseAddr("0.0.0.1")}, - {LocalIP: netip.MustParseAddr("0.0.0.2")}, - {LocalIP: netip.MustParseAddr("0.0.0.3")}, - {LocalIP: netip.MustParseAddr("0.0.0.4")}, + {LocalAddr: netip.MustParseAddr("0.0.0.1")}, + {LocalAddr: netip.MustParseAddr("0.0.0.2")}, + {LocalAddr: netip.MustParseAddr("0.0.0.3")}, + {LocalAddr: netip.MustParseAddr("0.0.0.4")}, } tw.Add(fps[0], time.Second*1) diff --git a/udp/conn.go b/udp/conn.go index fa4e443..895b0df 100644 --- a/udp/conn.go +++ b/udp/conn.go @@ -4,28 +4,19 @@ import ( "net/netip" "github.com/slackhq/nebula/config" - "github.com/slackhq/nebula/firewall" - "github.com/slackhq/nebula/header" ) const MTU = 9001 type EncReader func( addr netip.AddrPort, - out []byte, - packet []byte, - header *header.H, - fwPacket *firewall.Packet, - lhh LightHouseHandlerFunc, - nb []byte, - q int, - localCache firewall.ConntrackCache, + payload []byte, ) type Conn interface { Rebind() error LocalAddr() (netip.AddrPort, error) - ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int) + ListenOut(r EncReader) WriteTo(b []byte, addr netip.AddrPort) error ReloadConfig(c *config.C) Close() error @@ -39,7 +30,7 @@ func (NoopConn) Rebind() error { func (NoopConn) LocalAddr() (netip.AddrPort, error) { return netip.AddrPort{}, nil } -func (NoopConn) ListenOut(_ EncReader, _ LightHouseHandlerFunc, _ *firewall.ConntrackCacheTicker, _ int) { +func (NoopConn) ListenOut(_ EncReader) { return } func (NoopConn) WriteTo(_ []byte, _ netip.AddrPort) error { diff --git a/udp/temp.go b/udp/temp.go deleted file mode 100644 index b281906..0000000 --- a/udp/temp.go +++ /dev/null @@ -1,10 +0,0 @@ -package udp - -import ( - "net/netip" -) - -//TODO: The items in this file belong in their own packages but doing that in a single PR is a nightmare - -// TODO: IPV6-WORK this can likely be removed now -type LightHouseHandlerFunc func(rAddr netip.AddrPort, vpnIp netip.Addr, p []byte) diff --git a/udp/udp_generic.go b/udp/udp_generic.go index 2d84536..06a4d53 100644 --- a/udp/udp_generic.go +++ b/udp/udp_generic.go @@ -15,8 +15,6 @@ import ( "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" - "github.com/slackhq/nebula/firewall" - "github.com/slackhq/nebula/header" ) type GenericConn struct { @@ -60,7 +58,7 @@ func (u *GenericConn) LocalAddr() (netip.AddrPort, error) { } func (u *GenericConn) ReloadConfig(c *config.C) { - // TODO + } func NewUDPStatsEmitter(udpConns []Conn) func() { @@ -72,12 +70,8 @@ type rawMessage struct { Len uint32 } -func (u *GenericConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int) { - plaintext := make([]byte, MTU) +func (u *GenericConn) ListenOut(r EncReader) { buffer := make([]byte, MTU) - h := &header.H{} - fwPacket := &firewall.Packet{} - nb := make([]byte, 12, 12) for { // Just read one packet at a time @@ -87,16 +81,6 @@ func (u *GenericConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *f return } - r( - netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), - plaintext[:0], - buffer[:n], - h, - fwPacket, - lhf, - nb, - q, - cache.Get(u.l), - ) + r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n]) } } diff --git a/udp/udp_linux.go b/udp/udp_linux.go index 2eee76e..32a567e 100644 --- a/udp/udp_linux.go +++ b/udp/udp_linux.go @@ -14,13 +14,9 @@ import ( "github.com/rcrowley/go-metrics" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" - "github.com/slackhq/nebula/firewall" - "github.com/slackhq/nebula/header" "golang.org/x/sys/unix" ) -//TODO: make it support reload as best you can! - type StdConn struct { sysFd int isV4 bool @@ -59,7 +55,6 @@ func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch in } } - //TODO: support multiple listening IPs (for limiting ipv6) var sa unix.Sockaddr if ip.Is4() { sa4 := &unix.SockaddrInet4{Port: port} @@ -74,11 +69,6 @@ func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch in return nil, fmt.Errorf("unable to bind to socket: %s", err) } - //TODO: this may be useful for forcing threads into specific cores - //unix.SetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_INCOMING_CPU, x) - //v, err := unix.GetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_INCOMING_CPU) - //l.Println(v, err) - return &StdConn{sysFd: fd, isV4: ip.Is4(), l: l, batch: batch}, err } @@ -120,15 +110,9 @@ func (u *StdConn) LocalAddr() (netip.AddrPort, error) { } } -func (u *StdConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int) { - plaintext := make([]byte, MTU) - h := &header.H{} - fwPacket := &firewall.Packet{} +func (u *StdConn) ListenOut(r EncReader) { var ip netip.Addr - nb := make([]byte, 12, 12) - //TODO: should we track this? - //metric := metrics.GetOrRegisterHistogram("test.batch_read", nil, metrics.NewExpDecaySample(1028, 0.015)) msgs, buffers, names := u.PrepareRawMessages(u.batch) read := u.ReadMulti if u.batch == 1 { @@ -142,26 +126,14 @@ func (u *StdConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firew return } - //metric.Update(int64(n)) for i := 0; i < n; i++ { + // Its ok to skip the ok check here, the slicing is the only error that can occur and it will panic if u.isV4 { ip, _ = netip.AddrFromSlice(names[i][4:8]) - //TODO: IPV6-WORK what is not ok? } else { ip, _ = netip.AddrFromSlice(names[i][8:24]) - //TODO: IPV6-WORK what is not ok? } - r( - netip.AddrPortFrom(ip.Unmap(), binary.BigEndian.Uint16(names[i][2:4])), - plaintext[:0], - buffers[i][:msgs[i].Len], - h, - fwPacket, - lhf, - nb, - q, - cache.Get(u.l), - ) + r(netip.AddrPortFrom(ip.Unmap(), binary.BigEndian.Uint16(names[i][2:4])), buffers[i][:msgs[i].Len]) } } } @@ -235,8 +207,6 @@ func (u *StdConn) writeTo6(b []byte, ip netip.AddrPort) error { return &net.OpError{Op: "sendto", Err: err} } - //TODO: handle incomplete writes - return nil } } @@ -266,8 +236,6 @@ func (u *StdConn) writeTo4(b []byte, ip netip.AddrPort) error { return &net.OpError{Op: "sendto", Err: err} } - //TODO: handle incomplete writes - return nil } } @@ -314,7 +282,6 @@ func (u *StdConn) getMemInfo(meminfo *[unix.SK_MEMINFO_VARS]uint32) error { } func (u *StdConn) Close() error { - //TODO: this will not interrupt the read loop return syscall.Close(u.sysFd) } diff --git a/udp/udp_linux_32.go b/udp/udp_linux_32.go index 523968c..de8f1cd 100644 --- a/udp/udp_linux_32.go +++ b/udp/udp_linux_32.go @@ -39,7 +39,6 @@ func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) { buffers[i] = make([]byte, MTU) names[i] = make([]byte, unix.SizeofSockaddrInet6) - //TODO: this is still silly, no need for an array vs := []iovec{ {Base: &buffers[i][0], Len: uint32(len(buffers[i]))}, } diff --git a/udp/udp_linux_64.go b/udp/udp_linux_64.go index 87a0de7..48c5a97 100644 --- a/udp/udp_linux_64.go +++ b/udp/udp_linux_64.go @@ -42,7 +42,6 @@ func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) { buffers[i] = make([]byte, MTU) names[i] = make([]byte, unix.SizeofSockaddrInet6) - //TODO: this is still silly, no need for an array vs := []iovec{ {Base: &buffers[i][0], Len: uint64(len(buffers[i]))}, } diff --git a/udp/udp_rio_windows.go b/udp/udp_rio_windows.go index ee7e1e0..585b642 100644 --- a/udp/udp_rio_windows.go +++ b/udp/udp_rio_windows.go @@ -18,9 +18,6 @@ import ( "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" - "github.com/slackhq/nebula/firewall" - "github.com/slackhq/nebula/header" - "golang.org/x/sys/windows" "golang.zx2c4.com/wireguard/conn/winrio" ) @@ -118,12 +115,8 @@ func (u *RIOConn) bind(sa windows.Sockaddr) error { return nil } -func (u *RIOConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int) { - plaintext := make([]byte, MTU) +func (u *RIOConn) ListenOut(r EncReader) { buffer := make([]byte, MTU) - h := &header.H{} - fwPacket := &firewall.Packet{} - nb := make([]byte, 12, 12) for { // Just read one packet at a time @@ -133,17 +126,7 @@ func (u *RIOConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firew return } - r( - netip.AddrPortFrom(netip.AddrFrom16(rua.Addr).Unmap(), (rua.Port>>8)|((rua.Port&0xff)<<8)), - plaintext[:0], - buffer[:n], - h, - fwPacket, - lhf, - nb, - q, - cache.Get(u.l), - ) + r(netip.AddrPortFrom(netip.AddrFrom16(rua.Addr).Unmap(), (rua.Port>>8)|((rua.Port&0xff)<<8)), buffer[:n]) } } diff --git a/udp/udp_tester.go b/udp/udp_tester.go index f03a353..8d5e6c1 100644 --- a/udp/udp_tester.go +++ b/udp/udp_tester.go @@ -10,7 +10,6 @@ import ( "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" - "github.com/slackhq/nebula/firewall" "github.com/slackhq/nebula/header" ) @@ -107,18 +106,13 @@ func (u *TesterConn) WriteTo(b []byte, addr netip.AddrPort) error { return nil } -func (u *TesterConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int) { - plaintext := make([]byte, MTU) - h := &header.H{} - fwPacket := &firewall.Packet{} - nb := make([]byte, 12, 12) - +func (u *TesterConn) ListenOut(r EncReader) { for { p, ok := <-u.RxPackets if !ok { return } - r(p.From, plaintext[:0], p.Data, h, fwPacket, lhf, nb, q, cache.Get(u.l)) + r(p.From, p.Data) } } From 351dbd60596da90466b6e27a4f6d9051d7bd5c4a Mon Sep 17 00:00:00 2001 From: Wade Simmons Date: Thu, 6 Mar 2025 12:29:38 -0500 Subject: [PATCH 37/67] smoke-extra: support Ubuntu 24.04 (#1311) Ubuntu 24.04 doesn't include vagrant anymore, so add the hashicorp source --- .github/workflows/smoke-extra.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.github/workflows/smoke-extra.yml b/.github/workflows/smoke-extra.yml index 2b5e6e9..de582de 100644 --- a/.github/workflows/smoke-extra.yml +++ b/.github/workflows/smoke-extra.yml @@ -27,6 +27,9 @@ jobs: go-version-file: 'go.mod' check-latest: true + - name: add hashicorp source + run: wget -O- https://apt.releases.hashicorp.com/gpg | gpg --dearmor | sudo tee /usr/share/keyrings/hashicorp-archive-keyring.gpg && echo "deb [signed-by=/usr/share/keyrings/hashicorp-archive-keyring.gpg] https://apt.releases.hashicorp.com $(lsb_release -cs) main" | sudo tee /etc/apt/sources.list.d/hashicorp.list + - name: install vagrant run: sudo apt-get update && sudo apt-get install -y vagrant virtualbox From 32d3a6e09178bc861f26be04932b4bba4a109ee3 Mon Sep 17 00:00:00 2001 From: Wade Simmons Date: Thu, 6 Mar 2025 12:54:20 -0500 Subject: [PATCH 38/67] build with go1.23 (#1198) * make boringcrypto: add checklinkname flag for go1.23 Starting with go1.23, we need to set -checklinkname=0 when building for boringcrypto because we need to use go:linkname to access `newGCMTLS`. Note that this does break builds when using a go version less than go1.23.0. We can probably assume that someone using this Makefile and manually building is using the latest release of Go though. See: - https://go.dev/doc/go1.23#linker * build with go1.23 This doesn't change our go.mod, which still only requires go1.22 as a minimum, only changes our builds to use go1.23 so we have the latest improvements. * fix `make test-boringcrypto` as well * also fix boringcrypto e2e test --- .github/workflows/gofmt.yml | 2 +- .github/workflows/release.yml | 6 +++--- .github/workflows/smoke.yml | 2 +- .github/workflows/test.yml | 8 ++++---- Makefile | 4 +++- 5 files changed, 12 insertions(+), 10 deletions(-) diff --git a/.github/workflows/gofmt.yml b/.github/workflows/gofmt.yml index e0d41ae..20a39cf 100644 --- a/.github/workflows/gofmt.yml +++ b/.github/workflows/gofmt.yml @@ -18,7 +18,7 @@ jobs: - uses: actions/setup-go@v5 with: - go-version: '1.22' + go-version: '1.23' check-latest: true - name: Install goimports diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 31987db..392f71b 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -14,7 +14,7 @@ jobs: - uses: actions/setup-go@v5 with: - go-version: '1.22' + go-version: '1.23' check-latest: true - name: Build @@ -37,7 +37,7 @@ jobs: - uses: actions/setup-go@v5 with: - go-version: '1.22' + go-version: '1.23' check-latest: true - name: Build @@ -70,7 +70,7 @@ jobs: - uses: actions/setup-go@v5 with: - go-version: '1.22' + go-version: '1.23' check-latest: true - name: Import certificates diff --git a/.github/workflows/smoke.yml b/.github/workflows/smoke.yml index 54833bd..3f63008 100644 --- a/.github/workflows/smoke.yml +++ b/.github/workflows/smoke.yml @@ -22,7 +22,7 @@ jobs: - uses: actions/setup-go@v5 with: - go-version: '1.22' + go-version: '1.23' check-latest: true - name: build diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 2b27f52..4f3f2ed 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -22,7 +22,7 @@ jobs: - uses: actions/setup-go@v5 with: - go-version: '1.22' + go-version: '1.23' check-latest: true - name: Build @@ -55,7 +55,7 @@ jobs: - uses: actions/setup-go@v5 with: - go-version: '1.22' + go-version: '1.23' check-latest: true - name: Build @@ -65,7 +65,7 @@ jobs: run: make test-boringcrypto - name: End 2 end - run: make e2evv GOEXPERIMENT=boringcrypto CGO_ENABLED=1 + run: make e2e GOEXPERIMENT=boringcrypto CGO_ENABLED=1 TEST_ENV="TEST_LOGS=1" TEST_FLAGS="-v -ldflags -checklinkname=0" test-linux-pkcs11: name: Build and test on linux with pkcs11 @@ -97,7 +97,7 @@ jobs: - uses: actions/setup-go@v5 with: - go-version: '1.22' + go-version: '1.23' check-latest: true - name: Build nebula diff --git a/Makefile b/Makefile index d3fbcaa..0b199a5 100644 --- a/Makefile +++ b/Makefile @@ -137,6 +137,8 @@ build/linux-mips-softfloat/%: LDFLAGS += -s -w # boringcrypto build/linux-amd64-boringcrypto/%: GOENV += GOEXPERIMENT=boringcrypto CGO_ENABLED=1 build/linux-arm64-boringcrypto/%: GOENV += GOEXPERIMENT=boringcrypto CGO_ENABLED=1 +build/linux-amd64-boringcrypto/%: LDFLAGS += -checklinkname=0 +build/linux-arm64-boringcrypto/%: LDFLAGS += -checklinkname=0 build/%/nebula: .FORCE GOOS=$(firstword $(subst -, , $*)) \ @@ -170,7 +172,7 @@ test: go test -v ./... test-boringcrypto: - GOEXPERIMENT=boringcrypto CGO_ENABLED=1 go test -v ./... + GOEXPERIMENT=boringcrypto CGO_ENABLED=1 go test -ldflags "-checklinkname=0" -v ./... test-pkcs11: CGO_ENABLED=1 go test -v -tags pkcs11 ./... From 750e4a81bf95bfa2244364efe14cba19fe7bf5fe Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 6 Mar 2025 12:57:05 -0500 Subject: [PATCH 39/67] Bump the golang-x-dependencies group across 1 directory with 5 updates (#1303) Bumps the golang-x-dependencies group with 3 updates in the / directory: [golang.org/x/crypto](https://github.com/golang/crypto), [golang.org/x/net](https://github.com/golang/net) and [golang.org/x/sync](https://github.com/golang/sync). Updates `golang.org/x/crypto` from 0.28.0 to 0.32.0 - [Commits](https://github.com/golang/crypto/compare/v0.28.0...v0.32.0) Updates `golang.org/x/net` from 0.30.0 to 0.34.0 - [Commits](https://github.com/golang/net/compare/v0.30.0...v0.34.0) Updates `golang.org/x/sync` from 0.8.0 to 0.10.0 - [Commits](https://github.com/golang/sync/compare/v0.8.0...v0.10.0) Updates `golang.org/x/sys` from 0.26.0 to 0.29.0 - [Commits](https://github.com/golang/sys/compare/v0.26.0...v0.29.0) Updates `golang.org/x/term` from 0.25.0 to 0.28.0 - [Commits](https://github.com/golang/term/compare/v0.25.0...v0.28.0) --- updated-dependencies: - dependency-name: golang.org/x/crypto dependency-type: direct:production update-type: version-update:semver-minor dependency-group: golang-x-dependencies - dependency-name: golang.org/x/net dependency-type: direct:production update-type: version-update:semver-minor dependency-group: golang-x-dependencies - dependency-name: golang.org/x/sync dependency-type: direct:production update-type: version-update:semver-minor dependency-group: golang-x-dependencies - dependency-name: golang.org/x/sys dependency-type: direct:production update-type: version-update:semver-minor dependency-group: golang-x-dependencies - dependency-name: golang.org/x/term dependency-type: direct:production update-type: version-update:semver-minor dependency-group: golang-x-dependencies ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- go.mod | 10 +++++----- go.sum | 20 ++++++++++---------- 2 files changed, 15 insertions(+), 15 deletions(-) diff --git a/go.mod b/go.mod index 2ff9976..afb613a 100644 --- a/go.mod +++ b/go.mod @@ -24,12 +24,12 @@ require ( github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6 github.com/stretchr/testify v1.9.0 github.com/vishvananda/netlink v1.3.0 - golang.org/x/crypto v0.28.0 + golang.org/x/crypto v0.32.0 golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 - golang.org/x/net v0.30.0 - golang.org/x/sync v0.8.0 - golang.org/x/sys v0.26.0 - golang.org/x/term v0.25.0 + golang.org/x/net v0.34.0 + golang.org/x/sync v0.10.0 + golang.org/x/sys v0.29.0 + golang.org/x/term v0.28.0 golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b golang.zx2c4.com/wireguard/windows v0.5.3 diff --git a/go.sum b/go.sum index d0e9c55..cf358f0 100644 --- a/go.sum +++ b/go.sum @@ -158,8 +158,8 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= -golang.org/x/crypto v0.28.0 h1:GBDwsMXVQi34v5CCYUm2jkJvu4cbtru2U4TN2PSyQnw= -golang.org/x/crypto v0.28.0/go.mod h1:rmgy+3RHxRZMyY0jjAJShp2zgEdOqj2AO7U0pYmeQ7U= +golang.org/x/crypto v0.32.0 h1:euUpcYgM8WcP71gNpTqQCn6rC2t6ULUPiOzfWaXVVfc= +golang.org/x/crypto v0.32.0/go.mod h1:ZnnJkOaASj8g0AjIduWNlq2NRxL0PlBrbKVyZ6V/Ugc= golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 h1:Di6/M8l0O2lCLc6VVRWhgCiApHV8MnQurBnFSHsQtNY= golang.org/x/exp v0.0.0-20230725093048-515e97ebf090/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc= golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY= @@ -178,8 +178,8 @@ golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLL golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= -golang.org/x/net v0.30.0 h1:AcW1SDZMkb8IpzCdQUaIq2sP4sZ4zw+55h6ynffypl4= -golang.org/x/net v0.30.0/go.mod h1:2wGyMJ5iFasEhkwi13ChkO/t1ECNC4X4eBKkVFyYFlU= +golang.org/x/net v0.34.0 h1:Mb7Mrk043xzHgnRM88suvJFwzVrRfHEHJEl5/71CKw0= +golang.org/x/net v0.34.0/go.mod h1:di0qlW3YNM5oh6GqDGQr92MyTozJPmybPK4Ev/Gm31k= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -187,8 +187,8 @@ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ= -golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ= +golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -206,11 +206,11 @@ golang.org/x/sys v0.0.0-20210603081109-ebe580a85c40/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.26.0 h1:KHjCJyddX0LoSTb3J+vWpupP9p0oznkqVk/IfjymZbo= -golang.org/x/sys v0.26.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.29.0 h1:TPYlXGxvx1MGTn2GiZDhnjPA9wZzZeGKHHmKhHYvgaU= +golang.org/x/sys v0.29.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= -golang.org/x/term v0.25.0 h1:WtHI/ltw4NvSUig5KARz9h521QvRC8RmF/cuYqifU24= -golang.org/x/term v0.25.0/go.mod h1:RPyXicDX+6vLxogjjRxjgD2TKtmAO6NZBsBRfrOLu7M= +golang.org/x/term v0.28.0 h1:/Ts8HFuMR2E6IP/jlo7QVLZHggjKQbhu/7H0LJFr3Gg= +golang.org/x/term v0.28.0/go.mod h1:Sw/lC2IAUZ92udQNf3WodGtn4k/XoLyZoh8v/8uiwek= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= From 9feda811a6a423e78623289bd5b401629784b891 Mon Sep 17 00:00:00 2001 From: Wade Simmons Date: Thu, 6 Mar 2025 13:21:49 -0500 Subject: [PATCH 40/67] bump go.mod to go1.23 (#1342) * bump go.mod to go1.23 * 1.23.6 --- go.mod | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/go.mod b/go.mod index afb613a..cbeddb0 100644 --- a/go.mod +++ b/go.mod @@ -1,8 +1,8 @@ module github.com/slackhq/nebula -go 1.22.0 +go 1.23.6 -toolchain go1.22.2 +toolchain go1.23.7 require ( dario.cat/mergo v1.0.1 From 8a090e59d7311ab54da3d2a1af05fbe45f0a6747 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 6 Mar 2025 13:26:29 -0600 Subject: [PATCH 41/67] Bump github.com/gaissmai/bart from 0.13.0 to 0.18.1 (#1341) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Nate Brown --- firewall.go | 15 ++++++--------- firewall_test.go | 5 +++++ go.mod | 3 +-- go.sum | 6 ++---- 4 files changed, 14 insertions(+), 15 deletions(-) diff --git a/firewall.go b/firewall.go index d3b9eb6..e9f454d 100644 --- a/firewall.go +++ b/firewall.go @@ -862,16 +862,13 @@ func (fr *FirewallRule) match(p firewall.Packet, c *cert.CachedCertificate) bool } } - matched := false - prefix := netip.PrefixFrom(p.RemoteAddr, p.RemoteAddr.BitLen()) - fr.CIDR.EachLookupPrefix(prefix, func(prefix netip.Prefix, val *firewallLocalCIDR) bool { - if prefix.Contains(p.RemoteAddr) && val.match(p, c) { - matched = true - return false + for _, v := range fr.CIDR.Supernets(netip.PrefixFrom(p.RemoteAddr, p.RemoteAddr.BitLen())) { + if v.match(p, c) { + return true } - return true - }) - return matched + } + + return false } func (flc *firewallLocalCIDR) addRule(f *Firewall, localIp netip.Prefix) error { diff --git a/firewall_test.go b/firewall_test.go index 4dd2c9a..8d32369 100644 --- a/firewall_test.go +++ b/firewall_test.go @@ -440,6 +440,11 @@ func TestFirewall_Drop3(t *testing.T) { // c3 should fail because no match resetConntrack(fw) assert.Equal(t, fw.Drop(p, true, &h3, cp, nil), ErrNoMatchingRule) + + // Test a remote address match + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) + assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.MustParsePrefix("1.2.3.4/24"), netip.Prefix{}, "", "")) + assert.NoError(t, fw.Drop(p, true, &h1, cp, nil)) } func TestFirewall_DropConntrackReload(t *testing.T) { diff --git a/go.mod b/go.mod index cbeddb0..22792cb 100644 --- a/go.mod +++ b/go.mod @@ -10,7 +10,7 @@ require ( github.com/armon/go-radix v1.0.0 github.com/cyberdelia/go-metrics-graphite v0.0.0-20161219230853-39f87cc3b432 github.com/flynn/noise v1.1.0 - github.com/gaissmai/bart v0.13.0 + github.com/gaissmai/bart v0.18.1 github.com/gogo/protobuf v1.3.2 github.com/google/gopacket v1.1.19 github.com/kardianos/service v1.2.2 @@ -40,7 +40,6 @@ require ( require ( github.com/beorn7/perks v1.0.1 // indirect - github.com/bits-and-blooms/bitset v1.14.3 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/google/btree v1.1.2 // indirect diff --git a/go.sum b/go.sum index cf358f0..3b1c8a6 100644 --- a/go.sum +++ b/go.sum @@ -14,8 +14,6 @@ github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24 github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= -github.com/bits-and-blooms/bitset v1.14.3 h1:Gd2c8lSNf9pKXom5JtD7AaKO8o7fGQ2LtFj1436qilA= -github.com/bits-and-blooms/bitset v1.14.3/go.mod h1:7hO7Gc7Pp1vODcmWvKMRA9BNmbv6a/7QIWpPxHddWR8= github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= @@ -26,8 +24,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/flynn/noise v1.1.0 h1:KjPQoQCEFdZDiP03phOvGi11+SVVhBG2wOWAorLsstg= github.com/flynn/noise v1.1.0/go.mod h1:xbMo+0i6+IGbYdJhF31t2eR1BIU0CYc12+BNAKwUTag= -github.com/gaissmai/bart v0.13.0 h1:pItEhXDVVebUa+i978FfQ7ye8xZc1FrMgs8nJPPWAgA= -github.com/gaissmai/bart v0.13.0/go.mod h1:qSes2fnJ8hB410BW0ymHUN/eQkuGpTYyJcN8sKMYpJU= +github.com/gaissmai/bart v0.18.1 h1:bX2j560JC1MJpoEDevBGvXL5OZ1mkls320Vl8Igb5QQ= +github.com/gaissmai/bart v0.18.1/go.mod h1:JJzMAhNF5Rjo4SF4jWBrANuJfqY+FvsFhW7t1UZJ+XY= github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY= From 13799f425da5059eaa910ee2984afe1a6f8b4cf1 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 6 Mar 2025 14:30:20 -0500 Subject: [PATCH 42/67] Bump the golang-x-dependencies group with 5 updates (#1339) Bumps the golang-x-dependencies group with 5 updates: | Package | From | To | | --- | --- | --- | | [golang.org/x/crypto](https://github.com/golang/crypto) | `0.32.0` | `0.36.0` | | [golang.org/x/net](https://github.com/golang/net) | `0.34.0` | `0.37.0` | | [golang.org/x/sync](https://github.com/golang/sync) | `0.10.0` | `0.12.0` | | [golang.org/x/sys](https://github.com/golang/sys) | `0.29.0` | `0.31.0` | | [golang.org/x/term](https://github.com/golang/term) | `0.28.0` | `0.30.0` | Updates `golang.org/x/crypto` from 0.32.0 to 0.36.0 - [Commits](https://github.com/golang/crypto/compare/v0.32.0...v0.36.0) Updates `golang.org/x/net` from 0.34.0 to 0.37.0 - [Commits](https://github.com/golang/net/compare/v0.34.0...v0.37.0) Updates `golang.org/x/sync` from 0.10.0 to 0.12.0 - [Commits](https://github.com/golang/sync/compare/v0.10.0...v0.12.0) Updates `golang.org/x/sys` from 0.29.0 to 0.31.0 - [Commits](https://github.com/golang/sys/compare/v0.29.0...v0.31.0) Updates `golang.org/x/term` from 0.28.0 to 0.30.0 - [Commits](https://github.com/golang/term/compare/v0.28.0...v0.30.0) --- updated-dependencies: - dependency-name: golang.org/x/crypto dependency-type: direct:production update-type: version-update:semver-minor dependency-group: golang-x-dependencies - dependency-name: golang.org/x/net dependency-type: direct:production update-type: version-update:semver-minor dependency-group: golang-x-dependencies - dependency-name: golang.org/x/sync dependency-type: direct:production update-type: version-update:semver-minor dependency-group: golang-x-dependencies - dependency-name: golang.org/x/sys dependency-type: direct:production update-type: version-update:semver-minor dependency-group: golang-x-dependencies - dependency-name: golang.org/x/term dependency-type: direct:production update-type: version-update:semver-minor dependency-group: golang-x-dependencies ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- go.mod | 10 +++++----- go.sum | 20 ++++++++++---------- 2 files changed, 15 insertions(+), 15 deletions(-) diff --git a/go.mod b/go.mod index 22792cb..3e008f6 100644 --- a/go.mod +++ b/go.mod @@ -24,12 +24,12 @@ require ( github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6 github.com/stretchr/testify v1.9.0 github.com/vishvananda/netlink v1.3.0 - golang.org/x/crypto v0.32.0 + golang.org/x/crypto v0.36.0 golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 - golang.org/x/net v0.34.0 - golang.org/x/sync v0.10.0 - golang.org/x/sys v0.29.0 - golang.org/x/term v0.28.0 + golang.org/x/net v0.37.0 + golang.org/x/sync v0.12.0 + golang.org/x/sys v0.31.0 + golang.org/x/term v0.30.0 golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b golang.zx2c4.com/wireguard/windows v0.5.3 diff --git a/go.sum b/go.sum index 3b1c8a6..ce91b5a 100644 --- a/go.sum +++ b/go.sum @@ -156,8 +156,8 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= -golang.org/x/crypto v0.32.0 h1:euUpcYgM8WcP71gNpTqQCn6rC2t6ULUPiOzfWaXVVfc= -golang.org/x/crypto v0.32.0/go.mod h1:ZnnJkOaASj8g0AjIduWNlq2NRxL0PlBrbKVyZ6V/Ugc= +golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34= +golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc= golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 h1:Di6/M8l0O2lCLc6VVRWhgCiApHV8MnQurBnFSHsQtNY= golang.org/x/exp v0.0.0-20230725093048-515e97ebf090/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc= golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY= @@ -176,8 +176,8 @@ golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLL golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= -golang.org/x/net v0.34.0 h1:Mb7Mrk043xzHgnRM88suvJFwzVrRfHEHJEl5/71CKw0= -golang.org/x/net v0.34.0/go.mod h1:di0qlW3YNM5oh6GqDGQr92MyTozJPmybPK4Ev/Gm31k= +golang.org/x/net v0.37.0 h1:1zLorHbz+LYj7MQlSf1+2tPIIgibq2eL5xkrGk6f+2c= +golang.org/x/net v0.37.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -185,8 +185,8 @@ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ= -golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sync v0.12.0 h1:MHc5BpPuC30uJk597Ri8TV3CNZcTLu6B6z4lJy+g6Jw= +golang.org/x/sync v0.12.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -204,11 +204,11 @@ golang.org/x/sys v0.0.0-20210603081109-ebe580a85c40/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.29.0 h1:TPYlXGxvx1MGTn2GiZDhnjPA9wZzZeGKHHmKhHYvgaU= -golang.org/x/sys v0.29.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik= +golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= -golang.org/x/term v0.28.0 h1:/Ts8HFuMR2E6IP/jlo7QVLZHggjKQbhu/7H0LJFr3Gg= -golang.org/x/term v0.28.0/go.mod h1:Sw/lC2IAUZ92udQNf3WodGtn4k/XoLyZoh8v/8uiwek= +golang.org/x/term v0.30.0 h1:PQ39fJZ+mfadBm0y5WlL4vlM7Sx1Hgf13sMIY2+QS9Y= +golang.org/x/term v0.30.0/go.mod h1:NYYFdzHoI5wRh/h5tDMdMqCqPJZEuNqVR5xJLd/n67g= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= From 775c6bc83df70d7991036466e9de877f99c22f57 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 6 Mar 2025 15:43:55 -0500 Subject: [PATCH 43/67] Bump google.golang.org/protobuf (#1344) Bumps the protobuf-dependencies group with 1 update in the / directory: google.golang.org/protobuf. Updates `google.golang.org/protobuf` from 1.35.1 to 1.36.5 --- updated-dependencies: - dependency-name: google.golang.org/protobuf dependency-type: direct:production update-type: version-update:semver-minor dependency-group: protobuf-dependencies ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- go.mod | 2 +- go.sum | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/go.mod b/go.mod index 3e008f6..7bd4925 100644 --- a/go.mod +++ b/go.mod @@ -33,7 +33,7 @@ require ( golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b golang.zx2c4.com/wireguard/windows v0.5.3 - google.golang.org/protobuf v1.35.1 + google.golang.org/protobuf v1.36.5 gopkg.in/yaml.v2 v2.4.0 gvisor.dev/gvisor v0.0.0-20240423190808-9d7a357edefe ) diff --git a/go.sum b/go.sum index ce91b5a..2813b5f 100644 --- a/go.sum +++ b/go.sum @@ -239,8 +239,8 @@ google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miE google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo= google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= -google.golang.org/protobuf v1.35.1 h1:m3LfL6/Ca+fqnjnlqQXNpFPABW1UD7mjh8KO2mKFytA= -google.golang.org/protobuf v1.35.1/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE= +google.golang.org/protobuf v1.36.5 h1:tPhr+woSbjfYvY6/GPufUoYizxw1cF/yFoxJ2fmpwlM= +google.golang.org/protobuf v1.36.5/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE= gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= From c46ef435904efc198849106f4322e6b1ed83c8dc Mon Sep 17 00:00:00 2001 From: Wade Simmons Date: Thu, 6 Mar 2025 15:44:41 -0500 Subject: [PATCH 44/67] smoke-test-extra: cleanup ncat references (#1343) * smoke-extra: cleanup ncat references We can't run the `ncat` tests unless we can make sure to install it to all of the vagrant boxes. * more ncat * more cleanup --- .github/workflows/smoke/smoke-vagrant.sh | 34 ++++++++++++------------ 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/.github/workflows/smoke/smoke-vagrant.sh b/.github/workflows/smoke/smoke-vagrant.sh index 76cf72f..1c1e3c5 100755 --- a/.github/workflows/smoke/smoke-vagrant.sh +++ b/.github/workflows/smoke/smoke-vagrant.sh @@ -29,13 +29,13 @@ docker run --name lighthouse1 --rm "$CONTAINER" -config lighthouse1.yml -test docker run --name host2 --rm "$CONTAINER" -config host2.yml -test vagrant up -vagrant ssh -c "cd /nebula && /nebula/$1-nebula -config host3.yml -test" +vagrant ssh -c "cd /nebula && /nebula/$1-nebula -config host3.yml -test" -- -T docker run --name lighthouse1 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config lighthouse1.yml 2>&1 | tee logs/lighthouse1 | sed -u 's/^/ [lighthouse1] /' & sleep 1 docker run --name host2 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config host2.yml 2>&1 | tee logs/host2 | sed -u 's/^/ [host2] /' & sleep 1 -vagrant ssh -c "cd /nebula && sudo sh -c 'echo \$\$ >/nebula/pid && exec /nebula/$1-nebula -config host3.yml'" & +vagrant ssh -c "cd /nebula && sudo sh -c 'echo \$\$ >/nebula/pid && exec /nebula/$1-nebula -config host3.yml'" 2>&1 -- -T | tee logs/host3 | sed -u 's/^/ [host3] /' & sleep 15 # grab tcpdump pcaps for debugging @@ -46,8 +46,8 @@ docker exec host2 tcpdump -i eth0 -q -w - -U 2>logs/host2.outside.log >logs/host # vagrant ssh -c "tcpdump -i nebula1 -q -w - -U" 2>logs/host3.inside.log >logs/host3.inside.pcap & # vagrant ssh -c "tcpdump -i eth0 -q -w - -U" 2>logs/host3.outside.log >logs/host3.outside.pcap & -docker exec host2 ncat -nklv 0.0.0.0 2000 & -vagrant ssh -c "ncat -nklv 0.0.0.0 2000" & +#docker exec host2 ncat -nklv 0.0.0.0 2000 & +#vagrant ssh -c "ncat -nklv 0.0.0.0 2000" & #docker exec host2 ncat -e '/usr/bin/echo host2' -nkluv 0.0.0.0 3000 & #vagrant ssh -c "ncat -e '/usr/bin/echo host3' -nkluv 0.0.0.0 3000" & @@ -68,11 +68,11 @@ docker exec host2 ping -c1 192.168.100.1 # Should fail because not allowed by host3 inbound firewall ! docker exec host2 ping -c1 192.168.100.3 -w5 || exit 1 -set +x -echo -echo " *** Testing ncat from host2" -echo -set -x +#set +x +#echo +#echo " *** Testing ncat from host2" +#echo +#set -x # Should fail because not allowed by host3 inbound firewall #! docker exec host2 ncat -nzv -w5 192.168.100.3 2000 || exit 1 #! docker exec host2 ncat -nzuv -w5 192.168.100.3 3000 | grep -q host3 || exit 1 @@ -82,18 +82,18 @@ echo echo " *** Testing ping from host3" echo set -x -vagrant ssh -c "ping -c1 192.168.100.1" -vagrant ssh -c "ping -c1 192.168.100.2" +vagrant ssh -c "ping -c1 192.168.100.1" -- -T +vagrant ssh -c "ping -c1 192.168.100.2" -- -T -set +x -echo -echo " *** Testing ncat from host3" -echo -set -x +#set +x +#echo +#echo " *** Testing ncat from host3" +#echo +#set -x #vagrant ssh -c "ncat -nzv -w5 192.168.100.2 2000" #vagrant ssh -c "ncat -nzuv -w5 192.168.100.2 3000" | grep -q host2 -vagrant ssh -c "sudo xargs kill Date: Fri, 7 Mar 2025 09:44:30 -0500 Subject: [PATCH 45/67] Bump github.com/prometheus/client_golang from 1.20.4 to 1.21.1 (#1340) Bumps [github.com/prometheus/client_golang](https://github.com/prometheus/client_golang) from 1.20.4 to 1.21.1. - [Release notes](https://github.com/prometheus/client_golang/releases) - [Changelog](https://github.com/prometheus/client_golang/blob/main/CHANGELOG.md) - [Commits](https://github.com/prometheus/client_golang/compare/v1.20.4...v1.21.1) --- updated-dependencies: - dependency-name: github.com/prometheus/client_golang dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- go.mod | 8 ++++---- go.sum | 16 ++++++++-------- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/go.mod b/go.mod index 7bd4925..eca466c 100644 --- a/go.mod +++ b/go.mod @@ -17,12 +17,12 @@ require ( github.com/miekg/dns v1.1.62 github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b github.com/nbrownus/go-metrics-prometheus v0.0.0-20210712211119-974a6260965f - github.com/prometheus/client_golang v1.20.4 + github.com/prometheus/client_golang v1.21.1 github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475 github.com/sirupsen/logrus v1.9.3 github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6 - github.com/stretchr/testify v1.9.0 + github.com/stretchr/testify v1.10.0 github.com/vishvananda/netlink v1.3.0 golang.org/x/crypto v0.36.0 golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 @@ -43,11 +43,11 @@ require ( github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/google/btree v1.1.2 // indirect - github.com/klauspost/compress v1.17.9 // indirect + github.com/klauspost/compress v1.17.11 // indirect github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/prometheus/client_model v0.6.1 // indirect - github.com/prometheus/common v0.55.0 // indirect + github.com/prometheus/common v0.62.0 // indirect github.com/prometheus/procfs v0.15.1 // indirect github.com/vishvananda/netns v0.0.4 // indirect golang.org/x/mod v0.18.0 // indirect diff --git a/go.sum b/go.sum index 2813b5f..d0cfb92 100644 --- a/go.sum +++ b/go.sum @@ -68,8 +68,8 @@ github.com/kardianos/service v1.2.2 h1:ZvePhAHfvo0A7Mftk/tEzqEZ7Q4lgnR8sGz4xu1YX github.com/kardianos/service v1.2.2/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= -github.com/klauspost/compress v1.17.9 h1:6KIumPrER1LHsvBVuDa0r5xaG0Es51mhhB9BQB2qeMA= -github.com/klauspost/compress v1.17.9/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw= +github.com/klauspost/compress v1.17.11 h1:In6xLpyWOi1+C7tXUUWv2ot1QvBjxevKAaI6IXrJmUc= +github.com/klauspost/compress v1.17.11/go.mod h1:pMDklpSncoRMuLFrf1W9Ss9KT+0rH90U12bZKk7uwG0= github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/konsorten/go-windows-terminal-sequences v1.0.3/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc= @@ -106,8 +106,8 @@ github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXP github.com/prometheus/client_golang v1.0.0/go.mod h1:db9x61etRT2tGnBNRi70OPL5FsnadC4Ky3P0J6CfImo= github.com/prometheus/client_golang v1.7.1/go.mod h1:PY5Wy2awLA44sXw4AOSfFBetzPP4j5+D6mVACh+pe2M= github.com/prometheus/client_golang v1.11.0/go.mod h1:Z6t4BnS23TR94PD6BsDNk8yVqroYurpAkEiz0P2BEV0= -github.com/prometheus/client_golang v1.20.4 h1:Tgh3Yr67PaOv/uTqloMsCEdeuFTatm5zIq5+qNN23vI= -github.com/prometheus/client_golang v1.20.4/go.mod h1:PIEt8X02hGcP8JWbeHyeZ53Y/jReSnHgO035n//V5WE= +github.com/prometheus/client_golang v1.21.1 h1:DOvXXTqVzvkIewV/CDPFdejpMCGeMcbGCQ8YOmu+Ibk= +github.com/prometheus/client_golang v1.21.1/go.mod h1:U9NM32ykUErtVBxdvD3zfi+EuFkkaBvMb09mIfe0Zgg= github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo= github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= github.com/prometheus/client_model v0.2.0/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= @@ -116,8 +116,8 @@ github.com/prometheus/client_model v0.6.1/go.mod h1:OrxVMOVHjw3lKMa8+x6HeMGkHMQy github.com/prometheus/common v0.4.1/go.mod h1:TNfzLD0ON7rHzMJeJkieUDPYmFC7Snx/y86RQel1bk4= github.com/prometheus/common v0.10.0/go.mod h1:Tlit/dnDKsSWFlCLTWaA1cyBgKHSMdTB80sz/V91rCo= github.com/prometheus/common v0.26.0/go.mod h1:M7rCNAaPfAosfx8veZJCuw84e35h3Cfd9VFqTh1DIvc= -github.com/prometheus/common v0.55.0 h1:KEi6DK7lXW/m7Ig5i47x0vRzuBsHuvJdi5ee6Y3G1dc= -github.com/prometheus/common v0.55.0/go.mod h1:2SECS4xJG1kd8XF9IcM1gMX6510RAEL65zxzNImwdc8= +github.com/prometheus/common v0.62.0 h1:xasJaQlnWAeyHdUBeGjXmutelfJHWMRr+Fg4QszZ2Io= +github.com/prometheus/common v0.62.0/go.mod h1:vyBcEuLSvWos9B1+CyL7JZ2up+uFzXhkqml0W5zIY1I= github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk= github.com/prometheus/procfs v0.0.2/go.mod h1:TjEm7ze935MbeOT/UhFTIMYKhuLP4wbCsTZCD3I8kEA= github.com/prometheus/procfs v0.1.3/go.mod h1:lV6e/gmhEcM9IjHGsFOCxxuZ+z1YqCvr4OA4YeYWdaU= @@ -143,8 +143,8 @@ github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXf github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= -github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= +github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/vishvananda/netlink v1.3.0 h1:X7l42GfcV4S6E4vHTsw48qbrV+9PVojNfIhZcwQdrZk= github.com/vishvananda/netlink v1.3.0/go.mod h1:i6NetklAujEcC6fK0JPjT8qSwWyO0HLn4UKG+hGqeJs= github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8= From f8734ffa431143df89dbd6323d21ebd0e3659bdf Mon Sep 17 00:00:00 2001 From: Nate Brown Date: Fri, 7 Mar 2025 10:45:31 -0600 Subject: [PATCH 46/67] Improve logging when handshaking with an invalid cert (#1345) --- cert/cert.go | 22 +++--------- handshake_ix.go | 79 ++++++++++++++++++++++++++++---------------- handshake_manager.go | 2 +- 3 files changed, 56 insertions(+), 47 deletions(-) diff --git a/cert/cert.go b/cert/cert.go index 4246571..38a2528 100644 --- a/cert/cert.go +++ b/cert/cert.go @@ -113,10 +113,10 @@ func (cc *CachedCertificate) String() string { return cc.Certificate.String() } -// RecombineAndValidate will attempt to unmarshal a certificate received in a handshake. +// Recombine will attempt to unmarshal a certificate received in a handshake. // Handshakes save space by placing the peers public key in a different part of the packet, we have to // reassemble the actual certificate structure with that in mind. -func RecombineAndValidate(v Version, rawCertBytes, publicKey []byte, curve Curve, caPool *CAPool) (*CachedCertificate, error) { +func Recombine(v Version, rawCertBytes, publicKey []byte, curve Curve) (Certificate, error) { if publicKey == nil { return nil, ErrNoPeerStaticKey } @@ -125,29 +125,15 @@ func RecombineAndValidate(v Version, rawCertBytes, publicKey []byte, curve Curve return nil, ErrNoPayload } - c, err := unmarshalCertificateFromHandshake(v, rawCertBytes, publicKey, curve) - if err != nil { - return nil, fmt.Errorf("error unmarshaling cert: %w", err) - } - - cc, err := caPool.VerifyCertificate(time.Now(), c) - if err != nil { - return nil, fmt.Errorf("certificate validation failed: %w", err) - } - - return cc, nil -} - -func unmarshalCertificateFromHandshake(v Version, b []byte, publicKey []byte, curve Curve) (Certificate, error) { var c Certificate var err error switch v { // Implementations must ensure the result is a valid cert! case VersionPre1, Version1: - c, err = unmarshalCertificateV1(b, publicKey) + c, err = unmarshalCertificateV1(rawCertBytes, publicKey) case Version2: - c, err = unmarshalCertificateV2(b, publicKey, curve) + c, err = unmarshalCertificateV2(rawCertBytes, publicKey, curve) default: //TODO: CERT-V2 make a static var return nil, fmt.Errorf("unknown certificate version %d", v) diff --git a/handshake_ix.go b/handshake_ix.go index 9b8b3e9..daea526 100644 --- a/handshake_ix.go +++ b/handshake_ix.go @@ -132,13 +132,28 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet return } - remoteCert, err := cert.RecombineAndValidate(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), ci.Curve(), f.pki.GetCAPool()) + rc, err := cert.Recombine(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), ci.Curve()) if err != nil { - e := f.l.WithError(err).WithField("udpAddr", addr). - WithField("handshake", m{"stage": 1, "style": "ix_psk0"}) + f.l.WithError(err).WithField("udpAddr", addr). + WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). + Info("Handshake did not contain a certificate") + return + } - if f.l.Level > logrus.DebugLevel { - e = e.WithField("cert", remoteCert) + remoteCert, err := f.pki.GetCAPool().VerifyCertificate(time.Now(), rc) + if err != nil { + fp, err := rc.Fingerprint() + if err != nil { + fp = "" + } + + e := f.l.WithError(err).WithField("udpAddr", addr). + WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). + WithField("certVpnNetworks", rc.Networks()). + WithField("certFingerprint", fp) + + if f.l.Level >= logrus.DebugLevel { + e = e.WithField("cert", rc) } e.Info("Invalid certificate from host") @@ -160,14 +175,10 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet } if len(remoteCert.Certificate.Networks()) == 0 { - e := f.l.WithError(err).WithField("udpAddr", addr). - WithField("handshake", m{"stage": 1, "style": "ix_psk0"}) - - if f.l.Level > logrus.DebugLevel { - e = e.WithField("cert", remoteCert) - } - - e.Info("Invalid vpn ip from host") + f.l.WithError(err).WithField("udpAddr", addr). + WithField("cert", remoteCert). + WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). + Info("No networks in certificate") return } @@ -487,30 +498,42 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha return true } - remoteCert, err := cert.RecombineAndValidate(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), ci.Curve(), f.pki.GetCAPool()) + rc, err := cert.Recombine(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), ci.Curve()) if err != nil { - e := f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr). - WithField("handshake", m{"stage": 2, "style": "ix_psk0"}) + f.l.WithError(err).WithField("udpAddr", addr). + WithField("vpnAddrs", hostinfo.vpnAddrs). + WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). + Info("Handshake did not contain a certificate") + return true + } - if f.l.Level > logrus.DebugLevel { - e = e.WithField("cert", remoteCert) + remoteCert, err := f.pki.GetCAPool().VerifyCertificate(time.Now(), rc) + if err != nil { + fp, err := rc.Fingerprint() + if err != nil { + fp = "" } - e.Error("Invalid certificate from host") + e := f.l.WithError(err).WithField("udpAddr", addr). + WithField("vpnAddrs", hostinfo.vpnAddrs). + WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). + WithField("certFingerprint", fp). + WithField("certVpnNetworks", rc.Networks()) - // The handshake state machine is complete, if things break now there is no chance to recover. Tear down and start again + if f.l.Level >= logrus.DebugLevel { + e = e.WithField("cert", rc) + } + + e.Info("Invalid certificate from host") return true } if len(remoteCert.Certificate.Networks()) == 0 { - e := f.l.WithError(err).WithField("udpAddr", addr). - WithField("handshake", m{"stage": 2, "style": "ix_psk0"}) - - if f.l.Level > logrus.DebugLevel { - e = e.WithField("cert", remoteCert) - } - - e.Info("Empty networks from host") + f.l.WithError(err).WithField("udpAddr", addr). + WithField("vpnAddrs", hostinfo.vpnAddrs). + WithField("cert", remoteCert). + WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). + Info("No networks in certificate") return true } diff --git a/handshake_manager.go b/handshake_manager.go index 6d3ed12..6f95402 100644 --- a/handshake_manager.go +++ b/handshake_manager.go @@ -257,7 +257,7 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered WithField("initiatorIndex", hostinfo.localIndexId). WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). Info("Handshake message sent") - } else if hm.l.IsLevelEnabled(logrus.DebugLevel) { + } else if hm.l.Level >= logrus.DebugLevel { hostinfo.logger(hm.l).WithField("udpAddrs", sentTo). WithField("initiatorIndex", hostinfo.localIndexId). WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). From 096179a8c9d1644126d0703b57db12a849a86182 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Fri, 7 Mar 2025 12:05:36 -0500 Subject: [PATCH 47/67] Bump github.com/miekg/dns from 1.1.62 to 1.1.63 (#1346) Bumps [github.com/miekg/dns](https://github.com/miekg/dns) from 1.1.62 to 1.1.63. - [Changelog](https://github.com/miekg/dns/blob/master/Makefile.release) - [Commits](https://github.com/miekg/dns/compare/v1.1.62...v1.1.63) --- updated-dependencies: - dependency-name: github.com/miekg/dns dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- go.mod | 2 +- go.sum | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/go.mod b/go.mod index eca466c..de09c18 100644 --- a/go.mod +++ b/go.mod @@ -14,7 +14,7 @@ require ( github.com/gogo/protobuf v1.3.2 github.com/google/gopacket v1.1.19 github.com/kardianos/service v1.2.2 - github.com/miekg/dns v1.1.62 + github.com/miekg/dns v1.1.63 github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b github.com/nbrownus/go-metrics-prometheus v0.0.0-20210712211119-974a6260965f github.com/prometheus/client_golang v1.21.1 diff --git a/go.sum b/go.sum index d0cfb92..11f57c7 100644 --- a/go.sum +++ b/go.sum @@ -83,8 +83,8 @@ github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= -github.com/miekg/dns v1.1.62 h1:cN8OuEF1/x5Rq6Np+h1epln8OiyPWV+lROx9LxcGgIQ= -github.com/miekg/dns v1.1.62/go.mod h1:mvDlcItzm+br7MToIKqkglaGhlFMHJ9DTNNWONWXbNQ= +github.com/miekg/dns v1.1.63 h1:8M5aAw6OMZfFXTT7K5V0Eu5YiiL8l7nUAkyN6C9YwaY= +github.com/miekg/dns v1.1.63/go.mod h1:6NGHfjhpmr5lt3XPLuyfDJi5AXbNIPM9PY6H6sF1Nfs= github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b h1:J/AzCvg5z0Hn1rqZUJjpbzALUmkKX0Zwbc/i4fw7Sfk= github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b/go.mod h1:XsNlhZGX73bx86s2hdc/FuaLm2CPZJemRLMA+WTFxgs= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= From f7540ad3556b5a4a9b4bf846b1a3f8e885468f64 Mon Sep 17 00:00:00 2001 From: Caleb Jasik Date: Fri, 7 Mar 2025 14:37:07 -0600 Subject: [PATCH 48/67] Remove commented out metadata.go (#1320) --- metadata.go | 18 ------------------ 1 file changed, 18 deletions(-) delete mode 100644 metadata.go diff --git a/metadata.go b/metadata.go deleted file mode 100644 index 6a023ab..0000000 --- a/metadata.go +++ /dev/null @@ -1,18 +0,0 @@ -package nebula - -/* - -import ( - proto "google.golang.org/protobuf/proto" -) - -func HandleMetaProto(p []byte) { - m := &NebulaMeta{} - err := proto.Unmarshal(p, m) - if err != nil { - l.Debugf("problem unmarshaling meta message: %s", err) - } - //fmt.Println(m) -} - -*/ From 94e89a10453a0f33a96ae5cda30f4d28606858d2 Mon Sep 17 00:00:00 2001 From: Wade Simmons Date: Mon, 10 Mar 2025 10:17:54 -0400 Subject: [PATCH 49/67] smoke-tests: guess the lighthouse container IP better (#1347) Currently we just assume you are using the default Docker bridge network config of `172.17.0.0/24`. This change works to try to detect if you are using a different config, but still only works if you are using a `/24` and aren't running any other containers. A future PR could make this better by launching the lighthouse container first and then fetching what the IP address is before continuing with the configuration. --- .github/workflows/smoke/build.sh | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/.github/workflows/smoke/build.sh b/.github/workflows/smoke/build.sh index c546653..dcd132b 100755 --- a/.github/workflows/smoke/build.sh +++ b/.github/workflows/smoke/build.sh @@ -5,6 +5,10 @@ set -e -x rm -rf ./build mkdir ./build +# TODO: Assumes your docker bridge network is a /24, and the first container that launches will be .1 +# - We could make this better by launching the lighthouse first and then fetching what IP it is. +NET="$(docker network inspect bridge -f '{{ range .IPAM.Config }}{{ .Subnet }}{{ end }}' | cut -d. -f1-3)" + ( cd build @@ -21,16 +25,16 @@ mkdir ./build ../genconfig.sh >lighthouse1.yml HOST="host2" \ - LIGHTHOUSES="192.168.100.1 172.17.0.2:4242" \ + LIGHTHOUSES="192.168.100.1 $NET.2:4242" \ ../genconfig.sh >host2.yml HOST="host3" \ - LIGHTHOUSES="192.168.100.1 172.17.0.2:4242" \ + LIGHTHOUSES="192.168.100.1 $NET.2:4242" \ INBOUND='[{"port": "any", "proto": "icmp", "group": "lighthouse"}]' \ ../genconfig.sh >host3.yml HOST="host4" \ - LIGHTHOUSES="192.168.100.1 172.17.0.2:4242" \ + LIGHTHOUSES="192.168.100.1 $NET.2:4242" \ OUTBOUND='[{"port": "any", "proto": "icmp", "group": "lighthouse"}]' \ ../genconfig.sh >host4.yml From 612637f5290186c29e71f88ccfa9fcbda06e1666 Mon Sep 17 00:00:00 2001 From: Caleb Jasik Date: Mon, 10 Mar 2025 09:18:34 -0500 Subject: [PATCH 50/67] Fix `testifylint` lint errors (#1321) * Fix bool-compare * Fix empty * Fix encoded-compare * Fix error-is-as * Fix error-nil * Fix expected-actual * Fix len --- allow_list_test.go | 30 ++++----- cert/ca_pool_test.go | 118 ++++++++++++++++----------------- cert/cert_v1_test.go | 36 +++++----- cert/cert_v2_test.go | 30 ++++----- cert/crypto_test.go | 10 +-- cert/pem_test.go | 22 +++--- cert/sign_test.go | 12 ++-- cmd/nebula-cert/ca_test.go | 34 +++++----- cmd/nebula-cert/keygen_test.go | 14 ++-- cmd/nebula-cert/print_test.go | 6 +- cmd/nebula-cert/sign_test.go | 38 +++++------ cmd/nebula-cert/verify_test.go | 9 ++- config/config_test.go | 24 +++---- firewall_test.go | 86 ++++++++++++------------ handshake_manager_test.go | 2 +- header/header_test.go | 2 +- lighthouse_test.go | 2 +- outside_test.go | 20 +++--- overlay/route_test.go | 24 +++---- punchy_test.go | 16 ++--- 20 files changed, 267 insertions(+), 268 deletions(-) diff --git a/allow_list_test.go b/allow_list_test.go index c8b3d08..6d5e76b 100644 --- a/allow_list_test.go +++ b/allow_list_test.go @@ -98,7 +98,7 @@ func TestNewAllowListFromConfig(t *testing.T) { } func TestAllowList_Allow(t *testing.T) { - assert.Equal(t, true, ((*AllowList)(nil)).Allow(netip.MustParseAddr("1.1.1.1"))) + assert.True(t, ((*AllowList)(nil)).Allow(netip.MustParseAddr("1.1.1.1"))) tree := new(bart.Table[bool]) tree.Insert(netip.MustParsePrefix("0.0.0.0/0"), true) @@ -111,17 +111,17 @@ func TestAllowList_Allow(t *testing.T) { tree.Insert(netip.MustParsePrefix("::2/128"), false) al := &AllowList{cidrTree: tree} - assert.Equal(t, true, al.Allow(netip.MustParseAddr("1.1.1.1"))) - assert.Equal(t, false, al.Allow(netip.MustParseAddr("10.0.0.4"))) - assert.Equal(t, true, al.Allow(netip.MustParseAddr("10.42.42.42"))) - assert.Equal(t, false, al.Allow(netip.MustParseAddr("10.42.42.41"))) - assert.Equal(t, true, al.Allow(netip.MustParseAddr("10.42.0.1"))) - assert.Equal(t, true, al.Allow(netip.MustParseAddr("::1"))) - assert.Equal(t, false, al.Allow(netip.MustParseAddr("::2"))) + assert.True(t, al.Allow(netip.MustParseAddr("1.1.1.1"))) + assert.False(t, al.Allow(netip.MustParseAddr("10.0.0.4"))) + assert.True(t, al.Allow(netip.MustParseAddr("10.42.42.42"))) + assert.False(t, al.Allow(netip.MustParseAddr("10.42.42.41"))) + assert.True(t, al.Allow(netip.MustParseAddr("10.42.0.1"))) + assert.True(t, al.Allow(netip.MustParseAddr("::1"))) + assert.False(t, al.Allow(netip.MustParseAddr("::2"))) } func TestLocalAllowList_AllowName(t *testing.T) { - assert.Equal(t, true, ((*LocalAllowList)(nil)).AllowName("docker0")) + assert.True(t, ((*LocalAllowList)(nil)).AllowName("docker0")) rules := []AllowListNameRule{ {Name: regexp.MustCompile("^docker.*$"), Allow: false}, @@ -129,9 +129,9 @@ func TestLocalAllowList_AllowName(t *testing.T) { } al := &LocalAllowList{nameRules: rules} - assert.Equal(t, false, al.AllowName("docker0")) - assert.Equal(t, false, al.AllowName("tun0")) - assert.Equal(t, true, al.AllowName("eth0")) + assert.False(t, al.AllowName("docker0")) + assert.False(t, al.AllowName("tun0")) + assert.True(t, al.AllowName("eth0")) rules = []AllowListNameRule{ {Name: regexp.MustCompile("^eth.*$"), Allow: true}, @@ -139,7 +139,7 @@ func TestLocalAllowList_AllowName(t *testing.T) { } al = &LocalAllowList{nameRules: rules} - assert.Equal(t, false, al.AllowName("docker0")) - assert.Equal(t, true, al.AllowName("eth0")) - assert.Equal(t, true, al.AllowName("ens5")) + assert.False(t, al.AllowName("docker0")) + assert.True(t, al.AllowName("eth0")) + assert.True(t, al.AllowName("ens5")) } diff --git a/cert/ca_pool_test.go b/cert/ca_pool_test.go index f03b2ba..2f9255f 100644 --- a/cert/ca_pool_test.go +++ b/cert/ca_pool_test.go @@ -82,32 +82,32 @@ k+coOv04r+zh33ISyhbsafnYduN17p2eD7CmHvHuerguXD9f32gcxo/KsFCKEjMe } p, err := NewCAPoolFromPEM([]byte(noNewLines)) - assert.Nil(t, err) + assert.NoError(t, err) assert.Equal(t, p.CAs["ce4e6c7a596996eb0d82a8875f0f0137a4b53ce22d2421c9fd7150e7a26f6300"].Certificate.Name(), rootCA.details.name) assert.Equal(t, p.CAs["04c585fcd9a49b276df956a22b7ebea3bf23f1fca5a17c0b56ce2e626631969e"].Certificate.Name(), rootCA01.details.name) pp, err := NewCAPoolFromPEM([]byte(withNewLines)) - assert.Nil(t, err) + assert.NoError(t, err) assert.Equal(t, pp.CAs["ce4e6c7a596996eb0d82a8875f0f0137a4b53ce22d2421c9fd7150e7a26f6300"].Certificate.Name(), rootCA.details.name) assert.Equal(t, pp.CAs["04c585fcd9a49b276df956a22b7ebea3bf23f1fca5a17c0b56ce2e626631969e"].Certificate.Name(), rootCA01.details.name) // expired cert, no valid certs ppp, err := NewCAPoolFromPEM([]byte(expired)) assert.Equal(t, ErrExpired, err) - assert.Equal(t, ppp.CAs["c39b35a0e8f246203fe4f32b9aa8bfd155f1ae6a6be9d78370641e43397f48f5"].Certificate.Name(), "expired") + assert.Equal(t, "expired", ppp.CAs["c39b35a0e8f246203fe4f32b9aa8bfd155f1ae6a6be9d78370641e43397f48f5"].Certificate.Name()) // expired cert, with valid certs pppp, err := NewCAPoolFromPEM(append([]byte(expired), noNewLines...)) assert.Equal(t, ErrExpired, err) assert.Equal(t, pppp.CAs["ce4e6c7a596996eb0d82a8875f0f0137a4b53ce22d2421c9fd7150e7a26f6300"].Certificate.Name(), rootCA.details.name) assert.Equal(t, pppp.CAs["04c585fcd9a49b276df956a22b7ebea3bf23f1fca5a17c0b56ce2e626631969e"].Certificate.Name(), rootCA01.details.name) - assert.Equal(t, pppp.CAs["c39b35a0e8f246203fe4f32b9aa8bfd155f1ae6a6be9d78370641e43397f48f5"].Certificate.Name(), "expired") - assert.Equal(t, len(pppp.CAs), 3) + assert.Equal(t, "expired", pppp.CAs["c39b35a0e8f246203fe4f32b9aa8bfd155f1ae6a6be9d78370641e43397f48f5"].Certificate.Name()) + assert.Len(t, pppp.CAs, 3) ppppp, err := NewCAPoolFromPEM([]byte(p256)) - assert.Nil(t, err) + assert.NoError(t, err) assert.Equal(t, ppppp.CAs["552bf7d99bec1fc775a0e4c324bf6d8f789b3078f1919c7960d2e5e0c351ee97"].Certificate.Name(), rootCAP256.details.name) - assert.Equal(t, len(ppppp.CAs), 1) + assert.Len(t, ppppp.CAs, 1) } func TestCertificateV1_Verify(t *testing.T) { @@ -118,7 +118,7 @@ func TestCertificateV1_Verify(t *testing.T) { assert.NoError(t, caPool.AddCA(ca)) f, err := c.Fingerprint() - assert.Nil(t, err) + assert.NoError(t, err) caPool.BlocklistFingerprint(f) _, err = caPool.VerifyCertificate(time.Now(), c) @@ -126,7 +126,7 @@ func TestCertificateV1_Verify(t *testing.T) { caPool.ResetCertBlocklist() _, err = caPool.VerifyCertificate(time.Now(), c) - assert.Nil(t, err) + assert.NoError(t, err) _, err = caPool.VerifyCertificate(time.Now().Add(time.Hour*1000), c) assert.EqualError(t, err, "root certificate is expired") @@ -138,7 +138,7 @@ func TestCertificateV1_Verify(t *testing.T) { // Test group assertion ca, _, caKey, _ = NewTestCaCert(Version1, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{"test1", "test2"}) caPem, err := ca.MarshalPEM() - assert.Nil(t, err) + assert.NoError(t, err) caPool = NewCAPool() b, err := caPool.AddCAFromPEM(caPem) @@ -150,9 +150,9 @@ func TestCertificateV1_Verify(t *testing.T) { }) c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test2", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"}) - assert.Nil(t, err) + assert.NoError(t, err) _, err = caPool.VerifyCertificate(time.Now(), c) - assert.Nil(t, err) + assert.NoError(t, err) } func TestCertificateV1_VerifyP256(t *testing.T) { @@ -163,7 +163,7 @@ func TestCertificateV1_VerifyP256(t *testing.T) { assert.NoError(t, caPool.AddCA(ca)) f, err := c.Fingerprint() - assert.Nil(t, err) + assert.NoError(t, err) caPool.BlocklistFingerprint(f) _, err = caPool.VerifyCertificate(time.Now(), c) @@ -171,7 +171,7 @@ func TestCertificateV1_VerifyP256(t *testing.T) { caPool.ResetCertBlocklist() _, err = caPool.VerifyCertificate(time.Now(), c) - assert.Nil(t, err) + assert.NoError(t, err) _, err = caPool.VerifyCertificate(time.Now().Add(time.Hour*1000), c) assert.EqualError(t, err, "root certificate is expired") @@ -183,7 +183,7 @@ func TestCertificateV1_VerifyP256(t *testing.T) { // Test group assertion ca, _, caKey, _ = NewTestCaCert(Version1, Curve_P256, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{"test1", "test2"}) caPem, err := ca.MarshalPEM() - assert.Nil(t, err) + assert.NoError(t, err) caPool = NewCAPool() b, err := caPool.AddCAFromPEM(caPem) @@ -196,7 +196,7 @@ func TestCertificateV1_VerifyP256(t *testing.T) { c, _, _, _ = NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"}) _, err = caPool.VerifyCertificate(time.Now(), c) - assert.Nil(t, err) + assert.NoError(t, err) } func TestCertificateV1_Verify_IPs(t *testing.T) { @@ -205,7 +205,7 @@ func TestCertificateV1_Verify_IPs(t *testing.T) { ca, _, caKey, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), []netip.Prefix{caIp1, caIp2}, nil, []string{"test"}) caPem, err := ca.MarshalPEM() - assert.Nil(t, err) + assert.NoError(t, err) caPool := NewCAPool() b, err := caPool.AddCAFromPEM(caPem) @@ -245,25 +245,25 @@ func TestCertificateV1_Verify_IPs(t *testing.T) { cIp2 = mustParsePrefixUnmapped("192.168.0.1/25") c, _, _, _ := NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"}) _, err = caPool.VerifyCertificate(time.Now(), c) - assert.Nil(t, err) + assert.NoError(t, err) // Exact matches c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp1, caIp2}, nil, []string{"test"}) - assert.Nil(t, err) + assert.NoError(t, err) _, err = caPool.VerifyCertificate(time.Now(), c) - assert.Nil(t, err) + assert.NoError(t, err) // Exact matches reversed c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp2, caIp1}, nil, []string{"test"}) - assert.Nil(t, err) + assert.NoError(t, err) _, err = caPool.VerifyCertificate(time.Now(), c) - assert.Nil(t, err) + assert.NoError(t, err) // Exact matches reversed with just 1 c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp1}, nil, []string{"test"}) - assert.Nil(t, err) + assert.NoError(t, err) _, err = caPool.VerifyCertificate(time.Now(), c) - assert.Nil(t, err) + assert.NoError(t, err) } func TestCertificateV1_Verify_Subnets(t *testing.T) { @@ -272,7 +272,7 @@ func TestCertificateV1_Verify_Subnets(t *testing.T) { ca, _, caKey, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, []netip.Prefix{caIp1, caIp2}, []string{"test"}) caPem, err := ca.MarshalPEM() - assert.Nil(t, err) + assert.NoError(t, err) caPool := NewCAPool() b, err := caPool.AddCAFromPEM(caPem) @@ -311,27 +311,27 @@ func TestCertificateV1_Verify_Subnets(t *testing.T) { cIp1 = mustParsePrefixUnmapped("10.0.1.0/16") cIp2 = mustParsePrefixUnmapped("192.168.0.1/25") c, _, _, _ := NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"}) - assert.Nil(t, err) + assert.NoError(t, err) _, err = caPool.VerifyCertificate(time.Now(), c) - assert.Nil(t, err) + assert.NoError(t, err) // Exact matches c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp1, caIp2}, []string{"test"}) - assert.Nil(t, err) + assert.NoError(t, err) _, err = caPool.VerifyCertificate(time.Now(), c) - assert.Nil(t, err) + assert.NoError(t, err) // Exact matches reversed c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp2, caIp1}, []string{"test"}) - assert.Nil(t, err) + assert.NoError(t, err) _, err = caPool.VerifyCertificate(time.Now(), c) - assert.Nil(t, err) + assert.NoError(t, err) // Exact matches reversed with just 1 c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp1}, []string{"test"}) - assert.Nil(t, err) + assert.NoError(t, err) _, err = caPool.VerifyCertificate(time.Now(), c) - assert.Nil(t, err) + assert.NoError(t, err) } func TestCertificateV2_Verify(t *testing.T) { @@ -342,7 +342,7 @@ func TestCertificateV2_Verify(t *testing.T) { assert.NoError(t, caPool.AddCA(ca)) f, err := c.Fingerprint() - assert.Nil(t, err) + assert.NoError(t, err) caPool.BlocklistFingerprint(f) _, err = caPool.VerifyCertificate(time.Now(), c) @@ -350,7 +350,7 @@ func TestCertificateV2_Verify(t *testing.T) { caPool.ResetCertBlocklist() _, err = caPool.VerifyCertificate(time.Now(), c) - assert.Nil(t, err) + assert.NoError(t, err) _, err = caPool.VerifyCertificate(time.Now().Add(time.Hour*1000), c) assert.EqualError(t, err, "root certificate is expired") @@ -362,7 +362,7 @@ func TestCertificateV2_Verify(t *testing.T) { // Test group assertion ca, _, caKey, _ = NewTestCaCert(Version2, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{"test1", "test2"}) caPem, err := ca.MarshalPEM() - assert.Nil(t, err) + assert.NoError(t, err) caPool = NewCAPool() b, err := caPool.AddCAFromPEM(caPem) @@ -374,9 +374,9 @@ func TestCertificateV2_Verify(t *testing.T) { }) c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test2", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"}) - assert.Nil(t, err) + assert.NoError(t, err) _, err = caPool.VerifyCertificate(time.Now(), c) - assert.Nil(t, err) + assert.NoError(t, err) } func TestCertificateV2_VerifyP256(t *testing.T) { @@ -387,7 +387,7 @@ func TestCertificateV2_VerifyP256(t *testing.T) { assert.NoError(t, caPool.AddCA(ca)) f, err := c.Fingerprint() - assert.Nil(t, err) + assert.NoError(t, err) caPool.BlocklistFingerprint(f) _, err = caPool.VerifyCertificate(time.Now(), c) @@ -395,7 +395,7 @@ func TestCertificateV2_VerifyP256(t *testing.T) { caPool.ResetCertBlocklist() _, err = caPool.VerifyCertificate(time.Now(), c) - assert.Nil(t, err) + assert.NoError(t, err) _, err = caPool.VerifyCertificate(time.Now().Add(time.Hour*1000), c) assert.EqualError(t, err, "root certificate is expired") @@ -407,7 +407,7 @@ func TestCertificateV2_VerifyP256(t *testing.T) { // Test group assertion ca, _, caKey, _ = NewTestCaCert(Version2, Curve_P256, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{"test1", "test2"}) caPem, err := ca.MarshalPEM() - assert.Nil(t, err) + assert.NoError(t, err) caPool = NewCAPool() b, err := caPool.AddCAFromPEM(caPem) @@ -420,7 +420,7 @@ func TestCertificateV2_VerifyP256(t *testing.T) { c, _, _, _ = NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"}) _, err = caPool.VerifyCertificate(time.Now(), c) - assert.Nil(t, err) + assert.NoError(t, err) } func TestCertificateV2_Verify_IPs(t *testing.T) { @@ -429,7 +429,7 @@ func TestCertificateV2_Verify_IPs(t *testing.T) { ca, _, caKey, _ := NewTestCaCert(Version2, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), []netip.Prefix{caIp1, caIp2}, nil, []string{"test"}) caPem, err := ca.MarshalPEM() - assert.Nil(t, err) + assert.NoError(t, err) caPool := NewCAPool() b, err := caPool.AddCAFromPEM(caPem) @@ -469,25 +469,25 @@ func TestCertificateV2_Verify_IPs(t *testing.T) { cIp2 = mustParsePrefixUnmapped("192.168.0.1/25") c, _, _, _ := NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"}) _, err = caPool.VerifyCertificate(time.Now(), c) - assert.Nil(t, err) + assert.NoError(t, err) // Exact matches c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp1, caIp2}, nil, []string{"test"}) - assert.Nil(t, err) + assert.NoError(t, err) _, err = caPool.VerifyCertificate(time.Now(), c) - assert.Nil(t, err) + assert.NoError(t, err) // Exact matches reversed c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp2, caIp1}, nil, []string{"test"}) - assert.Nil(t, err) + assert.NoError(t, err) _, err = caPool.VerifyCertificate(time.Now(), c) - assert.Nil(t, err) + assert.NoError(t, err) // Exact matches reversed with just 1 c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp1}, nil, []string{"test"}) - assert.Nil(t, err) + assert.NoError(t, err) _, err = caPool.VerifyCertificate(time.Now(), c) - assert.Nil(t, err) + assert.NoError(t, err) } func TestCertificateV2_Verify_Subnets(t *testing.T) { @@ -496,7 +496,7 @@ func TestCertificateV2_Verify_Subnets(t *testing.T) { ca, _, caKey, _ := NewTestCaCert(Version2, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, []netip.Prefix{caIp1, caIp2}, []string{"test"}) caPem, err := ca.MarshalPEM() - assert.Nil(t, err) + assert.NoError(t, err) caPool := NewCAPool() b, err := caPool.AddCAFromPEM(caPem) @@ -535,25 +535,25 @@ func TestCertificateV2_Verify_Subnets(t *testing.T) { cIp1 = mustParsePrefixUnmapped("10.0.1.0/16") cIp2 = mustParsePrefixUnmapped("192.168.0.1/25") c, _, _, _ := NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"}) - assert.Nil(t, err) + assert.NoError(t, err) _, err = caPool.VerifyCertificate(time.Now(), c) - assert.Nil(t, err) + assert.NoError(t, err) // Exact matches c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp1, caIp2}, []string{"test"}) - assert.Nil(t, err) + assert.NoError(t, err) _, err = caPool.VerifyCertificate(time.Now(), c) - assert.Nil(t, err) + assert.NoError(t, err) // Exact matches reversed c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp2, caIp1}, []string{"test"}) - assert.Nil(t, err) + assert.NoError(t, err) _, err = caPool.VerifyCertificate(time.Now(), c) - assert.Nil(t, err) + assert.NoError(t, err) // Exact matches reversed with just 1 c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp1}, []string{"test"}) - assert.Nil(t, err) + assert.NoError(t, err) _, err = caPool.VerifyCertificate(time.Now(), c) - assert.Nil(t, err) + assert.NoError(t, err) } diff --git a/cert/cert_v1_test.go b/cert/cert_v1_test.go index 8c3fe93..ea98b08 100644 --- a/cert/cert_v1_test.go +++ b/cert/cert_v1_test.go @@ -39,14 +39,14 @@ func TestCertificateV1_Marshal(t *testing.T) { } b, err := nc.Marshal() - assert.Nil(t, err) + assert.NoError(t, err) //t.Log("Cert size:", len(b)) nc2, err := unmarshalCertificateV1(b, nil) - assert.Nil(t, err) + assert.NoError(t, err) - assert.Equal(t, nc.Version(), Version1) - assert.Equal(t, nc.Curve(), Curve_CURVE25519) + assert.Equal(t, Version1, nc.Version()) + assert.Equal(t, Curve_CURVE25519, nc.Curve()) assert.Equal(t, nc.Signature(), nc2.Signature()) assert.Equal(t, nc.Name(), nc2.Name()) assert.Equal(t, nc.NotBefore(), nc2.NotBefore()) @@ -99,8 +99,8 @@ func TestCertificateV1_MarshalJSON(t *testing.T) { } b, err := nc.MarshalJSON() - assert.Nil(t, err) - assert.Equal( + assert.NoError(t, err) + assert.JSONEq( t, "{\"details\":{\"curve\":\"CURVE25519\",\"groups\":[\"test-group1\",\"test-group2\",\"test-group3\"],\"isCa\":false,\"issuer\":\"1234567890abcedfghij1234567890ab\",\"name\":\"testing\",\"networks\":[\"10.1.1.1/24\",\"10.1.1.2/16\"],\"notAfter\":\"0000-11-30T02:00:00Z\",\"notBefore\":\"0000-11-30T01:00:00Z\",\"publicKey\":\"313233343536373839306162636564666768696a313233343536373839306162\",\"unsafeNetworks\":[\"9.1.1.2/24\",\"9.1.1.3/16\"]},\"fingerprint\":\"3944c53d4267a229295b56cb2d27d459164c010ac97d655063ba421e0670f4ba\",\"signature\":\"313233343536373839306162636564666768696a313233343536373839306162\",\"version\":1}", string(b), @@ -110,12 +110,12 @@ func TestCertificateV1_MarshalJSON(t *testing.T) { func TestCertificateV1_VerifyPrivateKey(t *testing.T) { ca, _, caKey, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil) err := ca.VerifyPrivateKey(Curve_CURVE25519, caKey) - assert.Nil(t, err) + assert.NoError(t, err) _, _, caKey2, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil) - assert.Nil(t, err) + assert.NoError(t, err) err = ca.VerifyPrivateKey(Curve_CURVE25519, caKey2) - assert.NotNil(t, err) + assert.Error(t, err) c, _, priv, _ := NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil) rawPriv, b, curve, err := UnmarshalPrivateKeyFromPEM(priv) @@ -123,22 +123,22 @@ func TestCertificateV1_VerifyPrivateKey(t *testing.T) { assert.Empty(t, b) assert.Equal(t, Curve_CURVE25519, curve) err = c.VerifyPrivateKey(Curve_CURVE25519, rawPriv) - assert.Nil(t, err) + assert.NoError(t, err) _, priv2 := X25519Keypair() err = c.VerifyPrivateKey(Curve_CURVE25519, priv2) - assert.NotNil(t, err) + assert.Error(t, err) } func TestCertificateV1_VerifyPrivateKeyP256(t *testing.T) { ca, _, caKey, _ := NewTestCaCert(Version1, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil) err := ca.VerifyPrivateKey(Curve_P256, caKey) - assert.Nil(t, err) + assert.NoError(t, err) _, _, caKey2, _ := NewTestCaCert(Version1, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil) - assert.Nil(t, err) + assert.NoError(t, err) err = ca.VerifyPrivateKey(Curve_P256, caKey2) - assert.NotNil(t, err) + assert.Error(t, err) c, _, priv, _ := NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil) rawPriv, b, curve, err := UnmarshalPrivateKeyFromPEM(priv) @@ -146,11 +146,11 @@ func TestCertificateV1_VerifyPrivateKeyP256(t *testing.T) { assert.Empty(t, b) assert.Equal(t, Curve_P256, curve) err = c.VerifyPrivateKey(Curve_P256, rawPriv) - assert.Nil(t, err) + assert.NoError(t, err) _, priv2 := P256Keypair() err = c.VerifyPrivateKey(Curve_P256, priv2) - assert.NotNil(t, err) + assert.Error(t, err) } // Ensure that upgrading the protobuf library does not change how certificates @@ -182,11 +182,11 @@ func TestMarshalingCertificateV1Consistency(t *testing.T) { } b, err := nc.Marshal() - require.Nil(t, err) + require.NoError(t, err) assert.Equal(t, "0a8e010a0774657374696e671212828284508080fcff0f8182845080feffff0f1a12838284488080fcff0f8282844880feffff0f220b746573742d67726f757031220b746573742d67726f757032220b746573742d67726f75703328cd1c30cdb8ccf0af073a20313233343536373839306162636564666768696a3132333435363738393061624a081234567890abcedf1220313233343536373839306162636564666768696a313233343536373839306162", fmt.Sprintf("%x", b)) b, err = proto.Marshal(nc.getRawDetails()) - assert.Nil(t, err) + assert.NoError(t, err) assert.Equal(t, "0a0774657374696e671212828284508080fcff0f8182845080feffff0f1a12838284488080fcff0f8282844880feffff0f220b746573742d67726f757031220b746573742d67726f757032220b746573742d67726f75703328cd1c30cdb8ccf0af073a20313233343536373839306162636564666768696a3132333435363738393061624a081234567890abcedf", fmt.Sprintf("%x", b)) } diff --git a/cert/cert_v2_test.go b/cert/cert_v2_test.go index 3afbcab..6d55750 100644 --- a/cert/cert_v2_test.go +++ b/cert/cert_v2_test.go @@ -45,14 +45,14 @@ func TestCertificateV2_Marshal(t *testing.T) { nc.rawDetails = db b, err := nc.Marshal() - require.Nil(t, err) + require.NoError(t, err) //t.Log("Cert size:", len(b)) nc2, err := unmarshalCertificateV2(b, nil, Curve_CURVE25519) - assert.Nil(t, err) + assert.NoError(t, err) - assert.Equal(t, nc.Version(), Version2) - assert.Equal(t, nc.Curve(), Curve_CURVE25519) + assert.Equal(t, Version2, nc.Version()) + assert.Equal(t, Curve_CURVE25519, nc.Curve()) assert.Equal(t, nc.Signature(), nc2.Signature()) assert.Equal(t, nc.Name(), nc2.Name()) assert.Equal(t, nc.NotBefore(), nc2.NotBefore()) @@ -121,8 +121,8 @@ func TestCertificateV2_MarshalJSON(t *testing.T) { nc.rawDetails = rd b, err = nc.MarshalJSON() - assert.Nil(t, err) - assert.Equal( + assert.NoError(t, err) + assert.JSONEq( t, "{\"curve\":\"CURVE25519\",\"details\":{\"groups\":[\"test-group1\",\"test-group2\",\"test-group3\"],\"isCa\":false,\"issuer\":\"1234567890abcedf1234567890abcedf\",\"name\":\"testing\",\"networks\":[\"10.1.1.1/24\",\"10.1.1.2/16\"],\"notAfter\":\"0000-11-30T02:00:00Z\",\"notBefore\":\"0000-11-30T01:00:00Z\",\"unsafeNetworks\":[\"9.1.1.2/24\",\"9.1.1.3/16\"]},\"fingerprint\":\"152d9a7400c1e001cb76cffd035215ebb351f69eeb797f7f847dd086e15e56dd\",\"publicKey\":\"3132333435363738393061626365646631323334353637383930616263656466\",\"signature\":\"31323334353637383930616263656466313233343536373839306162636564663132333435363738393061626365646631323334353637383930616263656466\",\"version\":2}", string(b), @@ -132,13 +132,13 @@ func TestCertificateV2_MarshalJSON(t *testing.T) { func TestCertificateV2_VerifyPrivateKey(t *testing.T) { ca, _, caKey, _ := NewTestCaCert(Version2, Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil) err := ca.VerifyPrivateKey(Curve_CURVE25519, caKey) - assert.Nil(t, err) + assert.NoError(t, err) err = ca.VerifyPrivateKey(Curve_CURVE25519, caKey[:16]) assert.ErrorIs(t, err, ErrInvalidPrivateKey) _, caKey2, err := ed25519.GenerateKey(rand.Reader) - require.Nil(t, err) + require.NoError(t, err) err = ca.VerifyPrivateKey(Curve_CURVE25519, caKey2) assert.ErrorIs(t, err, ErrPublicPrivateKeyMismatch) @@ -148,7 +148,7 @@ func TestCertificateV2_VerifyPrivateKey(t *testing.T) { assert.Empty(t, b) assert.Equal(t, Curve_CURVE25519, curve) err = c.VerifyPrivateKey(Curve_CURVE25519, rawPriv) - assert.Nil(t, err) + assert.NoError(t, err) _, priv2 := X25519Keypair() err = c.VerifyPrivateKey(Curve_P256, priv2) @@ -168,7 +168,7 @@ func TestCertificateV2_VerifyPrivateKey(t *testing.T) { ca2, _, caKey2, _ := NewTestCaCert(Version2, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil) err = ca.VerifyPrivateKey(Curve_CURVE25519, caKey) - assert.Nil(t, err) + assert.NoError(t, err) err = ca2.VerifyPrivateKey(Curve_P256, caKey2[:16]) assert.ErrorIs(t, err, ErrInvalidPrivateKey) @@ -193,12 +193,12 @@ func TestCertificateV2_VerifyPrivateKey(t *testing.T) { func TestCertificateV2_VerifyPrivateKeyP256(t *testing.T) { ca, _, caKey, _ := NewTestCaCert(Version2, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil) err := ca.VerifyPrivateKey(Curve_P256, caKey) - assert.Nil(t, err) + assert.NoError(t, err) _, _, caKey2, _ := NewTestCaCert(Version2, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil) - assert.Nil(t, err) + assert.NoError(t, err) err = ca.VerifyPrivateKey(Curve_P256, caKey2) - assert.NotNil(t, err) + assert.Error(t, err) c, _, priv, _ := NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil) rawPriv, b, curve, err := UnmarshalPrivateKeyFromPEM(priv) @@ -206,11 +206,11 @@ func TestCertificateV2_VerifyPrivateKeyP256(t *testing.T) { assert.Empty(t, b) assert.Equal(t, Curve_P256, curve) err = c.VerifyPrivateKey(Curve_P256, rawPriv) - assert.Nil(t, err) + assert.NoError(t, err) _, priv2 := P256Keypair() err = c.VerifyPrivateKey(Curve_P256, priv2) - assert.NotNil(t, err) + assert.Error(t, err) } func TestCertificateV2_Copy(t *testing.T) { diff --git a/cert/crypto_test.go b/cert/crypto_test.go index c9aba3e..c43eed7 100644 --- a/cert/crypto_test.go +++ b/cert/crypto_test.go @@ -61,7 +61,7 @@ qrlJ69wer3ZUHFXA // Success test case curve, k, rest, err := DecryptAndUnmarshalSigningPrivateKey(passphrase, keyBundle) - assert.Nil(t, err) + assert.NoError(t, err) assert.Equal(t, Curve_CURVE25519, curve) assert.Len(t, k, 64) assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem)) @@ -89,7 +89,7 @@ qrlJ69wer3ZUHFXA curve, k, rest, err = DecryptAndUnmarshalSigningPrivateKey([]byte("invalid passphrase"), privKey) assert.EqualError(t, err, "invalid passphrase or corrupt private key") assert.Nil(t, k) - assert.Equal(t, rest, []byte{}) + assert.Equal(t, []byte{}, rest) } func TestEncryptAndMarshalSigningPrivateKey(t *testing.T) { @@ -99,14 +99,14 @@ func TestEncryptAndMarshalSigningPrivateKey(t *testing.T) { bytes := []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA") kdfParams := NewArgon2Parameters(64*1024, 4, 3) key, err := EncryptAndMarshalSigningPrivateKey(Curve_CURVE25519, bytes, passphrase, kdfParams) - assert.Nil(t, err) + assert.NoError(t, err) // Verify the "key" can be decrypted successfully curve, k, rest, err := DecryptAndUnmarshalSigningPrivateKey(passphrase, key) assert.Len(t, k, 64) assert.Equal(t, Curve_CURVE25519, curve) - assert.Equal(t, rest, []byte{}) - assert.Nil(t, err) + assert.Equal(t, []byte{}, rest) + assert.NoError(t, err) // EncryptAndMarshalEd25519PrivateKey does not create any errors itself } diff --git a/cert/pem_test.go b/cert/pem_test.go index a0c6e74..9ad8a69 100644 --- a/cert/pem_test.go +++ b/cert/pem_test.go @@ -35,7 +35,7 @@ bzBEr00kERQxxTzTsH8cpYEgRoipvmExvg8WP8NdAJEYJosB cert, rest, err := UnmarshalCertificateFromPEM(certBundle) assert.NotNil(t, cert) assert.Equal(t, rest, append(badBanner, invalidPem...)) - assert.Nil(t, err) + assert.NoError(t, err) // Fail due to invalid banner. cert, rest, err = UnmarshalCertificateFromPEM(rest) @@ -84,14 +84,14 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA assert.Len(t, k, 64) assert.Equal(t, rest, appendByteSlices(privP256Key, shortKey, invalidBanner, invalidPem)) assert.Equal(t, Curve_CURVE25519, curve) - assert.Nil(t, err) + assert.NoError(t, err) // Success test case k, rest, curve, err = UnmarshalSigningPrivateKeyFromPEM(rest) assert.Len(t, k, 32) assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem)) assert.Equal(t, Curve_P256, curve) - assert.Nil(t, err) + assert.NoError(t, err) // Fail due to short key k, rest, curve, err = UnmarshalSigningPrivateKeyFromPEM(rest) @@ -146,14 +146,14 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA= assert.Len(t, k, 32) assert.Equal(t, rest, appendByteSlices(privP256Key, shortKey, invalidBanner, invalidPem)) assert.Equal(t, Curve_CURVE25519, curve) - assert.Nil(t, err) + assert.NoError(t, err) // Success test case k, rest, curve, err = UnmarshalPrivateKeyFromPEM(rest) assert.Len(t, k, 32) assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem)) assert.Equal(t, Curve_P256, curve) - assert.Nil(t, err) + assert.NoError(t, err) // Fail due to short key k, rest, curve, err = UnmarshalPrivateKeyFromPEM(rest) @@ -200,9 +200,9 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA= // Success test case k, rest, curve, err := UnmarshalPublicKeyFromPEM(keyBundle) - assert.Equal(t, 32, len(k)) + assert.Len(t, k, 32) assert.Equal(t, Curve_CURVE25519, curve) - assert.Nil(t, err) + assert.NoError(t, err) assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem)) // Fail due to short key @@ -259,15 +259,15 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA= // Success test case k, rest, curve, err := UnmarshalPublicKeyFromPEM(keyBundle) - assert.Equal(t, 32, len(k)) - assert.Nil(t, err) + assert.Len(t, k, 32) + assert.NoError(t, err) assert.Equal(t, rest, appendByteSlices(pubP256Key, shortKey, invalidBanner, invalidPem)) assert.Equal(t, Curve_CURVE25519, curve) // Success test case k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest) - assert.Equal(t, 65, len(k)) - assert.Nil(t, err) + assert.Len(t, k, 65) + assert.NoError(t, err) assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem)) assert.Equal(t, Curve_P256, curve) diff --git a/cert/sign_test.go b/cert/sign_test.go index 2b8dbe8..30d8480 100644 --- a/cert/sign_test.go +++ b/cert/sign_test.go @@ -37,14 +37,14 @@ func TestCertificateV1_Sign(t *testing.T) { pub, priv, err := ed25519.GenerateKey(rand.Reader) c, err := tbs.Sign(&certificateV1{details: detailsV1{notBefore: before, notAfter: after}}, Curve_CURVE25519, priv) - assert.Nil(t, err) + assert.NoError(t, err) assert.NotNil(t, c) assert.True(t, c.CheckSignature(pub)) b, err := c.Marshal() - assert.Nil(t, err) + assert.NoError(t, err) uc, err := unmarshalCertificateV1(b, nil) - assert.Nil(t, err) + assert.NoError(t, err) assert.NotNil(t, uc) } @@ -78,13 +78,13 @@ func TestCertificateV1_SignP256(t *testing.T) { rawPriv := priv.D.FillBytes(make([]byte, 32)) c, err := tbs.Sign(&certificateV1{details: detailsV1{notBefore: before, notAfter: after}}, Curve_P256, rawPriv) - assert.Nil(t, err) + assert.NoError(t, err) assert.NotNil(t, c) assert.True(t, c.CheckSignature(pub)) b, err := c.Marshal() - assert.Nil(t, err) + assert.NoError(t, err) uc, err := unmarshalCertificateV1(b, nil) - assert.Nil(t, err) + assert.NoError(t, err) assert.NotNil(t, uc) } diff --git a/cmd/nebula-cert/ca_test.go b/cmd/nebula-cert/ca_test.go index 9da0ad4..71b69be 100644 --- a/cmd/nebula-cert/ca_test.go +++ b/cmd/nebula-cert/ca_test.go @@ -112,8 +112,8 @@ func Test_ca(t *testing.T) { // create temp key file keyF, err := os.CreateTemp("", "test.key") - assert.Nil(t, err) - assert.Nil(t, os.Remove(keyF.Name())) + assert.NoError(t, err) + assert.NoError(t, os.Remove(keyF.Name())) // failed cert write ob.Reset() @@ -125,15 +125,15 @@ func Test_ca(t *testing.T) { // create temp cert file crtF, err := os.CreateTemp("", "test.crt") - assert.Nil(t, err) - assert.Nil(t, os.Remove(crtF.Name())) - assert.Nil(t, os.Remove(keyF.Name())) + assert.NoError(t, err) + assert.NoError(t, os.Remove(crtF.Name())) + assert.NoError(t, os.Remove(keyF.Name())) // test proper cert with removed empty groups and subnets ob.Reset() eb.Reset() args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} - assert.Nil(t, ca(args, ob, eb, nopw)) + assert.NoError(t, ca(args, ob, eb, nopw)) assert.Equal(t, "", ob.String()) assert.Equal(t, "", eb.String()) @@ -141,20 +141,20 @@ func Test_ca(t *testing.T) { rb, _ := os.ReadFile(keyF.Name()) lKey, b, c, err := cert.UnmarshalSigningPrivateKeyFromPEM(rb) assert.Equal(t, cert.Curve_CURVE25519, c) - assert.Len(t, b, 0) - assert.Nil(t, err) + assert.Empty(t, b) + assert.NoError(t, err) assert.Len(t, lKey, 64) rb, _ = os.ReadFile(crtF.Name()) lCrt, b, err := cert.UnmarshalCertificateFromPEM(rb) - assert.Len(t, b, 0) - assert.Nil(t, err) + assert.Empty(t, b) + assert.NoError(t, err) assert.Equal(t, "test", lCrt.Name()) - assert.Len(t, lCrt.Networks(), 0) + assert.Empty(t, lCrt.Networks()) assert.True(t, lCrt.IsCA()) assert.Equal(t, []string{"1", "2", "3", "4", "5"}, lCrt.Groups()) - assert.Len(t, lCrt.UnsafeNetworks(), 0) + assert.Empty(t, lCrt.UnsafeNetworks()) assert.Len(t, lCrt.PublicKey(), 32) assert.Equal(t, time.Duration(time.Minute*100), lCrt.NotAfter().Sub(lCrt.NotBefore())) assert.Equal(t, "", lCrt.Issuer()) @@ -166,7 +166,7 @@ func Test_ca(t *testing.T) { ob.Reset() eb.Reset() args = []string{"-version", "1", "-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} - assert.Nil(t, ca(args, ob, eb, testpw)) + assert.NoError(t, ca(args, ob, eb, testpw)) assert.Equal(t, pwPromptOb, ob.String()) assert.Equal(t, "", eb.String()) @@ -174,7 +174,7 @@ func Test_ca(t *testing.T) { rb, _ = os.ReadFile(keyF.Name()) k, _ := pem.Decode(rb) ned, err := cert.UnmarshalNebulaEncryptedData(k.Bytes) - assert.Nil(t, err) + assert.NoError(t, err) // we won't know salt in advance, so just check start of string assert.Equal(t, uint32(2*1024*1024), ned.EncryptionMetadata.Argon2Parameters.Memory) assert.Equal(t, uint8(4), ned.EncryptionMetadata.Argon2Parameters.Parallelism) @@ -184,8 +184,8 @@ func Test_ca(t *testing.T) { var curve cert.Curve curve, lKey, b, err = cert.DecryptAndUnmarshalSigningPrivateKey(passphrase, rb) assert.Equal(t, cert.Curve_CURVE25519, curve) - assert.Nil(t, err) - assert.Len(t, b, 0) + assert.NoError(t, err) + assert.Empty(t, b) assert.Len(t, lKey, 64) // test when reading passsword results in an error @@ -214,7 +214,7 @@ func Test_ca(t *testing.T) { ob.Reset() eb.Reset() args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} - assert.Nil(t, ca(args, ob, eb, nopw)) + assert.NoError(t, ca(args, ob, eb, nopw)) // test that we won't overwrite existing certificate file ob.Reset() diff --git a/cmd/nebula-cert/keygen_test.go b/cmd/nebula-cert/keygen_test.go index fcfd77b..3427254 100644 --- a/cmd/nebula-cert/keygen_test.go +++ b/cmd/nebula-cert/keygen_test.go @@ -53,7 +53,7 @@ func Test_keygen(t *testing.T) { // create temp key file keyF, err := os.CreateTemp("", "test.key") - assert.Nil(t, err) + assert.NoError(t, err) defer os.Remove(keyF.Name()) // failed pub write @@ -66,14 +66,14 @@ func Test_keygen(t *testing.T) { // create temp pub file pubF, err := os.CreateTemp("", "test.pub") - assert.Nil(t, err) + assert.NoError(t, err) defer os.Remove(pubF.Name()) // test proper keygen ob.Reset() eb.Reset() args = []string{"-out-pub", pubF.Name(), "-out-key", keyF.Name()} - assert.Nil(t, keygen(args, ob, eb)) + assert.NoError(t, keygen(args, ob, eb)) assert.Equal(t, "", ob.String()) assert.Equal(t, "", eb.String()) @@ -81,14 +81,14 @@ func Test_keygen(t *testing.T) { rb, _ := os.ReadFile(keyF.Name()) lKey, b, curve, err := cert.UnmarshalPrivateKeyFromPEM(rb) assert.Equal(t, cert.Curve_CURVE25519, curve) - assert.Len(t, b, 0) - assert.Nil(t, err) + assert.Empty(t, b) + assert.NoError(t, err) assert.Len(t, lKey, 32) rb, _ = os.ReadFile(pubF.Name()) lPub, b, curve, err := cert.UnmarshalPublicKeyFromPEM(rb) assert.Equal(t, cert.Curve_CURVE25519, curve) - assert.Len(t, b, 0) - assert.Nil(t, err) + assert.Empty(t, b) + assert.NoError(t, err) assert.Len(t, lPub, 32) } diff --git a/cmd/nebula-cert/print_test.go b/cmd/nebula-cert/print_test.go index 86795e4..77e98e6 100644 --- a/cmd/nebula-cert/print_test.go +++ b/cmd/nebula-cert/print_test.go @@ -58,7 +58,7 @@ func Test_printCert(t *testing.T) { ob.Reset() eb.Reset() tf, err := os.CreateTemp("", "print-cert") - assert.Nil(t, err) + assert.NoError(t, err) defer os.Remove(tf.Name()) tf.WriteString("-----BEGIN NOPE-----") @@ -84,7 +84,7 @@ func Test_printCert(t *testing.T) { fp, _ := c.Fingerprint() pk := hex.EncodeToString(c.PublicKey()) sig := hex.EncodeToString(c.Signature()) - assert.Nil(t, err) + assert.NoError(t, err) assert.Equal( t, //"NebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: "+c.Issuer()+"\n\t\tPublic key: "+pk+"\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: "+fp+"\n\tSignature: "+sig+"\n}\nNebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: "+c.Issuer()+"\n\t\tPublic key: "+pk+"\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: "+fp+"\n\tSignature: "+sig+"\n}\nNebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: "+c.Issuer()+"\n\t\tPublic key: "+pk+"\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: "+fp+"\n\tSignature: "+sig+"\n}\n", @@ -169,7 +169,7 @@ func Test_printCert(t *testing.T) { fp, _ = c.Fingerprint() pk = hex.EncodeToString(c.PublicKey()) sig = hex.EncodeToString(c.Signature()) - assert.Nil(t, err) + assert.NoError(t, err) assert.Equal( t, `[{"details":{"curve":"CURVE25519","groups":["hi"],"isCa":false,"issuer":"`+c.Issuer()+`","name":"test","networks":["10.0.0.123/8"],"notAfter":"0001-01-01T00:00:00Z","notBefore":"0001-01-01T00:00:00Z","publicKey":"`+pk+`","unsafeNetworks":[]},"fingerprint":"`+fp+`","signature":"`+sig+`","version":1},{"details":{"curve":"CURVE25519","groups":["hi"],"isCa":false,"issuer":"`+c.Issuer()+`","name":"test","networks":["10.0.0.123/8"],"notAfter":"0001-01-01T00:00:00Z","notBefore":"0001-01-01T00:00:00Z","publicKey":"`+pk+`","unsafeNetworks":[]},"fingerprint":"`+fp+`","signature":"`+sig+`","version":1},{"details":{"curve":"CURVE25519","groups":["hi"],"isCa":false,"issuer":"`+c.Issuer()+`","name":"test","networks":["10.0.0.123/8"],"notAfter":"0001-01-01T00:00:00Z","notBefore":"0001-01-01T00:00:00Z","publicKey":"`+pk+`","unsafeNetworks":[]},"fingerprint":"`+fp+`","signature":"`+sig+`","version":1}] diff --git a/cmd/nebula-cert/sign_test.go b/cmd/nebula-cert/sign_test.go index 466cb8c..4b242a4 100644 --- a/cmd/nebula-cert/sign_test.go +++ b/cmd/nebula-cert/sign_test.go @@ -109,7 +109,7 @@ func Test_signCert(t *testing.T) { ob.Reset() eb.Reset() caKeyF, err := os.CreateTemp("", "sign-cert.key") - assert.Nil(t, err) + assert.NoError(t, err) defer os.Remove(caKeyF.Name()) args = []string{"-version", "1", "-ca-crt", "./nope", "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"} @@ -133,7 +133,7 @@ func Test_signCert(t *testing.T) { ob.Reset() eb.Reset() caCrtF, err := os.CreateTemp("", "sign-cert.crt") - assert.Nil(t, err) + assert.NoError(t, err) defer os.Remove(caCrtF.Name()) args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"} @@ -156,7 +156,7 @@ func Test_signCert(t *testing.T) { ob.Reset() eb.Reset() inPubF, err := os.CreateTemp("", "in.pub") - assert.Nil(t, err) + assert.NoError(t, err) defer os.Remove(inPubF.Name()) args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-in-pub", inPubF.Name(), "-duration", "100m"} @@ -210,7 +210,7 @@ func Test_signCert(t *testing.T) { // mismatched ca key _, caPriv2, _ := ed25519.GenerateKey(rand.Reader) caKeyF2, err := os.CreateTemp("", "sign-cert-2.key") - assert.Nil(t, err) + assert.NoError(t, err) defer os.Remove(caKeyF2.Name()) caKeyF2.Write(cert.MarshalSigningPrivateKeyToPEM(cert.Curve_CURVE25519, caPriv2)) @@ -231,7 +231,7 @@ func Test_signCert(t *testing.T) { // create temp key file keyF, err := os.CreateTemp("", "test.key") - assert.Nil(t, err) + assert.NoError(t, err) os.Remove(keyF.Name()) // failed cert write @@ -245,14 +245,14 @@ func Test_signCert(t *testing.T) { // create temp cert file crtF, err := os.CreateTemp("", "test.crt") - assert.Nil(t, err) + assert.NoError(t, err) os.Remove(crtF.Name()) // test proper cert with removed empty groups and subnets ob.Reset() eb.Reset() args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} - assert.Nil(t, signCert(args, ob, eb, nopw)) + assert.NoError(t, signCert(args, ob, eb, nopw)) assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) @@ -260,14 +260,14 @@ func Test_signCert(t *testing.T) { rb, _ := os.ReadFile(keyF.Name()) lKey, b, curve, err := cert.UnmarshalPrivateKeyFromPEM(rb) assert.Equal(t, cert.Curve_CURVE25519, curve) - assert.Len(t, b, 0) - assert.Nil(t, err) + assert.Empty(t, b) + assert.NoError(t, err) assert.Len(t, lKey, 32) rb, _ = os.ReadFile(crtF.Name()) lCrt, b, err := cert.UnmarshalCertificateFromPEM(rb) - assert.Len(t, b, 0) - assert.Nil(t, err) + assert.Empty(t, b) + assert.NoError(t, err) assert.Equal(t, "test", lCrt.Name()) assert.Equal(t, "1.1.1.1/24", lCrt.Networks()[0].String()) @@ -295,15 +295,15 @@ func Test_signCert(t *testing.T) { ob.Reset() eb.Reset() args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-in-pub", inPubF.Name(), "-duration", "100m", "-groups", "1"} - assert.Nil(t, signCert(args, ob, eb, nopw)) + assert.NoError(t, signCert(args, ob, eb, nopw)) assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) // read cert file and check pub key matches in-pub rb, _ = os.ReadFile(crtF.Name()) lCrt, b, err = cert.UnmarshalCertificateFromPEM(rb) - assert.Len(t, b, 0) - assert.Nil(t, err) + assert.Empty(t, b) + assert.NoError(t, err) assert.Equal(t, lCrt.PublicKey(), inPub) // test refuse to sign cert with duration beyond root @@ -320,7 +320,7 @@ func Test_signCert(t *testing.T) { os.Remove(keyF.Name()) os.Remove(crtF.Name()) args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} - assert.Nil(t, signCert(args, ob, eb, nopw)) + assert.NoError(t, signCert(args, ob, eb, nopw)) // test that we won't overwrite existing key file os.Remove(crtF.Name()) @@ -335,7 +335,7 @@ func Test_signCert(t *testing.T) { os.Remove(keyF.Name()) os.Remove(crtF.Name()) args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} - assert.Nil(t, signCert(args, ob, eb, nopw)) + assert.NoError(t, signCert(args, ob, eb, nopw)) // test that we won't overwrite existing certificate file os.Remove(keyF.Name()) @@ -355,11 +355,11 @@ func Test_signCert(t *testing.T) { eb.Reset() caKeyF, err = os.CreateTemp("", "sign-cert.key") - assert.Nil(t, err) + assert.NoError(t, err) defer os.Remove(caKeyF.Name()) caCrtF, err = os.CreateTemp("", "sign-cert.crt") - assert.Nil(t, err) + assert.NoError(t, err) defer os.Remove(caCrtF.Name()) // generate the encrypted key @@ -374,7 +374,7 @@ func Test_signCert(t *testing.T) { // test with the proper password args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} - assert.Nil(t, signCert(args, ob, eb, testpw)) + assert.NoError(t, signCert(args, ob, eb, testpw)) assert.Equal(t, "Enter passphrase: ", ob.String()) assert.Empty(t, eb.String()) diff --git a/cmd/nebula-cert/verify_test.go b/cmd/nebula-cert/verify_test.go index d94bd1f..c2a9f55 100644 --- a/cmd/nebula-cert/verify_test.go +++ b/cmd/nebula-cert/verify_test.go @@ -3,7 +3,6 @@ package main import ( "bytes" "crypto/rand" - "errors" "os" "testing" "time" @@ -57,7 +56,7 @@ func Test_verify(t *testing.T) { ob.Reset() eb.Reset() caFile, err := os.CreateTemp("", "verify-ca") - assert.Nil(t, err) + assert.NoError(t, err) defer os.Remove(caFile.Name()) caFile.WriteString("-----BEGIN NOPE-----") @@ -84,7 +83,7 @@ func Test_verify(t *testing.T) { ob.Reset() eb.Reset() certFile, err := os.CreateTemp("", "verify-cert") - assert.Nil(t, err) + assert.NoError(t, err) defer os.Remove(certFile.Name()) certFile.WriteString("-----BEGIN NOPE-----") @@ -108,7 +107,7 @@ func Test_verify(t *testing.T) { err = verify([]string{"-ca", caFile.Name(), "-crt", certFile.Name()}, ob, eb) assert.Equal(t, "", ob.String()) assert.Equal(t, "", eb.String()) - assert.True(t, errors.Is(err, cert.ErrSignatureMismatch)) + assert.ErrorIs(t, err, cert.ErrSignatureMismatch) // verified cert at path crt, _ = NewTestCert(ca, caPriv, "test-cert", time.Now().Add(time.Hour*-1), time.Now().Add(time.Hour), nil, nil, nil) @@ -120,5 +119,5 @@ func Test_verify(t *testing.T) { err = verify([]string{"-ca", caFile.Name(), "-crt", certFile.Name()}, ob, eb) assert.Equal(t, "", ob.String()) assert.Equal(t, "", eb.String()) - assert.Nil(t, err) + assert.NoError(t, err) } diff --git a/config/config_test.go b/config/config_test.go index c3a1a73..39301f9 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -26,11 +26,11 @@ func TestConfig_Load(t *testing.T) { os.RemoveAll(dir) os.Mkdir(dir, 0755) - assert.Nil(t, err) + assert.NoError(t, err) os.WriteFile(filepath.Join(dir, "01.yaml"), []byte("outer:\n inner: hi"), 0644) os.WriteFile(filepath.Join(dir, "02.yml"), []byte("outer:\n inner: override\nnew: hi"), 0644) - assert.Nil(t, c.Load(dir)) + assert.NoError(t, c.Load(dir)) expected := map[interface{}]interface{}{ "outer": map[interface{}]interface{}{ "inner": "override", @@ -67,28 +67,28 @@ func TestConfig_GetBool(t *testing.T) { l := test.NewLogger() c := NewC(l) c.Settings["bool"] = true - assert.Equal(t, true, c.GetBool("bool", false)) + assert.True(t, c.GetBool("bool", false)) c.Settings["bool"] = "true" - assert.Equal(t, true, c.GetBool("bool", false)) + assert.True(t, c.GetBool("bool", false)) c.Settings["bool"] = false - assert.Equal(t, false, c.GetBool("bool", true)) + assert.False(t, c.GetBool("bool", true)) c.Settings["bool"] = "false" - assert.Equal(t, false, c.GetBool("bool", true)) + assert.False(t, c.GetBool("bool", true)) c.Settings["bool"] = "Y" - assert.Equal(t, true, c.GetBool("bool", false)) + assert.True(t, c.GetBool("bool", false)) c.Settings["bool"] = "yEs" - assert.Equal(t, true, c.GetBool("bool", false)) + assert.True(t, c.GetBool("bool", false)) c.Settings["bool"] = "N" - assert.Equal(t, false, c.GetBool("bool", true)) + assert.False(t, c.GetBool("bool", true)) c.Settings["bool"] = "nO" - assert.Equal(t, false, c.GetBool("bool", true)) + assert.False(t, c.GetBool("bool", true)) } func TestConfig_HasChanged(t *testing.T) { @@ -117,11 +117,11 @@ func TestConfig_ReloadConfig(t *testing.T) { l := test.NewLogger() done := make(chan bool, 1) dir, err := os.MkdirTemp("", "config-test") - assert.Nil(t, err) + assert.NoError(t, err) os.WriteFile(filepath.Join(dir, "01.yaml"), []byte("outer:\n inner: hi"), 0644) c := NewC(l) - assert.Nil(t, c.Load(dir)) + assert.NoError(t, c.Load(dir)) assert.False(t, c.HasChanged("outer.inner")) assert.False(t, c.HasChanged("outer")) diff --git a/firewall_test.go b/firewall_test.go index 8d32369..92914af 100644 --- a/firewall_test.go +++ b/firewall_test.go @@ -68,53 +68,53 @@ func TestFirewall_AddRule(t *testing.T) { ti, err := netip.ParsePrefix("1.2.3.4/32") assert.NoError(t, err) - assert.Nil(t, fw.AddRule(true, firewall.ProtoTCP, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", "")) + assert.NoError(t, fw.AddRule(true, firewall.ProtoTCP, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", "")) // An empty rule is any assert.True(t, fw.InRules.TCP[1].Any.Any.Any) assert.Empty(t, fw.InRules.TCP[1].Any.Groups) assert.Empty(t, fw.InRules.TCP[1].Any.Hosts) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) - assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) + assert.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) assert.Nil(t, fw.InRules.UDP[1].Any.Any) assert.Contains(t, fw.InRules.UDP[1].Any.Groups[0].Groups, "g1") assert.Empty(t, fw.InRules.UDP[1].Any.Hosts) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) - assert.Nil(t, fw.AddRule(true, firewall.ProtoICMP, 1, 1, []string{}, "h1", netip.Prefix{}, netip.Prefix{}, "", "")) + assert.NoError(t, fw.AddRule(true, firewall.ProtoICMP, 1, 1, []string{}, "h1", netip.Prefix{}, netip.Prefix{}, "", "")) assert.Nil(t, fw.InRules.ICMP[1].Any.Any) assert.Empty(t, fw.InRules.ICMP[1].Any.Groups) assert.Contains(t, fw.InRules.ICMP[1].Any.Hosts, "h1") fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) - assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti, netip.Prefix{}, "", "")) + assert.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti, netip.Prefix{}, "", "")) assert.Nil(t, fw.OutRules.AnyProto[1].Any.Any) _, ok := fw.OutRules.AnyProto[1].Any.CIDR.Get(ti) assert.True(t, ok) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) - assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, ti, "", "")) + assert.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, ti, "", "")) assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any) _, ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti) assert.True(t, ok) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) - assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "ca-name", "")) + assert.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "ca-name", "")) assert.Contains(t, fw.InRules.UDP[1].CANames, "ca-name") fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) - assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "", "ca-sha")) + assert.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "", "ca-sha")) assert.Contains(t, fw.InRules.UDP[1].CAShas, "ca-sha") fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) - assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", netip.Prefix{}, netip.Prefix{}, "", "")) + assert.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", netip.Prefix{}, netip.Prefix{}, "", "")) assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) anyIp, err := netip.ParsePrefix("0.0.0.0/0") assert.NoError(t, err) - assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp, netip.Prefix{}, "", "")) + assert.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp, netip.Prefix{}, "", "")) assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any) // Test error conditions @@ -155,7 +155,7 @@ func TestFirewall_Drop(t *testing.T) { h.buildNetworks(c.networks, c.unsafeNetworks) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c) - assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) + assert.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) cp := cert.NewCAPool() // Drop outbound @@ -174,28 +174,28 @@ func TestFirewall_Drop(t *testing.T) { // ensure signer doesn't get in the way of group checks fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) - assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum")) - assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad")) + assert.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum")) + assert.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad")) assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule) // test caSha doesn't drop on match fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) - assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad")) - assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum")) + assert.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad")) + assert.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum")) assert.NoError(t, fw.Drop(p, true, &h, cp, nil)) // ensure ca name doesn't get in the way of group checks cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}} fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) - assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", "")) - assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", "")) + assert.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", "")) + assert.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", "")) assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule) // test caName doesn't drop on match cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}} fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) - assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", "")) - assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", "")) + assert.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", "")) + assert.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", "")) assert.NoError(t, fw.Drop(p, true, &h, cp, nil)) } @@ -350,11 +350,11 @@ func TestFirewall_Drop2(t *testing.T) { h1.buildNetworks(c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks()) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) - assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) + assert.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) cp := cert.NewCAPool() // h1/c1 lacks the proper groups - assert.Error(t, fw.Drop(p, true, &h1, cp, nil), ErrNoMatchingRule) + assert.ErrorIs(t, fw.Drop(p, true, &h1, cp, nil), ErrNoMatchingRule) // c has the proper groups resetConntrack(fw) assert.NoError(t, fw.Drop(p, true, &h, cp, nil)) @@ -428,8 +428,8 @@ func TestFirewall_Drop3(t *testing.T) { h3.buildNetworks(c3.Certificate.Networks(), c3.Certificate.UnsafeNetworks()) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) - assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", netip.Prefix{}, netip.Prefix{}, "", "")) - assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-sha")) + assert.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", netip.Prefix{}, netip.Prefix{}, "", "")) + assert.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-sha")) cp := cert.NewCAPool() // c1 should pass because host match @@ -443,7 +443,7 @@ func TestFirewall_Drop3(t *testing.T) { // Test a remote address match fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) - assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.MustParsePrefix("1.2.3.4/24"), netip.Prefix{}, "", "")) + assert.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.MustParsePrefix("1.2.3.4/24"), netip.Prefix{}, "", "")) assert.NoError(t, fw.Drop(p, true, &h1, cp, nil)) } @@ -480,7 +480,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) { h.buildNetworks(c.Certificate.Networks(), c.Certificate.UnsafeNetworks()) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) - assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) + assert.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) cp := cert.NewCAPool() // Drop outbound @@ -493,7 +493,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) { oldFw := fw fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) - assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 10, 10, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) + assert.NoError(t, fw.AddRule(true, firewall.ProtoAny, 10, 10, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) fw.Conntrack = oldFw.Conntrack fw.rulesVersion = oldFw.rulesVersion + 1 @@ -502,7 +502,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) { oldFw = fw fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) - assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 11, 11, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) + assert.NoError(t, fw.AddRule(true, firewall.ProtoAny, 11, 11, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) fw.Conntrack = oldFw.Conntrack fw.rulesVersion = oldFw.rulesVersion + 1 @@ -605,22 +605,22 @@ func Test_parsePort(t *testing.T) { s, e, err := parsePort(" 1 - 2 ") assert.Equal(t, int32(1), s) assert.Equal(t, int32(2), e) - assert.Nil(t, err) + assert.NoError(t, err) s, e, err = parsePort("0-1") assert.Equal(t, int32(0), s) assert.Equal(t, int32(0), e) - assert.Nil(t, err) + assert.NoError(t, err) s, e, err = parsePort("9919") assert.Equal(t, int32(9919), s) assert.Equal(t, int32(9919), e) - assert.Nil(t, err) + assert.NoError(t, err) s, e, err = parsePort("any") assert.Equal(t, int32(0), s) assert.Equal(t, int32(0), e) - assert.Nil(t, err) + assert.NoError(t, err) } func TestNewFirewallFromConfig(t *testing.T) { @@ -688,28 +688,28 @@ func TestAddFirewallRulesFromConfig(t *testing.T) { conf := config.NewC(l) mf := &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "tcp", "host": "a"}}} - assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf)) + assert.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf)) assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoTCP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall) // Test adding udp rule conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "udp", "host": "a"}}} - assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf)) + assert.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf)) assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoUDP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall) // Test adding icmp rule conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "icmp", "host": "a"}}} - assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf)) + assert.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf)) assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall) // Test adding any rule conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "host": "a"}}} - assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf)) + assert.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall) // Test adding rule with cidr @@ -717,49 +717,49 @@ func TestAddFirewallRulesFromConfig(t *testing.T) { conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "cidr": cidr.String()}}} - assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf)) + assert.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr, localIp: netip.Prefix{}}, mf.lastCall) // Test adding rule with local_cidr conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "local_cidr": cidr.String()}}} - assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf)) + assert.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: cidr}, mf.lastCall) // Test adding rule with ca_sha conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "ca_sha": "12312313123"}}} - assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf)) + assert.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: netip.Prefix{}, caSha: "12312313123"}, mf.lastCall) // Test adding rule with ca_name conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "ca_name": "root01"}}} - assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf)) + assert.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: netip.Prefix{}, caName: "root01"}, mf.lastCall) // Test single group conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a"}}} - assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf)) + assert.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall) // Test single groups conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "groups": "a"}}} - assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf)) + assert.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall) // Test multiple AND groups conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "groups": []string{"a", "b"}}}} - assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf)) + assert.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a", "b"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall) // Test Add error @@ -782,7 +782,7 @@ func TestFirewall_convertRule(t *testing.T) { r, err := convertRule(l, c, "test", 1) assert.Contains(t, ob.String(), "test rule #1; group was an array with a single value, converting to simple value") - assert.Nil(t, err) + assert.NoError(t, err) assert.Equal(t, "group1", r.Group) // Ensure group array of > 1 is errord @@ -802,7 +802,7 @@ func TestFirewall_convertRule(t *testing.T) { } r, err = convertRule(l, c, "test", 1) - assert.Nil(t, err) + assert.NoError(t, err) assert.Equal(t, "group1", r.Group) } diff --git a/handshake_manager_test.go b/handshake_manager_test.go index 7edc55b..4b898af 100644 --- a/handshake_manager_test.go +++ b/handshake_manager_test.go @@ -44,7 +44,7 @@ func Test_NewHandshakeManagerVpnIp(t *testing.T) { i.remotes = NewRemoteList([]netip.Addr{}, nil) // Adding something to pending should not affect the main hostmap - assert.Len(t, mainHM.Hosts, 0) + assert.Empty(t, mainHM.Hosts) // Confirm they are in the pending index list assert.Contains(t, blah.vpnIps, ip) diff --git a/header/header_test.go b/header/header_test.go index 765a006..1836a75 100644 --- a/header/header_test.go +++ b/header/header_test.go @@ -111,7 +111,7 @@ func TestHeader_String(t *testing.T) { func TestHeader_MarshalJSON(t *testing.T) { b, err := (&H{100, Test, TestRequest, 99, 98, 97}).MarshalJSON() - assert.Nil(t, err) + assert.NoError(t, err) assert.Equal( t, "{\"messageCounter\":97,\"remoteIndex\":98,\"reserved\":99,\"subType\":\"testRequest\",\"type\":\"test\",\"version\":100}", diff --git a/lighthouse_test.go b/lighthouse_test.go index d5947aa..9e9ad53 100644 --- a/lighthouse_test.go +++ b/lighthouse_test.go @@ -42,7 +42,7 @@ func Test_lhStaticMapping(t *testing.T) { c.Settings["lighthouse"] = map[interface{}]interface{}{"hosts": []interface{}{lh1}} c.Settings["static_host_map"] = map[interface{}]interface{}{lh1: []interface{}{"1.1.1.1:4242"}} _, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil) - assert.Nil(t, err) + assert.NoError(t, err) lh2 := "10.128.0.3" c = config.NewC(l) diff --git a/outside_test.go b/outside_test.go index f197594..944bf16 100644 --- a/outside_test.go +++ b/outside_test.go @@ -63,7 +63,7 @@ func Test_newPacket(t *testing.T) { b = append(b, []byte{0, 3, 0, 4}...) err = newPacket(b, true, p) - assert.Nil(t, err) + assert.NoError(t, err) assert.Equal(t, uint8(firewall.ProtoTCP), p.Protocol) assert.Equal(t, netip.MustParseAddr("10.0.0.2"), p.LocalAddr) assert.Equal(t, netip.MustParseAddr("10.0.0.1"), p.RemoteAddr) @@ -85,7 +85,7 @@ func Test_newPacket(t *testing.T) { b = append(b, []byte{0, 5, 0, 6}...) err = newPacket(b, false, p) - assert.Nil(t, err) + assert.NoError(t, err) assert.Equal(t, uint8(2), p.Protocol) assert.Equal(t, netip.MustParseAddr("10.0.0.1"), p.LocalAddr) assert.Equal(t, netip.MustParseAddr("10.0.0.2"), p.RemoteAddr) @@ -134,7 +134,7 @@ func Test_newPacket_v6(t *testing.T) { } err = newPacket(buffer.Bytes(), true, p) - assert.Nil(t, err) + assert.NoError(t, err) assert.Equal(t, uint8(layers.IPProtocolICMPv6), p.Protocol) assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr) assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr) @@ -146,7 +146,7 @@ func Test_newPacket_v6(t *testing.T) { b := buffer.Bytes() b[6] = byte(layers.IPProtocolESP) err = newPacket(b, true, p) - assert.Nil(t, err) + assert.NoError(t, err) assert.Equal(t, uint8(layers.IPProtocolESP), p.Protocol) assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr) assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr) @@ -158,7 +158,7 @@ func Test_newPacket_v6(t *testing.T) { b = buffer.Bytes() b[6] = byte(layers.IPProtocolNoNextHeader) err = newPacket(b, true, p) - assert.Nil(t, err) + assert.NoError(t, err) assert.Equal(t, uint8(layers.IPProtocolNoNextHeader), p.Protocol) assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr) assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr) @@ -197,7 +197,7 @@ func Test_newPacket_v6(t *testing.T) { // incoming err = newPacket(b, true, p) - assert.Nil(t, err) + assert.NoError(t, err) assert.Equal(t, uint8(firewall.ProtoUDP), p.Protocol) assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr) assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr) @@ -207,7 +207,7 @@ func Test_newPacket_v6(t *testing.T) { // outgoing err = newPacket(b, false, p) - assert.Nil(t, err) + assert.NoError(t, err) assert.Equal(t, uint8(firewall.ProtoUDP), p.Protocol) assert.Equal(t, netip.MustParseAddr("ff02::2"), p.LocalAddr) assert.Equal(t, netip.MustParseAddr("ff02::1"), p.RemoteAddr) @@ -224,7 +224,7 @@ func Test_newPacket_v6(t *testing.T) { // incoming err = newPacket(b, true, p) - assert.Nil(t, err) + assert.NoError(t, err) assert.Equal(t, uint8(firewall.ProtoTCP), p.Protocol) assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr) assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr) @@ -234,7 +234,7 @@ func Test_newPacket_v6(t *testing.T) { // outgoing err = newPacket(b, false, p) - assert.Nil(t, err) + assert.NoError(t, err) assert.Equal(t, uint8(firewall.ProtoTCP), p.Protocol) assert.Equal(t, netip.MustParseAddr("ff02::2"), p.LocalAddr) assert.Equal(t, netip.MustParseAddr("ff02::1"), p.RemoteAddr) @@ -279,7 +279,7 @@ func Test_newPacket_v6(t *testing.T) { b = append(b, udpHeader...) err = newPacket(b, true, p) - assert.Nil(t, err) + assert.NoError(t, err) assert.Equal(t, uint8(firewall.ProtoUDP), p.Protocol) assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr) assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr) diff --git a/overlay/route_test.go b/overlay/route_test.go index c60e4c2..4fa30af 100644 --- a/overlay/route_test.go +++ b/overlay/route_test.go @@ -18,8 +18,8 @@ func Test_parseRoutes(t *testing.T) { // test no routes config routes, err := parseRoutes(c, []netip.Prefix{n}) - assert.Nil(t, err) - assert.Len(t, routes, 0) + assert.NoError(t, err) + assert.Empty(t, routes) // not an array c.Settings["tun"] = map[interface{}]interface{}{"routes": "hi"} @@ -30,8 +30,8 @@ func Test_parseRoutes(t *testing.T) { // no routes c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{}} routes, err = parseRoutes(c, []netip.Prefix{n}) - assert.Nil(t, err) - assert.Len(t, routes, 0) + assert.NoError(t, err) + assert.Empty(t, routes) // weird route c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{"asdf"}} @@ -93,7 +93,7 @@ func Test_parseRoutes(t *testing.T) { map[interface{}]interface{}{"mtu": "8000", "route": "10.0.0.1/32"}, }} routes, err = parseRoutes(c, []netip.Prefix{n}) - assert.Nil(t, err) + assert.NoError(t, err) assert.Len(t, routes, 2) tested := 0 @@ -123,8 +123,8 @@ func Test_parseUnsafeRoutes(t *testing.T) { // test no routes config routes, err := parseUnsafeRoutes(c, []netip.Prefix{n}) - assert.Nil(t, err) - assert.Len(t, routes, 0) + assert.NoError(t, err) + assert.Empty(t, routes) // not an array c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": "hi"} @@ -135,8 +135,8 @@ func Test_parseUnsafeRoutes(t *testing.T) { // no routes c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{}} routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) - assert.Nil(t, err) - assert.Len(t, routes, 0) + assert.NoError(t, err) + assert.Empty(t, routes) // weird route c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{"asdf"}} @@ -188,13 +188,13 @@ func Test_parseUnsafeRoutes(t *testing.T) { c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "1.0.0.0/8"}}} routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Len(t, routes, 1) - assert.Nil(t, err) + assert.NoError(t, err) // above network range c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "10.0.1.0/24"}}} routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Len(t, routes, 1) - assert.Nil(t, err) + assert.NoError(t, err) // no mtu c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "1.0.0.0/8"}}} @@ -228,7 +228,7 @@ func Test_parseUnsafeRoutes(t *testing.T) { map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "1500", "metric": 1234, "route": "1.0.0.2/32"}, }} routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) - assert.Nil(t, err) + assert.NoError(t, err) assert.Len(t, routes, 4) tested := 0 diff --git a/punchy_test.go b/punchy_test.go index bedd2b2..7918449 100644 --- a/punchy_test.go +++ b/punchy_test.go @@ -15,31 +15,31 @@ func TestNewPunchyFromConfig(t *testing.T) { // Test defaults p := NewPunchyFromConfig(l, c) - assert.Equal(t, false, p.GetPunch()) - assert.Equal(t, false, p.GetRespond()) + assert.False(t, p.GetPunch()) + assert.False(t, p.GetRespond()) assert.Equal(t, time.Second, p.GetDelay()) assert.Equal(t, 5*time.Second, p.GetRespondDelay()) // punchy deprecation c.Settings["punchy"] = true p = NewPunchyFromConfig(l, c) - assert.Equal(t, true, p.GetPunch()) + assert.True(t, p.GetPunch()) // punchy.punch c.Settings["punchy"] = map[interface{}]interface{}{"punch": true} p = NewPunchyFromConfig(l, c) - assert.Equal(t, true, p.GetPunch()) + assert.True(t, p.GetPunch()) // punch_back deprecation c.Settings["punch_back"] = true p = NewPunchyFromConfig(l, c) - assert.Equal(t, true, p.GetRespond()) + assert.True(t, p.GetRespond()) // punchy.respond c.Settings["punchy"] = map[interface{}]interface{}{"respond": true} c.Settings["punch_back"] = false p = NewPunchyFromConfig(l, c) - assert.Equal(t, true, p.GetRespond()) + assert.True(t, p.GetRespond()) // punchy.delay c.Settings["punchy"] = map[interface{}]interface{}{"delay": "1m"} @@ -63,7 +63,7 @@ punchy: `)) p := NewPunchyFromConfig(l, c) assert.Equal(t, delay, p.GetDelay()) - assert.Equal(t, false, p.GetRespond()) + assert.False(t, p.GetRespond()) newDelay, _ := time.ParseDuration("10m") assert.NoError(t, c.ReloadConfigString(` @@ -73,5 +73,5 @@ punchy: `)) p.reload(c, false) assert.Equal(t, newDelay, p.GetDelay()) - assert.Equal(t, true, p.GetRespond()) + assert.True(t, p.GetRespond()) } From 088af8edb264ec1a25d947e192da7938c48d18d4 Mon Sep 17 00:00:00 2001 From: Caleb Jasik Date: Mon, 10 Mar 2025 17:38:14 -0500 Subject: [PATCH 51/67] Enable running testifylint in CI (#1350) --- .github/workflows/test.yml | 10 +++ .golangci.yaml | 9 ++ allow_list_test.go | 13 +-- calculated_remote_test.go | 16 ++-- cert/ca_pool_test.go | 151 +++++++++++++++++---------------- cert/cert_v1_test.go | 34 ++++---- cert/cert_v2_test.go | 50 +++++------ cert/crypto_test.go | 15 ++-- cert/pem_test.go | 45 +++++----- cert/sign_test.go | 15 ++-- cmd/nebula-cert/ca_test.go | 37 ++++---- cmd/nebula-cert/keygen_test.go | 15 ++-- cmd/nebula-cert/main_test.go | 3 +- cmd/nebula-cert/print_test.go | 11 +-- cmd/nebula-cert/sign_test.go | 63 +++++++------- cmd/nebula-cert/verify_test.go | 17 ++-- config/config_test.go | 10 +-- connection_manager_test.go | 7 +- e2e/handshakes_test.go | 21 ++--- firewall_test.go | 150 ++++++++++++++++---------------- header/header_test.go | 3 +- lighthouse_test.go | 31 ++++--- outside_test.go | 57 +++++++------ overlay/route_test.go | 79 ++++++++--------- punchy_test.go | 5 +- 25 files changed, 451 insertions(+), 416 deletions(-) create mode 100644 .golangci.yaml diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 4f3f2ed..b8a4f03 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -31,6 +31,11 @@ jobs: - name: Vet run: make vet + - name: golangci-lint + uses: golangci/golangci-lint-action@v6 + with: + version: v1.64 + - name: Test run: make test @@ -109,6 +114,11 @@ jobs: - name: Vet run: make vet + - name: golangci-lint + uses: golangci/golangci-lint-action@v6 + with: + version: v1.64 + - name: Test run: make test diff --git a/.golangci.yaml b/.golangci.yaml new file mode 100644 index 0000000..f792069 --- /dev/null +++ b/.golangci.yaml @@ -0,0 +1,9 @@ +# yaml-language-server: $schema=https://golangci-lint.run/jsonschema/golangci.jsonschema.json +linters: + # Disable all linters. + # Default: false + disable-all: true + # Enable specific linter + # https://golangci-lint.run/usage/linters/#enabled-by-default + enable: + - testifylint diff --git a/allow_list_test.go b/allow_list_test.go index 6d5e76b..d7d2c9a 100644 --- a/allow_list_test.go +++ b/allow_list_test.go @@ -9,6 +9,7 @@ import ( "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/test" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestNewAllowListFromConfig(t *testing.T) { @@ -18,21 +19,21 @@ func TestNewAllowListFromConfig(t *testing.T) { "192.168.0.0": true, } r, err := newAllowListFromConfig(c, "allowlist", nil) - assert.EqualError(t, err, "config `allowlist` has invalid CIDR: 192.168.0.0. netip.ParsePrefix(\"192.168.0.0\"): no '/'") + require.EqualError(t, err, "config `allowlist` has invalid CIDR: 192.168.0.0. netip.ParsePrefix(\"192.168.0.0\"): no '/'") assert.Nil(t, r) c.Settings["allowlist"] = map[interface{}]interface{}{ "192.168.0.0/16": "abc", } r, err = newAllowListFromConfig(c, "allowlist", nil) - assert.EqualError(t, err, "config `allowlist` has invalid value (type string): abc") + require.EqualError(t, err, "config `allowlist` has invalid value (type string): abc") c.Settings["allowlist"] = map[interface{}]interface{}{ "192.168.0.0/16": true, "10.0.0.0/8": false, } r, err = newAllowListFromConfig(c, "allowlist", nil) - assert.EqualError(t, err, "config `allowlist` contains both true and false rules, but no default set for 0.0.0.0/0") + require.EqualError(t, err, "config `allowlist` contains both true and false rules, but no default set for 0.0.0.0/0") c.Settings["allowlist"] = map[interface{}]interface{}{ "0.0.0.0/0": true, @@ -42,7 +43,7 @@ func TestNewAllowListFromConfig(t *testing.T) { "fd00:fd00::/16": false, } r, err = newAllowListFromConfig(c, "allowlist", nil) - assert.EqualError(t, err, "config `allowlist` contains both true and false rules, but no default set for ::/0") + require.EqualError(t, err, "config `allowlist` contains both true and false rules, but no default set for ::/0") c.Settings["allowlist"] = map[interface{}]interface{}{ "0.0.0.0/0": true, @@ -75,7 +76,7 @@ func TestNewAllowListFromConfig(t *testing.T) { }, } lr, err := NewLocalAllowListFromConfig(c, "allowlist") - assert.EqualError(t, err, "config `allowlist.interfaces` has invalid value (type string): foo") + require.EqualError(t, err, "config `allowlist.interfaces` has invalid value (type string): foo") c.Settings["allowlist"] = map[interface{}]interface{}{ "interfaces": map[interface{}]interface{}{ @@ -84,7 +85,7 @@ func TestNewAllowListFromConfig(t *testing.T) { }, } lr, err = NewLocalAllowListFromConfig(c, "allowlist") - assert.EqualError(t, err, "config `allowlist.interfaces` values must all be the same true/false value") + require.EqualError(t, err, "config `allowlist.interfaces` values must all be the same true/false value") c.Settings["allowlist"] = map[interface{}]interface{}{ "interfaces": map[interface{}]interface{}{ diff --git a/calculated_remote_test.go b/calculated_remote_test.go index 066213e..6df893c 100644 --- a/calculated_remote_test.go +++ b/calculated_remote_test.go @@ -15,10 +15,10 @@ func TestCalculatedRemoteApply(t *testing.T) { require.NoError(t, err) input, err := netip.ParseAddr("10.0.10.182") - assert.NoError(t, err) + require.NoError(t, err) expected, err := netip.ParseAddr("192.168.1.182") - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, netAddrToProtoV4AddrPort(expected, 4242), c.ApplyV4(input)) @@ -28,10 +28,10 @@ func TestCalculatedRemoteApply(t *testing.T) { require.NoError(t, err) input, err = netip.ParseAddr("beef:beef:beef:beef:beef:beef:beef:beef") - assert.NoError(t, err) + require.NoError(t, err) expected, err = netip.ParseAddr("ffff:ffff:ffff:ffff:beef:beef:beef:beef") - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, netAddrToProtoV6AddrPort(expected, 4242), c.ApplyV6(input)) @@ -41,10 +41,10 @@ func TestCalculatedRemoteApply(t *testing.T) { require.NoError(t, err) input, err = netip.ParseAddr("beef:beef:beef:beef:beef:beef:beef:beef") - assert.NoError(t, err) + require.NoError(t, err) expected, err = netip.ParseAddr("ffff:ffff:ffff:ffff:ffff:beef:beef:beef") - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, netAddrToProtoV6AddrPort(expected, 4242), c.ApplyV6(input)) @@ -54,10 +54,10 @@ func TestCalculatedRemoteApply(t *testing.T) { require.NoError(t, err) input, err = netip.ParseAddr("beef:beef:beef:beef:beef:beef:beef:beef") - assert.NoError(t, err) + require.NoError(t, err) expected, err = netip.ParseAddr("ffff:ffff:ffff:beef:beef:beef:beef:beef") - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, netAddrToProtoV6AddrPort(expected, 4242), c.ApplyV6(input)) } diff --git a/cert/ca_pool_test.go b/cert/ca_pool_test.go index 2f9255f..b0fdd5f 100644 --- a/cert/ca_pool_test.go +++ b/cert/ca_pool_test.go @@ -6,6 +6,7 @@ import ( "time" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestNewCAPoolFromBytes(t *testing.T) { @@ -82,12 +83,12 @@ k+coOv04r+zh33ISyhbsafnYduN17p2eD7CmHvHuerguXD9f32gcxo/KsFCKEjMe } p, err := NewCAPoolFromPEM([]byte(noNewLines)) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, p.CAs["ce4e6c7a596996eb0d82a8875f0f0137a4b53ce22d2421c9fd7150e7a26f6300"].Certificate.Name(), rootCA.details.name) assert.Equal(t, p.CAs["04c585fcd9a49b276df956a22b7ebea3bf23f1fca5a17c0b56ce2e626631969e"].Certificate.Name(), rootCA01.details.name) pp, err := NewCAPoolFromPEM([]byte(withNewLines)) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, pp.CAs["ce4e6c7a596996eb0d82a8875f0f0137a4b53ce22d2421c9fd7150e7a26f6300"].Certificate.Name(), rootCA.details.name) assert.Equal(t, pp.CAs["04c585fcd9a49b276df956a22b7ebea3bf23f1fca5a17c0b56ce2e626631969e"].Certificate.Name(), rootCA01.details.name) @@ -105,7 +106,7 @@ k+coOv04r+zh33ISyhbsafnYduN17p2eD7CmHvHuerguXD9f32gcxo/KsFCKEjMe assert.Len(t, pppp.CAs, 3) ppppp, err := NewCAPoolFromPEM([]byte(p256)) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, ppppp.CAs["552bf7d99bec1fc775a0e4c324bf6d8f789b3078f1919c7960d2e5e0c351ee97"].Certificate.Name(), rootCAP256.details.name) assert.Len(t, ppppp.CAs, 1) } @@ -115,21 +116,21 @@ func TestCertificateV1_Verify(t *testing.T) { c, _, _, _ := NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test cert", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil) caPool := NewCAPool() - assert.NoError(t, caPool.AddCA(ca)) + require.NoError(t, caPool.AddCA(ca)) f, err := c.Fingerprint() - assert.NoError(t, err) + require.NoError(t, err) caPool.BlocklistFingerprint(f) _, err = caPool.VerifyCertificate(time.Now(), c) - assert.EqualError(t, err, "certificate is in the block list") + require.EqualError(t, err, "certificate is in the block list") caPool.ResetCertBlocklist() _, err = caPool.VerifyCertificate(time.Now(), c) - assert.NoError(t, err) + require.NoError(t, err) _, err = caPool.VerifyCertificate(time.Now().Add(time.Hour*1000), c) - assert.EqualError(t, err, "root certificate is expired") + require.EqualError(t, err, "root certificate is expired") assert.PanicsWithError(t, "certificate is valid before the signing certificate", func() { NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test cert2", time.Time{}, time.Time{}, nil, nil, nil) @@ -138,11 +139,11 @@ func TestCertificateV1_Verify(t *testing.T) { // Test group assertion ca, _, caKey, _ = NewTestCaCert(Version1, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{"test1", "test2"}) caPem, err := ca.MarshalPEM() - assert.NoError(t, err) + require.NoError(t, err) caPool = NewCAPool() b, err := caPool.AddCAFromPEM(caPem) - assert.NoError(t, err) + require.NoError(t, err) assert.Empty(t, b) assert.PanicsWithError(t, "certificate contained a group not present on the signing ca: bad", func() { @@ -150,9 +151,9 @@ func TestCertificateV1_Verify(t *testing.T) { }) c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test2", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"}) - assert.NoError(t, err) + require.NoError(t, err) _, err = caPool.VerifyCertificate(time.Now(), c) - assert.NoError(t, err) + require.NoError(t, err) } func TestCertificateV1_VerifyP256(t *testing.T) { @@ -160,21 +161,21 @@ func TestCertificateV1_VerifyP256(t *testing.T) { c, _, _, _ := NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil) caPool := NewCAPool() - assert.NoError(t, caPool.AddCA(ca)) + require.NoError(t, caPool.AddCA(ca)) f, err := c.Fingerprint() - assert.NoError(t, err) + require.NoError(t, err) caPool.BlocklistFingerprint(f) _, err = caPool.VerifyCertificate(time.Now(), c) - assert.EqualError(t, err, "certificate is in the block list") + require.EqualError(t, err, "certificate is in the block list") caPool.ResetCertBlocklist() _, err = caPool.VerifyCertificate(time.Now(), c) - assert.NoError(t, err) + require.NoError(t, err) _, err = caPool.VerifyCertificate(time.Now().Add(time.Hour*1000), c) - assert.EqualError(t, err, "root certificate is expired") + require.EqualError(t, err, "root certificate is expired") assert.PanicsWithError(t, "certificate is valid before the signing certificate", func() { NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil) @@ -183,11 +184,11 @@ func TestCertificateV1_VerifyP256(t *testing.T) { // Test group assertion ca, _, caKey, _ = NewTestCaCert(Version1, Curve_P256, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{"test1", "test2"}) caPem, err := ca.MarshalPEM() - assert.NoError(t, err) + require.NoError(t, err) caPool = NewCAPool() b, err := caPool.AddCAFromPEM(caPem) - assert.NoError(t, err) + require.NoError(t, err) assert.Empty(t, b) assert.PanicsWithError(t, "certificate contained a group not present on the signing ca: bad", func() { @@ -196,7 +197,7 @@ func TestCertificateV1_VerifyP256(t *testing.T) { c, _, _, _ = NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"}) _, err = caPool.VerifyCertificate(time.Now(), c) - assert.NoError(t, err) + require.NoError(t, err) } func TestCertificateV1_Verify_IPs(t *testing.T) { @@ -205,11 +206,11 @@ func TestCertificateV1_Verify_IPs(t *testing.T) { ca, _, caKey, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), []netip.Prefix{caIp1, caIp2}, nil, []string{"test"}) caPem, err := ca.MarshalPEM() - assert.NoError(t, err) + require.NoError(t, err) caPool := NewCAPool() b, err := caPool.AddCAFromPEM(caPem) - assert.NoError(t, err) + require.NoError(t, err) assert.Empty(t, b) // ip is outside the network @@ -245,25 +246,25 @@ func TestCertificateV1_Verify_IPs(t *testing.T) { cIp2 = mustParsePrefixUnmapped("192.168.0.1/25") c, _, _, _ := NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"}) _, err = caPool.VerifyCertificate(time.Now(), c) - assert.NoError(t, err) + require.NoError(t, err) // Exact matches c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp1, caIp2}, nil, []string{"test"}) - assert.NoError(t, err) + require.NoError(t, err) _, err = caPool.VerifyCertificate(time.Now(), c) - assert.NoError(t, err) + require.NoError(t, err) // Exact matches reversed c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp2, caIp1}, nil, []string{"test"}) - assert.NoError(t, err) + require.NoError(t, err) _, err = caPool.VerifyCertificate(time.Now(), c) - assert.NoError(t, err) + require.NoError(t, err) // Exact matches reversed with just 1 c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp1}, nil, []string{"test"}) - assert.NoError(t, err) + require.NoError(t, err) _, err = caPool.VerifyCertificate(time.Now(), c) - assert.NoError(t, err) + require.NoError(t, err) } func TestCertificateV1_Verify_Subnets(t *testing.T) { @@ -272,11 +273,11 @@ func TestCertificateV1_Verify_Subnets(t *testing.T) { ca, _, caKey, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, []netip.Prefix{caIp1, caIp2}, []string{"test"}) caPem, err := ca.MarshalPEM() - assert.NoError(t, err) + require.NoError(t, err) caPool := NewCAPool() b, err := caPool.AddCAFromPEM(caPem) - assert.NoError(t, err) + require.NoError(t, err) assert.Empty(t, b) // ip is outside the network @@ -311,27 +312,27 @@ func TestCertificateV1_Verify_Subnets(t *testing.T) { cIp1 = mustParsePrefixUnmapped("10.0.1.0/16") cIp2 = mustParsePrefixUnmapped("192.168.0.1/25") c, _, _, _ := NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"}) - assert.NoError(t, err) + require.NoError(t, err) _, err = caPool.VerifyCertificate(time.Now(), c) - assert.NoError(t, err) + require.NoError(t, err) // Exact matches c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp1, caIp2}, []string{"test"}) - assert.NoError(t, err) + require.NoError(t, err) _, err = caPool.VerifyCertificate(time.Now(), c) - assert.NoError(t, err) + require.NoError(t, err) // Exact matches reversed c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp2, caIp1}, []string{"test"}) - assert.NoError(t, err) + require.NoError(t, err) _, err = caPool.VerifyCertificate(time.Now(), c) - assert.NoError(t, err) + require.NoError(t, err) // Exact matches reversed with just 1 c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp1}, []string{"test"}) - assert.NoError(t, err) + require.NoError(t, err) _, err = caPool.VerifyCertificate(time.Now(), c) - assert.NoError(t, err) + require.NoError(t, err) } func TestCertificateV2_Verify(t *testing.T) { @@ -339,21 +340,21 @@ func TestCertificateV2_Verify(t *testing.T) { c, _, _, _ := NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test cert", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil) caPool := NewCAPool() - assert.NoError(t, caPool.AddCA(ca)) + require.NoError(t, caPool.AddCA(ca)) f, err := c.Fingerprint() - assert.NoError(t, err) + require.NoError(t, err) caPool.BlocklistFingerprint(f) _, err = caPool.VerifyCertificate(time.Now(), c) - assert.EqualError(t, err, "certificate is in the block list") + require.EqualError(t, err, "certificate is in the block list") caPool.ResetCertBlocklist() _, err = caPool.VerifyCertificate(time.Now(), c) - assert.NoError(t, err) + require.NoError(t, err) _, err = caPool.VerifyCertificate(time.Now().Add(time.Hour*1000), c) - assert.EqualError(t, err, "root certificate is expired") + require.EqualError(t, err, "root certificate is expired") assert.PanicsWithError(t, "certificate is valid before the signing certificate", func() { NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test cert2", time.Time{}, time.Time{}, nil, nil, nil) @@ -362,11 +363,11 @@ func TestCertificateV2_Verify(t *testing.T) { // Test group assertion ca, _, caKey, _ = NewTestCaCert(Version2, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{"test1", "test2"}) caPem, err := ca.MarshalPEM() - assert.NoError(t, err) + require.NoError(t, err) caPool = NewCAPool() b, err := caPool.AddCAFromPEM(caPem) - assert.NoError(t, err) + require.NoError(t, err) assert.Empty(t, b) assert.PanicsWithError(t, "certificate contained a group not present on the signing ca: bad", func() { @@ -374,9 +375,9 @@ func TestCertificateV2_Verify(t *testing.T) { }) c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test2", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"}) - assert.NoError(t, err) + require.NoError(t, err) _, err = caPool.VerifyCertificate(time.Now(), c) - assert.NoError(t, err) + require.NoError(t, err) } func TestCertificateV2_VerifyP256(t *testing.T) { @@ -384,21 +385,21 @@ func TestCertificateV2_VerifyP256(t *testing.T) { c, _, _, _ := NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil) caPool := NewCAPool() - assert.NoError(t, caPool.AddCA(ca)) + require.NoError(t, caPool.AddCA(ca)) f, err := c.Fingerprint() - assert.NoError(t, err) + require.NoError(t, err) caPool.BlocklistFingerprint(f) _, err = caPool.VerifyCertificate(time.Now(), c) - assert.EqualError(t, err, "certificate is in the block list") + require.EqualError(t, err, "certificate is in the block list") caPool.ResetCertBlocklist() _, err = caPool.VerifyCertificate(time.Now(), c) - assert.NoError(t, err) + require.NoError(t, err) _, err = caPool.VerifyCertificate(time.Now().Add(time.Hour*1000), c) - assert.EqualError(t, err, "root certificate is expired") + require.EqualError(t, err, "root certificate is expired") assert.PanicsWithError(t, "certificate is valid before the signing certificate", func() { NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil) @@ -407,11 +408,11 @@ func TestCertificateV2_VerifyP256(t *testing.T) { // Test group assertion ca, _, caKey, _ = NewTestCaCert(Version2, Curve_P256, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{"test1", "test2"}) caPem, err := ca.MarshalPEM() - assert.NoError(t, err) + require.NoError(t, err) caPool = NewCAPool() b, err := caPool.AddCAFromPEM(caPem) - assert.NoError(t, err) + require.NoError(t, err) assert.Empty(t, b) assert.PanicsWithError(t, "certificate contained a group not present on the signing ca: bad", func() { @@ -420,7 +421,7 @@ func TestCertificateV2_VerifyP256(t *testing.T) { c, _, _, _ = NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"}) _, err = caPool.VerifyCertificate(time.Now(), c) - assert.NoError(t, err) + require.NoError(t, err) } func TestCertificateV2_Verify_IPs(t *testing.T) { @@ -429,11 +430,11 @@ func TestCertificateV2_Verify_IPs(t *testing.T) { ca, _, caKey, _ := NewTestCaCert(Version2, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), []netip.Prefix{caIp1, caIp2}, nil, []string{"test"}) caPem, err := ca.MarshalPEM() - assert.NoError(t, err) + require.NoError(t, err) caPool := NewCAPool() b, err := caPool.AddCAFromPEM(caPem) - assert.NoError(t, err) + require.NoError(t, err) assert.Empty(t, b) // ip is outside the network @@ -469,25 +470,25 @@ func TestCertificateV2_Verify_IPs(t *testing.T) { cIp2 = mustParsePrefixUnmapped("192.168.0.1/25") c, _, _, _ := NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"}) _, err = caPool.VerifyCertificate(time.Now(), c) - assert.NoError(t, err) + require.NoError(t, err) // Exact matches c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp1, caIp2}, nil, []string{"test"}) - assert.NoError(t, err) + require.NoError(t, err) _, err = caPool.VerifyCertificate(time.Now(), c) - assert.NoError(t, err) + require.NoError(t, err) // Exact matches reversed c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp2, caIp1}, nil, []string{"test"}) - assert.NoError(t, err) + require.NoError(t, err) _, err = caPool.VerifyCertificate(time.Now(), c) - assert.NoError(t, err) + require.NoError(t, err) // Exact matches reversed with just 1 c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp1}, nil, []string{"test"}) - assert.NoError(t, err) + require.NoError(t, err) _, err = caPool.VerifyCertificate(time.Now(), c) - assert.NoError(t, err) + require.NoError(t, err) } func TestCertificateV2_Verify_Subnets(t *testing.T) { @@ -496,11 +497,11 @@ func TestCertificateV2_Verify_Subnets(t *testing.T) { ca, _, caKey, _ := NewTestCaCert(Version2, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, []netip.Prefix{caIp1, caIp2}, []string{"test"}) caPem, err := ca.MarshalPEM() - assert.NoError(t, err) + require.NoError(t, err) caPool := NewCAPool() b, err := caPool.AddCAFromPEM(caPem) - assert.NoError(t, err) + require.NoError(t, err) assert.Empty(t, b) // ip is outside the network @@ -535,25 +536,25 @@ func TestCertificateV2_Verify_Subnets(t *testing.T) { cIp1 = mustParsePrefixUnmapped("10.0.1.0/16") cIp2 = mustParsePrefixUnmapped("192.168.0.1/25") c, _, _, _ := NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"}) - assert.NoError(t, err) + require.NoError(t, err) _, err = caPool.VerifyCertificate(time.Now(), c) - assert.NoError(t, err) + require.NoError(t, err) // Exact matches c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp1, caIp2}, []string{"test"}) - assert.NoError(t, err) + require.NoError(t, err) _, err = caPool.VerifyCertificate(time.Now(), c) - assert.NoError(t, err) + require.NoError(t, err) // Exact matches reversed c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp2, caIp1}, []string{"test"}) - assert.NoError(t, err) + require.NoError(t, err) _, err = caPool.VerifyCertificate(time.Now(), c) - assert.NoError(t, err) + require.NoError(t, err) // Exact matches reversed with just 1 c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp1}, []string{"test"}) - assert.NoError(t, err) + require.NoError(t, err) _, err = caPool.VerifyCertificate(time.Now(), c) - assert.NoError(t, err) + require.NoError(t, err) } diff --git a/cert/cert_v1_test.go b/cert/cert_v1_test.go index ea98b08..c687172 100644 --- a/cert/cert_v1_test.go +++ b/cert/cert_v1_test.go @@ -39,11 +39,11 @@ func TestCertificateV1_Marshal(t *testing.T) { } b, err := nc.Marshal() - assert.NoError(t, err) + require.NoError(t, err) //t.Log("Cert size:", len(b)) nc2, err := unmarshalCertificateV1(b, nil) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, Version1, nc.Version()) assert.Equal(t, Curve_CURVE25519, nc.Curve()) @@ -99,7 +99,7 @@ func TestCertificateV1_MarshalJSON(t *testing.T) { } b, err := nc.MarshalJSON() - assert.NoError(t, err) + require.NoError(t, err) assert.JSONEq( t, "{\"details\":{\"curve\":\"CURVE25519\",\"groups\":[\"test-group1\",\"test-group2\",\"test-group3\"],\"isCa\":false,\"issuer\":\"1234567890abcedfghij1234567890ab\",\"name\":\"testing\",\"networks\":[\"10.1.1.1/24\",\"10.1.1.2/16\"],\"notAfter\":\"0000-11-30T02:00:00Z\",\"notBefore\":\"0000-11-30T01:00:00Z\",\"publicKey\":\"313233343536373839306162636564666768696a313233343536373839306162\",\"unsafeNetworks\":[\"9.1.1.2/24\",\"9.1.1.3/16\"]},\"fingerprint\":\"3944c53d4267a229295b56cb2d27d459164c010ac97d655063ba421e0670f4ba\",\"signature\":\"313233343536373839306162636564666768696a313233343536373839306162\",\"version\":1}", @@ -110,47 +110,47 @@ func TestCertificateV1_MarshalJSON(t *testing.T) { func TestCertificateV1_VerifyPrivateKey(t *testing.T) { ca, _, caKey, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil) err := ca.VerifyPrivateKey(Curve_CURVE25519, caKey) - assert.NoError(t, err) + require.NoError(t, err) _, _, caKey2, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil) - assert.NoError(t, err) + require.NoError(t, err) err = ca.VerifyPrivateKey(Curve_CURVE25519, caKey2) - assert.Error(t, err) + require.Error(t, err) c, _, priv, _ := NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil) rawPriv, b, curve, err := UnmarshalPrivateKeyFromPEM(priv) - assert.NoError(t, err) + require.NoError(t, err) assert.Empty(t, b) assert.Equal(t, Curve_CURVE25519, curve) err = c.VerifyPrivateKey(Curve_CURVE25519, rawPriv) - assert.NoError(t, err) + require.NoError(t, err) _, priv2 := X25519Keypair() err = c.VerifyPrivateKey(Curve_CURVE25519, priv2) - assert.Error(t, err) + require.Error(t, err) } func TestCertificateV1_VerifyPrivateKeyP256(t *testing.T) { ca, _, caKey, _ := NewTestCaCert(Version1, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil) err := ca.VerifyPrivateKey(Curve_P256, caKey) - assert.NoError(t, err) + require.NoError(t, err) _, _, caKey2, _ := NewTestCaCert(Version1, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil) - assert.NoError(t, err) + require.NoError(t, err) err = ca.VerifyPrivateKey(Curve_P256, caKey2) - assert.Error(t, err) + require.Error(t, err) c, _, priv, _ := NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil) rawPriv, b, curve, err := UnmarshalPrivateKeyFromPEM(priv) - assert.NoError(t, err) + require.NoError(t, err) assert.Empty(t, b) assert.Equal(t, Curve_P256, curve) err = c.VerifyPrivateKey(Curve_P256, rawPriv) - assert.NoError(t, err) + require.NoError(t, err) _, priv2 := P256Keypair() err = c.VerifyPrivateKey(Curve_P256, priv2) - assert.Error(t, err) + require.Error(t, err) } // Ensure that upgrading the protobuf library does not change how certificates @@ -186,7 +186,7 @@ func TestMarshalingCertificateV1Consistency(t *testing.T) { assert.Equal(t, "0a8e010a0774657374696e671212828284508080fcff0f8182845080feffff0f1a12838284488080fcff0f8282844880feffff0f220b746573742d67726f757031220b746573742d67726f757032220b746573742d67726f75703328cd1c30cdb8ccf0af073a20313233343536373839306162636564666768696a3132333435363738393061624a081234567890abcedf1220313233343536373839306162636564666768696a313233343536373839306162", fmt.Sprintf("%x", b)) b, err = proto.Marshal(nc.getRawDetails()) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, "0a0774657374696e671212828284508080fcff0f8182845080feffff0f1a12838284488080fcff0f8282844880feffff0f220b746573742d67726f757031220b746573742d67726f757032220b746573742d67726f75703328cd1c30cdb8ccf0af073a20313233343536373839306162636564666768696a3132333435363738393061624a081234567890abcedf", fmt.Sprintf("%x", b)) } @@ -201,7 +201,7 @@ func TestUnmarshalCertificateV1(t *testing.T) { // Test that we don't panic with an invalid certificate (#332) data := []byte("\x98\x00\x00") _, err := unmarshalCertificateV1(data, nil) - assert.EqualError(t, err, "encoded Details was nil") + require.EqualError(t, err, "encoded Details was nil") } func appendByteSlices(b ...[]byte) []byte { diff --git a/cert/cert_v2_test.go b/cert/cert_v2_test.go index 6d55750..c84f8c9 100644 --- a/cert/cert_v2_test.go +++ b/cert/cert_v2_test.go @@ -49,7 +49,7 @@ func TestCertificateV2_Marshal(t *testing.T) { //t.Log("Cert size:", len(b)) nc2, err := unmarshalCertificateV2(b, nil, Curve_CURVE25519) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, Version2, nc.Version()) assert.Equal(t, Curve_CURVE25519, nc.Curve()) @@ -114,14 +114,14 @@ func TestCertificateV2_MarshalJSON(t *testing.T) { } b, err := nc.MarshalJSON() - assert.ErrorIs(t, err, ErrMissingDetails) + require.ErrorIs(t, err, ErrMissingDetails) rd, err := nc.details.Marshal() - assert.NoError(t, err) + require.NoError(t, err) nc.rawDetails = rd b, err = nc.MarshalJSON() - assert.NoError(t, err) + require.NoError(t, err) assert.JSONEq( t, "{\"curve\":\"CURVE25519\",\"details\":{\"groups\":[\"test-group1\",\"test-group2\",\"test-group3\"],\"isCa\":false,\"issuer\":\"1234567890abcedf1234567890abcedf\",\"name\":\"testing\",\"networks\":[\"10.1.1.1/24\",\"10.1.1.2/16\"],\"notAfter\":\"0000-11-30T02:00:00Z\",\"notBefore\":\"0000-11-30T01:00:00Z\",\"unsafeNetworks\":[\"9.1.1.2/24\",\"9.1.1.3/16\"]},\"fingerprint\":\"152d9a7400c1e001cb76cffd035215ebb351f69eeb797f7f847dd086e15e56dd\",\"publicKey\":\"3132333435363738393061626365646631323334353637383930616263656466\",\"signature\":\"31323334353637383930616263656466313233343536373839306162636564663132333435363738393061626365646631323334353637383930616263656466\",\"version\":2}", @@ -132,85 +132,85 @@ func TestCertificateV2_MarshalJSON(t *testing.T) { func TestCertificateV2_VerifyPrivateKey(t *testing.T) { ca, _, caKey, _ := NewTestCaCert(Version2, Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil) err := ca.VerifyPrivateKey(Curve_CURVE25519, caKey) - assert.NoError(t, err) + require.NoError(t, err) err = ca.VerifyPrivateKey(Curve_CURVE25519, caKey[:16]) - assert.ErrorIs(t, err, ErrInvalidPrivateKey) + require.ErrorIs(t, err, ErrInvalidPrivateKey) _, caKey2, err := ed25519.GenerateKey(rand.Reader) require.NoError(t, err) err = ca.VerifyPrivateKey(Curve_CURVE25519, caKey2) - assert.ErrorIs(t, err, ErrPublicPrivateKeyMismatch) + require.ErrorIs(t, err, ErrPublicPrivateKeyMismatch) c, _, priv, _ := NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil) rawPriv, b, curve, err := UnmarshalPrivateKeyFromPEM(priv) - assert.NoError(t, err) + require.NoError(t, err) assert.Empty(t, b) assert.Equal(t, Curve_CURVE25519, curve) err = c.VerifyPrivateKey(Curve_CURVE25519, rawPriv) - assert.NoError(t, err) + require.NoError(t, err) _, priv2 := X25519Keypair() err = c.VerifyPrivateKey(Curve_P256, priv2) - assert.ErrorIs(t, err, ErrPublicPrivateCurveMismatch) + require.ErrorIs(t, err, ErrPublicPrivateCurveMismatch) err = c.VerifyPrivateKey(Curve_CURVE25519, priv2) - assert.ErrorIs(t, err, ErrPublicPrivateKeyMismatch) + require.ErrorIs(t, err, ErrPublicPrivateKeyMismatch) err = c.VerifyPrivateKey(Curve_CURVE25519, priv2[:16]) - assert.ErrorIs(t, err, ErrInvalidPrivateKey) + require.ErrorIs(t, err, ErrInvalidPrivateKey) ac, ok := c.(*certificateV2) require.True(t, ok) ac.curve = Curve(99) err = c.VerifyPrivateKey(Curve(99), priv2) - assert.EqualError(t, err, "invalid curve: 99") + require.EqualError(t, err, "invalid curve: 99") ca2, _, caKey2, _ := NewTestCaCert(Version2, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil) err = ca.VerifyPrivateKey(Curve_CURVE25519, caKey) - assert.NoError(t, err) + require.NoError(t, err) err = ca2.VerifyPrivateKey(Curve_P256, caKey2[:16]) - assert.ErrorIs(t, err, ErrInvalidPrivateKey) + require.ErrorIs(t, err, ErrInvalidPrivateKey) c, _, priv, _ = NewTestCert(Version2, Curve_P256, ca2, caKey2, "test", time.Time{}, time.Time{}, nil, nil, nil) rawPriv, b, curve, err = UnmarshalPrivateKeyFromPEM(priv) err = c.VerifyPrivateKey(Curve_P256, priv[:16]) - assert.ErrorIs(t, err, ErrInvalidPrivateKey) + require.ErrorIs(t, err, ErrInvalidPrivateKey) err = c.VerifyPrivateKey(Curve_P256, priv) - assert.ErrorIs(t, err, ErrInvalidPrivateKey) + require.ErrorIs(t, err, ErrInvalidPrivateKey) aCa, ok := ca2.(*certificateV2) require.True(t, ok) aCa.curve = Curve(99) err = aCa.VerifyPrivateKey(Curve(99), priv2) - assert.EqualError(t, err, "invalid curve: 99") + require.EqualError(t, err, "invalid curve: 99") } func TestCertificateV2_VerifyPrivateKeyP256(t *testing.T) { ca, _, caKey, _ := NewTestCaCert(Version2, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil) err := ca.VerifyPrivateKey(Curve_P256, caKey) - assert.NoError(t, err) + require.NoError(t, err) _, _, caKey2, _ := NewTestCaCert(Version2, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil) - assert.NoError(t, err) + require.NoError(t, err) err = ca.VerifyPrivateKey(Curve_P256, caKey2) - assert.Error(t, err) + require.Error(t, err) c, _, priv, _ := NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil) rawPriv, b, curve, err := UnmarshalPrivateKeyFromPEM(priv) - assert.NoError(t, err) + require.NoError(t, err) assert.Empty(t, b) assert.Equal(t, Curve_P256, curve) err = c.VerifyPrivateKey(Curve_P256, rawPriv) - assert.NoError(t, err) + require.NoError(t, err) _, priv2 := P256Keypair() err = c.VerifyPrivateKey(Curve_P256, priv2) - assert.Error(t, err) + require.Error(t, err) } func TestCertificateV2_Copy(t *testing.T) { @@ -223,7 +223,7 @@ func TestCertificateV2_Copy(t *testing.T) { func TestUnmarshalCertificateV2(t *testing.T) { data := []byte("\x98\x00\x00") _, err := unmarshalCertificateV2(data, nil, Curve_CURVE25519) - assert.EqualError(t, err, "bad wire format") + require.EqualError(t, err, "bad wire format") } func TestCertificateV2_marshalForSigningStability(t *testing.T) { diff --git a/cert/crypto_test.go b/cert/crypto_test.go index c43eed7..ee671c0 100644 --- a/cert/crypto_test.go +++ b/cert/crypto_test.go @@ -4,6 +4,7 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "golang.org/x/crypto/argon2" ) @@ -61,33 +62,33 @@ qrlJ69wer3ZUHFXA // Success test case curve, k, rest, err := DecryptAndUnmarshalSigningPrivateKey(passphrase, keyBundle) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, Curve_CURVE25519, curve) assert.Len(t, k, 64) assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem)) // Fail due to short key curve, k, rest, err = DecryptAndUnmarshalSigningPrivateKey(passphrase, rest) - assert.EqualError(t, err, "key was not 64 bytes, is invalid ed25519 private key") + require.EqualError(t, err, "key was not 64 bytes, is invalid ed25519 private key") assert.Nil(t, k) assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem)) // Fail due to invalid banner curve, k, rest, err = DecryptAndUnmarshalSigningPrivateKey(passphrase, rest) - assert.EqualError(t, err, "bytes did not contain a proper nebula encrypted Ed25519/ECDSA private key banner") + require.EqualError(t, err, "bytes did not contain a proper nebula encrypted Ed25519/ECDSA private key banner") assert.Nil(t, k) assert.Equal(t, rest, invalidPem) // Fail due to ivalid PEM format, because // it's missing the requisite pre-encapsulation boundary. curve, k, rest, err = DecryptAndUnmarshalSigningPrivateKey(passphrase, rest) - assert.EqualError(t, err, "input did not contain a valid PEM encoded block") + require.EqualError(t, err, "input did not contain a valid PEM encoded block") assert.Nil(t, k) assert.Equal(t, rest, invalidPem) // Fail due to invalid passphrase curve, k, rest, err = DecryptAndUnmarshalSigningPrivateKey([]byte("invalid passphrase"), privKey) - assert.EqualError(t, err, "invalid passphrase or corrupt private key") + require.EqualError(t, err, "invalid passphrase or corrupt private key") assert.Nil(t, k) assert.Equal(t, []byte{}, rest) } @@ -99,14 +100,14 @@ func TestEncryptAndMarshalSigningPrivateKey(t *testing.T) { bytes := []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA") kdfParams := NewArgon2Parameters(64*1024, 4, 3) key, err := EncryptAndMarshalSigningPrivateKey(Curve_CURVE25519, bytes, passphrase, kdfParams) - assert.NoError(t, err) + require.NoError(t, err) // Verify the "key" can be decrypted successfully curve, k, rest, err := DecryptAndUnmarshalSigningPrivateKey(passphrase, key) assert.Len(t, k, 64) assert.Equal(t, Curve_CURVE25519, curve) assert.Equal(t, []byte{}, rest) - assert.NoError(t, err) + require.NoError(t, err) // EncryptAndMarshalEd25519PrivateKey does not create any errors itself } diff --git a/cert/pem_test.go b/cert/pem_test.go index 9ad8a69..6e49249 100644 --- a/cert/pem_test.go +++ b/cert/pem_test.go @@ -4,6 +4,7 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestUnmarshalCertificateFromPEM(t *testing.T) { @@ -35,20 +36,20 @@ bzBEr00kERQxxTzTsH8cpYEgRoipvmExvg8WP8NdAJEYJosB cert, rest, err := UnmarshalCertificateFromPEM(certBundle) assert.NotNil(t, cert) assert.Equal(t, rest, append(badBanner, invalidPem...)) - assert.NoError(t, err) + require.NoError(t, err) // Fail due to invalid banner. cert, rest, err = UnmarshalCertificateFromPEM(rest) assert.Nil(t, cert) assert.Equal(t, rest, invalidPem) - assert.EqualError(t, err, "bytes did not contain a proper certificate banner") + require.EqualError(t, err, "bytes did not contain a proper certificate banner") // Fail due to ivalid PEM format, because // it's missing the requisite pre-encapsulation boundary. cert, rest, err = UnmarshalCertificateFromPEM(rest) assert.Nil(t, cert) assert.Equal(t, rest, invalidPem) - assert.EqualError(t, err, "input did not contain a valid PEM encoded block") + require.EqualError(t, err, "input did not contain a valid PEM encoded block") } func TestUnmarshalSigningPrivateKeyFromPEM(t *testing.T) { @@ -84,33 +85,33 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA assert.Len(t, k, 64) assert.Equal(t, rest, appendByteSlices(privP256Key, shortKey, invalidBanner, invalidPem)) assert.Equal(t, Curve_CURVE25519, curve) - assert.NoError(t, err) + require.NoError(t, err) // Success test case k, rest, curve, err = UnmarshalSigningPrivateKeyFromPEM(rest) assert.Len(t, k, 32) assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem)) assert.Equal(t, Curve_P256, curve) - assert.NoError(t, err) + require.NoError(t, err) // Fail due to short key k, rest, curve, err = UnmarshalSigningPrivateKeyFromPEM(rest) assert.Nil(t, k) assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem)) - assert.EqualError(t, err, "key was not 64 bytes, is invalid Ed25519 private key") + require.EqualError(t, err, "key was not 64 bytes, is invalid Ed25519 private key") // Fail due to invalid banner k, rest, curve, err = UnmarshalSigningPrivateKeyFromPEM(rest) assert.Nil(t, k) assert.Equal(t, rest, invalidPem) - assert.EqualError(t, err, "bytes did not contain a proper Ed25519/ECDSA private key banner") + require.EqualError(t, err, "bytes did not contain a proper Ed25519/ECDSA private key banner") // Fail due to ivalid PEM format, because // it's missing the requisite pre-encapsulation boundary. k, rest, curve, err = UnmarshalSigningPrivateKeyFromPEM(rest) assert.Nil(t, k) assert.Equal(t, rest, invalidPem) - assert.EqualError(t, err, "input did not contain a valid PEM encoded block") + require.EqualError(t, err, "input did not contain a valid PEM encoded block") } func TestUnmarshalPrivateKeyFromPEM(t *testing.T) { @@ -146,33 +147,33 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA= assert.Len(t, k, 32) assert.Equal(t, rest, appendByteSlices(privP256Key, shortKey, invalidBanner, invalidPem)) assert.Equal(t, Curve_CURVE25519, curve) - assert.NoError(t, err) + require.NoError(t, err) // Success test case k, rest, curve, err = UnmarshalPrivateKeyFromPEM(rest) assert.Len(t, k, 32) assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem)) assert.Equal(t, Curve_P256, curve) - assert.NoError(t, err) + require.NoError(t, err) // Fail due to short key k, rest, curve, err = UnmarshalPrivateKeyFromPEM(rest) assert.Nil(t, k) assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem)) - assert.EqualError(t, err, "key was not 32 bytes, is invalid CURVE25519 private key") + require.EqualError(t, err, "key was not 32 bytes, is invalid CURVE25519 private key") // Fail due to invalid banner k, rest, curve, err = UnmarshalPrivateKeyFromPEM(rest) assert.Nil(t, k) assert.Equal(t, rest, invalidPem) - assert.EqualError(t, err, "bytes did not contain a proper private key banner") + require.EqualError(t, err, "bytes did not contain a proper private key banner") // Fail due to ivalid PEM format, because // it's missing the requisite pre-encapsulation boundary. k, rest, curve, err = UnmarshalPrivateKeyFromPEM(rest) assert.Nil(t, k) assert.Equal(t, rest, invalidPem) - assert.EqualError(t, err, "input did not contain a valid PEM encoded block") + require.EqualError(t, err, "input did not contain a valid PEM encoded block") } func TestUnmarshalPublicKeyFromPEM(t *testing.T) { @@ -202,7 +203,7 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA= k, rest, curve, err := UnmarshalPublicKeyFromPEM(keyBundle) assert.Len(t, k, 32) assert.Equal(t, Curve_CURVE25519, curve) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem)) // Fail due to short key @@ -210,13 +211,13 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA= assert.Nil(t, k) assert.Equal(t, Curve_CURVE25519, curve) assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem)) - assert.EqualError(t, err, "key was not 32 bytes, is invalid CURVE25519 public key") + require.EqualError(t, err, "key was not 32 bytes, is invalid CURVE25519 public key") // Fail due to invalid banner k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest) assert.Nil(t, k) assert.Equal(t, Curve_CURVE25519, curve) - assert.EqualError(t, err, "bytes did not contain a proper public key banner") + require.EqualError(t, err, "bytes did not contain a proper public key banner") assert.Equal(t, rest, invalidPem) // Fail due to ivalid PEM format, because @@ -225,7 +226,7 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA= assert.Nil(t, k) assert.Equal(t, Curve_CURVE25519, curve) assert.Equal(t, rest, invalidPem) - assert.EqualError(t, err, "input did not contain a valid PEM encoded block") + require.EqualError(t, err, "input did not contain a valid PEM encoded block") } func TestUnmarshalX25519PublicKey(t *testing.T) { @@ -260,14 +261,14 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA= // Success test case k, rest, curve, err := UnmarshalPublicKeyFromPEM(keyBundle) assert.Len(t, k, 32) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, rest, appendByteSlices(pubP256Key, shortKey, invalidBanner, invalidPem)) assert.Equal(t, Curve_CURVE25519, curve) // Success test case k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest) assert.Len(t, k, 65) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem)) assert.Equal(t, Curve_P256, curve) @@ -275,12 +276,12 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA= k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest) assert.Nil(t, k) assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem)) - assert.EqualError(t, err, "key was not 32 bytes, is invalid CURVE25519 public key") + require.EqualError(t, err, "key was not 32 bytes, is invalid CURVE25519 public key") // Fail due to invalid banner k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest) assert.Nil(t, k) - assert.EqualError(t, err, "bytes did not contain a proper public key banner") + require.EqualError(t, err, "bytes did not contain a proper public key banner") assert.Equal(t, rest, invalidPem) // Fail due to ivalid PEM format, because @@ -288,5 +289,5 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA= k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest) assert.Nil(t, k) assert.Equal(t, rest, invalidPem) - assert.EqualError(t, err, "input did not contain a valid PEM encoded block") + require.EqualError(t, err, "input did not contain a valid PEM encoded block") } diff --git a/cert/sign_test.go b/cert/sign_test.go index 30d8480..e6f43cd 100644 --- a/cert/sign_test.go +++ b/cert/sign_test.go @@ -10,6 +10,7 @@ import ( "time" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestCertificateV1_Sign(t *testing.T) { @@ -37,14 +38,14 @@ func TestCertificateV1_Sign(t *testing.T) { pub, priv, err := ed25519.GenerateKey(rand.Reader) c, err := tbs.Sign(&certificateV1{details: detailsV1{notBefore: before, notAfter: after}}, Curve_CURVE25519, priv) - assert.NoError(t, err) + require.NoError(t, err) assert.NotNil(t, c) assert.True(t, c.CheckSignature(pub)) b, err := c.Marshal() - assert.NoError(t, err) + require.NoError(t, err) uc, err := unmarshalCertificateV1(b, nil) - assert.NoError(t, err) + require.NoError(t, err) assert.NotNil(t, uc) } @@ -73,18 +74,18 @@ func TestCertificateV1_SignP256(t *testing.T) { } priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) - assert.NoError(t, err) + require.NoError(t, err) pub := elliptic.Marshal(elliptic.P256(), priv.PublicKey.X, priv.PublicKey.Y) rawPriv := priv.D.FillBytes(make([]byte, 32)) c, err := tbs.Sign(&certificateV1{details: detailsV1{notBefore: before, notAfter: after}}, Curve_P256, rawPriv) - assert.NoError(t, err) + require.NoError(t, err) assert.NotNil(t, c) assert.True(t, c.CheckSignature(pub)) b, err := c.Marshal() - assert.NoError(t, err) + require.NoError(t, err) uc, err := unmarshalCertificateV1(b, nil) - assert.NoError(t, err) + require.NoError(t, err) assert.NotNil(t, uc) } diff --git a/cmd/nebula-cert/ca_test.go b/cmd/nebula-cert/ca_test.go index 71b69be..189fc02 100644 --- a/cmd/nebula-cert/ca_test.go +++ b/cmd/nebula-cert/ca_test.go @@ -14,6 +14,7 @@ import ( "github.com/slackhq/nebula/cert" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func Test_caSummary(t *testing.T) { @@ -106,34 +107,34 @@ func Test_ca(t *testing.T) { ob.Reset() eb.Reset() args := []string{"-version", "1", "-name", "test", "-duration", "100m", "-out-crt", "/do/not/write/pleasecrt", "-out-key", "/do/not/write/pleasekey"} - assert.EqualError(t, ca(args, ob, eb, nopw), "error while writing out-key: open /do/not/write/pleasekey: "+NoSuchDirError) + require.EqualError(t, ca(args, ob, eb, nopw), "error while writing out-key: open /do/not/write/pleasekey: "+NoSuchDirError) assert.Equal(t, "", ob.String()) assert.Equal(t, "", eb.String()) // create temp key file keyF, err := os.CreateTemp("", "test.key") - assert.NoError(t, err) - assert.NoError(t, os.Remove(keyF.Name())) + require.NoError(t, err) + require.NoError(t, os.Remove(keyF.Name())) // failed cert write ob.Reset() eb.Reset() args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-out-crt", "/do/not/write/pleasecrt", "-out-key", keyF.Name()} - assert.EqualError(t, ca(args, ob, eb, nopw), "error while writing out-crt: open /do/not/write/pleasecrt: "+NoSuchDirError) + require.EqualError(t, ca(args, ob, eb, nopw), "error while writing out-crt: open /do/not/write/pleasecrt: "+NoSuchDirError) assert.Equal(t, "", ob.String()) assert.Equal(t, "", eb.String()) // create temp cert file crtF, err := os.CreateTemp("", "test.crt") - assert.NoError(t, err) - assert.NoError(t, os.Remove(crtF.Name())) - assert.NoError(t, os.Remove(keyF.Name())) + require.NoError(t, err) + require.NoError(t, os.Remove(crtF.Name())) + require.NoError(t, os.Remove(keyF.Name())) // test proper cert with removed empty groups and subnets ob.Reset() eb.Reset() args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} - assert.NoError(t, ca(args, ob, eb, nopw)) + require.NoError(t, ca(args, ob, eb, nopw)) assert.Equal(t, "", ob.String()) assert.Equal(t, "", eb.String()) @@ -142,13 +143,13 @@ func Test_ca(t *testing.T) { lKey, b, c, err := cert.UnmarshalSigningPrivateKeyFromPEM(rb) assert.Equal(t, cert.Curve_CURVE25519, c) assert.Empty(t, b) - assert.NoError(t, err) + require.NoError(t, err) assert.Len(t, lKey, 64) rb, _ = os.ReadFile(crtF.Name()) lCrt, b, err := cert.UnmarshalCertificateFromPEM(rb) assert.Empty(t, b) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, "test", lCrt.Name()) assert.Empty(t, lCrt.Networks()) @@ -166,7 +167,7 @@ func Test_ca(t *testing.T) { ob.Reset() eb.Reset() args = []string{"-version", "1", "-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} - assert.NoError(t, ca(args, ob, eb, testpw)) + require.NoError(t, ca(args, ob, eb, testpw)) assert.Equal(t, pwPromptOb, ob.String()) assert.Equal(t, "", eb.String()) @@ -174,7 +175,7 @@ func Test_ca(t *testing.T) { rb, _ = os.ReadFile(keyF.Name()) k, _ := pem.Decode(rb) ned, err := cert.UnmarshalNebulaEncryptedData(k.Bytes) - assert.NoError(t, err) + require.NoError(t, err) // we won't know salt in advance, so just check start of string assert.Equal(t, uint32(2*1024*1024), ned.EncryptionMetadata.Argon2Parameters.Memory) assert.Equal(t, uint8(4), ned.EncryptionMetadata.Argon2Parameters.Parallelism) @@ -184,7 +185,7 @@ func Test_ca(t *testing.T) { var curve cert.Curve curve, lKey, b, err = cert.DecryptAndUnmarshalSigningPrivateKey(passphrase, rb) assert.Equal(t, cert.Curve_CURVE25519, curve) - assert.NoError(t, err) + require.NoError(t, err) assert.Empty(t, b) assert.Len(t, lKey, 64) @@ -194,7 +195,7 @@ func Test_ca(t *testing.T) { ob.Reset() eb.Reset() args = []string{"-version", "1", "-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} - assert.Error(t, ca(args, ob, eb, errpw)) + require.Error(t, ca(args, ob, eb, errpw)) assert.Equal(t, pwPromptOb, ob.String()) assert.Equal(t, "", eb.String()) @@ -204,7 +205,7 @@ func Test_ca(t *testing.T) { ob.Reset() eb.Reset() args = []string{"-version", "1", "-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} - assert.EqualError(t, ca(args, ob, eb, nopw), "no passphrase specified, remove -encrypt flag to write out-key in plaintext") + require.EqualError(t, ca(args, ob, eb, nopw), "no passphrase specified, remove -encrypt flag to write out-key in plaintext") assert.Equal(t, strings.Repeat(pwPromptOb, 5), ob.String()) // prompts 5 times before giving up assert.Equal(t, "", eb.String()) @@ -214,13 +215,13 @@ func Test_ca(t *testing.T) { ob.Reset() eb.Reset() args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} - assert.NoError(t, ca(args, ob, eb, nopw)) + require.NoError(t, ca(args, ob, eb, nopw)) // test that we won't overwrite existing certificate file ob.Reset() eb.Reset() args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} - assert.EqualError(t, ca(args, ob, eb, nopw), "refusing to overwrite existing CA key: "+keyF.Name()) + require.EqualError(t, ca(args, ob, eb, nopw), "refusing to overwrite existing CA key: "+keyF.Name()) assert.Equal(t, "", ob.String()) assert.Equal(t, "", eb.String()) @@ -229,7 +230,7 @@ func Test_ca(t *testing.T) { ob.Reset() eb.Reset() args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} - assert.EqualError(t, ca(args, ob, eb, nopw), "refusing to overwrite existing CA cert: "+crtF.Name()) + require.EqualError(t, ca(args, ob, eb, nopw), "refusing to overwrite existing CA cert: "+crtF.Name()) assert.Equal(t, "", ob.String()) assert.Equal(t, "", eb.String()) os.Remove(keyF.Name()) diff --git a/cmd/nebula-cert/keygen_test.go b/cmd/nebula-cert/keygen_test.go index 3427254..7eed5d2 100644 --- a/cmd/nebula-cert/keygen_test.go +++ b/cmd/nebula-cert/keygen_test.go @@ -7,6 +7,7 @@ import ( "github.com/slackhq/nebula/cert" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func Test_keygenSummary(t *testing.T) { @@ -47,33 +48,33 @@ func Test_keygen(t *testing.T) { ob.Reset() eb.Reset() args := []string{"-out-pub", "/do/not/write/pleasepub", "-out-key", "/do/not/write/pleasekey"} - assert.EqualError(t, keygen(args, ob, eb), "error while writing out-key: open /do/not/write/pleasekey: "+NoSuchDirError) + require.EqualError(t, keygen(args, ob, eb), "error while writing out-key: open /do/not/write/pleasekey: "+NoSuchDirError) assert.Equal(t, "", ob.String()) assert.Equal(t, "", eb.String()) // create temp key file keyF, err := os.CreateTemp("", "test.key") - assert.NoError(t, err) + require.NoError(t, err) defer os.Remove(keyF.Name()) // failed pub write ob.Reset() eb.Reset() args = []string{"-out-pub", "/do/not/write/pleasepub", "-out-key", keyF.Name()} - assert.EqualError(t, keygen(args, ob, eb), "error while writing out-pub: open /do/not/write/pleasepub: "+NoSuchDirError) + require.EqualError(t, keygen(args, ob, eb), "error while writing out-pub: open /do/not/write/pleasepub: "+NoSuchDirError) assert.Equal(t, "", ob.String()) assert.Equal(t, "", eb.String()) // create temp pub file pubF, err := os.CreateTemp("", "test.pub") - assert.NoError(t, err) + require.NoError(t, err) defer os.Remove(pubF.Name()) // test proper keygen ob.Reset() eb.Reset() args = []string{"-out-pub", pubF.Name(), "-out-key", keyF.Name()} - assert.NoError(t, keygen(args, ob, eb)) + require.NoError(t, keygen(args, ob, eb)) assert.Equal(t, "", ob.String()) assert.Equal(t, "", eb.String()) @@ -82,13 +83,13 @@ func Test_keygen(t *testing.T) { lKey, b, curve, err := cert.UnmarshalPrivateKeyFromPEM(rb) assert.Equal(t, cert.Curve_CURVE25519, curve) assert.Empty(t, b) - assert.NoError(t, err) + require.NoError(t, err) assert.Len(t, lKey, 32) rb, _ = os.ReadFile(pubF.Name()) lPub, b, curve, err := cert.UnmarshalPublicKeyFromPEM(rb) assert.Equal(t, cert.Curve_CURVE25519, curve) assert.Empty(t, b) - assert.NoError(t, err) + require.NoError(t, err) assert.Len(t, lPub, 32) } diff --git a/cmd/nebula-cert/main_test.go b/cmd/nebula-cert/main_test.go index f332895..2e92e7e 100644 --- a/cmd/nebula-cert/main_test.go +++ b/cmd/nebula-cert/main_test.go @@ -9,6 +9,7 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func Test_help(t *testing.T) { @@ -79,7 +80,7 @@ func assertHelpError(t *testing.T, err error, msg string) { t.Fatal(fmt.Sprintf("err was not a helpError: %q, expected %q", err, msg)) } - assert.EqualError(t, err, msg) + require.EqualError(t, err, msg) } func optionalPkcs11String(msg string) string { diff --git a/cmd/nebula-cert/print_test.go b/cmd/nebula-cert/print_test.go index 77e98e6..061e472 100644 --- a/cmd/nebula-cert/print_test.go +++ b/cmd/nebula-cert/print_test.go @@ -12,6 +12,7 @@ import ( "github.com/slackhq/nebula/cert" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func Test_printSummary(t *testing.T) { @@ -52,20 +53,20 @@ func Test_printCert(t *testing.T) { err = printCert([]string{"-path", "does_not_exist"}, ob, eb) assert.Equal(t, "", ob.String()) assert.Equal(t, "", eb.String()) - assert.EqualError(t, err, "unable to read cert; open does_not_exist: "+NoSuchFileError) + require.EqualError(t, err, "unable to read cert; open does_not_exist: "+NoSuchFileError) // invalid cert at path ob.Reset() eb.Reset() tf, err := os.CreateTemp("", "print-cert") - assert.NoError(t, err) + require.NoError(t, err) defer os.Remove(tf.Name()) tf.WriteString("-----BEGIN NOPE-----") err = printCert([]string{"-path", tf.Name()}, ob, eb) assert.Equal(t, "", ob.String()) assert.Equal(t, "", eb.String()) - assert.EqualError(t, err, "error while unmarshaling cert: input did not contain a valid PEM encoded block") + require.EqualError(t, err, "error while unmarshaling cert: input did not contain a valid PEM encoded block") // test multiple certs ob.Reset() @@ -84,7 +85,7 @@ func Test_printCert(t *testing.T) { fp, _ := c.Fingerprint() pk := hex.EncodeToString(c.PublicKey()) sig := hex.EncodeToString(c.Signature()) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal( t, //"NebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: "+c.Issuer()+"\n\t\tPublic key: "+pk+"\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: "+fp+"\n\tSignature: "+sig+"\n}\nNebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: "+c.Issuer()+"\n\t\tPublic key: "+pk+"\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: "+fp+"\n\tSignature: "+sig+"\n}\nNebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: "+c.Issuer()+"\n\t\tPublic key: "+pk+"\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: "+fp+"\n\tSignature: "+sig+"\n}\n", @@ -169,7 +170,7 @@ func Test_printCert(t *testing.T) { fp, _ = c.Fingerprint() pk = hex.EncodeToString(c.PublicKey()) sig = hex.EncodeToString(c.Signature()) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal( t, `[{"details":{"curve":"CURVE25519","groups":["hi"],"isCa":false,"issuer":"`+c.Issuer()+`","name":"test","networks":["10.0.0.123/8"],"notAfter":"0001-01-01T00:00:00Z","notBefore":"0001-01-01T00:00:00Z","publicKey":"`+pk+`","unsafeNetworks":[]},"fingerprint":"`+fp+`","signature":"`+sig+`","version":1},{"details":{"curve":"CURVE25519","groups":["hi"],"isCa":false,"issuer":"`+c.Issuer()+`","name":"test","networks":["10.0.0.123/8"],"notAfter":"0001-01-01T00:00:00Z","notBefore":"0001-01-01T00:00:00Z","publicKey":"`+pk+`","unsafeNetworks":[]},"fingerprint":"`+fp+`","signature":"`+sig+`","version":1},{"details":{"curve":"CURVE25519","groups":["hi"],"isCa":false,"issuer":"`+c.Issuer()+`","name":"test","networks":["10.0.0.123/8"],"notAfter":"0001-01-01T00:00:00Z","notBefore":"0001-01-01T00:00:00Z","publicKey":"`+pk+`","unsafeNetworks":[]},"fingerprint":"`+fp+`","signature":"`+sig+`","version":1}] diff --git a/cmd/nebula-cert/sign_test.go b/cmd/nebula-cert/sign_test.go index 4b242a4..b2bba76 100644 --- a/cmd/nebula-cert/sign_test.go +++ b/cmd/nebula-cert/sign_test.go @@ -13,6 +13,7 @@ import ( "github.com/slackhq/nebula/cert" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "golang.org/x/crypto/ed25519" ) @@ -103,17 +104,17 @@ func Test_signCert(t *testing.T) { ob.Reset() eb.Reset() args := []string{"-version", "1", "-ca-crt", "./nope", "-ca-key", "./nope", "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"} - assert.EqualError(t, signCert(args, ob, eb, nopw), "error while reading ca-key: open ./nope: "+NoSuchFileError) + require.EqualError(t, signCert(args, ob, eb, nopw), "error while reading ca-key: open ./nope: "+NoSuchFileError) // failed to unmarshal key ob.Reset() eb.Reset() caKeyF, err := os.CreateTemp("", "sign-cert.key") - assert.NoError(t, err) + require.NoError(t, err) defer os.Remove(caKeyF.Name()) args = []string{"-version", "1", "-ca-crt", "./nope", "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"} - assert.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing ca-key: input did not contain a valid PEM encoded block") + require.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing ca-key: input did not contain a valid PEM encoded block") assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) @@ -125,7 +126,7 @@ func Test_signCert(t *testing.T) { // failed to read cert args = []string{"-version", "1", "-ca-crt", "./nope", "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"} - assert.EqualError(t, signCert(args, ob, eb, nopw), "error while reading ca-crt: open ./nope: "+NoSuchFileError) + require.EqualError(t, signCert(args, ob, eb, nopw), "error while reading ca-crt: open ./nope: "+NoSuchFileError) assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) @@ -133,11 +134,11 @@ func Test_signCert(t *testing.T) { ob.Reset() eb.Reset() caCrtF, err := os.CreateTemp("", "sign-cert.crt") - assert.NoError(t, err) + require.NoError(t, err) defer os.Remove(caCrtF.Name()) args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"} - assert.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing ca-crt: input did not contain a valid PEM encoded block") + require.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing ca-crt: input did not contain a valid PEM encoded block") assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) @@ -148,7 +149,7 @@ func Test_signCert(t *testing.T) { // failed to read pub args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-in-pub", "./nope", "-duration", "100m"} - assert.EqualError(t, signCert(args, ob, eb, nopw), "error while reading in-pub: open ./nope: "+NoSuchFileError) + require.EqualError(t, signCert(args, ob, eb, nopw), "error while reading in-pub: open ./nope: "+NoSuchFileError) assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) @@ -156,11 +157,11 @@ func Test_signCert(t *testing.T) { ob.Reset() eb.Reset() inPubF, err := os.CreateTemp("", "in.pub") - assert.NoError(t, err) + require.NoError(t, err) defer os.Remove(inPubF.Name()) args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-in-pub", inPubF.Name(), "-duration", "100m"} - assert.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing in-pub: input did not contain a valid PEM encoded block") + require.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing in-pub: input did not contain a valid PEM encoded block") assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) @@ -210,14 +211,14 @@ func Test_signCert(t *testing.T) { // mismatched ca key _, caPriv2, _ := ed25519.GenerateKey(rand.Reader) caKeyF2, err := os.CreateTemp("", "sign-cert-2.key") - assert.NoError(t, err) + require.NoError(t, err) defer os.Remove(caKeyF2.Name()) caKeyF2.Write(cert.MarshalSigningPrivateKeyToPEM(cert.Curve_CURVE25519, caPriv2)) ob.Reset() eb.Reset() args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF2.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "a"} - assert.EqualError(t, signCert(args, ob, eb, nopw), "refusing to sign, root certificate does not match private key") + require.EqualError(t, signCert(args, ob, eb, nopw), "refusing to sign, root certificate does not match private key") assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) @@ -225,34 +226,34 @@ func Test_signCert(t *testing.T) { ob.Reset() eb.Reset() args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "/do/not/write/pleasecrt", "-out-key", "/do/not/write/pleasekey", "-duration", "100m", "-subnets", "10.1.1.1/32"} - assert.EqualError(t, signCert(args, ob, eb, nopw), "error while writing out-key: open /do/not/write/pleasekey: "+NoSuchDirError) + require.EqualError(t, signCert(args, ob, eb, nopw), "error while writing out-key: open /do/not/write/pleasekey: "+NoSuchDirError) assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) // create temp key file keyF, err := os.CreateTemp("", "test.key") - assert.NoError(t, err) + require.NoError(t, err) os.Remove(keyF.Name()) // failed cert write ob.Reset() eb.Reset() args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "/do/not/write/pleasecrt", "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32"} - assert.EqualError(t, signCert(args, ob, eb, nopw), "error while writing out-crt: open /do/not/write/pleasecrt: "+NoSuchDirError) + require.EqualError(t, signCert(args, ob, eb, nopw), "error while writing out-crt: open /do/not/write/pleasecrt: "+NoSuchDirError) assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) os.Remove(keyF.Name()) // create temp cert file crtF, err := os.CreateTemp("", "test.crt") - assert.NoError(t, err) + require.NoError(t, err) os.Remove(crtF.Name()) // test proper cert with removed empty groups and subnets ob.Reset() eb.Reset() args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} - assert.NoError(t, signCert(args, ob, eb, nopw)) + require.NoError(t, signCert(args, ob, eb, nopw)) assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) @@ -261,13 +262,13 @@ func Test_signCert(t *testing.T) { lKey, b, curve, err := cert.UnmarshalPrivateKeyFromPEM(rb) assert.Equal(t, cert.Curve_CURVE25519, curve) assert.Empty(t, b) - assert.NoError(t, err) + require.NoError(t, err) assert.Len(t, lKey, 32) rb, _ = os.ReadFile(crtF.Name()) lCrt, b, err := cert.UnmarshalCertificateFromPEM(rb) assert.Empty(t, b) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, "test", lCrt.Name()) assert.Equal(t, "1.1.1.1/24", lCrt.Networks()[0].String()) @@ -295,7 +296,7 @@ func Test_signCert(t *testing.T) { ob.Reset() eb.Reset() args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-in-pub", inPubF.Name(), "-duration", "100m", "-groups", "1"} - assert.NoError(t, signCert(args, ob, eb, nopw)) + require.NoError(t, signCert(args, ob, eb, nopw)) assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) @@ -303,7 +304,7 @@ func Test_signCert(t *testing.T) { rb, _ = os.ReadFile(crtF.Name()) lCrt, b, err = cert.UnmarshalCertificateFromPEM(rb) assert.Empty(t, b) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, lCrt.PublicKey(), inPub) // test refuse to sign cert with duration beyond root @@ -312,7 +313,7 @@ func Test_signCert(t *testing.T) { os.Remove(keyF.Name()) os.Remove(crtF.Name()) args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "1000m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} - assert.EqualError(t, signCert(args, ob, eb, nopw), "error while signing: certificate expires after signing certificate") + require.EqualError(t, signCert(args, ob, eb, nopw), "error while signing: certificate expires after signing certificate") assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) @@ -320,14 +321,14 @@ func Test_signCert(t *testing.T) { os.Remove(keyF.Name()) os.Remove(crtF.Name()) args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} - assert.NoError(t, signCert(args, ob, eb, nopw)) + require.NoError(t, signCert(args, ob, eb, nopw)) // test that we won't overwrite existing key file os.Remove(crtF.Name()) ob.Reset() eb.Reset() args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} - assert.EqualError(t, signCert(args, ob, eb, nopw), "refusing to overwrite existing key: "+keyF.Name()) + require.EqualError(t, signCert(args, ob, eb, nopw), "refusing to overwrite existing key: "+keyF.Name()) assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) @@ -335,14 +336,14 @@ func Test_signCert(t *testing.T) { os.Remove(keyF.Name()) os.Remove(crtF.Name()) args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} - assert.NoError(t, signCert(args, ob, eb, nopw)) + require.NoError(t, signCert(args, ob, eb, nopw)) // test that we won't overwrite existing certificate file os.Remove(keyF.Name()) ob.Reset() eb.Reset() args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} - assert.EqualError(t, signCert(args, ob, eb, nopw), "refusing to overwrite existing cert: "+crtF.Name()) + require.EqualError(t, signCert(args, ob, eb, nopw), "refusing to overwrite existing cert: "+crtF.Name()) assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) @@ -355,11 +356,11 @@ func Test_signCert(t *testing.T) { eb.Reset() caKeyF, err = os.CreateTemp("", "sign-cert.key") - assert.NoError(t, err) + require.NoError(t, err) defer os.Remove(caKeyF.Name()) caCrtF, err = os.CreateTemp("", "sign-cert.crt") - assert.NoError(t, err) + require.NoError(t, err) defer os.Remove(caCrtF.Name()) // generate the encrypted key @@ -374,7 +375,7 @@ func Test_signCert(t *testing.T) { // test with the proper password args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} - assert.NoError(t, signCert(args, ob, eb, testpw)) + require.NoError(t, signCert(args, ob, eb, testpw)) assert.Equal(t, "Enter passphrase: ", ob.String()) assert.Empty(t, eb.String()) @@ -384,7 +385,7 @@ func Test_signCert(t *testing.T) { testpw.password = []byte("invalid password") args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} - assert.Error(t, signCert(args, ob, eb, testpw)) + require.Error(t, signCert(args, ob, eb, testpw)) assert.Equal(t, "Enter passphrase: ", ob.String()) assert.Empty(t, eb.String()) @@ -393,7 +394,7 @@ func Test_signCert(t *testing.T) { eb.Reset() args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} - assert.Error(t, signCert(args, ob, eb, nopw)) + require.Error(t, signCert(args, ob, eb, nopw)) // normally the user hitting enter on the prompt would add newlines between these assert.Equal(t, "Enter passphrase: Enter passphrase: Enter passphrase: Enter passphrase: Enter passphrase: ", ob.String()) assert.Empty(t, eb.String()) @@ -403,7 +404,7 @@ func Test_signCert(t *testing.T) { eb.Reset() args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} - assert.Error(t, signCert(args, ob, eb, errpw)) + require.Error(t, signCert(args, ob, eb, errpw)) assert.Equal(t, "Enter passphrase: ", ob.String()) assert.Empty(t, eb.String()) } diff --git a/cmd/nebula-cert/verify_test.go b/cmd/nebula-cert/verify_test.go index c2a9f55..acc9cca 100644 --- a/cmd/nebula-cert/verify_test.go +++ b/cmd/nebula-cert/verify_test.go @@ -9,6 +9,7 @@ import ( "github.com/slackhq/nebula/cert" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "golang.org/x/crypto/ed25519" ) @@ -50,20 +51,20 @@ func Test_verify(t *testing.T) { err := verify([]string{"-ca", "does_not_exist", "-crt", "does_not_exist"}, ob, eb) assert.Equal(t, "", ob.String()) assert.Equal(t, "", eb.String()) - assert.EqualError(t, err, "error while reading ca: open does_not_exist: "+NoSuchFileError) + require.EqualError(t, err, "error while reading ca: open does_not_exist: "+NoSuchFileError) // invalid ca at path ob.Reset() eb.Reset() caFile, err := os.CreateTemp("", "verify-ca") - assert.NoError(t, err) + require.NoError(t, err) defer os.Remove(caFile.Name()) caFile.WriteString("-----BEGIN NOPE-----") err = verify([]string{"-ca", caFile.Name(), "-crt", "does_not_exist"}, ob, eb) assert.Equal(t, "", ob.String()) assert.Equal(t, "", eb.String()) - assert.EqualError(t, err, "error while adding ca cert to pool: input did not contain a valid PEM encoded block") + require.EqualError(t, err, "error while adding ca cert to pool: input did not contain a valid PEM encoded block") // make a ca for later caPub, caPriv, _ := ed25519.GenerateKey(rand.Reader) @@ -77,20 +78,20 @@ func Test_verify(t *testing.T) { err = verify([]string{"-ca", caFile.Name(), "-crt", "does_not_exist"}, ob, eb) assert.Equal(t, "", ob.String()) assert.Equal(t, "", eb.String()) - assert.EqualError(t, err, "unable to read crt: open does_not_exist: "+NoSuchFileError) + require.EqualError(t, err, "unable to read crt: open does_not_exist: "+NoSuchFileError) // invalid crt at path ob.Reset() eb.Reset() certFile, err := os.CreateTemp("", "verify-cert") - assert.NoError(t, err) + require.NoError(t, err) defer os.Remove(certFile.Name()) certFile.WriteString("-----BEGIN NOPE-----") err = verify([]string{"-ca", caFile.Name(), "-crt", certFile.Name()}, ob, eb) assert.Equal(t, "", ob.String()) assert.Equal(t, "", eb.String()) - assert.EqualError(t, err, "error while parsing crt: input did not contain a valid PEM encoded block") + require.EqualError(t, err, "error while parsing crt: input did not contain a valid PEM encoded block") // unverifiable cert at path crt, _ := NewTestCert(ca, caPriv, "test-cert", time.Now().Add(time.Hour*-1), time.Now().Add(time.Hour), nil, nil, nil) @@ -107,7 +108,7 @@ func Test_verify(t *testing.T) { err = verify([]string{"-ca", caFile.Name(), "-crt", certFile.Name()}, ob, eb) assert.Equal(t, "", ob.String()) assert.Equal(t, "", eb.String()) - assert.ErrorIs(t, err, cert.ErrSignatureMismatch) + require.ErrorIs(t, err, cert.ErrSignatureMismatch) // verified cert at path crt, _ = NewTestCert(ca, caPriv, "test-cert", time.Now().Add(time.Hour*-1), time.Now().Add(time.Hour), nil, nil, nil) @@ -119,5 +120,5 @@ func Test_verify(t *testing.T) { err = verify([]string{"-ca", caFile.Name(), "-crt", certFile.Name()}, ob, eb) assert.Equal(t, "", ob.String()) assert.Equal(t, "", eb.String()) - assert.NoError(t, err) + require.NoError(t, err) } diff --git a/config/config_test.go b/config/config_test.go index 39301f9..468c642 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -19,18 +19,18 @@ func TestConfig_Load(t *testing.T) { // invalid yaml c := NewC(l) os.WriteFile(filepath.Join(dir, "01.yaml"), []byte(" invalid yaml"), 0644) - assert.EqualError(t, c.Load(dir), "yaml: unmarshal errors:\n line 1: cannot unmarshal !!str `invalid...` into map[interface {}]interface {}") + require.EqualError(t, c.Load(dir), "yaml: unmarshal errors:\n line 1: cannot unmarshal !!str `invalid...` into map[interface {}]interface {}") // simple multi config merge c = NewC(l) os.RemoveAll(dir) os.Mkdir(dir, 0755) - assert.NoError(t, err) + require.NoError(t, err) os.WriteFile(filepath.Join(dir, "01.yaml"), []byte("outer:\n inner: hi"), 0644) os.WriteFile(filepath.Join(dir, "02.yml"), []byte("outer:\n inner: override\nnew: hi"), 0644) - assert.NoError(t, c.Load(dir)) + require.NoError(t, c.Load(dir)) expected := map[interface{}]interface{}{ "outer": map[interface{}]interface{}{ "inner": "override", @@ -117,11 +117,11 @@ func TestConfig_ReloadConfig(t *testing.T) { l := test.NewLogger() done := make(chan bool, 1) dir, err := os.MkdirTemp("", "config-test") - assert.NoError(t, err) + require.NoError(t, err) os.WriteFile(filepath.Join(dir, "01.yaml"), []byte("outer:\n inner: hi"), 0644) c := NewC(l) - assert.NoError(t, c.Load(dir)) + require.NoError(t, c.Load(dir)) assert.False(t, c.HasChanged("outer.inner")) assert.False(t, c.HasChanged("outer")) diff --git a/connection_manager_test.go b/connection_manager_test.go index 8e2ef15..2c9baa1 100644 --- a/connection_manager_test.go +++ b/connection_manager_test.go @@ -14,6 +14,7 @@ import ( "github.com/slackhq/nebula/test" "github.com/slackhq/nebula/udp" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func newTestLighthouse() *LightHouse { @@ -223,9 +224,9 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) { } caCert, err := tbs.Sign(nil, cert.Curve_CURVE25519, privCA) - assert.NoError(t, err) + require.NoError(t, err) ncp := cert.NewCAPool() - assert.NoError(t, ncp.AddCA(caCert)) + require.NoError(t, ncp.AddCA(caCert)) pubCrt, _, _ := ed25519.GenerateKey(rand.Reader) tbs = &cert.TBSCertificate{ @@ -237,7 +238,7 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) { PublicKey: pubCrt, } peerCert, err := tbs.Sign(caCert, cert.Curve_CURVE25519, privCA) - assert.NoError(t, err) + require.NoError(t, err) cachedPeerCert, err := ncp.VerifyCertificate(now.Add(time.Second), peerCert) diff --git a/e2e/handshakes_test.go b/e2e/handshakes_test.go index 2e7e6e4..06f2a21 100644 --- a/e2e/handshakes_test.go +++ b/e2e/handshakes_test.go @@ -19,6 +19,7 @@ import ( "github.com/slackhq/nebula/header" "github.com/slackhq/nebula/udp" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "gopkg.in/yaml.v2" ) @@ -771,7 +772,7 @@ func TestRehandshakingRelays(t *testing.T) { "key": string(myNextPrivKey), } rc, err := yaml.Marshal(relayConfig.Settings) - assert.NoError(t, err) + require.NoError(t, err) relayConfig.ReloadConfigString(string(rc)) for { @@ -875,7 +876,7 @@ func TestRehandshakingRelaysPrimary(t *testing.T) { "key": string(myNextPrivKey), } rc, err := yaml.Marshal(relayConfig.Settings) - assert.NoError(t, err) + require.NoError(t, err) relayConfig.ReloadConfigString(string(rc)) for { @@ -970,7 +971,7 @@ func TestRehandshaking(t *testing.T) { "key": string(myNextPrivKey), } rc, err := yaml.Marshal(myConfig.Settings) - assert.NoError(t, err) + require.NoError(t, err) myConfig.ReloadConfigString(string(rc)) for { @@ -987,9 +988,9 @@ func TestRehandshaking(t *testing.T) { r.Log("Got the new cert") // Flip their firewall to only allowing the new group to catch the tunnels reverting incorrectly rc, err = yaml.Marshal(theirConfig.Settings) - assert.NoError(t, err) + require.NoError(t, err) var theirNewConfig m - assert.NoError(t, yaml.Unmarshal(rc, &theirNewConfig)) + require.NoError(t, yaml.Unmarshal(rc, &theirNewConfig)) theirFirewall := theirNewConfig["firewall"].(map[interface{}]interface{}) theirFirewall["inbound"] = []m{{ "proto": "any", @@ -997,7 +998,7 @@ func TestRehandshaking(t *testing.T) { "group": "new group", }} rc, err = yaml.Marshal(theirNewConfig) - assert.NoError(t, err) + require.NoError(t, err) theirConfig.ReloadConfigString(string(rc)) r.Log("Spin until there is only 1 tunnel") @@ -1067,7 +1068,7 @@ func TestRehandshakingLoser(t *testing.T) { "key": string(theirNextPrivKey), } rc, err := yaml.Marshal(theirConfig.Settings) - assert.NoError(t, err) + require.NoError(t, err) theirConfig.ReloadConfigString(string(rc)) for { @@ -1083,9 +1084,9 @@ func TestRehandshakingLoser(t *testing.T) { // Flip my firewall to only allowing the new group to catch the tunnels reverting incorrectly rc, err = yaml.Marshal(myConfig.Settings) - assert.NoError(t, err) + require.NoError(t, err) var myNewConfig m - assert.NoError(t, yaml.Unmarshal(rc, &myNewConfig)) + require.NoError(t, yaml.Unmarshal(rc, &myNewConfig)) theirFirewall := myNewConfig["firewall"].(map[interface{}]interface{}) theirFirewall["inbound"] = []m{{ "proto": "any", @@ -1093,7 +1094,7 @@ func TestRehandshakingLoser(t *testing.T) { "group": "their new group", }} rc, err = yaml.Marshal(myNewConfig) - assert.NoError(t, err) + require.NoError(t, err) myConfig.ReloadConfigString(string(rc)) r.Log("Spin until there is only 1 tunnel") diff --git a/firewall_test.go b/firewall_test.go index 92914af..8c2eeb0 100644 --- a/firewall_test.go +++ b/firewall_test.go @@ -66,61 +66,61 @@ func TestFirewall_AddRule(t *testing.T) { assert.NotNil(t, fw.OutRules) ti, err := netip.ParsePrefix("1.2.3.4/32") - assert.NoError(t, err) + require.NoError(t, err) - assert.NoError(t, fw.AddRule(true, firewall.ProtoTCP, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", "")) + require.NoError(t, fw.AddRule(true, firewall.ProtoTCP, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", "")) // An empty rule is any assert.True(t, fw.InRules.TCP[1].Any.Any.Any) assert.Empty(t, fw.InRules.TCP[1].Any.Groups) assert.Empty(t, fw.InRules.TCP[1].Any.Hosts) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) - assert.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) + require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) assert.Nil(t, fw.InRules.UDP[1].Any.Any) assert.Contains(t, fw.InRules.UDP[1].Any.Groups[0].Groups, "g1") assert.Empty(t, fw.InRules.UDP[1].Any.Hosts) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) - assert.NoError(t, fw.AddRule(true, firewall.ProtoICMP, 1, 1, []string{}, "h1", netip.Prefix{}, netip.Prefix{}, "", "")) + require.NoError(t, fw.AddRule(true, firewall.ProtoICMP, 1, 1, []string{}, "h1", netip.Prefix{}, netip.Prefix{}, "", "")) assert.Nil(t, fw.InRules.ICMP[1].Any.Any) assert.Empty(t, fw.InRules.ICMP[1].Any.Groups) assert.Contains(t, fw.InRules.ICMP[1].Any.Hosts, "h1") fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) - assert.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti, netip.Prefix{}, "", "")) + require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti, netip.Prefix{}, "", "")) assert.Nil(t, fw.OutRules.AnyProto[1].Any.Any) _, ok := fw.OutRules.AnyProto[1].Any.CIDR.Get(ti) assert.True(t, ok) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) - assert.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, ti, "", "")) + require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, ti, "", "")) assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any) _, ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti) assert.True(t, ok) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) - assert.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "ca-name", "")) + require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "ca-name", "")) assert.Contains(t, fw.InRules.UDP[1].CANames, "ca-name") fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) - assert.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "", "ca-sha")) + require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "", "ca-sha")) assert.Contains(t, fw.InRules.UDP[1].CAShas, "ca-sha") fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) - assert.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", netip.Prefix{}, netip.Prefix{}, "", "")) + require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", netip.Prefix{}, netip.Prefix{}, "", "")) assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) anyIp, err := netip.ParsePrefix("0.0.0.0/0") - assert.NoError(t, err) + require.NoError(t, err) - assert.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp, netip.Prefix{}, "", "")) + require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp, netip.Prefix{}, "", "")) assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any) // Test error conditions fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) - assert.Error(t, fw.AddRule(true, math.MaxUint8, 0, 0, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", "")) - assert.Error(t, fw.AddRule(true, firewall.ProtoAny, 10, 0, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", "")) + require.Error(t, fw.AddRule(true, math.MaxUint8, 0, 0, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", "")) + require.Error(t, fw.AddRule(true, firewall.ProtoAny, 10, 0, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", "")) } func TestFirewall_Drop(t *testing.T) { @@ -155,16 +155,16 @@ func TestFirewall_Drop(t *testing.T) { h.buildNetworks(c.networks, c.unsafeNetworks) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c) - assert.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) + require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) cp := cert.NewCAPool() // Drop outbound assert.Equal(t, ErrNoMatchingRule, fw.Drop(p, false, &h, cp, nil)) // Allow inbound resetConntrack(fw) - assert.NoError(t, fw.Drop(p, true, &h, cp, nil)) + require.NoError(t, fw.Drop(p, true, &h, cp, nil)) // Allow outbound because conntrack - assert.NoError(t, fw.Drop(p, false, &h, cp, nil)) + require.NoError(t, fw.Drop(p, false, &h, cp, nil)) // test remote mismatch oldRemote := p.RemoteAddr @@ -174,29 +174,29 @@ func TestFirewall_Drop(t *testing.T) { // ensure signer doesn't get in the way of group checks fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) - assert.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum")) - assert.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad")) + require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum")) + require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad")) assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule) // test caSha doesn't drop on match fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) - assert.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad")) - assert.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum")) - assert.NoError(t, fw.Drop(p, true, &h, cp, nil)) + require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad")) + require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum")) + require.NoError(t, fw.Drop(p, true, &h, cp, nil)) // ensure ca name doesn't get in the way of group checks cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}} fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) - assert.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", "")) - assert.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", "")) + require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", "")) + require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", "")) assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule) // test caName doesn't drop on match cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}} fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) - assert.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", "")) - assert.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", "")) - assert.NoError(t, fw.Drop(p, true, &h, cp, nil)) + require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", "")) + require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", "")) + require.NoError(t, fw.Drop(p, true, &h, cp, nil)) } func BenchmarkFirewallTable_match(b *testing.B) { @@ -350,14 +350,14 @@ func TestFirewall_Drop2(t *testing.T) { h1.buildNetworks(c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks()) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) - assert.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) + require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) cp := cert.NewCAPool() // h1/c1 lacks the proper groups - assert.ErrorIs(t, fw.Drop(p, true, &h1, cp, nil), ErrNoMatchingRule) + require.ErrorIs(t, fw.Drop(p, true, &h1, cp, nil), ErrNoMatchingRule) // c has the proper groups resetConntrack(fw) - assert.NoError(t, fw.Drop(p, true, &h, cp, nil)) + require.NoError(t, fw.Drop(p, true, &h, cp, nil)) } func TestFirewall_Drop3(t *testing.T) { @@ -428,23 +428,23 @@ func TestFirewall_Drop3(t *testing.T) { h3.buildNetworks(c3.Certificate.Networks(), c3.Certificate.UnsafeNetworks()) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) - assert.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", netip.Prefix{}, netip.Prefix{}, "", "")) - assert.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-sha")) + require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", netip.Prefix{}, netip.Prefix{}, "", "")) + require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-sha")) cp := cert.NewCAPool() // c1 should pass because host match - assert.NoError(t, fw.Drop(p, true, &h1, cp, nil)) + require.NoError(t, fw.Drop(p, true, &h1, cp, nil)) // c2 should pass because ca sha match resetConntrack(fw) - assert.NoError(t, fw.Drop(p, true, &h2, cp, nil)) + require.NoError(t, fw.Drop(p, true, &h2, cp, nil)) // c3 should fail because no match resetConntrack(fw) assert.Equal(t, fw.Drop(p, true, &h3, cp, nil), ErrNoMatchingRule) // Test a remote address match fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) - assert.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.MustParsePrefix("1.2.3.4/24"), netip.Prefix{}, "", "")) - assert.NoError(t, fw.Drop(p, true, &h1, cp, nil)) + require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.MustParsePrefix("1.2.3.4/24"), netip.Prefix{}, "", "")) + require.NoError(t, fw.Drop(p, true, &h1, cp, nil)) } func TestFirewall_DropConntrackReload(t *testing.T) { @@ -480,29 +480,29 @@ func TestFirewall_DropConntrackReload(t *testing.T) { h.buildNetworks(c.Certificate.Networks(), c.Certificate.UnsafeNetworks()) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) - assert.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) + require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) cp := cert.NewCAPool() // Drop outbound assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrNoMatchingRule) // Allow inbound resetConntrack(fw) - assert.NoError(t, fw.Drop(p, true, &h, cp, nil)) + require.NoError(t, fw.Drop(p, true, &h, cp, nil)) // Allow outbound because conntrack - assert.NoError(t, fw.Drop(p, false, &h, cp, nil)) + require.NoError(t, fw.Drop(p, false, &h, cp, nil)) oldFw := fw fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) - assert.NoError(t, fw.AddRule(true, firewall.ProtoAny, 10, 10, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) + require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 10, 10, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) fw.Conntrack = oldFw.Conntrack fw.rulesVersion = oldFw.rulesVersion + 1 // Allow outbound because conntrack and new rules allow port 10 - assert.NoError(t, fw.Drop(p, false, &h, cp, nil)) + require.NoError(t, fw.Drop(p, false, &h, cp, nil)) oldFw = fw fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) - assert.NoError(t, fw.AddRule(true, firewall.ProtoAny, 11, 11, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) + require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 11, 11, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) fw.Conntrack = oldFw.Conntrack fw.rulesVersion = oldFw.rulesVersion + 1 @@ -585,42 +585,42 @@ func BenchmarkLookup(b *testing.B) { func Test_parsePort(t *testing.T) { _, _, err := parsePort("") - assert.EqualError(t, err, "was not a number; ``") + require.EqualError(t, err, "was not a number; ``") _, _, err = parsePort(" ") - assert.EqualError(t, err, "was not a number; ` `") + require.EqualError(t, err, "was not a number; ` `") _, _, err = parsePort("-") - assert.EqualError(t, err, "appears to be a range but could not be parsed; `-`") + require.EqualError(t, err, "appears to be a range but could not be parsed; `-`") _, _, err = parsePort(" - ") - assert.EqualError(t, err, "appears to be a range but could not be parsed; ` - `") + require.EqualError(t, err, "appears to be a range but could not be parsed; ` - `") _, _, err = parsePort("a-b") - assert.EqualError(t, err, "beginning range was not a number; `a`") + require.EqualError(t, err, "beginning range was not a number; `a`") _, _, err = parsePort("1-b") - assert.EqualError(t, err, "ending range was not a number; `b`") + require.EqualError(t, err, "ending range was not a number; `b`") s, e, err := parsePort(" 1 - 2 ") assert.Equal(t, int32(1), s) assert.Equal(t, int32(2), e) - assert.NoError(t, err) + require.NoError(t, err) s, e, err = parsePort("0-1") assert.Equal(t, int32(0), s) assert.Equal(t, int32(0), e) - assert.NoError(t, err) + require.NoError(t, err) s, e, err = parsePort("9919") assert.Equal(t, int32(9919), s) assert.Equal(t, int32(9919), e) - assert.NoError(t, err) + require.NoError(t, err) s, e, err = parsePort("any") assert.Equal(t, int32(0), s) assert.Equal(t, int32(0), e) - assert.NoError(t, err) + require.NoError(t, err) } func TestNewFirewallFromConfig(t *testing.T) { @@ -633,53 +633,53 @@ func TestNewFirewallFromConfig(t *testing.T) { conf := config.NewC(l) conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": "asdf"} _, err = NewFirewallFromConfig(l, cs, conf) - assert.EqualError(t, err, "firewall.outbound failed to parse, should be an array of rules") + require.EqualError(t, err, "firewall.outbound failed to parse, should be an array of rules") // Test both port and code conf = config.NewC(l) conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "code": "2"}}} _, err = NewFirewallFromConfig(l, cs, conf) - assert.EqualError(t, err, "firewall.outbound rule #0; only one of port or code should be provided") + require.EqualError(t, err, "firewall.outbound rule #0; only one of port or code should be provided") // Test missing host, group, cidr, ca_name and ca_sha conf = config.NewC(l) conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{}}} _, err = NewFirewallFromConfig(l, cs, conf) - assert.EqualError(t, err, "firewall.outbound rule #0; at least one of host, group, cidr, local_cidr, ca_name, or ca_sha must be provided") + require.EqualError(t, err, "firewall.outbound rule #0; at least one of host, group, cidr, local_cidr, ca_name, or ca_sha must be provided") // Test code/port error conf = config.NewC(l) conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "a", "host": "testh"}}} _, err = NewFirewallFromConfig(l, cs, conf) - assert.EqualError(t, err, "firewall.outbound rule #0; code was not a number; `a`") + require.EqualError(t, err, "firewall.outbound rule #0; code was not a number; `a`") conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "a", "host": "testh"}}} _, err = NewFirewallFromConfig(l, cs, conf) - assert.EqualError(t, err, "firewall.outbound rule #0; port was not a number; `a`") + require.EqualError(t, err, "firewall.outbound rule #0; port was not a number; `a`") // Test proto error conf = config.NewC(l) conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "host": "testh"}}} _, err = NewFirewallFromConfig(l, cs, conf) - assert.EqualError(t, err, "firewall.outbound rule #0; proto was not understood; ``") + require.EqualError(t, err, "firewall.outbound rule #0; proto was not understood; ``") // Test cidr parse error conf = config.NewC(l) conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "cidr": "testh", "proto": "any"}}} _, err = NewFirewallFromConfig(l, cs, conf) - assert.EqualError(t, err, "firewall.outbound rule #0; cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'") + require.EqualError(t, err, "firewall.outbound rule #0; cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'") // Test local_cidr parse error conf = config.NewC(l) conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "local_cidr": "testh", "proto": "any"}}} _, err = NewFirewallFromConfig(l, cs, conf) - assert.EqualError(t, err, "firewall.outbound rule #0; local_cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'") + require.EqualError(t, err, "firewall.outbound rule #0; local_cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'") // Test both group and groups conf = config.NewC(l) conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a", "groups": []string{"b", "c"}}}} _, err = NewFirewallFromConfig(l, cs, conf) - assert.EqualError(t, err, "firewall.inbound rule #0; only one of group or groups should be defined, both provided") + require.EqualError(t, err, "firewall.inbound rule #0; only one of group or groups should be defined, both provided") } func TestAddFirewallRulesFromConfig(t *testing.T) { @@ -688,28 +688,28 @@ func TestAddFirewallRulesFromConfig(t *testing.T) { conf := config.NewC(l) mf := &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "tcp", "host": "a"}}} - assert.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf)) + require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf)) assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoTCP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall) // Test adding udp rule conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "udp", "host": "a"}}} - assert.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf)) + require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf)) assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoUDP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall) // Test adding icmp rule conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "icmp", "host": "a"}}} - assert.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf)) + require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf)) assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall) // Test adding any rule conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "host": "a"}}} - assert.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) + require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall) // Test adding rule with cidr @@ -717,49 +717,49 @@ func TestAddFirewallRulesFromConfig(t *testing.T) { conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "cidr": cidr.String()}}} - assert.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) + require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr, localIp: netip.Prefix{}}, mf.lastCall) // Test adding rule with local_cidr conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "local_cidr": cidr.String()}}} - assert.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) + require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: cidr}, mf.lastCall) // Test adding rule with ca_sha conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "ca_sha": "12312313123"}}} - assert.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) + require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: netip.Prefix{}, caSha: "12312313123"}, mf.lastCall) // Test adding rule with ca_name conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "ca_name": "root01"}}} - assert.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) + require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: netip.Prefix{}, caName: "root01"}, mf.lastCall) // Test single group conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a"}}} - assert.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) + require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall) // Test single groups conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "groups": "a"}}} - assert.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) + require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall) // Test multiple AND groups conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "groups": []string{"a", "b"}}}} - assert.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) + require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a", "b"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall) // Test Add error @@ -767,7 +767,7 @@ func TestAddFirewallRulesFromConfig(t *testing.T) { mf = &mockFirewall{} mf.nextCallReturn = errors.New("test error") conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "host": "a"}}} - assert.EqualError(t, AddFirewallRulesFromConfig(l, true, conf, mf), "firewall.inbound rule #0; `test error`") + require.EqualError(t, AddFirewallRulesFromConfig(l, true, conf, mf), "firewall.inbound rule #0; `test error`") } func TestFirewall_convertRule(t *testing.T) { @@ -782,7 +782,7 @@ func TestFirewall_convertRule(t *testing.T) { r, err := convertRule(l, c, "test", 1) assert.Contains(t, ob.String(), "test rule #1; group was an array with a single value, converting to simple value") - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, "group1", r.Group) // Ensure group array of > 1 is errord @@ -793,7 +793,7 @@ func TestFirewall_convertRule(t *testing.T) { r, err = convertRule(l, c, "test", 1) assert.Equal(t, "", ob.String()) - assert.Error(t, err, "group should contain a single value, an array with more than one entry was provided") + require.Error(t, err, "group should contain a single value, an array with more than one entry was provided") // Make sure a well formed group is alright ob.Reset() @@ -802,7 +802,7 @@ func TestFirewall_convertRule(t *testing.T) { } r, err = convertRule(l, c, "test", 1) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, "group1", r.Group) } diff --git a/header/header_test.go b/header/header_test.go index 1836a75..a7e5374 100644 --- a/header/header_test.go +++ b/header/header_test.go @@ -5,6 +5,7 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) type headerTest struct { @@ -111,7 +112,7 @@ func TestHeader_String(t *testing.T) { func TestHeader_MarshalJSON(t *testing.T) { b, err := (&H{100, Test, TestRequest, 99, 98, 97}).MarshalJSON() - assert.NoError(t, err) + require.NoError(t, err) assert.Equal( t, "{\"messageCounter\":97,\"remoteIndex\":98,\"reserved\":99,\"subType\":\"testRequest\",\"type\":\"test\",\"version\":100}", diff --git a/lighthouse_test.go b/lighthouse_test.go index 9e9ad53..3b1295a 100644 --- a/lighthouse_test.go +++ b/lighthouse_test.go @@ -13,6 +13,7 @@ import ( "github.com/slackhq/nebula/header" "github.com/slackhq/nebula/test" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "gopkg.in/yaml.v2" ) @@ -21,7 +22,7 @@ func TestOldIPv4Only(t *testing.T) { b := []byte{8, 129, 130, 132, 80, 16, 10} var m V4AddrPort err := m.Unmarshal(b) - assert.NoError(t, err) + require.NoError(t, err) ip := netip.MustParseAddr("10.1.1.1") bp := ip.As4() assert.Equal(t, binary.BigEndian.Uint32(bp[:]), m.GetAddr()) @@ -42,14 +43,14 @@ func Test_lhStaticMapping(t *testing.T) { c.Settings["lighthouse"] = map[interface{}]interface{}{"hosts": []interface{}{lh1}} c.Settings["static_host_map"] = map[interface{}]interface{}{lh1: []interface{}{"1.1.1.1:4242"}} _, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil) - assert.NoError(t, err) + require.NoError(t, err) lh2 := "10.128.0.3" c = config.NewC(l) c.Settings["lighthouse"] = map[interface{}]interface{}{"hosts": []interface{}{lh1, lh2}} c.Settings["static_host_map"] = map[interface{}]interface{}{lh1: []interface{}{"100.1.1.1:4242"}} _, err = NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil) - assert.EqualError(t, err, "lighthouse 10.128.0.3 does not have a static_host_map entry") + require.EqualError(t, err, "lighthouse 10.128.0.3 does not have a static_host_map entry") } func TestReloadLighthouseInterval(t *testing.T) { @@ -71,19 +72,19 @@ func TestReloadLighthouseInterval(t *testing.T) { c.Settings["static_host_map"] = map[interface{}]interface{}{lh1: []interface{}{"1.1.1.1:4242"}} lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil) - assert.NoError(t, err) + require.NoError(t, err) lh.ifce = &mockEncWriter{} // The first one routine is kicked off by main.go currently, lets make sure that one dies - assert.NoError(t, c.ReloadConfigString("lighthouse:\n interval: 5")) + require.NoError(t, c.ReloadConfigString("lighthouse:\n interval: 5")) assert.Equal(t, int64(5), lh.interval.Load()) // Subsequent calls are killed off by the LightHouse.Reload function - assert.NoError(t, c.ReloadConfigString("lighthouse:\n interval: 10")) + require.NoError(t, c.ReloadConfigString("lighthouse:\n interval: 10")) assert.Equal(t, int64(10), lh.interval.Load()) // If this completes then nothing is stealing our reload routine - assert.NoError(t, c.ReloadConfigString("lighthouse:\n interval: 11")) + require.NoError(t, c.ReloadConfigString("lighthouse:\n interval: 11")) assert.Equal(t, int64(11), lh.interval.Load()) } @@ -99,9 +100,7 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) { c := config.NewC(l) lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil) - if !assert.NoError(b, err) { - b.Fatal() - } + require.NoError(b, err) hAddr := netip.MustParseAddrPort("4.5.6.7:12345") hAddr2 := netip.MustParseAddrPort("4.5.6.7:12346") @@ -145,7 +144,7 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) { }, } p, err := req.Marshal() - assert.NoError(b, err) + require.NoError(b, err) for n := 0; n < b.N; n++ { lhh.HandleRequest(rAddr, hi, p, mw) } @@ -160,7 +159,7 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) { }, } p, err := req.Marshal() - assert.NoError(b, err) + require.NoError(b, err) for n := 0; n < b.N; n++ { lhh.HandleRequest(rAddr, hi, p, mw) @@ -205,7 +204,7 @@ func TestLighthouse_Memory(t *testing.T) { } lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil) lh.ifce = &mockEncWriter{} - assert.NoError(t, err) + require.NoError(t, err) lhh := lh.NewRequestHandler() // Test that my first update responds with just that @@ -290,7 +289,7 @@ func TestLighthouse_reload(t *testing.T) { } lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil) - assert.NoError(t, err) + require.NoError(t, err) nc := map[interface{}]interface{}{ "static_host_map": map[interface{}]interface{}{ @@ -298,11 +297,11 @@ func TestLighthouse_reload(t *testing.T) { }, } rc, err := yaml.Marshal(nc) - assert.NoError(t, err) + require.NoError(t, err) c.ReloadConfigString(string(rc)) err = lh.reload(c, false) - assert.NoError(t, err) + require.NoError(t, err) } func newLHHostRequest(fromAddr netip.AddrPort, myVpnIp, queryVpnIp netip.Addr, lhh *LightHouseHandler) testLhReply { diff --git a/outside_test.go b/outside_test.go index 944bf16..c63e57d 100644 --- a/outside_test.go +++ b/outside_test.go @@ -12,6 +12,7 @@ import ( "github.com/slackhq/nebula/firewall" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "golang.org/x/net/ipv4" ) @@ -20,13 +21,13 @@ func Test_newPacket(t *testing.T) { // length fails err := newPacket([]byte{}, true, p) - assert.ErrorIs(t, err, ErrPacketTooShort) + require.ErrorIs(t, err, ErrPacketTooShort) err = newPacket([]byte{0x40}, true, p) - assert.ErrorIs(t, err, ErrIPv4PacketTooShort) + require.ErrorIs(t, err, ErrIPv4PacketTooShort) err = newPacket([]byte{0x60}, true, p) - assert.ErrorIs(t, err, ErrIPv6PacketTooShort) + require.ErrorIs(t, err, ErrIPv6PacketTooShort) // length fail with ip options h := ipv4.Header{ @@ -39,15 +40,15 @@ func Test_newPacket(t *testing.T) { b, _ := h.Marshal() err = newPacket(b, true, p) - assert.ErrorIs(t, err, ErrIPv4InvalidHeaderLength) + require.ErrorIs(t, err, ErrIPv4InvalidHeaderLength) // not an ipv4 packet err = newPacket([]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, true, p) - assert.ErrorIs(t, err, ErrUnknownIPVersion) + require.ErrorIs(t, err, ErrUnknownIPVersion) // invalid ihl err = newPacket([]byte{4<<4 | (8 >> 2 & 0x0f), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, true, p) - assert.ErrorIs(t, err, ErrIPv4InvalidHeaderLength) + require.ErrorIs(t, err, ErrIPv4InvalidHeaderLength) // account for variable ip header length - incoming h = ipv4.Header{ @@ -63,7 +64,7 @@ func Test_newPacket(t *testing.T) { b = append(b, []byte{0, 3, 0, 4}...) err = newPacket(b, true, p) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, uint8(firewall.ProtoTCP), p.Protocol) assert.Equal(t, netip.MustParseAddr("10.0.0.2"), p.LocalAddr) assert.Equal(t, netip.MustParseAddr("10.0.0.1"), p.RemoteAddr) @@ -85,7 +86,7 @@ func Test_newPacket(t *testing.T) { b = append(b, []byte{0, 5, 0, 6}...) err = newPacket(b, false, p) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, uint8(2), p.Protocol) assert.Equal(t, netip.MustParseAddr("10.0.0.1"), p.LocalAddr) assert.Equal(t, netip.MustParseAddr("10.0.0.2"), p.RemoteAddr) @@ -111,10 +112,10 @@ func Test_newPacket_v6(t *testing.T) { FixLengths: false, } err := gopacket.SerializeLayers(buffer, opt, &ip) - assert.NoError(t, err) + require.NoError(t, err) err = newPacket(buffer.Bytes(), true, p) - assert.ErrorIs(t, err, ErrIPv6CouldNotFindPayload) + require.ErrorIs(t, err, ErrIPv6CouldNotFindPayload) // A good ICMP packet ip = layers.IPv6{ @@ -134,7 +135,7 @@ func Test_newPacket_v6(t *testing.T) { } err = newPacket(buffer.Bytes(), true, p) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, uint8(layers.IPProtocolICMPv6), p.Protocol) assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr) assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr) @@ -146,7 +147,7 @@ func Test_newPacket_v6(t *testing.T) { b := buffer.Bytes() b[6] = byte(layers.IPProtocolESP) err = newPacket(b, true, p) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, uint8(layers.IPProtocolESP), p.Protocol) assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr) assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr) @@ -158,7 +159,7 @@ func Test_newPacket_v6(t *testing.T) { b = buffer.Bytes() b[6] = byte(layers.IPProtocolNoNextHeader) err = newPacket(b, true, p) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, uint8(layers.IPProtocolNoNextHeader), p.Protocol) assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr) assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr) @@ -170,7 +171,7 @@ func Test_newPacket_v6(t *testing.T) { b = buffer.Bytes() b[6] = 255 // 255 is a reserved protocol number err = newPacket(b, true, p) - assert.ErrorIs(t, err, ErrIPv6CouldNotFindPayload) + require.ErrorIs(t, err, ErrIPv6CouldNotFindPayload) // A good UDP packet ip = layers.IPv6{ @@ -186,7 +187,7 @@ func Test_newPacket_v6(t *testing.T) { DstPort: layers.UDPPort(22), } err = udp.SetNetworkLayerForChecksum(&ip) - assert.NoError(t, err) + require.NoError(t, err) buffer.Clear() err = gopacket.SerializeLayers(buffer, opt, &ip, &udp, gopacket.Payload([]byte{0xde, 0xad, 0xbe, 0xef})) @@ -197,7 +198,7 @@ func Test_newPacket_v6(t *testing.T) { // incoming err = newPacket(b, true, p) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, uint8(firewall.ProtoUDP), p.Protocol) assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr) assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr) @@ -207,7 +208,7 @@ func Test_newPacket_v6(t *testing.T) { // outgoing err = newPacket(b, false, p) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, uint8(firewall.ProtoUDP), p.Protocol) assert.Equal(t, netip.MustParseAddr("ff02::2"), p.LocalAddr) assert.Equal(t, netip.MustParseAddr("ff02::1"), p.RemoteAddr) @@ -217,14 +218,14 @@ func Test_newPacket_v6(t *testing.T) { // Too short UDP packet err = newPacket(b[:len(b)-10], false, p) // pull off the last 10 bytes - assert.ErrorIs(t, err, ErrIPv6PacketTooShort) + require.ErrorIs(t, err, ErrIPv6PacketTooShort) // A good TCP packet b[6] = byte(layers.IPProtocolTCP) // incoming err = newPacket(b, true, p) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, uint8(firewall.ProtoTCP), p.Protocol) assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr) assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr) @@ -234,7 +235,7 @@ func Test_newPacket_v6(t *testing.T) { // outgoing err = newPacket(b, false, p) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, uint8(firewall.ProtoTCP), p.Protocol) assert.Equal(t, netip.MustParseAddr("ff02::2"), p.LocalAddr) assert.Equal(t, netip.MustParseAddr("ff02::1"), p.RemoteAddr) @@ -244,7 +245,7 @@ func Test_newPacket_v6(t *testing.T) { // Too short TCP packet err = newPacket(b[:len(b)-10], false, p) // pull off the last 10 bytes - assert.ErrorIs(t, err, ErrIPv6PacketTooShort) + require.ErrorIs(t, err, ErrIPv6PacketTooShort) // A good UDP packet with an AH header ip = layers.IPv6{ @@ -279,7 +280,7 @@ func Test_newPacket_v6(t *testing.T) { b = append(b, udpHeader...) err = newPacket(b, true, p) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, uint8(firewall.ProtoUDP), p.Protocol) assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr) assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr) @@ -290,7 +291,7 @@ func Test_newPacket_v6(t *testing.T) { // Invalid AH header b = buffer.Bytes() err = newPacket(b, true, p) - assert.ErrorIs(t, err, ErrIPv6CouldNotFindPayload) + require.ErrorIs(t, err, ErrIPv6CouldNotFindPayload) } func Test_newPacket_ipv6Fragment(t *testing.T) { @@ -338,7 +339,7 @@ func Test_newPacket_ipv6Fragment(t *testing.T) { // Test first fragment incoming err = newPacket(firstFrag, true, p) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr) assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr) assert.Equal(t, uint8(layers.IPProtocolUDP), p.Protocol) @@ -348,7 +349,7 @@ func Test_newPacket_ipv6Fragment(t *testing.T) { // Test first fragment outgoing err = newPacket(firstFrag, false, p) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, netip.MustParseAddr("ff02::2"), p.LocalAddr) assert.Equal(t, netip.MustParseAddr("ff02::1"), p.RemoteAddr) assert.Equal(t, uint8(layers.IPProtocolUDP), p.Protocol) @@ -377,7 +378,7 @@ func Test_newPacket_ipv6Fragment(t *testing.T) { // Test second fragment incoming err = newPacket(secondFrag, true, p) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr) assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr) assert.Equal(t, uint8(layers.IPProtocolUDP), p.Protocol) @@ -387,7 +388,7 @@ func Test_newPacket_ipv6Fragment(t *testing.T) { // Test second fragment outgoing err = newPacket(secondFrag, false, p) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, netip.MustParseAddr("ff02::2"), p.LocalAddr) assert.Equal(t, netip.MustParseAddr("ff02::1"), p.RemoteAddr) assert.Equal(t, uint8(layers.IPProtocolUDP), p.Protocol) @@ -397,7 +398,7 @@ func Test_newPacket_ipv6Fragment(t *testing.T) { // Too short of a fragment packet err = newPacket(secondFrag[:len(secondFrag)-10], false, p) - assert.ErrorIs(t, err, ErrIPv6PacketTooShort) + require.ErrorIs(t, err, ErrIPv6PacketTooShort) } func BenchmarkParseV6(b *testing.B) { diff --git a/overlay/route_test.go b/overlay/route_test.go index 4fa30af..8f2c094 100644 --- a/overlay/route_test.go +++ b/overlay/route_test.go @@ -8,84 +8,85 @@ import ( "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/test" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func Test_parseRoutes(t *testing.T) { l := test.NewLogger() c := config.NewC(l) n, err := netip.ParsePrefix("10.0.0.0/24") - assert.NoError(t, err) + require.NoError(t, err) // test no routes config routes, err := parseRoutes(c, []netip.Prefix{n}) - assert.NoError(t, err) + require.NoError(t, err) assert.Empty(t, routes) // not an array c.Settings["tun"] = map[interface{}]interface{}{"routes": "hi"} routes, err = parseRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) - assert.EqualError(t, err, "tun.routes is not an array") + require.EqualError(t, err, "tun.routes is not an array") // no routes c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{}} routes, err = parseRoutes(c, []netip.Prefix{n}) - assert.NoError(t, err) + require.NoError(t, err) assert.Empty(t, routes) // weird route c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{"asdf"}} routes, err = parseRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) - assert.EqualError(t, err, "entry 1 in tun.routes is invalid") + require.EqualError(t, err, "entry 1 in tun.routes is invalid") // no mtu c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{}}} routes, err = parseRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) - assert.EqualError(t, err, "entry 1.mtu in tun.routes is not present") + require.EqualError(t, err, "entry 1.mtu in tun.routes is not present") // bad mtu c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "nope"}}} routes, err = parseRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) - assert.EqualError(t, err, "entry 1.mtu in tun.routes is not an integer: strconv.Atoi: parsing \"nope\": invalid syntax") + require.EqualError(t, err, "entry 1.mtu in tun.routes is not an integer: strconv.Atoi: parsing \"nope\": invalid syntax") // low mtu c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "499"}}} routes, err = parseRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) - assert.EqualError(t, err, "entry 1.mtu in tun.routes is below 500: 499") + require.EqualError(t, err, "entry 1.mtu in tun.routes is below 500: 499") // missing route c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500"}}} routes, err = parseRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) - assert.EqualError(t, err, "entry 1.route in tun.routes is not present") + require.EqualError(t, err, "entry 1.route in tun.routes is not present") // unparsable route c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "nope"}}} routes, err = parseRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) - assert.EqualError(t, err, "entry 1.route in tun.routes failed to parse: netip.ParsePrefix(\"nope\"): no '/'") + require.EqualError(t, err, "entry 1.route in tun.routes failed to parse: netip.ParsePrefix(\"nope\"): no '/'") // below network range c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "1.0.0.0/8"}}} routes, err = parseRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) - assert.EqualError(t, err, "entry 1.route in tun.routes is not contained within the configured vpn networks; route: 1.0.0.0/8, networks: [10.0.0.0/24]") + require.EqualError(t, err, "entry 1.route in tun.routes is not contained within the configured vpn networks; route: 1.0.0.0/8, networks: [10.0.0.0/24]") // above network range c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "10.0.1.0/24"}}} routes, err = parseRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) - assert.EqualError(t, err, "entry 1.route in tun.routes is not contained within the configured vpn networks; route: 10.0.1.0/24, networks: [10.0.0.0/24]") + require.EqualError(t, err, "entry 1.route in tun.routes is not contained within the configured vpn networks; route: 10.0.1.0/24, networks: [10.0.0.0/24]") // Not in multiple ranges c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "192.0.0.0/24"}}} routes, err = parseRoutes(c, []netip.Prefix{n, netip.MustParsePrefix("192.1.0.0/24")}) assert.Nil(t, routes) - assert.EqualError(t, err, "entry 1.route in tun.routes is not contained within the configured vpn networks; route: 192.0.0.0/24, networks: [10.0.0.0/24 192.1.0.0/24]") + require.EqualError(t, err, "entry 1.route in tun.routes is not contained within the configured vpn networks; route: 192.0.0.0/24, networks: [10.0.0.0/24 192.1.0.0/24]") // happy case c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{ @@ -93,7 +94,7 @@ func Test_parseRoutes(t *testing.T) { map[interface{}]interface{}{"mtu": "8000", "route": "10.0.0.1/32"}, }} routes, err = parseRoutes(c, []netip.Prefix{n}) - assert.NoError(t, err) + require.NoError(t, err) assert.Len(t, routes, 2) tested := 0 @@ -119,36 +120,36 @@ func Test_parseUnsafeRoutes(t *testing.T) { l := test.NewLogger() c := config.NewC(l) n, err := netip.ParsePrefix("10.0.0.0/24") - assert.NoError(t, err) + require.NoError(t, err) // test no routes config routes, err := parseUnsafeRoutes(c, []netip.Prefix{n}) - assert.NoError(t, err) + require.NoError(t, err) assert.Empty(t, routes) // not an array c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": "hi"} routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) - assert.EqualError(t, err, "tun.unsafe_routes is not an array") + require.EqualError(t, err, "tun.unsafe_routes is not an array") // no routes c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{}} routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) - assert.NoError(t, err) + require.NoError(t, err) assert.Empty(t, routes) // weird route c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{"asdf"}} routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) - assert.EqualError(t, err, "entry 1 in tun.unsafe_routes is invalid") + require.EqualError(t, err, "entry 1 in tun.unsafe_routes is invalid") // no via c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{}}} routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) - assert.EqualError(t, err, "entry 1.via in tun.unsafe_routes is not present") + require.EqualError(t, err, "entry 1.via in tun.unsafe_routes is not present") // invalid via for _, invalidValue := range []interface{}{ @@ -157,44 +158,44 @@ func Test_parseUnsafeRoutes(t *testing.T) { c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": invalidValue}}} routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) - assert.EqualError(t, err, fmt.Sprintf("entry 1.via in tun.unsafe_routes is not a string: found %T", invalidValue)) + require.EqualError(t, err, fmt.Sprintf("entry 1.via in tun.unsafe_routes is not a string: found %T", invalidValue)) } // unparsable via c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "via": "nope"}}} routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) - assert.EqualError(t, err, "entry 1.via in tun.unsafe_routes failed to parse address: ParseAddr(\"nope\"): unable to parse IP") + require.EqualError(t, err, "entry 1.via in tun.unsafe_routes failed to parse address: ParseAddr(\"nope\"): unable to parse IP") // missing route c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "500"}}} routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) - assert.EqualError(t, err, "entry 1.route in tun.unsafe_routes is not present") + require.EqualError(t, err, "entry 1.route in tun.unsafe_routes is not present") // unparsable route c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "500", "route": "nope"}}} routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) - assert.EqualError(t, err, "entry 1.route in tun.unsafe_routes failed to parse: netip.ParsePrefix(\"nope\"): no '/'") + require.EqualError(t, err, "entry 1.route in tun.unsafe_routes failed to parse: netip.ParsePrefix(\"nope\"): no '/'") // within network range c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "10.0.0.0/24"}}} routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) - assert.EqualError(t, err, "entry 1.route in tun.unsafe_routes is contained within the configured vpn networks; route: 10.0.0.0/24, network: 10.0.0.0/24") + require.EqualError(t, err, "entry 1.route in tun.unsafe_routes is contained within the configured vpn networks; route: 10.0.0.0/24, network: 10.0.0.0/24") // below network range c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "1.0.0.0/8"}}} routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Len(t, routes, 1) - assert.NoError(t, err) + require.NoError(t, err) // above network range c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "10.0.1.0/24"}}} routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Len(t, routes, 1) - assert.NoError(t, err) + require.NoError(t, err) // no mtu c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "1.0.0.0/8"}}} @@ -206,19 +207,19 @@ func Test_parseUnsafeRoutes(t *testing.T) { c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "nope"}}} routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) - assert.EqualError(t, err, "entry 1.mtu in tun.unsafe_routes is not an integer: strconv.Atoi: parsing \"nope\": invalid syntax") + require.EqualError(t, err, "entry 1.mtu in tun.unsafe_routes is not an integer: strconv.Atoi: parsing \"nope\": invalid syntax") // low mtu c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "499"}}} routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) - assert.EqualError(t, err, "entry 1.mtu in tun.unsafe_routes is below 500: 499") + require.EqualError(t, err, "entry 1.mtu in tun.unsafe_routes is below 500: 499") // bad install c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "9000", "route": "1.0.0.0/29", "install": "nope"}}} routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) - assert.EqualError(t, err, "entry 1.install in tun.unsafe_routes is not a boolean: strconv.ParseBool: parsing \"nope\": invalid syntax") + require.EqualError(t, err, "entry 1.install in tun.unsafe_routes is not a boolean: strconv.ParseBool: parsing \"nope\": invalid syntax") // happy case c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{ @@ -228,7 +229,7 @@ func Test_parseUnsafeRoutes(t *testing.T) { map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "1500", "metric": 1234, "route": "1.0.0.2/32"}, }} routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) - assert.NoError(t, err) + require.NoError(t, err) assert.Len(t, routes, 4) tested := 0 @@ -260,38 +261,38 @@ func Test_makeRouteTree(t *testing.T) { l := test.NewLogger() c := config.NewC(l) n, err := netip.ParsePrefix("10.0.0.0/24") - assert.NoError(t, err) + require.NoError(t, err) c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{ map[interface{}]interface{}{"via": "192.168.0.1", "route": "1.0.0.0/28"}, map[interface{}]interface{}{"via": "192.168.0.2", "route": "1.0.0.1/32"}, }} routes, err := parseUnsafeRoutes(c, []netip.Prefix{n}) - assert.NoError(t, err) + require.NoError(t, err) assert.Len(t, routes, 2) routeTree, err := makeRouteTree(l, routes, true) - assert.NoError(t, err) + require.NoError(t, err) ip, err := netip.ParseAddr("1.0.0.2") - assert.NoError(t, err) + require.NoError(t, err) r, ok := routeTree.Lookup(ip) assert.True(t, ok) nip, err := netip.ParseAddr("192.168.0.1") - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, nip, r) ip, err = netip.ParseAddr("1.0.0.1") - assert.NoError(t, err) + require.NoError(t, err) r, ok = routeTree.Lookup(ip) assert.True(t, ok) nip, err = netip.ParseAddr("192.168.0.2") - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, nip, r) ip, err = netip.ParseAddr("1.1.0.1") - assert.NoError(t, err) + require.NoError(t, err) r, ok = routeTree.Lookup(ip) assert.False(t, ok) } diff --git a/punchy_test.go b/punchy_test.go index 7918449..99d703d 100644 --- a/punchy_test.go +++ b/punchy_test.go @@ -7,6 +7,7 @@ import ( "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/test" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestNewPunchyFromConfig(t *testing.T) { @@ -56,7 +57,7 @@ func TestPunchy_reload(t *testing.T) { l := test.NewLogger() c := config.NewC(l) delay, _ := time.ParseDuration("1m") - assert.NoError(t, c.LoadString(` + require.NoError(t, c.LoadString(` punchy: delay: 1m respond: false @@ -66,7 +67,7 @@ punchy: assert.False(t, p.GetRespond()) newDelay, _ := time.ParseDuration("10m") - assert.NoError(t, c.ReloadConfigString(` + require.NoError(t, c.ReloadConfigString(` punchy: delay: 10m respond: true From 2fb018ced85be1f254de77eb1703584642aad49d Mon Sep 17 00:00:00 2001 From: Aleksandr Zykov Date: Wed, 12 Mar 2025 04:58:52 +0100 Subject: [PATCH 52/67] Fixed homebrew formula path (#1219) --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 56e4c9d..5eea0e2 100644 --- a/README.md +++ b/README.md @@ -47,7 +47,7 @@ Check the [releases](https://github.com/slackhq/nebula/releases/latest) page for $ sudo apk add nebula ``` -- [macOS Homebrew](https://github.com/Homebrew/homebrew-core/blob/HEAD/Formula/nebula.rb) +- [macOS Homebrew](https://github.com/Homebrew/homebrew-core/blob/HEAD/Formula/n/nebula.rb) ``` $ brew install nebula ``` From 1d3c85338c104a7869607b6c23272cec6026ea9e Mon Sep 17 00:00:00 2001 From: jampe Date: Wed, 12 Mar 2025 15:35:33 +0100 Subject: [PATCH 53/67] add so_mark sockopt support (#1331) --- examples/config.yml | 5 +++++ udp/udp_linux.go | 24 ++++++++++++++++++++++++ 2 files changed, 29 insertions(+) diff --git a/examples/config.yml b/examples/config.yml index 1c3584e..aae0d98 100644 --- a/examples/config.yml +++ b/examples/config.yml @@ -144,6 +144,11 @@ listen: # valid values: always, never, private # This setting is reloadable. #send_recv_error: always + # The so_sock option is a Linux-specific feature that allows all outgoing Nebula packets to be tagged with a specific identifier. + # This tagging enables IP rule-based filtering. For example, it supports 0.0.0.0/0 unsafe_routes, + # allowing for more precise routing decisions based on the packet tags. Default is 0 meaning no mark is set. + # This setting is reloadable. + #so_mark: 0 # Routines is the number of thread pairs to run that consume from the tun and UDP queues. # Currently, this defaults to 1 which means we have 1 tun queue reader and 1 diff --git a/udp/udp_linux.go b/udp/udp_linux.go index 32a567e..f1936b4 100644 --- a/udp/udp_linux.go +++ b/udp/udp_linux.go @@ -84,6 +84,10 @@ func (u *StdConn) SetSendBuffer(n int) error { return unix.SetsockoptInt(u.sysFd, unix.SOL_SOCKET, unix.SO_SNDBUFFORCE, n) } +func (u *StdConn) SetSoMark(mark int) error { + return unix.SetsockoptInt(u.sysFd, unix.SOL_SOCKET, unix.SO_MARK, mark) +} + func (u *StdConn) GetRecvBuffer() (int, error) { return unix.GetsockoptInt(int(u.sysFd), unix.SOL_SOCKET, unix.SO_RCVBUF) } @@ -92,6 +96,10 @@ func (u *StdConn) GetSendBuffer() (int, error) { return unix.GetsockoptInt(int(u.sysFd), unix.SOL_SOCKET, unix.SO_SNDBUF) } +func (u *StdConn) GetSoMark() (int, error) { + return unix.GetsockoptInt(int(u.sysFd), unix.SOL_SOCKET, unix.SO_MARK) +} + func (u *StdConn) LocalAddr() (netip.AddrPort, error) { sa, err := unix.Getsockname(u.sysFd) if err != nil { @@ -270,6 +278,22 @@ func (u *StdConn) ReloadConfig(c *config.C) { u.l.WithError(err).Error("Failed to set listen.write_buffer") } } + + b = c.GetInt("listen.so_mark", 0) + s, err := u.GetSoMark() + if b > 0 || (err == nil && s != 0) { + err := u.SetSoMark(b) + if err == nil { + s, err := u.GetSoMark() + if err == nil { + u.l.WithField("mark", s).Info("listen.so_mark was set") + } else { + u.l.WithError(err).Warn("Failed to get listen.so_mark") + } + } else { + u.l.WithError(err).Error("Failed to set listen.so_mark") + } + } } func (u *StdConn) getMemInfo(meminfo *[unix.SK_MEMINFO_VARS]uint32) error { From 50473bd2a893404de88841464781ac7deaba9ea9 Mon Sep 17 00:00:00 2001 From: Caleb Jasik Date: Wed, 12 Mar 2025 22:53:16 -0500 Subject: [PATCH 54/67] Update example config to listen on `::` by default (#1351) --- examples/config.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/config.yml b/examples/config.yml index aae0d98..4e7a4ae 100644 --- a/examples/config.yml +++ b/examples/config.yml @@ -126,8 +126,8 @@ lighthouse: # Port Nebula will be listening on. The default here is 4242. For a lighthouse node, the port should be defined, # however using port 0 will dynamically assign a port and is recommended for roaming nodes. listen: - # To listen on both any ipv4 and ipv6 use "::" - host: 0.0.0.0 + # To listen on only ipv4, use "0.0.0.0" + host: "::" port: 4242 # Sets the max number of packets to pull from the kernel for each syscall (under systems that support recvmmsg) # default is 64, does not support reload From 3de36c99b6c7e304a463128ae9319d96bfd822e9 Mon Sep 17 00:00:00 2001 From: Wade Simmons Date: Fri, 14 Mar 2025 13:49:27 -0400 Subject: [PATCH 55/67] build with go1.24 (#1338) This doesn't change our go.mod, which still only requires go1.22 as a minimum. It only changes our builds to use go1.24 so we have the latest improvements. --- .github/workflows/gofmt.yml | 2 +- .github/workflows/release.yml | 6 +++--- .github/workflows/smoke.yml | 2 +- .github/workflows/test.yml | 6 +++--- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/.github/workflows/gofmt.yml b/.github/workflows/gofmt.yml index 20a39cf..288f32c 100644 --- a/.github/workflows/gofmt.yml +++ b/.github/workflows/gofmt.yml @@ -18,7 +18,7 @@ jobs: - uses: actions/setup-go@v5 with: - go-version: '1.23' + go-version: '1.24' check-latest: true - name: Install goimports diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 392f71b..f9df115 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -14,7 +14,7 @@ jobs: - uses: actions/setup-go@v5 with: - go-version: '1.23' + go-version: '1.24' check-latest: true - name: Build @@ -37,7 +37,7 @@ jobs: - uses: actions/setup-go@v5 with: - go-version: '1.23' + go-version: '1.24' check-latest: true - name: Build @@ -70,7 +70,7 @@ jobs: - uses: actions/setup-go@v5 with: - go-version: '1.23' + go-version: '1.24' check-latest: true - name: Import certificates diff --git a/.github/workflows/smoke.yml b/.github/workflows/smoke.yml index 3f63008..fc654da 100644 --- a/.github/workflows/smoke.yml +++ b/.github/workflows/smoke.yml @@ -22,7 +22,7 @@ jobs: - uses: actions/setup-go@v5 with: - go-version: '1.23' + go-version: '1.24' check-latest: true - name: build diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index b8a4f03..28f0590 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -22,7 +22,7 @@ jobs: - uses: actions/setup-go@v5 with: - go-version: '1.23' + go-version: '1.24' check-latest: true - name: Build @@ -60,7 +60,7 @@ jobs: - uses: actions/setup-go@v5 with: - go-version: '1.23' + go-version: '1.24' check-latest: true - name: Build @@ -102,7 +102,7 @@ jobs: - uses: actions/setup-go@v5 with: - go-version: '1.23' + go-version: '1.24' check-latest: true - name: Build nebula From f86953ca56e623ec7629ef7753024d8ced944a72 Mon Sep 17 00:00:00 2001 From: dioss-Machiel Date: Mon, 24 Mar 2025 23:15:59 +0100 Subject: [PATCH 56/67] Implement ECMP for unsafe_routes (#1332) --- examples/config.yml | 23 ++++++- inside.go | 95 +++++++++++++++++++++++--- overlay/device.go | 4 +- overlay/route.go | 78 ++++++++++++++++++---- overlay/route_test.go | 112 ++++++++++++++++++++++++++++++- overlay/tun_android.go | 5 +- overlay/tun_darwin.go | 9 +-- overlay/tun_disabled.go | 5 +- overlay/tun_freebsd.go | 7 +- overlay/tun_ios.go | 5 +- overlay/tun_linux.go | 91 +++++++++++++++++++------ overlay/tun_netbsd.go | 7 +- overlay/tun_openbsd.go | 7 +- overlay/tun_tester.go | 5 +- overlay/tun_windows.go | 15 +++-- overlay/user.go | 11 ++- routing/balance.go | 39 +++++++++++ routing/balance_test.go | 144 ++++++++++++++++++++++++++++++++++++++++ routing/gateway.go | 70 +++++++++++++++++++ routing/gateway_test.go | 34 ++++++++++ test/tun.go | 6 +- 21 files changed, 690 insertions(+), 82 deletions(-) create mode 100644 routing/balance.go create mode 100644 routing/balance_test.go create mode 100644 routing/gateway.go create mode 100644 routing/gateway_test.go diff --git a/examples/config.yml b/examples/config.yml index 4e7a4ae..3b7c38b 100644 --- a/examples/config.yml +++ b/examples/config.yml @@ -239,7 +239,28 @@ tun: # Unsafe routes allows you to route traffic over nebula to non-nebula nodes # Unsafe routes should be avoided unless you have hosts/services that cannot run nebula - # NOTE: The nebula certificate of the "via" node *MUST* have the "route" defined as a subnet in its certificate + # Supports weighted ECMP if you define a list of gateways, this can be used for load balancing or redundancy to hosts outside of nebula + # NOTES: + # * You will only see a single gateway in the routing table if you are not on linux + # * If a gateway is not reachable through the overlay another gateway will be selected to send the traffic through, ignoring weights + # + # unsafe_routes: + # # Multiple gateways without defining a weight defaults to a weight of 1, this will balance traffic equally between the three gateways + # - route: 192.168.87.0/24 + # via: + # - gateway: 10.0.0.1 + # - gateway: 10.0.0.2 + # - gateway: 10.0.0.3 + # # Multiple gateways with a weight, this will balance traffic accordingly + # - route: 192.168.87.0/24 + # via: + # - gateway: 10.0.0.1 + # weight: 10 + # - gateway: 10.0.0.2 + # weight: 5 + # + # NOTE: The nebula certificate of the "via" node(s) *MUST* have the "route" defined as a subnet in its certificate + # `via`: single node or list of gateways to use for this route # `mtu`: will default to tun mtu if this option is not specified # `metric`: will default to 0 if this option is not specified # `install`: will default to true, controls whether this route is installed in the systems routing table. diff --git a/inside.go b/inside.go index 9629947..0af350d 100644 --- a/inside.go +++ b/inside.go @@ -8,6 +8,7 @@ import ( "github.com/slackhq/nebula/header" "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/noiseutil" + "github.com/slackhq/nebula/routing" ) func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb, out []byte, q int, localCache firewall.ConntrackCache) { @@ -49,7 +50,7 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet return } - hostinfo, ready := f.getOrHandshake(fwPacket.RemoteAddr, func(hh *HandshakeHostInfo) { + hostinfo, ready := f.getOrHandshakeConsiderRouting(fwPacket, func(hh *HandshakeHostInfo) { hh.cachePacket(f.l, header.Message, 0, packet, f.sendMessageNow, f.cachedPacketMetrics) }) @@ -121,22 +122,94 @@ func (f *Interface) rejectOutside(packet []byte, ci *ConnectionState, hostinfo * f.sendNoMetrics(header.Message, 0, ci, hostinfo, netip.AddrPort{}, out, nb, packet, q) } +// Handshake will attempt to initiate a tunnel with the provided vpn address if it is within our vpn networks. This is a no-op if the tunnel is already established or being established func (f *Interface) Handshake(vpnAddr netip.Addr) { - f.getOrHandshake(vpnAddr, nil) + f.getOrHandshakeNoRouting(vpnAddr, nil) } -// getOrHandshake returns nil if the vpnAddr is not routable. +// getOrHandshakeNoRouting returns nil if the vpnAddr is not routable. // If the 2nd return var is false then the hostinfo is not ready to be used in a tunnel -func (f *Interface) getOrHandshake(vpnAddr netip.Addr, cacheCallback func(*HandshakeHostInfo)) (*HostInfo, bool) { +func (f *Interface) getOrHandshakeNoRouting(vpnAddr netip.Addr, cacheCallback func(*HandshakeHostInfo)) (*HostInfo, bool) { _, found := f.myVpnNetworksTable.Lookup(vpnAddr) - if !found { - vpnAddr = f.inside.RouteFor(vpnAddr) - if !vpnAddr.IsValid() { - return nil, false - } + if found { + return f.handshakeManager.GetOrHandshake(vpnAddr, cacheCallback) + } + + return nil, false +} + +// getOrHandshakeConsiderRouting will try to find the HostInfo to handle this packet, starting a handshake if necessary. +// If the 2nd return var is false then the hostinfo is not ready to be used in a tunnel. +func (f *Interface) getOrHandshakeConsiderRouting(fwPacket *firewall.Packet, cacheCallback func(*HandshakeHostInfo)) (*HostInfo, bool) { + + destinationAddr := fwPacket.RemoteAddr + + hostinfo, ready := f.getOrHandshakeNoRouting(destinationAddr, cacheCallback) + + // Host is inside the mesh, no routing required + if hostinfo != nil { + return hostinfo, ready + } + + gateways := f.inside.RoutesFor(destinationAddr) + + switch len(gateways) { + case 0: + return nil, false + case 1: + // Single gateway route + return f.handshakeManager.GetOrHandshake(gateways[0].Addr(), cacheCallback) + default: + // Multi gateway route, perform ECMP categorization + gatewayAddr, balancingOk := routing.BalancePacket(fwPacket, gateways) + + if !balancingOk { + // This happens if the gateway buckets were not calculated, this _should_ never happen + f.l.Error("Gateway buckets not calculated, fallback from ECMP to random routing. Please report this bug.") + } + + var handshakeInfoForChosenGateway *HandshakeHostInfo + var hhReceiver = func(hh *HandshakeHostInfo) { + handshakeInfoForChosenGateway = hh + } + + // Store the handshakeHostInfo for later. + // If this node is not reachable we will attempt other nodes, if none are reachable we will + // cache the packet for this gateway. + if hostinfo, ready = f.handshakeManager.GetOrHandshake(gatewayAddr, hhReceiver); ready { + return hostinfo, true + } + + // It appears the selected gateway cannot be reached, find another gateway to fallback on. + // The current implementation breaks ECMP but that seems better than no connectivity. + // If ECMP is also required when a gateway is down then connectivity status + // for each gateway needs to be kept and the weights recalculated when they go up or down. + // This would also need to interact with unsafe_route updates through reloading the config or + // use of the use_system_route_table option + + if f.l.Level >= logrus.DebugLevel { + f.l.WithField("destination", destinationAddr). + WithField("originalGateway", gatewayAddr). + Debugln("Calculated gateway for ECMP not available, attempting other gateways") + } + + for i := range gateways { + // Skip the gateway that failed previously + if gateways[i].Addr() == gatewayAddr { + continue + } + + // We do not need the HandshakeHostInfo since we cache the packet in the originally chosen gateway + if hostinfo, ready = f.handshakeManager.GetOrHandshake(gateways[i].Addr(), nil); ready { + return hostinfo, true + } + } + + // No gateways reachable, cache the packet in the originally chosen gateway + cacheCallback(handshakeInfoForChosenGateway) + return hostinfo, false } - return f.handshakeManager.GetOrHandshake(vpnAddr, cacheCallback) } func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, nb, out []byte) { @@ -163,7 +236,7 @@ func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubTyp // SendMessageToVpnAddr handles real addr:port lookup and sends to the current best known address for vpnAddr func (f *Interface) SendMessageToVpnAddr(t header.MessageType, st header.MessageSubType, vpnAddr netip.Addr, p, nb, out []byte) { - hostInfo, ready := f.getOrHandshake(vpnAddr, func(hh *HandshakeHostInfo) { + hostInfo, ready := f.getOrHandshakeNoRouting(vpnAddr, func(hh *HandshakeHostInfo) { hh.cachePacket(f.l, t, st, p, f.SendMessageToHostInfo, f.cachedPacketMetrics) }) diff --git a/overlay/device.go b/overlay/device.go index da8cbe9..07146ab 100644 --- a/overlay/device.go +++ b/overlay/device.go @@ -3,6 +3,8 @@ package overlay import ( "io" "net/netip" + + "github.com/slackhq/nebula/routing" ) type Device interface { @@ -10,6 +12,6 @@ type Device interface { Activate() error Networks() []netip.Prefix Name() string - RouteFor(netip.Addr) netip.Addr + RoutesFor(netip.Addr) routing.Gateways NewMultiQueueReader() (io.ReadWriteCloser, error) } diff --git a/overlay/route.go b/overlay/route.go index 687cc11..12364ec 100644 --- a/overlay/route.go +++ b/overlay/route.go @@ -11,13 +11,14 @@ import ( "github.com/gaissmai/bart" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/routing" ) type Route struct { MTU int Metric int Cidr netip.Prefix - Via netip.Addr + Via routing.Gateways Install bool } @@ -47,15 +48,17 @@ func (r Route) String() string { return s } -func makeRouteTree(l *logrus.Logger, routes []Route, allowMTU bool) (*bart.Table[netip.Addr], error) { - routeTree := new(bart.Table[netip.Addr]) +func makeRouteTree(l *logrus.Logger, routes []Route, allowMTU bool) (*bart.Table[routing.Gateways], error) { + routeTree := new(bart.Table[routing.Gateways]) for _, r := range routes { if !allowMTU && r.MTU > 0 { l.WithField("route", r).Warnf("route MTU is not supported in %s", runtime.GOOS) } - if r.Via.IsValid() { - routeTree.Insert(r.Cidr, r.Via) + gateways := r.Via + if len(gateways) > 0 { + routing.CalculateBucketsForGateways(gateways) + routeTree.Insert(r.Cidr, gateways) } } return routeTree, nil @@ -201,14 +204,63 @@ func parseUnsafeRoutes(c *config.C, networks []netip.Prefix) ([]Route, error) { return nil, fmt.Errorf("entry %v.via in tun.unsafe_routes is not present", i+1) } - via, ok := rVia.(string) - if !ok { - return nil, fmt.Errorf("entry %v.via in tun.unsafe_routes is not a string: found %T", i+1, rVia) - } + var gateways routing.Gateways - viaVpnIp, err := netip.ParseAddr(via) - if err != nil { - return nil, fmt.Errorf("entry %v.via in tun.unsafe_routes failed to parse address: %v", i+1, err) + switch via := rVia.(type) { + case string: + viaIp, err := netip.ParseAddr(via) + if err != nil { + return nil, fmt.Errorf("entry %v.via in tun.unsafe_routes failed to parse address: %v", i+1, err) + } + + gateways = routing.Gateways{routing.NewGateway(viaIp, 1)} + + case []interface{}: + gateways = make(routing.Gateways, len(via)) + for ig, v := range via { + gatewayMap, ok := v.(map[interface{}]interface{}) + if !ok { + return nil, fmt.Errorf("entry %v in tun.unsafe_routes[%v].via is invalid", i+1, ig+1) + } + + rGateway, ok := gatewayMap["gateway"] + if !ok { + return nil, fmt.Errorf("entry .gateway in tun.unsafe_routes[%v].via[%v] is not present", i+1, ig+1) + } + + parsedGateway, ok := rGateway.(string) + if !ok { + return nil, fmt.Errorf("entry .gateway in tun.unsafe_routes[%v].via[%v] is not a string", i+1, ig+1) + } + + gatewayIp, err := netip.ParseAddr(parsedGateway) + if err != nil { + return nil, fmt.Errorf("entry .gateway in tun.unsafe_routes[%v].via[%v] failed to parse address: %v", i+1, ig+1, err) + } + + rGatewayWeight, ok := gatewayMap["weight"] + if !ok { + rGatewayWeight = 1 + } + + gatewayWeight, ok := rGatewayWeight.(int) + if !ok { + _, err = strconv.ParseInt(rGatewayWeight.(string), 10, 32) + if err != nil { + return nil, fmt.Errorf("entry .weight in tun.unsafe_routes[%v].via[%v] is not an integer", i+1, ig+1) + } + } + + if gatewayWeight < 1 || gatewayWeight > math.MaxInt32 { + return nil, fmt.Errorf("entry .weight in tun.unsafe_routes[%v].via[%v] is not in range (1-%d) : %v", i+1, ig+1, math.MaxInt32, gatewayWeight) + } + + gateways[ig] = routing.NewGateway(gatewayIp, gatewayWeight) + + } + + default: + return nil, fmt.Errorf("entry %v.via in tun.unsafe_routes is not a string or list of gateways: found %T", i+1, rVia) } rRoute, ok := m["route"] @@ -226,7 +278,7 @@ func parseUnsafeRoutes(c *config.C, networks []netip.Prefix) ([]Route, error) { } r := Route{ - Via: viaVpnIp, + Via: gateways, MTU: mtu, Metric: metric, Install: install, diff --git a/overlay/route_test.go b/overlay/route_test.go index 8f2c094..eb5e914 100644 --- a/overlay/route_test.go +++ b/overlay/route_test.go @@ -6,6 +6,7 @@ import ( "testing" "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/test" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -158,15 +159,39 @@ func Test_parseUnsafeRoutes(t *testing.T) { c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": invalidValue}}} routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) - require.EqualError(t, err, fmt.Sprintf("entry 1.via in tun.unsafe_routes is not a string: found %T", invalidValue)) + require.EqualError(t, err, fmt.Sprintf("entry 1.via in tun.unsafe_routes is not a string or list of gateways: found %T", invalidValue)) } + // Unparsable list of via + c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": []string{"1", "2"}}}} + routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) + assert.Nil(t, routes) + require.EqualError(t, err, "entry 1.via in tun.unsafe_routes is not a string or list of gateways: found []string") + // unparsable via c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "via": "nope"}}} routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) require.EqualError(t, err, "entry 1.via in tun.unsafe_routes failed to parse address: ParseAddr(\"nope\"): unable to parse IP") + // unparsable gateway + c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "via": []interface{}{map[interface{}]interface{}{"gateway": "1"}}}}} + routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) + assert.Nil(t, routes) + require.EqualError(t, err, "entry .gateway in tun.unsafe_routes[1].via[1] failed to parse address: ParseAddr(\"1\"): unable to parse IP") + + // missing gateway element + c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "via": []interface{}{map[interface{}]interface{}{"weight": "1"}}}}} + routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) + assert.Nil(t, routes) + require.EqualError(t, err, "entry .gateway in tun.unsafe_routes[1].via[1] is not present") + + // unparsable weight element + c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "via": []interface{}{map[interface{}]interface{}{"gateway": "10.0.0.1", "weight": "a"}}}}} + routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) + assert.Nil(t, routes) + require.EqualError(t, err, "entry .weight in tun.unsafe_routes[1].via[1] is not an integer") + // missing route c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "500"}}} routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) @@ -280,7 +305,7 @@ func Test_makeRouteTree(t *testing.T) { nip, err := netip.ParseAddr("192.168.0.1") require.NoError(t, err) - assert.Equal(t, nip, r) + assert.Equal(t, nip, r[0].Addr()) ip, err = netip.ParseAddr("1.0.0.1") require.NoError(t, err) @@ -289,10 +314,91 @@ func Test_makeRouteTree(t *testing.T) { nip, err = netip.ParseAddr("192.168.0.2") require.NoError(t, err) - assert.Equal(t, nip, r) + assert.Equal(t, nip, r[0].Addr()) ip, err = netip.ParseAddr("1.1.0.1") require.NoError(t, err) r, ok = routeTree.Lookup(ip) assert.False(t, ok) } + +func Test_makeMultipathUnsafeRouteTree(t *testing.T) { + l := test.NewLogger() + c := config.NewC(l) + n, err := netip.ParsePrefix("10.0.0.0/24") + require.NoError(t, err) + + c.Settings["tun"] = map[interface{}]interface{}{ + "unsafe_routes": []interface{}{ + map[interface{}]interface{}{ + "route": "192.168.86.0/24", + "via": "192.168.100.10", + }, + map[interface{}]interface{}{ + "route": "192.168.87.0/24", + "via": []interface{}{ + map[interface{}]interface{}{ + "gateway": "10.0.0.1", + }, + map[interface{}]interface{}{ + "gateway": "10.0.0.2", + }, + map[interface{}]interface{}{ + "gateway": "10.0.0.3", + }, + }, + }, + map[interface{}]interface{}{ + "route": "192.168.89.0/24", + "via": []interface{}{ + map[interface{}]interface{}{ + "gateway": "10.0.0.1", + "weight": 10, + }, + map[interface{}]interface{}{ + "gateway": "10.0.0.2", + "weight": 5, + }, + }, + }, + }, + } + + routes, err := parseUnsafeRoutes(c, []netip.Prefix{n}) + require.NoError(t, err) + assert.Len(t, routes, 3) + routeTree, err := makeRouteTree(l, routes, true) + require.NoError(t, err) + + ip, err := netip.ParseAddr("192.168.86.1") + require.NoError(t, err) + r, ok := routeTree.Lookup(ip) + assert.True(t, ok) + + nip, err := netip.ParseAddr("192.168.100.10") + require.NoError(t, err) + assert.Equal(t, nip, r[0].Addr()) + + ip, err = netip.ParseAddr("192.168.87.1") + require.NoError(t, err) + r, ok = routeTree.Lookup(ip) + assert.True(t, ok) + + expectedGateways := routing.Gateways{routing.NewGateway(netip.MustParseAddr("10.0.0.1"), 1), + routing.NewGateway(netip.MustParseAddr("10.0.0.2"), 1), + routing.NewGateway(netip.MustParseAddr("10.0.0.3"), 1)} + + routing.CalculateBucketsForGateways(expectedGateways) + assert.ElementsMatch(t, expectedGateways, r) + + ip, err = netip.ParseAddr("192.168.89.1") + require.NoError(t, err) + r, ok = routeTree.Lookup(ip) + assert.True(t, ok) + + expectedGateways = routing.Gateways{routing.NewGateway(netip.MustParseAddr("10.0.0.1"), 10), + routing.NewGateway(netip.MustParseAddr("10.0.0.2"), 5)} + + routing.CalculateBucketsForGateways(expectedGateways) + assert.ElementsMatch(t, expectedGateways, r) +} diff --git a/overlay/tun_android.go b/overlay/tun_android.go index 72a6565..df1ed8d 100644 --- a/overlay/tun_android.go +++ b/overlay/tun_android.go @@ -13,6 +13,7 @@ import ( "github.com/gaissmai/bart" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/util" ) @@ -21,7 +22,7 @@ type tun struct { fd int vpnNetworks []netip.Prefix Routes atomic.Pointer[[]Route] - routeTree atomic.Pointer[bart.Table[netip.Addr]] + routeTree atomic.Pointer[bart.Table[routing.Gateways]] l *logrus.Logger } @@ -56,7 +57,7 @@ func newTun(_ *config.C, _ *logrus.Logger, _ []netip.Prefix, _ bool) (*tun, erro return nil, fmt.Errorf("newTun not supported in Android") } -func (t *tun) RouteFor(ip netip.Addr) netip.Addr { +func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways { r, _ := t.routeTree.Load().Lookup(ip) return r } diff --git a/overlay/tun_darwin.go b/overlay/tun_darwin.go index 1a02b49..d2b2896 100644 --- a/overlay/tun_darwin.go +++ b/overlay/tun_darwin.go @@ -17,6 +17,7 @@ import ( "github.com/gaissmai/bart" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/util" netroute "golang.org/x/net/route" "golang.org/x/sys/unix" @@ -28,7 +29,7 @@ type tun struct { vpnNetworks []netip.Prefix DefaultMTU int Routes atomic.Pointer[[]Route] - routeTree atomic.Pointer[bart.Table[netip.Addr]] + routeTree atomic.Pointer[bart.Table[routing.Gateways]] linkAddr *netroute.LinkAddr l *logrus.Logger @@ -342,12 +343,12 @@ func (t *tun) reload(c *config.C, initial bool) error { return nil } -func (t *tun) RouteFor(ip netip.Addr) netip.Addr { +func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways { r, ok := t.routeTree.Load().Lookup(ip) if ok { return r } - return netip.Addr{} + return routing.Gateways{} } // Get the LinkAddr for the interface of the given name @@ -382,7 +383,7 @@ func (t *tun) addRoutes(logErrors bool) error { routes := *t.Routes.Load() for _, r := range routes { - if !r.Via.IsValid() || !r.Install { + if len(r.Via) == 0 || !r.Install { // We don't allow route MTUs so only install routes with a via continue } diff --git a/overlay/tun_disabled.go b/overlay/tun_disabled.go index cfbf17d..131879d 100644 --- a/overlay/tun_disabled.go +++ b/overlay/tun_disabled.go @@ -9,6 +9,7 @@ import ( "github.com/rcrowley/go-metrics" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/iputil" + "github.com/slackhq/nebula/routing" ) type disabledTun struct { @@ -43,8 +44,8 @@ func (*disabledTun) Activate() error { return nil } -func (*disabledTun) RouteFor(addr netip.Addr) netip.Addr { - return netip.Addr{} +func (*disabledTun) RoutesFor(addr netip.Addr) routing.Gateways { + return routing.Gateways{} } func (t *disabledTun) Networks() []netip.Prefix { diff --git a/overlay/tun_freebsd.go b/overlay/tun_freebsd.go index 69690e9..bcb82b3 100644 --- a/overlay/tun_freebsd.go +++ b/overlay/tun_freebsd.go @@ -20,6 +20,7 @@ import ( "github.com/gaissmai/bart" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/util" ) @@ -50,7 +51,7 @@ type tun struct { vpnNetworks []netip.Prefix MTU int Routes atomic.Pointer[[]Route] - routeTree atomic.Pointer[bart.Table[netip.Addr]] + routeTree atomic.Pointer[bart.Table[routing.Gateways]] l *logrus.Logger io.ReadWriteCloser @@ -242,7 +243,7 @@ func (t *tun) reload(c *config.C, initial bool) error { return nil } -func (t *tun) RouteFor(ip netip.Addr) netip.Addr { +func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways { r, _ := t.routeTree.Load().Lookup(ip) return r } @@ -262,7 +263,7 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { func (t *tun) addRoutes(logErrors bool) error { routes := *t.Routes.Load() for _, r := range routes { - if !r.Via.IsValid() || !r.Install { + if len(r.Via) == 0 || !r.Install { // We don't allow route MTUs so only install routes with a via continue } diff --git a/overlay/tun_ios.go b/overlay/tun_ios.go index e99d447..e51e112 100644 --- a/overlay/tun_ios.go +++ b/overlay/tun_ios.go @@ -16,6 +16,7 @@ import ( "github.com/gaissmai/bart" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/util" ) @@ -23,7 +24,7 @@ type tun struct { io.ReadWriteCloser vpnNetworks []netip.Prefix Routes atomic.Pointer[[]Route] - routeTree atomic.Pointer[bart.Table[netip.Addr]] + routeTree atomic.Pointer[bart.Table[routing.Gateways]] l *logrus.Logger } @@ -79,7 +80,7 @@ func (t *tun) reload(c *config.C, initial bool) error { return nil } -func (t *tun) RouteFor(ip netip.Addr) netip.Addr { +func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways { r, _ := t.routeTree.Load().Lookup(ip) return r } diff --git a/overlay/tun_linux.go b/overlay/tun_linux.go index 993bd4a..809536f 100644 --- a/overlay/tun_linux.go +++ b/overlay/tun_linux.go @@ -17,6 +17,7 @@ import ( "github.com/gaissmai/bart" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/util" "github.com/vishvananda/netlink" "golang.org/x/sys/unix" @@ -34,7 +35,7 @@ type tun struct { ioctlFd uintptr Routes atomic.Pointer[[]Route] - routeTree atomic.Pointer[bart.Table[netip.Addr]] + routeTree atomic.Pointer[bart.Table[routing.Gateways]] routeChan chan struct{} useSystemRoutes bool @@ -231,7 +232,7 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { return file, nil } -func (t *tun) RouteFor(ip netip.Addr) netip.Addr { +func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways { r, _ := t.routeTree.Load().Lookup(ip) return r } @@ -550,20 +551,7 @@ func (t *tun) watchRoutes() { }() } -func (t *tun) updateRoutes(r netlink.RouteUpdate) { - if r.Gw == nil { - // Not a gateway route, ignore - t.l.WithField("route", r).Debug("Ignoring route update, not a gateway route") - return - } - - gwAddr, ok := netip.AddrFromSlice(r.Gw) - if !ok { - t.l.WithField("route", r).Debug("Ignoring route update, invalid gateway address") - return - } - - gwAddr = gwAddr.Unmap() +func (t *tun) isGatewayInVpnNetworks(gwAddr netip.Addr) bool { withinNetworks := false for i := range t.vpnNetworks { if t.vpnNetworks[i].Contains(gwAddr) { @@ -571,9 +559,68 @@ func (t *tun) updateRoutes(r netlink.RouteUpdate) { break } } - if !withinNetworks { - // Gateway isn't in our overlay network, ignore - t.l.WithField("route", r).Debug("Ignoring route update, not in our networks") + + return withinNetworks +} + +func (t *tun) getGatewaysFromRoute(r *netlink.Route) routing.Gateways { + + var gateways routing.Gateways + + link, err := netlink.LinkByName(t.Device) + if err != nil { + t.l.WithField("Devicename", t.Device).Error("Ignoring route update: failed to get link by name") + return gateways + } + + // If this route is relevant to our interface and there is a gateway then add it + if r.LinkIndex == link.Attrs().Index && len(r.Gw) > 0 { + gwAddr, ok := netip.AddrFromSlice(r.Gw) + if !ok { + t.l.WithField("route", r).Debug("Ignoring route update, invalid gateway address") + } else { + gwAddr = gwAddr.Unmap() + + if !t.isGatewayInVpnNetworks(gwAddr) { + // Gateway isn't in our overlay network, ignore + t.l.WithField("route", r).Debug("Ignoring route update, not in our network") + } else { + gateways = append(gateways, routing.NewGateway(gwAddr, 1)) + } + } + } + + for _, p := range r.MultiPath { + // If this route is relevant to our interface and there is a gateway then add it + if p.LinkIndex == link.Attrs().Index && len(p.Gw) > 0 { + gwAddr, ok := netip.AddrFromSlice(p.Gw) + if !ok { + t.l.WithField("route", r).Debug("Ignoring multipath route update, invalid gateway address") + } else { + gwAddr = gwAddr.Unmap() + + if !t.isGatewayInVpnNetworks(gwAddr) { + // Gateway isn't in our overlay network, ignore + t.l.WithField("route", r).Debug("Ignoring route update, not in our network") + } else { + // p.Hops+1 = weight of the route + gateways = append(gateways, routing.NewGateway(gwAddr, p.Hops+1)) + } + } + } + } + + routing.CalculateBucketsForGateways(gateways) + return gateways +} + +func (t *tun) updateRoutes(r netlink.RouteUpdate) { + + gateways := t.getGatewaysFromRoute(&r.Route) + + if len(gateways) == 0 { + // No gateways relevant to our network, no routing changes required. + t.l.WithField("route", r).Debug("Ignoring route update, no gateways") return } @@ -589,12 +636,12 @@ func (t *tun) updateRoutes(r netlink.RouteUpdate) { newTree := t.routeTree.Load().Clone() if r.Type == unix.RTM_NEWROUTE { - t.l.WithField("destination", r.Dst).WithField("via", r.Gw).Info("Adding route") - newTree.Insert(dst, gwAddr) + t.l.WithField("destination", dst).WithField("via", gateways).Info("Adding route") + newTree.Insert(dst, gateways) } else { + t.l.WithField("destination", dst).WithField("via", gateways).Info("Removing route") newTree.Delete(dst) - t.l.WithField("destination", r.Dst).WithField("via", r.Gw).Info("Removing route") } t.routeTree.Store(newTree) } diff --git a/overlay/tun_netbsd.go b/overlay/tun_netbsd.go index f7586cb..847f1b5 100644 --- a/overlay/tun_netbsd.go +++ b/overlay/tun_netbsd.go @@ -18,6 +18,7 @@ import ( "github.com/gaissmai/bart" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/util" ) @@ -31,7 +32,7 @@ type tun struct { vpnNetworks []netip.Prefix MTU int Routes atomic.Pointer[[]Route] - routeTree atomic.Pointer[bart.Table[netip.Addr]] + routeTree atomic.Pointer[bart.Table[routing.Gateways]] l *logrus.Logger io.ReadWriteCloser @@ -177,7 +178,7 @@ func (t *tun) reload(c *config.C, initial bool) error { return nil } -func (t *tun) RouteFor(ip netip.Addr) netip.Addr { +func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways { r, _ := t.routeTree.Load().Lookup(ip) return r } @@ -197,7 +198,7 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { func (t *tun) addRoutes(logErrors bool) error { routes := *t.Routes.Load() for _, r := range routes { - if !r.Via.IsValid() || !r.Install { + if len(r.Via) == 0 || !r.Install { // We don't allow route MTUs so only install routes with a via continue } diff --git a/overlay/tun_openbsd.go b/overlay/tun_openbsd.go index a2fd184..03fb3a0 100644 --- a/overlay/tun_openbsd.go +++ b/overlay/tun_openbsd.go @@ -17,6 +17,7 @@ import ( "github.com/gaissmai/bart" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/util" ) @@ -25,7 +26,7 @@ type tun struct { vpnNetworks []netip.Prefix MTU int Routes atomic.Pointer[[]Route] - routeTree atomic.Pointer[bart.Table[netip.Addr]] + routeTree atomic.Pointer[bart.Table[routing.Gateways]] l *logrus.Logger io.ReadWriteCloser @@ -158,7 +159,7 @@ func (t *tun) Activate() error { return nil } -func (t *tun) RouteFor(ip netip.Addr) netip.Addr { +func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways { r, _ := t.routeTree.Load().Lookup(ip) return r } @@ -166,7 +167,7 @@ func (t *tun) RouteFor(ip netip.Addr) netip.Addr { func (t *tun) addRoutes(logErrors bool) error { routes := *t.Routes.Load() for _, r := range routes { - if !r.Via.IsValid() || !r.Install { + if len(r.Via) == 0 || !r.Install { // We don't allow route MTUs so only install routes with a via continue } diff --git a/overlay/tun_tester.go b/overlay/tun_tester.go index cc3942f..b6712fb 100644 --- a/overlay/tun_tester.go +++ b/overlay/tun_tester.go @@ -13,13 +13,14 @@ import ( "github.com/gaissmai/bart" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/routing" ) type TestTun struct { Device string vpnNetworks []netip.Prefix Routes []Route - routeTree *bart.Table[netip.Addr] + routeTree *bart.Table[routing.Gateways] l *logrus.Logger closed atomic.Bool @@ -86,7 +87,7 @@ func (t *TestTun) Get(block bool) []byte { // Below this is boilerplate implementation to make nebula actually work //********************************************************************************************************************// -func (t *TestTun) RouteFor(ip netip.Addr) netip.Addr { +func (t *TestTun) RoutesFor(ip netip.Addr) routing.Gateways { r, _ := t.routeTree.Lookup(ip) return r } diff --git a/overlay/tun_windows.go b/overlay/tun_windows.go index 289999d..1d66eac 100644 --- a/overlay/tun_windows.go +++ b/overlay/tun_windows.go @@ -18,6 +18,7 @@ import ( "github.com/gaissmai/bart" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/util" "github.com/slackhq/nebula/wintun" "golang.org/x/sys/windows" @@ -31,7 +32,7 @@ type winTun struct { vpnNetworks []netip.Prefix MTU int Routes atomic.Pointer[[]Route] - routeTree atomic.Pointer[bart.Table[netip.Addr]] + routeTree atomic.Pointer[bart.Table[routing.Gateways]] l *logrus.Logger tun *wintun.NativeTun @@ -147,13 +148,16 @@ func (t *winTun) addRoutes(logErrors bool) error { foundDefault4 := false for _, r := range routes { - if !r.Via.IsValid() || !r.Install { + if len(r.Via) == 0 || !r.Install { // We don't allow route MTUs so only install routes with a via continue } // Add our unsafe route - err := luid.AddRoute(r.Cidr, r.Via, uint32(r.Metric)) + // Windows does not support multipath routes natively, so we install only a single route. + // This is not a problem as traffic will always be sent to Nebula which handles the multipath routing internally. + // In effect this provides multipath routing support to windows supporting loadbalancing and redundancy. + err := luid.AddRoute(r.Cidr, r.Via[0].Addr(), uint32(r.Metric)) if err != nil { retErr := util.NewContextualError("Failed to add route", map[string]interface{}{"route": r}, err) if logErrors { @@ -198,7 +202,8 @@ func (t *winTun) removeRoutes(routes []Route) error { continue } - err := luid.DeleteRoute(r.Cidr, r.Via) + // See comment on luid.AddRoute + err := luid.DeleteRoute(r.Cidr, r.Via[0].Addr()) if err != nil { t.l.WithError(err).WithField("route", r).Error("Failed to remove route") } else { @@ -208,7 +213,7 @@ func (t *winTun) removeRoutes(routes []Route) error { return nil } -func (t *winTun) RouteFor(ip netip.Addr) netip.Addr { +func (t *winTun) RoutesFor(ip netip.Addr) routing.Gateways { r, _ := t.routeTree.Load().Lookup(ip) return r } diff --git a/overlay/user.go b/overlay/user.go index ae665f3..8a56d66 100644 --- a/overlay/user.go +++ b/overlay/user.go @@ -6,6 +6,7 @@ import ( "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/routing" ) func NewUserDeviceFromConfig(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) { @@ -38,9 +39,13 @@ type UserDevice struct { func (d *UserDevice) Activate() error { return nil } -func (d *UserDevice) Networks() []netip.Prefix { return d.vpnNetworks } -func (d *UserDevice) Name() string { return "faketun0" } -func (d *UserDevice) RouteFor(ip netip.Addr) netip.Addr { return ip } + +func (d *UserDevice) Networks() []netip.Prefix { return d.vpnNetworks } +func (d *UserDevice) Name() string { return "faketun0" } +func (d *UserDevice) RoutesFor(ip netip.Addr) routing.Gateways { + return routing.Gateways{routing.NewGateway(ip, 1)} +} + func (d *UserDevice) NewMultiQueueReader() (io.ReadWriteCloser, error) { return d, nil } diff --git a/routing/balance.go b/routing/balance.go new file mode 100644 index 0000000..6f52497 --- /dev/null +++ b/routing/balance.go @@ -0,0 +1,39 @@ +package routing + +import ( + "net/netip" + + "github.com/slackhq/nebula/firewall" +) + +// Hashes the packet source and destination port and always returns a positive integer +// Based on 'Prospecting for Hash Functions' +// - https://nullprogram.com/blog/2018/07/31/ +// - https://github.com/skeeto/hash-prospector +// [16 21f0aaad 15 d35a2d97 15] = 0.10760229515479501 +func hashPacket(p *firewall.Packet) int { + x := (uint32(p.LocalPort) << 16) | uint32(p.RemotePort) + x ^= x >> 16 + x *= 0x21f0aaad + x ^= x >> 15 + x *= 0xd35a2d97 + x ^= x >> 15 + + return int(x) & 0x7FFFFFFF +} + +// For this function to work correctly it requires that the buckets for the gateways have been calculated +// If the contract is violated balancing will not work properly and the second return value will return false +func BalancePacket(fwPacket *firewall.Packet, gateways []Gateway) (netip.Addr, bool) { + hash := hashPacket(fwPacket) + + for i := range gateways { + if hash <= gateways[i].BucketUpperBound() { + return gateways[i].Addr(), true + } + } + + // If you land here then the buckets for the gateways are not properly calculated + // Fallback to random routing and let the caller know + return gateways[hash%len(gateways)].Addr(), false +} diff --git a/routing/balance_test.go b/routing/balance_test.go new file mode 100644 index 0000000..bbfcb22 --- /dev/null +++ b/routing/balance_test.go @@ -0,0 +1,144 @@ +package routing + +import ( + "net/netip" + "testing" + + "github.com/slackhq/nebula/firewall" + "github.com/stretchr/testify/assert" +) + +func TestPacketsAreBalancedEqually(t *testing.T) { + + gateways := []Gateway{} + + gw1Addr := netip.MustParseAddr("1.0.0.1") + gw2Addr := netip.MustParseAddr("1.0.0.2") + gw3Addr := netip.MustParseAddr("1.0.0.3") + + gateways = append(gateways, NewGateway(gw1Addr, 1)) + gateways = append(gateways, NewGateway(gw2Addr, 1)) + gateways = append(gateways, NewGateway(gw3Addr, 1)) + + CalculateBucketsForGateways(gateways) + + gw1count := 0 + gw2count := 0 + gw3count := 0 + + iterationCount := uint16(65535) + for i := uint16(0); i < iterationCount; i++ { + packet := firewall.Packet{ + LocalAddr: netip.MustParseAddr("192.168.1.1"), + RemoteAddr: netip.MustParseAddr("10.0.0.1"), + LocalPort: i, + RemotePort: 65535 - i, + Protocol: 6, // TCP + Fragment: false, + } + + selectedGw, ok := BalancePacket(&packet, gateways) + assert.True(t, ok) + + switch selectedGw { + case gw1Addr: + gw1count += 1 + case gw2Addr: + gw2count += 1 + case gw3Addr: + gw3count += 1 + } + + } + + // Assert packets are balanced, allow variation of up to 100 packets per gateway + assert.InDeltaf(t, iterationCount/3, gw1count, 100, "Expected %d +/- 100, but got %d", iterationCount/3, gw1count) + assert.InDeltaf(t, iterationCount/3, gw2count, 100, "Expected %d +/- 100, but got %d", iterationCount/3, gw1count) + assert.InDeltaf(t, iterationCount/3, gw3count, 100, "Expected %d +/- 100, but got %d", iterationCount/3, gw1count) + +} + +func TestPacketsAreBalancedByPriority(t *testing.T) { + + gateways := []Gateway{} + + gw1Addr := netip.MustParseAddr("1.0.0.1") + gw2Addr := netip.MustParseAddr("1.0.0.2") + + gateways = append(gateways, NewGateway(gw1Addr, 10)) + gateways = append(gateways, NewGateway(gw2Addr, 5)) + + CalculateBucketsForGateways(gateways) + + gw1count := 0 + gw2count := 0 + + iterationCount := uint16(65535) + for i := uint16(0); i < iterationCount; i++ { + packet := firewall.Packet{ + LocalAddr: netip.MustParseAddr("192.168.1.1"), + RemoteAddr: netip.MustParseAddr("10.0.0.1"), + LocalPort: i, + RemotePort: 65535 - i, + Protocol: 6, // TCP + Fragment: false, + } + + selectedGw, ok := BalancePacket(&packet, gateways) + assert.True(t, ok) + + switch selectedGw { + case gw1Addr: + gw1count += 1 + case gw2Addr: + gw2count += 1 + } + + } + + iterationCountAsFloat := float32(iterationCount) + + assert.InDeltaf(t, iterationCountAsFloat*(2.0/3.0), gw1count, 100, "Expected %d +/- 100, but got %d", iterationCountAsFloat*(2.0/3.0), gw1count) + assert.InDeltaf(t, iterationCountAsFloat*(1.0/3.0), gw2count, 100, "Expected %d +/- 100, but got %d", iterationCountAsFloat*(1.0/3.0), gw2count) +} + +func TestBalancePacketDistributsRandomlyAndReturnsFalseIfBucketsNotCalculated(t *testing.T) { + gateways := []Gateway{} + + gw1Addr := netip.MustParseAddr("1.0.0.1") + gw2Addr := netip.MustParseAddr("1.0.0.2") + + gateways = append(gateways, NewGateway(gw1Addr, 10)) + gateways = append(gateways, NewGateway(gw2Addr, 5)) + + iterationCount := uint16(65535) + gw1count := 0 + gw2count := 0 + + for i := uint16(0); i < iterationCount; i++ { + packet := firewall.Packet{ + LocalAddr: netip.MustParseAddr("192.168.1.1"), + RemoteAddr: netip.MustParseAddr("10.0.0.1"), + LocalPort: i, + RemotePort: 65535 - i, + Protocol: 6, // TCP + Fragment: false, + } + + selectedGw, ok := BalancePacket(&packet, gateways) + assert.False(t, ok) + + switch selectedGw { + case gw1Addr: + gw1count += 1 + case gw2Addr: + gw2count += 1 + } + + } + + assert.Equal(t, int(iterationCount), (gw1count + gw2count)) + assert.NotEqual(t, 0, gw1count) + assert.NotEqual(t, 0, gw2count) + +} diff --git a/routing/gateway.go b/routing/gateway.go new file mode 100644 index 0000000..59d38a9 --- /dev/null +++ b/routing/gateway.go @@ -0,0 +1,70 @@ +package routing + +import ( + "fmt" + "net/netip" +) + +const ( + // Sentinal value + BucketNotCalculated = -1 +) + +type Gateways []Gateway + +func (g Gateways) String() string { + str := "" + for i, gw := range g { + str += gw.String() + if i < len(g)-1 { + str += ", " + } + } + return str +} + +type Gateway struct { + addr netip.Addr + weight int + bucketUpperBound int +} + +func NewGateway(addr netip.Addr, weight int) Gateway { + return Gateway{addr: addr, weight: weight, bucketUpperBound: BucketNotCalculated} +} + +func (g *Gateway) BucketUpperBound() int { + return g.bucketUpperBound +} + +func (g *Gateway) Addr() netip.Addr { + return g.addr +} + +func (g *Gateway) String() string { + return fmt.Sprintf("{addr: %s, weight: %d}", g.addr, g.weight) +} + +// Divide and round to nearest integer +func divideAndRound(v uint64, d uint64) uint64 { + var tmp uint64 = v + d/2 + return tmp / d +} + +// Implements Hash-Threshold mapping, equivalent to the implementation in the linux kernel. +// After this function returns each gateway will have a +// positive bucketUpperBound with a maximum value of 2147483647 (INT_MAX) +func CalculateBucketsForGateways(gateways []Gateway) { + + var totalWeight int = 0 + for i := range gateways { + totalWeight += gateways[i].weight + } + + var loopWeight int = 0 + for i := range gateways { + loopWeight += gateways[i].weight + gateways[i].bucketUpperBound = int(divideAndRound(uint64(loopWeight)<<31, uint64(totalWeight))) - 1 + } + +} diff --git a/routing/gateway_test.go b/routing/gateway_test.go new file mode 100644 index 0000000..8ae78f3 --- /dev/null +++ b/routing/gateway_test.go @@ -0,0 +1,34 @@ +package routing + +import ( + "net/netip" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestRebalance3_2Split(t *testing.T) { + gateways := []Gateway{} + + gateways = append(gateways, Gateway{addr: netip.Addr{}, weight: 10}) + gateways = append(gateways, Gateway{addr: netip.Addr{}, weight: 5}) + + CalculateBucketsForGateways(gateways) + + assert.Equal(t, 1431655764, gateways[0].bucketUpperBound) // INT_MAX/3*2 + assert.Equal(t, 2147483647, gateways[1].bucketUpperBound) // INT_MAX +} + +func TestRebalanceEqualSplit(t *testing.T) { + gateways := []Gateway{} + + gateways = append(gateways, Gateway{addr: netip.Addr{}, weight: 1}) + gateways = append(gateways, Gateway{addr: netip.Addr{}, weight: 1}) + gateways = append(gateways, Gateway{addr: netip.Addr{}, weight: 1}) + + CalculateBucketsForGateways(gateways) + + assert.Equal(t, 715827882, gateways[0].bucketUpperBound) // INT_MAX/3 + assert.Equal(t, 1431655764, gateways[1].bucketUpperBound) // INT_MAX/3*2 + assert.Equal(t, 2147483647, gateways[2].bucketUpperBound) // INT_MAX +} diff --git a/test/tun.go b/test/tun.go index b29d61c..ca65805 100644 --- a/test/tun.go +++ b/test/tun.go @@ -4,12 +4,14 @@ import ( "errors" "io" "net/netip" + + "github.com/slackhq/nebula/routing" ) type NoopTun struct{} -func (NoopTun) RouteFor(addr netip.Addr) netip.Addr { - return netip.Addr{} +func (NoopTun) RoutesFor(addr netip.Addr) routing.Gateways { + return routing.Gateways{} } func (NoopTun) Activate() error { From 4444ed166ac163bcf4296d62d826c06b3376957b Mon Sep 17 00:00:00 2001 From: Caleb Jasik Date: Tue, 25 Mar 2025 16:08:36 -0500 Subject: [PATCH 57/67] Add `certVersion` field to logs when logging the cert name in handshakes (#1359) --- handshake_ix.go | 25 +++++++++++++++++++++++-- 1 file changed, 23 insertions(+), 2 deletions(-) diff --git a/handshake_ix.go b/handshake_ix.go index daea526..0783999 100644 --- a/handshake_ix.go +++ b/handshake_ix.go @@ -71,7 +71,8 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool { hsBytes, err := hs.Marshal() if err != nil { - f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs).WithField("certVersion", v). + f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs). + WithField("certVersion", v). WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to marshal handshake message") return false } @@ -185,6 +186,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet var vpnAddrs []netip.Addr var filteredNetworks []netip.Prefix certName := remoteCert.Certificate.Name() + certVersion := remoteCert.Certificate.Version() fingerprint := remoteCert.Fingerprint issuer := remoteCert.Certificate.Issuer() @@ -194,6 +196,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet if found { f.l.WithField("vpnAddr", vpnAddr).WithField("udpAddr", addr). WithField("certName", certName). + WithField("certVersion", certVersion). WithField("fingerprint", fingerprint). WithField("issuer", issuer). WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Refusing to handshake with myself") @@ -212,6 +215,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet if len(vpnAddrs) == 0 { f.l.WithError(err).WithField("udpAddr", addr). WithField("certName", certName). + WithField("certVersion", certVersion). WithField("fingerprint", fingerprint). WithField("issuer", issuer). WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("No usable vpn addresses from host, refusing handshake") @@ -231,6 +235,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet if err != nil { f.l.WithError(err).WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr). WithField("certName", certName). + WithField("certVersion", certVersion). WithField("fingerprint", fingerprint). WithField("issuer", issuer). WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to generate index") @@ -253,6 +258,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr). WithField("certName", certName). + WithField("certVersion", certVersion). WithField("fingerprint", fingerprint). WithField("issuer", issuer). WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). @@ -264,6 +270,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet if hs.Details.Cert == nil { f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr). WithField("certName", certName). + WithField("certVersion", certVersion). WithField("fingerprint", fingerprint). WithField("issuer", issuer). WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). @@ -281,6 +288,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet if err != nil { f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr). WithField("certName", certName). + WithField("certVersion", certVersion). WithField("fingerprint", fingerprint). WithField("issuer", issuer). WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to marshal handshake message") @@ -292,6 +300,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet if err != nil { f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr). WithField("certName", certName). + WithField("certVersion", certVersion). WithField("fingerprint", fingerprint). WithField("issuer", issuer). WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage") @@ -299,6 +308,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet } else if dKey == nil || eKey == nil { f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr). WithField("certName", certName). + WithField("certVersion", certVersion). WithField("fingerprint", fingerprint). WithField("issuer", issuer). WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Noise did not arrive at a key") @@ -366,6 +376,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet // This means there was an existing tunnel and this handshake was older than the one we are currently based on f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr). WithField("certName", certName). + WithField("certVersion", certVersion). WithField("oldHandshakeTime", existing.lastHandshakeTime). WithField("newHandshakeTime", hostinfo.lastHandshakeTime). WithField("fingerprint", fingerprint). @@ -381,6 +392,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet // This means we failed to insert because of collision on localIndexId. Just let the next handshake packet retry f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr). WithField("certName", certName). + WithField("certVersion", certVersion). WithField("fingerprint", fingerprint). WithField("issuer", issuer). WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). @@ -393,6 +405,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet // And we forget to update it here f.l.WithError(err).WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr). WithField("certName", certName). + WithField("certVersion", certVersion). WithField("fingerprint", fingerprint). WithField("issuer", issuer). WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). @@ -409,6 +422,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet if err != nil { f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr). WithField("certName", certName). + WithField("certVersion", certVersion). WithField("fingerprint", fingerprint). WithField("issuer", issuer). WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). @@ -417,6 +431,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet } else { f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr). WithField("certName", certName). + WithField("certVersion", certVersion). WithField("fingerprint", fingerprint). WithField("issuer", issuer). WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). @@ -435,6 +450,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet f.SendVia(via.relayHI, via.relay, msg, make([]byte, 12), make([]byte, mtu), false) f.l.WithField("vpnAddrs", vpnAddrs).WithField("relay", via.relayHI.vpnAddrs[0]). WithField("certName", certName). + WithField("certVersion", certVersion). WithField("fingerprint", fingerprint). WithField("issuer", issuer). WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). @@ -539,6 +555,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha vpnNetworks := remoteCert.Certificate.Networks() certName := remoteCert.Certificate.Name() + certVersion := remoteCert.Certificate.Version() fingerprint := remoteCert.Fingerprint issuer := remoteCert.Certificate.Issuer() @@ -573,6 +590,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha if len(vpnAddrs) == 0 { f.l.WithError(err).WithField("udpAddr", addr). WithField("certName", certName). + WithField("certVersion", certVersion). WithField("fingerprint", fingerprint). WithField("issuer", issuer). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).Error("No usable vpn addresses from host, refusing handshake") @@ -582,7 +600,9 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha // Ensure the right host responded if !slices.Contains(vpnAddrs, hostinfo.vpnAddrs[0]) { f.l.WithField("intendedVpnAddrs", hostinfo.vpnAddrs).WithField("haveVpnNetworks", vpnNetworks). - WithField("udpAddr", addr).WithField("certName", certName). + WithField("udpAddr", addr). + WithField("certName", certName). + WithField("certVersion", certVersion). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). Info("Incorrect host responded to handshake") @@ -618,6 +638,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha duration := time.Since(hh.startTime).Nanoseconds() f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr). WithField("certName", certName). + WithField("certVersion", certVersion). WithField("fingerprint", fingerprint). WithField("issuer", issuer). WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). From 75faa5f2e5f551e21fcb75a9aeb3805366f30d90 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 31 Mar 2025 16:05:07 -0400 Subject: [PATCH 58/67] Bump golang.org/x/net in the golang-x-dependencies group (#1370) Bumps the golang-x-dependencies group with 1 update: [golang.org/x/net](https://github.com/golang/net). Updates `golang.org/x/net` from 0.37.0 to 0.38.0 - [Commits](https://github.com/golang/net/compare/v0.37.0...v0.38.0) --- updated-dependencies: - dependency-name: golang.org/x/net dependency-type: direct:production update-type: version-update:semver-minor dependency-group: golang-x-dependencies ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- go.mod | 4 ++-- go.sum | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/go.mod b/go.mod index de09c18..3b13170 100644 --- a/go.mod +++ b/go.mod @@ -2,7 +2,7 @@ module github.com/slackhq/nebula go 1.23.6 -toolchain go1.23.7 +toolchain go1.24.1 require ( dario.cat/mergo v1.0.1 @@ -26,7 +26,7 @@ require ( github.com/vishvananda/netlink v1.3.0 golang.org/x/crypto v0.36.0 golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 - golang.org/x/net v0.37.0 + golang.org/x/net v0.38.0 golang.org/x/sync v0.12.0 golang.org/x/sys v0.31.0 golang.org/x/term v0.30.0 diff --git a/go.sum b/go.sum index 11f57c7..78f2671 100644 --- a/go.sum +++ b/go.sum @@ -176,8 +176,8 @@ golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLL golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= -golang.org/x/net v0.37.0 h1:1zLorHbz+LYj7MQlSf1+2tPIIgibq2eL5xkrGk6f+2c= -golang.org/x/net v0.37.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8= +golang.org/x/net v0.38.0 h1:vRMAPTMaeGqVhG5QyLJHqNDwecKTomGeqbnfZyKlBI8= +golang.org/x/net v0.38.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= From 879852c32a385ac5059af91d89615178fcef532c Mon Sep 17 00:00:00 2001 From: Wade Simmons Date: Mon, 31 Mar 2025 16:08:34 -0400 Subject: [PATCH 59/67] upgrade to yaml.v3 (#1148) * upgrade to yaml.v3 The main nice fix here is that maps unmarshal into `map[string]any` instead of `map[any]any`, so it cleans things up a bit. * add config.AsBool Since yaml.v3 doesn't automatically convert yes to bool now, for backwards compat * use type aliases for m * more cleanup * more cleanup * more cleanup * go mod cleanup --- allow_list.go | 38 ++++++----------- allow_list_test.go | 24 +++++------ cert/cert_v1.go | 2 +- cmd/nebula-cert/main.go | 2 +- config/config.go | 50 ++++++++++++++-------- config/config_test.go | 38 ++++++++--------- control_test.go | 2 +- dns_server_test.go | 16 +++---- e2e/handshakes_test.go | 6 +-- e2e/helpers_test.go | 4 +- firewall.go | 10 ++--- firewall/packet.go | 2 +- firewall_test.go | 52 +++++++++++------------ go.mod | 3 +- go.sum | 2 - header/header.go | 2 +- lighthouse.go | 6 +-- lighthouse_test.go | 30 +++++++------- main.go | 4 +- overlay/route.go | 8 ++-- overlay/route_test.go | 76 +++++++++++++++++----------------- overlay/tun_darwin.go | 2 +- overlay/tun_freebsd.go | 2 +- overlay/tun_linux.go | 2 +- overlay/tun_netbsd.go | 2 +- overlay/tun_openbsd.go | 2 +- overlay/tun_windows.go | 2 +- punchy_test.go | 8 ++-- service/service_test.go | 4 +- ssh.go | 92 ++++++++++++++++++++--------------------- sshd/command.go | 10 ++--- sshd/server.go | 2 +- sshd/session.go | 2 +- test/assert.go | 2 +- util/error.go | 4 +- util/error_test.go | 2 +- 36 files changed, 257 insertions(+), 258 deletions(-) diff --git a/allow_list.go b/allow_list.go index cfdd983..cba56fc 100644 --- a/allow_list.go +++ b/allow_list.go @@ -36,7 +36,7 @@ type AllowListNameRule struct { func NewLocalAllowListFromConfig(c *config.C, k string) (*LocalAllowList, error) { var nameRules []AllowListNameRule - handleKey := func(key string, value interface{}) (bool, error) { + handleKey := func(key string, value any) (bool, error) { if key == "interfaces" { var err error nameRules, err = getAllowListInterfaces(k, value) @@ -70,7 +70,7 @@ func NewRemoteAllowListFromConfig(c *config.C, k, rangesKey string) (*RemoteAllo // If the handleKey func returns true, the rest of the parsing is skipped // for this key. This allows parsing of special values like `interfaces`. -func newAllowListFromConfig(c *config.C, k string, handleKey func(key string, value interface{}) (bool, error)) (*AllowList, error) { +func newAllowListFromConfig(c *config.C, k string, handleKey func(key string, value any) (bool, error)) (*AllowList, error) { r := c.Get(k) if r == nil { return nil, nil @@ -81,8 +81,8 @@ func newAllowListFromConfig(c *config.C, k string, handleKey func(key string, va // If the handleKey func returns true, the rest of the parsing is skipped // for this key. This allows parsing of special values like `interfaces`. -func newAllowList(k string, raw interface{}, handleKey func(key string, value interface{}) (bool, error)) (*AllowList, error) { - rawMap, ok := raw.(map[interface{}]interface{}) +func newAllowList(k string, raw any, handleKey func(key string, value any) (bool, error)) (*AllowList, error) { + rawMap, ok := raw.(map[string]any) if !ok { return nil, fmt.Errorf("config `%s` has invalid type: %T", k, raw) } @@ -100,12 +100,7 @@ func newAllowList(k string, raw interface{}, handleKey func(key string, value in rules4 := allowListRules{firstValue: true, allValuesMatch: true, defaultSet: false} rules6 := allowListRules{firstValue: true, allValuesMatch: true, defaultSet: false} - for rawKey, rawValue := range rawMap { - rawCIDR, ok := rawKey.(string) - if !ok { - return nil, fmt.Errorf("config `%s` has invalid key (type %T): %v", k, rawKey, rawKey) - } - + for rawCIDR, rawValue := range rawMap { if handleKey != nil { handled, err := handleKey(rawCIDR, rawValue) if err != nil { @@ -116,7 +111,7 @@ func newAllowList(k string, raw interface{}, handleKey func(key string, value in } } - value, ok := rawValue.(bool) + value, ok := config.AsBool(rawValue) if !ok { return nil, fmt.Errorf("config `%s` has invalid value (type %T): %v", k, rawValue, rawValue) } @@ -173,22 +168,18 @@ func newAllowList(k string, raw interface{}, handleKey func(key string, value in return &AllowList{cidrTree: tree}, nil } -func getAllowListInterfaces(k string, v interface{}) ([]AllowListNameRule, error) { +func getAllowListInterfaces(k string, v any) ([]AllowListNameRule, error) { var nameRules []AllowListNameRule - rawRules, ok := v.(map[interface{}]interface{}) + rawRules, ok := v.(map[string]any) if !ok { return nil, fmt.Errorf("config `%s.interfaces` is invalid (type %T): %v", k, v, v) } firstEntry := true var allValues bool - for rawName, rawAllow := range rawRules { - name, ok := rawName.(string) - if !ok { - return nil, fmt.Errorf("config `%s.interfaces` has invalid key (type %T): %v", k, rawName, rawName) - } - allow, ok := rawAllow.(bool) + for name, rawAllow := range rawRules { + allow, ok := config.AsBool(rawAllow) if !ok { return nil, fmt.Errorf("config `%s.interfaces` has invalid value (type %T): %v", k, rawAllow, rawAllow) } @@ -224,16 +215,11 @@ func getRemoteAllowRanges(c *config.C, k string) (*bart.Table[*AllowList], error remoteAllowRanges := new(bart.Table[*AllowList]) - rawMap, ok := value.(map[interface{}]interface{}) + rawMap, ok := value.(map[string]any) if !ok { return nil, fmt.Errorf("config `%s` has invalid type: %T", k, value) } - for rawKey, rawValue := range rawMap { - rawCIDR, ok := rawKey.(string) - if !ok { - return nil, fmt.Errorf("config `%s` has invalid key (type %T): %v", k, rawKey, rawKey) - } - + for rawCIDR, rawValue := range rawMap { allowList, err := newAllowList(fmt.Sprintf("%s.%s", k, rawCIDR), rawValue, nil) if err != nil { return nil, err diff --git a/allow_list_test.go b/allow_list_test.go index d7d2c9a..6135f36 100644 --- a/allow_list_test.go +++ b/allow_list_test.go @@ -15,27 +15,27 @@ import ( func TestNewAllowListFromConfig(t *testing.T) { l := test.NewLogger() c := config.NewC(l) - c.Settings["allowlist"] = map[interface{}]interface{}{ + c.Settings["allowlist"] = map[string]any{ "192.168.0.0": true, } r, err := newAllowListFromConfig(c, "allowlist", nil) require.EqualError(t, err, "config `allowlist` has invalid CIDR: 192.168.0.0. netip.ParsePrefix(\"192.168.0.0\"): no '/'") assert.Nil(t, r) - c.Settings["allowlist"] = map[interface{}]interface{}{ + c.Settings["allowlist"] = map[string]any{ "192.168.0.0/16": "abc", } r, err = newAllowListFromConfig(c, "allowlist", nil) require.EqualError(t, err, "config `allowlist` has invalid value (type string): abc") - c.Settings["allowlist"] = map[interface{}]interface{}{ + c.Settings["allowlist"] = map[string]any{ "192.168.0.0/16": true, "10.0.0.0/8": false, } r, err = newAllowListFromConfig(c, "allowlist", nil) require.EqualError(t, err, "config `allowlist` contains both true and false rules, but no default set for 0.0.0.0/0") - c.Settings["allowlist"] = map[interface{}]interface{}{ + c.Settings["allowlist"] = map[string]any{ "0.0.0.0/0": true, "10.0.0.0/8": false, "10.42.42.0/24": true, @@ -45,7 +45,7 @@ func TestNewAllowListFromConfig(t *testing.T) { r, err = newAllowListFromConfig(c, "allowlist", nil) require.EqualError(t, err, "config `allowlist` contains both true and false rules, but no default set for ::/0") - c.Settings["allowlist"] = map[interface{}]interface{}{ + c.Settings["allowlist"] = map[string]any{ "0.0.0.0/0": true, "10.0.0.0/8": false, "10.42.42.0/24": true, @@ -55,7 +55,7 @@ func TestNewAllowListFromConfig(t *testing.T) { assert.NotNil(t, r) } - c.Settings["allowlist"] = map[interface{}]interface{}{ + c.Settings["allowlist"] = map[string]any{ "0.0.0.0/0": true, "10.0.0.0/8": false, "10.42.42.0/24": true, @@ -70,16 +70,16 @@ func TestNewAllowListFromConfig(t *testing.T) { // Test interface names - c.Settings["allowlist"] = map[interface{}]interface{}{ - "interfaces": map[interface{}]interface{}{ + c.Settings["allowlist"] = map[string]any{ + "interfaces": map[string]any{ `docker.*`: "foo", }, } lr, err := NewLocalAllowListFromConfig(c, "allowlist") require.EqualError(t, err, "config `allowlist.interfaces` has invalid value (type string): foo") - c.Settings["allowlist"] = map[interface{}]interface{}{ - "interfaces": map[interface{}]interface{}{ + c.Settings["allowlist"] = map[string]any{ + "interfaces": map[string]any{ `docker.*`: false, `eth.*`: true, }, @@ -87,8 +87,8 @@ func TestNewAllowListFromConfig(t *testing.T) { lr, err = NewLocalAllowListFromConfig(c, "allowlist") require.EqualError(t, err, "config `allowlist.interfaces` values must all be the same true/false value") - c.Settings["allowlist"] = map[interface{}]interface{}{ - "interfaces": map[interface{}]interface{}{ + c.Settings["allowlist"] = map[string]any{ + "interfaces": map[string]any{ `docker.*`: false, }, } diff --git a/cert/cert_v1.go b/cert/cert_v1.go index 6bb146f..71d36eb 100644 --- a/cert/cert_v1.go +++ b/cert/cert_v1.go @@ -41,7 +41,7 @@ type detailsV1 struct { curve Curve } -type m map[string]interface{} +type m = map[string]any func (c *certificateV1) Version() Version { return Version1 diff --git a/cmd/nebula-cert/main.go b/cmd/nebula-cert/main.go index b803d30..c88626f 100644 --- a/cmd/nebula-cert/main.go +++ b/cmd/nebula-cert/main.go @@ -17,7 +17,7 @@ func (he *helpError) Error() string { return he.s } -func newHelpErrorf(s string, v ...interface{}) error { +func newHelpErrorf(s string, v ...any) error { return &helpError{s: fmt.Sprintf(s, v...)} } diff --git a/config/config.go b/config/config.go index 1aea832..b1531e9 100644 --- a/config/config.go +++ b/config/config.go @@ -17,14 +17,14 @@ import ( "dario.cat/mergo" "github.com/sirupsen/logrus" - "gopkg.in/yaml.v2" + "gopkg.in/yaml.v3" ) type C struct { path string files []string - Settings map[interface{}]interface{} - oldSettings map[interface{}]interface{} + Settings map[string]any + oldSettings map[string]any callbacks []func(*C) l *logrus.Logger reloadLock sync.Mutex @@ -32,7 +32,7 @@ type C struct { func NewC(l *logrus.Logger) *C { return &C{ - Settings: make(map[interface{}]interface{}), + Settings: make(map[string]any), l: l, } } @@ -92,8 +92,8 @@ func (c *C) HasChanged(k string) bool { } var ( - nv interface{} - ov interface{} + nv any + ov any ) if k == "" { @@ -147,7 +147,7 @@ func (c *C) ReloadConfig() { c.reloadLock.Lock() defer c.reloadLock.Unlock() - c.oldSettings = make(map[interface{}]interface{}) + c.oldSettings = make(map[string]any) for k, v := range c.Settings { c.oldSettings[k] = v } @@ -167,7 +167,7 @@ func (c *C) ReloadConfigString(raw string) error { c.reloadLock.Lock() defer c.reloadLock.Unlock() - c.oldSettings = make(map[interface{}]interface{}) + c.oldSettings = make(map[string]any) for k, v := range c.Settings { c.oldSettings[k] = v } @@ -201,7 +201,7 @@ func (c *C) GetStringSlice(k string, d []string) []string { return d } - rv, ok := r.([]interface{}) + rv, ok := r.([]any) if !ok { return d } @@ -215,13 +215,13 @@ func (c *C) GetStringSlice(k string, d []string) []string { } // GetMap will get the map for k or return the default d if not found or invalid -func (c *C) GetMap(k string, d map[interface{}]interface{}) map[interface{}]interface{} { +func (c *C) GetMap(k string, d map[string]any) map[string]any { r := c.Get(k) if r == nil { return d } - v, ok := r.(map[interface{}]interface{}) + v, ok := r.(map[string]any) if !ok { return d } @@ -266,6 +266,22 @@ func (c *C) GetBool(k string, d bool) bool { return v } +func AsBool(v any) (value bool, ok bool) { + switch x := v.(type) { + case bool: + return x, true + case string: + switch x { + case "y", "yes": + return true, true + case "n", "no": + return false, true + } + } + + return false, false +} + // GetDuration will get the duration for k or return the default d if not found or invalid func (c *C) GetDuration(k string, d time.Duration) time.Duration { r := c.GetString(k, "") @@ -276,7 +292,7 @@ func (c *C) GetDuration(k string, d time.Duration) time.Duration { return v } -func (c *C) Get(k string) interface{} { +func (c *C) Get(k string) any { return c.get(k, c.Settings) } @@ -284,10 +300,10 @@ func (c *C) IsSet(k string) bool { return c.get(k, c.Settings) != nil } -func (c *C) get(k string, v interface{}) interface{} { +func (c *C) get(k string, v any) any { parts := strings.Split(k, ".") for _, p := range parts { - m, ok := v.(map[interface{}]interface{}) + m, ok := v.(map[string]any) if !ok { return nil } @@ -346,7 +362,7 @@ func (c *C) addFile(path string, direct bool) error { } func (c *C) parseRaw(b []byte) error { - var m map[interface{}]interface{} + var m map[string]any err := yaml.Unmarshal(b, &m) if err != nil { @@ -358,7 +374,7 @@ func (c *C) parseRaw(b []byte) error { } func (c *C) parse() error { - var m map[interface{}]interface{} + var m map[string]any for _, path := range c.files { b, err := os.ReadFile(path) @@ -366,7 +382,7 @@ func (c *C) parse() error { return err } - var nm map[interface{}]interface{} + var nm map[string]any err = yaml.Unmarshal(b, &nm) if err != nil { return err diff --git a/config/config_test.go b/config/config_test.go index 468c642..ec5a4b0 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -10,7 +10,7 @@ import ( "github.com/slackhq/nebula/test" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "gopkg.in/yaml.v2" + "gopkg.in/yaml.v3" ) func TestConfig_Load(t *testing.T) { @@ -19,7 +19,7 @@ func TestConfig_Load(t *testing.T) { // invalid yaml c := NewC(l) os.WriteFile(filepath.Join(dir, "01.yaml"), []byte(" invalid yaml"), 0644) - require.EqualError(t, c.Load(dir), "yaml: unmarshal errors:\n line 1: cannot unmarshal !!str `invalid...` into map[interface {}]interface {}") + require.EqualError(t, c.Load(dir), "yaml: unmarshal errors:\n line 1: cannot unmarshal !!str `invalid...` into map[string]interface {}") // simple multi config merge c = NewC(l) @@ -31,8 +31,8 @@ func TestConfig_Load(t *testing.T) { os.WriteFile(filepath.Join(dir, "01.yaml"), []byte("outer:\n inner: hi"), 0644) os.WriteFile(filepath.Join(dir, "02.yml"), []byte("outer:\n inner: override\nnew: hi"), 0644) require.NoError(t, c.Load(dir)) - expected := map[interface{}]interface{}{ - "outer": map[interface{}]interface{}{ + expected := map[string]any{ + "outer": map[string]any{ "inner": "override", }, "new": "hi", @@ -44,12 +44,12 @@ func TestConfig_Get(t *testing.T) { l := test.NewLogger() // test simple type c := NewC(l) - c.Settings["firewall"] = map[interface{}]interface{}{"outbound": "hi"} + c.Settings["firewall"] = map[string]any{"outbound": "hi"} assert.Equal(t, "hi", c.Get("firewall.outbound")) // test complex type - inner := []map[interface{}]interface{}{{"port": "1", "code": "2"}} - c.Settings["firewall"] = map[interface{}]interface{}{"outbound": inner} + inner := []map[string]any{{"port": "1", "code": "2"}} + c.Settings["firewall"] = map[string]any{"outbound": inner} assert.EqualValues(t, inner, c.Get("firewall.outbound")) // test missing @@ -59,7 +59,7 @@ func TestConfig_Get(t *testing.T) { func TestConfig_GetStringSlice(t *testing.T) { l := test.NewLogger() c := NewC(l) - c.Settings["slice"] = []interface{}{"one", "two"} + c.Settings["slice"] = []any{"one", "two"} assert.Equal(t, []string{"one", "two"}, c.GetStringSlice("slice", []string{})) } @@ -101,14 +101,14 @@ func TestConfig_HasChanged(t *testing.T) { // Test key change c = NewC(l) c.Settings["test"] = "hi" - c.oldSettings = map[interface{}]interface{}{"test": "no"} + c.oldSettings = map[string]any{"test": "no"} assert.True(t, c.HasChanged("test")) assert.True(t, c.HasChanged("")) // No key change c = NewC(l) c.Settings["test"] = "hi" - c.oldSettings = map[interface{}]interface{}{"test": "hi"} + c.oldSettings = map[string]any{"test": "hi"} assert.False(t, c.HasChanged("test")) assert.False(t, c.HasChanged("")) } @@ -184,11 +184,11 @@ firewall: `), } - var m map[any]any + var m map[string]any // merge the same way config.parse() merges for _, b := range configs { - var nm map[any]any + var nm map[string]any err := yaml.Unmarshal(b, &nm) require.NoError(t, err) @@ -205,15 +205,15 @@ firewall: t.Logf("Merged Config as YAML:\n%s", mYaml) // If a bug is present, some items might be replaced instead of merged like we expect - expected := map[any]any{ - "firewall": map[any]any{ + expected := map[string]any{ + "firewall": map[string]any{ "inbound": []any{ - map[any]any{"host": "any", "port": "any", "proto": "icmp"}, - map[any]any{"groups": []any{"server"}, "port": 443, "proto": "tcp"}, - map[any]any{"groups": []any{"webapp"}, "port": 443, "proto": "tcp"}}, + map[string]any{"host": "any", "port": "any", "proto": "icmp"}, + map[string]any{"groups": []any{"server"}, "port": 443, "proto": "tcp"}, + map[string]any{"groups": []any{"webapp"}, "port": 443, "proto": "tcp"}}, "outbound": []any{ - map[any]any{"host": "any", "port": "any", "proto": "any"}}}, - "listen": map[any]any{ + map[string]any{"host": "any", "port": "any", "proto": "any"}}}, + "listen": map[string]any{ "host": "0.0.0.0", "port": 4242, }, diff --git a/control_test.go b/control_test.go index 6ce7083..de85fee 100644 --- a/control_test.go +++ b/control_test.go @@ -110,7 +110,7 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) { }) } -func assertFields(t *testing.T, expected []string, actualStruct interface{}) { +func assertFields(t *testing.T, expected []string, actualStruct any) { val := reflect.ValueOf(actualStruct).Elem() fields := make([]string, val.NumField()) for i := 0; i < val.NumField(); i++ { diff --git a/dns_server_test.go b/dns_server_test.go index f4643a3..356e589 100644 --- a/dns_server_test.go +++ b/dns_server_test.go @@ -38,24 +38,24 @@ func TestParsequery(t *testing.T) { func Test_getDnsServerAddr(t *testing.T) { c := config.NewC(nil) - c.Settings["lighthouse"] = map[interface{}]interface{}{ - "dns": map[interface{}]interface{}{ + c.Settings["lighthouse"] = map[string]any{ + "dns": map[string]any{ "host": "0.0.0.0", "port": "1", }, } assert.Equal(t, "0.0.0.0:1", getDnsServerAddr(c)) - c.Settings["lighthouse"] = map[interface{}]interface{}{ - "dns": map[interface{}]interface{}{ + c.Settings["lighthouse"] = map[string]any{ + "dns": map[string]any{ "host": "::", "port": "1", }, } assert.Equal(t, "[::]:1", getDnsServerAddr(c)) - c.Settings["lighthouse"] = map[interface{}]interface{}{ - "dns": map[interface{}]interface{}{ + c.Settings["lighthouse"] = map[string]any{ + "dns": map[string]any{ "host": "[::]", "port": "1", }, @@ -63,8 +63,8 @@ func Test_getDnsServerAddr(t *testing.T) { assert.Equal(t, "[::]:1", getDnsServerAddr(c)) // Make sure whitespace doesn't mess us up - c.Settings["lighthouse"] = map[interface{}]interface{}{ - "dns": map[interface{}]interface{}{ + c.Settings["lighthouse"] = map[string]any{ + "dns": map[string]any{ "host": "[::] ", "port": "1", }, diff --git a/e2e/handshakes_test.go b/e2e/handshakes_test.go index 06f2a21..bc080ce 100644 --- a/e2e/handshakes_test.go +++ b/e2e/handshakes_test.go @@ -20,7 +20,7 @@ import ( "github.com/slackhq/nebula/udp" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "gopkg.in/yaml.v2" + "gopkg.in/yaml.v3" ) func BenchmarkHotPath(b *testing.B) { @@ -991,7 +991,7 @@ func TestRehandshaking(t *testing.T) { require.NoError(t, err) var theirNewConfig m require.NoError(t, yaml.Unmarshal(rc, &theirNewConfig)) - theirFirewall := theirNewConfig["firewall"].(map[interface{}]interface{}) + theirFirewall := theirNewConfig["firewall"].(map[string]any) theirFirewall["inbound"] = []m{{ "proto": "any", "port": "any", @@ -1087,7 +1087,7 @@ func TestRehandshakingLoser(t *testing.T) { require.NoError(t, err) var myNewConfig m require.NoError(t, yaml.Unmarshal(rc, &myNewConfig)) - theirFirewall := myNewConfig["firewall"].(map[interface{}]interface{}) + theirFirewall := myNewConfig["firewall"].(map[string]any) theirFirewall["inbound"] = []m{{ "proto": "any", "port": "any", diff --git a/e2e/helpers_test.go b/e2e/helpers_test.go index e1b7ac2..a63b3d0 100644 --- a/e2e/helpers_test.go +++ b/e2e/helpers_test.go @@ -22,10 +22,10 @@ import ( "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/e2e/router" "github.com/stretchr/testify/assert" - "gopkg.in/yaml.v2" + "gopkg.in/yaml.v3" ) -type m map[string]interface{} +type m = map[string]any // newSimpleServer creates a nebula instance with many assumptions func newSimpleServer(v cert.Version, caCrt cert.Certificate, caKey []byte, name string, sVpnNetworks string, overrides m) (*nebula.Control, []netip.Prefix, netip.AddrPort, *config.C) { diff --git a/firewall.go b/firewall.go index e9f454d..e730114 100644 --- a/firewall.go +++ b/firewall.go @@ -331,7 +331,7 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw return nil } - rs, ok := r.([]interface{}) + rs, ok := r.([]any) if !ok { return fmt.Errorf("%s failed to parse, should be an array of rules", table) } @@ -918,15 +918,15 @@ type rule struct { CASha string } -func convertRule(l *logrus.Logger, p interface{}, table string, i int) (rule, error) { +func convertRule(l *logrus.Logger, p any, table string, i int) (rule, error) { r := rule{} - m, ok := p.(map[interface{}]interface{}) + m, ok := p.(map[string]any) if !ok { return r, errors.New("could not parse rule") } - toString := func(k string, m map[interface{}]interface{}) string { + toString := func(k string, m map[string]any) string { v, ok := m[k] if !ok { return "" @@ -944,7 +944,7 @@ func convertRule(l *logrus.Logger, p interface{}, table string, i int) (rule, er r.CASha = toString("ca_sha", m) // Make sure group isn't an array - if v, ok := m["group"].([]interface{}); ok { + if v, ok := m["group"].([]any); ok { if len(v) > 1 { return r, errors.New("group should contain a single value, an array with more than one entry was provided") } diff --git a/firewall/packet.go b/firewall/packet.go index 1d8f12a..40c7fc5 100644 --- a/firewall/packet.go +++ b/firewall/packet.go @@ -6,7 +6,7 @@ import ( "net/netip" ) -type m map[string]interface{} +type m = map[string]any const ( ProtoAny = 0 // When we want to handle HOPOPT (0) we can change this, if ever diff --git a/firewall_test.go b/firewall_test.go index 8c2eeb0..c90fb20 100644 --- a/firewall_test.go +++ b/firewall_test.go @@ -631,53 +631,53 @@ func TestNewFirewallFromConfig(t *testing.T) { require.NoError(t, err) conf := config.NewC(l) - conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": "asdf"} + conf.Settings["firewall"] = map[string]any{"outbound": "asdf"} _, err = NewFirewallFromConfig(l, cs, conf) require.EqualError(t, err, "firewall.outbound failed to parse, should be an array of rules") // Test both port and code conf = config.NewC(l) - conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "code": "2"}}} + conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "code": "2"}}} _, err = NewFirewallFromConfig(l, cs, conf) require.EqualError(t, err, "firewall.outbound rule #0; only one of port or code should be provided") // Test missing host, group, cidr, ca_name and ca_sha conf = config.NewC(l) - conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{}}} + conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{}}} _, err = NewFirewallFromConfig(l, cs, conf) require.EqualError(t, err, "firewall.outbound rule #0; at least one of host, group, cidr, local_cidr, ca_name, or ca_sha must be provided") // Test code/port error conf = config.NewC(l) - conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "a", "host": "testh"}}} + conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "a", "host": "testh"}}} _, err = NewFirewallFromConfig(l, cs, conf) require.EqualError(t, err, "firewall.outbound rule #0; code was not a number; `a`") - conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "a", "host": "testh"}}} + conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "a", "host": "testh"}}} _, err = NewFirewallFromConfig(l, cs, conf) require.EqualError(t, err, "firewall.outbound rule #0; port was not a number; `a`") // Test proto error conf = config.NewC(l) - conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "host": "testh"}}} + conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "1", "host": "testh"}}} _, err = NewFirewallFromConfig(l, cs, conf) require.EqualError(t, err, "firewall.outbound rule #0; proto was not understood; ``") // Test cidr parse error conf = config.NewC(l) - conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "cidr": "testh", "proto": "any"}}} + conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "1", "cidr": "testh", "proto": "any"}}} _, err = NewFirewallFromConfig(l, cs, conf) require.EqualError(t, err, "firewall.outbound rule #0; cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'") // Test local_cidr parse error conf = config.NewC(l) - conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "local_cidr": "testh", "proto": "any"}}} + conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "1", "local_cidr": "testh", "proto": "any"}}} _, err = NewFirewallFromConfig(l, cs, conf) require.EqualError(t, err, "firewall.outbound rule #0; local_cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'") // Test both group and groups conf = config.NewC(l) - conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a", "groups": []string{"b", "c"}}}} + conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "group": "a", "groups": []string{"b", "c"}}}} _, err = NewFirewallFromConfig(l, cs, conf) require.EqualError(t, err, "firewall.inbound rule #0; only one of group or groups should be defined, both provided") } @@ -687,28 +687,28 @@ func TestAddFirewallRulesFromConfig(t *testing.T) { // Test adding tcp rule conf := config.NewC(l) mf := &mockFirewall{} - conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "tcp", "host": "a"}}} + conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "tcp", "host": "a"}}} require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf)) assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoTCP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall) // Test adding udp rule conf = config.NewC(l) mf = &mockFirewall{} - conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "udp", "host": "a"}}} + conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "udp", "host": "a"}}} require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf)) assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoUDP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall) // Test adding icmp rule conf = config.NewC(l) mf = &mockFirewall{} - conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "icmp", "host": "a"}}} + conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "icmp", "host": "a"}}} require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf)) assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall) // Test adding any rule conf = config.NewC(l) mf = &mockFirewall{} - conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "host": "a"}}} + conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "host": "a"}}} require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall) @@ -716,49 +716,49 @@ func TestAddFirewallRulesFromConfig(t *testing.T) { cidr := netip.MustParsePrefix("10.0.0.0/8") conf = config.NewC(l) mf = &mockFirewall{} - conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "cidr": cidr.String()}}} + conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": cidr.String()}}} require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr, localIp: netip.Prefix{}}, mf.lastCall) // Test adding rule with local_cidr conf = config.NewC(l) mf = &mockFirewall{} - conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "local_cidr": cidr.String()}}} + conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": cidr.String()}}} require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: cidr}, mf.lastCall) // Test adding rule with ca_sha conf = config.NewC(l) mf = &mockFirewall{} - conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "ca_sha": "12312313123"}}} + conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "ca_sha": "12312313123"}}} require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: netip.Prefix{}, caSha: "12312313123"}, mf.lastCall) // Test adding rule with ca_name conf = config.NewC(l) mf = &mockFirewall{} - conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "ca_name": "root01"}}} + conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "ca_name": "root01"}}} require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: netip.Prefix{}, caName: "root01"}, mf.lastCall) // Test single group conf = config.NewC(l) mf = &mockFirewall{} - conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a"}}} + conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "group": "a"}}} require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall) // Test single groups conf = config.NewC(l) mf = &mockFirewall{} - conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "groups": "a"}}} + conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "groups": "a"}}} require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall) // Test multiple AND groups conf = config.NewC(l) mf = &mockFirewall{} - conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "groups": []string{"a", "b"}}}} + conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "groups": []string{"a", "b"}}}} require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a", "b"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall) @@ -766,7 +766,7 @@ func TestAddFirewallRulesFromConfig(t *testing.T) { conf = config.NewC(l) mf = &mockFirewall{} mf.nextCallReturn = errors.New("test error") - conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "host": "a"}}} + conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "host": "a"}}} require.EqualError(t, AddFirewallRulesFromConfig(l, true, conf, mf), "firewall.inbound rule #0; `test error`") } @@ -776,8 +776,8 @@ func TestFirewall_convertRule(t *testing.T) { l.SetOutput(ob) // Ensure group array of 1 is converted and a warning is printed - c := map[interface{}]interface{}{ - "group": []interface{}{"group1"}, + c := map[string]any{ + "group": []any{"group1"}, } r, err := convertRule(l, c, "test", 1) @@ -787,8 +787,8 @@ func TestFirewall_convertRule(t *testing.T) { // Ensure group array of > 1 is errord ob.Reset() - c = map[interface{}]interface{}{ - "group": []interface{}{"group1", "group2"}, + c = map[string]any{ + "group": []any{"group1", "group2"}, } r, err = convertRule(l, c, "test", 1) @@ -797,7 +797,7 @@ func TestFirewall_convertRule(t *testing.T) { // Make sure a well formed group is alright ob.Reset() - c = map[interface{}]interface{}{ + c = map[string]any{ "group": "group1", } diff --git a/go.mod b/go.mod index 3b13170..bbd5d8b 100644 --- a/go.mod +++ b/go.mod @@ -34,7 +34,7 @@ require ( golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b golang.zx2c4.com/wireguard/windows v0.5.3 google.golang.org/protobuf v1.36.5 - gopkg.in/yaml.v2 v2.4.0 + gopkg.in/yaml.v3 v3.0.1 gvisor.dev/gvisor v0.0.0-20240423190808-9d7a357edefe ) @@ -53,5 +53,4 @@ require ( golang.org/x/mod v0.18.0 // indirect golang.org/x/time v0.5.0 // indirect golang.org/x/tools v0.22.0 // indirect - gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 78f2671..8237bfa 100644 --- a/go.sum +++ b/go.sum @@ -251,8 +251,6 @@ gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.5/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= -gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/header/header.go b/header/header.go index 50b7d62..f22509b 100644 --- a/header/header.go +++ b/header/header.go @@ -19,7 +19,7 @@ import ( // |-----------------------------------------------------------------------| // | payload... | -type m map[string]interface{} +type m = map[string]any const ( Version uint8 = 1 diff --git a/lighthouse.go b/lighthouse.go index ce37023..f13afd3 100644 --- a/lighthouse.go +++ b/lighthouse.go @@ -422,7 +422,7 @@ func (lh *LightHouse) loadStaticMap(c *config.C, staticList map[netip.Addr]struc return err } - shm := c.GetMap("static_host_map", map[interface{}]interface{}{}) + shm := c.GetMap("static_host_map", map[string]any{}) i := 0 for k, v := range shm { @@ -436,9 +436,9 @@ func (lh *LightHouse) loadStaticMap(c *config.C, staticList map[netip.Addr]struc return util.NewContextualError("static_host_map key is not in our network, invalid", m{"vpnAddr": vpnAddr, "networks": lh.myVpnNetworks, "entry": i + 1}, nil) } - vals, ok := v.([]interface{}) + vals, ok := v.([]any) if !ok { - vals = []interface{}{v} + vals = []any{v} } remoteAddrs := []string{} for _, v := range vals { diff --git a/lighthouse_test.go b/lighthouse_test.go index 3b1295a..6a541c2 100644 --- a/lighthouse_test.go +++ b/lighthouse_test.go @@ -14,7 +14,7 @@ import ( "github.com/slackhq/nebula/test" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "gopkg.in/yaml.v2" + "gopkg.in/yaml.v3" ) func TestOldIPv4Only(t *testing.T) { @@ -40,15 +40,15 @@ func Test_lhStaticMapping(t *testing.T) { lh1 := "10.128.0.2" c := config.NewC(l) - c.Settings["lighthouse"] = map[interface{}]interface{}{"hosts": []interface{}{lh1}} - c.Settings["static_host_map"] = map[interface{}]interface{}{lh1: []interface{}{"1.1.1.1:4242"}} + c.Settings["lighthouse"] = map[string]any{"hosts": []any{lh1}} + c.Settings["static_host_map"] = map[string]any{lh1: []any{"1.1.1.1:4242"}} _, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil) require.NoError(t, err) lh2 := "10.128.0.3" c = config.NewC(l) - c.Settings["lighthouse"] = map[interface{}]interface{}{"hosts": []interface{}{lh1, lh2}} - c.Settings["static_host_map"] = map[interface{}]interface{}{lh1: []interface{}{"100.1.1.1:4242"}} + c.Settings["lighthouse"] = map[string]any{"hosts": []any{lh1, lh2}} + c.Settings["static_host_map"] = map[string]any{lh1: []any{"100.1.1.1:4242"}} _, err = NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil) require.EqualError(t, err, "lighthouse 10.128.0.3 does not have a static_host_map entry") } @@ -65,12 +65,12 @@ func TestReloadLighthouseInterval(t *testing.T) { lh1 := "10.128.0.2" c := config.NewC(l) - c.Settings["lighthouse"] = map[interface{}]interface{}{ - "hosts": []interface{}{lh1}, + c.Settings["lighthouse"] = map[string]any{ + "hosts": []any{lh1}, "interval": "1s", } - c.Settings["static_host_map"] = map[interface{}]interface{}{lh1: []interface{}{"1.1.1.1:4242"}} + c.Settings["static_host_map"] = map[string]any{lh1: []any{"1.1.1.1:4242"}} lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil) require.NoError(t, err) lh.ifce = &mockEncWriter{} @@ -192,8 +192,8 @@ func TestLighthouse_Memory(t *testing.T) { theirVpnIp := netip.MustParseAddr("10.128.0.3") c := config.NewC(l) - c.Settings["lighthouse"] = map[interface{}]interface{}{"am_lighthouse": true} - c.Settings["listen"] = map[interface{}]interface{}{"port": 4242} + c.Settings["lighthouse"] = map[string]any{"am_lighthouse": true} + c.Settings["listen"] = map[string]any{"port": 4242} myVpnNet := netip.MustParsePrefix("10.128.0.1/24") nt := new(bart.Table[struct{}]) @@ -277,8 +277,8 @@ func TestLighthouse_Memory(t *testing.T) { func TestLighthouse_reload(t *testing.T) { l := test.NewLogger() c := config.NewC(l) - c.Settings["lighthouse"] = map[interface{}]interface{}{"am_lighthouse": true} - c.Settings["listen"] = map[interface{}]interface{}{"port": 4242} + c.Settings["lighthouse"] = map[string]any{"am_lighthouse": true} + c.Settings["listen"] = map[string]any{"port": 4242} myVpnNet := netip.MustParsePrefix("10.128.0.1/24") nt := new(bart.Table[struct{}]) @@ -291,9 +291,9 @@ func TestLighthouse_reload(t *testing.T) { lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil) require.NoError(t, err) - nc := map[interface{}]interface{}{ - "static_host_map": map[interface{}]interface{}{ - "10.128.0.2": []interface{}{"1.1.1.1:4242"}, + nc := map[string]any{ + "static_host_map": map[string]any{ + "10.128.0.2": []any{"1.1.1.1:4242"}, }, } rc, err := yaml.Marshal(nc) diff --git a/main.go b/main.go index 7e94c32..b278fa6 100644 --- a/main.go +++ b/main.go @@ -13,10 +13,10 @@ import ( "github.com/slackhq/nebula/sshd" "github.com/slackhq/nebula/udp" "github.com/slackhq/nebula/util" - "gopkg.in/yaml.v2" + "gopkg.in/yaml.v3" ) -type m map[string]interface{} +type m = map[string]any func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logger, deviceFactory overlay.DeviceFactory) (retcon *Control, reterr error) { ctx, cancel := context.WithCancel(context.Background()) diff --git a/overlay/route.go b/overlay/route.go index 12364ec..360921f 100644 --- a/overlay/route.go +++ b/overlay/route.go @@ -72,7 +72,7 @@ func parseRoutes(c *config.C, networks []netip.Prefix) ([]Route, error) { return []Route{}, nil } - rawRoutes, ok := r.([]interface{}) + rawRoutes, ok := r.([]any) if !ok { return nil, fmt.Errorf("tun.routes is not an array") } @@ -83,7 +83,7 @@ func parseRoutes(c *config.C, networks []netip.Prefix) ([]Route, error) { routes := make([]Route, len(rawRoutes)) for i, r := range rawRoutes { - m, ok := r.(map[interface{}]interface{}) + m, ok := r.(map[string]any) if !ok { return nil, fmt.Errorf("entry %v in tun.routes is invalid", i+1) } @@ -151,7 +151,7 @@ func parseUnsafeRoutes(c *config.C, networks []netip.Prefix) ([]Route, error) { return []Route{}, nil } - rawRoutes, ok := r.([]interface{}) + rawRoutes, ok := r.([]any) if !ok { return nil, fmt.Errorf("tun.unsafe_routes is not an array") } @@ -162,7 +162,7 @@ func parseUnsafeRoutes(c *config.C, networks []netip.Prefix) ([]Route, error) { routes := make([]Route, len(rawRoutes)) for i, r := range rawRoutes { - m, ok := r.(map[interface{}]interface{}) + m, ok := r.(map[string]any) if !ok { return nil, fmt.Errorf("entry %v in tun.unsafe_routes is invalid", i+1) } diff --git a/overlay/route_test.go b/overlay/route_test.go index eb5e914..6b5ae2e 100644 --- a/overlay/route_test.go +++ b/overlay/route_test.go @@ -24,75 +24,75 @@ func Test_parseRoutes(t *testing.T) { assert.Empty(t, routes) // not an array - c.Settings["tun"] = map[interface{}]interface{}{"routes": "hi"} + c.Settings["tun"] = map[string]any{"routes": "hi"} routes, err = parseRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) require.EqualError(t, err, "tun.routes is not an array") // no routes - c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{}} + c.Settings["tun"] = map[string]any{"routes": []any{}} routes, err = parseRoutes(c, []netip.Prefix{n}) require.NoError(t, err) assert.Empty(t, routes) // weird route - c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{"asdf"}} + c.Settings["tun"] = map[string]any{"routes": []any{"asdf"}} routes, err = parseRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) require.EqualError(t, err, "entry 1 in tun.routes is invalid") // no mtu - c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{}}} + c.Settings["tun"] = map[string]any{"routes": []any{map[string]any{}}} routes, err = parseRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) require.EqualError(t, err, "entry 1.mtu in tun.routes is not present") // bad mtu - c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "nope"}}} + c.Settings["tun"] = map[string]any{"routes": []any{map[string]any{"mtu": "nope"}}} routes, err = parseRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) require.EqualError(t, err, "entry 1.mtu in tun.routes is not an integer: strconv.Atoi: parsing \"nope\": invalid syntax") // low mtu - c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "499"}}} + c.Settings["tun"] = map[string]any{"routes": []any{map[string]any{"mtu": "499"}}} routes, err = parseRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) require.EqualError(t, err, "entry 1.mtu in tun.routes is below 500: 499") // missing route - c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500"}}} + c.Settings["tun"] = map[string]any{"routes": []any{map[string]any{"mtu": "500"}}} routes, err = parseRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) require.EqualError(t, err, "entry 1.route in tun.routes is not present") // unparsable route - c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "nope"}}} + c.Settings["tun"] = map[string]any{"routes": []any{map[string]any{"mtu": "500", "route": "nope"}}} routes, err = parseRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) require.EqualError(t, err, "entry 1.route in tun.routes failed to parse: netip.ParsePrefix(\"nope\"): no '/'") // below network range - c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "1.0.0.0/8"}}} + c.Settings["tun"] = map[string]any{"routes": []any{map[string]any{"mtu": "500", "route": "1.0.0.0/8"}}} routes, err = parseRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) require.EqualError(t, err, "entry 1.route in tun.routes is not contained within the configured vpn networks; route: 1.0.0.0/8, networks: [10.0.0.0/24]") // above network range - c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "10.0.1.0/24"}}} + c.Settings["tun"] = map[string]any{"routes": []any{map[string]any{"mtu": "500", "route": "10.0.1.0/24"}}} routes, err = parseRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) require.EqualError(t, err, "entry 1.route in tun.routes is not contained within the configured vpn networks; route: 10.0.1.0/24, networks: [10.0.0.0/24]") // Not in multiple ranges - c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "192.0.0.0/24"}}} + c.Settings["tun"] = map[string]any{"routes": []any{map[string]any{"mtu": "500", "route": "192.0.0.0/24"}}} routes, err = parseRoutes(c, []netip.Prefix{n, netip.MustParsePrefix("192.1.0.0/24")}) assert.Nil(t, routes) require.EqualError(t, err, "entry 1.route in tun.routes is not contained within the configured vpn networks; route: 192.0.0.0/24, networks: [10.0.0.0/24 192.1.0.0/24]") // happy case - c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{ - map[interface{}]interface{}{"mtu": "9000", "route": "10.0.0.0/29"}, - map[interface{}]interface{}{"mtu": "8000", "route": "10.0.0.1/32"}, + c.Settings["tun"] = map[string]any{"routes": []any{ + map[string]any{"mtu": "9000", "route": "10.0.0.0/29"}, + map[string]any{"mtu": "8000", "route": "10.0.0.1/32"}, }} routes, err = parseRoutes(c, []netip.Prefix{n}) require.NoError(t, err) @@ -129,34 +129,34 @@ func Test_parseUnsafeRoutes(t *testing.T) { assert.Empty(t, routes) // not an array - c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": "hi"} + c.Settings["tun"] = map[string]any{"unsafe_routes": "hi"} routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) require.EqualError(t, err, "tun.unsafe_routes is not an array") // no routes - c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{}} + c.Settings["tun"] = map[string]any{"unsafe_routes": []any{}} routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) require.NoError(t, err) assert.Empty(t, routes) // weird route - c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{"asdf"}} + c.Settings["tun"] = map[string]any{"unsafe_routes": []any{"asdf"}} routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) require.EqualError(t, err, "entry 1 in tun.unsafe_routes is invalid") // no via - c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{}}} + c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{}}} routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) require.EqualError(t, err, "entry 1.via in tun.unsafe_routes is not present") // invalid via - for _, invalidValue := range []interface{}{ + for _, invalidValue := range []any{ 127, false, nil, 1.0, []string{"1", "2"}, } { - c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": invalidValue}}} + c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"via": invalidValue}}} routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) require.EqualError(t, err, fmt.Sprintf("entry 1.via in tun.unsafe_routes is not a string or list of gateways: found %T", invalidValue)) @@ -169,7 +169,7 @@ func Test_parseUnsafeRoutes(t *testing.T) { require.EqualError(t, err, "entry 1.via in tun.unsafe_routes is not a string or list of gateways: found []string") // unparsable via - c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "via": "nope"}}} + c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"mtu": "500", "via": "nope"}}} routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) require.EqualError(t, err, "entry 1.via in tun.unsafe_routes failed to parse address: ParseAddr(\"nope\"): unable to parse IP") @@ -193,65 +193,65 @@ func Test_parseUnsafeRoutes(t *testing.T) { require.EqualError(t, err, "entry .weight in tun.unsafe_routes[1].via[1] is not an integer") // missing route - c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "500"}}} + c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"via": "127.0.0.1", "mtu": "500"}}} routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) require.EqualError(t, err, "entry 1.route in tun.unsafe_routes is not present") // unparsable route - c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "500", "route": "nope"}}} + c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"via": "127.0.0.1", "mtu": "500", "route": "nope"}}} routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) require.EqualError(t, err, "entry 1.route in tun.unsafe_routes failed to parse: netip.ParsePrefix(\"nope\"): no '/'") // within network range - c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "10.0.0.0/24"}}} + c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"via": "127.0.0.1", "route": "10.0.0.0/24"}}} routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) require.EqualError(t, err, "entry 1.route in tun.unsafe_routes is contained within the configured vpn networks; route: 10.0.0.0/24, network: 10.0.0.0/24") // below network range - c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "1.0.0.0/8"}}} + c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"via": "127.0.0.1", "route": "1.0.0.0/8"}}} routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Len(t, routes, 1) require.NoError(t, err) // above network range - c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "10.0.1.0/24"}}} + c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"via": "127.0.0.1", "route": "10.0.1.0/24"}}} routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Len(t, routes, 1) require.NoError(t, err) // no mtu - c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "1.0.0.0/8"}}} + c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"via": "127.0.0.1", "route": "1.0.0.0/8"}}} routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Len(t, routes, 1) assert.Equal(t, 0, routes[0].MTU) // bad mtu - c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "nope"}}} + c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"via": "127.0.0.1", "mtu": "nope"}}} routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) require.EqualError(t, err, "entry 1.mtu in tun.unsafe_routes is not an integer: strconv.Atoi: parsing \"nope\": invalid syntax") // low mtu - c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "499"}}} + c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"via": "127.0.0.1", "mtu": "499"}}} routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) require.EqualError(t, err, "entry 1.mtu in tun.unsafe_routes is below 500: 499") // bad install - c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "9000", "route": "1.0.0.0/29", "install": "nope"}}} + c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"via": "127.0.0.1", "mtu": "9000", "route": "1.0.0.0/29", "install": "nope"}}} routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) require.EqualError(t, err, "entry 1.install in tun.unsafe_routes is not a boolean: strconv.ParseBool: parsing \"nope\": invalid syntax") // happy case - c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{ - map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "9000", "route": "1.0.0.0/29", "install": "t"}, - map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "8000", "route": "1.0.0.1/32", "install": 0}, - map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "1500", "metric": 1234, "route": "1.0.0.2/32", "install": 1}, - map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "1500", "metric": 1234, "route": "1.0.0.2/32"}, + c.Settings["tun"] = map[string]any{"unsafe_routes": []any{ + map[string]any{"via": "127.0.0.1", "mtu": "9000", "route": "1.0.0.0/29", "install": "t"}, + map[string]any{"via": "127.0.0.1", "mtu": "8000", "route": "1.0.0.1/32", "install": 0}, + map[string]any{"via": "127.0.0.1", "mtu": "1500", "metric": 1234, "route": "1.0.0.2/32", "install": 1}, + map[string]any{"via": "127.0.0.1", "mtu": "1500", "metric": 1234, "route": "1.0.0.2/32"}, }} routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) require.NoError(t, err) @@ -288,9 +288,9 @@ func Test_makeRouteTree(t *testing.T) { n, err := netip.ParsePrefix("10.0.0.0/24") require.NoError(t, err) - c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{ - map[interface{}]interface{}{"via": "192.168.0.1", "route": "1.0.0.0/28"}, - map[interface{}]interface{}{"via": "192.168.0.2", "route": "1.0.0.1/32"}, + c.Settings["tun"] = map[string]any{"unsafe_routes": []any{ + map[string]any{"via": "192.168.0.1", "route": "1.0.0.0/28"}, + map[string]any{"via": "192.168.0.2", "route": "1.0.0.1/32"}, }} routes, err := parseUnsafeRoutes(c, []netip.Prefix{n}) require.NoError(t, err) diff --git a/overlay/tun_darwin.go b/overlay/tun_darwin.go index d2b2896..7f6ba4f 100644 --- a/overlay/tun_darwin.go +++ b/overlay/tun_darwin.go @@ -394,7 +394,7 @@ func (t *tun) addRoutes(logErrors bool) error { t.l.WithField("route", r.Cidr). Warnf("unable to add unsafe_route, identical route already exists") } else { - retErr := util.NewContextualError("Failed to add route", map[string]interface{}{"route": r}, err) + retErr := util.NewContextualError("Failed to add route", map[string]any{"route": r}, err) if logErrors { retErr.Log(t.l) } else { diff --git a/overlay/tun_freebsd.go b/overlay/tun_freebsd.go index bcb82b3..2a89cbc 100644 --- a/overlay/tun_freebsd.go +++ b/overlay/tun_freebsd.go @@ -271,7 +271,7 @@ func (t *tun) addRoutes(logErrors bool) error { cmd := exec.Command("/sbin/route", "-n", "add", "-net", r.Cidr.String(), "-interface", t.Device) t.l.Debug("command: ", cmd.String()) if err := cmd.Run(); err != nil { - retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]interface{}{"route": r}, err) + retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]any{"route": r}, err) if logErrors { retErr.Log(t.l) } else { diff --git a/overlay/tun_linux.go b/overlay/tun_linux.go index 809536f..7d19c85 100644 --- a/overlay/tun_linux.go +++ b/overlay/tun_linux.go @@ -464,7 +464,7 @@ func (t *tun) addRoutes(logErrors bool) error { err := netlink.RouteReplace(&nr) if err != nil { - retErr := util.NewContextualError("Failed to add route", map[string]interface{}{"route": r}, err) + retErr := util.NewContextualError("Failed to add route", map[string]any{"route": r}, err) if logErrors { retErr.Log(t.l) } else { diff --git a/overlay/tun_netbsd.go b/overlay/tun_netbsd.go index 847f1b5..5ff9b0f 100644 --- a/overlay/tun_netbsd.go +++ b/overlay/tun_netbsd.go @@ -206,7 +206,7 @@ func (t *tun) addRoutes(logErrors bool) error { cmd := exec.Command("/sbin/route", "-n", "add", "-net", r.Cidr.String(), t.vpnNetworks[0].Addr().String()) t.l.Debug("command: ", cmd.String()) if err := cmd.Run(); err != nil { - retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]interface{}{"route": r}, err) + retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]any{"route": r}, err) if logErrors { retErr.Log(t.l) } else { diff --git a/overlay/tun_openbsd.go b/overlay/tun_openbsd.go index 03fb3a0..67a9a5f 100644 --- a/overlay/tun_openbsd.go +++ b/overlay/tun_openbsd.go @@ -175,7 +175,7 @@ func (t *tun) addRoutes(logErrors bool) error { cmd := exec.Command("/sbin/route", "-n", "add", "-inet", r.Cidr.String(), t.vpnNetworks[0].Addr().String()) t.l.Debug("command: ", cmd.String()) if err := cmd.Run(); err != nil { - retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]interface{}{"route": r}, err) + retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]any{"route": r}, err) if logErrors { retErr.Log(t.l) } else { diff --git a/overlay/tun_windows.go b/overlay/tun_windows.go index 1d66eac..7aac128 100644 --- a/overlay/tun_windows.go +++ b/overlay/tun_windows.go @@ -159,7 +159,7 @@ func (t *winTun) addRoutes(logErrors bool) error { // In effect this provides multipath routing support to windows supporting loadbalancing and redundancy. err := luid.AddRoute(r.Cidr, r.Via[0].Addr(), uint32(r.Metric)) if err != nil { - retErr := util.NewContextualError("Failed to add route", map[string]interface{}{"route": r}, err) + retErr := util.NewContextualError("Failed to add route", map[string]any{"route": r}, err) if logErrors { retErr.Log(t.l) continue diff --git a/punchy_test.go b/punchy_test.go index 99d703d..56dd1c2 100644 --- a/punchy_test.go +++ b/punchy_test.go @@ -27,7 +27,7 @@ func TestNewPunchyFromConfig(t *testing.T) { assert.True(t, p.GetPunch()) // punchy.punch - c.Settings["punchy"] = map[interface{}]interface{}{"punch": true} + c.Settings["punchy"] = map[string]any{"punch": true} p = NewPunchyFromConfig(l, c) assert.True(t, p.GetPunch()) @@ -37,18 +37,18 @@ func TestNewPunchyFromConfig(t *testing.T) { assert.True(t, p.GetRespond()) // punchy.respond - c.Settings["punchy"] = map[interface{}]interface{}{"respond": true} + c.Settings["punchy"] = map[string]any{"respond": true} c.Settings["punch_back"] = false p = NewPunchyFromConfig(l, c) assert.True(t, p.GetRespond()) // punchy.delay - c.Settings["punchy"] = map[interface{}]interface{}{"delay": "1m"} + c.Settings["punchy"] = map[string]any{"delay": "1m"} p = NewPunchyFromConfig(l, c) assert.Equal(t, time.Minute, p.GetDelay()) // punchy.respond_delay - c.Settings["punchy"] = map[interface{}]interface{}{"respond_delay": "1m"} + c.Settings["punchy"] = map[string]any{"respond_delay": "1m"} p = NewPunchyFromConfig(l, c) assert.Equal(t, time.Minute, p.GetRespondDelay()) } diff --git a/service/service_test.go b/service/service_test.go index 613758e..b9810cd 100644 --- a/service/service_test.go +++ b/service/service_test.go @@ -13,10 +13,10 @@ import ( "github.com/slackhq/nebula/cert_test" "github.com/slackhq/nebula/config" "golang.org/x/sync/errgroup" - "gopkg.in/yaml.v2" + "gopkg.in/yaml.v3" ) -type m map[string]interface{} +type m = map[string]any func newSimpleService(caCrt cert.Certificate, caKey []byte, name string, udpIp netip.Addr, overrides m) *Service { _, _, myPrivKey, myPEM := cert_test.NewTestCert(cert.Version2, cert.Curve_CURVE25519, caCrt, caKey, "a", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.PrefixFrom(udpIp, 24)}, nil, []string{}) diff --git a/ssh.go b/ssh.go index 203166c..9a26c29 100644 --- a/ssh.go +++ b/ssh.go @@ -124,10 +124,10 @@ func configSSH(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) (func(), erro } rawKeys := c.Get("sshd.authorized_users") - keys, ok := rawKeys.([]interface{}) + keys, ok := rawKeys.([]any) if ok { for _, rk := range keys { - kDef, ok := rk.(map[interface{}]interface{}) + kDef, ok := rk.(map[string]any) if !ok { l.WithField("sshKeyConfig", rk).Warn("Authorized user had an error, ignoring") continue @@ -148,7 +148,7 @@ func configSSH(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) (func(), erro continue } - case []interface{}: + case []any: for _, subK := range v { sk, ok := subK.(string) if !ok { @@ -190,7 +190,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter ssh.RegisterCommand(&sshd.Command{ Name: "list-hostmap", ShortDescription: "List all known previously connected hosts", - Flags: func() (*flag.FlagSet, interface{}) { + Flags: func() (*flag.FlagSet, any) { fl := flag.NewFlagSet("", flag.ContinueOnError) s := sshListHostMapFlags{} fl.BoolVar(&s.Json, "json", false, "outputs as json with more information") @@ -198,7 +198,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter fl.BoolVar(&s.ByIndex, "by-index", false, "gets all hosts in the hostmap from the index table") return fl, &s }, - Callback: func(fs interface{}, a []string, w sshd.StringWriter) error { + Callback: func(fs any, a []string, w sshd.StringWriter) error { return sshListHostMap(f.hostMap, fs, w) }, }) @@ -206,7 +206,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter ssh.RegisterCommand(&sshd.Command{ Name: "list-pending-hostmap", ShortDescription: "List all handshaking hosts", - Flags: func() (*flag.FlagSet, interface{}) { + Flags: func() (*flag.FlagSet, any) { fl := flag.NewFlagSet("", flag.ContinueOnError) s := sshListHostMapFlags{} fl.BoolVar(&s.Json, "json", false, "outputs as json with more information") @@ -214,7 +214,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter fl.BoolVar(&s.ByIndex, "by-index", false, "gets all hosts in the hostmap from the index table") return fl, &s }, - Callback: func(fs interface{}, a []string, w sshd.StringWriter) error { + Callback: func(fs any, a []string, w sshd.StringWriter) error { return sshListHostMap(f.handshakeManager, fs, w) }, }) @@ -222,14 +222,14 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter ssh.RegisterCommand(&sshd.Command{ Name: "list-lighthouse-addrmap", ShortDescription: "List all lighthouse map entries", - Flags: func() (*flag.FlagSet, interface{}) { + Flags: func() (*flag.FlagSet, any) { fl := flag.NewFlagSet("", flag.ContinueOnError) s := sshListHostMapFlags{} fl.BoolVar(&s.Json, "json", false, "outputs as json with more information") fl.BoolVar(&s.Pretty, "pretty", false, "pretty prints json, assumes -json") return fl, &s }, - Callback: func(fs interface{}, a []string, w sshd.StringWriter) error { + Callback: func(fs any, a []string, w sshd.StringWriter) error { return sshListLighthouseMap(f.lightHouse, fs, w) }, }) @@ -237,7 +237,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter ssh.RegisterCommand(&sshd.Command{ Name: "reload", ShortDescription: "Reloads configuration from disk, same as sending HUP to the process", - Callback: func(fs interface{}, a []string, w sshd.StringWriter) error { + Callback: func(fs any, a []string, w sshd.StringWriter) error { return sshReload(c, w) }, }) @@ -251,7 +251,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter ssh.RegisterCommand(&sshd.Command{ Name: "stop-cpu-profile", ShortDescription: "Stops a cpu profile and writes output to the previously provided file", - Callback: func(fs interface{}, a []string, w sshd.StringWriter) error { + Callback: func(fs any, a []string, w sshd.StringWriter) error { pprof.StopCPUProfile() return w.WriteLine("If a CPU profile was running it is now stopped") }, @@ -278,7 +278,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter ssh.RegisterCommand(&sshd.Command{ Name: "log-level", ShortDescription: "Gets or sets the current log level", - Callback: func(fs interface{}, a []string, w sshd.StringWriter) error { + Callback: func(fs any, a []string, w sshd.StringWriter) error { return sshLogLevel(l, fs, a, w) }, }) @@ -286,7 +286,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter ssh.RegisterCommand(&sshd.Command{ Name: "log-format", ShortDescription: "Gets or sets the current log format", - Callback: func(fs interface{}, a []string, w sshd.StringWriter) error { + Callback: func(fs any, a []string, w sshd.StringWriter) error { return sshLogFormat(l, fs, a, w) }, }) @@ -294,7 +294,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter ssh.RegisterCommand(&sshd.Command{ Name: "version", ShortDescription: "Prints the currently running version of nebula", - Callback: func(fs interface{}, a []string, w sshd.StringWriter) error { + Callback: func(fs any, a []string, w sshd.StringWriter) error { return sshVersion(f, fs, a, w) }, }) @@ -302,14 +302,14 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter ssh.RegisterCommand(&sshd.Command{ Name: "device-info", ShortDescription: "Prints information about the network device.", - Flags: func() (*flag.FlagSet, interface{}) { + Flags: func() (*flag.FlagSet, any) { fl := flag.NewFlagSet("", flag.ContinueOnError) s := sshDeviceInfoFlags{} fl.BoolVar(&s.Json, "json", false, "outputs as json with more information") fl.BoolVar(&s.Pretty, "pretty", false, "pretty prints json, assumes -json") return fl, &s }, - Callback: func(fs interface{}, a []string, w sshd.StringWriter) error { + Callback: func(fs any, a []string, w sshd.StringWriter) error { return sshDeviceInfo(f, fs, w) }, }) @@ -317,7 +317,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter ssh.RegisterCommand(&sshd.Command{ Name: "print-cert", ShortDescription: "Prints the current certificate being used or the certificate for the provided vpn addr", - Flags: func() (*flag.FlagSet, interface{}) { + Flags: func() (*flag.FlagSet, any) { fl := flag.NewFlagSet("", flag.ContinueOnError) s := sshPrintCertFlags{} fl.BoolVar(&s.Json, "json", false, "outputs as json") @@ -325,7 +325,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter fl.BoolVar(&s.Raw, "raw", false, "raw prints the PEM encoded certificate, not compatible with -json or -pretty") return fl, &s }, - Callback: func(fs interface{}, a []string, w sshd.StringWriter) error { + Callback: func(fs any, a []string, w sshd.StringWriter) error { return sshPrintCert(f, fs, a, w) }, }) @@ -333,13 +333,13 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter ssh.RegisterCommand(&sshd.Command{ Name: "print-tunnel", ShortDescription: "Prints json details about a tunnel for the provided vpn addr", - Flags: func() (*flag.FlagSet, interface{}) { + Flags: func() (*flag.FlagSet, any) { fl := flag.NewFlagSet("", flag.ContinueOnError) s := sshPrintTunnelFlags{} fl.BoolVar(&s.Pretty, "pretty", false, "pretty prints json") return fl, &s }, - Callback: func(fs interface{}, a []string, w sshd.StringWriter) error { + Callback: func(fs any, a []string, w sshd.StringWriter) error { return sshPrintTunnel(f, fs, a, w) }, }) @@ -347,13 +347,13 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter ssh.RegisterCommand(&sshd.Command{ Name: "print-relays", ShortDescription: "Prints json details about all relay info", - Flags: func() (*flag.FlagSet, interface{}) { + Flags: func() (*flag.FlagSet, any) { fl := flag.NewFlagSet("", flag.ContinueOnError) s := sshPrintTunnelFlags{} fl.BoolVar(&s.Pretty, "pretty", false, "pretty prints json") return fl, &s }, - Callback: func(fs interface{}, a []string, w sshd.StringWriter) error { + Callback: func(fs any, a []string, w sshd.StringWriter) error { return sshPrintRelays(f, fs, a, w) }, }) @@ -361,13 +361,13 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter ssh.RegisterCommand(&sshd.Command{ Name: "change-remote", ShortDescription: "Changes the remote address used in the tunnel for the provided vpn addr", - Flags: func() (*flag.FlagSet, interface{}) { + Flags: func() (*flag.FlagSet, any) { fl := flag.NewFlagSet("", flag.ContinueOnError) s := sshChangeRemoteFlags{} fl.StringVar(&s.Address, "address", "", "The new remote address, ip:port") return fl, &s }, - Callback: func(fs interface{}, a []string, w sshd.StringWriter) error { + Callback: func(fs any, a []string, w sshd.StringWriter) error { return sshChangeRemote(f, fs, a, w) }, }) @@ -375,13 +375,13 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter ssh.RegisterCommand(&sshd.Command{ Name: "close-tunnel", ShortDescription: "Closes a tunnel for the provided vpn addr", - Flags: func() (*flag.FlagSet, interface{}) { + Flags: func() (*flag.FlagSet, any) { fl := flag.NewFlagSet("", flag.ContinueOnError) s := sshCloseTunnelFlags{} fl.BoolVar(&s.LocalOnly, "local-only", false, "Disables notifying the remote that the tunnel is shutting down") return fl, &s }, - Callback: func(fs interface{}, a []string, w sshd.StringWriter) error { + Callback: func(fs any, a []string, w sshd.StringWriter) error { return sshCloseTunnel(f, fs, a, w) }, }) @@ -390,13 +390,13 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter Name: "create-tunnel", ShortDescription: "Creates a tunnel for the provided vpn address", Help: "The lighthouses will be queried for real addresses but you can provide one as well.", - Flags: func() (*flag.FlagSet, interface{}) { + Flags: func() (*flag.FlagSet, any) { fl := flag.NewFlagSet("", flag.ContinueOnError) s := sshCreateTunnelFlags{} fl.StringVar(&s.Address, "address", "", "Optionally provide a real remote address, ip:port ") return fl, &s }, - Callback: func(fs interface{}, a []string, w sshd.StringWriter) error { + Callback: func(fs any, a []string, w sshd.StringWriter) error { return sshCreateTunnel(f, fs, a, w) }, }) @@ -405,13 +405,13 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter Name: "query-lighthouse", ShortDescription: "Query the lighthouses for the provided vpn address", Help: "This command is asynchronous. Only currently known udp addresses will be printed.", - Callback: func(fs interface{}, a []string, w sshd.StringWriter) error { + Callback: func(fs any, a []string, w sshd.StringWriter) error { return sshQueryLighthouse(f, fs, a, w) }, }) } -func sshListHostMap(hl controlHostLister, a interface{}, w sshd.StringWriter) error { +func sshListHostMap(hl controlHostLister, a any, w sshd.StringWriter) error { fs, ok := a.(*sshListHostMapFlags) if !ok { return nil @@ -451,7 +451,7 @@ func sshListHostMap(hl controlHostLister, a interface{}, w sshd.StringWriter) er return nil } -func sshListLighthouseMap(lightHouse *LightHouse, a interface{}, w sshd.StringWriter) error { +func sshListLighthouseMap(lightHouse *LightHouse, a any, w sshd.StringWriter) error { fs, ok := a.(*sshListHostMapFlags) if !ok { return nil @@ -505,7 +505,7 @@ func sshListLighthouseMap(lightHouse *LightHouse, a interface{}, w sshd.StringWr return nil } -func sshStartCpuProfile(fs interface{}, a []string, w sshd.StringWriter) error { +func sshStartCpuProfile(fs any, a []string, w sshd.StringWriter) error { if len(a) == 0 { err := w.WriteLine("No path to write profile provided") return err @@ -527,11 +527,11 @@ func sshStartCpuProfile(fs interface{}, a []string, w sshd.StringWriter) error { return err } -func sshVersion(ifce *Interface, fs interface{}, a []string, w sshd.StringWriter) error { +func sshVersion(ifce *Interface, fs any, a []string, w sshd.StringWriter) error { return w.WriteLine(fmt.Sprintf("%s", ifce.version)) } -func sshQueryLighthouse(ifce *Interface, fs interface{}, a []string, w sshd.StringWriter) error { +func sshQueryLighthouse(ifce *Interface, fs any, a []string, w sshd.StringWriter) error { if len(a) == 0 { return w.WriteLine("No vpn address was provided") } @@ -553,7 +553,7 @@ func sshQueryLighthouse(ifce *Interface, fs interface{}, a []string, w sshd.Stri return json.NewEncoder(w.GetWriter()).Encode(cm) } -func sshCloseTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWriter) error { +func sshCloseTunnel(ifce *Interface, fs any, a []string, w sshd.StringWriter) error { flags, ok := fs.(*sshCloseTunnelFlags) if !ok { return nil @@ -593,7 +593,7 @@ func sshCloseTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWr return w.WriteLine("Closed") } -func sshCreateTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWriter) error { +func sshCreateTunnel(ifce *Interface, fs any, a []string, w sshd.StringWriter) error { flags, ok := fs.(*sshCreateTunnelFlags) if !ok { return nil @@ -638,7 +638,7 @@ func sshCreateTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringW return w.WriteLine("Created") } -func sshChangeRemote(ifce *Interface, fs interface{}, a []string, w sshd.StringWriter) error { +func sshChangeRemote(ifce *Interface, fs any, a []string, w sshd.StringWriter) error { flags, ok := fs.(*sshChangeRemoteFlags) if !ok { return nil @@ -675,7 +675,7 @@ func sshChangeRemote(ifce *Interface, fs interface{}, a []string, w sshd.StringW return w.WriteLine("Changed") } -func sshGetHeapProfile(fs interface{}, a []string, w sshd.StringWriter) error { +func sshGetHeapProfile(fs any, a []string, w sshd.StringWriter) error { if len(a) == 0 { return w.WriteLine("No path to write profile provided") } @@ -696,7 +696,7 @@ func sshGetHeapProfile(fs interface{}, a []string, w sshd.StringWriter) error { return err } -func sshMutexProfileFraction(fs interface{}, a []string, w sshd.StringWriter) error { +func sshMutexProfileFraction(fs any, a []string, w sshd.StringWriter) error { if len(a) == 0 { rate := runtime.SetMutexProfileFraction(-1) return w.WriteLine(fmt.Sprintf("Current value: %d", rate)) @@ -711,7 +711,7 @@ func sshMutexProfileFraction(fs interface{}, a []string, w sshd.StringWriter) er return w.WriteLine(fmt.Sprintf("New value: %d. Old value: %d", newRate, oldRate)) } -func sshGetMutexProfile(fs interface{}, a []string, w sshd.StringWriter) error { +func sshGetMutexProfile(fs any, a []string, w sshd.StringWriter) error { if len(a) == 0 { return w.WriteLine("No path to write profile provided") } @@ -735,7 +735,7 @@ func sshGetMutexProfile(fs interface{}, a []string, w sshd.StringWriter) error { return w.WriteLine(fmt.Sprintf("Mutex profile created at %s", a)) } -func sshLogLevel(l *logrus.Logger, fs interface{}, a []string, w sshd.StringWriter) error { +func sshLogLevel(l *logrus.Logger, fs any, a []string, w sshd.StringWriter) error { if len(a) == 0 { return w.WriteLine(fmt.Sprintf("Log level is: %s", l.Level)) } @@ -749,7 +749,7 @@ func sshLogLevel(l *logrus.Logger, fs interface{}, a []string, w sshd.StringWrit return w.WriteLine(fmt.Sprintf("Log level is: %s", l.Level)) } -func sshLogFormat(l *logrus.Logger, fs interface{}, a []string, w sshd.StringWriter) error { +func sshLogFormat(l *logrus.Logger, fs any, a []string, w sshd.StringWriter) error { if len(a) == 0 { return w.WriteLine(fmt.Sprintf("Log format is: %s", reflect.TypeOf(l.Formatter))) } @@ -767,7 +767,7 @@ func sshLogFormat(l *logrus.Logger, fs interface{}, a []string, w sshd.StringWri return w.WriteLine(fmt.Sprintf("Log format is: %s", reflect.TypeOf(l.Formatter))) } -func sshPrintCert(ifce *Interface, fs interface{}, a []string, w sshd.StringWriter) error { +func sshPrintCert(ifce *Interface, fs any, a []string, w sshd.StringWriter) error { args, ok := fs.(*sshPrintCertFlags) if !ok { return nil @@ -822,7 +822,7 @@ func sshPrintCert(ifce *Interface, fs interface{}, a []string, w sshd.StringWrit return w.WriteLine(cert.String()) } -func sshPrintRelays(ifce *Interface, fs interface{}, a []string, w sshd.StringWriter) error { +func sshPrintRelays(ifce *Interface, fs any, a []string, w sshd.StringWriter) error { args, ok := fs.(*sshPrintTunnelFlags) if !ok { w.WriteLine(fmt.Sprintf("sshPrintRelays failed to convert args type")) @@ -919,7 +919,7 @@ func sshPrintRelays(ifce *Interface, fs interface{}, a []string, w sshd.StringWr return nil } -func sshPrintTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWriter) error { +func sshPrintTunnel(ifce *Interface, fs any, a []string, w sshd.StringWriter) error { args, ok := fs.(*sshPrintTunnelFlags) if !ok { return nil @@ -951,7 +951,7 @@ func sshPrintTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWr return enc.Encode(copyHostInfo(hostInfo, ifce.hostMap.GetPreferredRanges())) } -func sshDeviceInfo(ifce *Interface, fs interface{}, w sshd.StringWriter) error { +func sshDeviceInfo(ifce *Interface, fs any, w sshd.StringWriter) error { data := struct { Name string `json:"name"` diff --git a/sshd/command.go b/sshd/command.go index 66646a6..7323d12 100644 --- a/sshd/command.go +++ b/sshd/command.go @@ -12,7 +12,7 @@ import ( // CommandFlags is a function called before help or command execution to parse command line flags // It should return a flag.FlagSet instance and a pointer to the struct that will contain parsed flags -type CommandFlags func() (*flag.FlagSet, interface{}) +type CommandFlags func() (*flag.FlagSet, any) // CommandCallback is the function called when your command should execute. // fs will be a a pointer to the struct provided by Command.Flags callback, if there was one. -h and -help are reserved @@ -21,7 +21,7 @@ type CommandFlags func() (*flag.FlagSet, interface{}) // w is the writer to use when sending messages back to the client. // If an error is returned by the callback it is logged locally, the callback should handle messaging errors to the user // where appropriate -type CommandCallback func(fs interface{}, a []string, w StringWriter) error +type CommandCallback func(fs any, a []string, w StringWriter) error type Command struct { Name string @@ -34,7 +34,7 @@ type Command struct { func execCommand(c *Command, args []string, w StringWriter) error { var ( fl *flag.FlagSet - fs interface{} + fs any ) if c.Flags != nil { @@ -85,7 +85,7 @@ func lookupCommand(c *radix.Tree, sCmd string) (*Command, error) { func matchCommand(c *radix.Tree, cmd string) []string { cmds := make([]string, 0) - c.WalkPrefix(cmd, func(found string, v interface{}) bool { + c.WalkPrefix(cmd, func(found string, v any) bool { cmds = append(cmds, found) return false }) @@ -95,7 +95,7 @@ func matchCommand(c *radix.Tree, cmd string) []string { func allCommands(c *radix.Tree) []*Command { cmds := make([]*Command, 0) - c.WalkPrefix("", func(found string, v interface{}) bool { + c.WalkPrefix("", func(found string, v any) bool { cmd, ok := v.(*Command) if ok { cmds = append(cmds, cmd) diff --git a/sshd/server.go b/sshd/server.go index c151f91..a8b60ba 100644 --- a/sshd/server.go +++ b/sshd/server.go @@ -86,7 +86,7 @@ func NewSSHServer(l *logrus.Entry) (*SSHServer, error) { s.RegisterCommand(&Command{ Name: "help", ShortDescription: "prints available commands or help for specific usage info", - Callback: func(a interface{}, args []string, w StringWriter) error { + Callback: func(a any, args []string, w StringWriter) error { return helpCallback(s.commands, args, w) }, }) diff --git a/sshd/session.go b/sshd/session.go index 7c5869e..03b20cd 100644 --- a/sshd/session.go +++ b/sshd/session.go @@ -31,7 +31,7 @@ func NewSession(commands *radix.Tree, conn *ssh.ServerConn, chans <-chan ssh.New s.commands.Insert("logout", &Command{ Name: "logout", ShortDescription: "Ends the current session", - Callback: func(a interface{}, args []string, w StringWriter) error { + Callback: func(a any, args []string, w StringWriter) error { s.Close() return nil }, diff --git a/test/assert.go b/test/assert.go index d34252e..1856877 100644 --- a/test/assert.go +++ b/test/assert.go @@ -13,7 +13,7 @@ import ( // AssertDeepCopyEqual checks to see if two variables have the same values but DO NOT share any memory // There is currently a special case for `time.loc` (as this code traverses into unexported fields) -func AssertDeepCopyEqual(t *testing.T, a interface{}, b interface{}) { +func AssertDeepCopyEqual(t *testing.T, a any, b any) { v1 := reflect.ValueOf(a) v2 := reflect.ValueOf(b) diff --git a/util/error.go b/util/error.go index d7710f9..814c77a 100644 --- a/util/error.go +++ b/util/error.go @@ -9,11 +9,11 @@ import ( type ContextualError struct { RealError error - Fields map[string]interface{} + Fields map[string]any Context string } -func NewContextualError(msg string, fields map[string]interface{}, realError error) *ContextualError { +func NewContextualError(msg string, fields map[string]any, realError error) *ContextualError { return &ContextualError{Context: msg, Fields: fields, RealError: realError} } diff --git a/util/error_test.go b/util/error_test.go index 5041f82..692c184 100644 --- a/util/error_test.go +++ b/util/error_test.go @@ -9,7 +9,7 @@ import ( "github.com/stretchr/testify/assert" ) -type m map[string]interface{} +type m = map[string]any type TestLogWriter struct { Logs []string From 36bc9dd26134e31b6893b431e92a5d37e58711e0 Mon Sep 17 00:00:00 2001 From: Wade Simmons Date: Tue, 1 Apr 2025 09:49:26 -0400 Subject: [PATCH 60/67] fix parseUnsafeRoutes for yaml.v3 (#1371) We switched to yaml.v3 with #1148, but missed this spot that was still casting into `map[any]any` when yaml.v3 makes it `map[string]any`. Also clean up a few more `interface{}` that were added as we changed them all to `any` with #1148. --- overlay/route.go | 4 ++-- overlay/route_test.go | 32 ++++++++++++++++---------------- 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/overlay/route.go b/overlay/route.go index 360921f..6198958 100644 --- a/overlay/route.go +++ b/overlay/route.go @@ -215,10 +215,10 @@ func parseUnsafeRoutes(c *config.C, networks []netip.Prefix) ([]Route, error) { gateways = routing.Gateways{routing.NewGateway(viaIp, 1)} - case []interface{}: + case []any: gateways = make(routing.Gateways, len(via)) for ig, v := range via { - gatewayMap, ok := v.(map[interface{}]interface{}) + gatewayMap, ok := v.(map[string]any) if !ok { return nil, fmt.Errorf("entry %v in tun.unsafe_routes[%v].via is invalid", i+1, ig+1) } diff --git a/overlay/route_test.go b/overlay/route_test.go index 6b5ae2e..9a959a5 100644 --- a/overlay/route_test.go +++ b/overlay/route_test.go @@ -163,7 +163,7 @@ func Test_parseUnsafeRoutes(t *testing.T) { } // Unparsable list of via - c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": []string{"1", "2"}}}} + c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"via": []string{"1", "2"}}}} routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) require.EqualError(t, err, "entry 1.via in tun.unsafe_routes is not a string or list of gateways: found []string") @@ -175,19 +175,19 @@ func Test_parseUnsafeRoutes(t *testing.T) { require.EqualError(t, err, "entry 1.via in tun.unsafe_routes failed to parse address: ParseAddr(\"nope\"): unable to parse IP") // unparsable gateway - c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "via": []interface{}{map[interface{}]interface{}{"gateway": "1"}}}}} + c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"mtu": "500", "via": []any{map[string]any{"gateway": "1"}}}}} routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) require.EqualError(t, err, "entry .gateway in tun.unsafe_routes[1].via[1] failed to parse address: ParseAddr(\"1\"): unable to parse IP") // missing gateway element - c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "via": []interface{}{map[interface{}]interface{}{"weight": "1"}}}}} + c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"mtu": "500", "via": []any{map[string]any{"weight": "1"}}}}} routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) require.EqualError(t, err, "entry .gateway in tun.unsafe_routes[1].via[1] is not present") // unparsable weight element - c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "via": []interface{}{map[interface{}]interface{}{"gateway": "10.0.0.1", "weight": "a"}}}}} + c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"mtu": "500", "via": []any{map[string]any{"gateway": "10.0.0.1", "weight": "a"}}}}} routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) require.EqualError(t, err, "entry .weight in tun.unsafe_routes[1].via[1] is not an integer") @@ -328,34 +328,34 @@ func Test_makeMultipathUnsafeRouteTree(t *testing.T) { n, err := netip.ParsePrefix("10.0.0.0/24") require.NoError(t, err) - c.Settings["tun"] = map[interface{}]interface{}{ - "unsafe_routes": []interface{}{ - map[interface{}]interface{}{ + c.Settings["tun"] = map[string]any{ + "unsafe_routes": []any{ + map[string]any{ "route": "192.168.86.0/24", "via": "192.168.100.10", }, - map[interface{}]interface{}{ + map[string]any{ "route": "192.168.87.0/24", - "via": []interface{}{ - map[interface{}]interface{}{ + "via": []any{ + map[string]any{ "gateway": "10.0.0.1", }, - map[interface{}]interface{}{ + map[string]any{ "gateway": "10.0.0.2", }, - map[interface{}]interface{}{ + map[string]any{ "gateway": "10.0.0.3", }, }, }, - map[interface{}]interface{}{ + map[string]any{ "route": "192.168.89.0/24", - "via": []interface{}{ - map[interface{}]interface{}{ + "via": []any{ + map[string]any{ "gateway": "10.0.0.1", "weight": 10, }, - map[interface{}]interface{}{ + map[string]any{ "gateway": "10.0.0.2", "weight": 5, }, From d2adebf26daed3d29ac4a2c664de999dc8c79fac Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 1 Apr 2025 13:24:19 -0400 Subject: [PATCH 61/67] Bump golangci/golangci-lint-action from 6 to 7 (#1361) * Bump golangci/golangci-lint-action from 6 to 7 Bumps [golangci/golangci-lint-action](https://github.com/golangci/golangci-lint-action) from 6 to 7. - [Release notes](https://github.com/golangci/golangci-lint-action/releases) - [Commits](https://github.com/golangci/golangci-lint-action/compare/v6...v7) --- updated-dependencies: - dependency-name: golangci/golangci-lint-action dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] * use latest golangci-lint * pin to v2.0 * golangci-lint migrate * make the tests happy --------- Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Wade Simmons --- .github/workflows/test.yml | 8 +++---- .golangci.yaml | 26 +++++++++++++++++----- cert/crypto_test.go | 4 ++-- cmd/nebula-cert/ca_test.go | 40 +++++++++++++++++----------------- cmd/nebula-cert/keygen_test.go | 20 ++++++++--------- cmd/nebula-cert/print_test.go | 16 +++++++------- cmd/nebula-cert/verify_test.go | 32 +++++++++++++-------------- control_test.go | 2 +- firewall_test.go | 2 +- hostmap_test.go | 4 ++-- 10 files changed, 84 insertions(+), 70 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 28f0590..006115d 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -32,9 +32,9 @@ jobs: run: make vet - name: golangci-lint - uses: golangci/golangci-lint-action@v6 + uses: golangci/golangci-lint-action@v7 with: - version: v1.64 + version: v2.0 - name: Test run: make test @@ -115,9 +115,9 @@ jobs: run: make vet - name: golangci-lint - uses: golangci/golangci-lint-action@v6 + uses: golangci/golangci-lint-action@v7 with: - version: v1.64 + version: v2.0 - name: Test run: make test diff --git a/.golangci.yaml b/.golangci.yaml index f792069..bd82a95 100644 --- a/.golangci.yaml +++ b/.golangci.yaml @@ -1,9 +1,23 @@ -# yaml-language-server: $schema=https://golangci-lint.run/jsonschema/golangci.jsonschema.json +version: "2" linters: - # Disable all linters. - # Default: false - disable-all: true - # Enable specific linter - # https://golangci-lint.run/usage/linters/#enabled-by-default + default: none enable: - testifylint + exclusions: + generated: lax + presets: + - comments + - common-false-positives + - legacy + - std-error-handling + paths: + - third_party$ + - builtin$ + - examples$ +formatters: + exclusions: + generated: lax + paths: + - third_party$ + - builtin$ + - examples$ diff --git a/cert/crypto_test.go b/cert/crypto_test.go index ee671c0..6358ba6 100644 --- a/cert/crypto_test.go +++ b/cert/crypto_test.go @@ -10,14 +10,14 @@ import ( func TestNewArgon2Parameters(t *testing.T) { p := NewArgon2Parameters(64*1024, 4, 3) - assert.EqualValues(t, &Argon2Parameters{ + assert.Equal(t, &Argon2Parameters{ version: argon2.Version, Memory: 64 * 1024, Parallelism: 4, Iterations: 3, }, p) p = NewArgon2Parameters(2*1024*1024, 2, 1) - assert.EqualValues(t, &Argon2Parameters{ + assert.Equal(t, &Argon2Parameters{ version: argon2.Version, Memory: 2 * 1024 * 1024, Parallelism: 2, diff --git a/cmd/nebula-cert/ca_test.go b/cmd/nebula-cert/ca_test.go index 189fc02..b1cbde9 100644 --- a/cmd/nebula-cert/ca_test.go +++ b/cmd/nebula-cert/ca_test.go @@ -90,26 +90,26 @@ func Test_ca(t *testing.T) { assertHelpError(t, ca( []string{"-version", "1", "-out-key", "nope", "-out-crt", "nope", "duration", "100m"}, ob, eb, nopw, ), "-name is required") - assert.Equal(t, "", ob.String()) - assert.Equal(t, "", eb.String()) + assert.Empty(t, ob.String()) + assert.Empty(t, eb.String()) // ipv4 only ips assertHelpError(t, ca([]string{"-version", "1", "-name", "ipv6", "-ips", "100::100/100"}, ob, eb, nopw), "invalid -networks definition: v1 certificates can only be ipv4, have 100::100/100") - assert.Equal(t, "", ob.String()) - assert.Equal(t, "", eb.String()) + assert.Empty(t, ob.String()) + assert.Empty(t, eb.String()) // ipv4 only subnets assertHelpError(t, ca([]string{"-version", "1", "-name", "ipv6", "-subnets", "100::100/100"}, ob, eb, nopw), "invalid -unsafe-networks definition: v1 certificates can only be ipv4, have 100::100/100") - assert.Equal(t, "", ob.String()) - assert.Equal(t, "", eb.String()) + assert.Empty(t, ob.String()) + assert.Empty(t, eb.String()) // failed key write ob.Reset() eb.Reset() args := []string{"-version", "1", "-name", "test", "-duration", "100m", "-out-crt", "/do/not/write/pleasecrt", "-out-key", "/do/not/write/pleasekey"} require.EqualError(t, ca(args, ob, eb, nopw), "error while writing out-key: open /do/not/write/pleasekey: "+NoSuchDirError) - assert.Equal(t, "", ob.String()) - assert.Equal(t, "", eb.String()) + assert.Empty(t, ob.String()) + assert.Empty(t, eb.String()) // create temp key file keyF, err := os.CreateTemp("", "test.key") @@ -121,8 +121,8 @@ func Test_ca(t *testing.T) { eb.Reset() args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-out-crt", "/do/not/write/pleasecrt", "-out-key", keyF.Name()} require.EqualError(t, ca(args, ob, eb, nopw), "error while writing out-crt: open /do/not/write/pleasecrt: "+NoSuchDirError) - assert.Equal(t, "", ob.String()) - assert.Equal(t, "", eb.String()) + assert.Empty(t, ob.String()) + assert.Empty(t, eb.String()) // create temp cert file crtF, err := os.CreateTemp("", "test.crt") @@ -135,8 +135,8 @@ func Test_ca(t *testing.T) { eb.Reset() args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} require.NoError(t, ca(args, ob, eb, nopw)) - assert.Equal(t, "", ob.String()) - assert.Equal(t, "", eb.String()) + assert.Empty(t, ob.String()) + assert.Empty(t, eb.String()) // read cert and key files rb, _ := os.ReadFile(keyF.Name()) @@ -158,7 +158,7 @@ func Test_ca(t *testing.T) { assert.Empty(t, lCrt.UnsafeNetworks()) assert.Len(t, lCrt.PublicKey(), 32) assert.Equal(t, time.Duration(time.Minute*100), lCrt.NotAfter().Sub(lCrt.NotBefore())) - assert.Equal(t, "", lCrt.Issuer()) + assert.Empty(t, lCrt.Issuer()) assert.True(t, lCrt.CheckSignature(lCrt.PublicKey())) // test encrypted key @@ -169,7 +169,7 @@ func Test_ca(t *testing.T) { args = []string{"-version", "1", "-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} require.NoError(t, ca(args, ob, eb, testpw)) assert.Equal(t, pwPromptOb, ob.String()) - assert.Equal(t, "", eb.String()) + assert.Empty(t, eb.String()) // read encrypted key file and verify default params rb, _ = os.ReadFile(keyF.Name()) @@ -197,7 +197,7 @@ func Test_ca(t *testing.T) { args = []string{"-version", "1", "-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} require.Error(t, ca(args, ob, eb, errpw)) assert.Equal(t, pwPromptOb, ob.String()) - assert.Equal(t, "", eb.String()) + assert.Empty(t, eb.String()) // test when user fails to enter a password os.Remove(keyF.Name()) @@ -207,7 +207,7 @@ func Test_ca(t *testing.T) { args = []string{"-version", "1", "-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} require.EqualError(t, ca(args, ob, eb, nopw), "no passphrase specified, remove -encrypt flag to write out-key in plaintext") assert.Equal(t, strings.Repeat(pwPromptOb, 5), ob.String()) // prompts 5 times before giving up - assert.Equal(t, "", eb.String()) + assert.Empty(t, eb.String()) // create valid cert/key for overwrite tests os.Remove(keyF.Name()) @@ -222,8 +222,8 @@ func Test_ca(t *testing.T) { eb.Reset() args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} require.EqualError(t, ca(args, ob, eb, nopw), "refusing to overwrite existing CA key: "+keyF.Name()) - assert.Equal(t, "", ob.String()) - assert.Equal(t, "", eb.String()) + assert.Empty(t, ob.String()) + assert.Empty(t, eb.String()) // test that we won't overwrite existing key file os.Remove(keyF.Name()) @@ -231,8 +231,8 @@ func Test_ca(t *testing.T) { eb.Reset() args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} require.EqualError(t, ca(args, ob, eb, nopw), "refusing to overwrite existing CA cert: "+crtF.Name()) - assert.Equal(t, "", ob.String()) - assert.Equal(t, "", eb.String()) + assert.Empty(t, ob.String()) + assert.Empty(t, eb.String()) os.Remove(keyF.Name()) } diff --git a/cmd/nebula-cert/keygen_test.go b/cmd/nebula-cert/keygen_test.go index 7eed5d2..95d9893 100644 --- a/cmd/nebula-cert/keygen_test.go +++ b/cmd/nebula-cert/keygen_test.go @@ -37,20 +37,20 @@ func Test_keygen(t *testing.T) { // required args assertHelpError(t, keygen([]string{"-out-pub", "nope"}, ob, eb), "-out-key is required") - assert.Equal(t, "", ob.String()) - assert.Equal(t, "", eb.String()) + assert.Empty(t, ob.String()) + assert.Empty(t, eb.String()) assertHelpError(t, keygen([]string{"-out-key", "nope"}, ob, eb), "-out-pub is required") - assert.Equal(t, "", ob.String()) - assert.Equal(t, "", eb.String()) + assert.Empty(t, ob.String()) + assert.Empty(t, eb.String()) // failed key write ob.Reset() eb.Reset() args := []string{"-out-pub", "/do/not/write/pleasepub", "-out-key", "/do/not/write/pleasekey"} require.EqualError(t, keygen(args, ob, eb), "error while writing out-key: open /do/not/write/pleasekey: "+NoSuchDirError) - assert.Equal(t, "", ob.String()) - assert.Equal(t, "", eb.String()) + assert.Empty(t, ob.String()) + assert.Empty(t, eb.String()) // create temp key file keyF, err := os.CreateTemp("", "test.key") @@ -62,8 +62,8 @@ func Test_keygen(t *testing.T) { eb.Reset() args = []string{"-out-pub", "/do/not/write/pleasepub", "-out-key", keyF.Name()} require.EqualError(t, keygen(args, ob, eb), "error while writing out-pub: open /do/not/write/pleasepub: "+NoSuchDirError) - assert.Equal(t, "", ob.String()) - assert.Equal(t, "", eb.String()) + assert.Empty(t, ob.String()) + assert.Empty(t, eb.String()) // create temp pub file pubF, err := os.CreateTemp("", "test.pub") @@ -75,8 +75,8 @@ func Test_keygen(t *testing.T) { eb.Reset() args = []string{"-out-pub", pubF.Name(), "-out-key", keyF.Name()} require.NoError(t, keygen(args, ob, eb)) - assert.Equal(t, "", ob.String()) - assert.Equal(t, "", eb.String()) + assert.Empty(t, ob.String()) + assert.Empty(t, eb.String()) // read cert and key files rb, _ := os.ReadFile(keyF.Name()) diff --git a/cmd/nebula-cert/print_test.go b/cmd/nebula-cert/print_test.go index 061e472..221ab77 100644 --- a/cmd/nebula-cert/print_test.go +++ b/cmd/nebula-cert/print_test.go @@ -43,16 +43,16 @@ func Test_printCert(t *testing.T) { // no path err := printCert([]string{}, ob, eb) - assert.Equal(t, "", ob.String()) - assert.Equal(t, "", eb.String()) + assert.Empty(t, ob.String()) + assert.Empty(t, eb.String()) assertHelpError(t, err, "-path is required") // no cert at path ob.Reset() eb.Reset() err = printCert([]string{"-path", "does_not_exist"}, ob, eb) - assert.Equal(t, "", ob.String()) - assert.Equal(t, "", eb.String()) + assert.Empty(t, ob.String()) + assert.Empty(t, eb.String()) require.EqualError(t, err, "unable to read cert; open does_not_exist: "+NoSuchFileError) // invalid cert at path @@ -64,8 +64,8 @@ func Test_printCert(t *testing.T) { tf.WriteString("-----BEGIN NOPE-----") err = printCert([]string{"-path", tf.Name()}, ob, eb) - assert.Equal(t, "", ob.String()) - assert.Equal(t, "", eb.String()) + assert.Empty(t, ob.String()) + assert.Empty(t, eb.String()) require.EqualError(t, err, "error while unmarshaling cert: input did not contain a valid PEM encoded block") // test multiple certs @@ -155,7 +155,7 @@ func Test_printCert(t *testing.T) { `, ob.String(), ) - assert.Equal(t, "", eb.String()) + assert.Empty(t, eb.String()) // test json ob.Reset() @@ -177,7 +177,7 @@ func Test_printCert(t *testing.T) { `, ob.String(), ) - assert.Equal(t, "", eb.String()) + assert.Empty(t, eb.String()) } // NewTestCaCert will generate a CA cert diff --git a/cmd/nebula-cert/verify_test.go b/cmd/nebula-cert/verify_test.go index acc9cca..f555e5f 100644 --- a/cmd/nebula-cert/verify_test.go +++ b/cmd/nebula-cert/verify_test.go @@ -38,19 +38,19 @@ func Test_verify(t *testing.T) { // required args assertHelpError(t, verify([]string{"-ca", "derp"}, ob, eb), "-crt is required") - assert.Equal(t, "", ob.String()) - assert.Equal(t, "", eb.String()) + assert.Empty(t, ob.String()) + assert.Empty(t, eb.String()) assertHelpError(t, verify([]string{"-crt", "derp"}, ob, eb), "-ca is required") - assert.Equal(t, "", ob.String()) - assert.Equal(t, "", eb.String()) + assert.Empty(t, ob.String()) + assert.Empty(t, eb.String()) // no ca at path ob.Reset() eb.Reset() err := verify([]string{"-ca", "does_not_exist", "-crt", "does_not_exist"}, ob, eb) - assert.Equal(t, "", ob.String()) - assert.Equal(t, "", eb.String()) + assert.Empty(t, ob.String()) + assert.Empty(t, eb.String()) require.EqualError(t, err, "error while reading ca: open does_not_exist: "+NoSuchFileError) // invalid ca at path @@ -62,8 +62,8 @@ func Test_verify(t *testing.T) { caFile.WriteString("-----BEGIN NOPE-----") err = verify([]string{"-ca", caFile.Name(), "-crt", "does_not_exist"}, ob, eb) - assert.Equal(t, "", ob.String()) - assert.Equal(t, "", eb.String()) + assert.Empty(t, ob.String()) + assert.Empty(t, eb.String()) require.EqualError(t, err, "error while adding ca cert to pool: input did not contain a valid PEM encoded block") // make a ca for later @@ -76,8 +76,8 @@ func Test_verify(t *testing.T) { // no crt at path err = verify([]string{"-ca", caFile.Name(), "-crt", "does_not_exist"}, ob, eb) - assert.Equal(t, "", ob.String()) - assert.Equal(t, "", eb.String()) + assert.Empty(t, ob.String()) + assert.Empty(t, eb.String()) require.EqualError(t, err, "unable to read crt: open does_not_exist: "+NoSuchFileError) // invalid crt at path @@ -89,8 +89,8 @@ func Test_verify(t *testing.T) { certFile.WriteString("-----BEGIN NOPE-----") err = verify([]string{"-ca", caFile.Name(), "-crt", certFile.Name()}, ob, eb) - assert.Equal(t, "", ob.String()) - assert.Equal(t, "", eb.String()) + assert.Empty(t, ob.String()) + assert.Empty(t, eb.String()) require.EqualError(t, err, "error while parsing crt: input did not contain a valid PEM encoded block") // unverifiable cert at path @@ -106,8 +106,8 @@ func Test_verify(t *testing.T) { certFile.Write(b) err = verify([]string{"-ca", caFile.Name(), "-crt", certFile.Name()}, ob, eb) - assert.Equal(t, "", ob.String()) - assert.Equal(t, "", eb.String()) + assert.Empty(t, ob.String()) + assert.Empty(t, eb.String()) require.ErrorIs(t, err, cert.ErrSignatureMismatch) // verified cert at path @@ -118,7 +118,7 @@ func Test_verify(t *testing.T) { certFile.Write(b) err = verify([]string{"-ca", caFile.Name(), "-crt", certFile.Name()}, ob, eb) - assert.Equal(t, "", ob.String()) - assert.Equal(t, "", eb.String()) + assert.Empty(t, ob.String()) + assert.Empty(t, eb.String()) require.NoError(t, err) } diff --git a/control_test.go b/control_test.go index de85fee..e400992 100644 --- a/control_test.go +++ b/control_test.go @@ -101,7 +101,7 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) { // Make sure we don't have any unexpected fields assertFields(t, []string{"VpnAddrs", "LocalIndex", "RemoteIndex", "RemoteAddrs", "Cert", "MessageCounter", "CurrentRemote", "CurrentRelaysToMe", "CurrentRelaysThroughMe"}, thi) - assert.EqualValues(t, &expectedInfo, thi) + assert.Equal(t, &expectedInfo, thi) test.AssertDeepCopyEqual(t, &expectedInfo, thi) // Make sure we don't panic if the host info doesn't have a cert yet diff --git a/firewall_test.go b/firewall_test.go index c90fb20..4731a6f 100644 --- a/firewall_test.go +++ b/firewall_test.go @@ -792,7 +792,7 @@ func TestFirewall_convertRule(t *testing.T) { } r, err = convertRule(l, c, "test", 1) - assert.Equal(t, "", ob.String()) + assert.Empty(t, ob.String()) require.Error(t, err, "group should contain a single value, an array with more than one entry was provided") // Make sure a well formed group is alright diff --git a/hostmap_test.go b/hostmap_test.go index e974340..b3580cf 100644 --- a/hostmap_test.go +++ b/hostmap_test.go @@ -210,8 +210,8 @@ func TestHostMap_reload(t *testing.T) { assert.Empty(t, hm.GetPreferredRanges()) c.ReloadConfigString("preferred_ranges: [1.1.1.0/24, 10.1.1.0/24]") - assert.EqualValues(t, []string{"1.1.1.0/24", "10.1.1.0/24"}, toS(hm.GetPreferredRanges())) + assert.Equal(t, []string{"1.1.1.0/24", "10.1.1.0/24"}, toS(hm.GetPreferredRanges())) c.ReloadConfigString("preferred_ranges: [1.1.1.1/32]") - assert.EqualValues(t, []string{"1.1.1.1/32"}, toS(hm.GetPreferredRanges())) + assert.Equal(t, []string{"1.1.1.1/32"}, toS(hm.GetPreferredRanges())) } From e136d1d47a630c9ac2de01949f3f3286fa110c23 Mon Sep 17 00:00:00 2001 From: John Maguire Date: Tue, 1 Apr 2025 17:08:03 -0400 Subject: [PATCH 62/67] Update example config with default_local_cidr_any changes (#1373) --- CHANGELOG.md | 7 +++++++ examples/config.yml | 18 ++++++++---------- 2 files changed, 15 insertions(+), 10 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index ad17147..1de3c19 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +### Changed + +- `default_local_cidr_any` now defaults to false, meaning that any firewall rule + intended to target an `unsafe_routes` entry must explicitly declare it via the + `local_cidr` field. This is almost always the intended behavior. This flag is + deprecated and will be removed in a future release. + ## [1.9.4] - 2024-09-09 ### Added diff --git a/examples/config.yml b/examples/config.yml index 3b7c38b..534608d 100644 --- a/examples/config.yml +++ b/examples/config.yml @@ -346,11 +346,11 @@ firewall: outbound_action: drop inbound_action: drop - # Controls the default value for local_cidr. Default is true, will be deprecated after v1.9 and defaulted to false. - # This setting only affects nebula hosts with subnets encoded in their certificate. A nebula host acting as an - # unsafe router with `default_local_cidr_any: true` will expose their unsafe routes to every inbound rule regardless - # of the actual destination for the packet. Setting this to false requires each inbound rule to contain a `local_cidr` - # if the intention is to allow traffic to flow to an unsafe route. + # THIS FLAG IS DEPRECATED AND WILL BE REMOVED IN A FUTURE RELEASE. (Defaults to false.) + # This setting only affects nebula hosts exposing unsafe_routes. When set to false, each inbound rule must contain a + # `local_cidr` if the intention is to allow traffic to flow to an unsafe route. When set to true, every firewall rule + # will apply to all configured unsafe_routes regardless of the actual destination of the packet, unless `local_cidr` + # is explicitly defined. This is usually not the desired behavior and should be avoided! #default_local_cidr_any: false conntrack: @@ -368,11 +368,9 @@ firewall: # group: `any` or a literal group name, ie `default-group` # groups: Same as group but accepts a list of values. Multiple values are AND'd together and a certificate would have to contain all groups to pass # cidr: a remote CIDR, `0.0.0.0/0` is any ipv4 and `::/0` is any ipv6. - # local_cidr: a local CIDR, `0.0.0.0/0` is any ipv4 and `::/0` is any ipv6. This could be used to filter destinations when using unsafe_routes. - # If no unsafe networks are present in the certificate(s) or `default_local_cidr_any` is true then the default is any ipv4 or ipv6 network. - # Otherwise the default is any vpn network assigned to via the certificate. - # `default_local_cidr_any` defaults to false and is deprecated, it will be removed in a future release. - # If there are unsafe routes present its best to set `local_cidr` to whatever best fits the situation. + # local_cidr: a local CIDR, `0.0.0.0/0` is any ipv4 and `::/0` is any ipv6. This can be used to filter destinations when using unsafe_routes. + # By default, this is set to only the VPN (overlay) networks assigned via the certificate networks field unless `default_local_cidr_any` is set to true. + # If there are unsafe_routes present in this config file, `local_cidr` should be set appropriately for the intended us case. # ca_name: An issuing CA name # ca_sha: An issuing CA shasum From 58ead4116ff6de08b56a7a32f930663cc9d2e9c4 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 1 Apr 2025 16:10:20 -0500 Subject: [PATCH 63/67] Bump github.com/gaissmai/bart from 0.18.1 to 0.20.1 (#1369) --- go.mod | 2 +- go.sum | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/go.mod b/go.mod index bbd5d8b..1b6be0b 100644 --- a/go.mod +++ b/go.mod @@ -10,7 +10,7 @@ require ( github.com/armon/go-radix v1.0.0 github.com/cyberdelia/go-metrics-graphite v0.0.0-20161219230853-39f87cc3b432 github.com/flynn/noise v1.1.0 - github.com/gaissmai/bart v0.18.1 + github.com/gaissmai/bart v0.20.1 github.com/gogo/protobuf v1.3.2 github.com/google/gopacket v1.1.19 github.com/kardianos/service v1.2.2 diff --git a/go.sum b/go.sum index 8237bfa..f142a20 100644 --- a/go.sum +++ b/go.sum @@ -24,8 +24,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/flynn/noise v1.1.0 h1:KjPQoQCEFdZDiP03phOvGi11+SVVhBG2wOWAorLsstg= github.com/flynn/noise v1.1.0/go.mod h1:xbMo+0i6+IGbYdJhF31t2eR1BIU0CYc12+BNAKwUTag= -github.com/gaissmai/bart v0.18.1 h1:bX2j560JC1MJpoEDevBGvXL5OZ1mkls320Vl8Igb5QQ= -github.com/gaissmai/bart v0.18.1/go.mod h1:JJzMAhNF5Rjo4SF4jWBrANuJfqY+FvsFhW7t1UZJ+XY= +github.com/gaissmai/bart v0.20.1 h1:igNss0zDsSY8e+ophKgD9KJVPKBOo7uSVjyKCL7nIzo= +github.com/gaissmai/bart v0.20.1/go.mod h1:JJzMAhNF5Rjo4SF4jWBrANuJfqY+FvsFhW7t1UZJ+XY= github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY= From e4bae1582556a264f7629b4d368098db1efbf723 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 1 Apr 2025 16:23:35 -0500 Subject: [PATCH 64/67] Bump google.golang.org/protobuf in the protobuf-dependencies group (#1365) --- go.mod | 2 +- go.sum | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/go.mod b/go.mod index 1b6be0b..19e83ab 100644 --- a/go.mod +++ b/go.mod @@ -33,7 +33,7 @@ require ( golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b golang.zx2c4.com/wireguard/windows v0.5.3 - google.golang.org/protobuf v1.36.5 + google.golang.org/protobuf v1.36.6 gopkg.in/yaml.v3 v3.0.1 gvisor.dev/gvisor v0.0.0-20240423190808-9d7a357edefe ) diff --git a/go.sum b/go.sum index f142a20..fa8c29b 100644 --- a/go.sum +++ b/go.sum @@ -239,8 +239,8 @@ google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miE google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo= google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= -google.golang.org/protobuf v1.36.5 h1:tPhr+woSbjfYvY6/GPufUoYizxw1cF/yFoxJ2fmpwlM= -google.golang.org/protobuf v1.36.5/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE= +google.golang.org/protobuf v1.36.6 h1:z1NpPI8ku2WgiWnf+t9wTPsn6eP1L7ksHUlkfLvd9xY= +google.golang.org/protobuf v1.36.6/go.mod h1:jduwjTPXsFjZGTmRluh+L6NjiWu7pchiJ2/5YcXBHnY= gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= From d99fd60e0622dd48b5ea67a14a251dc44efd404d Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 1 Apr 2025 16:26:23 -0500 Subject: [PATCH 65/67] Bump Apple-Actions/import-codesign-certs from 3 to 5 (#1364) --- .github/workflows/release.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index f9df115..3107b47 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -75,7 +75,7 @@ jobs: - name: Import certificates if: env.HAS_SIGNING_CREDS == 'true' - uses: Apple-Actions/import-codesign-certs@v3 + uses: Apple-Actions/import-codesign-certs@v5 with: p12-file-base64: ${{ secrets.APPLE_DEVELOPER_CERTIFICATE_P12_BASE64 }} p12-password: ${{ secrets.APPLE_DEVELOPER_CERTIFICATE_PASSWORD }} From e2d6f4e444d51d46f2a4715836fdf7d116187148 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 1 Apr 2025 16:28:27 -0500 Subject: [PATCH 66/67] Bump github.com/miekg/dns from 1.1.63 to 1.1.64 (#1363) --- go.mod | 6 +++--- go.sum | 12 ++++++------ 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/go.mod b/go.mod index 19e83ab..7302092 100644 --- a/go.mod +++ b/go.mod @@ -14,7 +14,7 @@ require ( github.com/gogo/protobuf v1.3.2 github.com/google/gopacket v1.1.19 github.com/kardianos/service v1.2.2 - github.com/miekg/dns v1.1.63 + github.com/miekg/dns v1.1.64 github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b github.com/nbrownus/go-metrics-prometheus v0.0.0-20210712211119-974a6260965f github.com/prometheus/client_golang v1.21.1 @@ -50,7 +50,7 @@ require ( github.com/prometheus/common v0.62.0 // indirect github.com/prometheus/procfs v0.15.1 // indirect github.com/vishvananda/netns v0.0.4 // indirect - golang.org/x/mod v0.18.0 // indirect + golang.org/x/mod v0.23.0 // indirect golang.org/x/time v0.5.0 // indirect - golang.org/x/tools v0.22.0 // indirect + golang.org/x/tools v0.30.0 // indirect ) diff --git a/go.sum b/go.sum index fa8c29b..030d6ef 100644 --- a/go.sum +++ b/go.sum @@ -83,8 +83,8 @@ github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= -github.com/miekg/dns v1.1.63 h1:8M5aAw6OMZfFXTT7K5V0Eu5YiiL8l7nUAkyN6C9YwaY= -github.com/miekg/dns v1.1.63/go.mod h1:6NGHfjhpmr5lt3XPLuyfDJi5AXbNIPM9PY6H6sF1Nfs= +github.com/miekg/dns v1.1.64 h1:wuZgD9wwCE6XMT05UU/mlSko71eRSXEAm2EbjQXLKnQ= +github.com/miekg/dns v1.1.64/go.mod h1:Dzw9769uoKVaLuODMDZz9M6ynFU6Em65csPuoi8G0ck= github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b h1:J/AzCvg5z0Hn1rqZUJjpbzALUmkKX0Zwbc/i4fw7Sfk= github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b/go.mod h1:XsNlhZGX73bx86s2hdc/FuaLm2CPZJemRLMA+WTFxgs= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= @@ -164,8 +164,8 @@ golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPI golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= -golang.org/x/mod v0.18.0 h1:5+9lSbEzPSdWkH32vYPBwEpX8KwDbM52Ud9xBUvNlb0= -golang.org/x/mod v0.18.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= +golang.org/x/mod v0.23.0 h1:Zb7khfcRGKk+kqfxFaP5tZqCnDZMjC5VtUBs87Hr6QM= +golang.org/x/mod v0.23.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= @@ -219,8 +219,8 @@ golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtn golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= -golang.org/x/tools v0.22.0 h1:gqSGLZqv+AI9lIQzniJ0nZDRG5GBPsSi+DRNHWNz6yA= -golang.org/x/tools v0.22.0/go.mod h1:aCwcsjqvq7Yqt6TNyX7QMU2enbQ/Gt0bo6krSeEri+c= +golang.org/x/tools v0.30.0 h1:BgcpHewrV5AUp2G9MebG4XPFI1E2W41zU1SaqVA9vJY= +golang.org/x/tools v0.30.0/go.mod h1:c347cR/OJfw5TI+GfX7RUPNMdDRRbjvYTS0jPyvsVtY= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= From f5d096dd2b719e75b0b41bdb7a0e1fd8f86b02bb Mon Sep 17 00:00:00 2001 From: Wade Simmons Date: Wed, 2 Apr 2025 09:11:34 -0400 Subject: [PATCH 67/67] move to golang.org/x/term (#1372) The `golang.org/x/crypto/ssh/terminal` was deprecated and moved to `golang.org/x/term`. We already use the new package in `cmd/nebula-cert`, so fix our remaining reference here. See: - https://github.com/golang/go/issues/31044 --- sshd/session.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sshd/session.go b/sshd/session.go index 03b20cd..87cc216 100644 --- a/sshd/session.go +++ b/sshd/session.go @@ -9,13 +9,13 @@ import ( "github.com/armon/go-radix" "github.com/sirupsen/logrus" "golang.org/x/crypto/ssh" - "golang.org/x/crypto/ssh/terminal" + "golang.org/x/term" ) type session struct { l *logrus.Entry c *ssh.ServerConn - term *terminal.Terminal + term *term.Terminal commands *radix.Tree exitChan chan bool } @@ -106,8 +106,8 @@ func (s *session) handleRequests(in <-chan *ssh.Request, channel ssh.Channel) { } } -func (s *session) createTerm(channel ssh.Channel) *terminal.Terminal { - term := terminal.NewTerminal(channel, s.c.User()+"@nebula > ") +func (s *session) createTerm(channel ssh.Channel) *term.Terminal { + term := term.NewTerminal(channel, s.c.User()+"@nebula > ") term.AutoCompleteCallback = func(line string, pos int, key rune) (newLine string, newPos int, ok bool) { // key 9 is tab if key == 9 {