mirror of
https://github.com/slackhq/nebula.git
synced 2025-11-23 17:04:25 +01:00
BAM! a hit from my spice-weasel. Now TX is zero-copy, at the expense of my sanity!
This commit is contained in:
31
interface.go
31
interface.go
@@ -295,29 +295,16 @@ func (f *Interface) listenOut(q int) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
f.readOutsidePacketsMany(pkts, outPackets, h, fwPacket, lhh, nb, q, ctCache.Get(f.l), time.Now())
|
f.readOutsidePacketsMany(pkts, outPackets, h, fwPacket, lhh, nb, q, ctCache.Get(f.l), time.Now())
|
||||||
for i := range pkts {
|
//we opportunistically tx, but try to also send stragglers
|
||||||
if pkts[i].OutLen != -1 {
|
if _, err := f.readers[q].WriteMany(outPackets, q); err != nil {
|
||||||
for j := 0; j < outPackets[i].SegCounter; j++ {
|
f.l.WithError(err).Error("Failed to send packets")
|
||||||
if len(outPackets[i].Segments[j]) > 0 {
|
|
||||||
toSend = append(toSend, outPackets[i].Segments[j])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
n := len(toSend)
|
|
||||||
if f.l.Level == logrus.DebugLevel {
|
|
||||||
f.listenOutMetric.Update(int64(n))
|
|
||||||
}
|
|
||||||
f.listenOutN = n
|
|
||||||
//toSend = toSend[:toSendCount]
|
|
||||||
for i := 0; i < n; i += batch {
|
|
||||||
x := min(len(toSend[i:]), batch)
|
|
||||||
toSendThisTime := toSend[i : i+x]
|
|
||||||
_, err := f.readers[q].WriteMany(toSendThisTime, q)
|
|
||||||
if err != nil {
|
|
||||||
f.l.WithError(err).Error("Failed to write messages")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
//todo I broke this
|
||||||
|
//n := len(toSend)
|
||||||
|
//if f.l.Level == logrus.DebugLevel {
|
||||||
|
// f.listenOutMetric.Update(int64(n))
|
||||||
|
//}
|
||||||
|
//f.listenOutN = n
|
||||||
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
21
outside.go
21
outside.go
@@ -419,7 +419,12 @@ func (f *Interface) readOutsidePacketsMany(packets []*packet.Packet, out []*pack
|
|||||||
f.handleHostRoaming(hostinfo, ip)
|
f.handleHostRoaming(hostinfo, ip)
|
||||||
|
|
||||||
f.connectionManager.In(hostinfo)
|
f.connectionManager.In(hostinfo)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
//_, err := f.readers[q].WriteOne(out[i], false, q)
|
||||||
|
//if err != nil {
|
||||||
|
// f.l.WithError(err).Error("Failed to write packet")
|
||||||
|
//}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -675,14 +680,20 @@ func (f *Interface) decrypt(hostinfo *HostInfo, mc uint64, out []byte, packet []
|
|||||||
func (f *Interface) decryptToTunDelayWrite(hostinfo *HostInfo, messageCounter uint64, out *packet.OutPacket, pkt *packet.Packet, inSegment []byte, fwPacket *firewall.Packet, nb []byte, q int, localCache firewall.ConntrackCache, now time.Time) bool {
|
func (f *Interface) decryptToTunDelayWrite(hostinfo *HostInfo, messageCounter uint64, out *packet.OutPacket, pkt *packet.Packet, inSegment []byte, fwPacket *firewall.Packet, nb []byte, q int, localCache firewall.ConntrackCache, now time.Time) bool {
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
out.Segments[out.SegCounter] = out.Segments[out.SegCounter][:0]
|
seg, err := f.readers[q].AllocSeg(out, q)
|
||||||
out.Segments[out.SegCounter], err = hostinfo.ConnectionState.dKey.DecryptDanger(out.Segments[out.SegCounter], inSegment[:header.Len], inSegment[header.Len:], messageCounter, nb)
|
if err != nil {
|
||||||
|
f.l.WithError(err).Errorln("decryptToTunDelayWrite: failed to allocate segment")
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
out.SegmentPayloads[seg] = out.SegmentPayloads[seg][:0]
|
||||||
|
out.SegmentPayloads[seg], err = hostinfo.ConnectionState.dKey.DecryptDanger(out.SegmentPayloads[seg], inSegment[:header.Len], inSegment[header.Len:], messageCounter, nb)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
hostinfo.logger(f.l).WithError(err).Error("Failed to decrypt packet")
|
hostinfo.logger(f.l).WithError(err).Error("Failed to decrypt packet")
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
err = newPacket(out.Segments[out.SegCounter], true, fwPacket)
|
err = newPacket(out.SegmentPayloads[seg], true, fwPacket)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
hostinfo.logger(f.l).WithError(err).WithField("packet", out).
|
hostinfo.logger(f.l).WithError(err).WithField("packet", out).
|
||||||
Warnf("Error while validating inbound packet")
|
Warnf("Error while validating inbound packet")
|
||||||
@@ -699,7 +710,7 @@ func (f *Interface) decryptToTunDelayWrite(hostinfo *HostInfo, messageCounter ui
|
|||||||
if dropReason != nil {
|
if dropReason != nil {
|
||||||
// NOTE: We give `packet` as the `out` here since we already decrypted from it and we don't need it anymore
|
// NOTE: We give `packet` as the `out` here since we already decrypted from it and we don't need it anymore
|
||||||
// This gives us a buffer to build the reject packet in
|
// This gives us a buffer to build the reject packet in
|
||||||
f.rejectOutside(out.Segments[out.SegCounter], hostinfo.ConnectionState, hostinfo, nb, inSegment, q)
|
f.rejectOutside(out.SegmentPayloads[seg], hostinfo.ConnectionState, hostinfo, nb, inSegment, q)
|
||||||
if f.l.Level >= logrus.DebugLevel {
|
if f.l.Level >= logrus.DebugLevel {
|
||||||
hostinfo.logger(f.l).WithField("fwPacket", fwPacket).
|
hostinfo.logger(f.l).WithField("fwPacket", fwPacket).
|
||||||
WithField("reason", dropReason).
|
WithField("reason", dropReason).
|
||||||
@@ -710,7 +721,7 @@ func (f *Interface) decryptToTunDelayWrite(hostinfo *HostInfo, messageCounter ui
|
|||||||
|
|
||||||
f.connectionManager.In(hostinfo)
|
f.connectionManager.In(hostinfo)
|
||||||
pkt.OutLen += len(inSegment)
|
pkt.OutLen += len(inSegment)
|
||||||
out.SegCounter++
|
out.Segments[seg] = out.Segments[seg][:len(out.SegmentHeaders[seg])+len(out.SegmentPayloads[seg])]
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -17,7 +17,11 @@ const DefaultMTU = 1300
|
|||||||
type TunDev interface {
|
type TunDev interface {
|
||||||
io.WriteCloser
|
io.WriteCloser
|
||||||
ReadMany(x []*packet.VirtIOPacket, q int) (int, error)
|
ReadMany(x []*packet.VirtIOPacket, q int) (int, error)
|
||||||
WriteMany(x [][]byte, q int) (int, error)
|
|
||||||
|
//todo this interface sux
|
||||||
|
AllocSeg(pkt *packet.OutPacket, q int) (int, error)
|
||||||
|
WriteOne(x *packet.OutPacket, kick bool, q int) (int, error)
|
||||||
|
WriteMany(x []*packet.OutPacket, q int) (int, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: We may be able to remove routines
|
// TODO: We may be able to remove routines
|
||||||
|
|||||||
@@ -111,16 +111,16 @@ func (t *disabledTun) Write(b []byte) (int, error) {
|
|||||||
return len(b), nil
|
return len(b), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *disabledTun) WriteMany(b [][]byte, _ int) (int, error) {
|
func (t *disabledTun) AllocSeg(pkt *packet.OutPacket, q int) (int, error) {
|
||||||
out := 0
|
return 0, fmt.Errorf("tun_disabled: AllocSeg not implemented")
|
||||||
for i := range b {
|
|
||||||
x, err := t.Write(b[i])
|
|
||||||
if err != nil {
|
|
||||||
return out, err
|
|
||||||
}
|
}
|
||||||
out += x
|
|
||||||
|
func (t *disabledTun) WriteOne(x *packet.OutPacket, kick bool, q int) (int, error) {
|
||||||
|
return 0, fmt.Errorf("tun_disabled: WriteOne not implemented")
|
||||||
}
|
}
|
||||||
return out, nil
|
|
||||||
|
func (t *disabledTun) WriteMany(x []*packet.OutPacket, q int) (int, error) {
|
||||||
|
return 0, fmt.Errorf("tun_disabled: WriteMany not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *disabledTun) ReadMany(b []*packet.VirtIOPacket, _ int) (int, error) {
|
func (t *disabledTun) ReadMany(b []*packet.VirtIOPacket, _ int) (int, error) {
|
||||||
|
|||||||
@@ -717,17 +717,15 @@ func (t *tun) ReadMany(p []*packet.VirtIOPacket, q int) (int, error) {
|
|||||||
func (t *tun) Write(b []byte) (int, error) {
|
func (t *tun) Write(b []byte) (int, error) {
|
||||||
maximum := len(b) //we are RXing
|
maximum := len(b) //we are RXing
|
||||||
|
|
||||||
hdr := virtio.NetHdr{ //todo
|
//todo garbagey
|
||||||
Flags: unix.VIRTIO_NET_HDR_F_DATA_VALID,
|
out := packet.NewOut()
|
||||||
GSOType: unix.VIRTIO_NET_HDR_GSO_NONE,
|
x, err := t.AllocSeg(out, 0)
|
||||||
HdrLen: 0,
|
if err != nil {
|
||||||
GSOSize: 0,
|
return 0, err
|
||||||
CsumStart: 0,
|
|
||||||
CsumOffset: 0,
|
|
||||||
NumBuffers: 0,
|
|
||||||
}
|
}
|
||||||
|
copy(out.SegmentPayloads[x], b)
|
||||||
|
err = t.vdev[0].TransmitPacket(out, true)
|
||||||
|
|
||||||
err := t.vdev[0].TransmitPackets(hdr, [][]byte{b})
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.l.WithError(err).Error("Transmitting packet")
|
t.l.WithError(err).Error("Transmitting packet")
|
||||||
return 0, err
|
return 0, err
|
||||||
@@ -735,22 +733,30 @@ func (t *tun) Write(b []byte) (int, error) {
|
|||||||
return maximum, nil
|
return maximum, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) WriteMany(b [][]byte, q int) (int, error) {
|
func (t *tun) AllocSeg(pkt *packet.OutPacket, q int) (int, error) {
|
||||||
maximum := len(b) //we are RXing
|
idx, buf, err := t.vdev[q].GetPacketForTx()
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
x := pkt.UseSegment(idx, buf)
|
||||||
|
return x, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *tun) WriteOne(x *packet.OutPacket, kick bool, q int) (int, error) {
|
||||||
|
if err := t.vdev[q].TransmitPacket(x, kick); err != nil {
|
||||||
|
t.l.WithError(err).Error("Transmitting packet")
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
return 1, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *tun) WriteMany(x []*packet.OutPacket, q int) (int, error) {
|
||||||
|
maximum := len(x) //we are RXing
|
||||||
if maximum == 0 {
|
if maximum == 0 {
|
||||||
return 0, nil
|
return 0, nil
|
||||||
}
|
}
|
||||||
hdr := virtio.NetHdr{ //todo
|
|
||||||
Flags: unix.VIRTIO_NET_HDR_F_DATA_VALID,
|
|
||||||
GSOType: unix.VIRTIO_NET_HDR_GSO_NONE,
|
|
||||||
HdrLen: 0,
|
|
||||||
GSOSize: 0,
|
|
||||||
CsumStart: 0,
|
|
||||||
CsumOffset: 0,
|
|
||||||
NumBuffers: 0,
|
|
||||||
}
|
|
||||||
|
|
||||||
err := t.vdev[q].TransmitPackets(hdr, b)
|
err := t.vdev[q].TransmitPackets(x)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.l.WithError(err).Error("Transmitting packet")
|
t.l.WithError(err).Error("Transmitting packet")
|
||||||
return 0, err
|
return 0, err
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package overlay
|
package overlay
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
|
||||||
@@ -71,14 +72,14 @@ func (d *UserDevice) ReadMany(b []*packet.VirtIOPacket, _ int) (int, error) {
|
|||||||
return d.Read(b[0].Payload)
|
return d.Read(b[0].Payload)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *UserDevice) WriteMany(b [][]byte, _ int) (int, error) {
|
func (d *UserDevice) AllocSeg(pkt *packet.OutPacket, q int) (int, error) {
|
||||||
out := 0
|
return 0, fmt.Errorf("user: AllocSeg not implemented")
|
||||||
for i := range b {
|
|
||||||
x, err := d.Write(b[i])
|
|
||||||
if err != nil {
|
|
||||||
return out, err
|
|
||||||
}
|
}
|
||||||
out += x
|
|
||||||
|
func (d *UserDevice) WriteOne(x *packet.OutPacket, kick bool, q int) (int, error) {
|
||||||
|
return 0, fmt.Errorf("user: WriteOne not implemented")
|
||||||
}
|
}
|
||||||
return out, nil
|
|
||||||
|
func (d *UserDevice) WriteMany(x []*packet.OutPacket, q int) (int, error) {
|
||||||
|
return 0, fmt.Errorf("user: WriteMany not implemented")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
"runtime"
|
"runtime"
|
||||||
"slices"
|
|
||||||
|
|
||||||
"github.com/slackhq/nebula/overlay/vhost"
|
"github.com/slackhq/nebula/overlay/vhost"
|
||||||
"github.com/slackhq/nebula/overlay/virtqueue"
|
"github.com/slackhq/nebula/overlay/virtqueue"
|
||||||
@@ -31,6 +30,7 @@ type Device struct {
|
|||||||
initialized bool
|
initialized bool
|
||||||
controlFD int
|
controlFD int
|
||||||
|
|
||||||
|
fullTable bool
|
||||||
ReceiveQueue *virtqueue.SplitQueue
|
ReceiveQueue *virtqueue.SplitQueue
|
||||||
TransmitQueue *virtqueue.SplitQueue
|
TransmitQueue *virtqueue.SplitQueue
|
||||||
}
|
}
|
||||||
@@ -123,6 +123,9 @@ func NewDevice(options ...Option) (*Device, error) {
|
|||||||
if err = dev.refillReceiveQueue(); err != nil {
|
if err = dev.refillReceiveQueue(); err != nil {
|
||||||
return nil, fmt.Errorf("refill receive queue: %w", err)
|
return nil, fmt.Errorf("refill receive queue: %w", err)
|
||||||
}
|
}
|
||||||
|
if err = dev.refillTransmitQueue(); err != nil {
|
||||||
|
return nil, fmt.Errorf("refill receive queue: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
dev.initialized = true
|
dev.initialized = true
|
||||||
|
|
||||||
@@ -139,7 +142,7 @@ func NewDevice(options ...Option) (*Device, error) {
|
|||||||
// packets.
|
// packets.
|
||||||
func (dev *Device) refillReceiveQueue() error {
|
func (dev *Device) refillReceiveQueue() error {
|
||||||
for {
|
for {
|
||||||
_, err := dev.ReceiveQueue.OfferInDescriptorChains(1)
|
_, err := dev.ReceiveQueue.OfferInDescriptorChains()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, virtqueue.ErrNotEnoughFreeDescriptors) {
|
if errors.Is(err, virtqueue.ErrNotEnoughFreeDescriptors) {
|
||||||
// Queue is full, job is done.
|
// Queue is full, job is done.
|
||||||
@@ -150,6 +153,22 @@ func (dev *Device) refillReceiveQueue() error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (dev *Device) refillTransmitQueue() error {
|
||||||
|
//for {
|
||||||
|
// desc, err := dev.TransmitQueue.DescriptorTable().CreateDescriptorForOutputs()
|
||||||
|
// if err != nil {
|
||||||
|
// if errors.Is(err, virtqueue.ErrNotEnoughFreeDescriptors) {
|
||||||
|
// // Queue is full, job is done.
|
||||||
|
// return nil
|
||||||
|
// }
|
||||||
|
// return fmt.Errorf("offer descriptor chain: %w", err)
|
||||||
|
// } else {
|
||||||
|
// dev.TransmitQueue.UsedRing().InitOfferSingle(desc, 0)
|
||||||
|
// }
|
||||||
|
//}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// Close cleans up the vhost networking device within the kernel and releases
|
// Close cleans up the vhost networking device within the kernel and releases
|
||||||
// all resources used for it.
|
// all resources used for it.
|
||||||
// The implementation will try to release as many resources as possible and
|
// The implementation will try to release as many resources as possible and
|
||||||
@@ -238,49 +257,67 @@ func truncateBuffers(buffers [][]byte, length int) (out [][]byte) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (dev *Device) TransmitPackets(vnethdr virtio.NetHdr, packets [][]byte) error {
|
func (dev *Device) GetPacketForTx() (uint16, []byte, error) {
|
||||||
// Prepend the packet with its virtio-net header.
|
var err error
|
||||||
vnethdrBuf := make([]byte, virtio.NetHdrSize+14) //todo WHY
|
var idx uint16
|
||||||
if err := vnethdr.Encode(vnethdrBuf); err != nil {
|
if !dev.fullTable {
|
||||||
return fmt.Errorf("encode vnethdr: %w", err)
|
|
||||||
}
|
|
||||||
vnethdrBuf[virtio.NetHdrSize+14-2] = 0x86
|
|
||||||
vnethdrBuf[virtio.NetHdrSize+14-1] = 0xdd //todo ipv6 ethertype
|
|
||||||
|
|
||||||
chainIndexes, err := dev.TransmitQueue.OfferOutDescriptorChains(vnethdrBuf, packets)
|
idx, err = dev.TransmitQueue.DescriptorTable().CreateDescriptorForOutputs()
|
||||||
if err != nil {
|
if err == virtqueue.ErrNotEnoughFreeDescriptors {
|
||||||
return fmt.Errorf("offer descriptor chain: %w", err)
|
dev.fullTable = true
|
||||||
|
idx, err = dev.TransmitQueue.TakeSingle(context.TODO())
|
||||||
}
|
}
|
||||||
//todo surely there's something better to do here
|
|
||||||
|
|
||||||
for {
|
|
||||||
txedChains, err := dev.TransmitQueue.BlockAndGetHeadsCapped(context.TODO(), len(chainIndexes))
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
} else if len(txedChains) == 0 {
|
|
||||||
continue //todo will this ever exit?
|
|
||||||
}
|
|
||||||
for _, c := range txedChains {
|
|
||||||
idx := slices.Index(chainIndexes, c.GetHead())
|
|
||||||
if idx < 0 {
|
|
||||||
continue
|
|
||||||
} else {
|
} else {
|
||||||
_ = dev.TransmitQueue.FreeDescriptorChain(chainIndexes[idx])
|
idx, err = dev.TransmitQueue.TakeSingle(context.TODO())
|
||||||
chainIndexes[idx] = 0 //todo I hope this works
|
|
||||||
}
|
}
|
||||||
|
if err != nil {
|
||||||
|
return 0, nil, fmt.Errorf("transmit queue: %w", err)
|
||||||
}
|
}
|
||||||
done := true //optimism!
|
buf, err := dev.TransmitQueue.GetDescriptorItem(idx)
|
||||||
for _, x := range chainIndexes {
|
if err != nil {
|
||||||
if x != 0 {
|
return 0, nil, fmt.Errorf("get descriptor chain: %w", err)
|
||||||
done = false
|
|
||||||
break
|
|
||||||
}
|
}
|
||||||
|
return idx, buf, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if done {
|
func (dev *Device) TransmitPacket(pkt *packet.OutPacket, kick bool) error {
|
||||||
|
//if pkt.Valid {
|
||||||
|
if len(pkt.SegmentIDs) == 0 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
for idx := range pkt.SegmentIDs {
|
||||||
|
segmentID := pkt.SegmentIDs[idx]
|
||||||
|
dev.TransmitQueue.SetDescSize(segmentID, len(pkt.Segments[idx]))
|
||||||
}
|
}
|
||||||
|
err := dev.TransmitQueue.OfferDescriptorChains(pkt.SegmentIDs, false)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("offer descriptor chains: %w", err)
|
||||||
|
}
|
||||||
|
pkt.Reset()
|
||||||
|
//}
|
||||||
|
//if kick {
|
||||||
|
if err := dev.TransmitQueue.Kick(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
//}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (dev *Device) TransmitPackets(pkts []*packet.OutPacket) error {
|
||||||
|
if len(pkts) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := range pkts {
|
||||||
|
if err := dev.TransmitPacket(pkts[i], false); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err := dev.TransmitQueue.Kick(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: Make above methods cancelable by taking a context.Context argument?
|
// TODO: Make above methods cancelable by taking a context.Context argument?
|
||||||
@@ -327,7 +364,7 @@ func (dev *Device) processChains(pkt *packet.VirtIOPacket, chains []virtqueue.Us
|
|||||||
//shift the buffer out of out:
|
//shift the buffer out of out:
|
||||||
pkt.Payload = pkt.ChainRefs[0][virtio.NetHdrSize:chains[0].Length]
|
pkt.Payload = pkt.ChainRefs[0][virtio.NetHdrSize:chains[0].Length]
|
||||||
pkt.Chains = append(pkt.Chains, uint16(chains[0].DescriptorIndex))
|
pkt.Chains = append(pkt.Chains, uint16(chains[0].DescriptorIndex))
|
||||||
pkt.Recycler = dev.ReceiveQueue.RecycleDescriptorChains
|
pkt.Recycler = dev.ReceiveQueue.OfferDescriptorChains
|
||||||
return 1, nil
|
return 1, nil
|
||||||
|
|
||||||
//cursor := n - virtio.NetHdrSize
|
//cursor := n - virtio.NetHdrSize
|
||||||
@@ -385,7 +422,7 @@ func (dev *Device) ReceivePackets(out []*packet.VirtIOPacket) (int, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Now that we have copied all buffers, we can recycle the used descriptor chains
|
// Now that we have copied all buffers, we can recycle the used descriptor chains
|
||||||
//if err = dev.ReceiveQueue.RecycleDescriptorChains(chains); err != nil {
|
//if err = dev.ReceiveQueue.OfferDescriptorChains(chains); err != nil {
|
||||||
// return 0, err
|
// return 0, err
|
||||||
//}
|
//}
|
||||||
|
|
||||||
|
|||||||
@@ -281,6 +281,106 @@ func (dt *DescriptorTable) createDescriptorChain(outBuffers [][]byte, numInBuffe
|
|||||||
return head, nil
|
return head, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (dt *DescriptorTable) CreateDescriptorForOutputs() (uint16, error) {
|
||||||
|
//todo just fill the damn table
|
||||||
|
// Do we still have enough free descriptors?
|
||||||
|
|
||||||
|
if 1 > dt.freeNum {
|
||||||
|
return 0, ErrNotEnoughFreeDescriptors
|
||||||
|
}
|
||||||
|
|
||||||
|
// Above validation ensured that there is at least one free descriptor, so
|
||||||
|
// the free descriptor chain head should be valid.
|
||||||
|
if dt.freeHeadIndex == noFreeHead {
|
||||||
|
panic("free descriptor chain head is unset but there should be free descriptors")
|
||||||
|
}
|
||||||
|
|
||||||
|
// To avoid having to iterate over the whole table to find the descriptor
|
||||||
|
// pointing to the head just to replace the free head, we instead always
|
||||||
|
// create descriptor chains from the descriptors coming after the head.
|
||||||
|
// This way we only have to touch the head as a last resort, when all other
|
||||||
|
// descriptors are already used.
|
||||||
|
head := dt.descriptors[dt.freeHeadIndex].next
|
||||||
|
desc := &dt.descriptors[head]
|
||||||
|
next := desc.next
|
||||||
|
|
||||||
|
checkUnusedDescriptorLength(head, desc)
|
||||||
|
|
||||||
|
// Give the device the maximum available number of bytes to write into.
|
||||||
|
desc.length = uint32(dt.itemSize)
|
||||||
|
desc.flags = 0 // descriptorFlagWritable
|
||||||
|
desc.next = 0 // Not necessary to clear this, it's just for looks.
|
||||||
|
|
||||||
|
dt.freeNum -= 1
|
||||||
|
|
||||||
|
if dt.freeNum == 0 {
|
||||||
|
// The last descriptor in the chain should be the free chain head
|
||||||
|
// itself.
|
||||||
|
if next != dt.freeHeadIndex {
|
||||||
|
panic("descriptor chain takes up all free descriptors but does not end with the free chain head")
|
||||||
|
}
|
||||||
|
|
||||||
|
// When this new chain takes up all remaining descriptors, we no longer
|
||||||
|
// have a free chain.
|
||||||
|
dt.freeHeadIndex = noFreeHead
|
||||||
|
} else {
|
||||||
|
// We took some descriptors out of the free chain, so make sure to close
|
||||||
|
// the circle again.
|
||||||
|
dt.descriptors[dt.freeHeadIndex].next = next
|
||||||
|
}
|
||||||
|
|
||||||
|
return head, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (dt *DescriptorTable) createDescriptorForInputs() (uint16, error) {
|
||||||
|
// Do we still have enough free descriptors?
|
||||||
|
if 1 > dt.freeNum {
|
||||||
|
return 0, ErrNotEnoughFreeDescriptors
|
||||||
|
}
|
||||||
|
|
||||||
|
// Above validation ensured that there is at least one free descriptor, so
|
||||||
|
// the free descriptor chain head should be valid.
|
||||||
|
if dt.freeHeadIndex == noFreeHead {
|
||||||
|
panic("free descriptor chain head is unset but there should be free descriptors")
|
||||||
|
}
|
||||||
|
|
||||||
|
// To avoid having to iterate over the whole table to find the descriptor
|
||||||
|
// pointing to the head just to replace the free head, we instead always
|
||||||
|
// create descriptor chains from the descriptors coming after the head.
|
||||||
|
// This way we only have to touch the head as a last resort, when all other
|
||||||
|
// descriptors are already used.
|
||||||
|
head := dt.descriptors[dt.freeHeadIndex].next
|
||||||
|
desc := &dt.descriptors[head]
|
||||||
|
next := desc.next
|
||||||
|
|
||||||
|
checkUnusedDescriptorLength(head, desc)
|
||||||
|
|
||||||
|
// Give the device the maximum available number of bytes to write into.
|
||||||
|
desc.length = uint32(dt.itemSize)
|
||||||
|
desc.flags = descriptorFlagWritable
|
||||||
|
desc.next = 0 // Not necessary to clear this, it's just for looks.
|
||||||
|
|
||||||
|
dt.freeNum -= 1
|
||||||
|
|
||||||
|
if dt.freeNum == 0 {
|
||||||
|
// The last descriptor in the chain should be the free chain head
|
||||||
|
// itself.
|
||||||
|
if next != dt.freeHeadIndex {
|
||||||
|
panic("descriptor chain takes up all free descriptors but does not end with the free chain head")
|
||||||
|
}
|
||||||
|
|
||||||
|
// When this new chain takes up all remaining descriptors, we no longer
|
||||||
|
// have a free chain.
|
||||||
|
dt.freeHeadIndex = noFreeHead
|
||||||
|
} else {
|
||||||
|
// We took some descriptors out of the free chain, so make sure to close
|
||||||
|
// the circle again.
|
||||||
|
dt.descriptors[dt.freeHeadIndex].next = next
|
||||||
|
}
|
||||||
|
|
||||||
|
return head, nil
|
||||||
|
}
|
||||||
|
|
||||||
// TODO: Implement a zero-copy variant of createDescriptorChain?
|
// TODO: Implement a zero-copy variant of createDescriptorChain?
|
||||||
|
|
||||||
// getDescriptorChain returns the device-readable buffers (out buffers) and
|
// getDescriptorChain returns the device-readable buffers (out buffers) and
|
||||||
@@ -334,6 +434,20 @@ func (dt *DescriptorTable) getDescriptorChain(head uint16) (outBuffers, inBuffer
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (dt *DescriptorTable) getDescriptorItem(head uint16) ([]byte, error) {
|
||||||
|
if int(head) > len(dt.descriptors) {
|
||||||
|
return nil, fmt.Errorf("%w: index out of range", ErrInvalidDescriptorChain)
|
||||||
|
}
|
||||||
|
|
||||||
|
desc := &dt.descriptors[head] //todo this is a pretty nasty hack with no checks
|
||||||
|
|
||||||
|
// The descriptor address points to memory not managed by Go, so this
|
||||||
|
// conversion is safe. See https://github.com/golang/go/issues/58625
|
||||||
|
//goland:noinspection GoVetUnsafePointer
|
||||||
|
bs := unsafe.Slice((*byte)(unsafe.Pointer(desc.address)), desc.length)
|
||||||
|
return bs, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (dt *DescriptorTable) getDescriptorInbuffers(head uint16, inBuffers *[][]byte) error {
|
func (dt *DescriptorTable) getDescriptorInbuffers(head uint16, inBuffers *[][]byte) error {
|
||||||
if int(head) > len(dt.descriptors) {
|
if int(head) > len(dt.descriptors) {
|
||||||
return fmt.Errorf("%w: index out of range", ErrInvalidDescriptorChain)
|
return fmt.Errorf("%w: index out of range", ErrInvalidDescriptorChain)
|
||||||
|
|||||||
@@ -208,6 +208,32 @@ func (sq *SplitQueue) BlockAndGetHeads(ctx context.Context) ([]UsedElement, erro
|
|||||||
return nil, ctx.Err()
|
return nil, ctx.Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (sq *SplitQueue) TakeSingle(ctx context.Context) (uint16, error) {
|
||||||
|
var n int
|
||||||
|
var err error
|
||||||
|
for ctx.Err() == nil {
|
||||||
|
out, ok := sq.usedRing.takeOne()
|
||||||
|
if ok {
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
// Wait for a signal from the device.
|
||||||
|
if n, err = sq.epoll.Block(); err != nil {
|
||||||
|
return 0, fmt.Errorf("wait: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if n > 0 {
|
||||||
|
out, ok = sq.usedRing.takeOne()
|
||||||
|
if ok {
|
||||||
|
_ = sq.epoll.Clear() //???
|
||||||
|
return out, nil
|
||||||
|
} else {
|
||||||
|
continue //???
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return 0, ctx.Err()
|
||||||
|
}
|
||||||
|
|
||||||
func (sq *SplitQueue) BlockAndGetHeadsCapped(ctx context.Context, maxToTake int) ([]UsedElement, error) {
|
func (sq *SplitQueue) BlockAndGetHeadsCapped(ctx context.Context, maxToTake int) ([]UsedElement, error) {
|
||||||
var n int
|
var n int
|
||||||
var err error
|
var err error
|
||||||
@@ -268,14 +294,14 @@ func (sq *SplitQueue) BlockAndGetHeadsCapped(ctx context.Context, maxToTake int)
|
|||||||
// they're done with them. When this does not happen, the queue will run full
|
// they're done with them. When this does not happen, the queue will run full
|
||||||
// and any further calls to [SplitQueue.OfferDescriptorChain] will stall.
|
// and any further calls to [SplitQueue.OfferDescriptorChain] will stall.
|
||||||
|
|
||||||
func (sq *SplitQueue) OfferInDescriptorChains(numInBuffers int) (uint16, error) {
|
func (sq *SplitQueue) OfferInDescriptorChains() (uint16, error) {
|
||||||
// Create a descriptor chain for the given buffers.
|
// Create a descriptor chain for the given buffers.
|
||||||
var (
|
var (
|
||||||
head uint16
|
head uint16
|
||||||
err error
|
err error
|
||||||
)
|
)
|
||||||
for {
|
for {
|
||||||
head, err = sq.descriptorTable.createDescriptorChain(nil, numInBuffers)
|
head, err = sq.descriptorTable.createDescriptorForInputs()
|
||||||
if err == nil {
|
if err == nil {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
@@ -361,6 +387,11 @@ func (sq *SplitQueue) GetDescriptorChain(head uint16) (outBuffers, inBuffers [][
|
|||||||
return sq.descriptorTable.getDescriptorChain(head)
|
return sq.descriptorTable.getDescriptorChain(head)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (sq *SplitQueue) GetDescriptorItem(head uint16) ([]byte, error) {
|
||||||
|
sq.descriptorTable.descriptors[head].length = uint32(sq.descriptorTable.itemSize)
|
||||||
|
return sq.descriptorTable.getDescriptorItem(head)
|
||||||
|
}
|
||||||
|
|
||||||
func (sq *SplitQueue) GetDescriptorChainContents(head uint16, out []byte, maxLen int) (int, error) {
|
func (sq *SplitQueue) GetDescriptorChainContents(head uint16, out []byte, maxLen int) (int, error) {
|
||||||
return sq.descriptorTable.getDescriptorChainContents(head, out, maxLen)
|
return sq.descriptorTable.getDescriptorChainContents(head, out, maxLen)
|
||||||
}
|
}
|
||||||
@@ -387,7 +418,12 @@ func (sq *SplitQueue) FreeDescriptorChain(head uint16) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (sq *SplitQueue) RecycleDescriptorChains(chains []uint16, kick bool) error {
|
func (sq *SplitQueue) SetDescSize(head uint16, sz int) {
|
||||||
|
//not called under lock
|
||||||
|
sq.descriptorTable.descriptors[int(head)].length = uint32(sz)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sq *SplitQueue) OfferDescriptorChains(chains []uint16, kick bool) error {
|
||||||
//todo not doing this may break eventually?
|
//todo not doing this may break eventually?
|
||||||
//not called under lock
|
//not called under lock
|
||||||
//if err := sq.descriptorTable.freeDescriptorChain(head); err != nil {
|
//if err := sq.descriptorTable.freeDescriptorChain(head); err != nil {
|
||||||
@@ -399,11 +435,16 @@ func (sq *SplitQueue) RecycleDescriptorChains(chains []uint16, kick bool) error
|
|||||||
|
|
||||||
// Notify the device to make it process the updated available ring.
|
// Notify the device to make it process the updated available ring.
|
||||||
if kick {
|
if kick {
|
||||||
|
return sq.Kick()
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sq *SplitQueue) Kick() error {
|
||||||
if err := sq.kickEventFD.Kick(); err != nil {
|
if err := sq.kickEventFD.Kick(); err != nil {
|
||||||
return fmt.Errorf("notify device: %w", err)
|
return fmt.Errorf("notify device: %w", err)
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -127,3 +127,58 @@ func (r *UsedRing) take(maxToTake int) (int, []UsedElement) {
|
|||||||
|
|
||||||
return stillNeedToTake, elems
|
return stillNeedToTake, elems
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *UsedRing) takeOne() (uint16, bool) {
|
||||||
|
//r.mu.Lock()
|
||||||
|
//defer r.mu.Unlock()
|
||||||
|
|
||||||
|
ringIndex := *r.ringIndex
|
||||||
|
if ringIndex == r.lastIndex {
|
||||||
|
// Nothing new.
|
||||||
|
return 0xffff, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calculate the number new used elements that we can read from the ring.
|
||||||
|
// The ring index may wrap, so special handling for that case is needed.
|
||||||
|
count := int(ringIndex - r.lastIndex)
|
||||||
|
if count < 0 {
|
||||||
|
count += 0xffff
|
||||||
|
}
|
||||||
|
|
||||||
|
// The number of new elements can never exceed the queue size.
|
||||||
|
if count > len(r.ring) {
|
||||||
|
panic("used ring contains more new elements than the ring is long")
|
||||||
|
}
|
||||||
|
|
||||||
|
if count == 0 {
|
||||||
|
return 0xffff, false
|
||||||
|
}
|
||||||
|
|
||||||
|
out := r.ring[r.lastIndex%uint16(len(r.ring))].GetHead()
|
||||||
|
r.lastIndex++
|
||||||
|
|
||||||
|
return out, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// InitOfferSingle is only used to pre-fill the used queue at startup, and should not be used if the device is running!
|
||||||
|
func (r *UsedRing) InitOfferSingle(x uint16, size int) {
|
||||||
|
//always called under lock
|
||||||
|
//r.mu.Lock()
|
||||||
|
//defer r.mu.Unlock()
|
||||||
|
|
||||||
|
offset := 0
|
||||||
|
// Add descriptor chain heads to the ring.
|
||||||
|
|
||||||
|
// The 16-bit ring index may overflow. This is expected and is not an
|
||||||
|
// issue because the size of the ring array (which equals the queue
|
||||||
|
// size) is always a power of 2 and smaller than the highest possible
|
||||||
|
// 16-bit value.
|
||||||
|
insertIndex := int(*r.ringIndex+uint16(offset)) % len(r.ring)
|
||||||
|
r.ring[insertIndex] = UsedElement{
|
||||||
|
DescriptorIndex: uint32(x),
|
||||||
|
Length: uint32(size),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Increase the ring index by the number of descriptor chains added to the ring.
|
||||||
|
*r.ringIndex += 1
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,7 +1,15 @@
|
|||||||
package packet
|
package packet
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/slackhq/nebula/util/virtio"
|
||||||
|
"golang.org/x/sys/unix"
|
||||||
|
)
|
||||||
|
|
||||||
type OutPacket struct {
|
type OutPacket struct {
|
||||||
Segments [][]byte
|
Segments [][]byte
|
||||||
|
SegmentPayloads [][]byte
|
||||||
|
SegmentHeaders [][]byte
|
||||||
|
SegmentIDs []uint16
|
||||||
//todo virtio header?
|
//todo virtio header?
|
||||||
SegSize int
|
SegSize int
|
||||||
SegCounter int
|
SegCounter int
|
||||||
@@ -13,11 +21,45 @@ type OutPacket struct {
|
|||||||
|
|
||||||
func NewOut() *OutPacket {
|
func NewOut() *OutPacket {
|
||||||
out := new(OutPacket)
|
out := new(OutPacket)
|
||||||
const numSegments = 64
|
out.Segments = make([][]byte, 0, 64)
|
||||||
out.Segments = make([][]byte, numSegments)
|
out.SegmentHeaders = make([][]byte, 0, 64)
|
||||||
for i := 0; i < numSegments; i++ { //todo this is dumb
|
out.SegmentPayloads = make([][]byte, 0, 64)
|
||||||
out.Segments[i] = make([]byte, Size)
|
out.SegmentIDs = make([]uint16, 0, 64)
|
||||||
}
|
|
||||||
out.Scratch = make([]byte, Size)
|
out.Scratch = make([]byte, Size)
|
||||||
return out
|
return out
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (pkt *OutPacket) Reset() {
|
||||||
|
pkt.Segments = pkt.Segments[:0]
|
||||||
|
pkt.SegmentPayloads = pkt.SegmentPayloads[:0]
|
||||||
|
pkt.SegmentHeaders = pkt.SegmentHeaders[:0]
|
||||||
|
pkt.SegmentIDs = pkt.SegmentIDs[:0]
|
||||||
|
pkt.SegSize = 0
|
||||||
|
pkt.Valid = false
|
||||||
|
pkt.wasSegmented = false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pkt *OutPacket) UseSegment(segID uint16, seg []byte) int {
|
||||||
|
pkt.Valid = true
|
||||||
|
pkt.SegmentIDs = append(pkt.SegmentIDs, segID)
|
||||||
|
pkt.Segments = append(pkt.Segments, seg) //todo do we need this?
|
||||||
|
|
||||||
|
vhdr := virtio.NetHdr{ //todo
|
||||||
|
Flags: unix.VIRTIO_NET_HDR_F_DATA_VALID,
|
||||||
|
GSOType: unix.VIRTIO_NET_HDR_GSO_NONE,
|
||||||
|
HdrLen: 0,
|
||||||
|
GSOSize: 0,
|
||||||
|
CsumStart: 0,
|
||||||
|
CsumOffset: 0,
|
||||||
|
NumBuffers: 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
hdr := seg[0 : virtio.NetHdrSize+14]
|
||||||
|
_ = vhdr.Encode(hdr)
|
||||||
|
hdr[virtio.NetHdrSize+14-2] = 0x86
|
||||||
|
hdr[virtio.NetHdrSize+14-1] = 0xdd //todo ipv6 ethertype
|
||||||
|
|
||||||
|
pkt.SegmentHeaders = append(pkt.SegmentHeaders, hdr)
|
||||||
|
pkt.SegmentPayloads = append(pkt.SegmentPayloads, seg[virtio.NetHdrSize+14:])
|
||||||
|
return len(pkt.SegmentIDs) - 1
|
||||||
|
}
|
||||||
|
|||||||
@@ -9,13 +9,13 @@ type VirtIOPacket struct {
|
|||||||
Header virtio.NetHdr
|
Header virtio.NetHdr
|
||||||
Chains []uint16
|
Chains []uint16
|
||||||
ChainRefs [][]byte
|
ChainRefs [][]byte
|
||||||
// RecycleDescriptorChains(chains []uint16, kick bool) error
|
// OfferDescriptorChains(chains []uint16, kick bool) error
|
||||||
Recycler func([]uint16, bool) error
|
Recycler func([]uint16, bool) error
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewVIO() *VirtIOPacket {
|
func NewVIO() *VirtIOPacket {
|
||||||
out := new(VirtIOPacket)
|
out := new(VirtIOPacket)
|
||||||
out.Payload = make([]byte, Size)
|
out.Payload = nil
|
||||||
out.ChainRefs = make([][]byte, 0, 4)
|
out.ChainRefs = make([][]byte, 0, 4)
|
||||||
out.Chains = make([]uint16, 0, 8)
|
out.Chains = make([]uint16, 0, 8)
|
||||||
return out
|
return out
|
||||||
@@ -37,3 +37,13 @@ func (v *VirtIOPacket) Recycle(lastOne bool) error {
|
|||||||
v.Reset()
|
v.Reset()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type VirtIOTXPacket struct {
|
||||||
|
VirtIOPacket
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewVIOTX(isV4 bool) *VirtIOTXPacket {
|
||||||
|
out := new(VirtIOTXPacket)
|
||||||
|
out.VirtIOPacket = *NewVIO()
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user