Compare commits

..

10 Commits

Author SHA1 Message Date
Nate Brown
064831cf21 Fix e2e tests 2025-11-03 22:32:10 -06:00
Nate Brown
9ded90c6e8 Fix tests 2025-11-03 22:25:33 -06:00
Nate Brown
19600f257f Rebase cleanup 2025-11-03 22:21:20 -06:00
Nate Brown
3e2a6e0a5d Cleanup and note more work 2025-11-03 22:03:27 -06:00
Nate Brown
bc62f5ec82 Try the timeout 2025-11-03 22:02:24 -06:00
Nate Brown
012fcf40fe Revert "More playing" way too much garbage emitted
This reverts commit fa098c551a.
2025-11-03 22:02:24 -06:00
Nate Brown
ad319b964d More playing 2025-11-03 22:02:24 -06:00
Nate Brown
f42878c5fc Playing 2025-11-03 22:02:24 -06:00
Nate Brown
f2b3ef4b3e non-blocking io for linux 2025-11-03 22:02:22 -06:00
Nate Brown
c3ec96d9c2 Remove more os.Exit calls and give a more reliable wait for stop function 2025-11-03 21:59:27 -06:00
58 changed files with 692 additions and 1512 deletions

View File

@@ -14,7 +14,7 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v5
- uses: actions/checkout@v4
- uses: actions/setup-go@v6
with:

View File

