mirror of
https://github.com/slackhq/nebula.git
synced 2025-11-22 08:24:25 +01:00
Compare commits
6 Commits
stinky
...
cross-stac
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f597aa71e3 | ||
|
|
20b7219fbe | ||
|
|
3b53c27170 | ||
|
|
526236c5fa | ||
|
|
0ab2882b78 | ||
|
|
889d49ff82 |
@@ -7,13 +7,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
|||||||
|
|
||||||
## [Unreleased]
|
## [Unreleased]
|
||||||
|
|
||||||
### Added
|
|
||||||
|
|
||||||
- Experimental Linux UDP offload support: enable `listen.enable_gso` and
|
|
||||||
`listen.enable_gro` to activate UDP_SEGMENT batching and GRO receive
|
|
||||||
splitting. Includes automatic capability probing, per-packet fallbacks, and
|
|
||||||
runtime metrics/logs for visibility.
|
|
||||||
|
|
||||||
### Changed
|
### Changed
|
||||||
|
|
||||||
- `default_local_cidr_any` now defaults to false, meaning that any firewall rule
|
- `default_local_cidr_any` now defaults to false, meaning that any firewall rule
|
||||||
|
|||||||
@@ -29,8 +29,6 @@ type m = map[string]any
|
|||||||
|
|
||||||
// newSimpleServer creates a nebula instance with many assumptions
|
// newSimpleServer creates a nebula instance with many assumptions
|
||||||
func newSimpleServer(v cert.Version, caCrt cert.Certificate, caKey []byte, name string, sVpnNetworks string, overrides m) (*nebula.Control, []netip.Prefix, netip.AddrPort, *config.C) {
|
func newSimpleServer(v cert.Version, caCrt cert.Certificate, caKey []byte, name string, sVpnNetworks string, overrides m) (*nebula.Control, []netip.Prefix, netip.AddrPort, *config.C) {
|
||||||
l := NewTestLogger()
|
|
||||||
|
|
||||||
var vpnNetworks []netip.Prefix
|
var vpnNetworks []netip.Prefix
|
||||||
for _, sn := range strings.Split(sVpnNetworks, ",") {
|
for _, sn := range strings.Split(sVpnNetworks, ",") {
|
||||||
vpnIpNet, err := netip.ParsePrefix(strings.TrimSpace(sn))
|
vpnIpNet, err := netip.ParsePrefix(strings.TrimSpace(sn))
|
||||||
@@ -56,6 +54,25 @@ func newSimpleServer(v cert.Version, caCrt cert.Certificate, caKey []byte, name
|
|||||||
budpIp[3] = 239
|
budpIp[3] = 239
|
||||||
udpAddr = netip.AddrPortFrom(netip.AddrFrom16(budpIp), 4242)
|
udpAddr = netip.AddrPortFrom(netip.AddrFrom16(budpIp), 4242)
|
||||||
}
|
}
|
||||||
|
return newSimpleServerWithUdp(v, caCrt, caKey, name, sVpnNetworks, udpAddr, overrides)
|
||||||
|
}
|
||||||
|
|
||||||
|
func newSimpleServerWithUdp(v cert.Version, caCrt cert.Certificate, caKey []byte, name string, sVpnNetworks string, udpAddr netip.AddrPort, overrides m) (*nebula.Control, []netip.Prefix, netip.AddrPort, *config.C) {
|
||||||
|
l := NewTestLogger()
|
||||||
|
|
||||||
|
var vpnNetworks []netip.Prefix
|
||||||
|
for _, sn := range strings.Split(sVpnNetworks, ",") {
|
||||||
|
vpnIpNet, err := netip.ParsePrefix(strings.TrimSpace(sn))
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
vpnNetworks = append(vpnNetworks, vpnIpNet)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(vpnNetworks) == 0 {
|
||||||
|
panic("no vpn networks")
|
||||||
|
}
|
||||||
|
|
||||||
_, _, myPrivKey, myPEM := cert_test.NewTestCert(v, cert.Curve_CURVE25519, caCrt, caKey, name, time.Now(), time.Now().Add(5*time.Minute), vpnNetworks, nil, []string{})
|
_, _, myPrivKey, myPEM := cert_test.NewTestCert(v, cert.Curve_CURVE25519, caCrt, caKey, name, time.Now(), time.Now().Add(5*time.Minute), vpnNetworks, nil, []string{})
|
||||||
|
|
||||||
caB, err := caCrt.MarshalPEM()
|
caB, err := caCrt.MarshalPEM()
|
||||||
|
|||||||
@@ -4,6 +4,7 @@
|
|||||||
package e2e
|
package e2e
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"net/netip"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -55,3 +56,50 @@ func TestDropInactiveTunnels(t *testing.T) {
|
|||||||
myControl.Stop()
|
myControl.Stop()
|
||||||
theirControl.Stop()
|
theirControl.Stop()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCrossStackRelaysWork(t *testing.T) {
|
||||||
|
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||||
|
myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version2, ca, caKey, "me ", "10.128.0.1/24,fc00::1/64", m{"relay": m{"use_relays": true}})
|
||||||
|
relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "relay ", "10.128.0.128/24,fc00::128/64", m{"relay": m{"am_relay": true}})
|
||||||
|
theirUdp := netip.MustParseAddrPort("10.0.0.2:4242")
|
||||||
|
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServerWithUdp(cert.Version2, ca, caKey, "them ", "fc00::2/64", theirUdp, m{"relay": m{"use_relays": true}})
|
||||||
|
|
||||||
|
//myVpnV4 := myVpnIpNet[0]
|
||||||
|
myVpnV6 := myVpnIpNet[1]
|
||||||
|
relayVpnV4 := relayVpnIpNet[0]
|
||||||
|
relayVpnV6 := relayVpnIpNet[1]
|
||||||
|
theirVpnV6 := theirVpnIpNet[0]
|
||||||
|
|
||||||
|
// Teach my how to get to the relay and that their can be reached via the relay
|
||||||
|
myControl.InjectLightHouseAddr(relayVpnV4.Addr(), relayUdpAddr)
|
||||||
|
myControl.InjectLightHouseAddr(relayVpnV6.Addr(), relayUdpAddr)
|
||||||
|
myControl.InjectRelays(theirVpnV6.Addr(), []netip.Addr{relayVpnV6.Addr()})
|
||||||
|
relayControl.InjectLightHouseAddr(theirVpnV6.Addr(), theirUdpAddr)
|
||||||
|
|
||||||
|
// Build a router so we don't have to reason who gets which packet
|
||||||
|
r := router.NewR(t, myControl, relayControl, theirControl)
|
||||||
|
defer r.RenderFlow()
|
||||||
|
|
||||||
|
// Start the servers
|
||||||
|
myControl.Start()
|
||||||
|
relayControl.Start()
|
||||||
|
theirControl.Start()
|
||||||
|
|
||||||
|
t.Log("Trigger a handshake from me to them via the relay")
|
||||||
|
myControl.InjectTunUDPPacket(theirVpnV6.Addr(), 80, myVpnV6.Addr(), 80, []byte("Hi from me"))
|
||||||
|
|
||||||
|
p := r.RouteForAllUntilTxTun(theirControl)
|
||||||
|
r.Log("Assert the tunnel works")
|
||||||
|
assertUdpPacket(t, []byte("Hi from me"), p, myVpnV6.Addr(), theirVpnV6.Addr(), 80, 80)
|
||||||
|
|
||||||
|
t.Log("reply?")
|
||||||
|
theirControl.InjectTunUDPPacket(myVpnV6.Addr(), 80, theirVpnV6.Addr(), 80, []byte("Hi from them"))
|
||||||
|
p = r.RouteForAllUntilTxTun(myControl)
|
||||||
|
assertUdpPacket(t, []byte("Hi from them"), p, theirVpnV6.Addr(), myVpnV6.Addr(), 80, 80)
|
||||||
|
|
||||||
|
r.RenderHostmaps("Final hostmaps", myControl, relayControl, theirControl)
|
||||||
|
//t.Log("finish up")
|
||||||
|
//myControl.Stop()
|
||||||
|
//theirControl.Stop()
|
||||||
|
//relayControl.Stop()
|
||||||
|
}
|
||||||
|
|||||||
29
firewall.go
29
firewall.go
@@ -417,6 +417,8 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var ErrUnknownNetworkType = errors.New("unknown network type")
|
||||||
|
var ErrPeerRejected = errors.New("remote IP is not within a subnet that we handle")
|
||||||
var ErrInvalidRemoteIP = errors.New("remote IP is not in remote certificate subnets")
|
var ErrInvalidRemoteIP = errors.New("remote IP is not in remote certificate subnets")
|
||||||
var ErrInvalidLocalIP = errors.New("local IP is not in list of handled local IPs")
|
var ErrInvalidLocalIP = errors.New("local IP is not in list of handled local IPs")
|
||||||
var ErrNoMatchingRule = errors.New("no matching rule in firewall table")
|
var ErrNoMatchingRule = errors.New("no matching rule in firewall table")
|
||||||
@@ -429,18 +431,31 @@ func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Make sure remote address matches nebula certificate
|
// Make sure remote address matches nebula certificate, and determine how to treat it
|
||||||
if h.networks != nil {
|
if h.networks == nil {
|
||||||
if !h.networks.Contains(fp.RemoteAddr) {
|
|
||||||
f.metrics(incoming).droppedRemoteAddr.Inc(1)
|
|
||||||
return ErrInvalidRemoteIP
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// Simple case: Certificate has one address and no unsafe networks
|
// Simple case: Certificate has one address and no unsafe networks
|
||||||
if h.vpnAddrs[0] != fp.RemoteAddr {
|
if h.vpnAddrs[0] != fp.RemoteAddr {
|
||||||
f.metrics(incoming).droppedRemoteAddr.Inc(1)
|
f.metrics(incoming).droppedRemoteAddr.Inc(1)
|
||||||
return ErrInvalidRemoteIP
|
return ErrInvalidRemoteIP
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
nwType, ok := h.networks.Lookup(fp.RemoteAddr)
|
||||||
|
if !ok {
|
||||||
|
f.metrics(incoming).droppedRemoteAddr.Inc(1)
|
||||||
|
return ErrInvalidRemoteIP
|
||||||
|
}
|
||||||
|
switch nwType {
|
||||||
|
case NetworkTypeVPN:
|
||||||
|
break // nothing special
|
||||||
|
case NetworkTypeVPNPeer:
|
||||||
|
f.metrics(incoming).droppedRemoteAddr.Inc(1)
|
||||||
|
return ErrPeerRejected // reject for now, one day this may have different FW rules
|
||||||
|
case NetworkTypeUnsafe:
|
||||||
|
break // nothing special, one day this may have different FW rules
|
||||||
|
default:
|
||||||
|
f.metrics(incoming).droppedRemoteAddr.Inc(1)
|
||||||
|
return ErrUnknownNetworkType //should never happen
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Make sure we are supposed to be handling this local ip address
|
// Make sure we are supposed to be handling this local ip address
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/gaissmai/bart"
|
||||||
"github.com/slackhq/nebula/cert"
|
"github.com/slackhq/nebula/cert"
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
"github.com/slackhq/nebula/firewall"
|
"github.com/slackhq/nebula/firewall"
|
||||||
@@ -149,7 +150,8 @@ func TestFirewall_Drop(t *testing.T) {
|
|||||||
l := test.NewLogger()
|
l := test.NewLogger()
|
||||||
ob := &bytes.Buffer{}
|
ob := &bytes.Buffer{}
|
||||||
l.SetOutput(ob)
|
l.SetOutput(ob)
|
||||||
|
myVpnNetworksTable := new(bart.Lite)
|
||||||
|
myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
|
||||||
p := firewall.Packet{
|
p := firewall.Packet{
|
||||||
LocalAddr: netip.MustParseAddr("1.2.3.4"),
|
LocalAddr: netip.MustParseAddr("1.2.3.4"),
|
||||||
RemoteAddr: netip.MustParseAddr("1.2.3.4"),
|
RemoteAddr: netip.MustParseAddr("1.2.3.4"),
|
||||||
@@ -174,7 +176,7 @@ func TestFirewall_Drop(t *testing.T) {
|
|||||||
},
|
},
|
||||||
vpnAddrs: []netip.Addr{netip.MustParseAddr("1.2.3.4")},
|
vpnAddrs: []netip.Addr{netip.MustParseAddr("1.2.3.4")},
|
||||||
}
|
}
|
||||||
h.buildNetworks(c.networks, c.unsafeNetworks)
|
h.buildNetworks(myVpnNetworksTable, c.networks, c.unsafeNetworks)
|
||||||
|
|
||||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
||||||
@@ -226,6 +228,9 @@ func TestFirewall_DropV6(t *testing.T) {
|
|||||||
ob := &bytes.Buffer{}
|
ob := &bytes.Buffer{}
|
||||||
l.SetOutput(ob)
|
l.SetOutput(ob)
|
||||||
|
|
||||||
|
myVpnNetworksTable := new(bart.Lite)
|
||||||
|
myVpnNetworksTable.Insert(netip.MustParsePrefix("fd00::/7"))
|
||||||
|
|
||||||
p := firewall.Packet{
|
p := firewall.Packet{
|
||||||
LocalAddr: netip.MustParseAddr("fd12::34"),
|
LocalAddr: netip.MustParseAddr("fd12::34"),
|
||||||
RemoteAddr: netip.MustParseAddr("fd12::34"),
|
RemoteAddr: netip.MustParseAddr("fd12::34"),
|
||||||
@@ -250,7 +255,7 @@ func TestFirewall_DropV6(t *testing.T) {
|
|||||||
},
|
},
|
||||||
vpnAddrs: []netip.Addr{netip.MustParseAddr("fd12::34")},
|
vpnAddrs: []netip.Addr{netip.MustParseAddr("fd12::34")},
|
||||||
}
|
}
|
||||||
h.buildNetworks(c.networks, c.unsafeNetworks)
|
h.buildNetworks(myVpnNetworksTable, c.networks, c.unsafeNetworks)
|
||||||
|
|
||||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
||||||
@@ -453,6 +458,8 @@ func TestFirewall_Drop2(t *testing.T) {
|
|||||||
l := test.NewLogger()
|
l := test.NewLogger()
|
||||||
ob := &bytes.Buffer{}
|
ob := &bytes.Buffer{}
|
||||||
l.SetOutput(ob)
|
l.SetOutput(ob)
|
||||||
|
myVpnNetworksTable := new(bart.Lite)
|
||||||
|
myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
|
||||||
|
|
||||||
p := firewall.Packet{
|
p := firewall.Packet{
|
||||||
LocalAddr: netip.MustParseAddr("1.2.3.4"),
|
LocalAddr: netip.MustParseAddr("1.2.3.4"),
|
||||||
@@ -478,7 +485,7 @@ func TestFirewall_Drop2(t *testing.T) {
|
|||||||
},
|
},
|
||||||
vpnAddrs: []netip.Addr{network.Addr()},
|
vpnAddrs: []netip.Addr{network.Addr()},
|
||||||
}
|
}
|
||||||
h.buildNetworks(c.Certificate.Networks(), c.Certificate.UnsafeNetworks())
|
h.buildNetworks(myVpnNetworksTable, c.Certificate.Networks(), c.Certificate.UnsafeNetworks())
|
||||||
|
|
||||||
c1 := cert.CachedCertificate{
|
c1 := cert.CachedCertificate{
|
||||||
Certificate: &dummyCert{
|
Certificate: &dummyCert{
|
||||||
@@ -493,7 +500,7 @@ func TestFirewall_Drop2(t *testing.T) {
|
|||||||
peerCert: &c1,
|
peerCert: &c1,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
h1.buildNetworks(c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks())
|
h1.buildNetworks(myVpnNetworksTable, c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks())
|
||||||
|
|
||||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
||||||
@@ -510,6 +517,8 @@ func TestFirewall_Drop3(t *testing.T) {
|
|||||||
l := test.NewLogger()
|
l := test.NewLogger()
|
||||||
ob := &bytes.Buffer{}
|
ob := &bytes.Buffer{}
|
||||||
l.SetOutput(ob)
|
l.SetOutput(ob)
|
||||||
|
myVpnNetworksTable := new(bart.Lite)
|
||||||
|
myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
|
||||||
|
|
||||||
p := firewall.Packet{
|
p := firewall.Packet{
|
||||||
LocalAddr: netip.MustParseAddr("1.2.3.4"),
|
LocalAddr: netip.MustParseAddr("1.2.3.4"),
|
||||||
@@ -541,7 +550,7 @@ func TestFirewall_Drop3(t *testing.T) {
|
|||||||
},
|
},
|
||||||
vpnAddrs: []netip.Addr{network.Addr()},
|
vpnAddrs: []netip.Addr{network.Addr()},
|
||||||
}
|
}
|
||||||
h1.buildNetworks(c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks())
|
h1.buildNetworks(myVpnNetworksTable, c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks())
|
||||||
|
|
||||||
c2 := cert.CachedCertificate{
|
c2 := cert.CachedCertificate{
|
||||||
Certificate: &dummyCert{
|
Certificate: &dummyCert{
|
||||||
@@ -556,7 +565,7 @@ func TestFirewall_Drop3(t *testing.T) {
|
|||||||
},
|
},
|
||||||
vpnAddrs: []netip.Addr{network.Addr()},
|
vpnAddrs: []netip.Addr{network.Addr()},
|
||||||
}
|
}
|
||||||
h2.buildNetworks(c2.Certificate.Networks(), c2.Certificate.UnsafeNetworks())
|
h2.buildNetworks(myVpnNetworksTable, c2.Certificate.Networks(), c2.Certificate.UnsafeNetworks())
|
||||||
|
|
||||||
c3 := cert.CachedCertificate{
|
c3 := cert.CachedCertificate{
|
||||||
Certificate: &dummyCert{
|
Certificate: &dummyCert{
|
||||||
@@ -571,7 +580,7 @@ func TestFirewall_Drop3(t *testing.T) {
|
|||||||
},
|
},
|
||||||
vpnAddrs: []netip.Addr{network.Addr()},
|
vpnAddrs: []netip.Addr{network.Addr()},
|
||||||
}
|
}
|
||||||
h3.buildNetworks(c3.Certificate.Networks(), c3.Certificate.UnsafeNetworks())
|
h3.buildNetworks(myVpnNetworksTable, c3.Certificate.Networks(), c3.Certificate.UnsafeNetworks())
|
||||||
|
|
||||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", netip.Prefix{}, netip.Prefix{}, "", ""))
|
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", netip.Prefix{}, netip.Prefix{}, "", ""))
|
||||||
@@ -597,6 +606,8 @@ func TestFirewall_Drop3V6(t *testing.T) {
|
|||||||
l := test.NewLogger()
|
l := test.NewLogger()
|
||||||
ob := &bytes.Buffer{}
|
ob := &bytes.Buffer{}
|
||||||
l.SetOutput(ob)
|
l.SetOutput(ob)
|
||||||
|
myVpnNetworksTable := new(bart.Lite)
|
||||||
|
myVpnNetworksTable.Insert(netip.MustParsePrefix("fd00::/7"))
|
||||||
|
|
||||||
p := firewall.Packet{
|
p := firewall.Packet{
|
||||||
LocalAddr: netip.MustParseAddr("fd12::34"),
|
LocalAddr: netip.MustParseAddr("fd12::34"),
|
||||||
@@ -620,7 +631,7 @@ func TestFirewall_Drop3V6(t *testing.T) {
|
|||||||
},
|
},
|
||||||
vpnAddrs: []netip.Addr{network.Addr()},
|
vpnAddrs: []netip.Addr{network.Addr()},
|
||||||
}
|
}
|
||||||
h.buildNetworks(c.Certificate.Networks(), c.Certificate.UnsafeNetworks())
|
h.buildNetworks(myVpnNetworksTable, c.Certificate.Networks(), c.Certificate.UnsafeNetworks())
|
||||||
|
|
||||||
// Test a remote address match
|
// Test a remote address match
|
||||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
||||||
@@ -633,6 +644,8 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
|
|||||||
l := test.NewLogger()
|
l := test.NewLogger()
|
||||||
ob := &bytes.Buffer{}
|
ob := &bytes.Buffer{}
|
||||||
l.SetOutput(ob)
|
l.SetOutput(ob)
|
||||||
|
myVpnNetworksTable := new(bart.Lite)
|
||||||
|
myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
|
||||||
|
|
||||||
p := firewall.Packet{
|
p := firewall.Packet{
|
||||||
LocalAddr: netip.MustParseAddr("1.2.3.4"),
|
LocalAddr: netip.MustParseAddr("1.2.3.4"),
|
||||||
@@ -659,7 +672,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
|
|||||||
},
|
},
|
||||||
vpnAddrs: []netip.Addr{network.Addr()},
|
vpnAddrs: []netip.Addr{network.Addr()},
|
||||||
}
|
}
|
||||||
h.buildNetworks(c.Certificate.Networks(), c.Certificate.UnsafeNetworks())
|
h.buildNetworks(myVpnNetworksTable, c.Certificate.Networks(), c.Certificate.UnsafeNetworks())
|
||||||
|
|
||||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
||||||
@@ -696,6 +709,8 @@ func TestFirewall_DropIPSpoofing(t *testing.T) {
|
|||||||
l := test.NewLogger()
|
l := test.NewLogger()
|
||||||
ob := &bytes.Buffer{}
|
ob := &bytes.Buffer{}
|
||||||
l.SetOutput(ob)
|
l.SetOutput(ob)
|
||||||
|
myVpnNetworksTable := new(bart.Lite)
|
||||||
|
myVpnNetworksTable.Insert(netip.MustParsePrefix("192.0.2.1/24"))
|
||||||
|
|
||||||
c := cert.CachedCertificate{
|
c := cert.CachedCertificate{
|
||||||
Certificate: &dummyCert{
|
Certificate: &dummyCert{
|
||||||
@@ -717,7 +732,7 @@ func TestFirewall_DropIPSpoofing(t *testing.T) {
|
|||||||
},
|
},
|
||||||
vpnAddrs: []netip.Addr{c1.Certificate.Networks()[0].Addr()},
|
vpnAddrs: []netip.Addr{c1.Certificate.Networks()[0].Addr()},
|
||||||
}
|
}
|
||||||
h1.buildNetworks(c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks())
|
h1.buildNetworks(myVpnNetworksTable, c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks())
|
||||||
|
|
||||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
||||||
|
|
||||||
|
|||||||
108
handshake_ix.go
108
handshake_ix.go
@@ -183,17 +183,18 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var vpnAddrs []netip.Addr
|
|
||||||
var filteredNetworks []netip.Prefix
|
|
||||||
certName := remoteCert.Certificate.Name()
|
certName := remoteCert.Certificate.Name()
|
||||||
certVersion := remoteCert.Certificate.Version()
|
certVersion := remoteCert.Certificate.Version()
|
||||||
fingerprint := remoteCert.Fingerprint
|
fingerprint := remoteCert.Fingerprint
|
||||||
issuer := remoteCert.Certificate.Issuer()
|
issuer := remoteCert.Certificate.Issuer()
|
||||||
|
vpnNetworks := remoteCert.Certificate.Networks()
|
||||||
|
|
||||||
for _, network := range remoteCert.Certificate.Networks() {
|
anyVpnAddrsInCommon := false
|
||||||
|
vpnAddrs := make([]netip.Addr, len(vpnNetworks))
|
||||||
|
for i, network := range vpnNetworks {
|
||||||
vpnAddr := network.Addr()
|
vpnAddr := network.Addr()
|
||||||
if f.myVpnAddrsTable.Contains(vpnAddr) {
|
if f.myVpnAddrsTable.Contains(vpnAddr) {
|
||||||
f.l.WithField("vpnAddr", vpnAddr).WithField("udpAddr", addr).
|
f.l.WithField("vpnNetworks", vpnNetworks).WithField("udpAddr", addr).
|
||||||
WithField("certName", certName).
|
WithField("certName", certName).
|
||||||
WithField("certVersion", certVersion).
|
WithField("certVersion", certVersion).
|
||||||
WithField("fingerprint", fingerprint).
|
WithField("fingerprint", fingerprint).
|
||||||
@@ -201,24 +202,10 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
|||||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Refusing to handshake with myself")
|
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Refusing to handshake with myself")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
vpnAddrs[i] = network.Addr()
|
||||||
// vpnAddrs outside our vpn networks are of no use to us, filter them out
|
if f.myVpnNetworksTable.Contains(vpnAddr) {
|
||||||
if !f.myVpnNetworksTable.Contains(vpnAddr) {
|
anyVpnAddrsInCommon = true
|
||||||
continue
|
|
||||||
}
|
}
|
||||||
|
|
||||||
filteredNetworks = append(filteredNetworks, network)
|
|
||||||
vpnAddrs = append(vpnAddrs, vpnAddr)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(vpnAddrs) == 0 {
|
|
||||||
f.l.WithError(err).WithField("udpAddr", addr).
|
|
||||||
WithField("certName", certName).
|
|
||||||
WithField("certVersion", certVersion).
|
|
||||||
WithField("fingerprint", fingerprint).
|
|
||||||
WithField("issuer", issuer).
|
|
||||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("No usable vpn addresses from host, refusing handshake")
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if addr.IsValid() {
|
if addr.IsValid() {
|
||||||
@@ -255,26 +242,30 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
|
msgRxL := f.l.WithFields(m{
|
||||||
WithField("certName", certName).
|
"vpnAddrs": vpnAddrs,
|
||||||
WithField("certVersion", certVersion).
|
"udpAddr": addr,
|
||||||
WithField("fingerprint", fingerprint).
|
"certName": certName,
|
||||||
WithField("issuer", issuer).
|
"certVersion": certVersion,
|
||||||
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
"fingerprint": fingerprint,
|
||||||
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
"issuer": issuer,
|
||||||
Info("Handshake message received")
|
"initiatorIndex": hs.Details.InitiatorIndex,
|
||||||
|
"responderIndex": hs.Details.ResponderIndex,
|
||||||
|
"remoteIndex": h.RemoteIndex,
|
||||||
|
"handshake": m{"stage": 1, "style": "ix_psk0"},
|
||||||
|
})
|
||||||
|
|
||||||
|
if anyVpnAddrsInCommon {
|
||||||
|
msgRxL.Info("Handshake message received")
|
||||||
|
} else {
|
||||||
|
//todo warn if not lighthouse or relay?
|
||||||
|
msgRxL.Info("Handshake message received, but no vpnNetworks in common.")
|
||||||
|
}
|
||||||
|
|
||||||
hs.Details.ResponderIndex = myIndex
|
hs.Details.ResponderIndex = myIndex
|
||||||
hs.Details.Cert = cs.getHandshakeBytes(ci.myCert.Version())
|
hs.Details.Cert = cs.getHandshakeBytes(ci.myCert.Version())
|
||||||
if hs.Details.Cert == nil {
|
if hs.Details.Cert == nil {
|
||||||
f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
|
msgRxL.WithField("myCertVersion", ci.myCert.Version()).
|
||||||
WithField("certName", certName).
|
|
||||||
WithField("certVersion", certVersion).
|
|
||||||
WithField("fingerprint", fingerprint).
|
|
||||||
WithField("issuer", issuer).
|
|
||||||
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
|
||||||
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
|
||||||
WithField("certVersion", ci.myCert.Version()).
|
|
||||||
Error("Unable to handshake with host because no certificate handshake bytes is available")
|
Error("Unable to handshake with host because no certificate handshake bytes is available")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -332,7 +323,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
|||||||
|
|
||||||
hostinfo.remotes = f.lightHouse.QueryCache(vpnAddrs)
|
hostinfo.remotes = f.lightHouse.QueryCache(vpnAddrs)
|
||||||
hostinfo.SetRemote(addr)
|
hostinfo.SetRemote(addr)
|
||||||
hostinfo.buildNetworks(filteredNetworks, remoteCert.Certificate.UnsafeNetworks())
|
hostinfo.buildNetworks(f.myVpnNetworksTable, remoteCert.Certificate.Networks(), remoteCert.Certificate.UnsafeNetworks())
|
||||||
|
|
||||||
existing, err := f.handshakeManager.CheckAndComplete(hostinfo, 0, f)
|
existing, err := f.handshakeManager.CheckAndComplete(hostinfo, 0, f)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -573,30 +564,17 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
|
|||||||
hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0])
|
hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0])
|
||||||
}
|
}
|
||||||
|
|
||||||
var vpnAddrs []netip.Addr
|
anyVpnAddrsInCommon := false
|
||||||
var filteredNetworks []netip.Prefix
|
vpnAddrs := make([]netip.Addr, len(vpnNetworks))
|
||||||
for _, network := range vpnNetworks {
|
for i, network := range vpnNetworks {
|
||||||
// vpnAddrs outside our vpn networks are of no use to us, filter them out
|
vpnAddrs[i] = network.Addr()
|
||||||
vpnAddr := network.Addr()
|
if f.myVpnNetworksTable.Contains(network.Addr()) {
|
||||||
if !f.myVpnNetworksTable.Contains(vpnAddr) {
|
anyVpnAddrsInCommon = true
|
||||||
continue
|
|
||||||
}
|
}
|
||||||
|
|
||||||
filteredNetworks = append(filteredNetworks, network)
|
|
||||||
vpnAddrs = append(vpnAddrs, vpnAddr)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(vpnAddrs) == 0 {
|
|
||||||
f.l.WithError(err).WithField("udpAddr", addr).
|
|
||||||
WithField("certName", certName).
|
|
||||||
WithField("certVersion", certVersion).
|
|
||||||
WithField("fingerprint", fingerprint).
|
|
||||||
WithField("issuer", issuer).
|
|
||||||
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).Error("No usable vpn addresses from host, refusing handshake")
|
|
||||||
return true
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Ensure the right host responded
|
// Ensure the right host responded
|
||||||
|
// todo is it more correct to see if any of hostinfo.vpnAddrs are in the cert? it should have len==1, but one day it might not?
|
||||||
if !slices.Contains(vpnAddrs, hostinfo.vpnAddrs[0]) {
|
if !slices.Contains(vpnAddrs, hostinfo.vpnAddrs[0]) {
|
||||||
f.l.WithField("intendedVpnAddrs", hostinfo.vpnAddrs).WithField("haveVpnNetworks", vpnNetworks).
|
f.l.WithField("intendedVpnAddrs", hostinfo.vpnAddrs).WithField("haveVpnNetworks", vpnNetworks).
|
||||||
WithField("udpAddr", addr).
|
WithField("udpAddr", addr).
|
||||||
@@ -609,6 +587,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
|
|||||||
f.handshakeManager.DeleteHostInfo(hostinfo)
|
f.handshakeManager.DeleteHostInfo(hostinfo)
|
||||||
|
|
||||||
// Create a new hostinfo/handshake for the intended vpn ip
|
// Create a new hostinfo/handshake for the intended vpn ip
|
||||||
|
//TODO is hostinfo.vpnAddrs[0] always the address to use?
|
||||||
f.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], func(newHH *HandshakeHostInfo) {
|
f.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], func(newHH *HandshakeHostInfo) {
|
||||||
// Block the current used address
|
// Block the current used address
|
||||||
newHH.hostinfo.remotes = hostinfo.remotes
|
newHH.hostinfo.remotes = hostinfo.remotes
|
||||||
@@ -635,7 +614,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
|
|||||||
ci.window.Update(f.l, 2)
|
ci.window.Update(f.l, 2)
|
||||||
|
|
||||||
duration := time.Since(hh.startTime).Nanoseconds()
|
duration := time.Since(hh.startTime).Nanoseconds()
|
||||||
f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
|
msgRxL := f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
|
||||||
WithField("certName", certName).
|
WithField("certName", certName).
|
||||||
WithField("certVersion", certVersion).
|
WithField("certVersion", certVersion).
|
||||||
WithField("fingerprint", fingerprint).
|
WithField("fingerprint", fingerprint).
|
||||||
@@ -643,12 +622,17 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
|
|||||||
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
||||||
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
|
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
|
||||||
WithField("durationNs", duration).
|
WithField("durationNs", duration).
|
||||||
WithField("sentCachedPackets", len(hh.packetStore)).
|
WithField("sentCachedPackets", len(hh.packetStore))
|
||||||
Info("Handshake message received")
|
if anyVpnAddrsInCommon {
|
||||||
|
msgRxL.Info("Handshake message received")
|
||||||
|
} else {
|
||||||
|
//todo warn if not lighthouse or relay?
|
||||||
|
msgRxL.Info("Handshake message received, but no vpnNetworks in common.")
|
||||||
|
}
|
||||||
|
|
||||||
// Build up the radix for the firewall if we have subnets in the cert
|
// Build up the radix for the firewall if we have subnets in the cert
|
||||||
hostinfo.vpnAddrs = vpnAddrs
|
hostinfo.vpnAddrs = vpnAddrs
|
||||||
hostinfo.buildNetworks(filteredNetworks, remoteCert.Certificate.UnsafeNetworks())
|
hostinfo.buildNetworks(f.myVpnNetworksTable, remoteCert.Certificate.Networks(), remoteCert.Certificate.UnsafeNetworks())
|
||||||
|
|
||||||
// Complete our handshake and update metrics, this will replace any existing tunnels for the vpnAddrs here
|
// Complete our handshake and update metrics, this will replace any existing tunnels for the vpnAddrs here
|
||||||
f.handshakeManager.Complete(hostinfo, f)
|
f.handshakeManager.Complete(hostinfo, f)
|
||||||
|
|||||||
39
hostmap.go
39
hostmap.go
@@ -212,6 +212,18 @@ func (rs *RelayState) InsertRelay(ip netip.Addr, idx uint32, r *Relay) {
|
|||||||
rs.relayForByIdx[idx] = r
|
rs.relayForByIdx[idx] = r
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type NetworkType uint8
|
||||||
|
|
||||||
|
const (
|
||||||
|
NetworkTypeUnknown NetworkType = iota
|
||||||
|
// NetworkTypeVPN is a network that overlaps one or more of the vpnNetworks in our certificate
|
||||||
|
NetworkTypeVPN
|
||||||
|
// NetworkTypeVPNPeer is a network that does not overlap one of our networks
|
||||||
|
NetworkTypeVPNPeer
|
||||||
|
// NetworkTypeUnsafe is a network from Certificate.UnsafeNetworks()
|
||||||
|
NetworkTypeUnsafe
|
||||||
|
)
|
||||||
|
|
||||||
type HostInfo struct {
|
type HostInfo struct {
|
||||||
remote netip.AddrPort
|
remote netip.AddrPort
|
||||||
remotes *RemoteList
|
remotes *RemoteList
|
||||||
@@ -220,13 +232,11 @@ type HostInfo struct {
|
|||||||
remoteIndexId uint32
|
remoteIndexId uint32
|
||||||
localIndexId uint32
|
localIndexId uint32
|
||||||
|
|
||||||
// vpnAddrs is a list of vpn addresses assigned to this host that are within our own vpn networks
|
// vpnAddrs is a list of vpn addresses assigned to this host
|
||||||
// The host may have other vpn addresses that are outside our
|
|
||||||
// vpn networks but were removed because they are not usable
|
|
||||||
vpnAddrs []netip.Addr
|
vpnAddrs []netip.Addr
|
||||||
|
|
||||||
// networks are both all vpn and unsafe networks assigned to this host
|
// networks is a combination of specific vpn addresses (not prefixes!) and full unsafe networks assigned to this host.
|
||||||
networks *bart.Lite
|
networks *bart.Table[NetworkType]
|
||||||
relayState RelayState
|
relayState RelayState
|
||||||
|
|
||||||
// HandshakePacket records the packets used to create this hostinfo
|
// HandshakePacket records the packets used to create this hostinfo
|
||||||
@@ -730,20 +740,27 @@ func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, newRemote netip.AddrPort) b
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (i *HostInfo) buildNetworks(networks, unsafeNetworks []netip.Prefix) {
|
func (i *HostInfo) buildNetworks(myVpnNetworksTable *bart.Lite, networks, unsafeNetworks []netip.Prefix) {
|
||||||
if len(networks) == 1 && len(unsafeNetworks) == 0 {
|
if len(networks) == 1 && len(unsafeNetworks) == 0 {
|
||||||
// Simple case, no CIDRTree needed
|
if myVpnNetworksTable.Contains(networks[0].Addr()) {
|
||||||
return
|
return // Simple case, no CIDRTree needed
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
i.networks = new(bart.Lite)
|
i.networks = new(bart.Table[NetworkType])
|
||||||
for _, network := range networks {
|
for _, network := range networks {
|
||||||
|
var nwType NetworkType
|
||||||
|
if myVpnNetworksTable.Contains(network.Addr()) {
|
||||||
|
nwType = NetworkTypeVPN
|
||||||
|
} else {
|
||||||
|
nwType = NetworkTypeVPNPeer
|
||||||
|
}
|
||||||
nprefix := netip.PrefixFrom(network.Addr(), network.Addr().BitLen())
|
nprefix := netip.PrefixFrom(network.Addr(), network.Addr().BitLen())
|
||||||
i.networks.Insert(nprefix)
|
i.networks.Insert(nprefix, nwType)
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, network := range unsafeNetworks {
|
for _, network := range unsafeNetworks {
|
||||||
i.networks.Insert(network)
|
i.networks.Insert(network, NetworkTypeUnsafe)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1017,17 +1017,17 @@ func (lhh *LightHouseHandler) resetMeta() *NebulaMeta {
|
|||||||
return lhh.meta
|
return lhh.meta
|
||||||
}
|
}
|
||||||
|
|
||||||
func (lhh *LightHouseHandler) HandleRequest(rAddr netip.AddrPort, fromVpnAddrs []netip.Addr, p []byte, w EncWriter) {
|
func (lhh *LightHouseHandler) HandleRequest(rAddr netip.AddrPort, hostinfo *HostInfo, p []byte, w EncWriter) {
|
||||||
n := lhh.resetMeta()
|
n := lhh.resetMeta()
|
||||||
err := n.Unmarshal(p)
|
err := n.Unmarshal(p)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
lhh.l.WithError(err).WithField("vpnAddrs", fromVpnAddrs).WithField("udpAddr", rAddr).
|
lhh.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", rAddr).
|
||||||
Error("Failed to unmarshal lighthouse packet")
|
Error("Failed to unmarshal lighthouse packet")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if n.Details == nil {
|
if n.Details == nil {
|
||||||
lhh.l.WithField("vpnAddrs", fromVpnAddrs).WithField("udpAddr", rAddr).
|
lhh.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", rAddr).
|
||||||
Error("Invalid lighthouse update")
|
Error("Invalid lighthouse update")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -1036,24 +1036,24 @@ func (lhh *LightHouseHandler) HandleRequest(rAddr netip.AddrPort, fromVpnAddrs [
|
|||||||
|
|
||||||
switch n.Type {
|
switch n.Type {
|
||||||
case NebulaMeta_HostQuery:
|
case NebulaMeta_HostQuery:
|
||||||
lhh.handleHostQuery(n, fromVpnAddrs, rAddr, w)
|
lhh.handleHostQuery(n, hostinfo, rAddr, w)
|
||||||
|
|
||||||
case NebulaMeta_HostQueryReply:
|
case NebulaMeta_HostQueryReply:
|
||||||
lhh.handleHostQueryReply(n, fromVpnAddrs)
|
lhh.handleHostQueryReply(n, hostinfo.vpnAddrs)
|
||||||
|
|
||||||
case NebulaMeta_HostUpdateNotification:
|
case NebulaMeta_HostUpdateNotification:
|
||||||
lhh.handleHostUpdateNotification(n, fromVpnAddrs, w)
|
lhh.handleHostUpdateNotification(n, hostinfo, w)
|
||||||
|
|
||||||
case NebulaMeta_HostMovedNotification:
|
case NebulaMeta_HostMovedNotification:
|
||||||
case NebulaMeta_HostPunchNotification:
|
case NebulaMeta_HostPunchNotification:
|
||||||
lhh.handleHostPunchNotification(n, fromVpnAddrs, w)
|
lhh.handleHostPunchNotification(n, hostinfo.vpnAddrs, w)
|
||||||
|
|
||||||
case NebulaMeta_HostUpdateNotificationAck:
|
case NebulaMeta_HostUpdateNotificationAck:
|
||||||
// noop
|
// noop
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, fromVpnAddrs []netip.Addr, addr netip.AddrPort, w EncWriter) {
|
func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, hostinfo *HostInfo, addr netip.AddrPort, w EncWriter) {
|
||||||
// Exit if we don't answer queries
|
// Exit if we don't answer queries
|
||||||
if !lhh.lh.amLighthouse {
|
if !lhh.lh.amLighthouse {
|
||||||
if lhh.l.Level >= logrus.DebugLevel {
|
if lhh.l.Level >= logrus.DebugLevel {
|
||||||
@@ -1065,7 +1065,7 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, fromVpnAddrs []neti
|
|||||||
queryVpnAddr, useVersion, err := n.Details.GetVpnAddrAndVersion()
|
queryVpnAddr, useVersion, err := n.Details.GetVpnAddrAndVersion()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if lhh.l.Level >= logrus.DebugLevel {
|
if lhh.l.Level >= logrus.DebugLevel {
|
||||||
lhh.l.WithField("from", fromVpnAddrs).WithField("details", n.Details).
|
lhh.l.WithField("from", hostinfo.vpnAddrs).WithField("details", n.Details).
|
||||||
Debugln("Dropping malformed HostQuery")
|
Debugln("Dropping malformed HostQuery")
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
@@ -1073,7 +1073,7 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, fromVpnAddrs []neti
|
|||||||
if useVersion == cert.Version1 && queryVpnAddr.Is6() {
|
if useVersion == cert.Version1 && queryVpnAddr.Is6() {
|
||||||
// this case really shouldn't be possible to represent, but reject it anyway.
|
// this case really shouldn't be possible to represent, but reject it anyway.
|
||||||
if lhh.l.Level >= logrus.DebugLevel {
|
if lhh.l.Level >= logrus.DebugLevel {
|
||||||
lhh.l.WithField("vpnAddrs", fromVpnAddrs).WithField("queryVpnAddr", queryVpnAddr).
|
lhh.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("queryVpnAddr", queryVpnAddr).
|
||||||
Debugln("invalid vpn addr for v1 handleHostQuery")
|
Debugln("invalid vpn addr for v1 handleHostQuery")
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
@@ -1099,14 +1099,14 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, fromVpnAddrs []neti
|
|||||||
}
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
lhh.l.WithError(err).WithField("vpnAddrs", fromVpnAddrs).Error("Failed to marshal lighthouse host query reply")
|
lhh.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).Error("Failed to marshal lighthouse host query reply")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
lhh.lh.metricTx(NebulaMeta_HostQueryReply, 1)
|
lhh.lh.metricTx(NebulaMeta_HostQueryReply, 1)
|
||||||
w.SendMessageToVpnAddr(header.LightHouse, 0, fromVpnAddrs[0], lhh.pb[:ln], lhh.nb, lhh.out[:0])
|
w.SendMessageToHostInfo(header.LightHouse, 0, hostinfo, lhh.pb[:ln], lhh.nb, lhh.out[:0])
|
||||||
|
|
||||||
lhh.sendHostPunchNotification(n, fromVpnAddrs, queryVpnAddr, w)
|
lhh.sendHostPunchNotification(n, hostinfo.vpnAddrs, queryVpnAddr, w)
|
||||||
}
|
}
|
||||||
|
|
||||||
// sendHostPunchNotification signals the other side to punch some zero byte udp packets
|
// sendHostPunchNotification signals the other side to punch some zero byte udp packets
|
||||||
@@ -1115,20 +1115,34 @@ func (lhh *LightHouseHandler) sendHostPunchNotification(n *NebulaMeta, fromVpnAd
|
|||||||
found, ln, err := lhh.lh.queryAndPrepMessage(whereToPunch, func(c *cache) (int, error) {
|
found, ln, err := lhh.lh.queryAndPrepMessage(whereToPunch, func(c *cache) (int, error) {
|
||||||
n = lhh.resetMeta()
|
n = lhh.resetMeta()
|
||||||
n.Type = NebulaMeta_HostPunchNotification
|
n.Type = NebulaMeta_HostPunchNotification
|
||||||
targetHI := lhh.lh.ifce.GetHostInfo(punchNotifDest)
|
punchNotifDestHI := lhh.lh.ifce.GetHostInfo(punchNotifDest)
|
||||||
var useVersion cert.Version
|
var useVersion cert.Version
|
||||||
if targetHI == nil {
|
if punchNotifDestHI == nil {
|
||||||
useVersion = lhh.lh.ifce.GetCertState().initiatingVersion
|
useVersion = lhh.lh.ifce.GetCertState().initiatingVersion
|
||||||
} else {
|
} else {
|
||||||
crt := targetHI.GetCert().Certificate
|
|
||||||
useVersion = crt.Version()
|
|
||||||
// we can only retarget if we have a hostinfo
|
// we can only retarget if we have a hostinfo
|
||||||
newDest, ok := findNetworkUnion(crt.Networks(), fromVpnAddrs)
|
punchNotifDestCrt := punchNotifDestHI.GetCert().Certificate
|
||||||
|
useVersion = punchNotifDestCrt.Version()
|
||||||
|
punchNotifDestNetworks := punchNotifDestCrt.Networks()
|
||||||
|
|
||||||
|
//if we (the lighthouse) don't have a network in common with punchNotifDest, try to find one
|
||||||
|
if !lhh.lh.myVpnNetworksTable.Contains(punchNotifDest) {
|
||||||
|
newPunchNotifDest, ok := findNetworkUnion(lhh.lh.myVpnNetworks, punchNotifDestHI.vpnAddrs)
|
||||||
|
if ok {
|
||||||
|
punchNotifDest = newPunchNotifDest
|
||||||
|
} else {
|
||||||
|
if lhh.l.Level >= logrus.DebugLevel {
|
||||||
|
lhh.l.WithField("to", punchNotifDestNetworks).Debugln("unable to notify host to host, no addresses in common")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
newWhereToPunch, ok := findNetworkUnion(punchNotifDestNetworks, fromVpnAddrs)
|
||||||
if ok {
|
if ok {
|
||||||
whereToPunch = newDest
|
whereToPunch = newWhereToPunch
|
||||||
} else {
|
} else {
|
||||||
if lhh.l.Level >= logrus.DebugLevel {
|
if lhh.l.Level >= logrus.DebugLevel {
|
||||||
lhh.l.WithField("to", crt.Networks()).Debugln("unable to punch to host, no addresses in common")
|
lhh.l.WithFields(m{"from": fromVpnAddrs, "to": punchNotifDestNetworks}).Debugln("unable to punch to host, no addresses in common with requestor")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1234,7 +1248,8 @@ func (lhh *LightHouseHandler) handleHostQueryReply(n *NebulaMeta, fromVpnAddrs [
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, fromVpnAddrs []netip.Addr, w EncWriter) {
|
func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, hostinfo *HostInfo, w EncWriter) {
|
||||||
|
fromVpnAddrs := hostinfo.vpnAddrs
|
||||||
if !lhh.lh.amLighthouse {
|
if !lhh.lh.amLighthouse {
|
||||||
if lhh.l.Level >= logrus.DebugLevel {
|
if lhh.l.Level >= logrus.DebugLevel {
|
||||||
lhh.l.Debugln("I am not a lighthouse, do not take host updates: ", fromVpnAddrs)
|
lhh.l.Debugln("I am not a lighthouse, do not take host updates: ", fromVpnAddrs)
|
||||||
@@ -1302,7 +1317,7 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, fromVp
|
|||||||
}
|
}
|
||||||
|
|
||||||
lhh.lh.metricTx(NebulaMeta_HostUpdateNotificationAck, 1)
|
lhh.lh.metricTx(NebulaMeta_HostUpdateNotificationAck, 1)
|
||||||
w.SendMessageToVpnAddr(header.LightHouse, 0, fromVpnAddrs[0], lhh.pb[:ln], lhh.nb, lhh.out[:0])
|
w.SendMessageToHostInfo(header.LightHouse, 0, hostinfo, lhh.pb[:ln], lhh.nb, lhh.out[:0])
|
||||||
}
|
}
|
||||||
|
|
||||||
func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, fromVpnAddrs []netip.Addr, w EncWriter) {
|
func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, fromVpnAddrs []netip.Addr, w EncWriter) {
|
||||||
|
|||||||
@@ -132,8 +132,13 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) {
|
|||||||
)
|
)
|
||||||
|
|
||||||
mw := &mockEncWriter{}
|
mw := &mockEncWriter{}
|
||||||
|
hostinfo := &HostInfo{
|
||||||
hi := []netip.Addr{vpnIp2}
|
ConnectionState: &ConnectionState{
|
||||||
|
eKey: nil,
|
||||||
|
dKey: nil,
|
||||||
|
},
|
||||||
|
vpnAddrs: []netip.Addr{vpnIp2},
|
||||||
|
}
|
||||||
b.Run("notfound", func(b *testing.B) {
|
b.Run("notfound", func(b *testing.B) {
|
||||||
lhh := lh.NewRequestHandler()
|
lhh := lh.NewRequestHandler()
|
||||||
req := &NebulaMeta{
|
req := &NebulaMeta{
|
||||||
@@ -146,7 +151,7 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) {
|
|||||||
p, err := req.Marshal()
|
p, err := req.Marshal()
|
||||||
require.NoError(b, err)
|
require.NoError(b, err)
|
||||||
for n := 0; n < b.N; n++ {
|
for n := 0; n < b.N; n++ {
|
||||||
lhh.HandleRequest(rAddr, hi, p, mw)
|
lhh.HandleRequest(rAddr, hostinfo, p, mw)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
b.Run("found", func(b *testing.B) {
|
b.Run("found", func(b *testing.B) {
|
||||||
@@ -162,7 +167,7 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) {
|
|||||||
require.NoError(b, err)
|
require.NoError(b, err)
|
||||||
|
|
||||||
for n := 0; n < b.N; n++ {
|
for n := 0; n < b.N; n++ {
|
||||||
lhh.HandleRequest(rAddr, hi, p, mw)
|
lhh.HandleRequest(rAddr, hostinfo, p, mw)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -326,7 +331,14 @@ func newLHHostRequest(fromAddr netip.AddrPort, myVpnIp, queryVpnIp netip.Addr, l
|
|||||||
w := &testEncWriter{
|
w := &testEncWriter{
|
||||||
metaFilter: &filter,
|
metaFilter: &filter,
|
||||||
}
|
}
|
||||||
lhh.HandleRequest(fromAddr, []netip.Addr{myVpnIp}, b, w)
|
hostinfo := &HostInfo{
|
||||||
|
ConnectionState: &ConnectionState{
|
||||||
|
eKey: nil,
|
||||||
|
dKey: nil,
|
||||||
|
},
|
||||||
|
vpnAddrs: []netip.Addr{myVpnIp},
|
||||||
|
}
|
||||||
|
lhh.HandleRequest(fromAddr, hostinfo, b, w)
|
||||||
return w.lastReply
|
return w.lastReply
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -355,9 +367,15 @@ func newLHHostUpdate(fromAddr netip.AddrPort, vpnIp netip.Addr, addrs []netip.Ad
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
hostinfo := &HostInfo{
|
||||||
|
ConnectionState: &ConnectionState{
|
||||||
|
eKey: nil,
|
||||||
|
dKey: nil,
|
||||||
|
},
|
||||||
|
vpnAddrs: []netip.Addr{vpnIp},
|
||||||
|
}
|
||||||
w := &testEncWriter{}
|
w := &testEncWriter{}
|
||||||
lhh.HandleRequest(fromAddr, []netip.Addr{vpnIp}, b, w)
|
lhh.HandleRequest(fromAddr, hostinfo, b, w)
|
||||||
}
|
}
|
||||||
|
|
||||||
type testLhReply struct {
|
type testLhReply struct {
|
||||||
|
|||||||
@@ -138,7 +138,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
lhf.HandleRequest(ip, hostinfo.vpnAddrs, d, f)
|
lhf.HandleRequest(ip, hostinfo, d, f)
|
||||||
|
|
||||||
// Fallthrough to the bottom to record incoming traffic
|
// Fallthrough to the bottom to record incoming traffic
|
||||||
|
|
||||||
|
|||||||
@@ -1,17 +0,0 @@
|
|||||||
//go:build linux && (386 || amd64p32 || arm || mips || mipsle) && !android && !e2e_testing
|
|
||||||
// +build linux
|
|
||||||
// +build 386 amd64p32 arm mips mipsle
|
|
||||||
// +build !android
|
|
||||||
// +build !e2e_testing
|
|
||||||
|
|
||||||
package udp
|
|
||||||
|
|
||||||
import "golang.org/x/sys/unix"
|
|
||||||
|
|
||||||
func controllen(n int) uint32 {
|
|
||||||
return uint32(n)
|
|
||||||
}
|
|
||||||
|
|
||||||
func setCmsgLen(h *unix.Cmsghdr, n int) {
|
|
||||||
h.Len = uint32(unix.CmsgLen(n))
|
|
||||||
}
|
|
||||||
@@ -1,17 +0,0 @@
|
|||||||
//go:build linux && (amd64 || arm64 || ppc64 || ppc64le || mips64 || mips64le || s390x || riscv64 || loong64) && !android && !e2e_testing
|
|
||||||
// +build linux
|
|
||||||
// +build amd64 arm64 ppc64 ppc64le mips64 mips64le s390x riscv64 loong64
|
|
||||||
// +build !android
|
|
||||||
// +build !e2e_testing
|
|
||||||
|
|
||||||
package udp
|
|
||||||
|
|
||||||
import "golang.org/x/sys/unix"
|
|
||||||
|
|
||||||
func controllen(n int) uint64 {
|
|
||||||
return uint64(n)
|
|
||||||
}
|
|
||||||
|
|
||||||
func setCmsgLen(h *unix.Cmsghdr, n int) {
|
|
||||||
h.Len = uint64(unix.CmsgLen(n))
|
|
||||||
}
|
|
||||||
595
udp/udp_linux.go
595
udp/udp_linux.go
@@ -5,14 +5,10 @@ package udp
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"sync"
|
|
||||||
"sync/atomic"
|
|
||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
|
||||||
"unsafe"
|
"unsafe"
|
||||||
|
|
||||||
"github.com/rcrowley/go-metrics"
|
"github.com/rcrowley/go-metrics"
|
||||||
@@ -21,38 +17,11 @@ import (
|
|||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
|
||||||
defaultGSOMaxSegments = 64
|
|
||||||
defaultGSOMaxBytes = 64000
|
|
||||||
defaultGROReadBufferSize = 2 * defaultGSOMaxBytes
|
|
||||||
defaultGSOFlushTimeout = 100 * time.Microsecond
|
|
||||||
)
|
|
||||||
|
|
||||||
type StdConn struct {
|
type StdConn struct {
|
||||||
sysFd int
|
sysFd int
|
||||||
isV4 bool
|
isV4 bool
|
||||||
l *logrus.Logger
|
l *logrus.Logger
|
||||||
batch int
|
batch int
|
||||||
|
|
||||||
enableGRO bool
|
|
||||||
enableGSO bool
|
|
||||||
|
|
||||||
controlLen atomic.Int32
|
|
||||||
|
|
||||||
gsoMu sync.Mutex
|
|
||||||
gsoPendingBuf []byte
|
|
||||||
gsoPendingSegments int
|
|
||||||
gsoPendingAddr netip.AddrPort
|
|
||||||
gsoPendingSegSize int
|
|
||||||
gsoMaxSegments int
|
|
||||||
gsoMaxBytes int
|
|
||||||
gsoFlushTimeout time.Duration
|
|
||||||
gsoFlushTimer *time.Timer
|
|
||||||
gsoControlBuf []byte
|
|
||||||
|
|
||||||
gsoBatches metrics.Counter
|
|
||||||
gsoSegments metrics.Counter
|
|
||||||
groSegments metrics.Counter
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func maybeIPV4(ip net.IP) (net.IP, bool) {
|
func maybeIPV4(ip net.IP) (net.IP, bool) {
|
||||||
@@ -100,18 +69,7 @@ func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch in
|
|||||||
return nil, fmt.Errorf("unable to bind to socket: %s", err)
|
return nil, fmt.Errorf("unable to bind to socket: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return &StdConn{
|
return &StdConn{sysFd: fd, isV4: ip.Is4(), l: l, batch: batch}, err
|
||||||
sysFd: fd,
|
|
||||||
isV4: ip.Is4(),
|
|
||||||
l: l,
|
|
||||||
batch: batch,
|
|
||||||
gsoMaxSegments: defaultGSOMaxSegments,
|
|
||||||
gsoMaxBytes: defaultGSOMaxBytes,
|
|
||||||
gsoFlushTimeout: defaultGSOFlushTimeout,
|
|
||||||
gsoBatches: metrics.GetOrRegisterCounter("udp.gso.batches", nil),
|
|
||||||
gsoSegments: metrics.GetOrRegisterCounter("udp.gso.segments", nil),
|
|
||||||
groSegments: metrics.GetOrRegisterCounter("udp.gro.segments", nil),
|
|
||||||
}, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *StdConn) Rebind() error {
|
func (u *StdConn) Rebind() error {
|
||||||
@@ -163,27 +121,13 @@ func (u *StdConn) LocalAddr() (netip.AddrPort, error) {
|
|||||||
func (u *StdConn) ListenOut(r EncReader) {
|
func (u *StdConn) ListenOut(r EncReader) {
|
||||||
var ip netip.Addr
|
var ip netip.Addr
|
||||||
|
|
||||||
msgs, buffers, names, controls := u.PrepareRawMessages(u.batch)
|
msgs, buffers, names := u.PrepareRawMessages(u.batch)
|
||||||
read := u.ReadMulti
|
read := u.ReadMulti
|
||||||
if u.batch == 1 {
|
if u.batch == 1 {
|
||||||
read = u.ReadSingle
|
read = u.ReadSingle
|
||||||
}
|
}
|
||||||
|
|
||||||
for {
|
for {
|
||||||
//desiredControl := int(u.controlLen.Load())
|
|
||||||
//hasControl := len(controls) > 0
|
|
||||||
//if (desiredControl > 0) != hasControl || (desiredControl > 0 && hasControl && len(controls[0]) != desiredControl) {
|
|
||||||
// msgs, buffers, names, controls = u.PrepareRawMessages(u.batch)
|
|
||||||
// hasControl = len(controls) > 0
|
|
||||||
//}
|
|
||||||
//
|
|
||||||
for i := range msgs {
|
|
||||||
if len(controls) <= i || len(controls[i]) == 0 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
msgs[i].Hdr.Controllen = controllen(len(controls[i]))
|
|
||||||
}
|
|
||||||
|
|
||||||
n, err := read(msgs)
|
n, err := read(msgs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
u.l.WithError(err).Debug("udp socket is closed, exiting read loop")
|
u.l.WithError(err).Debug("udp socket is closed, exiting read loop")
|
||||||
@@ -191,38 +135,13 @@ func (u *StdConn) ListenOut(r EncReader) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for i := 0; i < n; i++ {
|
for i := 0; i < n; i++ {
|
||||||
payloadLen := int(msgs[i].Len)
|
|
||||||
if payloadLen == 0 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// Its ok to skip the ok check here, the slicing is the only error that can occur and it will panic
|
// Its ok to skip the ok check here, the slicing is the only error that can occur and it will panic
|
||||||
if u.isV4 {
|
if u.isV4 {
|
||||||
ip, _ = netip.AddrFromSlice(names[i][4:8])
|
ip, _ = netip.AddrFromSlice(names[i][4:8])
|
||||||
} else {
|
} else {
|
||||||
ip, _ = netip.AddrFromSlice(names[i][8:24])
|
ip, _ = netip.AddrFromSlice(names[i][8:24])
|
||||||
}
|
}
|
||||||
addr := netip.AddrPortFrom(ip.Unmap(), binary.BigEndian.Uint16(names[i][2:4]))
|
r(netip.AddrPortFrom(ip.Unmap(), binary.BigEndian.Uint16(names[i][2:4])), buffers[i][:msgs[i].Len])
|
||||||
|
|
||||||
if len(controls) > i && len(controls[i]) > 0 {
|
|
||||||
if segSize, segCount := u.parseGROSegment(&msgs[i], controls[i]); segSize > 0 && segSize < payloadLen {
|
|
||||||
if u.emitSegments(r, addr, buffers[i][:payloadLen], segSize, segCount) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if segCount > 1 {
|
|
||||||
u.l.WithFields(logrus.Fields{
|
|
||||||
"tag": "gro-debug",
|
|
||||||
"stage": "listen_out",
|
|
||||||
"reason": "emit_failed",
|
|
||||||
"payload_len": payloadLen,
|
|
||||||
"seg_size": segSize,
|
|
||||||
"seg_count": segCount,
|
|
||||||
}).Debug("gro-debug fallback to single packet")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
r(addr, buffers[i][:payloadLen])
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -269,13 +188,6 @@ func (u *StdConn) ReadMulti(msgs []rawMessage) (int, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (u *StdConn) WriteTo(b []byte, ip netip.AddrPort) error {
|
func (u *StdConn) WriteTo(b []byte, ip netip.AddrPort) error {
|
||||||
if u.enableGSO {
|
|
||||||
if err := u.writeToGSO(b, ip); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if u.isV4 {
|
if u.isV4 {
|
||||||
return u.writeTo4(b, ip)
|
return u.writeTo4(b, ip)
|
||||||
}
|
}
|
||||||
@@ -336,494 +248,6 @@ func (u *StdConn) writeTo4(b []byte, ip netip.AddrPort) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *StdConn) writeToGSO(b []byte, addr netip.AddrPort) error {
|
|
||||||
if len(b) == 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if !addr.IsValid() {
|
|
||||||
return u.directWrite(b, addr)
|
|
||||||
}
|
|
||||||
|
|
||||||
u.gsoMu.Lock()
|
|
||||||
defer u.gsoMu.Unlock()
|
|
||||||
|
|
||||||
if cap(u.gsoPendingBuf) < u.gsoMaxBytes { //I feel like this is bad?
|
|
||||||
u.gsoPendingBuf = make([]byte, 0, u.gsoMaxBytes)
|
|
||||||
}
|
|
||||||
|
|
||||||
if u.gsoPendingSegments > 0 && u.gsoPendingAddr != addr {
|
|
||||||
if err := u.flushPendingLocked(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(b) > u.gsoMaxBytes || u.gsoMaxSegments <= 1 {
|
|
||||||
if err := u.flushPendingLocked(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return u.directWrite(b, addr)
|
|
||||||
}
|
|
||||||
|
|
||||||
if u.gsoPendingSegments == 0 {
|
|
||||||
u.gsoPendingAddr = addr
|
|
||||||
u.gsoPendingSegSize = len(b)
|
|
||||||
} else {
|
|
||||||
if len(b) > u.gsoPendingSegSize {
|
|
||||||
if err := u.flushPendingLocked(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
u.gsoPendingAddr = addr
|
|
||||||
u.gsoPendingSegSize = len(b)
|
|
||||||
} else if len(b) < u.gsoPendingSegSize {
|
|
||||||
if err := u.flushPendingLocked(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
u.gsoPendingAddr = addr
|
|
||||||
u.gsoPendingSegSize = len(b)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
inBuf := len(u.gsoPendingBuf) + len(b)
|
|
||||||
if len(u.gsoPendingBuf)+len(b) > u.gsoMaxBytes {
|
|
||||||
if err := u.flushPendingLocked(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
u.gsoPendingAddr = addr
|
|
||||||
u.gsoPendingSegSize = len(b)
|
|
||||||
}
|
|
||||||
|
|
||||||
u.gsoPendingBuf = append(u.gsoPendingBuf, b...)
|
|
||||||
u.gsoPendingSegments++
|
|
||||||
|
|
||||||
if u.gsoPendingSegments >= u.gsoMaxSegments {
|
|
||||||
return u.flushPendingLocked()
|
|
||||||
}
|
|
||||||
|
|
||||||
if u.gsoFlushTimeout <= 0 {
|
|
||||||
return u.flushPendingLocked()
|
|
||||||
}
|
|
||||||
|
|
||||||
u.scheduleFlushLocked(inBuf)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (u *StdConn) flushPendingLocked() error {
|
|
||||||
if u.gsoPendingSegments == 0 {
|
|
||||||
u.stopFlushTimerLocked()
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
buf := u.gsoPendingBuf[:len(u.gsoPendingBuf)]
|
|
||||||
addr := u.gsoPendingAddr
|
|
||||||
segSize := u.gsoPendingSegSize
|
|
||||||
segments := u.gsoPendingSegments
|
|
||||||
|
|
||||||
u.stopFlushTimerLocked()
|
|
||||||
|
|
||||||
var err error
|
|
||||||
if segments <= 1 || !u.enableGSO {
|
|
||||||
err = u.directWrite(buf, addr)
|
|
||||||
} else {
|
|
||||||
err = u.sendSegmentedLocked(buf, addr, segSize)
|
|
||||||
if err != nil && (errors.Is(err, unix.EOPNOTSUPP) || errors.Is(err, unix.ENOTSUP)) {
|
|
||||||
u.enableGSO = false
|
|
||||||
u.l.WithError(err).Warn("UDP GSO not supported, disabling")
|
|
||||||
err = u.sendSequentialLocked(buf, addr, segSize)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if err == nil && segments > 1 && u.enableGSO {
|
|
||||||
if u.gsoBatches != nil {
|
|
||||||
u.gsoBatches.Inc(1)
|
|
||||||
}
|
|
||||||
if u.gsoSegments != nil {
|
|
||||||
u.gsoSegments.Inc(int64(segments))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
u.gsoPendingBuf = u.gsoPendingBuf[:0]
|
|
||||||
u.gsoPendingSegments = 0
|
|
||||||
u.gsoPendingSegSize = 0
|
|
||||||
u.gsoPendingAddr = netip.AddrPort{}
|
|
||||||
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (u *StdConn) sendSegmentedLocked(buf []byte, addr netip.AddrPort, segSize int) error {
|
|
||||||
if len(buf) == 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
if segSize <= 0 {
|
|
||||||
segSize = len(buf)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(u.gsoControlBuf) < unix.CmsgSpace(2) {
|
|
||||||
u.gsoControlBuf = make([]byte, unix.CmsgSpace(2))
|
|
||||||
}
|
|
||||||
control := u.gsoControlBuf[:unix.CmsgSpace(2)]
|
|
||||||
for i := range control {
|
|
||||||
control[i] = 0
|
|
||||||
}
|
|
||||||
|
|
||||||
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
|
|
||||||
setCmsgLen(hdr, 2)
|
|
||||||
hdr.Level = unix.SOL_UDP
|
|
||||||
hdr.Type = unix.UDP_SEGMENT
|
|
||||||
|
|
||||||
dataOff := unix.CmsgLen(0)
|
|
||||||
binary.NativeEndian.PutUint16(control[dataOff:dataOff+2], uint16(segSize))
|
|
||||||
|
|
||||||
var sa unix.Sockaddr
|
|
||||||
if u.isV4 {
|
|
||||||
sa4 := &unix.SockaddrInet4{Port: int(addr.Port())}
|
|
||||||
sa4.Addr = addr.Addr().As4()
|
|
||||||
sa = sa4
|
|
||||||
} else {
|
|
||||||
sa6 := &unix.SockaddrInet6{Port: int(addr.Port())}
|
|
||||||
sa6.Addr = addr.Addr().As16()
|
|
||||||
sa = sa6
|
|
||||||
}
|
|
||||||
|
|
||||||
for {
|
|
||||||
n, err := unix.SendmsgN(u.sysFd, buf, control[:unix.CmsgSpace(2)], sa, 0)
|
|
||||||
if err != nil {
|
|
||||||
if err == unix.EINTR {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
return &net.OpError{Op: "sendmsg", Err: err}
|
|
||||||
}
|
|
||||||
if n != len(buf) {
|
|
||||||
return &net.OpError{Op: "sendmsg", Err: unix.EIO}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (u *StdConn) sendSequentialLocked(buf []byte, addr netip.AddrPort, segSize int) error {
|
|
||||||
if len(buf) == 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
if segSize <= 0 {
|
|
||||||
segSize = len(buf)
|
|
||||||
}
|
|
||||||
|
|
||||||
for offset := 0; offset < len(buf); offset += segSize {
|
|
||||||
end := offset + segSize
|
|
||||||
if end > len(buf) {
|
|
||||||
end = len(buf)
|
|
||||||
}
|
|
||||||
var err error
|
|
||||||
if u.isV4 {
|
|
||||||
err = u.writeTo4(buf[offset:end], addr)
|
|
||||||
} else {
|
|
||||||
err = u.writeTo6(buf[offset:end], addr)
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if end == len(buf) {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (u *StdConn) scheduleFlushLocked(inBuf int) {
|
|
||||||
if u.gsoFlushTimeout <= 0 {
|
|
||||||
_ = u.flushPendingLocked()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
t := u.gsoFlushTimeout
|
|
||||||
if inBuf > u.gsoMaxBytes/2 {
|
|
||||||
t = t / 2
|
|
||||||
}
|
|
||||||
|
|
||||||
if u.gsoFlushTimer == nil {
|
|
||||||
u.gsoFlushTimer = time.AfterFunc(t, u.flushTimerHandler)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if !u.gsoFlushTimer.Stop() {
|
|
||||||
// timer already fired or running; allow handler to exit if no data
|
|
||||||
}
|
|
||||||
u.gsoFlushTimer.Reset(t)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (u *StdConn) stopFlushTimerLocked() {
|
|
||||||
if u.gsoFlushTimer != nil {
|
|
||||||
u.gsoFlushTimer.Stop()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (u *StdConn) flushTimerHandler() {
|
|
||||||
//u.l.Warn("timer hit")
|
|
||||||
u.gsoMu.Lock()
|
|
||||||
defer u.gsoMu.Unlock()
|
|
||||||
|
|
||||||
if u.gsoPendingSegments == 0 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := u.flushPendingLocked(); err != nil {
|
|
||||||
u.l.WithError(err).Warn("Failed to flush GSO batch")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (u *StdConn) directWrite(b []byte, addr netip.AddrPort) error {
|
|
||||||
if u.isV4 {
|
|
||||||
return u.writeTo4(b, addr)
|
|
||||||
}
|
|
||||||
return u.writeTo6(b, addr)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (u *StdConn) emitSegments(r EncReader, addr netip.AddrPort, payload []byte, segSize, segCount int) bool {
|
|
||||||
if segSize <= 0 || segSize >= len(payload) {
|
|
||||||
if u.l.Level >= logrus.DebugLevel {
|
|
||||||
u.l.WithFields(logrus.Fields{
|
|
||||||
"tag": "gro-debug",
|
|
||||||
"stage": "emit",
|
|
||||||
"reason": "invalid_seg_size",
|
|
||||||
"payload_len": len(payload),
|
|
||||||
"seg_size": segSize,
|
|
||||||
"seg_count": segCount,
|
|
||||||
}).Debug("gro-debug skip emit")
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
totalLen := len(payload)
|
|
||||||
if segCount <= 0 {
|
|
||||||
segCount = (totalLen + segSize - 1) / segSize
|
|
||||||
}
|
|
||||||
if segCount <= 1 {
|
|
||||||
if u.l.Level >= logrus.DebugLevel {
|
|
||||||
u.l.WithFields(logrus.Fields{
|
|
||||||
"tag": "gro-debug",
|
|
||||||
"stage": "emit",
|
|
||||||
"reason": "single_segment",
|
|
||||||
"payload_len": totalLen,
|
|
||||||
"seg_size": segSize,
|
|
||||||
"seg_count": segCount,
|
|
||||||
}).Debug("gro-debug skip emit")
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
//segments := make([][]byte, 0, segCount)
|
|
||||||
start := 0
|
|
||||||
//var firstHeader header.H
|
|
||||||
//firstParsed := false
|
|
||||||
//var firstCounter uint64
|
|
||||||
//var firstRemote uint32
|
|
||||||
numSegments := 0
|
|
||||||
//for start < totalLen && len(segments) < segCount {
|
|
||||||
for start < totalLen && numSegments < segCount {
|
|
||||||
end := start + segSize
|
|
||||||
if end > totalLen {
|
|
||||||
end = totalLen
|
|
||||||
}
|
|
||||||
|
|
||||||
//segment := append([]byte(nil), payload[start:end]...)
|
|
||||||
//q := numSegments % 4 //TODO
|
|
||||||
r(addr, payload[start:end])
|
|
||||||
numSegments++
|
|
||||||
//segments = append(segments, segment)
|
|
||||||
start = end
|
|
||||||
|
|
||||||
//if !firstParsed {
|
|
||||||
// if err := firstHeader.Parse(segment); err == nil {
|
|
||||||
// firstParsed = true
|
|
||||||
// firstCounter = firstHeader.MessageCounter
|
|
||||||
// firstRemote = firstHeader.RemoteIndex
|
|
||||||
// } else if u.l.IsLevelEnabled(logrus.DebugLevel) {
|
|
||||||
// u.l.WithFields(logrus.Fields{
|
|
||||||
// "tag": "gro-debug",
|
|
||||||
// "stage": "emit",
|
|
||||||
// "event": "parse_fail",
|
|
||||||
// "seg_index": len(segments) - 1,
|
|
||||||
// "seg_size": segSize,
|
|
||||||
// "seg_count": segCount,
|
|
||||||
// "payload_len": totalLen,
|
|
||||||
// "err": err,
|
|
||||||
// }).Debug("gro-debug segment parse failed")
|
|
||||||
// }
|
|
||||||
//}
|
|
||||||
}
|
|
||||||
|
|
||||||
//for idx, segment := range segments {
|
|
||||||
// r(addr, segment)
|
|
||||||
//if idx == len(segments)-1 && len(segment) < segSize && u.l.IsLevelEnabled(logrus.DebugLevel) {
|
|
||||||
// var tail header.H
|
|
||||||
// if err := tail.Parse(segment); err == nil {
|
|
||||||
// u.l.WithFields(logrus.Fields{
|
|
||||||
// "tag": "gro-debug",
|
|
||||||
// "stage": "emit",
|
|
||||||
// "event": "tail_segment",
|
|
||||||
// "segment_len": len(segment),
|
|
||||||
// "remote_index": tail.RemoteIndex,
|
|
||||||
// "message_counter": tail.MessageCounter,
|
|
||||||
// }).Debug("gro-debug tail segment metadata")
|
|
||||||
// } else {
|
|
||||||
// u.l.WithError(err).Warn("Failed to parse tail segment")
|
|
||||||
// }
|
|
||||||
//}
|
|
||||||
//}
|
|
||||||
|
|
||||||
if u.groSegments != nil {
|
|
||||||
//u.groSegments.Inc(int64(len(segments)))
|
|
||||||
u.groSegments.Inc(int64(numSegments))
|
|
||||||
}
|
|
||||||
|
|
||||||
//if len(segments) > 0 {
|
|
||||||
// lastLen := len(segments[len(segments)-1])
|
|
||||||
// if u.l.IsLevelEnabled(logrus.DebugLevel) {
|
|
||||||
// u.l.WithFields(logrus.Fields{
|
|
||||||
// "tag": "gro-debug",
|
|
||||||
// "stage": "emit",
|
|
||||||
// "event": "success",
|
|
||||||
// "payload_len": totalLen,
|
|
||||||
// "seg_size": segSize,
|
|
||||||
// "seg_count": segCount,
|
|
||||||
// "actual_segs": len(segments),
|
|
||||||
// "last_seg_len": lastLen,
|
|
||||||
// "addr": addr.String(),
|
|
||||||
// "first_remote": firstRemote,
|
|
||||||
// "first_counter": firstCounter,
|
|
||||||
// }).Debug("gro-debug emit")
|
|
||||||
// }
|
|
||||||
//}
|
|
||||||
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
func (u *StdConn) parseGROSegment(msg *rawMessage, control []byte) (int, int) {
|
|
||||||
ctrlLen := int(msg.Hdr.Controllen)
|
|
||||||
if ctrlLen <= 0 {
|
|
||||||
return 0, 0
|
|
||||||
}
|
|
||||||
if ctrlLen > len(control) {
|
|
||||||
ctrlLen = len(control)
|
|
||||||
}
|
|
||||||
|
|
||||||
cmsgs, err := unix.ParseSocketControlMessage(control[:ctrlLen])
|
|
||||||
if err != nil {
|
|
||||||
u.l.WithError(err).Debug("failed to parse UDP GRO control message")
|
|
||||||
return 0, 0
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, c := range cmsgs {
|
|
||||||
if c.Header.Level == unix.SOL_UDP && c.Header.Type == unix.UDP_GRO && len(c.Data) >= 2 {
|
|
||||||
segSize := int(binary.NativeEndian.Uint16(c.Data[:2]))
|
|
||||||
segCount := 0
|
|
||||||
if len(c.Data) >= 4 {
|
|
||||||
segCount = int(binary.NativeEndian.Uint16(c.Data[2:4]))
|
|
||||||
}
|
|
||||||
if u.l.Level >= logrus.DebugLevel {
|
|
||||||
u.l.WithFields(logrus.Fields{
|
|
||||||
"tag": "gro-debug",
|
|
||||||
"stage": "parse",
|
|
||||||
"seg_size": segSize,
|
|
||||||
"seg_count": segCount,
|
|
||||||
}).Debug("gro-debug control parsed")
|
|
||||||
}
|
|
||||||
return segSize, segCount
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return 0, 0
|
|
||||||
}
|
|
||||||
|
|
||||||
func (u *StdConn) configureGRO(enable bool) {
|
|
||||||
if enable == u.enableGRO {
|
|
||||||
if enable {
|
|
||||||
u.controlLen.Store(int32(unix.CmsgSpace(2)))
|
|
||||||
} else {
|
|
||||||
u.controlLen.Store(0)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if enable {
|
|
||||||
if err := unix.SetsockoptInt(u.sysFd, unix.SOL_UDP, unix.UDP_GRO, 1); err != nil {
|
|
||||||
u.l.WithError(err).Warn("Failed to enable UDP GRO")
|
|
||||||
u.enableGRO = false
|
|
||||||
u.controlLen.Store(0)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
u.enableGRO = true
|
|
||||||
u.controlLen.Store(int32(unix.CmsgSpace(2)))
|
|
||||||
u.l.Info("UDP GRO enabled")
|
|
||||||
} else {
|
|
||||||
if u.enableGRO {
|
|
||||||
if err := unix.SetsockoptInt(u.sysFd, unix.SOL_UDP, unix.UDP_GRO, 0); err != nil {
|
|
||||||
u.l.WithError(err).Warn("Failed to disable UDP GRO")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
u.enableGRO = false
|
|
||||||
u.controlLen.Store(0)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (u *StdConn) configureGSO(enable bool, c *config.C) {
|
|
||||||
u.gsoMu.Lock()
|
|
||||||
defer u.gsoMu.Unlock()
|
|
||||||
|
|
||||||
if !enable {
|
|
||||||
if u.enableGSO {
|
|
||||||
if err := u.flushPendingLocked(); err != nil {
|
|
||||||
u.l.WithError(err).Warn("Failed to flush GSO buffers while disabling")
|
|
||||||
}
|
|
||||||
u.enableGSO = false
|
|
||||||
if u.gsoFlushTimer != nil {
|
|
||||||
u.gsoFlushTimer.Stop()
|
|
||||||
}
|
|
||||||
u.l.Info("UDP GSO disabled")
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
maxSegments := c.GetInt("listen.gso_max_segments", defaultGSOMaxSegments)
|
|
||||||
if maxSegments < 2 {
|
|
||||||
maxSegments = 2
|
|
||||||
}
|
|
||||||
|
|
||||||
maxBytes := c.GetInt("listen.gso_max_bytes", 0)
|
|
||||||
if maxBytes <= 0 {
|
|
||||||
maxBytes = defaultGSOMaxBytes
|
|
||||||
}
|
|
||||||
if maxBytes < MTU {
|
|
||||||
maxBytes = MTU
|
|
||||||
}
|
|
||||||
|
|
||||||
flushTimeout := c.GetDuration("listen.gso_flush_timeout", defaultGSOFlushTimeout)
|
|
||||||
if flushTimeout < 0 {
|
|
||||||
flushTimeout = 0
|
|
||||||
}
|
|
||||||
|
|
||||||
u.enableGSO = true
|
|
||||||
u.gsoMaxSegments = maxSegments
|
|
||||||
u.gsoMaxBytes = maxBytes
|
|
||||||
u.gsoFlushTimeout = flushTimeout
|
|
||||||
|
|
||||||
if cap(u.gsoPendingBuf) < u.gsoMaxBytes {
|
|
||||||
u.gsoPendingBuf = make([]byte, 0, u.gsoMaxBytes)
|
|
||||||
} else {
|
|
||||||
u.gsoPendingBuf = u.gsoPendingBuf[:0]
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(u.gsoControlBuf) < unix.CmsgSpace(2) {
|
|
||||||
u.gsoControlBuf = make([]byte, unix.CmsgSpace(2))
|
|
||||||
}
|
|
||||||
|
|
||||||
u.l.WithFields(logrus.Fields{
|
|
||||||
"segments": u.gsoMaxSegments,
|
|
||||||
"bytes": u.gsoMaxBytes,
|
|
||||||
"flush_timeout": u.gsoFlushTimeout,
|
|
||||||
}).Info("UDP GSO configured")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (u *StdConn) ReloadConfig(c *config.C) {
|
func (u *StdConn) ReloadConfig(c *config.C) {
|
||||||
b := c.GetInt("listen.read_buffer", 0)
|
b := c.GetInt("listen.read_buffer", 0)
|
||||||
if b > 0 {
|
if b > 0 {
|
||||||
@@ -870,9 +294,6 @@ func (u *StdConn) ReloadConfig(c *config.C) {
|
|||||||
u.l.WithError(err).Error("Failed to set listen.so_mark")
|
u.l.WithError(err).Error("Failed to set listen.so_mark")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
u.configureGRO(c.GetBool("listen.enable_gro", false))
|
|
||||||
u.configureGSO(c.GetBool("listen.enable_gso", false), c)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *StdConn) getMemInfo(meminfo *[unix.SK_MEMINFO_VARS]uint32) error {
|
func (u *StdConn) getMemInfo(meminfo *[unix.SK_MEMINFO_VARS]uint32) error {
|
||||||
@@ -885,15 +306,7 @@ func (u *StdConn) getMemInfo(meminfo *[unix.SK_MEMINFO_VARS]uint32) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (u *StdConn) Close() error {
|
func (u *StdConn) Close() error {
|
||||||
u.gsoMu.Lock()
|
return syscall.Close(u.sysFd)
|
||||||
flushErr := u.flushPendingLocked()
|
|
||||||
u.gsoMu.Unlock()
|
|
||||||
|
|
||||||
closeErr := syscall.Close(u.sysFd)
|
|
||||||
if flushErr != nil {
|
|
||||||
return flushErr
|
|
||||||
}
|
|
||||||
return closeErr
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewUDPStatsEmitter(udpConns []Conn) func() {
|
func NewUDPStatsEmitter(udpConns []Conn) func() {
|
||||||
|
|||||||
@@ -30,24 +30,13 @@ type rawMessage struct {
|
|||||||
Len uint32
|
Len uint32
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte, [][]byte) {
|
func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) {
|
||||||
controlLen := int(u.controlLen.Load())
|
|
||||||
|
|
||||||
msgs := make([]rawMessage, n)
|
msgs := make([]rawMessage, n)
|
||||||
buffers := make([][]byte, n)
|
buffers := make([][]byte, n)
|
||||||
names := make([][]byte, n)
|
names := make([][]byte, n)
|
||||||
|
|
||||||
var controls [][]byte
|
|
||||||
if controlLen > 0 {
|
|
||||||
controls = make([][]byte, n)
|
|
||||||
}
|
|
||||||
|
|
||||||
for i := range msgs {
|
for i := range msgs {
|
||||||
size := MTU
|
buffers[i] = make([]byte, MTU)
|
||||||
if defaultGROReadBufferSize > size {
|
|
||||||
size = defaultGROReadBufferSize
|
|
||||||
}
|
|
||||||
buffers[i] = make([]byte, size)
|
|
||||||
names[i] = make([]byte, unix.SizeofSockaddrInet6)
|
names[i] = make([]byte, unix.SizeofSockaddrInet6)
|
||||||
|
|
||||||
vs := []iovec{
|
vs := []iovec{
|
||||||
@@ -59,16 +48,7 @@ func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte, [
|
|||||||
|
|
||||||
msgs[i].Hdr.Name = &names[i][0]
|
msgs[i].Hdr.Name = &names[i][0]
|
||||||
msgs[i].Hdr.Namelen = uint32(len(names[i]))
|
msgs[i].Hdr.Namelen = uint32(len(names[i]))
|
||||||
|
|
||||||
if controlLen > 0 {
|
|
||||||
controls[i] = make([]byte, controlLen)
|
|
||||||
msgs[i].Hdr.Control = &controls[i][0]
|
|
||||||
msgs[i].Hdr.Controllen = controllen(len(controls[i]))
|
|
||||||
} else {
|
|
||||||
msgs[i].Hdr.Control = nil
|
|
||||||
msgs[i].Hdr.Controllen = controllen(0)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return msgs, buffers, names, controls
|
return msgs, buffers, names
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -33,43 +33,25 @@ type rawMessage struct {
|
|||||||
Pad0 [4]byte
|
Pad0 [4]byte
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte, [][]byte) {
|
func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) {
|
||||||
controlLen := int(u.controlLen.Load())
|
|
||||||
|
|
||||||
msgs := make([]rawMessage, n)
|
msgs := make([]rawMessage, n)
|
||||||
buffers := make([][]byte, n)
|
buffers := make([][]byte, n)
|
||||||
names := make([][]byte, n)
|
names := make([][]byte, n)
|
||||||
|
|
||||||
var controls [][]byte
|
|
||||||
if controlLen > 0 {
|
|
||||||
controls = make([][]byte, n)
|
|
||||||
}
|
|
||||||
|
|
||||||
for i := range msgs {
|
for i := range msgs {
|
||||||
size := MTU
|
buffers[i] = make([]byte, MTU)
|
||||||
if defaultGROReadBufferSize > size {
|
|
||||||
size = defaultGROReadBufferSize
|
|
||||||
}
|
|
||||||
buffers[i] = make([]byte, size)
|
|
||||||
names[i] = make([]byte, unix.SizeofSockaddrInet6)
|
names[i] = make([]byte, unix.SizeofSockaddrInet6)
|
||||||
|
|
||||||
vs := []iovec{{Base: &buffers[i][0], Len: uint64(len(buffers[i]))}}
|
vs := []iovec{
|
||||||
|
{Base: &buffers[i][0], Len: uint64(len(buffers[i]))},
|
||||||
|
}
|
||||||
|
|
||||||
msgs[i].Hdr.Iov = &vs[0]
|
msgs[i].Hdr.Iov = &vs[0]
|
||||||
msgs[i].Hdr.Iovlen = uint64(len(vs))
|
msgs[i].Hdr.Iovlen = uint64(len(vs))
|
||||||
|
|
||||||
msgs[i].Hdr.Name = &names[i][0]
|
msgs[i].Hdr.Name = &names[i][0]
|
||||||
msgs[i].Hdr.Namelen = uint32(len(names[i]))
|
msgs[i].Hdr.Namelen = uint32(len(names[i]))
|
||||||
|
|
||||||
if controlLen > 0 {
|
|
||||||
controls[i] = make([]byte, controlLen)
|
|
||||||
msgs[i].Hdr.Control = &controls[i][0]
|
|
||||||
msgs[i].Hdr.Controllen = controllen(len(controls[i]))
|
|
||||||
} else {
|
|
||||||
msgs[i].Hdr.Control = nil
|
|
||||||
msgs[i].Hdr.Controllen = controllen(0)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return msgs, buffers, names, controls
|
return msgs, buffers, names
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user