// SPDX-License-Identifier: MIT // // Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. package conn import ( "context" "errors" "net" "net/netip" "runtime" "strconv" "sync" "syscall" "golang.org/x/net/ipv4" "golang.org/x/net/ipv6" "golang.org/x/sys/unix" ) var ( _ Bind = (*StdNetBind)(nil) ) // StdNetBind implements Bind for all platforms. While Windows has its own Bind // (see bind_windows.go), it may fall back to StdNetBind. // TODO: Remove usage of ipv{4,6}.PacketConn when net.UDPConn has comparable // methods for sending and receiving multiple datagrams per-syscall. See the // proposal in https://github.com/golang/go/issues/45886#issuecomment-1218301564. type StdNetBind struct { mu sync.Mutex // protects all fields except as specified ipv4 *net.UDPConn ipv6 *net.UDPConn ipv4PC *ipv4.PacketConn // will be nil on non-Linux ipv6PC *ipv6.PacketConn // will be nil on non-Linux // these three fields are not guarded by mu udpAddrPool sync.Pool ipv4MsgsPool sync.Pool ipv6MsgsPool sync.Pool blackhole4 bool blackhole6 bool listenAddr4 string listenAddr6 string bindV4 bool bindV6 bool reusePort bool } func newStdNetBind() *StdNetBind { return &StdNetBind{ udpAddrPool: sync.Pool{ New: func() any { return &net.UDPAddr{ IP: make([]byte, 16), } }, }, ipv4MsgsPool: sync.Pool{ New: func() any { msgs := make([]ipv4.Message, IdealBatchSize) for i := range msgs { msgs[i].Buffers = make(net.Buffers, 1) msgs[i].OOB = make([]byte, srcControlSize) } return &msgs }, }, ipv6MsgsPool: sync.Pool{ New: func() any { msgs := make([]ipv6.Message, IdealBatchSize) for i := range msgs { msgs[i].Buffers = make(net.Buffers, 1) msgs[i].OOB = make([]byte, srcControlSize) } return &msgs }, }, bindV4: true, bindV6: true, reusePort: false, } } // NewStdNetBind creates a bind that listens on all interfaces. func NewStdNetBind() *StdNetBind { return newStdNetBind() } // NewStdNetBindForAddr creates a bind that listens on a specific address. // If addr is IPv4, only the IPv4 socket will be created. For IPv6, only the // IPv6 socket will be created. func NewStdNetBindForAddr(addr netip.Addr, reusePort bool) *StdNetBind { b := newStdNetBind() if addr.IsValid() { if addr.Is4() { b.listenAddr4 = addr.Unmap().String() b.bindV4 = true b.bindV6 = false } else { b.listenAddr6 = addr.Unmap().String() b.bindV6 = true b.bindV4 = false } } b.reusePort = reusePort return b } type StdNetEndpoint struct { // AddrPort is the endpoint destination. netip.AddrPort // src is the current sticky source address and interface index, if supported. src struct { netip.Addr ifidx int32 } } var ( _ Bind = (*StdNetBind)(nil) _ Endpoint = &StdNetEndpoint{} ) func (*StdNetBind) ParseEndpoint(s string) (Endpoint, error) { e, err := netip.ParseAddrPort(s) if err != nil { return nil, err } return &StdNetEndpoint{ AddrPort: e, }, nil } func (e *StdNetEndpoint) ClearSrc() { e.src.ifidx = 0 e.src.Addr = netip.Addr{} } func (e *StdNetEndpoint) DstIP() netip.Addr { return e.AddrPort.Addr() } func (e *StdNetEndpoint) SrcIP() netip.Addr { return e.src.Addr } func (e *StdNetEndpoint) SrcIfidx() int32 { return e.src.ifidx } func (e *StdNetEndpoint) DstToBytes() []byte { b, _ := e.AddrPort.MarshalBinary() return b } func (e *StdNetEndpoint) DstToString() string { return e.AddrPort.String() } func (e *StdNetEndpoint) SrcToString() string { return e.src.Addr.String() } func (s *StdNetBind) listenNet(network string, host string, port int) (*net.UDPConn, int, error) { lc := listenConfig() if s.reusePort { base := lc.Control lc.Control = func(network, address string, c syscall.RawConn) error { if base != nil { if err := base(network, address, c); err != nil { return err } } return c.Control(func(fd uintptr) { _ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_REUSEPORT, 1) }) } } addr := ":" + strconv.Itoa(port) if host != "" { addr = net.JoinHostPort(host, strconv.Itoa(port)) } conn, err := lc.ListenPacket(context.Background(), network, addr) if err != nil { return nil, 0, err } // Retrieve port. laddr := conn.LocalAddr() uaddr, err := net.ResolveUDPAddr( laddr.Network(), laddr.String(), ) if err != nil { return nil, 0, err } return conn.(*net.UDPConn), uaddr.Port, nil } func (s *StdNetBind) openIPv4(port int) (*net.UDPConn, *ipv4.PacketConn, int, error) { if !s.bindV4 { return nil, nil, port, nil } host := s.listenAddr4 conn, actualPort, err := s.listenNet("udp4", host, port) if err != nil { if errors.Is(err, syscall.EAFNOSUPPORT) { return nil, nil, port, nil } return nil, nil, port, err } if runtime.GOOS != "linux" { return conn, nil, actualPort, nil } pc := ipv4.NewPacketConn(conn) return conn, pc, actualPort, nil } func (s *StdNetBind) openIPv6(port int) (*net.UDPConn, *ipv6.PacketConn, int, error) { if !s.bindV6 { return nil, nil, port, nil } host := s.listenAddr6 conn, actualPort, err := s.listenNet("udp6", host, port) if err != nil { if errors.Is(err, syscall.EAFNOSUPPORT) { return nil, nil, port, nil } return nil, nil, port, err } if runtime.GOOS != "linux" { return conn, nil, actualPort, nil } pc := ipv6.NewPacketConn(conn) return conn, pc, actualPort, nil } func (s *StdNetBind) Open(uport uint16) ([]ReceiveFunc, uint16, error) { s.mu.Lock() defer s.mu.Unlock() var err error var tries int if s.ipv4 != nil || s.ipv6 != nil { return nil, 0, ErrBindAlreadyOpen } // Attempt to open ipv4 and ipv6 listeners on the same port. // If uport is 0, we can retry on failure. again: port := int(uport) var v4conn *net.UDPConn var v6conn *net.UDPConn var v4pc *ipv4.PacketConn var v6pc *ipv6.PacketConn v4conn, v4pc, port, err = s.openIPv4(port) if err != nil { return nil, 0, err } // Listen on the same port as we're using for ipv4. v6conn, v6pc, port, err = s.openIPv6(port) if uport == 0 && errors.Is(err, syscall.EADDRINUSE) && tries < 100 { if v4conn != nil { v4conn.Close() } tries++ goto again } if err != nil { if v4conn != nil { v4conn.Close() } return nil, 0, err } var fns []ReceiveFunc if v4conn != nil { s.ipv4 = v4conn if v4pc != nil { s.ipv4PC = v4pc } fns = append(fns, s.makeReceiveIPv4(v4pc, v4conn)) } if v6conn != nil { s.ipv6 = v6conn if v6pc != nil { s.ipv6PC = v6pc } fns = append(fns, s.makeReceiveIPv6(v6pc, v6conn)) } if len(fns) == 0 { return nil, 0, syscall.EAFNOSUPPORT } return fns, uint16(port), nil } func (s *StdNetBind) makeReceiveIPv4(pc *ipv4.PacketConn, conn *net.UDPConn) ReceiveFunc { return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) { msgs := s.ipv4MsgsPool.Get().(*[]ipv4.Message) defer s.ipv4MsgsPool.Put(msgs) for i := range bufs { (*msgs)[i].Buffers[0] = bufs[i] } var numMsgs int if runtime.GOOS == "linux" && pc != nil { numMsgs, err = pc.ReadBatch(*msgs, 0) if err != nil { return 0, err } } else { msg := &(*msgs)[0] msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB) if err != nil { return 0, err } numMsgs = 1 } for i := 0; i < numMsgs; i++ { msg := &(*msgs)[i] sizes[i] = msg.N addrPort := msg.Addr.(*net.UDPAddr).AddrPort() ep := &StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation getSrcFromControl(msg.OOB[:msg.NN], ep) eps[i] = ep } return numMsgs, nil } } func (s *StdNetBind) makeReceiveIPv6(pc *ipv6.PacketConn, conn *net.UDPConn) ReceiveFunc { return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) { msgs := s.ipv6MsgsPool.Get().(*[]ipv6.Message) defer s.ipv6MsgsPool.Put(msgs) for i := range bufs { (*msgs)[i].Buffers[0] = bufs[i] } var numMsgs int if runtime.GOOS == "linux" && pc != nil { numMsgs, err = pc.ReadBatch(*msgs, 0) if err != nil { return 0, err } } else { msg := &(*msgs)[0] msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB) if err != nil { return 0, err } numMsgs = 1 } for i := 0; i < numMsgs; i++ { msg := &(*msgs)[i] sizes[i] = msg.N addrPort := msg.Addr.(*net.UDPAddr).AddrPort() ep := &StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation getSrcFromControl(msg.OOB[:msg.NN], ep) eps[i] = ep } return numMsgs, nil } } // TODO: When all Binds handle IdealBatchSize, remove this dynamic function and // rename the IdealBatchSize constant to BatchSize. func (s *StdNetBind) BatchSize() int { if runtime.GOOS == "linux" { return IdealBatchSize } return 1 } func (s *StdNetBind) Close() error { s.mu.Lock() defer s.mu.Unlock() var err1, err2 error if s.ipv4 != nil { err1 = s.ipv4.Close() s.ipv4 = nil s.ipv4PC = nil } if s.ipv6 != nil { err2 = s.ipv6.Close() s.ipv6 = nil s.ipv6PC = nil } s.blackhole4 = false s.blackhole6 = false if err1 != nil { return err1 } return err2 } func (s *StdNetBind) Send(bufs [][]byte, endpoint Endpoint) error { s.mu.Lock() blackhole := s.blackhole4 conn := s.ipv4 var ( pc4 *ipv4.PacketConn pc6 *ipv6.PacketConn ) is6 := false if endpoint.DstIP().Is6() { blackhole = s.blackhole6 conn = s.ipv6 pc6 = s.ipv6PC is6 = true } else { pc4 = s.ipv4PC } s.mu.Unlock() if blackhole { return nil } if conn == nil { return syscall.EAFNOSUPPORT } if is6 { return s.send6(conn, pc6, endpoint, bufs) } else { return s.send4(conn, pc4, endpoint, bufs) } } func (s *StdNetBind) send4(conn *net.UDPConn, pc *ipv4.PacketConn, ep Endpoint, bufs [][]byte) error { ua := s.udpAddrPool.Get().(*net.UDPAddr) as4 := ep.DstIP().As4() copy(ua.IP, as4[:]) ua.IP = ua.IP[:4] ua.Port = int(ep.(*StdNetEndpoint).Port()) msgs := s.ipv4MsgsPool.Get().(*[]ipv4.Message) for i, buf := range bufs { (*msgs)[i].Buffers[0] = buf (*msgs)[i].Addr = ua setSrcControl(&(*msgs)[i].OOB, ep.(*StdNetEndpoint)) } var ( n int err error start int ) if runtime.GOOS == "linux" && pc != nil { for { n, err = pc.WriteBatch((*msgs)[start:len(bufs)], 0) if err != nil { if errors.Is(err, syscall.EAFNOSUPPORT) { for j := start; j < len(bufs); j++ { _, _, werr := conn.WriteMsgUDP(bufs[j], (*msgs)[j].OOB, ua) if werr != nil { err = werr break } } } break } if n == len((*msgs)[start:len(bufs)]) { break } start += n } } else { for i, buf := range bufs { _, _, err = conn.WriteMsgUDP(buf, (*msgs)[i].OOB, ua) if err != nil { break } } } s.udpAddrPool.Put(ua) s.ipv4MsgsPool.Put(msgs) return err } func (s *StdNetBind) send6(conn *net.UDPConn, pc *ipv6.PacketConn, ep Endpoint, bufs [][]byte) error { ua := s.udpAddrPool.Get().(*net.UDPAddr) as16 := ep.DstIP().As16() copy(ua.IP, as16[:]) ua.IP = ua.IP[:16] ua.Port = int(ep.(*StdNetEndpoint).Port()) msgs := s.ipv6MsgsPool.Get().(*[]ipv6.Message) for i, buf := range bufs { (*msgs)[i].Buffers[0] = buf (*msgs)[i].Addr = ua setSrcControl(&(*msgs)[i].OOB, ep.(*StdNetEndpoint)) } var ( n int err error start int ) if runtime.GOOS == "linux" && pc != nil { for { n, err = pc.WriteBatch((*msgs)[start:len(bufs)], 0) if err != nil { if errors.Is(err, syscall.EAFNOSUPPORT) { for j := start; j < len(bufs); j++ { _, _, werr := conn.WriteMsgUDP(bufs[j], (*msgs)[j].OOB, ua) if werr != nil { err = werr break } } } break } if n == len((*msgs)[start:len(bufs)]) { break } start += n } } else { for i, buf := range bufs { _, _, err = conn.WriteMsgUDP(buf, (*msgs)[i].OOB, ua) if err != nil { break } } } s.udpAddrPool.Put(ua) s.ipv6MsgsPool.Put(msgs) return err }