slightly nicer contract?

This commit is contained in:
JackDoan
2026-04-28 16:50:50 -05:00
parent f5db77f214
commit c62f27d4b4
3 changed files with 20 additions and 26 deletions

View File

@@ -397,7 +397,7 @@ func (c *TCPCoalescer) flushSlot(s *coalesceSlot) error {
tcsum := s.ipHdrLen + 16 tcsum := s.ipHdrLen + 16
binary.BigEndian.PutUint16(hdr[tcsum:tcsum+2], foldOnceNoInvert(psum)) binary.BigEndian.PutUint16(hdr[tcsum:tcsum+2], foldOnceNoInvert(psum))
return c.gsoW.WriteGSO(hdr, s.payIovs, uint16(s.gsoSize), uint16(s.ipHdrLen)) return c.gsoW.WriteGSO(hdr[:s.ipHdrLen], hdr[s.ipHdrLen:], s.payIovs)
} }
// headersMatch compares two IP+TCP header prefixes for byte-for-byte // headersMatch compares two IP+TCP header prefixes for byte-for-byte

View File

@@ -55,6 +55,13 @@ type Queue interface {
// hdr's TCP checksum field must already hold the pseudo-header partial // hdr's TCP checksum field must already hold the pseudo-header partial
// sum (single-fold, not inverted), per virtio NEEDS_CSUM semantics. // sum (single-fold, not inverted), per virtio NEEDS_CSUM semantics.
type GSOWriter interface { type GSOWriter interface {
WriteGSO(hdr []byte, pays [][]byte, gsoSize uint16, csumStart uint16) error // WriteGSO emits a TCP TSO superpacket in a single writev. hdr is the
// IPv4/IPv6 + TCP header prefix (already finalized — total length, IP csum,
// and TCP pseudo-header partial set by the caller). pays are payload
// fragments whose concatenation forms the full coalesced payload; each
// slice is read-only and must stay valid until return.
// every segment in pays except possibly the last is exactly the same size.
// csumStart is the byte offset where the TCP header begins within hdr.
WriteGSO(hdr []byte, transportHdr []byte, pays [][]byte) error
GSOSupported() bool GSOSupported() bool
} }

View File

@@ -291,33 +291,18 @@ func (r *Offload) rawWrite(iovs []unix.Iovec) (int, error) {
// Write calls. // Write calls.
func (r *Offload) GSOSupported() bool { return true } func (r *Offload) GSOSupported() bool { return true }
// WriteGSO emits a TCP TSO superpacket in a single writev. hdr is the func (r *Offload) WriteGSO(hdr []byte, transportHdr []byte, pays [][]byte) error {
// IPv4/IPv6 + TCP header prefix (already finalized — total length, IP csum, if len(hdr) == 0 || len(pays) == 0 || len(transportHdr) == 0 {
// and TCP pseudo-header partial set by the caller). pays are payload
// fragments whose concatenation forms the full coalesced payload; each
// slice is read-only and must stay valid until return. gsoSize is the MSS;
// every segment except possibly the last is exactly gsoSize bytes.
// csumStart is the byte offset where the TCP header begins within hdr.
func (r *Offload) WriteGSO(hdr []byte, pays [][]byte, gsoSize uint16, csumStart uint16) error {
if len(hdr) == 0 || len(pays) == 0 {
return nil return nil
} }
// Build the virtio_net_hdr. When pays total to <= gsoSize the kernel
// would produce a single segment; keep NEEDS_CSUM semantics but skip
// the GSO type so the kernel doesn't spuriously mark this as TSO.
vhdr := VirtioNetHdr{ vhdr := VirtioNetHdr{
Flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, Flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM,
HdrLen: uint16(len(hdr)), HdrLen: uint16(len(hdr) + len(transportHdr)),
GSOSize: uint16(len(pays[0])), GSOSize: uint16(len(pays[0])),
CsumStart: csumStart, CsumStart: uint16(len(hdr)),
CsumOffset: 16, // TCP checksum field lives 16 bytes into the TCP header CsumOffset: 16, // TCP checksum field lives 16 bytes into the TCP header
} }
var totalPay int if len(pays) > 1 {
for _, p := range pays {
totalPay += len(p)
}
if totalPay > int(gsoSize) {
ipVer := hdr[0] >> 4 ipVer := hdr[0] >> 4
if ipVer == 6 { if ipVer == 6 {
vhdr.GSOType = unix.VIRTIO_NET_HDR_GSO_TCPV6 vhdr.GSOType = unix.VIRTIO_NET_HDR_GSO_TCPV6
@@ -333,9 +318,9 @@ func (r *Offload) WriteGSO(hdr []byte, pays [][]byte, gsoSize uint16, csumStart
} }
vhdr.encode(r.gsoHdrBuf[:]) vhdr.encode(r.gsoHdrBuf[:])
// Build the iovec array: [virtio_hdr, hdr, pays...]. r.gsoIovs[0] is // Build the iovec array: [virtio_hdr, hdr, transportHdr, pays...]. r.gsoIovs[0] is
// wired to gsoHdrBuf at construction and never changes. // wired to gsoHdrBuf at construction and never changes.
need := 2 + len(pays) need := 3 + len(pays)
if cap(r.gsoIovs) < need { if cap(r.gsoIovs) < need {
grown := make([]unix.Iovec, need) grown := make([]unix.Iovec, need)
grown[0] = r.gsoIovs[0] grown[0] = r.gsoIovs[0]
@@ -345,9 +330,11 @@ func (r *Offload) WriteGSO(hdr []byte, pays [][]byte, gsoSize uint16, csumStart
} }
r.gsoIovs[1].Base = &hdr[0] r.gsoIovs[1].Base = &hdr[0]
r.gsoIovs[1].SetLen(len(hdr)) r.gsoIovs[1].SetLen(len(hdr))
r.gsoIovs[2].Base = &transportHdr[0]
r.gsoIovs[2].SetLen(len(transportHdr))
for i, p := range pays { for i, p := range pays {
r.gsoIovs[2+i].Base = &p[0] r.gsoIovs[3+i].Base = &p[0]
r.gsoIovs[2+i].SetLen(len(p)) r.gsoIovs[3+i].SetLen(len(p))
} }
_, err := r.rawWrite(r.gsoIovs) _, err := r.rawWrite(r.gsoIovs)