mirror of
https://github.com/slackhq/nebula.git
synced 2025-11-22 16:34:25 +01:00
218 lines
5.0 KiB
Go
218 lines
5.0 KiB
Go
//go:build linux && !android && !e2e_testing
|
|
|
|
package overlay
|
|
|
|
import (
|
|
"fmt"
|
|
"sync"
|
|
|
|
wgtun "github.com/slackhq/nebula/wgstack/tun"
|
|
)
|
|
|
|
type wireguardTunIO struct {
|
|
dev wgtun.Device
|
|
mtu int
|
|
batchSize int
|
|
|
|
readMu sync.Mutex
|
|
readBuffers [][]byte
|
|
readLens []int
|
|
legacyBuf []byte
|
|
|
|
writeMu sync.Mutex
|
|
writeBuf []byte
|
|
writeWrap [][]byte
|
|
writeBuffers [][]byte
|
|
}
|
|
|
|
func newWireguardTunIO(dev wgtun.Device, mtu int) *wireguardTunIO {
|
|
batch := dev.BatchSize()
|
|
if batch <= 0 {
|
|
batch = 1
|
|
}
|
|
if mtu <= 0 {
|
|
mtu = DefaultMTU
|
|
}
|
|
return &wireguardTunIO{
|
|
dev: dev,
|
|
mtu: mtu,
|
|
batchSize: batch,
|
|
readLens: make([]int, batch),
|
|
legacyBuf: make([]byte, wgtun.VirtioNetHdrLen+mtu),
|
|
writeBuf: make([]byte, wgtun.VirtioNetHdrLen+mtu),
|
|
writeWrap: make([][]byte, 1),
|
|
}
|
|
}
|
|
|
|
func (w *wireguardTunIO) Read(p []byte) (int, error) {
|
|
w.readMu.Lock()
|
|
defer w.readMu.Unlock()
|
|
|
|
bufs := w.readBuffers
|
|
if len(bufs) == 0 {
|
|
bufs = [][]byte{w.legacyBuf}
|
|
w.readBuffers = bufs
|
|
}
|
|
n, err := w.dev.Read(bufs[:1], w.readLens[:1], wgtun.VirtioNetHdrLen)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
if n == 0 {
|
|
return 0, nil
|
|
}
|
|
length := w.readLens[0]
|
|
copy(p, w.legacyBuf[wgtun.VirtioNetHdrLen:wgtun.VirtioNetHdrLen+length])
|
|
return length, nil
|
|
}
|
|
|
|
func (w *wireguardTunIO) Write(p []byte) (int, error) {
|
|
if len(p) > w.mtu {
|
|
return 0, fmt.Errorf("wireguard tun: payload exceeds MTU (%d > %d)", len(p), w.mtu)
|
|
}
|
|
w.writeMu.Lock()
|
|
defer w.writeMu.Unlock()
|
|
buf := w.writeBuf[:wgtun.VirtioNetHdrLen+len(p)]
|
|
for i := 0; i < wgtun.VirtioNetHdrLen; i++ {
|
|
buf[i] = 0
|
|
}
|
|
copy(buf[wgtun.VirtioNetHdrLen:], p)
|
|
w.writeWrap[0] = buf
|
|
n, err := w.dev.Write(w.writeWrap, wgtun.VirtioNetHdrLen)
|
|
if err != nil {
|
|
return n, err
|
|
}
|
|
return len(p), nil
|
|
}
|
|
|
|
func (w *wireguardTunIO) ReadIntoBatch(pool *PacketPool) ([]*Packet, error) {
|
|
if pool == nil {
|
|
return nil, fmt.Errorf("wireguard tun: packet pool is nil")
|
|
}
|
|
|
|
w.readMu.Lock()
|
|
defer w.readMu.Unlock()
|
|
|
|
if len(w.readBuffers) < w.batchSize {
|
|
w.readBuffers = make([][]byte, w.batchSize)
|
|
}
|
|
if len(w.readLens) < w.batchSize {
|
|
w.readLens = make([]int, w.batchSize)
|
|
}
|
|
|
|
packets := make([]*Packet, w.batchSize)
|
|
requiredHeadroom := w.BatchHeadroom()
|
|
requiredPayload := w.BatchPayloadCap()
|
|
headroom := 0
|
|
for i := 0; i < w.batchSize; i++ {
|
|
pkt := pool.Get()
|
|
if pkt == nil {
|
|
releasePackets(packets[:i])
|
|
return nil, fmt.Errorf("wireguard tun: packet pool returned nil packet")
|
|
}
|
|
if pkt.Capacity() < requiredPayload {
|
|
pkt.Release()
|
|
releasePackets(packets[:i])
|
|
return nil, fmt.Errorf("wireguard tun: packet capacity %d below required %d", pkt.Capacity(), requiredPayload)
|
|
}
|
|
if i == 0 {
|
|
headroom = pkt.Offset
|
|
if headroom < requiredHeadroom {
|
|
pkt.Release()
|
|
releasePackets(packets[:i])
|
|
return nil, fmt.Errorf("wireguard tun: packet headroom %d below virtio requirement %d", headroom, requiredHeadroom)
|
|
}
|
|
} else if pkt.Offset != headroom {
|
|
pkt.Release()
|
|
releasePackets(packets[:i])
|
|
return nil, fmt.Errorf("wireguard tun: inconsistent packet headroom (%d != %d)", pkt.Offset, headroom)
|
|
}
|
|
packets[i] = pkt
|
|
w.readBuffers[i] = pkt.Buf
|
|
}
|
|
|
|
n, err := w.dev.Read(w.readBuffers[:w.batchSize], w.readLens[:w.batchSize], headroom)
|
|
if err != nil {
|
|
releasePackets(packets)
|
|
return nil, err
|
|
}
|
|
if n == 0 {
|
|
releasePackets(packets)
|
|
return nil, nil
|
|
}
|
|
for i := 0; i < n; i++ {
|
|
packets[i].Len = w.readLens[i]
|
|
}
|
|
for i := n; i < w.batchSize; i++ {
|
|
packets[i].Release()
|
|
packets[i] = nil
|
|
}
|
|
return packets[:n], nil
|
|
}
|
|
|
|
func (w *wireguardTunIO) WriteBatch(packets []*Packet) (int, error) {
|
|
if len(packets) == 0 {
|
|
return 0, nil
|
|
}
|
|
requiredHeadroom := w.BatchHeadroom()
|
|
offset := packets[0].Offset
|
|
if offset < requiredHeadroom {
|
|
releasePackets(packets)
|
|
return 0, fmt.Errorf("wireguard tun: packet offset %d smaller than required headroom %d", offset, requiredHeadroom)
|
|
}
|
|
for _, pkt := range packets {
|
|
if pkt == nil {
|
|
continue
|
|
}
|
|
if pkt.Offset != offset {
|
|
releasePackets(packets)
|
|
return 0, fmt.Errorf("wireguard tun: mixed packet offsets not supported")
|
|
}
|
|
limit := pkt.Offset + pkt.Len
|
|
if limit > len(pkt.Buf) {
|
|
releasePackets(packets)
|
|
return 0, fmt.Errorf("wireguard tun: packet length %d exceeds buffer capacity %d", pkt.Len, len(pkt.Buf)-pkt.Offset)
|
|
}
|
|
}
|
|
w.writeMu.Lock()
|
|
defer w.writeMu.Unlock()
|
|
|
|
if len(w.writeBuffers) < len(packets) {
|
|
w.writeBuffers = make([][]byte, len(packets))
|
|
}
|
|
for i, pkt := range packets {
|
|
if pkt == nil {
|
|
w.writeBuffers[i] = nil
|
|
continue
|
|
}
|
|
limit := pkt.Offset + pkt.Len
|
|
w.writeBuffers[i] = pkt.Buf[:limit]
|
|
}
|
|
n, err := w.dev.Write(w.writeBuffers[:len(packets)], offset)
|
|
releasePackets(packets)
|
|
return n, err
|
|
}
|
|
|
|
func (w *wireguardTunIO) BatchHeadroom() int {
|
|
return wgtun.VirtioNetHdrLen
|
|
}
|
|
|
|
func (w *wireguardTunIO) BatchPayloadCap() int {
|
|
return w.mtu
|
|
}
|
|
|
|
func (w *wireguardTunIO) BatchSize() int {
|
|
return w.batchSize
|
|
}
|
|
|
|
func (w *wireguardTunIO) Close() error {
|
|
return nil
|
|
}
|
|
|
|
func releasePackets(pkts []*Packet) {
|
|
for _, pkt := range pkts {
|
|
if pkt != nil {
|
|
pkt.Release()
|
|
}
|
|
}
|
|
}
|