mirror of
https://github.com/slackhq/nebula.git
synced 2025-11-22 00:15:37 +01:00
add new files for compat layer
This commit is contained in:
102
overlay/wireguard_tun_linux.go
Normal file
102
overlay/wireguard_tun_linux.go
Normal file
@@ -0,0 +1,102 @@
|
||||
//go:build linux && !android && !e2e_testing
|
||||
|
||||
package overlay
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
wgtun "github.com/slackhq/nebula/wgstack/tun"
|
||||
)
|
||||
|
||||
type wireguardTunIO struct {
|
||||
dev wgtun.Device
|
||||
mtu int
|
||||
batchSize int
|
||||
|
||||
readMu sync.Mutex
|
||||
readBufs [][]byte
|
||||
readLens []int
|
||||
pending [][]byte
|
||||
pendIdx int
|
||||
|
||||
writeMu sync.Mutex
|
||||
writeBuf []byte
|
||||
writeWrap [][]byte
|
||||
}
|
||||
|
||||
func newWireguardTunIO(dev wgtun.Device, mtu int) *wireguardTunIO {
|
||||
batch := dev.BatchSize()
|
||||
if batch <= 0 {
|
||||
batch = 1
|
||||
}
|
||||
if mtu <= 0 {
|
||||
mtu = DefaultMTU
|
||||
}
|
||||
bufs := make([][]byte, batch)
|
||||
for i := range bufs {
|
||||
bufs[i] = make([]byte, wgtun.VirtioNetHdrLen+mtu)
|
||||
}
|
||||
return &wireguardTunIO{
|
||||
dev: dev,
|
||||
mtu: mtu,
|
||||
batchSize: batch,
|
||||
readBufs: bufs,
|
||||
readLens: make([]int, batch),
|
||||
pending: make([][]byte, 0, batch),
|
||||
writeBuf: make([]byte, wgtun.VirtioNetHdrLen+mtu),
|
||||
writeWrap: make([][]byte, 1),
|
||||
}
|
||||
}
|
||||
|
||||
func (w *wireguardTunIO) Read(p []byte) (int, error) {
|
||||
w.readMu.Lock()
|
||||
defer w.readMu.Unlock()
|
||||
|
||||
for {
|
||||
if w.pendIdx < len(w.pending) {
|
||||
segment := w.pending[w.pendIdx]
|
||||
w.pendIdx++
|
||||
n := copy(p, segment)
|
||||
return n, nil
|
||||
}
|
||||
|
||||
n, err := w.dev.Read(w.readBufs, w.readLens, wgtun.VirtioNetHdrLen)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
w.pending = w.pending[:0]
|
||||
w.pendIdx = 0
|
||||
for i := 0; i < n; i++ {
|
||||
length := w.readLens[i]
|
||||
if length == 0 {
|
||||
continue
|
||||
}
|
||||
segment := w.readBufs[i][wgtun.VirtioNetHdrLen : wgtun.VirtioNetHdrLen+length]
|
||||
w.pending = append(w.pending, segment)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (w *wireguardTunIO) Write(p []byte) (int, error) {
|
||||
if len(p) > w.mtu {
|
||||
return 0, fmt.Errorf("wireguard tun: payload exceeds MTU (%d > %d)", len(p), w.mtu)
|
||||
}
|
||||
w.writeMu.Lock()
|
||||
defer w.writeMu.Unlock()
|
||||
buf := w.writeBuf[:wgtun.VirtioNetHdrLen+len(p)]
|
||||
for i := 0; i < wgtun.VirtioNetHdrLen; i++ {
|
||||
buf[i] = 0
|
||||
}
|
||||
copy(buf[wgtun.VirtioNetHdrLen:], p)
|
||||
w.writeWrap[0] = buf
|
||||
n, err := w.dev.Write(w.writeWrap, wgtun.VirtioNetHdrLen)
|
||||
if err != nil {
|
||||
return n, err
|
||||
}
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
func (w *wireguardTunIO) Close() error {
|
||||
return nil
|
||||
}
|
||||
132
udp/wireguard_conn_linux.go
Normal file
132
udp/wireguard_conn_linux.go
Normal file
@@ -0,0 +1,132 @@
|
||||
//go:build linux && !android && !e2e_testing
|
||||
|
||||
package udp
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/config"
|
||||
wgconn "github.com/slackhq/nebula/wgstack/conn"
|
||||
)
|
||||
|
||||
// WGConn adapts WireGuard's batched UDP bind implementation to Nebula's udp.Conn interface.
|
||||
type WGConn struct {
|
||||
l *logrus.Logger
|
||||
bind *wgconn.StdNetBind
|
||||
recvers []wgconn.ReceiveFunc
|
||||
batch int
|
||||
localIP netip.Addr
|
||||
localPort uint16
|
||||
closed atomic.Bool
|
||||
|
||||
closeOnce sync.Once
|
||||
}
|
||||
|
||||
// NewWireguardListener creates a UDP listener backed by WireGuard's StdNetBind.
|
||||
func NewWireguardListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
|
||||
bind := wgconn.NewStdNetBindForAddr(ip, multi)
|
||||
recvers, actualPort, err := bind.Open(uint16(port))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if batch <= 0 || batch > bind.BatchSize() {
|
||||
batch = bind.BatchSize()
|
||||
}
|
||||
return &WGConn{
|
||||
l: l,
|
||||
bind: bind,
|
||||
recvers: recvers,
|
||||
batch: batch,
|
||||
localIP: ip,
|
||||
localPort: actualPort,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *WGConn) Rebind() error {
|
||||
// WireGuard's bind does not support rebinding in place.
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *WGConn) LocalAddr() (netip.AddrPort, error) {
|
||||
if !c.localIP.IsValid() || c.localIP.IsUnspecified() {
|
||||
// Fallback to wildcard IPv4 for display purposes.
|
||||
return netip.AddrPortFrom(netip.IPv4Unspecified(), c.localPort), nil
|
||||
}
|
||||
return netip.AddrPortFrom(c.localIP, c.localPort), nil
|
||||
}
|
||||
|
||||
func (c *WGConn) listen(fn wgconn.ReceiveFunc, r EncReader) {
|
||||
batchSize := c.batch
|
||||
packets := make([][]byte, batchSize)
|
||||
for i := range packets {
|
||||
packets[i] = make([]byte, MTU)
|
||||
}
|
||||
sizes := make([]int, batchSize)
|
||||
endpoints := make([]wgconn.Endpoint, batchSize)
|
||||
|
||||
for {
|
||||
if c.closed.Load() {
|
||||
return
|
||||
}
|
||||
n, err := fn(packets, sizes, endpoints)
|
||||
if err != nil {
|
||||
if errors.Is(err, net.ErrClosed) {
|
||||
return
|
||||
}
|
||||
if c.l != nil {
|
||||
c.l.WithError(err).Debug("wireguard UDP listener receive error")
|
||||
}
|
||||
continue
|
||||
}
|
||||
for i := 0; i < n; i++ {
|
||||
if sizes[i] == 0 {
|
||||
continue
|
||||
}
|
||||
stdEp, ok := endpoints[i].(*wgconn.StdNetEndpoint)
|
||||
if !ok {
|
||||
if c.l != nil {
|
||||
c.l.Warn("wireguard UDP listener received unexpected endpoint type")
|
||||
}
|
||||
continue
|
||||
}
|
||||
addr := stdEp.AddrPort
|
||||
r(addr, packets[i][:sizes[i]])
|
||||
endpoints[i] = nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *WGConn) ListenOut(r EncReader) {
|
||||
for _, fn := range c.recvers {
|
||||
go c.listen(fn, r)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *WGConn) WriteTo(b []byte, addr netip.AddrPort) error {
|
||||
if len(b) == 0 {
|
||||
return nil
|
||||
}
|
||||
if c.closed.Load() {
|
||||
return net.ErrClosed
|
||||
}
|
||||
ep := &wgconn.StdNetEndpoint{AddrPort: addr}
|
||||
return c.bind.Send([][]byte{b}, ep)
|
||||
}
|
||||
|
||||
func (c *WGConn) ReloadConfig(*config.C) {
|
||||
// WireGuard bind currently does not expose runtime configuration knobs.
|
||||
}
|
||||
|
||||
func (c *WGConn) Close() error {
|
||||
var err error
|
||||
c.closeOnce.Do(func() {
|
||||
c.closed.Store(true)
|
||||
err = c.bind.Close()
|
||||
})
|
||||
return err
|
||||
}
|
||||
15
udp/wireguard_conn_unsupported.go
Normal file
15
udp/wireguard_conn_unsupported.go
Normal file
@@ -0,0 +1,15 @@
|
||||
//go:build !linux || android || e2e_testing
|
||||
|
||||
package udp
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// NewWireguardListener is only available on Linux builds.
|
||||
func NewWireguardListener(*logrus.Logger, netip.Addr, int, bool, int) (Conn, error) {
|
||||
return nil, fmt.Errorf("wireguard experimental UDP listener is only supported on Linux")
|
||||
}
|
||||
513
wgstack/conn/bind_std.go
Normal file
513
wgstack/conn/bind_std.go
Normal file
@@ -0,0 +1,513 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
//
|
||||
// Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
|
||||
package conn
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"sync"
|
||||
"syscall"
|
||||
|
||||
"golang.org/x/net/ipv4"
|
||||
"golang.org/x/net/ipv6"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
var (
|
||||
_ Bind = (*StdNetBind)(nil)
|
||||
)
|
||||
|
||||
// StdNetBind implements Bind for all platforms. While Windows has its own Bind
|
||||
// (see bind_windows.go), it may fall back to StdNetBind.
|
||||
// TODO: Remove usage of ipv{4,6}.PacketConn when net.UDPConn has comparable
|
||||
// methods for sending and receiving multiple datagrams per-syscall. See the
|
||||
// proposal in https://github.com/golang/go/issues/45886#issuecomment-1218301564.
|
||||
type StdNetBind struct {
|
||||
mu sync.Mutex // protects all fields except as specified
|
||||
ipv4 *net.UDPConn
|
||||
ipv6 *net.UDPConn
|
||||
ipv4PC *ipv4.PacketConn // will be nil on non-Linux
|
||||
ipv6PC *ipv6.PacketConn // will be nil on non-Linux
|
||||
|
||||
// these three fields are not guarded by mu
|
||||
udpAddrPool sync.Pool
|
||||
ipv4MsgsPool sync.Pool
|
||||
ipv6MsgsPool sync.Pool
|
||||
|
||||
blackhole4 bool
|
||||
blackhole6 bool
|
||||
|
||||
listenAddr4 string
|
||||
listenAddr6 string
|
||||
bindV4 bool
|
||||
bindV6 bool
|
||||
reusePort bool
|
||||
}
|
||||
|
||||
func newStdNetBind() *StdNetBind {
|
||||
return &StdNetBind{
|
||||
udpAddrPool: sync.Pool{
|
||||
New: func() any {
|
||||
return &net.UDPAddr{
|
||||
IP: make([]byte, 16),
|
||||
}
|
||||
},
|
||||
},
|
||||
|
||||
ipv4MsgsPool: sync.Pool{
|
||||
New: func() any {
|
||||
msgs := make([]ipv4.Message, IdealBatchSize)
|
||||
for i := range msgs {
|
||||
msgs[i].Buffers = make(net.Buffers, 1)
|
||||
msgs[i].OOB = make([]byte, srcControlSize)
|
||||
}
|
||||
return &msgs
|
||||
},
|
||||
},
|
||||
|
||||
ipv6MsgsPool: sync.Pool{
|
||||
New: func() any {
|
||||
msgs := make([]ipv6.Message, IdealBatchSize)
|
||||
for i := range msgs {
|
||||
msgs[i].Buffers = make(net.Buffers, 1)
|
||||
msgs[i].OOB = make([]byte, srcControlSize)
|
||||
}
|
||||
return &msgs
|
||||
},
|
||||
},
|
||||
bindV4: true,
|
||||
bindV6: true,
|
||||
reusePort: false,
|
||||
}
|
||||
}
|
||||
|
||||
// NewStdNetBind creates a bind that listens on all interfaces.
|
||||
func NewStdNetBind() *StdNetBind {
|
||||
return newStdNetBind()
|
||||
}
|
||||
|
||||
// NewStdNetBindForAddr creates a bind that listens on a specific address.
|
||||
// If addr is IPv4, only the IPv4 socket will be created. For IPv6, only the
|
||||
// IPv6 socket will be created.
|
||||
func NewStdNetBindForAddr(addr netip.Addr, reusePort bool) *StdNetBind {
|
||||
b := newStdNetBind()
|
||||
if addr.IsValid() {
|
||||
if addr.Is4() {
|
||||
b.listenAddr4 = addr.Unmap().String()
|
||||
b.bindV4 = true
|
||||
b.bindV6 = false
|
||||
} else {
|
||||
b.listenAddr6 = addr.Unmap().String()
|
||||
b.bindV6 = true
|
||||
b.bindV4 = false
|
||||
}
|
||||
}
|
||||
b.reusePort = reusePort
|
||||
return b
|
||||
}
|
||||
|
||||
type StdNetEndpoint struct {
|
||||
// AddrPort is the endpoint destination.
|
||||
netip.AddrPort
|
||||
// src is the current sticky source address and interface index, if supported.
|
||||
src struct {
|
||||
netip.Addr
|
||||
ifidx int32
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
_ Bind = (*StdNetBind)(nil)
|
||||
_ Endpoint = &StdNetEndpoint{}
|
||||
)
|
||||
|
||||
func (*StdNetBind) ParseEndpoint(s string) (Endpoint, error) {
|
||||
e, err := netip.ParseAddrPort(s)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &StdNetEndpoint{
|
||||
AddrPort: e,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (e *StdNetEndpoint) ClearSrc() {
|
||||
e.src.ifidx = 0
|
||||
e.src.Addr = netip.Addr{}
|
||||
}
|
||||
|
||||
func (e *StdNetEndpoint) DstIP() netip.Addr {
|
||||
return e.AddrPort.Addr()
|
||||
}
|
||||
|
||||
func (e *StdNetEndpoint) SrcIP() netip.Addr {
|
||||
return e.src.Addr
|
||||
}
|
||||
|
||||
func (e *StdNetEndpoint) SrcIfidx() int32 {
|
||||
return e.src.ifidx
|
||||
}
|
||||
|
||||
func (e *StdNetEndpoint) DstToBytes() []byte {
|
||||
b, _ := e.AddrPort.MarshalBinary()
|
||||
return b
|
||||
}
|
||||
|
||||
func (e *StdNetEndpoint) DstToString() string {
|
||||
return e.AddrPort.String()
|
||||
}
|
||||
|
||||
func (e *StdNetEndpoint) SrcToString() string {
|
||||
return e.src.Addr.String()
|
||||
}
|
||||
|
||||
func (s *StdNetBind) listenNet(network string, host string, port int) (*net.UDPConn, int, error) {
|
||||
lc := listenConfig()
|
||||
if s.reusePort {
|
||||
base := lc.Control
|
||||
lc.Control = func(network, address string, c syscall.RawConn) error {
|
||||
if base != nil {
|
||||
if err := base(network, address, c); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return c.Control(func(fd uintptr) {
|
||||
_ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_REUSEPORT, 1)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
addr := ":" + strconv.Itoa(port)
|
||||
if host != "" {
|
||||
addr = net.JoinHostPort(host, strconv.Itoa(port))
|
||||
}
|
||||
|
||||
conn, err := lc.ListenPacket(context.Background(), network, addr)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// Retrieve port.
|
||||
laddr := conn.LocalAddr()
|
||||
uaddr, err := net.ResolveUDPAddr(
|
||||
laddr.Network(),
|
||||
laddr.String(),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
return conn.(*net.UDPConn), uaddr.Port, nil
|
||||
}
|
||||
|
||||
func (s *StdNetBind) openIPv4(port int) (*net.UDPConn, *ipv4.PacketConn, int, error) {
|
||||
if !s.bindV4 {
|
||||
return nil, nil, port, nil
|
||||
}
|
||||
host := s.listenAddr4
|
||||
conn, actualPort, err := s.listenNet("udp4", host, port)
|
||||
if err != nil {
|
||||
if errors.Is(err, syscall.EAFNOSUPPORT) {
|
||||
return nil, nil, port, nil
|
||||
}
|
||||
return nil, nil, port, err
|
||||
}
|
||||
if runtime.GOOS != "linux" {
|
||||
return conn, nil, actualPort, nil
|
||||
}
|
||||
pc := ipv4.NewPacketConn(conn)
|
||||
return conn, pc, actualPort, nil
|
||||
}
|
||||
|
||||
func (s *StdNetBind) openIPv6(port int) (*net.UDPConn, *ipv6.PacketConn, int, error) {
|
||||
if !s.bindV6 {
|
||||
return nil, nil, port, nil
|
||||
}
|
||||
host := s.listenAddr6
|
||||
conn, actualPort, err := s.listenNet("udp6", host, port)
|
||||
if err != nil {
|
||||
if errors.Is(err, syscall.EAFNOSUPPORT) {
|
||||
return nil, nil, port, nil
|
||||
}
|
||||
return nil, nil, port, err
|
||||
}
|
||||
if runtime.GOOS != "linux" {
|
||||
return conn, nil, actualPort, nil
|
||||
}
|
||||
pc := ipv6.NewPacketConn(conn)
|
||||
return conn, pc, actualPort, nil
|
||||
}
|
||||
|
||||
func (s *StdNetBind) Open(uport uint16) ([]ReceiveFunc, uint16, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
var err error
|
||||
var tries int
|
||||
|
||||
if s.ipv4 != nil || s.ipv6 != nil {
|
||||
return nil, 0, ErrBindAlreadyOpen
|
||||
}
|
||||
|
||||
// Attempt to open ipv4 and ipv6 listeners on the same port.
|
||||
// If uport is 0, we can retry on failure.
|
||||
again:
|
||||
port := int(uport)
|
||||
var v4conn *net.UDPConn
|
||||
var v6conn *net.UDPConn
|
||||
var v4pc *ipv4.PacketConn
|
||||
var v6pc *ipv6.PacketConn
|
||||
|
||||
v4conn, v4pc, port, err = s.openIPv4(port)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// Listen on the same port as we're using for ipv4.
|
||||
v6conn, v6pc, port, err = s.openIPv6(port)
|
||||
if uport == 0 && errors.Is(err, syscall.EADDRINUSE) && tries < 100 {
|
||||
if v4conn != nil {
|
||||
v4conn.Close()
|
||||
}
|
||||
tries++
|
||||
goto again
|
||||
}
|
||||
if err != nil {
|
||||
if v4conn != nil {
|
||||
v4conn.Close()
|
||||
}
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
var fns []ReceiveFunc
|
||||
if v4conn != nil {
|
||||
s.ipv4 = v4conn
|
||||
if v4pc != nil {
|
||||
s.ipv4PC = v4pc
|
||||
}
|
||||
fns = append(fns, s.makeReceiveIPv4(v4pc, v4conn))
|
||||
}
|
||||
if v6conn != nil {
|
||||
s.ipv6 = v6conn
|
||||
if v6pc != nil {
|
||||
s.ipv6PC = v6pc
|
||||
}
|
||||
fns = append(fns, s.makeReceiveIPv6(v6pc, v6conn))
|
||||
}
|
||||
if len(fns) == 0 {
|
||||
return nil, 0, syscall.EAFNOSUPPORT
|
||||
}
|
||||
|
||||
return fns, uint16(port), nil
|
||||
}
|
||||
|
||||
func (s *StdNetBind) makeReceiveIPv4(pc *ipv4.PacketConn, conn *net.UDPConn) ReceiveFunc {
|
||||
return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
|
||||
msgs := s.ipv4MsgsPool.Get().(*[]ipv4.Message)
|
||||
defer s.ipv4MsgsPool.Put(msgs)
|
||||
for i := range bufs {
|
||||
(*msgs)[i].Buffers[0] = bufs[i]
|
||||
}
|
||||
var numMsgs int
|
||||
if runtime.GOOS == "linux" && pc != nil {
|
||||
numMsgs, err = pc.ReadBatch(*msgs, 0)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
} else {
|
||||
msg := &(*msgs)[0]
|
||||
msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
numMsgs = 1
|
||||
}
|
||||
for i := 0; i < numMsgs; i++ {
|
||||
msg := &(*msgs)[i]
|
||||
sizes[i] = msg.N
|
||||
addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
|
||||
ep := &StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation
|
||||
getSrcFromControl(msg.OOB[:msg.NN], ep)
|
||||
eps[i] = ep
|
||||
}
|
||||
return numMsgs, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (s *StdNetBind) makeReceiveIPv6(pc *ipv6.PacketConn, conn *net.UDPConn) ReceiveFunc {
|
||||
return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
|
||||
msgs := s.ipv6MsgsPool.Get().(*[]ipv6.Message)
|
||||
defer s.ipv6MsgsPool.Put(msgs)
|
||||
for i := range bufs {
|
||||
(*msgs)[i].Buffers[0] = bufs[i]
|
||||
}
|
||||
var numMsgs int
|
||||
if runtime.GOOS == "linux" && pc != nil {
|
||||
numMsgs, err = pc.ReadBatch(*msgs, 0)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
} else {
|
||||
msg := &(*msgs)[0]
|
||||
msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
numMsgs = 1
|
||||
}
|
||||
for i := 0; i < numMsgs; i++ {
|
||||
msg := &(*msgs)[i]
|
||||
sizes[i] = msg.N
|
||||
addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
|
||||
ep := &StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation
|
||||
getSrcFromControl(msg.OOB[:msg.NN], ep)
|
||||
eps[i] = ep
|
||||
}
|
||||
return numMsgs, nil
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: When all Binds handle IdealBatchSize, remove this dynamic function and
|
||||
// rename the IdealBatchSize constant to BatchSize.
|
||||
func (s *StdNetBind) BatchSize() int {
|
||||
if runtime.GOOS == "linux" {
|
||||
return IdealBatchSize
|
||||
}
|
||||
return 1
|
||||
}
|
||||
|
||||
func (s *StdNetBind) Close() error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
var err1, err2 error
|
||||
if s.ipv4 != nil {
|
||||
err1 = s.ipv4.Close()
|
||||
s.ipv4 = nil
|
||||
s.ipv4PC = nil
|
||||
}
|
||||
if s.ipv6 != nil {
|
||||
err2 = s.ipv6.Close()
|
||||
s.ipv6 = nil
|
||||
s.ipv6PC = nil
|
||||
}
|
||||
s.blackhole4 = false
|
||||
s.blackhole6 = false
|
||||
if err1 != nil {
|
||||
return err1
|
||||
}
|
||||
return err2
|
||||
}
|
||||
|
||||
func (s *StdNetBind) Send(bufs [][]byte, endpoint Endpoint) error {
|
||||
s.mu.Lock()
|
||||
blackhole := s.blackhole4
|
||||
conn := s.ipv4
|
||||
var (
|
||||
pc4 *ipv4.PacketConn
|
||||
pc6 *ipv6.PacketConn
|
||||
)
|
||||
is6 := false
|
||||
if endpoint.DstIP().Is6() {
|
||||
blackhole = s.blackhole6
|
||||
conn = s.ipv6
|
||||
pc6 = s.ipv6PC
|
||||
is6 = true
|
||||
} else {
|
||||
pc4 = s.ipv4PC
|
||||
}
|
||||
s.mu.Unlock()
|
||||
|
||||
if blackhole {
|
||||
return nil
|
||||
}
|
||||
if conn == nil {
|
||||
return syscall.EAFNOSUPPORT
|
||||
}
|
||||
if is6 {
|
||||
return s.send6(conn, pc6, endpoint, bufs)
|
||||
} else {
|
||||
return s.send4(conn, pc4, endpoint, bufs)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *StdNetBind) send4(conn *net.UDPConn, pc *ipv4.PacketConn, ep Endpoint, bufs [][]byte) error {
|
||||
ua := s.udpAddrPool.Get().(*net.UDPAddr)
|
||||
as4 := ep.DstIP().As4()
|
||||
copy(ua.IP, as4[:])
|
||||
ua.IP = ua.IP[:4]
|
||||
ua.Port = int(ep.(*StdNetEndpoint).Port())
|
||||
msgs := s.ipv4MsgsPool.Get().(*[]ipv4.Message)
|
||||
for i, buf := range bufs {
|
||||
(*msgs)[i].Buffers[0] = buf
|
||||
(*msgs)[i].Addr = ua
|
||||
setSrcControl(&(*msgs)[i].OOB, ep.(*StdNetEndpoint))
|
||||
}
|
||||
var (
|
||||
n int
|
||||
err error
|
||||
start int
|
||||
)
|
||||
if runtime.GOOS == "linux" && pc != nil {
|
||||
for {
|
||||
n, err = pc.WriteBatch((*msgs)[start:len(bufs)], 0)
|
||||
if err != nil || n == len((*msgs)[start:len(bufs)]) {
|
||||
break
|
||||
}
|
||||
start += n
|
||||
}
|
||||
} else {
|
||||
for i, buf := range bufs {
|
||||
_, _, err = conn.WriteMsgUDP(buf, (*msgs)[i].OOB, ua)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
s.udpAddrPool.Put(ua)
|
||||
s.ipv4MsgsPool.Put(msgs)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *StdNetBind) send6(conn *net.UDPConn, pc *ipv6.PacketConn, ep Endpoint, bufs [][]byte) error {
|
||||
ua := s.udpAddrPool.Get().(*net.UDPAddr)
|
||||
as16 := ep.DstIP().As16()
|
||||
copy(ua.IP, as16[:])
|
||||
ua.IP = ua.IP[:16]
|
||||
ua.Port = int(ep.(*StdNetEndpoint).Port())
|
||||
msgs := s.ipv6MsgsPool.Get().(*[]ipv6.Message)
|
||||
for i, buf := range bufs {
|
||||
(*msgs)[i].Buffers[0] = buf
|
||||
(*msgs)[i].Addr = ua
|
||||
setSrcControl(&(*msgs)[i].OOB, ep.(*StdNetEndpoint))
|
||||
}
|
||||
var (
|
||||
n int
|
||||
err error
|
||||
start int
|
||||
)
|
||||
if runtime.GOOS == "linux" && pc != nil {
|
||||
for {
|
||||
n, err = pc.WriteBatch((*msgs)[start:len(bufs)], 0)
|
||||
if err != nil || n == len((*msgs)[start:len(bufs)]) {
|
||||
break
|
||||
}
|
||||
start += n
|
||||
}
|
||||
} else {
|
||||
for i, buf := range bufs {
|
||||
_, _, err = conn.WriteMsgUDP(buf, (*msgs)[i].OOB, ua)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
s.udpAddrPool.Put(ua)
|
||||
s.ipv6MsgsPool.Put(msgs)
|
||||
return err
|
||||
}
|
||||
131
wgstack/conn/conn.go
Normal file
131
wgstack/conn/conn.go
Normal file
@@ -0,0 +1,131 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
//
|
||||
// Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
|
||||
package conn
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"reflect"
|
||||
"runtime"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
IdealBatchSize = 128 // maximum number of packets handled per read and write
|
||||
)
|
||||
|
||||
// A ReceiveFunc receives at least one packet from the network and writes them
|
||||
// into packets. On a successful read it returns the number of elements of
|
||||
// sizes, packets, and endpoints that should be evaluated. Some elements of
|
||||
// sizes may be zero, and callers should ignore them. Callers must pass a sizes
|
||||
// and eps slice with a length greater than or equal to the length of packets.
|
||||
// These lengths must not exceed the length of the associated Bind.BatchSize().
|
||||
type ReceiveFunc func(packets [][]byte, sizes []int, eps []Endpoint) (n int, err error)
|
||||
|
||||
// A Bind listens on a port for both IPv6 and IPv4 UDP traffic.
|
||||
//
|
||||
// A Bind interface may also be a PeekLookAtSocketFd or BindSocketToInterface,
|
||||
// depending on the platform-specific implementation.
|
||||
type Bind interface {
|
||||
// Open puts the Bind into a listening state on a given port and reports the actual
|
||||
// port that it bound to. Passing zero results in a random selection.
|
||||
// fns is the set of functions that will be called to receive packets.
|
||||
Open(port uint16) (fns []ReceiveFunc, actualPort uint16, err error)
|
||||
|
||||
// Close closes the Bind listener.
|
||||
// All fns returned by Open must return net.ErrClosed after a call to Close.
|
||||
Close() error
|
||||
|
||||
// SetMark sets the mark for each packet sent through this Bind.
|
||||
// This mark is passed to the kernel as the socket option SO_MARK.
|
||||
SetMark(mark uint32) error
|
||||
|
||||
// Send writes one or more packets in bufs to address ep. The length of
|
||||
// bufs must not exceed BatchSize().
|
||||
Send(bufs [][]byte, ep Endpoint) error
|
||||
|
||||
// ParseEndpoint creates a new endpoint from a string.
|
||||
ParseEndpoint(s string) (Endpoint, error)
|
||||
|
||||
// BatchSize is the number of buffers expected to be passed to
|
||||
// the ReceiveFuncs, and the maximum expected to be passed to SendBatch.
|
||||
BatchSize() int
|
||||
}
|
||||
|
||||
// BindSocketToInterface is implemented by Bind objects that support being
|
||||
// tied to a single network interface. Used by wireguard-windows.
|
||||
type BindSocketToInterface interface {
|
||||
BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error
|
||||
BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error
|
||||
}
|
||||
|
||||
// PeekLookAtSocketFd is implemented by Bind objects that support having their
|
||||
// file descriptor peeked at. Used by wireguard-android.
|
||||
type PeekLookAtSocketFd interface {
|
||||
PeekLookAtSocketFd4() (fd int, err error)
|
||||
PeekLookAtSocketFd6() (fd int, err error)
|
||||
}
|
||||
|
||||
// An Endpoint maintains the source/destination caching for a peer.
|
||||
//
|
||||
// dst: the remote address of a peer ("endpoint" in uapi terminology)
|
||||
// src: the local address from which datagrams originate going to the peer
|
||||
type Endpoint interface {
|
||||
ClearSrc() // clears the source address
|
||||
SrcToString() string // returns the local source address (ip:port)
|
||||
DstToString() string // returns the destination address (ip:port)
|
||||
DstToBytes() []byte // used for mac2 cookie calculations
|
||||
DstIP() netip.Addr
|
||||
SrcIP() netip.Addr
|
||||
}
|
||||
|
||||
var (
|
||||
ErrBindAlreadyOpen = errors.New("bind is already open")
|
||||
ErrWrongEndpointType = errors.New("endpoint type does not correspond with bind type")
|
||||
)
|
||||
|
||||
func (fn ReceiveFunc) PrettyName() string {
|
||||
name := runtime.FuncForPC(reflect.ValueOf(fn).Pointer()).Name()
|
||||
// 0. cheese/taco.beansIPv6.func12.func21218-fm
|
||||
name = strings.TrimSuffix(name, "-fm")
|
||||
// 1. cheese/taco.beansIPv6.func12.func21218
|
||||
if idx := strings.LastIndexByte(name, '/'); idx != -1 {
|
||||
name = name[idx+1:]
|
||||
// 2. taco.beansIPv6.func12.func21218
|
||||
}
|
||||
for {
|
||||
var idx int
|
||||
for idx = len(name) - 1; idx >= 0; idx-- {
|
||||
if name[idx] < '0' || name[idx] > '9' {
|
||||
break
|
||||
}
|
||||
}
|
||||
if idx == len(name)-1 {
|
||||
break
|
||||
}
|
||||
const dotFunc = ".func"
|
||||
if !strings.HasSuffix(name[:idx+1], dotFunc) {
|
||||
break
|
||||
}
|
||||
name = name[:idx+1-len(dotFunc)]
|
||||
// 3. taco.beansIPv6.func12
|
||||
// 4. taco.beansIPv6
|
||||
}
|
||||
if idx := strings.LastIndexByte(name, '.'); idx != -1 {
|
||||
name = name[idx+1:]
|
||||
// 5. beansIPv6
|
||||
}
|
||||
if name == "" {
|
||||
return fmt.Sprintf("%p", fn)
|
||||
}
|
||||
if strings.HasSuffix(name, "IPv4") {
|
||||
return "v4"
|
||||
}
|
||||
if strings.HasSuffix(name, "IPv6") {
|
||||
return "v6"
|
||||
}
|
||||
return name
|
||||
}
|
||||
42
wgstack/conn/controlfns.go
Normal file
42
wgstack/conn/controlfns.go
Normal file
@@ -0,0 +1,42 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
//
|
||||
// Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
|
||||
package conn
|
||||
|
||||
import (
|
||||
"net"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
// UDP socket read/write buffer size (7MB). The value of 7MB is chosen as it is
|
||||
// the max supported by a default configuration of macOS. Some platforms will
|
||||
// silently clamp the value to other maximums, such as linux clamping to
|
||||
// net.core.{r,w}mem_max (see _linux.go for additional implementation that works
|
||||
// around this limitation)
|
||||
const socketBufferSize = 7 << 20
|
||||
|
||||
// controlFn is the callback function signature from net.ListenConfig.Control.
|
||||
// It is used to apply platform specific configuration to the socket prior to
|
||||
// bind.
|
||||
type controlFn func(network, address string, c syscall.RawConn) error
|
||||
|
||||
// controlFns is a list of functions that are called from the listen config
|
||||
// that can apply socket options.
|
||||
var controlFns = []controlFn{}
|
||||
|
||||
// listenConfig returns a net.ListenConfig that applies the controlFns to the
|
||||
// socket prior to bind. This is used to apply socket buffer sizing and packet
|
||||
// information OOB configuration for sticky sockets.
|
||||
func listenConfig() *net.ListenConfig {
|
||||
return &net.ListenConfig{
|
||||
Control: func(network, address string, c syscall.RawConn) error {
|
||||
for _, fn := range controlFns {
|
||||
if err := fn(network, address, c); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
}
|
||||
62
wgstack/conn/controlfns_linux.go
Normal file
62
wgstack/conn/controlfns_linux.go
Normal file
@@ -0,0 +1,62 @@
|
||||
//go:build linux
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
//
|
||||
// Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
|
||||
package conn
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"runtime"
|
||||
"syscall"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
func init() {
|
||||
controlFns = append(controlFns,
|
||||
|
||||
// Attempt to set the socket buffer size beyond net.core.{r,w}mem_max by
|
||||
// using SO_*BUFFORCE. This requires CAP_NET_ADMIN, and is allowed here to
|
||||
// fail silently - the result of failure is lower performance on very fast
|
||||
// links or high latency links.
|
||||
func(network, address string, c syscall.RawConn) error {
|
||||
return c.Control(func(fd uintptr) {
|
||||
// Set up to *mem_max
|
||||
_ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_RCVBUF, socketBufferSize)
|
||||
_ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_SNDBUF, socketBufferSize)
|
||||
// Set beyond *mem_max if CAP_NET_ADMIN
|
||||
_ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_RCVBUFFORCE, socketBufferSize)
|
||||
_ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_SNDBUFFORCE, socketBufferSize)
|
||||
})
|
||||
},
|
||||
|
||||
// Enable receiving of the packet information (IP_PKTINFO for IPv4,
|
||||
// IPV6_PKTINFO for IPv6) that is used to implement sticky socket support.
|
||||
func(network, address string, c syscall.RawConn) error {
|
||||
var err error
|
||||
switch network {
|
||||
case "udp4":
|
||||
if runtime.GOOS != "android" {
|
||||
c.Control(func(fd uintptr) {
|
||||
err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_PKTINFO, 1)
|
||||
})
|
||||
}
|
||||
case "udp6":
|
||||
c.Control(func(fd uintptr) {
|
||||
if runtime.GOOS != "android" {
|
||||
err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_RECVPKTINFO, 1)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_V6ONLY, 1)
|
||||
})
|
||||
default:
|
||||
err = fmt.Errorf("unhandled network: %s: %w", network, unix.EINVAL)
|
||||
}
|
||||
return err
|
||||
},
|
||||
)
|
||||
}
|
||||
9
wgstack/conn/default.go
Normal file
9
wgstack/conn/default.go
Normal file
@@ -0,0 +1,9 @@
|
||||
//go:build !windows
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
//
|
||||
// Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
|
||||
package conn
|
||||
|
||||
func NewDefaultBind() Bind { return NewStdNetBind() }
|
||||
64
wgstack/conn/mark_unix.go
Normal file
64
wgstack/conn/mark_unix.go
Normal file
@@ -0,0 +1,64 @@
|
||||
//go:build linux || openbsd || freebsd
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
//
|
||||
// Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
|
||||
package conn
|
||||
|
||||
import (
|
||||
"runtime"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
var fwmarkIoctl int
|
||||
|
||||
func init() {
|
||||
switch runtime.GOOS {
|
||||
case "linux", "android":
|
||||
fwmarkIoctl = 36 /* unix.SO_MARK */
|
||||
case "freebsd":
|
||||
fwmarkIoctl = 0x1015 /* unix.SO_USER_COOKIE */
|
||||
case "openbsd":
|
||||
fwmarkIoctl = 0x1021 /* unix.SO_RTABLE */
|
||||
}
|
||||
}
|
||||
|
||||
func (s *StdNetBind) SetMark(mark uint32) error {
|
||||
var operr error
|
||||
if fwmarkIoctl == 0 {
|
||||
return nil
|
||||
}
|
||||
if s.ipv4 != nil {
|
||||
fd, err := s.ipv4.SyscallConn()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = fd.Control(func(fd uintptr) {
|
||||
operr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, fwmarkIoctl, int(mark))
|
||||
})
|
||||
if err == nil {
|
||||
err = operr
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if s.ipv6 != nil {
|
||||
fd, err := s.ipv6.SyscallConn()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = fd.Control(func(fd uintptr) {
|
||||
operr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, fwmarkIoctl, int(mark))
|
||||
})
|
||||
if err == nil {
|
||||
err = operr
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
116
wgstack/conn/sticky_linux.go
Normal file
116
wgstack/conn/sticky_linux.go
Normal file
@@ -0,0 +1,116 @@
|
||||
//go:build linux && !android
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
//
|
||||
// Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
|
||||
package conn
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
// getSrcFromControl parses the control for PKTINFO and if found updates ep with
|
||||
// the source information found.
|
||||
func getSrcFromControl(control []byte, ep *StdNetEndpoint) {
|
||||
ep.ClearSrc()
|
||||
|
||||
var (
|
||||
hdr unix.Cmsghdr
|
||||
data []byte
|
||||
rem []byte = control
|
||||
err error
|
||||
)
|
||||
|
||||
for len(rem) > unix.SizeofCmsghdr {
|
||||
hdr, data, rem, err = unix.ParseOneSocketControlMessage(rem)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if hdr.Level == unix.IPPROTO_IP &&
|
||||
hdr.Type == unix.IP_PKTINFO {
|
||||
|
||||
info := pktInfoFromBuf[unix.Inet4Pktinfo](data)
|
||||
ep.src.Addr = netip.AddrFrom4(info.Spec_dst)
|
||||
ep.src.ifidx = info.Ifindex
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if hdr.Level == unix.IPPROTO_IPV6 &&
|
||||
hdr.Type == unix.IPV6_PKTINFO {
|
||||
|
||||
info := pktInfoFromBuf[unix.Inet6Pktinfo](data)
|
||||
ep.src.Addr = netip.AddrFrom16(info.Addr)
|
||||
ep.src.ifidx = int32(info.Ifindex)
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// pktInfoFromBuf returns type T populated from the provided buf via copy(). It
|
||||
// panics if buf is of insufficient size.
|
||||
func pktInfoFromBuf[T unix.Inet4Pktinfo | unix.Inet6Pktinfo](buf []byte) (t T) {
|
||||
size := int(unsafe.Sizeof(t))
|
||||
if len(buf) < size {
|
||||
panic("pktInfoFromBuf: buffer too small")
|
||||
}
|
||||
copy(unsafe.Slice((*byte)(unsafe.Pointer(&t)), size), buf)
|
||||
return t
|
||||
}
|
||||
|
||||
// setSrcControl sets an IP{V6}_PKTINFO in control based on the source address
|
||||
// and source ifindex found in ep. control's len will be set to 0 in the event
|
||||
// that ep is a default value.
|
||||
func setSrcControl(control *[]byte, ep *StdNetEndpoint) {
|
||||
*control = (*control)[:cap(*control)]
|
||||
if len(*control) < int(unsafe.Sizeof(unix.Cmsghdr{})) {
|
||||
*control = (*control)[:0]
|
||||
return
|
||||
}
|
||||
|
||||
if ep.src.ifidx == 0 && !ep.SrcIP().IsValid() {
|
||||
*control = (*control)[:0]
|
||||
return
|
||||
}
|
||||
|
||||
if len(*control) < srcControlSize {
|
||||
*control = (*control)[:0]
|
||||
return
|
||||
}
|
||||
|
||||
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&(*control)[0]))
|
||||
if ep.SrcIP().Is4() {
|
||||
hdr.Level = unix.IPPROTO_IP
|
||||
hdr.Type = unix.IP_PKTINFO
|
||||
hdr.SetLen(unix.CmsgLen(unix.SizeofInet4Pktinfo))
|
||||
|
||||
info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&(*control)[unix.SizeofCmsghdr]))
|
||||
info.Ifindex = ep.src.ifidx
|
||||
if ep.SrcIP().IsValid() {
|
||||
info.Spec_dst = ep.SrcIP().As4()
|
||||
}
|
||||
*control = (*control)[:unix.CmsgSpace(unix.SizeofInet4Pktinfo)]
|
||||
} else {
|
||||
hdr.Level = unix.IPPROTO_IPV6
|
||||
hdr.Type = unix.IPV6_PKTINFO
|
||||
hdr.SetLen(unix.CmsgLen(unix.SizeofInet6Pktinfo))
|
||||
|
||||
info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&(*control)[unix.SizeofCmsghdr]))
|
||||
info.Ifindex = uint32(ep.src.ifidx)
|
||||
if ep.SrcIP().IsValid() {
|
||||
info.Addr = ep.SrcIP().As16()
|
||||
}
|
||||
*control = (*control)[:unix.CmsgSpace(unix.SizeofInet6Pktinfo)]
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
var srcControlSize = unix.CmsgSpace(unix.SizeofInet6Pktinfo)
|
||||
|
||||
const StdNetSupportsStickySockets = true
|
||||
42
wgstack/tun/checksum.go
Normal file
42
wgstack/tun/checksum.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package tun
|
||||
|
||||
import "encoding/binary"
|
||||
|
||||
// TODO: Explore SIMD and/or other assembly optimizations.
|
||||
func checksumNoFold(b []byte, initial uint64) uint64 {
|
||||
ac := initial
|
||||
i := 0
|
||||
n := len(b)
|
||||
for n >= 4 {
|
||||
ac += uint64(binary.BigEndian.Uint32(b[i : i+4]))
|
||||
n -= 4
|
||||
i += 4
|
||||
}
|
||||
for n >= 2 {
|
||||
ac += uint64(binary.BigEndian.Uint16(b[i : i+2]))
|
||||
n -= 2
|
||||
i += 2
|
||||
}
|
||||
if n == 1 {
|
||||
ac += uint64(b[i]) << 8
|
||||
}
|
||||
return ac
|
||||
}
|
||||
|
||||
func checksum(b []byte, initial uint64) uint16 {
|
||||
ac := checksumNoFold(b, initial)
|
||||
ac = (ac >> 16) + (ac & 0xffff)
|
||||
ac = (ac >> 16) + (ac & 0xffff)
|
||||
ac = (ac >> 16) + (ac & 0xffff)
|
||||
ac = (ac >> 16) + (ac & 0xffff)
|
||||
return uint16(ac)
|
||||
}
|
||||
|
||||
func pseudoHeaderChecksumNoFold(protocol uint8, srcAddr, dstAddr []byte, totalLen uint16) uint64 {
|
||||
sum := checksumNoFold(srcAddr, 0)
|
||||
sum = checksumNoFold(dstAddr, sum)
|
||||
sum = checksumNoFold([]byte{0, protocol}, sum)
|
||||
tmp := make([]byte, 2)
|
||||
binary.BigEndian.PutUint16(tmp, totalLen)
|
||||
return checksumNoFold(tmp, sum)
|
||||
}
|
||||
3
wgstack/tun/export.go
Normal file
3
wgstack/tun/export.go
Normal file
@@ -0,0 +1,3 @@
|
||||
package tun
|
||||
|
||||
const VirtioNetHdrLen = virtioNetHdrLen
|
||||
630
wgstack/tun/tcp_offload_linux.go
Normal file
630
wgstack/tun/tcp_offload_linux.go
Normal file
@@ -0,0 +1,630 @@
|
||||
//go:build linux
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
//
|
||||
// Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
|
||||
package tun
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"io"
|
||||
"unsafe"
|
||||
|
||||
wgconn "github.com/slackhq/nebula/wgstack/conn"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
var ErrTooManySegments = errors.New("tun: too many segments for TSO")
|
||||
|
||||
const tcpFlagsOffset = 13
|
||||
|
||||
const (
|
||||
tcpFlagFIN uint8 = 0x01
|
||||
tcpFlagPSH uint8 = 0x08
|
||||
tcpFlagACK uint8 = 0x10
|
||||
)
|
||||
|
||||
// virtioNetHdr is defined in the kernel in include/uapi/linux/virtio_net.h. The
|
||||
// kernel symbol is virtio_net_hdr.
|
||||
type virtioNetHdr struct {
|
||||
flags uint8
|
||||
gsoType uint8
|
||||
hdrLen uint16
|
||||
gsoSize uint16
|
||||
csumStart uint16
|
||||
csumOffset uint16
|
||||
}
|
||||
|
||||
func (v *virtioNetHdr) decode(b []byte) error {
|
||||
if len(b) < virtioNetHdrLen {
|
||||
return io.ErrShortBuffer
|
||||
}
|
||||
copy(unsafe.Slice((*byte)(unsafe.Pointer(v)), virtioNetHdrLen), b[:virtioNetHdrLen])
|
||||
return nil
|
||||
}
|
||||
|
||||
func (v *virtioNetHdr) encode(b []byte) error {
|
||||
if len(b) < virtioNetHdrLen {
|
||||
return io.ErrShortBuffer
|
||||
}
|
||||
copy(b[:virtioNetHdrLen], unsafe.Slice((*byte)(unsafe.Pointer(v)), virtioNetHdrLen))
|
||||
return nil
|
||||
}
|
||||
|
||||
const (
|
||||
// virtioNetHdrLen is the length in bytes of virtioNetHdr. This matches the
|
||||
// shape of the C ABI for its kernel counterpart -- sizeof(virtio_net_hdr).
|
||||
virtioNetHdrLen = int(unsafe.Sizeof(virtioNetHdr{}))
|
||||
)
|
||||
|
||||
// flowKey represents the key for a flow.
|
||||
type flowKey struct {
|
||||
srcAddr, dstAddr [16]byte
|
||||
srcPort, dstPort uint16
|
||||
rxAck uint32 // varying ack values should not be coalesced. Treat them as separate flows.
|
||||
}
|
||||
|
||||
// tcpGROTable holds flow and coalescing information for the purposes of GRO.
|
||||
type tcpGROTable struct {
|
||||
itemsByFlow map[flowKey][]tcpGROItem
|
||||
itemsPool [][]tcpGROItem
|
||||
}
|
||||
|
||||
func newTCPGROTable() *tcpGROTable {
|
||||
t := &tcpGROTable{
|
||||
itemsByFlow: make(map[flowKey][]tcpGROItem, wgconn.IdealBatchSize),
|
||||
itemsPool: make([][]tcpGROItem, wgconn.IdealBatchSize),
|
||||
}
|
||||
for i := range t.itemsPool {
|
||||
t.itemsPool[i] = make([]tcpGROItem, 0, wgconn.IdealBatchSize)
|
||||
}
|
||||
return t
|
||||
}
|
||||
|
||||
func newFlowKey(pkt []byte, srcAddr, dstAddr, tcphOffset int) flowKey {
|
||||
key := flowKey{}
|
||||
addrSize := dstAddr - srcAddr
|
||||
copy(key.srcAddr[:], pkt[srcAddr:dstAddr])
|
||||
copy(key.dstAddr[:], pkt[dstAddr:dstAddr+addrSize])
|
||||
key.srcPort = binary.BigEndian.Uint16(pkt[tcphOffset:])
|
||||
key.dstPort = binary.BigEndian.Uint16(pkt[tcphOffset+2:])
|
||||
key.rxAck = binary.BigEndian.Uint32(pkt[tcphOffset+8:])
|
||||
return key
|
||||
}
|
||||
|
||||
// lookupOrInsert looks up a flow for the provided packet and metadata,
|
||||
// returning the packets found for the flow, or inserting a new one if none
|
||||
// is found.
|
||||
func (t *tcpGROTable) lookupOrInsert(pkt []byte, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, bufsIndex int) ([]tcpGROItem, bool) {
|
||||
key := newFlowKey(pkt, srcAddrOffset, dstAddrOffset, tcphOffset)
|
||||
items, ok := t.itemsByFlow[key]
|
||||
if ok {
|
||||
return items, ok
|
||||
}
|
||||
// TODO: insert() performs another map lookup. This could be rearranged to avoid.
|
||||
t.insert(pkt, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, bufsIndex)
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// insert an item in the table for the provided packet and packet metadata.
|
||||
func (t *tcpGROTable) insert(pkt []byte, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, bufsIndex int) {
|
||||
key := newFlowKey(pkt, srcAddrOffset, dstAddrOffset, tcphOffset)
|
||||
item := tcpGROItem{
|
||||
key: key,
|
||||
bufsIndex: uint16(bufsIndex),
|
||||
gsoSize: uint16(len(pkt[tcphOffset+tcphLen:])),
|
||||
iphLen: uint8(tcphOffset),
|
||||
tcphLen: uint8(tcphLen),
|
||||
sentSeq: binary.BigEndian.Uint32(pkt[tcphOffset+4:]),
|
||||
pshSet: pkt[tcphOffset+tcpFlagsOffset]&tcpFlagPSH != 0,
|
||||
}
|
||||
items, ok := t.itemsByFlow[key]
|
||||
if !ok {
|
||||
items = t.newItems()
|
||||
}
|
||||
items = append(items, item)
|
||||
t.itemsByFlow[key] = items
|
||||
}
|
||||
|
||||
func (t *tcpGROTable) updateAt(item tcpGROItem, i int) {
|
||||
items, _ := t.itemsByFlow[item.key]
|
||||
items[i] = item
|
||||
}
|
||||
|
||||
func (t *tcpGROTable) deleteAt(key flowKey, i int) {
|
||||
items, _ := t.itemsByFlow[key]
|
||||
items = append(items[:i], items[i+1:]...)
|
||||
t.itemsByFlow[key] = items
|
||||
}
|
||||
|
||||
// tcpGROItem represents bookkeeping data for a TCP packet during the lifetime
|
||||
// of a GRO evaluation across a vector of packets.
|
||||
type tcpGROItem struct {
|
||||
key flowKey
|
||||
sentSeq uint32 // the sequence number
|
||||
bufsIndex uint16 // the index into the original bufs slice
|
||||
numMerged uint16 // the number of packets merged into this item
|
||||
gsoSize uint16 // payload size
|
||||
iphLen uint8 // ip header len
|
||||
tcphLen uint8 // tcp header len
|
||||
pshSet bool // psh flag is set
|
||||
}
|
||||
|
||||
func (t *tcpGROTable) newItems() []tcpGROItem {
|
||||
var items []tcpGROItem
|
||||
items, t.itemsPool = t.itemsPool[len(t.itemsPool)-1], t.itemsPool[:len(t.itemsPool)-1]
|
||||
return items
|
||||
}
|
||||
|
||||
func (t *tcpGROTable) reset() {
|
||||
for k, items := range t.itemsByFlow {
|
||||
items = items[:0]
|
||||
t.itemsPool = append(t.itemsPool, items)
|
||||
delete(t.itemsByFlow, k)
|
||||
}
|
||||
}
|
||||
|
||||
// canCoalesce represents the outcome of checking if two TCP packets are
|
||||
// candidates for coalescing.
|
||||
type canCoalesce int
|
||||
|
||||
const (
|
||||
coalescePrepend canCoalesce = -1
|
||||
coalesceUnavailable canCoalesce = 0
|
||||
coalesceAppend canCoalesce = 1
|
||||
)
|
||||
|
||||
// tcpPacketsCanCoalesce evaluates if pkt can be coalesced with the packet
|
||||
// described by item. This function makes considerations that match the kernel's
|
||||
// GRO self tests, which can be found in tools/testing/selftests/net/gro.c.
|
||||
func tcpPacketsCanCoalesce(pkt []byte, iphLen, tcphLen uint8, seq uint32, pshSet bool, gsoSize uint16, item tcpGROItem, bufs [][]byte, bufsOffset int) canCoalesce {
|
||||
pktTarget := bufs[item.bufsIndex][bufsOffset:]
|
||||
if tcphLen != item.tcphLen {
|
||||
// cannot coalesce with unequal tcp options len
|
||||
return coalesceUnavailable
|
||||
}
|
||||
if tcphLen > 20 {
|
||||
if !bytes.Equal(pkt[iphLen+20:iphLen+tcphLen], pktTarget[item.iphLen+20:iphLen+tcphLen]) {
|
||||
// cannot coalesce with unequal tcp options
|
||||
return coalesceUnavailable
|
||||
}
|
||||
}
|
||||
if pkt[0]>>4 == 6 {
|
||||
if pkt[0] != pktTarget[0] || pkt[1]>>4 != pktTarget[1]>>4 {
|
||||
// cannot coalesce with unequal Traffic class values
|
||||
return coalesceUnavailable
|
||||
}
|
||||
if pkt[7] != pktTarget[7] {
|
||||
// cannot coalesce with unequal Hop limit values
|
||||
return coalesceUnavailable
|
||||
}
|
||||
} else {
|
||||
if pkt[1] != pktTarget[1] {
|
||||
// cannot coalesce with unequal ToS values
|
||||
return coalesceUnavailable
|
||||
}
|
||||
if pkt[6]>>5 != pktTarget[6]>>5 {
|
||||
// cannot coalesce with unequal DF or reserved bits. MF is checked
|
||||
// further up the stack.
|
||||
return coalesceUnavailable
|
||||
}
|
||||
if pkt[8] != pktTarget[8] {
|
||||
// cannot coalesce with unequal TTL values
|
||||
return coalesceUnavailable
|
||||
}
|
||||
}
|
||||
// seq adjacency
|
||||
lhsLen := item.gsoSize
|
||||
lhsLen += item.numMerged * item.gsoSize
|
||||
if seq == item.sentSeq+uint32(lhsLen) { // pkt aligns following item from a seq num perspective
|
||||
if item.pshSet {
|
||||
// We cannot append to a segment that has the PSH flag set, PSH
|
||||
// can only be set on the final segment in a reassembled group.
|
||||
return coalesceUnavailable
|
||||
}
|
||||
if len(pktTarget[iphLen+tcphLen:])%int(item.gsoSize) != 0 {
|
||||
// A smaller than gsoSize packet has been appended previously.
|
||||
// Nothing can come after a smaller packet on the end.
|
||||
return coalesceUnavailable
|
||||
}
|
||||
if gsoSize > item.gsoSize {
|
||||
// We cannot have a larger packet following a smaller one.
|
||||
return coalesceUnavailable
|
||||
}
|
||||
return coalesceAppend
|
||||
} else if seq+uint32(gsoSize) == item.sentSeq { // pkt aligns in front of item from a seq num perspective
|
||||
if pshSet {
|
||||
// We cannot prepend with a segment that has the PSH flag set, PSH
|
||||
// can only be set on the final segment in a reassembled group.
|
||||
return coalesceUnavailable
|
||||
}
|
||||
if gsoSize < item.gsoSize {
|
||||
// We cannot have a larger packet following a smaller one.
|
||||
return coalesceUnavailable
|
||||
}
|
||||
if gsoSize > item.gsoSize && item.numMerged > 0 {
|
||||
// There's at least one previous merge, and we're larger than all
|
||||
// previous. This would put multiple smaller packets on the end.
|
||||
return coalesceUnavailable
|
||||
}
|
||||
return coalescePrepend
|
||||
}
|
||||
return coalesceUnavailable
|
||||
}
|
||||
|
||||
func tcpChecksumValid(pkt []byte, iphLen uint8, isV6 bool) bool {
|
||||
srcAddrAt := ipv4SrcAddrOffset
|
||||
addrSize := 4
|
||||
if isV6 {
|
||||
srcAddrAt = ipv6SrcAddrOffset
|
||||
addrSize = 16
|
||||
}
|
||||
tcpTotalLen := uint16(len(pkt) - int(iphLen))
|
||||
tcpCSumNoFold := pseudoHeaderChecksumNoFold(unix.IPPROTO_TCP, pkt[srcAddrAt:srcAddrAt+addrSize], pkt[srcAddrAt+addrSize:srcAddrAt+addrSize*2], tcpTotalLen)
|
||||
return ^checksum(pkt[iphLen:], tcpCSumNoFold) == 0
|
||||
}
|
||||
|
||||
// coalesceResult represents the result of attempting to coalesce two TCP
|
||||
// packets.
|
||||
type coalesceResult int
|
||||
|
||||
const (
|
||||
coalesceInsufficientCap coalesceResult = 0
|
||||
coalescePSHEnding coalesceResult = 1
|
||||
coalesceItemInvalidCSum coalesceResult = 2
|
||||
coalescePktInvalidCSum coalesceResult = 3
|
||||
coalesceSuccess coalesceResult = 4
|
||||
)
|
||||
|
||||
// coalesceTCPPackets attempts to coalesce pkt with the packet described by
|
||||
// item, returning the outcome. This function may swap bufs elements in the
|
||||
// event of a prepend as item's bufs index is already being tracked for writing
|
||||
// to a Device.
|
||||
func coalesceTCPPackets(mode canCoalesce, pkt []byte, pktBuffsIndex int, gsoSize uint16, seq uint32, pshSet bool, item *tcpGROItem, bufs [][]byte, bufsOffset int, isV6 bool) coalesceResult {
|
||||
var pktHead []byte // the packet that will end up at the front
|
||||
headersLen := item.iphLen + item.tcphLen
|
||||
coalescedLen := len(bufs[item.bufsIndex][bufsOffset:]) + len(pkt) - int(headersLen)
|
||||
|
||||
// Copy data
|
||||
if mode == coalescePrepend {
|
||||
pktHead = pkt
|
||||
if cap(pkt)-bufsOffset < coalescedLen {
|
||||
// We don't want to allocate a new underlying array if capacity is
|
||||
// too small.
|
||||
return coalesceInsufficientCap
|
||||
}
|
||||
if pshSet {
|
||||
return coalescePSHEnding
|
||||
}
|
||||
if item.numMerged == 0 {
|
||||
if !tcpChecksumValid(bufs[item.bufsIndex][bufsOffset:], item.iphLen, isV6) {
|
||||
return coalesceItemInvalidCSum
|
||||
}
|
||||
}
|
||||
if !tcpChecksumValid(pkt, item.iphLen, isV6) {
|
||||
return coalescePktInvalidCSum
|
||||
}
|
||||
item.sentSeq = seq
|
||||
extendBy := coalescedLen - len(pktHead)
|
||||
bufs[pktBuffsIndex] = append(bufs[pktBuffsIndex], make([]byte, extendBy)...)
|
||||
copy(bufs[pktBuffsIndex][bufsOffset+len(pkt):], bufs[item.bufsIndex][bufsOffset+int(headersLen):])
|
||||
// Flip the slice headers in bufs as part of prepend. The index of item
|
||||
// is already being tracked for writing.
|
||||
bufs[item.bufsIndex], bufs[pktBuffsIndex] = bufs[pktBuffsIndex], bufs[item.bufsIndex]
|
||||
} else {
|
||||
pktHead = bufs[item.bufsIndex][bufsOffset:]
|
||||
if cap(pktHead)-bufsOffset < coalescedLen {
|
||||
// We don't want to allocate a new underlying array if capacity is
|
||||
// too small.
|
||||
return coalesceInsufficientCap
|
||||
}
|
||||
if item.numMerged == 0 {
|
||||
if !tcpChecksumValid(bufs[item.bufsIndex][bufsOffset:], item.iphLen, isV6) {
|
||||
return coalesceItemInvalidCSum
|
||||
}
|
||||
}
|
||||
if !tcpChecksumValid(pkt, item.iphLen, isV6) {
|
||||
return coalescePktInvalidCSum
|
||||
}
|
||||
if pshSet {
|
||||
// We are appending a segment with PSH set.
|
||||
item.pshSet = pshSet
|
||||
pktHead[item.iphLen+tcpFlagsOffset] |= tcpFlagPSH
|
||||
}
|
||||
extendBy := len(pkt) - int(headersLen)
|
||||
bufs[item.bufsIndex] = append(bufs[item.bufsIndex], make([]byte, extendBy)...)
|
||||
copy(bufs[item.bufsIndex][bufsOffset+len(pktHead):], pkt[headersLen:])
|
||||
}
|
||||
|
||||
if gsoSize > item.gsoSize {
|
||||
item.gsoSize = gsoSize
|
||||
}
|
||||
hdr := virtioNetHdr{
|
||||
flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, // this turns into CHECKSUM_PARTIAL in the skb
|
||||
hdrLen: uint16(headersLen),
|
||||
gsoSize: uint16(item.gsoSize),
|
||||
csumStart: uint16(item.iphLen),
|
||||
csumOffset: 16,
|
||||
}
|
||||
|
||||
// Recalculate the total len (IPv4) or payload len (IPv6). Recalculate the
|
||||
// (IPv4) header checksum.
|
||||
if isV6 {
|
||||
hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_TCPV6
|
||||
binary.BigEndian.PutUint16(pktHead[4:], uint16(coalescedLen)-uint16(item.iphLen)) // set new payload len
|
||||
} else {
|
||||
hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_TCPV4
|
||||
pktHead[10], pktHead[11] = 0, 0 // clear checksum field
|
||||
binary.BigEndian.PutUint16(pktHead[2:], uint16(coalescedLen)) // set new total length
|
||||
iphCSum := ^checksum(pktHead[:item.iphLen], 0) // compute checksum
|
||||
binary.BigEndian.PutUint16(pktHead[10:], iphCSum) // set checksum field
|
||||
}
|
||||
hdr.encode(bufs[item.bufsIndex][bufsOffset-virtioNetHdrLen:])
|
||||
|
||||
// Calculate the pseudo header checksum and place it at the TCP checksum
|
||||
// offset. Downstream checksum offloading will combine this with computation
|
||||
// of the tcp header and payload checksum.
|
||||
addrLen := 4
|
||||
addrOffset := ipv4SrcAddrOffset
|
||||
if isV6 {
|
||||
addrLen = 16
|
||||
addrOffset = ipv6SrcAddrOffset
|
||||
}
|
||||
srcAddrAt := bufsOffset + addrOffset
|
||||
srcAddr := bufs[item.bufsIndex][srcAddrAt : srcAddrAt+addrLen]
|
||||
dstAddr := bufs[item.bufsIndex][srcAddrAt+addrLen : srcAddrAt+addrLen*2]
|
||||
psum := pseudoHeaderChecksumNoFold(unix.IPPROTO_TCP, srcAddr, dstAddr, uint16(coalescedLen-int(item.iphLen)))
|
||||
binary.BigEndian.PutUint16(pktHead[hdr.csumStart+hdr.csumOffset:], checksum([]byte{}, psum))
|
||||
|
||||
item.numMerged++
|
||||
return coalesceSuccess
|
||||
}
|
||||
|
||||
const (
|
||||
ipv4FlagMoreFragments uint8 = 0x20
|
||||
)
|
||||
|
||||
const (
|
||||
ipv4SrcAddrOffset = 12
|
||||
ipv6SrcAddrOffset = 8
|
||||
maxUint16 = 1<<16 - 1
|
||||
)
|
||||
|
||||
// tcpGRO evaluates the TCP packet at pktI in bufs for coalescing with
|
||||
// existing packets tracked in table. It will return false when pktI is not
|
||||
// coalesced, otherwise true. This indicates to the caller if bufs[pktI]
|
||||
// should be written to the Device.
|
||||
func tcpGRO(bufs [][]byte, offset int, pktI int, table *tcpGROTable, isV6 bool) (pktCoalesced bool) {
|
||||
pkt := bufs[pktI][offset:]
|
||||
if len(pkt) > maxUint16 {
|
||||
// A valid IPv4 or IPv6 packet will never exceed this.
|
||||
return false
|
||||
}
|
||||
iphLen := int((pkt[0] & 0x0F) * 4)
|
||||
if isV6 {
|
||||
iphLen = 40
|
||||
ipv6HPayloadLen := int(binary.BigEndian.Uint16(pkt[4:]))
|
||||
if ipv6HPayloadLen != len(pkt)-iphLen {
|
||||
return false
|
||||
}
|
||||
} else {
|
||||
totalLen := int(binary.BigEndian.Uint16(pkt[2:]))
|
||||
if totalLen != len(pkt) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
if len(pkt) < iphLen {
|
||||
return false
|
||||
}
|
||||
tcphLen := int((pkt[iphLen+12] >> 4) * 4)
|
||||
if tcphLen < 20 || tcphLen > 60 {
|
||||
return false
|
||||
}
|
||||
if len(pkt) < iphLen+tcphLen {
|
||||
return false
|
||||
}
|
||||
if !isV6 {
|
||||
if pkt[6]&ipv4FlagMoreFragments != 0 || pkt[6]<<3 != 0 || pkt[7] != 0 {
|
||||
// no GRO support for fragmented segments for now
|
||||
return false
|
||||
}
|
||||
}
|
||||
tcpFlags := pkt[iphLen+tcpFlagsOffset]
|
||||
var pshSet bool
|
||||
// not a candidate if any non-ACK flags (except PSH+ACK) are set
|
||||
if tcpFlags != tcpFlagACK {
|
||||
if pkt[iphLen+tcpFlagsOffset] != tcpFlagACK|tcpFlagPSH {
|
||||
return false
|
||||
}
|
||||
pshSet = true
|
||||
}
|
||||
gsoSize := uint16(len(pkt) - tcphLen - iphLen)
|
||||
// not a candidate if payload len is 0
|
||||
if gsoSize < 1 {
|
||||
return false
|
||||
}
|
||||
seq := binary.BigEndian.Uint32(pkt[iphLen+4:])
|
||||
srcAddrOffset := ipv4SrcAddrOffset
|
||||
addrLen := 4
|
||||
if isV6 {
|
||||
srcAddrOffset = ipv6SrcAddrOffset
|
||||
addrLen = 16
|
||||
}
|
||||
items, existing := table.lookupOrInsert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, tcphLen, pktI)
|
||||
if !existing {
|
||||
return false
|
||||
}
|
||||
for i := len(items) - 1; i >= 0; i-- {
|
||||
// In the best case of packets arriving in order iterating in reverse is
|
||||
// more efficient if there are multiple items for a given flow. This
|
||||
// also enables a natural table.deleteAt() in the
|
||||
// coalesceItemInvalidCSum case without the need for index tracking.
|
||||
// This algorithm makes a best effort to coalesce in the event of
|
||||
// unordered packets, where pkt may land anywhere in items from a
|
||||
// sequence number perspective, however once an item is inserted into
|
||||
// the table it is never compared across other items later.
|
||||
item := items[i]
|
||||
can := tcpPacketsCanCoalesce(pkt, uint8(iphLen), uint8(tcphLen), seq, pshSet, gsoSize, item, bufs, offset)
|
||||
if can != coalesceUnavailable {
|
||||
result := coalesceTCPPackets(can, pkt, pktI, gsoSize, seq, pshSet, &item, bufs, offset, isV6)
|
||||
switch result {
|
||||
case coalesceSuccess:
|
||||
table.updateAt(item, i)
|
||||
return true
|
||||
case coalesceItemInvalidCSum:
|
||||
// delete the item with an invalid csum
|
||||
table.deleteAt(item.key, i)
|
||||
case coalescePktInvalidCSum:
|
||||
// no point in inserting an item that we can't coalesce
|
||||
return false
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
// failed to coalesce with any other packets; store the item in the flow
|
||||
table.insert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, tcphLen, pktI)
|
||||
return false
|
||||
}
|
||||
|
||||
func isTCP4NoIPOptions(b []byte) bool {
|
||||
if len(b) < 40 {
|
||||
return false
|
||||
}
|
||||
if b[0]>>4 != 4 {
|
||||
return false
|
||||
}
|
||||
if b[0]&0x0F != 5 {
|
||||
return false
|
||||
}
|
||||
if b[9] != unix.IPPROTO_TCP {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func isTCP6NoEH(b []byte) bool {
|
||||
if len(b) < 60 {
|
||||
return false
|
||||
}
|
||||
if b[0]>>4 != 6 {
|
||||
return false
|
||||
}
|
||||
if b[6] != unix.IPPROTO_TCP {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// handleGRO evaluates bufs for GRO, and writes the indices of the resulting
|
||||
// packets into toWrite. toWrite, tcp4Table, and tcp6Table should initially be
|
||||
// empty (but non-nil), and are passed in to save allocs as the caller may reset
|
||||
// and recycle them across vectors of packets.
|
||||
func handleGRO(bufs [][]byte, offset int, tcp4Table, tcp6Table *tcpGROTable, toWrite *[]int) error {
|
||||
for i := range bufs {
|
||||
if offset < virtioNetHdrLen || offset > len(bufs[i])-1 {
|
||||
return errors.New("invalid offset")
|
||||
}
|
||||
var coalesced bool
|
||||
switch {
|
||||
case isTCP4NoIPOptions(bufs[i][offset:]): // ipv4 packets w/IP options do not coalesce
|
||||
coalesced = tcpGRO(bufs, offset, i, tcp4Table, false)
|
||||
case isTCP6NoEH(bufs[i][offset:]): // ipv6 packets w/extension headers do not coalesce
|
||||
coalesced = tcpGRO(bufs, offset, i, tcp6Table, true)
|
||||
}
|
||||
if !coalesced {
|
||||
hdr := virtioNetHdr{}
|
||||
err := hdr.encode(bufs[i][offset-virtioNetHdrLen:])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*toWrite = append(*toWrite, i)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// tcpTSO splits packets from in into outBuffs, writing the size of each
|
||||
// element into sizes. It returns the number of buffers populated, and/or an
|
||||
// error.
|
||||
func tcpTSO(in []byte, hdr virtioNetHdr, outBuffs [][]byte, sizes []int, outOffset int) (int, error) {
|
||||
iphLen := int(hdr.csumStart)
|
||||
srcAddrOffset := ipv6SrcAddrOffset
|
||||
addrLen := 16
|
||||
if hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_TCPV4 {
|
||||
in[10], in[11] = 0, 0 // clear ipv4 header checksum
|
||||
srcAddrOffset = ipv4SrcAddrOffset
|
||||
addrLen = 4
|
||||
}
|
||||
tcpCSumAt := int(hdr.csumStart + hdr.csumOffset)
|
||||
in[tcpCSumAt], in[tcpCSumAt+1] = 0, 0 // clear tcp checksum
|
||||
firstTCPSeqNum := binary.BigEndian.Uint32(in[hdr.csumStart+4:])
|
||||
nextSegmentDataAt := int(hdr.hdrLen)
|
||||
i := 0
|
||||
for ; nextSegmentDataAt < len(in); i++ {
|
||||
if i == len(outBuffs) {
|
||||
return i - 1, ErrTooManySegments
|
||||
}
|
||||
nextSegmentEnd := nextSegmentDataAt + int(hdr.gsoSize)
|
||||
if nextSegmentEnd > len(in) {
|
||||
nextSegmentEnd = len(in)
|
||||
}
|
||||
segmentDataLen := nextSegmentEnd - nextSegmentDataAt
|
||||
totalLen := int(hdr.hdrLen) + segmentDataLen
|
||||
sizes[i] = totalLen
|
||||
out := outBuffs[i][outOffset:]
|
||||
|
||||
copy(out, in[:iphLen])
|
||||
if hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_TCPV4 {
|
||||
// For IPv4 we are responsible for incrementing the ID field,
|
||||
// updating the total len field, and recalculating the header
|
||||
// checksum.
|
||||
if i > 0 {
|
||||
id := binary.BigEndian.Uint16(out[4:])
|
||||
id += uint16(i)
|
||||
binary.BigEndian.PutUint16(out[4:], id)
|
||||
}
|
||||
binary.BigEndian.PutUint16(out[2:], uint16(totalLen))
|
||||
ipv4CSum := ^checksum(out[:iphLen], 0)
|
||||
binary.BigEndian.PutUint16(out[10:], ipv4CSum)
|
||||
} else {
|
||||
// For IPv6 we are responsible for updating the payload length field.
|
||||
binary.BigEndian.PutUint16(out[4:], uint16(totalLen-iphLen))
|
||||
}
|
||||
|
||||
// TCP header
|
||||
copy(out[hdr.csumStart:hdr.hdrLen], in[hdr.csumStart:hdr.hdrLen])
|
||||
tcpSeq := firstTCPSeqNum + uint32(hdr.gsoSize*uint16(i))
|
||||
binary.BigEndian.PutUint32(out[hdr.csumStart+4:], tcpSeq)
|
||||
if nextSegmentEnd != len(in) {
|
||||
// FIN and PSH should only be set on last segment
|
||||
clearFlags := tcpFlagFIN | tcpFlagPSH
|
||||
out[hdr.csumStart+tcpFlagsOffset] &^= clearFlags
|
||||
}
|
||||
|
||||
// payload
|
||||
copy(out[hdr.hdrLen:], in[nextSegmentDataAt:nextSegmentEnd])
|
||||
|
||||
// TCP checksum
|
||||
tcpHLen := int(hdr.hdrLen - hdr.csumStart)
|
||||
tcpLenForPseudo := uint16(tcpHLen + segmentDataLen)
|
||||
tcpCSumNoFold := pseudoHeaderChecksumNoFold(unix.IPPROTO_TCP, in[srcAddrOffset:srcAddrOffset+addrLen], in[srcAddrOffset+addrLen:srcAddrOffset+addrLen*2], tcpLenForPseudo)
|
||||
tcpCSum := ^checksum(out[hdr.csumStart:totalLen], tcpCSumNoFold)
|
||||
binary.BigEndian.PutUint16(out[hdr.csumStart+hdr.csumOffset:], tcpCSum)
|
||||
|
||||
nextSegmentDataAt += int(hdr.gsoSize)
|
||||
}
|
||||
return i, nil
|
||||
}
|
||||
|
||||
func gsoNoneChecksum(in []byte, cSumStart, cSumOffset uint16) error {
|
||||
cSumAt := cSumStart + cSumOffset
|
||||
// The initial value at the checksum offset should be summed with the
|
||||
// checksum we compute. This is typically the pseudo-header checksum.
|
||||
initial := binary.BigEndian.Uint16(in[cSumAt:])
|
||||
in[cSumAt], in[cSumAt+1] = 0, 0
|
||||
binary.BigEndian.PutUint16(in[cSumAt:], ^checksum(in[cSumStart:], uint64(initial)))
|
||||
return nil
|
||||
}
|
||||
52
wgstack/tun/tun.go
Normal file
52
wgstack/tun/tun.go
Normal file
@@ -0,0 +1,52 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
//
|
||||
// Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
|
||||
package tun
|
||||
|
||||
import (
|
||||
"os"
|
||||
)
|
||||
|
||||
type Event int
|
||||
|
||||
const (
|
||||
EventUp = 1 << iota
|
||||
EventDown
|
||||
EventMTUUpdate
|
||||
)
|
||||
|
||||
type Device interface {
|
||||
// File returns the file descriptor of the device.
|
||||
File() *os.File
|
||||
|
||||
// Read one or more packets from the Device (without any additional headers).
|
||||
// On a successful read it returns the number of packets read, and sets
|
||||
// packet lengths within the sizes slice. len(sizes) must be >= len(bufs).
|
||||
// A nonzero offset can be used to instruct the Device on where to begin
|
||||
// reading into each element of the bufs slice.
|
||||
Read(bufs [][]byte, sizes []int, offset int) (n int, err error)
|
||||
|
||||
// Write one or more packets to the device (without any additional headers).
|
||||
// On a successful write it returns the number of packets written. A nonzero
|
||||
// offset can be used to instruct the Device on where to begin writing from
|
||||
// each packet contained within the bufs slice.
|
||||
Write(bufs [][]byte, offset int) (int, error)
|
||||
|
||||
// MTU returns the MTU of the Device.
|
||||
MTU() (int, error)
|
||||
|
||||
// Name returns the current name of the Device.
|
||||
Name() (string, error)
|
||||
|
||||
// Events returns a channel of type Event, which is fed Device events.
|
||||
Events() <-chan Event
|
||||
|
||||
// Close stops the Device and closes the Event channel.
|
||||
Close() error
|
||||
|
||||
// BatchSize returns the preferred/max number of packets that can be read or
|
||||
// written in a single read/write call. BatchSize must not change over the
|
||||
// lifetime of a Device.
|
||||
BatchSize() int
|
||||
}
|
||||
652
wgstack/tun/tun_linux.go
Normal file
652
wgstack/tun/tun_linux.go
Normal file
@@ -0,0 +1,652 @@
|
||||
//go:build linux
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
//
|
||||
// Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
|
||||
package tun
|
||||
|
||||
/* Implementation of the TUN device interface for linux
|
||||
*/
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
wgconn "github.com/slackhq/nebula/wgstack/conn"
|
||||
"golang.org/x/sys/unix"
|
||||
"golang.zx2c4.com/wireguard/rwcancel"
|
||||
)
|
||||
|
||||
const (
|
||||
cloneDevicePath = "/dev/net/tun"
|
||||
ifReqSize = unix.IFNAMSIZ + 64
|
||||
)
|
||||
|
||||
type NativeTun struct {
|
||||
tunFile *os.File
|
||||
index int32 // if index
|
||||
errors chan error // async error handling
|
||||
events chan Event // device related events
|
||||
netlinkSock int
|
||||
netlinkCancel *rwcancel.RWCancel
|
||||
hackListenerClosed sync.Mutex
|
||||
statusListenersShutdown chan struct{}
|
||||
batchSize int
|
||||
vnetHdr bool
|
||||
|
||||
closeOnce sync.Once
|
||||
|
||||
nameOnce sync.Once // guards calling initNameCache, which sets following fields
|
||||
nameCache string // name of interface
|
||||
nameErr error
|
||||
|
||||
readOpMu sync.Mutex // readOpMu guards readBuff
|
||||
readBuff [virtioNetHdrLen + 65535]byte // if vnetHdr every read() is prefixed by virtioNetHdr
|
||||
|
||||
writeOpMu sync.Mutex // writeOpMu guards toWrite, tcp4GROTable, tcp6GROTable
|
||||
toWrite []int
|
||||
tcp4GROTable, tcp6GROTable *tcpGROTable
|
||||
}
|
||||
|
||||
func (tun *NativeTun) File() *os.File {
|
||||
return tun.tunFile
|
||||
}
|
||||
|
||||
func (tun *NativeTun) routineHackListener() {
|
||||
defer tun.hackListenerClosed.Unlock()
|
||||
/* This is needed for the detection to work across network namespaces
|
||||
* If you are reading this and know a better method, please get in touch.
|
||||
*/
|
||||
last := 0
|
||||
const (
|
||||
up = 1
|
||||
down = 2
|
||||
)
|
||||
for {
|
||||
sysconn, err := tun.tunFile.SyscallConn()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err2 := sysconn.Control(func(fd uintptr) {
|
||||
_, err = unix.Write(int(fd), nil)
|
||||
})
|
||||
if err2 != nil {
|
||||
return
|
||||
}
|
||||
switch err {
|
||||
case unix.EINVAL:
|
||||
if last != up {
|
||||
// If the tunnel is up, it reports that write() is
|
||||
// allowed but we provided invalid data.
|
||||
tun.events <- EventUp
|
||||
last = up
|
||||
}
|
||||
case unix.EIO:
|
||||
if last != down {
|
||||
// If the tunnel is down, it reports that no I/O
|
||||
// is possible, without checking our provided data.
|
||||
tun.events <- EventDown
|
||||
last = down
|
||||
}
|
||||
default:
|
||||
return
|
||||
}
|
||||
select {
|
||||
case <-time.After(time.Second):
|
||||
// nothing
|
||||
case <-tun.statusListenersShutdown:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func createNetlinkSocket() (int, error) {
|
||||
sock, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW|unix.SOCK_CLOEXEC, unix.NETLINK_ROUTE)
|
||||
if err != nil {
|
||||
return -1, err
|
||||
}
|
||||
saddr := &unix.SockaddrNetlink{
|
||||
Family: unix.AF_NETLINK,
|
||||
Groups: unix.RTMGRP_LINK | unix.RTMGRP_IPV4_IFADDR | unix.RTMGRP_IPV6_IFADDR,
|
||||
}
|
||||
err = unix.Bind(sock, saddr)
|
||||
if err != nil {
|
||||
return -1, err
|
||||
}
|
||||
return sock, nil
|
||||
}
|
||||
|
||||
func (tun *NativeTun) routineNetlinkListener() {
|
||||
defer func() {
|
||||
unix.Close(tun.netlinkSock)
|
||||
tun.hackListenerClosed.Lock()
|
||||
close(tun.events)
|
||||
tun.netlinkCancel.Close()
|
||||
}()
|
||||
|
||||
for msg := make([]byte, 1<<16); ; {
|
||||
var err error
|
||||
var msgn int
|
||||
for {
|
||||
msgn, _, _, _, err = unix.Recvmsg(tun.netlinkSock, msg[:], nil, 0)
|
||||
if err == nil || !rwcancel.RetryAfterError(err) {
|
||||
break
|
||||
}
|
||||
if !tun.netlinkCancel.ReadyRead() {
|
||||
tun.errors <- fmt.Errorf("netlink socket closed: %w", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
tun.errors <- fmt.Errorf("failed to receive netlink message: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case <-tun.statusListenersShutdown:
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
wasEverUp := false
|
||||
for remain := msg[:msgn]; len(remain) >= unix.SizeofNlMsghdr; {
|
||||
|
||||
hdr := *(*unix.NlMsghdr)(unsafe.Pointer(&remain[0]))
|
||||
|
||||
if int(hdr.Len) > len(remain) {
|
||||
break
|
||||
}
|
||||
|
||||
switch hdr.Type {
|
||||
case unix.NLMSG_DONE:
|
||||
remain = []byte{}
|
||||
|
||||
case unix.RTM_NEWLINK:
|
||||
info := *(*unix.IfInfomsg)(unsafe.Pointer(&remain[unix.SizeofNlMsghdr]))
|
||||
remain = remain[hdr.Len:]
|
||||
|
||||
if info.Index != tun.index {
|
||||
// not our interface
|
||||
continue
|
||||
}
|
||||
|
||||
if info.Flags&unix.IFF_RUNNING != 0 {
|
||||
tun.events <- EventUp
|
||||
wasEverUp = true
|
||||
}
|
||||
|
||||
if info.Flags&unix.IFF_RUNNING == 0 {
|
||||
// Don't emit EventDown before we've ever emitted EventUp.
|
||||
// This avoids a startup race with HackListener, which
|
||||
// might detect Up before we have finished reporting Down.
|
||||
if wasEverUp {
|
||||
tun.events <- EventDown
|
||||
}
|
||||
}
|
||||
|
||||
tun.events <- EventMTUUpdate
|
||||
|
||||
default:
|
||||
remain = remain[hdr.Len:]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func getIFIndex(name string) (int32, error) {
|
||||
fd, err := unix.Socket(
|
||||
unix.AF_INET,
|
||||
unix.SOCK_DGRAM|unix.SOCK_CLOEXEC,
|
||||
0,
|
||||
)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
defer unix.Close(fd)
|
||||
|
||||
var ifr [ifReqSize]byte
|
||||
copy(ifr[:], name)
|
||||
_, _, errno := unix.Syscall(
|
||||
unix.SYS_IOCTL,
|
||||
uintptr(fd),
|
||||
uintptr(unix.SIOCGIFINDEX),
|
||||
uintptr(unsafe.Pointer(&ifr[0])),
|
||||
)
|
||||
|
||||
if errno != 0 {
|
||||
return 0, errno
|
||||
}
|
||||
|
||||
return *(*int32)(unsafe.Pointer(&ifr[unix.IFNAMSIZ])), nil
|
||||
}
|
||||
|
||||
func (tun *NativeTun) setMTU(n int) error {
|
||||
name, err := tun.Name()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// open datagram socket
|
||||
fd, err := unix.Socket(
|
||||
unix.AF_INET,
|
||||
unix.SOCK_DGRAM|unix.SOCK_CLOEXEC,
|
||||
0,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer unix.Close(fd)
|
||||
|
||||
var ifr [ifReqSize]byte
|
||||
copy(ifr[:], name)
|
||||
*(*uint32)(unsafe.Pointer(&ifr[unix.IFNAMSIZ])) = uint32(n)
|
||||
|
||||
_, _, errno := unix.Syscall(
|
||||
unix.SYS_IOCTL,
|
||||
uintptr(fd),
|
||||
uintptr(unix.SIOCSIFMTU),
|
||||
uintptr(unsafe.Pointer(&ifr[0])),
|
||||
)
|
||||
|
||||
if errno != 0 {
|
||||
return errno
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tun *NativeTun) routineNetlinkRead() {
|
||||
defer func() {
|
||||
unix.Close(tun.netlinkSock)
|
||||
tun.hackListenerClosed.Lock()
|
||||
close(tun.events)
|
||||
tun.netlinkCancel.Close()
|
||||
}()
|
||||
|
||||
for msg := make([]byte, 1<<16); ; {
|
||||
var err error
|
||||
var msgn int
|
||||
for {
|
||||
msgn, _, _, _, err = unix.Recvmsg(tun.netlinkSock, msg[:], nil, 0)
|
||||
if err == nil || !rwcancel.RetryAfterError(err) {
|
||||
break
|
||||
}
|
||||
if !tun.netlinkCancel.ReadyRead() {
|
||||
tun.errors <- fmt.Errorf("netlink socket closed: %w", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
tun.errors <- fmt.Errorf("failed to receive netlink message: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
wasEverUp := false
|
||||
for remain := msg[:msgn]; len(remain) >= unix.SizeofNlMsghdr; {
|
||||
|
||||
hdr := *(*unix.NlMsghdr)(unsafe.Pointer(&remain[0]))
|
||||
|
||||
if int(hdr.Len) > len(remain) {
|
||||
break
|
||||
}
|
||||
|
||||
switch hdr.Type {
|
||||
case unix.NLMSG_DONE:
|
||||
remain = []byte{}
|
||||
|
||||
case unix.RTM_NEWLINK:
|
||||
info := *(*unix.IfInfomsg)(unsafe.Pointer(&remain[unix.SizeofNlMsghdr]))
|
||||
remain = remain[hdr.Len:]
|
||||
|
||||
if info.Index != tun.index {
|
||||
continue
|
||||
}
|
||||
|
||||
if info.Flags&unix.IFF_RUNNING != 0 {
|
||||
tun.events <- EventUp
|
||||
wasEverUp = true
|
||||
}
|
||||
|
||||
if info.Flags&unix.IFF_RUNNING == 0 {
|
||||
if wasEverUp {
|
||||
tun.events <- EventDown
|
||||
}
|
||||
}
|
||||
tun.events <- EventMTUUpdate
|
||||
|
||||
default:
|
||||
remain = remain[hdr.Len:]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (tun *NativeTun) routineNetlink() {
|
||||
var err error
|
||||
|
||||
tun.netlinkSock, err = createNetlinkSocket()
|
||||
if err != nil {
|
||||
tun.errors <- fmt.Errorf("failed to create netlink socket: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
tun.netlinkCancel, err = rwcancel.NewRWCancel(tun.netlinkSock)
|
||||
if err != nil {
|
||||
tun.errors <- fmt.Errorf("failed to create netlink cancel: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
go tun.routineNetlinkListener()
|
||||
}
|
||||
|
||||
func (tun *NativeTun) Close() error {
|
||||
var err1, err2 error
|
||||
tun.closeOnce.Do(func() {
|
||||
if tun.statusListenersShutdown != nil {
|
||||
close(tun.statusListenersShutdown)
|
||||
if tun.netlinkCancel != nil {
|
||||
err1 = tun.netlinkCancel.Cancel()
|
||||
}
|
||||
} else if tun.events != nil {
|
||||
close(tun.events)
|
||||
}
|
||||
err2 = tun.tunFile.Close()
|
||||
})
|
||||
if err1 != nil {
|
||||
return err1
|
||||
}
|
||||
return err2
|
||||
}
|
||||
|
||||
func (tun *NativeTun) BatchSize() int {
|
||||
return tun.batchSize
|
||||
}
|
||||
|
||||
const (
|
||||
// TODO: support TSO with ECN bits
|
||||
tunOffloads = unix.TUN_F_CSUM | unix.TUN_F_TSO4 | unix.TUN_F_TSO6
|
||||
)
|
||||
|
||||
func (tun *NativeTun) initFromFlags(name string) error {
|
||||
sc, err := tun.tunFile.SyscallConn()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if e := sc.Control(func(fd uintptr) {
|
||||
var (
|
||||
ifr *unix.Ifreq
|
||||
)
|
||||
ifr, err = unix.NewIfreq(name)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = unix.IoctlIfreq(int(fd), unix.TUNGETIFF, ifr)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
got := ifr.Uint16()
|
||||
if got&unix.IFF_VNET_HDR != 0 {
|
||||
err = unix.IoctlSetInt(int(fd), unix.TUNSETOFFLOAD, tunOffloads)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
tun.vnetHdr = true
|
||||
tun.batchSize = wgconn.IdealBatchSize
|
||||
} else {
|
||||
tun.batchSize = 1
|
||||
}
|
||||
}); e != nil {
|
||||
return e
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// CreateTUN creates a Device with the provided name and MTU.
|
||||
func CreateTUN(name string, mtu int) (Device, error) {
|
||||
nfd, err := unix.Open(cloneDevicePath, unix.O_RDWR|unix.O_CLOEXEC, 0)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("CreateTUN(%q) failed; %s does not exist", name, cloneDevicePath)
|
||||
}
|
||||
fd := os.NewFile(uintptr(nfd), cloneDevicePath)
|
||||
tun, err := CreateTUNFromFile(fd, mtu)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if name != "tun" {
|
||||
if err := tun.(*NativeTun).initFromFlags(name); err != nil {
|
||||
tun.Close()
|
||||
return nil, fmt.Errorf("CreateTUN(%q) failed to set flags: %w", name, err)
|
||||
}
|
||||
}
|
||||
return tun, nil
|
||||
}
|
||||
|
||||
// CreateTUNFromFile creates a Device from an os.File with the provided MTU.
|
||||
func CreateTUNFromFile(file *os.File, mtu int) (Device, error) {
|
||||
tun := &NativeTun{
|
||||
tunFile: file,
|
||||
errors: make(chan error, 5),
|
||||
events: make(chan Event, 5),
|
||||
}
|
||||
|
||||
var err error
|
||||
tun.index, err = getIFIndex("tun")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get TUN index: %w", err)
|
||||
}
|
||||
|
||||
if err = tun.setMTU(mtu); err != nil {
|
||||
return nil, fmt.Errorf("failed to set MTU: %w", err)
|
||||
}
|
||||
|
||||
tun.statusListenersShutdown = make(chan struct{})
|
||||
go tun.routineNetlink()
|
||||
|
||||
if tun.batchSize == 0 {
|
||||
tun.batchSize = 1
|
||||
}
|
||||
|
||||
tun.tcp4GROTable = newTCPGROTable()
|
||||
tun.tcp6GROTable = newTCPGROTable()
|
||||
|
||||
return tun, nil
|
||||
}
|
||||
|
||||
func (tun *NativeTun) Name() (string, error) {
|
||||
tun.nameOnce.Do(tun.initNameCache)
|
||||
return tun.nameCache, tun.nameErr
|
||||
}
|
||||
|
||||
func (tun *NativeTun) initNameCache() {
|
||||
sysconn, err := tun.tunFile.SyscallConn()
|
||||
if err != nil {
|
||||
tun.nameErr = err
|
||||
return
|
||||
}
|
||||
err = sysconn.Control(func(fd uintptr) {
|
||||
var ifr [ifReqSize]byte
|
||||
_, _, errno := unix.Syscall(
|
||||
unix.SYS_IOCTL,
|
||||
fd,
|
||||
uintptr(unix.TUNGETIFF),
|
||||
uintptr(unsafe.Pointer(&ifr[0])),
|
||||
)
|
||||
if errno != 0 {
|
||||
tun.nameErr = errno
|
||||
return
|
||||
}
|
||||
tun.nameCache = unix.ByteSliceToString(ifr[:])
|
||||
})
|
||||
if err != nil && tun.nameErr == nil {
|
||||
tun.nameErr = err
|
||||
}
|
||||
}
|
||||
|
||||
func (tun *NativeTun) MTU() (int, error) {
|
||||
name, err := tun.Name()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// open datagram socket
|
||||
fd, err := unix.Socket(
|
||||
unix.AF_INET,
|
||||
unix.SOCK_DGRAM|unix.SOCK_CLOEXEC,
|
||||
0,
|
||||
)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer unix.Close(fd)
|
||||
|
||||
var ifr [ifReqSize]byte
|
||||
copy(ifr[:], name)
|
||||
|
||||
_, _, errno := unix.Syscall(
|
||||
unix.SYS_IOCTL,
|
||||
uintptr(fd),
|
||||
uintptr(unix.SIOCGIFMTU),
|
||||
uintptr(unsafe.Pointer(&ifr[0])),
|
||||
)
|
||||
|
||||
if errno != 0 {
|
||||
return 0, errno
|
||||
}
|
||||
|
||||
return int(*(*uint32)(unsafe.Pointer(&ifr[unix.IFNAMSIZ]))), nil
|
||||
}
|
||||
|
||||
func (tun *NativeTun) Events() <-chan Event {
|
||||
return tun.events
|
||||
}
|
||||
|
||||
func (tun *NativeTun) Write(bufs [][]byte, offset int) (int, error) {
|
||||
tun.writeOpMu.Lock()
|
||||
defer func() {
|
||||
tun.tcp4GROTable.reset()
|
||||
tun.tcp6GROTable.reset()
|
||||
tun.writeOpMu.Unlock()
|
||||
}()
|
||||
var (
|
||||
errs error
|
||||
total int
|
||||
)
|
||||
tun.toWrite = tun.toWrite[:0]
|
||||
if tun.vnetHdr {
|
||||
err := handleGRO(bufs, offset, tun.tcp4GROTable, tun.tcp6GROTable, &tun.toWrite)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
offset -= virtioNetHdrLen
|
||||
} else {
|
||||
for i := range bufs {
|
||||
tun.toWrite = append(tun.toWrite, i)
|
||||
}
|
||||
}
|
||||
for _, bufsI := range tun.toWrite {
|
||||
n, err := tun.tunFile.Write(bufs[bufsI][offset:])
|
||||
if errors.Is(err, syscall.EBADFD) {
|
||||
return total, os.ErrClosed
|
||||
}
|
||||
if err != nil {
|
||||
errs = errors.Join(errs, err)
|
||||
} else {
|
||||
total += n
|
||||
}
|
||||
}
|
||||
return total, errs
|
||||
}
|
||||
|
||||
// handleVirtioRead splits in into bufs, leaving offset bytes at the front of
|
||||
// each buffer. It mutates sizes to reflect the size of each element of bufs,
|
||||
// and returns the number of packets read.
|
||||
func handleVirtioRead(in []byte, bufs [][]byte, sizes []int, offset int) (int, error) {
|
||||
var hdr virtioNetHdr
|
||||
if err := hdr.decode(in); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
in = in[virtioNetHdrLen:]
|
||||
if hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_NONE {
|
||||
if hdr.flags&unix.VIRTIO_NET_HDR_F_NEEDS_CSUM != 0 {
|
||||
if err := gsoNoneChecksum(in, hdr.csumStart, hdr.csumOffset); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
if len(in) > len(bufs[0][offset:]) {
|
||||
return 0, fmt.Errorf("read len %d overflows bufs element len %d", len(in), len(bufs[0][offset:]))
|
||||
}
|
||||
n := copy(bufs[0][offset:], in)
|
||||
sizes[0] = n
|
||||
return 1, nil
|
||||
}
|
||||
if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV4 && hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV6 {
|
||||
return 0, fmt.Errorf("unsupported virtio GSO type: %d", hdr.gsoType)
|
||||
}
|
||||
|
||||
ipVersion := in[0] >> 4
|
||||
switch ipVersion {
|
||||
case 4:
|
||||
if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV4 {
|
||||
return 0, fmt.Errorf("ip header version: %d, GSO type: %d", ipVersion, hdr.gsoType)
|
||||
}
|
||||
case 6:
|
||||
if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV6 {
|
||||
return 0, fmt.Errorf("ip header version: %d, GSO type: %d", ipVersion, hdr.gsoType)
|
||||
}
|
||||
default:
|
||||
return 0, fmt.Errorf("invalid ip header version: %d", ipVersion)
|
||||
}
|
||||
|
||||
if len(in) <= int(hdr.csumStart+12) {
|
||||
return 0, errors.New("packet is too short")
|
||||
}
|
||||
tcpHLen := uint16(in[hdr.csumStart+12] >> 4 * 4)
|
||||
if tcpHLen < 20 || tcpHLen > 60 {
|
||||
return 0, fmt.Errorf("tcp header len is invalid: %d", tcpHLen)
|
||||
}
|
||||
hdr.hdrLen = hdr.csumStart + tcpHLen
|
||||
if len(in) < int(hdr.hdrLen) {
|
||||
return 0, fmt.Errorf("length of packet (%d) < virtioNetHdr.hdrLen (%d)", len(in), hdr.hdrLen)
|
||||
}
|
||||
if hdr.hdrLen < hdr.csumStart {
|
||||
return 0, fmt.Errorf("virtioNetHdr.hdrLen (%d) < virtioNetHdr.csumStart (%d)", hdr.hdrLen, hdr.csumStart)
|
||||
}
|
||||
cSumAt := int(hdr.csumStart + hdr.csumOffset)
|
||||
if cSumAt+1 >= len(in) {
|
||||
return 0, fmt.Errorf("end of checksum offset (%d) exceeds packet length (%d)", cSumAt+1, len(in))
|
||||
}
|
||||
|
||||
return tcpTSO(in, hdr, bufs, sizes, offset)
|
||||
}
|
||||
|
||||
func (tun *NativeTun) Read(bufs [][]byte, sizes []int, offset int) (int, error) {
|
||||
tun.readOpMu.Lock()
|
||||
defer tun.readOpMu.Unlock()
|
||||
select {
|
||||
case err := <-tun.errors:
|
||||
return 0, err
|
||||
default:
|
||||
readInto := bufs[0][offset:]
|
||||
if tun.vnetHdr {
|
||||
readInto = tun.readBuff[:]
|
||||
}
|
||||
n, err := tun.tunFile.Read(readInto)
|
||||
if errors.Is(err, syscall.EBADFD) {
|
||||
err = os.ErrClosed
|
||||
}
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if tun.vnetHdr {
|
||||
return handleVirtioRead(readInto[:n], bufs, sizes, offset)
|
||||
}
|
||||
sizes[0] = n
|
||||
return 1, nil
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user