mirror of
https://github.com/slackhq/nebula.git
synced 2025-11-22 08:24:25 +01:00
Compare commits
22 Commits
psk-v2
...
jay.wren-d
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5ceac2b078 | ||
|
|
b8ea55eb90 | ||
|
|
4eb056af9d | ||
|
|
e49f279004 | ||
|
|
459cb38a6d | ||
|
|
18279ed17b | ||
|
|
c7fb3ad9cf | ||
|
|
d4a7df3083 | ||
|
|
e83a1c6c84 | ||
|
|
f5d096dd2b | ||
|
|
e2d6f4e444 | ||
|
|
d99fd60e06 | ||
|
|
e4bae15825 | ||
|
|
58ead4116f | ||
|
|
e136d1d47a | ||
|
|
d2adebf26d | ||
|
|
36bc9dd261 | ||
|
|
879852c32a | ||
|
|
75faa5f2e5 | ||
|
|
4444ed166a | ||
|
|
f86953ca56 | ||
|
|
3de36c99b6 |
2
.github/workflows/gofmt.yml
vendored
2
.github/workflows/gofmt.yml
vendored
@@ -18,7 +18,7 @@ jobs:
|
||||
|
||||
- uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: '1.23'
|
||||
go-version: '1.24'
|
||||
check-latest: true
|
||||
|
||||
- name: Install goimports
|
||||
|
||||
8
.github/workflows/release.yml
vendored
8
.github/workflows/release.yml
vendored
@@ -14,7 +14,7 @@ jobs:
|
||||
|
||||
- uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: '1.23'
|
||||
go-version: '1.24'
|
||||
check-latest: true
|
||||
|
||||
- name: Build
|
||||
@@ -37,7 +37,7 @@ jobs:
|
||||
|
||||
- uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: '1.23'
|
||||
go-version: '1.24'
|
||||
check-latest: true
|
||||
|
||||
- name: Build
|
||||
@@ -70,12 +70,12 @@ jobs:
|
||||
|
||||
- uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: '1.23'
|
||||
go-version: '1.24'
|
||||
check-latest: true
|
||||
|
||||
- name: Import certificates
|
||||
if: env.HAS_SIGNING_CREDS == 'true'
|
||||
uses: Apple-Actions/import-codesign-certs@v3
|
||||
uses: Apple-Actions/import-codesign-certs@v5
|
||||
with:
|
||||
p12-file-base64: ${{ secrets.APPLE_DEVELOPER_CERTIFICATE_P12_BASE64 }}
|
||||
p12-password: ${{ secrets.APPLE_DEVELOPER_CERTIFICATE_PASSWORD }}
|
||||
|
||||
2
.github/workflows/smoke.yml
vendored
2
.github/workflows/smoke.yml
vendored
@@ -22,7 +22,7 @@ jobs:
|
||||
|
||||
- uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: '1.23'
|
||||
go-version: '1.24'
|
||||
check-latest: true
|
||||
|
||||
- name: build
|
||||
|
||||
14
.github/workflows/test.yml
vendored
14
.github/workflows/test.yml
vendored
@@ -22,7 +22,7 @@ jobs:
|
||||
|
||||
- uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: '1.23'
|
||||
go-version: '1.24'
|
||||
check-latest: true
|
||||
|
||||
- name: Build
|
||||
@@ -32,9 +32,9 @@ jobs:
|
||||
run: make vet
|
||||
|
||||
- name: golangci-lint
|
||||
uses: golangci/golangci-lint-action@v6
|
||||
uses: golangci/golangci-lint-action@v7
|
||||
with:
|
||||
version: v1.64
|
||||
version: v2.0
|
||||
|
||||
- name: Test
|
||||
run: make test
|
||||
@@ -60,7 +60,7 @@ jobs:
|
||||
|
||||
- uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: '1.23'
|
||||
go-version: '1.24'
|
||||
check-latest: true
|
||||
|
||||
- name: Build
|
||||
@@ -102,7 +102,7 @@ jobs:
|
||||
|
||||
- uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: '1.23'
|
||||
go-version: '1.24'
|
||||
check-latest: true
|
||||
|
||||
- name: Build nebula
|
||||
@@ -115,9 +115,9 @@ jobs:
|
||||
run: make vet
|
||||
|
||||
- name: golangci-lint
|
||||
uses: golangci/golangci-lint-action@v6
|
||||
uses: golangci/golangci-lint-action@v7
|
||||
with:
|
||||
version: v1.64
|
||||
version: v2.0
|
||||
|
||||
- name: Test
|
||||
run: make test
|
||||
|
||||
@@ -1,9 +1,23 @@
|
||||
# yaml-language-server: $schema=https://golangci-lint.run/jsonschema/golangci.jsonschema.json
|
||||
version: "2"
|
||||
linters:
|
||||
# Disable all linters.
|
||||
# Default: false
|
||||
disable-all: true
|
||||
# Enable specific linter
|
||||
# https://golangci-lint.run/usage/linters/#enabled-by-default
|
||||
default: none
|
||||
enable:
|
||||
- testifylint
|
||||
exclusions:
|
||||
generated: lax
|
||||
presets:
|
||||
- comments
|
||||
- common-false-positives
|
||||
- legacy
|
||||
- std-error-handling
|
||||
paths:
|
||||
- third_party$
|
||||
- builtin$
|
||||
- examples$
|
||||
formatters:
|
||||
exclusions:
|
||||
generated: lax
|
||||
paths:
|
||||
- third_party$
|
||||
- builtin$
|
||||
- examples$
|
||||
|
||||
@@ -7,6 +7,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
|
||||
## [Unreleased]
|
||||
|
||||
### Changed
|
||||
|
||||
- `default_local_cidr_any` now defaults to false, meaning that any firewall rule
|
||||
intended to target an `unsafe_routes` entry must explicitly declare it via the
|
||||
`local_cidr` field. This is almost always the intended behavior. This flag is
|
||||
deprecated and will be removed in a future release.
|
||||
|
||||
## [1.9.4] - 2024-09-09
|
||||
|
||||
### Added
|
||||
|
||||
@@ -36,7 +36,7 @@ type AllowListNameRule struct {
|
||||
|
||||
func NewLocalAllowListFromConfig(c *config.C, k string) (*LocalAllowList, error) {
|
||||
var nameRules []AllowListNameRule
|
||||
handleKey := func(key string, value interface{}) (bool, error) {
|
||||
handleKey := func(key string, value any) (bool, error) {
|
||||
if key == "interfaces" {
|
||||
var err error
|
||||
nameRules, err = getAllowListInterfaces(k, value)
|
||||
@@ -70,7 +70,7 @@ func NewRemoteAllowListFromConfig(c *config.C, k, rangesKey string) (*RemoteAllo
|
||||
|
||||
// If the handleKey func returns true, the rest of the parsing is skipped
|
||||
// for this key. This allows parsing of special values like `interfaces`.
|
||||
func newAllowListFromConfig(c *config.C, k string, handleKey func(key string, value interface{}) (bool, error)) (*AllowList, error) {
|
||||
func newAllowListFromConfig(c *config.C, k string, handleKey func(key string, value any) (bool, error)) (*AllowList, error) {
|
||||
r := c.Get(k)
|
||||
if r == nil {
|
||||
return nil, nil
|
||||
@@ -81,8 +81,8 @@ func newAllowListFromConfig(c *config.C, k string, handleKey func(key string, va
|
||||
|
||||
// If the handleKey func returns true, the rest of the parsing is skipped
|
||||
// for this key. This allows parsing of special values like `interfaces`.
|
||||
func newAllowList(k string, raw interface{}, handleKey func(key string, value interface{}) (bool, error)) (*AllowList, error) {
|
||||
rawMap, ok := raw.(map[interface{}]interface{})
|
||||
func newAllowList(k string, raw any, handleKey func(key string, value any) (bool, error)) (*AllowList, error) {
|
||||
rawMap, ok := raw.(map[string]any)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("config `%s` has invalid type: %T", k, raw)
|
||||
}
|
||||
@@ -100,12 +100,7 @@ func newAllowList(k string, raw interface{}, handleKey func(key string, value in
|
||||
rules4 := allowListRules{firstValue: true, allValuesMatch: true, defaultSet: false}
|
||||
rules6 := allowListRules{firstValue: true, allValuesMatch: true, defaultSet: false}
|
||||
|
||||
for rawKey, rawValue := range rawMap {
|
||||
rawCIDR, ok := rawKey.(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("config `%s` has invalid key (type %T): %v", k, rawKey, rawKey)
|
||||
}
|
||||
|
||||
for rawCIDR, rawValue := range rawMap {
|
||||
if handleKey != nil {
|
||||
handled, err := handleKey(rawCIDR, rawValue)
|
||||
if err != nil {
|
||||
@@ -116,7 +111,7 @@ func newAllowList(k string, raw interface{}, handleKey func(key string, value in
|
||||
}
|
||||
}
|
||||
|
||||
value, ok := rawValue.(bool)
|
||||
value, ok := config.AsBool(rawValue)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("config `%s` has invalid value (type %T): %v", k, rawValue, rawValue)
|
||||
}
|
||||
@@ -173,22 +168,18 @@ func newAllowList(k string, raw interface{}, handleKey func(key string, value in
|
||||
return &AllowList{cidrTree: tree}, nil
|
||||
}
|
||||
|
||||
func getAllowListInterfaces(k string, v interface{}) ([]AllowListNameRule, error) {
|
||||
func getAllowListInterfaces(k string, v any) ([]AllowListNameRule, error) {
|
||||
var nameRules []AllowListNameRule
|
||||
|
||||
rawRules, ok := v.(map[interface{}]interface{})
|
||||
rawRules, ok := v.(map[string]any)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("config `%s.interfaces` is invalid (type %T): %v", k, v, v)
|
||||
}
|
||||
|
||||
firstEntry := true
|
||||
var allValues bool
|
||||
for rawName, rawAllow := range rawRules {
|
||||
name, ok := rawName.(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("config `%s.interfaces` has invalid key (type %T): %v", k, rawName, rawName)
|
||||
}
|
||||
allow, ok := rawAllow.(bool)
|
||||
for name, rawAllow := range rawRules {
|
||||
allow, ok := config.AsBool(rawAllow)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("config `%s.interfaces` has invalid value (type %T): %v", k, rawAllow, rawAllow)
|
||||
}
|
||||
@@ -224,16 +215,11 @@ func getRemoteAllowRanges(c *config.C, k string) (*bart.Table[*AllowList], error
|
||||
|
||||
remoteAllowRanges := new(bart.Table[*AllowList])
|
||||
|
||||
rawMap, ok := value.(map[interface{}]interface{})
|
||||
rawMap, ok := value.(map[string]any)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("config `%s` has invalid type: %T", k, value)
|
||||
}
|
||||
for rawKey, rawValue := range rawMap {
|
||||
rawCIDR, ok := rawKey.(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("config `%s` has invalid key (type %T): %v", k, rawKey, rawKey)
|
||||
}
|
||||
|
||||
for rawCIDR, rawValue := range rawMap {
|
||||
allowList, err := newAllowList(fmt.Sprintf("%s.%s", k, rawCIDR), rawValue, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
@@ -15,27 +15,27 @@ import (
|
||||
func TestNewAllowListFromConfig(t *testing.T) {
|
||||
l := test.NewLogger()
|
||||
c := config.NewC(l)
|
||||
c.Settings["allowlist"] = map[interface{}]interface{}{
|
||||
c.Settings["allowlist"] = map[string]any{
|
||||
"192.168.0.0": true,
|
||||
}
|
||||
r, err := newAllowListFromConfig(c, "allowlist", nil)
|
||||
require.EqualError(t, err, "config `allowlist` has invalid CIDR: 192.168.0.0. netip.ParsePrefix(\"192.168.0.0\"): no '/'")
|
||||
assert.Nil(t, r)
|
||||
|
||||
c.Settings["allowlist"] = map[interface{}]interface{}{
|
||||
c.Settings["allowlist"] = map[string]any{
|
||||
"192.168.0.0/16": "abc",
|
||||
}
|
||||
r, err = newAllowListFromConfig(c, "allowlist", nil)
|
||||
require.EqualError(t, err, "config `allowlist` has invalid value (type string): abc")
|
||||
|
||||
c.Settings["allowlist"] = map[interface{}]interface{}{
|
||||
c.Settings["allowlist"] = map[string]any{
|
||||
"192.168.0.0/16": true,
|
||||
"10.0.0.0/8": false,
|
||||
}
|
||||
r, err = newAllowListFromConfig(c, "allowlist", nil)
|
||||
require.EqualError(t, err, "config `allowlist` contains both true and false rules, but no default set for 0.0.0.0/0")
|
||||
|
||||
c.Settings["allowlist"] = map[interface{}]interface{}{
|
||||
c.Settings["allowlist"] = map[string]any{
|
||||
"0.0.0.0/0": true,
|
||||
"10.0.0.0/8": false,
|
||||
"10.42.42.0/24": true,
|
||||
@@ -45,7 +45,7 @@ func TestNewAllowListFromConfig(t *testing.T) {
|
||||
r, err = newAllowListFromConfig(c, "allowlist", nil)
|
||||
require.EqualError(t, err, "config `allowlist` contains both true and false rules, but no default set for ::/0")
|
||||
|
||||
c.Settings["allowlist"] = map[interface{}]interface{}{
|
||||
c.Settings["allowlist"] = map[string]any{
|
||||
"0.0.0.0/0": true,
|
||||
"10.0.0.0/8": false,
|
||||
"10.42.42.0/24": true,
|
||||
@@ -55,7 +55,7 @@ func TestNewAllowListFromConfig(t *testing.T) {
|
||||
assert.NotNil(t, r)
|
||||
}
|
||||
|
||||
c.Settings["allowlist"] = map[interface{}]interface{}{
|
||||
c.Settings["allowlist"] = map[string]any{
|
||||
"0.0.0.0/0": true,
|
||||
"10.0.0.0/8": false,
|
||||
"10.42.42.0/24": true,
|
||||
@@ -70,16 +70,16 @@ func TestNewAllowListFromConfig(t *testing.T) {
|
||||
|
||||
// Test interface names
|
||||
|
||||
c.Settings["allowlist"] = map[interface{}]interface{}{
|
||||
"interfaces": map[interface{}]interface{}{
|
||||
c.Settings["allowlist"] = map[string]any{
|
||||
"interfaces": map[string]any{
|
||||
`docker.*`: "foo",
|
||||
},
|
||||
}
|
||||
lr, err := NewLocalAllowListFromConfig(c, "allowlist")
|
||||
require.EqualError(t, err, "config `allowlist.interfaces` has invalid value (type string): foo")
|
||||
|
||||
c.Settings["allowlist"] = map[interface{}]interface{}{
|
||||
"interfaces": map[interface{}]interface{}{
|
||||
c.Settings["allowlist"] = map[string]any{
|
||||
"interfaces": map[string]any{
|
||||
`docker.*`: false,
|
||||
`eth.*`: true,
|
||||
},
|
||||
@@ -87,8 +87,8 @@ func TestNewAllowListFromConfig(t *testing.T) {
|
||||
lr, err = NewLocalAllowListFromConfig(c, "allowlist")
|
||||
require.EqualError(t, err, "config `allowlist.interfaces` values must all be the same true/false value")
|
||||
|
||||
c.Settings["allowlist"] = map[interface{}]interface{}{
|
||||
"interfaces": map[interface{}]interface{}{
|
||||
c.Settings["allowlist"] = map[string]any{
|
||||
"interfaces": map[string]any{
|
||||
`docker.*`: false,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -41,7 +41,7 @@ type detailsV1 struct {
|
||||
curve Curve
|
||||
}
|
||||
|
||||
type m map[string]interface{}
|
||||
type m = map[string]any
|
||||
|
||||
func (c *certificateV1) Version() Version {
|
||||
return Version1
|
||||
|
||||
@@ -10,14 +10,14 @@ import (
|
||||
|
||||
func TestNewArgon2Parameters(t *testing.T) {
|
||||
p := NewArgon2Parameters(64*1024, 4, 3)
|
||||
assert.EqualValues(t, &Argon2Parameters{
|
||||
assert.Equal(t, &Argon2Parameters{
|
||||
version: argon2.Version,
|
||||
Memory: 64 * 1024,
|
||||
Parallelism: 4,
|
||||
Iterations: 3,
|
||||
}, p)
|
||||
p = NewArgon2Parameters(2*1024*1024, 2, 1)
|
||||
assert.EqualValues(t, &Argon2Parameters{
|
||||
assert.Equal(t, &Argon2Parameters{
|
||||
version: argon2.Version,
|
||||
Memory: 2 * 1024 * 1024,
|
||||
Parallelism: 2,
|
||||
|
||||
@@ -90,26 +90,26 @@ func Test_ca(t *testing.T) {
|
||||
assertHelpError(t, ca(
|
||||
[]string{"-version", "1", "-out-key", "nope", "-out-crt", "nope", "duration", "100m"}, ob, eb, nopw,
|
||||
), "-name is required")
|
||||
assert.Equal(t, "", ob.String())
|
||||
assert.Equal(t, "", eb.String())
|
||||
assert.Empty(t, ob.String())
|
||||
assert.Empty(t, eb.String())
|
||||
|
||||
// ipv4 only ips
|
||||
assertHelpError(t, ca([]string{"-version", "1", "-name", "ipv6", "-ips", "100::100/100"}, ob, eb, nopw), "invalid -networks definition: v1 certificates can only be ipv4, have 100::100/100")
|
||||
assert.Equal(t, "", ob.String())
|
||||
assert.Equal(t, "", eb.String())
|
||||
assert.Empty(t, ob.String())
|
||||
assert.Empty(t, eb.String())
|
||||
|
||||
// ipv4 only subnets
|
||||
assertHelpError(t, ca([]string{"-version", "1", "-name", "ipv6", "-subnets", "100::100/100"}, ob, eb, nopw), "invalid -unsafe-networks definition: v1 certificates can only be ipv4, have 100::100/100")
|
||||
assert.Equal(t, "", ob.String())
|
||||
assert.Equal(t, "", eb.String())
|
||||
assert.Empty(t, ob.String())
|
||||
assert.Empty(t, eb.String())
|
||||
|
||||
// failed key write
|
||||
ob.Reset()
|
||||
eb.Reset()
|
||||
args := []string{"-version", "1", "-name", "test", "-duration", "100m", "-out-crt", "/do/not/write/pleasecrt", "-out-key", "/do/not/write/pleasekey"}
|
||||
require.EqualError(t, ca(args, ob, eb, nopw), "error while writing out-key: open /do/not/write/pleasekey: "+NoSuchDirError)
|
||||
assert.Equal(t, "", ob.String())
|
||||
assert.Equal(t, "", eb.String())
|
||||
assert.Empty(t, ob.String())
|
||||
assert.Empty(t, eb.String())
|
||||
|
||||
// create temp key file
|
||||
keyF, err := os.CreateTemp("", "test.key")
|
||||
@@ -121,8 +121,8 @@ func Test_ca(t *testing.T) {
|
||||
eb.Reset()
|
||||
args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-out-crt", "/do/not/write/pleasecrt", "-out-key", keyF.Name()}
|
||||
require.EqualError(t, ca(args, ob, eb, nopw), "error while writing out-crt: open /do/not/write/pleasecrt: "+NoSuchDirError)
|
||||
assert.Equal(t, "", ob.String())
|
||||
assert.Equal(t, "", eb.String())
|
||||
assert.Empty(t, ob.String())
|
||||
assert.Empty(t, eb.String())
|
||||
|
||||
// create temp cert file
|
||||
crtF, err := os.CreateTemp("", "test.crt")
|
||||
@@ -135,8 +135,8 @@ func Test_ca(t *testing.T) {
|
||||
eb.Reset()
|
||||
args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
|
||||
require.NoError(t, ca(args, ob, eb, nopw))
|
||||
assert.Equal(t, "", ob.String())
|
||||
assert.Equal(t, "", eb.String())
|
||||
assert.Empty(t, ob.String())
|
||||
assert.Empty(t, eb.String())
|
||||
|
||||
// read cert and key files
|
||||
rb, _ := os.ReadFile(keyF.Name())
|
||||
@@ -158,7 +158,7 @@ func Test_ca(t *testing.T) {
|
||||
assert.Empty(t, lCrt.UnsafeNetworks())
|
||||
assert.Len(t, lCrt.PublicKey(), 32)
|
||||
assert.Equal(t, time.Duration(time.Minute*100), lCrt.NotAfter().Sub(lCrt.NotBefore()))
|
||||
assert.Equal(t, "", lCrt.Issuer())
|
||||
assert.Empty(t, lCrt.Issuer())
|
||||
assert.True(t, lCrt.CheckSignature(lCrt.PublicKey()))
|
||||
|
||||
// test encrypted key
|
||||
@@ -169,7 +169,7 @@ func Test_ca(t *testing.T) {
|
||||
args = []string{"-version", "1", "-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
|
||||
require.NoError(t, ca(args, ob, eb, testpw))
|
||||
assert.Equal(t, pwPromptOb, ob.String())
|
||||
assert.Equal(t, "", eb.String())
|
||||
assert.Empty(t, eb.String())
|
||||
|
||||
// read encrypted key file and verify default params
|
||||
rb, _ = os.ReadFile(keyF.Name())
|
||||
@@ -197,7 +197,7 @@ func Test_ca(t *testing.T) {
|
||||
args = []string{"-version", "1", "-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
|
||||
require.Error(t, ca(args, ob, eb, errpw))
|
||||
assert.Equal(t, pwPromptOb, ob.String())
|
||||
assert.Equal(t, "", eb.String())
|
||||
assert.Empty(t, eb.String())
|
||||
|
||||
// test when user fails to enter a password
|
||||
os.Remove(keyF.Name())
|
||||
@@ -207,7 +207,7 @@ func Test_ca(t *testing.T) {
|
||||
args = []string{"-version", "1", "-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
|
||||
require.EqualError(t, ca(args, ob, eb, nopw), "no passphrase specified, remove -encrypt flag to write out-key in plaintext")
|
||||
assert.Equal(t, strings.Repeat(pwPromptOb, 5), ob.String()) // prompts 5 times before giving up
|
||||
assert.Equal(t, "", eb.String())
|
||||
assert.Empty(t, eb.String())
|
||||
|
||||
// create valid cert/key for overwrite tests
|
||||
os.Remove(keyF.Name())
|
||||
@@ -222,8 +222,8 @@ func Test_ca(t *testing.T) {
|
||||
eb.Reset()
|
||||
args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
|
||||
require.EqualError(t, ca(args, ob, eb, nopw), "refusing to overwrite existing CA key: "+keyF.Name())
|
||||
assert.Equal(t, "", ob.String())
|
||||
assert.Equal(t, "", eb.String())
|
||||
assert.Empty(t, ob.String())
|
||||
assert.Empty(t, eb.String())
|
||||
|
||||
// test that we won't overwrite existing key file
|
||||
os.Remove(keyF.Name())
|
||||
@@ -231,8 +231,8 @@ func Test_ca(t *testing.T) {
|
||||
eb.Reset()
|
||||
args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
|
||||
require.EqualError(t, ca(args, ob, eb, nopw), "refusing to overwrite existing CA cert: "+crtF.Name())
|
||||
assert.Equal(t, "", ob.String())
|
||||
assert.Equal(t, "", eb.String())
|
||||
assert.Empty(t, ob.String())
|
||||
assert.Empty(t, eb.String())
|
||||
os.Remove(keyF.Name())
|
||||
|
||||
}
|
||||
|
||||
@@ -37,20 +37,20 @@ func Test_keygen(t *testing.T) {
|
||||
|
||||
// required args
|
||||
assertHelpError(t, keygen([]string{"-out-pub", "nope"}, ob, eb), "-out-key is required")
|
||||
assert.Equal(t, "", ob.String())
|
||||
assert.Equal(t, "", eb.String())
|
||||
assert.Empty(t, ob.String())
|
||||
assert.Empty(t, eb.String())
|
||||
|
||||
assertHelpError(t, keygen([]string{"-out-key", "nope"}, ob, eb), "-out-pub is required")
|
||||
assert.Equal(t, "", ob.String())
|
||||
assert.Equal(t, "", eb.String())
|
||||
assert.Empty(t, ob.String())
|
||||
assert.Empty(t, eb.String())
|
||||
|
||||
// failed key write
|
||||
ob.Reset()
|
||||
eb.Reset()
|
||||
args := []string{"-out-pub", "/do/not/write/pleasepub", "-out-key", "/do/not/write/pleasekey"}
|
||||
require.EqualError(t, keygen(args, ob, eb), "error while writing out-key: open /do/not/write/pleasekey: "+NoSuchDirError)
|
||||
assert.Equal(t, "", ob.String())
|
||||
assert.Equal(t, "", eb.String())
|
||||
assert.Empty(t, ob.String())
|
||||
assert.Empty(t, eb.String())
|
||||
|
||||
// create temp key file
|
||||
keyF, err := os.CreateTemp("", "test.key")
|
||||
@@ -62,8 +62,8 @@ func Test_keygen(t *testing.T) {
|
||||
eb.Reset()
|
||||
args = []string{"-out-pub", "/do/not/write/pleasepub", "-out-key", keyF.Name()}
|
||||
require.EqualError(t, keygen(args, ob, eb), "error while writing out-pub: open /do/not/write/pleasepub: "+NoSuchDirError)
|
||||
assert.Equal(t, "", ob.String())
|
||||
assert.Equal(t, "", eb.String())
|
||||
assert.Empty(t, ob.String())
|
||||
assert.Empty(t, eb.String())
|
||||
|
||||
// create temp pub file
|
||||
pubF, err := os.CreateTemp("", "test.pub")
|
||||
@@ -75,8 +75,8 @@ func Test_keygen(t *testing.T) {
|
||||
eb.Reset()
|
||||
args = []string{"-out-pub", pubF.Name(), "-out-key", keyF.Name()}
|
||||
require.NoError(t, keygen(args, ob, eb))
|
||||
assert.Equal(t, "", ob.String())
|
||||
assert.Equal(t, "", eb.String())
|
||||
assert.Empty(t, ob.String())
|
||||
assert.Empty(t, eb.String())
|
||||
|
||||
// read cert and key files
|
||||
rb, _ := os.ReadFile(keyF.Name())
|
||||
|
||||
@@ -17,7 +17,7 @@ func (he *helpError) Error() string {
|
||||
return he.s
|
||||
}
|
||||
|
||||
func newHelpErrorf(s string, v ...interface{}) error {
|
||||
func newHelpErrorf(s string, v ...any) error {
|
||||
return &helpError{s: fmt.Sprintf(s, v...)}
|
||||
}
|
||||
|
||||
|
||||
@@ -43,16 +43,16 @@ func Test_printCert(t *testing.T) {
|
||||
|
||||
// no path
|
||||
err := printCert([]string{}, ob, eb)
|
||||
assert.Equal(t, "", ob.String())
|
||||
assert.Equal(t, "", eb.String())
|
||||
assert.Empty(t, ob.String())
|
||||
assert.Empty(t, eb.String())
|
||||
assertHelpError(t, err, "-path is required")
|
||||
|
||||
// no cert at path
|
||||
ob.Reset()
|
||||
eb.Reset()
|
||||
err = printCert([]string{"-path", "does_not_exist"}, ob, eb)
|
||||
assert.Equal(t, "", ob.String())
|
||||
assert.Equal(t, "", eb.String())
|
||||
assert.Empty(t, ob.String())
|
||||
assert.Empty(t, eb.String())
|
||||
require.EqualError(t, err, "unable to read cert; open does_not_exist: "+NoSuchFileError)
|
||||
|
||||
// invalid cert at path
|
||||
@@ -64,8 +64,8 @@ func Test_printCert(t *testing.T) {
|
||||
|
||||
tf.WriteString("-----BEGIN NOPE-----")
|
||||
err = printCert([]string{"-path", tf.Name()}, ob, eb)
|
||||
assert.Equal(t, "", ob.String())
|
||||
assert.Equal(t, "", eb.String())
|
||||
assert.Empty(t, ob.String())
|
||||
assert.Empty(t, eb.String())
|
||||
require.EqualError(t, err, "error while unmarshaling cert: input did not contain a valid PEM encoded block")
|
||||
|
||||
// test multiple certs
|
||||
@@ -155,7 +155,7 @@ func Test_printCert(t *testing.T) {
|
||||
`,
|
||||
ob.String(),
|
||||
)
|
||||
assert.Equal(t, "", eb.String())
|
||||
assert.Empty(t, eb.String())
|
||||
|
||||
// test json
|
||||
ob.Reset()
|
||||
@@ -177,7 +177,7 @@ func Test_printCert(t *testing.T) {
|
||||
`,
|
||||
ob.String(),
|
||||
)
|
||||
assert.Equal(t, "", eb.String())
|
||||
assert.Empty(t, eb.String())
|
||||
}
|
||||
|
||||
// NewTestCaCert will generate a CA cert
|
||||
|
||||
@@ -38,19 +38,19 @@ func Test_verify(t *testing.T) {
|
||||
|
||||
// required args
|
||||
assertHelpError(t, verify([]string{"-ca", "derp"}, ob, eb), "-crt is required")
|
||||
assert.Equal(t, "", ob.String())
|
||||
assert.Equal(t, "", eb.String())
|
||||
assert.Empty(t, ob.String())
|
||||
assert.Empty(t, eb.String())
|
||||
|
||||
assertHelpError(t, verify([]string{"-crt", "derp"}, ob, eb), "-ca is required")
|
||||
assert.Equal(t, "", ob.String())
|
||||
assert.Equal(t, "", eb.String())
|
||||
assert.Empty(t, ob.String())
|
||||
assert.Empty(t, eb.String())
|
||||
|
||||
// no ca at path
|
||||
ob.Reset()
|
||||
eb.Reset()
|
||||
err := verify([]string{"-ca", "does_not_exist", "-crt", "does_not_exist"}, ob, eb)
|
||||
assert.Equal(t, "", ob.String())
|
||||
assert.Equal(t, "", eb.String())
|
||||
assert.Empty(t, ob.String())
|
||||
assert.Empty(t, eb.String())
|
||||
require.EqualError(t, err, "error while reading ca: open does_not_exist: "+NoSuchFileError)
|
||||
|
||||
// invalid ca at path
|
||||
@@ -62,8 +62,8 @@ func Test_verify(t *testing.T) {
|
||||
|
||||
caFile.WriteString("-----BEGIN NOPE-----")
|
||||
err = verify([]string{"-ca", caFile.Name(), "-crt", "does_not_exist"}, ob, eb)
|
||||
assert.Equal(t, "", ob.String())
|
||||
assert.Equal(t, "", eb.String())
|
||||
assert.Empty(t, ob.String())
|
||||
assert.Empty(t, eb.String())
|
||||
require.EqualError(t, err, "error while adding ca cert to pool: input did not contain a valid PEM encoded block")
|
||||
|
||||
// make a ca for later
|
||||
@@ -76,8 +76,8 @@ func Test_verify(t *testing.T) {
|
||||
|
||||
// no crt at path
|
||||
err = verify([]string{"-ca", caFile.Name(), "-crt", "does_not_exist"}, ob, eb)
|
||||
assert.Equal(t, "", ob.String())
|
||||
assert.Equal(t, "", eb.String())
|
||||
assert.Empty(t, ob.String())
|
||||
assert.Empty(t, eb.String())
|
||||
require.EqualError(t, err, "unable to read crt: open does_not_exist: "+NoSuchFileError)
|
||||
|
||||
// invalid crt at path
|
||||
@@ -89,8 +89,8 @@ func Test_verify(t *testing.T) {
|
||||
|
||||
certFile.WriteString("-----BEGIN NOPE-----")
|
||||
err = verify([]string{"-ca", caFile.Name(), "-crt", certFile.Name()}, ob, eb)
|
||||
assert.Equal(t, "", ob.String())
|
||||
assert.Equal(t, "", eb.String())
|
||||
assert.Empty(t, ob.String())
|
||||
assert.Empty(t, eb.String())
|
||||
require.EqualError(t, err, "error while parsing crt: input did not contain a valid PEM encoded block")
|
||||
|
||||
// unverifiable cert at path
|
||||
@@ -106,8 +106,8 @@ func Test_verify(t *testing.T) {
|
||||
certFile.Write(b)
|
||||
|
||||
err = verify([]string{"-ca", caFile.Name(), "-crt", certFile.Name()}, ob, eb)
|
||||
assert.Equal(t, "", ob.String())
|
||||
assert.Equal(t, "", eb.String())
|
||||
assert.Empty(t, ob.String())
|
||||
assert.Empty(t, eb.String())
|
||||
require.ErrorIs(t, err, cert.ErrSignatureMismatch)
|
||||
|
||||
// verified cert at path
|
||||
@@ -118,7 +118,7 @@ func Test_verify(t *testing.T) {
|
||||
certFile.Write(b)
|
||||
|
||||
err = verify([]string{"-ca", caFile.Name(), "-crt", certFile.Name()}, ob, eb)
|
||||
assert.Equal(t, "", ob.String())
|
||||
assert.Equal(t, "", eb.String())
|
||||
assert.Empty(t, ob.String())
|
||||
assert.Empty(t, eb.String())
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
@@ -17,14 +17,14 @@ import (
|
||||
|
||||
"dario.cat/mergo"
|
||||
"github.com/sirupsen/logrus"
|
||||
"gopkg.in/yaml.v2"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
type C struct {
|
||||
path string
|
||||
files []string
|
||||
Settings map[interface{}]interface{}
|
||||
oldSettings map[interface{}]interface{}
|
||||
Settings map[string]any
|
||||
oldSettings map[string]any
|
||||
callbacks []func(*C)
|
||||
l *logrus.Logger
|
||||
reloadLock sync.Mutex
|
||||
@@ -32,7 +32,7 @@ type C struct {
|
||||
|
||||
func NewC(l *logrus.Logger) *C {
|
||||
return &C{
|
||||
Settings: make(map[interface{}]interface{}),
|
||||
Settings: make(map[string]any),
|
||||
l: l,
|
||||
}
|
||||
}
|
||||
@@ -92,8 +92,8 @@ func (c *C) HasChanged(k string) bool {
|
||||
}
|
||||
|
||||
var (
|
||||
nv interface{}
|
||||
ov interface{}
|
||||
nv any
|
||||
ov any
|
||||
)
|
||||
|
||||
if k == "" {
|
||||
@@ -147,7 +147,7 @@ func (c *C) ReloadConfig() {
|
||||
c.reloadLock.Lock()
|
||||
defer c.reloadLock.Unlock()
|
||||
|
||||
c.oldSettings = make(map[interface{}]interface{})
|
||||
c.oldSettings = make(map[string]any)
|
||||
for k, v := range c.Settings {
|
||||
c.oldSettings[k] = v
|
||||
}
|
||||
@@ -167,7 +167,7 @@ func (c *C) ReloadConfigString(raw string) error {
|
||||
c.reloadLock.Lock()
|
||||
defer c.reloadLock.Unlock()
|
||||
|
||||
c.oldSettings = make(map[interface{}]interface{})
|
||||
c.oldSettings = make(map[string]any)
|
||||
for k, v := range c.Settings {
|
||||
c.oldSettings[k] = v
|
||||
}
|
||||
@@ -201,7 +201,7 @@ func (c *C) GetStringSlice(k string, d []string) []string {
|
||||
return d
|
||||
}
|
||||
|
||||
rv, ok := r.([]interface{})
|
||||
rv, ok := r.([]any)
|
||||
if !ok {
|
||||
return d
|
||||
}
|
||||
@@ -215,13 +215,13 @@ func (c *C) GetStringSlice(k string, d []string) []string {
|
||||
}
|
||||
|
||||
// GetMap will get the map for k or return the default d if not found or invalid
|
||||
func (c *C) GetMap(k string, d map[interface{}]interface{}) map[interface{}]interface{} {
|
||||
func (c *C) GetMap(k string, d map[string]any) map[string]any {
|
||||
r := c.Get(k)
|
||||
if r == nil {
|
||||
return d
|
||||
}
|
||||
|
||||
v, ok := r.(map[interface{}]interface{})
|
||||
v, ok := r.(map[string]any)
|
||||
if !ok {
|
||||
return d
|
||||
}
|
||||
@@ -243,7 +243,7 @@ func (c *C) GetInt(k string, d int) int {
|
||||
// GetUint32 will get the uint32 for k or return the default d if not found or invalid
|
||||
func (c *C) GetUint32(k string, d uint32) uint32 {
|
||||
r := c.GetInt(k, int(d))
|
||||
if uint64(r) > uint64(math.MaxUint32) {
|
||||
if r < 0 || uint64(r) > uint64(math.MaxUint32) {
|
||||
return d
|
||||
}
|
||||
return uint32(r)
|
||||
@@ -266,6 +266,22 @@ func (c *C) GetBool(k string, d bool) bool {
|
||||
return v
|
||||
}
|
||||
|
||||
func AsBool(v any) (value bool, ok bool) {
|
||||
switch x := v.(type) {
|
||||
case bool:
|
||||
return x, true
|
||||
case string:
|
||||
switch x {
|
||||
case "y", "yes":
|
||||
return true, true
|
||||
case "n", "no":
|
||||
return false, true
|
||||
}
|
||||
}
|
||||
|
||||
return false, false
|
||||
}
|
||||
|
||||
// GetDuration will get the duration for k or return the default d if not found or invalid
|
||||
func (c *C) GetDuration(k string, d time.Duration) time.Duration {
|
||||
r := c.GetString(k, "")
|
||||
@@ -276,7 +292,7 @@ func (c *C) GetDuration(k string, d time.Duration) time.Duration {
|
||||
return v
|
||||
}
|
||||
|
||||
func (c *C) Get(k string) interface{} {
|
||||
func (c *C) Get(k string) any {
|
||||
return c.get(k, c.Settings)
|
||||
}
|
||||
|
||||
@@ -284,10 +300,10 @@ func (c *C) IsSet(k string) bool {
|
||||
return c.get(k, c.Settings) != nil
|
||||
}
|
||||
|
||||
func (c *C) get(k string, v interface{}) interface{} {
|
||||
func (c *C) get(k string, v any) any {
|
||||
parts := strings.Split(k, ".")
|
||||
for _, p := range parts {
|
||||
m, ok := v.(map[interface{}]interface{})
|
||||
m, ok := v.(map[string]any)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
@@ -346,7 +362,7 @@ func (c *C) addFile(path string, direct bool) error {
|
||||
}
|
||||
|
||||
func (c *C) parseRaw(b []byte) error {
|
||||
var m map[interface{}]interface{}
|
||||
var m map[string]any
|
||||
|
||||
err := yaml.Unmarshal(b, &m)
|
||||
if err != nil {
|
||||
@@ -358,7 +374,7 @@ func (c *C) parseRaw(b []byte) error {
|
||||
}
|
||||
|
||||
func (c *C) parse() error {
|
||||
var m map[interface{}]interface{}
|
||||
var m map[string]any
|
||||
|
||||
for _, path := range c.files {
|
||||
b, err := os.ReadFile(path)
|
||||
@@ -366,7 +382,7 @@ func (c *C) parse() error {
|
||||
return err
|
||||
}
|
||||
|
||||
var nm map[interface{}]interface{}
|
||||
var nm map[string]any
|
||||
err = yaml.Unmarshal(b, &nm)
|
||||
if err != nil {
|
||||
return err
|
||||
|
||||
@@ -10,7 +10,7 @@ import (
|
||||
"github.com/slackhq/nebula/test"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gopkg.in/yaml.v2"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
func TestConfig_Load(t *testing.T) {
|
||||
@@ -19,7 +19,7 @@ func TestConfig_Load(t *testing.T) {
|
||||
// invalid yaml
|
||||
c := NewC(l)
|
||||
os.WriteFile(filepath.Join(dir, "01.yaml"), []byte(" invalid yaml"), 0644)
|
||||
require.EqualError(t, c.Load(dir), "yaml: unmarshal errors:\n line 1: cannot unmarshal !!str `invalid...` into map[interface {}]interface {}")
|
||||
require.EqualError(t, c.Load(dir), "yaml: unmarshal errors:\n line 1: cannot unmarshal !!str `invalid...` into map[string]interface {}")
|
||||
|
||||
// simple multi config merge
|
||||
c = NewC(l)
|
||||
@@ -31,8 +31,8 @@ func TestConfig_Load(t *testing.T) {
|
||||
os.WriteFile(filepath.Join(dir, "01.yaml"), []byte("outer:\n inner: hi"), 0644)
|
||||
os.WriteFile(filepath.Join(dir, "02.yml"), []byte("outer:\n inner: override\nnew: hi"), 0644)
|
||||
require.NoError(t, c.Load(dir))
|
||||
expected := map[interface{}]interface{}{
|
||||
"outer": map[interface{}]interface{}{
|
||||
expected := map[string]any{
|
||||
"outer": map[string]any{
|
||||
"inner": "override",
|
||||
},
|
||||
"new": "hi",
|
||||
@@ -44,12 +44,12 @@ func TestConfig_Get(t *testing.T) {
|
||||
l := test.NewLogger()
|
||||
// test simple type
|
||||
c := NewC(l)
|
||||
c.Settings["firewall"] = map[interface{}]interface{}{"outbound": "hi"}
|
||||
c.Settings["firewall"] = map[string]any{"outbound": "hi"}
|
||||
assert.Equal(t, "hi", c.Get("firewall.outbound"))
|
||||
|
||||
// test complex type
|
||||
inner := []map[interface{}]interface{}{{"port": "1", "code": "2"}}
|
||||
c.Settings["firewall"] = map[interface{}]interface{}{"outbound": inner}
|
||||
inner := []map[string]any{{"port": "1", "code": "2"}}
|
||||
c.Settings["firewall"] = map[string]any{"outbound": inner}
|
||||
assert.EqualValues(t, inner, c.Get("firewall.outbound"))
|
||||
|
||||
// test missing
|
||||
@@ -59,7 +59,7 @@ func TestConfig_Get(t *testing.T) {
|
||||
func TestConfig_GetStringSlice(t *testing.T) {
|
||||
l := test.NewLogger()
|
||||
c := NewC(l)
|
||||
c.Settings["slice"] = []interface{}{"one", "two"}
|
||||
c.Settings["slice"] = []any{"one", "two"}
|
||||
assert.Equal(t, []string{"one", "two"}, c.GetStringSlice("slice", []string{}))
|
||||
}
|
||||
|
||||
@@ -101,14 +101,14 @@ func TestConfig_HasChanged(t *testing.T) {
|
||||
// Test key change
|
||||
c = NewC(l)
|
||||
c.Settings["test"] = "hi"
|
||||
c.oldSettings = map[interface{}]interface{}{"test": "no"}
|
||||
c.oldSettings = map[string]any{"test": "no"}
|
||||
assert.True(t, c.HasChanged("test"))
|
||||
assert.True(t, c.HasChanged(""))
|
||||
|
||||
// No key change
|
||||
c = NewC(l)
|
||||
c.Settings["test"] = "hi"
|
||||
c.oldSettings = map[interface{}]interface{}{"test": "hi"}
|
||||
c.oldSettings = map[string]any{"test": "hi"}
|
||||
assert.False(t, c.HasChanged("test"))
|
||||
assert.False(t, c.HasChanged(""))
|
||||
}
|
||||
@@ -184,11 +184,11 @@ firewall:
|
||||
`),
|
||||
}
|
||||
|
||||
var m map[any]any
|
||||
var m map[string]any
|
||||
|
||||
// merge the same way config.parse() merges
|
||||
for _, b := range configs {
|
||||
var nm map[any]any
|
||||
var nm map[string]any
|
||||
err := yaml.Unmarshal(b, &nm)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -205,15 +205,15 @@ firewall:
|
||||
t.Logf("Merged Config as YAML:\n%s", mYaml)
|
||||
|
||||
// If a bug is present, some items might be replaced instead of merged like we expect
|
||||
expected := map[any]any{
|
||||
"firewall": map[any]any{
|
||||
expected := map[string]any{
|
||||
"firewall": map[string]any{
|
||||
"inbound": []any{
|
||||
map[any]any{"host": "any", "port": "any", "proto": "icmp"},
|
||||
map[any]any{"groups": []any{"server"}, "port": 443, "proto": "tcp"},
|
||||
map[any]any{"groups": []any{"webapp"}, "port": 443, "proto": "tcp"}},
|
||||
map[string]any{"host": "any", "port": "any", "proto": "icmp"},
|
||||
map[string]any{"groups": []any{"server"}, "port": 443, "proto": "tcp"},
|
||||
map[string]any{"groups": []any{"webapp"}, "port": 443, "proto": "tcp"}},
|
||||
"outbound": []any{
|
||||
map[any]any{"host": "any", "port": "any", "proto": "any"}}},
|
||||
"listen": map[any]any{
|
||||
map[string]any{"host": "any", "port": "any", "proto": "any"}}},
|
||||
"listen": map[string]any{
|
||||
"host": "0.0.0.0",
|
||||
"port": 4242,
|
||||
},
|
||||
|
||||
@@ -498,7 +498,7 @@ func (n *connectionManager) tryRehandshake(hostinfo *HostInfo) {
|
||||
cs := n.intf.pki.getCertState()
|
||||
curCrt := hostinfo.ConnectionState.myCert
|
||||
myCrt := cs.getCertificate(curCrt.Version())
|
||||
if curCrt.Version() >= cs.defaultVersion && bytes.Equal(curCrt.Signature(), myCrt.Signature()) == true {
|
||||
if curCrt.Version() >= cs.initiatingVersion && bytes.Equal(curCrt.Signature(), myCrt.Signature()) == true {
|
||||
// The current tunnel is using the latest certificate and version, no need to rehandshake.
|
||||
return
|
||||
}
|
||||
|
||||
@@ -44,10 +44,10 @@ func Test_NewConnectionManagerTest(t *testing.T) {
|
||||
hostMap.preferredRanges.Store(&preferredRanges)
|
||||
|
||||
cs := &CertState{
|
||||
defaultVersion: cert.Version1,
|
||||
privateKey: []byte{},
|
||||
v1Cert: &dummyCert{version: cert.Version1},
|
||||
v1HandshakeBytes: []byte{},
|
||||
initiatingVersion: cert.Version1,
|
||||
privateKey: []byte{},
|
||||
v1Cert: &dummyCert{version: cert.Version1},
|
||||
v1HandshakeBytes: []byte{},
|
||||
}
|
||||
|
||||
lh := newTestLighthouse()
|
||||
@@ -126,10 +126,10 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
|
||||
hostMap.preferredRanges.Store(&preferredRanges)
|
||||
|
||||
cs := &CertState{
|
||||
defaultVersion: cert.Version1,
|
||||
privateKey: []byte{},
|
||||
v1Cert: &dummyCert{version: cert.Version1},
|
||||
v1HandshakeBytes: []byte{},
|
||||
initiatingVersion: cert.Version1,
|
||||
privateKey: []byte{},
|
||||
v1Cert: &dummyCert{version: cert.Version1},
|
||||
v1HandshakeBytes: []byte{},
|
||||
}
|
||||
|
||||
lh := newTestLighthouse()
|
||||
|
||||
@@ -27,7 +27,7 @@ type ConnectionState struct {
|
||||
writeLock sync.Mutex
|
||||
}
|
||||
|
||||
func NewConnectionState(l *logrus.Logger, cs *CertState, crt cert.Certificate, initiator bool, pattern noise.HandshakePattern, psk []byte) (*ConnectionState, error) {
|
||||
func NewConnectionState(l *logrus.Logger, cs *CertState, crt cert.Certificate, initiator bool, pattern noise.HandshakePattern) (*ConnectionState, error) {
|
||||
var dhFunc noise.DHFunc
|
||||
switch crt.Curve() {
|
||||
case cert.Curve_CURVE25519:
|
||||
@@ -56,12 +56,13 @@ func NewConnectionState(l *logrus.Logger, cs *CertState, crt cert.Certificate, i
|
||||
b.Update(l, 0)
|
||||
|
||||
hs, err := noise.NewHandshakeState(noise.Config{
|
||||
CipherSuite: ncs,
|
||||
Random: rand.Reader,
|
||||
Pattern: pattern,
|
||||
Initiator: initiator,
|
||||
StaticKeypair: static,
|
||||
PresharedKey: psk,
|
||||
CipherSuite: ncs,
|
||||
Random: rand.Reader,
|
||||
Pattern: pattern,
|
||||
Initiator: initiator,
|
||||
StaticKeypair: static,
|
||||
//NOTE: These should come from CertState (pki.go) when we finally implement it
|
||||
PresharedKey: []byte{},
|
||||
PresharedKeyPlacement: 0,
|
||||
})
|
||||
if err != nil {
|
||||
|
||||
@@ -131,8 +131,7 @@ func (c *Control) ListHostmapIndexes(pendingMap bool) []ControlHostInfo {
|
||||
|
||||
// GetCertByVpnIp returns the authenticated certificate of the given vpn IP, or nil if not found
|
||||
func (c *Control) GetCertByVpnIp(vpnIp netip.Addr) cert.Certificate {
|
||||
_, found := c.f.myVpnAddrsTable.Lookup(vpnIp)
|
||||
if found {
|
||||
if c.f.myVpnAddrsTable.Contains(vpnIp) {
|
||||
// Only returning the default certificate since its impossible
|
||||
// for any other host but ourselves to have more than 1
|
||||
return c.f.pki.getCertState().GetDefaultCertificate().Copy()
|
||||
|
||||
@@ -101,7 +101,7 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
|
||||
|
||||
// Make sure we don't have any unexpected fields
|
||||
assertFields(t, []string{"VpnAddrs", "LocalIndex", "RemoteIndex", "RemoteAddrs", "Cert", "MessageCounter", "CurrentRemote", "CurrentRelaysToMe", "CurrentRelaysThroughMe"}, thi)
|
||||
assert.EqualValues(t, &expectedInfo, thi)
|
||||
assert.Equal(t, &expectedInfo, thi)
|
||||
test.AssertDeepCopyEqual(t, &expectedInfo, thi)
|
||||
|
||||
// Make sure we don't panic if the host info doesn't have a cert yet
|
||||
@@ -110,7 +110,7 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func assertFields(t *testing.T, expected []string, actualStruct interface{}) {
|
||||
func assertFields(t *testing.T, expected []string, actualStruct any) {
|
||||
val := reflect.ValueOf(actualStruct).Elem()
|
||||
fields := make([]string, val.NumField())
|
||||
for i := 0; i < val.NumField(); i++ {
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package nebula
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
@@ -26,7 +27,7 @@ type dnsRecords struct {
|
||||
dnsMap4 map[string]netip.Addr
|
||||
dnsMap6 map[string]netip.Addr
|
||||
hostMap *HostMap
|
||||
myVpnAddrsTable *bart.Table[struct{}]
|
||||
myVpnAddrsTable *bart.Lite
|
||||
}
|
||||
|
||||
func newDnsRecords(l *logrus.Logger, cs *CertState, hostMap *HostMap) *dnsRecords {
|
||||
@@ -39,7 +40,7 @@ func newDnsRecords(l *logrus.Logger, cs *CertState, hostMap *HostMap) *dnsRecord
|
||||
}
|
||||
}
|
||||
|
||||
func (d *dnsRecords) Query(q uint16, data string) netip.Addr {
|
||||
func (d *dnsRecords) query(q uint16, data string) netip.Addr {
|
||||
data = strings.ToLower(data)
|
||||
d.RLock()
|
||||
defer d.RUnlock()
|
||||
@@ -57,7 +58,7 @@ func (d *dnsRecords) Query(q uint16, data string) netip.Addr {
|
||||
return netip.Addr{}
|
||||
}
|
||||
|
||||
func (d *dnsRecords) QueryCert(data string) string {
|
||||
func (d *dnsRecords) queryCert(data string) string {
|
||||
ip, err := netip.ParseAddr(data[:len(data)-1])
|
||||
if err != nil {
|
||||
return ""
|
||||
@@ -112,8 +113,8 @@ func (d *dnsRecords) isSelfNebulaOrLocalhost(addr string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
_, found := d.myVpnAddrsTable.Lookup(b)
|
||||
return found //if we found it in this table, it's good
|
||||
//if we found it in this table, it's good
|
||||
return d.myVpnAddrsTable.Contains(b)
|
||||
}
|
||||
|
||||
func (d *dnsRecords) parseQuery(m *dns.Msg, w dns.ResponseWriter) {
|
||||
@@ -122,7 +123,7 @@ func (d *dnsRecords) parseQuery(m *dns.Msg, w dns.ResponseWriter) {
|
||||
case dns.TypeA, dns.TypeAAAA:
|
||||
qType := dns.TypeToString[q.Qtype]
|
||||
d.l.Debugf("Query for %s %s", qType, q.Name)
|
||||
ip := d.Query(q.Qtype, q.Name)
|
||||
ip := d.query(q.Qtype, q.Name)
|
||||
if ip.IsValid() {
|
||||
rr, err := dns.NewRR(fmt.Sprintf("%s %s %s", q.Name, qType, ip))
|
||||
if err == nil {
|
||||
@@ -135,7 +136,7 @@ func (d *dnsRecords) parseQuery(m *dns.Msg, w dns.ResponseWriter) {
|
||||
return
|
||||
}
|
||||
d.l.Debugf("Query for TXT %s", q.Name)
|
||||
ip := d.QueryCert(q.Name)
|
||||
ip := d.queryCert(q.Name)
|
||||
if ip != "" {
|
||||
rr, err := dns.NewRR(fmt.Sprintf("%s TXT %s", q.Name, ip))
|
||||
if err == nil {
|
||||
@@ -163,18 +164,18 @@ func (d *dnsRecords) handleDnsRequest(w dns.ResponseWriter, r *dns.Msg) {
|
||||
w.WriteMsg(m)
|
||||
}
|
||||
|
||||
func dnsMain(l *logrus.Logger, cs *CertState, hostMap *HostMap, c *config.C) func() {
|
||||
func dnsMain(ctx context.Context, l *logrus.Logger, cs *CertState, hostMap *HostMap, c *config.C) func() {
|
||||
dnsR = newDnsRecords(l, cs, hostMap)
|
||||
|
||||
// attach request handler func
|
||||
dns.HandleFunc(".", dnsR.handleDnsRequest)
|
||||
|
||||
c.RegisterReloadCallback(func(c *config.C) {
|
||||
reloadDns(l, c)
|
||||
reloadDns(ctx, l, c)
|
||||
})
|
||||
|
||||
return func() {
|
||||
startDns(l, c)
|
||||
startDns(ctx, l, c)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -187,24 +188,24 @@ func getDnsServerAddr(c *config.C) string {
|
||||
return net.JoinHostPort(dnsHost, strconv.Itoa(c.GetInt("lighthouse.dns.port", 53)))
|
||||
}
|
||||
|
||||
func startDns(l *logrus.Logger, c *config.C) {
|
||||
func startDns(ctx context.Context, l *logrus.Logger, c *config.C) {
|
||||
dnsAddr = getDnsServerAddr(c)
|
||||
dnsServer = &dns.Server{Addr: dnsAddr, Net: "udp"}
|
||||
l.WithField("dnsListener", dnsAddr).Info("Starting DNS responder")
|
||||
err := dnsServer.ListenAndServe()
|
||||
defer dnsServer.Shutdown()
|
||||
defer dnsServer.ShutdownContext(ctx)
|
||||
if err != nil {
|
||||
l.Errorf("Failed to start server: %s\n ", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func reloadDns(l *logrus.Logger, c *config.C) {
|
||||
func reloadDns(ctx context.Context, l *logrus.Logger, c *config.C) {
|
||||
if dnsAddr == getDnsServerAddr(c) {
|
||||
l.Debug("No DNS server config change detected")
|
||||
return
|
||||
}
|
||||
|
||||
l.Debug("Restarting DNS server")
|
||||
dnsServer.Shutdown()
|
||||
go startDns(l, c)
|
||||
dnsServer.ShutdownContext(ctx)
|
||||
go startDns(ctx, l, c)
|
||||
}
|
||||
|
||||
@@ -38,24 +38,24 @@ func TestParsequery(t *testing.T) {
|
||||
func Test_getDnsServerAddr(t *testing.T) {
|
||||
c := config.NewC(nil)
|
||||
|
||||
c.Settings["lighthouse"] = map[interface{}]interface{}{
|
||||
"dns": map[interface{}]interface{}{
|
||||
c.Settings["lighthouse"] = map[string]any{
|
||||
"dns": map[string]any{
|
||||
"host": "0.0.0.0",
|
||||
"port": "1",
|
||||
},
|
||||
}
|
||||
assert.Equal(t, "0.0.0.0:1", getDnsServerAddr(c))
|
||||
|
||||
c.Settings["lighthouse"] = map[interface{}]interface{}{
|
||||
"dns": map[interface{}]interface{}{
|
||||
c.Settings["lighthouse"] = map[string]any{
|
||||
"dns": map[string]any{
|
||||
"host": "::",
|
||||
"port": "1",
|
||||
},
|
||||
}
|
||||
assert.Equal(t, "[::]:1", getDnsServerAddr(c))
|
||||
|
||||
c.Settings["lighthouse"] = map[interface{}]interface{}{
|
||||
"dns": map[interface{}]interface{}{
|
||||
c.Settings["lighthouse"] = map[string]any{
|
||||
"dns": map[string]any{
|
||||
"host": "[::]",
|
||||
"port": "1",
|
||||
},
|
||||
@@ -63,8 +63,8 @@ func Test_getDnsServerAddr(t *testing.T) {
|
||||
assert.Equal(t, "[::]:1", getDnsServerAddr(c))
|
||||
|
||||
// Make sure whitespace doesn't mess us up
|
||||
c.Settings["lighthouse"] = map[interface{}]interface{}{
|
||||
"dns": map[interface{}]interface{}{
|
||||
c.Settings["lighthouse"] = map[string]any{
|
||||
"dns": map[string]any{
|
||||
"host": "[::] ",
|
||||
"port": "1",
|
||||
},
|
||||
|
||||
@@ -20,7 +20,7 @@ import (
|
||||
"github.com/slackhq/nebula/udp"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gopkg.in/yaml.v2"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
func BenchmarkHotPath(b *testing.B) {
|
||||
@@ -991,7 +991,7 @@ func TestRehandshaking(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
var theirNewConfig m
|
||||
require.NoError(t, yaml.Unmarshal(rc, &theirNewConfig))
|
||||
theirFirewall := theirNewConfig["firewall"].(map[interface{}]interface{})
|
||||
theirFirewall := theirNewConfig["firewall"].(map[string]any)
|
||||
theirFirewall["inbound"] = []m{{
|
||||
"proto": "any",
|
||||
"port": "any",
|
||||
@@ -1087,7 +1087,7 @@ func TestRehandshakingLoser(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
var myNewConfig m
|
||||
require.NoError(t, yaml.Unmarshal(rc, &myNewConfig))
|
||||
theirFirewall := myNewConfig["firewall"].(map[interface{}]interface{})
|
||||
theirFirewall := myNewConfig["firewall"].(map[string]any)
|
||||
theirFirewall["inbound"] = []m{{
|
||||
"proto": "any",
|
||||
"port": "any",
|
||||
@@ -1224,135 +1224,3 @@ func TestV2NonPrimaryWithLighthouse(t *testing.T) {
|
||||
myControl.Stop()
|
||||
theirControl.Stop()
|
||||
}
|
||||
|
||||
func TestPSK(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
myPskMode nebula.PskMode
|
||||
theirPskMode nebula.PskMode
|
||||
}{
|
||||
// All accepting
|
||||
{
|
||||
name: "both accepting",
|
||||
myPskMode: nebula.PskAccepting,
|
||||
theirPskMode: nebula.PskAccepting,
|
||||
},
|
||||
|
||||
// accepting and sending both ways
|
||||
{
|
||||
name: "accepting to sending",
|
||||
myPskMode: nebula.PskAccepting,
|
||||
theirPskMode: nebula.PskSending,
|
||||
},
|
||||
{
|
||||
name: "sending to accepting",
|
||||
myPskMode: nebula.PskSending,
|
||||
theirPskMode: nebula.PskAccepting,
|
||||
},
|
||||
|
||||
// All sending
|
||||
{
|
||||
name: "sending to sending",
|
||||
myPskMode: nebula.PskSending,
|
||||
theirPskMode: nebula.PskSending,
|
||||
},
|
||||
|
||||
// enforced and sending both ways
|
||||
{
|
||||
name: "enforced to sending",
|
||||
myPskMode: nebula.PskEnforced,
|
||||
theirPskMode: nebula.PskSending,
|
||||
},
|
||||
{
|
||||
name: "sending to enforced",
|
||||
myPskMode: nebula.PskSending,
|
||||
theirPskMode: nebula.PskEnforced,
|
||||
},
|
||||
|
||||
// All enforced
|
||||
{
|
||||
name: "both enforced",
|
||||
myPskMode: nebula.PskEnforced,
|
||||
theirPskMode: nebula.PskEnforced,
|
||||
},
|
||||
|
||||
// Enforced can technically handshake with an accepting node, but it is bad to be in this state
|
||||
{
|
||||
name: "enforced to accepting",
|
||||
myPskMode: nebula.PskEnforced,
|
||||
theirPskMode: nebula.PskAccepting,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
var myPskSettings, theirPskSettings m
|
||||
|
||||
switch test.myPskMode {
|
||||
case nebula.PskAccepting:
|
||||
myPskSettings = m{"psk": &m{"mode": "accepting", "keys": []string{"garbage0", "this is a key"}}}
|
||||
case nebula.PskSending:
|
||||
myPskSettings = m{"psk": &m{"mode": "sending", "keys": []string{"this is a key", "garbage1"}}}
|
||||
case nebula.PskEnforced:
|
||||
myPskSettings = m{"psk": &m{"mode": "enforced", "keys": []string{"this is a key", "garbage2"}}}
|
||||
}
|
||||
|
||||
switch test.theirPskMode {
|
||||
case nebula.PskAccepting:
|
||||
theirPskSettings = m{"psk": &m{"mode": "accepting", "keys": []string{"garbage3", "this is a key"}}}
|
||||
case nebula.PskSending:
|
||||
theirPskSettings = m{"psk": &m{"mode": "sending", "keys": []string{"this is a key", "garbage4"}}}
|
||||
case nebula.PskEnforced:
|
||||
theirPskSettings = m{"psk": &m{"mode": "enforced", "keys": []string{"this is a key", "garbage5"}}}
|
||||
}
|
||||
|
||||
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, nil)
|
||||
myControl, myVpnIp, myUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "me", "10.0.0.1/24", myPskSettings)
|
||||
theirControl, theirVpnIp, theirUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "them", "10.0.0.2/24", theirPskSettings)
|
||||
|
||||
myControl.InjectLightHouseAddr(theirVpnIp[0].Addr(), theirUdpAddr)
|
||||
r := router.NewR(t, myControl, theirControl)
|
||||
|
||||
// Start the servers
|
||||
myControl.Start()
|
||||
theirControl.Start()
|
||||
|
||||
t.Log("Route until we see our cached packet flow")
|
||||
myControl.InjectTunUDPPacket(theirVpnIp[0].Addr(), 80, myVpnIp[0].Addr(), 80, []byte("Hi from me"))
|
||||
r.RouteForAllExitFunc(func(p *udp.Packet, c *nebula.Control) router.ExitType {
|
||||
h := &header.H{}
|
||||
err := h.Parse(p.Data)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// If this is the stage 1 handshake packet and I am configured to send with a psk, my cert name should
|
||||
// not appear. It would likely be more obvious to unmarshal the payload and check but this works fine for now
|
||||
if test.myPskMode == nebula.PskEnforced || test.myPskMode == nebula.PskSending {
|
||||
if h.Type == 0 && h.MessageCounter == 1 {
|
||||
assert.NotContains(t, string(p.Data), "test me")
|
||||
}
|
||||
}
|
||||
|
||||
if p.To == theirUdpAddr && h.Type == 1 {
|
||||
return router.RouteAndExit
|
||||
}
|
||||
|
||||
return router.KeepRouting
|
||||
})
|
||||
|
||||
t.Log("My cached packet should be received by them")
|
||||
myCachedPacket := theirControl.GetFromTun(true)
|
||||
assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIp[0].Addr(), theirVpnIp[0].Addr(), 80, 80)
|
||||
|
||||
t.Log("Test the tunnel with them")
|
||||
assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIp, theirVpnIp, myControl, theirControl)
|
||||
assertTunnel(t, myVpnIp[0].Addr(), theirVpnIp[0].Addr(), myControl, theirControl, r)
|
||||
|
||||
myControl.Stop()
|
||||
theirControl.Stop()
|
||||
//TODO: assert hostmaps
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -22,10 +22,10 @@ import (
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/e2e/router"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"gopkg.in/yaml.v2"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
type m map[string]interface{}
|
||||
type m = map[string]any
|
||||
|
||||
// newSimpleServer creates a nebula instance with many assumptions
|
||||
func newSimpleServer(v cert.Version, caCrt cert.Certificate, caKey []byte, name string, sVpnNetworks string, overrides m) (*nebula.Control, []netip.Prefix, netip.AddrPort, *config.C) {
|
||||
|
||||
@@ -111,6 +111,10 @@ type ExitFunc func(packet *udp.Packet, receiver *nebula.Control) ExitType
|
||||
func NewR(t testing.TB, controls ...*nebula.Control) *R {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
if err := os.MkdirAll("mermaid", 0755); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
r := &R{
|
||||
controls: make(map[netip.AddrPort]*nebula.Control),
|
||||
vpnControls: make(map[netip.Addr]*nebula.Control),
|
||||
@@ -190,9 +194,6 @@ func (r *R) renderFlow() {
|
||||
return
|
||||
}
|
||||
|
||||
if err := os.MkdirAll(filepath.Dir(r.fn), 0755); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
f, err := os.OpenFile(r.fn, os.O_CREATE|os.O_TRUNC|os.O_RDWR, 0644)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
|
||||
@@ -13,43 +13,11 @@ pki:
|
||||
# disconnect_invalid is a toggle to force a client to be disconnected if the certificate is expired or invalid.
|
||||
#disconnect_invalid: true
|
||||
|
||||
# default_version controls which certificate version is used in handshakes.
|
||||
# initiating_version controls which certificate version is used when initiating handshakes.
|
||||
# This setting only applies if both a v1 and a v2 certificate are configured, in which case it will default to `1`.
|
||||
# Once all hosts in the mesh are configured with both a v1 and v2 certificate then this should be changed to `2`.
|
||||
# After all hosts in the mesh are using a v2 certificate then v1 certificates are no longer needed.
|
||||
# default_version: 1
|
||||
|
||||
# psk can be used to mask the contents of handshakes.
|
||||
psk:
|
||||
# `mode` defines how the pre shared keys can be used in a handshake.
|
||||
# `accepting` (the default) will initiate handshakes using an empty key and will try to use any keys provided when
|
||||
# receiving handshakes, including an empty key.
|
||||
# `sending` will initiate handshakes with the first key provided and will try to use any keys provided when
|
||||
# receiving handshakes, including an empty key.
|
||||
# `enforced` will initiate handshakes with the first psk key provided and will try to use any keys provided when
|
||||
# responding to handshakes. An empty key will not be allowed.
|
||||
#
|
||||
# To change a mesh from not using a psk to enforcing psk:
|
||||
# 1. Leave `mode` as `accepting` and configure `psk.keys` to match on all nodes in the mesh and reload.
|
||||
# 2. Change `mode` to `sending` on all nodes in the mesh and reload.
|
||||
# 3. Change `mode` to `enforced` on all nodes in the mesh and reload.
|
||||
#mode: accepting
|
||||
|
||||
# The keys provided are sent through hkdf to ensure the shared secret used in the noise protocol is the
|
||||
# correct byte length.
|
||||
#
|
||||
# Only the first key is used for outbound handshakes but all keys provided will be tried in the order specified, on
|
||||
# incoming handshakes. This is to allow for psk rotation.
|
||||
#
|
||||
# To rotate a primary key:
|
||||
# 1. Put the new key in the 2nd slot on every node in the mesh and reload.
|
||||
# 2. Move the key from the 2nd slot to the 1st slot, the old primary key is now in the 2nd slot, reload.
|
||||
# 3. Remove the old primary key once it is no longer in use on every node in the mesh and reload.
|
||||
#keys:
|
||||
# - shared secret string, this one is used in all outbound handshakes # This is the primary key used when sending handshakes
|
||||
# - this is a fallback key, received handshakes can use this
|
||||
# - another fallback, received handshakes can use this one too
|
||||
# - "\x68\x65\x6c\x6c\x6f\x20\x66\x72\x69\x65\x6e\x64\x73" # for raw bytes if you desire
|
||||
# initiating_version: 1
|
||||
|
||||
# The static host map defines a set of hosts with fixed IP addresses on the internet (or any network).
|
||||
# A host can have multiple fixed IP addresses defined here, and nebula will try each when establishing a tunnel.
|
||||
@@ -271,7 +239,28 @@ tun:
|
||||
|
||||
# Unsafe routes allows you to route traffic over nebula to non-nebula nodes
|
||||
# Unsafe routes should be avoided unless you have hosts/services that cannot run nebula
|
||||
# NOTE: The nebula certificate of the "via" node *MUST* have the "route" defined as a subnet in its certificate
|
||||
# Supports weighted ECMP if you define a list of gateways, this can be used for load balancing or redundancy to hosts outside of nebula
|
||||
# NOTES:
|
||||
# * You will only see a single gateway in the routing table if you are not on linux
|
||||
# * If a gateway is not reachable through the overlay another gateway will be selected to send the traffic through, ignoring weights
|
||||
#
|
||||
# unsafe_routes:
|
||||
# # Multiple gateways without defining a weight defaults to a weight of 1, this will balance traffic equally between the three gateways
|
||||
# - route: 192.168.87.0/24
|
||||
# via:
|
||||
# - gateway: 10.0.0.1
|
||||
# - gateway: 10.0.0.2
|
||||
# - gateway: 10.0.0.3
|
||||
# # Multiple gateways with a weight, this will balance traffic accordingly
|
||||
# - route: 192.168.87.0/24
|
||||
# via:
|
||||
# - gateway: 10.0.0.1
|
||||
# weight: 10
|
||||
# - gateway: 10.0.0.2
|
||||
# weight: 5
|
||||
#
|
||||
# NOTE: The nebula certificate of the "via" node(s) *MUST* have the "route" defined as a subnet in its certificate
|
||||
# `via`: single node or list of gateways to use for this route
|
||||
# `mtu`: will default to tun mtu if this option is not specified
|
||||
# `metric`: will default to 0 if this option is not specified
|
||||
# `install`: will default to true, controls whether this route is installed in the systems routing table.
|
||||
@@ -345,6 +334,7 @@ logging:
|
||||
# after receiving the response for lighthouse queries
|
||||
#trigger_buffer: 64
|
||||
|
||||
|
||||
# Nebula security group configuration
|
||||
firewall:
|
||||
# Action to take when a packet is not allowed by the firewall rules.
|
||||
@@ -356,11 +346,11 @@ firewall:
|
||||
outbound_action: drop
|
||||
inbound_action: drop
|
||||
|
||||
# Controls the default value for local_cidr. Default is true, will be deprecated after v1.9 and defaulted to false.
|
||||
# This setting only affects nebula hosts with subnets encoded in their certificate. A nebula host acting as an
|
||||
# unsafe router with `default_local_cidr_any: true` will expose their unsafe routes to every inbound rule regardless
|
||||
# of the actual destination for the packet. Setting this to false requires each inbound rule to contain a `local_cidr`
|
||||
# if the intention is to allow traffic to flow to an unsafe route.
|
||||
# THIS FLAG IS DEPRECATED AND WILL BE REMOVED IN A FUTURE RELEASE. (Defaults to false.)
|
||||
# This setting only affects nebula hosts exposing unsafe_routes. When set to false, each inbound rule must contain a
|
||||
# `local_cidr` if the intention is to allow traffic to flow to an unsafe route. When set to true, every firewall rule
|
||||
# will apply to all configured unsafe_routes regardless of the actual destination of the packet, unless `local_cidr`
|
||||
# is explicitly defined. This is usually not the desired behavior and should be avoided!
|
||||
#default_local_cidr_any: false
|
||||
|
||||
conntrack:
|
||||
@@ -378,11 +368,9 @@ firewall:
|
||||
# group: `any` or a literal group name, ie `default-group`
|
||||
# groups: Same as group but accepts a list of values. Multiple values are AND'd together and a certificate would have to contain all groups to pass
|
||||
# cidr: a remote CIDR, `0.0.0.0/0` is any ipv4 and `::/0` is any ipv6.
|
||||
# local_cidr: a local CIDR, `0.0.0.0/0` is any ipv4 and `::/0` is any ipv6. This could be used to filter destinations when using unsafe_routes.
|
||||
# If no unsafe networks are present in the certificate(s) or `default_local_cidr_any` is true then the default is any ipv4 or ipv6 network.
|
||||
# Otherwise the default is any vpn network assigned to via the certificate.
|
||||
# `default_local_cidr_any` defaults to false and is deprecated, it will be removed in a future release.
|
||||
# If there are unsafe routes present its best to set `local_cidr` to whatever best fits the situation.
|
||||
# local_cidr: a local CIDR, `0.0.0.0/0` is any ipv4 and `::/0` is any ipv6. This can be used to filter destinations when using unsafe_routes.
|
||||
# By default, this is set to only the VPN (overlay) networks assigned via the certificate networks field unless `default_local_cidr_any` is set to true.
|
||||
# If there are unsafe_routes present in this config file, `local_cidr` should be set appropriately for the intended us case.
|
||||
# ca_name: An issuing CA name
|
||||
# ca_sha: An issuing CA shasum
|
||||
|
||||
|
||||
35
firewall.go
35
firewall.go
@@ -53,7 +53,7 @@ type Firewall struct {
|
||||
|
||||
// routableNetworks describes the vpn addresses as well as any unsafe networks issued to us in the certificate.
|
||||
// The vpn addresses are a full bit match while the unsafe networks only match the prefix
|
||||
routableNetworks *bart.Table[struct{}]
|
||||
routableNetworks *bart.Lite
|
||||
|
||||
// assignedNetworks is a list of vpn networks assigned to us in the certificate.
|
||||
assignedNetworks []netip.Prefix
|
||||
@@ -125,7 +125,7 @@ type firewallPort map[int32]*FirewallCA
|
||||
|
||||
type firewallLocalCIDR struct {
|
||||
Any bool
|
||||
LocalCIDR *bart.Table[struct{}]
|
||||
LocalCIDR *bart.Lite
|
||||
}
|
||||
|
||||
// NewFirewall creates a new Firewall object. A TimerWheel is created for you from the provided timeouts.
|
||||
@@ -148,17 +148,17 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D
|
||||
tmax = defaultTimeout
|
||||
}
|
||||
|
||||
routableNetworks := new(bart.Table[struct{}])
|
||||
routableNetworks := new(bart.Lite)
|
||||
var assignedNetworks []netip.Prefix
|
||||
for _, network := range c.Networks() {
|
||||
nprefix := netip.PrefixFrom(network.Addr(), network.Addr().BitLen())
|
||||
routableNetworks.Insert(nprefix, struct{}{})
|
||||
routableNetworks.Insert(nprefix)
|
||||
assignedNetworks = append(assignedNetworks, network)
|
||||
}
|
||||
|
||||
hasUnsafeNetworks := false
|
||||
for _, n := range c.UnsafeNetworks() {
|
||||
routableNetworks.Insert(n, struct{}{})
|
||||
routableNetworks.Insert(n)
|
||||
hasUnsafeNetworks = true
|
||||
}
|
||||
|
||||
@@ -331,7 +331,7 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw
|
||||
return nil
|
||||
}
|
||||
|
||||
rs, ok := r.([]interface{})
|
||||
rs, ok := r.([]any)
|
||||
if !ok {
|
||||
return fmt.Errorf("%s failed to parse, should be an array of rules", table)
|
||||
}
|
||||
@@ -431,8 +431,7 @@ func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *
|
||||
|
||||
// Make sure remote address matches nebula certificate
|
||||
if h.networks != nil {
|
||||
_, ok := h.networks.Lookup(fp.RemoteAddr)
|
||||
if !ok {
|
||||
if !h.networks.Contains(fp.RemoteAddr) {
|
||||
f.metrics(incoming).droppedRemoteAddr.Inc(1)
|
||||
return ErrInvalidRemoteIP
|
||||
}
|
||||
@@ -445,8 +444,7 @@ func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *
|
||||
}
|
||||
|
||||
// Make sure we are supposed to be handling this local ip address
|
||||
_, ok := f.routableNetworks.Lookup(fp.LocalAddr)
|
||||
if !ok {
|
||||
if !f.routableNetworks.Contains(fp.LocalAddr) {
|
||||
f.metrics(incoming).droppedLocalAddr.Inc(1)
|
||||
return ErrInvalidLocalIP
|
||||
}
|
||||
@@ -752,7 +750,7 @@ func (fc *FirewallCA) match(p firewall.Packet, c *cert.CachedCertificate, caPool
|
||||
func (fr *FirewallRule) addRule(f *Firewall, groups []string, host string, ip, localCIDR netip.Prefix) error {
|
||||
flc := func() *firewallLocalCIDR {
|
||||
return &firewallLocalCIDR{
|
||||
LocalCIDR: new(bart.Table[struct{}]),
|
||||
LocalCIDR: new(bart.Lite),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -879,7 +877,7 @@ func (flc *firewallLocalCIDR) addRule(f *Firewall, localIp netip.Prefix) error {
|
||||
}
|
||||
|
||||
for _, network := range f.assignedNetworks {
|
||||
flc.LocalCIDR.Insert(network, struct{}{})
|
||||
flc.LocalCIDR.Insert(network)
|
||||
}
|
||||
return nil
|
||||
|
||||
@@ -888,7 +886,7 @@ func (flc *firewallLocalCIDR) addRule(f *Firewall, localIp netip.Prefix) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
flc.LocalCIDR.Insert(localIp, struct{}{})
|
||||
flc.LocalCIDR.Insert(localIp)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -901,8 +899,7 @@ func (flc *firewallLocalCIDR) match(p firewall.Packet, c *cert.CachedCertificate
|
||||
return true
|
||||
}
|
||||
|
||||
_, ok := flc.LocalCIDR.Lookup(p.LocalAddr)
|
||||
return ok
|
||||
return flc.LocalCIDR.Contains(p.LocalAddr)
|
||||
}
|
||||
|
||||
type rule struct {
|
||||
@@ -918,15 +915,15 @@ type rule struct {
|
||||
CASha string
|
||||
}
|
||||
|
||||
func convertRule(l *logrus.Logger, p interface{}, table string, i int) (rule, error) {
|
||||
func convertRule(l *logrus.Logger, p any, table string, i int) (rule, error) {
|
||||
r := rule{}
|
||||
|
||||
m, ok := p.(map[interface{}]interface{})
|
||||
m, ok := p.(map[string]any)
|
||||
if !ok {
|
||||
return r, errors.New("could not parse rule")
|
||||
}
|
||||
|
||||
toString := func(k string, m map[interface{}]interface{}) string {
|
||||
toString := func(k string, m map[string]any) string {
|
||||
v, ok := m[k]
|
||||
if !ok {
|
||||
return ""
|
||||
@@ -944,7 +941,7 @@ func convertRule(l *logrus.Logger, p interface{}, table string, i int) (rule, er
|
||||
r.CASha = toString("ca_sha", m)
|
||||
|
||||
// Make sure group isn't an array
|
||||
if v, ok := m["group"].([]interface{}); ok {
|
||||
if v, ok := m["group"].([]any); ok {
|
||||
if len(v) > 1 {
|
||||
return r, errors.New("group should contain a single value, an array with more than one entry was provided")
|
||||
}
|
||||
|
||||
@@ -6,7 +6,7 @@ import (
|
||||
"net/netip"
|
||||
)
|
||||
|
||||
type m map[string]interface{}
|
||||
type m = map[string]any
|
||||
|
||||
const (
|
||||
ProtoAny = 0 // When we want to handle HOPOPT (0) we can change this, if ever
|
||||
|
||||
@@ -631,53 +631,53 @@ func TestNewFirewallFromConfig(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
conf := config.NewC(l)
|
||||
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": "asdf"}
|
||||
conf.Settings["firewall"] = map[string]any{"outbound": "asdf"}
|
||||
_, err = NewFirewallFromConfig(l, cs, conf)
|
||||
require.EqualError(t, err, "firewall.outbound failed to parse, should be an array of rules")
|
||||
|
||||
// Test both port and code
|
||||
conf = config.NewC(l)
|
||||
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "code": "2"}}}
|
||||
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "code": "2"}}}
|
||||
_, err = NewFirewallFromConfig(l, cs, conf)
|
||||
require.EqualError(t, err, "firewall.outbound rule #0; only one of port or code should be provided")
|
||||
|
||||
// Test missing host, group, cidr, ca_name and ca_sha
|
||||
conf = config.NewC(l)
|
||||
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{}}}
|
||||
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{}}}
|
||||
_, err = NewFirewallFromConfig(l, cs, conf)
|
||||
require.EqualError(t, err, "firewall.outbound rule #0; at least one of host, group, cidr, local_cidr, ca_name, or ca_sha must be provided")
|
||||
|
||||
// Test code/port error
|
||||
conf = config.NewC(l)
|
||||
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "a", "host": "testh"}}}
|
||||
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "a", "host": "testh"}}}
|
||||
_, err = NewFirewallFromConfig(l, cs, conf)
|
||||
require.EqualError(t, err, "firewall.outbound rule #0; code was not a number; `a`")
|
||||
|
||||
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "a", "host": "testh"}}}
|
||||
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "a", "host": "testh"}}}
|
||||
_, err = NewFirewallFromConfig(l, cs, conf)
|
||||
require.EqualError(t, err, "firewall.outbound rule #0; port was not a number; `a`")
|
||||
|
||||
// Test proto error
|
||||
conf = config.NewC(l)
|
||||
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "host": "testh"}}}
|
||||
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "1", "host": "testh"}}}
|
||||
_, err = NewFirewallFromConfig(l, cs, conf)
|
||||
require.EqualError(t, err, "firewall.outbound rule #0; proto was not understood; ``")
|
||||
|
||||
// Test cidr parse error
|
||||
conf = config.NewC(l)
|
||||
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "cidr": "testh", "proto": "any"}}}
|
||||
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "1", "cidr": "testh", "proto": "any"}}}
|
||||
_, err = NewFirewallFromConfig(l, cs, conf)
|
||||
require.EqualError(t, err, "firewall.outbound rule #0; cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'")
|
||||
|
||||
// Test local_cidr parse error
|
||||
conf = config.NewC(l)
|
||||
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "local_cidr": "testh", "proto": "any"}}}
|
||||
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "1", "local_cidr": "testh", "proto": "any"}}}
|
||||
_, err = NewFirewallFromConfig(l, cs, conf)
|
||||
require.EqualError(t, err, "firewall.outbound rule #0; local_cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'")
|
||||
|
||||
// Test both group and groups
|
||||
conf = config.NewC(l)
|
||||
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a", "groups": []string{"b", "c"}}}}
|
||||
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "group": "a", "groups": []string{"b", "c"}}}}
|
||||
_, err = NewFirewallFromConfig(l, cs, conf)
|
||||
require.EqualError(t, err, "firewall.inbound rule #0; only one of group or groups should be defined, both provided")
|
||||
}
|
||||
@@ -687,28 +687,28 @@ func TestAddFirewallRulesFromConfig(t *testing.T) {
|
||||
// Test adding tcp rule
|
||||
conf := config.NewC(l)
|
||||
mf := &mockFirewall{}
|
||||
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "tcp", "host": "a"}}}
|
||||
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "tcp", "host": "a"}}}
|
||||
require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf))
|
||||
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoTCP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
|
||||
|
||||
// Test adding udp rule
|
||||
conf = config.NewC(l)
|
||||
mf = &mockFirewall{}
|
||||
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "udp", "host": "a"}}}
|
||||
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "udp", "host": "a"}}}
|
||||
require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf))
|
||||
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoUDP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
|
||||
|
||||
// Test adding icmp rule
|
||||
conf = config.NewC(l)
|
||||
mf = &mockFirewall{}
|
||||
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "icmp", "host": "a"}}}
|
||||
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "icmp", "host": "a"}}}
|
||||
require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf))
|
||||
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
|
||||
|
||||
// Test adding any rule
|
||||
conf = config.NewC(l)
|
||||
mf = &mockFirewall{}
|
||||
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "host": "a"}}}
|
||||
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "host": "a"}}}
|
||||
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
||||
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
|
||||
|
||||
@@ -716,49 +716,49 @@ func TestAddFirewallRulesFromConfig(t *testing.T) {
|
||||
cidr := netip.MustParsePrefix("10.0.0.0/8")
|
||||
conf = config.NewC(l)
|
||||
mf = &mockFirewall{}
|
||||
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "cidr": cidr.String()}}}
|
||||
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": cidr.String()}}}
|
||||
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
||||
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr, localIp: netip.Prefix{}}, mf.lastCall)
|
||||
|
||||
// Test adding rule with local_cidr
|
||||
conf = config.NewC(l)
|
||||
mf = &mockFirewall{}
|
||||
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "local_cidr": cidr.String()}}}
|
||||
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": cidr.String()}}}
|
||||
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
||||
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: cidr}, mf.lastCall)
|
||||
|
||||
// Test adding rule with ca_sha
|
||||
conf = config.NewC(l)
|
||||
mf = &mockFirewall{}
|
||||
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "ca_sha": "12312313123"}}}
|
||||
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "ca_sha": "12312313123"}}}
|
||||
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
||||
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: netip.Prefix{}, caSha: "12312313123"}, mf.lastCall)
|
||||
|
||||
// Test adding rule with ca_name
|
||||
conf = config.NewC(l)
|
||||
mf = &mockFirewall{}
|
||||
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "ca_name": "root01"}}}
|
||||
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "ca_name": "root01"}}}
|
||||
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
||||
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: netip.Prefix{}, caName: "root01"}, mf.lastCall)
|
||||
|
||||
// Test single group
|
||||
conf = config.NewC(l)
|
||||
mf = &mockFirewall{}
|
||||
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a"}}}
|
||||
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "group": "a"}}}
|
||||
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
||||
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
|
||||
|
||||
// Test single groups
|
||||
conf = config.NewC(l)
|
||||
mf = &mockFirewall{}
|
||||
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "groups": "a"}}}
|
||||
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "groups": "a"}}}
|
||||
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
||||
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
|
||||
|
||||
// Test multiple AND groups
|
||||
conf = config.NewC(l)
|
||||
mf = &mockFirewall{}
|
||||
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "groups": []string{"a", "b"}}}}
|
||||
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "groups": []string{"a", "b"}}}}
|
||||
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
||||
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a", "b"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
|
||||
|
||||
@@ -766,7 +766,7 @@ func TestAddFirewallRulesFromConfig(t *testing.T) {
|
||||
conf = config.NewC(l)
|
||||
mf = &mockFirewall{}
|
||||
mf.nextCallReturn = errors.New("test error")
|
||||
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "host": "a"}}}
|
||||
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "host": "a"}}}
|
||||
require.EqualError(t, AddFirewallRulesFromConfig(l, true, conf, mf), "firewall.inbound rule #0; `test error`")
|
||||
}
|
||||
|
||||
@@ -776,8 +776,8 @@ func TestFirewall_convertRule(t *testing.T) {
|
||||
l.SetOutput(ob)
|
||||
|
||||
// Ensure group array of 1 is converted and a warning is printed
|
||||
c := map[interface{}]interface{}{
|
||||
"group": []interface{}{"group1"},
|
||||
c := map[string]any{
|
||||
"group": []any{"group1"},
|
||||
}
|
||||
|
||||
r, err := convertRule(l, c, "test", 1)
|
||||
@@ -787,17 +787,17 @@ func TestFirewall_convertRule(t *testing.T) {
|
||||
|
||||
// Ensure group array of > 1 is errord
|
||||
ob.Reset()
|
||||
c = map[interface{}]interface{}{
|
||||
"group": []interface{}{"group1", "group2"},
|
||||
c = map[string]any{
|
||||
"group": []any{"group1", "group2"},
|
||||
}
|
||||
|
||||
r, err = convertRule(l, c, "test", 1)
|
||||
assert.Equal(t, "", ob.String())
|
||||
assert.Empty(t, ob.String())
|
||||
require.Error(t, err, "group should contain a single value, an array with more than one entry was provided")
|
||||
|
||||
// Make sure a well formed group is alright
|
||||
ob.Reset()
|
||||
c = map[interface{}]interface{}{
|
||||
c = map[string]any{
|
||||
"group": "group1",
|
||||
}
|
||||
|
||||
|
||||
30
go.mod
30
go.mod
@@ -1,8 +1,8 @@
|
||||
module github.com/slackhq/nebula
|
||||
|
||||
go 1.23.6
|
||||
go 1.23.0
|
||||
|
||||
toolchain go1.23.7
|
||||
toolchain go1.24.1
|
||||
|
||||
require (
|
||||
dario.cat/mergo v1.0.1
|
||||
@@ -10,31 +10,31 @@ require (
|
||||
github.com/armon/go-radix v1.0.0
|
||||
github.com/cyberdelia/go-metrics-graphite v0.0.0-20161219230853-39f87cc3b432
|
||||
github.com/flynn/noise v1.1.0
|
||||
github.com/gaissmai/bart v0.18.1
|
||||
github.com/gaissmai/bart v0.20.4
|
||||
github.com/gogo/protobuf v1.3.2
|
||||
github.com/google/gopacket v1.1.19
|
||||
github.com/kardianos/service v1.2.2
|
||||
github.com/miekg/dns v1.1.63
|
||||
github.com/miekg/dns v1.1.65
|
||||
github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b
|
||||
github.com/nbrownus/go-metrics-prometheus v0.0.0-20210712211119-974a6260965f
|
||||
github.com/prometheus/client_golang v1.21.1
|
||||
github.com/prometheus/client_golang v1.22.0
|
||||
github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475
|
||||
github.com/sirupsen/logrus v1.9.3
|
||||
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e
|
||||
github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6
|
||||
github.com/stretchr/testify v1.10.0
|
||||
github.com/vishvananda/netlink v1.3.0
|
||||
golang.org/x/crypto v0.36.0
|
||||
golang.org/x/crypto v0.37.0
|
||||
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090
|
||||
golang.org/x/net v0.37.0
|
||||
golang.org/x/sync v0.12.0
|
||||
golang.org/x/sys v0.31.0
|
||||
golang.org/x/term v0.30.0
|
||||
golang.org/x/net v0.39.0
|
||||
golang.org/x/sync v0.13.0
|
||||
golang.org/x/sys v0.32.0
|
||||
golang.org/x/term v0.31.0
|
||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2
|
||||
golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b
|
||||
golang.zx2c4.com/wireguard/windows v0.5.3
|
||||
google.golang.org/protobuf v1.36.5
|
||||
gopkg.in/yaml.v2 v2.4.0
|
||||
google.golang.org/protobuf v1.36.6
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
gvisor.dev/gvisor v0.0.0-20240423190808-9d7a357edefe
|
||||
)
|
||||
|
||||
@@ -43,15 +43,13 @@ require (
|
||||
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/google/btree v1.1.2 // indirect
|
||||
github.com/klauspost/compress v1.17.11 // indirect
|
||||
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/prometheus/client_model v0.6.1 // indirect
|
||||
github.com/prometheus/common v0.62.0 // indirect
|
||||
github.com/prometheus/procfs v0.15.1 // indirect
|
||||
github.com/vishvananda/netns v0.0.4 // indirect
|
||||
golang.org/x/mod v0.18.0 // indirect
|
||||
golang.org/x/mod v0.23.0 // indirect
|
||||
golang.org/x/time v0.5.0 // indirect
|
||||
golang.org/x/tools v0.22.0 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
golang.org/x/tools v0.30.0 // indirect
|
||||
)
|
||||
|
||||
54
go.sum
54
go.sum
@@ -24,8 +24,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/flynn/noise v1.1.0 h1:KjPQoQCEFdZDiP03phOvGi11+SVVhBG2wOWAorLsstg=
|
||||
github.com/flynn/noise v1.1.0/go.mod h1:xbMo+0i6+IGbYdJhF31t2eR1BIU0CYc12+BNAKwUTag=
|
||||
github.com/gaissmai/bart v0.18.1 h1:bX2j560JC1MJpoEDevBGvXL5OZ1mkls320Vl8Igb5QQ=
|
||||
github.com/gaissmai/bart v0.18.1/go.mod h1:JJzMAhNF5Rjo4SF4jWBrANuJfqY+FvsFhW7t1UZJ+XY=
|
||||
github.com/gaissmai/bart v0.20.4 h1:Ik47r1fy3jRVU+1eYzKSW3ho2UgBVTVnUS8O993584U=
|
||||
github.com/gaissmai/bart v0.20.4/go.mod h1:cEed+ge8dalcbpi8wtS9x9m2hn/fNJH5suhdGQOHnYk=
|
||||
github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
|
||||
github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
|
||||
github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY=
|
||||
@@ -53,8 +53,8 @@ github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMyw
|
||||
github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/google/go-cmp v0.5.4/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
|
||||
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||
github.com/google/gopacket v1.1.19 h1:ves8RnFZPGiFnTS0uPQStjwru6uO6h+nlr9j6fL7kF8=
|
||||
github.com/google/gopacket v1.1.19/go.mod h1:iJ8V8n6KS+z2U1A8pUwu8bW5SyEMkXJB8Yo/Vo+TKTo=
|
||||
@@ -68,8 +68,8 @@ github.com/kardianos/service v1.2.2 h1:ZvePhAHfvo0A7Mftk/tEzqEZ7Q4lgnR8sGz4xu1YX
|
||||
github.com/kardianos/service v1.2.2/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
|
||||
github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8=
|
||||
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
|
||||
github.com/klauspost/compress v1.17.11 h1:In6xLpyWOi1+C7tXUUWv2ot1QvBjxevKAaI6IXrJmUc=
|
||||
github.com/klauspost/compress v1.17.11/go.mod h1:pMDklpSncoRMuLFrf1W9Ss9KT+0rH90U12bZKk7uwG0=
|
||||
github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo=
|
||||
github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ=
|
||||
github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
|
||||
github.com/konsorten/go-windows-terminal-sequences v1.0.3/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
|
||||
github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc=
|
||||
@@ -83,8 +83,8 @@ github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
|
||||
github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc=
|
||||
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
|
||||
github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0=
|
||||
github.com/miekg/dns v1.1.63 h1:8M5aAw6OMZfFXTT7K5V0Eu5YiiL8l7nUAkyN6C9YwaY=
|
||||
github.com/miekg/dns v1.1.63/go.mod h1:6NGHfjhpmr5lt3XPLuyfDJi5AXbNIPM9PY6H6sF1Nfs=
|
||||
github.com/miekg/dns v1.1.65 h1:0+tIPHzUW0GCge7IiK3guGP57VAw7hoPDfApjkMD1Fc=
|
||||
github.com/miekg/dns v1.1.65/go.mod h1:Dzw9769uoKVaLuODMDZz9M6ynFU6Em65csPuoi8G0ck=
|
||||
github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b h1:J/AzCvg5z0Hn1rqZUJjpbzALUmkKX0Zwbc/i4fw7Sfk=
|
||||
github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b/go.mod h1:XsNlhZGX73bx86s2hdc/FuaLm2CPZJemRLMA+WTFxgs=
|
||||
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||
@@ -106,8 +106,8 @@ github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXP
|
||||
github.com/prometheus/client_golang v1.0.0/go.mod h1:db9x61etRT2tGnBNRi70OPL5FsnadC4Ky3P0J6CfImo=
|
||||
github.com/prometheus/client_golang v1.7.1/go.mod h1:PY5Wy2awLA44sXw4AOSfFBetzPP4j5+D6mVACh+pe2M=
|
||||
github.com/prometheus/client_golang v1.11.0/go.mod h1:Z6t4BnS23TR94PD6BsDNk8yVqroYurpAkEiz0P2BEV0=
|
||||
github.com/prometheus/client_golang v1.21.1 h1:DOvXXTqVzvkIewV/CDPFdejpMCGeMcbGCQ8YOmu+Ibk=
|
||||
github.com/prometheus/client_golang v1.21.1/go.mod h1:U9NM32ykUErtVBxdvD3zfi+EuFkkaBvMb09mIfe0Zgg=
|
||||
github.com/prometheus/client_golang v1.22.0 h1:rb93p9lokFEsctTys46VnV1kLCDpVZ0a/Y92Vm0Zc6Q=
|
||||
github.com/prometheus/client_golang v1.22.0/go.mod h1:R7ljNsLXhuQXYZYtw6GAE9AZg8Y7vEW5scdCXrWRXC0=
|
||||
github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo=
|
||||
github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
|
||||
github.com/prometheus/client_model v0.2.0/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
|
||||
@@ -156,16 +156,16 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk
|
||||
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
||||
golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
|
||||
golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34=
|
||||
golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc=
|
||||
golang.org/x/crypto v0.37.0 h1:kJNSjF/Xp7kU0iB2Z+9viTPMW4EqqsrywMXLJOOsXSE=
|
||||
golang.org/x/crypto v0.37.0/go.mod h1:vg+k43peMZ0pUMhYmVAWysMK35e6ioLh3wB8ZCAfbVc=
|
||||
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 h1:Di6/M8l0O2lCLc6VVRWhgCiApHV8MnQurBnFSHsQtNY=
|
||||
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc=
|
||||
golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY=
|
||||
golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg=
|
||||
golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
||||
golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
||||
golang.org/x/mod v0.18.0 h1:5+9lSbEzPSdWkH32vYPBwEpX8KwDbM52Ud9xBUvNlb0=
|
||||
golang.org/x/mod v0.18.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
|
||||
golang.org/x/mod v0.23.0 h1:Zb7khfcRGKk+kqfxFaP5tZqCnDZMjC5VtUBs87Hr6QM=
|
||||
golang.org/x/mod v0.23.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY=
|
||||
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||
golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||
golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||
@@ -176,8 +176,8 @@ golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLL
|
||||
golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
|
||||
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
|
||||
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||
golang.org/x/net v0.37.0 h1:1zLorHbz+LYj7MQlSf1+2tPIIgibq2eL5xkrGk6f+2c=
|
||||
golang.org/x/net v0.37.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8=
|
||||
golang.org/x/net v0.39.0 h1:ZCu7HMWDxpXpaiKdhzIfaltL9Lp31x/3fCP11bc6/fY=
|
||||
golang.org/x/net v0.39.0/go.mod h1:X7NRbYVEA+ewNkCNyJ513WmMdQ3BineSwVtN2zD/d+E=
|
||||
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
|
||||
golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
@@ -185,8 +185,8 @@ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJ
|
||||
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.12.0 h1:MHc5BpPuC30uJk597Ri8TV3CNZcTLu6B6z4lJy+g6Jw=
|
||||
golang.org/x/sync v0.12.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
|
||||
golang.org/x/sync v0.13.0 h1:AauUjRAJ9OSnvULf/ARrrVywoJDy0YS2AwQ98I37610=
|
||||
golang.org/x/sync v0.13.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
|
||||
golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
@@ -204,11 +204,11 @@ golang.org/x/sys v0.0.0-20210603081109-ebe580a85c40/go.mod h1:oPkhp1MJrh7nUepCBc
|
||||
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik=
|
||||
golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
|
||||
golang.org/x/sys v0.32.0 h1:s77OFDvIQeibCmezSnk/q6iAfkdiQaJi4VzroCFrN20=
|
||||
golang.org/x/sys v0.32.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/term v0.30.0 h1:PQ39fJZ+mfadBm0y5WlL4vlM7Sx1Hgf13sMIY2+QS9Y=
|
||||
golang.org/x/term v0.30.0/go.mod h1:NYYFdzHoI5wRh/h5tDMdMqCqPJZEuNqVR5xJLd/n67g=
|
||||
golang.org/x/term v0.31.0 h1:erwDkOK1Msy6offm1mOgvspSkslFnIGsFnxOKoufg3o=
|
||||
golang.org/x/term v0.31.0/go.mod h1:R4BeIy7D95HzImkxGkTW1UQTtP54tio2RyHz7PwK0aw=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
|
||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
@@ -219,8 +219,8 @@ golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtn
|
||||
golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
|
||||
golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
|
||||
golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
|
||||
golang.org/x/tools v0.22.0 h1:gqSGLZqv+AI9lIQzniJ0nZDRG5GBPsSi+DRNHWNz6yA=
|
||||
golang.org/x/tools v0.22.0/go.mod h1:aCwcsjqvq7Yqt6TNyX7QMU2enbQ/Gt0bo6krSeEri+c=
|
||||
golang.org/x/tools v0.30.0 h1:BgcpHewrV5AUp2G9MebG4XPFI1E2W41zU1SaqVA9vJY=
|
||||
golang.org/x/tools v0.30.0/go.mod h1:c347cR/OJfw5TI+GfX7RUPNMdDRRbjvYTS0jPyvsVtY=
|
||||
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
@@ -239,8 +239,8 @@ google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miE
|
||||
google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo=
|
||||
google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU=
|
||||
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
|
||||
google.golang.org/protobuf v1.36.5 h1:tPhr+woSbjfYvY6/GPufUoYizxw1cF/yFoxJ2fmpwlM=
|
||||
google.golang.org/protobuf v1.36.5/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE=
|
||||
google.golang.org/protobuf v1.36.6 h1:z1NpPI8ku2WgiWnf+t9wTPsn6eP1L7ksHUlkfLvd9xY=
|
||||
google.golang.org/protobuf v1.36.6/go.mod h1:jduwjTPXsFjZGTmRluh+L6NjiWu7pchiJ2/5YcXBHnY=
|
||||
gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
@@ -251,8 +251,6 @@ gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
gopkg.in/yaml.v2 v2.2.5/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
|
||||
gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
|
||||
@@ -25,7 +25,7 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
|
||||
|
||||
// If we're connecting to a v6 address we must use a v2 cert
|
||||
cs := f.pki.getCertState()
|
||||
v := cs.defaultVersion
|
||||
v := cs.initiatingVersion
|
||||
for _, a := range hh.hostinfo.vpnAddrs {
|
||||
if a.Is6() {
|
||||
v = cert.Version2
|
||||
@@ -50,7 +50,7 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
|
||||
Error("Unable to handshake with host because no certificate handshake bytes is available")
|
||||
}
|
||||
|
||||
ci, err := NewConnectionState(f.l, cs, crt, true, noise.HandshakeIX, cs.psk.primary)
|
||||
ci, err := NewConnectionState(f.l, cs, crt, true, noise.HandshakeIX)
|
||||
if err != nil {
|
||||
f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs).
|
||||
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
|
||||
@@ -71,7 +71,8 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
|
||||
|
||||
hsBytes, err := hs.Marshal()
|
||||
if err != nil {
|
||||
f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs).WithField("certVersion", v).
|
||||
f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs).
|
||||
WithField("certVersion", v).
|
||||
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to marshal handshake message")
|
||||
return false
|
||||
}
|
||||
@@ -100,57 +101,38 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
||||
if crt == nil {
|
||||
f.l.WithField("udpAddr", addr).
|
||||
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
|
||||
WithField("certVersion", cs.defaultVersion).
|
||||
WithField("certVersion", cs.initiatingVersion).
|
||||
Error("Unable to handshake with host because no certificate is available")
|
||||
}
|
||||
|
||||
var (
|
||||
err error
|
||||
ci *ConnectionState
|
||||
msg []byte
|
||||
)
|
||||
|
||||
hs := &NebulaHandshake{}
|
||||
|
||||
for _, psk := range cs.psk.keys {
|
||||
ci, err = NewConnectionState(f.l, cs, crt, false, noise.HandshakeIX, psk)
|
||||
if err != nil {
|
||||
//TODO: should be bother logging this, if we have multiple psks and the error is unrelated it will be verbose.
|
||||
f.l.WithError(err).WithField("udpAddr", addr).
|
||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
||||
Error("Failed to create connection state")
|
||||
continue
|
||||
}
|
||||
|
||||
msg, _, _, err = ci.H.ReadMessage(nil, packet[header.Len:])
|
||||
if err != nil {
|
||||
// Calls to ReadMessage with an incorrect psk should fail, try the next one if we have one
|
||||
continue
|
||||
}
|
||||
|
||||
// Sometimes ReadMessage returns fine with a nil psk even if the handshake is using a psk, ensure our protobuf
|
||||
// comes out clean as well
|
||||
err = hs.Unmarshal(msg)
|
||||
if err == nil {
|
||||
// There was no error, we can continue with this handshake
|
||||
break
|
||||
}
|
||||
|
||||
// The unmarshal failed, try the next psk if we have one
|
||||
}
|
||||
|
||||
// We finished with an error, log it and get out
|
||||
if err != nil || hs.Details == nil {
|
||||
// We aren't logging the error here because we can't be sure of the failure when using psk
|
||||
f.l.WithField("udpAddr", addr).
|
||||
ci, err := NewConnectionState(f.l, cs, crt, false, noise.HandshakeIX)
|
||||
if err != nil {
|
||||
f.l.WithError(err).WithField("udpAddr", addr).
|
||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
||||
Error("Was unable to decrypt the handshake")
|
||||
Error("Failed to create connection state")
|
||||
return
|
||||
}
|
||||
|
||||
// Mark packet 1 as seen so it doesn't show up as missed
|
||||
ci.window.Update(f.l, 1)
|
||||
|
||||
msg, _, _, err := ci.H.ReadMessage(nil, packet[header.Len:])
|
||||
if err != nil {
|
||||
f.l.WithError(err).WithField("udpAddr", addr).
|
||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
||||
Error("Failed to call noise.ReadMessage")
|
||||
return
|
||||
}
|
||||
|
||||
hs := &NebulaHandshake{}
|
||||
err = hs.Unmarshal(msg)
|
||||
if err != nil || hs.Details == nil {
|
||||
f.l.WithError(err).WithField("udpAddr", addr).
|
||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
||||
Error("Failed unmarshal handshake message")
|
||||
return
|
||||
}
|
||||
|
||||
rc, err := cert.Recombine(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), ci.Curve())
|
||||
if err != nil {
|
||||
f.l.WithError(err).WithField("udpAddr", addr).
|
||||
@@ -204,15 +186,16 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
||||
var vpnAddrs []netip.Addr
|
||||
var filteredNetworks []netip.Prefix
|
||||
certName := remoteCert.Certificate.Name()
|
||||
certVersion := remoteCert.Certificate.Version()
|
||||
fingerprint := remoteCert.Fingerprint
|
||||
issuer := remoteCert.Certificate.Issuer()
|
||||
|
||||
for _, network := range remoteCert.Certificate.Networks() {
|
||||
vpnAddr := network.Addr()
|
||||
_, found := f.myVpnAddrsTable.Lookup(vpnAddr)
|
||||
if found {
|
||||
if f.myVpnAddrsTable.Contains(vpnAddr) {
|
||||
f.l.WithField("vpnAddr", vpnAddr).WithField("udpAddr", addr).
|
||||
WithField("certName", certName).
|
||||
WithField("certVersion", certVersion).
|
||||
WithField("fingerprint", fingerprint).
|
||||
WithField("issuer", issuer).
|
||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Refusing to handshake with myself")
|
||||
@@ -220,7 +203,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
||||
}
|
||||
|
||||
// vpnAddrs outside our vpn networks are of no use to us, filter them out
|
||||
if _, ok := f.myVpnNetworksTable.Lookup(vpnAddr); !ok {
|
||||
if !f.myVpnNetworksTable.Contains(vpnAddr) {
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -231,6 +214,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
||||
if len(vpnAddrs) == 0 {
|
||||
f.l.WithError(err).WithField("udpAddr", addr).
|
||||
WithField("certName", certName).
|
||||
WithField("certVersion", certVersion).
|
||||
WithField("fingerprint", fingerprint).
|
||||
WithField("issuer", issuer).
|
||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("No usable vpn addresses from host, refusing handshake")
|
||||
@@ -250,6 +234,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
||||
if err != nil {
|
||||
f.l.WithError(err).WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
|
||||
WithField("certName", certName).
|
||||
WithField("certVersion", certVersion).
|
||||
WithField("fingerprint", fingerprint).
|
||||
WithField("issuer", issuer).
|
||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to generate index")
|
||||
@@ -272,6 +257,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
||||
|
||||
f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
|
||||
WithField("certName", certName).
|
||||
WithField("certVersion", certVersion).
|
||||
WithField("fingerprint", fingerprint).
|
||||
WithField("issuer", issuer).
|
||||
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
||||
@@ -283,6 +269,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
||||
if hs.Details.Cert == nil {
|
||||
f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
|
||||
WithField("certName", certName).
|
||||
WithField("certVersion", certVersion).
|
||||
WithField("fingerprint", fingerprint).
|
||||
WithField("issuer", issuer).
|
||||
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
||||
@@ -300,6 +287,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
||||
if err != nil {
|
||||
f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).
|
||||
WithField("certName", certName).
|
||||
WithField("certVersion", certVersion).
|
||||
WithField("fingerprint", fingerprint).
|
||||
WithField("issuer", issuer).
|
||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to marshal handshake message")
|
||||
@@ -311,6 +299,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
||||
if err != nil {
|
||||
f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).
|
||||
WithField("certName", certName).
|
||||
WithField("certVersion", certVersion).
|
||||
WithField("fingerprint", fingerprint).
|
||||
WithField("issuer", issuer).
|
||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage")
|
||||
@@ -318,6 +307,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
||||
} else if dKey == nil || eKey == nil {
|
||||
f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).
|
||||
WithField("certName", certName).
|
||||
WithField("certVersion", certVersion).
|
||||
WithField("fingerprint", fingerprint).
|
||||
WithField("issuer", issuer).
|
||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Noise did not arrive at a key")
|
||||
@@ -385,6 +375,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
||||
// This means there was an existing tunnel and this handshake was older than the one we are currently based on
|
||||
f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
|
||||
WithField("certName", certName).
|
||||
WithField("certVersion", certVersion).
|
||||
WithField("oldHandshakeTime", existing.lastHandshakeTime).
|
||||
WithField("newHandshakeTime", hostinfo.lastHandshakeTime).
|
||||
WithField("fingerprint", fingerprint).
|
||||
@@ -400,6 +391,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
||||
// This means we failed to insert because of collision on localIndexId. Just let the next handshake packet retry
|
||||
f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
|
||||
WithField("certName", certName).
|
||||
WithField("certVersion", certVersion).
|
||||
WithField("fingerprint", fingerprint).
|
||||
WithField("issuer", issuer).
|
||||
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
||||
@@ -412,6 +404,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
||||
// And we forget to update it here
|
||||
f.l.WithError(err).WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
|
||||
WithField("certName", certName).
|
||||
WithField("certVersion", certVersion).
|
||||
WithField("fingerprint", fingerprint).
|
||||
WithField("issuer", issuer).
|
||||
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
||||
@@ -428,6 +421,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
||||
if err != nil {
|
||||
f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
|
||||
WithField("certName", certName).
|
||||
WithField("certVersion", certVersion).
|
||||
WithField("fingerprint", fingerprint).
|
||||
WithField("issuer", issuer).
|
||||
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
||||
@@ -436,6 +430,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
||||
} else {
|
||||
f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
|
||||
WithField("certName", certName).
|
||||
WithField("certVersion", certVersion).
|
||||
WithField("fingerprint", fingerprint).
|
||||
WithField("issuer", issuer).
|
||||
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
||||
@@ -454,6 +449,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
||||
f.SendVia(via.relayHI, via.relay, msg, make([]byte, 12), make([]byte, mtu), false)
|
||||
f.l.WithField("vpnAddrs", vpnAddrs).WithField("relay", via.relayHI.vpnAddrs[0]).
|
||||
WithField("certName", certName).
|
||||
WithField("certVersion", certVersion).
|
||||
WithField("fingerprint", fingerprint).
|
||||
WithField("issuer", issuer).
|
||||
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
||||
@@ -558,6 +554,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
|
||||
|
||||
vpnNetworks := remoteCert.Certificate.Networks()
|
||||
certName := remoteCert.Certificate.Name()
|
||||
certVersion := remoteCert.Certificate.Version()
|
||||
fingerprint := remoteCert.Fingerprint
|
||||
issuer := remoteCert.Certificate.Issuer()
|
||||
|
||||
@@ -581,7 +578,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
|
||||
for _, network := range vpnNetworks {
|
||||
// vpnAddrs outside our vpn networks are of no use to us, filter them out
|
||||
vpnAddr := network.Addr()
|
||||
if _, ok := f.myVpnNetworksTable.Lookup(vpnAddr); !ok {
|
||||
if !f.myVpnNetworksTable.Contains(vpnAddr) {
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -592,6 +589,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
|
||||
if len(vpnAddrs) == 0 {
|
||||
f.l.WithError(err).WithField("udpAddr", addr).
|
||||
WithField("certName", certName).
|
||||
WithField("certVersion", certVersion).
|
||||
WithField("fingerprint", fingerprint).
|
||||
WithField("issuer", issuer).
|
||||
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).Error("No usable vpn addresses from host, refusing handshake")
|
||||
@@ -601,7 +599,9 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
|
||||
// Ensure the right host responded
|
||||
if !slices.Contains(vpnAddrs, hostinfo.vpnAddrs[0]) {
|
||||
f.l.WithField("intendedVpnAddrs", hostinfo.vpnAddrs).WithField("haveVpnNetworks", vpnNetworks).
|
||||
WithField("udpAddr", addr).WithField("certName", certName).
|
||||
WithField("udpAddr", addr).
|
||||
WithField("certName", certName).
|
||||
WithField("certVersion", certVersion).
|
||||
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
|
||||
Info("Incorrect host responded to handshake")
|
||||
|
||||
@@ -637,6 +637,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
|
||||
duration := time.Since(hh.startTime).Nanoseconds()
|
||||
f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
|
||||
WithField("certName", certName).
|
||||
WithField("certVersion", certVersion).
|
||||
WithField("fingerprint", fingerprint).
|
||||
WithField("issuer", issuer).
|
||||
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
||||
|
||||
@@ -274,8 +274,7 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
|
||||
}
|
||||
|
||||
// Don't relay through the host I'm trying to connect to
|
||||
_, found := hm.f.myVpnAddrsTable.Lookup(relay)
|
||||
if found {
|
||||
if hm.f.myVpnAddrsTable.Contains(relay) {
|
||||
continue
|
||||
}
|
||||
|
||||
|
||||
@@ -10,7 +10,6 @@ import (
|
||||
"github.com/slackhq/nebula/test"
|
||||
"github.com/slackhq/nebula/udp"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func Test_NewHandshakeManagerVpnIp(t *testing.T) {
|
||||
@@ -24,15 +23,11 @@ func Test_NewHandshakeManagerVpnIp(t *testing.T) {
|
||||
|
||||
lh := newTestLighthouse()
|
||||
|
||||
psk, err := NewPsk(PskAccepting, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
cs := &CertState{
|
||||
defaultVersion: cert.Version1,
|
||||
privateKey: []byte{},
|
||||
v1Cert: &dummyCert{version: cert.Version1},
|
||||
v1HandshakeBytes: []byte{},
|
||||
psk: psk,
|
||||
initiatingVersion: cert.Version1,
|
||||
privateKey: []byte{},
|
||||
v1Cert: &dummyCert{version: cert.Version1},
|
||||
v1HandshakeBytes: []byte{},
|
||||
}
|
||||
|
||||
blah := NewHandshakeManager(l, mainHM, lh, &udp.NoopConn{}, defaultHandshakeConfig)
|
||||
@@ -103,5 +98,5 @@ func (mw *mockEncWriter) GetHostInfo(_ netip.Addr) *HostInfo {
|
||||
}
|
||||
|
||||
func (mw *mockEncWriter) GetCertState() *CertState {
|
||||
return &CertState{defaultVersion: cert.Version2}
|
||||
return &CertState{initiatingVersion: cert.Version2}
|
||||
}
|
||||
|
||||
@@ -19,7 +19,7 @@ import (
|
||||
// |-----------------------------------------------------------------------|
|
||||
// | payload... |
|
||||
|
||||
type m map[string]interface{}
|
||||
type m = map[string]any
|
||||
|
||||
const (
|
||||
Version uint8 = 1
|
||||
|
||||
@@ -223,7 +223,7 @@ type HostInfo struct {
|
||||
recvError atomic.Uint32
|
||||
|
||||
// networks are both all vpn and unsafe networks assigned to this host
|
||||
networks *bart.Table[struct{}]
|
||||
networks *bart.Lite
|
||||
relayState RelayState
|
||||
|
||||
// HandshakePacket records the packets used to create this hostinfo
|
||||
@@ -732,13 +732,13 @@ func (i *HostInfo) buildNetworks(networks, unsafeNetworks []netip.Prefix) {
|
||||
return
|
||||
}
|
||||
|
||||
i.networks = new(bart.Table[struct{}])
|
||||
i.networks = new(bart.Lite)
|
||||
for _, network := range networks {
|
||||
i.networks.Insert(network, struct{}{})
|
||||
i.networks.Insert(network)
|
||||
}
|
||||
|
||||
for _, network := range unsafeNetworks {
|
||||
i.networks.Insert(network, struct{}{})
|
||||
i.networks.Insert(network)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -210,8 +210,8 @@ func TestHostMap_reload(t *testing.T) {
|
||||
assert.Empty(t, hm.GetPreferredRanges())
|
||||
|
||||
c.ReloadConfigString("preferred_ranges: [1.1.1.0/24, 10.1.1.0/24]")
|
||||
assert.EqualValues(t, []string{"1.1.1.0/24", "10.1.1.0/24"}, toS(hm.GetPreferredRanges()))
|
||||
assert.Equal(t, []string{"1.1.1.0/24", "10.1.1.0/24"}, toS(hm.GetPreferredRanges()))
|
||||
|
||||
c.ReloadConfigString("preferred_ranges: [1.1.1.1/32]")
|
||||
assert.EqualValues(t, []string{"1.1.1.1/32"}, toS(hm.GetPreferredRanges()))
|
||||
assert.Equal(t, []string{"1.1.1.1/32"}, toS(hm.GetPreferredRanges()))
|
||||
}
|
||||
|
||||
102
inside.go
102
inside.go
@@ -8,6 +8,7 @@ import (
|
||||
"github.com/slackhq/nebula/header"
|
||||
"github.com/slackhq/nebula/iputil"
|
||||
"github.com/slackhq/nebula/noiseutil"
|
||||
"github.com/slackhq/nebula/routing"
|
||||
)
|
||||
|
||||
func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb, out []byte, q int, localCache firewall.ConntrackCache) {
|
||||
@@ -21,14 +22,12 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
|
||||
|
||||
// Ignore local broadcast packets
|
||||
if f.dropLocalBroadcast {
|
||||
_, found := f.myBroadcastAddrsTable.Lookup(fwPacket.RemoteAddr)
|
||||
if found {
|
||||
if f.myBroadcastAddrsTable.Contains(fwPacket.RemoteAddr) {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
_, found := f.myVpnAddrsTable.Lookup(fwPacket.RemoteAddr)
|
||||
if found {
|
||||
if f.myVpnAddrsTable.Contains(fwPacket.RemoteAddr) {
|
||||
// Immediately forward packets from self to self.
|
||||
// This should only happen on Darwin-based and FreeBSD hosts, which
|
||||
// routes packets from the Nebula addr to the Nebula addr through the Nebula
|
||||
@@ -49,7 +48,7 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
|
||||
return
|
||||
}
|
||||
|
||||
hostinfo, ready := f.getOrHandshake(fwPacket.RemoteAddr, func(hh *HandshakeHostInfo) {
|
||||
hostinfo, ready := f.getOrHandshakeConsiderRouting(fwPacket, func(hh *HandshakeHostInfo) {
|
||||
hh.cachePacket(f.l, header.Message, 0, packet, f.sendMessageNow, f.cachedPacketMetrics)
|
||||
})
|
||||
|
||||
@@ -121,22 +120,93 @@ func (f *Interface) rejectOutside(packet []byte, ci *ConnectionState, hostinfo *
|
||||
f.sendNoMetrics(header.Message, 0, ci, hostinfo, netip.AddrPort{}, out, nb, packet, q)
|
||||
}
|
||||
|
||||
// Handshake will attempt to initiate a tunnel with the provided vpn address if it is within our vpn networks. This is a no-op if the tunnel is already established or being established
|
||||
func (f *Interface) Handshake(vpnAddr netip.Addr) {
|
||||
f.getOrHandshake(vpnAddr, nil)
|
||||
f.getOrHandshakeNoRouting(vpnAddr, nil)
|
||||
}
|
||||
|
||||
// getOrHandshake returns nil if the vpnAddr is not routable.
|
||||
// getOrHandshakeNoRouting returns nil if the vpnAddr is not routable.
|
||||
// If the 2nd return var is false then the hostinfo is not ready to be used in a tunnel
|
||||
func (f *Interface) getOrHandshake(vpnAddr netip.Addr, cacheCallback func(*HandshakeHostInfo)) (*HostInfo, bool) {
|
||||
_, found := f.myVpnNetworksTable.Lookup(vpnAddr)
|
||||
if !found {
|
||||
vpnAddr = f.inside.RouteFor(vpnAddr)
|
||||
if !vpnAddr.IsValid() {
|
||||
return nil, false
|
||||
}
|
||||
func (f *Interface) getOrHandshakeNoRouting(vpnAddr netip.Addr, cacheCallback func(*HandshakeHostInfo)) (*HostInfo, bool) {
|
||||
if f.myVpnNetworksTable.Contains(vpnAddr) {
|
||||
return f.handshakeManager.GetOrHandshake(vpnAddr, cacheCallback)
|
||||
}
|
||||
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// getOrHandshakeConsiderRouting will try to find the HostInfo to handle this packet, starting a handshake if necessary.
|
||||
// If the 2nd return var is false then the hostinfo is not ready to be used in a tunnel.
|
||||
func (f *Interface) getOrHandshakeConsiderRouting(fwPacket *firewall.Packet, cacheCallback func(*HandshakeHostInfo)) (*HostInfo, bool) {
|
||||
|
||||
destinationAddr := fwPacket.RemoteAddr
|
||||
|
||||
hostinfo, ready := f.getOrHandshakeNoRouting(destinationAddr, cacheCallback)
|
||||
|
||||
// Host is inside the mesh, no routing required
|
||||
if hostinfo != nil {
|
||||
return hostinfo, ready
|
||||
}
|
||||
|
||||
gateways := f.inside.RoutesFor(destinationAddr)
|
||||
|
||||
switch len(gateways) {
|
||||
case 0:
|
||||
return nil, false
|
||||
case 1:
|
||||
// Single gateway route
|
||||
return f.handshakeManager.GetOrHandshake(gateways[0].Addr(), cacheCallback)
|
||||
default:
|
||||
// Multi gateway route, perform ECMP categorization
|
||||
gatewayAddr, balancingOk := routing.BalancePacket(fwPacket, gateways)
|
||||
|
||||
if !balancingOk {
|
||||
// This happens if the gateway buckets were not calculated, this _should_ never happen
|
||||
f.l.Error("Gateway buckets not calculated, fallback from ECMP to random routing. Please report this bug.")
|
||||
}
|
||||
|
||||
var handshakeInfoForChosenGateway *HandshakeHostInfo
|
||||
var hhReceiver = func(hh *HandshakeHostInfo) {
|
||||
handshakeInfoForChosenGateway = hh
|
||||
}
|
||||
|
||||
// Store the handshakeHostInfo for later.
|
||||
// If this node is not reachable we will attempt other nodes, if none are reachable we will
|
||||
// cache the packet for this gateway.
|
||||
if hostinfo, ready = f.handshakeManager.GetOrHandshake(gatewayAddr, hhReceiver); ready {
|
||||
return hostinfo, true
|
||||
}
|
||||
|
||||
// It appears the selected gateway cannot be reached, find another gateway to fallback on.
|
||||
// The current implementation breaks ECMP but that seems better than no connectivity.
|
||||
// If ECMP is also required when a gateway is down then connectivity status
|
||||
// for each gateway needs to be kept and the weights recalculated when they go up or down.
|
||||
// This would also need to interact with unsafe_route updates through reloading the config or
|
||||
// use of the use_system_route_table option
|
||||
|
||||
if f.l.Level >= logrus.DebugLevel {
|
||||
f.l.WithField("destination", destinationAddr).
|
||||
WithField("originalGateway", gatewayAddr).
|
||||
Debugln("Calculated gateway for ECMP not available, attempting other gateways")
|
||||
}
|
||||
|
||||
for i := range gateways {
|
||||
// Skip the gateway that failed previously
|
||||
if gateways[i].Addr() == gatewayAddr {
|
||||
continue
|
||||
}
|
||||
|
||||
// We do not need the HandshakeHostInfo since we cache the packet in the originally chosen gateway
|
||||
if hostinfo, ready = f.handshakeManager.GetOrHandshake(gateways[i].Addr(), nil); ready {
|
||||
return hostinfo, true
|
||||
}
|
||||
}
|
||||
|
||||
// No gateways reachable, cache the packet in the originally chosen gateway
|
||||
cacheCallback(handshakeInfoForChosenGateway)
|
||||
return hostinfo, false
|
||||
}
|
||||
|
||||
return f.handshakeManager.GetOrHandshake(vpnAddr, cacheCallback)
|
||||
}
|
||||
|
||||
func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, nb, out []byte) {
|
||||
@@ -163,7 +233,7 @@ func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubTyp
|
||||
|
||||
// SendMessageToVpnAddr handles real addr:port lookup and sends to the current best known address for vpnAddr
|
||||
func (f *Interface) SendMessageToVpnAddr(t header.MessageType, st header.MessageSubType, vpnAddr netip.Addr, p, nb, out []byte) {
|
||||
hostInfo, ready := f.getOrHandshake(vpnAddr, func(hh *HandshakeHostInfo) {
|
||||
hostInfo, ready := f.getOrHandshakeNoRouting(vpnAddr, func(hh *HandshakeHostInfo) {
|
||||
hh.cachePacket(f.l, t, st, p, f.SendMessageToHostInfo, f.cachedPacketMetrics)
|
||||
})
|
||||
|
||||
|
||||
14
interface.go
14
interface.go
@@ -61,11 +61,11 @@ type Interface struct {
|
||||
serveDns bool
|
||||
createTime time.Time
|
||||
lightHouse *LightHouse
|
||||
myBroadcastAddrsTable *bart.Table[struct{}]
|
||||
myVpnAddrs []netip.Addr // A list of addresses assigned to us via our certificate
|
||||
myVpnAddrsTable *bart.Table[struct{}] // A table of addresses assigned to us via our certificate
|
||||
myVpnNetworks []netip.Prefix // A list of networks assigned to us via our certificate
|
||||
myVpnNetworksTable *bart.Table[struct{}] // A table of networks assigned to us via our certificate
|
||||
myBroadcastAddrsTable *bart.Lite
|
||||
myVpnAddrs []netip.Addr // A list of addresses assigned to us via our certificate
|
||||
myVpnAddrsTable *bart.Lite
|
||||
myVpnNetworks []netip.Prefix // A list of networks assigned to us via our certificate
|
||||
myVpnNetworksTable *bart.Lite
|
||||
dropLocalBroadcast bool
|
||||
dropMulticast bool
|
||||
routines int
|
||||
@@ -410,7 +410,7 @@ func (f *Interface) emitStats(ctx context.Context, i time.Duration) {
|
||||
udpStats := udp.NewUDPStatsEmitter(f.writers)
|
||||
|
||||
certExpirationGauge := metrics.GetOrRegisterGauge("certificate.ttl_seconds", nil)
|
||||
certDefaultVersion := metrics.GetOrRegisterGauge("certificate.default_version", nil)
|
||||
certInitiatingVersion := metrics.GetOrRegisterGauge("certificate.initiating_version", nil)
|
||||
certMaxVersion := metrics.GetOrRegisterGauge("certificate.max_version", nil)
|
||||
|
||||
for {
|
||||
@@ -425,7 +425,7 @@ func (f *Interface) emitStats(ctx context.Context, i time.Duration) {
|
||||
certState := f.pki.getCertState()
|
||||
defaultCrt := certState.GetDefaultCertificate()
|
||||
certExpirationGauge.Update(int64(defaultCrt.NotAfter().Sub(time.Now()) / time.Second))
|
||||
certDefaultVersion.Update(int64(defaultCrt.Version()))
|
||||
certInitiatingVersion.Update(int64(defaultCrt.Version()))
|
||||
|
||||
// Report the max certificate version we are capable of using
|
||||
if certState.v2Cert != nil {
|
||||
|
||||
@@ -32,7 +32,7 @@ type LightHouse struct {
|
||||
amLighthouse bool
|
||||
|
||||
myVpnNetworks []netip.Prefix
|
||||
myVpnNetworksTable *bart.Table[struct{}]
|
||||
myVpnNetworksTable *bart.Lite
|
||||
punchConn udp.Conn
|
||||
punchy *Punchy
|
||||
|
||||
@@ -201,8 +201,7 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error {
|
||||
|
||||
//TODO: we could technically insert all returned addrs instead of just the first one if a dns lookup was used
|
||||
addr := addrs[0].Unmap()
|
||||
_, found := lh.myVpnNetworksTable.Lookup(addr)
|
||||
if found {
|
||||
if lh.myVpnNetworksTable.Contains(addr) {
|
||||
lh.l.WithField("addr", rawAddr).WithField("entry", i+1).
|
||||
Warn("Ignoring lighthouse.advertise_addrs report because it is within the nebula network range")
|
||||
continue
|
||||
@@ -359,8 +358,7 @@ func (lh *LightHouse) parseLighthouses(c *config.C, lhMap map[netip.Addr]struct{
|
||||
return util.NewContextualError("Unable to parse lighthouse host entry", m{"host": host, "entry": i + 1}, err)
|
||||
}
|
||||
|
||||
_, found := lh.myVpnNetworksTable.Lookup(addr)
|
||||
if !found {
|
||||
if !lh.myVpnNetworksTable.Contains(addr) {
|
||||
return util.NewContextualError("lighthouse host is not in our networks, invalid", m{"vpnAddr": addr, "networks": lh.myVpnNetworks}, nil)
|
||||
}
|
||||
lhMap[addr] = struct{}{}
|
||||
@@ -422,7 +420,7 @@ func (lh *LightHouse) loadStaticMap(c *config.C, staticList map[netip.Addr]struc
|
||||
return err
|
||||
}
|
||||
|
||||
shm := c.GetMap("static_host_map", map[interface{}]interface{}{})
|
||||
shm := c.GetMap("static_host_map", map[string]any{})
|
||||
i := 0
|
||||
|
||||
for k, v := range shm {
|
||||
@@ -431,14 +429,13 @@ func (lh *LightHouse) loadStaticMap(c *config.C, staticList map[netip.Addr]struc
|
||||
return util.NewContextualError("Unable to parse static_host_map entry", m{"host": k, "entry": i + 1}, err)
|
||||
}
|
||||
|
||||
_, found := lh.myVpnNetworksTable.Lookup(vpnAddr)
|
||||
if !found {
|
||||
if !lh.myVpnNetworksTable.Contains(vpnAddr) {
|
||||
return util.NewContextualError("static_host_map key is not in our network, invalid", m{"vpnAddr": vpnAddr, "networks": lh.myVpnNetworks, "entry": i + 1}, nil)
|
||||
}
|
||||
|
||||
vals, ok := v.([]interface{})
|
||||
vals, ok := v.([]any)
|
||||
if !ok {
|
||||
vals = []interface{}{v}
|
||||
vals = []any{v}
|
||||
}
|
||||
remoteAddrs := []string{}
|
||||
for _, v := range vals {
|
||||
@@ -653,8 +650,7 @@ func (lh *LightHouse) shouldAdd(vpnAddr netip.Addr, to netip.Addr) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
_, found := lh.myVpnNetworksTable.Lookup(to)
|
||||
if found {
|
||||
if lh.myVpnNetworksTable.Contains(to) {
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -674,8 +670,7 @@ func (lh *LightHouse) unlockedShouldAddV4(vpnAddr netip.Addr, to *V4AddrPort) bo
|
||||
return false
|
||||
}
|
||||
|
||||
_, found := lh.myVpnNetworksTable.Lookup(udpAddr.Addr())
|
||||
if found {
|
||||
if lh.myVpnNetworksTable.Contains(udpAddr.Addr()) {
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -695,8 +690,7 @@ func (lh *LightHouse) unlockedShouldAddV6(vpnAddr netip.Addr, to *V6AddrPort) bo
|
||||
return false
|
||||
}
|
||||
|
||||
_, found := lh.myVpnNetworksTable.Lookup(udpAddr.Addr())
|
||||
if found {
|
||||
if lh.myVpnNetworksTable.Contains(udpAddr.Addr()) {
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -763,7 +757,7 @@ func (lh *LightHouse) innerQueryServer(addr netip.Addr, nb, out []byte) {
|
||||
if hi != nil {
|
||||
v = hi.ConnectionState.myCert.Version()
|
||||
} else {
|
||||
v = lh.ifce.GetCertState().defaultVersion
|
||||
v = lh.ifce.GetCertState().initiatingVersion
|
||||
}
|
||||
|
||||
if v == cert.Version1 {
|
||||
@@ -856,8 +850,7 @@ func (lh *LightHouse) SendUpdate() {
|
||||
|
||||
lal := lh.GetLocalAllowList()
|
||||
for _, e := range localAddrs(lh.l, lal) {
|
||||
_, found := lh.myVpnNetworksTable.Lookup(e)
|
||||
if found {
|
||||
if lh.myVpnNetworksTable.Contains(e) {
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -883,7 +876,7 @@ func (lh *LightHouse) SendUpdate() {
|
||||
if hi != nil {
|
||||
v = hi.ConnectionState.myCert.Version()
|
||||
} else {
|
||||
v = lh.ifce.GetCertState().defaultVersion
|
||||
v = lh.ifce.GetCertState().initiatingVersion
|
||||
}
|
||||
if v == cert.Version1 {
|
||||
if v1Update == nil {
|
||||
@@ -1114,7 +1107,7 @@ func (lhh *LightHouseHandler) sendHostPunchNotification(n *NebulaMeta, fromVpnAd
|
||||
targetHI := lhh.lh.ifce.GetHostInfo(punchNotifDest)
|
||||
var useVersion cert.Version
|
||||
if targetHI == nil {
|
||||
useVersion = lhh.lh.ifce.GetCertState().defaultVersion
|
||||
useVersion = lhh.lh.ifce.GetCertState().initiatingVersion
|
||||
} else {
|
||||
crt := targetHI.GetCert().Certificate
|
||||
useVersion = crt.Version()
|
||||
|
||||
@@ -14,7 +14,7 @@ import (
|
||||
"github.com/slackhq/nebula/test"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gopkg.in/yaml.v2"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
func TestOldIPv4Only(t *testing.T) {
|
||||
@@ -31,8 +31,8 @@ func TestOldIPv4Only(t *testing.T) {
|
||||
func Test_lhStaticMapping(t *testing.T) {
|
||||
l := test.NewLogger()
|
||||
myVpnNet := netip.MustParsePrefix("10.128.0.1/16")
|
||||
nt := new(bart.Table[struct{}])
|
||||
nt.Insert(myVpnNet, struct{}{})
|
||||
nt := new(bart.Lite)
|
||||
nt.Insert(myVpnNet)
|
||||
cs := &CertState{
|
||||
myVpnNetworks: []netip.Prefix{myVpnNet},
|
||||
myVpnNetworksTable: nt,
|
||||
@@ -40,15 +40,15 @@ func Test_lhStaticMapping(t *testing.T) {
|
||||
lh1 := "10.128.0.2"
|
||||
|
||||
c := config.NewC(l)
|
||||
c.Settings["lighthouse"] = map[interface{}]interface{}{"hosts": []interface{}{lh1}}
|
||||
c.Settings["static_host_map"] = map[interface{}]interface{}{lh1: []interface{}{"1.1.1.1:4242"}}
|
||||
c.Settings["lighthouse"] = map[string]any{"hosts": []any{lh1}}
|
||||
c.Settings["static_host_map"] = map[string]any{lh1: []any{"1.1.1.1:4242"}}
|
||||
_, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
lh2 := "10.128.0.3"
|
||||
c = config.NewC(l)
|
||||
c.Settings["lighthouse"] = map[interface{}]interface{}{"hosts": []interface{}{lh1, lh2}}
|
||||
c.Settings["static_host_map"] = map[interface{}]interface{}{lh1: []interface{}{"100.1.1.1:4242"}}
|
||||
c.Settings["lighthouse"] = map[string]any{"hosts": []any{lh1, lh2}}
|
||||
c.Settings["static_host_map"] = map[string]any{lh1: []any{"100.1.1.1:4242"}}
|
||||
_, err = NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
|
||||
require.EqualError(t, err, "lighthouse 10.128.0.3 does not have a static_host_map entry")
|
||||
}
|
||||
@@ -56,8 +56,8 @@ func Test_lhStaticMapping(t *testing.T) {
|
||||
func TestReloadLighthouseInterval(t *testing.T) {
|
||||
l := test.NewLogger()
|
||||
myVpnNet := netip.MustParsePrefix("10.128.0.1/16")
|
||||
nt := new(bart.Table[struct{}])
|
||||
nt.Insert(myVpnNet, struct{}{})
|
||||
nt := new(bart.Lite)
|
||||
nt.Insert(myVpnNet)
|
||||
cs := &CertState{
|
||||
myVpnNetworks: []netip.Prefix{myVpnNet},
|
||||
myVpnNetworksTable: nt,
|
||||
@@ -65,12 +65,12 @@ func TestReloadLighthouseInterval(t *testing.T) {
|
||||
lh1 := "10.128.0.2"
|
||||
|
||||
c := config.NewC(l)
|
||||
c.Settings["lighthouse"] = map[interface{}]interface{}{
|
||||
"hosts": []interface{}{lh1},
|
||||
c.Settings["lighthouse"] = map[string]any{
|
||||
"hosts": []any{lh1},
|
||||
"interval": "1s",
|
||||
}
|
||||
|
||||
c.Settings["static_host_map"] = map[interface{}]interface{}{lh1: []interface{}{"1.1.1.1:4242"}}
|
||||
c.Settings["static_host_map"] = map[string]any{lh1: []any{"1.1.1.1:4242"}}
|
||||
lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
|
||||
require.NoError(t, err)
|
||||
lh.ifce = &mockEncWriter{}
|
||||
@@ -91,8 +91,8 @@ func TestReloadLighthouseInterval(t *testing.T) {
|
||||
func BenchmarkLighthouseHandleRequest(b *testing.B) {
|
||||
l := test.NewLogger()
|
||||
myVpnNet := netip.MustParsePrefix("10.128.0.1/0")
|
||||
nt := new(bart.Table[struct{}])
|
||||
nt.Insert(myVpnNet, struct{}{})
|
||||
nt := new(bart.Lite)
|
||||
nt.Insert(myVpnNet)
|
||||
cs := &CertState{
|
||||
myVpnNetworks: []netip.Prefix{myVpnNet},
|
||||
myVpnNetworksTable: nt,
|
||||
@@ -192,12 +192,12 @@ func TestLighthouse_Memory(t *testing.T) {
|
||||
theirVpnIp := netip.MustParseAddr("10.128.0.3")
|
||||
|
||||
c := config.NewC(l)
|
||||
c.Settings["lighthouse"] = map[interface{}]interface{}{"am_lighthouse": true}
|
||||
c.Settings["listen"] = map[interface{}]interface{}{"port": 4242}
|
||||
c.Settings["lighthouse"] = map[string]any{"am_lighthouse": true}
|
||||
c.Settings["listen"] = map[string]any{"port": 4242}
|
||||
|
||||
myVpnNet := netip.MustParsePrefix("10.128.0.1/24")
|
||||
nt := new(bart.Table[struct{}])
|
||||
nt.Insert(myVpnNet, struct{}{})
|
||||
nt := new(bart.Lite)
|
||||
nt.Insert(myVpnNet)
|
||||
cs := &CertState{
|
||||
myVpnNetworks: []netip.Prefix{myVpnNet},
|
||||
myVpnNetworksTable: nt,
|
||||
@@ -277,12 +277,12 @@ func TestLighthouse_Memory(t *testing.T) {
|
||||
func TestLighthouse_reload(t *testing.T) {
|
||||
l := test.NewLogger()
|
||||
c := config.NewC(l)
|
||||
c.Settings["lighthouse"] = map[interface{}]interface{}{"am_lighthouse": true}
|
||||
c.Settings["listen"] = map[interface{}]interface{}{"port": 4242}
|
||||
c.Settings["lighthouse"] = map[string]any{"am_lighthouse": true}
|
||||
c.Settings["listen"] = map[string]any{"port": 4242}
|
||||
|
||||
myVpnNet := netip.MustParsePrefix("10.128.0.1/24")
|
||||
nt := new(bart.Table[struct{}])
|
||||
nt.Insert(myVpnNet, struct{}{})
|
||||
nt := new(bart.Lite)
|
||||
nt.Insert(myVpnNet)
|
||||
cs := &CertState{
|
||||
myVpnNetworks: []netip.Prefix{myVpnNet},
|
||||
myVpnNetworksTable: nt,
|
||||
@@ -291,9 +291,9 @@ func TestLighthouse_reload(t *testing.T) {
|
||||
lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
nc := map[interface{}]interface{}{
|
||||
"static_host_map": map[interface{}]interface{}{
|
||||
"10.128.0.2": []interface{}{"1.1.1.1:4242"},
|
||||
nc := map[string]any{
|
||||
"static_host_map": map[string]any{
|
||||
"10.128.0.2": []any{"1.1.1.1:4242"},
|
||||
},
|
||||
}
|
||||
rc, err := yaml.Marshal(nc)
|
||||
@@ -417,7 +417,7 @@ func (tw *testEncWriter) GetHostInfo(vpnIp netip.Addr) *HostInfo {
|
||||
}
|
||||
|
||||
func (tw *testEncWriter) GetCertState() *CertState {
|
||||
return &CertState{defaultVersion: tw.protocolVersion}
|
||||
return &CertState{initiatingVersion: tw.protocolVersion}
|
||||
}
|
||||
|
||||
// assertIp4InArray asserts every address in want is at the same position in have and that the lengths match
|
||||
|
||||
6
main.go
6
main.go
@@ -13,10 +13,10 @@ import (
|
||||
"github.com/slackhq/nebula/sshd"
|
||||
"github.com/slackhq/nebula/udp"
|
||||
"github.com/slackhq/nebula/util"
|
||||
"gopkg.in/yaml.v2"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
type m map[string]interface{}
|
||||
type m = map[string]any
|
||||
|
||||
func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logger, deviceFactory overlay.DeviceFactory) (retcon *Control, reterr error) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
@@ -284,7 +284,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
||||
var dnsStart func()
|
||||
if lightHouse.amLighthouse && serveDns {
|
||||
l.Debugln("Starting dns server")
|
||||
dnsStart = dnsMain(l, pki.getCertState(), hostMap, c)
|
||||
dnsStart = dnsMain(ctx, l, pki.getCertState(), hostMap, c)
|
||||
}
|
||||
|
||||
return &Control{
|
||||
|
||||
@@ -31,8 +31,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
|
||||
|
||||
//l.Error("in packet ", header, packet[HeaderLen:])
|
||||
if ip.IsValid() {
|
||||
_, found := f.myVpnNetworksTable.Lookup(ip.Addr())
|
||||
if found {
|
||||
if f.myVpnNetworksTable.Contains(ip.Addr()) {
|
||||
if f.l.Level >= logrus.DebugLevel {
|
||||
f.l.WithField("udpAddr", ip).Debug("Refusing to process double encrypted packet")
|
||||
}
|
||||
|
||||
@@ -3,6 +3,8 @@ package overlay
|
||||
import (
|
||||
"io"
|
||||
"net/netip"
|
||||
|
||||
"github.com/slackhq/nebula/routing"
|
||||
)
|
||||
|
||||
type Device interface {
|
||||
@@ -10,6 +12,6 @@ type Device interface {
|
||||
Activate() error
|
||||
Networks() []netip.Prefix
|
||||
Name() string
|
||||
RouteFor(netip.Addr) netip.Addr
|
||||
RoutesFor(netip.Addr) routing.Gateways
|
||||
NewMultiQueueReader() (io.ReadWriteCloser, error)
|
||||
}
|
||||
|
||||
@@ -11,13 +11,14 @@ import (
|
||||
"github.com/gaissmai/bart"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/routing"
|
||||
)
|
||||
|
||||
type Route struct {
|
||||
MTU int
|
||||
Metric int
|
||||
Cidr netip.Prefix
|
||||
Via netip.Addr
|
||||
Via routing.Gateways
|
||||
Install bool
|
||||
}
|
||||
|
||||
@@ -47,15 +48,17 @@ func (r Route) String() string {
|
||||
return s
|
||||
}
|
||||
|
||||
func makeRouteTree(l *logrus.Logger, routes []Route, allowMTU bool) (*bart.Table[netip.Addr], error) {
|
||||
routeTree := new(bart.Table[netip.Addr])
|
||||
func makeRouteTree(l *logrus.Logger, routes []Route, allowMTU bool) (*bart.Table[routing.Gateways], error) {
|
||||
routeTree := new(bart.Table[routing.Gateways])
|
||||
for _, r := range routes {
|
||||
if !allowMTU && r.MTU > 0 {
|
||||
l.WithField("route", r).Warnf("route MTU is not supported in %s", runtime.GOOS)
|
||||
}
|
||||
|
||||
if r.Via.IsValid() {
|
||||
routeTree.Insert(r.Cidr, r.Via)
|
||||
gateways := r.Via
|
||||
if len(gateways) > 0 {
|
||||
routing.CalculateBucketsForGateways(gateways)
|
||||
routeTree.Insert(r.Cidr, gateways)
|
||||
}
|
||||
}
|
||||
return routeTree, nil
|
||||
@@ -69,7 +72,7 @@ func parseRoutes(c *config.C, networks []netip.Prefix) ([]Route, error) {
|
||||
return []Route{}, nil
|
||||
}
|
||||
|
||||
rawRoutes, ok := r.([]interface{})
|
||||
rawRoutes, ok := r.([]any)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("tun.routes is not an array")
|
||||
}
|
||||
@@ -80,7 +83,7 @@ func parseRoutes(c *config.C, networks []netip.Prefix) ([]Route, error) {
|
||||
|
||||
routes := make([]Route, len(rawRoutes))
|
||||
for i, r := range rawRoutes {
|
||||
m, ok := r.(map[interface{}]interface{})
|
||||
m, ok := r.(map[string]any)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("entry %v in tun.routes is invalid", i+1)
|
||||
}
|
||||
@@ -148,7 +151,7 @@ func parseUnsafeRoutes(c *config.C, networks []netip.Prefix) ([]Route, error) {
|
||||
return []Route{}, nil
|
||||
}
|
||||
|
||||
rawRoutes, ok := r.([]interface{})
|
||||
rawRoutes, ok := r.([]any)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("tun.unsafe_routes is not an array")
|
||||
}
|
||||
@@ -159,7 +162,7 @@ func parseUnsafeRoutes(c *config.C, networks []netip.Prefix) ([]Route, error) {
|
||||
|
||||
routes := make([]Route, len(rawRoutes))
|
||||
for i, r := range rawRoutes {
|
||||
m, ok := r.(map[interface{}]interface{})
|
||||
m, ok := r.(map[string]any)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("entry %v in tun.unsafe_routes is invalid", i+1)
|
||||
}
|
||||
@@ -201,14 +204,63 @@ func parseUnsafeRoutes(c *config.C, networks []netip.Prefix) ([]Route, error) {
|
||||
return nil, fmt.Errorf("entry %v.via in tun.unsafe_routes is not present", i+1)
|
||||
}
|
||||
|
||||
via, ok := rVia.(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("entry %v.via in tun.unsafe_routes is not a string: found %T", i+1, rVia)
|
||||
}
|
||||
var gateways routing.Gateways
|
||||
|
||||
viaVpnIp, err := netip.ParseAddr(via)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("entry %v.via in tun.unsafe_routes failed to parse address: %v", i+1, err)
|
||||
switch via := rVia.(type) {
|
||||
case string:
|
||||
viaIp, err := netip.ParseAddr(via)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("entry %v.via in tun.unsafe_routes failed to parse address: %v", i+1, err)
|
||||
}
|
||||
|
||||
gateways = routing.Gateways{routing.NewGateway(viaIp, 1)}
|
||||
|
||||
case []any:
|
||||
gateways = make(routing.Gateways, len(via))
|
||||
for ig, v := range via {
|
||||
gatewayMap, ok := v.(map[string]any)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("entry %v in tun.unsafe_routes[%v].via is invalid", i+1, ig+1)
|
||||
}
|
||||
|
||||
rGateway, ok := gatewayMap["gateway"]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("entry .gateway in tun.unsafe_routes[%v].via[%v] is not present", i+1, ig+1)
|
||||
}
|
||||
|
||||
parsedGateway, ok := rGateway.(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("entry .gateway in tun.unsafe_routes[%v].via[%v] is not a string", i+1, ig+1)
|
||||
}
|
||||
|
||||
gatewayIp, err := netip.ParseAddr(parsedGateway)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("entry .gateway in tun.unsafe_routes[%v].via[%v] failed to parse address: %v", i+1, ig+1, err)
|
||||
}
|
||||
|
||||
rGatewayWeight, ok := gatewayMap["weight"]
|
||||
if !ok {
|
||||
rGatewayWeight = 1
|
||||
}
|
||||
|
||||
gatewayWeight, ok := rGatewayWeight.(int)
|
||||
if !ok {
|
||||
_, err = strconv.ParseInt(rGatewayWeight.(string), 10, 32)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("entry .weight in tun.unsafe_routes[%v].via[%v] is not an integer", i+1, ig+1)
|
||||
}
|
||||
}
|
||||
|
||||
if gatewayWeight < 1 || gatewayWeight > math.MaxInt32 {
|
||||
return nil, fmt.Errorf("entry .weight in tun.unsafe_routes[%v].via[%v] is not in range (1-%d) : %v", i+1, ig+1, math.MaxInt32, gatewayWeight)
|
||||
}
|
||||
|
||||
gateways[ig] = routing.NewGateway(gatewayIp, gatewayWeight)
|
||||
|
||||
}
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("entry %v.via in tun.unsafe_routes is not a string or list of gateways: found %T", i+1, rVia)
|
||||
}
|
||||
|
||||
rRoute, ok := m["route"]
|
||||
@@ -226,7 +278,7 @@ func parseUnsafeRoutes(c *config.C, networks []netip.Prefix) ([]Route, error) {
|
||||
}
|
||||
|
||||
r := Route{
|
||||
Via: viaVpnIp,
|
||||
Via: gateways,
|
||||
MTU: mtu,
|
||||
Metric: metric,
|
||||
Install: install,
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/routing"
|
||||
"github.com/slackhq/nebula/test"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -23,75 +24,75 @@ func Test_parseRoutes(t *testing.T) {
|
||||
assert.Empty(t, routes)
|
||||
|
||||
// not an array
|
||||
c.Settings["tun"] = map[interface{}]interface{}{"routes": "hi"}
|
||||
c.Settings["tun"] = map[string]any{"routes": "hi"}
|
||||
routes, err = parseRoutes(c, []netip.Prefix{n})
|
||||
assert.Nil(t, routes)
|
||||
require.EqualError(t, err, "tun.routes is not an array")
|
||||
|
||||
// no routes
|
||||
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{}}
|
||||
c.Settings["tun"] = map[string]any{"routes": []any{}}
|
||||
routes, err = parseRoutes(c, []netip.Prefix{n})
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, routes)
|
||||
|
||||
// weird route
|
||||
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{"asdf"}}
|
||||
c.Settings["tun"] = map[string]any{"routes": []any{"asdf"}}
|
||||
routes, err = parseRoutes(c, []netip.Prefix{n})
|
||||
assert.Nil(t, routes)
|
||||
require.EqualError(t, err, "entry 1 in tun.routes is invalid")
|
||||
|
||||
// no mtu
|
||||
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{}}}
|
||||
c.Settings["tun"] = map[string]any{"routes": []any{map[string]any{}}}
|
||||
routes, err = parseRoutes(c, []netip.Prefix{n})
|
||||
assert.Nil(t, routes)
|
||||
require.EqualError(t, err, "entry 1.mtu in tun.routes is not present")
|
||||
|
||||
// bad mtu
|
||||
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "nope"}}}
|
||||
c.Settings["tun"] = map[string]any{"routes": []any{map[string]any{"mtu": "nope"}}}
|
||||
routes, err = parseRoutes(c, []netip.Prefix{n})
|
||||
assert.Nil(t, routes)
|
||||
require.EqualError(t, err, "entry 1.mtu in tun.routes is not an integer: strconv.Atoi: parsing \"nope\": invalid syntax")
|
||||
|
||||
// low mtu
|
||||
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "499"}}}
|
||||
c.Settings["tun"] = map[string]any{"routes": []any{map[string]any{"mtu": "499"}}}
|
||||
routes, err = parseRoutes(c, []netip.Prefix{n})
|
||||
assert.Nil(t, routes)
|
||||
require.EqualError(t, err, "entry 1.mtu in tun.routes is below 500: 499")
|
||||
|
||||
// missing route
|
||||
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500"}}}
|
||||
c.Settings["tun"] = map[string]any{"routes": []any{map[string]any{"mtu": "500"}}}
|
||||
routes, err = parseRoutes(c, []netip.Prefix{n})
|
||||
assert.Nil(t, routes)
|
||||
require.EqualError(t, err, "entry 1.route in tun.routes is not present")
|
||||
|
||||
// unparsable route
|
||||
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "nope"}}}
|
||||
c.Settings["tun"] = map[string]any{"routes": []any{map[string]any{"mtu": "500", "route": "nope"}}}
|
||||
routes, err = parseRoutes(c, []netip.Prefix{n})
|
||||
assert.Nil(t, routes)
|
||||
require.EqualError(t, err, "entry 1.route in tun.routes failed to parse: netip.ParsePrefix(\"nope\"): no '/'")
|
||||
|
||||
// below network range
|
||||
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "1.0.0.0/8"}}}
|
||||
c.Settings["tun"] = map[string]any{"routes": []any{map[string]any{"mtu": "500", "route": "1.0.0.0/8"}}}
|
||||
routes, err = parseRoutes(c, []netip.Prefix{n})
|
||||
assert.Nil(t, routes)
|
||||
require.EqualError(t, err, "entry 1.route in tun.routes is not contained within the configured vpn networks; route: 1.0.0.0/8, networks: [10.0.0.0/24]")
|
||||
|
||||
// above network range
|
||||
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "10.0.1.0/24"}}}
|
||||
c.Settings["tun"] = map[string]any{"routes": []any{map[string]any{"mtu": "500", "route": "10.0.1.0/24"}}}
|
||||
routes, err = parseRoutes(c, []netip.Prefix{n})
|
||||
assert.Nil(t, routes)
|
||||
require.EqualError(t, err, "entry 1.route in tun.routes is not contained within the configured vpn networks; route: 10.0.1.0/24, networks: [10.0.0.0/24]")
|
||||
|
||||
// Not in multiple ranges
|
||||
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "192.0.0.0/24"}}}
|
||||
c.Settings["tun"] = map[string]any{"routes": []any{map[string]any{"mtu": "500", "route": "192.0.0.0/24"}}}
|
||||
routes, err = parseRoutes(c, []netip.Prefix{n, netip.MustParsePrefix("192.1.0.0/24")})
|
||||
assert.Nil(t, routes)
|
||||
require.EqualError(t, err, "entry 1.route in tun.routes is not contained within the configured vpn networks; route: 192.0.0.0/24, networks: [10.0.0.0/24 192.1.0.0/24]")
|
||||
|
||||
// happy case
|
||||
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{
|
||||
map[interface{}]interface{}{"mtu": "9000", "route": "10.0.0.0/29"},
|
||||
map[interface{}]interface{}{"mtu": "8000", "route": "10.0.0.1/32"},
|
||||
c.Settings["tun"] = map[string]any{"routes": []any{
|
||||
map[string]any{"mtu": "9000", "route": "10.0.0.0/29"},
|
||||
map[string]any{"mtu": "8000", "route": "10.0.0.1/32"},
|
||||
}}
|
||||
routes, err = parseRoutes(c, []netip.Prefix{n})
|
||||
require.NoError(t, err)
|
||||
@@ -128,105 +129,129 @@ func Test_parseUnsafeRoutes(t *testing.T) {
|
||||
assert.Empty(t, routes)
|
||||
|
||||
// not an array
|
||||
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": "hi"}
|
||||
c.Settings["tun"] = map[string]any{"unsafe_routes": "hi"}
|
||||
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
|
||||
assert.Nil(t, routes)
|
||||
require.EqualError(t, err, "tun.unsafe_routes is not an array")
|
||||
|
||||
// no routes
|
||||
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{}}
|
||||
c.Settings["tun"] = map[string]any{"unsafe_routes": []any{}}
|
||||
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, routes)
|
||||
|
||||
// weird route
|
||||
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{"asdf"}}
|
||||
c.Settings["tun"] = map[string]any{"unsafe_routes": []any{"asdf"}}
|
||||
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
|
||||
assert.Nil(t, routes)
|
||||
require.EqualError(t, err, "entry 1 in tun.unsafe_routes is invalid")
|
||||
|
||||
// no via
|
||||
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{}}}
|
||||
c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{}}}
|
||||
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
|
||||
assert.Nil(t, routes)
|
||||
require.EqualError(t, err, "entry 1.via in tun.unsafe_routes is not present")
|
||||
|
||||
// invalid via
|
||||
for _, invalidValue := range []interface{}{
|
||||
for _, invalidValue := range []any{
|
||||
127, false, nil, 1.0, []string{"1", "2"},
|
||||
} {
|
||||
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": invalidValue}}}
|
||||
c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"via": invalidValue}}}
|
||||
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
|
||||
assert.Nil(t, routes)
|
||||
require.EqualError(t, err, fmt.Sprintf("entry 1.via in tun.unsafe_routes is not a string: found %T", invalidValue))
|
||||
require.EqualError(t, err, fmt.Sprintf("entry 1.via in tun.unsafe_routes is not a string or list of gateways: found %T", invalidValue))
|
||||
}
|
||||
|
||||
// Unparsable list of via
|
||||
c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"via": []string{"1", "2"}}}}
|
||||
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
|
||||
assert.Nil(t, routes)
|
||||
require.EqualError(t, err, "entry 1.via in tun.unsafe_routes is not a string or list of gateways: found []string")
|
||||
|
||||
// unparsable via
|
||||
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "via": "nope"}}}
|
||||
c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"mtu": "500", "via": "nope"}}}
|
||||
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
|
||||
assert.Nil(t, routes)
|
||||
require.EqualError(t, err, "entry 1.via in tun.unsafe_routes failed to parse address: ParseAddr(\"nope\"): unable to parse IP")
|
||||
|
||||
// unparsable gateway
|
||||
c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"mtu": "500", "via": []any{map[string]any{"gateway": "1"}}}}}
|
||||
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
|
||||
assert.Nil(t, routes)
|
||||
require.EqualError(t, err, "entry .gateway in tun.unsafe_routes[1].via[1] failed to parse address: ParseAddr(\"1\"): unable to parse IP")
|
||||
|
||||
// missing gateway element
|
||||
c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"mtu": "500", "via": []any{map[string]any{"weight": "1"}}}}}
|
||||
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
|
||||
assert.Nil(t, routes)
|
||||
require.EqualError(t, err, "entry .gateway in tun.unsafe_routes[1].via[1] is not present")
|
||||
|
||||
// unparsable weight element
|
||||
c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"mtu": "500", "via": []any{map[string]any{"gateway": "10.0.0.1", "weight": "a"}}}}}
|
||||
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
|
||||
assert.Nil(t, routes)
|
||||
require.EqualError(t, err, "entry .weight in tun.unsafe_routes[1].via[1] is not an integer")
|
||||
|
||||
// missing route
|
||||
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "500"}}}
|
||||
c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"via": "127.0.0.1", "mtu": "500"}}}
|
||||
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
|
||||
assert.Nil(t, routes)
|
||||
require.EqualError(t, err, "entry 1.route in tun.unsafe_routes is not present")
|
||||
|
||||
// unparsable route
|
||||
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "500", "route": "nope"}}}
|
||||
c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"via": "127.0.0.1", "mtu": "500", "route": "nope"}}}
|
||||
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
|
||||
assert.Nil(t, routes)
|
||||
require.EqualError(t, err, "entry 1.route in tun.unsafe_routes failed to parse: netip.ParsePrefix(\"nope\"): no '/'")
|
||||
|
||||
// within network range
|
||||
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "10.0.0.0/24"}}}
|
||||
c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"via": "127.0.0.1", "route": "10.0.0.0/24"}}}
|
||||
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
|
||||
assert.Nil(t, routes)
|
||||
require.EqualError(t, err, "entry 1.route in tun.unsafe_routes is contained within the configured vpn networks; route: 10.0.0.0/24, network: 10.0.0.0/24")
|
||||
|
||||
// below network range
|
||||
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "1.0.0.0/8"}}}
|
||||
c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"via": "127.0.0.1", "route": "1.0.0.0/8"}}}
|
||||
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
|
||||
assert.Len(t, routes, 1)
|
||||
require.NoError(t, err)
|
||||
|
||||
// above network range
|
||||
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "10.0.1.0/24"}}}
|
||||
c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"via": "127.0.0.1", "route": "10.0.1.0/24"}}}
|
||||
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
|
||||
assert.Len(t, routes, 1)
|
||||
require.NoError(t, err)
|
||||
|
||||
// no mtu
|
||||
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "1.0.0.0/8"}}}
|
||||
c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"via": "127.0.0.1", "route": "1.0.0.0/8"}}}
|
||||
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
|
||||
assert.Len(t, routes, 1)
|
||||
assert.Equal(t, 0, routes[0].MTU)
|
||||
|
||||
// bad mtu
|
||||
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "nope"}}}
|
||||
c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"via": "127.0.0.1", "mtu": "nope"}}}
|
||||
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
|
||||
assert.Nil(t, routes)
|
||||
require.EqualError(t, err, "entry 1.mtu in tun.unsafe_routes is not an integer: strconv.Atoi: parsing \"nope\": invalid syntax")
|
||||
|
||||
// low mtu
|
||||
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "499"}}}
|
||||
c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"via": "127.0.0.1", "mtu": "499"}}}
|
||||
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
|
||||
assert.Nil(t, routes)
|
||||
require.EqualError(t, err, "entry 1.mtu in tun.unsafe_routes is below 500: 499")
|
||||
|
||||
// bad install
|
||||
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "9000", "route": "1.0.0.0/29", "install": "nope"}}}
|
||||
c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"via": "127.0.0.1", "mtu": "9000", "route": "1.0.0.0/29", "install": "nope"}}}
|
||||
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
|
||||
assert.Nil(t, routes)
|
||||
require.EqualError(t, err, "entry 1.install in tun.unsafe_routes is not a boolean: strconv.ParseBool: parsing \"nope\": invalid syntax")
|
||||
|
||||
// happy case
|
||||
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{
|
||||
map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "9000", "route": "1.0.0.0/29", "install": "t"},
|
||||
map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "8000", "route": "1.0.0.1/32", "install": 0},
|
||||
map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "1500", "metric": 1234, "route": "1.0.0.2/32", "install": 1},
|
||||
map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "1500", "metric": 1234, "route": "1.0.0.2/32"},
|
||||
c.Settings["tun"] = map[string]any{"unsafe_routes": []any{
|
||||
map[string]any{"via": "127.0.0.1", "mtu": "9000", "route": "1.0.0.0/29", "install": "t"},
|
||||
map[string]any{"via": "127.0.0.1", "mtu": "8000", "route": "1.0.0.1/32", "install": 0},
|
||||
map[string]any{"via": "127.0.0.1", "mtu": "1500", "metric": 1234, "route": "1.0.0.2/32", "install": 1},
|
||||
map[string]any{"via": "127.0.0.1", "mtu": "1500", "metric": 1234, "route": "1.0.0.2/32"},
|
||||
}}
|
||||
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
|
||||
require.NoError(t, err)
|
||||
@@ -263,9 +288,9 @@ func Test_makeRouteTree(t *testing.T) {
|
||||
n, err := netip.ParsePrefix("10.0.0.0/24")
|
||||
require.NoError(t, err)
|
||||
|
||||
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{
|
||||
map[interface{}]interface{}{"via": "192.168.0.1", "route": "1.0.0.0/28"},
|
||||
map[interface{}]interface{}{"via": "192.168.0.2", "route": "1.0.0.1/32"},
|
||||
c.Settings["tun"] = map[string]any{"unsafe_routes": []any{
|
||||
map[string]any{"via": "192.168.0.1", "route": "1.0.0.0/28"},
|
||||
map[string]any{"via": "192.168.0.2", "route": "1.0.0.1/32"},
|
||||
}}
|
||||
routes, err := parseUnsafeRoutes(c, []netip.Prefix{n})
|
||||
require.NoError(t, err)
|
||||
@@ -280,7 +305,7 @@ func Test_makeRouteTree(t *testing.T) {
|
||||
|
||||
nip, err := netip.ParseAddr("192.168.0.1")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, nip, r)
|
||||
assert.Equal(t, nip, r[0].Addr())
|
||||
|
||||
ip, err = netip.ParseAddr("1.0.0.1")
|
||||
require.NoError(t, err)
|
||||
@@ -289,10 +314,91 @@ func Test_makeRouteTree(t *testing.T) {
|
||||
|
||||
nip, err = netip.ParseAddr("192.168.0.2")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, nip, r)
|
||||
assert.Equal(t, nip, r[0].Addr())
|
||||
|
||||
ip, err = netip.ParseAddr("1.1.0.1")
|
||||
require.NoError(t, err)
|
||||
r, ok = routeTree.Lookup(ip)
|
||||
assert.False(t, ok)
|
||||
}
|
||||
|
||||
func Test_makeMultipathUnsafeRouteTree(t *testing.T) {
|
||||
l := test.NewLogger()
|
||||
c := config.NewC(l)
|
||||
n, err := netip.ParsePrefix("10.0.0.0/24")
|
||||
require.NoError(t, err)
|
||||
|
||||
c.Settings["tun"] = map[string]any{
|
||||
"unsafe_routes": []any{
|
||||
map[string]any{
|
||||
"route": "192.168.86.0/24",
|
||||
"via": "192.168.100.10",
|
||||
},
|
||||
map[string]any{
|
||||
"route": "192.168.87.0/24",
|
||||
"via": []any{
|
||||
map[string]any{
|
||||
"gateway": "10.0.0.1",
|
||||
},
|
||||
map[string]any{
|
||||
"gateway": "10.0.0.2",
|
||||
},
|
||||
map[string]any{
|
||||
"gateway": "10.0.0.3",
|
||||
},
|
||||
},
|
||||
},
|
||||
map[string]any{
|
||||
"route": "192.168.89.0/24",
|
||||
"via": []any{
|
||||
map[string]any{
|
||||
"gateway": "10.0.0.1",
|
||||
"weight": 10,
|
||||
},
|
||||
map[string]any{
|
||||
"gateway": "10.0.0.2",
|
||||
"weight": 5,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
routes, err := parseUnsafeRoutes(c, []netip.Prefix{n})
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, routes, 3)
|
||||
routeTree, err := makeRouteTree(l, routes, true)
|
||||
require.NoError(t, err)
|
||||
|
||||
ip, err := netip.ParseAddr("192.168.86.1")
|
||||
require.NoError(t, err)
|
||||
r, ok := routeTree.Lookup(ip)
|
||||
assert.True(t, ok)
|
||||
|
||||
nip, err := netip.ParseAddr("192.168.100.10")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, nip, r[0].Addr())
|
||||
|
||||
ip, err = netip.ParseAddr("192.168.87.1")
|
||||
require.NoError(t, err)
|
||||
r, ok = routeTree.Lookup(ip)
|
||||
assert.True(t, ok)
|
||||
|
||||
expectedGateways := routing.Gateways{routing.NewGateway(netip.MustParseAddr("10.0.0.1"), 1),
|
||||
routing.NewGateway(netip.MustParseAddr("10.0.0.2"), 1),
|
||||
routing.NewGateway(netip.MustParseAddr("10.0.0.3"), 1)}
|
||||
|
||||
routing.CalculateBucketsForGateways(expectedGateways)
|
||||
assert.ElementsMatch(t, expectedGateways, r)
|
||||
|
||||
ip, err = netip.ParseAddr("192.168.89.1")
|
||||
require.NoError(t, err)
|
||||
r, ok = routeTree.Lookup(ip)
|
||||
assert.True(t, ok)
|
||||
|
||||
expectedGateways = routing.Gateways{routing.NewGateway(netip.MustParseAddr("10.0.0.1"), 10),
|
||||
routing.NewGateway(netip.MustParseAddr("10.0.0.2"), 5)}
|
||||
|
||||
routing.CalculateBucketsForGateways(expectedGateways)
|
||||
assert.ElementsMatch(t, expectedGateways, r)
|
||||
}
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"github.com/gaissmai/bart"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/routing"
|
||||
"github.com/slackhq/nebula/util"
|
||||
)
|
||||
|
||||
@@ -21,7 +22,7 @@ type tun struct {
|
||||
fd int
|
||||
vpnNetworks []netip.Prefix
|
||||
Routes atomic.Pointer[[]Route]
|
||||
routeTree atomic.Pointer[bart.Table[netip.Addr]]
|
||||
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
||||
l *logrus.Logger
|
||||
}
|
||||
|
||||
@@ -56,7 +57,7 @@ func newTun(_ *config.C, _ *logrus.Logger, _ []netip.Prefix, _ bool) (*tun, erro
|
||||
return nil, fmt.Errorf("newTun not supported in Android")
|
||||
}
|
||||
|
||||
func (t *tun) RouteFor(ip netip.Addr) netip.Addr {
|
||||
func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
|
||||
r, _ := t.routeTree.Load().Lookup(ip)
|
||||
return r
|
||||
}
|
||||
|
||||
@@ -17,6 +17,7 @@ import (
|
||||
"github.com/gaissmai/bart"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/routing"
|
||||
"github.com/slackhq/nebula/util"
|
||||
netroute "golang.org/x/net/route"
|
||||
"golang.org/x/sys/unix"
|
||||
@@ -28,7 +29,7 @@ type tun struct {
|
||||
vpnNetworks []netip.Prefix
|
||||
DefaultMTU int
|
||||
Routes atomic.Pointer[[]Route]
|
||||
routeTree atomic.Pointer[bart.Table[netip.Addr]]
|
||||
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
||||
linkAddr *netroute.LinkAddr
|
||||
l *logrus.Logger
|
||||
|
||||
@@ -342,12 +343,12 @@ func (t *tun) reload(c *config.C, initial bool) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *tun) RouteFor(ip netip.Addr) netip.Addr {
|
||||
func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
|
||||
r, ok := t.routeTree.Load().Lookup(ip)
|
||||
if ok {
|
||||
return r
|
||||
}
|
||||
return netip.Addr{}
|
||||
return routing.Gateways{}
|
||||
}
|
||||
|
||||
// Get the LinkAddr for the interface of the given name
|
||||
@@ -382,7 +383,7 @@ func (t *tun) addRoutes(logErrors bool) error {
|
||||
routes := *t.Routes.Load()
|
||||
|
||||
for _, r := range routes {
|
||||
if !r.Via.IsValid() || !r.Install {
|
||||
if len(r.Via) == 0 || !r.Install {
|
||||
// We don't allow route MTUs so only install routes with a via
|
||||
continue
|
||||
}
|
||||
@@ -393,7 +394,7 @@ func (t *tun) addRoutes(logErrors bool) error {
|
||||
t.l.WithField("route", r.Cidr).
|
||||
Warnf("unable to add unsafe_route, identical route already exists")
|
||||
} else {
|
||||
retErr := util.NewContextualError("Failed to add route", map[string]interface{}{"route": r}, err)
|
||||
retErr := util.NewContextualError("Failed to add route", map[string]any{"route": r}, err)
|
||||
if logErrors {
|
||||
retErr.Log(t.l)
|
||||
} else {
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"github.com/rcrowley/go-metrics"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/iputil"
|
||||
"github.com/slackhq/nebula/routing"
|
||||
)
|
||||
|
||||
type disabledTun struct {
|
||||
@@ -43,8 +44,8 @@ func (*disabledTun) Activate() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (*disabledTun) RouteFor(addr netip.Addr) netip.Addr {
|
||||
return netip.Addr{}
|
||||
func (*disabledTun) RoutesFor(addr netip.Addr) routing.Gateways {
|
||||
return routing.Gateways{}
|
||||
}
|
||||
|
||||
func (t *disabledTun) Networks() []netip.Prefix {
|
||||
|
||||
@@ -20,6 +20,7 @@ import (
|
||||
"github.com/gaissmai/bart"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/routing"
|
||||
"github.com/slackhq/nebula/util"
|
||||
)
|
||||
|
||||
@@ -50,7 +51,7 @@ type tun struct {
|
||||
vpnNetworks []netip.Prefix
|
||||
MTU int
|
||||
Routes atomic.Pointer[[]Route]
|
||||
routeTree atomic.Pointer[bart.Table[netip.Addr]]
|
||||
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
||||
l *logrus.Logger
|
||||
|
||||
io.ReadWriteCloser
|
||||
@@ -242,7 +243,7 @@ func (t *tun) reload(c *config.C, initial bool) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *tun) RouteFor(ip netip.Addr) netip.Addr {
|
||||
func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
|
||||
r, _ := t.routeTree.Load().Lookup(ip)
|
||||
return r
|
||||
}
|
||||
@@ -262,7 +263,7 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||
func (t *tun) addRoutes(logErrors bool) error {
|
||||
routes := *t.Routes.Load()
|
||||
for _, r := range routes {
|
||||
if !r.Via.IsValid() || !r.Install {
|
||||
if len(r.Via) == 0 || !r.Install {
|
||||
// We don't allow route MTUs so only install routes with a via
|
||||
continue
|
||||
}
|
||||
@@ -270,7 +271,7 @@ func (t *tun) addRoutes(logErrors bool) error {
|
||||
cmd := exec.Command("/sbin/route", "-n", "add", "-net", r.Cidr.String(), "-interface", t.Device)
|
||||
t.l.Debug("command: ", cmd.String())
|
||||
if err := cmd.Run(); err != nil {
|
||||
retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]interface{}{"route": r}, err)
|
||||
retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]any{"route": r}, err)
|
||||
if logErrors {
|
||||
retErr.Log(t.l)
|
||||
} else {
|
||||
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
"github.com/gaissmai/bart"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/routing"
|
||||
"github.com/slackhq/nebula/util"
|
||||
)
|
||||
|
||||
@@ -23,7 +24,7 @@ type tun struct {
|
||||
io.ReadWriteCloser
|
||||
vpnNetworks []netip.Prefix
|
||||
Routes atomic.Pointer[[]Route]
|
||||
routeTree atomic.Pointer[bart.Table[netip.Addr]]
|
||||
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
||||
l *logrus.Logger
|
||||
}
|
||||
|
||||
@@ -79,7 +80,7 @@ func (t *tun) reload(c *config.C, initial bool) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *tun) RouteFor(ip netip.Addr) netip.Addr {
|
||||
func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
|
||||
r, _ := t.routeTree.Load().Lookup(ip)
|
||||
return r
|
||||
}
|
||||
|
||||
@@ -17,6 +17,7 @@ import (
|
||||
"github.com/gaissmai/bart"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/routing"
|
||||
"github.com/slackhq/nebula/util"
|
||||
"github.com/vishvananda/netlink"
|
||||
"golang.org/x/sys/unix"
|
||||
@@ -34,7 +35,7 @@ type tun struct {
|
||||
ioctlFd uintptr
|
||||
|
||||
Routes atomic.Pointer[[]Route]
|
||||
routeTree atomic.Pointer[bart.Table[netip.Addr]]
|
||||
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
||||
routeChan chan struct{}
|
||||
useSystemRoutes bool
|
||||
|
||||
@@ -231,7 +232,7 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||
return file, nil
|
||||
}
|
||||
|
||||
func (t *tun) RouteFor(ip netip.Addr) netip.Addr {
|
||||
func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
|
||||
r, _ := t.routeTree.Load().Lookup(ip)
|
||||
return r
|
||||
}
|
||||
@@ -463,7 +464,7 @@ func (t *tun) addRoutes(logErrors bool) error {
|
||||
|
||||
err := netlink.RouteReplace(&nr)
|
||||
if err != nil {
|
||||
retErr := util.NewContextualError("Failed to add route", map[string]interface{}{"route": r}, err)
|
||||
retErr := util.NewContextualError("Failed to add route", map[string]any{"route": r}, err)
|
||||
if logErrors {
|
||||
retErr.Log(t.l)
|
||||
} else {
|
||||
@@ -550,20 +551,7 @@ func (t *tun) watchRoutes() {
|
||||
}()
|
||||
}
|
||||
|
||||
func (t *tun) updateRoutes(r netlink.RouteUpdate) {
|
||||
if r.Gw == nil {
|
||||
// Not a gateway route, ignore
|
||||
t.l.WithField("route", r).Debug("Ignoring route update, not a gateway route")
|
||||
return
|
||||
}
|
||||
|
||||
gwAddr, ok := netip.AddrFromSlice(r.Gw)
|
||||
if !ok {
|
||||
t.l.WithField("route", r).Debug("Ignoring route update, invalid gateway address")
|
||||
return
|
||||
}
|
||||
|
||||
gwAddr = gwAddr.Unmap()
|
||||
func (t *tun) isGatewayInVpnNetworks(gwAddr netip.Addr) bool {
|
||||
withinNetworks := false
|
||||
for i := range t.vpnNetworks {
|
||||
if t.vpnNetworks[i].Contains(gwAddr) {
|
||||
@@ -571,9 +559,68 @@ func (t *tun) updateRoutes(r netlink.RouteUpdate) {
|
||||
break
|
||||
}
|
||||
}
|
||||
if !withinNetworks {
|
||||
// Gateway isn't in our overlay network, ignore
|
||||
t.l.WithField("route", r).Debug("Ignoring route update, not in our networks")
|
||||
|
||||
return withinNetworks
|
||||
}
|
||||
|
||||
func (t *tun) getGatewaysFromRoute(r *netlink.Route) routing.Gateways {
|
||||
|
||||
var gateways routing.Gateways
|
||||
|
||||
link, err := netlink.LinkByName(t.Device)
|
||||
if err != nil {
|
||||
t.l.WithField("Devicename", t.Device).Error("Ignoring route update: failed to get link by name")
|
||||
return gateways
|
||||
}
|
||||
|
||||
// If this route is relevant to our interface and there is a gateway then add it
|
||||
if r.LinkIndex == link.Attrs().Index && len(r.Gw) > 0 {
|
||||
gwAddr, ok := netip.AddrFromSlice(r.Gw)
|
||||
if !ok {
|
||||
t.l.WithField("route", r).Debug("Ignoring route update, invalid gateway address")
|
||||
} else {
|
||||
gwAddr = gwAddr.Unmap()
|
||||
|
||||
if !t.isGatewayInVpnNetworks(gwAddr) {
|
||||
// Gateway isn't in our overlay network, ignore
|
||||
t.l.WithField("route", r).Debug("Ignoring route update, not in our network")
|
||||
} else {
|
||||
gateways = append(gateways, routing.NewGateway(gwAddr, 1))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, p := range r.MultiPath {
|
||||
// If this route is relevant to our interface and there is a gateway then add it
|
||||
if p.LinkIndex == link.Attrs().Index && len(p.Gw) > 0 {
|
||||
gwAddr, ok := netip.AddrFromSlice(p.Gw)
|
||||
if !ok {
|
||||
t.l.WithField("route", r).Debug("Ignoring multipath route update, invalid gateway address")
|
||||
} else {
|
||||
gwAddr = gwAddr.Unmap()
|
||||
|
||||
if !t.isGatewayInVpnNetworks(gwAddr) {
|
||||
// Gateway isn't in our overlay network, ignore
|
||||
t.l.WithField("route", r).Debug("Ignoring route update, not in our network")
|
||||
} else {
|
||||
// p.Hops+1 = weight of the route
|
||||
gateways = append(gateways, routing.NewGateway(gwAddr, p.Hops+1))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
routing.CalculateBucketsForGateways(gateways)
|
||||
return gateways
|
||||
}
|
||||
|
||||
func (t *tun) updateRoutes(r netlink.RouteUpdate) {
|
||||
|
||||
gateways := t.getGatewaysFromRoute(&r.Route)
|
||||
|
||||
if len(gateways) == 0 {
|
||||
// No gateways relevant to our network, no routing changes required.
|
||||
t.l.WithField("route", r).Debug("Ignoring route update, no gateways")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -589,12 +636,12 @@ func (t *tun) updateRoutes(r netlink.RouteUpdate) {
|
||||
newTree := t.routeTree.Load().Clone()
|
||||
|
||||
if r.Type == unix.RTM_NEWROUTE {
|
||||
t.l.WithField("destination", r.Dst).WithField("via", r.Gw).Info("Adding route")
|
||||
newTree.Insert(dst, gwAddr)
|
||||
t.l.WithField("destination", dst).WithField("via", gateways).Info("Adding route")
|
||||
newTree.Insert(dst, gateways)
|
||||
|
||||
} else {
|
||||
t.l.WithField("destination", dst).WithField("via", gateways).Info("Removing route")
|
||||
newTree.Delete(dst)
|
||||
t.l.WithField("destination", r.Dst).WithField("via", r.Gw).Info("Removing route")
|
||||
}
|
||||
t.routeTree.Store(newTree)
|
||||
}
|
||||
|
||||
@@ -18,6 +18,7 @@ import (
|
||||
"github.com/gaissmai/bart"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/routing"
|
||||
"github.com/slackhq/nebula/util"
|
||||
)
|
||||
|
||||
@@ -31,7 +32,7 @@ type tun struct {
|
||||
vpnNetworks []netip.Prefix
|
||||
MTU int
|
||||
Routes atomic.Pointer[[]Route]
|
||||
routeTree atomic.Pointer[bart.Table[netip.Addr]]
|
||||
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
||||
l *logrus.Logger
|
||||
|
||||
io.ReadWriteCloser
|
||||
@@ -177,7 +178,7 @@ func (t *tun) reload(c *config.C, initial bool) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *tun) RouteFor(ip netip.Addr) netip.Addr {
|
||||
func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
|
||||
r, _ := t.routeTree.Load().Lookup(ip)
|
||||
return r
|
||||
}
|
||||
@@ -197,7 +198,7 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||
func (t *tun) addRoutes(logErrors bool) error {
|
||||
routes := *t.Routes.Load()
|
||||
for _, r := range routes {
|
||||
if !r.Via.IsValid() || !r.Install {
|
||||
if len(r.Via) == 0 || !r.Install {
|
||||
// We don't allow route MTUs so only install routes with a via
|
||||
continue
|
||||
}
|
||||
@@ -205,7 +206,7 @@ func (t *tun) addRoutes(logErrors bool) error {
|
||||
cmd := exec.Command("/sbin/route", "-n", "add", "-net", r.Cidr.String(), t.vpnNetworks[0].Addr().String())
|
||||
t.l.Debug("command: ", cmd.String())
|
||||
if err := cmd.Run(); err != nil {
|
||||
retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]interface{}{"route": r}, err)
|
||||
retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]any{"route": r}, err)
|
||||
if logErrors {
|
||||
retErr.Log(t.l)
|
||||
} else {
|
||||
|
||||
@@ -17,6 +17,7 @@ import (
|
||||
"github.com/gaissmai/bart"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/routing"
|
||||
"github.com/slackhq/nebula/util"
|
||||
)
|
||||
|
||||
@@ -25,7 +26,7 @@ type tun struct {
|
||||
vpnNetworks []netip.Prefix
|
||||
MTU int
|
||||
Routes atomic.Pointer[[]Route]
|
||||
routeTree atomic.Pointer[bart.Table[netip.Addr]]
|
||||
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
||||
l *logrus.Logger
|
||||
|
||||
io.ReadWriteCloser
|
||||
@@ -158,7 +159,7 @@ func (t *tun) Activate() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *tun) RouteFor(ip netip.Addr) netip.Addr {
|
||||
func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
|
||||
r, _ := t.routeTree.Load().Lookup(ip)
|
||||
return r
|
||||
}
|
||||
@@ -166,7 +167,7 @@ func (t *tun) RouteFor(ip netip.Addr) netip.Addr {
|
||||
func (t *tun) addRoutes(logErrors bool) error {
|
||||
routes := *t.Routes.Load()
|
||||
for _, r := range routes {
|
||||
if !r.Via.IsValid() || !r.Install {
|
||||
if len(r.Via) == 0 || !r.Install {
|
||||
// We don't allow route MTUs so only install routes with a via
|
||||
continue
|
||||
}
|
||||
@@ -174,7 +175,7 @@ func (t *tun) addRoutes(logErrors bool) error {
|
||||
cmd := exec.Command("/sbin/route", "-n", "add", "-inet", r.Cidr.String(), t.vpnNetworks[0].Addr().String())
|
||||
t.l.Debug("command: ", cmd.String())
|
||||
if err := cmd.Run(); err != nil {
|
||||
retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]interface{}{"route": r}, err)
|
||||
retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]any{"route": r}, err)
|
||||
if logErrors {
|
||||
retErr.Log(t.l)
|
||||
} else {
|
||||
|
||||
@@ -13,13 +13,14 @@ import (
|
||||
"github.com/gaissmai/bart"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/routing"
|
||||
)
|
||||
|
||||
type TestTun struct {
|
||||
Device string
|
||||
vpnNetworks []netip.Prefix
|
||||
Routes []Route
|
||||
routeTree *bart.Table[netip.Addr]
|
||||
routeTree *bart.Table[routing.Gateways]
|
||||
l *logrus.Logger
|
||||
|
||||
closed atomic.Bool
|
||||
@@ -86,7 +87,7 @@ func (t *TestTun) Get(block bool) []byte {
|
||||
// Below this is boilerplate implementation to make nebula actually work
|
||||
//********************************************************************************************************************//
|
||||
|
||||
func (t *TestTun) RouteFor(ip netip.Addr) netip.Addr {
|
||||
func (t *TestTun) RoutesFor(ip netip.Addr) routing.Gateways {
|
||||
r, _ := t.routeTree.Lookup(ip)
|
||||
return r
|
||||
}
|
||||
|
||||
@@ -18,6 +18,7 @@ import (
|
||||
"github.com/gaissmai/bart"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/routing"
|
||||
"github.com/slackhq/nebula/util"
|
||||
"github.com/slackhq/nebula/wintun"
|
||||
"golang.org/x/sys/windows"
|
||||
@@ -31,7 +32,7 @@ type winTun struct {
|
||||
vpnNetworks []netip.Prefix
|
||||
MTU int
|
||||
Routes atomic.Pointer[[]Route]
|
||||
routeTree atomic.Pointer[bart.Table[netip.Addr]]
|
||||
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
||||
l *logrus.Logger
|
||||
|
||||
tun *wintun.NativeTun
|
||||
@@ -147,15 +148,18 @@ func (t *winTun) addRoutes(logErrors bool) error {
|
||||
foundDefault4 := false
|
||||
|
||||
for _, r := range routes {
|
||||
if !r.Via.IsValid() || !r.Install {
|
||||
if len(r.Via) == 0 || !r.Install {
|
||||
// We don't allow route MTUs so only install routes with a via
|
||||
continue
|
||||
}
|
||||
|
||||
// Add our unsafe route
|
||||
err := luid.AddRoute(r.Cidr, r.Via, uint32(r.Metric))
|
||||
// Windows does not support multipath routes natively, so we install only a single route.
|
||||
// This is not a problem as traffic will always be sent to Nebula which handles the multipath routing internally.
|
||||
// In effect this provides multipath routing support to windows supporting loadbalancing and redundancy.
|
||||
err := luid.AddRoute(r.Cidr, r.Via[0].Addr(), uint32(r.Metric))
|
||||
if err != nil {
|
||||
retErr := util.NewContextualError("Failed to add route", map[string]interface{}{"route": r}, err)
|
||||
retErr := util.NewContextualError("Failed to add route", map[string]any{"route": r}, err)
|
||||
if logErrors {
|
||||
retErr.Log(t.l)
|
||||
continue
|
||||
@@ -198,7 +202,8 @@ func (t *winTun) removeRoutes(routes []Route) error {
|
||||
continue
|
||||
}
|
||||
|
||||
err := luid.DeleteRoute(r.Cidr, r.Via)
|
||||
// See comment on luid.AddRoute
|
||||
err := luid.DeleteRoute(r.Cidr, r.Via[0].Addr())
|
||||
if err != nil {
|
||||
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
|
||||
} else {
|
||||
@@ -208,7 +213,7 @@ func (t *winTun) removeRoutes(routes []Route) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *winTun) RouteFor(ip netip.Addr) netip.Addr {
|
||||
func (t *winTun) RoutesFor(ip netip.Addr) routing.Gateways {
|
||||
r, _ := t.routeTree.Load().Lookup(ip)
|
||||
return r
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/routing"
|
||||
)
|
||||
|
||||
func NewUserDeviceFromConfig(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) {
|
||||
@@ -38,9 +39,13 @@ type UserDevice struct {
|
||||
func (d *UserDevice) Activate() error {
|
||||
return nil
|
||||
}
|
||||
func (d *UserDevice) Networks() []netip.Prefix { return d.vpnNetworks }
|
||||
func (d *UserDevice) Name() string { return "faketun0" }
|
||||
func (d *UserDevice) RouteFor(ip netip.Addr) netip.Addr { return ip }
|
||||
|
||||
func (d *UserDevice) Networks() []netip.Prefix { return d.vpnNetworks }
|
||||
func (d *UserDevice) Name() string { return "faketun0" }
|
||||
func (d *UserDevice) RoutesFor(ip netip.Addr) routing.Gateways {
|
||||
return routing.Gateways{routing.NewGateway(ip, 1)}
|
||||
}
|
||||
|
||||
func (d *UserDevice) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||
return d, nil
|
||||
}
|
||||
|
||||
70
pki.go
70
pki.go
@@ -33,18 +33,16 @@ type CertState struct {
|
||||
v2Cert cert.Certificate
|
||||
v2HandshakeBytes []byte
|
||||
|
||||
defaultVersion cert.Version
|
||||
privateKey []byte
|
||||
pkcs11Backed bool
|
||||
cipher string
|
||||
|
||||
psk *Psk
|
||||
initiatingVersion cert.Version
|
||||
privateKey []byte
|
||||
pkcs11Backed bool
|
||||
cipher string
|
||||
|
||||
myVpnNetworks []netip.Prefix
|
||||
myVpnNetworksTable *bart.Table[struct{}]
|
||||
myVpnNetworksTable *bart.Lite
|
||||
myVpnAddrs []netip.Addr
|
||||
myVpnAddrsTable *bart.Table[struct{}]
|
||||
myVpnBroadcastAddrsTable *bart.Table[struct{}]
|
||||
myVpnAddrsTable *bart.Lite
|
||||
myVpnBroadcastAddrsTable *bart.Lite
|
||||
}
|
||||
|
||||
func NewPKIFromConfig(l *logrus.Logger, c *config.C) (*PKI, error) {
|
||||
@@ -173,16 +171,6 @@ func (p *PKI) reloadCerts(c *config.C, initial bool) *util.ContextualError {
|
||||
}
|
||||
}
|
||||
|
||||
psk, err := NewPskFromConfig(c)
|
||||
if err != nil {
|
||||
return util.NewContextualError("Failed to load psk from config", nil, err)
|
||||
}
|
||||
if len(psk.keys) > 0 {
|
||||
p.l.WithField("pskMode", psk.mode).WithField("keysLen", len(psk.keys)).
|
||||
Info("pre shared keys are in use")
|
||||
}
|
||||
newState.psk = psk
|
||||
|
||||
p.cs.Store(newState)
|
||||
|
||||
//TODO: CERT-V2 newState needs a stringer that does json
|
||||
@@ -206,7 +194,7 @@ func (p *PKI) reloadCAPool(c *config.C) *util.ContextualError {
|
||||
}
|
||||
|
||||
func (cs *CertState) GetDefaultCertificate() cert.Certificate {
|
||||
c := cs.getCertificate(cs.defaultVersion)
|
||||
c := cs.getCertificate(cs.initiatingVersion)
|
||||
if c == nil {
|
||||
panic("No default certificate found")
|
||||
}
|
||||
@@ -329,37 +317,37 @@ func newCertStateFromConfig(c *config.C) (*CertState, error) {
|
||||
return nil, errors.New("no certificates found in pki.cert")
|
||||
}
|
||||
|
||||
useDefaultVersion := uint32(1)
|
||||
useInitiatingVersion := uint32(1)
|
||||
if v1 == nil {
|
||||
// The only condition that requires v2 as the default is if only a v2 certificate is present
|
||||
// We do this to avoid having to configure it specifically in the config file
|
||||
useDefaultVersion = 2
|
||||
useInitiatingVersion = 2
|
||||
}
|
||||
|
||||
rawDefaultVersion := c.GetUint32("pki.default_version", useDefaultVersion)
|
||||
var defaultVersion cert.Version
|
||||
switch rawDefaultVersion {
|
||||
rawInitiatingVersion := c.GetUint32("pki.initiating_version", useInitiatingVersion)
|
||||
var initiatingVersion cert.Version
|
||||
switch rawInitiatingVersion {
|
||||
case 1:
|
||||
if v1 == nil {
|
||||
return nil, fmt.Errorf("can not use pki.default_version 1 without a v1 certificate in pki.cert")
|
||||
return nil, fmt.Errorf("can not use pki.initiating_version 1 without a v1 certificate in pki.cert")
|
||||
}
|
||||
defaultVersion = cert.Version1
|
||||
initiatingVersion = cert.Version1
|
||||
case 2:
|
||||
defaultVersion = cert.Version2
|
||||
initiatingVersion = cert.Version2
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown pki.default_version: %v", rawDefaultVersion)
|
||||
return nil, fmt.Errorf("unknown pki.initiating_version: %v", rawInitiatingVersion)
|
||||
}
|
||||
|
||||
return newCertState(defaultVersion, v1, v2, isPkcs11, curve, rawKey)
|
||||
return newCertState(initiatingVersion, v1, v2, isPkcs11, curve, rawKey)
|
||||
}
|
||||
|
||||
func newCertState(dv cert.Version, v1, v2 cert.Certificate, pkcs11backed bool, privateKeyCurve cert.Curve, privateKey []byte) (*CertState, error) {
|
||||
cs := CertState{
|
||||
privateKey: privateKey,
|
||||
pkcs11Backed: pkcs11backed,
|
||||
myVpnNetworksTable: new(bart.Table[struct{}]),
|
||||
myVpnAddrsTable: new(bart.Table[struct{}]),
|
||||
myVpnBroadcastAddrsTable: new(bart.Table[struct{}]),
|
||||
myVpnNetworksTable: new(bart.Lite),
|
||||
myVpnAddrsTable: new(bart.Lite),
|
||||
myVpnBroadcastAddrsTable: new(bart.Lite),
|
||||
}
|
||||
|
||||
if v1 != nil && v2 != nil {
|
||||
@@ -373,7 +361,7 @@ func newCertState(dv cert.Version, v1, v2 cert.Certificate, pkcs11backed bool, p
|
||||
|
||||
//TODO: CERT-V2 make sure v2 has v1s address
|
||||
|
||||
cs.defaultVersion = dv
|
||||
cs.initiatingVersion = dv
|
||||
}
|
||||
|
||||
if v1 != nil {
|
||||
@@ -392,8 +380,8 @@ func newCertState(dv cert.Version, v1, v2 cert.Certificate, pkcs11backed bool, p
|
||||
cs.v1Cert = v1
|
||||
cs.v1HandshakeBytes = v1hs
|
||||
|
||||
if cs.defaultVersion == 0 {
|
||||
cs.defaultVersion = cert.Version1
|
||||
if cs.initiatingVersion == 0 {
|
||||
cs.initiatingVersion = cert.Version1
|
||||
}
|
||||
}
|
||||
|
||||
@@ -413,8 +401,8 @@ func newCertState(dv cert.Version, v1, v2 cert.Certificate, pkcs11backed bool, p
|
||||
cs.v2Cert = v2
|
||||
cs.v2HandshakeBytes = v2hs
|
||||
|
||||
if cs.defaultVersion == 0 {
|
||||
cs.defaultVersion = cert.Version2
|
||||
if cs.initiatingVersion == 0 {
|
||||
cs.initiatingVersion = cert.Version2
|
||||
}
|
||||
}
|
||||
|
||||
@@ -427,16 +415,16 @@ func newCertState(dv cert.Version, v1, v2 cert.Certificate, pkcs11backed bool, p
|
||||
|
||||
for _, network := range crt.Networks() {
|
||||
cs.myVpnNetworks = append(cs.myVpnNetworks, network)
|
||||
cs.myVpnNetworksTable.Insert(network, struct{}{})
|
||||
cs.myVpnNetworksTable.Insert(network)
|
||||
|
||||
cs.myVpnAddrs = append(cs.myVpnAddrs, network.Addr())
|
||||
cs.myVpnAddrsTable.Insert(netip.PrefixFrom(network.Addr(), network.Addr().BitLen()), struct{}{})
|
||||
cs.myVpnAddrsTable.Insert(netip.PrefixFrom(network.Addr(), network.Addr().BitLen()))
|
||||
|
||||
if network.Addr().Is4() {
|
||||
addr := network.Masked().Addr().As4()
|
||||
mask := net.CIDRMask(network.Bits(), network.Addr().BitLen())
|
||||
binary.BigEndian.PutUint32(addr[:], binary.BigEndian.Uint32(addr[:])|^binary.BigEndian.Uint32(mask))
|
||||
cs.myVpnBroadcastAddrsTable.Insert(netip.PrefixFrom(netip.AddrFrom4(addr), network.Addr().BitLen()), struct{}{})
|
||||
cs.myVpnBroadcastAddrsTable.Insert(netip.PrefixFrom(netip.AddrFrom4(addr), network.Addr().BitLen()))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
150
psk.go
150
psk.go
@@ -1,150 +0,0 @@
|
||||
package nebula
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/util"
|
||||
"golang.org/x/crypto/hkdf"
|
||||
)
|
||||
|
||||
var ErrNotAPskMode = errors.New("not a psk mode")
|
||||
var ErrKeyTooShort = errors.New("key is too short")
|
||||
var ErrNotEnoughPskKeys = errors.New("at least 1 key is required")
|
||||
|
||||
// MinPskLength is the minimum bytes that we accept for a user defined psk, the choice is arbitrary
|
||||
const MinPskLength = 8
|
||||
|
||||
type PskMode int
|
||||
|
||||
const (
|
||||
PskAccepting PskMode = 0
|
||||
PskSending PskMode = 1
|
||||
PskEnforced PskMode = 2
|
||||
)
|
||||
|
||||
func NewPskMode(m string) (PskMode, error) {
|
||||
switch m {
|
||||
case "accepting":
|
||||
return PskAccepting, nil
|
||||
case "sending":
|
||||
return PskSending, nil
|
||||
case "enforced":
|
||||
return PskEnforced, nil
|
||||
}
|
||||
return PskAccepting, ErrNotAPskMode
|
||||
}
|
||||
|
||||
func (p PskMode) String() string {
|
||||
switch p {
|
||||
case PskAccepting:
|
||||
return "accepting"
|
||||
case PskSending:
|
||||
return "sending"
|
||||
case PskEnforced:
|
||||
return "enforced"
|
||||
}
|
||||
|
||||
return "unknown"
|
||||
}
|
||||
|
||||
func (p PskMode) IsValid() bool {
|
||||
switch p {
|
||||
case PskAccepting, PskSending, PskEnforced:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
type Psk struct {
|
||||
// pskMode sets how psk works, ignored, allowed for incoming, or enforced for all
|
||||
mode PskMode
|
||||
|
||||
// primary is the key to use when sending, it may be nil
|
||||
primary []byte
|
||||
|
||||
// keys holds all pre-computed psk hkdfs
|
||||
// Handshakes iterate this directly
|
||||
keys [][]byte
|
||||
}
|
||||
|
||||
// NewPskFromConfig is a helper for initial boot and config reloading.
|
||||
func NewPskFromConfig(c *config.C) (*Psk, error) {
|
||||
sMode := c.GetString("psk.mode", "accepting")
|
||||
mode, err := NewPskMode(sMode)
|
||||
if err != nil {
|
||||
return nil, util.NewContextualError("Could not parse psk.mode", m{"mode": mode}, err)
|
||||
}
|
||||
|
||||
return NewPsk(
|
||||
mode,
|
||||
c.GetStringSlice("psk.keys", nil),
|
||||
)
|
||||
}
|
||||
|
||||
// NewPsk creates a new Psk object and handles the caching of all accepted keys
|
||||
func NewPsk(mode PskMode, keys []string) (*Psk, error) {
|
||||
if !mode.IsValid() {
|
||||
return nil, ErrNotAPskMode
|
||||
}
|
||||
|
||||
psk := &Psk{
|
||||
mode: mode,
|
||||
}
|
||||
|
||||
err := psk.cachePsks(keys)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return psk, nil
|
||||
}
|
||||
|
||||
// cachePsks generates all psks we accept and caches them to speed up handshaking
|
||||
func (p *Psk) cachePsks(keys []string) error {
|
||||
if p.mode != PskAccepting && len(keys) < 1 {
|
||||
return ErrNotEnoughPskKeys
|
||||
}
|
||||
|
||||
p.keys = [][]byte{}
|
||||
|
||||
for i, rk := range keys {
|
||||
k, err := sha256KdfFromString(rk)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to generate key for position %v: %w", i, err)
|
||||
}
|
||||
|
||||
p.keys = append(p.keys, k)
|
||||
}
|
||||
|
||||
if p.mode != PskAccepting {
|
||||
// We are either sending or enforcing, the primary key must the first slot
|
||||
p.primary = p.keys[0]
|
||||
}
|
||||
|
||||
if p.mode != PskEnforced {
|
||||
// If we are not enforcing psk use then a nil psk is allowed
|
||||
p.keys = append(p.keys, nil)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// sha256KdfFromString generates a useful key to use from a provided secret
|
||||
func sha256KdfFromString(secret string) ([]byte, error) {
|
||||
if len(secret) < MinPskLength {
|
||||
return nil, ErrKeyTooShort
|
||||
}
|
||||
|
||||
hmacKey := make([]byte, sha256.Size)
|
||||
_, err := io.ReadFull(hkdf.New(sha256.New, []byte(secret), nil, nil), hmacKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return hmacKey, nil
|
||||
}
|
||||
72
psk_test.go
72
psk_test.go
@@ -1,72 +0,0 @@
|
||||
package nebula
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewPsk(t *testing.T) {
|
||||
t.Run("mode accepting", func(t *testing.T) {
|
||||
p, err := NewPsk(PskAccepting, nil)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, PskAccepting, p.mode)
|
||||
assert.Nil(t, p.keys[0])
|
||||
assert.Nil(t, p.primary)
|
||||
|
||||
p, err = NewPsk(PskAccepting, []string{"1234567"})
|
||||
require.ErrorIs(t, err, ErrKeyTooShort)
|
||||
|
||||
p, err = NewPsk(PskAccepting, []string{"hi there friends"})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, PskAccepting, p.mode)
|
||||
assert.Nil(t, p.primary)
|
||||
assert.Len(t, p.keys, 2)
|
||||
assert.Nil(t, p.keys[1])
|
||||
|
||||
expectedCache := []byte{
|
||||
0xb9, 0x8c, 0xdc, 0xac, 0x77, 0xf4, 0x8c, 0xf8, 0x1d, 0xe7, 0xe7, 0xb, 0x53, 0x25, 0xd3, 0x65,
|
||||
0xa3, 0x9f, 0x78, 0xb2, 0xc7, 0x2d, 0xa5, 0xd8, 0x84, 0x81, 0x7b, 0xb5, 0xdb, 0xe0, 0x9a, 0xef,
|
||||
}
|
||||
assert.Equal(t, expectedCache, p.keys[0])
|
||||
})
|
||||
|
||||
t.Run("mode sending", func(t *testing.T) {
|
||||
p, err := NewPsk(PskSending, nil)
|
||||
require.ErrorIs(t, err, ErrNotEnoughPskKeys)
|
||||
|
||||
p, err = NewPsk(PskSending, []string{"1234567"})
|
||||
require.ErrorIs(t, err, ErrKeyTooShort)
|
||||
|
||||
p, err = NewPsk(PskSending, []string{"hi there friends"})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, PskSending, p.mode)
|
||||
assert.Len(t, p.keys, 2)
|
||||
assert.Nil(t, p.keys[1])
|
||||
|
||||
expectedCache := []byte{
|
||||
0xb9, 0x8c, 0xdc, 0xac, 0x77, 0xf4, 0x8c, 0xf8, 0x1d, 0xe7, 0xe7, 0xb, 0x53, 0x25, 0xd3, 0x65,
|
||||
0xa3, 0x9f, 0x78, 0xb2, 0xc7, 0x2d, 0xa5, 0xd8, 0x84, 0x81, 0x7b, 0xb5, 0xdb, 0xe0, 0x9a, 0xef,
|
||||
}
|
||||
assert.Equal(t, expectedCache, p.keys[0])
|
||||
assert.Equal(t, p.keys[0], p.primary)
|
||||
})
|
||||
|
||||
t.Run("mode enforced", func(t *testing.T) {
|
||||
p, err := NewPsk(PskEnforced, nil)
|
||||
require.ErrorIs(t, err, ErrNotEnoughPskKeys)
|
||||
|
||||
p, err = NewPsk(PskEnforced, []string{"hi there friends"})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, PskEnforced, p.mode)
|
||||
assert.Len(t, p.keys, 1)
|
||||
|
||||
expectedCache := []byte{
|
||||
0xb9, 0x8c, 0xdc, 0xac, 0x77, 0xf4, 0x8c, 0xf8, 0x1d, 0xe7, 0xe7, 0xb, 0x53, 0x25, 0xd3, 0x65,
|
||||
0xa3, 0x9f, 0x78, 0xb2, 0xc7, 0x2d, 0xa5, 0xd8, 0x84, 0x81, 0x7b, 0xb5, 0xdb, 0xe0, 0x9a, 0xef,
|
||||
}
|
||||
assert.Equal(t, expectedCache, p.keys[0])
|
||||
assert.Equal(t, p.keys[0], p.primary)
|
||||
})
|
||||
}
|
||||
@@ -27,7 +27,7 @@ func TestNewPunchyFromConfig(t *testing.T) {
|
||||
assert.True(t, p.GetPunch())
|
||||
|
||||
// punchy.punch
|
||||
c.Settings["punchy"] = map[interface{}]interface{}{"punch": true}
|
||||
c.Settings["punchy"] = map[string]any{"punch": true}
|
||||
p = NewPunchyFromConfig(l, c)
|
||||
assert.True(t, p.GetPunch())
|
||||
|
||||
@@ -37,18 +37,18 @@ func TestNewPunchyFromConfig(t *testing.T) {
|
||||
assert.True(t, p.GetRespond())
|
||||
|
||||
// punchy.respond
|
||||
c.Settings["punchy"] = map[interface{}]interface{}{"respond": true}
|
||||
c.Settings["punchy"] = map[string]any{"respond": true}
|
||||
c.Settings["punch_back"] = false
|
||||
p = NewPunchyFromConfig(l, c)
|
||||
assert.True(t, p.GetRespond())
|
||||
|
||||
// punchy.delay
|
||||
c.Settings["punchy"] = map[interface{}]interface{}{"delay": "1m"}
|
||||
c.Settings["punchy"] = map[string]any{"delay": "1m"}
|
||||
p = NewPunchyFromConfig(l, c)
|
||||
assert.Equal(t, time.Minute, p.GetDelay())
|
||||
|
||||
// punchy.respond_delay
|
||||
c.Settings["punchy"] = map[interface{}]interface{}{"respond_delay": "1m"}
|
||||
c.Settings["punchy"] = map[string]any{"respond_delay": "1m"}
|
||||
p = NewPunchyFromConfig(l, c)
|
||||
assert.Equal(t, time.Minute, p.GetRespondDelay())
|
||||
}
|
||||
|
||||
@@ -241,15 +241,13 @@ func (rm *relayManager) handleCreateRelayRequest(v cert.Version, h *HostInfo, f
|
||||
logMsg.Info("handleCreateRelayRequest")
|
||||
// Is the source of the relay me? This should never happen, but did happen due to
|
||||
// an issue migrating relays over to newly re-handshaked host info objects.
|
||||
_, found := f.myVpnAddrsTable.Lookup(from)
|
||||
if found {
|
||||
if f.myVpnAddrsTable.Contains(from) {
|
||||
logMsg.WithField("myIP", from).Error("Discarding relay request from myself")
|
||||
return
|
||||
}
|
||||
|
||||
// Is the target of the relay me?
|
||||
_, found = f.myVpnAddrsTable.Lookup(target)
|
||||
if found {
|
||||
if f.myVpnAddrsTable.Contains(target) {
|
||||
existingRelay, ok := h.relayState.QueryRelayForByIp(from)
|
||||
if ok {
|
||||
switch existingRelay.State {
|
||||
|
||||
39
routing/balance.go
Normal file
39
routing/balance.go
Normal file
@@ -0,0 +1,39 @@
|
||||
package routing
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
|
||||
"github.com/slackhq/nebula/firewall"
|
||||
)
|
||||
|
||||
// Hashes the packet source and destination port and always returns a positive integer
|
||||
// Based on 'Prospecting for Hash Functions'
|
||||
// - https://nullprogram.com/blog/2018/07/31/
|
||||
// - https://github.com/skeeto/hash-prospector
|
||||
// [16 21f0aaad 15 d35a2d97 15] = 0.10760229515479501
|
||||
func hashPacket(p *firewall.Packet) int {
|
||||
x := (uint32(p.LocalPort) << 16) | uint32(p.RemotePort)
|
||||
x ^= x >> 16
|
||||
x *= 0x21f0aaad
|
||||
x ^= x >> 15
|
||||
x *= 0xd35a2d97
|
||||
x ^= x >> 15
|
||||
|
||||
return int(x) & 0x7FFFFFFF
|
||||
}
|
||||
|
||||
// For this function to work correctly it requires that the buckets for the gateways have been calculated
|
||||
// If the contract is violated balancing will not work properly and the second return value will return false
|
||||
func BalancePacket(fwPacket *firewall.Packet, gateways []Gateway) (netip.Addr, bool) {
|
||||
hash := hashPacket(fwPacket)
|
||||
|
||||
for i := range gateways {
|
||||
if hash <= gateways[i].BucketUpperBound() {
|
||||
return gateways[i].Addr(), true
|
||||
}
|
||||
}
|
||||
|
||||
// If you land here then the buckets for the gateways are not properly calculated
|
||||
// Fallback to random routing and let the caller know
|
||||
return gateways[hash%len(gateways)].Addr(), false
|
||||
}
|
||||
144
routing/balance_test.go
Normal file
144
routing/balance_test.go
Normal file
@@ -0,0 +1,144 @@
|
||||
package routing
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/slackhq/nebula/firewall"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestPacketsAreBalancedEqually(t *testing.T) {
|
||||
|
||||
gateways := []Gateway{}
|
||||
|
||||
gw1Addr := netip.MustParseAddr("1.0.0.1")
|
||||
gw2Addr := netip.MustParseAddr("1.0.0.2")
|
||||
gw3Addr := netip.MustParseAddr("1.0.0.3")
|
||||
|
||||
gateways = append(gateways, NewGateway(gw1Addr, 1))
|
||||
gateways = append(gateways, NewGateway(gw2Addr, 1))
|
||||
gateways = append(gateways, NewGateway(gw3Addr, 1))
|
||||
|
||||
CalculateBucketsForGateways(gateways)
|
||||
|
||||
gw1count := 0
|
||||
gw2count := 0
|
||||
gw3count := 0
|
||||
|
||||
iterationCount := uint16(65535)
|
||||
for i := uint16(0); i < iterationCount; i++ {
|
||||
packet := firewall.Packet{
|
||||
LocalAddr: netip.MustParseAddr("192.168.1.1"),
|
||||
RemoteAddr: netip.MustParseAddr("10.0.0.1"),
|
||||
LocalPort: i,
|
||||
RemotePort: 65535 - i,
|
||||
Protocol: 6, // TCP
|
||||
Fragment: false,
|
||||
}
|
||||
|
||||
selectedGw, ok := BalancePacket(&packet, gateways)
|
||||
assert.True(t, ok)
|
||||
|
||||
switch selectedGw {
|
||||
case gw1Addr:
|
||||
gw1count += 1
|
||||
case gw2Addr:
|
||||
gw2count += 1
|
||||
case gw3Addr:
|
||||
gw3count += 1
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// Assert packets are balanced, allow variation of up to 100 packets per gateway
|
||||
assert.InDeltaf(t, iterationCount/3, gw1count, 100, "Expected %d +/- 100, but got %d", iterationCount/3, gw1count)
|
||||
assert.InDeltaf(t, iterationCount/3, gw2count, 100, "Expected %d +/- 100, but got %d", iterationCount/3, gw1count)
|
||||
assert.InDeltaf(t, iterationCount/3, gw3count, 100, "Expected %d +/- 100, but got %d", iterationCount/3, gw1count)
|
||||
|
||||
}
|
||||
|
||||
func TestPacketsAreBalancedByPriority(t *testing.T) {
|
||||
|
||||
gateways := []Gateway{}
|
||||
|
||||
gw1Addr := netip.MustParseAddr("1.0.0.1")
|
||||
gw2Addr := netip.MustParseAddr("1.0.0.2")
|
||||
|
||||
gateways = append(gateways, NewGateway(gw1Addr, 10))
|
||||
gateways = append(gateways, NewGateway(gw2Addr, 5))
|
||||
|
||||
CalculateBucketsForGateways(gateways)
|
||||
|
||||
gw1count := 0
|
||||
gw2count := 0
|
||||
|
||||
iterationCount := uint16(65535)
|
||||
for i := uint16(0); i < iterationCount; i++ {
|
||||
packet := firewall.Packet{
|
||||
LocalAddr: netip.MustParseAddr("192.168.1.1"),
|
||||
RemoteAddr: netip.MustParseAddr("10.0.0.1"),
|
||||
LocalPort: i,
|
||||
RemotePort: 65535 - i,
|
||||
Protocol: 6, // TCP
|
||||
Fragment: false,
|
||||
}
|
||||
|
||||
selectedGw, ok := BalancePacket(&packet, gateways)
|
||||
assert.True(t, ok)
|
||||
|
||||
switch selectedGw {
|
||||
case gw1Addr:
|
||||
gw1count += 1
|
||||
case gw2Addr:
|
||||
gw2count += 1
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
iterationCountAsFloat := float32(iterationCount)
|
||||
|
||||
assert.InDeltaf(t, iterationCountAsFloat*(2.0/3.0), gw1count, 100, "Expected %d +/- 100, but got %d", iterationCountAsFloat*(2.0/3.0), gw1count)
|
||||
assert.InDeltaf(t, iterationCountAsFloat*(1.0/3.0), gw2count, 100, "Expected %d +/- 100, but got %d", iterationCountAsFloat*(1.0/3.0), gw2count)
|
||||
}
|
||||
|
||||
func TestBalancePacketDistributsRandomlyAndReturnsFalseIfBucketsNotCalculated(t *testing.T) {
|
||||
gateways := []Gateway{}
|
||||
|
||||
gw1Addr := netip.MustParseAddr("1.0.0.1")
|
||||
gw2Addr := netip.MustParseAddr("1.0.0.2")
|
||||
|
||||
gateways = append(gateways, NewGateway(gw1Addr, 10))
|
||||
gateways = append(gateways, NewGateway(gw2Addr, 5))
|
||||
|
||||
iterationCount := uint16(65535)
|
||||
gw1count := 0
|
||||
gw2count := 0
|
||||
|
||||
for i := uint16(0); i < iterationCount; i++ {
|
||||
packet := firewall.Packet{
|
||||
LocalAddr: netip.MustParseAddr("192.168.1.1"),
|
||||
RemoteAddr: netip.MustParseAddr("10.0.0.1"),
|
||||
LocalPort: i,
|
||||
RemotePort: 65535 - i,
|
||||
Protocol: 6, // TCP
|
||||
Fragment: false,
|
||||
}
|
||||
|
||||
selectedGw, ok := BalancePacket(&packet, gateways)
|
||||
assert.False(t, ok)
|
||||
|
||||
switch selectedGw {
|
||||
case gw1Addr:
|
||||
gw1count += 1
|
||||
case gw2Addr:
|
||||
gw2count += 1
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
assert.Equal(t, int(iterationCount), (gw1count + gw2count))
|
||||
assert.NotEqual(t, 0, gw1count)
|
||||
assert.NotEqual(t, 0, gw2count)
|
||||
|
||||
}
|
||||
70
routing/gateway.go
Normal file
70
routing/gateway.go
Normal file
@@ -0,0 +1,70 @@
|
||||
package routing
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
)
|
||||
|
||||
const (
|
||||
// Sentinal value
|
||||
BucketNotCalculated = -1
|
||||
)
|
||||
|
||||
type Gateways []Gateway
|
||||
|
||||
func (g Gateways) String() string {
|
||||
str := ""
|
||||
for i, gw := range g {
|
||||
str += gw.String()
|
||||
if i < len(g)-1 {
|
||||
str += ", "
|
||||
}
|
||||
}
|
||||
return str
|
||||
}
|
||||
|
||||
type Gateway struct {
|
||||
addr netip.Addr
|
||||
weight int
|
||||
bucketUpperBound int
|
||||
}
|
||||
|
||||
func NewGateway(addr netip.Addr, weight int) Gateway {
|
||||
return Gateway{addr: addr, weight: weight, bucketUpperBound: BucketNotCalculated}
|
||||
}
|
||||
|
||||
func (g *Gateway) BucketUpperBound() int {
|
||||
return g.bucketUpperBound
|
||||
}
|
||||
|
||||
func (g *Gateway) Addr() netip.Addr {
|
||||
return g.addr
|
||||
}
|
||||
|
||||
func (g *Gateway) String() string {
|
||||
return fmt.Sprintf("{addr: %s, weight: %d}", g.addr, g.weight)
|
||||
}
|
||||
|
||||
// Divide and round to nearest integer
|
||||
func divideAndRound(v uint64, d uint64) uint64 {
|
||||
var tmp uint64 = v + d/2
|
||||
return tmp / d
|
||||
}
|
||||
|
||||
// Implements Hash-Threshold mapping, equivalent to the implementation in the linux kernel.
|
||||
// After this function returns each gateway will have a
|
||||
// positive bucketUpperBound with a maximum value of 2147483647 (INT_MAX)
|
||||
func CalculateBucketsForGateways(gateways []Gateway) {
|
||||
|
||||
var totalWeight int = 0
|
||||
for i := range gateways {
|
||||
totalWeight += gateways[i].weight
|
||||
}
|
||||
|
||||
var loopWeight int = 0
|
||||
for i := range gateways {
|
||||
loopWeight += gateways[i].weight
|
||||
gateways[i].bucketUpperBound = int(divideAndRound(uint64(loopWeight)<<31, uint64(totalWeight))) - 1
|
||||
}
|
||||
|
||||
}
|
||||
34
routing/gateway_test.go
Normal file
34
routing/gateway_test.go
Normal file
@@ -0,0 +1,34 @@
|
||||
package routing
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestRebalance3_2Split(t *testing.T) {
|
||||
gateways := []Gateway{}
|
||||
|
||||
gateways = append(gateways, Gateway{addr: netip.Addr{}, weight: 10})
|
||||
gateways = append(gateways, Gateway{addr: netip.Addr{}, weight: 5})
|
||||
|
||||
CalculateBucketsForGateways(gateways)
|
||||
|
||||
assert.Equal(t, 1431655764, gateways[0].bucketUpperBound) // INT_MAX/3*2
|
||||
assert.Equal(t, 2147483647, gateways[1].bucketUpperBound) // INT_MAX
|
||||
}
|
||||
|
||||
func TestRebalanceEqualSplit(t *testing.T) {
|
||||
gateways := []Gateway{}
|
||||
|
||||
gateways = append(gateways, Gateway{addr: netip.Addr{}, weight: 1})
|
||||
gateways = append(gateways, Gateway{addr: netip.Addr{}, weight: 1})
|
||||
gateways = append(gateways, Gateway{addr: netip.Addr{}, weight: 1})
|
||||
|
||||
CalculateBucketsForGateways(gateways)
|
||||
|
||||
assert.Equal(t, 715827882, gateways[0].bucketUpperBound) // INT_MAX/3
|
||||
assert.Equal(t, 1431655764, gateways[1].bucketUpperBound) // INT_MAX/3*2
|
||||
assert.Equal(t, 2147483647, gateways[2].bucketUpperBound) // INT_MAX
|
||||
}
|
||||
@@ -13,10 +13,10 @@ import (
|
||||
"github.com/slackhq/nebula/cert_test"
|
||||
"github.com/slackhq/nebula/config"
|
||||
"golang.org/x/sync/errgroup"
|
||||
"gopkg.in/yaml.v2"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
type m map[string]interface{}
|
||||
type m = map[string]any
|
||||
|
||||
func newSimpleService(caCrt cert.Certificate, caKey []byte, name string, udpIp netip.Addr, overrides m) *Service {
|
||||
_, _, myPrivKey, myPEM := cert_test.NewTestCert(cert.Version2, cert.Curve_CURVE25519, caCrt, caKey, "a", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.PrefixFrom(udpIp, 24)}, nil, []string{})
|
||||
|
||||
92
ssh.go
92
ssh.go
@@ -124,10 +124,10 @@ func configSSH(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) (func(), erro
|
||||
}
|
||||
|
||||
rawKeys := c.Get("sshd.authorized_users")
|
||||
keys, ok := rawKeys.([]interface{})
|
||||
keys, ok := rawKeys.([]any)
|
||||
if ok {
|
||||
for _, rk := range keys {
|
||||
kDef, ok := rk.(map[interface{}]interface{})
|
||||
kDef, ok := rk.(map[string]any)
|
||||
if !ok {
|
||||
l.WithField("sshKeyConfig", rk).Warn("Authorized user had an error, ignoring")
|
||||
continue
|
||||
@@ -148,7 +148,7 @@ func configSSH(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) (func(), erro
|
||||
continue
|
||||
}
|
||||
|
||||
case []interface{}:
|
||||
case []any:
|
||||
for _, subK := range v {
|
||||
sk, ok := subK.(string)
|
||||
if !ok {
|
||||
@@ -190,7 +190,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter
|
||||
ssh.RegisterCommand(&sshd.Command{
|
||||
Name: "list-hostmap",
|
||||
ShortDescription: "List all known previously connected hosts",
|
||||
Flags: func() (*flag.FlagSet, interface{}) {
|
||||
Flags: func() (*flag.FlagSet, any) {
|
||||
fl := flag.NewFlagSet("", flag.ContinueOnError)
|
||||
s := sshListHostMapFlags{}
|
||||
fl.BoolVar(&s.Json, "json", false, "outputs as json with more information")
|
||||
@@ -198,7 +198,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter
|
||||
fl.BoolVar(&s.ByIndex, "by-index", false, "gets all hosts in the hostmap from the index table")
|
||||
return fl, &s
|
||||
},
|
||||
Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
|
||||
Callback: func(fs any, a []string, w sshd.StringWriter) error {
|
||||
return sshListHostMap(f.hostMap, fs, w)
|
||||
},
|
||||
})
|
||||
@@ -206,7 +206,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter
|
||||
ssh.RegisterCommand(&sshd.Command{
|
||||
Name: "list-pending-hostmap",
|
||||
ShortDescription: "List all handshaking hosts",
|
||||
Flags: func() (*flag.FlagSet, interface{}) {
|
||||
Flags: func() (*flag.FlagSet, any) {
|
||||
fl := flag.NewFlagSet("", flag.ContinueOnError)
|
||||
s := sshListHostMapFlags{}
|
||||
fl.BoolVar(&s.Json, "json", false, "outputs as json with more information")
|
||||
@@ -214,7 +214,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter
|
||||
fl.BoolVar(&s.ByIndex, "by-index", false, "gets all hosts in the hostmap from the index table")
|
||||
return fl, &s
|
||||
},
|
||||
Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
|
||||
Callback: func(fs any, a []string, w sshd.StringWriter) error {
|
||||
return sshListHostMap(f.handshakeManager, fs, w)
|
||||
},
|
||||
})
|
||||
@@ -222,14 +222,14 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter
|
||||
ssh.RegisterCommand(&sshd.Command{
|
||||
Name: "list-lighthouse-addrmap",
|
||||
ShortDescription: "List all lighthouse map entries",
|
||||
Flags: func() (*flag.FlagSet, interface{}) {
|
||||
Flags: func() (*flag.FlagSet, any) {
|
||||
fl := flag.NewFlagSet("", flag.ContinueOnError)
|
||||
s := sshListHostMapFlags{}
|
||||
fl.BoolVar(&s.Json, "json", false, "outputs as json with more information")
|
||||
fl.BoolVar(&s.Pretty, "pretty", false, "pretty prints json, assumes -json")
|
||||
return fl, &s
|
||||
},
|
||||
Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
|
||||
Callback: func(fs any, a []string, w sshd.StringWriter) error {
|
||||
return sshListLighthouseMap(f.lightHouse, fs, w)
|
||||
},
|
||||
})
|
||||
@@ -237,7 +237,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter
|
||||
ssh.RegisterCommand(&sshd.Command{
|
||||
Name: "reload",
|
||||
ShortDescription: "Reloads configuration from disk, same as sending HUP to the process",
|
||||
Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
|
||||
Callback: func(fs any, a []string, w sshd.StringWriter) error {
|
||||
return sshReload(c, w)
|
||||
},
|
||||
})
|
||||
@@ -251,7 +251,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter
|
||||
ssh.RegisterCommand(&sshd.Command{
|
||||
Name: "stop-cpu-profile",
|
||||
ShortDescription: "Stops a cpu profile and writes output to the previously provided file",
|
||||
Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
|
||||
Callback: func(fs any, a []string, w sshd.StringWriter) error {
|
||||
pprof.StopCPUProfile()
|
||||
return w.WriteLine("If a CPU profile was running it is now stopped")
|
||||
},
|
||||
@@ -278,7 +278,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter
|
||||
ssh.RegisterCommand(&sshd.Command{
|
||||
Name: "log-level",
|
||||
ShortDescription: "Gets or sets the current log level",
|
||||
Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
|
||||
Callback: func(fs any, a []string, w sshd.StringWriter) error {
|
||||
return sshLogLevel(l, fs, a, w)
|
||||
},
|
||||
})
|
||||
@@ -286,7 +286,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter
|
||||
ssh.RegisterCommand(&sshd.Command{
|
||||
Name: "log-format",
|
||||
ShortDescription: "Gets or sets the current log format",
|
||||
Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
|
||||
Callback: func(fs any, a []string, w sshd.StringWriter) error {
|
||||
return sshLogFormat(l, fs, a, w)
|
||||
},
|
||||
})
|
||||
@@ -294,7 +294,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter
|
||||
ssh.RegisterCommand(&sshd.Command{
|
||||
Name: "version",
|
||||
ShortDescription: "Prints the currently running version of nebula",
|
||||
Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
|
||||
Callback: func(fs any, a []string, w sshd.StringWriter) error {
|
||||
return sshVersion(f, fs, a, w)
|
||||
},
|
||||
})
|
||||
@@ -302,14 +302,14 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter
|
||||
ssh.RegisterCommand(&sshd.Command{
|
||||
Name: "device-info",
|
||||
ShortDescription: "Prints information about the network device.",
|
||||
Flags: func() (*flag.FlagSet, interface{}) {
|
||||
Flags: func() (*flag.FlagSet, any) {
|
||||
fl := flag.NewFlagSet("", flag.ContinueOnError)
|
||||
s := sshDeviceInfoFlags{}
|
||||
fl.BoolVar(&s.Json, "json", false, "outputs as json with more information")
|
||||
fl.BoolVar(&s.Pretty, "pretty", false, "pretty prints json, assumes -json")
|
||||
return fl, &s
|
||||
},
|
||||
Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
|
||||
Callback: func(fs any, a []string, w sshd.StringWriter) error {
|
||||
return sshDeviceInfo(f, fs, w)
|
||||
},
|
||||
})
|
||||
@@ -317,7 +317,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter
|
||||
ssh.RegisterCommand(&sshd.Command{
|
||||
Name: "print-cert",
|
||||
ShortDescription: "Prints the current certificate being used or the certificate for the provided vpn addr",
|
||||
Flags: func() (*flag.FlagSet, interface{}) {
|
||||
Flags: func() (*flag.FlagSet, any) {
|
||||
fl := flag.NewFlagSet("", flag.ContinueOnError)
|
||||
s := sshPrintCertFlags{}
|
||||
fl.BoolVar(&s.Json, "json", false, "outputs as json")
|
||||
@@ -325,7 +325,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter
|
||||
fl.BoolVar(&s.Raw, "raw", false, "raw prints the PEM encoded certificate, not compatible with -json or -pretty")
|
||||
return fl, &s
|
||||
},
|
||||
Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
|
||||
Callback: func(fs any, a []string, w sshd.StringWriter) error {
|
||||
return sshPrintCert(f, fs, a, w)
|
||||
},
|
||||
})
|
||||
@@ -333,13 +333,13 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter
|
||||
ssh.RegisterCommand(&sshd.Command{
|
||||
Name: "print-tunnel",
|
||||
ShortDescription: "Prints json details about a tunnel for the provided vpn addr",
|
||||
Flags: func() (*flag.FlagSet, interface{}) {
|
||||
Flags: func() (*flag.FlagSet, any) {
|
||||
fl := flag.NewFlagSet("", flag.ContinueOnError)
|
||||
s := sshPrintTunnelFlags{}
|
||||
fl.BoolVar(&s.Pretty, "pretty", false, "pretty prints json")
|
||||
return fl, &s
|
||||
},
|
||||
Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
|
||||
Callback: func(fs any, a []string, w sshd.StringWriter) error {
|
||||
return sshPrintTunnel(f, fs, a, w)
|
||||
},
|
||||
})
|
||||
@@ -347,13 +347,13 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter
|
||||
ssh.RegisterCommand(&sshd.Command{
|
||||
Name: "print-relays",
|
||||
ShortDescription: "Prints json details about all relay info",
|
||||
Flags: func() (*flag.FlagSet, interface{}) {
|
||||
Flags: func() (*flag.FlagSet, any) {
|
||||
fl := flag.NewFlagSet("", flag.ContinueOnError)
|
||||
s := sshPrintTunnelFlags{}
|
||||
fl.BoolVar(&s.Pretty, "pretty", false, "pretty prints json")
|
||||
return fl, &s
|
||||
},
|
||||
Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
|
||||
Callback: func(fs any, a []string, w sshd.StringWriter) error {
|
||||
return sshPrintRelays(f, fs, a, w)
|
||||
},
|
||||
})
|
||||
@@ -361,13 +361,13 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter
|
||||
ssh.RegisterCommand(&sshd.Command{
|
||||
Name: "change-remote",
|
||||
ShortDescription: "Changes the remote address used in the tunnel for the provided vpn addr",
|
||||
Flags: func() (*flag.FlagSet, interface{}) {
|
||||
Flags: func() (*flag.FlagSet, any) {
|
||||
fl := flag.NewFlagSet("", flag.ContinueOnError)
|
||||
s := sshChangeRemoteFlags{}
|
||||
fl.StringVar(&s.Address, "address", "", "The new remote address, ip:port")
|
||||
return fl, &s
|
||||
},
|
||||
Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
|
||||
Callback: func(fs any, a []string, w sshd.StringWriter) error {
|
||||
return sshChangeRemote(f, fs, a, w)
|
||||
},
|
||||
})
|
||||
@@ -375,13 +375,13 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter
|
||||
ssh.RegisterCommand(&sshd.Command{
|
||||
Name: "close-tunnel",
|
||||
ShortDescription: "Closes a tunnel for the provided vpn addr",
|
||||
Flags: func() (*flag.FlagSet, interface{}) {
|
||||
Flags: func() (*flag.FlagSet, any) {
|
||||
fl := flag.NewFlagSet("", flag.ContinueOnError)
|
||||
s := sshCloseTunnelFlags{}
|
||||
fl.BoolVar(&s.LocalOnly, "local-only", false, "Disables notifying the remote that the tunnel is shutting down")
|
||||
return fl, &s
|
||||
},
|
||||
Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
|
||||
Callback: func(fs any, a []string, w sshd.StringWriter) error {
|
||||
return sshCloseTunnel(f, fs, a, w)
|
||||
},
|
||||
})
|
||||
@@ -390,13 +390,13 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter
|
||||
Name: "create-tunnel",
|
||||
ShortDescription: "Creates a tunnel for the provided vpn address",
|
||||
Help: "The lighthouses will be queried for real addresses but you can provide one as well.",
|
||||
Flags: func() (*flag.FlagSet, interface{}) {
|
||||
Flags: func() (*flag.FlagSet, any) {
|
||||
fl := flag.NewFlagSet("", flag.ContinueOnError)
|
||||
s := sshCreateTunnelFlags{}
|
||||
fl.StringVar(&s.Address, "address", "", "Optionally provide a real remote address, ip:port ")
|
||||
return fl, &s
|
||||
},
|
||||
Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
|
||||
Callback: func(fs any, a []string, w sshd.StringWriter) error {
|
||||
return sshCreateTunnel(f, fs, a, w)
|
||||
},
|
||||
})
|
||||
@@ -405,13 +405,13 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter
|
||||
Name: "query-lighthouse",
|
||||
ShortDescription: "Query the lighthouses for the provided vpn address",
|
||||
Help: "This command is asynchronous. Only currently known udp addresses will be printed.",
|
||||
Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
|
||||
Callback: func(fs any, a []string, w sshd.StringWriter) error {
|
||||
return sshQueryLighthouse(f, fs, a, w)
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func sshListHostMap(hl controlHostLister, a interface{}, w sshd.StringWriter) error {
|
||||
func sshListHostMap(hl controlHostLister, a any, w sshd.StringWriter) error {
|
||||
fs, ok := a.(*sshListHostMapFlags)
|
||||
if !ok {
|
||||
return nil
|
||||
@@ -451,7 +451,7 @@ func sshListHostMap(hl controlHostLister, a interface{}, w sshd.StringWriter) er
|
||||
return nil
|
||||
}
|
||||
|
||||
func sshListLighthouseMap(lightHouse *LightHouse, a interface{}, w sshd.StringWriter) error {
|
||||
func sshListLighthouseMap(lightHouse *LightHouse, a any, w sshd.StringWriter) error {
|
||||
fs, ok := a.(*sshListHostMapFlags)
|
||||
if !ok {
|
||||
return nil
|
||||
@@ -505,7 +505,7 @@ func sshListLighthouseMap(lightHouse *LightHouse, a interface{}, w sshd.StringWr
|
||||
return nil
|
||||
}
|
||||
|
||||
func sshStartCpuProfile(fs interface{}, a []string, w sshd.StringWriter) error {
|
||||
func sshStartCpuProfile(fs any, a []string, w sshd.StringWriter) error {
|
||||
if len(a) == 0 {
|
||||
err := w.WriteLine("No path to write profile provided")
|
||||
return err
|
||||
@@ -527,11 +527,11 @@ func sshStartCpuProfile(fs interface{}, a []string, w sshd.StringWriter) error {
|
||||
return err
|
||||
}
|
||||
|
||||
func sshVersion(ifce *Interface, fs interface{}, a []string, w sshd.StringWriter) error {
|
||||
func sshVersion(ifce *Interface, fs any, a []string, w sshd.StringWriter) error {
|
||||
return w.WriteLine(fmt.Sprintf("%s", ifce.version))
|
||||
}
|
||||
|
||||
func sshQueryLighthouse(ifce *Interface, fs interface{}, a []string, w sshd.StringWriter) error {
|
||||
func sshQueryLighthouse(ifce *Interface, fs any, a []string, w sshd.StringWriter) error {
|
||||
if len(a) == 0 {
|
||||
return w.WriteLine("No vpn address was provided")
|
||||
}
|
||||
@@ -553,7 +553,7 @@ func sshQueryLighthouse(ifce *Interface, fs interface{}, a []string, w sshd.Stri
|
||||
return json.NewEncoder(w.GetWriter()).Encode(cm)
|
||||
}
|
||||
|
||||
func sshCloseTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWriter) error {
|
||||
func sshCloseTunnel(ifce *Interface, fs any, a []string, w sshd.StringWriter) error {
|
||||
flags, ok := fs.(*sshCloseTunnelFlags)
|
||||
if !ok {
|
||||
return nil
|
||||
@@ -593,7 +593,7 @@ func sshCloseTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWr
|
||||
return w.WriteLine("Closed")
|
||||
}
|
||||
|
||||
func sshCreateTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWriter) error {
|
||||
func sshCreateTunnel(ifce *Interface, fs any, a []string, w sshd.StringWriter) error {
|
||||
flags, ok := fs.(*sshCreateTunnelFlags)
|
||||
if !ok {
|
||||
return nil
|
||||
@@ -638,7 +638,7 @@ func sshCreateTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringW
|
||||
return w.WriteLine("Created")
|
||||
}
|
||||
|
||||
func sshChangeRemote(ifce *Interface, fs interface{}, a []string, w sshd.StringWriter) error {
|
||||
func sshChangeRemote(ifce *Interface, fs any, a []string, w sshd.StringWriter) error {
|
||||
flags, ok := fs.(*sshChangeRemoteFlags)
|
||||
if !ok {
|
||||
return nil
|
||||
@@ -675,7 +675,7 @@ func sshChangeRemote(ifce *Interface, fs interface{}, a []string, w sshd.StringW
|
||||
return w.WriteLine("Changed")
|
||||
}
|
||||
|
||||
func sshGetHeapProfile(fs interface{}, a []string, w sshd.StringWriter) error {
|
||||
func sshGetHeapProfile(fs any, a []string, w sshd.StringWriter) error {
|
||||
if len(a) == 0 {
|
||||
return w.WriteLine("No path to write profile provided")
|
||||
}
|
||||
@@ -696,7 +696,7 @@ func sshGetHeapProfile(fs interface{}, a []string, w sshd.StringWriter) error {
|
||||
return err
|
||||
}
|
||||
|
||||
func sshMutexProfileFraction(fs interface{}, a []string, w sshd.StringWriter) error {
|
||||
func sshMutexProfileFraction(fs any, a []string, w sshd.StringWriter) error {
|
||||
if len(a) == 0 {
|
||||
rate := runtime.SetMutexProfileFraction(-1)
|
||||
return w.WriteLine(fmt.Sprintf("Current value: %d", rate))
|
||||
@@ -711,7 +711,7 @@ func sshMutexProfileFraction(fs interface{}, a []string, w sshd.StringWriter) er
|
||||
return w.WriteLine(fmt.Sprintf("New value: %d. Old value: %d", newRate, oldRate))
|
||||
}
|
||||
|
||||
func sshGetMutexProfile(fs interface{}, a []string, w sshd.StringWriter) error {
|
||||
func sshGetMutexProfile(fs any, a []string, w sshd.StringWriter) error {
|
||||
if len(a) == 0 {
|
||||
return w.WriteLine("No path to write profile provided")
|
||||
}
|
||||
@@ -735,7 +735,7 @@ func sshGetMutexProfile(fs interface{}, a []string, w sshd.StringWriter) error {
|
||||
return w.WriteLine(fmt.Sprintf("Mutex profile created at %s", a))
|
||||
}
|
||||
|
||||
func sshLogLevel(l *logrus.Logger, fs interface{}, a []string, w sshd.StringWriter) error {
|
||||
func sshLogLevel(l *logrus.Logger, fs any, a []string, w sshd.StringWriter) error {
|
||||
if len(a) == 0 {
|
||||
return w.WriteLine(fmt.Sprintf("Log level is: %s", l.Level))
|
||||
}
|
||||
@@ -749,7 +749,7 @@ func sshLogLevel(l *logrus.Logger, fs interface{}, a []string, w sshd.StringWrit
|
||||
return w.WriteLine(fmt.Sprintf("Log level is: %s", l.Level))
|
||||
}
|
||||
|
||||
func sshLogFormat(l *logrus.Logger, fs interface{}, a []string, w sshd.StringWriter) error {
|
||||
func sshLogFormat(l *logrus.Logger, fs any, a []string, w sshd.StringWriter) error {
|
||||
if len(a) == 0 {
|
||||
return w.WriteLine(fmt.Sprintf("Log format is: %s", reflect.TypeOf(l.Formatter)))
|
||||
}
|
||||
@@ -767,7 +767,7 @@ func sshLogFormat(l *logrus.Logger, fs interface{}, a []string, w sshd.StringWri
|
||||
return w.WriteLine(fmt.Sprintf("Log format is: %s", reflect.TypeOf(l.Formatter)))
|
||||
}
|
||||
|
||||
func sshPrintCert(ifce *Interface, fs interface{}, a []string, w sshd.StringWriter) error {
|
||||
func sshPrintCert(ifce *Interface, fs any, a []string, w sshd.StringWriter) error {
|
||||
args, ok := fs.(*sshPrintCertFlags)
|
||||
if !ok {
|
||||
return nil
|
||||
@@ -822,7 +822,7 @@ func sshPrintCert(ifce *Interface, fs interface{}, a []string, w sshd.StringWrit
|
||||
return w.WriteLine(cert.String())
|
||||
}
|
||||
|
||||
func sshPrintRelays(ifce *Interface, fs interface{}, a []string, w sshd.StringWriter) error {
|
||||
func sshPrintRelays(ifce *Interface, fs any, a []string, w sshd.StringWriter) error {
|
||||
args, ok := fs.(*sshPrintTunnelFlags)
|
||||
if !ok {
|
||||
w.WriteLine(fmt.Sprintf("sshPrintRelays failed to convert args type"))
|
||||
@@ -919,7 +919,7 @@ func sshPrintRelays(ifce *Interface, fs interface{}, a []string, w sshd.StringWr
|
||||
return nil
|
||||
}
|
||||
|
||||
func sshPrintTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWriter) error {
|
||||
func sshPrintTunnel(ifce *Interface, fs any, a []string, w sshd.StringWriter) error {
|
||||
args, ok := fs.(*sshPrintTunnelFlags)
|
||||
if !ok {
|
||||
return nil
|
||||
@@ -951,7 +951,7 @@ func sshPrintTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWr
|
||||
return enc.Encode(copyHostInfo(hostInfo, ifce.hostMap.GetPreferredRanges()))
|
||||
}
|
||||
|
||||
func sshDeviceInfo(ifce *Interface, fs interface{}, w sshd.StringWriter) error {
|
||||
func sshDeviceInfo(ifce *Interface, fs any, w sshd.StringWriter) error {
|
||||
|
||||
data := struct {
|
||||
Name string `json:"name"`
|
||||
|
||||
@@ -12,7 +12,7 @@ import (
|
||||
|
||||
// CommandFlags is a function called before help or command execution to parse command line flags
|
||||
// It should return a flag.FlagSet instance and a pointer to the struct that will contain parsed flags
|
||||
type CommandFlags func() (*flag.FlagSet, interface{})
|
||||
type CommandFlags func() (*flag.FlagSet, any)
|
||||
|
||||
// CommandCallback is the function called when your command should execute.
|
||||
// fs will be a a pointer to the struct provided by Command.Flags callback, if there was one. -h and -help are reserved
|
||||
@@ -21,7 +21,7 @@ type CommandFlags func() (*flag.FlagSet, interface{})
|
||||
// w is the writer to use when sending messages back to the client.
|
||||
// If an error is returned by the callback it is logged locally, the callback should handle messaging errors to the user
|
||||
// where appropriate
|
||||
type CommandCallback func(fs interface{}, a []string, w StringWriter) error
|
||||
type CommandCallback func(fs any, a []string, w StringWriter) error
|
||||
|
||||
type Command struct {
|
||||
Name string
|
||||
@@ -34,7 +34,7 @@ type Command struct {
|
||||
func execCommand(c *Command, args []string, w StringWriter) error {
|
||||
var (
|
||||
fl *flag.FlagSet
|
||||
fs interface{}
|
||||
fs any
|
||||
)
|
||||
|
||||
if c.Flags != nil {
|
||||
@@ -85,7 +85,7 @@ func lookupCommand(c *radix.Tree, sCmd string) (*Command, error) {
|
||||
|
||||
func matchCommand(c *radix.Tree, cmd string) []string {
|
||||
cmds := make([]string, 0)
|
||||
c.WalkPrefix(cmd, func(found string, v interface{}) bool {
|
||||
c.WalkPrefix(cmd, func(found string, v any) bool {
|
||||
cmds = append(cmds, found)
|
||||
return false
|
||||
})
|
||||
@@ -95,7 +95,7 @@ func matchCommand(c *radix.Tree, cmd string) []string {
|
||||
|
||||
func allCommands(c *radix.Tree) []*Command {
|
||||
cmds := make([]*Command, 0)
|
||||
c.WalkPrefix("", func(found string, v interface{}) bool {
|
||||
c.WalkPrefix("", func(found string, v any) bool {
|
||||
cmd, ok := v.(*Command)
|
||||
if ok {
|
||||
cmds = append(cmds, cmd)
|
||||
|
||||
@@ -86,7 +86,7 @@ func NewSSHServer(l *logrus.Entry) (*SSHServer, error) {
|
||||
s.RegisterCommand(&Command{
|
||||
Name: "help",
|
||||
ShortDescription: "prints available commands or help <command> for specific usage info",
|
||||
Callback: func(a interface{}, args []string, w StringWriter) error {
|
||||
Callback: func(a any, args []string, w StringWriter) error {
|
||||
return helpCallback(s.commands, args, w)
|
||||
},
|
||||
})
|
||||
|
||||
@@ -9,13 +9,13 @@ import (
|
||||
"github.com/armon/go-radix"
|
||||
"github.com/sirupsen/logrus"
|
||||
"golang.org/x/crypto/ssh"
|
||||
"golang.org/x/crypto/ssh/terminal"
|
||||
"golang.org/x/term"
|
||||
)
|
||||
|
||||
type session struct {
|
||||
l *logrus.Entry
|
||||
c *ssh.ServerConn
|
||||
term *terminal.Terminal
|
||||
term *term.Terminal
|
||||
commands *radix.Tree
|
||||
exitChan chan bool
|
||||
}
|
||||
@@ -31,7 +31,7 @@ func NewSession(commands *radix.Tree, conn *ssh.ServerConn, chans <-chan ssh.New
|
||||
s.commands.Insert("logout", &Command{
|
||||
Name: "logout",
|
||||
ShortDescription: "Ends the current session",
|
||||
Callback: func(a interface{}, args []string, w StringWriter) error {
|
||||
Callback: func(a any, args []string, w StringWriter) error {
|
||||
s.Close()
|
||||
return nil
|
||||
},
|
||||
@@ -106,8 +106,8 @@ func (s *session) handleRequests(in <-chan *ssh.Request, channel ssh.Channel) {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *session) createTerm(channel ssh.Channel) *terminal.Terminal {
|
||||
term := terminal.NewTerminal(channel, s.c.User()+"@nebula > ")
|
||||
func (s *session) createTerm(channel ssh.Channel) *term.Terminal {
|
||||
term := term.NewTerminal(channel, s.c.User()+"@nebula > ")
|
||||
term.AutoCompleteCallback = func(line string, pos int, key rune) (newLine string, newPos int, ok bool) {
|
||||
// key 9 is tab
|
||||
if key == 9 {
|
||||
|
||||
@@ -13,7 +13,7 @@ import (
|
||||
|
||||
// AssertDeepCopyEqual checks to see if two variables have the same values but DO NOT share any memory
|
||||
// There is currently a special case for `time.loc` (as this code traverses into unexported fields)
|
||||
func AssertDeepCopyEqual(t *testing.T, a interface{}, b interface{}) {
|
||||
func AssertDeepCopyEqual(t *testing.T, a any, b any) {
|
||||
v1 := reflect.ValueOf(a)
|
||||
v2 := reflect.ValueOf(b)
|
||||
|
||||
|
||||
@@ -4,12 +4,14 @@ import (
|
||||
"errors"
|
||||
"io"
|
||||
"net/netip"
|
||||
|
||||
"github.com/slackhq/nebula/routing"
|
||||
)
|
||||
|
||||
type NoopTun struct{}
|
||||
|
||||
func (NoopTun) RouteFor(addr netip.Addr) netip.Addr {
|
||||
return netip.Addr{}
|
||||
func (NoopTun) RoutesFor(addr netip.Addr) routing.Gateways {
|
||||
return routing.Gateways{}
|
||||
}
|
||||
|
||||
func (NoopTun) Activate() error {
|
||||
|
||||
@@ -9,11 +9,11 @@ import (
|
||||
|
||||
type ContextualError struct {
|
||||
RealError error
|
||||
Fields map[string]interface{}
|
||||
Fields map[string]any
|
||||
Context string
|
||||
}
|
||||
|
||||
func NewContextualError(msg string, fields map[string]interface{}, realError error) *ContextualError {
|
||||
func NewContextualError(msg string, fields map[string]any, realError error) *ContextualError {
|
||||
return &ContextualError{Context: msg, Fields: fields, RealError: realError}
|
||||
}
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
type m map[string]interface{}
|
||||
type m = map[string]any
|
||||
|
||||
type TestLogWriter struct {
|
||||
Logs []string
|
||||
|
||||
Reference in New Issue
Block a user