mirror of
https://github.com/slackhq/nebula.git
synced 2025-11-22 16:34:25 +01:00
Compare commits
9 Commits
vhost
...
jay.wren-w
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8d4dd26484 | ||
|
|
0a94f9f990 | ||
|
|
433c531ae4 | ||
|
|
4c0aad1b1f | ||
|
|
c8b0281736 | ||
|
|
8281b1699f | ||
|
|
0827a6f1c5 | ||
|
|
273119638d | ||
|
|
484de41b58 |
2
.github/workflows/gofmt.yml
vendored
2
.github/workflows/gofmt.yml
vendored
@@ -14,7 +14,7 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
|
|
||||||
- uses: actions/checkout@v5
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
- uses: actions/setup-go@v6
|
- uses: actions/setup-go@v6
|
||||||
with:
|
with:
|
||||||
|
|||||||
20
.github/workflows/release.yml
vendored
20
.github/workflows/release.yml
vendored
@@ -10,7 +10,7 @@ jobs:
|
|||||||
name: Build Linux/BSD All
|
name: Build Linux/BSD All
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v5
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
- uses: actions/setup-go@v6
|
- uses: actions/setup-go@v6
|
||||||
with:
|
with:
|
||||||
@@ -24,7 +24,7 @@ jobs:
|
|||||||
mv build/*.tar.gz release
|
mv build/*.tar.gz release
|
||||||
|
|
||||||
- name: Upload artifacts
|
- name: Upload artifacts
|
||||||
uses: actions/upload-artifact@v5
|
uses: actions/upload-artifact@v4
|
||||||
with:
|
with:
|
||||||
name: linux-latest
|
name: linux-latest
|
||||||
path: release
|
path: release
|
||||||
@@ -33,7 +33,7 @@ jobs:
|
|||||||
name: Build Windows
|
name: Build Windows
|
||||||
runs-on: windows-latest
|
runs-on: windows-latest
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v5
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
- uses: actions/setup-go@v6
|
- uses: actions/setup-go@v6
|
||||||
with:
|
with:
|
||||||
@@ -55,7 +55,7 @@ jobs:
|
|||||||
mv dist\windows\wintun build\dist\windows\
|
mv dist\windows\wintun build\dist\windows\
|
||||||
|
|
||||||
- name: Upload artifacts
|
- name: Upload artifacts
|
||||||
uses: actions/upload-artifact@v5
|
uses: actions/upload-artifact@v4
|
||||||
with:
|
with:
|
||||||
name: windows-latest
|
name: windows-latest
|
||||||
path: build
|
path: build
|
||||||
@@ -66,7 +66,7 @@ jobs:
|
|||||||
HAS_SIGNING_CREDS: ${{ secrets.AC_USERNAME != '' }}
|
HAS_SIGNING_CREDS: ${{ secrets.AC_USERNAME != '' }}
|
||||||
runs-on: macos-latest
|
runs-on: macos-latest
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v5
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
- uses: actions/setup-go@v6
|
- uses: actions/setup-go@v6
|
||||||
with:
|
with:
|
||||||
@@ -104,7 +104,7 @@ jobs:
|
|||||||
fi
|
fi
|
||||||
|
|
||||||
- name: Upload artifacts
|
- name: Upload artifacts
|
||||||
uses: actions/upload-artifact@v5
|
uses: actions/upload-artifact@v4
|
||||||
with:
|
with:
|
||||||
name: darwin-latest
|
name: darwin-latest
|
||||||
path: ./release/*
|
path: ./release/*
|
||||||
@@ -124,11 +124,11 @@ jobs:
|
|||||||
# be overwritten
|
# be overwritten
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
if: ${{ env.HAS_DOCKER_CREDS == 'true' }}
|
if: ${{ env.HAS_DOCKER_CREDS == 'true' }}
|
||||||
uses: actions/checkout@v5
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: Download artifacts
|
- name: Download artifacts
|
||||||
if: ${{ env.HAS_DOCKER_CREDS == 'true' }}
|
if: ${{ env.HAS_DOCKER_CREDS == 'true' }}
|
||||||
uses: actions/download-artifact@v6
|
uses: actions/download-artifact@v4
|
||||||
with:
|
with:
|
||||||
name: linux-latest
|
name: linux-latest
|
||||||
path: artifacts
|
path: artifacts
|
||||||
@@ -160,10 +160,10 @@ jobs:
|
|||||||
needs: [build-linux, build-darwin, build-windows]
|
needs: [build-linux, build-darwin, build-windows]
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v5
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: Download artifacts
|
- name: Download artifacts
|
||||||
uses: actions/download-artifact@v6
|
uses: actions/download-artifact@v4
|
||||||
with:
|
with:
|
||||||
path: artifacts
|
path: artifacts
|
||||||
|
|
||||||
|
|||||||
2
.github/workflows/smoke-extra.yml
vendored
2
.github/workflows/smoke-extra.yml
vendored
@@ -20,7 +20,7 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
|
|
||||||
- uses: actions/checkout@v5
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
- uses: actions/setup-go@v6
|
- uses: actions/setup-go@v6
|
||||||
with:
|
with:
|
||||||
|
|||||||
2
.github/workflows/smoke.yml
vendored
2
.github/workflows/smoke.yml
vendored
@@ -18,7 +18,7 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
|
|
||||||
- uses: actions/checkout@v5
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
- uses: actions/setup-go@v6
|
- uses: actions/setup-go@v6
|
||||||
with:
|
with:
|
||||||
|
|||||||
16
.github/workflows/test.yml
vendored
16
.github/workflows/test.yml
vendored
@@ -18,7 +18,7 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
|
|
||||||
- uses: actions/checkout@v5
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
- uses: actions/setup-go@v6
|
- uses: actions/setup-go@v6
|
||||||
with:
|
with:
|
||||||
@@ -32,7 +32,7 @@ jobs:
|
|||||||
run: make vet
|
run: make vet
|
||||||
|
|
||||||
- name: golangci-lint
|
- name: golangci-lint
|
||||||
uses: golangci/golangci-lint-action@v9
|
uses: golangci/golangci-lint-action@v8
|
||||||
with:
|
with:
|
||||||
version: v2.5
|
version: v2.5
|
||||||
|
|
||||||
@@ -45,7 +45,7 @@ jobs:
|
|||||||
- name: Build test mobile
|
- name: Build test mobile
|
||||||
run: make build-test-mobile
|
run: make build-test-mobile
|
||||||
|
|
||||||
- uses: actions/upload-artifact@v5
|
- uses: actions/upload-artifact@v4
|
||||||
with:
|
with:
|
||||||
name: e2e packet flow linux-latest
|
name: e2e packet flow linux-latest
|
||||||
path: e2e/mermaid/linux-latest
|
path: e2e/mermaid/linux-latest
|
||||||
@@ -56,7 +56,7 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
|
|
||||||
- uses: actions/checkout@v5
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
- uses: actions/setup-go@v6
|
- uses: actions/setup-go@v6
|
||||||
with:
|
with:
|
||||||
@@ -77,7 +77,7 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
|
|
||||||
- uses: actions/checkout@v5
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
- uses: actions/setup-go@v6
|
- uses: actions/setup-go@v6
|
||||||
with:
|
with:
|
||||||
@@ -98,7 +98,7 @@ jobs:
|
|||||||
os: [windows-latest, macos-latest]
|
os: [windows-latest, macos-latest]
|
||||||
steps:
|
steps:
|
||||||
|
|
||||||
- uses: actions/checkout@v5
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
- uses: actions/setup-go@v6
|
- uses: actions/setup-go@v6
|
||||||
with:
|
with:
|
||||||
@@ -115,7 +115,7 @@ jobs:
|
|||||||
run: make vet
|
run: make vet
|
||||||
|
|
||||||
- name: golangci-lint
|
- name: golangci-lint
|
||||||
uses: golangci/golangci-lint-action@v9
|
uses: golangci/golangci-lint-action@v8
|
||||||
with:
|
with:
|
||||||
version: v2.5
|
version: v2.5
|
||||||
|
|
||||||
@@ -125,7 +125,7 @@ jobs:
|
|||||||
- name: End 2 end
|
- name: End 2 end
|
||||||
run: make e2evv
|
run: make e2evv
|
||||||
|
|
||||||
- uses: actions/upload-artifact@v5
|
- uses: actions/upload-artifact@v4
|
||||||
with:
|
with:
|
||||||
name: e2e packet flow ${{ matrix.os }}
|
name: e2e packet flow ${{ matrix.os }}
|
||||||
path: e2e/mermaid/${{ matrix.os }}
|
path: e2e/mermaid/${{ matrix.os }}
|
||||||
|
|||||||
2
bits.go
2
bits.go
@@ -43,7 +43,7 @@ func (b *Bits) Check(l logrus.FieldLogger, i uint64) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Not within the window
|
// Not within the window
|
||||||
l.Debugf("rejected a packet (top) %d %d delta %d\n", b.current, i, b.current-i)
|
l.Debugf("rejected a packet (top) %d %d\n", b.current, i)
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -114,33 +114,6 @@ func NewTestCert(v cert.Version, curve cert.Curve, ca cert.Certificate, key []by
|
|||||||
return c, pub, cert.MarshalPrivateKeyToPEM(curve, priv), pem
|
return c, pub, cert.MarshalPrivateKeyToPEM(curve, priv), pem
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewTestCertDifferentVersion(c cert.Certificate, v cert.Version, ca cert.Certificate, key []byte) (cert.Certificate, []byte) {
|
|
||||||
nc := &cert.TBSCertificate{
|
|
||||||
Version: v,
|
|
||||||
Curve: c.Curve(),
|
|
||||||
Name: c.Name(),
|
|
||||||
Networks: c.Networks(),
|
|
||||||
UnsafeNetworks: c.UnsafeNetworks(),
|
|
||||||
Groups: c.Groups(),
|
|
||||||
NotBefore: time.Unix(c.NotBefore().Unix(), 0),
|
|
||||||
NotAfter: time.Unix(c.NotAfter().Unix(), 0),
|
|
||||||
PublicKey: c.PublicKey(),
|
|
||||||
IsCA: false,
|
|
||||||
}
|
|
||||||
|
|
||||||
c, err := nc.Sign(ca, ca.Curve(), key)
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
pem, err := c.MarshalPEM()
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return c, pem
|
|
||||||
}
|
|
||||||
|
|
||||||
func X25519Keypair() ([]byte, []byte) {
|
func X25519Keypair() ([]byte, []byte) {
|
||||||
privkey := make([]byte, 32)
|
privkey := make([]byte, 32)
|
||||||
if _, err := io.ReadFull(rand.Reader, privkey); err != nil {
|
if _, err := io.ReadFull(rand.Reader, privkey); err != nil {
|
||||||
|
|||||||
@@ -3,9 +3,6 @@ package main
|
|||||||
import (
|
import (
|
||||||
"flag"
|
"flag"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
|
||||||
"net/http"
|
|
||||||
_ "net/http/pprof"
|
|
||||||
"os"
|
"os"
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
@@ -61,10 +58,6 @@ func main() {
|
|||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
go func() {
|
|
||||||
log.Println(http.ListenAndServe("0.0.0.0:6060", nil))
|
|
||||||
}()
|
|
||||||
|
|
||||||
if !*configTest {
|
if !*configTest {
|
||||||
ctrl.Start()
|
ctrl.Start()
|
||||||
notifyReady(l)
|
notifyReady(l)
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ import (
|
|||||||
|
|
||||||
"dario.cat/mergo"
|
"dario.cat/mergo"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
"go.yaml.in/yaml/v3"
|
"gopkg.in/yaml.v3"
|
||||||
)
|
)
|
||||||
|
|
||||||
type C struct {
|
type C struct {
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ import (
|
|||||||
"github.com/slackhq/nebula/test"
|
"github.com/slackhq/nebula/test"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"go.yaml.in/yaml/v3"
|
"gopkg.in/yaml.v3"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestConfig_Load(t *testing.T) {
|
func TestConfig_Load(t *testing.T) {
|
||||||
|
|||||||
@@ -354,6 +354,7 @@ func (cm *connectionManager) makeTrafficDecision(localIndex uint32, now time.Tim
|
|||||||
|
|
||||||
if mainHostInfo {
|
if mainHostInfo {
|
||||||
decision = tryRehandshake
|
decision = tryRehandshake
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
if cm.shouldSwapPrimary(hostinfo) {
|
if cm.shouldSwapPrimary(hostinfo) {
|
||||||
decision = swapPrimary
|
decision = swapPrimary
|
||||||
@@ -460,10 +461,6 @@ func (cm *connectionManager) shouldSwapPrimary(current *HostInfo) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
crt := cm.intf.pki.getCertState().getCertificate(current.ConnectionState.myCert.Version())
|
crt := cm.intf.pki.getCertState().getCertificate(current.ConnectionState.myCert.Version())
|
||||||
if crt == nil {
|
|
||||||
//my cert was reloaded away. We should definitely swap from this tunnel
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
// If this tunnel is using the latest certificate then we should swap it to primary for a bit and see if things
|
// If this tunnel is using the latest certificate then we should swap it to primary for a bit and see if things
|
||||||
// settle down.
|
// settle down.
|
||||||
return bytes.Equal(current.ConnectionState.myCert.Signature(), crt.Signature())
|
return bytes.Equal(current.ConnectionState.myCert.Signature(), crt.Signature())
|
||||||
@@ -478,34 +475,31 @@ func (cm *connectionManager) swapPrimary(current, primary *HostInfo) {
|
|||||||
cm.hostMap.Unlock()
|
cm.hostMap.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
// isInvalidCertificate decides if we should destroy a tunnel.
|
// isInvalidCertificate will check if we should destroy a tunnel if pki.disconnect_invalid is true and
|
||||||
// returns true if pki.disconnect_invalid is true and the certificate is no longer valid.
|
// the certificate is no longer valid. Block listed certificates will skip the pki.disconnect_invalid
|
||||||
// Blocklisted certificates will skip the pki.disconnect_invalid check and return true.
|
// check and return true.
|
||||||
func (cm *connectionManager) isInvalidCertificate(now time.Time, hostinfo *HostInfo) bool {
|
func (cm *connectionManager) isInvalidCertificate(now time.Time, hostinfo *HostInfo) bool {
|
||||||
remoteCert := hostinfo.GetCert()
|
remoteCert := hostinfo.GetCert()
|
||||||
if remoteCert == nil {
|
if remoteCert == nil {
|
||||||
return false //don't tear down tunnels for handshakes in progress
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
caPool := cm.intf.pki.GetCAPool()
|
caPool := cm.intf.pki.GetCAPool()
|
||||||
err := caPool.VerifyCachedCertificate(now, remoteCert)
|
err := caPool.VerifyCachedCertificate(now, remoteCert)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
return false //cert is still valid! yay!
|
|
||||||
} else if err == cert.ErrBlockListed { //avoiding errors.Is for speed
|
|
||||||
// Block listed certificates should always be disconnected
|
|
||||||
hostinfo.logger(cm.l).WithError(err).
|
|
||||||
WithField("fingerprint", remoteCert.Fingerprint).
|
|
||||||
Info("Remote certificate is blocked, tearing down the tunnel")
|
|
||||||
return true
|
|
||||||
} else if cm.intf.disconnectInvalid.Load() {
|
|
||||||
hostinfo.logger(cm.l).WithError(err).
|
|
||||||
WithField("fingerprint", remoteCert.Fingerprint).
|
|
||||||
Info("Remote certificate is no longer valid, tearing down the tunnel")
|
|
||||||
return true
|
|
||||||
} else {
|
|
||||||
//if we reach here, the cert is no longer valid, but we're configured to keep tunnels from now-invalid certs open
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if !cm.intf.disconnectInvalid.Load() && err != cert.ErrBlockListed {
|
||||||
|
// Block listed certificates should always be disconnected
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
hostinfo.logger(cm.l).WithError(err).
|
||||||
|
WithField("fingerprint", remoteCert.Fingerprint).
|
||||||
|
Info("Remote certificate is no longer valid, tearing down the tunnel")
|
||||||
|
|
||||||
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cm *connectionManager) sendPunch(hostinfo *HostInfo) {
|
func (cm *connectionManager) sendPunch(hostinfo *HostInfo) {
|
||||||
@@ -536,45 +530,15 @@ func (cm *connectionManager) sendPunch(hostinfo *HostInfo) {
|
|||||||
func (cm *connectionManager) tryRehandshake(hostinfo *HostInfo) {
|
func (cm *connectionManager) tryRehandshake(hostinfo *HostInfo) {
|
||||||
cs := cm.intf.pki.getCertState()
|
cs := cm.intf.pki.getCertState()
|
||||||
curCrt := hostinfo.ConnectionState.myCert
|
curCrt := hostinfo.ConnectionState.myCert
|
||||||
curCrtVersion := curCrt.Version()
|
myCrt := cs.getCertificate(curCrt.Version())
|
||||||
myCrt := cs.getCertificate(curCrtVersion)
|
if curCrt.Version() >= cs.initiatingVersion && bytes.Equal(curCrt.Signature(), myCrt.Signature()) == true {
|
||||||
if myCrt == nil {
|
// The current tunnel is using the latest certificate and version, no need to rehandshake.
|
||||||
cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
|
|
||||||
WithField("version", curCrtVersion).
|
|
||||||
WithField("reason", "local certificate removed").
|
|
||||||
Info("Re-handshaking with remote")
|
|
||||||
cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
peerCrt := hostinfo.ConnectionState.peerCert
|
|
||||||
if peerCrt != nil && curCrtVersion < peerCrt.Certificate.Version() {
|
|
||||||
// if our certificate version is less than theirs, and we have a matching version available, rehandshake?
|
|
||||||
if cs.getCertificate(peerCrt.Certificate.Version()) != nil {
|
|
||||||
cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
|
|
||||||
WithField("version", curCrtVersion).
|
|
||||||
WithField("peerVersion", peerCrt.Certificate.Version()).
|
|
||||||
WithField("reason", "local certificate version lower than peer, attempting to correct").
|
|
||||||
Info("Re-handshaking with remote")
|
|
||||||
cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], func(hh *HandshakeHostInfo) {
|
|
||||||
hh.initiatingVersionOverride = peerCrt.Certificate.Version()
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if !bytes.Equal(curCrt.Signature(), myCrt.Signature()) {
|
|
||||||
cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
|
|
||||||
WithField("reason", "local certificate is not current").
|
|
||||||
Info("Re-handshaking with remote")
|
|
||||||
|
|
||||||
cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil)
|
cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
|
||||||
return
|
WithField("reason", "local certificate is not current").
|
||||||
}
|
Info("Re-handshaking with remote")
|
||||||
if curCrtVersion < cs.initiatingVersion {
|
|
||||||
cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
|
|
||||||
WithField("reason", "current cert version < pki.initiatingVersion").
|
|
||||||
Info("Re-handshaking with remote")
|
|
||||||
|
|
||||||
cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil)
|
cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil)
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ import (
|
|||||||
"github.com/slackhq/nebula/noiseutil"
|
"github.com/slackhq/nebula/noiseutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
const ReplayWindow = 4096
|
const ReplayWindow = 1024
|
||||||
|
|
||||||
type ConnectionState struct {
|
type ConnectionState struct {
|
||||||
eKey *NebulaCipherState
|
eKey *NebulaCipherState
|
||||||
|
|||||||
@@ -174,10 +174,6 @@ func (c *Control) GetHostmap() *HostMap {
|
|||||||
return c.f.hostMap
|
return c.f.hostMap
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Control) GetF() *Interface {
|
|
||||||
return c.f
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Control) GetCertState() *CertState {
|
func (c *Control) GetCertState() *CertState {
|
||||||
return c.f.pki.getCertState()
|
return c.f.pki.getCertState()
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ import (
|
|||||||
"github.com/slackhq/nebula/udp"
|
"github.com/slackhq/nebula/udp"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"go.yaml.in/yaml/v3"
|
"gopkg.in/yaml.v3"
|
||||||
)
|
)
|
||||||
|
|
||||||
func BenchmarkHotPath(b *testing.B) {
|
func BenchmarkHotPath(b *testing.B) {
|
||||||
@@ -97,41 +97,6 @@ func TestGoodHandshake(t *testing.T) {
|
|||||||
theirControl.Stop()
|
theirControl.Stop()
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestGoodHandshakeNoOverlap(t *testing.T) {
|
|
||||||
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
|
||||||
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "me", "10.128.0.1/24", nil)
|
|
||||||
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "them", "2001::69/24", nil) //look ma, cross-stack!
|
|
||||||
|
|
||||||
// Put their info in our lighthouse
|
|
||||||
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
|
|
||||||
|
|
||||||
// Start the servers
|
|
||||||
myControl.Start()
|
|
||||||
theirControl.Start()
|
|
||||||
|
|
||||||
empty := []byte{}
|
|
||||||
t.Log("do something to cause a handshake")
|
|
||||||
myControl.GetF().SendMessageToVpnAddr(header.Test, header.MessageNone, theirVpnIpNet[0].Addr(), empty, empty, empty)
|
|
||||||
|
|
||||||
t.Log("Have them consume my stage 0 packet. They have a tunnel now")
|
|
||||||
theirControl.InjectUDPPacket(myControl.GetFromUDP(true))
|
|
||||||
|
|
||||||
t.Log("Get their stage 1 packet")
|
|
||||||
stage1Packet := theirControl.GetFromUDP(true)
|
|
||||||
|
|
||||||
t.Log("Have me consume their stage 1 packet. I have a tunnel now")
|
|
||||||
myControl.InjectUDPPacket(stage1Packet)
|
|
||||||
|
|
||||||
t.Log("Wait until we see a test packet come through to make sure we give the tunnel time to complete")
|
|
||||||
myControl.WaitForType(header.Test, 0, theirControl)
|
|
||||||
|
|
||||||
t.Log("Make sure our host infos are correct")
|
|
||||||
assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet, theirVpnIpNet, myControl, theirControl)
|
|
||||||
|
|
||||||
myControl.Stop()
|
|
||||||
theirControl.Stop()
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestWrongResponderHandshake(t *testing.T) {
|
func TestWrongResponderHandshake(t *testing.T) {
|
||||||
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||||
|
|
||||||
@@ -499,35 +464,6 @@ func TestRelays(t *testing.T) {
|
|||||||
r.RenderHostmaps("Final hostmaps", myControl, relayControl, theirControl)
|
r.RenderHostmaps("Final hostmaps", myControl, relayControl, theirControl)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRelaysDontCareAboutIps(t *testing.T) {
|
|
||||||
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
|
||||||
myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version2, ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}})
|
|
||||||
relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "relay ", "2001::9999/24", m{"relay": m{"am_relay": true}})
|
|
||||||
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}})
|
|
||||||
|
|
||||||
// Teach my how to get to the relay and that their can be reached via the relay
|
|
||||||
myControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr)
|
|
||||||
myControl.InjectRelays(theirVpnIpNet[0].Addr(), []netip.Addr{relayVpnIpNet[0].Addr()})
|
|
||||||
relayControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
|
|
||||||
|
|
||||||
// Build a router so we don't have to reason who gets which packet
|
|
||||||
r := router.NewR(t, myControl, relayControl, theirControl)
|
|
||||||
defer r.RenderFlow()
|
|
||||||
|
|
||||||
// Start the servers
|
|
||||||
myControl.Start()
|
|
||||||
relayControl.Start()
|
|
||||||
theirControl.Start()
|
|
||||||
|
|
||||||
t.Log("Trigger a handshake from me to them via the relay")
|
|
||||||
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
|
|
||||||
|
|
||||||
p := r.RouteForAllUntilTxTun(theirControl)
|
|
||||||
r.Log("Assert the tunnel works")
|
|
||||||
assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80)
|
|
||||||
r.RenderHostmaps("Final hostmaps", myControl, relayControl, theirControl)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestReestablishRelays(t *testing.T) {
|
func TestReestablishRelays(t *testing.T) {
|
||||||
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||||
myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}})
|
myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}})
|
||||||
@@ -1291,109 +1227,3 @@ func TestV2NonPrimaryWithLighthouse(t *testing.T) {
|
|||||||
myControl.Stop()
|
myControl.Stop()
|
||||||
theirControl.Stop()
|
theirControl.Stop()
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestV2NonPrimaryWithOffNetLighthouse(t *testing.T) {
|
|
||||||
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
|
||||||
lhControl, lhVpnIpNet, lhUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "lh ", "2001::1/64", m{"lighthouse": m{"am_lighthouse": true}})
|
|
||||||
|
|
||||||
o := m{
|
|
||||||
"static_host_map": m{
|
|
||||||
lhVpnIpNet[0].Addr().String(): []string{lhUdpAddr.String()},
|
|
||||||
},
|
|
||||||
"lighthouse": m{
|
|
||||||
"hosts": []string{lhVpnIpNet[0].Addr().String()},
|
|
||||||
"local_allow_list": m{
|
|
||||||
// Try and block our lighthouse updates from using the actual addresses assigned to this computer
|
|
||||||
// If we start discovering addresses the test router doesn't know about then test traffic cant flow
|
|
||||||
"10.0.0.0/24": true,
|
|
||||||
"::/0": false,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version2, ca, caKey, "me ", "10.128.0.2/24, ff::2/64", o)
|
|
||||||
theirControl, theirVpnIpNet, _, _ := newSimpleServer(cert.Version2, ca, caKey, "them", "10.128.0.3/24, ff::3/64", o)
|
|
||||||
|
|
||||||
// Build a router so we don't have to reason who gets which packet
|
|
||||||
r := router.NewR(t, lhControl, myControl, theirControl)
|
|
||||||
defer r.RenderFlow()
|
|
||||||
|
|
||||||
// Start the servers
|
|
||||||
lhControl.Start()
|
|
||||||
myControl.Start()
|
|
||||||
theirControl.Start()
|
|
||||||
|
|
||||||
t.Log("Stand up an ipv6 tunnel between me and them")
|
|
||||||
assert.True(t, myVpnIpNet[1].Addr().Is6())
|
|
||||||
assert.True(t, theirVpnIpNet[1].Addr().Is6())
|
|
||||||
assertTunnel(t, myVpnIpNet[1].Addr(), theirVpnIpNet[1].Addr(), myControl, theirControl, r)
|
|
||||||
|
|
||||||
lhControl.Stop()
|
|
||||||
myControl.Stop()
|
|
||||||
theirControl.Stop()
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGoodHandshakeUnsafeDest(t *testing.T) {
|
|
||||||
unsafePrefix := "192.168.6.0/24"
|
|
||||||
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
|
||||||
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServerWithUdpAndUnsafeNetworks(cert.Version2, ca, caKey, "spooky", "10.128.0.2/24", netip.MustParseAddrPort("10.64.0.2:4242"), unsafePrefix, nil)
|
|
||||||
route := m{"route": unsafePrefix, "via": theirVpnIpNet[0].Addr().String()}
|
|
||||||
myCfg := m{
|
|
||||||
"tun": m{
|
|
||||||
"unsafe_routes": []m{route},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
myControl, myVpnIpNet, myUdpAddr, myConfig := newSimpleServer(cert.Version2, ca, caKey, "me", "10.128.0.1/24", myCfg)
|
|
||||||
t.Logf("my config %v", myConfig)
|
|
||||||
// Put their info in our lighthouse
|
|
||||||
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
|
|
||||||
|
|
||||||
spookyDest := netip.MustParseAddr("192.168.6.4")
|
|
||||||
|
|
||||||
// Start the servers
|
|
||||||
myControl.Start()
|
|
||||||
theirControl.Start()
|
|
||||||
|
|
||||||
t.Log("Send a udp packet through to begin standing up the tunnel, this should come out the other side")
|
|
||||||
myControl.InjectTunUDPPacket(spookyDest, 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
|
|
||||||
|
|
||||||
t.Log("Have them consume my stage 0 packet. They have a tunnel now")
|
|
||||||
theirControl.InjectUDPPacket(myControl.GetFromUDP(true))
|
|
||||||
|
|
||||||
t.Log("Get their stage 1 packet so that we can play with it")
|
|
||||||
stage1Packet := theirControl.GetFromUDP(true)
|
|
||||||
|
|
||||||
t.Log("I consume a garbage packet with a proper nebula header for our tunnel")
|
|
||||||
// this should log a statement and get ignored, allowing the real handshake packet to complete the tunnel
|
|
||||||
badPacket := stage1Packet.Copy()
|
|
||||||
badPacket.Data = badPacket.Data[:len(badPacket.Data)-header.Len]
|
|
||||||
myControl.InjectUDPPacket(badPacket)
|
|
||||||
|
|
||||||
t.Log("Have me consume their real stage 1 packet. I have a tunnel now")
|
|
||||||
myControl.InjectUDPPacket(stage1Packet)
|
|
||||||
|
|
||||||
t.Log("Wait until we see my cached packet come through")
|
|
||||||
myControl.WaitForType(1, 0, theirControl)
|
|
||||||
|
|
||||||
t.Log("Make sure our host infos are correct")
|
|
||||||
assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet, theirVpnIpNet, myControl, theirControl)
|
|
||||||
|
|
||||||
t.Log("Get that cached packet and make sure it looks right")
|
|
||||||
myCachedPacket := theirControl.GetFromTun(true)
|
|
||||||
assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet[0].Addr(), spookyDest, 80, 80)
|
|
||||||
|
|
||||||
//reply
|
|
||||||
theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, spookyDest, 80, []byte("Hi from the spookyman"))
|
|
||||||
//wait for reply
|
|
||||||
theirControl.WaitForType(1, 0, myControl)
|
|
||||||
theirCachedPacket := myControl.GetFromTun(true)
|
|
||||||
assertUdpPacket(t, []byte("Hi from the spookyman"), theirCachedPacket, spookyDest, myVpnIpNet[0].Addr(), 80, 80)
|
|
||||||
|
|
||||||
t.Log("Do a bidirectional tunnel test")
|
|
||||||
r := router.NewR(t, myControl, theirControl)
|
|
||||||
defer r.RenderFlow()
|
|
||||||
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
|
|
||||||
|
|
||||||
r.RenderHostmaps("Final hostmaps", myControl, theirControl)
|
|
||||||
myControl.Stop()
|
|
||||||
theirControl.Stop()
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -22,14 +22,15 @@ import (
|
|||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
"github.com/slackhq/nebula/e2e/router"
|
"github.com/slackhq/nebula/e2e/router"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"gopkg.in/yaml.v3"
|
||||||
"go.yaml.in/yaml/v3"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type m = map[string]any
|
type m = map[string]any
|
||||||
|
|
||||||
// newSimpleServer creates a nebula instance with many assumptions
|
// newSimpleServer creates a nebula instance with many assumptions
|
||||||
func newSimpleServer(v cert.Version, caCrt cert.Certificate, caKey []byte, name string, sVpnNetworks string, overrides m) (*nebula.Control, []netip.Prefix, netip.AddrPort, *config.C) {
|
func newSimpleServer(v cert.Version, caCrt cert.Certificate, caKey []byte, name string, sVpnNetworks string, overrides m) (*nebula.Control, []netip.Prefix, netip.AddrPort, *config.C) {
|
||||||
|
l := NewTestLogger()
|
||||||
|
|
||||||
var vpnNetworks []netip.Prefix
|
var vpnNetworks []netip.Prefix
|
||||||
for _, sn := range strings.Split(sVpnNetworks, ",") {
|
for _, sn := range strings.Split(sVpnNetworks, ",") {
|
||||||
vpnIpNet, err := netip.ParsePrefix(strings.TrimSpace(sn))
|
vpnIpNet, err := netip.ParsePrefix(strings.TrimSpace(sn))
|
||||||
@@ -55,54 +56,7 @@ func newSimpleServer(v cert.Version, caCrt cert.Certificate, caKey []byte, name
|
|||||||
budpIp[3] = 239
|
budpIp[3] = 239
|
||||||
udpAddr = netip.AddrPortFrom(netip.AddrFrom16(budpIp), 4242)
|
udpAddr = netip.AddrPortFrom(netip.AddrFrom16(budpIp), 4242)
|
||||||
}
|
}
|
||||||
return newSimpleServerWithUdp(v, caCrt, caKey, name, sVpnNetworks, udpAddr, overrides)
|
_, _, myPrivKey, myPEM := cert_test.NewTestCert(v, cert.Curve_CURVE25519, caCrt, caKey, name, time.Now(), time.Now().Add(5*time.Minute), vpnNetworks, nil, []string{})
|
||||||
}
|
|
||||||
|
|
||||||
func newSimpleServerWithUdp(v cert.Version, caCrt cert.Certificate, caKey []byte, name string, sVpnNetworks string, udpAddr netip.AddrPort, overrides m) (*nebula.Control, []netip.Prefix, netip.AddrPort, *config.C) {
|
|
||||||
return newSimpleServerWithUdpAndUnsafeNetworks(v, caCrt, caKey, name, sVpnNetworks, udpAddr, "", overrides)
|
|
||||||
}
|
|
||||||
|
|
||||||
func newSimpleServerWithUdpAndUnsafeNetworks(v cert.Version, caCrt cert.Certificate, caKey []byte, name string, sVpnNetworks string, udpAddr netip.AddrPort, sUnsafeNetworks string, overrides m) (*nebula.Control, []netip.Prefix, netip.AddrPort, *config.C) {
|
|
||||||
l := NewTestLogger()
|
|
||||||
|
|
||||||
var vpnNetworks []netip.Prefix
|
|
||||||
for _, sn := range strings.Split(sVpnNetworks, ",") {
|
|
||||||
vpnIpNet, err := netip.ParsePrefix(strings.TrimSpace(sn))
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
vpnNetworks = append(vpnNetworks, vpnIpNet)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(vpnNetworks) == 0 {
|
|
||||||
panic("no vpn networks")
|
|
||||||
}
|
|
||||||
|
|
||||||
firewallInbound := []m{{
|
|
||||||
"proto": "any",
|
|
||||||
"port": "any",
|
|
||||||
"host": "any",
|
|
||||||
}}
|
|
||||||
|
|
||||||
var unsafeNetworks []netip.Prefix
|
|
||||||
if sUnsafeNetworks != "" {
|
|
||||||
firewallInbound = []m{{
|
|
||||||
"proto": "any",
|
|
||||||
"port": "any",
|
|
||||||
"host": "any",
|
|
||||||
"local_cidr": "0.0.0.0/0",
|
|
||||||
}}
|
|
||||||
|
|
||||||
for _, sn := range strings.Split(sUnsafeNetworks, ",") {
|
|
||||||
x, err := netip.ParsePrefix(strings.TrimSpace(sn))
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
unsafeNetworks = append(unsafeNetworks, x)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
_, _, myPrivKey, myPEM := cert_test.NewTestCert(v, cert.Curve_CURVE25519, caCrt, caKey, name, time.Now(), time.Now().Add(5*time.Minute), vpnNetworks, unsafeNetworks, []string{})
|
|
||||||
|
|
||||||
caB, err := caCrt.MarshalPEM()
|
caB, err := caCrt.MarshalPEM()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -122,7 +76,11 @@ func newSimpleServerWithUdpAndUnsafeNetworks(v cert.Version, caCrt cert.Certific
|
|||||||
"port": "any",
|
"port": "any",
|
||||||
"host": "any",
|
"host": "any",
|
||||||
}},
|
}},
|
||||||
"inbound": firewallInbound,
|
"inbound": []m{{
|
||||||
|
"proto": "any",
|
||||||
|
"port": "any",
|
||||||
|
"host": "any",
|
||||||
|
}},
|
||||||
},
|
},
|
||||||
//"handshakes": m{
|
//"handshakes": m{
|
||||||
// "try_interval": "1s",
|
// "try_interval": "1s",
|
||||||
@@ -171,109 +129,6 @@ func newSimpleServerWithUdpAndUnsafeNetworks(v cert.Version, caCrt cert.Certific
|
|||||||
return control, vpnNetworks, udpAddr, c
|
return control, vpnNetworks, udpAddr, c
|
||||||
}
|
}
|
||||||
|
|
||||||
// newServer creates a nebula instance with fewer assumptions
|
|
||||||
func newServer(caCrt []cert.Certificate, certs []cert.Certificate, key []byte, overrides m) (*nebula.Control, []netip.Prefix, netip.AddrPort, *config.C) {
|
|
||||||
l := NewTestLogger()
|
|
||||||
|
|
||||||
vpnNetworks := certs[len(certs)-1].Networks()
|
|
||||||
|
|
||||||
var udpAddr netip.AddrPort
|
|
||||||
if vpnNetworks[0].Addr().Is4() {
|
|
||||||
budpIp := vpnNetworks[0].Addr().As4()
|
|
||||||
budpIp[1] -= 128
|
|
||||||
udpAddr = netip.AddrPortFrom(netip.AddrFrom4(budpIp), 4242)
|
|
||||||
} else {
|
|
||||||
budpIp := vpnNetworks[0].Addr().As16()
|
|
||||||
// beef for funsies
|
|
||||||
budpIp[2] = 190
|
|
||||||
budpIp[3] = 239
|
|
||||||
udpAddr = netip.AddrPortFrom(netip.AddrFrom16(budpIp), 4242)
|
|
||||||
}
|
|
||||||
|
|
||||||
caStr := ""
|
|
||||||
for _, ca := range caCrt {
|
|
||||||
x, err := ca.MarshalPEM()
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
caStr += string(x)
|
|
||||||
}
|
|
||||||
certStr := ""
|
|
||||||
for _, c := range certs {
|
|
||||||
x, err := c.MarshalPEM()
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
certStr += string(x)
|
|
||||||
}
|
|
||||||
|
|
||||||
mc := m{
|
|
||||||
"pki": m{
|
|
||||||
"ca": caStr,
|
|
||||||
"cert": certStr,
|
|
||||||
"key": string(key),
|
|
||||||
},
|
|
||||||
//"tun": m{"disabled": true},
|
|
||||||
"firewall": m{
|
|
||||||
"outbound": []m{{
|
|
||||||
"proto": "any",
|
|
||||||
"port": "any",
|
|
||||||
"host": "any",
|
|
||||||
}},
|
|
||||||
"inbound": []m{{
|
|
||||||
"proto": "any",
|
|
||||||
"port": "any",
|
|
||||||
"host": "any",
|
|
||||||
}},
|
|
||||||
},
|
|
||||||
//"handshakes": m{
|
|
||||||
// "try_interval": "1s",
|
|
||||||
//},
|
|
||||||
"listen": m{
|
|
||||||
"host": udpAddr.Addr().String(),
|
|
||||||
"port": udpAddr.Port(),
|
|
||||||
},
|
|
||||||
"logging": m{
|
|
||||||
"timestamp_format": fmt.Sprintf("%v 15:04:05.000000", certs[0].Name()),
|
|
||||||
"level": l.Level.String(),
|
|
||||||
},
|
|
||||||
"timers": m{
|
|
||||||
"pending_deletion_interval": 2,
|
|
||||||
"connection_alive_interval": 2,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
if overrides != nil {
|
|
||||||
final := m{}
|
|
||||||
err := mergo.Merge(&final, overrides, mergo.WithAppendSlice)
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
err = mergo.Merge(&final, mc, mergo.WithAppendSlice)
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
mc = final
|
|
||||||
}
|
|
||||||
|
|
||||||
cb, err := yaml.Marshal(mc)
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
c := config.NewC(l)
|
|
||||||
cStr := string(cb)
|
|
||||||
c.LoadString(cStr)
|
|
||||||
|
|
||||||
control, err := nebula.Main(c, false, "e2e-test", l, nil)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return control, vpnNetworks, udpAddr, c
|
|
||||||
}
|
|
||||||
|
|
||||||
type doneCb func()
|
type doneCb func()
|
||||||
|
|
||||||
func deadline(t *testing.T, seconds time.Duration) doneCb {
|
func deadline(t *testing.T, seconds time.Duration) doneCb {
|
||||||
@@ -308,10 +163,10 @@ func assertHostInfoPair(t *testing.T, addrA, addrB netip.AddrPort, vpnNetsA, vpn
|
|||||||
// Get both host infos
|
// Get both host infos
|
||||||
//TODO: CERT-V2 we may want to loop over each vpnAddr and assert all the things
|
//TODO: CERT-V2 we may want to loop over each vpnAddr and assert all the things
|
||||||
hBinA := controlA.GetHostInfoByVpnAddr(vpnNetsB[0].Addr(), false)
|
hBinA := controlA.GetHostInfoByVpnAddr(vpnNetsB[0].Addr(), false)
|
||||||
require.NotNil(t, hBinA, "Host B was not found by vpnAddr in controlA")
|
assert.NotNil(t, hBinA, "Host B was not found by vpnAddr in controlA")
|
||||||
|
|
||||||
hAinB := controlB.GetHostInfoByVpnAddr(vpnNetsA[0].Addr(), false)
|
hAinB := controlB.GetHostInfoByVpnAddr(vpnNetsA[0].Addr(), false)
|
||||||
require.NotNil(t, hAinB, "Host A was not found by vpnAddr in controlB")
|
assert.NotNil(t, hAinB, "Host A was not found by vpnAddr in controlB")
|
||||||
|
|
||||||
// Check that both vpn and real addr are correct
|
// Check that both vpn and real addr are correct
|
||||||
assert.EqualValues(t, getAddrs(vpnNetsB), hBinA.VpnAddrs, "Host B VpnIp is wrong in control A")
|
assert.EqualValues(t, getAddrs(vpnNetsB), hBinA.VpnAddrs, "Host B VpnIp is wrong in control A")
|
||||||
|
|||||||
@@ -4,16 +4,12 @@
|
|||||||
package e2e
|
package e2e
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
|
||||||
"net/netip"
|
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/slackhq/nebula/cert"
|
"github.com/slackhq/nebula/cert"
|
||||||
"github.com/slackhq/nebula/cert_test"
|
"github.com/slackhq/nebula/cert_test"
|
||||||
"github.com/slackhq/nebula/e2e/router"
|
"github.com/slackhq/nebula/e2e/router"
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"gopkg.in/yaml.v3"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestDropInactiveTunnels(t *testing.T) {
|
func TestDropInactiveTunnels(t *testing.T) {
|
||||||
@@ -59,309 +55,3 @@ func TestDropInactiveTunnels(t *testing.T) {
|
|||||||
myControl.Stop()
|
myControl.Stop()
|
||||||
theirControl.Stop()
|
theirControl.Stop()
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCertUpgrade(t *testing.T) {
|
|
||||||
// The goal of this test is to ensure the shortest inactivity timeout will close the tunnel on both sides
|
|
||||||
// under ideal conditions
|
|
||||||
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
|
||||||
caB, err := ca.MarshalPEM()
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
ca2, _, caKey2, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
|
||||||
|
|
||||||
ca2B, err := ca2.MarshalPEM()
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
caStr := fmt.Sprintf("%s\n%s", caB, ca2B)
|
|
||||||
|
|
||||||
myCert, _, myPrivKey, _ := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "me", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.1/24")}, nil, []string{})
|
|
||||||
_, myCert2Pem := cert_test.NewTestCertDifferentVersion(myCert, cert.Version2, ca2, caKey2)
|
|
||||||
|
|
||||||
theirCert, _, theirPrivKey, _ := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "them", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.2/24")}, nil, []string{})
|
|
||||||
theirCert2, _ := cert_test.NewTestCertDifferentVersion(theirCert, cert.Version2, ca2, caKey2)
|
|
||||||
|
|
||||||
myControl, myVpnIpNet, myUdpAddr, myC := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{myCert}, myPrivKey, m{})
|
|
||||||
theirControl, theirVpnIpNet, theirUdpAddr, _ := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{theirCert, theirCert2}, theirPrivKey, m{})
|
|
||||||
|
|
||||||
// Share our underlay information
|
|
||||||
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
|
|
||||||
theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
|
|
||||||
|
|
||||||
// Start the servers
|
|
||||||
myControl.Start()
|
|
||||||
theirControl.Start()
|
|
||||||
|
|
||||||
r := router.NewR(t, myControl, theirControl)
|
|
||||||
defer r.RenderFlow()
|
|
||||||
|
|
||||||
r.Log("Assert the tunnel between me and them works")
|
|
||||||
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
|
|
||||||
r.Log("yay")
|
|
||||||
//todo ???
|
|
||||||
time.Sleep(1 * time.Second)
|
|
||||||
r.FlushAll()
|
|
||||||
|
|
||||||
mc := m{
|
|
||||||
"pki": m{
|
|
||||||
"ca": caStr,
|
|
||||||
"cert": string(myCert2Pem),
|
|
||||||
"key": string(myPrivKey),
|
|
||||||
},
|
|
||||||
//"tun": m{"disabled": true},
|
|
||||||
"firewall": myC.Settings["firewall"],
|
|
||||||
//"handshakes": m{
|
|
||||||
// "try_interval": "1s",
|
|
||||||
//},
|
|
||||||
"listen": myC.Settings["listen"],
|
|
||||||
"logging": myC.Settings["logging"],
|
|
||||||
"timers": myC.Settings["timers"],
|
|
||||||
}
|
|
||||||
|
|
||||||
cb, err := yaml.Marshal(mc)
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
r.Logf("reload new v2-only config")
|
|
||||||
err = myC.ReloadConfigString(string(cb))
|
|
||||||
assert.NoError(t, err)
|
|
||||||
r.Log("yay, spin until their sees it")
|
|
||||||
waitStart := time.Now()
|
|
||||||
for {
|
|
||||||
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
|
|
||||||
c := theirControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false)
|
|
||||||
if c == nil {
|
|
||||||
r.Log("nil")
|
|
||||||
} else {
|
|
||||||
version := c.Cert.Version()
|
|
||||||
r.Logf("version %d", version)
|
|
||||||
if version == cert.Version2 {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
since := time.Since(waitStart)
|
|
||||||
if since > time.Second*10 {
|
|
||||||
t.Fatal("Cert should be new by now")
|
|
||||||
}
|
|
||||||
time.Sleep(time.Second)
|
|
||||||
}
|
|
||||||
|
|
||||||
r.RenderHostmaps("Final hostmaps", myControl, theirControl)
|
|
||||||
|
|
||||||
myControl.Stop()
|
|
||||||
theirControl.Stop()
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCertDowngrade(t *testing.T) {
|
|
||||||
// The goal of this test is to ensure the shortest inactivity timeout will close the tunnel on both sides
|
|
||||||
// under ideal conditions
|
|
||||||
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
|
||||||
caB, err := ca.MarshalPEM()
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
ca2, _, caKey2, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
|
||||||
|
|
||||||
ca2B, err := ca2.MarshalPEM()
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
caStr := fmt.Sprintf("%s\n%s", caB, ca2B)
|
|
||||||
|
|
||||||
myCert, _, myPrivKey, myCertPem := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "me", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.1/24")}, nil, []string{})
|
|
||||||
myCert2, _ := cert_test.NewTestCertDifferentVersion(myCert, cert.Version2, ca2, caKey2)
|
|
||||||
|
|
||||||
theirCert, _, theirPrivKey, _ := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "them", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.2/24")}, nil, []string{})
|
|
||||||
theirCert2, _ := cert_test.NewTestCertDifferentVersion(theirCert, cert.Version2, ca2, caKey2)
|
|
||||||
|
|
||||||
myControl, myVpnIpNet, myUdpAddr, myC := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{myCert2}, myPrivKey, m{})
|
|
||||||
theirControl, theirVpnIpNet, theirUdpAddr, _ := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{theirCert, theirCert2}, theirPrivKey, m{})
|
|
||||||
|
|
||||||
// Share our underlay information
|
|
||||||
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
|
|
||||||
theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
|
|
||||||
|
|
||||||
// Start the servers
|
|
||||||
myControl.Start()
|
|
||||||
theirControl.Start()
|
|
||||||
|
|
||||||
r := router.NewR(t, myControl, theirControl)
|
|
||||||
defer r.RenderFlow()
|
|
||||||
|
|
||||||
r.Log("Assert the tunnel between me and them works")
|
|
||||||
//assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r)
|
|
||||||
//r.Log("yay")
|
|
||||||
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
|
|
||||||
r.Log("yay")
|
|
||||||
//todo ???
|
|
||||||
time.Sleep(1 * time.Second)
|
|
||||||
r.FlushAll()
|
|
||||||
|
|
||||||
mc := m{
|
|
||||||
"pki": m{
|
|
||||||
"ca": caStr,
|
|
||||||
"cert": string(myCertPem),
|
|
||||||
"key": string(myPrivKey),
|
|
||||||
},
|
|
||||||
"firewall": myC.Settings["firewall"],
|
|
||||||
"listen": myC.Settings["listen"],
|
|
||||||
"logging": myC.Settings["logging"],
|
|
||||||
"timers": myC.Settings["timers"],
|
|
||||||
}
|
|
||||||
|
|
||||||
cb, err := yaml.Marshal(mc)
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
r.Logf("reload new v1-only config")
|
|
||||||
err = myC.ReloadConfigString(string(cb))
|
|
||||||
assert.NoError(t, err)
|
|
||||||
r.Log("yay, spin until their sees it")
|
|
||||||
waitStart := time.Now()
|
|
||||||
for {
|
|
||||||
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
|
|
||||||
c := theirControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false)
|
|
||||||
c2 := myControl.GetHostInfoByVpnAddr(theirVpnIpNet[0].Addr(), false)
|
|
||||||
if c == nil || c2 == nil {
|
|
||||||
r.Log("nil")
|
|
||||||
} else {
|
|
||||||
version := c.Cert.Version()
|
|
||||||
theirVersion := c2.Cert.Version()
|
|
||||||
r.Logf("version %d,%d", version, theirVersion)
|
|
||||||
if version == cert.Version1 {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
since := time.Since(waitStart)
|
|
||||||
if since > time.Second*5 {
|
|
||||||
r.Log("it is unusual that the cert is not new yet, but not a failure yet")
|
|
||||||
}
|
|
||||||
if since > time.Second*10 {
|
|
||||||
r.Log("wtf")
|
|
||||||
t.Fatal("Cert should be new by now")
|
|
||||||
}
|
|
||||||
time.Sleep(time.Second)
|
|
||||||
}
|
|
||||||
|
|
||||||
r.RenderHostmaps("Final hostmaps", myControl, theirControl)
|
|
||||||
|
|
||||||
myControl.Stop()
|
|
||||||
theirControl.Stop()
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCertMismatchCorrection(t *testing.T) {
|
|
||||||
// The goal of this test is to ensure the shortest inactivity timeout will close the tunnel on both sides
|
|
||||||
// under ideal conditions
|
|
||||||
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
|
||||||
ca2, _, caKey2, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
|
||||||
|
|
||||||
myCert, _, myPrivKey, _ := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "me", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.1/24")}, nil, []string{})
|
|
||||||
myCert2, _ := cert_test.NewTestCertDifferentVersion(myCert, cert.Version2, ca2, caKey2)
|
|
||||||
|
|
||||||
theirCert, _, theirPrivKey, _ := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "them", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.2/24")}, nil, []string{})
|
|
||||||
theirCert2, _ := cert_test.NewTestCertDifferentVersion(theirCert, cert.Version2, ca2, caKey2)
|
|
||||||
|
|
||||||
myControl, myVpnIpNet, myUdpAddr, _ := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{myCert2}, myPrivKey, m{})
|
|
||||||
theirControl, theirVpnIpNet, theirUdpAddr, _ := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{theirCert, theirCert2}, theirPrivKey, m{})
|
|
||||||
|
|
||||||
// Share our underlay information
|
|
||||||
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
|
|
||||||
theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
|
|
||||||
|
|
||||||
// Start the servers
|
|
||||||
myControl.Start()
|
|
||||||
theirControl.Start()
|
|
||||||
|
|
||||||
r := router.NewR(t, myControl, theirControl)
|
|
||||||
defer r.RenderFlow()
|
|
||||||
|
|
||||||
r.Log("Assert the tunnel between me and them works")
|
|
||||||
//assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r)
|
|
||||||
//r.Log("yay")
|
|
||||||
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
|
|
||||||
r.Log("yay")
|
|
||||||
//todo ???
|
|
||||||
time.Sleep(1 * time.Second)
|
|
||||||
r.FlushAll()
|
|
||||||
|
|
||||||
waitStart := time.Now()
|
|
||||||
for {
|
|
||||||
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
|
|
||||||
c := theirControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false)
|
|
||||||
c2 := myControl.GetHostInfoByVpnAddr(theirVpnIpNet[0].Addr(), false)
|
|
||||||
if c == nil || c2 == nil {
|
|
||||||
r.Log("nil")
|
|
||||||
} else {
|
|
||||||
version := c.Cert.Version()
|
|
||||||
theirVersion := c2.Cert.Version()
|
|
||||||
r.Logf("version %d,%d", version, theirVersion)
|
|
||||||
if version == theirVersion {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
since := time.Since(waitStart)
|
|
||||||
if since > time.Second*5 {
|
|
||||||
r.Log("wtf")
|
|
||||||
}
|
|
||||||
if since > time.Second*10 {
|
|
||||||
r.Log("wtf")
|
|
||||||
t.Fatal("Cert should be new by now")
|
|
||||||
}
|
|
||||||
time.Sleep(time.Second)
|
|
||||||
}
|
|
||||||
|
|
||||||
r.RenderHostmaps("Final hostmaps", myControl, theirControl)
|
|
||||||
|
|
||||||
myControl.Stop()
|
|
||||||
theirControl.Stop()
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCrossStackRelaysWork(t *testing.T) {
|
|
||||||
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
|
||||||
myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version2, ca, caKey, "me ", "10.128.0.1/24,fc00::1/64", m{"relay": m{"use_relays": true}})
|
|
||||||
relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "relay ", "10.128.0.128/24,fc00::128/64", m{"relay": m{"am_relay": true}})
|
|
||||||
theirUdp := netip.MustParseAddrPort("10.0.0.2:4242")
|
|
||||||
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServerWithUdp(cert.Version2, ca, caKey, "them ", "fc00::2/64", theirUdp, m{"relay": m{"use_relays": true}})
|
|
||||||
|
|
||||||
//myVpnV4 := myVpnIpNet[0]
|
|
||||||
myVpnV6 := myVpnIpNet[1]
|
|
||||||
relayVpnV4 := relayVpnIpNet[0]
|
|
||||||
relayVpnV6 := relayVpnIpNet[1]
|
|
||||||
theirVpnV6 := theirVpnIpNet[0]
|
|
||||||
|
|
||||||
// Teach my how to get to the relay and that their can be reached via the relay
|
|
||||||
myControl.InjectLightHouseAddr(relayVpnV4.Addr(), relayUdpAddr)
|
|
||||||
myControl.InjectLightHouseAddr(relayVpnV6.Addr(), relayUdpAddr)
|
|
||||||
myControl.InjectRelays(theirVpnV6.Addr(), []netip.Addr{relayVpnV6.Addr()})
|
|
||||||
relayControl.InjectLightHouseAddr(theirVpnV6.Addr(), theirUdpAddr)
|
|
||||||
|
|
||||||
// Build a router so we don't have to reason who gets which packet
|
|
||||||
r := router.NewR(t, myControl, relayControl, theirControl)
|
|
||||||
defer r.RenderFlow()
|
|
||||||
|
|
||||||
// Start the servers
|
|
||||||
myControl.Start()
|
|
||||||
relayControl.Start()
|
|
||||||
theirControl.Start()
|
|
||||||
|
|
||||||
t.Log("Trigger a handshake from me to them via the relay")
|
|
||||||
myControl.InjectTunUDPPacket(theirVpnV6.Addr(), 80, myVpnV6.Addr(), 80, []byte("Hi from me"))
|
|
||||||
|
|
||||||
p := r.RouteForAllUntilTxTun(theirControl)
|
|
||||||
r.Log("Assert the tunnel works")
|
|
||||||
assertUdpPacket(t, []byte("Hi from me"), p, myVpnV6.Addr(), theirVpnV6.Addr(), 80, 80)
|
|
||||||
|
|
||||||
t.Log("reply?")
|
|
||||||
theirControl.InjectTunUDPPacket(myVpnV6.Addr(), 80, theirVpnV6.Addr(), 80, []byte("Hi from them"))
|
|
||||||
p = r.RouteForAllUntilTxTun(myControl)
|
|
||||||
assertUdpPacket(t, []byte("Hi from them"), p, theirVpnV6.Addr(), myVpnV6.Addr(), 80, 80)
|
|
||||||
|
|
||||||
r.RenderHostmaps("Final hostmaps", myControl, relayControl, theirControl)
|
|
||||||
//t.Log("finish up")
|
|
||||||
//myControl.Stop()
|
|
||||||
//theirControl.Stop()
|
|
||||||
//relayControl.Stop()
|
|
||||||
}
|
|
||||||
|
|||||||
57
firewall.go
57
firewall.go
@@ -417,45 +417,30 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var ErrUnknownNetworkType = errors.New("unknown network type")
|
var ErrInvalidRemoteIP = errors.New("remote IP is not in remote certificate subnets")
|
||||||
var ErrPeerRejected = errors.New("remote address is not within a network that we handle")
|
var ErrInvalidLocalIP = errors.New("local IP is not in list of handled local IPs")
|
||||||
var ErrInvalidRemoteIP = errors.New("remote address is not in remote certificate networks")
|
|
||||||
var ErrInvalidLocalIP = errors.New("local address is not in list of handled local addresses")
|
|
||||||
var ErrNoMatchingRule = errors.New("no matching rule in firewall table")
|
var ErrNoMatchingRule = errors.New("no matching rule in firewall table")
|
||||||
|
|
||||||
// Drop returns an error if the packet should be dropped, explaining why. It
|
// Drop returns an error if the packet should be dropped, explaining why. It
|
||||||
// returns nil if the packet should not be dropped.
|
// returns nil if the packet should not be dropped.
|
||||||
func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *cert.CAPool, localCache firewall.ConntrackCache, now time.Time) error {
|
func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *cert.CAPool, localCache firewall.ConntrackCache) error {
|
||||||
// Check if we spoke to this tuple, if we did then allow this packet
|
// Check if we spoke to this tuple, if we did then allow this packet
|
||||||
if f.inConns(fp, h, caPool, localCache, now) {
|
if f.inConns(fp, h, caPool, localCache) {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Make sure remote address matches nebula certificate, and determine how to treat it
|
// Make sure remote address matches nebula certificate
|
||||||
if h.networks == nil {
|
if h.networks != nil {
|
||||||
// Simple case: Certificate has one address and no unsafe networks
|
if !h.networks.Contains(fp.RemoteAddr) {
|
||||||
if h.vpnAddrs[0] != fp.RemoteAddr {
|
|
||||||
f.metrics(incoming).droppedRemoteAddr.Inc(1)
|
f.metrics(incoming).droppedRemoteAddr.Inc(1)
|
||||||
return ErrInvalidRemoteIP
|
return ErrInvalidRemoteIP
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
nwType, ok := h.networks.Lookup(fp.RemoteAddr)
|
// Simple case: Certificate has one address and no unsafe networks
|
||||||
if !ok {
|
if h.vpnAddrs[0] != fp.RemoteAddr {
|
||||||
f.metrics(incoming).droppedRemoteAddr.Inc(1)
|
f.metrics(incoming).droppedRemoteAddr.Inc(1)
|
||||||
return ErrInvalidRemoteIP
|
return ErrInvalidRemoteIP
|
||||||
}
|
}
|
||||||
switch nwType {
|
|
||||||
case NetworkTypeVPN:
|
|
||||||
break // nothing special
|
|
||||||
case NetworkTypeVPNPeer:
|
|
||||||
f.metrics(incoming).droppedRemoteAddr.Inc(1)
|
|
||||||
return ErrPeerRejected // reject for now, one day this may have different FW rules
|
|
||||||
case NetworkTypeUnsafe:
|
|
||||||
break // nothing special, one day this may have different FW rules
|
|
||||||
default:
|
|
||||||
f.metrics(incoming).droppedRemoteAddr.Inc(1)
|
|
||||||
return ErrUnknownNetworkType //should never happen
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Make sure we are supposed to be handling this local ip address
|
// Make sure we are supposed to be handling this local ip address
|
||||||
@@ -476,7 +461,7 @@ func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *
|
|||||||
}
|
}
|
||||||
|
|
||||||
// We always want to conntrack since it is a faster operation
|
// We always want to conntrack since it is a faster operation
|
||||||
f.addConn(fp, incoming, now)
|
f.addConn(fp, incoming)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -505,7 +490,7 @@ func (f *Firewall) EmitStats() {
|
|||||||
metrics.GetOrRegisterGauge("firewall.rules.hash", nil).Update(int64(f.GetRuleHashFNV()))
|
metrics.GetOrRegisterGauge("firewall.rules.hash", nil).Update(int64(f.GetRuleHashFNV()))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool, localCache firewall.ConntrackCache, now time.Time) bool {
|
func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool, localCache firewall.ConntrackCache) bool {
|
||||||
if localCache != nil {
|
if localCache != nil {
|
||||||
if _, ok := localCache[fp]; ok {
|
if _, ok := localCache[fp]; ok {
|
||||||
return true
|
return true
|
||||||
@@ -517,7 +502,7 @@ func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool,
|
|||||||
// Purge every time we test
|
// Purge every time we test
|
||||||
ep, has := conntrack.TimerWheel.Purge()
|
ep, has := conntrack.TimerWheel.Purge()
|
||||||
if has {
|
if has {
|
||||||
f.evict(ep, now)
|
f.evict(ep)
|
||||||
}
|
}
|
||||||
|
|
||||||
c, ok := conntrack.Conns[fp]
|
c, ok := conntrack.Conns[fp]
|
||||||
@@ -564,11 +549,11 @@ func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool,
|
|||||||
|
|
||||||
switch fp.Protocol {
|
switch fp.Protocol {
|
||||||
case firewall.ProtoTCP:
|
case firewall.ProtoTCP:
|
||||||
c.Expires = now.Add(f.TCPTimeout)
|
c.Expires = time.Now().Add(f.TCPTimeout)
|
||||||
case firewall.ProtoUDP:
|
case firewall.ProtoUDP:
|
||||||
c.Expires = now.Add(f.UDPTimeout)
|
c.Expires = time.Now().Add(f.UDPTimeout)
|
||||||
default:
|
default:
|
||||||
c.Expires = now.Add(f.DefaultTimeout)
|
c.Expires = time.Now().Add(f.DefaultTimeout)
|
||||||
}
|
}
|
||||||
|
|
||||||
conntrack.Unlock()
|
conntrack.Unlock()
|
||||||
@@ -580,7 +565,7 @@ func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool,
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Firewall) addConn(fp firewall.Packet, incoming bool, now time.Time) {
|
func (f *Firewall) addConn(fp firewall.Packet, incoming bool) {
|
||||||
var timeout time.Duration
|
var timeout time.Duration
|
||||||
c := &conn{}
|
c := &conn{}
|
||||||
|
|
||||||
@@ -596,7 +581,7 @@ func (f *Firewall) addConn(fp firewall.Packet, incoming bool, now time.Time) {
|
|||||||
conntrack := f.Conntrack
|
conntrack := f.Conntrack
|
||||||
conntrack.Lock()
|
conntrack.Lock()
|
||||||
if _, ok := conntrack.Conns[fp]; !ok {
|
if _, ok := conntrack.Conns[fp]; !ok {
|
||||||
conntrack.TimerWheel.Advance(now)
|
conntrack.TimerWheel.Advance(time.Now())
|
||||||
conntrack.TimerWheel.Add(fp, timeout)
|
conntrack.TimerWheel.Add(fp, timeout)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -604,14 +589,14 @@ func (f *Firewall) addConn(fp firewall.Packet, incoming bool, now time.Time) {
|
|||||||
// firewall reload
|
// firewall reload
|
||||||
c.incoming = incoming
|
c.incoming = incoming
|
||||||
c.rulesVersion = f.rulesVersion
|
c.rulesVersion = f.rulesVersion
|
||||||
c.Expires = now.Add(timeout)
|
c.Expires = time.Now().Add(timeout)
|
||||||
conntrack.Conns[fp] = c
|
conntrack.Conns[fp] = c
|
||||||
conntrack.Unlock()
|
conntrack.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Evict checks if a conntrack entry has expired, if so it is removed, if not it is re-added to the wheel
|
// Evict checks if a conntrack entry has expired, if so it is removed, if not it is re-added to the wheel
|
||||||
// Caller must own the connMutex lock!
|
// Caller must own the connMutex lock!
|
||||||
func (f *Firewall) evict(p firewall.Packet, now time.Time) {
|
func (f *Firewall) evict(p firewall.Packet) {
|
||||||
// Are we still tracking this conn?
|
// Are we still tracking this conn?
|
||||||
conntrack := f.Conntrack
|
conntrack := f.Conntrack
|
||||||
t, ok := conntrack.Conns[p]
|
t, ok := conntrack.Conns[p]
|
||||||
@@ -619,11 +604,11 @@ func (f *Firewall) evict(p firewall.Packet, now time.Time) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
newT := t.Expires.Sub(now)
|
newT := t.Expires.Sub(time.Now())
|
||||||
|
|
||||||
// Timeout is in the future, re-add the timer
|
// Timeout is in the future, re-add the timer
|
||||||
if newT > 0 {
|
if newT > 0 {
|
||||||
conntrack.TimerWheel.Advance(now)
|
conntrack.TimerWheel.Advance(time.Now())
|
||||||
conntrack.TimerWheel.Add(p, newT)
|
conntrack.TimerWheel.Add(p, newT)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
203
firewall_test.go
203
firewall_test.go
@@ -8,8 +8,6 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gaissmai/bart"
|
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
"github.com/slackhq/nebula/cert"
|
"github.com/slackhq/nebula/cert"
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
"github.com/slackhq/nebula/firewall"
|
"github.com/slackhq/nebula/firewall"
|
||||||
@@ -151,8 +149,7 @@ func TestFirewall_Drop(t *testing.T) {
|
|||||||
l := test.NewLogger()
|
l := test.NewLogger()
|
||||||
ob := &bytes.Buffer{}
|
ob := &bytes.Buffer{}
|
||||||
l.SetOutput(ob)
|
l.SetOutput(ob)
|
||||||
myVpnNetworksTable := new(bart.Lite)
|
|
||||||
myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
|
|
||||||
p := firewall.Packet{
|
p := firewall.Packet{
|
||||||
LocalAddr: netip.MustParseAddr("1.2.3.4"),
|
LocalAddr: netip.MustParseAddr("1.2.3.4"),
|
||||||
RemoteAddr: netip.MustParseAddr("1.2.3.4"),
|
RemoteAddr: netip.MustParseAddr("1.2.3.4"),
|
||||||
@@ -177,7 +174,7 @@ func TestFirewall_Drop(t *testing.T) {
|
|||||||
},
|
},
|
||||||
vpnAddrs: []netip.Addr{netip.MustParseAddr("1.2.3.4")},
|
vpnAddrs: []netip.Addr{netip.MustParseAddr("1.2.3.4")},
|
||||||
}
|
}
|
||||||
h.buildNetworks(myVpnNetworksTable, &c)
|
h.buildNetworks(c.networks, c.unsafeNetworks)
|
||||||
|
|
||||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
||||||
@@ -229,9 +226,6 @@ func TestFirewall_DropV6(t *testing.T) {
|
|||||||
ob := &bytes.Buffer{}
|
ob := &bytes.Buffer{}
|
||||||
l.SetOutput(ob)
|
l.SetOutput(ob)
|
||||||
|
|
||||||
myVpnNetworksTable := new(bart.Lite)
|
|
||||||
myVpnNetworksTable.Insert(netip.MustParsePrefix("fd00::/7"))
|
|
||||||
|
|
||||||
p := firewall.Packet{
|
p := firewall.Packet{
|
||||||
LocalAddr: netip.MustParseAddr("fd12::34"),
|
LocalAddr: netip.MustParseAddr("fd12::34"),
|
||||||
RemoteAddr: netip.MustParseAddr("fd12::34"),
|
RemoteAddr: netip.MustParseAddr("fd12::34"),
|
||||||
@@ -256,7 +250,7 @@ func TestFirewall_DropV6(t *testing.T) {
|
|||||||
},
|
},
|
||||||
vpnAddrs: []netip.Addr{netip.MustParseAddr("fd12::34")},
|
vpnAddrs: []netip.Addr{netip.MustParseAddr("fd12::34")},
|
||||||
}
|
}
|
||||||
h.buildNetworks(myVpnNetworksTable, &c)
|
h.buildNetworks(c.networks, c.unsafeNetworks)
|
||||||
|
|
||||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
||||||
@@ -459,8 +453,6 @@ func TestFirewall_Drop2(t *testing.T) {
|
|||||||
l := test.NewLogger()
|
l := test.NewLogger()
|
||||||
ob := &bytes.Buffer{}
|
ob := &bytes.Buffer{}
|
||||||
l.SetOutput(ob)
|
l.SetOutput(ob)
|
||||||
myVpnNetworksTable := new(bart.Lite)
|
|
||||||
myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
|
|
||||||
|
|
||||||
p := firewall.Packet{
|
p := firewall.Packet{
|
||||||
LocalAddr: netip.MustParseAddr("1.2.3.4"),
|
LocalAddr: netip.MustParseAddr("1.2.3.4"),
|
||||||
@@ -486,7 +478,7 @@ func TestFirewall_Drop2(t *testing.T) {
|
|||||||
},
|
},
|
||||||
vpnAddrs: []netip.Addr{network.Addr()},
|
vpnAddrs: []netip.Addr{network.Addr()},
|
||||||
}
|
}
|
||||||
h.buildNetworks(myVpnNetworksTable, c.Certificate)
|
h.buildNetworks(c.Certificate.Networks(), c.Certificate.UnsafeNetworks())
|
||||||
|
|
||||||
c1 := cert.CachedCertificate{
|
c1 := cert.CachedCertificate{
|
||||||
Certificate: &dummyCert{
|
Certificate: &dummyCert{
|
||||||
@@ -501,7 +493,7 @@ func TestFirewall_Drop2(t *testing.T) {
|
|||||||
peerCert: &c1,
|
peerCert: &c1,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
h1.buildNetworks(myVpnNetworksTable, c1.Certificate)
|
h1.buildNetworks(c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks())
|
||||||
|
|
||||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
||||||
@@ -518,8 +510,6 @@ func TestFirewall_Drop3(t *testing.T) {
|
|||||||
l := test.NewLogger()
|
l := test.NewLogger()
|
||||||
ob := &bytes.Buffer{}
|
ob := &bytes.Buffer{}
|
||||||
l.SetOutput(ob)
|
l.SetOutput(ob)
|
||||||
myVpnNetworksTable := new(bart.Lite)
|
|
||||||
myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
|
|
||||||
|
|
||||||
p := firewall.Packet{
|
p := firewall.Packet{
|
||||||
LocalAddr: netip.MustParseAddr("1.2.3.4"),
|
LocalAddr: netip.MustParseAddr("1.2.3.4"),
|
||||||
@@ -551,7 +541,7 @@ func TestFirewall_Drop3(t *testing.T) {
|
|||||||
},
|
},
|
||||||
vpnAddrs: []netip.Addr{network.Addr()},
|
vpnAddrs: []netip.Addr{network.Addr()},
|
||||||
}
|
}
|
||||||
h1.buildNetworks(myVpnNetworksTable, c1.Certificate)
|
h1.buildNetworks(c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks())
|
||||||
|
|
||||||
c2 := cert.CachedCertificate{
|
c2 := cert.CachedCertificate{
|
||||||
Certificate: &dummyCert{
|
Certificate: &dummyCert{
|
||||||
@@ -566,7 +556,7 @@ func TestFirewall_Drop3(t *testing.T) {
|
|||||||
},
|
},
|
||||||
vpnAddrs: []netip.Addr{network.Addr()},
|
vpnAddrs: []netip.Addr{network.Addr()},
|
||||||
}
|
}
|
||||||
h2.buildNetworks(myVpnNetworksTable, c2.Certificate)
|
h2.buildNetworks(c2.Certificate.Networks(), c2.Certificate.UnsafeNetworks())
|
||||||
|
|
||||||
c3 := cert.CachedCertificate{
|
c3 := cert.CachedCertificate{
|
||||||
Certificate: &dummyCert{
|
Certificate: &dummyCert{
|
||||||
@@ -581,7 +571,7 @@ func TestFirewall_Drop3(t *testing.T) {
|
|||||||
},
|
},
|
||||||
vpnAddrs: []netip.Addr{network.Addr()},
|
vpnAddrs: []netip.Addr{network.Addr()},
|
||||||
}
|
}
|
||||||
h3.buildNetworks(myVpnNetworksTable, c3.Certificate)
|
h3.buildNetworks(c3.Certificate.Networks(), c3.Certificate.UnsafeNetworks())
|
||||||
|
|
||||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", netip.Prefix{}, netip.Prefix{}, "", ""))
|
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", netip.Prefix{}, netip.Prefix{}, "", ""))
|
||||||
@@ -607,8 +597,6 @@ func TestFirewall_Drop3V6(t *testing.T) {
|
|||||||
l := test.NewLogger()
|
l := test.NewLogger()
|
||||||
ob := &bytes.Buffer{}
|
ob := &bytes.Buffer{}
|
||||||
l.SetOutput(ob)
|
l.SetOutput(ob)
|
||||||
myVpnNetworksTable := new(bart.Lite)
|
|
||||||
myVpnNetworksTable.Insert(netip.MustParsePrefix("fd00::/7"))
|
|
||||||
|
|
||||||
p := firewall.Packet{
|
p := firewall.Packet{
|
||||||
LocalAddr: netip.MustParseAddr("fd12::34"),
|
LocalAddr: netip.MustParseAddr("fd12::34"),
|
||||||
@@ -632,7 +620,7 @@ func TestFirewall_Drop3V6(t *testing.T) {
|
|||||||
},
|
},
|
||||||
vpnAddrs: []netip.Addr{network.Addr()},
|
vpnAddrs: []netip.Addr{network.Addr()},
|
||||||
}
|
}
|
||||||
h.buildNetworks(myVpnNetworksTable, c.Certificate)
|
h.buildNetworks(c.Certificate.Networks(), c.Certificate.UnsafeNetworks())
|
||||||
|
|
||||||
// Test a remote address match
|
// Test a remote address match
|
||||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
||||||
@@ -645,8 +633,6 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
|
|||||||
l := test.NewLogger()
|
l := test.NewLogger()
|
||||||
ob := &bytes.Buffer{}
|
ob := &bytes.Buffer{}
|
||||||
l.SetOutput(ob)
|
l.SetOutput(ob)
|
||||||
myVpnNetworksTable := new(bart.Lite)
|
|
||||||
myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
|
|
||||||
|
|
||||||
p := firewall.Packet{
|
p := firewall.Packet{
|
||||||
LocalAddr: netip.MustParseAddr("1.2.3.4"),
|
LocalAddr: netip.MustParseAddr("1.2.3.4"),
|
||||||
@@ -673,7 +659,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
|
|||||||
},
|
},
|
||||||
vpnAddrs: []netip.Addr{network.Addr()},
|
vpnAddrs: []netip.Addr{network.Addr()},
|
||||||
}
|
}
|
||||||
h.buildNetworks(myVpnNetworksTable, c.Certificate)
|
h.buildNetworks(c.Certificate.Networks(), c.Certificate.UnsafeNetworks())
|
||||||
|
|
||||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
||||||
@@ -710,8 +696,6 @@ func TestFirewall_DropIPSpoofing(t *testing.T) {
|
|||||||
l := test.NewLogger()
|
l := test.NewLogger()
|
||||||
ob := &bytes.Buffer{}
|
ob := &bytes.Buffer{}
|
||||||
l.SetOutput(ob)
|
l.SetOutput(ob)
|
||||||
myVpnNetworksTable := new(bart.Lite)
|
|
||||||
myVpnNetworksTable.Insert(netip.MustParsePrefix("192.0.2.1/24"))
|
|
||||||
|
|
||||||
c := cert.CachedCertificate{
|
c := cert.CachedCertificate{
|
||||||
Certificate: &dummyCert{
|
Certificate: &dummyCert{
|
||||||
@@ -733,7 +717,7 @@ func TestFirewall_DropIPSpoofing(t *testing.T) {
|
|||||||
},
|
},
|
||||||
vpnAddrs: []netip.Addr{c1.Certificate.Networks()[0].Addr()},
|
vpnAddrs: []netip.Addr{c1.Certificate.Networks()[0].Addr()},
|
||||||
}
|
}
|
||||||
h1.buildNetworks(myVpnNetworksTable, c1.Certificate)
|
h1.buildNetworks(c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks())
|
||||||
|
|
||||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
||||||
|
|
||||||
@@ -1063,171 +1047,6 @@ func TestFirewall_convertRule(t *testing.T) {
|
|||||||
assert.Equal(t, "group1", r.Group)
|
assert.Equal(t, "group1", r.Group)
|
||||||
}
|
}
|
||||||
|
|
||||||
type testcase struct {
|
|
||||||
h *HostInfo
|
|
||||||
p firewall.Packet
|
|
||||||
c cert.Certificate
|
|
||||||
err error
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *testcase) Test(t *testing.T, fw *Firewall) {
|
|
||||||
t.Helper()
|
|
||||||
cp := cert.NewCAPool()
|
|
||||||
resetConntrack(fw)
|
|
||||||
err := fw.Drop(c.p, true, c.h, cp, nil)
|
|
||||||
if c.err == nil {
|
|
||||||
require.NoError(t, err, "failed to not drop remote address %s", c.p.RemoteAddr)
|
|
||||||
} else {
|
|
||||||
require.ErrorIs(t, c.err, err, "failed to drop remote address %s", c.p.RemoteAddr)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func buildTestCase(setup testsetup, err error, theirPrefixes ...netip.Prefix) testcase {
|
|
||||||
c1 := dummyCert{
|
|
||||||
name: "host1",
|
|
||||||
networks: theirPrefixes,
|
|
||||||
groups: []string{"default-group"},
|
|
||||||
issuer: "signer-shasum",
|
|
||||||
}
|
|
||||||
h := HostInfo{
|
|
||||||
ConnectionState: &ConnectionState{
|
|
||||||
peerCert: &cert.CachedCertificate{
|
|
||||||
Certificate: &c1,
|
|
||||||
InvertedGroups: map[string]struct{}{"default-group": {}},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
vpnAddrs: make([]netip.Addr, len(theirPrefixes)),
|
|
||||||
}
|
|
||||||
for i := range theirPrefixes {
|
|
||||||
h.vpnAddrs[i] = theirPrefixes[i].Addr()
|
|
||||||
}
|
|
||||||
h.buildNetworks(setup.myVpnNetworksTable, &c1)
|
|
||||||
p := firewall.Packet{
|
|
||||||
LocalAddr: setup.c.Networks()[0].Addr(), //todo?
|
|
||||||
RemoteAddr: theirPrefixes[0].Addr(),
|
|
||||||
LocalPort: 10,
|
|
||||||
RemotePort: 90,
|
|
||||||
Protocol: firewall.ProtoUDP,
|
|
||||||
Fragment: false,
|
|
||||||
}
|
|
||||||
return testcase{
|
|
||||||
h: &h,
|
|
||||||
p: p,
|
|
||||||
c: &c1,
|
|
||||||
err: err,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type testsetup struct {
|
|
||||||
c dummyCert
|
|
||||||
myVpnNetworksTable *bart.Lite
|
|
||||||
fw *Firewall
|
|
||||||
}
|
|
||||||
|
|
||||||
func newSetup(t *testing.T, l *logrus.Logger, myPrefixes ...netip.Prefix) testsetup {
|
|
||||||
c := dummyCert{
|
|
||||||
name: "me",
|
|
||||||
networks: myPrefixes,
|
|
||||||
groups: []string{"default-group"},
|
|
||||||
issuer: "signer-shasum",
|
|
||||||
}
|
|
||||||
|
|
||||||
return newSetupFromCert(t, l, c)
|
|
||||||
}
|
|
||||||
|
|
||||||
func newSetupFromCert(t *testing.T, l *logrus.Logger, c dummyCert) testsetup {
|
|
||||||
myVpnNetworksTable := new(bart.Lite)
|
|
||||||
for _, prefix := range c.Networks() {
|
|
||||||
myVpnNetworksTable.Insert(prefix)
|
|
||||||
}
|
|
||||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
|
||||||
|
|
||||||
return testsetup{
|
|
||||||
c: c,
|
|
||||||
fw: fw,
|
|
||||||
myVpnNetworksTable: myVpnNetworksTable,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestFirewall_Drop_EnforceIPMatch(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
l := test.NewLogger()
|
|
||||||
ob := &bytes.Buffer{}
|
|
||||||
l.SetOutput(ob)
|
|
||||||
|
|
||||||
myPrefix := netip.MustParsePrefix("1.1.1.1/8")
|
|
||||||
// for now, it's okay that these are all "incoming", the logic this test tries to check doesn't care about in/out
|
|
||||||
t.Run("allow inbound all matching", func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
setup := newSetup(t, l, myPrefix)
|
|
||||||
tc := buildTestCase(setup, nil, netip.MustParsePrefix("1.2.3.4/24"))
|
|
||||||
tc.Test(t, setup.fw)
|
|
||||||
})
|
|
||||||
t.Run("allow inbound local matching", func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
setup := newSetup(t, l, myPrefix)
|
|
||||||
tc := buildTestCase(setup, ErrInvalidLocalIP, netip.MustParsePrefix("1.2.3.4/24"))
|
|
||||||
tc.p.LocalAddr = netip.MustParseAddr("1.2.3.8")
|
|
||||||
tc.Test(t, setup.fw)
|
|
||||||
})
|
|
||||||
t.Run("block inbound remote mismatched", func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
setup := newSetup(t, l, myPrefix)
|
|
||||||
tc := buildTestCase(setup, ErrInvalidRemoteIP, netip.MustParsePrefix("1.2.3.4/24"))
|
|
||||||
tc.p.RemoteAddr = netip.MustParseAddr("9.9.9.9")
|
|
||||||
tc.Test(t, setup.fw)
|
|
||||||
})
|
|
||||||
t.Run("Block a vpn peer packet", func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
setup := newSetup(t, l, myPrefix)
|
|
||||||
tc := buildTestCase(setup, ErrPeerRejected, netip.MustParsePrefix("2.2.2.2/24"))
|
|
||||||
tc.Test(t, setup.fw)
|
|
||||||
})
|
|
||||||
twoPrefixes := []netip.Prefix{
|
|
||||||
netip.MustParsePrefix("1.2.3.4/24"), netip.MustParsePrefix("2.2.2.2/24"),
|
|
||||||
}
|
|
||||||
t.Run("allow inbound one matching", func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
setup := newSetup(t, l, myPrefix)
|
|
||||||
tc := buildTestCase(setup, nil, twoPrefixes...)
|
|
||||||
tc.Test(t, setup.fw)
|
|
||||||
})
|
|
||||||
t.Run("block inbound multimismatch", func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
setup := newSetup(t, l, myPrefix)
|
|
||||||
tc := buildTestCase(setup, ErrInvalidRemoteIP, twoPrefixes...)
|
|
||||||
tc.p.RemoteAddr = netip.MustParseAddr("9.9.9.9")
|
|
||||||
tc.Test(t, setup.fw)
|
|
||||||
})
|
|
||||||
t.Run("allow inbound 2nd one matching", func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
setup2 := newSetup(t, l, netip.MustParsePrefix("2.2.2.1/24"))
|
|
||||||
tc := buildTestCase(setup2, nil, twoPrefixes...)
|
|
||||||
tc.p.RemoteAddr = twoPrefixes[1].Addr()
|
|
||||||
tc.Test(t, setup2.fw)
|
|
||||||
})
|
|
||||||
t.Run("allow inbound unsafe route", func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
unsafePrefix := netip.MustParsePrefix("192.168.0.0/24")
|
|
||||||
c := dummyCert{
|
|
||||||
name: "me",
|
|
||||||
networks: []netip.Prefix{myPrefix},
|
|
||||||
unsafeNetworks: []netip.Prefix{unsafePrefix},
|
|
||||||
groups: []string{"default-group"},
|
|
||||||
issuer: "signer-shasum",
|
|
||||||
}
|
|
||||||
unsafeSetup := newSetupFromCert(t, l, c)
|
|
||||||
tc := buildTestCase(unsafeSetup, nil, twoPrefixes...)
|
|
||||||
tc.p.LocalAddr = netip.MustParseAddr("192.168.0.3")
|
|
||||||
tc.err = ErrNoMatchingRule
|
|
||||||
tc.Test(t, unsafeSetup.fw) //should hit firewall and bounce off
|
|
||||||
require.NoError(t, unsafeSetup.fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, unsafePrefix, "", ""))
|
|
||||||
tc.err = nil
|
|
||||||
tc.Test(t, unsafeSetup.fw) //should pass
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
type addRuleCall struct {
|
type addRuleCall struct {
|
||||||
incoming bool
|
incoming bool
|
||||||
proto uint8
|
proto uint8
|
||||||
|
|||||||
17
go.mod
17
go.mod
@@ -22,19 +22,18 @@ require (
|
|||||||
github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6
|
github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6
|
||||||
github.com/stretchr/testify v1.11.1
|
github.com/stretchr/testify v1.11.1
|
||||||
github.com/vishvananda/netlink v1.3.1
|
github.com/vishvananda/netlink v1.3.1
|
||||||
go.yaml.in/yaml/v3 v3.0.4
|
golang.org/x/crypto v0.43.0
|
||||||
golang.org/x/crypto v0.44.0
|
|
||||||
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090
|
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090
|
||||||
golang.org/x/net v0.46.0
|
golang.org/x/net v0.45.0
|
||||||
golang.org/x/sync v0.18.0
|
golang.org/x/sync v0.17.0
|
||||||
golang.org/x/sys v0.38.0
|
golang.org/x/sys v0.37.0
|
||||||
golang.org/x/term v0.37.0
|
golang.org/x/term v0.36.0
|
||||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2
|
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2
|
||||||
golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b
|
golang.zx2c4.com/wireguard v0.0.0-20250521234502-f333402bd9cb
|
||||||
golang.zx2c4.com/wireguard/windows v0.5.3
|
golang.zx2c4.com/wireguard/windows v0.5.3
|
||||||
google.golang.org/protobuf v1.36.10
|
google.golang.org/protobuf v1.36.8
|
||||||
gopkg.in/yaml.v3 v3.0.1
|
gopkg.in/yaml.v3 v3.0.1
|
||||||
gvisor.dev/gvisor v0.0.0-20240423190808-9d7a357edefe
|
gvisor.dev/gvisor v0.0.0-20250503011706-39ed1f5ac29c
|
||||||
)
|
)
|
||||||
|
|
||||||
require (
|
require (
|
||||||
|
|||||||
34
go.sum
34
go.sum
@@ -155,15 +155,13 @@ go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
|
|||||||
go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
|
go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
|
||||||
go.yaml.in/yaml/v2 v2.4.2 h1:DzmwEr2rDGHl7lsFgAHxmNz/1NlQ7xLIrlN2h5d1eGI=
|
go.yaml.in/yaml/v2 v2.4.2 h1:DzmwEr2rDGHl7lsFgAHxmNz/1NlQ7xLIrlN2h5d1eGI=
|
||||||
go.yaml.in/yaml/v2 v2.4.2/go.mod h1:081UH+NErpNdqlCXm3TtEran0rJZGxAYx9hb/ELlsPU=
|
go.yaml.in/yaml/v2 v2.4.2/go.mod h1:081UH+NErpNdqlCXm3TtEran0rJZGxAYx9hb/ELlsPU=
|
||||||
go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc=
|
|
||||||
go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg=
|
|
||||||
golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
|
golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
|
||||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||||
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||||
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
||||||
golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
|
golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
|
||||||
golang.org/x/crypto v0.44.0 h1:A97SsFvM3AIwEEmTBiaxPPTYpDC47w720rdiiUvgoAU=
|
golang.org/x/crypto v0.43.0 h1:dduJYIi3A3KOfdGOHX8AVZ/jGiyPa3IbBozJ5kNuE04=
|
||||||
golang.org/x/crypto v0.44.0/go.mod h1:013i+Nw79BMiQiMsOPcVCB5ZIJbYkerPrGnOa00tvmc=
|
golang.org/x/crypto v0.43.0/go.mod h1:BFbav4mRNlXJL4wNeejLpWxB7wMbc79PdRGhWKncxR0=
|
||||||
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 h1:Di6/M8l0O2lCLc6VVRWhgCiApHV8MnQurBnFSHsQtNY=
|
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 h1:Di6/M8l0O2lCLc6VVRWhgCiApHV8MnQurBnFSHsQtNY=
|
||||||
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc=
|
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc=
|
||||||
golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY=
|
golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY=
|
||||||
@@ -182,8 +180,8 @@ golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLL
|
|||||||
golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
|
golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
|
||||||
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
|
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
|
||||||
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||||
golang.org/x/net v0.46.0 h1:giFlY12I07fugqwPuWJi68oOnpfqFnJIJzaIIm2JVV4=
|
golang.org/x/net v0.45.0 h1:RLBg5JKixCy82FtLJpeNlVM0nrSqpCRYzVU1n8kj0tM=
|
||||||
golang.org/x/net v0.46.0/go.mod h1:Q9BGdFy1y4nkUwiLvT5qtyhAnEHgnQ/zd8PfU6nc210=
|
golang.org/x/net v0.45.0/go.mod h1:ECOoLqd5U3Lhyeyo/QDCEVQ4sNgYsqvCZ722XogGieY=
|
||||||
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
|
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
|
||||||
golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
@@ -191,8 +189,8 @@ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJ
|
|||||||
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I=
|
golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug=
|
||||||
golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||||
golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||||
golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||||
@@ -209,11 +207,11 @@ golang.org/x/sys v0.0.0-20210603081109-ebe580a85c40/go.mod h1:oPkhp1MJrh7nUepCBc
|
|||||||
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc=
|
golang.org/x/sys v0.37.0 h1:fdNQudmxPjkdUTPnLn5mdQv7Zwvbvpaxqs831goi9kQ=
|
||||||
golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
golang.org/x/sys v0.37.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||||
golang.org/x/term v0.37.0 h1:8EGAD0qCmHYZg6J17DvsMy9/wJ7/D/4pV/wfnld5lTU=
|
golang.org/x/term v0.36.0 h1:zMPR+aF8gfksFprF/Nc/rd1wRS1EI6nDBGyWAvDzx2Q=
|
||||||
golang.org/x/term v0.37.0/go.mod h1:5pB4lxRNYYVZuTLmy8oR2BH8dflOR+IbTYFD8fi3254=
|
golang.org/x/term v0.36.0/go.mod h1:Qu394IJq6V6dCBRgwqshf3mPF85AqzYEzofzRdZkWss=
|
||||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||||
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
|
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
|
||||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||||
@@ -232,8 +230,8 @@ golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8T
|
|||||||
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg=
|
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg=
|
||||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI=
|
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI=
|
||||||
golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b h1:J1CaxgLerRR5lgx3wnr6L04cJFbWoceSK9JWBdglINo=
|
golang.zx2c4.com/wireguard v0.0.0-20250521234502-f333402bd9cb h1:whnFRlWMcXI9d+ZbWg+4sHnLp52d5yiIPUxMBSt4X9A=
|
||||||
golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b/go.mod h1:tqur9LnfstdR9ep2LaJT4lFUl0EjlHtge+gAjmsHUG4=
|
golang.zx2c4.com/wireguard v0.0.0-20250521234502-f333402bd9cb/go.mod h1:rpwXGsirqLqN2L0JDJQlwOboGHmptD5ZD6T2VmcqhTw=
|
||||||
golang.zx2c4.com/wireguard/windows v0.5.3 h1:On6j2Rpn3OEMXqBq00QEDC7bWSZrPIHKIus8eIuExIE=
|
golang.zx2c4.com/wireguard/windows v0.5.3 h1:On6j2Rpn3OEMXqBq00QEDC7bWSZrPIHKIus8eIuExIE=
|
||||||
golang.zx2c4.com/wireguard/windows v0.5.3/go.mod h1:9TEe8TJmtwyQebdFwAkEWOPr3prrtqm+REGFifP60hI=
|
golang.zx2c4.com/wireguard/windows v0.5.3/go.mod h1:9TEe8TJmtwyQebdFwAkEWOPr3prrtqm+REGFifP60hI=
|
||||||
google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4=
|
google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4=
|
||||||
@@ -244,8 +242,8 @@ google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miE
|
|||||||
google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo=
|
google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo=
|
||||||
google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU=
|
google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU=
|
||||||
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
|
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
|
||||||
google.golang.org/protobuf v1.36.10 h1:AYd7cD/uASjIL6Q9LiTjz8JLcrh/88q5UObnmY3aOOE=
|
google.golang.org/protobuf v1.36.8 h1:xHScyCOEuuwZEc6UtSOvPbAT4zRh0xcNRYekJwfqyMc=
|
||||||
google.golang.org/protobuf v1.36.10/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco=
|
google.golang.org/protobuf v1.36.8/go.mod h1:fuxRtAxBytpl4zzqUh6/eyUujkJdNiuEkXntxiD/uRU=
|
||||||
gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw=
|
gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw=
|
||||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||||
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||||
@@ -259,5 +257,5 @@ gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
|||||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||||
gvisor.dev/gvisor v0.0.0-20240423190808-9d7a357edefe h1:fre4i6mv4iBuz5lCMOzHD1rH1ljqHWSICFmZRbbgp3g=
|
gvisor.dev/gvisor v0.0.0-20250503011706-39ed1f5ac29c h1:m/r7OM+Y2Ty1sgBQ7Qb27VgIMBW8ZZhT4gLnUyDIhzI=
|
||||||
gvisor.dev/gvisor v0.0.0-20240423190808-9d7a357edefe/go.mod h1:sxc3Uvk/vHcd3tj7/DHVBoR5wvWT/MmRq2pj7HRJnwU=
|
gvisor.dev/gvisor v0.0.0-20250503011706-39ed1f5ac29c/go.mod h1:3r5CMtNQMKIvBlrmM9xWUNamjKBYPOWyXOjmg5Kts3g=
|
||||||
|
|||||||
159
handshake_ix.go
159
handshake_ix.go
@@ -2,6 +2,7 @@ package nebula
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"slices"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/flynn/noise"
|
"github.com/flynn/noise"
|
||||||
@@ -22,17 +23,13 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// If we're connecting to a v6 address we must use a v2 cert
|
||||||
cs := f.pki.getCertState()
|
cs := f.pki.getCertState()
|
||||||
v := cs.initiatingVersion
|
v := cs.initiatingVersion
|
||||||
if hh.initiatingVersionOverride != cert.VersionPre1 {
|
for _, a := range hh.hostinfo.vpnAddrs {
|
||||||
v = hh.initiatingVersionOverride
|
if a.Is6() {
|
||||||
} else if v < cert.Version2 {
|
v = cert.Version2
|
||||||
// If we're connecting to a v6 address we should encourage use of a V2 cert
|
break
|
||||||
for _, a := range hh.hostinfo.vpnAddrs {
|
|
||||||
if a.Is6() {
|
|
||||||
v = cert.Version2
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -51,7 +48,6 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
|
|||||||
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
|
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
|
||||||
WithField("certVersion", v).
|
WithField("certVersion", v).
|
||||||
Error("Unable to handshake with host because no certificate handshake bytes is available")
|
Error("Unable to handshake with host because no certificate handshake bytes is available")
|
||||||
return false
|
|
||||||
}
|
}
|
||||||
|
|
||||||
ci, err := NewConnectionState(f.l, cs, crt, true, noise.HandshakeIX)
|
ci, err := NewConnectionState(f.l, cs, crt, true, noise.HandshakeIX)
|
||||||
@@ -107,7 +103,6 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
|||||||
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
|
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
|
||||||
WithField("certVersion", cs.initiatingVersion).
|
WithField("certVersion", cs.initiatingVersion).
|
||||||
Error("Unable to handshake with host because no certificate is available")
|
Error("Unable to handshake with host because no certificate is available")
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
ci, err := NewConnectionState(f.l, cs, crt, false, noise.HandshakeIX)
|
ci, err := NewConnectionState(f.l, cs, crt, false, noise.HandshakeIX)
|
||||||
@@ -148,8 +143,8 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
|||||||
|
|
||||||
remoteCert, err := f.pki.GetCAPool().VerifyCertificate(time.Now(), rc)
|
remoteCert, err := f.pki.GetCAPool().VerifyCertificate(time.Now(), rc)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fp, fperr := rc.Fingerprint()
|
fp, err := rc.Fingerprint()
|
||||||
if fperr != nil {
|
if err != nil {
|
||||||
fp = "<error generating certificate fingerprint>"
|
fp = "<error generating certificate fingerprint>"
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -168,19 +163,16 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
|||||||
|
|
||||||
if remoteCert.Certificate.Version() != ci.myCert.Version() {
|
if remoteCert.Certificate.Version() != ci.myCert.Version() {
|
||||||
// We started off using the wrong certificate version, lets see if we can match the version that was sent to us
|
// We started off using the wrong certificate version, lets see if we can match the version that was sent to us
|
||||||
myCertOtherVersion := cs.getCertificate(remoteCert.Certificate.Version())
|
rc := cs.getCertificate(remoteCert.Certificate.Version())
|
||||||
if myCertOtherVersion == nil {
|
if rc == nil {
|
||||||
if f.l.Level >= logrus.DebugLevel {
|
f.l.WithError(err).WithField("udpAddr", addr).
|
||||||
f.l.WithError(err).WithFields(m{
|
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).WithField("cert", remoteCert).
|
||||||
"udpAddr": addr,
|
Info("Unable to handshake with host due to missing certificate version")
|
||||||
"handshake": m{"stage": 1, "style": "ix_psk0"},
|
return
|
||||||
"cert": remoteCert,
|
|
||||||
}).Debug("Might be unable to handshake with host due to missing certificate version")
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// Record the certificate we are actually using
|
|
||||||
ci.myCert = myCertOtherVersion
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Record the certificate we are actually using
|
||||||
|
ci.myCert = rc
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(remoteCert.Certificate.Networks()) == 0 {
|
if len(remoteCert.Certificate.Networks()) == 0 {
|
||||||
@@ -191,17 +183,17 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var vpnAddrs []netip.Addr
|
||||||
|
var filteredNetworks []netip.Prefix
|
||||||
certName := remoteCert.Certificate.Name()
|
certName := remoteCert.Certificate.Name()
|
||||||
certVersion := remoteCert.Certificate.Version()
|
certVersion := remoteCert.Certificate.Version()
|
||||||
fingerprint := remoteCert.Fingerprint
|
fingerprint := remoteCert.Fingerprint
|
||||||
issuer := remoteCert.Certificate.Issuer()
|
issuer := remoteCert.Certificate.Issuer()
|
||||||
vpnNetworks := remoteCert.Certificate.Networks()
|
|
||||||
|
|
||||||
anyVpnAddrsInCommon := false
|
for _, network := range remoteCert.Certificate.Networks() {
|
||||||
vpnAddrs := make([]netip.Addr, len(vpnNetworks))
|
vpnAddr := network.Addr()
|
||||||
for i, network := range vpnNetworks {
|
if f.myVpnAddrsTable.Contains(vpnAddr) {
|
||||||
if f.myVpnAddrsTable.Contains(network.Addr()) {
|
f.l.WithField("vpnAddr", vpnAddr).WithField("udpAddr", addr).
|
||||||
f.l.WithField("vpnNetworks", vpnNetworks).WithField("udpAddr", addr).
|
|
||||||
WithField("certName", certName).
|
WithField("certName", certName).
|
||||||
WithField("certVersion", certVersion).
|
WithField("certVersion", certVersion).
|
||||||
WithField("fingerprint", fingerprint).
|
WithField("fingerprint", fingerprint).
|
||||||
@@ -209,10 +201,24 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
|||||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Refusing to handshake with myself")
|
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Refusing to handshake with myself")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
vpnAddrs[i] = network.Addr()
|
|
||||||
if f.myVpnNetworksTable.Contains(network.Addr()) {
|
// vpnAddrs outside our vpn networks are of no use to us, filter them out
|
||||||
anyVpnAddrsInCommon = true
|
if !f.myVpnNetworksTable.Contains(vpnAddr) {
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
filteredNetworks = append(filteredNetworks, network)
|
||||||
|
vpnAddrs = append(vpnAddrs, vpnAddr)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(vpnAddrs) == 0 {
|
||||||
|
f.l.WithError(err).WithField("udpAddr", addr).
|
||||||
|
WithField("certName", certName).
|
||||||
|
WithField("certVersion", certVersion).
|
||||||
|
WithField("fingerprint", fingerprint).
|
||||||
|
WithField("issuer", issuer).
|
||||||
|
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("No usable vpn addresses from host, refusing handshake")
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if addr.IsValid() {
|
if addr.IsValid() {
|
||||||
@@ -249,30 +255,26 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
msgRxL := f.l.WithFields(m{
|
f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
|
||||||
"vpnAddrs": vpnAddrs,
|
WithField("certName", certName).
|
||||||
"udpAddr": addr,
|
WithField("certVersion", certVersion).
|
||||||
"certName": certName,
|
WithField("fingerprint", fingerprint).
|
||||||
"certVersion": certVersion,
|
WithField("issuer", issuer).
|
||||||
"fingerprint": fingerprint,
|
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
||||||
"issuer": issuer,
|
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
||||||
"initiatorIndex": hs.Details.InitiatorIndex,
|
Info("Handshake message received")
|
||||||
"responderIndex": hs.Details.ResponderIndex,
|
|
||||||
"remoteIndex": h.RemoteIndex,
|
|
||||||
"handshake": m{"stage": 1, "style": "ix_psk0"},
|
|
||||||
})
|
|
||||||
|
|
||||||
if anyVpnAddrsInCommon {
|
|
||||||
msgRxL.Info("Handshake message received")
|
|
||||||
} else {
|
|
||||||
//todo warn if not lighthouse or relay?
|
|
||||||
msgRxL.Info("Handshake message received, but no vpnNetworks in common.")
|
|
||||||
}
|
|
||||||
|
|
||||||
hs.Details.ResponderIndex = myIndex
|
hs.Details.ResponderIndex = myIndex
|
||||||
hs.Details.Cert = cs.getHandshakeBytes(ci.myCert.Version())
|
hs.Details.Cert = cs.getHandshakeBytes(ci.myCert.Version())
|
||||||
if hs.Details.Cert == nil {
|
if hs.Details.Cert == nil {
|
||||||
msgRxL.WithField("myCertVersion", ci.myCert.Version()).
|
f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
|
||||||
|
WithField("certName", certName).
|
||||||
|
WithField("certVersion", certVersion).
|
||||||
|
WithField("fingerprint", fingerprint).
|
||||||
|
WithField("issuer", issuer).
|
||||||
|
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
||||||
|
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
||||||
|
WithField("certVersion", ci.myCert.Version()).
|
||||||
Error("Unable to handshake with host because no certificate handshake bytes is available")
|
Error("Unable to handshake with host because no certificate handshake bytes is available")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -330,7 +332,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
|||||||
|
|
||||||
hostinfo.remotes = f.lightHouse.QueryCache(vpnAddrs)
|
hostinfo.remotes = f.lightHouse.QueryCache(vpnAddrs)
|
||||||
hostinfo.SetRemote(addr)
|
hostinfo.SetRemote(addr)
|
||||||
hostinfo.buildNetworks(f.myVpnNetworksTable, remoteCert.Certificate)
|
hostinfo.buildNetworks(filteredNetworks, remoteCert.Certificate.UnsafeNetworks())
|
||||||
|
|
||||||
existing, err := f.handshakeManager.CheckAndComplete(hostinfo, 0, f)
|
existing, err := f.handshakeManager.CheckAndComplete(hostinfo, 0, f)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -571,22 +573,31 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
|
|||||||
hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0])
|
hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0])
|
||||||
}
|
}
|
||||||
|
|
||||||
correctHostResponded := false
|
var vpnAddrs []netip.Addr
|
||||||
anyVpnAddrsInCommon := false
|
var filteredNetworks []netip.Prefix
|
||||||
vpnAddrs := make([]netip.Addr, len(vpnNetworks))
|
for _, network := range vpnNetworks {
|
||||||
for i, network := range vpnNetworks {
|
// vpnAddrs outside our vpn networks are of no use to us, filter them out
|
||||||
vpnAddrs[i] = network.Addr()
|
vpnAddr := network.Addr()
|
||||||
if f.myVpnNetworksTable.Contains(network.Addr()) {
|
if !f.myVpnNetworksTable.Contains(vpnAddr) {
|
||||||
anyVpnAddrsInCommon = true
|
continue
|
||||||
}
|
|
||||||
if hostinfo.vpnAddrs[0] == network.Addr() {
|
|
||||||
// todo is it more correct to see if any of hostinfo.vpnAddrs are in the cert? it should have len==1, but one day it might not?
|
|
||||||
correctHostResponded = true
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
filteredNetworks = append(filteredNetworks, network)
|
||||||
|
vpnAddrs = append(vpnAddrs, vpnAddr)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(vpnAddrs) == 0 {
|
||||||
|
f.l.WithError(err).WithField("udpAddr", addr).
|
||||||
|
WithField("certName", certName).
|
||||||
|
WithField("certVersion", certVersion).
|
||||||
|
WithField("fingerprint", fingerprint).
|
||||||
|
WithField("issuer", issuer).
|
||||||
|
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).Error("No usable vpn addresses from host, refusing handshake")
|
||||||
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
// Ensure the right host responded
|
// Ensure the right host responded
|
||||||
if !correctHostResponded {
|
if !slices.Contains(vpnAddrs, hostinfo.vpnAddrs[0]) {
|
||||||
f.l.WithField("intendedVpnAddrs", hostinfo.vpnAddrs).WithField("haveVpnNetworks", vpnNetworks).
|
f.l.WithField("intendedVpnAddrs", hostinfo.vpnAddrs).WithField("haveVpnNetworks", vpnNetworks).
|
||||||
WithField("udpAddr", addr).
|
WithField("udpAddr", addr).
|
||||||
WithField("certName", certName).
|
WithField("certName", certName).
|
||||||
@@ -598,7 +609,6 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
|
|||||||
f.handshakeManager.DeleteHostInfo(hostinfo)
|
f.handshakeManager.DeleteHostInfo(hostinfo)
|
||||||
|
|
||||||
// Create a new hostinfo/handshake for the intended vpn ip
|
// Create a new hostinfo/handshake for the intended vpn ip
|
||||||
//TODO is hostinfo.vpnAddrs[0] always the address to use?
|
|
||||||
f.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], func(newHH *HandshakeHostInfo) {
|
f.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], func(newHH *HandshakeHostInfo) {
|
||||||
// Block the current used address
|
// Block the current used address
|
||||||
newHH.hostinfo.remotes = hostinfo.remotes
|
newHH.hostinfo.remotes = hostinfo.remotes
|
||||||
@@ -625,7 +635,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
|
|||||||
ci.window.Update(f.l, 2)
|
ci.window.Update(f.l, 2)
|
||||||
|
|
||||||
duration := time.Since(hh.startTime).Nanoseconds()
|
duration := time.Since(hh.startTime).Nanoseconds()
|
||||||
msgRxL := f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
|
f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
|
||||||
WithField("certName", certName).
|
WithField("certName", certName).
|
||||||
WithField("certVersion", certVersion).
|
WithField("certVersion", certVersion).
|
||||||
WithField("fingerprint", fingerprint).
|
WithField("fingerprint", fingerprint).
|
||||||
@@ -633,17 +643,12 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
|
|||||||
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
||||||
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
|
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
|
||||||
WithField("durationNs", duration).
|
WithField("durationNs", duration).
|
||||||
WithField("sentCachedPackets", len(hh.packetStore))
|
WithField("sentCachedPackets", len(hh.packetStore)).
|
||||||
if anyVpnAddrsInCommon {
|
Info("Handshake message received")
|
||||||
msgRxL.Info("Handshake message received")
|
|
||||||
} else {
|
|
||||||
//todo warn if not lighthouse or relay?
|
|
||||||
msgRxL.Info("Handshake message received, but no vpnNetworks in common.")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Build up the radix for the firewall if we have subnets in the cert
|
// Build up the radix for the firewall if we have subnets in the cert
|
||||||
hostinfo.vpnAddrs = vpnAddrs
|
hostinfo.vpnAddrs = vpnAddrs
|
||||||
hostinfo.buildNetworks(f.myVpnNetworksTable, remoteCert.Certificate)
|
hostinfo.buildNetworks(filteredNetworks, remoteCert.Certificate.UnsafeNetworks())
|
||||||
|
|
||||||
// Complete our handshake and update metrics, this will replace any existing tunnels for the vpnAddrs here
|
// Complete our handshake and update metrics, this will replace any existing tunnels for the vpnAddrs here
|
||||||
f.handshakeManager.Complete(hostinfo, f)
|
f.handshakeManager.Complete(hostinfo, f)
|
||||||
|
|||||||
@@ -68,12 +68,11 @@ type HandshakeManager struct {
|
|||||||
type HandshakeHostInfo struct {
|
type HandshakeHostInfo struct {
|
||||||
sync.Mutex
|
sync.Mutex
|
||||||
|
|
||||||
startTime time.Time // Time that we first started trying with this handshake
|
startTime time.Time // Time that we first started trying with this handshake
|
||||||
ready bool // Is the handshake ready
|
ready bool // Is the handshake ready
|
||||||
initiatingVersionOverride cert.Version // Should we use a non-default cert version for this handshake?
|
counter int64 // How many attempts have we made so far
|
||||||
counter int64 // How many attempts have we made so far
|
lastRemotes []netip.AddrPort // Remotes that we sent to during the previous attempt
|
||||||
lastRemotes []netip.AddrPort // Remotes that we sent to during the previous attempt
|
packetStore []*cachedPacket // A set of packets to be transmitted once the handshake completes
|
||||||
packetStore []*cachedPacket // A set of packets to be transmitted once the handshake completes
|
|
||||||
|
|
||||||
hostinfo *HostInfo
|
hostinfo *HostInfo
|
||||||
}
|
}
|
||||||
@@ -269,12 +268,12 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
|
|||||||
hostinfo.logger(hm.l).WithField("relays", hostinfo.remotes.relays).Info("Attempt to relay through hosts")
|
hostinfo.logger(hm.l).WithField("relays", hostinfo.remotes.relays).Info("Attempt to relay through hosts")
|
||||||
// Send a RelayRequest to all known Relay IP's
|
// Send a RelayRequest to all known Relay IP's
|
||||||
for _, relay := range hostinfo.remotes.relays {
|
for _, relay := range hostinfo.remotes.relays {
|
||||||
// Don't relay through the host I'm trying to connect to
|
// Don't relay to myself
|
||||||
if relay == vpnIp {
|
if relay == vpnIp {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// Don't relay to myself
|
// Don't relay through the host I'm trying to connect to
|
||||||
if hm.f.myVpnAddrsTable.Contains(relay) {
|
if hm.f.myVpnAddrsTable.Contains(relay) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|||||||
40
hostmap.go
40
hostmap.go
@@ -212,18 +212,6 @@ func (rs *RelayState) InsertRelay(ip netip.Addr, idx uint32, r *Relay) {
|
|||||||
rs.relayForByIdx[idx] = r
|
rs.relayForByIdx[idx] = r
|
||||||
}
|
}
|
||||||
|
|
||||||
type NetworkType uint8
|
|
||||||
|
|
||||||
const (
|
|
||||||
NetworkTypeUnknown NetworkType = iota
|
|
||||||
// NetworkTypeVPN is a network that overlaps one or more of the vpnNetworks in our certificate
|
|
||||||
NetworkTypeVPN
|
|
||||||
// NetworkTypeVPNPeer is a network that does not overlap one of our networks
|
|
||||||
NetworkTypeVPNPeer
|
|
||||||
// NetworkTypeUnsafe is a network from Certificate.UnsafeNetworks()
|
|
||||||
NetworkTypeUnsafe
|
|
||||||
)
|
|
||||||
|
|
||||||
type HostInfo struct {
|
type HostInfo struct {
|
||||||
remote netip.AddrPort
|
remote netip.AddrPort
|
||||||
remotes *RemoteList
|
remotes *RemoteList
|
||||||
@@ -237,8 +225,8 @@ type HostInfo struct {
|
|||||||
// vpn networks but were removed because they are not usable
|
// vpn networks but were removed because they are not usable
|
||||||
vpnAddrs []netip.Addr
|
vpnAddrs []netip.Addr
|
||||||
|
|
||||||
// networks is a combination of specific vpn addresses (not prefixes!) and full unsafe networks assigned to this host.
|
// networks are both all vpn and unsafe networks assigned to this host
|
||||||
networks *bart.Table[NetworkType]
|
networks *bart.Lite
|
||||||
relayState RelayState
|
relayState RelayState
|
||||||
|
|
||||||
// HandshakePacket records the packets used to create this hostinfo
|
// HandshakePacket records the packets used to create this hostinfo
|
||||||
@@ -742,26 +730,20 @@ func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, newRemote netip.AddrPort) b
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// buildNetworks fills in the networks field of HostInfo. It accepts a cert.Certificate so you never ever mix the network types up.
|
func (i *HostInfo) buildNetworks(networks, unsafeNetworks []netip.Prefix) {
|
||||||
func (i *HostInfo) buildNetworks(myVpnNetworksTable *bart.Lite, c cert.Certificate) {
|
if len(networks) == 1 && len(unsafeNetworks) == 0 {
|
||||||
if len(c.Networks()) == 1 && len(c.UnsafeNetworks()) == 0 {
|
// Simple case, no CIDRTree needed
|
||||||
if myVpnNetworksTable.Contains(c.Networks()[0].Addr()) {
|
return
|
||||||
return // Simple case, no BART needed
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
i.networks = new(bart.Table[NetworkType])
|
i.networks = new(bart.Lite)
|
||||||
for _, network := range c.Networks() {
|
for _, network := range networks {
|
||||||
nprefix := netip.PrefixFrom(network.Addr(), network.Addr().BitLen())
|
nprefix := netip.PrefixFrom(network.Addr(), network.Addr().BitLen())
|
||||||
if myVpnNetworksTable.Contains(network.Addr()) {
|
i.networks.Insert(nprefix)
|
||||||
i.networks.Insert(nprefix, NetworkTypeVPN)
|
|
||||||
} else {
|
|
||||||
i.networks.Insert(nprefix, NetworkTypeVPNPeer)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, network := range c.UnsafeNetworks() {
|
for _, network := range unsafeNetworks {
|
||||||
i.networks.Insert(network, NetworkTypeUnsafe)
|
i.networks.Insert(network)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
110
inside.go
110
inside.go
@@ -2,18 +2,16 @@ package nebula
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
"github.com/slackhq/nebula/firewall"
|
"github.com/slackhq/nebula/firewall"
|
||||||
"github.com/slackhq/nebula/header"
|
"github.com/slackhq/nebula/header"
|
||||||
"github.com/slackhq/nebula/iputil"
|
"github.com/slackhq/nebula/iputil"
|
||||||
"github.com/slackhq/nebula/noiseutil"
|
"github.com/slackhq/nebula/noiseutil"
|
||||||
"github.com/slackhq/nebula/packet"
|
|
||||||
"github.com/slackhq/nebula/routing"
|
"github.com/slackhq/nebula/routing"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb []byte, out *packet.Packet, q int, localCache firewall.ConntrackCache, now time.Time) {
|
func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb, out []byte, q int, localCache firewall.ConntrackCache) {
|
||||||
err := newPacket(packet, false, fwPacket)
|
err := newPacket(packet, false, fwPacket)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if f.l.Level >= logrus.DebugLevel {
|
if f.l.Level >= logrus.DebugLevel {
|
||||||
@@ -35,8 +33,7 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
|
|||||||
// routes packets from the Nebula addr to the Nebula addr through the Nebula
|
// routes packets from the Nebula addr to the Nebula addr through the Nebula
|
||||||
// TUN device.
|
// TUN device.
|
||||||
if immediatelyForwardToSelf {
|
if immediatelyForwardToSelf {
|
||||||
_, err := f.readers[q].Write(packet)
|
if err := f.writeTun(q, packet); err != nil {
|
||||||
if err != nil {
|
|
||||||
f.l.WithError(err).Error("Failed to forward to tun")
|
f.l.WithError(err).Error("Failed to forward to tun")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -55,7 +52,7 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
|
|||||||
})
|
})
|
||||||
|
|
||||||
if hostinfo == nil {
|
if hostinfo == nil {
|
||||||
f.rejectInside(packet, out.Payload, q) //todo vector?
|
f.rejectInside(packet, out, q)
|
||||||
if f.l.Level >= logrus.DebugLevel {
|
if f.l.Level >= logrus.DebugLevel {
|
||||||
f.l.WithField("vpnAddr", fwPacket.RemoteAddr).
|
f.l.WithField("vpnAddr", fwPacket.RemoteAddr).
|
||||||
WithField("fwPacket", fwPacket).
|
WithField("fwPacket", fwPacket).
|
||||||
@@ -68,11 +65,12 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
dropReason := f.firewall.Drop(*fwPacket, false, hostinfo, f.pki.GetCAPool(), localCache, now)
|
dropReason := f.firewall.Drop(*fwPacket, false, hostinfo, f.pki.GetCAPool(), localCache)
|
||||||
if dropReason == nil {
|
if dropReason == nil {
|
||||||
f.sendNoMetricsDelayed(header.Message, 0, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, packet, nb, out, q)
|
f.sendNoMetrics(header.Message, 0, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, packet, nb, out, q)
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
f.rejectInside(packet, out.Payload, q) //todo vector?
|
f.rejectInside(packet, out, q)
|
||||||
if f.l.Level >= logrus.DebugLevel {
|
if f.l.Level >= logrus.DebugLevel {
|
||||||
hostinfo.logger(f.l).
|
hostinfo.logger(f.l).
|
||||||
WithField("fwPacket", fwPacket).
|
WithField("fwPacket", fwPacket).
|
||||||
@@ -92,8 +90,7 @@ func (f *Interface) rejectInside(packet []byte, out []byte, q int) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err := f.readers[q].Write(out)
|
if err := f.writeTun(q, out); err != nil {
|
||||||
if err != nil {
|
|
||||||
f.l.WithError(err).Error("Failed to write to tun")
|
f.l.WithError(err).Error("Failed to write to tun")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -121,10 +118,9 @@ func (f *Interface) rejectOutside(packet []byte, ci *ConnectionState, hostinfo *
|
|||||||
f.sendNoMetrics(header.Message, 0, ci, hostinfo, netip.AddrPort{}, out, nb, packet, q)
|
f.sendNoMetrics(header.Message, 0, ci, hostinfo, netip.AddrPort{}, out, nb, packet, q)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handshake will attempt to initiate a tunnel with the provided vpn address. This is a no-op if the tunnel is already established or being established
|
// Handshake will attempt to initiate a tunnel with the provided vpn address if it is within our vpn networks. This is a no-op if the tunnel is already established or being established
|
||||||
// it does not check if it is within our vpn networks!
|
|
||||||
func (f *Interface) Handshake(vpnAddr netip.Addr) {
|
func (f *Interface) Handshake(vpnAddr netip.Addr) {
|
||||||
f.handshakeManager.GetOrHandshake(vpnAddr, nil)
|
f.getOrHandshakeNoRouting(vpnAddr, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
// getOrHandshakeNoRouting returns nil if the vpnAddr is not routable.
|
// getOrHandshakeNoRouting returns nil if the vpnAddr is not routable.
|
||||||
@@ -140,6 +136,7 @@ func (f *Interface) getOrHandshakeNoRouting(vpnAddr netip.Addr, cacheCallback fu
|
|||||||
// getOrHandshakeConsiderRouting will try to find the HostInfo to handle this packet, starting a handshake if necessary.
|
// getOrHandshakeConsiderRouting will try to find the HostInfo to handle this packet, starting a handshake if necessary.
|
||||||
// If the 2nd return var is false then the hostinfo is not ready to be used in a tunnel.
|
// If the 2nd return var is false then the hostinfo is not ready to be used in a tunnel.
|
||||||
func (f *Interface) getOrHandshakeConsiderRouting(fwPacket *firewall.Packet, cacheCallback func(*HandshakeHostInfo)) (*HostInfo, bool) {
|
func (f *Interface) getOrHandshakeConsiderRouting(fwPacket *firewall.Packet, cacheCallback func(*HandshakeHostInfo)) (*HostInfo, bool) {
|
||||||
|
|
||||||
destinationAddr := fwPacket.RemoteAddr
|
destinationAddr := fwPacket.RemoteAddr
|
||||||
|
|
||||||
hostinfo, ready := f.getOrHandshakeNoRouting(destinationAddr, cacheCallback)
|
hostinfo, ready := f.getOrHandshakeNoRouting(destinationAddr, cacheCallback)
|
||||||
@@ -219,7 +216,7 @@ func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubTyp
|
|||||||
}
|
}
|
||||||
|
|
||||||
// check if packet is in outbound fw rules
|
// check if packet is in outbound fw rules
|
||||||
dropReason := f.firewall.Drop(*fp, false, hostinfo, f.pki.GetCAPool(), nil, time.Now())
|
dropReason := f.firewall.Drop(*fp, false, hostinfo, f.pki.GetCAPool(), nil)
|
||||||
if dropReason != nil {
|
if dropReason != nil {
|
||||||
if f.l.Level >= logrus.DebugLevel {
|
if f.l.Level >= logrus.DebugLevel {
|
||||||
f.l.WithField("fwPacket", fp).
|
f.l.WithField("fwPacket", fp).
|
||||||
@@ -232,10 +229,9 @@ func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubTyp
|
|||||||
f.sendNoMetrics(header.Message, st, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, p, nb, out, 0)
|
f.sendNoMetrics(header.Message, st, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, p, nb, out, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
// SendMessageToVpnAddr handles real addr:port lookup and sends to the current best known address for vpnAddr.
|
// SendMessageToVpnAddr handles real addr:port lookup and sends to the current best known address for vpnAddr
|
||||||
// This function ignores myVpnNetworksTable, and will always attempt to treat the address as a vpnAddr
|
|
||||||
func (f *Interface) SendMessageToVpnAddr(t header.MessageType, st header.MessageSubType, vpnAddr netip.Addr, p, nb, out []byte) {
|
func (f *Interface) SendMessageToVpnAddr(t header.MessageType, st header.MessageSubType, vpnAddr netip.Addr, p, nb, out []byte) {
|
||||||
hostInfo, ready := f.handshakeManager.GetOrHandshake(vpnAddr, func(hh *HandshakeHostInfo) {
|
hostInfo, ready := f.getOrHandshakeNoRouting(vpnAddr, func(hh *HandshakeHostInfo) {
|
||||||
hh.cachePacket(f.l, t, st, p, f.SendMessageToHostInfo, f.cachedPacketMetrics)
|
hh.cachePacket(f.l, t, st, p, f.SendMessageToHostInfo, f.cachedPacketMetrics)
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -411,81 +407,3 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Interface) sendNoMetricsDelayed(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote netip.AddrPort, p, nb []byte, out *packet.Packet, q int) {
|
|
||||||
if ci.eKey == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
useRelay := !remote.IsValid() && !hostinfo.remote.IsValid()
|
|
||||||
fullOut := out.Payload
|
|
||||||
|
|
||||||
if useRelay {
|
|
||||||
if len(out.Payload) < header.Len {
|
|
||||||
// out always has a capacity of mtu, but not always a length greater than the header.Len.
|
|
||||||
// Grow it to make sure the next operation works.
|
|
||||||
out.Payload = out.Payload[:header.Len]
|
|
||||||
}
|
|
||||||
// Save a header's worth of data at the front of the 'out' buffer.
|
|
||||||
out.Payload = out.Payload[header.Len:]
|
|
||||||
}
|
|
||||||
|
|
||||||
if noiseutil.EncryptLockNeeded {
|
|
||||||
// NOTE: for goboring AESGCMTLS we need to lock because of the nonce check
|
|
||||||
ci.writeLock.Lock()
|
|
||||||
}
|
|
||||||
c := ci.messageCounter.Add(1)
|
|
||||||
|
|
||||||
//l.WithField("trace", string(debug.Stack())).Error("out Header ", &Header{Version, t, st, 0, hostinfo.remoteIndexId, c}, p)
|
|
||||||
out.Payload = header.Encode(out.Payload, header.Version, t, st, hostinfo.remoteIndexId, c)
|
|
||||||
f.connectionManager.Out(hostinfo)
|
|
||||||
|
|
||||||
// Query our LH if we haven't since the last time we've been rebound, this will cause the remote to punch against
|
|
||||||
// all our addrs and enable a faster roaming.
|
|
||||||
if t != header.CloseTunnel && hostinfo.lastRebindCount != f.rebindCount {
|
|
||||||
//NOTE: there is an update hole if a tunnel isn't used and exactly 256 rebinds occur before the tunnel is
|
|
||||||
// finally used again. This tunnel would eventually be torn down and recreated if this action didn't help.
|
|
||||||
f.lightHouse.QueryServer(hostinfo.vpnAddrs[0])
|
|
||||||
hostinfo.lastRebindCount = f.rebindCount
|
|
||||||
if f.l.Level >= logrus.DebugLevel {
|
|
||||||
f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).Debug("Lighthouse update triggered for punch due to rebind counter")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var err error
|
|
||||||
out.Payload, err = ci.eKey.EncryptDanger(out.Payload, out.Payload, p, c, nb)
|
|
||||||
if noiseutil.EncryptLockNeeded {
|
|
||||||
ci.writeLock.Unlock()
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
hostinfo.logger(f.l).WithError(err).
|
|
||||||
WithField("udpAddr", remote).WithField("counter", c).
|
|
||||||
WithField("attemptedCounter", c).
|
|
||||||
Error("Failed to encrypt outgoing packet")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if remote.IsValid() {
|
|
||||||
err = f.writers[q].Prep(out, remote)
|
|
||||||
if err != nil {
|
|
||||||
hostinfo.logger(f.l).WithError(err).WithField("udpAddr", remote).Error("Failed to write outgoing packet")
|
|
||||||
}
|
|
||||||
} else if hostinfo.remote.IsValid() {
|
|
||||||
err = f.writers[q].Prep(out, hostinfo.remote)
|
|
||||||
if err != nil {
|
|
||||||
hostinfo.logger(f.l).WithError(err).WithField("udpAddr", remote).Error("Failed to write outgoing packet")
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// Try to send via a relay
|
|
||||||
for _, relayIP := range hostinfo.relayState.CopyRelayIps() {
|
|
||||||
relayHostInfo, relay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relayIP)
|
|
||||||
if err != nil {
|
|
||||||
hostinfo.relayState.DeleteRelay(relayIP)
|
|
||||||
hostinfo.logger(f.l).WithField("relay", relayIP).WithError(err).Info("sendNoMetrics failed to find HostInfo")
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
//todo vector!!
|
|
||||||
f.SendVia(relayHostInfo, relay, out.Payload, nb, fullOut[:header.Len+len(out.Payload)], true)
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
185
interface.go
185
interface.go
@@ -4,6 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
"runtime"
|
"runtime"
|
||||||
@@ -17,12 +18,10 @@ import (
|
|||||||
"github.com/slackhq/nebula/firewall"
|
"github.com/slackhq/nebula/firewall"
|
||||||
"github.com/slackhq/nebula/header"
|
"github.com/slackhq/nebula/header"
|
||||||
"github.com/slackhq/nebula/overlay"
|
"github.com/slackhq/nebula/overlay"
|
||||||
"github.com/slackhq/nebula/packet"
|
|
||||||
"github.com/slackhq/nebula/udp"
|
"github.com/slackhq/nebula/udp"
|
||||||
)
|
)
|
||||||
|
|
||||||
const mtu = 9001
|
const mtu = 9001
|
||||||
const batch = 1024 //todo config!
|
|
||||||
|
|
||||||
type InterfaceConfig struct {
|
type InterfaceConfig struct {
|
||||||
HostMap *HostMap
|
HostMap *HostMap
|
||||||
@@ -48,6 +47,7 @@ type InterfaceConfig struct {
|
|||||||
reQueryWait time.Duration
|
reQueryWait time.Duration
|
||||||
|
|
||||||
ConntrackCacheTimeout time.Duration
|
ConntrackCacheTimeout time.Duration
|
||||||
|
batchSize int
|
||||||
l *logrus.Logger
|
l *logrus.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -85,20 +85,15 @@ type Interface struct {
|
|||||||
version string
|
version string
|
||||||
|
|
||||||
conntrackCacheTimeout time.Duration
|
conntrackCacheTimeout time.Duration
|
||||||
|
batchSize int
|
||||||
|
|
||||||
writers []udp.Conn
|
writers []udp.Conn
|
||||||
readers []overlay.TunDev
|
readers []io.ReadWriteCloser
|
||||||
|
|
||||||
metricHandshakes metrics.Histogram
|
metricHandshakes metrics.Histogram
|
||||||
messageMetrics *MessageMetrics
|
messageMetrics *MessageMetrics
|
||||||
cachedPacketMetrics *cachedPacketMetrics
|
cachedPacketMetrics *cachedPacketMetrics
|
||||||
|
|
||||||
listenInN int
|
|
||||||
listenOutN int
|
|
||||||
|
|
||||||
listenInMetric metrics.Histogram
|
|
||||||
listenOutMetric metrics.Histogram
|
|
||||||
|
|
||||||
l *logrus.Logger
|
l *logrus.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -117,6 +112,16 @@ type EncWriter interface {
|
|||||||
GetCertState() *CertState
|
GetCertState() *CertState
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// BatchReader is an interface for readers that support vectorized packet reading
|
||||||
|
type BatchReader interface {
|
||||||
|
BatchRead(buffers [][]byte, sizes []int) (int, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// BatchWriter is an interface for writers that support vectorized packet writing
|
||||||
|
type BatchWriter interface {
|
||||||
|
BatchWrite([][]byte) (int, error)
|
||||||
|
}
|
||||||
|
|
||||||
type sendRecvErrorConfig uint8
|
type sendRecvErrorConfig uint8
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -184,7 +189,7 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
|
|||||||
routines: c.routines,
|
routines: c.routines,
|
||||||
version: c.version,
|
version: c.version,
|
||||||
writers: make([]udp.Conn, c.routines),
|
writers: make([]udp.Conn, c.routines),
|
||||||
readers: make([]overlay.TunDev, c.routines),
|
readers: make([]io.ReadWriteCloser, c.routines),
|
||||||
myVpnNetworks: cs.myVpnNetworks,
|
myVpnNetworks: cs.myVpnNetworks,
|
||||||
myVpnNetworksTable: cs.myVpnNetworksTable,
|
myVpnNetworksTable: cs.myVpnNetworksTable,
|
||||||
myVpnAddrs: cs.myVpnAddrs,
|
myVpnAddrs: cs.myVpnAddrs,
|
||||||
@@ -193,6 +198,7 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
|
|||||||
relayManager: c.relayManager,
|
relayManager: c.relayManager,
|
||||||
connectionManager: c.connectionManager,
|
connectionManager: c.connectionManager,
|
||||||
conntrackCacheTimeout: c.ConntrackCacheTimeout,
|
conntrackCacheTimeout: c.ConntrackCacheTimeout,
|
||||||
|
batchSize: c.batchSize,
|
||||||
|
|
||||||
metricHandshakes: metrics.GetOrRegisterHistogram("handshakes", nil, metrics.NewExpDecaySample(1028, 0.015)),
|
metricHandshakes: metrics.GetOrRegisterHistogram("handshakes", nil, metrics.NewExpDecaySample(1028, 0.015)),
|
||||||
messageMetrics: c.MessageMetrics,
|
messageMetrics: c.MessageMetrics,
|
||||||
@@ -203,8 +209,6 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
|
|||||||
|
|
||||||
l: c.l,
|
l: c.l,
|
||||||
}
|
}
|
||||||
ifce.listenInMetric = metrics.GetOrRegisterHistogram("vhost.listenIn.n", nil, metrics.NewExpDecaySample(1028, 0.015))
|
|
||||||
ifce.listenOutMetric = metrics.GetOrRegisterHistogram("vhost.listenOut.n", nil, metrics.NewExpDecaySample(1028, 0.015))
|
|
||||||
|
|
||||||
ifce.tryPromoteEvery.Store(c.tryPromoteEvery)
|
ifce.tryPromoteEvery.Store(c.tryPromoteEvery)
|
||||||
ifce.reQueryEvery.Store(c.reQueryEvery)
|
ifce.reQueryEvery.Store(c.reQueryEvery)
|
||||||
@@ -234,7 +238,7 @@ func (f *Interface) activate() {
|
|||||||
metrics.GetOrRegisterGauge("routines", nil).Update(int64(f.routines))
|
metrics.GetOrRegisterGauge("routines", nil).Update(int64(f.routines))
|
||||||
|
|
||||||
// Prepare n tun queues
|
// Prepare n tun queues
|
||||||
var reader overlay.TunDev = f.inside
|
var reader io.ReadWriteCloser = f.inside
|
||||||
for i := 0; i < f.routines; i++ {
|
for i := 0; i < f.routines; i++ {
|
||||||
if i > 0 {
|
if i > 0 {
|
||||||
reader, err = f.inside.NewMultiQueueReader()
|
reader, err = f.inside.NewMultiQueueReader()
|
||||||
@@ -263,100 +267,134 @@ func (f *Interface) run() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Interface) listenOut(q int) {
|
func (f *Interface) listenOut(i int) {
|
||||||
runtime.LockOSThread()
|
runtime.LockOSThread()
|
||||||
|
|
||||||
var li udp.Conn
|
var li udp.Conn
|
||||||
if q > 0 {
|
if i > 0 {
|
||||||
li = f.writers[q]
|
li = f.writers[i]
|
||||||
} else {
|
} else {
|
||||||
li = f.outside
|
li = f.outside
|
||||||
}
|
}
|
||||||
|
|
||||||
ctCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
|
ctCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
|
||||||
lhh := f.lightHouse.NewRequestHandler()
|
lhh := f.lightHouse.NewRequestHandler()
|
||||||
|
plaintext := make([]byte, udp.MTU)
|
||||||
outPackets := make([]*packet.OutPacket, batch)
|
|
||||||
for i := 0; i < batch; i++ {
|
|
||||||
outPackets[i] = packet.NewOut()
|
|
||||||
}
|
|
||||||
|
|
||||||
h := &header.H{}
|
h := &header.H{}
|
||||||
fwPacket := &firewall.Packet{}
|
fwPacket := &firewall.Packet{}
|
||||||
nb := make([]byte, 12, 12)
|
nb := make([]byte, 12)
|
||||||
|
|
||||||
toSend := make([][]byte, batch)
|
|
||||||
|
|
||||||
li.ListenOut(func(pkts []*packet.Packet) {
|
|
||||||
toSend = toSend[:0]
|
|
||||||
for i := range outPackets {
|
|
||||||
outPackets[i].Valid = false
|
|
||||||
outPackets[i].SegCounter = 0
|
|
||||||
}
|
|
||||||
|
|
||||||
f.readOutsidePacketsMany(pkts, outPackets, h, fwPacket, lhh, nb, q, ctCache.Get(f.l), time.Now())
|
|
||||||
//we opportunistically tx, but try to also send stragglers
|
|
||||||
if _, err := f.readers[q].WriteMany(outPackets, q); err != nil {
|
|
||||||
f.l.WithError(err).Error("Failed to send packets")
|
|
||||||
}
|
|
||||||
//todo I broke this
|
|
||||||
//n := len(toSend)
|
|
||||||
//if f.l.Level == logrus.DebugLevel {
|
|
||||||
// f.listenOutMetric.Update(int64(n))
|
|
||||||
//}
|
|
||||||
//f.listenOutN = n
|
|
||||||
|
|
||||||
|
li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) {
|
||||||
|
f.readOutsidePackets(fromUdpAddr, nil, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get(f.l))
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Interface) listenIn(reader overlay.TunDev, queueNum int) {
|
func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
|
||||||
runtime.LockOSThread()
|
runtime.LockOSThread()
|
||||||
|
|
||||||
|
// Check if reader supports batch operations
|
||||||
|
if batchReader, ok := reader.(BatchReader); ok {
|
||||||
|
err := f.listenInBatch(batchReader, i)
|
||||||
|
if err != nil {
|
||||||
|
f.l.WithError(err).Error("Fatal error in batch packet reader, exiting goroutine")
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fall back to single-packet mode
|
||||||
|
packet := make([]byte, mtu)
|
||||||
|
out := make([]byte, mtu)
|
||||||
fwPacket := &firewall.Packet{}
|
fwPacket := &firewall.Packet{}
|
||||||
nb := make([]byte, 12, 12)
|
nb := make([]byte, 12, 12)
|
||||||
|
|
||||||
conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
|
conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
|
||||||
|
|
||||||
packets := make([]*packet.VirtIOPacket, batch)
|
|
||||||
outPackets := make([]*packet.Packet, batch)
|
|
||||||
for i := 0; i < batch; i++ {
|
|
||||||
packets[i] = packet.NewVIO()
|
|
||||||
outPackets[i] = packet.New(false) //todo?
|
|
||||||
}
|
|
||||||
|
|
||||||
for {
|
for {
|
||||||
n, err := reader.ReadMany(packets, queueNum)
|
n, err := reader.Read(packet)
|
||||||
|
|
||||||
//todo!!
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, os.ErrClosed) && f.closed.Load() {
|
if errors.Is(err, os.ErrClosed) && f.closed.Load() {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
f.l.WithError(err).Error("Error while reading outbound packet")
|
f.l.WithError(err).Error("Fatal error while reading outbound packet, exiting goroutine")
|
||||||
// This only seems to happen when something fatal happens to the fd, so exit.
|
return
|
||||||
os.Exit(2)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if f.l.Level == logrus.DebugLevel {
|
f.consumeInsidePacket(packet[:n], fwPacket, nb, out, i, conntrackCache.Get(f.l))
|
||||||
f.listenInMetric.Update(int64(n))
|
}
|
||||||
}
|
}
|
||||||
f.listenInN = n
|
|
||||||
|
|
||||||
now := time.Now()
|
// listenInBatch handles vectorized packet reading for improved performance
|
||||||
for i, pkt := range packets[:n] {
|
func (f *Interface) listenInBatch(reader BatchReader, i int) error {
|
||||||
outPackets[i].OutLen = -1
|
// Allocate per-packet state and buffers for batch reading
|
||||||
f.consumeInsidePacket(pkt.Payload, fwPacket, nb, outPackets[i], queueNum, conntrackCache.Get(f.l), now)
|
batchSize := f.batchSize
|
||||||
reader.RecycleRxSeg(pkt, i == (n-1), queueNum) //todo handle err?
|
if batchSize <= 0 {
|
||||||
pkt.Reset()
|
batchSize = 64 // Fallback to default if not configured
|
||||||
}
|
}
|
||||||
_, err = f.writers[queueNum].WriteBatch(outPackets[:n])
|
fwPackets := make([]*firewall.Packet, batchSize)
|
||||||
|
outBuffers := make([][]byte, batchSize)
|
||||||
|
nbBuffers := make([][]byte, batchSize)
|
||||||
|
packets := make([][]byte, batchSize)
|
||||||
|
sizes := make([]int, batchSize)
|
||||||
|
|
||||||
|
for j := 0; j < batchSize; j++ {
|
||||||
|
fwPackets[j] = &firewall.Packet{}
|
||||||
|
outBuffers[j] = make([]byte, mtu)
|
||||||
|
nbBuffers[j] = make([]byte, 12)
|
||||||
|
packets[j] = make([]byte, mtu)
|
||||||
|
}
|
||||||
|
|
||||||
|
conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
|
||||||
|
|
||||||
|
for {
|
||||||
|
n, err := reader.BatchRead(packets, sizes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
f.l.WithError(err).Error("Error while writing outbound packets")
|
if errors.Is(err, os.ErrClosed) && f.closed.Load() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Errorf("error while batch reading outbound packets: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process each packet in the batch
|
||||||
|
cache := conntrackCache.Get(f.l)
|
||||||
|
for idx := 0; idx < n; idx++ {
|
||||||
|
if sizes[idx] > 0 {
|
||||||
|
// Use modulo to reuse fw packet state if batch is larger than our pre-allocated state
|
||||||
|
stateIdx := idx % len(fwPackets)
|
||||||
|
f.consumeInsidePacket(packets[idx][:sizes[idx]], fwPackets[stateIdx], nbBuffers[stateIdx], outBuffers[stateIdx], i, cache)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// writeTunBatch attempts to write multiple packets to the TUN device using batch operations if supported
|
||||||
|
func (f *Interface) writeTunBatch(q int, packets [][]byte) error {
|
||||||
|
if len(packets) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if the reader/writer supports batch operations
|
||||||
|
if batchWriter, ok := f.readers[q].(BatchWriter); ok {
|
||||||
|
_, err := batchWriter.BatchWrite(packets)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fall back to writing packets individually
|
||||||
|
for _, packet := range packets {
|
||||||
|
if _, err := f.readers[q].Write(packet); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// writeTun writes a single packet to the TUN device
|
||||||
|
func (f *Interface) writeTun(q int, packet []byte) error {
|
||||||
|
_, err := f.readers[q].Write(packet)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
func (f *Interface) RegisterConfigChangeCallbacks(c *config.C) {
|
func (f *Interface) RegisterConfigChangeCallbacks(c *config.C) {
|
||||||
c.RegisterReloadCallback(f.reloadFirewall)
|
c.RegisterReloadCallback(f.reloadFirewall)
|
||||||
c.RegisterReloadCallback(f.reloadSendRecvError)
|
c.RegisterReloadCallback(f.reloadSendRecvError)
|
||||||
@@ -491,11 +529,6 @@ func (f *Interface) emitStats(ctx context.Context, i time.Duration) {
|
|||||||
} else {
|
} else {
|
||||||
certMaxVersion.Update(int64(certState.v1Cert.Version()))
|
certMaxVersion.Update(int64(certState.v1Cert.Version()))
|
||||||
}
|
}
|
||||||
if f.l.Level != logrus.DebugLevel {
|
|
||||||
f.listenInMetric.Update(int64(f.listenInN))
|
|
||||||
f.listenOutMetric.Update(int64(f.listenOutN))
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -360,8 +360,7 @@ func (lh *LightHouse) parseLighthouses(c *config.C) ([]netip.Addr, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if !lh.myVpnNetworksTable.Contains(addr) {
|
if !lh.myVpnNetworksTable.Contains(addr) {
|
||||||
lh.l.WithFields(m{"vpnAddr": addr, "networks": lh.myVpnNetworks}).
|
return nil, util.NewContextualError("lighthouse host is not in our networks, invalid", m{"vpnAddr": addr, "networks": lh.myVpnNetworks}, nil)
|
||||||
Warn("lighthouse host is not within our networks, lighthouse functionality will work but layer 3 network traffic to the lighthouse will not")
|
|
||||||
}
|
}
|
||||||
out[i] = addr
|
out[i] = addr
|
||||||
}
|
}
|
||||||
@@ -432,8 +431,7 @@ func (lh *LightHouse) loadStaticMap(c *config.C, staticList map[netip.Addr]struc
|
|||||||
}
|
}
|
||||||
|
|
||||||
if !lh.myVpnNetworksTable.Contains(vpnAddr) {
|
if !lh.myVpnNetworksTable.Contains(vpnAddr) {
|
||||||
lh.l.WithFields(m{"vpnAddr": vpnAddr, "networks": lh.myVpnNetworks, "entry": i + 1}).
|
return util.NewContextualError("static_host_map key is not in our network, invalid", m{"vpnAddr": vpnAddr, "networks": lh.myVpnNetworks, "entry": i + 1}, nil)
|
||||||
Warn("static_host_map key is not within our networks, layer 3 network traffic to this host will not work")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
vals, ok := v.([]any)
|
vals, ok := v.([]any)
|
||||||
@@ -1339,19 +1337,12 @@ func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, fromVpn
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
remoteAllowList := lhh.lh.GetRemoteAllowList()
|
|
||||||
for _, a := range n.Details.V4AddrPorts {
|
for _, a := range n.Details.V4AddrPorts {
|
||||||
b := protoV4AddrPortToNetAddrPort(a)
|
punch(protoV4AddrPortToNetAddrPort(a), detailsVpnAddr)
|
||||||
if remoteAllowList.Allow(detailsVpnAddr, b.Addr()) {
|
|
||||||
punch(b, detailsVpnAddr)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, a := range n.Details.V6AddrPorts {
|
for _, a := range n.Details.V6AddrPorts {
|
||||||
b := protoV6AddrPortToNetAddrPort(a)
|
punch(protoV6AddrPortToNetAddrPort(a), detailsVpnAddr)
|
||||||
if remoteAllowList.Allow(detailsVpnAddr, b.Addr()) {
|
|
||||||
punch(b, detailsVpnAddr)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// This sends a nebula test packet to the host trying to contact us. In the case
|
// This sends a nebula test packet to the host trying to contact us. In the case
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ import (
|
|||||||
"github.com/slackhq/nebula/test"
|
"github.com/slackhq/nebula/test"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"go.yaml.in/yaml/v3"
|
"gopkg.in/yaml.v3"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestOldIPv4Only(t *testing.T) {
|
func TestOldIPv4Only(t *testing.T) {
|
||||||
|
|||||||
6
main.go
6
main.go
@@ -13,7 +13,7 @@ import (
|
|||||||
"github.com/slackhq/nebula/sshd"
|
"github.com/slackhq/nebula/sshd"
|
||||||
"github.com/slackhq/nebula/udp"
|
"github.com/slackhq/nebula/udp"
|
||||||
"github.com/slackhq/nebula/util"
|
"github.com/slackhq/nebula/util"
|
||||||
"go.yaml.in/yaml/v3"
|
"gopkg.in/yaml.v3"
|
||||||
)
|
)
|
||||||
|
|
||||||
type m = map[string]any
|
type m = map[string]any
|
||||||
@@ -75,8 +75,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
|||||||
if c.GetBool("sshd.enabled", false) {
|
if c.GetBool("sshd.enabled", false) {
|
||||||
sshStart, err = configSSH(l, ssh, c)
|
sshStart, err = configSSH(l, ssh, c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
l.WithError(err).Warn("Failed to configure sshd, ssh debugging will not be available")
|
return nil, util.ContextualizeIfNeeded("Error while configuring the sshd", err)
|
||||||
sshStart = nil
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -243,6 +242,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
|||||||
relayManager: NewRelayManager(ctx, l, hostMap, c),
|
relayManager: NewRelayManager(ctx, l, hostMap, c),
|
||||||
punchy: punchy,
|
punchy: punchy,
|
||||||
ConntrackCacheTimeout: conntrackCacheTimeout,
|
ConntrackCacheTimeout: conntrackCacheTimeout,
|
||||||
|
batchSize: c.GetInt("tun.batch_size", 64),
|
||||||
l: l,
|
l: l,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
279
outside.go
279
outside.go
@@ -7,7 +7,6 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/google/gopacket/layers"
|
"github.com/google/gopacket/layers"
|
||||||
"github.com/slackhq/nebula/packet"
|
|
||||||
"golang.org/x/net/ipv6"
|
"golang.org/x/net/ipv6"
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
@@ -20,7 +19,7 @@ const (
|
|||||||
minFwPacketLen = 4
|
minFwPacketLen = 4
|
||||||
)
|
)
|
||||||
|
|
||||||
func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache, now time.Time) {
|
func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache) {
|
||||||
err := h.Parse(packet)
|
err := h.Parse(packet)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Hole punch packets are 0 or 1 byte big, so lets ignore printing those errors
|
// Hole punch packets are 0 or 1 byte big, so lets ignore printing those errors
|
||||||
@@ -62,7 +61,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
|
|||||||
|
|
||||||
switch h.Subtype {
|
switch h.Subtype {
|
||||||
case header.MessageNone:
|
case header.MessageNone:
|
||||||
if !f.decryptToTun(hostinfo, h.MessageCounter, out, packet, fwPacket, nb, q, localCache, now) {
|
if !f.decryptToTun(hostinfo, h.MessageCounter, out, packet, fwPacket, nb, q, localCache) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
case header.MessageRelay:
|
case header.MessageRelay:
|
||||||
@@ -97,7 +96,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
|
|||||||
case TerminalType:
|
case TerminalType:
|
||||||
// If I am the target of this relay, process the unwrapped packet
|
// If I am the target of this relay, process the unwrapped packet
|
||||||
// From this recursive point, all these variables are 'burned'. We shouldn't rely on them again.
|
// From this recursive point, all these variables are 'burned'. We shouldn't rely on them again.
|
||||||
f.readOutsidePackets(netip.AddrPort{}, &ViaSender{relayHI: hostinfo, remoteIdx: relay.RemoteIndex, relay: relay}, out[:0], signedPayload, h, fwPacket, lhf, nb, q, localCache, now)
|
f.readOutsidePackets(netip.AddrPort{}, &ViaSender{relayHI: hostinfo, remoteIdx: relay.RemoteIndex, relay: relay}, out[:0], signedPayload, h, fwPacket, lhf, nb, q, localCache)
|
||||||
return
|
return
|
||||||
case ForwardingType:
|
case ForwardingType:
|
||||||
// Find the target HostInfo relay object
|
// Find the target HostInfo relay object
|
||||||
@@ -217,217 +216,6 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
|
|||||||
f.connectionManager.In(hostinfo)
|
f.connectionManager.In(hostinfo)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Interface) readOutsidePacketsMany(packets []*packet.Packet, out []*packet.OutPacket, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache, now time.Time) {
|
|
||||||
for i, pkt := range packets {
|
|
||||||
out[i].Scratch = out[i].Scratch[:0]
|
|
||||||
ip := pkt.AddrPort()
|
|
||||||
|
|
||||||
//l.Error("in packet ", header, packet[HeaderLen:])
|
|
||||||
if ip.IsValid() {
|
|
||||||
if f.myVpnNetworksTable.Contains(ip.Addr()) {
|
|
||||||
if f.l.Level >= logrus.DebugLevel {
|
|
||||||
f.l.WithField("udpAddr", ip).Debug("Refusing to process double encrypted packet")
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
//todo per-segment!
|
|
||||||
for segment := range pkt.Segments() {
|
|
||||||
|
|
||||||
err := h.Parse(segment)
|
|
||||||
if err != nil {
|
|
||||||
// Hole punch packets are 0 or 1 byte big, so lets ignore printing those errors
|
|
||||||
if len(segment) > 1 {
|
|
||||||
f.l.WithField("packet", pkt).Infof("Error while parsing inbound packet from %s: %s", ip, err)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var hostinfo *HostInfo
|
|
||||||
// verify if we've seen this index before, otherwise respond to the handshake initiation
|
|
||||||
if h.Type == header.Message && h.Subtype == header.MessageRelay {
|
|
||||||
hostinfo = f.hostMap.QueryRelayIndex(h.RemoteIndex)
|
|
||||||
} else {
|
|
||||||
hostinfo = f.hostMap.QueryIndex(h.RemoteIndex)
|
|
||||||
}
|
|
||||||
|
|
||||||
var ci *ConnectionState
|
|
||||||
if hostinfo != nil {
|
|
||||||
ci = hostinfo.ConnectionState
|
|
||||||
}
|
|
||||||
|
|
||||||
switch h.Type {
|
|
||||||
case header.Message:
|
|
||||||
// TODO handleEncrypted sends directly to addr on error. Handle this in the tunneling case.
|
|
||||||
if !f.handleEncrypted(ci, ip, h) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
switch h.Subtype {
|
|
||||||
case header.MessageNone:
|
|
||||||
if !f.decryptToTunDelayWrite(hostinfo, h.MessageCounter, out[i], pkt, segment, fwPacket, nb, q, localCache, now) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
case header.MessageRelay:
|
|
||||||
// The entire body is sent as AD, not encrypted.
|
|
||||||
// The packet consists of a 16-byte parsed Nebula header, Associated Data-protected payload, and a trailing 16-byte AEAD signature value.
|
|
||||||
// The packet is guaranteed to be at least 16 bytes at this point, b/c it got past the h.Parse() call above. If it's
|
|
||||||
// otherwise malformed (meaning, there is no trailing 16 byte AEAD value), then this will result in at worst a 0-length slice
|
|
||||||
// which will gracefully fail in the DecryptDanger call.
|
|
||||||
signedPayload := segment[:len(segment)-hostinfo.ConnectionState.dKey.Overhead()]
|
|
||||||
signatureValue := segment[len(segment)-hostinfo.ConnectionState.dKey.Overhead():]
|
|
||||||
out[i].Scratch, err = hostinfo.ConnectionState.dKey.DecryptDanger(out[i].Scratch, signedPayload, signatureValue, h.MessageCounter, nb)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
// Successfully validated the thing. Get rid of the Relay header.
|
|
||||||
signedPayload = signedPayload[header.Len:]
|
|
||||||
// Pull the Roaming parts up here, and return in all call paths.
|
|
||||||
f.handleHostRoaming(hostinfo, ip)
|
|
||||||
// Track usage of both the HostInfo and the Relay for the received & authenticated packet
|
|
||||||
f.connectionManager.In(hostinfo)
|
|
||||||
f.connectionManager.RelayUsed(h.RemoteIndex)
|
|
||||||
|
|
||||||
relay, ok := hostinfo.relayState.QueryRelayForByIdx(h.RemoteIndex)
|
|
||||||
if !ok {
|
|
||||||
// The only way this happens is if hostmap has an index to the correct HostInfo, but the HostInfo is missing
|
|
||||||
// its internal mapping. This should never happen.
|
|
||||||
hostinfo.logger(f.l).WithFields(logrus.Fields{"vpnAddrs": hostinfo.vpnAddrs, "remoteIndex": h.RemoteIndex}).Error("HostInfo missing remote relay index")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
switch relay.Type {
|
|
||||||
case TerminalType:
|
|
||||||
// If I am the target of this relay, process the unwrapped packet
|
|
||||||
// From this recursive point, all these variables are 'burned'. We shouldn't rely on them again.
|
|
||||||
f.readOutsidePackets(netip.AddrPort{}, &ViaSender{relayHI: hostinfo, remoteIdx: relay.RemoteIndex, relay: relay}, out[i].Scratch[:0], signedPayload, h, fwPacket, lhf, nb, q, localCache, now)
|
|
||||||
return
|
|
||||||
case ForwardingType:
|
|
||||||
// Find the target HostInfo relay object
|
|
||||||
targetHI, targetRelay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relay.PeerAddr)
|
|
||||||
if err != nil {
|
|
||||||
hostinfo.logger(f.l).WithField("relayTo", relay.PeerAddr).WithError(err).WithField("hostinfo.vpnAddrs", hostinfo.vpnAddrs).Info("Failed to find target host info by ip")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// If that relay is Established, forward the payload through it
|
|
||||||
if targetRelay.State == Established {
|
|
||||||
switch targetRelay.Type {
|
|
||||||
case ForwardingType:
|
|
||||||
// Forward this packet through the relay tunnel
|
|
||||||
// Find the target HostInfo
|
|
||||||
f.SendVia(targetHI, targetRelay, signedPayload, nb, out[i].Scratch, false)
|
|
||||||
return
|
|
||||||
case TerminalType:
|
|
||||||
hostinfo.logger(f.l).Error("Unexpected Relay Type of Terminal")
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
hostinfo.logger(f.l).WithFields(logrus.Fields{"relayTo": relay.PeerAddr, "relayFrom": hostinfo.vpnAddrs[0], "targetRelayState": targetRelay.State}).Info("Unexpected target relay state")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
case header.LightHouse:
|
|
||||||
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
|
|
||||||
if !f.handleEncrypted(ci, ip, h) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
d, err := f.decrypt(hostinfo, h.MessageCounter, out[i].Scratch, segment, h, nb)
|
|
||||||
if err != nil {
|
|
||||||
hostinfo.logger(f.l).WithError(err).WithField("udpAddr", ip).
|
|
||||||
WithField("packet", segment).
|
|
||||||
Error("Failed to decrypt lighthouse packet")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
lhf.HandleRequest(ip, hostinfo.vpnAddrs, d, f)
|
|
||||||
|
|
||||||
// Fallthrough to the bottom to record incoming traffic
|
|
||||||
|
|
||||||
case header.Test:
|
|
||||||
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
|
|
||||||
if !f.handleEncrypted(ci, ip, h) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
d, err := f.decrypt(hostinfo, h.MessageCounter, out[i].Scratch, segment, h, nb)
|
|
||||||
if err != nil {
|
|
||||||
hostinfo.logger(f.l).WithError(err).WithField("udpAddr", ip).
|
|
||||||
WithField("packet", segment).
|
|
||||||
Error("Failed to decrypt test packet")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if h.Subtype == header.TestRequest {
|
|
||||||
// This testRequest might be from TryPromoteBest, so we should roam
|
|
||||||
// to the new IP address before responding
|
|
||||||
f.handleHostRoaming(hostinfo, ip)
|
|
||||||
f.send(header.Test, header.TestReply, ci, hostinfo, d, nb, out[i].Scratch)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Fallthrough to the bottom to record incoming traffic
|
|
||||||
|
|
||||||
// Non encrypted messages below here, they should not fall through to avoid tracking incoming traffic since they
|
|
||||||
// are unauthenticated
|
|
||||||
|
|
||||||
case header.Handshake:
|
|
||||||
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
|
|
||||||
f.handshakeManager.HandleIncoming(ip, nil, segment, h)
|
|
||||||
return
|
|
||||||
|
|
||||||
case header.RecvError:
|
|
||||||
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
|
|
||||||
f.handleRecvError(ip, h)
|
|
||||||
return
|
|
||||||
|
|
||||||
case header.CloseTunnel:
|
|
||||||
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
|
|
||||||
if !f.handleEncrypted(ci, ip, h) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
hostinfo.logger(f.l).WithField("udpAddr", ip).
|
|
||||||
Info("Close tunnel received, tearing down.")
|
|
||||||
|
|
||||||
f.closeTunnel(hostinfo)
|
|
||||||
return
|
|
||||||
|
|
||||||
case header.Control:
|
|
||||||
if !f.handleEncrypted(ci, ip, h) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
d, err := f.decrypt(hostinfo, h.MessageCounter, out[i].Scratch, segment, h, nb)
|
|
||||||
if err != nil {
|
|
||||||
hostinfo.logger(f.l).WithError(err).WithField("udpAddr", ip).
|
|
||||||
WithField("packet", segment).
|
|
||||||
Error("Failed to decrypt Control packet")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
f.relayManager.HandleControlMsg(hostinfo, d, f)
|
|
||||||
|
|
||||||
default:
|
|
||||||
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
|
|
||||||
hostinfo.logger(f.l).Debugf("Unexpected packet received from %s", ip)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
f.handleHostRoaming(hostinfo, ip)
|
|
||||||
|
|
||||||
f.connectionManager.In(hostinfo)
|
|
||||||
|
|
||||||
}
|
|
||||||
_, err := f.readers[q].WriteOne(out[i], false, q)
|
|
||||||
if err != nil {
|
|
||||||
f.l.WithError(err).Error("Failed to write packet")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// closeTunnel closes a tunnel locally, it does not send a closeTunnel packet to the remote
|
// closeTunnel closes a tunnel locally, it does not send a closeTunnel packet to the remote
|
||||||
func (f *Interface) closeTunnel(hostInfo *HostInfo) {
|
func (f *Interface) closeTunnel(hostInfo *HostInfo) {
|
||||||
final := f.hostMap.DeleteHostInfo(hostInfo)
|
final := f.hostMap.DeleteHostInfo(hostInfo)
|
||||||
@@ -545,12 +333,13 @@ func parseV6(data []byte, incoming bool, fp *firewall.Packet) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fp.Protocol = uint8(proto)
|
fp.Protocol = uint8(proto)
|
||||||
|
ports := data[offset : offset+4]
|
||||||
if incoming {
|
if incoming {
|
||||||
fp.RemotePort = binary.BigEndian.Uint16(data[offset : offset+2])
|
fp.RemotePort = binary.BigEndian.Uint16(ports[0:2])
|
||||||
fp.LocalPort = binary.BigEndian.Uint16(data[offset+2 : offset+4])
|
fp.LocalPort = binary.BigEndian.Uint16(ports[2:4])
|
||||||
} else {
|
} else {
|
||||||
fp.LocalPort = binary.BigEndian.Uint16(data[offset : offset+2])
|
fp.LocalPort = binary.BigEndian.Uint16(ports[0:2])
|
||||||
fp.RemotePort = binary.BigEndian.Uint16(data[offset+2 : offset+4])
|
fp.RemotePort = binary.BigEndian.Uint16(ports[2:4])
|
||||||
}
|
}
|
||||||
|
|
||||||
fp.Fragment = false
|
fp.Fragment = false
|
||||||
@@ -677,55 +466,7 @@ func (f *Interface) decrypt(hostinfo *HostInfo, mc uint64, out []byte, packet []
|
|||||||
return out, nil
|
return out, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Interface) decryptToTunDelayWrite(hostinfo *HostInfo, messageCounter uint64, out *packet.OutPacket, pkt *packet.Packet, inSegment []byte, fwPacket *firewall.Packet, nb []byte, q int, localCache firewall.ConntrackCache, now time.Time) bool {
|
func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out []byte, packet []byte, fwPacket *firewall.Packet, nb []byte, q int, localCache firewall.ConntrackCache) bool {
|
||||||
var err error
|
|
||||||
|
|
||||||
seg, err := f.readers[q].AllocSeg(out, q)
|
|
||||||
if err != nil {
|
|
||||||
f.l.WithError(err).Errorln("decryptToTunDelayWrite: failed to allocate segment")
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
out.SegmentPayloads[seg] = out.SegmentPayloads[seg][:0]
|
|
||||||
out.SegmentPayloads[seg], err = hostinfo.ConnectionState.dKey.DecryptDanger(out.SegmentPayloads[seg], inSegment[:header.Len], inSegment[header.Len:], messageCounter, nb)
|
|
||||||
if err != nil {
|
|
||||||
hostinfo.logger(f.l).WithError(err).Error("Failed to decrypt packet")
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
err = newPacket(out.SegmentPayloads[seg], true, fwPacket)
|
|
||||||
if err != nil {
|
|
||||||
hostinfo.logger(f.l).WithError(err).WithField("packet", out).
|
|
||||||
Warnf("Error while validating inbound packet")
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
if !hostinfo.ConnectionState.window.Update(f.l, messageCounter) {
|
|
||||||
hostinfo.logger(f.l).WithField("fwPacket", fwPacket).
|
|
||||||
Debugln("dropping out of window packet")
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
dropReason := f.firewall.Drop(*fwPacket, true, hostinfo, f.pki.GetCAPool(), localCache, now)
|
|
||||||
if dropReason != nil {
|
|
||||||
// NOTE: We give `packet` as the `out` here since we already decrypted from it and we don't need it anymore
|
|
||||||
// This gives us a buffer to build the reject packet in
|
|
||||||
f.rejectOutside(out.SegmentPayloads[seg], hostinfo.ConnectionState, hostinfo, nb, inSegment, q)
|
|
||||||
if f.l.Level >= logrus.DebugLevel {
|
|
||||||
hostinfo.logger(f.l).WithField("fwPacket", fwPacket).
|
|
||||||
WithField("reason", dropReason).
|
|
||||||
Debugln("dropping inbound packet")
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
f.connectionManager.In(hostinfo)
|
|
||||||
pkt.OutLen += len(inSegment)
|
|
||||||
out.Segments[seg] = out.Segments[seg][:len(out.SegmentHeaders[seg])+len(out.SegmentPayloads[seg])]
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out []byte, packet []byte, fwPacket *firewall.Packet, nb []byte, q int, localCache firewall.ConntrackCache, now time.Time) bool {
|
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:header.Len], packet[header.Len:], messageCounter, nb)
|
out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:header.Len], packet[header.Len:], messageCounter, nb)
|
||||||
@@ -747,7 +488,7 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
dropReason := f.firewall.Drop(*fwPacket, true, hostinfo, f.pki.GetCAPool(), localCache, now)
|
dropReason := f.firewall.Drop(*fwPacket, true, hostinfo, f.pki.GetCAPool(), localCache)
|
||||||
if dropReason != nil {
|
if dropReason != nil {
|
||||||
// NOTE: We give `packet` as the `out` here since we already decrypted from it and we don't need it anymore
|
// NOTE: We give `packet` as the `out` here since we already decrypted from it and we don't need it anymore
|
||||||
// This gives us a buffer to build the reject packet in
|
// This gives us a buffer to build the reject packet in
|
||||||
|
|||||||
@@ -1,16 +1,17 @@
|
|||||||
package overlay
|
package overlay
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"io"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
|
||||||
"github.com/slackhq/nebula/routing"
|
"github.com/slackhq/nebula/routing"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Device interface {
|
type Device interface {
|
||||||
TunDev
|
io.ReadWriteCloser
|
||||||
Activate() error
|
Activate() error
|
||||||
Networks() []netip.Prefix
|
Networks() []netip.Prefix
|
||||||
Name() string
|
Name() string
|
||||||
RoutesFor(netip.Addr) routing.Gateways
|
RoutesFor(netip.Addr) routing.Gateways
|
||||||
NewMultiQueueReader() (TunDev, error)
|
NewMultiQueueReader() (io.ReadWriteCloser, error)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,91 +0,0 @@
|
|||||||
package eventfd
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/binary"
|
|
||||||
"syscall"
|
|
||||||
|
|
||||||
"golang.org/x/sys/unix"
|
|
||||||
)
|
|
||||||
|
|
||||||
type EventFD struct {
|
|
||||||
fd int
|
|
||||||
buf [8]byte
|
|
||||||
}
|
|
||||||
|
|
||||||
func New() (EventFD, error) {
|
|
||||||
fd, err := unix.Eventfd(0, unix.EFD_NONBLOCK)
|
|
||||||
if err != nil {
|
|
||||||
return EventFD{}, err
|
|
||||||
}
|
|
||||||
return EventFD{
|
|
||||||
fd: fd,
|
|
||||||
buf: [8]byte{},
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *EventFD) Kick() error {
|
|
||||||
binary.LittleEndian.PutUint64(e.buf[:], 1) //is this right???
|
|
||||||
_, err := syscall.Write(int(e.fd), e.buf[:])
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *EventFD) Close() error {
|
|
||||||
if e.fd != 0 {
|
|
||||||
return unix.Close(e.fd)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *EventFD) FD() int {
|
|
||||||
return e.fd
|
|
||||||
}
|
|
||||||
|
|
||||||
type Epoll struct {
|
|
||||||
fd int
|
|
||||||
buf [8]byte
|
|
||||||
events []syscall.EpollEvent
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewEpoll() (Epoll, error) {
|
|
||||||
fd, err := unix.EpollCreate1(0)
|
|
||||||
if err != nil {
|
|
||||||
return Epoll{}, err
|
|
||||||
}
|
|
||||||
return Epoll{
|
|
||||||
fd: fd,
|
|
||||||
buf: [8]byte{},
|
|
||||||
events: make([]syscall.EpollEvent, 1),
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ep *Epoll) AddEvent(fdToAdd int) error {
|
|
||||||
event := syscall.EpollEvent{
|
|
||||||
Events: syscall.EPOLLIN,
|
|
||||||
Fd: int32(fdToAdd),
|
|
||||||
}
|
|
||||||
return syscall.EpollCtl(ep.fd, syscall.EPOLL_CTL_ADD, fdToAdd, &event)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ep *Epoll) Block() (int, error) {
|
|
||||||
n, err := syscall.EpollWait(ep.fd, ep.events, -1)
|
|
||||||
if err != nil {
|
|
||||||
//goland:noinspection GoDirectComparisonOfErrors
|
|
||||||
if err == syscall.EINTR {
|
|
||||||
return 0, nil //??
|
|
||||||
}
|
|
||||||
return -1, err
|
|
||||||
}
|
|
||||||
return n, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ep *Epoll) Clear() error {
|
|
||||||
_, err := syscall.Read(int(ep.events[0].Fd), ep.buf[:])
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ep *Epoll) Close() error {
|
|
||||||
if ep.fd != 0 {
|
|
||||||
return unix.Close(ep.fd)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
@@ -3,7 +3,6 @@ package overlay
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"math"
|
"math"
|
||||||
"net"
|
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"runtime"
|
"runtime"
|
||||||
"strconv"
|
"strconv"
|
||||||
@@ -305,29 +304,3 @@ func parseUnsafeRoutes(c *config.C, networks []netip.Prefix) ([]Route, error) {
|
|||||||
|
|
||||||
return routes, nil
|
return routes, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func ipWithin(o *net.IPNet, i *net.IPNet) bool {
|
|
||||||
// Make sure o contains the lowest form of i
|
|
||||||
if !o.Contains(i.IP.Mask(i.Mask)) {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// Find the max ip in i
|
|
||||||
ip4 := i.IP.To4()
|
|
||||||
if ip4 == nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
last := make(net.IP, len(ip4))
|
|
||||||
copy(last, ip4)
|
|
||||||
for x := range ip4 {
|
|
||||||
last[x] |= ^i.Mask[x]
|
|
||||||
}
|
|
||||||
|
|
||||||
// Make sure o contains the max
|
|
||||||
if !o.Contains(last) {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -225,6 +225,7 @@ func Test_parseUnsafeRoutes(t *testing.T) {
|
|||||||
// no mtu
|
// no mtu
|
||||||
c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"via": "127.0.0.1", "route": "1.0.0.0/8"}}}
|
c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"via": "127.0.0.1", "route": "1.0.0.0/8"}}}
|
||||||
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
|
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
|
||||||
|
require.NoError(t, err)
|
||||||
assert.Len(t, routes, 1)
|
assert.Len(t, routes, 1)
|
||||||
assert.Equal(t, 0, routes[0].MTU)
|
assert.Equal(t, 0, routes[0].MTU)
|
||||||
|
|
||||||
@@ -318,7 +319,7 @@ func Test_makeRouteTree(t *testing.T) {
|
|||||||
|
|
||||||
ip, err = netip.ParseAddr("1.1.0.1")
|
ip, err = netip.ParseAddr("1.1.0.1")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
r, ok = routeTree.Lookup(ip)
|
_, ok = routeTree.Lookup(ip)
|
||||||
assert.False(t, ok)
|
assert.False(t, ok)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,30 +1,15 @@
|
|||||||
package overlay
|
package overlay
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"net"
|
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
"github.com/slackhq/nebula/packet"
|
|
||||||
"github.com/slackhq/nebula/util"
|
"github.com/slackhq/nebula/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
const DefaultMTU = 1300
|
const DefaultMTU = 1300
|
||||||
|
|
||||||
type TunDev interface {
|
|
||||||
io.WriteCloser
|
|
||||||
ReadMany(x []*packet.VirtIOPacket, q int) (int, error)
|
|
||||||
|
|
||||||
//todo this interface sux
|
|
||||||
AllocSeg(pkt *packet.OutPacket, q int) (int, error)
|
|
||||||
WriteOne(x *packet.OutPacket, kick bool, q int) (int, error)
|
|
||||||
WriteMany(x []*packet.OutPacket, q int) (int, error)
|
|
||||||
RecycleRxSeg(pkt *packet.VirtIOPacket, kick bool, q int) error
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: We may be able to remove routines
|
// TODO: We may be able to remove routines
|
||||||
type DeviceFactory func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error)
|
type DeviceFactory func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error)
|
||||||
|
|
||||||
@@ -39,11 +24,11 @@ func NewDeviceFromConfig(c *config.C, l *logrus.Logger, vpnNetworks []netip.Pref
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
//func NewFdDeviceFromConfig(fd *int) DeviceFactory {
|
func NewFdDeviceFromConfig(fd *int) DeviceFactory {
|
||||||
// return func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) {
|
return func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) {
|
||||||
// return newTunFromFd(c, l, *fd, vpnNetworks)
|
return newTunFromFd(c, l, *fd, vpnNetworks)
|
||||||
// }
|
}
|
||||||
//}
|
}
|
||||||
|
|
||||||
func getAllRoutesFromConfig(c *config.C, vpnNetworks []netip.Prefix, initial bool) (bool, []Route, error) {
|
func getAllRoutesFromConfig(c *config.C, vpnNetworks []netip.Prefix, initial bool) (bool, []Route, error) {
|
||||||
if !initial && !c.HasChanged("tun.routes") && !c.HasChanged("tun.unsafe_routes") {
|
if !initial && !c.HasChanged("tun.routes") && !c.HasChanged("tun.unsafe_routes") {
|
||||||
@@ -85,51 +70,3 @@ func findRemovedRoutes(newRoutes, oldRoutes []Route) []Route {
|
|||||||
|
|
||||||
return removed
|
return removed
|
||||||
}
|
}
|
||||||
|
|
||||||
func prefixToMask(prefix netip.Prefix) netip.Addr {
|
|
||||||
pLen := 128
|
|
||||||
if prefix.Addr().Is4() {
|
|
||||||
pLen = 32
|
|
||||||
}
|
|
||||||
|
|
||||||
addr, _ := netip.AddrFromSlice(net.CIDRMask(prefix.Bits(), pLen))
|
|
||||||
return addr
|
|
||||||
}
|
|
||||||
|
|
||||||
func flipBytes(b []byte) []byte {
|
|
||||||
for i := 0; i < len(b); i++ {
|
|
||||||
b[i] ^= 0xFF
|
|
||||||
}
|
|
||||||
return b
|
|
||||||
}
|
|
||||||
func orBytes(a []byte, b []byte) []byte {
|
|
||||||
ret := make([]byte, len(a))
|
|
||||||
for i := 0; i < len(a); i++ {
|
|
||||||
ret[i] = a[i] | b[i]
|
|
||||||
}
|
|
||||||
return ret
|
|
||||||
}
|
|
||||||
|
|
||||||
func getBroadcast(cidr netip.Prefix) netip.Addr {
|
|
||||||
broadcast, _ := netip.AddrFromSlice(
|
|
||||||
orBytes(
|
|
||||||
cidr.Addr().AsSlice(),
|
|
||||||
flipBytes(prefixToMask(cidr).AsSlice()),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
return broadcast
|
|
||||||
}
|
|
||||||
|
|
||||||
func selectGateway(dest netip.Prefix, gateways []netip.Prefix) (netip.Prefix, error) {
|
|
||||||
for _, gateway := range gateways {
|
|
||||||
if dest.Addr().Is4() && gateway.Addr().Is4() {
|
|
||||||
return gateway, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if dest.Addr().Is6() && gateway.Addr().Is6() {
|
|
||||||
return gateway, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return netip.Prefix{}, fmt.Errorf("no gateway found for %v in the list of vpn networks", dest)
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
//go:build !ios && !e2e_testing
|
//go:build darwin && !ios && !e2e_testing
|
||||||
// +build !ios,!e2e_testing
|
// +build darwin,!ios,!e2e_testing
|
||||||
|
|
||||||
package overlay
|
package overlay
|
||||||
|
|
||||||
@@ -8,48 +8,27 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
|
||||||
"sync/atomic"
|
|
||||||
"syscall"
|
|
||||||
"unsafe"
|
"unsafe"
|
||||||
|
|
||||||
"github.com/gaissmai/bart"
|
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
"github.com/slackhq/nebula/routing"
|
|
||||||
"github.com/slackhq/nebula/util"
|
"github.com/slackhq/nebula/util"
|
||||||
netroute "golang.org/x/net/route"
|
netroute "golang.org/x/net/route"
|
||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
|
wgtun "golang.zx2c4.com/wireguard/tun"
|
||||||
)
|
)
|
||||||
|
|
||||||
type tun struct {
|
type tun struct {
|
||||||
io.ReadWriteCloser
|
linkAddr *netroute.LinkAddr
|
||||||
Device string
|
|
||||||
vpnNetworks []netip.Prefix
|
|
||||||
DefaultMTU int
|
|
||||||
Routes atomic.Pointer[[]Route]
|
|
||||||
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
|
||||||
linkAddr *netroute.LinkAddr
|
|
||||||
l *logrus.Logger
|
|
||||||
|
|
||||||
// cache out buffer since we need to prepend 4 bytes for tun metadata
|
|
||||||
out []byte
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ioctl structures for Darwin network configuration
|
||||||
type ifReq struct {
|
type ifReq struct {
|
||||||
Name [unix.IFNAMSIZ]byte
|
Name [unix.IFNAMSIZ]byte
|
||||||
Flags uint16
|
Flags uint16
|
||||||
pad [8]byte
|
pad [8]byte
|
||||||
}
|
}
|
||||||
|
|
||||||
const (
|
|
||||||
_SIOCAIFADDR_IN6 = 2155899162
|
|
||||||
_UTUN_OPT_IFNAME = 2
|
|
||||||
_IN6_IFF_NODAD = 0x0020
|
|
||||||
_IN6_IFF_SECURED = 0x0400
|
|
||||||
utunControlName = "com.apple.net.utun_control"
|
|
||||||
)
|
|
||||||
|
|
||||||
type ifreqMTU struct {
|
type ifreqMTU struct {
|
||||||
Name [16]byte
|
Name [16]byte
|
||||||
MTU int32
|
MTU int32
|
||||||
@@ -79,60 +58,61 @@ type ifreqAlias6 struct {
|
|||||||
Lifetime addrLifetime
|
Lifetime addrLifetime
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
|
const (
|
||||||
|
_SIOCAIFADDR_IN6 = 2155899162
|
||||||
|
_IN6_IFF_NODAD = 0x0020
|
||||||
|
)
|
||||||
|
|
||||||
|
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*wgTun, error) {
|
||||||
|
return nil, fmt.Errorf("newTunFromFd not supported on Darwin")
|
||||||
|
}
|
||||||
|
|
||||||
|
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*wgTun, error) {
|
||||||
name := c.GetString("tun.dev", "")
|
name := c.GetString("tun.dev", "")
|
||||||
ifIndex := -1
|
deviceName := "utun"
|
||||||
|
|
||||||
|
// Parse device name to handle utun[0-9]+ format
|
||||||
if name != "" && name != "utun" {
|
if name != "" && name != "utun" {
|
||||||
|
ifIndex := -1
|
||||||
_, err := fmt.Sscanf(name, "utun%d", &ifIndex)
|
_, err := fmt.Sscanf(name, "utun%d", &ifIndex)
|
||||||
if err != nil || ifIndex < 0 {
|
if err != nil || ifIndex < 0 {
|
||||||
// NOTE: we don't make this error so we don't break existing
|
// NOTE: we don't make this error so we don't break existing
|
||||||
// configs that set a name before it was used.
|
// configs that set a name before it was used.
|
||||||
l.Warn("interface name must be utun[0-9]+ on Darwin, ignoring")
|
l.Warn("interface name must be utun[0-9]+ on Darwin, ignoring")
|
||||||
ifIndex = -1
|
} else {
|
||||||
|
deviceName = name
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fd, err := unix.Socket(unix.AF_SYSTEM, unix.SOCK_DGRAM, unix.AF_SYS_CONTROL)
|
mtu := c.GetInt("tun.mtu", DefaultMTU)
|
||||||
|
|
||||||
|
// Create WireGuard TUN device
|
||||||
|
tunDevice, err := wgtun.CreateTUN(deviceName, mtu)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("system socket: %v", err)
|
return nil, fmt.Errorf("failed to create TUN device: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
var ctlInfo = &unix.CtlInfo{}
|
// Get the actual device name
|
||||||
copy(ctlInfo.Name[:], utunControlName)
|
actualName, err := tunDevice.Name()
|
||||||
|
|
||||||
err = unix.IoctlCtlInfo(fd, ctlInfo)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("CTLIOCGINFO: %v", err)
|
tunDevice.Close()
|
||||||
|
return nil, fmt.Errorf("failed to get TUN device name: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = unix.Connect(fd, &unix.SockaddrCtl{
|
t := &wgTun{
|
||||||
ID: ctlInfo.Id,
|
tunDevice: tunDevice,
|
||||||
Unit: uint32(ifIndex) + 1,
|
vpnNetworks: vpnNetworks,
|
||||||
})
|
MaxMTU: mtu,
|
||||||
if err != nil {
|
DefaultMTU: mtu,
|
||||||
return nil, fmt.Errorf("SYS_CONNECT: %v", err)
|
l: l,
|
||||||
}
|
}
|
||||||
|
|
||||||
name, err = unix.GetsockoptString(fd, unix.AF_SYS_CONTROL, _UTUN_OPT_IFNAME)
|
// Create Darwin-specific route manager
|
||||||
if err != nil {
|
t.routeManager = &tun{}
|
||||||
return nil, fmt.Errorf("failed to retrieve tun name: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = unix.SetNonblock(fd, true)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("SetNonblock: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
t := &tun{
|
|
||||||
ReadWriteCloser: os.NewFile(uintptr(fd), ""),
|
|
||||||
Device: name,
|
|
||||||
vpnNetworks: vpnNetworks,
|
|
||||||
DefaultMTU: c.GetInt("tun.mtu", DefaultMTU),
|
|
||||||
l: l,
|
|
||||||
}
|
|
||||||
|
|
||||||
err = t.reload(c, true)
|
err = t.reload(c, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
tunDevice.Close()
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -143,215 +123,251 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
l.WithField("name", actualName).Info("Created WireGuard TUN device")
|
||||||
|
|
||||||
return t, nil
|
return t, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) deviceBytes() (o [16]byte) {
|
func (rm *tun) Activate(t *wgTun) error {
|
||||||
for i, c := range t.Device {
|
name, err := t.tunDevice.Name()
|
||||||
o[i] = byte(c)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) {
|
|
||||||
return nil, fmt.Errorf("newTunFromFd not supported in Darwin")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *tun) Close() error {
|
|
||||||
if t.ReadWriteCloser != nil {
|
|
||||||
return t.ReadWriteCloser.Close()
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *tun) Activate() error {
|
|
||||||
devName := t.deviceBytes()
|
|
||||||
|
|
||||||
s, err := unix.Socket(
|
|
||||||
unix.AF_INET,
|
|
||||||
unix.SOCK_DGRAM,
|
|
||||||
unix.IPPROTO_IP,
|
|
||||||
)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return fmt.Errorf("failed to get device name: %w", err)
|
||||||
}
|
|
||||||
defer unix.Close(s)
|
|
||||||
|
|
||||||
fd := uintptr(s)
|
|
||||||
|
|
||||||
// Set the MTU on the device
|
|
||||||
ifm := ifreqMTU{Name: devName, MTU: int32(t.DefaultMTU)}
|
|
||||||
if err = ioctl(fd, unix.SIOCSIFMTU, uintptr(unsafe.Pointer(&ifm))); err != nil {
|
|
||||||
return fmt.Errorf("failed to set tun mtu: %v", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get the device flags
|
// Set the MTU
|
||||||
ifrf := ifReq{Name: devName}
|
rm.SetMTU(t, t.MaxMTU)
|
||||||
if err = ioctl(fd, unix.SIOCGIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil {
|
|
||||||
return fmt.Errorf("failed to get tun flags: %s", err)
|
// Add IP addresses
|
||||||
|
for _, network := range t.vpnNetworks {
|
||||||
|
if err := rm.addIP(t, name, network); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
linkAddr, err := getLinkAddr(t.Device)
|
// Bring up the interface using ioctl
|
||||||
|
if err := rm.bringUpInterface(name); err != nil {
|
||||||
|
return fmt.Errorf("failed to bring up interface: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the link address for routing
|
||||||
|
linkAddr, err := getLinkAddr(name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return fmt.Errorf("failed to get link address: %w", err)
|
||||||
}
|
}
|
||||||
if linkAddr == nil {
|
if linkAddr == nil {
|
||||||
return fmt.Errorf("unable to discover link_addr for tun interface")
|
return fmt.Errorf("unable to discover link_addr for tun interface")
|
||||||
}
|
}
|
||||||
t.linkAddr = linkAddr
|
rm.linkAddr = linkAddr
|
||||||
|
|
||||||
for _, network := range t.vpnNetworks {
|
// Set the routes
|
||||||
if network.Addr().Is4() {
|
if err := rm.AddRoutes(t, false); err != nil {
|
||||||
err = t.activate4(network)
|
return err
|
||||||
if err != nil {
|
}
|
||||||
return err
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rm *tun) bringUpInterface(name string) error {
|
||||||
|
// Open a socket for ioctl
|
||||||
|
fd, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, 0)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create socket: %w", err)
|
||||||
|
}
|
||||||
|
defer unix.Close(fd)
|
||||||
|
|
||||||
|
// Get current flags
|
||||||
|
var ifrf ifReq
|
||||||
|
copy(ifrf.Name[:], name)
|
||||||
|
|
||||||
|
if err := ioctl(uintptr(fd), unix.SIOCGIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil {
|
||||||
|
return fmt.Errorf("failed to get interface flags: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set IFF_UP and IFF_RUNNING flags
|
||||||
|
ifrf.Flags = ifrf.Flags | unix.IFF_UP | unix.IFF_RUNNING
|
||||||
|
|
||||||
|
if err := ioctl(uintptr(fd), unix.SIOCSIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil {
|
||||||
|
return fmt.Errorf("failed to set interface flags: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rm *tun) SetMTU(t *wgTun, mtu int) {
|
||||||
|
name, err := t.tunDevice.Name()
|
||||||
|
if err != nil {
|
||||||
|
t.l.WithError(err).Error("Failed to get device name for MTU set")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Open a socket for ioctl
|
||||||
|
fd, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, 0)
|
||||||
|
if err != nil {
|
||||||
|
t.l.WithError(err).Error("Failed to create socket for MTU set")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer unix.Close(fd)
|
||||||
|
|
||||||
|
// Prepare the ioctl request
|
||||||
|
var ifr ifreqMTU
|
||||||
|
copy(ifr.Name[:], name)
|
||||||
|
ifr.MTU = int32(mtu)
|
||||||
|
|
||||||
|
// Set the MTU using ioctl
|
||||||
|
if err := ioctl(uintptr(fd), unix.SIOCSIFMTU, uintptr(unsafe.Pointer(&ifr))); err != nil {
|
||||||
|
t.l.WithError(err).Error("Failed to set tun mtu via ioctl")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rm *tun) SetDefaultRoute(t *wgTun, cidr netip.Prefix) error {
|
||||||
|
// On Darwin, routes are set via ifconfig and route commands
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rm *tun) AddRoutes(t *wgTun, logErrors bool) error {
|
||||||
|
routes := *t.Routes.Load()
|
||||||
|
for _, r := range routes {
|
||||||
|
if !r.Install {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
err := rm.addRoute(r.Cidr)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, unix.EEXIST) {
|
||||||
|
t.l.WithField("route", r.Cidr).
|
||||||
|
Warnf("unable to add unsafe_route, identical route already exists")
|
||||||
|
} else {
|
||||||
|
retErr := util.NewContextualError("Failed to add route", map[string]any{"route": r}, err)
|
||||||
|
if logErrors {
|
||||||
|
retErr.Log(t.l)
|
||||||
|
} else {
|
||||||
|
return retErr
|
||||||
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
err = t.activate6(network)
|
t.l.WithField("route", r).Info("Added route")
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Run the interface
|
return nil
|
||||||
ifrf.Flags = ifrf.Flags | unix.IFF_UP | unix.IFF_RUNNING
|
|
||||||
if err = ioctl(fd, unix.SIOCSIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil {
|
|
||||||
return fmt.Errorf("failed to run tun device: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Unsafe path routes
|
|
||||||
return t.addRoutes(false)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) activate4(network netip.Prefix) error {
|
func (rm *tun) RemoveRoutes(t *wgTun, routes []Route) {
|
||||||
s, err := unix.Socket(
|
for _, r := range routes {
|
||||||
unix.AF_INET,
|
if !r.Install {
|
||||||
unix.SOCK_DGRAM,
|
continue
|
||||||
unix.IPPROTO_IP,
|
}
|
||||||
)
|
|
||||||
|
err := rm.delRoute(r.Cidr)
|
||||||
|
if err != nil {
|
||||||
|
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
|
||||||
|
} else {
|
||||||
|
t.l.WithField("route", r).Info("Removed route")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rm *tun) NewMultiQueueReader(t *wgTun) (io.ReadWriteCloser, error) {
|
||||||
|
// Darwin doesn't support multi-queue TUN devices in the same way as Linux
|
||||||
|
// Return a reader that wraps the same device
|
||||||
|
return &wgTunReader{
|
||||||
|
parent: t,
|
||||||
|
tunDevice: t.tunDevice,
|
||||||
|
offset: 0,
|
||||||
|
l: t.l,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rm *tun) addIP(t *wgTun, name string, network netip.Prefix) error {
|
||||||
|
addr := network.Addr()
|
||||||
|
|
||||||
|
if addr.Is4() {
|
||||||
|
return rm.addIPv4(name, network)
|
||||||
|
} else {
|
||||||
|
return rm.addIPv6(name, network)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rm *tun) addIPv4(name string, network netip.Prefix) error {
|
||||||
|
// Open an IPv4 socket for ioctl
|
||||||
|
s, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return fmt.Errorf("failed to create IPv4 socket: %w", err)
|
||||||
}
|
}
|
||||||
defer unix.Close(s)
|
defer unix.Close(s)
|
||||||
|
|
||||||
ifr := ifreqAlias4{
|
var ifr ifreqAlias4
|
||||||
Name: t.deviceBytes(),
|
copy(ifr.Name[:], name)
|
||||||
Addr: unix.RawSockaddrInet4{
|
|
||||||
Len: unix.SizeofSockaddrInet4,
|
// Set the address
|
||||||
Family: unix.AF_INET,
|
ifr.Addr = unix.RawSockaddrInet4{
|
||||||
Addr: network.Addr().As4(),
|
Len: unix.SizeofSockaddrInet4,
|
||||||
},
|
Family: unix.AF_INET,
|
||||||
DstAddr: unix.RawSockaddrInet4{
|
Addr: network.Addr().As4(),
|
||||||
Len: unix.SizeofSockaddrInet4,
|
}
|
||||||
Family: unix.AF_INET,
|
|
||||||
Addr: network.Addr().As4(),
|
// Set the destination address (same as address for point-to-point)
|
||||||
},
|
ifr.DstAddr = unix.RawSockaddrInet4{
|
||||||
MaskAddr: unix.RawSockaddrInet4{
|
Len: unix.SizeofSockaddrInet4,
|
||||||
Len: unix.SizeofSockaddrInet4,
|
Family: unix.AF_INET,
|
||||||
Family: unix.AF_INET,
|
Addr: network.Addr().As4(),
|
||||||
Addr: prefixToMask(network).As4(),
|
}
|
||||||
},
|
|
||||||
|
// Set the netmask
|
||||||
|
ifr.MaskAddr = unix.RawSockaddrInet4{
|
||||||
|
Len: unix.SizeofSockaddrInet4,
|
||||||
|
Family: unix.AF_INET,
|
||||||
|
Addr: prefixToMask(network).As4(),
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := ioctl(uintptr(s), unix.SIOCAIFADDR, uintptr(unsafe.Pointer(&ifr))); err != nil {
|
if err := ioctl(uintptr(s), unix.SIOCAIFADDR, uintptr(unsafe.Pointer(&ifr))); err != nil {
|
||||||
return fmt.Errorf("failed to set tun v4 address: %s", err)
|
return fmt.Errorf("failed to set IPv4 address via ioctl: %w", err)
|
||||||
}
|
|
||||||
|
|
||||||
err = addRoute(network, t.linkAddr)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) activate6(network netip.Prefix) error {
|
func (rm *tun) addIPv6(name string, network netip.Prefix) error {
|
||||||
s, err := unix.Socket(
|
// Open an IPv6 socket for ioctl
|
||||||
unix.AF_INET6,
|
s, err := unix.Socket(unix.AF_INET6, unix.SOCK_DGRAM, unix.IPPROTO_IP)
|
||||||
unix.SOCK_DGRAM,
|
|
||||||
unix.IPPROTO_IP,
|
|
||||||
)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return fmt.Errorf("failed to create IPv6 socket: %w", err)
|
||||||
}
|
}
|
||||||
defer unix.Close(s)
|
defer unix.Close(s)
|
||||||
|
|
||||||
ifr := ifreqAlias6{
|
var ifr ifreqAlias6
|
||||||
Name: t.deviceBytes(),
|
copy(ifr.Name[:], name)
|
||||||
Addr: unix.RawSockaddrInet6{
|
|
||||||
Len: unix.SizeofSockaddrInet6,
|
// Set the address
|
||||||
Family: unix.AF_INET6,
|
ifr.Addr = unix.RawSockaddrInet6{
|
||||||
Addr: network.Addr().As16(),
|
Len: unix.SizeofSockaddrInet6,
|
||||||
},
|
Family: unix.AF_INET6,
|
||||||
PrefixMask: unix.RawSockaddrInet6{
|
Addr: network.Addr().As16(),
|
||||||
Len: unix.SizeofSockaddrInet6,
|
|
||||||
Family: unix.AF_INET6,
|
|
||||||
Addr: prefixToMask(network).As16(),
|
|
||||||
},
|
|
||||||
Lifetime: addrLifetime{
|
|
||||||
// never expires
|
|
||||||
Vltime: 0xffffffff,
|
|
||||||
Pltime: 0xffffffff,
|
|
||||||
},
|
|
||||||
Flags: _IN6_IFF_NODAD,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Set the prefix mask
|
||||||
|
ifr.PrefixMask = unix.RawSockaddrInet6{
|
||||||
|
Len: unix.SizeofSockaddrInet6,
|
||||||
|
Family: unix.AF_INET6,
|
||||||
|
Addr: prefixToMask(network).As16(),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set lifetime (never expires)
|
||||||
|
ifr.Lifetime = addrLifetime{
|
||||||
|
Vltime: 0xffffffff,
|
||||||
|
Pltime: 0xffffffff,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set flags (no DAD - Duplicate Address Detection)
|
||||||
|
ifr.Flags = _IN6_IFF_NODAD
|
||||||
|
|
||||||
if err := ioctl(uintptr(s), _SIOCAIFADDR_IN6, uintptr(unsafe.Pointer(&ifr))); err != nil {
|
if err := ioctl(uintptr(s), _SIOCAIFADDR_IN6, uintptr(unsafe.Pointer(&ifr))); err != nil {
|
||||||
return fmt.Errorf("failed to set tun address: %s", err)
|
return fmt.Errorf("failed to set IPv6 address via ioctl: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) reload(c *config.C, initial bool) error {
|
|
||||||
change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if !initial && !change {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
routeTree, err := makeRouteTree(t.l, routes, false)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Teach nebula how to handle the routes before establishing them in the system table
|
|
||||||
oldRoutes := t.Routes.Swap(&routes)
|
|
||||||
t.routeTree.Store(routeTree)
|
|
||||||
|
|
||||||
if !initial {
|
|
||||||
// Remove first, if the system removes a wanted route hopefully it will be re-added next
|
|
||||||
err := t.removeRoutes(findRemovedRoutes(routes, *oldRoutes))
|
|
||||||
if err != nil {
|
|
||||||
util.LogWithContextIfNeeded("Failed to remove routes", err, t.l)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Ensure any routes we actually want are installed
|
|
||||||
err = t.addRoutes(true)
|
|
||||||
if err != nil {
|
|
||||||
// Catch any stray logs
|
|
||||||
util.LogWithContextIfNeeded("Failed to add routes", err, t.l)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
|
|
||||||
r, ok := t.routeTree.Load().Lookup(ip)
|
|
||||||
if ok {
|
|
||||||
return r
|
|
||||||
}
|
|
||||||
return routing.Gateways{}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get the LinkAddr for the interface of the given name
|
|
||||||
// Is there an easier way to fetch this when we create the interface?
|
|
||||||
// Maybe SIOCGIFINDEX? but this doesn't appear to exist in the darwin headers.
|
|
||||||
func getLinkAddr(name string) (*netroute.LinkAddr, error) {
|
func getLinkAddr(name string) (*netroute.LinkAddr, error) {
|
||||||
rib, err := netroute.FetchRIB(unix.AF_UNSPEC, unix.NET_RT_IFLIST, 0)
|
rib, err := netroute.FetchRIB(unix.AF_UNSPEC, unix.NET_RT_IFLIST, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -377,53 +393,7 @@ func getLinkAddr(name string) (*netroute.LinkAddr, error) {
|
|||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) addRoutes(logErrors bool) error {
|
func (rm *tun) addRoute(prefix netip.Prefix) error {
|
||||||
routes := *t.Routes.Load()
|
|
||||||
|
|
||||||
for _, r := range routes {
|
|
||||||
if len(r.Via) == 0 || !r.Install {
|
|
||||||
// We don't allow route MTUs so only install routes with a via
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
err := addRoute(r.Cidr, t.linkAddr)
|
|
||||||
if err != nil {
|
|
||||||
if errors.Is(err, unix.EEXIST) {
|
|
||||||
t.l.WithField("route", r.Cidr).
|
|
||||||
Warnf("unable to add unsafe_route, identical route already exists")
|
|
||||||
} else {
|
|
||||||
retErr := util.NewContextualError("Failed to add route", map[string]any{"route": r}, err)
|
|
||||||
if logErrors {
|
|
||||||
retErr.Log(t.l)
|
|
||||||
} else {
|
|
||||||
return retErr
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
t.l.WithField("route", r).Info("Added route")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *tun) removeRoutes(routes []Route) error {
|
|
||||||
for _, r := range routes {
|
|
||||||
if !r.Install {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
err := delRoute(r.Cidr, t.linkAddr)
|
|
||||||
if err != nil {
|
|
||||||
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
|
|
||||||
} else {
|
|
||||||
t.l.WithField("route", r).Info("Removed route")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func addRoute(prefix netip.Prefix, gateway netroute.Addr) error {
|
|
||||||
sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
|
sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
|
return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
|
||||||
@@ -441,13 +411,13 @@ func addRoute(prefix netip.Prefix, gateway netroute.Addr) error {
|
|||||||
route.Addrs = []netroute.Addr{
|
route.Addrs = []netroute.Addr{
|
||||||
unix.RTAX_DST: &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()},
|
unix.RTAX_DST: &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()},
|
||||||
unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()},
|
unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()},
|
||||||
unix.RTAX_GATEWAY: gateway,
|
unix.RTAX_GATEWAY: rm.linkAddr,
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
route.Addrs = []netroute.Addr{
|
route.Addrs = []netroute.Addr{
|
||||||
unix.RTAX_DST: &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()},
|
unix.RTAX_DST: &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()},
|
||||||
unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()},
|
unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()},
|
||||||
unix.RTAX_GATEWAY: gateway,
|
unix.RTAX_GATEWAY: rm.linkAddr,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -464,7 +434,7 @@ func addRoute(prefix netip.Prefix, gateway netroute.Addr) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func delRoute(prefix netip.Prefix, gateway netroute.Addr) error {
|
func (rm *tun) delRoute(prefix netip.Prefix) error {
|
||||||
sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
|
sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
|
return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
|
||||||
@@ -481,13 +451,13 @@ func delRoute(prefix netip.Prefix, gateway netroute.Addr) error {
|
|||||||
route.Addrs = []netroute.Addr{
|
route.Addrs = []netroute.Addr{
|
||||||
unix.RTAX_DST: &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()},
|
unix.RTAX_DST: &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()},
|
||||||
unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()},
|
unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()},
|
||||||
unix.RTAX_GATEWAY: gateway,
|
unix.RTAX_GATEWAY: rm.linkAddr,
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
route.Addrs = []netroute.Addr{
|
route.Addrs = []netroute.Addr{
|
||||||
unix.RTAX_DST: &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()},
|
unix.RTAX_DST: &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()},
|
||||||
unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()},
|
unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()},
|
||||||
unix.RTAX_GATEWAY: gateway,
|
unix.RTAX_GATEWAY: rm.linkAddr,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -495,6 +465,7 @@ func delRoute(prefix netip.Prefix, gateway netroute.Addr) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to create route.RouteMessage: %w", err)
|
return fmt.Errorf("failed to create route.RouteMessage: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = unix.Write(sock, data[:])
|
_, err = unix.Write(sock, data[:])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err)
|
return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err)
|
||||||
@@ -503,52 +474,34 @@ func delRoute(prefix netip.Prefix, gateway netroute.Addr) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) Read(to []byte) (int, error) {
|
func ioctl(a1, a2, a3 uintptr) error {
|
||||||
buf := make([]byte, len(to)+4)
|
_, _, errno := unix.Syscall(unix.SYS_IOCTL, a1, a2, a3)
|
||||||
|
if errno != 0 {
|
||||||
n, err := t.ReadWriteCloser.Read(buf)
|
return errno
|
||||||
|
}
|
||||||
copy(to, buf[4:])
|
return nil
|
||||||
return n - 4, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Write is only valid for single threaded use
|
func prefixToMask(prefix netip.Prefix) netip.Addr {
|
||||||
func (t *tun) Write(from []byte) (int, error) {
|
bits := prefix.Bits()
|
||||||
buf := t.out
|
if prefix.Addr().Is4() {
|
||||||
if cap(buf) < len(from)+4 {
|
// Create IPv4 netmask from prefix length
|
||||||
buf = make([]byte, len(from)+4)
|
mask := ^uint32(0) << (32 - bits)
|
||||||
t.out = buf
|
return netip.AddrFrom4([4]byte{
|
||||||
}
|
byte(mask >> 24),
|
||||||
buf = buf[:len(from)+4]
|
byte(mask >> 16),
|
||||||
|
byte(mask >> 8),
|
||||||
if len(from) == 0 {
|
byte(mask),
|
||||||
return 0, syscall.EIO
|
})
|
||||||
}
|
|
||||||
|
|
||||||
// Determine the IP Family for the NULL L2 Header
|
|
||||||
ipVer := from[0] >> 4
|
|
||||||
if ipVer == 4 {
|
|
||||||
buf[3] = syscall.AF_INET
|
|
||||||
} else if ipVer == 6 {
|
|
||||||
buf[3] = syscall.AF_INET6
|
|
||||||
} else {
|
} else {
|
||||||
return 0, fmt.Errorf("unable to determine IP version from packet")
|
// Create IPv6 netmask from prefix length
|
||||||
|
var mask [16]byte
|
||||||
|
for i := 0; i < bits/8; i++ {
|
||||||
|
mask[i] = 0xff
|
||||||
|
}
|
||||||
|
if bits%8 != 0 {
|
||||||
|
mask[bits/8] = ^byte(0) << (8 - bits%8)
|
||||||
|
}
|
||||||
|
return netip.AddrFrom16(mask)
|
||||||
}
|
}
|
||||||
|
|
||||||
copy(buf[4:], from)
|
|
||||||
|
|
||||||
n, err := t.ReadWriteCloser.Write(buf)
|
|
||||||
return n - 4, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *tun) Networks() []netip.Prefix {
|
|
||||||
return t.vpnNetworks
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *tun) Name() string {
|
|
||||||
return t.Device
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
|
||||||
return nil, fmt.Errorf("TODO: multiqueue not implemented for darwin")
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,8 +9,6 @@ import (
|
|||||||
"github.com/rcrowley/go-metrics"
|
"github.com/rcrowley/go-metrics"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
"github.com/slackhq/nebula/iputil"
|
"github.com/slackhq/nebula/iputil"
|
||||||
"github.com/slackhq/nebula/overlay/virtqueue"
|
|
||||||
"github.com/slackhq/nebula/packet"
|
|
||||||
"github.com/slackhq/nebula/routing"
|
"github.com/slackhq/nebula/routing"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -24,10 +22,6 @@ type disabledTun struct {
|
|||||||
l *logrus.Logger
|
l *logrus.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
func (*disabledTun) RecycleRxSeg(pkt *packet.VirtIOPacket, kick bool, q int) error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func newDisabledTun(vpnNetworks []netip.Prefix, queueLen int, metricsEnabled bool, l *logrus.Logger) *disabledTun {
|
func newDisabledTun(vpnNetworks []netip.Prefix, queueLen int, metricsEnabled bool, l *logrus.Logger) *disabledTun {
|
||||||
tun := &disabledTun{
|
tun := &disabledTun{
|
||||||
vpnNetworks: vpnNetworks,
|
vpnNetworks: vpnNetworks,
|
||||||
@@ -46,10 +40,6 @@ func newDisabledTun(vpnNetworks []netip.Prefix, queueLen int, metricsEnabled boo
|
|||||||
return tun
|
return tun
|
||||||
}
|
}
|
||||||
|
|
||||||
func (*disabledTun) GetQueues() []*virtqueue.SplitQueue {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (*disabledTun) Activate() error {
|
func (*disabledTun) Activate() error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -115,23 +105,7 @@ func (t *disabledTun) Write(b []byte) (int, error) {
|
|||||||
return len(b), nil
|
return len(b), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *disabledTun) AllocSeg(pkt *packet.OutPacket, q int) (int, error) {
|
func (t *disabledTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||||
return 0, fmt.Errorf("tun_disabled: AllocSeg not implemented")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *disabledTun) WriteOne(x *packet.OutPacket, kick bool, q int) (int, error) {
|
|
||||||
return 0, fmt.Errorf("tun_disabled: WriteOne not implemented")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *disabledTun) WriteMany(x []*packet.OutPacket, q int) (int, error) {
|
|
||||||
return 0, fmt.Errorf("tun_disabled: WriteMany not implemented")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *disabledTun) ReadMany(b []*packet.VirtIOPacket, _ int) (int, error) {
|
|
||||||
return t.Read(b[0].Payload)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *disabledTun) NewMultiQueueReader() (TunDev, error) {
|
|
||||||
return t, nil
|
return t, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,284 +1,77 @@
|
|||||||
//go:build !e2e_testing
|
//go:build freebsd && !e2e_testing
|
||||||
// +build !e2e_testing
|
// +build freebsd,!e2e_testing
|
||||||
|
|
||||||
package overlay
|
package overlay
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"io/fs"
|
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"sync/atomic"
|
"os/exec"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
|
||||||
"unsafe"
|
"unsafe"
|
||||||
|
|
||||||
"github.com/gaissmai/bart"
|
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
"github.com/slackhq/nebula/routing"
|
|
||||||
"github.com/slackhq/nebula/util"
|
"github.com/slackhq/nebula/util"
|
||||||
netroute "golang.org/x/net/route"
|
|
||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
|
wgtun "golang.zx2c4.com/wireguard/tun"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
type tun struct{}
|
||||||
// FIODGNAME is defined in sys/sys/filio.h on FreeBSD
|
|
||||||
// For 32-bit systems, use FIODGNAME_32 (not defined in this file: 0x80086678)
|
|
||||||
FIODGNAME = 0x80106678
|
|
||||||
TUNSIFMODE = 0x8004745e
|
|
||||||
TUNSIFHEAD = 0x80047460
|
|
||||||
OSIOCAIFADDR_IN6 = 0x8088691b
|
|
||||||
IN6_IFF_NODAD = 0x0020
|
|
||||||
)
|
|
||||||
|
|
||||||
type fiodgnameArg struct {
|
|
||||||
length int32
|
|
||||||
pad [4]byte
|
|
||||||
buf unsafe.Pointer
|
|
||||||
}
|
|
||||||
|
|
||||||
|
// ifreqRename is used for renaming network interfaces on FreeBSD
|
||||||
type ifreqRename struct {
|
type ifreqRename struct {
|
||||||
Name [unix.IFNAMSIZ]byte
|
Name [unix.IFNAMSIZ]byte
|
||||||
Data uintptr
|
Data uintptr
|
||||||
}
|
}
|
||||||
|
|
||||||
type ifreqDestroy struct {
|
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*wgTun, error) {
|
||||||
Name [unix.IFNAMSIZ]byte
|
return nil, fmt.Errorf("newTunFromFd not supported on FreeBSD")
|
||||||
pad [16]byte
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type ifReq struct {
|
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*wgTun, error) {
|
||||||
Name [unix.IFNAMSIZ]byte
|
deviceName := c.GetString("tun.dev", "tun")
|
||||||
Flags uint16
|
mtu := c.GetInt("tun.mtu", DefaultMTU)
|
||||||
}
|
|
||||||
|
|
||||||
type ifreqMTU struct {
|
// Create WireGuard TUN device
|
||||||
Name [unix.IFNAMSIZ]byte
|
tunDevice, err := wgtun.CreateTUN(deviceName, mtu)
|
||||||
MTU int32
|
|
||||||
}
|
|
||||||
|
|
||||||
type addrLifetime struct {
|
|
||||||
Expire uint64
|
|
||||||
Preferred uint64
|
|
||||||
Vltime uint32
|
|
||||||
Pltime uint32
|
|
||||||
}
|
|
||||||
|
|
||||||
type ifreqAlias4 struct {
|
|
||||||
Name [unix.IFNAMSIZ]byte
|
|
||||||
Addr unix.RawSockaddrInet4
|
|
||||||
DstAddr unix.RawSockaddrInet4
|
|
||||||
MaskAddr unix.RawSockaddrInet4
|
|
||||||
VHid uint32
|
|
||||||
}
|
|
||||||
|
|
||||||
type ifreqAlias6 struct {
|
|
||||||
Name [unix.IFNAMSIZ]byte
|
|
||||||
Addr unix.RawSockaddrInet6
|
|
||||||
DstAddr unix.RawSockaddrInet6
|
|
||||||
PrefixMask unix.RawSockaddrInet6
|
|
||||||
Flags uint32
|
|
||||||
Lifetime addrLifetime
|
|
||||||
VHid uint32
|
|
||||||
}
|
|
||||||
|
|
||||||
type tun struct {
|
|
||||||
Device string
|
|
||||||
vpnNetworks []netip.Prefix
|
|
||||||
MTU int
|
|
||||||
Routes atomic.Pointer[[]Route]
|
|
||||||
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
|
||||||
linkAddr *netroute.LinkAddr
|
|
||||||
l *logrus.Logger
|
|
||||||
devFd int
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *tun) Read(to []byte) (int, error) {
|
|
||||||
// use readv() to read from the tunnel device, to eliminate the need for copying the buffer
|
|
||||||
if t.devFd < 0 {
|
|
||||||
return -1, syscall.EINVAL
|
|
||||||
}
|
|
||||||
|
|
||||||
// first 4 bytes is protocol family, in network byte order
|
|
||||||
head := make([]byte, 4)
|
|
||||||
|
|
||||||
iovecs := []syscall.Iovec{
|
|
||||||
{&head[0], 4},
|
|
||||||
{&to[0], uint64(len(to))},
|
|
||||||
}
|
|
||||||
|
|
||||||
n, _, errno := syscall.Syscall(syscall.SYS_READV, uintptr(t.devFd), uintptr(unsafe.Pointer(&iovecs[0])), uintptr(2))
|
|
||||||
|
|
||||||
var err error
|
|
||||||
if errno != 0 {
|
|
||||||
err = syscall.Errno(errno)
|
|
||||||
} else {
|
|
||||||
err = nil
|
|
||||||
}
|
|
||||||
// fix bytes read number to exclude header
|
|
||||||
bytesRead := int(n)
|
|
||||||
if bytesRead < 0 {
|
|
||||||
return bytesRead, err
|
|
||||||
} else if bytesRead < 4 {
|
|
||||||
return 0, err
|
|
||||||
} else {
|
|
||||||
return bytesRead - 4, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Write is only valid for single threaded use
|
|
||||||
func (t *tun) Write(from []byte) (int, error) {
|
|
||||||
// use writev() to write to the tunnel device, to eliminate the need for copying the buffer
|
|
||||||
if t.devFd < 0 {
|
|
||||||
return -1, syscall.EINVAL
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(from) <= 1 {
|
|
||||||
return 0, syscall.EIO
|
|
||||||
}
|
|
||||||
ipVer := from[0] >> 4
|
|
||||||
var head []byte
|
|
||||||
// first 4 bytes is protocol family, in network byte order
|
|
||||||
if ipVer == 4 {
|
|
||||||
head = []byte{0, 0, 0, syscall.AF_INET}
|
|
||||||
} else if ipVer == 6 {
|
|
||||||
head = []byte{0, 0, 0, syscall.AF_INET6}
|
|
||||||
} else {
|
|
||||||
return 0, fmt.Errorf("unable to determine IP version from packet")
|
|
||||||
}
|
|
||||||
iovecs := []syscall.Iovec{
|
|
||||||
{&head[0], 4},
|
|
||||||
{&from[0], uint64(len(from))},
|
|
||||||
}
|
|
||||||
|
|
||||||
n, _, errno := syscall.Syscall(syscall.SYS_WRITEV, uintptr(t.devFd), uintptr(unsafe.Pointer(&iovecs[0])), uintptr(2))
|
|
||||||
|
|
||||||
var err error
|
|
||||||
if errno != 0 {
|
|
||||||
err = syscall.Errno(errno)
|
|
||||||
} else {
|
|
||||||
err = nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return int(n) - 4, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *tun) Close() error {
|
|
||||||
if t.devFd >= 0 {
|
|
||||||
err := syscall.Close(t.devFd)
|
|
||||||
if err != nil {
|
|
||||||
t.l.WithError(err).Error("Error closing device")
|
|
||||||
}
|
|
||||||
t.devFd = -1
|
|
||||||
|
|
||||||
c := make(chan struct{})
|
|
||||||
go func() {
|
|
||||||
// destroying the interface can block if a read() is still pending. Do this asynchronously.
|
|
||||||
defer close(c)
|
|
||||||
s, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_DGRAM, syscall.IPPROTO_IP)
|
|
||||||
if err == nil {
|
|
||||||
defer syscall.Close(s)
|
|
||||||
ifreq := ifreqDestroy{Name: t.deviceBytes()}
|
|
||||||
err = ioctl(uintptr(s), syscall.SIOCIFDESTROY, uintptr(unsafe.Pointer(&ifreq)))
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
t.l.WithError(err).Error("Error destroying tunnel")
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
// wait up to 1 second so we start blocking at the ioctl
|
|
||||||
select {
|
|
||||||
case <-c:
|
|
||||||
case <-time.After(1 * time.Second):
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) {
|
|
||||||
return nil, fmt.Errorf("newTunFromFd not supported in FreeBSD")
|
|
||||||
}
|
|
||||||
|
|
||||||
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
|
|
||||||
// Try to open existing tun device
|
|
||||||
var fd int
|
|
||||||
var err error
|
|
||||||
deviceName := c.GetString("tun.dev", "")
|
|
||||||
if deviceName != "" {
|
|
||||||
fd, err = syscall.Open("/dev/"+deviceName, syscall.O_RDWR, 0)
|
|
||||||
}
|
|
||||||
if errors.Is(err, fs.ErrNotExist) || deviceName == "" {
|
|
||||||
// If the device doesn't already exist, request a new one and rename it
|
|
||||||
fd, err = syscall.Open("/dev/tun", syscall.O_RDWR, 0)
|
|
||||||
}
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, fmt.Errorf("failed to create TUN device: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Read the name of the interface
|
// Get the actual device name
|
||||||
var name [16]byte
|
actualName, err := tunDevice.Name()
|
||||||
arg := fiodgnameArg{length: 16, buf: unsafe.Pointer(&name)}
|
if err != nil {
|
||||||
ctrlErr := ioctl(uintptr(fd), FIODGNAME, uintptr(unsafe.Pointer(&arg)))
|
tunDevice.Close()
|
||||||
|
return nil, fmt.Errorf("failed to get TUN device name: %w", err)
|
||||||
if ctrlErr == nil {
|
|
||||||
// set broadcast mode and multicast
|
|
||||||
ifmode := uint32(unix.IFF_BROADCAST | unix.IFF_MULTICAST)
|
|
||||||
ctrlErr = ioctl(uintptr(fd), TUNSIFMODE, uintptr(unsafe.Pointer(&ifmode)))
|
|
||||||
}
|
|
||||||
|
|
||||||
if ctrlErr == nil {
|
|
||||||
// turn on link-layer mode, to support ipv6
|
|
||||||
ifhead := uint32(1)
|
|
||||||
ctrlErr = ioctl(uintptr(fd), TUNSIFHEAD, uintptr(unsafe.Pointer(&ifhead)))
|
|
||||||
}
|
|
||||||
|
|
||||||
if ctrlErr != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
ifName := string(bytes.TrimRight(name[:], "\x00"))
|
|
||||||
if deviceName == "" {
|
|
||||||
deviceName = ifName
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// If the name doesn't match the desired interface name, rename it now
|
// If the name doesn't match the desired interface name, rename it now
|
||||||
if ifName != deviceName {
|
if actualName != deviceName && deviceName != "" && deviceName != "tun" {
|
||||||
s, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP)
|
if err := renameInterface(actualName, deviceName); err != nil {
|
||||||
if err != nil {
|
tunDevice.Close()
|
||||||
return nil, err
|
return nil, fmt.Errorf("failed to rename interface from %s to %s: %w", actualName, deviceName, err)
|
||||||
}
|
}
|
||||||
defer syscall.Close(s)
|
actualName = deviceName
|
||||||
|
|
||||||
fd := uintptr(s)
|
|
||||||
|
|
||||||
var fromName [16]byte
|
|
||||||
var toName [16]byte
|
|
||||||
copy(fromName[:], ifName)
|
|
||||||
copy(toName[:], deviceName)
|
|
||||||
|
|
||||||
ifrr := ifreqRename{
|
|
||||||
Name: fromName,
|
|
||||||
Data: uintptr(unsafe.Pointer(&toName)),
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set the device name
|
|
||||||
ioctl(fd, syscall.SIOCSIFNAME, uintptr(unsafe.Pointer(&ifrr)))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
t := &tun{
|
t := &wgTun{
|
||||||
Device: deviceName,
|
tunDevice: tunDevice,
|
||||||
vpnNetworks: vpnNetworks,
|
vpnNetworks: vpnNetworks,
|
||||||
MTU: c.GetInt("tun.mtu", DefaultMTU),
|
MaxMTU: mtu,
|
||||||
|
DefaultMTU: mtu,
|
||||||
l: l,
|
l: l,
|
||||||
devFd: fd,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Create FreeBSD-specific route manager
|
||||||
|
t.routeManager = &tun{}
|
||||||
|
|
||||||
err = t.reload(c, true)
|
err = t.reload(c, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
tunDevice.Close()
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -289,180 +82,86 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
l.WithField("name", actualName).Info("Created WireGuard TUN device")
|
||||||
|
|
||||||
return t, nil
|
return t, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) addIp(cidr netip.Prefix) error {
|
func (rm *tun) Activate(t *wgTun) error {
|
||||||
if cidr.Addr().Is4() {
|
name, err := t.tunDevice.Name()
|
||||||
ifr := ifreqAlias4{
|
|
||||||
Name: t.deviceBytes(),
|
|
||||||
Addr: unix.RawSockaddrInet4{
|
|
||||||
Len: unix.SizeofSockaddrInet4,
|
|
||||||
Family: unix.AF_INET,
|
|
||||||
Addr: cidr.Addr().As4(),
|
|
||||||
},
|
|
||||||
DstAddr: unix.RawSockaddrInet4{
|
|
||||||
Len: unix.SizeofSockaddrInet4,
|
|
||||||
Family: unix.AF_INET,
|
|
||||||
Addr: getBroadcast(cidr).As4(),
|
|
||||||
},
|
|
||||||
MaskAddr: unix.RawSockaddrInet4{
|
|
||||||
Len: unix.SizeofSockaddrInet4,
|
|
||||||
Family: unix.AF_INET,
|
|
||||||
Addr: prefixToMask(cidr).As4(),
|
|
||||||
},
|
|
||||||
VHid: 0,
|
|
||||||
}
|
|
||||||
s, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer syscall.Close(s)
|
|
||||||
// Note: unix.SIOCAIFADDR corresponds to FreeBSD's OSIOCAIFADDR
|
|
||||||
if err := ioctl(uintptr(s), unix.SIOCAIFADDR, uintptr(unsafe.Pointer(&ifr))); err != nil {
|
|
||||||
return fmt.Errorf("failed to set tun address %s: %s", cidr.Addr().String(), err)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if cidr.Addr().Is6() {
|
|
||||||
ifr := ifreqAlias6{
|
|
||||||
Name: t.deviceBytes(),
|
|
||||||
Addr: unix.RawSockaddrInet6{
|
|
||||||
Len: unix.SizeofSockaddrInet6,
|
|
||||||
Family: unix.AF_INET6,
|
|
||||||
Addr: cidr.Addr().As16(),
|
|
||||||
},
|
|
||||||
PrefixMask: unix.RawSockaddrInet6{
|
|
||||||
Len: unix.SizeofSockaddrInet6,
|
|
||||||
Family: unix.AF_INET6,
|
|
||||||
Addr: prefixToMask(cidr).As16(),
|
|
||||||
},
|
|
||||||
Lifetime: addrLifetime{
|
|
||||||
Expire: 0,
|
|
||||||
Preferred: 0,
|
|
||||||
Vltime: 0xffffffff,
|
|
||||||
Pltime: 0xffffffff,
|
|
||||||
},
|
|
||||||
Flags: IN6_IFF_NODAD,
|
|
||||||
}
|
|
||||||
s, err := syscall.Socket(syscall.AF_INET6, syscall.SOCK_DGRAM, syscall.IPPROTO_IP)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer syscall.Close(s)
|
|
||||||
|
|
||||||
if err := ioctl(uintptr(s), OSIOCAIFADDR_IN6, uintptr(unsafe.Pointer(&ifr))); err != nil {
|
|
||||||
return fmt.Errorf("failed to set tun address %s: %s", cidr.Addr().String(), err)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return fmt.Errorf("unknown address type %v", cidr)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *tun) Activate() error {
|
|
||||||
// Setup our default MTU
|
|
||||||
err := t.setMTU()
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return fmt.Errorf("failed to get device name: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
linkAddr, err := getLinkAddr(t.Device)
|
// Set the MTU
|
||||||
if err != nil {
|
rm.SetMTU(t, t.MaxMTU)
|
||||||
return err
|
|
||||||
}
|
|
||||||
if linkAddr == nil {
|
|
||||||
return fmt.Errorf("unable to discover link_addr for tun interface")
|
|
||||||
}
|
|
||||||
t.linkAddr = linkAddr
|
|
||||||
|
|
||||||
for i := range t.vpnNetworks {
|
// Add IP addresses
|
||||||
err := t.addIp(t.vpnNetworks[i])
|
for _, network := range t.vpnNetworks {
|
||||||
if err != nil {
|
if err := rm.addIP(t, name, network); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return t.addRoutes(false)
|
// Bring up the interface
|
||||||
}
|
if err := runCommandBSD("ifconfig", name, "up"); err != nil {
|
||||||
|
return fmt.Errorf("failed to bring up interface: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
func (t *tun) setMTU() error {
|
// Set the routes
|
||||||
// Set the MTU on the device
|
if err := rm.AddRoutes(t, false); err != nil {
|
||||||
s, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
defer syscall.Close(s)
|
|
||||||
|
|
||||||
ifm := ifreqMTU{Name: t.deviceBytes(), MTU: int32(t.MTU)}
|
|
||||||
err = ioctl(uintptr(s), unix.SIOCSIFMTU, uintptr(unsafe.Pointer(&ifm)))
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *tun) reload(c *config.C, initial bool) error {
|
|
||||||
change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if !initial && !change {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
routeTree, err := makeRouteTree(t.l, routes, false)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Teach nebula how to handle the routes before establishing them in the system table
|
|
||||||
oldRoutes := t.Routes.Swap(&routes)
|
|
||||||
t.routeTree.Store(routeTree)
|
|
||||||
|
|
||||||
if !initial {
|
|
||||||
// Remove first, if the system removes a wanted route hopefully it will be re-added next
|
|
||||||
err := t.removeRoutes(findRemovedRoutes(routes, *oldRoutes))
|
|
||||||
if err != nil {
|
|
||||||
util.LogWithContextIfNeeded("Failed to remove routes", err, t.l)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Ensure any routes we actually want are installed
|
|
||||||
err = t.addRoutes(true)
|
|
||||||
if err != nil {
|
|
||||||
// Catch any stray logs
|
|
||||||
util.LogWithContextIfNeeded("Failed to add routes", err, t.l)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
|
func (rm *tun) SetMTU(t *wgTun, mtu int) {
|
||||||
r, _ := t.routeTree.Load().Lookup(ip)
|
name, err := t.tunDevice.Name()
|
||||||
return r
|
if err != nil {
|
||||||
|
t.l.WithError(err).Error("Failed to get device name for MTU set")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := runCommandBSD("ifconfig", name, "mtu", strconv.Itoa(mtu)); err != nil {
|
||||||
|
t.l.WithError(err).Error("Failed to set tun mtu")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) Networks() []netip.Prefix {
|
func (rm *tun) SetDefaultRoute(t *wgTun, cidr netip.Prefix) error {
|
||||||
return t.vpnNetworks
|
// On FreeBSD, routes are set via ifconfig and route commands
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) Name() string {
|
func (rm *tun) AddRoutes(t *wgTun, logErrors bool) error {
|
||||||
return t.Device
|
name, err := t.tunDevice.Name()
|
||||||
}
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get device name: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
|
||||||
return nil, fmt.Errorf("TODO: multiqueue not implemented for freebsd")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *tun) addRoutes(logErrors bool) error {
|
|
||||||
routes := *t.Routes.Load()
|
routes := *t.Routes.Load()
|
||||||
for _, r := range routes {
|
for _, r := range routes {
|
||||||
if len(r.Via) == 0 || !r.Install {
|
if !r.Install {
|
||||||
// We don't allow route MTUs so only install routes with a via
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
err := addRoute(r.Cidr, t.linkAddr)
|
// Add route using route command
|
||||||
|
args := []string{"add"}
|
||||||
|
|
||||||
|
if r.Cidr.Addr().Is6() {
|
||||||
|
args = append(args, "-inet6")
|
||||||
|
} else {
|
||||||
|
args = append(args, "-inet")
|
||||||
|
}
|
||||||
|
|
||||||
|
args = append(args, r.Cidr.String(), "-interface", name)
|
||||||
|
|
||||||
|
if r.Metric > 0 {
|
||||||
|
// FreeBSD doesn't support route metrics directly like Linux
|
||||||
|
t.l.WithField("route", r).Warn("Route metrics are not fully supported on FreeBSD")
|
||||||
|
}
|
||||||
|
|
||||||
|
err := runCommandBSD("route", args...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
retErr := util.NewContextualError("Failed to add route", map[string]any{"route": r}, err)
|
retErr := util.NewContextualError("Failed to add route", map[string]any{"route": r}, err)
|
||||||
if logErrors {
|
if logErrors {
|
||||||
@@ -478,142 +177,99 @@ func (t *tun) addRoutes(logErrors bool) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) removeRoutes(routes []Route) error {
|
func (rm *tun) RemoveRoutes(t *wgTun, routes []Route) {
|
||||||
|
name, err := t.tunDevice.Name()
|
||||||
|
if err != nil {
|
||||||
|
t.l.WithError(err).Error("Failed to get device name for route removal")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
for _, r := range routes {
|
for _, r := range routes {
|
||||||
if !r.Install {
|
if !r.Install {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
err := delRoute(r.Cidr, t.linkAddr)
|
args := []string{"delete"}
|
||||||
|
|
||||||
|
if r.Cidr.Addr().Is6() {
|
||||||
|
args = append(args, "-inet6")
|
||||||
|
} else {
|
||||||
|
args = append(args, "-inet")
|
||||||
|
}
|
||||||
|
|
||||||
|
args = append(args, r.Cidr.String(), "-interface", name)
|
||||||
|
|
||||||
|
err := runCommandBSD("route", args...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
|
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
|
||||||
} else {
|
} else {
|
||||||
t.l.WithField("route", r).Info("Removed route")
|
t.l.WithField("route", r).Info("Removed route")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) deviceBytes() (o [16]byte) {
|
func (rm *tun) NewMultiQueueReader(t *wgTun) (io.ReadWriteCloser, error) {
|
||||||
for i, c := range t.Device {
|
// FreeBSD doesn't support multi-queue TUN devices in the same way as Linux
|
||||||
o[i] = byte(c)
|
// Return a reader that wraps the same device
|
||||||
}
|
return &wgTunReader{
|
||||||
return
|
parent: t,
|
||||||
|
tunDevice: t.tunDevice,
|
||||||
|
offset: 0,
|
||||||
|
l: t.l,
|
||||||
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func addRoute(prefix netip.Prefix, gateway netroute.Addr) error {
|
func (rm *tun) addIP(t *wgTun, name string, network netip.Prefix) error {
|
||||||
sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
|
addr := network.Addr()
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
|
|
||||||
}
|
|
||||||
defer unix.Close(sock)
|
|
||||||
|
|
||||||
route := &netroute.RouteMessage{
|
if addr.Is4() {
|
||||||
Version: unix.RTM_VERSION,
|
// For IPv4: ifconfig tun0 10.0.0.1/24
|
||||||
Type: unix.RTM_ADD,
|
if err := runCommandBSD("ifconfig", name, network.String()); err != nil {
|
||||||
Flags: unix.RTF_UP,
|
return fmt.Errorf("failed to add IPv4 address: %w", err)
|
||||||
Seq: 1,
|
|
||||||
}
|
|
||||||
|
|
||||||
if prefix.Addr().Is4() {
|
|
||||||
route.Addrs = []netroute.Addr{
|
|
||||||
unix.RTAX_DST: &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()},
|
|
||||||
unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()},
|
|
||||||
unix.RTAX_GATEWAY: gateway,
|
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
route.Addrs = []netroute.Addr{
|
// For IPv6: ifconfig tun0 inet6 add 2001:db8::1/64
|
||||||
unix.RTAX_DST: &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()},
|
if err := runCommandBSD("ifconfig", name, "inet6", "add", network.String()); err != nil {
|
||||||
unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()},
|
return fmt.Errorf("failed to add IPv6 address: %w", err)
|
||||||
unix.RTAX_GATEWAY: gateway,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
data, err := route.Marshal()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to create route.RouteMessage: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = unix.Write(sock, data[:])
|
|
||||||
if err != nil {
|
|
||||||
if errors.Is(err, unix.EEXIST) {
|
|
||||||
// Try to do a change
|
|
||||||
route.Type = unix.RTM_CHANGE
|
|
||||||
data, err = route.Marshal()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to create route.RouteMessage for change: %w", err)
|
|
||||||
}
|
|
||||||
_, err = unix.Write(sock, data[:])
|
|
||||||
fmt.Println("DOING CHANGE")
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func delRoute(prefix netip.Prefix, gateway netroute.Addr) error {
|
func runCommandBSD(name string, args ...string) error {
|
||||||
sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
|
cmd := exec.Command(name, args...)
|
||||||
|
output, err := cmd.CombinedOutput()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
|
return fmt.Errorf("%s %s failed: %w\nOutput: %s", name, strings.Join(args, " "), err, string(output))
|
||||||
}
|
}
|
||||||
defer unix.Close(sock)
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
route := netroute.RouteMessage{
|
func renameInterface(fromName, toName string) error {
|
||||||
Version: unix.RTM_VERSION,
|
s, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP)
|
||||||
Type: unix.RTM_DELETE,
|
if err != nil {
|
||||||
Seq: 1,
|
return fmt.Errorf("failed to create socket: %w", err)
|
||||||
|
}
|
||||||
|
defer syscall.Close(s)
|
||||||
|
|
||||||
|
fd := uintptr(s)
|
||||||
|
|
||||||
|
var fromNameBytes [unix.IFNAMSIZ]byte
|
||||||
|
var toNameBytes [unix.IFNAMSIZ]byte
|
||||||
|
copy(fromNameBytes[:], fromName)
|
||||||
|
copy(toNameBytes[:], toName)
|
||||||
|
|
||||||
|
ifrr := ifreqRename{
|
||||||
|
Name: fromNameBytes,
|
||||||
|
Data: uintptr(unsafe.Pointer(&toNameBytes)),
|
||||||
}
|
}
|
||||||
|
|
||||||
if prefix.Addr().Is4() {
|
// Set the device name using SIOCSIFNAME ioctl
|
||||||
route.Addrs = []netroute.Addr{
|
_, _, errno := syscall.Syscall(syscall.SYS_IOCTL, fd, syscall.SIOCSIFNAME, uintptr(unsafe.Pointer(&ifrr)))
|
||||||
unix.RTAX_DST: &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()},
|
if errno != 0 {
|
||||||
unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()},
|
return fmt.Errorf("SIOCSIFNAME ioctl failed: %w", errno)
|
||||||
unix.RTAX_GATEWAY: gateway,
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
route.Addrs = []netroute.Addr{
|
|
||||||
unix.RTAX_DST: &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()},
|
|
||||||
unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()},
|
|
||||||
unix.RTAX_GATEWAY: gateway,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
data, err := route.Marshal()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to create route.RouteMessage: %w", err)
|
|
||||||
}
|
|
||||||
_, err = unix.Write(sock, data[:])
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// getLinkAddr Gets the link address for the interface of the given name
|
|
||||||
func getLinkAddr(name string) (*netroute.LinkAddr, error) {
|
|
||||||
rib, err := netroute.FetchRIB(unix.AF_UNSPEC, unix.NET_RT_IFLIST, 0)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
msgs, err := netroute.ParseRIB(unix.NET_RT_IFLIST, rib)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, m := range msgs {
|
|
||||||
switch m := m.(type) {
|
|
||||||
case *netroute.InterfaceMessage:
|
|
||||||
if m.Name == name {
|
|
||||||
sa, ok := m.Addrs[unix.RTAX_IFP].(*netroute.LinkAddr)
|
|
||||||
if ok {
|
|
||||||
return sa, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -1,175 +1,113 @@
|
|||||||
//go:build !android && !e2e_testing
|
//go:build linux && !android && !e2e_testing
|
||||||
// +build !android,!e2e_testing
|
// +build linux,!android,!e2e_testing
|
||||||
|
|
||||||
package overlay
|
package overlay
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
"strings"
|
|
||||||
"sync/atomic"
|
|
||||||
"time"
|
"time"
|
||||||
"unsafe"
|
"unsafe"
|
||||||
|
|
||||||
"github.com/gaissmai/bart"
|
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
"github.com/slackhq/nebula/overlay/vhostnet"
|
|
||||||
"github.com/slackhq/nebula/packet"
|
|
||||||
"github.com/slackhq/nebula/routing"
|
"github.com/slackhq/nebula/routing"
|
||||||
"github.com/slackhq/nebula/util"
|
"github.com/slackhq/nebula/util"
|
||||||
"github.com/slackhq/nebula/util/virtio"
|
|
||||||
"github.com/vishvananda/netlink"
|
"github.com/vishvananda/netlink"
|
||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
|
wgtun "golang.zx2c4.com/wireguard/tun"
|
||||||
)
|
)
|
||||||
|
|
||||||
type tun struct {
|
type tun struct {
|
||||||
file *os.File
|
deviceIndex int
|
||||||
fd int
|
ioctlFd uintptr
|
||||||
vdev []*vhostnet.Device
|
txQueueLen int
|
||||||
Device string
|
|
||||||
vpnNetworks []netip.Prefix
|
|
||||||
MaxMTU int
|
|
||||||
DefaultMTU int
|
|
||||||
TXQueueLen int
|
|
||||||
deviceIndex int
|
|
||||||
ioctlFd uintptr
|
|
||||||
|
|
||||||
Routes atomic.Pointer[[]Route]
|
|
||||||
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
|
||||||
routeChan chan struct{}
|
|
||||||
useSystemRoutes bool
|
useSystemRoutes bool
|
||||||
useSystemRoutesBufferSize int
|
useSystemRoutesBufferSize int
|
||||||
|
|
||||||
isV6 bool
|
|
||||||
l *logrus.Logger
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) Networks() []netip.Prefix {
|
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueue bool) (*wgTun, error) {
|
||||||
return t.vpnNetworks
|
deviceName := c.GetString("tun.dev", "")
|
||||||
}
|
mtu := c.GetInt("tun.mtu", DefaultMTU)
|
||||||
|
|
||||||
type ifReq struct {
|
// Create WireGuard TUN device
|
||||||
Name [16]byte
|
tunDevice, err := wgtun.CreateTUN(deviceName, mtu)
|
||||||
Flags uint16
|
|
||||||
pad [8]byte
|
|
||||||
}
|
|
||||||
|
|
||||||
type ifreqMTU struct {
|
|
||||||
Name [16]byte
|
|
||||||
MTU int32
|
|
||||||
pad [8]byte
|
|
||||||
}
|
|
||||||
|
|
||||||
type ifreqQLEN struct {
|
|
||||||
Name [16]byte
|
|
||||||
Value int32
|
|
||||||
pad [8]byte
|
|
||||||
}
|
|
||||||
|
|
||||||
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) {
|
|
||||||
file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
|
|
||||||
|
|
||||||
t, err := newTunGeneric(c, l, file, vpnNetworks)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, fmt.Errorf("failed to create TUN device: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
t.Device = "tun0"
|
// Get the actual device name
|
||||||
|
actualName, err := tunDevice.Name()
|
||||||
return t, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueue bool) (*tun, error) {
|
|
||||||
fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// If /dev/net/tun doesn't exist, try to create it (will happen in docker)
|
tunDevice.Close()
|
||||||
if os.IsNotExist(err) {
|
return nil, fmt.Errorf("failed to get TUN device name: %w", err)
|
||||||
err = os.MkdirAll("/dev/net", 0755)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("/dev/net/tun doesn't exist, failed to mkdir -p /dev/net: %w", err)
|
|
||||||
}
|
|
||||||
err = unix.Mknod("/dev/net/tun", unix.S_IFCHR|0600, int(unix.Mkdev(10, 200)))
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to create /dev/net/tun: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
fd, err = unix.Open("/dev/net/tun", os.O_RDWR, 0)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("created /dev/net/tun, but still failed: %w", err)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var req ifReq
|
t := &wgTun{
|
||||||
req.Flags = uint16(unix.IFF_TUN | unix.IFF_NO_PI | unix.IFF_TUN_EXCL | unix.IFF_VNET_HDR | unix.IFF_NAPI)
|
tunDevice: tunDevice,
|
||||||
if multiqueue {
|
vpnNetworks: vpnNetworks,
|
||||||
req.Flags |= unix.IFF_MULTI_QUEUE
|
MaxMTU: mtu,
|
||||||
}
|
DefaultMTU: mtu,
|
||||||
copy(req.Name[:], c.GetString("tun.dev", ""))
|
l: l,
|
||||||
if err = ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
name := strings.Trim(string(req.Name[:]), "\x00")
|
|
||||||
|
|
||||||
if err = unix.SetNonblock(fd, true); err != nil {
|
|
||||||
_ = unix.Close(fd)
|
|
||||||
return nil, fmt.Errorf("make file descriptor non-blocking: %w", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
file := os.NewFile(uintptr(fd), "/dev/net/tun")
|
// Create Linux-specific route manager
|
||||||
|
routeManager := &tun{
|
||||||
err = unix.IoctlSetPointerInt(fd, unix.TUNSETVNETHDRSZ, virtio.NetHdrSize)
|
txQueueLen: c.GetInt("tun.tx_queue", 500),
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("set vnethdr size: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
flags := 0
|
|
||||||
//flags = //unix.TUN_F_CSUM //| unix.TUN_F_TSO4 | unix.TUN_F_USO4 | unix.TUN_F_TSO6 | unix.TUN_F_USO6
|
|
||||||
err = unix.IoctlSetInt(fd, unix.TUNSETOFFLOAD, flags)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("set offloads: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
t, err := newTunGeneric(c, l, file, vpnNetworks)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
t.fd = fd
|
|
||||||
t.Device = name
|
|
||||||
|
|
||||||
vdev, err := vhostnet.NewDevice(
|
|
||||||
vhostnet.WithBackendFD(fd),
|
|
||||||
vhostnet.WithQueueSize(8192), //todo config
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
t.vdev = []*vhostnet.Device{vdev}
|
|
||||||
|
|
||||||
return t, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, vpnNetworks []netip.Prefix) (*tun, error) {
|
|
||||||
t := &tun{
|
|
||||||
file: file,
|
|
||||||
fd: int(file.Fd()),
|
|
||||||
vpnNetworks: vpnNetworks,
|
|
||||||
TXQueueLen: c.GetInt("tun.tx_queue", 500),
|
|
||||||
useSystemRoutes: c.GetBool("tun.use_system_route_table", false),
|
useSystemRoutes: c.GetBool("tun.use_system_route_table", false),
|
||||||
useSystemRoutesBufferSize: c.GetInt("tun.use_system_route_table_buffer_size", 0),
|
useSystemRoutesBufferSize: c.GetInt("tun.use_system_route_table_buffer_size", 0),
|
||||||
l: l,
|
|
||||||
}
|
}
|
||||||
if len(vpnNetworks) != 0 {
|
t.routeManager = routeManager
|
||||||
t.isV6 = vpnNetworks[0].Addr().Is6() //todo what about multi-IP?
|
|
||||||
|
err = t.reload(c, true)
|
||||||
|
if err != nil {
|
||||||
|
tunDevice.Close()
|
||||||
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
err := t.reload(c, true)
|
c.RegisterReloadCallback(func(c *config.C) {
|
||||||
|
err := t.reload(c, false)
|
||||||
|
if err != nil {
|
||||||
|
util.LogWithContextIfNeeded("failed to reload tun device", err, t.l)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
l.WithField("name", actualName).Info("Created WireGuard TUN device")
|
||||||
|
|
||||||
|
return t, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*wgTun, error) {
|
||||||
|
// Create TUN device from file descriptor
|
||||||
|
file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
|
||||||
|
mtu := c.GetInt("tun.mtu", DefaultMTU)
|
||||||
|
tunDevice, err := wgtun.CreateTUNFromFile(file, mtu)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create TUN device from fd: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
t := &wgTun{
|
||||||
|
tunDevice: tunDevice,
|
||||||
|
vpnNetworks: vpnNetworks,
|
||||||
|
MaxMTU: mtu,
|
||||||
|
DefaultMTU: mtu,
|
||||||
|
l: l,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create Linux-specific route manager
|
||||||
|
routeManager := &tun{
|
||||||
|
txQueueLen: c.GetInt("tun.tx_queue", 500),
|
||||||
|
useSystemRoutes: c.GetBool("tun.use_system_route_table", false),
|
||||||
|
useSystemRoutesBufferSize: c.GetInt("tun.use_system_route_table_buffer_size", 0),
|
||||||
|
}
|
||||||
|
t.routeManager = routeManager
|
||||||
|
|
||||||
|
err = t.reload(c, true)
|
||||||
|
if err != nil {
|
||||||
|
tunDevice.Close()
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -183,258 +121,105 @@ func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, vpnNetworks []n
|
|||||||
return t, nil
|
return t, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) reload(c *config.C, initial bool) error {
|
func (rm *tun) Activate(t *wgTun) error {
|
||||||
routeChange, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial)
|
name, err := t.tunDevice.Name()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return fmt.Errorf("failed to get device name: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !initial && !routeChange && !c.HasChanged("tun.mtu") {
|
if t.routeManager.useSystemRoutes {
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
routeTree, err := makeRouteTree(t.l, routes, true)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
oldDefaultMTU := t.DefaultMTU
|
|
||||||
oldMaxMTU := t.MaxMTU
|
|
||||||
newDefaultMTU := c.GetInt("tun.mtu", DefaultMTU)
|
|
||||||
newMaxMTU := newDefaultMTU
|
|
||||||
for i, r := range routes {
|
|
||||||
if r.MTU == 0 {
|
|
||||||
routes[i].MTU = newDefaultMTU
|
|
||||||
}
|
|
||||||
|
|
||||||
if r.MTU > t.MaxMTU {
|
|
||||||
newMaxMTU = r.MTU
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
t.MaxMTU = newMaxMTU
|
|
||||||
t.DefaultMTU = newDefaultMTU
|
|
||||||
|
|
||||||
// Teach nebula how to handle the routes before establishing them in the system table
|
|
||||||
oldRoutes := t.Routes.Swap(&routes)
|
|
||||||
t.routeTree.Store(routeTree)
|
|
||||||
|
|
||||||
if !initial {
|
|
||||||
if oldMaxMTU != newMaxMTU {
|
|
||||||
t.setMTU()
|
|
||||||
t.l.Infof("Set max MTU to %v was %v", t.MaxMTU, oldMaxMTU)
|
|
||||||
}
|
|
||||||
|
|
||||||
if oldDefaultMTU != newDefaultMTU {
|
|
||||||
for i := range t.vpnNetworks {
|
|
||||||
err := t.setDefaultRoute(t.vpnNetworks[i])
|
|
||||||
if err != nil {
|
|
||||||
t.l.Warn(err)
|
|
||||||
} else {
|
|
||||||
t.l.Infof("Set default MTU to %v was %v", t.DefaultMTU, oldDefaultMTU)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Remove first, if the system removes a wanted route hopefully it will be re-added next
|
|
||||||
t.removeRoutes(findRemovedRoutes(routes, *oldRoutes))
|
|
||||||
|
|
||||||
// Ensure any routes we actually want are installed
|
|
||||||
err = t.addRoutes(true)
|
|
||||||
if err != nil {
|
|
||||||
// This should never be called since addRoutes should log its own errors in a reload condition
|
|
||||||
util.LogWithContextIfNeeded("Failed to refresh routes", err, t.l)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *tun) NewMultiQueueReader() (TunDev, error) {
|
|
||||||
fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var req ifReq
|
|
||||||
req.Flags = uint16(unix.IFF_TUN | unix.IFF_NO_PI | unix.IFF_MULTI_QUEUE)
|
|
||||||
copy(req.Name[:], t.Device)
|
|
||||||
if err = ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
vdev, err := vhostnet.NewDevice(
|
|
||||||
vhostnet.WithBackendFD(fd),
|
|
||||||
vhostnet.WithQueueSize(8192), //todo config
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
t.vdev = append(t.vdev, vdev)
|
|
||||||
|
|
||||||
return t, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
|
|
||||||
r, _ := t.routeTree.Load().Lookup(ip)
|
|
||||||
return r
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *tun) deviceBytes() (o [16]byte) {
|
|
||||||
for i, c := range t.Device {
|
|
||||||
o[i] = byte(c)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func hasNetlinkAddr(al []*netlink.Addr, x netlink.Addr) bool {
|
|
||||||
for i := range al {
|
|
||||||
if al[i].Equal(x) {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// addIPs uses netlink to add all addresses that don't exist, then it removes ones that should not be there
|
|
||||||
func (t *tun) addIPs(link netlink.Link) error {
|
|
||||||
newAddrs := make([]*netlink.Addr, len(t.vpnNetworks))
|
|
||||||
for i := range t.vpnNetworks {
|
|
||||||
newAddrs[i] = &netlink.Addr{
|
|
||||||
IPNet: &net.IPNet{
|
|
||||||
IP: t.vpnNetworks[i].Addr().AsSlice(),
|
|
||||||
Mask: net.CIDRMask(t.vpnNetworks[i].Bits(), t.vpnNetworks[i].Addr().BitLen()),
|
|
||||||
},
|
|
||||||
Label: t.vpnNetworks[i].Addr().Zone(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
//add all new addresses
|
|
||||||
for i := range newAddrs {
|
|
||||||
//AddrReplace still adds new IPs, but if their properties change it will change them as well
|
|
||||||
if err := netlink.AddrReplace(link, newAddrs[i]); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
//iterate over remainder, remove whoever shouldn't be there
|
|
||||||
al, err := netlink.AddrList(link, netlink.FAMILY_ALL)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to get tun address list: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
for i := range al {
|
|
||||||
if hasNetlinkAddr(newAddrs, al[i]) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
err = netlink.AddrDel(link, &al[i])
|
|
||||||
if err != nil {
|
|
||||||
t.l.WithError(err).Error("failed to remove address from tun address list")
|
|
||||||
} else {
|
|
||||||
t.l.WithField("removed", al[i].String()).Info("removed address not listed in cert(s)")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *tun) Activate() error {
|
|
||||||
devName := t.deviceBytes()
|
|
||||||
|
|
||||||
if t.useSystemRoutes {
|
|
||||||
t.watchRoutes()
|
t.watchRoutes()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Get the netlink device
|
||||||
|
link, err := netlink.LinkByName(name)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get tun device link: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
rm.deviceIndex = link.Attrs().Index
|
||||||
|
|
||||||
|
// Open socket for ioctl operations
|
||||||
s, err := unix.Socket(
|
s, err := unix.Socket(
|
||||||
unix.AF_INET, //because everything we use t.ioctlFd for is address family independent, this is fine
|
unix.AF_INET,
|
||||||
unix.SOCK_DGRAM,
|
unix.SOCK_DGRAM,
|
||||||
unix.IPPROTO_IP,
|
unix.IPPROTO_IP,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
t.ioctlFd = uintptr(s)
|
rm.ioctlFd = uintptr(s)
|
||||||
|
|
||||||
// Set the device name
|
rm.SetMTU(t, t.MaxMTU)
|
||||||
ifrf := ifReq{Name: devName}
|
|
||||||
if err = ioctl(t.ioctlFd, unix.SIOCGIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil {
|
|
||||||
return fmt.Errorf("failed to set tun device name: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
link, err := netlink.LinkByName(t.Device)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to get tun device link: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
t.deviceIndex = link.Attrs().Index
|
|
||||||
|
|
||||||
// Setup our default MTU
|
|
||||||
t.setMTU()
|
|
||||||
|
|
||||||
// Set the transmit queue length
|
// Set the transmit queue length
|
||||||
ifrq := ifreqQLEN{Name: devName, Value: int32(t.TXQueueLen)}
|
devName := deviceBytes(name)
|
||||||
if err = ioctl(t.ioctlFd, unix.SIOCSIFTXQLEN, uintptr(unsafe.Pointer(&ifrq))); err != nil {
|
ifrq := ifreqQLEN{Name: devName, Value: int32(rm.txQueueLen)}
|
||||||
|
if err = ioctl(t.routeManager.ioctlFd, unix.SIOCSIFTXQLEN, uintptr(unsafe.Pointer(&ifrq))); err != nil {
|
||||||
// If we can't set the queue length nebula will still work but it may lead to packet loss
|
// If we can't set the queue length nebula will still work but it may lead to packet loss
|
||||||
t.l.WithError(err).Error("Failed to set tun tx queue length")
|
t.l.WithError(err).Error("Failed to set tun tx queue length")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Disable IPv6 link-local address generation
|
||||||
const modeNone = 1
|
const modeNone = 1
|
||||||
if err = netlink.LinkSetIP6AddrGenMode(link, modeNone); err != nil {
|
if err = netlink.LinkSetIP6AddrGenMode(link, modeNone); err != nil {
|
||||||
t.l.WithError(err).Warn("Failed to disable link local address generation")
|
t.l.WithError(err).Warn("Failed to disable link local address generation")
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = t.addIPs(link); err != nil {
|
// Add IP addresses
|
||||||
|
if err = t.routeManager.addIPs(t, link); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Bring up the interface
|
// Bring up the interface
|
||||||
ifrf.Flags = ifrf.Flags | unix.IFF_UP
|
if err = netlink.LinkSetUp(link); err != nil {
|
||||||
if err = ioctl(t.ioctlFd, unix.SIOCSIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil {
|
|
||||||
return fmt.Errorf("failed to bring the tun device up: %s", err)
|
return fmt.Errorf("failed to bring the tun device up: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
//set route MTU
|
// Set route MTU
|
||||||
for i := range t.vpnNetworks {
|
for i := range t.vpnNetworks {
|
||||||
if err = t.setDefaultRoute(t.vpnNetworks[i]); err != nil {
|
if err = t.routeManager.SetDefaultRoute(t, t.vpnNetworks[i]); err != nil {
|
||||||
return fmt.Errorf("failed to set default route MTU: %w", err)
|
return fmt.Errorf("failed to set default route MTU: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set the routes
|
// Set the routes
|
||||||
if err = t.addRoutes(false); err != nil {
|
if err = t.routeManager.AddRoutes(t, false); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Run the interface
|
|
||||||
ifrf.Flags = ifrf.Flags | unix.IFF_UP | unix.IFF_RUNNING
|
|
||||||
if err = ioctl(t.ioctlFd, unix.SIOCSIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil {
|
|
||||||
return fmt.Errorf("failed to run tun device: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) setMTU() {
|
func (rm *tun) SetMTU(t *wgTun, mtu int) {
|
||||||
// Set the MTU on the device
|
name, err := t.tunDevice.Name()
|
||||||
ifm := ifreqMTU{Name: t.deviceBytes(), MTU: int32(t.MaxMTU)}
|
if err != nil {
|
||||||
if err := ioctl(t.ioctlFd, unix.SIOCSIFMTU, uintptr(unsafe.Pointer(&ifm))); err != nil {
|
t.l.WithError(err).Error("Failed to get device name for MTU set")
|
||||||
// This is currently a non fatal condition because the route table must have the MTU set appropriately as well
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
link, err := netlink.LinkByName(name)
|
||||||
|
if err != nil {
|
||||||
|
t.l.WithError(err).Error("Failed to get link for MTU set")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := netlink.LinkSetMTU(link, mtu); err != nil {
|
||||||
t.l.WithError(err).Error("Failed to set tun mtu")
|
t.l.WithError(err).Error("Failed to set tun mtu")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) setDefaultRoute(cidr netip.Prefix) error {
|
func (rm *tun) SetDefaultRoute(t *wgTun, cidr netip.Prefix) error {
|
||||||
dr := &net.IPNet{
|
dr := &net.IPNet{
|
||||||
IP: cidr.Masked().Addr().AsSlice(),
|
IP: cidr.Masked().Addr().AsSlice(),
|
||||||
Mask: net.CIDRMask(cidr.Bits(), cidr.Addr().BitLen()),
|
Mask: net.CIDRMask(cidr.Bits(), cidr.Addr().BitLen()),
|
||||||
}
|
}
|
||||||
|
|
||||||
nr := netlink.Route{
|
nr := netlink.Route{
|
||||||
LinkIndex: t.deviceIndex,
|
LinkIndex: t.routeManager.deviceIndex,
|
||||||
Dst: dr,
|
Dst: dr,
|
||||||
MTU: t.DefaultMTU,
|
MTU: t.DefaultMTU,
|
||||||
AdvMSS: t.advMSS(Route{}),
|
AdvMSS: advMSS(Route{}, t.DefaultMTU, t.MaxMTU),
|
||||||
Scope: unix.RT_SCOPE_LINK,
|
Scope: unix.RT_SCOPE_LINK,
|
||||||
Src: net.IP(cidr.Addr().AsSlice()),
|
Src: net.IP(cidr.Addr().AsSlice()),
|
||||||
Protocol: unix.RTPROT_KERNEL,
|
Protocol: unix.RTPROT_KERNEL,
|
||||||
@@ -444,7 +229,7 @@ func (t *tun) setDefaultRoute(cidr netip.Prefix) error {
|
|||||||
err := netlink.RouteReplace(&nr)
|
err := netlink.RouteReplace(&nr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.l.WithError(err).WithField("cidr", cidr).Warn("Failed to set default route MTU, retrying")
|
t.l.WithError(err).WithField("cidr", cidr).Warn("Failed to set default route MTU, retrying")
|
||||||
//retry twice more -- on some systems there appears to be a race condition where if we set routes too soon, netlink says `invalid argument`
|
// Retry twice more
|
||||||
for i := 0; i < 2; i++ {
|
for i := 0; i < 2; i++ {
|
||||||
time.Sleep(100 * time.Millisecond)
|
time.Sleep(100 * time.Millisecond)
|
||||||
err = netlink.RouteReplace(&nr)
|
err = netlink.RouteReplace(&nr)
|
||||||
@@ -462,8 +247,7 @@ func (t *tun) setDefaultRoute(cidr netip.Prefix) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) addRoutes(logErrors bool) error {
|
func (rm *tun) AddRoutes(t *wgTun, logErrors bool) error {
|
||||||
// Path routes
|
|
||||||
routes := *t.Routes.Load()
|
routes := *t.Routes.Load()
|
||||||
for _, r := range routes {
|
for _, r := range routes {
|
||||||
if !r.Install {
|
if !r.Install {
|
||||||
@@ -476,10 +260,10 @@ func (t *tun) addRoutes(logErrors bool) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
nr := netlink.Route{
|
nr := netlink.Route{
|
||||||
LinkIndex: t.deviceIndex,
|
LinkIndex: t.routeManager.deviceIndex,
|
||||||
Dst: dr,
|
Dst: dr,
|
||||||
MTU: r.MTU,
|
MTU: r.MTU,
|
||||||
AdvMSS: t.advMSS(r),
|
AdvMSS: advMSS(r, t.DefaultMTU, t.MaxMTU),
|
||||||
Scope: unix.RT_SCOPE_LINK,
|
Scope: unix.RT_SCOPE_LINK,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -503,7 +287,7 @@ func (t *tun) addRoutes(logErrors bool) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) removeRoutes(routes []Route) {
|
func (rm *tun) RemoveRoutes(t *wgTun, routes []Route) {
|
||||||
for _, r := range routes {
|
for _, r := range routes {
|
||||||
if !r.Install {
|
if !r.Install {
|
||||||
continue
|
continue
|
||||||
@@ -515,10 +299,10 @@ func (t *tun) removeRoutes(routes []Route) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
nr := netlink.Route{
|
nr := netlink.Route{
|
||||||
LinkIndex: t.deviceIndex,
|
LinkIndex: t.routeManager.deviceIndex,
|
||||||
Dst: dr,
|
Dst: dr,
|
||||||
MTU: r.MTU,
|
MTU: r.MTU,
|
||||||
AdvMSS: t.advMSS(r),
|
AdvMSS: advMSS(r, t.DefaultMTU, t.MaxMTU),
|
||||||
Scope: unix.RT_SCOPE_LINK,
|
Scope: unix.RT_SCOPE_LINK,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -535,30 +319,105 @@ func (t *tun) removeRoutes(routes []Route) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) Name() string {
|
func (rm *tun) NewMultiQueueReader(t *wgTun) (io.ReadWriteCloser, error) {
|
||||||
return t.Device
|
// For Linux with WireGuard TUN, we can reuse the same device
|
||||||
|
// The vectorized I/O will handle batching
|
||||||
|
return &wgTunReader{
|
||||||
|
parent: t,
|
||||||
|
tunDevice: t.tunDevice,
|
||||||
|
offset: 0,
|
||||||
|
l: t.l,
|
||||||
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) advMSS(r Route) int {
|
func deviceBytes(name string) [16]byte {
|
||||||
|
var o [16]byte
|
||||||
|
for i, c := range name {
|
||||||
|
if i >= 16 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
o[i] = byte(c)
|
||||||
|
}
|
||||||
|
return o
|
||||||
|
}
|
||||||
|
|
||||||
|
func advMSS(r Route, defaultMTU, maxMTU int) int {
|
||||||
mtu := r.MTU
|
mtu := r.MTU
|
||||||
if r.MTU == 0 {
|
if r.MTU == 0 {
|
||||||
mtu = t.DefaultMTU
|
mtu = defaultMTU
|
||||||
}
|
}
|
||||||
|
|
||||||
// We only need to set advmss if the route MTU does not match the device MTU
|
// We only need to set advmss if the route MTU does not match the device MTU
|
||||||
if mtu != t.MaxMTU {
|
if mtu != maxMTU {
|
||||||
return mtu - 40
|
return mtu - 40
|
||||||
}
|
}
|
||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) watchRoutes() {
|
type ifreqQLEN struct {
|
||||||
|
Name [16]byte
|
||||||
|
Value int32
|
||||||
|
pad [8]byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func hasNetlinkAddr(al []*netlink.Addr, x netlink.Addr) bool {
|
||||||
|
for i := range al {
|
||||||
|
if al[i].Equal(x) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rm *tun) addIPs(t *wgTun, link netlink.Link) error {
|
||||||
|
newAddrs := make([]*netlink.Addr, len(t.vpnNetworks))
|
||||||
|
for i := range t.vpnNetworks {
|
||||||
|
newAddrs[i] = &netlink.Addr{
|
||||||
|
IPNet: &net.IPNet{
|
||||||
|
IP: t.vpnNetworks[i].Addr().AsSlice(),
|
||||||
|
Mask: net.CIDRMask(t.vpnNetworks[i].Bits(), t.vpnNetworks[i].Addr().BitLen()),
|
||||||
|
},
|
||||||
|
Label: t.vpnNetworks[i].Addr().Zone(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add all new addresses
|
||||||
|
for i := range newAddrs {
|
||||||
|
if err := netlink.AddrReplace(link, newAddrs[i]); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Iterate over remainder, remove whoever shouldn't be there
|
||||||
|
al, err := netlink.AddrList(link, netlink.FAMILY_ALL)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get tun address list: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := range al {
|
||||||
|
if hasNetlinkAddr(newAddrs, al[i]) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
err = netlink.AddrDel(link, &al[i])
|
||||||
|
if err != nil {
|
||||||
|
t.l.WithError(err).Error("failed to remove address from tun address list")
|
||||||
|
} else {
|
||||||
|
t.l.WithField("removed", al[i].String()).Info("removed address not listed in cert(s)")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// watchRoutes monitors system route changes
|
||||||
|
func (t *wgTun) watchRoutes() {
|
||||||
|
|
||||||
rch := make(chan netlink.RouteUpdate)
|
rch := make(chan netlink.RouteUpdate)
|
||||||
doneChan := make(chan struct{})
|
doneChan := make(chan struct{})
|
||||||
|
|
||||||
netlinkOptions := netlink.RouteSubscribeOptions{
|
netlinkOptions := netlink.RouteSubscribeOptions{
|
||||||
ReceiveBufferSize: t.useSystemRoutesBufferSize,
|
ReceiveBufferSize: t.routeManager.useSystemRoutesBufferSize,
|
||||||
ReceiveBufferForceSize: t.useSystemRoutesBufferSize != 0,
|
ReceiveBufferForceSize: t.routeManager.useSystemRoutesBufferSize != 0,
|
||||||
ErrorCallback: func(e error) { t.l.WithError(e).Errorf("netlink error") },
|
ErrorCallback: func(e error) { t.l.WithError(e).Errorf("netlink error") },
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -576,87 +435,19 @@ func (t *tun) watchRoutes() {
|
|||||||
if ok {
|
if ok {
|
||||||
t.updateRoutes(r)
|
t.updateRoutes(r)
|
||||||
} else {
|
} else {
|
||||||
// may be should do something here as
|
|
||||||
// netlink stops sending updates
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
case <-doneChan:
|
case <-doneChan:
|
||||||
// netlink.RouteSubscriber will close the rch for us
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) isGatewayInVpnNetworks(gwAddr netip.Addr) bool {
|
func (t *wgTun) updateRoutes(r netlink.RouteUpdate) {
|
||||||
withinNetworks := false
|
gateways := t.getGatewaysFromRoute(&r.Route, t.routeManager.deviceIndex)
|
||||||
for i := range t.vpnNetworks {
|
|
||||||
if t.vpnNetworks[i].Contains(gwAddr) {
|
|
||||||
withinNetworks = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return withinNetworks
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *tun) getGatewaysFromRoute(r *netlink.Route) routing.Gateways {
|
|
||||||
|
|
||||||
var gateways routing.Gateways
|
|
||||||
|
|
||||||
link, err := netlink.LinkByName(t.Device)
|
|
||||||
if err != nil {
|
|
||||||
t.l.WithField("Devicename", t.Device).Error("Ignoring route update: failed to get link by name")
|
|
||||||
return gateways
|
|
||||||
}
|
|
||||||
|
|
||||||
// If this route is relevant to our interface and there is a gateway then add it
|
|
||||||
if r.LinkIndex == link.Attrs().Index && len(r.Gw) > 0 {
|
|
||||||
gwAddr, ok := netip.AddrFromSlice(r.Gw)
|
|
||||||
if !ok {
|
|
||||||
t.l.WithField("route", r).Debug("Ignoring route update, invalid gateway address")
|
|
||||||
} else {
|
|
||||||
gwAddr = gwAddr.Unmap()
|
|
||||||
|
|
||||||
if !t.isGatewayInVpnNetworks(gwAddr) {
|
|
||||||
// Gateway isn't in our overlay network, ignore
|
|
||||||
t.l.WithField("route", r).Debug("Ignoring route update, not in our network")
|
|
||||||
} else {
|
|
||||||
gateways = append(gateways, routing.NewGateway(gwAddr, 1))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, p := range r.MultiPath {
|
|
||||||
// If this route is relevant to our interface and there is a gateway then add it
|
|
||||||
if p.LinkIndex == link.Attrs().Index && len(p.Gw) > 0 {
|
|
||||||
gwAddr, ok := netip.AddrFromSlice(p.Gw)
|
|
||||||
if !ok {
|
|
||||||
t.l.WithField("route", r).Debug("Ignoring multipath route update, invalid gateway address")
|
|
||||||
} else {
|
|
||||||
gwAddr = gwAddr.Unmap()
|
|
||||||
|
|
||||||
if !t.isGatewayInVpnNetworks(gwAddr) {
|
|
||||||
// Gateway isn't in our overlay network, ignore
|
|
||||||
t.l.WithField("route", r).Debug("Ignoring route update, not in our network")
|
|
||||||
} else {
|
|
||||||
// p.Hops+1 = weight of the route
|
|
||||||
gateways = append(gateways, routing.NewGateway(gwAddr, p.Hops+1))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
routing.CalculateBucketsForGateways(gateways)
|
|
||||||
return gateways
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *tun) updateRoutes(r netlink.RouteUpdate) {
|
|
||||||
|
|
||||||
gateways := t.getGatewaysFromRoute(&r.Route)
|
|
||||||
|
|
||||||
if len(gateways) == 0 {
|
if len(gateways) == 0 {
|
||||||
// No gateways relevant to our network, no routing changes required.
|
|
||||||
t.l.WithField("route", r).Debug("Ignoring route update, no gateways")
|
t.l.WithField("route", r).Debug("Ignoring route update, no gateways")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -680,7 +471,6 @@ func (t *tun) updateRoutes(r netlink.RouteUpdate) {
|
|||||||
if r.Type == unix.RTM_NEWROUTE {
|
if r.Type == unix.RTM_NEWROUTE {
|
||||||
t.l.WithField("destination", dst).WithField("via", gateways).Info("Adding route")
|
t.l.WithField("destination", dst).WithField("via", gateways).Info("Adding route")
|
||||||
newTree.Insert(dst, gateways)
|
newTree.Insert(dst, gateways)
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
t.l.WithField("destination", dst).WithField("via", gateways).Info("Removing route")
|
t.l.WithField("destination", dst).WithField("via", gateways).Info("Removing route")
|
||||||
newTree.Delete(dst)
|
newTree.Delete(dst)
|
||||||
@@ -688,86 +478,71 @@ func (t *tun) updateRoutes(r netlink.RouteUpdate) {
|
|||||||
t.routeTree.Store(newTree)
|
t.routeTree.Store(newTree)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) Close() error {
|
func (t *wgTun) getGatewaysFromRoute(r *netlink.Route, deviceIndex int) routing.Gateways {
|
||||||
if t.routeChan != nil {
|
var gateways routing.Gateways
|
||||||
close(t.routeChan)
|
|
||||||
|
name, err := t.tunDevice.Name()
|
||||||
|
if err != nil {
|
||||||
|
t.l.Error("Ignoring route update: failed to get device name")
|
||||||
|
return gateways
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, v := range t.vdev {
|
link, err := netlink.LinkByName(name)
|
||||||
if v != nil {
|
if err != nil {
|
||||||
_ = v.Close()
|
t.l.WithField("DeviceName", name).Error("Ignoring route update: failed to get link by name")
|
||||||
|
return gateways
|
||||||
|
}
|
||||||
|
|
||||||
|
// If this route is relevant to our interface and there is a gateway then add it
|
||||||
|
if r.LinkIndex == link.Attrs().Index && len(r.Gw) > 0 {
|
||||||
|
gwAddr, ok := netip.AddrFromSlice(r.Gw)
|
||||||
|
if !ok {
|
||||||
|
t.l.WithField("route", r).Debug("Ignoring route update, invalid gateway address")
|
||||||
|
} else {
|
||||||
|
gwAddr = gwAddr.Unmap()
|
||||||
|
|
||||||
|
if !t.isGatewayInVpnNetworks(gwAddr) {
|
||||||
|
t.l.WithField("route", r).Debug("Ignoring route update, not in our network")
|
||||||
|
} else {
|
||||||
|
gateways = append(gateways, routing.NewGateway(gwAddr, 1))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if t.file != nil {
|
for _, p := range r.MultiPath {
|
||||||
_ = t.file.Close()
|
if p.LinkIndex == link.Attrs().Index && len(p.Gw) > 0 {
|
||||||
|
gwAddr, ok := netip.AddrFromSlice(p.Gw)
|
||||||
|
if !ok {
|
||||||
|
t.l.WithField("route", r).Debug("Ignoring multipath route update, invalid gateway address")
|
||||||
|
} else {
|
||||||
|
gwAddr = gwAddr.Unmap()
|
||||||
|
|
||||||
|
if !t.isGatewayInVpnNetworks(gwAddr) {
|
||||||
|
t.l.WithField("route", r).Debug("Ignoring route update, not in our network")
|
||||||
|
} else {
|
||||||
|
gateways = append(gateways, routing.NewGateway(gwAddr, p.Hops+1))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if t.ioctlFd > 0 {
|
routing.CalculateBucketsForGateways(gateways)
|
||||||
_ = os.NewFile(t.ioctlFd, "ioctlFd").Close()
|
return gateways
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *wgTun) isGatewayInVpnNetworks(gwAddr netip.Addr) bool {
|
||||||
|
for i := range t.vpnNetworks {
|
||||||
|
if t.vpnNetworks[i].Contains(gwAddr) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func ioctl(a1, a2, a3 uintptr) error {
|
||||||
|
_, _, errno := unix.Syscall(unix.SYS_IOCTL, a1, a2, a3)
|
||||||
|
if errno != 0 {
|
||||||
|
return errno
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) ReadMany(p []*packet.VirtIOPacket, q int) (int, error) {
|
|
||||||
n, err := t.vdev[q].ReceivePackets(p) //we are TXing
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
return n, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *tun) Write(b []byte) (int, error) {
|
|
||||||
maximum := len(b) //we are RXing
|
|
||||||
|
|
||||||
//todo garbagey
|
|
||||||
out := packet.NewOut()
|
|
||||||
x, err := t.AllocSeg(out, 0)
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
copy(out.SegmentPayloads[x], b)
|
|
||||||
err = t.vdev[0].TransmitPacket(out, true)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
t.l.WithError(err).Error("Transmitting packet")
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
return maximum, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *tun) AllocSeg(pkt *packet.OutPacket, q int) (int, error) {
|
|
||||||
idx, buf, err := t.vdev[q].GetPacketForTx()
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
x := pkt.UseSegment(idx, buf, t.isV6)
|
|
||||||
return x, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *tun) WriteOne(x *packet.OutPacket, kick bool, q int) (int, error) {
|
|
||||||
if err := t.vdev[q].TransmitPacket(x, kick); err != nil {
|
|
||||||
t.l.WithError(err).Error("Transmitting packet")
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
return 1, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *tun) WriteMany(x []*packet.OutPacket, q int) (int, error) {
|
|
||||||
maximum := len(x) //we are RXing
|
|
||||||
if maximum == 0 {
|
|
||||||
return 0, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
err := t.vdev[q].TransmitPackets(x)
|
|
||||||
if err != nil {
|
|
||||||
t.l.WithError(err).Error("Transmitting packet")
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
return maximum, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *tun) RecycleRxSeg(pkt *packet.VirtIOPacket, kick bool, q int) error {
|
|
||||||
return t.vdev[q].ReceiveQueue.OfferDescriptorChains(pkt.Chains, kick)
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -6,26 +6,27 @@ package overlay
|
|||||||
import "testing"
|
import "testing"
|
||||||
|
|
||||||
var runAdvMSSTests = []struct {
|
var runAdvMSSTests = []struct {
|
||||||
name string
|
name string
|
||||||
tun *tun
|
defaultMTU int
|
||||||
r Route
|
maxMTU int
|
||||||
expected int
|
r Route
|
||||||
|
expected int
|
||||||
}{
|
}{
|
||||||
// Standard case, default MTU is the device max MTU
|
// Standard case, default MTU is the device max MTU
|
||||||
{"default", &tun{DefaultMTU: 1440, MaxMTU: 1440}, Route{}, 0},
|
{"default", 1440, 1440, Route{}, 0},
|
||||||
{"default-min", &tun{DefaultMTU: 1440, MaxMTU: 1440}, Route{MTU: 1440}, 0},
|
{"default-min", 1440, 1440, Route{MTU: 1440}, 0},
|
||||||
{"default-low", &tun{DefaultMTU: 1440, MaxMTU: 1440}, Route{MTU: 1200}, 1160},
|
{"default-low", 1440, 1440, Route{MTU: 1200}, 1160},
|
||||||
|
|
||||||
// Case where we have a route MTU set higher than the default
|
// Case where we have a route MTU set higher than the default
|
||||||
{"route", &tun{DefaultMTU: 1440, MaxMTU: 8941}, Route{}, 1400},
|
{"route", 1440, 8941, Route{}, 1400},
|
||||||
{"route-min", &tun{DefaultMTU: 1440, MaxMTU: 8941}, Route{MTU: 1440}, 1400},
|
{"route-min", 1440, 8941, Route{MTU: 1440}, 1400},
|
||||||
{"route-high", &tun{DefaultMTU: 1440, MaxMTU: 8941}, Route{MTU: 8941}, 0},
|
{"route-high", 1440, 8941, Route{MTU: 8941}, 0},
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestTunAdvMSS(t *testing.T) {
|
func TestTunAdvMSS(t *testing.T) {
|
||||||
for _, tt := range runAdvMSSTests {
|
for _, tt := range runAdvMSSTests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
o := tt.tun.advMSS(tt.r)
|
o := advMSS(tt.r, tt.defaultMTU, tt.maxMTU)
|
||||||
if o != tt.expected {
|
if o != tt.expected {
|
||||||
t.Errorf("got %d, want %d", o, tt.expected)
|
t.Errorf("got %d, want %d", o, tt.expected)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -547,3 +547,41 @@ func delRoute(prefix netip.Prefix, gateways []netip.Prefix) error {
|
|||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func ioctl(a1, a2, a3 uintptr) error {
|
||||||
|
_, _, errno := syscall.Syscall(syscall.SYS_IOCTL, a1, a2, a3)
|
||||||
|
if errno != 0 {
|
||||||
|
return errno
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func prefixToMask(prefix netip.Prefix) netip.Addr {
|
||||||
|
bits := prefix.Bits()
|
||||||
|
if prefix.Addr().Is4() {
|
||||||
|
mask := ^uint32(0) << (32 - bits)
|
||||||
|
return netip.AddrFrom4([4]byte{
|
||||||
|
byte(mask >> 24),
|
||||||
|
byte(mask >> 16),
|
||||||
|
byte(mask >> 8),
|
||||||
|
byte(mask),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
var mask [16]byte
|
||||||
|
for i := 0; i < bits/8; i++ {
|
||||||
|
mask[i] = 0xff
|
||||||
|
}
|
||||||
|
if bits%8 != 0 {
|
||||||
|
mask[bits/8] = ^byte(0) << (8 - bits%8)
|
||||||
|
}
|
||||||
|
return netip.AddrFrom16(mask)
|
||||||
|
}
|
||||||
|
|
||||||
|
func selectGateway(prefix netip.Prefix, gateways []netip.Prefix) (netip.Prefix, error) {
|
||||||
|
for _, gw := range gateways {
|
||||||
|
if prefix.Addr().Is4() == gw.Addr().Is4() {
|
||||||
|
return gw, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return netip.Prefix{}, fmt.Errorf("no suitable gateway found for prefix %v", prefix)
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,14 +0,0 @@
|
|||||||
//go:build !windows
|
|
||||||
// +build !windows
|
|
||||||
|
|
||||||
package overlay
|
|
||||||
|
|
||||||
import "syscall"
|
|
||||||
|
|
||||||
func ioctl(a1, a2, a3 uintptr) error {
|
|
||||||
_, _, errno := syscall.Syscall(syscall.SYS_IOCTL, a1, a2, a3)
|
|
||||||
if errno != 0 {
|
|
||||||
return errno
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
@@ -1,104 +1,59 @@
|
|||||||
//go:build !e2e_testing
|
//go:build openbsd && !e2e_testing
|
||||||
// +build !e2e_testing
|
// +build openbsd,!e2e_testing
|
||||||
|
|
||||||
package overlay
|
package overlay
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os/exec"
|
||||||
"regexp"
|
"strconv"
|
||||||
"sync/atomic"
|
"strings"
|
||||||
"syscall"
|
|
||||||
"unsafe"
|
|
||||||
|
|
||||||
"github.com/gaissmai/bart"
|
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
"github.com/slackhq/nebula/routing"
|
|
||||||
"github.com/slackhq/nebula/util"
|
"github.com/slackhq/nebula/util"
|
||||||
netroute "golang.org/x/net/route"
|
wgtun "golang.zx2c4.com/wireguard/tun"
|
||||||
"golang.org/x/sys/unix"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
type tun struct{}
|
||||||
SIOCAIFADDR_IN6 = 0x8080691a
|
|
||||||
)
|
|
||||||
|
|
||||||
type ifreqAlias4 struct {
|
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*wgTun, error) {
|
||||||
Name [unix.IFNAMSIZ]byte
|
return nil, fmt.Errorf("newTunFromFd not supported on OpenBSD")
|
||||||
Addr unix.RawSockaddrInet4
|
|
||||||
DstAddr unix.RawSockaddrInet4
|
|
||||||
MaskAddr unix.RawSockaddrInet4
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type ifreqAlias6 struct {
|
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*wgTun, error) {
|
||||||
Name [unix.IFNAMSIZ]byte
|
deviceName := c.GetString("tun.dev", "tun")
|
||||||
Addr unix.RawSockaddrInet6
|
mtu := c.GetInt("tun.mtu", DefaultMTU)
|
||||||
DstAddr unix.RawSockaddrInet6
|
|
||||||
PrefixMask unix.RawSockaddrInet6
|
|
||||||
Flags uint32
|
|
||||||
Lifetime [2]uint32
|
|
||||||
}
|
|
||||||
|
|
||||||
type ifreq struct {
|
// Create WireGuard TUN device
|
||||||
Name [unix.IFNAMSIZ]byte
|
tunDevice, err := wgtun.CreateTUN(deviceName, mtu)
|
||||||
data int
|
|
||||||
}
|
|
||||||
|
|
||||||
type tun struct {
|
|
||||||
Device string
|
|
||||||
vpnNetworks []netip.Prefix
|
|
||||||
MTU int
|
|
||||||
Routes atomic.Pointer[[]Route]
|
|
||||||
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
|
||||||
l *logrus.Logger
|
|
||||||
f *os.File
|
|
||||||
fd int
|
|
||||||
// cache out buffer since we need to prepend 4 bytes for tun metadata
|
|
||||||
out []byte
|
|
||||||
}
|
|
||||||
|
|
||||||
var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
|
|
||||||
|
|
||||||
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) {
|
|
||||||
return nil, fmt.Errorf("newTunFromFd not supported in openbsd")
|
|
||||||
}
|
|
||||||
|
|
||||||
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
|
|
||||||
// Try to open tun device
|
|
||||||
var err error
|
|
||||||
deviceName := c.GetString("tun.dev", "")
|
|
||||||
if deviceName == "" {
|
|
||||||
return nil, fmt.Errorf("a device name in the format of /dev/tunN must be specified")
|
|
||||||
}
|
|
||||||
if !deviceNameRE.MatchString(deviceName) {
|
|
||||||
return nil, fmt.Errorf("a device name in the format of /dev/tunN must be specified")
|
|
||||||
}
|
|
||||||
|
|
||||||
fd, err := unix.Open("/dev/"+deviceName, os.O_RDWR, 0)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, fmt.Errorf("failed to create TUN device: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = unix.SetNonblock(fd, true)
|
// Get the actual device name
|
||||||
|
actualName, err := tunDevice.Name()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
l.WithError(err).Warn("Failed to set the tun device as nonblocking")
|
tunDevice.Close()
|
||||||
|
return nil, fmt.Errorf("failed to get TUN device name: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
t := &tun{
|
t := &wgTun{
|
||||||
f: os.NewFile(uintptr(fd), ""),
|
tunDevice: tunDevice,
|
||||||
fd: fd,
|
|
||||||
Device: deviceName,
|
|
||||||
vpnNetworks: vpnNetworks,
|
vpnNetworks: vpnNetworks,
|
||||||
MTU: c.GetInt("tun.mtu", DefaultMTU),
|
MaxMTU: mtu,
|
||||||
|
DefaultMTU: mtu,
|
||||||
l: l,
|
l: l,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Create OpenBSD-specific route manager
|
||||||
|
t.routeManager = &tun{}
|
||||||
|
|
||||||
err = t.reload(c, true)
|
err = t.reload(c, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
tunDevice.Close()
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -109,221 +64,86 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
l.WithField("name", actualName).Info("Created WireGuard TUN device")
|
||||||
|
|
||||||
return t, nil
|
return t, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) Close() error {
|
func (rm *tun) Activate(t *wgTun) error {
|
||||||
if t.f != nil {
|
name, err := t.tunDevice.Name()
|
||||||
if err := t.f.Close(); err != nil {
|
|
||||||
return fmt.Errorf("error closing tun file: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// t.f.Close should have handled it for us but let's be extra sure
|
|
||||||
_ = unix.Close(t.fd)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *tun) Read(to []byte) (int, error) {
|
|
||||||
buf := make([]byte, len(to)+4)
|
|
||||||
|
|
||||||
n, err := t.f.Read(buf)
|
|
||||||
|
|
||||||
copy(to, buf[4:])
|
|
||||||
return n - 4, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Write is only valid for single threaded use
|
|
||||||
func (t *tun) Write(from []byte) (int, error) {
|
|
||||||
buf := t.out
|
|
||||||
if cap(buf) < len(from)+4 {
|
|
||||||
buf = make([]byte, len(from)+4)
|
|
||||||
t.out = buf
|
|
||||||
}
|
|
||||||
buf = buf[:len(from)+4]
|
|
||||||
|
|
||||||
if len(from) == 0 {
|
|
||||||
return 0, syscall.EIO
|
|
||||||
}
|
|
||||||
|
|
||||||
// Determine the IP Family for the NULL L2 Header
|
|
||||||
ipVer := from[0] >> 4
|
|
||||||
if ipVer == 4 {
|
|
||||||
buf[3] = syscall.AF_INET
|
|
||||||
} else if ipVer == 6 {
|
|
||||||
buf[3] = syscall.AF_INET6
|
|
||||||
} else {
|
|
||||||
return 0, fmt.Errorf("unable to determine IP version from packet")
|
|
||||||
}
|
|
||||||
|
|
||||||
copy(buf[4:], from)
|
|
||||||
|
|
||||||
n, err := t.f.Write(buf)
|
|
||||||
return n - 4, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *tun) addIp(cidr netip.Prefix) error {
|
|
||||||
if cidr.Addr().Is4() {
|
|
||||||
var req ifreqAlias4
|
|
||||||
req.Name = t.deviceBytes()
|
|
||||||
req.Addr = unix.RawSockaddrInet4{
|
|
||||||
Len: unix.SizeofSockaddrInet4,
|
|
||||||
Family: unix.AF_INET,
|
|
||||||
Addr: cidr.Addr().As4(),
|
|
||||||
}
|
|
||||||
req.DstAddr = unix.RawSockaddrInet4{
|
|
||||||
Len: unix.SizeofSockaddrInet4,
|
|
||||||
Family: unix.AF_INET,
|
|
||||||
Addr: cidr.Addr().As4(),
|
|
||||||
}
|
|
||||||
req.MaskAddr = unix.RawSockaddrInet4{
|
|
||||||
Len: unix.SizeofSockaddrInet4,
|
|
||||||
Family: unix.AF_INET,
|
|
||||||
Addr: prefixToMask(cidr).As4(),
|
|
||||||
}
|
|
||||||
|
|
||||||
s, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer syscall.Close(s)
|
|
||||||
|
|
||||||
if err := ioctl(uintptr(s), unix.SIOCAIFADDR, uintptr(unsafe.Pointer(&req))); err != nil {
|
|
||||||
return fmt.Errorf("failed to set tun address %s: %s", cidr.Addr(), err)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = addRoute(cidr, t.vpnNetworks)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to set route for vpn network %v: %w", cidr, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if cidr.Addr().Is6() {
|
|
||||||
var req ifreqAlias6
|
|
||||||
req.Name = t.deviceBytes()
|
|
||||||
req.Addr = unix.RawSockaddrInet6{
|
|
||||||
Len: unix.SizeofSockaddrInet6,
|
|
||||||
Family: unix.AF_INET6,
|
|
||||||
Addr: cidr.Addr().As16(),
|
|
||||||
}
|
|
||||||
req.PrefixMask = unix.RawSockaddrInet6{
|
|
||||||
Len: unix.SizeofSockaddrInet6,
|
|
||||||
Family: unix.AF_INET6,
|
|
||||||
Addr: prefixToMask(cidr).As16(),
|
|
||||||
}
|
|
||||||
req.Lifetime[0] = 0xffffffff
|
|
||||||
req.Lifetime[1] = 0xffffffff
|
|
||||||
|
|
||||||
s, err := unix.Socket(unix.AF_INET6, unix.SOCK_DGRAM, unix.IPPROTO_IP)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer syscall.Close(s)
|
|
||||||
|
|
||||||
if err := ioctl(uintptr(s), SIOCAIFADDR_IN6, uintptr(unsafe.Pointer(&req))); err != nil {
|
|
||||||
return fmt.Errorf("failed to set tun address %s: %s", cidr.Addr().String(), err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return fmt.Errorf("unknown address type %v", cidr)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *tun) Activate() error {
|
|
||||||
err := t.doIoctlByName(unix.SIOCSIFMTU, uint32(t.MTU))
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to set tun mtu: %w", err)
|
return fmt.Errorf("failed to get device name: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
for i := range t.vpnNetworks {
|
// Set the MTU
|
||||||
err = t.addIp(t.vpnNetworks[i])
|
rm.SetMTU(t, t.MaxMTU)
|
||||||
if err != nil {
|
|
||||||
|
// Add IP addresses
|
||||||
|
for _, network := range t.vpnNetworks {
|
||||||
|
if err := rm.addIP(t, name, network); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return t.addRoutes(false)
|
// Bring up the interface
|
||||||
}
|
if err := runCommandBSD("ifconfig", name, "up"); err != nil {
|
||||||
|
return fmt.Errorf("failed to bring up interface: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
func (t *tun) doIoctlByName(ctl uintptr, value uint32) error {
|
// Set the routes
|
||||||
s, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP)
|
if err := rm.AddRoutes(t, false); err != nil {
|
||||||
if err != nil {
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
defer syscall.Close(s)
|
|
||||||
|
|
||||||
ir := ifreq{Name: t.deviceBytes(), data: int(value)}
|
|
||||||
err = ioctl(uintptr(s), ctl, uintptr(unsafe.Pointer(&ir)))
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *tun) reload(c *config.C, initial bool) error {
|
|
||||||
change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if !initial && !change {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
routeTree, err := makeRouteTree(t.l, routes, false)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Teach nebula how to handle the routes before establishing them in the system table
|
|
||||||
oldRoutes := t.Routes.Swap(&routes)
|
|
||||||
t.routeTree.Store(routeTree)
|
|
||||||
|
|
||||||
if !initial {
|
|
||||||
// Remove first, if the system removes a wanted route hopefully it will be re-added next
|
|
||||||
err := t.removeRoutes(findRemovedRoutes(routes, *oldRoutes))
|
|
||||||
if err != nil {
|
|
||||||
util.LogWithContextIfNeeded("Failed to remove routes", err, t.l)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Ensure any routes we actually want are installed
|
|
||||||
err = t.addRoutes(true)
|
|
||||||
if err != nil {
|
|
||||||
// Catch any stray logs
|
|
||||||
util.LogWithContextIfNeeded("Failed to add routes", err, t.l)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
|
func (rm *tun) SetMTU(t *wgTun, mtu int) {
|
||||||
r, _ := t.routeTree.Load().Lookup(ip)
|
name, err := t.tunDevice.Name()
|
||||||
return r
|
if err != nil {
|
||||||
|
t.l.WithError(err).Error("Failed to get device name for MTU set")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := runCommandBSD("ifconfig", name, "mtu", strconv.Itoa(mtu)); err != nil {
|
||||||
|
t.l.WithError(err).Error("Failed to set tun mtu")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) Networks() []netip.Prefix {
|
func (rm *tun) SetDefaultRoute(t *wgTun, cidr netip.Prefix) error {
|
||||||
return t.vpnNetworks
|
// On OpenBSD, routes are set via ifconfig and route commands
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) Name() string {
|
func (rm *tun) AddRoutes(t *wgTun, logErrors bool) error {
|
||||||
return t.Device
|
name, err := t.tunDevice.Name()
|
||||||
}
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get device name: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
|
||||||
return nil, fmt.Errorf("TODO: multiqueue not implemented for openbsd")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *tun) addRoutes(logErrors bool) error {
|
|
||||||
routes := *t.Routes.Load()
|
routes := *t.Routes.Load()
|
||||||
|
|
||||||
for _, r := range routes {
|
for _, r := range routes {
|
||||||
if len(r.Via) == 0 || !r.Install {
|
if !r.Install {
|
||||||
// We don't allow route MTUs so only install routes with a via
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
err := addRoute(r.Cidr, t.vpnNetworks)
|
// Add route using route command
|
||||||
|
args := []string{"add"}
|
||||||
|
|
||||||
|
if r.Cidr.Addr().Is6() {
|
||||||
|
args = append(args, "-inet6")
|
||||||
|
} else {
|
||||||
|
args = append(args, "-inet")
|
||||||
|
}
|
||||||
|
|
||||||
|
args = append(args, r.Cidr.String(), "-interface", name)
|
||||||
|
|
||||||
|
if r.Metric > 0 {
|
||||||
|
// OpenBSD doesn't support route metrics directly like Linux
|
||||||
|
t.l.WithField("route", r).Warn("Route metrics are not fully supported on OpenBSD")
|
||||||
|
}
|
||||||
|
|
||||||
|
err := runCommandBSD("route", args...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
retErr := util.NewContextualError("Failed to add route", map[string]any{"route": r}, err)
|
retErr := util.NewContextualError("Failed to add route", map[string]any{"route": r}, err)
|
||||||
if logErrors {
|
if logErrors {
|
||||||
@@ -339,131 +159,71 @@ func (t *tun) addRoutes(logErrors bool) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) removeRoutes(routes []Route) error {
|
func (rm *tun) RemoveRoutes(t *wgTun, routes []Route) {
|
||||||
|
name, err := t.tunDevice.Name()
|
||||||
|
if err != nil {
|
||||||
|
t.l.WithError(err).Error("Failed to get device name for route removal")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
for _, r := range routes {
|
for _, r := range routes {
|
||||||
if !r.Install {
|
if !r.Install {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
err := delRoute(r.Cidr, t.vpnNetworks)
|
args := []string{"delete"}
|
||||||
|
|
||||||
|
if r.Cidr.Addr().Is6() {
|
||||||
|
args = append(args, "-inet6")
|
||||||
|
} else {
|
||||||
|
args = append(args, "-inet")
|
||||||
|
}
|
||||||
|
|
||||||
|
args = append(args, r.Cidr.String(), "-interface", name)
|
||||||
|
|
||||||
|
err := runCommandBSD("route", args...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
|
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
|
||||||
} else {
|
} else {
|
||||||
t.l.WithField("route", r).Info("Removed route")
|
t.l.WithField("route", r).Info("Removed route")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) deviceBytes() (o [16]byte) {
|
func (rm *tun) NewMultiQueueReader(t *wgTun) (io.ReadWriteCloser, error) {
|
||||||
for i, c := range t.Device {
|
// OpenBSD doesn't support multi-queue TUN devices in the same way as Linux
|
||||||
o[i] = byte(c)
|
// Return a reader that wraps the same device
|
||||||
}
|
return &wgTunReader{
|
||||||
return
|
parent: t,
|
||||||
|
tunDevice: t.tunDevice,
|
||||||
|
offset: 0,
|
||||||
|
l: t.l,
|
||||||
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func addRoute(prefix netip.Prefix, gateways []netip.Prefix) error {
|
func (rm *tun) addIP(t *wgTun, name string, network netip.Prefix) error {
|
||||||
sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
|
addr := network.Addr()
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
|
|
||||||
}
|
|
||||||
defer unix.Close(sock)
|
|
||||||
|
|
||||||
route := &netroute.RouteMessage{
|
if addr.Is4() {
|
||||||
Version: unix.RTM_VERSION,
|
// For IPv4: ifconfig tun0 10.0.0.1/24
|
||||||
Type: unix.RTM_ADD,
|
if err := runCommandBSD("ifconfig", name, network.String()); err != nil {
|
||||||
Flags: unix.RTF_UP | unix.RTF_GATEWAY,
|
return fmt.Errorf("failed to add IPv4 address: %w", err)
|
||||||
Seq: 1,
|
|
||||||
}
|
|
||||||
|
|
||||||
if prefix.Addr().Is4() {
|
|
||||||
gw, err := selectGateway(prefix, gateways)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
route.Addrs = []netroute.Addr{
|
|
||||||
unix.RTAX_DST: &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()},
|
|
||||||
unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()},
|
|
||||||
unix.RTAX_GATEWAY: &netroute.Inet4Addr{IP: gw.Addr().As4()},
|
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
gw, err := selectGateway(prefix, gateways)
|
// For IPv6: ifconfig tun0 inet6 add 2001:db8::1/64
|
||||||
if err != nil {
|
if err := runCommandBSD("ifconfig", name, "inet6", "add", network.String()); err != nil {
|
||||||
return err
|
return fmt.Errorf("failed to add IPv6 address: %w", err)
|
||||||
}
|
}
|
||||||
route.Addrs = []netroute.Addr{
|
|
||||||
unix.RTAX_DST: &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()},
|
|
||||||
unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()},
|
|
||||||
unix.RTAX_GATEWAY: &netroute.Inet6Addr{IP: gw.Addr().As16()},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
data, err := route.Marshal()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to create route.RouteMessage: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = unix.Write(sock, data[:])
|
|
||||||
if err != nil {
|
|
||||||
if errors.Is(err, unix.EEXIST) {
|
|
||||||
// Try to do a change
|
|
||||||
route.Type = unix.RTM_CHANGE
|
|
||||||
data, err = route.Marshal()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to create route.RouteMessage for change: %w", err)
|
|
||||||
}
|
|
||||||
_, err = unix.Write(sock, data[:])
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func delRoute(prefix netip.Prefix, gateways []netip.Prefix) error {
|
func runCommandBSD(name string, args ...string) error {
|
||||||
sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
|
cmd := exec.Command(name, args...)
|
||||||
|
output, err := cmd.CombinedOutput()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
|
return fmt.Errorf("%s %s failed: %w\nOutput: %s", name, strings.Join(args, " "), err, string(output))
|
||||||
}
|
}
|
||||||
defer unix.Close(sock)
|
|
||||||
|
|
||||||
route := netroute.RouteMessage{
|
|
||||||
Version: unix.RTM_VERSION,
|
|
||||||
Type: unix.RTM_DELETE,
|
|
||||||
Seq: 1,
|
|
||||||
}
|
|
||||||
|
|
||||||
if prefix.Addr().Is4() {
|
|
||||||
gw, err := selectGateway(prefix, gateways)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
route.Addrs = []netroute.Addr{
|
|
||||||
unix.RTAX_DST: &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()},
|
|
||||||
unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()},
|
|
||||||
unix.RTAX_GATEWAY: &netroute.Inet4Addr{IP: gw.Addr().As4()},
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
gw, err := selectGateway(prefix, gateways)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
route.Addrs = []netroute.Addr{
|
|
||||||
unix.RTAX_DST: &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()},
|
|
||||||
unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()},
|
|
||||||
unix.RTAX_GATEWAY: &netroute.Inet6Addr{IP: gw.Addr().As16()},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
data, err := route.Marshal()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to create route.RouteMessage: %w", err)
|
|
||||||
}
|
|
||||||
_, err = unix.Write(sock, data[:])
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
242
overlay/tun_wg.go
Normal file
242
overlay/tun_wg.go
Normal file
@@ -0,0 +1,242 @@
|
|||||||
|
//go:build !android && !netbsd && !e2e_testing
|
||||||
|
// +build !android,!netbsd,!e2e_testing
|
||||||
|
|
||||||
|
package overlay
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/netip"
|
||||||
|
"sync/atomic"
|
||||||
|
|
||||||
|
"github.com/gaissmai/bart"
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
|
"github.com/slackhq/nebula/config"
|
||||||
|
"github.com/slackhq/nebula/routing"
|
||||||
|
"github.com/slackhq/nebula/util"
|
||||||
|
wgtun "golang.zx2c4.com/wireguard/tun"
|
||||||
|
)
|
||||||
|
|
||||||
|
// wgTun wraps a WireGuard TUN device and implements the overlay.Device interface
|
||||||
|
type wgTun struct {
|
||||||
|
tunDevice wgtun.Device
|
||||||
|
vpnNetworks []netip.Prefix
|
||||||
|
MaxMTU int
|
||||||
|
DefaultMTU int
|
||||||
|
|
||||||
|
Routes atomic.Pointer[[]Route]
|
||||||
|
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
||||||
|
routeChan chan struct{}
|
||||||
|
|
||||||
|
// Platform-specific route management
|
||||||
|
routeManager *tun
|
||||||
|
|
||||||
|
l *logrus.Logger
|
||||||
|
}
|
||||||
|
|
||||||
|
// BatchReader interface for readers that support vectorized I/O
|
||||||
|
type BatchReader interface {
|
||||||
|
BatchRead(buffers [][]byte, sizes []int) (int, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// BatchWriter interface for writers that support vectorized I/O
|
||||||
|
type BatchWriter interface {
|
||||||
|
BatchWrite(packets [][]byte) (int, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// wgTunReader wraps a single TUN queue for multi-queue support
|
||||||
|
type wgTunReader struct {
|
||||||
|
parent *wgTun
|
||||||
|
tunDevice wgtun.Device
|
||||||
|
offset int
|
||||||
|
l *logrus.Logger
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *wgTun) Networks() []netip.Prefix {
|
||||||
|
return t.vpnNetworks
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *wgTun) Name() string {
|
||||||
|
name, err := t.tunDevice.Name()
|
||||||
|
if err != nil {
|
||||||
|
t.l.WithError(err).Error("Failed to get TUN device name")
|
||||||
|
return "unknown"
|
||||||
|
}
|
||||||
|
return name
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *wgTun) RoutesFor(ip netip.Addr) routing.Gateways {
|
||||||
|
r, _ := t.routeTree.Load().Lookup(ip)
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *wgTun) Activate() error {
|
||||||
|
if t.routeManager == nil {
|
||||||
|
return fmt.Errorf("route manager not initialized")
|
||||||
|
}
|
||||||
|
return t.routeManager.Activate(t)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read implements single-packet read for backward compatibility
|
||||||
|
func (t *wgTun) Read(b []byte) (int, error) {
|
||||||
|
bufs := [][]byte{b}
|
||||||
|
sizes := []int{0}
|
||||||
|
n, err := t.tunDevice.Read(bufs, sizes, 0)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
if n == 0 {
|
||||||
|
return 0, io.ErrNoProgress
|
||||||
|
}
|
||||||
|
return sizes[0], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write implements single-packet write for backward compatibility
|
||||||
|
func (t *wgTun) Write(b []byte) (int, error) {
|
||||||
|
bufs := [][]byte{b}
|
||||||
|
offset := 0
|
||||||
|
|
||||||
|
// WireGuard TUN expects the packet data to start at offset 0
|
||||||
|
n, err := t.tunDevice.Write(bufs, offset)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
if n == 0 {
|
||||||
|
return 0, io.ErrShortWrite
|
||||||
|
}
|
||||||
|
return len(b), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *wgTun) Close() error {
|
||||||
|
if t.routeChan != nil {
|
||||||
|
close(t.routeChan)
|
||||||
|
}
|
||||||
|
|
||||||
|
if t.tunDevice != nil {
|
||||||
|
return t.tunDevice.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *wgTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||||
|
// For WireGuard TUN, we need to create separate TUN device instances for multi-queue
|
||||||
|
// The platform-specific implementation will handle this
|
||||||
|
if t.routeManager == nil {
|
||||||
|
return nil, fmt.Errorf("route manager not initialized for multi-queue reader")
|
||||||
|
}
|
||||||
|
|
||||||
|
return t.routeManager.NewMultiQueueReader(t)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *wgTun) reload(c *config.C, initial bool) error {
|
||||||
|
routeChange, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if !initial && !routeChange && !c.HasChanged("tun.mtu") {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
routeTree, err := makeRouteTree(t.l, routes, true)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
oldDefaultMTU := t.DefaultMTU
|
||||||
|
oldMaxMTU := t.MaxMTU
|
||||||
|
newDefaultMTU := c.GetInt("tun.mtu", DefaultMTU)
|
||||||
|
newMaxMTU := newDefaultMTU
|
||||||
|
for i, r := range routes {
|
||||||
|
if r.MTU == 0 {
|
||||||
|
routes[i].MTU = newDefaultMTU
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.MTU > t.MaxMTU {
|
||||||
|
newMaxMTU = r.MTU
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
t.MaxMTU = newMaxMTU
|
||||||
|
t.DefaultMTU = newDefaultMTU
|
||||||
|
|
||||||
|
// Teach nebula how to handle the routes before establishing them in the system table
|
||||||
|
oldRoutes := t.Routes.Swap(&routes)
|
||||||
|
t.routeTree.Store(routeTree)
|
||||||
|
|
||||||
|
if !initial && t.routeManager != nil {
|
||||||
|
if oldMaxMTU != newMaxMTU {
|
||||||
|
t.routeManager.SetMTU(t, t.MaxMTU)
|
||||||
|
t.l.Infof("Set max MTU to %v was %v", t.MaxMTU, oldMaxMTU)
|
||||||
|
}
|
||||||
|
|
||||||
|
if oldDefaultMTU != newDefaultMTU {
|
||||||
|
for i := range t.vpnNetworks {
|
||||||
|
err := t.routeManager.SetDefaultRoute(t, t.vpnNetworks[i])
|
||||||
|
if err != nil {
|
||||||
|
t.l.Warn(err)
|
||||||
|
} else {
|
||||||
|
t.l.Infof("Set default MTU to %v was %v", t.DefaultMTU, oldDefaultMTU)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove first, if the system removes a wanted route hopefully it will be re-added next
|
||||||
|
t.routeManager.RemoveRoutes(t, findRemovedRoutes(routes, *oldRoutes))
|
||||||
|
|
||||||
|
// Ensure any routes we actually want are installed
|
||||||
|
err = t.routeManager.AddRoutes(t, true)
|
||||||
|
if err != nil {
|
||||||
|
// This should never be called since AddRoutes should log its own errors in a reload condition
|
||||||
|
util.LogWithContextIfNeeded("Failed to refresh routes", err, t.l)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// BatchRead reads multiple packets from the TUN device using vectorized I/O
|
||||||
|
// The caller provides buffers and sizes slices, and this function returns the number of packets read.
|
||||||
|
func (r *wgTunReader) BatchRead(buffers [][]byte, sizes []int) (int, error) {
|
||||||
|
return r.tunDevice.Read(buffers, sizes, r.offset)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read implements io.Reader for wgTunReader (single packet for compatibility)
|
||||||
|
func (r *wgTunReader) Read(b []byte) (int, error) {
|
||||||
|
bufs := [][]byte{b}
|
||||||
|
sizes := []int{0}
|
||||||
|
n, err := r.tunDevice.Read(bufs, sizes, r.offset)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
if n == 0 {
|
||||||
|
return 0, io.ErrNoProgress
|
||||||
|
}
|
||||||
|
return sizes[0], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write implements io.Writer for wgTunReader
|
||||||
|
func (r *wgTunReader) Write(b []byte) (int, error) {
|
||||||
|
bufs := [][]byte{b}
|
||||||
|
n, err := r.tunDevice.Write(bufs, r.offset)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
if n == 0 {
|
||||||
|
return 0, io.ErrShortWrite
|
||||||
|
}
|
||||||
|
return len(b), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// BatchWrite writes multiple packets to the TUN device using vectorized I/O
|
||||||
|
func (r *wgTunReader) BatchWrite(packets [][]byte) (int, error) {
|
||||||
|
return r.tunDevice.Write(packets, r.offset)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *wgTunReader) Close() error {
|
||||||
|
if r.tunDevice != nil {
|
||||||
|
return r.tunDevice.Close()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -1,84 +1,77 @@
|
|||||||
//go:build !e2e_testing
|
//go:build windows && !e2e_testing
|
||||||
// +build !e2e_testing
|
// +build windows,!e2e_testing
|
||||||
|
|
||||||
package overlay
|
package overlay
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto"
|
"crypto"
|
||||||
|
"encoding/binary"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
"runtime"
|
|
||||||
"sync/atomic"
|
|
||||||
"syscall"
|
|
||||||
"unsafe"
|
|
||||||
|
|
||||||
"github.com/gaissmai/bart"
|
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
"github.com/slackhq/nebula/routing"
|
|
||||||
"github.com/slackhq/nebula/util"
|
"github.com/slackhq/nebula/util"
|
||||||
"github.com/slackhq/nebula/wintun"
|
|
||||||
"golang.org/x/sys/windows"
|
"golang.org/x/sys/windows"
|
||||||
|
wgtun "golang.zx2c4.com/wireguard/tun"
|
||||||
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
|
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
|
||||||
)
|
)
|
||||||
|
|
||||||
const tunGUIDLabel = "Fixed Nebula Windows GUID v1"
|
const tunGUIDLabel = "Fixed Nebula Windows GUID v1"
|
||||||
|
|
||||||
type winTun struct {
|
type tun struct {
|
||||||
Device string
|
luid winipcfg.LUID
|
||||||
vpnNetworks []netip.Prefix
|
|
||||||
MTU int
|
|
||||||
Routes atomic.Pointer[[]Route]
|
|
||||||
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
|
||||||
l *logrus.Logger
|
|
||||||
|
|
||||||
tun *wintun.NativeTun
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (Device, error) {
|
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*wgTun, error) {
|
||||||
return nil, fmt.Errorf("newTunFromFd not supported in Windows")
|
return nil, fmt.Errorf("newTunFromFd not supported in Windows")
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*winTun, error) {
|
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*wgTun, error) {
|
||||||
err := checkWinTunExists()
|
deviceName := c.GetString("tun.dev", "Nebula")
|
||||||
|
mtu := c.GetInt("tun.mtu", DefaultMTU)
|
||||||
|
|
||||||
|
// Create WireGuard TUN device
|
||||||
|
tunDevice, err := wgtun.CreateTUN(deviceName, mtu)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("can not load the wintun driver: %w", err)
|
return nil, fmt.Errorf("failed to create TUN device: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
deviceName := c.GetString("tun.dev", "")
|
// Get the actual device name
|
||||||
guid, err := generateGUIDByDeviceName(deviceName)
|
actualName, err := tunDevice.Name()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("generate GUID failed: %w", err)
|
tunDevice.Close()
|
||||||
|
return nil, fmt.Errorf("failed to get TUN device name: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
t := &winTun{
|
t := &wgTun{
|
||||||
Device: deviceName,
|
tunDevice: tunDevice,
|
||||||
vpnNetworks: vpnNetworks,
|
vpnNetworks: vpnNetworks,
|
||||||
MTU: c.GetInt("tun.mtu", DefaultMTU),
|
MaxMTU: mtu,
|
||||||
|
DefaultMTU: mtu,
|
||||||
l: l,
|
l: l,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Create Windows-specific route manager
|
||||||
|
rm := &tun{}
|
||||||
|
|
||||||
|
// Get LUID from the TUN device
|
||||||
|
// The WireGuard TUN device on Windows should provide a LUID() method
|
||||||
|
if nativeTun, ok := tunDevice.(interface{ LUID() uint64 }); ok {
|
||||||
|
rm.luid = winipcfg.LUID(nativeTun.LUID())
|
||||||
|
} else {
|
||||||
|
tunDevice.Close()
|
||||||
|
return nil, fmt.Errorf("failed to get LUID from TUN device")
|
||||||
|
}
|
||||||
|
t.routeManager = rm
|
||||||
|
|
||||||
err = t.reload(c, true)
|
err = t.reload(c, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
tunDevice.Close()
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
var tunDevice wintun.Device
|
|
||||||
tunDevice, err = wintun.CreateTUNWithRequestedGUID(deviceName, guid, t.MTU)
|
|
||||||
if err != nil {
|
|
||||||
// Windows 10 has an issue with unclean shutdowns not fully cleaning up the wintun device.
|
|
||||||
// Trying a second time resolves the issue.
|
|
||||||
l.WithError(err).Debug("Failed to create wintun device, retrying")
|
|
||||||
tunDevice, err = wintun.CreateTUNWithRequestedGUID(deviceName, guid, t.MTU)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("create TUN device failed: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
t.tun = tunDevice.(*wintun.NativeTun)
|
|
||||||
|
|
||||||
c.RegisterReloadCallback(func(c *config.C) {
|
c.RegisterReloadCallback(func(c *config.C) {
|
||||||
err := t.reload(c, false)
|
err := t.reload(c, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -86,206 +79,140 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
l.WithField("name", actualName).Info("Created WireGuard TUN device")
|
||||||
|
|
||||||
return t, nil
|
return t, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *winTun) reload(c *config.C, initial bool) error {
|
func (rm *tun) Activate(t *wgTun) error {
|
||||||
change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial)
|
// Set MTU
|
||||||
|
err := rm.setMTU(t, t.MaxMTU)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return fmt.Errorf("failed to set MTU: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !initial && !change {
|
// Add IP addresses
|
||||||
return nil
|
for _, network := range t.vpnNetworks {
|
||||||
}
|
if err := rm.addIP(t, network); err != nil {
|
||||||
|
return err
|
||||||
routeTree, err := makeRouteTree(t.l, routes, false)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Teach nebula how to handle the routes before establishing them in the system table
|
|
||||||
oldRoutes := t.Routes.Swap(&routes)
|
|
||||||
t.routeTree.Store(routeTree)
|
|
||||||
|
|
||||||
if !initial {
|
|
||||||
// Remove first, if the system removes a wanted route hopefully it will be re-added next
|
|
||||||
err := t.removeRoutes(findRemovedRoutes(routes, *oldRoutes))
|
|
||||||
if err != nil {
|
|
||||||
util.LogWithContextIfNeeded("Failed to remove routes", err, t.l)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Ensure any routes we actually want are installed
|
|
||||||
err = t.addRoutes(true)
|
|
||||||
if err != nil {
|
|
||||||
// Catch any stray logs
|
|
||||||
util.LogWithContextIfNeeded("Failed to add routes", err, t.l)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
// Add routes
|
||||||
}
|
if err := rm.AddRoutes(t, false); err != nil {
|
||||||
|
|
||||||
func (t *winTun) Activate() error {
|
|
||||||
luid := winipcfg.LUID(t.tun.LUID())
|
|
||||||
|
|
||||||
err := luid.SetIPAddresses(t.vpnNetworks)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to set address: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = t.addRoutes(false)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *winTun) addRoutes(logErrors bool) error {
|
func (rm *tun) SetMTU(t *wgTun, mtu int) {
|
||||||
luid := winipcfg.LUID(t.tun.LUID())
|
if err := rm.setMTU(t, mtu); err != nil {
|
||||||
|
t.l.WithError(err).Error("Failed to set MTU")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rm *tun) setMTU(t *wgTun, mtu int) error {
|
||||||
|
// Set MTU using winipcfg
|
||||||
|
// Note: MTU setting on Windows TUN devices may be handled by the driver
|
||||||
|
// For now, we'll skip explicit MTU setting as the WireGuard TUN handles it
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rm *tun) SetDefaultRoute(t *wgTun, cidr netip.Prefix) error {
|
||||||
|
// On Windows, routes are managed differently
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rm *tun) AddRoutes(t *wgTun, logErrors bool) error {
|
||||||
routes := *t.Routes.Load()
|
routes := *t.Routes.Load()
|
||||||
foundDefault4 := false
|
|
||||||
|
|
||||||
for _, r := range routes {
|
for _, r := range routes {
|
||||||
if len(r.Via) == 0 || !r.Install {
|
if !r.Install {
|
||||||
// We don't allow route MTUs so only install routes with a via
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add our unsafe route
|
if r.MTU > 0 {
|
||||||
// Windows does not support multipath routes natively, so we install only a single route.
|
// Windows route MTU is not directly supported
|
||||||
// This is not a problem as traffic will always be sent to Nebula which handles the multipath routing internally.
|
t.l.WithField("route", r).Debug("Route MTU is not supported on Windows")
|
||||||
// In effect this provides multipath routing support to windows supporting loadbalancing and redundancy.
|
}
|
||||||
err := luid.AddRoute(r.Cidr, r.Via[0].Addr(), uint32(r.Metric))
|
|
||||||
|
// Use winipcfg to add the route
|
||||||
|
// The rm.luid should have the AddRoute method from winipcfg
|
||||||
|
if len(r.Via) == 0 {
|
||||||
|
t.l.WithField("route", r).Warn("Route has no via address, skipping")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
err := rm.luid.AddRoute(r.Cidr, r.Via[0].Addr(), uint32(r.Metric))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
retErr := util.NewContextualError("Failed to add route", map[string]any{"route": r}, err)
|
retErr := util.NewContextualError("Failed to add route", map[string]any{"route": r}, err)
|
||||||
if logErrors {
|
if logErrors {
|
||||||
retErr.Log(t.l)
|
retErr.Log(t.l)
|
||||||
continue
|
|
||||||
} else {
|
} else {
|
||||||
return retErr
|
return retErr
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
t.l.WithField("route", r).Info("Added route")
|
t.l.WithField("route", r).Info("Added route")
|
||||||
}
|
}
|
||||||
|
|
||||||
if !foundDefault4 {
|
|
||||||
if r.Cidr.Bits() == 0 && r.Cidr.Addr().BitLen() == 32 {
|
|
||||||
foundDefault4 = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
ipif, err := luid.IPInterface(windows.AF_INET)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to get ip interface: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
ipif.NLMTU = uint32(t.MTU)
|
|
||||||
if foundDefault4 {
|
|
||||||
ipif.UseAutomaticMetric = false
|
|
||||||
ipif.Metric = 0
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := ipif.Set(); err != nil {
|
|
||||||
return fmt.Errorf("failed to set ip interface: %w", err)
|
|
||||||
}
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *winTun) removeRoutes(routes []Route) error {
|
func (rm *tun) RemoveRoutes(t *wgTun, routes []Route) {
|
||||||
luid := winipcfg.LUID(t.tun.LUID())
|
|
||||||
|
|
||||||
for _, r := range routes {
|
for _, r := range routes {
|
||||||
if !r.Install {
|
if !r.Install {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// See comment on luid.AddRoute
|
if len(r.Via) == 0 {
|
||||||
err := luid.DeleteRoute(r.Cidr, r.Via[0].Addr())
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
err := rm.luid.DeleteRoute(r.Cidr, r.Via[0].Addr())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
|
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
|
||||||
} else {
|
} else {
|
||||||
t.l.WithField("route", r).Info("Removed route")
|
t.l.WithField("route", r).Info("Removed route")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rm *tun) NewMultiQueueReader(t *wgTun) (io.ReadWriteCloser, error) {
|
||||||
|
// Windows doesn't support multi-queue TUN devices
|
||||||
|
// Return a reader that wraps the same device
|
||||||
|
return &wgTunReader{
|
||||||
|
parent: t,
|
||||||
|
tunDevice: t.tunDevice,
|
||||||
|
offset: 0,
|
||||||
|
l: t.l,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rm *tun) addIP(t *wgTun, network netip.Prefix) error {
|
||||||
|
// Add IP address using winipcfg
|
||||||
|
// SetIPAddresses expects a slice of prefixes
|
||||||
|
err := rm.luid.SetIPAddresses([]netip.Prefix{network})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to add IP address %s: %w", network, err)
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *winTun) RoutesFor(ip netip.Addr) routing.Gateways {
|
// generateGUIDByDeviceName generates a GUID based on the device name
|
||||||
r, _ := t.routeTree.Load().Lookup(ip)
|
func generateGUIDByDeviceName(deviceName string) (*windows.GUID, error) {
|
||||||
return r
|
// Hash the device name to create a deterministic GUID
|
||||||
}
|
h := crypto.SHA256.New()
|
||||||
|
h.Write([]byte(tunGUIDLabel))
|
||||||
|
h.Write([]byte(deviceName))
|
||||||
|
sum := h.Sum(nil)
|
||||||
|
|
||||||
func (t *winTun) Networks() []netip.Prefix {
|
guid := &windows.GUID{
|
||||||
return t.vpnNetworks
|
Data1: binary.LittleEndian.Uint32(sum[0:4]),
|
||||||
}
|
Data2: binary.LittleEndian.Uint16(sum[4:6]),
|
||||||
|
Data3: binary.LittleEndian.Uint16(sum[6:8]),
|
||||||
func (t *winTun) Name() string {
|
|
||||||
return t.Device
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *winTun) Read(b []byte) (int, error) {
|
|
||||||
return t.tun.Read(b, 0)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *winTun) Write(b []byte) (int, error) {
|
|
||||||
return t.tun.Write(b, 0)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *winTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
|
||||||
return nil, fmt.Errorf("TODO: multiqueue not implemented for windows")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *winTun) Close() error {
|
|
||||||
// It seems that the Windows networking stack doesn't like it when we destroy interfaces that have active routes,
|
|
||||||
// so to be certain, just remove everything before destroying.
|
|
||||||
luid := winipcfg.LUID(t.tun.LUID())
|
|
||||||
_ = luid.FlushRoutes(windows.AF_INET)
|
|
||||||
_ = luid.FlushIPAddresses(windows.AF_INET)
|
|
||||||
|
|
||||||
_ = luid.FlushRoutes(windows.AF_INET6)
|
|
||||||
_ = luid.FlushIPAddresses(windows.AF_INET6)
|
|
||||||
|
|
||||||
_ = luid.FlushDNS(windows.AF_INET)
|
|
||||||
_ = luid.FlushDNS(windows.AF_INET6)
|
|
||||||
|
|
||||||
return t.tun.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
func generateGUIDByDeviceName(name string) (*windows.GUID, error) {
|
|
||||||
// GUID is 128 bit
|
|
||||||
hash := crypto.MD5.New()
|
|
||||||
|
|
||||||
_, err := hash.Write([]byte(tunGUIDLabel))
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
}
|
||||||
|
copy(guid.Data4[:], sum[8:16])
|
||||||
|
|
||||||
_, err = hash.Write([]byte(name))
|
return guid, nil
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
sum := hash.Sum(nil)
|
|
||||||
|
|
||||||
return (*windows.GUID)(unsafe.Pointer(&sum[0])), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func checkWinTunExists() error {
|
|
||||||
myPath, err := os.Executable()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
arch := runtime.GOARCH
|
|
||||||
switch arch {
|
|
||||||
case "386":
|
|
||||||
//NOTE: wintun bundles 386 as x86
|
|
||||||
arch = "x86"
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = syscall.LoadDLL(filepath.Join(filepath.Dir(myPath), "dist", "windows", "wintun", "bin", arch, "wintun.dll"))
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,13 +1,11 @@
|
|||||||
package overlay
|
package overlay
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
|
||||||
"io"
|
"io"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
"github.com/slackhq/nebula/packet"
|
|
||||||
"github.com/slackhq/nebula/routing"
|
"github.com/slackhq/nebula/routing"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -38,10 +36,6 @@ type UserDevice struct {
|
|||||||
inboundWriter *io.PipeWriter
|
inboundWriter *io.PipeWriter
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *UserDevice) RecycleRxSeg(pkt *packet.VirtIOPacket, kick bool, q int) error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *UserDevice) Activate() error {
|
func (d *UserDevice) Activate() error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -52,7 +46,7 @@ func (d *UserDevice) RoutesFor(ip netip.Addr) routing.Gateways {
|
|||||||
return routing.Gateways{routing.NewGateway(ip, 1)}
|
return routing.Gateways{routing.NewGateway(ip, 1)}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *UserDevice) NewMultiQueueReader() (TunDev, error) {
|
func (d *UserDevice) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||||
return d, nil
|
return d, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -71,19 +65,3 @@ func (d *UserDevice) Close() error {
|
|||||||
d.outboundWriter.Close()
|
d.outboundWriter.Close()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *UserDevice) ReadMany(b []*packet.VirtIOPacket, _ int) (int, error) {
|
|
||||||
return d.Read(b[0].Payload)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *UserDevice) AllocSeg(pkt *packet.OutPacket, q int) (int, error) {
|
|
||||||
return 0, fmt.Errorf("user: AllocSeg not implemented")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *UserDevice) WriteOne(x *packet.OutPacket, kick bool, q int) (int, error) {
|
|
||||||
return 0, fmt.Errorf("user: WriteOne not implemented")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *UserDevice) WriteMany(x []*packet.OutPacket, q int) (int, error) {
|
|
||||||
return 0, fmt.Errorf("user: WriteMany not implemented")
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -1,23 +0,0 @@
|
|||||||
Significant portions of this code are derived from https://pkg.go.dev/github.com/hetznercloud/virtio-go
|
|
||||||
|
|
||||||
MIT License
|
|
||||||
|
|
||||||
Copyright (c) 2025 Hetzner Cloud GmbH
|
|
||||||
|
|
||||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
||||||
of this software and associated documentation files (the "Software"), to deal
|
|
||||||
in the Software without restriction, including without limitation the rights
|
|
||||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
||||||
copies of the Software, and to permit persons to whom the Software is
|
|
||||||
furnished to do so, subject to the following conditions:
|
|
||||||
|
|
||||||
The above copyright notice and this permission notice shall be included in all
|
|
||||||
copies or substantial portions of the Software.
|
|
||||||
|
|
||||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
||||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
||||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
||||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
||||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
||||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
||||||
SOFTWARE.
|
|
||||||
@@ -1,4 +0,0 @@
|
|||||||
// Package vhost implements the basic ioctl requests needed to interact with the
|
|
||||||
// kernel-level virtio server that provides accelerated virtio devices for
|
|
||||||
// networking and more.
|
|
||||||
package vhost
|
|
||||||
@@ -1,218 +0,0 @@
|
|||||||
package vhost
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"unsafe"
|
|
||||||
|
|
||||||
"github.com/slackhq/nebula/overlay/virtqueue"
|
|
||||||
"github.com/slackhq/nebula/util/virtio"
|
|
||||||
"golang.org/x/sys/unix"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
// vhostIoctlGetFeatures can be used to retrieve the features supported by
|
|
||||||
// the vhost implementation in the kernel.
|
|
||||||
//
|
|
||||||
// Response payload: [virtio.Feature]
|
|
||||||
// Kernel name: VHOST_GET_FEATURES
|
|
||||||
vhostIoctlGetFeatures = 0x8008af00
|
|
||||||
|
|
||||||
// vhostIoctlSetFeatures can be used to communicate the features supported
|
|
||||||
// by this virtio implementation to the kernel.
|
|
||||||
//
|
|
||||||
// Request payload: [virtio.Feature]
|
|
||||||
// Kernel name: VHOST_SET_FEATURES
|
|
||||||
vhostIoctlSetFeatures = 0x4008af00
|
|
||||||
|
|
||||||
// vhostIoctlSetOwner can be used to set the current process as the
|
|
||||||
// exclusive owner of a control file descriptor.
|
|
||||||
//
|
|
||||||
// Request payload: none
|
|
||||||
// Kernel name: VHOST_SET_OWNER
|
|
||||||
vhostIoctlSetOwner = 0x0000af01
|
|
||||||
|
|
||||||
// vhostIoctlSetMemoryLayout can be used to set up or modify the memory
|
|
||||||
// layout which describes the IOTLB mappings in the kernel.
|
|
||||||
//
|
|
||||||
// Request payload: [MemoryLayout] with custom serialization
|
|
||||||
// Kernel name: VHOST_SET_MEM_TABLE
|
|
||||||
vhostIoctlSetMemoryLayout = 0x4008af03
|
|
||||||
|
|
||||||
// vhostIoctlSetQueueSize can be used to set the size of the virtqueue.
|
|
||||||
//
|
|
||||||
// Request payload: [QueueState]
|
|
||||||
// Kernel name: VHOST_SET_VRING_NUM
|
|
||||||
vhostIoctlSetQueueSize = 0x4008af10
|
|
||||||
|
|
||||||
// vhostIoctlSetQueueAddress can be used to set the addresses of the
|
|
||||||
// different parts of the virtqueue.
|
|
||||||
//
|
|
||||||
// Request payload: [QueueAddresses]
|
|
||||||
// Kernel name: VHOST_SET_VRING_ADDR
|
|
||||||
vhostIoctlSetQueueAddress = 0x4028af11
|
|
||||||
|
|
||||||
// vhostIoctlSetAvailableRingBase can be used to set the index of the next
|
|
||||||
// available ring entry the device will process.
|
|
||||||
//
|
|
||||||
// Request payload: [QueueState]
|
|
||||||
// Kernel name: VHOST_SET_VRING_BASE
|
|
||||||
vhostIoctlSetAvailableRingBase = 0x4008af12
|
|
||||||
|
|
||||||
// vhostIoctlSetQueueKickEventFD can be used to set the event file
|
|
||||||
// descriptor to signal the device when descriptor chains were added to the
|
|
||||||
// available ring.
|
|
||||||
//
|
|
||||||
// Request payload: [QueueFile]
|
|
||||||
// Kernel name: VHOST_SET_VRING_KICK
|
|
||||||
vhostIoctlSetQueueKickEventFD = 0x4008af20
|
|
||||||
|
|
||||||
// vhostIoctlSetQueueCallEventFD can be used to set the event file
|
|
||||||
// descriptor that gets signaled by the device when descriptor chains have
|
|
||||||
// been used by it.
|
|
||||||
//
|
|
||||||
// Request payload: [QueueFile]
|
|
||||||
// Kernel name: VHOST_SET_VRING_CALL
|
|
||||||
vhostIoctlSetQueueCallEventFD = 0x4008af21
|
|
||||||
)
|
|
||||||
|
|
||||||
// QueueState is an ioctl request payload that can hold a queue index and any
|
|
||||||
// 32-bit number.
|
|
||||||
//
|
|
||||||
// Kernel name: vhost_vring_state
|
|
||||||
type QueueState struct {
|
|
||||||
// QueueIndex is the index of the virtqueue.
|
|
||||||
QueueIndex uint32
|
|
||||||
// Num is any 32-bit number, depending on the request.
|
|
||||||
Num uint32
|
|
||||||
}
|
|
||||||
|
|
||||||
// QueueAddresses is an ioctl request payload that can hold the addresses of the
|
|
||||||
// different parts of a virtqueue.
|
|
||||||
//
|
|
||||||
// Kernel name: vhost_vring_addr
|
|
||||||
type QueueAddresses struct {
|
|
||||||
// QueueIndex is the index of the virtqueue.
|
|
||||||
QueueIndex uint32
|
|
||||||
// Flags that are not used in this implementation.
|
|
||||||
Flags uint32
|
|
||||||
// DescriptorTableAddress is the address of the descriptor table in user
|
|
||||||
// space memory. It must be 16-byte aligned.
|
|
||||||
DescriptorTableAddress uintptr
|
|
||||||
// UsedRingAddress is the address of the used ring in user space memory. It
|
|
||||||
// must be 4-byte aligned.
|
|
||||||
UsedRingAddress uintptr
|
|
||||||
// AvailableRingAddress is the address of the available ring in user space
|
|
||||||
// memory. It must be 2-byte aligned.
|
|
||||||
AvailableRingAddress uintptr
|
|
||||||
// LogAddress is used for an optional logging support, not supported by this
|
|
||||||
// implementation.
|
|
||||||
LogAddress uintptr
|
|
||||||
}
|
|
||||||
|
|
||||||
// QueueFile is an ioctl request payload that can hold a queue index and a file
|
|
||||||
// descriptor.
|
|
||||||
//
|
|
||||||
// Kernel name: vhost_vring_file
|
|
||||||
type QueueFile struct {
|
|
||||||
// QueueIndex is the index of the virtqueue.
|
|
||||||
QueueIndex uint32
|
|
||||||
// FD is the file descriptor of the file. Pass -1 to unbind from a file.
|
|
||||||
FD int32
|
|
||||||
}
|
|
||||||
|
|
||||||
// IoctlPtr is a copy of the similarly named unexported function from the Go
|
|
||||||
// unix package. This is needed to do custom ioctl requests not supported by the
|
|
||||||
// standard library.
|
|
||||||
func IoctlPtr(fd int, req uint, arg unsafe.Pointer) error {
|
|
||||||
_, _, err := unix.Syscall(unix.SYS_IOCTL, uintptr(fd), uintptr(req), uintptr(arg))
|
|
||||||
if err != 0 {
|
|
||||||
return fmt.Errorf("ioctl request %d: %w", req, err)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetFeatures requests the supported feature bits from the virtio device
|
|
||||||
// associated with the given control file descriptor.
|
|
||||||
func GetFeatures(controlFD int) (virtio.Feature, error) {
|
|
||||||
var features virtio.Feature
|
|
||||||
if err := IoctlPtr(controlFD, vhostIoctlGetFeatures, unsafe.Pointer(&features)); err != nil {
|
|
||||||
return 0, fmt.Errorf("get features: %w", err)
|
|
||||||
}
|
|
||||||
return features, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetFeatures communicates the feature bits supported by this implementation
|
|
||||||
// to the virtio device associated with the given control file descriptor.
|
|
||||||
func SetFeatures(controlFD int, features virtio.Feature) error {
|
|
||||||
if err := IoctlPtr(controlFD, vhostIoctlSetFeatures, unsafe.Pointer(&features)); err != nil {
|
|
||||||
return fmt.Errorf("set features: %w", err)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// OwnControlFD sets the current process as the exclusive owner for the
|
|
||||||
// given control file descriptor. This must be called before interacting with
|
|
||||||
// the control file descriptor in any other way.
|
|
||||||
func OwnControlFD(controlFD int) error {
|
|
||||||
if err := IoctlPtr(controlFD, vhostIoctlSetOwner, unsafe.Pointer(nil)); err != nil {
|
|
||||||
return fmt.Errorf("set control file descriptor owner: %w", err)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetMemoryLayout sets up or modifies the memory layout for the kernel-level
|
|
||||||
// virtio device associated with the given control file descriptor.
|
|
||||||
func SetMemoryLayout(controlFD int, layout MemoryLayout) error {
|
|
||||||
payload := layout.serializePayload()
|
|
||||||
if err := IoctlPtr(controlFD, vhostIoctlSetMemoryLayout, unsafe.Pointer(&payload[0])); err != nil {
|
|
||||||
return fmt.Errorf("set memory layout: %w", err)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// RegisterQueue registers a virtio queue with the kernel-level virtio server.
|
|
||||||
// The virtqueue will be linked to the given control file descriptor and will
|
|
||||||
// have the given index. The kernel will use this queue until the control file
|
|
||||||
// descriptor is closed.
|
|
||||||
func RegisterQueue(controlFD int, queueIndex uint32, queue *virtqueue.SplitQueue) error {
|
|
||||||
if err := IoctlPtr(controlFD, vhostIoctlSetQueueSize, unsafe.Pointer(&QueueState{
|
|
||||||
QueueIndex: queueIndex,
|
|
||||||
Num: uint32(queue.Size()),
|
|
||||||
})); err != nil {
|
|
||||||
return fmt.Errorf("set queue size: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := IoctlPtr(controlFD, vhostIoctlSetQueueAddress, unsafe.Pointer(&QueueAddresses{
|
|
||||||
QueueIndex: queueIndex,
|
|
||||||
Flags: 0,
|
|
||||||
DescriptorTableAddress: queue.DescriptorTable().Address(),
|
|
||||||
UsedRingAddress: queue.UsedRing().Address(),
|
|
||||||
AvailableRingAddress: queue.AvailableRing().Address(),
|
|
||||||
LogAddress: 0,
|
|
||||||
})); err != nil {
|
|
||||||
return fmt.Errorf("set queue addresses: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := IoctlPtr(controlFD, vhostIoctlSetAvailableRingBase, unsafe.Pointer(&QueueState{
|
|
||||||
QueueIndex: queueIndex,
|
|
||||||
Num: 0,
|
|
||||||
})); err != nil {
|
|
||||||
return fmt.Errorf("set available ring base: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := IoctlPtr(controlFD, vhostIoctlSetQueueKickEventFD, unsafe.Pointer(&QueueFile{
|
|
||||||
QueueIndex: queueIndex,
|
|
||||||
FD: int32(queue.KickEventFD()),
|
|
||||||
})); err != nil {
|
|
||||||
return fmt.Errorf("set kick event file descriptor: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := IoctlPtr(controlFD, vhostIoctlSetQueueCallEventFD, unsafe.Pointer(&QueueFile{
|
|
||||||
QueueIndex: queueIndex,
|
|
||||||
FD: int32(queue.CallEventFD()),
|
|
||||||
})); err != nil {
|
|
||||||
return fmt.Errorf("set call event file descriptor: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
@@ -1,21 +0,0 @@
|
|||||||
package vhost_test
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
"unsafe"
|
|
||||||
|
|
||||||
"github.com/slackhq/nebula/overlay/vhost"
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestQueueState_Size(t *testing.T) {
|
|
||||||
assert.EqualValues(t, 8, unsafe.Sizeof(vhost.QueueState{}))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestQueueAddresses_Size(t *testing.T) {
|
|
||||||
assert.EqualValues(t, 40, unsafe.Sizeof(vhost.QueueAddresses{}))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestQueueFile_Size(t *testing.T) {
|
|
||||||
assert.EqualValues(t, 8, unsafe.Sizeof(vhost.QueueFile{}))
|
|
||||||
}
|
|
||||||
@@ -1,73 +0,0 @@
|
|||||||
package vhost
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/binary"
|
|
||||||
"fmt"
|
|
||||||
"unsafe"
|
|
||||||
|
|
||||||
"github.com/slackhq/nebula/overlay/virtqueue"
|
|
||||||
)
|
|
||||||
|
|
||||||
// MemoryRegion describes a region of userspace memory which is being made
|
|
||||||
// accessible to a vhost device.
|
|
||||||
//
|
|
||||||
// Kernel name: vhost_memory_region
|
|
||||||
type MemoryRegion struct {
|
|
||||||
// GuestPhysicalAddress is the physical address of the memory region within
|
|
||||||
// the guest, when virtualization is used. When no virtualization is used,
|
|
||||||
// this should be the same as UserspaceAddress.
|
|
||||||
GuestPhysicalAddress uintptr
|
|
||||||
// Size is the size of the memory region.
|
|
||||||
Size uint64
|
|
||||||
// UserspaceAddress is the virtual address in the userspace of the host
|
|
||||||
// where the memory region can be found.
|
|
||||||
UserspaceAddress uintptr
|
|
||||||
// Padding and room for flags. Currently unused.
|
|
||||||
_ uint64
|
|
||||||
}
|
|
||||||
|
|
||||||
// MemoryLayout is a list of [MemoryRegion]s.
|
|
||||||
type MemoryLayout []MemoryRegion
|
|
||||||
|
|
||||||
// NewMemoryLayoutForQueues returns a new [MemoryLayout] that describes the
|
|
||||||
// memory pages used by the descriptor tables of the given queues.
|
|
||||||
func NewMemoryLayoutForQueues(queues []*virtqueue.SplitQueue) MemoryLayout {
|
|
||||||
regions := make([]MemoryRegion, 0)
|
|
||||||
for _, queue := range queues {
|
|
||||||
for address, size := range queue.DescriptorTable().BufferAddresses() {
|
|
||||||
regions = append(regions, MemoryRegion{
|
|
||||||
// There is no virtualization in play here, so the guest address
|
|
||||||
// is the same as in the host's userspace.
|
|
||||||
GuestPhysicalAddress: address,
|
|
||||||
Size: uint64(size),
|
|
||||||
UserspaceAddress: address,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return regions
|
|
||||||
}
|
|
||||||
|
|
||||||
// serializePayload serializes the list of memory regions into a format that is
|
|
||||||
// compatible to the vhost_memory kernel struct. The returned byte slice can be
|
|
||||||
// used as a payload for the vhostIoctlSetMemoryLayout ioctl.
|
|
||||||
func (regions MemoryLayout) serializePayload() []byte {
|
|
||||||
regionCount := len(regions)
|
|
||||||
regionSize := int(unsafe.Sizeof(MemoryRegion{}))
|
|
||||||
payload := make([]byte, 8+regionCount*regionSize)
|
|
||||||
|
|
||||||
// The first 32 bits contain the number of memory regions. The following 32
|
|
||||||
// bits are padding.
|
|
||||||
binary.LittleEndian.PutUint32(payload[0:4], uint32(regionCount))
|
|
||||||
|
|
||||||
if regionCount > 0 {
|
|
||||||
// The underlying byte array of the slice should already have the correct
|
|
||||||
// format, so just copy that.
|
|
||||||
copied := copy(payload[8:], unsafe.Slice((*byte)(unsafe.Pointer(®ions[0])), regionCount*regionSize))
|
|
||||||
if copied != regionCount*regionSize {
|
|
||||||
panic(fmt.Sprintf("copied only %d bytes of the memory regions, but expected %d",
|
|
||||||
copied, regionCount*regionSize))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return payload
|
|
||||||
}
|
|
||||||
@@ -1,42 +0,0 @@
|
|||||||
package vhost
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
"unsafe"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestMemoryRegion_Size(t *testing.T) {
|
|
||||||
assert.EqualValues(t, 32, unsafe.Sizeof(MemoryRegion{}))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMemoryLayout_SerializePayload(t *testing.T) {
|
|
||||||
layout := MemoryLayout([]MemoryRegion{
|
|
||||||
{
|
|
||||||
GuestPhysicalAddress: 42,
|
|
||||||
Size: 100,
|
|
||||||
UserspaceAddress: 142,
|
|
||||||
}, {
|
|
||||||
GuestPhysicalAddress: 99,
|
|
||||||
Size: 100,
|
|
||||||
UserspaceAddress: 99,
|
|
||||||
},
|
|
||||||
})
|
|
||||||
payload := layout.serializePayload()
|
|
||||||
|
|
||||||
assert.Equal(t, []byte{
|
|
||||||
0x02, 0x00, 0x00, 0x00, // nregions
|
|
||||||
0x00, 0x00, 0x00, 0x00, // padding
|
|
||||||
// region 0
|
|
||||||
0x2a, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // guest_phys_addr
|
|
||||||
0x64, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // memory_size
|
|
||||||
0x8e, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // userspace_addr
|
|
||||||
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // flags_padding
|
|
||||||
// region 1
|
|
||||||
0x63, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // guest_phys_addr
|
|
||||||
0x64, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // memory_size
|
|
||||||
0x63, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // userspace_addr
|
|
||||||
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // flags_padding
|
|
||||||
}, payload)
|
|
||||||
}
|
|
||||||
@@ -1,23 +0,0 @@
|
|||||||
Significant portions of this code are derived from https://pkg.go.dev/github.com/hetznercloud/virtio-go
|
|
||||||
|
|
||||||
MIT License
|
|
||||||
|
|
||||||
Copyright (c) 2025 Hetzner Cloud GmbH
|
|
||||||
|
|
||||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
||||||
of this software and associated documentation files (the "Software"), to deal
|
|
||||||
in the Software without restriction, including without limitation the rights
|
|
||||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
||||||
copies of the Software, and to permit persons to whom the Software is
|
|
||||||
furnished to do so, subject to the following conditions:
|
|
||||||
|
|
||||||
The above copyright notice and this permission notice shall be included in all
|
|
||||||
copies or substantial portions of the Software.
|
|
||||||
|
|
||||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
||||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
||||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
||||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
||||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
||||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
||||||
SOFTWARE.
|
|
||||||
@@ -1,427 +0,0 @@
|
|||||||
package vhostnet
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"os"
|
|
||||||
"runtime"
|
|
||||||
|
|
||||||
"github.com/slackhq/nebula/overlay/vhost"
|
|
||||||
"github.com/slackhq/nebula/overlay/virtqueue"
|
|
||||||
"github.com/slackhq/nebula/packet"
|
|
||||||
"github.com/slackhq/nebula/util/virtio"
|
|
||||||
"golang.org/x/sys/unix"
|
|
||||||
)
|
|
||||||
|
|
||||||
// ErrDeviceClosed is returned when the [Device] is closed while operations are
|
|
||||||
// still running.
|
|
||||||
var ErrDeviceClosed = errors.New("device was closed")
|
|
||||||
|
|
||||||
// The indexes for the receive and transmit queues.
|
|
||||||
const (
|
|
||||||
receiveQueueIndex = 0
|
|
||||||
transmitQueueIndex = 1
|
|
||||||
)
|
|
||||||
|
|
||||||
// Device represents a vhost networking device within the kernel-level virtio
|
|
||||||
// implementation and provides methods to interact with it.
|
|
||||||
type Device struct {
|
|
||||||
initialized bool
|
|
||||||
controlFD int
|
|
||||||
|
|
||||||
fullTable bool
|
|
||||||
ReceiveQueue *virtqueue.SplitQueue
|
|
||||||
TransmitQueue *virtqueue.SplitQueue
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewDevice initializes a new vhost networking device within the
|
|
||||||
// kernel-level virtio implementation, sets up the virtqueues and returns a
|
|
||||||
// [Device] instance that can be used to communicate with that vhost device.
|
|
||||||
//
|
|
||||||
// There are multiple options that can be passed to this constructor to
|
|
||||||
// influence device creation:
|
|
||||||
// - [WithQueueSize]
|
|
||||||
// - [WithBackendFD]
|
|
||||||
// - [WithBackendDevice]
|
|
||||||
//
|
|
||||||
// Remember to call [Device.Close] after use to free up resources.
|
|
||||||
func NewDevice(options ...Option) (*Device, error) {
|
|
||||||
var err error
|
|
||||||
opts := optionDefaults
|
|
||||||
opts.apply(options)
|
|
||||||
if err = opts.validate(); err != nil {
|
|
||||||
return nil, fmt.Errorf("invalid options: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
dev := Device{
|
|
||||||
controlFD: -1,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Clean up a partially initialized device when something fails.
|
|
||||||
defer func() {
|
|
||||||
if err != nil {
|
|
||||||
_ = dev.Close()
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
// Retrieve a new control file descriptor. This will be used to configure
|
|
||||||
// the vhost networking device in the kernel.
|
|
||||||
dev.controlFD, err = unix.Open("/dev/vhost-net", os.O_RDWR, 0666)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("get control file descriptor: %w", err)
|
|
||||||
}
|
|
||||||
if err = vhost.OwnControlFD(dev.controlFD); err != nil {
|
|
||||||
return nil, fmt.Errorf("own control file descriptor: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Advertise the supported features. This isn't much for now.
|
|
||||||
// TODO: Add feature options and implement proper feature negotiation.
|
|
||||||
getFeatures, err := vhost.GetFeatures(dev.controlFD) //0x1033D008000 but why
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("get features: %w", err)
|
|
||||||
}
|
|
||||||
if getFeatures == 0 {
|
|
||||||
|
|
||||||
}
|
|
||||||
//const funky = virtio.Feature(1 << 27)
|
|
||||||
//features := virtio.FeatureVersion1 | funky // | todo virtio.FeatureNetMergeRXBuffers
|
|
||||||
features := virtio.FeatureVersion1 | virtio.FeatureNetMergeRXBuffers
|
|
||||||
if err = vhost.SetFeatures(dev.controlFD, features); err != nil {
|
|
||||||
return nil, fmt.Errorf("set features: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
itemSize := os.Getpagesize() * 4 //todo config
|
|
||||||
|
|
||||||
// Initialize and register the queues needed for the networking device.
|
|
||||||
if dev.ReceiveQueue, err = createQueue(dev.controlFD, receiveQueueIndex, opts.queueSize, itemSize); err != nil {
|
|
||||||
return nil, fmt.Errorf("create receive queue: %w", err)
|
|
||||||
}
|
|
||||||
if dev.TransmitQueue, err = createQueue(dev.controlFD, transmitQueueIndex, opts.queueSize, itemSize); err != nil {
|
|
||||||
return nil, fmt.Errorf("create transmit queue: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set up memory mappings for all buffers used by the queues. This has to
|
|
||||||
// happen before a backend for the queues can be registered.
|
|
||||||
memoryLayout := vhost.NewMemoryLayoutForQueues(
|
|
||||||
[]*virtqueue.SplitQueue{dev.ReceiveQueue, dev.TransmitQueue},
|
|
||||||
)
|
|
||||||
if err = vhost.SetMemoryLayout(dev.controlFD, memoryLayout); err != nil {
|
|
||||||
return nil, fmt.Errorf("setup memory layout: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set the queue backends. This activates the queues within the kernel.
|
|
||||||
if err = SetQueueBackend(dev.controlFD, receiveQueueIndex, opts.backendFD); err != nil {
|
|
||||||
return nil, fmt.Errorf("set receive queue backend: %w", err)
|
|
||||||
}
|
|
||||||
if err = SetQueueBackend(dev.controlFD, transmitQueueIndex, opts.backendFD); err != nil {
|
|
||||||
return nil, fmt.Errorf("set transmit queue backend: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Fully populate the receive queue with available buffers which the device
|
|
||||||
// can write new packets into.
|
|
||||||
if err = dev.refillReceiveQueue(); err != nil {
|
|
||||||
return nil, fmt.Errorf("refill receive queue: %w", err)
|
|
||||||
}
|
|
||||||
if err = dev.refillTransmitQueue(); err != nil {
|
|
||||||
return nil, fmt.Errorf("refill receive queue: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
dev.initialized = true
|
|
||||||
|
|
||||||
// Make sure to clean up even when the device gets garbage collected without
|
|
||||||
// Close being called first.
|
|
||||||
devPtr := &dev
|
|
||||||
runtime.SetFinalizer(devPtr, (*Device).Close)
|
|
||||||
|
|
||||||
return devPtr, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// refillReceiveQueue offers as many new device-writable buffers to the device
|
|
||||||
// as the queue can fit. The device will then use these to write received
|
|
||||||
// packets.
|
|
||||||
func (dev *Device) refillReceiveQueue() error {
|
|
||||||
for {
|
|
||||||
_, err := dev.ReceiveQueue.OfferInDescriptorChains()
|
|
||||||
if err != nil {
|
|
||||||
if errors.Is(err, virtqueue.ErrNotEnoughFreeDescriptors) {
|
|
||||||
// Queue is full, job is done.
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return fmt.Errorf("offer descriptor chain: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (dev *Device) refillTransmitQueue() error {
|
|
||||||
//for {
|
|
||||||
// desc, err := dev.TransmitQueue.DescriptorTable().CreateDescriptorForOutputs()
|
|
||||||
// if err != nil {
|
|
||||||
// if errors.Is(err, virtqueue.ErrNotEnoughFreeDescriptors) {
|
|
||||||
// // Queue is full, job is done.
|
|
||||||
// return nil
|
|
||||||
// }
|
|
||||||
// return fmt.Errorf("offer descriptor chain: %w", err)
|
|
||||||
// } else {
|
|
||||||
// dev.TransmitQueue.UsedRing().InitOfferSingle(desc, 0)
|
|
||||||
// }
|
|
||||||
//}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Close cleans up the vhost networking device within the kernel and releases
|
|
||||||
// all resources used for it.
|
|
||||||
// The implementation will try to release as many resources as possible and
|
|
||||||
// collect potential errors before returning them.
|
|
||||||
func (dev *Device) Close() error {
|
|
||||||
dev.initialized = false
|
|
||||||
|
|
||||||
// Closing the control file descriptor will unregister all queues from the
|
|
||||||
// kernel.
|
|
||||||
if dev.controlFD >= 0 {
|
|
||||||
if err := unix.Close(dev.controlFD); err != nil {
|
|
||||||
// Return an error and do not continue, because the memory used for
|
|
||||||
// the queues should not be released before they were unregistered
|
|
||||||
// from the kernel.
|
|
||||||
return fmt.Errorf("close control file descriptor: %w", err)
|
|
||||||
}
|
|
||||||
dev.controlFD = -1
|
|
||||||
}
|
|
||||||
|
|
||||||
var errs []error
|
|
||||||
|
|
||||||
if dev.ReceiveQueue != nil {
|
|
||||||
if err := dev.ReceiveQueue.Close(); err == nil {
|
|
||||||
dev.ReceiveQueue = nil
|
|
||||||
} else {
|
|
||||||
errs = append(errs, fmt.Errorf("close receive queue: %w", err))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if dev.TransmitQueue != nil {
|
|
||||||
if err := dev.TransmitQueue.Close(); err == nil {
|
|
||||||
dev.TransmitQueue = nil
|
|
||||||
} else {
|
|
||||||
errs = append(errs, fmt.Errorf("close transmit queue: %w", err))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(errs) == 0 {
|
|
||||||
// Everything was cleaned up. No need to run the finalizer anymore.
|
|
||||||
runtime.SetFinalizer(dev, nil)
|
|
||||||
}
|
|
||||||
|
|
||||||
return errors.Join(errs...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ensureInitialized is used as a guard to prevent methods to be called on an
|
|
||||||
// uninitialized instance.
|
|
||||||
func (dev *Device) ensureInitialized() {
|
|
||||||
if !dev.initialized {
|
|
||||||
panic("device is not initialized")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// createQueue creates a new virtqueue and registers it with the vhost device
|
|
||||||
// using the given index.
|
|
||||||
func createQueue(controlFD int, queueIndex int, queueSize int, itemSize int) (*virtqueue.SplitQueue, error) {
|
|
||||||
var (
|
|
||||||
queue *virtqueue.SplitQueue
|
|
||||||
err error
|
|
||||||
)
|
|
||||||
if queue, err = virtqueue.NewSplitQueue(queueSize, itemSize); err != nil {
|
|
||||||
return nil, fmt.Errorf("create virtqueue: %w", err)
|
|
||||||
}
|
|
||||||
if err = vhost.RegisterQueue(controlFD, uint32(queueIndex), queue); err != nil {
|
|
||||||
return nil, fmt.Errorf("register virtqueue with index %d: %w", queueIndex, err)
|
|
||||||
}
|
|
||||||
return queue, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// truncateBuffers returns a new list of buffers whose combined length matches
|
|
||||||
// exactly the specified length. When the specified length exceeds the length of
|
|
||||||
// the buffers, this is an error. When it is smaller, the buffer list will be
|
|
||||||
// truncated accordingly.
|
|
||||||
func truncateBuffers(buffers [][]byte, length int) (out [][]byte) {
|
|
||||||
for _, buffer := range buffers {
|
|
||||||
if length < len(buffer) {
|
|
||||||
out = append(out, buffer[:length])
|
|
||||||
return
|
|
||||||
}
|
|
||||||
out = append(out, buffer)
|
|
||||||
length -= len(buffer)
|
|
||||||
}
|
|
||||||
if length > 0 {
|
|
||||||
panic("length exceeds the combined length of all buffers")
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (dev *Device) GetPacketForTx() (uint16, []byte, error) {
|
|
||||||
var err error
|
|
||||||
var idx uint16
|
|
||||||
if !dev.fullTable {
|
|
||||||
|
|
||||||
idx, err = dev.TransmitQueue.DescriptorTable().CreateDescriptorForOutputs()
|
|
||||||
if err == virtqueue.ErrNotEnoughFreeDescriptors {
|
|
||||||
dev.fullTable = true
|
|
||||||
idx, err = dev.TransmitQueue.TakeSingle(context.TODO())
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
idx, err = dev.TransmitQueue.TakeSingle(context.TODO())
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return 0, nil, fmt.Errorf("transmit queue: %w", err)
|
|
||||||
}
|
|
||||||
buf, err := dev.TransmitQueue.GetDescriptorItem(idx)
|
|
||||||
if err != nil {
|
|
||||||
return 0, nil, fmt.Errorf("get descriptor chain: %w", err)
|
|
||||||
}
|
|
||||||
return idx, buf, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (dev *Device) TransmitPacket(pkt *packet.OutPacket, kick bool) error {
|
|
||||||
if len(pkt.SegmentIDs) == 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
for idx := range pkt.SegmentIDs {
|
|
||||||
segmentID := pkt.SegmentIDs[idx]
|
|
||||||
dev.TransmitQueue.SetDescSize(segmentID, len(pkt.Segments[idx]))
|
|
||||||
}
|
|
||||||
err := dev.TransmitQueue.OfferDescriptorChains(pkt.SegmentIDs, false)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("offer descriptor chains: %w", err)
|
|
||||||
}
|
|
||||||
pkt.Reset()
|
|
||||||
if kick {
|
|
||||||
if err := dev.TransmitQueue.Kick(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (dev *Device) TransmitPackets(pkts []*packet.OutPacket) error {
|
|
||||||
if len(pkts) == 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
for i := range pkts {
|
|
||||||
if err := dev.TransmitPacket(pkts[i], false); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if err := dev.TransmitQueue.Kick(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: Make above methods cancelable by taking a context.Context argument?
|
|
||||||
// TODO: Implement zero-copy variants to transmit and receive packets?
|
|
||||||
|
|
||||||
// processChains processes as many chains as needed to create one packet. The number of processed chains is returned.
|
|
||||||
func (dev *Device) processChains(pkt *packet.VirtIOPacket, chains []virtqueue.UsedElement) (int, error) {
|
|
||||||
//read first element to see how many descriptors we need:
|
|
||||||
pkt.Reset()
|
|
||||||
|
|
||||||
err := dev.ReceiveQueue.GetDescriptorInbuffers(uint16(chains[0].DescriptorIndex), &pkt.ChainRefs)
|
|
||||||
if err != nil {
|
|
||||||
return 0, fmt.Errorf("get descriptor chain: %w", err)
|
|
||||||
}
|
|
||||||
if len(pkt.ChainRefs) == 0 {
|
|
||||||
return 1, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// The specification requires that the first descriptor chain starts
|
|
||||||
// with a virtio-net header. It is not clear, whether it is also
|
|
||||||
// required to be fully contained in the first buffer of that
|
|
||||||
// descriptor chain, but it is reasonable to assume that this is
|
|
||||||
// always the case.
|
|
||||||
// The decode method already does the buffer length check.
|
|
||||||
if err = pkt.Header.Decode(pkt.ChainRefs[0][0:]); err != nil {
|
|
||||||
// The device misbehaved. There is no way we can gracefully
|
|
||||||
// recover from this, because we don't know how many of the
|
|
||||||
// following descriptor chains belong to this packet.
|
|
||||||
return 0, fmt.Errorf("decode vnethdr: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
//we have the header now: what do we need to do?
|
|
||||||
if int(pkt.Header.NumBuffers) > len(chains) {
|
|
||||||
return 0, fmt.Errorf("number of buffers is greater than number of chains %d", len(chains))
|
|
||||||
}
|
|
||||||
if int(pkt.Header.NumBuffers) != 1 {
|
|
||||||
return 0, fmt.Errorf("too smol-brain to handle more than one chain right now: %d chains", len(chains))
|
|
||||||
}
|
|
||||||
if chains[0].Length > 16000 {
|
|
||||||
//todo!
|
|
||||||
return 1, fmt.Errorf("too big packet length: %d", chains[0].Length)
|
|
||||||
}
|
|
||||||
|
|
||||||
//shift the buffer out of out:
|
|
||||||
pkt.Payload = pkt.ChainRefs[0][virtio.NetHdrSize:chains[0].Length]
|
|
||||||
pkt.Chains = append(pkt.Chains, uint16(chains[0].DescriptorIndex))
|
|
||||||
return 1, nil
|
|
||||||
|
|
||||||
//cursor := n - virtio.NetHdrSize
|
|
||||||
//
|
|
||||||
//if uint32(n) >= chains[0].Length && pkt.Header.NumBuffers == 1 {
|
|
||||||
// pkt.Payload = pkt.Payload[:chains[0].Length-virtio.NetHdrSize]
|
|
||||||
// return 1, nil
|
|
||||||
//}
|
|
||||||
//
|
|
||||||
//i := 1
|
|
||||||
//// we used chain 0 already
|
|
||||||
//for i = 1; i < len(chains); i++ {
|
|
||||||
// n, err = dev.ReceiveQueue.GetDescriptorChainContents(uint16(chains[i].DescriptorIndex), pkt.Payload[cursor:], int(chains[i].Length))
|
|
||||||
// if err != nil {
|
|
||||||
// // When this fails we may miss to free some descriptor chains. We
|
|
||||||
// // could try to mitigate this by deferring the freeing somehow, but
|
|
||||||
// // it's not worth the hassle. When this method fails, the queue will
|
|
||||||
// // be in a broken state anyway.
|
|
||||||
// return i, fmt.Errorf("get descriptor chain: %w", err)
|
|
||||||
// }
|
|
||||||
// cursor += n
|
|
||||||
//}
|
|
||||||
////todo this has to be wrong
|
|
||||||
//pkt.Payload = pkt.Payload[:cursor]
|
|
||||||
//return i, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (dev *Device) ReceivePackets(out []*packet.VirtIOPacket) (int, error) {
|
|
||||||
//todo optimize?
|
|
||||||
var chains []virtqueue.UsedElement
|
|
||||||
var err error
|
|
||||||
//if len(dev.extraRx) == 0 {
|
|
||||||
chains, err = dev.ReceiveQueue.BlockAndGetHeadsCapped(context.TODO(), len(out))
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
if len(chains) == 0 {
|
|
||||||
return 0, nil
|
|
||||||
}
|
|
||||||
//} else {
|
|
||||||
// chains = dev.extraRx
|
|
||||||
//}
|
|
||||||
|
|
||||||
numPackets := 0
|
|
||||||
chainsIdx := 0
|
|
||||||
for numPackets = 0; chainsIdx < len(chains); numPackets++ {
|
|
||||||
if numPackets >= len(out) {
|
|
||||||
return numPackets, fmt.Errorf("dropping %d packets, no room", len(chains)-numPackets)
|
|
||||||
}
|
|
||||||
numChains, err := dev.processChains(out[numPackets], chains[chainsIdx:])
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
chainsIdx += numChains
|
|
||||||
}
|
|
||||||
|
|
||||||
// Now that we have copied all buffers, we can recycle the used descriptor chains
|
|
||||||
//if err = dev.ReceiveQueue.OfferDescriptorChains(chains); err != nil {
|
|
||||||
// return 0, err
|
|
||||||
//}
|
|
||||||
|
|
||||||
return numPackets, nil
|
|
||||||
}
|
|
||||||
@@ -1,3 +0,0 @@
|
|||||||
// Package vhostnet implements methods to initialize vhost networking devices
|
|
||||||
// within the kernel-level virtio implementation and communicate with them.
|
|
||||||
package vhostnet
|
|
||||||
@@ -1,31 +0,0 @@
|
|||||||
package vhostnet
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"unsafe"
|
|
||||||
|
|
||||||
"github.com/slackhq/nebula/overlay/vhost"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
// vhostNetIoctlSetBackend can be used to attach a virtqueue to a RAW socket
|
|
||||||
// or TAP device.
|
|
||||||
//
|
|
||||||
// Request payload: [vhost.QueueFile]
|
|
||||||
// Kernel name: VHOST_NET_SET_BACKEND
|
|
||||||
vhostNetIoctlSetBackend = 0x4008af30
|
|
||||||
)
|
|
||||||
|
|
||||||
// SetQueueBackend attaches a virtqueue of the vhost networking device
|
|
||||||
// described by controlFD to the given backend file descriptor.
|
|
||||||
// The backend file descriptor can either be a RAW socket or a TAP device. When
|
|
||||||
// it is -1, the queue will be detached.
|
|
||||||
func SetQueueBackend(controlFD int, queueIndex uint32, backendFD int) error {
|
|
||||||
if err := vhost.IoctlPtr(controlFD, vhostNetIoctlSetBackend, unsafe.Pointer(&vhost.QueueFile{
|
|
||||||
QueueIndex: queueIndex,
|
|
||||||
FD: int32(backendFD),
|
|
||||||
})); err != nil {
|
|
||||||
return fmt.Errorf("set queue backend file descriptor: %w", err)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
@@ -1,69 +0,0 @@
|
|||||||
package vhostnet
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
|
|
||||||
"github.com/slackhq/nebula/overlay/virtqueue"
|
|
||||||
)
|
|
||||||
|
|
||||||
type optionValues struct {
|
|
||||||
queueSize int
|
|
||||||
backendFD int
|
|
||||||
}
|
|
||||||
|
|
||||||
func (o *optionValues) apply(options []Option) {
|
|
||||||
for _, option := range options {
|
|
||||||
option(o)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (o *optionValues) validate() error {
|
|
||||||
if o.queueSize == -1 {
|
|
||||||
return errors.New("queue size is required")
|
|
||||||
}
|
|
||||||
if err := virtqueue.CheckQueueSize(o.queueSize); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if o.backendFD == -1 {
|
|
||||||
return errors.New("backend file descriptor is required")
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
var optionDefaults = optionValues{
|
|
||||||
// Required.
|
|
||||||
queueSize: -1,
|
|
||||||
// Required.
|
|
||||||
backendFD: -1,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Option can be passed to [NewDevice] to influence device creation.
|
|
||||||
type Option func(*optionValues)
|
|
||||||
|
|
||||||
// WithQueueSize returns an [Option] that sets the size of the TX and RX queues
|
|
||||||
// that are to be created for the device. It specifies the number of
|
|
||||||
// entries/buffers each queue can hold. This also affects the memory
|
|
||||||
// consumption.
|
|
||||||
// This is required and must be an integer from 1 to 32768 that is also a power
|
|
||||||
// of 2.
|
|
||||||
func WithQueueSize(queueSize int) Option {
|
|
||||||
return func(o *optionValues) { o.queueSize = queueSize }
|
|
||||||
}
|
|
||||||
|
|
||||||
// WithBackendFD returns an [Option] that sets the file descriptor of the
|
|
||||||
// backend that will be used for the queues of the device. The device will write
|
|
||||||
// and read packets to/from that backend. The file descriptor can either be of a
|
|
||||||
// RAW socket or TUN/TAP device.
|
|
||||||
// Either this or [WithBackendDevice] is required.
|
|
||||||
func WithBackendFD(backendFD int) Option {
|
|
||||||
return func(o *optionValues) { o.backendFD = backendFD }
|
|
||||||
}
|
|
||||||
|
|
||||||
//// WithBackendDevice returns an [Option] that sets the given TAP device as the
|
|
||||||
//// backend that will be used for the queues of the device. The device will
|
|
||||||
//// write and read packets to/from that backend. The TAP device should have been
|
|
||||||
//// created with the [tuntap.WithVirtioNetHdr] option enabled.
|
|
||||||
//// Either this or [WithBackendFD] is required.
|
|
||||||
//func WithBackendDevice(dev *tuntap.Device) Option {
|
|
||||||
// return func(o *optionValues) { o.backendFD = int(dev.File().Fd()) }
|
|
||||||
//}
|
|
||||||
@@ -1,23 +0,0 @@
|
|||||||
Significant portions of this code are derived from https://pkg.go.dev/github.com/hetznercloud/virtio-go
|
|
||||||
|
|
||||||
MIT License
|
|
||||||
|
|
||||||
Copyright (c) 2025 Hetzner Cloud GmbH
|
|
||||||
|
|
||||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
||||||
of this software and associated documentation files (the "Software"), to deal
|
|
||||||
in the Software without restriction, including without limitation the rights
|
|
||||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
||||||
copies of the Software, and to permit persons to whom the Software is
|
|
||||||
furnished to do so, subject to the following conditions:
|
|
||||||
|
|
||||||
The above copyright notice and this permission notice shall be included in all
|
|
||||||
copies or substantial portions of the Software.
|
|
||||||
|
|
||||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
||||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
||||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
||||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
||||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
||||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
||||||
SOFTWARE.
|
|
||||||
@@ -1,140 +0,0 @@
|
|||||||
package virtqueue
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"unsafe"
|
|
||||||
)
|
|
||||||
|
|
||||||
// availableRingFlag is a flag that describes an [AvailableRing].
|
|
||||||
type availableRingFlag uint16
|
|
||||||
|
|
||||||
const (
|
|
||||||
// availableRingFlagNoInterrupt is used by the guest to advise the host to
|
|
||||||
// not interrupt it when consuming a buffer. It's unreliable, so it's simply
|
|
||||||
// an optimization.
|
|
||||||
availableRingFlagNoInterrupt availableRingFlag = 1 << iota
|
|
||||||
)
|
|
||||||
|
|
||||||
// availableRingSize is the number of bytes needed to store an [AvailableRing]
|
|
||||||
// with the given queue size in memory.
|
|
||||||
func availableRingSize(queueSize int) int {
|
|
||||||
return 6 + 2*queueSize
|
|
||||||
}
|
|
||||||
|
|
||||||
// availableRingAlignment is the minimum alignment of an [AvailableRing]
|
|
||||||
// in memory, as required by the virtio spec.
|
|
||||||
const availableRingAlignment = 2
|
|
||||||
|
|
||||||
// AvailableRing is used by the driver to offer descriptor chains to the device.
|
|
||||||
// Each ring entry refers to the head of a descriptor chain. It is only written
|
|
||||||
// to by the driver and read by the device.
|
|
||||||
//
|
|
||||||
// Because the size of the ring depends on the queue size, we cannot define a
|
|
||||||
// Go struct with a static size that maps to the memory of the ring. Instead,
|
|
||||||
// this struct only contains pointers to the corresponding memory areas.
|
|
||||||
type AvailableRing struct {
|
|
||||||
initialized bool
|
|
||||||
|
|
||||||
// flags that describe this ring.
|
|
||||||
flags *availableRingFlag
|
|
||||||
// ringIndex indicates where the driver would put the next entry into the
|
|
||||||
// ring (modulo the queue size).
|
|
||||||
ringIndex *uint16
|
|
||||||
// ring references buffers using the index of the head of the descriptor
|
|
||||||
// chain in the [DescriptorTable]. It wraps around at queue size.
|
|
||||||
ring []uint16
|
|
||||||
// usedEvent is not used by this implementation, but we reserve it anyway to
|
|
||||||
// avoid issues in case a device may try to access it, contrary to the
|
|
||||||
// virtio specification.
|
|
||||||
usedEvent *uint16
|
|
||||||
}
|
|
||||||
|
|
||||||
// newAvailableRing creates an available ring that uses the given underlying
|
|
||||||
// memory. The length of the memory slice must match the size needed for the
|
|
||||||
// ring (see [availableRingSize]) for the given queue size.
|
|
||||||
func newAvailableRing(queueSize int, mem []byte) *AvailableRing {
|
|
||||||
ringSize := availableRingSize(queueSize)
|
|
||||||
if len(mem) != ringSize {
|
|
||||||
panic(fmt.Sprintf("memory size (%v) does not match required size "+
|
|
||||||
"for available ring: %v", len(mem), ringSize))
|
|
||||||
}
|
|
||||||
|
|
||||||
return &AvailableRing{
|
|
||||||
initialized: true,
|
|
||||||
flags: (*availableRingFlag)(unsafe.Pointer(&mem[0])),
|
|
||||||
ringIndex: (*uint16)(unsafe.Pointer(&mem[2])),
|
|
||||||
ring: unsafe.Slice((*uint16)(unsafe.Pointer(&mem[4])), queueSize),
|
|
||||||
usedEvent: (*uint16)(unsafe.Pointer(&mem[ringSize-2])),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Address returns the pointer to the beginning of the ring in memory.
|
|
||||||
// Do not modify the memory directly to not interfere with this implementation.
|
|
||||||
func (r *AvailableRing) Address() uintptr {
|
|
||||||
if !r.initialized {
|
|
||||||
panic("available ring is not initialized")
|
|
||||||
}
|
|
||||||
return uintptr(unsafe.Pointer(r.flags))
|
|
||||||
}
|
|
||||||
|
|
||||||
// offer adds the given descriptor chain heads to the available ring and
|
|
||||||
// advances the ring index accordingly to make the device process the new
|
|
||||||
// descriptor chains.
|
|
||||||
func (r *AvailableRing) offerElements(chains []UsedElement) {
|
|
||||||
//always called under lock
|
|
||||||
//r.mu.Lock()
|
|
||||||
//defer r.mu.Unlock()
|
|
||||||
|
|
||||||
// Add descriptor chain heads to the ring.
|
|
||||||
for offset, x := range chains {
|
|
||||||
// The 16-bit ring index may overflow. This is expected and is not an
|
|
||||||
// issue because the size of the ring array (which equals the queue
|
|
||||||
// size) is always a power of 2 and smaller than the highest possible
|
|
||||||
// 16-bit value.
|
|
||||||
insertIndex := int(*r.ringIndex+uint16(offset)) % len(r.ring)
|
|
||||||
r.ring[insertIndex] = x.GetHead()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Increase the ring index by the number of descriptor chains added to the
|
|
||||||
// ring.
|
|
||||||
*r.ringIndex += uint16(len(chains))
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *AvailableRing) offer(chains []uint16) {
|
|
||||||
//always called under lock
|
|
||||||
//r.mu.Lock()
|
|
||||||
//defer r.mu.Unlock()
|
|
||||||
|
|
||||||
// Add descriptor chain heads to the ring.
|
|
||||||
for offset, x := range chains {
|
|
||||||
// The 16-bit ring index may overflow. This is expected and is not an
|
|
||||||
// issue because the size of the ring array (which equals the queue
|
|
||||||
// size) is always a power of 2 and smaller than the highest possible
|
|
||||||
// 16-bit value.
|
|
||||||
insertIndex := int(*r.ringIndex+uint16(offset)) % len(r.ring)
|
|
||||||
r.ring[insertIndex] = x
|
|
||||||
}
|
|
||||||
|
|
||||||
// Increase the ring index by the number of descriptor chains added to the
|
|
||||||
// ring.
|
|
||||||
*r.ringIndex += uint16(len(chains))
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *AvailableRing) offerSingle(x uint16) {
|
|
||||||
//always called under lock
|
|
||||||
//r.mu.Lock()
|
|
||||||
//defer r.mu.Unlock()
|
|
||||||
|
|
||||||
offset := 0
|
|
||||||
// Add descriptor chain heads to the ring.
|
|
||||||
|
|
||||||
// The 16-bit ring index may overflow. This is expected and is not an
|
|
||||||
// issue because the size of the ring array (which equals the queue
|
|
||||||
// size) is always a power of 2 and smaller than the highest possible
|
|
||||||
// 16-bit value.
|
|
||||||
insertIndex := int(*r.ringIndex+uint16(offset)) % len(r.ring)
|
|
||||||
r.ring[insertIndex] = x
|
|
||||||
|
|
||||||
// Increase the ring index by the number of descriptor chains added to the ring.
|
|
||||||
*r.ringIndex += 1
|
|
||||||
}
|
|
||||||
@@ -1,71 +0,0 @@
|
|||||||
package virtqueue
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestAvailableRing_MemoryLayout(t *testing.T) {
|
|
||||||
const queueSize = 2
|
|
||||||
|
|
||||||
memory := make([]byte, availableRingSize(queueSize))
|
|
||||||
r := newAvailableRing(queueSize, memory)
|
|
||||||
|
|
||||||
*r.flags = 0x01ff
|
|
||||||
*r.ringIndex = 1
|
|
||||||
r.ring[0] = 0x1234
|
|
||||||
r.ring[1] = 0x5678
|
|
||||||
|
|
||||||
assert.Equal(t, []byte{
|
|
||||||
0xff, 0x01,
|
|
||||||
0x01, 0x00,
|
|
||||||
0x34, 0x12,
|
|
||||||
0x78, 0x56,
|
|
||||||
0x00, 0x00,
|
|
||||||
}, memory)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAvailableRing_Offer(t *testing.T) {
|
|
||||||
const queueSize = 8
|
|
||||||
|
|
||||||
chainHeads := []uint16{42, 33, 69}
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
startRingIndex uint16
|
|
||||||
expectedRingIndex uint16
|
|
||||||
expectedRing []uint16
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "no overflow",
|
|
||||||
startRingIndex: 0,
|
|
||||||
expectedRingIndex: 3,
|
|
||||||
expectedRing: []uint16{42, 33, 69, 0, 0, 0, 0, 0},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "ring overflow",
|
|
||||||
startRingIndex: 6,
|
|
||||||
expectedRingIndex: 9,
|
|
||||||
expectedRing: []uint16{69, 0, 0, 0, 0, 0, 42, 33},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "index overflow",
|
|
||||||
startRingIndex: 65535,
|
|
||||||
expectedRingIndex: 2,
|
|
||||||
expectedRing: []uint16{33, 69, 0, 0, 0, 0, 0, 42},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
memory := make([]byte, availableRingSize(queueSize))
|
|
||||||
r := newAvailableRing(queueSize, memory)
|
|
||||||
*r.ringIndex = tt.startRingIndex
|
|
||||||
|
|
||||||
r.offer(chainHeads)
|
|
||||||
|
|
||||||
assert.Equal(t, tt.expectedRingIndex, *r.ringIndex)
|
|
||||||
assert.Equal(t, tt.expectedRing, r.ring)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,43 +0,0 @@
|
|||||||
package virtqueue
|
|
||||||
|
|
||||||
// descriptorFlag is a flag that describes a [Descriptor].
|
|
||||||
type descriptorFlag uint16
|
|
||||||
|
|
||||||
const (
|
|
||||||
// descriptorFlagHasNext marks a descriptor chain as continuing via the next
|
|
||||||
// field.
|
|
||||||
descriptorFlagHasNext descriptorFlag = 1 << iota
|
|
||||||
// descriptorFlagWritable marks a buffer as device write-only (otherwise
|
|
||||||
// device read-only).
|
|
||||||
descriptorFlagWritable
|
|
||||||
// descriptorFlagIndirect means the buffer contains a list of buffer
|
|
||||||
// descriptors to provide an additional layer of indirection.
|
|
||||||
// Only allowed when the [virtio.FeatureIndirectDescriptors] feature was
|
|
||||||
// negotiated.
|
|
||||||
descriptorFlagIndirect
|
|
||||||
)
|
|
||||||
|
|
||||||
// descriptorSize is the number of bytes needed to store a [Descriptor] in
|
|
||||||
// memory.
|
|
||||||
const descriptorSize = 16
|
|
||||||
|
|
||||||
// Descriptor describes (a part of) a buffer which is either read-only for the
|
|
||||||
// device or write-only for the device (depending on [descriptorFlagWritable]).
|
|
||||||
// Multiple descriptors can be chained to produce a "descriptor chain" that can
|
|
||||||
// contain both device-readable and device-writable buffers. Device-readable
|
|
||||||
// descriptors always come first in a chain. A single, large buffer may be
|
|
||||||
// split up by chaining multiple similar descriptors that reference different
|
|
||||||
// memory pages. This is required, because buffers may exceed a single page size
|
|
||||||
// and the memory accessed by the device is expected to be continuous.
|
|
||||||
type Descriptor struct {
|
|
||||||
// address is the address to the continuous memory holding the data for this
|
|
||||||
// descriptor.
|
|
||||||
address uintptr
|
|
||||||
// length is the amount of bytes stored at address.
|
|
||||||
length uint32
|
|
||||||
// flags that describe this descriptor.
|
|
||||||
flags descriptorFlag
|
|
||||||
// next contains the index of the next descriptor continuing this descriptor
|
|
||||||
// chain when the [descriptorFlagHasNext] flag is set.
|
|
||||||
next uint16
|
|
||||||
}
|
|
||||||
@@ -1,12 +0,0 @@
|
|||||||
package virtqueue
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
"unsafe"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestDescriptor_Size(t *testing.T) {
|
|
||||||
assert.EqualValues(t, descriptorSize, unsafe.Sizeof(Descriptor{}))
|
|
||||||
}
|
|
||||||
@@ -1,641 +0,0 @@
|
|||||||
package virtqueue
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"math"
|
|
||||||
"unsafe"
|
|
||||||
|
|
||||||
"golang.org/x/sys/unix"
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
// ErrDescriptorChainEmpty is returned when a descriptor chain would contain
|
|
||||||
// no buffers, which is not allowed.
|
|
||||||
ErrDescriptorChainEmpty = errors.New("empty descriptor chains are not allowed")
|
|
||||||
|
|
||||||
// ErrNotEnoughFreeDescriptors is returned when the free descriptors are
|
|
||||||
// exhausted, meaning that the queue is full.
|
|
||||||
ErrNotEnoughFreeDescriptors = errors.New("not enough free descriptors, queue is full")
|
|
||||||
|
|
||||||
// ErrInvalidDescriptorChain is returned when a descriptor chain is not
|
|
||||||
// valid for a given operation.
|
|
||||||
ErrInvalidDescriptorChain = errors.New("invalid descriptor chain")
|
|
||||||
)
|
|
||||||
|
|
||||||
// noFreeHead is used to mark when all descriptors are in use and we have no
|
|
||||||
// free chain. This value is impossible to occur as an index naturally, because
|
|
||||||
// it exceeds the maximum queue size.
|
|
||||||
const noFreeHead = uint16(math.MaxUint16)
|
|
||||||
|
|
||||||
// descriptorTableSize is the number of bytes needed to store a
|
|
||||||
// [DescriptorTable] with the given queue size in memory.
|
|
||||||
func descriptorTableSize(queueSize int) int {
|
|
||||||
return descriptorSize * queueSize
|
|
||||||
}
|
|
||||||
|
|
||||||
// descriptorTableAlignment is the minimum alignment of a [DescriptorTable]
|
|
||||||
// in memory, as required by the virtio spec.
|
|
||||||
const descriptorTableAlignment = 16
|
|
||||||
|
|
||||||
// DescriptorTable is a table that holds [Descriptor]s, addressed via their
|
|
||||||
// index in the slice.
|
|
||||||
type DescriptorTable struct {
|
|
||||||
descriptors []Descriptor
|
|
||||||
|
|
||||||
// freeHeadIndex is the index of the head of the descriptor chain which
|
|
||||||
// contains all currently unused descriptors. When all descriptors are in
|
|
||||||
// use, this has the special value of noFreeHead.
|
|
||||||
freeHeadIndex uint16
|
|
||||||
// freeNum tracks the number of descriptors which are currently not in use.
|
|
||||||
freeNum uint16
|
|
||||||
|
|
||||||
bufferBase uintptr
|
|
||||||
bufferSize int
|
|
||||||
itemSize int
|
|
||||||
}
|
|
||||||
|
|
||||||
// newDescriptorTable creates a descriptor table that uses the given underlying
|
|
||||||
// memory. The Length of the memory slice must match the size needed for the
|
|
||||||
// descriptor table (see [descriptorTableSize]) for the given queue size.
|
|
||||||
//
|
|
||||||
// Before this descriptor table can be used, [initialize] must be called.
|
|
||||||
func newDescriptorTable(queueSize int, mem []byte, itemSize int) *DescriptorTable {
|
|
||||||
dtSize := descriptorTableSize(queueSize)
|
|
||||||
if len(mem) != dtSize {
|
|
||||||
panic(fmt.Sprintf("memory size (%v) does not match required size "+
|
|
||||||
"for descriptor table: %v", len(mem), dtSize))
|
|
||||||
}
|
|
||||||
|
|
||||||
return &DescriptorTable{
|
|
||||||
descriptors: unsafe.Slice((*Descriptor)(unsafe.Pointer(&mem[0])), queueSize),
|
|
||||||
// We have no free descriptors until they were initialized.
|
|
||||||
freeHeadIndex: noFreeHead,
|
|
||||||
freeNum: 0,
|
|
||||||
itemSize: itemSize, //todo configurable? needs to be page-aligned
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Address returns the pointer to the beginning of the descriptor table in
|
|
||||||
// memory. Do not modify the memory directly to not interfere with this
|
|
||||||
// implementation.
|
|
||||||
func (dt *DescriptorTable) Address() uintptr {
|
|
||||||
if dt.descriptors == nil {
|
|
||||||
panic("descriptor table is not initialized")
|
|
||||||
}
|
|
||||||
//should be same as dt.bufferBase
|
|
||||||
return uintptr(unsafe.Pointer(&dt.descriptors[0]))
|
|
||||||
}
|
|
||||||
|
|
||||||
func (dt *DescriptorTable) Size() uintptr {
|
|
||||||
if dt.descriptors == nil {
|
|
||||||
panic("descriptor table is not initialized")
|
|
||||||
}
|
|
||||||
return uintptr(dt.bufferSize)
|
|
||||||
}
|
|
||||||
|
|
||||||
// BufferAddresses returns a map of pointer->size for all allocations used by the table
|
|
||||||
func (dt *DescriptorTable) BufferAddresses() map[uintptr]int {
|
|
||||||
if dt.descriptors == nil {
|
|
||||||
panic("descriptor table is not initialized")
|
|
||||||
}
|
|
||||||
|
|
||||||
return map[uintptr]int{dt.bufferBase: dt.bufferSize}
|
|
||||||
}
|
|
||||||
|
|
||||||
// initializeDescriptors allocates buffers with the size of a full memory page
|
|
||||||
// for each descriptor in the table. While this may be a bit wasteful, it makes
|
|
||||||
// dealing with descriptors way easier. Without this preallocation, we would
|
|
||||||
// have to allocate and free memory on demand, increasing complexity.
|
|
||||||
//
|
|
||||||
// All descriptors will be marked as free and will form a free chain. The
|
|
||||||
// addresses of all descriptors will be populated while their length remains
|
|
||||||
// zero.
|
|
||||||
func (dt *DescriptorTable) initializeDescriptors() error {
|
|
||||||
numDescriptors := len(dt.descriptors)
|
|
||||||
|
|
||||||
// Allocate ONE large region for all buffers
|
|
||||||
totalSize := dt.itemSize * numDescriptors
|
|
||||||
basePtr, err := unix.MmapPtr(-1, 0, nil, uintptr(totalSize),
|
|
||||||
unix.PROT_READ|unix.PROT_WRITE,
|
|
||||||
unix.MAP_PRIVATE|unix.MAP_ANONYMOUS)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("allocate buffer memory for descriptors: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Store the base for cleanup later
|
|
||||||
dt.bufferBase = uintptr(basePtr)
|
|
||||||
dt.bufferSize = totalSize
|
|
||||||
|
|
||||||
for i := range dt.descriptors {
|
|
||||||
dt.descriptors[i] = Descriptor{
|
|
||||||
address: dt.bufferBase + uintptr(i*dt.itemSize),
|
|
||||||
length: 0,
|
|
||||||
// All descriptors should form a free chain that loops around.
|
|
||||||
flags: descriptorFlagHasNext,
|
|
||||||
next: uint16((i + 1) % len(dt.descriptors)),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// All descriptors are free to use now.
|
|
||||||
dt.freeHeadIndex = 0
|
|
||||||
dt.freeNum = uint16(len(dt.descriptors))
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// releaseBuffers releases all allocated buffers for this descriptor table.
|
|
||||||
// The implementation will try to release as many buffers as possible and
|
|
||||||
// collect potential errors before returning them.
|
|
||||||
// The descriptor table should no longer be used after calling this.
|
|
||||||
func (dt *DescriptorTable) releaseBuffers() error {
|
|
||||||
for i := range dt.descriptors {
|
|
||||||
descriptor := &dt.descriptors[i]
|
|
||||||
descriptor.address = 0
|
|
||||||
}
|
|
||||||
|
|
||||||
// As a safety measure, make sure no descriptors can be used anymore.
|
|
||||||
dt.freeHeadIndex = noFreeHead
|
|
||||||
dt.freeNum = 0
|
|
||||||
|
|
||||||
if dt.bufferBase != 0 {
|
|
||||||
// The pointer points to memory not managed by Go, so this conversion
|
|
||||||
// is safe. See https://github.com/golang/go/issues/58625
|
|
||||||
dt.bufferBase = 0
|
|
||||||
//goland:noinspection GoVetUnsafePointer
|
|
||||||
err := unix.MunmapPtr(unsafe.Pointer(dt.bufferBase), uintptr(dt.bufferSize))
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("release buffer memory: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// createDescriptorChain creates a new descriptor chain within the descriptor
|
|
||||||
// table which contains a number of device-readable buffers (out buffers) and
|
|
||||||
// device-writable buffers (in buffers).
|
|
||||||
//
|
|
||||||
// All buffers in the outBuffers slice will be concatenated by chaining
|
|
||||||
// descriptors, one for each buffer in the slice. The size of the single buffers
|
|
||||||
// must not exceed the size of a memory page (see [os.Getpagesize]).
|
|
||||||
// When numInBuffers is greater than zero, the given number of device-writable
|
|
||||||
// descriptors will be appended to the end of the chain, each referencing a
|
|
||||||
// whole memory page.
|
|
||||||
//
|
|
||||||
// The index of the head of the new descriptor chain will be returned. Callers
|
|
||||||
// should make sure to free the descriptor chain using [freeDescriptorChain]
|
|
||||||
// after it was used by the device.
|
|
||||||
//
|
|
||||||
// When there are not enough free descriptors to hold the given number of
|
|
||||||
// buffers, an [ErrNotEnoughFreeDescriptors] will be returned. In this case, the
|
|
||||||
// caller should try again after some descriptor chains were used by the device
|
|
||||||
// and returned back into the free chain.
|
|
||||||
func (dt *DescriptorTable) createDescriptorChain(outBuffers [][]byte, numInBuffers int) (uint16, error) {
|
|
||||||
// Calculate the number of descriptors needed to build the chain.
|
|
||||||
numDesc := uint16(len(outBuffers) + numInBuffers)
|
|
||||||
|
|
||||||
// Descriptor chains must always contain at least one descriptor.
|
|
||||||
if numDesc < 1 {
|
|
||||||
return 0, ErrDescriptorChainEmpty
|
|
||||||
}
|
|
||||||
|
|
||||||
// Do we still have enough free descriptors?
|
|
||||||
if numDesc > dt.freeNum {
|
|
||||||
return 0, ErrNotEnoughFreeDescriptors
|
|
||||||
}
|
|
||||||
|
|
||||||
// Above validation ensured that there is at least one free descriptor, so
|
|
||||||
// the free descriptor chain head should be valid.
|
|
||||||
if dt.freeHeadIndex == noFreeHead {
|
|
||||||
panic("free descriptor chain head is unset but there should be free descriptors")
|
|
||||||
}
|
|
||||||
|
|
||||||
// To avoid having to iterate over the whole table to find the descriptor
|
|
||||||
// pointing to the head just to replace the free head, we instead always
|
|
||||||
// create descriptor chains from the descriptors coming after the head.
|
|
||||||
// This way we only have to touch the head as a last resort, when all other
|
|
||||||
// descriptors are already used.
|
|
||||||
head := dt.descriptors[dt.freeHeadIndex].next
|
|
||||||
next := head
|
|
||||||
tail := head
|
|
||||||
for i, buffer := range outBuffers {
|
|
||||||
desc := &dt.descriptors[next]
|
|
||||||
checkUnusedDescriptorLength(next, desc)
|
|
||||||
|
|
||||||
if len(buffer) > dt.itemSize {
|
|
||||||
// The caller should already prevent that from happening.
|
|
||||||
panic(fmt.Sprintf("out buffer %d has size %d which exceeds desc length %d", i, len(buffer), dt.itemSize))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Copy the buffer to the memory referenced by the descriptor.
|
|
||||||
// The descriptor address points to memory not managed by Go, so this
|
|
||||||
// conversion is safe. See https://github.com/golang/go/issues/58625
|
|
||||||
//goland:noinspection GoVetUnsafePointer
|
|
||||||
copy(unsafe.Slice((*byte)(unsafe.Pointer(desc.address)), dt.itemSize), buffer)
|
|
||||||
desc.length = uint32(len(buffer))
|
|
||||||
|
|
||||||
// Clear the flags in case there were any others set.
|
|
||||||
desc.flags = descriptorFlagHasNext
|
|
||||||
|
|
||||||
tail = next
|
|
||||||
next = desc.next
|
|
||||||
}
|
|
||||||
for range numInBuffers {
|
|
||||||
desc := &dt.descriptors[next]
|
|
||||||
checkUnusedDescriptorLength(next, desc)
|
|
||||||
|
|
||||||
// Give the device the maximum available number of bytes to write into.
|
|
||||||
desc.length = uint32(dt.itemSize)
|
|
||||||
|
|
||||||
// Mark the descriptor as device-writable.
|
|
||||||
desc.flags = descriptorFlagHasNext | descriptorFlagWritable
|
|
||||||
|
|
||||||
tail = next
|
|
||||||
next = desc.next
|
|
||||||
}
|
|
||||||
|
|
||||||
// The last descriptor should end the chain.
|
|
||||||
tailDesc := &dt.descriptors[tail]
|
|
||||||
tailDesc.flags &= ^descriptorFlagHasNext
|
|
||||||
tailDesc.next = 0 // Not necessary to clear this, it's just for looks.
|
|
||||||
|
|
||||||
dt.freeNum -= numDesc
|
|
||||||
|
|
||||||
if dt.freeNum == 0 {
|
|
||||||
// The last descriptor in the chain should be the free chain head
|
|
||||||
// itself.
|
|
||||||
if tail != dt.freeHeadIndex {
|
|
||||||
panic("descriptor chain takes up all free descriptors but does not end with the free chain head")
|
|
||||||
}
|
|
||||||
|
|
||||||
// When this new chain takes up all remaining descriptors, we no longer
|
|
||||||
// have a free chain.
|
|
||||||
dt.freeHeadIndex = noFreeHead
|
|
||||||
} else {
|
|
||||||
// We took some descriptors out of the free chain, so make sure to close
|
|
||||||
// the circle again.
|
|
||||||
dt.descriptors[dt.freeHeadIndex].next = next
|
|
||||||
}
|
|
||||||
|
|
||||||
return head, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (dt *DescriptorTable) CreateDescriptorForOutputs() (uint16, error) {
|
|
||||||
//todo just fill the damn table
|
|
||||||
// Do we still have enough free descriptors?
|
|
||||||
|
|
||||||
if 1 > dt.freeNum {
|
|
||||||
return 0, ErrNotEnoughFreeDescriptors
|
|
||||||
}
|
|
||||||
|
|
||||||
// Above validation ensured that there is at least one free descriptor, so
|
|
||||||
// the free descriptor chain head should be valid.
|
|
||||||
if dt.freeHeadIndex == noFreeHead {
|
|
||||||
panic("free descriptor chain head is unset but there should be free descriptors")
|
|
||||||
}
|
|
||||||
|
|
||||||
// To avoid having to iterate over the whole table to find the descriptor
|
|
||||||
// pointing to the head just to replace the free head, we instead always
|
|
||||||
// create descriptor chains from the descriptors coming after the head.
|
|
||||||
// This way we only have to touch the head as a last resort, when all other
|
|
||||||
// descriptors are already used.
|
|
||||||
head := dt.descriptors[dt.freeHeadIndex].next
|
|
||||||
desc := &dt.descriptors[head]
|
|
||||||
next := desc.next
|
|
||||||
|
|
||||||
checkUnusedDescriptorLength(head, desc)
|
|
||||||
|
|
||||||
// Give the device the maximum available number of bytes to write into.
|
|
||||||
desc.length = uint32(dt.itemSize)
|
|
||||||
desc.flags = 0 // descriptorFlagWritable
|
|
||||||
desc.next = 0 // Not necessary to clear this, it's just for looks.
|
|
||||||
|
|
||||||
dt.freeNum -= 1
|
|
||||||
|
|
||||||
if dt.freeNum == 0 {
|
|
||||||
// The last descriptor in the chain should be the free chain head
|
|
||||||
// itself.
|
|
||||||
if next != dt.freeHeadIndex {
|
|
||||||
panic("descriptor chain takes up all free descriptors but does not end with the free chain head")
|
|
||||||
}
|
|
||||||
|
|
||||||
// When this new chain takes up all remaining descriptors, we no longer
|
|
||||||
// have a free chain.
|
|
||||||
dt.freeHeadIndex = noFreeHead
|
|
||||||
} else {
|
|
||||||
// We took some descriptors out of the free chain, so make sure to close
|
|
||||||
// the circle again.
|
|
||||||
dt.descriptors[dt.freeHeadIndex].next = next
|
|
||||||
}
|
|
||||||
|
|
||||||
return head, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (dt *DescriptorTable) createDescriptorForInputs() (uint16, error) {
|
|
||||||
// Do we still have enough free descriptors?
|
|
||||||
if 1 > dt.freeNum {
|
|
||||||
return 0, ErrNotEnoughFreeDescriptors
|
|
||||||
}
|
|
||||||
|
|
||||||
// Above validation ensured that there is at least one free descriptor, so
|
|
||||||
// the free descriptor chain head should be valid.
|
|
||||||
if dt.freeHeadIndex == noFreeHead {
|
|
||||||
panic("free descriptor chain head is unset but there should be free descriptors")
|
|
||||||
}
|
|
||||||
|
|
||||||
// To avoid having to iterate over the whole table to find the descriptor
|
|
||||||
// pointing to the head just to replace the free head, we instead always
|
|
||||||
// create descriptor chains from the descriptors coming after the head.
|
|
||||||
// This way we only have to touch the head as a last resort, when all other
|
|
||||||
// descriptors are already used.
|
|
||||||
head := dt.descriptors[dt.freeHeadIndex].next
|
|
||||||
desc := &dt.descriptors[head]
|
|
||||||
next := desc.next
|
|
||||||
|
|
||||||
checkUnusedDescriptorLength(head, desc)
|
|
||||||
|
|
||||||
// Give the device the maximum available number of bytes to write into.
|
|
||||||
desc.length = uint32(dt.itemSize)
|
|
||||||
desc.flags = descriptorFlagWritable
|
|
||||||
desc.next = 0 // Not necessary to clear this, it's just for looks.
|
|
||||||
|
|
||||||
dt.freeNum -= 1
|
|
||||||
|
|
||||||
if dt.freeNum == 0 {
|
|
||||||
// The last descriptor in the chain should be the free chain head
|
|
||||||
// itself.
|
|
||||||
if next != dt.freeHeadIndex {
|
|
||||||
panic("descriptor chain takes up all free descriptors but does not end with the free chain head")
|
|
||||||
}
|
|
||||||
|
|
||||||
// When this new chain takes up all remaining descriptors, we no longer
|
|
||||||
// have a free chain.
|
|
||||||
dt.freeHeadIndex = noFreeHead
|
|
||||||
} else {
|
|
||||||
// We took some descriptors out of the free chain, so make sure to close
|
|
||||||
// the circle again.
|
|
||||||
dt.descriptors[dt.freeHeadIndex].next = next
|
|
||||||
}
|
|
||||||
|
|
||||||
return head, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: Implement a zero-copy variant of createDescriptorChain?
|
|
||||||
|
|
||||||
// getDescriptorChain returns the device-readable buffers (out buffers) and
|
|
||||||
// device-writable buffers (in buffers) of the descriptor chain that starts with
|
|
||||||
// the given head index. The descriptor chain must have been created using
|
|
||||||
// [createDescriptorChain] and must not have been freed yet (meaning that the
|
|
||||||
// head index must not be contained in the free chain).
|
|
||||||
//
|
|
||||||
// Be careful to only access the returned buffer slices when the device has not
|
|
||||||
// yet or is no longer using them. They must not be accessed after
|
|
||||||
// [freeDescriptorChain] has been called.
|
|
||||||
func (dt *DescriptorTable) getDescriptorChain(head uint16) (outBuffers, inBuffers [][]byte, err error) {
|
|
||||||
if int(head) > len(dt.descriptors) {
|
|
||||||
return nil, nil, fmt.Errorf("%w: index out of range", ErrInvalidDescriptorChain)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Iterate over the chain. The iteration is limited to the queue size to
|
|
||||||
// avoid ending up in an endless loop when things go very wrong.
|
|
||||||
next := head
|
|
||||||
for range len(dt.descriptors) {
|
|
||||||
if next == dt.freeHeadIndex {
|
|
||||||
return nil, nil, fmt.Errorf("%w: must not be part of the free chain", ErrInvalidDescriptorChain)
|
|
||||||
}
|
|
||||||
|
|
||||||
desc := &dt.descriptors[next]
|
|
||||||
|
|
||||||
// The descriptor address points to memory not managed by Go, so this
|
|
||||||
// conversion is safe. See https://github.com/golang/go/issues/58625
|
|
||||||
//goland:noinspection GoVetUnsafePointer
|
|
||||||
bs := unsafe.Slice((*byte)(unsafe.Pointer(desc.address)), desc.length)
|
|
||||||
|
|
||||||
if desc.flags&descriptorFlagWritable == 0 {
|
|
||||||
outBuffers = append(outBuffers, bs)
|
|
||||||
} else {
|
|
||||||
inBuffers = append(inBuffers, bs)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Is this the tail of the chain?
|
|
||||||
if desc.flags&descriptorFlagHasNext == 0 {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
// Detect loops.
|
|
||||||
if desc.next == head {
|
|
||||||
return nil, nil, fmt.Errorf("%w: contains a loop", ErrInvalidDescriptorChain)
|
|
||||||
}
|
|
||||||
|
|
||||||
next = desc.next
|
|
||||||
}
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (dt *DescriptorTable) getDescriptorItem(head uint16) ([]byte, error) {
|
|
||||||
if int(head) > len(dt.descriptors) {
|
|
||||||
return nil, fmt.Errorf("%w: index out of range", ErrInvalidDescriptorChain)
|
|
||||||
}
|
|
||||||
|
|
||||||
desc := &dt.descriptors[head] //todo this is a pretty nasty hack with no checks
|
|
||||||
|
|
||||||
// The descriptor address points to memory not managed by Go, so this
|
|
||||||
// conversion is safe. See https://github.com/golang/go/issues/58625
|
|
||||||
//goland:noinspection GoVetUnsafePointer
|
|
||||||
bs := unsafe.Slice((*byte)(unsafe.Pointer(desc.address)), desc.length)
|
|
||||||
return bs, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (dt *DescriptorTable) getDescriptorInbuffers(head uint16, inBuffers *[][]byte) error {
|
|
||||||
if int(head) > len(dt.descriptors) {
|
|
||||||
return fmt.Errorf("%w: index out of range", ErrInvalidDescriptorChain)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Iterate over the chain. The iteration is limited to the queue size to
|
|
||||||
// avoid ending up in an endless loop when things go very wrong.
|
|
||||||
next := head
|
|
||||||
for range len(dt.descriptors) {
|
|
||||||
if next == dt.freeHeadIndex {
|
|
||||||
return fmt.Errorf("%w: must not be part of the free chain", ErrInvalidDescriptorChain)
|
|
||||||
}
|
|
||||||
|
|
||||||
desc := &dt.descriptors[next]
|
|
||||||
|
|
||||||
// The descriptor address points to memory not managed by Go, so this
|
|
||||||
// conversion is safe. See https://github.com/golang/go/issues/58625
|
|
||||||
//goland:noinspection GoVetUnsafePointer
|
|
||||||
bs := unsafe.Slice((*byte)(unsafe.Pointer(desc.address)), desc.length)
|
|
||||||
|
|
||||||
if desc.flags&descriptorFlagWritable == 0 {
|
|
||||||
return fmt.Errorf("there should not be an outbuffer in %d", head)
|
|
||||||
} else {
|
|
||||||
*inBuffers = append(*inBuffers, bs)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Is this the tail of the chain?
|
|
||||||
if desc.flags&descriptorFlagHasNext == 0 {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
// Detect loops.
|
|
||||||
if desc.next == head {
|
|
||||||
return fmt.Errorf("%w: contains a loop", ErrInvalidDescriptorChain)
|
|
||||||
}
|
|
||||||
|
|
||||||
next = desc.next
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (dt *DescriptorTable) getDescriptorChainContents(head uint16, out []byte, maxLen int) (int, error) {
|
|
||||||
if int(head) > len(dt.descriptors) {
|
|
||||||
return 0, fmt.Errorf("%w: index out of range", ErrInvalidDescriptorChain)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Iterate over the chain. The iteration is limited to the queue size to
|
|
||||||
// avoid ending up in an endless loop when things go very wrong.
|
|
||||||
|
|
||||||
length := 0
|
|
||||||
//find length
|
|
||||||
next := head
|
|
||||||
for range len(dt.descriptors) {
|
|
||||||
if next == dt.freeHeadIndex {
|
|
||||||
return 0, fmt.Errorf("%w: must not be part of the free chain", ErrInvalidDescriptorChain)
|
|
||||||
}
|
|
||||||
|
|
||||||
desc := &dt.descriptors[next]
|
|
||||||
|
|
||||||
if desc.flags&descriptorFlagWritable == 0 {
|
|
||||||
return 0, fmt.Errorf("receive queue contains device-readable buffer")
|
|
||||||
}
|
|
||||||
length += int(desc.length)
|
|
||||||
|
|
||||||
// Is this the tail of the chain?
|
|
||||||
if desc.flags&descriptorFlagHasNext == 0 {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
// Detect loops.
|
|
||||||
if desc.next == head {
|
|
||||||
return 0, fmt.Errorf("%w: contains a loop", ErrInvalidDescriptorChain)
|
|
||||||
}
|
|
||||||
|
|
||||||
next = desc.next
|
|
||||||
}
|
|
||||||
if maxLen > 0 {
|
|
||||||
//todo length = min(maxLen, length)
|
|
||||||
}
|
|
||||||
//set out to length:
|
|
||||||
out = out[:length]
|
|
||||||
|
|
||||||
//now do the copying
|
|
||||||
copied := 0
|
|
||||||
for range len(dt.descriptors) {
|
|
||||||
desc := &dt.descriptors[next]
|
|
||||||
|
|
||||||
// The descriptor address points to memory not managed by Go, so this
|
|
||||||
// conversion is safe. See https://github.com/golang/go/issues/58625
|
|
||||||
//goland:noinspection GoVetUnsafePointer
|
|
||||||
bs := unsafe.Slice((*byte)(unsafe.Pointer(desc.address)), min(uint32(length-copied), desc.length))
|
|
||||||
copied += copy(out[copied:], bs)
|
|
||||||
|
|
||||||
// Is this the tail of the chain?
|
|
||||||
if desc.flags&descriptorFlagHasNext == 0 {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
// we did this already, no need to detect loops.
|
|
||||||
next = desc.next
|
|
||||||
}
|
|
||||||
if copied != length {
|
|
||||||
panic(fmt.Sprintf("expected to copy %d bytes but only copied %d bytes", length, copied))
|
|
||||||
}
|
|
||||||
|
|
||||||
return length, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// freeDescriptorChain can be used to free a descriptor chain when it is no
|
|
||||||
// longer in use. The descriptor chain that starts with the given index will be
|
|
||||||
// put back into the free chain, so the descriptors can be used for later calls
|
|
||||||
// of [createDescriptorChain].
|
|
||||||
// The descriptor chain must have been created using [createDescriptorChain] and
|
|
||||||
// must not have been freed yet (meaning that the head index must not be
|
|
||||||
// contained in the free chain).
|
|
||||||
func (dt *DescriptorTable) freeDescriptorChain(head uint16) error {
|
|
||||||
if int(head) > len(dt.descriptors) {
|
|
||||||
return fmt.Errorf("%w: index out of range", ErrInvalidDescriptorChain)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Iterate over the chain. The iteration is limited to the queue size to
|
|
||||||
// avoid ending up in an endless loop when things go very wrong.
|
|
||||||
next := head
|
|
||||||
var tailDesc *Descriptor
|
|
||||||
var chainLen uint16
|
|
||||||
for range len(dt.descriptors) {
|
|
||||||
if next == dt.freeHeadIndex {
|
|
||||||
return fmt.Errorf("%w: must not be part of the free chain", ErrInvalidDescriptorChain)
|
|
||||||
}
|
|
||||||
|
|
||||||
desc := &dt.descriptors[next]
|
|
||||||
chainLen++
|
|
||||||
|
|
||||||
// Set the length of all unused descriptors back to zero.
|
|
||||||
desc.length = 0
|
|
||||||
|
|
||||||
// Unset all flags except the next flag.
|
|
||||||
desc.flags &= descriptorFlagHasNext
|
|
||||||
|
|
||||||
// Is this the tail of the chain?
|
|
||||||
if desc.flags&descriptorFlagHasNext == 0 {
|
|
||||||
tailDesc = desc
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
// Detect loops.
|
|
||||||
if desc.next == head {
|
|
||||||
return fmt.Errorf("%w: contains a loop", ErrInvalidDescriptorChain)
|
|
||||||
}
|
|
||||||
|
|
||||||
next = desc.next
|
|
||||||
}
|
|
||||||
if tailDesc == nil {
|
|
||||||
// A descriptor chain longer than the queue size but without loops
|
|
||||||
// should be impossible.
|
|
||||||
panic(fmt.Sprintf("could not find a tail for descriptor chain starting at %d", head))
|
|
||||||
}
|
|
||||||
|
|
||||||
// The tail descriptor does not have the next flag set, but when it comes
|
|
||||||
// back into the free chain, it should have.
|
|
||||||
tailDesc.flags = descriptorFlagHasNext
|
|
||||||
|
|
||||||
if dt.freeHeadIndex == noFreeHead {
|
|
||||||
// The whole free chain was used up, so we turn this returned descriptor
|
|
||||||
// chain into the new free chain by completing the circle and using its
|
|
||||||
// head.
|
|
||||||
tailDesc.next = head
|
|
||||||
dt.freeHeadIndex = head
|
|
||||||
} else {
|
|
||||||
// Attach the returned chain at the beginning of the free chain but
|
|
||||||
// right after the free chain head.
|
|
||||||
freeHeadDesc := &dt.descriptors[dt.freeHeadIndex]
|
|
||||||
tailDesc.next = freeHeadDesc.next
|
|
||||||
freeHeadDesc.next = head
|
|
||||||
}
|
|
||||||
|
|
||||||
dt.freeNum += chainLen
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// checkUnusedDescriptorLength asserts that the length of an unused descriptor
|
|
||||||
// is zero, as it should be.
|
|
||||||
// This is not a requirement by the virtio spec but rather a thing we do to
|
|
||||||
// notice when our algorithm goes sideways.
|
|
||||||
func checkUnusedDescriptorLength(index uint16, desc *Descriptor) {
|
|
||||||
if desc.length != 0 {
|
|
||||||
panic(fmt.Sprintf("descriptor %d should be unused but has a non-zero length", index))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,407 +0,0 @@
|
|||||||
package virtqueue
|
|
||||||
|
|
||||||
import (
|
|
||||||
"os"
|
|
||||||
"testing"
|
|
||||||
"unsafe"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestDescriptorTable_InitializeDescriptors(t *testing.T) {
|
|
||||||
const queueSize = 32
|
|
||||||
|
|
||||||
dt := DescriptorTable{
|
|
||||||
descriptors: make([]Descriptor, queueSize),
|
|
||||||
}
|
|
||||||
|
|
||||||
assert.NoError(t, dt.initializeDescriptors())
|
|
||||||
t.Cleanup(func() {
|
|
||||||
assert.NoError(t, dt.releaseBuffers())
|
|
||||||
})
|
|
||||||
|
|
||||||
for i, descriptor := range dt.descriptors {
|
|
||||||
assert.NotZero(t, descriptor.address)
|
|
||||||
assert.Zero(t, descriptor.length)
|
|
||||||
assert.EqualValues(t, descriptorFlagHasNext, descriptor.flags)
|
|
||||||
assert.EqualValues(t, (i+1)%queueSize, descriptor.next)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestDescriptorTable_DescriptorChains(t *testing.T) {
|
|
||||||
// Use a very short queue size to not make this test overly verbose.
|
|
||||||
const queueSize = 8
|
|
||||||
|
|
||||||
pageSize := os.Getpagesize() * 2
|
|
||||||
|
|
||||||
// Initialize descriptor table.
|
|
||||||
dt := DescriptorTable{
|
|
||||||
descriptors: make([]Descriptor, queueSize),
|
|
||||||
}
|
|
||||||
assert.NoError(t, dt.initializeDescriptors())
|
|
||||||
t.Cleanup(func() {
|
|
||||||
assert.NoError(t, dt.releaseBuffers())
|
|
||||||
})
|
|
||||||
|
|
||||||
// Some utilities for easier checking if the descriptor table looks as
|
|
||||||
// expected.
|
|
||||||
type desc struct {
|
|
||||||
buffer []byte
|
|
||||||
flags descriptorFlag
|
|
||||||
next uint16
|
|
||||||
}
|
|
||||||
assertDescriptorTable := func(expected [queueSize]desc) {
|
|
||||||
for i := 0; i < queueSize; i++ {
|
|
||||||
actualDesc := &dt.descriptors[i]
|
|
||||||
expectedDesc := &expected[i]
|
|
||||||
assert.Equal(t, uint32(len(expectedDesc.buffer)), actualDesc.length)
|
|
||||||
if len(expectedDesc.buffer) > 0 {
|
|
||||||
//goland:noinspection GoVetUnsafePointer
|
|
||||||
assert.EqualValues(t,
|
|
||||||
unsafe.Slice((*byte)(unsafe.Pointer(actualDesc.address)), actualDesc.length),
|
|
||||||
expectedDesc.buffer)
|
|
||||||
}
|
|
||||||
assert.Equal(t, expectedDesc.flags, actualDesc.flags)
|
|
||||||
if expectedDesc.flags&descriptorFlagHasNext != 0 {
|
|
||||||
assert.Equal(t, expectedDesc.next, actualDesc.next)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Initial state: All descriptors are in the free chain.
|
|
||||||
assert.Equal(t, uint16(0), dt.freeHeadIndex)
|
|
||||||
assert.Equal(t, uint16(8), dt.freeNum)
|
|
||||||
assertDescriptorTable([queueSize]desc{
|
|
||||||
{
|
|
||||||
// Free head.
|
|
||||||
flags: descriptorFlagHasNext,
|
|
||||||
next: 1,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
flags: descriptorFlagHasNext,
|
|
||||||
next: 2,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
flags: descriptorFlagHasNext,
|
|
||||||
next: 3,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
flags: descriptorFlagHasNext,
|
|
||||||
next: 4,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
flags: descriptorFlagHasNext,
|
|
||||||
next: 5,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
flags: descriptorFlagHasNext,
|
|
||||||
next: 6,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
flags: descriptorFlagHasNext,
|
|
||||||
next: 7,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
flags: descriptorFlagHasNext,
|
|
||||||
next: 0,
|
|
||||||
},
|
|
||||||
})
|
|
||||||
|
|
||||||
// Create the first chain.
|
|
||||||
firstChain, err := dt.createDescriptorChain([][]byte{
|
|
||||||
makeTestBuffer(t, 26),
|
|
||||||
makeTestBuffer(t, 256),
|
|
||||||
}, 1)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.Equal(t, uint16(1), firstChain)
|
|
||||||
|
|
||||||
// Now there should be a new chain next to the free chain.
|
|
||||||
assert.Equal(t, uint16(0), dt.freeHeadIndex)
|
|
||||||
assert.Equal(t, uint16(5), dt.freeNum)
|
|
||||||
assertDescriptorTable([queueSize]desc{
|
|
||||||
{
|
|
||||||
// Free head.
|
|
||||||
flags: descriptorFlagHasNext,
|
|
||||||
next: 4,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
// Head of first chain.
|
|
||||||
buffer: makeTestBuffer(t, 26),
|
|
||||||
flags: descriptorFlagHasNext,
|
|
||||||
next: 2,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
buffer: makeTestBuffer(t, 256),
|
|
||||||
flags: descriptorFlagHasNext,
|
|
||||||
next: 3,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
// Tail of first chain.
|
|
||||||
buffer: make([]byte, pageSize),
|
|
||||||
flags: descriptorFlagWritable,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
flags: descriptorFlagHasNext,
|
|
||||||
next: 5,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
flags: descriptorFlagHasNext,
|
|
||||||
next: 6,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
flags: descriptorFlagHasNext,
|
|
||||||
next: 7,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
flags: descriptorFlagHasNext,
|
|
||||||
next: 0,
|
|
||||||
},
|
|
||||||
})
|
|
||||||
|
|
||||||
// Create a second chain with only a single in buffer.
|
|
||||||
secondChain, err := dt.createDescriptorChain(nil, 1)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.Equal(t, uint16(4), secondChain)
|
|
||||||
|
|
||||||
// Now there should be two chains next to the free chain.
|
|
||||||
assert.Equal(t, uint16(0), dt.freeHeadIndex)
|
|
||||||
assert.Equal(t, uint16(4), dt.freeNum)
|
|
||||||
assertDescriptorTable([queueSize]desc{
|
|
||||||
{
|
|
||||||
// Free head.
|
|
||||||
flags: descriptorFlagHasNext,
|
|
||||||
next: 5,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
// Head of the first chain.
|
|
||||||
buffer: makeTestBuffer(t, 26),
|
|
||||||
flags: descriptorFlagHasNext,
|
|
||||||
next: 2,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
buffer: makeTestBuffer(t, 256),
|
|
||||||
flags: descriptorFlagHasNext,
|
|
||||||
next: 3,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
// Tail of the first chain.
|
|
||||||
buffer: make([]byte, pageSize),
|
|
||||||
flags: descriptorFlagWritable,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
// Head and tail of the second chain.
|
|
||||||
buffer: make([]byte, pageSize),
|
|
||||||
flags: descriptorFlagWritable,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
flags: descriptorFlagHasNext,
|
|
||||||
next: 6,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
flags: descriptorFlagHasNext,
|
|
||||||
next: 7,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
flags: descriptorFlagHasNext,
|
|
||||||
next: 0,
|
|
||||||
},
|
|
||||||
})
|
|
||||||
|
|
||||||
// Create a third chain taking up all remaining descriptors.
|
|
||||||
thirdChain, err := dt.createDescriptorChain([][]byte{
|
|
||||||
makeTestBuffer(t, 42),
|
|
||||||
makeTestBuffer(t, 96),
|
|
||||||
makeTestBuffer(t, 33),
|
|
||||||
makeTestBuffer(t, 222),
|
|
||||||
}, 0)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.Equal(t, uint16(5), thirdChain)
|
|
||||||
|
|
||||||
// Now there should be three chains and no free chain.
|
|
||||||
assert.Equal(t, noFreeHead, dt.freeHeadIndex)
|
|
||||||
assert.Equal(t, uint16(0), dt.freeNum)
|
|
||||||
assertDescriptorTable([queueSize]desc{
|
|
||||||
{
|
|
||||||
// Tail of the third chain.
|
|
||||||
buffer: makeTestBuffer(t, 222),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
// Head of the first chain.
|
|
||||||
buffer: makeTestBuffer(t, 26),
|
|
||||||
flags: descriptorFlagHasNext,
|
|
||||||
next: 2,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
buffer: makeTestBuffer(t, 256),
|
|
||||||
flags: descriptorFlagHasNext,
|
|
||||||
next: 3,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
// Tail of the first chain.
|
|
||||||
buffer: make([]byte, pageSize),
|
|
||||||
flags: descriptorFlagWritable,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
// Head and tail of the second chain.
|
|
||||||
buffer: make([]byte, pageSize),
|
|
||||||
flags: descriptorFlagWritable,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
// Head of the third chain.
|
|
||||||
buffer: makeTestBuffer(t, 42),
|
|
||||||
flags: descriptorFlagHasNext,
|
|
||||||
next: 6,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
buffer: makeTestBuffer(t, 96),
|
|
||||||
flags: descriptorFlagHasNext,
|
|
||||||
next: 7,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
buffer: makeTestBuffer(t, 33),
|
|
||||||
flags: descriptorFlagHasNext,
|
|
||||||
next: 0,
|
|
||||||
},
|
|
||||||
})
|
|
||||||
|
|
||||||
// Free the third chain.
|
|
||||||
assert.NoError(t, dt.freeDescriptorChain(thirdChain))
|
|
||||||
|
|
||||||
// Now there should be two chains and a free chain again.
|
|
||||||
assert.Equal(t, uint16(5), dt.freeHeadIndex)
|
|
||||||
assert.Equal(t, uint16(4), dt.freeNum)
|
|
||||||
assertDescriptorTable([queueSize]desc{
|
|
||||||
{
|
|
||||||
flags: descriptorFlagHasNext,
|
|
||||||
next: 5,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
// Head of the first chain.
|
|
||||||
buffer: makeTestBuffer(t, 26),
|
|
||||||
flags: descriptorFlagHasNext,
|
|
||||||
next: 2,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
buffer: makeTestBuffer(t, 256),
|
|
||||||
flags: descriptorFlagHasNext,
|
|
||||||
next: 3,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
// Tail of the first chain.
|
|
||||||
buffer: make([]byte, pageSize),
|
|
||||||
flags: descriptorFlagWritable,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
// Head and tail of the second chain.
|
|
||||||
buffer: make([]byte, pageSize),
|
|
||||||
flags: descriptorFlagWritable,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
// Free head.
|
|
||||||
flags: descriptorFlagHasNext,
|
|
||||||
next: 6,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
flags: descriptorFlagHasNext,
|
|
||||||
next: 7,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
flags: descriptorFlagHasNext,
|
|
||||||
next: 0,
|
|
||||||
},
|
|
||||||
})
|
|
||||||
|
|
||||||
// Free the first chain.
|
|
||||||
assert.NoError(t, dt.freeDescriptorChain(firstChain))
|
|
||||||
|
|
||||||
// Now there should be only a single chain next to the free chain.
|
|
||||||
assert.Equal(t, uint16(5), dt.freeHeadIndex)
|
|
||||||
assert.Equal(t, uint16(7), dt.freeNum)
|
|
||||||
assertDescriptorTable([queueSize]desc{
|
|
||||||
{
|
|
||||||
flags: descriptorFlagHasNext,
|
|
||||||
next: 5,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
flags: descriptorFlagHasNext,
|
|
||||||
next: 2,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
flags: descriptorFlagHasNext,
|
|
||||||
next: 3,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
flags: descriptorFlagHasNext,
|
|
||||||
next: 6,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
// Head and tail of the second chain.
|
|
||||||
buffer: make([]byte, pageSize),
|
|
||||||
flags: descriptorFlagWritable,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
// Free head.
|
|
||||||
flags: descriptorFlagHasNext,
|
|
||||||
next: 1,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
flags: descriptorFlagHasNext,
|
|
||||||
next: 7,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
flags: descriptorFlagHasNext,
|
|
||||||
next: 0,
|
|
||||||
},
|
|
||||||
})
|
|
||||||
|
|
||||||
// Free the second chain.
|
|
||||||
assert.NoError(t, dt.freeDescriptorChain(secondChain))
|
|
||||||
|
|
||||||
// Now all descriptors should be in the free chain again.
|
|
||||||
assert.Equal(t, uint16(5), dt.freeHeadIndex)
|
|
||||||
assert.Equal(t, uint16(8), dt.freeNum)
|
|
||||||
assertDescriptorTable([queueSize]desc{
|
|
||||||
{
|
|
||||||
flags: descriptorFlagHasNext,
|
|
||||||
next: 5,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
flags: descriptorFlagHasNext,
|
|
||||||
next: 2,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
flags: descriptorFlagHasNext,
|
|
||||||
next: 3,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
flags: descriptorFlagHasNext,
|
|
||||||
next: 6,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
flags: descriptorFlagHasNext,
|
|
||||||
next: 1,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
// Free head.
|
|
||||||
flags: descriptorFlagHasNext,
|
|
||||||
next: 4,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
flags: descriptorFlagHasNext,
|
|
||||||
next: 7,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
flags: descriptorFlagHasNext,
|
|
||||||
next: 0,
|
|
||||||
},
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func makeTestBuffer(t *testing.T, length int) []byte {
|
|
||||||
t.Helper()
|
|
||||||
buf := make([]byte, length)
|
|
||||||
for i := 0; i < length; i++ {
|
|
||||||
buf[i] = byte(length - i)
|
|
||||||
}
|
|
||||||
return buf
|
|
||||||
}
|
|
||||||
@@ -1,7 +0,0 @@
|
|||||||
// Package virtqueue implements the driver-side for a virtio queue as described
|
|
||||||
// in the specification:
|
|
||||||
// https://docs.oasis-open.org/virtio/virtio/v1.2/csd01/virtio-v1.2-csd01.html#x1-270006
|
|
||||||
// This package does not make assumptions about the device that consumes the
|
|
||||||
// queue. It rather just allocates the queue structures in memory and provides
|
|
||||||
// methods to interact with it.
|
|
||||||
package virtqueue
|
|
||||||
@@ -1,45 +0,0 @@
|
|||||||
package virtqueue
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
"gvisor.dev/gvisor/pkg/eventfd"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Tests how an eventfd and a waiting goroutine can be gracefully closed.
|
|
||||||
// Extends the eventfd test suite:
|
|
||||||
// https://github.com/google/gvisor/blob/0799336d64be65eb97d330606c30162dc3440cab/pkg/eventfd/eventfd_test.go
|
|
||||||
func TestEventFD_CancelWait(t *testing.T) {
|
|
||||||
efd, err := eventfd.Create()
|
|
||||||
require.NoError(t, err)
|
|
||||||
t.Cleanup(func() {
|
|
||||||
assert.NoError(t, efd.Close())
|
|
||||||
})
|
|
||||||
|
|
||||||
var stop bool
|
|
||||||
|
|
||||||
done := make(chan struct{})
|
|
||||||
go func() {
|
|
||||||
for !stop {
|
|
||||||
_ = efd.Wait()
|
|
||||||
}
|
|
||||||
close(done)
|
|
||||||
}()
|
|
||||||
select {
|
|
||||||
case <-done:
|
|
||||||
t.Fatalf("goroutine ended early")
|
|
||||||
case <-time.After(500 * time.Millisecond):
|
|
||||||
}
|
|
||||||
|
|
||||||
stop = true
|
|
||||||
assert.NoError(t, efd.Notify())
|
|
||||||
select {
|
|
||||||
case <-done:
|
|
||||||
break
|
|
||||||
case <-time.After(5 * time.Second):
|
|
||||||
t.Error("goroutine did not end")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,33 +0,0 @@
|
|||||||
package virtqueue
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
)
|
|
||||||
|
|
||||||
// ErrQueueSizeInvalid is returned when a queue size is invalid.
|
|
||||||
var ErrQueueSizeInvalid = errors.New("queue size is invalid")
|
|
||||||
|
|
||||||
// CheckQueueSize checks if the given value would be a valid size for a
|
|
||||||
// virtqueue and returns an [ErrQueueSizeInvalid], if not.
|
|
||||||
func CheckQueueSize(queueSize int) error {
|
|
||||||
if queueSize <= 0 {
|
|
||||||
return fmt.Errorf("%w: %d is too small", ErrQueueSizeInvalid, queueSize)
|
|
||||||
}
|
|
||||||
|
|
||||||
// The queue size must always be a power of 2.
|
|
||||||
// This ensures that ring indexes wrap correctly when the 16-bit integers
|
|
||||||
// overflow.
|
|
||||||
if queueSize&(queueSize-1) != 0 {
|
|
||||||
return fmt.Errorf("%w: %d is not a power of 2", ErrQueueSizeInvalid, queueSize)
|
|
||||||
}
|
|
||||||
|
|
||||||
// The largest power of 2 that fits into a 16-bit integer is 32768.
|
|
||||||
// 2 * 32768 would be 65536 which no longer fits.
|
|
||||||
if queueSize > 32768 {
|
|
||||||
return fmt.Errorf("%w: %d is larger than the maximum possible queue size 32768",
|
|
||||||
ErrQueueSizeInvalid, queueSize)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
@@ -1,59 +0,0 @@
|
|||||||
package virtqueue
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestCheckQueueSize(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
queueSize int
|
|
||||||
containsErr string
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "negative",
|
|
||||||
queueSize: -1,
|
|
||||||
containsErr: "too small",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "zero",
|
|
||||||
queueSize: 0,
|
|
||||||
containsErr: "too small",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "not a power of 2",
|
|
||||||
queueSize: 24,
|
|
||||||
containsErr: "not a power of 2",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "too large",
|
|
||||||
queueSize: 65536,
|
|
||||||
containsErr: "larger than the maximum",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "valid 1",
|
|
||||||
queueSize: 1,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "valid 256",
|
|
||||||
queueSize: 256,
|
|
||||||
},
|
|
||||||
|
|
||||||
{
|
|
||||||
name: "valid 32768",
|
|
||||||
queueSize: 32768,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
err := CheckQueueSize(tt.queueSize)
|
|
||||||
if tt.containsErr != "" {
|
|
||||||
assert.ErrorContains(t, err, tt.containsErr)
|
|
||||||
} else {
|
|
||||||
assert.NoError(t, err)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,530 +0,0 @@
|
|||||||
package virtqueue
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"os"
|
|
||||||
"syscall"
|
|
||||||
|
|
||||||
"github.com/slackhq/nebula/overlay/eventfd"
|
|
||||||
"golang.org/x/sys/unix"
|
|
||||||
)
|
|
||||||
|
|
||||||
// SplitQueue is a virtqueue that consists of several parts, where each part is
|
|
||||||
// writeable by either the driver or the device, but not both.
|
|
||||||
type SplitQueue struct {
|
|
||||||
// size is the size of the queue.
|
|
||||||
size int
|
|
||||||
// buf is the underlying memory used for the queue.
|
|
||||||
buf []byte
|
|
||||||
|
|
||||||
descriptorTable *DescriptorTable
|
|
||||||
availableRing *AvailableRing
|
|
||||||
usedRing *UsedRing
|
|
||||||
|
|
||||||
// kickEventFD is used to signal the device when descriptor chains were
|
|
||||||
// added to the available ring.
|
|
||||||
kickEventFD eventfd.EventFD
|
|
||||||
// callEventFD is used by the device to signal when it has used descriptor
|
|
||||||
// chains and put them in the used ring.
|
|
||||||
callEventFD eventfd.EventFD
|
|
||||||
|
|
||||||
// stop is used by [SplitQueue.Close] to cancel the goroutine that handles
|
|
||||||
// used buffer notifications. It blocks until the goroutine ended.
|
|
||||||
stop func() error
|
|
||||||
|
|
||||||
itemSize int
|
|
||||||
|
|
||||||
epoll eventfd.Epoll
|
|
||||||
more int
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewSplitQueue allocates a new [SplitQueue] in memory. The given queue size
|
|
||||||
// specifies the number of entries/buffers the queue can hold. This also affects
|
|
||||||
// the memory consumption.
|
|
||||||
func NewSplitQueue(queueSize int, itemSize int) (_ *SplitQueue, err error) {
|
|
||||||
if err = CheckQueueSize(queueSize); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if itemSize%os.Getpagesize() != 0 {
|
|
||||||
return nil, errors.New("split queue size must be multiple of os.Getpagesize()")
|
|
||||||
}
|
|
||||||
|
|
||||||
sq := SplitQueue{
|
|
||||||
size: queueSize,
|
|
||||||
itemSize: itemSize,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Clean up a partially initialized queue when something fails.
|
|
||||||
defer func() {
|
|
||||||
if err != nil {
|
|
||||||
_ = sq.Close()
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
// There are multiple ways for how the memory for the virtqueue could be
|
|
||||||
// allocated. We could use Go native structs with arrays inside them, but
|
|
||||||
// this wouldn't allow us to make the queue size configurable. And including
|
|
||||||
// a slice in the Go structs wouldn't work, because this would just put the
|
|
||||||
// Go slice descriptor into the memory region which the virtio device will
|
|
||||||
// not understand.
|
|
||||||
// Additionally, Go does not allow us to ensure a correct alignment of the
|
|
||||||
// parts of the virtqueue, as it is required by the virtio specification.
|
|
||||||
//
|
|
||||||
// To resolve this, let's just allocate the memory manually by allocating
|
|
||||||
// one or more memory pages, depending on the queue size. Making the
|
|
||||||
// virtqueue start at the beginning of a page is not strictly necessary, as
|
|
||||||
// the virtio specification does not require it to be continuous in the
|
|
||||||
// physical memory of the host (e.g. the vhost implementation in the kernel
|
|
||||||
// always uses copy_from_user to access it), but this makes it very easy to
|
|
||||||
// guarantee the alignment. Also, it is not required for the virtqueue parts
|
|
||||||
// to be in the same memory region, as we pass separate pointers to them to
|
|
||||||
// the device, but this design just makes things easier to implement.
|
|
||||||
//
|
|
||||||
// One added benefit of allocating the memory manually is, that we have full
|
|
||||||
// control over its lifetime and don't risk the garbage collector to collect
|
|
||||||
// our valuable structures while the device still works with them.
|
|
||||||
|
|
||||||
// The descriptor table is at the start of the page, so alignment is not an
|
|
||||||
// issue here.
|
|
||||||
descriptorTableStart := 0
|
|
||||||
descriptorTableEnd := descriptorTableStart + descriptorTableSize(queueSize)
|
|
||||||
availableRingStart := align(descriptorTableEnd, availableRingAlignment)
|
|
||||||
availableRingEnd := availableRingStart + availableRingSize(queueSize)
|
|
||||||
usedRingStart := align(availableRingEnd, usedRingAlignment)
|
|
||||||
usedRingEnd := usedRingStart + usedRingSize(queueSize)
|
|
||||||
|
|
||||||
sq.buf, err = unix.Mmap(-1, 0, usedRingEnd,
|
|
||||||
unix.PROT_READ|unix.PROT_WRITE,
|
|
||||||
unix.MAP_PRIVATE|unix.MAP_ANONYMOUS)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("allocate virtqueue buffer: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
sq.descriptorTable = newDescriptorTable(queueSize, sq.buf[descriptorTableStart:descriptorTableEnd], sq.itemSize)
|
|
||||||
sq.availableRing = newAvailableRing(queueSize, sq.buf[availableRingStart:availableRingEnd])
|
|
||||||
sq.usedRing = newUsedRing(queueSize, sq.buf[usedRingStart:usedRingEnd])
|
|
||||||
|
|
||||||
sq.kickEventFD, err = eventfd.New()
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("create kick event file descriptor: %w", err)
|
|
||||||
}
|
|
||||||
sq.callEventFD, err = eventfd.New()
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("create call event file descriptor: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err = sq.descriptorTable.initializeDescriptors(); err != nil {
|
|
||||||
return nil, fmt.Errorf("initialize descriptors: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
sq.epoll, err = eventfd.NewEpoll()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
err = sq.epoll.AddEvent(sq.callEventFD.FD())
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Consume used buffer notifications in the background.
|
|
||||||
sq.stop = sq.startConsumeUsedRing()
|
|
||||||
|
|
||||||
return &sq, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Size returns the size of this queue, which is the number of entries/buffers
|
|
||||||
// this queue can hold.
|
|
||||||
func (sq *SplitQueue) Size() int {
|
|
||||||
return sq.size
|
|
||||||
}
|
|
||||||
|
|
||||||
// DescriptorTable returns the [DescriptorTable] behind this queue.
|
|
||||||
func (sq *SplitQueue) DescriptorTable() *DescriptorTable {
|
|
||||||
return sq.descriptorTable
|
|
||||||
}
|
|
||||||
|
|
||||||
// AvailableRing returns the [AvailableRing] behind this queue.
|
|
||||||
func (sq *SplitQueue) AvailableRing() *AvailableRing {
|
|
||||||
return sq.availableRing
|
|
||||||
}
|
|
||||||
|
|
||||||
// UsedRing returns the [UsedRing] behind this queue.
|
|
||||||
func (sq *SplitQueue) UsedRing() *UsedRing {
|
|
||||||
return sq.usedRing
|
|
||||||
}
|
|
||||||
|
|
||||||
// KickEventFD returns the kick event file descriptor behind this queue.
|
|
||||||
// The returned file descriptor should be used with great care to not interfere
|
|
||||||
// with this implementation.
|
|
||||||
func (sq *SplitQueue) KickEventFD() int {
|
|
||||||
return sq.kickEventFD.FD()
|
|
||||||
}
|
|
||||||
|
|
||||||
// CallEventFD returns the call event file descriptor behind this queue.
|
|
||||||
// The returned file descriptor should be used with great care to not interfere
|
|
||||||
// with this implementation.
|
|
||||||
func (sq *SplitQueue) CallEventFD() int {
|
|
||||||
return sq.callEventFD.FD()
|
|
||||||
}
|
|
||||||
|
|
||||||
// startConsumeUsedRing starts a goroutine that runs [consumeUsedRing].
|
|
||||||
// A function is returned that can be used to gracefully cancel it. todo rename
|
|
||||||
func (sq *SplitQueue) startConsumeUsedRing() func() error {
|
|
||||||
return func() error {
|
|
||||||
|
|
||||||
// The goroutine blocks until it receives a signal on the event file
|
|
||||||
// descriptor, so it will never notice the context being canceled.
|
|
||||||
// To resolve this, we can just produce a fake-signal ourselves to wake
|
|
||||||
// it up.
|
|
||||||
if err := sq.callEventFD.Kick(); err != nil {
|
|
||||||
return fmt.Errorf("wake up goroutine: %w", err)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// BlockAndGetHeads waits for the device to signal that it has used descriptor chains and returns all [UsedElement]s
|
|
||||||
func (sq *SplitQueue) BlockAndGetHeads(ctx context.Context) ([]UsedElement, error) {
|
|
||||||
var n int
|
|
||||||
var err error
|
|
||||||
for ctx.Err() == nil {
|
|
||||||
|
|
||||||
// Wait for a signal from the device.
|
|
||||||
if n, err = sq.epoll.Block(); err != nil {
|
|
||||||
return nil, fmt.Errorf("wait: %w", err)
|
|
||||||
}
|
|
||||||
if n > 0 {
|
|
||||||
stillNeedToTake, out := sq.usedRing.take(-1)
|
|
||||||
sq.more = stillNeedToTake
|
|
||||||
if stillNeedToTake == 0 {
|
|
||||||
_ = sq.epoll.Clear() //???
|
|
||||||
}
|
|
||||||
return out, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil, ctx.Err()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (sq *SplitQueue) TakeSingle(ctx context.Context) (uint16, error) {
|
|
||||||
var n int
|
|
||||||
var err error
|
|
||||||
for ctx.Err() == nil {
|
|
||||||
out, ok := sq.usedRing.takeOne()
|
|
||||||
if ok {
|
|
||||||
return out, nil
|
|
||||||
}
|
|
||||||
// Wait for a signal from the device.
|
|
||||||
if n, err = sq.epoll.Block(); err != nil {
|
|
||||||
return 0, fmt.Errorf("wait: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if n > 0 {
|
|
||||||
out, ok = sq.usedRing.takeOne()
|
|
||||||
if ok {
|
|
||||||
_ = sq.epoll.Clear() //???
|
|
||||||
return out, nil
|
|
||||||
} else {
|
|
||||||
continue //???
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return 0, ctx.Err()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (sq *SplitQueue) BlockAndGetHeadsCapped(ctx context.Context, maxToTake int) ([]UsedElement, error) {
|
|
||||||
var n int
|
|
||||||
var err error
|
|
||||||
for ctx.Err() == nil {
|
|
||||||
|
|
||||||
//we have leftovers in the fridge
|
|
||||||
if sq.more > 0 {
|
|
||||||
stillNeedToTake, out := sq.usedRing.take(maxToTake)
|
|
||||||
sq.more = stillNeedToTake
|
|
||||||
return out, nil
|
|
||||||
}
|
|
||||||
//look inside the fridge
|
|
||||||
stillNeedToTake, out := sq.usedRing.take(maxToTake)
|
|
||||||
if len(out) > 0 {
|
|
||||||
sq.more = stillNeedToTake
|
|
||||||
return out, nil
|
|
||||||
}
|
|
||||||
//fridge is empty I guess
|
|
||||||
|
|
||||||
// Wait for a signal from the device.
|
|
||||||
if n, err = sq.epoll.Block(); err != nil {
|
|
||||||
return nil, fmt.Errorf("wait: %w", err)
|
|
||||||
}
|
|
||||||
if n > 0 {
|
|
||||||
_ = sq.epoll.Clear() //???
|
|
||||||
stillNeedToTake, out = sq.usedRing.take(maxToTake)
|
|
||||||
sq.more = stillNeedToTake
|
|
||||||
return out, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, ctx.Err()
|
|
||||||
}
|
|
||||||
|
|
||||||
// OfferDescriptorChain offers a descriptor chain to the device which contains a
|
|
||||||
// number of device-readable buffers (out buffers) and device-writable buffers
|
|
||||||
// (in buffers).
|
|
||||||
//
|
|
||||||
// All buffers in the outBuffers slice will be concatenated by chaining
|
|
||||||
// descriptors, one for each buffer in the slice. When a buffer is too large to
|
|
||||||
// fit into a single descriptor (limited by the system's page size), it will be
|
|
||||||
// split up into multiple descriptors within the chain.
|
|
||||||
// When numInBuffers is greater than zero, the given number of device-writable
|
|
||||||
// descriptors will be appended to the end of the chain, each referencing a
|
|
||||||
// whole memory page (see [os.Getpagesize]).
|
|
||||||
//
|
|
||||||
// When the queue is full and no more descriptor chains can be added, a wrapped
|
|
||||||
// [ErrNotEnoughFreeDescriptors] will be returned. If you set waitFree to true,
|
|
||||||
// this method will handle this error and will block instead until there are
|
|
||||||
// enough free descriptors again.
|
|
||||||
//
|
|
||||||
// After defining the descriptor chain in the [DescriptorTable], the index of
|
|
||||||
// the head of the chain will be made available to the device using the
|
|
||||||
// [AvailableRing] and will be returned by this method.
|
|
||||||
// Callers should read from the [SplitQueue.UsedDescriptorChains] channel to be
|
|
||||||
// notified when the descriptor chain was used by the device and should free the
|
|
||||||
// used descriptor chains again using [SplitQueue.FreeDescriptorChain] when
|
|
||||||
// they're done with them. When this does not happen, the queue will run full
|
|
||||||
// and any further calls to [SplitQueue.OfferDescriptorChain] will stall.
|
|
||||||
|
|
||||||
func (sq *SplitQueue) OfferInDescriptorChains() (uint16, error) {
|
|
||||||
// Create a descriptor chain for the given buffers.
|
|
||||||
var (
|
|
||||||
head uint16
|
|
||||||
err error
|
|
||||||
)
|
|
||||||
for {
|
|
||||||
head, err = sq.descriptorTable.createDescriptorForInputs()
|
|
||||||
if err == nil {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
// I don't wanna use errors.Is, it's slow
|
|
||||||
//goland:noinspection GoDirectComparisonOfErrors
|
|
||||||
if err == ErrNotEnoughFreeDescriptors {
|
|
||||||
return 0, err
|
|
||||||
} else {
|
|
||||||
return 0, fmt.Errorf("create descriptor chain: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Make the descriptor chain available to the device.
|
|
||||||
sq.availableRing.offerSingle(head)
|
|
||||||
|
|
||||||
// Notify the device to make it process the updated available ring.
|
|
||||||
if err := sq.kickEventFD.Kick(); err != nil {
|
|
||||||
return head, fmt.Errorf("notify device: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return head, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (sq *SplitQueue) OfferOutDescriptorChains(prepend []byte, outBuffers [][]byte) ([]uint16, error) {
|
|
||||||
// TODO change this
|
|
||||||
// Each descriptor can only hold a whole memory page, so split large out
|
|
||||||
// buffers into multiple smaller ones.
|
|
||||||
outBuffers = splitBuffers(outBuffers, sq.itemSize)
|
|
||||||
|
|
||||||
chains := make([]uint16, len(outBuffers))
|
|
||||||
|
|
||||||
// Create a descriptor chain for the given buffers.
|
|
||||||
var (
|
|
||||||
head uint16
|
|
||||||
err error
|
|
||||||
)
|
|
||||||
for i := range outBuffers {
|
|
||||||
for {
|
|
||||||
bufs := [][]byte{prepend, outBuffers[i]}
|
|
||||||
head, err = sq.descriptorTable.createDescriptorChain(bufs, 0)
|
|
||||||
if err == nil {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
// I don't wanna use errors.Is, it's slow
|
|
||||||
//goland:noinspection GoDirectComparisonOfErrors
|
|
||||||
if err == ErrNotEnoughFreeDescriptors {
|
|
||||||
// Wait for more free descriptors to be put back into the queue.
|
|
||||||
// If the number of free descriptors is still not sufficient, we'll
|
|
||||||
// land here again.
|
|
||||||
//todo should never happen
|
|
||||||
syscall.Syscall(syscall.SYS_SCHED_YIELD, 0, 0, 0) // Cheap barrier
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
return nil, fmt.Errorf("create descriptor chain: %w", err)
|
|
||||||
}
|
|
||||||
chains[i] = head
|
|
||||||
}
|
|
||||||
|
|
||||||
// Make the descriptor chain available to the device.
|
|
||||||
sq.availableRing.offer(chains)
|
|
||||||
|
|
||||||
// Notify the device to make it process the updated available ring.
|
|
||||||
if err := sq.kickEventFD.Kick(); err != nil {
|
|
||||||
return chains, fmt.Errorf("notify device: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return chains, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetDescriptorChain returns the device-readable buffers (out buffers) and
|
|
||||||
// device-writable buffers (in buffers) of the descriptor chain with the given
|
|
||||||
// head index.
|
|
||||||
// The head index must be one that was returned by a previous call to
|
|
||||||
// [SplitQueue.OfferDescriptorChain] and the descriptor chain must not have been
|
|
||||||
// freed yet.
|
|
||||||
//
|
|
||||||
// Be careful to only access the returned buffer slices when the device is no
|
|
||||||
// longer using them. They must not be accessed after
|
|
||||||
// [SplitQueue.FreeDescriptorChain] has been called.
|
|
||||||
func (sq *SplitQueue) GetDescriptorChain(head uint16) (outBuffers, inBuffers [][]byte, err error) {
|
|
||||||
return sq.descriptorTable.getDescriptorChain(head)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (sq *SplitQueue) GetDescriptorItem(head uint16) ([]byte, error) {
|
|
||||||
sq.descriptorTable.descriptors[head].length = uint32(sq.descriptorTable.itemSize)
|
|
||||||
return sq.descriptorTable.getDescriptorItem(head)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (sq *SplitQueue) GetDescriptorChainContents(head uint16, out []byte, maxLen int) (int, error) {
|
|
||||||
return sq.descriptorTable.getDescriptorChainContents(head, out, maxLen)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (sq *SplitQueue) GetDescriptorInbuffers(head uint16, inBuffers *[][]byte) error {
|
|
||||||
return sq.descriptorTable.getDescriptorInbuffers(head, inBuffers)
|
|
||||||
}
|
|
||||||
|
|
||||||
// FreeDescriptorChain frees the descriptor chain with the given head index.
|
|
||||||
// The head index must be one that was returned by a previous call to
|
|
||||||
// [SplitQueue.OfferDescriptorChain] and the descriptor chain must not have been
|
|
||||||
// freed yet.
|
|
||||||
//
|
|
||||||
// This creates new room in the queue which can be used by following
|
|
||||||
// [SplitQueue.OfferDescriptorChain] calls.
|
|
||||||
// When there are outstanding calls for [SplitQueue.OfferDescriptorChain] that
|
|
||||||
// are waiting for free room in the queue, they may become unblocked by this.
|
|
||||||
func (sq *SplitQueue) FreeDescriptorChain(head uint16) error {
|
|
||||||
//not called under lock
|
|
||||||
if err := sq.descriptorTable.freeDescriptorChain(head); err != nil {
|
|
||||||
return fmt.Errorf("free: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (sq *SplitQueue) SetDescSize(head uint16, sz int) {
|
|
||||||
//not called under lock
|
|
||||||
sq.descriptorTable.descriptors[int(head)].length = uint32(sz)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (sq *SplitQueue) OfferDescriptorChains(chains []uint16, kick bool) error {
|
|
||||||
//todo not doing this may break eventually?
|
|
||||||
//not called under lock
|
|
||||||
//if err := sq.descriptorTable.freeDescriptorChain(head); err != nil {
|
|
||||||
// return fmt.Errorf("free: %w", err)
|
|
||||||
//}
|
|
||||||
|
|
||||||
// Make the descriptor chain available to the device.
|
|
||||||
sq.availableRing.offer(chains)
|
|
||||||
|
|
||||||
// Notify the device to make it process the updated available ring.
|
|
||||||
if kick {
|
|
||||||
return sq.Kick()
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (sq *SplitQueue) Kick() error {
|
|
||||||
if err := sq.kickEventFD.Kick(); err != nil {
|
|
||||||
return fmt.Errorf("notify device: %w", err)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Close releases all resources used for this queue.
|
|
||||||
// The implementation will try to release as many resources as possible and
|
|
||||||
// collect potential errors before returning them.
|
|
||||||
func (sq *SplitQueue) Close() error {
|
|
||||||
var errs []error
|
|
||||||
|
|
||||||
if sq.stop != nil {
|
|
||||||
// This has to happen before the event file descriptors may be closed.
|
|
||||||
if err := sq.stop(); err != nil {
|
|
||||||
errs = append(errs, fmt.Errorf("stop consume used ring: %w", err))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Make sure that this code block is executed only once.
|
|
||||||
sq.stop = nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := sq.kickEventFD.Close(); err != nil {
|
|
||||||
errs = append(errs, fmt.Errorf("close kick event file descriptor: %w", err))
|
|
||||||
}
|
|
||||||
if err := sq.callEventFD.Close(); err != nil {
|
|
||||||
errs = append(errs, fmt.Errorf("close call event file descriptor: %w", err))
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := sq.descriptorTable.releaseBuffers(); err != nil {
|
|
||||||
errs = append(errs, fmt.Errorf("release descriptor buffers: %w", err))
|
|
||||||
}
|
|
||||||
|
|
||||||
if sq.buf != nil {
|
|
||||||
if err := unix.Munmap(sq.buf); err == nil {
|
|
||||||
sq.buf = nil
|
|
||||||
} else {
|
|
||||||
errs = append(errs, fmt.Errorf("unmap virtqueue buffer: %w", err))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return errors.Join(errs...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ensureInitialized is used as a guard to prevent methods to be called on an
|
|
||||||
// uninitialized instance.
|
|
||||||
func (sq *SplitQueue) ensureInitialized() {
|
|
||||||
if sq.buf == nil {
|
|
||||||
panic("used ring is not initialized")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func align(index, alignment int) int {
|
|
||||||
remainder := index % alignment
|
|
||||||
if remainder == 0 {
|
|
||||||
return index
|
|
||||||
}
|
|
||||||
return index + alignment - remainder
|
|
||||||
}
|
|
||||||
|
|
||||||
// splitBuffers processes a list of buffers and splits each buffer that is
|
|
||||||
// larger than the size limit into multiple smaller buffers.
|
|
||||||
// If none of the buffers are too big though, do nothing, to avoid allocation for now
|
|
||||||
func splitBuffers(buffers [][]byte, sizeLimit int) [][]byte {
|
|
||||||
for i := range buffers {
|
|
||||||
if len(buffers[i]) > sizeLimit {
|
|
||||||
return reallySplitBuffers(buffers, sizeLimit)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return buffers
|
|
||||||
}
|
|
||||||
|
|
||||||
func reallySplitBuffers(buffers [][]byte, sizeLimit int) [][]byte {
|
|
||||||
result := make([][]byte, 0, len(buffers))
|
|
||||||
for _, buffer := range buffers {
|
|
||||||
for added := 0; added < len(buffer); added += sizeLimit {
|
|
||||||
if len(buffer)-added <= sizeLimit {
|
|
||||||
result = append(result, buffer[added:])
|
|
||||||
break
|
|
||||||
}
|
|
||||||
result = append(result, buffer[added:added+sizeLimit])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
@@ -1,105 +0,0 @@
|
|||||||
package virtqueue
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestSplitQueue_MemoryAlignment(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
queueSize int
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "minimal queue size",
|
|
||||||
queueSize: 1,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "small queue size",
|
|
||||||
queueSize: 8,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "large queue size",
|
|
||||||
queueSize: 256,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
sq, err := NewSplitQueue(tt.queueSize)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
assert.Zero(t, sq.descriptorTable.Address()%descriptorTableAlignment)
|
|
||||||
assert.Zero(t, sq.availableRing.Address()%availableRingAlignment)
|
|
||||||
assert.Zero(t, sq.usedRing.Address()%usedRingAlignment)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSplitBuffers(t *testing.T) {
|
|
||||||
const sizeLimit = 16
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
buffers [][]byte
|
|
||||||
expected [][]byte
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "no buffers",
|
|
||||||
buffers: make([][]byte, 0),
|
|
||||||
expected: make([][]byte, 0),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "small",
|
|
||||||
buffers: [][]byte{
|
|
||||||
make([]byte, 11),
|
|
||||||
},
|
|
||||||
expected: [][]byte{
|
|
||||||
make([]byte, 11),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "exact size",
|
|
||||||
buffers: [][]byte{
|
|
||||||
make([]byte, sizeLimit),
|
|
||||||
},
|
|
||||||
expected: [][]byte{
|
|
||||||
make([]byte, sizeLimit),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "large",
|
|
||||||
buffers: [][]byte{
|
|
||||||
make([]byte, 42),
|
|
||||||
},
|
|
||||||
expected: [][]byte{
|
|
||||||
make([]byte, 16),
|
|
||||||
make([]byte, 16),
|
|
||||||
make([]byte, 10),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "mixed",
|
|
||||||
buffers: [][]byte{
|
|
||||||
make([]byte, 7),
|
|
||||||
make([]byte, 30),
|
|
||||||
make([]byte, 15),
|
|
||||||
make([]byte, 32),
|
|
||||||
},
|
|
||||||
expected: [][]byte{
|
|
||||||
make([]byte, 7),
|
|
||||||
make([]byte, 16),
|
|
||||||
make([]byte, 14),
|
|
||||||
make([]byte, 15),
|
|
||||||
make([]byte, 16),
|
|
||||||
make([]byte, 16),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
actual := splitBuffers(tt.buffers, sizeLimit)
|
|
||||||
assert.Equal(t, tt.expected, actual)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,21 +0,0 @@
|
|||||||
package virtqueue
|
|
||||||
|
|
||||||
// usedElementSize is the number of bytes needed to store a [UsedElement] in
|
|
||||||
// memory.
|
|
||||||
const usedElementSize = 8
|
|
||||||
|
|
||||||
// UsedElement is an element of the [UsedRing] and describes a descriptor chain
|
|
||||||
// that was used by the device.
|
|
||||||
type UsedElement struct {
|
|
||||||
// DescriptorIndex is the index of the head of the used descriptor chain in
|
|
||||||
// the [DescriptorTable].
|
|
||||||
// The index is 32-bit here for padding reasons.
|
|
||||||
DescriptorIndex uint32
|
|
||||||
// Length is the number of bytes written into the device writable portion of
|
|
||||||
// the buffer described by the descriptor chain.
|
|
||||||
Length uint32
|
|
||||||
}
|
|
||||||
|
|
||||||
func (u *UsedElement) GetHead() uint16 {
|
|
||||||
return uint16(u.DescriptorIndex)
|
|
||||||
}
|
|
||||||
@@ -1,12 +0,0 @@
|
|||||||
package virtqueue
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
"unsafe"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestUsedElement_Size(t *testing.T) {
|
|
||||||
assert.EqualValues(t, usedElementSize, unsafe.Sizeof(UsedElement{}))
|
|
||||||
}
|
|
||||||
@@ -1,184 +0,0 @@
|
|||||||
package virtqueue
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"unsafe"
|
|
||||||
)
|
|
||||||
|
|
||||||
// usedRingFlag is a flag that describes a [UsedRing].
|
|
||||||
type usedRingFlag uint16
|
|
||||||
|
|
||||||
const (
|
|
||||||
// usedRingFlagNoNotify is used by the host to advise the guest to not
|
|
||||||
// kick it when adding a buffer. It's unreliable, so it's simply an
|
|
||||||
// optimization. Guest will still kick when it's out of buffers.
|
|
||||||
usedRingFlagNoNotify usedRingFlag = 1 << iota
|
|
||||||
)
|
|
||||||
|
|
||||||
// usedRingSize is the number of bytes needed to store a [UsedRing] with the
|
|
||||||
// given queue size in memory.
|
|
||||||
func usedRingSize(queueSize int) int {
|
|
||||||
return 6 + usedElementSize*queueSize
|
|
||||||
}
|
|
||||||
|
|
||||||
// usedRingAlignment is the minimum alignment of a [UsedRing] in memory, as
|
|
||||||
// required by the virtio spec.
|
|
||||||
const usedRingAlignment = 4
|
|
||||||
|
|
||||||
// UsedRing is where the device returns descriptor chains once it is done with
|
|
||||||
// them. Each ring entry is a [UsedElement]. It is only written to by the device
|
|
||||||
// and read by the driver.
|
|
||||||
//
|
|
||||||
// Because the size of the ring depends on the queue size, we cannot define a
|
|
||||||
// Go struct with a static size that maps to the memory of the ring. Instead,
|
|
||||||
// this struct only contains pointers to the corresponding memory areas.
|
|
||||||
type UsedRing struct {
|
|
||||||
initialized bool
|
|
||||||
|
|
||||||
// flags that describe this ring.
|
|
||||||
flags *usedRingFlag
|
|
||||||
// ringIndex indicates where the device would put the next entry into the
|
|
||||||
// ring (modulo the queue size).
|
|
||||||
ringIndex *uint16
|
|
||||||
// ring contains the [UsedElement]s. It wraps around at queue size.
|
|
||||||
ring []UsedElement
|
|
||||||
// availableEvent is not used by this implementation, but we reserve it
|
|
||||||
// anyway to avoid issues in case a device may try to write to it, contrary
|
|
||||||
// to the virtio specification.
|
|
||||||
availableEvent *uint16
|
|
||||||
|
|
||||||
// lastIndex is the internal ringIndex up to which all [UsedElement]s were
|
|
||||||
// processed.
|
|
||||||
lastIndex uint16
|
|
||||||
|
|
||||||
//mu sync.Mutex
|
|
||||||
}
|
|
||||||
|
|
||||||
// newUsedRing creates a used ring that uses the given underlying memory. The
|
|
||||||
// length of the memory slice must match the size needed for the ring (see
|
|
||||||
// [usedRingSize]) for the given queue size.
|
|
||||||
func newUsedRing(queueSize int, mem []byte) *UsedRing {
|
|
||||||
ringSize := usedRingSize(queueSize)
|
|
||||||
if len(mem) != ringSize {
|
|
||||||
panic(fmt.Sprintf("memory size (%v) does not match required size "+
|
|
||||||
"for used ring: %v", len(mem), ringSize))
|
|
||||||
}
|
|
||||||
|
|
||||||
r := UsedRing{
|
|
||||||
initialized: true,
|
|
||||||
flags: (*usedRingFlag)(unsafe.Pointer(&mem[0])),
|
|
||||||
ringIndex: (*uint16)(unsafe.Pointer(&mem[2])),
|
|
||||||
ring: unsafe.Slice((*UsedElement)(unsafe.Pointer(&mem[4])), queueSize),
|
|
||||||
availableEvent: (*uint16)(unsafe.Pointer(&mem[ringSize-2])),
|
|
||||||
}
|
|
||||||
r.lastIndex = *r.ringIndex
|
|
||||||
return &r
|
|
||||||
}
|
|
||||||
|
|
||||||
// Address returns the pointer to the beginning of the ring in memory.
|
|
||||||
// Do not modify the memory directly to not interfere with this implementation.
|
|
||||||
func (r *UsedRing) Address() uintptr {
|
|
||||||
if !r.initialized {
|
|
||||||
panic("used ring is not initialized")
|
|
||||||
}
|
|
||||||
return uintptr(unsafe.Pointer(r.flags))
|
|
||||||
}
|
|
||||||
|
|
||||||
// take returns all new [UsedElement]s that the device put into the ring and
|
|
||||||
// that weren't already returned by a previous call to this method.
|
|
||||||
// had a lock, I removed it
|
|
||||||
func (r *UsedRing) take(maxToTake int) (int, []UsedElement) {
|
|
||||||
//r.mu.Lock()
|
|
||||||
//defer r.mu.Unlock()
|
|
||||||
|
|
||||||
ringIndex := *r.ringIndex
|
|
||||||
if ringIndex == r.lastIndex {
|
|
||||||
// Nothing new.
|
|
||||||
return 0, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Calculate the number new used elements that we can read from the ring.
|
|
||||||
// The ring index may wrap, so special handling for that case is needed.
|
|
||||||
count := int(ringIndex - r.lastIndex)
|
|
||||||
if count < 0 {
|
|
||||||
count += 0xffff
|
|
||||||
}
|
|
||||||
|
|
||||||
stillNeedToTake := 0
|
|
||||||
|
|
||||||
if maxToTake > 0 {
|
|
||||||
stillNeedToTake = count - maxToTake
|
|
||||||
if stillNeedToTake < 0 {
|
|
||||||
stillNeedToTake = 0
|
|
||||||
}
|
|
||||||
count = min(count, maxToTake)
|
|
||||||
}
|
|
||||||
|
|
||||||
// The number of new elements can never exceed the queue size.
|
|
||||||
if count > len(r.ring) {
|
|
||||||
panic("used ring contains more new elements than the ring is long")
|
|
||||||
}
|
|
||||||
|
|
||||||
elems := make([]UsedElement, count)
|
|
||||||
for i := range count {
|
|
||||||
elems[i] = r.ring[r.lastIndex%uint16(len(r.ring))]
|
|
||||||
r.lastIndex++
|
|
||||||
}
|
|
||||||
|
|
||||||
return stillNeedToTake, elems
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *UsedRing) takeOne() (uint16, bool) {
|
|
||||||
//r.mu.Lock()
|
|
||||||
//defer r.mu.Unlock()
|
|
||||||
|
|
||||||
ringIndex := *r.ringIndex
|
|
||||||
if ringIndex == r.lastIndex {
|
|
||||||
// Nothing new.
|
|
||||||
return 0xffff, false
|
|
||||||
}
|
|
||||||
|
|
||||||
// Calculate the number new used elements that we can read from the ring.
|
|
||||||
// The ring index may wrap, so special handling for that case is needed.
|
|
||||||
count := int(ringIndex - r.lastIndex)
|
|
||||||
if count < 0 {
|
|
||||||
count += 0xffff
|
|
||||||
}
|
|
||||||
|
|
||||||
// The number of new elements can never exceed the queue size.
|
|
||||||
if count > len(r.ring) {
|
|
||||||
panic("used ring contains more new elements than the ring is long")
|
|
||||||
}
|
|
||||||
|
|
||||||
if count == 0 {
|
|
||||||
return 0xffff, false
|
|
||||||
}
|
|
||||||
|
|
||||||
out := r.ring[r.lastIndex%uint16(len(r.ring))].GetHead()
|
|
||||||
r.lastIndex++
|
|
||||||
|
|
||||||
return out, true
|
|
||||||
}
|
|
||||||
|
|
||||||
// InitOfferSingle is only used to pre-fill the used queue at startup, and should not be used if the device is running!
|
|
||||||
func (r *UsedRing) InitOfferSingle(x uint16, size int) {
|
|
||||||
//always called under lock
|
|
||||||
//r.mu.Lock()
|
|
||||||
//defer r.mu.Unlock()
|
|
||||||
|
|
||||||
offset := 0
|
|
||||||
// Add descriptor chain heads to the ring.
|
|
||||||
|
|
||||||
// The 16-bit ring index may overflow. This is expected and is not an
|
|
||||||
// issue because the size of the ring array (which equals the queue
|
|
||||||
// size) is always a power of 2 and smaller than the highest possible
|
|
||||||
// 16-bit value.
|
|
||||||
insertIndex := int(*r.ringIndex+uint16(offset)) % len(r.ring)
|
|
||||||
r.ring[insertIndex] = UsedElement{
|
|
||||||
DescriptorIndex: uint32(x),
|
|
||||||
Length: uint32(size),
|
|
||||||
}
|
|
||||||
|
|
||||||
// Increase the ring index by the number of descriptor chains added to the ring.
|
|
||||||
*r.ringIndex += 1
|
|
||||||
}
|
|
||||||
@@ -1,136 +0,0 @@
|
|||||||
package virtqueue
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestUsedRing_MemoryLayout(t *testing.T) {
|
|
||||||
const queueSize = 2
|
|
||||||
|
|
||||||
memory := make([]byte, usedRingSize(queueSize))
|
|
||||||
r := newUsedRing(queueSize, memory)
|
|
||||||
|
|
||||||
*r.flags = 0x01ff
|
|
||||||
*r.ringIndex = 1
|
|
||||||
r.ring[0] = UsedElement{
|
|
||||||
DescriptorIndex: 0x0123,
|
|
||||||
Length: 0x4567,
|
|
||||||
}
|
|
||||||
r.ring[1] = UsedElement{
|
|
||||||
DescriptorIndex: 0x89ab,
|
|
||||||
Length: 0xcdef,
|
|
||||||
}
|
|
||||||
|
|
||||||
assert.Equal(t, []byte{
|
|
||||||
0xff, 0x01,
|
|
||||||
0x01, 0x00,
|
|
||||||
0x23, 0x01, 0x00, 0x00,
|
|
||||||
0x67, 0x45, 0x00, 0x00,
|
|
||||||
0xab, 0x89, 0x00, 0x00,
|
|
||||||
0xef, 0xcd, 0x00, 0x00,
|
|
||||||
0x00, 0x00,
|
|
||||||
}, memory)
|
|
||||||
}
|
|
||||||
|
|
||||||
//func TestUsedRing_Take(t *testing.T) {
|
|
||||||
// const queueSize = 8
|
|
||||||
//
|
|
||||||
// tests := []struct {
|
|
||||||
// name string
|
|
||||||
// ring []UsedElement
|
|
||||||
// ringIndex uint16
|
|
||||||
// lastIndex uint16
|
|
||||||
// expected []UsedElement
|
|
||||||
// }{
|
|
||||||
// {
|
|
||||||
// name: "nothing new",
|
|
||||||
// ring: []UsedElement{
|
|
||||||
// {DescriptorIndex: 1},
|
|
||||||
// {DescriptorIndex: 2},
|
|
||||||
// {DescriptorIndex: 3},
|
|
||||||
// {DescriptorIndex: 4},
|
|
||||||
// {},
|
|
||||||
// {},
|
|
||||||
// {},
|
|
||||||
// {},
|
|
||||||
// },
|
|
||||||
// ringIndex: 4,
|
|
||||||
// lastIndex: 4,
|
|
||||||
// expected: nil,
|
|
||||||
// },
|
|
||||||
// {
|
|
||||||
// name: "no overflow",
|
|
||||||
// ring: []UsedElement{
|
|
||||||
// {DescriptorIndex: 1},
|
|
||||||
// {DescriptorIndex: 2},
|
|
||||||
// {DescriptorIndex: 3},
|
|
||||||
// {DescriptorIndex: 4},
|
|
||||||
// {},
|
|
||||||
// {},
|
|
||||||
// {},
|
|
||||||
// {},
|
|
||||||
// },
|
|
||||||
// ringIndex: 4,
|
|
||||||
// lastIndex: 1,
|
|
||||||
// expected: []UsedElement{
|
|
||||||
// {DescriptorIndex: 2},
|
|
||||||
// {DescriptorIndex: 3},
|
|
||||||
// {DescriptorIndex: 4},
|
|
||||||
// },
|
|
||||||
// },
|
|
||||||
// {
|
|
||||||
// name: "ring overflow",
|
|
||||||
// ring: []UsedElement{
|
|
||||||
// {DescriptorIndex: 9},
|
|
||||||
// {DescriptorIndex: 10},
|
|
||||||
// {DescriptorIndex: 3},
|
|
||||||
// {DescriptorIndex: 4},
|
|
||||||
// {DescriptorIndex: 5},
|
|
||||||
// {DescriptorIndex: 6},
|
|
||||||
// {DescriptorIndex: 7},
|
|
||||||
// {DescriptorIndex: 8},
|
|
||||||
// },
|
|
||||||
// ringIndex: 10,
|
|
||||||
// lastIndex: 7,
|
|
||||||
// expected: []UsedElement{
|
|
||||||
// {DescriptorIndex: 8},
|
|
||||||
// {DescriptorIndex: 9},
|
|
||||||
// {DescriptorIndex: 10},
|
|
||||||
// },
|
|
||||||
// },
|
|
||||||
// {
|
|
||||||
// name: "index overflow",
|
|
||||||
// ring: []UsedElement{
|
|
||||||
// {DescriptorIndex: 9},
|
|
||||||
// {DescriptorIndex: 10},
|
|
||||||
// {DescriptorIndex: 3},
|
|
||||||
// {DescriptorIndex: 4},
|
|
||||||
// {DescriptorIndex: 5},
|
|
||||||
// {DescriptorIndex: 6},
|
|
||||||
// {DescriptorIndex: 7},
|
|
||||||
// {DescriptorIndex: 8},
|
|
||||||
// },
|
|
||||||
// ringIndex: 2,
|
|
||||||
// lastIndex: 65535,
|
|
||||||
// expected: []UsedElement{
|
|
||||||
// {DescriptorIndex: 8},
|
|
||||||
// {DescriptorIndex: 9},
|
|
||||||
// {DescriptorIndex: 10},
|
|
||||||
// },
|
|
||||||
// },
|
|
||||||
// }
|
|
||||||
// for _, tt := range tests {
|
|
||||||
// t.Run(tt.name, func(t *testing.T) {
|
|
||||||
// memory := make([]byte, usedRingSize(queueSize))
|
|
||||||
// r := newUsedRing(queueSize, memory)
|
|
||||||
//
|
|
||||||
// copy(r.ring, tt.ring)
|
|
||||||
// *r.ringIndex = tt.ringIndex
|
|
||||||
// r.lastIndex = tt.lastIndex
|
|
||||||
//
|
|
||||||
// assert.Equal(t, tt.expected, r.take())
|
|
||||||
// })
|
|
||||||
// }
|
|
||||||
//}
|
|
||||||
@@ -1,70 +0,0 @@
|
|||||||
package packet
|
|
||||||
|
|
||||||
import (
|
|
||||||
"github.com/slackhq/nebula/util/virtio"
|
|
||||||
"golang.org/x/sys/unix"
|
|
||||||
)
|
|
||||||
|
|
||||||
type OutPacket struct {
|
|
||||||
Segments [][]byte
|
|
||||||
SegmentPayloads [][]byte
|
|
||||||
SegmentHeaders [][]byte
|
|
||||||
SegmentIDs []uint16
|
|
||||||
//todo virtio header?
|
|
||||||
SegSize int
|
|
||||||
SegCounter int
|
|
||||||
Valid bool
|
|
||||||
wasSegmented bool
|
|
||||||
|
|
||||||
Scratch []byte
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewOut() *OutPacket {
|
|
||||||
out := new(OutPacket)
|
|
||||||
out.Segments = make([][]byte, 0, 64)
|
|
||||||
out.SegmentHeaders = make([][]byte, 0, 64)
|
|
||||||
out.SegmentPayloads = make([][]byte, 0, 64)
|
|
||||||
out.SegmentIDs = make([]uint16, 0, 64)
|
|
||||||
out.Scratch = make([]byte, Size)
|
|
||||||
return out
|
|
||||||
}
|
|
||||||
|
|
||||||
func (pkt *OutPacket) Reset() {
|
|
||||||
pkt.Segments = pkt.Segments[:0]
|
|
||||||
pkt.SegmentPayloads = pkt.SegmentPayloads[:0]
|
|
||||||
pkt.SegmentHeaders = pkt.SegmentHeaders[:0]
|
|
||||||
pkt.SegmentIDs = pkt.SegmentIDs[:0]
|
|
||||||
pkt.SegSize = 0
|
|
||||||
pkt.Valid = false
|
|
||||||
pkt.wasSegmented = false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (pkt *OutPacket) UseSegment(segID uint16, seg []byte, isV6 bool) int {
|
|
||||||
pkt.Valid = true
|
|
||||||
pkt.SegmentIDs = append(pkt.SegmentIDs, segID)
|
|
||||||
pkt.Segments = append(pkt.Segments, seg) //todo do we need this?
|
|
||||||
|
|
||||||
vhdr := virtio.NetHdr{ //todo
|
|
||||||
Flags: unix.VIRTIO_NET_HDR_F_DATA_VALID,
|
|
||||||
GSOType: unix.VIRTIO_NET_HDR_GSO_NONE,
|
|
||||||
HdrLen: 0,
|
|
||||||
GSOSize: 0,
|
|
||||||
CsumStart: 0,
|
|
||||||
CsumOffset: 0,
|
|
||||||
NumBuffers: 0,
|
|
||||||
}
|
|
||||||
|
|
||||||
hdr := seg[0 : virtio.NetHdrSize+14]
|
|
||||||
_ = vhdr.Encode(hdr)
|
|
||||||
if isV6 {
|
|
||||||
hdr[virtio.NetHdrSize+14-2] = 0x86
|
|
||||||
hdr[virtio.NetHdrSize+14-1] = 0xdd
|
|
||||||
} else {
|
|
||||||
hdr[virtio.NetHdrSize+14-2] = 0x08
|
|
||||||
hdr[virtio.NetHdrSize+14-1] = 0x00
|
|
||||||
}
|
|
||||||
|
|
||||||
pkt.SegmentHeaders = append(pkt.SegmentHeaders, hdr)
|
|
||||||
pkt.SegmentPayloads = append(pkt.SegmentPayloads, seg[virtio.NetHdrSize+14:])
|
|
||||||
return len(pkt.SegmentIDs) - 1
|
|
||||||
}
|
|
||||||
119
packet/packet.go
119
packet/packet.go
@@ -1,119 +0,0 @@
|
|||||||
package packet
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/binary"
|
|
||||||
"iter"
|
|
||||||
"net/netip"
|
|
||||||
"slices"
|
|
||||||
"syscall"
|
|
||||||
"unsafe"
|
|
||||||
|
|
||||||
"golang.org/x/sys/unix"
|
|
||||||
)
|
|
||||||
|
|
||||||
const Size = 0xffff
|
|
||||||
|
|
||||||
type Packet struct {
|
|
||||||
Payload []byte
|
|
||||||
Control []byte
|
|
||||||
Name []byte
|
|
||||||
SegSize int
|
|
||||||
|
|
||||||
//todo should this hold out as well?
|
|
||||||
OutLen int
|
|
||||||
|
|
||||||
wasSegmented bool
|
|
||||||
isV4 bool
|
|
||||||
}
|
|
||||||
|
|
||||||
func New(isV4 bool) *Packet {
|
|
||||||
return &Packet{
|
|
||||||
Payload: make([]byte, Size),
|
|
||||||
Control: make([]byte, unix.CmsgSpace(2)),
|
|
||||||
Name: make([]byte, unix.SizeofSockaddrInet6),
|
|
||||||
isV4: isV4,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *Packet) AddrPort() netip.AddrPort {
|
|
||||||
var ip netip.Addr
|
|
||||||
// Its ok to skip the ok check here, the slicing is the only error that can occur and it will panic
|
|
||||||
if p.isV4 {
|
|
||||||
ip, _ = netip.AddrFromSlice(p.Name[4:8])
|
|
||||||
} else {
|
|
||||||
ip, _ = netip.AddrFromSlice(p.Name[8:24])
|
|
||||||
}
|
|
||||||
return netip.AddrPortFrom(ip.Unmap(), binary.BigEndian.Uint16(p.Name[2:4]))
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *Packet) updateCtrl(ctrlLen int) {
|
|
||||||
p.SegSize = len(p.Payload)
|
|
||||||
p.wasSegmented = false
|
|
||||||
if ctrlLen == 0 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if len(p.Control) == 0 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
cmsgs, err := unix.ParseSocketControlMessage(p.Control)
|
|
||||||
if err != nil {
|
|
||||||
return // oh well
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, c := range cmsgs {
|
|
||||||
if c.Header.Level == unix.SOL_UDP && c.Header.Type == unix.UDP_GRO && len(c.Data) >= 2 {
|
|
||||||
p.wasSegmented = true
|
|
||||||
p.SegSize = int(binary.LittleEndian.Uint16(c.Data[:2]))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Update sets a Packet into "just received, not processed" state
|
|
||||||
func (p *Packet) Update(ctrlLen int) {
|
|
||||||
p.OutLen = -1
|
|
||||||
p.updateCtrl(ctrlLen)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *Packet) SetSegSizeForTX() {
|
|
||||||
p.SegSize = len(p.Payload)
|
|
||||||
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&p.Control[0]))
|
|
||||||
hdr.Level = unix.SOL_UDP
|
|
||||||
hdr.Type = unix.UDP_SEGMENT
|
|
||||||
hdr.SetLen(syscall.CmsgLen(2))
|
|
||||||
binary.NativeEndian.PutUint16(p.Control[unix.CmsgLen(0):unix.CmsgLen(0)+2], uint16(p.SegSize))
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *Packet) CompatibleForSegmentationWith(otherP *Packet, currentTotalSize int) bool {
|
|
||||||
//same dest
|
|
||||||
if !slices.Equal(p.Name, otherP.Name) {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
//don't get too big
|
|
||||||
if len(p.Payload)+currentTotalSize >= 0xffff {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
//same body len
|
|
||||||
//todo allow single different size at end
|
|
||||||
if len(p.Payload) != len(otherP.Payload) {
|
|
||||||
return false //todo technically you can cram one extra in
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *Packet) Segments() iter.Seq[[]byte] {
|
|
||||||
return func(yield func([]byte) bool) {
|
|
||||||
//cursor := 0
|
|
||||||
for offset := 0; offset < len(p.Payload); offset += p.SegSize {
|
|
||||||
end := offset + p.SegSize
|
|
||||||
if end > len(p.Payload) {
|
|
||||||
end = len(p.Payload)
|
|
||||||
}
|
|
||||||
if !yield(p.Payload[offset:end]) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,37 +0,0 @@
|
|||||||
package packet
|
|
||||||
|
|
||||||
import (
|
|
||||||
"github.com/slackhq/nebula/util/virtio"
|
|
||||||
)
|
|
||||||
|
|
||||||
type VirtIOPacket struct {
|
|
||||||
Payload []byte
|
|
||||||
Header virtio.NetHdr
|
|
||||||
Chains []uint16
|
|
||||||
ChainRefs [][]byte
|
|
||||||
// OfferDescriptorChains(chains []uint16, kick bool) error
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewVIO() *VirtIOPacket {
|
|
||||||
out := new(VirtIOPacket)
|
|
||||||
out.Payload = nil
|
|
||||||
out.ChainRefs = make([][]byte, 0, 4)
|
|
||||||
out.Chains = make([]uint16, 0, 8)
|
|
||||||
return out
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *VirtIOPacket) Reset() {
|
|
||||||
v.Payload = nil
|
|
||||||
v.ChainRefs = v.ChainRefs[:0]
|
|
||||||
v.Chains = v.Chains[:0]
|
|
||||||
}
|
|
||||||
|
|
||||||
type VirtIOTXPacket struct {
|
|
||||||
VirtIOPacket
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewVIOTX(isV4 bool) *VirtIOTXPacket {
|
|
||||||
out := new(VirtIOTXPacket)
|
|
||||||
out.VirtIOPacket = *NewVIO()
|
|
||||||
return out
|
|
||||||
}
|
|
||||||
93
pki.go
93
pki.go
@@ -100,62 +100,55 @@ func (p *PKI) reloadCerts(c *config.C, initial bool) *util.ContextualError {
|
|||||||
currentState := p.cs.Load()
|
currentState := p.cs.Load()
|
||||||
if newState.v1Cert != nil {
|
if newState.v1Cert != nil {
|
||||||
if currentState.v1Cert == nil {
|
if currentState.v1Cert == nil {
|
||||||
//adding certs is fine, actually. Networks-in-common confirmed in newCertState().
|
return util.NewContextualError("v1 certificate was added, restart required", nil, err)
|
||||||
} else {
|
|
||||||
// did IP in cert change? if so, don't set
|
|
||||||
if !slices.Equal(currentState.v1Cert.Networks(), newState.v1Cert.Networks()) {
|
|
||||||
return util.NewContextualError(
|
|
||||||
"Networks in new cert was different from old",
|
|
||||||
m{"new_networks": newState.v1Cert.Networks(), "old_networks": currentState.v1Cert.Networks(), "cert_version": cert.Version1},
|
|
||||||
nil,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
if currentState.v1Cert.Curve() != newState.v1Cert.Curve() {
|
|
||||||
return util.NewContextualError(
|
|
||||||
"Curve in new v1 cert was different from old",
|
|
||||||
m{"new_curve": newState.v1Cert.Curve(), "old_curve": currentState.v1Cert.Curve(), "cert_version": cert.Version1},
|
|
||||||
nil,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// did IP in cert change? if so, don't set
|
||||||
|
if !slices.Equal(currentState.v1Cert.Networks(), newState.v1Cert.Networks()) {
|
||||||
|
return util.NewContextualError(
|
||||||
|
"Networks in new cert was different from old",
|
||||||
|
m{"new_networks": newState.v1Cert.Networks(), "old_networks": currentState.v1Cert.Networks()},
|
||||||
|
nil,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
if currentState.v1Cert.Curve() != newState.v1Cert.Curve() {
|
||||||
|
return util.NewContextualError(
|
||||||
|
"Curve in new cert was different from old",
|
||||||
|
m{"new_curve": newState.v1Cert.Curve(), "old_curve": currentState.v1Cert.Curve()},
|
||||||
|
nil,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
} else if currentState.v1Cert != nil {
|
||||||
|
//TODO: CERT-V2 we should be able to tear this down
|
||||||
|
return util.NewContextualError("v1 certificate was removed, restart required", nil, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if newState.v2Cert != nil {
|
if newState.v2Cert != nil {
|
||||||
if currentState.v2Cert == nil {
|
if currentState.v2Cert == nil {
|
||||||
//adding certs is fine, actually
|
return util.NewContextualError("v2 certificate was added, restart required", nil, err)
|
||||||
} else {
|
|
||||||
// did IP in cert change? if so, don't set
|
|
||||||
if !slices.Equal(currentState.v2Cert.Networks(), newState.v2Cert.Networks()) {
|
|
||||||
return util.NewContextualError(
|
|
||||||
"Networks in new cert was different from old",
|
|
||||||
m{"new_networks": newState.v2Cert.Networks(), "old_networks": currentState.v2Cert.Networks(), "cert_version": cert.Version2},
|
|
||||||
nil,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
if currentState.v2Cert.Curve() != newState.v2Cert.Curve() {
|
|
||||||
return util.NewContextualError(
|
|
||||||
"Curve in new cert was different from old",
|
|
||||||
m{"new_curve": newState.v2Cert.Curve(), "old_curve": currentState.v2Cert.Curve(), "cert_version": cert.Version2},
|
|
||||||
nil,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} else if currentState.v2Cert != nil {
|
// did IP in cert change? if so, don't set
|
||||||
//newState.v1Cert is non-nil bc empty certstates aren't permitted
|
if !slices.Equal(currentState.v2Cert.Networks(), newState.v2Cert.Networks()) {
|
||||||
if newState.v1Cert == nil {
|
|
||||||
return util.NewContextualError("v1 and v2 certs are nil, this should be impossible", nil, err)
|
|
||||||
}
|
|
||||||
//if we're going to v1-only, we need to make sure we didn't orphan any v2-cert vpnaddrs
|
|
||||||
if !slices.Equal(currentState.v2Cert.Networks(), newState.v1Cert.Networks()) {
|
|
||||||
return util.NewContextualError(
|
return util.NewContextualError(
|
||||||
"Removing a V2 cert is not permitted unless it has identical networks to the new V1 cert",
|
"Networks in new cert was different from old",
|
||||||
m{"new_v1_networks": newState.v1Cert.Networks(), "old_v2_networks": currentState.v2Cert.Networks()},
|
m{"new_networks": newState.v2Cert.Networks(), "old_networks": currentState.v2Cert.Networks()},
|
||||||
nil,
|
nil,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if currentState.v2Cert.Curve() != newState.v2Cert.Curve() {
|
||||||
|
return util.NewContextualError(
|
||||||
|
"Curve in new cert was different from old",
|
||||||
|
m{"new_curve": newState.v2Cert.Curve(), "old_curve": currentState.v2Cert.Curve()},
|
||||||
|
nil,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
} else if currentState.v2Cert != nil {
|
||||||
|
return util.NewContextualError("v2 certificate was removed, restart required", nil, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Cipher cant be hot swapped so just leave it at what it was before
|
// Cipher cant be hot swapped so just leave it at what it was before
|
||||||
@@ -523,13 +516,9 @@ func loadCAPoolFromConfig(l *logrus.Logger, c *config.C) (*cert.CAPool, error) {
|
|||||||
return nil, fmt.Errorf("error while adding CA certificate to CA trust store: %s", err)
|
return nil, fmt.Errorf("error while adding CA certificate to CA trust store: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
bl := c.GetStringSlice("pki.blocklist", []string{})
|
for _, fp := range c.GetStringSlice("pki.blocklist", []string{}) {
|
||||||
if len(bl) > 0 {
|
l.WithField("fingerprint", fp).Info("Blocklisting cert")
|
||||||
for _, fp := range bl {
|
caPool.BlocklistFingerprint(fp)
|
||||||
caPool.BlocklistFingerprint(fp)
|
|
||||||
}
|
|
||||||
|
|
||||||
l.WithField("fingerprintCount", len(bl)).Info("Blocklisted certificates")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return caPool, nil
|
return caPool, nil
|
||||||
|
|||||||
@@ -16,8 +16,8 @@ import (
|
|||||||
"github.com/slackhq/nebula/cert_test"
|
"github.com/slackhq/nebula/cert_test"
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
"github.com/slackhq/nebula/overlay"
|
"github.com/slackhq/nebula/overlay"
|
||||||
"go.yaml.in/yaml/v3"
|
|
||||||
"golang.org/x/sync/errgroup"
|
"golang.org/x/sync/errgroup"
|
||||||
|
"gopkg.in/yaml.v3"
|
||||||
)
|
)
|
||||||
|
|
||||||
type m = map[string]any
|
type m = map[string]any
|
||||||
|
|||||||
@@ -4,13 +4,13 @@ import (
|
|||||||
"net/netip"
|
"net/netip"
|
||||||
|
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
"github.com/slackhq/nebula/packet"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const MTU = 9001
|
const MTU = 9001
|
||||||
|
|
||||||
type EncReader func(
|
type EncReader func(
|
||||||
[]*packet.Packet,
|
addr netip.AddrPort,
|
||||||
|
payload []byte,
|
||||||
)
|
)
|
||||||
|
|
||||||
type Conn interface {
|
type Conn interface {
|
||||||
@@ -19,8 +19,6 @@ type Conn interface {
|
|||||||
ListenOut(r EncReader)
|
ListenOut(r EncReader)
|
||||||
WriteTo(b []byte, addr netip.AddrPort) error
|
WriteTo(b []byte, addr netip.AddrPort) error
|
||||||
ReloadConfig(c *config.C)
|
ReloadConfig(c *config.C)
|
||||||
Prep(pkt *packet.Packet, addr netip.AddrPort) error
|
|
||||||
WriteBatch(pkt []*packet.Packet) (int, error)
|
|
||||||
Close() error
|
Close() error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
216
udp/udp_linux.go
216
udp/udp_linux.go
@@ -14,22 +14,22 @@ import (
|
|||||||
"github.com/rcrowley/go-metrics"
|
"github.com/rcrowley/go-metrics"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
"github.com/slackhq/nebula/packet"
|
|
||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
)
|
)
|
||||||
|
|
||||||
const iovMax = 128 //1024 //no unix constant for this? from limits.h
|
|
||||||
//todo I'd like this to be 1024 but we seem to hit errors around ~130?
|
|
||||||
|
|
||||||
type StdConn struct {
|
type StdConn struct {
|
||||||
sysFd int
|
sysFd int
|
||||||
isV4 bool
|
isV4 bool
|
||||||
l *logrus.Logger
|
l *logrus.Logger
|
||||||
batch int
|
batch int
|
||||||
enableGRO bool
|
}
|
||||||
|
|
||||||
msgs []rawMessage
|
func maybeIPV4(ip net.IP) (net.IP, bool) {
|
||||||
iovs [][]iovec
|
ip4 := ip.To4()
|
||||||
|
if ip4 != nil {
|
||||||
|
return ip4, true
|
||||||
|
}
|
||||||
|
return ip, false
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
|
func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
|
||||||
@@ -69,20 +69,7 @@ func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch in
|
|||||||
return nil, fmt.Errorf("unable to bind to socket: %s", err)
|
return nil, fmt.Errorf("unable to bind to socket: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
const batchSize = 8192
|
return &StdConn{sysFd: fd, isV4: ip.Is4(), l: l, batch: batch}, err
|
||||||
msgs := make([]rawMessage, 0, batchSize) //todo configure
|
|
||||||
iovs := make([][]iovec, batchSize)
|
|
||||||
for i := range iovs {
|
|
||||||
iovs[i] = make([]iovec, iovMax)
|
|
||||||
}
|
|
||||||
return &StdConn{
|
|
||||||
sysFd: fd,
|
|
||||||
isV4: ip.Is4(),
|
|
||||||
l: l,
|
|
||||||
batch: batch,
|
|
||||||
msgs: msgs,
|
|
||||||
iovs: iovs,
|
|
||||||
}, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *StdConn) Rebind() error {
|
func (u *StdConn) Rebind() error {
|
||||||
@@ -132,7 +119,9 @@ func (u *StdConn) LocalAddr() (netip.AddrPort, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (u *StdConn) ListenOut(r EncReader) {
|
func (u *StdConn) ListenOut(r EncReader) {
|
||||||
msgs, packets := u.PrepareRawMessages(u.batch, u.isV4)
|
var ip netip.Addr
|
||||||
|
|
||||||
|
msgs, buffers, names := u.PrepareRawMessages(u.batch)
|
||||||
read := u.ReadMulti
|
read := u.ReadMulti
|
||||||
if u.batch == 1 {
|
if u.batch == 1 {
|
||||||
read = u.ReadSingle
|
read = u.ReadSingle
|
||||||
@@ -146,12 +135,13 @@ func (u *StdConn) ListenOut(r EncReader) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for i := 0; i < n; i++ {
|
for i := 0; i < n; i++ {
|
||||||
packets[i].Payload = packets[i].Payload[:msgs[i].Len]
|
// Its ok to skip the ok check here, the slicing is the only error that can occur and it will panic
|
||||||
packets[i].Update(getRawMessageControlLen(&msgs[i]))
|
if u.isV4 {
|
||||||
}
|
ip, _ = netip.AddrFromSlice(names[i][4:8])
|
||||||
r(packets[:n])
|
} else {
|
||||||
for i := 0; i < n; i++ { //todo reset this in prev loop, but this makes debug ez
|
ip, _ = netip.AddrFromSlice(names[i][8:24])
|
||||||
msgs[i].Hdr.Controllen = uint64(unix.CmsgSpace(2))
|
}
|
||||||
|
r(netip.AddrPortFrom(ip.Unmap(), binary.BigEndian.Uint16(names[i][2:4])), buffers[i][:msgs[i].Len])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -204,147 +194,6 @@ func (u *StdConn) WriteTo(b []byte, ip netip.AddrPort) error {
|
|||||||
return u.writeTo6(b, ip)
|
return u.writeTo6(b, ip)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *StdConn) WriteToBatch(b []byte, ip netip.AddrPort) error {
|
|
||||||
if u.isV4 {
|
|
||||||
return u.writeTo4(b, ip)
|
|
||||||
}
|
|
||||||
return u.writeTo6(b, ip)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (u *StdConn) Prep(pkt *packet.Packet, addr netip.AddrPort) error {
|
|
||||||
nl, err := u.encodeSockaddr(pkt.Name, addr)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
pkt.Name = pkt.Name[:nl]
|
|
||||||
pkt.OutLen = len(pkt.Payload)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (u *StdConn) WriteBatch(pkts []*packet.Packet) (int, error) {
|
|
||||||
if len(pkts) == 0 {
|
|
||||||
return 0, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
u.msgs = u.msgs[:0]
|
|
||||||
//u.iovs = u.iovs[:0]
|
|
||||||
|
|
||||||
sent := 0
|
|
||||||
var mostRecentPkt *packet.Packet
|
|
||||||
mostRecentPktSize := 0
|
|
||||||
//segmenting := false
|
|
||||||
idx := 0
|
|
||||||
for _, pkt := range pkts {
|
|
||||||
if len(pkt.Payload) == 0 || pkt.OutLen == -1 {
|
|
||||||
sent++
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
lastIdx := idx - 1
|
|
||||||
if mostRecentPkt != nil && pkt.CompatibleForSegmentationWith(mostRecentPkt, mostRecentPktSize) && u.msgs[lastIdx].Hdr.Iovlen < iovMax {
|
|
||||||
u.msgs[lastIdx].Hdr.Controllen = uint64(len(mostRecentPkt.Control))
|
|
||||||
u.msgs[lastIdx].Hdr.Control = &mostRecentPkt.Control[0]
|
|
||||||
|
|
||||||
u.iovs[lastIdx][u.msgs[lastIdx].Hdr.Iovlen].Base = &pkt.Payload[0]
|
|
||||||
u.iovs[lastIdx][u.msgs[lastIdx].Hdr.Iovlen].Len = uint64(len(pkt.Payload))
|
|
||||||
u.msgs[lastIdx].Hdr.Iovlen++
|
|
||||||
|
|
||||||
mostRecentPktSize += len(pkt.Payload)
|
|
||||||
mostRecentPkt.SetSegSizeForTX()
|
|
||||||
} else {
|
|
||||||
u.msgs = append(u.msgs, rawMessage{})
|
|
||||||
u.iovs[idx][0] = iovec{
|
|
||||||
Base: &pkt.Payload[0],
|
|
||||||
Len: uint64(len(pkt.Payload)),
|
|
||||||
}
|
|
||||||
|
|
||||||
msg := &u.msgs[idx]
|
|
||||||
iov := &u.iovs[idx][0]
|
|
||||||
idx++
|
|
||||||
|
|
||||||
msg.Hdr.Iov = iov
|
|
||||||
msg.Hdr.Iovlen = 1
|
|
||||||
setRawMessageControl(msg, nil)
|
|
||||||
msg.Hdr.Flags = 0
|
|
||||||
|
|
||||||
msg.Hdr.Name = &pkt.Name[0]
|
|
||||||
msg.Hdr.Namelen = uint32(len(pkt.Name))
|
|
||||||
mostRecentPkt = pkt
|
|
||||||
mostRecentPktSize = len(pkt.Payload)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(u.msgs) == 0 {
|
|
||||||
return sent, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
offset := 0
|
|
||||||
for offset < len(u.msgs) {
|
|
||||||
n, _, errno := unix.Syscall6(
|
|
||||||
unix.SYS_SENDMMSG,
|
|
||||||
uintptr(u.sysFd),
|
|
||||||
uintptr(unsafe.Pointer(&u.msgs[offset])),
|
|
||||||
uintptr(len(u.msgs)-offset),
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
)
|
|
||||||
|
|
||||||
if errno != 0 {
|
|
||||||
if errno == unix.EINTR {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
//for i := 0; i < len(u.msgs); i++ {
|
|
||||||
// for j := 0; j < int(u.msgs[i].Hdr.Iovlen); j++ {
|
|
||||||
// u.l.WithFields(logrus.Fields{
|
|
||||||
// "msg_index": i,
|
|
||||||
// "iov idx": j,
|
|
||||||
// "iov": fmt.Sprintf("%+v", u.iovs[i][j]),
|
|
||||||
// }).Warn("failed to send message")
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
//}
|
|
||||||
u.l.WithFields(logrus.Fields{
|
|
||||||
"errno": errno,
|
|
||||||
"idx": idx,
|
|
||||||
"len": len(u.msgs),
|
|
||||||
"deets": fmt.Sprintf("%+v", u.msgs),
|
|
||||||
"lastIOV": fmt.Sprintf("%+v", u.iovs[len(u.msgs)-1][u.msgs[len(u.msgs)-1].Hdr.Iovlen-1]),
|
|
||||||
}).Error("failed to send message")
|
|
||||||
return sent + offset, &net.OpError{Op: "sendmmsg", Err: errno}
|
|
||||||
}
|
|
||||||
|
|
||||||
if n == 0 {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
offset += int(n)
|
|
||||||
}
|
|
||||||
|
|
||||||
return sent + len(u.msgs), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (u *StdConn) encodeSockaddr(dst []byte, addr netip.AddrPort) (uint32, error) {
|
|
||||||
if u.isV4 {
|
|
||||||
if !addr.Addr().Is4() {
|
|
||||||
return 0, fmt.Errorf("Listener is IPv4, but writing to IPv6 remote")
|
|
||||||
}
|
|
||||||
var sa unix.RawSockaddrInet4
|
|
||||||
sa.Family = unix.AF_INET
|
|
||||||
sa.Addr = addr.Addr().As4()
|
|
||||||
binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&sa.Port))[:], addr.Port())
|
|
||||||
size := unix.SizeofSockaddrInet4
|
|
||||||
copy(dst[:size], (*(*[unix.SizeofSockaddrInet4]byte)(unsafe.Pointer(&sa)))[:])
|
|
||||||
return uint32(size), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
var sa unix.RawSockaddrInet6
|
|
||||||
sa.Family = unix.AF_INET6
|
|
||||||
sa.Addr = addr.Addr().As16()
|
|
||||||
binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&sa.Port))[:], addr.Port())
|
|
||||||
size := unix.SizeofSockaddrInet6
|
|
||||||
copy(dst[:size], (*(*[unix.SizeofSockaddrInet6]byte)(unsafe.Pointer(&sa)))[:])
|
|
||||||
return uint32(size), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (u *StdConn) writeTo6(b []byte, ip netip.AddrPort) error {
|
func (u *StdConn) writeTo6(b []byte, ip netip.AddrPort) error {
|
||||||
var rsa unix.RawSockaddrInet6
|
var rsa unix.RawSockaddrInet6
|
||||||
rsa.Family = unix.AF_INET6
|
rsa.Family = unix.AF_INET6
|
||||||
@@ -445,27 +294,6 @@ func (u *StdConn) ReloadConfig(c *config.C) {
|
|||||||
u.l.WithError(err).Error("Failed to set listen.so_mark")
|
u.l.WithError(err).Error("Failed to set listen.so_mark")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
u.configureGRO(true)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (u *StdConn) configureGRO(enable bool) {
|
|
||||||
if enable == u.enableGRO {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if enable {
|
|
||||||
if err := unix.SetsockoptInt(u.sysFd, unix.SOL_UDP, unix.UDP_GRO, 1); err != nil {
|
|
||||||
u.l.WithError(err).Warn("Failed to enable UDP GRO")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
u.enableGRO = true
|
|
||||||
u.l.Info("UDP GRO enabled")
|
|
||||||
} else {
|
|
||||||
if err := unix.SetsockoptInt(u.sysFd, unix.SOL_UDP, unix.UDP_GRO, 0); err != nil && err != unix.ENOPROTOOPT {
|
|
||||||
u.l.WithError(err).Warn("Failed to disable UDP GRO")
|
|
||||||
}
|
|
||||||
u.enableGRO = false
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *StdConn) getMemInfo(meminfo *[unix.SK_MEMINFO_VARS]uint32) error {
|
func (u *StdConn) getMemInfo(meminfo *[unix.SK_MEMINFO_VARS]uint32) error {
|
||||||
|
|||||||
@@ -7,7 +7,6 @@
|
|||||||
package udp
|
package udp
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/slackhq/nebula/packet"
|
|
||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -34,59 +33,25 @@ type rawMessage struct {
|
|||||||
Pad0 [4]byte
|
Pad0 [4]byte
|
||||||
}
|
}
|
||||||
|
|
||||||
func setRawMessageControl(msg *rawMessage, buf []byte) {
|
func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) {
|
||||||
if len(buf) == 0 {
|
|
||||||
msg.Hdr.Control = nil
|
|
||||||
msg.Hdr.Controllen = 0
|
|
||||||
return
|
|
||||||
}
|
|
||||||
msg.Hdr.Control = &buf[0]
|
|
||||||
msg.Hdr.Controllen = uint64(len(buf))
|
|
||||||
}
|
|
||||||
|
|
||||||
func getRawMessageControlLen(msg *rawMessage) int {
|
|
||||||
return int(msg.Hdr.Controllen)
|
|
||||||
}
|
|
||||||
|
|
||||||
func setCmsgLen(h *unix.Cmsghdr, l int) {
|
|
||||||
h.Len = uint64(l)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (u *StdConn) PrepareRawMessages(n int, isV4 bool) ([]rawMessage, []*packet.Packet) {
|
|
||||||
msgs := make([]rawMessage, n)
|
msgs := make([]rawMessage, n)
|
||||||
packets := make([]*packet.Packet, n)
|
buffers := make([][]byte, n)
|
||||||
|
names := make([][]byte, n)
|
||||||
|
|
||||||
for i := range msgs {
|
for i := range msgs {
|
||||||
packets[i] = packet.New(isV4)
|
buffers[i] = make([]byte, MTU)
|
||||||
|
names[i] = make([]byte, unix.SizeofSockaddrInet6)
|
||||||
|
|
||||||
vs := []iovec{
|
vs := []iovec{
|
||||||
{Base: &packets[i].Payload[0], Len: uint64(packet.Size)},
|
{Base: &buffers[i][0], Len: uint64(len(buffers[i]))},
|
||||||
}
|
}
|
||||||
|
|
||||||
msgs[i].Hdr.Iov = &vs[0]
|
msgs[i].Hdr.Iov = &vs[0]
|
||||||
msgs[i].Hdr.Iovlen = uint64(len(vs))
|
msgs[i].Hdr.Iovlen = uint64(len(vs))
|
||||||
|
|
||||||
msgs[i].Hdr.Name = &packets[i].Name[0]
|
msgs[i].Hdr.Name = &names[i][0]
|
||||||
msgs[i].Hdr.Namelen = uint32(len(packets[i].Name))
|
msgs[i].Hdr.Namelen = uint32(len(names[i]))
|
||||||
|
|
||||||
if u.enableGRO {
|
|
||||||
msgs[i].Hdr.Control = &packets[i].Control[0]
|
|
||||||
msgs[i].Hdr.Controllen = uint64(len(packets[i].Control))
|
|
||||||
} else {
|
|
||||||
msgs[i].Hdr.Control = nil
|
|
||||||
msgs[i].Hdr.Controllen = 0
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return msgs, packets
|
return msgs, buffers, names
|
||||||
}
|
|
||||||
|
|
||||||
func setIovecSlice(iov *iovec, b []byte) {
|
|
||||||
if len(b) == 0 {
|
|
||||||
iov.Base = nil
|
|
||||||
iov.Len = 0
|
|
||||||
return
|
|
||||||
}
|
|
||||||
iov.Base = &b[0]
|
|
||||||
iov.Len = uint64(len(b))
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,3 +0,0 @@
|
|||||||
// Package virtio contains some generic types and concepts related to the virtio
|
|
||||||
// protocol.
|
|
||||||
package virtio
|
|
||||||
@@ -1,136 +0,0 @@
|
|||||||
package virtio
|
|
||||||
|
|
||||||
// Feature contains feature bits that describe a virtio device or driver.
|
|
||||||
type Feature uint64
|
|
||||||
|
|
||||||
// Device-independent feature bits.
|
|
||||||
//
|
|
||||||
// Source: https://docs.oasis-open.org/virtio/virtio/v1.2/csd01/virtio-v1.2-csd01.html#x1-6600006
|
|
||||||
const (
|
|
||||||
// FeatureIndirectDescriptors indicates that the driver can use descriptors
|
|
||||||
// with an additional layer of indirection.
|
|
||||||
FeatureIndirectDescriptors Feature = 1 << 28
|
|
||||||
|
|
||||||
// FeatureVersion1 indicates compliance with version 1.0 of the virtio
|
|
||||||
// specification.
|
|
||||||
FeatureVersion1 Feature = 1 << 32
|
|
||||||
)
|
|
||||||
|
|
||||||
// Feature bits for networking devices.
|
|
||||||
//
|
|
||||||
// Source: https://docs.oasis-open.org/virtio/virtio/v1.2/csd01/virtio-v1.2-csd01.html#x1-2200003
|
|
||||||
const (
|
|
||||||
// FeatureNetDeviceCsum indicates that the device can handle packets with
|
|
||||||
// partial checksum (checksum offload).
|
|
||||||
FeatureNetDeviceCsum Feature = 1 << 0
|
|
||||||
|
|
||||||
// FeatureNetDriverCsum indicates that the driver can handle packets with
|
|
||||||
// partial checksum.
|
|
||||||
FeatureNetDriverCsum Feature = 1 << 1
|
|
||||||
|
|
||||||
// FeatureNetCtrlDriverOffloads indicates support for dynamic offload state
|
|
||||||
// reconfiguration.
|
|
||||||
FeatureNetCtrlDriverOffloads Feature = 1 << 2
|
|
||||||
|
|
||||||
// FeatureNetMTU indicates that the device reports a maximum MTU value.
|
|
||||||
FeatureNetMTU Feature = 1 << 3
|
|
||||||
|
|
||||||
// FeatureNetMAC indicates that the device provides a MAC address.
|
|
||||||
FeatureNetMAC Feature = 1 << 5
|
|
||||||
|
|
||||||
// FeatureNetDriverTSO4 indicates that the driver supports the TCP
|
|
||||||
// segmentation offload for received IPv4 packets.
|
|
||||||
FeatureNetDriverTSO4 Feature = 1 << 7
|
|
||||||
|
|
||||||
// FeatureNetDriverTSO6 indicates that the driver supports the TCP
|
|
||||||
// segmentation offload for received IPv6 packets.
|
|
||||||
FeatureNetDriverTSO6 Feature = 1 << 8
|
|
||||||
|
|
||||||
// FeatureNetDriverECN indicates that the driver supports the TCP
|
|
||||||
// segmentation offload with ECN for received packets.
|
|
||||||
FeatureNetDriverECN Feature = 1 << 9
|
|
||||||
|
|
||||||
// FeatureNetDriverUFO indicates that the driver supports the UDP
|
|
||||||
// fragmentation offload for received packets.
|
|
||||||
FeatureNetDriverUFO Feature = 1 << 10
|
|
||||||
|
|
||||||
// FeatureNetDeviceTSO4 indicates that the device supports the TCP
|
|
||||||
// segmentation offload for received IPv4 packets.
|
|
||||||
FeatureNetDeviceTSO4 Feature = 1 << 11
|
|
||||||
|
|
||||||
// FeatureNetDeviceTSO6 indicates that the device supports the TCP
|
|
||||||
// segmentation offload for received IPv6 packets.
|
|
||||||
FeatureNetDeviceTSO6 Feature = 1 << 12
|
|
||||||
|
|
||||||
// FeatureNetDeviceECN indicates that the device supports the TCP
|
|
||||||
// segmentation offload with ECN for received packets.
|
|
||||||
FeatureNetDeviceECN Feature = 1 << 13
|
|
||||||
|
|
||||||
// FeatureNetDeviceUFO indicates that the device supports the UDP
|
|
||||||
// fragmentation offload for received packets.
|
|
||||||
FeatureNetDeviceUFO Feature = 1 << 14
|
|
||||||
|
|
||||||
// FeatureNetMergeRXBuffers indicates that the driver can handle merged
|
|
||||||
// receive buffers.
|
|
||||||
// When this feature is negotiated, devices may merge multiple descriptor
|
|
||||||
// chains together to transport large received packets. [NetHdr.NumBuffers]
|
|
||||||
// will then contain the number of merged descriptor chains.
|
|
||||||
FeatureNetMergeRXBuffers Feature = 1 << 15
|
|
||||||
|
|
||||||
// FeatureNetStatus indicates that the device configuration status field is
|
|
||||||
// available.
|
|
||||||
FeatureNetStatus Feature = 1 << 16
|
|
||||||
|
|
||||||
// FeatureNetCtrlVQ indicates that a control channel virtqueue is
|
|
||||||
// available.
|
|
||||||
FeatureNetCtrlVQ Feature = 1 << 17
|
|
||||||
|
|
||||||
// FeatureNetCtrlRX indicates support for RX mode control (e.g. promiscuous
|
|
||||||
// or all-multicast) for packet receive filtering.
|
|
||||||
FeatureNetCtrlRX Feature = 1 << 18
|
|
||||||
|
|
||||||
// FeatureNetCtrlVLAN indicates support for VLAN filtering through the
|
|
||||||
// control channel.
|
|
||||||
FeatureNetCtrlVLAN Feature = 1 << 19
|
|
||||||
|
|
||||||
// FeatureNetDriverAnnounce indicates that the driver can send gratuitous
|
|
||||||
// packets.
|
|
||||||
FeatureNetDriverAnnounce Feature = 1 << 21
|
|
||||||
|
|
||||||
// FeatureNetMQ indicates that the device supports multiqueue with automatic
|
|
||||||
// receive steering.
|
|
||||||
FeatureNetMQ Feature = 1 << 22
|
|
||||||
|
|
||||||
// FeatureNetCtrlMACAddr indicates that the MAC address can be set through
|
|
||||||
// the control channel.
|
|
||||||
FeatureNetCtrlMACAddr Feature = 1 << 23
|
|
||||||
|
|
||||||
// FeatureNetDeviceUSO indicates that the device supports the UDP
|
|
||||||
// segmentation offload for received packets.
|
|
||||||
FeatureNetDeviceUSO Feature = 1 << 56
|
|
||||||
|
|
||||||
// FeatureNetHashReport indicates that the device can report a per-packet
|
|
||||||
// hash value and type.
|
|
||||||
FeatureNetHashReport Feature = 1 << 57
|
|
||||||
|
|
||||||
// FeatureNetDriverHdrLen indicates that the driver can provide the exact
|
|
||||||
// header length value (see [NetHdr.HdrLen]).
|
|
||||||
// Devices may benefit from knowing the exact header length.
|
|
||||||
FeatureNetDriverHdrLen Feature = 1 << 59
|
|
||||||
|
|
||||||
// FeatureNetRSS indicates that the device supports RSS (receive-side
|
|
||||||
// scaling) with configurable hash parameters.
|
|
||||||
FeatureNetRSS Feature = 1 << 60
|
|
||||||
|
|
||||||
// FeatureNetRSCExt indicates that the device can process duplicated ACKs
|
|
||||||
// and report the number of coalesced segments and duplicated ACKs.
|
|
||||||
FeatureNetRSCExt Feature = 1 << 61
|
|
||||||
|
|
||||||
// FeatureNetStandby indicates that the device may act as a standby for a
|
|
||||||
// primary device with the same MAC address.
|
|
||||||
FeatureNetStandby Feature = 1 << 62
|
|
||||||
|
|
||||||
// FeatureNetSpeedDuplex indicates that the device can report link speed and
|
|
||||||
// duplex mode.
|
|
||||||
FeatureNetSpeedDuplex Feature = 1 << 63
|
|
||||||
)
|
|
||||||
@@ -1,77 +0,0 @@
|
|||||||
package virtio
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"unsafe"
|
|
||||||
|
|
||||||
"golang.org/x/sys/unix"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Workaround to make Go doc links work.
|
|
||||||
var _ unix.Errno
|
|
||||||
|
|
||||||
// NetHdrSize is the number of bytes needed to store a [NetHdr] in memory.
|
|
||||||
const NetHdrSize = 12
|
|
||||||
|
|
||||||
// ErrNetHdrBufferTooSmall is returned when a buffer is too small to fit a
|
|
||||||
// virtio_net_hdr.
|
|
||||||
var ErrNetHdrBufferTooSmall = errors.New("the buffer is too small to fit a virtio_net_hdr")
|
|
||||||
|
|
||||||
// NetHdr defines the virtio_net_hdr as described by the virtio specification.
|
|
||||||
type NetHdr struct {
|
|
||||||
// Flags that describe the packet.
|
|
||||||
// Possible values are:
|
|
||||||
// - [unix.VIRTIO_NET_HDR_F_NEEDS_CSUM]
|
|
||||||
// - [unix.VIRTIO_NET_HDR_F_DATA_VALID]
|
|
||||||
// - [unix.VIRTIO_NET_HDR_F_RSC_INFO]
|
|
||||||
Flags uint8
|
|
||||||
// GSOType contains the type of segmentation offload that should be used for
|
|
||||||
// the packet.
|
|
||||||
// Possible values are:
|
|
||||||
// - [unix.VIRTIO_NET_HDR_GSO_NONE]
|
|
||||||
// - [unix.VIRTIO_NET_HDR_GSO_TCPV4]
|
|
||||||
// - [unix.VIRTIO_NET_HDR_GSO_UDP]
|
|
||||||
// - [unix.VIRTIO_NET_HDR_GSO_TCPV6]
|
|
||||||
// - [unix.VIRTIO_NET_HDR_GSO_UDP_L4]
|
|
||||||
// - [unix.VIRTIO_NET_HDR_GSO_ECN]
|
|
||||||
GSOType uint8
|
|
||||||
// HdrLen contains the length of the headers that need to be replicated by
|
|
||||||
// segmentation offloads. It's the number of bytes from the beginning of the
|
|
||||||
// packet to the beginning of the transport payload.
|
|
||||||
// Only used when [FeatureNetDriverHdrLen] is negotiated.
|
|
||||||
HdrLen uint16
|
|
||||||
// GSOSize contains the maximum size of each segmented packet beyond the
|
|
||||||
// header (payload size). In case of TCP, this is the MSS.
|
|
||||||
GSOSize uint16
|
|
||||||
// CsumStart contains the offset within the packet from which on the
|
|
||||||
// checksum should be computed.
|
|
||||||
CsumStart uint16
|
|
||||||
// CsumOffset specifies how many bytes after [NetHdr.CsumStart] the computed
|
|
||||||
// 16-bit checksum should be inserted.
|
|
||||||
CsumOffset uint16
|
|
||||||
// NumBuffers contains the number of merged descriptor chains when
|
|
||||||
// [FeatureNetMergeRXBuffers] is negotiated.
|
|
||||||
// This field is only used for packets received by the driver and should be
|
|
||||||
// zero for transmitted packets.
|
|
||||||
NumBuffers uint16
|
|
||||||
}
|
|
||||||
|
|
||||||
// Decode decodes the [NetHdr] from the given byte slice. The slice must contain
|
|
||||||
// at least [NetHdrSize] bytes.
|
|
||||||
func (v *NetHdr) Decode(data []byte) error {
|
|
||||||
if len(data) < NetHdrSize {
|
|
||||||
return ErrNetHdrBufferTooSmall
|
|
||||||
}
|
|
||||||
copy(unsafe.Slice((*byte)(unsafe.Pointer(v)), NetHdrSize), data[:NetHdrSize])
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Encode encodes the [NetHdr] into the given byte slice. The slice must have
|
|
||||||
// room for at least [NetHdrSize] bytes.
|
|
||||||
func (v *NetHdr) Encode(data []byte) error {
|
|
||||||
if len(data) < NetHdrSize {
|
|
||||||
return ErrNetHdrBufferTooSmall
|
|
||||||
}
|
|
||||||
copy(data[:NetHdrSize], unsafe.Slice((*byte)(unsafe.Pointer(v)), NetHdrSize))
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
@@ -1,43 +0,0 @@
|
|||||||
package virtio
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
"unsafe"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
"golang.org/x/sys/unix"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestNetHdr_Size(t *testing.T) {
|
|
||||||
assert.EqualValues(t, NetHdrSize, unsafe.Sizeof(NetHdr{}))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNetHdr_Encoding(t *testing.T) {
|
|
||||||
vnethdr := NetHdr{
|
|
||||||
Flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM,
|
|
||||||
GSOType: unix.VIRTIO_NET_HDR_GSO_UDP_L4,
|
|
||||||
HdrLen: 42,
|
|
||||||
GSOSize: 1472,
|
|
||||||
CsumStart: 34,
|
|
||||||
CsumOffset: 6,
|
|
||||||
NumBuffers: 16,
|
|
||||||
}
|
|
||||||
|
|
||||||
buf := make([]byte, NetHdrSize)
|
|
||||||
require.NoError(t, vnethdr.Encode(buf))
|
|
||||||
|
|
||||||
assert.Equal(t, []byte{
|
|
||||||
0x01, 0x05,
|
|
||||||
0x2a, 0x00,
|
|
||||||
0xc0, 0x05,
|
|
||||||
0x22, 0x00,
|
|
||||||
0x06, 0x00,
|
|
||||||
0x10, 0x00,
|
|
||||||
}, buf)
|
|
||||||
|
|
||||||
var decoded NetHdr
|
|
||||||
require.NoError(t, decoded.Decode(buf))
|
|
||||||
|
|
||||||
assert.Equal(t, vnethdr, decoded)
|
|
||||||
}
|
|
||||||
Reference in New Issue
Block a user