mirror of
https://github.com/slackhq/nebula.git
synced 2025-11-22 08:24:25 +01:00
Compare commits
16 Commits
io-uring-g
...
channels-s
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9253f36a3c | ||
|
|
c9a695c2bf | ||
|
|
2c6f81c224 | ||
|
|
ad37749c5e | ||
|
|
a0f8cb2098 | ||
|
|
d18d1aea67 | ||
|
|
f5ff534671 | ||
|
|
2ea8a72d5c | ||
|
|
663232e1fc | ||
|
|
2f48529e8b | ||
|
|
f3e1ad64cd | ||
|
|
1d8112a329 | ||
|
|
31eea0cc94 | ||
|
|
dbba4a4c77 | ||
|
|
194fde45da | ||
|
|
f46b83f2c4 |
3
bits.go
3
bits.go
@@ -5,6 +5,7 @@ import (
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// TODO: Pretty sure this is just all sorts of racy now, we need it to be atomic
|
||||
type Bits struct {
|
||||
length uint64
|
||||
current uint64
|
||||
@@ -43,7 +44,7 @@ func (b *Bits) Check(l logrus.FieldLogger, i uint64) bool {
|
||||
}
|
||||
|
||||
// Not within the window
|
||||
l.Debugf("rejected a packet (top) %d %d\n", b.current, i)
|
||||
l.Error("rejected a packet (top) %d %d\n", b.current, i)
|
||||
return false
|
||||
}
|
||||
|
||||
|
||||
97
cert/pem.go
97
cert/pem.go
@@ -1,8 +1,10 @@
|
||||
package cert
|
||||
|
||||
import (
|
||||
"encoding/hex"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"golang.org/x/crypto/ed25519"
|
||||
)
|
||||
@@ -138,6 +140,101 @@ func MarshalSigningPrivateKeyToPEM(curve Curve, b []byte) []byte {
|
||||
}
|
||||
}
|
||||
|
||||
// Backward compatibility functions for older API
|
||||
func MarshalX25519PublicKey(b []byte) []byte {
|
||||
return MarshalPublicKeyToPEM(Curve_CURVE25519, b)
|
||||
}
|
||||
|
||||
func MarshalX25519PrivateKey(b []byte) []byte {
|
||||
return MarshalPrivateKeyToPEM(Curve_CURVE25519, b)
|
||||
}
|
||||
|
||||
func MarshalPublicKey(curve Curve, b []byte) []byte {
|
||||
return MarshalPublicKeyToPEM(curve, b)
|
||||
}
|
||||
|
||||
func MarshalPrivateKey(curve Curve, b []byte) []byte {
|
||||
return MarshalPrivateKeyToPEM(curve, b)
|
||||
}
|
||||
|
||||
// NebulaCertificate is a compatibility wrapper for the old API
|
||||
type NebulaCertificate struct {
|
||||
Details NebulaCertificateDetails
|
||||
Signature []byte
|
||||
cert Certificate
|
||||
}
|
||||
|
||||
// NebulaCertificateDetails is a compatibility wrapper for certificate details
|
||||
type NebulaCertificateDetails struct {
|
||||
Name string
|
||||
NotBefore time.Time
|
||||
NotAfter time.Time
|
||||
PublicKey []byte
|
||||
IsCA bool
|
||||
Issuer []byte
|
||||
Curve Curve
|
||||
}
|
||||
|
||||
// UnmarshalNebulaCertificateFromPEM provides backward compatibility with the old API
|
||||
func UnmarshalNebulaCertificateFromPEM(b []byte) (*NebulaCertificate, []byte, error) {
|
||||
c, rest, err := UnmarshalCertificateFromPEM(b)
|
||||
if err != nil {
|
||||
return nil, rest, err
|
||||
}
|
||||
|
||||
issuerBytes, err := func() ([]byte, error) {
|
||||
issuer := c.Issuer()
|
||||
if issuer == "" {
|
||||
return nil, nil
|
||||
}
|
||||
decoded, err := hex.DecodeString(issuer)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode issuer fingerprint: %w", err)
|
||||
}
|
||||
return decoded, nil
|
||||
}()
|
||||
if err != nil {
|
||||
return nil, rest, err
|
||||
}
|
||||
|
||||
pubKey := c.PublicKey()
|
||||
if pubKey != nil {
|
||||
pubKey = append([]byte(nil), pubKey...)
|
||||
}
|
||||
|
||||
sig := c.Signature()
|
||||
if sig != nil {
|
||||
sig = append([]byte(nil), sig...)
|
||||
}
|
||||
|
||||
return &NebulaCertificate{
|
||||
Details: NebulaCertificateDetails{
|
||||
Name: c.Name(),
|
||||
NotBefore: c.NotBefore(),
|
||||
NotAfter: c.NotAfter(),
|
||||
PublicKey: pubKey,
|
||||
IsCA: c.IsCA(),
|
||||
Issuer: issuerBytes,
|
||||
Curve: c.Curve(),
|
||||
},
|
||||
Signature: sig,
|
||||
cert: c,
|
||||
}, rest, nil
|
||||
}
|
||||
|
||||
// IssuerString returns the issuer in hex format for compatibility
|
||||
func (n *NebulaCertificate) IssuerString() string {
|
||||
if n.Details.Issuer == nil {
|
||||
return ""
|
||||
}
|
||||
return hex.EncodeToString(n.Details.Issuer)
|
||||
}
|
||||
|
||||
// Certificate returns the underlying certificate (read-only)
|
||||
func (n *NebulaCertificate) Certificate() Certificate {
|
||||
return n.cert
|
||||
}
|
||||
|
||||
// UnmarshalPrivateKeyFromPEM will try to unmarshal the first pem block in a byte array, returning any non
|
||||
// consumed data or an error on failure
|
||||
func UnmarshalPrivateKeyFromPEM(b []byte) ([]byte, []byte, Curve, error) {
|
||||
|
||||
@@ -65,8 +65,16 @@ func main() {
|
||||
}
|
||||
|
||||
if !*configTest {
|
||||
ctrl.Start()
|
||||
ctrl.ShutdownBlock()
|
||||
wait, err := ctrl.Start()
|
||||
if err != nil {
|
||||
util.LogWithContextIfNeeded("Error while running", err, l)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
go ctrl.ShutdownBlock()
|
||||
wait()
|
||||
|
||||
l.Info("Goodbye")
|
||||
}
|
||||
|
||||
os.Exit(0)
|
||||
|
||||
@@ -3,6 +3,9 @@ package main
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
_ "net/http/pprof"
|
||||
"os"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
@@ -58,10 +61,22 @@ func main() {
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
go func() {
|
||||
log.Println(http.ListenAndServe("0.0.0.0:6060", nil))
|
||||
}()
|
||||
|
||||
if !*configTest {
|
||||
ctrl.Start()
|
||||
wait, err := ctrl.Start()
|
||||
if err != nil {
|
||||
util.LogWithContextIfNeeded("Error while running", err, l)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
go ctrl.ShutdownBlock()
|
||||
notifyReady(l)
|
||||
ctrl.ShutdownBlock()
|
||||
wait()
|
||||
|
||||
l.Info("Goodbye")
|
||||
}
|
||||
|
||||
os.Exit(0)
|
||||
|
||||
@@ -13,7 +13,9 @@ import (
|
||||
"github.com/slackhq/nebula/noiseutil"
|
||||
)
|
||||
|
||||
const ReplayWindow = 1024
|
||||
// TODO: In a 5Gbps test, 1024 is not sufficient. With a 1400 MTU this is about 1.4Gbps of window, assuming full packets.
|
||||
// 4092 should be sufficient for 5Gbps
|
||||
const ReplayWindow = 4096
|
||||
|
||||
type ConnectionState struct {
|
||||
eKey *NebulaCipherState
|
||||
|
||||
56
control.go
56
control.go
@@ -2,9 +2,11 @@ package nebula
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/netip"
|
||||
"os"
|
||||
"os/signal"
|
||||
"sync"
|
||||
"syscall"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
@@ -13,6 +15,16 @@ import (
|
||||
"github.com/slackhq/nebula/overlay"
|
||||
)
|
||||
|
||||
type RunState int
|
||||
|
||||
const (
|
||||
Stopped RunState = 0 // The control has yet to be started
|
||||
Started RunState = 1 // The control has been started
|
||||
Stopping RunState = 2 // The control is stopping
|
||||
)
|
||||
|
||||
var ErrAlreadyStarted = errors.New("nebula is already started")
|
||||
|
||||
// Every interaction here needs to take extra care to copy memory and not return or use arguments "as is" when touching
|
||||
// core. This means copying IP objects, slices, de-referencing pointers and taking the actual value, etc
|
||||
|
||||
@@ -26,6 +38,9 @@ type controlHostLister interface {
|
||||
}
|
||||
|
||||
type Control struct {
|
||||
stateLock sync.Mutex
|
||||
state RunState
|
||||
|
||||
f *Interface
|
||||
l *logrus.Logger
|
||||
ctx context.Context
|
||||
@@ -49,10 +64,21 @@ type ControlHostInfo struct {
|
||||
CurrentRelaysThroughMe []netip.Addr `json:"currentRelaysThroughMe"`
|
||||
}
|
||||
|
||||
// Start actually runs nebula, this is a nonblocking call. To block use Control.ShutdownBlock()
|
||||
func (c *Control) Start() {
|
||||
// Start actually runs nebula, this is a nonblocking call.
|
||||
// The returned function can be used to wait for nebula to fully stop.
|
||||
func (c *Control) Start() (func(), error) {
|
||||
c.stateLock.Lock()
|
||||
if c.state != Stopped {
|
||||
c.stateLock.Unlock()
|
||||
return nil, ErrAlreadyStarted
|
||||
}
|
||||
|
||||
// Activate the interface
|
||||
c.f.activate()
|
||||
err := c.f.activate()
|
||||
if err != nil {
|
||||
c.stateLock.Unlock()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Call all the delayed funcs that waited patiently for the interface to be created.
|
||||
if c.sshStart != nil {
|
||||
@@ -72,15 +98,33 @@ func (c *Control) Start() {
|
||||
}
|
||||
|
||||
// Start reading packets.
|
||||
c.f.run()
|
||||
c.state = Started
|
||||
c.stateLock.Unlock()
|
||||
return c.f.run(c.ctx)
|
||||
}
|
||||
|
||||
func (c *Control) State() RunState {
|
||||
c.stateLock.Lock()
|
||||
defer c.stateLock.Unlock()
|
||||
return c.state
|
||||
}
|
||||
|
||||
func (c *Control) Context() context.Context {
|
||||
return c.ctx
|
||||
}
|
||||
|
||||
// Stop signals nebula to shutdown and close all tunnels, returns after the shutdown is complete
|
||||
// Stop is a non-blocking call that signals nebula to close all tunnels and shut down
|
||||
func (c *Control) Stop() {
|
||||
c.stateLock.Lock()
|
||||
if c.state != Started {
|
||||
c.stateLock.Unlock()
|
||||
// We are stopping or stopped already
|
||||
return
|
||||
}
|
||||
|
||||
c.state = Stopping
|
||||
c.stateLock.Unlock()
|
||||
|
||||
// Stop the handshakeManager (and other services), to prevent new tunnels from
|
||||
// being created while we're shutting them all down.
|
||||
c.cancel()
|
||||
@@ -89,7 +133,7 @@ func (c *Control) Stop() {
|
||||
if err := c.f.Close(); err != nil {
|
||||
c.l.WithError(err).Error("Close interface failed")
|
||||
}
|
||||
c.l.Info("Goodbye")
|
||||
c.state = Stopped
|
||||
}
|
||||
|
||||
// ShutdownBlock will listen for and block on term and interrupt signals, calling Control.Stop() once signalled
|
||||
|
||||
@@ -132,6 +132,13 @@ listen:
|
||||
# Sets the max number of packets to pull from the kernel for each syscall (under systems that support recvmmsg)
|
||||
# default is 64, does not support reload
|
||||
#batch: 64
|
||||
|
||||
# Control batching between UDP and TUN pipelines
|
||||
#batch:
|
||||
# inbound_size: 32 # packets to queue from UDP before handing to workers
|
||||
# outbound_size: 32 # packets to queue from TUN before handing to workers
|
||||
# flush_interval: 50us # flush partially filled batches after this duration
|
||||
# max_outstanding: 1028 # batches buffered per routine on each channel
|
||||
# Configure socket buffers for the udp side (outside), leave unset to use the system defaults. Values will be doubled by the kernel
|
||||
# Default is net.core.rmem_default and net.core.wmem_default (/proc/sys/net/core/rmem_default and /proc/sys/net/core/rmem_default)
|
||||
# Maximum is limited by memory in the system, SO_RCVBUFFORCE and SO_SNDBUFFORCE is used to avoid having to raise the system wide
|
||||
|
||||
106
inside.go
106
inside.go
@@ -11,19 +11,19 @@ import (
|
||||
"github.com/slackhq/nebula/routing"
|
||||
)
|
||||
|
||||
func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb, out []byte, q int, localCache firewall.ConntrackCache) {
|
||||
func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb, out []byte, queue func(netip.AddrPort, int), q int, localCache firewall.ConntrackCache) bool {
|
||||
err := newPacket(packet, false, fwPacket)
|
||||
if err != nil {
|
||||
if f.l.Level >= logrus.DebugLevel {
|
||||
f.l.WithField("packet", packet).Debugf("Error while validating outbound packet: %s", err)
|
||||
}
|
||||
return
|
||||
return false
|
||||
}
|
||||
|
||||
// Ignore local broadcast packets
|
||||
if f.dropLocalBroadcast {
|
||||
if f.myBroadcastAddrsTable.Contains(fwPacket.RemoteAddr) {
|
||||
return
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
@@ -40,12 +40,12 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
|
||||
}
|
||||
// Otherwise, drop. On linux, we should never see these packets - Linux
|
||||
// routes packets from the nebula addr to the nebula addr through the loopback device.
|
||||
return
|
||||
return false
|
||||
}
|
||||
|
||||
// Ignore multicast packets
|
||||
if f.dropMulticast && fwPacket.RemoteAddr.IsMulticast() {
|
||||
return
|
||||
return false
|
||||
}
|
||||
|
||||
hostinfo, ready := f.getOrHandshakeConsiderRouting(fwPacket, func(hh *HandshakeHostInfo) {
|
||||
@@ -59,26 +59,26 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
|
||||
WithField("fwPacket", fwPacket).
|
||||
Debugln("dropping outbound packet, vpnAddr not in our vpn networks or in unsafe networks")
|
||||
}
|
||||
return
|
||||
return false
|
||||
}
|
||||
|
||||
if !ready {
|
||||
return
|
||||
return false
|
||||
}
|
||||
|
||||
dropReason := f.firewall.Drop(*fwPacket, false, hostinfo, f.pki.GetCAPool(), localCache)
|
||||
if dropReason == nil {
|
||||
f.sendNoMetrics(header.Message, 0, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, packet, nb, out, q)
|
||||
|
||||
} else {
|
||||
f.rejectInside(packet, out, q)
|
||||
if f.l.Level >= logrus.DebugLevel {
|
||||
hostinfo.logger(f.l).
|
||||
WithField("fwPacket", fwPacket).
|
||||
WithField("reason", dropReason).
|
||||
Debugln("dropping outbound packet")
|
||||
}
|
||||
return f.sendNoMetrics(header.Message, 0, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, packet, nb, out, queue, q)
|
||||
}
|
||||
|
||||
f.rejectInside(packet, out, q)
|
||||
if f.l.Level >= logrus.DebugLevel {
|
||||
hostinfo.logger(f.l).
|
||||
WithField("fwPacket", fwPacket).
|
||||
WithField("reason", dropReason).
|
||||
Debugln("dropping outbound packet")
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (f *Interface) rejectInside(packet []byte, out []byte, q int) {
|
||||
@@ -117,7 +117,7 @@ func (f *Interface) rejectOutside(packet []byte, ci *ConnectionState, hostinfo *
|
||||
return
|
||||
}
|
||||
|
||||
f.sendNoMetrics(header.Message, 0, ci, hostinfo, netip.AddrPort{}, out, nb, packet, q)
|
||||
_ = f.sendNoMetrics(header.Message, 0, ci, hostinfo, netip.AddrPort{}, out, nb, packet, nil, q)
|
||||
}
|
||||
|
||||
// Handshake will attempt to initiate a tunnel with the provided vpn address if it is within our vpn networks. This is a no-op if the tunnel is already established or being established
|
||||
@@ -228,7 +228,7 @@ func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubTyp
|
||||
return
|
||||
}
|
||||
|
||||
f.sendNoMetrics(header.Message, st, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, p, nb, out, 0)
|
||||
_ = f.sendNoMetrics(header.Message, st, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, p, nb, out, nil, 0)
|
||||
}
|
||||
|
||||
// SendMessageToVpnAddr handles real addr:port lookup and sends to the current best known address for vpnAddr
|
||||
@@ -258,12 +258,12 @@ func (f *Interface) SendMessageToHostInfo(t header.MessageType, st header.Messag
|
||||
|
||||
func (f *Interface) send(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, p, nb, out []byte) {
|
||||
f.messageMetrics.Tx(t, st, 1)
|
||||
f.sendNoMetrics(t, st, ci, hostinfo, netip.AddrPort{}, p, nb, out, 0)
|
||||
_ = f.sendNoMetrics(t, st, ci, hostinfo, netip.AddrPort{}, p, nb, out, nil, 0)
|
||||
}
|
||||
|
||||
func (f *Interface) sendTo(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote netip.AddrPort, p, nb, out []byte) {
|
||||
f.messageMetrics.Tx(t, st, 1)
|
||||
f.sendNoMetrics(t, st, ci, hostinfo, remote, p, nb, out, 0)
|
||||
_ = f.sendNoMetrics(t, st, ci, hostinfo, remote, p, nb, out, nil, 0)
|
||||
}
|
||||
|
||||
// SendVia sends a payload through a Relay tunnel. No authentication or encryption is done
|
||||
@@ -331,9 +331,12 @@ func (f *Interface) SendVia(via *HostInfo,
|
||||
f.connectionManager.RelayUsed(relay.LocalIndex)
|
||||
}
|
||||
|
||||
func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote netip.AddrPort, p, nb, out []byte, q int) {
|
||||
// sendNoMetrics encrypts and writes/queues an outbound packet. It returns true
|
||||
// when the payload has been handed to a caller-provided queue (meaning the
|
||||
// caller is responsible for flushing it later).
|
||||
func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote netip.AddrPort, p, nb, out []byte, queue func(netip.AddrPort, int), q int) bool {
|
||||
if ci.eKey == nil {
|
||||
return
|
||||
return false
|
||||
}
|
||||
useRelay := !remote.IsValid() && !hostinfo.remote.IsValid()
|
||||
fullOut := out
|
||||
@@ -380,32 +383,39 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType
|
||||
WithField("udpAddr", remote).WithField("counter", c).
|
||||
WithField("attemptedCounter", c).
|
||||
Error("Failed to encrypt outgoing packet")
|
||||
return
|
||||
return false
|
||||
}
|
||||
|
||||
if remote.IsValid() {
|
||||
err = f.writers[q].WriteTo(out, remote)
|
||||
if err != nil {
|
||||
hostinfo.logger(f.l).WithError(err).
|
||||
WithField("udpAddr", remote).Error("Failed to write outgoing packet")
|
||||
}
|
||||
} else if hostinfo.remote.IsValid() {
|
||||
err = f.writers[q].WriteTo(out, hostinfo.remote)
|
||||
if err != nil {
|
||||
hostinfo.logger(f.l).WithError(err).
|
||||
WithField("udpAddr", remote).Error("Failed to write outgoing packet")
|
||||
}
|
||||
} else {
|
||||
// Try to send via a relay
|
||||
for _, relayIP := range hostinfo.relayState.CopyRelayIps() {
|
||||
relayHostInfo, relay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relayIP)
|
||||
if err != nil {
|
||||
hostinfo.relayState.DeleteRelay(relayIP)
|
||||
hostinfo.logger(f.l).WithField("relay", relayIP).WithError(err).Info("sendNoMetrics failed to find HostInfo")
|
||||
continue
|
||||
}
|
||||
f.SendVia(relayHostInfo, relay, out, nb, fullOut[:header.Len+len(out)], true)
|
||||
break
|
||||
}
|
||||
dest := remote
|
||||
if !dest.IsValid() {
|
||||
dest = hostinfo.remote
|
||||
}
|
||||
|
||||
if dest.IsValid() {
|
||||
if queue != nil {
|
||||
queue(dest, len(out))
|
||||
return true
|
||||
}
|
||||
|
||||
err = f.writers[q].WriteTo(out, dest)
|
||||
if err != nil {
|
||||
hostinfo.logger(f.l).WithError(err).
|
||||
WithField("udpAddr", dest).Error("Failed to write outgoing packet")
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Try to send via a relay
|
||||
for _, relayIP := range hostinfo.relayState.CopyRelayIps() {
|
||||
relayHostInfo, relay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relayIP)
|
||||
if err != nil {
|
||||
hostinfo.relayState.DeleteRelay(relayIP)
|
||||
hostinfo.logger(f.l).WithField("relay", relayIP).WithError(err).Info("sendNoMetrics failed to find HostInfo")
|
||||
continue
|
||||
}
|
||||
f.SendVia(relayHostInfo, relay, out, nb, fullOut[:header.Len+len(out)], true)
|
||||
break
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
471
interface.go
471
interface.go
@@ -6,8 +6,8 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/netip"
|
||||
"os"
|
||||
"runtime"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
@@ -18,10 +18,22 @@ import (
|
||||
"github.com/slackhq/nebula/firewall"
|
||||
"github.com/slackhq/nebula/header"
|
||||
"github.com/slackhq/nebula/overlay"
|
||||
"github.com/slackhq/nebula/packet"
|
||||
"github.com/slackhq/nebula/udp"
|
||||
)
|
||||
|
||||
const mtu = 9001
|
||||
const (
|
||||
mtu = 9001
|
||||
|
||||
inboundBatchSizeDefault = 128
|
||||
outboundBatchSizeDefault = 64
|
||||
batchFlushIntervalDefault = 12 * time.Microsecond
|
||||
maxOutstandingBatchesDefault = 8
|
||||
sendBatchSizeDefault = 64
|
||||
maxPendingPacketsDefault = 32
|
||||
maxPendingBytesDefault = 64 * 1024
|
||||
maxSendBufPerRoutineDefault = 16
|
||||
)
|
||||
|
||||
type InterfaceConfig struct {
|
||||
HostMap *HostMap
|
||||
@@ -47,9 +59,20 @@ type InterfaceConfig struct {
|
||||
reQueryWait time.Duration
|
||||
|
||||
ConntrackCacheTimeout time.Duration
|
||||
BatchConfig BatchConfig
|
||||
l *logrus.Logger
|
||||
}
|
||||
|
||||
type BatchConfig struct {
|
||||
InboundBatchSize int
|
||||
OutboundBatchSize int
|
||||
FlushInterval time.Duration
|
||||
MaxOutstandingPerChan int
|
||||
MaxPendingPackets int
|
||||
MaxPendingBytes int
|
||||
MaxSendBuffersPerChan int
|
||||
}
|
||||
|
||||
type Interface struct {
|
||||
hostMap *HostMap
|
||||
outside udp.Conn
|
||||
@@ -87,12 +110,165 @@ type Interface struct {
|
||||
|
||||
writers []udp.Conn
|
||||
readers []io.ReadWriteCloser
|
||||
wg sync.WaitGroup
|
||||
|
||||
metricHandshakes metrics.Histogram
|
||||
messageMetrics *MessageMetrics
|
||||
cachedPacketMetrics *cachedPacketMetrics
|
||||
|
||||
l *logrus.Logger
|
||||
|
||||
inPool sync.Pool
|
||||
inbound []chan *packetBatch
|
||||
|
||||
outPool sync.Pool
|
||||
outbound []chan *outboundBatch
|
||||
|
||||
packetBatchPool sync.Pool
|
||||
outboundBatchPool sync.Pool
|
||||
|
||||
sendPool sync.Pool
|
||||
sendBufCache [][]*[]byte
|
||||
sendBatchSize int
|
||||
|
||||
inboundBatchSize int
|
||||
outboundBatchSize int
|
||||
batchFlushInterval time.Duration
|
||||
maxOutstandingPerChan int
|
||||
maxPendingPackets int
|
||||
maxPendingBytes int
|
||||
maxSendBufPerRoutine int
|
||||
}
|
||||
|
||||
type outboundSend struct {
|
||||
buf *[]byte
|
||||
length int
|
||||
addr netip.AddrPort
|
||||
}
|
||||
|
||||
type packetBatch struct {
|
||||
packets []*packet.Packet
|
||||
}
|
||||
|
||||
func newPacketBatch(capacity int) *packetBatch {
|
||||
return &packetBatch{
|
||||
packets: make([]*packet.Packet, 0, capacity),
|
||||
}
|
||||
}
|
||||
|
||||
func (b *packetBatch) add(p *packet.Packet) {
|
||||
b.packets = append(b.packets, p)
|
||||
}
|
||||
|
||||
func (b *packetBatch) reset() {
|
||||
for i := range b.packets {
|
||||
b.packets[i] = nil
|
||||
}
|
||||
b.packets = b.packets[:0]
|
||||
}
|
||||
|
||||
func (f *Interface) getPacketBatch() *packetBatch {
|
||||
if v := f.packetBatchPool.Get(); v != nil {
|
||||
b := v.(*packetBatch)
|
||||
b.reset()
|
||||
return b
|
||||
}
|
||||
return newPacketBatch(f.inboundBatchSize)
|
||||
}
|
||||
|
||||
func (f *Interface) releasePacketBatch(b *packetBatch) {
|
||||
b.reset()
|
||||
f.packetBatchPool.Put(b)
|
||||
}
|
||||
|
||||
type outboundBatch struct {
|
||||
payloads []*[]byte
|
||||
}
|
||||
|
||||
func newOutboundBatch(capacity int) *outboundBatch {
|
||||
return &outboundBatch{payloads: make([]*[]byte, 0, capacity)}
|
||||
}
|
||||
|
||||
func (b *outboundBatch) add(buf *[]byte) {
|
||||
b.payloads = append(b.payloads, buf)
|
||||
}
|
||||
|
||||
func (b *outboundBatch) reset() {
|
||||
for i := range b.payloads {
|
||||
b.payloads[i] = nil
|
||||
}
|
||||
b.payloads = b.payloads[:0]
|
||||
}
|
||||
|
||||
func (f *Interface) getOutboundBatch() *outboundBatch {
|
||||
if v := f.outboundBatchPool.Get(); v != nil {
|
||||
b := v.(*outboundBatch)
|
||||
b.reset()
|
||||
return b
|
||||
}
|
||||
return newOutboundBatch(f.outboundBatchSize)
|
||||
}
|
||||
|
||||
func (f *Interface) releaseOutboundBatch(b *outboundBatch) {
|
||||
b.reset()
|
||||
f.outboundBatchPool.Put(b)
|
||||
}
|
||||
|
||||
func (f *Interface) getSendBuffer(q int) *[]byte {
|
||||
cache := f.sendBufCache[q]
|
||||
if n := len(cache); n > 0 {
|
||||
buf := cache[n-1]
|
||||
f.sendBufCache[q] = cache[:n-1]
|
||||
*buf = (*buf)[:0]
|
||||
return buf
|
||||
}
|
||||
if v := f.sendPool.Get(); v != nil {
|
||||
buf := v.(*[]byte)
|
||||
*buf = (*buf)[:0]
|
||||
return buf
|
||||
}
|
||||
b := make([]byte, mtu)
|
||||
return &b
|
||||
}
|
||||
|
||||
func (f *Interface) releaseSendBuffer(q int, buf *[]byte) {
|
||||
if buf == nil {
|
||||
return
|
||||
}
|
||||
*buf = (*buf)[:0]
|
||||
cache := f.sendBufCache[q]
|
||||
if len(cache) < f.maxSendBufPerRoutine {
|
||||
f.sendBufCache[q] = append(cache, buf)
|
||||
return
|
||||
}
|
||||
f.sendPool.Put(buf)
|
||||
}
|
||||
|
||||
func (f *Interface) flushSendQueue(q int, pending *[]outboundSend, pendingBytes *int) {
|
||||
if len(*pending) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
batch := make([]udp.BatchPacket, len(*pending))
|
||||
for i, entry := range *pending {
|
||||
batch[i] = udp.BatchPacket{
|
||||
Payload: (*entry.buf)[:entry.length],
|
||||
Addr: entry.addr,
|
||||
}
|
||||
}
|
||||
|
||||
sent, err := f.writers[q].WriteBatch(batch)
|
||||
if err != nil {
|
||||
f.l.WithError(err).WithField("sent", sent).Error("Failed to batch send packets")
|
||||
}
|
||||
|
||||
for _, entry := range *pending {
|
||||
f.releaseSendBuffer(q, entry.buf)
|
||||
}
|
||||
*pending = (*pending)[:0]
|
||||
if pendingBytes != nil {
|
||||
*pendingBytes = 0
|
||||
}
|
||||
}
|
||||
|
||||
type EncWriter interface {
|
||||
@@ -162,6 +338,29 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
|
||||
}
|
||||
|
||||
cs := c.pki.getCertState()
|
||||
|
||||
bc := c.BatchConfig
|
||||
if bc.InboundBatchSize <= 0 {
|
||||
bc.InboundBatchSize = inboundBatchSizeDefault
|
||||
}
|
||||
if bc.OutboundBatchSize <= 0 {
|
||||
bc.OutboundBatchSize = outboundBatchSizeDefault
|
||||
}
|
||||
if bc.FlushInterval <= 0 {
|
||||
bc.FlushInterval = batchFlushIntervalDefault
|
||||
}
|
||||
if bc.MaxOutstandingPerChan <= 0 {
|
||||
bc.MaxOutstandingPerChan = maxOutstandingBatchesDefault
|
||||
}
|
||||
if bc.MaxPendingPackets <= 0 {
|
||||
bc.MaxPendingPackets = maxPendingPacketsDefault
|
||||
}
|
||||
if bc.MaxPendingBytes <= 0 {
|
||||
bc.MaxPendingBytes = maxPendingBytesDefault
|
||||
}
|
||||
if bc.MaxSendBuffersPerChan <= 0 {
|
||||
bc.MaxSendBuffersPerChan = maxSendBufPerRoutineDefault
|
||||
}
|
||||
ifce := &Interface{
|
||||
pki: c.pki,
|
||||
hostMap: c.HostMap,
|
||||
@@ -194,9 +393,49 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
|
||||
dropped: metrics.GetOrRegisterCounter("hostinfo.cached_packets.dropped", nil),
|
||||
},
|
||||
|
||||
inbound: make([]chan *packetBatch, c.routines),
|
||||
outbound: make([]chan *outboundBatch, c.routines),
|
||||
|
||||
l: c.l,
|
||||
|
||||
inboundBatchSize: bc.InboundBatchSize,
|
||||
outboundBatchSize: bc.OutboundBatchSize,
|
||||
batchFlushInterval: bc.FlushInterval,
|
||||
maxOutstandingPerChan: bc.MaxOutstandingPerChan,
|
||||
maxPendingPackets: bc.MaxPendingPackets,
|
||||
maxPendingBytes: bc.MaxPendingBytes,
|
||||
maxSendBufPerRoutine: bc.MaxSendBuffersPerChan,
|
||||
sendBatchSize: bc.OutboundBatchSize,
|
||||
}
|
||||
|
||||
for i := 0; i < c.routines; i++ {
|
||||
ifce.inbound[i] = make(chan *packetBatch, ifce.maxOutstandingPerChan)
|
||||
ifce.outbound[i] = make(chan *outboundBatch, ifce.maxOutstandingPerChan)
|
||||
}
|
||||
|
||||
ifce.inPool = sync.Pool{New: func() any {
|
||||
return packet.New()
|
||||
}}
|
||||
|
||||
ifce.outPool = sync.Pool{New: func() any {
|
||||
t := make([]byte, mtu)
|
||||
return &t
|
||||
}}
|
||||
|
||||
ifce.packetBatchPool = sync.Pool{New: func() any {
|
||||
return newPacketBatch(ifce.inboundBatchSize)
|
||||
}}
|
||||
|
||||
ifce.outboundBatchPool = sync.Pool{New: func() any {
|
||||
return newOutboundBatch(ifce.outboundBatchSize)
|
||||
}}
|
||||
|
||||
ifce.sendPool = sync.Pool{New: func() any {
|
||||
buf := make([]byte, mtu)
|
||||
return &buf
|
||||
}}
|
||||
ifce.sendBufCache = make([][]*[]byte, c.routines)
|
||||
|
||||
ifce.tryPromoteEvery.Store(c.tryPromoteEvery)
|
||||
ifce.reQueryEvery.Store(c.reQueryEvery)
|
||||
ifce.reQueryWait.Store(int64(c.reQueryWait))
|
||||
@@ -209,7 +448,7 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
|
||||
// activate creates the interface on the host. After the interface is created, any
|
||||
// other services that want to bind listeners to its IP may do so successfully. However,
|
||||
// the interface isn't going to process anything until run() is called.
|
||||
func (f *Interface) activate() {
|
||||
func (f *Interface) activate() error {
|
||||
// actually turn on tun dev
|
||||
|
||||
addr, err := f.outside.LocalAddr()
|
||||
@@ -230,33 +469,44 @@ func (f *Interface) activate() {
|
||||
if i > 0 {
|
||||
reader, err = f.inside.NewMultiQueueReader()
|
||||
if err != nil {
|
||||
f.l.Fatal(err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
f.readers[i] = reader
|
||||
}
|
||||
|
||||
if err := f.inside.Activate(); err != nil {
|
||||
if err = f.inside.Activate(); err != nil {
|
||||
f.inside.Close()
|
||||
f.l.Fatal(err)
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *Interface) run() {
|
||||
// Launch n queues to read packets from udp
|
||||
func (f *Interface) run(c context.Context) (func(), error) {
|
||||
for i := 0; i < f.routines; i++ {
|
||||
// Launch n queues to read packets from udp
|
||||
f.wg.Add(1)
|
||||
go f.listenOut(i)
|
||||
|
||||
// Launch n queues to read packets from tun dev
|
||||
f.wg.Add(1)
|
||||
go f.listenIn(f.readers[i], i)
|
||||
|
||||
// Launch n queues to read packets from tun dev
|
||||
f.wg.Add(1)
|
||||
go f.workerIn(i, c)
|
||||
|
||||
// Launch n queues to read packets from tun dev
|
||||
f.wg.Add(1)
|
||||
go f.workerOut(i, c)
|
||||
}
|
||||
|
||||
// Launch n queues to read packets from tun dev
|
||||
for i := 0; i < f.routines; i++ {
|
||||
go f.listenIn(f.readers[i], i)
|
||||
}
|
||||
return f.wg.Wait, nil
|
||||
}
|
||||
|
||||
func (f *Interface) listenOut(i int) {
|
||||
runtime.LockOSThread()
|
||||
|
||||
var li udp.Conn
|
||||
if i > 0 {
|
||||
li = f.writers[i]
|
||||
@@ -264,41 +514,176 @@ func (f *Interface) listenOut(i int) {
|
||||
li = f.outside
|
||||
}
|
||||
|
||||
ctCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
|
||||
lhh := f.lightHouse.NewRequestHandler()
|
||||
plaintext := make([]byte, udp.MTU)
|
||||
h := &header.H{}
|
||||
fwPacket := &firewall.Packet{}
|
||||
nb := make([]byte, 12, 12)
|
||||
batch := f.getPacketBatch()
|
||||
lastFlush := time.Now()
|
||||
|
||||
li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) {
|
||||
f.readOutsidePackets(fromUdpAddr, nil, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get(f.l))
|
||||
flush := func(force bool) {
|
||||
if len(batch.packets) == 0 {
|
||||
if force {
|
||||
f.releasePacketBatch(batch)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
f.inbound[i] <- batch
|
||||
batch = f.getPacketBatch()
|
||||
lastFlush = time.Now()
|
||||
}
|
||||
|
||||
err := li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) {
|
||||
p := f.inPool.Get().(*packet.Packet)
|
||||
p.Payload = p.Payload[:mtu]
|
||||
copy(p.Payload, payload)
|
||||
p.Payload = p.Payload[:len(payload)]
|
||||
p.Addr = fromUdpAddr
|
||||
batch.add(p)
|
||||
|
||||
if len(batch.packets) >= f.inboundBatchSize || time.Since(lastFlush) >= f.batchFlushInterval {
|
||||
flush(false)
|
||||
}
|
||||
})
|
||||
|
||||
if len(batch.packets) > 0 {
|
||||
f.inbound[i] <- batch
|
||||
} else {
|
||||
f.releasePacketBatch(batch)
|
||||
}
|
||||
|
||||
if err != nil && !f.closed.Load() {
|
||||
f.l.WithError(err).Error("Error while reading packet inbound packet, closing")
|
||||
//TODO: Trigger Control to close
|
||||
}
|
||||
|
||||
f.l.Debugf("underlay reader %v is done", i)
|
||||
f.wg.Done()
|
||||
}
|
||||
|
||||
func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
|
||||
runtime.LockOSThread()
|
||||
|
||||
packet := make([]byte, mtu)
|
||||
out := make([]byte, mtu)
|
||||
fwPacket := &firewall.Packet{}
|
||||
nb := make([]byte, 12, 12)
|
||||
batch := f.getOutboundBatch()
|
||||
lastFlush := time.Now()
|
||||
|
||||
conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
|
||||
|
||||
for {
|
||||
n, err := reader.Read(packet)
|
||||
if err != nil {
|
||||
if errors.Is(err, os.ErrClosed) && f.closed.Load() {
|
||||
return
|
||||
flush := func(force bool) {
|
||||
if len(batch.payloads) == 0 {
|
||||
if force {
|
||||
f.releaseOutboundBatch(batch)
|
||||
}
|
||||
|
||||
f.l.WithError(err).Error("Error while reading outbound packet")
|
||||
// This only seems to happen when something fatal happens to the fd, so exit.
|
||||
os.Exit(2)
|
||||
return
|
||||
}
|
||||
|
||||
f.consumeInsidePacket(packet[:n], fwPacket, nb, out, i, conntrackCache.Get(f.l))
|
||||
f.outbound[i] <- batch
|
||||
batch = f.getOutboundBatch()
|
||||
lastFlush = time.Now()
|
||||
}
|
||||
|
||||
for {
|
||||
p := f.outPool.Get().(*[]byte)
|
||||
*p = (*p)[:mtu]
|
||||
n, err := reader.Read(*p)
|
||||
if err != nil {
|
||||
if !f.closed.Load() {
|
||||
f.l.WithError(err).Error("Error while reading outbound packet, closing")
|
||||
//TODO: Trigger Control to close
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
*p = (*p)[:n]
|
||||
batch.add(p)
|
||||
|
||||
if len(batch.payloads) >= f.outboundBatchSize || time.Since(lastFlush) >= f.batchFlushInterval {
|
||||
flush(false)
|
||||
}
|
||||
}
|
||||
|
||||
if len(batch.payloads) > 0 {
|
||||
f.outbound[i] <- batch
|
||||
} else {
|
||||
f.releaseOutboundBatch(batch)
|
||||
}
|
||||
|
||||
f.l.Debugf("overlay reader %v is done", i)
|
||||
f.wg.Done()
|
||||
}
|
||||
|
||||
func (f *Interface) workerIn(i int, ctx context.Context) {
|
||||
lhh := f.lightHouse.NewRequestHandler()
|
||||
conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
|
||||
fwPacket2 := &firewall.Packet{}
|
||||
nb2 := make([]byte, 12, 12)
|
||||
result2 := make([]byte, mtu)
|
||||
h := &header.H{}
|
||||
|
||||
for {
|
||||
select {
|
||||
case batch := <-f.inbound[i]:
|
||||
for _, p := range batch.packets {
|
||||
f.readOutsidePackets(p.Addr, nil, result2[:0], p.Payload, h, fwPacket2, lhh, nb2, i, conntrackCache.Get(f.l))
|
||||
p.Payload = p.Payload[:mtu]
|
||||
f.inPool.Put(p)
|
||||
}
|
||||
f.releasePacketBatch(batch)
|
||||
case <-ctx.Done():
|
||||
f.wg.Done()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (f *Interface) workerOut(i int, ctx context.Context) {
|
||||
conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
|
||||
fwPacket1 := &firewall.Packet{}
|
||||
nb1 := make([]byte, 12, 12)
|
||||
pending := make([]outboundSend, 0, f.sendBatchSize)
|
||||
pendingBytes := 0
|
||||
maxPendingPackets := f.maxPendingPackets
|
||||
if maxPendingPackets <= 0 {
|
||||
maxPendingPackets = f.sendBatchSize
|
||||
}
|
||||
maxPendingBytes := f.maxPendingBytes
|
||||
if maxPendingBytes <= 0 {
|
||||
maxPendingBytes = f.sendBatchSize * mtu
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case batch := <-f.outbound[i]:
|
||||
for _, data := range batch.payloads {
|
||||
sendBuf := f.getSendBuffer(i)
|
||||
buf := (*sendBuf)[:0]
|
||||
queue := func(addr netip.AddrPort, length int) {
|
||||
if len(pending) >= maxPendingPackets || pendingBytes+length > maxPendingBytes {
|
||||
f.flushSendQueue(i, &pending, &pendingBytes)
|
||||
}
|
||||
pending = append(pending, outboundSend{
|
||||
buf: sendBuf,
|
||||
length: length,
|
||||
addr: addr,
|
||||
})
|
||||
pendingBytes += length
|
||||
if len(pending) >= f.sendBatchSize || pendingBytes >= maxPendingBytes {
|
||||
f.flushSendQueue(i, &pending, &pendingBytes)
|
||||
}
|
||||
}
|
||||
sent := f.consumeInsidePacket(*data, fwPacket1, nb1, buf, queue, i, conntrackCache.Get(f.l))
|
||||
if !sent {
|
||||
f.releaseSendBuffer(i, sendBuf)
|
||||
}
|
||||
*data = (*data)[:mtu]
|
||||
f.outPool.Put(data)
|
||||
}
|
||||
f.releaseOutboundBatch(batch)
|
||||
if len(pending) > 0 {
|
||||
f.flushSendQueue(i, &pending, &pendingBytes)
|
||||
}
|
||||
case <-ctx.Done():
|
||||
if len(pending) > 0 {
|
||||
f.flushSendQueue(i, &pending, &pendingBytes)
|
||||
}
|
||||
f.wg.Done()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -451,6 +836,7 @@ func (f *Interface) GetCertState() *CertState {
|
||||
func (f *Interface) Close() error {
|
||||
f.closed.Store(true)
|
||||
|
||||
// Release the udp readers
|
||||
for _, u := range f.writers {
|
||||
err := u.Close()
|
||||
if err != nil {
|
||||
@@ -458,6 +844,13 @@ func (f *Interface) Close() error {
|
||||
}
|
||||
}
|
||||
|
||||
// Release the tun device
|
||||
return f.inside.Close()
|
||||
// Release the tun readers
|
||||
for _, u := range f.readers {
|
||||
err := u.Close()
|
||||
if err != nil {
|
||||
f.l.WithError(err).Error("Error while closing tun device")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
31
main.go
31
main.go
@@ -164,7 +164,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
||||
|
||||
for i := 0; i < routines; i++ {
|
||||
l.Infof("listening on %v", netip.AddrPortFrom(listenHost, uint16(port)))
|
||||
udpServer, err := udp.NewListener(l, listenHost, port, routines > 1, c.GetInt("listen.batch", 64))
|
||||
udpServer, err := udp.NewListener(l, listenHost, port, routines > 1, c.GetInt("listen.batch", 128))
|
||||
if err != nil {
|
||||
return nil, util.NewContextualError("Failed to open udp listener", m{"queue": i}, err)
|
||||
}
|
||||
@@ -221,6 +221,16 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
||||
}
|
||||
}
|
||||
|
||||
batchCfg := BatchConfig{
|
||||
InboundBatchSize: c.GetInt("batch.inbound_size", inboundBatchSizeDefault),
|
||||
OutboundBatchSize: c.GetInt("batch.outbound_size", outboundBatchSizeDefault),
|
||||
FlushInterval: c.GetDuration("batch.flush_interval", batchFlushIntervalDefault),
|
||||
MaxOutstandingPerChan: c.GetInt("batch.max_outstanding", maxOutstandingBatchesDefault),
|
||||
MaxPendingPackets: c.GetInt("batch.max_pending_packets", 0),
|
||||
MaxPendingBytes: c.GetInt("batch.max_pending_bytes", 0),
|
||||
MaxSendBuffersPerChan: c.GetInt("batch.max_send_buffers_per_routine", 0),
|
||||
}
|
||||
|
||||
ifConfig := &InterfaceConfig{
|
||||
HostMap: hostMap,
|
||||
Inside: tun,
|
||||
@@ -242,6 +252,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
||||
relayManager: NewRelayManager(ctx, l, hostMap, c),
|
||||
punchy: punchy,
|
||||
ConntrackCacheTimeout: conntrackCacheTimeout,
|
||||
BatchConfig: batchCfg,
|
||||
l: l,
|
||||
}
|
||||
|
||||
@@ -284,14 +295,14 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
||||
}
|
||||
|
||||
return &Control{
|
||||
ifce,
|
||||
l,
|
||||
ctx,
|
||||
cancel,
|
||||
sshStart,
|
||||
statsStart,
|
||||
dnsStart,
|
||||
lightHouse.StartUpdateWorker,
|
||||
connManager.Start,
|
||||
f: ifce,
|
||||
l: l,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
sshStart: sshStart,
|
||||
statsStart: statsStart,
|
||||
dnsStart: dnsStart,
|
||||
lighthouseStart: lightHouse.StartUpdateWorker,
|
||||
connectionManagerStart: connManager.Start,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -29,7 +29,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
|
||||
return
|
||||
}
|
||||
|
||||
//l.Error("in packet ", header, packet[HeaderLen:])
|
||||
//f.l.Error("in packet ", h)
|
||||
if ip.IsValid() {
|
||||
if f.myVpnNetworksTable.Contains(ip.Addr()) {
|
||||
if f.l.Level >= logrus.DebugLevel {
|
||||
@@ -245,6 +245,7 @@ func (f *Interface) handleHostRoaming(hostinfo *HostInfo, udpAddr netip.AddrPort
|
||||
return
|
||||
}
|
||||
|
||||
//TODO: Seems we have a bunch of stuff racing here, since we don't have a lock on hostinfo anymore we announce roaming in bursts
|
||||
hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", udpAddr).
|
||||
Info("Host roamed to new udp ip/port.")
|
||||
hostinfo.lastRoam = time.Now()
|
||||
@@ -470,7 +471,7 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
|
||||
|
||||
out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:header.Len], packet[header.Len:], messageCounter, nb)
|
||||
if err != nil {
|
||||
hostinfo.logger(f.l).WithError(err).Error("Failed to decrypt packet")
|
||||
hostinfo.logger(f.l).WithError(err).WithField("fwPacket", fwPacket).Error("Failed to decrypt packet")
|
||||
return false
|
||||
}
|
||||
|
||||
|
||||
12
packet/packet.go
Normal file
12
packet/packet.go
Normal file
@@ -0,0 +1,12 @@
|
||||
package packet
|
||||
|
||||
import "net/netip"
|
||||
|
||||
type Packet struct {
|
||||
Payload []byte
|
||||
Addr netip.AddrPort
|
||||
}
|
||||
|
||||
func New() *Packet {
|
||||
return &Packet{Payload: make([]byte, 9001)}
|
||||
}
|
||||
@@ -44,7 +44,10 @@ type Service struct {
|
||||
}
|
||||
|
||||
func New(control *nebula.Control) (*Service, error) {
|
||||
control.Start()
|
||||
wait, err := control.Start()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ctx := control.Context()
|
||||
eg, ctx := errgroup.WithContext(ctx)
|
||||
@@ -141,6 +144,12 @@ func New(control *nebula.Control) (*Service, error) {
|
||||
}
|
||||
})
|
||||
|
||||
// Add the nebula wait function to the group
|
||||
eg.Go(func() error {
|
||||
wait()
|
||||
return nil
|
||||
})
|
||||
|
||||
return &s, nil
|
||||
}
|
||||
|
||||
|
||||
15
udp/conn.go
15
udp/conn.go
@@ -16,12 +16,18 @@ type EncReader func(
|
||||
type Conn interface {
|
||||
Rebind() error
|
||||
LocalAddr() (netip.AddrPort, error)
|
||||
ListenOut(r EncReader)
|
||||
ListenOut(r EncReader) error
|
||||
WriteTo(b []byte, addr netip.AddrPort) error
|
||||
WriteBatch(pkts []BatchPacket) (int, error)
|
||||
ReloadConfig(c *config.C)
|
||||
Close() error
|
||||
}
|
||||
|
||||
type BatchPacket struct {
|
||||
Payload []byte
|
||||
Addr netip.AddrPort
|
||||
}
|
||||
|
||||
type NoopConn struct{}
|
||||
|
||||
func (NoopConn) Rebind() error {
|
||||
@@ -30,12 +36,15 @@ func (NoopConn) Rebind() error {
|
||||
func (NoopConn) LocalAddr() (netip.AddrPort, error) {
|
||||
return netip.AddrPort{}, nil
|
||||
}
|
||||
func (NoopConn) ListenOut(_ EncReader) {
|
||||
return
|
||||
func (NoopConn) ListenOut(_ EncReader) error {
|
||||
return nil
|
||||
}
|
||||
func (NoopConn) WriteTo(_ []byte, _ netip.AddrPort) error {
|
||||
return nil
|
||||
}
|
||||
func (NoopConn) WriteBatch(_ []BatchPacket) (int, error) {
|
||||
return 0, nil
|
||||
}
|
||||
func (NoopConn) ReloadConfig(_ *config.C) {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -140,6 +140,17 @@ func (u *StdConn) WriteTo(b []byte, ap netip.AddrPort) error {
|
||||
}
|
||||
}
|
||||
|
||||
func (u *StdConn) WriteBatch(pkts []BatchPacket) (int, error) {
|
||||
sent := 0
|
||||
for _, pkt := range pkts {
|
||||
if err := u.WriteTo(pkt.Payload, pkt.Addr); err != nil {
|
||||
return sent, err
|
||||
}
|
||||
sent++
|
||||
}
|
||||
return sent, nil
|
||||
}
|
||||
|
||||
func (u *StdConn) LocalAddr() (netip.AddrPort, error) {
|
||||
a := u.UDPConn.LocalAddr()
|
||||
|
||||
@@ -165,7 +176,7 @@ func NewUDPStatsEmitter(udpConns []Conn) func() {
|
||||
return func() {}
|
||||
}
|
||||
|
||||
func (u *StdConn) ListenOut(r EncReader) {
|
||||
func (u *StdConn) ListenOut(r EncReader) error {
|
||||
buffer := make([]byte, MTU)
|
||||
|
||||
for {
|
||||
@@ -174,14 +185,17 @@ func (u *StdConn) ListenOut(r EncReader) {
|
||||
if err != nil {
|
||||
if errors.Is(err, net.ErrClosed) {
|
||||
u.l.WithError(err).Debug("udp socket is closed, exiting read loop")
|
||||
return
|
||||
return err
|
||||
}
|
||||
|
||||
u.l.WithError(err).Error("unexpected udp socket receive error")
|
||||
continue
|
||||
}
|
||||
|
||||
r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n])
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (u *StdConn) Rebind() error {
|
||||
|
||||
@@ -42,6 +42,17 @@ func (u *GenericConn) WriteTo(b []byte, addr netip.AddrPort) error {
|
||||
return err
|
||||
}
|
||||
|
||||
func (u *GenericConn) WriteBatch(pkts []BatchPacket) (int, error) {
|
||||
sent := 0
|
||||
for _, pkt := range pkts {
|
||||
if err := u.WriteTo(pkt.Payload, pkt.Addr); err != nil {
|
||||
return sent, err
|
||||
}
|
||||
sent++
|
||||
}
|
||||
return sent, nil
|
||||
}
|
||||
|
||||
func (u *GenericConn) LocalAddr() (netip.AddrPort, error) {
|
||||
a := u.UDPConn.LocalAddr()
|
||||
|
||||
@@ -71,15 +82,14 @@ type rawMessage struct {
|
||||
Len uint32
|
||||
}
|
||||
|
||||
func (u *GenericConn) ListenOut(r EncReader) {
|
||||
func (u *GenericConn) ListenOut(r EncReader) error {
|
||||
buffer := make([]byte, MTU)
|
||||
|
||||
for {
|
||||
// Just read one packet at a time
|
||||
n, rua, err := u.ReadFromUDPAddrPort(buffer)
|
||||
if err != nil {
|
||||
u.l.WithError(err).Debug("udp socket is closed, exiting read loop")
|
||||
return
|
||||
return err
|
||||
}
|
||||
|
||||
r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n])
|
||||
|
||||
611
udp/udp_linux.go
611
udp/udp_linux.go
@@ -5,10 +5,13 @@ package udp
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"github.com/rcrowley/go-metrics"
|
||||
@@ -17,19 +20,40 @@ import (
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
var readTimeout = unix.NsecToTimeval(int64(time.Millisecond * 500))
|
||||
|
||||
const (
|
||||
defaultGSOMaxSegments = 128
|
||||
defaultGSOFlushTimeout = 80 * time.Microsecond
|
||||
defaultGROReadBufferSize = MTU * defaultGSOMaxSegments
|
||||
maxGSOBatchBytes = 0xFFFF
|
||||
)
|
||||
|
||||
var (
|
||||
errGSOFallback = errors.New("udp gso fallback")
|
||||
errGSODisabled = errors.New("udp gso disabled")
|
||||
)
|
||||
|
||||
type StdConn struct {
|
||||
sysFd int
|
||||
isV4 bool
|
||||
l *logrus.Logger
|
||||
batch int
|
||||
}
|
||||
|
||||
func maybeIPV4(ip net.IP) (net.IP, bool) {
|
||||
ip4 := ip.To4()
|
||||
if ip4 != nil {
|
||||
return ip4, true
|
||||
}
|
||||
return ip, false
|
||||
enableGRO bool
|
||||
enableGSO bool
|
||||
|
||||
gsoMu sync.Mutex
|
||||
gsoBuf []byte
|
||||
gsoAddr netip.AddrPort
|
||||
gsoSegSize int
|
||||
gsoSegments int
|
||||
gsoMaxSegments int
|
||||
gsoMaxBytes int
|
||||
gsoFlushTimeout time.Duration
|
||||
gsoTimer *time.Timer
|
||||
|
||||
groBufSize int
|
||||
}
|
||||
|
||||
func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
|
||||
@@ -55,6 +79,11 @@ func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch in
|
||||
}
|
||||
}
|
||||
|
||||
// Set a read timeout
|
||||
if err = unix.SetsockoptTimeval(fd, unix.SOL_SOCKET, unix.SO_RCVTIMEO, &readTimeout); err != nil {
|
||||
return nil, fmt.Errorf("unable to set SO_RCVTIMEO: %s", err)
|
||||
}
|
||||
|
||||
var sa unix.Sockaddr
|
||||
if ip.Is4() {
|
||||
sa4 := &unix.SockaddrInet4{Port: port}
|
||||
@@ -69,7 +98,16 @@ func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch in
|
||||
return nil, fmt.Errorf("unable to bind to socket: %s", err)
|
||||
}
|
||||
|
||||
return &StdConn{sysFd: fd, isV4: ip.Is4(), l: l, batch: batch}, err
|
||||
return &StdConn{
|
||||
sysFd: fd,
|
||||
isV4: ip.Is4(),
|
||||
l: l,
|
||||
batch: batch,
|
||||
gsoMaxSegments: defaultGSOMaxSegments,
|
||||
gsoMaxBytes: MTU * defaultGSOMaxSegments,
|
||||
gsoFlushTimeout: defaultGSOFlushTimeout,
|
||||
groBufSize: MTU,
|
||||
}, err
|
||||
}
|
||||
|
||||
func (u *StdConn) Rebind() error {
|
||||
@@ -118,20 +156,46 @@ func (u *StdConn) LocalAddr() (netip.AddrPort, error) {
|
||||
}
|
||||
}
|
||||
|
||||
func (u *StdConn) ListenOut(r EncReader) {
|
||||
var ip netip.Addr
|
||||
func (u *StdConn) ListenOut(r EncReader) error {
|
||||
var (
|
||||
ip netip.Addr
|
||||
controls [][]byte
|
||||
)
|
||||
|
||||
msgs, buffers, names := u.PrepareRawMessages(u.batch)
|
||||
bufSize := u.readBufferSize()
|
||||
msgs, buffers, names := u.PrepareRawMessages(u.batch, bufSize)
|
||||
read := u.ReadMulti
|
||||
if u.batch == 1 {
|
||||
read = u.ReadSingle
|
||||
}
|
||||
|
||||
for {
|
||||
desired := u.readBufferSize()
|
||||
if len(buffers) == 0 || cap(buffers[0]) < desired {
|
||||
msgs, buffers, names = u.PrepareRawMessages(u.batch, desired)
|
||||
controls = nil
|
||||
}
|
||||
|
||||
if u.enableGRO {
|
||||
if controls == nil {
|
||||
controls = make([][]byte, len(msgs))
|
||||
for i := range controls {
|
||||
controls[i] = make([]byte, unix.CmsgSpace(4))
|
||||
}
|
||||
}
|
||||
for i := range msgs {
|
||||
setRawMessageControl(&msgs[i], controls[i])
|
||||
}
|
||||
} else if controls != nil {
|
||||
for i := range msgs {
|
||||
setRawMessageControl(&msgs[i], nil)
|
||||
}
|
||||
controls = nil
|
||||
}
|
||||
|
||||
n, err := read(msgs)
|
||||
if err != nil {
|
||||
u.l.WithError(err).Debug("udp socket is closed, exiting read loop")
|
||||
return
|
||||
return err
|
||||
}
|
||||
|
||||
for i := 0; i < n; i++ {
|
||||
@@ -141,11 +205,82 @@ func (u *StdConn) ListenOut(r EncReader) {
|
||||
} else {
|
||||
ip, _ = netip.AddrFromSlice(names[i][8:24])
|
||||
}
|
||||
r(netip.AddrPortFrom(ip.Unmap(), binary.BigEndian.Uint16(names[i][2:4])), buffers[i][:msgs[i].Len])
|
||||
addr := netip.AddrPortFrom(ip.Unmap(), binary.BigEndian.Uint16(names[i][2:4]))
|
||||
payload := buffers[i][:msgs[i].Len]
|
||||
|
||||
if u.enableGRO && u.l.IsLevelEnabled(logrus.DebugLevel) {
|
||||
ctrlLen := getRawMessageControlLen(&msgs[i])
|
||||
msgFlags := getRawMessageFlags(&msgs[i])
|
||||
u.l.WithFields(logrus.Fields{
|
||||
"tag": "gro-debug",
|
||||
"stage": "recv",
|
||||
"payload_len": len(payload),
|
||||
"ctrl_len": ctrlLen,
|
||||
"msg_flags": msgFlags,
|
||||
}).Debug("gro batch data")
|
||||
if controls != nil && ctrlLen > 0 {
|
||||
maxDump := ctrlLen
|
||||
if maxDump > 16 {
|
||||
maxDump = 16
|
||||
}
|
||||
u.l.WithFields(logrus.Fields{
|
||||
"tag": "gro-debug",
|
||||
"stage": "control-bytes",
|
||||
"control_hex": fmt.Sprintf("%x", controls[i][:maxDump]),
|
||||
"datalen": ctrlLen,
|
||||
}).Debug("gro control dump")
|
||||
}
|
||||
}
|
||||
|
||||
sawControl := false
|
||||
if controls != nil {
|
||||
if ctrlLen := getRawMessageControlLen(&msgs[i]); ctrlLen > 0 {
|
||||
if segSize, segCount := parseGROControl(controls[i][:ctrlLen]); segSize > 0 {
|
||||
sawControl = true
|
||||
if u.l.IsLevelEnabled(logrus.DebugLevel) {
|
||||
u.l.WithFields(logrus.Fields{
|
||||
"tag": "gro-debug",
|
||||
"stage": "control",
|
||||
"seg_size": segSize,
|
||||
"seg_count": segCount,
|
||||
"payloadLen": len(payload),
|
||||
}).Debug("gro control parsed")
|
||||
}
|
||||
segSize = normalizeGROSegSize(segSize, segCount, len(payload))
|
||||
if segSize > 0 && segSize < len(payload) {
|
||||
if u.emitGROSegments(r, addr, payload, segSize) {
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if u.enableGRO && len(payload) > MTU {
|
||||
if !sawControl && u.l.IsLevelEnabled(logrus.DebugLevel) {
|
||||
u.l.WithFields(logrus.Fields{
|
||||
"tag": "gro-debug",
|
||||
"stage": "fallback",
|
||||
"payload_len": len(payload),
|
||||
}).Debug("gro control missing; splitting payload by MTU")
|
||||
}
|
||||
if u.emitGROSegments(r, addr, payload, MTU) {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
r(addr, payload)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (u *StdConn) readBufferSize() int {
|
||||
if u.enableGRO && u.groBufSize > MTU {
|
||||
return u.groBufSize
|
||||
}
|
||||
return MTU
|
||||
}
|
||||
|
||||
func (u *StdConn) ReadSingle(msgs []rawMessage) (int, error) {
|
||||
for {
|
||||
n, _, err := unix.Syscall6(
|
||||
@@ -159,6 +294,9 @@ func (u *StdConn) ReadSingle(msgs []rawMessage) (int, error) {
|
||||
)
|
||||
|
||||
if err != 0 {
|
||||
if err == unix.EAGAIN || err == unix.EINTR {
|
||||
continue
|
||||
}
|
||||
return 0, &net.OpError{Op: "recvmsg", Err: err}
|
||||
}
|
||||
|
||||
@@ -180,6 +318,9 @@ func (u *StdConn) ReadMulti(msgs []rawMessage) (int, error) {
|
||||
)
|
||||
|
||||
if err != 0 {
|
||||
if err == unix.EAGAIN || err == unix.EINTR {
|
||||
continue
|
||||
}
|
||||
return 0, &net.OpError{Op: "recvmmsg", Err: err}
|
||||
}
|
||||
|
||||
@@ -188,12 +329,132 @@ func (u *StdConn) ReadMulti(msgs []rawMessage) (int, error) {
|
||||
}
|
||||
|
||||
func (u *StdConn) WriteTo(b []byte, ip netip.AddrPort) error {
|
||||
if u.enableGSO && ip.IsValid() {
|
||||
if err := u.queueGSOPacket(b, ip); err == nil {
|
||||
return nil
|
||||
} else if !errors.Is(err, errGSOFallback) {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if u.isV4 {
|
||||
return u.writeTo4(b, ip)
|
||||
}
|
||||
return u.writeTo6(b, ip)
|
||||
}
|
||||
|
||||
func (u *StdConn) WriteBatch(pkts []BatchPacket) (int, error) {
|
||||
if len(pkts) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
msgs := make([]rawMessage, 0, len(pkts))
|
||||
iovs := make([]iovec, 0, len(pkts))
|
||||
names := make([][unix.SizeofSockaddrInet6]byte, 0, len(pkts))
|
||||
|
||||
sent := 0
|
||||
|
||||
for _, pkt := range pkts {
|
||||
if len(pkt.Payload) == 0 {
|
||||
sent++
|
||||
continue
|
||||
}
|
||||
|
||||
if u.enableGSO && pkt.Addr.IsValid() {
|
||||
if err := u.queueGSOPacket(pkt.Payload, pkt.Addr); err == nil {
|
||||
sent++
|
||||
continue
|
||||
} else if !errors.Is(err, errGSOFallback) {
|
||||
return sent, err
|
||||
}
|
||||
}
|
||||
|
||||
if !pkt.Addr.IsValid() {
|
||||
if err := u.WriteTo(pkt.Payload, pkt.Addr); err != nil {
|
||||
return sent, err
|
||||
}
|
||||
sent++
|
||||
continue
|
||||
}
|
||||
|
||||
msgs = append(msgs, rawMessage{})
|
||||
iovs = append(iovs, iovec{})
|
||||
names = append(names, [unix.SizeofSockaddrInet6]byte{})
|
||||
|
||||
idx := len(msgs) - 1
|
||||
msg := &msgs[idx]
|
||||
iov := &iovs[idx]
|
||||
name := &names[idx]
|
||||
|
||||
setIovecSlice(iov, pkt.Payload)
|
||||
msg.Hdr.Iov = iov
|
||||
msg.Hdr.Iovlen = 1
|
||||
setRawMessageControl(msg, nil)
|
||||
msg.Hdr.Flags = 0
|
||||
|
||||
nameLen, err := u.encodeSockaddr(name[:], pkt.Addr)
|
||||
if err != nil {
|
||||
return sent, err
|
||||
}
|
||||
msg.Hdr.Name = &name[0]
|
||||
msg.Hdr.Namelen = nameLen
|
||||
}
|
||||
|
||||
if len(msgs) == 0 {
|
||||
return sent, nil
|
||||
}
|
||||
|
||||
offset := 0
|
||||
for offset < len(msgs) {
|
||||
n, _, errno := unix.Syscall6(
|
||||
unix.SYS_SENDMMSG,
|
||||
uintptr(u.sysFd),
|
||||
uintptr(unsafe.Pointer(&msgs[offset])),
|
||||
uintptr(len(msgs)-offset),
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
)
|
||||
|
||||
if errno != 0 {
|
||||
if errno == unix.EINTR {
|
||||
continue
|
||||
}
|
||||
return sent + offset, &net.OpError{Op: "sendmmsg", Err: errno}
|
||||
}
|
||||
|
||||
if n == 0 {
|
||||
break
|
||||
}
|
||||
offset += int(n)
|
||||
}
|
||||
|
||||
return sent + len(msgs), nil
|
||||
}
|
||||
|
||||
func (u *StdConn) encodeSockaddr(dst []byte, addr netip.AddrPort) (uint32, error) {
|
||||
if u.isV4 {
|
||||
if !addr.Addr().Is4() {
|
||||
return 0, fmt.Errorf("Listener is IPv4, but writing to IPv6 remote")
|
||||
}
|
||||
var sa unix.RawSockaddrInet4
|
||||
sa.Family = unix.AF_INET
|
||||
sa.Addr = addr.Addr().As4()
|
||||
binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&sa.Port))[:], addr.Port())
|
||||
size := unix.SizeofSockaddrInet4
|
||||
copy(dst[:size], (*(*[unix.SizeofSockaddrInet4]byte)(unsafe.Pointer(&sa)))[:])
|
||||
return uint32(size), nil
|
||||
}
|
||||
|
||||
var sa unix.RawSockaddrInet6
|
||||
sa.Family = unix.AF_INET6
|
||||
sa.Addr = addr.Addr().As16()
|
||||
binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&sa.Port))[:], addr.Port())
|
||||
size := unix.SizeofSockaddrInet6
|
||||
copy(dst[:size], (*(*[unix.SizeofSockaddrInet6]byte)(unsafe.Pointer(&sa)))[:])
|
||||
return uint32(size), nil
|
||||
}
|
||||
|
||||
func (u *StdConn) writeTo6(b []byte, ip netip.AddrPort) error {
|
||||
var rsa unix.RawSockaddrInet6
|
||||
rsa.Family = unix.AF_INET6
|
||||
@@ -221,7 +482,7 @@ func (u *StdConn) writeTo6(b []byte, ip netip.AddrPort) error {
|
||||
|
||||
func (u *StdConn) writeTo4(b []byte, ip netip.AddrPort) error {
|
||||
if !ip.Addr().Is4() {
|
||||
return ErrInvalidIPv6RemoteForSocket
|
||||
return fmt.Errorf("Listener is IPv4, but writing to IPv6 remote")
|
||||
}
|
||||
|
||||
var rsa unix.RawSockaddrInet4
|
||||
@@ -294,6 +555,94 @@ func (u *StdConn) ReloadConfig(c *config.C) {
|
||||
u.l.WithError(err).Error("Failed to set listen.so_mark")
|
||||
}
|
||||
}
|
||||
|
||||
u.configureGRO(c)
|
||||
u.configureGSO(c)
|
||||
}
|
||||
|
||||
func (u *StdConn) configureGRO(c *config.C) {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
|
||||
enable := c.GetBool("listen.enable_gro", true)
|
||||
if enable == u.enableGRO {
|
||||
if enable {
|
||||
if size := c.GetInt("listen.gro_read_buffer", 0); size > 0 {
|
||||
u.setGROBufferSize(size)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if enable {
|
||||
if err := unix.SetsockoptInt(u.sysFd, unix.SOL_UDP, unix.UDP_GRO, 1); err != nil {
|
||||
u.l.WithError(err).Warn("Failed to enable UDP GRO")
|
||||
return
|
||||
}
|
||||
u.enableGRO = true
|
||||
u.setGROBufferSize(c.GetInt("listen.gro_read_buffer", defaultGROReadBufferSize))
|
||||
u.l.WithField("buffer_size", u.groBufSize).Info("UDP GRO enabled")
|
||||
return
|
||||
}
|
||||
|
||||
if err := unix.SetsockoptInt(u.sysFd, unix.SOL_UDP, unix.UDP_GRO, 0); err != nil && err != unix.ENOPROTOOPT {
|
||||
u.l.WithError(err).Warn("Failed to disable UDP GRO")
|
||||
}
|
||||
u.enableGRO = false
|
||||
u.groBufSize = MTU
|
||||
}
|
||||
|
||||
func (u *StdConn) configureGSO(c *config.C) {
|
||||
enable := c.GetBool("listen.enable_gso", true)
|
||||
if !enable {
|
||||
u.disableGSO()
|
||||
} else {
|
||||
u.enableGSO = true
|
||||
}
|
||||
|
||||
segments := c.GetInt("listen.gso_max_segments", defaultGSOMaxSegments)
|
||||
if segments < 1 {
|
||||
segments = 1
|
||||
}
|
||||
u.gsoMaxSegments = segments
|
||||
|
||||
maxBytes := c.GetInt("listen.gso_max_bytes", 0)
|
||||
if maxBytes <= 0 {
|
||||
maxBytes = MTU * segments
|
||||
}
|
||||
if maxBytes > maxGSOBatchBytes {
|
||||
u.l.WithField("requested", maxBytes).Warn("listen.gso_max_bytes larger than UDP limit; clamping")
|
||||
maxBytes = maxGSOBatchBytes
|
||||
}
|
||||
u.gsoMaxBytes = maxBytes
|
||||
|
||||
timeout := c.GetDuration("listen.gso_flush_timeout", defaultGSOFlushTimeout)
|
||||
if timeout < 0 {
|
||||
timeout = 0
|
||||
}
|
||||
u.gsoFlushTimeout = timeout
|
||||
}
|
||||
|
||||
func (u *StdConn) setGROBufferSize(size int) {
|
||||
if size < MTU {
|
||||
size = defaultGROReadBufferSize
|
||||
}
|
||||
if size > maxGSOBatchBytes {
|
||||
size = maxGSOBatchBytes
|
||||
}
|
||||
u.groBufSize = size
|
||||
}
|
||||
|
||||
func (u *StdConn) disableGSO() {
|
||||
u.gsoMu.Lock()
|
||||
defer u.gsoMu.Unlock()
|
||||
u.enableGSO = false
|
||||
_ = u.flushGSOlocked()
|
||||
u.gsoBuf = nil
|
||||
u.gsoSegments = 0
|
||||
u.gsoSegSize = 0
|
||||
u.stopGSOTimerLocked()
|
||||
}
|
||||
|
||||
func (u *StdConn) getMemInfo(meminfo *[unix.SK_MEMINFO_VARS]uint32) error {
|
||||
@@ -305,7 +654,239 @@ func (u *StdConn) getMemInfo(meminfo *[unix.SK_MEMINFO_VARS]uint32) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (u *StdConn) queueGSOPacket(b []byte, addr netip.AddrPort) error {
|
||||
if len(b) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
u.gsoMu.Lock()
|
||||
defer u.gsoMu.Unlock()
|
||||
|
||||
if !u.enableGSO || !addr.IsValid() || len(b) > u.gsoMaxBytes {
|
||||
if err := u.flushGSOlocked(); err != nil {
|
||||
return err
|
||||
}
|
||||
return errGSOFallback
|
||||
}
|
||||
|
||||
if u.gsoSegments == 0 {
|
||||
if cap(u.gsoBuf) < u.gsoMaxBytes {
|
||||
u.gsoBuf = make([]byte, 0, u.gsoMaxBytes)
|
||||
}
|
||||
u.gsoAddr = addr
|
||||
u.gsoSegSize = len(b)
|
||||
} else if addr != u.gsoAddr || len(b) != u.gsoSegSize {
|
||||
if err := u.flushGSOlocked(); err != nil {
|
||||
return err
|
||||
}
|
||||
if cap(u.gsoBuf) < u.gsoMaxBytes {
|
||||
u.gsoBuf = make([]byte, 0, u.gsoMaxBytes)
|
||||
}
|
||||
u.gsoAddr = addr
|
||||
u.gsoSegSize = len(b)
|
||||
}
|
||||
|
||||
if len(u.gsoBuf)+len(b) > u.gsoMaxBytes {
|
||||
if err := u.flushGSOlocked(); err != nil {
|
||||
return err
|
||||
}
|
||||
if cap(u.gsoBuf) < u.gsoMaxBytes {
|
||||
u.gsoBuf = make([]byte, 0, u.gsoMaxBytes)
|
||||
}
|
||||
u.gsoAddr = addr
|
||||
u.gsoSegSize = len(b)
|
||||
}
|
||||
|
||||
u.gsoBuf = append(u.gsoBuf, b...)
|
||||
u.gsoSegments++
|
||||
|
||||
if u.gsoSegments >= u.gsoMaxSegments || u.gsoFlushTimeout <= 0 {
|
||||
return u.flushGSOlocked()
|
||||
}
|
||||
|
||||
u.scheduleGSOFlushLocked()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (u *StdConn) flushGSOlocked() error {
|
||||
if u.gsoSegments == 0 {
|
||||
u.stopGSOTimerLocked()
|
||||
return nil
|
||||
}
|
||||
|
||||
payload := append([]byte(nil), u.gsoBuf...)
|
||||
addr := u.gsoAddr
|
||||
segSize := u.gsoSegSize
|
||||
|
||||
u.gsoBuf = u.gsoBuf[:0]
|
||||
u.gsoSegments = 0
|
||||
u.gsoSegSize = 0
|
||||
u.stopGSOTimerLocked()
|
||||
|
||||
if segSize <= 0 {
|
||||
return errGSOFallback
|
||||
}
|
||||
|
||||
err := u.sendSegmented(payload, addr, segSize)
|
||||
if errors.Is(err, errGSODisabled) {
|
||||
u.l.WithField("addr", addr).Warn("UDP GSO disabled by kernel, falling back to sendto")
|
||||
u.enableGSO = false
|
||||
return u.sendSegmentsIndividually(payload, addr, segSize)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (u *StdConn) sendSegmented(payload []byte, addr netip.AddrPort, segSize int) error {
|
||||
if len(payload) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
control := make([]byte, unix.CmsgSpace(2))
|
||||
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
|
||||
hdr.Level = unix.SOL_UDP
|
||||
hdr.Type = unix.UDP_SEGMENT
|
||||
setCmsgLen(hdr, unix.CmsgLen(2))
|
||||
binary.NativeEndian.PutUint16(control[unix.CmsgLen(0):unix.CmsgLen(0)+2], uint16(segSize))
|
||||
|
||||
var sa unix.Sockaddr
|
||||
if addr.Addr().Is4() {
|
||||
var sa4 unix.SockaddrInet4
|
||||
sa4.Port = int(addr.Port())
|
||||
sa4.Addr = addr.Addr().As4()
|
||||
sa = &sa4
|
||||
} else {
|
||||
var sa6 unix.SockaddrInet6
|
||||
sa6.Port = int(addr.Port())
|
||||
sa6.Addr = addr.Addr().As16()
|
||||
sa = &sa6
|
||||
}
|
||||
|
||||
if _, err := unix.SendmsgN(u.sysFd, payload, control, sa, 0); err != nil {
|
||||
if errno, ok := err.(syscall.Errno); ok && (errno == unix.EINVAL || errno == unix.ENOTSUP || errno == unix.EOPNOTSUPP) {
|
||||
return errGSODisabled
|
||||
}
|
||||
return &net.OpError{Op: "sendmsg", Err: err}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (u *StdConn) sendSegmentsIndividually(buf []byte, addr netip.AddrPort, segSize int) error {
|
||||
if segSize <= 0 {
|
||||
return errGSOFallback
|
||||
}
|
||||
|
||||
for offset := 0; offset < len(buf); offset += segSize {
|
||||
end := offset + segSize
|
||||
if end > len(buf) {
|
||||
end = len(buf)
|
||||
}
|
||||
var err error
|
||||
if u.isV4 {
|
||||
err = u.writeTo4(buf[offset:end], addr)
|
||||
} else {
|
||||
err = u.writeTo6(buf[offset:end], addr)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (u *StdConn) scheduleGSOFlushLocked() {
|
||||
if u.gsoTimer == nil {
|
||||
u.gsoTimer = time.AfterFunc(u.gsoFlushTimeout, u.gsoFlushTimer)
|
||||
return
|
||||
}
|
||||
u.gsoTimer.Reset(u.gsoFlushTimeout)
|
||||
}
|
||||
|
||||
func (u *StdConn) stopGSOTimerLocked() {
|
||||
if u.gsoTimer != nil {
|
||||
u.gsoTimer.Stop()
|
||||
u.gsoTimer = nil
|
||||
}
|
||||
}
|
||||
|
||||
func (u *StdConn) gsoFlushTimer() {
|
||||
u.gsoMu.Lock()
|
||||
defer u.gsoMu.Unlock()
|
||||
_ = u.flushGSOlocked()
|
||||
}
|
||||
|
||||
func parseGROControl(control []byte) (int, int) {
|
||||
if len(control) == 0 {
|
||||
return 0, 0
|
||||
}
|
||||
|
||||
cmsgs, err := unix.ParseSocketControlMessage(control)
|
||||
if err != nil {
|
||||
return 0, 0
|
||||
}
|
||||
|
||||
for _, c := range cmsgs {
|
||||
if c.Header.Level == unix.SOL_UDP && c.Header.Type == unix.UDP_GRO && len(c.Data) >= 2 {
|
||||
segSize := int(binary.NativeEndian.Uint16(c.Data[:2]))
|
||||
segCount := 0
|
||||
if len(c.Data) >= 4 {
|
||||
segCount = int(binary.NativeEndian.Uint16(c.Data[2:4]))
|
||||
}
|
||||
return segSize, segCount
|
||||
}
|
||||
}
|
||||
|
||||
return 0, 0
|
||||
}
|
||||
|
||||
func (u *StdConn) emitGROSegments(r EncReader, addr netip.AddrPort, payload []byte, segSize int) bool {
|
||||
if segSize <= 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
for offset := 0; offset < len(payload); offset += segSize {
|
||||
end := offset + segSize
|
||||
if end > len(payload) {
|
||||
end = len(payload)
|
||||
}
|
||||
segment := make([]byte, end-offset)
|
||||
copy(segment, payload[offset:end])
|
||||
r(addr, segment)
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func normalizeGROSegSize(segSize, segCount, total int) int {
|
||||
if segSize <= 0 || total <= 0 {
|
||||
return segSize
|
||||
}
|
||||
|
||||
if segSize > total && segCount > 0 {
|
||||
segSize = total / segCount
|
||||
if segSize == 0 {
|
||||
segSize = total
|
||||
}
|
||||
}
|
||||
|
||||
if segCount <= 1 && segSize > 0 && total > segSize {
|
||||
calculated := total / segSize
|
||||
if calculated <= 1 {
|
||||
calculated = (total + segSize - 1) / segSize
|
||||
}
|
||||
if calculated > 1 {
|
||||
segCount = calculated
|
||||
}
|
||||
}
|
||||
|
||||
if segSize > MTU {
|
||||
return MTU
|
||||
}
|
||||
|
||||
return segSize
|
||||
}
|
||||
|
||||
func (u *StdConn) Close() error {
|
||||
u.disableGSO()
|
||||
return syscall.Close(u.sysFd)
|
||||
}
|
||||
|
||||
|
||||
@@ -30,13 +30,16 @@ type rawMessage struct {
|
||||
Len uint32
|
||||
}
|
||||
|
||||
func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) {
|
||||
func (u *StdConn) PrepareRawMessages(n int, bufSize int) ([]rawMessage, [][]byte, [][]byte) {
|
||||
if bufSize <= 0 {
|
||||
bufSize = MTU
|
||||
}
|
||||
msgs := make([]rawMessage, n)
|
||||
buffers := make([][]byte, n)
|
||||
names := make([][]byte, n)
|
||||
|
||||
for i := range msgs {
|
||||
buffers[i] = make([]byte, MTU)
|
||||
buffers[i] = make([]byte, bufSize)
|
||||
names[i] = make([]byte, unix.SizeofSockaddrInet6)
|
||||
|
||||
vs := []iovec{
|
||||
@@ -52,3 +55,35 @@ func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) {
|
||||
|
||||
return msgs, buffers, names
|
||||
}
|
||||
|
||||
func setRawMessageControl(msg *rawMessage, buf []byte) {
|
||||
if len(buf) == 0 {
|
||||
msg.Hdr.Control = nil
|
||||
msg.Hdr.Controllen = 0
|
||||
return
|
||||
}
|
||||
msg.Hdr.Control = &buf[0]
|
||||
msg.Hdr.Controllen = uint32(len(buf))
|
||||
}
|
||||
|
||||
func getRawMessageControlLen(msg *rawMessage) int {
|
||||
return int(msg.Hdr.Controllen)
|
||||
}
|
||||
|
||||
func getRawMessageFlags(msg *rawMessage) int {
|
||||
return int(msg.Hdr.Flags)
|
||||
}
|
||||
|
||||
func setCmsgLen(h *unix.Cmsghdr, l int) {
|
||||
h.Len = uint32(l)
|
||||
}
|
||||
|
||||
func setIovecSlice(iov *iovec, b []byte) {
|
||||
if len(b) == 0 {
|
||||
iov.Base = nil
|
||||
iov.Len = 0
|
||||
return
|
||||
}
|
||||
iov.Base = &b[0]
|
||||
iov.Len = uint32(len(b))
|
||||
}
|
||||
|
||||
@@ -33,13 +33,16 @@ type rawMessage struct {
|
||||
Pad0 [4]byte
|
||||
}
|
||||
|
||||
func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) {
|
||||
func (u *StdConn) PrepareRawMessages(n int, bufSize int) ([]rawMessage, [][]byte, [][]byte) {
|
||||
if bufSize <= 0 {
|
||||
bufSize = MTU
|
||||
}
|
||||
msgs := make([]rawMessage, n)
|
||||
buffers := make([][]byte, n)
|
||||
names := make([][]byte, n)
|
||||
|
||||
for i := range msgs {
|
||||
buffers[i] = make([]byte, MTU)
|
||||
buffers[i] = make([]byte, bufSize)
|
||||
names[i] = make([]byte, unix.SizeofSockaddrInet6)
|
||||
|
||||
vs := []iovec{
|
||||
@@ -55,3 +58,35 @@ func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) {
|
||||
|
||||
return msgs, buffers, names
|
||||
}
|
||||
|
||||
func setRawMessageControl(msg *rawMessage, buf []byte) {
|
||||
if len(buf) == 0 {
|
||||
msg.Hdr.Control = nil
|
||||
msg.Hdr.Controllen = 0
|
||||
return
|
||||
}
|
||||
msg.Hdr.Control = &buf[0]
|
||||
msg.Hdr.Controllen = uint64(len(buf))
|
||||
}
|
||||
|
||||
func getRawMessageControlLen(msg *rawMessage) int {
|
||||
return int(msg.Hdr.Controllen)
|
||||
}
|
||||
|
||||
func getRawMessageFlags(msg *rawMessage) int {
|
||||
return int(msg.Hdr.Flags)
|
||||
}
|
||||
|
||||
func setCmsgLen(h *unix.Cmsghdr, l int) {
|
||||
h.Len = uint64(l)
|
||||
}
|
||||
|
||||
func setIovecSlice(iov *iovec, b []byte) {
|
||||
if len(b) == 0 {
|
||||
iov.Base = nil
|
||||
iov.Len = 0
|
||||
return
|
||||
}
|
||||
iov.Base = &b[0]
|
||||
iov.Len = uint64(len(b))
|
||||
}
|
||||
|
||||
@@ -134,7 +134,7 @@ func (u *RIOConn) bind(sa windows.Sockaddr) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (u *RIOConn) ListenOut(r EncReader) {
|
||||
func (u *RIOConn) ListenOut(r EncReader) error {
|
||||
buffer := make([]byte, MTU)
|
||||
|
||||
for {
|
||||
@@ -304,6 +304,17 @@ func (u *RIOConn) WriteTo(buf []byte, ip netip.AddrPort) error {
|
||||
return winrio.SendEx(u.rq, dataBuffer, 1, nil, addressBuffer, nil, nil, 0, 0)
|
||||
}
|
||||
|
||||
func (u *RIOConn) WriteBatch(pkts []BatchPacket) (int, error) {
|
||||
sent := 0
|
||||
for _, pkt := range pkts {
|
||||
if err := u.WriteTo(pkt.Payload, pkt.Addr); err != nil {
|
||||
return sent, err
|
||||
}
|
||||
sent++
|
||||
}
|
||||
return sent, nil
|
||||
}
|
||||
|
||||
func (u *RIOConn) LocalAddr() (netip.AddrPort, error) {
|
||||
sa, err := windows.Getsockname(u.sock)
|
||||
if err != nil {
|
||||
|
||||
@@ -106,6 +106,17 @@ func (u *TesterConn) WriteTo(b []byte, addr netip.AddrPort) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (u *TesterConn) WriteBatch(pkts []BatchPacket) (int, error) {
|
||||
sent := 0
|
||||
for _, pkt := range pkts {
|
||||
if err := u.WriteTo(pkt.Payload, pkt.Addr); err != nil {
|
||||
return sent, err
|
||||
}
|
||||
sent++
|
||||
}
|
||||
return sent, nil
|
||||
}
|
||||
|
||||
func (u *TesterConn) ListenOut(r EncReader) {
|
||||
for {
|
||||
p, ok := <-u.RxPackets
|
||||
|
||||
Reference in New Issue
Block a user