This commit is contained in:
JackDoan
2025-11-05 15:38:47 -06:00
parent 2d128a3254
commit befba57366
3 changed files with 380 additions and 354 deletions

View File

@@ -1,12 +1,14 @@
// SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
// *
// Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package conn package conn
import ( import (
"context" "context"
"errors" "errors"
"fmt"
"net" "net"
"net/netip" "net/netip"
"runtime" "runtime"
@@ -16,7 +18,6 @@ import (
"golang.org/x/net/ipv4" "golang.org/x/net/ipv4"
"golang.org/x/net/ipv6" "golang.org/x/net/ipv6"
"golang.org/x/sys/unix"
) )
var ( var (
@@ -29,28 +30,53 @@ var (
// methods for sending and receiving multiple datagrams per-syscall. See the // methods for sending and receiving multiple datagrams per-syscall. See the
// proposal in https://github.com/golang/go/issues/45886#issuecomment-1218301564. // proposal in https://github.com/golang/go/issues/45886#issuecomment-1218301564.
type StdNetBind struct { type StdNetBind struct {
mu sync.Mutex // protects all fields except as specified mu sync.Mutex // protects all fields except as specified
ipv4 *net.UDPConn ipv4 *net.UDPConn
ipv6 *net.UDPConn ipv6 *net.UDPConn
ipv4PC *ipv4.PacketConn // will be nil on non-Linux ipv4PC *ipv4.PacketConn // will be nil on non-Linux
ipv6PC *ipv6.PacketConn // will be nil on non-Linux ipv6PC *ipv6.PacketConn // will be nil on non-Linux
ipv4TxOffload bool
ipv4RxOffload bool
ipv6TxOffload bool
ipv6RxOffload bool
// these three fields are not guarded by mu // these two fields are not guarded by mu
udpAddrPool sync.Pool udpAddrPool sync.Pool
ipv4MsgsPool sync.Pool msgsPool sync.Pool
ipv6MsgsPool sync.Pool
blackhole4 bool blackhole4 bool
blackhole6 bool blackhole6 bool
listenAddr4 string
listenAddr6 string
bindV4 bool
bindV6 bool
reusePort bool
} }
func newStdNetBind() *StdNetBind { // NewStdNetBind creates a bind that listens on all interfaces.
func NewStdNetBind() *StdNetBind {
return newStdNetBind().(*StdNetBind)
}
// NewStdNetBindForAddr creates a bind that listens on a specific address.
// If addr is IPv4, only the IPv4 socket will be created. For IPv6, only the
// IPv6 socket will be created.
func NewStdNetBindForAddr(addr netip.Addr, reusePort bool) *StdNetBind {
b := NewStdNetBind()
//if addr.IsValid() {
// if addr.IsUnspecified() {
// // keep dual-stack defaults with empty listen addresses
// } else if addr.Is4() {
// b.listenAddr4 = addr.Unmap().String()
// b.bindV4 = true
// b.bindV6 = false
// } else {
// b.listenAddr6 = addr.Unmap().String()
// b.bindV6 = true
// b.bindV4 = false
// }
//}
//b.reusePort = reusePort
return b
}
func newStdNetBind() Bind {
return &StdNetBind{ return &StdNetBind{
udpAddrPool: sync.Pool{ udpAddrPool: sync.Pool{
New: func() any { New: func() any {
@@ -60,68 +86,28 @@ func newStdNetBind() *StdNetBind {
}, },
}, },
ipv4MsgsPool: sync.Pool{ msgsPool: sync.Pool{
New: func() any {
msgs := make([]ipv4.Message, IdealBatchSize)
for i := range msgs {
msgs[i].Buffers = make(net.Buffers, 1)
msgs[i].OOB = make([]byte, srcControlSize)
}
return &msgs
},
},
ipv6MsgsPool: sync.Pool{
New: func() any { New: func() any {
// ipv6.Message and ipv4.Message are interchangeable as they are
// both aliases for x/net/internal/socket.Message.
msgs := make([]ipv6.Message, IdealBatchSize) msgs := make([]ipv6.Message, IdealBatchSize)
for i := range msgs { for i := range msgs {
msgs[i].Buffers = make(net.Buffers, 1) msgs[i].Buffers = make(net.Buffers, 1)
msgs[i].OOB = make([]byte, srcControlSize) msgs[i].OOB = make([]byte, 0, stickyControlSize+gsoControlSize)
} }
return &msgs return &msgs
}, },
}, },
bindV4: true,
bindV6: true,
reusePort: false,
} }
} }
// NewStdNetBind creates a bind that listens on all interfaces.
func NewStdNetBind() *StdNetBind {
return newStdNetBind()
}
// NewStdNetBindForAddr creates a bind that listens on a specific address.
// If addr is IPv4, only the IPv4 socket will be created. For IPv6, only the
// IPv6 socket will be created.
func NewStdNetBindForAddr(addr netip.Addr, reusePort bool) *StdNetBind {
b := newStdNetBind()
if addr.IsValid() {
if addr.IsUnspecified() {
// keep dual-stack defaults with empty listen addresses
} else if addr.Is4() {
b.listenAddr4 = addr.Unmap().String()
b.bindV4 = true
b.bindV6 = false
} else {
b.listenAddr6 = addr.Unmap().String()
b.bindV6 = true
b.bindV4 = false
}
}
b.reusePort = reusePort
return b
}
type StdNetEndpoint struct { type StdNetEndpoint struct {
// AddrPort is the endpoint destination. // AddrPort is the endpoint destination.
netip.AddrPort netip.AddrPort
// src is the current sticky source address and interface index, if supported. // src is the current sticky source address and interface index, if
src struct { // supported. Typically this is a PKTINFO structure from/for control
netip.Addr // messages, see unix.PKTINFO for an example.
ifidx int32 src []byte
}
} }
var ( var (
@@ -140,21 +126,17 @@ func (*StdNetBind) ParseEndpoint(s string) (Endpoint, error) {
} }
func (e *StdNetEndpoint) ClearSrc() { func (e *StdNetEndpoint) ClearSrc() {
e.src.ifidx = 0 if e.src != nil {
e.src.Addr = netip.Addr{} // Truncate src, no need to reallocate.
e.src = e.src[:0]
}
} }
func (e *StdNetEndpoint) DstIP() netip.Addr { func (e *StdNetEndpoint) DstIP() netip.Addr {
return e.AddrPort.Addr() return e.AddrPort.Addr()
} }
func (e *StdNetEndpoint) SrcIP() netip.Addr { // See control_default,linux, etc for implementations of SrcIP and SrcIfidx.
return e.src.Addr
}
func (e *StdNetEndpoint) SrcIfidx() int32 {
return e.src.ifidx
}
func (e *StdNetEndpoint) DstToBytes() []byte { func (e *StdNetEndpoint) DstToBytes() []byte {
b, _ := e.AddrPort.MarshalBinary() b, _ := e.AddrPort.MarshalBinary()
@@ -165,32 +147,8 @@ func (e *StdNetEndpoint) DstToString() string {
return e.AddrPort.String() return e.AddrPort.String()
} }
func (e *StdNetEndpoint) SrcToString() string { func listenNet(network string, port int) (*net.UDPConn, int, error) {
return e.src.Addr.String() conn, err := listenConfig().ListenPacket(context.Background(), network, ":"+strconv.Itoa(port))
}
func (s *StdNetBind) listenNet(network string, host string, port int) (*net.UDPConn, int, error) {
lc := listenConfig()
if s.reusePort {
base := lc.Control
lc.Control = func(network, address string, c syscall.RawConn) error {
if base != nil {
if err := base(network, address, c); err != nil {
return err
}
}
return c.Control(func(fd uintptr) {
_ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_REUSEPORT, 1)
})
}
}
addr := ":" + strconv.Itoa(port)
if host != "" {
addr = net.JoinHostPort(host, strconv.Itoa(port))
}
conn, err := lc.ListenPacket(context.Background(), network, addr)
if err != nil { if err != nil {
return nil, 0, err return nil, 0, err
} }
@@ -204,47 +162,10 @@ func (s *StdNetBind) listenNet(network string, host string, port int) (*net.UDPC
if err != nil { if err != nil {
return nil, 0, err return nil, 0, err
} }
return conn.(*net.UDPConn), uaddr.Port, nil return conn.(*net.UDPConn), uaddr.Port, nil
} }
func (s *StdNetBind) openIPv4(port int) (*net.UDPConn, *ipv4.PacketConn, int, error) {
if !s.bindV4 {
return nil, nil, port, nil
}
host := s.listenAddr4
conn, actualPort, err := s.listenNet("udp4", host, port)
if err != nil {
if errors.Is(err, syscall.EAFNOSUPPORT) {
return nil, nil, port, nil
}
return nil, nil, port, err
}
if runtime.GOOS != "linux" {
return conn, nil, actualPort, nil
}
pc := ipv4.NewPacketConn(conn)
return conn, pc, actualPort, nil
}
func (s *StdNetBind) openIPv6(port int) (*net.UDPConn, *ipv6.PacketConn, int, error) {
if !s.bindV6 {
return nil, nil, port, nil
}
host := s.listenAddr6
conn, actualPort, err := s.listenNet("udp6", host, port)
if err != nil {
if errors.Is(err, syscall.EAFNOSUPPORT) {
return nil, nil, port, nil
}
return nil, nil, port, err
}
if runtime.GOOS != "linux" {
return conn, nil, actualPort, nil
}
pc := ipv6.NewPacketConn(conn)
return conn, pc, actualPort, nil
}
func (s *StdNetBind) Open(uport uint16) ([]ReceiveFunc, uint16, error) { func (s *StdNetBind) Open(uport uint16) ([]ReceiveFunc, uint16, error) {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
@@ -260,46 +181,44 @@ func (s *StdNetBind) Open(uport uint16) ([]ReceiveFunc, uint16, error) {
// If uport is 0, we can retry on failure. // If uport is 0, we can retry on failure.
again: again:
port := int(uport) port := int(uport)
var v4conn *net.UDPConn var v4conn, v6conn *net.UDPConn
var v6conn *net.UDPConn
var v4pc *ipv4.PacketConn var v4pc *ipv4.PacketConn
var v6pc *ipv6.PacketConn var v6pc *ipv6.PacketConn
v4conn, v4pc, port, err = s.openIPv4(port) v4conn, port, err = listenNet("udp4", port)
if err != nil { if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
return nil, 0, err return nil, 0, err
} }
// Listen on the same port as we're using for ipv4. // Listen on the same port as we're using for ipv4.
v6conn, v6pc, port, err = s.openIPv6(port) v6conn, port, err = listenNet("udp6", port)
if uport == 0 && errors.Is(err, syscall.EADDRINUSE) && tries < 100 { if uport == 0 && errors.Is(err, syscall.EADDRINUSE) && tries < 100 {
if v4conn != nil { v4conn.Close()
v4conn.Close()
}
tries++ tries++
goto again goto again
} }
if err != nil { if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
if v4conn != nil { v4conn.Close()
v4conn.Close()
}
return nil, 0, err return nil, 0, err
} }
var fns []ReceiveFunc var fns []ReceiveFunc
if v4conn != nil { if v4conn != nil {
s.ipv4 = v4conn s.ipv4TxOffload, s.ipv4RxOffload = supportsUDPOffload(v4conn)
if v4pc != nil { if runtime.GOOS == "linux" || runtime.GOOS == "android" {
v4pc = ipv4.NewPacketConn(v4conn)
s.ipv4PC = v4pc s.ipv4PC = v4pc
} }
fns = append(fns, s.makeReceiveIPv4(v4pc, v4conn)) fns = append(fns, s.makeReceiveIPv4(v4pc, v4conn, s.ipv4RxOffload))
s.ipv4 = v4conn
} }
if v6conn != nil { if v6conn != nil {
s.ipv6 = v6conn s.ipv6TxOffload, s.ipv6RxOffload = supportsUDPOffload(v6conn)
if v6pc != nil { if runtime.GOOS == "linux" || runtime.GOOS == "android" {
v6pc = ipv6.NewPacketConn(v6conn)
s.ipv6PC = v6pc s.ipv6PC = v6pc
} }
fns = append(fns, s.makeReceiveIPv6(v6pc, v6conn)) fns = append(fns, s.makeReceiveIPv6(v6pc, v6conn, s.ipv6RxOffload))
s.ipv6 = v6conn
} }
if len(fns) == 0 { if len(fns) == 0 {
return nil, 0, syscall.EAFNOSUPPORT return nil, 0, syscall.EAFNOSUPPORT
@@ -308,76 +227,101 @@ again:
return fns, uint16(port), nil return fns, uint16(port), nil
} }
func (s *StdNetBind) makeReceiveIPv4(pc *ipv4.PacketConn, conn *net.UDPConn) ReceiveFunc { func (s *StdNetBind) putMessages(msgs *[]ipv6.Message) {
return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) { for i := range *msgs {
msgs := s.ipv4MsgsPool.Get().(*[]ipv4.Message) (*msgs)[i].OOB = (*msgs)[i].OOB[:0]
defer s.ipv4MsgsPool.Put(msgs) (*msgs)[i] = ipv6.Message{Buffers: (*msgs)[i].Buffers, OOB: (*msgs)[i].OOB}
for i := range bufs { }
(*msgs)[i].Buffers[0] = bufs[i] s.msgsPool.Put(msgs)
} }
var numMsgs int
if runtime.GOOS == "linux" && pc != nil { func (s *StdNetBind) getMessages() *[]ipv6.Message {
numMsgs, err = pc.ReadBatch(*msgs, 0) return s.msgsPool.Get().(*[]ipv6.Message)
}
var (
// If compilation fails here these are no longer the same underlying type.
_ ipv6.Message = ipv4.Message{}
)
type batchReader interface {
ReadBatch([]ipv6.Message, int) (int, error)
}
type batchWriter interface {
WriteBatch([]ipv6.Message, int) (int, error)
}
func (s *StdNetBind) receiveIP(
br batchReader,
conn *net.UDPConn,
rxOffload bool,
bufs [][]byte,
sizes []int,
eps []Endpoint,
) (n int, err error) {
msgs := s.getMessages()
for i := range bufs {
(*msgs)[i].Buffers[0] = bufs[i]
(*msgs)[i].OOB = (*msgs)[i].OOB[:cap((*msgs)[i].OOB)]
}
defer s.putMessages(msgs)
var numMsgs int
if runtime.GOOS == "linux" || runtime.GOOS == "android" {
if rxOffload {
readAt := len(*msgs) - (IdealBatchSize / udpSegmentMaxDatagrams)
numMsgs, err = br.ReadBatch((*msgs)[readAt:], 0)
if err != nil {
return 0, err
}
numMsgs, err = splitCoalescedMessages(*msgs, readAt, getGSOSize)
if err != nil { if err != nil {
return 0, err return 0, err
} }
} else { } else {
msg := &(*msgs)[0] numMsgs, err = br.ReadBatch(*msgs, 0)
msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB)
if err != nil { if err != nil {
return 0, err return 0, err
} }
numMsgs = 1
} }
for i := 0; i < numMsgs; i++ { } else {
msg := &(*msgs)[i] msg := &(*msgs)[0]
sizes[i] = msg.N msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB)
addrPort := msg.Addr.(*net.UDPAddr).AddrPort() if err != nil {
ep := &StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation return 0, err
getSrcFromControl(msg.OOB[:msg.NN], ep)
eps[i] = ep
} }
return numMsgs, nil numMsgs = 1
}
for i := 0; i < numMsgs; i++ {
msg := &(*msgs)[i]
sizes[i] = msg.N
if sizes[i] == 0 {
continue
}
addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
ep := &StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation
getSrcFromControl(msg.OOB[:msg.NN], ep)
eps[i] = ep
}
return numMsgs, nil
}
func (s *StdNetBind) makeReceiveIPv4(pc *ipv4.PacketConn, conn *net.UDPConn, rxOffload bool) ReceiveFunc {
return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
return s.receiveIP(pc, conn, rxOffload, bufs, sizes, eps)
} }
} }
func (s *StdNetBind) makeReceiveIPv6(pc *ipv6.PacketConn, conn *net.UDPConn) ReceiveFunc { func (s *StdNetBind) makeReceiveIPv6(pc *ipv6.PacketConn, conn *net.UDPConn, rxOffload bool) ReceiveFunc {
return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) { return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
msgs := s.ipv6MsgsPool.Get().(*[]ipv6.Message) return s.receiveIP(pc, conn, rxOffload, bufs, sizes, eps)
defer s.ipv6MsgsPool.Put(msgs)
for i := range bufs {
(*msgs)[i].Buffers[0] = bufs[i]
}
var numMsgs int
if runtime.GOOS == "linux" && pc != nil {
numMsgs, err = pc.ReadBatch(*msgs, 0)
if err != nil {
return 0, err
}
} else {
msg := &(*msgs)[0]
msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB)
if err != nil {
return 0, err
}
numMsgs = 1
}
for i := 0; i < numMsgs; i++ {
msg := &(*msgs)[i]
sizes[i] = msg.N
addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
ep := &StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation
getSrcFromControl(msg.OOB[:msg.NN], ep)
eps[i] = ep
}
return numMsgs, nil
} }
} }
// TODO: When all Binds handle IdealBatchSize, remove this dynamic function and // TODO: When all Binds handle IdealBatchSize, remove this dynamic function and
// rename the IdealBatchSize constant to BatchSize. // rename the IdealBatchSize constant to BatchSize.
func (s *StdNetBind) BatchSize() int { func (s *StdNetBind) BatchSize() int {
if runtime.GOOS == "linux" { if runtime.GOOS == "linux" || runtime.GOOS == "android" {
return IdealBatchSize return IdealBatchSize
} }
return 1 return 1
@@ -400,28 +344,42 @@ func (s *StdNetBind) Close() error {
} }
s.blackhole4 = false s.blackhole4 = false
s.blackhole6 = false s.blackhole6 = false
s.ipv4TxOffload = false
s.ipv4RxOffload = false
s.ipv6TxOffload = false
s.ipv6RxOffload = false
if err1 != nil { if err1 != nil {
return err1 return err1
} }
return err2 return err2
} }
type ErrUDPGSODisabled struct {
onLaddr string
RetryErr error
}
func (e ErrUDPGSODisabled) Error() string {
return fmt.Sprintf("disabled UDP GSO on %s, NIC(s) may not support checksum offload", e.onLaddr)
}
func (e ErrUDPGSODisabled) Unwrap() error {
return e.RetryErr
}
func (s *StdNetBind) Send(bufs [][]byte, endpoint Endpoint) error { func (s *StdNetBind) Send(bufs [][]byte, endpoint Endpoint) error {
s.mu.Lock() s.mu.Lock()
blackhole := s.blackhole4 blackhole := s.blackhole4
conn := s.ipv4 conn := s.ipv4
var ( offload := s.ipv4TxOffload
pc4 *ipv4.PacketConn br := batchWriter(s.ipv4PC)
pc6 *ipv6.PacketConn
)
is6 := false is6 := false
if endpoint.DstIP().Is6() { if endpoint.DstIP().Is6() {
blackhole = s.blackhole6 blackhole = s.blackhole6
conn = s.ipv6 conn = s.ipv6
pc6 = s.ipv6PC br = s.ipv6PC
is6 = true is6 = true
} else { offload = s.ipv6TxOffload
pc4 = s.ipv4PC
} }
s.mu.Unlock() s.mu.Unlock()
@@ -431,109 +389,185 @@ func (s *StdNetBind) Send(bufs [][]byte, endpoint Endpoint) error {
if conn == nil { if conn == nil {
return syscall.EAFNOSUPPORT return syscall.EAFNOSUPPORT
} }
msgs := s.getMessages()
defer s.putMessages(msgs)
ua := s.udpAddrPool.Get().(*net.UDPAddr)
defer s.udpAddrPool.Put(ua)
if is6 { if is6 {
return s.send6(conn, pc6, endpoint, bufs) as16 := endpoint.DstIP().As16()
copy(ua.IP, as16[:])
ua.IP = ua.IP[:16]
} else { } else {
return s.send4(conn, pc4, endpoint, bufs) as4 := endpoint.DstIP().As4()
copy(ua.IP, as4[:])
ua.IP = ua.IP[:4]
} }
ua.Port = int(endpoint.(*StdNetEndpoint).Port())
var (
retried bool
err error
)
retry:
if offload {
n := coalesceMessages(ua, endpoint.(*StdNetEndpoint), bufs, *msgs, setGSOSize)
err = s.send(conn, br, (*msgs)[:n])
if err != nil && offload && errShouldDisableUDPGSO(err) {
offload = false
s.mu.Lock()
if is6 {
s.ipv6TxOffload = false
} else {
s.ipv4TxOffload = false
}
s.mu.Unlock()
retried = true
goto retry
}
} else {
for i := range bufs {
(*msgs)[i].Addr = ua
(*msgs)[i].Buffers[0] = bufs[i]
setSrcControl(&(*msgs)[i].OOB, endpoint.(*StdNetEndpoint))
}
err = s.send(conn, br, (*msgs)[:len(bufs)])
}
if retried {
return ErrUDPGSODisabled{onLaddr: conn.LocalAddr().String(), RetryErr: err}
}
return err
} }
func (s *StdNetBind) send4(conn *net.UDPConn, pc *ipv4.PacketConn, ep Endpoint, bufs [][]byte) error { func (s *StdNetBind) send(conn *net.UDPConn, pc batchWriter, msgs []ipv6.Message) error {
ua := s.udpAddrPool.Get().(*net.UDPAddr)
as4 := ep.DstIP().As4()
copy(ua.IP, as4[:])
ua.IP = ua.IP[:4]
ua.Port = int(ep.(*StdNetEndpoint).Port())
msgs := s.ipv4MsgsPool.Get().(*[]ipv4.Message)
for i, buf := range bufs {
(*msgs)[i].Buffers[0] = buf
(*msgs)[i].Addr = ua
setSrcControl(&(*msgs)[i].OOB, ep.(*StdNetEndpoint))
}
var ( var (
n int n int
err error err error
start int start int
) )
if runtime.GOOS == "linux" && pc != nil { if runtime.GOOS == "linux" || runtime.GOOS == "android" {
for { for {
n, err = pc.WriteBatch((*msgs)[start:len(bufs)], 0) n, err = pc.WriteBatch(msgs[start:], 0)
if err != nil { if err != nil || n == len(msgs[start:]) {
if errors.Is(err, syscall.EAFNOSUPPORT) {
for j := start; j < len(bufs); j++ {
_, _, werr := conn.WriteMsgUDP(bufs[j], (*msgs)[j].OOB, ua)
if werr != nil {
err = werr
break
}
}
}
break
}
if n == len((*msgs)[start:len(bufs)]) {
break break
} }
start += n start += n
} }
} else { } else {
for i, buf := range bufs { for _, msg := range msgs {
_, _, err = conn.WriteMsgUDP(buf, (*msgs)[i].OOB, ua) _, _, err = conn.WriteMsgUDP(msg.Buffers[0], msg.OOB, msg.Addr.(*net.UDPAddr))
if err != nil { if err != nil {
break break
} }
} }
} }
s.udpAddrPool.Put(ua)
s.ipv4MsgsPool.Put(msgs)
return err return err
} }
func (s *StdNetBind) send6(conn *net.UDPConn, pc *ipv6.PacketConn, ep Endpoint, bufs [][]byte) error { const (
ua := s.udpAddrPool.Get().(*net.UDPAddr) // Exceeding these values results in EMSGSIZE. They account for layer3 and
as16 := ep.DstIP().As16() // layer4 headers. IPv6 does not need to account for itself as the payload
copy(ua.IP, as16[:]) // length field is self excluding.
ua.IP = ua.IP[:16] maxIPv4PayloadLen = 1<<16 - 1 - 20 - 8
ua.Port = int(ep.(*StdNetEndpoint).Port()) maxIPv6PayloadLen = 1<<16 - 1 - 8
msgs := s.ipv6MsgsPool.Get().(*[]ipv6.Message)
for i, buf := range bufs { // This is a hard limit imposed by the kernel.
(*msgs)[i].Buffers[0] = buf udpSegmentMaxDatagrams = 64
(*msgs)[i].Addr = ua )
setSrcControl(&(*msgs)[i].OOB, ep.(*StdNetEndpoint))
} type setGSOFunc func(control *[]byte, gsoSize uint16)
func coalesceMessages(addr *net.UDPAddr, ep *StdNetEndpoint, bufs [][]byte, msgs []ipv6.Message, setGSO setGSOFunc) int {
var ( var (
n int base = -1 // index of msg we are currently coalescing into
err error gsoSize int // segmentation size of msgs[base]
start int dgramCnt int // number of dgrams coalesced into msgs[base]
endBatch bool // tracking flag to start a new batch on next iteration of bufs
) )
if runtime.GOOS == "linux" && pc != nil { maxPayloadLen := maxIPv4PayloadLen
for { if ep.DstIP().Is6() {
n, err = pc.WriteBatch((*msgs)[start:len(bufs)], 0) maxPayloadLen = maxIPv6PayloadLen
if err != nil { }
if errors.Is(err, syscall.EAFNOSUPPORT) { for i, buf := range bufs {
for j := start; j < len(bufs); j++ { if i > 0 {
_, _, werr := conn.WriteMsgUDP(bufs[j], (*msgs)[j].OOB, ua) msgLen := len(buf)
if werr != nil { baseLenBefore := len(msgs[base].Buffers[0])
err = werr freeBaseCap := cap(msgs[base].Buffers[0]) - baseLenBefore
break if msgLen+baseLenBefore <= maxPayloadLen &&
} msgLen <= gsoSize &&
} msgLen <= freeBaseCap &&
dgramCnt < udpSegmentMaxDatagrams &&
!endBatch {
msgs[base].Buffers[0] = append(msgs[base].Buffers[0], buf...)
if i == len(bufs)-1 {
setGSO(&msgs[base].OOB, uint16(gsoSize))
} }
break dgramCnt++
if msgLen < gsoSize {
// A smaller than gsoSize packet on the tail is legal, but
// it must end the batch.
endBatch = true
}
continue
} }
if n == len((*msgs)[start:len(bufs)]) {
break
}
start += n
} }
} else { if dgramCnt > 1 {
for i, buf := range bufs { setGSO(&msgs[base].OOB, uint16(gsoSize))
_, _, err = conn.WriteMsgUDP(buf, (*msgs)[i].OOB, ua) }
if err != nil { // Reset prior to incrementing base since we are preparing to start a
break // new potential batch.
endBatch = false
base++
gsoSize = len(buf)
setSrcControl(&msgs[base].OOB, ep)
msgs[base].Buffers[0] = buf
msgs[base].Addr = addr
dgramCnt = 1
}
return base + 1
}
type getGSOFunc func(control []byte) (int, error)
func splitCoalescedMessages(msgs []ipv6.Message, firstMsgAt int, getGSO getGSOFunc) (n int, err error) {
for i := firstMsgAt; i < len(msgs); i++ {
msg := &msgs[i]
if msg.N == 0 {
return n, err
}
var (
gsoSize int
start int
end = msg.N
numToSplit = 1
)
gsoSize, err = getGSO(msg.OOB[:msg.NN])
if err != nil {
return n, err
}
if gsoSize > 0 {
numToSplit = (msg.N + gsoSize - 1) / gsoSize
end = gsoSize
}
for j := 0; j < numToSplit; j++ {
if n > i {
return n, errors.New("splitting coalesced packet resulted in overflow")
} }
copied := copy(msgs[n].Buffers[0], msg.Buffers[0][start:end])
msgs[n].N = copied
msgs[n].Addr = msg.Addr
start = end
end += gsoSize
if end > msg.N {
end = msg.N
}
n++
}
if i != n-1 {
// It is legal for bytes to move within msg.Buffers[0] as a result
// of splitting, so we only zero the source msg len when it is not
// the destination of the last split operation above.
msg.N = 0
} }
} }
s.udpAddrPool.Put(ua) return n, nil
s.ipv6MsgsPool.Put(msgs)
return err
} }

View File

@@ -29,6 +29,9 @@ func init() {
// Set beyond *mem_max if CAP_NET_ADMIN // Set beyond *mem_max if CAP_NET_ADMIN
_ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_RCVBUFFORCE, socketBufferSize) _ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_RCVBUFFORCE, socketBufferSize)
_ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_SNDBUFFORCE, socketBufferSize) _ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_SNDBUFFORCE, socketBufferSize)
_ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_REUSEPORT, 1) //todo!!!
_ = unix.SetsockoptInt(int(fd), unix.SOL_UDP, unix.UDP_SEGMENT, 0xffff) //todo!!!
//print(err.Error())
}) })
}, },

