Compare commits

..

2 Commits

Author SHA1 Message Date
JackDoan
3583a3f7ab feedback 2025-10-14 13:46:21 -05:00
JackDoan
36daea9551 linux: opt out of naming your tun device yourself 2025-10-14 13:33:01 -05:00
75 changed files with 418 additions and 6045 deletions

View File

@@ -14,7 +14,7 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v5 - uses: actions/checkout@v4
- uses: actions/setup-go@v6 - uses: actions/setup-go@v6
with: with:

View File

@@ -10,7 +10,7 @@ jobs:
name: Build Linux/BSD All name: Build Linux/BSD All
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v5 - uses: actions/checkout@v4
- uses: actions/setup-go@v6 - uses: actions/setup-go@v6
with: with:
@@ -24,7 +24,7 @@ jobs:
mv build/*.tar.gz release mv build/*.tar.gz release
- name: Upload artifacts - name: Upload artifacts
uses: actions/upload-artifact@v5 uses: actions/upload-artifact@v4
with: with:
name: linux-latest name: linux-latest
path: release path: release
@@ -33,7 +33,7 @@ jobs:
name: Build Windows name: Build Windows
runs-on: windows-latest runs-on: windows-latest
steps: steps:
- uses: actions/checkout@v5 - uses: actions/checkout@v4
- uses: actions/setup-go@v6 - uses: actions/setup-go@v6
with: with:
@@ -55,7 +55,7 @@ jobs:
mv dist\windows\wintun build\dist\windows\ mv dist\windows\wintun build\dist\windows\
- name: Upload artifacts - name: Upload artifacts
uses: actions/upload-artifact@v5 uses: actions/upload-artifact@v4
with: with:
name: windows-latest name: windows-latest
path: build path: build
@@ -66,7 +66,7 @@ jobs:
HAS_SIGNING_CREDS: ${{ secrets.AC_USERNAME != '' }} HAS_SIGNING_CREDS: ${{ secrets.AC_USERNAME != '' }}
runs-on: macos-latest runs-on: macos-latest
steps: steps:
- uses: actions/checkout@v5 - uses: actions/checkout@v4
- uses: actions/setup-go@v6 - uses: actions/setup-go@v6
with: with:
@@ -104,7 +104,7 @@ jobs:
fi fi
- name: Upload artifacts - name: Upload artifacts
uses: actions/upload-artifact@v5 uses: actions/upload-artifact@v4
with: with:
name: darwin-latest name: darwin-latest
path: ./release/* path: ./release/*
@@ -124,11 +124,11 @@ jobs:
# be overwritten # be overwritten
- name: Checkout code - name: Checkout code
if: ${{ env.HAS_DOCKER_CREDS == 'true' }} if: ${{ env.HAS_DOCKER_CREDS == 'true' }}
uses: actions/checkout@v5 uses: actions/checkout@v4
- name: Download artifacts - name: Download artifacts
if: ${{ env.HAS_DOCKER_CREDS == 'true' }} if: ${{ env.HAS_DOCKER_CREDS == 'true' }}
uses: actions/download-artifact@v6 uses: actions/download-artifact@v4
with: with:
name: linux-latest name: linux-latest
path: artifacts path: artifacts
@@ -160,10 +160,10 @@ jobs:
needs: [build-linux, build-darwin, build-windows] needs: [build-linux, build-darwin, build-windows]
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v5 - uses: actions/checkout@v4
- name: Download artifacts - name: Download artifacts
uses: actions/download-artifact@v6 uses: actions/download-artifact@v4
with: with:
path: artifacts path: artifacts

View File

@@ -20,7 +20,7 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v5 - uses: actions/checkout@v4
- uses: actions/setup-go@v6 - uses: actions/setup-go@v6
with: with:

View File

@@ -18,7 +18,7 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v5 - uses: actions/checkout@v4
- uses: actions/setup-go@v6 - uses: actions/setup-go@v6
with: with:

View File

@@ -18,7 +18,7 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v5 - uses: actions/checkout@v4
- uses: actions/setup-go@v6 - uses: actions/setup-go@v6
with: with:
@@ -32,7 +32,7 @@ jobs:
run: make vet run: make vet
- name: golangci-lint - name: golangci-lint
uses: golangci/golangci-lint-action@v9 uses: golangci/golangci-lint-action@v8
with: with:
version: v2.5 version: v2.5
@@ -45,7 +45,7 @@ jobs:
- name: Build test mobile - name: Build test mobile
run: make build-test-mobile run: make build-test-mobile
- uses: actions/upload-artifact@v5 - uses: actions/upload-artifact@v4
with: with:
name: e2e packet flow linux-latest name: e2e packet flow linux-latest
path: e2e/mermaid/linux-latest path: e2e/mermaid/linux-latest
@@ -56,7 +56,7 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v5 - uses: actions/checkout@v4
- uses: actions/setup-go@v6 - uses: actions/setup-go@v6
with: with:
@@ -77,7 +77,7 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v5 - uses: actions/checkout@v4
- uses: actions/setup-go@v6 - uses: actions/setup-go@v6
with: with:
@@ -98,7 +98,7 @@ jobs:
os: [windows-latest, macos-latest] os: [windows-latest, macos-latest]
steps: steps:
- uses: actions/checkout@v5 - uses: actions/checkout@v4
- uses: actions/setup-go@v6 - uses: actions/setup-go@v6
with: with:
@@ -115,7 +115,7 @@ jobs:
run: make vet run: make vet
- name: golangci-lint - name: golangci-lint
uses: golangci/golangci-lint-action@v9 uses: golangci/golangci-lint-action@v8
with: with:
version: v2.5 version: v2.5
@@ -125,7 +125,7 @@ jobs:
- name: End 2 end - name: End 2 end
run: make e2evv run: make e2evv
- uses: actions/upload-artifact@v5 - uses: actions/upload-artifact@v4
with: with:
name: e2e packet flow ${{ matrix.os }} name: e2e packet flow ${{ matrix.os }}
path: e2e/mermaid/${{ matrix.os }} path: e2e/mermaid/${{ matrix.os }}

View File

@@ -43,7 +43,7 @@ func (b *Bits) Check(l logrus.FieldLogger, i uint64) bool {
} }
// Not within the window // Not within the window
l.Debugf("rejected a packet (top) %d %d delta %d\n", b.current, i, b.current-i) l.Debugf("rejected a packet (top) %d %d\n", b.current, i)
return false return false
} }

View File

@@ -114,33 +114,6 @@ func NewTestCert(v cert.Version, curve cert.Curve, ca cert.Certificate, key []by
return c, pub, cert.MarshalPrivateKeyToPEM(curve, priv), pem return c, pub, cert.MarshalPrivateKeyToPEM(curve, priv), pem
} }
func NewTestCertDifferentVersion(c cert.Certificate, v cert.Version, ca cert.Certificate, key []byte) (cert.Certificate, []byte) {
nc := &cert.TBSCertificate{
Version: v,
Curve: c.Curve(),
Name: c.Name(),
Networks: c.Networks(),
UnsafeNetworks: c.UnsafeNetworks(),
Groups: c.Groups(),
NotBefore: time.Unix(c.NotBefore().Unix(), 0),
NotAfter: time.Unix(c.NotAfter().Unix(), 0),
PublicKey: c.PublicKey(),
IsCA: false,
}
c, err := nc.Sign(ca, ca.Curve(), key)
if err != nil {
panic(err)
}
pem, err := c.MarshalPEM()
if err != nil {
panic(err)
}
return c, pem
}
func X25519Keypair() ([]byte, []byte) { func X25519Keypair() ([]byte, []byte) {
privkey := make([]byte, 32) privkey := make([]byte, 32)
if _, err := io.ReadFull(rand.Reader, privkey); err != nil { if _, err := io.ReadFull(rand.Reader, privkey); err != nil {

View File

@@ -3,9 +3,6 @@ package main
import ( import (
"flag" "flag"
"fmt" "fmt"
"log"
"net/http"
_ "net/http/pprof"
"os" "os"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
@@ -61,10 +58,6 @@ func main() {
os.Exit(1) os.Exit(1)
} }
go func() {
log.Println(http.ListenAndServe("0.0.0.0:6060", nil))
}()
if !*configTest { if !*configTest {
ctrl.Start() ctrl.Start()
notifyReady(l) notifyReady(l)

View File

@@ -17,7 +17,7 @@ import (
"dario.cat/mergo" "dario.cat/mergo"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"go.yaml.in/yaml/v3" "gopkg.in/yaml.v3"
) )
type C struct { type C struct {

View File

@@ -10,7 +10,7 @@ import (
"github.com/slackhq/nebula/test" "github.com/slackhq/nebula/test"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"go.yaml.in/yaml/v3" "gopkg.in/yaml.v3"
) )
func TestConfig_Load(t *testing.T) { func TestConfig_Load(t *testing.T) {

View File

@@ -354,6 +354,7 @@ func (cm *connectionManager) makeTrafficDecision(localIndex uint32, now time.Tim
if mainHostInfo { if mainHostInfo {
decision = tryRehandshake decision = tryRehandshake
} else { } else {
if cm.shouldSwapPrimary(hostinfo) { if cm.shouldSwapPrimary(hostinfo) {
decision = swapPrimary decision = swapPrimary
@@ -460,10 +461,6 @@ func (cm *connectionManager) shouldSwapPrimary(current *HostInfo) bool {
} }
crt := cm.intf.pki.getCertState().getCertificate(current.ConnectionState.myCert.Version()) crt := cm.intf.pki.getCertState().getCertificate(current.ConnectionState.myCert.Version())
if crt == nil {
//my cert was reloaded away. We should definitely swap from this tunnel
return true
}
// If this tunnel is using the latest certificate then we should swap it to primary for a bit and see if things // If this tunnel is using the latest certificate then we should swap it to primary for a bit and see if things
// settle down. // settle down.
return bytes.Equal(current.ConnectionState.myCert.Signature(), crt.Signature()) return bytes.Equal(current.ConnectionState.myCert.Signature(), crt.Signature())
@@ -478,34 +475,31 @@ func (cm *connectionManager) swapPrimary(current, primary *HostInfo) {
cm.hostMap.Unlock() cm.hostMap.Unlock()
} }
// isInvalidCertificate decides if we should destroy a tunnel. // isInvalidCertificate will check if we should destroy a tunnel if pki.disconnect_invalid is true and
// returns true if pki.disconnect_invalid is true and the certificate is no longer valid. // the certificate is no longer valid. Block listed certificates will skip the pki.disconnect_invalid
// Blocklisted certificates will skip the pki.disconnect_invalid check and return true. // check and return true.
func (cm *connectionManager) isInvalidCertificate(now time.Time, hostinfo *HostInfo) bool { func (cm *connectionManager) isInvalidCertificate(now time.Time, hostinfo *HostInfo) bool {
remoteCert := hostinfo.GetCert() remoteCert := hostinfo.GetCert()
if remoteCert == nil { if remoteCert == nil {
return false //don't tear down tunnels for handshakes in progress return false
} }
caPool := cm.intf.pki.GetCAPool() caPool := cm.intf.pki.GetCAPool()
err := caPool.VerifyCachedCertificate(now, remoteCert) err := caPool.VerifyCachedCertificate(now, remoteCert)
if err == nil { if err == nil {
return false //cert is still valid! yay! return false
} else if err == cert.ErrBlockListed { //avoiding errors.Is for speed }
if !cm.intf.disconnectInvalid.Load() && err != cert.ErrBlockListed {
// Block listed certificates should always be disconnected // Block listed certificates should always be disconnected
hostinfo.logger(cm.l).WithError(err). return false
WithField("fingerprint", remoteCert.Fingerprint). }
Info("Remote certificate is blocked, tearing down the tunnel")
return true
} else if cm.intf.disconnectInvalid.Load() {
hostinfo.logger(cm.l).WithError(err). hostinfo.logger(cm.l).WithError(err).
WithField("fingerprint", remoteCert.Fingerprint). WithField("fingerprint", remoteCert.Fingerprint).
Info("Remote certificate is no longer valid, tearing down the tunnel") Info("Remote certificate is no longer valid, tearing down the tunnel")
return true return true
} else {
//if we reach here, the cert is no longer valid, but we're configured to keep tunnels from now-invalid certs open
return false
}
} }
func (cm *connectionManager) sendPunch(hostinfo *HostInfo) { func (cm *connectionManager) sendPunch(hostinfo *HostInfo) {
@@ -536,45 +530,15 @@ func (cm *connectionManager) sendPunch(hostinfo *HostInfo) {
func (cm *connectionManager) tryRehandshake(hostinfo *HostInfo) { func (cm *connectionManager) tryRehandshake(hostinfo *HostInfo) {
cs := cm.intf.pki.getCertState() cs := cm.intf.pki.getCertState()
curCrt := hostinfo.ConnectionState.myCert curCrt := hostinfo.ConnectionState.myCert
curCrtVersion := curCrt.Version() myCrt := cs.getCertificate(curCrt.Version())
myCrt := cs.getCertificate(curCrtVersion) if curCrt.Version() >= cs.initiatingVersion && bytes.Equal(curCrt.Signature(), myCrt.Signature()) == true {
if myCrt == nil { // The current tunnel is using the latest certificate and version, no need to rehandshake.
cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
WithField("version", curCrtVersion).
WithField("reason", "local certificate removed").
Info("Re-handshaking with remote")
cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil)
return return
} }
peerCrt := hostinfo.ConnectionState.peerCert
if peerCrt != nil && curCrtVersion < peerCrt.Certificate.Version() {
// if our certificate version is less than theirs, and we have a matching version available, rehandshake?
if cs.getCertificate(peerCrt.Certificate.Version()) != nil {
cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
WithField("version", curCrtVersion).
WithField("peerVersion", peerCrt.Certificate.Version()).
WithField("reason", "local certificate version lower than peer, attempting to correct").
Info("Re-handshaking with remote")
cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], func(hh *HandshakeHostInfo) {
hh.initiatingVersionOverride = peerCrt.Certificate.Version()
})
return
}
}
if !bytes.Equal(curCrt.Signature(), myCrt.Signature()) {
cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs). cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
WithField("reason", "local certificate is not current"). WithField("reason", "local certificate is not current").
Info("Re-handshaking with remote") Info("Re-handshaking with remote")
cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil) cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil)
return
}
if curCrtVersion < cs.initiatingVersion {
cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
WithField("reason", "current cert version < pki.initiatingVersion").
Info("Re-handshaking with remote")
cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil)
return
}
} }

View File

@@ -13,7 +13,7 @@ import (
"github.com/slackhq/nebula/noiseutil" "github.com/slackhq/nebula/noiseutil"
) )
const ReplayWindow = 4096 const ReplayWindow = 1024
type ConnectionState struct { type ConnectionState struct {
eKey *NebulaCipherState eKey *NebulaCipherState

View File

@@ -174,10 +174,6 @@ func (c *Control) GetHostmap() *HostMap {
return c.f.hostMap return c.f.hostMap
} }
func (c *Control) GetF() *Interface {
return c.f
}
func (c *Control) GetCertState() *CertState { func (c *Control) GetCertState() *CertState {
return c.f.pki.getCertState() return c.f.pki.getCertState()
} }

View File

@@ -20,7 +20,7 @@ import (
"github.com/slackhq/nebula/udp" "github.com/slackhq/nebula/udp"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"go.yaml.in/yaml/v3" "gopkg.in/yaml.v3"
) )
func BenchmarkHotPath(b *testing.B) { func BenchmarkHotPath(b *testing.B) {
@@ -97,41 +97,6 @@ func TestGoodHandshake(t *testing.T) {
theirControl.Stop() theirControl.Stop()
} }
func TestGoodHandshakeNoOverlap(t *testing.T) {
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "me", "10.128.0.1/24", nil)
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "them", "2001::69/24", nil) //look ma, cross-stack!
// Put their info in our lighthouse
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
// Start the servers
myControl.Start()
theirControl.Start()
empty := []byte{}
t.Log("do something to cause a handshake")
myControl.GetF().SendMessageToVpnAddr(header.Test, header.MessageNone, theirVpnIpNet[0].Addr(), empty, empty, empty)
t.Log("Have them consume my stage 0 packet. They have a tunnel now")
theirControl.InjectUDPPacket(myControl.GetFromUDP(true))
t.Log("Get their stage 1 packet")
stage1Packet := theirControl.GetFromUDP(true)
t.Log("Have me consume their stage 1 packet. I have a tunnel now")
myControl.InjectUDPPacket(stage1Packet)
t.Log("Wait until we see a test packet come through to make sure we give the tunnel time to complete")
myControl.WaitForType(header.Test, 0, theirControl)
t.Log("Make sure our host infos are correct")
assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet, theirVpnIpNet, myControl, theirControl)
myControl.Stop()
theirControl.Stop()
}
func TestWrongResponderHandshake(t *testing.T) { func TestWrongResponderHandshake(t *testing.T) {
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
@@ -499,35 +464,6 @@ func TestRelays(t *testing.T) {
r.RenderHostmaps("Final hostmaps", myControl, relayControl, theirControl) r.RenderHostmaps("Final hostmaps", myControl, relayControl, theirControl)
} }
func TestRelaysDontCareAboutIps(t *testing.T) {
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version2, ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}})
relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "relay ", "2001::9999/24", m{"relay": m{"am_relay": true}})
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}})
// Teach my how to get to the relay and that their can be reached via the relay
myControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr)
myControl.InjectRelays(theirVpnIpNet[0].Addr(), []netip.Addr{relayVpnIpNet[0].Addr()})
relayControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
// Build a router so we don't have to reason who gets which packet
r := router.NewR(t, myControl, relayControl, theirControl)
defer r.RenderFlow()
// Start the servers
myControl.Start()
relayControl.Start()
theirControl.Start()
t.Log("Trigger a handshake from me to them via the relay")
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
p := r.RouteForAllUntilTxTun(theirControl)
r.Log("Assert the tunnel works")
assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80)
r.RenderHostmaps("Final hostmaps", myControl, relayControl, theirControl)
}
func TestReestablishRelays(t *testing.T) { func TestReestablishRelays(t *testing.T) {
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}}) myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}})
@@ -1291,109 +1227,3 @@ func TestV2NonPrimaryWithLighthouse(t *testing.T) {
myControl.Stop() myControl.Stop()
theirControl.Stop() theirControl.Stop()
} }
func TestV2NonPrimaryWithOffNetLighthouse(t *testing.T) {
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
lhControl, lhVpnIpNet, lhUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "lh ", "2001::1/64", m{"lighthouse": m{"am_lighthouse": true}})
o := m{
"static_host_map": m{
lhVpnIpNet[0].Addr().String(): []string{lhUdpAddr.String()},
},
"lighthouse": m{
"hosts": []string{lhVpnIpNet[0].Addr().String()},
"local_allow_list": m{
// Try and block our lighthouse updates from using the actual addresses assigned to this computer
// If we start discovering addresses the test router doesn't know about then test traffic cant flow
"10.0.0.0/24": true,
"::/0": false,
},
},
}
myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version2, ca, caKey, "me ", "10.128.0.2/24, ff::2/64", o)
theirControl, theirVpnIpNet, _, _ := newSimpleServer(cert.Version2, ca, caKey, "them", "10.128.0.3/24, ff::3/64", o)
// Build a router so we don't have to reason who gets which packet
r := router.NewR(t, lhControl, myControl, theirControl)
defer r.RenderFlow()
// Start the servers
lhControl.Start()
myControl.Start()
theirControl.Start()
t.Log("Stand up an ipv6 tunnel between me and them")
assert.True(t, myVpnIpNet[1].Addr().Is6())
assert.True(t, theirVpnIpNet[1].Addr().Is6())
assertTunnel(t, myVpnIpNet[1].Addr(), theirVpnIpNet[1].Addr(), myControl, theirControl, r)
lhControl.Stop()
myControl.Stop()
theirControl.Stop()
}
func TestGoodHandshakeUnsafeDest(t *testing.T) {
unsafePrefix := "192.168.6.0/24"
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServerWithUdpAndUnsafeNetworks(cert.Version2, ca, caKey, "spooky", "10.128.0.2/24", netip.MustParseAddrPort("10.64.0.2:4242"), unsafePrefix, nil)
route := m{"route": unsafePrefix, "via": theirVpnIpNet[0].Addr().String()}
myCfg := m{
"tun": m{
"unsafe_routes": []m{route},
},
}
myControl, myVpnIpNet, myUdpAddr, myConfig := newSimpleServer(cert.Version2, ca, caKey, "me", "10.128.0.1/24", myCfg)
t.Logf("my config %v", myConfig)
// Put their info in our lighthouse
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
spookyDest := netip.MustParseAddr("192.168.6.4")
// Start the servers
myControl.Start()
theirControl.Start()
t.Log("Send a udp packet through to begin standing up the tunnel, this should come out the other side")
myControl.InjectTunUDPPacket(spookyDest, 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
t.Log("Have them consume my stage 0 packet. They have a tunnel now")
theirControl.InjectUDPPacket(myControl.GetFromUDP(true))
t.Log("Get their stage 1 packet so that we can play with it")
stage1Packet := theirControl.GetFromUDP(true)
t.Log("I consume a garbage packet with a proper nebula header for our tunnel")
// this should log a statement and get ignored, allowing the real handshake packet to complete the tunnel
badPacket := stage1Packet.Copy()
badPacket.Data = badPacket.Data[:len(badPacket.Data)-header.Len]
myControl.InjectUDPPacket(badPacket)
t.Log("Have me consume their real stage 1 packet. I have a tunnel now")
myControl.InjectUDPPacket(stage1Packet)
t.Log("Wait until we see my cached packet come through")
myControl.WaitForType(1, 0, theirControl)
t.Log("Make sure our host infos are correct")
assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet, theirVpnIpNet, myControl, theirControl)
t.Log("Get that cached packet and make sure it looks right")
myCachedPacket := theirControl.GetFromTun(true)
assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet[0].Addr(), spookyDest, 80, 80)
//reply
theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, spookyDest, 80, []byte("Hi from the spookyman"))
//wait for reply
theirControl.WaitForType(1, 0, myControl)
theirCachedPacket := myControl.GetFromTun(true)
assertUdpPacket(t, []byte("Hi from the spookyman"), theirCachedPacket, spookyDest, myVpnIpNet[0].Addr(), 80, 80)
t.Log("Do a bidirectional tunnel test")
r := router.NewR(t, myControl, theirControl)
defer r.RenderFlow()
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
r.RenderHostmaps("Final hostmaps", myControl, theirControl)
myControl.Stop()
theirControl.Stop()
}

View File

@@ -22,14 +22,15 @@ import (
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/e2e/router" "github.com/slackhq/nebula/e2e/router"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "gopkg.in/yaml.v3"
"go.yaml.in/yaml/v3"
) )
type m = map[string]any type m = map[string]any
// newSimpleServer creates a nebula instance with many assumptions // newSimpleServer creates a nebula instance with many assumptions
func newSimpleServer(v cert.Version, caCrt cert.Certificate, caKey []byte, name string, sVpnNetworks string, overrides m) (*nebula.Control, []netip.Prefix, netip.AddrPort, *config.C) { func newSimpleServer(v cert.Version, caCrt cert.Certificate, caKey []byte, name string, sVpnNetworks string, overrides m) (*nebula.Control, []netip.Prefix, netip.AddrPort, *config.C) {
l := NewTestLogger()
var vpnNetworks []netip.Prefix var vpnNetworks []netip.Prefix
for _, sn := range strings.Split(sVpnNetworks, ",") { for _, sn := range strings.Split(sVpnNetworks, ",") {
vpnIpNet, err := netip.ParsePrefix(strings.TrimSpace(sn)) vpnIpNet, err := netip.ParsePrefix(strings.TrimSpace(sn))
@@ -55,54 +56,7 @@ func newSimpleServer(v cert.Version, caCrt cert.Certificate, caKey []byte, name
budpIp[3] = 239 budpIp[3] = 239
udpAddr = netip.AddrPortFrom(netip.AddrFrom16(budpIp), 4242) udpAddr = netip.AddrPortFrom(netip.AddrFrom16(budpIp), 4242)
} }
return newSimpleServerWithUdp(v, caCrt, caKey, name, sVpnNetworks, udpAddr, overrides) _, _, myPrivKey, myPEM := cert_test.NewTestCert(v, cert.Curve_CURVE25519, caCrt, caKey, name, time.Now(), time.Now().Add(5*time.Minute), vpnNetworks, nil, []string{})
}
func newSimpleServerWithUdp(v cert.Version, caCrt cert.Certificate, caKey []byte, name string, sVpnNetworks string, udpAddr netip.AddrPort, overrides m) (*nebula.Control, []netip.Prefix, netip.AddrPort, *config.C) {
return newSimpleServerWithUdpAndUnsafeNetworks(v, caCrt, caKey, name, sVpnNetworks, udpAddr, "", overrides)
}
func newSimpleServerWithUdpAndUnsafeNetworks(v cert.Version, caCrt cert.Certificate, caKey []byte, name string, sVpnNetworks string, udpAddr netip.AddrPort, sUnsafeNetworks string, overrides m) (*nebula.Control, []netip.Prefix, netip.AddrPort, *config.C) {
l := NewTestLogger()
var vpnNetworks []netip.Prefix
for _, sn := range strings.Split(sVpnNetworks, ",") {
vpnIpNet, err := netip.ParsePrefix(strings.TrimSpace(sn))
if err != nil {
panic(err)
}
vpnNetworks = append(vpnNetworks, vpnIpNet)
}
if len(vpnNetworks) == 0 {
panic("no vpn networks")
}
firewallInbound := []m{{
"proto": "any",
"port": "any",
"host": "any",
}}
var unsafeNetworks []netip.Prefix
if sUnsafeNetworks != "" {
firewallInbound = []m{{
"proto": "any",
"port": "any",
"host": "any",
"local_cidr": "0.0.0.0/0",
}}
for _, sn := range strings.Split(sUnsafeNetworks, ",") {
x, err := netip.ParsePrefix(strings.TrimSpace(sn))
if err != nil {
panic(err)
}
unsafeNetworks = append(unsafeNetworks, x)
}
}
_, _, myPrivKey, myPEM := cert_test.NewTestCert(v, cert.Curve_CURVE25519, caCrt, caKey, name, time.Now(), time.Now().Add(5*time.Minute), vpnNetworks, unsafeNetworks, []string{})
caB, err := caCrt.MarshalPEM() caB, err := caCrt.MarshalPEM()
if err != nil { if err != nil {
@@ -122,7 +76,11 @@ func newSimpleServerWithUdpAndUnsafeNetworks(v cert.Version, caCrt cert.Certific
"port": "any", "port": "any",
"host": "any", "host": "any",
}}, }},
"inbound": firewallInbound, "inbound": []m{{
"proto": "any",
"port": "any",
"host": "any",
}},
}, },
//"handshakes": m{ //"handshakes": m{
// "try_interval": "1s", // "try_interval": "1s",
@@ -171,109 +129,6 @@ func newSimpleServerWithUdpAndUnsafeNetworks(v cert.Version, caCrt cert.Certific
return control, vpnNetworks, udpAddr, c return control, vpnNetworks, udpAddr, c
} }
// newServer creates a nebula instance with fewer assumptions
func newServer(caCrt []cert.Certificate, certs []cert.Certificate, key []byte, overrides m) (*nebula.Control, []netip.Prefix, netip.AddrPort, *config.C) {
l := NewTestLogger()
vpnNetworks := certs[len(certs)-1].Networks()
var udpAddr netip.AddrPort
if vpnNetworks[0].Addr().Is4() {
budpIp := vpnNetworks[0].Addr().As4()
budpIp[1] -= 128
udpAddr = netip.AddrPortFrom(netip.AddrFrom4(budpIp), 4242)
} else {
budpIp := vpnNetworks[0].Addr().As16()
// beef for funsies
budpIp[2] = 190
budpIp[3] = 239
udpAddr = netip.AddrPortFrom(netip.AddrFrom16(budpIp), 4242)
}
caStr := ""
for _, ca := range caCrt {
x, err := ca.MarshalPEM()
if err != nil {
panic(err)
}
caStr += string(x)
}
certStr := ""
for _, c := range certs {
x, err := c.MarshalPEM()
if err != nil {
panic(err)
}
certStr += string(x)
}
mc := m{
"pki": m{
"ca": caStr,
"cert": certStr,
"key": string(key),
},
//"tun": m{"disabled": true},
"firewall": m{
"outbound": []m{{
"proto": "any",
"port": "any",
"host": "any",
}},
"inbound": []m{{
"proto": "any",
"port": "any",
"host": "any",
}},
},
//"handshakes": m{
// "try_interval": "1s",
//},
"listen": m{
"host": udpAddr.Addr().String(),
"port": udpAddr.Port(),
},
"logging": m{
"timestamp_format": fmt.Sprintf("%v 15:04:05.000000", certs[0].Name()),
"level": l.Level.String(),
},
"timers": m{
"pending_deletion_interval": 2,
"connection_alive_interval": 2,
},
}
if overrides != nil {
final := m{}
err := mergo.Merge(&final, overrides, mergo.WithAppendSlice)
if err != nil {
panic(err)
}
err = mergo.Merge(&final, mc, mergo.WithAppendSlice)
if err != nil {
panic(err)
}
mc = final
}
cb, err := yaml.Marshal(mc)
if err != nil {
panic(err)
}
c := config.NewC(l)
cStr := string(cb)
c.LoadString(cStr)
control, err := nebula.Main(c, false, "e2e-test", l, nil)
if err != nil {
panic(err)
}
return control, vpnNetworks, udpAddr, c
}
type doneCb func() type doneCb func()
func deadline(t *testing.T, seconds time.Duration) doneCb { func deadline(t *testing.T, seconds time.Duration) doneCb {
@@ -308,10 +163,10 @@ func assertHostInfoPair(t *testing.T, addrA, addrB netip.AddrPort, vpnNetsA, vpn
// Get both host infos // Get both host infos
//TODO: CERT-V2 we may want to loop over each vpnAddr and assert all the things //TODO: CERT-V2 we may want to loop over each vpnAddr and assert all the things
hBinA := controlA.GetHostInfoByVpnAddr(vpnNetsB[0].Addr(), false) hBinA := controlA.GetHostInfoByVpnAddr(vpnNetsB[0].Addr(), false)
require.NotNil(t, hBinA, "Host B was not found by vpnAddr in controlA") assert.NotNil(t, hBinA, "Host B was not found by vpnAddr in controlA")
hAinB := controlB.GetHostInfoByVpnAddr(vpnNetsA[0].Addr(), false) hAinB := controlB.GetHostInfoByVpnAddr(vpnNetsA[0].Addr(), false)
require.NotNil(t, hAinB, "Host A was not found by vpnAddr in controlB") assert.NotNil(t, hAinB, "Host A was not found by vpnAddr in controlB")
// Check that both vpn and real addr are correct // Check that both vpn and real addr are correct
assert.EqualValues(t, getAddrs(vpnNetsB), hBinA.VpnAddrs, "Host B VpnIp is wrong in control A") assert.EqualValues(t, getAddrs(vpnNetsB), hBinA.VpnAddrs, "Host B VpnIp is wrong in control A")

View File

@@ -4,16 +4,12 @@
package e2e package e2e
import ( import (
"fmt"
"net/netip"
"testing" "testing"
"time" "time"
"github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/cert_test" "github.com/slackhq/nebula/cert_test"
"github.com/slackhq/nebula/e2e/router" "github.com/slackhq/nebula/e2e/router"
"github.com/stretchr/testify/assert"
"gopkg.in/yaml.v3"
) )
func TestDropInactiveTunnels(t *testing.T) { func TestDropInactiveTunnels(t *testing.T) {
@@ -59,309 +55,3 @@ func TestDropInactiveTunnels(t *testing.T) {
myControl.Stop() myControl.Stop()
theirControl.Stop() theirControl.Stop()
} }
func TestCertUpgrade(t *testing.T) {
// The goal of this test is to ensure the shortest inactivity timeout will close the tunnel on both sides
// under ideal conditions
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
caB, err := ca.MarshalPEM()
if err != nil {
panic(err)
}
ca2, _, caKey2, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
ca2B, err := ca2.MarshalPEM()
if err != nil {
panic(err)
}
caStr := fmt.Sprintf("%s\n%s", caB, ca2B)
myCert, _, myPrivKey, _ := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "me", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.1/24")}, nil, []string{})
_, myCert2Pem := cert_test.NewTestCertDifferentVersion(myCert, cert.Version2, ca2, caKey2)
theirCert, _, theirPrivKey, _ := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "them", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.2/24")}, nil, []string{})
theirCert2, _ := cert_test.NewTestCertDifferentVersion(theirCert, cert.Version2, ca2, caKey2)
myControl, myVpnIpNet, myUdpAddr, myC := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{myCert}, myPrivKey, m{})
theirControl, theirVpnIpNet, theirUdpAddr, _ := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{theirCert, theirCert2}, theirPrivKey, m{})
// Share our underlay information
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
// Start the servers
myControl.Start()
theirControl.Start()
r := router.NewR(t, myControl, theirControl)
defer r.RenderFlow()
r.Log("Assert the tunnel between me and them works")
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
r.Log("yay")
//todo ???
time.Sleep(1 * time.Second)
r.FlushAll()
mc := m{
"pki": m{
"ca": caStr,
"cert": string(myCert2Pem),
"key": string(myPrivKey),
},
//"tun": m{"disabled": true},
"firewall": myC.Settings["firewall"],
//"handshakes": m{
// "try_interval": "1s",
//},
"listen": myC.Settings["listen"],
"logging": myC.Settings["logging"],
"timers": myC.Settings["timers"],
}
cb, err := yaml.Marshal(mc)
if err != nil {
panic(err)
}
r.Logf("reload new v2-only config")
err = myC.ReloadConfigString(string(cb))
assert.NoError(t, err)
r.Log("yay, spin until their sees it")
waitStart := time.Now()
for {
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
c := theirControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false)
if c == nil {
r.Log("nil")
} else {
version := c.Cert.Version()
r.Logf("version %d", version)
if version == cert.Version2 {
break
}
}
since := time.Since(waitStart)
if since > time.Second*10 {
t.Fatal("Cert should be new by now")
}
time.Sleep(time.Second)
}
r.RenderHostmaps("Final hostmaps", myControl, theirControl)
myControl.Stop()
theirControl.Stop()
}
func TestCertDowngrade(t *testing.T) {
// The goal of this test is to ensure the shortest inactivity timeout will close the tunnel on both sides
// under ideal conditions
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
caB, err := ca.MarshalPEM()
if err != nil {
panic(err)
}
ca2, _, caKey2, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
ca2B, err := ca2.MarshalPEM()
if err != nil {
panic(err)
}
caStr := fmt.Sprintf("%s\n%s", caB, ca2B)
myCert, _, myPrivKey, myCertPem := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "me", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.1/24")}, nil, []string{})
myCert2, _ := cert_test.NewTestCertDifferentVersion(myCert, cert.Version2, ca2, caKey2)
theirCert, _, theirPrivKey, _ := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "them", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.2/24")}, nil, []string{})
theirCert2, _ := cert_test.NewTestCertDifferentVersion(theirCert, cert.Version2, ca2, caKey2)
myControl, myVpnIpNet, myUdpAddr, myC := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{myCert2}, myPrivKey, m{})
theirControl, theirVpnIpNet, theirUdpAddr, _ := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{theirCert, theirCert2}, theirPrivKey, m{})
// Share our underlay information
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
// Start the servers
myControl.Start()
theirControl.Start()
r := router.NewR(t, myControl, theirControl)
defer r.RenderFlow()
r.Log("Assert the tunnel between me and them works")
//assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r)
//r.Log("yay")
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
r.Log("yay")
//todo ???
time.Sleep(1 * time.Second)
r.FlushAll()
mc := m{
"pki": m{
"ca": caStr,
"cert": string(myCertPem),
"key": string(myPrivKey),
},
"firewall": myC.Settings["firewall"],
"listen": myC.Settings["listen"],
"logging": myC.Settings["logging"],
"timers": myC.Settings["timers"],
}
cb, err := yaml.Marshal(mc)
if err != nil {
panic(err)
}
r.Logf("reload new v1-only config")
err = myC.ReloadConfigString(string(cb))
assert.NoError(t, err)
r.Log("yay, spin until their sees it")
waitStart := time.Now()
for {
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
c := theirControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false)
c2 := myControl.GetHostInfoByVpnAddr(theirVpnIpNet[0].Addr(), false)
if c == nil || c2 == nil {
r.Log("nil")
} else {
version := c.Cert.Version()
theirVersion := c2.Cert.Version()
r.Logf("version %d,%d", version, theirVersion)
if version == cert.Version1 {
break
}
}
since := time.Since(waitStart)
if since > time.Second*5 {
r.Log("it is unusual that the cert is not new yet, but not a failure yet")
}
if since > time.Second*10 {
r.Log("wtf")
t.Fatal("Cert should be new by now")
}
time.Sleep(time.Second)
}
r.RenderHostmaps("Final hostmaps", myControl, theirControl)
myControl.Stop()
theirControl.Stop()
}
func TestCertMismatchCorrection(t *testing.T) {
// The goal of this test is to ensure the shortest inactivity timeout will close the tunnel on both sides
// under ideal conditions
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
ca2, _, caKey2, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
myCert, _, myPrivKey, _ := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "me", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.1/24")}, nil, []string{})
myCert2, _ := cert_test.NewTestCertDifferentVersion(myCert, cert.Version2, ca2, caKey2)
theirCert, _, theirPrivKey, _ := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "them", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.2/24")}, nil, []string{})
theirCert2, _ := cert_test.NewTestCertDifferentVersion(theirCert, cert.Version2, ca2, caKey2)
myControl, myVpnIpNet, myUdpAddr, _ := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{myCert2}, myPrivKey, m{})
theirControl, theirVpnIpNet, theirUdpAddr, _ := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{theirCert, theirCert2}, theirPrivKey, m{})
// Share our underlay information
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
// Start the servers
myControl.Start()
theirControl.Start()
r := router.NewR(t, myControl, theirControl)
defer r.RenderFlow()
r.Log("Assert the tunnel between me and them works")
//assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r)
//r.Log("yay")
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
r.Log("yay")
//todo ???
time.Sleep(1 * time.Second)
r.FlushAll()
waitStart := time.Now()
for {
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
c := theirControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false)
c2 := myControl.GetHostInfoByVpnAddr(theirVpnIpNet[0].Addr(), false)
if c == nil || c2 == nil {
r.Log("nil")
} else {
version := c.Cert.Version()
theirVersion := c2.Cert.Version()
r.Logf("version %d,%d", version, theirVersion)
if version == theirVersion {
break
}
}
since := time.Since(waitStart)
if since > time.Second*5 {
r.Log("wtf")
}
if since > time.Second*10 {
r.Log("wtf")
t.Fatal("Cert should be new by now")
}
time.Sleep(time.Second)
}
r.RenderHostmaps("Final hostmaps", myControl, theirControl)
myControl.Stop()
theirControl.Stop()
}
func TestCrossStackRelaysWork(t *testing.T) {
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version2, ca, caKey, "me ", "10.128.0.1/24,fc00::1/64", m{"relay": m{"use_relays": true}})
relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "relay ", "10.128.0.128/24,fc00::128/64", m{"relay": m{"am_relay": true}})
theirUdp := netip.MustParseAddrPort("10.0.0.2:4242")
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServerWithUdp(cert.Version2, ca, caKey, "them ", "fc00::2/64", theirUdp, m{"relay": m{"use_relays": true}})
//myVpnV4 := myVpnIpNet[0]
myVpnV6 := myVpnIpNet[1]
relayVpnV4 := relayVpnIpNet[0]
relayVpnV6 := relayVpnIpNet[1]
theirVpnV6 := theirVpnIpNet[0]
// Teach my how to get to the relay and that their can be reached via the relay
myControl.InjectLightHouseAddr(relayVpnV4.Addr(), relayUdpAddr)
myControl.InjectLightHouseAddr(relayVpnV6.Addr(), relayUdpAddr)
myControl.InjectRelays(theirVpnV6.Addr(), []netip.Addr{relayVpnV6.Addr()})
relayControl.InjectLightHouseAddr(theirVpnV6.Addr(), theirUdpAddr)
// Build a router so we don't have to reason who gets which packet
r := router.NewR(t, myControl, relayControl, theirControl)
defer r.RenderFlow()
// Start the servers
myControl.Start()
relayControl.Start()
theirControl.Start()
t.Log("Trigger a handshake from me to them via the relay")
myControl.InjectTunUDPPacket(theirVpnV6.Addr(), 80, myVpnV6.Addr(), 80, []byte("Hi from me"))
p := r.RouteForAllUntilTxTun(theirControl)
r.Log("Assert the tunnel works")
assertUdpPacket(t, []byte("Hi from me"), p, myVpnV6.Addr(), theirVpnV6.Addr(), 80, 80)
t.Log("reply?")
theirControl.InjectTunUDPPacket(myVpnV6.Addr(), 80, theirVpnV6.Addr(), 80, []byte("Hi from them"))
p = r.RouteForAllUntilTxTun(myControl)
assertUdpPacket(t, []byte("Hi from them"), p, theirVpnV6.Addr(), myVpnV6.Addr(), 80, 80)
r.RenderHostmaps("Final hostmaps", myControl, relayControl, theirControl)
//t.Log("finish up")
//myControl.Stop()
//theirControl.Stop()
//relayControl.Stop()
}

View File

@@ -417,45 +417,30 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw
return nil return nil
} }
var ErrUnknownNetworkType = errors.New("unknown network type") var ErrInvalidRemoteIP = errors.New("remote IP is not in remote certificate subnets")
var ErrPeerRejected = errors.New("remote address is not within a network that we handle") var ErrInvalidLocalIP = errors.New("local IP is not in list of handled local IPs")
var ErrInvalidRemoteIP = errors.New("remote address is not in remote certificate networks")
var ErrInvalidLocalIP = errors.New("local address is not in list of handled local addresses")
var ErrNoMatchingRule = errors.New("no matching rule in firewall table") var ErrNoMatchingRule = errors.New("no matching rule in firewall table")
// Drop returns an error if the packet should be dropped, explaining why. It // Drop returns an error if the packet should be dropped, explaining why. It
// returns nil if the packet should not be dropped. // returns nil if the packet should not be dropped.
func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *cert.CAPool, localCache firewall.ConntrackCache, now time.Time) error { func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *cert.CAPool, localCache firewall.ConntrackCache) error {
// Check if we spoke to this tuple, if we did then allow this packet // Check if we spoke to this tuple, if we did then allow this packet
if f.inConns(fp, h, caPool, localCache, now) { if f.inConns(fp, h, caPool, localCache) {
return nil return nil
} }
// Make sure remote address matches nebula certificate, and determine how to treat it // Make sure remote address matches nebula certificate
if h.networks == nil { if h.networks != nil {
// Simple case: Certificate has one address and no unsafe networks if !h.networks.Contains(fp.RemoteAddr) {
if h.vpnAddrs[0] != fp.RemoteAddr {
f.metrics(incoming).droppedRemoteAddr.Inc(1) f.metrics(incoming).droppedRemoteAddr.Inc(1)
return ErrInvalidRemoteIP return ErrInvalidRemoteIP
} }
} else { } else {
nwType, ok := h.networks.Lookup(fp.RemoteAddr) // Simple case: Certificate has one address and no unsafe networks
if !ok { if h.vpnAddrs[0] != fp.RemoteAddr {
f.metrics(incoming).droppedRemoteAddr.Inc(1) f.metrics(incoming).droppedRemoteAddr.Inc(1)
return ErrInvalidRemoteIP return ErrInvalidRemoteIP
} }
switch nwType {
case NetworkTypeVPN:
break // nothing special
case NetworkTypeVPNPeer:
f.metrics(incoming).droppedRemoteAddr.Inc(1)
return ErrPeerRejected // reject for now, one day this may have different FW rules
case NetworkTypeUnsafe:
break // nothing special, one day this may have different FW rules
default:
f.metrics(incoming).droppedRemoteAddr.Inc(1)
return ErrUnknownNetworkType //should never happen
}
} }
// Make sure we are supposed to be handling this local ip address // Make sure we are supposed to be handling this local ip address
@@ -476,7 +461,7 @@ func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *
} }
// We always want to conntrack since it is a faster operation // We always want to conntrack since it is a faster operation
f.addConn(fp, incoming, now) f.addConn(fp, incoming)
return nil return nil
} }
@@ -505,7 +490,7 @@ func (f *Firewall) EmitStats() {
metrics.GetOrRegisterGauge("firewall.rules.hash", nil).Update(int64(f.GetRuleHashFNV())) metrics.GetOrRegisterGauge("firewall.rules.hash", nil).Update(int64(f.GetRuleHashFNV()))
} }
func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool, localCache firewall.ConntrackCache, now time.Time) bool { func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool, localCache firewall.ConntrackCache) bool {
if localCache != nil { if localCache != nil {
if _, ok := localCache[fp]; ok { if _, ok := localCache[fp]; ok {
return true return true
@@ -517,7 +502,7 @@ func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool,
// Purge every time we test // Purge every time we test
ep, has := conntrack.TimerWheel.Purge() ep, has := conntrack.TimerWheel.Purge()
if has { if has {
f.evict(ep, now) f.evict(ep)
} }
c, ok := conntrack.Conns[fp] c, ok := conntrack.Conns[fp]
@@ -564,11 +549,11 @@ func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool,
switch fp.Protocol { switch fp.Protocol {
case firewall.ProtoTCP: case firewall.ProtoTCP:
c.Expires = now.Add(f.TCPTimeout) c.Expires = time.Now().Add(f.TCPTimeout)
case firewall.ProtoUDP: case firewall.ProtoUDP:
c.Expires = now.Add(f.UDPTimeout) c.Expires = time.Now().Add(f.UDPTimeout)
default: default:
c.Expires = now.Add(f.DefaultTimeout) c.Expires = time.Now().Add(f.DefaultTimeout)
} }
conntrack.Unlock() conntrack.Unlock()
@@ -580,7 +565,7 @@ func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool,
return true return true
} }
func (f *Firewall) addConn(fp firewall.Packet, incoming bool, now time.Time) { func (f *Firewall) addConn(fp firewall.Packet, incoming bool) {
var timeout time.Duration var timeout time.Duration
c := &conn{} c := &conn{}
@@ -596,7 +581,7 @@ func (f *Firewall) addConn(fp firewall.Packet, incoming bool, now time.Time) {
conntrack := f.Conntrack conntrack := f.Conntrack
conntrack.Lock() conntrack.Lock()
if _, ok := conntrack.Conns[fp]; !ok { if _, ok := conntrack.Conns[fp]; !ok {
conntrack.TimerWheel.Advance(now) conntrack.TimerWheel.Advance(time.Now())
conntrack.TimerWheel.Add(fp, timeout) conntrack.TimerWheel.Add(fp, timeout)
} }
@@ -604,14 +589,14 @@ func (f *Firewall) addConn(fp firewall.Packet, incoming bool, now time.Time) {
// firewall reload // firewall reload
c.incoming = incoming c.incoming = incoming
c.rulesVersion = f.rulesVersion c.rulesVersion = f.rulesVersion
c.Expires = now.Add(timeout) c.Expires = time.Now().Add(timeout)
conntrack.Conns[fp] = c conntrack.Conns[fp] = c
conntrack.Unlock() conntrack.Unlock()
} }
// Evict checks if a conntrack entry has expired, if so it is removed, if not it is re-added to the wheel // Evict checks if a conntrack entry has expired, if so it is removed, if not it is re-added to the wheel
// Caller must own the connMutex lock! // Caller must own the connMutex lock!
func (f *Firewall) evict(p firewall.Packet, now time.Time) { func (f *Firewall) evict(p firewall.Packet) {
// Are we still tracking this conn? // Are we still tracking this conn?
conntrack := f.Conntrack conntrack := f.Conntrack
t, ok := conntrack.Conns[p] t, ok := conntrack.Conns[p]
@@ -619,11 +604,11 @@ func (f *Firewall) evict(p firewall.Packet, now time.Time) {
return return
} }
newT := t.Expires.Sub(now) newT := t.Expires.Sub(time.Now())
// Timeout is in the future, re-add the timer // Timeout is in the future, re-add the timer
if newT > 0 { if newT > 0 {
conntrack.TimerWheel.Advance(now) conntrack.TimerWheel.Advance(time.Now())
conntrack.TimerWheel.Add(p, newT) conntrack.TimerWheel.Add(p, newT)
return return
} }

View File

@@ -8,8 +8,6 @@ import (
"testing" "testing"
"time" "time"
"github.com/gaissmai/bart"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/firewall" "github.com/slackhq/nebula/firewall"
@@ -151,8 +149,7 @@ func TestFirewall_Drop(t *testing.T) {
l := test.NewLogger() l := test.NewLogger()
ob := &bytes.Buffer{} ob := &bytes.Buffer{}
l.SetOutput(ob) l.SetOutput(ob)
myVpnNetworksTable := new(bart.Lite)
myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
p := firewall.Packet{ p := firewall.Packet{
LocalAddr: netip.MustParseAddr("1.2.3.4"), LocalAddr: netip.MustParseAddr("1.2.3.4"),
RemoteAddr: netip.MustParseAddr("1.2.3.4"), RemoteAddr: netip.MustParseAddr("1.2.3.4"),
@@ -177,7 +174,7 @@ func TestFirewall_Drop(t *testing.T) {
}, },
vpnAddrs: []netip.Addr{netip.MustParseAddr("1.2.3.4")}, vpnAddrs: []netip.Addr{netip.MustParseAddr("1.2.3.4")},
} }
h.buildNetworks(myVpnNetworksTable, &c) h.buildNetworks(c.networks, c.unsafeNetworks)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
@@ -229,9 +226,6 @@ func TestFirewall_DropV6(t *testing.T) {
ob := &bytes.Buffer{} ob := &bytes.Buffer{}
l.SetOutput(ob) l.SetOutput(ob)
myVpnNetworksTable := new(bart.Lite)
myVpnNetworksTable.Insert(netip.MustParsePrefix("fd00::/7"))
p := firewall.Packet{ p := firewall.Packet{
LocalAddr: netip.MustParseAddr("fd12::34"), LocalAddr: netip.MustParseAddr("fd12::34"),
RemoteAddr: netip.MustParseAddr("fd12::34"), RemoteAddr: netip.MustParseAddr("fd12::34"),
@@ -256,7 +250,7 @@ func TestFirewall_DropV6(t *testing.T) {
}, },
vpnAddrs: []netip.Addr{netip.MustParseAddr("fd12::34")}, vpnAddrs: []netip.Addr{netip.MustParseAddr("fd12::34")},
} }
h.buildNetworks(myVpnNetworksTable, &c) h.buildNetworks(c.networks, c.unsafeNetworks)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
@@ -398,7 +392,7 @@ func BenchmarkFirewallTable_match(b *testing.B) {
c := &cert.CachedCertificate{ c := &cert.CachedCertificate{
Certificate: &dummyCert{ Certificate: &dummyCert{
name: "nope", name: "nope",
networks: []netip.Prefix{netip.MustParsePrefix("fd99::99/128")}, networks: []netip.Prefix{netip.MustParsePrefix("fd99:99/128")},
}, },
InvertedGroups: map[string]struct{}{"nope": {}}, InvertedGroups: map[string]struct{}{"nope": {}},
} }
@@ -459,8 +453,6 @@ func TestFirewall_Drop2(t *testing.T) {
l := test.NewLogger() l := test.NewLogger()
ob := &bytes.Buffer{} ob := &bytes.Buffer{}
l.SetOutput(ob) l.SetOutput(ob)
myVpnNetworksTable := new(bart.Lite)
myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
p := firewall.Packet{ p := firewall.Packet{
LocalAddr: netip.MustParseAddr("1.2.3.4"), LocalAddr: netip.MustParseAddr("1.2.3.4"),
@@ -486,7 +478,7 @@ func TestFirewall_Drop2(t *testing.T) {
}, },
vpnAddrs: []netip.Addr{network.Addr()}, vpnAddrs: []netip.Addr{network.Addr()},
} }
h.buildNetworks(myVpnNetworksTable, c.Certificate) h.buildNetworks(c.Certificate.Networks(), c.Certificate.UnsafeNetworks())
c1 := cert.CachedCertificate{ c1 := cert.CachedCertificate{
Certificate: &dummyCert{ Certificate: &dummyCert{
@@ -501,7 +493,7 @@ func TestFirewall_Drop2(t *testing.T) {
peerCert: &c1, peerCert: &c1,
}, },
} }
h1.buildNetworks(myVpnNetworksTable, c1.Certificate) h1.buildNetworks(c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks())
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
@@ -518,8 +510,6 @@ func TestFirewall_Drop3(t *testing.T) {
l := test.NewLogger() l := test.NewLogger()
ob := &bytes.Buffer{} ob := &bytes.Buffer{}
l.SetOutput(ob) l.SetOutput(ob)
myVpnNetworksTable := new(bart.Lite)
myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
p := firewall.Packet{ p := firewall.Packet{
LocalAddr: netip.MustParseAddr("1.2.3.4"), LocalAddr: netip.MustParseAddr("1.2.3.4"),
@@ -551,7 +541,7 @@ func TestFirewall_Drop3(t *testing.T) {
}, },
vpnAddrs: []netip.Addr{network.Addr()}, vpnAddrs: []netip.Addr{network.Addr()},
} }
h1.buildNetworks(myVpnNetworksTable, c1.Certificate) h1.buildNetworks(c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks())
c2 := cert.CachedCertificate{ c2 := cert.CachedCertificate{
Certificate: &dummyCert{ Certificate: &dummyCert{
@@ -566,7 +556,7 @@ func TestFirewall_Drop3(t *testing.T) {
}, },
vpnAddrs: []netip.Addr{network.Addr()}, vpnAddrs: []netip.Addr{network.Addr()},
} }
h2.buildNetworks(myVpnNetworksTable, c2.Certificate) h2.buildNetworks(c2.Certificate.Networks(), c2.Certificate.UnsafeNetworks())
c3 := cert.CachedCertificate{ c3 := cert.CachedCertificate{
Certificate: &dummyCert{ Certificate: &dummyCert{
@@ -581,7 +571,7 @@ func TestFirewall_Drop3(t *testing.T) {
}, },
vpnAddrs: []netip.Addr{network.Addr()}, vpnAddrs: []netip.Addr{network.Addr()},
} }
h3.buildNetworks(myVpnNetworksTable, c3.Certificate) h3.buildNetworks(c3.Certificate.Networks(), c3.Certificate.UnsafeNetworks())
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", netip.Prefix{}, netip.Prefix{}, "", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", netip.Prefix{}, netip.Prefix{}, "", ""))
@@ -607,8 +597,6 @@ func TestFirewall_Drop3V6(t *testing.T) {
l := test.NewLogger() l := test.NewLogger()
ob := &bytes.Buffer{} ob := &bytes.Buffer{}
l.SetOutput(ob) l.SetOutput(ob)
myVpnNetworksTable := new(bart.Lite)
myVpnNetworksTable.Insert(netip.MustParsePrefix("fd00::/7"))
p := firewall.Packet{ p := firewall.Packet{
LocalAddr: netip.MustParseAddr("fd12::34"), LocalAddr: netip.MustParseAddr("fd12::34"),
@@ -632,7 +620,7 @@ func TestFirewall_Drop3V6(t *testing.T) {
}, },
vpnAddrs: []netip.Addr{network.Addr()}, vpnAddrs: []netip.Addr{network.Addr()},
} }
h.buildNetworks(myVpnNetworksTable, c.Certificate) h.buildNetworks(c.Certificate.Networks(), c.Certificate.UnsafeNetworks())
// Test a remote address match // Test a remote address match
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
@@ -645,8 +633,6 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
l := test.NewLogger() l := test.NewLogger()
ob := &bytes.Buffer{} ob := &bytes.Buffer{}
l.SetOutput(ob) l.SetOutput(ob)
myVpnNetworksTable := new(bart.Lite)
myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
p := firewall.Packet{ p := firewall.Packet{
LocalAddr: netip.MustParseAddr("1.2.3.4"), LocalAddr: netip.MustParseAddr("1.2.3.4"),
@@ -673,7 +659,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
}, },
vpnAddrs: []netip.Addr{network.Addr()}, vpnAddrs: []netip.Addr{network.Addr()},
} }
h.buildNetworks(myVpnNetworksTable, c.Certificate) h.buildNetworks(c.Certificate.Networks(), c.Certificate.UnsafeNetworks())
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
@@ -710,8 +696,6 @@ func TestFirewall_DropIPSpoofing(t *testing.T) {
l := test.NewLogger() l := test.NewLogger()
ob := &bytes.Buffer{} ob := &bytes.Buffer{}
l.SetOutput(ob) l.SetOutput(ob)
myVpnNetworksTable := new(bart.Lite)
myVpnNetworksTable.Insert(netip.MustParsePrefix("192.0.2.1/24"))
c := cert.CachedCertificate{ c := cert.CachedCertificate{
Certificate: &dummyCert{ Certificate: &dummyCert{
@@ -733,7 +717,7 @@ func TestFirewall_DropIPSpoofing(t *testing.T) {
}, },
vpnAddrs: []netip.Addr{c1.Certificate.Networks()[0].Addr()}, vpnAddrs: []netip.Addr{c1.Certificate.Networks()[0].Addr()},
} }
h1.buildNetworks(myVpnNetworksTable, c1.Certificate) h1.buildNetworks(c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks())
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
@@ -1063,171 +1047,6 @@ func TestFirewall_convertRule(t *testing.T) {
assert.Equal(t, "group1", r.Group) assert.Equal(t, "group1", r.Group)
} }
type testcase struct {
h *HostInfo
p firewall.Packet
c cert.Certificate
err error
}
func (c *testcase) Test(t *testing.T, fw *Firewall) {
t.Helper()
cp := cert.NewCAPool()
resetConntrack(fw)
err := fw.Drop(c.p, true, c.h, cp, nil)
if c.err == nil {
require.NoError(t, err, "failed to not drop remote address %s", c.p.RemoteAddr)
} else {
require.ErrorIs(t, c.err, err, "failed to drop remote address %s", c.p.RemoteAddr)
}
}
func buildTestCase(setup testsetup, err error, theirPrefixes ...netip.Prefix) testcase {
c1 := dummyCert{
name: "host1",
networks: theirPrefixes,
groups: []string{"default-group"},
issuer: "signer-shasum",
}
h := HostInfo{
ConnectionState: &ConnectionState{
peerCert: &cert.CachedCertificate{
Certificate: &c1,
InvertedGroups: map[string]struct{}{"default-group": {}},
},
},
vpnAddrs: make([]netip.Addr, len(theirPrefixes)),
}
for i := range theirPrefixes {
h.vpnAddrs[i] = theirPrefixes[i].Addr()
}
h.buildNetworks(setup.myVpnNetworksTable, &c1)
p := firewall.Packet{
LocalAddr: setup.c.Networks()[0].Addr(), //todo?
RemoteAddr: theirPrefixes[0].Addr(),
LocalPort: 10,
RemotePort: 90,
Protocol: firewall.ProtoUDP,
Fragment: false,
}
return testcase{
h: &h,
p: p,
c: &c1,
err: err,
}
}
type testsetup struct {
c dummyCert
myVpnNetworksTable *bart.Lite
fw *Firewall
}
func newSetup(t *testing.T, l *logrus.Logger, myPrefixes ...netip.Prefix) testsetup {
c := dummyCert{
name: "me",
networks: myPrefixes,
groups: []string{"default-group"},
issuer: "signer-shasum",
}
return newSetupFromCert(t, l, c)
}
func newSetupFromCert(t *testing.T, l *logrus.Logger, c dummyCert) testsetup {
myVpnNetworksTable := new(bart.Lite)
for _, prefix := range c.Networks() {
myVpnNetworksTable.Insert(prefix)
}
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
return testsetup{
c: c,
fw: fw,
myVpnNetworksTable: myVpnNetworksTable,
}
}
func TestFirewall_Drop_EnforceIPMatch(t *testing.T) {
t.Parallel()
l := test.NewLogger()
ob := &bytes.Buffer{}
l.SetOutput(ob)
myPrefix := netip.MustParsePrefix("1.1.1.1/8")
// for now, it's okay that these are all "incoming", the logic this test tries to check doesn't care about in/out
t.Run("allow inbound all matching", func(t *testing.T) {
t.Parallel()
setup := newSetup(t, l, myPrefix)
tc := buildTestCase(setup, nil, netip.MustParsePrefix("1.2.3.4/24"))
tc.Test(t, setup.fw)
})
t.Run("allow inbound local matching", func(t *testing.T) {
t.Parallel()
setup := newSetup(t, l, myPrefix)
tc := buildTestCase(setup, ErrInvalidLocalIP, netip.MustParsePrefix("1.2.3.4/24"))
tc.p.LocalAddr = netip.MustParseAddr("1.2.3.8")
tc.Test(t, setup.fw)
})
t.Run("block inbound remote mismatched", func(t *testing.T) {
t.Parallel()
setup := newSetup(t, l, myPrefix)
tc := buildTestCase(setup, ErrInvalidRemoteIP, netip.MustParsePrefix("1.2.3.4/24"))
tc.p.RemoteAddr = netip.MustParseAddr("9.9.9.9")
tc.Test(t, setup.fw)
})
t.Run("Block a vpn peer packet", func(t *testing.T) {
t.Parallel()
setup := newSetup(t, l, myPrefix)
tc := buildTestCase(setup, ErrPeerRejected, netip.MustParsePrefix("2.2.2.2/24"))
tc.Test(t, setup.fw)
})
twoPrefixes := []netip.Prefix{
netip.MustParsePrefix("1.2.3.4/24"), netip.MustParsePrefix("2.2.2.2/24"),
}
t.Run("allow inbound one matching", func(t *testing.T) {
t.Parallel()
setup := newSetup(t, l, myPrefix)
tc := buildTestCase(setup, nil, twoPrefixes...)
tc.Test(t, setup.fw)
})
t.Run("block inbound multimismatch", func(t *testing.T) {
t.Parallel()
setup := newSetup(t, l, myPrefix)
tc := buildTestCase(setup, ErrInvalidRemoteIP, twoPrefixes...)
tc.p.RemoteAddr = netip.MustParseAddr("9.9.9.9")
tc.Test(t, setup.fw)
})
t.Run("allow inbound 2nd one matching", func(t *testing.T) {
t.Parallel()
setup2 := newSetup(t, l, netip.MustParsePrefix("2.2.2.1/24"))
tc := buildTestCase(setup2, nil, twoPrefixes...)
tc.p.RemoteAddr = twoPrefixes[1].Addr()
tc.Test(t, setup2.fw)
})
t.Run("allow inbound unsafe route", func(t *testing.T) {
t.Parallel()
unsafePrefix := netip.MustParsePrefix("192.168.0.0/24")
c := dummyCert{
name: "me",
networks: []netip.Prefix{myPrefix},
unsafeNetworks: []netip.Prefix{unsafePrefix},
groups: []string{"default-group"},
issuer: "signer-shasum",
}
unsafeSetup := newSetupFromCert(t, l, c)
tc := buildTestCase(unsafeSetup, nil, twoPrefixes...)
tc.p.LocalAddr = netip.MustParseAddr("192.168.0.3")
tc.err = ErrNoMatchingRule
tc.Test(t, unsafeSetup.fw) //should hit firewall and bounce off
require.NoError(t, unsafeSetup.fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, unsafePrefix, "", ""))
tc.err = nil
tc.Test(t, unsafeSetup.fw) //should pass
})
}
type addRuleCall struct { type addRuleCall struct {
incoming bool incoming bool
proto uint8 proto uint8

15
go.mod
View File

@@ -22,17 +22,16 @@ require (
github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6 github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6
github.com/stretchr/testify v1.11.1 github.com/stretchr/testify v1.11.1
github.com/vishvananda/netlink v1.3.1 github.com/vishvananda/netlink v1.3.1
go.yaml.in/yaml/v3 v3.0.4 golang.org/x/crypto v0.43.0
golang.org/x/crypto v0.44.0
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 golang.org/x/exp v0.0.0-20230725093048-515e97ebf090
golang.org/x/net v0.46.0 golang.org/x/net v0.45.0
golang.org/x/sync v0.18.0 golang.org/x/sync v0.17.0
golang.org/x/sys v0.38.0 golang.org/x/sys v0.37.0
golang.org/x/term v0.37.0 golang.org/x/term v0.36.0
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2
golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b
golang.zx2c4.com/wireguard/windows v0.5.3 golang.zx2c4.com/wireguard/windows v0.5.3
google.golang.org/protobuf v1.36.10 google.golang.org/protobuf v1.36.8
gopkg.in/yaml.v3 v3.0.1 gopkg.in/yaml.v3 v3.0.1
gvisor.dev/gvisor v0.0.0-20240423190808-9d7a357edefe gvisor.dev/gvisor v0.0.0-20240423190808-9d7a357edefe
) )
@@ -50,6 +49,6 @@ require (
github.com/vishvananda/netns v0.0.5 // indirect github.com/vishvananda/netns v0.0.5 // indirect
go.yaml.in/yaml/v2 v2.4.2 // indirect go.yaml.in/yaml/v2 v2.4.2 // indirect
golang.org/x/mod v0.24.0 // indirect golang.org/x/mod v0.24.0 // indirect
golang.org/x/time v0.7.0 // indirect golang.org/x/time v0.5.0 // indirect
golang.org/x/tools v0.33.0 // indirect golang.org/x/tools v0.33.0 // indirect
) )

30
go.sum
View File

@@ -155,15 +155,13 @@ go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
go.yaml.in/yaml/v2 v2.4.2 h1:DzmwEr2rDGHl7lsFgAHxmNz/1NlQ7xLIrlN2h5d1eGI= go.yaml.in/yaml/v2 v2.4.2 h1:DzmwEr2rDGHl7lsFgAHxmNz/1NlQ7xLIrlN2h5d1eGI=
go.yaml.in/yaml/v2 v2.4.2/go.mod h1:081UH+NErpNdqlCXm3TtEran0rJZGxAYx9hb/ELlsPU= go.yaml.in/yaml/v2 v2.4.2/go.mod h1:081UH+NErpNdqlCXm3TtEran0rJZGxAYx9hb/ELlsPU=
go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc=
go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg=
golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
golang.org/x/crypto v0.44.0 h1:A97SsFvM3AIwEEmTBiaxPPTYpDC47w720rdiiUvgoAU= golang.org/x/crypto v0.43.0 h1:dduJYIi3A3KOfdGOHX8AVZ/jGiyPa3IbBozJ5kNuE04=
golang.org/x/crypto v0.44.0/go.mod h1:013i+Nw79BMiQiMsOPcVCB5ZIJbYkerPrGnOa00tvmc= golang.org/x/crypto v0.43.0/go.mod h1:BFbav4mRNlXJL4wNeejLpWxB7wMbc79PdRGhWKncxR0=
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 h1:Di6/M8l0O2lCLc6VVRWhgCiApHV8MnQurBnFSHsQtNY= golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 h1:Di6/M8l0O2lCLc6VVRWhgCiApHV8MnQurBnFSHsQtNY=
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc= golang.org/x/exp v0.0.0-20230725093048-515e97ebf090/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc=
golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY= golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY=
@@ -182,8 +180,8 @@ golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLL
golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.46.0 h1:giFlY12I07fugqwPuWJi68oOnpfqFnJIJzaIIm2JVV4= golang.org/x/net v0.45.0 h1:RLBg5JKixCy82FtLJpeNlVM0nrSqpCRYzVU1n8kj0tM=
golang.org/x/net v0.46.0/go.mod h1:Q9BGdFy1y4nkUwiLvT5qtyhAnEHgnQ/zd8PfU6nc210= golang.org/x/net v0.45.0/go.mod h1:ECOoLqd5U3Lhyeyo/QDCEVQ4sNgYsqvCZ722XogGieY=
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
@@ -191,8 +189,8 @@ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJ
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I= golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug=
golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
@@ -209,16 +207,16 @@ golang.org/x/sys v0.0.0-20210603081109-ebe580a85c40/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc= golang.org/x/sys v0.37.0 h1:fdNQudmxPjkdUTPnLn5mdQv7Zwvbvpaxqs831goi9kQ=
golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= golang.org/x/sys v0.37.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.37.0 h1:8EGAD0qCmHYZg6J17DvsMy9/wJ7/D/4pV/wfnld5lTU= golang.org/x/term v0.36.0 h1:zMPR+aF8gfksFprF/Nc/rd1wRS1EI6nDBGyWAvDzx2Q=
golang.org/x/term v0.37.0/go.mod h1:5pB4lxRNYYVZuTLmy8oR2BH8dflOR+IbTYFD8fi3254= golang.org/x/term v0.36.0/go.mod h1:Qu394IJq6V6dCBRgwqshf3mPF85AqzYEzofzRdZkWss=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/time v0.7.0 h1:ntUhktv3OPE6TgYxXWv9vKvUSJyIFJlyohwbkEwPrKQ= golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk=
golang.org/x/time v0.7.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
@@ -244,8 +242,8 @@ google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miE
google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo= google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo=
google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU=
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
google.golang.org/protobuf v1.36.10 h1:AYd7cD/uASjIL6Q9LiTjz8JLcrh/88q5UObnmY3aOOE= google.golang.org/protobuf v1.36.8 h1:xHScyCOEuuwZEc6UtSOvPbAT4zRh0xcNRYekJwfqyMc=
google.golang.org/protobuf v1.36.10/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco= google.golang.org/protobuf v1.36.8/go.mod h1:fuxRtAxBytpl4zzqUh6/eyUujkJdNiuEkXntxiD/uRU=
gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw= gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=

View File

@@ -2,6 +2,7 @@ package nebula
import ( import (
"net/netip" "net/netip"
"slices"
"time" "time"
"github.com/flynn/noise" "github.com/flynn/noise"
@@ -22,19 +23,15 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
return false return false
} }
// If we're connecting to a v6 address we must use a v2 cert
cs := f.pki.getCertState() cs := f.pki.getCertState()
v := cs.initiatingVersion v := cs.initiatingVersion
if hh.initiatingVersionOverride != cert.VersionPre1 {
v = hh.initiatingVersionOverride
} else if v < cert.Version2 {
// If we're connecting to a v6 address we should encourage use of a V2 cert
for _, a := range hh.hostinfo.vpnAddrs { for _, a := range hh.hostinfo.vpnAddrs {
if a.Is6() { if a.Is6() {
v = cert.Version2 v = cert.Version2
break break
} }
} }
}
crt := cs.getCertificate(v) crt := cs.getCertificate(v)
if crt == nil { if crt == nil {
@@ -51,7 +48,6 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}). WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
WithField("certVersion", v). WithField("certVersion", v).
Error("Unable to handshake with host because no certificate handshake bytes is available") Error("Unable to handshake with host because no certificate handshake bytes is available")
return false
} }
ci, err := NewConnectionState(f.l, cs, crt, true, noise.HandshakeIX) ci, err := NewConnectionState(f.l, cs, crt, true, noise.HandshakeIX)
@@ -107,7 +103,6 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}). WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
WithField("certVersion", cs.initiatingVersion). WithField("certVersion", cs.initiatingVersion).
Error("Unable to handshake with host because no certificate is available") Error("Unable to handshake with host because no certificate is available")
return
} }
ci, err := NewConnectionState(f.l, cs, crt, false, noise.HandshakeIX) ci, err := NewConnectionState(f.l, cs, crt, false, noise.HandshakeIX)
@@ -148,8 +143,8 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
remoteCert, err := f.pki.GetCAPool().VerifyCertificate(time.Now(), rc) remoteCert, err := f.pki.GetCAPool().VerifyCertificate(time.Now(), rc)
if err != nil { if err != nil {
fp, fperr := rc.Fingerprint() fp, err := rc.Fingerprint()
if fperr != nil { if err != nil {
fp = "<error generating certificate fingerprint>" fp = "<error generating certificate fingerprint>"
} }
@@ -168,19 +163,16 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
if remoteCert.Certificate.Version() != ci.myCert.Version() { if remoteCert.Certificate.Version() != ci.myCert.Version() {
// We started off using the wrong certificate version, lets see if we can match the version that was sent to us // We started off using the wrong certificate version, lets see if we can match the version that was sent to us
myCertOtherVersion := cs.getCertificate(remoteCert.Certificate.Version()) rc := cs.getCertificate(remoteCert.Certificate.Version())
if myCertOtherVersion == nil { if rc == nil {
if f.l.Level >= logrus.DebugLevel { f.l.WithError(err).WithField("udpAddr", addr).
f.l.WithError(err).WithFields(m{ WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).WithField("cert", remoteCert).
"udpAddr": addr, Info("Unable to handshake with host due to missing certificate version")
"handshake": m{"stage": 1, "style": "ix_psk0"}, return
"cert": remoteCert,
}).Debug("Might be unable to handshake with host due to missing certificate version")
} }
} else {
// Record the certificate we are actually using // Record the certificate we are actually using
ci.myCert = myCertOtherVersion ci.myCert = rc
}
} }
if len(remoteCert.Certificate.Networks()) == 0 { if len(remoteCert.Certificate.Networks()) == 0 {
@@ -191,17 +183,17 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
return return
} }
var vpnAddrs []netip.Addr
var filteredNetworks []netip.Prefix
certName := remoteCert.Certificate.Name() certName := remoteCert.Certificate.Name()
certVersion := remoteCert.Certificate.Version() certVersion := remoteCert.Certificate.Version()
fingerprint := remoteCert.Fingerprint fingerprint := remoteCert.Fingerprint
issuer := remoteCert.Certificate.Issuer() issuer := remoteCert.Certificate.Issuer()
vpnNetworks := remoteCert.Certificate.Networks()
anyVpnAddrsInCommon := false for _, network := range remoteCert.Certificate.Networks() {
vpnAddrs := make([]netip.Addr, len(vpnNetworks)) vpnAddr := network.Addr()
for i, network := range vpnNetworks { if f.myVpnAddrsTable.Contains(vpnAddr) {
if f.myVpnAddrsTable.Contains(network.Addr()) { f.l.WithField("vpnAddr", vpnAddr).WithField("udpAddr", addr).
f.l.WithField("vpnNetworks", vpnNetworks).WithField("udpAddr", addr).
WithField("certName", certName). WithField("certName", certName).
WithField("certVersion", certVersion). WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint). WithField("fingerprint", fingerprint).
@@ -209,10 +201,24 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Refusing to handshake with myself") WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Refusing to handshake with myself")
return return
} }
vpnAddrs[i] = network.Addr()
if f.myVpnNetworksTable.Contains(network.Addr()) { // vpnAddrs outside our vpn networks are of no use to us, filter them out
anyVpnAddrsInCommon = true if !f.myVpnNetworksTable.Contains(vpnAddr) {
continue
} }
filteredNetworks = append(filteredNetworks, network)
vpnAddrs = append(vpnAddrs, vpnAddr)
}
if len(vpnAddrs) == 0 {
f.l.WithError(err).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint).
WithField("issuer", issuer).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("No usable vpn addresses from host, refusing handshake")
return
} }
if addr.IsValid() { if addr.IsValid() {
@@ -249,30 +255,26 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
}, },
} }
msgRxL := f.l.WithFields(m{ f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
"vpnAddrs": vpnAddrs, WithField("certName", certName).
"udpAddr": addr, WithField("certVersion", certVersion).
"certName": certName, WithField("fingerprint", fingerprint).
"certVersion": certVersion, WithField("issuer", issuer).
"fingerprint": fingerprint, WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
"issuer": issuer, WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
"initiatorIndex": hs.Details.InitiatorIndex, Info("Handshake message received")
"responderIndex": hs.Details.ResponderIndex,
"remoteIndex": h.RemoteIndex,
"handshake": m{"stage": 1, "style": "ix_psk0"},
})
if anyVpnAddrsInCommon {
msgRxL.Info("Handshake message received")
} else {
//todo warn if not lighthouse or relay?
msgRxL.Info("Handshake message received, but no vpnNetworks in common.")
}
hs.Details.ResponderIndex = myIndex hs.Details.ResponderIndex = myIndex
hs.Details.Cert = cs.getHandshakeBytes(ci.myCert.Version()) hs.Details.Cert = cs.getHandshakeBytes(ci.myCert.Version())
if hs.Details.Cert == nil { if hs.Details.Cert == nil {
msgRxL.WithField("myCertVersion", ci.myCert.Version()). f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint).
WithField("issuer", issuer).
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
WithField("certVersion", ci.myCert.Version()).
Error("Unable to handshake with host because no certificate handshake bytes is available") Error("Unable to handshake with host because no certificate handshake bytes is available")
return return
} }
@@ -330,7 +332,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
hostinfo.remotes = f.lightHouse.QueryCache(vpnAddrs) hostinfo.remotes = f.lightHouse.QueryCache(vpnAddrs)
hostinfo.SetRemote(addr) hostinfo.SetRemote(addr)
hostinfo.buildNetworks(f.myVpnNetworksTable, remoteCert.Certificate) hostinfo.buildNetworks(filteredNetworks, remoteCert.Certificate.UnsafeNetworks())
existing, err := f.handshakeManager.CheckAndComplete(hostinfo, 0, f) existing, err := f.handshakeManager.CheckAndComplete(hostinfo, 0, f)
if err != nil { if err != nil {
@@ -571,22 +573,31 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0]) hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0])
} }
correctHostResponded := false var vpnAddrs []netip.Addr
anyVpnAddrsInCommon := false var filteredNetworks []netip.Prefix
vpnAddrs := make([]netip.Addr, len(vpnNetworks)) for _, network := range vpnNetworks {
for i, network := range vpnNetworks { // vpnAddrs outside our vpn networks are of no use to us, filter them out
vpnAddrs[i] = network.Addr() vpnAddr := network.Addr()
if f.myVpnNetworksTable.Contains(network.Addr()) { if !f.myVpnNetworksTable.Contains(vpnAddr) {
anyVpnAddrsInCommon = true continue
} }
if hostinfo.vpnAddrs[0] == network.Addr() {
// todo is it more correct to see if any of hostinfo.vpnAddrs are in the cert? it should have len==1, but one day it might not? filteredNetworks = append(filteredNetworks, network)
correctHostResponded = true vpnAddrs = append(vpnAddrs, vpnAddr)
} }
if len(vpnAddrs) == 0 {
f.l.WithError(err).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint).
WithField("issuer", issuer).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).Error("No usable vpn addresses from host, refusing handshake")
return true
} }
// Ensure the right host responded // Ensure the right host responded
if !correctHostResponded { if !slices.Contains(vpnAddrs, hostinfo.vpnAddrs[0]) {
f.l.WithField("intendedVpnAddrs", hostinfo.vpnAddrs).WithField("haveVpnNetworks", vpnNetworks). f.l.WithField("intendedVpnAddrs", hostinfo.vpnAddrs).WithField("haveVpnNetworks", vpnNetworks).
WithField("udpAddr", addr). WithField("udpAddr", addr).
WithField("certName", certName). WithField("certName", certName).
@@ -598,7 +609,6 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
f.handshakeManager.DeleteHostInfo(hostinfo) f.handshakeManager.DeleteHostInfo(hostinfo)
// Create a new hostinfo/handshake for the intended vpn ip // Create a new hostinfo/handshake for the intended vpn ip
//TODO is hostinfo.vpnAddrs[0] always the address to use?
f.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], func(newHH *HandshakeHostInfo) { f.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], func(newHH *HandshakeHostInfo) {
// Block the current used address // Block the current used address
newHH.hostinfo.remotes = hostinfo.remotes newHH.hostinfo.remotes = hostinfo.remotes
@@ -625,7 +635,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
ci.window.Update(f.l, 2) ci.window.Update(f.l, 2)
duration := time.Since(hh.startTime).Nanoseconds() duration := time.Since(hh.startTime).Nanoseconds()
msgRxL := f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr). f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
WithField("certName", certName). WithField("certName", certName).
WithField("certVersion", certVersion). WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint). WithField("fingerprint", fingerprint).
@@ -633,17 +643,12 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
WithField("durationNs", duration). WithField("durationNs", duration).
WithField("sentCachedPackets", len(hh.packetStore)) WithField("sentCachedPackets", len(hh.packetStore)).
if anyVpnAddrsInCommon { Info("Handshake message received")
msgRxL.Info("Handshake message received")
} else {
//todo warn if not lighthouse or relay?
msgRxL.Info("Handshake message received, but no vpnNetworks in common.")
}
// Build up the radix for the firewall if we have subnets in the cert // Build up the radix for the firewall if we have subnets in the cert
hostinfo.vpnAddrs = vpnAddrs hostinfo.vpnAddrs = vpnAddrs
hostinfo.buildNetworks(f.myVpnNetworksTable, remoteCert.Certificate) hostinfo.buildNetworks(filteredNetworks, remoteCert.Certificate.UnsafeNetworks())
// Complete our handshake and update metrics, this will replace any existing tunnels for the vpnAddrs here // Complete our handshake and update metrics, this will replace any existing tunnels for the vpnAddrs here
f.handshakeManager.Complete(hostinfo, f) f.handshakeManager.Complete(hostinfo, f)

View File

@@ -70,7 +70,6 @@ type HandshakeHostInfo struct {
startTime time.Time // Time that we first started trying with this handshake startTime time.Time // Time that we first started trying with this handshake
ready bool // Is the handshake ready ready bool // Is the handshake ready
initiatingVersionOverride cert.Version // Should we use a non-default cert version for this handshake?
counter int64 // How many attempts have we made so far counter int64 // How many attempts have we made so far
lastRemotes []netip.AddrPort // Remotes that we sent to during the previous attempt lastRemotes []netip.AddrPort // Remotes that we sent to during the previous attempt
packetStore []*cachedPacket // A set of packets to be transmitted once the handshake completes packetStore []*cachedPacket // A set of packets to be transmitted once the handshake completes
@@ -269,12 +268,12 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
hostinfo.logger(hm.l).WithField("relays", hostinfo.remotes.relays).Info("Attempt to relay through hosts") hostinfo.logger(hm.l).WithField("relays", hostinfo.remotes.relays).Info("Attempt to relay through hosts")
// Send a RelayRequest to all known Relay IP's // Send a RelayRequest to all known Relay IP's
for _, relay := range hostinfo.remotes.relays { for _, relay := range hostinfo.remotes.relays {
// Don't relay through the host I'm trying to connect to // Don't relay to myself
if relay == vpnIp { if relay == vpnIp {
continue continue
} }
// Don't relay to myself // Don't relay through the host I'm trying to connect to
if hm.f.myVpnAddrsTable.Contains(relay) { if hm.f.myVpnAddrsTable.Contains(relay) {
continue continue
} }

View File

@@ -212,18 +212,6 @@ func (rs *RelayState) InsertRelay(ip netip.Addr, idx uint32, r *Relay) {
rs.relayForByIdx[idx] = r rs.relayForByIdx[idx] = r
} }
type NetworkType uint8
const (
NetworkTypeUnknown NetworkType = iota
// NetworkTypeVPN is a network that overlaps one or more of the vpnNetworks in our certificate
NetworkTypeVPN
// NetworkTypeVPNPeer is a network that does not overlap one of our networks
NetworkTypeVPNPeer
// NetworkTypeUnsafe is a network from Certificate.UnsafeNetworks()
NetworkTypeUnsafe
)
type HostInfo struct { type HostInfo struct {
remote netip.AddrPort remote netip.AddrPort
remotes *RemoteList remotes *RemoteList
@@ -237,8 +225,8 @@ type HostInfo struct {
// vpn networks but were removed because they are not usable // vpn networks but were removed because they are not usable
vpnAddrs []netip.Addr vpnAddrs []netip.Addr
// networks is a combination of specific vpn addresses (not prefixes!) and full unsafe networks assigned to this host. // networks are both all vpn and unsafe networks assigned to this host
networks *bart.Table[NetworkType] networks *bart.Lite
relayState RelayState relayState RelayState
// HandshakePacket records the packets used to create this hostinfo // HandshakePacket records the packets used to create this hostinfo
@@ -742,26 +730,20 @@ func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, newRemote netip.AddrPort) b
return false return false
} }
// buildNetworks fills in the networks field of HostInfo. It accepts a cert.Certificate so you never ever mix the network types up. func (i *HostInfo) buildNetworks(networks, unsafeNetworks []netip.Prefix) {
func (i *HostInfo) buildNetworks(myVpnNetworksTable *bart.Lite, c cert.Certificate) { if len(networks) == 1 && len(unsafeNetworks) == 0 {
if len(c.Networks()) == 1 && len(c.UnsafeNetworks()) == 0 { // Simple case, no CIDRTree needed
if myVpnNetworksTable.Contains(c.Networks()[0].Addr()) { return
return // Simple case, no BART needed
}
} }
i.networks = new(bart.Table[NetworkType]) i.networks = new(bart.Lite)
for _, network := range c.Networks() { for _, network := range networks {
nprefix := netip.PrefixFrom(network.Addr(), network.Addr().BitLen()) nprefix := netip.PrefixFrom(network.Addr(), network.Addr().BitLen())
if myVpnNetworksTable.Contains(network.Addr()) { i.networks.Insert(nprefix)
i.networks.Insert(nprefix, NetworkTypeVPN)
} else {
i.networks.Insert(nprefix, NetworkTypeVPNPeer)
}
} }
for _, network := range c.UnsafeNetworks() { for _, network := range unsafeNetworks {
i.networks.Insert(network, NetworkTypeUnsafe) i.networks.Insert(network)
} }
} }

104
inside.go
View File

@@ -2,18 +2,16 @@ package nebula
import ( import (
"net/netip" "net/netip"
"time"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula/firewall" "github.com/slackhq/nebula/firewall"
"github.com/slackhq/nebula/header" "github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/noiseutil" "github.com/slackhq/nebula/noiseutil"
"github.com/slackhq/nebula/packet"
"github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/routing"
) )
func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb []byte, out *packet.Packet, q int, localCache firewall.ConntrackCache, now time.Time) { func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb, out []byte, q int, localCache firewall.ConntrackCache) {
err := newPacket(packet, false, fwPacket) err := newPacket(packet, false, fwPacket)
if err != nil { if err != nil {
if f.l.Level >= logrus.DebugLevel { if f.l.Level >= logrus.DebugLevel {
@@ -55,7 +53,7 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
}) })
if hostinfo == nil { if hostinfo == nil {
f.rejectInside(packet, out.Payload, q) //todo vector? f.rejectInside(packet, out, q)
if f.l.Level >= logrus.DebugLevel { if f.l.Level >= logrus.DebugLevel {
f.l.WithField("vpnAddr", fwPacket.RemoteAddr). f.l.WithField("vpnAddr", fwPacket.RemoteAddr).
WithField("fwPacket", fwPacket). WithField("fwPacket", fwPacket).
@@ -68,11 +66,12 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
return return
} }
dropReason := f.firewall.Drop(*fwPacket, false, hostinfo, f.pki.GetCAPool(), localCache, now) dropReason := f.firewall.Drop(*fwPacket, false, hostinfo, f.pki.GetCAPool(), localCache)
if dropReason == nil { if dropReason == nil {
f.sendNoMetricsDelayed(header.Message, 0, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, packet, nb, out, q) f.sendNoMetrics(header.Message, 0, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, packet, nb, out, q)
} else { } else {
f.rejectInside(packet, out.Payload, q) //todo vector? f.rejectInside(packet, out, q)
if f.l.Level >= logrus.DebugLevel { if f.l.Level >= logrus.DebugLevel {
hostinfo.logger(f.l). hostinfo.logger(f.l).
WithField("fwPacket", fwPacket). WithField("fwPacket", fwPacket).
@@ -121,10 +120,9 @@ func (f *Interface) rejectOutside(packet []byte, ci *ConnectionState, hostinfo *
f.sendNoMetrics(header.Message, 0, ci, hostinfo, netip.AddrPort{}, out, nb, packet, q) f.sendNoMetrics(header.Message, 0, ci, hostinfo, netip.AddrPort{}, out, nb, packet, q)
} }
// Handshake will attempt to initiate a tunnel with the provided vpn address. This is a no-op if the tunnel is already established or being established // Handshake will attempt to initiate a tunnel with the provided vpn address if it is within our vpn networks. This is a no-op if the tunnel is already established or being established
// it does not check if it is within our vpn networks!
func (f *Interface) Handshake(vpnAddr netip.Addr) { func (f *Interface) Handshake(vpnAddr netip.Addr) {
f.handshakeManager.GetOrHandshake(vpnAddr, nil) f.getOrHandshakeNoRouting(vpnAddr, nil)
} }
// getOrHandshakeNoRouting returns nil if the vpnAddr is not routable. // getOrHandshakeNoRouting returns nil if the vpnAddr is not routable.
@@ -140,6 +138,7 @@ func (f *Interface) getOrHandshakeNoRouting(vpnAddr netip.Addr, cacheCallback fu
// getOrHandshakeConsiderRouting will try to find the HostInfo to handle this packet, starting a handshake if necessary. // getOrHandshakeConsiderRouting will try to find the HostInfo to handle this packet, starting a handshake if necessary.
// If the 2nd return var is false then the hostinfo is not ready to be used in a tunnel. // If the 2nd return var is false then the hostinfo is not ready to be used in a tunnel.
func (f *Interface) getOrHandshakeConsiderRouting(fwPacket *firewall.Packet, cacheCallback func(*HandshakeHostInfo)) (*HostInfo, bool) { func (f *Interface) getOrHandshakeConsiderRouting(fwPacket *firewall.Packet, cacheCallback func(*HandshakeHostInfo)) (*HostInfo, bool) {
destinationAddr := fwPacket.RemoteAddr destinationAddr := fwPacket.RemoteAddr
hostinfo, ready := f.getOrHandshakeNoRouting(destinationAddr, cacheCallback) hostinfo, ready := f.getOrHandshakeNoRouting(destinationAddr, cacheCallback)
@@ -219,7 +218,7 @@ func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubTyp
} }
// check if packet is in outbound fw rules // check if packet is in outbound fw rules
dropReason := f.firewall.Drop(*fp, false, hostinfo, f.pki.GetCAPool(), nil, time.Now()) dropReason := f.firewall.Drop(*fp, false, hostinfo, f.pki.GetCAPool(), nil)
if dropReason != nil { if dropReason != nil {
if f.l.Level >= logrus.DebugLevel { if f.l.Level >= logrus.DebugLevel {
f.l.WithField("fwPacket", fp). f.l.WithField("fwPacket", fp).
@@ -232,10 +231,9 @@ func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubTyp
f.sendNoMetrics(header.Message, st, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, p, nb, out, 0) f.sendNoMetrics(header.Message, st, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, p, nb, out, 0)
} }
// SendMessageToVpnAddr handles real addr:port lookup and sends to the current best known address for vpnAddr. // SendMessageToVpnAddr handles real addr:port lookup and sends to the current best known address for vpnAddr
// This function ignores myVpnNetworksTable, and will always attempt to treat the address as a vpnAddr
func (f *Interface) SendMessageToVpnAddr(t header.MessageType, st header.MessageSubType, vpnAddr netip.Addr, p, nb, out []byte) { func (f *Interface) SendMessageToVpnAddr(t header.MessageType, st header.MessageSubType, vpnAddr netip.Addr, p, nb, out []byte) {
hostInfo, ready := f.handshakeManager.GetOrHandshake(vpnAddr, func(hh *HandshakeHostInfo) { hostInfo, ready := f.getOrHandshakeNoRouting(vpnAddr, func(hh *HandshakeHostInfo) {
hh.cachePacket(f.l, t, st, p, f.SendMessageToHostInfo, f.cachedPacketMetrics) hh.cachePacket(f.l, t, st, p, f.SendMessageToHostInfo, f.cachedPacketMetrics)
}) })
@@ -411,81 +409,3 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType
} }
} }
} }
func (f *Interface) sendNoMetricsDelayed(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote netip.AddrPort, p, nb []byte, out *packet.Packet, q int) {
if ci.eKey == nil {
return
}
useRelay := !remote.IsValid() && !hostinfo.remote.IsValid()
fullOut := out.Payload
if useRelay {
if len(out.Payload) < header.Len {
// out always has a capacity of mtu, but not always a length greater than the header.Len.
// Grow it to make sure the next operation works.
out.Payload = out.Payload[:header.Len]
}
// Save a header's worth of data at the front of the 'out' buffer.
out.Payload = out.Payload[header.Len:]
}
if noiseutil.EncryptLockNeeded {
// NOTE: for goboring AESGCMTLS we need to lock because of the nonce check
ci.writeLock.Lock()
}
c := ci.messageCounter.Add(1)
//l.WithField("trace", string(debug.Stack())).Error("out Header ", &Header{Version, t, st, 0, hostinfo.remoteIndexId, c}, p)
out.Payload = header.Encode(out.Payload, header.Version, t, st, hostinfo.remoteIndexId, c)
f.connectionManager.Out(hostinfo)
// Query our LH if we haven't since the last time we've been rebound, this will cause the remote to punch against
// all our addrs and enable a faster roaming.
if t != header.CloseTunnel && hostinfo.lastRebindCount != f.rebindCount {
//NOTE: there is an update hole if a tunnel isn't used and exactly 256 rebinds occur before the tunnel is
// finally used again. This tunnel would eventually be torn down and recreated if this action didn't help.
f.lightHouse.QueryServer(hostinfo.vpnAddrs[0])
hostinfo.lastRebindCount = f.rebindCount
if f.l.Level >= logrus.DebugLevel {
f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).Debug("Lighthouse update triggered for punch due to rebind counter")
}
}
var err error
out.Payload, err = ci.eKey.EncryptDanger(out.Payload, out.Payload, p, c, nb)
if noiseutil.EncryptLockNeeded {
ci.writeLock.Unlock()
}
if err != nil {
hostinfo.logger(f.l).WithError(err).
WithField("udpAddr", remote).WithField("counter", c).
WithField("attemptedCounter", c).
Error("Failed to encrypt outgoing packet")
return
}
if remote.IsValid() {
err = f.writers[q].Prep(out, remote)
if err != nil {
hostinfo.logger(f.l).WithError(err).WithField("udpAddr", remote).Error("Failed to write outgoing packet")
}
} else if hostinfo.remote.IsValid() {
err = f.writers[q].Prep(out, hostinfo.remote)
if err != nil {
hostinfo.logger(f.l).WithError(err).WithField("udpAddr", remote).Error("Failed to write outgoing packet")
}
} else {
// Try to send via a relay
for _, relayIP := range hostinfo.relayState.CopyRelayIps() {
relayHostInfo, relay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relayIP)
if err != nil {
hostinfo.relayState.DeleteRelay(relayIP)
hostinfo.logger(f.l).WithField("relay", relayIP).WithError(err).Info("sendNoMetrics failed to find HostInfo")
continue
}
//todo vector!!
f.SendVia(relayHostInfo, relay, out.Payload, nb, fullOut[:header.Len+len(out.Payload)], true)
break
}
}
}

View File

@@ -4,6 +4,7 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"io"
"net/netip" "net/netip"
"os" "os"
"runtime" "runtime"
@@ -17,12 +18,10 @@ import (
"github.com/slackhq/nebula/firewall" "github.com/slackhq/nebula/firewall"
"github.com/slackhq/nebula/header" "github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/overlay" "github.com/slackhq/nebula/overlay"
"github.com/slackhq/nebula/packet"
"github.com/slackhq/nebula/udp" "github.com/slackhq/nebula/udp"
) )
const mtu = 9001 const mtu = 9001
const batch = 1024 //todo config!
type InterfaceConfig struct { type InterfaceConfig struct {
HostMap *HostMap HostMap *HostMap
@@ -87,18 +86,12 @@ type Interface struct {
conntrackCacheTimeout time.Duration conntrackCacheTimeout time.Duration
writers []udp.Conn writers []udp.Conn
readers []overlay.TunDev readers []io.ReadWriteCloser
metricHandshakes metrics.Histogram metricHandshakes metrics.Histogram
messageMetrics *MessageMetrics messageMetrics *MessageMetrics
cachedPacketMetrics *cachedPacketMetrics cachedPacketMetrics *cachedPacketMetrics
listenInN int
listenOutN int
listenInMetric metrics.Histogram
listenOutMetric metrics.Histogram
l *logrus.Logger l *logrus.Logger
} }
@@ -184,7 +177,7 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
routines: c.routines, routines: c.routines,
version: c.version, version: c.version,
writers: make([]udp.Conn, c.routines), writers: make([]udp.Conn, c.routines),
readers: make([]overlay.TunDev, c.routines), readers: make([]io.ReadWriteCloser, c.routines),
myVpnNetworks: cs.myVpnNetworks, myVpnNetworks: cs.myVpnNetworks,
myVpnNetworksTable: cs.myVpnNetworksTable, myVpnNetworksTable: cs.myVpnNetworksTable,
myVpnAddrs: cs.myVpnAddrs, myVpnAddrs: cs.myVpnAddrs,
@@ -203,8 +196,6 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
l: c.l, l: c.l,
} }
ifce.listenInMetric = metrics.GetOrRegisterHistogram("vhost.listenIn.n", nil, metrics.NewExpDecaySample(1028, 0.015))
ifce.listenOutMetric = metrics.GetOrRegisterHistogram("vhost.listenOut.n", nil, metrics.NewExpDecaySample(1028, 0.015))
ifce.tryPromoteEvery.Store(c.tryPromoteEvery) ifce.tryPromoteEvery.Store(c.tryPromoteEvery)
ifce.reQueryEvery.Store(c.reQueryEvery) ifce.reQueryEvery.Store(c.reQueryEvery)
@@ -234,7 +225,7 @@ func (f *Interface) activate() {
metrics.GetOrRegisterGauge("routines", nil).Update(int64(f.routines)) metrics.GetOrRegisterGauge("routines", nil).Update(int64(f.routines))
// Prepare n tun queues // Prepare n tun queues
var reader overlay.TunDev = f.inside var reader io.ReadWriteCloser = f.inside
for i := 0; i < f.routines; i++ { for i := 0; i < f.routines; i++ {
if i > 0 { if i > 0 {
reader, err = f.inside.NewMultiQueueReader() reader, err = f.inside.NewMultiQueueReader()
@@ -263,71 +254,40 @@ func (f *Interface) run() {
} }
} }
func (f *Interface) listenOut(q int) { func (f *Interface) listenOut(i int) {
runtime.LockOSThread() runtime.LockOSThread()
var li udp.Conn var li udp.Conn
if q > 0 { if i > 0 {
li = f.writers[q] li = f.writers[i]
} else { } else {
li = f.outside li = f.outside
} }
ctCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout) ctCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
lhh := f.lightHouse.NewRequestHandler() lhh := f.lightHouse.NewRequestHandler()
plaintext := make([]byte, udp.MTU)
outPackets := make([]*packet.OutPacket, batch)
for i := 0; i < batch; i++ {
outPackets[i] = packet.NewOut()
}
h := &header.H{} h := &header.H{}
fwPacket := &firewall.Packet{} fwPacket := &firewall.Packet{}
nb := make([]byte, 12, 12) nb := make([]byte, 12, 12)
toSend := make([][]byte, batch) li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) {
f.readOutsidePackets(fromUdpAddr, nil, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get(f.l))
li.ListenOut(func(pkts []*packet.Packet) {
toSend = toSend[:0]
for i := range outPackets {
outPackets[i].Valid = false
outPackets[i].SegCounter = 0
}
f.readOutsidePacketsMany(pkts, outPackets, h, fwPacket, lhh, nb, q, ctCache.Get(f.l), time.Now())
//we opportunistically tx, but try to also send stragglers
if _, err := f.readers[q].WriteMany(outPackets, q); err != nil {
f.l.WithError(err).Error("Failed to send packets")
}
//todo I broke this
//n := len(toSend)
//if f.l.Level == logrus.DebugLevel {
// f.listenOutMetric.Update(int64(n))
//}
//f.listenOutN = n
}) })
} }
func (f *Interface) listenIn(reader overlay.TunDev, queueNum int) { func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
runtime.LockOSThread() runtime.LockOSThread()
packet := make([]byte, mtu)
out := make([]byte, mtu)
fwPacket := &firewall.Packet{} fwPacket := &firewall.Packet{}
nb := make([]byte, 12, 12) nb := make([]byte, 12, 12)
conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout) conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
packets := make([]*packet.VirtIOPacket, batch)
outPackets := make([]*packet.Packet, batch)
for i := 0; i < batch; i++ {
packets[i] = packet.NewVIO()
outPackets[i] = packet.New(false) //todo?
}
for { for {
n, err := reader.ReadMany(packets, queueNum) n, err := reader.Read(packet)
//todo!!
if err != nil { if err != nil {
if errors.Is(err, os.ErrClosed) && f.closed.Load() { if errors.Is(err, os.ErrClosed) && f.closed.Load() {
return return
@@ -338,22 +298,7 @@ func (f *Interface) listenIn(reader overlay.TunDev, queueNum int) {
os.Exit(2) os.Exit(2)
} }
if f.l.Level == logrus.DebugLevel { f.consumeInsidePacket(packet[:n], fwPacket, nb, out, i, conntrackCache.Get(f.l))
f.listenInMetric.Update(int64(n))
}
f.listenInN = n
now := time.Now()
for i, pkt := range packets[:n] {
outPackets[i].OutLen = -1
f.consumeInsidePacket(pkt.Payload, fwPacket, nb, outPackets[i], queueNum, conntrackCache.Get(f.l), now)
reader.RecycleRxSeg(pkt, i == (n-1), queueNum) //todo handle err?
pkt.Reset()
}
_, err = f.writers[queueNum].WriteBatch(outPackets[:n])
if err != nil {
f.l.WithError(err).Error("Error while writing outbound packets")
}
} }
} }
@@ -491,11 +436,6 @@ func (f *Interface) emitStats(ctx context.Context, i time.Duration) {
} else { } else {
certMaxVersion.Update(int64(certState.v1Cert.Version())) certMaxVersion.Update(int64(certState.v1Cert.Version()))
} }
if f.l.Level != logrus.DebugLevel {
f.listenInMetric.Update(int64(f.listenInN))
f.listenOutMetric.Update(int64(f.listenOutN))
}
} }
} }
} }

View File

@@ -360,8 +360,7 @@ func (lh *LightHouse) parseLighthouses(c *config.C) ([]netip.Addr, error) {
} }
if !lh.myVpnNetworksTable.Contains(addr) { if !lh.myVpnNetworksTable.Contains(addr) {
lh.l.WithFields(m{"vpnAddr": addr, "networks": lh.myVpnNetworks}). return nil, util.NewContextualError("lighthouse host is not in our networks, invalid", m{"vpnAddr": addr, "networks": lh.myVpnNetworks}, nil)
Warn("lighthouse host is not within our networks, lighthouse functionality will work but layer 3 network traffic to the lighthouse will not")
} }
out[i] = addr out[i] = addr
} }
@@ -432,8 +431,7 @@ func (lh *LightHouse) loadStaticMap(c *config.C, staticList map[netip.Addr]struc
} }
if !lh.myVpnNetworksTable.Contains(vpnAddr) { if !lh.myVpnNetworksTable.Contains(vpnAddr) {
lh.l.WithFields(m{"vpnAddr": vpnAddr, "networks": lh.myVpnNetworks, "entry": i + 1}). return util.NewContextualError("static_host_map key is not in our network, invalid", m{"vpnAddr": vpnAddr, "networks": lh.myVpnNetworks, "entry": i + 1}, nil)
Warn("static_host_map key is not within our networks, layer 3 network traffic to this host will not work")
} }
vals, ok := v.([]any) vals, ok := v.([]any)
@@ -1339,19 +1337,12 @@ func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, fromVpn
} }
} }
remoteAllowList := lhh.lh.GetRemoteAllowList()
for _, a := range n.Details.V4AddrPorts { for _, a := range n.Details.V4AddrPorts {
b := protoV4AddrPortToNetAddrPort(a) punch(protoV4AddrPortToNetAddrPort(a), detailsVpnAddr)
if remoteAllowList.Allow(detailsVpnAddr, b.Addr()) {
punch(b, detailsVpnAddr)
}
} }
for _, a := range n.Details.V6AddrPorts { for _, a := range n.Details.V6AddrPorts {
b := protoV6AddrPortToNetAddrPort(a) punch(protoV6AddrPortToNetAddrPort(a), detailsVpnAddr)
if remoteAllowList.Allow(detailsVpnAddr, b.Addr()) {
punch(b, detailsVpnAddr)
}
} }
// This sends a nebula test packet to the host trying to contact us. In the case // This sends a nebula test packet to the host trying to contact us. In the case

View File

@@ -14,7 +14,7 @@ import (
"github.com/slackhq/nebula/test" "github.com/slackhq/nebula/test"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"go.yaml.in/yaml/v3" "gopkg.in/yaml.v3"
) )
func TestOldIPv4Only(t *testing.T) { func TestOldIPv4Only(t *testing.T) {

View File

@@ -13,7 +13,7 @@ import (
"github.com/slackhq/nebula/sshd" "github.com/slackhq/nebula/sshd"
"github.com/slackhq/nebula/udp" "github.com/slackhq/nebula/udp"
"github.com/slackhq/nebula/util" "github.com/slackhq/nebula/util"
"go.yaml.in/yaml/v3" "gopkg.in/yaml.v3"
) )
type m = map[string]any type m = map[string]any
@@ -75,8 +75,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
if c.GetBool("sshd.enabled", false) { if c.GetBool("sshd.enabled", false) {
sshStart, err = configSSH(l, ssh, c) sshStart, err = configSSH(l, ssh, c)
if err != nil { if err != nil {
l.WithError(err).Warn("Failed to configure sshd, ssh debugging will not be available") return nil, util.ContextualizeIfNeeded("Error while configuring the sshd", err)
sshStart = nil
} }
} }

View File

@@ -7,7 +7,6 @@ import (
"time" "time"
"github.com/google/gopacket/layers" "github.com/google/gopacket/layers"
"github.com/slackhq/nebula/packet"
"golang.org/x/net/ipv6" "golang.org/x/net/ipv6"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
@@ -20,7 +19,7 @@ const (
minFwPacketLen = 4 minFwPacketLen = 4
) )
func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache, now time.Time) { func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache) {
err := h.Parse(packet) err := h.Parse(packet)
if err != nil { if err != nil {
// Hole punch packets are 0 or 1 byte big, so lets ignore printing those errors // Hole punch packets are 0 or 1 byte big, so lets ignore printing those errors
@@ -62,7 +61,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
switch h.Subtype { switch h.Subtype {
case header.MessageNone: case header.MessageNone:
if !f.decryptToTun(hostinfo, h.MessageCounter, out, packet, fwPacket, nb, q, localCache, now) { if !f.decryptToTun(hostinfo, h.MessageCounter, out, packet, fwPacket, nb, q, localCache) {
return return
} }
case header.MessageRelay: case header.MessageRelay:
@@ -97,7 +96,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
case TerminalType: case TerminalType:
// If I am the target of this relay, process the unwrapped packet // If I am the target of this relay, process the unwrapped packet
// From this recursive point, all these variables are 'burned'. We shouldn't rely on them again. // From this recursive point, all these variables are 'burned'. We shouldn't rely on them again.
f.readOutsidePackets(netip.AddrPort{}, &ViaSender{relayHI: hostinfo, remoteIdx: relay.RemoteIndex, relay: relay}, out[:0], signedPayload, h, fwPacket, lhf, nb, q, localCache, now) f.readOutsidePackets(netip.AddrPort{}, &ViaSender{relayHI: hostinfo, remoteIdx: relay.RemoteIndex, relay: relay}, out[:0], signedPayload, h, fwPacket, lhf, nb, q, localCache)
return return
case ForwardingType: case ForwardingType:
// Find the target HostInfo relay object // Find the target HostInfo relay object
@@ -217,217 +216,6 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
f.connectionManager.In(hostinfo) f.connectionManager.In(hostinfo)
} }
func (f *Interface) readOutsidePacketsMany(packets []*packet.Packet, out []*packet.OutPacket, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache, now time.Time) {
for i, pkt := range packets {
out[i].Scratch = out[i].Scratch[:0]
ip := pkt.AddrPort()
//l.Error("in packet ", header, packet[HeaderLen:])
if ip.IsValid() {
if f.myVpnNetworksTable.Contains(ip.Addr()) {
if f.l.Level >= logrus.DebugLevel {
f.l.WithField("udpAddr", ip).Debug("Refusing to process double encrypted packet")
}
return
}
}
//todo per-segment!
for segment := range pkt.Segments() {
err := h.Parse(segment)
if err != nil {
// Hole punch packets are 0 or 1 byte big, so lets ignore printing those errors
if len(segment) > 1 {
f.l.WithField("packet", pkt).Infof("Error while parsing inbound packet from %s: %s", ip, err)
}
return
}
var hostinfo *HostInfo
// verify if we've seen this index before, otherwise respond to the handshake initiation
if h.Type == header.Message && h.Subtype == header.MessageRelay {
hostinfo = f.hostMap.QueryRelayIndex(h.RemoteIndex)
} else {
hostinfo = f.hostMap.QueryIndex(h.RemoteIndex)
}
var ci *ConnectionState
if hostinfo != nil {
ci = hostinfo.ConnectionState
}
switch h.Type {
case header.Message:
// TODO handleEncrypted sends directly to addr on error. Handle this in the tunneling case.
if !f.handleEncrypted(ci, ip, h) {
return
}
switch h.Subtype {
case header.MessageNone:
if !f.decryptToTunDelayWrite(hostinfo, h.MessageCounter, out[i], pkt, segment, fwPacket, nb, q, localCache, now) {
return
}
case header.MessageRelay:
// The entire body is sent as AD, not encrypted.
// The packet consists of a 16-byte parsed Nebula header, Associated Data-protected payload, and a trailing 16-byte AEAD signature value.
// The packet is guaranteed to be at least 16 bytes at this point, b/c it got past the h.Parse() call above. If it's
// otherwise malformed (meaning, there is no trailing 16 byte AEAD value), then this will result in at worst a 0-length slice
// which will gracefully fail in the DecryptDanger call.
signedPayload := segment[:len(segment)-hostinfo.ConnectionState.dKey.Overhead()]
signatureValue := segment[len(segment)-hostinfo.ConnectionState.dKey.Overhead():]
out[i].Scratch, err = hostinfo.ConnectionState.dKey.DecryptDanger(out[i].Scratch, signedPayload, signatureValue, h.MessageCounter, nb)
if err != nil {
return
}
// Successfully validated the thing. Get rid of the Relay header.
signedPayload = signedPayload[header.Len:]
// Pull the Roaming parts up here, and return in all call paths.
f.handleHostRoaming(hostinfo, ip)
// Track usage of both the HostInfo and the Relay for the received & authenticated packet
f.connectionManager.In(hostinfo)
f.connectionManager.RelayUsed(h.RemoteIndex)
relay, ok := hostinfo.relayState.QueryRelayForByIdx(h.RemoteIndex)
if !ok {
// The only way this happens is if hostmap has an index to the correct HostInfo, but the HostInfo is missing
// its internal mapping. This should never happen.
hostinfo.logger(f.l).WithFields(logrus.Fields{"vpnAddrs": hostinfo.vpnAddrs, "remoteIndex": h.RemoteIndex}).Error("HostInfo missing remote relay index")
return
}
switch relay.Type {
case TerminalType:
// If I am the target of this relay, process the unwrapped packet
// From this recursive point, all these variables are 'burned'. We shouldn't rely on them again.
f.readOutsidePackets(netip.AddrPort{}, &ViaSender{relayHI: hostinfo, remoteIdx: relay.RemoteIndex, relay: relay}, out[i].Scratch[:0], signedPayload, h, fwPacket, lhf, nb, q, localCache, now)
return
case ForwardingType:
// Find the target HostInfo relay object
targetHI, targetRelay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relay.PeerAddr)
if err != nil {
hostinfo.logger(f.l).WithField("relayTo", relay.PeerAddr).WithError(err).WithField("hostinfo.vpnAddrs", hostinfo.vpnAddrs).Info("Failed to find target host info by ip")
return
}
// If that relay is Established, forward the payload through it
if targetRelay.State == Established {
switch targetRelay.Type {
case ForwardingType:
// Forward this packet through the relay tunnel
// Find the target HostInfo
f.SendVia(targetHI, targetRelay, signedPayload, nb, out[i].Scratch, false)
return
case TerminalType:
hostinfo.logger(f.l).Error("Unexpected Relay Type of Terminal")
}
} else {
hostinfo.logger(f.l).WithFields(logrus.Fields{"relayTo": relay.PeerAddr, "relayFrom": hostinfo.vpnAddrs[0], "targetRelayState": targetRelay.State}).Info("Unexpected target relay state")
return
}
}
}
case header.LightHouse:
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
if !f.handleEncrypted(ci, ip, h) {
return
}
d, err := f.decrypt(hostinfo, h.MessageCounter, out[i].Scratch, segment, h, nb)
if err != nil {
hostinfo.logger(f.l).WithError(err).WithField("udpAddr", ip).
WithField("packet", segment).
Error("Failed to decrypt lighthouse packet")
return
}
lhf.HandleRequest(ip, hostinfo.vpnAddrs, d, f)
// Fallthrough to the bottom to record incoming traffic
case header.Test:
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
if !f.handleEncrypted(ci, ip, h) {
return
}
d, err := f.decrypt(hostinfo, h.MessageCounter, out[i].Scratch, segment, h, nb)
if err != nil {
hostinfo.logger(f.l).WithError(err).WithField("udpAddr", ip).
WithField("packet", segment).
Error("Failed to decrypt test packet")
return
}
if h.Subtype == header.TestRequest {
// This testRequest might be from TryPromoteBest, so we should roam
// to the new IP address before responding
f.handleHostRoaming(hostinfo, ip)
f.send(header.Test, header.TestReply, ci, hostinfo, d, nb, out[i].Scratch)
}
// Fallthrough to the bottom to record incoming traffic
// Non encrypted messages below here, they should not fall through to avoid tracking incoming traffic since they
// are unauthenticated
case header.Handshake:
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
f.handshakeManager.HandleIncoming(ip, nil, segment, h)
return
case header.RecvError:
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
f.handleRecvError(ip, h)
return
case header.CloseTunnel:
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
if !f.handleEncrypted(ci, ip, h) {
return
}
hostinfo.logger(f.l).WithField("udpAddr", ip).
Info("Close tunnel received, tearing down.")
f.closeTunnel(hostinfo)
return
case header.Control:
if !f.handleEncrypted(ci, ip, h) {
return
}
d, err := f.decrypt(hostinfo, h.MessageCounter, out[i].Scratch, segment, h, nb)
if err != nil {
hostinfo.logger(f.l).WithError(err).WithField("udpAddr", ip).
WithField("packet", segment).
Error("Failed to decrypt Control packet")
return
}
f.relayManager.HandleControlMsg(hostinfo, d, f)
default:
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
hostinfo.logger(f.l).Debugf("Unexpected packet received from %s", ip)
return
}
f.handleHostRoaming(hostinfo, ip)
f.connectionManager.In(hostinfo)
}
_, err := f.readers[q].WriteOne(out[i], false, q)
if err != nil {
f.l.WithError(err).Error("Failed to write packet")
}
}
}
// closeTunnel closes a tunnel locally, it does not send a closeTunnel packet to the remote // closeTunnel closes a tunnel locally, it does not send a closeTunnel packet to the remote
func (f *Interface) closeTunnel(hostInfo *HostInfo) { func (f *Interface) closeTunnel(hostInfo *HostInfo) {
final := f.hostMap.DeleteHostInfo(hostInfo) final := f.hostMap.DeleteHostInfo(hostInfo)
@@ -677,55 +465,7 @@ func (f *Interface) decrypt(hostinfo *HostInfo, mc uint64, out []byte, packet []
return out, nil return out, nil
} }
func (f *Interface) decryptToTunDelayWrite(hostinfo *HostInfo, messageCounter uint64, out *packet.OutPacket, pkt *packet.Packet, inSegment []byte, fwPacket *firewall.Packet, nb []byte, q int, localCache firewall.ConntrackCache, now time.Time) bool { func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out []byte, packet []byte, fwPacket *firewall.Packet, nb []byte, q int, localCache firewall.ConntrackCache) bool {
var err error
seg, err := f.readers[q].AllocSeg(out, q)
if err != nil {
f.l.WithError(err).Errorln("decryptToTunDelayWrite: failed to allocate segment")
return false
}
out.SegmentPayloads[seg] = out.SegmentPayloads[seg][:0]
out.SegmentPayloads[seg], err = hostinfo.ConnectionState.dKey.DecryptDanger(out.SegmentPayloads[seg], inSegment[:header.Len], inSegment[header.Len:], messageCounter, nb)
if err != nil {
hostinfo.logger(f.l).WithError(err).Error("Failed to decrypt packet")
return false
}
err = newPacket(out.SegmentPayloads[seg], true, fwPacket)
if err != nil {
hostinfo.logger(f.l).WithError(err).WithField("packet", out).
Warnf("Error while validating inbound packet")
return false
}
if !hostinfo.ConnectionState.window.Update(f.l, messageCounter) {
hostinfo.logger(f.l).WithField("fwPacket", fwPacket).
Debugln("dropping out of window packet")
return false
}
dropReason := f.firewall.Drop(*fwPacket, true, hostinfo, f.pki.GetCAPool(), localCache, now)
if dropReason != nil {
// NOTE: We give `packet` as the `out` here since we already decrypted from it and we don't need it anymore
// This gives us a buffer to build the reject packet in
f.rejectOutside(out.SegmentPayloads[seg], hostinfo.ConnectionState, hostinfo, nb, inSegment, q)
if f.l.Level >= logrus.DebugLevel {
hostinfo.logger(f.l).WithField("fwPacket", fwPacket).
WithField("reason", dropReason).
Debugln("dropping inbound packet")
}
return false
}
f.connectionManager.In(hostinfo)
pkt.OutLen += len(inSegment)
out.Segments[seg] = out.Segments[seg][:len(out.SegmentHeaders[seg])+len(out.SegmentPayloads[seg])]
return true
}
func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out []byte, packet []byte, fwPacket *firewall.Packet, nb []byte, q int, localCache firewall.ConntrackCache, now time.Time) bool {
var err error var err error
out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:header.Len], packet[header.Len:], messageCounter, nb) out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:header.Len], packet[header.Len:], messageCounter, nb)
@@ -747,7 +487,7 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
return false return false
} }
dropReason := f.firewall.Drop(*fwPacket, true, hostinfo, f.pki.GetCAPool(), localCache, now) dropReason := f.firewall.Drop(*fwPacket, true, hostinfo, f.pki.GetCAPool(), localCache)
if dropReason != nil { if dropReason != nil {
// NOTE: We give `packet` as the `out` here since we already decrypted from it and we don't need it anymore // NOTE: We give `packet` as the `out` here since we already decrypted from it and we don't need it anymore
// This gives us a buffer to build the reject packet in // This gives us a buffer to build the reject packet in

View File

@@ -1,16 +1,17 @@
package overlay package overlay
import ( import (
"io"
"net/netip" "net/netip"
"github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/routing"
) )
type Device interface { type Device interface {
TunDev io.ReadWriteCloser
Activate() error Activate() error
Networks() []netip.Prefix Networks() []netip.Prefix
Name() string Name() string
RoutesFor(netip.Addr) routing.Gateways RoutesFor(netip.Addr) routing.Gateways
NewMultiQueueReader() (TunDev, error) NewMultiQueueReader() (io.ReadWriteCloser, error)
} }

View File

@@ -1,91 +0,0 @@
package eventfd
import (
"encoding/binary"
"syscall"
"golang.org/x/sys/unix"
)
type EventFD struct {
fd int
buf [8]byte
}
func New() (EventFD, error) {
fd, err := unix.Eventfd(0, unix.EFD_NONBLOCK)
if err != nil {
return EventFD{}, err
}
return EventFD{
fd: fd,
buf: [8]byte{},
}, nil
}
func (e *EventFD) Kick() error {
binary.LittleEndian.PutUint64(e.buf[:], 1) //is this right???
_, err := syscall.Write(int(e.fd), e.buf[:])
return err
}
func (e *EventFD) Close() error {
if e.fd != 0 {
return unix.Close(e.fd)
}
return nil
}
func (e *EventFD) FD() int {
return e.fd
}
type Epoll struct {
fd int
buf [8]byte
events []syscall.EpollEvent
}
func NewEpoll() (Epoll, error) {
fd, err := unix.EpollCreate1(0)
if err != nil {
return Epoll{}, err
}
return Epoll{
fd: fd,
buf: [8]byte{},
events: make([]syscall.EpollEvent, 1),
}, nil
}
func (ep *Epoll) AddEvent(fdToAdd int) error {
event := syscall.EpollEvent{
Events: syscall.EPOLLIN,
Fd: int32(fdToAdd),
}
return syscall.EpollCtl(ep.fd, syscall.EPOLL_CTL_ADD, fdToAdd, &event)
}
func (ep *Epoll) Block() (int, error) {
n, err := syscall.EpollWait(ep.fd, ep.events, -1)
if err != nil {
//goland:noinspection GoDirectComparisonOfErrors
if err == syscall.EINTR {
return 0, nil //??
}
return -1, err
}
return n, nil
}
func (ep *Epoll) Clear() error {
_, err := syscall.Read(int(ep.events[0].Fd), ep.buf[:])
return err
}
func (ep *Epoll) Close() error {
if ep.fd != 0 {
return unix.Close(ep.fd)
}
return nil
}

View File

@@ -2,29 +2,16 @@ package overlay
import ( import (
"fmt" "fmt"
"io"
"net" "net"
"net/netip" "net/netip"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/packet"
"github.com/slackhq/nebula/util" "github.com/slackhq/nebula/util"
) )
const DefaultMTU = 1300 const DefaultMTU = 1300
type TunDev interface {
io.WriteCloser
ReadMany(x []*packet.VirtIOPacket, q int) (int, error)
//todo this interface sux
AllocSeg(pkt *packet.OutPacket, q int) (int, error)
WriteOne(x *packet.OutPacket, kick bool, q int) (int, error)
WriteMany(x []*packet.OutPacket, q int) (int, error)
RecycleRxSeg(pkt *packet.VirtIOPacket, kick bool, q int) error
}
// TODO: We may be able to remove routines // TODO: We may be able to remove routines
type DeviceFactory func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) type DeviceFactory func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error)
@@ -39,11 +26,11 @@ func NewDeviceFromConfig(c *config.C, l *logrus.Logger, vpnNetworks []netip.Pref
} }
} }
//func NewFdDeviceFromConfig(fd *int) DeviceFactory { func NewFdDeviceFromConfig(fd *int) DeviceFactory {
// return func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) { return func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) {
// return newTunFromFd(c, l, *fd, vpnNetworks) return newTunFromFd(c, l, *fd, vpnNetworks)
// } }
//} }
func getAllRoutesFromConfig(c *config.C, vpnNetworks []netip.Prefix, initial bool) (bool, []Route, error) { func getAllRoutesFromConfig(c *config.C, vpnNetworks []netip.Prefix, initial bool) (bool, []Route, error) {
if !initial && !c.HasChanged("tun.routes") && !c.HasChanged("tun.unsafe_routes") { if !initial && !c.HasChanged("tun.routes") && !c.HasChanged("tun.unsafe_routes") {

View File

@@ -9,8 +9,6 @@ import (
"github.com/rcrowley/go-metrics" "github.com/rcrowley/go-metrics"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/overlay/virtqueue"
"github.com/slackhq/nebula/packet"
"github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/routing"
) )
@@ -24,10 +22,6 @@ type disabledTun struct {
l *logrus.Logger l *logrus.Logger
} }
func (*disabledTun) RecycleRxSeg(pkt *packet.VirtIOPacket, kick bool, q int) error {
return nil
}
func newDisabledTun(vpnNetworks []netip.Prefix, queueLen int, metricsEnabled bool, l *logrus.Logger) *disabledTun { func newDisabledTun(vpnNetworks []netip.Prefix, queueLen int, metricsEnabled bool, l *logrus.Logger) *disabledTun {
tun := &disabledTun{ tun := &disabledTun{
vpnNetworks: vpnNetworks, vpnNetworks: vpnNetworks,
@@ -46,10 +40,6 @@ func newDisabledTun(vpnNetworks []netip.Prefix, queueLen int, metricsEnabled boo
return tun return tun
} }
func (*disabledTun) GetQueues() []*virtqueue.SplitQueue {
return nil
}
func (*disabledTun) Activate() error { func (*disabledTun) Activate() error {
return nil return nil
} }
@@ -115,23 +105,7 @@ func (t *disabledTun) Write(b []byte) (int, error) {
return len(b), nil return len(b), nil
} }
func (t *disabledTun) AllocSeg(pkt *packet.OutPacket, q int) (int, error) { func (t *disabledTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return 0, fmt.Errorf("tun_disabled: AllocSeg not implemented")
}
func (t *disabledTun) WriteOne(x *packet.OutPacket, kick bool, q int) (int, error) {
return 0, fmt.Errorf("tun_disabled: WriteOne not implemented")
}
func (t *disabledTun) WriteMany(x []*packet.OutPacket, q int) (int, error) {
return 0, fmt.Errorf("tun_disabled: WriteMany not implemented")
}
func (t *disabledTun) ReadMany(b []*packet.VirtIOPacket, _ int) (int, error) {
return t.Read(b[0].Payload)
}
func (t *disabledTun) NewMultiQueueReader() (TunDev, error) {
return t, nil return t, nil
} }

View File

@@ -4,7 +4,9 @@
package overlay package overlay
import ( import (
"errors"
"fmt" "fmt"
"io"
"net" "net"
"net/netip" "net/netip"
"os" "os"
@@ -16,19 +18,15 @@ import (
"github.com/gaissmai/bart" "github.com/gaissmai/bart"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/overlay/vhostnet"
"github.com/slackhq/nebula/packet"
"github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/routing"
"github.com/slackhq/nebula/util" "github.com/slackhq/nebula/util"
"github.com/slackhq/nebula/util/virtio"
"github.com/vishvananda/netlink" "github.com/vishvananda/netlink"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
) )
type tun struct { type tun struct {
file *os.File io.ReadWriteCloser
fd int fd int
vdev []*vhostnet.Device
Device string Device string
vpnNetworks []netip.Prefix vpnNetworks []netip.Prefix
MaxMTU int MaxMTU int
@@ -43,7 +41,6 @@ type tun struct {
useSystemRoutes bool useSystemRoutes bool
useSystemRoutesBufferSize int useSystemRoutesBufferSize int
isV6 bool
l *logrus.Logger l *logrus.Logger
} }
@@ -105,58 +102,75 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueu
} }
} }
tunNameTemplate := c.GetString("tun.dev", "nebula%d")
tunName, err := findNextTunName(tunNameTemplate)
if err != nil {
return nil, err
}
var req ifReq var req ifReq
req.Flags = uint16(unix.IFF_TUN | unix.IFF_NO_PI | unix.IFF_TUN_EXCL | unix.IFF_VNET_HDR | unix.IFF_NAPI) req.Flags = uint16(unix.IFF_TUN | unix.IFF_NO_PI)
if multiqueue { if multiqueue {
req.Flags |= unix.IFF_MULTI_QUEUE req.Flags |= unix.IFF_MULTI_QUEUE
} }
copy(req.Name[:], c.GetString("tun.dev", "")) copy(req.Name[:], tunName)
if err = ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil { if err = ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil {
return nil, err return nil, err
} }
name := strings.Trim(string(req.Name[:]), "\x00") name := strings.Trim(string(req.Name[:]), "\x00")
if err = unix.SetNonblock(fd, true); err != nil {
_ = unix.Close(fd)
return nil, fmt.Errorf("make file descriptor non-blocking: %w", err)
}
file := os.NewFile(uintptr(fd), "/dev/net/tun") file := os.NewFile(uintptr(fd), "/dev/net/tun")
err = unix.IoctlSetPointerInt(fd, unix.TUNSETVNETHDRSZ, virtio.NetHdrSize)
if err != nil {
return nil, fmt.Errorf("set vnethdr size: %w", err)
}
flags := 0
//flags = //unix.TUN_F_CSUM //| unix.TUN_F_TSO4 | unix.TUN_F_USO4 | unix.TUN_F_TSO6 | unix.TUN_F_USO6
err = unix.IoctlSetInt(fd, unix.TUNSETOFFLOAD, flags)
if err != nil {
return nil, fmt.Errorf("set offloads: %w", err)
}
t, err := newTunGeneric(c, l, file, vpnNetworks) t, err := newTunGeneric(c, l, file, vpnNetworks)
if err != nil { if err != nil {
return nil, err return nil, err
} }
t.fd = fd
t.Device = name
vdev, err := vhostnet.NewDevice( t.Device = name
vhostnet.WithBackendFD(fd),
vhostnet.WithQueueSize(8192), //todo config
)
if err != nil {
return nil, err
}
t.vdev = []*vhostnet.Device{vdev}
return t, nil return t, nil
} }
func findNextTunName(tunName string) (string, error) {
if !strings.HasSuffix(tunName, "%d") {
return tunName, nil
}
if len(tunName) == 2 {
return "", errors.New("please don't name your tun device '%d'")
}
if (len(tunName) - len("%d") + len("0")) > unix.IFNAMSIZ {
return "", fmt.Errorf("your tun device name template %s would result in a name longer than the maximum allowed length of %d", tunName, unix.IFNAMSIZ)
}
tunNameTemplate := tunName[:len(tunName)-len("%d")]
links, err := netlink.LinkList()
if err != nil {
return "", err
}
var candidateName string
i := 0
for {
candidateName = fmt.Sprintf("%s%d", tunNameTemplate, i)
good := true
for _, link := range links {
if candidateName == link.Attrs().Name {
good = false
break
}
}
if len(candidateName) > unix.IFNAMSIZ {
return "", fmt.Errorf("first available tun device is %s, which is longer than the max allowed size of %d", candidateName, unix.IFNAMSIZ)
}
if good {
return candidateName, nil
}
}
return "", errors.New("failed to find a tun device name")
}
func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, vpnNetworks []netip.Prefix) (*tun, error) { func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, vpnNetworks []netip.Prefix) (*tun, error) {
t := &tun{ t := &tun{
file: file, ReadWriteCloser: file,
fd: int(file.Fd()), fd: int(file.Fd()),
vpnNetworks: vpnNetworks, vpnNetworks: vpnNetworks,
TXQueueLen: c.GetInt("tun.tx_queue", 500), TXQueueLen: c.GetInt("tun.tx_queue", 500),
@@ -164,9 +178,6 @@ func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, vpnNetworks []n
useSystemRoutesBufferSize: c.GetInt("tun.use_system_route_table_buffer_size", 0), useSystemRoutesBufferSize: c.GetInt("tun.use_system_route_table_buffer_size", 0),
l: l, l: l,
} }
if len(vpnNetworks) != 0 {
t.isV6 = vpnNetworks[0].Addr().Is6() //todo what about multi-IP?
}
err := t.reload(c, true) err := t.reload(c, true)
if err != nil { if err != nil {
@@ -250,7 +261,7 @@ func (t *tun) reload(c *config.C, initial bool) error {
return nil return nil
} }
func (t *tun) NewMultiQueueReader() (TunDev, error) { func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0) fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -263,17 +274,9 @@ func (t *tun) NewMultiQueueReader() (TunDev, error) {
return nil, err return nil, err
} }
vdev, err := vhostnet.NewDevice( file := os.NewFile(uintptr(fd), "/dev/net/tun")
vhostnet.WithBackendFD(fd),
vhostnet.WithQueueSize(8192), //todo config
)
if err != nil {
return nil, err
}
t.vdev = append(t.vdev, vdev) return file, nil
return t, nil
} }
func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways { func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
@@ -281,6 +284,29 @@ func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
return r return r
} }
func (t *tun) Write(b []byte) (int, error) {
var nn int
maximum := len(b)
for {
n, err := unix.Write(t.fd, b[nn:maximum])
if n > 0 {
nn += n
}
if nn == len(b) {
return nn, err
}
if err != nil {
return nn, err
}
if n == 0 {
return nn, io.ErrUnexpectedEOF
}
}
}
func (t *tun) deviceBytes() (o [16]byte) { func (t *tun) deviceBytes() (o [16]byte) {
for i, c := range t.Device { for i, c := range t.Device {
o[i] = byte(c) o[i] = byte(c)
@@ -601,9 +627,7 @@ func (t *tun) isGatewayInVpnNetworks(gwAddr netip.Addr) bool {
} }
func (t *tun) getGatewaysFromRoute(r *netlink.Route) routing.Gateways { func (t *tun) getGatewaysFromRoute(r *netlink.Route) routing.Gateways {
var gateways routing.Gateways var gateways routing.Gateways
link, err := netlink.LinkByName(t.Device) link, err := netlink.LinkByName(t.Device)
if err != nil { if err != nil {
t.l.WithField("Devicename", t.Device).Error("Ignoring route update: failed to get link by name") t.l.WithField("Devicename", t.Device).Error("Ignoring route update: failed to get link by name")
@@ -652,9 +676,7 @@ func (t *tun) getGatewaysFromRoute(r *netlink.Route) routing.Gateways {
} }
func (t *tun) updateRoutes(r netlink.RouteUpdate) { func (t *tun) updateRoutes(r netlink.RouteUpdate) {
gateways := t.getGatewaysFromRoute(&r.Route) gateways := t.getGatewaysFromRoute(&r.Route)
if len(gateways) == 0 { if len(gateways) == 0 {
// No gateways relevant to our network, no routing changes required. // No gateways relevant to our network, no routing changes required.
t.l.WithField("route", r).Debug("Ignoring route update, no gateways") t.l.WithField("route", r).Debug("Ignoring route update, no gateways")
@@ -693,14 +715,8 @@ func (t *tun) Close() error {
close(t.routeChan) close(t.routeChan)
} }
for _, v := range t.vdev { if t.ReadWriteCloser != nil {
if v != nil { _ = t.ReadWriteCloser.Close()
_ = v.Close()
}
}
if t.file != nil {
_ = t.file.Close()
} }
if t.ioctlFd > 0 { if t.ioctlFd > 0 {
@@ -709,65 +725,3 @@ func (t *tun) Close() error {
return nil return nil
} }
func (t *tun) ReadMany(p []*packet.VirtIOPacket, q int) (int, error) {
n, err := t.vdev[q].ReceivePackets(p) //we are TXing
if err != nil {
return 0, err
}
return n, nil
}
func (t *tun) Write(b []byte) (int, error) {
maximum := len(b) //we are RXing
//todo garbagey
out := packet.NewOut()
x, err := t.AllocSeg(out, 0)
if err != nil {
return 0, err
}
copy(out.SegmentPayloads[x], b)
err = t.vdev[0].TransmitPacket(out, true)
if err != nil {
t.l.WithError(err).Error("Transmitting packet")
return 0, err
}
return maximum, nil
}
func (t *tun) AllocSeg(pkt *packet.OutPacket, q int) (int, error) {
idx, buf, err := t.vdev[q].GetPacketForTx()
if err != nil {
return 0, err
}
x := pkt.UseSegment(idx, buf, t.isV6)
return x, nil
}
func (t *tun) WriteOne(x *packet.OutPacket, kick bool, q int) (int, error) {
if err := t.vdev[q].TransmitPacket(x, kick); err != nil {
t.l.WithError(err).Error("Transmitting packet")
return 0, err
}
return 1, nil
}
func (t *tun) WriteMany(x []*packet.OutPacket, q int) (int, error) {
maximum := len(x) //we are RXing
if maximum == 0 {
return 0, nil
}
err := t.vdev[q].TransmitPackets(x)
if err != nil {
t.l.WithError(err).Error("Transmitting packet")
return 0, err
}
return maximum, nil
}
func (t *tun) RecycleRxSeg(pkt *packet.VirtIOPacket, kick bool, q int) error {
return t.vdev[q].ReceiveQueue.OfferDescriptorChains(pkt.Chains, kick)
}

View File

@@ -1,13 +1,11 @@
package overlay package overlay
import ( import (
"fmt"
"io" "io"
"net/netip" "net/netip"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/packet"
"github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/routing"
) )
@@ -38,10 +36,6 @@ type UserDevice struct {
inboundWriter *io.PipeWriter inboundWriter *io.PipeWriter
} }
func (d *UserDevice) RecycleRxSeg(pkt *packet.VirtIOPacket, kick bool, q int) error {
return nil
}
func (d *UserDevice) Activate() error { func (d *UserDevice) Activate() error {
return nil return nil
} }
@@ -52,7 +46,7 @@ func (d *UserDevice) RoutesFor(ip netip.Addr) routing.Gateways {
return routing.Gateways{routing.NewGateway(ip, 1)} return routing.Gateways{routing.NewGateway(ip, 1)}
} }
func (d *UserDevice) NewMultiQueueReader() (TunDev, error) { func (d *UserDevice) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return d, nil return d, nil
} }
@@ -71,19 +65,3 @@ func (d *UserDevice) Close() error {
d.outboundWriter.Close() d.outboundWriter.Close()
return nil return nil
} }
func (d *UserDevice) ReadMany(b []*packet.VirtIOPacket, _ int) (int, error) {
return d.Read(b[0].Payload)
}
func (d *UserDevice) AllocSeg(pkt *packet.OutPacket, q int) (int, error) {
return 0, fmt.Errorf("user: AllocSeg not implemented")
}
func (d *UserDevice) WriteOne(x *packet.OutPacket, kick bool, q int) (int, error) {
return 0, fmt.Errorf("user: WriteOne not implemented")
}
func (d *UserDevice) WriteMany(x []*packet.OutPacket, q int) (int, error) {
return 0, fmt.Errorf("user: WriteMany not implemented")
}

View File

@@ -1,23 +0,0 @@
Significant portions of this code are derived from https://pkg.go.dev/github.com/hetznercloud/virtio-go
MIT License
Copyright (c) 2025 Hetzner Cloud GmbH
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@@ -1,4 +0,0 @@
// Package vhost implements the basic ioctl requests needed to interact with the
// kernel-level virtio server that provides accelerated virtio devices for
// networking and more.
package vhost

View File

@@ -1,218 +0,0 @@
package vhost
import (
"fmt"
"unsafe"
"github.com/slackhq/nebula/overlay/virtqueue"
"github.com/slackhq/nebula/util/virtio"
"golang.org/x/sys/unix"
)
const (
// vhostIoctlGetFeatures can be used to retrieve the features supported by
// the vhost implementation in the kernel.
//
// Response payload: [virtio.Feature]
// Kernel name: VHOST_GET_FEATURES
vhostIoctlGetFeatures = 0x8008af00
// vhostIoctlSetFeatures can be used to communicate the features supported
// by this virtio implementation to the kernel.
//
// Request payload: [virtio.Feature]
// Kernel name: VHOST_SET_FEATURES
vhostIoctlSetFeatures = 0x4008af00
// vhostIoctlSetOwner can be used to set the current process as the
// exclusive owner of a control file descriptor.
//
// Request payload: none
// Kernel name: VHOST_SET_OWNER
vhostIoctlSetOwner = 0x0000af01
// vhostIoctlSetMemoryLayout can be used to set up or modify the memory
// layout which describes the IOTLB mappings in the kernel.
//
// Request payload: [MemoryLayout] with custom serialization
// Kernel name: VHOST_SET_MEM_TABLE
vhostIoctlSetMemoryLayout = 0x4008af03
// vhostIoctlSetQueueSize can be used to set the size of the virtqueue.
//
// Request payload: [QueueState]
// Kernel name: VHOST_SET_VRING_NUM
vhostIoctlSetQueueSize = 0x4008af10
// vhostIoctlSetQueueAddress can be used to set the addresses of the
// different parts of the virtqueue.
//
// Request payload: [QueueAddresses]
// Kernel name: VHOST_SET_VRING_ADDR
vhostIoctlSetQueueAddress = 0x4028af11
// vhostIoctlSetAvailableRingBase can be used to set the index of the next
// available ring entry the device will process.
//
// Request payload: [QueueState]
// Kernel name: VHOST_SET_VRING_BASE
vhostIoctlSetAvailableRingBase = 0x4008af12
// vhostIoctlSetQueueKickEventFD can be used to set the event file
// descriptor to signal the device when descriptor chains were added to the
// available ring.
//
// Request payload: [QueueFile]
// Kernel name: VHOST_SET_VRING_KICK
vhostIoctlSetQueueKickEventFD = 0x4008af20
// vhostIoctlSetQueueCallEventFD can be used to set the event file
// descriptor that gets signaled by the device when descriptor chains have
// been used by it.
//
// Request payload: [QueueFile]
// Kernel name: VHOST_SET_VRING_CALL
vhostIoctlSetQueueCallEventFD = 0x4008af21
)
// QueueState is an ioctl request payload that can hold a queue index and any
// 32-bit number.
//
// Kernel name: vhost_vring_state
type QueueState struct {
// QueueIndex is the index of the virtqueue.
QueueIndex uint32
// Num is any 32-bit number, depending on the request.
Num uint32
}
// QueueAddresses is an ioctl request payload that can hold the addresses of the
// different parts of a virtqueue.
//
// Kernel name: vhost_vring_addr
type QueueAddresses struct {
// QueueIndex is the index of the virtqueue.
QueueIndex uint32
// Flags that are not used in this implementation.
Flags uint32
// DescriptorTableAddress is the address of the descriptor table in user
// space memory. It must be 16-byte aligned.
DescriptorTableAddress uintptr
// UsedRingAddress is the address of the used ring in user space memory. It
// must be 4-byte aligned.
UsedRingAddress uintptr
// AvailableRingAddress is the address of the available ring in user space
// memory. It must be 2-byte aligned.
AvailableRingAddress uintptr
// LogAddress is used for an optional logging support, not supported by this
// implementation.
LogAddress uintptr
}
// QueueFile is an ioctl request payload that can hold a queue index and a file
// descriptor.
//
// Kernel name: vhost_vring_file
type QueueFile struct {
// QueueIndex is the index of the virtqueue.
QueueIndex uint32
// FD is the file descriptor of the file. Pass -1 to unbind from a file.
FD int32
}
// IoctlPtr is a copy of the similarly named unexported function from the Go
// unix package. This is needed to do custom ioctl requests not supported by the
// standard library.
func IoctlPtr(fd int, req uint, arg unsafe.Pointer) error {
_, _, err := unix.Syscall(unix.SYS_IOCTL, uintptr(fd), uintptr(req), uintptr(arg))
if err != 0 {
return fmt.Errorf("ioctl request %d: %w", req, err)
}
return nil
}
// GetFeatures requests the supported feature bits from the virtio device
// associated with the given control file descriptor.
func GetFeatures(controlFD int) (virtio.Feature, error) {
var features virtio.Feature
if err := IoctlPtr(controlFD, vhostIoctlGetFeatures, unsafe.Pointer(&features)); err != nil {
return 0, fmt.Errorf("get features: %w", err)
}
return features, nil
}
// SetFeatures communicates the feature bits supported by this implementation
// to the virtio device associated with the given control file descriptor.
func SetFeatures(controlFD int, features virtio.Feature) error {
if err := IoctlPtr(controlFD, vhostIoctlSetFeatures, unsafe.Pointer(&features)); err != nil {
return fmt.Errorf("set features: %w", err)
}
return nil
}
// OwnControlFD sets the current process as the exclusive owner for the
// given control file descriptor. This must be called before interacting with
// the control file descriptor in any other way.
func OwnControlFD(controlFD int) error {
if err := IoctlPtr(controlFD, vhostIoctlSetOwner, unsafe.Pointer(nil)); err != nil {
return fmt.Errorf("set control file descriptor owner: %w", err)
}
return nil
}
// SetMemoryLayout sets up or modifies the memory layout for the kernel-level
// virtio device associated with the given control file descriptor.
func SetMemoryLayout(controlFD int, layout MemoryLayout) error {
payload := layout.serializePayload()
if err := IoctlPtr(controlFD, vhostIoctlSetMemoryLayout, unsafe.Pointer(&payload[0])); err != nil {
return fmt.Errorf("set memory layout: %w", err)
}
return nil
}
// RegisterQueue registers a virtio queue with the kernel-level virtio server.
// The virtqueue will be linked to the given control file descriptor and will
// have the given index. The kernel will use this queue until the control file
// descriptor is closed.
func RegisterQueue(controlFD int, queueIndex uint32, queue *virtqueue.SplitQueue) error {
if err := IoctlPtr(controlFD, vhostIoctlSetQueueSize, unsafe.Pointer(&QueueState{
QueueIndex: queueIndex,
Num: uint32(queue.Size()),
})); err != nil {
return fmt.Errorf("set queue size: %w", err)
}
if err := IoctlPtr(controlFD, vhostIoctlSetQueueAddress, unsafe.Pointer(&QueueAddresses{
QueueIndex: queueIndex,
Flags: 0,
DescriptorTableAddress: queue.DescriptorTable().Address(),
UsedRingAddress: queue.UsedRing().Address(),
AvailableRingAddress: queue.AvailableRing().Address(),
LogAddress: 0,
})); err != nil {
return fmt.Errorf("set queue addresses: %w", err)
}
if err := IoctlPtr(controlFD, vhostIoctlSetAvailableRingBase, unsafe.Pointer(&QueueState{
QueueIndex: queueIndex,
Num: 0,
})); err != nil {
return fmt.Errorf("set available ring base: %w", err)
}
if err := IoctlPtr(controlFD, vhostIoctlSetQueueKickEventFD, unsafe.Pointer(&QueueFile{
QueueIndex: queueIndex,
FD: int32(queue.KickEventFD()),
})); err != nil {
return fmt.Errorf("set kick event file descriptor: %w", err)
}
if err := IoctlPtr(controlFD, vhostIoctlSetQueueCallEventFD, unsafe.Pointer(&QueueFile{
QueueIndex: queueIndex,
FD: int32(queue.CallEventFD()),
})); err != nil {
return fmt.Errorf("set call event file descriptor: %w", err)
}
return nil
}

View File

@@ -1,21 +0,0 @@
package vhost_test
import (
"testing"
"unsafe"
"github.com/slackhq/nebula/overlay/vhost"
"github.com/stretchr/testify/assert"
)
func TestQueueState_Size(t *testing.T) {
assert.EqualValues(t, 8, unsafe.Sizeof(vhost.QueueState{}))
}
func TestQueueAddresses_Size(t *testing.T) {
assert.EqualValues(t, 40, unsafe.Sizeof(vhost.QueueAddresses{}))
}
func TestQueueFile_Size(t *testing.T) {
assert.EqualValues(t, 8, unsafe.Sizeof(vhost.QueueFile{}))
}

View File

@@ -1,73 +0,0 @@
package vhost
import (
"encoding/binary"
"fmt"
"unsafe"
"github.com/slackhq/nebula/overlay/virtqueue"
)
// MemoryRegion describes a region of userspace memory which is being made
// accessible to a vhost device.
//
// Kernel name: vhost_memory_region
type MemoryRegion struct {
// GuestPhysicalAddress is the physical address of the memory region within
// the guest, when virtualization is used. When no virtualization is used,
// this should be the same as UserspaceAddress.
GuestPhysicalAddress uintptr
// Size is the size of the memory region.
Size uint64
// UserspaceAddress is the virtual address in the userspace of the host
// where the memory region can be found.
UserspaceAddress uintptr
// Padding and room for flags. Currently unused.
_ uint64
}
// MemoryLayout is a list of [MemoryRegion]s.
type MemoryLayout []MemoryRegion
// NewMemoryLayoutForQueues returns a new [MemoryLayout] that describes the
// memory pages used by the descriptor tables of the given queues.
func NewMemoryLayoutForQueues(queues []*virtqueue.SplitQueue) MemoryLayout {
regions := make([]MemoryRegion, 0)
for _, queue := range queues {
for address, size := range queue.DescriptorTable().BufferAddresses() {
regions = append(regions, MemoryRegion{
// There is no virtualization in play here, so the guest address
// is the same as in the host's userspace.
GuestPhysicalAddress: address,
Size: uint64(size),
UserspaceAddress: address,
})
}
}
return regions
}
// serializePayload serializes the list of memory regions into a format that is
// compatible to the vhost_memory kernel struct. The returned byte slice can be
// used as a payload for the vhostIoctlSetMemoryLayout ioctl.
func (regions MemoryLayout) serializePayload() []byte {
regionCount := len(regions)
regionSize := int(unsafe.Sizeof(MemoryRegion{}))
payload := make([]byte, 8+regionCount*regionSize)
// The first 32 bits contain the number of memory regions. The following 32
// bits are padding.
binary.LittleEndian.PutUint32(payload[0:4], uint32(regionCount))
if regionCount > 0 {
// The underlying byte array of the slice should already have the correct
// format, so just copy that.
copied := copy(payload[8:], unsafe.Slice((*byte)(unsafe.Pointer(&regions[0])), regionCount*regionSize))
if copied != regionCount*regionSize {
panic(fmt.Sprintf("copied only %d bytes of the memory regions, but expected %d",
copied, regionCount*regionSize))
}
}
return payload
}

View File

@@ -1,42 +0,0 @@
package vhost
import (
"testing"
"unsafe"
"github.com/stretchr/testify/assert"
)
func TestMemoryRegion_Size(t *testing.T) {
assert.EqualValues(t, 32, unsafe.Sizeof(MemoryRegion{}))
}
func TestMemoryLayout_SerializePayload(t *testing.T) {
layout := MemoryLayout([]MemoryRegion{
{
GuestPhysicalAddress: 42,
Size: 100,
UserspaceAddress: 142,
}, {
GuestPhysicalAddress: 99,
Size: 100,
UserspaceAddress: 99,
},
})
payload := layout.serializePayload()
assert.Equal(t, []byte{
0x02, 0x00, 0x00, 0x00, // nregions
0x00, 0x00, 0x00, 0x00, // padding
// region 0
0x2a, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // guest_phys_addr
0x64, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // memory_size
0x8e, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // userspace_addr
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // flags_padding
// region 1
0x63, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // guest_phys_addr
0x64, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // memory_size
0x63, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // userspace_addr
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // flags_padding
}, payload)
}

View File

@@ -1,23 +0,0 @@
Significant portions of this code are derived from https://pkg.go.dev/github.com/hetznercloud/virtio-go
MIT License
Copyright (c) 2025 Hetzner Cloud GmbH
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@@ -1,427 +0,0 @@
package vhostnet
import (
"context"
"errors"
"fmt"
"os"
"runtime"
"github.com/slackhq/nebula/overlay/vhost"
"github.com/slackhq/nebula/overlay/virtqueue"
"github.com/slackhq/nebula/packet"
"github.com/slackhq/nebula/util/virtio"
"golang.org/x/sys/unix"
)
// ErrDeviceClosed is returned when the [Device] is closed while operations are
// still running.
var ErrDeviceClosed = errors.New("device was closed")
// The indexes for the receive and transmit queues.
const (
receiveQueueIndex = 0
transmitQueueIndex = 1
)
// Device represents a vhost networking device within the kernel-level virtio
// implementation and provides methods to interact with it.
type Device struct {
initialized bool
controlFD int
fullTable bool
ReceiveQueue *virtqueue.SplitQueue
TransmitQueue *virtqueue.SplitQueue
}
// NewDevice initializes a new vhost networking device within the
// kernel-level virtio implementation, sets up the virtqueues and returns a
// [Device] instance that can be used to communicate with that vhost device.
//
// There are multiple options that can be passed to this constructor to
// influence device creation:
// - [WithQueueSize]
// - [WithBackendFD]
// - [WithBackendDevice]
//
// Remember to call [Device.Close] after use to free up resources.
func NewDevice(options ...Option) (*Device, error) {
var err error
opts := optionDefaults
opts.apply(options)
if err = opts.validate(); err != nil {
return nil, fmt.Errorf("invalid options: %w", err)
}
dev := Device{
controlFD: -1,
}
// Clean up a partially initialized device when something fails.
defer func() {
if err != nil {
_ = dev.Close()
}
}()
// Retrieve a new control file descriptor. This will be used to configure
// the vhost networking device in the kernel.
dev.controlFD, err = unix.Open("/dev/vhost-net", os.O_RDWR, 0666)
if err != nil {
return nil, fmt.Errorf("get control file descriptor: %w", err)
}
if err = vhost.OwnControlFD(dev.controlFD); err != nil {
return nil, fmt.Errorf("own control file descriptor: %w", err)
}
// Advertise the supported features. This isn't much for now.
// TODO: Add feature options and implement proper feature negotiation.
getFeatures, err := vhost.GetFeatures(dev.controlFD) //0x1033D008000 but why
if err != nil {
return nil, fmt.Errorf("get features: %w", err)
}
if getFeatures == 0 {
}
//const funky = virtio.Feature(1 << 27)
//features := virtio.FeatureVersion1 | funky // | todo virtio.FeatureNetMergeRXBuffers
features := virtio.FeatureVersion1 | virtio.FeatureNetMergeRXBuffers
if err = vhost.SetFeatures(dev.controlFD, features); err != nil {
return nil, fmt.Errorf("set features: %w", err)
}
itemSize := os.Getpagesize() * 4 //todo config
// Initialize and register the queues needed for the networking device.
if dev.ReceiveQueue, err = createQueue(dev.controlFD, receiveQueueIndex, opts.queueSize, itemSize); err != nil {
return nil, fmt.Errorf("create receive queue: %w", err)
}
if dev.TransmitQueue, err = createQueue(dev.controlFD, transmitQueueIndex, opts.queueSize, itemSize); err != nil {
return nil, fmt.Errorf("create transmit queue: %w", err)
}
// Set up memory mappings for all buffers used by the queues. This has to
// happen before a backend for the queues can be registered.
memoryLayout := vhost.NewMemoryLayoutForQueues(
[]*virtqueue.SplitQueue{dev.ReceiveQueue, dev.TransmitQueue},
)
if err = vhost.SetMemoryLayout(dev.controlFD, memoryLayout); err != nil {
return nil, fmt.Errorf("setup memory layout: %w", err)
}
// Set the queue backends. This activates the queues within the kernel.
if err = SetQueueBackend(dev.controlFD, receiveQueueIndex, opts.backendFD); err != nil {
return nil, fmt.Errorf("set receive queue backend: %w", err)
}
if err = SetQueueBackend(dev.controlFD, transmitQueueIndex, opts.backendFD); err != nil {
return nil, fmt.Errorf("set transmit queue backend: %w", err)
}
// Fully populate the receive queue with available buffers which the device
// can write new packets into.
if err = dev.refillReceiveQueue(); err != nil {
return nil, fmt.Errorf("refill receive queue: %w", err)
}
if err = dev.refillTransmitQueue(); err != nil {
return nil, fmt.Errorf("refill receive queue: %w", err)
}
dev.initialized = true
// Make sure to clean up even when the device gets garbage collected without
// Close being called first.
devPtr := &dev
runtime.SetFinalizer(devPtr, (*Device).Close)
return devPtr, nil
}
// refillReceiveQueue offers as many new device-writable buffers to the device
// as the queue can fit. The device will then use these to write received
// packets.
func (dev *Device) refillReceiveQueue() error {
for {
_, err := dev.ReceiveQueue.OfferInDescriptorChains()
if err != nil {
if errors.Is(err, virtqueue.ErrNotEnoughFreeDescriptors) {
// Queue is full, job is done.
return nil
}
return fmt.Errorf("offer descriptor chain: %w", err)
}
}
}
func (dev *Device) refillTransmitQueue() error {
//for {
// desc, err := dev.TransmitQueue.DescriptorTable().CreateDescriptorForOutputs()
// if err != nil {
// if errors.Is(err, virtqueue.ErrNotEnoughFreeDescriptors) {
// // Queue is full, job is done.
// return nil
// }
// return fmt.Errorf("offer descriptor chain: %w", err)
// } else {
// dev.TransmitQueue.UsedRing().InitOfferSingle(desc, 0)
// }
//}
return nil
}
// Close cleans up the vhost networking device within the kernel and releases
// all resources used for it.
// The implementation will try to release as many resources as possible and
// collect potential errors before returning them.
func (dev *Device) Close() error {
dev.initialized = false
// Closing the control file descriptor will unregister all queues from the
// kernel.
if dev.controlFD >= 0 {
if err := unix.Close(dev.controlFD); err != nil {
// Return an error and do not continue, because the memory used for
// the queues should not be released before they were unregistered
// from the kernel.
return fmt.Errorf("close control file descriptor: %w", err)
}
dev.controlFD = -1
}
var errs []error
if dev.ReceiveQueue != nil {
if err := dev.ReceiveQueue.Close(); err == nil {
dev.ReceiveQueue = nil
} else {
errs = append(errs, fmt.Errorf("close receive queue: %w", err))
}
}
if dev.TransmitQueue != nil {
if err := dev.TransmitQueue.Close(); err == nil {
dev.TransmitQueue = nil
} else {
errs = append(errs, fmt.Errorf("close transmit queue: %w", err))
}
}
if len(errs) == 0 {
// Everything was cleaned up. No need to run the finalizer anymore.
runtime.SetFinalizer(dev, nil)
}
return errors.Join(errs...)
}
// ensureInitialized is used as a guard to prevent methods to be called on an
// uninitialized instance.
func (dev *Device) ensureInitialized() {
if !dev.initialized {
panic("device is not initialized")
}
}
// createQueue creates a new virtqueue and registers it with the vhost device
// using the given index.
func createQueue(controlFD int, queueIndex int, queueSize int, itemSize int) (*virtqueue.SplitQueue, error) {
var (
queue *virtqueue.SplitQueue
err error
)
if queue, err = virtqueue.NewSplitQueue(queueSize, itemSize); err != nil {
return nil, fmt.Errorf("create virtqueue: %w", err)
}
if err = vhost.RegisterQueue(controlFD, uint32(queueIndex), queue); err != nil {
return nil, fmt.Errorf("register virtqueue with index %d: %w", queueIndex, err)
}
return queue, nil
}
// truncateBuffers returns a new list of buffers whose combined length matches
// exactly the specified length. When the specified length exceeds the length of
// the buffers, this is an error. When it is smaller, the buffer list will be
// truncated accordingly.
func truncateBuffers(buffers [][]byte, length int) (out [][]byte) {
for _, buffer := range buffers {
if length < len(buffer) {
out = append(out, buffer[:length])
return
}
out = append(out, buffer)
length -= len(buffer)
}
if length > 0 {
panic("length exceeds the combined length of all buffers")
}
return
}
func (dev *Device) GetPacketForTx() (uint16, []byte, error) {
var err error
var idx uint16
if !dev.fullTable {
idx, err = dev.TransmitQueue.DescriptorTable().CreateDescriptorForOutputs()
if err == virtqueue.ErrNotEnoughFreeDescriptors {
dev.fullTable = true
idx, err = dev.TransmitQueue.TakeSingle(context.TODO())
}
} else {
idx, err = dev.TransmitQueue.TakeSingle(context.TODO())
}
if err != nil {
return 0, nil, fmt.Errorf("transmit queue: %w", err)
}
buf, err := dev.TransmitQueue.GetDescriptorItem(idx)
if err != nil {
return 0, nil, fmt.Errorf("get descriptor chain: %w", err)
}
return idx, buf, nil
}
func (dev *Device) TransmitPacket(pkt *packet.OutPacket, kick bool) error {
if len(pkt.SegmentIDs) == 0 {
return nil
}
for idx := range pkt.SegmentIDs {
segmentID := pkt.SegmentIDs[idx]
dev.TransmitQueue.SetDescSize(segmentID, len(pkt.Segments[idx]))
}
err := dev.TransmitQueue.OfferDescriptorChains(pkt.SegmentIDs, false)
if err != nil {
return fmt.Errorf("offer descriptor chains: %w", err)
}
pkt.Reset()
if kick {
if err := dev.TransmitQueue.Kick(); err != nil {
return err
}
}
return nil
}
func (dev *Device) TransmitPackets(pkts []*packet.OutPacket) error {
if len(pkts) == 0 {
return nil
}
for i := range pkts {
if err := dev.TransmitPacket(pkts[i], false); err != nil {
return err
}
}
if err := dev.TransmitQueue.Kick(); err != nil {
return err
}
return nil
}
// TODO: Make above methods cancelable by taking a context.Context argument?
// TODO: Implement zero-copy variants to transmit and receive packets?
// processChains processes as many chains as needed to create one packet. The number of processed chains is returned.
func (dev *Device) processChains(pkt *packet.VirtIOPacket, chains []virtqueue.UsedElement) (int, error) {
//read first element to see how many descriptors we need:
pkt.Reset()
err := dev.ReceiveQueue.GetDescriptorInbuffers(uint16(chains[0].DescriptorIndex), &pkt.ChainRefs)
if err != nil {
return 0, fmt.Errorf("get descriptor chain: %w", err)
}
if len(pkt.ChainRefs) == 0 {
return 1, nil
}
// The specification requires that the first descriptor chain starts
// with a virtio-net header. It is not clear, whether it is also
// required to be fully contained in the first buffer of that
// descriptor chain, but it is reasonable to assume that this is
// always the case.
// The decode method already does the buffer length check.
if err = pkt.Header.Decode(pkt.ChainRefs[0][0:]); err != nil {
// The device misbehaved. There is no way we can gracefully
// recover from this, because we don't know how many of the
// following descriptor chains belong to this packet.
return 0, fmt.Errorf("decode vnethdr: %w", err)
}
//we have the header now: what do we need to do?
if int(pkt.Header.NumBuffers) > len(chains) {
return 0, fmt.Errorf("number of buffers is greater than number of chains %d", len(chains))
}
if int(pkt.Header.NumBuffers) != 1 {
return 0, fmt.Errorf("too smol-brain to handle more than one chain right now: %d chains", len(chains))
}
if chains[0].Length > 16000 {
//todo!
return 1, fmt.Errorf("too big packet length: %d", chains[0].Length)
}
//shift the buffer out of out:
pkt.Payload = pkt.ChainRefs[0][virtio.NetHdrSize:chains[0].Length]
pkt.Chains = append(pkt.Chains, uint16(chains[0].DescriptorIndex))
return 1, nil
//cursor := n - virtio.NetHdrSize
//
//if uint32(n) >= chains[0].Length && pkt.Header.NumBuffers == 1 {
// pkt.Payload = pkt.Payload[:chains[0].Length-virtio.NetHdrSize]
// return 1, nil
//}
//
//i := 1
//// we used chain 0 already
//for i = 1; i < len(chains); i++ {
// n, err = dev.ReceiveQueue.GetDescriptorChainContents(uint16(chains[i].DescriptorIndex), pkt.Payload[cursor:], int(chains[i].Length))
// if err != nil {
// // When this fails we may miss to free some descriptor chains. We
// // could try to mitigate this by deferring the freeing somehow, but
// // it's not worth the hassle. When this method fails, the queue will
// // be in a broken state anyway.
// return i, fmt.Errorf("get descriptor chain: %w", err)
// }
// cursor += n
//}
////todo this has to be wrong
//pkt.Payload = pkt.Payload[:cursor]
//return i, nil
}
func (dev *Device) ReceivePackets(out []*packet.VirtIOPacket) (int, error) {
//todo optimize?
var chains []virtqueue.UsedElement
var err error
//if len(dev.extraRx) == 0 {
chains, err = dev.ReceiveQueue.BlockAndGetHeadsCapped(context.TODO(), len(out))
if err != nil {
return 0, err
}
if len(chains) == 0 {
return 0, nil
}
//} else {
// chains = dev.extraRx
//}
numPackets := 0
chainsIdx := 0
for numPackets = 0; chainsIdx < len(chains); numPackets++ {
if numPackets >= len(out) {
return numPackets, fmt.Errorf("dropping %d packets, no room", len(chains)-numPackets)
}
numChains, err := dev.processChains(out[numPackets], chains[chainsIdx:])
if err != nil {
return 0, err
}
chainsIdx += numChains
}
// Now that we have copied all buffers, we can recycle the used descriptor chains
//if err = dev.ReceiveQueue.OfferDescriptorChains(chains); err != nil {
// return 0, err
//}
return numPackets, nil
}

View File

@@ -1,3 +0,0 @@
// Package vhostnet implements methods to initialize vhost networking devices
// within the kernel-level virtio implementation and communicate with them.
package vhostnet

View File

@@ -1,31 +0,0 @@
package vhostnet
import (
"fmt"
"unsafe"
"github.com/slackhq/nebula/overlay/vhost"
)
const (
// vhostNetIoctlSetBackend can be used to attach a virtqueue to a RAW socket
// or TAP device.
//
// Request payload: [vhost.QueueFile]
// Kernel name: VHOST_NET_SET_BACKEND
vhostNetIoctlSetBackend = 0x4008af30
)
// SetQueueBackend attaches a virtqueue of the vhost networking device
// described by controlFD to the given backend file descriptor.
// The backend file descriptor can either be a RAW socket or a TAP device. When
// it is -1, the queue will be detached.
func SetQueueBackend(controlFD int, queueIndex uint32, backendFD int) error {
if err := vhost.IoctlPtr(controlFD, vhostNetIoctlSetBackend, unsafe.Pointer(&vhost.QueueFile{
QueueIndex: queueIndex,
FD: int32(backendFD),
})); err != nil {
return fmt.Errorf("set queue backend file descriptor: %w", err)
}
return nil
}

View File

@@ -1,69 +0,0 @@
package vhostnet
import (
"errors"
"github.com/slackhq/nebula/overlay/virtqueue"
)
type optionValues struct {
queueSize int
backendFD int
}
func (o *optionValues) apply(options []Option) {
for _, option := range options {
option(o)
}
}
func (o *optionValues) validate() error {
if o.queueSize == -1 {
return errors.New("queue size is required")
}
if err := virtqueue.CheckQueueSize(o.queueSize); err != nil {
return err
}
if o.backendFD == -1 {
return errors.New("backend file descriptor is required")
}
return nil
}
var optionDefaults = optionValues{
// Required.
queueSize: -1,
// Required.
backendFD: -1,
}
// Option can be passed to [NewDevice] to influence device creation.
type Option func(*optionValues)
// WithQueueSize returns an [Option] that sets the size of the TX and RX queues
// that are to be created for the device. It specifies the number of
// entries/buffers each queue can hold. This also affects the memory
// consumption.
// This is required and must be an integer from 1 to 32768 that is also a power
// of 2.
func WithQueueSize(queueSize int) Option {
return func(o *optionValues) { o.queueSize = queueSize }
}
// WithBackendFD returns an [Option] that sets the file descriptor of the
// backend that will be used for the queues of the device. The device will write
// and read packets to/from that backend. The file descriptor can either be of a
// RAW socket or TUN/TAP device.
// Either this or [WithBackendDevice] is required.
func WithBackendFD(backendFD int) Option {
return func(o *optionValues) { o.backendFD = backendFD }
}
//// WithBackendDevice returns an [Option] that sets the given TAP device as the
//// backend that will be used for the queues of the device. The device will
//// write and read packets to/from that backend. The TAP device should have been
//// created with the [tuntap.WithVirtioNetHdr] option enabled.
//// Either this or [WithBackendFD] is required.
//func WithBackendDevice(dev *tuntap.Device) Option {
// return func(o *optionValues) { o.backendFD = int(dev.File().Fd()) }
//}

View File

@@ -1,23 +0,0 @@
Significant portions of this code are derived from https://pkg.go.dev/github.com/hetznercloud/virtio-go
MIT License
Copyright (c) 2025 Hetzner Cloud GmbH
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@@ -1,140 +0,0 @@
package virtqueue
import (
"fmt"
"unsafe"
)
// availableRingFlag is a flag that describes an [AvailableRing].
type availableRingFlag uint16
const (
// availableRingFlagNoInterrupt is used by the guest to advise the host to
// not interrupt it when consuming a buffer. It's unreliable, so it's simply
// an optimization.
availableRingFlagNoInterrupt availableRingFlag = 1 << iota
)
// availableRingSize is the number of bytes needed to store an [AvailableRing]
// with the given queue size in memory.
func availableRingSize(queueSize int) int {
return 6 + 2*queueSize
}
// availableRingAlignment is the minimum alignment of an [AvailableRing]
// in memory, as required by the virtio spec.
const availableRingAlignment = 2
// AvailableRing is used by the driver to offer descriptor chains to the device.
// Each ring entry refers to the head of a descriptor chain. It is only written
// to by the driver and read by the device.
//
// Because the size of the ring depends on the queue size, we cannot define a
// Go struct with a static size that maps to the memory of the ring. Instead,
// this struct only contains pointers to the corresponding memory areas.
type AvailableRing struct {
initialized bool
// flags that describe this ring.
flags *availableRingFlag
// ringIndex indicates where the driver would put the next entry into the
// ring (modulo the queue size).
ringIndex *uint16
// ring references buffers using the index of the head of the descriptor
// chain in the [DescriptorTable]. It wraps around at queue size.
ring []uint16
// usedEvent is not used by this implementation, but we reserve it anyway to
// avoid issues in case a device may try to access it, contrary to the
// virtio specification.
usedEvent *uint16
}
// newAvailableRing creates an available ring that uses the given underlying
// memory. The length of the memory slice must match the size needed for the
// ring (see [availableRingSize]) for the given queue size.
func newAvailableRing(queueSize int, mem []byte) *AvailableRing {
ringSize := availableRingSize(queueSize)
if len(mem) != ringSize {
panic(fmt.Sprintf("memory size (%v) does not match required size "+
"for available ring: %v", len(mem), ringSize))
}
return &AvailableRing{
initialized: true,
flags: (*availableRingFlag)(unsafe.Pointer(&mem[0])),
ringIndex: (*uint16)(unsafe.Pointer(&mem[2])),
ring: unsafe.Slice((*uint16)(unsafe.Pointer(&mem[4])), queueSize),
usedEvent: (*uint16)(unsafe.Pointer(&mem[ringSize-2])),
}
}
// Address returns the pointer to the beginning of the ring in memory.
// Do not modify the memory directly to not interfere with this implementation.
func (r *AvailableRing) Address() uintptr {
if !r.initialized {
panic("available ring is not initialized")
}
return uintptr(unsafe.Pointer(r.flags))
}
// offer adds the given descriptor chain heads to the available ring and
// advances the ring index accordingly to make the device process the new
// descriptor chains.
func (r *AvailableRing) offerElements(chains []UsedElement) {
//always called under lock
//r.mu.Lock()
//defer r.mu.Unlock()
// Add descriptor chain heads to the ring.
for offset, x := range chains {
// The 16-bit ring index may overflow. This is expected and is not an
// issue because the size of the ring array (which equals the queue
// size) is always a power of 2 and smaller than the highest possible
// 16-bit value.
insertIndex := int(*r.ringIndex+uint16(offset)) % len(r.ring)
r.ring[insertIndex] = x.GetHead()
}
// Increase the ring index by the number of descriptor chains added to the
// ring.
*r.ringIndex += uint16(len(chains))
}
func (r *AvailableRing) offer(chains []uint16) {
//always called under lock
//r.mu.Lock()
//defer r.mu.Unlock()
// Add descriptor chain heads to the ring.
for offset, x := range chains {
// The 16-bit ring index may overflow. This is expected and is not an
// issue because the size of the ring array (which equals the queue
// size) is always a power of 2 and smaller than the highest possible
// 16-bit value.
insertIndex := int(*r.ringIndex+uint16(offset)) % len(r.ring)
r.ring[insertIndex] = x
}
// Increase the ring index by the number of descriptor chains added to the
// ring.
*r.ringIndex += uint16(len(chains))
}
func (r *AvailableRing) offerSingle(x uint16) {
//always called under lock
//r.mu.Lock()
//defer r.mu.Unlock()
offset := 0
// Add descriptor chain heads to the ring.
// The 16-bit ring index may overflow. This is expected and is not an
// issue because the size of the ring array (which equals the queue
// size) is always a power of 2 and smaller than the highest possible
// 16-bit value.
insertIndex := int(*r.ringIndex+uint16(offset)) % len(r.ring)
r.ring[insertIndex] = x
// Increase the ring index by the number of descriptor chains added to the ring.
*r.ringIndex += 1
}

View File

@@ -1,71 +0,0 @@
package virtqueue
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestAvailableRing_MemoryLayout(t *testing.T) {
const queueSize = 2
memory := make([]byte, availableRingSize(queueSize))
r := newAvailableRing(queueSize, memory)
*r.flags = 0x01ff
*r.ringIndex = 1
r.ring[0] = 0x1234
r.ring[1] = 0x5678
assert.Equal(t, []byte{
0xff, 0x01,
0x01, 0x00,
0x34, 0x12,
0x78, 0x56,
0x00, 0x00,
}, memory)
}
func TestAvailableRing_Offer(t *testing.T) {
const queueSize = 8
chainHeads := []uint16{42, 33, 69}
tests := []struct {
name string
startRingIndex uint16
expectedRingIndex uint16
expectedRing []uint16
}{
{
name: "no overflow",
startRingIndex: 0,
expectedRingIndex: 3,
expectedRing: []uint16{42, 33, 69, 0, 0, 0, 0, 0},
},
{
name: "ring overflow",
startRingIndex: 6,
expectedRingIndex: 9,
expectedRing: []uint16{69, 0, 0, 0, 0, 0, 42, 33},
},
{
name: "index overflow",
startRingIndex: 65535,
expectedRingIndex: 2,
expectedRing: []uint16{33, 69, 0, 0, 0, 0, 0, 42},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
memory := make([]byte, availableRingSize(queueSize))
r := newAvailableRing(queueSize, memory)
*r.ringIndex = tt.startRingIndex
r.offer(chainHeads)
assert.Equal(t, tt.expectedRingIndex, *r.ringIndex)
assert.Equal(t, tt.expectedRing, r.ring)
})
}
}

View File

@@ -1,43 +0,0 @@
package virtqueue
// descriptorFlag is a flag that describes a [Descriptor].
type descriptorFlag uint16
const (
// descriptorFlagHasNext marks a descriptor chain as continuing via the next
// field.
descriptorFlagHasNext descriptorFlag = 1 << iota
// descriptorFlagWritable marks a buffer as device write-only (otherwise
// device read-only).
descriptorFlagWritable
// descriptorFlagIndirect means the buffer contains a list of buffer
// descriptors to provide an additional layer of indirection.
// Only allowed when the [virtio.FeatureIndirectDescriptors] feature was
// negotiated.
descriptorFlagIndirect
)
// descriptorSize is the number of bytes needed to store a [Descriptor] in
// memory.
const descriptorSize = 16
// Descriptor describes (a part of) a buffer which is either read-only for the
// device or write-only for the device (depending on [descriptorFlagWritable]).
// Multiple descriptors can be chained to produce a "descriptor chain" that can
// contain both device-readable and device-writable buffers. Device-readable
// descriptors always come first in a chain. A single, large buffer may be
// split up by chaining multiple similar descriptors that reference different
// memory pages. This is required, because buffers may exceed a single page size
// and the memory accessed by the device is expected to be continuous.
type Descriptor struct {
// address is the address to the continuous memory holding the data for this
// descriptor.
address uintptr
// length is the amount of bytes stored at address.
length uint32
// flags that describe this descriptor.
flags descriptorFlag
// next contains the index of the next descriptor continuing this descriptor
// chain when the [descriptorFlagHasNext] flag is set.
next uint16
}

View File

@@ -1,12 +0,0 @@
package virtqueue
import (
"testing"
"unsafe"
"github.com/stretchr/testify/assert"
)
func TestDescriptor_Size(t *testing.T) {
assert.EqualValues(t, descriptorSize, unsafe.Sizeof(Descriptor{}))
}

View File

@@ -1,641 +0,0 @@
package virtqueue
import (
"errors"
"fmt"
"math"
"unsafe"
"golang.org/x/sys/unix"
)
var (
// ErrDescriptorChainEmpty is returned when a descriptor chain would contain
// no buffers, which is not allowed.
ErrDescriptorChainEmpty = errors.New("empty descriptor chains are not allowed")
// ErrNotEnoughFreeDescriptors is returned when the free descriptors are
// exhausted, meaning that the queue is full.
ErrNotEnoughFreeDescriptors = errors.New("not enough free descriptors, queue is full")
// ErrInvalidDescriptorChain is returned when a descriptor chain is not
// valid for a given operation.
ErrInvalidDescriptorChain = errors.New("invalid descriptor chain")
)
// noFreeHead is used to mark when all descriptors are in use and we have no
// free chain. This value is impossible to occur as an index naturally, because
// it exceeds the maximum queue size.
const noFreeHead = uint16(math.MaxUint16)
// descriptorTableSize is the number of bytes needed to store a
// [DescriptorTable] with the given queue size in memory.
func descriptorTableSize(queueSize int) int {
return descriptorSize * queueSize
}
// descriptorTableAlignment is the minimum alignment of a [DescriptorTable]
// in memory, as required by the virtio spec.
const descriptorTableAlignment = 16
// DescriptorTable is a table that holds [Descriptor]s, addressed via their
// index in the slice.
type DescriptorTable struct {
descriptors []Descriptor
// freeHeadIndex is the index of the head of the descriptor chain which
// contains all currently unused descriptors. When all descriptors are in
// use, this has the special value of noFreeHead.
freeHeadIndex uint16
// freeNum tracks the number of descriptors which are currently not in use.
freeNum uint16
bufferBase uintptr
bufferSize int
itemSize int
}
// newDescriptorTable creates a descriptor table that uses the given underlying
// memory. The Length of the memory slice must match the size needed for the
// descriptor table (see [descriptorTableSize]) for the given queue size.
//
// Before this descriptor table can be used, [initialize] must be called.
func newDescriptorTable(queueSize int, mem []byte, itemSize int) *DescriptorTable {
dtSize := descriptorTableSize(queueSize)
if len(mem) != dtSize {
panic(fmt.Sprintf("memory size (%v) does not match required size "+
"for descriptor table: %v", len(mem), dtSize))
}
return &DescriptorTable{
descriptors: unsafe.Slice((*Descriptor)(unsafe.Pointer(&mem[0])), queueSize),
// We have no free descriptors until they were initialized.
freeHeadIndex: noFreeHead,
freeNum: 0,
itemSize: itemSize, //todo configurable? needs to be page-aligned
}
}
// Address returns the pointer to the beginning of the descriptor table in
// memory. Do not modify the memory directly to not interfere with this
// implementation.
func (dt *DescriptorTable) Address() uintptr {
if dt.descriptors == nil {
panic("descriptor table is not initialized")
}
//should be same as dt.bufferBase
return uintptr(unsafe.Pointer(&dt.descriptors[0]))
}
func (dt *DescriptorTable) Size() uintptr {
if dt.descriptors == nil {
panic("descriptor table is not initialized")
}
return uintptr(dt.bufferSize)
}
// BufferAddresses returns a map of pointer->size for all allocations used by the table
func (dt *DescriptorTable) BufferAddresses() map[uintptr]int {
if dt.descriptors == nil {
panic("descriptor table is not initialized")
}
return map[uintptr]int{dt.bufferBase: dt.bufferSize}
}
// initializeDescriptors allocates buffers with the size of a full memory page
// for each descriptor in the table. While this may be a bit wasteful, it makes
// dealing with descriptors way easier. Without this preallocation, we would
// have to allocate and free memory on demand, increasing complexity.
//
// All descriptors will be marked as free and will form a free chain. The
// addresses of all descriptors will be populated while their length remains
// zero.
func (dt *DescriptorTable) initializeDescriptors() error {
numDescriptors := len(dt.descriptors)
// Allocate ONE large region for all buffers
totalSize := dt.itemSize * numDescriptors
basePtr, err := unix.MmapPtr(-1, 0, nil, uintptr(totalSize),
unix.PROT_READ|unix.PROT_WRITE,
unix.MAP_PRIVATE|unix.MAP_ANONYMOUS)
if err != nil {
return fmt.Errorf("allocate buffer memory for descriptors: %w", err)
}
// Store the base for cleanup later
dt.bufferBase = uintptr(basePtr)
dt.bufferSize = totalSize
for i := range dt.descriptors {
dt.descriptors[i] = Descriptor{
address: dt.bufferBase + uintptr(i*dt.itemSize),
length: 0,
// All descriptors should form a free chain that loops around.
flags: descriptorFlagHasNext,
next: uint16((i + 1) % len(dt.descriptors)),
}
}
// All descriptors are free to use now.
dt.freeHeadIndex = 0
dt.freeNum = uint16(len(dt.descriptors))
return nil
}
// releaseBuffers releases all allocated buffers for this descriptor table.
// The implementation will try to release as many buffers as possible and
// collect potential errors before returning them.
// The descriptor table should no longer be used after calling this.
func (dt *DescriptorTable) releaseBuffers() error {
for i := range dt.descriptors {
descriptor := &dt.descriptors[i]
descriptor.address = 0
}
// As a safety measure, make sure no descriptors can be used anymore.
dt.freeHeadIndex = noFreeHead
dt.freeNum = 0
if dt.bufferBase != 0 {
// The pointer points to memory not managed by Go, so this conversion
// is safe. See https://github.com/golang/go/issues/58625
dt.bufferBase = 0
//goland:noinspection GoVetUnsafePointer
err := unix.MunmapPtr(unsafe.Pointer(dt.bufferBase), uintptr(dt.bufferSize))
if err != nil {
return fmt.Errorf("release buffer memory: %w", err)
}
}
return nil
}
// createDescriptorChain creates a new descriptor chain within the descriptor
// table which contains a number of device-readable buffers (out buffers) and
// device-writable buffers (in buffers).
//
// All buffers in the outBuffers slice will be concatenated by chaining
// descriptors, one for each buffer in the slice. The size of the single buffers
// must not exceed the size of a memory page (see [os.Getpagesize]).
// When numInBuffers is greater than zero, the given number of device-writable
// descriptors will be appended to the end of the chain, each referencing a
// whole memory page.
//
// The index of the head of the new descriptor chain will be returned. Callers
// should make sure to free the descriptor chain using [freeDescriptorChain]
// after it was used by the device.
//
// When there are not enough free descriptors to hold the given number of
// buffers, an [ErrNotEnoughFreeDescriptors] will be returned. In this case, the
// caller should try again after some descriptor chains were used by the device
// and returned back into the free chain.
func (dt *DescriptorTable) createDescriptorChain(outBuffers [][]byte, numInBuffers int) (uint16, error) {
// Calculate the number of descriptors needed to build the chain.
numDesc := uint16(len(outBuffers) + numInBuffers)
// Descriptor chains must always contain at least one descriptor.
if numDesc < 1 {
return 0, ErrDescriptorChainEmpty
}
// Do we still have enough free descriptors?
if numDesc > dt.freeNum {
return 0, ErrNotEnoughFreeDescriptors
}
// Above validation ensured that there is at least one free descriptor, so
// the free descriptor chain head should be valid.
if dt.freeHeadIndex == noFreeHead {
panic("free descriptor chain head is unset but there should be free descriptors")
}
// To avoid having to iterate over the whole table to find the descriptor
// pointing to the head just to replace the free head, we instead always
// create descriptor chains from the descriptors coming after the head.
// This way we only have to touch the head as a last resort, when all other
// descriptors are already used.
head := dt.descriptors[dt.freeHeadIndex].next
next := head
tail := head
for i, buffer := range outBuffers {
desc := &dt.descriptors[next]
checkUnusedDescriptorLength(next, desc)
if len(buffer) > dt.itemSize {
// The caller should already prevent that from happening.
panic(fmt.Sprintf("out buffer %d has size %d which exceeds desc length %d", i, len(buffer), dt.itemSize))
}
// Copy the buffer to the memory referenced by the descriptor.
// The descriptor address points to memory not managed by Go, so this
// conversion is safe. See https://github.com/golang/go/issues/58625
//goland:noinspection GoVetUnsafePointer
copy(unsafe.Slice((*byte)(unsafe.Pointer(desc.address)), dt.itemSize), buffer)
desc.length = uint32(len(buffer))
// Clear the flags in case there were any others set.
desc.flags = descriptorFlagHasNext
tail = next
next = desc.next
}
for range numInBuffers {
desc := &dt.descriptors[next]
checkUnusedDescriptorLength(next, desc)
// Give the device the maximum available number of bytes to write into.
desc.length = uint32(dt.itemSize)
// Mark the descriptor as device-writable.
desc.flags = descriptorFlagHasNext | descriptorFlagWritable
tail = next
next = desc.next
}
// The last descriptor should end the chain.
tailDesc := &dt.descriptors[tail]
tailDesc.flags &= ^descriptorFlagHasNext
tailDesc.next = 0 // Not necessary to clear this, it's just for looks.
dt.freeNum -= numDesc
if dt.freeNum == 0 {
// The last descriptor in the chain should be the free chain head
// itself.
if tail != dt.freeHeadIndex {
panic("descriptor chain takes up all free descriptors but does not end with the free chain head")
}
// When this new chain takes up all remaining descriptors, we no longer
// have a free chain.
dt.freeHeadIndex = noFreeHead
} else {
// We took some descriptors out of the free chain, so make sure to close
// the circle again.
dt.descriptors[dt.freeHeadIndex].next = next
}
return head, nil
}
func (dt *DescriptorTable) CreateDescriptorForOutputs() (uint16, error) {
//todo just fill the damn table
// Do we still have enough free descriptors?
if 1 > dt.freeNum {
return 0, ErrNotEnoughFreeDescriptors
}
// Above validation ensured that there is at least one free descriptor, so
// the free descriptor chain head should be valid.
if dt.freeHeadIndex == noFreeHead {
panic("free descriptor chain head is unset but there should be free descriptors")
}
// To avoid having to iterate over the whole table to find the descriptor
// pointing to the head just to replace the free head, we instead always
// create descriptor chains from the descriptors coming after the head.
// This way we only have to touch the head as a last resort, when all other
// descriptors are already used.
head := dt.descriptors[dt.freeHeadIndex].next
desc := &dt.descriptors[head]
next := desc.next
checkUnusedDescriptorLength(head, desc)
// Give the device the maximum available number of bytes to write into.
desc.length = uint32(dt.itemSize)
desc.flags = 0 // descriptorFlagWritable
desc.next = 0 // Not necessary to clear this, it's just for looks.
dt.freeNum -= 1
if dt.freeNum == 0 {
// The last descriptor in the chain should be the free chain head
// itself.
if next != dt.freeHeadIndex {
panic("descriptor chain takes up all free descriptors but does not end with the free chain head")
}
// When this new chain takes up all remaining descriptors, we no longer
// have a free chain.
dt.freeHeadIndex = noFreeHead
} else {
// We took some descriptors out of the free chain, so make sure to close
// the circle again.
dt.descriptors[dt.freeHeadIndex].next = next
}
return head, nil
}
func (dt *DescriptorTable) createDescriptorForInputs() (uint16, error) {
// Do we still have enough free descriptors?
if 1 > dt.freeNum {
return 0, ErrNotEnoughFreeDescriptors
}
// Above validation ensured that there is at least one free descriptor, so
// the free descriptor chain head should be valid.
if dt.freeHeadIndex == noFreeHead {
panic("free descriptor chain head is unset but there should be free descriptors")
}
// To avoid having to iterate over the whole table to find the descriptor
// pointing to the head just to replace the free head, we instead always
// create descriptor chains from the descriptors coming after the head.
// This way we only have to touch the head as a last resort, when all other
// descriptors are already used.
head := dt.descriptors[dt.freeHeadIndex].next
desc := &dt.descriptors[head]
next := desc.next
checkUnusedDescriptorLength(head, desc)
// Give the device the maximum available number of bytes to write into.
desc.length = uint32(dt.itemSize)
desc.flags = descriptorFlagWritable
desc.next = 0 // Not necessary to clear this, it's just for looks.
dt.freeNum -= 1
if dt.freeNum == 0 {
// The last descriptor in the chain should be the free chain head
// itself.
if next != dt.freeHeadIndex {
panic("descriptor chain takes up all free descriptors but does not end with the free chain head")
}
// When this new chain takes up all remaining descriptors, we no longer
// have a free chain.
dt.freeHeadIndex = noFreeHead
} else {
// We took some descriptors out of the free chain, so make sure to close
// the circle again.
dt.descriptors[dt.freeHeadIndex].next = next
}
return head, nil
}
// TODO: Implement a zero-copy variant of createDescriptorChain?
// getDescriptorChain returns the device-readable buffers (out buffers) and
// device-writable buffers (in buffers) of the descriptor chain that starts with
// the given head index. The descriptor chain must have been created using
// [createDescriptorChain] and must not have been freed yet (meaning that the
// head index must not be contained in the free chain).
//
// Be careful to only access the returned buffer slices when the device has not
// yet or is no longer using them. They must not be accessed after
// [freeDescriptorChain] has been called.
func (dt *DescriptorTable) getDescriptorChain(head uint16) (outBuffers, inBuffers [][]byte, err error) {
if int(head) > len(dt.descriptors) {
return nil, nil, fmt.Errorf("%w: index out of range", ErrInvalidDescriptorChain)
}
// Iterate over the chain. The iteration is limited to the queue size to
// avoid ending up in an endless loop when things go very wrong.
next := head
for range len(dt.descriptors) {
if next == dt.freeHeadIndex {
return nil, nil, fmt.Errorf("%w: must not be part of the free chain", ErrInvalidDescriptorChain)
}
desc := &dt.descriptors[next]
// The descriptor address points to memory not managed by Go, so this
// conversion is safe. See https://github.com/golang/go/issues/58625
//goland:noinspection GoVetUnsafePointer
bs := unsafe.Slice((*byte)(unsafe.Pointer(desc.address)), desc.length)
if desc.flags&descriptorFlagWritable == 0 {
outBuffers = append(outBuffers, bs)
} else {
inBuffers = append(inBuffers, bs)
}
// Is this the tail of the chain?
if desc.flags&descriptorFlagHasNext == 0 {
break
}
// Detect loops.
if desc.next == head {
return nil, nil, fmt.Errorf("%w: contains a loop", ErrInvalidDescriptorChain)
}
next = desc.next
}
return
}
func (dt *DescriptorTable) getDescriptorItem(head uint16) ([]byte, error) {
if int(head) > len(dt.descriptors) {
return nil, fmt.Errorf("%w: index out of range", ErrInvalidDescriptorChain)
}
desc := &dt.descriptors[head] //todo this is a pretty nasty hack with no checks
// The descriptor address points to memory not managed by Go, so this
// conversion is safe. See https://github.com/golang/go/issues/58625
//goland:noinspection GoVetUnsafePointer
bs := unsafe.Slice((*byte)(unsafe.Pointer(desc.address)), desc.length)
return bs, nil
}
func (dt *DescriptorTable) getDescriptorInbuffers(head uint16, inBuffers *[][]byte) error {
if int(head) > len(dt.descriptors) {
return fmt.Errorf("%w: index out of range", ErrInvalidDescriptorChain)
}
// Iterate over the chain. The iteration is limited to the queue size to
// avoid ending up in an endless loop when things go very wrong.
next := head
for range len(dt.descriptors) {
if next == dt.freeHeadIndex {
return fmt.Errorf("%w: must not be part of the free chain", ErrInvalidDescriptorChain)
}
desc := &dt.descriptors[next]
// The descriptor address points to memory not managed by Go, so this
// conversion is safe. See https://github.com/golang/go/issues/58625
//goland:noinspection GoVetUnsafePointer
bs := unsafe.Slice((*byte)(unsafe.Pointer(desc.address)), desc.length)
if desc.flags&descriptorFlagWritable == 0 {
return fmt.Errorf("there should not be an outbuffer in %d", head)
} else {
*inBuffers = append(*inBuffers, bs)
}
// Is this the tail of the chain?
if desc.flags&descriptorFlagHasNext == 0 {
break
}
// Detect loops.
if desc.next == head {
return fmt.Errorf("%w: contains a loop", ErrInvalidDescriptorChain)
}
next = desc.next
}
return nil
}
func (dt *DescriptorTable) getDescriptorChainContents(head uint16, out []byte, maxLen int) (int, error) {
if int(head) > len(dt.descriptors) {
return 0, fmt.Errorf("%w: index out of range", ErrInvalidDescriptorChain)
}
// Iterate over the chain. The iteration is limited to the queue size to
// avoid ending up in an endless loop when things go very wrong.
length := 0
//find length
next := head
for range len(dt.descriptors) {
if next == dt.freeHeadIndex {
return 0, fmt.Errorf("%w: must not be part of the free chain", ErrInvalidDescriptorChain)
}
desc := &dt.descriptors[next]
if desc.flags&descriptorFlagWritable == 0 {
return 0, fmt.Errorf("receive queue contains device-readable buffer")
}
length += int(desc.length)
// Is this the tail of the chain?
if desc.flags&descriptorFlagHasNext == 0 {
break
}
// Detect loops.
if desc.next == head {
return 0, fmt.Errorf("%w: contains a loop", ErrInvalidDescriptorChain)
}
next = desc.next
}
if maxLen > 0 {
//todo length = min(maxLen, length)
}
//set out to length:
out = out[:length]
//now do the copying
copied := 0
for range len(dt.descriptors) {
desc := &dt.descriptors[next]
// The descriptor address points to memory not managed by Go, so this
// conversion is safe. See https://github.com/golang/go/issues/58625
//goland:noinspection GoVetUnsafePointer
bs := unsafe.Slice((*byte)(unsafe.Pointer(desc.address)), min(uint32(length-copied), desc.length))
copied += copy(out[copied:], bs)
// Is this the tail of the chain?
if desc.flags&descriptorFlagHasNext == 0 {
break
}
// we did this already, no need to detect loops.
next = desc.next
}
if copied != length {
panic(fmt.Sprintf("expected to copy %d bytes but only copied %d bytes", length, copied))
}
return length, nil
}
// freeDescriptorChain can be used to free a descriptor chain when it is no
// longer in use. The descriptor chain that starts with the given index will be
// put back into the free chain, so the descriptors can be used for later calls
// of [createDescriptorChain].
// The descriptor chain must have been created using [createDescriptorChain] and
// must not have been freed yet (meaning that the head index must not be
// contained in the free chain).
func (dt *DescriptorTable) freeDescriptorChain(head uint16) error {
if int(head) > len(dt.descriptors) {
return fmt.Errorf("%w: index out of range", ErrInvalidDescriptorChain)
}
// Iterate over the chain. The iteration is limited to the queue size to
// avoid ending up in an endless loop when things go very wrong.
next := head
var tailDesc *Descriptor
var chainLen uint16
for range len(dt.descriptors) {
if next == dt.freeHeadIndex {
return fmt.Errorf("%w: must not be part of the free chain", ErrInvalidDescriptorChain)
}
desc := &dt.descriptors[next]
chainLen++
// Set the length of all unused descriptors back to zero.
desc.length = 0
// Unset all flags except the next flag.
desc.flags &= descriptorFlagHasNext
// Is this the tail of the chain?
if desc.flags&descriptorFlagHasNext == 0 {
tailDesc = desc
break
}
// Detect loops.
if desc.next == head {
return fmt.Errorf("%w: contains a loop", ErrInvalidDescriptorChain)
}
next = desc.next
}
if tailDesc == nil {
// A descriptor chain longer than the queue size but without loops
// should be impossible.
panic(fmt.Sprintf("could not find a tail for descriptor chain starting at %d", head))
}
// The tail descriptor does not have the next flag set, but when it comes
// back into the free chain, it should have.
tailDesc.flags = descriptorFlagHasNext
if dt.freeHeadIndex == noFreeHead {
// The whole free chain was used up, so we turn this returned descriptor
// chain into the new free chain by completing the circle and using its
// head.
tailDesc.next = head
dt.freeHeadIndex = head
} else {
// Attach the returned chain at the beginning of the free chain but
// right after the free chain head.
freeHeadDesc := &dt.descriptors[dt.freeHeadIndex]
tailDesc.next = freeHeadDesc.next
freeHeadDesc.next = head
}
dt.freeNum += chainLen
return nil
}
// checkUnusedDescriptorLength asserts that the length of an unused descriptor
// is zero, as it should be.
// This is not a requirement by the virtio spec but rather a thing we do to
// notice when our algorithm goes sideways.
func checkUnusedDescriptorLength(index uint16, desc *Descriptor) {
if desc.length != 0 {
panic(fmt.Sprintf("descriptor %d should be unused but has a non-zero length", index))
}
}

View File

@@ -1,407 +0,0 @@
package virtqueue
import (
"os"
"testing"
"unsafe"
"github.com/stretchr/testify/assert"
)
func TestDescriptorTable_InitializeDescriptors(t *testing.T) {
const queueSize = 32
dt := DescriptorTable{
descriptors: make([]Descriptor, queueSize),
}
assert.NoError(t, dt.initializeDescriptors())
t.Cleanup(func() {
assert.NoError(t, dt.releaseBuffers())
})
for i, descriptor := range dt.descriptors {
assert.NotZero(t, descriptor.address)
assert.Zero(t, descriptor.length)
assert.EqualValues(t, descriptorFlagHasNext, descriptor.flags)
assert.EqualValues(t, (i+1)%queueSize, descriptor.next)
}
}
func TestDescriptorTable_DescriptorChains(t *testing.T) {
// Use a very short queue size to not make this test overly verbose.
const queueSize = 8
pageSize := os.Getpagesize() * 2
// Initialize descriptor table.
dt := DescriptorTable{
descriptors: make([]Descriptor, queueSize),
}
assert.NoError(t, dt.initializeDescriptors())
t.Cleanup(func() {
assert.NoError(t, dt.releaseBuffers())
})
// Some utilities for easier checking if the descriptor table looks as
// expected.
type desc struct {
buffer []byte
flags descriptorFlag
next uint16
}
assertDescriptorTable := func(expected [queueSize]desc) {
for i := 0; i < queueSize; i++ {
actualDesc := &dt.descriptors[i]
expectedDesc := &expected[i]
assert.Equal(t, uint32(len(expectedDesc.buffer)), actualDesc.length)
if len(expectedDesc.buffer) > 0 {
//goland:noinspection GoVetUnsafePointer
assert.EqualValues(t,
unsafe.Slice((*byte)(unsafe.Pointer(actualDesc.address)), actualDesc.length),
expectedDesc.buffer)
}
assert.Equal(t, expectedDesc.flags, actualDesc.flags)
if expectedDesc.flags&descriptorFlagHasNext != 0 {
assert.Equal(t, expectedDesc.next, actualDesc.next)
}
}
}
// Initial state: All descriptors are in the free chain.
assert.Equal(t, uint16(0), dt.freeHeadIndex)
assert.Equal(t, uint16(8), dt.freeNum)
assertDescriptorTable([queueSize]desc{
{
// Free head.
flags: descriptorFlagHasNext,
next: 1,
},
{
flags: descriptorFlagHasNext,
next: 2,
},
{
flags: descriptorFlagHasNext,
next: 3,
},
{
flags: descriptorFlagHasNext,
next: 4,
},
{
flags: descriptorFlagHasNext,
next: 5,
},
{
flags: descriptorFlagHasNext,
next: 6,
},
{
flags: descriptorFlagHasNext,
next: 7,
},
{
flags: descriptorFlagHasNext,
next: 0,
},
})
// Create the first chain.
firstChain, err := dt.createDescriptorChain([][]byte{
makeTestBuffer(t, 26),
makeTestBuffer(t, 256),
}, 1)
assert.NoError(t, err)
assert.Equal(t, uint16(1), firstChain)
// Now there should be a new chain next to the free chain.
assert.Equal(t, uint16(0), dt.freeHeadIndex)
assert.Equal(t, uint16(5), dt.freeNum)
assertDescriptorTable([queueSize]desc{
{
// Free head.
flags: descriptorFlagHasNext,
next: 4,
},
{
// Head of first chain.
buffer: makeTestBuffer(t, 26),
flags: descriptorFlagHasNext,
next: 2,
},
{
buffer: makeTestBuffer(t, 256),
flags: descriptorFlagHasNext,
next: 3,
},
{
// Tail of first chain.
buffer: make([]byte, pageSize),
flags: descriptorFlagWritable,
},
{
flags: descriptorFlagHasNext,
next: 5,
},
{
flags: descriptorFlagHasNext,
next: 6,
},
{
flags: descriptorFlagHasNext,
next: 7,
},
{
flags: descriptorFlagHasNext,
next: 0,
},
})
// Create a second chain with only a single in buffer.
secondChain, err := dt.createDescriptorChain(nil, 1)
assert.NoError(t, err)
assert.Equal(t, uint16(4), secondChain)
// Now there should be two chains next to the free chain.
assert.Equal(t, uint16(0), dt.freeHeadIndex)
assert.Equal(t, uint16(4), dt.freeNum)
assertDescriptorTable([queueSize]desc{
{
// Free head.
flags: descriptorFlagHasNext,
next: 5,
},
{
// Head of the first chain.
buffer: makeTestBuffer(t, 26),
flags: descriptorFlagHasNext,
next: 2,
},
{
buffer: makeTestBuffer(t, 256),
flags: descriptorFlagHasNext,
next: 3,
},
{
// Tail of the first chain.
buffer: make([]byte, pageSize),
flags: descriptorFlagWritable,
},
{
// Head and tail of the second chain.
buffer: make([]byte, pageSize),
flags: descriptorFlagWritable,
},
{
flags: descriptorFlagHasNext,
next: 6,
},
{
flags: descriptorFlagHasNext,
next: 7,
},
{
flags: descriptorFlagHasNext,
next: 0,
},
})
// Create a third chain taking up all remaining descriptors.
thirdChain, err := dt.createDescriptorChain([][]byte{
makeTestBuffer(t, 42),
makeTestBuffer(t, 96),
makeTestBuffer(t, 33),
makeTestBuffer(t, 222),
}, 0)
assert.NoError(t, err)
assert.Equal(t, uint16(5), thirdChain)
// Now there should be three chains and no free chain.
assert.Equal(t, noFreeHead, dt.freeHeadIndex)
assert.Equal(t, uint16(0), dt.freeNum)
assertDescriptorTable([queueSize]desc{
{
// Tail of the third chain.
buffer: makeTestBuffer(t, 222),
},
{
// Head of the first chain.
buffer: makeTestBuffer(t, 26),
flags: descriptorFlagHasNext,
next: 2,
},
{
buffer: makeTestBuffer(t, 256),
flags: descriptorFlagHasNext,
next: 3,
},
{
// Tail of the first chain.
buffer: make([]byte, pageSize),
flags: descriptorFlagWritable,
},
{
// Head and tail of the second chain.
buffer: make([]byte, pageSize),
flags: descriptorFlagWritable,
},
{
// Head of the third chain.
buffer: makeTestBuffer(t, 42),
flags: descriptorFlagHasNext,
next: 6,
},
{
buffer: makeTestBuffer(t, 96),
flags: descriptorFlagHasNext,
next: 7,
},
{
buffer: makeTestBuffer(t, 33),
flags: descriptorFlagHasNext,
next: 0,
},
})
// Free the third chain.
assert.NoError(t, dt.freeDescriptorChain(thirdChain))
// Now there should be two chains and a free chain again.
assert.Equal(t, uint16(5), dt.freeHeadIndex)
assert.Equal(t, uint16(4), dt.freeNum)
assertDescriptorTable([queueSize]desc{
{
flags: descriptorFlagHasNext,
next: 5,
},
{
// Head of the first chain.
buffer: makeTestBuffer(t, 26),
flags: descriptorFlagHasNext,
next: 2,
},
{
buffer: makeTestBuffer(t, 256),
flags: descriptorFlagHasNext,
next: 3,
},
{
// Tail of the first chain.
buffer: make([]byte, pageSize),
flags: descriptorFlagWritable,
},
{
// Head and tail of the second chain.
buffer: make([]byte, pageSize),
flags: descriptorFlagWritable,
},
{
// Free head.
flags: descriptorFlagHasNext,
next: 6,
},
{
flags: descriptorFlagHasNext,
next: 7,
},
{
flags: descriptorFlagHasNext,
next: 0,
},
})
// Free the first chain.
assert.NoError(t, dt.freeDescriptorChain(firstChain))
// Now there should be only a single chain next to the free chain.
assert.Equal(t, uint16(5), dt.freeHeadIndex)
assert.Equal(t, uint16(7), dt.freeNum)
assertDescriptorTable([queueSize]desc{
{
flags: descriptorFlagHasNext,
next: 5,
},
{
flags: descriptorFlagHasNext,
next: 2,
},
{
flags: descriptorFlagHasNext,
next: 3,
},
{
flags: descriptorFlagHasNext,
next: 6,
},
{
// Head and tail of the second chain.
buffer: make([]byte, pageSize),
flags: descriptorFlagWritable,
},
{
// Free head.
flags: descriptorFlagHasNext,
next: 1,
},
{
flags: descriptorFlagHasNext,
next: 7,
},
{
flags: descriptorFlagHasNext,
next: 0,
},
})
// Free the second chain.
assert.NoError(t, dt.freeDescriptorChain(secondChain))
// Now all descriptors should be in the free chain again.
assert.Equal(t, uint16(5), dt.freeHeadIndex)
assert.Equal(t, uint16(8), dt.freeNum)
assertDescriptorTable([queueSize]desc{
{
flags: descriptorFlagHasNext,
next: 5,
},
{
flags: descriptorFlagHasNext,
next: 2,
},
{
flags: descriptorFlagHasNext,
next: 3,
},
{
flags: descriptorFlagHasNext,
next: 6,
},
{
flags: descriptorFlagHasNext,
next: 1,
},
{
// Free head.
flags: descriptorFlagHasNext,
next: 4,
},
{
flags: descriptorFlagHasNext,
next: 7,
},
{
flags: descriptorFlagHasNext,
next: 0,
},
})
}
func makeTestBuffer(t *testing.T, length int) []byte {
t.Helper()
buf := make([]byte, length)
for i := 0; i < length; i++ {
buf[i] = byte(length - i)
}
return buf
}

View File

@@ -1,7 +0,0 @@
// Package virtqueue implements the driver-side for a virtio queue as described
// in the specification:
// https://docs.oasis-open.org/virtio/virtio/v1.2/csd01/virtio-v1.2-csd01.html#x1-270006
// This package does not make assumptions about the device that consumes the
// queue. It rather just allocates the queue structures in memory and provides
// methods to interact with it.
package virtqueue

View File

@@ -1,45 +0,0 @@
package virtqueue
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gvisor.dev/gvisor/pkg/eventfd"
)
// Tests how an eventfd and a waiting goroutine can be gracefully closed.
// Extends the eventfd test suite:
// https://github.com/google/gvisor/blob/0799336d64be65eb97d330606c30162dc3440cab/pkg/eventfd/eventfd_test.go
func TestEventFD_CancelWait(t *testing.T) {
efd, err := eventfd.Create()
require.NoError(t, err)
t.Cleanup(func() {
assert.NoError(t, efd.Close())
})
var stop bool
done := make(chan struct{})
go func() {
for !stop {
_ = efd.Wait()
}
close(done)
}()
select {
case <-done:
t.Fatalf("goroutine ended early")
case <-time.After(500 * time.Millisecond):
}
stop = true
assert.NoError(t, efd.Notify())
select {
case <-done:
break
case <-time.After(5 * time.Second):
t.Error("goroutine did not end")
}
}

View File

@@ -1,33 +0,0 @@
package virtqueue
import (
"errors"
"fmt"
)
// ErrQueueSizeInvalid is returned when a queue size is invalid.
var ErrQueueSizeInvalid = errors.New("queue size is invalid")
// CheckQueueSize checks if the given value would be a valid size for a
// virtqueue and returns an [ErrQueueSizeInvalid], if not.
func CheckQueueSize(queueSize int) error {
if queueSize <= 0 {
return fmt.Errorf("%w: %d is too small", ErrQueueSizeInvalid, queueSize)
}
// The queue size must always be a power of 2.
// This ensures that ring indexes wrap correctly when the 16-bit integers
// overflow.
if queueSize&(queueSize-1) != 0 {
return fmt.Errorf("%w: %d is not a power of 2", ErrQueueSizeInvalid, queueSize)
}
// The largest power of 2 that fits into a 16-bit integer is 32768.
// 2 * 32768 would be 65536 which no longer fits.
if queueSize > 32768 {
return fmt.Errorf("%w: %d is larger than the maximum possible queue size 32768",
ErrQueueSizeInvalid, queueSize)
}
return nil
}

View File

@@ -1,59 +0,0 @@
package virtqueue
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestCheckQueueSize(t *testing.T) {
tests := []struct {
name string
queueSize int
containsErr string
}{
{
name: "negative",
queueSize: -1,
containsErr: "too small",
},
{
name: "zero",
queueSize: 0,
containsErr: "too small",
},
{
name: "not a power of 2",
queueSize: 24,
containsErr: "not a power of 2",
},
{
name: "too large",
queueSize: 65536,
containsErr: "larger than the maximum",
},
{
name: "valid 1",
queueSize: 1,
},
{
name: "valid 256",
queueSize: 256,
},
{
name: "valid 32768",
queueSize: 32768,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := CheckQueueSize(tt.queueSize)
if tt.containsErr != "" {
assert.ErrorContains(t, err, tt.containsErr)
} else {
assert.NoError(t, err)
}
})
}
}

View File

@@ -1,530 +0,0 @@
package virtqueue
import (
"context"
"errors"
"fmt"
"os"
"syscall"
"github.com/slackhq/nebula/overlay/eventfd"
"golang.org/x/sys/unix"
)
// SplitQueue is a virtqueue that consists of several parts, where each part is
// writeable by either the driver or the device, but not both.
type SplitQueue struct {
// size is the size of the queue.
size int
// buf is the underlying memory used for the queue.
buf []byte
descriptorTable *DescriptorTable
availableRing *AvailableRing
usedRing *UsedRing
// kickEventFD is used to signal the device when descriptor chains were
// added to the available ring.
kickEventFD eventfd.EventFD
// callEventFD is used by the device to signal when it has used descriptor
// chains and put them in the used ring.
callEventFD eventfd.EventFD
// stop is used by [SplitQueue.Close] to cancel the goroutine that handles
// used buffer notifications. It blocks until the goroutine ended.
stop func() error
itemSize int
epoll eventfd.Epoll
more int
}
// NewSplitQueue allocates a new [SplitQueue] in memory. The given queue size
// specifies the number of entries/buffers the queue can hold. This also affects
// the memory consumption.
func NewSplitQueue(queueSize int, itemSize int) (_ *SplitQueue, err error) {
if err = CheckQueueSize(queueSize); err != nil {
return nil, err
}
if itemSize%os.Getpagesize() != 0 {
return nil, errors.New("split queue size must be multiple of os.Getpagesize()")
}
sq := SplitQueue{
size: queueSize,
itemSize: itemSize,
}
// Clean up a partially initialized queue when something fails.
defer func() {
if err != nil {
_ = sq.Close()
}
}()
// There are multiple ways for how the memory for the virtqueue could be
// allocated. We could use Go native structs with arrays inside them, but
// this wouldn't allow us to make the queue size configurable. And including
// a slice in the Go structs wouldn't work, because this would just put the
// Go slice descriptor into the memory region which the virtio device will
// not understand.
// Additionally, Go does not allow us to ensure a correct alignment of the
// parts of the virtqueue, as it is required by the virtio specification.
//
// To resolve this, let's just allocate the memory manually by allocating
// one or more memory pages, depending on the queue size. Making the
// virtqueue start at the beginning of a page is not strictly necessary, as
// the virtio specification does not require it to be continuous in the
// physical memory of the host (e.g. the vhost implementation in the kernel
// always uses copy_from_user to access it), but this makes it very easy to
// guarantee the alignment. Also, it is not required for the virtqueue parts
// to be in the same memory region, as we pass separate pointers to them to
// the device, but this design just makes things easier to implement.
//
// One added benefit of allocating the memory manually is, that we have full
// control over its lifetime and don't risk the garbage collector to collect
// our valuable structures while the device still works with them.
// The descriptor table is at the start of the page, so alignment is not an
// issue here.
descriptorTableStart := 0
descriptorTableEnd := descriptorTableStart + descriptorTableSize(queueSize)
availableRingStart := align(descriptorTableEnd, availableRingAlignment)
availableRingEnd := availableRingStart + availableRingSize(queueSize)
usedRingStart := align(availableRingEnd, usedRingAlignment)
usedRingEnd := usedRingStart + usedRingSize(queueSize)
sq.buf, err = unix.Mmap(-1, 0, usedRingEnd,
unix.PROT_READ|unix.PROT_WRITE,
unix.MAP_PRIVATE|unix.MAP_ANONYMOUS)
if err != nil {
return nil, fmt.Errorf("allocate virtqueue buffer: %w", err)
}
sq.descriptorTable = newDescriptorTable(queueSize, sq.buf[descriptorTableStart:descriptorTableEnd], sq.itemSize)
sq.availableRing = newAvailableRing(queueSize, sq.buf[availableRingStart:availableRingEnd])
sq.usedRing = newUsedRing(queueSize, sq.buf[usedRingStart:usedRingEnd])
sq.kickEventFD, err = eventfd.New()
if err != nil {
return nil, fmt.Errorf("create kick event file descriptor: %w", err)
}
sq.callEventFD, err = eventfd.New()
if err != nil {
return nil, fmt.Errorf("create call event file descriptor: %w", err)
}
if err = sq.descriptorTable.initializeDescriptors(); err != nil {
return nil, fmt.Errorf("initialize descriptors: %w", err)
}
sq.epoll, err = eventfd.NewEpoll()
if err != nil {
return nil, err
}
err = sq.epoll.AddEvent(sq.callEventFD.FD())
if err != nil {
return nil, err
}
// Consume used buffer notifications in the background.
sq.stop = sq.startConsumeUsedRing()
return &sq, nil
}
// Size returns the size of this queue, which is the number of entries/buffers
// this queue can hold.
func (sq *SplitQueue) Size() int {
return sq.size
}
// DescriptorTable returns the [DescriptorTable] behind this queue.
func (sq *SplitQueue) DescriptorTable() *DescriptorTable {
return sq.descriptorTable
}
// AvailableRing returns the [AvailableRing] behind this queue.
func (sq *SplitQueue) AvailableRing() *AvailableRing {
return sq.availableRing
}
// UsedRing returns the [UsedRing] behind this queue.
func (sq *SplitQueue) UsedRing() *UsedRing {
return sq.usedRing
}
// KickEventFD returns the kick event file descriptor behind this queue.
// The returned file descriptor should be used with great care to not interfere
// with this implementation.
func (sq *SplitQueue) KickEventFD() int {
return sq.kickEventFD.FD()
}
// CallEventFD returns the call event file descriptor behind this queue.
// The returned file descriptor should be used with great care to not interfere
// with this implementation.
func (sq *SplitQueue) CallEventFD() int {
return sq.callEventFD.FD()
}
// startConsumeUsedRing starts a goroutine that runs [consumeUsedRing].
// A function is returned that can be used to gracefully cancel it. todo rename
func (sq *SplitQueue) startConsumeUsedRing() func() error {
return func() error {
// The goroutine blocks until it receives a signal on the event file
// descriptor, so it will never notice the context being canceled.
// To resolve this, we can just produce a fake-signal ourselves to wake
// it up.
if err := sq.callEventFD.Kick(); err != nil {
return fmt.Errorf("wake up goroutine: %w", err)
}
return nil
}
}
// BlockAndGetHeads waits for the device to signal that it has used descriptor chains and returns all [UsedElement]s
func (sq *SplitQueue) BlockAndGetHeads(ctx context.Context) ([]UsedElement, error) {
var n int
var err error
for ctx.Err() == nil {
// Wait for a signal from the device.
if n, err = sq.epoll.Block(); err != nil {
return nil, fmt.Errorf("wait: %w", err)
}
if n > 0 {
stillNeedToTake, out := sq.usedRing.take(-1)
sq.more = stillNeedToTake
if stillNeedToTake == 0 {
_ = sq.epoll.Clear() //???
}
return out, nil
}
}
return nil, ctx.Err()
}
func (sq *SplitQueue) TakeSingle(ctx context.Context) (uint16, error) {
var n int
var err error
for ctx.Err() == nil {
out, ok := sq.usedRing.takeOne()
if ok {
return out, nil
}
// Wait for a signal from the device.
if n, err = sq.epoll.Block(); err != nil {
return 0, fmt.Errorf("wait: %w", err)
}
if n > 0 {
out, ok = sq.usedRing.takeOne()
if ok {
_ = sq.epoll.Clear() //???
return out, nil
} else {
continue //???
}
}
}
return 0, ctx.Err()
}
func (sq *SplitQueue) BlockAndGetHeadsCapped(ctx context.Context, maxToTake int) ([]UsedElement, error) {
var n int
var err error
for ctx.Err() == nil {
//we have leftovers in the fridge
if sq.more > 0 {
stillNeedToTake, out := sq.usedRing.take(maxToTake)
sq.more = stillNeedToTake
return out, nil
}
//look inside the fridge
stillNeedToTake, out := sq.usedRing.take(maxToTake)
if len(out) > 0 {
sq.more = stillNeedToTake
return out, nil
}
//fridge is empty I guess
// Wait for a signal from the device.
if n, err = sq.epoll.Block(); err != nil {
return nil, fmt.Errorf("wait: %w", err)
}
if n > 0 {
_ = sq.epoll.Clear() //???
stillNeedToTake, out = sq.usedRing.take(maxToTake)
sq.more = stillNeedToTake
return out, nil
}
}
return nil, ctx.Err()
}
// OfferDescriptorChain offers a descriptor chain to the device which contains a
// number of device-readable buffers (out buffers) and device-writable buffers
// (in buffers).
//
// All buffers in the outBuffers slice will be concatenated by chaining
// descriptors, one for each buffer in the slice. When a buffer is too large to
// fit into a single descriptor (limited by the system's page size), it will be
// split up into multiple descriptors within the chain.
// When numInBuffers is greater than zero, the given number of device-writable
// descriptors will be appended to the end of the chain, each referencing a
// whole memory page (see [os.Getpagesize]).
//
// When the queue is full and no more descriptor chains can be added, a wrapped
// [ErrNotEnoughFreeDescriptors] will be returned. If you set waitFree to true,
// this method will handle this error and will block instead until there are
// enough free descriptors again.
//
// After defining the descriptor chain in the [DescriptorTable], the index of
// the head of the chain will be made available to the device using the
// [AvailableRing] and will be returned by this method.
// Callers should read from the [SplitQueue.UsedDescriptorChains] channel to be
// notified when the descriptor chain was used by the device and should free the
// used descriptor chains again using [SplitQueue.FreeDescriptorChain] when
// they're done with them. When this does not happen, the queue will run full
// and any further calls to [SplitQueue.OfferDescriptorChain] will stall.
func (sq *SplitQueue) OfferInDescriptorChains() (uint16, error) {
// Create a descriptor chain for the given buffers.
var (
head uint16
err error
)
for {
head, err = sq.descriptorTable.createDescriptorForInputs()
if err == nil {
break
}
// I don't wanna use errors.Is, it's slow
//goland:noinspection GoDirectComparisonOfErrors
if err == ErrNotEnoughFreeDescriptors {
return 0, err
} else {
return 0, fmt.Errorf("create descriptor chain: %w", err)
}
}
// Make the descriptor chain available to the device.
sq.availableRing.offerSingle(head)
// Notify the device to make it process the updated available ring.
if err := sq.kickEventFD.Kick(); err != nil {
return head, fmt.Errorf("notify device: %w", err)
}
return head, nil
}
func (sq *SplitQueue) OfferOutDescriptorChains(prepend []byte, outBuffers [][]byte) ([]uint16, error) {
// TODO change this
// Each descriptor can only hold a whole memory page, so split large out
// buffers into multiple smaller ones.
outBuffers = splitBuffers(outBuffers, sq.itemSize)
chains := make([]uint16, len(outBuffers))
// Create a descriptor chain for the given buffers.
var (
head uint16
err error
)
for i := range outBuffers {
for {
bufs := [][]byte{prepend, outBuffers[i]}
head, err = sq.descriptorTable.createDescriptorChain(bufs, 0)
if err == nil {
break
}
// I don't wanna use errors.Is, it's slow
//goland:noinspection GoDirectComparisonOfErrors
if err == ErrNotEnoughFreeDescriptors {
// Wait for more free descriptors to be put back into the queue.
// If the number of free descriptors is still not sufficient, we'll
// land here again.
//todo should never happen
syscall.Syscall(syscall.SYS_SCHED_YIELD, 0, 0, 0) // Cheap barrier
continue
}
return nil, fmt.Errorf("create descriptor chain: %w", err)
}
chains[i] = head
}
// Make the descriptor chain available to the device.
sq.availableRing.offer(chains)
// Notify the device to make it process the updated available ring.
if err := sq.kickEventFD.Kick(); err != nil {
return chains, fmt.Errorf("notify device: %w", err)
}
return chains, nil
}
// GetDescriptorChain returns the device-readable buffers (out buffers) and
// device-writable buffers (in buffers) of the descriptor chain with the given
// head index.
// The head index must be one that was returned by a previous call to
// [SplitQueue.OfferDescriptorChain] and the descriptor chain must not have been
// freed yet.
//
// Be careful to only access the returned buffer slices when the device is no
// longer using them. They must not be accessed after
// [SplitQueue.FreeDescriptorChain] has been called.
func (sq *SplitQueue) GetDescriptorChain(head uint16) (outBuffers, inBuffers [][]byte, err error) {
return sq.descriptorTable.getDescriptorChain(head)
}
func (sq *SplitQueue) GetDescriptorItem(head uint16) ([]byte, error) {
sq.descriptorTable.descriptors[head].length = uint32(sq.descriptorTable.itemSize)
return sq.descriptorTable.getDescriptorItem(head)
}
func (sq *SplitQueue) GetDescriptorChainContents(head uint16, out []byte, maxLen int) (int, error) {
return sq.descriptorTable.getDescriptorChainContents(head, out, maxLen)
}
func (sq *SplitQueue) GetDescriptorInbuffers(head uint16, inBuffers *[][]byte) error {
return sq.descriptorTable.getDescriptorInbuffers(head, inBuffers)
}
// FreeDescriptorChain frees the descriptor chain with the given head index.
// The head index must be one that was returned by a previous call to
// [SplitQueue.OfferDescriptorChain] and the descriptor chain must not have been
// freed yet.
//
// This creates new room in the queue which can be used by following
// [SplitQueue.OfferDescriptorChain] calls.
// When there are outstanding calls for [SplitQueue.OfferDescriptorChain] that
// are waiting for free room in the queue, they may become unblocked by this.
func (sq *SplitQueue) FreeDescriptorChain(head uint16) error {
//not called under lock
if err := sq.descriptorTable.freeDescriptorChain(head); err != nil {
return fmt.Errorf("free: %w", err)
}
return nil
}
func (sq *SplitQueue) SetDescSize(head uint16, sz int) {
//not called under lock
sq.descriptorTable.descriptors[int(head)].length = uint32(sz)
}
func (sq *SplitQueue) OfferDescriptorChains(chains []uint16, kick bool) error {
//todo not doing this may break eventually?
//not called under lock
//if err := sq.descriptorTable.freeDescriptorChain(head); err != nil {
// return fmt.Errorf("free: %w", err)
//}
// Make the descriptor chain available to the device.
sq.availableRing.offer(chains)
// Notify the device to make it process the updated available ring.
if kick {
return sq.Kick()
}
return nil
}
func (sq *SplitQueue) Kick() error {
if err := sq.kickEventFD.Kick(); err != nil {
return fmt.Errorf("notify device: %w", err)
}
return nil
}
// Close releases all resources used for this queue.
// The implementation will try to release as many resources as possible and
// collect potential errors before returning them.
func (sq *SplitQueue) Close() error {
var errs []error
if sq.stop != nil {
// This has to happen before the event file descriptors may be closed.
if err := sq.stop(); err != nil {
errs = append(errs, fmt.Errorf("stop consume used ring: %w", err))
}
// Make sure that this code block is executed only once.
sq.stop = nil
}
if err := sq.kickEventFD.Close(); err != nil {
errs = append(errs, fmt.Errorf("close kick event file descriptor: %w", err))
}
if err := sq.callEventFD.Close(); err != nil {
errs = append(errs, fmt.Errorf("close call event file descriptor: %w", err))
}
if err := sq.descriptorTable.releaseBuffers(); err != nil {
errs = append(errs, fmt.Errorf("release descriptor buffers: %w", err))
}
if sq.buf != nil {
if err := unix.Munmap(sq.buf); err == nil {
sq.buf = nil
} else {
errs = append(errs, fmt.Errorf("unmap virtqueue buffer: %w", err))
}
}
return errors.Join(errs...)
}
// ensureInitialized is used as a guard to prevent methods to be called on an
// uninitialized instance.
func (sq *SplitQueue) ensureInitialized() {
if sq.buf == nil {
panic("used ring is not initialized")
}
}
func align(index, alignment int) int {
remainder := index % alignment
if remainder == 0 {
return index
}
return index + alignment - remainder
}
// splitBuffers processes a list of buffers and splits each buffer that is
// larger than the size limit into multiple smaller buffers.
// If none of the buffers are too big though, do nothing, to avoid allocation for now
func splitBuffers(buffers [][]byte, sizeLimit int) [][]byte {
for i := range buffers {
if len(buffers[i]) > sizeLimit {
return reallySplitBuffers(buffers, sizeLimit)
}
}
return buffers
}
func reallySplitBuffers(buffers [][]byte, sizeLimit int) [][]byte {
result := make([][]byte, 0, len(buffers))
for _, buffer := range buffers {
for added := 0; added < len(buffer); added += sizeLimit {
if len(buffer)-added <= sizeLimit {
result = append(result, buffer[added:])
break
}
result = append(result, buffer[added:added+sizeLimit])
}
}
return result
}

View File

@@ -1,105 +0,0 @@
package virtqueue
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestSplitQueue_MemoryAlignment(t *testing.T) {
tests := []struct {
name string
queueSize int
}{
{
name: "minimal queue size",
queueSize: 1,
},
{
name: "small queue size",
queueSize: 8,
},
{
name: "large queue size",
queueSize: 256,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
sq, err := NewSplitQueue(tt.queueSize)
require.NoError(t, err)
assert.Zero(t, sq.descriptorTable.Address()%descriptorTableAlignment)
assert.Zero(t, sq.availableRing.Address()%availableRingAlignment)
assert.Zero(t, sq.usedRing.Address()%usedRingAlignment)
})
}
}
func TestSplitBuffers(t *testing.T) {
const sizeLimit = 16
tests := []struct {
name string
buffers [][]byte
expected [][]byte
}{
{
name: "no buffers",
buffers: make([][]byte, 0),
expected: make([][]byte, 0),
},
{
name: "small",
buffers: [][]byte{
make([]byte, 11),
},
expected: [][]byte{
make([]byte, 11),
},
},
{
name: "exact size",
buffers: [][]byte{
make([]byte, sizeLimit),
},
expected: [][]byte{
make([]byte, sizeLimit),
},
},
{
name: "large",
buffers: [][]byte{
make([]byte, 42),
},
expected: [][]byte{
make([]byte, 16),
make([]byte, 16),
make([]byte, 10),
},
},
{
name: "mixed",
buffers: [][]byte{
make([]byte, 7),
make([]byte, 30),
make([]byte, 15),
make([]byte, 32),
},
expected: [][]byte{
make([]byte, 7),
make([]byte, 16),
make([]byte, 14),
make([]byte, 15),
make([]byte, 16),
make([]byte, 16),
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
actual := splitBuffers(tt.buffers, sizeLimit)
assert.Equal(t, tt.expected, actual)
})
}
}

View File

@@ -1,21 +0,0 @@
package virtqueue
// usedElementSize is the number of bytes needed to store a [UsedElement] in
// memory.
const usedElementSize = 8
// UsedElement is an element of the [UsedRing] and describes a descriptor chain
// that was used by the device.
type UsedElement struct {
// DescriptorIndex is the index of the head of the used descriptor chain in
// the [DescriptorTable].
// The index is 32-bit here for padding reasons.
DescriptorIndex uint32
// Length is the number of bytes written into the device writable portion of
// the buffer described by the descriptor chain.
Length uint32
}
func (u *UsedElement) GetHead() uint16 {
return uint16(u.DescriptorIndex)
}

View File

@@ -1,12 +0,0 @@
package virtqueue
import (
"testing"
"unsafe"
"github.com/stretchr/testify/assert"
)
func TestUsedElement_Size(t *testing.T) {
assert.EqualValues(t, usedElementSize, unsafe.Sizeof(UsedElement{}))
}

View File

@@ -1,184 +0,0 @@
package virtqueue
import (
"fmt"
"unsafe"
)
// usedRingFlag is a flag that describes a [UsedRing].
type usedRingFlag uint16
const (
// usedRingFlagNoNotify is used by the host to advise the guest to not
// kick it when adding a buffer. It's unreliable, so it's simply an
// optimization. Guest will still kick when it's out of buffers.
usedRingFlagNoNotify usedRingFlag = 1 << iota
)
// usedRingSize is the number of bytes needed to store a [UsedRing] with the
// given queue size in memory.
func usedRingSize(queueSize int) int {
return 6 + usedElementSize*queueSize
}
// usedRingAlignment is the minimum alignment of a [UsedRing] in memory, as
// required by the virtio spec.
const usedRingAlignment = 4
// UsedRing is where the device returns descriptor chains once it is done with
// them. Each ring entry is a [UsedElement]. It is only written to by the device
// and read by the driver.
//
// Because the size of the ring depends on the queue size, we cannot define a
// Go struct with a static size that maps to the memory of the ring. Instead,
// this struct only contains pointers to the corresponding memory areas.
type UsedRing struct {
initialized bool
// flags that describe this ring.
flags *usedRingFlag
// ringIndex indicates where the device would put the next entry into the
// ring (modulo the queue size).
ringIndex *uint16
// ring contains the [UsedElement]s. It wraps around at queue size.
ring []UsedElement
// availableEvent is not used by this implementation, but we reserve it
// anyway to avoid issues in case a device may try to write to it, contrary
// to the virtio specification.
availableEvent *uint16
// lastIndex is the internal ringIndex up to which all [UsedElement]s were
// processed.
lastIndex uint16
//mu sync.Mutex
}
// newUsedRing creates a used ring that uses the given underlying memory. The
// length of the memory slice must match the size needed for the ring (see
// [usedRingSize]) for the given queue size.
func newUsedRing(queueSize int, mem []byte) *UsedRing {
ringSize := usedRingSize(queueSize)
if len(mem) != ringSize {
panic(fmt.Sprintf("memory size (%v) does not match required size "+
"for used ring: %v", len(mem), ringSize))
}
r := UsedRing{
initialized: true,
flags: (*usedRingFlag)(unsafe.Pointer(&mem[0])),
ringIndex: (*uint16)(unsafe.Pointer(&mem[2])),
ring: unsafe.Slice((*UsedElement)(unsafe.Pointer(&mem[4])), queueSize),
availableEvent: (*uint16)(unsafe.Pointer(&mem[ringSize-2])),
}
r.lastIndex = *r.ringIndex
return &r
}
// Address returns the pointer to the beginning of the ring in memory.
// Do not modify the memory directly to not interfere with this implementation.
func (r *UsedRing) Address() uintptr {
if !r.initialized {
panic("used ring is not initialized")
}
return uintptr(unsafe.Pointer(r.flags))
}
// take returns all new [UsedElement]s that the device put into the ring and
// that weren't already returned by a previous call to this method.
// had a lock, I removed it
func (r *UsedRing) take(maxToTake int) (int, []UsedElement) {
//r.mu.Lock()
//defer r.mu.Unlock()
ringIndex := *r.ringIndex
if ringIndex == r.lastIndex {
// Nothing new.
return 0, nil
}
// Calculate the number new used elements that we can read from the ring.
// The ring index may wrap, so special handling for that case is needed.
count := int(ringIndex - r.lastIndex)
if count < 0 {
count += 0xffff
}
stillNeedToTake := 0
if maxToTake > 0 {
stillNeedToTake = count - maxToTake
if stillNeedToTake < 0 {
stillNeedToTake = 0
}
count = min(count, maxToTake)
}
// The number of new elements can never exceed the queue size.
if count > len(r.ring) {
panic("used ring contains more new elements than the ring is long")
}
elems := make([]UsedElement, count)
for i := range count {
elems[i] = r.ring[r.lastIndex%uint16(len(r.ring))]
r.lastIndex++
}
return stillNeedToTake, elems
}
func (r *UsedRing) takeOne() (uint16, bool) {
//r.mu.Lock()
//defer r.mu.Unlock()
ringIndex := *r.ringIndex
if ringIndex == r.lastIndex {
// Nothing new.
return 0xffff, false
}
// Calculate the number new used elements that we can read from the ring.
// The ring index may wrap, so special handling for that case is needed.
count := int(ringIndex - r.lastIndex)
if count < 0 {
count += 0xffff
}
// The number of new elements can never exceed the queue size.
if count > len(r.ring) {
panic("used ring contains more new elements than the ring is long")
}
if count == 0 {
return 0xffff, false
}
out := r.ring[r.lastIndex%uint16(len(r.ring))].GetHead()
r.lastIndex++
return out, true
}
// InitOfferSingle is only used to pre-fill the used queue at startup, and should not be used if the device is running!
func (r *UsedRing) InitOfferSingle(x uint16, size int) {
//always called under lock
//r.mu.Lock()
//defer r.mu.Unlock()
offset := 0
// Add descriptor chain heads to the ring.
// The 16-bit ring index may overflow. This is expected and is not an
// issue because the size of the ring array (which equals the queue
// size) is always a power of 2 and smaller than the highest possible
// 16-bit value.
insertIndex := int(*r.ringIndex+uint16(offset)) % len(r.ring)
r.ring[insertIndex] = UsedElement{
DescriptorIndex: uint32(x),
Length: uint32(size),
}
// Increase the ring index by the number of descriptor chains added to the ring.
*r.ringIndex += 1
}

View File

@@ -1,136 +0,0 @@
package virtqueue
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestUsedRing_MemoryLayout(t *testing.T) {
const queueSize = 2
memory := make([]byte, usedRingSize(queueSize))
r := newUsedRing(queueSize, memory)
*r.flags = 0x01ff
*r.ringIndex = 1
r.ring[0] = UsedElement{
DescriptorIndex: 0x0123,
Length: 0x4567,
}
r.ring[1] = UsedElement{
DescriptorIndex: 0x89ab,
Length: 0xcdef,
}
assert.Equal(t, []byte{
0xff, 0x01,
0x01, 0x00,
0x23, 0x01, 0x00, 0x00,
0x67, 0x45, 0x00, 0x00,
0xab, 0x89, 0x00, 0x00,
0xef, 0xcd, 0x00, 0x00,
0x00, 0x00,
}, memory)
}
//func TestUsedRing_Take(t *testing.T) {
// const queueSize = 8
//
// tests := []struct {
// name string
// ring []UsedElement
// ringIndex uint16
// lastIndex uint16
// expected []UsedElement
// }{
// {
// name: "nothing new",
// ring: []UsedElement{
// {DescriptorIndex: 1},
// {DescriptorIndex: 2},
// {DescriptorIndex: 3},
// {DescriptorIndex: 4},
// {},
// {},
// {},
// {},
// },
// ringIndex: 4,
// lastIndex: 4,
// expected: nil,
// },
// {
// name: "no overflow",
// ring: []UsedElement{
// {DescriptorIndex: 1},
// {DescriptorIndex: 2},
// {DescriptorIndex: 3},
// {DescriptorIndex: 4},
// {},
// {},
// {},
// {},
// },
// ringIndex: 4,
// lastIndex: 1,
// expected: []UsedElement{
// {DescriptorIndex: 2},
// {DescriptorIndex: 3},
// {DescriptorIndex: 4},
// },
// },
// {
// name: "ring overflow",
// ring: []UsedElement{
// {DescriptorIndex: 9},
// {DescriptorIndex: 10},
// {DescriptorIndex: 3},
// {DescriptorIndex: 4},
// {DescriptorIndex: 5},
// {DescriptorIndex: 6},
// {DescriptorIndex: 7},
// {DescriptorIndex: 8},
// },
// ringIndex: 10,
// lastIndex: 7,
// expected: []UsedElement{
// {DescriptorIndex: 8},
// {DescriptorIndex: 9},
// {DescriptorIndex: 10},
// },
// },
// {
// name: "index overflow",
// ring: []UsedElement{
// {DescriptorIndex: 9},
// {DescriptorIndex: 10},
// {DescriptorIndex: 3},
// {DescriptorIndex: 4},
// {DescriptorIndex: 5},
// {DescriptorIndex: 6},
// {DescriptorIndex: 7},
// {DescriptorIndex: 8},
// },
// ringIndex: 2,
// lastIndex: 65535,
// expected: []UsedElement{
// {DescriptorIndex: 8},
// {DescriptorIndex: 9},
// {DescriptorIndex: 10},
// },
// },
// }
// for _, tt := range tests {
// t.Run(tt.name, func(t *testing.T) {
// memory := make([]byte, usedRingSize(queueSize))
// r := newUsedRing(queueSize, memory)
//
// copy(r.ring, tt.ring)
// *r.ringIndex = tt.ringIndex
// r.lastIndex = tt.lastIndex
//
// assert.Equal(t, tt.expected, r.take())
// })
// }
//}

View File

@@ -1,70 +0,0 @@
package packet
import (
"github.com/slackhq/nebula/util/virtio"
"golang.org/x/sys/unix"
)
type OutPacket struct {
Segments [][]byte
SegmentPayloads [][]byte
SegmentHeaders [][]byte
SegmentIDs []uint16
//todo virtio header?
SegSize int
SegCounter int
Valid bool
wasSegmented bool
Scratch []byte
}
func NewOut() *OutPacket {
out := new(OutPacket)
out.Segments = make([][]byte, 0, 64)
out.SegmentHeaders = make([][]byte, 0, 64)
out.SegmentPayloads = make([][]byte, 0, 64)
out.SegmentIDs = make([]uint16, 0, 64)
out.Scratch = make([]byte, Size)
return out
}
func (pkt *OutPacket) Reset() {
pkt.Segments = pkt.Segments[:0]
pkt.SegmentPayloads = pkt.SegmentPayloads[:0]
pkt.SegmentHeaders = pkt.SegmentHeaders[:0]
pkt.SegmentIDs = pkt.SegmentIDs[:0]
pkt.SegSize = 0
pkt.Valid = false
pkt.wasSegmented = false
}
func (pkt *OutPacket) UseSegment(segID uint16, seg []byte, isV6 bool) int {
pkt.Valid = true
pkt.SegmentIDs = append(pkt.SegmentIDs, segID)
pkt.Segments = append(pkt.Segments, seg) //todo do we need this?
vhdr := virtio.NetHdr{ //todo
Flags: unix.VIRTIO_NET_HDR_F_DATA_VALID,
GSOType: unix.VIRTIO_NET_HDR_GSO_NONE,
HdrLen: 0,
GSOSize: 0,
CsumStart: 0,
CsumOffset: 0,
NumBuffers: 0,
}
hdr := seg[0 : virtio.NetHdrSize+14]
_ = vhdr.Encode(hdr)
if isV6 {
hdr[virtio.NetHdrSize+14-2] = 0x86
hdr[virtio.NetHdrSize+14-1] = 0xdd
} else {
hdr[virtio.NetHdrSize+14-2] = 0x08
hdr[virtio.NetHdrSize+14-1] = 0x00
}
pkt.SegmentHeaders = append(pkt.SegmentHeaders, hdr)
pkt.SegmentPayloads = append(pkt.SegmentPayloads, seg[virtio.NetHdrSize+14:])
return len(pkt.SegmentIDs) - 1
}

View File

@@ -1,119 +0,0 @@
package packet
import (
"encoding/binary"
"iter"
"net/netip"
"slices"
"syscall"
"unsafe"
"golang.org/x/sys/unix"
)
const Size = 0xffff
type Packet struct {
Payload []byte
Control []byte
Name []byte
SegSize int
//todo should this hold out as well?
OutLen int
wasSegmented bool
isV4 bool
}
func New(isV4 bool) *Packet {
return &Packet{
Payload: make([]byte, Size),
Control: make([]byte, unix.CmsgSpace(2)),
Name: make([]byte, unix.SizeofSockaddrInet6),
isV4: isV4,
}
}
func (p *Packet) AddrPort() netip.AddrPort {
var ip netip.Addr
// Its ok to skip the ok check here, the slicing is the only error that can occur and it will panic
if p.isV4 {
ip, _ = netip.AddrFromSlice(p.Name[4:8])
} else {
ip, _ = netip.AddrFromSlice(p.Name[8:24])
}
return netip.AddrPortFrom(ip.Unmap(), binary.BigEndian.Uint16(p.Name[2:4]))
}
func (p *Packet) updateCtrl(ctrlLen int) {
p.SegSize = len(p.Payload)
p.wasSegmented = false
if ctrlLen == 0 {
return
}
if len(p.Control) == 0 {
return
}
cmsgs, err := unix.ParseSocketControlMessage(p.Control)
if err != nil {
return // oh well
}
for _, c := range cmsgs {
if c.Header.Level == unix.SOL_UDP && c.Header.Type == unix.UDP_GRO && len(c.Data) >= 2 {
p.wasSegmented = true
p.SegSize = int(binary.LittleEndian.Uint16(c.Data[:2]))
return
}
}
}
// Update sets a Packet into "just received, not processed" state
func (p *Packet) Update(ctrlLen int) {
p.OutLen = -1
p.updateCtrl(ctrlLen)
}
func (p *Packet) SetSegSizeForTX() {
p.SegSize = len(p.Payload)
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&p.Control[0]))
hdr.Level = unix.SOL_UDP
hdr.Type = unix.UDP_SEGMENT
hdr.SetLen(syscall.CmsgLen(2))
binary.NativeEndian.PutUint16(p.Control[unix.CmsgLen(0):unix.CmsgLen(0)+2], uint16(p.SegSize))
}
func (p *Packet) CompatibleForSegmentationWith(otherP *Packet, currentTotalSize int) bool {
//same dest
if !slices.Equal(p.Name, otherP.Name) {
return false
}
//don't get too big
if len(p.Payload)+currentTotalSize >= 0xffff {
return false
}
//same body len
//todo allow single different size at end
if len(p.Payload) != len(otherP.Payload) {
return false //todo technically you can cram one extra in
}
return true
}
func (p *Packet) Segments() iter.Seq[[]byte] {
return func(yield func([]byte) bool) {
//cursor := 0
for offset := 0; offset < len(p.Payload); offset += p.SegSize {
end := offset + p.SegSize
if end > len(p.Payload) {
end = len(p.Payload)
}
if !yield(p.Payload[offset:end]) {
return
}
}
}
}

View File

@@ -1,37 +0,0 @@
package packet
import (
"github.com/slackhq/nebula/util/virtio"
)
type VirtIOPacket struct {
Payload []byte
Header virtio.NetHdr
Chains []uint16
ChainRefs [][]byte
// OfferDescriptorChains(chains []uint16, kick bool) error
}
func NewVIO() *VirtIOPacket {
out := new(VirtIOPacket)
out.Payload = nil
out.ChainRefs = make([][]byte, 0, 4)
out.Chains = make([]uint16, 0, 8)
return out
}
func (v *VirtIOPacket) Reset() {
v.Payload = nil
v.ChainRefs = v.ChainRefs[:0]
v.Chains = v.Chains[:0]
}
type VirtIOTXPacket struct {
VirtIOPacket
}
func NewVIOTX(isV4 bool) *VirtIOTXPacket {
out := new(VirtIOTXPacket)
out.VirtIOPacket = *NewVIO()
return out
}

47
pki.go
View File

@@ -100,36 +100,41 @@ func (p *PKI) reloadCerts(c *config.C, initial bool) *util.ContextualError {
currentState := p.cs.Load() currentState := p.cs.Load()
if newState.v1Cert != nil { if newState.v1Cert != nil {
if currentState.v1Cert == nil { if currentState.v1Cert == nil {
//adding certs is fine, actually. Networks-in-common confirmed in newCertState(). return util.NewContextualError("v1 certificate was added, restart required", nil, err)
} else { }
// did IP in cert change? if so, don't set // did IP in cert change? if so, don't set
if !slices.Equal(currentState.v1Cert.Networks(), newState.v1Cert.Networks()) { if !slices.Equal(currentState.v1Cert.Networks(), newState.v1Cert.Networks()) {
return util.NewContextualError( return util.NewContextualError(
"Networks in new cert was different from old", "Networks in new cert was different from old",
m{"new_networks": newState.v1Cert.Networks(), "old_networks": currentState.v1Cert.Networks(), "cert_version": cert.Version1}, m{"new_networks": newState.v1Cert.Networks(), "old_networks": currentState.v1Cert.Networks()},
nil, nil,
) )
} }
if currentState.v1Cert.Curve() != newState.v1Cert.Curve() { if currentState.v1Cert.Curve() != newState.v1Cert.Curve() {
return util.NewContextualError( return util.NewContextualError(
"Curve in new v1 cert was different from old", "Curve in new cert was different from old",
m{"new_curve": newState.v1Cert.Curve(), "old_curve": currentState.v1Cert.Curve(), "cert_version": cert.Version1}, m{"new_curve": newState.v1Cert.Curve(), "old_curve": currentState.v1Cert.Curve()},
nil, nil,
) )
} }
}
} else if currentState.v1Cert != nil {
//TODO: CERT-V2 we should be able to tear this down
return util.NewContextualError("v1 certificate was removed, restart required", nil, err)
} }
if newState.v2Cert != nil { if newState.v2Cert != nil {
if currentState.v2Cert == nil { if currentState.v2Cert == nil {
//adding certs is fine, actually return util.NewContextualError("v2 certificate was added, restart required", nil, err)
} else { }
// did IP in cert change? if so, don't set // did IP in cert change? if so, don't set
if !slices.Equal(currentState.v2Cert.Networks(), newState.v2Cert.Networks()) { if !slices.Equal(currentState.v2Cert.Networks(), newState.v2Cert.Networks()) {
return util.NewContextualError( return util.NewContextualError(
"Networks in new cert was different from old", "Networks in new cert was different from old",
m{"new_networks": newState.v2Cert.Networks(), "old_networks": currentState.v2Cert.Networks(), "cert_version": cert.Version2}, m{"new_networks": newState.v2Cert.Networks(), "old_networks": currentState.v2Cert.Networks()},
nil, nil,
) )
} }
@@ -137,25 +142,13 @@ func (p *PKI) reloadCerts(c *config.C, initial bool) *util.ContextualError {
if currentState.v2Cert.Curve() != newState.v2Cert.Curve() { if currentState.v2Cert.Curve() != newState.v2Cert.Curve() {
return util.NewContextualError( return util.NewContextualError(
"Curve in new cert was different from old", "Curve in new cert was different from old",
m{"new_curve": newState.v2Cert.Curve(), "old_curve": currentState.v2Cert.Curve(), "cert_version": cert.Version2}, m{"new_curve": newState.v2Cert.Curve(), "old_curve": currentState.v2Cert.Curve()},
nil, nil,
) )
} }
}
} else if currentState.v2Cert != nil { } else if currentState.v2Cert != nil {
//newState.v1Cert is non-nil bc empty certstates aren't permitted return util.NewContextualError("v2 certificate was removed, restart required", nil, err)
if newState.v1Cert == nil {
return util.NewContextualError("v1 and v2 certs are nil, this should be impossible", nil, err)
}
//if we're going to v1-only, we need to make sure we didn't orphan any v2-cert vpnaddrs
if !slices.Equal(currentState.v2Cert.Networks(), newState.v1Cert.Networks()) {
return util.NewContextualError(
"Removing a V2 cert is not permitted unless it has identical networks to the new V1 cert",
m{"new_v1_networks": newState.v1Cert.Networks(), "old_v2_networks": currentState.v2Cert.Networks()},
nil,
)
}
} }
// Cipher cant be hot swapped so just leave it at what it was before // Cipher cant be hot swapped so just leave it at what it was before
@@ -523,14 +516,10 @@ func loadCAPoolFromConfig(l *logrus.Logger, c *config.C) (*cert.CAPool, error) {
return nil, fmt.Errorf("error while adding CA certificate to CA trust store: %s", err) return nil, fmt.Errorf("error while adding CA certificate to CA trust store: %s", err)
} }
bl := c.GetStringSlice("pki.blocklist", []string{}) for _, fp := range c.GetStringSlice("pki.blocklist", []string{}) {
if len(bl) > 0 { l.WithField("fingerprint", fp).Info("Blocklisting cert")
for _, fp := range bl {
caPool.BlocklistFingerprint(fp) caPool.BlocklistFingerprint(fp)
} }
l.WithField("fingerprintCount", len(bl)).Info("Blocklisted certificates")
}
return caPool, nil return caPool, nil
} }

View File

@@ -16,8 +16,8 @@ import (
"github.com/slackhq/nebula/cert_test" "github.com/slackhq/nebula/cert_test"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/overlay" "github.com/slackhq/nebula/overlay"
"go.yaml.in/yaml/v3"
"golang.org/x/sync/errgroup" "golang.org/x/sync/errgroup"
"gopkg.in/yaml.v3"
) )
type m = map[string]any type m = map[string]any

View File

@@ -4,13 +4,13 @@ import (
"net/netip" "net/netip"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/packet"
) )
const MTU = 9001 const MTU = 9001
type EncReader func( type EncReader func(
[]*packet.Packet, addr netip.AddrPort,
payload []byte,
) )
type Conn interface { type Conn interface {
@@ -19,8 +19,6 @@ type Conn interface {
ListenOut(r EncReader) ListenOut(r EncReader)
WriteTo(b []byte, addr netip.AddrPort) error WriteTo(b []byte, addr netip.AddrPort) error
ReloadConfig(c *config.C) ReloadConfig(c *config.C)
Prep(pkt *packet.Packet, addr netip.AddrPort) error
WriteBatch(pkt []*packet.Packet) (int, error)
Close() error Close() error
} }

View File

@@ -14,22 +14,22 @@ import (
"github.com/rcrowley/go-metrics" "github.com/rcrowley/go-metrics"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/packet"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
) )
const iovMax = 128 //1024 //no unix constant for this? from limits.h
//todo I'd like this to be 1024 but we seem to hit errors around ~130?
type StdConn struct { type StdConn struct {
sysFd int sysFd int
isV4 bool isV4 bool
l *logrus.Logger l *logrus.Logger
batch int batch int
enableGRO bool }
msgs []rawMessage func maybeIPV4(ip net.IP) (net.IP, bool) {
iovs [][]iovec ip4 := ip.To4()
if ip4 != nil {
return ip4, true
}
return ip, false
} }
func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) { func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
@@ -69,20 +69,7 @@ func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch in
return nil, fmt.Errorf("unable to bind to socket: %s", err) return nil, fmt.Errorf("unable to bind to socket: %s", err)
} }
const batchSize = 8192 return &StdConn{sysFd: fd, isV4: ip.Is4(), l: l, batch: batch}, err
msgs := make([]rawMessage, 0, batchSize) //todo configure
iovs := make([][]iovec, batchSize)
for i := range iovs {
iovs[i] = make([]iovec, iovMax)
}
return &StdConn{
sysFd: fd,
isV4: ip.Is4(),
l: l,
batch: batch,
msgs: msgs,
iovs: iovs,
}, err
} }
func (u *StdConn) Rebind() error { func (u *StdConn) Rebind() error {
@@ -132,7 +119,9 @@ func (u *StdConn) LocalAddr() (netip.AddrPort, error) {
} }
func (u *StdConn) ListenOut(r EncReader) { func (u *StdConn) ListenOut(r EncReader) {
msgs, packets := u.PrepareRawMessages(u.batch, u.isV4) var ip netip.Addr
msgs, buffers, names := u.PrepareRawMessages(u.batch)
read := u.ReadMulti read := u.ReadMulti
if u.batch == 1 { if u.batch == 1 {
read = u.ReadSingle read = u.ReadSingle
@@ -146,12 +135,13 @@ func (u *StdConn) ListenOut(r EncReader) {
} }
for i := 0; i < n; i++ { for i := 0; i < n; i++ {
packets[i].Payload = packets[i].Payload[:msgs[i].Len] // Its ok to skip the ok check here, the slicing is the only error that can occur and it will panic
packets[i].Update(getRawMessageControlLen(&msgs[i])) if u.isV4 {
ip, _ = netip.AddrFromSlice(names[i][4:8])
} else {
ip, _ = netip.AddrFromSlice(names[i][8:24])
} }
r(packets[:n]) r(netip.AddrPortFrom(ip.Unmap(), binary.BigEndian.Uint16(names[i][2:4])), buffers[i][:msgs[i].Len])
for i := 0; i < n; i++ { //todo reset this in prev loop, but this makes debug ez
msgs[i].Hdr.Controllen = uint64(unix.CmsgSpace(2))
} }
} }
} }
@@ -204,147 +194,6 @@ func (u *StdConn) WriteTo(b []byte, ip netip.AddrPort) error {
return u.writeTo6(b, ip) return u.writeTo6(b, ip)
} }
func (u *StdConn) WriteToBatch(b []byte, ip netip.AddrPort) error {
if u.isV4 {
return u.writeTo4(b, ip)
}
return u.writeTo6(b, ip)
}
func (u *StdConn) Prep(pkt *packet.Packet, addr netip.AddrPort) error {
nl, err := u.encodeSockaddr(pkt.Name, addr)
if err != nil {
return err
}
pkt.Name = pkt.Name[:nl]
pkt.OutLen = len(pkt.Payload)
return nil
}
func (u *StdConn) WriteBatch(pkts []*packet.Packet) (int, error) {
if len(pkts) == 0 {
return 0, nil
}
u.msgs = u.msgs[:0]
//u.iovs = u.iovs[:0]
sent := 0
var mostRecentPkt *packet.Packet
mostRecentPktSize := 0
//segmenting := false
idx := 0
for _, pkt := range pkts {
if len(pkt.Payload) == 0 || pkt.OutLen == -1 {
sent++
continue
}
lastIdx := idx - 1
if mostRecentPkt != nil && pkt.CompatibleForSegmentationWith(mostRecentPkt, mostRecentPktSize) && u.msgs[lastIdx].Hdr.Iovlen < iovMax {
u.msgs[lastIdx].Hdr.Controllen = uint64(len(mostRecentPkt.Control))
u.msgs[lastIdx].Hdr.Control = &mostRecentPkt.Control[0]
u.iovs[lastIdx][u.msgs[lastIdx].Hdr.Iovlen].Base = &pkt.Payload[0]
u.iovs[lastIdx][u.msgs[lastIdx].Hdr.Iovlen].Len = uint64(len(pkt.Payload))
u.msgs[lastIdx].Hdr.Iovlen++
mostRecentPktSize += len(pkt.Payload)
mostRecentPkt.SetSegSizeForTX()
} else {
u.msgs = append(u.msgs, rawMessage{})
u.iovs[idx][0] = iovec{
Base: &pkt.Payload[0],
Len: uint64(len(pkt.Payload)),
}
msg := &u.msgs[idx]
iov := &u.iovs[idx][0]
idx++
msg.Hdr.Iov = iov
msg.Hdr.Iovlen = 1
setRawMessageControl(msg, nil)
msg.Hdr.Flags = 0
msg.Hdr.Name = &pkt.Name[0]
msg.Hdr.Namelen = uint32(len(pkt.Name))
mostRecentPkt = pkt
mostRecentPktSize = len(pkt.Payload)
}
}
if len(u.msgs) == 0 {
return sent, nil
}
offset := 0
for offset < len(u.msgs) {
n, _, errno := unix.Syscall6(
unix.SYS_SENDMMSG,
uintptr(u.sysFd),
uintptr(unsafe.Pointer(&u.msgs[offset])),
uintptr(len(u.msgs)-offset),
0,
0,
0,
)
if errno != 0 {
if errno == unix.EINTR {
continue
}
//for i := 0; i < len(u.msgs); i++ {
// for j := 0; j < int(u.msgs[i].Hdr.Iovlen); j++ {
// u.l.WithFields(logrus.Fields{
// "msg_index": i,
// "iov idx": j,
// "iov": fmt.Sprintf("%+v", u.iovs[i][j]),
// }).Warn("failed to send message")
// }
//
//}
u.l.WithFields(logrus.Fields{
"errno": errno,
"idx": idx,
"len": len(u.msgs),
"deets": fmt.Sprintf("%+v", u.msgs),
"lastIOV": fmt.Sprintf("%+v", u.iovs[len(u.msgs)-1][u.msgs[len(u.msgs)-1].Hdr.Iovlen-1]),
}).Error("failed to send message")
return sent + offset, &net.OpError{Op: "sendmmsg", Err: errno}
}
if n == 0 {
break
}
offset += int(n)
}
return sent + len(u.msgs), nil
}
func (u *StdConn) encodeSockaddr(dst []byte, addr netip.AddrPort) (uint32, error) {
if u.isV4 {
if !addr.Addr().Is4() {
return 0, fmt.Errorf("Listener is IPv4, but writing to IPv6 remote")
}
var sa unix.RawSockaddrInet4
sa.Family = unix.AF_INET
sa.Addr = addr.Addr().As4()
binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&sa.Port))[:], addr.Port())
size := unix.SizeofSockaddrInet4
copy(dst[:size], (*(*[unix.SizeofSockaddrInet4]byte)(unsafe.Pointer(&sa)))[:])
return uint32(size), nil
}
var sa unix.RawSockaddrInet6
sa.Family = unix.AF_INET6
sa.Addr = addr.Addr().As16()
binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&sa.Port))[:], addr.Port())
size := unix.SizeofSockaddrInet6
copy(dst[:size], (*(*[unix.SizeofSockaddrInet6]byte)(unsafe.Pointer(&sa)))[:])
return uint32(size), nil
}
func (u *StdConn) writeTo6(b []byte, ip netip.AddrPort) error { func (u *StdConn) writeTo6(b []byte, ip netip.AddrPort) error {
var rsa unix.RawSockaddrInet6 var rsa unix.RawSockaddrInet6
rsa.Family = unix.AF_INET6 rsa.Family = unix.AF_INET6
@@ -445,27 +294,6 @@ func (u *StdConn) ReloadConfig(c *config.C) {
u.l.WithError(err).Error("Failed to set listen.so_mark") u.l.WithError(err).Error("Failed to set listen.so_mark")
} }
} }
u.configureGRO(true)
}
func (u *StdConn) configureGRO(enable bool) {
if enable == u.enableGRO {
return
}
if enable {
if err := unix.SetsockoptInt(u.sysFd, unix.SOL_UDP, unix.UDP_GRO, 1); err != nil {
u.l.WithError(err).Warn("Failed to enable UDP GRO")
return
}
u.enableGRO = true
u.l.Info("UDP GRO enabled")
} else {
if err := unix.SetsockoptInt(u.sysFd, unix.SOL_UDP, unix.UDP_GRO, 0); err != nil && err != unix.ENOPROTOOPT {
u.l.WithError(err).Warn("Failed to disable UDP GRO")
}
u.enableGRO = false
}
} }
func (u *StdConn) getMemInfo(meminfo *[unix.SK_MEMINFO_VARS]uint32) error { func (u *StdConn) getMemInfo(meminfo *[unix.SK_MEMINFO_VARS]uint32) error {

View File

@@ -7,7 +7,6 @@
package udp package udp
import ( import (
"github.com/slackhq/nebula/packet"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
) )
@@ -34,59 +33,25 @@ type rawMessage struct {
Pad0 [4]byte Pad0 [4]byte
} }
func setRawMessageControl(msg *rawMessage, buf []byte) { func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) {
if len(buf) == 0 {
msg.Hdr.Control = nil
msg.Hdr.Controllen = 0
return
}
msg.Hdr.Control = &buf[0]
msg.Hdr.Controllen = uint64(len(buf))
}
func getRawMessageControlLen(msg *rawMessage) int {
return int(msg.Hdr.Controllen)
}
func setCmsgLen(h *unix.Cmsghdr, l int) {
h.Len = uint64(l)
}
func (u *StdConn) PrepareRawMessages(n int, isV4 bool) ([]rawMessage, []*packet.Packet) {
msgs := make([]rawMessage, n) msgs := make([]rawMessage, n)
packets := make([]*packet.Packet, n) buffers := make([][]byte, n)
names := make([][]byte, n)
for i := range msgs { for i := range msgs {
packets[i] = packet.New(isV4) buffers[i] = make([]byte, MTU)
names[i] = make([]byte, unix.SizeofSockaddrInet6)
vs := []iovec{ vs := []iovec{
{Base: &packets[i].Payload[0], Len: uint64(packet.Size)}, {Base: &buffers[i][0], Len: uint64(len(buffers[i]))},
} }
msgs[i].Hdr.Iov = &vs[0] msgs[i].Hdr.Iov = &vs[0]
msgs[i].Hdr.Iovlen = uint64(len(vs)) msgs[i].Hdr.Iovlen = uint64(len(vs))
msgs[i].Hdr.Name = &packets[i].Name[0] msgs[i].Hdr.Name = &names[i][0]
msgs[i].Hdr.Namelen = uint32(len(packets[i].Name)) msgs[i].Hdr.Namelen = uint32(len(names[i]))
if u.enableGRO {
msgs[i].Hdr.Control = &packets[i].Control[0]
msgs[i].Hdr.Controllen = uint64(len(packets[i].Control))
} else {
msgs[i].Hdr.Control = nil
msgs[i].Hdr.Controllen = 0
}
} }
return msgs, packets return msgs, buffers, names
}
func setIovecSlice(iov *iovec, b []byte) {
if len(b) == 0 {
iov.Base = nil
iov.Len = 0
return
}
iov.Base = &b[0]
iov.Len = uint64(len(b))
} }

View File

@@ -1,3 +0,0 @@
// Package virtio contains some generic types and concepts related to the virtio
// protocol.
package virtio

View File

@@ -1,136 +0,0 @@
package virtio
// Feature contains feature bits that describe a virtio device or driver.
type Feature uint64
// Device-independent feature bits.
//
// Source: https://docs.oasis-open.org/virtio/virtio/v1.2/csd01/virtio-v1.2-csd01.html#x1-6600006
const (
// FeatureIndirectDescriptors indicates that the driver can use descriptors
// with an additional layer of indirection.
FeatureIndirectDescriptors Feature = 1 << 28
// FeatureVersion1 indicates compliance with version 1.0 of the virtio
// specification.
FeatureVersion1 Feature = 1 << 32
)
// Feature bits for networking devices.
//
// Source: https://docs.oasis-open.org/virtio/virtio/v1.2/csd01/virtio-v1.2-csd01.html#x1-2200003
const (
// FeatureNetDeviceCsum indicates that the device can handle packets with
// partial checksum (checksum offload).
FeatureNetDeviceCsum Feature = 1 << 0
// FeatureNetDriverCsum indicates that the driver can handle packets with
// partial checksum.
FeatureNetDriverCsum Feature = 1 << 1
// FeatureNetCtrlDriverOffloads indicates support for dynamic offload state
// reconfiguration.
FeatureNetCtrlDriverOffloads Feature = 1 << 2
// FeatureNetMTU indicates that the device reports a maximum MTU value.
FeatureNetMTU Feature = 1 << 3
// FeatureNetMAC indicates that the device provides a MAC address.
FeatureNetMAC Feature = 1 << 5
// FeatureNetDriverTSO4 indicates that the driver supports the TCP
// segmentation offload for received IPv4 packets.
FeatureNetDriverTSO4 Feature = 1 << 7
// FeatureNetDriverTSO6 indicates that the driver supports the TCP
// segmentation offload for received IPv6 packets.
FeatureNetDriverTSO6 Feature = 1 << 8
// FeatureNetDriverECN indicates that the driver supports the TCP
// segmentation offload with ECN for received packets.
FeatureNetDriverECN Feature = 1 << 9
// FeatureNetDriverUFO indicates that the driver supports the UDP
// fragmentation offload for received packets.
FeatureNetDriverUFO Feature = 1 << 10
// FeatureNetDeviceTSO4 indicates that the device supports the TCP
// segmentation offload for received IPv4 packets.
FeatureNetDeviceTSO4 Feature = 1 << 11
// FeatureNetDeviceTSO6 indicates that the device supports the TCP
// segmentation offload for received IPv6 packets.
FeatureNetDeviceTSO6 Feature = 1 << 12
// FeatureNetDeviceECN indicates that the device supports the TCP
// segmentation offload with ECN for received packets.
FeatureNetDeviceECN Feature = 1 << 13
// FeatureNetDeviceUFO indicates that the device supports the UDP
// fragmentation offload for received packets.
FeatureNetDeviceUFO Feature = 1 << 14
// FeatureNetMergeRXBuffers indicates that the driver can handle merged
// receive buffers.
// When this feature is negotiated, devices may merge multiple descriptor
// chains together to transport large received packets. [NetHdr.NumBuffers]
// will then contain the number of merged descriptor chains.
FeatureNetMergeRXBuffers Feature = 1 << 15
// FeatureNetStatus indicates that the device configuration status field is
// available.
FeatureNetStatus Feature = 1 << 16
// FeatureNetCtrlVQ indicates that a control channel virtqueue is
// available.
FeatureNetCtrlVQ Feature = 1 << 17
// FeatureNetCtrlRX indicates support for RX mode control (e.g. promiscuous
// or all-multicast) for packet receive filtering.
FeatureNetCtrlRX Feature = 1 << 18
// FeatureNetCtrlVLAN indicates support for VLAN filtering through the
// control channel.
FeatureNetCtrlVLAN Feature = 1 << 19
// FeatureNetDriverAnnounce indicates that the driver can send gratuitous
// packets.
FeatureNetDriverAnnounce Feature = 1 << 21
// FeatureNetMQ indicates that the device supports multiqueue with automatic
// receive steering.
FeatureNetMQ Feature = 1 << 22
// FeatureNetCtrlMACAddr indicates that the MAC address can be set through
// the control channel.
FeatureNetCtrlMACAddr Feature = 1 << 23
// FeatureNetDeviceUSO indicates that the device supports the UDP
// segmentation offload for received packets.
FeatureNetDeviceUSO Feature = 1 << 56
// FeatureNetHashReport indicates that the device can report a per-packet
// hash value and type.
FeatureNetHashReport Feature = 1 << 57
// FeatureNetDriverHdrLen indicates that the driver can provide the exact
// header length value (see [NetHdr.HdrLen]).
// Devices may benefit from knowing the exact header length.
FeatureNetDriverHdrLen Feature = 1 << 59
// FeatureNetRSS indicates that the device supports RSS (receive-side
// scaling) with configurable hash parameters.
FeatureNetRSS Feature = 1 << 60
// FeatureNetRSCExt indicates that the device can process duplicated ACKs
// and report the number of coalesced segments and duplicated ACKs.
FeatureNetRSCExt Feature = 1 << 61
// FeatureNetStandby indicates that the device may act as a standby for a
// primary device with the same MAC address.
FeatureNetStandby Feature = 1 << 62
// FeatureNetSpeedDuplex indicates that the device can report link speed and
// duplex mode.
FeatureNetSpeedDuplex Feature = 1 << 63
)

View File

@@ -1,77 +0,0 @@
package virtio
import (
"errors"
"unsafe"
"golang.org/x/sys/unix"
)
// Workaround to make Go doc links work.
var _ unix.Errno
// NetHdrSize is the number of bytes needed to store a [NetHdr] in memory.
const NetHdrSize = 12
// ErrNetHdrBufferTooSmall is returned when a buffer is too small to fit a
// virtio_net_hdr.
var ErrNetHdrBufferTooSmall = errors.New("the buffer is too small to fit a virtio_net_hdr")
// NetHdr defines the virtio_net_hdr as described by the virtio specification.
type NetHdr struct {
// Flags that describe the packet.
// Possible values are:
// - [unix.VIRTIO_NET_HDR_F_NEEDS_CSUM]
// - [unix.VIRTIO_NET_HDR_F_DATA_VALID]
// - [unix.VIRTIO_NET_HDR_F_RSC_INFO]
Flags uint8
// GSOType contains the type of segmentation offload that should be used for
// the packet.
// Possible values are:
// - [unix.VIRTIO_NET_HDR_GSO_NONE]
// - [unix.VIRTIO_NET_HDR_GSO_TCPV4]
// - [unix.VIRTIO_NET_HDR_GSO_UDP]
// - [unix.VIRTIO_NET_HDR_GSO_TCPV6]
// - [unix.VIRTIO_NET_HDR_GSO_UDP_L4]
// - [unix.VIRTIO_NET_HDR_GSO_ECN]
GSOType uint8
// HdrLen contains the length of the headers that need to be replicated by
// segmentation offloads. It's the number of bytes from the beginning of the
// packet to the beginning of the transport payload.
// Only used when [FeatureNetDriverHdrLen] is negotiated.
HdrLen uint16
// GSOSize contains the maximum size of each segmented packet beyond the
// header (payload size). In case of TCP, this is the MSS.
GSOSize uint16
// CsumStart contains the offset within the packet from which on the
// checksum should be computed.
CsumStart uint16
// CsumOffset specifies how many bytes after [NetHdr.CsumStart] the computed
// 16-bit checksum should be inserted.
CsumOffset uint16
// NumBuffers contains the number of merged descriptor chains when
// [FeatureNetMergeRXBuffers] is negotiated.
// This field is only used for packets received by the driver and should be
// zero for transmitted packets.
NumBuffers uint16
}
// Decode decodes the [NetHdr] from the given byte slice. The slice must contain
// at least [NetHdrSize] bytes.
func (v *NetHdr) Decode(data []byte) error {
if len(data) < NetHdrSize {
return ErrNetHdrBufferTooSmall
}
copy(unsafe.Slice((*byte)(unsafe.Pointer(v)), NetHdrSize), data[:NetHdrSize])
return nil
}
// Encode encodes the [NetHdr] into the given byte slice. The slice must have
// room for at least [NetHdrSize] bytes.
func (v *NetHdr) Encode(data []byte) error {
if len(data) < NetHdrSize {
return ErrNetHdrBufferTooSmall
}
copy(data[:NetHdrSize], unsafe.Slice((*byte)(unsafe.Pointer(v)), NetHdrSize))
return nil
}

View File

@@ -1,43 +0,0 @@
package virtio
import (
"testing"
"unsafe"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/sys/unix"
)
func TestNetHdr_Size(t *testing.T) {
assert.EqualValues(t, NetHdrSize, unsafe.Sizeof(NetHdr{}))
}
func TestNetHdr_Encoding(t *testing.T) {
vnethdr := NetHdr{
Flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM,
GSOType: unix.VIRTIO_NET_HDR_GSO_UDP_L4,
HdrLen: 42,
GSOSize: 1472,
CsumStart: 34,
CsumOffset: 6,
NumBuffers: 16,
}
buf := make([]byte, NetHdrSize)
require.NoError(t, vnethdr.Encode(buf))
assert.Equal(t, []byte{
0x01, 0x05,
0x2a, 0x00,
0xc0, 0x05,
0x22, 0x00,
0x06, 0x00,
0x10, 0x00,
}, buf)
var decoded NetHdr
require.NoError(t, decoded.Decode(buf))
assert.Equal(t, vnethdr, decoded)
}