mirror of
https://github.com/slackhq/nebula.git
synced 2025-11-23 08:54:25 +01:00
try with sendmmsg merged back
This commit is contained in:
@@ -18,10 +18,16 @@ type Conn interface {
|
||||
LocalAddr() (netip.AddrPort, error)
|
||||
ListenOut(r EncReader) error
|
||||
WriteTo(b []byte, addr netip.AddrPort) error
|
||||
WriteBatch(pkts []BatchPacket) (int, error)
|
||||
ReloadConfig(c *config.C)
|
||||
Close() error
|
||||
}
|
||||
|
||||
type BatchPacket struct {
|
||||
Payload []byte
|
||||
Addr netip.AddrPort
|
||||
}
|
||||
|
||||
type NoopConn struct{}
|
||||
|
||||
func (NoopConn) Rebind() error {
|
||||
@@ -36,6 +42,9 @@ func (NoopConn) ListenOut(_ EncReader) error {
|
||||
func (NoopConn) WriteTo(_ []byte, _ netip.AddrPort) error {
|
||||
return nil
|
||||
}
|
||||
func (NoopConn) WriteBatch(_ []BatchPacket) (int, error) {
|
||||
return 0, nil
|
||||
}
|
||||
func (NoopConn) ReloadConfig(_ *config.C) {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -140,6 +140,17 @@ func (u *StdConn) WriteTo(b []byte, ap netip.AddrPort) error {
|
||||
}
|
||||
}
|
||||
|
||||
func (u *StdConn) WriteBatch(pkts []BatchPacket) (int, error) {
|
||||
sent := 0
|
||||
for _, pkt := range pkts {
|
||||
if err := u.WriteTo(pkt.Payload, pkt.Addr); err != nil {
|
||||
return sent, err
|
||||
}
|
||||
sent++
|
||||
}
|
||||
return sent, nil
|
||||
}
|
||||
|
||||
func (u *StdConn) LocalAddr() (netip.AddrPort, error) {
|
||||
a := u.UDPConn.LocalAddr()
|
||||
|
||||
|
||||
@@ -42,6 +42,17 @@ func (u *GenericConn) WriteTo(b []byte, addr netip.AddrPort) error {
|
||||
return err
|
||||
}
|
||||
|
||||
func (u *GenericConn) WriteBatch(pkts []BatchPacket) (int, error) {
|
||||
sent := 0
|
||||
for _, pkt := range pkts {
|
||||
if err := u.WriteTo(pkt.Payload, pkt.Addr); err != nil {
|
||||
return sent, err
|
||||
}
|
||||
sent++
|
||||
}
|
||||
return sent, nil
|
||||
}
|
||||
|
||||
func (u *GenericConn) LocalAddr() (netip.AddrPort, error) {
|
||||
a := u.UDPConn.LocalAddr()
|
||||
|
||||
|
||||
112
udp/udp_linux.go
112
udp/udp_linux.go
@@ -343,6 +343,118 @@ func (u *StdConn) WriteTo(b []byte, ip netip.AddrPort) error {
|
||||
return u.writeTo6(b, ip)
|
||||
}
|
||||
|
||||
func (u *StdConn) WriteBatch(pkts []BatchPacket) (int, error) {
|
||||
if len(pkts) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
msgs := make([]rawMessage, 0, len(pkts))
|
||||
iovs := make([]iovec, 0, len(pkts))
|
||||
names := make([][unix.SizeofSockaddrInet6]byte, 0, len(pkts))
|
||||
|
||||
sent := 0
|
||||
|
||||
for _, pkt := range pkts {
|
||||
if len(pkt.Payload) == 0 {
|
||||
sent++
|
||||
continue
|
||||
}
|
||||
|
||||
if u.enableGSO && pkt.Addr.IsValid() {
|
||||
if err := u.queueGSOPacket(pkt.Payload, pkt.Addr); err == nil {
|
||||
sent++
|
||||
continue
|
||||
} else if !errors.Is(err, errGSOFallback) {
|
||||
return sent, err
|
||||
}
|
||||
}
|
||||
|
||||
if !pkt.Addr.IsValid() {
|
||||
if err := u.WriteTo(pkt.Payload, pkt.Addr); err != nil {
|
||||
return sent, err
|
||||
}
|
||||
sent++
|
||||
continue
|
||||
}
|
||||
|
||||
msgs = append(msgs, rawMessage{})
|
||||
iovs = append(iovs, iovec{})
|
||||
names = append(names, [unix.SizeofSockaddrInet6]byte{})
|
||||
|
||||
idx := len(msgs) - 1
|
||||
msg := &msgs[idx]
|
||||
iov := &iovs[idx]
|
||||
name := &names[idx]
|
||||
|
||||
setIovecSlice(iov, pkt.Payload)
|
||||
msg.Hdr.Iov = iov
|
||||
msg.Hdr.Iovlen = 1
|
||||
setRawMessageControl(msg, nil)
|
||||
msg.Hdr.Flags = 0
|
||||
|
||||
nameLen, err := u.encodeSockaddr(name[:], pkt.Addr)
|
||||
if err != nil {
|
||||
return sent, err
|
||||
}
|
||||
msg.Hdr.Name = &name[0]
|
||||
msg.Hdr.Namelen = nameLen
|
||||
}
|
||||
|
||||
if len(msgs) == 0 {
|
||||
return sent, nil
|
||||
}
|
||||
|
||||
offset := 0
|
||||
for offset < len(msgs) {
|
||||
n, _, errno := unix.Syscall6(
|
||||
unix.SYS_SENDMMSG,
|
||||
uintptr(u.sysFd),
|
||||
uintptr(unsafe.Pointer(&msgs[offset])),
|
||||
uintptr(len(msgs)-offset),
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
)
|
||||
|
||||
if errno != 0 {
|
||||
if errno == unix.EINTR {
|
||||
continue
|
||||
}
|
||||
return sent + offset, &net.OpError{Op: "sendmmsg", Err: errno}
|
||||
}
|
||||
|
||||
if n == 0 {
|
||||
break
|
||||
}
|
||||
offset += int(n)
|
||||
}
|
||||
|
||||
return sent + len(msgs), nil
|
||||
}
|
||||
|
||||
func (u *StdConn) encodeSockaddr(dst []byte, addr netip.AddrPort) (uint32, error) {
|
||||
if u.isV4 {
|
||||
if !addr.Addr().Is4() {
|
||||
return 0, fmt.Errorf("Listener is IPv4, but writing to IPv6 remote")
|
||||
}
|
||||
var sa unix.RawSockaddrInet4
|
||||
sa.Family = unix.AF_INET
|
||||
sa.Addr = addr.Addr().As4()
|
||||
binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&sa.Port))[:], addr.Port())
|
||||
size := unix.SizeofSockaddrInet4
|
||||
copy(dst[:size], (*(*[unix.SizeofSockaddrInet4]byte)(unsafe.Pointer(&sa)))[:])
|
||||
return uint32(size), nil
|
||||
}
|
||||
|
||||
var sa unix.RawSockaddrInet6
|
||||
sa.Family = unix.AF_INET6
|
||||
sa.Addr = addr.Addr().As16()
|
||||
binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&sa.Port))[:], addr.Port())
|
||||
size := unix.SizeofSockaddrInet6
|
||||
copy(dst[:size], (*(*[unix.SizeofSockaddrInet6]byte)(unsafe.Pointer(&sa)))[:])
|
||||
return uint32(size), nil
|
||||
}
|
||||
|
||||
func (u *StdConn) writeTo6(b []byte, ip netip.AddrPort) error {
|
||||
var rsa unix.RawSockaddrInet6
|
||||
rsa.Family = unix.AF_INET6
|
||||
|
||||
@@ -77,3 +77,13 @@ func getRawMessageFlags(msg *rawMessage) int {
|
||||
func setCmsgLen(h *unix.Cmsghdr, l int) {
|
||||
h.Len = uint32(l)
|
||||
}
|
||||
|
||||
func setIovecSlice(iov *iovec, b []byte) {
|
||||
if len(b) == 0 {
|
||||
iov.Base = nil
|
||||
iov.Len = 0
|
||||
return
|
||||
}
|
||||
iov.Base = &b[0]
|
||||
iov.Len = uint32(len(b))
|
||||
}
|
||||
|
||||
@@ -80,3 +80,13 @@ func getRawMessageFlags(msg *rawMessage) int {
|
||||
func setCmsgLen(h *unix.Cmsghdr, l int) {
|
||||
h.Len = uint64(l)
|
||||
}
|
||||
|
||||
func setIovecSlice(iov *iovec, b []byte) {
|
||||
if len(b) == 0 {
|
||||
iov.Base = nil
|
||||
iov.Len = 0
|
||||
return
|
||||
}
|
||||
iov.Base = &b[0]
|
||||
iov.Len = uint64(len(b))
|
||||
}
|
||||
|
||||
@@ -304,6 +304,17 @@ func (u *RIOConn) WriteTo(buf []byte, ip netip.AddrPort) error {
|
||||
return winrio.SendEx(u.rq, dataBuffer, 1, nil, addressBuffer, nil, nil, 0, 0)
|
||||
}
|
||||
|
||||
func (u *RIOConn) WriteBatch(pkts []BatchPacket) (int, error) {
|
||||
sent := 0
|
||||
for _, pkt := range pkts {
|
||||
if err := u.WriteTo(pkt.Payload, pkt.Addr); err != nil {
|
||||
return sent, err
|
||||
}
|
||||
sent++
|
||||
}
|
||||
return sent, nil
|
||||
}
|
||||
|
||||
func (u *RIOConn) LocalAddr() (netip.AddrPort, error) {
|
||||
sa, err := windows.Getsockname(u.sock)
|
||||
if err != nil {
|
||||
|
||||
@@ -106,6 +106,17 @@ func (u *TesterConn) WriteTo(b []byte, addr netip.AddrPort) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (u *TesterConn) WriteBatch(pkts []BatchPacket) (int, error) {
|
||||
sent := 0
|
||||
for _, pkt := range pkts {
|
||||
if err := u.WriteTo(pkt.Payload, pkt.Addr); err != nil {
|
||||
return sent, err
|
||||
}
|
||||
sent++
|
||||
}
|
||||
return sent, nil
|
||||
}
|
||||
|
||||
func (u *TesterConn) ListenOut(r EncReader) {
|
||||
for {
|
||||
p, ok := <-u.RxPackets
|
||||
|
||||
Reference in New Issue
Block a user