diff --git a/control.go b/control.go index 20d4bc0..4c6648a 100644 --- a/control.go +++ b/control.go @@ -192,10 +192,7 @@ func (c *Control) Hook(t NebulaMessageSubType, w func([]byte) error) error { // The provided payload will be encapsulated in a Nebula Firewall packet // (IPv4 plus ports) from the node IP to the provided destination nebula IP. // Any protocol handling above layer 3 (IP) must be managed by the caller. -func (c *Control) Send(ip uint32, port uint16, t NebulaMessageSubType, payload []byte) { - hostinfo := c.f.getOrHandshake(ip) - ci := hostinfo.ConnectionState - +func (c *Control) Send(ip uint32, port uint16, st NebulaMessageSubType, payload []byte) { headerLen := ipv4.HeaderLen + minFwPacketLen length := headerLen + len(payload) packet := make([]byte, length) @@ -206,13 +203,14 @@ func (c *Control) Send(ip uint32, port uint16, t NebulaMessageSubType, payload [ binary.BigEndian.PutUint32(packet[16:20], ip) // Set identical values for src and dst port as they're only - // used for nebula firewall rule mataching. + // used for nebula firewall rule/conntrack matching. binary.BigEndian.PutUint16(packet[20:22], port) binary.BigEndian.PutUint16(packet[22:24], port) copy(packet[headerLen:], payload) + fp := &FirewallPacket{} nb := make([]byte, 12) out := make([]byte, mtu) - c.f.sendNoMetrics(message, t, ci, hostinfo, hostinfo.remote, packet, nb, out) + c.f.consumeInsidePacket(st, packet, fp, nb, out) } diff --git a/inside.go b/inside.go index 3e36b49..d1d6a81 100644 --- a/inside.go +++ b/inside.go @@ -7,7 +7,7 @@ import ( "github.com/sirupsen/logrus" ) -func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *FirewallPacket, nb, out []byte) { +func (f *Interface) consumeInsidePacket(st NebulaMessageSubType, packet []byte, fwPacket *FirewallPacket, nb, out []byte) { err := newPacket(packet, false, fwPacket) if err != nil { l.WithField("packet", packet).Debugf("Error while validating outbound packet: %s", err) @@ -45,7 +45,7 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *FirewallPacket, // the packet queue. ci.queueLock.Lock() if !ci.ready { - hostinfo.cachePacket(message, 0, packet, f.sendMessageNow) + hostinfo.cachePacket(message, st, packet, f.sendMessageNow) ci.queueLock.Unlock() return } @@ -54,7 +54,7 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *FirewallPacket, dropReason := f.firewall.Drop(packet, *fwPacket, false, hostinfo, trustedCAs) if dropReason == nil { - mc := f.sendNoMetrics(message, 0, ci, hostinfo, hostinfo.remote, packet, nb, out) + mc := f.sendNoMetrics(message, st, ci, hostinfo, hostinfo.remote, packet, nb, out) if f.lightHouse != nil && mc%5000 == 0 { f.lightHouse.Query(fwPacket.RemoteIP, f) } diff --git a/interface.go b/interface.go index 694d25d..1f3b687 100644 --- a/interface.go +++ b/interface.go @@ -196,7 +196,7 @@ func (f *Interface) listenIn(i int) { os.Exit(2) } - f.consumeInsidePacket(packet[:n], fwPacket, nb, out) + f.consumeInsidePacket(subTypeNone, packet[:n], fwPacket, nb, out) } }