diff --git a/handler.go b/handler.go index 45c75dc..369b4dd 100644 --- a/handler.go +++ b/handler.go @@ -1,22 +1,30 @@ package nebula -func (f *Interface) handleMessagePacket(hostInfo *HostInfo, ci *ConnectionState, addr *udpAddr, header *Header, out []byte, packet []byte, fwPacket *FirewallPacket, nb []byte) { - if !f.handleEncrypted(ci, addr, header) { - return +func (f *Interface) encrypted(h InsideHandler) InsideHandler { + return func(hostInfo *HostInfo, ci *ConnectionState, addr *udpAddr, header *Header, out []byte, packet []byte, fwPacket *FirewallPacket, nb []byte) { + if !f.handleEncrypted(ci, addr, header) { + return + } + + h(hostInfo, ci, addr, header, out, packet, fwPacket, nb) + + f.handleHostRoaming(hostInfo, addr) + f.connectionManager.In(hostInfo.hostId) } +} +func (f *Interface) rxMetrics(h InsideHandler) InsideHandler { + return func(hostInfo *HostInfo, ci *ConnectionState, addr *udpAddr, header *Header, out []byte, packet []byte, fwPacket *FirewallPacket, nb []byte) { + f.messageMetrics.Rx(header.Type, header.Subtype, 1) + h(hostInfo, ci, addr, header, out, packet, fwPacket, nb) + } +} + +func (f *Interface) handleMessagePacket(hostInfo *HostInfo, ci *ConnectionState, addr *udpAddr, header *Header, out []byte, packet []byte, fwPacket *FirewallPacket, nb []byte) { f.decryptToTun(hostInfo, header.MessageCounter, out, packet, fwPacket, nb) - - f.handleHostRoaming(hostInfo, addr) - f.connectionManager.In(hostInfo.hostId) } func (f *Interface) handleLighthousePacket(hostInfo *HostInfo, ci *ConnectionState, addr *udpAddr, header *Header, out []byte, packet []byte, fwPacket *FirewallPacket, nb []byte) { - f.messageMetrics.Rx(header.Type, header.Subtype, 1) - if !f.handleEncrypted(ci, addr, header) { - return - } - d, err := f.decrypt(hostInfo, header.MessageCounter, out, packet, header, nb) if err != nil { hostInfo.logger().WithError(err).WithField("udpAddr", addr). @@ -29,17 +37,9 @@ func (f *Interface) handleLighthousePacket(hostInfo *HostInfo, ci *ConnectionSta } f.lightHouse.HandleRequest(addr, hostInfo.hostId, d, hostInfo.GetCert(), f) - - f.handleHostRoaming(hostInfo, addr) - f.connectionManager.In(hostInfo.hostId) } func (f *Interface) handleTestPacket(hostInfo *HostInfo, ci *ConnectionState, addr *udpAddr, header *Header, out []byte, packet []byte, fwPacket *FirewallPacket, nb []byte) { - f.messageMetrics.Rx(header.Type, header.Subtype, 1) - if !f.handleEncrypted(ci, addr, header) { - return - } - d, err := f.decrypt(hostInfo, header.MessageCounter, out, packet, header, nb) if err != nil { hostInfo.logger().WithError(err).WithField("udpAddr", addr). @@ -57,28 +57,18 @@ func (f *Interface) handleTestPacket(hostInfo *HostInfo, ci *ConnectionState, ad f.handleHostRoaming(hostInfo, addr) f.send(test, testReply, ci, hostInfo, hostInfo.remote, d, nb, out) } - - f.handleHostRoaming(hostInfo, addr) - f.connectionManager.In(hostInfo.hostId) } func (f *Interface) handleHandshakePacket(hostInfo *HostInfo, ci *ConnectionState, addr *udpAddr, header *Header, out []byte, packet []byte, fwPacket *FirewallPacket, nb []byte) { - f.messageMetrics.Rx(header.Type, header.Subtype, 1) HandleIncomingHandshake(f, addr, packet, header, hostInfo) } func (f *Interface) handleRecvErrorPacket(hostInfo *HostInfo, ci *ConnectionState, addr *udpAddr, header *Header, out []byte, packet []byte, fwPacket *FirewallPacket, nb []byte) { - f.messageMetrics.Rx(header.Type, header.Subtype, 1) // TODO: Remove this with recv_error deprecation f.handleRecvError(addr, header) } func (f *Interface) handleCloseTunnelPacket(hostInfo *HostInfo, ci *ConnectionState, addr *udpAddr, header *Header, out []byte, packet []byte, fwPacket *FirewallPacket, nb []byte) { - f.messageMetrics.Rx(header.Type, header.Subtype, 1) - if !f.handleEncrypted(ci, addr, header) { - return - } - hostInfo.logger().WithField("udpAddr", addr). Info("Close tunnel received, tearing down.") diff --git a/interface.go b/interface.go index 3a80838..694d25d 100644 --- a/interface.go +++ b/interface.go @@ -111,23 +111,23 @@ func NewInterface(c *InterfaceConfig) (*Interface, error) { ifce.handlers = map[uint8]map[NebulaMessageType]map[NebulaMessageSubType]InsideHandler{ Version: { handshake: { - handshakeIXPSK0: ifce.handleHandshakePacket, + handshakeIXPSK0: ifce.rxMetrics(ifce.handleHandshakePacket), }, message: { - subTypeNone: ifce.handleMessagePacket, + subTypeNone: ifce.encrypted(ifce.handleMessagePacket), }, recvError: { - subTypeNone: ifce.handleRecvErrorPacket, + subTypeNone: ifce.rxMetrics(ifce.handleRecvErrorPacket), }, lightHouse: { - subTypeNone: ifce.handleLighthousePacket, + subTypeNone: ifce.rxMetrics(ifce.encrypted(ifce.handleLighthousePacket)), }, test: { - testRequest: ifce.handleTestPacket, - testReply: ifce.handleTestPacket, + testRequest: ifce.rxMetrics(ifce.encrypted(ifce.handleTestPacket)), + testReply: ifce.rxMetrics(ifce.encrypted(ifce.handleTestPacket)), }, closeTunnel: { - subTypeNone: ifce.handleCloseTunnelPacket, + subTypeNone: ifce.rxMetrics(ifce.encrypted(ifce.handleCloseTunnelPacket)), }, }, }