mirror of
https://github.com/slackhq/nebula.git
synced 2025-11-23 17:04:25 +01:00
what about with bad GRO on UDP
This commit is contained in:
31
interface.go
31
interface.go
@@ -18,6 +18,7 @@ import (
|
|||||||
"github.com/slackhq/nebula/firewall"
|
"github.com/slackhq/nebula/firewall"
|
||||||
"github.com/slackhq/nebula/header"
|
"github.com/slackhq/nebula/header"
|
||||||
"github.com/slackhq/nebula/overlay"
|
"github.com/slackhq/nebula/overlay"
|
||||||
|
"github.com/slackhq/nebula/packet"
|
||||||
"github.com/slackhq/nebula/udp"
|
"github.com/slackhq/nebula/udp"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -268,12 +269,9 @@ func (f *Interface) listenOut(q int) {
|
|||||||
|
|
||||||
ctCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
|
ctCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
|
||||||
lhh := f.lightHouse.NewRequestHandler()
|
lhh := f.lightHouse.NewRequestHandler()
|
||||||
plaintexts := make([][]byte, batch)
|
outPackets := make([]*packet.OutPacket, batch)
|
||||||
outNeedsTun := make([]*int, batch)
|
|
||||||
for i := 0; i < batch; i++ {
|
for i := 0; i < batch; i++ {
|
||||||
plaintexts[i] = make([]byte, udp.MTU)
|
outPackets[i] = packet.NewOut()
|
||||||
outNeedsTun[i] = new(int)
|
|
||||||
*outNeedsTun[i] = -1
|
|
||||||
}
|
}
|
||||||
|
|
||||||
h := &header.H{}
|
h := &header.H{}
|
||||||
@@ -282,16 +280,23 @@ func (f *Interface) listenOut(q int) {
|
|||||||
|
|
||||||
toSend := make([][]byte, batch)
|
toSend := make([][]byte, batch)
|
||||||
|
|
||||||
li.ListenOut(func(fromUdpAddrs []netip.AddrPort, payloads [][]byte) {
|
li.ListenOut(func(pkts []*packet.Packet) {
|
||||||
toSend = toSend[:0]
|
toSend = toSend[:0]
|
||||||
for i := range plaintexts {
|
for i := range outPackets {
|
||||||
plaintexts[i] = plaintexts[i][:0]
|
outPackets[i].Valid = false
|
||||||
|
outPackets[i].SegCounter = 0
|
||||||
}
|
}
|
||||||
f.readOutsidePacketsMany(fromUdpAddrs, plaintexts, outNeedsTun, payloads, h, fwPacket, lhh, nb, q, ctCache.Get(f.l))
|
|
||||||
for i := range plaintexts {
|
f.readOutsidePacketsMany(pkts, outPackets, h, fwPacket, lhh, nb, q, ctCache.Get(f.l))
|
||||||
if *outNeedsTun[i] != -1 {
|
for i := range outPackets {
|
||||||
toSend = append(toSend, plaintexts[i][:*outNeedsTun[i]])
|
if pkts[i].OutLen != -1 {
|
||||||
*outNeedsTun[i] = -1
|
for j := 0; j < outPackets[i].SegCounter; j++ {
|
||||||
|
if len(outPackets[i].Segments[j]) > 0 {
|
||||||
|
toSend = append(toSend, outPackets[i].Segments[j])
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
//toSend = append(toSend, outPackets[i])
|
||||||
//toSendCount++
|
//toSendCount++
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
354
outside.go
354
outside.go
@@ -7,6 +7,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/google/gopacket/layers"
|
"github.com/google/gopacket/layers"
|
||||||
|
"github.com/slackhq/nebula/packet"
|
||||||
"golang.org/x/net/ipv6"
|
"golang.org/x/net/ipv6"
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
@@ -216,21 +217,14 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
|
|||||||
f.connectionManager.In(hostinfo)
|
f.connectionManager.In(hostinfo)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Interface) readOutsidePacketsMany(ip []netip.AddrPort, out [][]byte, outNeedsTun []*int, packets [][]byte, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache) {
|
func (f *Interface) readOutsidePacketsMany(packets []*packet.Packet, out []*packet.OutPacket, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache) {
|
||||||
for i, packet := range packets {
|
for i, pkt := range packets {
|
||||||
|
out[i].Scratch = out[i].Scratch[:0]
|
||||||
err := h.Parse(packet)
|
ip := pkt.AddrPort()
|
||||||
if err != nil {
|
|
||||||
// Hole punch packets are 0 or 1 byte big, so lets ignore printing those errors
|
|
||||||
if len(packet) > 1 {
|
|
||||||
f.l.WithField("packet", packet).Infof("Error while parsing inbound packet from %s: %s", ip, err)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
//l.Error("in packet ", header, packet[HeaderLen:])
|
//l.Error("in packet ", header, packet[HeaderLen:])
|
||||||
if ip[i].IsValid() {
|
if ip.IsValid() {
|
||||||
if f.myVpnNetworksTable.Contains(ip[i].Addr()) {
|
if f.myVpnNetworksTable.Contains(ip.Addr()) {
|
||||||
if f.l.Level >= logrus.DebugLevel {
|
if f.l.Level >= logrus.DebugLevel {
|
||||||
f.l.WithField("udpAddr", ip).Debug("Refusing to process double encrypted packet")
|
f.l.WithField("udpAddr", ip).Debug("Refusing to process double encrypted packet")
|
||||||
}
|
}
|
||||||
@@ -238,182 +232,194 @@ func (f *Interface) readOutsidePacketsMany(ip []netip.AddrPort, out [][]byte, ou
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var hostinfo *HostInfo
|
//todo per-segment!
|
||||||
// verify if we've seen this index before, otherwise respond to the handshake initiation
|
for segment := range pkt.Segments() {
|
||||||
if h.Type == header.Message && h.Subtype == header.MessageRelay {
|
|
||||||
hostinfo = f.hostMap.QueryRelayIndex(h.RemoteIndex)
|
|
||||||
} else {
|
|
||||||
hostinfo = f.hostMap.QueryIndex(h.RemoteIndex)
|
|
||||||
}
|
|
||||||
|
|
||||||
var ci *ConnectionState
|
err := h.Parse(segment)
|
||||||
if hostinfo != nil {
|
if err != nil {
|
||||||
ci = hostinfo.ConnectionState
|
// Hole punch packets are 0 or 1 byte big, so lets ignore printing those errors
|
||||||
}
|
if len(segment) > 1 {
|
||||||
|
f.l.WithField("packet", pkt).Infof("Error while parsing inbound packet from %s: %s", ip, err)
|
||||||
switch h.Type {
|
}
|
||||||
case header.Message:
|
|
||||||
// TODO handleEncrypted sends directly to addr on error. Handle this in the tunneling case.
|
|
||||||
if !f.handleEncrypted(ci, ip[i], h) {
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
switch h.Subtype {
|
var hostinfo *HostInfo
|
||||||
case header.MessageNone:
|
// verify if we've seen this index before, otherwise respond to the handshake initiation
|
||||||
out[i] = out[i][:0]
|
if h.Type == header.Message && h.Subtype == header.MessageRelay {
|
||||||
if !f.decryptToTunDelayWrite(hostinfo, h.MessageCounter, out[i][:0], outNeedsTun[i], packet, fwPacket, nb, q, localCache) {
|
hostinfo = f.hostMap.QueryRelayIndex(h.RemoteIndex)
|
||||||
return
|
} else {
|
||||||
}
|
hostinfo = f.hostMap.QueryIndex(h.RemoteIndex)
|
||||||
case header.MessageRelay:
|
}
|
||||||
// The entire body is sent as AD, not encrypted.
|
|
||||||
// The packet consists of a 16-byte parsed Nebula header, Associated Data-protected payload, and a trailing 16-byte AEAD signature value.
|
|
||||||
// The packet is guaranteed to be at least 16 bytes at this point, b/c it got past the h.Parse() call above. If it's
|
|
||||||
// otherwise malformed (meaning, there is no trailing 16 byte AEAD value), then this will result in at worst a 0-length slice
|
|
||||||
// which will gracefully fail in the DecryptDanger call.
|
|
||||||
signedPayload := packet[:len(packet)-hostinfo.ConnectionState.dKey.Overhead()]
|
|
||||||
signatureValue := packet[len(packet)-hostinfo.ConnectionState.dKey.Overhead():]
|
|
||||||
out[i], err = hostinfo.ConnectionState.dKey.DecryptDanger(out[i], signedPayload, signatureValue, h.MessageCounter, nb)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
// Successfully validated the thing. Get rid of the Relay header.
|
|
||||||
signedPayload = signedPayload[header.Len:]
|
|
||||||
// Pull the Roaming parts up here, and return in all call paths.
|
|
||||||
f.handleHostRoaming(hostinfo, ip[i])
|
|
||||||
// Track usage of both the HostInfo and the Relay for the received & authenticated packet
|
|
||||||
f.connectionManager.In(hostinfo)
|
|
||||||
f.connectionManager.RelayUsed(h.RemoteIndex)
|
|
||||||
|
|
||||||
relay, ok := hostinfo.relayState.QueryRelayForByIdx(h.RemoteIndex)
|
var ci *ConnectionState
|
||||||
if !ok {
|
if hostinfo != nil {
|
||||||
// The only way this happens is if hostmap has an index to the correct HostInfo, but the HostInfo is missing
|
ci = hostinfo.ConnectionState
|
||||||
// its internal mapping. This should never happen.
|
}
|
||||||
hostinfo.logger(f.l).WithFields(logrus.Fields{"vpnAddrs": hostinfo.vpnAddrs, "remoteIndex": h.RemoteIndex}).Error("HostInfo missing remote relay index")
|
|
||||||
|
switch h.Type {
|
||||||
|
case header.Message:
|
||||||
|
// TODO handleEncrypted sends directly to addr on error. Handle this in the tunneling case.
|
||||||
|
if !f.handleEncrypted(ci, ip, h) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
switch relay.Type {
|
switch h.Subtype {
|
||||||
case TerminalType:
|
case header.MessageNone:
|
||||||
// If I am the target of this relay, process the unwrapped packet
|
if !f.decryptToTunDelayWrite(hostinfo, h.MessageCounter, out[i], pkt, segment, fwPacket, nb, q, localCache) {
|
||||||
// From this recursive point, all these variables are 'burned'. We shouldn't rely on them again.
|
return
|
||||||
f.readOutsidePackets(netip.AddrPort{}, &ViaSender{relayHI: hostinfo, remoteIdx: relay.RemoteIndex, relay: relay}, out[i][:0], signedPayload, h, fwPacket, lhf, nb, q, localCache)
|
}
|
||||||
return
|
case header.MessageRelay:
|
||||||
case ForwardingType:
|
// The entire body is sent as AD, not encrypted.
|
||||||
// Find the target HostInfo relay object
|
// The packet consists of a 16-byte parsed Nebula header, Associated Data-protected payload, and a trailing 16-byte AEAD signature value.
|
||||||
targetHI, targetRelay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relay.PeerAddr)
|
// The packet is guaranteed to be at least 16 bytes at this point, b/c it got past the h.Parse() call above. If it's
|
||||||
|
// otherwise malformed (meaning, there is no trailing 16 byte AEAD value), then this will result in at worst a 0-length slice
|
||||||
|
// which will gracefully fail in the DecryptDanger call.
|
||||||
|
signedPayload := segment[:len(segment)-hostinfo.ConnectionState.dKey.Overhead()]
|
||||||
|
signatureValue := segment[len(segment)-hostinfo.ConnectionState.dKey.Overhead():]
|
||||||
|
out[i].Scratch, err = hostinfo.ConnectionState.dKey.DecryptDanger(out[i].Scratch, signedPayload, signatureValue, h.MessageCounter, nb)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
hostinfo.logger(f.l).WithField("relayTo", relay.PeerAddr).WithError(err).WithField("hostinfo.vpnAddrs", hostinfo.vpnAddrs).Info("Failed to find target host info by ip")
|
return
|
||||||
|
}
|
||||||
|
// Successfully validated the thing. Get rid of the Relay header.
|
||||||
|
signedPayload = signedPayload[header.Len:]
|
||||||
|
// Pull the Roaming parts up here, and return in all call paths.
|
||||||
|
f.handleHostRoaming(hostinfo, ip)
|
||||||
|
// Track usage of both the HostInfo and the Relay for the received & authenticated packet
|
||||||
|
f.connectionManager.In(hostinfo)
|
||||||
|
f.connectionManager.RelayUsed(h.RemoteIndex)
|
||||||
|
|
||||||
|
relay, ok := hostinfo.relayState.QueryRelayForByIdx(h.RemoteIndex)
|
||||||
|
if !ok {
|
||||||
|
// The only way this happens is if hostmap has an index to the correct HostInfo, but the HostInfo is missing
|
||||||
|
// its internal mapping. This should never happen.
|
||||||
|
hostinfo.logger(f.l).WithFields(logrus.Fields{"vpnAddrs": hostinfo.vpnAddrs, "remoteIndex": h.RemoteIndex}).Error("HostInfo missing remote relay index")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// If that relay is Established, forward the payload through it
|
switch relay.Type {
|
||||||
if targetRelay.State == Established {
|
case TerminalType:
|
||||||
switch targetRelay.Type {
|
// If I am the target of this relay, process the unwrapped packet
|
||||||
case ForwardingType:
|
// From this recursive point, all these variables are 'burned'. We shouldn't rely on them again.
|
||||||
// Forward this packet through the relay tunnel
|
f.readOutsidePackets(netip.AddrPort{}, &ViaSender{relayHI: hostinfo, remoteIdx: relay.RemoteIndex, relay: relay}, out[i].Scratch[:0], signedPayload, h, fwPacket, lhf, nb, q, localCache)
|
||||||
// Find the target HostInfo
|
|
||||||
f.SendVia(targetHI, targetRelay, signedPayload, nb, out[i], false)
|
|
||||||
return
|
|
||||||
case TerminalType:
|
|
||||||
hostinfo.logger(f.l).Error("Unexpected Relay Type of Terminal")
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
hostinfo.logger(f.l).WithFields(logrus.Fields{"relayTo": relay.PeerAddr, "relayFrom": hostinfo.vpnAddrs[0], "targetRelayState": targetRelay.State}).Info("Unexpected target relay state")
|
|
||||||
return
|
return
|
||||||
|
case ForwardingType:
|
||||||
|
// Find the target HostInfo relay object
|
||||||
|
targetHI, targetRelay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relay.PeerAddr)
|
||||||
|
if err != nil {
|
||||||
|
hostinfo.logger(f.l).WithField("relayTo", relay.PeerAddr).WithError(err).WithField("hostinfo.vpnAddrs", hostinfo.vpnAddrs).Info("Failed to find target host info by ip")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// If that relay is Established, forward the payload through it
|
||||||
|
if targetRelay.State == Established {
|
||||||
|
switch targetRelay.Type {
|
||||||
|
case ForwardingType:
|
||||||
|
// Forward this packet through the relay tunnel
|
||||||
|
// Find the target HostInfo
|
||||||
|
f.SendVia(targetHI, targetRelay, signedPayload, nb, out[i].Scratch, false)
|
||||||
|
return
|
||||||
|
case TerminalType:
|
||||||
|
hostinfo.logger(f.l).Error("Unexpected Relay Type of Terminal")
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
hostinfo.logger(f.l).WithFields(logrus.Fields{"relayTo": relay.PeerAddr, "relayFrom": hostinfo.vpnAddrs[0], "targetRelayState": targetRelay.State}).Info("Unexpected target relay state")
|
||||||
|
return
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
case header.LightHouse:
|
case header.LightHouse:
|
||||||
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
|
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
|
||||||
if !f.handleEncrypted(ci, ip[i], h) {
|
if !f.handleEncrypted(ci, ip, h) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
d, err := f.decrypt(hostinfo, h.MessageCounter, out[i].Scratch, segment, h, nb)
|
||||||
|
if err != nil {
|
||||||
|
hostinfo.logger(f.l).WithError(err).WithField("udpAddr", ip).
|
||||||
|
WithField("packet", segment).
|
||||||
|
Error("Failed to decrypt lighthouse packet")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
lhf.HandleRequest(ip, hostinfo.vpnAddrs, d, f)
|
||||||
|
|
||||||
|
// Fallthrough to the bottom to record incoming traffic
|
||||||
|
|
||||||
|
case header.Test:
|
||||||
|
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
|
||||||
|
if !f.handleEncrypted(ci, ip, h) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
d, err := f.decrypt(hostinfo, h.MessageCounter, out[i].Scratch, segment, h, nb)
|
||||||
|
if err != nil {
|
||||||
|
hostinfo.logger(f.l).WithError(err).WithField("udpAddr", ip).
|
||||||
|
WithField("packet", segment).
|
||||||
|
Error("Failed to decrypt test packet")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if h.Subtype == header.TestRequest {
|
||||||
|
// This testRequest might be from TryPromoteBest, so we should roam
|
||||||
|
// to the new IP address before responding
|
||||||
|
f.handleHostRoaming(hostinfo, ip)
|
||||||
|
f.send(header.Test, header.TestReply, ci, hostinfo, d, nb, out[i].Scratch)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fallthrough to the bottom to record incoming traffic
|
||||||
|
|
||||||
|
// Non encrypted messages below here, they should not fall through to avoid tracking incoming traffic since they
|
||||||
|
// are unauthenticated
|
||||||
|
|
||||||
|
case header.Handshake:
|
||||||
|
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
|
||||||
|
f.handshakeManager.HandleIncoming(ip, nil, segment, h)
|
||||||
|
return
|
||||||
|
|
||||||
|
case header.RecvError:
|
||||||
|
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
|
||||||
|
f.handleRecvError(ip, h)
|
||||||
|
return
|
||||||
|
|
||||||
|
case header.CloseTunnel:
|
||||||
|
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
|
||||||
|
if !f.handleEncrypted(ci, ip, h) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
hostinfo.logger(f.l).WithField("udpAddr", ip).
|
||||||
|
Info("Close tunnel received, tearing down.")
|
||||||
|
|
||||||
|
f.closeTunnel(hostinfo)
|
||||||
|
return
|
||||||
|
|
||||||
|
case header.Control:
|
||||||
|
if !f.handleEncrypted(ci, ip, h) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
d, err := f.decrypt(hostinfo, h.MessageCounter, out[i].Scratch, segment, h, nb)
|
||||||
|
if err != nil {
|
||||||
|
hostinfo.logger(f.l).WithError(err).WithField("udpAddr", ip).
|
||||||
|
WithField("packet", segment).
|
||||||
|
Error("Failed to decrypt Control packet")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
f.relayManager.HandleControlMsg(hostinfo, d, f)
|
||||||
|
|
||||||
|
default:
|
||||||
|
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
|
||||||
|
hostinfo.logger(f.l).Debugf("Unexpected packet received from %s", ip)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
d, err := f.decrypt(hostinfo, h.MessageCounter, out[i], packet, h, nb)
|
f.handleHostRoaming(hostinfo, ip)
|
||||||
if err != nil {
|
|
||||||
hostinfo.logger(f.l).WithError(err).WithField("udpAddr", ip).
|
|
||||||
WithField("packet", packet).
|
|
||||||
Error("Failed to decrypt lighthouse packet")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
lhf.HandleRequest(ip[i], hostinfo.vpnAddrs, d, f)
|
f.connectionManager.In(hostinfo)
|
||||||
|
|
||||||
// Fallthrough to the bottom to record incoming traffic
|
|
||||||
|
|
||||||
case header.Test:
|
|
||||||
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
|
|
||||||
if !f.handleEncrypted(ci, ip[i], h) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
d, err := f.decrypt(hostinfo, h.MessageCounter, out[i], packet, h, nb)
|
|
||||||
if err != nil {
|
|
||||||
hostinfo.logger(f.l).WithError(err).WithField("udpAddr", ip).
|
|
||||||
WithField("packet", packet).
|
|
||||||
Error("Failed to decrypt test packet")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if h.Subtype == header.TestRequest {
|
|
||||||
// This testRequest might be from TryPromoteBest, so we should roam
|
|
||||||
// to the new IP address before responding
|
|
||||||
f.handleHostRoaming(hostinfo, ip[i])
|
|
||||||
f.send(header.Test, header.TestReply, ci, hostinfo, d, nb, out[i])
|
|
||||||
}
|
|
||||||
|
|
||||||
// Fallthrough to the bottom to record incoming traffic
|
|
||||||
|
|
||||||
// Non encrypted messages below here, they should not fall through to avoid tracking incoming traffic since they
|
|
||||||
// are unauthenticated
|
|
||||||
|
|
||||||
case header.Handshake:
|
|
||||||
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
|
|
||||||
f.handshakeManager.HandleIncoming(ip[i], nil, packet, h)
|
|
||||||
return
|
|
||||||
|
|
||||||
case header.RecvError:
|
|
||||||
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
|
|
||||||
f.handleRecvError(ip[i], h)
|
|
||||||
return
|
|
||||||
|
|
||||||
case header.CloseTunnel:
|
|
||||||
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
|
|
||||||
if !f.handleEncrypted(ci, ip[i], h) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
hostinfo.logger(f.l).WithField("udpAddr", ip).
|
|
||||||
Info("Close tunnel received, tearing down.")
|
|
||||||
|
|
||||||
f.closeTunnel(hostinfo)
|
|
||||||
return
|
|
||||||
|
|
||||||
case header.Control:
|
|
||||||
if !f.handleEncrypted(ci, ip[i], h) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
d, err := f.decrypt(hostinfo, h.MessageCounter, out[i], packet, h, nb)
|
|
||||||
if err != nil {
|
|
||||||
hostinfo.logger(f.l).WithError(err).WithField("udpAddr", ip).
|
|
||||||
WithField("packet", packet).
|
|
||||||
Error("Failed to decrypt Control packet")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
f.relayManager.HandleControlMsg(hostinfo, d, f)
|
|
||||||
|
|
||||||
default:
|
|
||||||
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
|
|
||||||
hostinfo.logger(f.l).Debugf("Unexpected packet received from %s", ip)
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
f.handleHostRoaming(hostinfo, ip[i])
|
|
||||||
|
|
||||||
f.connectionManager.In(hostinfo)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -666,16 +672,17 @@ func (f *Interface) decrypt(hostinfo *HostInfo, mc uint64, out []byte, packet []
|
|||||||
return out, nil
|
return out, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Interface) decryptToTunDelayWrite(hostinfo *HostInfo, messageCounter uint64, out []byte, outNeedsTun *int, packet []byte, fwPacket *firewall.Packet, nb []byte, q int, localCache firewall.ConntrackCache) bool {
|
func (f *Interface) decryptToTunDelayWrite(hostinfo *HostInfo, messageCounter uint64, out *packet.OutPacket, pkt *packet.Packet, inSegment []byte, fwPacket *firewall.Packet, nb []byte, q int, localCache firewall.ConntrackCache) bool {
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:header.Len], packet[header.Len:], messageCounter, nb)
|
out.Segments[out.SegCounter] = out.Segments[out.SegCounter][:0]
|
||||||
|
out.Segments[out.SegCounter], err = hostinfo.ConnectionState.dKey.DecryptDanger(out.Segments[out.SegCounter], inSegment[:header.Len], inSegment[header.Len:], messageCounter, nb)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
hostinfo.logger(f.l).WithError(err).Error("Failed to decrypt packet")
|
hostinfo.logger(f.l).WithError(err).Error("Failed to decrypt packet")
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
err = newPacket(out, true, fwPacket)
|
err = newPacket(out.Segments[out.SegCounter], true, fwPacket)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
hostinfo.logger(f.l).WithError(err).WithField("packet", out).
|
hostinfo.logger(f.l).WithError(err).WithField("packet", out).
|
||||||
Warnf("Error while validating inbound packet")
|
Warnf("Error while validating inbound packet")
|
||||||
@@ -692,7 +699,7 @@ func (f *Interface) decryptToTunDelayWrite(hostinfo *HostInfo, messageCounter ui
|
|||||||
if dropReason != nil {
|
if dropReason != nil {
|
||||||
// NOTE: We give `packet` as the `out` here since we already decrypted from it and we don't need it anymore
|
// NOTE: We give `packet` as the `out` here since we already decrypted from it and we don't need it anymore
|
||||||
// This gives us a buffer to build the reject packet in
|
// This gives us a buffer to build the reject packet in
|
||||||
f.rejectOutside(out, hostinfo.ConnectionState, hostinfo, nb, packet, q)
|
f.rejectOutside(out.Segments[out.SegCounter], hostinfo.ConnectionState, hostinfo, nb, inSegment, q)
|
||||||
if f.l.Level >= logrus.DebugLevel {
|
if f.l.Level >= logrus.DebugLevel {
|
||||||
hostinfo.logger(f.l).WithField("fwPacket", fwPacket).
|
hostinfo.logger(f.l).WithField("fwPacket", fwPacket).
|
||||||
WithField("reason", dropReason).
|
WithField("reason", dropReason).
|
||||||
@@ -702,7 +709,8 @@ func (f *Interface) decryptToTunDelayWrite(hostinfo *HostInfo, messageCounter ui
|
|||||||
}
|
}
|
||||||
|
|
||||||
f.connectionManager.In(hostinfo)
|
f.connectionManager.In(hostinfo)
|
||||||
*outNeedsTun = len(out)
|
pkt.OutLen += len(inSegment)
|
||||||
|
out.SegCounter++
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -128,8 +128,10 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueu
|
|||||||
return nil, fmt.Errorf("set vnethdr size: %w", err)
|
return nil, fmt.Errorf("set vnethdr size: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
flags := 0
|
||||||
|
//flags := unix.TUN_F_CSUM
|
||||||
//|unix.TUN_F_USO4|unix.TUN_F_USO6
|
//|unix.TUN_F_USO4|unix.TUN_F_USO6
|
||||||
err = unix.IoctlSetInt(fd, unix.TUNSETOFFLOAD, 0) //todo!
|
err = unix.IoctlSetInt(fd, unix.TUNSETOFFLOAD, flags)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("set offloads: %w", err)
|
return nil, fmt.Errorf("set offloads: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
23
packet/outpacket.go
Normal file
23
packet/outpacket.go
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
package packet
|
||||||
|
|
||||||
|
type OutPacket struct {
|
||||||
|
Segments [][]byte
|
||||||
|
//todo virtio header?
|
||||||
|
SegSize int
|
||||||
|
SegCounter int
|
||||||
|
Valid bool
|
||||||
|
wasSegmented bool
|
||||||
|
|
||||||
|
Scratch []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewOut() *OutPacket {
|
||||||
|
out := new(OutPacket)
|
||||||
|
const numSegments = 64
|
||||||
|
out.Segments = make([][]byte, numSegments)
|
||||||
|
for i := 0; i < numSegments; i++ { //todo this is dumb
|
||||||
|
out.Segments[i] = make([]byte, Size)
|
||||||
|
}
|
||||||
|
out.Scratch = make([]byte, Size)
|
||||||
|
return out
|
||||||
|
}
|
||||||
117
packet/packet.go
Normal file
117
packet/packet.go
Normal file
@@ -0,0 +1,117 @@
|
|||||||
|
package packet
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"iter"
|
||||||
|
"net/netip"
|
||||||
|
|
||||||
|
"golang.org/x/sys/unix"
|
||||||
|
)
|
||||||
|
|
||||||
|
const Size = 0xffff
|
||||||
|
|
||||||
|
type Packet struct {
|
||||||
|
Payload []byte
|
||||||
|
Control []byte
|
||||||
|
Name []byte
|
||||||
|
SegSize int
|
||||||
|
|
||||||
|
//todo should this hold out as well?
|
||||||
|
OutLen int
|
||||||
|
|
||||||
|
wasSegmented bool
|
||||||
|
isV4 bool
|
||||||
|
//Addr netip.AddrPort
|
||||||
|
}
|
||||||
|
|
||||||
|
func New(isV4 bool) *Packet {
|
||||||
|
return &Packet{
|
||||||
|
Payload: make([]byte, Size),
|
||||||
|
Control: make([]byte, unix.CmsgSpace(2)),
|
||||||
|
Name: make([]byte, unix.SizeofSockaddrInet6),
|
||||||
|
isV4: isV4,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Packet) AddrPort() netip.AddrPort {
|
||||||
|
var ip netip.Addr
|
||||||
|
// Its ok to skip the ok check here, the slicing is the only error that can occur and it will panic
|
||||||
|
if p.isV4 {
|
||||||
|
ip, _ = netip.AddrFromSlice(p.Name[4:8])
|
||||||
|
} else {
|
||||||
|
ip, _ = netip.AddrFromSlice(p.Name[8:24])
|
||||||
|
}
|
||||||
|
return netip.AddrPortFrom(ip.Unmap(), binary.BigEndian.Uint16(p.Name[2:4]))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Packet) updateCtrl(ctrlLen int) {
|
||||||
|
p.SegSize = len(p.Payload)
|
||||||
|
p.wasSegmented = false
|
||||||
|
if ctrlLen == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if len(p.Control) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
cmsgs, err := unix.ParseSocketControlMessage(p.Control)
|
||||||
|
if err != nil {
|
||||||
|
return // oh well
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, c := range cmsgs {
|
||||||
|
if c.Header.Level == unix.SOL_UDP && c.Header.Type == unix.UDP_GRO && len(c.Data) >= 2 {
|
||||||
|
p.wasSegmented = true
|
||||||
|
p.SegSize = int(binary.LittleEndian.Uint16(c.Data[:2]))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update sets a Packet into "just received, not processed" state
|
||||||
|
func (p *Packet) Update(ctrlLen int) {
|
||||||
|
p.OutLen = -1
|
||||||
|
p.updateCtrl(ctrlLen)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Packet) Segments() iter.Seq[[]byte] {
|
||||||
|
return func(yield func([]byte) bool) {
|
||||||
|
//cursor := 0
|
||||||
|
for offset := 0; offset < len(p.Payload); offset += p.SegSize {
|
||||||
|
end := offset + p.SegSize
|
||||||
|
if end > len(p.Payload) {
|
||||||
|
end = len(p.Payload)
|
||||||
|
}
|
||||||
|
if !yield(p.Payload[offset:end]) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//if p.SegSize > 0 && p.SegSize < len(p.Payload) {
|
||||||
|
//
|
||||||
|
//} else {
|
||||||
|
// f.readOutsidePackets(p.Addr, nil, result2[:0], p.Payload, h, fwPacket2, lhh, nb2, i, conntrackCache.Get(f.l))
|
||||||
|
//}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//type Pool struct {
|
||||||
|
// pool sync.Pool
|
||||||
|
//}
|
||||||
|
//
|
||||||
|
//var bigPool = &Pool{
|
||||||
|
// pool: sync.Pool{New: func() any { return New() }},
|
||||||
|
//}
|
||||||
|
//
|
||||||
|
//func GetPool() *Pool {
|
||||||
|
// return bigPool
|
||||||
|
//}
|
||||||
|
//
|
||||||
|
//func (p *Pool) Get() *Packet {
|
||||||
|
// return p.pool.Get().(*Packet)
|
||||||
|
//}
|
||||||
|
//
|
||||||
|
//func (p *Pool) Put(x *Packet) {
|
||||||
|
// x.Payload = x.Payload[:Size]
|
||||||
|
// p.pool.Put(x)
|
||||||
|
//}
|
||||||
@@ -4,13 +4,13 @@ import (
|
|||||||
"net/netip"
|
"net/netip"
|
||||||
|
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
|
"github.com/slackhq/nebula/packet"
|
||||||
)
|
)
|
||||||
|
|
||||||
const MTU = 9001
|
const MTU = 9001
|
||||||
|
|
||||||
type EncReader func(
|
type EncReader func(
|
||||||
addrs []netip.AddrPort,
|
[]*packet.Packet,
|
||||||
payload [][]byte,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type Conn interface {
|
type Conn interface {
|
||||||
|
|||||||
@@ -18,18 +18,11 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type StdConn struct {
|
type StdConn struct {
|
||||||
sysFd int
|
sysFd int
|
||||||
isV4 bool
|
isV4 bool
|
||||||
l *logrus.Logger
|
l *logrus.Logger
|
||||||
batch int
|
batch int
|
||||||
}
|
enableGRO bool
|
||||||
|
|
||||||
func maybeIPV4(ip net.IP) (net.IP, bool) {
|
|
||||||
ip4 := ip.To4()
|
|
||||||
if ip4 != nil {
|
|
||||||
return ip4, true
|
|
||||||
}
|
|
||||||
return ip, false
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
|
func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
|
||||||
@@ -119,9 +112,7 @@ func (u *StdConn) LocalAddr() (netip.AddrPort, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (u *StdConn) ListenOut(r EncReader) {
|
func (u *StdConn) ListenOut(r EncReader) {
|
||||||
var ip netip.Addr
|
msgs, packets := u.PrepareRawMessages(u.batch, u.isV4)
|
||||||
addrPorts := make([]netip.AddrPort, u.batch)
|
|
||||||
msgs, buffers, names := u.PrepareRawMessages(u.batch)
|
|
||||||
read := u.ReadMulti
|
read := u.ReadMulti
|
||||||
if u.batch == 1 {
|
if u.batch == 1 {
|
||||||
read = u.ReadSingle
|
read = u.ReadSingle
|
||||||
@@ -135,17 +126,13 @@ func (u *StdConn) ListenOut(r EncReader) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for i := 0; i < n; i++ {
|
for i := 0; i < n; i++ {
|
||||||
// Its ok to skip the ok check here, the slicing is the only error that can occur and it will panic
|
packets[i].Payload = packets[i].Payload[:msgs[i].Len]
|
||||||
if u.isV4 {
|
packets[i].Update(getRawMessageControlLen(&msgs[i]))
|
||||||
ip, _ = netip.AddrFromSlice(names[i][4:8])
|
}
|
||||||
} else {
|
r(packets)
|
||||||
ip, _ = netip.AddrFromSlice(names[i][8:24])
|
for i := 0; i < n; i++ { //todo reset this in prev loop, but this makes debug ez
|
||||||
}
|
msgs[i].Hdr.Controllen = uint64(unix.CmsgSpace(2))
|
||||||
addrPorts[i] = netip.AddrPortFrom(ip.Unmap(), binary.BigEndian.Uint16(names[i][2:4]))
|
|
||||||
buffers[i] = buffers[i][:msgs[i].Len]
|
|
||||||
|
|
||||||
}
|
}
|
||||||
r(addrPorts, buffers)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -297,6 +284,27 @@ func (u *StdConn) ReloadConfig(c *config.C) {
|
|||||||
u.l.WithError(err).Error("Failed to set listen.so_mark")
|
u.l.WithError(err).Error("Failed to set listen.so_mark")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
u.configureGRO(true)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *StdConn) configureGRO(enable bool) {
|
||||||
|
if enable == u.enableGRO {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if enable {
|
||||||
|
if err := unix.SetsockoptInt(u.sysFd, unix.SOL_UDP, unix.UDP_GRO, 1); err != nil {
|
||||||
|
u.l.WithError(err).Warn("Failed to enable UDP GRO")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
u.enableGRO = true
|
||||||
|
u.l.Info("UDP GRO enabled")
|
||||||
|
} else {
|
||||||
|
if err := unix.SetsockoptInt(u.sysFd, unix.SOL_UDP, unix.UDP_GRO, 0); err != nil && err != unix.ENOPROTOOPT {
|
||||||
|
u.l.WithError(err).Warn("Failed to disable UDP GRO")
|
||||||
|
}
|
||||||
|
u.enableGRO = false
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *StdConn) getMemInfo(meminfo *[unix.SK_MEMINFO_VARS]uint32) error {
|
func (u *StdConn) getMemInfo(meminfo *[unix.SK_MEMINFO_VARS]uint32) error {
|
||||||
|
|||||||
@@ -7,6 +7,7 @@
|
|||||||
package udp
|
package udp
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"github.com/slackhq/nebula/packet"
|
||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -33,25 +34,49 @@ type rawMessage struct {
|
|||||||
Pad0 [4]byte
|
Pad0 [4]byte
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) {
|
func setRawMessageControl(msg *rawMessage, buf []byte) {
|
||||||
|
if len(buf) == 0 {
|
||||||
|
msg.Hdr.Control = nil
|
||||||
|
msg.Hdr.Controllen = 0
|
||||||
|
return
|
||||||
|
}
|
||||||
|
msg.Hdr.Control = &buf[0]
|
||||||
|
msg.Hdr.Controllen = uint64(len(buf))
|
||||||
|
}
|
||||||
|
|
||||||
|
func getRawMessageControlLen(msg *rawMessage) int {
|
||||||
|
return int(msg.Hdr.Controllen)
|
||||||
|
}
|
||||||
|
|
||||||
|
func setCmsgLen(h *unix.Cmsghdr, l int) {
|
||||||
|
h.Len = uint64(l)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *StdConn) PrepareRawMessages(n int, isV4 bool) ([]rawMessage, []*packet.Packet) {
|
||||||
msgs := make([]rawMessage, n)
|
msgs := make([]rawMessage, n)
|
||||||
buffers := make([][]byte, n)
|
packets := make([]*packet.Packet, n)
|
||||||
names := make([][]byte, n)
|
|
||||||
|
|
||||||
for i := range msgs {
|
for i := range msgs {
|
||||||
buffers[i] = make([]byte, MTU)
|
packets[i] = packet.New(isV4)
|
||||||
names[i] = make([]byte, unix.SizeofSockaddrInet6)
|
|
||||||
|
|
||||||
vs := []iovec{
|
vs := []iovec{
|
||||||
{Base: &buffers[i][0], Len: uint64(len(buffers[i]))},
|
{Base: &packets[i].Payload[0], Len: uint64(packet.Size)},
|
||||||
}
|
}
|
||||||
|
|
||||||
msgs[i].Hdr.Iov = &vs[0]
|
msgs[i].Hdr.Iov = &vs[0]
|
||||||
msgs[i].Hdr.Iovlen = uint64(len(vs))
|
msgs[i].Hdr.Iovlen = uint64(len(vs))
|
||||||
|
|
||||||
msgs[i].Hdr.Name = &names[i][0]
|
msgs[i].Hdr.Name = &packets[i].Name[0]
|
||||||
msgs[i].Hdr.Namelen = uint32(len(names[i]))
|
msgs[i].Hdr.Namelen = uint32(len(packets[i].Name))
|
||||||
|
|
||||||
|
if u.enableGRO {
|
||||||
|
msgs[i].Hdr.Control = &packets[i].Control[0]
|
||||||
|
msgs[i].Hdr.Controllen = uint64(len(packets[i].Control))
|
||||||
|
} else {
|
||||||
|
msgs[i].Hdr.Control = nil
|
||||||
|
msgs[i].Hdr.Controllen = 0
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return msgs, buffers, names
|
return msgs, packets
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user