diff --git a/interface.go b/interface.go index 5cef0b7..4fcfd67 100644 --- a/interface.go +++ b/interface.go @@ -18,6 +18,7 @@ import ( "github.com/slackhq/nebula/firewall" "github.com/slackhq/nebula/header" "github.com/slackhq/nebula/overlay" + "github.com/slackhq/nebula/packet" "github.com/slackhq/nebula/udp" ) @@ -268,12 +269,9 @@ func (f *Interface) listenOut(q int) { ctCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout) lhh := f.lightHouse.NewRequestHandler() - plaintexts := make([][]byte, batch) - outNeedsTun := make([]*int, batch) + outPackets := make([]*packet.OutPacket, batch) for i := 0; i < batch; i++ { - plaintexts[i] = make([]byte, udp.MTU) - outNeedsTun[i] = new(int) - *outNeedsTun[i] = -1 + outPackets[i] = packet.NewOut() } h := &header.H{} @@ -282,16 +280,23 @@ func (f *Interface) listenOut(q int) { toSend := make([][]byte, batch) - li.ListenOut(func(fromUdpAddrs []netip.AddrPort, payloads [][]byte) { + li.ListenOut(func(pkts []*packet.Packet) { toSend = toSend[:0] - for i := range plaintexts { - plaintexts[i] = plaintexts[i][:0] + for i := range outPackets { + outPackets[i].Valid = false + outPackets[i].SegCounter = 0 } - f.readOutsidePacketsMany(fromUdpAddrs, plaintexts, outNeedsTun, payloads, h, fwPacket, lhh, nb, q, ctCache.Get(f.l)) - for i := range plaintexts { - if *outNeedsTun[i] != -1 { - toSend = append(toSend, plaintexts[i][:*outNeedsTun[i]]) - *outNeedsTun[i] = -1 + + f.readOutsidePacketsMany(pkts, outPackets, h, fwPacket, lhh, nb, q, ctCache.Get(f.l)) + for i := range outPackets { + if pkts[i].OutLen != -1 { + for j := 0; j < outPackets[i].SegCounter; j++ { + if len(outPackets[i].Segments[j]) > 0 { + toSend = append(toSend, outPackets[i].Segments[j]) + } + + } + //toSend = append(toSend, outPackets[i]) //toSendCount++ } } diff --git a/outside.go b/outside.go index 83a5ae0..b9ddac7 100644 --- a/outside.go +++ b/outside.go @@ -7,6 +7,7 @@ import ( "time" "github.com/google/gopacket/layers" + "github.com/slackhq/nebula/packet" "golang.org/x/net/ipv6" "github.com/sirupsen/logrus" @@ -216,21 +217,14 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out [] f.connectionManager.In(hostinfo) } -func (f *Interface) readOutsidePacketsMany(ip []netip.AddrPort, out [][]byte, outNeedsTun []*int, packets [][]byte, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache) { - for i, packet := range packets { - - err := h.Parse(packet) - if err != nil { - // Hole punch packets are 0 or 1 byte big, so lets ignore printing those errors - if len(packet) > 1 { - f.l.WithField("packet", packet).Infof("Error while parsing inbound packet from %s: %s", ip, err) - } - return - } +func (f *Interface) readOutsidePacketsMany(packets []*packet.Packet, out []*packet.OutPacket, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache) { + for i, pkt := range packets { + out[i].Scratch = out[i].Scratch[:0] + ip := pkt.AddrPort() //l.Error("in packet ", header, packet[HeaderLen:]) - if ip[i].IsValid() { - if f.myVpnNetworksTable.Contains(ip[i].Addr()) { + if ip.IsValid() { + if f.myVpnNetworksTable.Contains(ip.Addr()) { if f.l.Level >= logrus.DebugLevel { f.l.WithField("udpAddr", ip).Debug("Refusing to process double encrypted packet") } @@ -238,182 +232,194 @@ func (f *Interface) readOutsidePacketsMany(ip []netip.AddrPort, out [][]byte, ou } } - var hostinfo *HostInfo - // verify if we've seen this index before, otherwise respond to the handshake initiation - if h.Type == header.Message && h.Subtype == header.MessageRelay { - hostinfo = f.hostMap.QueryRelayIndex(h.RemoteIndex) - } else { - hostinfo = f.hostMap.QueryIndex(h.RemoteIndex) - } + //todo per-segment! + for segment := range pkt.Segments() { - var ci *ConnectionState - if hostinfo != nil { - ci = hostinfo.ConnectionState - } - - switch h.Type { - case header.Message: - // TODO handleEncrypted sends directly to addr on error. Handle this in the tunneling case. - if !f.handleEncrypted(ci, ip[i], h) { + err := h.Parse(segment) + if err != nil { + // Hole punch packets are 0 or 1 byte big, so lets ignore printing those errors + if len(segment) > 1 { + f.l.WithField("packet", pkt).Infof("Error while parsing inbound packet from %s: %s", ip, err) + } return } - switch h.Subtype { - case header.MessageNone: - out[i] = out[i][:0] - if !f.decryptToTunDelayWrite(hostinfo, h.MessageCounter, out[i][:0], outNeedsTun[i], packet, fwPacket, nb, q, localCache) { - return - } - case header.MessageRelay: - // The entire body is sent as AD, not encrypted. - // The packet consists of a 16-byte parsed Nebula header, Associated Data-protected payload, and a trailing 16-byte AEAD signature value. - // The packet is guaranteed to be at least 16 bytes at this point, b/c it got past the h.Parse() call above. If it's - // otherwise malformed (meaning, there is no trailing 16 byte AEAD value), then this will result in at worst a 0-length slice - // which will gracefully fail in the DecryptDanger call. - signedPayload := packet[:len(packet)-hostinfo.ConnectionState.dKey.Overhead()] - signatureValue := packet[len(packet)-hostinfo.ConnectionState.dKey.Overhead():] - out[i], err = hostinfo.ConnectionState.dKey.DecryptDanger(out[i], signedPayload, signatureValue, h.MessageCounter, nb) - if err != nil { - return - } - // Successfully validated the thing. Get rid of the Relay header. - signedPayload = signedPayload[header.Len:] - // Pull the Roaming parts up here, and return in all call paths. - f.handleHostRoaming(hostinfo, ip[i]) - // Track usage of both the HostInfo and the Relay for the received & authenticated packet - f.connectionManager.In(hostinfo) - f.connectionManager.RelayUsed(h.RemoteIndex) + var hostinfo *HostInfo + // verify if we've seen this index before, otherwise respond to the handshake initiation + if h.Type == header.Message && h.Subtype == header.MessageRelay { + hostinfo = f.hostMap.QueryRelayIndex(h.RemoteIndex) + } else { + hostinfo = f.hostMap.QueryIndex(h.RemoteIndex) + } - relay, ok := hostinfo.relayState.QueryRelayForByIdx(h.RemoteIndex) - if !ok { - // The only way this happens is if hostmap has an index to the correct HostInfo, but the HostInfo is missing - // its internal mapping. This should never happen. - hostinfo.logger(f.l).WithFields(logrus.Fields{"vpnAddrs": hostinfo.vpnAddrs, "remoteIndex": h.RemoteIndex}).Error("HostInfo missing remote relay index") + var ci *ConnectionState + if hostinfo != nil { + ci = hostinfo.ConnectionState + } + + switch h.Type { + case header.Message: + // TODO handleEncrypted sends directly to addr on error. Handle this in the tunneling case. + if !f.handleEncrypted(ci, ip, h) { return } - switch relay.Type { - case TerminalType: - // If I am the target of this relay, process the unwrapped packet - // From this recursive point, all these variables are 'burned'. We shouldn't rely on them again. - f.readOutsidePackets(netip.AddrPort{}, &ViaSender{relayHI: hostinfo, remoteIdx: relay.RemoteIndex, relay: relay}, out[i][:0], signedPayload, h, fwPacket, lhf, nb, q, localCache) - return - case ForwardingType: - // Find the target HostInfo relay object - targetHI, targetRelay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relay.PeerAddr) + switch h.Subtype { + case header.MessageNone: + if !f.decryptToTunDelayWrite(hostinfo, h.MessageCounter, out[i], pkt, segment, fwPacket, nb, q, localCache) { + return + } + case header.MessageRelay: + // The entire body is sent as AD, not encrypted. + // The packet consists of a 16-byte parsed Nebula header, Associated Data-protected payload, and a trailing 16-byte AEAD signature value. + // The packet is guaranteed to be at least 16 bytes at this point, b/c it got past the h.Parse() call above. If it's + // otherwise malformed (meaning, there is no trailing 16 byte AEAD value), then this will result in at worst a 0-length slice + // which will gracefully fail in the DecryptDanger call. + signedPayload := segment[:len(segment)-hostinfo.ConnectionState.dKey.Overhead()] + signatureValue := segment[len(segment)-hostinfo.ConnectionState.dKey.Overhead():] + out[i].Scratch, err = hostinfo.ConnectionState.dKey.DecryptDanger(out[i].Scratch, signedPayload, signatureValue, h.MessageCounter, nb) if err != nil { - hostinfo.logger(f.l).WithField("relayTo", relay.PeerAddr).WithError(err).WithField("hostinfo.vpnAddrs", hostinfo.vpnAddrs).Info("Failed to find target host info by ip") + return + } + // Successfully validated the thing. Get rid of the Relay header. + signedPayload = signedPayload[header.Len:] + // Pull the Roaming parts up here, and return in all call paths. + f.handleHostRoaming(hostinfo, ip) + // Track usage of both the HostInfo and the Relay for the received & authenticated packet + f.connectionManager.In(hostinfo) + f.connectionManager.RelayUsed(h.RemoteIndex) + + relay, ok := hostinfo.relayState.QueryRelayForByIdx(h.RemoteIndex) + if !ok { + // The only way this happens is if hostmap has an index to the correct HostInfo, but the HostInfo is missing + // its internal mapping. This should never happen. + hostinfo.logger(f.l).WithFields(logrus.Fields{"vpnAddrs": hostinfo.vpnAddrs, "remoteIndex": h.RemoteIndex}).Error("HostInfo missing remote relay index") return } - // If that relay is Established, forward the payload through it - if targetRelay.State == Established { - switch targetRelay.Type { - case ForwardingType: - // Forward this packet through the relay tunnel - // Find the target HostInfo - f.SendVia(targetHI, targetRelay, signedPayload, nb, out[i], false) - return - case TerminalType: - hostinfo.logger(f.l).Error("Unexpected Relay Type of Terminal") - } - } else { - hostinfo.logger(f.l).WithFields(logrus.Fields{"relayTo": relay.PeerAddr, "relayFrom": hostinfo.vpnAddrs[0], "targetRelayState": targetRelay.State}).Info("Unexpected target relay state") + switch relay.Type { + case TerminalType: + // If I am the target of this relay, process the unwrapped packet + // From this recursive point, all these variables are 'burned'. We shouldn't rely on them again. + f.readOutsidePackets(netip.AddrPort{}, &ViaSender{relayHI: hostinfo, remoteIdx: relay.RemoteIndex, relay: relay}, out[i].Scratch[:0], signedPayload, h, fwPacket, lhf, nb, q, localCache) return + case ForwardingType: + // Find the target HostInfo relay object + targetHI, targetRelay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relay.PeerAddr) + if err != nil { + hostinfo.logger(f.l).WithField("relayTo", relay.PeerAddr).WithError(err).WithField("hostinfo.vpnAddrs", hostinfo.vpnAddrs).Info("Failed to find target host info by ip") + return + } + + // If that relay is Established, forward the payload through it + if targetRelay.State == Established { + switch targetRelay.Type { + case ForwardingType: + // Forward this packet through the relay tunnel + // Find the target HostInfo + f.SendVia(targetHI, targetRelay, signedPayload, nb, out[i].Scratch, false) + return + case TerminalType: + hostinfo.logger(f.l).Error("Unexpected Relay Type of Terminal") + } + } else { + hostinfo.logger(f.l).WithFields(logrus.Fields{"relayTo": relay.PeerAddr, "relayFrom": hostinfo.vpnAddrs[0], "targetRelayState": targetRelay.State}).Info("Unexpected target relay state") + return + } } } - } - case header.LightHouse: - f.messageMetrics.Rx(h.Type, h.Subtype, 1) - if !f.handleEncrypted(ci, ip[i], h) { + case header.LightHouse: + f.messageMetrics.Rx(h.Type, h.Subtype, 1) + if !f.handleEncrypted(ci, ip, h) { + return + } + + d, err := f.decrypt(hostinfo, h.MessageCounter, out[i].Scratch, segment, h, nb) + if err != nil { + hostinfo.logger(f.l).WithError(err).WithField("udpAddr", ip). + WithField("packet", segment). + Error("Failed to decrypt lighthouse packet") + return + } + + lhf.HandleRequest(ip, hostinfo.vpnAddrs, d, f) + + // Fallthrough to the bottom to record incoming traffic + + case header.Test: + f.messageMetrics.Rx(h.Type, h.Subtype, 1) + if !f.handleEncrypted(ci, ip, h) { + return + } + + d, err := f.decrypt(hostinfo, h.MessageCounter, out[i].Scratch, segment, h, nb) + if err != nil { + hostinfo.logger(f.l).WithError(err).WithField("udpAddr", ip). + WithField("packet", segment). + Error("Failed to decrypt test packet") + return + } + + if h.Subtype == header.TestRequest { + // This testRequest might be from TryPromoteBest, so we should roam + // to the new IP address before responding + f.handleHostRoaming(hostinfo, ip) + f.send(header.Test, header.TestReply, ci, hostinfo, d, nb, out[i].Scratch) + } + + // Fallthrough to the bottom to record incoming traffic + + // Non encrypted messages below here, they should not fall through to avoid tracking incoming traffic since they + // are unauthenticated + + case header.Handshake: + f.messageMetrics.Rx(h.Type, h.Subtype, 1) + f.handshakeManager.HandleIncoming(ip, nil, segment, h) + return + + case header.RecvError: + f.messageMetrics.Rx(h.Type, h.Subtype, 1) + f.handleRecvError(ip, h) + return + + case header.CloseTunnel: + f.messageMetrics.Rx(h.Type, h.Subtype, 1) + if !f.handleEncrypted(ci, ip, h) { + return + } + + hostinfo.logger(f.l).WithField("udpAddr", ip). + Info("Close tunnel received, tearing down.") + + f.closeTunnel(hostinfo) + return + + case header.Control: + if !f.handleEncrypted(ci, ip, h) { + return + } + + d, err := f.decrypt(hostinfo, h.MessageCounter, out[i].Scratch, segment, h, nb) + if err != nil { + hostinfo.logger(f.l).WithError(err).WithField("udpAddr", ip). + WithField("packet", segment). + Error("Failed to decrypt Control packet") + return + } + + f.relayManager.HandleControlMsg(hostinfo, d, f) + + default: + f.messageMetrics.Rx(h.Type, h.Subtype, 1) + hostinfo.logger(f.l).Debugf("Unexpected packet received from %s", ip) return } - d, err := f.decrypt(hostinfo, h.MessageCounter, out[i], packet, h, nb) - if err != nil { - hostinfo.logger(f.l).WithError(err).WithField("udpAddr", ip). - WithField("packet", packet). - Error("Failed to decrypt lighthouse packet") - return - } + f.handleHostRoaming(hostinfo, ip) - lhf.HandleRequest(ip[i], hostinfo.vpnAddrs, d, f) - - // Fallthrough to the bottom to record incoming traffic - - case header.Test: - f.messageMetrics.Rx(h.Type, h.Subtype, 1) - if !f.handleEncrypted(ci, ip[i], h) { - return - } - - d, err := f.decrypt(hostinfo, h.MessageCounter, out[i], packet, h, nb) - if err != nil { - hostinfo.logger(f.l).WithError(err).WithField("udpAddr", ip). - WithField("packet", packet). - Error("Failed to decrypt test packet") - return - } - - if h.Subtype == header.TestRequest { - // This testRequest might be from TryPromoteBest, so we should roam - // to the new IP address before responding - f.handleHostRoaming(hostinfo, ip[i]) - f.send(header.Test, header.TestReply, ci, hostinfo, d, nb, out[i]) - } - - // Fallthrough to the bottom to record incoming traffic - - // Non encrypted messages below here, they should not fall through to avoid tracking incoming traffic since they - // are unauthenticated - - case header.Handshake: - f.messageMetrics.Rx(h.Type, h.Subtype, 1) - f.handshakeManager.HandleIncoming(ip[i], nil, packet, h) - return - - case header.RecvError: - f.messageMetrics.Rx(h.Type, h.Subtype, 1) - f.handleRecvError(ip[i], h) - return - - case header.CloseTunnel: - f.messageMetrics.Rx(h.Type, h.Subtype, 1) - if !f.handleEncrypted(ci, ip[i], h) { - return - } - - hostinfo.logger(f.l).WithField("udpAddr", ip). - Info("Close tunnel received, tearing down.") - - f.closeTunnel(hostinfo) - return - - case header.Control: - if !f.handleEncrypted(ci, ip[i], h) { - return - } - - d, err := f.decrypt(hostinfo, h.MessageCounter, out[i], packet, h, nb) - if err != nil { - hostinfo.logger(f.l).WithError(err).WithField("udpAddr", ip). - WithField("packet", packet). - Error("Failed to decrypt Control packet") - return - } - - f.relayManager.HandleControlMsg(hostinfo, d, f) - - default: - f.messageMetrics.Rx(h.Type, h.Subtype, 1) - hostinfo.logger(f.l).Debugf("Unexpected packet received from %s", ip) - return + f.connectionManager.In(hostinfo) } - - f.handleHostRoaming(hostinfo, ip[i]) - - f.connectionManager.In(hostinfo) } } @@ -666,16 +672,17 @@ func (f *Interface) decrypt(hostinfo *HostInfo, mc uint64, out []byte, packet [] return out, nil } -func (f *Interface) decryptToTunDelayWrite(hostinfo *HostInfo, messageCounter uint64, out []byte, outNeedsTun *int, packet []byte, fwPacket *firewall.Packet, nb []byte, q int, localCache firewall.ConntrackCache) bool { +func (f *Interface) decryptToTunDelayWrite(hostinfo *HostInfo, messageCounter uint64, out *packet.OutPacket, pkt *packet.Packet, inSegment []byte, fwPacket *firewall.Packet, nb []byte, q int, localCache firewall.ConntrackCache) bool { var err error - out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:header.Len], packet[header.Len:], messageCounter, nb) + out.Segments[out.SegCounter] = out.Segments[out.SegCounter][:0] + out.Segments[out.SegCounter], err = hostinfo.ConnectionState.dKey.DecryptDanger(out.Segments[out.SegCounter], inSegment[:header.Len], inSegment[header.Len:], messageCounter, nb) if err != nil { hostinfo.logger(f.l).WithError(err).Error("Failed to decrypt packet") return false } - err = newPacket(out, true, fwPacket) + err = newPacket(out.Segments[out.SegCounter], true, fwPacket) if err != nil { hostinfo.logger(f.l).WithError(err).WithField("packet", out). Warnf("Error while validating inbound packet") @@ -692,7 +699,7 @@ func (f *Interface) decryptToTunDelayWrite(hostinfo *HostInfo, messageCounter ui if dropReason != nil { // NOTE: We give `packet` as the `out` here since we already decrypted from it and we don't need it anymore // This gives us a buffer to build the reject packet in - f.rejectOutside(out, hostinfo.ConnectionState, hostinfo, nb, packet, q) + f.rejectOutside(out.Segments[out.SegCounter], hostinfo.ConnectionState, hostinfo, nb, inSegment, q) if f.l.Level >= logrus.DebugLevel { hostinfo.logger(f.l).WithField("fwPacket", fwPacket). WithField("reason", dropReason). @@ -702,7 +709,8 @@ func (f *Interface) decryptToTunDelayWrite(hostinfo *HostInfo, messageCounter ui } f.connectionManager.In(hostinfo) - *outNeedsTun = len(out) + pkt.OutLen += len(inSegment) + out.SegCounter++ return true } diff --git a/overlay/tun_linux.go b/overlay/tun_linux.go index c53a54d..722091c 100644 --- a/overlay/tun_linux.go +++ b/overlay/tun_linux.go @@ -128,8 +128,10 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueu return nil, fmt.Errorf("set vnethdr size: %w", err) } + flags := 0 + //flags := unix.TUN_F_CSUM //|unix.TUN_F_USO4|unix.TUN_F_USO6 - err = unix.IoctlSetInt(fd, unix.TUNSETOFFLOAD, 0) //todo! + err = unix.IoctlSetInt(fd, unix.TUNSETOFFLOAD, flags) if err != nil { return nil, fmt.Errorf("set offloads: %w", err) } diff --git a/packet/outpacket.go b/packet/outpacket.go new file mode 100644 index 0000000..e345afe --- /dev/null +++ b/packet/outpacket.go @@ -0,0 +1,23 @@ +package packet + +type OutPacket struct { + Segments [][]byte + //todo virtio header? + SegSize int + SegCounter int + Valid bool + wasSegmented bool + + Scratch []byte +} + +func NewOut() *OutPacket { + out := new(OutPacket) + const numSegments = 64 + out.Segments = make([][]byte, numSegments) + for i := 0; i < numSegments; i++ { //todo this is dumb + out.Segments[i] = make([]byte, Size) + } + out.Scratch = make([]byte, Size) + return out +} diff --git a/packet/packet.go b/packet/packet.go new file mode 100644 index 0000000..8eeb448 --- /dev/null +++ b/packet/packet.go @@ -0,0 +1,117 @@ +package packet + +import ( + "encoding/binary" + "iter" + "net/netip" + + "golang.org/x/sys/unix" +) + +const Size = 0xffff + +type Packet struct { + Payload []byte + Control []byte + Name []byte + SegSize int + + //todo should this hold out as well? + OutLen int + + wasSegmented bool + isV4 bool + //Addr netip.AddrPort +} + +func New(isV4 bool) *Packet { + return &Packet{ + Payload: make([]byte, Size), + Control: make([]byte, unix.CmsgSpace(2)), + Name: make([]byte, unix.SizeofSockaddrInet6), + isV4: isV4, + } +} + +func (p *Packet) AddrPort() netip.AddrPort { + var ip netip.Addr + // Its ok to skip the ok check here, the slicing is the only error that can occur and it will panic + if p.isV4 { + ip, _ = netip.AddrFromSlice(p.Name[4:8]) + } else { + ip, _ = netip.AddrFromSlice(p.Name[8:24]) + } + return netip.AddrPortFrom(ip.Unmap(), binary.BigEndian.Uint16(p.Name[2:4])) +} + +func (p *Packet) updateCtrl(ctrlLen int) { + p.SegSize = len(p.Payload) + p.wasSegmented = false + if ctrlLen == 0 { + return + } + if len(p.Control) == 0 { + return + } + cmsgs, err := unix.ParseSocketControlMessage(p.Control) + if err != nil { + return // oh well + } + + for _, c := range cmsgs { + if c.Header.Level == unix.SOL_UDP && c.Header.Type == unix.UDP_GRO && len(c.Data) >= 2 { + p.wasSegmented = true + p.SegSize = int(binary.LittleEndian.Uint16(c.Data[:2])) + return + } + } +} + +// Update sets a Packet into "just received, not processed" state +func (p *Packet) Update(ctrlLen int) { + p.OutLen = -1 + p.updateCtrl(ctrlLen) +} + +func (p *Packet) Segments() iter.Seq[[]byte] { + return func(yield func([]byte) bool) { + //cursor := 0 + for offset := 0; offset < len(p.Payload); offset += p.SegSize { + end := offset + p.SegSize + if end > len(p.Payload) { + end = len(p.Payload) + } + if !yield(p.Payload[offset:end]) { + return + } + } + + //if p.SegSize > 0 && p.SegSize < len(p.Payload) { + // + //} else { + // f.readOutsidePackets(p.Addr, nil, result2[:0], p.Payload, h, fwPacket2, lhh, nb2, i, conntrackCache.Get(f.l)) + //} + + } +} + +//type Pool struct { +// pool sync.Pool +//} +// +//var bigPool = &Pool{ +// pool: sync.Pool{New: func() any { return New() }}, +//} +// +//func GetPool() *Pool { +// return bigPool +//} +// +//func (p *Pool) Get() *Packet { +// return p.pool.Get().(*Packet) +//} +// +//func (p *Pool) Put(x *Packet) { +// x.Payload = x.Payload[:Size] +// p.pool.Put(x) +//} diff --git a/udp/conn.go b/udp/conn.go index 6a3c3b4..d7583f2 100644 --- a/udp/conn.go +++ b/udp/conn.go @@ -4,13 +4,13 @@ import ( "net/netip" "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/packet" ) const MTU = 9001 type EncReader func( - addrs []netip.AddrPort, - payload [][]byte, + []*packet.Packet, ) type Conn interface { diff --git a/udp/udp_linux.go b/udp/udp_linux.go index e1ba229..ec5eeb3 100644 --- a/udp/udp_linux.go +++ b/udp/udp_linux.go @@ -18,18 +18,11 @@ import ( ) type StdConn struct { - sysFd int - isV4 bool - l *logrus.Logger - batch int -} - -func maybeIPV4(ip net.IP) (net.IP, bool) { - ip4 := ip.To4() - if ip4 != nil { - return ip4, true - } - return ip, false + sysFd int + isV4 bool + l *logrus.Logger + batch int + enableGRO bool } func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) { @@ -119,9 +112,7 @@ func (u *StdConn) LocalAddr() (netip.AddrPort, error) { } func (u *StdConn) ListenOut(r EncReader) { - var ip netip.Addr - addrPorts := make([]netip.AddrPort, u.batch) - msgs, buffers, names := u.PrepareRawMessages(u.batch) + msgs, packets := u.PrepareRawMessages(u.batch, u.isV4) read := u.ReadMulti if u.batch == 1 { read = u.ReadSingle @@ -135,17 +126,13 @@ func (u *StdConn) ListenOut(r EncReader) { } for i := 0; i < n; i++ { - // Its ok to skip the ok check here, the slicing is the only error that can occur and it will panic - if u.isV4 { - ip, _ = netip.AddrFromSlice(names[i][4:8]) - } else { - ip, _ = netip.AddrFromSlice(names[i][8:24]) - } - addrPorts[i] = netip.AddrPortFrom(ip.Unmap(), binary.BigEndian.Uint16(names[i][2:4])) - buffers[i] = buffers[i][:msgs[i].Len] - + packets[i].Payload = packets[i].Payload[:msgs[i].Len] + packets[i].Update(getRawMessageControlLen(&msgs[i])) + } + r(packets) + for i := 0; i < n; i++ { //todo reset this in prev loop, but this makes debug ez + msgs[i].Hdr.Controllen = uint64(unix.CmsgSpace(2)) } - r(addrPorts, buffers) } } @@ -297,6 +284,27 @@ func (u *StdConn) ReloadConfig(c *config.C) { u.l.WithError(err).Error("Failed to set listen.so_mark") } } + u.configureGRO(true) +} + +func (u *StdConn) configureGRO(enable bool) { + if enable == u.enableGRO { + return + } + + if enable { + if err := unix.SetsockoptInt(u.sysFd, unix.SOL_UDP, unix.UDP_GRO, 1); err != nil { + u.l.WithError(err).Warn("Failed to enable UDP GRO") + return + } + u.enableGRO = true + u.l.Info("UDP GRO enabled") + } else { + if err := unix.SetsockoptInt(u.sysFd, unix.SOL_UDP, unix.UDP_GRO, 0); err != nil && err != unix.ENOPROTOOPT { + u.l.WithError(err).Warn("Failed to disable UDP GRO") + } + u.enableGRO = false + } } func (u *StdConn) getMemInfo(meminfo *[unix.SK_MEMINFO_VARS]uint32) error { diff --git a/udp/udp_linux_64.go b/udp/udp_linux_64.go index 48c5a97..8e778f0 100644 --- a/udp/udp_linux_64.go +++ b/udp/udp_linux_64.go @@ -7,6 +7,7 @@ package udp import ( + "github.com/slackhq/nebula/packet" "golang.org/x/sys/unix" ) @@ -33,25 +34,49 @@ type rawMessage struct { Pad0 [4]byte } -func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) { +func setRawMessageControl(msg *rawMessage, buf []byte) { + if len(buf) == 0 { + msg.Hdr.Control = nil + msg.Hdr.Controllen = 0 + return + } + msg.Hdr.Control = &buf[0] + msg.Hdr.Controllen = uint64(len(buf)) +} + +func getRawMessageControlLen(msg *rawMessage) int { + return int(msg.Hdr.Controllen) +} + +func setCmsgLen(h *unix.Cmsghdr, l int) { + h.Len = uint64(l) +} + +func (u *StdConn) PrepareRawMessages(n int, isV4 bool) ([]rawMessage, []*packet.Packet) { msgs := make([]rawMessage, n) - buffers := make([][]byte, n) - names := make([][]byte, n) + packets := make([]*packet.Packet, n) for i := range msgs { - buffers[i] = make([]byte, MTU) - names[i] = make([]byte, unix.SizeofSockaddrInet6) + packets[i] = packet.New(isV4) vs := []iovec{ - {Base: &buffers[i][0], Len: uint64(len(buffers[i]))}, + {Base: &packets[i].Payload[0], Len: uint64(packet.Size)}, } msgs[i].Hdr.Iov = &vs[0] msgs[i].Hdr.Iovlen = uint64(len(vs)) - msgs[i].Hdr.Name = &names[i][0] - msgs[i].Hdr.Namelen = uint32(len(names[i])) + msgs[i].Hdr.Name = &packets[i].Name[0] + msgs[i].Hdr.Namelen = uint32(len(packets[i].Name)) + + if u.enableGRO { + msgs[i].Hdr.Control = &packets[i].Control[0] + msgs[i].Hdr.Controllen = uint64(len(packets[i].Control)) + } else { + msgs[i].Hdr.Control = nil + msgs[i].Hdr.Controllen = 0 + } } - return msgs, buffers, names + return msgs, packets }