mirror of
https://github.com/slackhq/nebula.git
synced 2026-02-14 08:44:24 +01:00
claude does TUN virtio header support
This commit is contained in:
@@ -24,6 +24,11 @@ import (
|
|||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// virtioNetHdrLen is the length of virtio_net_hdr (without mergeable buffers)
|
||||||
|
virtioNetHdrLen = 10
|
||||||
|
)
|
||||||
|
|
||||||
type tun struct {
|
type tun struct {
|
||||||
io.ReadWriteCloser
|
io.ReadWriteCloser
|
||||||
fd int
|
fd int
|
||||||
@@ -35,6 +40,12 @@ type tun struct {
|
|||||||
deviceIndex int
|
deviceIndex int
|
||||||
ioctlFd uintptr
|
ioctlFd uintptr
|
||||||
nonBlocking bool // true if fd is in non-blocking mode
|
nonBlocking bool // true if fd is in non-blocking mode
|
||||||
|
vnetHdr bool // true if IFF_VNET_HDR is enabled on the TUN device
|
||||||
|
|
||||||
|
// readBuf is used when vnetHdr is enabled to read the full packet+header
|
||||||
|
// before stripping the header. This is needed because caller-provided
|
||||||
|
// buffers are sized for MTU but kernel writes MTU+10 with virtio header.
|
||||||
|
readBuf []byte
|
||||||
|
|
||||||
Routes atomic.Pointer[[]Route]
|
Routes atomic.Pointer[[]Route]
|
||||||
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
||||||
@@ -54,6 +65,23 @@ func (t *tun) Networks() []netip.Prefix {
|
|||||||
return t.vpnNetworks
|
return t.vpnNetworks
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// tunVnetHdrSupported checks if the kernel supports IFF_VNET_HDR on TUN devices
|
||||||
|
func tunVnetHdrSupported() bool {
|
||||||
|
fd, err := unix.Open("/dev/net/tun", unix.O_RDONLY, 0)
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
defer unix.Close(fd)
|
||||||
|
|
||||||
|
var features uint32
|
||||||
|
err = ioctl(uintptr(fd), uintptr(unix.TUNGETFEATURES), uintptr(unsafe.Pointer(&features)))
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return features&unix.IFF_VNET_HDR != 0
|
||||||
|
}
|
||||||
|
|
||||||
type ifReq struct {
|
type ifReq struct {
|
||||||
Name [16]byte
|
Name [16]byte
|
||||||
Flags uint16
|
Flags uint16
|
||||||
@@ -108,11 +136,18 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueu
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check if VNET_HDR is supported before trying to use it
|
||||||
|
useVnetHdr := tunVnetHdrSupported()
|
||||||
|
|
||||||
var req ifReq
|
var req ifReq
|
||||||
req.Flags = uint16(unix.IFF_TUN | unix.IFF_NO_PI)
|
req.Flags = uint16(unix.IFF_TUN | unix.IFF_NO_PI)
|
||||||
if multiqueue {
|
if multiqueue {
|
||||||
req.Flags |= unix.IFF_MULTI_QUEUE
|
req.Flags |= unix.IFF_MULTI_QUEUE
|
||||||
}
|
}
|
||||||
|
if useVnetHdr {
|
||||||
|
req.Flags |= unix.IFF_VNET_HDR
|
||||||
|
}
|
||||||
|
|
||||||
nameStr := c.GetString("tun.dev", "")
|
nameStr := c.GetString("tun.dev", "")
|
||||||
copy(req.Name[:], nameStr)
|
copy(req.Name[:], nameStr)
|
||||||
if err = ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil {
|
if err = ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil {
|
||||||
@@ -123,6 +158,13 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueu
|
|||||||
}
|
}
|
||||||
name := strings.Trim(string(req.Name[:]), "\x00")
|
name := strings.Trim(string(req.Name[:]), "\x00")
|
||||||
|
|
||||||
|
// Track if VNET_HDR is in use
|
||||||
|
// Note: We don't call TUNSETOFFLOAD - just handle the headers manually
|
||||||
|
vnetHdrEnabled := useVnetHdr
|
||||||
|
if vnetHdrEnabled {
|
||||||
|
l.Info("TUN VNET_HDR enabled")
|
||||||
|
}
|
||||||
|
|
||||||
file := os.NewFile(uintptr(fd), "/dev/net/tun")
|
file := os.NewFile(uintptr(fd), "/dev/net/tun")
|
||||||
t, err := newTunGeneric(c, l, file, vpnNetworks)
|
t, err := newTunGeneric(c, l, file, vpnNetworks)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -130,6 +172,13 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueu
|
|||||||
}
|
}
|
||||||
|
|
||||||
t.Device = name
|
t.Device = name
|
||||||
|
t.vnetHdr = vnetHdrEnabled
|
||||||
|
|
||||||
|
// Allocate read buffer for virtio header handling
|
||||||
|
// Buffer needs to be large enough for virtio header + max packet
|
||||||
|
if t.vnetHdr {
|
||||||
|
t.readBuf = make([]byte, t.MaxMTU+virtioNetHdrLen)
|
||||||
|
}
|
||||||
|
|
||||||
return t, nil
|
return t, nil
|
||||||
}
|
}
|
||||||
@@ -247,25 +296,48 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
|||||||
|
|
||||||
var req ifReq
|
var req ifReq
|
||||||
req.Flags = uint16(unix.IFF_TUN | unix.IFF_NO_PI | unix.IFF_MULTI_QUEUE)
|
req.Flags = uint16(unix.IFF_TUN | unix.IFF_NO_PI | unix.IFF_MULTI_QUEUE)
|
||||||
|
if t.vnetHdr {
|
||||||
|
req.Flags |= unix.IFF_VNET_HDR
|
||||||
|
}
|
||||||
copy(req.Name[:], t.Device)
|
copy(req.Name[:], t.Device)
|
||||||
if err = ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil {
|
if err = ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return &tunBatchReader{fd: fd, device: t.Device}, nil
|
reader := &tunBatchReader{fd: fd, device: t.Device, vnetHdr: t.vnetHdr}
|
||||||
|
if t.vnetHdr {
|
||||||
|
reader.readBuf = make([]byte, t.MaxMTU+virtioNetHdrLen)
|
||||||
|
}
|
||||||
|
return reader, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// tunBatchReader implements BatchReader for efficient batch packet reading
|
// tunBatchReader implements BatchReader for efficient batch packet reading
|
||||||
type tunBatchReader struct {
|
type tunBatchReader struct {
|
||||||
fd int
|
fd int
|
||||||
device string
|
device string
|
||||||
|
vnetHdr bool
|
||||||
|
readBuf []byte // internal buffer for virtio header handling
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *tunBatchReader) Read(b []byte) (int, error) {
|
func (r *tunBatchReader) Read(b []byte) (int, error) {
|
||||||
|
// Choose buffer: use internal buffer for vnetHdr, caller's buffer otherwise
|
||||||
|
readBuf := b
|
||||||
|
if r.vnetHdr {
|
||||||
|
readBuf = r.readBuf
|
||||||
|
}
|
||||||
|
|
||||||
// Use poll to wait for data, then read
|
// Use poll to wait for data, then read
|
||||||
for {
|
for {
|
||||||
n, err := unix.Read(r.fd, b)
|
n, err := unix.Read(r.fd, readBuf)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
|
if r.vnetHdr && n > virtioNetHdrLen {
|
||||||
|
packetLen := n - virtioNetHdrLen
|
||||||
|
copy(b, readBuf[virtioNetHdrLen:n])
|
||||||
|
return packetLen, nil
|
||||||
|
}
|
||||||
|
if r.vnetHdr {
|
||||||
|
return 0, nil // No packet data
|
||||||
|
}
|
||||||
return n, nil
|
return n, nil
|
||||||
}
|
}
|
||||||
if err == unix.EAGAIN || err == unix.EWOULDBLOCK {
|
if err == unix.EAGAIN || err == unix.EWOULDBLOCK {
|
||||||
@@ -285,7 +357,24 @@ func (r *tunBatchReader) Read(b []byte) (int, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *tunBatchReader) Write(b []byte) (int, error) {
|
func (r *tunBatchReader) Write(b []byte) (int, error) {
|
||||||
return unix.Write(r.fd, b)
|
if !r.vnetHdr {
|
||||||
|
return unix.Write(r.fd, b)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use writev to prepend virtio header without copying the packet data
|
||||||
|
// Header is all zeros = no GSO, no checksum offload
|
||||||
|
var hdr [virtioNetHdrLen]byte
|
||||||
|
bufs := [][]byte{hdr[:], b}
|
||||||
|
|
||||||
|
n, err := unix.Writev(r.fd, bufs)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
// Return only the packet bytes written (exclude header)
|
||||||
|
if n > virtioNetHdrLen {
|
||||||
|
return n - virtioNetHdrLen, nil
|
||||||
|
}
|
||||||
|
return 0, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *tunBatchReader) Close() error {
|
func (r *tunBatchReader) Close() error {
|
||||||
@@ -302,10 +391,31 @@ func (r *tunBatchReader) ReadBatch(packets [][]byte, sizes []int) (int, error) {
|
|||||||
maxPackets = len(sizes)
|
maxPackets = len(sizes)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Choose read buffer based on vnetHdr
|
||||||
|
readBuf := packets[0] // Will be updated in loop for non-vnetHdr
|
||||||
|
if r.vnetHdr {
|
||||||
|
readBuf = r.readBuf
|
||||||
|
}
|
||||||
|
|
||||||
for count < maxPackets {
|
for count < maxPackets {
|
||||||
n, err := unix.Read(r.fd, packets[count])
|
if !r.vnetHdr {
|
||||||
|
readBuf = packets[count]
|
||||||
|
}
|
||||||
|
|
||||||
|
n, err := unix.Read(r.fd, readBuf)
|
||||||
if err == nil && n > 0 {
|
if err == nil && n > 0 {
|
||||||
sizes[count] = n
|
if r.vnetHdr {
|
||||||
|
if n > virtioNetHdrLen {
|
||||||
|
packetLen := n - virtioNetHdrLen
|
||||||
|
copy(packets[count], readBuf[virtioNetHdrLen:n])
|
||||||
|
sizes[count] = packetLen
|
||||||
|
} else {
|
||||||
|
// Malformed packet (no data after header), skip
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
sizes[count] = n
|
||||||
|
}
|
||||||
count++
|
count++
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -353,6 +463,27 @@ func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) Write(b []byte) (int, error) {
|
func (t *tun) Write(b []byte) (int, error) {
|
||||||
|
if !t.vnetHdr {
|
||||||
|
return t.writeSimple(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use writev to prepend virtio header without copying the packet data
|
||||||
|
// Header is all zeros = no GSO, no checksum offload
|
||||||
|
var hdr [virtioNetHdrLen]byte
|
||||||
|
bufs := [][]byte{hdr[:], b}
|
||||||
|
|
||||||
|
n, err := unix.Writev(t.fd, bufs)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
// Return only the packet bytes written (exclude header)
|
||||||
|
if n > virtioNetHdrLen {
|
||||||
|
return n - virtioNetHdrLen, nil
|
||||||
|
}
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *tun) writeSimple(b []byte) (int, error) {
|
||||||
var nn int
|
var nn int
|
||||||
maximum := len(b)
|
maximum := len(b)
|
||||||
|
|
||||||
@@ -389,21 +520,64 @@ func (t *tun) EnableBatchReading() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Read overrides the default Read to handle non-blocking mode
|
// Read overrides the default Read to handle non-blocking mode and virtio headers
|
||||||
func (t *tun) Read(b []byte) (int, error) {
|
func (t *tun) Read(b []byte) (int, error) {
|
||||||
|
if !t.vnetHdr {
|
||||||
|
return t.readSimple(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
// With VNET_HDR, read into internal buffer (which has space for header)
|
||||||
|
// then copy packet data to caller's buffer
|
||||||
if !t.nonBlocking {
|
if !t.nonBlocking {
|
||||||
// Use the embedded ReadWriteCloser for blocking reads
|
n, err := t.ReadWriteCloser.Read(t.readBuf)
|
||||||
return t.ReadWriteCloser.Read(b)
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
if n <= virtioNetHdrLen {
|
||||||
|
return 0, nil // No packet data
|
||||||
|
}
|
||||||
|
packetLen := n - virtioNetHdrLen
|
||||||
|
copy(b, t.readBuf[virtioNetHdrLen:n])
|
||||||
|
return packetLen, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Non-blocking read with poll
|
// Non-blocking read with poll
|
||||||
|
for {
|
||||||
|
n, err := unix.Read(t.fd, t.readBuf)
|
||||||
|
if err == nil {
|
||||||
|
if n <= virtioNetHdrLen {
|
||||||
|
return 0, nil // No packet data
|
||||||
|
}
|
||||||
|
packetLen := n - virtioNetHdrLen
|
||||||
|
copy(b, t.readBuf[virtioNetHdrLen:n])
|
||||||
|
return packetLen, nil
|
||||||
|
}
|
||||||
|
if err == unix.EAGAIN || err == unix.EWOULDBLOCK {
|
||||||
|
pfds := []unix.PollFd{{Fd: int32(t.fd), Events: unix.POLLIN}}
|
||||||
|
_, err = unix.Poll(pfds, -1)
|
||||||
|
if err != nil {
|
||||||
|
if err == unix.EINTR {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *tun) readSimple(b []byte) (int, error) {
|
||||||
|
if !t.nonBlocking {
|
||||||
|
return t.ReadWriteCloser.Read(b)
|
||||||
|
}
|
||||||
|
|
||||||
for {
|
for {
|
||||||
n, err := unix.Read(t.fd, b)
|
n, err := unix.Read(t.fd, b)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
return n, nil
|
return n, nil
|
||||||
}
|
}
|
||||||
if err == unix.EAGAIN || err == unix.EWOULDBLOCK {
|
if err == unix.EAGAIN || err == unix.EWOULDBLOCK {
|
||||||
// Wait for data
|
|
||||||
pfds := []unix.PollFd{{Fd: int32(t.fd), Events: unix.POLLIN}}
|
pfds := []unix.PollFd{{Fd: int32(t.fd), Events: unix.POLLIN}}
|
||||||
_, err = unix.Poll(pfds, -1)
|
_, err = unix.Poll(pfds, -1)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -423,7 +597,7 @@ func (t *tun) Read(b []byte) (int, error) {
|
|||||||
func (t *tun) ReadBatch(packets [][]byte, sizes []int) (int, error) {
|
func (t *tun) ReadBatch(packets [][]byte, sizes []int) (int, error) {
|
||||||
if !t.nonBlocking {
|
if !t.nonBlocking {
|
||||||
// Fallback to single read if non-blocking not enabled
|
// Fallback to single read if non-blocking not enabled
|
||||||
n, err := t.ReadWriteCloser.Read(packets[0])
|
n, err := t.Read(packets[0])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
@@ -437,10 +611,33 @@ func (t *tun) ReadBatch(packets [][]byte, sizes []int) (int, error) {
|
|||||||
maxPackets = len(sizes)
|
maxPackets = len(sizes)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Choose read buffer based on vnetHdr
|
||||||
|
// With vnetHdr, we need to read into internal buffer (has space for header)
|
||||||
|
// then copy packet data to caller's buffer
|
||||||
|
readBuf := packets[0] // Will be updated in the loop
|
||||||
|
if t.vnetHdr {
|
||||||
|
readBuf = t.readBuf
|
||||||
|
}
|
||||||
|
|
||||||
for count < maxPackets {
|
for count < maxPackets {
|
||||||
n, err := unix.Read(t.fd, packets[count])
|
if !t.vnetHdr {
|
||||||
|
readBuf = packets[count]
|
||||||
|
}
|
||||||
|
|
||||||
|
n, err := unix.Read(t.fd, readBuf)
|
||||||
if err == nil && n > 0 {
|
if err == nil && n > 0 {
|
||||||
sizes[count] = n
|
if t.vnetHdr {
|
||||||
|
if n > virtioNetHdrLen {
|
||||||
|
packetLen := n - virtioNetHdrLen
|
||||||
|
copy(packets[count], readBuf[virtioNetHdrLen:n])
|
||||||
|
sizes[count] = packetLen
|
||||||
|
} else {
|
||||||
|
// Malformed packet (no data after header), skip
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
sizes[count] = n
|
||||||
|
}
|
||||||
count++
|
count++
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user