Compare commits

..

8 Commits

Author SHA1 Message Date
Ryan Huber
a4b7f624da sure 2025-11-03 17:23:57 +00:00
Ryan Huber
1c069a8e42 reuse control on gso 2025-11-03 11:14:52 +00:00
Ryan Huber
0d8bd11818 reuse GRO slices 2025-11-03 11:06:07 +00:00
Ryan Huber
5128e2653e reuse packet buffer 2025-11-03 10:52:09 +00:00
Ryan Huber
c73b2dfbc7 fixed fallback for non io_uring packet send/recv 2025-11-03 10:45:30 +00:00
Ryan Huber
3dea761530 fix compile for 386 2025-11-03 10:12:02 +00:00
Ryan Huber
b394112ad9 gso and gro with uring on send/receive for udp 2025-11-03 09:59:45 +00:00
Jack Doan
770147264d fix make bench (#1510) 2025-10-21 11:32:34 -05:00
25 changed files with 4861 additions and 288 deletions

View File

@@ -1,8 +1,10 @@
package cert
import (
"encoding/hex"
"encoding/pem"
"fmt"
"time"
"golang.org/x/crypto/ed25519"
)
@@ -189,3 +191,71 @@ func UnmarshalSigningPrivateKeyFromPEM(b []byte) ([]byte, []byte, Curve, error)
}
return k.Bytes, r, curve, nil
}
// Backward compatibility functions for older API
func MarshalX25519PublicKey(b []byte) []byte {
return MarshalPublicKeyToPEM(Curve_CURVE25519, b)
}
func MarshalX25519PrivateKey(b []byte) []byte {
return MarshalPrivateKeyToPEM(Curve_CURVE25519, b)
}
func MarshalPublicKey(curve Curve, b []byte) []byte {
return MarshalPublicKeyToPEM(curve, b)
}
func MarshalPrivateKey(curve Curve, b []byte) []byte {
return MarshalPrivateKeyToPEM(curve, b)
}
// NebulaCertificate is a compatibility wrapper for the old API
type NebulaCertificate struct {
Details NebulaCertificateDetails
Signature []byte
cert Certificate
}
// NebulaCertificateDetails is a compatibility wrapper for certificate details
type NebulaCertificateDetails struct {
Name string
NotBefore time.Time
NotAfter time.Time
PublicKey []byte
IsCA bool
Issuer []byte
Curve Curve
}
// UnmarshalNebulaCertificateFromPEM provides backward compatibility with the old API
func UnmarshalNebulaCertificateFromPEM(b []byte) (*NebulaCertificate, []byte, error) {
c, rest, err := UnmarshalCertificateFromPEM(b)
if err != nil {
return nil, rest, err
}
// Convert to old format
nc := &NebulaCertificate{
Details: NebulaCertificateDetails{
Name: c.Name(),
NotBefore: c.NotBefore(),
NotAfter: c.NotAfter(),
PublicKey: c.PublicKey(),
IsCA: c.IsCA(),
Curve: c.Curve(),
},
Signature: c.Signature(),
cert: c,
}
// Handle issuer
if c.Issuer() != "" {
issuerBytes, err := hex.DecodeString(c.Issuer())
if err != nil {
return nil, rest, fmt.Errorf("failed to decode issuer fingerprint: %w", err)
}
nc.Details.Issuer = issuerBytes
}
return nc, rest, nil
}

View File

