Compare commits

..

7 Commits

Author SHA1 Message Date
Ryan Huber
a4b7f624da sure 2025-11-03 17:23:57 +00:00
Ryan Huber
1c069a8e42 reuse control on gso 2025-11-03 11:14:52 +00:00
Ryan Huber
0d8bd11818 reuse GRO slices 2025-11-03 11:06:07 +00:00
Ryan Huber
5128e2653e reuse packet buffer 2025-11-03 10:52:09 +00:00
Ryan Huber
c73b2dfbc7 fixed fallback for non io_uring packet send/recv 2025-11-03 10:45:30 +00:00
Ryan Huber
3dea761530 fix compile for 386 2025-11-03 10:12:02 +00:00
Ryan Huber
b394112ad9 gso and gro with uring on send/receive for udp 2025-11-03 09:59:45 +00:00
29 changed files with 4876 additions and 721 deletions

View File

@@ -1,8 +1,10 @@
package cert
import (
"encoding/hex"
"encoding/pem"
"fmt"
"time"
"golang.org/x/crypto/ed25519"
)
@@ -189,3 +191,71 @@ func UnmarshalSigningPrivateKeyFromPEM(b []byte) ([]byte, []byte, Curve, error)
}
return k.Bytes, r, curve, nil
}
// Backward compatibility functions for older API
func MarshalX25519PublicKey(b []byte) []byte {
return MarshalPublicKeyToPEM(Curve_CURVE25519, b)
}
func MarshalX25519PrivateKey(b []byte) []byte {
return MarshalPrivateKeyToPEM(Curve_CURVE25519, b)
}
func MarshalPublicKey(curve Curve, b []byte) []byte {
return MarshalPublicKeyToPEM(curve, b)
}
func MarshalPrivateKey(curve Curve, b []byte) []byte {
return MarshalPrivateKeyToPEM(curve, b)
}
// NebulaCertificate is a compatibility wrapper for the old API
type NebulaCertificate struct {
Details NebulaCertificateDetails
Signature []byte
cert Certificate
}
// NebulaCertificateDetails is a compatibility wrapper for certificate details
type NebulaCertificateDetails struct {
Name string
NotBefore time.Time
NotAfter time.Time
PublicKey []byte
IsCA bool
Issuer []byte
Curve Curve
}
// UnmarshalNebulaCertificateFromPEM provides backward compatibility with the old API
func UnmarshalNebulaCertificateFromPEM(b []byte) (*NebulaCertificate, []byte, error) {
c, rest, err := UnmarshalCertificateFromPEM(b)
if err != nil {
return nil, rest, err
}
// Convert to old format
nc := &NebulaCertificate{
Details: NebulaCertificateDetails{
Name: c.Name(),
NotBefore: c.NotBefore(),
NotAfter: c.NotAfter(),
PublicKey: c.PublicKey(),
IsCA: c.IsCA(),
Curve: c.Curve(),
},
Signature: c.Signature(),
cert: c,
}
// Handle issuer
if c.Issuer() != "" {
issuerBytes, err := hex.DecodeString(c.Issuer())
if err != nil {
return nil, rest, fmt.Errorf("failed to decode issuer fingerprint: %w", err)
}
nc.Details.Issuer = issuerBytes
}
return nc, rest, nil
}

View File

