diff --git a/overlay/batch/tcp_coalesce.go b/overlay/batch/tcp_coalesce.go index 35b4db8e..604a802e 100644 --- a/overlay/batch/tcp_coalesce.go +++ b/overlay/batch/tcp_coalesce.go @@ -397,7 +397,7 @@ func (c *TCPCoalescer) flushSlot(s *coalesceSlot) error { tcsum := s.ipHdrLen + 16 binary.BigEndian.PutUint16(hdr[tcsum:tcsum+2], foldOnceNoInvert(psum)) - return c.gsoW.WriteGSO(hdr, s.payIovs, uint16(s.gsoSize), uint16(s.ipHdrLen)) + return c.gsoW.WriteGSO(hdr[:s.ipHdrLen], hdr[s.ipHdrLen:], s.payIovs) } // headersMatch compares two IP+TCP header prefixes for byte-for-byte diff --git a/overlay/tio/tio.go b/overlay/tio/tio.go index e8a2d902..adae230e 100644 --- a/overlay/tio/tio.go +++ b/overlay/tio/tio.go @@ -55,6 +55,13 @@ type Queue interface { // hdr's TCP checksum field must already hold the pseudo-header partial // sum (single-fold, not inverted), per virtio NEEDS_CSUM semantics. type GSOWriter interface { - WriteGSO(hdr []byte, pays [][]byte, gsoSize uint16, csumStart uint16) error + // WriteGSO emits a TCP TSO superpacket in a single writev. hdr is the + // IPv4/IPv6 + TCP header prefix (already finalized — total length, IP csum, + // and TCP pseudo-header partial set by the caller). pays are payload + // fragments whose concatenation forms the full coalesced payload; each + // slice is read-only and must stay valid until return. + // every segment in pays except possibly the last is exactly the same size. + // csumStart is the byte offset where the TCP header begins within hdr. + WriteGSO(hdr []byte, transportHdr []byte, pays [][]byte) error GSOSupported() bool } diff --git a/overlay/tio/tio_gso_linux.go b/overlay/tio/tio_gso_linux.go index b85e2e30..fcc1f231 100644 --- a/overlay/tio/tio_gso_linux.go +++ b/overlay/tio/tio_gso_linux.go @@ -291,33 +291,18 @@ func (r *Offload) rawWrite(iovs []unix.Iovec) (int, error) { // Write calls. func (r *Offload) GSOSupported() bool { return true } -// WriteGSO emits a TCP TSO superpacket in a single writev. hdr is the -// IPv4/IPv6 + TCP header prefix (already finalized — total length, IP csum, -// and TCP pseudo-header partial set by the caller). pays are payload -// fragments whose concatenation forms the full coalesced payload; each -// slice is read-only and must stay valid until return. gsoSize is the MSS; -// every segment except possibly the last is exactly gsoSize bytes. -// csumStart is the byte offset where the TCP header begins within hdr. -func (r *Offload) WriteGSO(hdr []byte, pays [][]byte, gsoSize uint16, csumStart uint16) error { - if len(hdr) == 0 || len(pays) == 0 { +func (r *Offload) WriteGSO(hdr []byte, transportHdr []byte, pays [][]byte) error { + if len(hdr) == 0 || len(pays) == 0 || len(transportHdr) == 0 { return nil } - - // Build the virtio_net_hdr. When pays total to <= gsoSize the kernel - // would produce a single segment; keep NEEDS_CSUM semantics but skip - // the GSO type so the kernel doesn't spuriously mark this as TSO. vhdr := VirtioNetHdr{ Flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, - HdrLen: uint16(len(hdr)), + HdrLen: uint16(len(hdr) + len(transportHdr)), GSOSize: uint16(len(pays[0])), - CsumStart: csumStart, + CsumStart: uint16(len(hdr)), CsumOffset: 16, // TCP checksum field lives 16 bytes into the TCP header } - var totalPay int - for _, p := range pays { - totalPay += len(p) - } - if totalPay > int(gsoSize) { + if len(pays) > 1 { ipVer := hdr[0] >> 4 if ipVer == 6 { vhdr.GSOType = unix.VIRTIO_NET_HDR_GSO_TCPV6 @@ -333,9 +318,9 @@ func (r *Offload) WriteGSO(hdr []byte, pays [][]byte, gsoSize uint16, csumStart } vhdr.encode(r.gsoHdrBuf[:]) - // Build the iovec array: [virtio_hdr, hdr, pays...]. r.gsoIovs[0] is + // Build the iovec array: [virtio_hdr, hdr, transportHdr, pays...]. r.gsoIovs[0] is // wired to gsoHdrBuf at construction and never changes. - need := 2 + len(pays) + need := 3 + len(pays) if cap(r.gsoIovs) < need { grown := make([]unix.Iovec, need) grown[0] = r.gsoIovs[0] @@ -345,9 +330,11 @@ func (r *Offload) WriteGSO(hdr []byte, pays [][]byte, gsoSize uint16, csumStart } r.gsoIovs[1].Base = &hdr[0] r.gsoIovs[1].SetLen(len(hdr)) + r.gsoIovs[2].Base = &transportHdr[0] + r.gsoIovs[2].SetLen(len(transportHdr)) for i, p := range pays { - r.gsoIovs[2+i].Base = &p[0] - r.gsoIovs[2+i].SetLen(len(p)) + r.gsoIovs[3+i].Base = &p[0] + r.gsoIovs[3+i].SetLen(len(p)) } _, err := r.rawWrite(r.gsoIovs)