@@ -29,6 +29,8 @@ type m = map[string]any
// newSimpleServer creates a nebula instance with many assumptions
func newSimpleServer(v cert.Version, caCrt cert.Certificate, caKey []byte, name string, sVpnNetworks string, overrides m) (*nebula.Control, []netip.Prefix, netip.AddrPort, *config.C) {
l := NewTestLogger()
var vpnNetworks []netip.Prefix
for _, sn := range strings.Split(sVpnNetworks, ",") {
vpnIpNet, err := netip.ParsePrefix(strings.TrimSpace(sn))
@@ -54,25 +56,6 @@ func newSimpleServer(v cert.Version, caCrt cert.Certificate, caKey []byte, name
budpIp[3] = 239
udpAddr = netip.AddrPortFrom(netip.AddrFrom16(budpIp), 4242)
}
return newSimpleServerWithUdp(v, caCrt, caKey, name, sVpnNetworks, udpAddr, overrides)
}
func newSimpleServerWithUdp(v cert.Version, caCrt cert.Certificate, caKey []byte, name string, sVpnNetworks string, udpAddr netip.AddrPort, overrides m) (*nebula.Control, []netip.Prefix, netip.AddrPort, *config.C) {
l := NewTestLogger()
var vpnNetworks []netip.Prefix
for _, sn := range strings.Split(sVpnNetworks, ",") {
vpnIpNet, err := netip.ParsePrefix(strings.TrimSpace(sn))
if err != nil {
panic(err)
}
vpnNetworks = append(vpnNetworks, vpnIpNet)
}
if len(vpnNetworks) == 0 {
panic("no vpn networks")
}
_, _, myPrivKey, myPEM := cert_test.NewTestCert(v, cert.Curve_CURVE25519, caCrt, caKey, name, time.Now(), time.Now().Add(5*time.Minute), vpnNetworks, nil, []string{})
caB, err := caCrt.MarshalPEM()

View File

@@ -4,7 +4,6 @@
package e2e
import (
"net/netip"
"testing"
"time"
@@ -56,50 +55,3 @@ func TestDropInactiveTunnels(t *testing.T) {
myControl.Stop()
theirControl.Stop()
}
func TestCrossStackRelaysWork(t *testing.T) {
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version2, ca, caKey, "me ", "10.128.0.1/24,fc00::1/64", m{"relay": m{"use_relays": true}})
relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "relay ", "10.128.0.128/24,fc00::128/64", m{"relay": m{"am_relay": true}})
theirUdp := netip.MustParseAddrPort("10.0.0.2:4242")
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServerWithUdp(cert.Version2, ca, caKey, "them ", "fc00::2/64", theirUdp, m{"relay": m{"use_relays": true}})
//myVpnV4 := myVpnIpNet[0]
myVpnV6 := myVpnIpNet[1]
relayVpnV4 := relayVpnIpNet[0]
relayVpnV6 := relayVpnIpNet[1]
theirVpnV6 := theirVpnIpNet[0]
// Teach my how to get to the relay and that their can be reached via the relay
myControl.InjectLightHouseAddr(relayVpnV4.Addr(), relayUdpAddr)
myControl.InjectLightHouseAddr(relayVpnV6.Addr(), relayUdpAddr)
myControl.InjectRelays(theirVpnV6.Addr(), []netip.Addr{relayVpnV6.Addr()})
relayControl.InjectLightHouseAddr(theirVpnV6.Addr(), theirUdpAddr)
// Build a router so we don't have to reason who gets which packet
r := router.NewR(t, myControl, relayControl, theirControl)
defer r.RenderFlow()
// Start the servers
myControl.Start()
relayControl.Start()
theirControl.Start()
t.Log("Trigger a handshake from me to them via the relay")
myControl.InjectTunUDPPacket(theirVpnV6.Addr(), 80, myVpnV6.Addr(), 80, []byte("Hi from me"))
p := r.RouteForAllUntilTxTun(theirControl)
r.Log("Assert the tunnel works")
assertUdpPacket(t, []byte("Hi from me"), p, myVpnV6.Addr(), theirVpnV6.Addr(), 80, 80)
t.Log("reply?")
theirControl.InjectTunUDPPacket(myVpnV6.Addr(), 80, theirVpnV6.Addr(), 80, []byte("Hi from them"))
p = r.RouteForAllUntilTxTun(myControl)
assertUdpPacket(t, []byte("Hi from them"), p, theirVpnV6.Addr(), myVpnV6.Addr(), 80, 80)
r.RenderHostmaps("Final hostmaps", myControl, relayControl, theirControl)
//t.Log("finish up")
//myControl.Stop()
//theirControl.Stop()
//relayControl.Stop()
}

View File

@@ -417,8 +417,6 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw
return nil
}
var ErrUnknownNetworkType = errors.New("unknown network type")
var ErrPeerRejected = errors.New("remote IP is not within a subnet that we handle")
var ErrInvalidRemoteIP = errors.New("remote IP is not in remote certificate subnets")
var ErrInvalidLocalIP = errors.New("local IP is not in list of handled local IPs")
var ErrNoMatchingRule = errors.New("no matching rule in firewall table")
@@ -431,31 +429,18 @@ func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *
return nil
}
// Make sure remote address matches nebula certificate, and determine how to treat it
if h.networks == nil {
// Simple case: Certificate has one address and no unsafe networks
if h.vpnAddrs[0] != fp.RemoteAddr {
// Make sure remote address matches nebula certificate
if h.networks != nil {
if !h.networks.Contains(fp.RemoteAddr) {
f.metrics(incoming).droppedRemoteAddr.Inc(1)
return ErrInvalidRemoteIP
}
} else {
nwType, ok := h.networks.Lookup(fp.RemoteAddr)
if !ok {
// Simple case: Certificate has one address and no unsafe networks
if h.vpnAddrs[0] != fp.RemoteAddr {
f.metrics(incoming).droppedRemoteAddr.Inc(1)
return ErrInvalidRemoteIP
}
switch nwType {
case NetworkTypeVPN:
break // nothing special
case NetworkTypeVPNPeer:
f.metrics(incoming).droppedRemoteAddr.Inc(1)
return ErrPeerRejected // reject for now, one day this may have different FW rules
case NetworkTypeUnsafe:
break // nothing special, one day this may have different FW rules
default:
f.metrics(incoming).droppedRemoteAddr.Inc(1)
return ErrUnknownNetworkType //should never happen
}
}
// Make sure we are supposed to be handling this local ip address

View File

@@ -8,7 +8,6 @@ import (
"testing"
"time"
"github.com/gaissmai/bart"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/firewall"
@@ -150,8 +149,7 @@ func TestFirewall_Drop(t *testing.T) {
l := test.NewLogger()
ob := &bytes.Buffer{}
l.SetOutput(ob)
myVpnNetworksTable := new(bart.Lite)
myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
p := firewall.Packet{
LocalAddr: netip.MustParseAddr("1.2.3.4"),
RemoteAddr: netip.MustParseAddr("1.2.3.4"),
@@ -176,7 +174,7 @@ func TestFirewall_Drop(t *testing.T) {
},
vpnAddrs: []netip.Addr{netip.MustParseAddr("1.2.3.4")},
}
h.buildNetworks(myVpnNetworksTable, c.networks, c.unsafeNetworks)
h.buildNetworks(c.networks, c.unsafeNetworks)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
@@ -228,9 +226,6 @@ func TestFirewall_DropV6(t *testing.T) {
ob := &bytes.Buffer{}
l.SetOutput(ob)
myVpnNetworksTable := new(bart.Lite)
myVpnNetworksTable.Insert(netip.MustParsePrefix("fd00::/7"))
p := firewall.Packet{
LocalAddr: netip.MustParseAddr("fd12::34"),
RemoteAddr: netip.MustParseAddr("fd12::34"),
@@ -255,7 +250,7 @@ func TestFirewall_DropV6(t *testing.T) {
},
vpnAddrs: []netip.Addr{netip.MustParseAddr("fd12::34")},
}
h.buildNetworks(myVpnNetworksTable, c.networks, c.unsafeNetworks)
h.buildNetworks(c.networks, c.unsafeNetworks)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
@@ -458,8 +453,6 @@ func TestFirewall_Drop2(t *testing.T) {
l := test.NewLogger()
ob := &bytes.Buffer{}
l.SetOutput(ob)
myVpnNetworksTable := new(bart.Lite)
myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
p := firewall.Packet{
LocalAddr: netip.MustParseAddr("1.2.3.4"),
@@ -485,7 +478,7 @@ func TestFirewall_Drop2(t *testing.T) {
},
vpnAddrs: []netip.Addr{network.Addr()},
}
h.buildNetworks(myVpnNetworksTable, c.Certificate.Networks(), c.Certificate.UnsafeNetworks())
h.buildNetworks(c.Certificate.Networks(), c.Certificate.UnsafeNetworks())
c1 := cert.CachedCertificate{
Certificate: &dummyCert{
@@ -500,7 +493,7 @@ func TestFirewall_Drop2(t *testing.T) {
peerCert: &c1,
},
}
h1.buildNetworks(myVpnNetworksTable, c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks())
h1.buildNetworks(c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks())
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
@@ -517,8 +510,6 @@ func TestFirewall_Drop3(t *testing.T) {
l := test.NewLogger()
ob := &bytes.Buffer{}
l.SetOutput(ob)
myVpnNetworksTable := new(bart.Lite)
myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
p := firewall.Packet{
LocalAddr: netip.MustParseAddr("1.2.3.4"),
@@ -550,7 +541,7 @@ func TestFirewall_Drop3(t *testing.T) {
},
vpnAddrs: []netip.Addr{network.Addr()},
}
h1.buildNetworks(myVpnNetworksTable, c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks())
h1.buildNetworks(c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks())
c2 := cert.CachedCertificate{
Certificate: &dummyCert{
@@ -565,7 +556,7 @@ func TestFirewall_Drop3(t *testing.T) {
},
vpnAddrs: []netip.Addr{network.Addr()},
}
h2.buildNetworks(myVpnNetworksTable, c2.Certificate.Networks(), c2.Certificate.UnsafeNetworks())
h2.buildNetworks(c2.Certificate.Networks(), c2.Certificate.UnsafeNetworks())
c3 := cert.CachedCertificate{
Certificate: &dummyCert{
@@ -580,7 +571,7 @@ func TestFirewall_Drop3(t *testing.T) {
},
vpnAddrs: []netip.Addr{network.Addr()},
}
h3.buildNetworks(myVpnNetworksTable, c3.Certificate.Networks(), c3.Certificate.UnsafeNetworks())
h3.buildNetworks(c3.Certificate.Networks(), c3.Certificate.UnsafeNetworks())
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", netip.Prefix{}, netip.Prefix{}, "", ""))
@@ -606,8 +597,6 @@ func TestFirewall_Drop3V6(t *testing.T) {
l := test.NewLogger()
ob := &bytes.Buffer{}
l.SetOutput(ob)
myVpnNetworksTable := new(bart.Lite)
myVpnNetworksTable.Insert(netip.MustParsePrefix("fd00::/7"))
p := firewall.Packet{
LocalAddr: netip.MustParseAddr("fd12::34"),
@@ -631,7 +620,7 @@ func TestFirewall_Drop3V6(t *testing.T) {
},
vpnAddrs: []netip.Addr{network.Addr()},
}
h.buildNetworks(myVpnNetworksTable, c.Certificate.Networks(), c.Certificate.UnsafeNetworks())
h.buildNetworks(c.Certificate.Networks(), c.Certificate.UnsafeNetworks())
// Test a remote address match
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
@@ -644,8 +633,6 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
l := test.NewLogger()
ob := &bytes.Buffer{}
l.SetOutput(ob)
myVpnNetworksTable := new(bart.Lite)
myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
p := firewall.Packet{
LocalAddr: netip.MustParseAddr("1.2.3.4"),
@@ -672,7 +659,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
},
vpnAddrs: []netip.Addr{network.Addr()},
}
h.buildNetworks(myVpnNetworksTable, c.Certificate.Networks(), c.Certificate.UnsafeNetworks())
h.buildNetworks(c.Certificate.Networks(), c.Certificate.UnsafeNetworks())
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
@@ -709,8 +696,6 @@ func TestFirewall_DropIPSpoofing(t *testing.T) {
l := test.NewLogger()
ob := &bytes.Buffer{}
l.SetOutput(ob)
myVpnNetworksTable := new(bart.Lite)
myVpnNetworksTable.Insert(netip.MustParsePrefix("192.0.2.1/24"))
c := cert.CachedCertificate{
Certificate: &dummyCert{
@@ -732,7 +717,7 @@ func TestFirewall_DropIPSpoofing(t *testing.T) {
},
vpnAddrs: []netip.Addr{c1.Certificate.Networks()[0].Addr()},
}
h1.buildNetworks(myVpnNetworksTable, c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks())
h1.buildNetworks(c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks())
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)

View File

@@ -183,18 +183,17 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
return
}
var vpnAddrs []netip.Addr
var filteredNetworks []netip.Prefix
certName := remoteCert.Certificate.Name()
certVersion := remoteCert.Certificate.Version()
fingerprint := remoteCert.Fingerprint
issuer := remoteCert.Certificate.Issuer()
vpnNetworks := remoteCert.Certificate.Networks()
anyVpnAddrsInCommon := false
vpnAddrs := make([]netip.Addr, len(vpnNetworks))
for i, network := range vpnNetworks {
for _, network := range remoteCert.Certificate.Networks() {
vpnAddr := network.Addr()
if f.myVpnAddrsTable.Contains(vpnAddr) {
f.l.WithField("vpnNetworks", vpnNetworks).WithField("udpAddr", addr).
f.l.WithField("vpnAddr", vpnAddr).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint).
@@ -202,10 +201,24 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Refusing to handshake with myself")
return
}
vpnAddrs[i] = network.Addr()
if f.myVpnNetworksTable.Contains(vpnAddr) {
anyVpnAddrsInCommon = true
// vpnAddrs outside our vpn networks are of no use to us, filter them out
if !f.myVpnNetworksTable.Contains(vpnAddr) {
continue
}
filteredNetworks = append(filteredNetworks, network)
vpnAddrs = append(vpnAddrs, vpnAddr)
}
if len(vpnAddrs) == 0 {
f.l.WithError(err).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint).
WithField("issuer", issuer).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("No usable vpn addresses from host, refusing handshake")
return
}
if addr.IsValid() {
@@ -242,30 +255,26 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
},
}
msgRxL := f.l.WithFields(m{
"vpnAddrs": vpnAddrs,
"udpAddr": addr,
"certName": certName,
"certVersion": certVersion,
"fingerprint": fingerprint,
"issuer": issuer,
"initiatorIndex": hs.Details.InitiatorIndex,
"responderIndex": hs.Details.ResponderIndex,
"remoteIndex": h.RemoteIndex,
"handshake": m{"stage": 1, "style": "ix_psk0"},
})
if anyVpnAddrsInCommon {
msgRxL.Info("Handshake message received")
} else {
//todo warn if not lighthouse or relay?
msgRxL.Info("Handshake message received, but no vpnNetworks in common.")
}
f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint).
WithField("issuer", issuer).
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
Info("Handshake message received")
hs.Details.ResponderIndex = myIndex
hs.Details.Cert = cs.getHandshakeBytes(ci.myCert.Version())
if hs.Details.Cert == nil {
msgRxL.WithField("myCertVersion", ci.myCert.Version()).
f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint).
WithField("issuer", issuer).
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
WithField("certVersion", ci.myCert.Version()).
Error("Unable to handshake with host because no certificate handshake bytes is available")
return
}
@@ -323,7 +332,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
hostinfo.remotes = f.lightHouse.QueryCache(vpnAddrs)
hostinfo.SetRemote(addr)
hostinfo.buildNetworks(f.myVpnNetworksTable, remoteCert.Certificate.Networks(), remoteCert.Certificate.UnsafeNetworks())
hostinfo.buildNetworks(filteredNetworks, remoteCert.Certificate.UnsafeNetworks())
existing, err := f.handshakeManager.CheckAndComplete(hostinfo, 0, f)
if err != nil {
@@ -564,17 +573,30 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0])
}
anyVpnAddrsInCommon := false
vpnAddrs := make([]netip.Addr, len(vpnNetworks))
for i, network := range vpnNetworks {
vpnAddrs[i] = network.Addr()
if f.myVpnNetworksTable.Contains(network.Addr()) {
anyVpnAddrsInCommon = true
var vpnAddrs []netip.Addr
var filteredNetworks []netip.Prefix
for _, network := range vpnNetworks {
// vpnAddrs outside our vpn networks are of no use to us, filter them out
vpnAddr := network.Addr()
if !f.myVpnNetworksTable.Contains(vpnAddr) {
continue
}
filteredNetworks = append(filteredNetworks, network)
vpnAddrs = append(vpnAddrs, vpnAddr)
}
if len(vpnAddrs) == 0 {
f.l.WithError(err).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint).
WithField("issuer", issuer).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).Error("No usable vpn addresses from host, refusing handshake")
return true
}
// Ensure the right host responded
// todo is it more correct to see if any of hostinfo.vpnAddrs are in the cert? it should have len==1, but one day it might not?
if !slices.Contains(vpnAddrs, hostinfo.vpnAddrs[0]) {
f.l.WithField("intendedVpnAddrs", hostinfo.vpnAddrs).WithField("haveVpnNetworks", vpnNetworks).
WithField("udpAddr", addr).
@@ -587,7 +609,6 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
f.handshakeManager.DeleteHostInfo(hostinfo)
// Create a new hostinfo/handshake for the intended vpn ip
//TODO is hostinfo.vpnAddrs[0] always the address to use?
f.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], func(newHH *HandshakeHostInfo) {
// Block the current used address
newHH.hostinfo.remotes = hostinfo.remotes
@@ -614,7 +635,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
ci.window.Update(f.l, 2)
duration := time.Since(hh.startTime).Nanoseconds()
msgRxL := f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint).
@@ -622,17 +643,12 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
WithField("durationNs", duration).
WithField("sentCachedPackets", len(hh.packetStore))
if anyVpnAddrsInCommon {
msgRxL.Info("Handshake message received")
} else {
//todo warn if not lighthouse or relay?
msgRxL.Info("Handshake message received, but no vpnNetworks in common.")
}
WithField("sentCachedPackets", len(hh.packetStore)).
Info("Handshake message received")
// Build up the radix for the firewall if we have subnets in the cert
hostinfo.vpnAddrs = vpnAddrs
hostinfo.buildNetworks(f.myVpnNetworksTable, remoteCert.Certificate.Networks(), remoteCert.Certificate.UnsafeNetworks())
hostinfo.buildNetworks(filteredNetworks, remoteCert.Certificate.UnsafeNetworks())
// Complete our handshake and update metrics, this will replace any existing tunnels for the vpnAddrs here
f.handshakeManager.Complete(hostinfo, f)

View File

@@ -212,18 +212,6 @@ func (rs *RelayState) InsertRelay(ip netip.Addr, idx uint32, r *Relay) {
rs.relayForByIdx[idx] = r
}
type NetworkType uint8
const (
NetworkTypeUnknown NetworkType = iota
// NetworkTypeVPN is a network that overlaps one or more of the vpnNetworks in our certificate
NetworkTypeVPN
// NetworkTypeVPNPeer is a network that does not overlap one of our networks
NetworkTypeVPNPeer
// NetworkTypeUnsafe is a network from Certificate.UnsafeNetworks()
NetworkTypeUnsafe
)
type HostInfo struct {
remote netip.AddrPort
remotes *RemoteList
@@ -232,11 +220,13 @@ type HostInfo struct {
remoteIndexId uint32
localIndexId uint32
// vpnAddrs is a list of vpn addresses assigned to this host
// vpnAddrs is a list of vpn addresses assigned to this host that are within our own vpn networks
// The host may have other vpn addresses that are outside our
// vpn networks but were removed because they are not usable
vpnAddrs []netip.Addr
// networks is a combination of specific vpn addresses (not prefixes!) and full unsafe networks assigned to this host.
networks *bart.Table[NetworkType]
// networks are both all vpn and unsafe networks assigned to this host
networks *bart.Lite
relayState RelayState
// HandshakePacket records the packets used to create this hostinfo
@@ -740,27 +730,20 @@ func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, newRemote netip.AddrPort) b
return false
}
func (i *HostInfo) buildNetworks(myVpnNetworksTable *bart.Lite, networks, unsafeNetworks []netip.Prefix) {
func (i *HostInfo) buildNetworks(networks, unsafeNetworks []netip.Prefix) {
if len(networks) == 1 && len(unsafeNetworks) == 0 {
if myVpnNetworksTable.Contains(networks[0].Addr()) {
return // Simple case, no CIDRTree needed
}
// Simple case, no CIDRTree needed
return
}
i.networks = new(bart.Table[NetworkType])
i.networks = new(bart.Lite)
for _, network := range networks {
var nwType NetworkType
if myVpnNetworksTable.Contains(network.Addr()) {
nwType = NetworkTypeVPN
} else {
nwType = NetworkTypeVPNPeer
}
nprefix := netip.PrefixFrom(network.Addr(), network.Addr().BitLen())
i.networks.Insert(nprefix, nwType)
i.networks.Insert(nprefix)
}
for _, network := range unsafeNetworks {
i.networks.Insert(network, NetworkTypeUnsafe)
i.networks.Insert(network)
}
}

View File

@@ -271,7 +271,10 @@ func (f *Interface) listenOut(i int) {
fwPacket := &firewall.Packet{}
nb := make([]byte, 12, 12)
li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) {
li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte, release func()) {
if release != nil {
defer release()
}
f.readOutsidePackets(fromUdpAddr, nil, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get(f.l))
})
}

View File

@@ -1017,17 +1017,17 @@ func (lhh *LightHouseHandler) resetMeta() *NebulaMeta {
return lhh.meta
}
func (lhh *LightHouseHandler) HandleRequest(rAddr netip.AddrPort, hostinfo *HostInfo, p []byte, w EncWriter) {
func (lhh *LightHouseHandler) HandleRequest(rAddr netip.AddrPort, fromVpnAddrs []netip.Addr, p []byte, w EncWriter) {
n := lhh.resetMeta()
err := n.Unmarshal(p)
if err != nil {
lhh.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", rAddr).
lhh.l.WithError(err).WithField("vpnAddrs", fromVpnAddrs).WithField("udpAddr", rAddr).
Error("Failed to unmarshal lighthouse packet")
return
}
if n.Details == nil {
lhh.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", rAddr).
lhh.l.WithField("vpnAddrs", fromVpnAddrs).WithField("udpAddr", rAddr).
Error("Invalid lighthouse update")
return
}
@@ -1036,24 +1036,24 @@ func (lhh *LightHouseHandler) HandleRequest(rAddr netip.AddrPort, hostinfo *Host
switch n.Type {
case NebulaMeta_HostQuery:
lhh.handleHostQuery(n, hostinfo, rAddr, w)
lhh.handleHostQuery(n, fromVpnAddrs, rAddr, w)
case NebulaMeta_HostQueryReply:
lhh.handleHostQueryReply(n, hostinfo.vpnAddrs)
lhh.handleHostQueryReply(n, fromVpnAddrs)
case NebulaMeta_HostUpdateNotification:
lhh.handleHostUpdateNotification(n, hostinfo, w)
lhh.handleHostUpdateNotification(n, fromVpnAddrs, w)
case NebulaMeta_HostMovedNotification:
case NebulaMeta_HostPunchNotification:
lhh.handleHostPunchNotification(n, hostinfo.vpnAddrs, w)
lhh.handleHostPunchNotification(n, fromVpnAddrs, w)
case NebulaMeta_HostUpdateNotificationAck:
// noop
}
}
func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, hostinfo *HostInfo, addr netip.AddrPort, w EncWriter) {
func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, fromVpnAddrs []netip.Addr, addr netip.AddrPort, w EncWriter) {
// Exit if we don't answer queries
if !lhh.lh.amLighthouse {
if lhh.l.Level >= logrus.DebugLevel {
@@ -1065,7 +1065,7 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, hostinfo *HostInfo,
queryVpnAddr, useVersion, err := n.Details.GetVpnAddrAndVersion()
if err != nil {
if lhh.l.Level >= logrus.DebugLevel {
lhh.l.WithField("from", hostinfo.vpnAddrs).WithField("details", n.Details).
lhh.l.WithField("from", fromVpnAddrs).WithField("details", n.Details).
Debugln("Dropping malformed HostQuery")
}
return
@@ -1073,7 +1073,7 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, hostinfo *HostInfo,
if useVersion == cert.Version1 && queryVpnAddr.Is6() {
// this case really shouldn't be possible to represent, but reject it anyway.
if lhh.l.Level >= logrus.DebugLevel {
lhh.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("queryVpnAddr", queryVpnAddr).
lhh.l.WithField("vpnAddrs", fromVpnAddrs).WithField("queryVpnAddr", queryVpnAddr).
Debugln("invalid vpn addr for v1 handleHostQuery")
}
return
@@ -1099,14 +1099,14 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, hostinfo *HostInfo,
}
if err != nil {
lhh.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).Error("Failed to marshal lighthouse host query reply")
lhh.l.WithError(err).WithField("vpnAddrs", fromVpnAddrs).Error("Failed to marshal lighthouse host query reply")
return
}
lhh.lh.metricTx(NebulaMeta_HostQueryReply, 1)
w.SendMessageToHostInfo(header.LightHouse, 0, hostinfo, lhh.pb[:ln], lhh.nb, lhh.out[:0])
w.SendMessageToVpnAddr(header.LightHouse, 0, fromVpnAddrs[0], lhh.pb[:ln], lhh.nb, lhh.out[:0])
lhh.sendHostPunchNotification(n, hostinfo.vpnAddrs, queryVpnAddr, w)
lhh.sendHostPunchNotification(n, fromVpnAddrs, queryVpnAddr, w)
}
// sendHostPunchNotification signals the other side to punch some zero byte udp packets
@@ -1115,34 +1115,20 @@ func (lhh *LightHouseHandler) sendHostPunchNotification(n *NebulaMeta, fromVpnAd
found, ln, err := lhh.lh.queryAndPrepMessage(whereToPunch, func(c *cache) (int, error) {
n = lhh.resetMeta()
n.Type = NebulaMeta_HostPunchNotification
punchNotifDestHI := lhh.lh.ifce.GetHostInfo(punchNotifDest)
targetHI := lhh.lh.ifce.GetHostInfo(punchNotifDest)
var useVersion cert.Version
if punchNotifDestHI == nil {
if targetHI == nil {
useVersion = lhh.lh.ifce.GetCertState().initiatingVersion
} else {
crt := targetHI.GetCert().Certificate
useVersion = crt.Version()
// we can only retarget if we have a hostinfo
punchNotifDestCrt := punchNotifDestHI.GetCert().Certificate
useVersion = punchNotifDestCrt.Version()
punchNotifDestNetworks := punchNotifDestCrt.Networks()
//if we (the lighthouse) don't have a network in common with punchNotifDest, try to find one
if !lhh.lh.myVpnNetworksTable.Contains(punchNotifDest) {
newPunchNotifDest, ok := findNetworkUnion(lhh.lh.myVpnNetworks, punchNotifDestHI.vpnAddrs)
if ok {
punchNotifDest = newPunchNotifDest
} else {
if lhh.l.Level >= logrus.DebugLevel {
lhh.l.WithField("to", punchNotifDestNetworks).Debugln("unable to notify host to host, no addresses in common")
}
}
}
newWhereToPunch, ok := findNetworkUnion(punchNotifDestNetworks, fromVpnAddrs)
newDest, ok := findNetworkUnion(crt.Networks(), fromVpnAddrs)
if ok {
whereToPunch = newWhereToPunch
whereToPunch = newDest
} else {
if lhh.l.Level >= logrus.DebugLevel {
lhh.l.WithFields(m{"from": fromVpnAddrs, "to": punchNotifDestNetworks}).Debugln("unable to punch to host, no addresses in common with requestor")
lhh.l.WithField("to", crt.Networks()).Debugln("unable to punch to host, no addresses in common")
}
}
}
@@ -1248,8 +1234,7 @@ func (lhh *LightHouseHandler) handleHostQueryReply(n *NebulaMeta, fromVpnAddrs [
}
}
func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, hostinfo *HostInfo, w EncWriter) {
fromVpnAddrs := hostinfo.vpnAddrs
func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, fromVpnAddrs []netip.Addr, w EncWriter) {
if !lhh.lh.amLighthouse {
if lhh.l.Level >= logrus.DebugLevel {
lhh.l.Debugln("I am not a lighthouse, do not take host updates: ", fromVpnAddrs)
@@ -1317,7 +1302,7 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, hostin
}
lhh.lh.metricTx(NebulaMeta_HostUpdateNotificationAck, 1)
w.SendMessageToHostInfo(header.LightHouse, 0, hostinfo, lhh.pb[:ln], lhh.nb, lhh.out[:0])
w.SendMessageToVpnAddr(header.LightHouse, 0, fromVpnAddrs[0], lhh.pb[:ln], lhh.nb, lhh.out[:0])
}
func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, fromVpnAddrs []netip.Addr, w EncWriter) {

View File

@@ -132,13 +132,8 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) {
)
mw := &mockEncWriter{}
hostinfo := &HostInfo{
ConnectionState: &ConnectionState{
eKey: nil,
dKey: nil,
},
vpnAddrs: []netip.Addr{vpnIp2},
}
hi := []netip.Addr{vpnIp2}
b.Run("notfound", func(b *testing.B) {
lhh := lh.NewRequestHandler()
req := &NebulaMeta{
@@ -151,7 +146,7 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) {
p, err := req.Marshal()
require.NoError(b, err)
for n := 0; n < b.N; n++ {
lhh.HandleRequest(rAddr, hostinfo, p, mw)
lhh.HandleRequest(rAddr, hi, p, mw)
}
})
b.Run("found", func(b *testing.B) {
@@ -167,7 +162,7 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) {
require.NoError(b, err)
for n := 0; n < b.N; n++ {
lhh.HandleRequest(rAddr, hostinfo, p, mw)
lhh.HandleRequest(rAddr, hi, p, mw)
}
})
}
@@ -331,14 +326,7 @@ func newLHHostRequest(fromAddr netip.AddrPort, myVpnIp, queryVpnIp netip.Addr, l
w := &testEncWriter{
metaFilter: &filter,
}
hostinfo := &HostInfo{
ConnectionState: &ConnectionState{
eKey: nil,
dKey: nil,
},
vpnAddrs: []netip.Addr{myVpnIp},
}
lhh.HandleRequest(fromAddr, hostinfo, b, w)
lhh.HandleRequest(fromAddr, []netip.Addr{myVpnIp}, b, w)
return w.lastReply
}
@@ -367,15 +355,9 @@ func newLHHostUpdate(fromAddr netip.AddrPort, vpnIp netip.Addr, addrs []netip.Ad
if err != nil {
panic(err)
}
hostinfo := &HostInfo{
ConnectionState: &ConnectionState{
eKey: nil,
dKey: nil,
},
vpnAddrs: []netip.Addr{vpnIp},
}
w := &testEncWriter{}
lhh.HandleRequest(fromAddr, hostinfo, b, w)
lhh.HandleRequest(fromAddr, []netip.Addr{vpnIp}, b, w)
}
type testLhReply struct {

View File

@@ -138,7 +138,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
return
}
lhf.HandleRequest(ip, hostinfo, d, f)
lhf.HandleRequest(ip, hostinfo.vpnAddrs, d, f)
// Fallthrough to the bottom to record incoming traffic

16
udp/config.go Normal file
View File

@@ -0,0 +1,16 @@
package udp
import "sync/atomic"
var disableUDPCsum atomic.Bool
// SetDisableUDPCsum controls whether IPv4 UDP sockets opt out of kernel
// checksum calculation via SO_NO_CHECK. Only applicable on platforms that
// support the option (Linux). IPv6 always keeps the checksum enabled.
func SetDisableUDPCsum(disable bool) {
disableUDPCsum.Store(disable)
}
func udpChecksumDisabled() bool {
return disableUDPCsum.Load()
}

View File

@@ -11,6 +11,7 @@ const MTU = 9001
type EncReader func(
addr netip.AddrPort,
payload []byte,
release func(),
)
type Conn interface {

1740
udp/io_uring_linux.go Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,25 @@
//go:build linux && (386 || amd64p32 || arm || mips || mipsle) && !android && !e2e_testing
// +build linux
// +build 386 amd64p32 arm mips mipsle
// +build !android
// +build !e2e_testing
package udp
import "golang.org/x/sys/unix"
func controllen(n int) uint32 {
return uint32(n)
}
func setCmsgLen(h *unix.Cmsghdr, n int) {
h.Len = uint32(unix.CmsgLen(n))
}
func setIovecLen(v *unix.Iovec, n int) {
v.Len = uint32(n)
}
func setMsghdrIovlen(m *unix.Msghdr, n int) {
m.Iovlen = uint32(n)
}

View File

@@ -0,0 +1,25 @@
//go:build linux && (amd64 || arm64 || ppc64 || ppc64le || mips64 || mips64le || s390x || riscv64 || loong64) && !android && !e2e_testing
// +build linux
// +build amd64 arm64 ppc64 ppc64le mips64 mips64le s390x riscv64 loong64
// +build !android
// +build !e2e_testing
package udp
import "golang.org/x/sys/unix"
func controllen(n int) uint64 {
return uint64(n)
}
func setCmsgLen(h *unix.Cmsghdr, n int) {
h.Len = uint64(unix.CmsgLen(n))
}
func setIovecLen(v *unix.Iovec, n int) {
v.Len = uint64(n)
}
func setMsghdrIovlen(m *unix.Msghdr, n int) {
m.Iovlen = uint64(n)
}

25
udp/sendmmsg_linux_32.go Normal file
View File

@@ -0,0 +1,25 @@
//go:build linux && (386 || amd64p32 || arm || mips || mipsle) && !android && !e2e_testing
package udp
import (
"unsafe"
"golang.org/x/sys/unix"
)
type linuxMmsgHdr struct {
Hdr unix.Msghdr
Len uint32
}
func sendmmsg(fd int, hdrs []linuxMmsgHdr, flags int) (int, error) {
if len(hdrs) == 0 {
return 0, nil
}
n, _, errno := unix.Syscall6(unix.SYS_SENDMMSG, uintptr(fd), uintptr(unsafe.Pointer(&hdrs[0])), uintptr(len(hdrs)), uintptr(flags), 0, 0)
if errno != 0 {
return int(n), errno
}
return int(n), nil
}

26
udp/sendmmsg_linux_64.go Normal file
View File

@@ -0,0 +1,26 @@
//go:build linux && (amd64 || arm64 || ppc64 || ppc64le || mips64 || mips64le || s390x || riscv64 || loong64) && !android && !e2e_testing
package udp
import (
"unsafe"
"golang.org/x/sys/unix"
)
type linuxMmsgHdr struct {
Hdr unix.Msghdr
Len uint32
_ uint32
}
func sendmmsg(fd int, hdrs []linuxMmsgHdr, flags int) (int, error) {
if len(hdrs) == 0 {
return 0, nil
}
n, _, errno := unix.Syscall6(unix.SYS_SENDMMSG, uintptr(fd), uintptr(unsafe.Pointer(&hdrs[0])), uintptr(len(hdrs)), uintptr(flags), 0, 0)
if errno != 0 {
return int(n), errno
}
return int(n), nil
}

View File

@@ -180,7 +180,7 @@ func (u *StdConn) ListenOut(r EncReader) {
u.l.WithError(err).Error("unexpected udp socket receive error")
}
r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n])
r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n], nil)
}
}

View File

@@ -82,6 +82,6 @@ func (u *GenericConn) ListenOut(r EncReader) {
return
}
r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n])
r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n], nil)
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -7,6 +7,9 @@
package udp
import (
"errors"
"fmt"
"golang.org/x/sys/unix"
)
@@ -30,17 +33,29 @@ type rawMessage struct {
Len uint32
}
func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) {
func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte, [][]byte) {
controlLen := int(u.controlLen.Load())
msgs := make([]rawMessage, n)
buffers := make([][]byte, n)
names := make([][]byte, n)
var controls [][]byte
if controlLen > 0 {
controls = make([][]byte, n)
}
for i := range msgs {
buffers[i] = make([]byte, MTU)
size := int(u.groBufSize.Load())
if size < MTU {
size = MTU
}
buf := u.borrowRxBuffer(size)
buffers[i] = buf
names[i] = make([]byte, unix.SizeofSockaddrInet6)
vs := []iovec{
{Base: &buffers[i][0], Len: uint32(len(buffers[i]))},
{Base: &buf[0], Len: uint32(len(buf))},
}
msgs[i].Hdr.Iov = &vs[0]
@@ -48,7 +63,71 @@ func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) {
msgs[i].Hdr.Name = &names[i][0]
msgs[i].Hdr.Namelen = uint32(len(names[i]))
if controlLen > 0 {
controls[i] = make([]byte, controlLen)
msgs[i].Hdr.Control = &controls[i][0]
msgs[i].Hdr.Controllen = controllen(len(controls[i]))
} else {
msgs[i].Hdr.Control = nil
msgs[i].Hdr.Controllen = controllen(0)
}
}
return msgs, buffers, names
return msgs, buffers, names, controls
}
func setIovecBase(msg *rawMessage, buf []byte) {
iov := (*iovec)(msg.Hdr.Iov)
iov.Base = &buf[0]
iov.Len = uint32(len(buf))
}
func rawMessageToUnixMsghdr(msg *rawMessage) (unix.Msghdr, unix.Iovec, error) {
var hdr unix.Msghdr
var iov unix.Iovec
if msg == nil {
return hdr, iov, errors.New("nil rawMessage")
}
if msg.Hdr.Iov == nil || msg.Hdr.Iov.Base == nil {
return hdr, iov, errors.New("rawMessage missing payload buffer")
}
payloadLen := int(msg.Hdr.Iov.Len)
if payloadLen < 0 {
return hdr, iov, fmt.Errorf("invalid payload length: %d", payloadLen)
}
iov.Base = msg.Hdr.Iov.Base
iov.Len = uint32(payloadLen)
hdr.Iov = &iov
hdr.Iovlen = 1
hdr.Name = msg.Hdr.Name
// CRITICAL: Always set to full buffer size for receive, not what kernel wrote last time
if hdr.Name != nil {
hdr.Namelen = uint32(unix.SizeofSockaddrInet6)
} else {
hdr.Namelen = 0
}
hdr.Control = msg.Hdr.Control
// CRITICAL: Use the allocated size, not what was previously returned
if hdr.Control != nil {
// Control buffer size is stored in Controllen from PrepareRawMessages
hdr.Controllen = msg.Hdr.Controllen
} else {
hdr.Controllen = 0
}
hdr.Flags = 0 // Reset flags for new receive
return hdr, iov, nil
}
func updateRawMessageFromUnixMsghdr(msg *rawMessage, hdr *unix.Msghdr, n int) {
if msg == nil || hdr == nil {
return
}
msg.Hdr.Namelen = hdr.Namelen
msg.Hdr.Controllen = hdr.Controllen
msg.Hdr.Flags = hdr.Flags
if n < 0 {
n = 0
}
msg.Len = uint32(n)
}

View File

@@ -7,6 +7,9 @@
package udp
import (
"errors"
"fmt"
"golang.org/x/sys/unix"
)
@@ -33,25 +36,99 @@ type rawMessage struct {
Pad0 [4]byte
}
func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) {
func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte, [][]byte) {
controlLen := int(u.controlLen.Load())
msgs := make([]rawMessage, n)
buffers := make([][]byte, n)
names := make([][]byte, n)
var controls [][]byte
if controlLen > 0 {
controls = make([][]byte, n)
}
for i := range msgs {
buffers[i] = make([]byte, MTU)
size := int(u.groBufSize.Load())
if size < MTU {
size = MTU
}
buf := u.borrowRxBuffer(size)
buffers[i] = buf
names[i] = make([]byte, unix.SizeofSockaddrInet6)
vs := []iovec{
{Base: &buffers[i][0], Len: uint64(len(buffers[i]))},
}
vs := []iovec{{Base: &buf[0], Len: uint64(len(buf))}}
msgs[i].Hdr.Iov = &vs[0]
msgs[i].Hdr.Iovlen = uint64(len(vs))
msgs[i].Hdr.Name = &names[i][0]
msgs[i].Hdr.Namelen = uint32(len(names[i]))
if controlLen > 0 {
controls[i] = make([]byte, controlLen)
msgs[i].Hdr.Control = &controls[i][0]
msgs[i].Hdr.Controllen = controllen(len(controls[i]))
} else {
msgs[i].Hdr.Control = nil
msgs[i].Hdr.Controllen = controllen(0)
}
}
return msgs, buffers, names
return msgs, buffers, names, controls
}
func setIovecBase(msg *rawMessage, buf []byte) {
iov := (*iovec)(msg.Hdr.Iov)
iov.Base = &buf[0]
iov.Len = uint64(len(buf))
}
func rawMessageToUnixMsghdr(msg *rawMessage) (unix.Msghdr, unix.Iovec, error) {
var hdr unix.Msghdr
var iov unix.Iovec
if msg == nil {
return hdr, iov, errors.New("nil rawMessage")
}
if msg.Hdr.Iov == nil || msg.Hdr.Iov.Base == nil {
return hdr, iov, errors.New("rawMessage missing payload buffer")
}
payloadLen := int(msg.Hdr.Iov.Len)
if payloadLen < 0 {
return hdr, iov, fmt.Errorf("invalid payload length: %d", payloadLen)
}
iov.Base = msg.Hdr.Iov.Base
iov.Len = uint64(payloadLen)
hdr.Iov = &iov
hdr.Iovlen = 1
hdr.Name = msg.Hdr.Name
// CRITICAL: Always set to full buffer size for receive, not what kernel wrote last time
if hdr.Name != nil {
hdr.Namelen = uint32(unix.SizeofSockaddrInet6)
} else {
hdr.Namelen = 0
}
hdr.Control = msg.Hdr.Control
// CRITICAL: Use the allocated size, not what was previously returned
if hdr.Control != nil {
// Control buffer size is stored in Controllen from PrepareRawMessages
hdr.Controllen = msg.Hdr.Controllen
} else {
hdr.Controllen = 0
}
hdr.Flags = 0 // Reset flags for new receive
return hdr, iov, nil
}
func updateRawMessageFromUnixMsghdr(msg *rawMessage, hdr *unix.Msghdr, n int) {
if msg == nil || hdr == nil {
return
}
msg.Hdr.Namelen = hdr.Namelen
msg.Hdr.Controllen = hdr.Controllen
msg.Hdr.Flags = hdr.Flags
if n < 0 {
n = 0
}
msg.Len = uint32(n)
}

View File

@@ -149,7 +149,7 @@ func (u *RIOConn) ListenOut(r EncReader) {
continue
}
r(netip.AddrPortFrom(netip.AddrFrom16(rua.Addr).Unmap(), (rua.Port>>8)|((rua.Port&0xff)<<8)), buffer[:n])
r(netip.AddrPortFrom(netip.AddrFrom16(rua.Addr).Unmap(), (rua.Port>>8)|((rua.Port&0xff)<<8)), buffer[:n], nil)
}
}

View File

@@ -112,7 +112,7 @@ func (u *TesterConn) ListenOut(r EncReader) {
if !ok {
return
}
r(p.From, p.Data)
r(p.From, p.Data, func() {})
}
}