mirror of
https://github.com/slackhq/nebula.git
synced 2025-11-22 16:34:25 +01:00
Compare commits
4 Commits
jay.wren-w
...
stinkier
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
29157f413c | ||
|
|
68746bd907 | ||
|
|
51b383f680 | ||
|
|
71c849e63e |
18
CHANGELOG.md
18
CHANGELOG.md
@@ -7,12 +7,30 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
|||||||
|
|
||||||
## [Unreleased]
|
## [Unreleased]
|
||||||
|
|
||||||
|
### Added
|
||||||
|
|
||||||
|
- Experimental Linux UDP offload support: enable `listen.enable_gso` and
|
||||||
|
`listen.enable_gro` to activate UDP_SEGMENT batching and GRO receive
|
||||||
|
splitting. Includes automatic capability probing, per-packet fallbacks, and
|
||||||
|
runtime metrics/logs for visibility.
|
||||||
|
- Optional Linux TUN `virtio_net_hdr` support: set `tun.enable_vnet_hdr` to
|
||||||
|
have Nebula negotiate VNET headers and offload flags so future batches can
|
||||||
|
be delivered to the kernel with metadata instead of per-packet writes.
|
||||||
|
- Linux UDP send sharding can now be tuned with `listen.send_shards`; defaults
|
||||||
|
to `GOMAXPROCS` but can be increased to stripe heavy peers across more
|
||||||
|
goroutines.
|
||||||
|
|
||||||
### Changed
|
### Changed
|
||||||
|
|
||||||
- `default_local_cidr_any` now defaults to false, meaning that any firewall rule
|
- `default_local_cidr_any` now defaults to false, meaning that any firewall rule
|
||||||
intended to target an `unsafe_routes` entry must explicitly declare it via the
|
intended to target an `unsafe_routes` entry must explicitly declare it via the
|
||||||
`local_cidr` field. This is almost always the intended behavior. This flag is
|
`local_cidr` field. This is almost always the intended behavior. This flag is
|
||||||
deprecated and will be removed in a future release.
|
deprecated and will be removed in a future release.
|
||||||
|
- UDP receive path now enqueues into per-worker lock-free rings, restoring the
|
||||||
|
`listen.decrypt_workers`/`listen.decrypt_queue_depth` tuning knobs while
|
||||||
|
eliminating the mutex contention from the old shared channel.
|
||||||
|
- Increased replay protection window to 32k packets so high-throughput links
|
||||||
|
tolerate larger bursts of reordering without tripping the anti-replay logic.
|
||||||
|
|
||||||
## [1.9.4] - 2024-09-09
|
## [1.9.4] - 2024-09-09
|
||||||
|
|
||||||
|
|||||||
@@ -114,33 +114,6 @@ func NewTestCert(v cert.Version, curve cert.Curve, ca cert.Certificate, key []by
|
|||||||
return c, pub, cert.MarshalPrivateKeyToPEM(curve, priv), pem
|
return c, pub, cert.MarshalPrivateKeyToPEM(curve, priv), pem
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewTestCertDifferentVersion(c cert.Certificate, v cert.Version, ca cert.Certificate, key []byte) (cert.Certificate, []byte) {
|
|
||||||
nc := &cert.TBSCertificate{
|
|
||||||
Version: v,
|
|
||||||
Curve: c.Curve(),
|
|
||||||
Name: c.Name(),
|
|
||||||
Networks: c.Networks(),
|
|
||||||
UnsafeNetworks: c.UnsafeNetworks(),
|
|
||||||
Groups: c.Groups(),
|
|
||||||
NotBefore: time.Unix(c.NotBefore().Unix(), 0),
|
|
||||||
NotAfter: time.Unix(c.NotAfter().Unix(), 0),
|
|
||||||
PublicKey: c.PublicKey(),
|
|
||||||
IsCA: false,
|
|
||||||
}
|
|
||||||
|
|
||||||
c, err := nc.Sign(ca, ca.Curve(), key)
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
pem, err := c.MarshalPEM()
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return c, pem
|
|
||||||
}
|
|
||||||
|
|
||||||
func X25519Keypair() ([]byte, []byte) {
|
func X25519Keypair() ([]byte, []byte) {
|
||||||
privkey := make([]byte, 32)
|
privkey := make([]byte, 32)
|
||||||
if _, err := io.ReadFull(rand.Reader, privkey); err != nil {
|
if _, err := io.ReadFull(rand.Reader, privkey); err != nil {
|
||||||
|
|||||||
@@ -354,6 +354,7 @@ func (cm *connectionManager) makeTrafficDecision(localIndex uint32, now time.Tim
|
|||||||
|
|
||||||
if mainHostInfo {
|
if mainHostInfo {
|
||||||
decision = tryRehandshake
|
decision = tryRehandshake
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
if cm.shouldSwapPrimary(hostinfo) {
|
if cm.shouldSwapPrimary(hostinfo) {
|
||||||
decision = swapPrimary
|
decision = swapPrimary
|
||||||
@@ -460,10 +461,6 @@ func (cm *connectionManager) shouldSwapPrimary(current *HostInfo) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
crt := cm.intf.pki.getCertState().getCertificate(current.ConnectionState.myCert.Version())
|
crt := cm.intf.pki.getCertState().getCertificate(current.ConnectionState.myCert.Version())
|
||||||
if crt == nil {
|
|
||||||
//my cert was reloaded away. We should definitely swap from this tunnel
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
// If this tunnel is using the latest certificate then we should swap it to primary for a bit and see if things
|
// If this tunnel is using the latest certificate then we should swap it to primary for a bit and see if things
|
||||||
// settle down.
|
// settle down.
|
||||||
return bytes.Equal(current.ConnectionState.myCert.Signature(), crt.Signature())
|
return bytes.Equal(current.ConnectionState.myCert.Signature(), crt.Signature())
|
||||||
@@ -478,34 +475,31 @@ func (cm *connectionManager) swapPrimary(current, primary *HostInfo) {
|
|||||||
cm.hostMap.Unlock()
|
cm.hostMap.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
// isInvalidCertificate decides if we should destroy a tunnel.
|
// isInvalidCertificate will check if we should destroy a tunnel if pki.disconnect_invalid is true and
|
||||||
// returns true if pki.disconnect_invalid is true and the certificate is no longer valid.
|
// the certificate is no longer valid. Block listed certificates will skip the pki.disconnect_invalid
|
||||||
// Blocklisted certificates will skip the pki.disconnect_invalid check and return true.
|
// check and return true.
|
||||||
func (cm *connectionManager) isInvalidCertificate(now time.Time, hostinfo *HostInfo) bool {
|
func (cm *connectionManager) isInvalidCertificate(now time.Time, hostinfo *HostInfo) bool {
|
||||||
remoteCert := hostinfo.GetCert()
|
remoteCert := hostinfo.GetCert()
|
||||||
if remoteCert == nil {
|
if remoteCert == nil {
|
||||||
return false //don't tear down tunnels for handshakes in progress
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
caPool := cm.intf.pki.GetCAPool()
|
caPool := cm.intf.pki.GetCAPool()
|
||||||
err := caPool.VerifyCachedCertificate(now, remoteCert)
|
err := caPool.VerifyCachedCertificate(now, remoteCert)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
return false //cert is still valid! yay!
|
return false
|
||||||
} else if err == cert.ErrBlockListed { //avoiding errors.Is for speed
|
}
|
||||||
|
|
||||||
|
if !cm.intf.disconnectInvalid.Load() && err != cert.ErrBlockListed {
|
||||||
// Block listed certificates should always be disconnected
|
// Block listed certificates should always be disconnected
|
||||||
hostinfo.logger(cm.l).WithError(err).
|
return false
|
||||||
WithField("fingerprint", remoteCert.Fingerprint).
|
}
|
||||||
Info("Remote certificate is blocked, tearing down the tunnel")
|
|
||||||
return true
|
|
||||||
} else if cm.intf.disconnectInvalid.Load() {
|
|
||||||
hostinfo.logger(cm.l).WithError(err).
|
hostinfo.logger(cm.l).WithError(err).
|
||||||
WithField("fingerprint", remoteCert.Fingerprint).
|
WithField("fingerprint", remoteCert.Fingerprint).
|
||||||
Info("Remote certificate is no longer valid, tearing down the tunnel")
|
Info("Remote certificate is no longer valid, tearing down the tunnel")
|
||||||
|
|
||||||
return true
|
return true
|
||||||
} else {
|
|
||||||
//if we reach here, the cert is no longer valid, but we're configured to keep tunnels from now-invalid certs open
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cm *connectionManager) sendPunch(hostinfo *HostInfo) {
|
func (cm *connectionManager) sendPunch(hostinfo *HostInfo) {
|
||||||
@@ -536,45 +530,15 @@ func (cm *connectionManager) sendPunch(hostinfo *HostInfo) {
|
|||||||
func (cm *connectionManager) tryRehandshake(hostinfo *HostInfo) {
|
func (cm *connectionManager) tryRehandshake(hostinfo *HostInfo) {
|
||||||
cs := cm.intf.pki.getCertState()
|
cs := cm.intf.pki.getCertState()
|
||||||
curCrt := hostinfo.ConnectionState.myCert
|
curCrt := hostinfo.ConnectionState.myCert
|
||||||
curCrtVersion := curCrt.Version()
|
myCrt := cs.getCertificate(curCrt.Version())
|
||||||
myCrt := cs.getCertificate(curCrtVersion)
|
if curCrt.Version() >= cs.initiatingVersion && bytes.Equal(curCrt.Signature(), myCrt.Signature()) == true {
|
||||||
if myCrt == nil {
|
// The current tunnel is using the latest certificate and version, no need to rehandshake.
|
||||||
cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
|
|
||||||
WithField("version", curCrtVersion).
|
|
||||||
WithField("reason", "local certificate removed").
|
|
||||||
Info("Re-handshaking with remote")
|
|
||||||
cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
peerCrt := hostinfo.ConnectionState.peerCert
|
|
||||||
if peerCrt != nil && curCrtVersion < peerCrt.Certificate.Version() {
|
|
||||||
// if our certificate version is less than theirs, and we have a matching version available, rehandshake?
|
|
||||||
if cs.getCertificate(peerCrt.Certificate.Version()) != nil {
|
|
||||||
cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
|
|
||||||
WithField("version", curCrtVersion).
|
|
||||||
WithField("peerVersion", peerCrt.Certificate.Version()).
|
|
||||||
WithField("reason", "local certificate version lower than peer, attempting to correct").
|
|
||||||
Info("Re-handshaking with remote")
|
|
||||||
cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], func(hh *HandshakeHostInfo) {
|
|
||||||
hh.initiatingVersionOverride = peerCrt.Certificate.Version()
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if !bytes.Equal(curCrt.Signature(), myCrt.Signature()) {
|
|
||||||
cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
|
cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
|
||||||
WithField("reason", "local certificate is not current").
|
WithField("reason", "local certificate is not current").
|
||||||
Info("Re-handshaking with remote")
|
Info("Re-handshaking with remote")
|
||||||
|
|
||||||
cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil)
|
cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil)
|
||||||
return
|
|
||||||
}
|
|
||||||
if curCrtVersion < cs.initiatingVersion {
|
|
||||||
cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
|
|
||||||
WithField("reason", "current cert version < pki.initiatingVersion").
|
|
||||||
Info("Re-handshaking with remote")
|
|
||||||
|
|
||||||
cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -13,7 +13,10 @@ import (
|
|||||||
"github.com/slackhq/nebula/noiseutil"
|
"github.com/slackhq/nebula/noiseutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
const ReplayWindow = 1024
|
// ReplayWindow controls the size of the sliding window used to detect replays.
|
||||||
|
// High-bandwidth links with GRO/GSO can reorder more than a thousand packets in
|
||||||
|
// flight, so keep this comfortably above the largest expected burst.
|
||||||
|
const ReplayWindow = 32768
|
||||||
|
|
||||||
type ConnectionState struct {
|
type ConnectionState struct {
|
||||||
eKey *NebulaCipherState
|
eKey *NebulaCipherState
|
||||||
|
|||||||
@@ -129,109 +129,6 @@ func newSimpleServer(v cert.Version, caCrt cert.Certificate, caKey []byte, name
|
|||||||
return control, vpnNetworks, udpAddr, c
|
return control, vpnNetworks, udpAddr, c
|
||||||
}
|
}
|
||||||
|
|
||||||
// newServer creates a nebula instance with fewer assumptions
|
|
||||||
func newServer(caCrt []cert.Certificate, certs []cert.Certificate, key []byte, overrides m) (*nebula.Control, []netip.Prefix, netip.AddrPort, *config.C) {
|
|
||||||
l := NewTestLogger()
|
|
||||||
|
|
||||||
vpnNetworks := certs[len(certs)-1].Networks()
|
|
||||||
|
|
||||||
var udpAddr netip.AddrPort
|
|
||||||
if vpnNetworks[0].Addr().Is4() {
|
|
||||||
budpIp := vpnNetworks[0].Addr().As4()
|
|
||||||
budpIp[1] -= 128
|
|
||||||
udpAddr = netip.AddrPortFrom(netip.AddrFrom4(budpIp), 4242)
|
|
||||||
} else {
|
|
||||||
budpIp := vpnNetworks[0].Addr().As16()
|
|
||||||
// beef for funsies
|
|
||||||
budpIp[2] = 190
|
|
||||||
budpIp[3] = 239
|
|
||||||
udpAddr = netip.AddrPortFrom(netip.AddrFrom16(budpIp), 4242)
|
|
||||||
}
|
|
||||||
|
|
||||||
caStr := ""
|
|
||||||
for _, ca := range caCrt {
|
|
||||||
x, err := ca.MarshalPEM()
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
caStr += string(x)
|
|
||||||
}
|
|
||||||
certStr := ""
|
|
||||||
for _, c := range certs {
|
|
||||||
x, err := c.MarshalPEM()
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
certStr += string(x)
|
|
||||||
}
|
|
||||||
|
|
||||||
mc := m{
|
|
||||||
"pki": m{
|
|
||||||
"ca": caStr,
|
|
||||||
"cert": certStr,
|
|
||||||
"key": string(key),
|
|
||||||
},
|
|
||||||
//"tun": m{"disabled": true},
|
|
||||||
"firewall": m{
|
|
||||||
"outbound": []m{{
|
|
||||||
"proto": "any",
|
|
||||||
"port": "any",
|
|
||||||
"host": "any",
|
|
||||||
}},
|
|
||||||
"inbound": []m{{
|
|
||||||
"proto": "any",
|
|
||||||
"port": "any",
|
|
||||||
"host": "any",
|
|
||||||
}},
|
|
||||||
},
|
|
||||||
//"handshakes": m{
|
|
||||||
// "try_interval": "1s",
|
|
||||||
//},
|
|
||||||
"listen": m{
|
|
||||||
"host": udpAddr.Addr().String(),
|
|
||||||
"port": udpAddr.Port(),
|
|
||||||
},
|
|
||||||
"logging": m{
|
|
||||||
"timestamp_format": fmt.Sprintf("%v 15:04:05.000000", certs[0].Name()),
|
|
||||||
"level": l.Level.String(),
|
|
||||||
},
|
|
||||||
"timers": m{
|
|
||||||
"pending_deletion_interval": 2,
|
|
||||||
"connection_alive_interval": 2,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
if overrides != nil {
|
|
||||||
final := m{}
|
|
||||||
err := mergo.Merge(&final, overrides, mergo.WithAppendSlice)
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
err = mergo.Merge(&final, mc, mergo.WithAppendSlice)
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
mc = final
|
|
||||||
}
|
|
||||||
|
|
||||||
cb, err := yaml.Marshal(mc)
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
c := config.NewC(l)
|
|
||||||
cStr := string(cb)
|
|
||||||
c.LoadString(cStr)
|
|
||||||
|
|
||||||
control, err := nebula.Main(c, false, "e2e-test", l, nil)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return control, vpnNetworks, udpAddr, c
|
|
||||||
}
|
|
||||||
|
|
||||||
type doneCb func()
|
type doneCb func()
|
||||||
|
|
||||||
func deadline(t *testing.T, seconds time.Duration) doneCb {
|
func deadline(t *testing.T, seconds time.Duration) doneCb {
|
||||||
|
|||||||
@@ -4,16 +4,12 @@
|
|||||||
package e2e
|
package e2e
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
|
||||||
"net/netip"
|
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/slackhq/nebula/cert"
|
"github.com/slackhq/nebula/cert"
|
||||||
"github.com/slackhq/nebula/cert_test"
|
"github.com/slackhq/nebula/cert_test"
|
||||||
"github.com/slackhq/nebula/e2e/router"
|
"github.com/slackhq/nebula/e2e/router"
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"gopkg.in/yaml.v3"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestDropInactiveTunnels(t *testing.T) {
|
func TestDropInactiveTunnels(t *testing.T) {
|
||||||
@@ -59,262 +55,3 @@ func TestDropInactiveTunnels(t *testing.T) {
|
|||||||
myControl.Stop()
|
myControl.Stop()
|
||||||
theirControl.Stop()
|
theirControl.Stop()
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCertUpgrade(t *testing.T) {
|
|
||||||
// The goal of this test is to ensure the shortest inactivity timeout will close the tunnel on both sides
|
|
||||||
// under ideal conditions
|
|
||||||
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
|
||||||
caB, err := ca.MarshalPEM()
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
ca2, _, caKey2, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
|
||||||
|
|
||||||
ca2B, err := ca2.MarshalPEM()
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
caStr := fmt.Sprintf("%s\n%s", caB, ca2B)
|
|
||||||
|
|
||||||
myCert, _, myPrivKey, _ := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "me", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.1/24")}, nil, []string{})
|
|
||||||
_, myCert2Pem := cert_test.NewTestCertDifferentVersion(myCert, cert.Version2, ca2, caKey2)
|
|
||||||
|
|
||||||
theirCert, _, theirPrivKey, _ := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "them", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.2/24")}, nil, []string{})
|
|
||||||
theirCert2, _ := cert_test.NewTestCertDifferentVersion(theirCert, cert.Version2, ca2, caKey2)
|
|
||||||
|
|
||||||
myControl, myVpnIpNet, myUdpAddr, myC := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{myCert}, myPrivKey, m{})
|
|
||||||
theirControl, theirVpnIpNet, theirUdpAddr, _ := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{theirCert, theirCert2}, theirPrivKey, m{})
|
|
||||||
|
|
||||||
// Share our underlay information
|
|
||||||
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
|
|
||||||
theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
|
|
||||||
|
|
||||||
// Start the servers
|
|
||||||
myControl.Start()
|
|
||||||
theirControl.Start()
|
|
||||||
|
|
||||||
r := router.NewR(t, myControl, theirControl)
|
|
||||||
defer r.RenderFlow()
|
|
||||||
|
|
||||||
r.Log("Assert the tunnel between me and them works")
|
|
||||||
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
|
|
||||||
r.Log("yay")
|
|
||||||
//todo ???
|
|
||||||
time.Sleep(1 * time.Second)
|
|
||||||
r.FlushAll()
|
|
||||||
|
|
||||||
mc := m{
|
|
||||||
"pki": m{
|
|
||||||
"ca": caStr,
|
|
||||||
"cert": string(myCert2Pem),
|
|
||||||
"key": string(myPrivKey),
|
|
||||||
},
|
|
||||||
//"tun": m{"disabled": true},
|
|
||||||
"firewall": myC.Settings["firewall"],
|
|
||||||
//"handshakes": m{
|
|
||||||
// "try_interval": "1s",
|
|
||||||
//},
|
|
||||||
"listen": myC.Settings["listen"],
|
|
||||||
"logging": myC.Settings["logging"],
|
|
||||||
"timers": myC.Settings["timers"],
|
|
||||||
}
|
|
||||||
|
|
||||||
cb, err := yaml.Marshal(mc)
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
r.Logf("reload new v2-only config")
|
|
||||||
err = myC.ReloadConfigString(string(cb))
|
|
||||||
assert.NoError(t, err)
|
|
||||||
r.Log("yay, spin until their sees it")
|
|
||||||
waitStart := time.Now()
|
|
||||||
for {
|
|
||||||
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
|
|
||||||
c := theirControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false)
|
|
||||||
if c == nil {
|
|
||||||
r.Log("nil")
|
|
||||||
} else {
|
|
||||||
version := c.Cert.Version()
|
|
||||||
r.Logf("version %d", version)
|
|
||||||
if version == cert.Version2 {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
since := time.Since(waitStart)
|
|
||||||
if since > time.Second*10 {
|
|
||||||
t.Fatal("Cert should be new by now")
|
|
||||||
}
|
|
||||||
time.Sleep(time.Second)
|
|
||||||
}
|
|
||||||
|
|
||||||
r.RenderHostmaps("Final hostmaps", myControl, theirControl)
|
|
||||||
|
|
||||||
myControl.Stop()
|
|
||||||
theirControl.Stop()
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCertDowngrade(t *testing.T) {
|
|
||||||
// The goal of this test is to ensure the shortest inactivity timeout will close the tunnel on both sides
|
|
||||||
// under ideal conditions
|
|
||||||
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
|
||||||
caB, err := ca.MarshalPEM()
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
ca2, _, caKey2, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
|
||||||
|
|
||||||
ca2B, err := ca2.MarshalPEM()
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
caStr := fmt.Sprintf("%s\n%s", caB, ca2B)
|
|
||||||
|
|
||||||
myCert, _, myPrivKey, myCertPem := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "me", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.1/24")}, nil, []string{})
|
|
||||||
myCert2, _ := cert_test.NewTestCertDifferentVersion(myCert, cert.Version2, ca2, caKey2)
|
|
||||||
|
|
||||||
theirCert, _, theirPrivKey, _ := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "them", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.2/24")}, nil, []string{})
|
|
||||||
theirCert2, _ := cert_test.NewTestCertDifferentVersion(theirCert, cert.Version2, ca2, caKey2)
|
|
||||||
|
|
||||||
myControl, myVpnIpNet, myUdpAddr, myC := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{myCert2}, myPrivKey, m{})
|
|
||||||
theirControl, theirVpnIpNet, theirUdpAddr, _ := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{theirCert, theirCert2}, theirPrivKey, m{})
|
|
||||||
|
|
||||||
// Share our underlay information
|
|
||||||
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
|
|
||||||
theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
|
|
||||||
|
|
||||||
// Start the servers
|
|
||||||
myControl.Start()
|
|
||||||
theirControl.Start()
|
|
||||||
|
|
||||||
r := router.NewR(t, myControl, theirControl)
|
|
||||||
defer r.RenderFlow()
|
|
||||||
|
|
||||||
r.Log("Assert the tunnel between me and them works")
|
|
||||||
//assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r)
|
|
||||||
//r.Log("yay")
|
|
||||||
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
|
|
||||||
r.Log("yay")
|
|
||||||
//todo ???
|
|
||||||
time.Sleep(1 * time.Second)
|
|
||||||
r.FlushAll()
|
|
||||||
|
|
||||||
mc := m{
|
|
||||||
"pki": m{
|
|
||||||
"ca": caStr,
|
|
||||||
"cert": string(myCertPem),
|
|
||||||
"key": string(myPrivKey),
|
|
||||||
},
|
|
||||||
"firewall": myC.Settings["firewall"],
|
|
||||||
"listen": myC.Settings["listen"],
|
|
||||||
"logging": myC.Settings["logging"],
|
|
||||||
"timers": myC.Settings["timers"],
|
|
||||||
}
|
|
||||||
|
|
||||||
cb, err := yaml.Marshal(mc)
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
r.Logf("reload new v1-only config")
|
|
||||||
err = myC.ReloadConfigString(string(cb))
|
|
||||||
assert.NoError(t, err)
|
|
||||||
r.Log("yay, spin until their sees it")
|
|
||||||
waitStart := time.Now()
|
|
||||||
for {
|
|
||||||
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
|
|
||||||
c := theirControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false)
|
|
||||||
c2 := myControl.GetHostInfoByVpnAddr(theirVpnIpNet[0].Addr(), false)
|
|
||||||
if c == nil || c2 == nil {
|
|
||||||
r.Log("nil")
|
|
||||||
} else {
|
|
||||||
version := c.Cert.Version()
|
|
||||||
theirVersion := c2.Cert.Version()
|
|
||||||
r.Logf("version %d,%d", version, theirVersion)
|
|
||||||
if version == cert.Version1 {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
since := time.Since(waitStart)
|
|
||||||
if since > time.Second*5 {
|
|
||||||
r.Log("it is unusual that the cert is not new yet, but not a failure yet")
|
|
||||||
}
|
|
||||||
if since > time.Second*10 {
|
|
||||||
r.Log("wtf")
|
|
||||||
t.Fatal("Cert should be new by now")
|
|
||||||
}
|
|
||||||
time.Sleep(time.Second)
|
|
||||||
}
|
|
||||||
|
|
||||||
r.RenderHostmaps("Final hostmaps", myControl, theirControl)
|
|
||||||
|
|
||||||
myControl.Stop()
|
|
||||||
theirControl.Stop()
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCertMismatchCorrection(t *testing.T) {
|
|
||||||
// The goal of this test is to ensure the shortest inactivity timeout will close the tunnel on both sides
|
|
||||||
// under ideal conditions
|
|
||||||
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
|
||||||
ca2, _, caKey2, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
|
||||||
|
|
||||||
myCert, _, myPrivKey, _ := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "me", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.1/24")}, nil, []string{})
|
|
||||||
myCert2, _ := cert_test.NewTestCertDifferentVersion(myCert, cert.Version2, ca2, caKey2)
|
|
||||||
|
|
||||||
theirCert, _, theirPrivKey, _ := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "them", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.2/24")}, nil, []string{})
|
|
||||||
theirCert2, _ := cert_test.NewTestCertDifferentVersion(theirCert, cert.Version2, ca2, caKey2)
|
|
||||||
|
|
||||||
myControl, myVpnIpNet, myUdpAddr, _ := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{myCert2}, myPrivKey, m{})
|
|
||||||
theirControl, theirVpnIpNet, theirUdpAddr, _ := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{theirCert, theirCert2}, theirPrivKey, m{})
|
|
||||||
|
|
||||||
// Share our underlay information
|
|
||||||
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
|
|
||||||
theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
|
|
||||||
|
|
||||||
// Start the servers
|
|
||||||
myControl.Start()
|
|
||||||
theirControl.Start()
|
|
||||||
|
|
||||||
r := router.NewR(t, myControl, theirControl)
|
|
||||||
defer r.RenderFlow()
|
|
||||||
|
|
||||||
r.Log("Assert the tunnel between me and them works")
|
|
||||||
//assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r)
|
|
||||||
//r.Log("yay")
|
|
||||||
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
|
|
||||||
r.Log("yay")
|
|
||||||
//todo ???
|
|
||||||
time.Sleep(1 * time.Second)
|
|
||||||
r.FlushAll()
|
|
||||||
|
|
||||||
waitStart := time.Now()
|
|
||||||
for {
|
|
||||||
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
|
|
||||||
c := theirControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false)
|
|
||||||
c2 := myControl.GetHostInfoByVpnAddr(theirVpnIpNet[0].Addr(), false)
|
|
||||||
if c == nil || c2 == nil {
|
|
||||||
r.Log("nil")
|
|
||||||
} else {
|
|
||||||
version := c.Cert.Version()
|
|
||||||
theirVersion := c2.Cert.Version()
|
|
||||||
r.Logf("version %d,%d", version, theirVersion)
|
|
||||||
if version == theirVersion {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
since := time.Since(waitStart)
|
|
||||||
if since > time.Second*5 {
|
|
||||||
r.Log("wtf")
|
|
||||||
}
|
|
||||||
if since > time.Second*10 {
|
|
||||||
r.Log("wtf")
|
|
||||||
t.Fatal("Cert should be new by now")
|
|
||||||
}
|
|
||||||
time.Sleep(time.Second)
|
|
||||||
}
|
|
||||||
|
|
||||||
r.RenderHostmaps("Final hostmaps", myControl, theirControl)
|
|
||||||
|
|
||||||
myControl.Stop()
|
|
||||||
theirControl.Stop()
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -23,19 +23,15 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// If we're connecting to a v6 address we must use a v2 cert
|
||||||
cs := f.pki.getCertState()
|
cs := f.pki.getCertState()
|
||||||
v := cs.initiatingVersion
|
v := cs.initiatingVersion
|
||||||
if hh.initiatingVersionOverride != cert.VersionPre1 {
|
|
||||||
v = hh.initiatingVersionOverride
|
|
||||||
} else if v < cert.Version2 {
|
|
||||||
// If we're connecting to a v6 address we should encourage use of a V2 cert
|
|
||||||
for _, a := range hh.hostinfo.vpnAddrs {
|
for _, a := range hh.hostinfo.vpnAddrs {
|
||||||
if a.Is6() {
|
if a.Is6() {
|
||||||
v = cert.Version2
|
v = cert.Version2
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
crt := cs.getCertificate(v)
|
crt := cs.getCertificate(v)
|
||||||
if crt == nil {
|
if crt == nil {
|
||||||
@@ -52,7 +48,6 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
|
|||||||
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
|
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
|
||||||
WithField("certVersion", v).
|
WithField("certVersion", v).
|
||||||
Error("Unable to handshake with host because no certificate handshake bytes is available")
|
Error("Unable to handshake with host because no certificate handshake bytes is available")
|
||||||
return false
|
|
||||||
}
|
}
|
||||||
|
|
||||||
ci, err := NewConnectionState(f.l, cs, crt, true, noise.HandshakeIX)
|
ci, err := NewConnectionState(f.l, cs, crt, true, noise.HandshakeIX)
|
||||||
@@ -108,7 +103,6 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
|||||||
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
|
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
|
||||||
WithField("certVersion", cs.initiatingVersion).
|
WithField("certVersion", cs.initiatingVersion).
|
||||||
Error("Unable to handshake with host because no certificate is available")
|
Error("Unable to handshake with host because no certificate is available")
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
ci, err := NewConnectionState(f.l, cs, crt, false, noise.HandshakeIX)
|
ci, err := NewConnectionState(f.l, cs, crt, false, noise.HandshakeIX)
|
||||||
@@ -149,8 +143,8 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
|||||||
|
|
||||||
remoteCert, err := f.pki.GetCAPool().VerifyCertificate(time.Now(), rc)
|
remoteCert, err := f.pki.GetCAPool().VerifyCertificate(time.Now(), rc)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fp, fperr := rc.Fingerprint()
|
fp, err := rc.Fingerprint()
|
||||||
if fperr != nil {
|
if err != nil {
|
||||||
fp = "<error generating certificate fingerprint>"
|
fp = "<error generating certificate fingerprint>"
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -169,19 +163,16 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
|||||||
|
|
||||||
if remoteCert.Certificate.Version() != ci.myCert.Version() {
|
if remoteCert.Certificate.Version() != ci.myCert.Version() {
|
||||||
// We started off using the wrong certificate version, lets see if we can match the version that was sent to us
|
// We started off using the wrong certificate version, lets see if we can match the version that was sent to us
|
||||||
myCertOtherVersion := cs.getCertificate(remoteCert.Certificate.Version())
|
rc := cs.getCertificate(remoteCert.Certificate.Version())
|
||||||
if myCertOtherVersion == nil {
|
if rc == nil {
|
||||||
if f.l.Level >= logrus.DebugLevel {
|
f.l.WithError(err).WithField("udpAddr", addr).
|
||||||
f.l.WithError(err).WithFields(m{
|
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).WithField("cert", remoteCert).
|
||||||
"udpAddr": addr,
|
Info("Unable to handshake with host due to missing certificate version")
|
||||||
"handshake": m{"stage": 1, "style": "ix_psk0"},
|
return
|
||||||
"cert": remoteCert,
|
|
||||||
}).Debug("Might be unable to handshake with host due to missing certificate version")
|
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
// Record the certificate we are actually using
|
// Record the certificate we are actually using
|
||||||
ci.myCert = myCertOtherVersion
|
ci.myCert = rc
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(remoteCert.Certificate.Networks()) == 0 {
|
if len(remoteCert.Certificate.Networks()) == 0 {
|
||||||
|
|||||||
@@ -70,7 +70,6 @@ type HandshakeHostInfo struct {
|
|||||||
|
|
||||||
startTime time.Time // Time that we first started trying with this handshake
|
startTime time.Time // Time that we first started trying with this handshake
|
||||||
ready bool // Is the handshake ready
|
ready bool // Is the handshake ready
|
||||||
initiatingVersionOverride cert.Version // Should we use a non-default cert version for this handshake?
|
|
||||||
counter int64 // How many attempts have we made so far
|
counter int64 // How many attempts have we made so far
|
||||||
lastRemotes []netip.AddrPort // Remotes that we sent to during the previous attempt
|
lastRemotes []netip.AddrPort // Remotes that we sent to during the previous attempt
|
||||||
packetStore []*cachedPacket // A set of packets to be transmitted once the handshake completes
|
packetStore []*cachedPacket // A set of packets to be transmitted once the handshake completes
|
||||||
|
|||||||
143
inside.go
143
inside.go
@@ -11,149 +11,6 @@ import (
|
|||||||
"github.com/slackhq/nebula/routing"
|
"github.com/slackhq/nebula/routing"
|
||||||
)
|
)
|
||||||
|
|
||||||
// consumeInsidePackets processes multiple packets in a batch for improved performance
|
|
||||||
// packets: slice of packet buffers to process
|
|
||||||
// sizes: slice of packet sizes
|
|
||||||
// count: number of packets to process
|
|
||||||
// outs: slice of output buffers (one per packet) with virtio headroom
|
|
||||||
// q: queue index
|
|
||||||
// localCache: firewall conntrack cache
|
|
||||||
// batchPackets: pre-allocated slice for accumulating encrypted packets
|
|
||||||
// batchAddrs: pre-allocated slice for accumulating destination addresses
|
|
||||||
func (f *Interface) consumeInsidePackets(packets [][]byte, sizes []int, count int, outs [][]byte, nb []byte, q int, localCache firewall.ConntrackCache, batchPackets *[][]byte, batchAddrs *[]netip.AddrPort) {
|
|
||||||
// Reusable per-packet state
|
|
||||||
fwPacket := &firewall.Packet{}
|
|
||||||
|
|
||||||
// Reset batch accumulation slices (reuse capacity)
|
|
||||||
*batchPackets = (*batchPackets)[:0]
|
|
||||||
*batchAddrs = (*batchAddrs)[:0]
|
|
||||||
|
|
||||||
// Process each packet in the batch
|
|
||||||
for i := 0; i < count; i++ {
|
|
||||||
packet := packets[i][:sizes[i]]
|
|
||||||
out := outs[i]
|
|
||||||
|
|
||||||
// Inline the consumeInsidePacket logic for better performance
|
|
||||||
err := newPacket(packet, false, fwPacket)
|
|
||||||
if err != nil {
|
|
||||||
if f.l.Level >= logrus.DebugLevel {
|
|
||||||
f.l.WithField("packet", packet).Debugf("Error while validating outbound packet: %s", err)
|
|
||||||
}
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// Ignore local broadcast packets
|
|
||||||
if f.dropLocalBroadcast {
|
|
||||||
if f.myBroadcastAddrsTable.Contains(fwPacket.RemoteAddr) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if f.myVpnAddrsTable.Contains(fwPacket.RemoteAddr) {
|
|
||||||
// Immediately forward packets from self to self.
|
|
||||||
if immediatelyForwardToSelf {
|
|
||||||
_, err := f.readers[q].Write(packet)
|
|
||||||
if err != nil {
|
|
||||||
f.l.WithError(err).Error("Failed to forward to tun")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// Ignore multicast packets
|
|
||||||
if f.dropMulticast && fwPacket.RemoteAddr.IsMulticast() {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
hostinfo, ready := f.getOrHandshakeConsiderRouting(fwPacket, func(hh *HandshakeHostInfo) {
|
|
||||||
hh.cachePacket(f.l, header.Message, 0, packet, f.sendMessageNow, f.cachedPacketMetrics)
|
|
||||||
})
|
|
||||||
|
|
||||||
if hostinfo == nil {
|
|
||||||
f.rejectInside(packet, out, q)
|
|
||||||
if f.l.Level >= logrus.DebugLevel {
|
|
||||||
f.l.WithField("vpnAddr", fwPacket.RemoteAddr).
|
|
||||||
WithField("fwPacket", fwPacket).
|
|
||||||
Debugln("dropping outbound packet, vpnAddr not in our vpn networks or in unsafe networks")
|
|
||||||
}
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if !ready {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
dropReason := f.firewall.Drop(*fwPacket, false, hostinfo, f.pki.GetCAPool(), localCache)
|
|
||||||
if dropReason != nil {
|
|
||||||
f.rejectInside(packet, out, q)
|
|
||||||
if f.l.Level >= logrus.DebugLevel {
|
|
||||||
hostinfo.logger(f.l).
|
|
||||||
WithField("fwPacket", fwPacket).
|
|
||||||
WithField("reason", dropReason).
|
|
||||||
Debugln("dropping outbound packet")
|
|
||||||
}
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// Encrypt and prepare packet for batch sending
|
|
||||||
ci := hostinfo.ConnectionState
|
|
||||||
if ci.eKey == nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if this needs relay - if so, send immediately and skip batching
|
|
||||||
useRelay := !hostinfo.remote.IsValid()
|
|
||||||
if useRelay {
|
|
||||||
// Handle relay sends individually (less common path)
|
|
||||||
f.sendNoMetrics(header.Message, 0, ci, hostinfo, netip.AddrPort{}, packet, nb, out, q)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// Encrypt the packet for batch sending
|
|
||||||
if noiseutil.EncryptLockNeeded {
|
|
||||||
ci.writeLock.Lock()
|
|
||||||
}
|
|
||||||
c := ci.messageCounter.Add(1)
|
|
||||||
out = header.Encode(out, header.Version, header.Message, 0, hostinfo.remoteIndexId, c)
|
|
||||||
f.connectionManager.Out(hostinfo)
|
|
||||||
|
|
||||||
// Query lighthouse if needed
|
|
||||||
if hostinfo.lastRebindCount != f.rebindCount {
|
|
||||||
f.lightHouse.QueryServer(hostinfo.vpnAddrs[0])
|
|
||||||
hostinfo.lastRebindCount = f.rebindCount
|
|
||||||
if f.l.Level >= logrus.DebugLevel {
|
|
||||||
f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).Debug("Lighthouse update triggered for punch due to rebind counter")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
out, err = ci.eKey.EncryptDanger(out, out, packet, c, nb)
|
|
||||||
if noiseutil.EncryptLockNeeded {
|
|
||||||
ci.writeLock.Unlock()
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
hostinfo.logger(f.l).WithError(err).
|
|
||||||
WithField("counter", c).
|
|
||||||
Error("Failed to encrypt outgoing packet")
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add to batch
|
|
||||||
*batchPackets = append(*batchPackets, out)
|
|
||||||
*batchAddrs = append(*batchAddrs, hostinfo.remote)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Send all accumulated packets in one batch
|
|
||||||
if len(*batchPackets) > 0 {
|
|
||||||
batchSize := len(*batchPackets)
|
|
||||||
f.batchMetrics.udpWriteSize.Update(int64(batchSize))
|
|
||||||
|
|
||||||
n, err := f.writers[q].WriteMulti(*batchPackets, *batchAddrs)
|
|
||||||
if err != nil {
|
|
||||||
f.l.WithError(err).WithField("sent", n).WithField("total", batchSize).Error("Failed to send batch")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb, out []byte, q int, localCache firewall.ConntrackCache) {
|
func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb, out []byte, q int, localCache firewall.ConntrackCache) {
|
||||||
err := newPacket(packet, false, fwPacket)
|
err := newPacket(packet, false, fwPacket)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
423
interface.go
423
interface.go
@@ -4,9 +4,12 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"math/bits"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -20,8 +23,12 @@ import (
|
|||||||
"github.com/slackhq/nebula/udp"
|
"github.com/slackhq/nebula/udp"
|
||||||
)
|
)
|
||||||
|
|
||||||
const mtu = 9001
|
const (
|
||||||
const virtioNetHdrLen = overlay.VirtioNetHdrLen
|
mtu = 9001
|
||||||
|
tunReadBufferSize = mtu * 8
|
||||||
|
defaultDecryptWorkerFactor = 2
|
||||||
|
defaultInboundQueueDepth = 1024
|
||||||
|
)
|
||||||
|
|
||||||
type InterfaceConfig struct {
|
type InterfaceConfig struct {
|
||||||
HostMap *HostMap
|
HostMap *HostMap
|
||||||
@@ -48,13 +55,8 @@ type InterfaceConfig struct {
|
|||||||
|
|
||||||
ConntrackCacheTimeout time.Duration
|
ConntrackCacheTimeout time.Duration
|
||||||
l *logrus.Logger
|
l *logrus.Logger
|
||||||
}
|
DecryptWorkers int
|
||||||
|
DecryptQueueDepth int
|
||||||
type batchMetrics struct {
|
|
||||||
udpReadSize metrics.Histogram
|
|
||||||
tunReadSize metrics.Histogram
|
|
||||||
udpWriteSize metrics.Histogram
|
|
||||||
tunWriteSize metrics.Histogram
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type Interface struct {
|
type Interface struct {
|
||||||
@@ -93,14 +95,173 @@ type Interface struct {
|
|||||||
conntrackCacheTimeout time.Duration
|
conntrackCacheTimeout time.Duration
|
||||||
|
|
||||||
writers []udp.Conn
|
writers []udp.Conn
|
||||||
readers []overlay.BatchReadWriter
|
readers []io.ReadWriteCloser
|
||||||
|
|
||||||
metricHandshakes metrics.Histogram
|
metricHandshakes metrics.Histogram
|
||||||
messageMetrics *MessageMetrics
|
messageMetrics *MessageMetrics
|
||||||
cachedPacketMetrics *cachedPacketMetrics
|
cachedPacketMetrics *cachedPacketMetrics
|
||||||
batchMetrics *batchMetrics
|
|
||||||
|
|
||||||
l *logrus.Logger
|
l *logrus.Logger
|
||||||
|
ctx context.Context
|
||||||
|
udpListenWG sync.WaitGroup
|
||||||
|
inboundPool sync.Pool
|
||||||
|
decryptWG sync.WaitGroup
|
||||||
|
decryptQueues []*inboundRing
|
||||||
|
decryptWorkers int
|
||||||
|
decryptStates []decryptWorkerState
|
||||||
|
decryptCounter atomic.Uint32
|
||||||
|
}
|
||||||
|
|
||||||
|
type inboundPacket struct {
|
||||||
|
addr netip.AddrPort
|
||||||
|
payload []byte
|
||||||
|
release func()
|
||||||
|
queue int
|
||||||
|
}
|
||||||
|
|
||||||
|
type decryptWorkerState struct {
|
||||||
|
queue *inboundRing
|
||||||
|
notify chan struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
type decryptContext struct {
|
||||||
|
ctTicker *firewall.ConntrackCacheTicker
|
||||||
|
plain []byte
|
||||||
|
head header.H
|
||||||
|
fwPacket firewall.Packet
|
||||||
|
light *LightHouseHandler
|
||||||
|
nebula []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
type inboundCell struct {
|
||||||
|
seq atomic.Uint64
|
||||||
|
pkt *inboundPacket
|
||||||
|
}
|
||||||
|
|
||||||
|
type inboundRing struct {
|
||||||
|
mask uint64
|
||||||
|
cells []inboundCell
|
||||||
|
enqueuePos atomic.Uint64
|
||||||
|
dequeuePos atomic.Uint64
|
||||||
|
}
|
||||||
|
|
||||||
|
func newInboundRing(capacity int) *inboundRing {
|
||||||
|
if capacity < 2 {
|
||||||
|
capacity = 2
|
||||||
|
}
|
||||||
|
size := nextPowerOfTwo(uint32(capacity))
|
||||||
|
if size < 2 {
|
||||||
|
size = 2
|
||||||
|
}
|
||||||
|
ring := &inboundRing{
|
||||||
|
mask: uint64(size - 1),
|
||||||
|
cells: make([]inboundCell, size),
|
||||||
|
}
|
||||||
|
for i := range ring.cells {
|
||||||
|
ring.cells[i].seq.Store(uint64(i))
|
||||||
|
}
|
||||||
|
return ring
|
||||||
|
}
|
||||||
|
|
||||||
|
func nextPowerOfTwo(v uint32) uint32 {
|
||||||
|
if v == 0 {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
return 1 << (32 - bits.LeadingZeros32(v-1))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *inboundRing) Enqueue(pkt *inboundPacket) bool {
|
||||||
|
var cell *inboundCell
|
||||||
|
pos := r.enqueuePos.Load()
|
||||||
|
for {
|
||||||
|
cell = &r.cells[pos&r.mask]
|
||||||
|
seq := cell.seq.Load()
|
||||||
|
diff := int64(seq) - int64(pos)
|
||||||
|
if diff == 0 {
|
||||||
|
if r.enqueuePos.CompareAndSwap(pos, pos+1) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
} else if diff < 0 {
|
||||||
|
return false
|
||||||
|
} else {
|
||||||
|
pos = r.enqueuePos.Load()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
cell.pkt = pkt
|
||||||
|
cell.seq.Store(pos + 1)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *inboundRing) Dequeue() (*inboundPacket, bool) {
|
||||||
|
var cell *inboundCell
|
||||||
|
pos := r.dequeuePos.Load()
|
||||||
|
for {
|
||||||
|
cell = &r.cells[pos&r.mask]
|
||||||
|
seq := cell.seq.Load()
|
||||||
|
diff := int64(seq) - int64(pos+1)
|
||||||
|
if diff == 0 {
|
||||||
|
if r.dequeuePos.CompareAndSwap(pos, pos+1) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
} else if diff < 0 {
|
||||||
|
return nil, false
|
||||||
|
} else {
|
||||||
|
pos = r.dequeuePos.Load()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
pkt := cell.pkt
|
||||||
|
cell.pkt = nil
|
||||||
|
cell.seq.Store(pos + r.mask + 1)
|
||||||
|
return pkt, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *Interface) getInboundPacket() *inboundPacket {
|
||||||
|
if pkt, ok := f.inboundPool.Get().(*inboundPacket); ok && pkt != nil {
|
||||||
|
return pkt
|
||||||
|
}
|
||||||
|
return &inboundPacket{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *Interface) putInboundPacket(pkt *inboundPacket) {
|
||||||
|
if pkt == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
pkt.addr = netip.AddrPort{}
|
||||||
|
pkt.payload = nil
|
||||||
|
pkt.release = nil
|
||||||
|
pkt.queue = 0
|
||||||
|
f.inboundPool.Put(pkt)
|
||||||
|
}
|
||||||
|
|
||||||
|
func newDecryptContext(f *Interface) *decryptContext {
|
||||||
|
return &decryptContext{
|
||||||
|
ctTicker: firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout),
|
||||||
|
plain: make([]byte, udp.MTU),
|
||||||
|
head: header.H{},
|
||||||
|
fwPacket: firewall.Packet{},
|
||||||
|
light: f.lightHouse.NewRequestHandler(),
|
||||||
|
nebula: make([]byte, 12, 12),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *Interface) processInboundPacket(pkt *inboundPacket, ctx *decryptContext) {
|
||||||
|
if pkt == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if pkt.release != nil {
|
||||||
|
pkt.release()
|
||||||
|
}
|
||||||
|
f.putInboundPacket(pkt)
|
||||||
|
}()
|
||||||
|
|
||||||
|
ctx.head = header.H{}
|
||||||
|
ctx.fwPacket = firewall.Packet{}
|
||||||
|
var cache firewall.ConntrackCache
|
||||||
|
if ctx.ctTicker != nil {
|
||||||
|
cache = ctx.ctTicker.Get(f.l)
|
||||||
|
}
|
||||||
|
f.readOutsidePackets(pkt.addr, nil, ctx.plain[:0], pkt.payload, &ctx.head, &ctx.fwPacket, ctx.light, ctx.nebula, pkt.queue, cache)
|
||||||
}
|
}
|
||||||
|
|
||||||
type EncWriter interface {
|
type EncWriter interface {
|
||||||
@@ -170,6 +331,35 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
cs := c.pki.getCertState()
|
cs := c.pki.getCertState()
|
||||||
|
decryptWorkers := c.DecryptWorkers
|
||||||
|
if decryptWorkers < 0 {
|
||||||
|
decryptWorkers = 0
|
||||||
|
}
|
||||||
|
if decryptWorkers == 0 {
|
||||||
|
decryptWorkers = c.routines * defaultDecryptWorkerFactor
|
||||||
|
if decryptWorkers < c.routines {
|
||||||
|
decryptWorkers = c.routines
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if decryptWorkers < 0 {
|
||||||
|
decryptWorkers = 0
|
||||||
|
}
|
||||||
|
if runtime.GOOS != "linux" {
|
||||||
|
decryptWorkers = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
queueDepth := c.DecryptQueueDepth
|
||||||
|
if queueDepth <= 0 {
|
||||||
|
queueDepth = defaultInboundQueueDepth
|
||||||
|
}
|
||||||
|
minDepth := c.routines * 64
|
||||||
|
if minDepth <= 0 {
|
||||||
|
minDepth = 64
|
||||||
|
}
|
||||||
|
if queueDepth < minDepth {
|
||||||
|
queueDepth = minDepth
|
||||||
|
}
|
||||||
|
|
||||||
ifce := &Interface{
|
ifce := &Interface{
|
||||||
pki: c.pki,
|
pki: c.pki,
|
||||||
hostMap: c.HostMap,
|
hostMap: c.HostMap,
|
||||||
@@ -185,7 +375,7 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
|
|||||||
routines: c.routines,
|
routines: c.routines,
|
||||||
version: c.version,
|
version: c.version,
|
||||||
writers: make([]udp.Conn, c.routines),
|
writers: make([]udp.Conn, c.routines),
|
||||||
readers: make([]overlay.BatchReadWriter, c.routines),
|
readers: make([]io.ReadWriteCloser, c.routines),
|
||||||
myVpnNetworks: cs.myVpnNetworks,
|
myVpnNetworks: cs.myVpnNetworks,
|
||||||
myVpnNetworksTable: cs.myVpnNetworksTable,
|
myVpnNetworksTable: cs.myVpnNetworksTable,
|
||||||
myVpnAddrs: cs.myVpnAddrs,
|
myVpnAddrs: cs.myVpnAddrs,
|
||||||
@@ -201,14 +391,11 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
|
|||||||
sent: metrics.GetOrRegisterCounter("hostinfo.cached_packets.sent", nil),
|
sent: metrics.GetOrRegisterCounter("hostinfo.cached_packets.sent", nil),
|
||||||
dropped: metrics.GetOrRegisterCounter("hostinfo.cached_packets.dropped", nil),
|
dropped: metrics.GetOrRegisterCounter("hostinfo.cached_packets.dropped", nil),
|
||||||
},
|
},
|
||||||
batchMetrics: &batchMetrics{
|
|
||||||
udpReadSize: metrics.GetOrRegisterHistogram("batch.udp_read_size", nil, metrics.NewUniformSample(1024)),
|
|
||||||
tunReadSize: metrics.GetOrRegisterHistogram("batch.tun_read_size", nil, metrics.NewUniformSample(1024)),
|
|
||||||
udpWriteSize: metrics.GetOrRegisterHistogram("batch.udp_write_size", nil, metrics.NewUniformSample(1024)),
|
|
||||||
tunWriteSize: metrics.GetOrRegisterHistogram("batch.tun_write_size", nil, metrics.NewUniformSample(1024)),
|
|
||||||
},
|
|
||||||
|
|
||||||
l: c.l,
|
l: c.l,
|
||||||
|
ctx: ctx,
|
||||||
|
inboundPool: sync.Pool{New: func() any { return &inboundPacket{} }},
|
||||||
|
decryptWorkers: decryptWorkers,
|
||||||
}
|
}
|
||||||
|
|
||||||
ifce.tryPromoteEvery.Store(c.tryPromoteEvery)
|
ifce.tryPromoteEvery.Store(c.tryPromoteEvery)
|
||||||
@@ -217,6 +404,19 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
|
|||||||
|
|
||||||
ifce.connectionManager.intf = ifce
|
ifce.connectionManager.intf = ifce
|
||||||
|
|
||||||
|
if decryptWorkers > 0 {
|
||||||
|
ifce.decryptQueues = make([]*inboundRing, decryptWorkers)
|
||||||
|
ifce.decryptStates = make([]decryptWorkerState, decryptWorkers)
|
||||||
|
for i := 0; i < decryptWorkers; i++ {
|
||||||
|
queue := newInboundRing(queueDepth)
|
||||||
|
ifce.decryptQueues[i] = queue
|
||||||
|
ifce.decryptStates[i] = decryptWorkerState{
|
||||||
|
queue: queue,
|
||||||
|
notify: make(chan struct{}, 1),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return ifce, nil
|
return ifce, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -239,7 +439,7 @@ func (f *Interface) activate() {
|
|||||||
metrics.GetOrRegisterGauge("routines", nil).Update(int64(f.routines))
|
metrics.GetOrRegisterGauge("routines", nil).Update(int64(f.routines))
|
||||||
|
|
||||||
// Prepare n tun queues
|
// Prepare n tun queues
|
||||||
var reader overlay.BatchReadWriter = f.inside
|
var reader io.ReadWriteCloser = f.inside
|
||||||
for i := 0; i < f.routines; i++ {
|
for i := 0; i < f.routines; i++ {
|
||||||
if i > 0 {
|
if i > 0 {
|
||||||
reader, err = f.inside.NewMultiQueueReader()
|
reader, err = f.inside.NewMultiQueueReader()
|
||||||
@@ -256,8 +456,68 @@ func (f *Interface) activate() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (f *Interface) startDecryptWorkers() {
|
||||||
|
if f.decryptWorkers <= 0 || len(f.decryptQueues) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
f.decryptWG.Add(f.decryptWorkers)
|
||||||
|
for i := 0; i < f.decryptWorkers; i++ {
|
||||||
|
go f.decryptWorker(i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *Interface) decryptWorker(id int) {
|
||||||
|
defer f.decryptWG.Done()
|
||||||
|
if id < 0 || id >= len(f.decryptStates) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
state := f.decryptStates[id]
|
||||||
|
if state.queue == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ctx := newDecryptContext(f)
|
||||||
|
for {
|
||||||
|
for {
|
||||||
|
pkt, ok := state.queue.Dequeue()
|
||||||
|
if !ok {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
f.processInboundPacket(pkt, ctx)
|
||||||
|
}
|
||||||
|
if f.closed.Load() || f.ctx.Err() != nil {
|
||||||
|
for {
|
||||||
|
pkt, ok := state.queue.Dequeue()
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
f.processInboundPacket(pkt, ctx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-f.ctx.Done():
|
||||||
|
case <-state.notify:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *Interface) notifyDecryptWorker(idx int) {
|
||||||
|
if idx < 0 || idx >= len(f.decryptStates) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
state := f.decryptStates[idx]
|
||||||
|
if state.notify == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case state.notify <- struct{}{}:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (f *Interface) run() {
|
func (f *Interface) run() {
|
||||||
|
f.startDecryptWorkers()
|
||||||
// Launch n queues to read packets from udp
|
// Launch n queues to read packets from udp
|
||||||
|
f.udpListenWG.Add(f.routines)
|
||||||
for i := 0; i < f.routines; i++ {
|
for i := 0; i < f.routines; i++ {
|
||||||
go f.listenOut(i)
|
go f.listenOut(i)
|
||||||
}
|
}
|
||||||
@@ -270,6 +530,7 @@ func (f *Interface) run() {
|
|||||||
|
|
||||||
func (f *Interface) listenOut(i int) {
|
func (f *Interface) listenOut(i int) {
|
||||||
runtime.LockOSThread()
|
runtime.LockOSThread()
|
||||||
|
defer f.udpListenWG.Done()
|
||||||
|
|
||||||
var li udp.Conn
|
var li udp.Conn
|
||||||
if i > 0 {
|
if i > 0 {
|
||||||
@@ -278,71 +539,96 @@ func (f *Interface) listenOut(i int) {
|
|||||||
li = f.outside
|
li = f.outside
|
||||||
}
|
}
|
||||||
|
|
||||||
ctCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
|
useWorkers := f.decryptWorkers > 0 && len(f.decryptQueues) > 0
|
||||||
lhh := f.lightHouse.NewRequestHandler()
|
var (
|
||||||
|
inlineTicker *firewall.ConntrackCacheTicker
|
||||||
|
inlineHandler *LightHouseHandler
|
||||||
|
inlinePlain []byte
|
||||||
|
inlineHeader header.H
|
||||||
|
inlinePacket firewall.Packet
|
||||||
|
inlineNB []byte
|
||||||
|
inlineCtx *decryptContext
|
||||||
|
)
|
||||||
|
|
||||||
// Pre-allocate output buffers for batch processing
|
if useWorkers {
|
||||||
batchSize := li.BatchSize()
|
inlineCtx = newDecryptContext(f)
|
||||||
outs := make([][]byte, batchSize)
|
} else {
|
||||||
for idx := range outs {
|
inlineTicker = firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
|
||||||
// Allocate full buffer with virtio header space
|
inlineHandler = f.lightHouse.NewRequestHandler()
|
||||||
outs[idx] = make([]byte, virtioNetHdrLen, virtioNetHdrLen+udp.MTU)
|
inlinePlain = make([]byte, udp.MTU)
|
||||||
|
inlineNB = make([]byte, 12, 12)
|
||||||
}
|
}
|
||||||
|
|
||||||
h := &header.H{}
|
li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte, release func()) {
|
||||||
fwPacket := &firewall.Packet{}
|
if !useWorkers {
|
||||||
nb := make([]byte, 12)
|
if release != nil {
|
||||||
|
defer release()
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-f.ctx.Done():
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
inlineHeader = header.H{}
|
||||||
|
inlinePacket = firewall.Packet{}
|
||||||
|
var cache firewall.ConntrackCache
|
||||||
|
if inlineTicker != nil {
|
||||||
|
cache = inlineTicker.Get(f.l)
|
||||||
|
}
|
||||||
|
f.readOutsidePackets(fromUdpAddr, nil, inlinePlain[:0], payload, &inlineHeader, &inlinePacket, inlineHandler, inlineNB, i, cache)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
li.ListenOutBatch(func(addrs []netip.AddrPort, payloads [][]byte, count int) {
|
if f.ctx.Err() != nil {
|
||||||
f.readOutsidePacketsBatch(addrs, payloads, count, outs[:count], nb, i, h, fwPacket, lhh, ctCache.Get(f.l))
|
if release != nil {
|
||||||
|
release()
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
pkt := f.getInboundPacket()
|
||||||
|
pkt.addr = fromUdpAddr
|
||||||
|
pkt.payload = payload
|
||||||
|
pkt.release = release
|
||||||
|
pkt.queue = i
|
||||||
|
|
||||||
|
queueCount := len(f.decryptQueues)
|
||||||
|
if queueCount == 0 {
|
||||||
|
f.processInboundPacket(pkt, inlineCtx)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w := int(f.decryptCounter.Add(1)-1) % queueCount
|
||||||
|
if w < 0 || w >= queueCount || !f.decryptQueues[w].Enqueue(pkt) {
|
||||||
|
f.processInboundPacket(pkt, inlineCtx)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
f.notifyDecryptWorker(w)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Interface) listenIn(reader overlay.BatchReadWriter, i int) {
|
func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
|
||||||
runtime.LockOSThread()
|
runtime.LockOSThread()
|
||||||
|
|
||||||
batchSize := reader.BatchSize()
|
packet := make([]byte, tunReadBufferSize)
|
||||||
|
out := make([]byte, tunReadBufferSize)
|
||||||
// Allocate buffers for batch reading
|
fwPacket := &firewall.Packet{}
|
||||||
bufs := make([][]byte, batchSize)
|
nb := make([]byte, 12, 12)
|
||||||
for idx := range bufs {
|
|
||||||
bufs[idx] = make([]byte, mtu)
|
|
||||||
}
|
|
||||||
sizes := make([]int, batchSize)
|
|
||||||
|
|
||||||
// Allocate output buffers for batch processing (one per packet)
|
|
||||||
// Each has virtio header headroom to avoid copies on write
|
|
||||||
outs := make([][]byte, batchSize)
|
|
||||||
for idx := range outs {
|
|
||||||
outBuf := make([]byte, virtioNetHdrLen+mtu)
|
|
||||||
outs[idx] = outBuf[virtioNetHdrLen:] // Slice starting after headroom
|
|
||||||
}
|
|
||||||
|
|
||||||
// Pre-allocate batch accumulation buffers for sending
|
|
||||||
batchPackets := make([][]byte, 0, batchSize)
|
|
||||||
batchAddrs := make([]netip.AddrPort, 0, batchSize)
|
|
||||||
|
|
||||||
// Pre-allocate nonce buffer (reused for all encryptions)
|
|
||||||
nb := make([]byte, 12)
|
|
||||||
|
|
||||||
conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
|
conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
|
||||||
|
|
||||||
for {
|
for {
|
||||||
n, err := reader.BatchRead(bufs, sizes)
|
n, err := reader.Read(packet)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, os.ErrClosed) && f.closed.Load() {
|
if errors.Is(err, os.ErrClosed) && f.closed.Load() {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
f.l.WithError(err).Error("Error while batch reading outbound packets")
|
f.l.WithError(err).Error("Error while reading outbound packet")
|
||||||
// This only seems to happen when something fatal happens to the fd, so exit.
|
// This only seems to happen when something fatal happens to the fd, so exit.
|
||||||
os.Exit(2)
|
os.Exit(2)
|
||||||
}
|
}
|
||||||
|
|
||||||
f.batchMetrics.tunReadSize.Update(int64(n))
|
f.consumeInsidePacket(packet[:n], fwPacket, nb, out, i, conntrackCache.Get(f.l))
|
||||||
|
|
||||||
// Process all packets in the batch at once
|
|
||||||
f.consumeInsidePackets(bufs, sizes, n, outs, nb, i, conntrackCache.Get(f.l), &batchPackets, &batchAddrs)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -502,6 +788,19 @@ func (f *Interface) Close() error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
f.udpListenWG.Wait()
|
||||||
|
if f.decryptWorkers > 0 {
|
||||||
|
for _, state := range f.decryptStates {
|
||||||
|
if state.notify != nil {
|
||||||
|
select {
|
||||||
|
case state.notify <- struct{}{}:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
f.decryptWG.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
// Release the tun device
|
// Release the tun device
|
||||||
return f.inside.Close()
|
return f.inside.Close()
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1337,19 +1337,12 @@ func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, fromVpn
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
remoteAllowList := lhh.lh.GetRemoteAllowList()
|
|
||||||
for _, a := range n.Details.V4AddrPorts {
|
for _, a := range n.Details.V4AddrPorts {
|
||||||
b := protoV4AddrPortToNetAddrPort(a)
|
punch(protoV4AddrPortToNetAddrPort(a), detailsVpnAddr)
|
||||||
if remoteAllowList.Allow(detailsVpnAddr, b.Addr()) {
|
|
||||||
punch(b, detailsVpnAddr)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, a := range n.Details.V6AddrPorts {
|
for _, a := range n.Details.V6AddrPorts {
|
||||||
b := protoV6AddrPortToNetAddrPort(a)
|
punch(protoV6AddrPortToNetAddrPort(a), detailsVpnAddr)
|
||||||
if remoteAllowList.Allow(detailsVpnAddr, b.Addr()) {
|
|
||||||
punch(b, detailsVpnAddr)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// This sends a nebula test packet to the host trying to contact us. In the case
|
// This sends a nebula test packet to the host trying to contact us. In the case
|
||||||
|
|||||||
12
main.go
12
main.go
@@ -75,8 +75,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
|||||||
if c.GetBool("sshd.enabled", false) {
|
if c.GetBool("sshd.enabled", false) {
|
||||||
sshStart, err = configSSH(l, ssh, c)
|
sshStart, err = configSSH(l, ssh, c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
l.WithError(err).Warn("Failed to configure sshd, ssh debugging will not be available")
|
return nil, util.ContextualizeIfNeeded("Error while configuring the sshd", err)
|
||||||
sshStart = nil
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -121,6 +120,8 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
|||||||
l.WithField("duration", conntrackCacheTimeout).Info("Using routine-local conntrack cache")
|
l.WithField("duration", conntrackCacheTimeout).Info("Using routine-local conntrack cache")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
udp.SetDisableUDPCsum(c.GetBool("listen.disable_udp_checksum", false))
|
||||||
|
|
||||||
var tun overlay.Device
|
var tun overlay.Device
|
||||||
if !configTest {
|
if !configTest {
|
||||||
c.CatchHUP(ctx)
|
c.CatchHUP(ctx)
|
||||||
@@ -165,7 +166,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
|||||||
|
|
||||||
for i := 0; i < routines; i++ {
|
for i := 0; i < routines; i++ {
|
||||||
l.Infof("listening on %v", netip.AddrPortFrom(listenHost, uint16(port)))
|
l.Infof("listening on %v", netip.AddrPortFrom(listenHost, uint16(port)))
|
||||||
udpServer, err := udp.NewListener(l, listenHost, port, routines > 1, c.GetInt("listen.batch", 128))
|
udpServer, err := udp.NewListener(l, listenHost, port, routines > 1, c.GetInt("listen.batch", 64))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, util.NewContextualError("Failed to open udp listener", m{"queue": i}, err)
|
return nil, util.NewContextualError("Failed to open udp listener", m{"queue": i}, err)
|
||||||
}
|
}
|
||||||
@@ -222,6 +223,9 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
decryptWorkers := c.GetInt("listen.decrypt_workers", 0)
|
||||||
|
decryptQueueDepth := c.GetInt("listen.decrypt_queue_depth", 0)
|
||||||
|
|
||||||
ifConfig := &InterfaceConfig{
|
ifConfig := &InterfaceConfig{
|
||||||
HostMap: hostMap,
|
HostMap: hostMap,
|
||||||
Inside: tun,
|
Inside: tun,
|
||||||
@@ -244,6 +248,8 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
|||||||
punchy: punchy,
|
punchy: punchy,
|
||||||
ConntrackCacheTimeout: conntrackCacheTimeout,
|
ConntrackCacheTimeout: conntrackCacheTimeout,
|
||||||
l: l,
|
l: l,
|
||||||
|
DecryptWorkers: decryptWorkers,
|
||||||
|
DecryptQueueDepth: decryptQueueDepth,
|
||||||
}
|
}
|
||||||
|
|
||||||
var ifce *Interface
|
var ifce *Interface
|
||||||
|
|||||||
130
outside.go
130
outside.go
@@ -95,7 +95,8 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
|
|||||||
switch relay.Type {
|
switch relay.Type {
|
||||||
case TerminalType:
|
case TerminalType:
|
||||||
// If I am the target of this relay, process the unwrapped packet
|
// If I am the target of this relay, process the unwrapped packet
|
||||||
f.readOutsidePackets(netip.AddrPort{}, &ViaSender{relayHI: hostinfo, remoteIdx: relay.RemoteIndex, relay: relay}, out[:virtioNetHdrLen], signedPayload, h, fwPacket, lhf, nb, q, localCache)
|
// From this recursive point, all these variables are 'burned'. We shouldn't rely on them again.
|
||||||
|
f.readOutsidePackets(netip.AddrPort{}, &ViaSender{relayHI: hostinfo, remoteIdx: relay.RemoteIndex, relay: relay}, out[:0], signedPayload, h, fwPacket, lhf, nb, q, localCache)
|
||||||
return
|
return
|
||||||
case ForwardingType:
|
case ForwardingType:
|
||||||
// Find the target HostInfo relay object
|
// Find the target HostInfo relay object
|
||||||
@@ -137,7 +138,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
lhf.HandleRequest(ip, hostinfo.vpnAddrs, d[virtioNetHdrLen:], f)
|
lhf.HandleRequest(ip, hostinfo.vpnAddrs, d, f)
|
||||||
|
|
||||||
// Fallthrough to the bottom to record incoming traffic
|
// Fallthrough to the bottom to record incoming traffic
|
||||||
|
|
||||||
@@ -159,7 +160,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
|
|||||||
// This testRequest might be from TryPromoteBest, so we should roam
|
// This testRequest might be from TryPromoteBest, so we should roam
|
||||||
// to the new IP address before responding
|
// to the new IP address before responding
|
||||||
f.handleHostRoaming(hostinfo, ip)
|
f.handleHostRoaming(hostinfo, ip)
|
||||||
f.send(header.Test, header.TestReply, ci, hostinfo, d[virtioNetHdrLen:], nb, out)
|
f.send(header.Test, header.TestReply, ci, hostinfo, d, nb, out)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Fallthrough to the bottom to record incoming traffic
|
// Fallthrough to the bottom to record incoming traffic
|
||||||
@@ -202,7 +203,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
f.relayManager.HandleControlMsg(hostinfo, d[virtioNetHdrLen:], f)
|
f.relayManager.HandleControlMsg(hostinfo, d, f)
|
||||||
|
|
||||||
default:
|
default:
|
||||||
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
|
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
|
||||||
@@ -469,15 +470,19 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
|
|||||||
|
|
||||||
out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:header.Len], packet[header.Len:], messageCounter, nb)
|
out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:header.Len], packet[header.Len:], messageCounter, nb)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
hostinfo.logger(f.l).WithError(err).Error("Failed to decrypt packet")
|
hostinfo.logger(f.l).
|
||||||
|
WithError(err).
|
||||||
|
WithField("tag", "decrypt-debug").
|
||||||
|
WithField("remoteIndexLocal", hostinfo.localIndexId).
|
||||||
|
WithField("messageCounter", messageCounter).
|
||||||
|
WithField("packet_len", len(packet)).
|
||||||
|
Error("Failed to decrypt packet")
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
packetData := out[virtioNetHdrLen:]
|
err = newPacket(out, true, fwPacket)
|
||||||
|
|
||||||
err = newPacket(packetData, true, fwPacket)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
hostinfo.logger(f.l).WithError(err).WithField("packet", packetData).
|
hostinfo.logger(f.l).WithError(err).WithField("packet", out).
|
||||||
Warnf("Error while validating inbound packet")
|
Warnf("Error while validating inbound packet")
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
@@ -492,7 +497,7 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
|
|||||||
if dropReason != nil {
|
if dropReason != nil {
|
||||||
// NOTE: We give `packet` as the `out` here since we already decrypted from it and we don't need it anymore
|
// NOTE: We give `packet` as the `out` here since we already decrypted from it and we don't need it anymore
|
||||||
// This gives us a buffer to build the reject packet in
|
// This gives us a buffer to build the reject packet in
|
||||||
f.rejectOutside(packetData, hostinfo.ConnectionState, hostinfo, nb, packet, q)
|
f.rejectOutside(out, hostinfo.ConnectionState, hostinfo, nb, packet, q)
|
||||||
if f.l.Level >= logrus.DebugLevel {
|
if f.l.Level >= logrus.DebugLevel {
|
||||||
hostinfo.logger(f.l).WithField("fwPacket", fwPacket).
|
hostinfo.logger(f.l).WithField("fwPacket", fwPacket).
|
||||||
WithField("reason", dropReason).
|
WithField("reason", dropReason).
|
||||||
@@ -549,108 +554,3 @@ func (f *Interface) handleRecvError(addr netip.AddrPort, h *header.H) {
|
|||||||
// We also delete it from pending hostmap to allow for fast reconnect.
|
// We also delete it from pending hostmap to allow for fast reconnect.
|
||||||
f.handshakeManager.DeleteHostInfo(hostinfo)
|
f.handshakeManager.DeleteHostInfo(hostinfo)
|
||||||
}
|
}
|
||||||
|
|
||||||
// readOutsidePacketsBatch processes multiple packets received from UDP in a batch
|
|
||||||
// and writes all successfully decrypted packets to TUN in a single operation
|
|
||||||
func (f *Interface) readOutsidePacketsBatch(addrs []netip.AddrPort, payloads [][]byte, count int, outs [][]byte, nb []byte, q int, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, localCache firewall.ConntrackCache) {
|
|
||||||
// Pre-allocate slice for accumulating successful decryptions
|
|
||||||
tunPackets := make([][]byte, 0, count)
|
|
||||||
|
|
||||||
for i := 0; i < count; i++ {
|
|
||||||
payload := payloads[i]
|
|
||||||
addr := addrs[i]
|
|
||||||
out := outs[i]
|
|
||||||
|
|
||||||
// Parse header
|
|
||||||
err := h.Parse(payload)
|
|
||||||
if err != nil {
|
|
||||||
if len(payload) > 1 {
|
|
||||||
f.l.WithField("packet", payload).Infof("Error while parsing inbound packet from %s: %s", addr, err)
|
|
||||||
}
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if addr.IsValid() {
|
|
||||||
if f.myVpnNetworksTable.Contains(addr.Addr()) {
|
|
||||||
if f.l.Level >= logrus.DebugLevel {
|
|
||||||
f.l.WithField("udpAddr", addr).Debug("Refusing to process double encrypted packet")
|
|
||||||
}
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var hostinfo *HostInfo
|
|
||||||
if h.Type == header.Message && h.Subtype == header.MessageRelay {
|
|
||||||
hostinfo = f.hostMap.QueryRelayIndex(h.RemoteIndex)
|
|
||||||
} else {
|
|
||||||
hostinfo = f.hostMap.QueryIndex(h.RemoteIndex)
|
|
||||||
}
|
|
||||||
|
|
||||||
var ci *ConnectionState
|
|
||||||
if hostinfo != nil {
|
|
||||||
ci = hostinfo.ConnectionState
|
|
||||||
}
|
|
||||||
|
|
||||||
switch h.Type {
|
|
||||||
case header.Message:
|
|
||||||
if !f.handleEncrypted(ci, addr, h) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
switch h.Subtype {
|
|
||||||
case header.MessageNone:
|
|
||||||
// Decrypt packet
|
|
||||||
out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, payload[:header.Len], payload[header.Len:], h.MessageCounter, nb)
|
|
||||||
if err != nil {
|
|
||||||
hostinfo.logger(f.l).WithError(err).Error("Failed to decrypt packet")
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
packetData := out[virtioNetHdrLen:]
|
|
||||||
|
|
||||||
err = newPacket(packetData, true, fwPacket)
|
|
||||||
if err != nil {
|
|
||||||
hostinfo.logger(f.l).WithError(err).WithField("packet", packetData).Warnf("Error while validating inbound packet")
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if !hostinfo.ConnectionState.window.Update(f.l, h.MessageCounter) {
|
|
||||||
hostinfo.logger(f.l).WithField("fwPacket", fwPacket).Debugln("dropping out of window packet")
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
dropReason := f.firewall.Drop(*fwPacket, true, hostinfo, f.pki.GetCAPool(), localCache)
|
|
||||||
if dropReason != nil {
|
|
||||||
f.rejectOutside(packetData, hostinfo.ConnectionState, hostinfo, nb, payload, q)
|
|
||||||
if f.l.Level >= logrus.DebugLevel {
|
|
||||||
hostinfo.logger(f.l).WithField("fwPacket", fwPacket).WithField("reason", dropReason).Debugln("dropping inbound packet")
|
|
||||||
}
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
f.connectionManager.In(hostinfo)
|
|
||||||
// Add to batch for TUN write
|
|
||||||
tunPackets = append(tunPackets, out)
|
|
||||||
|
|
||||||
case header.MessageRelay:
|
|
||||||
// Skip relay packets in batch mode for now (less common path)
|
|
||||||
f.readOutsidePackets(addr, nil, out, payload, h, fwPacket, lhf, nb, q, localCache)
|
|
||||||
|
|
||||||
default:
|
|
||||||
hostinfo.logger(f.l).Debugf("unexpected message subtype %d", h.Subtype)
|
|
||||||
}
|
|
||||||
|
|
||||||
default:
|
|
||||||
// Handle non-Message types using single-packet path
|
|
||||||
f.readOutsidePackets(addr, nil, out, payload, h, fwPacket, lhf, nb, q, localCache)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(tunPackets) > 0 {
|
|
||||||
n, err := f.readers[q].WriteBatch(tunPackets, virtioNetHdrLen)
|
|
||||||
if err != nil {
|
|
||||||
f.l.WithError(err).WithField("sent", n).WithField("total", len(tunPackets)).Error("Failed to batch write to tun")
|
|
||||||
}
|
|
||||||
f.batchMetrics.tunWriteSize.Update(int64(len(tunPackets)))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -7,25 +7,11 @@ import (
|
|||||||
"github.com/slackhq/nebula/routing"
|
"github.com/slackhq/nebula/routing"
|
||||||
)
|
)
|
||||||
|
|
||||||
// BatchReadWriter extends io.ReadWriteCloser with batch I/O operations
|
|
||||||
type BatchReadWriter interface {
|
|
||||||
io.ReadWriteCloser
|
|
||||||
|
|
||||||
// BatchRead reads multiple packets at once
|
|
||||||
BatchRead(bufs [][]byte, sizes []int) (int, error)
|
|
||||||
|
|
||||||
// WriteBatch writes multiple packets at once
|
|
||||||
WriteBatch(bufs [][]byte, offset int) (int, error)
|
|
||||||
|
|
||||||
// BatchSize returns the optimal batch size for this device
|
|
||||||
BatchSize() int
|
|
||||||
}
|
|
||||||
|
|
||||||
type Device interface {
|
type Device interface {
|
||||||
BatchReadWriter
|
io.ReadWriteCloser
|
||||||
Activate() error
|
Activate() error
|
||||||
Networks() []netip.Prefix
|
Networks() []netip.Prefix
|
||||||
Name() string
|
Name() string
|
||||||
RoutesFor(netip.Addr) routing.Gateways
|
RoutesFor(netip.Addr) routing.Gateways
|
||||||
NewMultiQueueReader() (BatchReadWriter, error)
|
NewMultiQueueReader() (io.ReadWriteCloser, error)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -11,7 +11,6 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const DefaultMTU = 1300
|
const DefaultMTU = 1300
|
||||||
const VirtioNetHdrLen = 10 // Size of virtio_net_hdr structure
|
|
||||||
|
|
||||||
// TODO: We may be able to remove routines
|
// TODO: We may be able to remove routines
|
||||||
type DeviceFactory func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error)
|
type DeviceFactory func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error)
|
||||||
|
|||||||
@@ -95,29 +95,6 @@ func (t *tun) Name() string {
|
|||||||
return "android"
|
return "android"
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) NewMultiQueueReader() (BatchReadWriter, error) {
|
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||||
return nil, fmt.Errorf("TODO: multiqueue not implemented for android")
|
return nil, fmt.Errorf("TODO: multiqueue not implemented for android")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) BatchRead(bufs [][]byte, sizes []int) (int, error) {
|
|
||||||
n, err := t.Read(bufs[0])
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
sizes[0] = n
|
|
||||||
return 1, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *tun) WriteBatch(bufs [][]byte, offset int) (int, error) {
|
|
||||||
for i, buf := range bufs {
|
|
||||||
_, err := t.Write(buf[offset:])
|
|
||||||
if err != nil {
|
|
||||||
return i, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return len(bufs), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *tun) BatchSize() int {
|
|
||||||
return 1
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -549,32 +549,6 @@ func (t *tun) Name() string {
|
|||||||
return t.Device
|
return t.Device
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) NewMultiQueueReader() (BatchReadWriter, error) {
|
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||||
return nil, fmt.Errorf("TODO: multiqueue not implemented for darwin")
|
return nil, fmt.Errorf("TODO: multiqueue not implemented for darwin")
|
||||||
}
|
}
|
||||||
|
|
||||||
// BatchRead reads a single packet (batch size 1 for non-Linux platforms)
|
|
||||||
func (t *tun) BatchRead(bufs [][]byte, sizes []int) (int, error) {
|
|
||||||
n, err := t.Read(bufs[0])
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
sizes[0] = n
|
|
||||||
return 1, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// WriteBatch writes packets individually (no batching for non-Linux platforms)
|
|
||||||
func (t *tun) WriteBatch(bufs [][]byte, offset int) (int, error) {
|
|
||||||
for i, buf := range bufs {
|
|
||||||
_, err := t.Write(buf[offset:])
|
|
||||||
if err != nil {
|
|
||||||
return i, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return len(bufs), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// BatchSize returns 1 for non-Linux platforms (no batching)
|
|
||||||
func (t *tun) BatchSize() int {
|
|
||||||
return 1
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -105,36 +105,10 @@ func (t *disabledTun) Write(b []byte) (int, error) {
|
|||||||
return len(b), nil
|
return len(b), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *disabledTun) NewMultiQueueReader() (BatchReadWriter, error) {
|
func (t *disabledTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||||
return t, nil
|
return t, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// BatchRead reads a single packet (batch size 1 for disabled tun)
|
|
||||||
func (t *disabledTun) BatchRead(bufs [][]byte, sizes []int) (int, error) {
|
|
||||||
n, err := t.Read(bufs[0])
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
sizes[0] = n
|
|
||||||
return 1, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// WriteBatch writes packets individually (no batching for disabled tun)
|
|
||||||
func (t *disabledTun) WriteBatch(bufs [][]byte, offset int) (int, error) {
|
|
||||||
for i, buf := range bufs {
|
|
||||||
_, err := t.Write(buf[offset:])
|
|
||||||
if err != nil {
|
|
||||||
return i, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return len(bufs), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// BatchSize returns 1 for disabled tun (no batching)
|
|
||||||
func (t *disabledTun) BatchSize() int {
|
|
||||||
return 1
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *disabledTun) Close() error {
|
func (t *disabledTun) Close() error {
|
||||||
if t.read != nil {
|
if t.read != nil {
|
||||||
close(t.read)
|
close(t.read)
|
||||||
|
|||||||
@@ -450,36 +450,10 @@ func (t *tun) Name() string {
|
|||||||
return t.Device
|
return t.Device
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) NewMultiQueueReader() (BatchReadWriter, error) {
|
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||||
return nil, fmt.Errorf("TODO: multiqueue not implemented for freebsd")
|
return nil, fmt.Errorf("TODO: multiqueue not implemented for freebsd")
|
||||||
}
|
}
|
||||||
|
|
||||||
// BatchRead reads a single packet (batch size 1 for FreeBSD)
|
|
||||||
func (t *tun) BatchRead(bufs [][]byte, sizes []int) (int, error) {
|
|
||||||
n, err := t.Read(bufs[0])
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
sizes[0] = n
|
|
||||||
return 1, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// WriteBatch writes packets individually (no batching for FreeBSD)
|
|
||||||
func (t *tun) WriteBatch(bufs [][]byte, offset int) (int, error) {
|
|
||||||
for i, buf := range bufs {
|
|
||||||
_, err := t.Write(buf[offset:])
|
|
||||||
if err != nil {
|
|
||||||
return i, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return len(bufs), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// BatchSize returns 1 for FreeBSD (no batching)
|
|
||||||
func (t *tun) BatchSize() int {
|
|
||||||
return 1
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *tun) addRoutes(logErrors bool) error {
|
func (t *tun) addRoutes(logErrors bool) error {
|
||||||
routes := *t.Routes.Load()
|
routes := *t.Routes.Load()
|
||||||
for _, r := range routes {
|
for _, r := range routes {
|
||||||
|
|||||||
@@ -151,29 +151,6 @@ func (t *tun) Name() string {
|
|||||||
return "iOS"
|
return "iOS"
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) NewMultiQueueReader() (BatchReadWriter, error) {
|
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||||
return nil, fmt.Errorf("TODO: multiqueue not implemented for ios")
|
return nil, fmt.Errorf("TODO: multiqueue not implemented for ios")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) BatchRead(bufs [][]byte, sizes []int) (int, error) {
|
|
||||||
n, err := t.Read(bufs[0])
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
sizes[0] = n
|
|
||||||
return 1, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *tun) WriteBatch(bufs [][]byte, offset int) (int, error) {
|
|
||||||
for i, buf := range bufs {
|
|
||||||
_, err := t.Write(buf[offset:])
|
|
||||||
if err != nil {
|
|
||||||
return i, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return len(bufs), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *tun) BatchSize() int {
|
|
||||||
return 1
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
|
"strings"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
"unsafe"
|
"unsafe"
|
||||||
@@ -20,12 +21,10 @@ import (
|
|||||||
"github.com/slackhq/nebula/util"
|
"github.com/slackhq/nebula/util"
|
||||||
"github.com/vishvananda/netlink"
|
"github.com/vishvananda/netlink"
|
||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
wgtun "golang.zx2c4.com/wireguard/tun"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type tun struct {
|
type tun struct {
|
||||||
io.ReadWriteCloser
|
io.ReadWriteCloser
|
||||||
wgDevice wgtun.Device
|
|
||||||
fd int
|
fd int
|
||||||
Device string
|
Device string
|
||||||
vpnNetworks []netip.Prefix
|
vpnNetworks []netip.Prefix
|
||||||
@@ -34,6 +33,9 @@ type tun struct {
|
|||||||
TXQueueLen int
|
TXQueueLen int
|
||||||
deviceIndex int
|
deviceIndex int
|
||||||
ioctlFd uintptr
|
ioctlFd uintptr
|
||||||
|
enableVnetHdr bool
|
||||||
|
vnetHdrLen int
|
||||||
|
queues []*tunQueue
|
||||||
|
|
||||||
Routes atomic.Pointer[[]Route]
|
Routes atomic.Pointer[[]Route]
|
||||||
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
||||||
@@ -66,174 +68,179 @@ type ifreqQLEN struct {
|
|||||||
pad [8]byte
|
pad [8]byte
|
||||||
}
|
}
|
||||||
|
|
||||||
// wgDeviceWrapper wraps a wireguard Device to implement io.ReadWriteCloser
|
const (
|
||||||
// This allows multiqueue readers to use the same wireguard Device batching as the main device
|
virtioNetHdrLen = 12
|
||||||
type wgDeviceWrapper struct {
|
tunDefaultMaxPacket = 65536
|
||||||
dev wgtun.Device
|
)
|
||||||
buf []byte // Reusable buffer for single packet reads
|
|
||||||
|
type tunQueue struct {
|
||||||
|
file *os.File
|
||||||
|
fd int
|
||||||
|
enableVnetHdr bool
|
||||||
|
vnetHdrLen int
|
||||||
|
maxPacket int
|
||||||
|
writeScratch []byte
|
||||||
|
readScratch []byte
|
||||||
|
l *logrus.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *wgDeviceWrapper) Read(b []byte) (int, error) {
|
func newTunQueue(file *os.File, enableVnetHdr bool, vnetHdrLen, maxPacket int, l *logrus.Logger) *tunQueue {
|
||||||
// Use wireguard Device's batch API for single packet
|
if maxPacket <= 0 {
|
||||||
bufs := [][]byte{b}
|
maxPacket = tunDefaultMaxPacket
|
||||||
sizes := make([]int, 1)
|
|
||||||
n, err := w.dev.Read(bufs, sizes, 0)
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
}
|
||||||
if n == 0 {
|
q := &tunQueue{
|
||||||
return 0, io.EOF
|
file: file,
|
||||||
|
fd: int(file.Fd()),
|
||||||
|
enableVnetHdr: enableVnetHdr,
|
||||||
|
vnetHdrLen: vnetHdrLen,
|
||||||
|
maxPacket: maxPacket,
|
||||||
|
l: l,
|
||||||
}
|
}
|
||||||
return sizes[0], nil
|
if enableVnetHdr {
|
||||||
}
|
q.growReadScratch(maxPacket)
|
||||||
|
|
||||||
func (w *wgDeviceWrapper) Write(b []byte) (int, error) {
|
|
||||||
// Buffer b should have virtio header space (10 bytes) at the beginning
|
|
||||||
// The decrypted packet data starts at offset 10
|
|
||||||
// Pass the full buffer to WireGuard with offset=virtioNetHdrLen
|
|
||||||
bufs := [][]byte{b}
|
|
||||||
n, err := w.dev.Write(bufs, VirtioNetHdrLen)
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
}
|
||||||
if n == 0 {
|
return q
|
||||||
return 0, io.ErrShortWrite
|
}
|
||||||
|
|
||||||
|
func (q *tunQueue) growReadScratch(packetSize int) {
|
||||||
|
needed := q.vnetHdrLen + packetSize
|
||||||
|
if needed < q.vnetHdrLen+DefaultMTU {
|
||||||
|
needed = q.vnetHdrLen + DefaultMTU
|
||||||
|
}
|
||||||
|
if q.readScratch == nil || cap(q.readScratch) < needed {
|
||||||
|
q.readScratch = make([]byte, needed)
|
||||||
|
} else {
|
||||||
|
q.readScratch = q.readScratch[:needed]
|
||||||
}
|
}
|
||||||
return len(b), nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *wgDeviceWrapper) WriteBatch(bufs [][]byte, offset int) (int, error) {
|
func (q *tunQueue) setMaxPacket(packet int) {
|
||||||
// Pass all buffers to WireGuard's batch write
|
if packet <= 0 {
|
||||||
return w.dev.Write(bufs, offset)
|
packet = DefaultMTU
|
||||||
|
}
|
||||||
|
q.maxPacket = packet
|
||||||
|
if q.enableVnetHdr {
|
||||||
|
q.growReadScratch(packet)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *wgDeviceWrapper) Close() error {
|
func configureVnetHdr(fd int, hdrLen int, l *logrus.Logger) error {
|
||||||
return w.dev.Close()
|
features, err := unix.IoctlGetInt(fd, unix.TUNGETFEATURES)
|
||||||
}
|
if err == nil && features&unix.IFF_VNET_HDR == 0 {
|
||||||
|
return fmt.Errorf("kernel does not support IFF_VNET_HDR")
|
||||||
// BatchRead implements batching for multiqueue readers
|
}
|
||||||
func (w *wgDeviceWrapper) BatchRead(bufs [][]byte, sizes []int) (int, error) {
|
if err := unix.IoctlSetInt(fd, unix.TUNSETVNETHDRSZ, hdrLen); err != nil {
|
||||||
// The zero here is offset.
|
return err
|
||||||
return w.dev.Read(bufs, sizes, 0)
|
}
|
||||||
}
|
offload := unix.TUN_F_CSUM | unix.TUN_F_UFO
|
||||||
|
if err := unix.IoctlSetInt(fd, unix.TUNSETOFFLOAD, offload); err != nil {
|
||||||
// BatchSize returns the optimal batch size
|
if l != nil {
|
||||||
func (w *wgDeviceWrapper) BatchSize() int {
|
l.WithError(err).Warn("Failed to enable TUN offload features")
|
||||||
return w.dev.BatchSize()
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) {
|
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) {
|
||||||
wgDev, name, err := wgtun.CreateUnmonitoredTUNFromFD(deviceFd)
|
file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
|
||||||
if err != nil {
|
enableVnetHdr := c.GetBool("tun.enable_vnet_hdr", false)
|
||||||
return nil, fmt.Errorf("failed to create TUN from FD: %w", err)
|
if enableVnetHdr {
|
||||||
|
if err := configureVnetHdr(deviceFd, virtioNetHdrLen, l); err != nil {
|
||||||
|
l.WithError(err).Warn("Failed to configure VNET header support on provided tun fd; disabling")
|
||||||
|
enableVnetHdr = false
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
file := wgDev.File()
|
t, err := newTunGeneric(c, l, file, vpnNetworks, enableVnetHdr)
|
||||||
t, err := newTunGeneric(c, l, file, vpnNetworks)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
_ = wgDev.Close()
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
t.wgDevice = wgDev
|
t.Device = "tun0"
|
||||||
t.Device = name
|
|
||||||
|
|
||||||
return t, nil
|
return t, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueue bool) (*tun, error) {
|
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueue bool) (*tun, error) {
|
||||||
// Check if /dev/net/tun exists, create if needed (for docker containers)
|
|
||||||
if _, err := os.Stat("/dev/net/tun"); os.IsNotExist(err) {
|
|
||||||
if err := os.MkdirAll("/dev/net", 0755); err != nil {
|
|
||||||
return nil, fmt.Errorf("/dev/net/tun doesn't exist, failed to mkdir -p /dev/net: %w", err)
|
|
||||||
}
|
|
||||||
if err := unix.Mknod("/dev/net/tun", unix.S_IFCHR|0600, int(unix.Mkdev(10, 200))); err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to create /dev/net/tun: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
devName := c.GetString("tun.dev", "")
|
|
||||||
mtu := c.GetInt("tun.mtu", DefaultMTU)
|
|
||||||
|
|
||||||
// Create TUN device manually to support multiqueue
|
|
||||||
fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
|
fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
// If /dev/net/tun doesn't exist, try to create it (will happen in docker)
|
||||||
|
if os.IsNotExist(err) {
|
||||||
|
err = os.MkdirAll("/dev/net", 0755)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("/dev/net/tun doesn't exist, failed to mkdir -p /dev/net: %w", err)
|
||||||
|
}
|
||||||
|
err = unix.Mknod("/dev/net/tun", unix.S_IFCHR|0600, int(unix.Mkdev(10, 200)))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create /dev/net/tun: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
fd, err = unix.Open("/dev/net/tun", os.O_RDWR, 0)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("created /dev/net/tun, but still failed: %w", err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
var req ifReq
|
var req ifReq
|
||||||
req.Flags = uint16(unix.IFF_TUN | unix.IFF_NO_PI | unix.IFF_VNET_HDR)
|
req.Flags = uint16(unix.IFF_TUN | unix.IFF_NO_PI)
|
||||||
if multiqueue {
|
if multiqueue {
|
||||||
req.Flags |= unix.IFF_MULTI_QUEUE
|
req.Flags |= unix.IFF_MULTI_QUEUE
|
||||||
}
|
}
|
||||||
copy(req.Name[:], devName)
|
enableVnetHdr := c.GetBool("tun.enable_vnet_hdr", false)
|
||||||
|
if enableVnetHdr {
|
||||||
|
req.Flags |= unix.IFF_VNET_HDR
|
||||||
|
}
|
||||||
|
copy(req.Name[:], c.GetString("tun.dev", ""))
|
||||||
if err = ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil {
|
if err = ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil {
|
||||||
unix.Close(fd)
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
name := strings.Trim(string(req.Name[:]), "\x00")
|
||||||
|
|
||||||
// Set nonblocking
|
if enableVnetHdr {
|
||||||
if err = unix.SetNonblock(fd, true); err != nil {
|
if err := configureVnetHdr(fd, virtioNetHdrLen, l); err != nil {
|
||||||
unix.Close(fd)
|
l.WithError(err).Warn("Failed to configure VNET header support on tun device; disabling")
|
||||||
return nil, err
|
enableVnetHdr = false
|
||||||
}
|
}
|
||||||
|
|
||||||
// Enable TCP and UDP offload (TSO/GRO) for performance
|
|
||||||
// This allows the kernel to handle segmentation/coalescing
|
|
||||||
const (
|
|
||||||
tunTCPOffloads = unix.TUN_F_CSUM | unix.TUN_F_TSO4 | unix.TUN_F_TSO6
|
|
||||||
tunUDPOffloads = unix.TUN_F_USO4 | unix.TUN_F_USO6
|
|
||||||
)
|
|
||||||
offloads := tunTCPOffloads | tunUDPOffloads
|
|
||||||
if err = unix.IoctlSetInt(fd, unix.TUNSETOFFLOAD, offloads); err != nil {
|
|
||||||
// Log warning but don't fail - offload is optional
|
|
||||||
l.WithError(err).Warn("Failed to enable TUN offload (TSO/GRO), performance may be reduced")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
file := os.NewFile(uintptr(fd), "/dev/net/tun")
|
file := os.NewFile(uintptr(fd), "/dev/net/tun")
|
||||||
|
t, err := newTunGeneric(c, l, file, vpnNetworks, enableVnetHdr)
|
||||||
// Create wireguard device from file descriptor
|
|
||||||
wgDev, err := wgtun.CreateTUNFromFile(file, mtu)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
file.Close()
|
|
||||||
return nil, fmt.Errorf("failed to create TUN from file: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
name, err := wgDev.Name()
|
|
||||||
if err != nil {
|
|
||||||
_ = wgDev.Close()
|
|
||||||
return nil, fmt.Errorf("failed to get TUN device name: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// file is now owned by wgDev, get a new reference
|
|
||||||
file = wgDev.File()
|
|
||||||
t, err := newTunGeneric(c, l, file, vpnNetworks)
|
|
||||||
if err != nil {
|
|
||||||
_ = wgDev.Close()
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
t.wgDevice = wgDev
|
|
||||||
t.Device = name
|
t.Device = name
|
||||||
|
|
||||||
return t, nil
|
return t, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, vpnNetworks []netip.Prefix) (*tun, error) {
|
func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, vpnNetworks []netip.Prefix, enableVnetHdr bool) (*tun, error) {
|
||||||
|
queue := newTunQueue(file, enableVnetHdr, virtioNetHdrLen, tunDefaultMaxPacket, l)
|
||||||
t := &tun{
|
t := &tun{
|
||||||
ReadWriteCloser: file,
|
ReadWriteCloser: queue,
|
||||||
fd: int(file.Fd()),
|
fd: int(file.Fd()),
|
||||||
vpnNetworks: vpnNetworks,
|
vpnNetworks: vpnNetworks,
|
||||||
TXQueueLen: c.GetInt("tun.tx_queue", 500),
|
TXQueueLen: c.GetInt("tun.tx_queue", 500),
|
||||||
useSystemRoutes: c.GetBool("tun.use_system_route_table", false),
|
useSystemRoutes: c.GetBool("tun.use_system_route_table", false),
|
||||||
useSystemRoutesBufferSize: c.GetInt("tun.use_system_route_table_buffer_size", 0),
|
useSystemRoutesBufferSize: c.GetInt("tun.use_system_route_table_buffer_size", 0),
|
||||||
l: l,
|
l: l,
|
||||||
|
enableVnetHdr: enableVnetHdr,
|
||||||
|
vnetHdrLen: virtioNetHdrLen,
|
||||||
|
queues: []*tunQueue{queue},
|
||||||
}
|
}
|
||||||
|
|
||||||
err := t.reload(c, true)
|
err := t.reload(c, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
if enableVnetHdr {
|
||||||
|
for _, q := range t.queues {
|
||||||
|
q.setMaxPacket(t.MaxMTU)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
c.RegisterReloadCallback(func(c *config.C) {
|
c.RegisterReloadCallback(func(c *config.C) {
|
||||||
err := t.reload(c, false)
|
err := t.reload(c, false)
|
||||||
@@ -276,6 +283,11 @@ func (t *tun) reload(c *config.C, initial bool) error {
|
|||||||
|
|
||||||
t.MaxMTU = newMaxMTU
|
t.MaxMTU = newMaxMTU
|
||||||
t.DefaultMTU = newDefaultMTU
|
t.DefaultMTU = newDefaultMTU
|
||||||
|
if t.enableVnetHdr {
|
||||||
|
for _, q := range t.queues {
|
||||||
|
q.setMaxPacket(t.MaxMTU)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Teach nebula how to handle the routes before establishing them in the system table
|
// Teach nebula how to handle the routes before establishing them in the system table
|
||||||
oldRoutes := t.Routes.Swap(&routes)
|
oldRoutes := t.Routes.Swap(&routes)
|
||||||
@@ -312,44 +324,95 @@ func (t *tun) reload(c *config.C, initial bool) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) NewMultiQueueReader() (BatchReadWriter, error) {
|
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||||
fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
|
fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
var req ifReq
|
var req ifReq
|
||||||
// MUST match the flags used in newTun - includes IFF_VNET_HDR
|
req.Flags = uint16(unix.IFF_TUN | unix.IFF_NO_PI | unix.IFF_MULTI_QUEUE)
|
||||||
req.Flags = uint16(unix.IFF_TUN | unix.IFF_NO_PI | unix.IFF_VNET_HDR | unix.IFF_MULTI_QUEUE)
|
if t.enableVnetHdr {
|
||||||
|
req.Flags |= unix.IFF_VNET_HDR
|
||||||
|
}
|
||||||
copy(req.Name[:], t.Device)
|
copy(req.Name[:], t.Device)
|
||||||
if err = ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil {
|
if err = ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil {
|
||||||
unix.Close(fd)
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set nonblocking mode - CRITICAL for proper netpoller integration
|
|
||||||
if err = unix.SetNonblock(fd, true); err != nil {
|
|
||||||
unix.Close(fd)
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get MTU from main device
|
|
||||||
mtu := t.MaxMTU
|
|
||||||
if mtu == 0 {
|
|
||||||
mtu = DefaultMTU
|
|
||||||
}
|
|
||||||
|
|
||||||
file := os.NewFile(uintptr(fd), "/dev/net/tun")
|
file := os.NewFile(uintptr(fd), "/dev/net/tun")
|
||||||
|
queue := newTunQueue(file, t.enableVnetHdr, t.vnetHdrLen, t.MaxMTU, t.l)
|
||||||
|
if t.enableVnetHdr {
|
||||||
|
if err := configureVnetHdr(fd, t.vnetHdrLen, t.l); err != nil {
|
||||||
|
queue.enableVnetHdr = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
t.queues = append(t.queues, queue)
|
||||||
|
|
||||||
// Create wireguard Device from the file descriptor (just like the main device)
|
return queue, nil
|
||||||
wgDev, err := wgtun.CreateTUNFromFile(file, mtu)
|
}
|
||||||
if err != nil {
|
|
||||||
file.Close()
|
func (q *tunQueue) Read(p []byte) (int, error) {
|
||||||
return nil, fmt.Errorf("failed to create multiqueue TUN device: %w", err)
|
if !q.enableVnetHdr {
|
||||||
|
return q.file.Read(p)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Return a wrapper that uses the wireguard Device for all I/O
|
if len(p)+q.vnetHdrLen > cap(q.readScratch) {
|
||||||
return &wgDeviceWrapper{dev: wgDev}, nil
|
q.growReadScratch(len(p))
|
||||||
|
}
|
||||||
|
|
||||||
|
buf := q.readScratch[:cap(q.readScratch)]
|
||||||
|
n, err := q.file.Read(buf)
|
||||||
|
if n <= 0 {
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
if n < q.vnetHdrLen {
|
||||||
|
if err == nil {
|
||||||
|
err = io.ErrUnexpectedEOF
|
||||||
|
}
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
payload := buf[q.vnetHdrLen:n]
|
||||||
|
if len(payload) > len(p) {
|
||||||
|
copy(p, payload[:len(p)])
|
||||||
|
if err == nil {
|
||||||
|
err = io.ErrShortBuffer
|
||||||
|
}
|
||||||
|
return len(p), err
|
||||||
|
}
|
||||||
|
copy(p, payload)
|
||||||
|
return len(payload), err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *tunQueue) Write(b []byte) (int, error) {
|
||||||
|
if !q.enableVnetHdr {
|
||||||
|
return unix.Write(q.fd, b)
|
||||||
|
}
|
||||||
|
|
||||||
|
total := q.vnetHdrLen + len(b)
|
||||||
|
if cap(q.writeScratch) < total {
|
||||||
|
q.writeScratch = make([]byte, total)
|
||||||
|
} else {
|
||||||
|
q.writeScratch = q.writeScratch[:total]
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 0; i < q.vnetHdrLen; i++ {
|
||||||
|
q.writeScratch[i] = 0
|
||||||
|
}
|
||||||
|
copy(q.writeScratch[q.vnetHdrLen:], b)
|
||||||
|
|
||||||
|
n, err := unix.Write(q.fd, q.writeScratch)
|
||||||
|
if n >= q.vnetHdrLen {
|
||||||
|
n -= q.vnetHdrLen
|
||||||
|
} else {
|
||||||
|
n = 0
|
||||||
|
}
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *tunQueue) Close() error {
|
||||||
|
return q.file.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
|
func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
|
||||||
@@ -357,68 +420,7 @@ func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
|
|||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) Read(b []byte) (int, error) {
|
|
||||||
if t.wgDevice != nil {
|
|
||||||
// Use wireguard device which handles virtio headers internally
|
|
||||||
bufs := [][]byte{b}
|
|
||||||
sizes := make([]int, 1)
|
|
||||||
n, err := t.wgDevice.Read(bufs, sizes, 0)
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
if n == 0 {
|
|
||||||
return 0, io.EOF
|
|
||||||
}
|
|
||||||
return sizes[0], nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Fallback: direct read from file (shouldn't happen in normal operation)
|
|
||||||
return t.ReadWriteCloser.Read(b)
|
|
||||||
}
|
|
||||||
|
|
||||||
// BatchRead reads multiple packets at once for improved performance
|
|
||||||
// bufs: slice of buffers to read into
|
|
||||||
// sizes: slice that will be filled with packet sizes
|
|
||||||
// Returns number of packets read
|
|
||||||
func (t *tun) BatchRead(bufs [][]byte, sizes []int) (int, error) {
|
|
||||||
if t.wgDevice != nil {
|
|
||||||
return t.wgDevice.Read(bufs, sizes, 0)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Fallback: single packet read
|
|
||||||
n, err := t.ReadWriteCloser.Read(bufs[0])
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
sizes[0] = n
|
|
||||||
return 1, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// BatchSize returns the optimal number of packets to read/write in a batch
|
|
||||||
func (t *tun) BatchSize() int {
|
|
||||||
if t.wgDevice != nil {
|
|
||||||
return t.wgDevice.BatchSize()
|
|
||||||
}
|
|
||||||
return 1
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *tun) Write(b []byte) (int, error) {
|
func (t *tun) Write(b []byte) (int, error) {
|
||||||
if t.wgDevice != nil {
|
|
||||||
// Buffer b should have virtio header space (10 bytes) at the beginning
|
|
||||||
// The decrypted packet data starts at offset 10
|
|
||||||
// Pass the full buffer to WireGuard with offset=virtioNetHdrLen
|
|
||||||
bufs := [][]byte{b}
|
|
||||||
n, err := t.wgDevice.Write(bufs, VirtioNetHdrLen)
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
if n == 0 {
|
|
||||||
return 0, io.ErrShortWrite
|
|
||||||
}
|
|
||||||
return len(b), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Fallback: direct write (shouldn't happen in normal operation)
|
|
||||||
var nn int
|
var nn int
|
||||||
maximum := len(b)
|
maximum := len(b)
|
||||||
|
|
||||||
@@ -441,22 +443,6 @@ func (t *tun) Write(b []byte) (int, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// WriteBatch writes multiple packets to the TUN device in a single syscall
|
|
||||||
func (t *tun) WriteBatch(bufs [][]byte, offset int) (int, error) {
|
|
||||||
if t.wgDevice != nil {
|
|
||||||
return t.wgDevice.Write(bufs, offset)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Fallback: write individually (shouldn't happen in normal operation)
|
|
||||||
for i, buf := range bufs {
|
|
||||||
_, err := t.Write(buf)
|
|
||||||
if err != nil {
|
|
||||||
return i, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return len(bufs), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *tun) deviceBytes() (o [16]byte) {
|
func (t *tun) deviceBytes() (o [16]byte) {
|
||||||
for i, c := range t.Device {
|
for i, c := range t.Device {
|
||||||
o[i] = byte(c)
|
o[i] = byte(c)
|
||||||
@@ -869,10 +855,6 @@ func (t *tun) Close() error {
|
|||||||
close(t.routeChan)
|
close(t.routeChan)
|
||||||
}
|
}
|
||||||
|
|
||||||
if t.wgDevice != nil {
|
|
||||||
_ = t.wgDevice.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
if t.ReadWriteCloser != nil {
|
if t.ReadWriteCloser != nil {
|
||||||
_ = t.ReadWriteCloser.Close()
|
_ = t.ReadWriteCloser.Close()
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -390,33 +390,10 @@ func (t *tun) Name() string {
|
|||||||
return t.Device
|
return t.Device
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) NewMultiQueueReader() (BatchReadWriter, error) {
|
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||||
return nil, fmt.Errorf("TODO: multiqueue not implemented for netbsd")
|
return nil, fmt.Errorf("TODO: multiqueue not implemented for netbsd")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) BatchRead(bufs [][]byte, sizes []int) (int, error) {
|
|
||||||
n, err := t.Read(bufs[0])
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
sizes[0] = n
|
|
||||||
return 1, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *tun) WriteBatch(bufs [][]byte, offset int) (int, error) {
|
|
||||||
for i, buf := range bufs {
|
|
||||||
_, err := t.Write(buf[offset:])
|
|
||||||
if err != nil {
|
|
||||||
return i, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return len(bufs), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *tun) BatchSize() int {
|
|
||||||
return 1
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *tun) addRoutes(logErrors bool) error {
|
func (t *tun) addRoutes(logErrors bool) error {
|
||||||
routes := *t.Routes.Load()
|
routes := *t.Routes.Load()
|
||||||
|
|
||||||
|
|||||||
@@ -310,33 +310,10 @@ func (t *tun) Name() string {
|
|||||||
return t.Device
|
return t.Device
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) NewMultiQueueReader() (BatchReadWriter, error) {
|
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||||
return nil, fmt.Errorf("TODO: multiqueue not implemented for openbsd")
|
return nil, fmt.Errorf("TODO: multiqueue not implemented for openbsd")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) BatchRead(bufs [][]byte, sizes []int) (int, error) {
|
|
||||||
n, err := t.Read(bufs[0])
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
sizes[0] = n
|
|
||||||
return 1, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *tun) WriteBatch(bufs [][]byte, offset int) (int, error) {
|
|
||||||
for i, buf := range bufs {
|
|
||||||
_, err := t.Write(buf[offset:])
|
|
||||||
if err != nil {
|
|
||||||
return i, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return len(bufs), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *tun) BatchSize() int {
|
|
||||||
return 1
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *tun) addRoutes(logErrors bool) error {
|
func (t *tun) addRoutes(logErrors bool) error {
|
||||||
routes := *t.Routes.Load()
|
routes := *t.Routes.Load()
|
||||||
|
|
||||||
|
|||||||
@@ -132,29 +132,6 @@ func (t *TestTun) Read(b []byte) (int, error) {
|
|||||||
return len(p), nil
|
return len(p), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *TestTun) NewMultiQueueReader() (BatchReadWriter, error) {
|
func (t *TestTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||||
return nil, fmt.Errorf("TODO: multiqueue not implemented")
|
return nil, fmt.Errorf("TODO: multiqueue not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *TestTun) BatchRead(bufs [][]byte, sizes []int) (int, error) {
|
|
||||||
n, err := t.Read(bufs[0])
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
sizes[0] = n
|
|
||||||
return 1, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *TestTun) WriteBatch(bufs [][]byte, offset int) (int, error) {
|
|
||||||
for i, buf := range bufs {
|
|
||||||
_, err := t.Write(buf[offset:])
|
|
||||||
if err != nil {
|
|
||||||
return i, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return len(bufs), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *TestTun) BatchSize() int {
|
|
||||||
return 1
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ package overlay
|
|||||||
import (
|
import (
|
||||||
"crypto"
|
"crypto"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
@@ -233,36 +234,10 @@ func (t *winTun) Write(b []byte) (int, error) {
|
|||||||
return t.tun.Write(b, 0)
|
return t.tun.Write(b, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *winTun) NewMultiQueueReader() (BatchReadWriter, error) {
|
func (t *winTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||||
return nil, fmt.Errorf("TODO: multiqueue not implemented for windows")
|
return nil, fmt.Errorf("TODO: multiqueue not implemented for windows")
|
||||||
}
|
}
|
||||||
|
|
||||||
// BatchRead reads a single packet (batch size 1 for Windows)
|
|
||||||
func (t *winTun) BatchRead(bufs [][]byte, sizes []int) (int, error) {
|
|
||||||
n, err := t.Read(bufs[0])
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
sizes[0] = n
|
|
||||||
return 1, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// WriteBatch writes packets individually (no batching for Windows)
|
|
||||||
func (t *winTun) WriteBatch(bufs [][]byte, offset int) (int, error) {
|
|
||||||
for i, buf := range bufs {
|
|
||||||
_, err := t.Write(buf[offset:])
|
|
||||||
if err != nil {
|
|
||||||
return i, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return len(bufs), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// BatchSize returns 1 for Windows (no batching)
|
|
||||||
func (t *winTun) BatchSize() int {
|
|
||||||
return 1
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *winTun) Close() error {
|
func (t *winTun) Close() error {
|
||||||
// It seems that the Windows networking stack doesn't like it when we destroy interfaces that have active routes,
|
// It seems that the Windows networking stack doesn't like it when we destroy interfaces that have active routes,
|
||||||
// so to be certain, just remove everything before destroying.
|
// so to be certain, just remove everything before destroying.
|
||||||
|
|||||||
@@ -46,36 +46,10 @@ func (d *UserDevice) RoutesFor(ip netip.Addr) routing.Gateways {
|
|||||||
return routing.Gateways{routing.NewGateway(ip, 1)}
|
return routing.Gateways{routing.NewGateway(ip, 1)}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *UserDevice) NewMultiQueueReader() (BatchReadWriter, error) {
|
func (d *UserDevice) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||||
return d, nil
|
return d, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// BatchRead reads a single packet (batch size 1 for UserDevice)
|
|
||||||
func (d *UserDevice) BatchRead(bufs [][]byte, sizes []int) (int, error) {
|
|
||||||
n, err := d.Read(bufs[0])
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
sizes[0] = n
|
|
||||||
return 1, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// WriteBatch writes packets individually (no batching for UserDevice)
|
|
||||||
func (d *UserDevice) WriteBatch(bufs [][]byte, offset int) (int, error) {
|
|
||||||
for i, buf := range bufs {
|
|
||||||
_, err := d.Write(buf[offset:])
|
|
||||||
if err != nil {
|
|
||||||
return i, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return len(bufs), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// BatchSize returns 1 for UserDevice (no batching)
|
|
||||||
func (d *UserDevice) BatchSize() int {
|
|
||||||
return 1
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *UserDevice) Pipe() (*io.PipeReader, *io.PipeWriter) {
|
func (d *UserDevice) Pipe() (*io.PipeReader, *io.PipeWriter) {
|
||||||
return d.inboundReader, d.outboundWriter
|
return d.inboundReader, d.outboundWriter
|
||||||
}
|
}
|
||||||
|
|||||||
39
pki.go
39
pki.go
@@ -100,36 +100,41 @@ func (p *PKI) reloadCerts(c *config.C, initial bool) *util.ContextualError {
|
|||||||
currentState := p.cs.Load()
|
currentState := p.cs.Load()
|
||||||
if newState.v1Cert != nil {
|
if newState.v1Cert != nil {
|
||||||
if currentState.v1Cert == nil {
|
if currentState.v1Cert == nil {
|
||||||
//adding certs is fine, actually. Networks-in-common confirmed in newCertState().
|
return util.NewContextualError("v1 certificate was added, restart required", nil, err)
|
||||||
} else {
|
}
|
||||||
|
|
||||||
// did IP in cert change? if so, don't set
|
// did IP in cert change? if so, don't set
|
||||||
if !slices.Equal(currentState.v1Cert.Networks(), newState.v1Cert.Networks()) {
|
if !slices.Equal(currentState.v1Cert.Networks(), newState.v1Cert.Networks()) {
|
||||||
return util.NewContextualError(
|
return util.NewContextualError(
|
||||||
"Networks in new cert was different from old",
|
"Networks in new cert was different from old",
|
||||||
m{"new_networks": newState.v1Cert.Networks(), "old_networks": currentState.v1Cert.Networks(), "cert_version": cert.Version1},
|
m{"new_networks": newState.v1Cert.Networks(), "old_networks": currentState.v1Cert.Networks()},
|
||||||
nil,
|
nil,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
if currentState.v1Cert.Curve() != newState.v1Cert.Curve() {
|
if currentState.v1Cert.Curve() != newState.v1Cert.Curve() {
|
||||||
return util.NewContextualError(
|
return util.NewContextualError(
|
||||||
"Curve in new v1 cert was different from old",
|
"Curve in new cert was different from old",
|
||||||
m{"new_curve": newState.v1Cert.Curve(), "old_curve": currentState.v1Cert.Curve(), "cert_version": cert.Version1},
|
m{"new_curve": newState.v1Cert.Curve(), "old_curve": currentState.v1Cert.Curve()},
|
||||||
nil,
|
nil,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
} else if currentState.v1Cert != nil {
|
||||||
|
//TODO: CERT-V2 we should be able to tear this down
|
||||||
|
return util.NewContextualError("v1 certificate was removed, restart required", nil, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if newState.v2Cert != nil {
|
if newState.v2Cert != nil {
|
||||||
if currentState.v2Cert == nil {
|
if currentState.v2Cert == nil {
|
||||||
//adding certs is fine, actually
|
return util.NewContextualError("v2 certificate was added, restart required", nil, err)
|
||||||
} else {
|
}
|
||||||
|
|
||||||
// did IP in cert change? if so, don't set
|
// did IP in cert change? if so, don't set
|
||||||
if !slices.Equal(currentState.v2Cert.Networks(), newState.v2Cert.Networks()) {
|
if !slices.Equal(currentState.v2Cert.Networks(), newState.v2Cert.Networks()) {
|
||||||
return util.NewContextualError(
|
return util.NewContextualError(
|
||||||
"Networks in new cert was different from old",
|
"Networks in new cert was different from old",
|
||||||
m{"new_networks": newState.v2Cert.Networks(), "old_networks": currentState.v2Cert.Networks(), "cert_version": cert.Version2},
|
m{"new_networks": newState.v2Cert.Networks(), "old_networks": currentState.v2Cert.Networks()},
|
||||||
nil,
|
nil,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
@@ -137,25 +142,13 @@ func (p *PKI) reloadCerts(c *config.C, initial bool) *util.ContextualError {
|
|||||||
if currentState.v2Cert.Curve() != newState.v2Cert.Curve() {
|
if currentState.v2Cert.Curve() != newState.v2Cert.Curve() {
|
||||||
return util.NewContextualError(
|
return util.NewContextualError(
|
||||||
"Curve in new cert was different from old",
|
"Curve in new cert was different from old",
|
||||||
m{"new_curve": newState.v2Cert.Curve(), "old_curve": currentState.v2Cert.Curve(), "cert_version": cert.Version2},
|
m{"new_curve": newState.v2Cert.Curve(), "old_curve": currentState.v2Cert.Curve()},
|
||||||
nil,
|
nil,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
} else if currentState.v2Cert != nil {
|
} else if currentState.v2Cert != nil {
|
||||||
//newState.v1Cert is non-nil bc empty certstates aren't permitted
|
return util.NewContextualError("v2 certificate was removed, restart required", nil, err)
|
||||||
if newState.v1Cert == nil {
|
|
||||||
return util.NewContextualError("v1 and v2 certs are nil, this should be impossible", nil, err)
|
|
||||||
}
|
|
||||||
//if we're going to v1-only, we need to make sure we didn't orphan any v2-cert vpnaddrs
|
|
||||||
if !slices.Equal(currentState.v2Cert.Networks(), newState.v1Cert.Networks()) {
|
|
||||||
return util.NewContextualError(
|
|
||||||
"Removing a V2 cert is not permitted unless it has identical networks to the new V1 cert",
|
|
||||||
m{"new_v1_networks": newState.v1Cert.Networks(), "old_v2_networks": currentState.v2Cert.Networks()},
|
|
||||||
nil,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Cipher cant be hot swapped so just leave it at what it was before
|
// Cipher cant be hot swapped so just leave it at what it was before
|
||||||
|
|||||||
1
stats.go
1
stats.go
@@ -6,7 +6,6 @@ import (
|
|||||||
"log"
|
"log"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
_ "net/http/pprof"
|
|
||||||
"runtime"
|
"runtime"
|
||||||
"strconv"
|
"strconv"
|
||||||
"time"
|
"time"
|
||||||
|
|||||||
16
udp/config.go
Normal file
16
udp/config.go
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
package udp
|
||||||
|
|
||||||
|
import "sync/atomic"
|
||||||
|
|
||||||
|
var disableUDPCsum atomic.Bool
|
||||||
|
|
||||||
|
// SetDisableUDPCsum controls whether IPv4 UDP sockets opt out of kernel
|
||||||
|
// checksum calculation via SO_NO_CHECK. Only applicable on platforms that
|
||||||
|
// support the option (Linux). IPv6 always keeps the checksum enabled.
|
||||||
|
func SetDisableUDPCsum(disable bool) {
|
||||||
|
disableUDPCsum.Store(disable)
|
||||||
|
}
|
||||||
|
|
||||||
|
func udpChecksumDisabled() bool {
|
||||||
|
return disableUDPCsum.Load()
|
||||||
|
}
|
||||||
19
udp/conn.go
19
udp/conn.go
@@ -11,23 +11,15 @@ const MTU = 9001
|
|||||||
type EncReader func(
|
type EncReader func(
|
||||||
addr netip.AddrPort,
|
addr netip.AddrPort,
|
||||||
payload []byte,
|
payload []byte,
|
||||||
)
|
release func(),
|
||||||
|
|
||||||
type EncBatchReader func(
|
|
||||||
addrs []netip.AddrPort,
|
|
||||||
payloads [][]byte,
|
|
||||||
count int,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type Conn interface {
|
type Conn interface {
|
||||||
Rebind() error
|
Rebind() error
|
||||||
LocalAddr() (netip.AddrPort, error)
|
LocalAddr() (netip.AddrPort, error)
|
||||||
ListenOut(r EncReader)
|
ListenOut(r EncReader)
|
||||||
ListenOutBatch(r EncBatchReader)
|
|
||||||
WriteTo(b []byte, addr netip.AddrPort) error
|
WriteTo(b []byte, addr netip.AddrPort) error
|
||||||
WriteMulti(packets [][]byte, addrs []netip.AddrPort) (int, error)
|
|
||||||
ReloadConfig(c *config.C)
|
ReloadConfig(c *config.C)
|
||||||
BatchSize() int
|
|
||||||
Close() error
|
Close() error
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -42,21 +34,12 @@ func (NoopConn) LocalAddr() (netip.AddrPort, error) {
|
|||||||
func (NoopConn) ListenOut(_ EncReader) {
|
func (NoopConn) ListenOut(_ EncReader) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
func (NoopConn) ListenOutBatch(_ EncBatchReader) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
func (NoopConn) WriteTo(_ []byte, _ netip.AddrPort) error {
|
func (NoopConn) WriteTo(_ []byte, _ netip.AddrPort) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
func (NoopConn) WriteMulti(_ [][]byte, _ []netip.AddrPort) (int, error) {
|
|
||||||
return 0, nil
|
|
||||||
}
|
|
||||||
func (NoopConn) ReloadConfig(_ *config.C) {
|
func (NoopConn) ReloadConfig(_ *config.C) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
func (NoopConn) BatchSize() int {
|
|
||||||
return 1
|
|
||||||
}
|
|
||||||
func (NoopConn) Close() error {
|
func (NoopConn) Close() error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
25
udp/msghdr_helper_linux_32.go
Normal file
25
udp/msghdr_helper_linux_32.go
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
//go:build linux && (386 || amd64p32 || arm || mips || mipsle) && !android && !e2e_testing
|
||||||
|
// +build linux
|
||||||
|
// +build 386 amd64p32 arm mips mipsle
|
||||||
|
// +build !android
|
||||||
|
// +build !e2e_testing
|
||||||
|
|
||||||
|
package udp
|
||||||
|
|
||||||
|
import "golang.org/x/sys/unix"
|
||||||
|
|
||||||
|
func controllen(n int) uint32 {
|
||||||
|
return uint32(n)
|
||||||
|
}
|
||||||
|
|
||||||
|
func setCmsgLen(h *unix.Cmsghdr, n int) {
|
||||||
|
h.Len = uint32(unix.CmsgLen(n))
|
||||||
|
}
|
||||||
|
|
||||||
|
func setIovecLen(v *unix.Iovec, n int) {
|
||||||
|
v.Len = uint32(n)
|
||||||
|
}
|
||||||
|
|
||||||
|
func setMsghdrIovlen(m *unix.Msghdr, n int) {
|
||||||
|
m.Iovlen = uint32(n)
|
||||||
|
}
|
||||||
25
udp/msghdr_helper_linux_64.go
Normal file
25
udp/msghdr_helper_linux_64.go
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
//go:build linux && (amd64 || arm64 || ppc64 || ppc64le || mips64 || mips64le || s390x || riscv64 || loong64) && !android && !e2e_testing
|
||||||
|
// +build linux
|
||||||
|
// +build amd64 arm64 ppc64 ppc64le mips64 mips64le s390x riscv64 loong64
|
||||||
|
// +build !android
|
||||||
|
// +build !e2e_testing
|
||||||
|
|
||||||
|
package udp
|
||||||
|
|
||||||
|
import "golang.org/x/sys/unix"
|
||||||
|
|
||||||
|
func controllen(n int) uint64 {
|
||||||
|
return uint64(n)
|
||||||
|
}
|
||||||
|
|
||||||
|
func setCmsgLen(h *unix.Cmsghdr, n int) {
|
||||||
|
h.Len = uint64(unix.CmsgLen(n))
|
||||||
|
}
|
||||||
|
|
||||||
|
func setIovecLen(v *unix.Iovec, n int) {
|
||||||
|
v.Len = uint64(n)
|
||||||
|
}
|
||||||
|
|
||||||
|
func setMsghdrIovlen(m *unix.Msghdr, n int) {
|
||||||
|
m.Iovlen = uint64(n)
|
||||||
|
}
|
||||||
25
udp/sendmmsg_linux_32.go
Normal file
25
udp/sendmmsg_linux_32.go
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
//go:build linux && (386 || amd64p32 || arm || mips || mipsle) && !android && !e2e_testing
|
||||||
|
|
||||||
|
package udp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
|
"golang.org/x/sys/unix"
|
||||||
|
)
|
||||||
|
|
||||||
|
type linuxMmsgHdr struct {
|
||||||
|
Hdr unix.Msghdr
|
||||||
|
Len uint32
|
||||||
|
}
|
||||||
|
|
||||||
|
func sendmmsg(fd int, hdrs []linuxMmsgHdr, flags int) (int, error) {
|
||||||
|
if len(hdrs) == 0 {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
n, _, errno := unix.Syscall6(unix.SYS_SENDMMSG, uintptr(fd), uintptr(unsafe.Pointer(&hdrs[0])), uintptr(len(hdrs)), uintptr(flags), 0, 0)
|
||||||
|
if errno != 0 {
|
||||||
|
return int(n), errno
|
||||||
|
}
|
||||||
|
return int(n), nil
|
||||||
|
}
|
||||||
26
udp/sendmmsg_linux_64.go
Normal file
26
udp/sendmmsg_linux_64.go
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
//go:build linux && (amd64 || arm64 || ppc64 || ppc64le || mips64 || mips64le || s390x || riscv64 || loong64) && !android && !e2e_testing
|
||||||
|
|
||||||
|
package udp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
|
"golang.org/x/sys/unix"
|
||||||
|
)
|
||||||
|
|
||||||
|
type linuxMmsgHdr struct {
|
||||||
|
Hdr unix.Msghdr
|
||||||
|
Len uint32
|
||||||
|
_ uint32
|
||||||
|
}
|
||||||
|
|
||||||
|
func sendmmsg(fd int, hdrs []linuxMmsgHdr, flags int) (int, error) {
|
||||||
|
if len(hdrs) == 0 {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
n, _, errno := unix.Syscall6(unix.SYS_SENDMMSG, uintptr(fd), uintptr(unsafe.Pointer(&hdrs[0])), uintptr(len(hdrs)), uintptr(flags), 0, 0)
|
||||||
|
if errno != 0 {
|
||||||
|
return int(n), errno
|
||||||
|
}
|
||||||
|
return int(n), nil
|
||||||
|
}
|
||||||
@@ -140,17 +140,6 @@ func (u *StdConn) WriteTo(b []byte, ap netip.AddrPort) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// WriteMulti sends multiple packets - fallback implementation without sendmmsg
|
|
||||||
func (u *StdConn) WriteMulti(packets [][]byte, addrs []netip.AddrPort) (int, error) {
|
|
||||||
for i := range packets {
|
|
||||||
err := u.WriteTo(packets[i], addrs[i])
|
|
||||||
if err != nil {
|
|
||||||
return i, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return len(packets), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (u *StdConn) LocalAddr() (netip.AddrPort, error) {
|
func (u *StdConn) LocalAddr() (netip.AddrPort, error) {
|
||||||
a := u.UDPConn.LocalAddr()
|
a := u.UDPConn.LocalAddr()
|
||||||
|
|
||||||
@@ -191,38 +180,10 @@ func (u *StdConn) ListenOut(r EncReader) {
|
|||||||
u.l.WithError(err).Error("unexpected udp socket receive error")
|
u.l.WithError(err).Error("unexpected udp socket receive error")
|
||||||
}
|
}
|
||||||
|
|
||||||
r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n])
|
r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n], nil)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ListenOutBatch - fallback to single-packet reads for Darwin
|
|
||||||
func (u *StdConn) ListenOutBatch(r EncBatchReader) {
|
|
||||||
buffer := make([]byte, MTU)
|
|
||||||
addrs := make([]netip.AddrPort, 1)
|
|
||||||
payloads := make([][]byte, 1)
|
|
||||||
|
|
||||||
for {
|
|
||||||
// Just read one packet at a time and call batch callback with count=1
|
|
||||||
n, rua, err := u.ReadFromUDPAddrPort(buffer)
|
|
||||||
if err != nil {
|
|
||||||
if errors.Is(err, net.ErrClosed) {
|
|
||||||
u.l.WithError(err).Debug("udp socket is closed, exiting read loop")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
u.l.WithError(err).Error("unexpected udp socket receive error")
|
|
||||||
}
|
|
||||||
|
|
||||||
addrs[0] = netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port())
|
|
||||||
payloads[0] = buffer[:n]
|
|
||||||
r(addrs, payloads, 1)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (u *StdConn) BatchSize() int {
|
|
||||||
return 1
|
|
||||||
}
|
|
||||||
|
|
||||||
func (u *StdConn) Rebind() error {
|
func (u *StdConn) Rebind() error {
|
||||||
var err error
|
var err error
|
||||||
if u.isV4 {
|
if u.isV4 {
|
||||||
|
|||||||
@@ -82,45 +82,6 @@ func (u *GenericConn) ListenOut(r EncReader) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n])
|
r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n], nil)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ListenOutBatch - fallback to single-packet reads for generic platforms
|
|
||||||
func (u *GenericConn) ListenOutBatch(r EncBatchReader) {
|
|
||||||
buffer := make([]byte, MTU)
|
|
||||||
addrs := make([]netip.AddrPort, 1)
|
|
||||||
payloads := make([][]byte, 1)
|
|
||||||
|
|
||||||
for {
|
|
||||||
// Just read one packet at a time and call batch callback with count=1
|
|
||||||
n, rua, err := u.ReadFromUDPAddrPort(buffer)
|
|
||||||
if err != nil {
|
|
||||||
u.l.WithError(err).Debug("udp socket is closed, exiting read loop")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
addrs[0] = netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port())
|
|
||||||
payloads[0] = buffer[:n]
|
|
||||||
r(addrs, payloads, 1)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// WriteMulti sends multiple packets - fallback implementation
|
|
||||||
func (u *GenericConn) WriteMulti(packets [][]byte, addrs []netip.AddrPort) (int, error) {
|
|
||||||
for i := range packets {
|
|
||||||
err := u.WriteTo(packets[i], addrs[i])
|
|
||||||
if err != nil {
|
|
||||||
return i, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return len(packets), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (u *GenericConn) BatchSize() int {
|
|
||||||
return 1
|
|
||||||
}
|
|
||||||
|
|
||||||
func (u *GenericConn) Rebind() error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|||||||
1250
udp/udp_linux.go
1250
udp/udp_linux.go
File diff suppressed because it is too large
Load Diff
@@ -12,7 +12,7 @@ import (
|
|||||||
|
|
||||||
type iovec struct {
|
type iovec struct {
|
||||||
Base *byte
|
Base *byte
|
||||||
Len uint
|
Len uint32
|
||||||
}
|
}
|
||||||
|
|
||||||
type msghdr struct {
|
type msghdr struct {
|
||||||
@@ -30,17 +30,29 @@ type rawMessage struct {
|
|||||||
Len uint32
|
Len uint32
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) {
|
func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte, [][]byte) {
|
||||||
|
controlLen := int(u.controlLen.Load())
|
||||||
|
|
||||||
msgs := make([]rawMessage, n)
|
msgs := make([]rawMessage, n)
|
||||||
buffers := make([][]byte, n)
|
buffers := make([][]byte, n)
|
||||||
names := make([][]byte, n)
|
names := make([][]byte, n)
|
||||||
|
|
||||||
|
var controls [][]byte
|
||||||
|
if controlLen > 0 {
|
||||||
|
controls = make([][]byte, n)
|
||||||
|
}
|
||||||
|
|
||||||
for i := range msgs {
|
for i := range msgs {
|
||||||
buffers[i] = make([]byte, MTU)
|
size := int(u.groBufSize.Load())
|
||||||
|
if size < MTU {
|
||||||
|
size = MTU
|
||||||
|
}
|
||||||
|
buf := u.borrowRxBuffer(size)
|
||||||
|
buffers[i] = buf
|
||||||
names[i] = make([]byte, unix.SizeofSockaddrInet6)
|
names[i] = make([]byte, unix.SizeofSockaddrInet6)
|
||||||
|
|
||||||
vs := []iovec{
|
vs := []iovec{
|
||||||
{Base: &buffers[i][0], Len: uint(len(buffers[i]))},
|
{Base: &buf[0], Len: uint32(len(buf))},
|
||||||
}
|
}
|
||||||
|
|
||||||
msgs[i].Hdr.Iov = &vs[0]
|
msgs[i].Hdr.Iov = &vs[0]
|
||||||
@@ -48,7 +60,22 @@ func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) {
|
|||||||
|
|
||||||
msgs[i].Hdr.Name = &names[i][0]
|
msgs[i].Hdr.Name = &names[i][0]
|
||||||
msgs[i].Hdr.Namelen = uint32(len(names[i]))
|
msgs[i].Hdr.Namelen = uint32(len(names[i]))
|
||||||
|
|
||||||
|
if controlLen > 0 {
|
||||||
|
controls[i] = make([]byte, controlLen)
|
||||||
|
msgs[i].Hdr.Control = &controls[i][0]
|
||||||
|
msgs[i].Hdr.Controllen = controllen(len(controls[i]))
|
||||||
|
} else {
|
||||||
|
msgs[i].Hdr.Control = nil
|
||||||
|
msgs[i].Hdr.Controllen = controllen(0)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return msgs, buffers, names
|
return msgs, buffers, names, controls
|
||||||
|
}
|
||||||
|
|
||||||
|
func setIovecBase(msg *rawMessage, buf []byte) {
|
||||||
|
iov := (*iovec)(msg.Hdr.Iov)
|
||||||
|
iov.Base = &buf[0]
|
||||||
|
iov.Len = uint32(len(buf))
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ import (
|
|||||||
|
|
||||||
type iovec struct {
|
type iovec struct {
|
||||||
Base *byte
|
Base *byte
|
||||||
Len uint
|
Len uint64
|
||||||
}
|
}
|
||||||
|
|
||||||
type msghdr struct {
|
type msghdr struct {
|
||||||
@@ -33,25 +33,50 @@ type rawMessage struct {
|
|||||||
Pad0 [4]byte
|
Pad0 [4]byte
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) {
|
func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte, [][]byte) {
|
||||||
|
controlLen := int(u.controlLen.Load())
|
||||||
|
|
||||||
msgs := make([]rawMessage, n)
|
msgs := make([]rawMessage, n)
|
||||||
buffers := make([][]byte, n)
|
buffers := make([][]byte, n)
|
||||||
names := make([][]byte, n)
|
names := make([][]byte, n)
|
||||||
|
|
||||||
|
var controls [][]byte
|
||||||
|
if controlLen > 0 {
|
||||||
|
controls = make([][]byte, n)
|
||||||
|
}
|
||||||
|
|
||||||
for i := range msgs {
|
for i := range msgs {
|
||||||
buffers[i] = make([]byte, MTU)
|
size := int(u.groBufSize.Load())
|
||||||
|
if size < MTU {
|
||||||
|
size = MTU
|
||||||
|
}
|
||||||
|
buf := u.borrowRxBuffer(size)
|
||||||
|
buffers[i] = buf
|
||||||
names[i] = make([]byte, unix.SizeofSockaddrInet6)
|
names[i] = make([]byte, unix.SizeofSockaddrInet6)
|
||||||
|
|
||||||
vs := []iovec{
|
vs := []iovec{{Base: &buf[0], Len: uint64(len(buf))}}
|
||||||
{Base: &buffers[i][0], Len: uint(len(buffers[i]))},
|
|
||||||
}
|
|
||||||
|
|
||||||
msgs[i].Hdr.Iov = &vs[0]
|
msgs[i].Hdr.Iov = &vs[0]
|
||||||
msgs[i].Hdr.Iovlen = uint64(len(vs))
|
msgs[i].Hdr.Iovlen = uint64(len(vs))
|
||||||
|
|
||||||
msgs[i].Hdr.Name = &names[i][0]
|
msgs[i].Hdr.Name = &names[i][0]
|
||||||
msgs[i].Hdr.Namelen = uint32(len(names[i]))
|
msgs[i].Hdr.Namelen = uint32(len(names[i]))
|
||||||
|
|
||||||
|
if controlLen > 0 {
|
||||||
|
controls[i] = make([]byte, controlLen)
|
||||||
|
msgs[i].Hdr.Control = &controls[i][0]
|
||||||
|
msgs[i].Hdr.Controllen = controllen(len(controls[i]))
|
||||||
|
} else {
|
||||||
|
msgs[i].Hdr.Control = nil
|
||||||
|
msgs[i].Hdr.Controllen = controllen(0)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return msgs, buffers, names
|
return msgs, buffers, names, controls
|
||||||
|
}
|
||||||
|
|
||||||
|
func setIovecBase(msg *rawMessage, buf []byte) {
|
||||||
|
iov := (*iovec)(msg.Hdr.Iov)
|
||||||
|
iov.Base = &buf[0]
|
||||||
|
iov.Len = uint64(len(buf))
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -149,7 +149,7 @@ func (u *RIOConn) ListenOut(r EncReader) {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
r(netip.AddrPortFrom(netip.AddrFrom16(rua.Addr).Unmap(), (rua.Port>>8)|((rua.Port&0xff)<<8)), buffer[:n])
|
r(netip.AddrPortFrom(netip.AddrFrom16(rua.Addr).Unmap(), (rua.Port>>8)|((rua.Port&0xff)<<8)), buffer[:n], nil)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -112,35 +112,10 @@ func (u *TesterConn) ListenOut(r EncReader) {
|
|||||||
if !ok {
|
if !ok {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
r(p.From, p.Data)
|
r(p.From, p.Data, func() {})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *TesterConn) ListenOutBatch(r EncBatchReader) {
|
|
||||||
addrs := make([]netip.AddrPort, 1)
|
|
||||||
payloads := make([][]byte, 1)
|
|
||||||
|
|
||||||
for {
|
|
||||||
p, ok := <-u.RxPackets
|
|
||||||
if !ok {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
addrs[0] = p.From
|
|
||||||
payloads[0] = p.Data
|
|
||||||
r(addrs, payloads, 1)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (u *TesterConn) WriteMulti(packets [][]byte, addrs []netip.AddrPort) (int, error) {
|
|
||||||
for i := range packets {
|
|
||||||
err := u.WriteTo(packets[i], addrs[i])
|
|
||||||
if err != nil {
|
|
||||||
return i, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return len(packets), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (u *TesterConn) ReloadConfig(*config.C) {}
|
func (u *TesterConn) ReloadConfig(*config.C) {}
|
||||||
|
|
||||||
func NewUDPStatsEmitter(_ []Conn) func() {
|
func NewUDPStatsEmitter(_ []Conn) func() {
|
||||||
@@ -152,10 +127,6 @@ func (u *TesterConn) LocalAddr() (netip.AddrPort, error) {
|
|||||||
return u.Addr, nil
|
return u.Addr, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *TesterConn) BatchSize() int {
|
|
||||||
return 1
|
|
||||||
}
|
|
||||||
|
|
||||||
func (u *TesterConn) Rebind() error {
|
func (u *TesterConn) Rebind() error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user