Compare commits

..

5 Commits

Author SHA1 Message Date
Jay Wren
5ceac2b078 add a little context to dns 2025-04-18 17:20:14 -04:00
Wade Simmons
b8ea55eb90 optimize usage of bart (#1395)
Some checks failed
gofmt / Run gofmt (push) Successful in 9s
smoke-extra / Run extra smoke tests (push) Failing after 19s
smoke / Run multi node smoke test (push) Failing after 1m19s
Build and test / Build all and test on ubuntu-linux (push) Failing after 18m41s
Build and test / Build and test on linux with boringcrypto (push) Failing after 2m47s
Build and test / Build and test on linux with pkcs11 (push) Failing after 2m47s
Build and test / Build and test on macos-latest (push) Has been cancelled
Build and test / Build and test on windows-latest (push) Has been cancelled
Use `bart.Lite` and `.Contains` as suggested by the bart maintainer:

- 9455952eed (commitcomment-155362580)
2025-04-18 12:37:20 -04:00
dependabot[bot]
4eb056af9d Bump github.com/prometheus/client_golang from 1.21.1 to 1.22.0 (#1393)
Some checks failed
gofmt / Run gofmt (push) Successful in 24s
smoke-extra / Run extra smoke tests (push) Failing after 19s
smoke / Run multi node smoke test (push) Failing after 1m20s
Build and test / Build all and test on ubuntu-linux (push) Failing after 18m9s
Build and test / Build and test on linux with boringcrypto (push) Failing after 4m11s
Build and test / Build and test on linux with pkcs11 (push) Failing after 3m28s
Build and test / Build and test on macos-latest (push) Has been cancelled
Build and test / Build and test on windows-latest (push) Has been cancelled
Bumps [github.com/prometheus/client_golang](https://github.com/prometheus/client_golang) from 1.21.1 to 1.22.0.
- [Release notes](https://github.com/prometheus/client_golang/releases)
- [Changelog](https://github.com/prometheus/client_golang/blob/main/CHANGELOG.md)
- [Commits](https://github.com/prometheus/client_golang/compare/v1.21.1...v1.22.0)

---
updated-dependencies:
- dependency-name: github.com/prometheus/client_golang
  dependency-version: 1.22.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-04-17 06:43:55 -04:00
dependabot[bot]
e49f279004 Bump golang.org/x/net in the golang-x-dependencies group (#1392)
Bumps the golang-x-dependencies group with 1 update: [golang.org/x/net](https://github.com/golang/net).


Updates `golang.org/x/net` from 0.38.0 to 0.39.0
- [Commits](https://github.com/golang/net/compare/v0.38.0...v0.39.0)

---
updated-dependencies:
- dependency-name: golang.org/x/net
  dependency-version: 0.39.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
  dependency-group: golang-x-dependencies
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-04-17 06:41:53 -04:00
dependabot[bot]
459cb38a6d Bump github.com/gaissmai/bart from 0.20.1 to 0.20.4 (#1391)
Some checks failed
gofmt / Run gofmt (push) Successful in 24s
smoke-extra / Run extra smoke tests (push) Failing after 28s
smoke / Run multi node smoke test (push) Failing after 1m24s
Build and test / Build all and test on ubuntu-linux (push) Failing after 18m16s
Build and test / Build and test on linux with boringcrypto (push) Failing after 2m49s
Build and test / Build and test on linux with pkcs11 (push) Failing after 4m13s
Build and test / Build and test on macos-latest (push) Has been cancelled
Build and test / Build and test on windows-latest (push) Has been cancelled
* Bump github.com/gaissmai/bart from 0.20.1 to 0.20.4

Bumps [github.com/gaissmai/bart](https://github.com/gaissmai/bart) from 0.20.1 to 0.20.4.
- [Release notes](https://github.com/gaissmai/bart/releases)
- [Commits](https://github.com/gaissmai/bart/compare/v0.20.1...v0.20.4)

---
updated-dependencies:
- dependency-name: github.com/gaissmai/bart
  dependency-version: 0.20.4
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>

* set back to go 1.23.0

We were only on 1.23.6 because of bart in the first place.

---------

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Wade Simmons <wsimmons@slack-corp.com>
2025-04-16 11:46:46 -04:00
53 changed files with 294 additions and 263 deletions

View File

@@ -25,14 +25,14 @@ func TestNewAllowListFromConfig(t *testing.T) {
c.Settings["allowlist"] = map[string]any{ c.Settings["allowlist"] = map[string]any{
"192.168.0.0/16": "abc", "192.168.0.0/16": "abc",
} }
_, err = newAllowListFromConfig(c, "allowlist", nil) r, err = newAllowListFromConfig(c, "allowlist", nil)
require.EqualError(t, err, "config `allowlist` has invalid value (type string): abc") require.EqualError(t, err, "config `allowlist` has invalid value (type string): abc")
c.Settings["allowlist"] = map[string]any{ c.Settings["allowlist"] = map[string]any{
"192.168.0.0/16": true, "192.168.0.0/16": true,
"10.0.0.0/8": false, "10.0.0.0/8": false,
} }
_, err = newAllowListFromConfig(c, "allowlist", nil) r, err = newAllowListFromConfig(c, "allowlist", nil)
require.EqualError(t, err, "config `allowlist` contains both true and false rules, but no default set for 0.0.0.0/0") require.EqualError(t, err, "config `allowlist` contains both true and false rules, but no default set for 0.0.0.0/0")
c.Settings["allowlist"] = map[string]any{ c.Settings["allowlist"] = map[string]any{
@@ -42,7 +42,7 @@ func TestNewAllowListFromConfig(t *testing.T) {
"fd00::/8": true, "fd00::/8": true,
"fd00:fd00::/16": false, "fd00:fd00::/16": false,
} }
_, err = newAllowListFromConfig(c, "allowlist", nil) r, err = newAllowListFromConfig(c, "allowlist", nil)
require.EqualError(t, err, "config `allowlist` contains both true and false rules, but no default set for ::/0") require.EqualError(t, err, "config `allowlist` contains both true and false rules, but no default set for ::/0")
c.Settings["allowlist"] = map[string]any{ c.Settings["allowlist"] = map[string]any{
@@ -75,7 +75,7 @@ func TestNewAllowListFromConfig(t *testing.T) {
`docker.*`: "foo", `docker.*`: "foo",
}, },
} }
_, err = NewLocalAllowListFromConfig(c, "allowlist") lr, err := NewLocalAllowListFromConfig(c, "allowlist")
require.EqualError(t, err, "config `allowlist.interfaces` has invalid value (type string): foo") require.EqualError(t, err, "config `allowlist.interfaces` has invalid value (type string): foo")
c.Settings["allowlist"] = map[string]any{ c.Settings["allowlist"] = map[string]any{
@@ -84,7 +84,7 @@ func TestNewAllowListFromConfig(t *testing.T) {
`eth.*`: true, `eth.*`: true,
}, },
} }
_, err = NewLocalAllowListFromConfig(c, "allowlist") lr, err = NewLocalAllowListFromConfig(c, "allowlist")
require.EqualError(t, err, "config `allowlist.interfaces` values must all be the same true/false value") require.EqualError(t, err, "config `allowlist.interfaces` values must all be the same true/false value")
c.Settings["allowlist"] = map[string]any{ c.Settings["allowlist"] = map[string]any{
@@ -92,7 +92,7 @@ func TestNewAllowListFromConfig(t *testing.T) {
`docker.*`: false, `docker.*`: false,
}, },
} }
lr, err := NewLocalAllowListFromConfig(c, "allowlist") lr, err = NewLocalAllowListFromConfig(c, "allowlist")
if assert.NoError(t, err) { if assert.NoError(t, err) {
assert.NotNil(t, lr) assert.NotNil(t, lr)
} }

10
bits.go
View File

@@ -18,7 +18,7 @@ type Bits struct {
func NewBits(bits uint64) *Bits { func NewBits(bits uint64) *Bits {
return &Bits{ return &Bits{
length: bits, length: bits,
bits: make([]bool, bits), bits: make([]bool, bits, bits),
current: 0, current: 0,
lostCounter: metrics.GetOrRegisterCounter("network.packets.lost", nil), lostCounter: metrics.GetOrRegisterCounter("network.packets.lost", nil),
dupeCounter: metrics.GetOrRegisterCounter("network.packets.duplicate", nil), dupeCounter: metrics.GetOrRegisterCounter("network.packets.duplicate", nil),
@@ -28,7 +28,7 @@ func NewBits(bits uint64) *Bits {
func (b *Bits) Check(l logrus.FieldLogger, i uint64) bool { func (b *Bits) Check(l logrus.FieldLogger, i uint64) bool {
// If i is the next number, return true. // If i is the next number, return true.
if i > b.current || (i == 0 && !b.firstSeen && b.current < b.length) { if i > b.current || (i == 0 && b.firstSeen == false && b.current < b.length) {
return true return true
} }
@@ -51,7 +51,7 @@ func (b *Bits) Update(l *logrus.Logger, i uint64) bool {
// If i is the next number, return true and update current. // If i is the next number, return true and update current.
if i == b.current+1 { if i == b.current+1 {
// Report missed packets, we can only understand what was missed after the first window has been gone through // Report missed packets, we can only understand what was missed after the first window has been gone through
if i > b.length && !b.bits[i%b.length] { if i > b.length && b.bits[i%b.length] == false {
b.lostCounter.Inc(1) b.lostCounter.Inc(1)
} }
b.bits[i%b.length] = true b.bits[i%b.length] = true
@@ -104,7 +104,7 @@ func (b *Bits) Update(l *logrus.Logger, i uint64) bool {
} }
// Allow for the 0 packet to come in within the first window // Allow for the 0 packet to come in within the first window
if i == 0 && !b.firstSeen && b.current < b.length { if i == 0 && b.firstSeen == false && b.current < b.length {
b.firstSeen = true b.firstSeen = true
b.bits[i%b.length] = true b.bits[i%b.length] = true
return true return true
@@ -122,7 +122,7 @@ func (b *Bits) Update(l *logrus.Logger, i uint64) bool {
return false return false
} }
if b.bits[i%b.length] { if b.bits[i%b.length] == true {
if l.Level >= logrus.DebugLevel { if l.Level >= logrus.DebugLevel {
l.WithField("receiveWindow", m{"accepted": false, "currentCounter": b.current, "incomingCounter": i, "reason": "old duplicate"}). l.WithField("receiveWindow", m{"accepted": false, "currentCounter": b.current, "incomingCounter": i, "reason": "old duplicate"}).
Debug("Receive window") Debug("Receive window")

View File

@@ -20,6 +20,8 @@ import (
"google.golang.org/protobuf/proto" "google.golang.org/protobuf/proto"
) )
const publicKeyLen = 32
type certificateV1 struct { type certificateV1 struct {
details detailsV1 details detailsV1
signature []byte signature []byte

View File

@@ -113,14 +113,14 @@ func TestCertificateV2_MarshalJSON(t *testing.T) {
signature: []byte("1234567890abcedf1234567890abcedf1234567890abcedf1234567890abcedf"), signature: []byte("1234567890abcedf1234567890abcedf1234567890abcedf1234567890abcedf"),
} }
_, err := nc.MarshalJSON() b, err := nc.MarshalJSON()
require.ErrorIs(t, err, ErrMissingDetails) require.ErrorIs(t, err, ErrMissingDetails)
rd, err := nc.details.Marshal() rd, err := nc.details.Marshal()
require.NoError(t, err) require.NoError(t, err)
nc.rawDetails = rd nc.rawDetails = rd
b, err := nc.MarshalJSON() b, err = nc.MarshalJSON()
require.NoError(t, err) require.NoError(t, err)
assert.JSONEq( assert.JSONEq(
t, t,
@@ -174,9 +174,8 @@ func TestCertificateV2_VerifyPrivateKey(t *testing.T) {
require.ErrorIs(t, err, ErrInvalidPrivateKey) require.ErrorIs(t, err, ErrInvalidPrivateKey)
c, _, priv, _ = NewTestCert(Version2, Curve_P256, ca2, caKey2, "test", time.Time{}, time.Time{}, nil, nil, nil) c, _, priv, _ = NewTestCert(Version2, Curve_P256, ca2, caKey2, "test", time.Time{}, time.Time{}, nil, nil, nil)
_, _, curve, err = UnmarshalPrivateKeyFromPEM(priv) rawPriv, b, curve, err = UnmarshalPrivateKeyFromPEM(priv)
assert.Equal(t, err, nil)
assert.Equal(t, curve, Curve_P256)
err = c.VerifyPrivateKey(Curve_P256, priv[:16]) err = c.VerifyPrivateKey(Curve_P256, priv[:16])
require.ErrorIs(t, err, ErrInvalidPrivateKey) require.ErrorIs(t, err, ErrInvalidPrivateKey)
@@ -262,7 +261,6 @@ func TestCertificateV2_marshalForSigningStability(t *testing.T) {
assert.Equal(t, expectedRawDetails, db) assert.Equal(t, expectedRawDetails, db)
expectedForSigning, err := hex.DecodeString(expectedRawDetailsStr + "00313233343536373839306162636564666768696a313233343536373839306162") expectedForSigning, err := hex.DecodeString(expectedRawDetailsStr + "00313233343536373839306162636564666768696a313233343536373839306162")
require.NoError(t, err)
b, err := nc.marshalForSigning() b, err := nc.marshalForSigning()
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, expectedForSigning, b) assert.Equal(t, expectedForSigning, b)

View File

@@ -227,9 +227,6 @@ func UnmarshalNebulaEncryptedData(b []byte) (*NebulaEncryptedData, error) {
} }
func unmarshalArgon2Parameters(params *RawNebulaArgon2Parameters) (*Argon2Parameters, error) { func unmarshalArgon2Parameters(params *RawNebulaArgon2Parameters) (*Argon2Parameters, error) {
// Are we testing the compilers types here?
// No value of int32 is lewss than math.MinInt32.
// By definition these checks can never be true.
if params.Version < math.MinInt32 || params.Version > math.MaxInt32 { if params.Version < math.MinInt32 || params.Version > math.MaxInt32 {
return nil, fmt.Errorf("Argon2Parameters Version must be at least %d and no more than %d", math.MinInt32, math.MaxInt32) return nil, fmt.Errorf("Argon2Parameters Version must be at least %d and no more than %d", math.MinInt32, math.MaxInt32)
} }

View File

@@ -72,14 +72,12 @@ qrlJ69wer3ZUHFXA
require.EqualError(t, err, "key was not 64 bytes, is invalid ed25519 private key") require.EqualError(t, err, "key was not 64 bytes, is invalid ed25519 private key")
assert.Nil(t, k) assert.Nil(t, k)
assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem)) assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem))
assert.Equal(t, curve, Curve_CURVE25519)
// Fail due to invalid banner // Fail due to invalid banner
curve, k, rest, err = DecryptAndUnmarshalSigningPrivateKey(passphrase, rest) curve, k, rest, err = DecryptAndUnmarshalSigningPrivateKey(passphrase, rest)
require.EqualError(t, err, "bytes did not contain a proper nebula encrypted Ed25519/ECDSA private key banner") require.EqualError(t, err, "bytes did not contain a proper nebula encrypted Ed25519/ECDSA private key banner")
assert.Nil(t, k) assert.Nil(t, k)
assert.Equal(t, rest, invalidPem) assert.Equal(t, rest, invalidPem)
assert.Equal(t, curve, Curve_CURVE25519)
// Fail due to ivalid PEM format, because // Fail due to ivalid PEM format, because
// it's missing the requisite pre-encapsulation boundary. // it's missing the requisite pre-encapsulation boundary.
@@ -87,14 +85,12 @@ qrlJ69wer3ZUHFXA
require.EqualError(t, err, "input did not contain a valid PEM encoded block") require.EqualError(t, err, "input did not contain a valid PEM encoded block")
assert.Nil(t, k) assert.Nil(t, k)
assert.Equal(t, rest, invalidPem) assert.Equal(t, rest, invalidPem)
assert.Equal(t, curve, Curve_CURVE25519)
// Fail due to invalid passphrase // Fail due to invalid passphrase
curve, k, rest, err = DecryptAndUnmarshalSigningPrivateKey([]byte("invalid passphrase"), privKey) curve, k, rest, err = DecryptAndUnmarshalSigningPrivateKey([]byte("invalid passphrase"), privKey)
require.EqualError(t, err, "invalid passphrase or corrupt private key") require.EqualError(t, err, "invalid passphrase or corrupt private key")
assert.Nil(t, k) assert.Nil(t, k)
assert.Equal(t, []byte{}, rest) assert.Equal(t, []byte{}, rest)
assert.Equal(t, curve, Curve_CURVE25519)
} }
func TestEncryptAndMarshalSigningPrivateKey(t *testing.T) { func TestEncryptAndMarshalSigningPrivateKey(t *testing.T) {

View File

@@ -21,9 +21,6 @@ func NewTestCaCert(version Version, curve Curve, before, after time.Time, networ
switch curve { switch curve {
case Curve_CURVE25519: case Curve_CURVE25519:
pub, priv, err = ed25519.GenerateKey(rand.Reader) pub, priv, err = ed25519.GenerateKey(rand.Reader)
if err != nil {
panic(err)
}
case Curve_P256: case Curve_P256:
privk, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) privk, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil { if err != nil {

View File

@@ -97,14 +97,12 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
// Fail due to short key // Fail due to short key
k, rest, curve, err = UnmarshalSigningPrivateKeyFromPEM(rest) k, rest, curve, err = UnmarshalSigningPrivateKeyFromPEM(rest)
assert.Nil(t, k) assert.Nil(t, k)
assert.Equal(t, Curve_CURVE25519, curve)
assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem)) assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem))
require.EqualError(t, err, "key was not 64 bytes, is invalid Ed25519 private key") require.EqualError(t, err, "key was not 64 bytes, is invalid Ed25519 private key")
// Fail due to invalid banner // Fail due to invalid banner
k, rest, curve, err = UnmarshalSigningPrivateKeyFromPEM(rest) k, rest, curve, err = UnmarshalSigningPrivateKeyFromPEM(rest)
assert.Nil(t, k) assert.Nil(t, k)
assert.Equal(t, Curve_CURVE25519, curve)
assert.Equal(t, rest, invalidPem) assert.Equal(t, rest, invalidPem)
require.EqualError(t, err, "bytes did not contain a proper Ed25519/ECDSA private key banner") require.EqualError(t, err, "bytes did not contain a proper Ed25519/ECDSA private key banner")
@@ -112,7 +110,6 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
// it's missing the requisite pre-encapsulation boundary. // it's missing the requisite pre-encapsulation boundary.
k, rest, curve, err = UnmarshalSigningPrivateKeyFromPEM(rest) k, rest, curve, err = UnmarshalSigningPrivateKeyFromPEM(rest)
assert.Nil(t, k) assert.Nil(t, k)
assert.Equal(t, Curve_CURVE25519, curve)
assert.Equal(t, rest, invalidPem) assert.Equal(t, rest, invalidPem)
require.EqualError(t, err, "input did not contain a valid PEM encoded block") require.EqualError(t, err, "input did not contain a valid PEM encoded block")
} }
@@ -162,14 +159,12 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
// Fail due to short key // Fail due to short key
k, rest, curve, err = UnmarshalPrivateKeyFromPEM(rest) k, rest, curve, err = UnmarshalPrivateKeyFromPEM(rest)
assert.Nil(t, k) assert.Nil(t, k)
assert.Equal(t, Curve_CURVE25519, curve)
assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem)) assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem))
require.EqualError(t, err, "key was not 32 bytes, is invalid CURVE25519 private key") require.EqualError(t, err, "key was not 32 bytes, is invalid CURVE25519 private key")
// Fail due to invalid banner // Fail due to invalid banner
k, rest, curve, err = UnmarshalPrivateKeyFromPEM(rest) k, rest, curve, err = UnmarshalPrivateKeyFromPEM(rest)
assert.Nil(t, k) assert.Nil(t, k)
assert.Equal(t, Curve_CURVE25519, curve)
assert.Equal(t, rest, invalidPem) assert.Equal(t, rest, invalidPem)
require.EqualError(t, err, "bytes did not contain a proper private key banner") require.EqualError(t, err, "bytes did not contain a proper private key banner")
@@ -177,7 +172,6 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
// it's missing the requisite pre-encapsulation boundary. // it's missing the requisite pre-encapsulation boundary.
k, rest, curve, err = UnmarshalPrivateKeyFromPEM(rest) k, rest, curve, err = UnmarshalPrivateKeyFromPEM(rest)
assert.Nil(t, k) assert.Nil(t, k)
assert.Equal(t, Curve_CURVE25519, curve)
assert.Equal(t, rest, invalidPem) assert.Equal(t, rest, invalidPem)
require.EqualError(t, err, "input did not contain a valid PEM encoded block") require.EqualError(t, err, "input did not contain a valid PEM encoded block")
} }
@@ -281,14 +275,12 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
// Fail due to short key // Fail due to short key
k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest) k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest)
assert.Nil(t, k) assert.Nil(t, k)
assert.Equal(t, Curve_CURVE25519, curve)
assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem)) assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem))
require.EqualError(t, err, "key was not 32 bytes, is invalid CURVE25519 public key") require.EqualError(t, err, "key was not 32 bytes, is invalid CURVE25519 public key")
// Fail due to invalid banner // Fail due to invalid banner
k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest) k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest)
assert.Nil(t, k) assert.Nil(t, k)
assert.Equal(t, Curve_CURVE25519, curve)
require.EqualError(t, err, "bytes did not contain a proper public key banner") require.EqualError(t, err, "bytes did not contain a proper public key banner")
assert.Equal(t, rest, invalidPem) assert.Equal(t, rest, invalidPem)
@@ -296,7 +288,6 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
// it's missing the requisite pre-encapsulation boundary. // it's missing the requisite pre-encapsulation boundary.
k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest) k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest)
assert.Nil(t, k) assert.Nil(t, k)
assert.Equal(t, Curve_CURVE25519, curve)
assert.Equal(t, rest, invalidPem) assert.Equal(t, rest, invalidPem)
require.EqualError(t, err, "input did not contain a valid PEM encoded block") require.EqualError(t, err, "input did not contain a valid PEM encoded block")
} }

View File

@@ -37,7 +37,6 @@ func TestCertificateV1_Sign(t *testing.T) {
} }
pub, priv, err := ed25519.GenerateKey(rand.Reader) pub, priv, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
c, err := tbs.Sign(&certificateV1{details: detailsV1{notBefore: before, notAfter: after}}, Curve_CURVE25519, priv) c, err := tbs.Sign(&certificateV1{details: detailsV1{notBefore: before, notAfter: after}}, Curve_CURVE25519, priv)
require.NoError(t, err) require.NoError(t, err)
assert.NotNil(t, c) assert.NotNil(t, c)

View File

@@ -22,9 +22,6 @@ func NewTestCaCert(version cert.Version, curve cert.Curve, before, after time.Ti
switch curve { switch curve {
case cert.Curve_CURVE25519: case cert.Curve_CURVE25519:
pub, priv, err = ed25519.GenerateKey(rand.Reader) pub, priv, err = ed25519.GenerateKey(rand.Reader)
if err != nil {
panic(err)
}
case cert.Curve_P256: case cert.Curve_P256:
privk, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) privk, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil { if err != nil {

View File

@@ -81,7 +81,7 @@ func parseArgonParameters(memory uint, parallelism uint, iterations uint) (*cert
return cert.NewArgon2Parameters(uint32(memory), uint8(parallelism), uint32(iterations)), nil return cert.NewArgon2Parameters(uint32(memory), uint8(parallelism), uint32(iterations)), nil
} }
func ca(args []string, out io.Writer, _ io.Writer, pr PasswordReader) error { func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error {
cf := newCaFlags() cf := newCaFlags()
err := cf.set.Parse(args) err := cf.set.Parse(args)
if err != nil { if err != nil {

View File

@@ -29,7 +29,7 @@ func newKeygenFlags() *keygenFlags {
return &cf return &cf
} }
func keygen(args []string, _ io.Writer, _ io.Writer) error { func keygen(args []string, out io.Writer, errOut io.Writer) error {
cf := newKeygenFlags() cf := newKeygenFlags()
err := cf.set.Parse(args) err := cf.set.Parse(args)
if err != nil { if err != nil {

View File

@@ -3,6 +3,7 @@ package main
import ( import (
"bytes" "bytes"
"errors" "errors"
"fmt"
"io" "io"
"os" "os"
"testing" "testing"
@@ -76,7 +77,7 @@ func assertHelpError(t *testing.T, err error, msg string) {
case *helpError: case *helpError:
// good // good
default: default:
t.Fatalf("err was not a helpError: %q, expected %q", err, msg) t.Fatal(fmt.Sprintf("err was not a helpError: %q, expected %q", err, msg))
} }
require.EqualError(t, err, msg) require.EqualError(t, err, msg)

View File

@@ -10,7 +10,7 @@ func p11Supported() bool {
return false return false
} }
func p11Flag(_ *flag.FlagSet) *string { func p11Flag(set *flag.FlagSet) *string {
var ret = "" var ret = ""
return &ret return &ret
} }

View File

@@ -1,12 +1,12 @@
package main package main
import ( import (
"bytes"
"encoding/json" "encoding/json"
"flag" "flag"
"fmt" "fmt"
"io" "io"
"os" "os"
"strings"
"github.com/skip2/go-qrcode" "github.com/skip2/go-qrcode"
"github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/cert"
@@ -29,7 +29,7 @@ func newPrintFlags() *printFlags {
return &pf return &pf
} }
func printCert(args []string, out io.Writer, _ io.Writer) error { func printCert(args []string, out io.Writer, errOut io.Writer) error {
pf := newPrintFlags() pf := newPrintFlags()
err := pf.set.Parse(args) err := pf.set.Parse(args)
if err != nil { if err != nil {
@@ -72,7 +72,7 @@ func printCert(args []string, out io.Writer, _ io.Writer) error {
qrBytes = append(qrBytes, b...) qrBytes = append(qrBytes, b...)
} }
if len(rawCert) == 0 || len(bytes.TrimSpace(rawCert)) == 0 { if rawCert == nil || len(rawCert) == 0 || strings.TrimSpace(string(rawCert)) == "" {
break break
} }

View File

@@ -1,12 +1,12 @@
package main package main
import ( import (
"bytes"
"errors" "errors"
"flag" "flag"
"fmt" "fmt"
"io" "io"
"os" "os"
"strings"
"time" "time"
"github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/cert"
@@ -52,7 +52,7 @@ func verify(args []string, out io.Writer, errOut io.Writer) error {
return fmt.Errorf("error while adding ca cert to pool: %w", err) return fmt.Errorf("error while adding ca cert to pool: %w", err)
} }
if len(rawCACert) == 0 || len(bytes.TrimSpace(rawCACert)) == 0 { if rawCACert == nil || len(rawCACert) == 0 || strings.TrimSpace(string(rawCACert)) == "" {
break break
} }
} }

View File

@@ -97,7 +97,7 @@ func Test_verify(t *testing.T) {
crt, _ := NewTestCert(ca, caPriv, "test-cert", time.Now().Add(time.Hour*-1), time.Now().Add(time.Hour), nil, nil, nil) crt, _ := NewTestCert(ca, caPriv, "test-cert", time.Now().Add(time.Hour*-1), time.Now().Add(time.Hour), nil, nil, nil)
// Slightly evil hack to modify the certificate after it was sealed to generate an invalid signature // Slightly evil hack to modify the certificate after it was sealed to generate an invalid signature
pub := crt.PublicKey() pub := crt.PublicKey()
for i := range pub { for i, _ := range pub {
pub[i] = 0 pub[i] = 0
} }
b, _ = crt.MarshalPEM() b, _ = crt.MarshalPEM()

View File

@@ -51,7 +51,10 @@ func (p *program) Stop(s service.Service) error {
func fileExists(filename string) bool { func fileExists(filename string) bool {
_, err := os.Stat(filename) _, err := os.Stat(filename)
return !os.IsNotExist(err) if os.IsNotExist(err) {
return false
}
return true
} }
func doService(configPath *string, configTest *bool, build string, serviceFlag *string) { func doService(configPath *string, configTest *bool, build string, serviceFlag *string) {

View File

@@ -63,7 +63,7 @@ func (c *C) Load(path string) error {
func (c *C) LoadString(raw string) error { func (c *C) LoadString(raw string) error {
if raw == "" { if raw == "" {
return errors.New("empty configuration") return errors.New("Empty configuration")
} }
return c.parseRaw([]byte(raw)) return c.parseRaw([]byte(raw))
} }

View File

@@ -154,7 +154,7 @@ func (n *connectionManager) Run(ctx context.Context) {
defer clockSource.Stop() defer clockSource.Stop()
p := []byte("") p := []byte("")
nb := make([]byte, 12) nb := make([]byte, 12, 12)
out := make([]byte, mtu) out := make([]byte, mtu)
for { for {
@@ -355,7 +355,7 @@ func (n *connectionManager) makeTrafficDecision(localIndex uint32, now time.Time
decision = tryRehandshake decision = tryRehandshake
} else { } else {
if n.shouldSwapPrimary(hostinfo) { if n.shouldSwapPrimary(hostinfo, primary) {
decision = swapPrimary decision = swapPrimary
} else { } else {
// migrate the relays to the primary, if in use. // migrate the relays to the primary, if in use.
@@ -384,7 +384,7 @@ func (n *connectionManager) makeTrafficDecision(localIndex uint32, now time.Time
} }
decision := doNothing decision := doNothing
if hostinfo.ConnectionState != nil && mainHostInfo { if hostinfo != nil && hostinfo.ConnectionState != nil && mainHostInfo {
if !outTraffic { if !outTraffic {
// If we aren't sending or receiving traffic then its an unused tunnel and we don't to test the tunnel. // If we aren't sending or receiving traffic then its an unused tunnel and we don't to test the tunnel.
// Just maintain NAT state if configured to do so. // Just maintain NAT state if configured to do so.
@@ -421,7 +421,7 @@ func (n *connectionManager) makeTrafficDecision(localIndex uint32, now time.Time
return decision, hostinfo, nil return decision, hostinfo, nil
} }
func (n *connectionManager) shouldSwapPrimary(current *HostInfo) bool { func (n *connectionManager) shouldSwapPrimary(current, primary *HostInfo) bool {
// The primary tunnel is the most recent handshake to complete locally and should work entirely fine. // The primary tunnel is the most recent handshake to complete locally and should work entirely fine.
// If we are here then we have multiple tunnels for a host pair and neither side believes the same tunnel is primary. // If we are here then we have multiple tunnels for a host pair and neither side believes the same tunnel is primary.
// Let's sort this out. // Let's sort this out.
@@ -498,7 +498,7 @@ func (n *connectionManager) tryRehandshake(hostinfo *HostInfo) {
cs := n.intf.pki.getCertState() cs := n.intf.pki.getCertState()
curCrt := hostinfo.ConnectionState.myCert curCrt := hostinfo.ConnectionState.myCert
myCrt := cs.getCertificate(curCrt.Version()) myCrt := cs.getCertificate(curCrt.Version())
if curCrt.Version() >= cs.initiatingVersion && bytes.Equal(curCrt.Signature(), myCrt.Signature()) { if curCrt.Version() >= cs.initiatingVersion && bytes.Equal(curCrt.Signature(), myCrt.Signature()) == true {
// The current tunnel is using the latest certificate and version, no need to rehandshake. // The current tunnel is using the latest certificate and version, no need to rehandshake.
return return
} }

View File

@@ -69,7 +69,7 @@ func Test_NewConnectionManagerTest(t *testing.T) {
punchy := NewPunchyFromConfig(l, config.NewC(l)) punchy := NewPunchyFromConfig(l, config.NewC(l))
nc := newConnectionManager(ctx, l, ifce, 5, 10, punchy) nc := newConnectionManager(ctx, l, ifce, 5, 10, punchy)
p := []byte("") p := []byte("")
nb := make([]byte, 12) nb := make([]byte, 12, 12)
out := make([]byte, mtu) out := make([]byte, mtu)
// Add an ip we have established a connection w/ to hostmap // Add an ip we have established a connection w/ to hostmap
@@ -151,7 +151,7 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
punchy := NewPunchyFromConfig(l, config.NewC(l)) punchy := NewPunchyFromConfig(l, config.NewC(l))
nc := newConnectionManager(ctx, l, ifce, 5, 10, punchy) nc := newConnectionManager(ctx, l, ifce, 5, 10, punchy)
p := []byte("") p := []byte("")
nb := make([]byte, 12) nb := make([]byte, 12, 12)
out := make([]byte, mtu) out := make([]byte, mtu)
// Add an ip we have established a connection w/ to hostmap // Add an ip we have established a connection w/ to hostmap
@@ -241,7 +241,7 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
cachedPeerCert, err := ncp.VerifyCertificate(now.Add(time.Second), peerCert) cachedPeerCert, err := ncp.VerifyCertificate(now.Add(time.Second), peerCert)
require.NoError(t, err)
cs := &CertState{ cs := &CertState{
privateKey: []byte{}, privateKey: []byte{},
v1Cert: &dummyCert{}, v1Cert: &dummyCert{},

View File

@@ -131,8 +131,7 @@ func (c *Control) ListHostmapIndexes(pendingMap bool) []ControlHostInfo {
// GetCertByVpnIp returns the authenticated certificate of the given vpn IP, or nil if not found // GetCertByVpnIp returns the authenticated certificate of the given vpn IP, or nil if not found
func (c *Control) GetCertByVpnIp(vpnIp netip.Addr) cert.Certificate { func (c *Control) GetCertByVpnIp(vpnIp netip.Addr) cert.Certificate {
_, found := c.f.myVpnAddrsTable.Lookup(vpnIp) if c.f.myVpnAddrsTable.Contains(vpnIp) {
if found {
// Only returning the default certificate since its impossible // Only returning the default certificate since its impossible
// for any other host but ourselves to have more than 1 // for any other host but ourselves to have more than 1
return c.f.pki.getCertState().GetDefaultCertificate().Copy() return c.f.pki.getCertState().GetDefaultCertificate().Copy()
@@ -215,7 +214,7 @@ func (c *Control) CloseTunnel(vpnIp netip.Addr, localOnly bool) bool {
hostInfo.ConnectionState, hostInfo.ConnectionState,
hostInfo, hostInfo,
[]byte{}, []byte{},
make([]byte, 12), make([]byte, 12, 12),
make([]byte, mtu), make([]byte, mtu),
) )
} }
@@ -231,7 +230,7 @@ func (c *Control) CloseAllTunnels(excludeLighthouses bool) (closed int) {
if excludeLighthouses && c.f.lightHouse.IsAnyLighthouseAddr(h.vpnAddrs) { if excludeLighthouses && c.f.lightHouse.IsAnyLighthouseAddr(h.vpnAddrs) {
return return
} }
c.f.send(header.CloseTunnel, 0, h.ConnectionState, h, []byte{}, make([]byte, 12), make([]byte, mtu)) c.f.send(header.CloseTunnel, 0, h.ConnectionState, h, []byte{}, make([]byte, 12, 12), make([]byte, mtu))
c.f.closeTunnel(h) c.f.closeTunnel(h)
c.l.WithField("vpnAddrs", h.vpnAddrs).WithField("udpAddr", h.remote). c.l.WithField("vpnAddrs", h.vpnAddrs).WithField("udpAddr", h.remote).
@@ -282,7 +281,9 @@ func copyHostInfo(h *HostInfo, preferredRanges []netip.Prefix) ControlHostInfo {
CurrentRemote: h.remote, CurrentRemote: h.remote,
} }
copy(chi.VpnAddrs, h.vpnAddrs) for i, a := range h.vpnAddrs {
chi.VpnAddrs[i] = a
}
if h.ConnectionState != nil { if h.ConnectionState != nil {
chi.MessageCounter = h.ConnectionState.messageCounter.Load() chi.MessageCounter = h.ConnectionState.messageCounter.Load()

View File

@@ -27,10 +27,12 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
ipNet := net.IPNet{ ipNet := net.IPNet{
IP: remote1.Addr().AsSlice(), IP: remote1.Addr().AsSlice(),
Mask: net.IPMask{255, 255, 255, 0},
} }
ipNet2 := net.IPNet{ ipNet2 := net.IPNet{
IP: remote2.Addr().AsSlice(), IP: remote2.Addr().AsSlice(),
Mask: net.IPMask{255, 255, 255, 0},
} }
remotes := NewRemoteList([]netip.Addr{netip.IPv4Unspecified()}, nil) remotes := NewRemoteList([]netip.Addr{netip.IPv4Unspecified()}, nil)

View File

@@ -1,6 +1,7 @@
package nebula package nebula
import ( import (
"context"
"fmt" "fmt"
"net" "net"
"net/netip" "net/netip"
@@ -26,7 +27,7 @@ type dnsRecords struct {
dnsMap4 map[string]netip.Addr dnsMap4 map[string]netip.Addr
dnsMap6 map[string]netip.Addr dnsMap6 map[string]netip.Addr
hostMap *HostMap hostMap *HostMap
myVpnAddrsTable *bart.Table[struct{}] myVpnAddrsTable *bart.Lite
} }
func newDnsRecords(l *logrus.Logger, cs *CertState, hostMap *HostMap) *dnsRecords { func newDnsRecords(l *logrus.Logger, cs *CertState, hostMap *HostMap) *dnsRecords {
@@ -39,7 +40,7 @@ func newDnsRecords(l *logrus.Logger, cs *CertState, hostMap *HostMap) *dnsRecord
} }
} }
func (d *dnsRecords) Query(q uint16, data string) netip.Addr { func (d *dnsRecords) query(q uint16, data string) netip.Addr {
data = strings.ToLower(data) data = strings.ToLower(data)
d.RLock() d.RLock()
defer d.RUnlock() defer d.RUnlock()
@@ -57,7 +58,7 @@ func (d *dnsRecords) Query(q uint16, data string) netip.Addr {
return netip.Addr{} return netip.Addr{}
} }
func (d *dnsRecords) QueryCert(data string) string { func (d *dnsRecords) queryCert(data string) string {
ip, err := netip.ParseAddr(data[:len(data)-1]) ip, err := netip.ParseAddr(data[:len(data)-1])
if err != nil { if err != nil {
return "" return ""
@@ -112,8 +113,8 @@ func (d *dnsRecords) isSelfNebulaOrLocalhost(addr string) bool {
return true return true
} }
_, found := d.myVpnAddrsTable.Lookup(b) //if we found it in this table, it's good
return found //if we found it in this table, it's good return d.myVpnAddrsTable.Contains(b)
} }
func (d *dnsRecords) parseQuery(m *dns.Msg, w dns.ResponseWriter) { func (d *dnsRecords) parseQuery(m *dns.Msg, w dns.ResponseWriter) {
@@ -122,7 +123,7 @@ func (d *dnsRecords) parseQuery(m *dns.Msg, w dns.ResponseWriter) {
case dns.TypeA, dns.TypeAAAA: case dns.TypeA, dns.TypeAAAA:
qType := dns.TypeToString[q.Qtype] qType := dns.TypeToString[q.Qtype]
d.l.Debugf("Query for %s %s", qType, q.Name) d.l.Debugf("Query for %s %s", qType, q.Name)
ip := d.Query(q.Qtype, q.Name) ip := d.query(q.Qtype, q.Name)
if ip.IsValid() { if ip.IsValid() {
rr, err := dns.NewRR(fmt.Sprintf("%s %s %s", q.Name, qType, ip)) rr, err := dns.NewRR(fmt.Sprintf("%s %s %s", q.Name, qType, ip))
if err == nil { if err == nil {
@@ -135,7 +136,7 @@ func (d *dnsRecords) parseQuery(m *dns.Msg, w dns.ResponseWriter) {
return return
} }
d.l.Debugf("Query for TXT %s", q.Name) d.l.Debugf("Query for TXT %s", q.Name)
ip := d.QueryCert(q.Name) ip := d.queryCert(q.Name)
if ip != "" { if ip != "" {
rr, err := dns.NewRR(fmt.Sprintf("%s TXT %s", q.Name, ip)) rr, err := dns.NewRR(fmt.Sprintf("%s TXT %s", q.Name, ip))
if err == nil { if err == nil {
@@ -163,18 +164,18 @@ func (d *dnsRecords) handleDnsRequest(w dns.ResponseWriter, r *dns.Msg) {
w.WriteMsg(m) w.WriteMsg(m)
} }
func dnsMain(l *logrus.Logger, cs *CertState, hostMap *HostMap, c *config.C) func() { func dnsMain(ctx context.Context, l *logrus.Logger, cs *CertState, hostMap *HostMap, c *config.C) func() {
dnsR = newDnsRecords(l, cs, hostMap) dnsR = newDnsRecords(l, cs, hostMap)
// attach request handler func // attach request handler func
dns.HandleFunc(".", dnsR.handleDnsRequest) dns.HandleFunc(".", dnsR.handleDnsRequest)
c.RegisterReloadCallback(func(c *config.C) { c.RegisterReloadCallback(func(c *config.C) {
reloadDns(l, c) reloadDns(ctx, l, c)
}) })
return func() { return func() {
startDns(l, c) startDns(ctx, l, c)
} }
} }
@@ -187,24 +188,24 @@ func getDnsServerAddr(c *config.C) string {
return net.JoinHostPort(dnsHost, strconv.Itoa(c.GetInt("lighthouse.dns.port", 53))) return net.JoinHostPort(dnsHost, strconv.Itoa(c.GetInt("lighthouse.dns.port", 53)))
} }
func startDns(l *logrus.Logger, c *config.C) { func startDns(ctx context.Context, l *logrus.Logger, c *config.C) {
dnsAddr = getDnsServerAddr(c) dnsAddr = getDnsServerAddr(c)
dnsServer = &dns.Server{Addr: dnsAddr, Net: "udp"} dnsServer = &dns.Server{Addr: dnsAddr, Net: "udp"}
l.WithField("dnsListener", dnsAddr).Info("Starting DNS responder") l.WithField("dnsListener", dnsAddr).Info("Starting DNS responder")
err := dnsServer.ListenAndServe() err := dnsServer.ListenAndServe()
defer dnsServer.Shutdown() defer dnsServer.ShutdownContext(ctx)
if err != nil { if err != nil {
l.Errorf("Failed to start server: %s\n ", err.Error()) l.Errorf("Failed to start server: %s\n ", err.Error())
} }
} }
func reloadDns(l *logrus.Logger, c *config.C) { func reloadDns(ctx context.Context, l *logrus.Logger, c *config.C) {
if dnsAddr == getDnsServerAddr(c) { if dnsAddr == getDnsServerAddr(c) {
l.Debug("No DNS server config change detected") l.Debug("No DNS server config change detected")
return return
} }
l.Debug("Restarting DNS server") l.Debug("Restarting DNS server")
dnsServer.Shutdown() dnsServer.ShutdownContext(ctx)
go startDns(l, c) go startDns(ctx, l, c)
} }

View File

@@ -53,7 +53,7 @@ type Firewall struct {
// routableNetworks describes the vpn addresses as well as any unsafe networks issued to us in the certificate. // routableNetworks describes the vpn addresses as well as any unsafe networks issued to us in the certificate.
// The vpn addresses are a full bit match while the unsafe networks only match the prefix // The vpn addresses are a full bit match while the unsafe networks only match the prefix
routableNetworks *bart.Table[struct{}] routableNetworks *bart.Lite
// assignedNetworks is a list of vpn networks assigned to us in the certificate. // assignedNetworks is a list of vpn networks assigned to us in the certificate.
assignedNetworks []netip.Prefix assignedNetworks []netip.Prefix
@@ -125,7 +125,7 @@ type firewallPort map[int32]*FirewallCA
type firewallLocalCIDR struct { type firewallLocalCIDR struct {
Any bool Any bool
LocalCIDR *bart.Table[struct{}] LocalCIDR *bart.Lite
} }
// NewFirewall creates a new Firewall object. A TimerWheel is created for you from the provided timeouts. // NewFirewall creates a new Firewall object. A TimerWheel is created for you from the provided timeouts.
@@ -148,17 +148,17 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D
tmax = defaultTimeout tmax = defaultTimeout
} }
routableNetworks := new(bart.Table[struct{}]) routableNetworks := new(bart.Lite)
var assignedNetworks []netip.Prefix var assignedNetworks []netip.Prefix
for _, network := range c.Networks() { for _, network := range c.Networks() {
nprefix := netip.PrefixFrom(network.Addr(), network.Addr().BitLen()) nprefix := netip.PrefixFrom(network.Addr(), network.Addr().BitLen())
routableNetworks.Insert(nprefix, struct{}{}) routableNetworks.Insert(nprefix)
assignedNetworks = append(assignedNetworks, network) assignedNetworks = append(assignedNetworks, network)
} }
hasUnsafeNetworks := false hasUnsafeNetworks := false
for _, n := range c.UnsafeNetworks() { for _, n := range c.UnsafeNetworks() {
routableNetworks.Insert(n, struct{}{}) routableNetworks.Insert(n)
hasUnsafeNetworks = true hasUnsafeNetworks = true
} }
@@ -431,8 +431,7 @@ func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *
// Make sure remote address matches nebula certificate // Make sure remote address matches nebula certificate
if h.networks != nil { if h.networks != nil {
_, ok := h.networks.Lookup(fp.RemoteAddr) if !h.networks.Contains(fp.RemoteAddr) {
if !ok {
f.metrics(incoming).droppedRemoteAddr.Inc(1) f.metrics(incoming).droppedRemoteAddr.Inc(1)
return ErrInvalidRemoteIP return ErrInvalidRemoteIP
} }
@@ -445,8 +444,7 @@ func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *
} }
// Make sure we are supposed to be handling this local ip address // Make sure we are supposed to be handling this local ip address
_, ok := f.routableNetworks.Lookup(fp.LocalAddr) if !f.routableNetworks.Contains(fp.LocalAddr) {
if !ok {
f.metrics(incoming).droppedLocalAddr.Inc(1) f.metrics(incoming).droppedLocalAddr.Inc(1)
return ErrInvalidLocalIP return ErrInvalidLocalIP
} }
@@ -606,7 +604,7 @@ func (f *Firewall) evict(p firewall.Packet) {
return return
} }
newT := time.Until(t.Expires) newT := t.Expires.Sub(time.Now())
// Timeout is in the future, re-add the timer // Timeout is in the future, re-add the timer
if newT > 0 { if newT > 0 {
@@ -752,7 +750,7 @@ func (fc *FirewallCA) match(p firewall.Packet, c *cert.CachedCertificate, caPool
func (fr *FirewallRule) addRule(f *Firewall, groups []string, host string, ip, localCIDR netip.Prefix) error { func (fr *FirewallRule) addRule(f *Firewall, groups []string, host string, ip, localCIDR netip.Prefix) error {
flc := func() *firewallLocalCIDR { flc := func() *firewallLocalCIDR {
return &firewallLocalCIDR{ return &firewallLocalCIDR{
LocalCIDR: new(bart.Table[struct{}]), LocalCIDR: new(bart.Lite),
} }
} }
@@ -832,7 +830,7 @@ func (fr *FirewallRule) match(p firewall.Packet, c *cert.CachedCertificate) bool
} }
// Shortcut path for if groups, hosts, or cidr contained an `any` // Shortcut path for if groups, hosts, or cidr contained an `any`
if fr.Any.match(p) { if fr.Any.match(p, c) {
return true return true
} }
@@ -849,21 +847,21 @@ func (fr *FirewallRule) match(p firewall.Packet, c *cert.CachedCertificate) bool
found = true found = true
} }
if found && sg.LocalCIDR.match(p) { if found && sg.LocalCIDR.match(p, c) {
return true return true
} }
} }
if fr.Hosts != nil { if fr.Hosts != nil {
if flc, ok := fr.Hosts[c.Certificate.Name()]; ok { if flc, ok := fr.Hosts[c.Certificate.Name()]; ok {
if flc.match(p) { if flc.match(p, c) {
return true return true
} }
} }
} }
for _, v := range fr.CIDR.Supernets(netip.PrefixFrom(p.RemoteAddr, p.RemoteAddr.BitLen())) { for _, v := range fr.CIDR.Supernets(netip.PrefixFrom(p.RemoteAddr, p.RemoteAddr.BitLen())) {
if v.match(p) { if v.match(p, c) {
return true return true
} }
} }
@@ -879,7 +877,7 @@ func (flc *firewallLocalCIDR) addRule(f *Firewall, localIp netip.Prefix) error {
} }
for _, network := range f.assignedNetworks { for _, network := range f.assignedNetworks {
flc.LocalCIDR.Insert(network, struct{}{}) flc.LocalCIDR.Insert(network)
} }
return nil return nil
@@ -888,11 +886,11 @@ func (flc *firewallLocalCIDR) addRule(f *Firewall, localIp netip.Prefix) error {
return nil return nil
} }
flc.LocalCIDR.Insert(localIp, struct{}{}) flc.LocalCIDR.Insert(localIp)
return nil return nil
} }
func (flc *firewallLocalCIDR) match(p firewall.Packet) bool { func (flc *firewallLocalCIDR) match(p firewall.Packet, c *cert.CachedCertificate) bool {
if flc == nil { if flc == nil {
return false return false
} }
@@ -901,8 +899,7 @@ func (flc *firewallLocalCIDR) match(p firewall.Packet) bool {
return true return true
} }
_, ok := flc.LocalCIDR.Lookup(p.LocalAddr) return flc.LocalCIDR.Contains(p.LocalAddr)
return ok
} }
type rule struct { type rule struct {

View File

@@ -35,27 +35,22 @@ func TestNewFirewall(t *testing.T) {
assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen) assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen)
fw = NewFirewall(l, time.Second, time.Hour, time.Minute, c) fw = NewFirewall(l, time.Second, time.Hour, time.Minute, c)
conntrack = fw.Conntrack
assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration) assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen) assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen)
fw = NewFirewall(l, time.Hour, time.Second, time.Minute, c) fw = NewFirewall(l, time.Hour, time.Second, time.Minute, c)
conntrack = fw.Conntrack
assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration) assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen) assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen)
fw = NewFirewall(l, time.Hour, time.Minute, time.Second, c) fw = NewFirewall(l, time.Hour, time.Minute, time.Second, c)
conntrack = fw.Conntrack
assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration) assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen) assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen)
fw = NewFirewall(l, time.Minute, time.Hour, time.Second, c) fw = NewFirewall(l, time.Minute, time.Hour, time.Second, c)
conntrack = fw.Conntrack
assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration) assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen) assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen)
fw = NewFirewall(l, time.Minute, time.Second, time.Hour, c) fw = NewFirewall(l, time.Minute, time.Second, time.Hour, c)
conntrack = fw.Conntrack
assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration) assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen) assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen)
} }

9
go.mod
View File

@@ -1,6 +1,6 @@
module github.com/slackhq/nebula module github.com/slackhq/nebula
go 1.23.6 go 1.23.0
toolchain go1.24.1 toolchain go1.24.1
@@ -10,14 +10,14 @@ require (
github.com/armon/go-radix v1.0.0 github.com/armon/go-radix v1.0.0
github.com/cyberdelia/go-metrics-graphite v0.0.0-20161219230853-39f87cc3b432 github.com/cyberdelia/go-metrics-graphite v0.0.0-20161219230853-39f87cc3b432
github.com/flynn/noise v1.1.0 github.com/flynn/noise v1.1.0
github.com/gaissmai/bart v0.20.1 github.com/gaissmai/bart v0.20.4
github.com/gogo/protobuf v1.3.2 github.com/gogo/protobuf v1.3.2
github.com/google/gopacket v1.1.19 github.com/google/gopacket v1.1.19
github.com/kardianos/service v1.2.2 github.com/kardianos/service v1.2.2
github.com/miekg/dns v1.1.65 github.com/miekg/dns v1.1.65
github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b
github.com/nbrownus/go-metrics-prometheus v0.0.0-20210712211119-974a6260965f github.com/nbrownus/go-metrics-prometheus v0.0.0-20210712211119-974a6260965f
github.com/prometheus/client_golang v1.21.1 github.com/prometheus/client_golang v1.22.0
github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475 github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475
github.com/sirupsen/logrus v1.9.3 github.com/sirupsen/logrus v1.9.3
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e
@@ -26,7 +26,7 @@ require (
github.com/vishvananda/netlink v1.3.0 github.com/vishvananda/netlink v1.3.0
golang.org/x/crypto v0.37.0 golang.org/x/crypto v0.37.0
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 golang.org/x/exp v0.0.0-20230725093048-515e97ebf090
golang.org/x/net v0.38.0 golang.org/x/net v0.39.0
golang.org/x/sync v0.13.0 golang.org/x/sync v0.13.0
golang.org/x/sys v0.32.0 golang.org/x/sys v0.32.0
golang.org/x/term v0.31.0 golang.org/x/term v0.31.0
@@ -43,7 +43,6 @@ require (
github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect github.com/davecgh/go-spew v1.1.1 // indirect
github.com/google/btree v1.1.2 // indirect github.com/google/btree v1.1.2 // indirect
github.com/klauspost/compress v1.17.11 // indirect
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/prometheus/client_model v0.6.1 // indirect github.com/prometheus/client_model v0.6.1 // indirect

20
go.sum
View File

@@ -24,8 +24,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/flynn/noise v1.1.0 h1:KjPQoQCEFdZDiP03phOvGi11+SVVhBG2wOWAorLsstg= github.com/flynn/noise v1.1.0 h1:KjPQoQCEFdZDiP03phOvGi11+SVVhBG2wOWAorLsstg=
github.com/flynn/noise v1.1.0/go.mod h1:xbMo+0i6+IGbYdJhF31t2eR1BIU0CYc12+BNAKwUTag= github.com/flynn/noise v1.1.0/go.mod h1:xbMo+0i6+IGbYdJhF31t2eR1BIU0CYc12+BNAKwUTag=
github.com/gaissmai/bart v0.20.1 h1:igNss0zDsSY8e+ophKgD9KJVPKBOo7uSVjyKCL7nIzo= github.com/gaissmai/bart v0.20.4 h1:Ik47r1fy3jRVU+1eYzKSW3ho2UgBVTVnUS8O993584U=
github.com/gaissmai/bart v0.20.1/go.mod h1:JJzMAhNF5Rjo4SF4jWBrANuJfqY+FvsFhW7t1UZJ+XY= github.com/gaissmai/bart v0.20.4/go.mod h1:cEed+ge8dalcbpi8wtS9x9m2hn/fNJH5suhdGQOHnYk=
github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY= github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY=
@@ -53,8 +53,8 @@ github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMyw
github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.4/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.4/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/gopacket v1.1.19 h1:ves8RnFZPGiFnTS0uPQStjwru6uO6h+nlr9j6fL7kF8= github.com/google/gopacket v1.1.19 h1:ves8RnFZPGiFnTS0uPQStjwru6uO6h+nlr9j6fL7kF8=
github.com/google/gopacket v1.1.19/go.mod h1:iJ8V8n6KS+z2U1A8pUwu8bW5SyEMkXJB8Yo/Vo+TKTo= github.com/google/gopacket v1.1.19/go.mod h1:iJ8V8n6KS+z2U1A8pUwu8bW5SyEMkXJB8Yo/Vo+TKTo=
@@ -68,8 +68,8 @@ github.com/kardianos/service v1.2.2 h1:ZvePhAHfvo0A7Mftk/tEzqEZ7Q4lgnR8sGz4xu1YX
github.com/kardianos/service v1.2.2/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= github.com/kardianos/service v1.2.2/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8= github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8=
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
github.com/klauspost/compress v1.17.11 h1:In6xLpyWOi1+C7tXUUWv2ot1QvBjxevKAaI6IXrJmUc= github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo=
github.com/klauspost/compress v1.17.11/go.mod h1:pMDklpSncoRMuLFrf1W9Ss9KT+0rH90U12bZKk7uwG0= github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ=
github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
github.com/konsorten/go-windows-terminal-sequences v1.0.3/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/konsorten/go-windows-terminal-sequences v1.0.3/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc= github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc=
@@ -106,8 +106,8 @@ github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXP
github.com/prometheus/client_golang v1.0.0/go.mod h1:db9x61etRT2tGnBNRi70OPL5FsnadC4Ky3P0J6CfImo= github.com/prometheus/client_golang v1.0.0/go.mod h1:db9x61etRT2tGnBNRi70OPL5FsnadC4Ky3P0J6CfImo=
github.com/prometheus/client_golang v1.7.1/go.mod h1:PY5Wy2awLA44sXw4AOSfFBetzPP4j5+D6mVACh+pe2M= github.com/prometheus/client_golang v1.7.1/go.mod h1:PY5Wy2awLA44sXw4AOSfFBetzPP4j5+D6mVACh+pe2M=
github.com/prometheus/client_golang v1.11.0/go.mod h1:Z6t4BnS23TR94PD6BsDNk8yVqroYurpAkEiz0P2BEV0= github.com/prometheus/client_golang v1.11.0/go.mod h1:Z6t4BnS23TR94PD6BsDNk8yVqroYurpAkEiz0P2BEV0=
github.com/prometheus/client_golang v1.21.1 h1:DOvXXTqVzvkIewV/CDPFdejpMCGeMcbGCQ8YOmu+Ibk= github.com/prometheus/client_golang v1.22.0 h1:rb93p9lokFEsctTys46VnV1kLCDpVZ0a/Y92Vm0Zc6Q=
github.com/prometheus/client_golang v1.21.1/go.mod h1:U9NM32ykUErtVBxdvD3zfi+EuFkkaBvMb09mIfe0Zgg= github.com/prometheus/client_golang v1.22.0/go.mod h1:R7ljNsLXhuQXYZYtw6GAE9AZg8Y7vEW5scdCXrWRXC0=
github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo= github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo=
github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
github.com/prometheus/client_model v0.2.0/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= github.com/prometheus/client_model v0.2.0/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
@@ -176,8 +176,8 @@ golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLL
golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.38.0 h1:vRMAPTMaeGqVhG5QyLJHqNDwecKTomGeqbnfZyKlBI8= golang.org/x/net v0.39.0 h1:ZCu7HMWDxpXpaiKdhzIfaltL9Lp31x/3fCP11bc6/fY=
golang.org/x/net v0.38.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8= golang.org/x/net v0.39.0/go.mod h1:X7NRbYVEA+ewNkCNyJ513WmMdQ3BineSwVtN2zD/d+E=
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=

View File

@@ -192,8 +192,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
for _, network := range remoteCert.Certificate.Networks() { for _, network := range remoteCert.Certificate.Networks() {
vpnAddr := network.Addr() vpnAddr := network.Addr()
_, found := f.myVpnAddrsTable.Lookup(vpnAddr) if f.myVpnAddrsTable.Contains(vpnAddr) {
if found {
f.l.WithField("vpnAddr", vpnAddr).WithField("udpAddr", addr). f.l.WithField("vpnAddr", vpnAddr).WithField("udpAddr", addr).
WithField("certName", certName). WithField("certName", certName).
WithField("certVersion", certVersion). WithField("certVersion", certVersion).
@@ -204,7 +203,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
} }
// vpnAddrs outside our vpn networks are of no use to us, filter them out // vpnAddrs outside our vpn networks are of no use to us, filter them out
if _, ok := f.myVpnNetworksTable.Lookup(vpnAddr); !ok { if !f.myVpnNetworksTable.Contains(vpnAddr) {
continue continue
} }
@@ -343,7 +342,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
if existing.SetRemoteIfPreferred(f.hostMap, addr) { if existing.SetRemoteIfPreferred(f.hostMap, addr) {
// Send a test packet to ensure the other side has also switched to // Send a test packet to ensure the other side has also switched to
// the preferred remote // the preferred remote
f.SendMessageToVpnAddr(header.Test, header.TestRequest, vpnAddrs[0], []byte(""), make([]byte, 12), make([]byte, mtu)) f.SendMessageToVpnAddr(header.Test, header.TestRequest, vpnAddrs[0], []byte(""), make([]byte, 12, 12), make([]byte, mtu))
} }
msg = existing.HandshakePacket[2] msg = existing.HandshakePacket[2]
@@ -386,7 +385,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
Info("Handshake too old") Info("Handshake too old")
// Send a test packet to trigger an authenticated tunnel test, this should suss out any lingering tunnel issues // Send a test packet to trigger an authenticated tunnel test, this should suss out any lingering tunnel issues
f.SendMessageToVpnAddr(header.Test, header.TestRequest, vpnAddrs[0], []byte(""), make([]byte, 12), make([]byte, mtu)) f.SendMessageToVpnAddr(header.Test, header.TestRequest, vpnAddrs[0], []byte(""), make([]byte, 12, 12), make([]byte, mtu))
return return
case ErrLocalIndexCollision: case ErrLocalIndexCollision:
// This means we failed to insert because of collision on localIndexId. Just let the next handshake packet retry // This means we failed to insert because of collision on localIndexId. Just let the next handshake packet retry
@@ -461,6 +460,8 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
f.connectionManager.AddTrafficWatch(hostinfo.localIndexId) f.connectionManager.AddTrafficWatch(hostinfo.localIndexId)
hostinfo.remotes.ResetBlockedRemotes() hostinfo.remotes.ResetBlockedRemotes()
return
} }
func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *HandshakeHostInfo, packet []byte, h *header.H) bool { func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *HandshakeHostInfo, packet []byte, h *header.H) bool {
@@ -577,7 +578,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
for _, network := range vpnNetworks { for _, network := range vpnNetworks {
// vpnAddrs outside our vpn networks are of no use to us, filter them out // vpnAddrs outside our vpn networks are of no use to us, filter them out
vpnAddr := network.Addr() vpnAddr := network.Addr()
if _, ok := f.myVpnNetworksTable.Lookup(vpnAddr); !ok { if !f.myVpnNetworksTable.Contains(vpnAddr) {
continue continue
} }
@@ -658,7 +659,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
} }
if len(hh.packetStore) > 0 { if len(hh.packetStore) > 0 {
nb := make([]byte, 12) nb := make([]byte, 12, 12)
out := make([]byte, mtu) out := make([]byte, mtu)
for _, cp := range hh.packetStore { for _, cp := range hh.packetStore {
cp.callback(cp.messageType, cp.messageSubType, hostinfo, cp.packet, nb, out) cp.callback(cp.messageType, cp.messageSubType, hostinfo, cp.packet, nb, out)

View File

@@ -274,8 +274,7 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
} }
// Don't relay through the host I'm trying to connect to // Don't relay through the host I'm trying to connect to
_, found := hm.f.myVpnAddrsTable.Lookup(relay) if hm.f.myVpnAddrsTable.Contains(relay) {
if found {
continue continue
} }

View File

@@ -65,16 +65,30 @@ func Test_NewHandshakeManagerVpnIp(t *testing.T) {
assert.NotContains(t, blah.vpnIps, ip) assert.NotContains(t, blah.vpnIps, ip)
} }
func testCountTimerWheelEntries(tw *LockingTimerWheel[netip.Addr]) (c int) {
for _, i := range tw.t.wheel {
n := i.Head
for n != nil {
c++
n = n.Next
}
}
return c
}
type mockEncWriter struct { type mockEncWriter struct {
} }
func (mw *mockEncWriter) SendMessageToVpnAddr(_ header.MessageType, _ header.MessageSubType, _ netip.Addr, _, _, _ []byte) { func (mw *mockEncWriter) SendMessageToVpnAddr(_ header.MessageType, _ header.MessageSubType, _ netip.Addr, _, _, _ []byte) {
return
} }
func (mw *mockEncWriter) SendVia(_ *HostInfo, _ *Relay, _, _, _ []byte, _ bool) { func (mw *mockEncWriter) SendVia(_ *HostInfo, _ *Relay, _, _, _ []byte, _ bool) {
return
} }
func (mw *mockEncWriter) SendMessageToHostInfo(_ header.MessageType, _ header.MessageSubType, _ *HostInfo, _, _, _ []byte) { func (mw *mockEncWriter) SendMessageToHostInfo(_ header.MessageType, _ header.MessageSubType, _ *HostInfo, _, _, _ []byte) {
return
} }
func (mw *mockEncWriter) Handshake(_ netip.Addr) {} func (mw *mockEncWriter) Handshake(_ netip.Addr) {}

View File

@@ -23,7 +23,7 @@ type m = map[string]any
const ( const (
Version uint8 = 1 Version uint8 = 1
Len int = 16 Len = 16
) )
type MessageType uint8 type MessageType uint8

View File

@@ -223,7 +223,7 @@ type HostInfo struct {
recvError atomic.Uint32 recvError atomic.Uint32
// networks are both all vpn and unsafe networks assigned to this host // networks are both all vpn and unsafe networks assigned to this host
networks *bart.Table[struct{}] networks *bart.Lite
relayState RelayState relayState RelayState
// HandshakePacket records the packets used to create this hostinfo // HandshakePacket records the packets used to create this hostinfo
@@ -568,7 +568,7 @@ func (hm *HostMap) unlockedAddHostInfo(hostinfo *HostInfo, f *Interface) {
dnsR.Add(remoteCert.Certificate.Name()+".", hostinfo.vpnAddrs) dnsR.Add(remoteCert.Certificate.Name()+".", hostinfo.vpnAddrs)
} }
for _, addr := range hostinfo.vpnAddrs { for _, addr := range hostinfo.vpnAddrs {
hm.unlockedInnerAddHostInfo(addr, hostinfo) hm.unlockedInnerAddHostInfo(addr, hostinfo, f)
} }
hm.Indexes[hostinfo.localIndexId] = hostinfo hm.Indexes[hostinfo.localIndexId] = hostinfo
@@ -581,7 +581,7 @@ func (hm *HostMap) unlockedAddHostInfo(hostinfo *HostInfo, f *Interface) {
} }
} }
func (hm *HostMap) unlockedInnerAddHostInfo(vpnAddr netip.Addr, hostinfo *HostInfo) { func (hm *HostMap) unlockedInnerAddHostInfo(vpnAddr netip.Addr, hostinfo *HostInfo, f *Interface) {
existing := hm.Hosts[vpnAddr] existing := hm.Hosts[vpnAddr]
hm.Hosts[vpnAddr] = hostinfo hm.Hosts[vpnAddr] = hostinfo
@@ -648,7 +648,7 @@ func (i *HostInfo) TryPromoteBest(preferredRanges []netip.Prefix, ifce *Interfac
// Try to send a test packet to that host, this should // Try to send a test packet to that host, this should
// cause it to detect a roaming event and switch remotes // cause it to detect a roaming event and switch remotes
ifce.sendTo(header.Test, header.TestRequest, i.ConnectionState, i, addr, []byte(""), make([]byte, 12), make([]byte, mtu)) ifce.sendTo(header.Test, header.TestRequest, i.ConnectionState, i, addr, []byte(""), make([]byte, 12, 12), make([]byte, mtu))
}) })
} }
@@ -732,13 +732,13 @@ func (i *HostInfo) buildNetworks(networks, unsafeNetworks []netip.Prefix) {
return return
} }
i.networks = new(bart.Table[struct{}]) i.networks = new(bart.Lite)
for _, network := range networks { for _, network := range networks {
i.networks.Insert(network, struct{}{}) i.networks.Insert(network)
} }
for _, network := range unsafeNetworks { for _, network := range unsafeNetworks {
i.networks.Insert(network, struct{}{}) i.networks.Insert(network)
} }
} }
@@ -794,7 +794,7 @@ func localAddrs(l *logrus.Logger, allowList *LocalAllowList) []netip.Addr {
} }
addr = addr.Unmap() addr = addr.Unmap()
if !addr.IsLoopback() && !addr.IsLinkLocalUnicast() { if addr.IsLoopback() == false && addr.IsLinkLocalUnicast() == false {
isAllowed := allowList.Allow(addr) isAllowed := allowList.Allow(addr)
if l.Level >= logrus.TraceLevel { if l.Level >= logrus.TraceLevel {
l.WithField("localAddr", addr).WithField("allowed", isAllowed).Trace("localAllowList.Allow") l.WithField("localAddr", addr).WithField("allowed", isAllowed).Trace("localAllowList.Allow")

View File

@@ -22,14 +22,12 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
// Ignore local broadcast packets // Ignore local broadcast packets
if f.dropLocalBroadcast { if f.dropLocalBroadcast {
_, found := f.myBroadcastAddrsTable.Lookup(fwPacket.RemoteAddr) if f.myBroadcastAddrsTable.Contains(fwPacket.RemoteAddr) {
if found {
return return
} }
} }
_, found := f.myVpnAddrsTable.Lookup(fwPacket.RemoteAddr) if f.myVpnAddrsTable.Contains(fwPacket.RemoteAddr) {
if found {
// Immediately forward packets from self to self. // Immediately forward packets from self to self.
// This should only happen on Darwin-based and FreeBSD hosts, which // This should only happen on Darwin-based and FreeBSD hosts, which
// routes packets from the Nebula addr to the Nebula addr through the Nebula // routes packets from the Nebula addr to the Nebula addr through the Nebula
@@ -130,8 +128,7 @@ func (f *Interface) Handshake(vpnAddr netip.Addr) {
// getOrHandshakeNoRouting returns nil if the vpnAddr is not routable. // getOrHandshakeNoRouting returns nil if the vpnAddr is not routable.
// If the 2nd return var is false then the hostinfo is not ready to be used in a tunnel // If the 2nd return var is false then the hostinfo is not ready to be used in a tunnel
func (f *Interface) getOrHandshakeNoRouting(vpnAddr netip.Addr, cacheCallback func(*HandshakeHostInfo)) (*HostInfo, bool) { func (f *Interface) getOrHandshakeNoRouting(vpnAddr netip.Addr, cacheCallback func(*HandshakeHostInfo)) (*HostInfo, bool) {
_, found := f.myVpnNetworksTable.Lookup(vpnAddr) if f.myVpnNetworksTable.Contains(vpnAddr) {
if found {
return f.handshakeManager.GetOrHandshake(vpnAddr, cacheCallback) return f.handshakeManager.GetOrHandshake(vpnAddr, cacheCallback)
} }

View File

@@ -61,11 +61,11 @@ type Interface struct {
serveDns bool serveDns bool
createTime time.Time createTime time.Time
lightHouse *LightHouse lightHouse *LightHouse
myBroadcastAddrsTable *bart.Table[struct{}] myBroadcastAddrsTable *bart.Lite
myVpnAddrs []netip.Addr // A list of addresses assigned to us via our certificate myVpnAddrs []netip.Addr // A list of addresses assigned to us via our certificate
myVpnAddrsTable *bart.Table[struct{}] // A table of addresses assigned to us via our certificate myVpnAddrsTable *bart.Lite
myVpnNetworks []netip.Prefix // A list of networks assigned to us via our certificate myVpnNetworks []netip.Prefix // A list of networks assigned to us via our certificate
myVpnNetworksTable *bart.Table[struct{}] // A table of networks assigned to us via our certificate myVpnNetworksTable *bart.Lite
dropLocalBroadcast bool dropLocalBroadcast bool
dropMulticast bool dropMulticast bool
routines int routines int
@@ -266,7 +266,7 @@ func (f *Interface) listenOut(i int) {
plaintext := make([]byte, udp.MTU) plaintext := make([]byte, udp.MTU)
h := &header.H{} h := &header.H{}
fwPacket := &firewall.Packet{} fwPacket := &firewall.Packet{}
nb := make([]byte, 12) nb := make([]byte, 12, 12)
li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) { li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) {
f.readOutsidePackets(fromUdpAddr, nil, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get(f.l)) f.readOutsidePackets(fromUdpAddr, nil, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get(f.l))
@@ -279,7 +279,7 @@ func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
packet := make([]byte, mtu) packet := make([]byte, mtu)
out := make([]byte, mtu) out := make([]byte, mtu)
fwPacket := &firewall.Packet{} fwPacket := &firewall.Packet{}
nb := make([]byte, 12) nb := make([]byte, 12, 12)
conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout) conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
@@ -322,7 +322,7 @@ func (f *Interface) reloadDisconnectInvalid(c *config.C) {
func (f *Interface) reloadFirewall(c *config.C) { func (f *Interface) reloadFirewall(c *config.C) {
//TODO: need to trigger/detect if the certificate changed too //TODO: need to trigger/detect if the certificate changed too
if !c.HasChanged("firewall") { if c.HasChanged("firewall") == false {
f.l.Debug("No firewall config change detected") f.l.Debug("No firewall config change detected")
return return
} }
@@ -424,7 +424,7 @@ func (f *Interface) emitStats(ctx context.Context, i time.Duration) {
certState := f.pki.getCertState() certState := f.pki.getCertState()
defaultCrt := certState.GetDefaultCertificate() defaultCrt := certState.GetDefaultCertificate()
certExpirationGauge.Update(int64(time.Until(defaultCrt.NotAfter()) / time.Second)) certExpirationGauge.Update(int64(defaultCrt.NotAfter().Sub(time.Now()) / time.Second))
certInitiatingVersion.Update(int64(defaultCrt.Version())) certInitiatingVersion.Update(int64(defaultCrt.Version()))
// Report the max certificate version we are capable of using // Report the max certificate version we are capable of using

View File

@@ -32,7 +32,7 @@ type LightHouse struct {
amLighthouse bool amLighthouse bool
myVpnNetworks []netip.Prefix myVpnNetworks []netip.Prefix
myVpnNetworksTable *bart.Table[struct{}] myVpnNetworksTable *bart.Lite
punchConn udp.Conn punchConn udp.Conn
punchy *Punchy punchy *Punchy
@@ -201,8 +201,7 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error {
//TODO: we could technically insert all returned addrs instead of just the first one if a dns lookup was used //TODO: we could technically insert all returned addrs instead of just the first one if a dns lookup was used
addr := addrs[0].Unmap() addr := addrs[0].Unmap()
_, found := lh.myVpnNetworksTable.Lookup(addr) if lh.myVpnNetworksTable.Contains(addr) {
if found {
lh.l.WithField("addr", rawAddr).WithField("entry", i+1). lh.l.WithField("addr", rawAddr).WithField("entry", i+1).
Warn("Ignoring lighthouse.advertise_addrs report because it is within the nebula network range") Warn("Ignoring lighthouse.advertise_addrs report because it is within the nebula network range")
continue continue
@@ -359,8 +358,7 @@ func (lh *LightHouse) parseLighthouses(c *config.C, lhMap map[netip.Addr]struct{
return util.NewContextualError("Unable to parse lighthouse host entry", m{"host": host, "entry": i + 1}, err) return util.NewContextualError("Unable to parse lighthouse host entry", m{"host": host, "entry": i + 1}, err)
} }
_, found := lh.myVpnNetworksTable.Lookup(addr) if !lh.myVpnNetworksTable.Contains(addr) {
if !found {
return util.NewContextualError("lighthouse host is not in our networks, invalid", m{"vpnAddr": addr, "networks": lh.myVpnNetworks}, nil) return util.NewContextualError("lighthouse host is not in our networks, invalid", m{"vpnAddr": addr, "networks": lh.myVpnNetworks}, nil)
} }
lhMap[addr] = struct{}{} lhMap[addr] = struct{}{}
@@ -371,7 +369,7 @@ func (lh *LightHouse) parseLighthouses(c *config.C, lhMap map[netip.Addr]struct{
} }
staticList := lh.GetStaticHostList() staticList := lh.GetStaticHostList()
for lhAddr := range lhMap { for lhAddr, _ := range lhMap {
if _, ok := staticList[lhAddr]; !ok { if _, ok := staticList[lhAddr]; !ok {
return fmt.Errorf("lighthouse %s does not have a static_host_map entry", lhAddr) return fmt.Errorf("lighthouse %s does not have a static_host_map entry", lhAddr)
} }
@@ -431,8 +429,7 @@ func (lh *LightHouse) loadStaticMap(c *config.C, staticList map[netip.Addr]struc
return util.NewContextualError("Unable to parse static_host_map entry", m{"host": k, "entry": i + 1}, err) return util.NewContextualError("Unable to parse static_host_map entry", m{"host": k, "entry": i + 1}, err)
} }
_, found := lh.myVpnNetworksTable.Lookup(vpnAddr) if !lh.myVpnNetworksTable.Contains(vpnAddr) {
if !found {
return util.NewContextualError("static_host_map key is not in our network, invalid", m{"vpnAddr": vpnAddr, "networks": lh.myVpnNetworks, "entry": i + 1}, nil) return util.NewContextualError("static_host_map key is not in our network, invalid", m{"vpnAddr": vpnAddr, "networks": lh.myVpnNetworks, "entry": i + 1}, nil)
} }
@@ -653,9 +650,11 @@ func (lh *LightHouse) shouldAdd(vpnAddr netip.Addr, to netip.Addr) bool {
return false return false
} }
_, found := lh.myVpnNetworksTable.Lookup(to) if lh.myVpnNetworksTable.Contains(to) {
return false
}
return !found return true
} }
// unlockedShouldAddV4 checks if to is allowed by our allow list // unlockedShouldAddV4 checks if to is allowed by our allow list
@@ -671,8 +670,11 @@ func (lh *LightHouse) unlockedShouldAddV4(vpnAddr netip.Addr, to *V4AddrPort) bo
return false return false
} }
_, found := lh.myVpnNetworksTable.Lookup(udpAddr.Addr()) if lh.myVpnNetworksTable.Contains(udpAddr.Addr()) {
return !found return false
}
return true
} }
// unlockedShouldAddV6 checks if to is allowed by our allow list // unlockedShouldAddV6 checks if to is allowed by our allow list
@@ -688,9 +690,11 @@ func (lh *LightHouse) unlockedShouldAddV6(vpnAddr netip.Addr, to *V6AddrPort) bo
return false return false
} }
_, found := lh.myVpnNetworksTable.Lookup(udpAddr.Addr()) if lh.myVpnNetworksTable.Contains(udpAddr.Addr()) {
return false
}
return !found return true
} }
func (lh *LightHouse) IsLighthouseAddr(vpnAddr netip.Addr) bool { func (lh *LightHouse) IsLighthouseAddr(vpnAddr netip.Addr) bool {
@@ -718,7 +722,7 @@ func (lh *LightHouse) startQueryWorker() {
} }
go func() { go func() {
nb := make([]byte, 12) nb := make([]byte, 12, 12)
out := make([]byte, mtu) out := make([]byte, mtu)
for { for {
@@ -846,8 +850,7 @@ func (lh *LightHouse) SendUpdate() {
lal := lh.GetLocalAllowList() lal := lh.GetLocalAllowList()
for _, e := range localAddrs(lh.l, lal) { for _, e := range localAddrs(lh.l, lal) {
_, found := lh.myVpnNetworksTable.Lookup(e) if lh.myVpnNetworksTable.Contains(e) {
if found {
continue continue
} }
@@ -859,7 +862,7 @@ func (lh *LightHouse) SendUpdate() {
} }
} }
nb := make([]byte, 12) nb := make([]byte, 12, 12)
out := make([]byte, mtu) out := make([]byte, mtu)
var v1Update, v2Update []byte var v1Update, v2Update []byte
@@ -961,7 +964,7 @@ type LightHouseHandler struct {
func (lh *LightHouse) NewRequestHandler() *LightHouseHandler { func (lh *LightHouse) NewRequestHandler() *LightHouseHandler {
lhh := &LightHouseHandler{ lhh := &LightHouseHandler{
lh: lh, lh: lh,
nb: make([]byte, 12), nb: make([]byte, 12, 12),
out: make([]byte, mtu), out: make([]byte, mtu),
l: lh.l, l: lh.l,
pb: make([]byte, mtu), pb: make([]byte, mtu),
@@ -1152,7 +1155,7 @@ func (lhh *LightHouseHandler) coalesceAnswers(v cert.Version, c *cache, n *Nebul
if c.v4.learned != nil { if c.v4.learned != nil {
n.Details.V4AddrPorts = append(n.Details.V4AddrPorts, c.v4.learned) n.Details.V4AddrPorts = append(n.Details.V4AddrPorts, c.v4.learned)
} }
if len(c.v4.reported) > 0 { if c.v4.reported != nil && len(c.v4.reported) > 0 {
n.Details.V4AddrPorts = append(n.Details.V4AddrPorts, c.v4.reported...) n.Details.V4AddrPorts = append(n.Details.V4AddrPorts, c.v4.reported...)
} }
} }
@@ -1161,7 +1164,7 @@ func (lhh *LightHouseHandler) coalesceAnswers(v cert.Version, c *cache, n *Nebul
if c.v6.learned != nil { if c.v6.learned != nil {
n.Details.V6AddrPorts = append(n.Details.V6AddrPorts, c.v6.learned) n.Details.V6AddrPorts = append(n.Details.V6AddrPorts, c.v6.learned)
} }
if len(c.v6.reported) > 0 { if c.v6.reported != nil && len(c.v6.reported) > 0 {
n.Details.V6AddrPorts = append(n.Details.V6AddrPorts, c.v6.reported...) n.Details.V6AddrPorts = append(n.Details.V6AddrPorts, c.v6.reported...)
} }
} }
@@ -1359,7 +1362,7 @@ func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, fromVpn
//NOTE: we have to allocate a new output buffer here since we are spawning a new goroutine //NOTE: we have to allocate a new output buffer here since we are spawning a new goroutine
// for each punchBack packet. We should move this into a timerwheel or a single goroutine // for each punchBack packet. We should move this into a timerwheel or a single goroutine
// managed by a channel. // managed by a channel.
w.SendMessageToVpnAddr(header.Test, header.TestRequest, queryVpnAddr, []byte(""), make([]byte, 12), make([]byte, mtu)) w.SendMessageToVpnAddr(header.Test, header.TestRequest, queryVpnAddr, []byte(""), make([]byte, 12, 12), make([]byte, mtu))
}() }()
} }
} }

View File

@@ -31,8 +31,8 @@ func TestOldIPv4Only(t *testing.T) {
func Test_lhStaticMapping(t *testing.T) { func Test_lhStaticMapping(t *testing.T) {
l := test.NewLogger() l := test.NewLogger()
myVpnNet := netip.MustParsePrefix("10.128.0.1/16") myVpnNet := netip.MustParsePrefix("10.128.0.1/16")
nt := new(bart.Table[struct{}]) nt := new(bart.Lite)
nt.Insert(myVpnNet, struct{}{}) nt.Insert(myVpnNet)
cs := &CertState{ cs := &CertState{
myVpnNetworks: []netip.Prefix{myVpnNet}, myVpnNetworks: []netip.Prefix{myVpnNet},
myVpnNetworksTable: nt, myVpnNetworksTable: nt,
@@ -56,8 +56,8 @@ func Test_lhStaticMapping(t *testing.T) {
func TestReloadLighthouseInterval(t *testing.T) { func TestReloadLighthouseInterval(t *testing.T) {
l := test.NewLogger() l := test.NewLogger()
myVpnNet := netip.MustParsePrefix("10.128.0.1/16") myVpnNet := netip.MustParsePrefix("10.128.0.1/16")
nt := new(bart.Table[struct{}]) nt := new(bart.Lite)
nt.Insert(myVpnNet, struct{}{}) nt.Insert(myVpnNet)
cs := &CertState{ cs := &CertState{
myVpnNetworks: []netip.Prefix{myVpnNet}, myVpnNetworks: []netip.Prefix{myVpnNet},
myVpnNetworksTable: nt, myVpnNetworksTable: nt,
@@ -91,8 +91,8 @@ func TestReloadLighthouseInterval(t *testing.T) {
func BenchmarkLighthouseHandleRequest(b *testing.B) { func BenchmarkLighthouseHandleRequest(b *testing.B) {
l := test.NewLogger() l := test.NewLogger()
myVpnNet := netip.MustParsePrefix("10.128.0.1/0") myVpnNet := netip.MustParsePrefix("10.128.0.1/0")
nt := new(bart.Table[struct{}]) nt := new(bart.Lite)
nt.Insert(myVpnNet, struct{}{}) nt.Insert(myVpnNet)
cs := &CertState{ cs := &CertState{
myVpnNetworks: []netip.Prefix{myVpnNet}, myVpnNetworks: []netip.Prefix{myVpnNet},
myVpnNetworksTable: nt, myVpnNetworksTable: nt,
@@ -196,8 +196,8 @@ func TestLighthouse_Memory(t *testing.T) {
c.Settings["listen"] = map[string]any{"port": 4242} c.Settings["listen"] = map[string]any{"port": 4242}
myVpnNet := netip.MustParsePrefix("10.128.0.1/24") myVpnNet := netip.MustParsePrefix("10.128.0.1/24")
nt := new(bart.Table[struct{}]) nt := new(bart.Lite)
nt.Insert(myVpnNet, struct{}{}) nt.Insert(myVpnNet)
cs := &CertState{ cs := &CertState{
myVpnNetworks: []netip.Prefix{myVpnNet}, myVpnNetworks: []netip.Prefix{myVpnNet},
myVpnNetworksTable: nt, myVpnNetworksTable: nt,
@@ -281,8 +281,8 @@ func TestLighthouse_reload(t *testing.T) {
c.Settings["listen"] = map[string]any{"port": 4242} c.Settings["listen"] = map[string]any{"port": 4242}
myVpnNet := netip.MustParsePrefix("10.128.0.1/24") myVpnNet := netip.MustParsePrefix("10.128.0.1/24")
nt := new(bart.Table[struct{}]) nt := new(bart.Lite)
nt.Insert(myVpnNet, struct{}{}) nt.Insert(myVpnNet)
cs := &CertState{ cs := &CertState{
myVpnNetworks: []netip.Prefix{myVpnNet}, myVpnNetworks: []netip.Prefix{myVpnNet},
myVpnNetworksTable: nt, myVpnNetworksTable: nt,
@@ -484,12 +484,12 @@ func Test_findNetworkUnion(t *testing.T) {
assert.Equal(t, out, afe81) assert.Equal(t, out, afe81)
//falsey cases //falsey cases
_, ok = findNetworkUnion([]netip.Prefix{oneSevenTwo, fe80}, []netip.Addr{a1}) out, ok = findNetworkUnion([]netip.Prefix{oneSevenTwo, fe80}, []netip.Addr{a1})
assert.False(t, ok) assert.False(t, ok)
_, ok = findNetworkUnion([]netip.Prefix{fc00, fe80}, []netip.Addr{a1}) out, ok = findNetworkUnion([]netip.Prefix{fc00, fe80}, []netip.Addr{a1})
assert.False(t, ok) assert.False(t, ok)
_, ok = findNetworkUnion([]netip.Prefix{oneSevenTwo, fc00}, []netip.Addr{a1, afe81}) out, ok = findNetworkUnion([]netip.Prefix{oneSevenTwo, fc00}, []netip.Addr{a1, afe81})
assert.False(t, ok) assert.False(t, ok)
_, ok = findNetworkUnion([]netip.Prefix{fc00}, []netip.Addr{a1, afe81}) out, ok = findNetworkUnion([]netip.Prefix{fc00}, []netip.Addr{a1, afe81})
assert.False(t, ok) assert.False(t, ok)
} }

View File

@@ -284,7 +284,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
var dnsStart func() var dnsStart func()
if lightHouse.amLighthouse && serveDns { if lightHouse.amLighthouse && serveDns {
l.Debugln("Starting dns server") l.Debugln("Starting dns server")
dnsStart = dnsMain(l, pki.getCertState(), hostMap, c) dnsStart = dnsMain(ctx, l, pki.getCertState(), hostMap, c)
} }
return &Control{ return &Control{

View File

@@ -17,7 +17,7 @@ type MessageMetrics struct {
func (m *MessageMetrics) Rx(t header.MessageType, s header.MessageSubType, i int64) { func (m *MessageMetrics) Rx(t header.MessageType, s header.MessageSubType, i int64) {
if m != nil { if m != nil {
if int(t) < len(m.rx) && int(s) < len(m.rx[t]) { if t >= 0 && int(t) < len(m.rx) && s >= 0 && int(s) < len(m.rx[t]) {
m.rx[t][s].Inc(i) m.rx[t][s].Inc(i)
} else if m.rxUnknown != nil { } else if m.rxUnknown != nil {
m.rxUnknown.Inc(i) m.rxUnknown.Inc(i)
@@ -26,7 +26,7 @@ func (m *MessageMetrics) Rx(t header.MessageType, s header.MessageSubType, i int
} }
func (m *MessageMetrics) Tx(t header.MessageType, s header.MessageSubType, i int64) { func (m *MessageMetrics) Tx(t header.MessageType, s header.MessageSubType, i int64) {
if m != nil { if m != nil {
if int(t) < len(m.tx) && int(s) < len(m.tx[t]) { if t >= 0 && int(t) < len(m.tx) && s >= 0 && int(s) < len(m.tx[t]) {
m.tx[t][s].Inc(i) m.tx[t][s].Inc(i)
} else if m.txUnknown != nil { } else if m.txUnknown != nil {
m.txUnknown.Inc(i) m.txUnknown.Inc(i)

View File

@@ -31,8 +31,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
//l.Error("in packet ", header, packet[HeaderLen:]) //l.Error("in packet ", header, packet[HeaderLen:])
if ip.IsValid() { if ip.IsValid() {
_, found := f.myVpnNetworksTable.Lookup(ip.Addr()) if f.myVpnNetworksTable.Contains(ip.Addr()) {
if found {
if f.l.Level >= logrus.DebugLevel { if f.l.Level >= logrus.DebugLevel {
f.l.WithField("udpAddr", ip).Debug("Refusing to process double encrypted packet") f.l.WithField("udpAddr", ip).Debug("Refusing to process double encrypted packet")
} }
@@ -228,7 +227,7 @@ func (f *Interface) closeTunnel(hostInfo *HostInfo) {
// sendCloseTunnel is a helper function to send a proper close tunnel packet to a remote // sendCloseTunnel is a helper function to send a proper close tunnel packet to a remote
func (f *Interface) sendCloseTunnel(h *HostInfo) { func (f *Interface) sendCloseTunnel(h *HostInfo) {
f.send(header.CloseTunnel, 0, h.ConnectionState, h, []byte{}, make([]byte, 12), make([]byte, mtu)) f.send(header.CloseTunnel, 0, h.ConnectionState, h, []byte{}, make([]byte, 12, 12), make([]byte, mtu))
} }
func (f *Interface) handleHostRoaming(hostinfo *HostInfo, udpAddr netip.AddrPort) { func (f *Interface) handleHostRoaming(hostinfo *HostInfo, udpAddr netip.AddrPort) {

View File

@@ -3,6 +3,7 @@ package overlay
import ( import (
"fmt" "fmt"
"math" "math"
"net"
"net/netip" "net/netip"
"runtime" "runtime"
"strconv" "strconv"
@@ -304,3 +305,29 @@ func parseUnsafeRoutes(c *config.C, networks []netip.Prefix) ([]Route, error) {
return routes, nil return routes, nil
} }
func ipWithin(o *net.IPNet, i *net.IPNet) bool {
// Make sure o contains the lowest form of i
if !o.Contains(i.IP.Mask(i.Mask)) {
return false
}
// Find the max ip in i
ip4 := i.IP.To4()
if ip4 == nil {
return false
}
last := make(net.IP, len(ip4))
copy(last, ip4)
for x := range ip4 {
last[x] |= ^i.Mask[x]
}
// Make sure o contains the max
if !o.Contains(last) {
return false
}
return true
}

View File

@@ -225,7 +225,6 @@ func Test_parseUnsafeRoutes(t *testing.T) {
// no mtu // no mtu
c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"via": "127.0.0.1", "route": "1.0.0.0/8"}}} c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"via": "127.0.0.1", "route": "1.0.0.0/8"}}}
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
require.NoError(t, err)
assert.Len(t, routes, 1) assert.Len(t, routes, 1)
assert.Equal(t, 0, routes[0].MTU) assert.Equal(t, 0, routes[0].MTU)
@@ -319,7 +318,7 @@ func Test_makeRouteTree(t *testing.T) {
ip, err = netip.ParseAddr("1.1.0.1") ip, err = netip.ParseAddr("1.1.0.1")
require.NoError(t, err) require.NoError(t, err)
_, ok = routeTree.Lookup(ip) r, ok = routeTree.Lookup(ip)
assert.False(t, ok) assert.False(t, ok)
} }

View File

@@ -1,6 +1,8 @@
package pkclient package pkclient
import ( import (
"crypto/ecdsa"
"crypto/x509"
"fmt" "fmt"
"io" "io"
"strconv" "strconv"
@@ -48,6 +50,27 @@ func FromUrl(pkurl string) (*PKClient, error) {
return New(module, uint(slotid), pin, id, label) return New(module, uint(slotid), pin, id, label)
} }
func ecKeyToArray(key *ecdsa.PublicKey) []byte {
x := make([]byte, 32)
y := make([]byte, 32)
key.X.FillBytes(x)
key.Y.FillBytes(y)
return append([]byte{0x04}, append(x, y...)...)
}
func formatPubkeyFromPublicKeyInfoAttr(d []byte) ([]byte, error) {
e, err := x509.ParsePKIXPublicKey(d)
if err != nil {
return nil, err
}
switch t := e.(type) {
case *ecdsa.PublicKey:
return ecKeyToArray(e.(*ecdsa.PublicKey)), nil
default:
return nil, fmt.Errorf("unknown public key type: %T", t)
}
}
func (c *PKClient) Test() error { func (c *PKClient) Test() error {
pub, err := c.GetPubKey() pub, err := c.GetPubKey()
if err != nil { if err != nil {

View File

@@ -3,8 +3,6 @@
package pkclient package pkclient
import ( import (
"crypto/ecdsa"
"crypto/x509"
"encoding/asn1" "encoding/asn1"
"errors" "errors"
"fmt" "fmt"
@@ -229,24 +227,3 @@ func (c *PKClient) GetPubKey() ([]byte, error) {
return nil, fmt.Errorf("unknown public key length: %d", len(d)) return nil, fmt.Errorf("unknown public key length: %d", len(d))
} }
} }
func ecKeyToArray(key *ecdsa.PublicKey) []byte {
x := make([]byte, 32)
y := make([]byte, 32)
key.X.FillBytes(x)
key.Y.FillBytes(y)
return append([]byte{0x04}, append(x, y...)...)
}
func formatPubkeyFromPublicKeyInfoAttr(d []byte) ([]byte, error) {
e, err := x509.ParsePKIXPublicKey(d)
if err != nil {
return nil, err
}
switch t := e.(type) {
case *ecdsa.PublicKey:
return ecKeyToArray(e.(*ecdsa.PublicKey)), nil
default:
return nil, fmt.Errorf("unknown public key type: %T", t)
}
}

View File

@@ -7,10 +7,10 @@ import "errors"
type PKClient struct { type PKClient struct {
} }
var errNotImplemented = errors.New("not implemented") var notImplemented = errors.New("not implemented")
func New(hsmPath string, slotId uint, pin string, id string, label string) (*PKClient, error) { func New(hsmPath string, slotId uint, pin string, id string, label string) (*PKClient, error) {
return nil, errNotImplemented return nil, notImplemented
} }
func (c *PKClient) Close() error { func (c *PKClient) Close() error {
@@ -18,13 +18,13 @@ func (c *PKClient) Close() error {
} }
func (c *PKClient) SignASN1(data []byte) ([]byte, error) { func (c *PKClient) SignASN1(data []byte) ([]byte, error) {
return nil, errNotImplemented return nil, notImplemented
} }
func (c *PKClient) DeriveNoise(_ []byte) ([]byte, error) { func (c *PKClient) DeriveNoise(_ []byte) ([]byte, error) {
return nil, errNotImplemented return nil, notImplemented
} }
func (c *PKClient) GetPubKey() ([]byte, error) { func (c *PKClient) GetPubKey() ([]byte, error) {
return nil, errNotImplemented return nil, notImplemented
} }

18
pki.go
View File

@@ -39,10 +39,10 @@ type CertState struct {
cipher string cipher string
myVpnNetworks []netip.Prefix myVpnNetworks []netip.Prefix
myVpnNetworksTable *bart.Table[struct{}] myVpnNetworksTable *bart.Lite
myVpnAddrs []netip.Addr myVpnAddrs []netip.Addr
myVpnAddrsTable *bart.Table[struct{}] myVpnAddrsTable *bart.Lite
myVpnBroadcastAddrsTable *bart.Table[struct{}] myVpnBroadcastAddrsTable *bart.Lite
} }
func NewPKIFromConfig(l *logrus.Logger, c *config.C) (*PKI, error) { func NewPKIFromConfig(l *logrus.Logger, c *config.C) (*PKI, error) {
@@ -345,9 +345,9 @@ func newCertState(dv cert.Version, v1, v2 cert.Certificate, pkcs11backed bool, p
cs := CertState{ cs := CertState{
privateKey: privateKey, privateKey: privateKey,
pkcs11Backed: pkcs11backed, pkcs11Backed: pkcs11backed,
myVpnNetworksTable: new(bart.Table[struct{}]), myVpnNetworksTable: new(bart.Lite),
myVpnAddrsTable: new(bart.Table[struct{}]), myVpnAddrsTable: new(bart.Lite),
myVpnBroadcastAddrsTable: new(bart.Table[struct{}]), myVpnBroadcastAddrsTable: new(bart.Lite),
} }
if v1 != nil && v2 != nil { if v1 != nil && v2 != nil {
@@ -415,16 +415,16 @@ func newCertState(dv cert.Version, v1, v2 cert.Certificate, pkcs11backed bool, p
for _, network := range crt.Networks() { for _, network := range crt.Networks() {
cs.myVpnNetworks = append(cs.myVpnNetworks, network) cs.myVpnNetworks = append(cs.myVpnNetworks, network)
cs.myVpnNetworksTable.Insert(network, struct{}{}) cs.myVpnNetworksTable.Insert(network)
cs.myVpnAddrs = append(cs.myVpnAddrs, network.Addr()) cs.myVpnAddrs = append(cs.myVpnAddrs, network.Addr())
cs.myVpnAddrsTable.Insert(netip.PrefixFrom(network.Addr(), network.Addr().BitLen()), struct{}{}) cs.myVpnAddrsTable.Insert(netip.PrefixFrom(network.Addr(), network.Addr().BitLen()))
if network.Addr().Is4() { if network.Addr().Is4() {
addr := network.Masked().Addr().As4() addr := network.Masked().Addr().As4()
mask := net.CIDRMask(network.Bits(), network.Addr().BitLen()) mask := net.CIDRMask(network.Bits(), network.Addr().BitLen())
binary.BigEndian.PutUint32(addr[:], binary.BigEndian.Uint32(addr[:])|^binary.BigEndian.Uint32(mask)) binary.BigEndian.PutUint32(addr[:], binary.BigEndian.Uint32(addr[:])|^binary.BigEndian.Uint32(mask))
cs.myVpnBroadcastAddrsTable.Insert(netip.PrefixFrom(netip.AddrFrom4(addr), network.Addr().BitLen()), struct{}{}) cs.myVpnBroadcastAddrsTable.Insert(netip.PrefixFrom(netip.AddrFrom4(addr), network.Addr().BitLen()))
} }
} }

View File

@@ -241,15 +241,13 @@ func (rm *relayManager) handleCreateRelayRequest(v cert.Version, h *HostInfo, f
logMsg.Info("handleCreateRelayRequest") logMsg.Info("handleCreateRelayRequest")
// Is the source of the relay me? This should never happen, but did happen due to // Is the source of the relay me? This should never happen, but did happen due to
// an issue migrating relays over to newly re-handshaked host info objects. // an issue migrating relays over to newly re-handshaked host info objects.
_, found := f.myVpnAddrsTable.Lookup(from) if f.myVpnAddrsTable.Contains(from) {
if found {
logMsg.WithField("myIP", from).Error("Discarding relay request from myself") logMsg.WithField("myIP", from).Error("Discarding relay request from myself")
return return
} }
// Is the target of the relay me? // Is the target of the relay me?
_, found = f.myVpnAddrsTable.Lookup(target) if f.myVpnAddrsTable.Contains(target) {
if found {
existingRelay, ok := h.relayState.QueryRelayForByIp(from) existingRelay, ok := h.relayState.QueryRelayForByIp(from)
if ok { if ok {
switch existingRelay.State { switch existingRelay.State {

View File

@@ -263,7 +263,9 @@ func (r *RemoteList) CopyAddrs(preferredRanges []netip.Prefix) []netip.AddrPort
r.RLock() r.RLock()
defer r.RUnlock() defer r.RUnlock()
c := make([]netip.AddrPort, len(r.addrs)) c := make([]netip.AddrPort, len(r.addrs))
copy(c, r.addrs) for i, v := range r.addrs {
c[i] = v
}
return c return c
} }
@@ -324,7 +326,9 @@ func (r *RemoteList) CopyCache() *CacheMap {
} }
if mc.relay != nil { if mc.relay != nil {
c.Relay = append(c.Relay, mc.relay.relay...) for _, a := range mc.relay.relay {
c.Relay = append(c.Relay, a)
}
} }
} }
@@ -358,7 +362,9 @@ func (r *RemoteList) CopyBlockedRemotes() []netip.AddrPort {
defer r.RUnlock() defer r.RUnlock()
c := make([]netip.AddrPort, len(r.badRemotes)) c := make([]netip.AddrPort, len(r.badRemotes))
copy(c, r.badRemotes) for i, v := range r.badRemotes {
c[i] = v
}
return c return c
} }
@@ -563,7 +569,9 @@ func (r *RemoteList) unlockedCollect() {
} }
if c.relay != nil { if c.relay != nil {
relays = append(relays, c.relay.relay...) for _, v := range c.relay.relay {
relays = append(relays, v)
}
} }
} }
@@ -627,15 +635,15 @@ func (r *RemoteList) unlockedSort(preferredRanges []netip.Prefix) {
a4 := a.Addr().Is4() a4 := a.Addr().Is4()
b4 := b.Addr().Is4() b4 := b.Addr().Is4()
switch { switch {
case !a4 && b4: case a4 == false && b4 == true:
// If i is v6 and j is v4, i is less than j // If i is v6 and j is v4, i is less than j
return true return true
case a4 && !b4: case a4 == true && b4 == false:
// If j is v6 and i is v4, i is not less than j // If j is v6 and i is v4, i is not less than j
return false return false
case a4 && b4: case a4 == true && b4 == true:
// i and j are both ipv4 // i and j are both ipv4
aPrivate := a.Addr().IsPrivate() aPrivate := a.Addr().IsPrivate()
bPrivate := b.Addr().IsPrivate() bPrivate := b.Addr().IsPrivate()
@@ -683,6 +691,7 @@ func (r *RemoteList) unlockedSort(preferredRanges []netip.Prefix) {
} }
r.addrs = r.addrs[:a+1] r.addrs = r.addrs[:a+1]
return
} }
// minInt returns the minimum integer of a or b // minInt returns the minimum integer of a or b

20
ssh.go
View File

@@ -527,11 +527,11 @@ func sshStartCpuProfile(fs any, a []string, w sshd.StringWriter) error {
return err return err
} }
func sshVersion(ifce *Interface, _ any, _ []string, w sshd.StringWriter) error { func sshVersion(ifce *Interface, fs any, a []string, w sshd.StringWriter) error {
return w.WriteLine(ifce.version) return w.WriteLine(fmt.Sprintf("%s", ifce.version))
} }
func sshQueryLighthouse(ifce *Interface, _ any, a []string, w sshd.StringWriter) error { func sshQueryLighthouse(ifce *Interface, fs any, a []string, w sshd.StringWriter) error {
if len(a) == 0 { if len(a) == 0 {
return w.WriteLine("No vpn address was provided") return w.WriteLine("No vpn address was provided")
} }
@@ -584,7 +584,7 @@ func sshCloseTunnel(ifce *Interface, fs any, a []string, w sshd.StringWriter) er
hostInfo.ConnectionState, hostInfo.ConnectionState,
hostInfo, hostInfo,
[]byte{}, []byte{},
make([]byte, 12), make([]byte, 12, 12),
make([]byte, mtu), make([]byte, mtu),
) )
} }
@@ -614,12 +614,12 @@ func sshCreateTunnel(ifce *Interface, fs any, a []string, w sshd.StringWriter) e
hostInfo := ifce.hostMap.QueryVpnAddr(vpnAddr) hostInfo := ifce.hostMap.QueryVpnAddr(vpnAddr)
if hostInfo != nil { if hostInfo != nil {
return w.WriteLine("Tunnel already exists") return w.WriteLine(fmt.Sprintf("Tunnel already exists"))
} }
hostInfo = ifce.handshakeManager.QueryVpnAddr(vpnAddr) hostInfo = ifce.handshakeManager.QueryVpnAddr(vpnAddr)
if hostInfo != nil { if hostInfo != nil {
return w.WriteLine("Tunnel already handshaking") return w.WriteLine(fmt.Sprintf("Tunnel already handshaking"))
} }
var addr netip.AddrPort var addr netip.AddrPort
@@ -735,7 +735,7 @@ func sshGetMutexProfile(fs any, a []string, w sshd.StringWriter) error {
return w.WriteLine(fmt.Sprintf("Mutex profile created at %s", a)) return w.WriteLine(fmt.Sprintf("Mutex profile created at %s", a))
} }
func sshLogLevel(l *logrus.Logger, _ any, a []string, w sshd.StringWriter) error { func sshLogLevel(l *logrus.Logger, fs any, a []string, w sshd.StringWriter) error {
if len(a) == 0 { if len(a) == 0 {
return w.WriteLine(fmt.Sprintf("Log level is: %s", l.Level)) return w.WriteLine(fmt.Sprintf("Log level is: %s", l.Level))
} }
@@ -749,7 +749,7 @@ func sshLogLevel(l *logrus.Logger, _ any, a []string, w sshd.StringWriter) error
return w.WriteLine(fmt.Sprintf("Log level is: %s", l.Level)) return w.WriteLine(fmt.Sprintf("Log level is: %s", l.Level))
} }
func sshLogFormat(l *logrus.Logger, _ any, a []string, w sshd.StringWriter) error { func sshLogFormat(l *logrus.Logger, fs any, a []string, w sshd.StringWriter) error {
if len(a) == 0 { if len(a) == 0 {
return w.WriteLine(fmt.Sprintf("Log format is: %s", reflect.TypeOf(l.Formatter))) return w.WriteLine(fmt.Sprintf("Log format is: %s", reflect.TypeOf(l.Formatter)))
} }
@@ -822,10 +822,10 @@ func sshPrintCert(ifce *Interface, fs any, a []string, w sshd.StringWriter) erro
return w.WriteLine(cert.String()) return w.WriteLine(cert.String())
} }
func sshPrintRelays(ifce *Interface, fs any, _ []string, w sshd.StringWriter) error { func sshPrintRelays(ifce *Interface, fs any, a []string, w sshd.StringWriter) error {
args, ok := fs.(*sshPrintTunnelFlags) args, ok := fs.(*sshPrintTunnelFlags)
if !ok { if !ok {
w.WriteLine("sshPrintRelays failed to convert args type") w.WriteLine(fmt.Sprintf("sshPrintRelays failed to convert args type"))
return nil return nil
} }

View File

@@ -23,6 +23,7 @@ type SSHServer struct {
trustedCAs []ssh.PublicKey trustedCAs []ssh.PublicKey
// List of available commands // List of available commands
helpCommand *Command
commands *radix.Tree commands *radix.Tree
listener net.Listener listener net.Listener
@@ -42,7 +43,7 @@ func NewSSHServer(l *logrus.Entry) (*SSHServer, error) {
conns: make(map[int]*session), conns: make(map[int]*session),
} }
cc := &ssh.CertChecker{ cc := ssh.CertChecker{
IsUserAuthority: func(auth ssh.PublicKey) bool { IsUserAuthority: func(auth ssh.PublicKey) bool {
for _, ca := range s.trustedCAs { for _, ca := range s.trustedCAs {
if bytes.Equal(ca.Marshal(), auth.Marshal()) { if bytes.Equal(ca.Marshal(), auth.Marshal()) {
@@ -76,11 +77,10 @@ func NewSSHServer(l *logrus.Entry) (*SSHServer, error) {
}, },
} }
s.certChecker = cc
s.config = &ssh.ServerConfig{ s.config = &ssh.ServerConfig{
PublicKeyCallback: cc.Authenticate, PublicKeyCallback: cc.Authenticate,
ServerVersion: "SSH-2.0-Nebula???", ServerVersion: fmt.Sprintf("SSH-2.0-Nebula???"),
} }
s.RegisterCommand(&Command{ s.RegisterCommand(&Command{

View File

@@ -170,6 +170,7 @@ func (s *session) dispatchCommand(line string, w StringWriter) {
} }
_ = execCommand(c, args[1:], w) _ = execCommand(c, args[1:], w)
return
} }
func (s *session) Close() { func (s *session) Close() {

View File

@@ -30,11 +30,15 @@ func (NoopConn) Rebind() error {
func (NoopConn) LocalAddr() (netip.AddrPort, error) { func (NoopConn) LocalAddr() (netip.AddrPort, error) {
return netip.AddrPort{}, nil return netip.AddrPort{}, nil
} }
func (NoopConn) ListenOut(_ EncReader) {} func (NoopConn) ListenOut(_ EncReader) {
return
}
func (NoopConn) WriteTo(_ []byte, _ netip.AddrPort) error { func (NoopConn) WriteTo(_ []byte, _ netip.AddrPort) error {
return nil return nil
} }
func (NoopConn) ReloadConfig(_ *config.C) {} func (NoopConn) ReloadConfig(_ *config.C) {
return
}
func (NoopConn) Close() error { func (NoopConn) Close() error {
return nil return nil
} }

View File

@@ -33,7 +33,7 @@ func NewGenericListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, b
if uc, ok := pc.(*net.UDPConn); ok { if uc, ok := pc.(*net.UDPConn); ok {
return &GenericConn{UDPConn: uc, l: l}, nil return &GenericConn{UDPConn: uc, l: l}, nil
} }
return nil, fmt.Errorf("unexpected PacketConn: %T %#v", pc, pc) return nil, fmt.Errorf("Unexpected PacketConn: %T %#v", pc, pc)
} }
func (u *GenericConn) WriteTo(b []byte, addr netip.AddrPort) error { func (u *GenericConn) WriteTo(b []byte, addr netip.AddrPort) error {
@@ -66,6 +66,10 @@ func NewUDPStatsEmitter(udpConns []Conn) func() {
return func() {} return func() {}
} }
type rawMessage struct {
Len uint32
}
func (u *GenericConn) ListenOut(r EncReader) { func (u *GenericConn) ListenOut(r EncReader) {
buffer := make([]byte, MTU) buffer := make([]byte, MTU)