mirror of
https://github.com/slackhq/nebula.git
synced 2025-11-23 00:44:25 +01:00
hmmmmmm it works i guess maybe
This commit is contained in:
@@ -14,15 +14,15 @@ type wireguardTunIO struct {
|
||||
mtu int
|
||||
batchSize int
|
||||
|
||||
readMu sync.Mutex
|
||||
readBufs [][]byte
|
||||
readLens []int
|
||||
pending [][]byte
|
||||
pendIdx int
|
||||
readMu sync.Mutex
|
||||
readBuffers [][]byte
|
||||
readLens []int
|
||||
legacyBuf []byte
|
||||
|
||||
writeMu sync.Mutex
|
||||
writeBuf []byte
|
||||
writeWrap [][]byte
|
||||
writeMu sync.Mutex
|
||||
writeBuf []byte
|
||||
writeWrap [][]byte
|
||||
writeBuffers [][]byte
|
||||
}
|
||||
|
||||
func newWireguardTunIO(dev wgtun.Device, mtu int) *wireguardTunIO {
|
||||
@@ -33,17 +33,12 @@ func newWireguardTunIO(dev wgtun.Device, mtu int) *wireguardTunIO {
|
||||
if mtu <= 0 {
|
||||
mtu = DefaultMTU
|
||||
}
|
||||
bufs := make([][]byte, batch)
|
||||
for i := range bufs {
|
||||
bufs[i] = make([]byte, wgtun.VirtioNetHdrLen+mtu)
|
||||
}
|
||||
return &wireguardTunIO{
|
||||
dev: dev,
|
||||
mtu: mtu,
|
||||
batchSize: batch,
|
||||
readBufs: bufs,
|
||||
readLens: make([]int, batch),
|
||||
pending: make([][]byte, 0, batch),
|
||||
legacyBuf: make([]byte, wgtun.VirtioNetHdrLen+mtu),
|
||||
writeBuf: make([]byte, wgtun.VirtioNetHdrLen+mtu),
|
||||
writeWrap: make([][]byte, 1),
|
||||
}
|
||||
@@ -53,29 +48,21 @@ func (w *wireguardTunIO) Read(p []byte) (int, error) {
|
||||
w.readMu.Lock()
|
||||
defer w.readMu.Unlock()
|
||||
|
||||
for {
|
||||
if w.pendIdx < len(w.pending) {
|
||||
segment := w.pending[w.pendIdx]
|
||||
w.pendIdx++
|
||||
n := copy(p, segment)
|
||||
return n, nil
|
||||
}
|
||||
|
||||
n, err := w.dev.Read(w.readBufs, w.readLens, wgtun.VirtioNetHdrLen)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
w.pending = w.pending[:0]
|
||||
w.pendIdx = 0
|
||||
for i := 0; i < n; i++ {
|
||||
length := w.readLens[i]
|
||||
if length == 0 {
|
||||
continue
|
||||
}
|
||||
segment := w.readBufs[i][wgtun.VirtioNetHdrLen : wgtun.VirtioNetHdrLen+length]
|
||||
w.pending = append(w.pending, segment)
|
||||
}
|
||||
bufs := w.readBuffers
|
||||
if len(bufs) == 0 {
|
||||
bufs = [][]byte{w.legacyBuf}
|
||||
w.readBuffers = bufs
|
||||
}
|
||||
n, err := w.dev.Read(bufs[:1], w.readLens[:1], wgtun.VirtioNetHdrLen)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if n == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
length := w.readLens[0]
|
||||
copy(p, w.legacyBuf[wgtun.VirtioNetHdrLen:wgtun.VirtioNetHdrLen+length])
|
||||
return length, nil
|
||||
}
|
||||
|
||||
func (w *wireguardTunIO) Write(p []byte) (int, error) {
|
||||
@@ -97,6 +84,134 @@ func (w *wireguardTunIO) Write(p []byte) (int, error) {
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
func (w *wireguardTunIO) ReadIntoBatch(pool *PacketPool) ([]*Packet, error) {
|
||||
if pool == nil {
|
||||
return nil, fmt.Errorf("wireguard tun: packet pool is nil")
|
||||
}
|
||||
|
||||
w.readMu.Lock()
|
||||
defer w.readMu.Unlock()
|
||||
|
||||
if len(w.readBuffers) < w.batchSize {
|
||||
w.readBuffers = make([][]byte, w.batchSize)
|
||||
}
|
||||
if len(w.readLens) < w.batchSize {
|
||||
w.readLens = make([]int, w.batchSize)
|
||||
}
|
||||
|
||||
packets := make([]*Packet, w.batchSize)
|
||||
requiredHeadroom := w.BatchHeadroom()
|
||||
requiredPayload := w.BatchPayloadCap()
|
||||
headroom := 0
|
||||
for i := 0; i < w.batchSize; i++ {
|
||||
pkt := pool.Get()
|
||||
if pkt == nil {
|
||||
releasePackets(packets[:i])
|
||||
return nil, fmt.Errorf("wireguard tun: packet pool returned nil packet")
|
||||
}
|
||||
if pkt.Capacity() < requiredPayload {
|
||||
pkt.Release()
|
||||
releasePackets(packets[:i])
|
||||
return nil, fmt.Errorf("wireguard tun: packet capacity %d below required %d", pkt.Capacity(), requiredPayload)
|
||||
}
|
||||
if i == 0 {
|
||||
headroom = pkt.Offset
|
||||
if headroom < requiredHeadroom {
|
||||
pkt.Release()
|
||||
releasePackets(packets[:i])
|
||||
return nil, fmt.Errorf("wireguard tun: packet headroom %d below virtio requirement %d", headroom, requiredHeadroom)
|
||||
}
|
||||
} else if pkt.Offset != headroom {
|
||||
pkt.Release()
|
||||
releasePackets(packets[:i])
|
||||
return nil, fmt.Errorf("wireguard tun: inconsistent packet headroom (%d != %d)", pkt.Offset, headroom)
|
||||
}
|
||||
packets[i] = pkt
|
||||
w.readBuffers[i] = pkt.Buf
|
||||
}
|
||||
|
||||
n, err := w.dev.Read(w.readBuffers[:w.batchSize], w.readLens[:w.batchSize], headroom)
|
||||
if err != nil {
|
||||
releasePackets(packets)
|
||||
return nil, err
|
||||
}
|
||||
if n == 0 {
|
||||
releasePackets(packets)
|
||||
return nil, nil
|
||||
}
|
||||
for i := 0; i < n; i++ {
|
||||
packets[i].Len = w.readLens[i]
|
||||
}
|
||||
for i := n; i < w.batchSize; i++ {
|
||||
packets[i].Release()
|
||||
packets[i] = nil
|
||||
}
|
||||
return packets[:n], nil
|
||||
}
|
||||
|
||||
func (w *wireguardTunIO) WriteBatch(packets []*Packet) (int, error) {
|
||||
if len(packets) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
requiredHeadroom := w.BatchHeadroom()
|
||||
offset := packets[0].Offset
|
||||
if offset < requiredHeadroom {
|
||||
releasePackets(packets)
|
||||
return 0, fmt.Errorf("wireguard tun: packet offset %d smaller than required headroom %d", offset, requiredHeadroom)
|
||||
}
|
||||
for _, pkt := range packets {
|
||||
if pkt == nil {
|
||||
continue
|
||||
}
|
||||
if pkt.Offset != offset {
|
||||
releasePackets(packets)
|
||||
return 0, fmt.Errorf("wireguard tun: mixed packet offsets not supported")
|
||||
}
|
||||
limit := pkt.Offset + pkt.Len
|
||||
if limit > len(pkt.Buf) {
|
||||
releasePackets(packets)
|
||||
return 0, fmt.Errorf("wireguard tun: packet length %d exceeds buffer capacity %d", pkt.Len, len(pkt.Buf)-pkt.Offset)
|
||||
}
|
||||
}
|
||||
w.writeMu.Lock()
|
||||
defer w.writeMu.Unlock()
|
||||
|
||||
if len(w.writeBuffers) < len(packets) {
|
||||
w.writeBuffers = make([][]byte, len(packets))
|
||||
}
|
||||
for i, pkt := range packets {
|
||||
if pkt == nil {
|
||||
w.writeBuffers[i] = nil
|
||||
continue
|
||||
}
|
||||
limit := pkt.Offset + pkt.Len
|
||||
w.writeBuffers[i] = pkt.Buf[:limit]
|
||||
}
|
||||
n, err := w.dev.Write(w.writeBuffers[:len(packets)], offset)
|
||||
releasePackets(packets)
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (w *wireguardTunIO) BatchHeadroom() int {
|
||||
return wgtun.VirtioNetHdrLen
|
||||
}
|
||||
|
||||
func (w *wireguardTunIO) BatchPayloadCap() int {
|
||||
return w.mtu
|
||||
}
|
||||
|
||||
func (w *wireguardTunIO) BatchSize() int {
|
||||
return w.batchSize
|
||||
}
|
||||
|
||||
func (w *wireguardTunIO) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func releasePackets(pkts []*Packet) {
|
||||
for _, pkt := range pkts {
|
||||
if pkt != nil {
|
||||
pkt.Release()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user