better and batched tun interface

This commit is contained in:
JackDoan
2026-04-17 10:25:05 -05:00
parent 398d67e2da
commit f95857b4c3
34 changed files with 1189 additions and 483 deletions

View File

@@ -8,6 +8,12 @@ import (
const MTU = 9001
// MaxWriteBatch is the largest batch any Conn.WriteBatch implementation is
// required to accept. Callers SHOULD NOT pass more than this per call; Linux
// backends preallocate sendmmsg scratch sized to this value, so exceeding it
// only costs a chunked retry.
const MaxWriteBatch = 128
type EncReader func(
addr netip.AddrPort,
payload []byte,
@@ -16,8 +22,19 @@ type EncReader func(
type Conn interface {
Rebind() error
LocalAddr() (netip.AddrPort, error)
ListenOut(r EncReader) error
// ListenOut invokes r for each received packet. On batch-capable
// backends (recvmmsg), flush is called after each batch is fully
// delivered — callers use it to flush per-batch accumulators such as
// TUN write coalescers. Single-packet backends call flush after each
// packet. flush must not be nil.
ListenOut(r EncReader, flush func()) error
WriteTo(b []byte, addr netip.AddrPort) error
// WriteBatch sends a contiguous batch of packets, each with its own
// destination. bufs and addrs must have the same length. Linux uses
// sendmmsg(2) for a single syscall; other backends fall back to a
// WriteTo loop. Returns on the first error; callers may observe a
// partial send if some packets went out before the error.
WriteBatch(bufs [][]byte, addrs []netip.AddrPort) error
ReloadConfig(c *config.C)
SupportsMultipleReaders() bool
Close() error
@@ -31,7 +48,7 @@ func (NoopConn) Rebind() error {
func (NoopConn) LocalAddr() (netip.AddrPort, error) {
return netip.AddrPort{}, nil
}
func (NoopConn) ListenOut(_ EncReader) error {
func (NoopConn) ListenOut(_ EncReader, _ func()) error {
return nil
}
func (NoopConn) SupportsMultipleReaders() bool {
@@ -40,6 +57,9 @@ func (NoopConn) SupportsMultipleReaders() bool {
func (NoopConn) WriteTo(_ []byte, _ netip.AddrPort) error {
return nil
}
func (NoopConn) WriteBatch(_ [][]byte, _ []netip.AddrPort) error {
return nil
}
func (NoopConn) ReloadConfig(_ *config.C) {
return
}

View File

@@ -140,6 +140,15 @@ func (u *StdConn) WriteTo(b []byte, ap netip.AddrPort) error {
}
}
func (u *StdConn) WriteBatch(bufs [][]byte, addrs []netip.AddrPort) error {
for i, b := range bufs {
if err := u.WriteTo(b, addrs[i]); err != nil {
return err
}
}
return nil
}
func (u *StdConn) LocalAddr() (netip.AddrPort, error) {
a := u.UDPConn.LocalAddr()
@@ -165,7 +174,7 @@ func NewUDPStatsEmitter(udpConns []Conn) func() {
return func() {}
}
func (u *StdConn) ListenOut(r EncReader) error {
func (u *StdConn) ListenOut(r EncReader, flush func()) error {
buffer := make([]byte, MTU)
for {
@@ -180,6 +189,7 @@ func (u *StdConn) ListenOut(r EncReader) error {
}
r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n])
flush()
}
}

View File

@@ -44,6 +44,15 @@ func (u *GenericConn) WriteTo(b []byte, addr netip.AddrPort) error {
return err
}
func (u *GenericConn) WriteBatch(bufs [][]byte, addrs []netip.AddrPort) error {
for i, b := range bufs {
if _, err := u.UDPConn.WriteToUDPAddrPort(b, addrs[i]); err != nil {
return err
}
}
return nil
}
func (u *GenericConn) LocalAddr() (netip.AddrPort, error) {
a := u.UDPConn.LocalAddr()
@@ -73,7 +82,7 @@ type rawMessage struct {
Len uint32
}
func (u *GenericConn) ListenOut(r EncReader) error {
func (u *GenericConn) ListenOut(r EncReader, flush func()) error {
buffer := make([]byte, MTU)
var lastRecvErr time.Time
@@ -94,6 +103,7 @@ func (u *GenericConn) ListenOut(r EncReader) error {
}
r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n])
flush()
}
}

View File

@@ -171,7 +171,7 @@ func recvmmsg(fd uintptr, msgs []rawMessage) (int, bool, error) {
return int(n), true, nil
}
func (u *StdConn) listenOutSingle(r EncReader) error {
func (u *StdConn) listenOutSingle(r EncReader, flush func()) error {
var err error
var n int
var from netip.AddrPort
@@ -184,15 +184,17 @@ func (u *StdConn) listenOutSingle(r EncReader) error {
}
from = netip.AddrPortFrom(from.Addr().Unmap(), from.Port())
r(from, buffer[:n])
flush()
}
}
func (u *StdConn) listenOutBatch(r EncReader) error {
func (u *StdConn) listenOutBatch(r EncReader, flush func()) error {
var ip netip.Addr
var n int
var operr error
msgs, buffers, names := u.PrepareRawMessages(u.batch)
bufSize := MTU
msgs, buffers, names := u.PrepareRawMessages(u.batch, bufSize)
//reader needs to capture variables from this function, since it's used as a lambda with rawConn.Read
//defining it outside the loop so it gets re-used
@@ -217,16 +219,22 @@ func (u *StdConn) listenOutBatch(r EncReader) error {
} else {
ip, _ = netip.AddrFromSlice(names[i][8:24])
}
r(netip.AddrPortFrom(ip.Unmap(), binary.BigEndian.Uint16(names[i][2:4])), buffers[i][:msgs[i].Len])
from := netip.AddrPortFrom(ip.Unmap(), binary.BigEndian.Uint16(names[i][2:4]))
payload := buffers[i][:msgs[i].Len]
r(from, payload)
}
// End-of-batch: let callers (e.g. TUN write coalescer) flush any
// state they accumulated across this batch.
flush()
}
}
func (u *StdConn) ListenOut(r EncReader) error {
func (u *StdConn) ListenOut(r EncReader, flush func()) error {
if u.batch == 1 {
return u.listenOutSingle(r)
return u.listenOutSingle(r, flush)
} else {
return u.listenOutBatch(r)
return u.listenOutBatch(r, flush)
}
}
@@ -235,6 +243,19 @@ func (u *StdConn) WriteTo(b []byte, ip netip.AddrPort) error {
return err
}
func (u *StdConn) WriteBatch(bufs [][]byte, addrs []netip.AddrPort) error {
if len(bufs) != len(addrs) {
return fmt.Errorf("WriteBatch: len(bufs)=%d != len(addrs)=%d", len(bufs), len(addrs))
}
//todo use sendmmsg
for i := 0; i < len(bufs); i++ {
if _, err := u.udpConn.WriteToUDPAddrPort(bufs[i], addrs[i]); err != nil {
return err
}
}
return nil
}
func (u *StdConn) ReloadConfig(c *config.C) {
b := c.GetInt("listen.read_buffer", 0)
if b > 0 {

View File

@@ -30,13 +30,13 @@ type rawMessage struct {
Len uint32
}
func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) {
func (u *StdConn) PrepareRawMessages(n, bufSize int) ([]rawMessage, [][]byte, [][]byte) {
msgs := make([]rawMessage, n)
buffers := make([][]byte, n)
names := make([][]byte, n)
for i := range msgs {
buffers[i] = make([]byte, MTU)
buffers[i] = make([]byte, bufSize)
names[i] = make([]byte, unix.SizeofSockaddrInet6)
vs := []iovec{
@@ -52,3 +52,19 @@ func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) {
return msgs, buffers, names
}
func setIovLen(v *iovec, n int) {
v.Len = uint32(n)
}
func setMsgIovlen(m *msghdr, n int) {
m.Iovlen = uint32(n)
}
func setMsgControllen(m *msghdr, n int) {
m.Controllen = uint32(n)
}
func setCmsgLen(h *unix.Cmsghdr, n int) {
h.Len = uint32(n)
}

View File

@@ -33,13 +33,13 @@ type rawMessage struct {
Pad0 [4]byte
}
func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) {
func (u *StdConn) PrepareRawMessages(n, bufSize int) ([]rawMessage, [][]byte, [][]byte) {
msgs := make([]rawMessage, n)
buffers := make([][]byte, n)
names := make([][]byte, n)
for i := range msgs {
buffers[i] = make([]byte, MTU)
buffers[i] = make([]byte, bufSize)
names[i] = make([]byte, unix.SizeofSockaddrInet6)
vs := []iovec{
@@ -55,3 +55,19 @@ func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) {
return msgs, buffers, names
}
func setIovLen(v *iovec, n int) {
v.Len = uint64(n)
}
func setMsgIovlen(m *msghdr, n int) {
m.Iovlen = uint64(n)
}
func setMsgControllen(m *msghdr, n int) {
m.Controllen = uint64(n)
}
func setCmsgLen(h *unix.Cmsghdr, n int) {
h.Len = uint64(n)
}

View File

@@ -140,7 +140,7 @@ func (u *RIOConn) bind(l *slog.Logger, sa windows.Sockaddr) error {
return nil
}
func (u *RIOConn) ListenOut(r EncReader) error {
func (u *RIOConn) ListenOut(r EncReader, flush func()) error {
buffer := make([]byte, MTU)
var lastRecvErr time.Time
@@ -162,6 +162,7 @@ func (u *RIOConn) ListenOut(r EncReader) error {
}
r(netip.AddrPortFrom(netip.AddrFrom16(rua.Addr).Unmap(), (rua.Port>>8)|((rua.Port&0xff)<<8)), buffer[:n])
flush()
}
}
@@ -316,6 +317,15 @@ func (u *RIOConn) WriteTo(buf []byte, ip netip.AddrPort) error {
return winrio.SendEx(u.rq, dataBuffer, 1, nil, addressBuffer, nil, nil, 0, 0)
}
func (u *RIOConn) WriteBatch(bufs [][]byte, addrs []netip.AddrPort) error {
for i, b := range bufs {
if err := u.WriteTo(b, addrs[i]); err != nil {
return err
}
}
return nil
}
func (u *RIOConn) LocalAddr() (netip.AddrPort, error) {
sa, err := windows.Getsockname(u.sock)
if err != nil {

View File

@@ -157,8 +157,16 @@ func (u *TesterConn) WriteTo(b []byte, addr netip.AddrPort) error {
return nil
}
}
func (u *TesterConn) WriteBatch(bufs [][]byte, addrs []netip.AddrPort) error {
for i, b := range bufs {
if err := u.WriteTo(b, addrs[i]); err != nil {
return err
}
}
return nil
}
func (u *TesterConn) ListenOut(r EncReader) error {
func (u *TesterConn) ListenOut(r EncReader, flush func()) error {
for {
select {
case <-u.done:
@@ -166,6 +174,7 @@ func (u *TesterConn) ListenOut(r EncReader) error {
case p := <-u.RxPackets:
r(p.From, p.Data)
p.Release()
flush()
}
}
}