@@ -10,7 +10,7 @@ jobs:
name: Build Linux/BSD All
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v5
- uses: actions/checkout@v4
- uses: actions/setup-go@v6
with:
@@ -24,7 +24,7 @@ jobs:
mv build/*.tar.gz release
- name: Upload artifacts
uses: actions/upload-artifact@v5
uses: actions/upload-artifact@v4
with:
name: linux-latest
path: release
@@ -33,7 +33,7 @@ jobs:
name: Build Windows
runs-on: windows-latest
steps:
- uses: actions/checkout@v5
- uses: actions/checkout@v4
- uses: actions/setup-go@v6
with:
@@ -55,7 +55,7 @@ jobs:
mv dist\windows\wintun build\dist\windows\
- name: Upload artifacts
uses: actions/upload-artifact@v5
uses: actions/upload-artifact@v4
with:
name: windows-latest
path: build
@@ -66,7 +66,7 @@ jobs:
HAS_SIGNING_CREDS: ${{ secrets.AC_USERNAME != '' }}
runs-on: macos-latest
steps:
- uses: actions/checkout@v5
- uses: actions/checkout@v4
- uses: actions/setup-go@v6
with:
@@ -104,7 +104,7 @@ jobs:
fi
- name: Upload artifacts
uses: actions/upload-artifact@v5
uses: actions/upload-artifact@v4
with:
name: darwin-latest
path: ./release/*
@@ -124,11 +124,11 @@ jobs:
# be overwritten
- name: Checkout code
if: ${{ env.HAS_DOCKER_CREDS == 'true' }}
uses: actions/checkout@v5
uses: actions/checkout@v4
- name: Download artifacts
if: ${{ env.HAS_DOCKER_CREDS == 'true' }}
uses: actions/download-artifact@v6
uses: actions/download-artifact@v4
with:
name: linux-latest
path: artifacts
@@ -160,10 +160,10 @@ jobs:
needs: [build-linux, build-darwin, build-windows]
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v5
- uses: actions/checkout@v4
- name: Download artifacts
uses: actions/download-artifact@v6
uses: actions/download-artifact@v4
with:
path: artifacts

View File

@@ -20,7 +20,7 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v5
- uses: actions/checkout@v4
- uses: actions/setup-go@v6
with:

View File

@@ -18,7 +18,7 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v5
- uses: actions/checkout@v4
- uses: actions/setup-go@v6
with:

View File

@@ -18,7 +18,7 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v5
- uses: actions/checkout@v4
- uses: actions/setup-go@v6
with:
@@ -32,7 +32,7 @@ jobs:
run: make vet
- name: golangci-lint
uses: golangci/golangci-lint-action@v9
uses: golangci/golangci-lint-action@v8
with:
version: v2.5
@@ -45,7 +45,7 @@ jobs:
- name: Build test mobile
run: make build-test-mobile
- uses: actions/upload-artifact@v5
- uses: actions/upload-artifact@v4
with:
name: e2e packet flow linux-latest
path: e2e/mermaid/linux-latest
@@ -56,7 +56,7 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v5
- uses: actions/checkout@v4
- uses: actions/setup-go@v6
with:
@@ -77,7 +77,7 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v5
- uses: actions/checkout@v4
- uses: actions/setup-go@v6
with:
@@ -98,7 +98,7 @@ jobs:
os: [windows-latest, macos-latest]
steps:
- uses: actions/checkout@v5
- uses: actions/checkout@v4
- uses: actions/setup-go@v6
with:
@@ -115,7 +115,7 @@ jobs:
run: make vet
- name: golangci-lint
uses: golangci/golangci-lint-action@v9
uses: golangci/golangci-lint-action@v8
with:
version: v2.5
@@ -125,7 +125,7 @@ jobs:
- name: End 2 end
run: make e2evv
- uses: actions/upload-artifact@v5
- uses: actions/upload-artifact@v4
with:
name: e2e packet flow ${{ matrix.os }}
path: e2e/mermaid/${{ matrix.os }}

109
bits.go
View File

@@ -9,13 +9,14 @@ type Bits struct {
length uint64
current uint64
bits []bool
firstSeen bool
lostCounter metrics.Counter
dupeCounter metrics.Counter
outOfWindowCounter metrics.Counter
}
func NewBits(bits uint64) *Bits {
b := &Bits{
return &Bits{
length: bits,
bits: make([]bool, bits, bits),
current: 0,
@@ -23,37 +24,34 @@ func NewBits(bits uint64) *Bits {
dupeCounter: metrics.GetOrRegisterCounter("network.packets.duplicate", nil),
outOfWindowCounter: metrics.GetOrRegisterCounter("network.packets.out_of_window", nil),
}
// There is no counter value 0, mark it to avoid counting a lost packet later.
b.bits[0] = true
b.current = 0
return b
}
func (b *Bits) Check(l *logrus.Logger, i uint64) bool {
func (b *Bits) Check(l logrus.FieldLogger, i uint64) bool {
// If i is the next number, return true.
if i > b.current {
if i > b.current || (i == 0 && b.firstSeen == false && b.current < b.length) {
return true
}
// If i is within the window, check if it's been set already.
if i > b.current-b.length || i < b.length && b.current < b.length {
// If i is within the window, check if it's been set already. The first window will fail this check
if i > b.current-b.length {
return !b.bits[i%b.length]
}
// If i is within the first window
if i < b.length {
return !b.bits[i%b.length]
}
// Not within the window
if l.Level >= logrus.DebugLevel {
l.Debugf("rejected a packet (top) %d %d\n", b.current, i)
}
l.Debugf("rejected a packet (top) %d %d\n", b.current, i)
return false
}
func (b *Bits) Update(l *logrus.Logger, i uint64) bool {
// If i is the next number, return true and update current.
if i == b.current+1 {
// Check if the oldest bit was lost since we are shifting the window by 1 and occupying it with this counter
// The very first window can only be tracked as lost once we are on the 2nd window or greater
if b.bits[i%b.length] == false && i > b.length {
// Report missed packets, we can only understand what was missed after the first window has been gone through
if i > b.length && b.bits[i%b.length] == false {
b.lostCounter.Inc(1)
}
b.bits[i%b.length] = true
@@ -61,32 +59,61 @@ func (b *Bits) Update(l *logrus.Logger, i uint64) bool {
return true
}
// If i is a jump, adjust the window, record lost, update current, and return true
if i > b.current {
lost := int64(0)
// Zero out the bits between the current and the new counter value, limited by the window size,
// since the window is shifting
for n := b.current + 1; n <= min(i, b.current+b.length); n++ {
if b.bits[n%b.length] == false && n > b.length {
lost++
}
// If i packet is greater than current but less than the maximum length of our bitmap,
// flip everything in between to false and move ahead.
if i > b.current && i < b.current+b.length {
// In between current and i need to be zero'd to allow those packets to come in later
for n := b.current + 1; n < i; n++ {
b.bits[n%b.length] = false
}
// Only record any skipped packets as a result of the window moving further than the window length
// Any loss within the new window will be accounted for in future calls
lost += max(0, int64(i-b.current-b.length))
b.bits[i%b.length] = true
b.current = i
//l.Debugf("missed %d packets between %d and %d\n", i-b.current, i, b.current)
return true
}
// If i is greater than the delta between current and the total length of our bitmap,
// just flip everything in the map and move ahead.
if i >= b.current+b.length {
// The current window loss will be accounted for later, only record the jump as loss up until then
lost := maxInt64(0, int64(i-b.current-b.length))
//TODO: explain this
if b.current == 0 {
lost++
}
for n := range b.bits {
// Don't want to count the first window as a loss
//TODO: this is likely wrong, we are wanting to track only the bit slots that we aren't going to track anymore and this is marking everything as missed
//if b.bits[n] == false {
// lost++
//}
b.bits[n] = false
}
b.lostCounter.Inc(lost)
if l.Level >= logrus.DebugLevel {
l.WithField("receiveWindow", m{"accepted": true, "currentCounter": b.current, "incomingCounter": i, "reason": "window shifting"}).
Debug("Receive window")
}
b.bits[i%b.length] = true
b.current = i
return true
}
// If i is within the current window but below the current counter,
// Check to see if it's a duplicate
if i > b.current-b.length || i < b.length && b.current < b.length {
if b.current == i || b.bits[i%b.length] == true {
// Allow for the 0 packet to come in within the first window
if i == 0 && b.firstSeen == false && b.current < b.length {
b.firstSeen = true
b.bits[i%b.length] = true
return true
}
// If i is within the window of current minus length (the total pat window size),
// allow it and flip to true but to NOT change current. We also have to account for the first window
if ((b.current >= b.length && i > b.current-b.length) || (b.current < b.length && i < b.length)) && i <= b.current {
if b.current == i {
if l.Level >= logrus.DebugLevel {
l.WithField("receiveWindow", m{"accepted": false, "currentCounter": b.current, "incomingCounter": i, "reason": "duplicate"}).
Debug("Receive window")
@@ -95,8 +122,18 @@ func (b *Bits) Update(l *logrus.Logger, i uint64) bool {
return false
}
if b.bits[i%b.length] == true {
if l.Level >= logrus.DebugLevel {
l.WithField("receiveWindow", m{"accepted": false, "currentCounter": b.current, "incomingCounter": i, "reason": "old duplicate"}).
Debug("Receive window")
}
b.dupeCounter.Inc(1)
return false
}
b.bits[i%b.length] = true
return true
}
// In all other cases, fail and don't change current.
@@ -110,3 +147,11 @@ func (b *Bits) Update(l *logrus.Logger, i uint64) bool {
}
return false
}
func maxInt64(a, b int64) int64 {
if a > b {
return a
}
return b
}

View File

@@ -15,41 +15,48 @@ func TestBits(t *testing.T) {
assert.Len(t, b.bits, 10)
// This is initialized to zero - receive one. This should work.
assert.True(t, b.Check(l, 1))
assert.True(t, b.Update(l, 1))
u := b.Update(l, 1)
assert.True(t, u)
assert.EqualValues(t, 1, b.current)
g := []bool{true, true, false, false, false, false, false, false, false, false}
g := []bool{false, true, false, false, false, false, false, false, false, false}
assert.Equal(t, g, b.bits)
// Receive two
assert.True(t, b.Check(l, 2))
assert.True(t, b.Update(l, 2))
u = b.Update(l, 2)
assert.True(t, u)
assert.EqualValues(t, 2, b.current)
g = []bool{true, true, true, false, false, false, false, false, false, false}
g = []bool{false, true, true, false, false, false, false, false, false, false}
assert.Equal(t, g, b.bits)
// Receive two again - it will fail
assert.False(t, b.Check(l, 2))
assert.False(t, b.Update(l, 2))
u = b.Update(l, 2)
assert.False(t, u)
assert.EqualValues(t, 2, b.current)
// Jump ahead to 15, which should clear everything and set the 6th element
assert.True(t, b.Check(l, 15))
assert.True(t, b.Update(l, 15))
u = b.Update(l, 15)
assert.True(t, u)
assert.EqualValues(t, 15, b.current)
g = []bool{false, false, false, false, false, true, false, false, false, false}
assert.Equal(t, g, b.bits)
// Mark 14, which is allowed because it is in the window
assert.True(t, b.Check(l, 14))
assert.True(t, b.Update(l, 14))
u = b.Update(l, 14)
assert.True(t, u)
assert.EqualValues(t, 15, b.current)
g = []bool{false, false, false, false, true, true, false, false, false, false}
assert.Equal(t, g, b.bits)
// Mark 5, which is not allowed because it is not in the window
assert.False(t, b.Check(l, 5))
assert.False(t, b.Update(l, 5))
u = b.Update(l, 5)
assert.False(t, u)
assert.EqualValues(t, 15, b.current)
g = []bool{false, false, false, false, true, true, false, false, false, false}
assert.Equal(t, g, b.bits)
@@ -62,29 +69,10 @@ func TestBits(t *testing.T) {
// Walk through a few windows in order
b = NewBits(10)
for i := uint64(1); i <= 100; i++ {
for i := uint64(0); i <= 100; i++ {
assert.True(t, b.Check(l, i), "Error while checking %v", i)
assert.True(t, b.Update(l, i), "Error while updating %v", i)
}
assert.False(t, b.Check(l, 1), "Out of window check")
}
func TestBitsLargeJumps(t *testing.T) {
l := test.NewLogger()
b := NewBits(10)
b.lostCounter.Clear()
b = NewBits(10)
b.lostCounter.Clear()
assert.True(t, b.Update(l, 55)) // We saw packet 55 and can still track 45,46,47,48,49,50,51,52,53,54
assert.Equal(t, int64(45), b.lostCounter.Count())
assert.True(t, b.Update(l, 100)) // We saw packet 55 and 100 and can still track 90,91,92,93,94,95,96,97,98,99
assert.Equal(t, int64(89), b.lostCounter.Count())
assert.True(t, b.Update(l, 200)) // We saw packet 55, 100, and 200 and can still track 190,191,192,193,194,195,196,197,198,199
assert.Equal(t, int64(188), b.lostCounter.Count())
}
func TestBitsDupeCounter(t *testing.T) {
@@ -136,7 +124,8 @@ func TestBitsOutOfWindowCounter(t *testing.T) {
assert.False(t, b.Update(l, 0))
assert.Equal(t, int64(1), b.outOfWindowCounter.Count())
assert.Equal(t, int64(19), b.lostCounter.Count()) // packet 0 wasn't lost
//tODO: make sure lostcounter doesn't increase in orderly increment
assert.Equal(t, int64(20), b.lostCounter.Count())
assert.Equal(t, int64(0), b.dupeCounter.Count())
assert.Equal(t, int64(1), b.outOfWindowCounter.Count())
}
@@ -148,6 +137,8 @@ func TestBitsLostCounter(t *testing.T) {
b.dupeCounter.Clear()
b.outOfWindowCounter.Clear()
//assert.True(t, b.Update(0))
assert.True(t, b.Update(l, 0))
assert.True(t, b.Update(l, 20))
assert.True(t, b.Update(l, 21))
assert.True(t, b.Update(l, 22))
@@ -158,7 +149,7 @@ func TestBitsLostCounter(t *testing.T) {
assert.True(t, b.Update(l, 27))
assert.True(t, b.Update(l, 28))
assert.True(t, b.Update(l, 29))
assert.Equal(t, int64(19), b.lostCounter.Count()) // packet 0 wasn't lost
assert.Equal(t, int64(20), b.lostCounter.Count())
assert.Equal(t, int64(0), b.dupeCounter.Count())
assert.Equal(t, int64(0), b.outOfWindowCounter.Count())
@@ -167,6 +158,8 @@ func TestBitsLostCounter(t *testing.T) {
b.dupeCounter.Clear()
b.outOfWindowCounter.Clear()
assert.True(t, b.Update(l, 0))
assert.Equal(t, int64(0), b.lostCounter.Count())
assert.True(t, b.Update(l, 9))
assert.Equal(t, int64(0), b.lostCounter.Count())
// 10 will set 0 index, 0 was already set, no lost packets
@@ -221,62 +214,6 @@ func TestBitsLostCounter(t *testing.T) {
assert.Equal(t, int64(0), b.outOfWindowCounter.Count())
}
func TestBitsLostCounterIssue1(t *testing.T) {
l := test.NewLogger()
b := NewBits(10)
b.lostCounter.Clear()
b.dupeCounter.Clear()
b.outOfWindowCounter.Clear()
assert.True(t, b.Update(l, 4))
assert.Equal(t, int64(0), b.lostCounter.Count())
assert.True(t, b.Update(l, 1))
assert.Equal(t, int64(0), b.lostCounter.Count())
assert.True(t, b.Update(l, 9))
assert.Equal(t, int64(0), b.lostCounter.Count())
assert.True(t, b.Update(l, 2))
assert.Equal(t, int64(0), b.lostCounter.Count())
assert.True(t, b.Update(l, 3))
assert.Equal(t, int64(0), b.lostCounter.Count())
assert.True(t, b.Update(l, 5))
assert.Equal(t, int64(0), b.lostCounter.Count())
assert.True(t, b.Update(l, 6))
assert.Equal(t, int64(0), b.lostCounter.Count())
assert.True(t, b.Update(l, 7))
assert.Equal(t, int64(0), b.lostCounter.Count())
// assert.True(t, b.Update(l, 8))
assert.True(t, b.Update(l, 10))
assert.Equal(t, int64(0), b.lostCounter.Count())
assert.True(t, b.Update(l, 11))
assert.Equal(t, int64(0), b.lostCounter.Count())
assert.True(t, b.Update(l, 14))
assert.Equal(t, int64(0), b.lostCounter.Count())
// Issue seems to be here, we reset missing packet 8 to false here and don't increment the lost counter
assert.True(t, b.Update(l, 19))
assert.Equal(t, int64(1), b.lostCounter.Count())
assert.True(t, b.Update(l, 12))
assert.Equal(t, int64(1), b.lostCounter.Count())
assert.True(t, b.Update(l, 13))
assert.Equal(t, int64(1), b.lostCounter.Count())
assert.True(t, b.Update(l, 15))
assert.Equal(t, int64(1), b.lostCounter.Count())
assert.True(t, b.Update(l, 16))
assert.Equal(t, int64(1), b.lostCounter.Count())
assert.True(t, b.Update(l, 17))
assert.Equal(t, int64(1), b.lostCounter.Count())
assert.True(t, b.Update(l, 18))
assert.Equal(t, int64(1), b.lostCounter.Count())
assert.True(t, b.Update(l, 20))
assert.Equal(t, int64(1), b.lostCounter.Count())
assert.True(t, b.Update(l, 21))
// We missed packet 8 above
assert.Equal(t, int64(1), b.lostCounter.Count())
assert.Equal(t, int64(0), b.dupeCounter.Count())
assert.Equal(t, int64(0), b.outOfWindowCounter.Count())
}
func BenchmarkBits(b *testing.B) {
z := NewBits(10)
for n := 0; n < b.N; n++ {

View File

@@ -173,26 +173,23 @@ func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error
var passphrase []byte
if !isP11 && *cf.encryption {
passphrase = []byte(os.Getenv("NEBULA_CA_PASSPHRASE"))
for i := 0; i < 5; i++ {
out.Write([]byte("Enter passphrase: "))
passphrase, err = pr.ReadPassword()
if err == ErrNoTerminal {
return fmt.Errorf("out-key must be encrypted interactively")
} else if err != nil {
return fmt.Errorf("error reading passphrase: %s", err)
}
if len(passphrase) > 0 {
break
}
}
if len(passphrase) == 0 {
for i := 0; i < 5; i++ {
out.Write([]byte("Enter passphrase: "))
passphrase, err = pr.ReadPassword()
if err == ErrNoTerminal {
return fmt.Errorf("out-key must be encrypted interactively")
} else if err != nil {
return fmt.Errorf("error reading passphrase: %s", err)
}
if len(passphrase) > 0 {
break
}
}
if len(passphrase) == 0 {
return fmt.Errorf("no passphrase specified, remove -encrypt flag to write out-key in plaintext")
}
return fmt.Errorf("no passphrase specified, remove -encrypt flag to write out-key in plaintext")
}
}

View File

@@ -171,17 +171,6 @@ func Test_ca(t *testing.T) {
assert.Equal(t, pwPromptOb, ob.String())
assert.Empty(t, eb.String())
// test encrypted key with passphrase environment variable
os.Remove(keyF.Name())
os.Remove(crtF.Name())
ob.Reset()
eb.Reset()
args = []string{"-version", "1", "-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
os.Setenv("NEBULA_CA_PASSPHRASE", string(passphrase))
require.NoError(t, ca(args, ob, eb, testpw))
assert.Empty(t, eb.String())
os.Setenv("NEBULA_CA_PASSPHRASE", "")
// read encrypted key file and verify default params
rb, _ = os.ReadFile(keyF.Name())
k, _ := pem.Decode(rb)

View File

@@ -5,28 +5,10 @@ import (
"fmt"
"io"
"os"
"runtime/debug"
"strings"
)
// A version string that can be set with
//
// -ldflags "-X main.Build=SOMEVERSION"
//
// at compile-time.
var Build string
func init() {
if Build == "" {
info, ok := debug.ReadBuildInfo()
if !ok {
return
}
Build = strings.TrimPrefix(info.Main.Version, "v")
}
}
type helpError struct {
s string
}

View File

@@ -43,7 +43,7 @@ type signFlags struct {
func newSignFlags() *signFlags {
sf := signFlags{set: flag.NewFlagSet("sign", flag.ContinueOnError)}
sf.set.Usage = func() {}
sf.version = sf.set.Uint("version", 0, "Optional: version of the certificate format to use. The default is to match the version of the signing CA")
sf.version = sf.set.Uint("version", 0, "Optional: version of the certificate format to use, the default is to create both v1 and v2 certificates.")
sf.caKeyPath = sf.set.String("ca-key", "ca.key", "Optional: path to the signing CA key")
sf.caCertPath = sf.set.String("ca-crt", "ca.crt", "Optional: path to the signing CA cert")
sf.name = sf.set.String("name", "", "Required: name of the cert, usually a hostname")
@@ -116,28 +116,26 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
// naively attempt to decode the private key as though it is not encrypted
caKey, _, curve, err = cert.UnmarshalSigningPrivateKeyFromPEM(rawCAKey)
if errors.Is(err, cert.ErrPrivateKeyEncrypted) {
// ask for a passphrase until we get one
var passphrase []byte
passphrase = []byte(os.Getenv("NEBULA_CA_PASSPHRASE"))
if len(passphrase) == 0 {
// ask for a passphrase until we get one
for i := 0; i < 5; i++ {
out.Write([]byte("Enter passphrase: "))
passphrase, err = pr.ReadPassword()
for i := 0; i < 5; i++ {
out.Write([]byte("Enter passphrase: "))
passphrase, err = pr.ReadPassword()
if errors.Is(err, ErrNoTerminal) {
return fmt.Errorf("ca-key is encrypted and must be decrypted interactively")
} else if err != nil {
return fmt.Errorf("error reading password: %s", err)
}
if len(passphrase) > 0 {
break
}
if errors.Is(err, ErrNoTerminal) {
return fmt.Errorf("ca-key is encrypted and must be decrypted interactively")
} else if err != nil {
return fmt.Errorf("error reading password: %s", err)
}
if len(passphrase) == 0 {
return fmt.Errorf("cannot open encrypted ca-key without passphrase")
if len(passphrase) > 0 {
break
}
}
if len(passphrase) == 0 {
return fmt.Errorf("cannot open encrypted ca-key without passphrase")
}
curve, caKey, _, err = cert.DecryptAndUnmarshalSigningPrivateKey(passphrase, rawCAKey)
if err != nil {
return fmt.Errorf("error while parsing encrypted ca-key: %s", err)
@@ -167,10 +165,6 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
return fmt.Errorf("ca certificate is expired")
}
if version == 0 {
version = caCert.Version()
}
// if no duration is given, expire one second before the root expires
if *sf.duration <= 0 {
*sf.duration = time.Until(caCert.NotAfter()) - time.Second*1
@@ -283,19 +277,21 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
notBefore := time.Now()
notAfter := notBefore.Add(*sf.duration)
switch version {
case cert.Version1:
// Make sure we have only one ipv4 address
if version == 0 || version == cert.Version1 {
// Make sure we at least have an ip
if len(v4Networks) != 1 {
return newHelpErrorf("invalid -networks definition: v1 certificates can only have a single ipv4 address")
}
if len(v6Networks) > 0 {
return newHelpErrorf("invalid -networks definition: v1 certificates can only contain ipv4 addresses")
}
if version == cert.Version1 {
// If we are asked to mint a v1 certificate only then we cant just ignore any v6 addresses
if len(v6Networks) > 0 {
return newHelpErrorf("invalid -networks definition: v1 certificates can only be ipv4")
}
if len(v6UnsafeNetworks) > 0 {
return newHelpErrorf("invalid -unsafe-networks definition: v1 certificates can only contain ipv4 addresses")
if len(v6UnsafeNetworks) > 0 {
return newHelpErrorf("invalid -unsafe-networks definition: v1 certificates can only be ipv4")
}
}
t := &cert.TBSCertificate{
@@ -325,8 +321,9 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
}
crts = append(crts, nc)
}
case cert.Version2:
if version == 0 || version == cert.Version2 {
t := &cert.TBSCertificate{
Version: cert.Version2,
Name: *sf.name,
@@ -354,9 +351,6 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
}
crts = append(crts, nc)
default:
// this should be unreachable
return fmt.Errorf("invalid version: %d", version)
}
if !isP11 && *sf.inPubPath == "" {

View File

@@ -55,7 +55,7 @@ func Test_signHelp(t *testing.T) {
" -unsafe-networks string\n"+
" \tOptional: comma separated list of ip address and network in CIDR notation. Unsafe networks this cert can route for\n"+
" -version uint\n"+
" \tOptional: version of the certificate format to use. The default is to match the version of the signing CA\n",
" \tOptional: version of the certificate format to use, the default is to create both v1 and v2 certificates.\n",
ob.String(),
)
}
@@ -204,7 +204,7 @@ func Test_signCert(t *testing.T) {
ob.Reset()
eb.Reset()
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "100::100/100"}
assertHelpError(t, signCert(args, ob, eb, nopw), "invalid -unsafe-networks definition: v1 certificates can only contain ipv4 addresses")
assertHelpError(t, signCert(args, ob, eb, nopw), "invalid -unsafe-networks definition: v1 certificates can only be ipv4")
assert.Empty(t, ob.String())
assert.Empty(t, eb.String())
@@ -379,15 +379,6 @@ func Test_signCert(t *testing.T) {
assert.Equal(t, "Enter passphrase: ", ob.String())
assert.Empty(t, eb.String())
// test with the proper password in the environment
os.Remove(crtF.Name())
os.Remove(keyF.Name())
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
os.Setenv("NEBULA_CA_PASSPHRASE", string(passphrase))
require.NoError(t, signCert(args, ob, eb, testpw))
assert.Empty(t, eb.String())
os.Setenv("NEBULA_CA_PASSPHRASE", "")
// test with the wrong password
ob.Reset()
eb.Reset()
@@ -398,17 +389,6 @@ func Test_signCert(t *testing.T) {
assert.Equal(t, "Enter passphrase: ", ob.String())
assert.Empty(t, eb.String())
// test with the wrong password in environment
ob.Reset()
eb.Reset()
os.Setenv("NEBULA_CA_PASSPHRASE", "invalid password")
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
require.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing encrypted ca-key: invalid passphrase or corrupt private key")
assert.Empty(t, ob.String())
assert.Empty(t, eb.String())
os.Setenv("NEBULA_CA_PASSPHRASE", "")
// test with the user not entering a password
ob.Reset()
eb.Reset()

View File

@@ -4,8 +4,6 @@ import (
"flag"
"fmt"
"os"
"runtime/debug"
"strings"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula"
@@ -20,17 +18,6 @@ import (
// at compile-time.
var Build string
func init() {
if Build == "" {
info, ok := debug.ReadBuildInfo()
if !ok {
return
}
Build = strings.TrimPrefix(info.Main.Version, "v")
}
}
func main() {
serviceFlag := flag.String("service", "", "Control the system service.")
configPath := flag.String("config", "", "Path to either a file or directory to load configuration from")
@@ -78,8 +65,16 @@ func main() {
}
if !*configTest {
ctrl.Start()
ctrl.ShutdownBlock()
wait, err := ctrl.Start()
if err != nil {
util.LogWithContextIfNeeded("Error while running", err, l)
os.Exit(1)
}
go ctrl.ShutdownBlock()
wait()
l.Info("Goodbye")
}
os.Exit(0)

View File

@@ -3,9 +3,10 @@ package main
import (
"flag"
"fmt"
"log"
"net/http"
_ "net/http/pprof"
"os"
"runtime/debug"
"strings"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula"
@@ -20,17 +21,6 @@ import (
// at compile-time.
var Build string
func init() {
if Build == "" {
info, ok := debug.ReadBuildInfo()
if !ok {
return
}
Build = strings.TrimPrefix(info.Main.Version, "v")
}
}
func main() {
configPath := flag.String("config", "", "Path to either a file or directory to load configuration from")
configTest := flag.Bool("test", false, "Test the config and print the end result. Non zero exit indicates a faulty config")
@@ -71,10 +61,22 @@ func main() {
os.Exit(1)
}
go func() {
log.Println(http.ListenAndServe("0.0.0.0:6060", nil))
}()
if !*configTest {
ctrl.Start()
wait, err := ctrl.Start()
if err != nil {
util.LogWithContextIfNeeded("Error while running", err, l)
os.Exit(1)
}
go ctrl.ShutdownBlock()
notifyReady(l)
ctrl.ShutdownBlock()
wait()
l.Info("Goodbye")
}
os.Exit(0)

View File

@@ -17,7 +17,7 @@ import (
"dario.cat/mergo"
"github.com/sirupsen/logrus"
"go.yaml.in/yaml/v3"
"gopkg.in/yaml.v3"
)
type C struct {

View File

@@ -10,7 +10,7 @@ import (
"github.com/slackhq/nebula/test"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.yaml.in/yaml/v3"
"gopkg.in/yaml.v3"
)
func TestConfig_Load(t *testing.T) {

View File

@@ -50,6 +50,11 @@ func NewConnectionState(l *logrus.Logger, cs *CertState, crt cert.Certificate, i
}
static := noise.DHKey{Private: cs.privateKey, Public: crt.PublicKey()}
b := NewBits(ReplayWindow)
// Clear out bit 0, we never transmit it, and we don't want it showing as packet loss
b.Update(l, 0)
hs, err := noise.NewHandshakeState(noise.Config{
CipherSuite: ncs,
Random: rand.Reader,
@@ -69,7 +74,7 @@ func NewConnectionState(l *logrus.Logger, cs *CertState, crt cert.Certificate, i
ci := &ConnectionState{
H: hs,
initiator: initiator,
window: NewBits(ReplayWindow),
window: b,
myCert: crt,
}
// always start the counter from 2, as packet 1 and packet 2 are handshake packets.

View File

@@ -2,9 +2,11 @@ package nebula
import (
"context"
"errors"
"net/netip"
"os"
"os/signal"
"sync"
"syscall"
"github.com/sirupsen/logrus"
@@ -13,6 +15,16 @@ import (
"github.com/slackhq/nebula/overlay"
)
type RunState int
const (
Stopped RunState = 0 // The control has yet to be started
Started RunState = 1 // The control has been started
Stopping RunState = 2 // The control is stopping
)
var ErrAlreadyStarted = errors.New("nebula is already started")
// Every interaction here needs to take extra care to copy memory and not return or use arguments "as is" when touching
// core. This means copying IP objects, slices, de-referencing pointers and taking the actual value, etc
@@ -26,6 +38,9 @@ type controlHostLister interface {
}
type Control struct {
stateLock sync.Mutex
state RunState
f *Interface
l *logrus.Logger
ctx context.Context
@@ -49,10 +64,21 @@ type ControlHostInfo struct {
CurrentRelaysThroughMe []netip.Addr `json:"currentRelaysThroughMe"`
}
// Start actually runs nebula, this is a nonblocking call. To block use Control.ShutdownBlock()
func (c *Control) Start() {
// Start actually runs nebula, this is a nonblocking call.
// The returned function can be used to wait for nebula to fully stop.
func (c *Control) Start() (func(), error) {
c.stateLock.Lock()
if c.state != Stopped {
c.stateLock.Unlock()
return nil, ErrAlreadyStarted
}
// Activate the interface
c.f.activate()
err := c.f.activate()
if err != nil {
c.stateLock.Unlock()
return nil, err
}
// Call all the delayed funcs that waited patiently for the interface to be created.
if c.sshStart != nil {
@@ -72,15 +98,33 @@ func (c *Control) Start() {
}
// Start reading packets.
c.f.run()
c.state = Started
c.stateLock.Unlock()
return c.f.run()
}
func (c *Control) State() RunState {
c.stateLock.Lock()
defer c.stateLock.Unlock()
return c.state
}
func (c *Control) Context() context.Context {
return c.ctx
}
// Stop signals nebula to shutdown and close all tunnels, returns after the shutdown is complete
// Stop is a non-blocking call that signals nebula to close all tunnels and shut down
func (c *Control) Stop() {
c.stateLock.Lock()
if c.state != Started {
c.stateLock.Unlock()
// We are stopping or stopped already
return
}
c.state = Stopping
c.stateLock.Unlock()
// Stop the handshakeManager (and other services), to prevent new tunnels from
// being created while we're shutting them all down.
c.cancel()
@@ -89,7 +133,7 @@ func (c *Control) Stop() {
if err := c.f.Close(); err != nil {
c.l.WithError(err).Error("Close interface failed")
}
c.l.Info("Goodbye")
c.state = Stopped
}
// ShutdownBlock will listen for and block on term and interrupt signals, calling Control.Stop() once signalled

View File

@@ -174,10 +174,6 @@ func (c *Control) GetHostmap() *HostMap {
return c.f.hostMap
}
func (c *Control) GetF() *Interface {
return c.f
}
func (c *Control) GetCertState() *CertState {
return c.f.pki.getCertState()
}

View File

@@ -20,7 +20,7 @@ import (
"github.com/slackhq/nebula/udp"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.yaml.in/yaml/v3"
"gopkg.in/yaml.v3"
)
func BenchmarkHotPath(b *testing.B) {
@@ -97,41 +97,6 @@ func TestGoodHandshake(t *testing.T) {
theirControl.Stop()
}
func TestGoodHandshakeNoOverlap(t *testing.T) {
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "me", "10.128.0.1/24", nil)
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "them", "2001::69/24", nil) //look ma, cross-stack!
// Put their info in our lighthouse
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
// Start the servers
myControl.Start()
theirControl.Start()
empty := []byte{}
t.Log("do something to cause a handshake")
myControl.GetF().SendMessageToVpnAddr(header.Test, header.MessageNone, theirVpnIpNet[0].Addr(), empty, empty, empty)
t.Log("Have them consume my stage 0 packet. They have a tunnel now")
theirControl.InjectUDPPacket(myControl.GetFromUDP(true))
t.Log("Get their stage 1 packet")
stage1Packet := theirControl.GetFromUDP(true)
t.Log("Have me consume their stage 1 packet. I have a tunnel now")
myControl.InjectUDPPacket(stage1Packet)
t.Log("Wait until we see a test packet come through to make sure we give the tunnel time to complete")
myControl.WaitForType(header.Test, 0, theirControl)
t.Log("Make sure our host infos are correct")
assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet, theirVpnIpNet, myControl, theirControl)
myControl.Stop()
theirControl.Stop()
}
func TestWrongResponderHandshake(t *testing.T) {
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
@@ -499,35 +464,6 @@ func TestRelays(t *testing.T) {
r.RenderHostmaps("Final hostmaps", myControl, relayControl, theirControl)
}
func TestRelaysDontCareAboutIps(t *testing.T) {
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version2, ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}})
relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "relay ", "2001::9999/24", m{"relay": m{"am_relay": true}})
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}})
// Teach my how to get to the relay and that their can be reached via the relay
myControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr)
myControl.InjectRelays(theirVpnIpNet[0].Addr(), []netip.Addr{relayVpnIpNet[0].Addr()})
relayControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
// Build a router so we don't have to reason who gets which packet
r := router.NewR(t, myControl, relayControl, theirControl)
defer r.RenderFlow()
// Start the servers
myControl.Start()
relayControl.Start()
theirControl.Start()
t.Log("Trigger a handshake from me to them via the relay")
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
p := r.RouteForAllUntilTxTun(theirControl)
r.Log("Assert the tunnel works")
assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80)
r.RenderHostmaps("Final hostmaps", myControl, relayControl, theirControl)
}
func TestReestablishRelays(t *testing.T) {
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}})
@@ -1291,116 +1227,3 @@ func TestV2NonPrimaryWithLighthouse(t *testing.T) {
myControl.Stop()
theirControl.Stop()
}
func TestV2NonPrimaryWithOffNetLighthouse(t *testing.T) {
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
lhControl, lhVpnIpNet, lhUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "lh ", "2001::1/64", m{"lighthouse": m{"am_lighthouse": true}})
o := m{
"static_host_map": m{
lhVpnIpNet[0].Addr().String(): []string{lhUdpAddr.String()},
},
"lighthouse": m{
"hosts": []string{lhVpnIpNet[0].Addr().String()},
"local_allow_list": m{
// Try and block our lighthouse updates from using the actual addresses assigned to this computer
// If we start discovering addresses the test router doesn't know about then test traffic cant flow
"10.0.0.0/24": true,
"::/0": false,
},
},
}
myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version2, ca, caKey, "me ", "10.128.0.2/24, ff::2/64", o)
theirControl, theirVpnIpNet, _, _ := newSimpleServer(cert.Version2, ca, caKey, "them", "10.128.0.3/24, ff::3/64", o)
// Build a router so we don't have to reason who gets which packet
r := router.NewR(t, lhControl, myControl, theirControl)
defer r.RenderFlow()
// Start the servers
lhControl.Start()
myControl.Start()
theirControl.Start()
t.Log("Stand up an ipv6 tunnel between me and them")
assert.True(t, myVpnIpNet[1].Addr().Is6())
assert.True(t, theirVpnIpNet[1].Addr().Is6())
assertTunnel(t, myVpnIpNet[1].Addr(), theirVpnIpNet[1].Addr(), myControl, theirControl, r)
lhControl.Stop()
myControl.Stop()
theirControl.Stop()
}
func TestGoodHandshakeUnsafeDest(t *testing.T) {
unsafePrefix := "192.168.6.0/24"
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServerWithUdpAndUnsafeNetworks(cert.Version2, ca, caKey, "spooky", "10.128.0.2/24", netip.MustParseAddrPort("10.64.0.2:4242"), unsafePrefix, nil)
route := m{"route": unsafePrefix, "via": theirVpnIpNet[0].Addr().String()}
myCfg := m{
"tun": m{
"unsafe_routes": []m{route},
},
"firewall": m{
"unsafe_outbound": []m{{
"port": "any",
"proto": "any",
"host": "any",
}},
},
}
myControl, myVpnIpNet, myUdpAddr, myConfig := newSimpleServer(cert.Version2, ca, caKey, "me", "10.128.0.1/24", myCfg)
t.Logf("my config %v", myConfig)
// Put their info in our lighthouse
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
spookyDest := netip.MustParseAddr("192.168.6.4")
// Start the servers
myControl.Start()
theirControl.Start()
t.Log("Send a udp packet through to begin standing up the tunnel, this should come out the other side")
myControl.InjectTunUDPPacket(spookyDest, 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
t.Log("Have them consume my stage 0 packet. They have a tunnel now")
theirControl.InjectUDPPacket(myControl.GetFromUDP(true))
t.Log("Get their stage 1 packet so that we can play with it")
stage1Packet := theirControl.GetFromUDP(true)
t.Log("I consume a garbage packet with a proper nebula header for our tunnel")
// this should log a statement and get ignored, allowing the real handshake packet to complete the tunnel
badPacket := stage1Packet.Copy()
badPacket.Data = badPacket.Data[:len(badPacket.Data)-header.Len]
myControl.InjectUDPPacket(badPacket)
t.Log("Have me consume their real stage 1 packet. I have a tunnel now")
myControl.InjectUDPPacket(stage1Packet)
t.Log("Wait until we see my cached packet come through")
myControl.WaitForType(1, 0, theirControl)
t.Log("Make sure our host infos are correct")
assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet, theirVpnIpNet, myControl, theirControl)
t.Log("Get that cached packet and make sure it looks right")
myCachedPacket := theirControl.GetFromTun(true)
assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet[0].Addr(), spookyDest, 80, 80)
//reply
theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, spookyDest, 80, []byte("Hi from the spookyman"))
//wait for reply
theirControl.WaitForType(1, 0, myControl)
theirCachedPacket := myControl.GetFromTun(true)
assertUdpPacket(t, []byte("Hi from the spookyman"), theirCachedPacket, spookyDest, myVpnIpNet[0].Addr(), 80, 80)
t.Log("Do a bidirectional tunnel test")
r := router.NewR(t, myControl, theirControl)
defer r.RenderFlow()
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
r.RenderHostmaps("Final hostmaps", myControl, theirControl)
myControl.Stop()
theirControl.Stop()
}

View File

@@ -22,14 +22,15 @@ import (
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/e2e/router"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.yaml.in/yaml/v3"
"gopkg.in/yaml.v3"
)
type m = map[string]any
// newSimpleServer creates a nebula instance with many assumptions
func newSimpleServer(v cert.Version, caCrt cert.Certificate, caKey []byte, name string, sVpnNetworks string, overrides m) (*nebula.Control, []netip.Prefix, netip.AddrPort, *config.C) {
l := NewTestLogger()
var vpnNetworks []netip.Prefix
for _, sn := range strings.Split(sVpnNetworks, ",") {
vpnIpNet, err := netip.ParsePrefix(strings.TrimSpace(sn))
@@ -55,55 +56,7 @@ func newSimpleServer(v cert.Version, caCrt cert.Certificate, caKey []byte, name
budpIp[3] = 239
udpAddr = netip.AddrPortFrom(netip.AddrFrom16(budpIp), 4242)
}
return newSimpleServerWithUdp(v, caCrt, caKey, name, sVpnNetworks, udpAddr, overrides)
}
func newSimpleServerWithUdp(v cert.Version, caCrt cert.Certificate, caKey []byte, name string, sVpnNetworks string, udpAddr netip.AddrPort, overrides m) (*nebula.Control, []netip.Prefix, netip.AddrPort, *config.C) {
return newSimpleServerWithUdpAndUnsafeNetworks(v, caCrt, caKey, name, sVpnNetworks, udpAddr, "", overrides)
}
func newSimpleServerWithUdpAndUnsafeNetworks(v cert.Version, caCrt cert.Certificate, caKey []byte, name string, sVpnNetworks string, udpAddr netip.AddrPort, sUnsafeNetworks string, overrides m) (*nebula.Control, []netip.Prefix, netip.AddrPort, *config.C) {
l := NewTestLogger()
var vpnNetworks []netip.Prefix
for _, sn := range strings.Split(sVpnNetworks, ",") {
vpnIpNet, err := netip.ParsePrefix(strings.TrimSpace(sn))
if err != nil {
panic(err)
}
vpnNetworks = append(vpnNetworks, vpnIpNet)
}
if len(vpnNetworks) == 0 {
panic("no vpn networks")
}
firewallInbound := []m{{
"proto": "any",
"port": "any",
"host": "any",
}}
var unsafeNetworks []netip.Prefix
var firewallUnsafeInbound []m
if sUnsafeNetworks != "" {
firewallUnsafeInbound = []m{{
"proto": "any",
"port": "any",
"host": "any",
"local_cidr": "0.0.0.0/0",
}}
for _, sn := range strings.Split(sUnsafeNetworks, ",") {
x, err := netip.ParsePrefix(strings.TrimSpace(sn))
if err != nil {
panic(err)
}
unsafeNetworks = append(unsafeNetworks, x)
}
}
_, _, myPrivKey, myPEM := cert_test.NewTestCert(v, cert.Curve_CURVE25519, caCrt, caKey, name, time.Now(), time.Now().Add(5*time.Minute), vpnNetworks, unsafeNetworks, []string{})
_, _, myPrivKey, myPEM := cert_test.NewTestCert(v, cert.Curve_CURVE25519, caCrt, caKey, name, time.Now(), time.Now().Add(5*time.Minute), vpnNetworks, nil, []string{})
caB, err := caCrt.MarshalPEM()
if err != nil {
@@ -123,8 +76,11 @@ func newSimpleServerWithUdpAndUnsafeNetworks(v cert.Version, caCrt cert.Certific
"port": "any",
"host": "any",
}},
"inbound": firewallInbound,
"unsafe_inbound": firewallUnsafeInbound,
"inbound": []m{{
"proto": "any",
"port": "any",
"host": "any",
}},
},
//"handshakes": m{
// "try_interval": "1s",
@@ -310,10 +266,10 @@ func assertHostInfoPair(t *testing.T, addrA, addrB netip.AddrPort, vpnNetsA, vpn
// Get both host infos
//TODO: CERT-V2 we may want to loop over each vpnAddr and assert all the things
hBinA := controlA.GetHostInfoByVpnAddr(vpnNetsB[0].Addr(), false)
require.NotNil(t, hBinA, "Host B was not found by vpnAddr in controlA")
assert.NotNil(t, hBinA, "Host B was not found by vpnAddr in controlA")
hAinB := controlB.GetHostInfoByVpnAddr(vpnNetsA[0].Addr(), false)
require.NotNil(t, hAinB, "Host A was not found by vpnAddr in controlB")
assert.NotNil(t, hAinB, "Host A was not found by vpnAddr in controlB")
// Check that both vpn and real addr are correct
assert.EqualValues(t, getAddrs(vpnNetsB), hBinA.VpnAddrs, "Host B VpnIp is wrong in control A")

View File

@@ -318,50 +318,3 @@ func TestCertMismatchCorrection(t *testing.T) {
myControl.Stop()
theirControl.Stop()
}
func TestCrossStackRelaysWork(t *testing.T) {
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version2, ca, caKey, "me ", "10.128.0.1/24,fc00::1/64", m{"relay": m{"use_relays": true}})
relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "relay ", "10.128.0.128/24,fc00::128/64", m{"relay": m{"am_relay": true}})
theirUdp := netip.MustParseAddrPort("10.0.0.2:4242")
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServerWithUdp(cert.Version2, ca, caKey, "them ", "fc00::2/64", theirUdp, m{"relay": m{"use_relays": true}})
//myVpnV4 := myVpnIpNet[0]
myVpnV6 := myVpnIpNet[1]
relayVpnV4 := relayVpnIpNet[0]
relayVpnV6 := relayVpnIpNet[1]
theirVpnV6 := theirVpnIpNet[0]
// Teach my how to get to the relay and that their can be reached via the relay
myControl.InjectLightHouseAddr(relayVpnV4.Addr(), relayUdpAddr)
myControl.InjectLightHouseAddr(relayVpnV6.Addr(), relayUdpAddr)
myControl.InjectRelays(theirVpnV6.Addr(), []netip.Addr{relayVpnV6.Addr()})
relayControl.InjectLightHouseAddr(theirVpnV6.Addr(), theirUdpAddr)
// Build a router so we don't have to reason who gets which packet
r := router.NewR(t, myControl, relayControl, theirControl)
defer r.RenderFlow()
// Start the servers
myControl.Start()
relayControl.Start()
theirControl.Start()
t.Log("Trigger a handshake from me to them via the relay")
myControl.InjectTunUDPPacket(theirVpnV6.Addr(), 80, myVpnV6.Addr(), 80, []byte("Hi from me"))
p := r.RouteForAllUntilTxTun(theirControl)
r.Log("Assert the tunnel works")
assertUdpPacket(t, []byte("Hi from me"), p, myVpnV6.Addr(), theirVpnV6.Addr(), 80, 80)
t.Log("reply?")
theirControl.InjectTunUDPPacket(myVpnV6.Addr(), 80, theirVpnV6.Addr(), 80, []byte("Hi from them"))
p = r.RouteForAllUntilTxTun(myControl)
assertUdpPacket(t, []byte("Hi from them"), p, theirVpnV6.Addr(), myVpnV6.Addr(), 80, 80)
r.RenderHostmaps("Final hostmaps", myControl, relayControl, theirControl)
//t.Log("finish up")
//myControl.Stop()
//theirControl.Stop()
//relayControl.Stop()
}

View File

@@ -383,9 +383,8 @@ firewall:
# host: `any` or a literal hostname, ie `test-host`
# group: `any` or a literal group name, ie `default-group`
# groups: Same as group but accepts a list of values. Multiple values are AND'd together and a certificate would have to contain all groups to pass
# cidr: a remote CIDR, `0.0.0.0/0` is any ipv4 and `::/0` is any ipv6. `any` means any ip family and address.
# local_cidr: a local CIDR, `0.0.0.0/0` is any ipv4 and `::/0` is any ipv6. `any` means any ip family and address.
# This can be used to filter destinations when using unsafe_routes.
# cidr: a remote CIDR, `0.0.0.0/0` is any ipv4 and `::/0` is any ipv6.
# local_cidr: a local CIDR, `0.0.0.0/0` is any ipv4 and `::/0` is any ipv6. This can be used to filter destinations when using unsafe_routes.
# By default, this is set to only the VPN (overlay) networks assigned via the certificate networks field unless `default_local_cidr_any` is set to true.
# If there are unsafe_routes present in this config file, `local_cidr` should be set appropriately for the intended us case.
# ca_name: An issuing CA name

View File

@@ -8,7 +8,6 @@ import (
"hash/fnv"
"net/netip"
"reflect"
"slices"
"strconv"
"strings"
"sync"
@@ -23,15 +22,16 @@ import (
)
type FirewallInterface interface {
AddRule(unsafe, incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, cidr, localCidr string, caName string, caSha string) error
AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, addr, localAddr netip.Prefix, caName string, caSha string) error
}
type conn struct {
Expires time.Time // Time when this conntrack entry will expire
// record why the original connection passed the firewall, so we can re-validate after ruleset changes.
// record why the original connection passed the firewall, so we can re-validate
// after ruleset changes. Note, rulesVersion is a uint16 so that these two
// fields pack for free after the uint32 above
incoming bool
unsafe bool
rulesVersion uint16
}
@@ -39,10 +39,8 @@ type conn struct {
type Firewall struct {
Conntrack *FirewallConntrack
InRules *FirewallTable
OutRules *FirewallTable
UnsafeInRules *FirewallTable
UnsafeOutRules *FirewallTable
InRules *FirewallTable
OutRules *FirewallTable
InSendReject bool
OutSendReject bool
@@ -55,7 +53,7 @@ type Firewall struct {
// routableNetworks describes the vpn addresses as well as any unsafe networks issued to us in the certificate.
// The vpn addresses are a full bit match while the unsafe networks only match the prefix
routableNetworks *bart.Table[NetworkType]
routableNetworks *bart.Lite
// assignedNetworks is a list of vpn networks assigned to us in the certificate.
assignedNetworks []netip.Prefix
@@ -150,16 +148,17 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D
tmax = defaultTimeout
}
routableNetworks := new(bart.Table[NetworkType])
routableNetworks := new(bart.Lite)
var assignedNetworks []netip.Prefix
for _, network := range c.Networks() {
routableNetworks.Insert(netip.PrefixFrom(network.Addr(), network.Addr().BitLen()), NetworkTypeVPN)
nprefix := netip.PrefixFrom(network.Addr(), network.Addr().BitLen())
routableNetworks.Insert(nprefix)
assignedNetworks = append(assignedNetworks, network)
}
hasUnsafeNetworks := false
for _, n := range c.UnsafeNetworks() {
routableNetworks.Insert(n, NetworkTypeUnsafe)
routableNetworks.Insert(n)
hasUnsafeNetworks = true
}
@@ -170,8 +169,6 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D
},
InRules: newFirewallTable(),
OutRules: newFirewallTable(),
UnsafeInRules: newFirewallTable(),
UnsafeOutRules: newFirewallTable(),
TCPTimeout: tcpTimeout,
UDPTimeout: UDPTimeout,
DefaultTimeout: defaultTimeout,
@@ -214,7 +211,6 @@ func NewFirewallFromConfig(l *logrus.Logger, cs *CertState, c *config.C) (*Firew
fw.defaultLocalCIDRAny = c.GetBool("firewall.default_local_cidr_any", false)
//TODO: do we also need firewall.unsafe_inbound_action and firewall.unsafe_outbound_action?
inboundAction := c.GetString("firewall.inbound_action", "drop")
switch inboundAction {
case "reject":
@@ -237,26 +233,12 @@ func NewFirewallFromConfig(l *logrus.Logger, cs *CertState, c *config.C) (*Firew
fw.OutSendReject = false
}
// outbound rules
err := AddFirewallRulesFromConfig(l, false, false, c, fw)
err := AddFirewallRulesFromConfig(l, false, c, fw)
if err != nil {
return nil, err
}
// unsafe outbound rules
err = AddFirewallRulesFromConfig(l, true, false, c, fw)
if err != nil {
return nil, err
}
// inbound rules
err = AddFirewallRulesFromConfig(l, false, true, c, fw)
if err != nil {
return nil, err
}
// unsafe inbound rules
err = AddFirewallRulesFromConfig(l, true, true, c, fw)
err = AddFirewallRulesFromConfig(l, true, c, fw)
if err != nil {
return nil, err
}
@@ -265,11 +247,22 @@ func NewFirewallFromConfig(l *logrus.Logger, cs *CertState, c *config.C) (*Firew
}
// AddRule properly creates the in memory rule structure for a firewall table.
func (f *Firewall) AddRule(unsafe, incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, cidr, localCidr, caName string, caSha string) error {
func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip, localIp netip.Prefix, caName string, caSha string) error {
// Under gomobile, stringing a nil pointer with fmt causes an abort in debug mode for iOS
// https://github.com/golang/go/issues/14131
sIp := ""
if ip.IsValid() {
sIp = ip.String()
}
lIp := ""
if localIp.IsValid() {
lIp = localIp.String()
}
// We need this rule string because we generate a hash. Removing this will break firewall reload.
ruleString := fmt.Sprintf(
"unsafe: %v, incoming: %v, proto: %v, startPort: %v, endPort: %v, groups: %v, host: %v, ip: %v, localIp: %v, caName: %v, caSha: %s",
unsafe, incoming, proto, startPort, endPort, groups, host, cidr, localCidr, caName, caSha,
"incoming: %v, proto: %v, startPort: %v, endPort: %v, groups: %v, host: %v, ip: %v, localIp: %v, caName: %v, caSha: %s",
incoming, proto, startPort, endPort, groups, host, sIp, lIp, caName, caSha,
)
f.rules += ruleString + "\n"
@@ -277,12 +270,8 @@ func (f *Firewall) AddRule(unsafe, incoming bool, proto uint8, startPort int32,
if !incoming {
direction = "outgoing"
}
fields := m{"direction": direction, "proto": proto, "startPort": startPort, "endPort": endPort, "groups": groups, "host": host, "cidr": cidr, "localCidr": localCidr, "caName": caName, "caSha": caSha}
if unsafe {
fields["unsafe"] = true
}
f.l.WithField("firewallRule", fields).Info("Firewall rule added")
f.l.WithField("firewallRule", m{"direction": direction, "proto": proto, "startPort": startPort, "endPort": endPort, "groups": groups, "host": host, "ip": sIp, "localIp": lIp, "caName": caName, "caSha": caSha}).
Info("Firewall rule added")
var (
ft *FirewallTable
@@ -290,18 +279,9 @@ func (f *Firewall) AddRule(unsafe, incoming bool, proto uint8, startPort int32,
)
if incoming {
if unsafe {
ft = f.UnsafeInRules
} else {
ft = f.InRules
}
ft = f.InRules
} else {
if unsafe {
ft = f.UnsafeOutRules
} else {
ft = f.OutRules
}
ft = f.OutRules
}
switch proto {
@@ -317,7 +297,7 @@ func (f *Firewall) AddRule(unsafe, incoming bool, proto uint8, startPort int32,
return fmt.Errorf("unknown protocol %v", proto)
}
return fp.addRule(f, startPort, endPort, groups, host, cidr, localCidr, caName, caSha)
return fp.addRule(f, startPort, endPort, groups, host, ip, localIp, caName, caSha)
}
// GetRuleHash returns a hash representation of all inbound and outbound rules
@@ -338,21 +318,12 @@ func (f *Firewall) GetRuleHashes() string {
return "SHA:" + f.GetRuleHash() + ",FNV:" + strconv.FormatUint(uint64(f.GetRuleHashFNV()), 10)
}
func AddFirewallRulesFromConfig(l *logrus.Logger, unsafe, inbound bool, c *config.C, fw FirewallInterface) error {
func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw FirewallInterface) error {
var table string
if inbound {
if unsafe {
table = "firewall.unsafe_inbound"
} else {
table = "firewall.inbound"
}
table = "firewall.inbound"
} else {
if unsafe {
table = "firewall.unsafe_outbound"
} else {
table = "firewall.outbound"
}
table = "firewall.outbound"
}
r := c.Get(table)
@@ -366,6 +337,7 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, unsafe, inbound bool, c *confi
}
for i, t := range rs {
var groups []string
r, err := convertRule(l, t, table, i)
if err != nil {
return fmt.Errorf("%s rule #%v; %s", table, i, err)
@@ -375,10 +347,23 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, unsafe, inbound bool, c *confi
return fmt.Errorf("%s rule #%v; only one of port or code should be provided", table, i)
}
if r.Host == "" && len(r.Groups) == 0 && r.Cidr == "" && r.LocalCidr == "" && r.CAName == "" && r.CASha == "" {
if r.Host == "" && len(r.Groups) == 0 && r.Group == "" && r.Cidr == "" && r.LocalCidr == "" && r.CAName == "" && r.CASha == "" {
return fmt.Errorf("%s rule #%v; at least one of host, group, cidr, local_cidr, ca_name, or ca_sha must be provided", table, i)
}
if len(r.Groups) > 0 {
groups = r.Groups
}
if r.Group != "" {
// Check if we have both groups and group provided in the rule config
if len(groups) > 0 {
return fmt.Errorf("%s rule #%v; only one of group or groups should be defined, both provided", table, i)
}
groups = []string{r.Group}
}
var sPort, errPort string
if r.Code != "" {
errPort = "code"
@@ -407,25 +392,23 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, unsafe, inbound bool, c *confi
return fmt.Errorf("%s rule #%v; proto was not understood; `%s`", table, i, r.Proto)
}
if r.Cidr != "" && r.Cidr != "any" {
_, err = netip.ParsePrefix(r.Cidr)
var cidr netip.Prefix
if r.Cidr != "" {
cidr, err = netip.ParsePrefix(r.Cidr)
if err != nil {
return fmt.Errorf("%s rule #%v; cidr did not parse; %s", table, i, err)
}
}
if r.LocalCidr != "" && r.LocalCidr != "any" {
_, err = netip.ParsePrefix(r.LocalCidr)
var localCidr netip.Prefix
if r.LocalCidr != "" {
localCidr, err = netip.ParsePrefix(r.LocalCidr)
if err != nil {
return fmt.Errorf("%s rule #%v; local_cidr did not parse; %s", table, i, err)
}
}
if warning := r.sanity(); warning != nil {
l.Warnf("%s rule #%v; %s", table, i, warning)
}
err = fw.AddRule(unsafe, inbound, proto, startPort, endPort, r.Groups, r.Host, r.Cidr, r.LocalCidr, r.CAName, r.CASha)
err = fw.AddRule(inbound, proto, startPort, endPort, groups, r.Host, cidr, localCidr, r.CAName, r.CASha)
if err != nil {
return fmt.Errorf("%s rule #%v; `%s`", table, i, err)
}
@@ -434,10 +417,8 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, unsafe, inbound bool, c *confi
return nil
}
var ErrUnknownNetworkType = errors.New("unknown network type")
var ErrPeerRejected = errors.New("remote address is not within a network that we handle")
var ErrInvalidRemoteIP = errors.New("remote address is not in remote certificate networks")
var ErrInvalidLocalIP = errors.New("local address is not in list of handled local addresses")
var ErrInvalidRemoteIP = errors.New("remote IP is not in remote certificate subnets")
var ErrInvalidLocalIP = errors.New("local IP is not in list of handled local IPs")
var ErrNoMatchingRule = errors.New("no matching rule in firewall table")
// Drop returns an error if the packet should be dropped, explaining why. It
@@ -448,59 +429,29 @@ func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *
return nil
}
var remoteNetworkType NetworkType
var ok bool
// Make sure remote address matches nebula certificate, and determine how to treat it
if h.networks == nil {
// Make sure remote address matches nebula certificate
if h.networks != nil {
if !h.networks.Contains(fp.RemoteAddr) {
f.metrics(incoming).droppedRemoteAddr.Inc(1)
return ErrInvalidRemoteIP
}
} else {
// Simple case: Certificate has one address and no unsafe networks
if h.vpnAddrs[0] != fp.RemoteAddr {
f.metrics(incoming).droppedRemoteAddr.Inc(1)
return ErrInvalidRemoteIP
}
remoteNetworkType = NetworkTypeVPN
} else {
remoteNetworkType, ok = h.networks.Lookup(fp.RemoteAddr)
if !ok {
f.metrics(incoming).droppedRemoteAddr.Inc(1)
return ErrInvalidRemoteIP
}
switch remoteNetworkType {
case NetworkTypeVPN:
break // nothing special
case NetworkTypeVPNPeer:
f.metrics(incoming).droppedRemoteAddr.Inc(1)
return ErrPeerRejected // reject for now, one day this may have different FW rules
case NetworkTypeUnsafe:
break // nothing special, one day this may have different FW rules
default:
f.metrics(incoming).droppedRemoteAddr.Inc(1)
return ErrUnknownNetworkType //should never happen
}
}
// Make sure we are supposed to be handling this local ip address
localNetworkType, ok := f.routableNetworks.Lookup(fp.LocalAddr)
if !ok {
if !f.routableNetworks.Contains(fp.LocalAddr) {
f.metrics(incoming).droppedLocalAddr.Inc(1)
return ErrInvalidLocalIP
}
useUnsafe := remoteNetworkType == NetworkTypeUnsafe || localNetworkType == NetworkTypeUnsafe
var table *FirewallTable
table := f.OutRules
if incoming {
if useUnsafe {
table = f.UnsafeInRules
} else {
table = f.InRules
}
} else {
if useUnsafe {
table = f.UnsafeOutRules
} else {
table = f.OutRules
}
table = f.InRules
}
// We now know which firewall table to check against
@@ -510,13 +461,12 @@ func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *
}
// We always want to conntrack since it is a faster operation
f.addConn(fp, useUnsafe, incoming)
f.addConn(fp, incoming)
return nil
}
func (f *Firewall) metrics(incoming bool) firewallMetrics {
//TODO: need unsafe metrics too
if incoming {
return f.incomingMetrics
} else {
@@ -556,6 +506,7 @@ func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool,
}
c, ok := conntrack.Conns[fp]
if !ok {
conntrack.Unlock()
return false
@@ -564,19 +515,9 @@ func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool,
if c.rulesVersion != f.rulesVersion {
// This conntrack entry was for an older rule set, validate
// it still passes with the current rule set
var table *FirewallTable
table := f.OutRules
if c.incoming {
if c.unsafe {
table = f.UnsafeInRules
} else {
table = f.InRules
}
} else {
if c.unsafe {
table = f.UnsafeOutRules
} else {
table = f.OutRules
}
table = f.InRules
}
// We now know which firewall table to check against
@@ -585,7 +526,6 @@ func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool,
h.logger(f.l).
WithField("fwPacket", fp).
WithField("incoming", c.incoming).
WithField("unsafe", c.unsafe).
WithField("rulesVersion", f.rulesVersion).
WithField("oldRulesVersion", c.rulesVersion).
Debugln("dropping old conntrack entry, does not match new ruleset")
@@ -599,7 +539,6 @@ func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool,
h.logger(f.l).
WithField("fwPacket", fp).
WithField("incoming", c.incoming).
WithField("unsafe", c.unsafe).
WithField("rulesVersion", f.rulesVersion).
WithField("oldRulesVersion", c.rulesVersion).
Debugln("keeping old conntrack entry, does match new ruleset")
@@ -626,7 +565,7 @@ func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool,
return true
}
func (f *Firewall) addConn(fp firewall.Packet, unsafe, incoming bool) {
func (f *Firewall) addConn(fp firewall.Packet, incoming bool) {
var timeout time.Duration
c := &conn{}
@@ -649,7 +588,6 @@ func (f *Firewall) addConn(fp firewall.Packet, unsafe, incoming bool) {
// Record which rulesVersion allowed this connection, so we can retest after
// firewall reload
c.incoming = incoming
c.unsafe = unsafe
c.rulesVersion = f.rulesVersion
c.Expires = time.Now().Add(timeout)
conntrack.Conns[fp] = c
@@ -702,7 +640,7 @@ func (ft *FirewallTable) match(p firewall.Packet, incoming bool, c *cert.CachedC
return false
}
func (fp firewallPort) addRule(f *Firewall, startPort int32, endPort int32, groups []string, host string, cidr, localCidr, caName string, caSha string) error {
func (fp firewallPort) addRule(f *Firewall, startPort int32, endPort int32, groups []string, host string, ip, localIp netip.Prefix, caName string, caSha string) error {
if startPort > endPort {
return fmt.Errorf("start port was lower than end port")
}
@@ -715,7 +653,7 @@ func (fp firewallPort) addRule(f *Firewall, startPort int32, endPort int32, grou
}
}
if err := fp[i].addRule(f, groups, host, cidr, localCidr, caName, caSha); err != nil {
if err := fp[i].addRule(f, groups, host, ip, localIp, caName, caSha); err != nil {
return err
}
}
@@ -746,7 +684,7 @@ func (fp firewallPort) match(p firewall.Packet, incoming bool, c *cert.CachedCer
return fp[firewall.PortAny].match(p, c, caPool)
}
func (fc *FirewallCA) addRule(f *Firewall, groups []string, host string, cidr, localCidr, caName, caSha string) error {
func (fc *FirewallCA) addRule(f *Firewall, groups []string, host string, ip, localIp netip.Prefix, caName, caSha string) error {
fr := func() *FirewallRule {
return &FirewallRule{
Hosts: make(map[string]*firewallLocalCIDR),
@@ -760,14 +698,14 @@ func (fc *FirewallCA) addRule(f *Firewall, groups []string, host string, cidr, l
fc.Any = fr()
}
return fc.Any.addRule(f, groups, host, cidr, localCidr)
return fc.Any.addRule(f, groups, host, ip, localIp)
}
if caSha != "" {
if _, ok := fc.CAShas[caSha]; !ok {
fc.CAShas[caSha] = fr()
}
err := fc.CAShas[caSha].addRule(f, groups, host, cidr, localCidr)
err := fc.CAShas[caSha].addRule(f, groups, host, ip, localIp)
if err != nil {
return err
}
@@ -777,7 +715,7 @@ func (fc *FirewallCA) addRule(f *Firewall, groups []string, host string, cidr, l
if _, ok := fc.CANames[caName]; !ok {
fc.CANames[caName] = fr()
}
err := fc.CANames[caName].addRule(f, groups, host, cidr, localCidr)
err := fc.CANames[caName].addRule(f, groups, host, ip, localIp)
if err != nil {
return err
}
@@ -809,24 +747,24 @@ func (fc *FirewallCA) match(p firewall.Packet, c *cert.CachedCertificate, caPool
return fc.CANames[s.Certificate.Name()].match(p, c)
}
func (fr *FirewallRule) addRule(f *Firewall, groups []string, host, cidr, localCidr string) error {
func (fr *FirewallRule) addRule(f *Firewall, groups []string, host string, ip, localCIDR netip.Prefix) error {
flc := func() *firewallLocalCIDR {
return &firewallLocalCIDR{
LocalCIDR: new(bart.Lite),
}
}
if fr.isAny(groups, host, cidr) {
if fr.isAny(groups, host, ip) {
if fr.Any == nil {
fr.Any = flc()
}
return fr.Any.addRule(f, localCidr)
return fr.Any.addRule(f, localCIDR)
}
if len(groups) > 0 {
nlc := flc()
err := nlc.addRule(f, localCidr)
err := nlc.addRule(f, localCIDR)
if err != nil {
return err
}
@@ -842,34 +780,30 @@ func (fr *FirewallRule) addRule(f *Firewall, groups []string, host, cidr, localC
if nlc == nil {
nlc = flc()
}
err := nlc.addRule(f, localCidr)
err := nlc.addRule(f, localCIDR)
if err != nil {
return err
}
fr.Hosts[host] = nlc
}
if cidr != "" {
c, err := netip.ParsePrefix(cidr)
if err != nil {
return err
}
nlc, _ := fr.CIDR.Get(c)
if ip.IsValid() {
nlc, _ := fr.CIDR.Get(ip)
if nlc == nil {
nlc = flc()
}
err = nlc.addRule(f, localCidr)
err := nlc.addRule(f, localCIDR)
if err != nil {
return err
}
fr.CIDR.Insert(c, nlc)
fr.CIDR.Insert(ip, nlc)
}
return nil
}
func (fr *FirewallRule) isAny(groups []string, host string, cidr string) bool {
if len(groups) == 0 && host == "" && cidr == "" {
func (fr *FirewallRule) isAny(groups []string, host string, ip netip.Prefix) bool {
if len(groups) == 0 && host == "" && !ip.IsValid() {
return true
}
@@ -883,7 +817,7 @@ func (fr *FirewallRule) isAny(groups []string, host string, cidr string) bool {
return true
}
if cidr == "any" {
if ip.IsValid() && ip.Bits() == 0 {
return true
}
@@ -935,13 +869,8 @@ func (fr *FirewallRule) match(p firewall.Packet, c *cert.CachedCertificate) bool
return false
}
func (flc *firewallLocalCIDR) addRule(f *Firewall, localCidr string) error {
if localCidr == "any" {
flc.Any = true
return nil
}
if localCidr == "" {
func (flc *firewallLocalCIDR) addRule(f *Firewall, localIp netip.Prefix) error {
if !localIp.IsValid() {
if !f.hasUnsafeNetworks || f.defaultLocalCIDRAny {
flc.Any = true
return nil
@@ -952,13 +881,12 @@ func (flc *firewallLocalCIDR) addRule(f *Firewall, localCidr string) error {
}
return nil
} else if localIp.Bits() == 0 {
flc.Any = true
return nil
}
c, err := netip.ParsePrefix(localCidr)
if err != nil {
return err
}
flc.LocalCIDR.Insert(c)
flc.LocalCIDR.Insert(localIp)
return nil
}
@@ -979,6 +907,7 @@ type rule struct {
Code string
Proto string
Host string
Group string
Groups []string
Cidr string
LocalCidr string
@@ -1006,7 +935,6 @@ func convertRule(l *logrus.Logger, p any, table string, i int) (rule, error) {
r.Code = toString("code", m)
r.Proto = toString("proto", m)
r.Host = toString("host", m)
//TODO: create an alias to remote_cidr and deprecate cidr?
r.Cidr = toString("cidr", m)
r.LocalCidr = toString("local_cidr", m)
r.CAName = toString("ca_name", m)
@@ -1021,8 +949,7 @@ func convertRule(l *logrus.Logger, p any, table string, i int) (rule, error) {
l.Warnf("%s rule #%v; group was an array with a single value, converting to simple value", table, i)
m["group"] = v[0]
}
singleGroup := toString("group", m)
r.Group = toString("group", m)
if rg, ok := m["groups"]; ok {
switch reflect.TypeOf(rg).Kind() {
@@ -1039,60 +966,9 @@ func convertRule(l *logrus.Logger, p any, table string, i int) (rule, error) {
}
}
//flatten group vs groups
if singleGroup != "" {
// Check if we have both groups and group provided in the rule config
if len(r.Groups) > 0 {
return r, fmt.Errorf("only one of group or groups should be defined, both provided")
}
r.Groups = []string{singleGroup}
}
return r, nil
}
// sanity returns an error if the rule would be evaluated in a way that would short-circuit a configured check on a wildcard value
// rules are evaluated as "port AND proto AND (ca_sha OR ca_name) AND (host OR group OR groups OR cidr) AND local_cidr"
func (r *rule) sanity() error {
//port, proto, local_cidr are AND, no need to check here
//ca_sha and ca_name don't have a wildcard value, no need to check here
groupsEmpty := len(r.Groups) == 0
hostEmpty := r.Host == ""
cidrEmpty := r.Cidr == ""
if (groupsEmpty && hostEmpty && cidrEmpty) == true {
return nil //no content!
}
groupsHasAny := slices.Contains(r.Groups, "any")
if groupsHasAny && len(r.Groups) > 1 {
return fmt.Errorf("groups spec [%s] contains the group '\"any\". This rule will ignore the other groups specified", r.Groups)
}
if r.Host == "any" {
if !groupsEmpty {
return fmt.Errorf("groups specified as %s, but host=any will match any host, regardless of groups", r.Groups)
}
if !cidrEmpty {
return fmt.Errorf("cidr specified as %s, but host=any will match any host, regardless of cidr", r.Cidr)
}
}
if groupsHasAny {
if !hostEmpty && r.Host != "any" {
return fmt.Errorf("groups spec [%s] contains the group '\"any\". This rule will ignore the specified host %s", r.Groups, r.Host)
}
if !cidrEmpty {
return fmt.Errorf("groups spec [%s] contains the group '\"any\". This rule will ignore the specified cidr %s", r.Groups, r.Cidr)
}
}
//todo alert on cidr-any
return nil
}
func parsePort(s string) (startPort, endPort int32, err error) {
if s == "any" {
startPort = firewall.PortAny

View File

@@ -8,8 +8,6 @@ import (
"testing"
"time"
"github.com/gaissmai/bart"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/firewall"
@@ -73,114 +71,85 @@ func TestFirewall_AddRule(t *testing.T) {
ti6, err := netip.ParsePrefix("fd12::34/128")
require.NoError(t, err)
require.NoError(t, fw.AddRule(false, true, firewall.ProtoTCP, 1, 1, []string{}, "", "", "", "", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoTCP, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
// An empty rule is any
assert.True(t, fw.InRules.TCP[1].Any.Any.Any)
assert.Empty(t, fw.InRules.TCP[1].Any.Groups)
assert.Empty(t, fw.InRules.TCP[1].Any.Hosts)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(false, true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", "", "", "", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
assert.Nil(t, fw.InRules.UDP[1].Any.Any)
assert.Contains(t, fw.InRules.UDP[1].Any.Groups[0].Groups, "g1")
assert.Empty(t, fw.InRules.UDP[1].Any.Hosts)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(false, true, firewall.ProtoICMP, 1, 1, []string{}, "h1", "", "", "", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoICMP, 1, 1, []string{}, "h1", netip.Prefix{}, netip.Prefix{}, "", ""))
assert.Nil(t, fw.InRules.ICMP[1].Any.Any)
assert.Empty(t, fw.InRules.ICMP[1].Any.Groups)
assert.Contains(t, fw.InRules.ICMP[1].Any.Hosts, "h1")
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(false, false, firewall.ProtoAny, 1, 1, []string{}, "", ti.String(), "", "", ""))
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti, netip.Prefix{}, "", ""))
assert.Nil(t, fw.OutRules.AnyProto[1].Any.Any)
_, ok := fw.OutRules.AnyProto[1].Any.CIDR.Get(ti)
assert.True(t, ok)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(false, false, firewall.ProtoAny, 1, 1, []string{}, "", ti6.String(), "", "", ""))
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti6, netip.Prefix{}, "", ""))
assert.Nil(t, fw.OutRules.AnyProto[1].Any.Any)
_, ok = fw.OutRules.AnyProto[1].Any.CIDR.Get(ti6)
assert.True(t, ok)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(false, false, firewall.ProtoAny, 1, 1, []string{}, "", "", ti.String(), "", ""))
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, ti, "", ""))
assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any)
ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti)
_, ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti)
assert.True(t, ok)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(false, false, firewall.ProtoAny, 1, 1, []string{}, "", "", ti6.String(), "", ""))
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, ti6, "", ""))
assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any)
ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti6)
_, ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti6)
assert.True(t, ok)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(false, true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", "", "", "ca-name", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "ca-name", ""))
assert.Contains(t, fw.InRules.UDP[1].CANames, "ca-name")
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(false, true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", "", "", "", "ca-sha"))
require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "", "ca-sha"))
assert.Contains(t, fw.InRules.UDP[1].CAShas, "ca-sha")
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(false, false, firewall.ProtoAny, 0, 0, []string{}, "any", "", "", "", ""))
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", netip.Prefix{}, netip.Prefix{}, "", ""))
assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
anyIp, err := netip.ParsePrefix("0.0.0.0/0")
require.NoError(t, err)
require.NoError(t, fw.AddRule(false, false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp.String(), "", "", ""))
assert.Nil(t, fw.OutRules.AnyProto[0].Any.Any)
table, ok := fw.OutRules.AnyProto[0].Any.CIDR.Lookup(netip.MustParseAddr("1.1.1.1"))
assert.True(t, table.Any)
table, ok = fw.OutRules.AnyProto[0].Any.CIDR.Lookup(netip.MustParseAddr("9::9"))
assert.False(t, ok)
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp, netip.Prefix{}, "", ""))
assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
anyIp6, err := netip.ParsePrefix("::/0")
require.NoError(t, err)
require.NoError(t, fw.AddRule(false, false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp6.String(), "", "", ""))
assert.Nil(t, fw.OutRules.AnyProto[0].Any.Any)
table, ok = fw.OutRules.AnyProto[0].Any.CIDR.Lookup(netip.MustParseAddr("9::9"))
assert.True(t, table.Any)
table, ok = fw.OutRules.AnyProto[0].Any.CIDR.Lookup(netip.MustParseAddr("1.1.1.1"))
assert.False(t, ok)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(false, false, firewall.ProtoAny, 0, 0, []string{}, "", "any", "", "", ""))
assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(false, false, firewall.ProtoAny, 0, 0, []string{}, "", "", anyIp.String(), "", ""))
assert.False(t, fw.OutRules.AnyProto[0].Any.Any.Any)
assert.True(t, fw.OutRules.AnyProto[0].Any.Any.LocalCIDR.Lookup(netip.MustParseAddr("1.1.1.1")))
assert.False(t, fw.OutRules.AnyProto[0].Any.Any.LocalCIDR.Lookup(netip.MustParseAddr("9::9")))
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(false, false, firewall.ProtoAny, 0, 0, []string{}, "", "", anyIp6.String(), "", ""))
assert.False(t, fw.OutRules.AnyProto[0].Any.Any.Any)
assert.True(t, fw.OutRules.AnyProto[0].Any.Any.LocalCIDR.Lookup(netip.MustParseAddr("9::9")))
assert.False(t, fw.OutRules.AnyProto[0].Any.Any.LocalCIDR.Lookup(netip.MustParseAddr("1.1.1.1")))
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(false, false, firewall.ProtoAny, 0, 0, []string{}, "", "", "any", "", ""))
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp6, netip.Prefix{}, "", ""))
assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
// Test error conditions
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.Error(t, fw.AddRule(false, true, math.MaxUint8, 0, 0, []string{}, "", "", "", "", ""))
require.Error(t, fw.AddRule(false, true, firewall.ProtoAny, 10, 0, []string{}, "", "", "", "", ""))
require.Error(t, fw.AddRule(true, math.MaxUint8, 0, 0, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
require.Error(t, fw.AddRule(true, firewall.ProtoAny, 10, 0, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
}
func TestFirewall_Drop(t *testing.T) {
l := test.NewLogger()
ob := &bytes.Buffer{}
l.SetOutput(ob)
myVpnNetworksTable := new(bart.Lite)
myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
p := firewall.Packet{
LocalAddr: netip.MustParseAddr("1.2.3.4"),
RemoteAddr: netip.MustParseAddr("1.2.3.4"),
@@ -205,10 +174,10 @@ func TestFirewall_Drop(t *testing.T) {
},
vpnAddrs: []netip.Addr{netip.MustParseAddr("1.2.3.4")},
}
h.buildNetworks(myVpnNetworksTable, &c)
h.buildNetworks(c.networks, c.unsafeNetworks)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(false, true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
cp := cert.NewCAPool()
// Drop outbound
@@ -227,28 +196,28 @@ func TestFirewall_Drop(t *testing.T) {
// ensure signer doesn't get in the way of group checks
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(false, true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum"))
require.NoError(t, fw.AddRule(false, true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum-bad"))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum"))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad"))
assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
// test caSha doesn't drop on match
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(false, true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum-bad"))
require.NoError(t, fw.AddRule(false, true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum"))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad"))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum"))
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
// ensure ca name doesn't get in the way of group checks
cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(false, true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good", ""))
require.NoError(t, fw.AddRule(false, true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good-bad", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", ""))
assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
// test caName doesn't drop on match
cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(false, true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good-bad", ""))
require.NoError(t, fw.AddRule(false, true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", ""))
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
}
@@ -257,9 +226,6 @@ func TestFirewall_DropV6(t *testing.T) {
ob := &bytes.Buffer{}
l.SetOutput(ob)
myVpnNetworksTable := new(bart.Lite)
myVpnNetworksTable.Insert(netip.MustParsePrefix("fd00::/7"))
p := firewall.Packet{
LocalAddr: netip.MustParseAddr("fd12::34"),
RemoteAddr: netip.MustParseAddr("fd12::34"),
@@ -284,10 +250,10 @@ func TestFirewall_DropV6(t *testing.T) {
},
vpnAddrs: []netip.Addr{netip.MustParseAddr("fd12::34")},
}
h.buildNetworks(myVpnNetworksTable, &c)
h.buildNetworks(c.networks, c.unsafeNetworks)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(false, true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
cp := cert.NewCAPool()
// Drop outbound
@@ -306,28 +272,28 @@ func TestFirewall_DropV6(t *testing.T) {
// ensure signer doesn't get in the way of group checks
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(false, true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum"))
require.NoError(t, fw.AddRule(false, true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum-bad"))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum"))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad"))
assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
// test caSha doesn't drop on match
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(false, true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum-bad"))
require.NoError(t, fw.AddRule(false, true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum"))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad"))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum"))
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
// ensure ca name doesn't get in the way of group checks
cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(false, true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good", ""))
require.NoError(t, fw.AddRule(false, true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good-bad", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", ""))
assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
// test caName doesn't drop on match
cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(false, true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good-bad", ""))
require.NoError(t, fw.AddRule(false, true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", ""))
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
}
@@ -338,12 +304,12 @@ func BenchmarkFirewallTable_match(b *testing.B) {
}
pfix := netip.MustParsePrefix("172.1.1.1/32")
_ = ft.TCP.addRule(f, 10, 10, []string{"good-group"}, "good-host", pfix.String(), "", "", "")
_ = ft.TCP.addRule(f, 100, 100, []string{"good-group"}, "good-host", "", pfix.String(), "", "")
_ = ft.TCP.addRule(f, 10, 10, []string{"good-group"}, "good-host", pfix, netip.Prefix{}, "", "")
_ = ft.TCP.addRule(f, 100, 100, []string{"good-group"}, "good-host", netip.Prefix{}, pfix, "", "")
pfix6 := netip.MustParsePrefix("fd11::11/128")
_ = ft.TCP.addRule(f, 10, 10, []string{"good-group"}, "good-host", pfix6.String(), "", "", "")
_ = ft.TCP.addRule(f, 100, 100, []string{"good-group"}, "good-host", "", pfix6.String(), "", "")
_ = ft.TCP.addRule(f, 10, 10, []string{"good-group"}, "good-host", pfix6, netip.Prefix{}, "", "")
_ = ft.TCP.addRule(f, 100, 100, []string{"good-group"}, "good-host", netip.Prefix{}, pfix6, "", "")
cp := cert.NewCAPool()
b.Run("fail on proto", func(b *testing.B) {
@@ -487,8 +453,6 @@ func TestFirewall_Drop2(t *testing.T) {
l := test.NewLogger()
ob := &bytes.Buffer{}
l.SetOutput(ob)
myVpnNetworksTable := new(bart.Lite)
myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
p := firewall.Packet{
LocalAddr: netip.MustParseAddr("1.2.3.4"),
@@ -514,7 +478,7 @@ func TestFirewall_Drop2(t *testing.T) {
},
vpnAddrs: []netip.Addr{network.Addr()},
}
h.buildNetworks(myVpnNetworksTable, c.Certificate)
h.buildNetworks(c.Certificate.Networks(), c.Certificate.UnsafeNetworks())
c1 := cert.CachedCertificate{
Certificate: &dummyCert{
@@ -529,10 +493,10 @@ func TestFirewall_Drop2(t *testing.T) {
peerCert: &c1,
},
}
h1.buildNetworks(myVpnNetworksTable, c1.Certificate)
h1.buildNetworks(c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks())
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
require.NoError(t, fw.AddRule(false, true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", "", "", "", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
cp := cert.NewCAPool()
// h1/c1 lacks the proper groups
@@ -546,8 +510,6 @@ func TestFirewall_Drop3(t *testing.T) {
l := test.NewLogger()
ob := &bytes.Buffer{}
l.SetOutput(ob)
myVpnNetworksTable := new(bart.Lite)
myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
p := firewall.Packet{
LocalAddr: netip.MustParseAddr("1.2.3.4"),
@@ -579,7 +541,7 @@ func TestFirewall_Drop3(t *testing.T) {
},
vpnAddrs: []netip.Addr{network.Addr()},
}
h1.buildNetworks(myVpnNetworksTable, c1.Certificate)
h1.buildNetworks(c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks())
c2 := cert.CachedCertificate{
Certificate: &dummyCert{
@@ -594,7 +556,7 @@ func TestFirewall_Drop3(t *testing.T) {
},
vpnAddrs: []netip.Addr{network.Addr()},
}
h2.buildNetworks(myVpnNetworksTable, c2.Certificate)
h2.buildNetworks(c2.Certificate.Networks(), c2.Certificate.UnsafeNetworks())
c3 := cert.CachedCertificate{
Certificate: &dummyCert{
@@ -609,11 +571,11 @@ func TestFirewall_Drop3(t *testing.T) {
},
vpnAddrs: []netip.Addr{network.Addr()},
}
h3.buildNetworks(myVpnNetworksTable, c3.Certificate)
h3.buildNetworks(c3.Certificate.Networks(), c3.Certificate.UnsafeNetworks())
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
require.NoError(t, fw.AddRule(false, true, firewall.ProtoAny, 1, 1, []string{}, "host1", "", "", "", ""))
require.NoError(t, fw.AddRule(false, true, firewall.ProtoAny, 1, 1, []string{}, "", "", "", "", "signer-sha"))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", netip.Prefix{}, netip.Prefix{}, "", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-sha"))
cp := cert.NewCAPool()
// c1 should pass because host match
@@ -627,7 +589,7 @@ func TestFirewall_Drop3(t *testing.T) {
// Test a remote address match
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
require.NoError(t, fw.AddRule(false, true, firewall.ProtoAny, 1, 1, []string{}, "", "1.2.3.4/24", "", "", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.MustParsePrefix("1.2.3.4/24"), netip.Prefix{}, "", ""))
require.NoError(t, fw.Drop(p, true, &h1, cp, nil))
}
@@ -635,8 +597,6 @@ func TestFirewall_Drop3V6(t *testing.T) {
l := test.NewLogger()
ob := &bytes.Buffer{}
l.SetOutput(ob)
myVpnNetworksTable := new(bart.Lite)
myVpnNetworksTable.Insert(netip.MustParsePrefix("fd00::/7"))
p := firewall.Packet{
LocalAddr: netip.MustParseAddr("fd12::34"),
@@ -660,12 +620,12 @@ func TestFirewall_Drop3V6(t *testing.T) {
},
vpnAddrs: []netip.Addr{network.Addr()},
}
h.buildNetworks(myVpnNetworksTable, c.Certificate)
h.buildNetworks(c.Certificate.Networks(), c.Certificate.UnsafeNetworks())
// Test a remote address match
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
cp := cert.NewCAPool()
require.NoError(t, fw.AddRule(false, true, firewall.ProtoAny, 1, 1, []string{}, "", "fd12::34/120", "", "", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.MustParsePrefix("fd12::34/120"), netip.Prefix{}, "", ""))
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
}
@@ -673,8 +633,6 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
l := test.NewLogger()
ob := &bytes.Buffer{}
l.SetOutput(ob)
myVpnNetworksTable := new(bart.Lite)
myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
p := firewall.Packet{
LocalAddr: netip.MustParseAddr("1.2.3.4"),
@@ -701,10 +659,10 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
},
vpnAddrs: []netip.Addr{network.Addr()},
}
h.buildNetworks(myVpnNetworksTable, c.Certificate)
h.buildNetworks(c.Certificate.Networks(), c.Certificate.UnsafeNetworks())
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
require.NoError(t, fw.AddRule(false, true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
cp := cert.NewCAPool()
// Drop outbound
@@ -717,7 +675,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
oldFw := fw
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
require.NoError(t, fw.AddRule(false, true, firewall.ProtoAny, 10, 10, []string{"any"}, "", "", "", "", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 10, 10, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
fw.Conntrack = oldFw.Conntrack
fw.rulesVersion = oldFw.rulesVersion + 1
@@ -726,7 +684,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
oldFw = fw
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
require.NoError(t, fw.AddRule(false, true, firewall.ProtoAny, 11, 11, []string{"any"}, "", "", "", "", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 11, 11, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
fw.Conntrack = oldFw.Conntrack
fw.rulesVersion = oldFw.rulesVersion + 1
@@ -738,8 +696,6 @@ func TestFirewall_DropIPSpoofing(t *testing.T) {
l := test.NewLogger()
ob := &bytes.Buffer{}
l.SetOutput(ob)
myVpnNetworksTable := new(bart.Lite)
myVpnNetworksTable.Insert(netip.MustParsePrefix("192.0.2.1/24"))
c := cert.CachedCertificate{
Certificate: &dummyCert{
@@ -761,11 +717,11 @@ func TestFirewall_DropIPSpoofing(t *testing.T) {
},
vpnAddrs: []netip.Addr{c1.Certificate.Networks()[0].Addr()},
}
h1.buildNetworks(myVpnNetworksTable, c1.Certificate)
h1.buildNetworks(c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks())
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
require.NoError(t, fw.AddRule(false, true, firewall.ProtoAny, 1, 1, []string{}, "", "", "", "", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
cp := cert.NewCAPool()
// Packet spoofed by `c1`. Note that the remote addr is not a valid one.
@@ -958,127 +914,101 @@ func TestAddFirewallRulesFromConfig(t *testing.T) {
conf := config.NewC(l)
mf := &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "tcp", "host": "a"}}}
require.NoError(t, AddFirewallRulesFromConfig(l, false, false, conf, mf))
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoTCP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall)
require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf))
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoTCP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
// Test adding udp rule
conf = config.NewC(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "udp", "host": "a"}}}
require.NoError(t, AddFirewallRulesFromConfig(l, false, false, conf, mf))
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoUDP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall)
require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf))
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoUDP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
// Test adding icmp rule
conf = config.NewC(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "icmp", "host": "a"}}}
require.NoError(t, AddFirewallRulesFromConfig(l, false, false, conf, mf))
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall)
require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf))
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
// Test adding any rule
conf = config.NewC(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "host": "a"}}}
require.NoError(t, AddFirewallRulesFromConfig(l, false, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall)
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
// Test adding rule with cidr
cidr := netip.MustParsePrefix("10.0.0.0/8")
conf = config.NewC(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": cidr.String()}}}
require.NoError(t, AddFirewallRulesFromConfig(l, false, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr.String(), localIp: ""}, mf.lastCall)
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr, localIp: netip.Prefix{}}, mf.lastCall)
// Test adding rule with local_cidr
conf = config.NewC(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": cidr.String()}}}
require.NoError(t, AddFirewallRulesFromConfig(l, false, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "", localIp: cidr.String()}, mf.lastCall)
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: cidr}, mf.lastCall)
// Test adding rule with cidr ipv6
cidr6 := netip.MustParsePrefix("fd00::/8")
conf = config.NewC(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": cidr6.String()}}}
require.NoError(t, AddFirewallRulesFromConfig(l, false, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr6.String(), localIp: ""}, mf.lastCall)
// Test adding rule with any cidr
conf = config.NewC(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": "any"}}}
require.NoError(t, AddFirewallRulesFromConfig(l, false, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "any", localIp: ""}, mf.lastCall)
// Test adding rule with junk cidr
conf = config.NewC(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": "junk/junk"}}}
require.EqualError(t, AddFirewallRulesFromConfig(l, false, true, conf, mf), "firewall.inbound rule #0; cidr did not parse; netip.ParsePrefix(\"junk/junk\"): ParseAddr(\"junk\"): unable to parse IP")
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr6, localIp: netip.Prefix{}}, mf.lastCall)
// Test adding rule with local_cidr ipv6
conf = config.NewC(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": cidr6.String()}}}
require.NoError(t, AddFirewallRulesFromConfig(l, false, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "", localIp: cidr6.String()}, mf.lastCall)
// Test adding rule with any local_cidr
conf = config.NewC(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": "any"}}}
require.NoError(t, AddFirewallRulesFromConfig(l, false, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, localIp: "any"}, mf.lastCall)
// Test adding rule with junk local_cidr
conf = config.NewC(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": "junk/junk"}}}
require.EqualError(t, AddFirewallRulesFromConfig(l, false, true, conf, mf), "firewall.inbound rule #0; local_cidr did not parse; netip.ParsePrefix(\"junk/junk\"): ParseAddr(\"junk\"): unable to parse IP")
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: cidr6}, mf.lastCall)
// Test adding rule with ca_sha
conf = config.NewC(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "ca_sha": "12312313123"}}}
require.NoError(t, AddFirewallRulesFromConfig(l, false, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "", localIp: "", caSha: "12312313123"}, mf.lastCall)
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: netip.Prefix{}, caSha: "12312313123"}, mf.lastCall)
// Test adding rule with ca_name
conf = config.NewC(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "ca_name": "root01"}}}
require.NoError(t, AddFirewallRulesFromConfig(l, false, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "", localIp: "", caName: "root01"}, mf.lastCall)
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: netip.Prefix{}, caName: "root01"}, mf.lastCall)
// Test single group
conf = config.NewC(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "group": "a"}}}
require.NoError(t, AddFirewallRulesFromConfig(l, false, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: "", localIp: ""}, mf.lastCall)
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
// Test single groups
conf = config.NewC(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "groups": "a"}}}
require.NoError(t, AddFirewallRulesFromConfig(l, false, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: "", localIp: ""}, mf.lastCall)
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
// Test multiple AND groups
conf = config.NewC(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "groups": []string{"a", "b"}}}}
require.NoError(t, AddFirewallRulesFromConfig(l, false, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a", "b"}, ip: "", localIp: ""}, mf.lastCall)
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a", "b"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
// Test Add error
conf = config.NewC(l)
mf = &mockFirewall{}
mf.nextCallReturn = errors.New("test error")
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "host": "a"}}}
require.EqualError(t, AddFirewallRulesFromConfig(l, false, true, conf, mf), "firewall.inbound rule #0; `test error`")
require.EqualError(t, AddFirewallRulesFromConfig(l, true, conf, mf), "firewall.inbound rule #0; `test error`")
}
func TestFirewall_convertRule(t *testing.T) {
@@ -1094,7 +1024,7 @@ func TestFirewall_convertRule(t *testing.T) {
r, err := convertRule(l, c, "test", 1)
assert.Contains(t, ob.String(), "test rule #1; group was an array with a single value, converting to simple value")
require.NoError(t, err)
assert.Equal(t, []string{"group1"}, r.Groups)
assert.Equal(t, "group1", r.Group)
// Ensure group array of > 1 is errord
ob.Reset()
@@ -1114,240 +1044,18 @@ func TestFirewall_convertRule(t *testing.T) {
r, err = convertRule(l, c, "test", 1)
require.NoError(t, err)
assert.Equal(t, []string{"group1"}, r.Groups)
}
func TestFirewall_convertRuleSanity(t *testing.T) {
l := test.NewLogger()
ob := &bytes.Buffer{}
l.SetOutput(ob)
noWarningPlease := []map[string]any{
{"group": "group1"},
{"groups": []any{"group2"}},
{"host": "bob"},
{"cidr": "1.1.1.1/1"},
{"groups": []any{"group2"}, "host": "bob"},
{"cidr": "1.1.1.1/1", "host": "bob"},
{"groups": []any{"group2"}, "cidr": "1.1.1.1/1"},
{"groups": []any{"group2"}, "cidr": "1.1.1.1/1", "host": "bob"},
}
for _, c := range noWarningPlease {
r, err := convertRule(l, c, "test", 1)
require.NoError(t, err)
require.NoError(t, r.sanity(), "should not generate a sanity warning, %+v", c)
}
yesWarningPlease := []map[string]any{
{"group": "group1"},
{"groups": []any{"group2"}},
{"cidr": "1.1.1.1/1"},
{"groups": []any{"group2"}, "host": "bob"},
{"cidr": "1.1.1.1/1", "host": "bob"},
{"groups": []any{"group2"}, "cidr": "1.1.1.1/1"},
{"groups": []any{"group2"}, "cidr": "1.1.1.1/1", "host": "bob"},
}
for _, c := range yesWarningPlease {
c["host"] = "any"
r, err := convertRule(l, c, "test", 1)
require.NoError(t, err)
err = r.sanity()
require.Error(t, err, "I wanted a warning: %+v", c)
}
//reset the list
yesWarningPlease = []map[string]any{
{"group": "group1"},
{"groups": []any{"group2"}},
{"cidr": "1.1.1.1/1"},
{"groups": []any{"group2"}, "host": "bob"},
{"cidr": "1.1.1.1/1", "host": "bob"},
{"groups": []any{"group2"}, "cidr": "1.1.1.1/1"},
{"groups": []any{"group2"}, "cidr": "1.1.1.1/1", "host": "bob"},
}
for _, c := range yesWarningPlease {
r, err := convertRule(l, c, "test", 1)
require.NoError(t, err)
r.Groups = append(r.Groups, "any")
err = r.sanity()
require.Error(t, err, "I wanted a warning: %+v", c)
}
}
type testcase struct {
h *HostInfo
p firewall.Packet
c cert.Certificate
err error
}
func (c *testcase) Test(t *testing.T, fw *Firewall) {
t.Helper()
cp := cert.NewCAPool()
resetConntrack(fw)
err := fw.Drop(c.p, true, c.h, cp, nil)
if c.err == nil {
require.NoError(t, err, "failed to not drop remote address %s", c.p.RemoteAddr)
} else {
require.ErrorIs(t, c.err, err, "failed to drop remote address %s", c.p.RemoteAddr)
}
}
func buildTestCase(setup testsetup, err error, theirPrefixes ...netip.Prefix) testcase {
c1 := dummyCert{
name: "host1",
networks: theirPrefixes,
groups: []string{"default-group"},
issuer: "signer-shasum",
}
h := HostInfo{
ConnectionState: &ConnectionState{
peerCert: &cert.CachedCertificate{
Certificate: &c1,
InvertedGroups: map[string]struct{}{"default-group": {}},
},
},
vpnAddrs: make([]netip.Addr, len(theirPrefixes)),
}
for i := range theirPrefixes {
h.vpnAddrs[i] = theirPrefixes[i].Addr()
}
h.buildNetworks(setup.myVpnNetworksTable, &c1)
p := firewall.Packet{
LocalAddr: setup.c.Networks()[0].Addr(), //todo?
RemoteAddr: theirPrefixes[0].Addr(),
LocalPort: 10,
RemotePort: 90,
Protocol: firewall.ProtoUDP,
Fragment: false,
}
return testcase{
h: &h,
p: p,
c: &c1,
err: err,
}
}
type testsetup struct {
c dummyCert
myVpnNetworksTable *bart.Lite
fw *Firewall
}
func newSetup(t *testing.T, l *logrus.Logger, myPrefixes ...netip.Prefix) testsetup {
c := dummyCert{
name: "me",
networks: myPrefixes,
groups: []string{"default-group"},
issuer: "signer-shasum",
}
return newSetupFromCert(t, l, c)
}
func newSetupFromCert(t *testing.T, l *logrus.Logger, c dummyCert) testsetup {
myVpnNetworksTable := new(bart.Lite)
for _, prefix := range c.Networks() {
myVpnNetworksTable.Insert(prefix)
}
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(false, true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", ""))
return testsetup{
c: c,
fw: fw,
myVpnNetworksTable: myVpnNetworksTable,
}
}
func TestFirewall_Drop_EnforceIPMatch(t *testing.T) {
t.Parallel()
l := test.NewLogger()
ob := &bytes.Buffer{}
l.SetOutput(ob)
myPrefix := netip.MustParsePrefix("1.1.1.1/8")
// for now, it's okay that these are all "incoming", the logic this test tries to check doesn't care about in/out
t.Run("allow inbound all matching", func(t *testing.T) {
t.Parallel()
setup := newSetup(t, l, myPrefix)
tc := buildTestCase(setup, nil, netip.MustParsePrefix("1.2.3.4/24"))
tc.Test(t, setup.fw)
})
t.Run("allow inbound local matching", func(t *testing.T) {
t.Parallel()
setup := newSetup(t, l, myPrefix)
tc := buildTestCase(setup, ErrInvalidLocalIP, netip.MustParsePrefix("1.2.3.4/24"))
tc.p.LocalAddr = netip.MustParseAddr("1.2.3.8")
tc.Test(t, setup.fw)
})
t.Run("block inbound remote mismatched", func(t *testing.T) {
t.Parallel()
setup := newSetup(t, l, myPrefix)
tc := buildTestCase(setup, ErrInvalidRemoteIP, netip.MustParsePrefix("1.2.3.4/24"))
tc.p.RemoteAddr = netip.MustParseAddr("9.9.9.9")
tc.Test(t, setup.fw)
})
t.Run("Block a vpn peer packet", func(t *testing.T) {
t.Parallel()
setup := newSetup(t, l, myPrefix)
tc := buildTestCase(setup, ErrPeerRejected, netip.MustParsePrefix("2.2.2.2/24"))
tc.Test(t, setup.fw)
})
twoPrefixes := []netip.Prefix{
netip.MustParsePrefix("1.2.3.4/24"), netip.MustParsePrefix("2.2.2.2/24"),
}
t.Run("allow inbound one matching", func(t *testing.T) {
t.Parallel()
setup := newSetup(t, l, myPrefix)
tc := buildTestCase(setup, nil, twoPrefixes...)
tc.Test(t, setup.fw)
})
t.Run("block inbound multimismatch", func(t *testing.T) {
t.Parallel()
setup := newSetup(t, l, myPrefix)
tc := buildTestCase(setup, ErrInvalidRemoteIP, twoPrefixes...)
tc.p.RemoteAddr = netip.MustParseAddr("9.9.9.9")
tc.Test(t, setup.fw)
})
t.Run("allow inbound 2nd one matching", func(t *testing.T) {
t.Parallel()
setup2 := newSetup(t, l, netip.MustParsePrefix("2.2.2.1/24"))
tc := buildTestCase(setup2, nil, twoPrefixes...)
tc.p.RemoteAddr = twoPrefixes[1].Addr()
tc.Test(t, setup2.fw)
})
t.Run("allow inbound unsafe route", func(t *testing.T) {
t.Parallel()
unsafePrefix := netip.MustParsePrefix("192.168.0.0/24")
c := dummyCert{
name: "me",
networks: []netip.Prefix{myPrefix},
unsafeNetworks: []netip.Prefix{unsafePrefix},
groups: []string{"default-group"},
issuer: "signer-shasum",
}
unsafeSetup := newSetupFromCert(t, l, c)
tc := buildTestCase(unsafeSetup, nil, twoPrefixes...)
tc.p.LocalAddr = netip.MustParseAddr("192.168.0.3")
tc.err = ErrNoMatchingRule
tc.Test(t, unsafeSetup.fw) //should hit firewall and bounce off
require.NoError(t, unsafeSetup.fw.AddRule(true, true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", unsafePrefix.String(), "", ""))
tc.err = nil
tc.Test(t, unsafeSetup.fw) //should pass
})
assert.Equal(t, "group1", r.Group)
}
type addRuleCall struct {
unsafe bool
incoming bool
proto uint8
startPort int32
endPort int32
groups []string
host string
ip string
localIp string
ip netip.Prefix
localIp netip.Prefix
caName string
caSha string
}
@@ -1357,9 +1065,8 @@ type mockFirewall struct {
nextCallReturn error
}
func (mf *mockFirewall) AddRule(unsafe, incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip, localIp, caName string, caSha string) error {
func (mf *mockFirewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip netip.Prefix, localIp netip.Prefix, caName string, caSha string) error {
mf.lastCall = addRuleCall{
unsafe: unsafe,
incoming: incoming,
proto: proto,
startPort: startPort,

15
go.mod
View File

@@ -8,7 +8,7 @@ require (
github.com/armon/go-radix v1.0.0
github.com/cyberdelia/go-metrics-graphite v0.0.0-20161219230853-39f87cc3b432
github.com/flynn/noise v1.1.0
github.com/gaissmai/bart v0.26.0
github.com/gaissmai/bart v0.25.0
github.com/gogo/protobuf v1.3.2
github.com/google/gopacket v1.1.19
github.com/kardianos/service v1.2.4
@@ -22,17 +22,16 @@ require (
github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6
github.com/stretchr/testify v1.11.1
github.com/vishvananda/netlink v1.3.1
go.yaml.in/yaml/v3 v3.0.4
golang.org/x/crypto v0.45.0
golang.org/x/crypto v0.43.0
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090
golang.org/x/net v0.47.0
golang.org/x/sync v0.18.0
golang.org/x/sys v0.38.0
golang.org/x/term v0.37.0
golang.org/x/net v0.45.0
golang.org/x/sync v0.17.0
golang.org/x/sys v0.37.0
golang.org/x/term v0.36.0
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2
golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b
golang.zx2c4.com/wireguard/windows v0.5.3
google.golang.org/protobuf v1.36.10
google.golang.org/protobuf v1.36.8
gopkg.in/yaml.v3 v3.0.1
gvisor.dev/gvisor v0.0.0-20240423190808-9d7a357edefe
)

30
go.sum
View File

@@ -24,8 +24,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/flynn/noise v1.1.0 h1:KjPQoQCEFdZDiP03phOvGi11+SVVhBG2wOWAorLsstg=
github.com/flynn/noise v1.1.0/go.mod h1:xbMo+0i6+IGbYdJhF31t2eR1BIU0CYc12+BNAKwUTag=
github.com/gaissmai/bart v0.26.0 h1:xOZ57E9hJLBiQaSyeZa9wgWhGuzfGACgqp4BE77OkO0=
github.com/gaissmai/bart v0.26.0/go.mod h1:GREWQfTLRWz/c5FTOsIw+KkscuFkIV5t8Rp7Nd1Td5c=
github.com/gaissmai/bart v0.25.0 h1:eqiokVPqM3F94vJ0bTHXHtH91S8zkKL+bKh+BsGOsJM=
github.com/gaissmai/bart v0.25.0/go.mod h1:GREWQfTLRWz/c5FTOsIw+KkscuFkIV5t8Rp7Nd1Td5c=
github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY=
@@ -155,15 +155,13 @@ go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
go.yaml.in/yaml/v2 v2.4.2 h1:DzmwEr2rDGHl7lsFgAHxmNz/1NlQ7xLIrlN2h5d1eGI=
go.yaml.in/yaml/v2 v2.4.2/go.mod h1:081UH+NErpNdqlCXm3TtEran0rJZGxAYx9hb/ELlsPU=
go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc=
go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg=
golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q=
golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4=
golang.org/x/crypto v0.43.0 h1:dduJYIi3A3KOfdGOHX8AVZ/jGiyPa3IbBozJ5kNuE04=
golang.org/x/crypto v0.43.0/go.mod h1:BFbav4mRNlXJL4wNeejLpWxB7wMbc79PdRGhWKncxR0=
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 h1:Di6/M8l0O2lCLc6VVRWhgCiApHV8MnQurBnFSHsQtNY=
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc=
golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY=
@@ -182,8 +180,8 @@ golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLL
golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY=
golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU=
golang.org/x/net v0.45.0 h1:RLBg5JKixCy82FtLJpeNlVM0nrSqpCRYzVU1n8kj0tM=
golang.org/x/net v0.45.0/go.mod h1:ECOoLqd5U3Lhyeyo/QDCEVQ4sNgYsqvCZ722XogGieY=
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
@@ -191,8 +189,8 @@ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJ
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I=
golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug=
golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
@@ -209,11 +207,11 @@ golang.org/x/sys v0.0.0-20210603081109-ebe580a85c40/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc=
golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/sys v0.37.0 h1:fdNQudmxPjkdUTPnLn5mdQv7Zwvbvpaxqs831goi9kQ=
golang.org/x/sys v0.37.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.37.0 h1:8EGAD0qCmHYZg6J17DvsMy9/wJ7/D/4pV/wfnld5lTU=
golang.org/x/term v0.37.0/go.mod h1:5pB4lxRNYYVZuTLmy8oR2BH8dflOR+IbTYFD8fi3254=
golang.org/x/term v0.36.0 h1:zMPR+aF8gfksFprF/Nc/rd1wRS1EI6nDBGyWAvDzx2Q=
golang.org/x/term v0.36.0/go.mod h1:Qu394IJq6V6dCBRgwqshf3mPF85AqzYEzofzRdZkWss=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
@@ -244,8 +242,8 @@ google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miE
google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo=
google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU=
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
google.golang.org/protobuf v1.36.10 h1:AYd7cD/uASjIL6Q9LiTjz8JLcrh/88q5UObnmY3aOOE=
google.golang.org/protobuf v1.36.10/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco=
google.golang.org/protobuf v1.36.8 h1:xHScyCOEuuwZEc6UtSOvPbAT4zRh0xcNRYekJwfqyMc=
google.golang.org/protobuf v1.36.8/go.mod h1:fuxRtAxBytpl4zzqUh6/eyUujkJdNiuEkXntxiD/uRU=
gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=

View File

@@ -2,6 +2,7 @@ package nebula
import (
"net/netip"
"slices"
"time"
"github.com/flynn/noise"
@@ -191,17 +192,17 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
return
}
var vpnAddrs []netip.Addr
var filteredNetworks []netip.Prefix
certName := remoteCert.Certificate.Name()
certVersion := remoteCert.Certificate.Version()
fingerprint := remoteCert.Fingerprint
issuer := remoteCert.Certificate.Issuer()
vpnNetworks := remoteCert.Certificate.Networks()
anyVpnAddrsInCommon := false
vpnAddrs := make([]netip.Addr, len(vpnNetworks))
for i, network := range vpnNetworks {
if f.myVpnAddrsTable.Contains(network.Addr()) {
f.l.WithField("vpnNetworks", vpnNetworks).WithField("udpAddr", addr).
for _, network := range remoteCert.Certificate.Networks() {
vpnAddr := network.Addr()
if f.myVpnAddrsTable.Contains(vpnAddr) {
f.l.WithField("vpnAddr", vpnAddr).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint).
@@ -209,10 +210,24 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Refusing to handshake with myself")
return
}
vpnAddrs[i] = network.Addr()
if f.myVpnNetworksTable.Contains(network.Addr()) {
anyVpnAddrsInCommon = true
// vpnAddrs outside our vpn networks are of no use to us, filter them out
if !f.myVpnNetworksTable.Contains(vpnAddr) {
continue
}
filteredNetworks = append(filteredNetworks, network)
vpnAddrs = append(vpnAddrs, vpnAddr)
}
if len(vpnAddrs) == 0 {
f.l.WithError(err).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint).
WithField("issuer", issuer).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("No usable vpn addresses from host, refusing handshake")
return
}
if addr.IsValid() {
@@ -249,30 +264,26 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
},
}
msgRxL := f.l.WithFields(m{
"vpnAddrs": vpnAddrs,
"udpAddr": addr,
"certName": certName,
"certVersion": certVersion,
"fingerprint": fingerprint,
"issuer": issuer,
"initiatorIndex": hs.Details.InitiatorIndex,
"responderIndex": hs.Details.ResponderIndex,
"remoteIndex": h.RemoteIndex,
"handshake": m{"stage": 1, "style": "ix_psk0"},
})
if anyVpnAddrsInCommon {
msgRxL.Info("Handshake message received")
} else {
//todo warn if not lighthouse or relay?
msgRxL.Info("Handshake message received, but no vpnNetworks in common.")
}
f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint).
WithField("issuer", issuer).
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
Info("Handshake message received")
hs.Details.ResponderIndex = myIndex
hs.Details.Cert = cs.getHandshakeBytes(ci.myCert.Version())
if hs.Details.Cert == nil {
msgRxL.WithField("myCertVersion", ci.myCert.Version()).
f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint).
WithField("issuer", issuer).
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
WithField("certVersion", ci.myCert.Version()).
Error("Unable to handshake with host because no certificate handshake bytes is available")
return
}
@@ -330,7 +341,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
hostinfo.remotes = f.lightHouse.QueryCache(vpnAddrs)
hostinfo.SetRemote(addr)
hostinfo.buildNetworks(f.myVpnNetworksTable, remoteCert.Certificate)
hostinfo.buildNetworks(filteredNetworks, remoteCert.Certificate.UnsafeNetworks())
existing, err := f.handshakeManager.CheckAndComplete(hostinfo, 0, f)
if err != nil {
@@ -571,22 +582,31 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0])
}
correctHostResponded := false
anyVpnAddrsInCommon := false
vpnAddrs := make([]netip.Addr, len(vpnNetworks))
for i, network := range vpnNetworks {
vpnAddrs[i] = network.Addr()
if f.myVpnNetworksTable.Contains(network.Addr()) {
anyVpnAddrsInCommon = true
}
if hostinfo.vpnAddrs[0] == network.Addr() {
// todo is it more correct to see if any of hostinfo.vpnAddrs are in the cert? it should have len==1, but one day it might not?
correctHostResponded = true
var vpnAddrs []netip.Addr
var filteredNetworks []netip.Prefix
for _, network := range vpnNetworks {
// vpnAddrs outside our vpn networks are of no use to us, filter them out
vpnAddr := network.Addr()
if !f.myVpnNetworksTable.Contains(vpnAddr) {
continue
}
filteredNetworks = append(filteredNetworks, network)
vpnAddrs = append(vpnAddrs, vpnAddr)
}
if len(vpnAddrs) == 0 {
f.l.WithError(err).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint).
WithField("issuer", issuer).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).Error("No usable vpn addresses from host, refusing handshake")
return true
}
// Ensure the right host responded
if !correctHostResponded {
if !slices.Contains(vpnAddrs, hostinfo.vpnAddrs[0]) {
f.l.WithField("intendedVpnAddrs", hostinfo.vpnAddrs).WithField("haveVpnNetworks", vpnNetworks).
WithField("udpAddr", addr).
WithField("certName", certName).
@@ -598,7 +618,6 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
f.handshakeManager.DeleteHostInfo(hostinfo)
// Create a new hostinfo/handshake for the intended vpn ip
//TODO is hostinfo.vpnAddrs[0] always the address to use?
f.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], func(newHH *HandshakeHostInfo) {
// Block the current used address
newHH.hostinfo.remotes = hostinfo.remotes
@@ -625,7 +644,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
ci.window.Update(f.l, 2)
duration := time.Since(hh.startTime).Nanoseconds()
msgRxL := f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint).
@@ -633,17 +652,12 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
WithField("durationNs", duration).
WithField("sentCachedPackets", len(hh.packetStore))
if anyVpnAddrsInCommon {
msgRxL.Info("Handshake message received")
} else {
//todo warn if not lighthouse or relay?
msgRxL.Info("Handshake message received, but no vpnNetworks in common.")
}
WithField("sentCachedPackets", len(hh.packetStore)).
Info("Handshake message received")
// Build up the radix for the firewall if we have subnets in the cert
hostinfo.vpnAddrs = vpnAddrs
hostinfo.buildNetworks(f.myVpnNetworksTable, remoteCert.Certificate)
hostinfo.buildNetworks(filteredNetworks, remoteCert.Certificate.UnsafeNetworks())
// Complete our handshake and update metrics, this will replace any existing tunnels for the vpnAddrs here
f.handshakeManager.Complete(hostinfo, f)

View File

@@ -269,12 +269,12 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
hostinfo.logger(hm.l).WithField("relays", hostinfo.remotes.relays).Info("Attempt to relay through hosts")
// Send a RelayRequest to all known Relay IP's
for _, relay := range hostinfo.remotes.relays {
// Don't relay through the host I'm trying to connect to
// Don't relay to myself
if relay == vpnIp {
continue
}
// Don't relay to myself
// Don't relay through the host I'm trying to connect to
if hm.f.myVpnAddrsTable.Contains(relay) {
continue
}

View File

@@ -212,18 +212,6 @@ func (rs *RelayState) InsertRelay(ip netip.Addr, idx uint32, r *Relay) {
rs.relayForByIdx[idx] = r
}
type NetworkType uint8
const (
NetworkTypeUnknown NetworkType = iota
// NetworkTypeVPN is a network that overlaps one or more of the vpnNetworks in our certificate
NetworkTypeVPN
// NetworkTypeVPNPeer is a network that does not overlap one of our networks
NetworkTypeVPNPeer
// NetworkTypeUnsafe is a network from Certificate.UnsafeNetworks()
NetworkTypeUnsafe
)
type HostInfo struct {
remote netip.AddrPort
remotes *RemoteList
@@ -237,8 +225,8 @@ type HostInfo struct {
// vpn networks but were removed because they are not usable
vpnAddrs []netip.Addr
// networks is a combination of specific vpn addresses (not prefixes!) and full unsafe networks assigned to this host.
networks *bart.Table[NetworkType]
// networks are both all vpn and unsafe networks assigned to this host
networks *bart.Lite
relayState RelayState
// HandshakePacket records the packets used to create this hostinfo
@@ -742,26 +730,20 @@ func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, newRemote netip.AddrPort) b
return false
}
// buildNetworks fills in the networks field of HostInfo. It accepts a cert.Certificate so you never ever mix the network types up.
func (i *HostInfo) buildNetworks(myVpnNetworksTable *bart.Lite, c cert.Certificate) {
if len(c.Networks()) == 1 && len(c.UnsafeNetworks()) == 0 {
if myVpnNetworksTable.Contains(c.Networks()[0].Addr()) {
return // Simple case, no BART needed
}
func (i *HostInfo) buildNetworks(networks, unsafeNetworks []netip.Prefix) {
if len(networks) == 1 && len(unsafeNetworks) == 0 {
// Simple case, no CIDRTree needed
return
}
i.networks = new(bart.Table[NetworkType])
for _, network := range c.Networks() {
i.networks = new(bart.Lite)
for _, network := range networks {
nprefix := netip.PrefixFrom(network.Addr(), network.Addr().BitLen())
if myVpnNetworksTable.Contains(network.Addr()) {
i.networks.Insert(nprefix, NetworkTypeVPN)
} else {
i.networks.Insert(nprefix, NetworkTypeVPNPeer)
}
i.networks.Insert(nprefix)
}
for _, network := range c.UnsafeNetworks() {
i.networks.Insert(network, NetworkTypeUnsafe)
for _, network := range unsafeNetworks {
i.networks.Insert(network)
}
}

View File

@@ -120,10 +120,9 @@ func (f *Interface) rejectOutside(packet []byte, ci *ConnectionState, hostinfo *
f.sendNoMetrics(header.Message, 0, ci, hostinfo, netip.AddrPort{}, out, nb, packet, q)
}
// Handshake will attempt to initiate a tunnel with the provided vpn address. This is a no-op if the tunnel is already established or being established
// it does not check if it is within our vpn networks!
// Handshake will attempt to initiate a tunnel with the provided vpn address if it is within our vpn networks. This is a no-op if the tunnel is already established or being established
func (f *Interface) Handshake(vpnAddr netip.Addr) {
f.handshakeManager.GetOrHandshake(vpnAddr, nil)
f.getOrHandshakeNoRouting(vpnAddr, nil)
}
// getOrHandshakeNoRouting returns nil if the vpnAddr is not routable.
@@ -139,6 +138,7 @@ func (f *Interface) getOrHandshakeNoRouting(vpnAddr netip.Addr, cacheCallback fu
// getOrHandshakeConsiderRouting will try to find the HostInfo to handle this packet, starting a handshake if necessary.
// If the 2nd return var is false then the hostinfo is not ready to be used in a tunnel.
func (f *Interface) getOrHandshakeConsiderRouting(fwPacket *firewall.Packet, cacheCallback func(*HandshakeHostInfo)) (*HostInfo, bool) {
destinationAddr := fwPacket.RemoteAddr
hostinfo, ready := f.getOrHandshakeNoRouting(destinationAddr, cacheCallback)
@@ -231,10 +231,9 @@ func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubTyp
f.sendNoMetrics(header.Message, st, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, p, nb, out, 0)
}
// SendMessageToVpnAddr handles real addr:port lookup and sends to the current best known address for vpnAddr.
// This function ignores myVpnNetworksTable, and will always attempt to treat the address as a vpnAddr
// SendMessageToVpnAddr handles real addr:port lookup and sends to the current best known address for vpnAddr
func (f *Interface) SendMessageToVpnAddr(t header.MessageType, st header.MessageSubType, vpnAddr netip.Addr, p, nb, out []byte) {
hostInfo, ready := f.handshakeManager.GetOrHandshake(vpnAddr, func(hh *HandshakeHostInfo) {
hostInfo, ready := f.getOrHandshakeNoRouting(vpnAddr, func(hh *HandshakeHostInfo) {
hh.cachePacket(f.l, t, st, p, f.SendMessageToHostInfo, f.cachedPacketMetrics)
})

View File

@@ -6,8 +6,8 @@ import (
"fmt"
"io"
"net/netip"
"os"
"runtime"
"sync"
"sync/atomic"
"time"
@@ -87,6 +87,7 @@ type Interface struct {
writers []udp.Conn
readers []io.ReadWriteCloser
wg sync.WaitGroup
metricHandshakes metrics.Histogram
messageMetrics *MessageMetrics
@@ -209,7 +210,7 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
// activate creates the interface on the host. After the interface is created, any
// other services that want to bind listeners to its IP may do so successfully. However,
// the interface isn't going to process anything until run() is called.
func (f *Interface) activate() {
func (f *Interface) activate() error {
// actually turn on tun dev
addr, err := f.outside.LocalAddr()
@@ -222,13 +223,6 @@ func (f *Interface) activate() {
WithField("boringcrypto", boringEnabled()).
Info("Nebula interface is active")
if f.routines > 1 {
if !f.inside.SupportsMultiqueue() || !f.outside.SupportsMultipleReaders() {
f.routines = 1
f.l.Warn("routines is not supported on this platform, falling back to a single routine")
}
}
metrics.GetOrRegisterGauge("routines", nil).Update(int64(f.routines))
// Prepare n tun queues
@@ -237,33 +231,38 @@ func (f *Interface) activate() {
if i > 0 {
reader, err = f.inside.NewMultiQueueReader()
if err != nil {
f.l.Fatal(err)
return err
}
}
f.readers[i] = reader
}
if err := f.inside.Activate(); err != nil {
if err = f.inside.Activate(); err != nil {
f.inside.Close()
f.l.Fatal(err)
return err
}
return nil
}
func (f *Interface) run() {
func (f *Interface) run() (func(), error) {
// Launch n queues to read packets from udp
for i := 0; i < f.routines; i++ {
f.wg.Add(1)
go f.listenOut(i)
}
// Launch n queues to read packets from tun dev
for i := 0; i < f.routines; i++ {
f.wg.Add(1)
go f.listenIn(f.readers[i], i)
}
return f.wg.Wait, nil
}
func (f *Interface) listenOut(i int) {
runtime.LockOSThread()
var li udp.Conn
if i > 0 {
li = f.writers[i]
@@ -278,14 +277,21 @@ func (f *Interface) listenOut(i int) {
fwPacket := &firewall.Packet{}
nb := make([]byte, 12, 12)
li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) {
err := li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) {
f.readOutsidePackets(fromUdpAddr, nil, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get(f.l))
})
if err != nil && !f.closed.Load() {
f.l.WithError(err).Error("Error while reading packet inbound packet, closing")
//TODO: Trigger Control to close
}
f.l.Debugf("underlay reader %v is done", i)
f.wg.Done()
}
func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
runtime.LockOSThread()
packet := make([]byte, mtu)
out := make([]byte, mtu)
fwPacket := &firewall.Packet{}
@@ -296,17 +302,18 @@ func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
for {
n, err := reader.Read(packet)
if err != nil {
if errors.Is(err, os.ErrClosed) && f.closed.Load() {
return
if !f.closed.Load() {
f.l.WithError(err).Error("Error while reading outbound packet, closing")
//TODO: Trigger Control to close
}
f.l.WithError(err).Error("Error while reading outbound packet")
// This only seems to happen when something fatal happens to the fd, so exit.
os.Exit(2)
break
}
f.consumeInsidePacket(packet[:n], fwPacket, nb, out, i, conntrackCache.Get(f.l))
}
f.l.Debugf("overlay reader %v is done", i)
f.wg.Done()
}
func (f *Interface) RegisterConfigChangeCallbacks(c *config.C) {
@@ -458,6 +465,7 @@ func (f *Interface) GetCertState() *CertState {
func (f *Interface) Close() error {
f.closed.Store(true)
// Release the udp readers
for _, u := range f.writers {
err := u.Close()
if err != nil {
@@ -465,6 +473,13 @@ func (f *Interface) Close() error {
}
}
// Release the tun device
return f.inside.Close()
// Release the tun readers
for _, u := range f.readers {
err := u.Close()
if err != nil {
f.l.WithError(err).Error("Error while closing tun device")
}
}
return nil
}

View File

@@ -360,8 +360,7 @@ func (lh *LightHouse) parseLighthouses(c *config.C) ([]netip.Addr, error) {
}
if !lh.myVpnNetworksTable.Contains(addr) {
lh.l.WithFields(m{"vpnAddr": addr, "networks": lh.myVpnNetworks}).
Warn("lighthouse host is not within our networks, lighthouse functionality will work but layer 3 network traffic to the lighthouse will not")
return nil, util.NewContextualError("lighthouse host is not in our networks, invalid", m{"vpnAddr": addr, "networks": lh.myVpnNetworks}, nil)
}
out[i] = addr
}
@@ -432,8 +431,7 @@ func (lh *LightHouse) loadStaticMap(c *config.C, staticList map[netip.Addr]struc
}
if !lh.myVpnNetworksTable.Contains(vpnAddr) {
lh.l.WithFields(m{"vpnAddr": vpnAddr, "networks": lh.myVpnNetworks, "entry": i + 1}).
Warn("static_host_map key is not within our networks, layer 3 network traffic to this host will not work")
return util.NewContextualError("static_host_map key is not in our network, invalid", m{"vpnAddr": vpnAddr, "networks": lh.myVpnNetworks, "entry": i + 1}, nil)
}
vals, ok := v.([]any)
@@ -1339,19 +1337,12 @@ func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, fromVpn
}
}
remoteAllowList := lhh.lh.GetRemoteAllowList()
for _, a := range n.Details.V4AddrPorts {
b := protoV4AddrPortToNetAddrPort(a)
if remoteAllowList.Allow(detailsVpnAddr, b.Addr()) {
punch(b, detailsVpnAddr)
}
punch(protoV4AddrPortToNetAddrPort(a), detailsVpnAddr)
}
for _, a := range n.Details.V6AddrPorts {
b := protoV6AddrPortToNetAddrPort(a)
if remoteAllowList.Allow(detailsVpnAddr, b.Addr()) {
punch(b, detailsVpnAddr)
}
punch(protoV6AddrPortToNetAddrPort(a), detailsVpnAddr)
}
// This sends a nebula test packet to the host trying to contact us. In the case

View File

@@ -14,7 +14,7 @@ import (
"github.com/slackhq/nebula/test"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.yaml.in/yaml/v3"
"gopkg.in/yaml.v3"
)
func TestOldIPv4Only(t *testing.T) {

44
main.go
View File

@@ -5,8 +5,6 @@ import (
"fmt"
"net"
"net/netip"
"runtime/debug"
"strings"
"time"
"github.com/sirupsen/logrus"
@@ -15,7 +13,7 @@ import (
"github.com/slackhq/nebula/sshd"
"github.com/slackhq/nebula/udp"
"github.com/slackhq/nebula/util"
"go.yaml.in/yaml/v3"
"gopkg.in/yaml.v3"
)
type m = map[string]any
@@ -29,10 +27,6 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
}
}()
if buildVersion == "" {
buildVersion = moduleVersion()
}
l := logger
l.Formatter = &logrus.TextFormatter{
FullTimestamp: true,
@@ -81,8 +75,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
if c.GetBool("sshd.enabled", false) {
sshStart, err = configSSH(l, ssh, c)
if err != nil {
l.WithError(err).Warn("Failed to configure sshd, ssh debugging will not be available")
sshStart = nil
return nil, util.ContextualizeIfNeeded("Error while configuring the sshd", err)
}
}
@@ -291,29 +284,14 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
}
return &Control{
ifce,
l,
ctx,
cancel,
sshStart,
statsStart,
dnsStart,
lightHouse.StartUpdateWorker,
connManager.Start,
f: ifce,
l: l,
ctx: ctx,
cancel: cancel,
sshStart: sshStart,
statsStart: statsStart,
dnsStart: dnsStart,
lighthouseStart: lightHouse.StartUpdateWorker,
connectionManagerStart: connManager.Start,
}, nil
}
func moduleVersion() string {
info, ok := debug.ReadBuildInfo()
if !ok {
return ""
}
for _, dep := range info.Deps {
if dep.Path == "github.com/slackhq/nebula" {
return strings.TrimPrefix(dep.Version, "v")
}
}
return ""
}

View File

@@ -29,7 +29,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
return
}
//l.Error("in packet ", header, packet[HeaderLen:])
//f.l.Error("in packet ", h)
if ip.IsValid() {
if f.myVpnNetworksTable.Contains(ip.Addr()) {
if f.l.Level >= logrus.DebugLevel {

View File

@@ -13,6 +13,5 @@ type Device interface {
Networks() []netip.Prefix
Name() string
RoutesFor(netip.Addr) routing.Gateways
SupportsMultiqueue() bool
NewMultiQueueReader() (io.ReadWriteCloser, error)
}

View File

@@ -95,10 +95,6 @@ func (t *tun) Name() string {
return "android"
}
func (t *tun) SupportsMultiqueue() bool {
return false
}
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for android")
}

View File

@@ -549,10 +549,6 @@ func (t *tun) Name() string {
return t.Device
}
func (t *tun) SupportsMultiqueue() bool {
return false
}
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for darwin")
}

View File

@@ -105,10 +105,6 @@ func (t *disabledTun) Write(b []byte) (int, error) {
return len(b), nil
}
func (t *disabledTun) SupportsMultiqueue() bool {
return true
}
func (t *disabledTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return t, nil
}

View File

@@ -450,10 +450,6 @@ func (t *tun) Name() string {
return t.Device
}
func (t *tun) SupportsMultiqueue() bool {
return false
}
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for freebsd")
}

View File

@@ -151,10 +151,6 @@ func (t *tun) Name() string {
return "iOS"
}
func (t *tun) SupportsMultiqueue() bool {
return false
}
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for ios")
}

View File

@@ -216,10 +216,6 @@ func (t *tun) reload(c *config.C, initial bool) error {
return nil
}
func (t *tun) SupportsMultiqueue() bool {
return true
}
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
if err != nil {
@@ -586,42 +582,48 @@ func (t *tun) isGatewayInVpnNetworks(gwAddr netip.Addr) bool {
}
func (t *tun) getGatewaysFromRoute(r *netlink.Route) routing.Gateways {
var gateways routing.Gateways
link, err := netlink.LinkByName(t.Device)
if err != nil {
t.l.WithField("deviceName", t.Device).Error("Ignoring route update: failed to get link by name")
t.l.WithField("Devicename", t.Device).Error("Ignoring route update: failed to get link by name")
return gateways
}
// If this route is relevant to our interface and there is a gateway then add it
if r.LinkIndex == link.Attrs().Index {
gwAddr, ok := getGatewayAddr(r.Gw, r.Via)
if ok {
if t.isGatewayInVpnNetworks(gwAddr) {
gateways = append(gateways, routing.NewGateway(gwAddr, 1))
} else {
// Gateway isn't in our overlay network, ignore
t.l.WithField("route", r).Debug("Ignoring route update, gateway is not in our network")
}
if r.LinkIndex == link.Attrs().Index && len(r.Gw) > 0 {
gwAddr, ok := netip.AddrFromSlice(r.Gw)
if !ok {
t.l.WithField("route", r).Debug("Ignoring route update, invalid gateway address")
} else {
t.l.WithField("route", r).Debug("Ignoring route update, invalid gateway or via address")
gwAddr = gwAddr.Unmap()
if !t.isGatewayInVpnNetworks(gwAddr) {
// Gateway isn't in our overlay network, ignore
t.l.WithField("route", r).Debug("Ignoring route update, not in our network")
} else {
gateways = append(gateways, routing.NewGateway(gwAddr, 1))
}
}
}
for _, p := range r.MultiPath {
// If this route is relevant to our interface and there is a gateway then add it
if p.LinkIndex == link.Attrs().Index {
gwAddr, ok := getGatewayAddr(p.Gw, p.Via)
if ok {
if t.isGatewayInVpnNetworks(gwAddr) {
gateways = append(gateways, routing.NewGateway(gwAddr, p.Hops+1))
} else {
// Gateway isn't in our overlay network, ignore
t.l.WithField("route", r).Debug("Ignoring route update, gateway is not in our network")
}
if p.LinkIndex == link.Attrs().Index && len(p.Gw) > 0 {
gwAddr, ok := netip.AddrFromSlice(p.Gw)
if !ok {
t.l.WithField("route", r).Debug("Ignoring multipath route update, invalid gateway address")
} else {
t.l.WithField("route", r).Debug("Ignoring route update, invalid gateway or via address")
gwAddr = gwAddr.Unmap()
if !t.isGatewayInVpnNetworks(gwAddr) {
// Gateway isn't in our overlay network, ignore
t.l.WithField("route", r).Debug("Ignoring route update, not in our network")
} else {
// p.Hops+1 = weight of the route
gateways = append(gateways, routing.NewGateway(gwAddr, p.Hops+1))
}
}
}
}
@@ -630,27 +632,10 @@ func (t *tun) getGatewaysFromRoute(r *netlink.Route) routing.Gateways {
return gateways
}
func getGatewayAddr(gw net.IP, via netlink.Destination) (netip.Addr, bool) {
// Try to use the old RTA_GATEWAY first
gwAddr, ok := netip.AddrFromSlice(gw)
if !ok {
// Fallback to the new RTA_VIA
rVia, ok := via.(*netlink.Via)
if ok {
gwAddr, ok = netip.AddrFromSlice(rVia.Addr)
}
}
if gwAddr.IsValid() {
gwAddr = gwAddr.Unmap()
return gwAddr, true
}
return netip.Addr{}, false
}
func (t *tun) updateRoutes(r netlink.RouteUpdate) {
gateways := t.getGatewaysFromRoute(&r.Route)
if len(gateways) == 0 {
// No gateways relevant to our network, no routing changes required.
t.l.WithField("route", r).Debug("Ignoring route update, no gateways")

View File

@@ -390,10 +390,6 @@ func (t *tun) Name() string {
return t.Device
}
func (t *tun) SupportsMultiqueue() bool {
return false
}
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for netbsd")
}

View File

@@ -310,10 +310,6 @@ func (t *tun) Name() string {
return t.Device
}
func (t *tun) SupportsMultiqueue() bool {
return false
}
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for openbsd")
}

View File

@@ -132,10 +132,6 @@ func (t *TestTun) Read(b []byte) (int, error) {
return len(p), nil
}
func (t *TestTun) SupportsMultiqueue() bool {
return false
}
func (t *TestTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented")
}

View File

@@ -234,10 +234,6 @@ func (t *winTun) Write(b []byte) (int, error) {
return t.tun.Write(b, 0)
}
func (t *winTun) SupportsMultiqueue() bool {
return false
}
func (t *winTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for windows")
}

View File

@@ -46,10 +46,6 @@ func (d *UserDevice) RoutesFor(ip netip.Addr) routing.Gateways {
return routing.Gateways{routing.NewGateway(ip, 1)}
}
func (d *UserDevice) SupportsMultiqueue() bool {
return true
}
func (d *UserDevice) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return d, nil
}

10
pki.go
View File

@@ -523,13 +523,9 @@ func loadCAPoolFromConfig(l *logrus.Logger, c *config.C) (*cert.CAPool, error) {
return nil, fmt.Errorf("error while adding CA certificate to CA trust store: %s", err)
}
bl := c.GetStringSlice("pki.blocklist", []string{})
if len(bl) > 0 {
for _, fp := range bl {
caPool.BlocklistFingerprint(fp)
}
l.WithField("fingerprintCount", len(bl)).Info("Blocklisted certificates")
for _, fp := range c.GetStringSlice("pki.blocklist", []string{}) {
l.WithField("fingerprint", fp).Info("Blocklisting cert")
caPool.BlocklistFingerprint(fp)
}
return caPool, nil

View File

@@ -44,7 +44,10 @@ type Service struct {
}
func New(control *nebula.Control) (*Service, error) {
control.Start()
wait, err := control.Start()
if err != nil {
return nil, err
}
ctx := control.Context()
eg, ctx := errgroup.WithContext(ctx)
@@ -141,6 +144,12 @@ func New(control *nebula.Control) (*Service, error) {
}
})
// Add the nebula wait function to the group
eg.Go(func() error {
wait()
return nil
})
return &s, nil
}

View File

@@ -16,8 +16,8 @@ import (
"github.com/slackhq/nebula/cert_test"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/overlay"
"go.yaml.in/yaml/v3"
"golang.org/x/sync/errgroup"
"gopkg.in/yaml.v3"
)
type m = map[string]any

View File

@@ -34,10 +34,6 @@ func (NoopTun) Write([]byte) (int, error) {
return 0, nil
}
func (NoopTun) SupportsMultiqueue() bool {
return false
}
func (NoopTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return nil, errors.New("unsupported")
}

View File

@@ -16,10 +16,9 @@ type EncReader func(
type Conn interface {
Rebind() error
LocalAddr() (netip.AddrPort, error)
ListenOut(r EncReader)
ListenOut(r EncReader) error
WriteTo(b []byte, addr netip.AddrPort) error
ReloadConfig(c *config.C)
SupportsMultipleReaders() bool
Close() error
}
@@ -31,11 +30,8 @@ func (NoopConn) Rebind() error {
func (NoopConn) LocalAddr() (netip.AddrPort, error) {
return netip.AddrPort{}, nil
}
func (NoopConn) ListenOut(_ EncReader) {
return
}
func (NoopConn) SupportsMultipleReaders() bool {
return false
func (NoopConn) ListenOut(_ EncReader) error {
return nil
}
func (NoopConn) WriteTo(_ []byte, _ netip.AddrPort) error {
return nil

View File

@@ -98,9 +98,9 @@ func (u *StdConn) WriteTo(b []byte, ap netip.AddrPort) error {
return ErrInvalidIPv6RemoteForSocket
}
var rsa unix.RawSockaddrInet4
rsa.Family = unix.AF_INET
rsa.Addr = ap.Addr().As4()
var rsa unix.RawSockaddrInet6
rsa.Family = unix.AF_INET6
rsa.Addr = ap.Addr().As16()
binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&rsa.Port))[:], ap.Port())
sa = unsafe.Pointer(&rsa)
addrLen = syscall.SizeofSockaddrInet4
@@ -165,7 +165,7 @@ func NewUDPStatsEmitter(udpConns []Conn) func() {
return func() {}
}
func (u *StdConn) ListenOut(r EncReader) {
func (u *StdConn) ListenOut(r EncReader) error {
buffer := make([]byte, MTU)
for {
@@ -173,8 +173,7 @@ func (u *StdConn) ListenOut(r EncReader) {
n, rua, err := u.ReadFromUDPAddrPort(buffer)
if err != nil {
if errors.Is(err, net.ErrClosed) {
u.l.WithError(err).Debug("udp socket is closed, exiting read loop")
return
return err
}
u.l.WithError(err).Error("unexpected udp socket receive error")
@@ -184,10 +183,6 @@ func (u *StdConn) ListenOut(r EncReader) {
}
}
func (u *StdConn) SupportsMultipleReaders() bool {
return false
}
func (u *StdConn) Rebind() error {
var err error
if u.isV4 {

View File

@@ -71,21 +71,16 @@ type rawMessage struct {
Len uint32
}
func (u *GenericConn) ListenOut(r EncReader) {
func (u *GenericConn) ListenOut(r EncReader) error {
buffer := make([]byte, MTU)
for {
// Just read one packet at a time
n, rua, err := u.ReadFromUDPAddrPort(buffer)
if err != nil {
u.l.WithError(err).Debug("udp socket is closed, exiting read loop")
return
return err
}
r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n])
}
}
func (u *GenericConn) SupportsMultipleReaders() bool {
return false
}

View File

@@ -9,6 +9,7 @@ import (
"net"
"net/netip"
"syscall"
"time"
"unsafe"
"github.com/rcrowley/go-metrics"
@@ -17,6 +18,8 @@ import (
"golang.org/x/sys/unix"
)
var readTimeout = unix.NsecToTimeval(int64(time.Millisecond * 500))
type StdConn struct {
sysFd int
isV4 bool
@@ -24,14 +27,6 @@ type StdConn struct {
batch int
}
func maybeIPV4(ip net.IP) (net.IP, bool) {
ip4 := ip.To4()
if ip4 != nil {
return ip4, true
}
return ip, false
}
func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
af := unix.AF_INET6
if ip.Is4() {
@@ -55,6 +50,11 @@ func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch in
}
}
// Set a read timeout
if err = unix.SetsockoptTimeval(fd, unix.SOL_SOCKET, unix.SO_RCVTIMEO, &readTimeout); err != nil {
return nil, fmt.Errorf("unable to set SO_RCVTIMEO: %s", err)
}
var sa unix.Sockaddr
if ip.Is4() {
sa4 := &unix.SockaddrInet4{Port: port}
@@ -72,10 +72,6 @@ func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch in
return &StdConn{sysFd: fd, isV4: ip.Is4(), l: l, batch: batch}, err
}
func (u *StdConn) SupportsMultipleReaders() bool {
return true
}
func (u *StdConn) Rebind() error {
return nil
}
@@ -122,7 +118,7 @@ func (u *StdConn) LocalAddr() (netip.AddrPort, error) {
}
}
func (u *StdConn) ListenOut(r EncReader) {
func (u *StdConn) ListenOut(r EncReader) error {
var ip netip.Addr
msgs, buffers, names := u.PrepareRawMessages(u.batch)
@@ -134,8 +130,7 @@ func (u *StdConn) ListenOut(r EncReader) {
for {
n, err := read(msgs)
if err != nil {
u.l.WithError(err).Debug("udp socket is closed, exiting read loop")
return
return err
}
for i := 0; i < n; i++ {
@@ -163,6 +158,9 @@ func (u *StdConn) ReadSingle(msgs []rawMessage) (int, error) {
)
if err != 0 {
if err == unix.EAGAIN || err == unix.EINTR {
continue
}
return 0, &net.OpError{Op: "recvmsg", Err: err}
}
@@ -184,6 +182,9 @@ func (u *StdConn) ReadMulti(msgs []rawMessage) (int, error) {
)
if err != 0 {
if err == unix.EAGAIN || err == unix.EINTR {
continue
}
return 0, &net.OpError{Op: "recvmmsg", Err: err}
}

View File

@@ -134,7 +134,7 @@ func (u *RIOConn) bind(sa windows.Sockaddr) error {
return nil
}
func (u *RIOConn) ListenOut(r EncReader) {
func (u *RIOConn) ListenOut(r EncReader) error {
buffer := make([]byte, MTU)
for {
@@ -142,8 +142,7 @@ func (u *RIOConn) ListenOut(r EncReader) {
n, rua, err := u.receive(buffer)
if err != nil {
if errors.Is(err, net.ErrClosed) {
u.l.WithError(err).Debug("udp socket is closed, exiting read loop")
return
return err
}
u.l.WithError(err).Error("unexpected udp socket receive error")
continue
@@ -315,10 +314,6 @@ func (u *RIOConn) LocalAddr() (netip.AddrPort, error) {
}
func (u *RIOConn) SupportsMultipleReaders() bool {
return false
}
func (u *RIOConn) Rebind() error {
return nil
}

View File

@@ -6,6 +6,7 @@ package udp
import (
"io"
"net/netip"
"os"
"sync/atomic"
"github.com/sirupsen/logrus"
@@ -106,11 +107,11 @@ func (u *TesterConn) WriteTo(b []byte, addr netip.AddrPort) error {
return nil
}
func (u *TesterConn) ListenOut(r EncReader) {
func (u *TesterConn) ListenOut(r EncReader) error {
for {
p, ok := <-u.RxPackets
if !ok {
return
return os.ErrClosed
}
r(p.From, p.Data)
}
@@ -127,10 +128,6 @@ func (u *TesterConn) LocalAddr() (netip.AddrPort, error) {
return u.Addr, nil
}
func (u *TesterConn) SupportsMultipleReaders() bool {
return false
}
func (u *TesterConn) Rebind() error {
return nil
}