diff --git a/cmd/nebula/main.go b/cmd/nebula/main.go index ffdc15b..479c929 100644 --- a/cmd/nebula/main.go +++ b/cmd/nebula/main.go @@ -3,6 +3,9 @@ package main import ( "flag" "fmt" + "log" + "net/http" + _ "net/http/pprof" "os" "runtime/debug" "strings" @@ -71,6 +74,10 @@ func main() { os.Exit(1) } + go func() { + log.Println(http.ListenAndServe("0.0.0.0:6060", nil)) + }() + if !*configTest { ctrl.Start() notifyReady(l) diff --git a/connection_state.go b/connection_state.go index db885d4..6913cd0 100644 --- a/connection_state.go +++ b/connection_state.go @@ -13,7 +13,7 @@ import ( "github.com/slackhq/nebula/noiseutil" ) -const ReplayWindow = 1024 +const ReplayWindow = 4096 type ConnectionState struct { eKey *NebulaCipherState diff --git a/firewall.go b/firewall.go index 45dc069..1ef6578 100644 --- a/firewall.go +++ b/firewall.go @@ -403,9 +403,9 @@ var ErrNoMatchingRule = errors.New("no matching rule in firewall table") // Drop returns an error if the packet should be dropped, explaining why. It // returns nil if the packet should not be dropped. -func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *cert.CAPool, localCache firewall.ConntrackCache) error { +func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *cert.CAPool, localCache firewall.ConntrackCache, now time.Time) error { // Check if we spoke to this tuple, if we did then allow this packet - if f.inConns(fp, h, caPool, localCache) { + if f.inConns(fp, h, caPool, localCache, now) { return nil } @@ -454,7 +454,7 @@ func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool * } // We always want to conntrack since it is a faster operation - f.addConn(fp, incoming) + f.addConn(fp, incoming, now) return nil } @@ -483,7 +483,7 @@ func (f *Firewall) EmitStats() { metrics.GetOrRegisterGauge("firewall.rules.hash", nil).Update(int64(f.GetRuleHashFNV())) } -func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool, localCache firewall.ConntrackCache) bool { +func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool, localCache firewall.ConntrackCache, now time.Time) bool { if localCache != nil { if _, ok := localCache[fp]; ok { return true @@ -495,7 +495,7 @@ func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool, // Purge every time we test ep, has := conntrack.TimerWheel.Purge() if has { - f.evict(ep) + f.evict(ep, now) } c, ok := conntrack.Conns[fp] @@ -542,11 +542,11 @@ func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool, switch fp.Protocol { case firewall.ProtoTCP: - c.Expires = time.Now().Add(f.TCPTimeout) + c.Expires = now.Add(f.TCPTimeout) case firewall.ProtoUDP: - c.Expires = time.Now().Add(f.UDPTimeout) + c.Expires = now.Add(f.UDPTimeout) default: - c.Expires = time.Now().Add(f.DefaultTimeout) + c.Expires = now.Add(f.DefaultTimeout) } conntrack.Unlock() @@ -558,7 +558,7 @@ func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool, return true } -func (f *Firewall) addConn(fp firewall.Packet, incoming bool) { +func (f *Firewall) addConn(fp firewall.Packet, incoming bool, now time.Time) { var timeout time.Duration c := &conn{} @@ -574,7 +574,7 @@ func (f *Firewall) addConn(fp firewall.Packet, incoming bool) { conntrack := f.Conntrack conntrack.Lock() if _, ok := conntrack.Conns[fp]; !ok { - conntrack.TimerWheel.Advance(time.Now()) + conntrack.TimerWheel.Advance(now) conntrack.TimerWheel.Add(fp, timeout) } @@ -582,14 +582,14 @@ func (f *Firewall) addConn(fp firewall.Packet, incoming bool) { // firewall reload c.incoming = incoming c.rulesVersion = f.rulesVersion - c.Expires = time.Now().Add(timeout) + c.Expires = now.Add(timeout) conntrack.Conns[fp] = c conntrack.Unlock() } // Evict checks if a conntrack entry has expired, if so it is removed, if not it is re-added to the wheel // Caller must own the connMutex lock! -func (f *Firewall) evict(p firewall.Packet) { +func (f *Firewall) evict(p firewall.Packet, now time.Time) { // Are we still tracking this conn? conntrack := f.Conntrack t, ok := conntrack.Conns[p] @@ -597,11 +597,11 @@ func (f *Firewall) evict(p firewall.Packet) { return } - newT := t.Expires.Sub(time.Now()) + newT := t.Expires.Sub(now) // Timeout is in the future, re-add the timer if newT > 0 { - conntrack.TimerWheel.Advance(time.Now()) + conntrack.TimerWheel.Advance(now) conntrack.TimerWheel.Add(p, newT) return } diff --git a/go.mod b/go.mod index 7999c64..8ad287c 100644 --- a/go.mod +++ b/go.mod @@ -50,6 +50,6 @@ require ( github.com/vishvananda/netns v0.0.5 // indirect go.yaml.in/yaml/v2 v2.4.2 // indirect golang.org/x/mod v0.24.0 // indirect - golang.org/x/time v0.5.0 // indirect + golang.org/x/time v0.7.0 // indirect golang.org/x/tools v0.33.0 // indirect ) diff --git a/go.sum b/go.sum index 1e1c9d8..4ec92ba 100644 --- a/go.sum +++ b/go.sum @@ -217,8 +217,8 @@ golang.org/x/term v0.38.0/go.mod h1:bSEAKrOT1W+VSu9TSCMtoGEOUcKxOKgl3LE5QEF/xVg= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk= -golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= +golang.org/x/time v0.7.0 h1:ntUhktv3OPE6TgYxXWv9vKvUSJyIFJlyohwbkEwPrKQ= +golang.org/x/time v0.7.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= diff --git a/inside.go b/inside.go index 0d53f95..1896df8 100644 --- a/inside.go +++ b/inside.go @@ -2,16 +2,18 @@ package nebula import ( "net/netip" + "time" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/firewall" "github.com/slackhq/nebula/header" "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/noiseutil" + "github.com/slackhq/nebula/packet" "github.com/slackhq/nebula/routing" ) -func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb, out []byte, q int, localCache firewall.ConntrackCache) { +func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb []byte, out *packet.Packet, q int, localCache firewall.ConntrackCache, now time.Time) { err := newPacket(packet, false, fwPacket) if err != nil { if f.l.Level >= logrus.DebugLevel { @@ -53,7 +55,7 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet }) if hostinfo == nil { - f.rejectInside(packet, out, q) + f.rejectInside(packet, out.Payload, q) //todo vector? if f.l.Level >= logrus.DebugLevel { f.l.WithField("vpnAddr", fwPacket.RemoteAddr). WithField("fwPacket", fwPacket). @@ -66,12 +68,11 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet return } - dropReason := f.firewall.Drop(*fwPacket, false, hostinfo, f.pki.GetCAPool(), localCache) + dropReason := f.firewall.Drop(*fwPacket, false, hostinfo, f.pki.GetCAPool(), localCache, now) if dropReason == nil { - f.sendNoMetrics(header.Message, 0, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, packet, nb, out, q) - + f.sendNoMetricsDelayed(header.Message, 0, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, packet, nb, out, q) } else { - f.rejectInside(packet, out, q) + f.rejectInside(packet, out.Payload, q) //todo vector? if f.l.Level >= logrus.DebugLevel { hostinfo.logger(f.l). WithField("fwPacket", fwPacket). @@ -218,7 +219,7 @@ func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubTyp } // check if packet is in outbound fw rules - dropReason := f.firewall.Drop(*fp, false, hostinfo, f.pki.GetCAPool(), nil) + dropReason := f.firewall.Drop(*fp, false, hostinfo, f.pki.GetCAPool(), nil, time.Now()) if dropReason != nil { if f.l.Level >= logrus.DebugLevel { f.l.WithField("fwPacket", fp). @@ -410,3 +411,81 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType } } } + +func (f *Interface) sendNoMetricsDelayed(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote netip.AddrPort, p, nb []byte, out *packet.Packet, q int) { + if ci.eKey == nil { + return + } + useRelay := !remote.IsValid() && !hostinfo.remote.IsValid() + fullOut := out.Payload + + if useRelay { + if len(out.Payload) < header.Len { + // out always has a capacity of mtu, but not always a length greater than the header.Len. + // Grow it to make sure the next operation works. + out.Payload = out.Payload[:header.Len] + } + // Save a header's worth of data at the front of the 'out' buffer. + out.Payload = out.Payload[header.Len:] + } + + if noiseutil.EncryptLockNeeded { + // NOTE: for goboring AESGCMTLS we need to lock because of the nonce check + ci.writeLock.Lock() + } + c := ci.messageCounter.Add(1) + + //l.WithField("trace", string(debug.Stack())).Error("out Header ", &Header{Version, t, st, 0, hostinfo.remoteIndexId, c}, p) + out.Payload = header.Encode(out.Payload, header.Version, t, st, hostinfo.remoteIndexId, c) + f.connectionManager.Out(hostinfo) + + // Query our LH if we haven't since the last time we've been rebound, this will cause the remote to punch against + // all our addrs and enable a faster roaming. + if t != header.CloseTunnel && hostinfo.lastRebindCount != f.rebindCount { + //NOTE: there is an update hole if a tunnel isn't used and exactly 256 rebinds occur before the tunnel is + // finally used again. This tunnel would eventually be torn down and recreated if this action didn't help. + f.lightHouse.QueryServer(hostinfo.vpnAddrs[0]) + hostinfo.lastRebindCount = f.rebindCount + if f.l.Level >= logrus.DebugLevel { + f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).Debug("Lighthouse update triggered for punch due to rebind counter") + } + } + + var err error + out.Payload, err = ci.eKey.EncryptDanger(out.Payload, out.Payload, p, c, nb) + if noiseutil.EncryptLockNeeded { + ci.writeLock.Unlock() + } + if err != nil { + hostinfo.logger(f.l).WithError(err). + WithField("udpAddr", remote).WithField("counter", c). + WithField("attemptedCounter", c). + Error("Failed to encrypt outgoing packet") + return + } + + if remote.IsValid() { + err = f.writers[q].Prep(out, remote) + if err != nil { + hostinfo.logger(f.l).WithError(err).WithField("udpAddr", remote).Error("Failed to write outgoing packet") + } + } else if hostinfo.remote.IsValid() { + err = f.writers[q].Prep(out, hostinfo.remote) + if err != nil { + hostinfo.logger(f.l).WithError(err).WithField("udpAddr", remote).Error("Failed to write outgoing packet") + } + } else { + // Try to send via a relay + for _, relayIP := range hostinfo.relayState.CopyRelayIps() { + relayHostInfo, relay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relayIP) + if err != nil { + hostinfo.relayState.DeleteRelay(relayIP) + hostinfo.logger(f.l).WithField("relay", relayIP).WithError(err).Info("sendNoMetrics failed to find HostInfo") + continue + } + //todo vector!! + f.SendVia(relayHostInfo, relay, out.Payload, nb, fullOut[:header.Len+len(out.Payload)], true) + break + } + } +} diff --git a/interface.go b/interface.go index 9f83d18..0a0b472 100644 --- a/interface.go +++ b/interface.go @@ -4,7 +4,6 @@ import ( "context" "errors" "fmt" - "io" "net/netip" "os" "runtime" @@ -18,10 +17,12 @@ import ( "github.com/slackhq/nebula/firewall" "github.com/slackhq/nebula/header" "github.com/slackhq/nebula/overlay" + "github.com/slackhq/nebula/packet" "github.com/slackhq/nebula/udp" ) const mtu = 9001 +const batch = 1024 //todo config! type InterfaceConfig struct { HostMap *HostMap @@ -86,12 +87,18 @@ type Interface struct { conntrackCacheTimeout time.Duration writers []udp.Conn - readers []io.ReadWriteCloser + readers []overlay.TunDev metricHandshakes metrics.Histogram messageMetrics *MessageMetrics cachedPacketMetrics *cachedPacketMetrics + listenInN int + listenOutN int + + listenInMetric metrics.Histogram + listenOutMetric metrics.Histogram + l *logrus.Logger } @@ -177,7 +184,7 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) { routines: c.routines, version: c.version, writers: make([]udp.Conn, c.routines), - readers: make([]io.ReadWriteCloser, c.routines), + readers: make([]overlay.TunDev, c.routines), myVpnNetworks: cs.myVpnNetworks, myVpnNetworksTable: cs.myVpnNetworksTable, myVpnAddrs: cs.myVpnAddrs, @@ -196,6 +203,8 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) { l: c.l, } + ifce.listenInMetric = metrics.GetOrRegisterHistogram("vhost.listenIn.n", nil, metrics.NewExpDecaySample(1028, 0.015)) + ifce.listenOutMetric = metrics.GetOrRegisterHistogram("vhost.listenOut.n", nil, metrics.NewExpDecaySample(1028, 0.015)) ifce.tryPromoteEvery.Store(c.tryPromoteEvery) ifce.reQueryEvery.Store(c.reQueryEvery) @@ -232,7 +241,7 @@ func (f *Interface) activate() { metrics.GetOrRegisterGauge("routines", nil).Update(int64(f.routines)) // Prepare n tun queues - var reader io.ReadWriteCloser = f.inside + var reader overlay.TunDev = f.inside for i := 0; i < f.routines; i++ { if i > 0 { reader, err = f.inside.NewMultiQueueReader() @@ -261,40 +270,72 @@ func (f *Interface) run() { } } -func (f *Interface) listenOut(i int) { +func (f *Interface) listenOut(q int) { runtime.LockOSThread() var li udp.Conn - if i > 0 { - li = f.writers[i] + if q > 0 { + li = f.writers[q] } else { li = f.outside } ctCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout) lhh := f.lightHouse.NewRequestHandler() - plaintext := make([]byte, udp.MTU) + + outPackets := make([]*packet.OutPacket, batch) + for i := 0; i < batch; i++ { + outPackets[i] = packet.NewOut() + } + h := &header.H{} fwPacket := &firewall.Packet{} nb := make([]byte, 12, 12) - li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) { - f.readOutsidePackets(ViaSender{UdpAddr: fromUdpAddr}, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get(f.l)) + toSend := make([][]byte, batch) + + li.ListenOut(func(pkts []*packet.Packet) { + toSend = toSend[:0] + for i := range outPackets { + outPackets[i].Valid = false + outPackets[i].SegCounter = 0 + } + + //todo f.readOutsidePackets(ViaSender{UdpAddr: fromUdpAddr}, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get(f.l)) + f.readOutsidePacketsMany(pkts, outPackets, h, fwPacket, lhh, nb, q, ctCache.Get(f.l), time.Now()) + //we opportunistically tx, but try to also send stragglers + if _, err := f.readers[q].WriteMany(outPackets, q); err != nil { + f.l.WithError(err).Error("Failed to send packets") + } + //todo I broke this + //n := len(toSend) + //if f.l.Level == logrus.DebugLevel { + // f.listenOutMetric.Update(int64(n)) + //} + //f.listenOutN = n + }) } -func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) { +func (f *Interface) listenIn(reader overlay.TunDev, queueNum int) { runtime.LockOSThread() - packet := make([]byte, mtu) - out := make([]byte, mtu) fwPacket := &firewall.Packet{} nb := make([]byte, 12, 12) conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout) + packets := make([]*packet.VirtIOPacket, batch) + outPackets := make([]*packet.Packet, batch) + for i := 0; i < batch; i++ { + packets[i] = packet.NewVIO() + outPackets[i] = packet.New(false) //todo? + } + for { - n, err := reader.Read(packet) + n, err := reader.ReadMany(packets, queueNum) + + //todo!! if err != nil { if errors.Is(err, os.ErrClosed) && f.closed.Load() { return @@ -305,7 +346,22 @@ func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) { os.Exit(2) } - f.consumeInsidePacket(packet[:n], fwPacket, nb, out, i, conntrackCache.Get(f.l)) + if f.l.Level == logrus.DebugLevel { + f.listenInMetric.Update(int64(n)) + } + f.listenInN = n + + now := time.Now() + for i, pkt := range packets[:n] { + outPackets[i].OutLen = -1 + f.consumeInsidePacket(pkt.Payload, fwPacket, nb, outPackets[i], queueNum, conntrackCache.Get(f.l), now) + reader.RecycleRxSeg(pkt, i == (n-1), queueNum) //todo handle err? + pkt.Reset() + } + _, err = f.writers[queueNum].WriteBatch(outPackets[:n]) + if err != nil { + f.l.WithError(err).Error("Error while writing outbound packets") + } } } @@ -443,6 +499,11 @@ func (f *Interface) emitStats(ctx context.Context, i time.Duration) { } else { certMaxVersion.Update(int64(certState.v1Cert.Version())) } + if f.l.Level != logrus.DebugLevel { + f.listenInMetric.Update(int64(f.listenInN)) + f.listenOutMetric.Update(int64(f.listenOutN)) + } + } } } diff --git a/outside.go b/outside.go index b1a28e5..30a17e2 100644 --- a/outside.go +++ b/outside.go @@ -7,6 +7,7 @@ import ( "time" "github.com/google/gopacket/layers" + "github.com/slackhq/nebula/packet" "golang.org/x/net/ipv6" "github.com/sirupsen/logrus" @@ -19,7 +20,7 @@ const ( minFwPacketLen = 4 ) -func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache) { +func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache, now time.Time) { err := h.Parse(packet) if err != nil { // Hole punch packets are 0 or 1 byte big, so lets ignore printing those errors @@ -60,7 +61,7 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte, switch h.Subtype { case header.MessageNone: - if !f.decryptToTun(hostinfo, h.MessageCounter, out, packet, fwPacket, nb, q, localCache) { + if !f.decryptToTun(hostinfo, h.MessageCounter, out, packet, fwPacket, nb, q, localCache, now) { return } case header.MessageRelay: @@ -102,7 +103,7 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte, relay: relay, IsRelayed: true, } - f.readOutsidePackets(via, out[:0], signedPayload, h, fwPacket, lhf, nb, q, localCache) + f.readOutsidePackets(via, out[:0], signedPayload, h, fwPacket, lhf, nb, q, localCache, now) return case ForwardingType: // Find the target HostInfo relay object @@ -223,6 +224,217 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte, f.connectionManager.In(hostinfo) } +func (f *Interface) readOutsidePacketsMany(packets []*packet.Packet, out []*packet.OutPacket, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache, now time.Time) { + for i, pkt := range packets { + out[i].Scratch = out[i].Scratch[:0] + ip := pkt.AddrPort() + + //l.Error("in packet ", header, packet[HeaderLen:]) + if ip.IsValid() { + if f.myVpnNetworksTable.Contains(ip.Addr()) { + if f.l.Level >= logrus.DebugLevel { + f.l.WithField("udpAddr", ip).Debug("Refusing to process double encrypted packet") + } + return + } + } + + //todo per-segment! + for segment := range pkt.Segments() { + + err := h.Parse(segment) + if err != nil { + // Hole punch packets are 0 or 1 byte big, so lets ignore printing those errors + if len(segment) > 1 { + f.l.WithField("packet", pkt).Infof("Error while parsing inbound packet from %s: %s", ip, err) + } + return + } + + var hostinfo *HostInfo + // verify if we've seen this index before, otherwise respond to the handshake initiation + if h.Type == header.Message && h.Subtype == header.MessageRelay { + hostinfo = f.hostMap.QueryRelayIndex(h.RemoteIndex) + } else { + hostinfo = f.hostMap.QueryIndex(h.RemoteIndex) + } + + var ci *ConnectionState + if hostinfo != nil { + ci = hostinfo.ConnectionState + } + + switch h.Type { + case header.Message: + // TODO handleEncrypted sends directly to addr on error. Handle this in the tunneling case. + if !f.handleEncrypted(ci, ip, h) { + return + } + + switch h.Subtype { + case header.MessageNone: + if !f.decryptToTunDelayWrite(hostinfo, h.MessageCounter, out[i], pkt, segment, fwPacket, nb, q, localCache, now) { + return + } + case header.MessageRelay: + // The entire body is sent as AD, not encrypted. + // The packet consists of a 16-byte parsed Nebula header, Associated Data-protected payload, and a trailing 16-byte AEAD signature value. + // The packet is guaranteed to be at least 16 bytes at this point, b/c it got past the h.Parse() call above. If it's + // otherwise malformed (meaning, there is no trailing 16 byte AEAD value), then this will result in at worst a 0-length slice + // which will gracefully fail in the DecryptDanger call. + signedPayload := segment[:len(segment)-hostinfo.ConnectionState.dKey.Overhead()] + signatureValue := segment[len(segment)-hostinfo.ConnectionState.dKey.Overhead():] + out[i].Scratch, err = hostinfo.ConnectionState.dKey.DecryptDanger(out[i].Scratch, signedPayload, signatureValue, h.MessageCounter, nb) + if err != nil { + return + } + // Successfully validated the thing. Get rid of the Relay header. + signedPayload = signedPayload[header.Len:] + // Pull the Roaming parts up here, and return in all call paths. + f.handleHostRoaming(hostinfo, ip) + // Track usage of both the HostInfo and the Relay for the received & authenticated packet + f.connectionManager.In(hostinfo) + f.connectionManager.RelayUsed(h.RemoteIndex) + + relay, ok := hostinfo.relayState.QueryRelayForByIdx(h.RemoteIndex) + if !ok { + // The only way this happens is if hostmap has an index to the correct HostInfo, but the HostInfo is missing + // its internal mapping. This should never happen. + hostinfo.logger(f.l).WithFields(logrus.Fields{"vpnAddrs": hostinfo.vpnAddrs, "remoteIndex": h.RemoteIndex}).Error("HostInfo missing remote relay index") + return + } + + switch relay.Type { + case TerminalType: + // If I am the target of this relay, process the unwrapped packet + // From this recursive point, all these variables are 'burned'. We shouldn't rely on them again. + f.readOutsidePackets(netip.AddrPort{}, &ViaSender{relayHI: hostinfo, remoteIdx: relay.RemoteIndex, relay: relay}, out[i].Scratch[:0], signedPayload, h, fwPacket, lhf, nb, q, localCache, now) + return + case ForwardingType: + // Find the target HostInfo relay object + targetHI, targetRelay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relay.PeerAddr) + if err != nil { + hostinfo.logger(f.l).WithField("relayTo", relay.PeerAddr).WithError(err).WithField("hostinfo.vpnAddrs", hostinfo.vpnAddrs).Info("Failed to find target host info by ip") + return + } + + // If that relay is Established, forward the payload through it + if targetRelay.State == Established { + switch targetRelay.Type { + case ForwardingType: + // Forward this packet through the relay tunnel + // Find the target HostInfo + f.SendVia(targetHI, targetRelay, signedPayload, nb, out[i].Scratch, false) + return + case TerminalType: + hostinfo.logger(f.l).Error("Unexpected Relay Type of Terminal") + } + } else { + hostinfo.logger(f.l).WithFields(logrus.Fields{"relayTo": relay.PeerAddr, "relayFrom": hostinfo.vpnAddrs[0], "targetRelayState": targetRelay.State}).Info("Unexpected target relay state") + return + } + } + } + + case header.LightHouse: + f.messageMetrics.Rx(h.Type, h.Subtype, 1) + if !f.handleEncrypted(ci, ip, h) { + return + } + + d, err := f.decrypt(hostinfo, h.MessageCounter, out[i].Scratch, segment, h, nb) + if err != nil { + hostinfo.logger(f.l).WithError(err).WithField("udpAddr", ip). + WithField("packet", segment). + Error("Failed to decrypt lighthouse packet") + return + } + + lhf.HandleRequest(ip, hostinfo.vpnAddrs, d, f) + + // Fallthrough to the bottom to record incoming traffic + + case header.Test: + f.messageMetrics.Rx(h.Type, h.Subtype, 1) + if !f.handleEncrypted(ci, ip, h) { + return + } + + d, err := f.decrypt(hostinfo, h.MessageCounter, out[i].Scratch, segment, h, nb) + if err != nil { + hostinfo.logger(f.l).WithError(err).WithField("udpAddr", ip). + WithField("packet", segment). + Error("Failed to decrypt test packet") + return + } + + if h.Subtype == header.TestRequest { + // This testRequest might be from TryPromoteBest, so we should roam + // to the new IP address before responding + f.handleHostRoaming(hostinfo, ip) + f.send(header.Test, header.TestReply, ci, hostinfo, d, nb, out[i].Scratch) + } + + // Fallthrough to the bottom to record incoming traffic + + // Non encrypted messages below here, they should not fall through to avoid tracking incoming traffic since they + // are unauthenticated + + case header.Handshake: + f.messageMetrics.Rx(h.Type, h.Subtype, 1) + f.handshakeManager.HandleIncoming(ip, nil, segment, h) + return + + case header.RecvError: + f.messageMetrics.Rx(h.Type, h.Subtype, 1) + f.handleRecvError(ip, h) + return + + case header.CloseTunnel: + f.messageMetrics.Rx(h.Type, h.Subtype, 1) + if !f.handleEncrypted(ci, ip, h) { + return + } + + hostinfo.logger(f.l).WithField("udpAddr", ip). + Info("Close tunnel received, tearing down.") + + f.closeTunnel(hostinfo) + return + + case header.Control: + if !f.handleEncrypted(ci, ip, h) { + return + } + + d, err := f.decrypt(hostinfo, h.MessageCounter, out[i].Scratch, segment, h, nb) + if err != nil { + hostinfo.logger(f.l).WithError(err).WithField("udpAddr", ip). + WithField("packet", segment). + Error("Failed to decrypt Control packet") + return + } + + f.relayManager.HandleControlMsg(hostinfo, d, f) + + default: + f.messageMetrics.Rx(h.Type, h.Subtype, 1) + hostinfo.logger(f.l).Debugf("Unexpected packet received from %s", ip) + return + } + + f.handleHostRoaming(hostinfo, ip) + + f.connectionManager.In(hostinfo) + + } + _, err := f.readers[q].WriteOne(out[i], false, q) + if err != nil { + f.l.WithError(err).Error("Failed to write packet") + } + } +} + // closeTunnel closes a tunnel locally, it does not send a closeTunnel packet to the remote func (f *Interface) closeTunnel(hostInfo *HostInfo) { final := f.hostMap.DeleteHostInfo(hostInfo) @@ -472,7 +684,55 @@ func (f *Interface) decrypt(hostinfo *HostInfo, mc uint64, out []byte, packet [] return out, nil } -func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out []byte, packet []byte, fwPacket *firewall.Packet, nb []byte, q int, localCache firewall.ConntrackCache) bool { +func (f *Interface) decryptToTunDelayWrite(hostinfo *HostInfo, messageCounter uint64, out *packet.OutPacket, pkt *packet.Packet, inSegment []byte, fwPacket *firewall.Packet, nb []byte, q int, localCache firewall.ConntrackCache, now time.Time) bool { + var err error + + seg, err := f.readers[q].AllocSeg(out, q) + if err != nil { + f.l.WithError(err).Errorln("decryptToTunDelayWrite: failed to allocate segment") + return false + } + + out.SegmentPayloads[seg] = out.SegmentPayloads[seg][:0] + out.SegmentPayloads[seg], err = hostinfo.ConnectionState.dKey.DecryptDanger(out.SegmentPayloads[seg], inSegment[:header.Len], inSegment[header.Len:], messageCounter, nb) + if err != nil { + hostinfo.logger(f.l).WithError(err).Error("Failed to decrypt packet") + return false + } + + err = newPacket(out.SegmentPayloads[seg], true, fwPacket) + if err != nil { + hostinfo.logger(f.l).WithError(err).WithField("packet", out). + Warnf("Error while validating inbound packet") + return false + } + + if !hostinfo.ConnectionState.window.Update(f.l, messageCounter) { + hostinfo.logger(f.l).WithField("fwPacket", fwPacket). + Debugln("dropping out of window packet") + return false + } + + dropReason := f.firewall.Drop(*fwPacket, true, hostinfo, f.pki.GetCAPool(), localCache, now) + if dropReason != nil { + // NOTE: We give `packet` as the `out` here since we already decrypted from it and we don't need it anymore + // This gives us a buffer to build the reject packet in + f.rejectOutside(out.SegmentPayloads[seg], hostinfo.ConnectionState, hostinfo, nb, inSegment, q) + if f.l.Level >= logrus.DebugLevel { + hostinfo.logger(f.l).WithField("fwPacket", fwPacket). + WithField("reason", dropReason). + Debugln("dropping inbound packet") + } + return false + } + + f.connectionManager.In(hostinfo) + pkt.OutLen += len(inSegment) + out.Segments[seg] = out.Segments[seg][:len(out.SegmentHeaders[seg])+len(out.SegmentPayloads[seg])] + return true +} + +func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out []byte, packet []byte, fwPacket *firewall.Packet, nb []byte, q int, localCache firewall.ConntrackCache, now time.Time) bool { var err error out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:header.Len], packet[header.Len:], messageCounter, nb) @@ -494,7 +754,7 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out return false } - dropReason := f.firewall.Drop(*fwPacket, true, hostinfo, f.pki.GetCAPool(), localCache) + dropReason := f.firewall.Drop(*fwPacket, true, hostinfo, f.pki.GetCAPool(), localCache, now) if dropReason != nil { // NOTE: We give `packet` as the `out` here since we already decrypted from it and we don't need it anymore // This gives us a buffer to build the reject packet in diff --git a/overlay/device.go b/overlay/device.go index b6077ab..63de9f3 100644 --- a/overlay/device.go +++ b/overlay/device.go @@ -1,18 +1,17 @@ package overlay import ( - "io" "net/netip" "github.com/slackhq/nebula/routing" ) type Device interface { - io.ReadWriteCloser + TunDev Activate() error Networks() []netip.Prefix Name() string RoutesFor(netip.Addr) routing.Gateways SupportsMultiqueue() bool - NewMultiQueueReader() (io.ReadWriteCloser, error) + NewMultiQueueReader() (TunDev, error) } diff --git a/overlay/eventfd/eventfd.go b/overlay/eventfd/eventfd.go new file mode 100644 index 0000000..cf3dd0d --- /dev/null +++ b/overlay/eventfd/eventfd.go @@ -0,0 +1,91 @@ +package eventfd + +import ( + "encoding/binary" + "syscall" + + "golang.org/x/sys/unix" +) + +type EventFD struct { + fd int + buf [8]byte +} + +func New() (EventFD, error) { + fd, err := unix.Eventfd(0, unix.EFD_NONBLOCK) + if err != nil { + return EventFD{}, err + } + return EventFD{ + fd: fd, + buf: [8]byte{}, + }, nil +} + +func (e *EventFD) Kick() error { + binary.LittleEndian.PutUint64(e.buf[:], 1) //is this right??? + _, err := syscall.Write(int(e.fd), e.buf[:]) + return err +} + +func (e *EventFD) Close() error { + if e.fd != 0 { + return unix.Close(e.fd) + } + return nil +} + +func (e *EventFD) FD() int { + return e.fd +} + +type Epoll struct { + fd int + buf [8]byte + events []syscall.EpollEvent +} + +func NewEpoll() (Epoll, error) { + fd, err := unix.EpollCreate1(0) + if err != nil { + return Epoll{}, err + } + return Epoll{ + fd: fd, + buf: [8]byte{}, + events: make([]syscall.EpollEvent, 1), + }, nil +} + +func (ep *Epoll) AddEvent(fdToAdd int) error { + event := syscall.EpollEvent{ + Events: syscall.EPOLLIN, + Fd: int32(fdToAdd), + } + return syscall.EpollCtl(ep.fd, syscall.EPOLL_CTL_ADD, fdToAdd, &event) +} + +func (ep *Epoll) Block() (int, error) { + n, err := syscall.EpollWait(ep.fd, ep.events, -1) + if err != nil { + //goland:noinspection GoDirectComparisonOfErrors + if err == syscall.EINTR { + return 0, nil //?? + } + return -1, err + } + return n, nil +} + +func (ep *Epoll) Clear() error { + _, err := syscall.Read(int(ep.events[0].Fd), ep.buf[:]) + return err +} + +func (ep *Epoll) Close() error { + if ep.fd != 0 { + return unix.Close(ep.fd) + } + return nil +} diff --git a/overlay/tun.go b/overlay/tun.go index 3a61d18..1dc914c 100644 --- a/overlay/tun.go +++ b/overlay/tun.go @@ -2,16 +2,29 @@ package overlay import ( "fmt" + "io" "net" "net/netip" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/packet" "github.com/slackhq/nebula/util" ) const DefaultMTU = 1300 +type TunDev interface { + io.WriteCloser + ReadMany(x []*packet.VirtIOPacket, q int) (int, error) + + //todo this interface sux + AllocSeg(pkt *packet.OutPacket, q int) (int, error) + WriteOne(x *packet.OutPacket, kick bool, q int) (int, error) + WriteMany(x []*packet.OutPacket, q int) (int, error) + RecycleRxSeg(pkt *packet.VirtIOPacket, kick bool, q int) error +} + // TODO: We may be able to remove routines type DeviceFactory func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) @@ -26,11 +39,11 @@ func NewDeviceFromConfig(c *config.C, l *logrus.Logger, vpnNetworks []netip.Pref } } -func NewFdDeviceFromConfig(fd *int) DeviceFactory { - return func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) { - return newTunFromFd(c, l, *fd, vpnNetworks) - } -} +//func NewFdDeviceFromConfig(fd *int) DeviceFactory { +// return func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) { +// return newTunFromFd(c, l, *fd, vpnNetworks) +// } +//} func getAllRoutesFromConfig(c *config.C, vpnNetworks []netip.Prefix, initial bool) (bool, []Route, error) { if !initial && !c.HasChanged("tun.routes") && !c.HasChanged("tun.unsafe_routes") { diff --git a/overlay/tun_disabled.go b/overlay/tun_disabled.go index aa3ddda..086a676 100644 --- a/overlay/tun_disabled.go +++ b/overlay/tun_disabled.go @@ -9,6 +9,8 @@ import ( "github.com/rcrowley/go-metrics" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/iputil" + "github.com/slackhq/nebula/overlay/virtqueue" + "github.com/slackhq/nebula/packet" "github.com/slackhq/nebula/routing" ) @@ -22,6 +24,10 @@ type disabledTun struct { l *logrus.Logger } +func (*disabledTun) RecycleRxSeg(pkt *packet.VirtIOPacket, kick bool, q int) error { + return nil +} + func newDisabledTun(vpnNetworks []netip.Prefix, queueLen int, metricsEnabled bool, l *logrus.Logger) *disabledTun { tun := &disabledTun{ vpnNetworks: vpnNetworks, @@ -40,6 +46,10 @@ func newDisabledTun(vpnNetworks []netip.Prefix, queueLen int, metricsEnabled boo return tun } +func (*disabledTun) GetQueues() []*virtqueue.SplitQueue { + return nil +} + func (*disabledTun) Activate() error { return nil } @@ -109,7 +119,23 @@ func (t *disabledTun) SupportsMultiqueue() bool { return true } -func (t *disabledTun) NewMultiQueueReader() (io.ReadWriteCloser, error) { +func (t *disabledTun) AllocSeg(pkt *packet.OutPacket, q int) (int, error) { + return 0, fmt.Errorf("tun_disabled: AllocSeg not implemented") +} + +func (t *disabledTun) WriteOne(x *packet.OutPacket, kick bool, q int) (int, error) { + return 0, fmt.Errorf("tun_disabled: WriteOne not implemented") +} + +func (t *disabledTun) WriteMany(x []*packet.OutPacket, q int) (int, error) { + return 0, fmt.Errorf("tun_disabled: WriteMany not implemented") +} + +func (t *disabledTun) ReadMany(b []*packet.VirtIOPacket, _ int) (int, error) { + return t.Read(b[0].Payload) +} + +func (t *disabledTun) NewMultiQueueReader() (TunDev, error) { return t, nil } diff --git a/overlay/tun_linux.go b/overlay/tun_linux.go index 32bf51f..386b37d 100644 --- a/overlay/tun_linux.go +++ b/overlay/tun_linux.go @@ -5,7 +5,6 @@ package overlay import ( "fmt" - "io" "net" "net/netip" "os" @@ -17,15 +16,19 @@ import ( "github.com/gaissmai/bart" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/overlay/vhostnet" + "github.com/slackhq/nebula/packet" "github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/util" + "github.com/slackhq/nebula/util/virtio" "github.com/vishvananda/netlink" "golang.org/x/sys/unix" ) type tun struct { - io.ReadWriteCloser + file *os.File fd int + vdev []*vhostnet.Device Device string vpnNetworks []netip.Prefix MaxMTU int @@ -40,7 +43,8 @@ type tun struct { useSystemRoutes bool useSystemRoutesBufferSize int - l *logrus.Logger + isV6 bool + l *logrus.Logger } func (t *tun) Networks() []netip.Prefix { @@ -102,7 +106,7 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueu } var req ifReq - req.Flags = uint16(unix.IFF_TUN | unix.IFF_NO_PI) + req.Flags = uint16(unix.IFF_TUN | unix.IFF_NO_PI | unix.IFF_TUN_EXCL | unix.IFF_VNET_HDR | unix.IFF_NAPI) if multiqueue { req.Flags |= unix.IFF_MULTI_QUEUE } @@ -112,20 +116,47 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueu } name := strings.Trim(string(req.Name[:]), "\x00") + if err = unix.SetNonblock(fd, true); err != nil { + _ = unix.Close(fd) + return nil, fmt.Errorf("make file descriptor non-blocking: %w", err) + } + file := os.NewFile(uintptr(fd), "/dev/net/tun") + + err = unix.IoctlSetPointerInt(fd, unix.TUNSETVNETHDRSZ, virtio.NetHdrSize) + if err != nil { + return nil, fmt.Errorf("set vnethdr size: %w", err) + } + + flags := 0 + //flags = //unix.TUN_F_CSUM //| unix.TUN_F_TSO4 | unix.TUN_F_USO4 | unix.TUN_F_TSO6 | unix.TUN_F_USO6 + err = unix.IoctlSetInt(fd, unix.TUNSETOFFLOAD, flags) + if err != nil { + return nil, fmt.Errorf("set offloads: %w", err) + } + t, err := newTunGeneric(c, l, file, vpnNetworks) if err != nil { return nil, err } - + t.fd = fd t.Device = name + vdev, err := vhostnet.NewDevice( + vhostnet.WithBackendFD(fd), + vhostnet.WithQueueSize(8192), //todo config + ) + if err != nil { + return nil, err + } + t.vdev = []*vhostnet.Device{vdev} + return t, nil } func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, vpnNetworks []netip.Prefix) (*tun, error) { t := &tun{ - ReadWriteCloser: file, + file: file, fd: int(file.Fd()), vpnNetworks: vpnNetworks, TXQueueLen: c.GetInt("tun.tx_queue", 500), @@ -133,6 +164,9 @@ func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, vpnNetworks []n useSystemRoutesBufferSize: c.GetInt("tun.use_system_route_table_buffer_size", 0), l: l, } + if len(vpnNetworks) != 0 { + t.isV6 = vpnNetworks[0].Addr().Is6() //todo what about multi-IP? + } err := t.reload(c, true) if err != nil { @@ -220,7 +254,7 @@ func (t *tun) SupportsMultiqueue() bool { return true } -func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { +func (t *tun) NewMultiQueueReader() (TunDev, error) { fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0) if err != nil { return nil, err @@ -233,9 +267,17 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { return nil, err } - file := os.NewFile(uintptr(fd), "/dev/net/tun") + vdev, err := vhostnet.NewDevice( + vhostnet.WithBackendFD(fd), + vhostnet.WithQueueSize(8192), //todo config + ) + if err != nil { + return nil, err + } - return file, nil + t.vdev = append(t.vdev, vdev) + + return t, nil } func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways { @@ -243,29 +285,6 @@ func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways { return r } -func (t *tun) Write(b []byte) (int, error) { - var nn int - maximum := len(b) - - for { - n, err := unix.Write(t.fd, b[nn:maximum]) - if n > 0 { - nn += n - } - if nn == len(b) { - return nn, err - } - - if err != nil { - return nn, err - } - - if n == 0 { - return nn, io.ErrUnexpectedEOF - } - } -} - func (t *tun) deviceBytes() (o [16]byte) { for i, c := range t.Device { o[i] = byte(c) @@ -689,8 +708,14 @@ func (t *tun) Close() error { close(t.routeChan) } - if t.ReadWriteCloser != nil { - _ = t.ReadWriteCloser.Close() + for _, v := range t.vdev { + if v != nil { + _ = v.Close() + } + } + + if t.file != nil { + _ = t.file.Close() } if t.ioctlFd > 0 { @@ -699,3 +724,65 @@ func (t *tun) Close() error { return nil } + +func (t *tun) ReadMany(p []*packet.VirtIOPacket, q int) (int, error) { + n, err := t.vdev[q].ReceivePackets(p) //we are TXing + if err != nil { + return 0, err + } + return n, nil +} + +func (t *tun) Write(b []byte) (int, error) { + maximum := len(b) //we are RXing + + //todo garbagey + out := packet.NewOut() + x, err := t.AllocSeg(out, 0) + if err != nil { + return 0, err + } + copy(out.SegmentPayloads[x], b) + err = t.vdev[0].TransmitPacket(out, true) + + if err != nil { + t.l.WithError(err).Error("Transmitting packet") + return 0, err + } + return maximum, nil +} + +func (t *tun) AllocSeg(pkt *packet.OutPacket, q int) (int, error) { + idx, buf, err := t.vdev[q].GetPacketForTx() + if err != nil { + return 0, err + } + x := pkt.UseSegment(idx, buf, t.isV6) + return x, nil +} + +func (t *tun) WriteOne(x *packet.OutPacket, kick bool, q int) (int, error) { + if err := t.vdev[q].TransmitPacket(x, kick); err != nil { + t.l.WithError(err).Error("Transmitting packet") + return 0, err + } + return 1, nil +} + +func (t *tun) WriteMany(x []*packet.OutPacket, q int) (int, error) { + maximum := len(x) //we are RXing + if maximum == 0 { + return 0, nil + } + + err := t.vdev[q].TransmitPackets(x) + if err != nil { + t.l.WithError(err).Error("Transmitting packet") + return 0, err + } + return maximum, nil +} + +func (t *tun) RecycleRxSeg(pkt *packet.VirtIOPacket, kick bool, q int) error { + return t.vdev[q].ReceiveQueue.OfferDescriptorChains(pkt.Chains, kick) +} diff --git a/overlay/user.go b/overlay/user.go index 1f92d4e..0a5857e 100644 --- a/overlay/user.go +++ b/overlay/user.go @@ -1,11 +1,13 @@ package overlay import ( + "fmt" "io" "net/netip" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/packet" "github.com/slackhq/nebula/routing" ) @@ -36,6 +38,10 @@ type UserDevice struct { inboundWriter *io.PipeWriter } +func (d *UserDevice) RecycleRxSeg(pkt *packet.VirtIOPacket, kick bool, q int) error { + return nil +} + func (d *UserDevice) Activate() error { return nil } @@ -50,7 +56,7 @@ func (d *UserDevice) SupportsMultiqueue() bool { return true } -func (d *UserDevice) NewMultiQueueReader() (io.ReadWriteCloser, error) { +func (d *UserDevice) NewMultiQueueReader() (TunDev, error) { return d, nil } @@ -69,3 +75,19 @@ func (d *UserDevice) Close() error { d.outboundWriter.Close() return nil } + +func (d *UserDevice) ReadMany(b []*packet.VirtIOPacket, _ int) (int, error) { + return d.Read(b[0].Payload) +} + +func (d *UserDevice) AllocSeg(pkt *packet.OutPacket, q int) (int, error) { + return 0, fmt.Errorf("user: AllocSeg not implemented") +} + +func (d *UserDevice) WriteOne(x *packet.OutPacket, kick bool, q int) (int, error) { + return 0, fmt.Errorf("user: WriteOne not implemented") +} + +func (d *UserDevice) WriteMany(x []*packet.OutPacket, q int) (int, error) { + return 0, fmt.Errorf("user: WriteMany not implemented") +} diff --git a/overlay/vhost/README.md b/overlay/vhost/README.md new file mode 100644 index 0000000..1116e3d --- /dev/null +++ b/overlay/vhost/README.md @@ -0,0 +1,23 @@ +Significant portions of this code are derived from https://pkg.go.dev/github.com/hetznercloud/virtio-go + +MIT License + +Copyright (c) 2025 Hetzner Cloud GmbH + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/overlay/vhost/doc.go b/overlay/vhost/doc.go new file mode 100644 index 0000000..3d519b2 --- /dev/null +++ b/overlay/vhost/doc.go @@ -0,0 +1,4 @@ +// Package vhost implements the basic ioctl requests needed to interact with the +// kernel-level virtio server that provides accelerated virtio devices for +// networking and more. +package vhost diff --git a/overlay/vhost/ioctl.go b/overlay/vhost/ioctl.go new file mode 100644 index 0000000..7e855db --- /dev/null +++ b/overlay/vhost/ioctl.go @@ -0,0 +1,218 @@ +package vhost + +import ( + "fmt" + "unsafe" + + "github.com/slackhq/nebula/overlay/virtqueue" + "github.com/slackhq/nebula/util/virtio" + "golang.org/x/sys/unix" +) + +const ( + // vhostIoctlGetFeatures can be used to retrieve the features supported by + // the vhost implementation in the kernel. + // + // Response payload: [virtio.Feature] + // Kernel name: VHOST_GET_FEATURES + vhostIoctlGetFeatures = 0x8008af00 + + // vhostIoctlSetFeatures can be used to communicate the features supported + // by this virtio implementation to the kernel. + // + // Request payload: [virtio.Feature] + // Kernel name: VHOST_SET_FEATURES + vhostIoctlSetFeatures = 0x4008af00 + + // vhostIoctlSetOwner can be used to set the current process as the + // exclusive owner of a control file descriptor. + // + // Request payload: none + // Kernel name: VHOST_SET_OWNER + vhostIoctlSetOwner = 0x0000af01 + + // vhostIoctlSetMemoryLayout can be used to set up or modify the memory + // layout which describes the IOTLB mappings in the kernel. + // + // Request payload: [MemoryLayout] with custom serialization + // Kernel name: VHOST_SET_MEM_TABLE + vhostIoctlSetMemoryLayout = 0x4008af03 + + // vhostIoctlSetQueueSize can be used to set the size of the virtqueue. + // + // Request payload: [QueueState] + // Kernel name: VHOST_SET_VRING_NUM + vhostIoctlSetQueueSize = 0x4008af10 + + // vhostIoctlSetQueueAddress can be used to set the addresses of the + // different parts of the virtqueue. + // + // Request payload: [QueueAddresses] + // Kernel name: VHOST_SET_VRING_ADDR + vhostIoctlSetQueueAddress = 0x4028af11 + + // vhostIoctlSetAvailableRingBase can be used to set the index of the next + // available ring entry the device will process. + // + // Request payload: [QueueState] + // Kernel name: VHOST_SET_VRING_BASE + vhostIoctlSetAvailableRingBase = 0x4008af12 + + // vhostIoctlSetQueueKickEventFD can be used to set the event file + // descriptor to signal the device when descriptor chains were added to the + // available ring. + // + // Request payload: [QueueFile] + // Kernel name: VHOST_SET_VRING_KICK + vhostIoctlSetQueueKickEventFD = 0x4008af20 + + // vhostIoctlSetQueueCallEventFD can be used to set the event file + // descriptor that gets signaled by the device when descriptor chains have + // been used by it. + // + // Request payload: [QueueFile] + // Kernel name: VHOST_SET_VRING_CALL + vhostIoctlSetQueueCallEventFD = 0x4008af21 +) + +// QueueState is an ioctl request payload that can hold a queue index and any +// 32-bit number. +// +// Kernel name: vhost_vring_state +type QueueState struct { + // QueueIndex is the index of the virtqueue. + QueueIndex uint32 + // Num is any 32-bit number, depending on the request. + Num uint32 +} + +// QueueAddresses is an ioctl request payload that can hold the addresses of the +// different parts of a virtqueue. +// +// Kernel name: vhost_vring_addr +type QueueAddresses struct { + // QueueIndex is the index of the virtqueue. + QueueIndex uint32 + // Flags that are not used in this implementation. + Flags uint32 + // DescriptorTableAddress is the address of the descriptor table in user + // space memory. It must be 16-byte aligned. + DescriptorTableAddress uintptr + // UsedRingAddress is the address of the used ring in user space memory. It + // must be 4-byte aligned. + UsedRingAddress uintptr + // AvailableRingAddress is the address of the available ring in user space + // memory. It must be 2-byte aligned. + AvailableRingAddress uintptr + // LogAddress is used for an optional logging support, not supported by this + // implementation. + LogAddress uintptr +} + +// QueueFile is an ioctl request payload that can hold a queue index and a file +// descriptor. +// +// Kernel name: vhost_vring_file +type QueueFile struct { + // QueueIndex is the index of the virtqueue. + QueueIndex uint32 + // FD is the file descriptor of the file. Pass -1 to unbind from a file. + FD int32 +} + +// IoctlPtr is a copy of the similarly named unexported function from the Go +// unix package. This is needed to do custom ioctl requests not supported by the +// standard library. +func IoctlPtr(fd int, req uint, arg unsafe.Pointer) error { + _, _, err := unix.Syscall(unix.SYS_IOCTL, uintptr(fd), uintptr(req), uintptr(arg)) + if err != 0 { + return fmt.Errorf("ioctl request %d: %w", req, err) + } + return nil +} + +// GetFeatures requests the supported feature bits from the virtio device +// associated with the given control file descriptor. +func GetFeatures(controlFD int) (virtio.Feature, error) { + var features virtio.Feature + if err := IoctlPtr(controlFD, vhostIoctlGetFeatures, unsafe.Pointer(&features)); err != nil { + return 0, fmt.Errorf("get features: %w", err) + } + return features, nil +} + +// SetFeatures communicates the feature bits supported by this implementation +// to the virtio device associated with the given control file descriptor. +func SetFeatures(controlFD int, features virtio.Feature) error { + if err := IoctlPtr(controlFD, vhostIoctlSetFeatures, unsafe.Pointer(&features)); err != nil { + return fmt.Errorf("set features: %w", err) + } + return nil +} + +// OwnControlFD sets the current process as the exclusive owner for the +// given control file descriptor. This must be called before interacting with +// the control file descriptor in any other way. +func OwnControlFD(controlFD int) error { + if err := IoctlPtr(controlFD, vhostIoctlSetOwner, unsafe.Pointer(nil)); err != nil { + return fmt.Errorf("set control file descriptor owner: %w", err) + } + return nil +} + +// SetMemoryLayout sets up or modifies the memory layout for the kernel-level +// virtio device associated with the given control file descriptor. +func SetMemoryLayout(controlFD int, layout MemoryLayout) error { + payload := layout.serializePayload() + if err := IoctlPtr(controlFD, vhostIoctlSetMemoryLayout, unsafe.Pointer(&payload[0])); err != nil { + return fmt.Errorf("set memory layout: %w", err) + } + return nil +} + +// RegisterQueue registers a virtio queue with the kernel-level virtio server. +// The virtqueue will be linked to the given control file descriptor and will +// have the given index. The kernel will use this queue until the control file +// descriptor is closed. +func RegisterQueue(controlFD int, queueIndex uint32, queue *virtqueue.SplitQueue) error { + if err := IoctlPtr(controlFD, vhostIoctlSetQueueSize, unsafe.Pointer(&QueueState{ + QueueIndex: queueIndex, + Num: uint32(queue.Size()), + })); err != nil { + return fmt.Errorf("set queue size: %w", err) + } + + if err := IoctlPtr(controlFD, vhostIoctlSetQueueAddress, unsafe.Pointer(&QueueAddresses{ + QueueIndex: queueIndex, + Flags: 0, + DescriptorTableAddress: queue.DescriptorTable().Address(), + UsedRingAddress: queue.UsedRing().Address(), + AvailableRingAddress: queue.AvailableRing().Address(), + LogAddress: 0, + })); err != nil { + return fmt.Errorf("set queue addresses: %w", err) + } + + if err := IoctlPtr(controlFD, vhostIoctlSetAvailableRingBase, unsafe.Pointer(&QueueState{ + QueueIndex: queueIndex, + Num: 0, + })); err != nil { + return fmt.Errorf("set available ring base: %w", err) + } + + if err := IoctlPtr(controlFD, vhostIoctlSetQueueKickEventFD, unsafe.Pointer(&QueueFile{ + QueueIndex: queueIndex, + FD: int32(queue.KickEventFD()), + })); err != nil { + return fmt.Errorf("set kick event file descriptor: %w", err) + } + + if err := IoctlPtr(controlFD, vhostIoctlSetQueueCallEventFD, unsafe.Pointer(&QueueFile{ + QueueIndex: queueIndex, + FD: int32(queue.CallEventFD()), + })); err != nil { + return fmt.Errorf("set call event file descriptor: %w", err) + } + + return nil +} diff --git a/overlay/vhost/ioctl_test.go b/overlay/vhost/ioctl_test.go new file mode 100644 index 0000000..0732b96 --- /dev/null +++ b/overlay/vhost/ioctl_test.go @@ -0,0 +1,21 @@ +package vhost_test + +import ( + "testing" + "unsafe" + + "github.com/slackhq/nebula/overlay/vhost" + "github.com/stretchr/testify/assert" +) + +func TestQueueState_Size(t *testing.T) { + assert.EqualValues(t, 8, unsafe.Sizeof(vhost.QueueState{})) +} + +func TestQueueAddresses_Size(t *testing.T) { + assert.EqualValues(t, 40, unsafe.Sizeof(vhost.QueueAddresses{})) +} + +func TestQueueFile_Size(t *testing.T) { + assert.EqualValues(t, 8, unsafe.Sizeof(vhost.QueueFile{})) +} diff --git a/overlay/vhost/memory.go b/overlay/vhost/memory.go new file mode 100644 index 0000000..d9a94c3 --- /dev/null +++ b/overlay/vhost/memory.go @@ -0,0 +1,73 @@ +package vhost + +import ( + "encoding/binary" + "fmt" + "unsafe" + + "github.com/slackhq/nebula/overlay/virtqueue" +) + +// MemoryRegion describes a region of userspace memory which is being made +// accessible to a vhost device. +// +// Kernel name: vhost_memory_region +type MemoryRegion struct { + // GuestPhysicalAddress is the physical address of the memory region within + // the guest, when virtualization is used. When no virtualization is used, + // this should be the same as UserspaceAddress. + GuestPhysicalAddress uintptr + // Size is the size of the memory region. + Size uint64 + // UserspaceAddress is the virtual address in the userspace of the host + // where the memory region can be found. + UserspaceAddress uintptr + // Padding and room for flags. Currently unused. + _ uint64 +} + +// MemoryLayout is a list of [MemoryRegion]s. +type MemoryLayout []MemoryRegion + +// NewMemoryLayoutForQueues returns a new [MemoryLayout] that describes the +// memory pages used by the descriptor tables of the given queues. +func NewMemoryLayoutForQueues(queues []*virtqueue.SplitQueue) MemoryLayout { + regions := make([]MemoryRegion, 0) + for _, queue := range queues { + for address, size := range queue.DescriptorTable().BufferAddresses() { + regions = append(regions, MemoryRegion{ + // There is no virtualization in play here, so the guest address + // is the same as in the host's userspace. + GuestPhysicalAddress: address, + Size: uint64(size), + UserspaceAddress: address, + }) + } + } + return regions +} + +// serializePayload serializes the list of memory regions into a format that is +// compatible to the vhost_memory kernel struct. The returned byte slice can be +// used as a payload for the vhostIoctlSetMemoryLayout ioctl. +func (regions MemoryLayout) serializePayload() []byte { + regionCount := len(regions) + regionSize := int(unsafe.Sizeof(MemoryRegion{})) + payload := make([]byte, 8+regionCount*regionSize) + + // The first 32 bits contain the number of memory regions. The following 32 + // bits are padding. + binary.LittleEndian.PutUint32(payload[0:4], uint32(regionCount)) + + if regionCount > 0 { + // The underlying byte array of the slice should already have the correct + // format, so just copy that. + copied := copy(payload[8:], unsafe.Slice((*byte)(unsafe.Pointer(®ions[0])), regionCount*regionSize)) + if copied != regionCount*regionSize { + panic(fmt.Sprintf("copied only %d bytes of the memory regions, but expected %d", + copied, regionCount*regionSize)) + } + } + + return payload +} diff --git a/overlay/vhost/memory_internal_test.go b/overlay/vhost/memory_internal_test.go new file mode 100644 index 0000000..7257805 --- /dev/null +++ b/overlay/vhost/memory_internal_test.go @@ -0,0 +1,42 @@ +package vhost + +import ( + "testing" + "unsafe" + + "github.com/stretchr/testify/assert" +) + +func TestMemoryRegion_Size(t *testing.T) { + assert.EqualValues(t, 32, unsafe.Sizeof(MemoryRegion{})) +} + +func TestMemoryLayout_SerializePayload(t *testing.T) { + layout := MemoryLayout([]MemoryRegion{ + { + GuestPhysicalAddress: 42, + Size: 100, + UserspaceAddress: 142, + }, { + GuestPhysicalAddress: 99, + Size: 100, + UserspaceAddress: 99, + }, + }) + payload := layout.serializePayload() + + assert.Equal(t, []byte{ + 0x02, 0x00, 0x00, 0x00, // nregions + 0x00, 0x00, 0x00, 0x00, // padding + // region 0 + 0x2a, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // guest_phys_addr + 0x64, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // memory_size + 0x8e, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // userspace_addr + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // flags_padding + // region 1 + 0x63, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // guest_phys_addr + 0x64, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // memory_size + 0x63, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // userspace_addr + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // flags_padding + }, payload) +} diff --git a/overlay/vhostnet/README.md b/overlay/vhostnet/README.md new file mode 100644 index 0000000..1116e3d --- /dev/null +++ b/overlay/vhostnet/README.md @@ -0,0 +1,23 @@ +Significant portions of this code are derived from https://pkg.go.dev/github.com/hetznercloud/virtio-go + +MIT License + +Copyright (c) 2025 Hetzner Cloud GmbH + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/overlay/vhostnet/device.go b/overlay/vhostnet/device.go new file mode 100644 index 0000000..435424b --- /dev/null +++ b/overlay/vhostnet/device.go @@ -0,0 +1,372 @@ +package vhostnet + +import ( + "context" + "errors" + "fmt" + "os" + "runtime" + + "github.com/slackhq/nebula/overlay/vhost" + "github.com/slackhq/nebula/overlay/virtqueue" + "github.com/slackhq/nebula/packet" + "github.com/slackhq/nebula/util/virtio" + "golang.org/x/sys/unix" +) + +// ErrDeviceClosed is returned when the [Device] is closed while operations are +// still running. +var ErrDeviceClosed = errors.New("device was closed") + +// The indexes for the receive and transmit queues. +const ( + receiveQueueIndex = 0 + transmitQueueIndex = 1 +) + +// Device represents a vhost networking device within the kernel-level virtio +// implementation and provides methods to interact with it. +type Device struct { + initialized bool + controlFD int + + fullTable bool + ReceiveQueue *virtqueue.SplitQueue + TransmitQueue *virtqueue.SplitQueue +} + +// NewDevice initializes a new vhost networking device within the +// kernel-level virtio implementation, sets up the virtqueues and returns a +// [Device] instance that can be used to communicate with that vhost device. +// +// There are multiple options that can be passed to this constructor to +// influence device creation: +// - [WithQueueSize] +// - [WithBackendFD] +// - [WithBackendDevice] +// +// Remember to call [Device.Close] after use to free up resources. +func NewDevice(options ...Option) (*Device, error) { + var err error + opts := optionDefaults + opts.apply(options) + if err = opts.validate(); err != nil { + return nil, fmt.Errorf("invalid options: %w", err) + } + + dev := Device{ + controlFD: -1, + } + + // Clean up a partially initialized device when something fails. + defer func() { + if err != nil { + _ = dev.Close() + } + }() + + // Retrieve a new control file descriptor. This will be used to configure + // the vhost networking device in the kernel. + dev.controlFD, err = unix.Open("/dev/vhost-net", os.O_RDWR, 0666) + if err != nil { + return nil, fmt.Errorf("get control file descriptor: %w", err) + } + if err = vhost.OwnControlFD(dev.controlFD); err != nil { + return nil, fmt.Errorf("own control file descriptor: %w", err) + } + + // Advertise the supported features. This isn't much for now. + // TODO: Add feature options and implement proper feature negotiation. + getFeatures, err := vhost.GetFeatures(dev.controlFD) //0x1033D008000 but why + if err != nil { + return nil, fmt.Errorf("get features: %w", err) + } + if getFeatures == 0 { + + } + //const funky = virtio.Feature(1 << 27) + //features := virtio.FeatureVersion1 | funky // | todo virtio.FeatureNetMergeRXBuffers + features := virtio.FeatureVersion1 | virtio.FeatureNetMergeRXBuffers + if err = vhost.SetFeatures(dev.controlFD, features); err != nil { + return nil, fmt.Errorf("set features: %w", err) + } + + itemSize := os.Getpagesize() * 4 //todo config + + // Initialize and register the queues needed for the networking device. + if dev.ReceiveQueue, err = createQueue(dev.controlFD, receiveQueueIndex, opts.queueSize, itemSize); err != nil { + return nil, fmt.Errorf("create receive queue: %w", err) + } + if dev.TransmitQueue, err = createQueue(dev.controlFD, transmitQueueIndex, opts.queueSize, itemSize); err != nil { + return nil, fmt.Errorf("create transmit queue: %w", err) + } + + // Set up memory mappings for all buffers used by the queues. This has to + // happen before a backend for the queues can be registered. + memoryLayout := vhost.NewMemoryLayoutForQueues( + []*virtqueue.SplitQueue{dev.ReceiveQueue, dev.TransmitQueue}, + ) + if err = vhost.SetMemoryLayout(dev.controlFD, memoryLayout); err != nil { + return nil, fmt.Errorf("setup memory layout: %w", err) + } + + // Set the queue backends. This activates the queues within the kernel. + if err = SetQueueBackend(dev.controlFD, receiveQueueIndex, opts.backendFD); err != nil { + return nil, fmt.Errorf("set receive queue backend: %w", err) + } + if err = SetQueueBackend(dev.controlFD, transmitQueueIndex, opts.backendFD); err != nil { + return nil, fmt.Errorf("set transmit queue backend: %w", err) + } + + // Fully populate the receive queue with available buffers which the device + // can write new packets into. + if err = dev.refillReceiveQueue(); err != nil { + return nil, fmt.Errorf("refill receive queue: %w", err) + } + + dev.initialized = true + + // Make sure to clean up even when the device gets garbage collected without + // Close being called first. + devPtr := &dev + runtime.SetFinalizer(devPtr, (*Device).Close) + + return devPtr, nil +} + +// refillReceiveQueue offers as many new device-writable buffers to the device +// as the queue can fit. The device will then use these to write received +// packets. +func (dev *Device) refillReceiveQueue() error { + for { + _, err := dev.ReceiveQueue.OfferInDescriptorChains() + if err != nil { + if errors.Is(err, virtqueue.ErrNotEnoughFreeDescriptors) { + // Queue is full, job is done. + return nil + } + return fmt.Errorf("offer descriptor chain: %w", err) + } + } +} + +// Close cleans up the vhost networking device within the kernel and releases +// all resources used for it. +// The implementation will try to release as many resources as possible and +// collect potential errors before returning them. +func (dev *Device) Close() error { + dev.initialized = false + + // Closing the control file descriptor will unregister all queues from the + // kernel. + if dev.controlFD >= 0 { + if err := unix.Close(dev.controlFD); err != nil { + // Return an error and do not continue, because the memory used for + // the queues should not be released before they were unregistered + // from the kernel. + return fmt.Errorf("close control file descriptor: %w", err) + } + dev.controlFD = -1 + } + + var errs []error + + if dev.ReceiveQueue != nil { + if err := dev.ReceiveQueue.Close(); err == nil { + dev.ReceiveQueue = nil + } else { + errs = append(errs, fmt.Errorf("close receive queue: %w", err)) + } + } + + if dev.TransmitQueue != nil { + if err := dev.TransmitQueue.Close(); err == nil { + dev.TransmitQueue = nil + } else { + errs = append(errs, fmt.Errorf("close transmit queue: %w", err)) + } + } + + if len(errs) == 0 { + // Everything was cleaned up. No need to run the finalizer anymore. + runtime.SetFinalizer(dev, nil) + } + + return errors.Join(errs...) +} + +// createQueue creates a new virtqueue and registers it with the vhost device +// using the given index. +func createQueue(controlFD int, queueIndex int, queueSize int, itemSize int) (*virtqueue.SplitQueue, error) { + var ( + queue *virtqueue.SplitQueue + err error + ) + if queue, err = virtqueue.NewSplitQueue(queueSize, itemSize); err != nil { + return nil, fmt.Errorf("create virtqueue: %w", err) + } + if err = vhost.RegisterQueue(controlFD, uint32(queueIndex), queue); err != nil { + return nil, fmt.Errorf("register virtqueue with index %d: %w", queueIndex, err) + } + return queue, nil +} + +func (dev *Device) GetPacketForTx() (uint16, []byte, error) { + var err error + var idx uint16 + if !dev.fullTable { + idx, err = dev.TransmitQueue.DescriptorTable().CreateDescriptorForOutputs() + if err == virtqueue.ErrNotEnoughFreeDescriptors { + dev.fullTable = true + idx, err = dev.TransmitQueue.TakeSingle(context.TODO()) + } + } else { + idx, err = dev.TransmitQueue.TakeSingle(context.TODO()) + } + if err != nil { + return 0, nil, fmt.Errorf("transmit queue: %w", err) + } + buf, err := dev.TransmitQueue.GetDescriptorItem(idx) + if err != nil { + return 0, nil, fmt.Errorf("get descriptor chain: %w", err) + } + return idx, buf, nil +} + +func (dev *Device) TransmitPacket(pkt *packet.OutPacket, kick bool) error { + if len(pkt.SegmentIDs) == 0 { + return nil + } + for idx := range pkt.SegmentIDs { + segmentID := pkt.SegmentIDs[idx] + dev.TransmitQueue.SetDescSize(segmentID, len(pkt.Segments[idx])) + } + err := dev.TransmitQueue.OfferDescriptorChains(pkt.SegmentIDs, false) + if err != nil { + return fmt.Errorf("offer descriptor chains: %w", err) + } + pkt.Reset() + if kick { + if err := dev.TransmitQueue.Kick(); err != nil { + return err + } + } + + return nil +} + +func (dev *Device) TransmitPackets(pkts []*packet.OutPacket) error { + if len(pkts) == 0 { + return nil + } + + for i := range pkts { + if err := dev.TransmitPacket(pkts[i], false); err != nil { + return err + } + } + if err := dev.TransmitQueue.Kick(); err != nil { + return err + } + return nil +} + +// TODO: Make above methods cancelable by taking a context.Context argument? +// TODO: Implement zero-copy variants to transmit and receive packets? + +// processChains processes as many chains as needed to create one packet. The number of processed chains is returned. +func (dev *Device) processChains(pkt *packet.VirtIOPacket, chains []virtqueue.UsedElement) (int, error) { + //read first element to see how many descriptors we need: + pkt.Reset() + + err := dev.ReceiveQueue.GetDescriptorInbuffers(uint16(chains[0].DescriptorIndex), &pkt.ChainRefs) + if err != nil { + return 0, fmt.Errorf("get descriptor chain: %w", err) + } + if len(pkt.ChainRefs) == 0 { + return 1, nil + } + + // The specification requires that the first descriptor chain starts + // with a virtio-net header. It is not clear, whether it is also + // required to be fully contained in the first buffer of that + // descriptor chain, but it is reasonable to assume that this is + // always the case. + // The decode method already does the buffer length check. + if err = pkt.Header.Decode(pkt.ChainRefs[0][0:]); err != nil { + // The device misbehaved. There is no way we can gracefully + // recover from this, because we don't know how many of the + // following descriptor chains belong to this packet. + return 0, fmt.Errorf("decode vnethdr: %w", err) + } + + //we have the header now: what do we need to do? + if int(pkt.Header.NumBuffers) > len(chains) { + return 0, fmt.Errorf("number of buffers is greater than number of chains %d", len(chains)) + } + if int(pkt.Header.NumBuffers) != 1 { + return 0, fmt.Errorf("too smol-brain to handle more than one chain right now: %d chains", len(chains)) + } + if chains[0].Length > 16000 { + //todo! + return 1, fmt.Errorf("too big packet length: %d", chains[0].Length) + } + + //shift the buffer out of out: + pkt.Payload = pkt.ChainRefs[0][virtio.NetHdrSize:chains[0].Length] + pkt.Chains = append(pkt.Chains, uint16(chains[0].DescriptorIndex)) + return 1, nil + + //cursor := n - virtio.NetHdrSize + // + //if uint32(n) >= chains[0].Length && pkt.Header.NumBuffers == 1 { + // pkt.Payload = pkt.Payload[:chains[0].Length-virtio.NetHdrSize] + // return 1, nil + //} + // + //i := 1 + //// we used chain 0 already + //for i = 1; i < len(chains); i++ { + // n, err = dev.ReceiveQueue.GetDescriptorChainContents(uint16(chains[i].DescriptorIndex), pkt.Payload[cursor:], int(chains[i].Length)) + // if err != nil { + // // When this fails we may miss to free some descriptor chains. We + // // could try to mitigate this by deferring the freeing somehow, but + // // it's not worth the hassle. When this method fails, the queue will + // // be in a broken state anyway. + // return i, fmt.Errorf("get descriptor chain: %w", err) + // } + // cursor += n + //} + ////todo this has to be wrong + //pkt.Payload = pkt.Payload[:cursor] + //return i, nil +} + +func (dev *Device) ReceivePackets(out []*packet.VirtIOPacket) (int, error) { + //todo optimize? + var chains []virtqueue.UsedElement + var err error + + chains, err = dev.ReceiveQueue.BlockAndGetHeadsCapped(context.TODO(), len(out)) + if err != nil { + return 0, err + } + if len(chains) == 0 { + return 0, nil + } + + numPackets := 0 + chainsIdx := 0 + for numPackets = 0; chainsIdx < len(chains); numPackets++ { + if numPackets >= len(out) { + return numPackets, fmt.Errorf("dropping %d packets, no room", len(chains)-numPackets) + } + numChains, err := dev.processChains(out[numPackets], chains[chainsIdx:]) + if err != nil { + return 0, err + } + chainsIdx += numChains + } + + return numPackets, nil +} diff --git a/overlay/vhostnet/doc.go b/overlay/vhostnet/doc.go new file mode 100644 index 0000000..126d326 --- /dev/null +++ b/overlay/vhostnet/doc.go @@ -0,0 +1,3 @@ +// Package vhostnet implements methods to initialize vhost networking devices +// within the kernel-level virtio implementation and communicate with them. +package vhostnet diff --git a/overlay/vhostnet/ioctl.go b/overlay/vhostnet/ioctl.go new file mode 100644 index 0000000..ba97629 --- /dev/null +++ b/overlay/vhostnet/ioctl.go @@ -0,0 +1,31 @@ +package vhostnet + +import ( + "fmt" + "unsafe" + + "github.com/slackhq/nebula/overlay/vhost" +) + +const ( + // vhostNetIoctlSetBackend can be used to attach a virtqueue to a RAW socket + // or TAP device. + // + // Request payload: [vhost.QueueFile] + // Kernel name: VHOST_NET_SET_BACKEND + vhostNetIoctlSetBackend = 0x4008af30 +) + +// SetQueueBackend attaches a virtqueue of the vhost networking device +// described by controlFD to the given backend file descriptor. +// The backend file descriptor can either be a RAW socket or a TAP device. When +// it is -1, the queue will be detached. +func SetQueueBackend(controlFD int, queueIndex uint32, backendFD int) error { + if err := vhost.IoctlPtr(controlFD, vhostNetIoctlSetBackend, unsafe.Pointer(&vhost.QueueFile{ + QueueIndex: queueIndex, + FD: int32(backendFD), + })); err != nil { + return fmt.Errorf("set queue backend file descriptor: %w", err) + } + return nil +} diff --git a/overlay/vhostnet/options.go b/overlay/vhostnet/options.go new file mode 100644 index 0000000..1d61914 --- /dev/null +++ b/overlay/vhostnet/options.go @@ -0,0 +1,69 @@ +package vhostnet + +import ( + "errors" + + "github.com/slackhq/nebula/overlay/virtqueue" +) + +type optionValues struct { + queueSize int + backendFD int +} + +func (o *optionValues) apply(options []Option) { + for _, option := range options { + option(o) + } +} + +func (o *optionValues) validate() error { + if o.queueSize == -1 { + return errors.New("queue size is required") + } + if err := virtqueue.CheckQueueSize(o.queueSize); err != nil { + return err + } + if o.backendFD == -1 { + return errors.New("backend file descriptor is required") + } + return nil +} + +var optionDefaults = optionValues{ + // Required. + queueSize: -1, + // Required. + backendFD: -1, +} + +// Option can be passed to [NewDevice] to influence device creation. +type Option func(*optionValues) + +// WithQueueSize returns an [Option] that sets the size of the TX and RX queues +// that are to be created for the device. It specifies the number of +// entries/buffers each queue can hold. This also affects the memory +// consumption. +// This is required and must be an integer from 1 to 32768 that is also a power +// of 2. +func WithQueueSize(queueSize int) Option { + return func(o *optionValues) { o.queueSize = queueSize } +} + +// WithBackendFD returns an [Option] that sets the file descriptor of the +// backend that will be used for the queues of the device. The device will write +// and read packets to/from that backend. The file descriptor can either be of a +// RAW socket or TUN/TAP device. +// Either this or [WithBackendDevice] is required. +func WithBackendFD(backendFD int) Option { + return func(o *optionValues) { o.backendFD = backendFD } +} + +//// WithBackendDevice returns an [Option] that sets the given TAP device as the +//// backend that will be used for the queues of the device. The device will +//// write and read packets to/from that backend. The TAP device should have been +//// created with the [tuntap.WithVirtioNetHdr] option enabled. +//// Either this or [WithBackendFD] is required. +//func WithBackendDevice(dev *tuntap.Device) Option { +// return func(o *optionValues) { o.backendFD = int(dev.File().Fd()) } +//} diff --git a/overlay/virtqueue/README.md b/overlay/virtqueue/README.md new file mode 100644 index 0000000..1116e3d --- /dev/null +++ b/overlay/virtqueue/README.md @@ -0,0 +1,23 @@ +Significant portions of this code are derived from https://pkg.go.dev/github.com/hetznercloud/virtio-go + +MIT License + +Copyright (c) 2025 Hetzner Cloud GmbH + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/overlay/virtqueue/available_ring.go b/overlay/virtqueue/available_ring.go new file mode 100644 index 0000000..abe540b --- /dev/null +++ b/overlay/virtqueue/available_ring.go @@ -0,0 +1,140 @@ +package virtqueue + +import ( + "fmt" + "unsafe" +) + +// availableRingFlag is a flag that describes an [AvailableRing]. +type availableRingFlag uint16 + +const ( + // availableRingFlagNoInterrupt is used by the guest to advise the host to + // not interrupt it when consuming a buffer. It's unreliable, so it's simply + // an optimization. + availableRingFlagNoInterrupt availableRingFlag = 1 << iota +) + +// availableRingSize is the number of bytes needed to store an [AvailableRing] +// with the given queue size in memory. +func availableRingSize(queueSize int) int { + return 6 + 2*queueSize +} + +// availableRingAlignment is the minimum alignment of an [AvailableRing] +// in memory, as required by the virtio spec. +const availableRingAlignment = 2 + +// AvailableRing is used by the driver to offer descriptor chains to the device. +// Each ring entry refers to the head of a descriptor chain. It is only written +// to by the driver and read by the device. +// +// Because the size of the ring depends on the queue size, we cannot define a +// Go struct with a static size that maps to the memory of the ring. Instead, +// this struct only contains pointers to the corresponding memory areas. +type AvailableRing struct { + initialized bool + + // flags that describe this ring. + flags *availableRingFlag + // ringIndex indicates where the driver would put the next entry into the + // ring (modulo the queue size). + ringIndex *uint16 + // ring references buffers using the index of the head of the descriptor + // chain in the [DescriptorTable]. It wraps around at queue size. + ring []uint16 + // usedEvent is not used by this implementation, but we reserve it anyway to + // avoid issues in case a device may try to access it, contrary to the + // virtio specification. + usedEvent *uint16 +} + +// newAvailableRing creates an available ring that uses the given underlying +// memory. The length of the memory slice must match the size needed for the +// ring (see [availableRingSize]) for the given queue size. +func newAvailableRing(queueSize int, mem []byte) *AvailableRing { + ringSize := availableRingSize(queueSize) + if len(mem) != ringSize { + panic(fmt.Sprintf("memory size (%v) does not match required size "+ + "for available ring: %v", len(mem), ringSize)) + } + + return &AvailableRing{ + initialized: true, + flags: (*availableRingFlag)(unsafe.Pointer(&mem[0])), + ringIndex: (*uint16)(unsafe.Pointer(&mem[2])), + ring: unsafe.Slice((*uint16)(unsafe.Pointer(&mem[4])), queueSize), + usedEvent: (*uint16)(unsafe.Pointer(&mem[ringSize-2])), + } +} + +// Address returns the pointer to the beginning of the ring in memory. +// Do not modify the memory directly to not interfere with this implementation. +func (r *AvailableRing) Address() uintptr { + if !r.initialized { + panic("available ring is not initialized") + } + return uintptr(unsafe.Pointer(r.flags)) +} + +// offer adds the given descriptor chain heads to the available ring and +// advances the ring index accordingly to make the device process the new +// descriptor chains. +func (r *AvailableRing) offerElements(chains []UsedElement) { + //always called under lock + //r.mu.Lock() + //defer r.mu.Unlock() + + // Add descriptor chain heads to the ring. + for offset, x := range chains { + // The 16-bit ring index may overflow. This is expected and is not an + // issue because the size of the ring array (which equals the queue + // size) is always a power of 2 and smaller than the highest possible + // 16-bit value. + insertIndex := int(*r.ringIndex+uint16(offset)) % len(r.ring) + r.ring[insertIndex] = x.GetHead() + } + + // Increase the ring index by the number of descriptor chains added to the + // ring. + *r.ringIndex += uint16(len(chains)) +} + +func (r *AvailableRing) offer(chains []uint16) { + //always called under lock + //r.mu.Lock() + //defer r.mu.Unlock() + + // Add descriptor chain heads to the ring. + for offset, x := range chains { + // The 16-bit ring index may overflow. This is expected and is not an + // issue because the size of the ring array (which equals the queue + // size) is always a power of 2 and smaller than the highest possible + // 16-bit value. + insertIndex := int(*r.ringIndex+uint16(offset)) % len(r.ring) + r.ring[insertIndex] = x + } + + // Increase the ring index by the number of descriptor chains added to the + // ring. + *r.ringIndex += uint16(len(chains)) +} + +func (r *AvailableRing) offerSingle(x uint16) { + //always called under lock + //r.mu.Lock() + //defer r.mu.Unlock() + + offset := 0 + // Add descriptor chain heads to the ring. + + // The 16-bit ring index may overflow. This is expected and is not an + // issue because the size of the ring array (which equals the queue + // size) is always a power of 2 and smaller than the highest possible + // 16-bit value. + insertIndex := int(*r.ringIndex+uint16(offset)) % len(r.ring) + r.ring[insertIndex] = x + + // Increase the ring index by the number of descriptor chains added to the ring. + *r.ringIndex += 1 +} diff --git a/overlay/virtqueue/available_ring_internal_test.go b/overlay/virtqueue/available_ring_internal_test.go new file mode 100644 index 0000000..aa330b9 --- /dev/null +++ b/overlay/virtqueue/available_ring_internal_test.go @@ -0,0 +1,71 @@ +package virtqueue + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestAvailableRing_MemoryLayout(t *testing.T) { + const queueSize = 2 + + memory := make([]byte, availableRingSize(queueSize)) + r := newAvailableRing(queueSize, memory) + + *r.flags = 0x01ff + *r.ringIndex = 1 + r.ring[0] = 0x1234 + r.ring[1] = 0x5678 + + assert.Equal(t, []byte{ + 0xff, 0x01, + 0x01, 0x00, + 0x34, 0x12, + 0x78, 0x56, + 0x00, 0x00, + }, memory) +} + +func TestAvailableRing_Offer(t *testing.T) { + const queueSize = 8 + + chainHeads := []uint16{42, 33, 69} + + tests := []struct { + name string + startRingIndex uint16 + expectedRingIndex uint16 + expectedRing []uint16 + }{ + { + name: "no overflow", + startRingIndex: 0, + expectedRingIndex: 3, + expectedRing: []uint16{42, 33, 69, 0, 0, 0, 0, 0}, + }, + { + name: "ring overflow", + startRingIndex: 6, + expectedRingIndex: 9, + expectedRing: []uint16{69, 0, 0, 0, 0, 0, 42, 33}, + }, + { + name: "index overflow", + startRingIndex: 65535, + expectedRingIndex: 2, + expectedRing: []uint16{33, 69, 0, 0, 0, 0, 0, 42}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + memory := make([]byte, availableRingSize(queueSize)) + r := newAvailableRing(queueSize, memory) + *r.ringIndex = tt.startRingIndex + + r.offer(chainHeads) + + assert.Equal(t, tt.expectedRingIndex, *r.ringIndex) + assert.Equal(t, tt.expectedRing, r.ring) + }) + } +} diff --git a/overlay/virtqueue/descriptor.go b/overlay/virtqueue/descriptor.go new file mode 100644 index 0000000..fc6b3ef --- /dev/null +++ b/overlay/virtqueue/descriptor.go @@ -0,0 +1,43 @@ +package virtqueue + +// descriptorFlag is a flag that describes a [Descriptor]. +type descriptorFlag uint16 + +const ( + // descriptorFlagHasNext marks a descriptor chain as continuing via the next + // field. + descriptorFlagHasNext descriptorFlag = 1 << iota + // descriptorFlagWritable marks a buffer as device write-only (otherwise + // device read-only). + descriptorFlagWritable + // descriptorFlagIndirect means the buffer contains a list of buffer + // descriptors to provide an additional layer of indirection. + // Only allowed when the [virtio.FeatureIndirectDescriptors] feature was + // negotiated. + descriptorFlagIndirect +) + +// descriptorSize is the number of bytes needed to store a [Descriptor] in +// memory. +const descriptorSize = 16 + +// Descriptor describes (a part of) a buffer which is either read-only for the +// device or write-only for the device (depending on [descriptorFlagWritable]). +// Multiple descriptors can be chained to produce a "descriptor chain" that can +// contain both device-readable and device-writable buffers. Device-readable +// descriptors always come first in a chain. A single, large buffer may be +// split up by chaining multiple similar descriptors that reference different +// memory pages. This is required, because buffers may exceed a single page size +// and the memory accessed by the device is expected to be continuous. +type Descriptor struct { + // address is the address to the continuous memory holding the data for this + // descriptor. + address uintptr + // length is the amount of bytes stored at address. + length uint32 + // flags that describe this descriptor. + flags descriptorFlag + // next contains the index of the next descriptor continuing this descriptor + // chain when the [descriptorFlagHasNext] flag is set. + next uint16 +} diff --git a/overlay/virtqueue/descriptor_internal_test.go b/overlay/virtqueue/descriptor_internal_test.go new file mode 100644 index 0000000..cd043a1 --- /dev/null +++ b/overlay/virtqueue/descriptor_internal_test.go @@ -0,0 +1,12 @@ +package virtqueue + +import ( + "testing" + "unsafe" + + "github.com/stretchr/testify/assert" +) + +func TestDescriptor_Size(t *testing.T) { + assert.EqualValues(t, descriptorSize, unsafe.Sizeof(Descriptor{})) +} diff --git a/overlay/virtqueue/descriptor_table.go b/overlay/virtqueue/descriptor_table.go new file mode 100644 index 0000000..298036f --- /dev/null +++ b/overlay/virtqueue/descriptor_table.go @@ -0,0 +1,465 @@ +package virtqueue + +import ( + "errors" + "fmt" + "math" + "unsafe" + + "golang.org/x/sys/unix" +) + +var ( + // ErrDescriptorChainEmpty is returned when a descriptor chain would contain + // no buffers, which is not allowed. + ErrDescriptorChainEmpty = errors.New("empty descriptor chains are not allowed") + + // ErrNotEnoughFreeDescriptors is returned when the free descriptors are + // exhausted, meaning that the queue is full. + ErrNotEnoughFreeDescriptors = errors.New("not enough free descriptors, queue is full") + + // ErrInvalidDescriptorChain is returned when a descriptor chain is not + // valid for a given operation. + ErrInvalidDescriptorChain = errors.New("invalid descriptor chain") +) + +// noFreeHead is used to mark when all descriptors are in use and we have no +// free chain. This value is impossible to occur as an index naturally, because +// it exceeds the maximum queue size. +const noFreeHead = uint16(math.MaxUint16) + +// descriptorTableSize is the number of bytes needed to store a +// [DescriptorTable] with the given queue size in memory. +func descriptorTableSize(queueSize int) int { + return descriptorSize * queueSize +} + +// descriptorTableAlignment is the minimum alignment of a [DescriptorTable] +// in memory, as required by the virtio spec. +const descriptorTableAlignment = 16 + +// DescriptorTable is a table that holds [Descriptor]s, addressed via their +// index in the slice. +type DescriptorTable struct { + descriptors []Descriptor + + // freeHeadIndex is the index of the head of the descriptor chain which + // contains all currently unused descriptors. When all descriptors are in + // use, this has the special value of noFreeHead. + freeHeadIndex uint16 + // freeNum tracks the number of descriptors which are currently not in use. + freeNum uint16 + + bufferBase uintptr + bufferSize int + itemSize int +} + +// newDescriptorTable creates a descriptor table that uses the given underlying +// memory. The Length of the memory slice must match the size needed for the +// descriptor table (see [descriptorTableSize]) for the given queue size. +// +// Before this descriptor table can be used, [initialize] must be called. +func newDescriptorTable(queueSize int, mem []byte, itemSize int) *DescriptorTable { + dtSize := descriptorTableSize(queueSize) + if len(mem) != dtSize { + panic(fmt.Sprintf("memory size (%v) does not match required size "+ + "for descriptor table: %v", len(mem), dtSize)) + } + + return &DescriptorTable{ + descriptors: unsafe.Slice((*Descriptor)(unsafe.Pointer(&mem[0])), queueSize), + // We have no free descriptors until they were initialized. + freeHeadIndex: noFreeHead, + freeNum: 0, + itemSize: itemSize, //todo configurable? needs to be page-aligned + } +} + +// Address returns the pointer to the beginning of the descriptor table in +// memory. Do not modify the memory directly to not interfere with this +// implementation. +func (dt *DescriptorTable) Address() uintptr { + if dt.descriptors == nil { + panic("descriptor table is not initialized") + } + //should be same as dt.bufferBase + return uintptr(unsafe.Pointer(&dt.descriptors[0])) +} + +func (dt *DescriptorTable) Size() uintptr { + if dt.descriptors == nil { + panic("descriptor table is not initialized") + } + return uintptr(dt.bufferSize) +} + +// BufferAddresses returns a map of pointer->size for all allocations used by the table +func (dt *DescriptorTable) BufferAddresses() map[uintptr]int { + if dt.descriptors == nil { + panic("descriptor table is not initialized") + } + + return map[uintptr]int{dt.bufferBase: dt.bufferSize} +} + +// initializeDescriptors allocates buffers with the size of a full memory page +// for each descriptor in the table. While this may be a bit wasteful, it makes +// dealing with descriptors way easier. Without this preallocation, we would +// have to allocate and free memory on demand, increasing complexity. +// +// All descriptors will be marked as free and will form a free chain. The +// addresses of all descriptors will be populated while their length remains +// zero. +func (dt *DescriptorTable) initializeDescriptors() error { + numDescriptors := len(dt.descriptors) + + // Allocate ONE large region for all buffers + totalSize := dt.itemSize * numDescriptors + basePtr, err := unix.MmapPtr(-1, 0, nil, uintptr(totalSize), + unix.PROT_READ|unix.PROT_WRITE, + unix.MAP_PRIVATE|unix.MAP_ANONYMOUS) + if err != nil { + return fmt.Errorf("allocate buffer memory for descriptors: %w", err) + } + + // Store the base for cleanup later + dt.bufferBase = uintptr(basePtr) + dt.bufferSize = totalSize + + for i := range dt.descriptors { + dt.descriptors[i] = Descriptor{ + address: dt.bufferBase + uintptr(i*dt.itemSize), + length: 0, + // All descriptors should form a free chain that loops around. + flags: descriptorFlagHasNext, + next: uint16((i + 1) % len(dt.descriptors)), + } + } + + // All descriptors are free to use now. + dt.freeHeadIndex = 0 + dt.freeNum = uint16(len(dt.descriptors)) + + return nil +} + +// releaseBuffers releases all allocated buffers for this descriptor table. +// The implementation will try to release as many buffers as possible and +// collect potential errors before returning them. +// The descriptor table should no longer be used after calling this. +func (dt *DescriptorTable) releaseBuffers() error { + for i := range dt.descriptors { + descriptor := &dt.descriptors[i] + descriptor.address = 0 + } + + // As a safety measure, make sure no descriptors can be used anymore. + dt.freeHeadIndex = noFreeHead + dt.freeNum = 0 + + if dt.bufferBase != 0 { + // The pointer points to memory not managed by Go, so this conversion + // is safe. See https://github.com/golang/go/issues/58625 + dt.bufferBase = 0 + //goland:noinspection GoVetUnsafePointer + err := unix.MunmapPtr(unsafe.Pointer(dt.bufferBase), uintptr(dt.bufferSize)) + if err != nil { + return fmt.Errorf("release buffer memory: %w", err) + } + } + + return nil +} + +func (dt *DescriptorTable) CreateDescriptorForOutputs() (uint16, error) { + //todo just fill the damn table + // Do we still have enough free descriptors? + + if 1 > dt.freeNum { + return 0, ErrNotEnoughFreeDescriptors + } + + // Above validation ensured that there is at least one free descriptor, so + // the free descriptor chain head should be valid. + if dt.freeHeadIndex == noFreeHead { + panic("free descriptor chain head is unset but there should be free descriptors") + } + + // To avoid having to iterate over the whole table to find the descriptor + // pointing to the head just to replace the free head, we instead always + // create descriptor chains from the descriptors coming after the head. + // This way we only have to touch the head as a last resort, when all other + // descriptors are already used. + head := dt.descriptors[dt.freeHeadIndex].next + desc := &dt.descriptors[head] + next := desc.next + + checkUnusedDescriptorLength(head, desc) + + // Give the device the maximum available number of bytes to write into. + desc.length = uint32(dt.itemSize) + desc.flags = 0 // descriptorFlagWritable + desc.next = 0 // Not necessary to clear this, it's just for looks. + + dt.freeNum -= 1 + + if dt.freeNum == 0 { + // The last descriptor in the chain should be the free chain head + // itself. + if next != dt.freeHeadIndex { + panic("descriptor chain takes up all free descriptors but does not end with the free chain head") + } + + // When this new chain takes up all remaining descriptors, we no longer + // have a free chain. + dt.freeHeadIndex = noFreeHead + } else { + // We took some descriptors out of the free chain, so make sure to close + // the circle again. + dt.descriptors[dt.freeHeadIndex].next = next + } + + return head, nil +} + +func (dt *DescriptorTable) createDescriptorForInputs() (uint16, error) { + // Do we still have enough free descriptors? + if 1 > dt.freeNum { + return 0, ErrNotEnoughFreeDescriptors + } + + // Above validation ensured that there is at least one free descriptor, so + // the free descriptor chain head should be valid. + if dt.freeHeadIndex == noFreeHead { + panic("free descriptor chain head is unset but there should be free descriptors") + } + + // To avoid having to iterate over the whole table to find the descriptor + // pointing to the head just to replace the free head, we instead always + // create descriptor chains from the descriptors coming after the head. + // This way we only have to touch the head as a last resort, when all other + // descriptors are already used. + head := dt.descriptors[dt.freeHeadIndex].next + desc := &dt.descriptors[head] + next := desc.next + + checkUnusedDescriptorLength(head, desc) + + // Give the device the maximum available number of bytes to write into. + desc.length = uint32(dt.itemSize) + desc.flags = descriptorFlagWritable + desc.next = 0 // Not necessary to clear this, it's just for looks. + + dt.freeNum -= 1 + + if dt.freeNum == 0 { + // The last descriptor in the chain should be the free chain head + // itself. + if next != dt.freeHeadIndex { + panic("descriptor chain takes up all free descriptors but does not end with the free chain head") + } + + // When this new chain takes up all remaining descriptors, we no longer + // have a free chain. + dt.freeHeadIndex = noFreeHead + } else { + // We took some descriptors out of the free chain, so make sure to close + // the circle again. + dt.descriptors[dt.freeHeadIndex].next = next + } + + return head, nil +} + +// TODO: Implement a zero-copy variant of createDescriptorChain? + +// getDescriptorChain returns the device-readable buffers (out buffers) and +// device-writable buffers (in buffers) of the descriptor chain that starts with +// the given head index. The descriptor chain must have been created using +// [createDescriptorChain] and must not have been freed yet (meaning that the +// head index must not be contained in the free chain). +// +// Be careful to only access the returned buffer slices when the device has not +// yet or is no longer using them. They must not be accessed after +// [freeDescriptorChain] has been called. +func (dt *DescriptorTable) getDescriptorChain(head uint16) (outBuffers, inBuffers [][]byte, err error) { + if int(head) > len(dt.descriptors) { + return nil, nil, fmt.Errorf("%w: index out of range", ErrInvalidDescriptorChain) + } + + // Iterate over the chain. The iteration is limited to the queue size to + // avoid ending up in an endless loop when things go very wrong. + next := head + for range len(dt.descriptors) { + if next == dt.freeHeadIndex { + return nil, nil, fmt.Errorf("%w: must not be part of the free chain", ErrInvalidDescriptorChain) + } + + desc := &dt.descriptors[next] + + // The descriptor address points to memory not managed by Go, so this + // conversion is safe. See https://github.com/golang/go/issues/58625 + //goland:noinspection GoVetUnsafePointer + bs := unsafe.Slice((*byte)(unsafe.Pointer(desc.address)), desc.length) + + if desc.flags&descriptorFlagWritable == 0 { + outBuffers = append(outBuffers, bs) + } else { + inBuffers = append(inBuffers, bs) + } + + // Is this the tail of the chain? + if desc.flags&descriptorFlagHasNext == 0 { + break + } + + // Detect loops. + if desc.next == head { + return nil, nil, fmt.Errorf("%w: contains a loop", ErrInvalidDescriptorChain) + } + + next = desc.next + } + + return +} + +func (dt *DescriptorTable) getDescriptorItem(head uint16) ([]byte, error) { + if int(head) > len(dt.descriptors) { + return nil, fmt.Errorf("%w: index out of range", ErrInvalidDescriptorChain) + } + + desc := &dt.descriptors[head] //todo this is a pretty nasty hack with no checks + + // The descriptor address points to memory not managed by Go, so this + // conversion is safe. See https://github.com/golang/go/issues/58625 + //goland:noinspection GoVetUnsafePointer + bs := unsafe.Slice((*byte)(unsafe.Pointer(desc.address)), desc.length) + return bs, nil +} + +func (dt *DescriptorTable) getDescriptorInbuffers(head uint16, inBuffers *[][]byte) error { + if int(head) > len(dt.descriptors) { + return fmt.Errorf("%w: index out of range", ErrInvalidDescriptorChain) + } + + // Iterate over the chain. The iteration is limited to the queue size to + // avoid ending up in an endless loop when things go very wrong. + next := head + for range len(dt.descriptors) { + if next == dt.freeHeadIndex { + return fmt.Errorf("%w: must not be part of the free chain", ErrInvalidDescriptorChain) + } + + desc := &dt.descriptors[next] + + // The descriptor address points to memory not managed by Go, so this + // conversion is safe. See https://github.com/golang/go/issues/58625 + //goland:noinspection GoVetUnsafePointer + bs := unsafe.Slice((*byte)(unsafe.Pointer(desc.address)), desc.length) + + if desc.flags&descriptorFlagWritable == 0 { + return fmt.Errorf("there should not be an outbuffer in %d", head) + } else { + *inBuffers = append(*inBuffers, bs) + } + + // Is this the tail of the chain? + if desc.flags&descriptorFlagHasNext == 0 { + break + } + + // Detect loops. + if desc.next == head { + return fmt.Errorf("%w: contains a loop", ErrInvalidDescriptorChain) + } + + next = desc.next + } + + return nil +} + +// freeDescriptorChain can be used to free a descriptor chain when it is no +// longer in use. The descriptor chain that starts with the given index will be +// put back into the free chain, so the descriptors can be used for later calls +// of [createDescriptorChain]. +// The descriptor chain must have been created using [createDescriptorChain] and +// must not have been freed yet (meaning that the head index must not be +// contained in the free chain). +func (dt *DescriptorTable) freeDescriptorChain(head uint16) error { + if int(head) > len(dt.descriptors) { + return fmt.Errorf("%w: index out of range", ErrInvalidDescriptorChain) + } + + // Iterate over the chain. The iteration is limited to the queue size to + // avoid ending up in an endless loop when things go very wrong. + next := head + var tailDesc *Descriptor + var chainLen uint16 + for range len(dt.descriptors) { + if next == dt.freeHeadIndex { + return fmt.Errorf("%w: must not be part of the free chain", ErrInvalidDescriptorChain) + } + + desc := &dt.descriptors[next] + chainLen++ + + // Set the length of all unused descriptors back to zero. + desc.length = 0 + + // Unset all flags except the next flag. + desc.flags &= descriptorFlagHasNext + + // Is this the tail of the chain? + if desc.flags&descriptorFlagHasNext == 0 { + tailDesc = desc + break + } + + // Detect loops. + if desc.next == head { + return fmt.Errorf("%w: contains a loop", ErrInvalidDescriptorChain) + } + + next = desc.next + } + if tailDesc == nil { + // A descriptor chain longer than the queue size but without loops + // should be impossible. + panic(fmt.Sprintf("could not find a tail for descriptor chain starting at %d", head)) + } + + // The tail descriptor does not have the next flag set, but when it comes + // back into the free chain, it should have. + tailDesc.flags = descriptorFlagHasNext + + if dt.freeHeadIndex == noFreeHead { + // The whole free chain was used up, so we turn this returned descriptor + // chain into the new free chain by completing the circle and using its + // head. + tailDesc.next = head + dt.freeHeadIndex = head + } else { + // Attach the returned chain at the beginning of the free chain but + // right after the free chain head. + freeHeadDesc := &dt.descriptors[dt.freeHeadIndex] + tailDesc.next = freeHeadDesc.next + freeHeadDesc.next = head + } + + dt.freeNum += chainLen + + return nil +} + +// checkUnusedDescriptorLength asserts that the length of an unused descriptor +// is zero, as it should be. +// This is not a requirement by the virtio spec but rather a thing we do to +// notice when our algorithm goes sideways. +func checkUnusedDescriptorLength(index uint16, desc *Descriptor) { + if desc.length != 0 { + panic(fmt.Sprintf("descriptor %d should be unused but has a non-zero length", index)) + } +} diff --git a/overlay/virtqueue/doc.go b/overlay/virtqueue/doc.go new file mode 100644 index 0000000..158d13b --- /dev/null +++ b/overlay/virtqueue/doc.go @@ -0,0 +1,7 @@ +// Package virtqueue implements the driver-side for a virtio queue as described +// in the specification: +// https://docs.oasis-open.org/virtio/virtio/v1.2/csd01/virtio-v1.2-csd01.html#x1-270006 +// This package does not make assumptions about the device that consumes the +// queue. It rather just allocates the queue structures in memory and provides +// methods to interact with it. +package virtqueue diff --git a/overlay/virtqueue/eventfd_test.go b/overlay/virtqueue/eventfd_test.go new file mode 100644 index 0000000..c42b3be --- /dev/null +++ b/overlay/virtqueue/eventfd_test.go @@ -0,0 +1,45 @@ +package virtqueue + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "gvisor.dev/gvisor/pkg/eventfd" +) + +// Tests how an eventfd and a waiting goroutine can be gracefully closed. +// Extends the eventfd test suite: +// https://github.com/google/gvisor/blob/0799336d64be65eb97d330606c30162dc3440cab/pkg/eventfd/eventfd_test.go +func TestEventFD_CancelWait(t *testing.T) { + efd, err := eventfd.Create() + require.NoError(t, err) + t.Cleanup(func() { + assert.NoError(t, efd.Close()) + }) + + var stop bool + + done := make(chan struct{}) + go func() { + for !stop { + _ = efd.Wait() + } + close(done) + }() + select { + case <-done: + t.Fatalf("goroutine ended early") + case <-time.After(500 * time.Millisecond): + } + + stop = true + assert.NoError(t, efd.Notify()) + select { + case <-done: + break + case <-time.After(5 * time.Second): + t.Error("goroutine did not end") + } +} diff --git a/overlay/virtqueue/size.go b/overlay/virtqueue/size.go new file mode 100644 index 0000000..605ca70 --- /dev/null +++ b/overlay/virtqueue/size.go @@ -0,0 +1,33 @@ +package virtqueue + +import ( + "errors" + "fmt" +) + +// ErrQueueSizeInvalid is returned when a queue size is invalid. +var ErrQueueSizeInvalid = errors.New("queue size is invalid") + +// CheckQueueSize checks if the given value would be a valid size for a +// virtqueue and returns an [ErrQueueSizeInvalid], if not. +func CheckQueueSize(queueSize int) error { + if queueSize <= 0 { + return fmt.Errorf("%w: %d is too small", ErrQueueSizeInvalid, queueSize) + } + + // The queue size must always be a power of 2. + // This ensures that ring indexes wrap correctly when the 16-bit integers + // overflow. + if queueSize&(queueSize-1) != 0 { + return fmt.Errorf("%w: %d is not a power of 2", ErrQueueSizeInvalid, queueSize) + } + + // The largest power of 2 that fits into a 16-bit integer is 32768. + // 2 * 32768 would be 65536 which no longer fits. + if queueSize > 32768 { + return fmt.Errorf("%w: %d is larger than the maximum possible queue size 32768", + ErrQueueSizeInvalid, queueSize) + } + + return nil +} diff --git a/overlay/virtqueue/size_test.go b/overlay/virtqueue/size_test.go new file mode 100644 index 0000000..707f58c --- /dev/null +++ b/overlay/virtqueue/size_test.go @@ -0,0 +1,59 @@ +package virtqueue + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestCheckQueueSize(t *testing.T) { + tests := []struct { + name string + queueSize int + containsErr string + }{ + { + name: "negative", + queueSize: -1, + containsErr: "too small", + }, + { + name: "zero", + queueSize: 0, + containsErr: "too small", + }, + { + name: "not a power of 2", + queueSize: 24, + containsErr: "not a power of 2", + }, + { + name: "too large", + queueSize: 65536, + containsErr: "larger than the maximum", + }, + { + name: "valid 1", + queueSize: 1, + }, + { + name: "valid 256", + queueSize: 256, + }, + + { + name: "valid 32768", + queueSize: 32768, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := CheckQueueSize(tt.queueSize) + if tt.containsErr != "" { + assert.ErrorContains(t, err, tt.containsErr) + } else { + assert.NoError(t, err) + } + }) + } +} diff --git a/overlay/virtqueue/split_virtqueue.go b/overlay/virtqueue/split_virtqueue.go new file mode 100644 index 0000000..59421ea --- /dev/null +++ b/overlay/virtqueue/split_virtqueue.go @@ -0,0 +1,421 @@ +package virtqueue + +import ( + "context" + "errors" + "fmt" + "os" + + "github.com/slackhq/nebula/overlay/eventfd" + "golang.org/x/sys/unix" +) + +// SplitQueue is a virtqueue that consists of several parts, where each part is +// writeable by either the driver or the device, but not both. +type SplitQueue struct { + // size is the size of the queue. + size int + // buf is the underlying memory used for the queue. + buf []byte + + descriptorTable *DescriptorTable + availableRing *AvailableRing + usedRing *UsedRing + + // kickEventFD is used to signal the device when descriptor chains were + // added to the available ring. + kickEventFD eventfd.EventFD + // callEventFD is used by the device to signal when it has used descriptor + // chains and put them in the used ring. + callEventFD eventfd.EventFD + + // stop is used by [SplitQueue.Close] to cancel the goroutine that handles + // used buffer notifications. It blocks until the goroutine ended. + stop func() error + + itemSize int + + epoll eventfd.Epoll + more int +} + +// NewSplitQueue allocates a new [SplitQueue] in memory. The given queue size +// specifies the number of entries/buffers the queue can hold. This also affects +// the memory consumption. +func NewSplitQueue(queueSize int, itemSize int) (_ *SplitQueue, err error) { + if err = CheckQueueSize(queueSize); err != nil { + return nil, err + } + + if itemSize%os.Getpagesize() != 0 { + return nil, errors.New("split queue size must be multiple of os.Getpagesize()") + } + + sq := SplitQueue{ + size: queueSize, + itemSize: itemSize, + } + + // Clean up a partially initialized queue when something fails. + defer func() { + if err != nil { + _ = sq.Close() + } + }() + + // There are multiple ways for how the memory for the virtqueue could be + // allocated. We could use Go native structs with arrays inside them, but + // this wouldn't allow us to make the queue size configurable. And including + // a slice in the Go structs wouldn't work, because this would just put the + // Go slice descriptor into the memory region which the virtio device will + // not understand. + // Additionally, Go does not allow us to ensure a correct alignment of the + // parts of the virtqueue, as it is required by the virtio specification. + // + // To resolve this, let's just allocate the memory manually by allocating + // one or more memory pages, depending on the queue size. Making the + // virtqueue start at the beginning of a page is not strictly necessary, as + // the virtio specification does not require it to be continuous in the + // physical memory of the host (e.g. the vhost implementation in the kernel + // always uses copy_from_user to access it), but this makes it very easy to + // guarantee the alignment. Also, it is not required for the virtqueue parts + // to be in the same memory region, as we pass separate pointers to them to + // the device, but this design just makes things easier to implement. + // + // One added benefit of allocating the memory manually is, that we have full + // control over its lifetime and don't risk the garbage collector to collect + // our valuable structures while the device still works with them. + + // The descriptor table is at the start of the page, so alignment is not an + // issue here. + descriptorTableStart := 0 + descriptorTableEnd := descriptorTableStart + descriptorTableSize(queueSize) + availableRingStart := align(descriptorTableEnd, availableRingAlignment) + availableRingEnd := availableRingStart + availableRingSize(queueSize) + usedRingStart := align(availableRingEnd, usedRingAlignment) + usedRingEnd := usedRingStart + usedRingSize(queueSize) + + sq.buf, err = unix.Mmap(-1, 0, usedRingEnd, + unix.PROT_READ|unix.PROT_WRITE, + unix.MAP_PRIVATE|unix.MAP_ANONYMOUS) + if err != nil { + return nil, fmt.Errorf("allocate virtqueue buffer: %w", err) + } + + sq.descriptorTable = newDescriptorTable(queueSize, sq.buf[descriptorTableStart:descriptorTableEnd], sq.itemSize) + sq.availableRing = newAvailableRing(queueSize, sq.buf[availableRingStart:availableRingEnd]) + sq.usedRing = newUsedRing(queueSize, sq.buf[usedRingStart:usedRingEnd]) + + sq.kickEventFD, err = eventfd.New() + if err != nil { + return nil, fmt.Errorf("create kick event file descriptor: %w", err) + } + sq.callEventFD, err = eventfd.New() + if err != nil { + return nil, fmt.Errorf("create call event file descriptor: %w", err) + } + + if err = sq.descriptorTable.initializeDescriptors(); err != nil { + return nil, fmt.Errorf("initialize descriptors: %w", err) + } + + sq.epoll, err = eventfd.NewEpoll() + if err != nil { + return nil, err + } + err = sq.epoll.AddEvent(sq.callEventFD.FD()) + if err != nil { + return nil, err + } + + // Consume used buffer notifications in the background. + sq.stop = sq.startConsumeUsedRing() + + return &sq, nil +} + +// Size returns the size of this queue, which is the number of entries/buffers +// this queue can hold. +func (sq *SplitQueue) Size() int { + return sq.size +} + +// DescriptorTable returns the [DescriptorTable] behind this queue. +func (sq *SplitQueue) DescriptorTable() *DescriptorTable { + return sq.descriptorTable +} + +// AvailableRing returns the [AvailableRing] behind this queue. +func (sq *SplitQueue) AvailableRing() *AvailableRing { + return sq.availableRing +} + +// UsedRing returns the [UsedRing] behind this queue. +func (sq *SplitQueue) UsedRing() *UsedRing { + return sq.usedRing +} + +// KickEventFD returns the kick event file descriptor behind this queue. +// The returned file descriptor should be used with great care to not interfere +// with this implementation. +func (sq *SplitQueue) KickEventFD() int { + return sq.kickEventFD.FD() +} + +// CallEventFD returns the call event file descriptor behind this queue. +// The returned file descriptor should be used with great care to not interfere +// with this implementation. +func (sq *SplitQueue) CallEventFD() int { + return sq.callEventFD.FD() +} + +// startConsumeUsedRing starts a goroutine that runs [consumeUsedRing]. +// A function is returned that can be used to gracefully cancel it. todo rename +func (sq *SplitQueue) startConsumeUsedRing() func() error { + return func() error { + + // The goroutine blocks until it receives a signal on the event file + // descriptor, so it will never notice the context being canceled. + // To resolve this, we can just produce a fake-signal ourselves to wake + // it up. + if err := sq.callEventFD.Kick(); err != nil { + return fmt.Errorf("wake up goroutine: %w", err) + } + return nil + } +} + +func (sq *SplitQueue) TakeSingle(ctx context.Context) (uint16, error) { + var n int + var err error + for ctx.Err() == nil { + out, ok := sq.usedRing.takeOne() + if ok { + return out, nil + } + // Wait for a signal from the device. + if n, err = sq.epoll.Block(); err != nil { + return 0, fmt.Errorf("wait: %w", err) + } + + if n > 0 { + out, ok = sq.usedRing.takeOne() + if ok { + _ = sq.epoll.Clear() //??? + return out, nil + } else { + continue //??? + } + } + } + return 0, ctx.Err() +} + +func (sq *SplitQueue) BlockAndGetHeadsCapped(ctx context.Context, maxToTake int) ([]UsedElement, error) { + var n int + var err error + for ctx.Err() == nil { + + //we have leftovers in the fridge + if sq.more > 0 { + stillNeedToTake, out := sq.usedRing.take(maxToTake) + sq.more = stillNeedToTake + return out, nil + } + //look inside the fridge + stillNeedToTake, out := sq.usedRing.take(maxToTake) + if len(out) > 0 { + sq.more = stillNeedToTake + return out, nil + } + //fridge is empty I guess + + // Wait for a signal from the device. + if n, err = sq.epoll.Block(); err != nil { + return nil, fmt.Errorf("wait: %w", err) + } + if n > 0 { + _ = sq.epoll.Clear() //??? + stillNeedToTake, out = sq.usedRing.take(maxToTake) + sq.more = stillNeedToTake + return out, nil + } + } + + return nil, ctx.Err() +} + +// OfferDescriptorChain offers a descriptor chain to the device which contains a +// number of device-readable buffers (out buffers) and device-writable buffers +// (in buffers). +// +// All buffers in the outBuffers slice will be concatenated by chaining +// descriptors, one for each buffer in the slice. When a buffer is too large to +// fit into a single descriptor (limited by the system's page size), it will be +// split up into multiple descriptors within the chain. +// When numInBuffers is greater than zero, the given number of device-writable +// descriptors will be appended to the end of the chain, each referencing a +// whole memory page (see [os.Getpagesize]). +// +// When the queue is full and no more descriptor chains can be added, a wrapped +// [ErrNotEnoughFreeDescriptors] will be returned. If you set waitFree to true, +// this method will handle this error and will block instead until there are +// enough free descriptors again. +// +// After defining the descriptor chain in the [DescriptorTable], the index of +// the head of the chain will be made available to the device using the +// [AvailableRing] and will be returned by this method. +// Callers should read from the [SplitQueue.UsedDescriptorChains] channel to be +// notified when the descriptor chain was used by the device and should free the +// used descriptor chains again using [SplitQueue.FreeDescriptorChain] when +// they're done with them. When this does not happen, the queue will run full +// and any further calls to [SplitQueue.OfferDescriptorChain] will stall. + +func (sq *SplitQueue) OfferInDescriptorChains() (uint16, error) { + // Create a descriptor chain for the given buffers. + var ( + head uint16 + err error + ) + for { + head, err = sq.descriptorTable.createDescriptorForInputs() + if err == nil { + break + } + + // I don't wanna use errors.Is, it's slow + //goland:noinspection GoDirectComparisonOfErrors + if err == ErrNotEnoughFreeDescriptors { + return 0, err + } else { + return 0, fmt.Errorf("create descriptor chain: %w", err) + } + } + + // Make the descriptor chain available to the device. + sq.availableRing.offerSingle(head) + + // Notify the device to make it process the updated available ring. + if err := sq.kickEventFD.Kick(); err != nil { + return head, fmt.Errorf("notify device: %w", err) + } + + return head, nil +} + +// GetDescriptorChain returns the device-readable buffers (out buffers) and +// device-writable buffers (in buffers) of the descriptor chain with the given +// head index. +// The head index must be one that was returned by a previous call to +// [SplitQueue.OfferDescriptorChain] and the descriptor chain must not have been +// freed yet. +// +// Be careful to only access the returned buffer slices when the device is no +// longer using them. They must not be accessed after +// [SplitQueue.FreeDescriptorChain] has been called. +func (sq *SplitQueue) GetDescriptorChain(head uint16) (outBuffers, inBuffers [][]byte, err error) { + return sq.descriptorTable.getDescriptorChain(head) +} + +func (sq *SplitQueue) GetDescriptorItem(head uint16) ([]byte, error) { + sq.descriptorTable.descriptors[head].length = uint32(sq.descriptorTable.itemSize) + return sq.descriptorTable.getDescriptorItem(head) +} + +func (sq *SplitQueue) GetDescriptorInbuffers(head uint16, inBuffers *[][]byte) error { + return sq.descriptorTable.getDescriptorInbuffers(head, inBuffers) +} + +// FreeDescriptorChain frees the descriptor chain with the given head index. +// The head index must be one that was returned by a previous call to +// [SplitQueue.OfferDescriptorChain] and the descriptor chain must not have been +// freed yet. +// +// This creates new room in the queue which can be used by following +// [SplitQueue.OfferDescriptorChain] calls. +// When there are outstanding calls for [SplitQueue.OfferDescriptorChain] that +// are waiting for free room in the queue, they may become unblocked by this. +func (sq *SplitQueue) FreeDescriptorChain(head uint16) error { + //not called under lock + if err := sq.descriptorTable.freeDescriptorChain(head); err != nil { + return fmt.Errorf("free: %w", err) + } + + return nil +} + +func (sq *SplitQueue) SetDescSize(head uint16, sz int) { + //not called under lock + sq.descriptorTable.descriptors[int(head)].length = uint32(sz) +} + +func (sq *SplitQueue) OfferDescriptorChains(chains []uint16, kick bool) error { + //todo not doing this may break eventually? + //not called under lock + //if err := sq.descriptorTable.freeDescriptorChain(head); err != nil { + // return fmt.Errorf("free: %w", err) + //} + + // Make the descriptor chain available to the device. + sq.availableRing.offer(chains) + + // Notify the device to make it process the updated available ring. + if kick { + return sq.Kick() + } + + return nil +} + +func (sq *SplitQueue) Kick() error { + if err := sq.kickEventFD.Kick(); err != nil { + return fmt.Errorf("notify device: %w", err) + } + return nil +} + +// Close releases all resources used for this queue. +// The implementation will try to release as many resources as possible and +// collect potential errors before returning them. +func (sq *SplitQueue) Close() error { + var errs []error + + if sq.stop != nil { + // This has to happen before the event file descriptors may be closed. + if err := sq.stop(); err != nil { + errs = append(errs, fmt.Errorf("stop consume used ring: %w", err)) + } + + // Make sure that this code block is executed only once. + sq.stop = nil + } + + if err := sq.kickEventFD.Close(); err != nil { + errs = append(errs, fmt.Errorf("close kick event file descriptor: %w", err)) + } + if err := sq.callEventFD.Close(); err != nil { + errs = append(errs, fmt.Errorf("close call event file descriptor: %w", err)) + } + + if err := sq.descriptorTable.releaseBuffers(); err != nil { + errs = append(errs, fmt.Errorf("release descriptor buffers: %w", err)) + } + + if sq.buf != nil { + if err := unix.Munmap(sq.buf); err == nil { + sq.buf = nil + } else { + errs = append(errs, fmt.Errorf("unmap virtqueue buffer: %w", err)) + } + } + + return errors.Join(errs...) +} + +func align(index, alignment int) int { + remainder := index % alignment + if remainder == 0 { + return index + } + return index + alignment - remainder +} diff --git a/overlay/virtqueue/used_element.go b/overlay/virtqueue/used_element.go new file mode 100644 index 0000000..a4d5d26 --- /dev/null +++ b/overlay/virtqueue/used_element.go @@ -0,0 +1,21 @@ +package virtqueue + +// usedElementSize is the number of bytes needed to store a [UsedElement] in +// memory. +const usedElementSize = 8 + +// UsedElement is an element of the [UsedRing] and describes a descriptor chain +// that was used by the device. +type UsedElement struct { + // DescriptorIndex is the index of the head of the used descriptor chain in + // the [DescriptorTable]. + // The index is 32-bit here for padding reasons. + DescriptorIndex uint32 + // Length is the number of bytes written into the device writable portion of + // the buffer described by the descriptor chain. + Length uint32 +} + +func (u *UsedElement) GetHead() uint16 { + return uint16(u.DescriptorIndex) +} diff --git a/overlay/virtqueue/used_element_internal_test.go b/overlay/virtqueue/used_element_internal_test.go new file mode 100644 index 0000000..3ac65a4 --- /dev/null +++ b/overlay/virtqueue/used_element_internal_test.go @@ -0,0 +1,12 @@ +package virtqueue + +import ( + "testing" + "unsafe" + + "github.com/stretchr/testify/assert" +) + +func TestUsedElement_Size(t *testing.T) { + assert.EqualValues(t, usedElementSize, unsafe.Sizeof(UsedElement{})) +} diff --git a/overlay/virtqueue/used_ring.go b/overlay/virtqueue/used_ring.go new file mode 100644 index 0000000..824c07c --- /dev/null +++ b/overlay/virtqueue/used_ring.go @@ -0,0 +1,184 @@ +package virtqueue + +import ( + "fmt" + "unsafe" +) + +// usedRingFlag is a flag that describes a [UsedRing]. +type usedRingFlag uint16 + +const ( + // usedRingFlagNoNotify is used by the host to advise the guest to not + // kick it when adding a buffer. It's unreliable, so it's simply an + // optimization. Guest will still kick when it's out of buffers. + usedRingFlagNoNotify usedRingFlag = 1 << iota +) + +// usedRingSize is the number of bytes needed to store a [UsedRing] with the +// given queue size in memory. +func usedRingSize(queueSize int) int { + return 6 + usedElementSize*queueSize +} + +// usedRingAlignment is the minimum alignment of a [UsedRing] in memory, as +// required by the virtio spec. +const usedRingAlignment = 4 + +// UsedRing is where the device returns descriptor chains once it is done with +// them. Each ring entry is a [UsedElement]. It is only written to by the device +// and read by the driver. +// +// Because the size of the ring depends on the queue size, we cannot define a +// Go struct with a static size that maps to the memory of the ring. Instead, +// this struct only contains pointers to the corresponding memory areas. +type UsedRing struct { + initialized bool + + // flags that describe this ring. + flags *usedRingFlag + // ringIndex indicates where the device would put the next entry into the + // ring (modulo the queue size). + ringIndex *uint16 + // ring contains the [UsedElement]s. It wraps around at queue size. + ring []UsedElement + // availableEvent is not used by this implementation, but we reserve it + // anyway to avoid issues in case a device may try to write to it, contrary + // to the virtio specification. + availableEvent *uint16 + + // lastIndex is the internal ringIndex up to which all [UsedElement]s were + // processed. + lastIndex uint16 + + //mu sync.Mutex +} + +// newUsedRing creates a used ring that uses the given underlying memory. The +// length of the memory slice must match the size needed for the ring (see +// [usedRingSize]) for the given queue size. +func newUsedRing(queueSize int, mem []byte) *UsedRing { + ringSize := usedRingSize(queueSize) + if len(mem) != ringSize { + panic(fmt.Sprintf("memory size (%v) does not match required size "+ + "for used ring: %v", len(mem), ringSize)) + } + + r := UsedRing{ + initialized: true, + flags: (*usedRingFlag)(unsafe.Pointer(&mem[0])), + ringIndex: (*uint16)(unsafe.Pointer(&mem[2])), + ring: unsafe.Slice((*UsedElement)(unsafe.Pointer(&mem[4])), queueSize), + availableEvent: (*uint16)(unsafe.Pointer(&mem[ringSize-2])), + } + r.lastIndex = *r.ringIndex + return &r +} + +// Address returns the pointer to the beginning of the ring in memory. +// Do not modify the memory directly to not interfere with this implementation. +func (r *UsedRing) Address() uintptr { + if !r.initialized { + panic("used ring is not initialized") + } + return uintptr(unsafe.Pointer(r.flags)) +} + +// take returns all new [UsedElement]s that the device put into the ring and +// that weren't already returned by a previous call to this method. +// had a lock, I removed it +func (r *UsedRing) take(maxToTake int) (int, []UsedElement) { + //r.mu.Lock() + //defer r.mu.Unlock() + + ringIndex := *r.ringIndex + if ringIndex == r.lastIndex { + // Nothing new. + return 0, nil + } + + // Calculate the number new used elements that we can read from the ring. + // The ring index may wrap, so special handling for that case is needed. + count := int(ringIndex - r.lastIndex) + if count < 0 { + count += 0xffff + } + + stillNeedToTake := 0 + + if maxToTake > 0 { + stillNeedToTake = count - maxToTake + if stillNeedToTake < 0 { + stillNeedToTake = 0 + } + count = min(count, maxToTake) + } + + // The number of new elements can never exceed the queue size. + if count > len(r.ring) { + panic("used ring contains more new elements than the ring is long") + } + + elems := make([]UsedElement, count) + for i := range count { + elems[i] = r.ring[r.lastIndex%uint16(len(r.ring))] + r.lastIndex++ + } + + return stillNeedToTake, elems +} + +func (r *UsedRing) takeOne() (uint16, bool) { + //r.mu.Lock() + //defer r.mu.Unlock() + + ringIndex := *r.ringIndex + if ringIndex == r.lastIndex { + // Nothing new. + return 0xffff, false + } + + // Calculate the number new used elements that we can read from the ring. + // The ring index may wrap, so special handling for that case is needed. + count := int(ringIndex - r.lastIndex) + if count < 0 { + count += 0xffff + } + + // The number of new elements can never exceed the queue size. + if count > len(r.ring) { + panic("used ring contains more new elements than the ring is long") + } + + if count == 0 { + return 0xffff, false + } + + out := r.ring[r.lastIndex%uint16(len(r.ring))].GetHead() + r.lastIndex++ + + return out, true +} + +// InitOfferSingle is only used to pre-fill the used queue at startup, and should not be used if the device is running! +func (r *UsedRing) InitOfferSingle(x uint16, size int) { + //always called under lock + //r.mu.Lock() + //defer r.mu.Unlock() + + offset := 0 + // Add descriptor chain heads to the ring. + + // The 16-bit ring index may overflow. This is expected and is not an + // issue because the size of the ring array (which equals the queue + // size) is always a power of 2 and smaller than the highest possible + // 16-bit value. + insertIndex := int(*r.ringIndex+uint16(offset)) % len(r.ring) + r.ring[insertIndex] = UsedElement{ + DescriptorIndex: uint32(x), + Length: uint32(size), + } + + // Increase the ring index by the number of descriptor chains added to the ring. + *r.ringIndex += 1 +} diff --git a/overlay/virtqueue/used_ring_internal_test.go b/overlay/virtqueue/used_ring_internal_test.go new file mode 100644 index 0000000..3e6faf2 --- /dev/null +++ b/overlay/virtqueue/used_ring_internal_test.go @@ -0,0 +1,136 @@ +package virtqueue + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestUsedRing_MemoryLayout(t *testing.T) { + const queueSize = 2 + + memory := make([]byte, usedRingSize(queueSize)) + r := newUsedRing(queueSize, memory) + + *r.flags = 0x01ff + *r.ringIndex = 1 + r.ring[0] = UsedElement{ + DescriptorIndex: 0x0123, + Length: 0x4567, + } + r.ring[1] = UsedElement{ + DescriptorIndex: 0x89ab, + Length: 0xcdef, + } + + assert.Equal(t, []byte{ + 0xff, 0x01, + 0x01, 0x00, + 0x23, 0x01, 0x00, 0x00, + 0x67, 0x45, 0x00, 0x00, + 0xab, 0x89, 0x00, 0x00, + 0xef, 0xcd, 0x00, 0x00, + 0x00, 0x00, + }, memory) +} + +//func TestUsedRing_Take(t *testing.T) { +// const queueSize = 8 +// +// tests := []struct { +// name string +// ring []UsedElement +// ringIndex uint16 +// lastIndex uint16 +// expected []UsedElement +// }{ +// { +// name: "nothing new", +// ring: []UsedElement{ +// {DescriptorIndex: 1}, +// {DescriptorIndex: 2}, +// {DescriptorIndex: 3}, +// {DescriptorIndex: 4}, +// {}, +// {}, +// {}, +// {}, +// }, +// ringIndex: 4, +// lastIndex: 4, +// expected: nil, +// }, +// { +// name: "no overflow", +// ring: []UsedElement{ +// {DescriptorIndex: 1}, +// {DescriptorIndex: 2}, +// {DescriptorIndex: 3}, +// {DescriptorIndex: 4}, +// {}, +// {}, +// {}, +// {}, +// }, +// ringIndex: 4, +// lastIndex: 1, +// expected: []UsedElement{ +// {DescriptorIndex: 2}, +// {DescriptorIndex: 3}, +// {DescriptorIndex: 4}, +// }, +// }, +// { +// name: "ring overflow", +// ring: []UsedElement{ +// {DescriptorIndex: 9}, +// {DescriptorIndex: 10}, +// {DescriptorIndex: 3}, +// {DescriptorIndex: 4}, +// {DescriptorIndex: 5}, +// {DescriptorIndex: 6}, +// {DescriptorIndex: 7}, +// {DescriptorIndex: 8}, +// }, +// ringIndex: 10, +// lastIndex: 7, +// expected: []UsedElement{ +// {DescriptorIndex: 8}, +// {DescriptorIndex: 9}, +// {DescriptorIndex: 10}, +// }, +// }, +// { +// name: "index overflow", +// ring: []UsedElement{ +// {DescriptorIndex: 9}, +// {DescriptorIndex: 10}, +// {DescriptorIndex: 3}, +// {DescriptorIndex: 4}, +// {DescriptorIndex: 5}, +// {DescriptorIndex: 6}, +// {DescriptorIndex: 7}, +// {DescriptorIndex: 8}, +// }, +// ringIndex: 2, +// lastIndex: 65535, +// expected: []UsedElement{ +// {DescriptorIndex: 8}, +// {DescriptorIndex: 9}, +// {DescriptorIndex: 10}, +// }, +// }, +// } +// for _, tt := range tests { +// t.Run(tt.name, func(t *testing.T) { +// memory := make([]byte, usedRingSize(queueSize)) +// r := newUsedRing(queueSize, memory) +// +// copy(r.ring, tt.ring) +// *r.ringIndex = tt.ringIndex +// r.lastIndex = tt.lastIndex +// +// assert.Equal(t, tt.expected, r.take()) +// }) +// } +//} diff --git a/packet/outpacket.go b/packet/outpacket.go new file mode 100644 index 0000000..3ae13bb --- /dev/null +++ b/packet/outpacket.go @@ -0,0 +1,70 @@ +package packet + +import ( + "github.com/slackhq/nebula/util/virtio" + "golang.org/x/sys/unix" +) + +type OutPacket struct { + Segments [][]byte + SegmentPayloads [][]byte + SegmentHeaders [][]byte + SegmentIDs []uint16 + //todo virtio header? + SegSize int + SegCounter int + Valid bool + wasSegmented bool + + Scratch []byte +} + +func NewOut() *OutPacket { + out := new(OutPacket) + out.Segments = make([][]byte, 0, 64) + out.SegmentHeaders = make([][]byte, 0, 64) + out.SegmentPayloads = make([][]byte, 0, 64) + out.SegmentIDs = make([]uint16, 0, 64) + out.Scratch = make([]byte, Size) + return out +} + +func (pkt *OutPacket) Reset() { + pkt.Segments = pkt.Segments[:0] + pkt.SegmentPayloads = pkt.SegmentPayloads[:0] + pkt.SegmentHeaders = pkt.SegmentHeaders[:0] + pkt.SegmentIDs = pkt.SegmentIDs[:0] + pkt.SegSize = 0 + pkt.Valid = false + pkt.wasSegmented = false +} + +func (pkt *OutPacket) UseSegment(segID uint16, seg []byte, isV6 bool) int { + pkt.Valid = true + pkt.SegmentIDs = append(pkt.SegmentIDs, segID) + pkt.Segments = append(pkt.Segments, seg) //todo do we need this? + + vhdr := virtio.NetHdr{ //todo + Flags: unix.VIRTIO_NET_HDR_F_DATA_VALID, + GSOType: unix.VIRTIO_NET_HDR_GSO_NONE, + HdrLen: 0, + GSOSize: 0, + CsumStart: 0, + CsumOffset: 0, + NumBuffers: 0, + } + + hdr := seg[0 : virtio.NetHdrSize+14] + _ = vhdr.Encode(hdr) + if isV6 { + hdr[virtio.NetHdrSize+14-2] = 0x86 + hdr[virtio.NetHdrSize+14-1] = 0xdd + } else { + hdr[virtio.NetHdrSize+14-2] = 0x08 + hdr[virtio.NetHdrSize+14-1] = 0x00 + } + + pkt.SegmentHeaders = append(pkt.SegmentHeaders, hdr) + pkt.SegmentPayloads = append(pkt.SegmentPayloads, seg[virtio.NetHdrSize+14:]) + return len(pkt.SegmentIDs) - 1 +} diff --git a/packet/packet.go b/packet/packet.go new file mode 100644 index 0000000..31b9fd9 --- /dev/null +++ b/packet/packet.go @@ -0,0 +1,119 @@ +package packet + +import ( + "encoding/binary" + "iter" + "net/netip" + "slices" + "syscall" + "unsafe" + + "golang.org/x/sys/unix" +) + +const Size = 0xffff + +type Packet struct { + Payload []byte + Control []byte + Name []byte + SegSize int + + //todo should this hold out as well? + OutLen int + + wasSegmented bool + isV4 bool +} + +func New(isV4 bool) *Packet { + return &Packet{ + Payload: make([]byte, Size), + Control: make([]byte, unix.CmsgSpace(2)), + Name: make([]byte, unix.SizeofSockaddrInet6), + isV4: isV4, + } +} + +func (p *Packet) AddrPort() netip.AddrPort { + var ip netip.Addr + // Its ok to skip the ok check here, the slicing is the only error that can occur and it will panic + if p.isV4 { + ip, _ = netip.AddrFromSlice(p.Name[4:8]) + } else { + ip, _ = netip.AddrFromSlice(p.Name[8:24]) + } + return netip.AddrPortFrom(ip.Unmap(), binary.BigEndian.Uint16(p.Name[2:4])) +} + +func (p *Packet) updateCtrl(ctrlLen int) { + p.SegSize = len(p.Payload) + p.wasSegmented = false + if ctrlLen == 0 { + return + } + if len(p.Control) == 0 { + return + } + cmsgs, err := unix.ParseSocketControlMessage(p.Control) + if err != nil { + return // oh well + } + + for _, c := range cmsgs { + if c.Header.Level == unix.SOL_UDP && c.Header.Type == unix.UDP_GRO && len(c.Data) >= 2 { + p.wasSegmented = true + p.SegSize = int(binary.LittleEndian.Uint16(c.Data[:2])) + return + } + } +} + +// Update sets a Packet into "just received, not processed" state +func (p *Packet) Update(ctrlLen int) { + p.OutLen = -1 + p.updateCtrl(ctrlLen) +} + +func (p *Packet) SetSegSizeForTX() { + p.SegSize = len(p.Payload) + hdr := (*unix.Cmsghdr)(unsafe.Pointer(&p.Control[0])) + hdr.Level = unix.SOL_UDP + hdr.Type = unix.UDP_SEGMENT + hdr.SetLen(syscall.CmsgLen(2)) + binary.NativeEndian.PutUint16(p.Control[unix.CmsgLen(0):unix.CmsgLen(0)+2], uint16(p.SegSize)) +} + +func (p *Packet) CompatibleForSegmentationWith(otherP *Packet, currentTotalSize int) bool { + //same dest + if !slices.Equal(p.Name, otherP.Name) { + return false + } + + //don't get too big + if len(p.Payload)+currentTotalSize >= 0xffff { + return false + } + + //same body len + //todo allow single different size at end + if len(p.Payload) != len(otherP.Payload) { + return false //todo technically you can cram one extra in + } + return true +} + +func (p *Packet) Segments() iter.Seq[[]byte] { + return func(yield func([]byte) bool) { + //cursor := 0 + for offset := 0; offset < len(p.Payload); offset += p.SegSize { + end := offset + p.SegSize + if end > len(p.Payload) { + end = len(p.Payload) + } + if !yield(p.Payload[offset:end]) { + return + } + } + } +} diff --git a/packet/virtio.go b/packet/virtio.go new file mode 100644 index 0000000..d03bd86 --- /dev/null +++ b/packet/virtio.go @@ -0,0 +1,37 @@ +package packet + +import ( + "github.com/slackhq/nebula/util/virtio" +) + +type VirtIOPacket struct { + Payload []byte + Header virtio.NetHdr + Chains []uint16 + ChainRefs [][]byte + // OfferDescriptorChains(chains []uint16, kick bool) error +} + +func NewVIO() *VirtIOPacket { + out := new(VirtIOPacket) + out.Payload = nil + out.ChainRefs = make([][]byte, 0, 4) + out.Chains = make([]uint16, 0, 8) + return out +} + +func (v *VirtIOPacket) Reset() { + v.Payload = nil + v.ChainRefs = v.ChainRefs[:0] + v.Chains = v.Chains[:0] +} + +type VirtIOTXPacket struct { + VirtIOPacket +} + +func NewVIOTX(isV4 bool) *VirtIOTXPacket { + out := new(VirtIOTXPacket) + out.VirtIOPacket = *NewVIO() + return out +} diff --git a/udp/conn.go b/udp/conn.go index 1ae585c..f249389 100644 --- a/udp/conn.go +++ b/udp/conn.go @@ -4,13 +4,13 @@ import ( "net/netip" "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/packet" ) const MTU = 9001 type EncReader func( - addr netip.AddrPort, - payload []byte, + []*packet.Packet, ) type Conn interface { @@ -19,6 +19,8 @@ type Conn interface { ListenOut(r EncReader) WriteTo(b []byte, addr netip.AddrPort) error ReloadConfig(c *config.C) + Prep(pkt *packet.Packet, addr netip.AddrPort) error + WriteBatch(pkt []*packet.Packet) (int, error) SupportsMultipleReaders() bool Close() error } diff --git a/udp/udp_linux.go b/udp/udp_linux.go index e775932..e39af7b 100644 --- a/udp/udp_linux.go +++ b/udp/udp_linux.go @@ -14,22 +14,22 @@ import ( "github.com/rcrowley/go-metrics" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/packet" "golang.org/x/sys/unix" ) -type StdConn struct { - sysFd int - isV4 bool - l *logrus.Logger - batch int -} +const iovMax = 128 //1024 //no unix constant for this? from limits.h +//todo I'd like this to be 1024 but we seem to hit errors around ~130? -func maybeIPV4(ip net.IP) (net.IP, bool) { - ip4 := ip.To4() - if ip4 != nil { - return ip4, true - } - return ip, false +type StdConn struct { + sysFd int + isV4 bool + l *logrus.Logger + batch int + enableGRO bool + + msgs []rawMessage + iovs [][]iovec } func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) { @@ -69,7 +69,20 @@ func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch in return nil, fmt.Errorf("unable to bind to socket: %s", err) } - return &StdConn{sysFd: fd, isV4: ip.Is4(), l: l, batch: batch}, err + const batchSize = 8192 + msgs := make([]rawMessage, 0, batchSize) //todo configure + iovs := make([][]iovec, batchSize) + for i := range iovs { + iovs[i] = make([]iovec, iovMax) + } + return &StdConn{ + sysFd: fd, + isV4: ip.Is4(), + l: l, + batch: batch, + msgs: msgs, + iovs: iovs, + }, err } func (u *StdConn) SupportsMultipleReaders() bool { @@ -123,9 +136,7 @@ func (u *StdConn) LocalAddr() (netip.AddrPort, error) { } func (u *StdConn) ListenOut(r EncReader) { - var ip netip.Addr - - msgs, buffers, names := u.PrepareRawMessages(u.batch) + msgs, packets := u.PrepareRawMessages(u.batch, u.isV4) read := u.ReadMulti if u.batch == 1 { read = u.ReadSingle @@ -139,13 +150,12 @@ func (u *StdConn) ListenOut(r EncReader) { } for i := 0; i < n; i++ { - // Its ok to skip the ok check here, the slicing is the only error that can occur and it will panic - if u.isV4 { - ip, _ = netip.AddrFromSlice(names[i][4:8]) - } else { - ip, _ = netip.AddrFromSlice(names[i][8:24]) - } - r(netip.AddrPortFrom(ip.Unmap(), binary.BigEndian.Uint16(names[i][2:4])), buffers[i][:msgs[i].Len]) + packets[i].Payload = packets[i].Payload[:msgs[i].Len] + packets[i].Update(getRawMessageControlLen(&msgs[i])) + } + r(packets[:n]) + for i := 0; i < n; i++ { //todo reset this in prev loop, but this makes debug ez + msgs[i].Hdr.Controllen = uint64(unix.CmsgSpace(2)) } } } @@ -198,6 +208,147 @@ func (u *StdConn) WriteTo(b []byte, ip netip.AddrPort) error { return u.writeTo6(b, ip) } +func (u *StdConn) WriteToBatch(b []byte, ip netip.AddrPort) error { + if u.isV4 { + return u.writeTo4(b, ip) + } + return u.writeTo6(b, ip) +} + +func (u *StdConn) Prep(pkt *packet.Packet, addr netip.AddrPort) error { + nl, err := u.encodeSockaddr(pkt.Name, addr) + if err != nil { + return err + } + pkt.Name = pkt.Name[:nl] + pkt.OutLen = len(pkt.Payload) + return nil +} + +func (u *StdConn) WriteBatch(pkts []*packet.Packet) (int, error) { + if len(pkts) == 0 { + return 0, nil + } + + u.msgs = u.msgs[:0] + //u.iovs = u.iovs[:0] + + sent := 0 + var mostRecentPkt *packet.Packet + mostRecentPktSize := 0 + //segmenting := false + idx := 0 + for _, pkt := range pkts { + if len(pkt.Payload) == 0 || pkt.OutLen == -1 { + sent++ + continue + } + lastIdx := idx - 1 + if mostRecentPkt != nil && pkt.CompatibleForSegmentationWith(mostRecentPkt, mostRecentPktSize) && u.msgs[lastIdx].Hdr.Iovlen < iovMax { + u.msgs[lastIdx].Hdr.Controllen = uint64(len(mostRecentPkt.Control)) + u.msgs[lastIdx].Hdr.Control = &mostRecentPkt.Control[0] + + u.iovs[lastIdx][u.msgs[lastIdx].Hdr.Iovlen].Base = &pkt.Payload[0] + u.iovs[lastIdx][u.msgs[lastIdx].Hdr.Iovlen].Len = uint64(len(pkt.Payload)) + u.msgs[lastIdx].Hdr.Iovlen++ + + mostRecentPktSize += len(pkt.Payload) + mostRecentPkt.SetSegSizeForTX() + } else { + u.msgs = append(u.msgs, rawMessage{}) + u.iovs[idx][0] = iovec{ + Base: &pkt.Payload[0], + Len: uint64(len(pkt.Payload)), + } + + msg := &u.msgs[idx] + iov := &u.iovs[idx][0] + idx++ + + msg.Hdr.Iov = iov + msg.Hdr.Iovlen = 1 + setRawMessageControl(msg, nil) + msg.Hdr.Flags = 0 + + msg.Hdr.Name = &pkt.Name[0] + msg.Hdr.Namelen = uint32(len(pkt.Name)) + mostRecentPkt = pkt + mostRecentPktSize = len(pkt.Payload) + } + } + + if len(u.msgs) == 0 { + return sent, nil + } + + offset := 0 + for offset < len(u.msgs) { + n, _, errno := unix.Syscall6( + unix.SYS_SENDMMSG, + uintptr(u.sysFd), + uintptr(unsafe.Pointer(&u.msgs[offset])), + uintptr(len(u.msgs)-offset), + 0, + 0, + 0, + ) + + if errno != 0 { + if errno == unix.EINTR { + continue + } + //for i := 0; i < len(u.msgs); i++ { + // for j := 0; j < int(u.msgs[i].Hdr.Iovlen); j++ { + // u.l.WithFields(logrus.Fields{ + // "msg_index": i, + // "iov idx": j, + // "iov": fmt.Sprintf("%+v", u.iovs[i][j]), + // }).Warn("failed to send message") + // } + // + //} + u.l.WithFields(logrus.Fields{ + "errno": errno, + "idx": idx, + "len": len(u.msgs), + "deets": fmt.Sprintf("%+v", u.msgs), + "lastIOV": fmt.Sprintf("%+v", u.iovs[len(u.msgs)-1][u.msgs[len(u.msgs)-1].Hdr.Iovlen-1]), + }).Error("failed to send message") + return sent + offset, &net.OpError{Op: "sendmmsg", Err: errno} + } + + if n == 0 { + break + } + offset += int(n) + } + + return sent + len(u.msgs), nil +} + +func (u *StdConn) encodeSockaddr(dst []byte, addr netip.AddrPort) (uint32, error) { + if u.isV4 { + if !addr.Addr().Is4() { + return 0, fmt.Errorf("Listener is IPv4, but writing to IPv6 remote") + } + var sa unix.RawSockaddrInet4 + sa.Family = unix.AF_INET + sa.Addr = addr.Addr().As4() + binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&sa.Port))[:], addr.Port()) + size := unix.SizeofSockaddrInet4 + copy(dst[:size], (*(*[unix.SizeofSockaddrInet4]byte)(unsafe.Pointer(&sa)))[:]) + return uint32(size), nil + } + + var sa unix.RawSockaddrInet6 + sa.Family = unix.AF_INET6 + sa.Addr = addr.Addr().As16() + binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&sa.Port))[:], addr.Port()) + size := unix.SizeofSockaddrInet6 + copy(dst[:size], (*(*[unix.SizeofSockaddrInet6]byte)(unsafe.Pointer(&sa)))[:]) + return uint32(size), nil +} + func (u *StdConn) writeTo6(b []byte, ip netip.AddrPort) error { var rsa unix.RawSockaddrInet6 rsa.Family = unix.AF_INET6 @@ -298,6 +449,27 @@ func (u *StdConn) ReloadConfig(c *config.C) { u.l.WithError(err).Error("Failed to set listen.so_mark") } } + u.configureGRO(true) +} + +func (u *StdConn) configureGRO(enable bool) { + if enable == u.enableGRO { + return + } + + if enable { + if err := unix.SetsockoptInt(u.sysFd, unix.SOL_UDP, unix.UDP_GRO, 1); err != nil { + u.l.WithError(err).Warn("Failed to enable UDP GRO") + return + } + u.enableGRO = true + u.l.Info("UDP GRO enabled") + } else { + if err := unix.SetsockoptInt(u.sysFd, unix.SOL_UDP, unix.UDP_GRO, 0); err != nil && err != unix.ENOPROTOOPT { + u.l.WithError(err).Warn("Failed to disable UDP GRO") + } + u.enableGRO = false + } } func (u *StdConn) getMemInfo(meminfo *[unix.SK_MEMINFO_VARS]uint32) error { diff --git a/udp/udp_linux_64.go b/udp/udp_linux_64.go index 48c5a97..e9e3ccb 100644 --- a/udp/udp_linux_64.go +++ b/udp/udp_linux_64.go @@ -7,6 +7,7 @@ package udp import ( + "github.com/slackhq/nebula/packet" "golang.org/x/sys/unix" ) @@ -33,25 +34,59 @@ type rawMessage struct { Pad0 [4]byte } -func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) { +func setRawMessageControl(msg *rawMessage, buf []byte) { + if len(buf) == 0 { + msg.Hdr.Control = nil + msg.Hdr.Controllen = 0 + return + } + msg.Hdr.Control = &buf[0] + msg.Hdr.Controllen = uint64(len(buf)) +} + +func getRawMessageControlLen(msg *rawMessage) int { + return int(msg.Hdr.Controllen) +} + +func setCmsgLen(h *unix.Cmsghdr, l int) { + h.Len = uint64(l) +} + +func (u *StdConn) PrepareRawMessages(n int, isV4 bool) ([]rawMessage, []*packet.Packet) { msgs := make([]rawMessage, n) - buffers := make([][]byte, n) - names := make([][]byte, n) + packets := make([]*packet.Packet, n) for i := range msgs { - buffers[i] = make([]byte, MTU) - names[i] = make([]byte, unix.SizeofSockaddrInet6) + packets[i] = packet.New(isV4) vs := []iovec{ - {Base: &buffers[i][0], Len: uint64(len(buffers[i]))}, + {Base: &packets[i].Payload[0], Len: uint64(packet.Size)}, } msgs[i].Hdr.Iov = &vs[0] msgs[i].Hdr.Iovlen = uint64(len(vs)) - msgs[i].Hdr.Name = &names[i][0] - msgs[i].Hdr.Namelen = uint32(len(names[i])) + msgs[i].Hdr.Name = &packets[i].Name[0] + msgs[i].Hdr.Namelen = uint32(len(packets[i].Name)) + + if u.enableGRO { + msgs[i].Hdr.Control = &packets[i].Control[0] + msgs[i].Hdr.Controllen = uint64(len(packets[i].Control)) + } else { + msgs[i].Hdr.Control = nil + msgs[i].Hdr.Controllen = 0 + } } - return msgs, buffers, names + return msgs, packets +} + +func setIovecSlice(iov *iovec, b []byte) { + if len(b) == 0 { + iov.Base = nil + iov.Len = 0 + return + } + iov.Base = &b[0] + iov.Len = uint64(len(b)) } diff --git a/util/virtio/doc.go b/util/virtio/doc.go new file mode 100644 index 0000000..ff866f0 --- /dev/null +++ b/util/virtio/doc.go @@ -0,0 +1,3 @@ +// Package virtio contains some generic types and concepts related to the virtio +// protocol. +package virtio diff --git a/util/virtio/features.go b/util/virtio/features.go new file mode 100644 index 0000000..ff5c873 --- /dev/null +++ b/util/virtio/features.go @@ -0,0 +1,136 @@ +package virtio + +// Feature contains feature bits that describe a virtio device or driver. +type Feature uint64 + +// Device-independent feature bits. +// +// Source: https://docs.oasis-open.org/virtio/virtio/v1.2/csd01/virtio-v1.2-csd01.html#x1-6600006 +const ( + // FeatureIndirectDescriptors indicates that the driver can use descriptors + // with an additional layer of indirection. + FeatureIndirectDescriptors Feature = 1 << 28 + + // FeatureVersion1 indicates compliance with version 1.0 of the virtio + // specification. + FeatureVersion1 Feature = 1 << 32 +) + +// Feature bits for networking devices. +// +// Source: https://docs.oasis-open.org/virtio/virtio/v1.2/csd01/virtio-v1.2-csd01.html#x1-2200003 +const ( + // FeatureNetDeviceCsum indicates that the device can handle packets with + // partial checksum (checksum offload). + FeatureNetDeviceCsum Feature = 1 << 0 + + // FeatureNetDriverCsum indicates that the driver can handle packets with + // partial checksum. + FeatureNetDriverCsum Feature = 1 << 1 + + // FeatureNetCtrlDriverOffloads indicates support for dynamic offload state + // reconfiguration. + FeatureNetCtrlDriverOffloads Feature = 1 << 2 + + // FeatureNetMTU indicates that the device reports a maximum MTU value. + FeatureNetMTU Feature = 1 << 3 + + // FeatureNetMAC indicates that the device provides a MAC address. + FeatureNetMAC Feature = 1 << 5 + + // FeatureNetDriverTSO4 indicates that the driver supports the TCP + // segmentation offload for received IPv4 packets. + FeatureNetDriverTSO4 Feature = 1 << 7 + + // FeatureNetDriverTSO6 indicates that the driver supports the TCP + // segmentation offload for received IPv6 packets. + FeatureNetDriverTSO6 Feature = 1 << 8 + + // FeatureNetDriverECN indicates that the driver supports the TCP + // segmentation offload with ECN for received packets. + FeatureNetDriverECN Feature = 1 << 9 + + // FeatureNetDriverUFO indicates that the driver supports the UDP + // fragmentation offload for received packets. + FeatureNetDriverUFO Feature = 1 << 10 + + // FeatureNetDeviceTSO4 indicates that the device supports the TCP + // segmentation offload for received IPv4 packets. + FeatureNetDeviceTSO4 Feature = 1 << 11 + + // FeatureNetDeviceTSO6 indicates that the device supports the TCP + // segmentation offload for received IPv6 packets. + FeatureNetDeviceTSO6 Feature = 1 << 12 + + // FeatureNetDeviceECN indicates that the device supports the TCP + // segmentation offload with ECN for received packets. + FeatureNetDeviceECN Feature = 1 << 13 + + // FeatureNetDeviceUFO indicates that the device supports the UDP + // fragmentation offload for received packets. + FeatureNetDeviceUFO Feature = 1 << 14 + + // FeatureNetMergeRXBuffers indicates that the driver can handle merged + // receive buffers. + // When this feature is negotiated, devices may merge multiple descriptor + // chains together to transport large received packets. [NetHdr.NumBuffers] + // will then contain the number of merged descriptor chains. + FeatureNetMergeRXBuffers Feature = 1 << 15 + + // FeatureNetStatus indicates that the device configuration status field is + // available. + FeatureNetStatus Feature = 1 << 16 + + // FeatureNetCtrlVQ indicates that a control channel virtqueue is + // available. + FeatureNetCtrlVQ Feature = 1 << 17 + + // FeatureNetCtrlRX indicates support for RX mode control (e.g. promiscuous + // or all-multicast) for packet receive filtering. + FeatureNetCtrlRX Feature = 1 << 18 + + // FeatureNetCtrlVLAN indicates support for VLAN filtering through the + // control channel. + FeatureNetCtrlVLAN Feature = 1 << 19 + + // FeatureNetDriverAnnounce indicates that the driver can send gratuitous + // packets. + FeatureNetDriverAnnounce Feature = 1 << 21 + + // FeatureNetMQ indicates that the device supports multiqueue with automatic + // receive steering. + FeatureNetMQ Feature = 1 << 22 + + // FeatureNetCtrlMACAddr indicates that the MAC address can be set through + // the control channel. + FeatureNetCtrlMACAddr Feature = 1 << 23 + + // FeatureNetDeviceUSO indicates that the device supports the UDP + // segmentation offload for received packets. + FeatureNetDeviceUSO Feature = 1 << 56 + + // FeatureNetHashReport indicates that the device can report a per-packet + // hash value and type. + FeatureNetHashReport Feature = 1 << 57 + + // FeatureNetDriverHdrLen indicates that the driver can provide the exact + // header length value (see [NetHdr.HdrLen]). + // Devices may benefit from knowing the exact header length. + FeatureNetDriverHdrLen Feature = 1 << 59 + + // FeatureNetRSS indicates that the device supports RSS (receive-side + // scaling) with configurable hash parameters. + FeatureNetRSS Feature = 1 << 60 + + // FeatureNetRSCExt indicates that the device can process duplicated ACKs + // and report the number of coalesced segments and duplicated ACKs. + FeatureNetRSCExt Feature = 1 << 61 + + // FeatureNetStandby indicates that the device may act as a standby for a + // primary device with the same MAC address. + FeatureNetStandby Feature = 1 << 62 + + // FeatureNetSpeedDuplex indicates that the device can report link speed and + // duplex mode. + FeatureNetSpeedDuplex Feature = 1 << 63 +) diff --git a/util/virtio/net_hdr.go b/util/virtio/net_hdr.go new file mode 100644 index 0000000..8acb7c8 --- /dev/null +++ b/util/virtio/net_hdr.go @@ -0,0 +1,77 @@ +package virtio + +import ( + "errors" + "unsafe" + + "golang.org/x/sys/unix" +) + +// Workaround to make Go doc links work. +var _ unix.Errno + +// NetHdrSize is the number of bytes needed to store a [NetHdr] in memory. +const NetHdrSize = 12 + +// ErrNetHdrBufferTooSmall is returned when a buffer is too small to fit a +// virtio_net_hdr. +var ErrNetHdrBufferTooSmall = errors.New("the buffer is too small to fit a virtio_net_hdr") + +// NetHdr defines the virtio_net_hdr as described by the virtio specification. +type NetHdr struct { + // Flags that describe the packet. + // Possible values are: + // - [unix.VIRTIO_NET_HDR_F_NEEDS_CSUM] + // - [unix.VIRTIO_NET_HDR_F_DATA_VALID] + // - [unix.VIRTIO_NET_HDR_F_RSC_INFO] + Flags uint8 + // GSOType contains the type of segmentation offload that should be used for + // the packet. + // Possible values are: + // - [unix.VIRTIO_NET_HDR_GSO_NONE] + // - [unix.VIRTIO_NET_HDR_GSO_TCPV4] + // - [unix.VIRTIO_NET_HDR_GSO_UDP] + // - [unix.VIRTIO_NET_HDR_GSO_TCPV6] + // - [unix.VIRTIO_NET_HDR_GSO_UDP_L4] + // - [unix.VIRTIO_NET_HDR_GSO_ECN] + GSOType uint8 + // HdrLen contains the length of the headers that need to be replicated by + // segmentation offloads. It's the number of bytes from the beginning of the + // packet to the beginning of the transport payload. + // Only used when [FeatureNetDriverHdrLen] is negotiated. + HdrLen uint16 + // GSOSize contains the maximum size of each segmented packet beyond the + // header (payload size). In case of TCP, this is the MSS. + GSOSize uint16 + // CsumStart contains the offset within the packet from which on the + // checksum should be computed. + CsumStart uint16 + // CsumOffset specifies how many bytes after [NetHdr.CsumStart] the computed + // 16-bit checksum should be inserted. + CsumOffset uint16 + // NumBuffers contains the number of merged descriptor chains when + // [FeatureNetMergeRXBuffers] is negotiated. + // This field is only used for packets received by the driver and should be + // zero for transmitted packets. + NumBuffers uint16 +} + +// Decode decodes the [NetHdr] from the given byte slice. The slice must contain +// at least [NetHdrSize] bytes. +func (v *NetHdr) Decode(data []byte) error { + if len(data) < NetHdrSize { + return ErrNetHdrBufferTooSmall + } + copy(unsafe.Slice((*byte)(unsafe.Pointer(v)), NetHdrSize), data[:NetHdrSize]) + return nil +} + +// Encode encodes the [NetHdr] into the given byte slice. The slice must have +// room for at least [NetHdrSize] bytes. +func (v *NetHdr) Encode(data []byte) error { + if len(data) < NetHdrSize { + return ErrNetHdrBufferTooSmall + } + copy(data[:NetHdrSize], unsafe.Slice((*byte)(unsafe.Pointer(v)), NetHdrSize)) + return nil +} diff --git a/util/virtio/net_hdr_test.go b/util/virtio/net_hdr_test.go new file mode 100644 index 0000000..81fd22b --- /dev/null +++ b/util/virtio/net_hdr_test.go @@ -0,0 +1,43 @@ +package virtio + +import ( + "testing" + "unsafe" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/sys/unix" +) + +func TestNetHdr_Size(t *testing.T) { + assert.EqualValues(t, NetHdrSize, unsafe.Sizeof(NetHdr{})) +} + +func TestNetHdr_Encoding(t *testing.T) { + vnethdr := NetHdr{ + Flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, + GSOType: unix.VIRTIO_NET_HDR_GSO_UDP_L4, + HdrLen: 42, + GSOSize: 1472, + CsumStart: 34, + CsumOffset: 6, + NumBuffers: 16, + } + + buf := make([]byte, NetHdrSize) + require.NoError(t, vnethdr.Encode(buf)) + + assert.Equal(t, []byte{ + 0x01, 0x05, + 0x2a, 0x00, + 0xc0, 0x05, + 0x22, 0x00, + 0x06, 0x00, + 0x10, 0x00, + }, buf) + + var decoded NetHdr + require.NoError(t, decoded.Decode(buf)) + + assert.Equal(t, vnethdr, decoded) +}