From e8ea021bddcd0ef086723f658a54b9f331b63a0c Mon Sep 17 00:00:00 2001 From: JackDoan Date: Thu, 13 Nov 2025 16:29:16 -0600 Subject: [PATCH] BAM! a hit from my spice-weasel. Now TX is zero-copy, at the expense of my sanity! --- interface.go | 31 ++----- outside.go | 21 +++-- overlay/tun.go | 6 +- overlay/tun_disabled.go | 20 ++--- overlay/tun_linux.go | 48 ++++++----- overlay/user.go | 21 ++--- overlay/vhostnet/device.go | 117 +++++++++++++++++--------- overlay/virtqueue/descriptor_table.go | 114 +++++++++++++++++++++++++ overlay/virtqueue/split_virtqueue.go | 53 ++++++++++-- overlay/virtqueue/used_ring.go | 55 ++++++++++++ packet/outpacket.go | 54 ++++++++++-- packet/virtio.go | 14 ++- 12 files changed, 431 insertions(+), 123 deletions(-) diff --git a/interface.go b/interface.go index a44b574..5abb837 100644 --- a/interface.go +++ b/interface.go @@ -295,29 +295,16 @@ func (f *Interface) listenOut(q int) { } f.readOutsidePacketsMany(pkts, outPackets, h, fwPacket, lhh, nb, q, ctCache.Get(f.l), time.Now()) - for i := range pkts { - if pkts[i].OutLen != -1 { - for j := 0; j < outPackets[i].SegCounter; j++ { - if len(outPackets[i].Segments[j]) > 0 { - toSend = append(toSend, outPackets[i].Segments[j]) - } - } - } - } - n := len(toSend) - if f.l.Level == logrus.DebugLevel { - f.listenOutMetric.Update(int64(n)) - } - f.listenOutN = n - //toSend = toSend[:toSendCount] - for i := 0; i < n; i += batch { - x := min(len(toSend[i:]), batch) - toSendThisTime := toSend[i : i+x] - _, err := f.readers[q].WriteMany(toSendThisTime, q) - if err != nil { - f.l.WithError(err).Error("Failed to write messages") - } + //we opportunistically tx, but try to also send stragglers + if _, err := f.readers[q].WriteMany(outPackets, q); err != nil { + f.l.WithError(err).Error("Failed to send packets") } + //todo I broke this + //n := len(toSend) + //if f.l.Level == logrus.DebugLevel { + // f.listenOutMetric.Update(int64(n)) + //} + //f.listenOutN = n }) } diff --git a/outside.go b/outside.go index 8cc9a5e..d0c64c7 100644 --- a/outside.go +++ b/outside.go @@ -419,7 +419,12 @@ func (f *Interface) readOutsidePacketsMany(packets []*packet.Packet, out []*pack f.handleHostRoaming(hostinfo, ip) f.connectionManager.In(hostinfo) + } + //_, err := f.readers[q].WriteOne(out[i], false, q) + //if err != nil { + // f.l.WithError(err).Error("Failed to write packet") + //} } } @@ -675,14 +680,20 @@ func (f *Interface) decrypt(hostinfo *HostInfo, mc uint64, out []byte, packet [] func (f *Interface) decryptToTunDelayWrite(hostinfo *HostInfo, messageCounter uint64, out *packet.OutPacket, pkt *packet.Packet, inSegment []byte, fwPacket *firewall.Packet, nb []byte, q int, localCache firewall.ConntrackCache, now time.Time) bool { var err error - out.Segments[out.SegCounter] = out.Segments[out.SegCounter][:0] - out.Segments[out.SegCounter], err = hostinfo.ConnectionState.dKey.DecryptDanger(out.Segments[out.SegCounter], inSegment[:header.Len], inSegment[header.Len:], messageCounter, nb) + seg, err := f.readers[q].AllocSeg(out, q) + if err != nil { + f.l.WithError(err).Errorln("decryptToTunDelayWrite: failed to allocate segment") + return false + } + + out.SegmentPayloads[seg] = out.SegmentPayloads[seg][:0] + out.SegmentPayloads[seg], err = hostinfo.ConnectionState.dKey.DecryptDanger(out.SegmentPayloads[seg], inSegment[:header.Len], inSegment[header.Len:], messageCounter, nb) if err != nil { hostinfo.logger(f.l).WithError(err).Error("Failed to decrypt packet") return false } - err = newPacket(out.Segments[out.SegCounter], true, fwPacket) + err = newPacket(out.SegmentPayloads[seg], true, fwPacket) if err != nil { hostinfo.logger(f.l).WithError(err).WithField("packet", out). Warnf("Error while validating inbound packet") @@ -699,7 +710,7 @@ func (f *Interface) decryptToTunDelayWrite(hostinfo *HostInfo, messageCounter ui if dropReason != nil { // NOTE: We give `packet` as the `out` here since we already decrypted from it and we don't need it anymore // This gives us a buffer to build the reject packet in - f.rejectOutside(out.Segments[out.SegCounter], hostinfo.ConnectionState, hostinfo, nb, inSegment, q) + f.rejectOutside(out.SegmentPayloads[seg], hostinfo.ConnectionState, hostinfo, nb, inSegment, q) if f.l.Level >= logrus.DebugLevel { hostinfo.logger(f.l).WithField("fwPacket", fwPacket). WithField("reason", dropReason). @@ -710,7 +721,7 @@ func (f *Interface) decryptToTunDelayWrite(hostinfo *HostInfo, messageCounter ui f.connectionManager.In(hostinfo) pkt.OutLen += len(inSegment) - out.SegCounter++ + out.Segments[seg] = out.Segments[seg][:len(out.SegmentHeaders[seg])+len(out.SegmentPayloads[seg])] return true } diff --git a/overlay/tun.go b/overlay/tun.go index 871e01d..0567594 100644 --- a/overlay/tun.go +++ b/overlay/tun.go @@ -17,7 +17,11 @@ const DefaultMTU = 1300 type TunDev interface { io.WriteCloser ReadMany(x []*packet.VirtIOPacket, q int) (int, error) - WriteMany(x [][]byte, q int) (int, error) + + //todo this interface sux + AllocSeg(pkt *packet.OutPacket, q int) (int, error) + WriteOne(x *packet.OutPacket, kick bool, q int) (int, error) + WriteMany(x []*packet.OutPacket, q int) (int, error) } // TODO: We may be able to remove routines diff --git a/overlay/tun_disabled.go b/overlay/tun_disabled.go index 55e6303..0da3edc 100644 --- a/overlay/tun_disabled.go +++ b/overlay/tun_disabled.go @@ -111,16 +111,16 @@ func (t *disabledTun) Write(b []byte) (int, error) { return len(b), nil } -func (t *disabledTun) WriteMany(b [][]byte, _ int) (int, error) { - out := 0 - for i := range b { - x, err := t.Write(b[i]) - if err != nil { - return out, err - } - out += x - } - return out, nil +func (t *disabledTun) AllocSeg(pkt *packet.OutPacket, q int) (int, error) { + return 0, fmt.Errorf("tun_disabled: AllocSeg not implemented") +} + +func (t *disabledTun) WriteOne(x *packet.OutPacket, kick bool, q int) (int, error) { + return 0, fmt.Errorf("tun_disabled: WriteOne not implemented") +} + +func (t *disabledTun) WriteMany(x []*packet.OutPacket, q int) (int, error) { + return 0, fmt.Errorf("tun_disabled: WriteMany not implemented") } func (t *disabledTun) ReadMany(b []*packet.VirtIOPacket, _ int) (int, error) { diff --git a/overlay/tun_linux.go b/overlay/tun_linux.go index d8c81cd..86e36e9 100644 --- a/overlay/tun_linux.go +++ b/overlay/tun_linux.go @@ -717,17 +717,15 @@ func (t *tun) ReadMany(p []*packet.VirtIOPacket, q int) (int, error) { func (t *tun) Write(b []byte) (int, error) { maximum := len(b) //we are RXing - hdr := virtio.NetHdr{ //todo - Flags: unix.VIRTIO_NET_HDR_F_DATA_VALID, - GSOType: unix.VIRTIO_NET_HDR_GSO_NONE, - HdrLen: 0, - GSOSize: 0, - CsumStart: 0, - CsumOffset: 0, - NumBuffers: 0, + //todo garbagey + out := packet.NewOut() + x, err := t.AllocSeg(out, 0) + if err != nil { + return 0, err } + copy(out.SegmentPayloads[x], b) + err = t.vdev[0].TransmitPacket(out, true) - err := t.vdev[0].TransmitPackets(hdr, [][]byte{b}) if err != nil { t.l.WithError(err).Error("Transmitting packet") return 0, err @@ -735,22 +733,30 @@ func (t *tun) Write(b []byte) (int, error) { return maximum, nil } -func (t *tun) WriteMany(b [][]byte, q int) (int, error) { - maximum := len(b) //we are RXing +func (t *tun) AllocSeg(pkt *packet.OutPacket, q int) (int, error) { + idx, buf, err := t.vdev[q].GetPacketForTx() + if err != nil { + return 0, err + } + x := pkt.UseSegment(idx, buf) + return x, nil +} + +func (t *tun) WriteOne(x *packet.OutPacket, kick bool, q int) (int, error) { + if err := t.vdev[q].TransmitPacket(x, kick); err != nil { + t.l.WithError(err).Error("Transmitting packet") + return 0, err + } + return 1, nil +} + +func (t *tun) WriteMany(x []*packet.OutPacket, q int) (int, error) { + maximum := len(x) //we are RXing if maximum == 0 { return 0, nil } - hdr := virtio.NetHdr{ //todo - Flags: unix.VIRTIO_NET_HDR_F_DATA_VALID, - GSOType: unix.VIRTIO_NET_HDR_GSO_NONE, - HdrLen: 0, - GSOSize: 0, - CsumStart: 0, - CsumOffset: 0, - NumBuffers: 0, - } - err := t.vdev[q].TransmitPackets(hdr, b) + err := t.vdev[q].TransmitPackets(x) if err != nil { t.l.WithError(err).Error("Transmitting packet") return 0, err diff --git a/overlay/user.go b/overlay/user.go index 992b74a..d469d3e 100644 --- a/overlay/user.go +++ b/overlay/user.go @@ -1,6 +1,7 @@ package overlay import ( + "fmt" "io" "net/netip" @@ -71,14 +72,14 @@ func (d *UserDevice) ReadMany(b []*packet.VirtIOPacket, _ int) (int, error) { return d.Read(b[0].Payload) } -func (d *UserDevice) WriteMany(b [][]byte, _ int) (int, error) { - out := 0 - for i := range b { - x, err := d.Write(b[i]) - if err != nil { - return out, err - } - out += x - } - return out, nil +func (d *UserDevice) AllocSeg(pkt *packet.OutPacket, q int) (int, error) { + return 0, fmt.Errorf("user: AllocSeg not implemented") +} + +func (d *UserDevice) WriteOne(x *packet.OutPacket, kick bool, q int) (int, error) { + return 0, fmt.Errorf("user: WriteOne not implemented") +} + +func (d *UserDevice) WriteMany(x []*packet.OutPacket, q int) (int, error) { + return 0, fmt.Errorf("user: WriteMany not implemented") } diff --git a/overlay/vhostnet/device.go b/overlay/vhostnet/device.go index 452882c..ba41c7e 100644 --- a/overlay/vhostnet/device.go +++ b/overlay/vhostnet/device.go @@ -6,7 +6,6 @@ import ( "fmt" "os" "runtime" - "slices" "github.com/slackhq/nebula/overlay/vhost" "github.com/slackhq/nebula/overlay/virtqueue" @@ -31,6 +30,7 @@ type Device struct { initialized bool controlFD int + fullTable bool ReceiveQueue *virtqueue.SplitQueue TransmitQueue *virtqueue.SplitQueue } @@ -123,6 +123,9 @@ func NewDevice(options ...Option) (*Device, error) { if err = dev.refillReceiveQueue(); err != nil { return nil, fmt.Errorf("refill receive queue: %w", err) } + if err = dev.refillTransmitQueue(); err != nil { + return nil, fmt.Errorf("refill receive queue: %w", err) + } dev.initialized = true @@ -139,7 +142,7 @@ func NewDevice(options ...Option) (*Device, error) { // packets. func (dev *Device) refillReceiveQueue() error { for { - _, err := dev.ReceiveQueue.OfferInDescriptorChains(1) + _, err := dev.ReceiveQueue.OfferInDescriptorChains() if err != nil { if errors.Is(err, virtqueue.ErrNotEnoughFreeDescriptors) { // Queue is full, job is done. @@ -150,6 +153,22 @@ func (dev *Device) refillReceiveQueue() error { } } +func (dev *Device) refillTransmitQueue() error { + //for { + // desc, err := dev.TransmitQueue.DescriptorTable().CreateDescriptorForOutputs() + // if err != nil { + // if errors.Is(err, virtqueue.ErrNotEnoughFreeDescriptors) { + // // Queue is full, job is done. + // return nil + // } + // return fmt.Errorf("offer descriptor chain: %w", err) + // } else { + // dev.TransmitQueue.UsedRing().InitOfferSingle(desc, 0) + // } + //} + return nil +} + // Close cleans up the vhost networking device within the kernel and releases // all resources used for it. // The implementation will try to release as many resources as possible and @@ -238,49 +257,67 @@ func truncateBuffers(buffers [][]byte, length int) (out [][]byte) { return } -func (dev *Device) TransmitPackets(vnethdr virtio.NetHdr, packets [][]byte) error { - // Prepend the packet with its virtio-net header. - vnethdrBuf := make([]byte, virtio.NetHdrSize+14) //todo WHY - if err := vnethdr.Encode(vnethdrBuf); err != nil { - return fmt.Errorf("encode vnethdr: %w", err) - } - vnethdrBuf[virtio.NetHdrSize+14-2] = 0x86 - vnethdrBuf[virtio.NetHdrSize+14-1] = 0xdd //todo ipv6 ethertype +func (dev *Device) GetPacketForTx() (uint16, []byte, error) { + var err error + var idx uint16 + if !dev.fullTable { - chainIndexes, err := dev.TransmitQueue.OfferOutDescriptorChains(vnethdrBuf, packets) + idx, err = dev.TransmitQueue.DescriptorTable().CreateDescriptorForOutputs() + if err == virtqueue.ErrNotEnoughFreeDescriptors { + dev.fullTable = true + idx, err = dev.TransmitQueue.TakeSingle(context.TODO()) + } + } else { + idx, err = dev.TransmitQueue.TakeSingle(context.TODO()) + } if err != nil { - return fmt.Errorf("offer descriptor chain: %w", err) + return 0, nil, fmt.Errorf("transmit queue: %w", err) } - //todo surely there's something better to do here + buf, err := dev.TransmitQueue.GetDescriptorItem(idx) + if err != nil { + return 0, nil, fmt.Errorf("get descriptor chain: %w", err) + } + return idx, buf, nil +} - for { - txedChains, err := dev.TransmitQueue.BlockAndGetHeadsCapped(context.TODO(), len(chainIndexes)) - if err != nil { +func (dev *Device) TransmitPacket(pkt *packet.OutPacket, kick bool) error { + //if pkt.Valid { + if len(pkt.SegmentIDs) == 0 { + return nil + } + for idx := range pkt.SegmentIDs { + segmentID := pkt.SegmentIDs[idx] + dev.TransmitQueue.SetDescSize(segmentID, len(pkt.Segments[idx])) + } + err := dev.TransmitQueue.OfferDescriptorChains(pkt.SegmentIDs, false) + if err != nil { + return fmt.Errorf("offer descriptor chains: %w", err) + } + pkt.Reset() + //} + //if kick { + if err := dev.TransmitQueue.Kick(); err != nil { + return err + } + //} + + return nil +} + +func (dev *Device) TransmitPackets(pkts []*packet.OutPacket) error { + if len(pkts) == 0 { + return nil + } + + for i := range pkts { + if err := dev.TransmitPacket(pkts[i], false); err != nil { return err - } else if len(txedChains) == 0 { - continue //todo will this ever exit? - } - for _, c := range txedChains { - idx := slices.Index(chainIndexes, c.GetHead()) - if idx < 0 { - continue - } else { - _ = dev.TransmitQueue.FreeDescriptorChain(chainIndexes[idx]) - chainIndexes[idx] = 0 //todo I hope this works - } - } - done := true //optimism! - for _, x := range chainIndexes { - if x != 0 { - done = false - break - } - } - - if done { - return nil } } + if err := dev.TransmitQueue.Kick(); err != nil { + return err + } + return nil } // TODO: Make above methods cancelable by taking a context.Context argument? @@ -327,7 +364,7 @@ func (dev *Device) processChains(pkt *packet.VirtIOPacket, chains []virtqueue.Us //shift the buffer out of out: pkt.Payload = pkt.ChainRefs[0][virtio.NetHdrSize:chains[0].Length] pkt.Chains = append(pkt.Chains, uint16(chains[0].DescriptorIndex)) - pkt.Recycler = dev.ReceiveQueue.RecycleDescriptorChains + pkt.Recycler = dev.ReceiveQueue.OfferDescriptorChains return 1, nil //cursor := n - virtio.NetHdrSize @@ -385,7 +422,7 @@ func (dev *Device) ReceivePackets(out []*packet.VirtIOPacket) (int, error) { } // Now that we have copied all buffers, we can recycle the used descriptor chains - //if err = dev.ReceiveQueue.RecycleDescriptorChains(chains); err != nil { + //if err = dev.ReceiveQueue.OfferDescriptorChains(chains); err != nil { // return 0, err //} diff --git a/overlay/virtqueue/descriptor_table.go b/overlay/virtqueue/descriptor_table.go index a56b779..44b8494 100644 --- a/overlay/virtqueue/descriptor_table.go +++ b/overlay/virtqueue/descriptor_table.go @@ -281,6 +281,106 @@ func (dt *DescriptorTable) createDescriptorChain(outBuffers [][]byte, numInBuffe return head, nil } +func (dt *DescriptorTable) CreateDescriptorForOutputs() (uint16, error) { + //todo just fill the damn table + // Do we still have enough free descriptors? + + if 1 > dt.freeNum { + return 0, ErrNotEnoughFreeDescriptors + } + + // Above validation ensured that there is at least one free descriptor, so + // the free descriptor chain head should be valid. + if dt.freeHeadIndex == noFreeHead { + panic("free descriptor chain head is unset but there should be free descriptors") + } + + // To avoid having to iterate over the whole table to find the descriptor + // pointing to the head just to replace the free head, we instead always + // create descriptor chains from the descriptors coming after the head. + // This way we only have to touch the head as a last resort, when all other + // descriptors are already used. + head := dt.descriptors[dt.freeHeadIndex].next + desc := &dt.descriptors[head] + next := desc.next + + checkUnusedDescriptorLength(head, desc) + + // Give the device the maximum available number of bytes to write into. + desc.length = uint32(dt.itemSize) + desc.flags = 0 // descriptorFlagWritable + desc.next = 0 // Not necessary to clear this, it's just for looks. + + dt.freeNum -= 1 + + if dt.freeNum == 0 { + // The last descriptor in the chain should be the free chain head + // itself. + if next != dt.freeHeadIndex { + panic("descriptor chain takes up all free descriptors but does not end with the free chain head") + } + + // When this new chain takes up all remaining descriptors, we no longer + // have a free chain. + dt.freeHeadIndex = noFreeHead + } else { + // We took some descriptors out of the free chain, so make sure to close + // the circle again. + dt.descriptors[dt.freeHeadIndex].next = next + } + + return head, nil +} + +func (dt *DescriptorTable) createDescriptorForInputs() (uint16, error) { + // Do we still have enough free descriptors? + if 1 > dt.freeNum { + return 0, ErrNotEnoughFreeDescriptors + } + + // Above validation ensured that there is at least one free descriptor, so + // the free descriptor chain head should be valid. + if dt.freeHeadIndex == noFreeHead { + panic("free descriptor chain head is unset but there should be free descriptors") + } + + // To avoid having to iterate over the whole table to find the descriptor + // pointing to the head just to replace the free head, we instead always + // create descriptor chains from the descriptors coming after the head. + // This way we only have to touch the head as a last resort, when all other + // descriptors are already used. + head := dt.descriptors[dt.freeHeadIndex].next + desc := &dt.descriptors[head] + next := desc.next + + checkUnusedDescriptorLength(head, desc) + + // Give the device the maximum available number of bytes to write into. + desc.length = uint32(dt.itemSize) + desc.flags = descriptorFlagWritable + desc.next = 0 // Not necessary to clear this, it's just for looks. + + dt.freeNum -= 1 + + if dt.freeNum == 0 { + // The last descriptor in the chain should be the free chain head + // itself. + if next != dt.freeHeadIndex { + panic("descriptor chain takes up all free descriptors but does not end with the free chain head") + } + + // When this new chain takes up all remaining descriptors, we no longer + // have a free chain. + dt.freeHeadIndex = noFreeHead + } else { + // We took some descriptors out of the free chain, so make sure to close + // the circle again. + dt.descriptors[dt.freeHeadIndex].next = next + } + + return head, nil +} + // TODO: Implement a zero-copy variant of createDescriptorChain? // getDescriptorChain returns the device-readable buffers (out buffers) and @@ -334,6 +434,20 @@ func (dt *DescriptorTable) getDescriptorChain(head uint16) (outBuffers, inBuffer return } +func (dt *DescriptorTable) getDescriptorItem(head uint16) ([]byte, error) { + if int(head) > len(dt.descriptors) { + return nil, fmt.Errorf("%w: index out of range", ErrInvalidDescriptorChain) + } + + desc := &dt.descriptors[head] //todo this is a pretty nasty hack with no checks + + // The descriptor address points to memory not managed by Go, so this + // conversion is safe. See https://github.com/golang/go/issues/58625 + //goland:noinspection GoVetUnsafePointer + bs := unsafe.Slice((*byte)(unsafe.Pointer(desc.address)), desc.length) + return bs, nil +} + func (dt *DescriptorTable) getDescriptorInbuffers(head uint16, inBuffers *[][]byte) error { if int(head) > len(dt.descriptors) { return fmt.Errorf("%w: index out of range", ErrInvalidDescriptorChain) diff --git a/overlay/virtqueue/split_virtqueue.go b/overlay/virtqueue/split_virtqueue.go index fd15743..0da7734 100644 --- a/overlay/virtqueue/split_virtqueue.go +++ b/overlay/virtqueue/split_virtqueue.go @@ -208,6 +208,32 @@ func (sq *SplitQueue) BlockAndGetHeads(ctx context.Context) ([]UsedElement, erro return nil, ctx.Err() } +func (sq *SplitQueue) TakeSingle(ctx context.Context) (uint16, error) { + var n int + var err error + for ctx.Err() == nil { + out, ok := sq.usedRing.takeOne() + if ok { + return out, nil + } + // Wait for a signal from the device. + if n, err = sq.epoll.Block(); err != nil { + return 0, fmt.Errorf("wait: %w", err) + } + + if n > 0 { + out, ok = sq.usedRing.takeOne() + if ok { + _ = sq.epoll.Clear() //??? + return out, nil + } else { + continue //??? + } + } + } + return 0, ctx.Err() +} + func (sq *SplitQueue) BlockAndGetHeadsCapped(ctx context.Context, maxToTake int) ([]UsedElement, error) { var n int var err error @@ -268,14 +294,14 @@ func (sq *SplitQueue) BlockAndGetHeadsCapped(ctx context.Context, maxToTake int) // they're done with them. When this does not happen, the queue will run full // and any further calls to [SplitQueue.OfferDescriptorChain] will stall. -func (sq *SplitQueue) OfferInDescriptorChains(numInBuffers int) (uint16, error) { +func (sq *SplitQueue) OfferInDescriptorChains() (uint16, error) { // Create a descriptor chain for the given buffers. var ( head uint16 err error ) for { - head, err = sq.descriptorTable.createDescriptorChain(nil, numInBuffers) + head, err = sq.descriptorTable.createDescriptorForInputs() if err == nil { break } @@ -361,6 +387,11 @@ func (sq *SplitQueue) GetDescriptorChain(head uint16) (outBuffers, inBuffers [][ return sq.descriptorTable.getDescriptorChain(head) } +func (sq *SplitQueue) GetDescriptorItem(head uint16) ([]byte, error) { + sq.descriptorTable.descriptors[head].length = uint32(sq.descriptorTable.itemSize) + return sq.descriptorTable.getDescriptorItem(head) +} + func (sq *SplitQueue) GetDescriptorChainContents(head uint16, out []byte, maxLen int) (int, error) { return sq.descriptorTable.getDescriptorChainContents(head, out, maxLen) } @@ -387,7 +418,12 @@ func (sq *SplitQueue) FreeDescriptorChain(head uint16) error { return nil } -func (sq *SplitQueue) RecycleDescriptorChains(chains []uint16, kick bool) error { +func (sq *SplitQueue) SetDescSize(head uint16, sz int) { + //not called under lock + sq.descriptorTable.descriptors[int(head)].length = uint32(sz) +} + +func (sq *SplitQueue) OfferDescriptorChains(chains []uint16, kick bool) error { //todo not doing this may break eventually? //not called under lock //if err := sq.descriptorTable.freeDescriptorChain(head); err != nil { @@ -399,14 +435,19 @@ func (sq *SplitQueue) RecycleDescriptorChains(chains []uint16, kick bool) error // Notify the device to make it process the updated available ring. if kick { - if err := sq.kickEventFD.Kick(); err != nil { - return fmt.Errorf("notify device: %w", err) - } + return sq.Kick() } return nil } +func (sq *SplitQueue) Kick() error { + if err := sq.kickEventFD.Kick(); err != nil { + return fmt.Errorf("notify device: %w", err) + } + return nil +} + // Close releases all resources used for this queue. // The implementation will try to release as many resources as possible and // collect potential errors before returning them. diff --git a/overlay/virtqueue/used_ring.go b/overlay/virtqueue/used_ring.go index c08b48b..824c07c 100644 --- a/overlay/virtqueue/used_ring.go +++ b/overlay/virtqueue/used_ring.go @@ -127,3 +127,58 @@ func (r *UsedRing) take(maxToTake int) (int, []UsedElement) { return stillNeedToTake, elems } + +func (r *UsedRing) takeOne() (uint16, bool) { + //r.mu.Lock() + //defer r.mu.Unlock() + + ringIndex := *r.ringIndex + if ringIndex == r.lastIndex { + // Nothing new. + return 0xffff, false + } + + // Calculate the number new used elements that we can read from the ring. + // The ring index may wrap, so special handling for that case is needed. + count := int(ringIndex - r.lastIndex) + if count < 0 { + count += 0xffff + } + + // The number of new elements can never exceed the queue size. + if count > len(r.ring) { + panic("used ring contains more new elements than the ring is long") + } + + if count == 0 { + return 0xffff, false + } + + out := r.ring[r.lastIndex%uint16(len(r.ring))].GetHead() + r.lastIndex++ + + return out, true +} + +// InitOfferSingle is only used to pre-fill the used queue at startup, and should not be used if the device is running! +func (r *UsedRing) InitOfferSingle(x uint16, size int) { + //always called under lock + //r.mu.Lock() + //defer r.mu.Unlock() + + offset := 0 + // Add descriptor chain heads to the ring. + + // The 16-bit ring index may overflow. This is expected and is not an + // issue because the size of the ring array (which equals the queue + // size) is always a power of 2 and smaller than the highest possible + // 16-bit value. + insertIndex := int(*r.ringIndex+uint16(offset)) % len(r.ring) + r.ring[insertIndex] = UsedElement{ + DescriptorIndex: uint32(x), + Length: uint32(size), + } + + // Increase the ring index by the number of descriptor chains added to the ring. + *r.ringIndex += 1 +} diff --git a/packet/outpacket.go b/packet/outpacket.go index e345afe..42dd5dd 100644 --- a/packet/outpacket.go +++ b/packet/outpacket.go @@ -1,7 +1,15 @@ package packet +import ( + "github.com/slackhq/nebula/util/virtio" + "golang.org/x/sys/unix" +) + type OutPacket struct { - Segments [][]byte + Segments [][]byte + SegmentPayloads [][]byte + SegmentHeaders [][]byte + SegmentIDs []uint16 //todo virtio header? SegSize int SegCounter int @@ -13,11 +21,45 @@ type OutPacket struct { func NewOut() *OutPacket { out := new(OutPacket) - const numSegments = 64 - out.Segments = make([][]byte, numSegments) - for i := 0; i < numSegments; i++ { //todo this is dumb - out.Segments[i] = make([]byte, Size) - } + out.Segments = make([][]byte, 0, 64) + out.SegmentHeaders = make([][]byte, 0, 64) + out.SegmentPayloads = make([][]byte, 0, 64) + out.SegmentIDs = make([]uint16, 0, 64) out.Scratch = make([]byte, Size) return out } + +func (pkt *OutPacket) Reset() { + pkt.Segments = pkt.Segments[:0] + pkt.SegmentPayloads = pkt.SegmentPayloads[:0] + pkt.SegmentHeaders = pkt.SegmentHeaders[:0] + pkt.SegmentIDs = pkt.SegmentIDs[:0] + pkt.SegSize = 0 + pkt.Valid = false + pkt.wasSegmented = false +} + +func (pkt *OutPacket) UseSegment(segID uint16, seg []byte) int { + pkt.Valid = true + pkt.SegmentIDs = append(pkt.SegmentIDs, segID) + pkt.Segments = append(pkt.Segments, seg) //todo do we need this? + + vhdr := virtio.NetHdr{ //todo + Flags: unix.VIRTIO_NET_HDR_F_DATA_VALID, + GSOType: unix.VIRTIO_NET_HDR_GSO_NONE, + HdrLen: 0, + GSOSize: 0, + CsumStart: 0, + CsumOffset: 0, + NumBuffers: 0, + } + + hdr := seg[0 : virtio.NetHdrSize+14] + _ = vhdr.Encode(hdr) + hdr[virtio.NetHdrSize+14-2] = 0x86 + hdr[virtio.NetHdrSize+14-1] = 0xdd //todo ipv6 ethertype + + pkt.SegmentHeaders = append(pkt.SegmentHeaders, hdr) + pkt.SegmentPayloads = append(pkt.SegmentPayloads, seg[virtio.NetHdrSize+14:]) + return len(pkt.SegmentIDs) - 1 +} diff --git a/packet/virtio.go b/packet/virtio.go index 446443c..133c2b7 100644 --- a/packet/virtio.go +++ b/packet/virtio.go @@ -9,13 +9,13 @@ type VirtIOPacket struct { Header virtio.NetHdr Chains []uint16 ChainRefs [][]byte - // RecycleDescriptorChains(chains []uint16, kick bool) error + // OfferDescriptorChains(chains []uint16, kick bool) error Recycler func([]uint16, bool) error } func NewVIO() *VirtIOPacket { out := new(VirtIOPacket) - out.Payload = make([]byte, Size) + out.Payload = nil out.ChainRefs = make([][]byte, 0, 4) out.Chains = make([]uint16, 0, 8) return out @@ -37,3 +37,13 @@ func (v *VirtIOPacket) Recycle(lastOne bool) error { v.Reset() return nil } + +type VirtIOTXPacket struct { + VirtIOPacket +} + +func NewVIOTX(isV4 bool) *VirtIOTXPacket { + out := new(VirtIOTXPacket) + out.VirtIOPacket = *NewVIO() + return out +}