This commit is contained in:
JackDoan
2025-11-05 15:38:47 -06:00
parent 2d128a3254
commit befba57366
3 changed files with 380 additions and 354 deletions

View File

@@ -1,12 +1,14 @@
// SPDX-License-Identifier: MIT
//
// Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package conn
import (
"context"
"errors"
"fmt"
"net"
"net/netip"
"runtime"
@@ -16,7 +18,6 @@ import (
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
"golang.org/x/sys/unix"
)
var (
@@ -29,28 +30,53 @@ var (
// methods for sending and receiving multiple datagrams per-syscall. See the
// proposal in https://github.com/golang/go/issues/45886#issuecomment-1218301564.
type StdNetBind struct {
mu sync.Mutex // protects all fields except as specified
ipv4 *net.UDPConn
ipv6 *net.UDPConn
ipv4PC *ipv4.PacketConn // will be nil on non-Linux
ipv6PC *ipv6.PacketConn // will be nil on non-Linux
mu sync.Mutex // protects all fields except as specified
ipv4 *net.UDPConn
ipv6 *net.UDPConn
ipv4PC *ipv4.PacketConn // will be nil on non-Linux
ipv6PC *ipv6.PacketConn // will be nil on non-Linux
ipv4TxOffload bool
ipv4RxOffload bool
ipv6TxOffload bool
ipv6RxOffload bool
// these three fields are not guarded by mu
udpAddrPool sync.Pool
ipv4MsgsPool sync.Pool
ipv6MsgsPool sync.Pool
// these two fields are not guarded by mu
udpAddrPool sync.Pool
msgsPool sync.Pool
blackhole4 bool
blackhole6 bool
listenAddr4 string
listenAddr6 string
bindV4 bool
bindV6 bool
reusePort bool
}
func newStdNetBind() *StdNetBind {
// NewStdNetBind creates a bind that listens on all interfaces.
func NewStdNetBind() *StdNetBind {
return newStdNetBind().(*StdNetBind)
}
// NewStdNetBindForAddr creates a bind that listens on a specific address.
// If addr is IPv4, only the IPv4 socket will be created. For IPv6, only the
// IPv6 socket will be created.
func NewStdNetBindForAddr(addr netip.Addr, reusePort bool) *StdNetBind {
b := NewStdNetBind()
//if addr.IsValid() {
// if addr.IsUnspecified() {
// // keep dual-stack defaults with empty listen addresses
// } else if addr.Is4() {
// b.listenAddr4 = addr.Unmap().String()
// b.bindV4 = true
// b.bindV6 = false
// } else {
// b.listenAddr6 = addr.Unmap().String()
// b.bindV6 = true
// b.bindV4 = false
// }
//}
//b.reusePort = reusePort
return b
}
func newStdNetBind() Bind {
return &StdNetBind{
udpAddrPool: sync.Pool{
New: func() any {
@@ -60,68 +86,28 @@ func newStdNetBind() *StdNetBind {
},
},
ipv4MsgsPool: sync.Pool{
New: func() any {
msgs := make([]ipv4.Message, IdealBatchSize)
for i := range msgs {
msgs[i].Buffers = make(net.Buffers, 1)
msgs[i].OOB = make([]byte, srcControlSize)
}
return &msgs
},
},
ipv6MsgsPool: sync.Pool{
msgsPool: sync.Pool{
New: func() any {
// ipv6.Message and ipv4.Message are interchangeable as they are
// both aliases for x/net/internal/socket.Message.
msgs := make([]ipv6.Message, IdealBatchSize)
for i := range msgs {
msgs[i].Buffers = make(net.Buffers, 1)
msgs[i].OOB = make([]byte, srcControlSize)
msgs[i].OOB = make([]byte, 0, stickyControlSize+gsoControlSize)
}
return &msgs
},
},
bindV4: true,
bindV6: true,
reusePort: false,
}
}
// NewStdNetBind creates a bind that listens on all interfaces.
func NewStdNetBind() *StdNetBind {
return newStdNetBind()
}
// NewStdNetBindForAddr creates a bind that listens on a specific address.
// If addr is IPv4, only the IPv4 socket will be created. For IPv6, only the
// IPv6 socket will be created.
func NewStdNetBindForAddr(addr netip.Addr, reusePort bool) *StdNetBind {
b := newStdNetBind()
if addr.IsValid() {
if addr.IsUnspecified() {
// keep dual-stack defaults with empty listen addresses
} else if addr.Is4() {
b.listenAddr4 = addr.Unmap().String()
b.bindV4 = true
b.bindV6 = false
} else {
b.listenAddr6 = addr.Unmap().String()
b.bindV6 = true
b.bindV4 = false
}
}
b.reusePort = reusePort
return b
}
type StdNetEndpoint struct {
// AddrPort is the endpoint destination.
netip.AddrPort
// src is the current sticky source address and interface index, if supported.
src struct {
netip.Addr
ifidx int32
}
// src is the current sticky source address and interface index, if
// supported. Typically this is a PKTINFO structure from/for control
// messages, see unix.PKTINFO for an example.
src []byte
}
var (
@@ -140,21 +126,17 @@ func (*StdNetBind) ParseEndpoint(s string) (Endpoint, error) {
}
func (e *StdNetEndpoint) ClearSrc() {
e.src.ifidx = 0
e.src.Addr = netip.Addr{}
if e.src != nil {
// Truncate src, no need to reallocate.
e.src = e.src[:0]
}
}
func (e *StdNetEndpoint) DstIP() netip.Addr {
return e.AddrPort.Addr()
}
func (e *StdNetEndpoint) SrcIP() netip.Addr {
return e.src.Addr
}
func (e *StdNetEndpoint) SrcIfidx() int32 {
return e.src.ifidx
}
// See control_default,linux, etc for implementations of SrcIP and SrcIfidx.
func (e *StdNetEndpoint) DstToBytes() []byte {
b, _ := e.AddrPort.MarshalBinary()
@@ -165,32 +147,8 @@ func (e *StdNetEndpoint) DstToString() string {
return e.AddrPort.String()
}
func (e *StdNetEndpoint) SrcToString() string {
return e.src.Addr.String()
}
func (s *StdNetBind) listenNet(network string, host string, port int) (*net.UDPConn, int, error) {
lc := listenConfig()
if s.reusePort {
base := lc.Control
lc.Control = func(network, address string, c syscall.RawConn) error {
if base != nil {
if err := base(network, address, c); err != nil {
return err
}
}
return c.Control(func(fd uintptr) {
_ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_REUSEPORT, 1)
})
}
}
addr := ":" + strconv.Itoa(port)
if host != "" {
addr = net.JoinHostPort(host, strconv.Itoa(port))
}
conn, err := lc.ListenPacket(context.Background(), network, addr)
func listenNet(network string, port int) (*net.UDPConn, int, error) {
conn, err := listenConfig().ListenPacket(context.Background(), network, ":"+strconv.Itoa(port))
if err != nil {
return nil, 0, err
}
@@ -204,47 +162,10 @@ func (s *StdNetBind) listenNet(network string, host string, port int) (*net.UDPC
if err != nil {
return nil, 0, err
}
return conn.(*net.UDPConn), uaddr.Port, nil
}
func (s *StdNetBind) openIPv4(port int) (*net.UDPConn, *ipv4.PacketConn, int, error) {
if !s.bindV4 {
return nil, nil, port, nil
}
host := s.listenAddr4
conn, actualPort, err := s.listenNet("udp4", host, port)
if err != nil {
if errors.Is(err, syscall.EAFNOSUPPORT) {
return nil, nil, port, nil
}
return nil, nil, port, err
}
if runtime.GOOS != "linux" {
return conn, nil, actualPort, nil
}
pc := ipv4.NewPacketConn(conn)
return conn, pc, actualPort, nil
}
func (s *StdNetBind) openIPv6(port int) (*net.UDPConn, *ipv6.PacketConn, int, error) {
if !s.bindV6 {
return nil, nil, port, nil
}
host := s.listenAddr6
conn, actualPort, err := s.listenNet("udp6", host, port)
if err != nil {
if errors.Is(err, syscall.EAFNOSUPPORT) {
return nil, nil, port, nil
}
return nil, nil, port, err
}
if runtime.GOOS != "linux" {
return conn, nil, actualPort, nil
}
pc := ipv6.NewPacketConn(conn)
return conn, pc, actualPort, nil
}
func (s *StdNetBind) Open(uport uint16) ([]ReceiveFunc, uint16, error) {
s.mu.Lock()
defer s.mu.Unlock()
@@ -260,46 +181,44 @@ func (s *StdNetBind) Open(uport uint16) ([]ReceiveFunc, uint16, error) {
// If uport is 0, we can retry on failure.
again:
port := int(uport)
var v4conn *net.UDPConn
var v6conn *net.UDPConn
var v4conn, v6conn *net.UDPConn
var v4pc *ipv4.PacketConn
var v6pc *ipv6.PacketConn
v4conn, v4pc, port, err = s.openIPv4(port)
if err != nil {
v4conn, port, err = listenNet("udp4", port)
if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
return nil, 0, err
}
// Listen on the same port as we're using for ipv4.
v6conn, v6pc, port, err = s.openIPv6(port)
v6conn, port, err = listenNet("udp6", port)
if uport == 0 && errors.Is(err, syscall.EADDRINUSE) && tries < 100 {
if v4conn != nil {
v4conn.Close()
}
v4conn.Close()
tries++
goto again
}
if err != nil {
if v4conn != nil {
v4conn.Close()
}
if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
v4conn.Close()
return nil, 0, err
}
var fns []ReceiveFunc
if v4conn != nil {
s.ipv4 = v4conn
if v4pc != nil {
s.ipv4TxOffload, s.ipv4RxOffload = supportsUDPOffload(v4conn)
if runtime.GOOS == "linux" || runtime.GOOS == "android" {
v4pc = ipv4.NewPacketConn(v4conn)
s.ipv4PC = v4pc
}
fns = append(fns, s.makeReceiveIPv4(v4pc, v4conn))
fns = append(fns, s.makeReceiveIPv4(v4pc, v4conn, s.ipv4RxOffload))
s.ipv4 = v4conn
}
if v6conn != nil {
s.ipv6 = v6conn
if v6pc != nil {
s.ipv6TxOffload, s.ipv6RxOffload = supportsUDPOffload(v6conn)
if runtime.GOOS == "linux" || runtime.GOOS == "android" {
v6pc = ipv6.NewPacketConn(v6conn)
s.ipv6PC = v6pc
}
fns = append(fns, s.makeReceiveIPv6(v6pc, v6conn))
fns = append(fns, s.makeReceiveIPv6(v6pc, v6conn, s.ipv6RxOffload))
s.ipv6 = v6conn
}
if len(fns) == 0 {
return nil, 0, syscall.EAFNOSUPPORT
@@ -308,76 +227,101 @@ again:
return fns, uint16(port), nil
}
func (s *StdNetBind) makeReceiveIPv4(pc *ipv4.PacketConn, conn *net.UDPConn) ReceiveFunc {
return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
msgs := s.ipv4MsgsPool.Get().(*[]ipv4.Message)
defer s.ipv4MsgsPool.Put(msgs)
for i := range bufs {
(*msgs)[i].Buffers[0] = bufs[i]
}
var numMsgs int
if runtime.GOOS == "linux" && pc != nil {
numMsgs, err = pc.ReadBatch(*msgs, 0)
func (s *StdNetBind) putMessages(msgs *[]ipv6.Message) {
for i := range *msgs {
(*msgs)[i].OOB = (*msgs)[i].OOB[:0]
(*msgs)[i] = ipv6.Message{Buffers: (*msgs)[i].Buffers, OOB: (*msgs)[i].OOB}
}
s.msgsPool.Put(msgs)
}
func (s *StdNetBind) getMessages() *[]ipv6.Message {
return s.msgsPool.Get().(*[]ipv6.Message)
}
var (
// If compilation fails here these are no longer the same underlying type.
_ ipv6.Message = ipv4.Message{}
)
type batchReader interface {
ReadBatch([]ipv6.Message, int) (int, error)
}
type batchWriter interface {
WriteBatch([]ipv6.Message, int) (int, error)
}
func (s *StdNetBind) receiveIP(
br batchReader,
conn *net.UDPConn,
rxOffload bool,
bufs [][]byte,
sizes []int,
eps []Endpoint,
) (n int, err error) {
msgs := s.getMessages()
for i := range bufs {
(*msgs)[i].Buffers[0] = bufs[i]
(*msgs)[i].OOB = (*msgs)[i].OOB[:cap((*msgs)[i].OOB)]
}
defer s.putMessages(msgs)
var numMsgs int
if runtime.GOOS == "linux" || runtime.GOOS == "android" {
if rxOffload {
readAt := len(*msgs) - (IdealBatchSize / udpSegmentMaxDatagrams)
numMsgs, err = br.ReadBatch((*msgs)[readAt:], 0)
if err != nil {
return 0, err
}
numMsgs, err = splitCoalescedMessages(*msgs, readAt, getGSOSize)
if err != nil {
return 0, err
}
} else {
msg := &(*msgs)[0]
msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB)
numMsgs, err = br.ReadBatch(*msgs, 0)
if err != nil {
return 0, err
}
numMsgs = 1
}
for i := 0; i < numMsgs; i++ {
msg := &(*msgs)[i]
sizes[i] = msg.N
addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
ep := &StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation
getSrcFromControl(msg.OOB[:msg.NN], ep)
eps[i] = ep
} else {
msg := &(*msgs)[0]
msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB)
if err != nil {
return 0, err
}
return numMsgs, nil
numMsgs = 1
}
for i := 0; i < numMsgs; i++ {
msg := &(*msgs)[i]
sizes[i] = msg.N
if sizes[i] == 0 {
continue
}
addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
ep := &StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation
getSrcFromControl(msg.OOB[:msg.NN], ep)
eps[i] = ep
}
return numMsgs, nil
}
func (s *StdNetBind) makeReceiveIPv4(pc *ipv4.PacketConn, conn *net.UDPConn, rxOffload bool) ReceiveFunc {
return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
return s.receiveIP(pc, conn, rxOffload, bufs, sizes, eps)
}
}
func (s *StdNetBind) makeReceiveIPv6(pc *ipv6.PacketConn, conn *net.UDPConn) ReceiveFunc {
func (s *StdNetBind) makeReceiveIPv6(pc *ipv6.PacketConn, conn *net.UDPConn, rxOffload bool) ReceiveFunc {
return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
msgs := s.ipv6MsgsPool.Get().(*[]ipv6.Message)
defer s.ipv6MsgsPool.Put(msgs)
for i := range bufs {
(*msgs)[i].Buffers[0] = bufs[i]
}
var numMsgs int
if runtime.GOOS == "linux" && pc != nil {
numMsgs, err = pc.ReadBatch(*msgs, 0)
if err != nil {
return 0, err
}
} else {
msg := &(*msgs)[0]
msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB)
if err != nil {
return 0, err
}
numMsgs = 1
}
for i := 0; i < numMsgs; i++ {
msg := &(*msgs)[i]
sizes[i] = msg.N
addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
ep := &StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation
getSrcFromControl(msg.OOB[:msg.NN], ep)
eps[i] = ep
}
return numMsgs, nil
return s.receiveIP(pc, conn, rxOffload, bufs, sizes, eps)
}
}
// TODO: When all Binds handle IdealBatchSize, remove this dynamic function and
// rename the IdealBatchSize constant to BatchSize.
func (s *StdNetBind) BatchSize() int {
if runtime.GOOS == "linux" {
if runtime.GOOS == "linux" || runtime.GOOS == "android" {
return IdealBatchSize
}
return 1
@@ -400,28 +344,42 @@ func (s *StdNetBind) Close() error {
}
s.blackhole4 = false
s.blackhole6 = false
s.ipv4TxOffload = false
s.ipv4RxOffload = false
s.ipv6TxOffload = false
s.ipv6RxOffload = false
if err1 != nil {
return err1
}
return err2
}
type ErrUDPGSODisabled struct {
onLaddr string
RetryErr error
}
func (e ErrUDPGSODisabled) Error() string {
return fmt.Sprintf("disabled UDP GSO on %s, NIC(s) may not support checksum offload", e.onLaddr)
}
func (e ErrUDPGSODisabled) Unwrap() error {
return e.RetryErr
}
func (s *StdNetBind) Send(bufs [][]byte, endpoint Endpoint) error {
s.mu.Lock()
blackhole := s.blackhole4
conn := s.ipv4
var (
pc4 *ipv4.PacketConn
pc6 *ipv6.PacketConn
)
offload := s.ipv4TxOffload
br := batchWriter(s.ipv4PC)
is6 := false
if endpoint.DstIP().Is6() {
blackhole = s.blackhole6
conn = s.ipv6
pc6 = s.ipv6PC
br = s.ipv6PC
is6 = true
} else {
pc4 = s.ipv4PC
offload = s.ipv6TxOffload
}
s.mu.Unlock()
@@ -431,109 +389,185 @@ func (s *StdNetBind) Send(bufs [][]byte, endpoint Endpoint) error {
if conn == nil {
return syscall.EAFNOSUPPORT
}
msgs := s.getMessages()
defer s.putMessages(msgs)
ua := s.udpAddrPool.Get().(*net.UDPAddr)
defer s.udpAddrPool.Put(ua)
if is6 {
return s.send6(conn, pc6, endpoint, bufs)
as16 := endpoint.DstIP().As16()
copy(ua.IP, as16[:])
ua.IP = ua.IP[:16]
} else {
return s.send4(conn, pc4, endpoint, bufs)
as4 := endpoint.DstIP().As4()
copy(ua.IP, as4[:])
ua.IP = ua.IP[:4]
}
ua.Port = int(endpoint.(*StdNetEndpoint).Port())
var (
retried bool
err error
)
retry:
if offload {
n := coalesceMessages(ua, endpoint.(*StdNetEndpoint), bufs, *msgs, setGSOSize)
err = s.send(conn, br, (*msgs)[:n])
if err != nil && offload && errShouldDisableUDPGSO(err) {
offload = false
s.mu.Lock()
if is6 {
s.ipv6TxOffload = false
} else {
s.ipv4TxOffload = false
}
s.mu.Unlock()
retried = true
goto retry
}
} else {
for i := range bufs {
(*msgs)[i].Addr = ua
(*msgs)[i].Buffers[0] = bufs[i]
setSrcControl(&(*msgs)[i].OOB, endpoint.(*StdNetEndpoint))
}
err = s.send(conn, br, (*msgs)[:len(bufs)])
}
if retried {
return ErrUDPGSODisabled{onLaddr: conn.LocalAddr().String(), RetryErr: err}
}
return err
}
func (s *StdNetBind) send4(conn *net.UDPConn, pc *ipv4.PacketConn, ep Endpoint, bufs [][]byte) error {
ua := s.udpAddrPool.Get().(*net.UDPAddr)
as4 := ep.DstIP().As4()
copy(ua.IP, as4[:])
ua.IP = ua.IP[:4]
ua.Port = int(ep.(*StdNetEndpoint).Port())
msgs := s.ipv4MsgsPool.Get().(*[]ipv4.Message)
for i, buf := range bufs {
(*msgs)[i].Buffers[0] = buf
(*msgs)[i].Addr = ua
setSrcControl(&(*msgs)[i].OOB, ep.(*StdNetEndpoint))
}
func (s *StdNetBind) send(conn *net.UDPConn, pc batchWriter, msgs []ipv6.Message) error {
var (
n int
err error
start int
)
if runtime.GOOS == "linux" && pc != nil {
if runtime.GOOS == "linux" || runtime.GOOS == "android" {
for {
n, err = pc.WriteBatch((*msgs)[start:len(bufs)], 0)
if err != nil {
if errors.Is(err, syscall.EAFNOSUPPORT) {
for j := start; j < len(bufs); j++ {
_, _, werr := conn.WriteMsgUDP(bufs[j], (*msgs)[j].OOB, ua)
if werr != nil {
err = werr
break
}
}
}
break
}
if n == len((*msgs)[start:len(bufs)]) {
n, err = pc.WriteBatch(msgs[start:], 0)
if err != nil || n == len(msgs[start:]) {
break
}
start += n
}
} else {
for i, buf := range bufs {
_, _, err = conn.WriteMsgUDP(buf, (*msgs)[i].OOB, ua)
for _, msg := range msgs {
_, _, err = conn.WriteMsgUDP(msg.Buffers[0], msg.OOB, msg.Addr.(*net.UDPAddr))
if err != nil {
break
}
}
}
s.udpAddrPool.Put(ua)
s.ipv4MsgsPool.Put(msgs)
return err
}
func (s *StdNetBind) send6(conn *net.UDPConn, pc *ipv6.PacketConn, ep Endpoint, bufs [][]byte) error {
ua := s.udpAddrPool.Get().(*net.UDPAddr)
as16 := ep.DstIP().As16()
copy(ua.IP, as16[:])
ua.IP = ua.IP[:16]
ua.Port = int(ep.(*StdNetEndpoint).Port())
msgs := s.ipv6MsgsPool.Get().(*[]ipv6.Message)
for i, buf := range bufs {
(*msgs)[i].Buffers[0] = buf
(*msgs)[i].Addr = ua
setSrcControl(&(*msgs)[i].OOB, ep.(*StdNetEndpoint))
}
const (
// Exceeding these values results in EMSGSIZE. They account for layer3 and
// layer4 headers. IPv6 does not need to account for itself as the payload
// length field is self excluding.
maxIPv4PayloadLen = 1<<16 - 1 - 20 - 8
maxIPv6PayloadLen = 1<<16 - 1 - 8
// This is a hard limit imposed by the kernel.
udpSegmentMaxDatagrams = 64
)
type setGSOFunc func(control *[]byte, gsoSize uint16)
func coalesceMessages(addr *net.UDPAddr, ep *StdNetEndpoint, bufs [][]byte, msgs []ipv6.Message, setGSO setGSOFunc) int {
var (
n int
err error
start int
base = -1 // index of msg we are currently coalescing into
gsoSize int // segmentation size of msgs[base]
dgramCnt int // number of dgrams coalesced into msgs[base]
endBatch bool // tracking flag to start a new batch on next iteration of bufs
)
if runtime.GOOS == "linux" && pc != nil {
for {
n, err = pc.WriteBatch((*msgs)[start:len(bufs)], 0)
if err != nil {
if errors.Is(err, syscall.EAFNOSUPPORT) {
for j := start; j < len(bufs); j++ {
_, _, werr := conn.WriteMsgUDP(bufs[j], (*msgs)[j].OOB, ua)
if werr != nil {
err = werr
break
}
}
maxPayloadLen := maxIPv4PayloadLen
if ep.DstIP().Is6() {
maxPayloadLen = maxIPv6PayloadLen
}
for i, buf := range bufs {
if i > 0 {
msgLen := len(buf)
baseLenBefore := len(msgs[base].Buffers[0])
freeBaseCap := cap(msgs[base].Buffers[0]) - baseLenBefore
if msgLen+baseLenBefore <= maxPayloadLen &&
msgLen <= gsoSize &&
msgLen <= freeBaseCap &&
dgramCnt < udpSegmentMaxDatagrams &&
!endBatch {
msgs[base].Buffers[0] = append(msgs[base].Buffers[0], buf...)
if i == len(bufs)-1 {
setGSO(&msgs[base].OOB, uint16(gsoSize))
}
break
dgramCnt++
if msgLen < gsoSize {
// A smaller than gsoSize packet on the tail is legal, but
// it must end the batch.
endBatch = true
}
continue
}
if n == len((*msgs)[start:len(bufs)]) {
break
}
start += n
}
} else {
for i, buf := range bufs {
_, _, err = conn.WriteMsgUDP(buf, (*msgs)[i].OOB, ua)
if err != nil {
break
if dgramCnt > 1 {
setGSO(&msgs[base].OOB, uint16(gsoSize))
}
// Reset prior to incrementing base since we are preparing to start a
// new potential batch.
endBatch = false
base++
gsoSize = len(buf)
setSrcControl(&msgs[base].OOB, ep)
msgs[base].Buffers[0] = buf
msgs[base].Addr = addr
dgramCnt = 1
}
return base + 1
}
type getGSOFunc func(control []byte) (int, error)
func splitCoalescedMessages(msgs []ipv6.Message, firstMsgAt int, getGSO getGSOFunc) (n int, err error) {
for i := firstMsgAt; i < len(msgs); i++ {
msg := &msgs[i]
if msg.N == 0 {
return n, err
}
var (
gsoSize int
start int
end = msg.N
numToSplit = 1
)
gsoSize, err = getGSO(msg.OOB[:msg.NN])
if err != nil {
return n, err
}
if gsoSize > 0 {
numToSplit = (msg.N + gsoSize - 1) / gsoSize
end = gsoSize
}
for j := 0; j < numToSplit; j++ {
if n > i {
return n, errors.New("splitting coalesced packet resulted in overflow")
}
copied := copy(msgs[n].Buffers[0], msg.Buffers[0][start:end])
msgs[n].N = copied
msgs[n].Addr = msg.Addr
start = end
end += gsoSize
if end > msg.N {
end = msg.N
}
n++
}
if i != n-1 {
// It is legal for bytes to move within msg.Buffers[0] as a result
// of splitting, so we only zero the source msg len when it is not
// the destination of the last split operation above.
msg.N = 0
}
}
s.udpAddrPool.Put(ua)
s.ipv6MsgsPool.Put(msgs)
return err
return n, nil
}

View File

@@ -29,6 +29,9 @@ func init() {
// Set beyond *mem_max if CAP_NET_ADMIN
_ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_RCVBUFFORCE, socketBufferSize)
_ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_SNDBUFFORCE, socketBufferSize)
_ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_REUSEPORT, 1) //todo!!!
_ = unix.SetsockoptInt(int(fd), unix.SOL_UDP, unix.UDP_SEGMENT, 0xffff) //todo!!!
//print(err.Error())
})
},

View File

@@ -1,9 +1,3 @@
//go:build linux && !android
// SPDX-License-Identifier: MIT
//
// Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
package conn
import (
@@ -13,6 +7,37 @@ import (
"golang.org/x/sys/unix"
)
func (e *StdNetEndpoint) SrcIP() netip.Addr {
switch len(e.src) {
case unix.CmsgSpace(unix.SizeofInet4Pktinfo):
info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)]))
return netip.AddrFrom4(info.Spec_dst)
case unix.CmsgSpace(unix.SizeofInet6Pktinfo):
info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)]))
// TODO: set zone. in order to do so we need to check if the address is
// link local, and if it is perform a syscall to turn the ifindex into a
// zone string because netip uses string zones.
return netip.AddrFrom16(info.Addr)
}
return netip.Addr{}
}
func (e *StdNetEndpoint) SrcIfidx() int32 {
switch len(e.src) {
case unix.CmsgSpace(unix.SizeofInet4Pktinfo):
info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)]))
return info.Ifindex
case unix.CmsgSpace(unix.SizeofInet6Pktinfo):
info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)]))
return int32(info.Ifindex)
}
return 0
}
func (e *StdNetEndpoint) SrcToString() string {
return e.SrcIP().String()
}
// getSrcFromControl parses the control for PKTINFO and if found updates ep with
// the source information found.
func getSrcFromControl(control []byte, ep *StdNetEndpoint) {
@@ -34,83 +59,47 @@ func getSrcFromControl(control []byte, ep *StdNetEndpoint) {
if hdr.Level == unix.IPPROTO_IP &&
hdr.Type == unix.IP_PKTINFO {
info := pktInfoFromBuf[unix.Inet4Pktinfo](data)
ep.src.Addr = netip.AddrFrom4(info.Spec_dst)
ep.src.ifidx = info.Ifindex
if ep.src == nil || cap(ep.src) < unix.CmsgSpace(unix.SizeofInet4Pktinfo) {
ep.src = make([]byte, 0, unix.CmsgSpace(unix.SizeofInet4Pktinfo))
}
ep.src = ep.src[:unix.CmsgSpace(unix.SizeofInet4Pktinfo)]
hdrBuf := unsafe.Slice((*byte)(unsafe.Pointer(&hdr)), unix.SizeofCmsghdr)
copy(ep.src, hdrBuf)
copy(ep.src[unix.CmsgLen(0):], data)
return
}
if hdr.Level == unix.IPPROTO_IPV6 &&
hdr.Type == unix.IPV6_PKTINFO {
info := pktInfoFromBuf[unix.Inet6Pktinfo](data)
ep.src.Addr = netip.AddrFrom16(info.Addr)
ep.src.ifidx = int32(info.Ifindex)
if ep.src == nil || cap(ep.src) < unix.CmsgSpace(unix.SizeofInet6Pktinfo) {
ep.src = make([]byte, 0, unix.CmsgSpace(unix.SizeofInet6Pktinfo))
}
ep.src = ep.src[:unix.CmsgSpace(unix.SizeofInet6Pktinfo)]
hdrBuf := unsafe.Slice((*byte)(unsafe.Pointer(&hdr)), unix.SizeofCmsghdr)
copy(ep.src, hdrBuf)
copy(ep.src[unix.CmsgLen(0):], data)
return
}
}
}
// pktInfoFromBuf returns type T populated from the provided buf via copy(). It
// panics if buf is of insufficient size.
func pktInfoFromBuf[T unix.Inet4Pktinfo | unix.Inet6Pktinfo](buf []byte) (t T) {
size := int(unsafe.Sizeof(t))
if len(buf) < size {
panic("pktInfoFromBuf: buffer too small")
}
copy(unsafe.Slice((*byte)(unsafe.Pointer(&t)), size), buf)
return t
}
// setSrcControl sets an IP{V6}_PKTINFO in control based on the source address
// and source ifindex found in ep. control's len will be set to 0 in the event
// that ep is a default value.
func setSrcControl(control *[]byte, ep *StdNetEndpoint) {
*control = (*control)[:cap(*control)]
if len(*control) < int(unsafe.Sizeof(unix.Cmsghdr{})) {
*control = (*control)[:0]
if cap(*control) < len(ep.src) {
return
}
if ep.src.ifidx == 0 && !ep.SrcIP().IsValid() {
*control = (*control)[:0]
return
}
if len(*control) < srcControlSize {
*control = (*control)[:0]
return
}
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&(*control)[0]))
if ep.SrcIP().Is4() {
hdr.Level = unix.IPPROTO_IP
hdr.Type = unix.IP_PKTINFO
hdr.SetLen(unix.CmsgLen(unix.SizeofInet4Pktinfo))
info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&(*control)[unix.SizeofCmsghdr]))
info.Ifindex = ep.src.ifidx
if ep.SrcIP().IsValid() {
info.Spec_dst = ep.SrcIP().As4()
}
*control = (*control)[:unix.CmsgSpace(unix.SizeofInet4Pktinfo)]
} else {
hdr.Level = unix.IPPROTO_IPV6
hdr.Type = unix.IPV6_PKTINFO
hdr.SetLen(unix.CmsgLen(unix.SizeofInet6Pktinfo))
info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&(*control)[unix.SizeofCmsghdr]))
info.Ifindex = uint32(ep.src.ifidx)
if ep.SrcIP().IsValid() {
info.Addr = ep.SrcIP().As16()
}
*control = (*control)[:unix.CmsgSpace(unix.SizeofInet6Pktinfo)]
}
*control = (*control)[:0]
*control = append(*control, ep.src...)
}
var srcControlSize = unix.CmsgSpace(unix.SizeofInet6Pktinfo)
// stickyControlSize returns the recommended buffer size for pooling sticky
// offloading control data.
var stickyControlSize = unix.CmsgSpace(unix.SizeofInet6Pktinfo)
const StdNetSupportsStickySockets = true