mirror of
https://github.com/slackhq/nebula.git
synced 2025-12-18 04:48:28 +01:00
vhost
This commit is contained in:
@@ -1,18 +1,17 @@
|
||||
package overlay
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/netip"
|
||||
|
||||
"github.com/slackhq/nebula/routing"
|
||||
)
|
||||
|
||||
type Device interface {
|
||||
io.ReadWriteCloser
|
||||
TunDev
|
||||
Activate() error
|
||||
Networks() []netip.Prefix
|
||||
Name() string
|
||||
RoutesFor(netip.Addr) routing.Gateways
|
||||
SupportsMultiqueue() bool
|
||||
NewMultiQueueReader() (io.ReadWriteCloser, error)
|
||||
NewMultiQueueReader() (TunDev, error)
|
||||
}
|
||||
|
||||
91
overlay/eventfd/eventfd.go
Normal file
91
overlay/eventfd/eventfd.go
Normal file
@@ -0,0 +1,91 @@
|
||||
package eventfd
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"syscall"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
type EventFD struct {
|
||||
fd int
|
||||
buf [8]byte
|
||||
}
|
||||
|
||||
func New() (EventFD, error) {
|
||||
fd, err := unix.Eventfd(0, unix.EFD_NONBLOCK)
|
||||
if err != nil {
|
||||
return EventFD{}, err
|
||||
}
|
||||
return EventFD{
|
||||
fd: fd,
|
||||
buf: [8]byte{},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (e *EventFD) Kick() error {
|
||||
binary.LittleEndian.PutUint64(e.buf[:], 1) //is this right???
|
||||
_, err := syscall.Write(int(e.fd), e.buf[:])
|
||||
return err
|
||||
}
|
||||
|
||||
func (e *EventFD) Close() error {
|
||||
if e.fd != 0 {
|
||||
return unix.Close(e.fd)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *EventFD) FD() int {
|
||||
return e.fd
|
||||
}
|
||||
|
||||
type Epoll struct {
|
||||
fd int
|
||||
buf [8]byte
|
||||
events []syscall.EpollEvent
|
||||
}
|
||||
|
||||
func NewEpoll() (Epoll, error) {
|
||||
fd, err := unix.EpollCreate1(0)
|
||||
if err != nil {
|
||||
return Epoll{}, err
|
||||
}
|
||||
return Epoll{
|
||||
fd: fd,
|
||||
buf: [8]byte{},
|
||||
events: make([]syscall.EpollEvent, 1),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (ep *Epoll) AddEvent(fdToAdd int) error {
|
||||
event := syscall.EpollEvent{
|
||||
Events: syscall.EPOLLIN,
|
||||
Fd: int32(fdToAdd),
|
||||
}
|
||||
return syscall.EpollCtl(ep.fd, syscall.EPOLL_CTL_ADD, fdToAdd, &event)
|
||||
}
|
||||
|
||||
func (ep *Epoll) Block() (int, error) {
|
||||
n, err := syscall.EpollWait(ep.fd, ep.events, -1)
|
||||
if err != nil {
|
||||
//goland:noinspection GoDirectComparisonOfErrors
|
||||
if err == syscall.EINTR {
|
||||
return 0, nil //??
|
||||
}
|
||||
return -1, err
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (ep *Epoll) Clear() error {
|
||||
_, err := syscall.Read(int(ep.events[0].Fd), ep.buf[:])
|
||||
return err
|
||||
}
|
||||
|
||||
func (ep *Epoll) Close() error {
|
||||
if ep.fd != 0 {
|
||||
return unix.Close(ep.fd)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -2,16 +2,29 @@ package overlay
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/packet"
|
||||
"github.com/slackhq/nebula/util"
|
||||
)
|
||||
|
||||
const DefaultMTU = 1300
|
||||
|
||||
type TunDev interface {
|
||||
io.WriteCloser
|
||||
ReadMany(x []*packet.VirtIOPacket, q int) (int, error)
|
||||
|
||||
//todo this interface sux
|
||||
AllocSeg(pkt *packet.OutPacket, q int) (int, error)
|
||||
WriteOne(x *packet.OutPacket, kick bool, q int) (int, error)
|
||||
WriteMany(x []*packet.OutPacket, q int) (int, error)
|
||||
RecycleRxSeg(pkt *packet.VirtIOPacket, kick bool, q int) error
|
||||
}
|
||||
|
||||
// TODO: We may be able to remove routines
|
||||
type DeviceFactory func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error)
|
||||
|
||||
@@ -26,11 +39,11 @@ func NewDeviceFromConfig(c *config.C, l *logrus.Logger, vpnNetworks []netip.Pref
|
||||
}
|
||||
}
|
||||
|
||||
func NewFdDeviceFromConfig(fd *int) DeviceFactory {
|
||||
return func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) {
|
||||
return newTunFromFd(c, l, *fd, vpnNetworks)
|
||||
}
|
||||
}
|
||||
//func NewFdDeviceFromConfig(fd *int) DeviceFactory {
|
||||
// return func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) {
|
||||
// return newTunFromFd(c, l, *fd, vpnNetworks)
|
||||
// }
|
||||
//}
|
||||
|
||||
func getAllRoutesFromConfig(c *config.C, vpnNetworks []netip.Prefix, initial bool) (bool, []Route, error) {
|
||||
if !initial && !c.HasChanged("tun.routes") && !c.HasChanged("tun.unsafe_routes") {
|
||||
|
||||
@@ -9,6 +9,8 @@ import (
|
||||
"github.com/rcrowley/go-metrics"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/iputil"
|
||||
"github.com/slackhq/nebula/overlay/virtqueue"
|
||||
"github.com/slackhq/nebula/packet"
|
||||
"github.com/slackhq/nebula/routing"
|
||||
)
|
||||
|
||||
@@ -22,6 +24,10 @@ type disabledTun struct {
|
||||
l *logrus.Logger
|
||||
}
|
||||
|
||||
func (*disabledTun) RecycleRxSeg(pkt *packet.VirtIOPacket, kick bool, q int) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func newDisabledTun(vpnNetworks []netip.Prefix, queueLen int, metricsEnabled bool, l *logrus.Logger) *disabledTun {
|
||||
tun := &disabledTun{
|
||||
vpnNetworks: vpnNetworks,
|
||||
@@ -40,6 +46,10 @@ func newDisabledTun(vpnNetworks []netip.Prefix, queueLen int, metricsEnabled boo
|
||||
return tun
|
||||
}
|
||||
|
||||
func (*disabledTun) GetQueues() []*virtqueue.SplitQueue {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (*disabledTun) Activate() error {
|
||||
return nil
|
||||
}
|
||||
@@ -109,7 +119,23 @@ func (t *disabledTun) SupportsMultiqueue() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (t *disabledTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||
func (t *disabledTun) AllocSeg(pkt *packet.OutPacket, q int) (int, error) {
|
||||
return 0, fmt.Errorf("tun_disabled: AllocSeg not implemented")
|
||||
}
|
||||
|
||||
func (t *disabledTun) WriteOne(x *packet.OutPacket, kick bool, q int) (int, error) {
|
||||
return 0, fmt.Errorf("tun_disabled: WriteOne not implemented")
|
||||
}
|
||||
|
||||
func (t *disabledTun) WriteMany(x []*packet.OutPacket, q int) (int, error) {
|
||||
return 0, fmt.Errorf("tun_disabled: WriteMany not implemented")
|
||||
}
|
||||
|
||||
func (t *disabledTun) ReadMany(b []*packet.VirtIOPacket, _ int) (int, error) {
|
||||
return t.Read(b[0].Payload)
|
||||
}
|
||||
|
||||
func (t *disabledTun) NewMultiQueueReader() (TunDev, error) {
|
||||
return t, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -5,7 +5,6 @@ package overlay
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
@@ -17,15 +16,19 @@ import (
|
||||
"github.com/gaissmai/bart"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/overlay/vhostnet"
|
||||
"github.com/slackhq/nebula/packet"
|
||||
"github.com/slackhq/nebula/routing"
|
||||
"github.com/slackhq/nebula/util"
|
||||
"github.com/slackhq/nebula/util/virtio"
|
||||
"github.com/vishvananda/netlink"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
type tun struct {
|
||||
io.ReadWriteCloser
|
||||
file *os.File
|
||||
fd int
|
||||
vdev []*vhostnet.Device
|
||||
Device string
|
||||
vpnNetworks []netip.Prefix
|
||||
MaxMTU int
|
||||
@@ -40,7 +43,8 @@ type tun struct {
|
||||
useSystemRoutes bool
|
||||
useSystemRoutesBufferSize int
|
||||
|
||||
l *logrus.Logger
|
||||
isV6 bool
|
||||
l *logrus.Logger
|
||||
}
|
||||
|
||||
func (t *tun) Networks() []netip.Prefix {
|
||||
@@ -102,7 +106,7 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueu
|
||||
}
|
||||
|
||||
var req ifReq
|
||||
req.Flags = uint16(unix.IFF_TUN | unix.IFF_NO_PI)
|
||||
req.Flags = uint16(unix.IFF_TUN | unix.IFF_NO_PI | unix.IFF_TUN_EXCL | unix.IFF_VNET_HDR | unix.IFF_NAPI)
|
||||
if multiqueue {
|
||||
req.Flags |= unix.IFF_MULTI_QUEUE
|
||||
}
|
||||
@@ -112,20 +116,47 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueu
|
||||
}
|
||||
name := strings.Trim(string(req.Name[:]), "\x00")
|
||||
|
||||
if err = unix.SetNonblock(fd, true); err != nil {
|
||||
_ = unix.Close(fd)
|
||||
return nil, fmt.Errorf("make file descriptor non-blocking: %w", err)
|
||||
}
|
||||
|
||||
file := os.NewFile(uintptr(fd), "/dev/net/tun")
|
||||
|
||||
err = unix.IoctlSetPointerInt(fd, unix.TUNSETVNETHDRSZ, virtio.NetHdrSize)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("set vnethdr size: %w", err)
|
||||
}
|
||||
|
||||
flags := 0
|
||||
//flags = //unix.TUN_F_CSUM //| unix.TUN_F_TSO4 | unix.TUN_F_USO4 | unix.TUN_F_TSO6 | unix.TUN_F_USO6
|
||||
err = unix.IoctlSetInt(fd, unix.TUNSETOFFLOAD, flags)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("set offloads: %w", err)
|
||||
}
|
||||
|
||||
t, err := newTunGeneric(c, l, file, vpnNetworks)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
t.fd = fd
|
||||
t.Device = name
|
||||
|
||||
vdev, err := vhostnet.NewDevice(
|
||||
vhostnet.WithBackendFD(fd),
|
||||
vhostnet.WithQueueSize(8192), //todo config
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
t.vdev = []*vhostnet.Device{vdev}
|
||||
|
||||
return t, nil
|
||||
}
|
||||
|
||||
func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, vpnNetworks []netip.Prefix) (*tun, error) {
|
||||
t := &tun{
|
||||
ReadWriteCloser: file,
|
||||
file: file,
|
||||
fd: int(file.Fd()),
|
||||
vpnNetworks: vpnNetworks,
|
||||
TXQueueLen: c.GetInt("tun.tx_queue", 500),
|
||||
@@ -133,6 +164,9 @@ func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, vpnNetworks []n
|
||||
useSystemRoutesBufferSize: c.GetInt("tun.use_system_route_table_buffer_size", 0),
|
||||
l: l,
|
||||
}
|
||||
if len(vpnNetworks) != 0 {
|
||||
t.isV6 = vpnNetworks[0].Addr().Is6() //todo what about multi-IP?
|
||||
}
|
||||
|
||||
err := t.reload(c, true)
|
||||
if err != nil {
|
||||
@@ -220,7 +254,7 @@ func (t *tun) SupportsMultiqueue() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||
func (t *tun) NewMultiQueueReader() (TunDev, error) {
|
||||
fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -233,9 +267,17 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
file := os.NewFile(uintptr(fd), "/dev/net/tun")
|
||||
vdev, err := vhostnet.NewDevice(
|
||||
vhostnet.WithBackendFD(fd),
|
||||
vhostnet.WithQueueSize(8192), //todo config
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return file, nil
|
||||
t.vdev = append(t.vdev, vdev)
|
||||
|
||||
return t, nil
|
||||
}
|
||||
|
||||
func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
|
||||
@@ -243,29 +285,6 @@ func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
|
||||
return r
|
||||
}
|
||||
|
||||
func (t *tun) Write(b []byte) (int, error) {
|
||||
var nn int
|
||||
maximum := len(b)
|
||||
|
||||
for {
|
||||
n, err := unix.Write(t.fd, b[nn:maximum])
|
||||
if n > 0 {
|
||||
nn += n
|
||||
}
|
||||
if nn == len(b) {
|
||||
return nn, err
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nn, err
|
||||
}
|
||||
|
||||
if n == 0 {
|
||||
return nn, io.ErrUnexpectedEOF
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (t *tun) deviceBytes() (o [16]byte) {
|
||||
for i, c := range t.Device {
|
||||
o[i] = byte(c)
|
||||
@@ -689,8 +708,14 @@ func (t *tun) Close() error {
|
||||
close(t.routeChan)
|
||||
}
|
||||
|
||||
if t.ReadWriteCloser != nil {
|
||||
_ = t.ReadWriteCloser.Close()
|
||||
for _, v := range t.vdev {
|
||||
if v != nil {
|
||||
_ = v.Close()
|
||||
}
|
||||
}
|
||||
|
||||
if t.file != nil {
|
||||
_ = t.file.Close()
|
||||
}
|
||||
|
||||
if t.ioctlFd > 0 {
|
||||
@@ -699,3 +724,65 @@ func (t *tun) Close() error {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *tun) ReadMany(p []*packet.VirtIOPacket, q int) (int, error) {
|
||||
n, err := t.vdev[q].ReceivePackets(p) //we are TXing
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (t *tun) Write(b []byte) (int, error) {
|
||||
maximum := len(b) //we are RXing
|
||||
|
||||
//todo garbagey
|
||||
out := packet.NewOut()
|
||||
x, err := t.AllocSeg(out, 0)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
copy(out.SegmentPayloads[x], b)
|
||||
err = t.vdev[0].TransmitPacket(out, true)
|
||||
|
||||
if err != nil {
|
||||
t.l.WithError(err).Error("Transmitting packet")
|
||||
return 0, err
|
||||
}
|
||||
return maximum, nil
|
||||
}
|
||||
|
||||
func (t *tun) AllocSeg(pkt *packet.OutPacket, q int) (int, error) {
|
||||
idx, buf, err := t.vdev[q].GetPacketForTx()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
x := pkt.UseSegment(idx, buf, t.isV6)
|
||||
return x, nil
|
||||
}
|
||||
|
||||
func (t *tun) WriteOne(x *packet.OutPacket, kick bool, q int) (int, error) {
|
||||
if err := t.vdev[q].TransmitPacket(x, kick); err != nil {
|
||||
t.l.WithError(err).Error("Transmitting packet")
|
||||
return 0, err
|
||||
}
|
||||
return 1, nil
|
||||
}
|
||||
|
||||
func (t *tun) WriteMany(x []*packet.OutPacket, q int) (int, error) {
|
||||
maximum := len(x) //we are RXing
|
||||
if maximum == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
err := t.vdev[q].TransmitPackets(x)
|
||||
if err != nil {
|
||||
t.l.WithError(err).Error("Transmitting packet")
|
||||
return 0, err
|
||||
}
|
||||
return maximum, nil
|
||||
}
|
||||
|
||||
func (t *tun) RecycleRxSeg(pkt *packet.VirtIOPacket, kick bool, q int) error {
|
||||
return t.vdev[q].ReceiveQueue.OfferDescriptorChains(pkt.Chains, kick)
|
||||
}
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
package overlay
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/netip"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/packet"
|
||||
"github.com/slackhq/nebula/routing"
|
||||
)
|
||||
|
||||
@@ -36,6 +38,10 @@ type UserDevice struct {
|
||||
inboundWriter *io.PipeWriter
|
||||
}
|
||||
|
||||
func (d *UserDevice) RecycleRxSeg(pkt *packet.VirtIOPacket, kick bool, q int) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *UserDevice) Activate() error {
|
||||
return nil
|
||||
}
|
||||
@@ -50,7 +56,7 @@ func (d *UserDevice) SupportsMultiqueue() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (d *UserDevice) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||
func (d *UserDevice) NewMultiQueueReader() (TunDev, error) {
|
||||
return d, nil
|
||||
}
|
||||
|
||||
@@ -69,3 +75,19 @@ func (d *UserDevice) Close() error {
|
||||
d.outboundWriter.Close()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *UserDevice) ReadMany(b []*packet.VirtIOPacket, _ int) (int, error) {
|
||||
return d.Read(b[0].Payload)
|
||||
}
|
||||
|
||||
func (d *UserDevice) AllocSeg(pkt *packet.OutPacket, q int) (int, error) {
|
||||
return 0, fmt.Errorf("user: AllocSeg not implemented")
|
||||
}
|
||||
|
||||
func (d *UserDevice) WriteOne(x *packet.OutPacket, kick bool, q int) (int, error) {
|
||||
return 0, fmt.Errorf("user: WriteOne not implemented")
|
||||
}
|
||||
|
||||
func (d *UserDevice) WriteMany(x []*packet.OutPacket, q int) (int, error) {
|
||||
return 0, fmt.Errorf("user: WriteMany not implemented")
|
||||
}
|
||||
|
||||
23
overlay/vhost/README.md
Normal file
23
overlay/vhost/README.md
Normal file
@@ -0,0 +1,23 @@
|
||||
Significant portions of this code are derived from https://pkg.go.dev/github.com/hetznercloud/virtio-go
|
||||
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2025 Hetzner Cloud GmbH
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
4
overlay/vhost/doc.go
Normal file
4
overlay/vhost/doc.go
Normal file
@@ -0,0 +1,4 @@
|
||||
// Package vhost implements the basic ioctl requests needed to interact with the
|
||||
// kernel-level virtio server that provides accelerated virtio devices for
|
||||
// networking and more.
|
||||
package vhost
|
||||
218
overlay/vhost/ioctl.go
Normal file
218
overlay/vhost/ioctl.go
Normal file
@@ -0,0 +1,218 @@
|
||||
package vhost
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"unsafe"
|
||||
|
||||
"github.com/slackhq/nebula/overlay/virtqueue"
|
||||
"github.com/slackhq/nebula/util/virtio"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
const (
|
||||
// vhostIoctlGetFeatures can be used to retrieve the features supported by
|
||||
// the vhost implementation in the kernel.
|
||||
//
|
||||
// Response payload: [virtio.Feature]
|
||||
// Kernel name: VHOST_GET_FEATURES
|
||||
vhostIoctlGetFeatures = 0x8008af00
|
||||
|
||||
// vhostIoctlSetFeatures can be used to communicate the features supported
|
||||
// by this virtio implementation to the kernel.
|
||||
//
|
||||
// Request payload: [virtio.Feature]
|
||||
// Kernel name: VHOST_SET_FEATURES
|
||||
vhostIoctlSetFeatures = 0x4008af00
|
||||
|
||||
// vhostIoctlSetOwner can be used to set the current process as the
|
||||
// exclusive owner of a control file descriptor.
|
||||
//
|
||||
// Request payload: none
|
||||
// Kernel name: VHOST_SET_OWNER
|
||||
vhostIoctlSetOwner = 0x0000af01
|
||||
|
||||
// vhostIoctlSetMemoryLayout can be used to set up or modify the memory
|
||||
// layout which describes the IOTLB mappings in the kernel.
|
||||
//
|
||||
// Request payload: [MemoryLayout] with custom serialization
|
||||
// Kernel name: VHOST_SET_MEM_TABLE
|
||||
vhostIoctlSetMemoryLayout = 0x4008af03
|
||||
|
||||
// vhostIoctlSetQueueSize can be used to set the size of the virtqueue.
|
||||
//
|
||||
// Request payload: [QueueState]
|
||||
// Kernel name: VHOST_SET_VRING_NUM
|
||||
vhostIoctlSetQueueSize = 0x4008af10
|
||||
|
||||
// vhostIoctlSetQueueAddress can be used to set the addresses of the
|
||||
// different parts of the virtqueue.
|
||||
//
|
||||
// Request payload: [QueueAddresses]
|
||||
// Kernel name: VHOST_SET_VRING_ADDR
|
||||
vhostIoctlSetQueueAddress = 0x4028af11
|
||||
|
||||
// vhostIoctlSetAvailableRingBase can be used to set the index of the next
|
||||
// available ring entry the device will process.
|
||||
//
|
||||
// Request payload: [QueueState]
|
||||
// Kernel name: VHOST_SET_VRING_BASE
|
||||
vhostIoctlSetAvailableRingBase = 0x4008af12
|
||||
|
||||
// vhostIoctlSetQueueKickEventFD can be used to set the event file
|
||||
// descriptor to signal the device when descriptor chains were added to the
|
||||
// available ring.
|
||||
//
|
||||
// Request payload: [QueueFile]
|
||||
// Kernel name: VHOST_SET_VRING_KICK
|
||||
vhostIoctlSetQueueKickEventFD = 0x4008af20
|
||||
|
||||
// vhostIoctlSetQueueCallEventFD can be used to set the event file
|
||||
// descriptor that gets signaled by the device when descriptor chains have
|
||||
// been used by it.
|
||||
//
|
||||
// Request payload: [QueueFile]
|
||||
// Kernel name: VHOST_SET_VRING_CALL
|
||||
vhostIoctlSetQueueCallEventFD = 0x4008af21
|
||||
)
|
||||
|
||||
// QueueState is an ioctl request payload that can hold a queue index and any
|
||||
// 32-bit number.
|
||||
//
|
||||
// Kernel name: vhost_vring_state
|
||||
type QueueState struct {
|
||||
// QueueIndex is the index of the virtqueue.
|
||||
QueueIndex uint32
|
||||
// Num is any 32-bit number, depending on the request.
|
||||
Num uint32
|
||||
}
|
||||
|
||||
// QueueAddresses is an ioctl request payload that can hold the addresses of the
|
||||
// different parts of a virtqueue.
|
||||
//
|
||||
// Kernel name: vhost_vring_addr
|
||||
type QueueAddresses struct {
|
||||
// QueueIndex is the index of the virtqueue.
|
||||
QueueIndex uint32
|
||||
// Flags that are not used in this implementation.
|
||||
Flags uint32
|
||||
// DescriptorTableAddress is the address of the descriptor table in user
|
||||
// space memory. It must be 16-byte aligned.
|
||||
DescriptorTableAddress uintptr
|
||||
// UsedRingAddress is the address of the used ring in user space memory. It
|
||||
// must be 4-byte aligned.
|
||||
UsedRingAddress uintptr
|
||||
// AvailableRingAddress is the address of the available ring in user space
|
||||
// memory. It must be 2-byte aligned.
|
||||
AvailableRingAddress uintptr
|
||||
// LogAddress is used for an optional logging support, not supported by this
|
||||
// implementation.
|
||||
LogAddress uintptr
|
||||
}
|
||||
|
||||
// QueueFile is an ioctl request payload that can hold a queue index and a file
|
||||
// descriptor.
|
||||
//
|
||||
// Kernel name: vhost_vring_file
|
||||
type QueueFile struct {
|
||||
// QueueIndex is the index of the virtqueue.
|
||||
QueueIndex uint32
|
||||
// FD is the file descriptor of the file. Pass -1 to unbind from a file.
|
||||
FD int32
|
||||
}
|
||||
|
||||
// IoctlPtr is a copy of the similarly named unexported function from the Go
|
||||
// unix package. This is needed to do custom ioctl requests not supported by the
|
||||
// standard library.
|
||||
func IoctlPtr(fd int, req uint, arg unsafe.Pointer) error {
|
||||
_, _, err := unix.Syscall(unix.SYS_IOCTL, uintptr(fd), uintptr(req), uintptr(arg))
|
||||
if err != 0 {
|
||||
return fmt.Errorf("ioctl request %d: %w", req, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetFeatures requests the supported feature bits from the virtio device
|
||||
// associated with the given control file descriptor.
|
||||
func GetFeatures(controlFD int) (virtio.Feature, error) {
|
||||
var features virtio.Feature
|
||||
if err := IoctlPtr(controlFD, vhostIoctlGetFeatures, unsafe.Pointer(&features)); err != nil {
|
||||
return 0, fmt.Errorf("get features: %w", err)
|
||||
}
|
||||
return features, nil
|
||||
}
|
||||
|
||||
// SetFeatures communicates the feature bits supported by this implementation
|
||||
// to the virtio device associated with the given control file descriptor.
|
||||
func SetFeatures(controlFD int, features virtio.Feature) error {
|
||||
if err := IoctlPtr(controlFD, vhostIoctlSetFeatures, unsafe.Pointer(&features)); err != nil {
|
||||
return fmt.Errorf("set features: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// OwnControlFD sets the current process as the exclusive owner for the
|
||||
// given control file descriptor. This must be called before interacting with
|
||||
// the control file descriptor in any other way.
|
||||
func OwnControlFD(controlFD int) error {
|
||||
if err := IoctlPtr(controlFD, vhostIoctlSetOwner, unsafe.Pointer(nil)); err != nil {
|
||||
return fmt.Errorf("set control file descriptor owner: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetMemoryLayout sets up or modifies the memory layout for the kernel-level
|
||||
// virtio device associated with the given control file descriptor.
|
||||
func SetMemoryLayout(controlFD int, layout MemoryLayout) error {
|
||||
payload := layout.serializePayload()
|
||||
if err := IoctlPtr(controlFD, vhostIoctlSetMemoryLayout, unsafe.Pointer(&payload[0])); err != nil {
|
||||
return fmt.Errorf("set memory layout: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// RegisterQueue registers a virtio queue with the kernel-level virtio server.
|
||||
// The virtqueue will be linked to the given control file descriptor and will
|
||||
// have the given index. The kernel will use this queue until the control file
|
||||
// descriptor is closed.
|
||||
func RegisterQueue(controlFD int, queueIndex uint32, queue *virtqueue.SplitQueue) error {
|
||||
if err := IoctlPtr(controlFD, vhostIoctlSetQueueSize, unsafe.Pointer(&QueueState{
|
||||
QueueIndex: queueIndex,
|
||||
Num: uint32(queue.Size()),
|
||||
})); err != nil {
|
||||
return fmt.Errorf("set queue size: %w", err)
|
||||
}
|
||||
|
||||
if err := IoctlPtr(controlFD, vhostIoctlSetQueueAddress, unsafe.Pointer(&QueueAddresses{
|
||||
QueueIndex: queueIndex,
|
||||
Flags: 0,
|
||||
DescriptorTableAddress: queue.DescriptorTable().Address(),
|
||||
UsedRingAddress: queue.UsedRing().Address(),
|
||||
AvailableRingAddress: queue.AvailableRing().Address(),
|
||||
LogAddress: 0,
|
||||
})); err != nil {
|
||||
return fmt.Errorf("set queue addresses: %w", err)
|
||||
}
|
||||
|
||||
if err := IoctlPtr(controlFD, vhostIoctlSetAvailableRingBase, unsafe.Pointer(&QueueState{
|
||||
QueueIndex: queueIndex,
|
||||
Num: 0,
|
||||
})); err != nil {
|
||||
return fmt.Errorf("set available ring base: %w", err)
|
||||
}
|
||||
|
||||
if err := IoctlPtr(controlFD, vhostIoctlSetQueueKickEventFD, unsafe.Pointer(&QueueFile{
|
||||
QueueIndex: queueIndex,
|
||||
FD: int32(queue.KickEventFD()),
|
||||
})); err != nil {
|
||||
return fmt.Errorf("set kick event file descriptor: %w", err)
|
||||
}
|
||||
|
||||
if err := IoctlPtr(controlFD, vhostIoctlSetQueueCallEventFD, unsafe.Pointer(&QueueFile{
|
||||
QueueIndex: queueIndex,
|
||||
FD: int32(queue.CallEventFD()),
|
||||
})); err != nil {
|
||||
return fmt.Errorf("set call event file descriptor: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
21
overlay/vhost/ioctl_test.go
Normal file
21
overlay/vhost/ioctl_test.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package vhost_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"unsafe"
|
||||
|
||||
"github.com/slackhq/nebula/overlay/vhost"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestQueueState_Size(t *testing.T) {
|
||||
assert.EqualValues(t, 8, unsafe.Sizeof(vhost.QueueState{}))
|
||||
}
|
||||
|
||||
func TestQueueAddresses_Size(t *testing.T) {
|
||||
assert.EqualValues(t, 40, unsafe.Sizeof(vhost.QueueAddresses{}))
|
||||
}
|
||||
|
||||
func TestQueueFile_Size(t *testing.T) {
|
||||
assert.EqualValues(t, 8, unsafe.Sizeof(vhost.QueueFile{}))
|
||||
}
|
||||
73
overlay/vhost/memory.go
Normal file
73
overlay/vhost/memory.go
Normal file
@@ -0,0 +1,73 @@
|
||||
package vhost
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"unsafe"
|
||||
|
||||
"github.com/slackhq/nebula/overlay/virtqueue"
|
||||
)
|
||||
|
||||
// MemoryRegion describes a region of userspace memory which is being made
|
||||
// accessible to a vhost device.
|
||||
//
|
||||
// Kernel name: vhost_memory_region
|
||||
type MemoryRegion struct {
|
||||
// GuestPhysicalAddress is the physical address of the memory region within
|
||||
// the guest, when virtualization is used. When no virtualization is used,
|
||||
// this should be the same as UserspaceAddress.
|
||||
GuestPhysicalAddress uintptr
|
||||
// Size is the size of the memory region.
|
||||
Size uint64
|
||||
// UserspaceAddress is the virtual address in the userspace of the host
|
||||
// where the memory region can be found.
|
||||
UserspaceAddress uintptr
|
||||
// Padding and room for flags. Currently unused.
|
||||
_ uint64
|
||||
}
|
||||
|
||||
// MemoryLayout is a list of [MemoryRegion]s.
|
||||
type MemoryLayout []MemoryRegion
|
||||
|
||||
// NewMemoryLayoutForQueues returns a new [MemoryLayout] that describes the
|
||||
// memory pages used by the descriptor tables of the given queues.
|
||||
func NewMemoryLayoutForQueues(queues []*virtqueue.SplitQueue) MemoryLayout {
|
||||
regions := make([]MemoryRegion, 0)
|
||||
for _, queue := range queues {
|
||||
for address, size := range queue.DescriptorTable().BufferAddresses() {
|
||||
regions = append(regions, MemoryRegion{
|
||||
// There is no virtualization in play here, so the guest address
|
||||
// is the same as in the host's userspace.
|
||||
GuestPhysicalAddress: address,
|
||||
Size: uint64(size),
|
||||
UserspaceAddress: address,
|
||||
})
|
||||
}
|
||||
}
|
||||
return regions
|
||||
}
|
||||
|
||||
// serializePayload serializes the list of memory regions into a format that is
|
||||
// compatible to the vhost_memory kernel struct. The returned byte slice can be
|
||||
// used as a payload for the vhostIoctlSetMemoryLayout ioctl.
|
||||
func (regions MemoryLayout) serializePayload() []byte {
|
||||
regionCount := len(regions)
|
||||
regionSize := int(unsafe.Sizeof(MemoryRegion{}))
|
||||
payload := make([]byte, 8+regionCount*regionSize)
|
||||
|
||||
// The first 32 bits contain the number of memory regions. The following 32
|
||||
// bits are padding.
|
||||
binary.LittleEndian.PutUint32(payload[0:4], uint32(regionCount))
|
||||
|
||||
if regionCount > 0 {
|
||||
// The underlying byte array of the slice should already have the correct
|
||||
// format, so just copy that.
|
||||
copied := copy(payload[8:], unsafe.Slice((*byte)(unsafe.Pointer(®ions[0])), regionCount*regionSize))
|
||||
if copied != regionCount*regionSize {
|
||||
panic(fmt.Sprintf("copied only %d bytes of the memory regions, but expected %d",
|
||||
copied, regionCount*regionSize))
|
||||
}
|
||||
}
|
||||
|
||||
return payload
|
||||
}
|
||||
42
overlay/vhost/memory_internal_test.go
Normal file
42
overlay/vhost/memory_internal_test.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package vhost
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"unsafe"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestMemoryRegion_Size(t *testing.T) {
|
||||
assert.EqualValues(t, 32, unsafe.Sizeof(MemoryRegion{}))
|
||||
}
|
||||
|
||||
func TestMemoryLayout_SerializePayload(t *testing.T) {
|
||||
layout := MemoryLayout([]MemoryRegion{
|
||||
{
|
||||
GuestPhysicalAddress: 42,
|
||||
Size: 100,
|
||||
UserspaceAddress: 142,
|
||||
}, {
|
||||
GuestPhysicalAddress: 99,
|
||||
Size: 100,
|
||||
UserspaceAddress: 99,
|
||||
},
|
||||
})
|
||||
payload := layout.serializePayload()
|
||||
|
||||
assert.Equal(t, []byte{
|
||||
0x02, 0x00, 0x00, 0x00, // nregions
|
||||
0x00, 0x00, 0x00, 0x00, // padding
|
||||
// region 0
|
||||
0x2a, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // guest_phys_addr
|
||||
0x64, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // memory_size
|
||||
0x8e, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // userspace_addr
|
||||
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // flags_padding
|
||||
// region 1
|
||||
0x63, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // guest_phys_addr
|
||||
0x64, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // memory_size
|
||||
0x63, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // userspace_addr
|
||||
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // flags_padding
|
||||
}, payload)
|
||||
}
|
||||
23
overlay/vhostnet/README.md
Normal file
23
overlay/vhostnet/README.md
Normal file
@@ -0,0 +1,23 @@
|
||||
Significant portions of this code are derived from https://pkg.go.dev/github.com/hetznercloud/virtio-go
|
||||
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2025 Hetzner Cloud GmbH
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
372
overlay/vhostnet/device.go
Normal file
372
overlay/vhostnet/device.go
Normal file
@@ -0,0 +1,372 @@
|
||||
package vhostnet
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"runtime"
|
||||
|
||||
"github.com/slackhq/nebula/overlay/vhost"
|
||||
"github.com/slackhq/nebula/overlay/virtqueue"
|
||||
"github.com/slackhq/nebula/packet"
|
||||
"github.com/slackhq/nebula/util/virtio"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
// ErrDeviceClosed is returned when the [Device] is closed while operations are
|
||||
// still running.
|
||||
var ErrDeviceClosed = errors.New("device was closed")
|
||||
|
||||
// The indexes for the receive and transmit queues.
|
||||
const (
|
||||
receiveQueueIndex = 0
|
||||
transmitQueueIndex = 1
|
||||
)
|
||||
|
||||
// Device represents a vhost networking device within the kernel-level virtio
|
||||
// implementation and provides methods to interact with it.
|
||||
type Device struct {
|
||||
initialized bool
|
||||
controlFD int
|
||||
|
||||
fullTable bool
|
||||
ReceiveQueue *virtqueue.SplitQueue
|
||||
TransmitQueue *virtqueue.SplitQueue
|
||||
}
|
||||
|
||||
// NewDevice initializes a new vhost networking device within the
|
||||
// kernel-level virtio implementation, sets up the virtqueues and returns a
|
||||
// [Device] instance that can be used to communicate with that vhost device.
|
||||
//
|
||||
// There are multiple options that can be passed to this constructor to
|
||||
// influence device creation:
|
||||
// - [WithQueueSize]
|
||||
// - [WithBackendFD]
|
||||
// - [WithBackendDevice]
|
||||
//
|
||||
// Remember to call [Device.Close] after use to free up resources.
|
||||
func NewDevice(options ...Option) (*Device, error) {
|
||||
var err error
|
||||
opts := optionDefaults
|
||||
opts.apply(options)
|
||||
if err = opts.validate(); err != nil {
|
||||
return nil, fmt.Errorf("invalid options: %w", err)
|
||||
}
|
||||
|
||||
dev := Device{
|
||||
controlFD: -1,
|
||||
}
|
||||
|
||||
// Clean up a partially initialized device when something fails.
|
||||
defer func() {
|
||||
if err != nil {
|
||||
_ = dev.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
// Retrieve a new control file descriptor. This will be used to configure
|
||||
// the vhost networking device in the kernel.
|
||||
dev.controlFD, err = unix.Open("/dev/vhost-net", os.O_RDWR, 0666)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get control file descriptor: %w", err)
|
||||
}
|
||||
if err = vhost.OwnControlFD(dev.controlFD); err != nil {
|
||||
return nil, fmt.Errorf("own control file descriptor: %w", err)
|
||||
}
|
||||
|
||||
// Advertise the supported features. This isn't much for now.
|
||||
// TODO: Add feature options and implement proper feature negotiation.
|
||||
getFeatures, err := vhost.GetFeatures(dev.controlFD) //0x1033D008000 but why
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get features: %w", err)
|
||||
}
|
||||
if getFeatures == 0 {
|
||||
|
||||
}
|
||||
//const funky = virtio.Feature(1 << 27)
|
||||
//features := virtio.FeatureVersion1 | funky // | todo virtio.FeatureNetMergeRXBuffers
|
||||
features := virtio.FeatureVersion1 | virtio.FeatureNetMergeRXBuffers
|
||||
if err = vhost.SetFeatures(dev.controlFD, features); err != nil {
|
||||
return nil, fmt.Errorf("set features: %w", err)
|
||||
}
|
||||
|
||||
itemSize := os.Getpagesize() * 4 //todo config
|
||||
|
||||
// Initialize and register the queues needed for the networking device.
|
||||
if dev.ReceiveQueue, err = createQueue(dev.controlFD, receiveQueueIndex, opts.queueSize, itemSize); err != nil {
|
||||
return nil, fmt.Errorf("create receive queue: %w", err)
|
||||
}
|
||||
if dev.TransmitQueue, err = createQueue(dev.controlFD, transmitQueueIndex, opts.queueSize, itemSize); err != nil {
|
||||
return nil, fmt.Errorf("create transmit queue: %w", err)
|
||||
}
|
||||
|
||||
// Set up memory mappings for all buffers used by the queues. This has to
|
||||
// happen before a backend for the queues can be registered.
|
||||
memoryLayout := vhost.NewMemoryLayoutForQueues(
|
||||
[]*virtqueue.SplitQueue{dev.ReceiveQueue, dev.TransmitQueue},
|
||||
)
|
||||
if err = vhost.SetMemoryLayout(dev.controlFD, memoryLayout); err != nil {
|
||||
return nil, fmt.Errorf("setup memory layout: %w", err)
|
||||
}
|
||||
|
||||
// Set the queue backends. This activates the queues within the kernel.
|
||||
if err = SetQueueBackend(dev.controlFD, receiveQueueIndex, opts.backendFD); err != nil {
|
||||
return nil, fmt.Errorf("set receive queue backend: %w", err)
|
||||
}
|
||||
if err = SetQueueBackend(dev.controlFD, transmitQueueIndex, opts.backendFD); err != nil {
|
||||
return nil, fmt.Errorf("set transmit queue backend: %w", err)
|
||||
}
|
||||
|
||||
// Fully populate the receive queue with available buffers which the device
|
||||
// can write new packets into.
|
||||
if err = dev.refillReceiveQueue(); err != nil {
|
||||
return nil, fmt.Errorf("refill receive queue: %w", err)
|
||||
}
|
||||
|
||||
dev.initialized = true
|
||||
|
||||
// Make sure to clean up even when the device gets garbage collected without
|
||||
// Close being called first.
|
||||
devPtr := &dev
|
||||
runtime.SetFinalizer(devPtr, (*Device).Close)
|
||||
|
||||
return devPtr, nil
|
||||
}
|
||||
|
||||
// refillReceiveQueue offers as many new device-writable buffers to the device
|
||||
// as the queue can fit. The device will then use these to write received
|
||||
// packets.
|
||||
func (dev *Device) refillReceiveQueue() error {
|
||||
for {
|
||||
_, err := dev.ReceiveQueue.OfferInDescriptorChains()
|
||||
if err != nil {
|
||||
if errors.Is(err, virtqueue.ErrNotEnoughFreeDescriptors) {
|
||||
// Queue is full, job is done.
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("offer descriptor chain: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Close cleans up the vhost networking device within the kernel and releases
|
||||
// all resources used for it.
|
||||
// The implementation will try to release as many resources as possible and
|
||||
// collect potential errors before returning them.
|
||||
func (dev *Device) Close() error {
|
||||
dev.initialized = false
|
||||
|
||||
// Closing the control file descriptor will unregister all queues from the
|
||||
// kernel.
|
||||
if dev.controlFD >= 0 {
|
||||
if err := unix.Close(dev.controlFD); err != nil {
|
||||
// Return an error and do not continue, because the memory used for
|
||||
// the queues should not be released before they were unregistered
|
||||
// from the kernel.
|
||||
return fmt.Errorf("close control file descriptor: %w", err)
|
||||
}
|
||||
dev.controlFD = -1
|
||||
}
|
||||
|
||||
var errs []error
|
||||
|
||||
if dev.ReceiveQueue != nil {
|
||||
if err := dev.ReceiveQueue.Close(); err == nil {
|
||||
dev.ReceiveQueue = nil
|
||||
} else {
|
||||
errs = append(errs, fmt.Errorf("close receive queue: %w", err))
|
||||
}
|
||||
}
|
||||
|
||||
if dev.TransmitQueue != nil {
|
||||
if err := dev.TransmitQueue.Close(); err == nil {
|
||||
dev.TransmitQueue = nil
|
||||
} else {
|
||||
errs = append(errs, fmt.Errorf("close transmit queue: %w", err))
|
||||
}
|
||||
}
|
||||
|
||||
if len(errs) == 0 {
|
||||
// Everything was cleaned up. No need to run the finalizer anymore.
|
||||
runtime.SetFinalizer(dev, nil)
|
||||
}
|
||||
|
||||
return errors.Join(errs...)
|
||||
}
|
||||
|
||||
// createQueue creates a new virtqueue and registers it with the vhost device
|
||||
// using the given index.
|
||||
func createQueue(controlFD int, queueIndex int, queueSize int, itemSize int) (*virtqueue.SplitQueue, error) {
|
||||
var (
|
||||
queue *virtqueue.SplitQueue
|
||||
err error
|
||||
)
|
||||
if queue, err = virtqueue.NewSplitQueue(queueSize, itemSize); err != nil {
|
||||
return nil, fmt.Errorf("create virtqueue: %w", err)
|
||||
}
|
||||
if err = vhost.RegisterQueue(controlFD, uint32(queueIndex), queue); err != nil {
|
||||
return nil, fmt.Errorf("register virtqueue with index %d: %w", queueIndex, err)
|
||||
}
|
||||
return queue, nil
|
||||
}
|
||||
|
||||
func (dev *Device) GetPacketForTx() (uint16, []byte, error) {
|
||||
var err error
|
||||
var idx uint16
|
||||
if !dev.fullTable {
|
||||
idx, err = dev.TransmitQueue.DescriptorTable().CreateDescriptorForOutputs()
|
||||
if err == virtqueue.ErrNotEnoughFreeDescriptors {
|
||||
dev.fullTable = true
|
||||
idx, err = dev.TransmitQueue.TakeSingle(context.TODO())
|
||||
}
|
||||
} else {
|
||||
idx, err = dev.TransmitQueue.TakeSingle(context.TODO())
|
||||
}
|
||||
if err != nil {
|
||||
return 0, nil, fmt.Errorf("transmit queue: %w", err)
|
||||
}
|
||||
buf, err := dev.TransmitQueue.GetDescriptorItem(idx)
|
||||
if err != nil {
|
||||
return 0, nil, fmt.Errorf("get descriptor chain: %w", err)
|
||||
}
|
||||
return idx, buf, nil
|
||||
}
|
||||
|
||||
func (dev *Device) TransmitPacket(pkt *packet.OutPacket, kick bool) error {
|
||||
if len(pkt.SegmentIDs) == 0 {
|
||||
return nil
|
||||
}
|
||||
for idx := range pkt.SegmentIDs {
|
||||
segmentID := pkt.SegmentIDs[idx]
|
||||
dev.TransmitQueue.SetDescSize(segmentID, len(pkt.Segments[idx]))
|
||||
}
|
||||
err := dev.TransmitQueue.OfferDescriptorChains(pkt.SegmentIDs, false)
|
||||
if err != nil {
|
||||
return fmt.Errorf("offer descriptor chains: %w", err)
|
||||
}
|
||||
pkt.Reset()
|
||||
if kick {
|
||||
if err := dev.TransmitQueue.Kick(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (dev *Device) TransmitPackets(pkts []*packet.OutPacket) error {
|
||||
if len(pkts) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
for i := range pkts {
|
||||
if err := dev.TransmitPacket(pkts[i], false); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if err := dev.TransmitQueue.Kick(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// TODO: Make above methods cancelable by taking a context.Context argument?
|
||||
// TODO: Implement zero-copy variants to transmit and receive packets?
|
||||
|
||||
// processChains processes as many chains as needed to create one packet. The number of processed chains is returned.
|
||||
func (dev *Device) processChains(pkt *packet.VirtIOPacket, chains []virtqueue.UsedElement) (int, error) {
|
||||
//read first element to see how many descriptors we need:
|
||||
pkt.Reset()
|
||||
|
||||
err := dev.ReceiveQueue.GetDescriptorInbuffers(uint16(chains[0].DescriptorIndex), &pkt.ChainRefs)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("get descriptor chain: %w", err)
|
||||
}
|
||||
if len(pkt.ChainRefs) == 0 {
|
||||
return 1, nil
|
||||
}
|
||||
|
||||
// The specification requires that the first descriptor chain starts
|
||||
// with a virtio-net header. It is not clear, whether it is also
|
||||
// required to be fully contained in the first buffer of that
|
||||
// descriptor chain, but it is reasonable to assume that this is
|
||||
// always the case.
|
||||
// The decode method already does the buffer length check.
|
||||
if err = pkt.Header.Decode(pkt.ChainRefs[0][0:]); err != nil {
|
||||
// The device misbehaved. There is no way we can gracefully
|
||||
// recover from this, because we don't know how many of the
|
||||
// following descriptor chains belong to this packet.
|
||||
return 0, fmt.Errorf("decode vnethdr: %w", err)
|
||||
}
|
||||
|
||||
//we have the header now: what do we need to do?
|
||||
if int(pkt.Header.NumBuffers) > len(chains) {
|
||||
return 0, fmt.Errorf("number of buffers is greater than number of chains %d", len(chains))
|
||||
}
|
||||
if int(pkt.Header.NumBuffers) != 1 {
|
||||
return 0, fmt.Errorf("too smol-brain to handle more than one chain right now: %d chains", len(chains))
|
||||
}
|
||||
if chains[0].Length > 16000 {
|
||||
//todo!
|
||||
return 1, fmt.Errorf("too big packet length: %d", chains[0].Length)
|
||||
}
|
||||
|
||||
//shift the buffer out of out:
|
||||
pkt.Payload = pkt.ChainRefs[0][virtio.NetHdrSize:chains[0].Length]
|
||||
pkt.Chains = append(pkt.Chains, uint16(chains[0].DescriptorIndex))
|
||||
return 1, nil
|
||||
|
||||
//cursor := n - virtio.NetHdrSize
|
||||
//
|
||||
//if uint32(n) >= chains[0].Length && pkt.Header.NumBuffers == 1 {
|
||||
// pkt.Payload = pkt.Payload[:chains[0].Length-virtio.NetHdrSize]
|
||||
// return 1, nil
|
||||
//}
|
||||
//
|
||||
//i := 1
|
||||
//// we used chain 0 already
|
||||
//for i = 1; i < len(chains); i++ {
|
||||
// n, err = dev.ReceiveQueue.GetDescriptorChainContents(uint16(chains[i].DescriptorIndex), pkt.Payload[cursor:], int(chains[i].Length))
|
||||
// if err != nil {
|
||||
// // When this fails we may miss to free some descriptor chains. We
|
||||
// // could try to mitigate this by deferring the freeing somehow, but
|
||||
// // it's not worth the hassle. When this method fails, the queue will
|
||||
// // be in a broken state anyway.
|
||||
// return i, fmt.Errorf("get descriptor chain: %w", err)
|
||||
// }
|
||||
// cursor += n
|
||||
//}
|
||||
////todo this has to be wrong
|
||||
//pkt.Payload = pkt.Payload[:cursor]
|
||||
//return i, nil
|
||||
}
|
||||
|
||||
func (dev *Device) ReceivePackets(out []*packet.VirtIOPacket) (int, error) {
|
||||
//todo optimize?
|
||||
var chains []virtqueue.UsedElement
|
||||
var err error
|
||||
|
||||
chains, err = dev.ReceiveQueue.BlockAndGetHeadsCapped(context.TODO(), len(out))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if len(chains) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
numPackets := 0
|
||||
chainsIdx := 0
|
||||
for numPackets = 0; chainsIdx < len(chains); numPackets++ {
|
||||
if numPackets >= len(out) {
|
||||
return numPackets, fmt.Errorf("dropping %d packets, no room", len(chains)-numPackets)
|
||||
}
|
||||
numChains, err := dev.processChains(out[numPackets], chains[chainsIdx:])
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
chainsIdx += numChains
|
||||
}
|
||||
|
||||
return numPackets, nil
|
||||
}
|
||||
3
overlay/vhostnet/doc.go
Normal file
3
overlay/vhostnet/doc.go
Normal file
@@ -0,0 +1,3 @@
|
||||
// Package vhostnet implements methods to initialize vhost networking devices
|
||||
// within the kernel-level virtio implementation and communicate with them.
|
||||
package vhostnet
|
||||
31
overlay/vhostnet/ioctl.go
Normal file
31
overlay/vhostnet/ioctl.go
Normal file
@@ -0,0 +1,31 @@
|
||||
package vhostnet
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"unsafe"
|
||||
|
||||
"github.com/slackhq/nebula/overlay/vhost"
|
||||
)
|
||||
|
||||
const (
|
||||
// vhostNetIoctlSetBackend can be used to attach a virtqueue to a RAW socket
|
||||
// or TAP device.
|
||||
//
|
||||
// Request payload: [vhost.QueueFile]
|
||||
// Kernel name: VHOST_NET_SET_BACKEND
|
||||
vhostNetIoctlSetBackend = 0x4008af30
|
||||
)
|
||||
|
||||
// SetQueueBackend attaches a virtqueue of the vhost networking device
|
||||
// described by controlFD to the given backend file descriptor.
|
||||
// The backend file descriptor can either be a RAW socket or a TAP device. When
|
||||
// it is -1, the queue will be detached.
|
||||
func SetQueueBackend(controlFD int, queueIndex uint32, backendFD int) error {
|
||||
if err := vhost.IoctlPtr(controlFD, vhostNetIoctlSetBackend, unsafe.Pointer(&vhost.QueueFile{
|
||||
QueueIndex: queueIndex,
|
||||
FD: int32(backendFD),
|
||||
})); err != nil {
|
||||
return fmt.Errorf("set queue backend file descriptor: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
69
overlay/vhostnet/options.go
Normal file
69
overlay/vhostnet/options.go
Normal file
@@ -0,0 +1,69 @@
|
||||
package vhostnet
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/slackhq/nebula/overlay/virtqueue"
|
||||
)
|
||||
|
||||
type optionValues struct {
|
||||
queueSize int
|
||||
backendFD int
|
||||
}
|
||||
|
||||
func (o *optionValues) apply(options []Option) {
|
||||
for _, option := range options {
|
||||
option(o)
|
||||
}
|
||||
}
|
||||
|
||||
func (o *optionValues) validate() error {
|
||||
if o.queueSize == -1 {
|
||||
return errors.New("queue size is required")
|
||||
}
|
||||
if err := virtqueue.CheckQueueSize(o.queueSize); err != nil {
|
||||
return err
|
||||
}
|
||||
if o.backendFD == -1 {
|
||||
return errors.New("backend file descriptor is required")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
var optionDefaults = optionValues{
|
||||
// Required.
|
||||
queueSize: -1,
|
||||
// Required.
|
||||
backendFD: -1,
|
||||
}
|
||||
|
||||
// Option can be passed to [NewDevice] to influence device creation.
|
||||
type Option func(*optionValues)
|
||||
|
||||
// WithQueueSize returns an [Option] that sets the size of the TX and RX queues
|
||||
// that are to be created for the device. It specifies the number of
|
||||
// entries/buffers each queue can hold. This also affects the memory
|
||||
// consumption.
|
||||
// This is required and must be an integer from 1 to 32768 that is also a power
|
||||
// of 2.
|
||||
func WithQueueSize(queueSize int) Option {
|
||||
return func(o *optionValues) { o.queueSize = queueSize }
|
||||
}
|
||||
|
||||
// WithBackendFD returns an [Option] that sets the file descriptor of the
|
||||
// backend that will be used for the queues of the device. The device will write
|
||||
// and read packets to/from that backend. The file descriptor can either be of a
|
||||
// RAW socket or TUN/TAP device.
|
||||
// Either this or [WithBackendDevice] is required.
|
||||
func WithBackendFD(backendFD int) Option {
|
||||
return func(o *optionValues) { o.backendFD = backendFD }
|
||||
}
|
||||
|
||||
//// WithBackendDevice returns an [Option] that sets the given TAP device as the
|
||||
//// backend that will be used for the queues of the device. The device will
|
||||
//// write and read packets to/from that backend. The TAP device should have been
|
||||
//// created with the [tuntap.WithVirtioNetHdr] option enabled.
|
||||
//// Either this or [WithBackendFD] is required.
|
||||
//func WithBackendDevice(dev *tuntap.Device) Option {
|
||||
// return func(o *optionValues) { o.backendFD = int(dev.File().Fd()) }
|
||||
//}
|
||||
23
overlay/virtqueue/README.md
Normal file
23
overlay/virtqueue/README.md
Normal file
@@ -0,0 +1,23 @@
|
||||
Significant portions of this code are derived from https://pkg.go.dev/github.com/hetznercloud/virtio-go
|
||||
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2025 Hetzner Cloud GmbH
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
140
overlay/virtqueue/available_ring.go
Normal file
140
overlay/virtqueue/available_ring.go
Normal file
@@ -0,0 +1,140 @@
|
||||
package virtqueue
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
// availableRingFlag is a flag that describes an [AvailableRing].
|
||||
type availableRingFlag uint16
|
||||
|
||||
const (
|
||||
// availableRingFlagNoInterrupt is used by the guest to advise the host to
|
||||
// not interrupt it when consuming a buffer. It's unreliable, so it's simply
|
||||
// an optimization.
|
||||
availableRingFlagNoInterrupt availableRingFlag = 1 << iota
|
||||
)
|
||||
|
||||
// availableRingSize is the number of bytes needed to store an [AvailableRing]
|
||||
// with the given queue size in memory.
|
||||
func availableRingSize(queueSize int) int {
|
||||
return 6 + 2*queueSize
|
||||
}
|
||||
|
||||
// availableRingAlignment is the minimum alignment of an [AvailableRing]
|
||||
// in memory, as required by the virtio spec.
|
||||
const availableRingAlignment = 2
|
||||
|
||||
// AvailableRing is used by the driver to offer descriptor chains to the device.
|
||||
// Each ring entry refers to the head of a descriptor chain. It is only written
|
||||
// to by the driver and read by the device.
|
||||
//
|
||||
// Because the size of the ring depends on the queue size, we cannot define a
|
||||
// Go struct with a static size that maps to the memory of the ring. Instead,
|
||||
// this struct only contains pointers to the corresponding memory areas.
|
||||
type AvailableRing struct {
|
||||
initialized bool
|
||||
|
||||
// flags that describe this ring.
|
||||
flags *availableRingFlag
|
||||
// ringIndex indicates where the driver would put the next entry into the
|
||||
// ring (modulo the queue size).
|
||||
ringIndex *uint16
|
||||
// ring references buffers using the index of the head of the descriptor
|
||||
// chain in the [DescriptorTable]. It wraps around at queue size.
|
||||
ring []uint16
|
||||
// usedEvent is not used by this implementation, but we reserve it anyway to
|
||||
// avoid issues in case a device may try to access it, contrary to the
|
||||
// virtio specification.
|
||||
usedEvent *uint16
|
||||
}
|
||||
|
||||
// newAvailableRing creates an available ring that uses the given underlying
|
||||
// memory. The length of the memory slice must match the size needed for the
|
||||
// ring (see [availableRingSize]) for the given queue size.
|
||||
func newAvailableRing(queueSize int, mem []byte) *AvailableRing {
|
||||
ringSize := availableRingSize(queueSize)
|
||||
if len(mem) != ringSize {
|
||||
panic(fmt.Sprintf("memory size (%v) does not match required size "+
|
||||
"for available ring: %v", len(mem), ringSize))
|
||||
}
|
||||
|
||||
return &AvailableRing{
|
||||
initialized: true,
|
||||
flags: (*availableRingFlag)(unsafe.Pointer(&mem[0])),
|
||||
ringIndex: (*uint16)(unsafe.Pointer(&mem[2])),
|
||||
ring: unsafe.Slice((*uint16)(unsafe.Pointer(&mem[4])), queueSize),
|
||||
usedEvent: (*uint16)(unsafe.Pointer(&mem[ringSize-2])),
|
||||
}
|
||||
}
|
||||
|
||||
// Address returns the pointer to the beginning of the ring in memory.
|
||||
// Do not modify the memory directly to not interfere with this implementation.
|
||||
func (r *AvailableRing) Address() uintptr {
|
||||
if !r.initialized {
|
||||
panic("available ring is not initialized")
|
||||
}
|
||||
return uintptr(unsafe.Pointer(r.flags))
|
||||
}
|
||||
|
||||
// offer adds the given descriptor chain heads to the available ring and
|
||||
// advances the ring index accordingly to make the device process the new
|
||||
// descriptor chains.
|
||||
func (r *AvailableRing) offerElements(chains []UsedElement) {
|
||||
//always called under lock
|
||||
//r.mu.Lock()
|
||||
//defer r.mu.Unlock()
|
||||
|
||||
// Add descriptor chain heads to the ring.
|
||||
for offset, x := range chains {
|
||||
// The 16-bit ring index may overflow. This is expected and is not an
|
||||
// issue because the size of the ring array (which equals the queue
|
||||
// size) is always a power of 2 and smaller than the highest possible
|
||||
// 16-bit value.
|
||||
insertIndex := int(*r.ringIndex+uint16(offset)) % len(r.ring)
|
||||
r.ring[insertIndex] = x.GetHead()
|
||||
}
|
||||
|
||||
// Increase the ring index by the number of descriptor chains added to the
|
||||
// ring.
|
||||
*r.ringIndex += uint16(len(chains))
|
||||
}
|
||||
|
||||
func (r *AvailableRing) offer(chains []uint16) {
|
||||
//always called under lock
|
||||
//r.mu.Lock()
|
||||
//defer r.mu.Unlock()
|
||||
|
||||
// Add descriptor chain heads to the ring.
|
||||
for offset, x := range chains {
|
||||
// The 16-bit ring index may overflow. This is expected and is not an
|
||||
// issue because the size of the ring array (which equals the queue
|
||||
// size) is always a power of 2 and smaller than the highest possible
|
||||
// 16-bit value.
|
||||
insertIndex := int(*r.ringIndex+uint16(offset)) % len(r.ring)
|
||||
r.ring[insertIndex] = x
|
||||
}
|
||||
|
||||
// Increase the ring index by the number of descriptor chains added to the
|
||||
// ring.
|
||||
*r.ringIndex += uint16(len(chains))
|
||||
}
|
||||
|
||||
func (r *AvailableRing) offerSingle(x uint16) {
|
||||
//always called under lock
|
||||
//r.mu.Lock()
|
||||
//defer r.mu.Unlock()
|
||||
|
||||
offset := 0
|
||||
// Add descriptor chain heads to the ring.
|
||||
|
||||
// The 16-bit ring index may overflow. This is expected and is not an
|
||||
// issue because the size of the ring array (which equals the queue
|
||||
// size) is always a power of 2 and smaller than the highest possible
|
||||
// 16-bit value.
|
||||
insertIndex := int(*r.ringIndex+uint16(offset)) % len(r.ring)
|
||||
r.ring[insertIndex] = x
|
||||
|
||||
// Increase the ring index by the number of descriptor chains added to the ring.
|
||||
*r.ringIndex += 1
|
||||
}
|
||||
71
overlay/virtqueue/available_ring_internal_test.go
Normal file
71
overlay/virtqueue/available_ring_internal_test.go
Normal file
@@ -0,0 +1,71 @@
|
||||
package virtqueue
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestAvailableRing_MemoryLayout(t *testing.T) {
|
||||
const queueSize = 2
|
||||
|
||||
memory := make([]byte, availableRingSize(queueSize))
|
||||
r := newAvailableRing(queueSize, memory)
|
||||
|
||||
*r.flags = 0x01ff
|
||||
*r.ringIndex = 1
|
||||
r.ring[0] = 0x1234
|
||||
r.ring[1] = 0x5678
|
||||
|
||||
assert.Equal(t, []byte{
|
||||
0xff, 0x01,
|
||||
0x01, 0x00,
|
||||
0x34, 0x12,
|
||||
0x78, 0x56,
|
||||
0x00, 0x00,
|
||||
}, memory)
|
||||
}
|
||||
|
||||
func TestAvailableRing_Offer(t *testing.T) {
|
||||
const queueSize = 8
|
||||
|
||||
chainHeads := []uint16{42, 33, 69}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
startRingIndex uint16
|
||||
expectedRingIndex uint16
|
||||
expectedRing []uint16
|
||||
}{
|
||||
{
|
||||
name: "no overflow",
|
||||
startRingIndex: 0,
|
||||
expectedRingIndex: 3,
|
||||
expectedRing: []uint16{42, 33, 69, 0, 0, 0, 0, 0},
|
||||
},
|
||||
{
|
||||
name: "ring overflow",
|
||||
startRingIndex: 6,
|
||||
expectedRingIndex: 9,
|
||||
expectedRing: []uint16{69, 0, 0, 0, 0, 0, 42, 33},
|
||||
},
|
||||
{
|
||||
name: "index overflow",
|
||||
startRingIndex: 65535,
|
||||
expectedRingIndex: 2,
|
||||
expectedRing: []uint16{33, 69, 0, 0, 0, 0, 0, 42},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
memory := make([]byte, availableRingSize(queueSize))
|
||||
r := newAvailableRing(queueSize, memory)
|
||||
*r.ringIndex = tt.startRingIndex
|
||||
|
||||
r.offer(chainHeads)
|
||||
|
||||
assert.Equal(t, tt.expectedRingIndex, *r.ringIndex)
|
||||
assert.Equal(t, tt.expectedRing, r.ring)
|
||||
})
|
||||
}
|
||||
}
|
||||
43
overlay/virtqueue/descriptor.go
Normal file
43
overlay/virtqueue/descriptor.go
Normal file
@@ -0,0 +1,43 @@
|
||||
package virtqueue
|
||||
|
||||
// descriptorFlag is a flag that describes a [Descriptor].
|
||||
type descriptorFlag uint16
|
||||
|
||||
const (
|
||||
// descriptorFlagHasNext marks a descriptor chain as continuing via the next
|
||||
// field.
|
||||
descriptorFlagHasNext descriptorFlag = 1 << iota
|
||||
// descriptorFlagWritable marks a buffer as device write-only (otherwise
|
||||
// device read-only).
|
||||
descriptorFlagWritable
|
||||
// descriptorFlagIndirect means the buffer contains a list of buffer
|
||||
// descriptors to provide an additional layer of indirection.
|
||||
// Only allowed when the [virtio.FeatureIndirectDescriptors] feature was
|
||||
// negotiated.
|
||||
descriptorFlagIndirect
|
||||
)
|
||||
|
||||
// descriptorSize is the number of bytes needed to store a [Descriptor] in
|
||||
// memory.
|
||||
const descriptorSize = 16
|
||||
|
||||
// Descriptor describes (a part of) a buffer which is either read-only for the
|
||||
// device or write-only for the device (depending on [descriptorFlagWritable]).
|
||||
// Multiple descriptors can be chained to produce a "descriptor chain" that can
|
||||
// contain both device-readable and device-writable buffers. Device-readable
|
||||
// descriptors always come first in a chain. A single, large buffer may be
|
||||
// split up by chaining multiple similar descriptors that reference different
|
||||
// memory pages. This is required, because buffers may exceed a single page size
|
||||
// and the memory accessed by the device is expected to be continuous.
|
||||
type Descriptor struct {
|
||||
// address is the address to the continuous memory holding the data for this
|
||||
// descriptor.
|
||||
address uintptr
|
||||
// length is the amount of bytes stored at address.
|
||||
length uint32
|
||||
// flags that describe this descriptor.
|
||||
flags descriptorFlag
|
||||
// next contains the index of the next descriptor continuing this descriptor
|
||||
// chain when the [descriptorFlagHasNext] flag is set.
|
||||
next uint16
|
||||
}
|
||||
12
overlay/virtqueue/descriptor_internal_test.go
Normal file
12
overlay/virtqueue/descriptor_internal_test.go
Normal file
@@ -0,0 +1,12 @@
|
||||
package virtqueue
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"unsafe"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestDescriptor_Size(t *testing.T) {
|
||||
assert.EqualValues(t, descriptorSize, unsafe.Sizeof(Descriptor{}))
|
||||
}
|
||||
465
overlay/virtqueue/descriptor_table.go
Normal file
465
overlay/virtqueue/descriptor_table.go
Normal file
@@ -0,0 +1,465 @@
|
||||
package virtqueue
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrDescriptorChainEmpty is returned when a descriptor chain would contain
|
||||
// no buffers, which is not allowed.
|
||||
ErrDescriptorChainEmpty = errors.New("empty descriptor chains are not allowed")
|
||||
|
||||
// ErrNotEnoughFreeDescriptors is returned when the free descriptors are
|
||||
// exhausted, meaning that the queue is full.
|
||||
ErrNotEnoughFreeDescriptors = errors.New("not enough free descriptors, queue is full")
|
||||
|
||||
// ErrInvalidDescriptorChain is returned when a descriptor chain is not
|
||||
// valid for a given operation.
|
||||
ErrInvalidDescriptorChain = errors.New("invalid descriptor chain")
|
||||
)
|
||||
|
||||
// noFreeHead is used to mark when all descriptors are in use and we have no
|
||||
// free chain. This value is impossible to occur as an index naturally, because
|
||||
// it exceeds the maximum queue size.
|
||||
const noFreeHead = uint16(math.MaxUint16)
|
||||
|
||||
// descriptorTableSize is the number of bytes needed to store a
|
||||
// [DescriptorTable] with the given queue size in memory.
|
||||
func descriptorTableSize(queueSize int) int {
|
||||
return descriptorSize * queueSize
|
||||
}
|
||||
|
||||
// descriptorTableAlignment is the minimum alignment of a [DescriptorTable]
|
||||
// in memory, as required by the virtio spec.
|
||||
const descriptorTableAlignment = 16
|
||||
|
||||
// DescriptorTable is a table that holds [Descriptor]s, addressed via their
|
||||
// index in the slice.
|
||||
type DescriptorTable struct {
|
||||
descriptors []Descriptor
|
||||
|
||||
// freeHeadIndex is the index of the head of the descriptor chain which
|
||||
// contains all currently unused descriptors. When all descriptors are in
|
||||
// use, this has the special value of noFreeHead.
|
||||
freeHeadIndex uint16
|
||||
// freeNum tracks the number of descriptors which are currently not in use.
|
||||
freeNum uint16
|
||||
|
||||
bufferBase uintptr
|
||||
bufferSize int
|
||||
itemSize int
|
||||
}
|
||||
|
||||
// newDescriptorTable creates a descriptor table that uses the given underlying
|
||||
// memory. The Length of the memory slice must match the size needed for the
|
||||
// descriptor table (see [descriptorTableSize]) for the given queue size.
|
||||
//
|
||||
// Before this descriptor table can be used, [initialize] must be called.
|
||||
func newDescriptorTable(queueSize int, mem []byte, itemSize int) *DescriptorTable {
|
||||
dtSize := descriptorTableSize(queueSize)
|
||||
if len(mem) != dtSize {
|
||||
panic(fmt.Sprintf("memory size (%v) does not match required size "+
|
||||
"for descriptor table: %v", len(mem), dtSize))
|
||||
}
|
||||
|
||||
return &DescriptorTable{
|
||||
descriptors: unsafe.Slice((*Descriptor)(unsafe.Pointer(&mem[0])), queueSize),
|
||||
// We have no free descriptors until they were initialized.
|
||||
freeHeadIndex: noFreeHead,
|
||||
freeNum: 0,
|
||||
itemSize: itemSize, //todo configurable? needs to be page-aligned
|
||||
}
|
||||
}
|
||||
|
||||
// Address returns the pointer to the beginning of the descriptor table in
|
||||
// memory. Do not modify the memory directly to not interfere with this
|
||||
// implementation.
|
||||
func (dt *DescriptorTable) Address() uintptr {
|
||||
if dt.descriptors == nil {
|
||||
panic("descriptor table is not initialized")
|
||||
}
|
||||
//should be same as dt.bufferBase
|
||||
return uintptr(unsafe.Pointer(&dt.descriptors[0]))
|
||||
}
|
||||
|
||||
func (dt *DescriptorTable) Size() uintptr {
|
||||
if dt.descriptors == nil {
|
||||
panic("descriptor table is not initialized")
|
||||
}
|
||||
return uintptr(dt.bufferSize)
|
||||
}
|
||||
|
||||
// BufferAddresses returns a map of pointer->size for all allocations used by the table
|
||||
func (dt *DescriptorTable) BufferAddresses() map[uintptr]int {
|
||||
if dt.descriptors == nil {
|
||||
panic("descriptor table is not initialized")
|
||||
}
|
||||
|
||||
return map[uintptr]int{dt.bufferBase: dt.bufferSize}
|
||||
}
|
||||
|
||||
// initializeDescriptors allocates buffers with the size of a full memory page
|
||||
// for each descriptor in the table. While this may be a bit wasteful, it makes
|
||||
// dealing with descriptors way easier. Without this preallocation, we would
|
||||
// have to allocate and free memory on demand, increasing complexity.
|
||||
//
|
||||
// All descriptors will be marked as free and will form a free chain. The
|
||||
// addresses of all descriptors will be populated while their length remains
|
||||
// zero.
|
||||
func (dt *DescriptorTable) initializeDescriptors() error {
|
||||
numDescriptors := len(dt.descriptors)
|
||||
|
||||
// Allocate ONE large region for all buffers
|
||||
totalSize := dt.itemSize * numDescriptors
|
||||
basePtr, err := unix.MmapPtr(-1, 0, nil, uintptr(totalSize),
|
||||
unix.PROT_READ|unix.PROT_WRITE,
|
||||
unix.MAP_PRIVATE|unix.MAP_ANONYMOUS)
|
||||
if err != nil {
|
||||
return fmt.Errorf("allocate buffer memory for descriptors: %w", err)
|
||||
}
|
||||
|
||||
// Store the base for cleanup later
|
||||
dt.bufferBase = uintptr(basePtr)
|
||||
dt.bufferSize = totalSize
|
||||
|
||||
for i := range dt.descriptors {
|
||||
dt.descriptors[i] = Descriptor{
|
||||
address: dt.bufferBase + uintptr(i*dt.itemSize),
|
||||
length: 0,
|
||||
// All descriptors should form a free chain that loops around.
|
||||
flags: descriptorFlagHasNext,
|
||||
next: uint16((i + 1) % len(dt.descriptors)),
|
||||
}
|
||||
}
|
||||
|
||||
// All descriptors are free to use now.
|
||||
dt.freeHeadIndex = 0
|
||||
dt.freeNum = uint16(len(dt.descriptors))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// releaseBuffers releases all allocated buffers for this descriptor table.
|
||||
// The implementation will try to release as many buffers as possible and
|
||||
// collect potential errors before returning them.
|
||||
// The descriptor table should no longer be used after calling this.
|
||||
func (dt *DescriptorTable) releaseBuffers() error {
|
||||
for i := range dt.descriptors {
|
||||
descriptor := &dt.descriptors[i]
|
||||
descriptor.address = 0
|
||||
}
|
||||
|
||||
// As a safety measure, make sure no descriptors can be used anymore.
|
||||
dt.freeHeadIndex = noFreeHead
|
||||
dt.freeNum = 0
|
||||
|
||||
if dt.bufferBase != 0 {
|
||||
// The pointer points to memory not managed by Go, so this conversion
|
||||
// is safe. See https://github.com/golang/go/issues/58625
|
||||
dt.bufferBase = 0
|
||||
//goland:noinspection GoVetUnsafePointer
|
||||
err := unix.MunmapPtr(unsafe.Pointer(dt.bufferBase), uintptr(dt.bufferSize))
|
||||
if err != nil {
|
||||
return fmt.Errorf("release buffer memory: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (dt *DescriptorTable) CreateDescriptorForOutputs() (uint16, error) {
|
||||
//todo just fill the damn table
|
||||
// Do we still have enough free descriptors?
|
||||
|
||||
if 1 > dt.freeNum {
|
||||
return 0, ErrNotEnoughFreeDescriptors
|
||||
}
|
||||
|
||||
// Above validation ensured that there is at least one free descriptor, so
|
||||
// the free descriptor chain head should be valid.
|
||||
if dt.freeHeadIndex == noFreeHead {
|
||||
panic("free descriptor chain head is unset but there should be free descriptors")
|
||||
}
|
||||
|
||||
// To avoid having to iterate over the whole table to find the descriptor
|
||||
// pointing to the head just to replace the free head, we instead always
|
||||
// create descriptor chains from the descriptors coming after the head.
|
||||
// This way we only have to touch the head as a last resort, when all other
|
||||
// descriptors are already used.
|
||||
head := dt.descriptors[dt.freeHeadIndex].next
|
||||
desc := &dt.descriptors[head]
|
||||
next := desc.next
|
||||
|
||||
checkUnusedDescriptorLength(head, desc)
|
||||
|
||||
// Give the device the maximum available number of bytes to write into.
|
||||
desc.length = uint32(dt.itemSize)
|
||||
desc.flags = 0 // descriptorFlagWritable
|
||||
desc.next = 0 // Not necessary to clear this, it's just for looks.
|
||||
|
||||
dt.freeNum -= 1
|
||||
|
||||
if dt.freeNum == 0 {
|
||||
// The last descriptor in the chain should be the free chain head
|
||||
// itself.
|
||||
if next != dt.freeHeadIndex {
|
||||
panic("descriptor chain takes up all free descriptors but does not end with the free chain head")
|
||||
}
|
||||
|
||||
// When this new chain takes up all remaining descriptors, we no longer
|
||||
// have a free chain.
|
||||
dt.freeHeadIndex = noFreeHead
|
||||
} else {
|
||||
// We took some descriptors out of the free chain, so make sure to close
|
||||
// the circle again.
|
||||
dt.descriptors[dt.freeHeadIndex].next = next
|
||||
}
|
||||
|
||||
return head, nil
|
||||
}
|
||||
|
||||
func (dt *DescriptorTable) createDescriptorForInputs() (uint16, error) {
|
||||
// Do we still have enough free descriptors?
|
||||
if 1 > dt.freeNum {
|
||||
return 0, ErrNotEnoughFreeDescriptors
|
||||
}
|
||||
|
||||
// Above validation ensured that there is at least one free descriptor, so
|
||||
// the free descriptor chain head should be valid.
|
||||
if dt.freeHeadIndex == noFreeHead {
|
||||
panic("free descriptor chain head is unset but there should be free descriptors")
|
||||
}
|
||||
|
||||
// To avoid having to iterate over the whole table to find the descriptor
|
||||
// pointing to the head just to replace the free head, we instead always
|
||||
// create descriptor chains from the descriptors coming after the head.
|
||||
// This way we only have to touch the head as a last resort, when all other
|
||||
// descriptors are already used.
|
||||
head := dt.descriptors[dt.freeHeadIndex].next
|
||||
desc := &dt.descriptors[head]
|
||||
next := desc.next
|
||||
|
||||
checkUnusedDescriptorLength(head, desc)
|
||||
|
||||
// Give the device the maximum available number of bytes to write into.
|
||||
desc.length = uint32(dt.itemSize)
|
||||
desc.flags = descriptorFlagWritable
|
||||
desc.next = 0 // Not necessary to clear this, it's just for looks.
|
||||
|
||||
dt.freeNum -= 1
|
||||
|
||||
if dt.freeNum == 0 {
|
||||
// The last descriptor in the chain should be the free chain head
|
||||
// itself.
|
||||
if next != dt.freeHeadIndex {
|
||||
panic("descriptor chain takes up all free descriptors but does not end with the free chain head")
|
||||
}
|
||||
|
||||
// When this new chain takes up all remaining descriptors, we no longer
|
||||
// have a free chain.
|
||||
dt.freeHeadIndex = noFreeHead
|
||||
} else {
|
||||
// We took some descriptors out of the free chain, so make sure to close
|
||||
// the circle again.
|
||||
dt.descriptors[dt.freeHeadIndex].next = next
|
||||
}
|
||||
|
||||
return head, nil
|
||||
}
|
||||
|
||||
// TODO: Implement a zero-copy variant of createDescriptorChain?
|
||||
|
||||
// getDescriptorChain returns the device-readable buffers (out buffers) and
|
||||
// device-writable buffers (in buffers) of the descriptor chain that starts with
|
||||
// the given head index. The descriptor chain must have been created using
|
||||
// [createDescriptorChain] and must not have been freed yet (meaning that the
|
||||
// head index must not be contained in the free chain).
|
||||
//
|
||||
// Be careful to only access the returned buffer slices when the device has not
|
||||
// yet or is no longer using them. They must not be accessed after
|
||||
// [freeDescriptorChain] has been called.
|
||||
func (dt *DescriptorTable) getDescriptorChain(head uint16) (outBuffers, inBuffers [][]byte, err error) {
|
||||
if int(head) > len(dt.descriptors) {
|
||||
return nil, nil, fmt.Errorf("%w: index out of range", ErrInvalidDescriptorChain)
|
||||
}
|
||||
|
||||
// Iterate over the chain. The iteration is limited to the queue size to
|
||||
// avoid ending up in an endless loop when things go very wrong.
|
||||
next := head
|
||||
for range len(dt.descriptors) {
|
||||
if next == dt.freeHeadIndex {
|
||||
return nil, nil, fmt.Errorf("%w: must not be part of the free chain", ErrInvalidDescriptorChain)
|
||||
}
|
||||
|
||||
desc := &dt.descriptors[next]
|
||||
|
||||
// The descriptor address points to memory not managed by Go, so this
|
||||
// conversion is safe. See https://github.com/golang/go/issues/58625
|
||||
//goland:noinspection GoVetUnsafePointer
|
||||
bs := unsafe.Slice((*byte)(unsafe.Pointer(desc.address)), desc.length)
|
||||
|
||||
if desc.flags&descriptorFlagWritable == 0 {
|
||||
outBuffers = append(outBuffers, bs)
|
||||
} else {
|
||||
inBuffers = append(inBuffers, bs)
|
||||
}
|
||||
|
||||
// Is this the tail of the chain?
|
||||
if desc.flags&descriptorFlagHasNext == 0 {
|
||||
break
|
||||
}
|
||||
|
||||
// Detect loops.
|
||||
if desc.next == head {
|
||||
return nil, nil, fmt.Errorf("%w: contains a loop", ErrInvalidDescriptorChain)
|
||||
}
|
||||
|
||||
next = desc.next
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (dt *DescriptorTable) getDescriptorItem(head uint16) ([]byte, error) {
|
||||
if int(head) > len(dt.descriptors) {
|
||||
return nil, fmt.Errorf("%w: index out of range", ErrInvalidDescriptorChain)
|
||||
}
|
||||
|
||||
desc := &dt.descriptors[head] //todo this is a pretty nasty hack with no checks
|
||||
|
||||
// The descriptor address points to memory not managed by Go, so this
|
||||
// conversion is safe. See https://github.com/golang/go/issues/58625
|
||||
//goland:noinspection GoVetUnsafePointer
|
||||
bs := unsafe.Slice((*byte)(unsafe.Pointer(desc.address)), desc.length)
|
||||
return bs, nil
|
||||
}
|
||||
|
||||
func (dt *DescriptorTable) getDescriptorInbuffers(head uint16, inBuffers *[][]byte) error {
|
||||
if int(head) > len(dt.descriptors) {
|
||||
return fmt.Errorf("%w: index out of range", ErrInvalidDescriptorChain)
|
||||
}
|
||||
|
||||
// Iterate over the chain. The iteration is limited to the queue size to
|
||||
// avoid ending up in an endless loop when things go very wrong.
|
||||
next := head
|
||||
for range len(dt.descriptors) {
|
||||
if next == dt.freeHeadIndex {
|
||||
return fmt.Errorf("%w: must not be part of the free chain", ErrInvalidDescriptorChain)
|
||||
}
|
||||
|
||||
desc := &dt.descriptors[next]
|
||||
|
||||
// The descriptor address points to memory not managed by Go, so this
|
||||
// conversion is safe. See https://github.com/golang/go/issues/58625
|
||||
//goland:noinspection GoVetUnsafePointer
|
||||
bs := unsafe.Slice((*byte)(unsafe.Pointer(desc.address)), desc.length)
|
||||
|
||||
if desc.flags&descriptorFlagWritable == 0 {
|
||||
return fmt.Errorf("there should not be an outbuffer in %d", head)
|
||||
} else {
|
||||
*inBuffers = append(*inBuffers, bs)
|
||||
}
|
||||
|
||||
// Is this the tail of the chain?
|
||||
if desc.flags&descriptorFlagHasNext == 0 {
|
||||
break
|
||||
}
|
||||
|
||||
// Detect loops.
|
||||
if desc.next == head {
|
||||
return fmt.Errorf("%w: contains a loop", ErrInvalidDescriptorChain)
|
||||
}
|
||||
|
||||
next = desc.next
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// freeDescriptorChain can be used to free a descriptor chain when it is no
|
||||
// longer in use. The descriptor chain that starts with the given index will be
|
||||
// put back into the free chain, so the descriptors can be used for later calls
|
||||
// of [createDescriptorChain].
|
||||
// The descriptor chain must have been created using [createDescriptorChain] and
|
||||
// must not have been freed yet (meaning that the head index must not be
|
||||
// contained in the free chain).
|
||||
func (dt *DescriptorTable) freeDescriptorChain(head uint16) error {
|
||||
if int(head) > len(dt.descriptors) {
|
||||
return fmt.Errorf("%w: index out of range", ErrInvalidDescriptorChain)
|
||||
}
|
||||
|
||||
// Iterate over the chain. The iteration is limited to the queue size to
|
||||
// avoid ending up in an endless loop when things go very wrong.
|
||||
next := head
|
||||
var tailDesc *Descriptor
|
||||
var chainLen uint16
|
||||
for range len(dt.descriptors) {
|
||||
if next == dt.freeHeadIndex {
|
||||
return fmt.Errorf("%w: must not be part of the free chain", ErrInvalidDescriptorChain)
|
||||
}
|
||||
|
||||
desc := &dt.descriptors[next]
|
||||
chainLen++
|
||||
|
||||
// Set the length of all unused descriptors back to zero.
|
||||
desc.length = 0
|
||||
|
||||
// Unset all flags except the next flag.
|
||||
desc.flags &= descriptorFlagHasNext
|
||||
|
||||
// Is this the tail of the chain?
|
||||
if desc.flags&descriptorFlagHasNext == 0 {
|
||||
tailDesc = desc
|
||||
break
|
||||
}
|
||||
|
||||
// Detect loops.
|
||||
if desc.next == head {
|
||||
return fmt.Errorf("%w: contains a loop", ErrInvalidDescriptorChain)
|
||||
}
|
||||
|
||||
next = desc.next
|
||||
}
|
||||
if tailDesc == nil {
|
||||
// A descriptor chain longer than the queue size but without loops
|
||||
// should be impossible.
|
||||
panic(fmt.Sprintf("could not find a tail for descriptor chain starting at %d", head))
|
||||
}
|
||||
|
||||
// The tail descriptor does not have the next flag set, but when it comes
|
||||
// back into the free chain, it should have.
|
||||
tailDesc.flags = descriptorFlagHasNext
|
||||
|
||||
if dt.freeHeadIndex == noFreeHead {
|
||||
// The whole free chain was used up, so we turn this returned descriptor
|
||||
// chain into the new free chain by completing the circle and using its
|
||||
// head.
|
||||
tailDesc.next = head
|
||||
dt.freeHeadIndex = head
|
||||
} else {
|
||||
// Attach the returned chain at the beginning of the free chain but
|
||||
// right after the free chain head.
|
||||
freeHeadDesc := &dt.descriptors[dt.freeHeadIndex]
|
||||
tailDesc.next = freeHeadDesc.next
|
||||
freeHeadDesc.next = head
|
||||
}
|
||||
|
||||
dt.freeNum += chainLen
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// checkUnusedDescriptorLength asserts that the length of an unused descriptor
|
||||
// is zero, as it should be.
|
||||
// This is not a requirement by the virtio spec but rather a thing we do to
|
||||
// notice when our algorithm goes sideways.
|
||||
func checkUnusedDescriptorLength(index uint16, desc *Descriptor) {
|
||||
if desc.length != 0 {
|
||||
panic(fmt.Sprintf("descriptor %d should be unused but has a non-zero length", index))
|
||||
}
|
||||
}
|
||||
7
overlay/virtqueue/doc.go
Normal file
7
overlay/virtqueue/doc.go
Normal file
@@ -0,0 +1,7 @@
|
||||
// Package virtqueue implements the driver-side for a virtio queue as described
|
||||
// in the specification:
|
||||
// https://docs.oasis-open.org/virtio/virtio/v1.2/csd01/virtio-v1.2-csd01.html#x1-270006
|
||||
// This package does not make assumptions about the device that consumes the
|
||||
// queue. It rather just allocates the queue structures in memory and provides
|
||||
// methods to interact with it.
|
||||
package virtqueue
|
||||
45
overlay/virtqueue/eventfd_test.go
Normal file
45
overlay/virtqueue/eventfd_test.go
Normal file
@@ -0,0 +1,45 @@
|
||||
package virtqueue
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gvisor.dev/gvisor/pkg/eventfd"
|
||||
)
|
||||
|
||||
// Tests how an eventfd and a waiting goroutine can be gracefully closed.
|
||||
// Extends the eventfd test suite:
|
||||
// https://github.com/google/gvisor/blob/0799336d64be65eb97d330606c30162dc3440cab/pkg/eventfd/eventfd_test.go
|
||||
func TestEventFD_CancelWait(t *testing.T) {
|
||||
efd, err := eventfd.Create()
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
assert.NoError(t, efd.Close())
|
||||
})
|
||||
|
||||
var stop bool
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
for !stop {
|
||||
_ = efd.Wait()
|
||||
}
|
||||
close(done)
|
||||
}()
|
||||
select {
|
||||
case <-done:
|
||||
t.Fatalf("goroutine ended early")
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
}
|
||||
|
||||
stop = true
|
||||
assert.NoError(t, efd.Notify())
|
||||
select {
|
||||
case <-done:
|
||||
break
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Error("goroutine did not end")
|
||||
}
|
||||
}
|
||||
33
overlay/virtqueue/size.go
Normal file
33
overlay/virtqueue/size.go
Normal file
@@ -0,0 +1,33 @@
|
||||
package virtqueue
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// ErrQueueSizeInvalid is returned when a queue size is invalid.
|
||||
var ErrQueueSizeInvalid = errors.New("queue size is invalid")
|
||||
|
||||
// CheckQueueSize checks if the given value would be a valid size for a
|
||||
// virtqueue and returns an [ErrQueueSizeInvalid], if not.
|
||||
func CheckQueueSize(queueSize int) error {
|
||||
if queueSize <= 0 {
|
||||
return fmt.Errorf("%w: %d is too small", ErrQueueSizeInvalid, queueSize)
|
||||
}
|
||||
|
||||
// The queue size must always be a power of 2.
|
||||
// This ensures that ring indexes wrap correctly when the 16-bit integers
|
||||
// overflow.
|
||||
if queueSize&(queueSize-1) != 0 {
|
||||
return fmt.Errorf("%w: %d is not a power of 2", ErrQueueSizeInvalid, queueSize)
|
||||
}
|
||||
|
||||
// The largest power of 2 that fits into a 16-bit integer is 32768.
|
||||
// 2 * 32768 would be 65536 which no longer fits.
|
||||
if queueSize > 32768 {
|
||||
return fmt.Errorf("%w: %d is larger than the maximum possible queue size 32768",
|
||||
ErrQueueSizeInvalid, queueSize)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
59
overlay/virtqueue/size_test.go
Normal file
59
overlay/virtqueue/size_test.go
Normal file
@@ -0,0 +1,59 @@
|
||||
package virtqueue
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestCheckQueueSize(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
queueSize int
|
||||
containsErr string
|
||||
}{
|
||||
{
|
||||
name: "negative",
|
||||
queueSize: -1,
|
||||
containsErr: "too small",
|
||||
},
|
||||
{
|
||||
name: "zero",
|
||||
queueSize: 0,
|
||||
containsErr: "too small",
|
||||
},
|
||||
{
|
||||
name: "not a power of 2",
|
||||
queueSize: 24,
|
||||
containsErr: "not a power of 2",
|
||||
},
|
||||
{
|
||||
name: "too large",
|
||||
queueSize: 65536,
|
||||
containsErr: "larger than the maximum",
|
||||
},
|
||||
{
|
||||
name: "valid 1",
|
||||
queueSize: 1,
|
||||
},
|
||||
{
|
||||
name: "valid 256",
|
||||
queueSize: 256,
|
||||
},
|
||||
|
||||
{
|
||||
name: "valid 32768",
|
||||
queueSize: 32768,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := CheckQueueSize(tt.queueSize)
|
||||
if tt.containsErr != "" {
|
||||
assert.ErrorContains(t, err, tt.containsErr)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
421
overlay/virtqueue/split_virtqueue.go
Normal file
421
overlay/virtqueue/split_virtqueue.go
Normal file
@@ -0,0 +1,421 @@
|
||||
package virtqueue
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/slackhq/nebula/overlay/eventfd"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
// SplitQueue is a virtqueue that consists of several parts, where each part is
|
||||
// writeable by either the driver or the device, but not both.
|
||||
type SplitQueue struct {
|
||||
// size is the size of the queue.
|
||||
size int
|
||||
// buf is the underlying memory used for the queue.
|
||||
buf []byte
|
||||
|
||||
descriptorTable *DescriptorTable
|
||||
availableRing *AvailableRing
|
||||
usedRing *UsedRing
|
||||
|
||||
// kickEventFD is used to signal the device when descriptor chains were
|
||||
// added to the available ring.
|
||||
kickEventFD eventfd.EventFD
|
||||
// callEventFD is used by the device to signal when it has used descriptor
|
||||
// chains and put them in the used ring.
|
||||
callEventFD eventfd.EventFD
|
||||
|
||||
// stop is used by [SplitQueue.Close] to cancel the goroutine that handles
|
||||
// used buffer notifications. It blocks until the goroutine ended.
|
||||
stop func() error
|
||||
|
||||
itemSize int
|
||||
|
||||
epoll eventfd.Epoll
|
||||
more int
|
||||
}
|
||||
|
||||
// NewSplitQueue allocates a new [SplitQueue] in memory. The given queue size
|
||||
// specifies the number of entries/buffers the queue can hold. This also affects
|
||||
// the memory consumption.
|
||||
func NewSplitQueue(queueSize int, itemSize int) (_ *SplitQueue, err error) {
|
||||
if err = CheckQueueSize(queueSize); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if itemSize%os.Getpagesize() != 0 {
|
||||
return nil, errors.New("split queue size must be multiple of os.Getpagesize()")
|
||||
}
|
||||
|
||||
sq := SplitQueue{
|
||||
size: queueSize,
|
||||
itemSize: itemSize,
|
||||
}
|
||||
|
||||
// Clean up a partially initialized queue when something fails.
|
||||
defer func() {
|
||||
if err != nil {
|
||||
_ = sq.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
// There are multiple ways for how the memory for the virtqueue could be
|
||||
// allocated. We could use Go native structs with arrays inside them, but
|
||||
// this wouldn't allow us to make the queue size configurable. And including
|
||||
// a slice in the Go structs wouldn't work, because this would just put the
|
||||
// Go slice descriptor into the memory region which the virtio device will
|
||||
// not understand.
|
||||
// Additionally, Go does not allow us to ensure a correct alignment of the
|
||||
// parts of the virtqueue, as it is required by the virtio specification.
|
||||
//
|
||||
// To resolve this, let's just allocate the memory manually by allocating
|
||||
// one or more memory pages, depending on the queue size. Making the
|
||||
// virtqueue start at the beginning of a page is not strictly necessary, as
|
||||
// the virtio specification does not require it to be continuous in the
|
||||
// physical memory of the host (e.g. the vhost implementation in the kernel
|
||||
// always uses copy_from_user to access it), but this makes it very easy to
|
||||
// guarantee the alignment. Also, it is not required for the virtqueue parts
|
||||
// to be in the same memory region, as we pass separate pointers to them to
|
||||
// the device, but this design just makes things easier to implement.
|
||||
//
|
||||
// One added benefit of allocating the memory manually is, that we have full
|
||||
// control over its lifetime and don't risk the garbage collector to collect
|
||||
// our valuable structures while the device still works with them.
|
||||
|
||||
// The descriptor table is at the start of the page, so alignment is not an
|
||||
// issue here.
|
||||
descriptorTableStart := 0
|
||||
descriptorTableEnd := descriptorTableStart + descriptorTableSize(queueSize)
|
||||
availableRingStart := align(descriptorTableEnd, availableRingAlignment)
|
||||
availableRingEnd := availableRingStart + availableRingSize(queueSize)
|
||||
usedRingStart := align(availableRingEnd, usedRingAlignment)
|
||||
usedRingEnd := usedRingStart + usedRingSize(queueSize)
|
||||
|
||||
sq.buf, err = unix.Mmap(-1, 0, usedRingEnd,
|
||||
unix.PROT_READ|unix.PROT_WRITE,
|
||||
unix.MAP_PRIVATE|unix.MAP_ANONYMOUS)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("allocate virtqueue buffer: %w", err)
|
||||
}
|
||||
|
||||
sq.descriptorTable = newDescriptorTable(queueSize, sq.buf[descriptorTableStart:descriptorTableEnd], sq.itemSize)
|
||||
sq.availableRing = newAvailableRing(queueSize, sq.buf[availableRingStart:availableRingEnd])
|
||||
sq.usedRing = newUsedRing(queueSize, sq.buf[usedRingStart:usedRingEnd])
|
||||
|
||||
sq.kickEventFD, err = eventfd.New()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create kick event file descriptor: %w", err)
|
||||
}
|
||||
sq.callEventFD, err = eventfd.New()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create call event file descriptor: %w", err)
|
||||
}
|
||||
|
||||
if err = sq.descriptorTable.initializeDescriptors(); err != nil {
|
||||
return nil, fmt.Errorf("initialize descriptors: %w", err)
|
||||
}
|
||||
|
||||
sq.epoll, err = eventfd.NewEpoll()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = sq.epoll.AddEvent(sq.callEventFD.FD())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Consume used buffer notifications in the background.
|
||||
sq.stop = sq.startConsumeUsedRing()
|
||||
|
||||
return &sq, nil
|
||||
}
|
||||
|
||||
// Size returns the size of this queue, which is the number of entries/buffers
|
||||
// this queue can hold.
|
||||
func (sq *SplitQueue) Size() int {
|
||||
return sq.size
|
||||
}
|
||||
|
||||
// DescriptorTable returns the [DescriptorTable] behind this queue.
|
||||
func (sq *SplitQueue) DescriptorTable() *DescriptorTable {
|
||||
return sq.descriptorTable
|
||||
}
|
||||
|
||||
// AvailableRing returns the [AvailableRing] behind this queue.
|
||||
func (sq *SplitQueue) AvailableRing() *AvailableRing {
|
||||
return sq.availableRing
|
||||
}
|
||||
|
||||
// UsedRing returns the [UsedRing] behind this queue.
|
||||
func (sq *SplitQueue) UsedRing() *UsedRing {
|
||||
return sq.usedRing
|
||||
}
|
||||
|
||||
// KickEventFD returns the kick event file descriptor behind this queue.
|
||||
// The returned file descriptor should be used with great care to not interfere
|
||||
// with this implementation.
|
||||
func (sq *SplitQueue) KickEventFD() int {
|
||||
return sq.kickEventFD.FD()
|
||||
}
|
||||
|
||||
// CallEventFD returns the call event file descriptor behind this queue.
|
||||
// The returned file descriptor should be used with great care to not interfere
|
||||
// with this implementation.
|
||||
func (sq *SplitQueue) CallEventFD() int {
|
||||
return sq.callEventFD.FD()
|
||||
}
|
||||
|
||||
// startConsumeUsedRing starts a goroutine that runs [consumeUsedRing].
|
||||
// A function is returned that can be used to gracefully cancel it. todo rename
|
||||
func (sq *SplitQueue) startConsumeUsedRing() func() error {
|
||||
return func() error {
|
||||
|
||||
// The goroutine blocks until it receives a signal on the event file
|
||||
// descriptor, so it will never notice the context being canceled.
|
||||
// To resolve this, we can just produce a fake-signal ourselves to wake
|
||||
// it up.
|
||||
if err := sq.callEventFD.Kick(); err != nil {
|
||||
return fmt.Errorf("wake up goroutine: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (sq *SplitQueue) TakeSingle(ctx context.Context) (uint16, error) {
|
||||
var n int
|
||||
var err error
|
||||
for ctx.Err() == nil {
|
||||
out, ok := sq.usedRing.takeOne()
|
||||
if ok {
|
||||
return out, nil
|
||||
}
|
||||
// Wait for a signal from the device.
|
||||
if n, err = sq.epoll.Block(); err != nil {
|
||||
return 0, fmt.Errorf("wait: %w", err)
|
||||
}
|
||||
|
||||
if n > 0 {
|
||||
out, ok = sq.usedRing.takeOne()
|
||||
if ok {
|
||||
_ = sq.epoll.Clear() //???
|
||||
return out, nil
|
||||
} else {
|
||||
continue //???
|
||||
}
|
||||
}
|
||||
}
|
||||
return 0, ctx.Err()
|
||||
}
|
||||
|
||||
func (sq *SplitQueue) BlockAndGetHeadsCapped(ctx context.Context, maxToTake int) ([]UsedElement, error) {
|
||||
var n int
|
||||
var err error
|
||||
for ctx.Err() == nil {
|
||||
|
||||
//we have leftovers in the fridge
|
||||
if sq.more > 0 {
|
||||
stillNeedToTake, out := sq.usedRing.take(maxToTake)
|
||||
sq.more = stillNeedToTake
|
||||
return out, nil
|
||||
}
|
||||
//look inside the fridge
|
||||
stillNeedToTake, out := sq.usedRing.take(maxToTake)
|
||||
if len(out) > 0 {
|
||||
sq.more = stillNeedToTake
|
||||
return out, nil
|
||||
}
|
||||
//fridge is empty I guess
|
||||
|
||||
// Wait for a signal from the device.
|
||||
if n, err = sq.epoll.Block(); err != nil {
|
||||
return nil, fmt.Errorf("wait: %w", err)
|
||||
}
|
||||
if n > 0 {
|
||||
_ = sq.epoll.Clear() //???
|
||||
stillNeedToTake, out = sq.usedRing.take(maxToTake)
|
||||
sq.more = stillNeedToTake
|
||||
return out, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
|
||||
// OfferDescriptorChain offers a descriptor chain to the device which contains a
|
||||
// number of device-readable buffers (out buffers) and device-writable buffers
|
||||
// (in buffers).
|
||||
//
|
||||
// All buffers in the outBuffers slice will be concatenated by chaining
|
||||
// descriptors, one for each buffer in the slice. When a buffer is too large to
|
||||
// fit into a single descriptor (limited by the system's page size), it will be
|
||||
// split up into multiple descriptors within the chain.
|
||||
// When numInBuffers is greater than zero, the given number of device-writable
|
||||
// descriptors will be appended to the end of the chain, each referencing a
|
||||
// whole memory page (see [os.Getpagesize]).
|
||||
//
|
||||
// When the queue is full and no more descriptor chains can be added, a wrapped
|
||||
// [ErrNotEnoughFreeDescriptors] will be returned. If you set waitFree to true,
|
||||
// this method will handle this error and will block instead until there are
|
||||
// enough free descriptors again.
|
||||
//
|
||||
// After defining the descriptor chain in the [DescriptorTable], the index of
|
||||
// the head of the chain will be made available to the device using the
|
||||
// [AvailableRing] and will be returned by this method.
|
||||
// Callers should read from the [SplitQueue.UsedDescriptorChains] channel to be
|
||||
// notified when the descriptor chain was used by the device and should free the
|
||||
// used descriptor chains again using [SplitQueue.FreeDescriptorChain] when
|
||||
// they're done with them. When this does not happen, the queue will run full
|
||||
// and any further calls to [SplitQueue.OfferDescriptorChain] will stall.
|
||||
|
||||
func (sq *SplitQueue) OfferInDescriptorChains() (uint16, error) {
|
||||
// Create a descriptor chain for the given buffers.
|
||||
var (
|
||||
head uint16
|
||||
err error
|
||||
)
|
||||
for {
|
||||
head, err = sq.descriptorTable.createDescriptorForInputs()
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
|
||||
// I don't wanna use errors.Is, it's slow
|
||||
//goland:noinspection GoDirectComparisonOfErrors
|
||||
if err == ErrNotEnoughFreeDescriptors {
|
||||
return 0, err
|
||||
} else {
|
||||
return 0, fmt.Errorf("create descriptor chain: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Make the descriptor chain available to the device.
|
||||
sq.availableRing.offerSingle(head)
|
||||
|
||||
// Notify the device to make it process the updated available ring.
|
||||
if err := sq.kickEventFD.Kick(); err != nil {
|
||||
return head, fmt.Errorf("notify device: %w", err)
|
||||
}
|
||||
|
||||
return head, nil
|
||||
}
|
||||
|
||||
// GetDescriptorChain returns the device-readable buffers (out buffers) and
|
||||
// device-writable buffers (in buffers) of the descriptor chain with the given
|
||||
// head index.
|
||||
// The head index must be one that was returned by a previous call to
|
||||
// [SplitQueue.OfferDescriptorChain] and the descriptor chain must not have been
|
||||
// freed yet.
|
||||
//
|
||||
// Be careful to only access the returned buffer slices when the device is no
|
||||
// longer using them. They must not be accessed after
|
||||
// [SplitQueue.FreeDescriptorChain] has been called.
|
||||
func (sq *SplitQueue) GetDescriptorChain(head uint16) (outBuffers, inBuffers [][]byte, err error) {
|
||||
return sq.descriptorTable.getDescriptorChain(head)
|
||||
}
|
||||
|
||||
func (sq *SplitQueue) GetDescriptorItem(head uint16) ([]byte, error) {
|
||||
sq.descriptorTable.descriptors[head].length = uint32(sq.descriptorTable.itemSize)
|
||||
return sq.descriptorTable.getDescriptorItem(head)
|
||||
}
|
||||
|
||||
func (sq *SplitQueue) GetDescriptorInbuffers(head uint16, inBuffers *[][]byte) error {
|
||||
return sq.descriptorTable.getDescriptorInbuffers(head, inBuffers)
|
||||
}
|
||||
|
||||
// FreeDescriptorChain frees the descriptor chain with the given head index.
|
||||
// The head index must be one that was returned by a previous call to
|
||||
// [SplitQueue.OfferDescriptorChain] and the descriptor chain must not have been
|
||||
// freed yet.
|
||||
//
|
||||
// This creates new room in the queue which can be used by following
|
||||
// [SplitQueue.OfferDescriptorChain] calls.
|
||||
// When there are outstanding calls for [SplitQueue.OfferDescriptorChain] that
|
||||
// are waiting for free room in the queue, they may become unblocked by this.
|
||||
func (sq *SplitQueue) FreeDescriptorChain(head uint16) error {
|
||||
//not called under lock
|
||||
if err := sq.descriptorTable.freeDescriptorChain(head); err != nil {
|
||||
return fmt.Errorf("free: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (sq *SplitQueue) SetDescSize(head uint16, sz int) {
|
||||
//not called under lock
|
||||
sq.descriptorTable.descriptors[int(head)].length = uint32(sz)
|
||||
}
|
||||
|
||||
func (sq *SplitQueue) OfferDescriptorChains(chains []uint16, kick bool) error {
|
||||
//todo not doing this may break eventually?
|
||||
//not called under lock
|
||||
//if err := sq.descriptorTable.freeDescriptorChain(head); err != nil {
|
||||
// return fmt.Errorf("free: %w", err)
|
||||
//}
|
||||
|
||||
// Make the descriptor chain available to the device.
|
||||
sq.availableRing.offer(chains)
|
||||
|
||||
// Notify the device to make it process the updated available ring.
|
||||
if kick {
|
||||
return sq.Kick()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (sq *SplitQueue) Kick() error {
|
||||
if err := sq.kickEventFD.Kick(); err != nil {
|
||||
return fmt.Errorf("notify device: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close releases all resources used for this queue.
|
||||
// The implementation will try to release as many resources as possible and
|
||||
// collect potential errors before returning them.
|
||||
func (sq *SplitQueue) Close() error {
|
||||
var errs []error
|
||||
|
||||
if sq.stop != nil {
|
||||
// This has to happen before the event file descriptors may be closed.
|
||||
if err := sq.stop(); err != nil {
|
||||
errs = append(errs, fmt.Errorf("stop consume used ring: %w", err))
|
||||
}
|
||||
|
||||
// Make sure that this code block is executed only once.
|
||||
sq.stop = nil
|
||||
}
|
||||
|
||||
if err := sq.kickEventFD.Close(); err != nil {
|
||||
errs = append(errs, fmt.Errorf("close kick event file descriptor: %w", err))
|
||||
}
|
||||
if err := sq.callEventFD.Close(); err != nil {
|
||||
errs = append(errs, fmt.Errorf("close call event file descriptor: %w", err))
|
||||
}
|
||||
|
||||
if err := sq.descriptorTable.releaseBuffers(); err != nil {
|
||||
errs = append(errs, fmt.Errorf("release descriptor buffers: %w", err))
|
||||
}
|
||||
|
||||
if sq.buf != nil {
|
||||
if err := unix.Munmap(sq.buf); err == nil {
|
||||
sq.buf = nil
|
||||
} else {
|
||||
errs = append(errs, fmt.Errorf("unmap virtqueue buffer: %w", err))
|
||||
}
|
||||
}
|
||||
|
||||
return errors.Join(errs...)
|
||||
}
|
||||
|
||||
func align(index, alignment int) int {
|
||||
remainder := index % alignment
|
||||
if remainder == 0 {
|
||||
return index
|
||||
}
|
||||
return index + alignment - remainder
|
||||
}
|
||||
21
overlay/virtqueue/used_element.go
Normal file
21
overlay/virtqueue/used_element.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package virtqueue
|
||||
|
||||
// usedElementSize is the number of bytes needed to store a [UsedElement] in
|
||||
// memory.
|
||||
const usedElementSize = 8
|
||||
|
||||
// UsedElement is an element of the [UsedRing] and describes a descriptor chain
|
||||
// that was used by the device.
|
||||
type UsedElement struct {
|
||||
// DescriptorIndex is the index of the head of the used descriptor chain in
|
||||
// the [DescriptorTable].
|
||||
// The index is 32-bit here for padding reasons.
|
||||
DescriptorIndex uint32
|
||||
// Length is the number of bytes written into the device writable portion of
|
||||
// the buffer described by the descriptor chain.
|
||||
Length uint32
|
||||
}
|
||||
|
||||
func (u *UsedElement) GetHead() uint16 {
|
||||
return uint16(u.DescriptorIndex)
|
||||
}
|
||||
12
overlay/virtqueue/used_element_internal_test.go
Normal file
12
overlay/virtqueue/used_element_internal_test.go
Normal file
@@ -0,0 +1,12 @@
|
||||
package virtqueue
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"unsafe"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestUsedElement_Size(t *testing.T) {
|
||||
assert.EqualValues(t, usedElementSize, unsafe.Sizeof(UsedElement{}))
|
||||
}
|
||||
184
overlay/virtqueue/used_ring.go
Normal file
184
overlay/virtqueue/used_ring.go
Normal file
@@ -0,0 +1,184 @@
|
||||
package virtqueue
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
// usedRingFlag is a flag that describes a [UsedRing].
|
||||
type usedRingFlag uint16
|
||||
|
||||
const (
|
||||
// usedRingFlagNoNotify is used by the host to advise the guest to not
|
||||
// kick it when adding a buffer. It's unreliable, so it's simply an
|
||||
// optimization. Guest will still kick when it's out of buffers.
|
||||
usedRingFlagNoNotify usedRingFlag = 1 << iota
|
||||
)
|
||||
|
||||
// usedRingSize is the number of bytes needed to store a [UsedRing] with the
|
||||
// given queue size in memory.
|
||||
func usedRingSize(queueSize int) int {
|
||||
return 6 + usedElementSize*queueSize
|
||||
}
|
||||
|
||||
// usedRingAlignment is the minimum alignment of a [UsedRing] in memory, as
|
||||
// required by the virtio spec.
|
||||
const usedRingAlignment = 4
|
||||
|
||||
// UsedRing is where the device returns descriptor chains once it is done with
|
||||
// them. Each ring entry is a [UsedElement]. It is only written to by the device
|
||||
// and read by the driver.
|
||||
//
|
||||
// Because the size of the ring depends on the queue size, we cannot define a
|
||||
// Go struct with a static size that maps to the memory of the ring. Instead,
|
||||
// this struct only contains pointers to the corresponding memory areas.
|
||||
type UsedRing struct {
|
||||
initialized bool
|
||||
|
||||
// flags that describe this ring.
|
||||
flags *usedRingFlag
|
||||
// ringIndex indicates where the device would put the next entry into the
|
||||
// ring (modulo the queue size).
|
||||
ringIndex *uint16
|
||||
// ring contains the [UsedElement]s. It wraps around at queue size.
|
||||
ring []UsedElement
|
||||
// availableEvent is not used by this implementation, but we reserve it
|
||||
// anyway to avoid issues in case a device may try to write to it, contrary
|
||||
// to the virtio specification.
|
||||
availableEvent *uint16
|
||||
|
||||
// lastIndex is the internal ringIndex up to which all [UsedElement]s were
|
||||
// processed.
|
||||
lastIndex uint16
|
||||
|
||||
//mu sync.Mutex
|
||||
}
|
||||
|
||||
// newUsedRing creates a used ring that uses the given underlying memory. The
|
||||
// length of the memory slice must match the size needed for the ring (see
|
||||
// [usedRingSize]) for the given queue size.
|
||||
func newUsedRing(queueSize int, mem []byte) *UsedRing {
|
||||
ringSize := usedRingSize(queueSize)
|
||||
if len(mem) != ringSize {
|
||||
panic(fmt.Sprintf("memory size (%v) does not match required size "+
|
||||
"for used ring: %v", len(mem), ringSize))
|
||||
}
|
||||
|
||||
r := UsedRing{
|
||||
initialized: true,
|
||||
flags: (*usedRingFlag)(unsafe.Pointer(&mem[0])),
|
||||
ringIndex: (*uint16)(unsafe.Pointer(&mem[2])),
|
||||
ring: unsafe.Slice((*UsedElement)(unsafe.Pointer(&mem[4])), queueSize),
|
||||
availableEvent: (*uint16)(unsafe.Pointer(&mem[ringSize-2])),
|
||||
}
|
||||
r.lastIndex = *r.ringIndex
|
||||
return &r
|
||||
}
|
||||
|
||||
// Address returns the pointer to the beginning of the ring in memory.
|
||||
// Do not modify the memory directly to not interfere with this implementation.
|
||||
func (r *UsedRing) Address() uintptr {
|
||||
if !r.initialized {
|
||||
panic("used ring is not initialized")
|
||||
}
|
||||
return uintptr(unsafe.Pointer(r.flags))
|
||||
}
|
||||
|
||||
// take returns all new [UsedElement]s that the device put into the ring and
|
||||
// that weren't already returned by a previous call to this method.
|
||||
// had a lock, I removed it
|
||||
func (r *UsedRing) take(maxToTake int) (int, []UsedElement) {
|
||||
//r.mu.Lock()
|
||||
//defer r.mu.Unlock()
|
||||
|
||||
ringIndex := *r.ringIndex
|
||||
if ringIndex == r.lastIndex {
|
||||
// Nothing new.
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
// Calculate the number new used elements that we can read from the ring.
|
||||
// The ring index may wrap, so special handling for that case is needed.
|
||||
count := int(ringIndex - r.lastIndex)
|
||||
if count < 0 {
|
||||
count += 0xffff
|
||||
}
|
||||
|
||||
stillNeedToTake := 0
|
||||
|
||||
if maxToTake > 0 {
|
||||
stillNeedToTake = count - maxToTake
|
||||
if stillNeedToTake < 0 {
|
||||
stillNeedToTake = 0
|
||||
}
|
||||
count = min(count, maxToTake)
|
||||
}
|
||||
|
||||
// The number of new elements can never exceed the queue size.
|
||||
if count > len(r.ring) {
|
||||
panic("used ring contains more new elements than the ring is long")
|
||||
}
|
||||
|
||||
elems := make([]UsedElement, count)
|
||||
for i := range count {
|
||||
elems[i] = r.ring[r.lastIndex%uint16(len(r.ring))]
|
||||
r.lastIndex++
|
||||
}
|
||||
|
||||
return stillNeedToTake, elems
|
||||
}
|
||||
|
||||
func (r *UsedRing) takeOne() (uint16, bool) {
|
||||
//r.mu.Lock()
|
||||
//defer r.mu.Unlock()
|
||||
|
||||
ringIndex := *r.ringIndex
|
||||
if ringIndex == r.lastIndex {
|
||||
// Nothing new.
|
||||
return 0xffff, false
|
||||
}
|
||||
|
||||
// Calculate the number new used elements that we can read from the ring.
|
||||
// The ring index may wrap, so special handling for that case is needed.
|
||||
count := int(ringIndex - r.lastIndex)
|
||||
if count < 0 {
|
||||
count += 0xffff
|
||||
}
|
||||
|
||||
// The number of new elements can never exceed the queue size.
|
||||
if count > len(r.ring) {
|
||||
panic("used ring contains more new elements than the ring is long")
|
||||
}
|
||||
|
||||
if count == 0 {
|
||||
return 0xffff, false
|
||||
}
|
||||
|
||||
out := r.ring[r.lastIndex%uint16(len(r.ring))].GetHead()
|
||||
r.lastIndex++
|
||||
|
||||
return out, true
|
||||
}
|
||||
|
||||
// InitOfferSingle is only used to pre-fill the used queue at startup, and should not be used if the device is running!
|
||||
func (r *UsedRing) InitOfferSingle(x uint16, size int) {
|
||||
//always called under lock
|
||||
//r.mu.Lock()
|
||||
//defer r.mu.Unlock()
|
||||
|
||||
offset := 0
|
||||
// Add descriptor chain heads to the ring.
|
||||
|
||||
// The 16-bit ring index may overflow. This is expected and is not an
|
||||
// issue because the size of the ring array (which equals the queue
|
||||
// size) is always a power of 2 and smaller than the highest possible
|
||||
// 16-bit value.
|
||||
insertIndex := int(*r.ringIndex+uint16(offset)) % len(r.ring)
|
||||
r.ring[insertIndex] = UsedElement{
|
||||
DescriptorIndex: uint32(x),
|
||||
Length: uint32(size),
|
||||
}
|
||||
|
||||
// Increase the ring index by the number of descriptor chains added to the ring.
|
||||
*r.ringIndex += 1
|
||||
}
|
||||
136
overlay/virtqueue/used_ring_internal_test.go
Normal file
136
overlay/virtqueue/used_ring_internal_test.go
Normal file
@@ -0,0 +1,136 @@
|
||||
package virtqueue
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestUsedRing_MemoryLayout(t *testing.T) {
|
||||
const queueSize = 2
|
||||
|
||||
memory := make([]byte, usedRingSize(queueSize))
|
||||
r := newUsedRing(queueSize, memory)
|
||||
|
||||
*r.flags = 0x01ff
|
||||
*r.ringIndex = 1
|
||||
r.ring[0] = UsedElement{
|
||||
DescriptorIndex: 0x0123,
|
||||
Length: 0x4567,
|
||||
}
|
||||
r.ring[1] = UsedElement{
|
||||
DescriptorIndex: 0x89ab,
|
||||
Length: 0xcdef,
|
||||
}
|
||||
|
||||
assert.Equal(t, []byte{
|
||||
0xff, 0x01,
|
||||
0x01, 0x00,
|
||||
0x23, 0x01, 0x00, 0x00,
|
||||
0x67, 0x45, 0x00, 0x00,
|
||||
0xab, 0x89, 0x00, 0x00,
|
||||
0xef, 0xcd, 0x00, 0x00,
|
||||
0x00, 0x00,
|
||||
}, memory)
|
||||
}
|
||||
|
||||
//func TestUsedRing_Take(t *testing.T) {
|
||||
// const queueSize = 8
|
||||
//
|
||||
// tests := []struct {
|
||||
// name string
|
||||
// ring []UsedElement
|
||||
// ringIndex uint16
|
||||
// lastIndex uint16
|
||||
// expected []UsedElement
|
||||
// }{
|
||||
// {
|
||||
// name: "nothing new",
|
||||
// ring: []UsedElement{
|
||||
// {DescriptorIndex: 1},
|
||||
// {DescriptorIndex: 2},
|
||||
// {DescriptorIndex: 3},
|
||||
// {DescriptorIndex: 4},
|
||||
// {},
|
||||
// {},
|
||||
// {},
|
||||
// {},
|
||||
// },
|
||||
// ringIndex: 4,
|
||||
// lastIndex: 4,
|
||||
// expected: nil,
|
||||
// },
|
||||
// {
|
||||
// name: "no overflow",
|
||||
// ring: []UsedElement{
|
||||
// {DescriptorIndex: 1},
|
||||
// {DescriptorIndex: 2},
|
||||
// {DescriptorIndex: 3},
|
||||
// {DescriptorIndex: 4},
|
||||
// {},
|
||||
// {},
|
||||
// {},
|
||||
// {},
|
||||
// },
|
||||
// ringIndex: 4,
|
||||
// lastIndex: 1,
|
||||
// expected: []UsedElement{
|
||||
// {DescriptorIndex: 2},
|
||||
// {DescriptorIndex: 3},
|
||||
// {DescriptorIndex: 4},
|
||||
// },
|
||||
// },
|
||||
// {
|
||||
// name: "ring overflow",
|
||||
// ring: []UsedElement{
|
||||
// {DescriptorIndex: 9},
|
||||
// {DescriptorIndex: 10},
|
||||
// {DescriptorIndex: 3},
|
||||
// {DescriptorIndex: 4},
|
||||
// {DescriptorIndex: 5},
|
||||
// {DescriptorIndex: 6},
|
||||
// {DescriptorIndex: 7},
|
||||
// {DescriptorIndex: 8},
|
||||
// },
|
||||
// ringIndex: 10,
|
||||
// lastIndex: 7,
|
||||
// expected: []UsedElement{
|
||||
// {DescriptorIndex: 8},
|
||||
// {DescriptorIndex: 9},
|
||||
// {DescriptorIndex: 10},
|
||||
// },
|
||||
// },
|
||||
// {
|
||||
// name: "index overflow",
|
||||
// ring: []UsedElement{
|
||||
// {DescriptorIndex: 9},
|
||||
// {DescriptorIndex: 10},
|
||||
// {DescriptorIndex: 3},
|
||||
// {DescriptorIndex: 4},
|
||||
// {DescriptorIndex: 5},
|
||||
// {DescriptorIndex: 6},
|
||||
// {DescriptorIndex: 7},
|
||||
// {DescriptorIndex: 8},
|
||||
// },
|
||||
// ringIndex: 2,
|
||||
// lastIndex: 65535,
|
||||
// expected: []UsedElement{
|
||||
// {DescriptorIndex: 8},
|
||||
// {DescriptorIndex: 9},
|
||||
// {DescriptorIndex: 10},
|
||||
// },
|
||||
// },
|
||||
// }
|
||||
// for _, tt := range tests {
|
||||
// t.Run(tt.name, func(t *testing.T) {
|
||||
// memory := make([]byte, usedRingSize(queueSize))
|
||||
// r := newUsedRing(queueSize, memory)
|
||||
//
|
||||
// copy(r.ring, tt.ring)
|
||||
// *r.ringIndex = tt.ringIndex
|
||||
// r.lastIndex = tt.lastIndex
|
||||
//
|
||||
// assert.Equal(t, tt.expected, r.take())
|
||||
// })
|
||||
// }
|
||||
//}
|
||||
Reference in New Issue
Block a user