diff --git a/.github/workflows/gofmt.yml b/.github/workflows/gofmt.yml index 1552cc6..399bc95 100644 --- a/.github/workflows/gofmt.yml +++ b/.github/workflows/gofmt.yml @@ -16,7 +16,7 @@ jobs: - uses: actions/checkout@v4 - - uses: actions/setup-go@v4 + - uses: actions/setup-go@v5 with: go-version-file: 'go.mod' check-latest: true diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index ef4e507..b5b8ced 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -12,7 +12,7 @@ jobs: steps: - uses: actions/checkout@v4 - - uses: actions/setup-go@v4 + - uses: actions/setup-go@v5 with: go-version-file: 'go.mod' check-latest: true @@ -35,7 +35,7 @@ jobs: steps: - uses: actions/checkout@v4 - - uses: actions/setup-go@v4 + - uses: actions/setup-go@v5 with: go-version-file: 'go.mod' check-latest: true @@ -68,7 +68,7 @@ jobs: steps: - uses: actions/checkout@v4 - - uses: actions/setup-go@v4 + - uses: actions/setup-go@v5 with: go-version-file: 'go.mod' check-latest: true diff --git a/.github/workflows/smoke.yml b/.github/workflows/smoke.yml index 99c7e82..2b4adf6 100644 --- a/.github/workflows/smoke.yml +++ b/.github/workflows/smoke.yml @@ -20,13 +20,13 @@ jobs: - uses: actions/checkout@v4 - - uses: actions/setup-go@v4 + - uses: actions/setup-go@v5 with: go-version-file: 'go.mod' check-latest: true - name: build - run: make bin-docker + run: make bin-docker CGO_ENABLED=1 BUILD_ARGS=-race - name: setup docker image working-directory: ./.github/workflows/smoke diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index cc3725f..34fe5f3 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -20,7 +20,7 @@ jobs: - uses: actions/checkout@v4 - - uses: actions/setup-go@v4 + - uses: actions/setup-go@v5 with: go-version-file: 'go.mod' check-latest: true @@ -37,6 +37,9 @@ jobs: - name: End 2 end run: make e2evv + - name: Build test mobile + run: make build-test-mobile + - uses: actions/upload-artifact@v3 with: name: e2e packet flow @@ -50,7 +53,7 @@ jobs: - uses: actions/checkout@v4 - - uses: actions/setup-go@v4 + - uses: actions/setup-go@v5 with: go-version-file: 'go.mod' check-latest: true @@ -74,7 +77,7 @@ jobs: - uses: actions/checkout@v4 - - uses: actions/setup-go@v4 + - uses: actions/setup-go@v5 with: go-version-file: 'go.mod' check-latest: true diff --git a/CHANGELOG.md b/CHANGELOG.md index 6951a4a..71c3ed4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,76 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +## [1.8.2] - 2024-01-08 + +### Fixed + +- Fix multiple routines when listen.port is zero. This was a regression + introduced in v1.6.0. (#1057) + +### Changed + +- Small dependency update for Noise. (#1038) + +## [1.8.1] - 2023-12-19 + +### Security + +- Update `golang.org/x/crypto`, which includes a fix for CVE-2023-48795. (#1048) + +### Fixed + +- Fix a deadlock introduced in v1.8.0 that could occur during handshakes. (#1044) + +- Fix mobile builds. (#1035) + +## [1.8.0] - 2023-12-06 + +### Deprecated + +- The next minor release of Nebula, 1.9.0, will require at least Windows 10 or + Windows Server 2016. This is because support for earlier versions was removed + in Go 1.21. See https://go.dev/doc/go1.21#windows + +### Added + +- Linux: Notify systemd of service readiness. This should resolve timing issues + with services that depend on Nebula being active. For an example of how to + enable this, see: `examples/service_scripts/nebula.service`. (#929) + +- Windows: Use Registered IO (RIO) when possible. Testing on a Windows 11 + machine shows ~50x improvement in throughput. (#905) + +- NetBSD, OpenBSD: Added rudimentary support. (#916, #812) + +- FreeBSD: Add support for naming tun devices. (#903) + +### Changed + +- `pki.disconnect_invalid` will now default to true. This means that once a + certificate expires, the tunnel will be disconnected. If you use SIGHUP to + reload certificates without restarting Nebula, you should ensure all of your + clients are on 1.7.0 or newer before you enable this feature. (#859) + +- Limit how often a busy tunnel can requery the lighthouse. The new config + option `timers.requery_wait_duration` defaults to `60s`. (#940) + +- The internal structures for hostmaps were refactored to reduce memory usage + and the potential for subtle bugs. (#843, #938, #953, #954, #955) + +- Lots of dependency updates. + +### Fixed + +- Windows: Retry wintun device creation if it fails the first time. (#985) + +- Fix issues with firewall reject packets that could cause panics. (#957) + +- Fix relay migration during re-handshakes. (#964) + +- Various other refactors and fixes. (#935, #952, #972, #961, #996, #1002, + #987, #1004, #1030, #1032, ...) + ## [1.7.2] - 2023-06-01 ### Fixed @@ -488,7 +558,10 @@ created.) - Initial public release. -[Unreleased]: https://github.com/slackhq/nebula/compare/v1.7.2...HEAD +[Unreleased]: https://github.com/slackhq/nebula/compare/v1.8.2...HEAD +[1.8.2]: https://github.com/slackhq/nebula/releases/tag/v1.8.2 +[1.8.1]: https://github.com/slackhq/nebula/releases/tag/v1.8.1 +[1.8.0]: https://github.com/slackhq/nebula/releases/tag/v1.8.0 [1.7.2]: https://github.com/slackhq/nebula/releases/tag/v1.7.2 [1.7.1]: https://github.com/slackhq/nebula/releases/tag/v1.7.1 [1.7.0]: https://github.com/slackhq/nebula/releases/tag/v1.7.0 diff --git a/Makefile b/Makefile index 89bd284..3f53cd9 100644 --- a/Makefile +++ b/Makefile @@ -169,6 +169,12 @@ test-cov-html: go test -coverprofile=coverage.out go tool cover -html=coverage.out +build-test-mobile: + GOARCH=amd64 GOOS=ios go build $(shell go list ./... | grep -v '/cmd/\|/examples/') + GOARCH=arm64 GOOS=ios go build $(shell go list ./... | grep -v '/cmd/\|/examples/') + GOARCH=amd64 GOOS=android go build $(shell go list ./... | grep -v '/cmd/\|/examples/') + GOARCH=arm64 GOOS=android go build $(shell go list ./... | grep -v '/cmd/\|/examples/') + bench: go test -bench=. @@ -214,8 +220,9 @@ smoke-multiport-docker: bin-docker cd .github/workflows/smoke/ && NAME="smoke-multiport" ./smoke.sh smoke-docker-race: BUILD_ARGS = -race +smoke-docker-race: CGO_ENABLED = 1 smoke-docker-race: smoke-docker .FORCE: -.PHONY: e2e e2ev e2evv e2evvv e2evvvv test test-cov-html bench bench-cpu bench-cpu-long bin proto release service smoke-docker smoke-docker-race +.PHONY: bench bench-cpu bench-cpu-long bin build-test-mobile e2e e2ev e2evv e2evvv e2evvvv proto release service smoke-docker smoke-docker-race test test-cov-html .DEFAULT_GOAL := bin diff --git a/README.md b/README.md index 6a7e5f2..51e913d 100644 --- a/README.md +++ b/README.md @@ -27,15 +27,26 @@ Check the [releases](https://github.com/slackhq/nebula/releases/latest) page for #### Distribution Packages -- [Arch Linux](https://archlinux.org/packages/community/x86_64/nebula/) +- [Arch Linux](https://archlinux.org/packages/extra/x86_64/nebula/) ``` $ sudo pacman -S nebula ``` + - [Fedora Linux](https://src.fedoraproject.org/rpms/nebula) ``` $ sudo dnf install nebula ``` +- [Debian Linux](https://packages.debian.org/source/stable/nebula) + ``` + $ sudo apt install nebula + ``` + +- [Alpine Linux](https://pkgs.alpinelinux.org/packages?name=nebula) + ``` + $ sudo apk add nebula + ``` + - [macOS Homebrew](https://github.com/Homebrew/homebrew-core/blob/HEAD/Formula/nebula.rb) ``` $ brew install nebula diff --git a/allow_list.go b/allow_list.go index 0e44a12..9186b2f 100644 --- a/allow_list.go +++ b/allow_list.go @@ -12,7 +12,7 @@ import ( type AllowList struct { // The values of this cidrTree are `bool`, signifying allow/deny - cidrTree *cidr.Tree6 + cidrTree *cidr.Tree6[bool] } type RemoteAllowList struct { @@ -20,7 +20,7 @@ type RemoteAllowList struct { // Inside Range Specific, keys of this tree are inside CIDRs and values // are *AllowList - insideAllowLists *cidr.Tree6 + insideAllowLists *cidr.Tree6[*AllowList] } type LocalAllowList struct { @@ -88,7 +88,7 @@ func newAllowList(k string, raw interface{}, handleKey func(key string, value in return nil, fmt.Errorf("config `%s` has invalid type: %T", k, raw) } - tree := cidr.NewTree6() + tree := cidr.NewTree6[bool]() // Keep track of the rules we have added for both ipv4 and ipv6 type allowListRules struct { @@ -218,13 +218,13 @@ func getAllowListInterfaces(k string, v interface{}) ([]AllowListNameRule, error return nameRules, nil } -func getRemoteAllowRanges(c *config.C, k string) (*cidr.Tree6, error) { +func getRemoteAllowRanges(c *config.C, k string) (*cidr.Tree6[*AllowList], error) { value := c.Get(k) if value == nil { return nil, nil } - remoteAllowRanges := cidr.NewTree6() + remoteAllowRanges := cidr.NewTree6[*AllowList]() rawMap, ok := value.(map[interface{}]interface{}) if !ok { @@ -257,13 +257,8 @@ func (al *AllowList) Allow(ip net.IP) bool { return true } - result := al.cidrTree.MostSpecificContains(ip) - switch v := result.(type) { - case bool: - return v - default: - panic(fmt.Errorf("invalid state, allowlist returned: %T %v", result, result)) - } + _, result := al.cidrTree.MostSpecificContains(ip) + return result } func (al *AllowList) AllowIpV4(ip iputil.VpnIp) bool { @@ -271,13 +266,8 @@ func (al *AllowList) AllowIpV4(ip iputil.VpnIp) bool { return true } - result := al.cidrTree.MostSpecificContainsIpV4(ip) - switch v := result.(type) { - case bool: - return v - default: - panic(fmt.Errorf("invalid state, allowlist returned: %T %v", result, result)) - } + _, result := al.cidrTree.MostSpecificContainsIpV4(ip) + return result } func (al *AllowList) AllowIpV6(hi, lo uint64) bool { @@ -285,13 +275,8 @@ func (al *AllowList) AllowIpV6(hi, lo uint64) bool { return true } - result := al.cidrTree.MostSpecificContainsIpV6(hi, lo) - switch v := result.(type) { - case bool: - return v - default: - panic(fmt.Errorf("invalid state, allowlist returned: %T %v", result, result)) - } + _, result := al.cidrTree.MostSpecificContainsIpV6(hi, lo) + return result } func (al *LocalAllowList) Allow(ip net.IP) bool { @@ -352,9 +337,9 @@ func (al *RemoteAllowList) AllowIpV6(vpnIp iputil.VpnIp, hi, lo uint64) bool { func (al *RemoteAllowList) getInsideAllowList(vpnIp iputil.VpnIp) *AllowList { if al.insideAllowLists != nil { - inside := al.insideAllowLists.MostSpecificContainsIpV4(vpnIp) - if inside != nil { - return inside.(*AllowList) + ok, inside := al.insideAllowLists.MostSpecificContainsIpV4(vpnIp) + if ok { + return inside } } return nil diff --git a/allow_list_test.go b/allow_list_test.go index 991b8a3..334cb60 100644 --- a/allow_list_test.go +++ b/allow_list_test.go @@ -100,7 +100,7 @@ func TestNewAllowListFromConfig(t *testing.T) { func TestAllowList_Allow(t *testing.T) { assert.Equal(t, true, ((*AllowList)(nil)).Allow(net.ParseIP("1.1.1.1"))) - tree := cidr.NewTree6() + tree := cidr.NewTree6[bool]() tree.AddCIDR(cidr.Parse("0.0.0.0/0"), true) tree.AddCIDR(cidr.Parse("10.0.0.0/8"), false) tree.AddCIDR(cidr.Parse("10.42.42.42/32"), true) diff --git a/calculated_remote.go b/calculated_remote.go index 910f757..38f5bea 100644 --- a/calculated_remote.go +++ b/calculated_remote.go @@ -51,13 +51,13 @@ func (c *calculatedRemote) Apply(ip iputil.VpnIp) *Ip4AndPort { return &Ip4AndPort{Ip: uint32(masked), Port: c.port} } -func NewCalculatedRemotesFromConfig(c *config.C, k string) (*cidr.Tree4, error) { +func NewCalculatedRemotesFromConfig(c *config.C, k string) (*cidr.Tree4[[]*calculatedRemote], error) { value := c.Get(k) if value == nil { return nil, nil } - calculatedRemotes := cidr.NewTree4() + calculatedRemotes := cidr.NewTree4[[]*calculatedRemote]() rawMap, ok := value.(map[any]any) if !ok { diff --git a/cidr/tree4.go b/cidr/tree4.go index 0839c90..fd4b358 100644 --- a/cidr/tree4.go +++ b/cidr/tree4.go @@ -6,35 +6,36 @@ import ( "github.com/slackhq/nebula/iputil" ) -type Node struct { - left *Node - right *Node - parent *Node - value interface{} +type Node[T any] struct { + left *Node[T] + right *Node[T] + parent *Node[T] + hasValue bool + value T } -type entry struct { +type entry[T any] struct { CIDR *net.IPNet - Value *interface{} + Value T } -type Tree4 struct { - root *Node - list []entry +type Tree4[T any] struct { + root *Node[T] + list []entry[T] } const ( startbit = iputil.VpnIp(0x80000000) ) -func NewTree4() *Tree4 { - tree := new(Tree4) - tree.root = &Node{} - tree.list = []entry{} +func NewTree4[T any]() *Tree4[T] { + tree := new(Tree4[T]) + tree.root = &Node[T]{} + tree.list = []entry[T]{} return tree } -func (tree *Tree4) AddCIDR(cidr *net.IPNet, val interface{}) { +func (tree *Tree4[T]) AddCIDR(cidr *net.IPNet, val T) { bit := startbit node := tree.root next := tree.root @@ -68,14 +69,15 @@ func (tree *Tree4) AddCIDR(cidr *net.IPNet, val interface{}) { } } - tree.list = append(tree.list, entry{CIDR: cidr, Value: &val}) + tree.list = append(tree.list, entry[T]{CIDR: cidr, Value: val}) node.value = val + node.hasValue = true return } // Build up the rest of the tree we don't already have for bit&mask != 0 { - next = &Node{} + next = &Node[T]{} next.parent = node if ip&bit != 0 { @@ -90,17 +92,18 @@ func (tree *Tree4) AddCIDR(cidr *net.IPNet, val interface{}) { // Final node marks our cidr, set the value node.value = val - tree.list = append(tree.list, entry{CIDR: cidr, Value: &val}) + node.hasValue = true + tree.list = append(tree.list, entry[T]{CIDR: cidr, Value: val}) } // Contains finds the first match, which may be the least specific -func (tree *Tree4) Contains(ip iputil.VpnIp) (value interface{}) { +func (tree *Tree4[T]) Contains(ip iputil.VpnIp) (ok bool, value T) { bit := startbit node := tree.root for node != nil { - if node.value != nil { - return node.value + if node.hasValue { + return true, node.value } if ip&bit != 0 { @@ -113,17 +116,18 @@ func (tree *Tree4) Contains(ip iputil.VpnIp) (value interface{}) { } - return value + return false, value } // MostSpecificContains finds the most specific match -func (tree *Tree4) MostSpecificContains(ip iputil.VpnIp) (value interface{}) { +func (tree *Tree4[T]) MostSpecificContains(ip iputil.VpnIp) (ok bool, value T) { bit := startbit node := tree.root for node != nil { - if node.value != nil { + if node.hasValue { value = node.value + ok = true } if ip&bit != 0 { @@ -135,11 +139,12 @@ func (tree *Tree4) MostSpecificContains(ip iputil.VpnIp) (value interface{}) { bit >>= 1 } - return value + return ok, value } // Match finds the most specific match -func (tree *Tree4) Match(ip iputil.VpnIp) (value interface{}) { +// TODO this is exact match +func (tree *Tree4[T]) Match(ip iputil.VpnIp) (ok bool, value T) { bit := startbit node := tree.root lastNode := node @@ -157,11 +162,12 @@ func (tree *Tree4) Match(ip iputil.VpnIp) (value interface{}) { if bit == 0 && lastNode != nil { value = lastNode.value + ok = true } - return value + return ok, value } // List will return all CIDRs and their current values. Do not modify the contents! -func (tree *Tree4) List() []entry { +func (tree *Tree4[T]) List() []entry[T] { return tree.list } diff --git a/cidr/tree4_test.go b/cidr/tree4_test.go index dce8d54..acd403e 100644 --- a/cidr/tree4_test.go +++ b/cidr/tree4_test.go @@ -9,7 +9,7 @@ import ( ) func TestCIDRTree_List(t *testing.T) { - tree := NewTree4() + tree := NewTree4[string]() tree.AddCIDR(Parse("1.0.0.0/16"), "1") tree.AddCIDR(Parse("1.0.0.0/8"), "2") tree.AddCIDR(Parse("1.0.0.0/16"), "3") @@ -17,13 +17,13 @@ func TestCIDRTree_List(t *testing.T) { list := tree.List() assert.Len(t, list, 2) assert.Equal(t, "1.0.0.0/8", list[0].CIDR.String()) - assert.Equal(t, "2", *list[0].Value) + assert.Equal(t, "2", list[0].Value) assert.Equal(t, "1.0.0.0/16", list[1].CIDR.String()) - assert.Equal(t, "4", *list[1].Value) + assert.Equal(t, "4", list[1].Value) } func TestCIDRTree_Contains(t *testing.T) { - tree := NewTree4() + tree := NewTree4[string]() tree.AddCIDR(Parse("1.0.0.0/8"), "1") tree.AddCIDR(Parse("2.1.0.0/16"), "2") tree.AddCIDR(Parse("3.1.1.0/24"), "3") @@ -33,35 +33,43 @@ func TestCIDRTree_Contains(t *testing.T) { tree.AddCIDR(Parse("254.0.0.0/4"), "5") tests := []struct { + Found bool Result interface{} IP string }{ - {"1", "1.0.0.0"}, - {"1", "1.255.255.255"}, - {"2", "2.1.0.0"}, - {"2", "2.1.255.255"}, - {"3", "3.1.1.0"}, - {"3", "3.1.1.255"}, - {"4a", "4.1.1.255"}, - {"4a", "4.1.1.1"}, - {"5", "240.0.0.0"}, - {"5", "255.255.255.255"}, - {nil, "239.0.0.0"}, - {nil, "4.1.2.2"}, + {true, "1", "1.0.0.0"}, + {true, "1", "1.255.255.255"}, + {true, "2", "2.1.0.0"}, + {true, "2", "2.1.255.255"}, + {true, "3", "3.1.1.0"}, + {true, "3", "3.1.1.255"}, + {true, "4a", "4.1.1.255"}, + {true, "4a", "4.1.1.1"}, + {true, "5", "240.0.0.0"}, + {true, "5", "255.255.255.255"}, + {false, "", "239.0.0.0"}, + {false, "", "4.1.2.2"}, } for _, tt := range tests { - assert.Equal(t, tt.Result, tree.Contains(iputil.Ip2VpnIp(net.ParseIP(tt.IP)))) + ok, r := tree.Contains(iputil.Ip2VpnIp(net.ParseIP(tt.IP))) + assert.Equal(t, tt.Found, ok) + assert.Equal(t, tt.Result, r) } - tree = NewTree4() + tree = NewTree4[string]() tree.AddCIDR(Parse("1.1.1.1/0"), "cool") - assert.Equal(t, "cool", tree.Contains(iputil.Ip2VpnIp(net.ParseIP("0.0.0.0")))) - assert.Equal(t, "cool", tree.Contains(iputil.Ip2VpnIp(net.ParseIP("255.255.255.255")))) + ok, r := tree.Contains(iputil.Ip2VpnIp(net.ParseIP("0.0.0.0"))) + assert.True(t, ok) + assert.Equal(t, "cool", r) + + ok, r = tree.Contains(iputil.Ip2VpnIp(net.ParseIP("255.255.255.255"))) + assert.True(t, ok) + assert.Equal(t, "cool", r) } func TestCIDRTree_MostSpecificContains(t *testing.T) { - tree := NewTree4() + tree := NewTree4[string]() tree.AddCIDR(Parse("1.0.0.0/8"), "1") tree.AddCIDR(Parse("2.1.0.0/16"), "2") tree.AddCIDR(Parse("3.1.1.0/24"), "3") @@ -71,59 +79,75 @@ func TestCIDRTree_MostSpecificContains(t *testing.T) { tree.AddCIDR(Parse("254.0.0.0/4"), "5") tests := []struct { + Found bool Result interface{} IP string }{ - {"1", "1.0.0.0"}, - {"1", "1.255.255.255"}, - {"2", "2.1.0.0"}, - {"2", "2.1.255.255"}, - {"3", "3.1.1.0"}, - {"3", "3.1.1.255"}, - {"4a", "4.1.1.255"}, - {"4b", "4.1.1.2"}, - {"4c", "4.1.1.1"}, - {"5", "240.0.0.0"}, - {"5", "255.255.255.255"}, - {nil, "239.0.0.0"}, - {nil, "4.1.2.2"}, + {true, "1", "1.0.0.0"}, + {true, "1", "1.255.255.255"}, + {true, "2", "2.1.0.0"}, + {true, "2", "2.1.255.255"}, + {true, "3", "3.1.1.0"}, + {true, "3", "3.1.1.255"}, + {true, "4a", "4.1.1.255"}, + {true, "4b", "4.1.1.2"}, + {true, "4c", "4.1.1.1"}, + {true, "5", "240.0.0.0"}, + {true, "5", "255.255.255.255"}, + {false, "", "239.0.0.0"}, + {false, "", "4.1.2.2"}, } for _, tt := range tests { - assert.Equal(t, tt.Result, tree.MostSpecificContains(iputil.Ip2VpnIp(net.ParseIP(tt.IP)))) + ok, r := tree.MostSpecificContains(iputil.Ip2VpnIp(net.ParseIP(tt.IP))) + assert.Equal(t, tt.Found, ok) + assert.Equal(t, tt.Result, r) } - tree = NewTree4() + tree = NewTree4[string]() tree.AddCIDR(Parse("1.1.1.1/0"), "cool") - assert.Equal(t, "cool", tree.MostSpecificContains(iputil.Ip2VpnIp(net.ParseIP("0.0.0.0")))) - assert.Equal(t, "cool", tree.MostSpecificContains(iputil.Ip2VpnIp(net.ParseIP("255.255.255.255")))) + ok, r := tree.MostSpecificContains(iputil.Ip2VpnIp(net.ParseIP("0.0.0.0"))) + assert.True(t, ok) + assert.Equal(t, "cool", r) + + ok, r = tree.MostSpecificContains(iputil.Ip2VpnIp(net.ParseIP("255.255.255.255"))) + assert.True(t, ok) + assert.Equal(t, "cool", r) } func TestCIDRTree_Match(t *testing.T) { - tree := NewTree4() + tree := NewTree4[string]() tree.AddCIDR(Parse("4.1.1.0/32"), "1a") tree.AddCIDR(Parse("4.1.1.1/32"), "1b") tests := []struct { + Found bool Result interface{} IP string }{ - {"1a", "4.1.1.0"}, - {"1b", "4.1.1.1"}, + {true, "1a", "4.1.1.0"}, + {true, "1b", "4.1.1.1"}, } for _, tt := range tests { - assert.Equal(t, tt.Result, tree.Match(iputil.Ip2VpnIp(net.ParseIP(tt.IP)))) + ok, r := tree.Match(iputil.Ip2VpnIp(net.ParseIP(tt.IP))) + assert.Equal(t, tt.Found, ok) + assert.Equal(t, tt.Result, r) } - tree = NewTree4() + tree = NewTree4[string]() tree.AddCIDR(Parse("1.1.1.1/0"), "cool") - assert.Equal(t, "cool", tree.Contains(iputil.Ip2VpnIp(net.ParseIP("0.0.0.0")))) - assert.Equal(t, "cool", tree.Contains(iputil.Ip2VpnIp(net.ParseIP("255.255.255.255")))) + ok, r := tree.Contains(iputil.Ip2VpnIp(net.ParseIP("0.0.0.0"))) + assert.True(t, ok) + assert.Equal(t, "cool", r) + + ok, r = tree.Contains(iputil.Ip2VpnIp(net.ParseIP("255.255.255.255"))) + assert.True(t, ok) + assert.Equal(t, "cool", r) } func BenchmarkCIDRTree_Contains(b *testing.B) { - tree := NewTree4() + tree := NewTree4[string]() tree.AddCIDR(Parse("1.1.0.0/16"), "1") tree.AddCIDR(Parse("1.2.1.1/32"), "1") tree.AddCIDR(Parse("192.2.1.1/32"), "1") @@ -145,7 +169,7 @@ func BenchmarkCIDRTree_Contains(b *testing.B) { } func BenchmarkCIDRTree_Match(b *testing.B) { - tree := NewTree4() + tree := NewTree4[string]() tree.AddCIDR(Parse("1.1.0.0/16"), "1") tree.AddCIDR(Parse("1.2.1.1/32"), "1") tree.AddCIDR(Parse("192.2.1.1/32"), "1") diff --git a/cidr/tree6.go b/cidr/tree6.go index d13c93d..3f2cd2a 100644 --- a/cidr/tree6.go +++ b/cidr/tree6.go @@ -8,20 +8,20 @@ import ( const startbit6 = uint64(1 << 63) -type Tree6 struct { - root4 *Node - root6 *Node +type Tree6[T any] struct { + root4 *Node[T] + root6 *Node[T] } -func NewTree6() *Tree6 { - tree := new(Tree6) - tree.root4 = &Node{} - tree.root6 = &Node{} +func NewTree6[T any]() *Tree6[T] { + tree := new(Tree6[T]) + tree.root4 = &Node[T]{} + tree.root6 = &Node[T]{} return tree } -func (tree *Tree6) AddCIDR(cidr *net.IPNet, val interface{}) { - var node, next *Node +func (tree *Tree6[T]) AddCIDR(cidr *net.IPNet, val T) { + var node, next *Node[T] cidrIP, ipv4 := isIPV4(cidr.IP) if ipv4 { @@ -56,7 +56,7 @@ func (tree *Tree6) AddCIDR(cidr *net.IPNet, val interface{}) { // Build up the rest of the tree we don't already have for bit&mask != 0 { - next = &Node{} + next = &Node[T]{} next.parent = node if ip&bit != 0 { @@ -72,11 +72,12 @@ func (tree *Tree6) AddCIDR(cidr *net.IPNet, val interface{}) { // Final node marks our cidr, set the value node.value = val + node.hasValue = true } // Finds the most specific match -func (tree *Tree6) MostSpecificContains(ip net.IP) (value interface{}) { - var node *Node +func (tree *Tree6[T]) MostSpecificContains(ip net.IP) (ok bool, value T) { + var node *Node[T] wholeIP, ipv4 := isIPV4(ip) if ipv4 { @@ -90,8 +91,9 @@ func (tree *Tree6) MostSpecificContains(ip net.IP) (value interface{}) { bit := startbit for node != nil { - if node.value != nil { + if node.hasValue { value = node.value + ok = true } if bit == 0 { @@ -108,16 +110,17 @@ func (tree *Tree6) MostSpecificContains(ip net.IP) (value interface{}) { } } - return value + return ok, value } -func (tree *Tree6) MostSpecificContainsIpV4(ip iputil.VpnIp) (value interface{}) { +func (tree *Tree6[T]) MostSpecificContainsIpV4(ip iputil.VpnIp) (ok bool, value T) { bit := startbit node := tree.root4 for node != nil { - if node.value != nil { + if node.hasValue { value = node.value + ok = true } if ip&bit != 0 { @@ -129,10 +132,10 @@ func (tree *Tree6) MostSpecificContainsIpV4(ip iputil.VpnIp) (value interface{}) bit >>= 1 } - return value + return ok, value } -func (tree *Tree6) MostSpecificContainsIpV6(hi, lo uint64) (value interface{}) { +func (tree *Tree6[T]) MostSpecificContainsIpV6(hi, lo uint64) (ok bool, value T) { ip := hi node := tree.root6 @@ -140,8 +143,9 @@ func (tree *Tree6) MostSpecificContainsIpV6(hi, lo uint64) (value interface{}) { bit := startbit6 for node != nil { - if node.value != nil { + if node.hasValue { value = node.value + ok = true } if bit == 0 { @@ -160,7 +164,7 @@ func (tree *Tree6) MostSpecificContainsIpV6(hi, lo uint64) (value interface{}) { ip = lo } - return value + return ok, value } func isIPV4(ip net.IP) (net.IP, bool) { diff --git a/cidr/tree6_test.go b/cidr/tree6_test.go index b6dc4c2..eb159ec 100644 --- a/cidr/tree6_test.go +++ b/cidr/tree6_test.go @@ -9,7 +9,7 @@ import ( ) func TestCIDR6Tree_MostSpecificContains(t *testing.T) { - tree := NewTree6() + tree := NewTree6[string]() tree.AddCIDR(Parse("1.0.0.0/8"), "1") tree.AddCIDR(Parse("2.1.0.0/16"), "2") tree.AddCIDR(Parse("3.1.1.0/24"), "3") @@ -22,53 +22,68 @@ func TestCIDR6Tree_MostSpecificContains(t *testing.T) { tree.AddCIDR(Parse("1:2:0:4:5:0:0:0/96"), "6c") tests := []struct { + Found bool Result interface{} IP string }{ - {"1", "1.0.0.0"}, - {"1", "1.255.255.255"}, - {"2", "2.1.0.0"}, - {"2", "2.1.255.255"}, - {"3", "3.1.1.0"}, - {"3", "3.1.1.255"}, - {"4a", "4.1.1.255"}, - {"4b", "4.1.1.2"}, - {"4c", "4.1.1.1"}, - {"5", "240.0.0.0"}, - {"5", "255.255.255.255"}, - {"6a", "1:2:0:4:1:1:1:1"}, - {"6b", "1:2:0:4:5:1:1:1"}, - {"6c", "1:2:0:4:5:0:0:0"}, - {nil, "239.0.0.0"}, - {nil, "4.1.2.2"}, + {true, "1", "1.0.0.0"}, + {true, "1", "1.255.255.255"}, + {true, "2", "2.1.0.0"}, + {true, "2", "2.1.255.255"}, + {true, "3", "3.1.1.0"}, + {true, "3", "3.1.1.255"}, + {true, "4a", "4.1.1.255"}, + {true, "4b", "4.1.1.2"}, + {true, "4c", "4.1.1.1"}, + {true, "5", "240.0.0.0"}, + {true, "5", "255.255.255.255"}, + {true, "6a", "1:2:0:4:1:1:1:1"}, + {true, "6b", "1:2:0:4:5:1:1:1"}, + {true, "6c", "1:2:0:4:5:0:0:0"}, + {false, "", "239.0.0.0"}, + {false, "", "4.1.2.2"}, } for _, tt := range tests { - assert.Equal(t, tt.Result, tree.MostSpecificContains(net.ParseIP(tt.IP))) + ok, r := tree.MostSpecificContains(net.ParseIP(tt.IP)) + assert.Equal(t, tt.Found, ok) + assert.Equal(t, tt.Result, r) } - tree = NewTree6() + tree = NewTree6[string]() tree.AddCIDR(Parse("1.1.1.1/0"), "cool") tree.AddCIDR(Parse("::/0"), "cool6") - assert.Equal(t, "cool", tree.MostSpecificContains(net.ParseIP("0.0.0.0"))) - assert.Equal(t, "cool", tree.MostSpecificContains(net.ParseIP("255.255.255.255"))) - assert.Equal(t, "cool6", tree.MostSpecificContains(net.ParseIP("::"))) - assert.Equal(t, "cool6", tree.MostSpecificContains(net.ParseIP("1:2:3:4:5:6:7:8"))) + ok, r := tree.MostSpecificContains(net.ParseIP("0.0.0.0")) + assert.True(t, ok) + assert.Equal(t, "cool", r) + + ok, r = tree.MostSpecificContains(net.ParseIP("255.255.255.255")) + assert.True(t, ok) + assert.Equal(t, "cool", r) + + ok, r = tree.MostSpecificContains(net.ParseIP("::")) + assert.True(t, ok) + assert.Equal(t, "cool6", r) + + ok, r = tree.MostSpecificContains(net.ParseIP("1:2:3:4:5:6:7:8")) + assert.True(t, ok) + assert.Equal(t, "cool6", r) } func TestCIDR6Tree_MostSpecificContainsIpV6(t *testing.T) { - tree := NewTree6() + tree := NewTree6[string]() tree.AddCIDR(Parse("1:2:0:4:5:0:0:0/64"), "6a") tree.AddCIDR(Parse("1:2:0:4:5:0:0:0/80"), "6b") tree.AddCIDR(Parse("1:2:0:4:5:0:0:0/96"), "6c") tests := []struct { + Found bool Result interface{} IP string }{ - {"6a", "1:2:0:4:1:1:1:1"}, - {"6b", "1:2:0:4:5:1:1:1"}, - {"6c", "1:2:0:4:5:0:0:0"}, + {true, "6a", "1:2:0:4:1:1:1:1"}, + {true, "6b", "1:2:0:4:5:1:1:1"}, + {true, "6c", "1:2:0:4:5:0:0:0"}, } for _, tt := range tests { @@ -76,6 +91,8 @@ func TestCIDR6Tree_MostSpecificContainsIpV6(t *testing.T) { hi := binary.BigEndian.Uint64(ip[:8]) lo := binary.BigEndian.Uint64(ip[8:]) - assert.Equal(t, tt.Result, tree.MostSpecificContainsIpV6(hi, lo)) + ok, r := tree.MostSpecificContainsIpV6(hi, lo) + assert.Equal(t, tt.Found, ok) + assert.Equal(t, tt.Result, r) } } diff --git a/cmd/nebula-cert/ca.go b/cmd/nebula-cert/ca.go index e9ad3cb..69df4ab 100644 --- a/cmd/nebula-cert/ca.go +++ b/cmd/nebula-cert/ca.go @@ -7,7 +7,6 @@ import ( "flag" "fmt" "io" - "io/ioutil" "math" "net" "os" @@ -213,27 +212,27 @@ func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error return fmt.Errorf("error while signing: %s", err) } + var b []byte if *cf.encryption { - b, err := cert.EncryptAndMarshalSigningPrivateKey(curve, rawPriv, passphrase, kdfParams) + b, err = cert.EncryptAndMarshalSigningPrivateKey(curve, rawPriv, passphrase, kdfParams) if err != nil { return fmt.Errorf("error while encrypting out-key: %s", err) } - - err = ioutil.WriteFile(*cf.outKeyPath, b, 0600) } else { - err = ioutil.WriteFile(*cf.outKeyPath, cert.MarshalSigningPrivateKey(curve, rawPriv), 0600) + b = cert.MarshalSigningPrivateKey(curve, rawPriv) } + err = os.WriteFile(*cf.outKeyPath, b, 0600) if err != nil { return fmt.Errorf("error while writing out-key: %s", err) } - b, err := nc.MarshalToPEM() + b, err = nc.MarshalToPEM() if err != nil { return fmt.Errorf("error while marshalling certificate: %s", err) } - err = ioutil.WriteFile(*cf.outCertPath, b, 0600) + err = os.WriteFile(*cf.outCertPath, b, 0600) if err != nil { return fmt.Errorf("error while writing out-crt: %s", err) } @@ -244,7 +243,7 @@ func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error return fmt.Errorf("error while generating qr code: %s", err) } - err = ioutil.WriteFile(*cf.outQRPath, b, 0600) + err = os.WriteFile(*cf.outQRPath, b, 0600) if err != nil { return fmt.Errorf("error while writing out-qr: %s", err) } diff --git a/cmd/nebula-cert/ca_test.go b/cmd/nebula-cert/ca_test.go index ae79baf..3a53405 100644 --- a/cmd/nebula-cert/ca_test.go +++ b/cmd/nebula-cert/ca_test.go @@ -7,7 +7,6 @@ import ( "bytes" "encoding/pem" "errors" - "io/ioutil" "os" "strings" "testing" @@ -107,7 +106,7 @@ func Test_ca(t *testing.T) { assert.Equal(t, "", eb.String()) // create temp key file - keyF, err := ioutil.TempFile("", "test.key") + keyF, err := os.CreateTemp("", "test.key") assert.Nil(t, err) os.Remove(keyF.Name()) @@ -120,7 +119,7 @@ func Test_ca(t *testing.T) { assert.Equal(t, "", eb.String()) // create temp cert file - crtF, err := ioutil.TempFile("", "test.crt") + crtF, err := os.CreateTemp("", "test.crt") assert.Nil(t, err) os.Remove(crtF.Name()) os.Remove(keyF.Name()) @@ -134,13 +133,13 @@ func Test_ca(t *testing.T) { assert.Equal(t, "", eb.String()) // read cert and key files - rb, _ := ioutil.ReadFile(keyF.Name()) + rb, _ := os.ReadFile(keyF.Name()) lKey, b, err := cert.UnmarshalEd25519PrivateKey(rb) assert.Len(t, b, 0) assert.Nil(t, err) assert.Len(t, lKey, 64) - rb, _ = ioutil.ReadFile(crtF.Name()) + rb, _ = os.ReadFile(crtF.Name()) lCrt, b, err := cert.UnmarshalNebulaCertificateFromPEM(rb) assert.Len(t, b, 0) assert.Nil(t, err) @@ -166,7 +165,7 @@ func Test_ca(t *testing.T) { assert.Equal(t, "", eb.String()) // read encrypted key file and verify default params - rb, _ = ioutil.ReadFile(keyF.Name()) + rb, _ = os.ReadFile(keyF.Name()) k, _ := pem.Decode(rb) ned, err := cert.UnmarshalNebulaEncryptedData(k.Bytes) assert.Nil(t, err) diff --git a/cmd/nebula-cert/keygen.go b/cmd/nebula-cert/keygen.go index 0016fc1..d94cbf1 100644 --- a/cmd/nebula-cert/keygen.go +++ b/cmd/nebula-cert/keygen.go @@ -4,7 +4,6 @@ import ( "flag" "fmt" "io" - "io/ioutil" "os" "github.com/slackhq/nebula/cert" @@ -54,12 +53,12 @@ func keygen(args []string, out io.Writer, errOut io.Writer) error { return fmt.Errorf("invalid curve: %s", *cf.curve) } - err = ioutil.WriteFile(*cf.outKeyPath, cert.MarshalPrivateKey(curve, rawPriv), 0600) + err = os.WriteFile(*cf.outKeyPath, cert.MarshalPrivateKey(curve, rawPriv), 0600) if err != nil { return fmt.Errorf("error while writing out-key: %s", err) } - err = ioutil.WriteFile(*cf.outPubPath, cert.MarshalPublicKey(curve, pub), 0600) + err = os.WriteFile(*cf.outPubPath, cert.MarshalPublicKey(curve, pub), 0600) if err != nil { return fmt.Errorf("error while writing out-pub: %s", err) } diff --git a/cmd/nebula-cert/keygen_test.go b/cmd/nebula-cert/keygen_test.go index 1480e5a..9a3b3f3 100644 --- a/cmd/nebula-cert/keygen_test.go +++ b/cmd/nebula-cert/keygen_test.go @@ -2,7 +2,6 @@ package main import ( "bytes" - "io/ioutil" "os" "testing" @@ -54,7 +53,7 @@ func Test_keygen(t *testing.T) { assert.Equal(t, "", eb.String()) // create temp key file - keyF, err := ioutil.TempFile("", "test.key") + keyF, err := os.CreateTemp("", "test.key") assert.Nil(t, err) defer os.Remove(keyF.Name()) @@ -67,7 +66,7 @@ func Test_keygen(t *testing.T) { assert.Equal(t, "", eb.String()) // create temp pub file - pubF, err := ioutil.TempFile("", "test.pub") + pubF, err := os.CreateTemp("", "test.pub") assert.Nil(t, err) defer os.Remove(pubF.Name()) @@ -80,13 +79,13 @@ func Test_keygen(t *testing.T) { assert.Equal(t, "", eb.String()) // read cert and key files - rb, _ := ioutil.ReadFile(keyF.Name()) + rb, _ := os.ReadFile(keyF.Name()) lKey, b, err := cert.UnmarshalX25519PrivateKey(rb) assert.Len(t, b, 0) assert.Nil(t, err) assert.Len(t, lKey, 32) - rb, _ = ioutil.ReadFile(pubF.Name()) + rb, _ = os.ReadFile(pubF.Name()) lPub, b, err := cert.UnmarshalX25519PublicKey(rb) assert.Len(t, b, 0) assert.Nil(t, err) diff --git a/cmd/nebula-cert/print.go b/cmd/nebula-cert/print.go index 222dbc0..746d6a3 100644 --- a/cmd/nebula-cert/print.go +++ b/cmd/nebula-cert/print.go @@ -5,7 +5,6 @@ import ( "flag" "fmt" "io" - "io/ioutil" "os" "strings" @@ -41,7 +40,7 @@ func printCert(args []string, out io.Writer, errOut io.Writer) error { return err } - rawCert, err := ioutil.ReadFile(*pf.path) + rawCert, err := os.ReadFile(*pf.path) if err != nil { return fmt.Errorf("unable to read cert; %s", err) } @@ -87,7 +86,7 @@ func printCert(args []string, out io.Writer, errOut io.Writer) error { return fmt.Errorf("error while generating qr code: %s", err) } - err = ioutil.WriteFile(*pf.outQRPath, b, 0600) + err = os.WriteFile(*pf.outQRPath, b, 0600) if err != nil { return fmt.Errorf("error while writing out-qr: %s", err) } diff --git a/cmd/nebula-cert/print_test.go b/cmd/nebula-cert/print_test.go index eb117f1..9fa8a54 100644 --- a/cmd/nebula-cert/print_test.go +++ b/cmd/nebula-cert/print_test.go @@ -2,7 +2,6 @@ package main import ( "bytes" - "io/ioutil" "os" "testing" "time" @@ -54,7 +53,7 @@ func Test_printCert(t *testing.T) { // invalid cert at path ob.Reset() eb.Reset() - tf, err := ioutil.TempFile("", "print-cert") + tf, err := os.CreateTemp("", "print-cert") assert.Nil(t, err) defer os.Remove(tf.Name()) diff --git a/cmd/nebula-cert/sign.go b/cmd/nebula-cert/sign.go index 9938401..35d6446 100644 --- a/cmd/nebula-cert/sign.go +++ b/cmd/nebula-cert/sign.go @@ -6,7 +6,6 @@ import ( "flag" "fmt" "io" - "io/ioutil" "net" "os" "strings" @@ -73,7 +72,7 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) return newHelpErrorf("cannot set both -in-pub and -out-key") } - rawCAKey, err := ioutil.ReadFile(*sf.caKeyPath) + rawCAKey, err := os.ReadFile(*sf.caKeyPath) if err != nil { return fmt.Errorf("error while reading ca-key: %s", err) } @@ -112,7 +111,7 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) return fmt.Errorf("error while parsing ca-key: %s", err) } - rawCACert, err := ioutil.ReadFile(*sf.caCertPath) + rawCACert, err := os.ReadFile(*sf.caCertPath) if err != nil { return fmt.Errorf("error while reading ca-crt: %s", err) } @@ -178,7 +177,7 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) var pub, rawPriv []byte if *sf.inPubPath != "" { - rawPub, err := ioutil.ReadFile(*sf.inPubPath) + rawPub, err := os.ReadFile(*sf.inPubPath) if err != nil { return fmt.Errorf("error while reading in-pub: %s", err) } @@ -235,7 +234,7 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) return fmt.Errorf("refusing to overwrite existing key: %s", *sf.outKeyPath) } - err = ioutil.WriteFile(*sf.outKeyPath, cert.MarshalPrivateKey(curve, rawPriv), 0600) + err = os.WriteFile(*sf.outKeyPath, cert.MarshalPrivateKey(curve, rawPriv), 0600) if err != nil { return fmt.Errorf("error while writing out-key: %s", err) } @@ -246,7 +245,7 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) return fmt.Errorf("error while marshalling certificate: %s", err) } - err = ioutil.WriteFile(*sf.outCertPath, b, 0600) + err = os.WriteFile(*sf.outCertPath, b, 0600) if err != nil { return fmt.Errorf("error while writing out-crt: %s", err) } @@ -257,7 +256,7 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) return fmt.Errorf("error while generating qr code: %s", err) } - err = ioutil.WriteFile(*sf.outQRPath, b, 0600) + err = os.WriteFile(*sf.outQRPath, b, 0600) if err != nil { return fmt.Errorf("error while writing out-qr: %s", err) } diff --git a/cmd/nebula-cert/sign_test.go b/cmd/nebula-cert/sign_test.go index 7018fd2..adf83a2 100644 --- a/cmd/nebula-cert/sign_test.go +++ b/cmd/nebula-cert/sign_test.go @@ -7,7 +7,6 @@ import ( "bytes" "crypto/rand" "errors" - "io/ioutil" "os" "testing" "time" @@ -104,7 +103,7 @@ func Test_signCert(t *testing.T) { // failed to unmarshal key ob.Reset() eb.Reset() - caKeyF, err := ioutil.TempFile("", "sign-cert.key") + caKeyF, err := os.CreateTemp("", "sign-cert.key") assert.Nil(t, err) defer os.Remove(caKeyF.Name()) @@ -128,7 +127,7 @@ func Test_signCert(t *testing.T) { // failed to unmarshal cert ob.Reset() eb.Reset() - caCrtF, err := ioutil.TempFile("", "sign-cert.crt") + caCrtF, err := os.CreateTemp("", "sign-cert.crt") assert.Nil(t, err) defer os.Remove(caCrtF.Name()) @@ -159,7 +158,7 @@ func Test_signCert(t *testing.T) { // failed to unmarshal pub ob.Reset() eb.Reset() - inPubF, err := ioutil.TempFile("", "in.pub") + inPubF, err := os.CreateTemp("", "in.pub") assert.Nil(t, err) defer os.Remove(inPubF.Name()) @@ -206,7 +205,7 @@ func Test_signCert(t *testing.T) { // mismatched ca key _, caPriv2, _ := ed25519.GenerateKey(rand.Reader) - caKeyF2, err := ioutil.TempFile("", "sign-cert-2.key") + caKeyF2, err := os.CreateTemp("", "sign-cert-2.key") assert.Nil(t, err) defer os.Remove(caKeyF2.Name()) caKeyF2.Write(cert.MarshalEd25519PrivateKey(caPriv2)) @@ -227,7 +226,7 @@ func Test_signCert(t *testing.T) { assert.Empty(t, eb.String()) // create temp key file - keyF, err := ioutil.TempFile("", "test.key") + keyF, err := os.CreateTemp("", "test.key") assert.Nil(t, err) os.Remove(keyF.Name()) @@ -241,7 +240,7 @@ func Test_signCert(t *testing.T) { os.Remove(keyF.Name()) // create temp cert file - crtF, err := ioutil.TempFile("", "test.crt") + crtF, err := os.CreateTemp("", "test.crt") assert.Nil(t, err) os.Remove(crtF.Name()) @@ -254,13 +253,13 @@ func Test_signCert(t *testing.T) { assert.Empty(t, eb.String()) // read cert and key files - rb, _ := ioutil.ReadFile(keyF.Name()) + rb, _ := os.ReadFile(keyF.Name()) lKey, b, err := cert.UnmarshalX25519PrivateKey(rb) assert.Len(t, b, 0) assert.Nil(t, err) assert.Len(t, lKey, 32) - rb, _ = ioutil.ReadFile(crtF.Name()) + rb, _ = os.ReadFile(crtF.Name()) lCrt, b, err := cert.UnmarshalNebulaCertificateFromPEM(rb) assert.Len(t, b, 0) assert.Nil(t, err) @@ -296,7 +295,7 @@ func Test_signCert(t *testing.T) { assert.Empty(t, eb.String()) // read cert file and check pub key matches in-pub - rb, _ = ioutil.ReadFile(crtF.Name()) + rb, _ = os.ReadFile(crtF.Name()) lCrt, b, err = cert.UnmarshalNebulaCertificateFromPEM(rb) assert.Len(t, b, 0) assert.Nil(t, err) @@ -348,11 +347,11 @@ func Test_signCert(t *testing.T) { ob.Reset() eb.Reset() - caKeyF, err = ioutil.TempFile("", "sign-cert.key") + caKeyF, err = os.CreateTemp("", "sign-cert.key") assert.Nil(t, err) defer os.Remove(caKeyF.Name()) - caCrtF, err = ioutil.TempFile("", "sign-cert.crt") + caCrtF, err = os.CreateTemp("", "sign-cert.crt") assert.Nil(t, err) defer os.Remove(caCrtF.Name()) diff --git a/cmd/nebula-cert/verify.go b/cmd/nebula-cert/verify.go index 51b9a93..c955913 100644 --- a/cmd/nebula-cert/verify.go +++ b/cmd/nebula-cert/verify.go @@ -4,7 +4,6 @@ import ( "flag" "fmt" "io" - "io/ioutil" "os" "strings" "time" @@ -40,7 +39,7 @@ func verify(args []string, out io.Writer, errOut io.Writer) error { return err } - rawCACert, err := ioutil.ReadFile(*vf.caPath) + rawCACert, err := os.ReadFile(*vf.caPath) if err != nil { return fmt.Errorf("error while reading ca: %s", err) } @@ -57,7 +56,7 @@ func verify(args []string, out io.Writer, errOut io.Writer) error { } } - rawCert, err := ioutil.ReadFile(*vf.certPath) + rawCert, err := os.ReadFile(*vf.certPath) if err != nil { return fmt.Errorf("unable to read crt; %s", err) } diff --git a/cmd/nebula-cert/verify_test.go b/cmd/nebula-cert/verify_test.go index 25014fd..f0f4c78 100644 --- a/cmd/nebula-cert/verify_test.go +++ b/cmd/nebula-cert/verify_test.go @@ -3,7 +3,6 @@ package main import ( "bytes" "crypto/rand" - "io/ioutil" "os" "testing" "time" @@ -56,7 +55,7 @@ func Test_verify(t *testing.T) { // invalid ca at path ob.Reset() eb.Reset() - caFile, err := ioutil.TempFile("", "verify-ca") + caFile, err := os.CreateTemp("", "verify-ca") assert.Nil(t, err) defer os.Remove(caFile.Name()) @@ -92,7 +91,7 @@ func Test_verify(t *testing.T) { // invalid crt at path ob.Reset() eb.Reset() - certFile, err := ioutil.TempFile("", "verify-cert") + certFile, err := os.CreateTemp("", "verify-cert") assert.Nil(t, err) defer os.Remove(certFile.Name()) diff --git a/config/config.go b/config/config.go index bc3818d..1aea832 100644 --- a/config/config.go +++ b/config/config.go @@ -4,7 +4,6 @@ import ( "context" "errors" "fmt" - "io/ioutil" "math" "os" "os/signal" @@ -122,6 +121,10 @@ func (c *C) HasChanged(k string) bool { // CatchHUP will listen for the HUP signal in a go routine and reload all configs found in the // original path provided to Load. The old settings are shallow copied for change detection after the reload. func (c *C) CatchHUP(ctx context.Context) { + if c.path == "" { + return + } + ch := make(chan os.Signal, 1) signal.Notify(ch, syscall.SIGHUP) @@ -358,7 +361,7 @@ func (c *C) parse() error { var m map[interface{}]interface{} for _, path := range c.files { - b, err := ioutil.ReadFile(path) + b, err := os.ReadFile(path) if err != nil { return err } diff --git a/config/config_test.go b/config/config_test.go index 1001f8d..fa94393 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -1,7 +1,6 @@ package config import ( - "io/ioutil" "os" "path/filepath" "testing" @@ -16,10 +15,10 @@ import ( func TestConfig_Load(t *testing.T) { l := test.NewLogger() - dir, err := ioutil.TempDir("", "config-test") + dir, err := os.MkdirTemp("", "config-test") // invalid yaml c := NewC(l) - ioutil.WriteFile(filepath.Join(dir, "01.yaml"), []byte(" invalid yaml"), 0644) + os.WriteFile(filepath.Join(dir, "01.yaml"), []byte(" invalid yaml"), 0644) assert.EqualError(t, c.Load(dir), "yaml: unmarshal errors:\n line 1: cannot unmarshal !!str `invalid...` into map[interface {}]interface {}") // simple multi config merge @@ -29,8 +28,8 @@ func TestConfig_Load(t *testing.T) { assert.Nil(t, err) - ioutil.WriteFile(filepath.Join(dir, "01.yaml"), []byte("outer:\n inner: hi"), 0644) - ioutil.WriteFile(filepath.Join(dir, "02.yml"), []byte("outer:\n inner: override\nnew: hi"), 0644) + os.WriteFile(filepath.Join(dir, "01.yaml"), []byte("outer:\n inner: hi"), 0644) + os.WriteFile(filepath.Join(dir, "02.yml"), []byte("outer:\n inner: override\nnew: hi"), 0644) assert.Nil(t, c.Load(dir)) expected := map[interface{}]interface{}{ "outer": map[interface{}]interface{}{ @@ -120,9 +119,9 @@ func TestConfig_HasChanged(t *testing.T) { func TestConfig_ReloadConfig(t *testing.T) { l := test.NewLogger() done := make(chan bool, 1) - dir, err := ioutil.TempDir("", "config-test") + dir, err := os.MkdirTemp("", "config-test") assert.Nil(t, err) - ioutil.WriteFile(filepath.Join(dir, "01.yaml"), []byte("outer:\n inner: hi"), 0644) + os.WriteFile(filepath.Join(dir, "01.yaml"), []byte("outer:\n inner: hi"), 0644) c := NewC(l) assert.Nil(t, c.Load(dir)) @@ -131,7 +130,7 @@ func TestConfig_ReloadConfig(t *testing.T) { assert.False(t, c.HasChanged("outer")) assert.False(t, c.HasChanged("")) - ioutil.WriteFile(filepath.Join(dir, "01.yaml"), []byte("outer:\n inner: ho"), 0644) + os.WriteFile(filepath.Join(dir, "01.yaml"), []byte("outer:\n inner: ho"), 0644) c.RegisterReloadCallback(func(c *C) { done <- true diff --git a/connection_manager.go b/connection_manager.go index ce11f19..f5dd594 100644 --- a/connection_manager.go +++ b/connection_manager.go @@ -23,6 +23,7 @@ const ( swapPrimary trafficDecision = 3 migrateRelays trafficDecision = 4 tryRehandshake trafficDecision = 5 + sendTestPacket trafficDecision = 6 ) type connectionManager struct { @@ -176,7 +177,7 @@ func (n *connectionManager) Run(ctx context.Context) { } func (n *connectionManager) doTrafficCheck(localIndex uint32, p, nb, out []byte, now time.Time) { - decision, hostinfo, primary := n.makeTrafficDecision(localIndex, p, nb, out, now) + decision, hostinfo, primary := n.makeTrafficDecision(localIndex, now) switch decision { case deleteTunnel: @@ -197,6 +198,9 @@ func (n *connectionManager) doTrafficCheck(localIndex uint32, p, nb, out []byte, case tryRehandshake: n.tryRehandshake(hostinfo) + + case sendTestPacket: + n.intf.SendMessageToHostInfo(header.Test, header.TestRequest, hostinfo, p, nb, out) } n.resetRelayTrafficCheck(hostinfo) @@ -289,7 +293,7 @@ func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo) } } -func (n *connectionManager) makeTrafficDecision(localIndex uint32, p, nb, out []byte, now time.Time) (trafficDecision, *HostInfo, *HostInfo) { +func (n *connectionManager) makeTrafficDecision(localIndex uint32, now time.Time) (trafficDecision, *HostInfo, *HostInfo) { n.hostMap.RLock() defer n.hostMap.RUnlock() @@ -356,6 +360,7 @@ func (n *connectionManager) makeTrafficDecision(localIndex uint32, p, nb, out [] return deleteTunnel, hostinfo, nil } + decision := doNothing if hostinfo != nil && hostinfo.ConnectionState != nil && mainHostInfo { if !outTraffic { // If we aren't sending or receiving traffic then its an unused tunnel and we don't to test the tunnel. @@ -380,7 +385,7 @@ func (n *connectionManager) makeTrafficDecision(localIndex uint32, p, nb, out [] } // Send a test packet to trigger an authenticated tunnel test, this should suss out any lingering tunnel issues - n.intf.SendMessageToHostInfo(header.Test, header.TestRequest, hostinfo, p, nb, out) + decision = sendTestPacket } else { if n.l.Level >= logrus.DebugLevel { @@ -390,7 +395,7 @@ func (n *connectionManager) makeTrafficDecision(localIndex uint32, p, nb, out [] n.pendingDeletion[hostinfo.localIndexId] = struct{}{} n.trafficTimer.Add(hostinfo.localIndexId, n.pendingDeletionInterval) - return doNothing, nil, nil + return decision, hostinfo, nil } func (n *connectionManager) shouldSwapPrimary(current, primary *HostInfo) bool { @@ -432,7 +437,7 @@ func (n *connectionManager) isInvalidCertificate(now time.Time, hostinfo *HostIn return false } - if !n.intf.disconnectInvalid && err != cert.ErrBlockListed { + if !n.intf.disconnectInvalid.Load() && err != cert.ErrBlockListed { // Block listed certificates should always be disconnected return false } diff --git a/connection_manager_test.go b/connection_manager_test.go index e802904..a2607a2 100644 --- a/connection_manager_test.go +++ b/connection_manager_test.go @@ -21,8 +21,9 @@ var vpnIp iputil.VpnIp func newTestLighthouse() *LightHouse { lh := &LightHouse{ - l: test.NewLogger(), - addrMap: map[iputil.VpnIp]*RemoteList{}, + l: test.NewLogger(), + addrMap: map[iputil.VpnIp]*RemoteList{}, + queryChan: make(chan iputil.VpnIp, 10), } lighthouses := map[iputil.VpnIp]struct{}{} staticList := map[iputil.VpnIp]struct{}{} @@ -253,18 +254,18 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) { lh := newTestLighthouse() ifce := &Interface{ - hostMap: hostMap, - inside: &test.NoopTun{}, - outside: &udp.NoopConn{}, - firewall: &Firewall{}, - lightHouse: lh, - handshakeManager: NewHandshakeManager(l, hostMap, lh, &udp.NoopConn{}, defaultHandshakeConfig), - l: l, - disconnectInvalid: true, - pki: &PKI{}, + hostMap: hostMap, + inside: &test.NoopTun{}, + outside: &udp.NoopConn{}, + firewall: &Firewall{}, + lightHouse: lh, + handshakeManager: NewHandshakeManager(l, hostMap, lh, &udp.NoopConn{}, defaultHandshakeConfig), + l: l, + pki: &PKI{}, } ifce.pki.cs.Store(cs) ifce.pki.caPool.Store(ncp) + ifce.disconnectInvalid.Store(true) // Create manager ctx, cancel := context.WithCancel(context.Background()) diff --git a/connection_state.go b/connection_state.go index f8c31f6..8ef8b3a 100644 --- a/connection_state.go +++ b/connection_state.go @@ -24,7 +24,6 @@ type ConnectionState struct { messageCounter atomic.Uint64 window *Bits writeLock sync.Mutex - ready bool } func NewConnectionState(l *logrus.Logger, cipher string, certState *CertState, initiator bool, pattern noise.HandshakePattern, psk []byte, pskStage int) *ConnectionState { @@ -71,7 +70,6 @@ func NewConnectionState(l *logrus.Logger, cipher string, certState *CertState, i H: hs, initiator: initiator, window: b, - ready: false, myCert: certState.Certificate, } @@ -83,6 +81,5 @@ func (cs *ConnectionState) MarshalJSON() ([]byte, error) { "certificate": cs.peerCert, "initiator": cs.initiator, "message_counter": cs.messageCounter.Load(), - "ready": cs.ready, }) } diff --git a/control.go b/control.go index 4af115c..1e27b0f 100644 --- a/control.go +++ b/control.go @@ -11,6 +11,7 @@ import ( "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/header" "github.com/slackhq/nebula/iputil" + "github.com/slackhq/nebula/overlay" "github.com/slackhq/nebula/udp" ) @@ -29,6 +30,7 @@ type controlHostLister interface { type Control struct { f *Interface l *logrus.Logger + ctx context.Context cancel context.CancelFunc sshStart func() statsStart func() @@ -41,7 +43,6 @@ type ControlHostInfo struct { LocalIndex uint32 `json:"localIndex"` RemoteIndex uint32 `json:"remoteIndex"` RemoteAddrs []*udp.Addr `json:"remoteAddrs"` - CachedPackets int `json:"cachedPackets"` Cert *cert.NebulaCertificate `json:"cert"` MessageCounter uint64 `json:"messageCounter"` CurrentRemote *udp.Addr `json:"currentRemote"` @@ -72,6 +73,10 @@ func (c *Control) Start() { c.f.run() } +func (c *Control) Context() context.Context { + return c.ctx +} + // Stop signals nebula to shutdown and close all tunnels, returns after the shutdown is complete func (c *Control) Stop() { // Stop the handshakeManager (and other services), to prevent new tunnels from @@ -227,6 +232,10 @@ func (c *Control) CloseAllTunnels(excludeLighthouses bool) (closed int) { return } +func (c *Control) Device() overlay.Device { + return c.f.inside +} + func copyHostInfo(h *HostInfo, preferredRanges []*net.IPNet) ControlHostInfo { chi := ControlHostInfo{ @@ -234,7 +243,6 @@ func copyHostInfo(h *HostInfo, preferredRanges []*net.IPNet) ControlHostInfo { LocalIndex: h.localIndexId, RemoteIndex: h.remoteIndexId, RemoteAddrs: h.remotes.CopyAddrs(preferredRanges), - CachedPackets: len(h.packetStore), CurrentRelaysToMe: h.relayState.CopyRelayIps(), CurrentRelaysThroughMe: h.relayState.CopyRelayForIps(), } diff --git a/control_test.go b/control_test.go index 56a2b2f..847332b 100644 --- a/control_test.go +++ b/control_test.go @@ -96,7 +96,6 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) { LocalIndex: 201, RemoteIndex: 200, RemoteAddrs: []*udp.Addr{remote2, remote1}, - CachedPackets: 0, Cert: crt.Copy(), MessageCounter: 0, CurrentRemote: udp.NewAddr(net.ParseIP("0.0.0.100"), 4444), @@ -105,7 +104,7 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) { } // Make sure we don't have any unexpected fields - assertFields(t, []string{"VpnIp", "LocalIndex", "RemoteIndex", "RemoteAddrs", "CachedPackets", "Cert", "MessageCounter", "CurrentRemote", "CurrentRelaysToMe", "CurrentRelaysThroughMe"}, thi) + assertFields(t, []string{"VpnIp", "LocalIndex", "RemoteIndex", "RemoteAddrs", "Cert", "MessageCounter", "CurrentRemote", "CurrentRelaysToMe", "CurrentRelaysThroughMe"}, thi) test.AssertDeepCopyEqual(t, &expectedInfo, thi) // Make sure we don't panic if the host info doesn't have a cert yet diff --git a/e2e/handshakes_test.go b/e2e/handshakes_test.go index 022b5a3..59f1d0e 100644 --- a/e2e/handshakes_test.go +++ b/e2e/handshakes_test.go @@ -20,7 +20,7 @@ import ( ) func BenchmarkHotPath(b *testing.B) { - ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) + ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) myControl, _, _, _ := newSimpleServer(ca, caKey, "me", net.IP{10, 0, 0, 1}, nil) theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil) @@ -44,7 +44,7 @@ func BenchmarkHotPath(b *testing.B) { } func TestGoodHandshake(t *testing.T) { - ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) + ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me", net.IP{10, 0, 0, 1}, nil) theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil) @@ -95,7 +95,7 @@ func TestGoodHandshake(t *testing.T) { } func TestWrongResponderHandshake(t *testing.T) { - ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) + ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) // The IPs here are chosen on purpose: // The current remote handling will sort by preference, public, and then lexically. @@ -164,7 +164,7 @@ func TestStage1Race(t *testing.T) { // This tests ensures that two hosts handshaking with each other at the same time will allow traffic to flow // But will eventually collapse down to a single tunnel - ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) + ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 1}, nil) theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil) @@ -241,7 +241,7 @@ func TestStage1Race(t *testing.T) { } func TestUncleanShutdownRaceLoser(t *testing.T) { - ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) + ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 1}, nil) theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil) @@ -290,7 +290,7 @@ func TestUncleanShutdownRaceLoser(t *testing.T) { } func TestUncleanShutdownRaceWinner(t *testing.T) { - ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) + ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 1}, nil) theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil) @@ -341,7 +341,7 @@ func TestUncleanShutdownRaceWinner(t *testing.T) { } func TestRelays(t *testing.T) { - ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) + ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) myControl, myVpnIpNet, _, _ := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 1}, m{"relay": m{"use_relays": true}}) relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(ca, caKey, "relay ", net.IP{10, 0, 0, 128}, m{"relay": m{"am_relay": true}}) theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", net.IP{10, 0, 0, 2}, m{"relay": m{"use_relays": true}}) @@ -372,7 +372,7 @@ func TestRelays(t *testing.T) { func TestStage1RaceRelays(t *testing.T) { //NOTE: this is a race between me and relay resulting in a full tunnel from me to them via relay - ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) + ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 1}, m{"relay": m{"use_relays": true}}) relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(ca, caKey, "relay ", net.IP{10, 0, 0, 128}, m{"relay": m{"am_relay": true}}) theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", net.IP{10, 0, 0, 2}, m{"relay": m{"use_relays": true}}) @@ -421,7 +421,7 @@ func TestStage1RaceRelays(t *testing.T) { func TestStage1RaceRelays2(t *testing.T) { //NOTE: this is a race between me and relay resulting in a full tunnel from me to them via relay - ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) + ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 1}, m{"relay": m{"use_relays": true}}) relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(ca, caKey, "relay ", net.IP{10, 0, 0, 128}, m{"relay": m{"am_relay": true}}) theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", net.IP{10, 0, 0, 2}, m{"relay": m{"use_relays": true}}) @@ -508,7 +508,7 @@ func TestStage1RaceRelays2(t *testing.T) { ////TODO: assert hostmaps } func TestRehandshakingRelays(t *testing.T) { - ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) + ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) myControl, myVpnIpNet, _, _ := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 1}, m{"relay": m{"use_relays": true}}) relayControl, relayVpnIpNet, relayUdpAddr, relayConfig := newSimpleServer(ca, caKey, "relay ", net.IP{10, 0, 0, 128}, m{"relay": m{"am_relay": true}}) theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", net.IP{10, 0, 0, 2}, m{"relay": m{"use_relays": true}}) @@ -538,7 +538,7 @@ func TestRehandshakingRelays(t *testing.T) { // When I update the certificate for the relay, both me and them will have 2 host infos for the relay, // and the main host infos will not have any relay state to handle the me<->relay<->them tunnel. r.Log("Renew relay certificate and spin until me and them sees it") - _, _, myNextPrivKey, myNextPEM := newTestCert(ca, caKey, "relay", time.Now(), time.Now().Add(5*time.Minute), relayVpnIpNet, nil, []string{"new group"}) + _, _, myNextPrivKey, myNextPEM := NewTestCert(ca, caKey, "relay", time.Now(), time.Now().Add(5*time.Minute), relayVpnIpNet, nil, []string{"new group"}) caB, err := ca.MarshalToPEM() if err != nil { @@ -612,7 +612,7 @@ func TestRehandshakingRelays(t *testing.T) { func TestRehandshakingRelaysPrimary(t *testing.T) { // This test is the same as TestRehandshakingRelays but one of the terminal types is a primary swap winner - ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) + ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) myControl, myVpnIpNet, _, _ := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 128}, m{"relay": m{"use_relays": true}}) relayControl, relayVpnIpNet, relayUdpAddr, relayConfig := newSimpleServer(ca, caKey, "relay ", net.IP{10, 0, 0, 1}, m{"relay": m{"am_relay": true}}) theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", net.IP{10, 0, 0, 2}, m{"relay": m{"use_relays": true}}) @@ -642,7 +642,7 @@ func TestRehandshakingRelaysPrimary(t *testing.T) { // When I update the certificate for the relay, both me and them will have 2 host infos for the relay, // and the main host infos will not have any relay state to handle the me<->relay<->them tunnel. r.Log("Renew relay certificate and spin until me and them sees it") - _, _, myNextPrivKey, myNextPEM := newTestCert(ca, caKey, "relay", time.Now(), time.Now().Add(5*time.Minute), relayVpnIpNet, nil, []string{"new group"}) + _, _, myNextPrivKey, myNextPEM := NewTestCert(ca, caKey, "relay", time.Now(), time.Now().Add(5*time.Minute), relayVpnIpNet, nil, []string{"new group"}) caB, err := ca.MarshalToPEM() if err != nil { @@ -715,7 +715,7 @@ func TestRehandshakingRelaysPrimary(t *testing.T) { } func TestRehandshaking(t *testing.T) { - ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) + ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) myControl, myVpnIpNet, myUdpAddr, myConfig := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 2}, nil) theirControl, theirVpnIpNet, theirUdpAddr, theirConfig := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 1}, nil) @@ -737,7 +737,7 @@ func TestRehandshaking(t *testing.T) { r.RenderHostmaps("Starting hostmaps", myControl, theirControl) r.Log("Renew my certificate and spin until their sees it") - _, _, myNextPrivKey, myNextPEM := newTestCert(ca, caKey, "me", time.Now(), time.Now().Add(5*time.Minute), myVpnIpNet, nil, []string{"new group"}) + _, _, myNextPrivKey, myNextPEM := NewTestCert(ca, caKey, "me", time.Now(), time.Now().Add(5*time.Minute), myVpnIpNet, nil, []string{"new group"}) caB, err := ca.MarshalToPEM() if err != nil { @@ -811,7 +811,7 @@ func TestRehandshaking(t *testing.T) { func TestRehandshakingLoser(t *testing.T) { // The purpose of this test is that the race loser renews their certificate and rehandshakes. The final tunnel // Should be the one with the new certificate - ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) + ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) myControl, myVpnIpNet, myUdpAddr, myConfig := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 2}, nil) theirControl, theirVpnIpNet, theirUdpAddr, theirConfig := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 1}, nil) @@ -837,7 +837,7 @@ func TestRehandshakingLoser(t *testing.T) { r.RenderHostmaps("Starting hostmaps", myControl, theirControl) r.Log("Renew their certificate and spin until mine sees it") - _, _, theirNextPrivKey, theirNextPEM := newTestCert(ca, caKey, "them", time.Now(), time.Now().Add(5*time.Minute), theirVpnIpNet, nil, []string{"their new group"}) + _, _, theirNextPrivKey, theirNextPEM := NewTestCert(ca, caKey, "them", time.Now(), time.Now().Add(5*time.Minute), theirVpnIpNet, nil, []string{"their new group"}) caB, err := ca.MarshalToPEM() if err != nil { @@ -912,7 +912,7 @@ func TestRaceRegression(t *testing.T) { // This test forces stage 1, stage 2, stage 1 to be received by me from them // We had a bug where we were not finding the duplicate handshake and responding to the final stage 1 which // caused a cross-linked hostinfo - ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) + ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me", net.IP{10, 0, 0, 1}, nil) theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil) diff --git a/e2e/helpers.go b/e2e/helpers.go new file mode 100644 index 0000000..13146ab --- /dev/null +++ b/e2e/helpers.go @@ -0,0 +1,118 @@ +package e2e + +import ( + "crypto/rand" + "io" + "net" + "time" + + "github.com/slackhq/nebula/cert" + "golang.org/x/crypto/curve25519" + "golang.org/x/crypto/ed25519" +) + +// NewTestCaCert will generate a CA cert +func NewTestCaCert(before, after time.Time, ips, subnets []*net.IPNet, groups []string) (*cert.NebulaCertificate, []byte, []byte, []byte) { + pub, priv, err := ed25519.GenerateKey(rand.Reader) + if before.IsZero() { + before = time.Now().Add(time.Second * -60).Round(time.Second) + } + if after.IsZero() { + after = time.Now().Add(time.Second * 60).Round(time.Second) + } + + nc := &cert.NebulaCertificate{ + Details: cert.NebulaCertificateDetails{ + Name: "test ca", + NotBefore: time.Unix(before.Unix(), 0), + NotAfter: time.Unix(after.Unix(), 0), + PublicKey: pub, + IsCA: true, + InvertedGroups: make(map[string]struct{}), + }, + } + + if len(ips) > 0 { + nc.Details.Ips = ips + } + + if len(subnets) > 0 { + nc.Details.Subnets = subnets + } + + if len(groups) > 0 { + nc.Details.Groups = groups + } + + err = nc.Sign(cert.Curve_CURVE25519, priv) + if err != nil { + panic(err) + } + + pem, err := nc.MarshalToPEM() + if err != nil { + panic(err) + } + + return nc, pub, priv, pem +} + +// NewTestCert will generate a signed certificate with the provided details. +// Expiry times are defaulted if you do not pass them in +func NewTestCert(ca *cert.NebulaCertificate, key []byte, name string, before, after time.Time, ip *net.IPNet, subnets []*net.IPNet, groups []string) (*cert.NebulaCertificate, []byte, []byte, []byte) { + issuer, err := ca.Sha256Sum() + if err != nil { + panic(err) + } + + if before.IsZero() { + before = time.Now().Add(time.Second * -60).Round(time.Second) + } + + if after.IsZero() { + after = time.Now().Add(time.Second * 60).Round(time.Second) + } + + pub, rawPriv := x25519Keypair() + + nc := &cert.NebulaCertificate{ + Details: cert.NebulaCertificateDetails{ + Name: name, + Ips: []*net.IPNet{ip}, + Subnets: subnets, + Groups: groups, + NotBefore: time.Unix(before.Unix(), 0), + NotAfter: time.Unix(after.Unix(), 0), + PublicKey: pub, + IsCA: false, + Issuer: issuer, + InvertedGroups: make(map[string]struct{}), + }, + } + + err = nc.Sign(ca.Details.Curve, key) + if err != nil { + panic(err) + } + + pem, err := nc.MarshalToPEM() + if err != nil { + panic(err) + } + + return nc, pub, cert.MarshalX25519PrivateKey(rawPriv), pem +} + +func x25519Keypair() ([]byte, []byte) { + privkey := make([]byte, 32) + if _, err := io.ReadFull(rand.Reader, privkey); err != nil { + panic(err) + } + + pubkey, err := curve25519.X25519(privkey, curve25519.Basepoint) + if err != nil { + panic(err) + } + + return pubkey, privkey +} diff --git a/e2e/helpers_test.go b/e2e/helpers_test.go index 8440a72..b05c84a 100644 --- a/e2e/helpers_test.go +++ b/e2e/helpers_test.go @@ -4,7 +4,6 @@ package e2e import ( - "crypto/rand" "fmt" "io" "net" @@ -22,8 +21,6 @@ import ( "github.com/slackhq/nebula/e2e/router" "github.com/slackhq/nebula/iputil" "github.com/stretchr/testify/assert" - "golang.org/x/crypto/curve25519" - "golang.org/x/crypto/ed25519" "gopkg.in/yaml.v2" ) @@ -40,7 +37,7 @@ func newSimpleServer(caCrt *cert.NebulaCertificate, caKey []byte, name string, u IP: udpIp, Port: 4242, } - _, _, myPrivKey, myPEM := newTestCert(caCrt, caKey, name, time.Now(), time.Now().Add(5*time.Minute), vpnIpNet, nil, []string{}) + _, _, myPrivKey, myPEM := NewTestCert(caCrt, caKey, name, time.Now(), time.Now().Add(5*time.Minute), vpnIpNet, nil, []string{}) caB, err := caCrt.MarshalToPEM() if err != nil { @@ -108,112 +105,6 @@ func newSimpleServer(caCrt *cert.NebulaCertificate, caKey []byte, name string, u return control, vpnIpNet, &udpAddr, c } -// newTestCaCert will generate a CA cert -func newTestCaCert(before, after time.Time, ips, subnets []*net.IPNet, groups []string) (*cert.NebulaCertificate, []byte, []byte, []byte) { - pub, priv, err := ed25519.GenerateKey(rand.Reader) - if before.IsZero() { - before = time.Now().Add(time.Second * -60).Round(time.Second) - } - if after.IsZero() { - after = time.Now().Add(time.Second * 60).Round(time.Second) - } - - nc := &cert.NebulaCertificate{ - Details: cert.NebulaCertificateDetails{ - Name: "test ca", - NotBefore: time.Unix(before.Unix(), 0), - NotAfter: time.Unix(after.Unix(), 0), - PublicKey: pub, - IsCA: true, - InvertedGroups: make(map[string]struct{}), - }, - } - - if len(ips) > 0 { - nc.Details.Ips = ips - } - - if len(subnets) > 0 { - nc.Details.Subnets = subnets - } - - if len(groups) > 0 { - nc.Details.Groups = groups - } - - err = nc.Sign(cert.Curve_CURVE25519, priv) - if err != nil { - panic(err) - } - - pem, err := nc.MarshalToPEM() - if err != nil { - panic(err) - } - - return nc, pub, priv, pem -} - -// newTestCert will generate a signed certificate with the provided details. -// Expiry times are defaulted if you do not pass them in -func newTestCert(ca *cert.NebulaCertificate, key []byte, name string, before, after time.Time, ip *net.IPNet, subnets []*net.IPNet, groups []string) (*cert.NebulaCertificate, []byte, []byte, []byte) { - issuer, err := ca.Sha256Sum() - if err != nil { - panic(err) - } - - if before.IsZero() { - before = time.Now().Add(time.Second * -60).Round(time.Second) - } - - if after.IsZero() { - after = time.Now().Add(time.Second * 60).Round(time.Second) - } - - pub, rawPriv := x25519Keypair() - - nc := &cert.NebulaCertificate{ - Details: cert.NebulaCertificateDetails{ - Name: name, - Ips: []*net.IPNet{ip}, - Subnets: subnets, - Groups: groups, - NotBefore: time.Unix(before.Unix(), 0), - NotAfter: time.Unix(after.Unix(), 0), - PublicKey: pub, - IsCA: false, - Issuer: issuer, - InvertedGroups: make(map[string]struct{}), - }, - } - - err = nc.Sign(ca.Details.Curve, key) - if err != nil { - panic(err) - } - - pem, err := nc.MarshalToPEM() - if err != nil { - panic(err) - } - - return nc, pub, cert.MarshalX25519PrivateKey(rawPriv), pem -} - -func x25519Keypair() ([]byte, []byte) { - privkey := make([]byte, 32) - if _, err := io.ReadFull(rand.Reader, privkey); err != nil { - panic(err) - } - - pubkey, err := curve25519.X25519(privkey, curve25519.Basepoint) - if err != nil { - panic(err) - } - - return pubkey, privkey -} - type doneCb func() func deadline(t *testing.T, seconds time.Duration) doneCb { diff --git a/examples/config.yml b/examples/config.yml index 96ae8de..21cda3b 100644 --- a/examples/config.yml +++ b/examples/config.yml @@ -11,7 +11,7 @@ pki: #blocklist: # - c99d4e650533b92061b09918e838a5a0a6aaee21eed1d12fd937682865936c72 # disconnect_invalid is a toggle to force a client to be disconnected if the certificate is expired or invalid. - #disconnect_invalid: false + #disconnect_invalid: true # The static host map defines a set of hosts with fixed IP addresses on the internet (or any network). # A host can have multiple fixed IP addresses defined here, and nebula will try each when establishing a tunnel. @@ -330,6 +330,10 @@ logging: # A 100ms interval with the default 10 retries will give a handshake 5.5 seconds to resolve before timing out #try_interval: 100ms #retries: 20 + + # query_buffer is the size of the buffer channel for querying lighthouses + #query_buffer: 64 + # trigger_buffer is the size of the buffer channel for quickly sending handshakes # after receiving the response for lighthouse queries #trigger_buffer: 64 diff --git a/examples/go_service/main.go b/examples/go_service/main.go new file mode 100644 index 0000000..f46273a --- /dev/null +++ b/examples/go_service/main.go @@ -0,0 +1,100 @@ +package main + +import ( + "bufio" + "fmt" + "log" + + "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/service" +) + +func main() { + if err := run(); err != nil { + log.Fatalf("%+v", err) + } +} + +func run() error { + configStr := ` +tun: + user: true + +static_host_map: + '192.168.100.1': ['localhost:4242'] + +listen: + host: 0.0.0.0 + port: 4241 + +lighthouse: + am_lighthouse: false + interval: 60 + hosts: + - '192.168.100.1' + +firewall: + outbound: + # Allow all outbound traffic from this node + - port: any + proto: any + host: any + + inbound: + # Allow icmp between any nebula hosts + - port: any + proto: icmp + host: any + - port: any + proto: any + host: any + +pki: + ca: /home/rice/Developer/nebula-config/ca.crt + cert: /home/rice/Developer/nebula-config/app.crt + key: /home/rice/Developer/nebula-config/app.key +` + var config config.C + if err := config.LoadString(configStr); err != nil { + return err + } + service, err := service.New(&config) + if err != nil { + return err + } + + ln, err := service.Listen("tcp", ":1234") + if err != nil { + return err + } + for { + conn, err := ln.Accept() + if err != nil { + log.Printf("accept error: %s", err) + break + } + defer conn.Close() + + log.Printf("got connection") + + conn.Write([]byte("hello world\n")) + + scanner := bufio.NewScanner(conn) + for scanner.Scan() { + message := scanner.Text() + fmt.Fprintf(conn, "echo: %q\n", message) + log.Printf("got message %q", message) + } + + if err := scanner.Err(); err != nil { + log.Printf("scanner error: %s", err) + break + } + } + + service.Close() + if err := service.Wait(); err != nil { + return err + } + return nil +} diff --git a/firewall.go b/firewall.go index 93d940d..64fada3 100644 --- a/firewall.go +++ b/firewall.go @@ -6,6 +6,7 @@ import ( "encoding/hex" "errors" "fmt" + "hash/fnv" "net" "reflect" "strconv" @@ -57,7 +58,7 @@ type Firewall struct { DefaultTimeout time.Duration //linux: 600s // Used to ensure we don't emit local packets for ips we don't own - localIps *cidr.Tree4 + localIps *cidr.Tree4[struct{}] rules string rulesVersion uint16 @@ -110,8 +111,8 @@ type FirewallRule struct { Any bool Hosts map[string]struct{} Groups [][]string - CIDR *cidr.Tree4 - LocalCIDR *cidr.Tree4 + CIDR *cidr.Tree4[struct{}] + LocalCIDR *cidr.Tree4[struct{}] } // Even though ports are uint16, int32 maps are faster for lookup @@ -137,7 +138,7 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D max = defaultTimeout } - localIps := cidr.NewTree4() + localIps := cidr.NewTree4[struct{}]() for _, ip := range c.Details.Ips { localIps.AddCIDR(&net.IPNet{IP: ip.IP, Mask: net.IPMask{255, 255, 255, 255}}, struct{}{}) } @@ -278,6 +279,18 @@ func (f *Firewall) GetRuleHash() string { return hex.EncodeToString(sum[:]) } +// GetRuleHashFNV returns a uint32 FNV-1 hash representation the rules, for use as a metric value +func (f *Firewall) GetRuleHashFNV() uint32 { + h := fnv.New32a() + h.Write([]byte(f.rules)) + return h.Sum32() +} + +// GetRuleHashes returns both the sha256 and FNV-1 hashes, suitable for logging +func (f *Firewall) GetRuleHashes() string { + return "SHA:" + f.GetRuleHash() + ",FNV:" + strconv.FormatUint(uint64(f.GetRuleHashFNV()), 10) +} + func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw FirewallInterface) error { var table string if inbound { @@ -391,7 +404,8 @@ func (f *Firewall) Drop(packet []byte, fp firewall.Packet, incoming bool, h *Hos // Make sure remote address matches nebula certificate if remoteCidr := h.remoteCidr; remoteCidr != nil { - if remoteCidr.Contains(fp.RemoteIP) == nil { + ok, _ := remoteCidr.Contains(fp.RemoteIP) + if !ok { f.metrics(incoming).droppedRemoteIP.Inc(1) return ErrInvalidRemoteIP } @@ -404,7 +418,8 @@ func (f *Firewall) Drop(packet []byte, fp firewall.Packet, incoming bool, h *Hos } // Make sure we are supposed to be handling this local ip address - if f.localIps.Contains(fp.LocalIP) == nil { + ok, _ := f.localIps.Contains(fp.LocalIP) + if !ok { f.metrics(incoming).droppedLocalIP.Inc(1) return ErrInvalidLocalIP } @@ -447,6 +462,7 @@ func (f *Firewall) EmitStats() { conntrack.Unlock() metrics.GetOrRegisterGauge("firewall.conntrack.count", nil).Update(int64(conntrackCount)) metrics.GetOrRegisterGauge("firewall.rules.version", nil).Update(int64(f.rulesVersion)) + metrics.GetOrRegisterGauge("firewall.rules.hash", nil).Update(int64(f.GetRuleHashFNV())) } func (f *Firewall) inConns(packet []byte, fp firewall.Packet, incoming bool, h *HostInfo, caPool *cert.NebulaCAPool, localCache firewall.ConntrackCache) bool { @@ -657,8 +673,8 @@ func (fc *FirewallCA) addRule(groups []string, host string, ip, localIp *net.IPN return &FirewallRule{ Hosts: make(map[string]struct{}), Groups: make([][]string, 0), - CIDR: cidr.NewTree4(), - LocalCIDR: cidr.NewTree4(), + CIDR: cidr.NewTree4[struct{}](), + LocalCIDR: cidr.NewTree4[struct{}](), } } @@ -726,8 +742,8 @@ func (fr *FirewallRule) addRule(groups []string, host string, ip *net.IPNet, loc // If it's any we need to wipe out any pre-existing rules to save on memory fr.Groups = make([][]string, 0) fr.Hosts = make(map[string]struct{}) - fr.CIDR = cidr.NewTree4() - fr.LocalCIDR = cidr.NewTree4() + fr.CIDR = cidr.NewTree4[struct{}]() + fr.LocalCIDR = cidr.NewTree4[struct{}]() } else { if len(groups) > 0 { fr.Groups = append(fr.Groups, groups) @@ -809,12 +825,18 @@ func (fr *FirewallRule) match(p firewall.Packet, c *cert.NebulaCertificate) bool } } - if fr.CIDR != nil && fr.CIDR.Contains(p.RemoteIP) != nil { - return true + if fr.CIDR != nil { + ok, _ := fr.CIDR.Contains(p.RemoteIP) + if ok { + return true + } } - if fr.LocalCIDR != nil && fr.LocalCIDR.Contains(p.LocalIP) != nil { - return true + if fr.LocalCIDR != nil { + ok, _ := fr.LocalCIDR.Contains(p.LocalIP) + if ok { + return true + } } // No host, group, or cidr matched, bye bye diff --git a/firewall_test.go b/firewall_test.go index 7ffa747..83da899 100644 --- a/firewall_test.go +++ b/firewall_test.go @@ -92,14 +92,16 @@ func TestFirewall_AddRule(t *testing.T) { assert.False(t, fw.OutRules.AnyProto[1].Any.Any) assert.Empty(t, fw.OutRules.AnyProto[1].Any.Groups) assert.Empty(t, fw.OutRules.AnyProto[1].Any.Hosts) - assert.NotNil(t, fw.OutRules.AnyProto[1].Any.CIDR.Match(iputil.Ip2VpnIp(ti.IP))) + ok, _ := fw.OutRules.AnyProto[1].Any.CIDR.Match(iputil.Ip2VpnIp(ti.IP)) + assert.True(t, ok) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", nil, ti, "", "")) assert.False(t, fw.OutRules.AnyProto[1].Any.Any) assert.Empty(t, fw.OutRules.AnyProto[1].Any.Groups) assert.Empty(t, fw.OutRules.AnyProto[1].Any.Hosts) - assert.NotNil(t, fw.OutRules.AnyProto[1].Any.LocalCIDR.Match(iputil.Ip2VpnIp(ti.IP))) + ok, _ = fw.OutRules.AnyProto[1].Any.LocalCIDR.Match(iputil.Ip2VpnIp(ti.IP)) + assert.True(t, ok) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", nil, nil, "ca-name", "")) @@ -114,8 +116,10 @@ func TestFirewall_AddRule(t *testing.T) { assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{"g1", "g2"}, "h1", ti, ti, "", "")) assert.Equal(t, []string{"g1", "g2"}, fw.OutRules.AnyProto[0].Any.Groups[0]) assert.Contains(t, fw.OutRules.AnyProto[0].Any.Hosts, "h1") - assert.NotNil(t, fw.OutRules.AnyProto[0].Any.CIDR.Match(iputil.Ip2VpnIp(ti.IP))) - assert.NotNil(t, fw.OutRules.AnyProto[0].Any.LocalCIDR.Match(iputil.Ip2VpnIp(ti.IP))) + ok, _ = fw.OutRules.AnyProto[0].Any.CIDR.Match(iputil.Ip2VpnIp(ti.IP)) + assert.True(t, ok) + ok, _ = fw.OutRules.AnyProto[0].Any.LocalCIDR.Match(iputil.Ip2VpnIp(ti.IP)) + assert.True(t, ok) // run twice just to make sure //TODO: these ANY rules should clear the CA firewall portion diff --git a/go.mod b/go.mod index ba57aa1..5c6e87a 100644 --- a/go.mod +++ b/go.mod @@ -7,29 +7,31 @@ require ( github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be github.com/armon/go-radix v1.0.0 github.com/cyberdelia/go-metrics-graphite v0.0.0-20161219230853-39f87cc3b432 - github.com/flynn/noise v1.0.0 + github.com/flynn/noise v1.0.1 github.com/gogo/protobuf v1.3.2 github.com/google/gopacket v1.1.19 github.com/kardianos/service v1.2.2 github.com/miekg/dns v1.1.56 github.com/nbrownus/go-metrics-prometheus v0.0.0-20210712211119-974a6260965f - github.com/prometheus/client_golang v1.16.0 + github.com/prometheus/client_golang v1.17.0 github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475 github.com/sirupsen/logrus v1.9.3 github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8 github.com/stretchr/testify v1.8.4 - github.com/vishvananda/netlink v1.1.0 - golang.org/x/crypto v0.14.0 + github.com/vishvananda/netlink v1.1.1-0.20211118161826-650dca95af54 + golang.org/x/crypto v0.17.0 golang.org/x/exp v0.0.0-20230425010034-47ecfdc1ba53 - golang.org/x/net v0.17.0 - golang.org/x/sys v0.13.0 - golang.org/x/term v0.13.0 + golang.org/x/net v0.19.0 + golang.org/x/sync v0.5.0 + golang.org/x/sys v0.15.0 + golang.org/x/term v0.15.0 golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b golang.zx2c4.com/wireguard/windows v0.5.3 google.golang.org/protobuf v1.31.0 gopkg.in/yaml.v2 v2.4.0 + gvisor.dev/gvisor v0.0.0-20230504175454-7b0a1988a28f ) require ( @@ -37,14 +39,15 @@ require ( github.com/cespare/xxhash/v2 v2.2.0 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/golang/protobuf v1.5.3 // indirect + github.com/google/btree v1.0.1 // indirect github.com/matttproud/golang_protobuf_extensions v1.0.4 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect - github.com/prometheus/client_model v0.4.0 // indirect - github.com/prometheus/common v0.42.0 // indirect - github.com/prometheus/procfs v0.10.1 // indirect - github.com/rogpeppe/go-internal v1.10.0 // indirect + github.com/prometheus/client_model v0.4.1-0.20230718164431-9a2bf3000d16 // indirect + github.com/prometheus/common v0.44.0 // indirect + github.com/prometheus/procfs v0.11.1 // indirect github.com/vishvananda/netns v0.0.4 // indirect golang.org/x/mod v0.12.0 // indirect + golang.org/x/time v0.0.0-20220210224613-90d013bbcef8 // indirect golang.org/x/tools v0.13.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 445f18a..6ef9874 100644 --- a/go.sum +++ b/go.sum @@ -22,8 +22,8 @@ github.com/cyberdelia/go-metrics-graphite v0.0.0-20161219230853-39f87cc3b432/go. github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/flynn/noise v1.0.0 h1:DlTHqmzmvcEiKj+4RYo/imoswx/4r6iBlCMfVtrMXpQ= -github.com/flynn/noise v1.0.0/go.mod h1:xbMo+0i6+IGbYdJhF31t2eR1BIU0CYc12+BNAKwUTag= +github.com/flynn/noise v1.0.1 h1:vPp/jdQLXC6ppsXSj/pM3W1BIJ5FEHE2TulSJBpb43Y= +github.com/flynn/noise v1.0.1/go.mod h1:xbMo+0i6+IGbYdJhF31t2eR1BIU0CYc12+BNAKwUTag= github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY= @@ -47,6 +47,8 @@ github.com/golang/protobuf v1.4.3/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg= github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= +github.com/google/btree v1.0.1 h1:gK4Kx5IaGY9CD5sPJ36FHiBJ6ZXl0kilRiiCj+jdYp4= +github.com/google/btree v1.0.1/go.mod h1:xXMiIv4Fb/0kKde4SpL7qlzvu5cMJDRkFDxJfI9uaxA= github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= @@ -97,28 +99,27 @@ github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXP github.com/prometheus/client_golang v1.0.0/go.mod h1:db9x61etRT2tGnBNRi70OPL5FsnadC4Ky3P0J6CfImo= github.com/prometheus/client_golang v1.7.1/go.mod h1:PY5Wy2awLA44sXw4AOSfFBetzPP4j5+D6mVACh+pe2M= github.com/prometheus/client_golang v1.11.0/go.mod h1:Z6t4BnS23TR94PD6BsDNk8yVqroYurpAkEiz0P2BEV0= -github.com/prometheus/client_golang v1.16.0 h1:yk/hx9hDbrGHovbci4BY+pRMfSuuat626eFsHb7tmT8= -github.com/prometheus/client_golang v1.16.0/go.mod h1:Zsulrv/L9oM40tJ7T815tM89lFEugiJ9HzIqaAx4LKc= +github.com/prometheus/client_golang v1.17.0 h1:rl2sfwZMtSthVU752MqfjQozy7blglC+1SOtjMAMh+Q= +github.com/prometheus/client_golang v1.17.0/go.mod h1:VeL+gMmOAxkS2IqfCq0ZmHSL+LjWfWDUmp1mBz9JgUY= github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo= github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= github.com/prometheus/client_model v0.2.0/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= -github.com/prometheus/client_model v0.4.0 h1:5lQXD3cAg1OXBf4Wq03gTrXHeaV0TQvGfUooCfx1yqY= -github.com/prometheus/client_model v0.4.0/go.mod h1:oMQmHW1/JoDwqLtg57MGgP/Fb1CJEYF2imWWhWtMkYU= +github.com/prometheus/client_model v0.4.1-0.20230718164431-9a2bf3000d16 h1:v7DLqVdK4VrYkVD5diGdl4sxJurKJEMnODWRJlxV9oM= +github.com/prometheus/client_model v0.4.1-0.20230718164431-9a2bf3000d16/go.mod h1:oMQmHW1/JoDwqLtg57MGgP/Fb1CJEYF2imWWhWtMkYU= github.com/prometheus/common v0.4.1/go.mod h1:TNfzLD0ON7rHzMJeJkieUDPYmFC7Snx/y86RQel1bk4= github.com/prometheus/common v0.10.0/go.mod h1:Tlit/dnDKsSWFlCLTWaA1cyBgKHSMdTB80sz/V91rCo= github.com/prometheus/common v0.26.0/go.mod h1:M7rCNAaPfAosfx8veZJCuw84e35h3Cfd9VFqTh1DIvc= -github.com/prometheus/common v0.42.0 h1:EKsfXEYo4JpWMHH5cg+KOUWeuJSov1Id8zGR8eeI1YM= -github.com/prometheus/common v0.42.0/go.mod h1:xBwqVerjNdUDjgODMpudtOMwlOwf2SaTr1yjz4b7Zbc= +github.com/prometheus/common v0.44.0 h1:+5BrQJwiBB9xsMygAB3TNvpQKOwlkc25LbISbrdOOfY= +github.com/prometheus/common v0.44.0/go.mod h1:ofAIvZbQ1e/nugmZGz4/qCb9Ap1VoSTIO7x0VV9VvuY= github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk= github.com/prometheus/procfs v0.0.2/go.mod h1:TjEm7ze935MbeOT/UhFTIMYKhuLP4wbCsTZCD3I8kEA= github.com/prometheus/procfs v0.1.3/go.mod h1:lV6e/gmhEcM9IjHGsFOCxxuZ+z1YqCvr4OA4YeYWdaU= github.com/prometheus/procfs v0.6.0/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1xBZuNvfVA= -github.com/prometheus/procfs v0.10.1 h1:kYK1Va/YMlutzCGazswoHKo//tZVlFpKYh+PymziUAg= -github.com/prometheus/procfs v0.10.1/go.mod h1:nwNm2aOCAYw8uTR/9bWRREkZFxAUcWzPHWJq+XBB/FM= +github.com/prometheus/procfs v0.11.1 h1:xRC8Iq1yyca5ypa9n1EZnWZkt7dwcoRPQwX/5gwaUuI= +github.com/prometheus/procfs v0.11.1/go.mod h1:eesXgaPo1q7lBpVMoMy0ZOFTth9hBn4W/y0/p/ScXhY= github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475 h1:N/ElC8H3+5XpJzTSTfLsJV/mx9Q9g7kxmchpfZyxgzM= github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4= github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ= -github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog= github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= github.com/sirupsen/logrus v1.6.0/go.mod h1:7uNnSEd1DgxDLC74fIahvMZmmYsHGZGEOFrfsX/uA88= @@ -136,9 +137,9 @@ github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81P github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= -github.com/vishvananda/netlink v1.1.0 h1:1iyaYNBLmP6L0220aDnYQpo1QEV4t4hJ+xEEhhJH8j0= -github.com/vishvananda/netlink v1.1.0/go.mod h1:cTgwzPIzzgDAYoQrMm0EdrjRUBkTqKYppBueQtXaqoE= -github.com/vishvananda/netns v0.0.0-20191106174202-0a2b9b5464df/go.mod h1:JP3t17pCcGlemwknint6hfoeCVQrEMVwxRLRjXpq+BU= +github.com/vishvananda/netlink v1.1.1-0.20211118161826-650dca95af54 h1:8mhqcHPqTMhSPoslhGYihEgSfc77+7La1P6kiB6+9So= +github.com/vishvananda/netlink v1.1.1-0.20211118161826-650dca95af54/go.mod h1:twkDnbuQxJYemMlGd4JFIcuhgX83tXhKS2B/PRMpOho= +github.com/vishvananda/netns v0.0.0-20200728191858-db3c7e526aae/go.mod h1:DD4vA1DwXk04H54A1oHXtwZmA0grkVMdPxx/VGLCah0= github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8= github.com/vishvananda/netns v0.0.4/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= @@ -148,8 +149,8 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= -golang.org/x/crypto v0.14.0 h1:wBqGXzWJW6m1XrIKlAH0Hs1JJ7+9KBwnIO8v66Q9cHc= -golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4= +golang.org/x/crypto v0.17.0 h1:r8bRNjWL3GshPW3gkd+RpvzWrZAwPS49OmTGZ/uhM4k= +golang.org/x/crypto v0.17.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4= golang.org/x/exp v0.0.0-20230425010034-47ecfdc1ba53 h1:5llv2sWeaMSnA3w2kS57ouQQ4pudlXrR0dCgw51QK9o= golang.org/x/exp v0.0.0-20230425010034-47ecfdc1ba53/go.mod h1:V1LtkGg67GoY2N1AnLN78QLrzxkLyJw7RJb1gzOOz9w= golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY= @@ -168,8 +169,8 @@ golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLL golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= -golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM= -golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE= +golang.org/x/net v0.19.0 h1:zTwKpTd2XuCqf8huc7Fo2iSy+4RHPd10s4KzeTnVr1c= +golang.org/x/net v0.19.0/go.mod h1:CfAk/cbD4CthTvqiEl8NpboMuiuOYsAr/7NOjZJtv1U= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -177,31 +178,35 @@ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.3.0 h1:ftCYgMx6zT/asHUrPw8BLLscYtGznsLAnjq5RH9P66E= +golang.org/x/sync v0.5.0 h1:60k92dhOjHxJkrqnwsfl8KuaHbn/5dl0lUPUklKo3qE= +golang.org/x/sync v0.5.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20190606203320-7fc4e5ec1444/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200106162015-b016eb3dc98e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200217220822-9197077df867/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200615200032-f1bc736245b1/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200625212154-ddb9806d33ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200728102440-3e129f6d46b1/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201015000850-e3ed0017c211/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210603081109-ebe580a85c40/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE= -golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.15.0 h1:h48lPFYpsTvQJZF4EKyI4aLHaev3CxivZmv7yZig9pc= +golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= -golang.org/x/term v0.13.0 h1:bb+I9cTfFazGW51MZqBVmZy7+JEJMouUHTUSKVQLBek= -golang.org/x/term v0.13.0/go.mod h1:LTmsnFJwVN6bCy1rVCoS+qHT1HhALEFxKncY3WNNh4U= +golang.org/x/term v0.15.0 h1:y/Oo/a/q3IXu26lQgl04j/gjuBDOBlx7X6Om1j2CPW4= +golang.org/x/term v0.15.0/go.mod h1:BDl952bC7+uMoWR75FIrCDx79TPU9oHkTZ9yRbYOrX0= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/time v0.0.0-20220210224613-90d013bbcef8 h1:vVKdlvoWBphwdxWKrFZEuM0kGgGLxUOYcY4U/2Vjg44= +golang.org/x/time v0.0.0-20220210224613-90d013bbcef8/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= @@ -245,3 +250,5 @@ gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gvisor.dev/gvisor v0.0.0-20230504175454-7b0a1988a28f h1:8GE2MRjGiFmfpon8dekPI08jEuNMQzSffVHgdupcO4E= +gvisor.dev/gvisor v0.0.0-20230504175454-7b0a1988a28f/go.mod h1:pzr6sy8gDLfVmDAg8OYrlKvGEHw5C3PGTiBXBTCx76Q= diff --git a/handshake.go b/handshake.go deleted file mode 100644 index 8cfba21..0000000 --- a/handshake.go +++ /dev/null @@ -1,31 +0,0 @@ -package nebula - -import ( - "github.com/slackhq/nebula/header" - "github.com/slackhq/nebula/udp" -) - -func HandleIncomingHandshake(f *Interface, addr *udp.Addr, via *ViaSender, packet []byte, h *header.H, hostinfo *HostInfo) { - // First remote allow list check before we know the vpnIp - if addr != nil { - if !f.lightHouse.GetRemoteAllowList().AllowUnknownVpnIp(addr.IP) { - f.l.WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake") - return - } - } - - switch h.Subtype { - case header.HandshakeIXPSK0: - switch h.MessageCounter { - case 1: - ixHandshakeStage1(f, addr, via, packet, h) - case 2: - newHostinfo := f.handshakeManager.QueryIndex(h.RemoteIndex) - tearDown := ixHandshakeStage2(f, addr, via, newHostinfo, packet, h) - if tearDown && newHostinfo != nil { - f.handshakeManager.DeleteHostInfo(newHostinfo) - } - } - } - -} diff --git a/handshake_ix.go b/handshake_ix.go index 26cc983..68998e9 100644 --- a/handshake_ix.go +++ b/handshake_ix.go @@ -4,6 +4,7 @@ import ( "time" "github.com/flynn/noise" + "github.com/sirupsen/logrus" "github.com/slackhq/nebula/header" "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/udp" @@ -13,20 +14,20 @@ import ( // This function constructs a handshake packet, but does not actually send it // Sending is done by the handshake manager -func ixHandshakeStage0(f *Interface, hostinfo *HostInfo) bool { - err := f.handshakeManager.allocateIndex(hostinfo) +func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool { + err := f.handshakeManager.allocateIndex(hh) if err != nil { - f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp). + f.l.WithError(err).WithField("vpnIp", hh.hostinfo.vpnIp). WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to generate index") return false } certState := f.pki.GetCertState() ci := NewConnectionState(f.l, f.cipher, certState, true, noise.HandshakeIX, []byte{}, 0) - hostinfo.ConnectionState = ci + hh.hostinfo.ConnectionState = ci hsProto := &NebulaHandshakeDetails{ - InitiatorIndex: hostinfo.localIndexId, + InitiatorIndex: hh.hostinfo.localIndexId, Time: uint64(time.Now().UnixNano()), Cert: certState.RawCertificateNoKey, } @@ -48,7 +49,7 @@ func ixHandshakeStage0(f *Interface, hostinfo *HostInfo) bool { hsBytes, err = hs.Marshal() if err != nil { - f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp). + f.l.WithError(err).WithField("vpnIp", hh.hostinfo.vpnIp). WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to marshal handshake message") return false } @@ -58,7 +59,7 @@ func ixHandshakeStage0(f *Interface, hostinfo *HostInfo) bool { msg, _, _, err := ci.H.WriteMessage(h, hsBytes) if err != nil { - f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp). + f.l.WithError(err).WithField("vpnIp", hh.hostinfo.vpnIp). WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage") return false } @@ -67,9 +68,8 @@ func ixHandshakeStage0(f *Interface, hostinfo *HostInfo) bool { // handshake packet 1 from the responder ci.window.Update(f.l, 1) - hostinfo.HandshakePacket[0] = msg - hostinfo.HandshakeReady = true - hostinfo.handshakeStart = time.Now() + hh.hostinfo.HandshakePacket[0] = msg + hh.ready = true return true } @@ -174,9 +174,6 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []by }, } - hostinfo.Lock() - defer hostinfo.Unlock() - f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr). WithField("certName", certName). WithField("fingerprint", fingerprint). @@ -243,19 +240,16 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []by if err != nil { switch err { case ErrAlreadySeen: - // Update remote if preferred (Note we have to switch to locking - // the existing hostinfo, and then switch back so the defer Unlock - // higher in this function still works) - hostinfo.Unlock() - existing.Lock() + if hostinfo.multiportRx { + // The other host is sending to us with multiport, so only grab the IP + addr.Port = hostinfo.remote.Port + } // Update remote if preferred if existing.SetRemoteIfPreferred(f.hostMap, addr) { // Send a test packet to ensure the other side has also switched to // the preferred remote f.SendMessageToVpnIp(header.Test, header.TestRequest, vpnIp, []byte(""), make([]byte, 12, 12), make([]byte, mtu)) } - existing.Unlock() - hostinfo.Lock() msg = existing.HandshakePacket[2] f.messageMetrics.Tx(header.Handshake, header.MessageSubType(msg[1]), 1) @@ -356,7 +350,6 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []by WithField("issuer", issuer). WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). - WithField("sentCachedPackets", len(hostinfo.packetStore)). Info("Handshake message sent") } } else { @@ -372,25 +365,26 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []by WithField("issuer", issuer). WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). - WithField("sentCachedPackets", len(hostinfo.packetStore)). Info("Handshake message sent") } f.connectionManager.AddTrafficWatch(hostinfo.localIndexId) - hostinfo.handshakeComplete(f.l, f.cachedPacketMetrics) + hostinfo.ConnectionState.messageCounter.Store(2) + hostinfo.remotes.ResetBlockedRemotes() return } -func ixHandshakeStage2(f *Interface, addr *udp.Addr, via *ViaSender, hostinfo *HostInfo, packet []byte, h *header.H) bool { - if hostinfo == nil { +func ixHandshakeStage2(f *Interface, addr *udp.Addr, via *ViaSender, hh *HandshakeHostInfo, packet []byte, h *header.H) bool { + if hh == nil { // Nothing here to tear down, got a bogus stage 2 packet return true } - hostinfo.Lock() - defer hostinfo.Unlock() + hh.Lock() + defer hh.Unlock() + hostinfo := hh.hostinfo if addr != nil { if !f.lightHouse.GetRemoteAllowList().Allow(hostinfo.vpnIp, addr.IP) { f.l.WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake") @@ -399,27 +393,6 @@ func ixHandshakeStage2(f *Interface, addr *udp.Addr, via *ViaSender, hostinfo *H } ci := hostinfo.ConnectionState - if ci.ready { - if hostinfo.multiportRx { - // The other host is sending to us with multiport, so only grab the IP - addr.Port = hostinfo.remote.Port - } - - f.l.WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr). - WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("header", h). - Info("Handshake is already complete") - - // Update remote if preferred - if hostinfo.SetRemoteIfPreferred(f.hostMap, addr) { - // Send a test packet to ensure the other side has also switched to - // the preferred remote - f.SendMessageToVpnIp(header.Test, header.TestRequest, hostinfo.vpnIp, []byte(""), make([]byte, 12, 12), make([]byte, mtu)) - } - - // We already have a complete tunnel, there is nothing that can be done by processing further stage 1 packets - return false - } - msg, eKey, dKey, err := ci.H.ReadMessage(nil, packet[header.Len:]) if err != nil { f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr). @@ -490,22 +463,22 @@ func ixHandshakeStage2(f *Interface, addr *udp.Addr, via *ViaSender, hostinfo *H f.handshakeManager.DeleteHostInfo(hostinfo) // Create a new hostinfo/handshake for the intended vpn ip - f.handshakeManager.StartHandshake(hostinfo.vpnIp, func(newHostInfo *HostInfo) { + f.handshakeManager.StartHandshake(hostinfo.vpnIp, func(newHH *HandshakeHostInfo) { //TODO: this doesnt know if its being added or is being used for caching a packet // Block the current used address - newHostInfo.remotes = hostinfo.remotes - newHostInfo.remotes.BlockRemote(addr) + newHH.hostinfo.remotes = hostinfo.remotes + newHH.hostinfo.remotes.BlockRemote(addr) // Get the correct remote list for the host we did handshake with hostinfo.remotes = f.lightHouse.QueryCache(vpnIp) - f.l.WithField("blockedUdpAddrs", newHostInfo.remotes.CopyBlockedRemotes()).WithField("vpnIp", vpnIp). - WithField("remotes", newHostInfo.remotes.CopyAddrs(f.hostMap.preferredRanges)). + f.l.WithField("blockedUdpAddrs", newHH.hostinfo.remotes.CopyBlockedRemotes()).WithField("vpnIp", vpnIp). + WithField("remotes", newHH.hostinfo.remotes.CopyAddrs(f.hostMap.preferredRanges)). Info("Blocked addresses for handshakes") // Swap the packet store to benefit the original intended recipient - newHostInfo.packetStore = hostinfo.packetStore - hostinfo.packetStore = []*cachedPacket{} + newHH.packetStore = hh.packetStore + hh.packetStore = []*cachedPacket{} // Finally, put the correct vpn ip in the host info, tell them to close the tunnel, and return true to tear down hostinfo.vpnIp = vpnIp @@ -518,7 +491,7 @@ func ixHandshakeStage2(f *Interface, addr *udp.Addr, via *ViaSender, hostinfo *H // Mark packet 2 as seen so it doesn't show up as missed ci.window.Update(f.l, 2) - duration := time.Since(hostinfo.handshakeStart).Nanoseconds() + duration := time.Since(hh.startTime).Nanoseconds() f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr). WithField("certName", certName). WithField("fingerprint", fingerprint). @@ -526,7 +499,7 @@ func ixHandshakeStage2(f *Interface, addr *udp.Addr, via *ViaSender, hostinfo *H WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). WithField("durationNs", duration). - WithField("sentCachedPackets", len(hostinfo.packetStore)). + WithField("sentCachedPackets", len(hh.packetStore)). WithField("multiportTx", hostinfo.multiportTx).WithField("multiportRx", hostinfo.multiportRx). Info("Handshake message received") @@ -551,7 +524,23 @@ func ixHandshakeStage2(f *Interface, addr *udp.Addr, via *ViaSender, hostinfo *H // Complete our handshake and update metrics, this will replace any existing tunnels for this vpnIp f.handshakeManager.Complete(hostinfo, f) f.connectionManager.AddTrafficWatch(hostinfo.localIndexId) - hostinfo.handshakeComplete(f.l, f.cachedPacketMetrics) + + hostinfo.ConnectionState.messageCounter.Store(2) + + if f.l.Level >= logrus.DebugLevel { + hostinfo.logger(f.l).Debugf("Sending %d stored packets", len(hh.packetStore)) + } + + if len(hh.packetStore) > 0 { + nb := make([]byte, 12, 12) + out := make([]byte, mtu) + for _, cp := range hh.packetStore { + cp.callback(cp.messageType, cp.messageSubType, hostinfo, cp.packet, nb, out) + } + f.cachedPacketMetrics.sent.Inc(int64(len(hh.packetStore))) + } + + hostinfo.remotes.ResetBlockedRemotes() f.metricHandshakes.Update(duration) return false diff --git a/handshake_manager.go b/handshake_manager.go index 107b1f3..0d50843 100644 --- a/handshake_manager.go +++ b/handshake_manager.go @@ -46,8 +46,8 @@ type HandshakeManager struct { // Mutex for interacting with the vpnIps and indexes maps sync.RWMutex - vpnIps map[iputil.VpnIp]*HostInfo - indexes map[uint32]*HostInfo + vpnIps map[iputil.VpnIp]*HandshakeHostInfo + indexes map[uint32]*HandshakeHostInfo mainHostMap *HostMap lightHouse *LightHouse @@ -67,10 +67,47 @@ type HandshakeManager struct { trigger chan iputil.VpnIp } +type HandshakeHostInfo struct { + sync.Mutex + + startTime time.Time // Time that we first started trying with this handshake + ready bool // Is the handshake ready + counter int // How many attempts have we made so far + lastRemotes []*udp.Addr // Remotes that we sent to during the previous attempt + packetStore []*cachedPacket // A set of packets to be transmitted once the handshake completes + + hostinfo *HostInfo +} + +func (hh *HandshakeHostInfo) cachePacket(l *logrus.Logger, t header.MessageType, st header.MessageSubType, packet []byte, f packetCallback, m *cachedPacketMetrics) { + if len(hh.packetStore) < 100 { + tempPacket := make([]byte, len(packet)) + copy(tempPacket, packet) + + hh.packetStore = append(hh.packetStore, &cachedPacket{t, st, f, tempPacket}) + if l.Level >= logrus.DebugLevel { + hh.hostinfo.logger(l). + WithField("length", len(hh.packetStore)). + WithField("stored", true). + Debugf("Packet store") + } + + } else { + m.dropped.Inc(1) + + if l.Level >= logrus.DebugLevel { + hh.hostinfo.logger(l). + WithField("length", len(hh.packetStore)). + WithField("stored", false). + Debugf("Packet store") + } + } +} + func NewHandshakeManager(l *logrus.Logger, mainHostMap *HostMap, lightHouse *LightHouse, outside udp.Conn, config HandshakeConfig) *HandshakeManager { return &HandshakeManager{ - vpnIps: map[iputil.VpnIp]*HostInfo{}, - indexes: map[uint32]*HostInfo{}, + vpnIps: map[iputil.VpnIp]*HandshakeHostInfo{}, + indexes: map[uint32]*HandshakeHostInfo{}, mainHostMap: mainHostMap, lightHouse: lightHouse, outside: outside, @@ -100,6 +137,31 @@ func (c *HandshakeManager) Run(ctx context.Context) { } } +func (hm *HandshakeManager) HandleIncoming(addr *udp.Addr, via *ViaSender, packet []byte, h *header.H) { + // First remote allow list check before we know the vpnIp + if addr != nil { + if !hm.lightHouse.GetRemoteAllowList().AllowUnknownVpnIp(addr.IP) { + hm.l.WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake") + return + } + } + + switch h.Subtype { + case header.HandshakeIXPSK0: + switch h.MessageCounter { + case 1: + ixHandshakeStage1(hm.f, addr, via, packet, h) + + case 2: + newHostinfo := hm.queryIndex(h.RemoteIndex) + tearDown := ixHandshakeStage2(hm.f, addr, via, newHostinfo, packet, h) + if tearDown && newHostinfo != nil { + hm.DeleteHostInfo(newHostinfo.hostinfo) + } + } + } +} + func (c *HandshakeManager) NextOutboundHandshakeTimerTick(now time.Time) { c.OutboundHandshakeTimer.Advance(now) for { @@ -111,41 +173,35 @@ func (c *HandshakeManager) NextOutboundHandshakeTimerTick(now time.Time) { } } -func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTriggered bool) { - hostinfo := c.QueryVpnIp(vpnIp) - if hostinfo == nil { - return - } - hostinfo.Lock() - defer hostinfo.Unlock() - - // We may have raced to completion but now that we have a lock we should ensure we have not yet completed. - if hostinfo.HandshakeComplete { - // Ensure we don't exist in the pending hostmap anymore since we have completed - c.DeleteHostInfo(hostinfo) +func (hm *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTriggered bool) { + hh := hm.queryVpnIp(vpnIp) + if hh == nil { return } + hh.Lock() + defer hh.Unlock() + hostinfo := hh.hostinfo // If we are out of time, clean up - if hostinfo.HandshakeCounter >= c.config.retries { - hostinfo.logger(c.l).WithField("udpAddrs", hostinfo.remotes.CopyAddrs(c.mainHostMap.preferredRanges)). - WithField("initiatorIndex", hostinfo.localIndexId). - WithField("remoteIndex", hostinfo.remoteIndexId). + if hh.counter >= hm.config.retries { + hh.hostinfo.logger(hm.l).WithField("udpAddrs", hh.hostinfo.remotes.CopyAddrs(hm.mainHostMap.preferredRanges)). + WithField("initiatorIndex", hh.hostinfo.localIndexId). + WithField("remoteIndex", hh.hostinfo.remoteIndexId). WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). - WithField("durationNs", time.Since(hostinfo.handshakeStart).Nanoseconds()). + WithField("durationNs", time.Since(hh.startTime).Nanoseconds()). Info("Handshake timed out") - c.metricTimedOut.Inc(1) - c.DeleteHostInfo(hostinfo) + hm.metricTimedOut.Inc(1) + hm.DeleteHostInfo(hostinfo) return } // Increment the counter to increase our delay, linear backoff - hostinfo.HandshakeCounter++ + hh.counter++ // Check if we have a handshake packet to transmit yet - if !hostinfo.HandshakeReady { - if !ixHandshakeStage0(c.f, hostinfo) { - c.OutboundHandshakeTimer.Add(vpnIp, c.config.tryInterval*time.Duration(hostinfo.HandshakeCounter)) + if !hh.ready { + if !ixHandshakeStage0(hm.f, hh) { + hm.OutboundHandshakeTimer.Add(vpnIp, hm.config.tryInterval*time.Duration(hh.counter)) return } } @@ -155,11 +211,11 @@ func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTriggere // NB ^ This comment doesn't jive. It's how the thing gets initialized. // It's the common path. Should it update every time, in case a future LH query/queries give us more info? if hostinfo.remotes == nil { - hostinfo.remotes = c.lightHouse.QueryCache(vpnIp) + hostinfo.remotes = hm.lightHouse.QueryCache(vpnIp) } - remotes := hostinfo.remotes.CopyAddrs(c.mainHostMap.preferredRanges) - remotesHaveChanged := !udp.AddrSlice(remotes).Equal(hostinfo.HandshakeLastRemotes) + remotes := hostinfo.remotes.CopyAddrs(hm.mainHostMap.preferredRanges) + remotesHaveChanged := !udp.AddrSlice(remotes).Equal(hh.lastRemotes) // We only care about a lighthouse trigger if we have new remotes to send to. // This is a very specific optimization for a fast lighthouse reply. @@ -168,26 +224,26 @@ func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTriggere return } - hostinfo.HandshakeLastRemotes = remotes + hh.lastRemotes = remotes // TODO: this will generate a load of queries for hosts with only 1 ip // (such as ones registered to the lighthouse with only a private IP) // So we only do it one time after attempting 5 handshakes already. - if len(remotes) <= 1 && hostinfo.HandshakeCounter == 5 { + if len(remotes) <= 1 && hh.counter == 5 { // If we only have 1 remote it is highly likely our query raced with the other host registered within the lighthouse // Our vpnIp here has a tunnel with a lighthouse but has yet to send a host update packet there so we only know about // the learned public ip for them. Query again to short circuit the promotion counter - c.lightHouse.QueryServer(vpnIp, c.f) + hm.lightHouse.QueryServer(vpnIp) } // Send the handshake to all known ips, stage 2 takes care of assigning the hostinfo.remote based on the first to reply var sentTo []*udp.Addr var sentMultiport bool - hostinfo.remotes.ForEach(c.mainHostMap.preferredRanges, func(addr *udp.Addr, _ bool) { - c.messageMetrics.Tx(header.Handshake, header.MessageSubType(hostinfo.HandshakePacket[0][1]), 1) - err := c.outside.WriteTo(hostinfo.HandshakePacket[0], addr) + hostinfo.remotes.ForEach(hm.mainHostMap.preferredRanges, func(addr *udp.Addr, _ bool) { + hm.messageMetrics.Tx(header.Handshake, header.MessageSubType(hostinfo.HandshakePacket[0][1]), 1) + err := hm.outside.WriteTo(hostinfo.HandshakePacket[0], addr) if err != nil { - hostinfo.logger(c.l).WithField("udpAddr", addr). + hostinfo.logger(hm.l).WithField("udpAddr", addr). WithField("initiatorIndex", hostinfo.localIndexId). WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). WithError(err).Error("Failed to send handshake message") @@ -197,7 +253,7 @@ func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTriggere } // Attempt a multiport handshake if we are past the TxHandshakeDelay attempts - if c.multiPort.TxHandshake && c.udpRaw != nil && hostinfo.HandshakeCounter >= c.multiPort.TxHandshakeDelay { + if hm.multiPort.TxHandshake && hm.udpRaw != nil && hh.counter >= hm.multiPort.TxHandshakeDelay { sentMultiport = true // We need to re-allocate with 8 bytes at the start of SOCK_RAW raw := hostinfo.HandshakePacket[0x80] @@ -207,10 +263,10 @@ func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTriggere hostinfo.HandshakePacket[0x80] = raw } - c.messageMetrics.Tx(header.Handshake, header.MessageSubType(hostinfo.HandshakePacket[0][1]), 1) - err = c.udpRaw.WriteTo(raw, udp.RandomSendPort.UDPSendPort(c.multiPort.TxPorts), addr) + hm.messageMetrics.Tx(header.Handshake, header.MessageSubType(hostinfo.HandshakePacket[0][1]), 1) + err = hm.udpRaw.WriteTo(raw, udp.RandomSendPort.UDPSendPort(hm.multiPort.TxPorts), addr) if err != nil { - hostinfo.logger(c.l).WithField("udpAddr", addr). + hostinfo.logger(hm.l).WithField("udpAddr", addr). WithField("initiatorIndex", hostinfo.localIndexId). WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). WithError(err).Error("Failed to send handshake message") @@ -221,64 +277,64 @@ func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTriggere // Don't be too noisy or confusing if we fail to send a handshake - if we don't get through we'll eventually log a timeout, // so only log when the list of remotes has changed if remotesHaveChanged { - hostinfo.logger(c.l).WithField("udpAddrs", sentTo). + hostinfo.logger(hm.l).WithField("udpAddrs", sentTo). WithField("initiatorIndex", hostinfo.localIndexId). WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). WithField("multiportHandshake", sentMultiport). Info("Handshake message sent") - } else if c.l.IsLevelEnabled(logrus.DebugLevel) { - hostinfo.logger(c.l).WithField("udpAddrs", sentTo). + } else if hm.l.IsLevelEnabled(logrus.DebugLevel) { + hostinfo.logger(hm.l).WithField("udpAddrs", sentTo). WithField("initiatorIndex", hostinfo.localIndexId). WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). Debug("Handshake message sent") } - if c.config.useRelays && len(hostinfo.remotes.relays) > 0 { - hostinfo.logger(c.l).WithField("relays", hostinfo.remotes.relays).Info("Attempt to relay through hosts") + if hm.config.useRelays && len(hostinfo.remotes.relays) > 0 { + hostinfo.logger(hm.l).WithField("relays", hostinfo.remotes.relays).Info("Attempt to relay through hosts") // Send a RelayRequest to all known Relay IP's for _, relay := range hostinfo.remotes.relays { // Don't relay to myself, and don't relay through the host I'm trying to connect to - if *relay == vpnIp || *relay == c.lightHouse.myVpnIp { + if *relay == vpnIp || *relay == hm.lightHouse.myVpnIp { continue } - relayHostInfo := c.mainHostMap.QueryVpnIp(*relay) + relayHostInfo := hm.mainHostMap.QueryVpnIp(*relay) if relayHostInfo == nil || relayHostInfo.remote == nil { - hostinfo.logger(c.l).WithField("relay", relay.String()).Info("Establish tunnel to relay target") - c.f.Handshake(*relay) + hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Establish tunnel to relay target") + hm.f.Handshake(*relay) continue } // Check the relay HostInfo to see if we already established a relay through it if existingRelay, ok := relayHostInfo.relayState.QueryRelayForByIp(vpnIp); ok { switch existingRelay.State { case Established: - hostinfo.logger(c.l).WithField("relay", relay.String()).Info("Send handshake via relay") - c.f.SendVia(relayHostInfo, existingRelay, hostinfo.HandshakePacket[0], make([]byte, 12), make([]byte, mtu), false) + hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Send handshake via relay") + hm.f.SendVia(relayHostInfo, existingRelay, hostinfo.HandshakePacket[0], make([]byte, 12), make([]byte, mtu), false) case Requested: - hostinfo.logger(c.l).WithField("relay", relay.String()).Info("Re-send CreateRelay request") + hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Re-send CreateRelay request") // Re-send the CreateRelay request, in case the previous one was lost. m := NebulaControl{ Type: NebulaControl_CreateRelayRequest, InitiatorRelayIndex: existingRelay.LocalIndex, - RelayFromIp: uint32(c.lightHouse.myVpnIp), + RelayFromIp: uint32(hm.lightHouse.myVpnIp), RelayToIp: uint32(vpnIp), } msg, err := m.Marshal() if err != nil { - hostinfo.logger(c.l). + hostinfo.logger(hm.l). WithError(err). Error("Failed to marshal Control message to create relay") } else { // This must send over the hostinfo, not over hm.Hosts[ip] - c.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu)) - c.l.WithFields(logrus.Fields{ - "relayFrom": c.lightHouse.myVpnIp, + hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu)) + hm.l.WithFields(logrus.Fields{ + "relayFrom": hm.lightHouse.myVpnIp, "relayTo": vpnIp, "initiatorRelayIndex": existingRelay.LocalIndex, "relay": *relay}). Info("send CreateRelayRequest") } default: - hostinfo.logger(c.l). + hostinfo.logger(hm.l). WithField("vpnIp", vpnIp). WithField("state", existingRelay.State). WithField("relay", relayHostInfo.vpnIp). @@ -287,26 +343,26 @@ func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTriggere } else { // No relays exist or requested yet. if relayHostInfo.remote != nil { - idx, err := AddRelay(c.l, relayHostInfo, c.mainHostMap, vpnIp, nil, TerminalType, Requested) + idx, err := AddRelay(hm.l, relayHostInfo, hm.mainHostMap, vpnIp, nil, TerminalType, Requested) if err != nil { - hostinfo.logger(c.l).WithField("relay", relay.String()).WithError(err).Info("Failed to add relay to hostmap") + hostinfo.logger(hm.l).WithField("relay", relay.String()).WithError(err).Info("Failed to add relay to hostmap") } m := NebulaControl{ Type: NebulaControl_CreateRelayRequest, InitiatorRelayIndex: idx, - RelayFromIp: uint32(c.lightHouse.myVpnIp), + RelayFromIp: uint32(hm.lightHouse.myVpnIp), RelayToIp: uint32(vpnIp), } msg, err := m.Marshal() if err != nil { - hostinfo.logger(c.l). + hostinfo.logger(hm.l). WithError(err). Error("Failed to marshal Control message to create relay") } else { - c.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu)) - c.l.WithFields(logrus.Fields{ - "relayFrom": c.lightHouse.myVpnIp, + hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu)) + hm.l.WithFields(logrus.Fields{ + "relayFrom": hm.lightHouse.myVpnIp, "relayTo": vpnIp, "initiatorRelayIndex": idx, "relay": *relay}). @@ -319,13 +375,13 @@ func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTriggere // If a lighthouse triggered this attempt then we are still in the timer wheel and do not need to re-add if !lighthouseTriggered { - c.OutboundHandshakeTimer.Add(vpnIp, c.config.tryInterval*time.Duration(hostinfo.HandshakeCounter)) + hm.OutboundHandshakeTimer.Add(vpnIp, hm.config.tryInterval*time.Duration(hh.counter)) } } // GetOrHandshake will try to find a hostinfo with a fully formed tunnel or start a new handshake if one is not present // The 2nd argument will be true if the hostinfo is ready to transmit traffic -func (hm *HandshakeManager) GetOrHandshake(vpnIp iputil.VpnIp, cacheCb func(*HostInfo)) (*HostInfo, bool) { +func (hm *HandshakeManager) GetOrHandshake(vpnIp iputil.VpnIp, cacheCb func(*HandshakeHostInfo)) (*HostInfo, bool) { // Check the main hostmap and maintain a read lock if our host is not there hm.mainHostMap.RLock() if h, ok := hm.mainHostMap.Hosts[vpnIp]; ok { @@ -342,16 +398,16 @@ func (hm *HandshakeManager) GetOrHandshake(vpnIp iputil.VpnIp, cacheCb func(*Hos } // StartHandshake will ensure a handshake is currently being attempted for the provided vpn ip -func (hm *HandshakeManager) StartHandshake(vpnIp iputil.VpnIp, cacheCb func(*HostInfo)) *HostInfo { +func (hm *HandshakeManager) StartHandshake(vpnIp iputil.VpnIp, cacheCb func(*HandshakeHostInfo)) *HostInfo { hm.Lock() + defer hm.Unlock() - if hostinfo, ok := hm.vpnIps[vpnIp]; ok { + if hh, ok := hm.vpnIps[vpnIp]; ok { // We are already trying to handshake with this vpn ip if cacheCb != nil { - cacheCb(hostinfo) + cacheCb(hh) } - hm.Unlock() - return hostinfo + return hh.hostinfo } hostinfo := &HostInfo{ @@ -364,12 +420,16 @@ func (hm *HandshakeManager) StartHandshake(vpnIp iputil.VpnIp, cacheCb func(*Hos }, } - hm.vpnIps[vpnIp] = hostinfo + hh := &HandshakeHostInfo{ + hostinfo: hostinfo, + startTime: time.Now(), + } + hm.vpnIps[vpnIp] = hh hm.metricInitiated.Inc(1) hm.OutboundHandshakeTimer.Add(vpnIp, hm.config.tryInterval) if cacheCb != nil { - cacheCb(hostinfo) + cacheCb(hh) } // If this is a static host, we don't need to wait for the HostQueryReply @@ -387,8 +447,7 @@ func (hm *HandshakeManager) StartHandshake(vpnIp iputil.VpnIp, cacheCb func(*Hos } } - hm.Unlock() - hm.lightHouse.QueryServer(vpnIp, hm.f) + hm.lightHouse.QueryServer(vpnIp) return hostinfo } @@ -442,8 +501,8 @@ func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket return existingIndex, ErrLocalIndexCollision } - existingIndex, found = c.indexes[hostinfo.localIndexId] - if found && existingIndex != hostinfo { + existingPendingIndex, found := c.indexes[hostinfo.localIndexId] + if found && existingPendingIndex.hostinfo != hostinfo { // We have a collision, but for a different hostinfo return existingIndex, ErrLocalIndexCollision } @@ -487,7 +546,7 @@ func (hm *HandshakeManager) Complete(hostinfo *HostInfo, f *Interface) { // allocateIndex generates a unique localIndexId for this HostInfo // and adds it to the pendingHostMap. Will error if we are unable to generate // a unique localIndexId -func (hm *HandshakeManager) allocateIndex(h *HostInfo) error { +func (hm *HandshakeManager) allocateIndex(hh *HandshakeHostInfo) error { hm.mainHostMap.RLock() defer hm.mainHostMap.RUnlock() hm.Lock() @@ -503,8 +562,8 @@ func (hm *HandshakeManager) allocateIndex(h *HostInfo) error { _, inMain := hm.mainHostMap.Indexes[index] if !inMain && !inPending { - h.localIndexId = index - hm.indexes[index] = h + hh.hostinfo.localIndexId = index + hm.indexes[index] = hh return nil } } @@ -521,12 +580,12 @@ func (c *HandshakeManager) DeleteHostInfo(hostinfo *HostInfo) { func (c *HandshakeManager) unlockedDeleteHostInfo(hostinfo *HostInfo) { delete(c.vpnIps, hostinfo.vpnIp) if len(c.vpnIps) == 0 { - c.vpnIps = map[iputil.VpnIp]*HostInfo{} + c.vpnIps = map[iputil.VpnIp]*HandshakeHostInfo{} } delete(c.indexes, hostinfo.localIndexId) if len(c.vpnIps) == 0 { - c.indexes = map[uint32]*HostInfo{} + c.indexes = map[uint32]*HandshakeHostInfo{} } if c.l.Level >= logrus.DebugLevel { @@ -536,16 +595,33 @@ func (c *HandshakeManager) unlockedDeleteHostInfo(hostinfo *HostInfo) { } } -func (c *HandshakeManager) QueryVpnIp(vpnIp iputil.VpnIp) *HostInfo { - c.RLock() - defer c.RUnlock() - return c.vpnIps[vpnIp] +func (hm *HandshakeManager) QueryVpnIp(vpnIp iputil.VpnIp) *HostInfo { + hh := hm.queryVpnIp(vpnIp) + if hh != nil { + return hh.hostinfo + } + return nil + } -func (c *HandshakeManager) QueryIndex(index uint32) *HostInfo { - c.RLock() - defer c.RUnlock() - return c.indexes[index] +func (hm *HandshakeManager) queryVpnIp(vpnIp iputil.VpnIp) *HandshakeHostInfo { + hm.RLock() + defer hm.RUnlock() + return hm.vpnIps[vpnIp] +} + +func (hm *HandshakeManager) QueryIndex(index uint32) *HostInfo { + hh := hm.queryIndex(index) + if hh != nil { + return hh.hostinfo + } + return nil +} + +func (hm *HandshakeManager) queryIndex(index uint32) *HandshakeHostInfo { + hm.RLock() + defer hm.RUnlock() + return hm.indexes[index] } func (c *HandshakeManager) GetPreferredRanges() []*net.IPNet { @@ -557,7 +633,7 @@ func (c *HandshakeManager) ForEachVpnIp(f controlEach) { defer c.RUnlock() for _, v := range c.vpnIps { - f(v) + f(v.hostinfo) } } @@ -566,7 +642,7 @@ func (c *HandshakeManager) ForEachIndex(f controlEach) { defer c.RUnlock() for _, v := range c.indexes { - f(v) + f(v.hostinfo) } } diff --git a/handshake_manager_test.go b/handshake_manager_test.go index d318a9d..303aa50 100644 --- a/handshake_manager_test.go +++ b/handshake_manager_test.go @@ -5,6 +5,7 @@ import ( "testing" "time" + "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/header" "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/test" @@ -21,7 +22,16 @@ func Test_NewHandshakeManagerVpnIp(t *testing.T) { mainHM := NewHostMap(l, vpncidr, preferredRanges) lh := newTestLighthouse() + cs := &CertState{ + RawCertificate: []byte{}, + PrivateKey: []byte{}, + Certificate: &cert.NebulaCertificate{}, + RawCertificateNoKey: []byte{}, + } + blah := NewHandshakeManager(l, mainHM, lh, &udp.NoopConn{}, defaultHandshakeConfig) + blah.f = &Interface{handshakeManager: blah, pki: &PKI{}, l: l} + blah.f.pki.cs.Store(cs) now := time.Now() blah.NextOutboundHandshakeTimerTick(now) @@ -31,7 +41,6 @@ func Test_NewHandshakeManagerVpnIp(t *testing.T) { assert.Same(t, i, i2) i.remotes = NewRemoteList(nil) - i.HandshakeReady = true // Adding something to pending should not affect the main hostmap assert.Len(t, mainHM.Hosts, 0) diff --git a/hostmap.go b/hostmap.go index 8f36f65..85af110 100644 --- a/hostmap.go +++ b/hostmap.go @@ -21,6 +21,7 @@ const defaultPromoteEvery = 1000 // Count of packets sent before we try mo const defaultReQueryEvery = 5000 // Count of packets sent before re-querying a hostinfo to the lighthouse const defaultReQueryWait = time.Minute // Minimum amount of seconds to wait before re-querying a hostinfo the lighthouse. Evaluated every ReQueryEvery const MaxRemotes = 10 +const maxRecvError = 4 // MaxHostInfosPerVpnIp is the max number of hostinfos we will track for a given vpn ip // 5 allows for an initial handshake and each host pair re-handshaking twice @@ -196,27 +197,26 @@ func (rs *RelayState) InsertRelay(ip iputil.VpnIp, idx uint32, r *Relay) { } type HostInfo struct { - sync.RWMutex + remote *udp.Addr + remotes *RemoteList + promoteCounter atomic.Uint32 + ConnectionState *ConnectionState + remoteIndexId uint32 + localIndexId uint32 + vpnIp iputil.VpnIp + recvError atomic.Uint32 + remoteCidr *cidr.Tree4[struct{}] + relayState RelayState - remote *udp.Addr - remotes *RemoteList - promoteCounter atomic.Uint32 - multiportTx bool - multiportRx bool - ConnectionState *ConnectionState - handshakeStart time.Time //todo: this an entry in the handshake manager - HandshakeReady bool //todo: being in the manager means you are ready - HandshakeCounter int //todo: another handshake manager entry - HandshakeLastRemotes []*udp.Addr //todo: another handshake manager entry, which remotes we sent to last time - HandshakeComplete bool //todo: this should go away in favor of ConnectionState.ready - HandshakePacket map[uint8][]byte - packetStore []*cachedPacket //todo: this is other handshake manager entry - remoteIndexId uint32 - localIndexId uint32 - vpnIp iputil.VpnIp - recvError int - remoteCidr *cidr.Tree4 - relayState RelayState + // If true, we should send to this remote using multiport + multiportTx bool + + // If true, we will receive from this remote using multiport + multiportRx bool + + // HandshakePacket records the packets used to create this hostinfo + // We need these to avoid replayed handshake packets creating new hostinfos which causes churn + HandshakePacket map[uint8][]byte // nextLHQuery is the earliest we can ask the lighthouse for new information. // This is used to limit lighthouse re-queries in chatty clients @@ -414,7 +414,6 @@ func (hm *HostMap) QueryIndex(index uint32) *HostInfo { } func (hm *HostMap) QueryRelayIndex(index uint32) *HostInfo { - //TODO: we probably just want to return bool instead of error, or at least a static error hm.RLock() if h, ok := hm.Relays[index]; ok { hm.RUnlock() @@ -537,10 +536,7 @@ func (hm *HostMap) ForEachIndex(f controlEach) { func (i *HostInfo) TryPromoteBest(preferredRanges []*net.IPNet, ifce *Interface) { c := i.promoteCounter.Add(1) if c%ifce.tryPromoteEvery.Load() == 0 { - // The lock here is currently protecting i.remote access - i.RLock() remote := i.remote - i.RUnlock() // return early if we are already on a preferred remote if remote != nil { @@ -571,62 +567,10 @@ func (i *HostInfo) TryPromoteBest(preferredRanges []*net.IPNet, ifce *Interface) } i.nextLHQuery.Store(now + ifce.reQueryWait.Load()) - ifce.lightHouse.QueryServer(i.vpnIp, ifce) + ifce.lightHouse.QueryServer(i.vpnIp) } } -func (i *HostInfo) unlockedCachePacket(l *logrus.Logger, t header.MessageType, st header.MessageSubType, packet []byte, f packetCallback, m *cachedPacketMetrics) { - //TODO: return the error so we can log with more context - if len(i.packetStore) < 100 { - tempPacket := make([]byte, len(packet)) - copy(tempPacket, packet) - //l.WithField("trace", string(debug.Stack())).Error("Caching packet", tempPacket) - i.packetStore = append(i.packetStore, &cachedPacket{t, st, f, tempPacket}) - if l.Level >= logrus.DebugLevel { - i.logger(l). - WithField("length", len(i.packetStore)). - WithField("stored", true). - Debugf("Packet store") - } - - } else if l.Level >= logrus.DebugLevel { - m.dropped.Inc(1) - i.logger(l). - WithField("length", len(i.packetStore)). - WithField("stored", false). - Debugf("Packet store") - } -} - -// handshakeComplete will set the connection as ready to communicate, as well as flush any stored packets -func (i *HostInfo) handshakeComplete(l *logrus.Logger, m *cachedPacketMetrics) { - //TODO: I'm not certain the distinction between handshake complete and ConnectionState being ready matters because: - //TODO: HandshakeComplete means send stored packets and ConnectionState.ready means we are ready to send - //TODO: if the transition from HandhsakeComplete to ConnectionState.ready happens all within this function they are identical - - i.HandshakeComplete = true - //TODO: this should be managed by the handshake state machine to set it based on how many handshake were seen. - // Clamping it to 2 gets us out of the woods for now - i.ConnectionState.messageCounter.Store(2) - - if l.Level >= logrus.DebugLevel { - i.logger(l).Debugf("Sending %d stored packets", len(i.packetStore)) - } - - if len(i.packetStore) > 0 { - nb := make([]byte, 12, 12) - out := make([]byte, mtu) - for _, cp := range i.packetStore { - cp.callback(cp.messageType, cp.messageSubType, i, cp.packet, nb, out) - } - m.sent.Inc(int64(len(i.packetStore))) - } - - i.remotes.ResetBlockedRemotes() - i.packetStore = make([]*cachedPacket, 0) - i.ConnectionState.ready = true -} - func (i *HostInfo) GetCert() *cert.NebulaCertificate { if i.ConnectionState != nil { return i.ConnectionState.peerCert @@ -683,9 +627,8 @@ func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, newRemote *udp.Addr) bool { } func (i *HostInfo) RecvErrorExceeded() bool { - if i.recvError < 3 { - i.recvError += 1 - return false + if i.recvError.Add(1) >= maxRecvError { + return true } return true } @@ -696,7 +639,7 @@ func (i *HostInfo) CreateRemoteCIDR(c *cert.NebulaCertificate) { return } - remoteCidr := cidr.NewTree4() + remoteCidr := cidr.NewTree4[struct{}]() for _, ip := range c.Details.Ips { remoteCidr.AddCIDR(&net.IPNet{IP: ip.IP, Mask: net.IPMask{255, 255, 255, 255}}, struct{}{}) } diff --git a/inside.go b/inside.go index 728dddd..2f0894b 100644 --- a/inside.go +++ b/inside.go @@ -44,8 +44,8 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet return } - hostinfo, ready := f.getOrHandshake(fwPacket.RemoteIP, func(h *HostInfo) { - h.unlockedCachePacket(f.l, header.Message, 0, packet, f.sendMessageNow, f.cachedPacketMetrics) + hostinfo, ready := f.getOrHandshake(fwPacket.RemoteIP, func(hh *HandshakeHostInfo) { + hh.cachePacket(f.l, header.Message, 0, packet, f.sendMessageNow, f.cachedPacketMetrics) }) if hostinfo == nil { @@ -83,6 +83,10 @@ func (f *Interface) rejectInside(packet []byte, out []byte, q int) { } out = iputil.CreateRejectPacket(packet, out) + if len(out) == 0 { + return + } + _, err := f.readers[q].Write(out) if err != nil { f.l.WithError(err).Error("Failed to write to tun") @@ -94,12 +98,22 @@ func (f *Interface) rejectOutside(packet []byte, ci *ConnectionState, hostinfo * return } - // Use some out buffer space to build the packet before encryption - // Need 40 bytes for the reject packet (20 byte ipv4 header, 20 byte tcp rst packet) - // Leave 100 bytes for the encrypted packet (60 byte Nebula header, 40 byte reject packet) - out = out[:140] - outPacket := iputil.CreateRejectPacket(packet, out[100:]) - f.sendNoMetrics(header.Message, 0, ci, hostinfo, nil, outPacket, nb, out, q, nil) + out = iputil.CreateRejectPacket(packet, out) + if len(out) == 0 { + return + } + + if len(out) > iputil.MaxRejectPacketSize { + if f.l.GetLevel() >= logrus.InfoLevel { + f.l. + WithField("packet", packet). + WithField("outPacket", out). + Info("rejectOutside: packet too big, not sending") + } + return + } + + f.sendNoMetrics(header.Message, 0, ci, hostinfo, nil, out, nb, packet, q, nil) } func (f *Interface) Handshake(vpnIp iputil.VpnIp) { @@ -108,7 +122,7 @@ func (f *Interface) Handshake(vpnIp iputil.VpnIp) { // getOrHandshake returns nil if the vpnIp is not routable. // If the 2nd return var is false then the hostinfo is not ready to be used in a tunnel -func (f *Interface) getOrHandshake(vpnIp iputil.VpnIp, cacheCallback func(info *HostInfo)) (*HostInfo, bool) { +func (f *Interface) getOrHandshake(vpnIp iputil.VpnIp, cacheCallback func(*HandshakeHostInfo)) (*HostInfo, bool) { if !ipMaskContains(f.lightHouse.myVpnIp, f.lightHouse.myVpnZeros, vpnIp) { vpnIp = f.inside.RouteFor(vpnIp) if vpnIp == 0 { @@ -143,8 +157,8 @@ func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubTyp // SendMessageToVpnIp handles real ip:port lookup and sends to the current best known address for vpnIp func (f *Interface) SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp iputil.VpnIp, p, nb, out []byte) { - hostInfo, ready := f.getOrHandshake(vpnIp, func(h *HostInfo) { - h.unlockedCachePacket(f.l, t, st, p, f.SendMessageToHostInfo, f.cachedPacketMetrics) + hostInfo, ready := f.getOrHandshake(vpnIp, func(hh *HandshakeHostInfo) { + hh.cachePacket(f.l, t, st, p, f.SendMessageToHostInfo, f.cachedPacketMetrics) }) if hostInfo == nil { @@ -291,7 +305,7 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType if t != header.CloseTunnel && hostinfo.lastRebindCount != f.rebindCount { //NOTE: there is an update hole if a tunnel isn't used and exactly 256 rebinds occur before the tunnel is // finally used again. This tunnel would eventually be torn down and recreated if this action didn't help. - f.lightHouse.QueryServer(hostinfo.vpnIp, f) + f.lightHouse.QueryServer(hostinfo.vpnIp) hostinfo.lastRebindCount = f.rebindCount if f.l.Level >= logrus.DebugLevel { f.l.WithField("vpnIp", hostinfo.vpnIp).Debug("Lighthouse update triggered for punch due to rebind counter") diff --git a/interface.go b/interface.go index 58fa2a2..d933a3e 100644 --- a/interface.go +++ b/interface.go @@ -40,7 +40,6 @@ type InterfaceConfig struct { routines int MessageMetrics *MessageMetrics version string - disconnectInvalid bool relayManager *relayManager punchy *Punchy @@ -69,7 +68,7 @@ type Interface struct { dropLocalBroadcast bool dropMulticast bool routines int - disconnectInvalid bool + disconnectInvalid atomic.Bool closed atomic.Bool relayManager *relayManager @@ -188,7 +187,6 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) { version: c.version, writers: make([]udp.Conn, c.routines), readers: make([]io.ReadWriteCloser, c.routines), - disconnectInvalid: c.disconnectInvalid, myVpnIp: myVpnIp, relayManager: c.relayManager, @@ -308,12 +306,24 @@ func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) { func (f *Interface) RegisterConfigChangeCallbacks(c *config.C) { c.RegisterReloadCallback(f.reloadFirewall) c.RegisterReloadCallback(f.reloadSendRecvError) + c.RegisterReloadCallback(f.reloadDisconnectInvalid) c.RegisterReloadCallback(f.reloadMisc) + for _, udpConn := range f.writers { c.RegisterReloadCallback(udpConn.ReloadConfig) } } +func (f *Interface) reloadDisconnectInvalid(c *config.C) { + initial := c.InitialLoad() + if initial || c.HasChanged("pki.disconnect_invalid") { + f.disconnectInvalid.Store(c.GetBool("pki.disconnect_invalid", true)) + if !initial { + f.l.Infof("pki.disconnect_invalid changed to %v", f.disconnectInvalid.Load()) + } + } +} + func (f *Interface) reloadFirewall(c *config.C) { //TODO: need to trigger/detect if the certificate changed too if c.HasChanged("firewall") == false { @@ -336,8 +346,8 @@ func (f *Interface) reloadFirewall(c *config.C) { // If rulesVersion is back to zero, we have wrapped all the way around. Be // safe and just reset conntrack in this case. if fw.rulesVersion == 0 { - f.l.WithField("firewallHash", fw.GetRuleHash()). - WithField("oldFirewallHash", oldFw.GetRuleHash()). + f.l.WithField("firewallHashes", fw.GetRuleHashes()). + WithField("oldFirewallHashes", oldFw.GetRuleHashes()). WithField("rulesVersion", fw.rulesVersion). Warn("firewall rulesVersion has overflowed, resetting conntrack") } else { @@ -347,8 +357,8 @@ func (f *Interface) reloadFirewall(c *config.C) { f.firewall = fw oldFw.Destroy() - f.l.WithField("firewallHash", fw.GetRuleHash()). - WithField("oldFirewallHash", oldFw.GetRuleHash()). + f.l.WithField("firewallHashes", fw.GetRuleHashes()). + WithField("oldFirewallHashes", oldFw.GetRuleHashes()). WithField("rulesVersion", fw.rulesVersion). Info("New firewall has been installed") } diff --git a/iputil/packet.go b/iputil/packet.go index 74ae37f..b18e524 100644 --- a/iputil/packet.go +++ b/iputil/packet.go @@ -6,8 +6,19 @@ import ( "golang.org/x/net/ipv4" ) +const ( + // Need 96 bytes for the largest reject packet: + // - 20 byte ipv4 header + // - 8 byte icmpv4 header + // - 68 byte body (60 byte max orig ipv4 header + 8 byte orig icmpv4 header) + MaxRejectPacketSize = ipv4.HeaderLen + 8 + 60 + 8 +) + func CreateRejectPacket(packet []byte, out []byte) []byte { - // TODO ipv4 only, need to fix when inside supports ipv6 + if len(packet) < ipv4.HeaderLen || int(packet[0]>>4) != ipv4.Version { + return nil + } + switch packet[9] { case 6: // tcp return ipv4CreateRejectTCPPacket(packet, out) @@ -19,20 +30,28 @@ func CreateRejectPacket(packet []byte, out []byte) []byte { func ipv4CreateRejectICMPPacket(packet []byte, out []byte) []byte { ihl := int(packet[0]&0x0f) << 2 - // ICMP reply includes header and first 8 bytes of the packet + if len(packet) < ihl { + // We need at least this many bytes for this to be a valid packet + return nil + } + + // ICMP reply includes original header and first 8 bytes of the packet packetLen := len(packet) if packetLen > ihl+8 { packetLen = ihl + 8 } outLen := ipv4.HeaderLen + 8 + packetLen + if outLen > cap(out) { + return nil + } - out = out[:(outLen)] + out = out[:outLen] ipHdr := out[0:ipv4.HeaderLen] - ipHdr[0] = ipv4.Version<<4 | (ipv4.HeaderLen >> 2) // version, ihl - ipHdr[1] = 0 // DSCP, ECN - binary.BigEndian.PutUint16(ipHdr[2:], uint16(ipv4.HeaderLen+8+packetLen)) // Total Length + ipHdr[0] = ipv4.Version<<4 | (ipv4.HeaderLen >> 2) // version, ihl + ipHdr[1] = 0 // DSCP, ECN + binary.BigEndian.PutUint16(ipHdr[2:], uint16(outLen)) // Total Length ipHdr[4] = 0 // id ipHdr[5] = 0 // . @@ -76,7 +95,15 @@ func ipv4CreateRejectTCPPacket(packet []byte, out []byte) []byte { ihl := int(packet[0]&0x0f) << 2 outLen := ipv4.HeaderLen + tcpLen - out = out[:(outLen)] + if len(packet) < ihl+tcpLen { + // We need at least this many bytes for this to be a valid packet + return nil + } + if outLen > cap(out) { + return nil + } + + out = out[:outLen] ipHdr := out[0:ipv4.HeaderLen] ipHdr[0] = ipv4.Version<<4 | (ipv4.HeaderLen >> 2) // version, ihl diff --git a/iputil/packet_test.go b/iputil/packet_test.go new file mode 100644 index 0000000..e1d0d95 --- /dev/null +++ b/iputil/packet_test.go @@ -0,0 +1,73 @@ +package iputil + +import ( + "net" + "testing" + + "github.com/stretchr/testify/assert" + "golang.org/x/net/ipv4" +) + +func Test_CreateRejectPacket(t *testing.T) { + h := ipv4.Header{ + Len: 20, + Src: net.IPv4(10, 0, 0, 1), + Dst: net.IPv4(10, 0, 0, 2), + Protocol: 1, // ICMP + } + + b, err := h.Marshal() + if err != nil { + t.Fatalf("h.Marhshal: %v", err) + } + b = append(b, []byte{0, 3, 0, 4}...) + + expectedLen := ipv4.HeaderLen + 8 + h.Len + 4 + out := make([]byte, expectedLen) + rejectPacket := CreateRejectPacket(b, out) + assert.NotNil(t, rejectPacket) + assert.Len(t, rejectPacket, expectedLen) + + // ICMP with max header len + h = ipv4.Header{ + Len: 60, + Src: net.IPv4(10, 0, 0, 1), + Dst: net.IPv4(10, 0, 0, 2), + Protocol: 1, // ICMP + Options: make([]byte, 40), + } + + b, err = h.Marshal() + if err != nil { + t.Fatalf("h.Marhshal: %v", err) + } + b = append(b, []byte{0, 3, 0, 4, 0, 0, 0, 0}...) + + expectedLen = MaxRejectPacketSize + out = make([]byte, MaxRejectPacketSize) + rejectPacket = CreateRejectPacket(b, out) + assert.NotNil(t, rejectPacket) + assert.Len(t, rejectPacket, expectedLen) + + // TCP with max header len + h = ipv4.Header{ + Len: 60, + Src: net.IPv4(10, 0, 0, 1), + Dst: net.IPv4(10, 0, 0, 2), + Protocol: 6, // TCP + Options: make([]byte, 40), + } + + b, err = h.Marshal() + if err != nil { + t.Fatalf("h.Marhshal: %v", err) + } + b = append(b, []byte{0, 3, 0, 4}...) + b = append(b, make([]byte, 16)...) + + expectedLen = ipv4.HeaderLen + 20 + out = make([]byte, expectedLen) + rejectPacket = CreateRejectPacket(b, out) + assert.NotNil(t, rejectPacket) + assert.Len(t, rejectPacket, expectedLen) +} diff --git a/lighthouse.go b/lighthouse.go index 9b3b837..aa54c4b 100644 --- a/lighthouse.go +++ b/lighthouse.go @@ -74,7 +74,9 @@ type LightHouse struct { // IP's of relays that can be used by peers to access me relaysForMe atomic.Pointer[[]iputil.VpnIp] - calculatedRemotes atomic.Pointer[cidr.Tree4] // Maps VpnIp to []*calculatedRemote + queryChan chan iputil.VpnIp + + calculatedRemotes atomic.Pointer[cidr.Tree4[[]*calculatedRemote]] // Maps VpnIp to []*calculatedRemote metrics *MessageMetrics metricHolepunchTx metrics.Counter @@ -110,6 +112,7 @@ func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C, nebulaPort: nebulaPort, punchConn: pc, punchy: p, + queryChan: make(chan iputil.VpnIp, c.GetUint32("handshakes.query_buffer", 64)), l: l, } lighthouses := make(map[iputil.VpnIp]struct{}) @@ -139,6 +142,8 @@ func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C, } }) + h.startQueryWorker() + return &h, nil } @@ -166,7 +171,7 @@ func (lh *LightHouse) GetRelaysForMe() []iputil.VpnIp { return *lh.relaysForMe.Load() } -func (lh *LightHouse) getCalculatedRemotes() *cidr.Tree4 { +func (lh *LightHouse) getCalculatedRemotes() *cidr.Tree4[[]*calculatedRemote] { return lh.calculatedRemotes.Load() } @@ -443,9 +448,9 @@ func (lh *LightHouse) loadStaticMap(c *config.C, tunCidr *net.IPNet, staticList return nil } -func (lh *LightHouse) Query(ip iputil.VpnIp, f EncWriter) *RemoteList { +func (lh *LightHouse) Query(ip iputil.VpnIp) *RemoteList { if !lh.IsLighthouseIP(ip) { - lh.QueryServer(ip, f) + lh.QueryServer(ip) } lh.RLock() if v, ok := lh.addrMap[ip]; ok { @@ -456,30 +461,14 @@ func (lh *LightHouse) Query(ip iputil.VpnIp, f EncWriter) *RemoteList { return nil } -// This is asynchronous so no reply should be expected -func (lh *LightHouse) QueryServer(ip iputil.VpnIp, f EncWriter) { - if lh.amLighthouse { +// QueryServer is asynchronous so no reply should be expected +func (lh *LightHouse) QueryServer(ip iputil.VpnIp) { + // Don't put lighthouse ips in the query channel because we can't query lighthouses about lighthouses + if lh.amLighthouse || lh.IsLighthouseIP(ip) { return } - if lh.IsLighthouseIP(ip) { - return - } - - // Send a query to the lighthouses and hope for the best next time - query, err := NewLhQueryByInt(ip).Marshal() - if err != nil { - lh.l.WithError(err).WithField("vpnIp", ip).Error("Failed to marshal lighthouse query payload") - return - } - - lighthouses := lh.GetLighthouses() - lh.metricTx(NebulaMeta_HostQuery, int64(len(lighthouses))) - nb := make([]byte, 12, 12) - out := make([]byte, mtu) - for n := range lighthouses { - f.SendMessageToVpnIp(header.LightHouse, 0, n, query, nb, out) - } + lh.queryChan <- ip } func (lh *LightHouse) QueryCache(ip iputil.VpnIp) *RemoteList { @@ -594,11 +583,10 @@ func (lh *LightHouse) addCalculatedRemotes(vpnIp iputil.VpnIp) bool { if tree == nil { return false } - value := tree.MostSpecificContains(vpnIp) - if value == nil { + ok, calculatedRemotes := tree.MostSpecificContains(vpnIp) + if !ok { return false } - calculatedRemotes := value.([]*calculatedRemote) var calculated []*Ip4AndPort for _, cr := range calculatedRemotes { @@ -753,6 +741,46 @@ func NewUDPAddrFromLH6(ipp *Ip6AndPort) *udp.Addr { return udp.NewAddr(lhIp6ToIp(ipp), uint16(ipp.Port)) } +func (lh *LightHouse) startQueryWorker() { + if lh.amLighthouse { + return + } + + go func() { + nb := make([]byte, 12, 12) + out := make([]byte, mtu) + + for { + select { + case <-lh.ctx.Done(): + return + case ip := <-lh.queryChan: + lh.innerQueryServer(ip, nb, out) + } + } + }() +} + +func (lh *LightHouse) innerQueryServer(ip iputil.VpnIp, nb, out []byte) { + if lh.IsLighthouseIP(ip) { + return + } + + // Send a query to the lighthouses and hope for the best next time + query, err := NewLhQueryByInt(ip).Marshal() + if err != nil { + lh.l.WithError(err).WithField("vpnIp", ip).Error("Failed to marshal lighthouse query payload") + return + } + + lighthouses := lh.GetLighthouses() + lh.metricTx(NebulaMeta_HostQuery, int64(len(lighthouses))) + + for n := range lighthouses { + lh.ifce.SendMessageToVpnIp(header.LightHouse, 0, n, query, nb, out) + } +} + func (lh *LightHouse) StartUpdateWorker() { interval := lh.GetUpdateInterval() if lh.amLighthouse || interval == 0 { diff --git a/main.go b/main.go index 4398328..7a7fde6 100644 --- a/main.go +++ b/main.go @@ -18,7 +18,7 @@ import ( type m map[string]interface{} -func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logger, tunFd *int) (retcon *Control, reterr error) { +func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logger, deviceFactory overlay.DeviceFactory) (retcon *Control, reterr error) { ctx, cancel := context.WithCancel(context.Background()) // Automatically cancel the context if Main returns an error, to signal all created goroutines to quit. defer func() { @@ -65,12 +65,15 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg if err != nil { return nil, util.ContextualizeIfNeeded("Error while loading firewall rules", err) } - l.WithField("firewallHash", fw.GetRuleHash()).Info("Firewall started") + l.WithField("firewallHashes", fw.GetRuleHashes()).Info("Firewall started") // TODO: make sure mask is 4 bytes tunCidr := certificate.Details.Ips[0] ssh, err := sshd.NewSSHServer(l.WithField("subsystem", "sshd")) + if err != nil { + return nil, util.ContextualizeIfNeeded("Error while creating SSH server", err) + } wireSSHReload(l, ssh, c) var sshStart func() if c.GetBool("sshd.enabled", false) { @@ -125,7 +128,11 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg if !configTest { c.CatchHUP(ctx) - tun, err = overlay.NewDeviceFromConfig(c, l, tunCidr, tunFd, routines) + if deviceFactory == nil { + deviceFactory = overlay.NewDeviceFromConfig + } + + tun, err = deviceFactory(c, l, tunCidr, routines) if err != nil { return nil, util.ContextualizeIfNeeded("Failed to get a tun/tap device", err) } @@ -156,12 +163,23 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg } for i := 0; i < routines; i++ { + l.Infof("listening %q %d", listenHost.IP, port) udpServer, err := udp.NewListener(l, listenHost.IP, port, routines > 1, c.GetInt("listen.batch", 64)) if err != nil { return nil, util.NewContextualError("Failed to open udp listener", m{"queue": i}, err) } udpServer.ReloadConfig(c) udpConns[i] = udpServer + + // If port is dynamic, discover it before the next pass through the for loop + // This way all routines will use the same port correctly + if port == 0 { + uPort, err := udpServer.LocalAddr() + if err != nil { + return nil, util.NewContextualError("Failed to get listening port", nil, err) + } + port = int(uPort.Port) + } } } @@ -270,7 +288,6 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg routines: routines, MessageMetrics: messageMetrics, version: buildVersion, - disconnectInvalid: c.GetBool("pki.disconnect_invalid", false), relayManager: NewRelayManager(ctx, l, hostMap, c), punchy: punchy, @@ -333,6 +350,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg c.RegisterReloadCallback(loadMultiPortConfig) ifce.RegisterConfigChangeCallbacks(c) + ifce.reloadDisconnectInvalid(c) ifce.reloadSendRecvError(c) handshakeManager.f = ifce @@ -365,6 +383,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg return &Control{ ifce, l, + ctx, cancel, sshStart, statsStart, diff --git a/outside.go b/outside.go index 970f299..bf2b4dd 100644 --- a/outside.go +++ b/outside.go @@ -198,7 +198,7 @@ func (f *Interface) readOutsidePackets(addr *udp.Addr, via *ViaSender, out []byt case header.Handshake: f.messageMetrics.Rx(h.Type, h.Subtype, 1) - HandleIncomingHandshake(f, addr, via, packet, h, hostinfo) + f.handshakeManager.HandleIncoming(addr, via, packet, h) return case header.RecvError: @@ -419,7 +419,9 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out dropReason := f.firewall.Drop(out, *fwPacket, true, hostinfo, f.pki.GetCAPool(), localCache) if dropReason != nil { - f.rejectOutside(out, hostinfo.ConnectionState, hostinfo, nb, out, q) + // NOTE: We give `packet` as the `out` here since we already decrypted from it and we don't need it anymore + // This gives us a buffer to build the reject packet in + f.rejectOutside(out, hostinfo.ConnectionState, hostinfo, nb, packet, q) if f.l.Level >= logrus.DebugLevel { hostinfo.logger(f.l).WithField("fwPacket", fwPacket). WithField("reason", dropReason). @@ -468,9 +470,6 @@ func (f *Interface) handleRecvError(addr *udp.Addr, h *header.H) { return } - hostinfo.Lock() - defer hostinfo.Unlock() - if !hostinfo.RecvErrorExceeded() { return } diff --git a/overlay/route.go b/overlay/route.go index 41c7a9c..793c8fd 100644 --- a/overlay/route.go +++ b/overlay/route.go @@ -21,8 +21,8 @@ type Route struct { Install bool } -func makeRouteTree(l *logrus.Logger, routes []Route, allowMTU bool) (*cidr.Tree4, error) { - routeTree := cidr.NewTree4() +func makeRouteTree(l *logrus.Logger, routes []Route, allowMTU bool) (*cidr.Tree4[iputil.VpnIp], error) { + routeTree := cidr.NewTree4[iputil.VpnIp]() for _, r := range routes { if !allowMTU && r.MTU > 0 { l.WithField("route", r).Warnf("route MTU is not supported in %s", runtime.GOOS) diff --git a/overlay/route_test.go b/overlay/route_test.go index f83b5c1..46fb87c 100644 --- a/overlay/route_test.go +++ b/overlay/route_test.go @@ -265,18 +265,16 @@ func Test_makeRouteTree(t *testing.T) { assert.NoError(t, err) ip := iputil.Ip2VpnIp(net.ParseIP("1.0.0.2")) - r := routeTree.MostSpecificContains(ip) - assert.NotNil(t, r) - assert.IsType(t, iputil.VpnIp(0), r) - assert.EqualValues(t, iputil.Ip2VpnIp(net.ParseIP("192.168.0.1")), r) + ok, r := routeTree.MostSpecificContains(ip) + assert.True(t, ok) + assert.Equal(t, iputil.Ip2VpnIp(net.ParseIP("192.168.0.1")), r) ip = iputil.Ip2VpnIp(net.ParseIP("1.0.0.1")) - r = routeTree.MostSpecificContains(ip) - assert.NotNil(t, r) - assert.IsType(t, iputil.VpnIp(0), r) - assert.EqualValues(t, iputil.Ip2VpnIp(net.ParseIP("192.168.0.2")), r) + ok, r = routeTree.MostSpecificContains(ip) + assert.True(t, ok) + assert.Equal(t, iputil.Ip2VpnIp(net.ParseIP("192.168.0.2")), r) ip = iputil.Ip2VpnIp(net.ParseIP("1.1.0.1")) - r = routeTree.MostSpecificContains(ip) - assert.Nil(t, r) + ok, r = routeTree.MostSpecificContains(ip) + assert.False(t, ok) } diff --git a/overlay/tun.go b/overlay/tun.go index 5eccec9..ca1a64a 100644 --- a/overlay/tun.go +++ b/overlay/tun.go @@ -10,7 +10,9 @@ import ( const DefaultMTU = 1300 -func NewDeviceFromConfig(c *config.C, l *logrus.Logger, tunCidr *net.IPNet, fd *int, routines int) (Device, error) { +type DeviceFactory func(c *config.C, l *logrus.Logger, tunCidr *net.IPNet, routines int) (Device, error) + +func NewDeviceFromConfig(c *config.C, l *logrus.Logger, tunCidr *net.IPNet, routines int) (Device, error) { routes, err := parseRoutes(c, tunCidr) if err != nil { return nil, util.NewContextualError("Could not parse tun.routes", nil, err) @@ -27,17 +29,6 @@ func NewDeviceFromConfig(c *config.C, l *logrus.Logger, tunCidr *net.IPNet, fd * tun := newDisabledTun(tunCidr, c.GetInt("tun.tx_queue", 500), c.GetBool("stats.message_metrics", false), l) return tun, nil - case fd != nil: - return newTunFromFd( - l, - *fd, - tunCidr, - c.GetInt("tun.mtu", DefaultMTU), - routes, - c.GetInt("tun.tx_queue", 500), - c.GetBool("tun.use_system_route_table", false), - ) - default: return newTun( l, @@ -51,3 +42,28 @@ func NewDeviceFromConfig(c *config.C, l *logrus.Logger, tunCidr *net.IPNet, fd * ) } } + +func NewFdDeviceFromConfig(fd *int) DeviceFactory { + return func(c *config.C, l *logrus.Logger, tunCidr *net.IPNet, routines int) (Device, error) { + routes, err := parseRoutes(c, tunCidr) + if err != nil { + return nil, util.NewContextualError("Could not parse tun.routes", nil, err) + } + + unsafeRoutes, err := parseUnsafeRoutes(c, tunCidr) + if err != nil { + return nil, util.NewContextualError("Could not parse tun.unsafe_routes", nil, err) + } + routes = append(routes, unsafeRoutes...) + return newTunFromFd( + l, + *fd, + tunCidr, + c.GetInt("tun.mtu", DefaultMTU), + routes, + c.GetInt("tun.tx_queue", 500), + c.GetBool("tun.use_system_route_table", false), + ) + + } +} diff --git a/overlay/tun_android.go b/overlay/tun_android.go index c731d78..c5c52db 100644 --- a/overlay/tun_android.go +++ b/overlay/tun_android.go @@ -18,7 +18,7 @@ type tun struct { io.ReadWriteCloser fd int cidr *net.IPNet - routeTree *cidr.Tree4 + routeTree *cidr.Tree4[iputil.VpnIp] l *logrus.Logger } @@ -46,12 +46,8 @@ func newTun(_ *logrus.Logger, _ string, _ *net.IPNet, _ int, _ []Route, _ int, _ } func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp { - r := t.routeTree.MostSpecificContains(ip) - if r != nil { - return r.(iputil.VpnIp) - } - - return 0 + _, r := t.routeTree.MostSpecificContains(ip) + return r } func (t tun) Activate() error { diff --git a/overlay/tun_darwin.go b/overlay/tun_darwin.go index 428e38f..caec580 100644 --- a/overlay/tun_darwin.go +++ b/overlay/tun_darwin.go @@ -25,7 +25,7 @@ type tun struct { cidr *net.IPNet DefaultMTU int Routes []Route - routeTree *cidr.Tree4 + routeTree *cidr.Tree4[iputil.VpnIp] l *logrus.Logger // cache out buffer since we need to prepend 4 bytes for tun metadata @@ -304,9 +304,9 @@ func (t *tun) Activate() error { } func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp { - r := t.routeTree.MostSpecificContains(ip) - if r != nil { - return r.(iputil.VpnIp) + ok, r := t.routeTree.MostSpecificContains(ip) + if ok { + return r } return 0 diff --git a/overlay/tun_freebsd.go b/overlay/tun_freebsd.go index 8a52954..338b8f6 100644 --- a/overlay/tun_freebsd.go +++ b/overlay/tun_freebsd.go @@ -48,7 +48,7 @@ type tun struct { cidr *net.IPNet MTU int Routes []Route - routeTree *cidr.Tree4 + routeTree *cidr.Tree4[iputil.VpnIp] l *logrus.Logger io.ReadWriteCloser @@ -192,12 +192,8 @@ func (t *tun) Activate() error { } func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp { - r := t.routeTree.MostSpecificContains(ip) - if r != nil { - return r.(iputil.VpnIp) - } - - return 0 + _, r := t.routeTree.MostSpecificContains(ip) + return r } func (t *tun) Cidr() *net.IPNet { diff --git a/overlay/tun_ios.go b/overlay/tun_ios.go index 26f34ec..ce65b33 100644 --- a/overlay/tun_ios.go +++ b/overlay/tun_ios.go @@ -20,7 +20,7 @@ import ( type tun struct { io.ReadWriteCloser cidr *net.IPNet - routeTree *cidr.Tree4 + routeTree *cidr.Tree4[iputil.VpnIp] } func newTun(_ *logrus.Logger, _ string, _ *net.IPNet, _ int, _ []Route, _ int, _ bool, _ bool) (*tun, error) { @@ -46,12 +46,8 @@ func (t *tun) Activate() error { } func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp { - r := t.routeTree.MostSpecificContains(ip) - if r != nil { - return r.(iputil.VpnIp) - } - - return 0 + _, r := t.routeTree.MostSpecificContains(ip) + return r } // The following is hoisted up from water, we do this so we can inject our own fd on iOS diff --git a/overlay/tun_linux.go b/overlay/tun_linux.go index 8751a3f..a576bf3 100644 --- a/overlay/tun_linux.go +++ b/overlay/tun_linux.go @@ -30,7 +30,7 @@ type tun struct { TXQueueLen int Routes []Route - routeTree atomic.Pointer[cidr.Tree4] + routeTree atomic.Pointer[cidr.Tree4[iputil.VpnIp]] routeChan chan struct{} useSystemRoutes bool @@ -154,12 +154,8 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { } func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp { - r := t.routeTree.Load().MostSpecificContains(ip) - if r != nil { - return r.(iputil.VpnIp) - } - - return 0 + _, r := t.routeTree.Load().MostSpecificContains(ip) + return r } func (t *tun) Write(b []byte) (int, error) { @@ -380,7 +376,7 @@ func (t *tun) updateRoutes(r netlink.RouteUpdate) { return } - newTree := cidr.NewTree4() + newTree := cidr.NewTree4[iputil.VpnIp]() if r.Type == unix.RTM_NEWROUTE { for _, oldR := range t.routeTree.Load().List() { newTree.AddCIDR(oldR.CIDR, oldR.Value) @@ -392,7 +388,7 @@ func (t *tun) updateRoutes(r netlink.RouteUpdate) { } else { gw := iputil.Ip2VpnIp(r.Gw) for _, oldR := range t.routeTree.Load().List() { - if bytes.Equal(oldR.CIDR.IP, r.Dst.IP) && bytes.Equal(oldR.CIDR.Mask, r.Dst.Mask) && *oldR.Value != nil && (*oldR.Value).(iputil.VpnIp) == gw { + if bytes.Equal(oldR.CIDR.IP, r.Dst.IP) && bytes.Equal(oldR.CIDR.Mask, r.Dst.Mask) && oldR.Value == gw { // This is the record to delete t.l.WithField("destination", r.Dst).WithField("via", r.Gw).Info("Removing route") continue diff --git a/overlay/tun_netbsd.go b/overlay/tun_netbsd.go index 4d7f897..b1135fe 100644 --- a/overlay/tun_netbsd.go +++ b/overlay/tun_netbsd.go @@ -29,7 +29,7 @@ type tun struct { cidr *net.IPNet MTU int Routes []Route - routeTree *cidr.Tree4 + routeTree *cidr.Tree4[iputil.VpnIp] l *logrus.Logger io.ReadWriteCloser @@ -134,12 +134,8 @@ func (t *tun) Activate() error { } func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp { - r := t.routeTree.MostSpecificContains(ip) - if r != nil { - return r.(iputil.VpnIp) - } - - return 0 + _, r := t.routeTree.MostSpecificContains(ip) + return r } func (t *tun) Cidr() *net.IPNet { diff --git a/overlay/tun_openbsd.go b/overlay/tun_openbsd.go index 709fb42..45c06dc 100644 --- a/overlay/tun_openbsd.go +++ b/overlay/tun_openbsd.go @@ -23,7 +23,7 @@ type tun struct { cidr *net.IPNet MTU int Routes []Route - routeTree *cidr.Tree4 + routeTree *cidr.Tree4[iputil.VpnIp] l *logrus.Logger io.ReadWriteCloser @@ -115,12 +115,8 @@ func (t *tun) Activate() error { } func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp { - r := t.routeTree.MostSpecificContains(ip) - if r != nil { - return r.(iputil.VpnIp) - } - - return 0 + _, r := t.routeTree.MostSpecificContains(ip) + return r } func (t *tun) Cidr() *net.IPNet { diff --git a/overlay/tun_tester.go b/overlay/tun_tester.go index a2a57e1..964315a 100644 --- a/overlay/tun_tester.go +++ b/overlay/tun_tester.go @@ -19,7 +19,7 @@ type TestTun struct { Device string cidr *net.IPNet Routes []Route - routeTree *cidr.Tree4 + routeTree *cidr.Tree4[iputil.VpnIp] l *logrus.Logger closed atomic.Bool @@ -83,12 +83,8 @@ func (t *TestTun) Get(block bool) []byte { //********************************************************************************************************************// func (t *TestTun) RouteFor(ip iputil.VpnIp) iputil.VpnIp { - r := t.routeTree.MostSpecificContains(ip) - if r != nil { - return r.(iputil.VpnIp) - } - - return 0 + _, r := t.routeTree.MostSpecificContains(ip) + return r } func (t *TestTun) Activate() error { diff --git a/overlay/tun_water_windows.go b/overlay/tun_water_windows.go index b1c28d6..e27cff2 100644 --- a/overlay/tun_water_windows.go +++ b/overlay/tun_water_windows.go @@ -18,7 +18,7 @@ type waterTun struct { cidr *net.IPNet MTU int Routes []Route - routeTree *cidr.Tree4 + routeTree *cidr.Tree4[iputil.VpnIp] *water.Interface } @@ -97,12 +97,8 @@ func (t *waterTun) Activate() error { } func (t *waterTun) RouteFor(ip iputil.VpnIp) iputil.VpnIp { - r := t.routeTree.MostSpecificContains(ip) - if r != nil { - return r.(iputil.VpnIp) - } - - return 0 + _, r := t.routeTree.MostSpecificContains(ip) + return r } func (t *waterTun) Cidr() *net.IPNet { diff --git a/overlay/tun_wintun_windows.go b/overlay/tun_wintun_windows.go index a406123..9647024 100644 --- a/overlay/tun_wintun_windows.go +++ b/overlay/tun_wintun_windows.go @@ -24,7 +24,7 @@ type winTun struct { prefix netip.Prefix MTU int Routes []Route - routeTree *cidr.Tree4 + routeTree *cidr.Tree4[iputil.VpnIp] tun *wintun.NativeTun } @@ -146,12 +146,8 @@ func (t *winTun) Activate() error { } func (t *winTun) RouteFor(ip iputil.VpnIp) iputil.VpnIp { - r := t.routeTree.MostSpecificContains(ip) - if r != nil { - return r.(iputil.VpnIp) - } - - return 0 + _, r := t.routeTree.MostSpecificContains(ip) + return r } func (t *winTun) Cidr() *net.IPNet { diff --git a/overlay/user.go b/overlay/user.go new file mode 100644 index 0000000..9d819ae --- /dev/null +++ b/overlay/user.go @@ -0,0 +1,63 @@ +package overlay + +import ( + "io" + "net" + + "github.com/sirupsen/logrus" + "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/iputil" +) + +func NewUserDeviceFromConfig(c *config.C, l *logrus.Logger, tunCidr *net.IPNet, routines int) (Device, error) { + return NewUserDevice(tunCidr) +} + +func NewUserDevice(tunCidr *net.IPNet) (Device, error) { + // these pipes guarantee each write/read will match 1:1 + or, ow := io.Pipe() + ir, iw := io.Pipe() + return &UserDevice{ + tunCidr: tunCidr, + outboundReader: or, + outboundWriter: ow, + inboundReader: ir, + inboundWriter: iw, + }, nil +} + +type UserDevice struct { + tunCidr *net.IPNet + + outboundReader *io.PipeReader + outboundWriter *io.PipeWriter + + inboundReader *io.PipeReader + inboundWriter *io.PipeWriter +} + +func (d *UserDevice) Activate() error { + return nil +} +func (d *UserDevice) Cidr() *net.IPNet { return d.tunCidr } +func (d *UserDevice) Name() string { return "faketun0" } +func (d *UserDevice) RouteFor(ip iputil.VpnIp) iputil.VpnIp { return ip } +func (d *UserDevice) NewMultiQueueReader() (io.ReadWriteCloser, error) { + return d, nil +} + +func (d *UserDevice) Pipe() (*io.PipeReader, *io.PipeWriter) { + return d.inboundReader, d.outboundWriter +} + +func (d *UserDevice) Read(p []byte) (n int, err error) { + return d.outboundReader.Read(p) +} +func (d *UserDevice) Write(p []byte) (n int, err error) { + return d.inboundWriter.Write(p) +} +func (d *UserDevice) Close() error { + d.inboundWriter.Close() + d.outboundWriter.Close() + return nil +} diff --git a/service/listener.go b/service/listener.go new file mode 100644 index 0000000..6d5c8a4 --- /dev/null +++ b/service/listener.go @@ -0,0 +1,36 @@ +package service + +import ( + "io" + "net" +) + +type tcpListener struct { + port uint16 + s *Service + addr *net.TCPAddr + accept chan net.Conn +} + +func (l *tcpListener) Accept() (net.Conn, error) { + conn, ok := <-l.accept + if !ok { + return nil, io.EOF + } + return conn, nil +} + +func (l *tcpListener) Close() error { + l.s.mu.Lock() + defer l.s.mu.Unlock() + delete(l.s.mu.listeners, uint16(l.addr.Port)) + + close(l.accept) + + return nil +} + +// Addr returns the listener's network address. +func (l *tcpListener) Addr() net.Addr { + return l.addr +} diff --git a/service/service.go b/service/service.go new file mode 100644 index 0000000..66ce864 --- /dev/null +++ b/service/service.go @@ -0,0 +1,248 @@ +package service + +import ( + "bytes" + "context" + "errors" + "fmt" + "log" + "math" + "net" + "os" + "strings" + "sync" + + "github.com/sirupsen/logrus" + "github.com/slackhq/nebula" + "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/overlay" + "golang.org/x/sync/errgroup" + "gvisor.dev/gvisor/pkg/bufferv2" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/channel" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" + "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" + "gvisor.dev/gvisor/pkg/tcpip/transport/udp" + "gvisor.dev/gvisor/pkg/waiter" +) + +const nicID = 1 + +type Service struct { + eg *errgroup.Group + control *nebula.Control + ipstack *stack.Stack + + mu struct { + sync.Mutex + + listeners map[uint16]*tcpListener + } +} + +func New(config *config.C) (*Service, error) { + logger := logrus.New() + logger.Out = os.Stdout + + control, err := nebula.Main(config, false, "custom-app", logger, overlay.NewUserDeviceFromConfig) + if err != nil { + return nil, err + } + control.Start() + + ctx := control.Context() + eg, ctx := errgroup.WithContext(ctx) + s := Service{ + eg: eg, + control: control, + } + s.mu.listeners = map[uint16]*tcpListener{} + + device, ok := control.Device().(*overlay.UserDevice) + if !ok { + return nil, errors.New("must be using user device") + } + + s.ipstack = stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol, icmp.NewProtocol4, icmp.NewProtocol6}, + }) + sackEnabledOpt := tcpip.TCPSACKEnabled(true) // TCP SACK is disabled by default + tcpipErr := s.ipstack.SetTransportProtocolOption(tcp.ProtocolNumber, &sackEnabledOpt) + if tcpipErr != nil { + return nil, fmt.Errorf("could not enable TCP SACK: %v", tcpipErr) + } + linkEP := channel.New( /*size*/ 512 /*mtu*/, 1280, "") + if tcpipProblem := s.ipstack.CreateNIC(nicID, linkEP); tcpipProblem != nil { + return nil, fmt.Errorf("could not create netstack NIC: %v", tcpipProblem) + } + ipv4Subnet, _ := tcpip.NewSubnet(tcpip.Address(strings.Repeat("\x00", 4)), tcpip.AddressMask(strings.Repeat("\x00", 4))) + s.ipstack.SetRouteTable([]tcpip.Route{ + { + Destination: ipv4Subnet, + NIC: nicID, + }, + }) + + ipNet := device.Cidr() + pa := tcpip.ProtocolAddress{ + AddressWithPrefix: tcpip.Address(ipNet.IP).WithPrefix(), + Protocol: ipv4.ProtocolNumber, + } + if err := s.ipstack.AddProtocolAddress(nicID, pa, stack.AddressProperties{ + PEB: stack.CanBePrimaryEndpoint, // zero value default + ConfigType: stack.AddressConfigStatic, // zero value default + }); err != nil { + return nil, fmt.Errorf("error creating IP: %s", err) + } + + const tcpReceiveBufferSize = 0 + const maxInFlightConnectionAttempts = 1024 + tcpFwd := tcp.NewForwarder(s.ipstack, tcpReceiveBufferSize, maxInFlightConnectionAttempts, s.tcpHandler) + s.ipstack.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpFwd.HandlePacket) + + reader, writer := device.Pipe() + + go func() { + <-ctx.Done() + reader.Close() + writer.Close() + }() + + // create Goroutines to forward packets between Nebula and Gvisor + eg.Go(func() error { + buf := make([]byte, header.IPv4MaximumHeaderSize+header.IPv4MaximumPayloadSize) + for { + // this will read exactly one packet + n, err := reader.Read(buf) + if err != nil { + return err + } + packetBuf := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Payload: bufferv2.MakeWithData(bytes.Clone(buf[:n])), + }) + linkEP.InjectInbound(header.IPv4ProtocolNumber, packetBuf) + + if err := ctx.Err(); err != nil { + return err + } + } + }) + eg.Go(func() error { + for { + packet := linkEP.ReadContext(ctx) + if packet.IsNil() { + if err := ctx.Err(); err != nil { + return err + } + continue + } + bufView := packet.ToView() + if _, err := bufView.WriteTo(writer); err != nil { + return err + } + bufView.Release() + } + }) + + return &s, nil +} + +// DialContext dials the provided address. Currently only TCP is supported. +func (s *Service) DialContext(ctx context.Context, network, address string) (net.Conn, error) { + if network != "tcp" && network != "tcp4" { + return nil, errors.New("only tcp is supported") + } + + addr, err := net.ResolveTCPAddr(network, address) + if err != nil { + return nil, err + } + + fullAddr := tcpip.FullAddress{ + NIC: nicID, + Addr: tcpip.Address(addr.IP), + Port: uint16(addr.Port), + } + + return gonet.DialContextTCP(ctx, s.ipstack, fullAddr, ipv4.ProtocolNumber) +} + +// Listen listens on the provided address. Currently only TCP with wildcard +// addresses are supported. +func (s *Service) Listen(network, address string) (net.Listener, error) { + if network != "tcp" && network != "tcp4" { + return nil, errors.New("only tcp is supported") + } + addr, err := net.ResolveTCPAddr(network, address) + if err != nil { + return nil, err + } + if addr.IP != nil && !bytes.Equal(addr.IP, []byte{0, 0, 0, 0}) { + return nil, fmt.Errorf("only wildcard address supported, got %q %v", address, addr.IP) + } + if addr.Port == 0 { + return nil, errors.New("specific port required, got 0") + } + if addr.Port < 0 || addr.Port >= math.MaxUint16 { + return nil, fmt.Errorf("invalid port %d", addr.Port) + } + port := uint16(addr.Port) + + l := &tcpListener{ + port: port, + s: s, + addr: addr, + accept: make(chan net.Conn), + } + + s.mu.Lock() + defer s.mu.Unlock() + + if _, ok := s.mu.listeners[port]; ok { + return nil, fmt.Errorf("already listening on port %d", port) + } + s.mu.listeners[port] = l + + return l, nil +} + +func (s *Service) Wait() error { + return s.eg.Wait() +} + +func (s *Service) Close() error { + s.control.Stop() + return nil +} + +func (s *Service) tcpHandler(r *tcp.ForwarderRequest) { + endpointID := r.ID() + + s.mu.Lock() + defer s.mu.Unlock() + + l, ok := s.mu.listeners[endpointID.LocalPort] + if !ok { + r.Complete(true) + return + } + + var wq waiter.Queue + ep, err := r.CreateEndpoint(&wq) + if err != nil { + log.Printf("got error creating endpoint %q", err) + r.Complete(true) + return + } + r.Complete(false) + ep.SocketOptions().SetKeepAlive(true) + + conn := gonet.NewTCPConn(&wq, ep) + l.accept <- conn +} diff --git a/service/service_test.go b/service/service_test.go new file mode 100644 index 0000000..d1909cd --- /dev/null +++ b/service/service_test.go @@ -0,0 +1,165 @@ +package service + +import ( + "bytes" + "context" + "errors" + "net" + "testing" + "time" + + "dario.cat/mergo" + "github.com/slackhq/nebula/cert" + "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/e2e" + "golang.org/x/sync/errgroup" + "gopkg.in/yaml.v2" +) + +type m map[string]interface{} + +func newSimpleService(caCrt *cert.NebulaCertificate, caKey []byte, name string, udpIp net.IP, overrides m) *Service { + + vpnIpNet := &net.IPNet{IP: make([]byte, len(udpIp)), Mask: net.IPMask{255, 255, 255, 0}} + copy(vpnIpNet.IP, udpIp) + + _, _, myPrivKey, myPEM := e2e.NewTestCert(caCrt, caKey, "a", time.Now(), time.Now().Add(5*time.Minute), vpnIpNet, nil, []string{}) + caB, err := caCrt.MarshalToPEM() + if err != nil { + panic(err) + } + + mc := m{ + "pki": m{ + "ca": string(caB), + "cert": string(myPEM), + "key": string(myPrivKey), + }, + //"tun": m{"disabled": true}, + "firewall": m{ + "outbound": []m{{ + "proto": "any", + "port": "any", + "host": "any", + }}, + "inbound": []m{{ + "proto": "any", + "port": "any", + "host": "any", + }}, + }, + "timers": m{ + "pending_deletion_interval": 2, + "connection_alive_interval": 2, + }, + "handshakes": m{ + "try_interval": "200ms", + }, + } + + if overrides != nil { + err = mergo.Merge(&overrides, mc, mergo.WithAppendSlice) + if err != nil { + panic(err) + } + mc = overrides + } + + cb, err := yaml.Marshal(mc) + if err != nil { + panic(err) + } + + var c config.C + if err := c.LoadString(string(cb)); err != nil { + panic(err) + } + + s, err := New(&c) + if err != nil { + panic(err) + } + return s +} + +func TestService(t *testing.T) { + ca, _, caKey, _ := e2e.NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) + a := newSimpleService(ca, caKey, "a", net.IP{10, 0, 0, 1}, m{ + "static_host_map": m{}, + "lighthouse": m{ + "am_lighthouse": true, + }, + "listen": m{ + "host": "0.0.0.0", + "port": 4243, + }, + }) + b := newSimpleService(ca, caKey, "b", net.IP{10, 0, 0, 2}, m{ + "static_host_map": m{ + "10.0.0.1": []string{"localhost:4243"}, + }, + "lighthouse": m{ + "hosts": []string{"10.0.0.1"}, + "interval": 1, + }, + }) + + ln, err := a.Listen("tcp", ":1234") + if err != nil { + t.Fatal(err) + } + var eg errgroup.Group + eg.Go(func() error { + conn, err := ln.Accept() + if err != nil { + return err + } + defer conn.Close() + + t.Log("accepted connection") + + if _, err := conn.Write([]byte("server msg")); err != nil { + return err + } + + t.Log("server: wrote message") + + data := make([]byte, 100) + n, err := conn.Read(data) + if err != nil { + return err + } + data = data[:n] + if !bytes.Equal(data, []byte("client msg")) { + return errors.New("got invalid message from client") + } + t.Log("server: read message") + return conn.Close() + }) + + c, err := b.DialContext(context.Background(), "tcp", "10.0.0.1:1234") + if err != nil { + t.Fatal(err) + } + if _, err := c.Write([]byte("client msg")); err != nil { + t.Fatal(err) + } + + data := make([]byte, 100) + n, err := c.Read(data) + if err != nil { + t.Fatal(err) + } + data = data[:n] + if !bytes.Equal(data, []byte("server msg")) { + t.Fatal("got invalid message from client") + } + + if err := c.Close(); err != nil { + t.Fatal(err) + } + + if err := eg.Wait(); err != nil { + t.Fatal(err) + } +} diff --git a/ssh.go b/ssh.go index 30f9aea..8e48fc4 100644 --- a/ssh.go +++ b/ssh.go @@ -6,7 +6,6 @@ import ( "errors" "flag" "fmt" - "io/ioutil" "net" "os" "reflect" @@ -96,7 +95,7 @@ func configSSH(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) (func(), erro return nil, fmt.Errorf("sshd.host_key must be provided") } - hostKeyBytes, err := ioutil.ReadFile(hostKeyFile) + hostKeyBytes, err := os.ReadFile(hostKeyFile) if err != nil { return nil, fmt.Errorf("error while loading sshd.host_key file: %s", err) } @@ -519,7 +518,7 @@ func sshQueryLighthouse(ifce *Interface, fs interface{}, a []string, w sshd.Stri } var cm *CacheMap - rl := ifce.lightHouse.Query(vpnIp, ifce) + rl := ifce.lightHouse.Query(vpnIp) if rl != nil { cm = rl.CopyCache() } diff --git a/test/logger.go b/test/logger.go index 197ab44..b5a717d 100644 --- a/test/logger.go +++ b/test/logger.go @@ -1,7 +1,7 @@ package test import ( - "io/ioutil" + "io" "os" "github.com/sirupsen/logrus" @@ -12,7 +12,7 @@ func NewLogger() *logrus.Logger { v := os.Getenv("TEST_LOGS") if v == "" { - l.SetOutput(ioutil.Discard) + l.SetOutput(io.Discard) return l }