diff --git a/interface.go b/interface.go index 082906d..5cef0b7 100644 --- a/interface.go +++ b/interface.go @@ -86,7 +86,7 @@ type Interface struct { conntrackCacheTimeout time.Duration writers []udp.Conn - readers []io.ReadWriteCloser + readers []overlay.TunDev metricHandshakes metrics.Histogram messageMetrics *MessageMetrics @@ -177,7 +177,7 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) { routines: c.routines, version: c.version, writers: make([]udp.Conn, c.routines), - readers: make([]io.ReadWriteCloser, c.routines), + readers: make([]overlay.TunDev, c.routines), myVpnNetworks: cs.myVpnNetworks, myVpnNetworksTable: cs.myVpnNetworksTable, myVpnAddrs: cs.myVpnAddrs, @@ -225,7 +225,7 @@ func (f *Interface) activate() { metrics.GetOrRegisterGauge("routines", nil).Update(int64(f.routines)) // Prepare n tun queues - var reader io.ReadWriteCloser = f.inside + var reader overlay.TunDev = f.inside for i := 0; i < f.routines; i++ { if i > 0 { reader, err = f.inside.NewMultiQueueReader() @@ -254,25 +254,52 @@ func (f *Interface) run() { } } -func (f *Interface) listenOut(i int) { +func (f *Interface) listenOut(q int) { runtime.LockOSThread() var li udp.Conn - if i > 0 { - li = f.writers[i] + if q > 0 { + li = f.writers[q] } else { li = f.outside } + const batch = 64 //todo + ctCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout) lhh := f.lightHouse.NewRequestHandler() - plaintext := make([]byte, udp.MTU) + plaintexts := make([][]byte, batch) + outNeedsTun := make([]*int, batch) + for i := 0; i < batch; i++ { + plaintexts[i] = make([]byte, udp.MTU) + outNeedsTun[i] = new(int) + *outNeedsTun[i] = -1 + } + h := &header.H{} fwPacket := &firewall.Packet{} nb := make([]byte, 12, 12) - li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) { - f.readOutsidePackets(fromUdpAddr, nil, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get(f.l)) + toSend := make([][]byte, batch) + + li.ListenOut(func(fromUdpAddrs []netip.AddrPort, payloads [][]byte) { + toSend = toSend[:0] + for i := range plaintexts { + plaintexts[i] = plaintexts[i][:0] + } + f.readOutsidePacketsMany(fromUdpAddrs, plaintexts, outNeedsTun, payloads, h, fwPacket, lhh, nb, q, ctCache.Get(f.l)) + for i := range plaintexts { + if *outNeedsTun[i] != -1 { + toSend = append(toSend, plaintexts[i][:*outNeedsTun[i]]) + *outNeedsTun[i] = -1 + //toSendCount++ + } + } + //toSend = toSend[:toSendCount] + _, err := f.readers[q].WriteMany(toSend) + if err != nil { + f.l.WithError(err).Error("Failed to write messages") + } }) } diff --git a/outside.go b/outside.go index 5ff87bd..83a5ae0 100644 --- a/outside.go +++ b/outside.go @@ -216,6 +216,207 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out [] f.connectionManager.In(hostinfo) } +func (f *Interface) readOutsidePacketsMany(ip []netip.AddrPort, out [][]byte, outNeedsTun []*int, packets [][]byte, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache) { + for i, packet := range packets { + + err := h.Parse(packet) + if err != nil { + // Hole punch packets are 0 or 1 byte big, so lets ignore printing those errors + if len(packet) > 1 { + f.l.WithField("packet", packet).Infof("Error while parsing inbound packet from %s: %s", ip, err) + } + return + } + + //l.Error("in packet ", header, packet[HeaderLen:]) + if ip[i].IsValid() { + if f.myVpnNetworksTable.Contains(ip[i].Addr()) { + if f.l.Level >= logrus.DebugLevel { + f.l.WithField("udpAddr", ip).Debug("Refusing to process double encrypted packet") + } + return + } + } + + var hostinfo *HostInfo + // verify if we've seen this index before, otherwise respond to the handshake initiation + if h.Type == header.Message && h.Subtype == header.MessageRelay { + hostinfo = f.hostMap.QueryRelayIndex(h.RemoteIndex) + } else { + hostinfo = f.hostMap.QueryIndex(h.RemoteIndex) + } + + var ci *ConnectionState + if hostinfo != nil { + ci = hostinfo.ConnectionState + } + + switch h.Type { + case header.Message: + // TODO handleEncrypted sends directly to addr on error. Handle this in the tunneling case. + if !f.handleEncrypted(ci, ip[i], h) { + return + } + + switch h.Subtype { + case header.MessageNone: + out[i] = out[i][:0] + if !f.decryptToTunDelayWrite(hostinfo, h.MessageCounter, out[i][:0], outNeedsTun[i], packet, fwPacket, nb, q, localCache) { + return + } + case header.MessageRelay: + // The entire body is sent as AD, not encrypted. + // The packet consists of a 16-byte parsed Nebula header, Associated Data-protected payload, and a trailing 16-byte AEAD signature value. + // The packet is guaranteed to be at least 16 bytes at this point, b/c it got past the h.Parse() call above. If it's + // otherwise malformed (meaning, there is no trailing 16 byte AEAD value), then this will result in at worst a 0-length slice + // which will gracefully fail in the DecryptDanger call. + signedPayload := packet[:len(packet)-hostinfo.ConnectionState.dKey.Overhead()] + signatureValue := packet[len(packet)-hostinfo.ConnectionState.dKey.Overhead():] + out[i], err = hostinfo.ConnectionState.dKey.DecryptDanger(out[i], signedPayload, signatureValue, h.MessageCounter, nb) + if err != nil { + return + } + // Successfully validated the thing. Get rid of the Relay header. + signedPayload = signedPayload[header.Len:] + // Pull the Roaming parts up here, and return in all call paths. + f.handleHostRoaming(hostinfo, ip[i]) + // Track usage of both the HostInfo and the Relay for the received & authenticated packet + f.connectionManager.In(hostinfo) + f.connectionManager.RelayUsed(h.RemoteIndex) + + relay, ok := hostinfo.relayState.QueryRelayForByIdx(h.RemoteIndex) + if !ok { + // The only way this happens is if hostmap has an index to the correct HostInfo, but the HostInfo is missing + // its internal mapping. This should never happen. + hostinfo.logger(f.l).WithFields(logrus.Fields{"vpnAddrs": hostinfo.vpnAddrs, "remoteIndex": h.RemoteIndex}).Error("HostInfo missing remote relay index") + return + } + + switch relay.Type { + case TerminalType: + // If I am the target of this relay, process the unwrapped packet + // From this recursive point, all these variables are 'burned'. We shouldn't rely on them again. + f.readOutsidePackets(netip.AddrPort{}, &ViaSender{relayHI: hostinfo, remoteIdx: relay.RemoteIndex, relay: relay}, out[i][:0], signedPayload, h, fwPacket, lhf, nb, q, localCache) + return + case ForwardingType: + // Find the target HostInfo relay object + targetHI, targetRelay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relay.PeerAddr) + if err != nil { + hostinfo.logger(f.l).WithField("relayTo", relay.PeerAddr).WithError(err).WithField("hostinfo.vpnAddrs", hostinfo.vpnAddrs).Info("Failed to find target host info by ip") + return + } + + // If that relay is Established, forward the payload through it + if targetRelay.State == Established { + switch targetRelay.Type { + case ForwardingType: + // Forward this packet through the relay tunnel + // Find the target HostInfo + f.SendVia(targetHI, targetRelay, signedPayload, nb, out[i], false) + return + case TerminalType: + hostinfo.logger(f.l).Error("Unexpected Relay Type of Terminal") + } + } else { + hostinfo.logger(f.l).WithFields(logrus.Fields{"relayTo": relay.PeerAddr, "relayFrom": hostinfo.vpnAddrs[0], "targetRelayState": targetRelay.State}).Info("Unexpected target relay state") + return + } + } + } + + case header.LightHouse: + f.messageMetrics.Rx(h.Type, h.Subtype, 1) + if !f.handleEncrypted(ci, ip[i], h) { + return + } + + d, err := f.decrypt(hostinfo, h.MessageCounter, out[i], packet, h, nb) + if err != nil { + hostinfo.logger(f.l).WithError(err).WithField("udpAddr", ip). + WithField("packet", packet). + Error("Failed to decrypt lighthouse packet") + return + } + + lhf.HandleRequest(ip[i], hostinfo.vpnAddrs, d, f) + + // Fallthrough to the bottom to record incoming traffic + + case header.Test: + f.messageMetrics.Rx(h.Type, h.Subtype, 1) + if !f.handleEncrypted(ci, ip[i], h) { + return + } + + d, err := f.decrypt(hostinfo, h.MessageCounter, out[i], packet, h, nb) + if err != nil { + hostinfo.logger(f.l).WithError(err).WithField("udpAddr", ip). + WithField("packet", packet). + Error("Failed to decrypt test packet") + return + } + + if h.Subtype == header.TestRequest { + // This testRequest might be from TryPromoteBest, so we should roam + // to the new IP address before responding + f.handleHostRoaming(hostinfo, ip[i]) + f.send(header.Test, header.TestReply, ci, hostinfo, d, nb, out[i]) + } + + // Fallthrough to the bottom to record incoming traffic + + // Non encrypted messages below here, they should not fall through to avoid tracking incoming traffic since they + // are unauthenticated + + case header.Handshake: + f.messageMetrics.Rx(h.Type, h.Subtype, 1) + f.handshakeManager.HandleIncoming(ip[i], nil, packet, h) + return + + case header.RecvError: + f.messageMetrics.Rx(h.Type, h.Subtype, 1) + f.handleRecvError(ip[i], h) + return + + case header.CloseTunnel: + f.messageMetrics.Rx(h.Type, h.Subtype, 1) + if !f.handleEncrypted(ci, ip[i], h) { + return + } + + hostinfo.logger(f.l).WithField("udpAddr", ip). + Info("Close tunnel received, tearing down.") + + f.closeTunnel(hostinfo) + return + + case header.Control: + if !f.handleEncrypted(ci, ip[i], h) { + return + } + + d, err := f.decrypt(hostinfo, h.MessageCounter, out[i], packet, h, nb) + if err != nil { + hostinfo.logger(f.l).WithError(err).WithField("udpAddr", ip). + WithField("packet", packet). + Error("Failed to decrypt Control packet") + return + } + + f.relayManager.HandleControlMsg(hostinfo, d, f) + + default: + f.messageMetrics.Rx(h.Type, h.Subtype, 1) + hostinfo.logger(f.l).Debugf("Unexpected packet received from %s", ip) + return + } + + f.handleHostRoaming(hostinfo, ip[i]) + + f.connectionManager.In(hostinfo) + } +} + // closeTunnel closes a tunnel locally, it does not send a closeTunnel packet to the remote func (f *Interface) closeTunnel(hostInfo *HostInfo) { final := f.hostMap.DeleteHostInfo(hostInfo) @@ -465,6 +666,46 @@ func (f *Interface) decrypt(hostinfo *HostInfo, mc uint64, out []byte, packet [] return out, nil } +func (f *Interface) decryptToTunDelayWrite(hostinfo *HostInfo, messageCounter uint64, out []byte, outNeedsTun *int, packet []byte, fwPacket *firewall.Packet, nb []byte, q int, localCache firewall.ConntrackCache) bool { + var err error + + out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:header.Len], packet[header.Len:], messageCounter, nb) + if err != nil { + hostinfo.logger(f.l).WithError(err).Error("Failed to decrypt packet") + return false + } + + err = newPacket(out, true, fwPacket) + if err != nil { + hostinfo.logger(f.l).WithError(err).WithField("packet", out). + Warnf("Error while validating inbound packet") + return false + } + + if !hostinfo.ConnectionState.window.Update(f.l, messageCounter) { + hostinfo.logger(f.l).WithField("fwPacket", fwPacket). + Debugln("dropping out of window packet") + return false + } + + dropReason := f.firewall.Drop(*fwPacket, true, hostinfo, f.pki.GetCAPool(), localCache) + if dropReason != nil { + // NOTE: We give `packet` as the `out` here since we already decrypted from it and we don't need it anymore + // This gives us a buffer to build the reject packet in + f.rejectOutside(out, hostinfo.ConnectionState, hostinfo, nb, packet, q) + if f.l.Level >= logrus.DebugLevel { + hostinfo.logger(f.l).WithField("fwPacket", fwPacket). + WithField("reason", dropReason). + Debugln("dropping inbound packet") + } + return false + } + + f.connectionManager.In(hostinfo) + *outNeedsTun = len(out) + return true +} + func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out []byte, packet []byte, fwPacket *firewall.Packet, nb []byte, q int, localCache firewall.ConntrackCache) bool { var err error diff --git a/overlay/device.go b/overlay/device.go index 07146ab..5b058b5 100644 --- a/overlay/device.go +++ b/overlay/device.go @@ -1,17 +1,16 @@ package overlay import ( - "io" "net/netip" "github.com/slackhq/nebula/routing" ) type Device interface { - io.ReadWriteCloser + TunDev Activate() error Networks() []netip.Prefix Name() string RoutesFor(netip.Addr) routing.Gateways - NewMultiQueueReader() (io.ReadWriteCloser, error) + NewMultiQueueReader() (TunDev, error) } diff --git a/overlay/tun.go b/overlay/tun.go index adee8de..7c84d97 100644 --- a/overlay/tun.go +++ b/overlay/tun.go @@ -2,6 +2,7 @@ package overlay import ( "fmt" + "io" "net" "net/netip" @@ -12,6 +13,11 @@ import ( const DefaultMTU = 1300 +type TunDev interface { + io.ReadWriteCloser + WriteMany([][]byte) (int, error) +} + // TODO: We may be able to remove routines type DeviceFactory func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) diff --git a/overlay/tun_disabled.go b/overlay/tun_disabled.go index 131879d..875fa3c 100644 --- a/overlay/tun_disabled.go +++ b/overlay/tun_disabled.go @@ -105,7 +105,19 @@ func (t *disabledTun) Write(b []byte) (int, error) { return len(b), nil } -func (t *disabledTun) NewMultiQueueReader() (io.ReadWriteCloser, error) { +func (t *disabledTun) WriteMany(b [][]byte) (int, error) { + out := 0 + for i := range b { + x, err := t.Write(b[i]) + if err != nil { + return out, err + } + out += x + } + return out, nil +} + +func (t *disabledTun) NewMultiQueueReader() (TunDev, error) { return t, nil } diff --git a/overlay/tun_linux.go b/overlay/tun_linux.go index 31bc4a3..c53a54d 100644 --- a/overlay/tun_linux.go +++ b/overlay/tun_linux.go @@ -257,7 +257,7 @@ func (t *tun) reload(c *config.C, initial bool) error { return nil } -func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { +func (t *tun) NewMultiQueueReader() (TunDev, error) { //fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0) //if err != nil { // return nil, err @@ -741,3 +741,24 @@ func (t *tun) Write(b []byte) (int, error) { } return maximum, nil } + +func (t *tun) WriteMany(b [][]byte) (int, error) { + maximum := len(b) //we are RXing + + hdr := virtio.NetHdr{ //todo + Flags: unix.VIRTIO_NET_HDR_F_DATA_VALID, + GSOType: unix.VIRTIO_NET_HDR_GSO_NONE, + HdrLen: 0, + GSOSize: 0, + CsumStart: 0, + CsumOffset: 0, + NumBuffers: 0, + } + + err := t.vdev.TransmitPackets(hdr, b) + if err != nil { + t.l.WithError(err).Error("Transmitting packet") + return 0, err + } + return maximum, nil +} diff --git a/overlay/user.go b/overlay/user.go index 8a56d66..a1a937c 100644 --- a/overlay/user.go +++ b/overlay/user.go @@ -46,7 +46,7 @@ func (d *UserDevice) RoutesFor(ip netip.Addr) routing.Gateways { return routing.Gateways{routing.NewGateway(ip, 1)} } -func (d *UserDevice) NewMultiQueueReader() (io.ReadWriteCloser, error) { +func (d *UserDevice) NewMultiQueueReader() (TunDev, error) { return d, nil } @@ -65,3 +65,15 @@ func (d *UserDevice) Close() error { d.outboundWriter.Close() return nil } + +func (d *UserDevice) WriteMany(b [][]byte) (int, error) { + out := 0 + for i := range b { + x, err := d.Write(b[i]) + if err != nil { + return out, err + } + out += x + } + return out, nil +} diff --git a/overlay/vhostnet/device.go b/overlay/vhostnet/device.go index b7531ee..53f0308 100644 --- a/overlay/vhostnet/device.go +++ b/overlay/vhostnet/device.go @@ -311,6 +311,33 @@ func (dev *Device) TransmitPacket(vnethdr virtio.NetHdr, packet []byte) error { return nil } +func (dev *Device) TransmitPackets(vnethdr virtio.NetHdr, packets [][]byte) error { + // Prepend the packet with its virtio-net header. + vnethdrBuf := make([]byte, virtio.NetHdrSize+14) //todo WHY + if err := vnethdr.Encode(vnethdrBuf); err != nil { + return fmt.Errorf("encode vnethdr: %w", err) + } + vnethdrBuf[virtio.NetHdrSize+14-2] = 0x86 + vnethdrBuf[virtio.NetHdrSize+14-1] = 0xdd //todo ipv6 ethertype + + chainIndexes, err := dev.transmitQueue.OfferOutDescriptorChains(vnethdrBuf, packets, true) + if err != nil { + return fmt.Errorf("offer descriptor chain: %w", err) + } + + //todo blocking here suxxxx + // Wait for the packet to have been transmitted. + for i := range chainIndexes { + <-dev.transmitted[chainIndexes[i]] + + if err = dev.transmitQueue.FreeDescriptorChain(chainIndexes[i]); err != nil { + return fmt.Errorf("free descriptor chain: %w", err) + } + } + + return nil +} + // ReceivePacket reads the next available packet from the receive queue of this // device and returns its [virtio.NetHdr] and packet data separately. // diff --git a/overlay/virtqueue/split_virtqueue.go b/overlay/virtqueue/split_virtqueue.go index a88dff9..9a0ba76 100644 --- a/overlay/virtqueue/split_virtqueue.go +++ b/overlay/virtqueue/split_virtqueue.go @@ -345,6 +345,66 @@ func (sq *SplitQueue) OfferDescriptorChain(outBuffers [][]byte, numInBuffers int return head, nil } +func (sq *SplitQueue) OfferOutDescriptorChains(prepend []byte, outBuffers [][]byte, waitFree bool) ([]uint16, error) { + sq.ensureInitialized() + + // TODO change this + // Each descriptor can only hold a whole memory page, so split large out + // buffers into multiple smaller ones. + outBuffers = splitBuffers(outBuffers, sq.pageSize) + + // Synchronize the offering of descriptor chains. While the descriptor table + // and available ring are synchronized on their own as well, this does not + // protect us from interleaved calls which could cause reordering. + // By locking here, we can ensure that all descriptor chains are made + // available to the device in the same order as this method was called. + sq.offerMutex.Lock() + defer sq.offerMutex.Unlock() + + chains := make([]uint16, len(outBuffers)) + + // Create a descriptor chain for the given buffers. + var ( + head uint16 + err error + ) + for i := range outBuffers { + for { + bufs := [][]byte{prepend, outBuffers[i]} + head, err = sq.descriptorTable.createDescriptorChain(bufs, 0) + if err == nil { + break + } + + // I don't wanna use errors.Is, it's slow + //goland:noinspection GoDirectComparisonOfErrors + if err == ErrNotEnoughFreeDescriptors { + if waitFree { + // Wait for more free descriptors to be put back into the queue. + // If the number of free descriptors is still not sufficient, we'll + // land here again. + sq.blockForMoreDescriptors() + continue + } else { + return nil, err + } + } + return nil, fmt.Errorf("create descriptor chain: %w", err) + } + chains[i] = head + } + + // Make the descriptor chain available to the device. + sq.availableRing.offer(chains) + + // Notify the device to make it process the updated available ring. + if err := sq.kickEventFD.Kick(); err != nil { + return chains, fmt.Errorf("notify device: %w", err) + } + + return chains, nil +} + // GetDescriptorChain returns the device-readable buffers (out buffers) and // device-writable buffers (in buffers) of the descriptor chain with the given // head index. diff --git a/udp/conn.go b/udp/conn.go index 895b0df..6a3c3b4 100644 --- a/udp/conn.go +++ b/udp/conn.go @@ -9,8 +9,8 @@ import ( const MTU = 9001 type EncReader func( - addr netip.AddrPort, - payload []byte, + addrs []netip.AddrPort, + payload [][]byte, ) type Conn interface { diff --git a/udp/udp_linux.go b/udp/udp_linux.go index ec0bf64..e1ba229 100644 --- a/udp/udp_linux.go +++ b/udp/udp_linux.go @@ -120,7 +120,7 @@ func (u *StdConn) LocalAddr() (netip.AddrPort, error) { func (u *StdConn) ListenOut(r EncReader) { var ip netip.Addr - + addrPorts := make([]netip.AddrPort, u.batch) msgs, buffers, names := u.PrepareRawMessages(u.batch) read := u.ReadMulti if u.batch == 1 { @@ -141,8 +141,11 @@ func (u *StdConn) ListenOut(r EncReader) { } else { ip, _ = netip.AddrFromSlice(names[i][8:24]) } - r(netip.AddrPortFrom(ip.Unmap(), binary.BigEndian.Uint16(names[i][2:4])), buffers[i][:msgs[i].Len]) + addrPorts[i] = netip.AddrPortFrom(ip.Unmap(), binary.BigEndian.Uint16(names[i][2:4])) + buffers[i] = buffers[i][:msgs[i].Len] + } + r(addrPorts, buffers) } }