fix tests

This commit is contained in:
JackDoan
2026-04-23 11:35:51 -05:00
parent 382b15ac52
commit f76ac2e216
8 changed files with 97 additions and 164 deletions

View File

@@ -4,7 +4,6 @@ import (
"fmt"
"io"
"os"
"runtime"
"sync/atomic"
"syscall"
"unsafe"
@@ -28,7 +27,7 @@ const tunSegBufCap = tunSegBufSize * 2
const tunDrainCap = 64
// gsoInitialPayIovs is the starting capacity (in payload fragments) of
// tunFile.gsoIovs. Sized to cover the default coalesce segment cap without
// Offload.gsoIovs. Sized to cover the default coalesce segment cap without
// any reallocations.
const gsoInitialPayIovs = 66
@@ -42,9 +41,9 @@ const gsoInitialPayIovs = 66
// safe.
var validVnetHdr = [virtioNetHdrLen]byte{unix.VIRTIO_NET_HDR_F_DATA_VALID}
// tunFile wraps a TUN file descriptor with poll-based reads. The FD provided will be changed to non-blocking.
// Offload wraps a TUN file descriptor with poll-based reads. The FD provided will be changed to non-blocking.
// A shared eventfd allows Close to wake all readers blocked in poll.
type tunFile struct { //todo rename GSO
type Offload struct {
fd int
shutdownFd int
readPoll [2]unix.PollFd
@@ -71,12 +70,12 @@ type tunFile struct { //todo rename GSO
gsoIovs []unix.Iovec
}
func newTunFd(fd int, shutdownFd int) (*tunFile, error) {
func newOffload(fd int, shutdownFd int) (*Offload, error) {
if err := unix.SetNonblock(fd, true); err != nil {
return nil, fmt.Errorf("failed to set tun fd non-blocking: %w", err)
}
out := &tunFile{
out := &Offload{
fd: fd,
shutdownFd: shutdownFd,
closed: atomic.Bool{},
@@ -104,7 +103,7 @@ func newTunFd(fd int, shutdownFd int) (*tunFile, error) {
return out, nil
}
func (r *tunFile) blockOnRead() error {
func (r *Offload) blockOnRead() error {
const problemFlags = unix.POLLHUP | unix.POLLNVAL | unix.POLLERR
var err error
for {
@@ -130,7 +129,7 @@ func (r *tunFile) blockOnRead() error {
return nil
}
func (r *tunFile) blockOnWrite() error {
func (r *Offload) blockOnWrite() error {
const problemFlags = unix.POLLHUP | unix.POLLNVAL | unix.POLLERR
var err error
for {
@@ -156,7 +155,7 @@ func (r *tunFile) blockOnWrite() error {
return nil
}
func (r *tunFile) readRaw(buf []byte) (int, error) {
func (r *Offload) readRaw(buf []byte) (int, error) {
for {
if n, err := unix.Read(r.fd, buf); err == nil {
return n, nil
@@ -180,9 +179,9 @@ func (r *tunFile) readRaw(buf []byte) (int, error) {
// readable we drain additional packets non-blocking until the kernel queue
// is empty (EAGAIN), we've collected tunDrainCap packets, or we're out of
// segBuf headroom. This amortizes the poll wake over bursts of small
// packets (e.g. TCP ACKs). Slices point into the tunFile's internal buffers
// packets (e.g. TCP ACKs). Slices point into the Offload's internal buffers
// and are only valid until the next Read or Close on this Queue.
func (r *tunFile) Read() ([][]byte, error) {
func (r *Offload) Read() ([][]byte, error) {
r.pending = r.pending[:0]
r.segOff = 0
@@ -226,7 +225,7 @@ func (r *tunFile) Read() ([][]byte, error) {
// decodeRead decodes the virtio header plus payload in r.readBuf[:n], appends
// the segments to r.pending, and advances r.segOff by the total scratch used.
// Caller must have already ensured r.vnetHdr is true.
func (r *tunFile) decodeRead(n int) error {
func (r *Offload) decodeRead(n int) error {
if n < virtioNetHdrLen {
return fmt.Errorf("short tun read: %d < %d", n, virtioNetHdrLen)
}
@@ -242,7 +241,7 @@ func (r *tunFile) decodeRead(n int) error {
return nil
}
func (r *tunFile) Write(buf []byte) (int, error) {
func (r *Offload) Write(buf []byte) (int, error) {
return r.writeWithScratch(buf, &r.writeIovs)
}
@@ -250,36 +249,33 @@ func (r *tunFile) Write(buf []byte) (int, error) {
// distinct from the one used by the coalescer's Write path. This avoids a
// data race between the inside (listenIn) goroutine emitting reject or
// self-forward packets and the outside (listenOut) goroutine flushing TCP
// coalescer passthroughs on the same tunFile.
func (r *tunFile) WriteReject(buf []byte) (int, error) {
// coalescer passthroughs on the same Offload.
func (r *Offload) WriteReject(buf []byte) (int, error) {
return r.writeWithScratch(buf, &r.rejectIovs)
}
func (r *tunFile) writeWithScratch(buf []byte, iovs *[2]unix.Iovec) (int, error) {
func (r *Offload) writeWithScratch(buf []byte, iovs *[2]unix.Iovec) (int, error) {
if len(buf) == 0 {
return 0, nil
}
// Point the payload iovec at the caller's buffer. iovs[0] is pre-wired
// to validVnetHdr during tunFile construction so we don't rebuild it here.
// to validVnetHdr during Offload construction so we don't rebuild it here.
iovs[1].Base = &buf[0]
iovs[1].SetLen(len(buf))
iovPtr := uintptr(unsafe.Pointer(&iovs[0]))
// The TUN fd is non-blocking (set in newTunFd / newFriend), so writev
// either completes promptly or returns EAGAIN — it cannot park the
// goroutine inside the kernel. That lets us use syscall.RawSyscall and
// skip the runtime.entersyscall / exitsyscall bookkeeping on every
// packet; we only pay that cost when we fall through to blockOnWrite.
iovPtr := unsafe.Pointer(&iovs[0])
return r.rawWrite(iovPtr, 2)
}
func (r *Offload) rawWrite(iovs unsafe.Pointer, iovcnt int) (int, error) {
for {
n, _, errno := syscall.RawSyscall(unix.SYS_WRITEV, uintptr(r.fd), iovPtr, 2)
n, _, errno := syscall.Syscall(unix.SYS_WRITEV, uintptr(r.fd), uintptr(iovs), uintptr(iovcnt))
if errno == 0 {
runtime.KeepAlive(buf)
if int(n) < virtioNetHdrLen {
return 0, io.ErrShortWrite
}
return int(n) - virtioNetHdrLen, nil
}
if errno == unix.EAGAIN {
runtime.KeepAlive(buf)
if err := r.blockOnWrite(); err != nil {
return 0, err
}
@@ -291,7 +287,6 @@ func (r *tunFile) writeWithScratch(buf []byte, iovs *[2]unix.Iovec) (int, error)
if errno == unix.EBADF {
return 0, os.ErrClosed
}
runtime.KeepAlive(buf)
return 0, errno
}
}
@@ -299,7 +294,7 @@ func (r *tunFile) writeWithScratch(buf []byte, iovs *[2]unix.Iovec) (int, error)
// GSOSupported reports whether this queue was opened with IFF_VNET_HDR and
// can accept WriteGSO. When false, callers should fall back to per-segment
// Write calls.
func (r *tunFile) GSOSupported() bool { return true }
func (r *Offload) GSOSupported() bool { return true }
// WriteGSO emits a TCP TSO superpacket in a single writev. hdr is the
// IPv4/IPv6 + TCP header prefix (already finalized — total length, IP csum,
@@ -308,7 +303,7 @@ func (r *tunFile) GSOSupported() bool { return true }
// slice is read-only and must stay valid until return. gsoSize is the MSS;
// every segment except possibly the last is exactly gsoSize bytes.
// csumStart is the byte offset where the TCP header begins within hdr.
func (r *tunFile) WriteGSO(hdr []byte, pays [][]byte, gsoSize uint16, isV6 bool, csumStart uint16) error {
func (r *Offload) WriteGSO(hdr []byte, pays [][]byte, gsoSize uint16, isV6 bool, csumStart uint16) error {
if len(hdr) == 0 || len(pays) == 0 {
return nil
}
@@ -356,45 +351,18 @@ func (r *tunFile) WriteGSO(hdr []byte, pays [][]byte, gsoSize uint16, isV6 bool,
r.gsoIovs[2+i].SetLen(len(p))
}
iovPtr := uintptr(unsafe.Pointer(&r.gsoIovs[0]))
iovCnt := uintptr(len(r.gsoIovs))
for {
n, _, errno := syscall.RawSyscall(unix.SYS_WRITEV, uintptr(r.fd), iovPtr, iovCnt)
if errno == 0 {
runtime.KeepAlive(hdr)
runtime.KeepAlive(pays)
if int(n) < virtioNetHdrLen {
return io.ErrShortWrite
}
return nil
}
if errno == unix.EAGAIN {
runtime.KeepAlive(hdr)
runtime.KeepAlive(pays)
if err := r.blockOnWrite(); err != nil {
return err
}
continue
}
if errno == unix.EINTR {
continue
}
if errno == unix.EBADF {
return os.ErrClosed
}
runtime.KeepAlive(hdr)
runtime.KeepAlive(pays)
return errno
}
iovPtr := unsafe.Pointer(&r.gsoIovs[0])
iovCnt := len(r.gsoIovs)
_, err := r.rawWrite(iovPtr, iovCnt)
return err
}
func (r *tunFile) Close() error {
func (r *Offload) Close() error {
if r.closed.Swap(true) {
return nil
}
//shutdownFd is owned by the container, so we should not close it
var err error
if r.fd >= 0 {
err = unix.Close(r.fd)