diff --git a/overlay/tun_linux.go b/overlay/tun_linux.go index 719f5b67..5e42c486 100644 --- a/overlay/tun_linux.go +++ b/overlay/tun_linux.go @@ -24,6 +24,11 @@ import ( "golang.org/x/sys/unix" ) +const ( + // virtioNetHdrLen is the length of virtio_net_hdr (without mergeable buffers) + virtioNetHdrLen = 10 +) + type tun struct { io.ReadWriteCloser fd int @@ -35,6 +40,12 @@ type tun struct { deviceIndex int ioctlFd uintptr nonBlocking bool // true if fd is in non-blocking mode + vnetHdr bool // true if IFF_VNET_HDR is enabled on the TUN device + + // readBuf is used when vnetHdr is enabled to read the full packet+header + // before stripping the header. This is needed because caller-provided + // buffers are sized for MTU but kernel writes MTU+10 with virtio header. + readBuf []byte Routes atomic.Pointer[[]Route] routeTree atomic.Pointer[bart.Table[routing.Gateways]] @@ -54,6 +65,23 @@ func (t *tun) Networks() []netip.Prefix { return t.vpnNetworks } +// tunVnetHdrSupported checks if the kernel supports IFF_VNET_HDR on TUN devices +func tunVnetHdrSupported() bool { + fd, err := unix.Open("/dev/net/tun", unix.O_RDONLY, 0) + if err != nil { + return false + } + defer unix.Close(fd) + + var features uint32 + err = ioctl(uintptr(fd), uintptr(unix.TUNGETFEATURES), uintptr(unsafe.Pointer(&features))) + if err != nil { + return false + } + + return features&unix.IFF_VNET_HDR != 0 +} + type ifReq struct { Name [16]byte Flags uint16 @@ -108,11 +136,18 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueu } } + // Check if VNET_HDR is supported before trying to use it + useVnetHdr := tunVnetHdrSupported() + var req ifReq req.Flags = uint16(unix.IFF_TUN | unix.IFF_NO_PI) if multiqueue { req.Flags |= unix.IFF_MULTI_QUEUE } + if useVnetHdr { + req.Flags |= unix.IFF_VNET_HDR + } + nameStr := c.GetString("tun.dev", "") copy(req.Name[:], nameStr) if err = ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil { @@ -123,6 +158,13 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueu } name := strings.Trim(string(req.Name[:]), "\x00") + // Track if VNET_HDR is in use + // Note: We don't call TUNSETOFFLOAD - just handle the headers manually + vnetHdrEnabled := useVnetHdr + if vnetHdrEnabled { + l.Info("TUN VNET_HDR enabled") + } + file := os.NewFile(uintptr(fd), "/dev/net/tun") t, err := newTunGeneric(c, l, file, vpnNetworks) if err != nil { @@ -130,6 +172,13 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueu } t.Device = name + t.vnetHdr = vnetHdrEnabled + + // Allocate read buffer for virtio header handling + // Buffer needs to be large enough for virtio header + max packet + if t.vnetHdr { + t.readBuf = make([]byte, t.MaxMTU+virtioNetHdrLen) + } return t, nil } @@ -247,25 +296,48 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { var req ifReq req.Flags = uint16(unix.IFF_TUN | unix.IFF_NO_PI | unix.IFF_MULTI_QUEUE) + if t.vnetHdr { + req.Flags |= unix.IFF_VNET_HDR + } copy(req.Name[:], t.Device) if err = ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil { return nil, err } - return &tunBatchReader{fd: fd, device: t.Device}, nil + reader := &tunBatchReader{fd: fd, device: t.Device, vnetHdr: t.vnetHdr} + if t.vnetHdr { + reader.readBuf = make([]byte, t.MaxMTU+virtioNetHdrLen) + } + return reader, nil } // tunBatchReader implements BatchReader for efficient batch packet reading type tunBatchReader struct { - fd int - device string + fd int + device string + vnetHdr bool + readBuf []byte // internal buffer for virtio header handling } func (r *tunBatchReader) Read(b []byte) (int, error) { + // Choose buffer: use internal buffer for vnetHdr, caller's buffer otherwise + readBuf := b + if r.vnetHdr { + readBuf = r.readBuf + } + // Use poll to wait for data, then read for { - n, err := unix.Read(r.fd, b) + n, err := unix.Read(r.fd, readBuf) if err == nil { + if r.vnetHdr && n > virtioNetHdrLen { + packetLen := n - virtioNetHdrLen + copy(b, readBuf[virtioNetHdrLen:n]) + return packetLen, nil + } + if r.vnetHdr { + return 0, nil // No packet data + } return n, nil } if err == unix.EAGAIN || err == unix.EWOULDBLOCK { @@ -285,7 +357,24 @@ func (r *tunBatchReader) Read(b []byte) (int, error) { } func (r *tunBatchReader) Write(b []byte) (int, error) { - return unix.Write(r.fd, b) + if !r.vnetHdr { + return unix.Write(r.fd, b) + } + + // Use writev to prepend virtio header without copying the packet data + // Header is all zeros = no GSO, no checksum offload + var hdr [virtioNetHdrLen]byte + bufs := [][]byte{hdr[:], b} + + n, err := unix.Writev(r.fd, bufs) + if err != nil { + return 0, err + } + // Return only the packet bytes written (exclude header) + if n > virtioNetHdrLen { + return n - virtioNetHdrLen, nil + } + return 0, nil } func (r *tunBatchReader) Close() error { @@ -302,10 +391,31 @@ func (r *tunBatchReader) ReadBatch(packets [][]byte, sizes []int) (int, error) { maxPackets = len(sizes) } + // Choose read buffer based on vnetHdr + readBuf := packets[0] // Will be updated in loop for non-vnetHdr + if r.vnetHdr { + readBuf = r.readBuf + } + for count < maxPackets { - n, err := unix.Read(r.fd, packets[count]) + if !r.vnetHdr { + readBuf = packets[count] + } + + n, err := unix.Read(r.fd, readBuf) if err == nil && n > 0 { - sizes[count] = n + if r.vnetHdr { + if n > virtioNetHdrLen { + packetLen := n - virtioNetHdrLen + copy(packets[count], readBuf[virtioNetHdrLen:n]) + sizes[count] = packetLen + } else { + // Malformed packet (no data after header), skip + continue + } + } else { + sizes[count] = n + } count++ continue } @@ -353,6 +463,27 @@ func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways { } func (t *tun) Write(b []byte) (int, error) { + if !t.vnetHdr { + return t.writeSimple(b) + } + + // Use writev to prepend virtio header without copying the packet data + // Header is all zeros = no GSO, no checksum offload + var hdr [virtioNetHdrLen]byte + bufs := [][]byte{hdr[:], b} + + n, err := unix.Writev(t.fd, bufs) + if err != nil { + return 0, err + } + // Return only the packet bytes written (exclude header) + if n > virtioNetHdrLen { + return n - virtioNetHdrLen, nil + } + return 0, nil +} + +func (t *tun) writeSimple(b []byte) (int, error) { var nn int maximum := len(b) @@ -389,21 +520,64 @@ func (t *tun) EnableBatchReading() error { return nil } -// Read overrides the default Read to handle non-blocking mode +// Read overrides the default Read to handle non-blocking mode and virtio headers func (t *tun) Read(b []byte) (int, error) { + if !t.vnetHdr { + return t.readSimple(b) + } + + // With VNET_HDR, read into internal buffer (which has space for header) + // then copy packet data to caller's buffer if !t.nonBlocking { - // Use the embedded ReadWriteCloser for blocking reads - return t.ReadWriteCloser.Read(b) + n, err := t.ReadWriteCloser.Read(t.readBuf) + if err != nil { + return 0, err + } + if n <= virtioNetHdrLen { + return 0, nil // No packet data + } + packetLen := n - virtioNetHdrLen + copy(b, t.readBuf[virtioNetHdrLen:n]) + return packetLen, nil } // Non-blocking read with poll + for { + n, err := unix.Read(t.fd, t.readBuf) + if err == nil { + if n <= virtioNetHdrLen { + return 0, nil // No packet data + } + packetLen := n - virtioNetHdrLen + copy(b, t.readBuf[virtioNetHdrLen:n]) + return packetLen, nil + } + if err == unix.EAGAIN || err == unix.EWOULDBLOCK { + pfds := []unix.PollFd{{Fd: int32(t.fd), Events: unix.POLLIN}} + _, err = unix.Poll(pfds, -1) + if err != nil { + if err == unix.EINTR { + continue + } + return 0, err + } + continue + } + return n, err + } +} + +func (t *tun) readSimple(b []byte) (int, error) { + if !t.nonBlocking { + return t.ReadWriteCloser.Read(b) + } + for { n, err := unix.Read(t.fd, b) if err == nil { return n, nil } if err == unix.EAGAIN || err == unix.EWOULDBLOCK { - // Wait for data pfds := []unix.PollFd{{Fd: int32(t.fd), Events: unix.POLLIN}} _, err = unix.Poll(pfds, -1) if err != nil { @@ -423,7 +597,7 @@ func (t *tun) Read(b []byte) (int, error) { func (t *tun) ReadBatch(packets [][]byte, sizes []int) (int, error) { if !t.nonBlocking { // Fallback to single read if non-blocking not enabled - n, err := t.ReadWriteCloser.Read(packets[0]) + n, err := t.Read(packets[0]) if err != nil { return 0, err } @@ -437,10 +611,33 @@ func (t *tun) ReadBatch(packets [][]byte, sizes []int) (int, error) { maxPackets = len(sizes) } + // Choose read buffer based on vnetHdr + // With vnetHdr, we need to read into internal buffer (has space for header) + // then copy packet data to caller's buffer + readBuf := packets[0] // Will be updated in the loop + if t.vnetHdr { + readBuf = t.readBuf + } + for count < maxPackets { - n, err := unix.Read(t.fd, packets[count]) + if !t.vnetHdr { + readBuf = packets[count] + } + + n, err := unix.Read(t.fd, readBuf) if err == nil && n > 0 { - sizes[count] = n + if t.vnetHdr { + if n > virtioNetHdrLen { + packetLen := n - virtioNetHdrLen + copy(packets[count], readBuf[virtioNetHdrLen:n]) + sizes[count] = packetLen + } else { + // Malformed packet (no data after header), skip + continue + } + } else { + sizes[count] = n + } count++ continue }