View File

@@ -1,9 +1,3 @@
//go:build linux && !android
// SPDX-License-Identifier: MIT
//
// Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
package conn package conn
import ( import (
@@ -13,6 +7,37 @@ import (
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
) )
func (e *StdNetEndpoint) SrcIP() netip.Addr {
switch len(e.src) {
case unix.CmsgSpace(unix.SizeofInet4Pktinfo):
info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)]))
return netip.AddrFrom4(info.Spec_dst)
case unix.CmsgSpace(unix.SizeofInet6Pktinfo):
info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)]))
// TODO: set zone. in order to do so we need to check if the address is
// link local, and if it is perform a syscall to turn the ifindex into a
// zone string because netip uses string zones.
return netip.AddrFrom16(info.Addr)
}
return netip.Addr{}
}
func (e *StdNetEndpoint) SrcIfidx() int32 {
switch len(e.src) {
case unix.CmsgSpace(unix.SizeofInet4Pktinfo):
info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)]))
return info.Ifindex
case unix.CmsgSpace(unix.SizeofInet6Pktinfo):
info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)]))
return int32(info.Ifindex)
}
return 0
}
func (e *StdNetEndpoint) SrcToString() string {
return e.SrcIP().String()
}
// getSrcFromControl parses the control for PKTINFO and if found updates ep with // getSrcFromControl parses the control for PKTINFO and if found updates ep with
// the source information found. // the source information found.
func getSrcFromControl(control []byte, ep *StdNetEndpoint) { func getSrcFromControl(control []byte, ep *StdNetEndpoint) {
@@ -34,83 +59,47 @@ func getSrcFromControl(control []byte, ep *StdNetEndpoint) {
if hdr.Level == unix.IPPROTO_IP && if hdr.Level == unix.IPPROTO_IP &&
hdr.Type == unix.IP_PKTINFO { hdr.Type == unix.IP_PKTINFO {
info := pktInfoFromBuf[unix.Inet4Pktinfo](data) if ep.src == nil || cap(ep.src) < unix.CmsgSpace(unix.SizeofInet4Pktinfo) {
ep.src.Addr = netip.AddrFrom4(info.Spec_dst) ep.src = make([]byte, 0, unix.CmsgSpace(unix.SizeofInet4Pktinfo))
ep.src.ifidx = info.Ifindex }
ep.src = ep.src[:unix.CmsgSpace(unix.SizeofInet4Pktinfo)]
hdrBuf := unsafe.Slice((*byte)(unsafe.Pointer(&hdr)), unix.SizeofCmsghdr)
copy(ep.src, hdrBuf)
copy(ep.src[unix.CmsgLen(0):], data)
return return
} }
if hdr.Level == unix.IPPROTO_IPV6 && if hdr.Level == unix.IPPROTO_IPV6 &&
hdr.Type == unix.IPV6_PKTINFO { hdr.Type == unix.IPV6_PKTINFO {
info := pktInfoFromBuf[unix.Inet6Pktinfo](data) if ep.src == nil || cap(ep.src) < unix.CmsgSpace(unix.SizeofInet6Pktinfo) {
ep.src.Addr = netip.AddrFrom16(info.Addr) ep.src = make([]byte, 0, unix.CmsgSpace(unix.SizeofInet6Pktinfo))
ep.src.ifidx = int32(info.Ifindex) }
ep.src = ep.src[:unix.CmsgSpace(unix.SizeofInet6Pktinfo)]
hdrBuf := unsafe.Slice((*byte)(unsafe.Pointer(&hdr)), unix.SizeofCmsghdr)
copy(ep.src, hdrBuf)
copy(ep.src[unix.CmsgLen(0):], data)
return return
} }
} }
} }
// pktInfoFromBuf returns type T populated from the provided buf via copy(). It
// panics if buf is of insufficient size.
func pktInfoFromBuf[T unix.Inet4Pktinfo | unix.Inet6Pktinfo](buf []byte) (t T) {
size := int(unsafe.Sizeof(t))
if len(buf) < size {
panic("pktInfoFromBuf: buffer too small")
}
copy(unsafe.Slice((*byte)(unsafe.Pointer(&t)), size), buf)
return t
}
// setSrcControl sets an IP{V6}_PKTINFO in control based on the source address // setSrcControl sets an IP{V6}_PKTINFO in control based on the source address
// and source ifindex found in ep. control's len will be set to 0 in the event // and source ifindex found in ep. control's len will be set to 0 in the event
// that ep is a default value. // that ep is a default value.
func setSrcControl(control *[]byte, ep *StdNetEndpoint) { func setSrcControl(control *[]byte, ep *StdNetEndpoint) {
*control = (*control)[:cap(*control)] if cap(*control) < len(ep.src) {
if len(*control) < int(unsafe.Sizeof(unix.Cmsghdr{})) {
*control = (*control)[:0]
return return
} }
*control = (*control)[:0]
if ep.src.ifidx == 0 && !ep.SrcIP().IsValid() { *control = append(*control, ep.src...)
*control = (*control)[:0]
return
}
if len(*control) < srcControlSize {
*control = (*control)[:0]
return
}
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&(*control)[0]))
if ep.SrcIP().Is4() {
hdr.Level = unix.IPPROTO_IP
hdr.Type = unix.IP_PKTINFO
hdr.SetLen(unix.CmsgLen(unix.SizeofInet4Pktinfo))
info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&(*control)[unix.SizeofCmsghdr]))
info.Ifindex = ep.src.ifidx
if ep.SrcIP().IsValid() {
info.Spec_dst = ep.SrcIP().As4()
}
*control = (*control)[:unix.CmsgSpace(unix.SizeofInet4Pktinfo)]
} else {
hdr.Level = unix.IPPROTO_IPV6
hdr.Type = unix.IPV6_PKTINFO
hdr.SetLen(unix.CmsgLen(unix.SizeofInet6Pktinfo))
info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&(*control)[unix.SizeofCmsghdr]))
info.Ifindex = uint32(ep.src.ifidx)
if ep.SrcIP().IsValid() {
info.Addr = ep.SrcIP().As16()
}
*control = (*control)[:unix.CmsgSpace(unix.SizeofInet6Pktinfo)]
}
} }
var srcControlSize = unix.CmsgSpace(unix.SizeofInet6Pktinfo) // stickyControlSize returns the recommended buffer size for pooling sticky
// offloading control data.
var stickyControlSize = unix.CmsgSpace(unix.SizeofInet6Pktinfo)
const StdNetSupportsStickySockets = true const StdNetSupportsStickySockets = true