mirror of
https://github.com/slackhq/nebula.git
synced 2026-05-16 04:47:38 +02:00
holy crap 2x
This commit is contained in:
13
interface.go
13
interface.go
@@ -86,7 +86,11 @@ type Interface struct {
|
|||||||
|
|
||||||
writers []udp.Conn
|
writers []udp.Conn
|
||||||
readers []overlay.Queue
|
readers []overlay.Queue
|
||||||
wg sync.WaitGroup
|
// tunCoalescers is one tcpCoalescer per tun queue, wrapping readers[i].
|
||||||
|
// decryptToTun sends plaintext into the coalescer; listenOut calls its
|
||||||
|
// Flush at the end of each UDP recvmmsg batch.
|
||||||
|
tunCoalescers []*tcpCoalescer
|
||||||
|
wg sync.WaitGroup
|
||||||
|
|
||||||
// fatalErr holds the first unexpected reader error that caused shutdown.
|
// fatalErr holds the first unexpected reader error that caused shutdown.
|
||||||
// nil means "no fatal error" (yet)
|
// nil means "no fatal error" (yet)
|
||||||
@@ -184,6 +188,7 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
|
|||||||
version: c.version,
|
version: c.version,
|
||||||
writers: make([]udp.Conn, c.routines),
|
writers: make([]udp.Conn, c.routines),
|
||||||
readers: make([]overlay.Queue, c.routines),
|
readers: make([]overlay.Queue, c.routines),
|
||||||
|
tunCoalescers: make([]*tcpCoalescer, c.routines),
|
||||||
myVpnNetworks: cs.myVpnNetworks,
|
myVpnNetworks: cs.myVpnNetworks,
|
||||||
myVpnNetworksTable: cs.myVpnNetworksTable,
|
myVpnNetworksTable: cs.myVpnNetworksTable,
|
||||||
myVpnAddrs: cs.myVpnAddrs,
|
myVpnAddrs: cs.myVpnAddrs,
|
||||||
@@ -247,6 +252,7 @@ func (f *Interface) activate() error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
f.readers[i] = reader
|
f.readers[i] = reader
|
||||||
|
f.tunCoalescers[i] = newTCPCoalescer(reader)
|
||||||
}
|
}
|
||||||
|
|
||||||
f.wg.Add(1) // for us to wait on Close() to return
|
f.wg.Add(1) // for us to wait on Close() to return
|
||||||
@@ -308,8 +314,13 @@ func (f *Interface) listenOut(i int) {
|
|||||||
fwPacket := &firewall.Packet{}
|
fwPacket := &firewall.Packet{}
|
||||||
nb := make([]byte, 12, 12)
|
nb := make([]byte, 12, 12)
|
||||||
|
|
||||||
|
coalescer := f.tunCoalescers[i]
|
||||||
err := li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) {
|
err := li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) {
|
||||||
f.readOutsidePackets(ViaSender{UdpAddr: fromUdpAddr}, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get(f.l))
|
f.readOutsidePackets(ViaSender{UdpAddr: fromUdpAddr}, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get(f.l))
|
||||||
|
}, func() {
|
||||||
|
if err := coalescer.Flush(); err != nil {
|
||||||
|
f.l.WithError(err).Error("Failed to flush tun coalescer")
|
||||||
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
if err != nil && !f.closed.Load() {
|
if err != nil && !f.closed.Load() {
|
||||||
|
|||||||
@@ -535,7 +535,7 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
|
|||||||
}
|
}
|
||||||
|
|
||||||
f.connectionManager.In(hostinfo)
|
f.connectionManager.In(hostinfo)
|
||||||
_, err = f.readers[q].Write(out)
|
err = f.tunCoalescers[q].Add(out)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
f.l.WithError(err).Error("Failed to write to tun")
|
f.l.WithError(err).Error("Failed to write to tun")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -30,3 +30,22 @@ type Device interface {
|
|||||||
SupportsMultiqueue() bool
|
SupportsMultiqueue() bool
|
||||||
NewMultiQueueReader() (Queue, error)
|
NewMultiQueueReader() (Queue, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GSOWriter is implemented by Queues that can write a TCP TSO superpacket as
|
||||||
|
// a single virtio_net_hdr + payload writev, letting the kernel segment on
|
||||||
|
// egress. Callers type-assert on it; backends that don't support GSO return
|
||||||
|
// false from Supported and all coalescing logic is skipped.
|
||||||
|
//
|
||||||
|
// pkt must contain the IPv4/IPv6 + TCP header plus the concatenated
|
||||||
|
// coalesced payload. hdrLen is the total L3+L4 header length (where the
|
||||||
|
// payload starts). csumStart is the byte offset where the TCP header
|
||||||
|
// begins (= IP header length). gsoSize is the MSS — every segment except
|
||||||
|
// possibly the last must be exactly this many payload bytes. isV6 selects
|
||||||
|
// GSO_TCPV4 vs GSO_TCPV6.
|
||||||
|
//
|
||||||
|
// pkt's TCP checksum field must already hold the pseudo-header partial
|
||||||
|
// sum (single-fold, not inverted), per virtio NEEDS_CSUM semantics.
|
||||||
|
type GSOWriter interface {
|
||||||
|
WriteGSO(pkt []byte, gsoSize uint16, isV6 bool, hdrLen, csumStart uint16) error
|
||||||
|
GSOSupported() bool
|
||||||
|
}
|
||||||
|
|||||||
@@ -48,6 +48,12 @@ type tunFile struct {
|
|||||||
pending [][]byte // segments waiting to be drained by Read
|
pending [][]byte // segments waiting to be drained by Read
|
||||||
pendingIdx int
|
pendingIdx int
|
||||||
writeIovs [2]unix.Iovec // preallocated iovecs for vnetHdr writes; iovs[0] is fixed to zeroVnetHdr
|
writeIovs [2]unix.Iovec // preallocated iovecs for vnetHdr writes; iovs[0] is fixed to zeroVnetHdr
|
||||||
|
|
||||||
|
// gsoHdrBuf is a per-queue 10-byte scratch for the virtio_net_hdr emitted
|
||||||
|
// by WriteGSO. Separate from zeroVnetHdr so a concurrent non-GSO Write on
|
||||||
|
// another queue never observes a half-written header.
|
||||||
|
gsoHdrBuf [virtioNetHdrLen]byte
|
||||||
|
gsoIovs [2]unix.Iovec
|
||||||
}
|
}
|
||||||
|
|
||||||
// zeroVnetHdr is the 10-byte virtio_net_hdr we prepend to every TUN write when
|
// zeroVnetHdr is the 10-byte virtio_net_hdr we prepend to every TUN write when
|
||||||
@@ -78,6 +84,8 @@ func (r *tunFile) newFriend(fd int) (*tunFile, error) {
|
|||||||
out.segBuf = make([]byte, tunSegBufCap)
|
out.segBuf = make([]byte, tunSegBufCap)
|
||||||
out.writeIovs[0].Base = &zeroVnetHdr[0]
|
out.writeIovs[0].Base = &zeroVnetHdr[0]
|
||||||
out.writeIovs[0].SetLen(virtioNetHdrLen)
|
out.writeIovs[0].SetLen(virtioNetHdrLen)
|
||||||
|
out.gsoIovs[0].Base = &out.gsoHdrBuf[0]
|
||||||
|
out.gsoIovs[0].SetLen(virtioNetHdrLen)
|
||||||
}
|
}
|
||||||
return out, nil
|
return out, nil
|
||||||
}
|
}
|
||||||
@@ -111,6 +119,8 @@ func newTunFd(fd int, vnetHdr bool) (*tunFile, error) {
|
|||||||
out.segBuf = make([]byte, tunSegBufCap)
|
out.segBuf = make([]byte, tunSegBufCap)
|
||||||
out.writeIovs[0].Base = &zeroVnetHdr[0]
|
out.writeIovs[0].Base = &zeroVnetHdr[0]
|
||||||
out.writeIovs[0].SetLen(virtioNetHdrLen)
|
out.writeIovs[0].SetLen(virtioNetHdrLen)
|
||||||
|
out.gsoIovs[0].Base = &out.gsoHdrBuf[0]
|
||||||
|
out.gsoIovs[0].SetLen(virtioNetHdrLen)
|
||||||
}
|
}
|
||||||
|
|
||||||
return out, nil
|
return out, nil
|
||||||
@@ -331,6 +341,64 @@ func (r *tunFile) Write(buf []byte) (int, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GSOSupported reports whether this queue was opened with IFF_VNET_HDR and
|
||||||
|
// can accept WriteGSO. When false, callers should fall back to per-segment
|
||||||
|
// Write calls.
|
||||||
|
func (r *tunFile) GSOSupported() bool { return r.vnetHdr }
|
||||||
|
|
||||||
|
// WriteGSO emits pkt as a single TCP TSO superpacket via writev. pkt must
|
||||||
|
// contain a full IPv4/IPv6 + TCP header prefix followed by the concatenated
|
||||||
|
// coalesced payload. The TCP checksum field must already hold the
|
||||||
|
// pseudo-header partial (NEEDS_CSUM semantics). gsoSize is the MSS; every
|
||||||
|
// segment except the last must be exactly that many payload bytes.
|
||||||
|
func (r *tunFile) WriteGSO(pkt []byte, gsoSize uint16, isV6 bool, hdrLen, csumStart uint16) error {
|
||||||
|
if !r.vnetHdr {
|
||||||
|
return fmt.Errorf("WriteGSO called on tun without IFF_VNET_HDR")
|
||||||
|
}
|
||||||
|
if len(pkt) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
hdr := virtioNetHdr{
|
||||||
|
Flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM,
|
||||||
|
HdrLen: hdrLen,
|
||||||
|
GSOSize: gsoSize,
|
||||||
|
CsumStart: csumStart,
|
||||||
|
CsumOffset: 16, // TCP checksum field lives 16 bytes into the TCP header
|
||||||
|
}
|
||||||
|
if isV6 {
|
||||||
|
hdr.GSOType = unix.VIRTIO_NET_HDR_GSO_TCPV6
|
||||||
|
} else {
|
||||||
|
hdr.GSOType = unix.VIRTIO_NET_HDR_GSO_TCPV4
|
||||||
|
}
|
||||||
|
hdr.encode(r.gsoHdrBuf[:])
|
||||||
|
|
||||||
|
r.gsoIovs[1].Base = &pkt[0]
|
||||||
|
r.gsoIovs[1].SetLen(len(pkt))
|
||||||
|
iovPtr := uintptr(unsafe.Pointer(&r.gsoIovs[0]))
|
||||||
|
for {
|
||||||
|
n, _, errno := syscall.RawSyscall(unix.SYS_WRITEV, uintptr(r.fd), iovPtr, 2)
|
||||||
|
if errno == 0 {
|
||||||
|
runtime.KeepAlive(pkt)
|
||||||
|
if int(n) < virtioNetHdrLen {
|
||||||
|
return io.ErrShortWrite
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if errno == unix.EAGAIN {
|
||||||
|
runtime.KeepAlive(pkt)
|
||||||
|
if err := r.blockOnWrite(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if errno == unix.EINTR {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
runtime.KeepAlive(pkt)
|
||||||
|
return errno
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (r *tunFile) wakeForShutdown() error {
|
func (r *tunFile) wakeForShutdown() error {
|
||||||
var buf [8]byte
|
var buf [8]byte
|
||||||
binary.NativeEndian.PutUint64(buf[:], 1)
|
binary.NativeEndian.PutUint64(buf[:], 1)
|
||||||
|
|||||||
@@ -54,6 +54,18 @@ func (h *virtioNetHdr) decode(b []byte) {
|
|||||||
h.CsumOffset = binary.NativeEndian.Uint16(b[8:10])
|
h.CsumOffset = binary.NativeEndian.Uint16(b[8:10])
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// encode is the inverse of decode: writes the virtio_net_hdr fields into b
|
||||||
|
// (must be at least virtioNetHdrLen bytes). Used to emit a TSO superpacket
|
||||||
|
// on egress.
|
||||||
|
func (h *virtioNetHdr) encode(b []byte) {
|
||||||
|
b[0] = h.Flags
|
||||||
|
b[1] = h.GSOType
|
||||||
|
binary.NativeEndian.PutUint16(b[2:4], h.HdrLen)
|
||||||
|
binary.NativeEndian.PutUint16(b[4:6], h.GSOSize)
|
||||||
|
binary.NativeEndian.PutUint16(b[6:8], h.CsumStart)
|
||||||
|
binary.NativeEndian.PutUint16(b[8:10], h.CsumOffset)
|
||||||
|
}
|
||||||
|
|
||||||
// segmentInto splits a TUN-side packet described by hdr into one or more
|
// segmentInto splits a TUN-side packet described by hdr into one or more
|
||||||
// IP packets, each appended to *out as a slice of scratch. scratch must be
|
// IP packets, each appended to *out as a slice of scratch. scratch must be
|
||||||
// sized to hold every segment (including replicated headers).
|
// sized to hold every segment (including replicated headers).
|
||||||
|
|||||||
436
tcp_coalesce.go
Normal file
436
tcp_coalesce.go
Normal file
@@ -0,0 +1,436 @@
|
|||||||
|
package nebula
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"io"
|
||||||
|
|
||||||
|
"github.com/slackhq/nebula/overlay"
|
||||||
|
)
|
||||||
|
|
||||||
|
// IPPROTO_TCP is the IANA protocol number for TCP. Hardcoded instead of
|
||||||
|
// reaching for ipProtoTCP because golang.org/x/sys/unix doesn't
|
||||||
|
// define that constant on Windows, which would break cross-compiles even
|
||||||
|
// though this file runs unchanged on every platform.
|
||||||
|
const ipProtoTCP = 6
|
||||||
|
|
||||||
|
// tcpCoalesceBufSize bounds the largest coalesced superpacket we will buffer.
|
||||||
|
// Linux caps sk_gso_max_size around 64KiB; 65535 bytes covers IP hdr + TCP
|
||||||
|
// hdr + up to ~65KB of payload, which is the most the kernel's TSO can
|
||||||
|
// segment in one shot.
|
||||||
|
const tcpCoalesceBufSize = 65535
|
||||||
|
|
||||||
|
// tcpCoalesceMaxSegs caps how many segments we are willing to coalesce into
|
||||||
|
// a single superpacket regardless of byte budget. Kernel allows up to 64
|
||||||
|
// for UDP GSO and 128 for many TSO engines; stop well before either limit
|
||||||
|
// to keep latency bounded.
|
||||||
|
const tcpCoalesceMaxSegs = 64
|
||||||
|
|
||||||
|
// tcpCoalescer accumulates adjacent in-flow TCP data segments into a single
|
||||||
|
// TSO superpacket and emits them via overlay.GSOWriter in one writev. When
|
||||||
|
// a packet fails admission or fails to extend the pending flow, the
|
||||||
|
// pending superpacket is flushed and the non-matching packet is written
|
||||||
|
// through as-is. Owns no locks — one coalescer per TUN write queue.
|
||||||
|
type tcpCoalescer struct {
|
||||||
|
plainW io.Writer
|
||||||
|
gsoW overlay.GSOWriter // nil when the queue doesn't support TSO
|
||||||
|
|
||||||
|
buf []byte
|
||||||
|
bufLen int // valid bytes in buf — hdrLen plus accumulated payload
|
||||||
|
active bool // a seed packet is present
|
||||||
|
numSeg int
|
||||||
|
gsoSize int // payload length of each segment (= MSS of the seed)
|
||||||
|
isV6 bool
|
||||||
|
ipHdrLen int
|
||||||
|
hdrLen int // ipHdrLen + tcpHdrLen, the offset where payload starts
|
||||||
|
nextSeq uint32 // expected TCP seq of the next packet to coalesce
|
||||||
|
// psh indicates the last-accepted segment had PSH set. We accept a PSH
|
||||||
|
// packet as the final segment but reject any further Adds after that.
|
||||||
|
psh bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func newTCPCoalescer(w io.Writer) *tcpCoalescer {
|
||||||
|
c := &tcpCoalescer{plainW: w, buf: make([]byte, tcpCoalesceBufSize)}
|
||||||
|
if gw, ok := w.(overlay.GSOWriter); ok && gw.GSOSupported() {
|
||||||
|
c.gsoW = gw
|
||||||
|
}
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
// parsedTCP holds the byte offsets / values we extract from one admission
|
||||||
|
// check so Add and canAppend don't re-parse the same header twice.
|
||||||
|
type parsedTCP struct {
|
||||||
|
isV6 bool
|
||||||
|
ipHdrLen int
|
||||||
|
tcpHdrLen int
|
||||||
|
hdrLen int // ipHdrLen + tcpHdrLen
|
||||||
|
payLen int
|
||||||
|
seq uint32
|
||||||
|
flags byte
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseCoalesceable decides whether pkt is eligible for TCP coalescing. It
|
||||||
|
// accepts IPv4 (no options, DF set, no fragmentation) and IPv6 (no
|
||||||
|
// extension headers) carrying a TCP segment with flags in {ACK, ACK|PSH}
|
||||||
|
// and a non-empty payload. On success it returns the parsed offsets.
|
||||||
|
func parseCoalesceable(pkt []byte) (parsedTCP, bool) {
|
||||||
|
var p parsedTCP
|
||||||
|
if len(pkt) < 20 {
|
||||||
|
return p, false
|
||||||
|
}
|
||||||
|
v := pkt[0] >> 4
|
||||||
|
switch v {
|
||||||
|
case 4:
|
||||||
|
if len(pkt) < 20 {
|
||||||
|
return p, false
|
||||||
|
}
|
||||||
|
ihl := int(pkt[0]&0x0f) * 4
|
||||||
|
if ihl != 20 {
|
||||||
|
return p, false // reject IP options
|
||||||
|
}
|
||||||
|
if pkt[9] != ipProtoTCP {
|
||||||
|
return p, false
|
||||||
|
}
|
||||||
|
// Fragment check: MF=0 and frag offset=0. Accept DF=1 or DF=0 —
|
||||||
|
// just reject any actual fragmentation.
|
||||||
|
fragField := binary.BigEndian.Uint16(pkt[6:8])
|
||||||
|
if fragField&0x3fff != 0 {
|
||||||
|
return p, false
|
||||||
|
}
|
||||||
|
totalLen := int(binary.BigEndian.Uint16(pkt[2:4]))
|
||||||
|
if totalLen > len(pkt) || totalLen < ihl {
|
||||||
|
return p, false
|
||||||
|
}
|
||||||
|
p.isV6 = false
|
||||||
|
p.ipHdrLen = ihl
|
||||||
|
pkt = pkt[:totalLen]
|
||||||
|
case 6:
|
||||||
|
if len(pkt) < 40 {
|
||||||
|
return p, false
|
||||||
|
}
|
||||||
|
if pkt[6] != ipProtoTCP {
|
||||||
|
return p, false // reject ext headers
|
||||||
|
}
|
||||||
|
payloadLen := int(binary.BigEndian.Uint16(pkt[4:6]))
|
||||||
|
if 40+payloadLen > len(pkt) {
|
||||||
|
return p, false
|
||||||
|
}
|
||||||
|
p.isV6 = true
|
||||||
|
p.ipHdrLen = 40
|
||||||
|
pkt = pkt[:40+payloadLen]
|
||||||
|
default:
|
||||||
|
return p, false
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(pkt) < p.ipHdrLen+20 {
|
||||||
|
return p, false
|
||||||
|
}
|
||||||
|
tcpOff := int(pkt[p.ipHdrLen+12]>>4) * 4
|
||||||
|
if tcpOff < 20 || tcpOff > 60 {
|
||||||
|
return p, false
|
||||||
|
}
|
||||||
|
if len(pkt) < p.ipHdrLen+tcpOff {
|
||||||
|
return p, false
|
||||||
|
}
|
||||||
|
flags := pkt[p.ipHdrLen+13]
|
||||||
|
// Allow only ACK and ACK|PSH. In particular: no SYN/FIN/RST/URG/CWR/ECE.
|
||||||
|
const ack = 0x10
|
||||||
|
const psh = 0x08
|
||||||
|
if flags&^(ack|psh) != 0 || flags&ack == 0 {
|
||||||
|
return p, false
|
||||||
|
}
|
||||||
|
p.tcpHdrLen = tcpOff
|
||||||
|
p.hdrLen = p.ipHdrLen + tcpOff
|
||||||
|
p.payLen = len(pkt) - p.hdrLen
|
||||||
|
if p.payLen <= 0 {
|
||||||
|
return p, false
|
||||||
|
}
|
||||||
|
p.seq = binary.BigEndian.Uint32(pkt[p.ipHdrLen+4 : p.ipHdrLen+8])
|
||||||
|
p.flags = flags
|
||||||
|
return p, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add takes a plaintext inbound packet destined for the tun. If GSO is
|
||||||
|
// unavailable or the packet isn't coalesceable, Add falls through to a
|
||||||
|
// plain Write on the underlying queue (flushing any pending superpacket
|
||||||
|
// first).
|
||||||
|
func (c *tcpCoalescer) Add(pkt []byte) error {
|
||||||
|
if c.gsoW == nil {
|
||||||
|
_, err := c.plainW.Write(pkt)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
info, ok := parseCoalesceable(pkt)
|
||||||
|
if !ok {
|
||||||
|
if c.active {
|
||||||
|
if err := c.flushLocked(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_, err := c.plainW.Write(pkt)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.active {
|
||||||
|
if c.canAppend(pkt, info) {
|
||||||
|
c.appendPayload(pkt, info)
|
||||||
|
if info.flags&0x08 != 0 {
|
||||||
|
c.psh = true
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if err := c.flushLocked(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return c.seed(pkt, info)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Flush emits any pending superpacket. Called by the UDP read loop at
|
||||||
|
// recvmmsg batch boundaries — "no more packets coming right now".
|
||||||
|
func (c *tcpCoalescer) Flush() error {
|
||||||
|
if !c.active {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return c.flushLocked()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *tcpCoalescer) reset() {
|
||||||
|
c.active = false
|
||||||
|
c.bufLen = 0
|
||||||
|
c.numSeg = 0
|
||||||
|
c.gsoSize = 0
|
||||||
|
c.hdrLen = 0
|
||||||
|
c.ipHdrLen = 0
|
||||||
|
c.nextSeq = 0
|
||||||
|
c.psh = false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *tcpCoalescer) seed(pkt []byte, info parsedTCP) error {
|
||||||
|
if info.hdrLen+info.payLen > len(c.buf) {
|
||||||
|
// Oversize single packet — flush (already done above) and passthrough.
|
||||||
|
_, err := c.plainW.Write(pkt)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
copy(c.buf, pkt[:info.hdrLen+info.payLen])
|
||||||
|
c.active = true
|
||||||
|
c.bufLen = info.hdrLen + info.payLen
|
||||||
|
c.numSeg = 1
|
||||||
|
c.gsoSize = info.payLen
|
||||||
|
c.isV6 = info.isV6
|
||||||
|
c.ipHdrLen = info.ipHdrLen
|
||||||
|
c.hdrLen = info.hdrLen
|
||||||
|
c.nextSeq = info.seq + uint32(info.payLen)
|
||||||
|
c.psh = info.flags&0x08 != 0
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// canAppend reports whether info's packet extends the current seed: same
|
||||||
|
// flow, adjacent seq, payload size rule, and no-PSH-mid-chain.
|
||||||
|
func (c *tcpCoalescer) canAppend(pkt []byte, info parsedTCP) bool {
|
||||||
|
if c.psh {
|
||||||
|
return false // we already accepted a PSH — chain is closed
|
||||||
|
}
|
||||||
|
if info.isV6 != c.isV6 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if info.hdrLen != c.hdrLen {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if info.seq != c.nextSeq {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if c.numSeg >= tcpCoalesceMaxSegs {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if c.bufLen+info.payLen > len(c.buf) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
// Every mid-chain segment must be exactly gsoSize. The final segment may
|
||||||
|
// be shorter, but once a short segment is appended we can't add another.
|
||||||
|
if info.payLen > c.gsoSize {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if info.payLen < c.gsoSize {
|
||||||
|
// Will become the last segment — always OK to append, just no more.
|
||||||
|
}
|
||||||
|
// Compare the stable parts of the header.
|
||||||
|
if !headersMatch(c.buf[:c.hdrLen], pkt[:info.hdrLen], c.isV6, c.ipHdrLen) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *tcpCoalescer) appendPayload(pkt []byte, info parsedTCP) {
|
||||||
|
copy(c.buf[c.bufLen:], pkt[info.hdrLen:info.hdrLen+info.payLen])
|
||||||
|
c.bufLen += info.payLen
|
||||||
|
c.numSeg++
|
||||||
|
c.nextSeq = info.seq + uint32(info.payLen)
|
||||||
|
// If this was a sub-gsoSize last segment, mark chain as closed.
|
||||||
|
if info.payLen < c.gsoSize {
|
||||||
|
c.psh = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// headersMatch compares two IP+TCP header prefixes for byte-for-byte
|
||||||
|
// equality on every field that must be identical across coalesced
|
||||||
|
// segments. Size/IPID/IPCsum/seq/flags/tcpCsum are masked out.
|
||||||
|
func headersMatch(a, b []byte, isV6 bool, ipHdrLen int) bool {
|
||||||
|
if len(a) != len(b) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if isV6 {
|
||||||
|
// IPv6: bytes [0:4] = version/TC/flow-label, [6:8] = next_hdr/hop,
|
||||||
|
// [8:40] = src+dst. Skip [4:6] payload length.
|
||||||
|
if !bytesEq(a[0:4], b[0:4]) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if !bytesEq(a[6:40], b[6:40]) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// IPv4: [0:2] version/IHL/TOS, [6:10] flags/fragoff/TTL/proto,
|
||||||
|
// [12:20] src+dst. Skip [2:4] total len, [4:6] id, [10:12] csum.
|
||||||
|
if !bytesEq(a[0:2], b[0:2]) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if !bytesEq(a[6:10], b[6:10]) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if !bytesEq(a[12:20], b[12:20]) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// TCP: compare [0:4] ports, [8:13] ack+dataoff, [14:16] window,
|
||||||
|
// [18:tcpHdrLen] options (incl. urgent).
|
||||||
|
tcp := ipHdrLen
|
||||||
|
if !bytesEq(a[tcp:tcp+4], b[tcp:tcp+4]) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if !bytesEq(a[tcp+8:tcp+13], b[tcp+8:tcp+13]) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if !bytesEq(a[tcp+14:tcp+16], b[tcp+14:tcp+16]) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if !bytesEq(a[tcp+18:], b[tcp+18:]) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func bytesEq(a, b []byte) bool {
|
||||||
|
if len(a) != len(b) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for i := range a {
|
||||||
|
if a[i] != b[i] {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *tcpCoalescer) flushLocked() error {
|
||||||
|
// Guarantee the coalescer is empty on exit regardless of how we leave.
|
||||||
|
defer c.reset()
|
||||||
|
|
||||||
|
if c.numSeg <= 1 {
|
||||||
|
_, err := c.plainW.Write(c.buf[:c.bufLen])
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
total := c.bufLen
|
||||||
|
l4Len := total - c.ipHdrLen
|
||||||
|
|
||||||
|
// Fix IP header length field.
|
||||||
|
if c.isV6 {
|
||||||
|
if l4Len > 0xffff {
|
||||||
|
// Shouldn't happen given buffer size, but guard against it.
|
||||||
|
return c.flushAsPerSegment()
|
||||||
|
}
|
||||||
|
binary.BigEndian.PutUint16(c.buf[4:6], uint16(l4Len))
|
||||||
|
} else {
|
||||||
|
if total > 0xffff {
|
||||||
|
return c.flushAsPerSegment()
|
||||||
|
}
|
||||||
|
binary.BigEndian.PutUint16(c.buf[2:4], uint16(total))
|
||||||
|
// Recompute IPv4 header checksum.
|
||||||
|
c.buf[10] = 0
|
||||||
|
c.buf[11] = 0
|
||||||
|
binary.BigEndian.PutUint16(c.buf[10:12], ipv4HdrChecksum(c.buf[:c.ipHdrLen]))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write the virtio NEEDS_CSUM pseudo-header partial into the TCP csum field.
|
||||||
|
var psum uint32
|
||||||
|
if c.isV6 {
|
||||||
|
psum = pseudoSumIPv6(c.buf[8:24], c.buf[24:40], ipProtoTCP, l4Len)
|
||||||
|
} else {
|
||||||
|
psum = pseudoSumIPv4(c.buf[12:16], c.buf[16:20], ipProtoTCP, l4Len)
|
||||||
|
}
|
||||||
|
tcsum := c.ipHdrLen + 16
|
||||||
|
binary.BigEndian.PutUint16(c.buf[tcsum:tcsum+2], foldOnceNoInvert(psum))
|
||||||
|
|
||||||
|
return c.gsoW.WriteGSO(c.buf[:total], uint16(c.gsoSize), c.isV6, uint16(c.hdrLen), uint16(c.ipHdrLen))
|
||||||
|
}
|
||||||
|
|
||||||
|
// flushAsPerSegment is a defensive fallback used if the coalesced superpacket
|
||||||
|
// somehow exceeds 16-bit length fields. It writes the packet as-is through
|
||||||
|
// the plain writer (the kernel will reject it, but that's a visible error
|
||||||
|
// rather than silent corruption).
|
||||||
|
func (c *tcpCoalescer) flushAsPerSegment() error {
|
||||||
|
_, err := c.plainW.Write(c.buf[:c.bufLen])
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// ipv4HdrChecksum computes the IPv4 header checksum over hdr (which must
|
||||||
|
// already have its checksum field zeroed) and returns the folded/inverted
|
||||||
|
// 16-bit value to store.
|
||||||
|
func ipv4HdrChecksum(hdr []byte) uint16 {
|
||||||
|
var sum uint32
|
||||||
|
for i := 0; i+1 < len(hdr); i += 2 {
|
||||||
|
sum += uint32(binary.BigEndian.Uint16(hdr[i : i+2]))
|
||||||
|
}
|
||||||
|
if len(hdr)%2 == 1 {
|
||||||
|
sum += uint32(hdr[len(hdr)-1]) << 8
|
||||||
|
}
|
||||||
|
for sum>>16 != 0 {
|
||||||
|
sum = (sum & 0xffff) + (sum >> 16)
|
||||||
|
}
|
||||||
|
return ^uint16(sum)
|
||||||
|
}
|
||||||
|
|
||||||
|
// pseudoSumIPv4 / pseudoSumIPv6 build the TCP pseudo-header partial sum
|
||||||
|
// expected by the virtio NEEDS_CSUM kernel path: the 32-bit accumulator
|
||||||
|
// before folding.
|
||||||
|
func pseudoSumIPv4(src, dst []byte, proto byte, l4Len int) uint32 {
|
||||||
|
var sum uint32
|
||||||
|
sum += uint32(binary.BigEndian.Uint16(src[0:2]))
|
||||||
|
sum += uint32(binary.BigEndian.Uint16(src[2:4]))
|
||||||
|
sum += uint32(binary.BigEndian.Uint16(dst[0:2]))
|
||||||
|
sum += uint32(binary.BigEndian.Uint16(dst[2:4]))
|
||||||
|
sum += uint32(proto)
|
||||||
|
sum += uint32(l4Len)
|
||||||
|
return sum
|
||||||
|
}
|
||||||
|
|
||||||
|
func pseudoSumIPv6(src, dst []byte, proto byte, l4Len int) uint32 {
|
||||||
|
var sum uint32
|
||||||
|
for i := 0; i < 16; i += 2 {
|
||||||
|
sum += uint32(binary.BigEndian.Uint16(src[i : i+2]))
|
||||||
|
sum += uint32(binary.BigEndian.Uint16(dst[i : i+2]))
|
||||||
|
}
|
||||||
|
sum += uint32(l4Len >> 16)
|
||||||
|
sum += uint32(l4Len & 0xffff)
|
||||||
|
sum += uint32(proto)
|
||||||
|
return sum
|
||||||
|
}
|
||||||
|
|
||||||
|
// foldOnceNoInvert folds the 32-bit accumulator to 16 bits and returns it
|
||||||
|
// unchanged (no one's complement). This is what virtio NEEDS_CSUM wants in
|
||||||
|
// the L4 checksum field — the kernel will add the payload sum and invert.
|
||||||
|
func foldOnceNoInvert(sum uint32) uint16 {
|
||||||
|
for sum>>16 != 0 {
|
||||||
|
sum = (sum & 0xffff) + (sum >> 16)
|
||||||
|
}
|
||||||
|
return uint16(sum)
|
||||||
|
}
|
||||||
356
tcp_coalesce_test.go
Normal file
356
tcp_coalesce_test.go
Normal file
@@ -0,0 +1,356 @@
|
|||||||
|
package nebula
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
// A minimal stub writer that records each plain Write and each WriteGSO
|
||||||
|
// call without touching a real TUN fd.
|
||||||
|
type fakeTunWriter struct {
|
||||||
|
gsoEnabled bool
|
||||||
|
writes [][]byte
|
||||||
|
gsoWrites []fakeGSOWrite
|
||||||
|
}
|
||||||
|
|
||||||
|
type fakeGSOWrite struct {
|
||||||
|
pkt []byte
|
||||||
|
gsoSize uint16
|
||||||
|
isV6 bool
|
||||||
|
hdrLen uint16
|
||||||
|
csumStart uint16
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *fakeTunWriter) Write(p []byte) (int, error) {
|
||||||
|
buf := make([]byte, len(p))
|
||||||
|
copy(buf, p)
|
||||||
|
w.writes = append(w.writes, buf)
|
||||||
|
return len(p), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *fakeTunWriter) WriteGSO(pkt []byte, gsoSize uint16, isV6 bool, hdrLen, csumStart uint16) error {
|
||||||
|
buf := make([]byte, len(pkt))
|
||||||
|
copy(buf, pkt)
|
||||||
|
w.gsoWrites = append(w.gsoWrites, fakeGSOWrite{pkt: buf, gsoSize: gsoSize, isV6: isV6, hdrLen: hdrLen, csumStart: csumStart})
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *fakeTunWriter) GSOSupported() bool { return w.gsoEnabled }
|
||||||
|
|
||||||
|
// buildTCPv4 constructs a minimal IPv4+TCP packet with the given payload,
|
||||||
|
// seq, and flags. Assumes no IP options and a 20-byte TCP header.
|
||||||
|
func buildTCPv4(seq uint32, flags byte, payload []byte) []byte {
|
||||||
|
const ipHdrLen = 20
|
||||||
|
const tcpHdrLen = 20
|
||||||
|
total := ipHdrLen + tcpHdrLen + len(payload)
|
||||||
|
pkt := make([]byte, total)
|
||||||
|
|
||||||
|
// IPv4 header.
|
||||||
|
pkt[0] = 0x45 // version 4, IHL 5
|
||||||
|
pkt[1] = 0x00 // TOS
|
||||||
|
binary.BigEndian.PutUint16(pkt[2:4], uint16(total))
|
||||||
|
binary.BigEndian.PutUint16(pkt[4:6], 0) // id
|
||||||
|
binary.BigEndian.PutUint16(pkt[6:8], 0x4000) // DF
|
||||||
|
pkt[8] = 64 // TTL
|
||||||
|
pkt[9] = ipProtoTCP
|
||||||
|
// csum left zero — coalescer recomputes on emit.
|
||||||
|
copy(pkt[12:16], []byte{10, 0, 0, 1}) // src
|
||||||
|
copy(pkt[16:20], []byte{10, 0, 0, 2}) // dst
|
||||||
|
|
||||||
|
// TCP header.
|
||||||
|
binary.BigEndian.PutUint16(pkt[20:22], 1000) // sport
|
||||||
|
binary.BigEndian.PutUint16(pkt[22:24], 2000) // dport
|
||||||
|
binary.BigEndian.PutUint32(pkt[24:28], seq)
|
||||||
|
binary.BigEndian.PutUint32(pkt[28:32], 12345) // ack
|
||||||
|
pkt[32] = 0x50 // data offset = 5 << 4
|
||||||
|
pkt[33] = flags
|
||||||
|
binary.BigEndian.PutUint16(pkt[34:36], 0xffff) // window
|
||||||
|
// tcp csum zero
|
||||||
|
// urgent zero
|
||||||
|
|
||||||
|
copy(pkt[40:], payload)
|
||||||
|
return pkt
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
tcpAck = 0x10
|
||||||
|
tcpPsh = 0x08
|
||||||
|
tcpSyn = 0x02
|
||||||
|
tcpFin = 0x01
|
||||||
|
tcpAckPsh = tcpAck | tcpPsh
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestCoalescerPassthroughWhenGSOUnavailable(t *testing.T) {
|
||||||
|
w := &fakeTunWriter{gsoEnabled: false}
|
||||||
|
c := newTCPCoalescer(w)
|
||||||
|
pkt := buildTCPv4(1000, tcpAck, []byte("hello"))
|
||||||
|
if err := c.Add(pkt); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if len(w.writes) != 1 || len(w.gsoWrites) != 0 {
|
||||||
|
t.Fatalf("want single plain write, got writes=%d gso=%d", len(w.writes), len(w.gsoWrites))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCoalescerNonTCPPassthrough(t *testing.T) {
|
||||||
|
w := &fakeTunWriter{gsoEnabled: true}
|
||||||
|
c := newTCPCoalescer(w)
|
||||||
|
// ICMP packet: proto=1.
|
||||||
|
pkt := make([]byte, 28)
|
||||||
|
pkt[0] = 0x45
|
||||||
|
binary.BigEndian.PutUint16(pkt[2:4], 28)
|
||||||
|
pkt[9] = 1
|
||||||
|
copy(pkt[12:16], []byte{10, 0, 0, 1})
|
||||||
|
copy(pkt[16:20], []byte{10, 0, 0, 2})
|
||||||
|
if err := c.Add(pkt); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if len(w.writes) != 1 || len(w.gsoWrites) != 0 {
|
||||||
|
t.Fatalf("ICMP should pass through unchanged")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCoalescerSeedThenFlushAlone(t *testing.T) {
|
||||||
|
w := &fakeTunWriter{gsoEnabled: true}
|
||||||
|
c := newTCPCoalescer(w)
|
||||||
|
pkt := buildTCPv4(1000, tcpAck, make([]byte, 1000))
|
||||||
|
if err := c.Add(pkt); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
// No flush yet — still pending.
|
||||||
|
if len(w.writes) != 0 || len(w.gsoWrites) != 0 {
|
||||||
|
t.Fatalf("unexpected output before flush")
|
||||||
|
}
|
||||||
|
if err := c.Flush(); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
// Single segment — should use plain write, not gso.
|
||||||
|
if len(w.writes) != 1 || len(w.gsoWrites) != 0 {
|
||||||
|
t.Fatalf("single-seg flush: writes=%d gso=%d", len(w.writes), len(w.gsoWrites))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCoalescerCoalescesAdjacentACKs(t *testing.T) {
|
||||||
|
w := &fakeTunWriter{gsoEnabled: true}
|
||||||
|
c := newTCPCoalescer(w)
|
||||||
|
pay := make([]byte, 1200)
|
||||||
|
if err := c.Add(buildTCPv4(1000, tcpAck, pay)); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if err := c.Add(buildTCPv4(2200, tcpAck, pay)); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if err := c.Add(buildTCPv4(3400, tcpAck, pay)); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if err := c.Flush(); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if len(w.gsoWrites) != 1 {
|
||||||
|
t.Fatalf("want 1 gso write, got %d (plain=%d)", len(w.gsoWrites), len(w.writes))
|
||||||
|
}
|
||||||
|
g := w.gsoWrites[0]
|
||||||
|
if g.gsoSize != 1200 {
|
||||||
|
t.Errorf("gsoSize=%d want 1200", g.gsoSize)
|
||||||
|
}
|
||||||
|
if g.hdrLen != 40 {
|
||||||
|
t.Errorf("hdrLen=%d want 40", g.hdrLen)
|
||||||
|
}
|
||||||
|
if g.csumStart != 20 {
|
||||||
|
t.Errorf("csumStart=%d want 20", g.csumStart)
|
||||||
|
}
|
||||||
|
if len(g.pkt) != 40+3*1200 {
|
||||||
|
t.Errorf("superpacket len=%d want %d", len(g.pkt), 40+3*1200)
|
||||||
|
}
|
||||||
|
// IP total length should reflect superpacket.
|
||||||
|
if tot := binary.BigEndian.Uint16(g.pkt[2:4]); int(tot) != len(g.pkt) {
|
||||||
|
t.Errorf("ip total_length=%d want %d", tot, len(g.pkt))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCoalescerRejectsSeqGap(t *testing.T) {
|
||||||
|
w := &fakeTunWriter{gsoEnabled: true}
|
||||||
|
c := newTCPCoalescer(w)
|
||||||
|
pay := make([]byte, 1200)
|
||||||
|
if err := c.Add(buildTCPv4(1000, tcpAck, pay)); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
// seq should be 2200; use 3000 to simulate a gap.
|
||||||
|
if err := c.Add(buildTCPv4(3000, tcpAck, pay)); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if err := c.Flush(); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
// First packet should have been flushed as a plain write (single seg),
|
||||||
|
// then second packet seeded and flushed likewise.
|
||||||
|
if len(w.writes) != 2 || len(w.gsoWrites) != 0 {
|
||||||
|
t.Fatalf("seq gap: want 2 plain writes got writes=%d gso=%d", len(w.writes), len(w.gsoWrites))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCoalescerRejectsFlagMismatch(t *testing.T) {
|
||||||
|
w := &fakeTunWriter{gsoEnabled: true}
|
||||||
|
c := newTCPCoalescer(w)
|
||||||
|
pay := make([]byte, 1200)
|
||||||
|
if err := c.Add(buildTCPv4(1000, tcpAck, pay)); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
// SYN flag — not admissible at all. Should flush first packet + plain-write second.
|
||||||
|
syn := buildTCPv4(2200, tcpSyn|tcpAck, pay)
|
||||||
|
if err := c.Add(syn); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if err := c.Flush(); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if len(w.writes) != 2 || len(w.gsoWrites) != 0 {
|
||||||
|
t.Fatalf("flag mismatch: want 2 plain writes got writes=%d gso=%d", len(w.writes), len(w.gsoWrites))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCoalescerRejectsFIN(t *testing.T) {
|
||||||
|
w := &fakeTunWriter{gsoEnabled: true}
|
||||||
|
c := newTCPCoalescer(w)
|
||||||
|
fin := buildTCPv4(1000, tcpAck|tcpFin, []byte("x"))
|
||||||
|
if err := c.Add(fin); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if err := c.Flush(); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if len(w.writes) != 1 || len(w.gsoWrites) != 0 {
|
||||||
|
t.Fatalf("FIN should be passthrough, got writes=%d gso=%d", len(w.writes), len(w.gsoWrites))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCoalescerShortLastSegmentClosesChain(t *testing.T) {
|
||||||
|
w := &fakeTunWriter{gsoEnabled: true}
|
||||||
|
c := newTCPCoalescer(w)
|
||||||
|
full := make([]byte, 1200)
|
||||||
|
half := make([]byte, 500)
|
||||||
|
if err := c.Add(buildTCPv4(1000, tcpAck, full)); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if err := c.Add(buildTCPv4(2200, tcpAck, half)); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
// Next full-size would have to start at 2700 but chain is closed —
|
||||||
|
// should flush + seed.
|
||||||
|
if err := c.Add(buildTCPv4(2700, tcpAck, full)); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if err := c.Flush(); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
// Expect: one gso write (first two coalesced) + one plain write (the
|
||||||
|
// third, flushed alone).
|
||||||
|
if len(w.gsoWrites) != 1 {
|
||||||
|
t.Fatalf("want 1 gso write got %d", len(w.gsoWrites))
|
||||||
|
}
|
||||||
|
if len(w.writes) != 1 {
|
||||||
|
t.Fatalf("want 1 plain write got %d", len(w.writes))
|
||||||
|
}
|
||||||
|
if w.gsoWrites[0].gsoSize != 1200 {
|
||||||
|
t.Errorf("gsoSize=%d want 1200", w.gsoWrites[0].gsoSize)
|
||||||
|
}
|
||||||
|
if got, want := len(w.gsoWrites[0].pkt), 40+1200+500; got != want {
|
||||||
|
t.Errorf("super len=%d want %d", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCoalescerPSHFinalizesChain(t *testing.T) {
|
||||||
|
w := &fakeTunWriter{gsoEnabled: true}
|
||||||
|
c := newTCPCoalescer(w)
|
||||||
|
pay := make([]byte, 1200)
|
||||||
|
if err := c.Add(buildTCPv4(1000, tcpAck, pay)); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
// Last full-size segment with PSH — admitted but chain is now closed.
|
||||||
|
if err := c.Add(buildTCPv4(2200, tcpAckPsh, pay)); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
// Further full-size would not coalesce.
|
||||||
|
if err := c.Add(buildTCPv4(3400, tcpAck, pay)); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if err := c.Flush(); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if len(w.gsoWrites) != 1 {
|
||||||
|
t.Fatalf("want 1 gso write got %d", len(w.gsoWrites))
|
||||||
|
}
|
||||||
|
if len(w.writes) != 1 {
|
||||||
|
t.Fatalf("want 1 plain write got %d", len(w.writes))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCoalescerRejectsDifferentFlow(t *testing.T) {
|
||||||
|
w := &fakeTunWriter{gsoEnabled: true}
|
||||||
|
c := newTCPCoalescer(w)
|
||||||
|
pay := make([]byte, 1200)
|
||||||
|
p1 := buildTCPv4(1000, tcpAck, pay)
|
||||||
|
p2 := buildTCPv4(2200, tcpAck, pay)
|
||||||
|
// Mutate p2's source port to break flow match.
|
||||||
|
binary.BigEndian.PutUint16(p2[20:22], 9999)
|
||||||
|
if err := c.Add(p1); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if err := c.Add(p2); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if err := c.Flush(); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
// Both flushed as plain writes.
|
||||||
|
if len(w.writes) != 2 || len(w.gsoWrites) != 0 {
|
||||||
|
t.Fatalf("diff flow: writes=%d gso=%d", len(w.writes), len(w.gsoWrites))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCoalescerRejectsIPOptions(t *testing.T) {
|
||||||
|
w := &fakeTunWriter{gsoEnabled: true}
|
||||||
|
c := newTCPCoalescer(w)
|
||||||
|
pay := make([]byte, 500)
|
||||||
|
pkt := buildTCPv4(1000, tcpAck, pay)
|
||||||
|
// Bump IHL to 6 to simulate 4 bytes of IP options. Don't actually add
|
||||||
|
// bytes — parser should bail before it matters.
|
||||||
|
pkt[0] = 0x46
|
||||||
|
if err := c.Add(pkt); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if err := c.Flush(); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if len(w.writes) != 1 || len(w.gsoWrites) != 0 {
|
||||||
|
t.Fatalf("IP options should passthrough, got writes=%d gso=%d", len(w.writes), len(w.gsoWrites))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCoalescerCapBySegments(t *testing.T) {
|
||||||
|
w := &fakeTunWriter{gsoEnabled: true}
|
||||||
|
c := newTCPCoalescer(w)
|
||||||
|
pay := make([]byte, 512) // small so we can fit many before byte cap
|
||||||
|
seq := uint32(1000)
|
||||||
|
for i := 0; i < tcpCoalesceMaxSegs+5; i++ {
|
||||||
|
if err := c.Add(buildTCPv4(seq, tcpAck, pay)); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
seq += uint32(len(pay))
|
||||||
|
}
|
||||||
|
if err := c.Flush(); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
// We expect the first tcpCoalesceMaxSegs to form one gso, then 5 more:
|
||||||
|
// The 5 follow-ons seed a new super that completes as another gso if >=2,
|
||||||
|
// or a mix. Just assert we never exceed the cap per super.
|
||||||
|
for _, g := range w.gsoWrites {
|
||||||
|
segs := (len(g.pkt) - int(g.hdrLen)) / int(g.gsoSize)
|
||||||
|
if rem := (len(g.pkt) - int(g.hdrLen)) % int(g.gsoSize); rem != 0 {
|
||||||
|
segs++
|
||||||
|
}
|
||||||
|
if segs > tcpCoalesceMaxSegs {
|
||||||
|
t.Fatalf("super exceeded seg cap: %d > %d", segs, tcpCoalesceMaxSegs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -22,7 +22,12 @@ type EncReader func(
|
|||||||
type Conn interface {
|
type Conn interface {
|
||||||
Rebind() error
|
Rebind() error
|
||||||
LocalAddr() (netip.AddrPort, error)
|
LocalAddr() (netip.AddrPort, error)
|
||||||
ListenOut(r EncReader) error
|
// ListenOut invokes r for each received packet. On batch-capable
|
||||||
|
// backends (recvmmsg), flush is called after each batch is fully
|
||||||
|
// delivered — callers use it to flush per-batch accumulators such as
|
||||||
|
// TUN write coalescers. Single-packet backends call flush after each
|
||||||
|
// packet. flush must not be nil.
|
||||||
|
ListenOut(r EncReader, flush func()) error
|
||||||
WriteTo(b []byte, addr netip.AddrPort) error
|
WriteTo(b []byte, addr netip.AddrPort) error
|
||||||
// WriteBatch sends a contiguous batch of packets, each with its own
|
// WriteBatch sends a contiguous batch of packets, each with its own
|
||||||
// destination. bufs and addrs must have the same length. Linux uses
|
// destination. bufs and addrs must have the same length. Linux uses
|
||||||
@@ -53,7 +58,7 @@ func (NoopConn) Rebind() error {
|
|||||||
func (NoopConn) LocalAddr() (netip.AddrPort, error) {
|
func (NoopConn) LocalAddr() (netip.AddrPort, error) {
|
||||||
return netip.AddrPort{}, nil
|
return netip.AddrPort{}, nil
|
||||||
}
|
}
|
||||||
func (NoopConn) ListenOut(_ EncReader) error {
|
func (NoopConn) ListenOut(_ EncReader, _ func()) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
func (NoopConn) SupportsMultipleReaders() bool {
|
func (NoopConn) SupportsMultipleReaders() bool {
|
||||||
|
|||||||
@@ -185,7 +185,7 @@ func NewUDPStatsEmitter(udpConns []Conn) func() {
|
|||||||
return func() {}
|
return func() {}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *StdConn) ListenOut(r EncReader) error {
|
func (u *StdConn) ListenOut(r EncReader, flush func()) error {
|
||||||
buffer := make([]byte, MTU)
|
buffer := make([]byte, MTU)
|
||||||
|
|
||||||
for {
|
for {
|
||||||
@@ -200,6 +200,7 @@ func (u *StdConn) ListenOut(r EncReader) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n])
|
r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n])
|
||||||
|
flush()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -91,7 +91,7 @@ type rawMessage struct {
|
|||||||
Len uint32
|
Len uint32
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *GenericConn) ListenOut(r EncReader) error {
|
func (u *GenericConn) ListenOut(r EncReader, flush func()) error {
|
||||||
buffer := make([]byte, MTU)
|
buffer := make([]byte, MTU)
|
||||||
|
|
||||||
for {
|
for {
|
||||||
@@ -102,6 +102,7 @@ func (u *GenericConn) ListenOut(r EncReader) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n])
|
r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n])
|
||||||
|
flush()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -249,7 +249,7 @@ func recvmmsg(fd uintptr, msgs []rawMessage) (int, bool, error) {
|
|||||||
return int(n), true, nil
|
return int(n), true, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *StdConn) listenOutSingle(r EncReader) error {
|
func (u *StdConn) listenOutSingle(r EncReader, flush func()) error {
|
||||||
var err error
|
var err error
|
||||||
var n int
|
var n int
|
||||||
var from netip.AddrPort
|
var from netip.AddrPort
|
||||||
@@ -262,10 +262,11 @@ func (u *StdConn) listenOutSingle(r EncReader) error {
|
|||||||
}
|
}
|
||||||
from = netip.AddrPortFrom(from.Addr().Unmap(), from.Port())
|
from = netip.AddrPortFrom(from.Addr().Unmap(), from.Port())
|
||||||
r(from, buffer[:n])
|
r(from, buffer[:n])
|
||||||
|
flush()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *StdConn) listenOutBatch(r EncReader) error {
|
func (u *StdConn) listenOutBatch(r EncReader, flush func()) error {
|
||||||
var ip netip.Addr
|
var ip netip.Addr
|
||||||
var n int
|
var n int
|
||||||
var operr error
|
var operr error
|
||||||
@@ -297,14 +298,17 @@ func (u *StdConn) listenOutBatch(r EncReader) error {
|
|||||||
}
|
}
|
||||||
r(netip.AddrPortFrom(ip.Unmap(), binary.BigEndian.Uint16(names[i][2:4])), buffers[i][:msgs[i].Len])
|
r(netip.AddrPortFrom(ip.Unmap(), binary.BigEndian.Uint16(names[i][2:4])), buffers[i][:msgs[i].Len])
|
||||||
}
|
}
|
||||||
|
// End-of-batch: let callers (e.g. TUN write coalescer) flush any
|
||||||
|
// state they accumulated across this batch.
|
||||||
|
flush()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *StdConn) ListenOut(r EncReader) error {
|
func (u *StdConn) ListenOut(r EncReader, flush func()) error {
|
||||||
if u.batch == 1 {
|
if u.batch == 1 {
|
||||||
return u.listenOutSingle(r)
|
return u.listenOutSingle(r, flush)
|
||||||
} else {
|
} else {
|
||||||
return u.listenOutBatch(r)
|
return u.listenOutBatch(r, flush)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -140,7 +140,7 @@ func (u *RIOConn) bind(l *logrus.Logger, sa windows.Sockaddr) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *RIOConn) ListenOut(r EncReader) error {
|
func (u *RIOConn) ListenOut(r EncReader, flush func()) error {
|
||||||
buffer := make([]byte, MTU)
|
buffer := make([]byte, MTU)
|
||||||
|
|
||||||
var lastRecvErr time.Time
|
var lastRecvErr time.Time
|
||||||
@@ -162,6 +162,7 @@ func (u *RIOConn) ListenOut(r EncReader) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
r(netip.AddrPortFrom(netip.AddrFrom16(rua.Addr).Unmap(), (rua.Port>>8)|((rua.Port&0xff)<<8)), buffer[:n])
|
r(netip.AddrPortFrom(netip.AddrFrom16(rua.Addr).Unmap(), (rua.Port>>8)|((rua.Port&0xff)<<8)), buffer[:n])
|
||||||
|
flush()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -127,13 +127,14 @@ func (u *TesterConn) WriteSegmented(bufs [][]byte, addr netip.AddrPort, _ int) e
|
|||||||
|
|
||||||
func (u *TesterConn) SupportsGSO() bool { return false }
|
func (u *TesterConn) SupportsGSO() bool { return false }
|
||||||
|
|
||||||
func (u *TesterConn) ListenOut(r EncReader) error {
|
func (u *TesterConn) ListenOut(r EncReader, flush func()) error {
|
||||||
for {
|
for {
|
||||||
p, ok := <-u.RxPackets
|
p, ok := <-u.RxPackets
|
||||||
if !ok {
|
if !ok {
|
||||||
return os.ErrClosed
|
return os.ErrClosed
|
||||||
}
|
}
|
||||||
r(p.From, p.Data)
|
r(p.From, p.Data)
|
||||||
|
flush()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user