@@ -114,33 +114,6 @@ func NewTestCert(v cert.Version, curve cert.Curve, ca cert.Certificate, key []by
return c, pub, cert.MarshalPrivateKeyToPEM(curve, priv), pem
}
func NewTestCertDifferentVersion(c cert.Certificate, v cert.Version, ca cert.Certificate, key []byte) (cert.Certificate, []byte) {
nc := &cert.TBSCertificate{
Version: v,
Curve: c.Curve(),
Name: c.Name(),
Networks: c.Networks(),
UnsafeNetworks: c.UnsafeNetworks(),
Groups: c.Groups(),
NotBefore: time.Unix(c.NotBefore().Unix(), 0),
NotAfter: time.Unix(c.NotAfter().Unix(), 0),
PublicKey: c.PublicKey(),
IsCA: false,
}
c, err := nc.Sign(ca, ca.Curve(), key)
if err != nil {
panic(err)
}
pem, err := c.MarshalPEM()
if err != nil {
panic(err)
}
return c, pem
}
func X25519Keypair() ([]byte, []byte) {
privkey := make([]byte, 32)
if _, err := io.ReadFull(rand.Reader, privkey); err != nil {

View File

@@ -65,16 +65,8 @@ func main() {
}
if !*configTest {
wait, err := ctrl.Start()
if err != nil {
util.LogWithContextIfNeeded("Error while running", err, l)
os.Exit(1)
}
go ctrl.ShutdownBlock()
wait()
l.Info("Goodbye")
ctrl.Start()
ctrl.ShutdownBlock()
}
os.Exit(0)

View File

@@ -3,9 +3,6 @@ package main
import (
"flag"
"fmt"
"log"
"net/http"
_ "net/http/pprof"
"os"
"github.com/sirupsen/logrus"
@@ -61,22 +58,10 @@ func main() {
os.Exit(1)
}
go func() {
log.Println(http.ListenAndServe("0.0.0.0:6060", nil))
}()
if !*configTest {
wait, err := ctrl.Start()
if err != nil {
util.LogWithContextIfNeeded("Error while running", err, l)
os.Exit(1)
}
go ctrl.ShutdownBlock()
ctrl.Start()
notifyReady(l)
wait()
l.Info("Goodbye")
ctrl.ShutdownBlock()
}
os.Exit(0)

View File

@@ -354,6 +354,7 @@ func (cm *connectionManager) makeTrafficDecision(localIndex uint32, now time.Tim
if mainHostInfo {
decision = tryRehandshake
} else {
if cm.shouldSwapPrimary(hostinfo) {
decision = swapPrimary
@@ -460,10 +461,6 @@ func (cm *connectionManager) shouldSwapPrimary(current *HostInfo) bool {
}
crt := cm.intf.pki.getCertState().getCertificate(current.ConnectionState.myCert.Version())
if crt == nil {
//my cert was reloaded away. We should definitely swap from this tunnel
return true
}
// If this tunnel is using the latest certificate then we should swap it to primary for a bit and see if things
// settle down.
return bytes.Equal(current.ConnectionState.myCert.Signature(), crt.Signature())
@@ -478,34 +475,31 @@ func (cm *connectionManager) swapPrimary(current, primary *HostInfo) {
cm.hostMap.Unlock()
}
// isInvalidCertificate decides if we should destroy a tunnel.
// returns true if pki.disconnect_invalid is true and the certificate is no longer valid.
// Blocklisted certificates will skip the pki.disconnect_invalid check and return true.
// isInvalidCertificate will check if we should destroy a tunnel if pki.disconnect_invalid is true and
// the certificate is no longer valid. Block listed certificates will skip the pki.disconnect_invalid
// check and return true.
func (cm *connectionManager) isInvalidCertificate(now time.Time, hostinfo *HostInfo) bool {
remoteCert := hostinfo.GetCert()
if remoteCert == nil {
return false //don't tear down tunnels for handshakes in progress
return false
}
caPool := cm.intf.pki.GetCAPool()
err := caPool.VerifyCachedCertificate(now, remoteCert)
if err == nil {
return false //cert is still valid! yay!
} else if err == cert.ErrBlockListed { //avoiding errors.Is for speed
// Block listed certificates should always be disconnected
hostinfo.logger(cm.l).WithError(err).
WithField("fingerprint", remoteCert.Fingerprint).
Info("Remote certificate is blocked, tearing down the tunnel")
return true
} else if cm.intf.disconnectInvalid.Load() {
hostinfo.logger(cm.l).WithError(err).
WithField("fingerprint", remoteCert.Fingerprint).
Info("Remote certificate is no longer valid, tearing down the tunnel")
return true
} else {
//if we reach here, the cert is no longer valid, but we're configured to keep tunnels from now-invalid certs open
return false
}
if !cm.intf.disconnectInvalid.Load() && err != cert.ErrBlockListed {
// Block listed certificates should always be disconnected
return false
}
hostinfo.logger(cm.l).WithError(err).
WithField("fingerprint", remoteCert.Fingerprint).
Info("Remote certificate is no longer valid, tearing down the tunnel")
return true
}
func (cm *connectionManager) sendPunch(hostinfo *HostInfo) {
@@ -536,45 +530,15 @@ func (cm *connectionManager) sendPunch(hostinfo *HostInfo) {
func (cm *connectionManager) tryRehandshake(hostinfo *HostInfo) {
cs := cm.intf.pki.getCertState()
curCrt := hostinfo.ConnectionState.myCert
curCrtVersion := curCrt.Version()
myCrt := cs.getCertificate(curCrtVersion)
if myCrt == nil {
cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
WithField("version", curCrtVersion).
WithField("reason", "local certificate removed").
Info("Re-handshaking with remote")
cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil)
myCrt := cs.getCertificate(curCrt.Version())
if curCrt.Version() >= cs.initiatingVersion && bytes.Equal(curCrt.Signature(), myCrt.Signature()) == true {
// The current tunnel is using the latest certificate and version, no need to rehandshake.
return
}
peerCrt := hostinfo.ConnectionState.peerCert
if peerCrt != nil && curCrtVersion < peerCrt.Certificate.Version() {
// if our certificate version is less than theirs, and we have a matching version available, rehandshake?
if cs.getCertificate(peerCrt.Certificate.Version()) != nil {
cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
WithField("version", curCrtVersion).
WithField("peerVersion", peerCrt.Certificate.Version()).
WithField("reason", "local certificate version lower than peer, attempting to correct").
Info("Re-handshaking with remote")
cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], func(hh *HandshakeHostInfo) {
hh.initiatingVersionOverride = peerCrt.Certificate.Version()
})
return
}
}
if !bytes.Equal(curCrt.Signature(), myCrt.Signature()) {
cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
WithField("reason", "local certificate is not current").
Info("Re-handshaking with remote")
cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil)
return
}
if curCrtVersion < cs.initiatingVersion {
cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
WithField("reason", "current cert version < pki.initiatingVersion").
Info("Re-handshaking with remote")
cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
WithField("reason", "local certificate is not current").
Info("Re-handshaking with remote")
cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil)
return
}
cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil)
}

View File

@@ -2,11 +2,9 @@ package nebula
import (
"context"
"errors"
"net/netip"
"os"
"os/signal"
"sync"
"syscall"
"github.com/sirupsen/logrus"
@@ -15,16 +13,6 @@ import (
"github.com/slackhq/nebula/overlay"
)
type RunState int
const (
Stopped RunState = 0 // The control has yet to be started
Started RunState = 1 // The control has been started
Stopping RunState = 2 // The control is stopping
)
var ErrAlreadyStarted = errors.New("nebula is already started")
// Every interaction here needs to take extra care to copy memory and not return or use arguments "as is" when touching
// core. This means copying IP objects, slices, de-referencing pointers and taking the actual value, etc
@@ -38,9 +26,6 @@ type controlHostLister interface {
}
type Control struct {
stateLock sync.Mutex
state RunState
f *Interface
l *logrus.Logger
ctx context.Context
@@ -64,21 +49,10 @@ type ControlHostInfo struct {
CurrentRelaysThroughMe []netip.Addr `json:"currentRelaysThroughMe"`
}
// Start actually runs nebula, this is a nonblocking call.
// The returned function can be used to wait for nebula to fully stop.
func (c *Control) Start() (func(), error) {
c.stateLock.Lock()
if c.state != Stopped {
c.stateLock.Unlock()
return nil, ErrAlreadyStarted
}
// Start actually runs nebula, this is a nonblocking call. To block use Control.ShutdownBlock()
func (c *Control) Start() {
// Activate the interface
err := c.f.activate()
if err != nil {
c.stateLock.Unlock()
return nil, err
}
c.f.activate()
// Call all the delayed funcs that waited patiently for the interface to be created.
if c.sshStart != nil {
@@ -98,33 +72,15 @@ func (c *Control) Start() (func(), error) {
}
// Start reading packets.
c.state = Started
c.stateLock.Unlock()
return c.f.run()
}
func (c *Control) State() RunState {
c.stateLock.Lock()
defer c.stateLock.Unlock()
return c.state
c.f.run()
}
func (c *Control) Context() context.Context {
return c.ctx
}
// Stop is a non-blocking call that signals nebula to close all tunnels and shut down
// Stop signals nebula to shutdown and close all tunnels, returns after the shutdown is complete
func (c *Control) Stop() {
c.stateLock.Lock()
if c.state != Started {
c.stateLock.Unlock()
// We are stopping or stopped already
return
}
c.state = Stopping
c.stateLock.Unlock()
// Stop the handshakeManager (and other services), to prevent new tunnels from
// being created while we're shutting them all down.
c.cancel()
@@ -133,7 +89,7 @@ func (c *Control) Stop() {
if err := c.f.Close(); err != nil {
c.l.WithError(err).Error("Close interface failed")
}
c.state = Stopped
c.l.Info("Goodbye")
}
// ShutdownBlock will listen for and block on term and interrupt signals, calling Control.Stop() once signalled

View File

@@ -129,109 +129,6 @@ func newSimpleServer(v cert.Version, caCrt cert.Certificate, caKey []byte, name
return control, vpnNetworks, udpAddr, c
}
// newServer creates a nebula instance with fewer assumptions
func newServer(caCrt []cert.Certificate, certs []cert.Certificate, key []byte, overrides m) (*nebula.Control, []netip.Prefix, netip.AddrPort, *config.C) {
l := NewTestLogger()
vpnNetworks := certs[len(certs)-1].Networks()
var udpAddr netip.AddrPort
if vpnNetworks[0].Addr().Is4() {
budpIp := vpnNetworks[0].Addr().As4()
budpIp[1] -= 128
udpAddr = netip.AddrPortFrom(netip.AddrFrom4(budpIp), 4242)
} else {
budpIp := vpnNetworks[0].Addr().As16()
// beef for funsies
budpIp[2] = 190
budpIp[3] = 239
udpAddr = netip.AddrPortFrom(netip.AddrFrom16(budpIp), 4242)
}
caStr := ""
for _, ca := range caCrt {
x, err := ca.MarshalPEM()
if err != nil {
panic(err)
}
caStr += string(x)
}
certStr := ""
for _, c := range certs {
x, err := c.MarshalPEM()
if err != nil {
panic(err)
}
certStr += string(x)
}
mc := m{
"pki": m{
"ca": caStr,
"cert": certStr,
"key": string(key),
},
//"tun": m{"disabled": true},
"firewall": m{
"outbound": []m{{
"proto": "any",
"port": "any",
"host": "any",
}},
"inbound": []m{{
"proto": "any",
"port": "any",
"host": "any",
}},
},
//"handshakes": m{
// "try_interval": "1s",
//},
"listen": m{
"host": udpAddr.Addr().String(),
"port": udpAddr.Port(),
},
"logging": m{
"timestamp_format": fmt.Sprintf("%v 15:04:05.000000", certs[0].Name()),
"level": l.Level.String(),
},
"timers": m{
"pending_deletion_interval": 2,
"connection_alive_interval": 2,
},
}
if overrides != nil {
final := m{}
err := mergo.Merge(&final, overrides, mergo.WithAppendSlice)
if err != nil {
panic(err)
}
err = mergo.Merge(&final, mc, mergo.WithAppendSlice)
if err != nil {
panic(err)
}
mc = final
}
cb, err := yaml.Marshal(mc)
if err != nil {
panic(err)
}
c := config.NewC(l)
cStr := string(cb)
c.LoadString(cStr)
control, err := nebula.Main(c, false, "e2e-test", l, nil)
if err != nil {
panic(err)
}
return control, vpnNetworks, udpAddr, c
}
type doneCb func()
func deadline(t *testing.T, seconds time.Duration) doneCb {

View File

@@ -4,16 +4,12 @@
package e2e
import (
"fmt"
"net/netip"
"testing"
"time"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/cert_test"
"github.com/slackhq/nebula/e2e/router"
"github.com/stretchr/testify/assert"
"gopkg.in/yaml.v3"
)
func TestDropInactiveTunnels(t *testing.T) {
@@ -59,262 +55,3 @@ func TestDropInactiveTunnels(t *testing.T) {
myControl.Stop()
theirControl.Stop()
}
func TestCertUpgrade(t *testing.T) {
// The goal of this test is to ensure the shortest inactivity timeout will close the tunnel on both sides
// under ideal conditions
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
caB, err := ca.MarshalPEM()
if err != nil {
panic(err)
}
ca2, _, caKey2, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
ca2B, err := ca2.MarshalPEM()
if err != nil {
panic(err)
}
caStr := fmt.Sprintf("%s\n%s", caB, ca2B)
myCert, _, myPrivKey, _ := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "me", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.1/24")}, nil, []string{})
_, myCert2Pem := cert_test.NewTestCertDifferentVersion(myCert, cert.Version2, ca2, caKey2)
theirCert, _, theirPrivKey, _ := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "them", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.2/24")}, nil, []string{})
theirCert2, _ := cert_test.NewTestCertDifferentVersion(theirCert, cert.Version2, ca2, caKey2)
myControl, myVpnIpNet, myUdpAddr, myC := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{myCert}, myPrivKey, m{})
theirControl, theirVpnIpNet, theirUdpAddr, _ := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{theirCert, theirCert2}, theirPrivKey, m{})
// Share our underlay information
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
// Start the servers
myControl.Start()
theirControl.Start()
r := router.NewR(t, myControl, theirControl)
defer r.RenderFlow()
r.Log("Assert the tunnel between me and them works")
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
r.Log("yay")
//todo ???
time.Sleep(1 * time.Second)
r.FlushAll()
mc := m{
"pki": m{
"ca": caStr,
"cert": string(myCert2Pem),
"key": string(myPrivKey),
},
//"tun": m{"disabled": true},
"firewall": myC.Settings["firewall"],
//"handshakes": m{
// "try_interval": "1s",
//},
"listen": myC.Settings["listen"],
"logging": myC.Settings["logging"],
"timers": myC.Settings["timers"],
}
cb, err := yaml.Marshal(mc)
if err != nil {
panic(err)
}
r.Logf("reload new v2-only config")
err = myC.ReloadConfigString(string(cb))
assert.NoError(t, err)
r.Log("yay, spin until their sees it")
waitStart := time.Now()
for {
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
c := theirControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false)
if c == nil {
r.Log("nil")
} else {
version := c.Cert.Version()
r.Logf("version %d", version)
if version == cert.Version2 {
break
}
}
since := time.Since(waitStart)
if since > time.Second*10 {
t.Fatal("Cert should be new by now")
}
time.Sleep(time.Second)
}
r.RenderHostmaps("Final hostmaps", myControl, theirControl)
myControl.Stop()
theirControl.Stop()
}
func TestCertDowngrade(t *testing.T) {
// The goal of this test is to ensure the shortest inactivity timeout will close the tunnel on both sides
// under ideal conditions
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
caB, err := ca.MarshalPEM()
if err != nil {
panic(err)
}
ca2, _, caKey2, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
ca2B, err := ca2.MarshalPEM()
if err != nil {
panic(err)
}
caStr := fmt.Sprintf("%s\n%s", caB, ca2B)
myCert, _, myPrivKey, myCertPem := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "me", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.1/24")}, nil, []string{})
myCert2, _ := cert_test.NewTestCertDifferentVersion(myCert, cert.Version2, ca2, caKey2)
theirCert, _, theirPrivKey, _ := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "them", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.2/24")}, nil, []string{})
theirCert2, _ := cert_test.NewTestCertDifferentVersion(theirCert, cert.Version2, ca2, caKey2)
myControl, myVpnIpNet, myUdpAddr, myC := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{myCert2}, myPrivKey, m{})
theirControl, theirVpnIpNet, theirUdpAddr, _ := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{theirCert, theirCert2}, theirPrivKey, m{})
// Share our underlay information
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
// Start the servers
myControl.Start()
theirControl.Start()
r := router.NewR(t, myControl, theirControl)
defer r.RenderFlow()
r.Log("Assert the tunnel between me and them works")
//assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r)
//r.Log("yay")
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
r.Log("yay")
//todo ???
time.Sleep(1 * time.Second)
r.FlushAll()
mc := m{
"pki": m{
"ca": caStr,
"cert": string(myCertPem),
"key": string(myPrivKey),
},
"firewall": myC.Settings["firewall"],
"listen": myC.Settings["listen"],
"logging": myC.Settings["logging"],
"timers": myC.Settings["timers"],
}
cb, err := yaml.Marshal(mc)
if err != nil {
panic(err)
}
r.Logf("reload new v1-only config")
err = myC.ReloadConfigString(string(cb))
assert.NoError(t, err)
r.Log("yay, spin until their sees it")
waitStart := time.Now()
for {
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
c := theirControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false)
c2 := myControl.GetHostInfoByVpnAddr(theirVpnIpNet[0].Addr(), false)
if c == nil || c2 == nil {
r.Log("nil")
} else {
version := c.Cert.Version()
theirVersion := c2.Cert.Version()
r.Logf("version %d,%d", version, theirVersion)
if version == cert.Version1 {
break
}
}
since := time.Since(waitStart)
if since > time.Second*5 {
r.Log("it is unusual that the cert is not new yet, but not a failure yet")
}
if since > time.Second*10 {
r.Log("wtf")
t.Fatal("Cert should be new by now")
}
time.Sleep(time.Second)
}
r.RenderHostmaps("Final hostmaps", myControl, theirControl)
myControl.Stop()
theirControl.Stop()
}
func TestCertMismatchCorrection(t *testing.T) {
// The goal of this test is to ensure the shortest inactivity timeout will close the tunnel on both sides
// under ideal conditions
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
ca2, _, caKey2, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
myCert, _, myPrivKey, _ := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "me", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.1/24")}, nil, []string{})
myCert2, _ := cert_test.NewTestCertDifferentVersion(myCert, cert.Version2, ca2, caKey2)
theirCert, _, theirPrivKey, _ := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "them", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.2/24")}, nil, []string{})
theirCert2, _ := cert_test.NewTestCertDifferentVersion(theirCert, cert.Version2, ca2, caKey2)
myControl, myVpnIpNet, myUdpAddr, _ := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{myCert2}, myPrivKey, m{})
theirControl, theirVpnIpNet, theirUdpAddr, _ := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{theirCert, theirCert2}, theirPrivKey, m{})
// Share our underlay information
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
// Start the servers
myControl.Start()
theirControl.Start()
r := router.NewR(t, myControl, theirControl)
defer r.RenderFlow()
r.Log("Assert the tunnel between me and them works")
//assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r)
//r.Log("yay")
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
r.Log("yay")
//todo ???
time.Sleep(1 * time.Second)
r.FlushAll()
waitStart := time.Now()
for {
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
c := theirControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false)
c2 := myControl.GetHostInfoByVpnAddr(theirVpnIpNet[0].Addr(), false)
if c == nil || c2 == nil {
r.Log("nil")
} else {
version := c.Cert.Version()
theirVersion := c2.Cert.Version()
r.Logf("version %d,%d", version, theirVersion)
if version == theirVersion {
break
}
}
since := time.Since(waitStart)
if since > time.Second*5 {
r.Log("wtf")
}
if since > time.Second*10 {
r.Log("wtf")
t.Fatal("Cert should be new by now")
}
time.Sleep(time.Second)
}
r.RenderHostmaps("Final hostmaps", myControl, theirControl)
myControl.Stop()
theirControl.Stop()
}

View File

@@ -23,17 +23,13 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
return false
}
// If we're connecting to a v6 address we must use a v2 cert
cs := f.pki.getCertState()
v := cs.initiatingVersion
if hh.initiatingVersionOverride != cert.VersionPre1 {
v = hh.initiatingVersionOverride
} else if v < cert.Version2 {
// If we're connecting to a v6 address we should encourage use of a V2 cert
for _, a := range hh.hostinfo.vpnAddrs {
if a.Is6() {
v = cert.Version2
break
}
for _, a := range hh.hostinfo.vpnAddrs {
if a.Is6() {
v = cert.Version2
break
}
}
@@ -52,7 +48,6 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
WithField("certVersion", v).
Error("Unable to handshake with host because no certificate handshake bytes is available")
return false
}
ci, err := NewConnectionState(f.l, cs, crt, true, noise.HandshakeIX)
@@ -108,7 +103,6 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
WithField("certVersion", cs.initiatingVersion).
Error("Unable to handshake with host because no certificate is available")
return
}
ci, err := NewConnectionState(f.l, cs, crt, false, noise.HandshakeIX)
@@ -149,8 +143,8 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
remoteCert, err := f.pki.GetCAPool().VerifyCertificate(time.Now(), rc)
if err != nil {
fp, fperr := rc.Fingerprint()
if fperr != nil {
fp, err := rc.Fingerprint()
if err != nil {
fp = "<error generating certificate fingerprint>"
}
@@ -169,19 +163,16 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
if remoteCert.Certificate.Version() != ci.myCert.Version() {
// We started off using the wrong certificate version, lets see if we can match the version that was sent to us
myCertOtherVersion := cs.getCertificate(remoteCert.Certificate.Version())
if myCertOtherVersion == nil {
if f.l.Level >= logrus.DebugLevel {
f.l.WithError(err).WithFields(m{
"udpAddr": addr,
"handshake": m{"stage": 1, "style": "ix_psk0"},
"cert": remoteCert,
}).Debug("Might be unable to handshake with host due to missing certificate version")
}
} else {
// Record the certificate we are actually using
ci.myCert = myCertOtherVersion
rc := cs.getCertificate(remoteCert.Certificate.Version())
if rc == nil {
f.l.WithError(err).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).WithField("cert", remoteCert).
Info("Unable to handshake with host due to missing certificate version")
return
}
// Record the certificate we are actually using
ci.myCert = rc
}
if len(remoteCert.Certificate.Networks()) == 0 {

View File

@@ -68,12 +68,11 @@ type HandshakeManager struct {
type HandshakeHostInfo struct {
sync.Mutex
startTime time.Time // Time that we first started trying with this handshake
ready bool // Is the handshake ready
initiatingVersionOverride cert.Version // Should we use a non-default cert version for this handshake?
counter int64 // How many attempts have we made so far
lastRemotes []netip.AddrPort // Remotes that we sent to during the previous attempt
packetStore []*cachedPacket // A set of packets to be transmitted once the handshake completes
startTime time.Time // Time that we first started trying with this handshake
ready bool // Is the handshake ready
counter int64 // How many attempts have we made so far
lastRemotes []netip.AddrPort // Remotes that we sent to during the previous attempt
packetStore []*cachedPacket // A set of packets to be transmitted once the handshake completes
hostinfo *HostInfo
}

View File

@@ -6,8 +6,8 @@ import (
"fmt"
"io"
"net/netip"
"os"
"runtime"
"sync"
"sync/atomic"
"time"
@@ -87,7 +87,6 @@ type Interface struct {
writers []udp.Conn
readers []io.ReadWriteCloser
wg sync.WaitGroup
metricHandshakes metrics.Histogram
messageMetrics *MessageMetrics
@@ -210,7 +209,7 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
// activate creates the interface on the host. After the interface is created, any
// other services that want to bind listeners to its IP may do so successfully. However,
// the interface isn't going to process anything until run() is called.
func (f *Interface) activate() error {
func (f *Interface) activate() {
// actually turn on tun dev
addr, err := f.outside.LocalAddr()
@@ -231,38 +230,33 @@ func (f *Interface) activate() error {
if i > 0 {
reader, err = f.inside.NewMultiQueueReader()
if err != nil {
return err
f.l.Fatal(err)
}
}
f.readers[i] = reader
}
if err = f.inside.Activate(); err != nil {
if err := f.inside.Activate(); err != nil {
f.inside.Close()
return err
f.l.Fatal(err)
}
return nil
}
func (f *Interface) run() (func(), error) {
func (f *Interface) run() {
// Launch n queues to read packets from udp
for i := 0; i < f.routines; i++ {
f.wg.Add(1)
go f.listenOut(i)
}
// Launch n queues to read packets from tun dev
for i := 0; i < f.routines; i++ {
f.wg.Add(1)
go f.listenIn(f.readers[i], i)
}
return f.wg.Wait, nil
}
func (f *Interface) listenOut(i int) {
runtime.LockOSThread()
var li udp.Conn
if i > 0 {
li = f.writers[i]
@@ -277,21 +271,17 @@ func (f *Interface) listenOut(i int) {
fwPacket := &firewall.Packet{}
nb := make([]byte, 12, 12)
err := li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) {
li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte, release func()) {
if release != nil {
defer release()
}
f.readOutsidePackets(fromUdpAddr, nil, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get(f.l))
})
if err != nil && !f.closed.Load() {
f.l.WithError(err).Error("Error while reading packet inbound packet, closing")
//TODO: Trigger Control to close
}
f.l.Debugf("underlay reader %v is done", i)
f.wg.Done()
}
func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
runtime.LockOSThread()
packet := make([]byte, mtu)
out := make([]byte, mtu)
fwPacket := &firewall.Packet{}
@@ -302,18 +292,17 @@ func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
for {
n, err := reader.Read(packet)
if err != nil {
if !f.closed.Load() {
f.l.WithError(err).Error("Error while reading outbound packet, closing")
//TODO: Trigger Control to close
if errors.Is(err, os.ErrClosed) && f.closed.Load() {
return
}
break
f.l.WithError(err).Error("Error while reading outbound packet")
// This only seems to happen when something fatal happens to the fd, so exit.
os.Exit(2)
}
f.consumeInsidePacket(packet[:n], fwPacket, nb, out, i, conntrackCache.Get(f.l))
}
f.l.Debugf("overlay reader %v is done", i)
f.wg.Done()
}
func (f *Interface) RegisterConfigChangeCallbacks(c *config.C) {
@@ -465,7 +454,6 @@ func (f *Interface) GetCertState() *CertState {
func (f *Interface) Close() error {
f.closed.Store(true)
// Release the udp readers
for _, u := range f.writers {
err := u.Close()
if err != nil {
@@ -473,13 +461,6 @@ func (f *Interface) Close() error {
}
}
// Release the tun readers
for _, u := range f.readers {
err := u.Close()
if err != nil {
f.l.WithError(err).Error("Error while closing tun device")
}
}
return nil
// Release the tun device
return f.inside.Close()
}

18
main.go
View File

@@ -284,14 +284,14 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
}
return &Control{
f: ifce,
l: l,
ctx: ctx,
cancel: cancel,
sshStart: sshStart,
statsStart: statsStart,
dnsStart: dnsStart,
lighthouseStart: lightHouse.StartUpdateWorker,
connectionManagerStart: connManager.Start,
ifce,
l,
ctx,
cancel,
sshStart,
statsStart,
dnsStart,
lightHouse.StartUpdateWorker,
connManager.Start,
}, nil
}

View File

@@ -29,7 +29,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
return
}
//f.l.Error("in packet ", h)
//l.Error("in packet ", header, packet[HeaderLen:])
if ip.IsValid() {
if f.myVpnNetworksTable.Contains(ip.Addr()) {
if f.l.Level >= logrus.DebugLevel {

83
pki.go
View File

@@ -100,62 +100,55 @@ func (p *PKI) reloadCerts(c *config.C, initial bool) *util.ContextualError {
currentState := p.cs.Load()
if newState.v1Cert != nil {
if currentState.v1Cert == nil {
//adding certs is fine, actually. Networks-in-common confirmed in newCertState().
} else {
// did IP in cert change? if so, don't set
if !slices.Equal(currentState.v1Cert.Networks(), newState.v1Cert.Networks()) {
return util.NewContextualError(
"Networks in new cert was different from old",
m{"new_networks": newState.v1Cert.Networks(), "old_networks": currentState.v1Cert.Networks(), "cert_version": cert.Version1},
nil,
)
}
if currentState.v1Cert.Curve() != newState.v1Cert.Curve() {
return util.NewContextualError(
"Curve in new v1 cert was different from old",
m{"new_curve": newState.v1Cert.Curve(), "old_curve": currentState.v1Cert.Curve(), "cert_version": cert.Version1},
nil,
)
}
return util.NewContextualError("v1 certificate was added, restart required", nil, err)
}
// did IP in cert change? if so, don't set
if !slices.Equal(currentState.v1Cert.Networks(), newState.v1Cert.Networks()) {
return util.NewContextualError(
"Networks in new cert was different from old",
m{"new_networks": newState.v1Cert.Networks(), "old_networks": currentState.v1Cert.Networks()},
nil,
)
}
if currentState.v1Cert.Curve() != newState.v1Cert.Curve() {
return util.NewContextualError(
"Curve in new cert was different from old",
m{"new_curve": newState.v1Cert.Curve(), "old_curve": currentState.v1Cert.Curve()},
nil,
)
}
} else if currentState.v1Cert != nil {
//TODO: CERT-V2 we should be able to tear this down
return util.NewContextualError("v1 certificate was removed, restart required", nil, err)
}
if newState.v2Cert != nil {
if currentState.v2Cert == nil {
//adding certs is fine, actually
} else {
// did IP in cert change? if so, don't set
if !slices.Equal(currentState.v2Cert.Networks(), newState.v2Cert.Networks()) {
return util.NewContextualError(
"Networks in new cert was different from old",
m{"new_networks": newState.v2Cert.Networks(), "old_networks": currentState.v2Cert.Networks(), "cert_version": cert.Version2},
nil,
)
}
if currentState.v2Cert.Curve() != newState.v2Cert.Curve() {
return util.NewContextualError(
"Curve in new cert was different from old",
m{"new_curve": newState.v2Cert.Curve(), "old_curve": currentState.v2Cert.Curve(), "cert_version": cert.Version2},
nil,
)
}
return util.NewContextualError("v2 certificate was added, restart required", nil, err)
}
} else if currentState.v2Cert != nil {
//newState.v1Cert is non-nil bc empty certstates aren't permitted
if newState.v1Cert == nil {
return util.NewContextualError("v1 and v2 certs are nil, this should be impossible", nil, err)
}
//if we're going to v1-only, we need to make sure we didn't orphan any v2-cert vpnaddrs
if !slices.Equal(currentState.v2Cert.Networks(), newState.v1Cert.Networks()) {
// did IP in cert change? if so, don't set
if !slices.Equal(currentState.v2Cert.Networks(), newState.v2Cert.Networks()) {
return util.NewContextualError(
"Removing a V2 cert is not permitted unless it has identical networks to the new V1 cert",
m{"new_v1_networks": newState.v1Cert.Networks(), "old_v2_networks": currentState.v2Cert.Networks()},
"Networks in new cert was different from old",
m{"new_networks": newState.v2Cert.Networks(), "old_networks": currentState.v2Cert.Networks()},
nil,
)
}
if currentState.v2Cert.Curve() != newState.v2Cert.Curve() {
return util.NewContextualError(
"Curve in new cert was different from old",
m{"new_curve": newState.v2Cert.Curve(), "old_curve": currentState.v2Cert.Curve()},
nil,
)
}
} else if currentState.v2Cert != nil {
return util.NewContextualError("v2 certificate was removed, restart required", nil, err)
}
// Cipher cant be hot swapped so just leave it at what it was before

View File

@@ -44,10 +44,7 @@ type Service struct {
}
func New(control *nebula.Control) (*Service, error) {
wait, err := control.Start()
if err != nil {
return nil, err
}
control.Start()
ctx := control.Context()
eg, ctx := errgroup.WithContext(ctx)
@@ -144,12 +141,6 @@ func New(control *nebula.Control) (*Service, error) {
}
})
// Add the nebula wait function to the group
eg.Go(func() error {
wait()
return nil
})
return &s, nil
}

16
udp/config.go Normal file
View File

@@ -0,0 +1,16 @@
package udp
import "sync/atomic"
var disableUDPCsum atomic.Bool
// SetDisableUDPCsum controls whether IPv4 UDP sockets opt out of kernel
// checksum calculation via SO_NO_CHECK. Only applicable on platforms that
// support the option (Linux). IPv6 always keeps the checksum enabled.
func SetDisableUDPCsum(disable bool) {
disableUDPCsum.Store(disable)
}
func udpChecksumDisabled() bool {
return disableUDPCsum.Load()
}

View File

@@ -11,12 +11,13 @@ const MTU = 9001
type EncReader func(
addr netip.AddrPort,
payload []byte,
release func(),
)
type Conn interface {
Rebind() error
LocalAddr() (netip.AddrPort, error)
ListenOut(r EncReader) error
ListenOut(r EncReader)
WriteTo(b []byte, addr netip.AddrPort) error
ReloadConfig(c *config.C)
Close() error
@@ -30,8 +31,8 @@ func (NoopConn) Rebind() error {
func (NoopConn) LocalAddr() (netip.AddrPort, error) {
return netip.AddrPort{}, nil
}
func (NoopConn) ListenOut(_ EncReader) error {
return nil
func (NoopConn) ListenOut(_ EncReader) {
return
}
func (NoopConn) WriteTo(_ []byte, _ netip.AddrPort) error {
return nil

1740
udp/io_uring_linux.go Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,25 @@
//go:build linux && (386 || amd64p32 || arm || mips || mipsle) && !android && !e2e_testing
// +build linux
// +build 386 amd64p32 arm mips mipsle
// +build !android
// +build !e2e_testing
package udp
import "golang.org/x/sys/unix"
func controllen(n int) uint32 {
return uint32(n)
}
func setCmsgLen(h *unix.Cmsghdr, n int) {
h.Len = uint32(unix.CmsgLen(n))
}
func setIovecLen(v *unix.Iovec, n int) {
v.Len = uint32(n)
}
func setMsghdrIovlen(m *unix.Msghdr, n int) {
m.Iovlen = uint32(n)
}

View File

@@ -0,0 +1,25 @@
//go:build linux && (amd64 || arm64 || ppc64 || ppc64le || mips64 || mips64le || s390x || riscv64 || loong64) && !android && !e2e_testing
// +build linux
// +build amd64 arm64 ppc64 ppc64le mips64 mips64le s390x riscv64 loong64
// +build !android
// +build !e2e_testing
package udp
import "golang.org/x/sys/unix"
func controllen(n int) uint64 {
return uint64(n)
}
func setCmsgLen(h *unix.Cmsghdr, n int) {
h.Len = uint64(unix.CmsgLen(n))
}
func setIovecLen(v *unix.Iovec, n int) {
v.Len = uint64(n)
}
func setMsghdrIovlen(m *unix.Msghdr, n int) {
m.Iovlen = uint64(n)
}

25
udp/sendmmsg_linux_32.go Normal file
View File

@@ -0,0 +1,25 @@
//go:build linux && (386 || amd64p32 || arm || mips || mipsle) && !android && !e2e_testing
package udp
import (
"unsafe"
"golang.org/x/sys/unix"
)
type linuxMmsgHdr struct {
Hdr unix.Msghdr
Len uint32
}
func sendmmsg(fd int, hdrs []linuxMmsgHdr, flags int) (int, error) {
if len(hdrs) == 0 {
return 0, nil
}
n, _, errno := unix.Syscall6(unix.SYS_SENDMMSG, uintptr(fd), uintptr(unsafe.Pointer(&hdrs[0])), uintptr(len(hdrs)), uintptr(flags), 0, 0)
if errno != 0 {
return int(n), errno
}
return int(n), nil
}

26
udp/sendmmsg_linux_64.go Normal file
View File

@@ -0,0 +1,26 @@
//go:build linux && (amd64 || arm64 || ppc64 || ppc64le || mips64 || mips64le || s390x || riscv64 || loong64) && !android && !e2e_testing
package udp
import (
"unsafe"
"golang.org/x/sys/unix"
)
type linuxMmsgHdr struct {
Hdr unix.Msghdr
Len uint32
_ uint32
}
func sendmmsg(fd int, hdrs []linuxMmsgHdr, flags int) (int, error) {
if len(hdrs) == 0 {
return 0, nil
}
n, _, errno := unix.Syscall6(unix.SYS_SENDMMSG, uintptr(fd), uintptr(unsafe.Pointer(&hdrs[0])), uintptr(len(hdrs)), uintptr(flags), 0, 0)
if errno != 0 {
return int(n), errno
}
return int(n), nil
}

View File

@@ -165,7 +165,7 @@ func NewUDPStatsEmitter(udpConns []Conn) func() {
return func() {}
}
func (u *StdConn) ListenOut(r EncReader) error {
func (u *StdConn) ListenOut(r EncReader) {
buffer := make([]byte, MTU)
for {
@@ -173,13 +173,14 @@ func (u *StdConn) ListenOut(r EncReader) error {
n, rua, err := u.ReadFromUDPAddrPort(buffer)
if err != nil {
if errors.Is(err, net.ErrClosed) {
return err
u.l.WithError(err).Debug("udp socket is closed, exiting read loop")
return
}
u.l.WithError(err).Error("unexpected udp socket receive error")
}
r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n])
r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n], nil)
}
}

View File

@@ -71,16 +71,17 @@ type rawMessage struct {
Len uint32
}
func (u *GenericConn) ListenOut(r EncReader) error {
func (u *GenericConn) ListenOut(r EncReader) {
buffer := make([]byte, MTU)
for {
// Just read one packet at a time
n, rua, err := u.ReadFromUDPAddrPort(buffer)
if err != nil {
return err
u.l.WithError(err).Debug("udp socket is closed, exiting read loop")
return
}
r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n])
r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n], nil)
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -7,6 +7,9 @@
package udp
import (
"errors"
"fmt"
"golang.org/x/sys/unix"
)
@@ -30,17 +33,29 @@ type rawMessage struct {
Len uint32
}
func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) {
func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte, [][]byte) {
controlLen := int(u.controlLen.Load())
msgs := make([]rawMessage, n)
buffers := make([][]byte, n)
names := make([][]byte, n)
var controls [][]byte
if controlLen > 0 {
controls = make([][]byte, n)
}
for i := range msgs {
buffers[i] = make([]byte, MTU)
size := int(u.groBufSize.Load())
if size < MTU {
size = MTU
}
buf := u.borrowRxBuffer(size)
buffers[i] = buf
names[i] = make([]byte, unix.SizeofSockaddrInet6)
vs := []iovec{
{Base: &buffers[i][0], Len: uint32(len(buffers[i]))},
{Base: &buf[0], Len: uint32(len(buf))},
}
msgs[i].Hdr.Iov = &vs[0]
@@ -48,7 +63,71 @@ func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) {
msgs[i].Hdr.Name = &names[i][0]
msgs[i].Hdr.Namelen = uint32(len(names[i]))
if controlLen > 0 {
controls[i] = make([]byte, controlLen)
msgs[i].Hdr.Control = &controls[i][0]
msgs[i].Hdr.Controllen = controllen(len(controls[i]))
} else {
msgs[i].Hdr.Control = nil
msgs[i].Hdr.Controllen = controllen(0)
}
}
return msgs, buffers, names
return msgs, buffers, names, controls
}
func setIovecBase(msg *rawMessage, buf []byte) {
iov := (*iovec)(msg.Hdr.Iov)
iov.Base = &buf[0]
iov.Len = uint32(len(buf))
}
func rawMessageToUnixMsghdr(msg *rawMessage) (unix.Msghdr, unix.Iovec, error) {
var hdr unix.Msghdr
var iov unix.Iovec
if msg == nil {
return hdr, iov, errors.New("nil rawMessage")
}
if msg.Hdr.Iov == nil || msg.Hdr.Iov.Base == nil {
return hdr, iov, errors.New("rawMessage missing payload buffer")
}
payloadLen := int(msg.Hdr.Iov.Len)
if payloadLen < 0 {
return hdr, iov, fmt.Errorf("invalid payload length: %d", payloadLen)
}
iov.Base = msg.Hdr.Iov.Base
iov.Len = uint32(payloadLen)
hdr.Iov = &iov
hdr.Iovlen = 1
hdr.Name = msg.Hdr.Name
// CRITICAL: Always set to full buffer size for receive, not what kernel wrote last time
if hdr.Name != nil {
hdr.Namelen = uint32(unix.SizeofSockaddrInet6)
} else {
hdr.Namelen = 0
}
hdr.Control = msg.Hdr.Control
// CRITICAL: Use the allocated size, not what was previously returned
if hdr.Control != nil {
// Control buffer size is stored in Controllen from PrepareRawMessages
hdr.Controllen = msg.Hdr.Controllen
} else {
hdr.Controllen = 0
}
hdr.Flags = 0 // Reset flags for new receive
return hdr, iov, nil
}
func updateRawMessageFromUnixMsghdr(msg *rawMessage, hdr *unix.Msghdr, n int) {
if msg == nil || hdr == nil {
return
}
msg.Hdr.Namelen = hdr.Namelen
msg.Hdr.Controllen = hdr.Controllen
msg.Hdr.Flags = hdr.Flags
if n < 0 {
n = 0
}
msg.Len = uint32(n)
}

View File

@@ -7,6 +7,9 @@
package udp
import (
"errors"
"fmt"
"golang.org/x/sys/unix"
)
@@ -33,25 +36,99 @@ type rawMessage struct {
Pad0 [4]byte
}
func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) {
func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte, [][]byte) {
controlLen := int(u.controlLen.Load())
msgs := make([]rawMessage, n)
buffers := make([][]byte, n)
names := make([][]byte, n)
var controls [][]byte
if controlLen > 0 {
controls = make([][]byte, n)
}
for i := range msgs {
buffers[i] = make([]byte, MTU)
size := int(u.groBufSize.Load())
if size < MTU {
size = MTU
}
buf := u.borrowRxBuffer(size)
buffers[i] = buf
names[i] = make([]byte, unix.SizeofSockaddrInet6)
vs := []iovec{
{Base: &buffers[i][0], Len: uint64(len(buffers[i]))},
}
vs := []iovec{{Base: &buf[0], Len: uint64(len(buf))}}
msgs[i].Hdr.Iov = &vs[0]
msgs[i].Hdr.Iovlen = uint64(len(vs))
msgs[i].Hdr.Name = &names[i][0]
msgs[i].Hdr.Namelen = uint32(len(names[i]))
if controlLen > 0 {
controls[i] = make([]byte, controlLen)
msgs[i].Hdr.Control = &controls[i][0]
msgs[i].Hdr.Controllen = controllen(len(controls[i]))
} else {
msgs[i].Hdr.Control = nil
msgs[i].Hdr.Controllen = controllen(0)
}
}
return msgs, buffers, names
return msgs, buffers, names, controls
}
func setIovecBase(msg *rawMessage, buf []byte) {
iov := (*iovec)(msg.Hdr.Iov)
iov.Base = &buf[0]
iov.Len = uint64(len(buf))
}
func rawMessageToUnixMsghdr(msg *rawMessage) (unix.Msghdr, unix.Iovec, error) {
var hdr unix.Msghdr
var iov unix.Iovec
if msg == nil {
return hdr, iov, errors.New("nil rawMessage")
}
if msg.Hdr.Iov == nil || msg.Hdr.Iov.Base == nil {
return hdr, iov, errors.New("rawMessage missing payload buffer")
}
payloadLen := int(msg.Hdr.Iov.Len)
if payloadLen < 0 {
return hdr, iov, fmt.Errorf("invalid payload length: %d", payloadLen)
}
iov.Base = msg.Hdr.Iov.Base
iov.Len = uint64(payloadLen)
hdr.Iov = &iov
hdr.Iovlen = 1
hdr.Name = msg.Hdr.Name
// CRITICAL: Always set to full buffer size for receive, not what kernel wrote last time
if hdr.Name != nil {
hdr.Namelen = uint32(unix.SizeofSockaddrInet6)
} else {
hdr.Namelen = 0
}
hdr.Control = msg.Hdr.Control
// CRITICAL: Use the allocated size, not what was previously returned
if hdr.Control != nil {
// Control buffer size is stored in Controllen from PrepareRawMessages
hdr.Controllen = msg.Hdr.Controllen
} else {
hdr.Controllen = 0
}
hdr.Flags = 0 // Reset flags for new receive
return hdr, iov, nil
}
func updateRawMessageFromUnixMsghdr(msg *rawMessage, hdr *unix.Msghdr, n int) {
if msg == nil || hdr == nil {
return
}
msg.Hdr.Namelen = hdr.Namelen
msg.Hdr.Controllen = hdr.Controllen
msg.Hdr.Flags = hdr.Flags
if n < 0 {
n = 0
}
msg.Len = uint32(n)
}

View File

@@ -134,7 +134,7 @@ func (u *RIOConn) bind(sa windows.Sockaddr) error {
return nil
}
func (u *RIOConn) ListenOut(r EncReader) error {
func (u *RIOConn) ListenOut(r EncReader) {
buffer := make([]byte, MTU)
for {
@@ -142,13 +142,14 @@ func (u *RIOConn) ListenOut(r EncReader) error {
n, rua, err := u.receive(buffer)
if err != nil {
if errors.Is(err, net.ErrClosed) {
return err
u.l.WithError(err).Debug("udp socket is closed, exiting read loop")
return
}
u.l.WithError(err).Error("unexpected udp socket receive error")
continue
}
r(netip.AddrPortFrom(netip.AddrFrom16(rua.Addr).Unmap(), (rua.Port>>8)|((rua.Port&0xff)<<8)), buffer[:n])
r(netip.AddrPortFrom(netip.AddrFrom16(rua.Addr).Unmap(), (rua.Port>>8)|((rua.Port&0xff)<<8)), buffer[:n], nil)
}
}

View File

@@ -6,7 +6,6 @@ package udp
import (
"io"
"net/netip"
"os"
"sync/atomic"
"github.com/sirupsen/logrus"
@@ -107,13 +106,13 @@ func (u *TesterConn) WriteTo(b []byte, addr netip.AddrPort) error {
return nil
}
func (u *TesterConn) ListenOut(r EncReader) error {
func (u *TesterConn) ListenOut(r EncReader) {
for {
p, ok := <-u.RxPackets
if !ok {
return os.ErrClosed
return
}
r(p.From, p.Data)
r(p.From, p.Data, func() {})
}
}