Compare commits

...

14 Commits

Author SHA1 Message Date
Nate Brown
662a3358ee Slight improvement to hot path benchmark, add a relay hot path benchmark 2025-11-21 23:20:20 -06:00
Nate Brown
64f202fa17 Make 0.0.0.0/0 and ::/0 not mean any address family, add any for that (#1538)
Some checks failed
Build and test / Build and test on macos-latest (push) Waiting to run
Build and test / Build and test on windows-latest (push) Waiting to run
gofmt / Run gofmt (push) Failing after 13s
smoke-extra / Run extra smoke tests (push) Failing after 2s
smoke / Run multi node smoke test (push) Failing after 2s
Build and test / Build all and test on ubuntu-linux (push) Failing after 2s
Build and test / Build and test on linux with boringcrypto (push) Failing after 3s
Build and test / Build and test on linux with pkcs11 (push) Failing after 2s
2025-11-21 13:46:36 -06:00
Jack Doan
6d7cf611c9 improve nebula-cert sign version auto-select (#1535) 2025-11-20 13:27:27 -06:00
Nate Brown
83ae8077f5 No need to clear counter 0 (#1537) 2025-11-20 13:22:58 -06:00
Bryan Lee
12cf348c80 feat: support via gateway for v6 multihop for v4 routes (#1521)
Co-authored-by: Nate Brown <nbrown.us@gmail.com>
2025-11-19 22:21:03 -06:00
dependabot[bot]
a5ee928990 Bump golang.org/x/crypto in the golang-x-dependencies group (#1536)
Bumps the golang-x-dependencies group with 1 update: [golang.org/x/crypto](https://github.com/golang/crypto).


Updates `golang.org/x/crypto` from 0.44.0 to 0.45.0
- [Commits](https://github.com/golang/crypto/compare/v0.44.0...v0.45.0)

---
updated-dependencies:
- dependency-name: golang.org/x/crypto
  dependency-version: 0.45.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
  dependency-group: golang-x-dependencies
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-11-19 15:48:53 -06:00
Nate Brown
7aff313a17 Relax the restriction on routines from the config (#1531) 2025-11-19 13:10:11 -06:00
Jack Doan
297767b2e3 warn user if they configure a firewall rule that will allow way more traffic than you might expect (#1513)
* warn user if they accidentally configure a firewall rule that will allow way more traffic than you might expect

* add groups-contains-any warning
2025-11-19 11:06:34 -06:00
Nate Brown
99faab505c Fix a potential bug with udp ipv4 only on darwin (#1532) 2025-11-19 09:56:58 -06:00
dependabot[bot]
584c2668b3 Bump golang.org/x/net in the golang-x-dependencies group (#1530)
Bumps the golang-x-dependencies group with 1 update: [golang.org/x/net](https://github.com/golang/net).


Updates `golang.org/x/net` from 0.46.0 to 0.47.0
- [Commits](https://github.com/golang/net/compare/v0.46.0...v0.47.0)

---
updated-dependencies:
- dependency-name: golang.org/x/net
  dependency-version: 0.47.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
  dependency-group: golang-x-dependencies
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-11-18 21:25:03 -06:00
Wade Simmons
27ea667aee add more tests around bits counters (#1441)
Co-authored-by: Nate Brown <nbrown.us@gmail.com>
2025-11-18 16:42:21 -06:00
Hal Martin
4df8bcb1f5 nebula-cert: support reading CA passphrase from env (#1421)
* nebula-cert: support reading CA passphrase from env

This patch extends the `nebula-cert` command to support reading
the CA passphrase from the environment variable `CA_PASSPHRASE`.

Currently `nebula-cert` depends in an interactive session to obtain
the CA passphrase. This presents a challenge for automation tools like
ansible. With this change, ansible can store the CA passphrase in a
vault and supply it to `nebula-cert` via the `CA_PASSPHRASE`
environment variable for non-interactive signing.

Signed-off-by: Hal Martin <1230969+halmartin@users.noreply.github.com>

* name the variable NEBULA_CA_PASSPHRASE

---------

Signed-off-by: Hal Martin <1230969+halmartin@users.noreply.github.com>
Co-authored-by: JackDoan <me@jackdoan.com>
2025-11-17 14:41:08 -06:00
Wade Simmons
36c890eaad populate default Build version if missing (#1386)
* populate default Build version if missing

Use the Go module information built into the binary if the Build var
wasn't set during the build.

This means if you install via a specific tag, you get:

    go install github.com/slackhq/nebula/cmd/nebula@v1.9.5

    $ nebula -version
    Version: 1.9.5

And if you install master, you get:

    go install github.com/slackhq/nebula/cmd/nebula@master

    $ nebula -version
    Version: 1.9.5-0.20250408154034-18279ed17b10

* also default in the library

* cleanup
2025-11-14 08:58:15 -05:00
dependabot[bot]
44001244f2 Bump github.com/gaissmai/bart from 0.25.0 to 0.26.0 (#1508)
* Bump github.com/gaissmai/bart from 0.25.0 to 0.26.0

Bumps [github.com/gaissmai/bart](https://github.com/gaissmai/bart) from 0.25.0 to 0.26.0.
- [Release notes](https://github.com/gaissmai/bart/releases)
- [Commits](https://github.com/gaissmai/bart/compare/v0.25.0...v0.26.0)

---
updated-dependencies:
- dependency-name: github.com/gaissmai/bart
  dependency-version: 0.26.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>

* fix tests

---------

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Wade Simmons <wsimmons@slack-corp.com>
2025-11-13 13:16:48 -05:00
38 changed files with 732 additions and 336 deletions

109
bits.go
View File

@@ -9,14 +9,13 @@ type Bits struct {
length uint64
current uint64
bits []bool
firstSeen bool
lostCounter metrics.Counter
dupeCounter metrics.Counter
outOfWindowCounter metrics.Counter
}
func NewBits(bits uint64) *Bits {
return &Bits{
b := &Bits{
length: bits,
bits: make([]bool, bits, bits),
current: 0,
@@ -24,34 +23,37 @@ func NewBits(bits uint64) *Bits {
dupeCounter: metrics.GetOrRegisterCounter("network.packets.duplicate", nil),
outOfWindowCounter: metrics.GetOrRegisterCounter("network.packets.out_of_window", nil),
}
// There is no counter value 0, mark it to avoid counting a lost packet later.
b.bits[0] = true
b.current = 0
return b
}
func (b *Bits) Check(l logrus.FieldLogger, i uint64) bool {
func (b *Bits) Check(l *logrus.Logger, i uint64) bool {
// If i is the next number, return true.
if i > b.current || (i == 0 && b.firstSeen == false && b.current < b.length) {
if i > b.current {
return true
}
// If i is within the window, check if it's been set already. The first window will fail this check
if i > b.current-b.length {
return !b.bits[i%b.length]
}
// If i is within the first window
if i < b.length {
// If i is within the window, check if it's been set already.
if i > b.current-b.length || i < b.length && b.current < b.length {
return !b.bits[i%b.length]
}
// Not within the window
l.Debugf("rejected a packet (top) %d %d\n", b.current, i)
if l.Level >= logrus.DebugLevel {
l.Debugf("rejected a packet (top) %d %d\n", b.current, i)
}
return false
}
func (b *Bits) Update(l *logrus.Logger, i uint64) bool {
// If i is the next number, return true and update current.
if i == b.current+1 {
// Report missed packets, we can only understand what was missed after the first window has been gone through
if i > b.length && b.bits[i%b.length] == false {
// Check if the oldest bit was lost since we are shifting the window by 1 and occupying it with this counter
// The very first window can only be tracked as lost once we are on the 2nd window or greater
if b.bits[i%b.length] == false && i > b.length {
b.lostCounter.Inc(1)
}
b.bits[i%b.length] = true
@@ -59,61 +61,32 @@ func (b *Bits) Update(l *logrus.Logger, i uint64) bool {
return true
}
// If i packet is greater than current but less than the maximum length of our bitmap,
// flip everything in between to false and move ahead.
if i > b.current && i < b.current+b.length {
// In between current and i need to be zero'd to allow those packets to come in later
for n := b.current + 1; n < i; n++ {
// If i is a jump, adjust the window, record lost, update current, and return true
if i > b.current {
lost := int64(0)
// Zero out the bits between the current and the new counter value, limited by the window size,
// since the window is shifting
for n := b.current + 1; n <= min(i, b.current+b.length); n++ {
if b.bits[n%b.length] == false && n > b.length {
lost++
}
b.bits[n%b.length] = false
}
b.bits[i%b.length] = true
b.current = i
//l.Debugf("missed %d packets between %d and %d\n", i-b.current, i, b.current)
return true
}
// If i is greater than the delta between current and the total length of our bitmap,
// just flip everything in the map and move ahead.
if i >= b.current+b.length {
// The current window loss will be accounted for later, only record the jump as loss up until then
lost := maxInt64(0, int64(i-b.current-b.length))
//TODO: explain this
if b.current == 0 {
lost++
}
for n := range b.bits {
// Don't want to count the first window as a loss
//TODO: this is likely wrong, we are wanting to track only the bit slots that we aren't going to track anymore and this is marking everything as missed
//if b.bits[n] == false {
// lost++
//}
b.bits[n] = false
}
// Only record any skipped packets as a result of the window moving further than the window length
// Any loss within the new window will be accounted for in future calls
lost += max(0, int64(i-b.current-b.length))
b.lostCounter.Inc(lost)
if l.Level >= logrus.DebugLevel {
l.WithField("receiveWindow", m{"accepted": true, "currentCounter": b.current, "incomingCounter": i, "reason": "window shifting"}).
Debug("Receive window")
}
b.bits[i%b.length] = true
b.current = i
return true
}
// Allow for the 0 packet to come in within the first window
if i == 0 && b.firstSeen == false && b.current < b.length {
b.firstSeen = true
b.bits[i%b.length] = true
return true
}
// If i is within the window of current minus length (the total pat window size),
// allow it and flip to true but to NOT change current. We also have to account for the first window
if ((b.current >= b.length && i > b.current-b.length) || (b.current < b.length && i < b.length)) && i <= b.current {
if b.current == i {
// If i is within the current window but below the current counter,
// Check to see if it's a duplicate
if i > b.current-b.length || i < b.length && b.current < b.length {
if b.current == i || b.bits[i%b.length] == true {
if l.Level >= logrus.DebugLevel {
l.WithField("receiveWindow", m{"accepted": false, "currentCounter": b.current, "incomingCounter": i, "reason": "duplicate"}).
Debug("Receive window")
@@ -122,18 +95,8 @@ func (b *Bits) Update(l *logrus.Logger, i uint64) bool {
return false
}
if b.bits[i%b.length] == true {
if l.Level >= logrus.DebugLevel {
l.WithField("receiveWindow", m{"accepted": false, "currentCounter": b.current, "incomingCounter": i, "reason": "old duplicate"}).
Debug("Receive window")
}
b.dupeCounter.Inc(1)
return false
}
b.bits[i%b.length] = true
return true
}
// In all other cases, fail and don't change current.
@@ -147,11 +110,3 @@ func (b *Bits) Update(l *logrus.Logger, i uint64) bool {
}
return false
}
func maxInt64(a, b int64) int64 {
if a > b {
return a
}
return b
}

View File

@@ -15,48 +15,41 @@ func TestBits(t *testing.T) {
assert.Len(t, b.bits, 10)
// This is initialized to zero - receive one. This should work.
assert.True(t, b.Check(l, 1))
u := b.Update(l, 1)
assert.True(t, u)
assert.True(t, b.Update(l, 1))
assert.EqualValues(t, 1, b.current)
g := []bool{false, true, false, false, false, false, false, false, false, false}
g := []bool{true, true, false, false, false, false, false, false, false, false}
assert.Equal(t, g, b.bits)
// Receive two
assert.True(t, b.Check(l, 2))
u = b.Update(l, 2)
assert.True(t, u)
assert.True(t, b.Update(l, 2))
assert.EqualValues(t, 2, b.current)
g = []bool{false, true, true, false, false, false, false, false, false, false}
g = []bool{true, true, true, false, false, false, false, false, false, false}
assert.Equal(t, g, b.bits)
// Receive two again - it will fail
assert.False(t, b.Check(l, 2))
u = b.Update(l, 2)
assert.False(t, u)
assert.False(t, b.Update(l, 2))
assert.EqualValues(t, 2, b.current)
// Jump ahead to 15, which should clear everything and set the 6th element
assert.True(t, b.Check(l, 15))
u = b.Update(l, 15)
assert.True(t, u)
assert.True(t, b.Update(l, 15))
assert.EqualValues(t, 15, b.current)
g = []bool{false, false, false, false, false, true, false, false, false, false}
assert.Equal(t, g, b.bits)
// Mark 14, which is allowed because it is in the window
assert.True(t, b.Check(l, 14))
u = b.Update(l, 14)
assert.True(t, u)
assert.True(t, b.Update(l, 14))
assert.EqualValues(t, 15, b.current)
g = []bool{false, false, false, false, true, true, false, false, false, false}
assert.Equal(t, g, b.bits)
// Mark 5, which is not allowed because it is not in the window
assert.False(t, b.Check(l, 5))
u = b.Update(l, 5)
assert.False(t, u)
assert.False(t, b.Update(l, 5))
assert.EqualValues(t, 15, b.current)
g = []bool{false, false, false, false, true, true, false, false, false, false}
assert.Equal(t, g, b.bits)
@@ -69,10 +62,29 @@ func TestBits(t *testing.T) {
// Walk through a few windows in order
b = NewBits(10)
for i := uint64(0); i <= 100; i++ {
for i := uint64(1); i <= 100; i++ {
assert.True(t, b.Check(l, i), "Error while checking %v", i)
assert.True(t, b.Update(l, i), "Error while updating %v", i)
}
assert.False(t, b.Check(l, 1), "Out of window check")
}
func TestBitsLargeJumps(t *testing.T) {
l := test.NewLogger()
b := NewBits(10)
b.lostCounter.Clear()
b = NewBits(10)
b.lostCounter.Clear()
assert.True(t, b.Update(l, 55)) // We saw packet 55 and can still track 45,46,47,48,49,50,51,52,53,54
assert.Equal(t, int64(45), b.lostCounter.Count())
assert.True(t, b.Update(l, 100)) // We saw packet 55 and 100 and can still track 90,91,92,93,94,95,96,97,98,99
assert.Equal(t, int64(89), b.lostCounter.Count())
assert.True(t, b.Update(l, 200)) // We saw packet 55, 100, and 200 and can still track 190,191,192,193,194,195,196,197,198,199
assert.Equal(t, int64(188), b.lostCounter.Count())
}
func TestBitsDupeCounter(t *testing.T) {
@@ -124,8 +136,7 @@ func TestBitsOutOfWindowCounter(t *testing.T) {
assert.False(t, b.Update(l, 0))
assert.Equal(t, int64(1), b.outOfWindowCounter.Count())
//tODO: make sure lostcounter doesn't increase in orderly increment
assert.Equal(t, int64(20), b.lostCounter.Count())
assert.Equal(t, int64(19), b.lostCounter.Count()) // packet 0 wasn't lost
assert.Equal(t, int64(0), b.dupeCounter.Count())
assert.Equal(t, int64(1), b.outOfWindowCounter.Count())
}
@@ -137,8 +148,6 @@ func TestBitsLostCounter(t *testing.T) {
b.dupeCounter.Clear()
b.outOfWindowCounter.Clear()
//assert.True(t, b.Update(0))
assert.True(t, b.Update(l, 0))
assert.True(t, b.Update(l, 20))
assert.True(t, b.Update(l, 21))
assert.True(t, b.Update(l, 22))
@@ -149,7 +158,7 @@ func TestBitsLostCounter(t *testing.T) {
assert.True(t, b.Update(l, 27))
assert.True(t, b.Update(l, 28))
assert.True(t, b.Update(l, 29))
assert.Equal(t, int64(20), b.lostCounter.Count())
assert.Equal(t, int64(19), b.lostCounter.Count()) // packet 0 wasn't lost
assert.Equal(t, int64(0), b.dupeCounter.Count())
assert.Equal(t, int64(0), b.outOfWindowCounter.Count())
@@ -158,8 +167,6 @@ func TestBitsLostCounter(t *testing.T) {
b.dupeCounter.Clear()
b.outOfWindowCounter.Clear()
assert.True(t, b.Update(l, 0))
assert.Equal(t, int64(0), b.lostCounter.Count())
assert.True(t, b.Update(l, 9))
assert.Equal(t, int64(0), b.lostCounter.Count())
// 10 will set 0 index, 0 was already set, no lost packets
@@ -214,6 +221,62 @@ func TestBitsLostCounter(t *testing.T) {
assert.Equal(t, int64(0), b.outOfWindowCounter.Count())
}
func TestBitsLostCounterIssue1(t *testing.T) {
l := test.NewLogger()
b := NewBits(10)
b.lostCounter.Clear()
b.dupeCounter.Clear()
b.outOfWindowCounter.Clear()
assert.True(t, b.Update(l, 4))
assert.Equal(t, int64(0), b.lostCounter.Count())
assert.True(t, b.Update(l, 1))
assert.Equal(t, int64(0), b.lostCounter.Count())
assert.True(t, b.Update(l, 9))
assert.Equal(t, int64(0), b.lostCounter.Count())
assert.True(t, b.Update(l, 2))
assert.Equal(t, int64(0), b.lostCounter.Count())
assert.True(t, b.Update(l, 3))
assert.Equal(t, int64(0), b.lostCounter.Count())
assert.True(t, b.Update(l, 5))
assert.Equal(t, int64(0), b.lostCounter.Count())
assert.True(t, b.Update(l, 6))
assert.Equal(t, int64(0), b.lostCounter.Count())
assert.True(t, b.Update(l, 7))
assert.Equal(t, int64(0), b.lostCounter.Count())
// assert.True(t, b.Update(l, 8))
assert.True(t, b.Update(l, 10))
assert.Equal(t, int64(0), b.lostCounter.Count())
assert.True(t, b.Update(l, 11))
assert.Equal(t, int64(0), b.lostCounter.Count())
assert.True(t, b.Update(l, 14))
assert.Equal(t, int64(0), b.lostCounter.Count())
// Issue seems to be here, we reset missing packet 8 to false here and don't increment the lost counter
assert.True(t, b.Update(l, 19))
assert.Equal(t, int64(1), b.lostCounter.Count())
assert.True(t, b.Update(l, 12))
assert.Equal(t, int64(1), b.lostCounter.Count())
assert.True(t, b.Update(l, 13))
assert.Equal(t, int64(1), b.lostCounter.Count())
assert.True(t, b.Update(l, 15))
assert.Equal(t, int64(1), b.lostCounter.Count())
assert.True(t, b.Update(l, 16))
assert.Equal(t, int64(1), b.lostCounter.Count())
assert.True(t, b.Update(l, 17))
assert.Equal(t, int64(1), b.lostCounter.Count())
assert.True(t, b.Update(l, 18))
assert.Equal(t, int64(1), b.lostCounter.Count())
assert.True(t, b.Update(l, 20))
assert.Equal(t, int64(1), b.lostCounter.Count())
assert.True(t, b.Update(l, 21))
// We missed packet 8 above
assert.Equal(t, int64(1), b.lostCounter.Count())
assert.Equal(t, int64(0), b.dupeCounter.Count())
assert.Equal(t, int64(0), b.outOfWindowCounter.Count())
}
func BenchmarkBits(b *testing.B) {
z := NewBits(10)
for n := 0; n < b.N; n++ {

View File

@@ -173,23 +173,26 @@ func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error
var passphrase []byte
if !isP11 && *cf.encryption {
for i := 0; i < 5; i++ {
out.Write([]byte("Enter passphrase: "))
passphrase, err = pr.ReadPassword()
if err == ErrNoTerminal {
return fmt.Errorf("out-key must be encrypted interactively")
} else if err != nil {
return fmt.Errorf("error reading passphrase: %s", err)
}
if len(passphrase) > 0 {
break
}
}
passphrase = []byte(os.Getenv("NEBULA_CA_PASSPHRASE"))
if len(passphrase) == 0 {
return fmt.Errorf("no passphrase specified, remove -encrypt flag to write out-key in plaintext")
for i := 0; i < 5; i++ {
out.Write([]byte("Enter passphrase: "))
passphrase, err = pr.ReadPassword()
if err == ErrNoTerminal {
return fmt.Errorf("out-key must be encrypted interactively")
} else if err != nil {
return fmt.Errorf("error reading passphrase: %s", err)
}
if len(passphrase) > 0 {
break
}
}
if len(passphrase) == 0 {
return fmt.Errorf("no passphrase specified, remove -encrypt flag to write out-key in plaintext")
}
}
}

View File

@@ -171,6 +171,17 @@ func Test_ca(t *testing.T) {
assert.Equal(t, pwPromptOb, ob.String())
assert.Empty(t, eb.String())
// test encrypted key with passphrase environment variable
os.Remove(keyF.Name())
os.Remove(crtF.Name())
ob.Reset()
eb.Reset()
args = []string{"-version", "1", "-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
os.Setenv("NEBULA_CA_PASSPHRASE", string(passphrase))
require.NoError(t, ca(args, ob, eb, testpw))
assert.Empty(t, eb.String())
os.Setenv("NEBULA_CA_PASSPHRASE", "")
// read encrypted key file and verify default params
rb, _ = os.ReadFile(keyF.Name())
k, _ := pem.Decode(rb)

View File

@@ -5,10 +5,28 @@ import (
"fmt"
"io"
"os"
"runtime/debug"
"strings"
)
// A version string that can be set with
//
// -ldflags "-X main.Build=SOMEVERSION"
//
// at compile-time.
var Build string
func init() {
if Build == "" {
info, ok := debug.ReadBuildInfo()
if !ok {
return
}
Build = strings.TrimPrefix(info.Main.Version, "v")
}
}
type helpError struct {
s string
}

View File

@@ -43,7 +43,7 @@ type signFlags struct {
func newSignFlags() *signFlags {
sf := signFlags{set: flag.NewFlagSet("sign", flag.ContinueOnError)}
sf.set.Usage = func() {}
sf.version = sf.set.Uint("version", 0, "Optional: version of the certificate format to use, the default is to create both v1 and v2 certificates.")
sf.version = sf.set.Uint("version", 0, "Optional: version of the certificate format to use. The default is to match the version of the signing CA")
sf.caKeyPath = sf.set.String("ca-key", "ca.key", "Optional: path to the signing CA key")
sf.caCertPath = sf.set.String("ca-crt", "ca.crt", "Optional: path to the signing CA cert")
sf.name = sf.set.String("name", "", "Required: name of the cert, usually a hostname")
@@ -116,26 +116,28 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
// naively attempt to decode the private key as though it is not encrypted
caKey, _, curve, err = cert.UnmarshalSigningPrivateKeyFromPEM(rawCAKey)
if errors.Is(err, cert.ErrPrivateKeyEncrypted) {
// ask for a passphrase until we get one
var passphrase []byte
for i := 0; i < 5; i++ {
out.Write([]byte("Enter passphrase: "))
passphrase, err = pr.ReadPassword()
if errors.Is(err, ErrNoTerminal) {
return fmt.Errorf("ca-key is encrypted and must be decrypted interactively")
} else if err != nil {
return fmt.Errorf("error reading password: %s", err)
}
if len(passphrase) > 0 {
break
}
}
passphrase = []byte(os.Getenv("NEBULA_CA_PASSPHRASE"))
if len(passphrase) == 0 {
return fmt.Errorf("cannot open encrypted ca-key without passphrase")
}
// ask for a passphrase until we get one
for i := 0; i < 5; i++ {
out.Write([]byte("Enter passphrase: "))
passphrase, err = pr.ReadPassword()
if errors.Is(err, ErrNoTerminal) {
return fmt.Errorf("ca-key is encrypted and must be decrypted interactively")
} else if err != nil {
return fmt.Errorf("error reading password: %s", err)
}
if len(passphrase) > 0 {
break
}
}
if len(passphrase) == 0 {
return fmt.Errorf("cannot open encrypted ca-key without passphrase")
}
}
curve, caKey, _, err = cert.DecryptAndUnmarshalSigningPrivateKey(passphrase, rawCAKey)
if err != nil {
return fmt.Errorf("error while parsing encrypted ca-key: %s", err)
@@ -165,6 +167,10 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
return fmt.Errorf("ca certificate is expired")
}
if version == 0 {
version = caCert.Version()
}
// if no duration is given, expire one second before the root expires
if *sf.duration <= 0 {
*sf.duration = time.Until(caCert.NotAfter()) - time.Second*1
@@ -277,21 +283,19 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
notBefore := time.Now()
notAfter := notBefore.Add(*sf.duration)
if version == 0 || version == cert.Version1 {
// Make sure we at least have an ip
switch version {
case cert.Version1:
// Make sure we have only one ipv4 address
if len(v4Networks) != 1 {
return newHelpErrorf("invalid -networks definition: v1 certificates can only have a single ipv4 address")
}
if version == cert.Version1 {
// If we are asked to mint a v1 certificate only then we cant just ignore any v6 addresses
if len(v6Networks) > 0 {
return newHelpErrorf("invalid -networks definition: v1 certificates can only be ipv4")
}
if len(v6Networks) > 0 {
return newHelpErrorf("invalid -networks definition: v1 certificates can only contain ipv4 addresses")
}
if len(v6UnsafeNetworks) > 0 {
return newHelpErrorf("invalid -unsafe-networks definition: v1 certificates can only be ipv4")
}
if len(v6UnsafeNetworks) > 0 {
return newHelpErrorf("invalid -unsafe-networks definition: v1 certificates can only contain ipv4 addresses")
}
t := &cert.TBSCertificate{
@@ -321,9 +325,8 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
}
crts = append(crts, nc)
}
if version == 0 || version == cert.Version2 {
case cert.Version2:
t := &cert.TBSCertificate{
Version: cert.Version2,
Name: *sf.name,
@@ -351,6 +354,9 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
}
crts = append(crts, nc)
default:
// this should be unreachable
return fmt.Errorf("invalid version: %d", version)
}
if !isP11 && *sf.inPubPath == "" {

View File

@@ -55,7 +55,7 @@ func Test_signHelp(t *testing.T) {
" -unsafe-networks string\n"+
" \tOptional: comma separated list of ip address and network in CIDR notation. Unsafe networks this cert can route for\n"+
" -version uint\n"+
" \tOptional: version of the certificate format to use, the default is to create both v1 and v2 certificates.\n",
" \tOptional: version of the certificate format to use. The default is to match the version of the signing CA\n",
ob.String(),
)
}
@@ -204,7 +204,7 @@ func Test_signCert(t *testing.T) {
ob.Reset()
eb.Reset()
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "100::100/100"}
assertHelpError(t, signCert(args, ob, eb, nopw), "invalid -unsafe-networks definition: v1 certificates can only be ipv4")
assertHelpError(t, signCert(args, ob, eb, nopw), "invalid -unsafe-networks definition: v1 certificates can only contain ipv4 addresses")
assert.Empty(t, ob.String())
assert.Empty(t, eb.String())
@@ -379,6 +379,15 @@ func Test_signCert(t *testing.T) {
assert.Equal(t, "Enter passphrase: ", ob.String())
assert.Empty(t, eb.String())
// test with the proper password in the environment
os.Remove(crtF.Name())
os.Remove(keyF.Name())
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
os.Setenv("NEBULA_CA_PASSPHRASE", string(passphrase))
require.NoError(t, signCert(args, ob, eb, testpw))
assert.Empty(t, eb.String())
os.Setenv("NEBULA_CA_PASSPHRASE", "")
// test with the wrong password
ob.Reset()
eb.Reset()
@@ -389,6 +398,17 @@ func Test_signCert(t *testing.T) {
assert.Equal(t, "Enter passphrase: ", ob.String())
assert.Empty(t, eb.String())
// test with the wrong password in environment
ob.Reset()
eb.Reset()
os.Setenv("NEBULA_CA_PASSPHRASE", "invalid password")
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
require.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing encrypted ca-key: invalid passphrase or corrupt private key")
assert.Empty(t, ob.String())
assert.Empty(t, eb.String())
os.Setenv("NEBULA_CA_PASSPHRASE", "")
// test with the user not entering a password
ob.Reset()
eb.Reset()

View File

@@ -4,6 +4,8 @@ import (
"flag"
"fmt"
"os"
"runtime/debug"
"strings"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula"
@@ -18,6 +20,17 @@ import (
// at compile-time.
var Build string
func init() {
if Build == "" {
info, ok := debug.ReadBuildInfo()
if !ok {
return
}
Build = strings.TrimPrefix(info.Main.Version, "v")
}
}
func main() {
serviceFlag := flag.String("service", "", "Control the system service.")
configPath := flag.String("config", "", "Path to either a file or directory to load configuration from")

View File

@@ -4,6 +4,8 @@ import (
"flag"
"fmt"
"os"
"runtime/debug"
"strings"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula"
@@ -18,6 +20,17 @@ import (
// at compile-time.
var Build string
func init() {
if Build == "" {
info, ok := debug.ReadBuildInfo()
if !ok {
return
}
Build = strings.TrimPrefix(info.Main.Version, "v")
}
}
func main() {
configPath := flag.String("config", "", "Path to either a file or directory to load configuration from")
configTest := flag.Bool("test", false, "Test the config and print the end result. Non zero exit indicates a faulty config")

View File

@@ -50,11 +50,6 @@ func NewConnectionState(l *logrus.Logger, cs *CertState, crt cert.Certificate, i
}
static := noise.DHKey{Private: cs.privateKey, Public: crt.PublicKey()}
b := NewBits(ReplayWindow)
// Clear out bit 0, we never transmit it, and we don't want it showing as packet loss
b.Update(l, 0)
hs, err := noise.NewHandshakeState(noise.Config{
CipherSuite: ncs,
Random: rand.Reader,
@@ -74,7 +69,7 @@ func NewConnectionState(l *logrus.Logger, cs *CertState, crt cert.Certificate, i
ci := &ConnectionState{
H: hs,
initiator: initiator,
window: b,
window: NewBits(ReplayWindow),
myCert: crt,
}
// always start the counter from 2, as packet 1 and packet 2 are handshake packets.

View File

@@ -25,11 +25,12 @@ import (
func BenchmarkHotPath(b *testing.B) {
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil)
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil)
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil)
// Put their info in our lighthouse
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
// Start the servers
myControl.Start()
@@ -38,6 +39,9 @@ func BenchmarkHotPath(b *testing.B) {
r := router.NewR(b, myControl, theirControl)
r.CancelFlowLogs()
assertTunnel(b, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
b.ResetTimer()
for n := 0; n < b.N; n++ {
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
_ = r.RouteForAllUntilTxTun(theirControl)
@@ -47,6 +51,39 @@ func BenchmarkHotPath(b *testing.B) {
theirControl.Stop()
}
func BenchmarkHotPathRelay(b *testing.B) {
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}})
relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}})
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}})
// Teach my how to get to the relay and that their can be reached via the relay
myControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr)
myControl.InjectRelays(theirVpnIpNet[0].Addr(), []netip.Addr{relayVpnIpNet[0].Addr()})
relayControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
// Build a router so we don't have to reason who gets which packet
r := router.NewR(b, myControl, relayControl, theirControl)
r.CancelFlowLogs()
// Start the servers
myControl.Start()
relayControl.Start()
theirControl.Start()
assertTunnel(b, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r)
b.ResetTimer()
for n := 0; n < b.N; n++ {
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
_ = r.RouteForAllUntilTxTun(theirControl)
}
myControl.Stop()
theirControl.Stop()
relayControl.Stop()
}
func TestGoodHandshake(t *testing.T) {
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil)

View File

@@ -292,7 +292,7 @@ func deadline(t *testing.T, seconds time.Duration) doneCb {
}
}
func assertTunnel(t *testing.T, vpnIpA, vpnIpB netip.Addr, controlA, controlB *nebula.Control, r *router.R) {
func assertTunnel(t testing.TB, vpnIpA, vpnIpB netip.Addr, controlA, controlB *nebula.Control, r *router.R) {
// Send a packet from them to me
controlB.InjectTunUDPPacket(vpnIpA, 80, vpnIpB, 90, []byte("Hi from B"))
bPacket := r.RouteForAllUntilTxTun(controlA)
@@ -304,7 +304,7 @@ func assertTunnel(t *testing.T, vpnIpA, vpnIpB netip.Addr, controlA, controlB *n
assertUdpPacket(t, []byte("Hello from A"), aPacket, vpnIpA, vpnIpB, 90, 80)
}
func assertHostInfoPair(t *testing.T, addrA, addrB netip.AddrPort, vpnNetsA, vpnNetsB []netip.Prefix, controlA, controlB *nebula.Control) {
func assertHostInfoPair(t testing.TB, addrA, addrB netip.AddrPort, vpnNetsA, vpnNetsB []netip.Prefix, controlA, controlB *nebula.Control) {
// Get both host infos
//TODO: CERT-V2 we may want to loop over each vpnAddr and assert all the things
hBinA := controlA.GetHostInfoByVpnAddr(vpnNetsB[0].Addr(), false)
@@ -325,7 +325,7 @@ func assertHostInfoPair(t *testing.T, addrA, addrB netip.AddrPort, vpnNetsA, vpn
assert.Equal(t, hBinA.RemoteIndex, hAinB.LocalIndex, "Host B remote index does not match host A local index")
}
func assertUdpPacket(t *testing.T, expected, b []byte, fromIp, toIp netip.Addr, fromPort, toPort uint16) {
func assertUdpPacket(t testing.TB, expected, b []byte, fromIp, toIp netip.Addr, fromPort, toPort uint16) {
if toIp.Is6() {
assertUdpPacket6(t, expected, b, fromIp, toIp, fromPort, toPort)
} else {
@@ -333,7 +333,7 @@ func assertUdpPacket(t *testing.T, expected, b []byte, fromIp, toIp netip.Addr,
}
}
func assertUdpPacket6(t *testing.T, expected, b []byte, fromIp, toIp netip.Addr, fromPort, toPort uint16) {
func assertUdpPacket6(t testing.TB, expected, b []byte, fromIp, toIp netip.Addr, fromPort, toPort uint16) {
packet := gopacket.NewPacket(b, layers.LayerTypeIPv6, gopacket.Lazy)
v6 := packet.Layer(layers.LayerTypeIPv6).(*layers.IPv6)
assert.NotNil(t, v6, "No ipv6 data found")
@@ -352,7 +352,7 @@ func assertUdpPacket6(t *testing.T, expected, b []byte, fromIp, toIp netip.Addr,
assert.Equal(t, expected, data.Payload(), "Data was incorrect")
}
func assertUdpPacket4(t *testing.T, expected, b []byte, fromIp, toIp netip.Addr, fromPort, toPort uint16) {
func assertUdpPacket4(t testing.TB, expected, b []byte, fromIp, toIp netip.Addr, fromPort, toPort uint16) {
packet := gopacket.NewPacket(b, layers.LayerTypeIPv4, gopacket.Lazy)
v4 := packet.Layer(layers.LayerTypeIPv4).(*layers.IPv4)
assert.NotNil(t, v4, "No ipv4 data found")

View File

@@ -383,8 +383,9 @@ firewall:
# host: `any` or a literal hostname, ie `test-host`
# group: `any` or a literal group name, ie `default-group`
# groups: Same as group but accepts a list of values. Multiple values are AND'd together and a certificate would have to contain all groups to pass
# cidr: a remote CIDR, `0.0.0.0/0` is any ipv4 and `::/0` is any ipv6.
# local_cidr: a local CIDR, `0.0.0.0/0` is any ipv4 and `::/0` is any ipv6. This can be used to filter destinations when using unsafe_routes.
# cidr: a remote CIDR, `0.0.0.0/0` is any ipv4 and `::/0` is any ipv6. `any` means any ip family and address.
# local_cidr: a local CIDR, `0.0.0.0/0` is any ipv4 and `::/0` is any ipv6. `any` means any ip family and address.
# This can be used to filter destinations when using unsafe_routes.
# By default, this is set to only the VPN (overlay) networks assigned via the certificate networks field unless `default_local_cidr_any` is set to true.
# If there are unsafe_routes present in this config file, `local_cidr` should be set appropriately for the intended us case.
# ca_name: An issuing CA name

View File

@@ -8,6 +8,7 @@ import (
"hash/fnv"
"net/netip"
"reflect"
"slices"
"strconv"
"strings"
"sync"
@@ -22,7 +23,7 @@ import (
)
type FirewallInterface interface {
AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, addr, localAddr netip.Prefix, caName string, caSha string) error
AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, cidr, localCidr string, caName string, caSha string) error
}
type conn struct {
@@ -247,22 +248,11 @@ func NewFirewallFromConfig(l *logrus.Logger, cs *CertState, c *config.C) (*Firew
}
// AddRule properly creates the in memory rule structure for a firewall table.
func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip, localIp netip.Prefix, caName string, caSha string) error {
// Under gomobile, stringing a nil pointer with fmt causes an abort in debug mode for iOS
// https://github.com/golang/go/issues/14131
sIp := ""
if ip.IsValid() {
sIp = ip.String()
}
lIp := ""
if localIp.IsValid() {
lIp = localIp.String()
}
func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, cidr, localCidr, caName string, caSha string) error {
// We need this rule string because we generate a hash. Removing this will break firewall reload.
ruleString := fmt.Sprintf(
"incoming: %v, proto: %v, startPort: %v, endPort: %v, groups: %v, host: %v, ip: %v, localIp: %v, caName: %v, caSha: %s",
incoming, proto, startPort, endPort, groups, host, sIp, lIp, caName, caSha,
incoming, proto, startPort, endPort, groups, host, cidr, localCidr, caName, caSha,
)
f.rules += ruleString + "\n"
@@ -270,7 +260,7 @@ func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort
if !incoming {
direction = "outgoing"
}
f.l.WithField("firewallRule", m{"direction": direction, "proto": proto, "startPort": startPort, "endPort": endPort, "groups": groups, "host": host, "ip": sIp, "localIp": lIp, "caName": caName, "caSha": caSha}).
f.l.WithField("firewallRule", m{"direction": direction, "proto": proto, "startPort": startPort, "endPort": endPort, "groups": groups, "host": host, "cidr": cidr, "localCidr": localCidr, "caName": caName, "caSha": caSha}).
Info("Firewall rule added")
var (
@@ -297,7 +287,7 @@ func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort
return fmt.Errorf("unknown protocol %v", proto)
}
return fp.addRule(f, startPort, endPort, groups, host, ip, localIp, caName, caSha)
return fp.addRule(f, startPort, endPort, groups, host, cidr, localCidr, caName, caSha)
}
// GetRuleHash returns a hash representation of all inbound and outbound rules
@@ -337,7 +327,6 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw
}
for i, t := range rs {
var groups []string
r, err := convertRule(l, t, table, i)
if err != nil {
return fmt.Errorf("%s rule #%v; %s", table, i, err)
@@ -347,23 +336,10 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw
return fmt.Errorf("%s rule #%v; only one of port or code should be provided", table, i)
}
if r.Host == "" && len(r.Groups) == 0 && r.Group == "" && r.Cidr == "" && r.LocalCidr == "" && r.CAName == "" && r.CASha == "" {
if r.Host == "" && len(r.Groups) == 0 && r.Cidr == "" && r.LocalCidr == "" && r.CAName == "" && r.CASha == "" {
return fmt.Errorf("%s rule #%v; at least one of host, group, cidr, local_cidr, ca_name, or ca_sha must be provided", table, i)
}
if len(r.Groups) > 0 {
groups = r.Groups
}
if r.Group != "" {
// Check if we have both groups and group provided in the rule config
if len(groups) > 0 {
return fmt.Errorf("%s rule #%v; only one of group or groups should be defined, both provided", table, i)
}
groups = []string{r.Group}
}
var sPort, errPort string
if r.Code != "" {
errPort = "code"
@@ -392,23 +368,25 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw
return fmt.Errorf("%s rule #%v; proto was not understood; `%s`", table, i, r.Proto)
}
var cidr netip.Prefix
if r.Cidr != "" {
cidr, err = netip.ParsePrefix(r.Cidr)
if r.Cidr != "" && r.Cidr != "any" {
_, err = netip.ParsePrefix(r.Cidr)
if err != nil {
return fmt.Errorf("%s rule #%v; cidr did not parse; %s", table, i, err)
}
}
var localCidr netip.Prefix
if r.LocalCidr != "" {
localCidr, err = netip.ParsePrefix(r.LocalCidr)
if r.LocalCidr != "" && r.LocalCidr != "any" {
_, err = netip.ParsePrefix(r.LocalCidr)
if err != nil {
return fmt.Errorf("%s rule #%v; local_cidr did not parse; %s", table, i, err)
}
}
err = fw.AddRule(inbound, proto, startPort, endPort, groups, r.Host, cidr, localCidr, r.CAName, r.CASha)
if warning := r.sanity(); warning != nil {
l.Warnf("%s rule #%v; %s", table, i, warning)
}
err = fw.AddRule(inbound, proto, startPort, endPort, r.Groups, r.Host, r.Cidr, r.LocalCidr, r.CAName, r.CASha)
if err != nil {
return fmt.Errorf("%s rule #%v; `%s`", table, i, err)
}
@@ -655,7 +633,7 @@ func (ft *FirewallTable) match(p firewall.Packet, incoming bool, c *cert.CachedC
return false
}
func (fp firewallPort) addRule(f *Firewall, startPort int32, endPort int32, groups []string, host string, ip, localIp netip.Prefix, caName string, caSha string) error {
func (fp firewallPort) addRule(f *Firewall, startPort int32, endPort int32, groups []string, host string, cidr, localCidr, caName string, caSha string) error {
if startPort > endPort {
return fmt.Errorf("start port was lower than end port")
}
@@ -668,7 +646,7 @@ func (fp firewallPort) addRule(f *Firewall, startPort int32, endPort int32, grou
}
}
if err := fp[i].addRule(f, groups, host, ip, localIp, caName, caSha); err != nil {
if err := fp[i].addRule(f, groups, host, cidr, localCidr, caName, caSha); err != nil {
return err
}
}
@@ -699,7 +677,7 @@ func (fp firewallPort) match(p firewall.Packet, incoming bool, c *cert.CachedCer
return fp[firewall.PortAny].match(p, c, caPool)
}
func (fc *FirewallCA) addRule(f *Firewall, groups []string, host string, ip, localIp netip.Prefix, caName, caSha string) error {
func (fc *FirewallCA) addRule(f *Firewall, groups []string, host string, cidr, localCidr, caName, caSha string) error {
fr := func() *FirewallRule {
return &FirewallRule{
Hosts: make(map[string]*firewallLocalCIDR),
@@ -713,14 +691,14 @@ func (fc *FirewallCA) addRule(f *Firewall, groups []string, host string, ip, loc
fc.Any = fr()
}
return fc.Any.addRule(f, groups, host, ip, localIp)
return fc.Any.addRule(f, groups, host, cidr, localCidr)
}
if caSha != "" {
if _, ok := fc.CAShas[caSha]; !ok {
fc.CAShas[caSha] = fr()
}
err := fc.CAShas[caSha].addRule(f, groups, host, ip, localIp)
err := fc.CAShas[caSha].addRule(f, groups, host, cidr, localCidr)
if err != nil {
return err
}
@@ -730,7 +708,7 @@ func (fc *FirewallCA) addRule(f *Firewall, groups []string, host string, ip, loc
if _, ok := fc.CANames[caName]; !ok {
fc.CANames[caName] = fr()
}
err := fc.CANames[caName].addRule(f, groups, host, ip, localIp)
err := fc.CANames[caName].addRule(f, groups, host, cidr, localCidr)
if err != nil {
return err
}
@@ -762,24 +740,24 @@ func (fc *FirewallCA) match(p firewall.Packet, c *cert.CachedCertificate, caPool
return fc.CANames[s.Certificate.Name()].match(p, c)
}
func (fr *FirewallRule) addRule(f *Firewall, groups []string, host string, ip, localCIDR netip.Prefix) error {
func (fr *FirewallRule) addRule(f *Firewall, groups []string, host, cidr, localCidr string) error {
flc := func() *firewallLocalCIDR {
return &firewallLocalCIDR{
LocalCIDR: new(bart.Lite),
}
}
if fr.isAny(groups, host, ip) {
if fr.isAny(groups, host, cidr) {
if fr.Any == nil {
fr.Any = flc()
}
return fr.Any.addRule(f, localCIDR)
return fr.Any.addRule(f, localCidr)
}
if len(groups) > 0 {
nlc := flc()
err := nlc.addRule(f, localCIDR)
err := nlc.addRule(f, localCidr)
if err != nil {
return err
}
@@ -795,30 +773,34 @@ func (fr *FirewallRule) addRule(f *Firewall, groups []string, host string, ip, l
if nlc == nil {
nlc = flc()
}
err := nlc.addRule(f, localCIDR)
err := nlc.addRule(f, localCidr)
if err != nil {
return err
}
fr.Hosts[host] = nlc
}
if ip.IsValid() {
nlc, _ := fr.CIDR.Get(ip)
if nlc == nil {
nlc = flc()
}
err := nlc.addRule(f, localCIDR)
if cidr != "" {
c, err := netip.ParsePrefix(cidr)
if err != nil {
return err
}
fr.CIDR.Insert(ip, nlc)
nlc, _ := fr.CIDR.Get(c)
if nlc == nil {
nlc = flc()
}
err = nlc.addRule(f, localCidr)
if err != nil {
return err
}
fr.CIDR.Insert(c, nlc)
}
return nil
}
func (fr *FirewallRule) isAny(groups []string, host string, ip netip.Prefix) bool {
if len(groups) == 0 && host == "" && !ip.IsValid() {
func (fr *FirewallRule) isAny(groups []string, host string, cidr string) bool {
if len(groups) == 0 && host == "" && cidr == "" {
return true
}
@@ -832,7 +814,7 @@ func (fr *FirewallRule) isAny(groups []string, host string, ip netip.Prefix) boo
return true
}
if ip.IsValid() && ip.Bits() == 0 {
if cidr == "any" {
return true
}
@@ -884,8 +866,13 @@ func (fr *FirewallRule) match(p firewall.Packet, c *cert.CachedCertificate) bool
return false
}
func (flc *firewallLocalCIDR) addRule(f *Firewall, localIp netip.Prefix) error {
if !localIp.IsValid() {
func (flc *firewallLocalCIDR) addRule(f *Firewall, localCidr string) error {
if localCidr == "any" {
flc.Any = true
return nil
}
if localCidr == "" {
if !f.hasUnsafeNetworks || f.defaultLocalCIDRAny {
flc.Any = true
return nil
@@ -896,12 +883,13 @@ func (flc *firewallLocalCIDR) addRule(f *Firewall, localIp netip.Prefix) error {
}
return nil
} else if localIp.Bits() == 0 {
flc.Any = true
return nil
}
flc.LocalCIDR.Insert(localIp)
c, err := netip.ParsePrefix(localCidr)
if err != nil {
return err
}
flc.LocalCIDR.Insert(c)
return nil
}
@@ -922,7 +910,6 @@ type rule struct {
Code string
Proto string
Host string
Group string
Groups []string
Cidr string
LocalCidr string
@@ -964,7 +951,8 @@ func convertRule(l *logrus.Logger, p any, table string, i int) (rule, error) {
l.Warnf("%s rule #%v; group was an array with a single value, converting to simple value", table, i)
m["group"] = v[0]
}
r.Group = toString("group", m)
singleGroup := toString("group", m)
if rg, ok := m["groups"]; ok {
switch reflect.TypeOf(rg).Kind() {
@@ -981,9 +969,60 @@ func convertRule(l *logrus.Logger, p any, table string, i int) (rule, error) {
}
}
//flatten group vs groups
if singleGroup != "" {
// Check if we have both groups and group provided in the rule config
if len(r.Groups) > 0 {
return r, fmt.Errorf("only one of group or groups should be defined, both provided")
}
r.Groups = []string{singleGroup}
}
return r, nil
}
// sanity returns an error if the rule would be evaluated in a way that would short-circuit a configured check on a wildcard value
// rules are evaluated as "port AND proto AND (ca_sha OR ca_name) AND (host OR group OR groups OR cidr) AND local_cidr"
func (r *rule) sanity() error {
//port, proto, local_cidr are AND, no need to check here
//ca_sha and ca_name don't have a wildcard value, no need to check here
groupsEmpty := len(r.Groups) == 0
hostEmpty := r.Host == ""
cidrEmpty := r.Cidr == ""
if (groupsEmpty && hostEmpty && cidrEmpty) == true {
return nil //no content!
}
groupsHasAny := slices.Contains(r.Groups, "any")
if groupsHasAny && len(r.Groups) > 1 {
return fmt.Errorf("groups spec [%s] contains the group '\"any\". This rule will ignore the other groups specified", r.Groups)
}
if r.Host == "any" {
if !groupsEmpty {
return fmt.Errorf("groups specified as %s, but host=any will match any host, regardless of groups", r.Groups)
}
if !cidrEmpty {
return fmt.Errorf("cidr specified as %s, but host=any will match any host, regardless of cidr", r.Cidr)
}
}
if groupsHasAny {
if !hostEmpty && r.Host != "any" {
return fmt.Errorf("groups spec [%s] contains the group '\"any\". This rule will ignore the specified host %s", r.Groups, r.Host)
}
if !cidrEmpty {
return fmt.Errorf("groups spec [%s] contains the group '\"any\". This rule will ignore the specified cidr %s", r.Groups, r.Cidr)
}
}
//todo alert on cidr-any
return nil
}
func parsePort(s string) (startPort, endPort int32, err error) {
if s == "any" {
startPort = firewall.PortAny

View File

@@ -73,78 +73,106 @@ func TestFirewall_AddRule(t *testing.T) {
ti6, err := netip.ParsePrefix("fd12::34/128")
require.NoError(t, err)
require.NoError(t, fw.AddRule(true, firewall.ProtoTCP, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoTCP, 1, 1, []string{}, "", "", "", "", ""))
// An empty rule is any
assert.True(t, fw.InRules.TCP[1].Any.Any.Any)
assert.Empty(t, fw.InRules.TCP[1].Any.Groups)
assert.Empty(t, fw.InRules.TCP[1].Any.Hosts)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", "", "", "", ""))
assert.Nil(t, fw.InRules.UDP[1].Any.Any)
assert.Contains(t, fw.InRules.UDP[1].Any.Groups[0].Groups, "g1")
assert.Empty(t, fw.InRules.UDP[1].Any.Hosts)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(true, firewall.ProtoICMP, 1, 1, []string{}, "h1", netip.Prefix{}, netip.Prefix{}, "", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoICMP, 1, 1, []string{}, "h1", "", "", "", ""))
assert.Nil(t, fw.InRules.ICMP[1].Any.Any)
assert.Empty(t, fw.InRules.ICMP[1].Any.Groups)
assert.Contains(t, fw.InRules.ICMP[1].Any.Hosts, "h1")
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti, netip.Prefix{}, "", ""))
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti.String(), "", "", ""))
assert.Nil(t, fw.OutRules.AnyProto[1].Any.Any)
_, ok := fw.OutRules.AnyProto[1].Any.CIDR.Get(ti)
assert.True(t, ok)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti6, netip.Prefix{}, "", ""))
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti6.String(), "", "", ""))
assert.Nil(t, fw.OutRules.AnyProto[1].Any.Any)
_, ok = fw.OutRules.AnyProto[1].Any.CIDR.Get(ti6)
assert.True(t, ok)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, ti, "", ""))
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", "", ti.String(), "", ""))
assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any)
_, ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti)
ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti)
assert.True(t, ok)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, ti6, "", ""))
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", "", ti6.String(), "", ""))
assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any)
_, ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti6)
ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti6)
assert.True(t, ok)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "ca-name", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", "", "", "ca-name", ""))
assert.Contains(t, fw.InRules.UDP[1].CANames, "ca-name")
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "", "ca-sha"))
require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", "", "", "", "ca-sha"))
assert.Contains(t, fw.InRules.UDP[1].CAShas, "ca-sha")
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", netip.Prefix{}, netip.Prefix{}, "", ""))
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", "", "", "", ""))
assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
anyIp, err := netip.ParsePrefix("0.0.0.0/0")
require.NoError(t, err)
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp, netip.Prefix{}, "", ""))
assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp.String(), "", "", ""))
assert.Nil(t, fw.OutRules.AnyProto[0].Any.Any)
table, ok := fw.OutRules.AnyProto[0].Any.CIDR.Lookup(netip.MustParseAddr("1.1.1.1"))
assert.True(t, table.Any)
table, ok = fw.OutRules.AnyProto[0].Any.CIDR.Lookup(netip.MustParseAddr("9::9"))
assert.False(t, ok)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
anyIp6, err := netip.ParsePrefix("::/0")
require.NoError(t, err)
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp6, netip.Prefix{}, "", ""))
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp6.String(), "", "", ""))
assert.Nil(t, fw.OutRules.AnyProto[0].Any.Any)
table, ok = fw.OutRules.AnyProto[0].Any.CIDR.Lookup(netip.MustParseAddr("9::9"))
assert.True(t, table.Any)
table, ok = fw.OutRules.AnyProto[0].Any.CIDR.Lookup(netip.MustParseAddr("1.1.1.1"))
assert.False(t, ok)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", "any", "", "", ""))
assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", "", anyIp.String(), "", ""))
assert.False(t, fw.OutRules.AnyProto[0].Any.Any.Any)
assert.True(t, fw.OutRules.AnyProto[0].Any.Any.LocalCIDR.Lookup(netip.MustParseAddr("1.1.1.1")))
assert.False(t, fw.OutRules.AnyProto[0].Any.Any.LocalCIDR.Lookup(netip.MustParseAddr("9::9")))
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", "", anyIp6.String(), "", ""))
assert.False(t, fw.OutRules.AnyProto[0].Any.Any.Any)
assert.True(t, fw.OutRules.AnyProto[0].Any.Any.LocalCIDR.Lookup(netip.MustParseAddr("9::9")))
assert.False(t, fw.OutRules.AnyProto[0].Any.Any.LocalCIDR.Lookup(netip.MustParseAddr("1.1.1.1")))
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", "", "any", "", ""))
assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
// Test error conditions
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.Error(t, fw.AddRule(true, math.MaxUint8, 0, 0, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
require.Error(t, fw.AddRule(true, firewall.ProtoAny, 10, 0, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
require.Error(t, fw.AddRule(true, math.MaxUint8, 0, 0, []string{}, "", "", "", "", ""))
require.Error(t, fw.AddRule(true, firewall.ProtoAny, 10, 0, []string{}, "", "", "", "", ""))
}
func TestFirewall_Drop(t *testing.T) {
@@ -180,7 +208,7 @@ func TestFirewall_Drop(t *testing.T) {
h.buildNetworks(myVpnNetworksTable, &c)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", ""))
cp := cert.NewCAPool()
// Drop outbound
@@ -199,28 +227,28 @@ func TestFirewall_Drop(t *testing.T) {
// ensure signer doesn't get in the way of group checks
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum"))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad"))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum"))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum-bad"))
assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
// test caSha doesn't drop on match
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad"))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum"))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum-bad"))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum"))
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
// ensure ca name doesn't get in the way of group checks
cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good-bad", ""))
assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
// test caName doesn't drop on match
cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good-bad", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good", ""))
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
}
@@ -259,7 +287,7 @@ func TestFirewall_DropV6(t *testing.T) {
h.buildNetworks(myVpnNetworksTable, &c)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", ""))
cp := cert.NewCAPool()
// Drop outbound
@@ -278,28 +306,28 @@ func TestFirewall_DropV6(t *testing.T) {
// ensure signer doesn't get in the way of group checks
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum"))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad"))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum"))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum-bad"))
assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
// test caSha doesn't drop on match
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad"))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum"))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum-bad"))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum"))
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
// ensure ca name doesn't get in the way of group checks
cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good-bad", ""))
assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
// test caName doesn't drop on match
cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good-bad", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good", ""))
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
}
@@ -310,12 +338,12 @@ func BenchmarkFirewallTable_match(b *testing.B) {
}
pfix := netip.MustParsePrefix("172.1.1.1/32")
_ = ft.TCP.addRule(f, 10, 10, []string{"good-group"}, "good-host", pfix, netip.Prefix{}, "", "")
_ = ft.TCP.addRule(f, 100, 100, []string{"good-group"}, "good-host", netip.Prefix{}, pfix, "", "")
_ = ft.TCP.addRule(f, 10, 10, []string{"good-group"}, "good-host", pfix.String(), "", "", "")
_ = ft.TCP.addRule(f, 100, 100, []string{"good-group"}, "good-host", "", pfix.String(), "", "")
pfix6 := netip.MustParsePrefix("fd11::11/128")
_ = ft.TCP.addRule(f, 10, 10, []string{"good-group"}, "good-host", pfix6, netip.Prefix{}, "", "")
_ = ft.TCP.addRule(f, 100, 100, []string{"good-group"}, "good-host", netip.Prefix{}, pfix6, "", "")
_ = ft.TCP.addRule(f, 10, 10, []string{"good-group"}, "good-host", pfix6.String(), "", "", "")
_ = ft.TCP.addRule(f, 100, 100, []string{"good-group"}, "good-host", "", pfix6.String(), "", "")
cp := cert.NewCAPool()
b.Run("fail on proto", func(b *testing.B) {
@@ -504,7 +532,7 @@ func TestFirewall_Drop2(t *testing.T) {
h1.buildNetworks(myVpnNetworksTable, c1.Certificate)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", "", "", "", ""))
cp := cert.NewCAPool()
// h1/c1 lacks the proper groups
@@ -584,8 +612,8 @@ func TestFirewall_Drop3(t *testing.T) {
h3.buildNetworks(myVpnNetworksTable, c3.Certificate)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", netip.Prefix{}, netip.Prefix{}, "", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-sha"))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", "", "", "", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", "", "", "", "signer-sha"))
cp := cert.NewCAPool()
// c1 should pass because host match
@@ -599,7 +627,7 @@ func TestFirewall_Drop3(t *testing.T) {
// Test a remote address match
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.MustParsePrefix("1.2.3.4/24"), netip.Prefix{}, "", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", "1.2.3.4/24", "", "", ""))
require.NoError(t, fw.Drop(p, true, &h1, cp, nil))
}
@@ -637,7 +665,7 @@ func TestFirewall_Drop3V6(t *testing.T) {
// Test a remote address match
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
cp := cert.NewCAPool()
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.MustParsePrefix("fd12::34/120"), netip.Prefix{}, "", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", "fd12::34/120", "", "", ""))
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
}
@@ -676,7 +704,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
h.buildNetworks(myVpnNetworksTable, c.Certificate)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", ""))
cp := cert.NewCAPool()
// Drop outbound
@@ -689,7 +717,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
oldFw := fw
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 10, 10, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 10, 10, []string{"any"}, "", "", "", "", ""))
fw.Conntrack = oldFw.Conntrack
fw.rulesVersion = oldFw.rulesVersion + 1
@@ -698,7 +726,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
oldFw = fw
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 11, 11, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 11, 11, []string{"any"}, "", "", "", "", ""))
fw.Conntrack = oldFw.Conntrack
fw.rulesVersion = oldFw.rulesVersion + 1
@@ -737,7 +765,7 @@ func TestFirewall_DropIPSpoofing(t *testing.T) {
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", "", "", "", ""))
cp := cert.NewCAPool()
// Packet spoofed by `c1`. Note that the remote addr is not a valid one.
@@ -931,28 +959,28 @@ func TestAddFirewallRulesFromConfig(t *testing.T) {
mf := &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "tcp", "host": "a"}}}
require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf))
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoTCP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoTCP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall)
// Test adding udp rule
conf = config.NewC(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "udp", "host": "a"}}}
require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf))
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoUDP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoUDP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall)
// Test adding icmp rule
conf = config.NewC(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "icmp", "host": "a"}}}
require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf))
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall)
// Test adding any rule
conf = config.NewC(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "host": "a"}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall)
// Test adding rule with cidr
cidr := netip.MustParsePrefix("10.0.0.0/8")
@@ -960,14 +988,14 @@ func TestAddFirewallRulesFromConfig(t *testing.T) {
mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": cidr.String()}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr, localIp: netip.Prefix{}}, mf.lastCall)
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr.String(), localIp: ""}, mf.lastCall)
// Test adding rule with local_cidr
conf = config.NewC(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": cidr.String()}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: cidr}, mf.lastCall)
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "", localIp: cidr.String()}, mf.lastCall)
// Test adding rule with cidr ipv6
cidr6 := netip.MustParsePrefix("fd00::/8")
@@ -975,49 +1003,75 @@ func TestAddFirewallRulesFromConfig(t *testing.T) {
mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": cidr6.String()}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr6, localIp: netip.Prefix{}}, mf.lastCall)
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr6.String(), localIp: ""}, mf.lastCall)
// Test adding rule with any cidr
conf = config.NewC(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": "any"}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "any", localIp: ""}, mf.lastCall)
// Test adding rule with junk cidr
conf = config.NewC(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": "junk/junk"}}}
require.EqualError(t, AddFirewallRulesFromConfig(l, true, conf, mf), "firewall.inbound rule #0; cidr did not parse; netip.ParsePrefix(\"junk/junk\"): ParseAddr(\"junk\"): unable to parse IP")
// Test adding rule with local_cidr ipv6
conf = config.NewC(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": cidr6.String()}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: cidr6}, mf.lastCall)
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "", localIp: cidr6.String()}, mf.lastCall)
// Test adding rule with any local_cidr
conf = config.NewC(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": "any"}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, localIp: "any"}, mf.lastCall)
// Test adding rule with junk local_cidr
conf = config.NewC(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": "junk/junk"}}}
require.EqualError(t, AddFirewallRulesFromConfig(l, true, conf, mf), "firewall.inbound rule #0; local_cidr did not parse; netip.ParsePrefix(\"junk/junk\"): ParseAddr(\"junk\"): unable to parse IP")
// Test adding rule with ca_sha
conf = config.NewC(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "ca_sha": "12312313123"}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: netip.Prefix{}, caSha: "12312313123"}, mf.lastCall)
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "", localIp: "", caSha: "12312313123"}, mf.lastCall)
// Test adding rule with ca_name
conf = config.NewC(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "ca_name": "root01"}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: netip.Prefix{}, caName: "root01"}, mf.lastCall)
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "", localIp: "", caName: "root01"}, mf.lastCall)
// Test single group
conf = config.NewC(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "group": "a"}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: "", localIp: ""}, mf.lastCall)
// Test single groups
conf = config.NewC(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "groups": "a"}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: "", localIp: ""}, mf.lastCall)
// Test multiple AND groups
conf = config.NewC(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "groups": []string{"a", "b"}}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a", "b"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a", "b"}, ip: "", localIp: ""}, mf.lastCall)
// Test Add error
conf = config.NewC(l)
@@ -1040,7 +1094,7 @@ func TestFirewall_convertRule(t *testing.T) {
r, err := convertRule(l, c, "test", 1)
assert.Contains(t, ob.String(), "test rule #1; group was an array with a single value, converting to simple value")
require.NoError(t, err)
assert.Equal(t, "group1", r.Group)
assert.Equal(t, []string{"group1"}, r.Groups)
// Ensure group array of > 1 is errord
ob.Reset()
@@ -1060,7 +1114,63 @@ func TestFirewall_convertRule(t *testing.T) {
r, err = convertRule(l, c, "test", 1)
require.NoError(t, err)
assert.Equal(t, "group1", r.Group)
assert.Equal(t, []string{"group1"}, r.Groups)
}
func TestFirewall_convertRuleSanity(t *testing.T) {
l := test.NewLogger()
ob := &bytes.Buffer{}
l.SetOutput(ob)
noWarningPlease := []map[string]any{
{"group": "group1"},
{"groups": []any{"group2"}},
{"host": "bob"},
{"cidr": "1.1.1.1/1"},
{"groups": []any{"group2"}, "host": "bob"},
{"cidr": "1.1.1.1/1", "host": "bob"},
{"groups": []any{"group2"}, "cidr": "1.1.1.1/1"},
{"groups": []any{"group2"}, "cidr": "1.1.1.1/1", "host": "bob"},
}
for _, c := range noWarningPlease {
r, err := convertRule(l, c, "test", 1)
require.NoError(t, err)
require.NoError(t, r.sanity(), "should not generate a sanity warning, %+v", c)
}
yesWarningPlease := []map[string]any{
{"group": "group1"},
{"groups": []any{"group2"}},
{"cidr": "1.1.1.1/1"},
{"groups": []any{"group2"}, "host": "bob"},
{"cidr": "1.1.1.1/1", "host": "bob"},
{"groups": []any{"group2"}, "cidr": "1.1.1.1/1"},
{"groups": []any{"group2"}, "cidr": "1.1.1.1/1", "host": "bob"},
}
for _, c := range yesWarningPlease {
c["host"] = "any"
r, err := convertRule(l, c, "test", 1)
require.NoError(t, err)
err = r.sanity()
require.Error(t, err, "I wanted a warning: %+v", c)
}
//reset the list
yesWarningPlease = []map[string]any{
{"group": "group1"},
{"groups": []any{"group2"}},
{"cidr": "1.1.1.1/1"},
{"groups": []any{"group2"}, "host": "bob"},
{"cidr": "1.1.1.1/1", "host": "bob"},
{"groups": []any{"group2"}, "cidr": "1.1.1.1/1"},
{"groups": []any{"group2"}, "cidr": "1.1.1.1/1", "host": "bob"},
}
for _, c := range yesWarningPlease {
r, err := convertRule(l, c, "test", 1)
require.NoError(t, err)
r.Groups = append(r.Groups, "any")
err = r.sanity()
require.Error(t, err, "I wanted a warning: %+v", c)
}
}
type testcase struct {
@@ -1141,7 +1251,7 @@ func newSetupFromCert(t *testing.T, l *logrus.Logger, c dummyCert) testsetup {
myVpnNetworksTable.Insert(prefix)
}
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", ""))
return testsetup{
c: c,
@@ -1222,7 +1332,7 @@ func TestFirewall_Drop_EnforceIPMatch(t *testing.T) {
tc.p.LocalAddr = netip.MustParseAddr("192.168.0.3")
tc.err = ErrNoMatchingRule
tc.Test(t, unsafeSetup.fw) //should hit firewall and bounce off
require.NoError(t, unsafeSetup.fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, unsafePrefix, "", ""))
require.NoError(t, unsafeSetup.fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", unsafePrefix.String(), "", ""))
tc.err = nil
tc.Test(t, unsafeSetup.fw) //should pass
})
@@ -1235,8 +1345,8 @@ type addRuleCall struct {
endPort int32
groups []string
host string
ip netip.Prefix
localIp netip.Prefix
ip string
localIp string
caName string
caSha string
}
@@ -1246,7 +1356,7 @@ type mockFirewall struct {
nextCallReturn error
}
func (mf *mockFirewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip netip.Prefix, localIp netip.Prefix, caName string, caSha string) error {
func (mf *mockFirewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip, localIp, caName string, caSha string) error {
mf.lastCall = addRuleCall{
incoming: incoming,
proto: proto,

6
go.mod
View File

@@ -8,7 +8,7 @@ require (
github.com/armon/go-radix v1.0.0
github.com/cyberdelia/go-metrics-graphite v0.0.0-20161219230853-39f87cc3b432
github.com/flynn/noise v1.1.0
github.com/gaissmai/bart v0.25.0
github.com/gaissmai/bart v0.26.0
github.com/gogo/protobuf v1.3.2
github.com/google/gopacket v1.1.19
github.com/kardianos/service v1.2.4
@@ -23,9 +23,9 @@ require (
github.com/stretchr/testify v1.11.1
github.com/vishvananda/netlink v1.3.1
go.yaml.in/yaml/v3 v3.0.4
golang.org/x/crypto v0.44.0
golang.org/x/crypto v0.45.0
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090
golang.org/x/net v0.46.0
golang.org/x/net v0.47.0
golang.org/x/sync v0.18.0
golang.org/x/sys v0.38.0
golang.org/x/term v0.37.0

12
go.sum
View File

@@ -24,8 +24,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/flynn/noise v1.1.0 h1:KjPQoQCEFdZDiP03phOvGi11+SVVhBG2wOWAorLsstg=
github.com/flynn/noise v1.1.0/go.mod h1:xbMo+0i6+IGbYdJhF31t2eR1BIU0CYc12+BNAKwUTag=
github.com/gaissmai/bart v0.25.0 h1:eqiokVPqM3F94vJ0bTHXHtH91S8zkKL+bKh+BsGOsJM=
github.com/gaissmai/bart v0.25.0/go.mod h1:GREWQfTLRWz/c5FTOsIw+KkscuFkIV5t8Rp7Nd1Td5c=
github.com/gaissmai/bart v0.26.0 h1:xOZ57E9hJLBiQaSyeZa9wgWhGuzfGACgqp4BE77OkO0=
github.com/gaissmai/bart v0.26.0/go.mod h1:GREWQfTLRWz/c5FTOsIw+KkscuFkIV5t8Rp7Nd1Td5c=
github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY=
@@ -162,8 +162,8 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
golang.org/x/crypto v0.44.0 h1:A97SsFvM3AIwEEmTBiaxPPTYpDC47w720rdiiUvgoAU=
golang.org/x/crypto v0.44.0/go.mod h1:013i+Nw79BMiQiMsOPcVCB5ZIJbYkerPrGnOa00tvmc=
golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q=
golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4=
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 h1:Di6/M8l0O2lCLc6VVRWhgCiApHV8MnQurBnFSHsQtNY=
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc=
golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY=
@@ -182,8 +182,8 @@ golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLL
golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.46.0 h1:giFlY12I07fugqwPuWJi68oOnpfqFnJIJzaIIm2JVV4=
golang.org/x/net v0.46.0/go.mod h1:Q9BGdFy1y4nkUwiLvT5qtyhAnEHgnQ/zd8PfU6nc210=
golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY=
golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU=
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=

View File

@@ -222,6 +222,13 @@ func (f *Interface) activate() {
WithField("boringcrypto", boringEnabled()).
Info("Nebula interface is active")
if f.routines > 1 {
if !f.inside.SupportsMultiqueue() || !f.outside.SupportsMultipleReaders() {
f.routines = 1
f.l.Warn("routines is not supported on this platform, falling back to a single routine")
}
}
metrics.GetOrRegisterGauge("routines", nil).Update(int64(f.routines))
// Prepare n tun queues

21
main.go
View File

@@ -5,6 +5,8 @@ import (
"fmt"
"net"
"net/netip"
"runtime/debug"
"strings"
"time"
"github.com/sirupsen/logrus"
@@ -27,6 +29,10 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
}
}()
if buildVersion == "" {
buildVersion = moduleVersion()
}
l := logger
l.Formatter = &logrus.TextFormatter{
FullTimestamp: true,
@@ -296,3 +302,18 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
connManager.Start,
}, nil
}
func moduleVersion() string {
info, ok := debug.ReadBuildInfo()
if !ok {
return ""
}
for _, dep := range info.Deps {
if dep.Path == "github.com/slackhq/nebula" {
return strings.TrimPrefix(dep.Version, "v")
}
}
return ""
}

View File

@@ -13,5 +13,6 @@ type Device interface {
Networks() []netip.Prefix
Name() string
RoutesFor(netip.Addr) routing.Gateways
SupportsMultiqueue() bool
NewMultiQueueReader() (io.ReadWriteCloser, error)
}

View File

@@ -95,6 +95,10 @@ func (t *tun) Name() string {
return "android"
}
func (t *tun) SupportsMultiqueue() bool {
return false
}
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for android")
}

View File

@@ -549,6 +549,10 @@ func (t *tun) Name() string {
return t.Device
}
func (t *tun) SupportsMultiqueue() bool {
return false
}
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for darwin")
}

View File

@@ -105,6 +105,10 @@ func (t *disabledTun) Write(b []byte) (int, error) {
return len(b), nil
}
func (t *disabledTun) SupportsMultiqueue() bool {
return true
}
func (t *disabledTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return t, nil
}

View File

@@ -450,6 +450,10 @@ func (t *tun) Name() string {
return t.Device
}
func (t *tun) SupportsMultiqueue() bool {
return false
}
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for freebsd")
}

View File

@@ -151,6 +151,10 @@ func (t *tun) Name() string {
return "iOS"
}
func (t *tun) SupportsMultiqueue() bool {
return false
}
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for ios")
}

View File

@@ -216,6 +216,10 @@ func (t *tun) reload(c *config.C, initial bool) error {
return nil
}
func (t *tun) SupportsMultiqueue() bool {
return true
}
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
if err != nil {
@@ -582,48 +586,42 @@ func (t *tun) isGatewayInVpnNetworks(gwAddr netip.Addr) bool {
}
func (t *tun) getGatewaysFromRoute(r *netlink.Route) routing.Gateways {
var gateways routing.Gateways
link, err := netlink.LinkByName(t.Device)
if err != nil {
t.l.WithField("Devicename", t.Device).Error("Ignoring route update: failed to get link by name")
t.l.WithField("deviceName", t.Device).Error("Ignoring route update: failed to get link by name")
return gateways
}
// If this route is relevant to our interface and there is a gateway then add it
if r.LinkIndex == link.Attrs().Index && len(r.Gw) > 0 {
gwAddr, ok := netip.AddrFromSlice(r.Gw)
if !ok {
t.l.WithField("route", r).Debug("Ignoring route update, invalid gateway address")
} else {
gwAddr = gwAddr.Unmap()
if !t.isGatewayInVpnNetworks(gwAddr) {
// Gateway isn't in our overlay network, ignore
t.l.WithField("route", r).Debug("Ignoring route update, not in our network")
} else {
if r.LinkIndex == link.Attrs().Index {
gwAddr, ok := getGatewayAddr(r.Gw, r.Via)
if ok {
if t.isGatewayInVpnNetworks(gwAddr) {
gateways = append(gateways, routing.NewGateway(gwAddr, 1))
} else {
// Gateway isn't in our overlay network, ignore
t.l.WithField("route", r).Debug("Ignoring route update, gateway is not in our network")
}
} else {
t.l.WithField("route", r).Debug("Ignoring route update, invalid gateway or via address")
}
}
for _, p := range r.MultiPath {
// If this route is relevant to our interface and there is a gateway then add it
if p.LinkIndex == link.Attrs().Index && len(p.Gw) > 0 {
gwAddr, ok := netip.AddrFromSlice(p.Gw)
if !ok {
t.l.WithField("route", r).Debug("Ignoring multipath route update, invalid gateway address")
} else {
gwAddr = gwAddr.Unmap()
if !t.isGatewayInVpnNetworks(gwAddr) {
// Gateway isn't in our overlay network, ignore
t.l.WithField("route", r).Debug("Ignoring route update, not in our network")
} else {
// p.Hops+1 = weight of the route
if p.LinkIndex == link.Attrs().Index {
gwAddr, ok := getGatewayAddr(p.Gw, p.Via)
if ok {
if t.isGatewayInVpnNetworks(gwAddr) {
gateways = append(gateways, routing.NewGateway(gwAddr, p.Hops+1))
} else {
// Gateway isn't in our overlay network, ignore
t.l.WithField("route", r).Debug("Ignoring route update, gateway is not in our network")
}
} else {
t.l.WithField("route", r).Debug("Ignoring route update, invalid gateway or via address")
}
}
}
@@ -632,10 +630,27 @@ func (t *tun) getGatewaysFromRoute(r *netlink.Route) routing.Gateways {
return gateways
}
func getGatewayAddr(gw net.IP, via netlink.Destination) (netip.Addr, bool) {
// Try to use the old RTA_GATEWAY first
gwAddr, ok := netip.AddrFromSlice(gw)
if !ok {
// Fallback to the new RTA_VIA
rVia, ok := via.(*netlink.Via)
if ok {
gwAddr, ok = netip.AddrFromSlice(rVia.Addr)
}
}
if gwAddr.IsValid() {
gwAddr = gwAddr.Unmap()
return gwAddr, true
}
return netip.Addr{}, false
}
func (t *tun) updateRoutes(r netlink.RouteUpdate) {
gateways := t.getGatewaysFromRoute(&r.Route)
if len(gateways) == 0 {
// No gateways relevant to our network, no routing changes required.
t.l.WithField("route", r).Debug("Ignoring route update, no gateways")

View File

@@ -390,6 +390,10 @@ func (t *tun) Name() string {
return t.Device
}
func (t *tun) SupportsMultiqueue() bool {
return false
}
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for netbsd")
}

View File

@@ -310,6 +310,10 @@ func (t *tun) Name() string {
return t.Device
}
func (t *tun) SupportsMultiqueue() bool {
return false
}
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for openbsd")
}

View File

@@ -132,6 +132,10 @@ func (t *TestTun) Read(b []byte) (int, error) {
return len(p), nil
}
func (t *TestTun) SupportsMultiqueue() bool {
return false
}
func (t *TestTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented")
}

View File

@@ -234,6 +234,10 @@ func (t *winTun) Write(b []byte) (int, error) {
return t.tun.Write(b, 0)
}
func (t *winTun) SupportsMultiqueue() bool {
return false
}
func (t *winTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for windows")
}

View File

@@ -46,6 +46,10 @@ func (d *UserDevice) RoutesFor(ip netip.Addr) routing.Gateways {
return routing.Gateways{routing.NewGateway(ip, 1)}
}
func (d *UserDevice) SupportsMultiqueue() bool {
return true
}
func (d *UserDevice) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return d, nil
}

View File

@@ -34,6 +34,10 @@ func (NoopTun) Write([]byte) (int, error) {
return 0, nil
}
func (NoopTun) SupportsMultiqueue() bool {
return false
}
func (NoopTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return nil, errors.New("unsupported")
}

View File

@@ -19,6 +19,7 @@ type Conn interface {
ListenOut(r EncReader)
WriteTo(b []byte, addr netip.AddrPort) error
ReloadConfig(c *config.C)
SupportsMultipleReaders() bool
Close() error
}
@@ -33,6 +34,9 @@ func (NoopConn) LocalAddr() (netip.AddrPort, error) {
func (NoopConn) ListenOut(_ EncReader) {
return
}
func (NoopConn) SupportsMultipleReaders() bool {
return false
}
func (NoopConn) WriteTo(_ []byte, _ netip.AddrPort) error {
return nil
}

View File

@@ -98,9 +98,9 @@ func (u *StdConn) WriteTo(b []byte, ap netip.AddrPort) error {
return ErrInvalidIPv6RemoteForSocket
}
var rsa unix.RawSockaddrInet6
rsa.Family = unix.AF_INET6
rsa.Addr = ap.Addr().As16()
var rsa unix.RawSockaddrInet4
rsa.Family = unix.AF_INET
rsa.Addr = ap.Addr().As4()
binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&rsa.Port))[:], ap.Port())
sa = unsafe.Pointer(&rsa)
addrLen = syscall.SizeofSockaddrInet4
@@ -184,6 +184,10 @@ func (u *StdConn) ListenOut(r EncReader) {
}
}
func (u *StdConn) SupportsMultipleReaders() bool {
return false
}
func (u *StdConn) Rebind() error {
var err error
if u.isV4 {

View File

@@ -85,3 +85,7 @@ func (u *GenericConn) ListenOut(r EncReader) {
r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n])
}
}
func (u *GenericConn) SupportsMultipleReaders() bool {
return false
}

View File

@@ -72,6 +72,10 @@ func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch in
return &StdConn{sysFd: fd, isV4: ip.Is4(), l: l, batch: batch}, err
}
func (u *StdConn) SupportsMultipleReaders() bool {
return true
}
func (u *StdConn) Rebind() error {
return nil
}

View File

@@ -315,6 +315,10 @@ func (u *RIOConn) LocalAddr() (netip.AddrPort, error) {
}
func (u *RIOConn) SupportsMultipleReaders() bool {
return false
}
func (u *RIOConn) Rebind() error {
return nil
}

View File

@@ -127,6 +127,10 @@ func (u *TesterConn) LocalAddr() (netip.AddrPort, error) {
return u.Addr, nil
}
func (u *TesterConn) SupportsMultipleReaders() bool {
return false
}
func (u *TesterConn) Rebind() error {
return nil
}