diff --git a/inside.go b/inside.go index 0d53f952..369b86b6 100644 --- a/inside.go +++ b/inside.go @@ -69,7 +69,6 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet dropReason := f.firewall.Drop(*fwPacket, false, hostinfo, f.pki.GetCAPool(), localCache) if dropReason == nil { f.sendNoMetrics(header.Message, 0, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, packet, nb, out, q) - } else { f.rejectInside(packet, out, q) if f.l.Level >= logrus.DebugLevel { diff --git a/udp/conn.go b/udp/conn.go index 1ae585c2..95f20d53 100644 --- a/udp/conn.go +++ b/udp/conn.go @@ -13,11 +13,18 @@ type EncReader func( payload []byte, ) +// BatchPacket represents a single packet in a batch write operation +type BatchPacket struct { + Payload []byte + Addr netip.AddrPort +} + type Conn interface { Rebind() error LocalAddr() (netip.AddrPort, error) ListenOut(r EncReader) WriteTo(b []byte, addr netip.AddrPort) error + WriteBatch(pkts []BatchPacket) (int, error) ReloadConfig(c *config.C) SupportsMultipleReaders() bool Close() error @@ -40,6 +47,9 @@ func (NoopConn) SupportsMultipleReaders() bool { func (NoopConn) WriteTo(_ []byte, _ netip.AddrPort) error { return nil } +func (NoopConn) WriteBatch(pkts []BatchPacket) (int, error) { + return len(pkts), nil +} func (NoopConn) ReloadConfig(_ *config.C) { return } diff --git a/udp/udp_darwin.go b/udp/udp_darwin.go index 91201194..8360ed4e 100644 --- a/udp/udp_darwin.go +++ b/udp/udp_darwin.go @@ -202,3 +202,12 @@ func (u *StdConn) Rebind() error { return nil } + +func (u *StdConn) WriteBatch(pkts []BatchPacket) (int, error) { + for i := range pkts { + if err := u.WriteTo(pkts[i].Payload, pkts[i].Addr); err != nil { + return i, err + } + } + return len(pkts), nil +} diff --git a/udp/udp_generic.go b/udp/udp_generic.go index e9dad6c5..bd47f40c 100644 --- a/udp/udp_generic.go +++ b/udp/udp_generic.go @@ -101,3 +101,12 @@ func (u *GenericConn) ListenOut(r EncReader) { func (u *GenericConn) SupportsMultipleReaders() bool { return false } + +func (u *GenericConn) WriteBatch(pkts []BatchPacket) (int, error) { + for i := range pkts { + if err := u.WriteTo(pkts[i].Payload, pkts[i].Addr); err != nil { + return i, err + } + } + return len(pkts), nil +} diff --git a/udp/udp_linux.go b/udp/udp_linux.go index e7759329..392f19ce 100644 --- a/udp/udp_linux.go +++ b/udp/udp_linux.go @@ -313,6 +313,69 @@ func (u *StdConn) Close() error { return syscall.Close(u.sysFd) } +func (u *StdConn) WriteBatch(pkts []BatchPacket) (int, error) { + if len(pkts) == 0 { + return 0, nil + } + + msgs := make([]rawMessage, len(pkts)) + iovecs := make([]iovec, len(pkts)) + var names4 []unix.RawSockaddrInet4 + var names6 []unix.RawSockaddrInet6 + + if u.isV4 { + names4 = make([]unix.RawSockaddrInet4, len(pkts)) + } else { + names6 = make([]unix.RawSockaddrInet6, len(pkts)) + } + + for i := range pkts { + setIovecBase(&iovecs[i], &pkts[i].Payload[0]) + setIovecLen(&iovecs[i], len(pkts[i].Payload)) + msgs[i].Hdr.Iov = &iovecs[i] + setMsghdrIovlen(&msgs[i].Hdr, 1) + + if u.isV4 { + names4[i].Family = unix.AF_INET + names4[i].Addr = pkts[i].Addr.Addr().As4() + binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&names4[i].Port))[:], pkts[i].Addr.Port()) + msgs[i].Hdr.Name = (*byte)(unsafe.Pointer(&names4[i])) + msgs[i].Hdr.Namelen = unix.SizeofSockaddrInet4 + } else { + names6[i].Family = unix.AF_INET6 + names6[i].Addr = pkts[i].Addr.Addr().As16() + binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&names6[i].Port))[:], pkts[i].Addr.Port()) + msgs[i].Hdr.Name = (*byte)(unsafe.Pointer(&names6[i])) + msgs[i].Hdr.Namelen = unix.SizeofSockaddrInet6 + } + } + + var sent int + for sent < len(msgs) { + n, _, errno := unix.Syscall6( + unix.SYS_SENDMMSG, + uintptr(u.sysFd), + uintptr(unsafe.Pointer(&msgs[sent])), + uintptr(len(msgs)-sent), + 0, + 0, + 0, + ) + + if errno == unix.EINTR { + continue + } + + if errno != 0 { + return sent, &net.OpError{Op: "sendmmsg", Err: errno} + } + + sent += int(n) + } + + return sent, nil +} + func NewUDPStatsEmitter(udpConns []Conn) func() { // Check if our kernel supports SO_MEMINFO before registering the gauges var udpGauges [][unix.SK_MEMINFO_VARS]metrics.Gauge diff --git a/udp/udp_linux_32.go b/udp/udp_linux_32.go index de8f1cdf..1dec8793 100644 --- a/udp/udp_linux_32.go +++ b/udp/udp_linux_32.go @@ -52,3 +52,15 @@ func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) { return msgs, buffers, names } + +func setIovecBase(iov *iovec, base *byte) { + iov.Base = base +} + +func setIovecLen(iov *iovec, l int) { + iov.Len = uint32(l) +} + +func setMsghdrIovlen(hdr *msghdr, l int) { + hdr.Iovlen = uint32(l) +} diff --git a/udp/udp_linux_64.go b/udp/udp_linux_64.go index 48c5a978..430f16ca 100644 --- a/udp/udp_linux_64.go +++ b/udp/udp_linux_64.go @@ -55,3 +55,15 @@ func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) { return msgs, buffers, names } + +func setIovecBase(iov *iovec, base *byte) { + iov.Base = base +} + +func setIovecLen(iov *iovec, l int) { + iov.Len = uint64(l) +} + +func setMsghdrIovlen(hdr *msghdr, l int) { + hdr.Iovlen = uint64(l) +} diff --git a/udp/udp_rio_windows.go b/udp/udp_rio_windows.go index 3d60f34c..a456bf1f 100644 --- a/udp/udp_rio_windows.go +++ b/udp/udp_rio_windows.go @@ -338,6 +338,15 @@ func (u *RIOConn) Rebind() error { func (u *RIOConn) ReloadConfig(*config.C) {} +func (u *RIOConn) WriteBatch(pkts []BatchPacket) (int, error) { + for i := range pkts { + if err := u.WriteTo(pkts[i].Payload, pkts[i].Addr); err != nil { + return i, err + } + } + return len(pkts), nil +} + func (u *RIOConn) Close() error { if !u.isOpen.CompareAndSwap(true, false) { return nil diff --git a/udp/udp_tester.go b/udp/udp_tester.go index 5f0f7765..6f881110 100644 --- a/udp/udp_tester.go +++ b/udp/udp_tester.go @@ -142,3 +142,12 @@ func (u *TesterConn) Close() error { } return nil } + +func (u *TesterConn) WriteBatch(pkts []BatchPacket) (int, error) { + for i := range pkts { + if err := u.WriteTo(pkts[i].Payload, pkts[i].Addr); err != nil { + return i, err + } + } + return len(pkts), nil +}