Compare commits

..

5 Commits

Author SHA1 Message Date
Nate Brown
e25016a946 Fix e2e unsafe inbound test 2025-11-21 15:51:02 -06:00
Nate Brown
c2381e7019 Add some comments to make table config clearer 2025-11-21 14:29:46 -06:00
Nate Brown
2b0d57b464 Track unsafe in the mock firewall 2025-11-21 14:23:00 -06:00
Nate Brown
c69b009650 Change name from forward to unsafe 2025-11-21 14:21:08 -06:00
Nate Brown
281a9017ce Add forward tables to handle unsafe network packets distinctly from vpn network packets 2025-11-21 14:18:26 -06:00
11 changed files with 308 additions and 362 deletions

View File

@@ -7,85 +7,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [Unreleased] ## [Unreleased]
## [1.10.0] - 2025-12-04
See the [v1.10.0](https://github.com/slackhq/nebula/milestone/16?closed=1) milestone for a complete list of changes.
### Added
- Support for ipv6 and multiple ipv4/6 addresses in the overlay.
A new v2 ASN.1 based certificate format.
Certificates now have a unified interface for external implementations.
(#1212, #1216, #1345, #1359, #1381, #1419, #1464, #1466, #1451, #1476, #1467, #1481, #1399, #1488, #1492, #1495, #1468, #1521, #1535, #1538)
- Add the ability to mark packets on linux to better target nebula packets in iptables/nftables. (#1331)
- Add ECMP support for `unsafe_routes`. (#1332)
- PKCS11 support for P256 keys when built with `pkcs11` tag (#1153, #1482)
### Changed ### Changed
- **NOTE**: `default_local_cidr_any` now defaults to false, meaning that any firewall rule - `default_local_cidr_any` now defaults to false, meaning that any firewall rule
intended to target an `unsafe_routes` entry must explicitly declare it via the intended to target an `unsafe_routes` entry must explicitly declare it via the
`local_cidr` field. This is almost always the intended behavior. This flag is `local_cidr` field. This is almost always the intended behavior. This flag is
deprecated and will be removed in a future release. (#1373) deprecated and will be removed in a future release.
- Improve logging when a relay is in use on an inbound packet. (#1533)
- Avoid fatal errors if `rountines` is > 1 on systems that don't support more than 1 routine. (#1531)
- Log a warning if a firewall rule contains an `any` that negates a more restrictive filter. (#1513)
- Accept encrypted CA passphrase from an environment variable. (#1421)
- Allow handshaking with any trusted remote. (#1509)
- Log only the count of blocklisted certificate fingerprints instead of the entire list. (#1525)
- Don't fatal when the ssh server is unable to be configured successfully. (#1520)
- Update to build against go v1.25. (#1483)
- Allow projects using `nebula` as a library with userspace networking to configure the `logger` and build version. (#1239)
- Upgrade to `yaml.v3`. (#1148, #1371, #1438, #1478)
### Fixed
- Fix a potential bug with udp ipv4 only on darwin. (#1532)
- Improve lost packet statistics. (#1441, #1537)
- Honor `remote_allow_list` in hole punch response. (#1186)
- Fix a panic when `tun.use_system_route_table` is `true` and a route lacks a destination. (#1437)
- Fix an issue when `tun.use_system_route_table: true` could result in heavy CPU utilization when many thousands of routes
are present. (#1326)
- Fix tests for 32 bit machines. (#1394)
- Fix a possible 32bit integer underflow in config handling. (#1353)
- Fix moving a udp address from one vpn address to another in the `static_host_map`
which could cause rapid re-handshaking with an incorrect remote. (#1259)
- Improve smoke tests in environments where the docker network is not the default. (#1347)
## [1.9.7] - 2025-10-10
### Security
- Fix an issue where Nebula could incorrectly accept and process a packet from an erroneous source IP when the sender's
certificate is configured with unsafe_routes (cert v1/v2) or multiple IPs (cert v2). (#1494)
### Changed
- Disable sending `recv_error` messages when a packet is received outside the allowable counter window. (#1459)
- Improve error messages and remove some unnecessary fatal conditions in the Windows and generic udp listener. (#1453)
## [1.9.6] - 2025-7-15
### Added
- Support dropping inactive tunnels. This is disabled by default in this release but can be enabled with `tunnels.drop_inactive`. See example config for more details. (#1413)
### Fixed
- Fix Darwin freeze due to presence of some Network Extensions (#1426)
- Ensure the same relay tunnel is always used when multiple relay tunnels are present (#1422)
- Fix Windows freeze due to ICMP error handling (#1412)
- Fix relay migration panic (#1403)
## [1.9.5] - 2024-12-05
### Added
- Gracefully ignore v2 certificates. (#1282)
### Fixed
- Fix relays that refuse to re-establish after one of the remote tunnel pairs breaks. (#1277)
## [1.9.4] - 2024-09-09 ## [1.9.4] - 2024-09-09
@@ -744,11 +671,7 @@ created.)
- Initial public release. - Initial public release.
[Unreleased]: https://github.com/slackhq/nebula/compare/v1.10.0...HEAD [Unreleased]: https://github.com/slackhq/nebula/compare/v1.9.4...HEAD
[1.10.0]: https://github.com/slackhq/nebula/releases/tag/v1.10.0
[1.9.7]: https://github.com/slackhq/nebula/releases/tag/v1.9.7
[1.9.6]: https://github.com/slackhq/nebula/releases/tag/v1.9.6
[1.9.5]: https://github.com/slackhq/nebula/releases/tag/v1.9.5
[1.9.4]: https://github.com/slackhq/nebula/releases/tag/v1.9.4 [1.9.4]: https://github.com/slackhq/nebula/releases/tag/v1.9.4
[1.9.3]: https://github.com/slackhq/nebula/releases/tag/v1.9.3 [1.9.3]: https://github.com/slackhq/nebula/releases/tag/v1.9.3
[1.9.2]: https://github.com/slackhq/nebula/releases/tag/v1.9.2 [1.9.2]: https://github.com/slackhq/nebula/releases/tag/v1.9.2

View File

@@ -25,12 +25,11 @@ import (
func BenchmarkHotPath(b *testing.B) { func BenchmarkHotPath(b *testing.B) {
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil) myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil)
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil) theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil)
// Put their info in our lighthouse // Put their info in our lighthouse
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
// Start the servers // Start the servers
myControl.Start() myControl.Start()
@@ -39,9 +38,6 @@ func BenchmarkHotPath(b *testing.B) {
r := router.NewR(b, myControl, theirControl) r := router.NewR(b, myControl, theirControl)
r.CancelFlowLogs() r.CancelFlowLogs()
assertTunnel(b, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
b.ResetTimer()
for n := 0; n < b.N; n++ { for n := 0; n < b.N; n++ {
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
_ = r.RouteForAllUntilTxTun(theirControl) _ = r.RouteForAllUntilTxTun(theirControl)
@@ -51,39 +47,6 @@ func BenchmarkHotPath(b *testing.B) {
theirControl.Stop() theirControl.Stop()
} }
func BenchmarkHotPathRelay(b *testing.B) {
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}})
relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}})
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}})
// Teach my how to get to the relay and that their can be reached via the relay
myControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr)
myControl.InjectRelays(theirVpnIpNet[0].Addr(), []netip.Addr{relayVpnIpNet[0].Addr()})
relayControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
// Build a router so we don't have to reason who gets which packet
r := router.NewR(b, myControl, relayControl, theirControl)
r.CancelFlowLogs()
// Start the servers
myControl.Start()
relayControl.Start()
theirControl.Start()
assertTunnel(b, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r)
b.ResetTimer()
for n := 0; n < b.N; n++ {
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
_ = r.RouteForAllUntilTxTun(theirControl)
}
myControl.Stop()
theirControl.Stop()
relayControl.Stop()
}
func TestGoodHandshake(t *testing.T) { func TestGoodHandshake(t *testing.T) {
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil) myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil)
@@ -1378,6 +1341,13 @@ func TestGoodHandshakeUnsafeDest(t *testing.T) {
"tun": m{ "tun": m{
"unsafe_routes": []m{route}, "unsafe_routes": []m{route},
}, },
"firewall": m{
"unsafe_outbound": []m{{
"port": "any",
"proto": "any",
"host": "any",
}},
},
} }
myControl, myVpnIpNet, myUdpAddr, myConfig := newSimpleServer(cert.Version2, ca, caKey, "me", "10.128.0.1/24", myCfg) myControl, myVpnIpNet, myUdpAddr, myConfig := newSimpleServer(cert.Version2, ca, caKey, "me", "10.128.0.1/24", myCfg)
t.Logf("my config %v", myConfig) t.Logf("my config %v", myConfig)

View File

@@ -85,8 +85,9 @@ func newSimpleServerWithUdpAndUnsafeNetworks(v cert.Version, caCrt cert.Certific
}} }}
var unsafeNetworks []netip.Prefix var unsafeNetworks []netip.Prefix
var firewallUnsafeInbound []m
if sUnsafeNetworks != "" { if sUnsafeNetworks != "" {
firewallInbound = []m{{ firewallUnsafeInbound = []m{{
"proto": "any", "proto": "any",
"port": "any", "port": "any",
"host": "any", "host": "any",
@@ -122,7 +123,8 @@ func newSimpleServerWithUdpAndUnsafeNetworks(v cert.Version, caCrt cert.Certific
"port": "any", "port": "any",
"host": "any", "host": "any",
}}, }},
"inbound": firewallInbound, "inbound": firewallInbound,
"unsafe_inbound": firewallUnsafeInbound,
}, },
//"handshakes": m{ //"handshakes": m{
// "try_interval": "1s", // "try_interval": "1s",
@@ -292,7 +294,7 @@ func deadline(t *testing.T, seconds time.Duration) doneCb {
} }
} }
func assertTunnel(t testing.TB, vpnIpA, vpnIpB netip.Addr, controlA, controlB *nebula.Control, r *router.R) { func assertTunnel(t *testing.T, vpnIpA, vpnIpB netip.Addr, controlA, controlB *nebula.Control, r *router.R) {
// Send a packet from them to me // Send a packet from them to me
controlB.InjectTunUDPPacket(vpnIpA, 80, vpnIpB, 90, []byte("Hi from B")) controlB.InjectTunUDPPacket(vpnIpA, 80, vpnIpB, 90, []byte("Hi from B"))
bPacket := r.RouteForAllUntilTxTun(controlA) bPacket := r.RouteForAllUntilTxTun(controlA)
@@ -325,7 +327,7 @@ func assertHostInfoPair(t *testing.T, addrA, addrB netip.AddrPort, vpnNetsA, vpn
assert.Equal(t, hBinA.RemoteIndex, hAinB.LocalIndex, "Host B remote index does not match host A local index") assert.Equal(t, hBinA.RemoteIndex, hAinB.LocalIndex, "Host B remote index does not match host A local index")
} }
func assertUdpPacket(t testing.TB, expected, b []byte, fromIp, toIp netip.Addr, fromPort, toPort uint16) { func assertUdpPacket(t *testing.T, expected, b []byte, fromIp, toIp netip.Addr, fromPort, toPort uint16) {
if toIp.Is6() { if toIp.Is6() {
assertUdpPacket6(t, expected, b, fromIp, toIp, fromPort, toPort) assertUdpPacket6(t, expected, b, fromIp, toIp, fromPort, toPort)
} else { } else {
@@ -333,7 +335,7 @@ func assertUdpPacket(t testing.TB, expected, b []byte, fromIp, toIp netip.Addr,
} }
} }
func assertUdpPacket6(t testing.TB, expected, b []byte, fromIp, toIp netip.Addr, fromPort, toPort uint16) { func assertUdpPacket6(t *testing.T, expected, b []byte, fromIp, toIp netip.Addr, fromPort, toPort uint16) {
packet := gopacket.NewPacket(b, layers.LayerTypeIPv6, gopacket.Lazy) packet := gopacket.NewPacket(b, layers.LayerTypeIPv6, gopacket.Lazy)
v6 := packet.Layer(layers.LayerTypeIPv6).(*layers.IPv6) v6 := packet.Layer(layers.LayerTypeIPv6).(*layers.IPv6)
assert.NotNil(t, v6, "No ipv6 data found") assert.NotNil(t, v6, "No ipv6 data found")
@@ -352,7 +354,7 @@ func assertUdpPacket6(t testing.TB, expected, b []byte, fromIp, toIp netip.Addr,
assert.Equal(t, expected, data.Payload(), "Data was incorrect") assert.Equal(t, expected, data.Payload(), "Data was incorrect")
} }
func assertUdpPacket4(t testing.TB, expected, b []byte, fromIp, toIp netip.Addr, fromPort, toPort uint16) { func assertUdpPacket4(t *testing.T, expected, b []byte, fromIp, toIp netip.Addr, fromPort, toPort uint16) {
packet := gopacket.NewPacket(b, layers.LayerTypeIPv4, gopacket.Lazy) packet := gopacket.NewPacket(b, layers.LayerTypeIPv4, gopacket.Lazy)
v4 := packet.Layer(layers.LayerTypeIPv4).(*layers.IPv4) v4 := packet.Layer(layers.LayerTypeIPv4).(*layers.IPv4)
assert.NotNil(t, v4, "No ipv4 data found") assert.NotNil(t, v4, "No ipv4 data found")

View File

@@ -23,16 +23,15 @@ import (
) )
type FirewallInterface interface { type FirewallInterface interface {
AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, cidr, localCidr string, caName string, caSha string) error AddRule(unsafe, incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, cidr, localCidr string, caName string, caSha string) error
} }
type conn struct { type conn struct {
Expires time.Time // Time when this conntrack entry will expire Expires time.Time // Time when this conntrack entry will expire
// record why the original connection passed the firewall, so we can re-validate // record why the original connection passed the firewall, so we can re-validate after ruleset changes.
// after ruleset changes. Note, rulesVersion is a uint16 so that these two
// fields pack for free after the uint32 above
incoming bool incoming bool
unsafe bool
rulesVersion uint16 rulesVersion uint16
} }
@@ -40,8 +39,10 @@ type conn struct {
type Firewall struct { type Firewall struct {
Conntrack *FirewallConntrack Conntrack *FirewallConntrack
InRules *FirewallTable InRules *FirewallTable
OutRules *FirewallTable OutRules *FirewallTable
UnsafeInRules *FirewallTable
UnsafeOutRules *FirewallTable
InSendReject bool InSendReject bool
OutSendReject bool OutSendReject bool
@@ -54,7 +55,7 @@ type Firewall struct {
// routableNetworks describes the vpn addresses as well as any unsafe networks issued to us in the certificate. // routableNetworks describes the vpn addresses as well as any unsafe networks issued to us in the certificate.
// The vpn addresses are a full bit match while the unsafe networks only match the prefix // The vpn addresses are a full bit match while the unsafe networks only match the prefix
routableNetworks *bart.Lite routableNetworks *bart.Table[NetworkType]
// assignedNetworks is a list of vpn networks assigned to us in the certificate. // assignedNetworks is a list of vpn networks assigned to us in the certificate.
assignedNetworks []netip.Prefix assignedNetworks []netip.Prefix
@@ -149,17 +150,16 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D
tmax = defaultTimeout tmax = defaultTimeout
} }
routableNetworks := new(bart.Lite) routableNetworks := new(bart.Table[NetworkType])
var assignedNetworks []netip.Prefix var assignedNetworks []netip.Prefix
for _, network := range c.Networks() { for _, network := range c.Networks() {
nprefix := netip.PrefixFrom(network.Addr(), network.Addr().BitLen()) routableNetworks.Insert(netip.PrefixFrom(network.Addr(), network.Addr().BitLen()), NetworkTypeVPN)
routableNetworks.Insert(nprefix)
assignedNetworks = append(assignedNetworks, network) assignedNetworks = append(assignedNetworks, network)
} }
hasUnsafeNetworks := false hasUnsafeNetworks := false
for _, n := range c.UnsafeNetworks() { for _, n := range c.UnsafeNetworks() {
routableNetworks.Insert(n) routableNetworks.Insert(n, NetworkTypeUnsafe)
hasUnsafeNetworks = true hasUnsafeNetworks = true
} }
@@ -170,6 +170,8 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D
}, },
InRules: newFirewallTable(), InRules: newFirewallTable(),
OutRules: newFirewallTable(), OutRules: newFirewallTable(),
UnsafeInRules: newFirewallTable(),
UnsafeOutRules: newFirewallTable(),
TCPTimeout: tcpTimeout, TCPTimeout: tcpTimeout,
UDPTimeout: UDPTimeout, UDPTimeout: UDPTimeout,
DefaultTimeout: defaultTimeout, DefaultTimeout: defaultTimeout,
@@ -212,6 +214,7 @@ func NewFirewallFromConfig(l *logrus.Logger, cs *CertState, c *config.C) (*Firew
fw.defaultLocalCIDRAny = c.GetBool("firewall.default_local_cidr_any", false) fw.defaultLocalCIDRAny = c.GetBool("firewall.default_local_cidr_any", false)
//TODO: do we also need firewall.unsafe_inbound_action and firewall.unsafe_outbound_action?
inboundAction := c.GetString("firewall.inbound_action", "drop") inboundAction := c.GetString("firewall.inbound_action", "drop")
switch inboundAction { switch inboundAction {
case "reject": case "reject":
@@ -234,12 +237,26 @@ func NewFirewallFromConfig(l *logrus.Logger, cs *CertState, c *config.C) (*Firew
fw.OutSendReject = false fw.OutSendReject = false
} }
err := AddFirewallRulesFromConfig(l, false, c, fw) // outbound rules
err := AddFirewallRulesFromConfig(l, false, false, c, fw)
if err != nil { if err != nil {
return nil, err return nil, err
} }
err = AddFirewallRulesFromConfig(l, true, c, fw) // unsafe outbound rules
err = AddFirewallRulesFromConfig(l, true, false, c, fw)
if err != nil {
return nil, err
}
// inbound rules
err = AddFirewallRulesFromConfig(l, false, true, c, fw)
if err != nil {
return nil, err
}
// unsafe inbound rules
err = AddFirewallRulesFromConfig(l, true, true, c, fw)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -248,11 +265,11 @@ func NewFirewallFromConfig(l *logrus.Logger, cs *CertState, c *config.C) (*Firew
} }
// AddRule properly creates the in memory rule structure for a firewall table. // AddRule properly creates the in memory rule structure for a firewall table.
func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, cidr, localCidr, caName string, caSha string) error { func (f *Firewall) AddRule(unsafe, incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, cidr, localCidr, caName string, caSha string) error {
// We need this rule string because we generate a hash. Removing this will break firewall reload. // We need this rule string because we generate a hash. Removing this will break firewall reload.
ruleString := fmt.Sprintf( ruleString := fmt.Sprintf(
"incoming: %v, proto: %v, startPort: %v, endPort: %v, groups: %v, host: %v, ip: %v, localIp: %v, caName: %v, caSha: %s", "unsafe: %v, incoming: %v, proto: %v, startPort: %v, endPort: %v, groups: %v, host: %v, ip: %v, localIp: %v, caName: %v, caSha: %s",
incoming, proto, startPort, endPort, groups, host, cidr, localCidr, caName, caSha, unsafe, incoming, proto, startPort, endPort, groups, host, cidr, localCidr, caName, caSha,
) )
f.rules += ruleString + "\n" f.rules += ruleString + "\n"
@@ -260,8 +277,12 @@ func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort
if !incoming { if !incoming {
direction = "outgoing" direction = "outgoing"
} }
f.l.WithField("firewallRule", m{"direction": direction, "proto": proto, "startPort": startPort, "endPort": endPort, "groups": groups, "host": host, "cidr": cidr, "localCidr": localCidr, "caName": caName, "caSha": caSha}).
Info("Firewall rule added") fields := m{"direction": direction, "proto": proto, "startPort": startPort, "endPort": endPort, "groups": groups, "host": host, "cidr": cidr, "localCidr": localCidr, "caName": caName, "caSha": caSha}
if unsafe {
fields["unsafe"] = true
}
f.l.WithField("firewallRule", fields).Info("Firewall rule added")
var ( var (
ft *FirewallTable ft *FirewallTable
@@ -269,9 +290,18 @@ func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort
) )
if incoming { if incoming {
ft = f.InRules if unsafe {
ft = f.UnsafeInRules
} else {
ft = f.InRules
}
} else { } else {
ft = f.OutRules if unsafe {
ft = f.UnsafeOutRules
} else {
ft = f.OutRules
}
} }
switch proto { switch proto {
@@ -308,12 +338,21 @@ func (f *Firewall) GetRuleHashes() string {
return "SHA:" + f.GetRuleHash() + ",FNV:" + strconv.FormatUint(uint64(f.GetRuleHashFNV()), 10) return "SHA:" + f.GetRuleHash() + ",FNV:" + strconv.FormatUint(uint64(f.GetRuleHashFNV()), 10)
} }
func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw FirewallInterface) error { func AddFirewallRulesFromConfig(l *logrus.Logger, unsafe, inbound bool, c *config.C, fw FirewallInterface) error {
var table string var table string
if inbound { if inbound {
table = "firewall.inbound" if unsafe {
table = "firewall.unsafe_inbound"
} else {
table = "firewall.inbound"
}
} else { } else {
table = "firewall.outbound" if unsafe {
table = "firewall.unsafe_outbound"
} else {
table = "firewall.outbound"
}
} }
r := c.Get(table) r := c.Get(table)
@@ -386,7 +425,7 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw
l.Warnf("%s rule #%v; %s", table, i, warning) l.Warnf("%s rule #%v; %s", table, i, warning)
} }
err = fw.AddRule(inbound, proto, startPort, endPort, r.Groups, r.Host, r.Cidr, r.LocalCidr, r.CAName, r.CASha) err = fw.AddRule(unsafe, inbound, proto, startPort, endPort, r.Groups, r.Host, r.Cidr, r.LocalCidr, r.CAName, r.CASha)
if err != nil { if err != nil {
return fmt.Errorf("%s rule #%v; `%s`", table, i, err) return fmt.Errorf("%s rule #%v; `%s`", table, i, err)
} }
@@ -409,6 +448,9 @@ func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *
return nil return nil
} }
var remoteNetworkType NetworkType
var ok bool
// Make sure remote address matches nebula certificate, and determine how to treat it // Make sure remote address matches nebula certificate, and determine how to treat it
if h.networks == nil { if h.networks == nil {
// Simple case: Certificate has one address and no unsafe networks // Simple case: Certificate has one address and no unsafe networks
@@ -416,13 +458,14 @@ func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *
f.metrics(incoming).droppedRemoteAddr.Inc(1) f.metrics(incoming).droppedRemoteAddr.Inc(1)
return ErrInvalidRemoteIP return ErrInvalidRemoteIP
} }
remoteNetworkType = NetworkTypeVPN
} else { } else {
nwType, ok := h.networks.Lookup(fp.RemoteAddr) remoteNetworkType, ok = h.networks.Lookup(fp.RemoteAddr)
if !ok { if !ok {
f.metrics(incoming).droppedRemoteAddr.Inc(1) f.metrics(incoming).droppedRemoteAddr.Inc(1)
return ErrInvalidRemoteIP return ErrInvalidRemoteIP
} }
switch nwType { switch remoteNetworkType {
case NetworkTypeVPN: case NetworkTypeVPN:
break // nothing special break // nothing special
case NetworkTypeVPNPeer: case NetworkTypeVPNPeer:
@@ -437,14 +480,27 @@ func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *
} }
// Make sure we are supposed to be handling this local ip address // Make sure we are supposed to be handling this local ip address
if !f.routableNetworks.Contains(fp.LocalAddr) { localNetworkType, ok := f.routableNetworks.Lookup(fp.LocalAddr)
if !ok {
f.metrics(incoming).droppedLocalAddr.Inc(1) f.metrics(incoming).droppedLocalAddr.Inc(1)
return ErrInvalidLocalIP return ErrInvalidLocalIP
} }
table := f.OutRules useUnsafe := remoteNetworkType == NetworkTypeUnsafe || localNetworkType == NetworkTypeUnsafe
var table *FirewallTable
if incoming { if incoming {
table = f.InRules if useUnsafe {
table = f.UnsafeInRules
} else {
table = f.InRules
}
} else {
if useUnsafe {
table = f.UnsafeOutRules
} else {
table = f.OutRules
}
} }
// We now know which firewall table to check against // We now know which firewall table to check against
@@ -454,12 +510,13 @@ func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *
} }
// We always want to conntrack since it is a faster operation // We always want to conntrack since it is a faster operation
f.addConn(fp, incoming) f.addConn(fp, useUnsafe, incoming)
return nil return nil
} }
func (f *Firewall) metrics(incoming bool) firewallMetrics { func (f *Firewall) metrics(incoming bool) firewallMetrics {
//TODO: need unsafe metrics too
if incoming { if incoming {
return f.incomingMetrics return f.incomingMetrics
} else { } else {
@@ -499,7 +556,6 @@ func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool,
} }
c, ok := conntrack.Conns[fp] c, ok := conntrack.Conns[fp]
if !ok { if !ok {
conntrack.Unlock() conntrack.Unlock()
return false return false
@@ -508,9 +564,19 @@ func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool,
if c.rulesVersion != f.rulesVersion { if c.rulesVersion != f.rulesVersion {
// This conntrack entry was for an older rule set, validate // This conntrack entry was for an older rule set, validate
// it still passes with the current rule set // it still passes with the current rule set
table := f.OutRules var table *FirewallTable
if c.incoming { if c.incoming {
table = f.InRules if c.unsafe {
table = f.UnsafeInRules
} else {
table = f.InRules
}
} else {
if c.unsafe {
table = f.UnsafeOutRules
} else {
table = f.OutRules
}
} }
// We now know which firewall table to check against // We now know which firewall table to check against
@@ -519,6 +585,7 @@ func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool,
h.logger(f.l). h.logger(f.l).
WithField("fwPacket", fp). WithField("fwPacket", fp).
WithField("incoming", c.incoming). WithField("incoming", c.incoming).
WithField("unsafe", c.unsafe).
WithField("rulesVersion", f.rulesVersion). WithField("rulesVersion", f.rulesVersion).
WithField("oldRulesVersion", c.rulesVersion). WithField("oldRulesVersion", c.rulesVersion).
Debugln("dropping old conntrack entry, does not match new ruleset") Debugln("dropping old conntrack entry, does not match new ruleset")
@@ -532,6 +599,7 @@ func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool,
h.logger(f.l). h.logger(f.l).
WithField("fwPacket", fp). WithField("fwPacket", fp).
WithField("incoming", c.incoming). WithField("incoming", c.incoming).
WithField("unsafe", c.unsafe).
WithField("rulesVersion", f.rulesVersion). WithField("rulesVersion", f.rulesVersion).
WithField("oldRulesVersion", c.rulesVersion). WithField("oldRulesVersion", c.rulesVersion).
Debugln("keeping old conntrack entry, does match new ruleset") Debugln("keeping old conntrack entry, does match new ruleset")
@@ -558,7 +626,7 @@ func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool,
return true return true
} }
func (f *Firewall) addConn(fp firewall.Packet, incoming bool) { func (f *Firewall) addConn(fp firewall.Packet, unsafe, incoming bool) {
var timeout time.Duration var timeout time.Duration
c := &conn{} c := &conn{}
@@ -581,6 +649,7 @@ func (f *Firewall) addConn(fp firewall.Packet, incoming bool) {
// Record which rulesVersion allowed this connection, so we can retest after // Record which rulesVersion allowed this connection, so we can retest after
// firewall reload // firewall reload
c.incoming = incoming c.incoming = incoming
c.unsafe = unsafe
c.rulesVersion = f.rulesVersion c.rulesVersion = f.rulesVersion
c.Expires = time.Now().Add(timeout) c.Expires = time.Now().Add(timeout)
conntrack.Conns[fp] = c conntrack.Conns[fp] = c
@@ -937,6 +1006,7 @@ func convertRule(l *logrus.Logger, p any, table string, i int) (rule, error) {
r.Code = toString("code", m) r.Code = toString("code", m)
r.Proto = toString("proto", m) r.Proto = toString("proto", m)
r.Host = toString("host", m) r.Host = toString("host", m)
//TODO: create an alias to remote_cidr and deprecate cidr?
r.Cidr = toString("cidr", m) r.Cidr = toString("cidr", m)
r.LocalCidr = toString("local_cidr", m) r.LocalCidr = toString("local_cidr", m)
r.CAName = toString("ca_name", m) r.CAName = toString("ca_name", m)

View File

@@ -73,65 +73,65 @@ func TestFirewall_AddRule(t *testing.T) {
ti6, err := netip.ParsePrefix("fd12::34/128") ti6, err := netip.ParsePrefix("fd12::34/128")
require.NoError(t, err) require.NoError(t, err)
require.NoError(t, fw.AddRule(true, firewall.ProtoTCP, 1, 1, []string{}, "", "", "", "", "")) require.NoError(t, fw.AddRule(false, true, firewall.ProtoTCP, 1, 1, []string{}, "", "", "", "", ""))
// An empty rule is any // An empty rule is any
assert.True(t, fw.InRules.TCP[1].Any.Any.Any) assert.True(t, fw.InRules.TCP[1].Any.Any.Any)
assert.Empty(t, fw.InRules.TCP[1].Any.Groups) assert.Empty(t, fw.InRules.TCP[1].Any.Groups)
assert.Empty(t, fw.InRules.TCP[1].Any.Hosts) assert.Empty(t, fw.InRules.TCP[1].Any.Hosts)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", "", "", "", "")) require.NoError(t, fw.AddRule(false, true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", "", "", "", ""))
assert.Nil(t, fw.InRules.UDP[1].Any.Any) assert.Nil(t, fw.InRules.UDP[1].Any.Any)
assert.Contains(t, fw.InRules.UDP[1].Any.Groups[0].Groups, "g1") assert.Contains(t, fw.InRules.UDP[1].Any.Groups[0].Groups, "g1")
assert.Empty(t, fw.InRules.UDP[1].Any.Hosts) assert.Empty(t, fw.InRules.UDP[1].Any.Hosts)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(true, firewall.ProtoICMP, 1, 1, []string{}, "h1", "", "", "", "")) require.NoError(t, fw.AddRule(false, true, firewall.ProtoICMP, 1, 1, []string{}, "h1", "", "", "", ""))
assert.Nil(t, fw.InRules.ICMP[1].Any.Any) assert.Nil(t, fw.InRules.ICMP[1].Any.Any)
assert.Empty(t, fw.InRules.ICMP[1].Any.Groups) assert.Empty(t, fw.InRules.ICMP[1].Any.Groups)
assert.Contains(t, fw.InRules.ICMP[1].Any.Hosts, "h1") assert.Contains(t, fw.InRules.ICMP[1].Any.Hosts, "h1")
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti.String(), "", "", "")) require.NoError(t, fw.AddRule(false, false, firewall.ProtoAny, 1, 1, []string{}, "", ti.String(), "", "", ""))
assert.Nil(t, fw.OutRules.AnyProto[1].Any.Any) assert.Nil(t, fw.OutRules.AnyProto[1].Any.Any)
_, ok := fw.OutRules.AnyProto[1].Any.CIDR.Get(ti) _, ok := fw.OutRules.AnyProto[1].Any.CIDR.Get(ti)
assert.True(t, ok) assert.True(t, ok)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti6.String(), "", "", "")) require.NoError(t, fw.AddRule(false, false, firewall.ProtoAny, 1, 1, []string{}, "", ti6.String(), "", "", ""))
assert.Nil(t, fw.OutRules.AnyProto[1].Any.Any) assert.Nil(t, fw.OutRules.AnyProto[1].Any.Any)
_, ok = fw.OutRules.AnyProto[1].Any.CIDR.Get(ti6) _, ok = fw.OutRules.AnyProto[1].Any.CIDR.Get(ti6)
assert.True(t, ok) assert.True(t, ok)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", "", ti.String(), "", "")) require.NoError(t, fw.AddRule(false, false, firewall.ProtoAny, 1, 1, []string{}, "", "", ti.String(), "", ""))
assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any) assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any)
ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti) ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti)
assert.True(t, ok) assert.True(t, ok)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", "", ti6.String(), "", "")) require.NoError(t, fw.AddRule(false, false, firewall.ProtoAny, 1, 1, []string{}, "", "", ti6.String(), "", ""))
assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any) assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any)
ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti6) ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti6)
assert.True(t, ok) assert.True(t, ok)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", "", "", "ca-name", "")) require.NoError(t, fw.AddRule(false, true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", "", "", "ca-name", ""))
assert.Contains(t, fw.InRules.UDP[1].CANames, "ca-name") assert.Contains(t, fw.InRules.UDP[1].CANames, "ca-name")
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", "", "", "", "ca-sha")) require.NoError(t, fw.AddRule(false, true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", "", "", "", "ca-sha"))
assert.Contains(t, fw.InRules.UDP[1].CAShas, "ca-sha") assert.Contains(t, fw.InRules.UDP[1].CAShas, "ca-sha")
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", "", "", "", "")) require.NoError(t, fw.AddRule(false, false, firewall.ProtoAny, 0, 0, []string{}, "any", "", "", "", ""))
assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any) assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
anyIp, err := netip.ParsePrefix("0.0.0.0/0") anyIp, err := netip.ParsePrefix("0.0.0.0/0")
require.NoError(t, err) require.NoError(t, err)
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp.String(), "", "", "")) require.NoError(t, fw.AddRule(false, false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp.String(), "", "", ""))
assert.Nil(t, fw.OutRules.AnyProto[0].Any.Any) assert.Nil(t, fw.OutRules.AnyProto[0].Any.Any)
table, ok := fw.OutRules.AnyProto[0].Any.CIDR.Lookup(netip.MustParseAddr("1.1.1.1")) table, ok := fw.OutRules.AnyProto[0].Any.CIDR.Lookup(netip.MustParseAddr("1.1.1.1"))
assert.True(t, table.Any) assert.True(t, table.Any)
@@ -142,7 +142,7 @@ func TestFirewall_AddRule(t *testing.T) {
anyIp6, err := netip.ParsePrefix("::/0") anyIp6, err := netip.ParsePrefix("::/0")
require.NoError(t, err) require.NoError(t, err)
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp6.String(), "", "", "")) require.NoError(t, fw.AddRule(false, false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp6.String(), "", "", ""))
assert.Nil(t, fw.OutRules.AnyProto[0].Any.Any) assert.Nil(t, fw.OutRules.AnyProto[0].Any.Any)
table, ok = fw.OutRules.AnyProto[0].Any.CIDR.Lookup(netip.MustParseAddr("9::9")) table, ok = fw.OutRules.AnyProto[0].Any.CIDR.Lookup(netip.MustParseAddr("9::9"))
assert.True(t, table.Any) assert.True(t, table.Any)
@@ -150,29 +150,29 @@ func TestFirewall_AddRule(t *testing.T) {
assert.False(t, ok) assert.False(t, ok)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", "any", "", "", "")) require.NoError(t, fw.AddRule(false, false, firewall.ProtoAny, 0, 0, []string{}, "", "any", "", "", ""))
assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any) assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", "", anyIp.String(), "", "")) require.NoError(t, fw.AddRule(false, false, firewall.ProtoAny, 0, 0, []string{}, "", "", anyIp.String(), "", ""))
assert.False(t, fw.OutRules.AnyProto[0].Any.Any.Any) assert.False(t, fw.OutRules.AnyProto[0].Any.Any.Any)
assert.True(t, fw.OutRules.AnyProto[0].Any.Any.LocalCIDR.Lookup(netip.MustParseAddr("1.1.1.1"))) assert.True(t, fw.OutRules.AnyProto[0].Any.Any.LocalCIDR.Lookup(netip.MustParseAddr("1.1.1.1")))
assert.False(t, fw.OutRules.AnyProto[0].Any.Any.LocalCIDR.Lookup(netip.MustParseAddr("9::9"))) assert.False(t, fw.OutRules.AnyProto[0].Any.Any.LocalCIDR.Lookup(netip.MustParseAddr("9::9")))
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", "", anyIp6.String(), "", "")) require.NoError(t, fw.AddRule(false, false, firewall.ProtoAny, 0, 0, []string{}, "", "", anyIp6.String(), "", ""))
assert.False(t, fw.OutRules.AnyProto[0].Any.Any.Any) assert.False(t, fw.OutRules.AnyProto[0].Any.Any.Any)
assert.True(t, fw.OutRules.AnyProto[0].Any.Any.LocalCIDR.Lookup(netip.MustParseAddr("9::9"))) assert.True(t, fw.OutRules.AnyProto[0].Any.Any.LocalCIDR.Lookup(netip.MustParseAddr("9::9")))
assert.False(t, fw.OutRules.AnyProto[0].Any.Any.LocalCIDR.Lookup(netip.MustParseAddr("1.1.1.1"))) assert.False(t, fw.OutRules.AnyProto[0].Any.Any.LocalCIDR.Lookup(netip.MustParseAddr("1.1.1.1")))
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", "", "any", "", "")) require.NoError(t, fw.AddRule(false, false, firewall.ProtoAny, 0, 0, []string{}, "", "", "any", "", ""))
assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any) assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
// Test error conditions // Test error conditions
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.Error(t, fw.AddRule(true, math.MaxUint8, 0, 0, []string{}, "", "", "", "", "")) require.Error(t, fw.AddRule(false, true, math.MaxUint8, 0, 0, []string{}, "", "", "", "", ""))
require.Error(t, fw.AddRule(true, firewall.ProtoAny, 10, 0, []string{}, "", "", "", "", "")) require.Error(t, fw.AddRule(false, true, firewall.ProtoAny, 10, 0, []string{}, "", "", "", "", ""))
} }
func TestFirewall_Drop(t *testing.T) { func TestFirewall_Drop(t *testing.T) {
@@ -208,7 +208,7 @@ func TestFirewall_Drop(t *testing.T) {
h.buildNetworks(myVpnNetworksTable, &c) h.buildNetworks(myVpnNetworksTable, &c)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", "")) require.NoError(t, fw.AddRule(false, true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", ""))
cp := cert.NewCAPool() cp := cert.NewCAPool()
// Drop outbound // Drop outbound
@@ -227,28 +227,28 @@ func TestFirewall_Drop(t *testing.T) {
// ensure signer doesn't get in the way of group checks // ensure signer doesn't get in the way of group checks
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum")) require.NoError(t, fw.AddRule(false, true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum"))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum-bad")) require.NoError(t, fw.AddRule(false, true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum-bad"))
assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule) assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
// test caSha doesn't drop on match // test caSha doesn't drop on match
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum-bad")) require.NoError(t, fw.AddRule(false, true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum-bad"))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum")) require.NoError(t, fw.AddRule(false, true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum"))
require.NoError(t, fw.Drop(p, true, &h, cp, nil)) require.NoError(t, fw.Drop(p, true, &h, cp, nil))
// ensure ca name doesn't get in the way of group checks // ensure ca name doesn't get in the way of group checks
cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}} cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good", "")) require.NoError(t, fw.AddRule(false, true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good-bad", "")) require.NoError(t, fw.AddRule(false, true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good-bad", ""))
assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule) assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
// test caName doesn't drop on match // test caName doesn't drop on match
cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}} cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good-bad", "")) require.NoError(t, fw.AddRule(false, true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good-bad", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good", "")) require.NoError(t, fw.AddRule(false, true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good", ""))
require.NoError(t, fw.Drop(p, true, &h, cp, nil)) require.NoError(t, fw.Drop(p, true, &h, cp, nil))
} }
@@ -287,7 +287,7 @@ func TestFirewall_DropV6(t *testing.T) {
h.buildNetworks(myVpnNetworksTable, &c) h.buildNetworks(myVpnNetworksTable, &c)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", "")) require.NoError(t, fw.AddRule(false, true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", ""))
cp := cert.NewCAPool() cp := cert.NewCAPool()
// Drop outbound // Drop outbound
@@ -306,28 +306,28 @@ func TestFirewall_DropV6(t *testing.T) {
// ensure signer doesn't get in the way of group checks // ensure signer doesn't get in the way of group checks
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum")) require.NoError(t, fw.AddRule(false, true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum"))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum-bad")) require.NoError(t, fw.AddRule(false, true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum-bad"))
assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule) assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
// test caSha doesn't drop on match // test caSha doesn't drop on match
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum-bad")) require.NoError(t, fw.AddRule(false, true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum-bad"))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum")) require.NoError(t, fw.AddRule(false, true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum"))
require.NoError(t, fw.Drop(p, true, &h, cp, nil)) require.NoError(t, fw.Drop(p, true, &h, cp, nil))
// ensure ca name doesn't get in the way of group checks // ensure ca name doesn't get in the way of group checks
cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}} cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good", "")) require.NoError(t, fw.AddRule(false, true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good-bad", "")) require.NoError(t, fw.AddRule(false, true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good-bad", ""))
assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule) assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
// test caName doesn't drop on match // test caName doesn't drop on match
cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}} cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good-bad", "")) require.NoError(t, fw.AddRule(false, true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good-bad", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good", "")) require.NoError(t, fw.AddRule(false, true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good", ""))
require.NoError(t, fw.Drop(p, true, &h, cp, nil)) require.NoError(t, fw.Drop(p, true, &h, cp, nil))
} }
@@ -532,7 +532,7 @@ func TestFirewall_Drop2(t *testing.T) {
h1.buildNetworks(myVpnNetworksTable, c1.Certificate) h1.buildNetworks(myVpnNetworksTable, c1.Certificate)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", "", "", "", "")) require.NoError(t, fw.AddRule(false, true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", "", "", "", ""))
cp := cert.NewCAPool() cp := cert.NewCAPool()
// h1/c1 lacks the proper groups // h1/c1 lacks the proper groups
@@ -612,8 +612,8 @@ func TestFirewall_Drop3(t *testing.T) {
h3.buildNetworks(myVpnNetworksTable, c3.Certificate) h3.buildNetworks(myVpnNetworksTable, c3.Certificate)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", "", "", "", "")) require.NoError(t, fw.AddRule(false, true, firewall.ProtoAny, 1, 1, []string{}, "host1", "", "", "", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", "", "", "", "signer-sha")) require.NoError(t, fw.AddRule(false, true, firewall.ProtoAny, 1, 1, []string{}, "", "", "", "", "signer-sha"))
cp := cert.NewCAPool() cp := cert.NewCAPool()
// c1 should pass because host match // c1 should pass because host match
@@ -627,7 +627,7 @@ func TestFirewall_Drop3(t *testing.T) {
// Test a remote address match // Test a remote address match
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", "1.2.3.4/24", "", "", "")) require.NoError(t, fw.AddRule(false, true, firewall.ProtoAny, 1, 1, []string{}, "", "1.2.3.4/24", "", "", ""))
require.NoError(t, fw.Drop(p, true, &h1, cp, nil)) require.NoError(t, fw.Drop(p, true, &h1, cp, nil))
} }
@@ -665,7 +665,7 @@ func TestFirewall_Drop3V6(t *testing.T) {
// Test a remote address match // Test a remote address match
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
cp := cert.NewCAPool() cp := cert.NewCAPool()
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", "fd12::34/120", "", "", "")) require.NoError(t, fw.AddRule(false, true, firewall.ProtoAny, 1, 1, []string{}, "", "fd12::34/120", "", "", ""))
require.NoError(t, fw.Drop(p, true, &h, cp, nil)) require.NoError(t, fw.Drop(p, true, &h, cp, nil))
} }
@@ -704,7 +704,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
h.buildNetworks(myVpnNetworksTable, c.Certificate) h.buildNetworks(myVpnNetworksTable, c.Certificate)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", "")) require.NoError(t, fw.AddRule(false, true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", ""))
cp := cert.NewCAPool() cp := cert.NewCAPool()
// Drop outbound // Drop outbound
@@ -717,7 +717,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
oldFw := fw oldFw := fw
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 10, 10, []string{"any"}, "", "", "", "", "")) require.NoError(t, fw.AddRule(false, true, firewall.ProtoAny, 10, 10, []string{"any"}, "", "", "", "", ""))
fw.Conntrack = oldFw.Conntrack fw.Conntrack = oldFw.Conntrack
fw.rulesVersion = oldFw.rulesVersion + 1 fw.rulesVersion = oldFw.rulesVersion + 1
@@ -726,7 +726,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
oldFw = fw oldFw = fw
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 11, 11, []string{"any"}, "", "", "", "", "")) require.NoError(t, fw.AddRule(false, true, firewall.ProtoAny, 11, 11, []string{"any"}, "", "", "", "", ""))
fw.Conntrack = oldFw.Conntrack fw.Conntrack = oldFw.Conntrack
fw.rulesVersion = oldFw.rulesVersion + 1 fw.rulesVersion = oldFw.rulesVersion + 1
@@ -765,7 +765,7 @@ func TestFirewall_DropIPSpoofing(t *testing.T) {
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", "", "", "", "")) require.NoError(t, fw.AddRule(false, true, firewall.ProtoAny, 1, 1, []string{}, "", "", "", "", ""))
cp := cert.NewCAPool() cp := cert.NewCAPool()
// Packet spoofed by `c1`. Note that the remote addr is not a valid one. // Packet spoofed by `c1`. Note that the remote addr is not a valid one.
@@ -958,28 +958,28 @@ func TestAddFirewallRulesFromConfig(t *testing.T) {
conf := config.NewC(l) conf := config.NewC(l)
mf := &mockFirewall{} mf := &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "tcp", "host": "a"}}} conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "tcp", "host": "a"}}}
require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf)) require.NoError(t, AddFirewallRulesFromConfig(l, false, false, conf, mf))
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoTCP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoTCP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall)
// Test adding udp rule // Test adding udp rule
conf = config.NewC(l) conf = config.NewC(l)
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "udp", "host": "a"}}} conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "udp", "host": "a"}}}
require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf)) require.NoError(t, AddFirewallRulesFromConfig(l, false, false, conf, mf))
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoUDP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoUDP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall)
// Test adding icmp rule // Test adding icmp rule
conf = config.NewC(l) conf = config.NewC(l)
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "icmp", "host": "a"}}} conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "icmp", "host": "a"}}}
require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf)) require.NoError(t, AddFirewallRulesFromConfig(l, false, false, conf, mf))
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall)
// Test adding any rule // Test adding any rule
conf = config.NewC(l) conf = config.NewC(l)
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "host": "a"}}} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "host": "a"}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) require.NoError(t, AddFirewallRulesFromConfig(l, false, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall)
// Test adding rule with cidr // Test adding rule with cidr
@@ -987,14 +987,14 @@ func TestAddFirewallRulesFromConfig(t *testing.T) {
conf = config.NewC(l) conf = config.NewC(l)
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": cidr.String()}}} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": cidr.String()}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) require.NoError(t, AddFirewallRulesFromConfig(l, false, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr.String(), localIp: ""}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr.String(), localIp: ""}, mf.lastCall)
// Test adding rule with local_cidr // Test adding rule with local_cidr
conf = config.NewC(l) conf = config.NewC(l)
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": cidr.String()}}} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": cidr.String()}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) require.NoError(t, AddFirewallRulesFromConfig(l, false, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "", localIp: cidr.String()}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "", localIp: cidr.String()}, mf.lastCall)
// Test adding rule with cidr ipv6 // Test adding rule with cidr ipv6
@@ -1002,75 +1002,75 @@ func TestAddFirewallRulesFromConfig(t *testing.T) {
conf = config.NewC(l) conf = config.NewC(l)
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": cidr6.String()}}} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": cidr6.String()}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) require.NoError(t, AddFirewallRulesFromConfig(l, false, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr6.String(), localIp: ""}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr6.String(), localIp: ""}, mf.lastCall)
// Test adding rule with any cidr // Test adding rule with any cidr
conf = config.NewC(l) conf = config.NewC(l)
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": "any"}}} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": "any"}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) require.NoError(t, AddFirewallRulesFromConfig(l, false, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "any", localIp: ""}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "any", localIp: ""}, mf.lastCall)
// Test adding rule with junk cidr // Test adding rule with junk cidr
conf = config.NewC(l) conf = config.NewC(l)
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": "junk/junk"}}} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": "junk/junk"}}}
require.EqualError(t, AddFirewallRulesFromConfig(l, true, conf, mf), "firewall.inbound rule #0; cidr did not parse; netip.ParsePrefix(\"junk/junk\"): ParseAddr(\"junk\"): unable to parse IP") require.EqualError(t, AddFirewallRulesFromConfig(l, false, true, conf, mf), "firewall.inbound rule #0; cidr did not parse; netip.ParsePrefix(\"junk/junk\"): ParseAddr(\"junk\"): unable to parse IP")
// Test adding rule with local_cidr ipv6 // Test adding rule with local_cidr ipv6
conf = config.NewC(l) conf = config.NewC(l)
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": cidr6.String()}}} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": cidr6.String()}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) require.NoError(t, AddFirewallRulesFromConfig(l, false, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "", localIp: cidr6.String()}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "", localIp: cidr6.String()}, mf.lastCall)
// Test adding rule with any local_cidr // Test adding rule with any local_cidr
conf = config.NewC(l) conf = config.NewC(l)
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": "any"}}} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": "any"}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) require.NoError(t, AddFirewallRulesFromConfig(l, false, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, localIp: "any"}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, localIp: "any"}, mf.lastCall)
// Test adding rule with junk local_cidr // Test adding rule with junk local_cidr
conf = config.NewC(l) conf = config.NewC(l)
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": "junk/junk"}}} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": "junk/junk"}}}
require.EqualError(t, AddFirewallRulesFromConfig(l, true, conf, mf), "firewall.inbound rule #0; local_cidr did not parse; netip.ParsePrefix(\"junk/junk\"): ParseAddr(\"junk\"): unable to parse IP") require.EqualError(t, AddFirewallRulesFromConfig(l, false, true, conf, mf), "firewall.inbound rule #0; local_cidr did not parse; netip.ParsePrefix(\"junk/junk\"): ParseAddr(\"junk\"): unable to parse IP")
// Test adding rule with ca_sha // Test adding rule with ca_sha
conf = config.NewC(l) conf = config.NewC(l)
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "ca_sha": "12312313123"}}} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "ca_sha": "12312313123"}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) require.NoError(t, AddFirewallRulesFromConfig(l, false, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "", localIp: "", caSha: "12312313123"}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "", localIp: "", caSha: "12312313123"}, mf.lastCall)
// Test adding rule with ca_name // Test adding rule with ca_name
conf = config.NewC(l) conf = config.NewC(l)
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "ca_name": "root01"}}} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "ca_name": "root01"}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) require.NoError(t, AddFirewallRulesFromConfig(l, false, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "", localIp: "", caName: "root01"}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "", localIp: "", caName: "root01"}, mf.lastCall)
// Test single group // Test single group
conf = config.NewC(l) conf = config.NewC(l)
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "group": "a"}}} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "group": "a"}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) require.NoError(t, AddFirewallRulesFromConfig(l, false, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: "", localIp: ""}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: "", localIp: ""}, mf.lastCall)
// Test single groups // Test single groups
conf = config.NewC(l) conf = config.NewC(l)
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "groups": "a"}}} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "groups": "a"}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) require.NoError(t, AddFirewallRulesFromConfig(l, false, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: "", localIp: ""}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: "", localIp: ""}, mf.lastCall)
// Test multiple AND groups // Test multiple AND groups
conf = config.NewC(l) conf = config.NewC(l)
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "groups": []string{"a", "b"}}}} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "groups": []string{"a", "b"}}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) require.NoError(t, AddFirewallRulesFromConfig(l, false, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a", "b"}, ip: "", localIp: ""}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a", "b"}, ip: "", localIp: ""}, mf.lastCall)
// Test Add error // Test Add error
@@ -1078,7 +1078,7 @@ func TestAddFirewallRulesFromConfig(t *testing.T) {
mf = &mockFirewall{} mf = &mockFirewall{}
mf.nextCallReturn = errors.New("test error") mf.nextCallReturn = errors.New("test error")
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "host": "a"}}} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "host": "a"}}}
require.EqualError(t, AddFirewallRulesFromConfig(l, true, conf, mf), "firewall.inbound rule #0; `test error`") require.EqualError(t, AddFirewallRulesFromConfig(l, false, true, conf, mf), "firewall.inbound rule #0; `test error`")
} }
func TestFirewall_convertRule(t *testing.T) { func TestFirewall_convertRule(t *testing.T) {
@@ -1251,7 +1251,7 @@ func newSetupFromCert(t *testing.T, l *logrus.Logger, c dummyCert) testsetup {
myVpnNetworksTable.Insert(prefix) myVpnNetworksTable.Insert(prefix)
} }
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", "")) require.NoError(t, fw.AddRule(false, true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", ""))
return testsetup{ return testsetup{
c: c, c: c,
@@ -1332,13 +1332,14 @@ func TestFirewall_Drop_EnforceIPMatch(t *testing.T) {
tc.p.LocalAddr = netip.MustParseAddr("192.168.0.3") tc.p.LocalAddr = netip.MustParseAddr("192.168.0.3")
tc.err = ErrNoMatchingRule tc.err = ErrNoMatchingRule
tc.Test(t, unsafeSetup.fw) //should hit firewall and bounce off tc.Test(t, unsafeSetup.fw) //should hit firewall and bounce off
require.NoError(t, unsafeSetup.fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", unsafePrefix.String(), "", "")) require.NoError(t, unsafeSetup.fw.AddRule(true, true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", unsafePrefix.String(), "", ""))
tc.err = nil tc.err = nil
tc.Test(t, unsafeSetup.fw) //should pass tc.Test(t, unsafeSetup.fw) //should pass
}) })
} }
type addRuleCall struct { type addRuleCall struct {
unsafe bool
incoming bool incoming bool
proto uint8 proto uint8
startPort int32 startPort int32
@@ -1356,8 +1357,9 @@ type mockFirewall struct {
nextCallReturn error nextCallReturn error
} }
func (mf *mockFirewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip, localIp, caName string, caSha string) error { func (mf *mockFirewall) AddRule(unsafe, incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip, localIp, caName string, caSha string) error {
mf.lastCall = addRuleCall{ mf.lastCall = addRuleCall{
unsafe: unsafe,
incoming: incoming, incoming: incoming,
proto: proto, proto: proto,
startPort: startPort, startPort: startPort,

View File

@@ -99,11 +99,11 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
return true return true
} }
func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H) { func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet []byte, h *header.H) {
cs := f.pki.getCertState() cs := f.pki.getCertState()
crt := cs.GetDefaultCertificate() crt := cs.GetDefaultCertificate()
if crt == nil { if crt == nil {
f.l.WithField("from", via). f.l.WithField("udpAddr", addr).
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}). WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
WithField("certVersion", cs.initiatingVersion). WithField("certVersion", cs.initiatingVersion).
Error("Unable to handshake with host because no certificate is available") Error("Unable to handshake with host because no certificate is available")
@@ -112,7 +112,7 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H)
ci, err := NewConnectionState(f.l, cs, crt, false, noise.HandshakeIX) ci, err := NewConnectionState(f.l, cs, crt, false, noise.HandshakeIX)
if err != nil { if err != nil {
f.l.WithError(err).WithField("from", via). f.l.WithError(err).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
Error("Failed to create connection state") Error("Failed to create connection state")
return return
@@ -123,7 +123,7 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H)
msg, _, _, err := ci.H.ReadMessage(nil, packet[header.Len:]) msg, _, _, err := ci.H.ReadMessage(nil, packet[header.Len:])
if err != nil { if err != nil {
f.l.WithError(err).WithField("from", via). f.l.WithError(err).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
Error("Failed to call noise.ReadMessage") Error("Failed to call noise.ReadMessage")
return return
@@ -132,7 +132,7 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H)
hs := &NebulaHandshake{} hs := &NebulaHandshake{}
err = hs.Unmarshal(msg) err = hs.Unmarshal(msg)
if err != nil || hs.Details == nil { if err != nil || hs.Details == nil {
f.l.WithError(err).WithField("from", via). f.l.WithError(err).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
Error("Failed unmarshal handshake message") Error("Failed unmarshal handshake message")
return return
@@ -140,7 +140,7 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H)
rc, err := cert.Recombine(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), ci.Curve()) rc, err := cert.Recombine(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), ci.Curve())
if err != nil { if err != nil {
f.l.WithError(err).WithField("from", via). f.l.WithError(err).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
Info("Handshake did not contain a certificate") Info("Handshake did not contain a certificate")
return return
@@ -153,7 +153,7 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H)
fp = "<error generating certificate fingerprint>" fp = "<error generating certificate fingerprint>"
} }
e := f.l.WithError(err).WithField("from", via). e := f.l.WithError(err).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
WithField("certVpnNetworks", rc.Networks()). WithField("certVpnNetworks", rc.Networks()).
WithField("certFingerprint", fp) WithField("certFingerprint", fp)
@@ -172,7 +172,7 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H)
if myCertOtherVersion == nil { if myCertOtherVersion == nil {
if f.l.Level >= logrus.DebugLevel { if f.l.Level >= logrus.DebugLevel {
f.l.WithError(err).WithFields(m{ f.l.WithError(err).WithFields(m{
"from": via, "udpAddr": addr,
"handshake": m{"stage": 1, "style": "ix_psk0"}, "handshake": m{"stage": 1, "style": "ix_psk0"},
"cert": remoteCert, "cert": remoteCert,
}).Debug("Might be unable to handshake with host due to missing certificate version") }).Debug("Might be unable to handshake with host due to missing certificate version")
@@ -184,7 +184,7 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H)
} }
if len(remoteCert.Certificate.Networks()) == 0 { if len(remoteCert.Certificate.Networks()) == 0 {
f.l.WithError(err).WithField("from", via). f.l.WithError(err).WithField("udpAddr", addr).
WithField("cert", remoteCert). WithField("cert", remoteCert).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
Info("No networks in certificate") Info("No networks in certificate")
@@ -201,7 +201,7 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H)
vpnAddrs := make([]netip.Addr, len(vpnNetworks)) vpnAddrs := make([]netip.Addr, len(vpnNetworks))
for i, network := range vpnNetworks { for i, network := range vpnNetworks {
if f.myVpnAddrsTable.Contains(network.Addr()) { if f.myVpnAddrsTable.Contains(network.Addr()) {
f.l.WithField("vpnNetworks", vpnNetworks).WithField("from", via). f.l.WithField("vpnNetworks", vpnNetworks).WithField("udpAddr", addr).
WithField("certName", certName). WithField("certName", certName).
WithField("certVersion", certVersion). WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint). WithField("fingerprint", fingerprint).
@@ -215,18 +215,18 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H)
} }
} }
if !via.IsRelayed { if addr.IsValid() {
// addr can be invalid when the tunnel is being relayed.
// We only want to apply the remote allow list for direct tunnels here // We only want to apply the remote allow list for direct tunnels here
if !f.lightHouse.GetRemoteAllowList().AllowAll(vpnAddrs, via.UdpAddr.Addr()) { if !f.lightHouse.GetRemoteAllowList().AllowAll(vpnAddrs, addr.Addr()) {
f.l.WithField("vpnAddrs", vpnAddrs).WithField("from", via). f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
Debug("lighthouse.remote_allow_list denied incoming handshake")
return return
} }
} }
myIndex, err := generateIndex(f.l) myIndex, err := generateIndex(f.l)
if err != nil { if err != nil {
f.l.WithError(err).WithField("vpnAddrs", vpnAddrs).WithField("from", via). f.l.WithError(err).WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
WithField("certName", certName). WithField("certName", certName).
WithField("certVersion", certVersion). WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint). WithField("fingerprint", fingerprint).
@@ -251,7 +251,7 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H)
msgRxL := f.l.WithFields(m{ msgRxL := f.l.WithFields(m{
"vpnAddrs": vpnAddrs, "vpnAddrs": vpnAddrs,
"from": via, "udpAddr": addr,
"certName": certName, "certName": certName,
"certVersion": certVersion, "certVersion": certVersion,
"fingerprint": fingerprint, "fingerprint": fingerprint,
@@ -283,7 +283,7 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H)
hsBytes, err := hs.Marshal() hsBytes, err := hs.Marshal()
if err != nil { if err != nil {
f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via). f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).
WithField("certName", certName). WithField("certName", certName).
WithField("certVersion", certVersion). WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint). WithField("fingerprint", fingerprint).
@@ -295,7 +295,7 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H)
nh := header.Encode(make([]byte, header.Len), header.Version, header.Handshake, header.HandshakeIXPSK0, hs.Details.InitiatorIndex, 2) nh := header.Encode(make([]byte, header.Len), header.Version, header.Handshake, header.HandshakeIXPSK0, hs.Details.InitiatorIndex, 2)
msg, dKey, eKey, err := ci.H.WriteMessage(nh, hsBytes) msg, dKey, eKey, err := ci.H.WriteMessage(nh, hsBytes)
if err != nil { if err != nil {
f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via). f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).
WithField("certName", certName). WithField("certName", certName).
WithField("certVersion", certVersion). WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint). WithField("fingerprint", fingerprint).
@@ -303,7 +303,7 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H)
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage") WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage")
return return
} else if dKey == nil || eKey == nil { } else if dKey == nil || eKey == nil {
f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via). f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).
WithField("certName", certName). WithField("certName", certName).
WithField("certVersion", certVersion). WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint). WithField("fingerprint", fingerprint).
@@ -329,9 +329,7 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H)
ci.eKey = NewNebulaCipherState(eKey) ci.eKey = NewNebulaCipherState(eKey)
hostinfo.remotes = f.lightHouse.QueryCache(vpnAddrs) hostinfo.remotes = f.lightHouse.QueryCache(vpnAddrs)
if !via.IsRelayed { hostinfo.SetRemote(addr)
hostinfo.SetRemote(via.UdpAddr)
}
hostinfo.buildNetworks(f.myVpnNetworksTable, remoteCert.Certificate) hostinfo.buildNetworks(f.myVpnNetworksTable, remoteCert.Certificate)
existing, err := f.handshakeManager.CheckAndComplete(hostinfo, 0, f) existing, err := f.handshakeManager.CheckAndComplete(hostinfo, 0, f)
@@ -339,7 +337,7 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H)
switch err { switch err {
case ErrAlreadySeen: case ErrAlreadySeen:
// Update remote if preferred // Update remote if preferred
if existing.SetRemoteIfPreferred(f.hostMap, via) { if existing.SetRemoteIfPreferred(f.hostMap, addr) {
// Send a test packet to ensure the other side has also switched to // Send a test packet to ensure the other side has also switched to
// the preferred remote // the preferred remote
f.SendMessageToVpnAddr(header.Test, header.TestRequest, vpnAddrs[0], []byte(""), make([]byte, 12, 12), make([]byte, mtu)) f.SendMessageToVpnAddr(header.Test, header.TestRequest, vpnAddrs[0], []byte(""), make([]byte, 12, 12), make([]byte, mtu))
@@ -347,21 +345,21 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H)
msg = existing.HandshakePacket[2] msg = existing.HandshakePacket[2]
f.messageMetrics.Tx(header.Handshake, header.MessageSubType(msg[1]), 1) f.messageMetrics.Tx(header.Handshake, header.MessageSubType(msg[1]), 1)
if !via.IsRelayed { if addr.IsValid() {
err := f.outside.WriteTo(msg, via.UdpAddr) err := f.outside.WriteTo(msg, addr)
if err != nil { if err != nil {
f.l.WithField("vpnAddrs", existing.vpnAddrs).WithField("from", via). f.l.WithField("vpnAddrs", existing.vpnAddrs).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true).
WithError(err).Error("Failed to send handshake message") WithError(err).Error("Failed to send handshake message")
} else { } else {
f.l.WithField("vpnAddrs", existing.vpnAddrs).WithField("from", via). f.l.WithField("vpnAddrs", existing.vpnAddrs).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true).
Info("Handshake message sent") Info("Handshake message sent")
} }
return return
} else { } else {
if via.relay == nil { if via == nil {
f.l.Error("Handshake send failed: both addr and via.relay are nil.") f.l.Error("Handshake send failed: both addr and via are nil.")
return return
} }
hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0]) hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0])
@@ -373,7 +371,7 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H)
} }
case ErrExistingHostInfo: case ErrExistingHostInfo:
// This means there was an existing tunnel and this handshake was older than the one we are currently based on // This means there was an existing tunnel and this handshake was older than the one we are currently based on
f.l.WithField("vpnAddrs", vpnAddrs).WithField("from", via). f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
WithField("certName", certName). WithField("certName", certName).
WithField("certVersion", certVersion). WithField("certVersion", certVersion).
WithField("oldHandshakeTime", existing.lastHandshakeTime). WithField("oldHandshakeTime", existing.lastHandshakeTime).
@@ -389,7 +387,7 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H)
return return
case ErrLocalIndexCollision: case ErrLocalIndexCollision:
// This means we failed to insert because of collision on localIndexId. Just let the next handshake packet retry // This means we failed to insert because of collision on localIndexId. Just let the next handshake packet retry
f.l.WithField("vpnAddrs", vpnAddrs).WithField("from", via). f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
WithField("certName", certName). WithField("certName", certName).
WithField("certVersion", certVersion). WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint). WithField("fingerprint", fingerprint).
@@ -402,7 +400,7 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H)
default: default:
// Shouldn't happen, but just in case someone adds a new error type to CheckAndComplete // Shouldn't happen, but just in case someone adds a new error type to CheckAndComplete
// And we forget to update it here // And we forget to update it here
f.l.WithError(err).WithField("vpnAddrs", vpnAddrs).WithField("from", via). f.l.WithError(err).WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
WithField("certName", certName). WithField("certName", certName).
WithField("certVersion", certVersion). WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint). WithField("fingerprint", fingerprint).
@@ -416,23 +414,30 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H)
// Do the send // Do the send
f.messageMetrics.Tx(header.Handshake, header.MessageSubType(msg[1]), 1) f.messageMetrics.Tx(header.Handshake, header.MessageSubType(msg[1]), 1)
if !via.IsRelayed { if addr.IsValid() {
err = f.outside.WriteTo(msg, via.UdpAddr) err = f.outside.WriteTo(msg, addr)
log := f.l.WithField("vpnAddrs", vpnAddrs).WithField("from", via).
WithField("certName", certName).
WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint).
WithField("issuer", issuer).
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"})
if err != nil { if err != nil {
log.WithError(err).Error("Failed to send handshake") f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint).
WithField("issuer", issuer).
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
WithError(err).Error("Failed to send handshake")
} else { } else {
log.Info("Handshake message sent") f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint).
WithField("issuer", issuer).
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
Info("Handshake message sent")
} }
} else { } else {
if via.relay == nil { if via == nil {
f.l.Error("Handshake send failed: both addr and via.relay are nil.") f.l.Error("Handshake send failed: both addr and via are nil.")
return return
} }
hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0]) hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0])
@@ -457,7 +462,7 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H)
return return
} }
func ixHandshakeStage2(f *Interface, via ViaSender, hh *HandshakeHostInfo, packet []byte, h *header.H) bool { func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *HandshakeHostInfo, packet []byte, h *header.H) bool {
if hh == nil { if hh == nil {
// Nothing here to tear down, got a bogus stage 2 packet // Nothing here to tear down, got a bogus stage 2 packet
return true return true
@@ -467,10 +472,10 @@ func ixHandshakeStage2(f *Interface, via ViaSender, hh *HandshakeHostInfo, packe
defer hh.Unlock() defer hh.Unlock()
hostinfo := hh.hostinfo hostinfo := hh.hostinfo
if !via.IsRelayed { if addr.IsValid() {
// The vpnAddr we know about is the one we tried to handshake with, use it to apply the remote allow list. // The vpnAddr we know about is the one we tried to handshake with, use it to apply the remote allow list.
if !f.lightHouse.GetRemoteAllowList().AllowAll(hostinfo.vpnAddrs, via.UdpAddr.Addr()) { if !f.lightHouse.GetRemoteAllowList().AllowAll(hostinfo.vpnAddrs, addr.Addr()) {
f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via).Debug("lighthouse.remote_allow_list denied incoming handshake") f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
return false return false
} }
} }
@@ -478,7 +483,7 @@ func ixHandshakeStage2(f *Interface, via ViaSender, hh *HandshakeHostInfo, packe
ci := hostinfo.ConnectionState ci := hostinfo.ConnectionState
msg, eKey, dKey, err := ci.H.ReadMessage(nil, packet[header.Len:]) msg, eKey, dKey, err := ci.H.ReadMessage(nil, packet[header.Len:])
if err != nil { if err != nil {
f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via). f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("header", h). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("header", h).
Error("Failed to call noise.ReadMessage") Error("Failed to call noise.ReadMessage")
@@ -487,7 +492,7 @@ func ixHandshakeStage2(f *Interface, via ViaSender, hh *HandshakeHostInfo, packe
// near future // near future
return false return false
} else if dKey == nil || eKey == nil { } else if dKey == nil || eKey == nil {
f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via). f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
Error("Noise did not arrive at a key") Error("Noise did not arrive at a key")
@@ -499,7 +504,7 @@ func ixHandshakeStage2(f *Interface, via ViaSender, hh *HandshakeHostInfo, packe
hs := &NebulaHandshake{} hs := &NebulaHandshake{}
err = hs.Unmarshal(msg) err = hs.Unmarshal(msg)
if err != nil || hs.Details == nil { if err != nil || hs.Details == nil {
f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via). f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).Error("Failed unmarshal handshake message") WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).Error("Failed unmarshal handshake message")
// The handshake state machine is complete, if things break now there is no chance to recover. Tear down and start again // The handshake state machine is complete, if things break now there is no chance to recover. Tear down and start again
@@ -508,7 +513,7 @@ func ixHandshakeStage2(f *Interface, via ViaSender, hh *HandshakeHostInfo, packe
rc, err := cert.Recombine(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), ci.Curve()) rc, err := cert.Recombine(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), ci.Curve())
if err != nil { if err != nil {
f.l.WithError(err).WithField("from", via). f.l.WithError(err).WithField("udpAddr", addr).
WithField("vpnAddrs", hostinfo.vpnAddrs). WithField("vpnAddrs", hostinfo.vpnAddrs).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
Info("Handshake did not contain a certificate") Info("Handshake did not contain a certificate")
@@ -522,7 +527,7 @@ func ixHandshakeStage2(f *Interface, via ViaSender, hh *HandshakeHostInfo, packe
fp = "<error generating certificate fingerprint>" fp = "<error generating certificate fingerprint>"
} }
e := f.l.WithError(err).WithField("from", via). e := f.l.WithError(err).WithField("udpAddr", addr).
WithField("vpnAddrs", hostinfo.vpnAddrs). WithField("vpnAddrs", hostinfo.vpnAddrs).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
WithField("certFingerprint", fp). WithField("certFingerprint", fp).
@@ -537,7 +542,7 @@ func ixHandshakeStage2(f *Interface, via ViaSender, hh *HandshakeHostInfo, packe
} }
if len(remoteCert.Certificate.Networks()) == 0 { if len(remoteCert.Certificate.Networks()) == 0 {
f.l.WithError(err).WithField("from", via). f.l.WithError(err).WithField("udpAddr", addr).
WithField("vpnAddrs", hostinfo.vpnAddrs). WithField("vpnAddrs", hostinfo.vpnAddrs).
WithField("cert", remoteCert). WithField("cert", remoteCert).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
@@ -560,8 +565,8 @@ func ixHandshakeStage2(f *Interface, via ViaSender, hh *HandshakeHostInfo, packe
ci.eKey = NewNebulaCipherState(eKey) ci.eKey = NewNebulaCipherState(eKey)
// Make sure the current udpAddr being used is set for responding // Make sure the current udpAddr being used is set for responding
if !via.IsRelayed { if addr.IsValid() {
hostinfo.SetRemote(via.UdpAddr) hostinfo.SetRemote(addr)
} else { } else {
hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0]) hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0])
} }
@@ -583,7 +588,7 @@ func ixHandshakeStage2(f *Interface, via ViaSender, hh *HandshakeHostInfo, packe
// Ensure the right host responded // Ensure the right host responded
if !correctHostResponded { if !correctHostResponded {
f.l.WithField("intendedVpnAddrs", hostinfo.vpnAddrs).WithField("haveVpnNetworks", vpnNetworks). f.l.WithField("intendedVpnAddrs", hostinfo.vpnAddrs).WithField("haveVpnNetworks", vpnNetworks).
WithField("from", via). WithField("udpAddr", addr).
WithField("certName", certName). WithField("certName", certName).
WithField("certVersion", certVersion). WithField("certVersion", certVersion).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
@@ -597,7 +602,7 @@ func ixHandshakeStage2(f *Interface, via ViaSender, hh *HandshakeHostInfo, packe
f.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], func(newHH *HandshakeHostInfo) { f.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], func(newHH *HandshakeHostInfo) {
// Block the current used address // Block the current used address
newHH.hostinfo.remotes = hostinfo.remotes newHH.hostinfo.remotes = hostinfo.remotes
newHH.hostinfo.remotes.BlockRemote(via) newHH.hostinfo.remotes.BlockRemote(addr)
f.l.WithField("blockedUdpAddrs", newHH.hostinfo.remotes.CopyBlockedRemotes()). f.l.WithField("blockedUdpAddrs", newHH.hostinfo.remotes.CopyBlockedRemotes()).
WithField("vpnNetworks", vpnNetworks). WithField("vpnNetworks", vpnNetworks).
@@ -620,7 +625,7 @@ func ixHandshakeStage2(f *Interface, via ViaSender, hh *HandshakeHostInfo, packe
ci.window.Update(f.l, 2) ci.window.Update(f.l, 2)
duration := time.Since(hh.startTime).Nanoseconds() duration := time.Since(hh.startTime).Nanoseconds()
msgRxL := f.l.WithField("vpnAddrs", vpnAddrs).WithField("from", via). msgRxL := f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
WithField("certName", certName). WithField("certName", certName).
WithField("certVersion", certVersion). WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint). WithField("fingerprint", fingerprint).

View File

@@ -136,11 +136,11 @@ func (hm *HandshakeManager) Run(ctx context.Context) {
} }
} }
func (hm *HandshakeManager) HandleIncoming(via ViaSender, packet []byte, h *header.H) { func (hm *HandshakeManager) HandleIncoming(addr netip.AddrPort, via *ViaSender, packet []byte, h *header.H) {
// First remote allow list check before we know the vpnIp // First remote allow list check before we know the vpnIp
if !via.IsRelayed { if addr.IsValid() {
if !hm.lightHouse.GetRemoteAllowList().AllowUnknownVpnAddr(via.UdpAddr.Addr()) { if !hm.lightHouse.GetRemoteAllowList().AllowUnknownVpnAddr(addr.Addr()) {
hm.l.WithField("from", via).Debug("lighthouse.remote_allow_list denied incoming handshake") hm.l.WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
return return
} }
} }
@@ -149,11 +149,11 @@ func (hm *HandshakeManager) HandleIncoming(via ViaSender, packet []byte, h *head
case header.HandshakeIXPSK0: case header.HandshakeIXPSK0:
switch h.MessageCounter { switch h.MessageCounter {
case 1: case 1:
ixHandshakeStage1(hm.f, via, packet, h) ixHandshakeStage1(hm.f, addr, via, packet, h)
case 2: case 2:
newHostinfo := hm.queryIndex(h.RemoteIndex) newHostinfo := hm.queryIndex(h.RemoteIndex)
tearDown := ixHandshakeStage2(hm.f, via, newHostinfo, packet, h) tearDown := ixHandshakeStage2(hm.f, addr, via, newHostinfo, packet, h)
if tearDown && newHostinfo != nil { if tearDown && newHostinfo != nil {
hm.DeleteHostInfo(newHostinfo.hostinfo) hm.DeleteHostInfo(newHostinfo.hostinfo)
} }

View File

@@ -1,9 +1,7 @@
package nebula package nebula
import ( import (
"encoding/json"
"errors" "errors"
"fmt"
"net" "net"
"net/netip" "net/netip"
"slices" "slices"
@@ -278,25 +276,9 @@ type HostInfo struct {
} }
type ViaSender struct { type ViaSender struct {
UdpAddr netip.AddrPort
relayHI *HostInfo // relayHI is the host info object of the relay relayHI *HostInfo // relayHI is the host info object of the relay
remoteIdx uint32 // remoteIdx is the index included in the header of the received packet remoteIdx uint32 // remoteIdx is the index included in the header of the received packet
relay *Relay // relay contains the rest of the relay information, including the PeerIP of the host trying to communicate with us. relay *Relay // relay contains the rest of the relay information, including the PeerIP of the host trying to communicate with us.
IsRelayed bool // IsRelayed is true if the packet was sent through a relay
}
func (v ViaSender) String() string {
if v.IsRelayed {
return fmt.Sprintf("%s (relayed)", v.UdpAddr)
}
return v.UdpAddr.String()
}
func (v ViaSender) MarshalJSON() ([]byte, error) {
if v.IsRelayed {
return json.Marshal(m{"relay": v.UdpAddr})
}
return json.Marshal(m{"direct": v.UdpAddr})
} }
type cachedPacket struct { type cachedPacket struct {
@@ -712,7 +694,6 @@ func (i *HostInfo) GetCert() *cert.CachedCertificate {
return nil return nil
} }
// TODO: Maybe use ViaSender here?
func (i *HostInfo) SetRemote(remote netip.AddrPort) { func (i *HostInfo) SetRemote(remote netip.AddrPort) {
// We copy here because we likely got this remote from a source that reuses the object // We copy here because we likely got this remote from a source that reuses the object
if i.remote != remote { if i.remote != remote {
@@ -723,14 +704,14 @@ func (i *HostInfo) SetRemote(remote netip.AddrPort) {
// SetRemoteIfPreferred returns true if the remote was changed. The lastRoam // SetRemoteIfPreferred returns true if the remote was changed. The lastRoam
// time on the HostInfo will also be updated. // time on the HostInfo will also be updated.
func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, via ViaSender) bool { func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, newRemote netip.AddrPort) bool {
if via.IsRelayed { if !newRemote.IsValid() {
// relays have nil udp Addrs
return false return false
} }
currentRemote := i.remote currentRemote := i.remote
if !currentRemote.IsValid() { if !currentRemote.IsValid() {
i.SetRemote(via.UdpAddr) i.SetRemote(newRemote)
return true return true
} }
@@ -743,7 +724,7 @@ func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, via ViaSender) bool {
return false return false
} }
if l.Contains(via.UdpAddr.Addr()) { if l.Contains(newRemote.Addr()) {
newIsPreferred = true newIsPreferred = true
} }
} }
@@ -753,7 +734,7 @@ func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, via ViaSender) bool {
i.lastRoam = time.Now() i.lastRoam = time.Now()
i.lastRoamRemote = currentRemote i.lastRoamRemote = currentRemote
i.SetRemote(via.UdpAddr) i.SetRemote(newRemote)
return true return true
} }

View File

@@ -279,7 +279,7 @@ func (f *Interface) listenOut(i int) {
nb := make([]byte, 12, 12) nb := make([]byte, 12, 12)
li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) { li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) {
f.readOutsidePackets(ViaSender{UdpAddr: fromUdpAddr}, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get(f.l)) f.readOutsidePackets(fromUdpAddr, nil, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get(f.l))
}) })
} }

View File

@@ -19,21 +19,21 @@ const (
minFwPacketLen = 4 minFwPacketLen = 4
) )
func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache) { func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache) {
err := h.Parse(packet) err := h.Parse(packet)
if err != nil { if err != nil {
// Hole punch packets are 0 or 1 byte big, so lets ignore printing those errors // Hole punch packets are 0 or 1 byte big, so lets ignore printing those errors
if len(packet) > 1 { if len(packet) > 1 {
f.l.WithField("packet", packet).Infof("Error while parsing inbound packet from %s: %s", via, err) f.l.WithField("packet", packet).Infof("Error while parsing inbound packet from %s: %s", ip, err)
} }
return return
} }
//l.Error("in packet ", header, packet[HeaderLen:]) //l.Error("in packet ", header, packet[HeaderLen:])
if !via.IsRelayed { if ip.IsValid() {
if f.myVpnNetworksTable.Contains(via.UdpAddr.Addr()) { if f.myVpnNetworksTable.Contains(ip.Addr()) {
if f.l.Level >= logrus.DebugLevel { if f.l.Level >= logrus.DebugLevel {
f.l.WithField("from", via).Debug("Refusing to process double encrypted packet") f.l.WithField("udpAddr", ip).Debug("Refusing to process double encrypted packet")
} }
return return
} }
@@ -54,7 +54,8 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte,
switch h.Type { switch h.Type {
case header.Message: case header.Message:
if !f.handleEncrypted(ci, via, h) { // TODO handleEncrypted sends directly to addr on error. Handle this in the tunneling case.
if !f.handleEncrypted(ci, ip, h) {
return return
} }
@@ -78,7 +79,7 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte,
// Successfully validated the thing. Get rid of the Relay header. // Successfully validated the thing. Get rid of the Relay header.
signedPayload = signedPayload[header.Len:] signedPayload = signedPayload[header.Len:]
// Pull the Roaming parts up here, and return in all call paths. // Pull the Roaming parts up here, and return in all call paths.
f.handleHostRoaming(hostinfo, via) f.handleHostRoaming(hostinfo, ip)
// Track usage of both the HostInfo and the Relay for the received & authenticated packet // Track usage of both the HostInfo and the Relay for the received & authenticated packet
f.connectionManager.In(hostinfo) f.connectionManager.In(hostinfo)
f.connectionManager.RelayUsed(h.RemoteIndex) f.connectionManager.RelayUsed(h.RemoteIndex)
@@ -95,14 +96,7 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte,
case TerminalType: case TerminalType:
// If I am the target of this relay, process the unwrapped packet // If I am the target of this relay, process the unwrapped packet
// From this recursive point, all these variables are 'burned'. We shouldn't rely on them again. // From this recursive point, all these variables are 'burned'. We shouldn't rely on them again.
via = ViaSender{ f.readOutsidePackets(netip.AddrPort{}, &ViaSender{relayHI: hostinfo, remoteIdx: relay.RemoteIndex, relay: relay}, out[:0], signedPayload, h, fwPacket, lhf, nb, q, localCache)
UdpAddr: via.UdpAddr,
relayHI: hostinfo,
remoteIdx: relay.RemoteIndex,
relay: relay,
IsRelayed: true,
}
f.readOutsidePackets(via, out[:0], signedPayload, h, fwPacket, lhf, nb, q, localCache)
return return
case ForwardingType: case ForwardingType:
// Find the target HostInfo relay object // Find the target HostInfo relay object
@@ -132,32 +126,31 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte,
case header.LightHouse: case header.LightHouse:
f.messageMetrics.Rx(h.Type, h.Subtype, 1) f.messageMetrics.Rx(h.Type, h.Subtype, 1)
if !f.handleEncrypted(ci, via, h) { if !f.handleEncrypted(ci, ip, h) {
return return
} }
d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb) d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb)
if err != nil { if err != nil {
hostinfo.logger(f.l).WithError(err).WithField("from", via). hostinfo.logger(f.l).WithError(err).WithField("udpAddr", ip).
WithField("packet", packet). WithField("packet", packet).
Error("Failed to decrypt lighthouse packet") Error("Failed to decrypt lighthouse packet")
return return
} }
//TODO: assert via is not relayed lhf.HandleRequest(ip, hostinfo.vpnAddrs, d, f)
lhf.HandleRequest(via.UdpAddr, hostinfo.vpnAddrs, d, f)
// Fallthrough to the bottom to record incoming traffic // Fallthrough to the bottom to record incoming traffic
case header.Test: case header.Test:
f.messageMetrics.Rx(h.Type, h.Subtype, 1) f.messageMetrics.Rx(h.Type, h.Subtype, 1)
if !f.handleEncrypted(ci, via, h) { if !f.handleEncrypted(ci, ip, h) {
return return
} }
d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb) d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb)
if err != nil { if err != nil {
hostinfo.logger(f.l).WithError(err).WithField("from", via). hostinfo.logger(f.l).WithError(err).WithField("udpAddr", ip).
WithField("packet", packet). WithField("packet", packet).
Error("Failed to decrypt test packet") Error("Failed to decrypt test packet")
return return
@@ -166,7 +159,7 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte,
if h.Subtype == header.TestRequest { if h.Subtype == header.TestRequest {
// This testRequest might be from TryPromoteBest, so we should roam // This testRequest might be from TryPromoteBest, so we should roam
// to the new IP address before responding // to the new IP address before responding
f.handleHostRoaming(hostinfo, via) f.handleHostRoaming(hostinfo, ip)
f.send(header.Test, header.TestReply, ci, hostinfo, d, nb, out) f.send(header.Test, header.TestReply, ci, hostinfo, d, nb, out)
} }
@@ -177,34 +170,34 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte,
case header.Handshake: case header.Handshake:
f.messageMetrics.Rx(h.Type, h.Subtype, 1) f.messageMetrics.Rx(h.Type, h.Subtype, 1)
f.handshakeManager.HandleIncoming(via, packet, h) f.handshakeManager.HandleIncoming(ip, via, packet, h)
return return
case header.RecvError: case header.RecvError:
f.messageMetrics.Rx(h.Type, h.Subtype, 1) f.messageMetrics.Rx(h.Type, h.Subtype, 1)
f.handleRecvError(via.UdpAddr, h) f.handleRecvError(ip, h)
return return
case header.CloseTunnel: case header.CloseTunnel:
f.messageMetrics.Rx(h.Type, h.Subtype, 1) f.messageMetrics.Rx(h.Type, h.Subtype, 1)
if !f.handleEncrypted(ci, via, h) { if !f.handleEncrypted(ci, ip, h) {
return return
} }
hostinfo.logger(f.l).WithField("from", via). hostinfo.logger(f.l).WithField("udpAddr", ip).
Info("Close tunnel received, tearing down.") Info("Close tunnel received, tearing down.")
f.closeTunnel(hostinfo) f.closeTunnel(hostinfo)
return return
case header.Control: case header.Control:
if !f.handleEncrypted(ci, via, h) { if !f.handleEncrypted(ci, ip, h) {
return return
} }
d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb) d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb)
if err != nil { if err != nil {
hostinfo.logger(f.l).WithError(err).WithField("from", via). hostinfo.logger(f.l).WithError(err).WithField("udpAddr", ip).
WithField("packet", packet). WithField("packet", packet).
Error("Failed to decrypt Control packet") Error("Failed to decrypt Control packet")
return return
@@ -214,11 +207,11 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte,
default: default:
f.messageMetrics.Rx(h.Type, h.Subtype, 1) f.messageMetrics.Rx(h.Type, h.Subtype, 1)
hostinfo.logger(f.l).Debugf("Unexpected packet received from %s", via) hostinfo.logger(f.l).Debugf("Unexpected packet received from %s", ip)
return return
} }
f.handleHostRoaming(hostinfo, via) f.handleHostRoaming(hostinfo, ip)
f.connectionManager.In(hostinfo) f.connectionManager.In(hostinfo)
} }
@@ -237,36 +230,36 @@ func (f *Interface) sendCloseTunnel(h *HostInfo) {
f.send(header.CloseTunnel, 0, h.ConnectionState, h, []byte{}, make([]byte, 12, 12), make([]byte, mtu)) f.send(header.CloseTunnel, 0, h.ConnectionState, h, []byte{}, make([]byte, 12, 12), make([]byte, mtu))
} }
func (f *Interface) handleHostRoaming(hostinfo *HostInfo, via ViaSender) { func (f *Interface) handleHostRoaming(hostinfo *HostInfo, udpAddr netip.AddrPort) {
if !via.IsRelayed && hostinfo.remote != via.UdpAddr { if udpAddr.IsValid() && hostinfo.remote != udpAddr {
if !f.lightHouse.GetRemoteAllowList().AllowAll(hostinfo.vpnAddrs, via.UdpAddr.Addr()) { if !f.lightHouse.GetRemoteAllowList().AllowAll(hostinfo.vpnAddrs, udpAddr.Addr()) {
hostinfo.logger(f.l).WithField("newAddr", via.UdpAddr).Debug("lighthouse.remote_allow_list denied roaming") hostinfo.logger(f.l).WithField("newAddr", udpAddr).Debug("lighthouse.remote_allow_list denied roaming")
return return
} }
if !hostinfo.lastRoam.IsZero() && via.UdpAddr == hostinfo.lastRoamRemote && time.Since(hostinfo.lastRoam) < RoamingSuppressSeconds*time.Second { if !hostinfo.lastRoam.IsZero() && udpAddr == hostinfo.lastRoamRemote && time.Since(hostinfo.lastRoam) < RoamingSuppressSeconds*time.Second {
if f.l.Level >= logrus.DebugLevel { if f.l.Level >= logrus.DebugLevel {
hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", via.UdpAddr). hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", udpAddr).
Debugf("Suppressing roam back to previous remote for %d seconds", RoamingSuppressSeconds) Debugf("Suppressing roam back to previous remote for %d seconds", RoamingSuppressSeconds)
} }
return return
} }
hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", via.UdpAddr). hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", udpAddr).
Info("Host roamed to new udp ip/port.") Info("Host roamed to new udp ip/port.")
hostinfo.lastRoam = time.Now() hostinfo.lastRoam = time.Now()
hostinfo.lastRoamRemote = hostinfo.remote hostinfo.lastRoamRemote = hostinfo.remote
hostinfo.SetRemote(via.UdpAddr) hostinfo.SetRemote(udpAddr)
} }
} }
// handleEncrypted returns true if a packet should be processed, false otherwise // handleEncrypted returns true if a packet should be processed, false otherwise
func (f *Interface) handleEncrypted(ci *ConnectionState, via ViaSender, h *header.H) bool { func (f *Interface) handleEncrypted(ci *ConnectionState, addr netip.AddrPort, h *header.H) bool {
// If connectionstate does not exist, send a recv error, if possible, to encourage a fast reconnect // If connectionstate does not exist, send a recv error, if possible, to encourage a fast reconnect
if ci == nil { if ci == nil {
if !via.IsRelayed { if addr.IsValid() {
f.maybeSendRecvError(via.UdpAddr, h.RemoteIndex) f.maybeSendRecvError(addr, h.RemoteIndex)
} }
return false return false
} }

View File

@@ -338,21 +338,21 @@ func (r *RemoteList) CopyCache() *CacheMap {
} }
// BlockRemote locks and records the address as bad, it will be excluded from the deduplicated address list // BlockRemote locks and records the address as bad, it will be excluded from the deduplicated address list
func (r *RemoteList) BlockRemote(bad ViaSender) { func (r *RemoteList) BlockRemote(bad netip.AddrPort) {
if bad.IsRelayed { if !bad.IsValid() {
// relays can have nil udp Addrs
return return
} }
r.Lock() r.Lock()
defer r.Unlock() defer r.Unlock()
// Check if we already blocked this addr // Check if we already blocked this addr
if r.unlockedIsBad(bad.UdpAddr) { if r.unlockedIsBad(bad) {
return return
} }
// We copy here because we are taking something else's memory and we can't trust everything // We copy here because we are taking something else's memory and we can't trust everything
r.badRemotes = append(r.badRemotes, bad.UdpAddr) r.badRemotes = append(r.badRemotes, bad)
// Mark the next interaction must recollect/dedupe // Mark the next interaction must recollect/dedupe
r.shouldRebuild = true r.shouldRebuild = true