Compare commits

..

2 Commits

Author SHA1 Message Date
Nate Brown
ee8e4d2017 Start of the changelog 2025-11-18 23:00:04 -06:00
Nate Brown
8d656fb890 Pull in v1.9.5-v1.9.7 CHANGELOG 2025-11-18 21:58:26 -06:00
36 changed files with 371 additions and 613 deletions

View File

@@ -7,12 +7,64 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [Unreleased] ## [Unreleased]
## [1.10.0] - ????
### Added
- PKCS11 support for P256 keys when built with `pkcs11` tag (#1153)
- ASN.1 based v2 nebula certificates with support for ipv6 and multiple ip addresses.
Certificates now have a unified interface for external implementations. (#1212, #1216, #1345)
**TODO: External documentation link!**
- Add the ability to mark packets on linux to better target nebula packets in iptables/nftables. (#1331)
- Add ECMP support for `unsafe_routes`. (#1332)
### Changed ### Changed
- `default_local_cidr_any` now defaults to false, meaning that any firewall rule - `default_local_cidr_any` now defaults to false, meaning that any firewall rule
intended to target an `unsafe_routes` entry must explicitly declare it via the intended to target an `unsafe_routes` entry must explicitly declare it via the
`local_cidr` field. This is almost always the intended behavior. This flag is `local_cidr` field. This is almost always the intended behavior. This flag is
deprecated and will be removed in a future release. deprecated and will be removed in a future release. (#1373)
### Fixed
- Fix moving a udp address from one vpn address to another in the `static_host_map`
which could cause rapid re-handshaking with an incorrect remote. (#1259)
- Improve smoke tests in environments where the docker network is not the default. (#1347)
## [1.9.7] - 2025-10-10
### Security
- Fix an issue where Nebula could incorrectly accept and process a packet from an erroneous source IP when the sender's
certificate is configured with unsafe_routes (cert v1/v2) or multiple IPs (cert v2). (#1494)
### Changed
- Disable sending `recv_error` messages when a packet is received outside the allowable counter window. (#1459)
- Improve error messages and remove some unnecessary fatal conditions in the Windows and generic udp listener. (#1543)
## [1.9.6] - 2025-7-15
### Added
- Support dropping inactive tunnels. This is disabled by default in this release but can be enabled with `tunnels.drop_inactive`. See example config for more details. (#1413)
### Fixed
- Fix Darwin freeze due to presence of some Network Extensions (#1426)
- Ensure the same relay tunnel is always used when multiple relay tunnels are present (#1422)
- Fix Windows freeze due to ICMP error handling (#1412)
- Fix relay migration panic (#1403)
## [1.9.5] - 2024-12-05
### Added
- Gracefully ignore v2 certificates. (#1282)
### Fixed
- Fix relays that refuse to re-establish after one of the remote tunnel pairs breaks. (#1277)
## [1.9.4] - 2024-09-09 ## [1.9.4] - 2024-09-09
@@ -671,7 +723,11 @@ created.)
- Initial public release. - Initial public release.
[Unreleased]: https://github.com/slackhq/nebula/compare/v1.9.4...HEAD [Unreleased]: https://github.com/slackhq/nebula/compare/v1.10.0...HEAD
[1.10.0]: https://github.com/slackhq/nebula/releases/tag/v1.10.0
[1.9.7]: https://github.com/slackhq/nebula/releases/tag/v1.9.7
[1.9.6]: https://github.com/slackhq/nebula/releases/tag/v1.9.6
[1.9.5]: https://github.com/slackhq/nebula/releases/tag/v1.9.5
[1.9.4]: https://github.com/slackhq/nebula/releases/tag/v1.9.4 [1.9.4]: https://github.com/slackhq/nebula/releases/tag/v1.9.4
[1.9.3]: https://github.com/slackhq/nebula/releases/tag/v1.9.3 [1.9.3]: https://github.com/slackhq/nebula/releases/tag/v1.9.3
[1.9.2]: https://github.com/slackhq/nebula/releases/tag/v1.9.2 [1.9.2]: https://github.com/slackhq/nebula/releases/tag/v1.9.2

View File

@@ -43,7 +43,7 @@ type signFlags struct {
func newSignFlags() *signFlags { func newSignFlags() *signFlags {
sf := signFlags{set: flag.NewFlagSet("sign", flag.ContinueOnError)} sf := signFlags{set: flag.NewFlagSet("sign", flag.ContinueOnError)}
sf.set.Usage = func() {} sf.set.Usage = func() {}
sf.version = sf.set.Uint("version", 0, "Optional: version of the certificate format to use. The default is to match the version of the signing CA") sf.version = sf.set.Uint("version", 0, "Optional: version of the certificate format to use, the default is to create both v1 and v2 certificates.")
sf.caKeyPath = sf.set.String("ca-key", "ca.key", "Optional: path to the signing CA key") sf.caKeyPath = sf.set.String("ca-key", "ca.key", "Optional: path to the signing CA key")
sf.caCertPath = sf.set.String("ca-crt", "ca.crt", "Optional: path to the signing CA cert") sf.caCertPath = sf.set.String("ca-crt", "ca.crt", "Optional: path to the signing CA cert")
sf.name = sf.set.String("name", "", "Required: name of the cert, usually a hostname") sf.name = sf.set.String("name", "", "Required: name of the cert, usually a hostname")
@@ -167,10 +167,6 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
return fmt.Errorf("ca certificate is expired") return fmt.Errorf("ca certificate is expired")
} }
if version == 0 {
version = caCert.Version()
}
// if no duration is given, expire one second before the root expires // if no duration is given, expire one second before the root expires
if *sf.duration <= 0 { if *sf.duration <= 0 {
*sf.duration = time.Until(caCert.NotAfter()) - time.Second*1 *sf.duration = time.Until(caCert.NotAfter()) - time.Second*1
@@ -283,19 +279,21 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
notBefore := time.Now() notBefore := time.Now()
notAfter := notBefore.Add(*sf.duration) notAfter := notBefore.Add(*sf.duration)
switch version { if version == 0 || version == cert.Version1 {
case cert.Version1: // Make sure we at least have an ip
// Make sure we have only one ipv4 address
if len(v4Networks) != 1 { if len(v4Networks) != 1 {
return newHelpErrorf("invalid -networks definition: v1 certificates can only have a single ipv4 address") return newHelpErrorf("invalid -networks definition: v1 certificates can only have a single ipv4 address")
} }
if len(v6Networks) > 0 { if version == cert.Version1 {
return newHelpErrorf("invalid -networks definition: v1 certificates can only contain ipv4 addresses") // If we are asked to mint a v1 certificate only then we cant just ignore any v6 addresses
} if len(v6Networks) > 0 {
return newHelpErrorf("invalid -networks definition: v1 certificates can only be ipv4")
}
if len(v6UnsafeNetworks) > 0 { if len(v6UnsafeNetworks) > 0 {
return newHelpErrorf("invalid -unsafe-networks definition: v1 certificates can only contain ipv4 addresses") return newHelpErrorf("invalid -unsafe-networks definition: v1 certificates can only be ipv4")
}
} }
t := &cert.TBSCertificate{ t := &cert.TBSCertificate{
@@ -325,8 +323,9 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
} }
crts = append(crts, nc) crts = append(crts, nc)
}
case cert.Version2: if version == 0 || version == cert.Version2 {
t := &cert.TBSCertificate{ t := &cert.TBSCertificate{
Version: cert.Version2, Version: cert.Version2,
Name: *sf.name, Name: *sf.name,
@@ -354,9 +353,6 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
} }
crts = append(crts, nc) crts = append(crts, nc)
default:
// this should be unreachable
return fmt.Errorf("invalid version: %d", version)
} }
if !isP11 && *sf.inPubPath == "" { if !isP11 && *sf.inPubPath == "" {

View File

@@ -55,7 +55,7 @@ func Test_signHelp(t *testing.T) {
" -unsafe-networks string\n"+ " -unsafe-networks string\n"+
" \tOptional: comma separated list of ip address and network in CIDR notation. Unsafe networks this cert can route for\n"+ " \tOptional: comma separated list of ip address and network in CIDR notation. Unsafe networks this cert can route for\n"+
" -version uint\n"+ " -version uint\n"+
" \tOptional: version of the certificate format to use. The default is to match the version of the signing CA\n", " \tOptional: version of the certificate format to use, the default is to create both v1 and v2 certificates.\n",
ob.String(), ob.String(),
) )
} }
@@ -204,7 +204,7 @@ func Test_signCert(t *testing.T) {
ob.Reset() ob.Reset()
eb.Reset() eb.Reset()
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "100::100/100"} args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "100::100/100"}
assertHelpError(t, signCert(args, ob, eb, nopw), "invalid -unsafe-networks definition: v1 certificates can only contain ipv4 addresses") assertHelpError(t, signCert(args, ob, eb, nopw), "invalid -unsafe-networks definition: v1 certificates can only be ipv4")
assert.Empty(t, ob.String()) assert.Empty(t, ob.String())
assert.Empty(t, eb.String()) assert.Empty(t, eb.String())

View File

@@ -50,6 +50,11 @@ func NewConnectionState(l *logrus.Logger, cs *CertState, crt cert.Certificate, i
} }
static := noise.DHKey{Private: cs.privateKey, Public: crt.PublicKey()} static := noise.DHKey{Private: cs.privateKey, Public: crt.PublicKey()}
b := NewBits(ReplayWindow)
// Clear out bit 0, we never transmit it, and we don't want it showing as packet loss
b.Update(l, 0)
hs, err := noise.NewHandshakeState(noise.Config{ hs, err := noise.NewHandshakeState(noise.Config{
CipherSuite: ncs, CipherSuite: ncs,
Random: rand.Reader, Random: rand.Reader,
@@ -69,7 +74,7 @@ func NewConnectionState(l *logrus.Logger, cs *CertState, crt cert.Certificate, i
ci := &ConnectionState{ ci := &ConnectionState{
H: hs, H: hs,
initiator: initiator, initiator: initiator,
window: NewBits(ReplayWindow), window: b,
myCert: crt, myCert: crt,
} }
// always start the counter from 2, as packet 1 and packet 2 are handshake packets. // always start the counter from 2, as packet 1 and packet 2 are handshake packets.

View File

@@ -25,12 +25,11 @@ import (
func BenchmarkHotPath(b *testing.B) { func BenchmarkHotPath(b *testing.B) {
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil) myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil)
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil) theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil)
// Put their info in our lighthouse // Put their info in our lighthouse
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
// Start the servers // Start the servers
myControl.Start() myControl.Start()
@@ -39,9 +38,6 @@ func BenchmarkHotPath(b *testing.B) {
r := router.NewR(b, myControl, theirControl) r := router.NewR(b, myControl, theirControl)
r.CancelFlowLogs() r.CancelFlowLogs()
assertTunnel(b, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
b.ResetTimer()
for n := 0; n < b.N; n++ { for n := 0; n < b.N; n++ {
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
_ = r.RouteForAllUntilTxTun(theirControl) _ = r.RouteForAllUntilTxTun(theirControl)
@@ -51,39 +47,6 @@ func BenchmarkHotPath(b *testing.B) {
theirControl.Stop() theirControl.Stop()
} }
func BenchmarkHotPathRelay(b *testing.B) {
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}})
relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}})
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}})
// Teach my how to get to the relay and that their can be reached via the relay
myControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr)
myControl.InjectRelays(theirVpnIpNet[0].Addr(), []netip.Addr{relayVpnIpNet[0].Addr()})
relayControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
// Build a router so we don't have to reason who gets which packet
r := router.NewR(b, myControl, relayControl, theirControl)
r.CancelFlowLogs()
// Start the servers
myControl.Start()
relayControl.Start()
theirControl.Start()
assertTunnel(b, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r)
b.ResetTimer()
for n := 0; n < b.N; n++ {
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
_ = r.RouteForAllUntilTxTun(theirControl)
}
myControl.Stop()
theirControl.Stop()
relayControl.Stop()
}
func TestGoodHandshake(t *testing.T) { func TestGoodHandshake(t *testing.T) {
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil) myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil)

View File

@@ -292,7 +292,7 @@ func deadline(t *testing.T, seconds time.Duration) doneCb {
} }
} }
func assertTunnel(t testing.TB, vpnIpA, vpnIpB netip.Addr, controlA, controlB *nebula.Control, r *router.R) { func assertTunnel(t *testing.T, vpnIpA, vpnIpB netip.Addr, controlA, controlB *nebula.Control, r *router.R) {
// Send a packet from them to me // Send a packet from them to me
controlB.InjectTunUDPPacket(vpnIpA, 80, vpnIpB, 90, []byte("Hi from B")) controlB.InjectTunUDPPacket(vpnIpA, 80, vpnIpB, 90, []byte("Hi from B"))
bPacket := r.RouteForAllUntilTxTun(controlA) bPacket := r.RouteForAllUntilTxTun(controlA)
@@ -325,7 +325,7 @@ func assertHostInfoPair(t *testing.T, addrA, addrB netip.AddrPort, vpnNetsA, vpn
assert.Equal(t, hBinA.RemoteIndex, hAinB.LocalIndex, "Host B remote index does not match host A local index") assert.Equal(t, hBinA.RemoteIndex, hAinB.LocalIndex, "Host B remote index does not match host A local index")
} }
func assertUdpPacket(t testing.TB, expected, b []byte, fromIp, toIp netip.Addr, fromPort, toPort uint16) { func assertUdpPacket(t *testing.T, expected, b []byte, fromIp, toIp netip.Addr, fromPort, toPort uint16) {
if toIp.Is6() { if toIp.Is6() {
assertUdpPacket6(t, expected, b, fromIp, toIp, fromPort, toPort) assertUdpPacket6(t, expected, b, fromIp, toIp, fromPort, toPort)
} else { } else {
@@ -333,7 +333,7 @@ func assertUdpPacket(t testing.TB, expected, b []byte, fromIp, toIp netip.Addr,
} }
} }
func assertUdpPacket6(t testing.TB, expected, b []byte, fromIp, toIp netip.Addr, fromPort, toPort uint16) { func assertUdpPacket6(t *testing.T, expected, b []byte, fromIp, toIp netip.Addr, fromPort, toPort uint16) {
packet := gopacket.NewPacket(b, layers.LayerTypeIPv6, gopacket.Lazy) packet := gopacket.NewPacket(b, layers.LayerTypeIPv6, gopacket.Lazy)
v6 := packet.Layer(layers.LayerTypeIPv6).(*layers.IPv6) v6 := packet.Layer(layers.LayerTypeIPv6).(*layers.IPv6)
assert.NotNil(t, v6, "No ipv6 data found") assert.NotNil(t, v6, "No ipv6 data found")
@@ -352,7 +352,7 @@ func assertUdpPacket6(t testing.TB, expected, b []byte, fromIp, toIp netip.Addr,
assert.Equal(t, expected, data.Payload(), "Data was incorrect") assert.Equal(t, expected, data.Payload(), "Data was incorrect")
} }
func assertUdpPacket4(t testing.TB, expected, b []byte, fromIp, toIp netip.Addr, fromPort, toPort uint16) { func assertUdpPacket4(t *testing.T, expected, b []byte, fromIp, toIp netip.Addr, fromPort, toPort uint16) {
packet := gopacket.NewPacket(b, layers.LayerTypeIPv4, gopacket.Lazy) packet := gopacket.NewPacket(b, layers.LayerTypeIPv4, gopacket.Lazy)
v4 := packet.Layer(layers.LayerTypeIPv4).(*layers.IPv4) v4 := packet.Layer(layers.LayerTypeIPv4).(*layers.IPv4)
assert.NotNil(t, v4, "No ipv4 data found") assert.NotNil(t, v4, "No ipv4 data found")

View File

@@ -383,9 +383,8 @@ firewall:
# host: `any` or a literal hostname, ie `test-host` # host: `any` or a literal hostname, ie `test-host`
# group: `any` or a literal group name, ie `default-group` # group: `any` or a literal group name, ie `default-group`
# groups: Same as group but accepts a list of values. Multiple values are AND'd together and a certificate would have to contain all groups to pass # groups: Same as group but accepts a list of values. Multiple values are AND'd together and a certificate would have to contain all groups to pass
# cidr: a remote CIDR, `0.0.0.0/0` is any ipv4 and `::/0` is any ipv6. `any` means any ip family and address. # cidr: a remote CIDR, `0.0.0.0/0` is any ipv4 and `::/0` is any ipv6.
# local_cidr: a local CIDR, `0.0.0.0/0` is any ipv4 and `::/0` is any ipv6. `any` means any ip family and address. # local_cidr: a local CIDR, `0.0.0.0/0` is any ipv4 and `::/0` is any ipv6. This can be used to filter destinations when using unsafe_routes.
# This can be used to filter destinations when using unsafe_routes.
# By default, this is set to only the VPN (overlay) networks assigned via the certificate networks field unless `default_local_cidr_any` is set to true. # By default, this is set to only the VPN (overlay) networks assigned via the certificate networks field unless `default_local_cidr_any` is set to true.
# If there are unsafe_routes present in this config file, `local_cidr` should be set appropriately for the intended us case. # If there are unsafe_routes present in this config file, `local_cidr` should be set appropriately for the intended us case.
# ca_name: An issuing CA name # ca_name: An issuing CA name

View File

@@ -8,7 +8,6 @@ import (
"hash/fnv" "hash/fnv"
"net/netip" "net/netip"
"reflect" "reflect"
"slices"
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
@@ -23,7 +22,7 @@ import (
) )
type FirewallInterface interface { type FirewallInterface interface {
AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, cidr, localCidr string, caName string, caSha string) error AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, addr, localAddr netip.Prefix, caName string, caSha string) error
} }
type conn struct { type conn struct {
@@ -248,11 +247,22 @@ func NewFirewallFromConfig(l *logrus.Logger, cs *CertState, c *config.C) (*Firew
} }
// AddRule properly creates the in memory rule structure for a firewall table. // AddRule properly creates the in memory rule structure for a firewall table.
func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, cidr, localCidr, caName string, caSha string) error { func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip, localIp netip.Prefix, caName string, caSha string) error {
// Under gomobile, stringing a nil pointer with fmt causes an abort in debug mode for iOS
// https://github.com/golang/go/issues/14131
sIp := ""
if ip.IsValid() {
sIp = ip.String()
}
lIp := ""
if localIp.IsValid() {
lIp = localIp.String()
}
// We need this rule string because we generate a hash. Removing this will break firewall reload. // We need this rule string because we generate a hash. Removing this will break firewall reload.
ruleString := fmt.Sprintf( ruleString := fmt.Sprintf(
"incoming: %v, proto: %v, startPort: %v, endPort: %v, groups: %v, host: %v, ip: %v, localIp: %v, caName: %v, caSha: %s", "incoming: %v, proto: %v, startPort: %v, endPort: %v, groups: %v, host: %v, ip: %v, localIp: %v, caName: %v, caSha: %s",
incoming, proto, startPort, endPort, groups, host, cidr, localCidr, caName, caSha, incoming, proto, startPort, endPort, groups, host, sIp, lIp, caName, caSha,
) )
f.rules += ruleString + "\n" f.rules += ruleString + "\n"
@@ -260,7 +270,7 @@ func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort
if !incoming { if !incoming {
direction = "outgoing" direction = "outgoing"
} }
f.l.WithField("firewallRule", m{"direction": direction, "proto": proto, "startPort": startPort, "endPort": endPort, "groups": groups, "host": host, "cidr": cidr, "localCidr": localCidr, "caName": caName, "caSha": caSha}). f.l.WithField("firewallRule", m{"direction": direction, "proto": proto, "startPort": startPort, "endPort": endPort, "groups": groups, "host": host, "ip": sIp, "localIp": lIp, "caName": caName, "caSha": caSha}).
Info("Firewall rule added") Info("Firewall rule added")
var ( var (
@@ -287,7 +297,7 @@ func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort
return fmt.Errorf("unknown protocol %v", proto) return fmt.Errorf("unknown protocol %v", proto)
} }
return fp.addRule(f, startPort, endPort, groups, host, cidr, localCidr, caName, caSha) return fp.addRule(f, startPort, endPort, groups, host, ip, localIp, caName, caSha)
} }
// GetRuleHash returns a hash representation of all inbound and outbound rules // GetRuleHash returns a hash representation of all inbound and outbound rules
@@ -327,6 +337,7 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw
} }
for i, t := range rs { for i, t := range rs {
var groups []string
r, err := convertRule(l, t, table, i) r, err := convertRule(l, t, table, i)
if err != nil { if err != nil {
return fmt.Errorf("%s rule #%v; %s", table, i, err) return fmt.Errorf("%s rule #%v; %s", table, i, err)
@@ -336,10 +347,23 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw
return fmt.Errorf("%s rule #%v; only one of port or code should be provided", table, i) return fmt.Errorf("%s rule #%v; only one of port or code should be provided", table, i)
} }
if r.Host == "" && len(r.Groups) == 0 && r.Cidr == "" && r.LocalCidr == "" && r.CAName == "" && r.CASha == "" { if r.Host == "" && len(r.Groups) == 0 && r.Group == "" && r.Cidr == "" && r.LocalCidr == "" && r.CAName == "" && r.CASha == "" {
return fmt.Errorf("%s rule #%v; at least one of host, group, cidr, local_cidr, ca_name, or ca_sha must be provided", table, i) return fmt.Errorf("%s rule #%v; at least one of host, group, cidr, local_cidr, ca_name, or ca_sha must be provided", table, i)
} }
if len(r.Groups) > 0 {
groups = r.Groups
}
if r.Group != "" {
// Check if we have both groups and group provided in the rule config
if len(groups) > 0 {
return fmt.Errorf("%s rule #%v; only one of group or groups should be defined, both provided", table, i)
}
groups = []string{r.Group}
}
var sPort, errPort string var sPort, errPort string
if r.Code != "" { if r.Code != "" {
errPort = "code" errPort = "code"
@@ -368,25 +392,23 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw
return fmt.Errorf("%s rule #%v; proto was not understood; `%s`", table, i, r.Proto) return fmt.Errorf("%s rule #%v; proto was not understood; `%s`", table, i, r.Proto)
} }
if r.Cidr != "" && r.Cidr != "any" { var cidr netip.Prefix
_, err = netip.ParsePrefix(r.Cidr) if r.Cidr != "" {
cidr, err = netip.ParsePrefix(r.Cidr)
if err != nil { if err != nil {
return fmt.Errorf("%s rule #%v; cidr did not parse; %s", table, i, err) return fmt.Errorf("%s rule #%v; cidr did not parse; %s", table, i, err)
} }
} }
if r.LocalCidr != "" && r.LocalCidr != "any" { var localCidr netip.Prefix
_, err = netip.ParsePrefix(r.LocalCidr) if r.LocalCidr != "" {
localCidr, err = netip.ParsePrefix(r.LocalCidr)
if err != nil { if err != nil {
return fmt.Errorf("%s rule #%v; local_cidr did not parse; %s", table, i, err) return fmt.Errorf("%s rule #%v; local_cidr did not parse; %s", table, i, err)
} }
} }
if warning := r.sanity(); warning != nil { err = fw.AddRule(inbound, proto, startPort, endPort, groups, r.Host, cidr, localCidr, r.CAName, r.CASha)
l.Warnf("%s rule #%v; %s", table, i, warning)
}
err = fw.AddRule(inbound, proto, startPort, endPort, r.Groups, r.Host, r.Cidr, r.LocalCidr, r.CAName, r.CASha)
if err != nil { if err != nil {
return fmt.Errorf("%s rule #%v; `%s`", table, i, err) return fmt.Errorf("%s rule #%v; `%s`", table, i, err)
} }
@@ -633,7 +655,7 @@ func (ft *FirewallTable) match(p firewall.Packet, incoming bool, c *cert.CachedC
return false return false
} }
func (fp firewallPort) addRule(f *Firewall, startPort int32, endPort int32, groups []string, host string, cidr, localCidr, caName string, caSha string) error { func (fp firewallPort) addRule(f *Firewall, startPort int32, endPort int32, groups []string, host string, ip, localIp netip.Prefix, caName string, caSha string) error {
if startPort > endPort { if startPort > endPort {
return fmt.Errorf("start port was lower than end port") return fmt.Errorf("start port was lower than end port")
} }
@@ -646,7 +668,7 @@ func (fp firewallPort) addRule(f *Firewall, startPort int32, endPort int32, grou
} }
} }
if err := fp[i].addRule(f, groups, host, cidr, localCidr, caName, caSha); err != nil { if err := fp[i].addRule(f, groups, host, ip, localIp, caName, caSha); err != nil {
return err return err
} }
} }
@@ -677,7 +699,7 @@ func (fp firewallPort) match(p firewall.Packet, incoming bool, c *cert.CachedCer
return fp[firewall.PortAny].match(p, c, caPool) return fp[firewall.PortAny].match(p, c, caPool)
} }
func (fc *FirewallCA) addRule(f *Firewall, groups []string, host string, cidr, localCidr, caName, caSha string) error { func (fc *FirewallCA) addRule(f *Firewall, groups []string, host string, ip, localIp netip.Prefix, caName, caSha string) error {
fr := func() *FirewallRule { fr := func() *FirewallRule {
return &FirewallRule{ return &FirewallRule{
Hosts: make(map[string]*firewallLocalCIDR), Hosts: make(map[string]*firewallLocalCIDR),
@@ -691,14 +713,14 @@ func (fc *FirewallCA) addRule(f *Firewall, groups []string, host string, cidr, l
fc.Any = fr() fc.Any = fr()
} }
return fc.Any.addRule(f, groups, host, cidr, localCidr) return fc.Any.addRule(f, groups, host, ip, localIp)
} }
if caSha != "" { if caSha != "" {
if _, ok := fc.CAShas[caSha]; !ok { if _, ok := fc.CAShas[caSha]; !ok {
fc.CAShas[caSha] = fr() fc.CAShas[caSha] = fr()
} }
err := fc.CAShas[caSha].addRule(f, groups, host, cidr, localCidr) err := fc.CAShas[caSha].addRule(f, groups, host, ip, localIp)
if err != nil { if err != nil {
return err return err
} }
@@ -708,7 +730,7 @@ func (fc *FirewallCA) addRule(f *Firewall, groups []string, host string, cidr, l
if _, ok := fc.CANames[caName]; !ok { if _, ok := fc.CANames[caName]; !ok {
fc.CANames[caName] = fr() fc.CANames[caName] = fr()
} }
err := fc.CANames[caName].addRule(f, groups, host, cidr, localCidr) err := fc.CANames[caName].addRule(f, groups, host, ip, localIp)
if err != nil { if err != nil {
return err return err
} }
@@ -740,24 +762,24 @@ func (fc *FirewallCA) match(p firewall.Packet, c *cert.CachedCertificate, caPool
return fc.CANames[s.Certificate.Name()].match(p, c) return fc.CANames[s.Certificate.Name()].match(p, c)
} }
func (fr *FirewallRule) addRule(f *Firewall, groups []string, host, cidr, localCidr string) error { func (fr *FirewallRule) addRule(f *Firewall, groups []string, host string, ip, localCIDR netip.Prefix) error {
flc := func() *firewallLocalCIDR { flc := func() *firewallLocalCIDR {
return &firewallLocalCIDR{ return &firewallLocalCIDR{
LocalCIDR: new(bart.Lite), LocalCIDR: new(bart.Lite),
} }
} }
if fr.isAny(groups, host, cidr) { if fr.isAny(groups, host, ip) {
if fr.Any == nil { if fr.Any == nil {
fr.Any = flc() fr.Any = flc()
} }
return fr.Any.addRule(f, localCidr) return fr.Any.addRule(f, localCIDR)
} }
if len(groups) > 0 { if len(groups) > 0 {
nlc := flc() nlc := flc()
err := nlc.addRule(f, localCidr) err := nlc.addRule(f, localCIDR)
if err != nil { if err != nil {
return err return err
} }
@@ -773,34 +795,30 @@ func (fr *FirewallRule) addRule(f *Firewall, groups []string, host, cidr, localC
if nlc == nil { if nlc == nil {
nlc = flc() nlc = flc()
} }
err := nlc.addRule(f, localCidr) err := nlc.addRule(f, localCIDR)
if err != nil { if err != nil {
return err return err
} }
fr.Hosts[host] = nlc fr.Hosts[host] = nlc
} }
if cidr != "" { if ip.IsValid() {
c, err := netip.ParsePrefix(cidr) nlc, _ := fr.CIDR.Get(ip)
if err != nil {
return err
}
nlc, _ := fr.CIDR.Get(c)
if nlc == nil { if nlc == nil {
nlc = flc() nlc = flc()
} }
err = nlc.addRule(f, localCidr) err := nlc.addRule(f, localCIDR)
if err != nil { if err != nil {
return err return err
} }
fr.CIDR.Insert(c, nlc) fr.CIDR.Insert(ip, nlc)
} }
return nil return nil
} }
func (fr *FirewallRule) isAny(groups []string, host string, cidr string) bool { func (fr *FirewallRule) isAny(groups []string, host string, ip netip.Prefix) bool {
if len(groups) == 0 && host == "" && cidr == "" { if len(groups) == 0 && host == "" && !ip.IsValid() {
return true return true
} }
@@ -814,7 +832,7 @@ func (fr *FirewallRule) isAny(groups []string, host string, cidr string) bool {
return true return true
} }
if cidr == "any" { if ip.IsValid() && ip.Bits() == 0 {
return true return true
} }
@@ -866,13 +884,8 @@ func (fr *FirewallRule) match(p firewall.Packet, c *cert.CachedCertificate) bool
return false return false
} }
func (flc *firewallLocalCIDR) addRule(f *Firewall, localCidr string) error { func (flc *firewallLocalCIDR) addRule(f *Firewall, localIp netip.Prefix) error {
if localCidr == "any" { if !localIp.IsValid() {
flc.Any = true
return nil
}
if localCidr == "" {
if !f.hasUnsafeNetworks || f.defaultLocalCIDRAny { if !f.hasUnsafeNetworks || f.defaultLocalCIDRAny {
flc.Any = true flc.Any = true
return nil return nil
@@ -883,13 +896,12 @@ func (flc *firewallLocalCIDR) addRule(f *Firewall, localCidr string) error {
} }
return nil return nil
} else if localIp.Bits() == 0 {
flc.Any = true
return nil
} }
c, err := netip.ParsePrefix(localCidr) flc.LocalCIDR.Insert(localIp)
if err != nil {
return err
}
flc.LocalCIDR.Insert(c)
return nil return nil
} }
@@ -910,6 +922,7 @@ type rule struct {
Code string Code string
Proto string Proto string
Host string Host string
Group string
Groups []string Groups []string
Cidr string Cidr string
LocalCidr string LocalCidr string
@@ -951,8 +964,7 @@ func convertRule(l *logrus.Logger, p any, table string, i int) (rule, error) {
l.Warnf("%s rule #%v; group was an array with a single value, converting to simple value", table, i) l.Warnf("%s rule #%v; group was an array with a single value, converting to simple value", table, i)
m["group"] = v[0] m["group"] = v[0]
} }
r.Group = toString("group", m)
singleGroup := toString("group", m)
if rg, ok := m["groups"]; ok { if rg, ok := m["groups"]; ok {
switch reflect.TypeOf(rg).Kind() { switch reflect.TypeOf(rg).Kind() {
@@ -969,60 +981,9 @@ func convertRule(l *logrus.Logger, p any, table string, i int) (rule, error) {
} }
} }
//flatten group vs groups
if singleGroup != "" {
// Check if we have both groups and group provided in the rule config
if len(r.Groups) > 0 {
return r, fmt.Errorf("only one of group or groups should be defined, both provided")
}
r.Groups = []string{singleGroup}
}
return r, nil return r, nil
} }
// sanity returns an error if the rule would be evaluated in a way that would short-circuit a configured check on a wildcard value
// rules are evaluated as "port AND proto AND (ca_sha OR ca_name) AND (host OR group OR groups OR cidr) AND local_cidr"
func (r *rule) sanity() error {
//port, proto, local_cidr are AND, no need to check here
//ca_sha and ca_name don't have a wildcard value, no need to check here
groupsEmpty := len(r.Groups) == 0
hostEmpty := r.Host == ""
cidrEmpty := r.Cidr == ""
if (groupsEmpty && hostEmpty && cidrEmpty) == true {
return nil //no content!
}
groupsHasAny := slices.Contains(r.Groups, "any")
if groupsHasAny && len(r.Groups) > 1 {
return fmt.Errorf("groups spec [%s] contains the group '\"any\". This rule will ignore the other groups specified", r.Groups)
}
if r.Host == "any" {
if !groupsEmpty {
return fmt.Errorf("groups specified as %s, but host=any will match any host, regardless of groups", r.Groups)
}
if !cidrEmpty {
return fmt.Errorf("cidr specified as %s, but host=any will match any host, regardless of cidr", r.Cidr)
}
}
if groupsHasAny {
if !hostEmpty && r.Host != "any" {
return fmt.Errorf("groups spec [%s] contains the group '\"any\". This rule will ignore the specified host %s", r.Groups, r.Host)
}
if !cidrEmpty {
return fmt.Errorf("groups spec [%s] contains the group '\"any\". This rule will ignore the specified cidr %s", r.Groups, r.Cidr)
}
}
//todo alert on cidr-any
return nil
}
func parsePort(s string) (startPort, endPort int32, err error) { func parsePort(s string) (startPort, endPort int32, err error) {
if s == "any" { if s == "any" {
startPort = firewall.PortAny startPort = firewall.PortAny

View File

@@ -73,106 +73,78 @@ func TestFirewall_AddRule(t *testing.T) {
ti6, err := netip.ParsePrefix("fd12::34/128") ti6, err := netip.ParsePrefix("fd12::34/128")
require.NoError(t, err) require.NoError(t, err)
require.NoError(t, fw.AddRule(true, firewall.ProtoTCP, 1, 1, []string{}, "", "", "", "", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoTCP, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
// An empty rule is any // An empty rule is any
assert.True(t, fw.InRules.TCP[1].Any.Any.Any) assert.True(t, fw.InRules.TCP[1].Any.Any.Any)
assert.Empty(t, fw.InRules.TCP[1].Any.Groups) assert.Empty(t, fw.InRules.TCP[1].Any.Groups)
assert.Empty(t, fw.InRules.TCP[1].Any.Hosts) assert.Empty(t, fw.InRules.TCP[1].Any.Hosts)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", "", "", "", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
assert.Nil(t, fw.InRules.UDP[1].Any.Any) assert.Nil(t, fw.InRules.UDP[1].Any.Any)
assert.Contains(t, fw.InRules.UDP[1].Any.Groups[0].Groups, "g1") assert.Contains(t, fw.InRules.UDP[1].Any.Groups[0].Groups, "g1")
assert.Empty(t, fw.InRules.UDP[1].Any.Hosts) assert.Empty(t, fw.InRules.UDP[1].Any.Hosts)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(true, firewall.ProtoICMP, 1, 1, []string{}, "h1", "", "", "", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoICMP, 1, 1, []string{}, "h1", netip.Prefix{}, netip.Prefix{}, "", ""))
assert.Nil(t, fw.InRules.ICMP[1].Any.Any) assert.Nil(t, fw.InRules.ICMP[1].Any.Any)
assert.Empty(t, fw.InRules.ICMP[1].Any.Groups) assert.Empty(t, fw.InRules.ICMP[1].Any.Groups)
assert.Contains(t, fw.InRules.ICMP[1].Any.Hosts, "h1") assert.Contains(t, fw.InRules.ICMP[1].Any.Hosts, "h1")
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti.String(), "", "", "")) require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti, netip.Prefix{}, "", ""))
assert.Nil(t, fw.OutRules.AnyProto[1].Any.Any) assert.Nil(t, fw.OutRules.AnyProto[1].Any.Any)
_, ok := fw.OutRules.AnyProto[1].Any.CIDR.Get(ti) _, ok := fw.OutRules.AnyProto[1].Any.CIDR.Get(ti)
assert.True(t, ok) assert.True(t, ok)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti6.String(), "", "", "")) require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti6, netip.Prefix{}, "", ""))
assert.Nil(t, fw.OutRules.AnyProto[1].Any.Any) assert.Nil(t, fw.OutRules.AnyProto[1].Any.Any)
_, ok = fw.OutRules.AnyProto[1].Any.CIDR.Get(ti6) _, ok = fw.OutRules.AnyProto[1].Any.CIDR.Get(ti6)
assert.True(t, ok) assert.True(t, ok)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", "", ti.String(), "", "")) require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, ti, "", ""))
assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any) assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any)
ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti) ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti)
assert.True(t, ok) assert.True(t, ok)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", "", ti6.String(), "", "")) require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, ti6, "", ""))
assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any) assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any)
ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti6) ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti6)
assert.True(t, ok) assert.True(t, ok)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", "", "", "ca-name", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "ca-name", ""))
assert.Contains(t, fw.InRules.UDP[1].CANames, "ca-name") assert.Contains(t, fw.InRules.UDP[1].CANames, "ca-name")
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", "", "", "", "ca-sha")) require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "", "ca-sha"))
assert.Contains(t, fw.InRules.UDP[1].CAShas, "ca-sha") assert.Contains(t, fw.InRules.UDP[1].CAShas, "ca-sha")
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", "", "", "", "")) require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", netip.Prefix{}, netip.Prefix{}, "", ""))
assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any) assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
anyIp, err := netip.ParsePrefix("0.0.0.0/0") anyIp, err := netip.ParsePrefix("0.0.0.0/0")
require.NoError(t, err) require.NoError(t, err)
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp.String(), "", "", "")) require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp, netip.Prefix{}, "", ""))
assert.Nil(t, fw.OutRules.AnyProto[0].Any.Any) assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
table, ok := fw.OutRules.AnyProto[0].Any.CIDR.Lookup(netip.MustParseAddr("1.1.1.1"))
assert.True(t, table.Any)
table, ok = fw.OutRules.AnyProto[0].Any.CIDR.Lookup(netip.MustParseAddr("9::9"))
assert.False(t, ok)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
anyIp6, err := netip.ParsePrefix("::/0") anyIp6, err := netip.ParsePrefix("::/0")
require.NoError(t, err) require.NoError(t, err)
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp6.String(), "", "", "")) require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp6, netip.Prefix{}, "", ""))
assert.Nil(t, fw.OutRules.AnyProto[0].Any.Any)
table, ok = fw.OutRules.AnyProto[0].Any.CIDR.Lookup(netip.MustParseAddr("9::9"))
assert.True(t, table.Any)
table, ok = fw.OutRules.AnyProto[0].Any.CIDR.Lookup(netip.MustParseAddr("1.1.1.1"))
assert.False(t, ok)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", "any", "", "", ""))
assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", "", anyIp.String(), "", ""))
assert.False(t, fw.OutRules.AnyProto[0].Any.Any.Any)
assert.True(t, fw.OutRules.AnyProto[0].Any.Any.LocalCIDR.Lookup(netip.MustParseAddr("1.1.1.1")))
assert.False(t, fw.OutRules.AnyProto[0].Any.Any.LocalCIDR.Lookup(netip.MustParseAddr("9::9")))
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", "", anyIp6.String(), "", ""))
assert.False(t, fw.OutRules.AnyProto[0].Any.Any.Any)
assert.True(t, fw.OutRules.AnyProto[0].Any.Any.LocalCIDR.Lookup(netip.MustParseAddr("9::9")))
assert.False(t, fw.OutRules.AnyProto[0].Any.Any.LocalCIDR.Lookup(netip.MustParseAddr("1.1.1.1")))
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", "", "any", "", ""))
assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any) assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
// Test error conditions // Test error conditions
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.Error(t, fw.AddRule(true, math.MaxUint8, 0, 0, []string{}, "", "", "", "", "")) require.Error(t, fw.AddRule(true, math.MaxUint8, 0, 0, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
require.Error(t, fw.AddRule(true, firewall.ProtoAny, 10, 0, []string{}, "", "", "", "", "")) require.Error(t, fw.AddRule(true, firewall.ProtoAny, 10, 0, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
} }
func TestFirewall_Drop(t *testing.T) { func TestFirewall_Drop(t *testing.T) {
@@ -208,7 +180,7 @@ func TestFirewall_Drop(t *testing.T) {
h.buildNetworks(myVpnNetworksTable, &c) h.buildNetworks(myVpnNetworksTable, &c)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
cp := cert.NewCAPool() cp := cert.NewCAPool()
// Drop outbound // Drop outbound
@@ -227,28 +199,28 @@ func TestFirewall_Drop(t *testing.T) {
// ensure signer doesn't get in the way of group checks // ensure signer doesn't get in the way of group checks
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum"))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum-bad")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad"))
assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule) assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
// test caSha doesn't drop on match // test caSha doesn't drop on match
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum-bad")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad"))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum"))
require.NoError(t, fw.Drop(p, true, &h, cp, nil)) require.NoError(t, fw.Drop(p, true, &h, cp, nil))
// ensure ca name doesn't get in the way of group checks // ensure ca name doesn't get in the way of group checks
cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}} cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good-bad", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", ""))
assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule) assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
// test caName doesn't drop on match // test caName doesn't drop on match
cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}} cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good-bad", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", ""))
require.NoError(t, fw.Drop(p, true, &h, cp, nil)) require.NoError(t, fw.Drop(p, true, &h, cp, nil))
} }
@@ -287,7 +259,7 @@ func TestFirewall_DropV6(t *testing.T) {
h.buildNetworks(myVpnNetworksTable, &c) h.buildNetworks(myVpnNetworksTable, &c)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
cp := cert.NewCAPool() cp := cert.NewCAPool()
// Drop outbound // Drop outbound
@@ -306,28 +278,28 @@ func TestFirewall_DropV6(t *testing.T) {
// ensure signer doesn't get in the way of group checks // ensure signer doesn't get in the way of group checks
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum"))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum-bad")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad"))
assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule) assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
// test caSha doesn't drop on match // test caSha doesn't drop on match
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum-bad")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad"))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum"))
require.NoError(t, fw.Drop(p, true, &h, cp, nil)) require.NoError(t, fw.Drop(p, true, &h, cp, nil))
// ensure ca name doesn't get in the way of group checks // ensure ca name doesn't get in the way of group checks
cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}} cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good-bad", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", ""))
assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule) assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
// test caName doesn't drop on match // test caName doesn't drop on match
cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}} cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good-bad", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", ""))
require.NoError(t, fw.Drop(p, true, &h, cp, nil)) require.NoError(t, fw.Drop(p, true, &h, cp, nil))
} }
@@ -338,12 +310,12 @@ func BenchmarkFirewallTable_match(b *testing.B) {
} }
pfix := netip.MustParsePrefix("172.1.1.1/32") pfix := netip.MustParsePrefix("172.1.1.1/32")
_ = ft.TCP.addRule(f, 10, 10, []string{"good-group"}, "good-host", pfix.String(), "", "", "") _ = ft.TCP.addRule(f, 10, 10, []string{"good-group"}, "good-host", pfix, netip.Prefix{}, "", "")
_ = ft.TCP.addRule(f, 100, 100, []string{"good-group"}, "good-host", "", pfix.String(), "", "") _ = ft.TCP.addRule(f, 100, 100, []string{"good-group"}, "good-host", netip.Prefix{}, pfix, "", "")
pfix6 := netip.MustParsePrefix("fd11::11/128") pfix6 := netip.MustParsePrefix("fd11::11/128")
_ = ft.TCP.addRule(f, 10, 10, []string{"good-group"}, "good-host", pfix6.String(), "", "", "") _ = ft.TCP.addRule(f, 10, 10, []string{"good-group"}, "good-host", pfix6, netip.Prefix{}, "", "")
_ = ft.TCP.addRule(f, 100, 100, []string{"good-group"}, "good-host", "", pfix6.String(), "", "") _ = ft.TCP.addRule(f, 100, 100, []string{"good-group"}, "good-host", netip.Prefix{}, pfix6, "", "")
cp := cert.NewCAPool() cp := cert.NewCAPool()
b.Run("fail on proto", func(b *testing.B) { b.Run("fail on proto", func(b *testing.B) {
@@ -532,7 +504,7 @@ func TestFirewall_Drop2(t *testing.T) {
h1.buildNetworks(myVpnNetworksTable, c1.Certificate) h1.buildNetworks(myVpnNetworksTable, c1.Certificate)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", "", "", "", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
cp := cert.NewCAPool() cp := cert.NewCAPool()
// h1/c1 lacks the proper groups // h1/c1 lacks the proper groups
@@ -612,8 +584,8 @@ func TestFirewall_Drop3(t *testing.T) {
h3.buildNetworks(myVpnNetworksTable, c3.Certificate) h3.buildNetworks(myVpnNetworksTable, c3.Certificate)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", "", "", "", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", netip.Prefix{}, netip.Prefix{}, "", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", "", "", "", "signer-sha")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-sha"))
cp := cert.NewCAPool() cp := cert.NewCAPool()
// c1 should pass because host match // c1 should pass because host match
@@ -627,7 +599,7 @@ func TestFirewall_Drop3(t *testing.T) {
// Test a remote address match // Test a remote address match
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", "1.2.3.4/24", "", "", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.MustParsePrefix("1.2.3.4/24"), netip.Prefix{}, "", ""))
require.NoError(t, fw.Drop(p, true, &h1, cp, nil)) require.NoError(t, fw.Drop(p, true, &h1, cp, nil))
} }
@@ -665,7 +637,7 @@ func TestFirewall_Drop3V6(t *testing.T) {
// Test a remote address match // Test a remote address match
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
cp := cert.NewCAPool() cp := cert.NewCAPool()
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", "fd12::34/120", "", "", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.MustParsePrefix("fd12::34/120"), netip.Prefix{}, "", ""))
require.NoError(t, fw.Drop(p, true, &h, cp, nil)) require.NoError(t, fw.Drop(p, true, &h, cp, nil))
} }
@@ -704,7 +676,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
h.buildNetworks(myVpnNetworksTable, c.Certificate) h.buildNetworks(myVpnNetworksTable, c.Certificate)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
cp := cert.NewCAPool() cp := cert.NewCAPool()
// Drop outbound // Drop outbound
@@ -717,7 +689,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
oldFw := fw oldFw := fw
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 10, 10, []string{"any"}, "", "", "", "", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 10, 10, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
fw.Conntrack = oldFw.Conntrack fw.Conntrack = oldFw.Conntrack
fw.rulesVersion = oldFw.rulesVersion + 1 fw.rulesVersion = oldFw.rulesVersion + 1
@@ -726,7 +698,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
oldFw = fw oldFw = fw
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 11, 11, []string{"any"}, "", "", "", "", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 11, 11, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
fw.Conntrack = oldFw.Conntrack fw.Conntrack = oldFw.Conntrack
fw.rulesVersion = oldFw.rulesVersion + 1 fw.rulesVersion = oldFw.rulesVersion + 1
@@ -765,7 +737,7 @@ func TestFirewall_DropIPSpoofing(t *testing.T) {
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", "", "", "", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
cp := cert.NewCAPool() cp := cert.NewCAPool()
// Packet spoofed by `c1`. Note that the remote addr is not a valid one. // Packet spoofed by `c1`. Note that the remote addr is not a valid one.
@@ -959,28 +931,28 @@ func TestAddFirewallRulesFromConfig(t *testing.T) {
mf := &mockFirewall{} mf := &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "tcp", "host": "a"}}} conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "tcp", "host": "a"}}}
require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf)) require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf))
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoTCP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoTCP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
// Test adding udp rule // Test adding udp rule
conf = config.NewC(l) conf = config.NewC(l)
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "udp", "host": "a"}}} conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "udp", "host": "a"}}}
require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf)) require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf))
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoUDP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoUDP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
// Test adding icmp rule // Test adding icmp rule
conf = config.NewC(l) conf = config.NewC(l)
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "icmp", "host": "a"}}} conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "icmp", "host": "a"}}}
require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf)) require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf))
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
// Test adding any rule // Test adding any rule
conf = config.NewC(l) conf = config.NewC(l)
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "host": "a"}}} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "host": "a"}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
// Test adding rule with cidr // Test adding rule with cidr
cidr := netip.MustParsePrefix("10.0.0.0/8") cidr := netip.MustParsePrefix("10.0.0.0/8")
@@ -988,14 +960,14 @@ func TestAddFirewallRulesFromConfig(t *testing.T) {
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": cidr.String()}}} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": cidr.String()}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr.String(), localIp: ""}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr, localIp: netip.Prefix{}}, mf.lastCall)
// Test adding rule with local_cidr // Test adding rule with local_cidr
conf = config.NewC(l) conf = config.NewC(l)
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": cidr.String()}}} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": cidr.String()}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "", localIp: cidr.String()}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: cidr}, mf.lastCall)
// Test adding rule with cidr ipv6 // Test adding rule with cidr ipv6
cidr6 := netip.MustParsePrefix("fd00::/8") cidr6 := netip.MustParsePrefix("fd00::/8")
@@ -1003,75 +975,49 @@ func TestAddFirewallRulesFromConfig(t *testing.T) {
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": cidr6.String()}}} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": cidr6.String()}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr6.String(), localIp: ""}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr6, localIp: netip.Prefix{}}, mf.lastCall)
// Test adding rule with any cidr
conf = config.NewC(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": "any"}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "any", localIp: ""}, mf.lastCall)
// Test adding rule with junk cidr
conf = config.NewC(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": "junk/junk"}}}
require.EqualError(t, AddFirewallRulesFromConfig(l, true, conf, mf), "firewall.inbound rule #0; cidr did not parse; netip.ParsePrefix(\"junk/junk\"): ParseAddr(\"junk\"): unable to parse IP")
// Test adding rule with local_cidr ipv6 // Test adding rule with local_cidr ipv6
conf = config.NewC(l) conf = config.NewC(l)
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": cidr6.String()}}} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": cidr6.String()}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "", localIp: cidr6.String()}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: cidr6}, mf.lastCall)
// Test adding rule with any local_cidr
conf = config.NewC(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": "any"}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, localIp: "any"}, mf.lastCall)
// Test adding rule with junk local_cidr
conf = config.NewC(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": "junk/junk"}}}
require.EqualError(t, AddFirewallRulesFromConfig(l, true, conf, mf), "firewall.inbound rule #0; local_cidr did not parse; netip.ParsePrefix(\"junk/junk\"): ParseAddr(\"junk\"): unable to parse IP")
// Test adding rule with ca_sha // Test adding rule with ca_sha
conf = config.NewC(l) conf = config.NewC(l)
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "ca_sha": "12312313123"}}} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "ca_sha": "12312313123"}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "", localIp: "", caSha: "12312313123"}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: netip.Prefix{}, caSha: "12312313123"}, mf.lastCall)
// Test adding rule with ca_name // Test adding rule with ca_name
conf = config.NewC(l) conf = config.NewC(l)
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "ca_name": "root01"}}} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "ca_name": "root01"}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "", localIp: "", caName: "root01"}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: netip.Prefix{}, caName: "root01"}, mf.lastCall)
// Test single group // Test single group
conf = config.NewC(l) conf = config.NewC(l)
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "group": "a"}}} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "group": "a"}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: "", localIp: ""}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
// Test single groups // Test single groups
conf = config.NewC(l) conf = config.NewC(l)
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "groups": "a"}}} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "groups": "a"}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: "", localIp: ""}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
// Test multiple AND groups // Test multiple AND groups
conf = config.NewC(l) conf = config.NewC(l)
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "groups": []string{"a", "b"}}}} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "groups": []string{"a", "b"}}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a", "b"}, ip: "", localIp: ""}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a", "b"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
// Test Add error // Test Add error
conf = config.NewC(l) conf = config.NewC(l)
@@ -1094,7 +1040,7 @@ func TestFirewall_convertRule(t *testing.T) {
r, err := convertRule(l, c, "test", 1) r, err := convertRule(l, c, "test", 1)
assert.Contains(t, ob.String(), "test rule #1; group was an array with a single value, converting to simple value") assert.Contains(t, ob.String(), "test rule #1; group was an array with a single value, converting to simple value")
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, []string{"group1"}, r.Groups) assert.Equal(t, "group1", r.Group)
// Ensure group array of > 1 is errord // Ensure group array of > 1 is errord
ob.Reset() ob.Reset()
@@ -1114,63 +1060,7 @@ func TestFirewall_convertRule(t *testing.T) {
r, err = convertRule(l, c, "test", 1) r, err = convertRule(l, c, "test", 1)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, []string{"group1"}, r.Groups) assert.Equal(t, "group1", r.Group)
}
func TestFirewall_convertRuleSanity(t *testing.T) {
l := test.NewLogger()
ob := &bytes.Buffer{}
l.SetOutput(ob)
noWarningPlease := []map[string]any{
{"group": "group1"},
{"groups": []any{"group2"}},
{"host": "bob"},
{"cidr": "1.1.1.1/1"},
{"groups": []any{"group2"}, "host": "bob"},
{"cidr": "1.1.1.1/1", "host": "bob"},
{"groups": []any{"group2"}, "cidr": "1.1.1.1/1"},
{"groups": []any{"group2"}, "cidr": "1.1.1.1/1", "host": "bob"},
}
for _, c := range noWarningPlease {
r, err := convertRule(l, c, "test", 1)
require.NoError(t, err)
require.NoError(t, r.sanity(), "should not generate a sanity warning, %+v", c)
}
yesWarningPlease := []map[string]any{
{"group": "group1"},
{"groups": []any{"group2"}},
{"cidr": "1.1.1.1/1"},
{"groups": []any{"group2"}, "host": "bob"},
{"cidr": "1.1.1.1/1", "host": "bob"},
{"groups": []any{"group2"}, "cidr": "1.1.1.1/1"},
{"groups": []any{"group2"}, "cidr": "1.1.1.1/1", "host": "bob"},
}
for _, c := range yesWarningPlease {
c["host"] = "any"
r, err := convertRule(l, c, "test", 1)
require.NoError(t, err)
err = r.sanity()
require.Error(t, err, "I wanted a warning: %+v", c)
}
//reset the list
yesWarningPlease = []map[string]any{
{"group": "group1"},
{"groups": []any{"group2"}},
{"cidr": "1.1.1.1/1"},
{"groups": []any{"group2"}, "host": "bob"},
{"cidr": "1.1.1.1/1", "host": "bob"},
{"groups": []any{"group2"}, "cidr": "1.1.1.1/1"},
{"groups": []any{"group2"}, "cidr": "1.1.1.1/1", "host": "bob"},
}
for _, c := range yesWarningPlease {
r, err := convertRule(l, c, "test", 1)
require.NoError(t, err)
r.Groups = append(r.Groups, "any")
err = r.sanity()
require.Error(t, err, "I wanted a warning: %+v", c)
}
} }
type testcase struct { type testcase struct {
@@ -1251,7 +1141,7 @@ func newSetupFromCert(t *testing.T, l *logrus.Logger, c dummyCert) testsetup {
myVpnNetworksTable.Insert(prefix) myVpnNetworksTable.Insert(prefix)
} }
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
return testsetup{ return testsetup{
c: c, c: c,
@@ -1332,7 +1222,7 @@ func TestFirewall_Drop_EnforceIPMatch(t *testing.T) {
tc.p.LocalAddr = netip.MustParseAddr("192.168.0.3") tc.p.LocalAddr = netip.MustParseAddr("192.168.0.3")
tc.err = ErrNoMatchingRule tc.err = ErrNoMatchingRule
tc.Test(t, unsafeSetup.fw) //should hit firewall and bounce off tc.Test(t, unsafeSetup.fw) //should hit firewall and bounce off
require.NoError(t, unsafeSetup.fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", unsafePrefix.String(), "", "")) require.NoError(t, unsafeSetup.fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, unsafePrefix, "", ""))
tc.err = nil tc.err = nil
tc.Test(t, unsafeSetup.fw) //should pass tc.Test(t, unsafeSetup.fw) //should pass
}) })
@@ -1345,8 +1235,8 @@ type addRuleCall struct {
endPort int32 endPort int32
groups []string groups []string
host string host string
ip string ip netip.Prefix
localIp string localIp netip.Prefix
caName string caName string
caSha string caSha string
} }
@@ -1356,7 +1246,7 @@ type mockFirewall struct {
nextCallReturn error nextCallReturn error
} }
func (mf *mockFirewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip, localIp, caName string, caSha string) error { func (mf *mockFirewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip netip.Prefix, localIp netip.Prefix, caName string, caSha string) error {
mf.lastCall = addRuleCall{ mf.lastCall = addRuleCall{
incoming: incoming, incoming: incoming,
proto: proto, proto: proto,

4
go.mod
View File

@@ -23,9 +23,9 @@ require (
github.com/stretchr/testify v1.11.1 github.com/stretchr/testify v1.11.1
github.com/vishvananda/netlink v1.3.1 github.com/vishvananda/netlink v1.3.1
go.yaml.in/yaml/v3 v3.0.4 go.yaml.in/yaml/v3 v3.0.4
golang.org/x/crypto v0.45.0 golang.org/x/crypto v0.44.0
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 golang.org/x/exp v0.0.0-20230725093048-515e97ebf090
golang.org/x/net v0.47.0 golang.org/x/net v0.46.0
golang.org/x/sync v0.18.0 golang.org/x/sync v0.18.0
golang.org/x/sys v0.38.0 golang.org/x/sys v0.38.0
golang.org/x/term v0.37.0 golang.org/x/term v0.37.0

8
go.sum
View File

@@ -162,8 +162,8 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q= golang.org/x/crypto v0.44.0 h1:A97SsFvM3AIwEEmTBiaxPPTYpDC47w720rdiiUvgoAU=
golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4= golang.org/x/crypto v0.44.0/go.mod h1:013i+Nw79BMiQiMsOPcVCB5ZIJbYkerPrGnOa00tvmc=
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 h1:Di6/M8l0O2lCLc6VVRWhgCiApHV8MnQurBnFSHsQtNY= golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 h1:Di6/M8l0O2lCLc6VVRWhgCiApHV8MnQurBnFSHsQtNY=
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc= golang.org/x/exp v0.0.0-20230725093048-515e97ebf090/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc=
golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY= golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY=
@@ -182,8 +182,8 @@ golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLL
golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY= golang.org/x/net v0.46.0 h1:giFlY12I07fugqwPuWJi68oOnpfqFnJIJzaIIm2JVV4=
golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU= golang.org/x/net v0.46.0/go.mod h1:Q9BGdFy1y4nkUwiLvT5qtyhAnEHgnQ/zd8PfU6nc210=
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=

View File

@@ -99,11 +99,11 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
return true return true
} }
func ixHandshakeStage1(f *Interface, via *ViaSender, packet []byte, h *header.H) { func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet []byte, h *header.H) {
cs := f.pki.getCertState() cs := f.pki.getCertState()
crt := cs.GetDefaultCertificate() crt := cs.GetDefaultCertificate()
if crt == nil { if crt == nil {
f.l.WithField("from", via). f.l.WithField("udpAddr", addr).
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}). WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
WithField("certVersion", cs.initiatingVersion). WithField("certVersion", cs.initiatingVersion).
Error("Unable to handshake with host because no certificate is available") Error("Unable to handshake with host because no certificate is available")
@@ -112,7 +112,7 @@ func ixHandshakeStage1(f *Interface, via *ViaSender, packet []byte, h *header.H)
ci, err := NewConnectionState(f.l, cs, crt, false, noise.HandshakeIX) ci, err := NewConnectionState(f.l, cs, crt, false, noise.HandshakeIX)
if err != nil { if err != nil {
f.l.WithError(err).WithField("from", via). f.l.WithError(err).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
Error("Failed to create connection state") Error("Failed to create connection state")
return return
@@ -123,7 +123,7 @@ func ixHandshakeStage1(f *Interface, via *ViaSender, packet []byte, h *header.H)
msg, _, _, err := ci.H.ReadMessage(nil, packet[header.Len:]) msg, _, _, err := ci.H.ReadMessage(nil, packet[header.Len:])
if err != nil { if err != nil {
f.l.WithError(err).WithField("from", via). f.l.WithError(err).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
Error("Failed to call noise.ReadMessage") Error("Failed to call noise.ReadMessage")
return return
@@ -132,7 +132,7 @@ func ixHandshakeStage1(f *Interface, via *ViaSender, packet []byte, h *header.H)
hs := &NebulaHandshake{} hs := &NebulaHandshake{}
err = hs.Unmarshal(msg) err = hs.Unmarshal(msg)
if err != nil || hs.Details == nil { if err != nil || hs.Details == nil {
f.l.WithError(err).WithField("from", via). f.l.WithError(err).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
Error("Failed unmarshal handshake message") Error("Failed unmarshal handshake message")
return return
@@ -140,7 +140,7 @@ func ixHandshakeStage1(f *Interface, via *ViaSender, packet []byte, h *header.H)
rc, err := cert.Recombine(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), ci.Curve()) rc, err := cert.Recombine(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), ci.Curve())
if err != nil { if err != nil {
f.l.WithError(err).WithField("from", via). f.l.WithError(err).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
Info("Handshake did not contain a certificate") Info("Handshake did not contain a certificate")
return return
@@ -153,7 +153,7 @@ func ixHandshakeStage1(f *Interface, via *ViaSender, packet []byte, h *header.H)
fp = "<error generating certificate fingerprint>" fp = "<error generating certificate fingerprint>"
} }
e := f.l.WithError(err).WithField("from", via). e := f.l.WithError(err).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
WithField("certVpnNetworks", rc.Networks()). WithField("certVpnNetworks", rc.Networks()).
WithField("certFingerprint", fp) WithField("certFingerprint", fp)
@@ -172,7 +172,7 @@ func ixHandshakeStage1(f *Interface, via *ViaSender, packet []byte, h *header.H)
if myCertOtherVersion == nil { if myCertOtherVersion == nil {
if f.l.Level >= logrus.DebugLevel { if f.l.Level >= logrus.DebugLevel {
f.l.WithError(err).WithFields(m{ f.l.WithError(err).WithFields(m{
"from": via, "udpAddr": addr,
"handshake": m{"stage": 1, "style": "ix_psk0"}, "handshake": m{"stage": 1, "style": "ix_psk0"},
"cert": remoteCert, "cert": remoteCert,
}).Debug("Might be unable to handshake with host due to missing certificate version") }).Debug("Might be unable to handshake with host due to missing certificate version")
@@ -184,7 +184,7 @@ func ixHandshakeStage1(f *Interface, via *ViaSender, packet []byte, h *header.H)
} }
if len(remoteCert.Certificate.Networks()) == 0 { if len(remoteCert.Certificate.Networks()) == 0 {
f.l.WithError(err).WithField("from", via). f.l.WithError(err).WithField("udpAddr", addr).
WithField("cert", remoteCert). WithField("cert", remoteCert).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
Info("No networks in certificate") Info("No networks in certificate")
@@ -201,7 +201,7 @@ func ixHandshakeStage1(f *Interface, via *ViaSender, packet []byte, h *header.H)
vpnAddrs := make([]netip.Addr, len(vpnNetworks)) vpnAddrs := make([]netip.Addr, len(vpnNetworks))
for i, network := range vpnNetworks { for i, network := range vpnNetworks {
if f.myVpnAddrsTable.Contains(network.Addr()) { if f.myVpnAddrsTable.Contains(network.Addr()) {
f.l.WithField("vpnNetworks", vpnNetworks).WithField("from", via). f.l.WithField("vpnNetworks", vpnNetworks).WithField("udpAddr", addr).
WithField("certName", certName). WithField("certName", certName).
WithField("certVersion", certVersion). WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint). WithField("fingerprint", fingerprint).
@@ -215,18 +215,18 @@ func ixHandshakeStage1(f *Interface, via *ViaSender, packet []byte, h *header.H)
} }
} }
if !via.IsRelayed { if addr.IsValid() {
// addr can be invalid when the tunnel is being relayed.
// We only want to apply the remote allow list for direct tunnels here // We only want to apply the remote allow list for direct tunnels here
if !f.lightHouse.GetRemoteAllowList().AllowAll(vpnAddrs, via.UdpAddr.Addr()) { if !f.lightHouse.GetRemoteAllowList().AllowAll(vpnAddrs, addr.Addr()) {
f.l.WithField("vpnAddrs", vpnAddrs).WithField("from", via). f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
Debug("lighthouse.remote_allow_list denied incoming handshake")
return return
} }
} }
myIndex, err := generateIndex(f.l) myIndex, err := generateIndex(f.l)
if err != nil { if err != nil {
f.l.WithError(err).WithField("vpnAddrs", vpnAddrs).WithField("from", via). f.l.WithError(err).WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
WithField("certName", certName). WithField("certName", certName).
WithField("certVersion", certVersion). WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint). WithField("fingerprint", fingerprint).
@@ -251,7 +251,7 @@ func ixHandshakeStage1(f *Interface, via *ViaSender, packet []byte, h *header.H)
msgRxL := f.l.WithFields(m{ msgRxL := f.l.WithFields(m{
"vpnAddrs": vpnAddrs, "vpnAddrs": vpnAddrs,
"from": via, "udpAddr": addr,
"certName": certName, "certName": certName,
"certVersion": certVersion, "certVersion": certVersion,
"fingerprint": fingerprint, "fingerprint": fingerprint,
@@ -283,7 +283,7 @@ func ixHandshakeStage1(f *Interface, via *ViaSender, packet []byte, h *header.H)
hsBytes, err := hs.Marshal() hsBytes, err := hs.Marshal()
if err != nil { if err != nil {
f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via). f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).
WithField("certName", certName). WithField("certName", certName).
WithField("certVersion", certVersion). WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint). WithField("fingerprint", fingerprint).
@@ -295,7 +295,7 @@ func ixHandshakeStage1(f *Interface, via *ViaSender, packet []byte, h *header.H)
nh := header.Encode(make([]byte, header.Len), header.Version, header.Handshake, header.HandshakeIXPSK0, hs.Details.InitiatorIndex, 2) nh := header.Encode(make([]byte, header.Len), header.Version, header.Handshake, header.HandshakeIXPSK0, hs.Details.InitiatorIndex, 2)
msg, dKey, eKey, err := ci.H.WriteMessage(nh, hsBytes) msg, dKey, eKey, err := ci.H.WriteMessage(nh, hsBytes)
if err != nil { if err != nil {
f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via). f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).
WithField("certName", certName). WithField("certName", certName).
WithField("certVersion", certVersion). WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint). WithField("fingerprint", fingerprint).
@@ -303,7 +303,7 @@ func ixHandshakeStage1(f *Interface, via *ViaSender, packet []byte, h *header.H)
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage") WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage")
return return
} else if dKey == nil || eKey == nil { } else if dKey == nil || eKey == nil {
f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via). f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).
WithField("certName", certName). WithField("certName", certName).
WithField("certVersion", certVersion). WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint). WithField("fingerprint", fingerprint).
@@ -329,9 +329,7 @@ func ixHandshakeStage1(f *Interface, via *ViaSender, packet []byte, h *header.H)
ci.eKey = NewNebulaCipherState(eKey) ci.eKey = NewNebulaCipherState(eKey)
hostinfo.remotes = f.lightHouse.QueryCache(vpnAddrs) hostinfo.remotes = f.lightHouse.QueryCache(vpnAddrs)
if !via.IsRelayed { hostinfo.SetRemote(addr)
hostinfo.SetRemote(via.UdpAddr)
}
hostinfo.buildNetworks(f.myVpnNetworksTable, remoteCert.Certificate) hostinfo.buildNetworks(f.myVpnNetworksTable, remoteCert.Certificate)
existing, err := f.handshakeManager.CheckAndComplete(hostinfo, 0, f) existing, err := f.handshakeManager.CheckAndComplete(hostinfo, 0, f)
@@ -339,7 +337,7 @@ func ixHandshakeStage1(f *Interface, via *ViaSender, packet []byte, h *header.H)
switch err { switch err {
case ErrAlreadySeen: case ErrAlreadySeen:
// Update remote if preferred // Update remote if preferred
if existing.SetRemoteIfPreferred(f.hostMap, via) { if existing.SetRemoteIfPreferred(f.hostMap, addr) {
// Send a test packet to ensure the other side has also switched to // Send a test packet to ensure the other side has also switched to
// the preferred remote // the preferred remote
f.SendMessageToVpnAddr(header.Test, header.TestRequest, vpnAddrs[0], []byte(""), make([]byte, 12, 12), make([]byte, mtu)) f.SendMessageToVpnAddr(header.Test, header.TestRequest, vpnAddrs[0], []byte(""), make([]byte, 12, 12), make([]byte, mtu))
@@ -347,21 +345,21 @@ func ixHandshakeStage1(f *Interface, via *ViaSender, packet []byte, h *header.H)
msg = existing.HandshakePacket[2] msg = existing.HandshakePacket[2]
f.messageMetrics.Tx(header.Handshake, header.MessageSubType(msg[1]), 1) f.messageMetrics.Tx(header.Handshake, header.MessageSubType(msg[1]), 1)
if !via.IsRelayed { if addr.IsValid() {
err := f.outside.WriteTo(msg, via.UdpAddr) err := f.outside.WriteTo(msg, addr)
if err != nil { if err != nil {
f.l.WithField("vpnAddrs", existing.vpnAddrs).WithField("from", via). f.l.WithField("vpnAddrs", existing.vpnAddrs).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true).
WithError(err).Error("Failed to send handshake message") WithError(err).Error("Failed to send handshake message")
} else { } else {
f.l.WithField("vpnAddrs", existing.vpnAddrs).WithField("from", via). f.l.WithField("vpnAddrs", existing.vpnAddrs).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true).
Info("Handshake message sent") Info("Handshake message sent")
} }
return return
} else { } else {
if via.relay == nil { if via == nil {
f.l.Error("Handshake send failed: both addr and via.relay are nil.") f.l.Error("Handshake send failed: both addr and via are nil.")
return return
} }
hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0]) hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0])
@@ -373,7 +371,7 @@ func ixHandshakeStage1(f *Interface, via *ViaSender, packet []byte, h *header.H)
} }
case ErrExistingHostInfo: case ErrExistingHostInfo:
// This means there was an existing tunnel and this handshake was older than the one we are currently based on // This means there was an existing tunnel and this handshake was older than the one we are currently based on
f.l.WithField("vpnAddrs", vpnAddrs).WithField("from", via). f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
WithField("certName", certName). WithField("certName", certName).
WithField("certVersion", certVersion). WithField("certVersion", certVersion).
WithField("oldHandshakeTime", existing.lastHandshakeTime). WithField("oldHandshakeTime", existing.lastHandshakeTime).
@@ -389,7 +387,7 @@ func ixHandshakeStage1(f *Interface, via *ViaSender, packet []byte, h *header.H)
return return
case ErrLocalIndexCollision: case ErrLocalIndexCollision:
// This means we failed to insert because of collision on localIndexId. Just let the next handshake packet retry // This means we failed to insert because of collision on localIndexId. Just let the next handshake packet retry
f.l.WithField("vpnAddrs", vpnAddrs).WithField("from", via). f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
WithField("certName", certName). WithField("certName", certName).
WithField("certVersion", certVersion). WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint). WithField("fingerprint", fingerprint).
@@ -402,7 +400,7 @@ func ixHandshakeStage1(f *Interface, via *ViaSender, packet []byte, h *header.H)
default: default:
// Shouldn't happen, but just in case someone adds a new error type to CheckAndComplete // Shouldn't happen, but just in case someone adds a new error type to CheckAndComplete
// And we forget to update it here // And we forget to update it here
f.l.WithError(err).WithField("vpnAddrs", vpnAddrs).WithField("from", via). f.l.WithError(err).WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
WithField("certName", certName). WithField("certName", certName).
WithField("certVersion", certVersion). WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint). WithField("fingerprint", fingerprint).
@@ -416,23 +414,30 @@ func ixHandshakeStage1(f *Interface, via *ViaSender, packet []byte, h *header.H)
// Do the send // Do the send
f.messageMetrics.Tx(header.Handshake, header.MessageSubType(msg[1]), 1) f.messageMetrics.Tx(header.Handshake, header.MessageSubType(msg[1]), 1)
if !via.IsRelayed { if addr.IsValid() {
err = f.outside.WriteTo(msg, via.UdpAddr) err = f.outside.WriteTo(msg, addr)
log := f.l.WithField("vpnAddrs", vpnAddrs).WithField("from", via).
WithField("certName", certName).
WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint).
WithField("issuer", issuer).
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"})
if err != nil { if err != nil {
log.WithError(err).Error("Failed to send handshake") f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint).
WithField("issuer", issuer).
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
WithError(err).Error("Failed to send handshake")
} else { } else {
log.Info("Handshake message sent") f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint).
WithField("issuer", issuer).
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
Info("Handshake message sent")
} }
} else { } else {
if via.relay == nil { if via == nil {
f.l.Error("Handshake send failed: both addr and via.relay are nil.") f.l.Error("Handshake send failed: both addr and via are nil.")
return return
} }
hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0]) hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0])
@@ -457,7 +462,7 @@ func ixHandshakeStage1(f *Interface, via *ViaSender, packet []byte, h *header.H)
return return
} }
func ixHandshakeStage2(f *Interface, via *ViaSender, hh *HandshakeHostInfo, packet []byte, h *header.H) bool { func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *HandshakeHostInfo, packet []byte, h *header.H) bool {
if hh == nil { if hh == nil {
// Nothing here to tear down, got a bogus stage 2 packet // Nothing here to tear down, got a bogus stage 2 packet
return true return true
@@ -467,10 +472,10 @@ func ixHandshakeStage2(f *Interface, via *ViaSender, hh *HandshakeHostInfo, pack
defer hh.Unlock() defer hh.Unlock()
hostinfo := hh.hostinfo hostinfo := hh.hostinfo
if !via.IsRelayed { if addr.IsValid() {
// The vpnAddr we know about is the one we tried to handshake with, use it to apply the remote allow list. // The vpnAddr we know about is the one we tried to handshake with, use it to apply the remote allow list.
if !f.lightHouse.GetRemoteAllowList().AllowAll(hostinfo.vpnAddrs, via.UdpAddr.Addr()) { if !f.lightHouse.GetRemoteAllowList().AllowAll(hostinfo.vpnAddrs, addr.Addr()) {
f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via).Debug("lighthouse.remote_allow_list denied incoming handshake") f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
return false return false
} }
} }
@@ -478,7 +483,7 @@ func ixHandshakeStage2(f *Interface, via *ViaSender, hh *HandshakeHostInfo, pack
ci := hostinfo.ConnectionState ci := hostinfo.ConnectionState
msg, eKey, dKey, err := ci.H.ReadMessage(nil, packet[header.Len:]) msg, eKey, dKey, err := ci.H.ReadMessage(nil, packet[header.Len:])
if err != nil { if err != nil {
f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via). f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("header", h). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("header", h).
Error("Failed to call noise.ReadMessage") Error("Failed to call noise.ReadMessage")
@@ -487,7 +492,7 @@ func ixHandshakeStage2(f *Interface, via *ViaSender, hh *HandshakeHostInfo, pack
// near future // near future
return false return false
} else if dKey == nil || eKey == nil { } else if dKey == nil || eKey == nil {
f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via). f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
Error("Noise did not arrive at a key") Error("Noise did not arrive at a key")
@@ -499,7 +504,7 @@ func ixHandshakeStage2(f *Interface, via *ViaSender, hh *HandshakeHostInfo, pack
hs := &NebulaHandshake{} hs := &NebulaHandshake{}
err = hs.Unmarshal(msg) err = hs.Unmarshal(msg)
if err != nil || hs.Details == nil { if err != nil || hs.Details == nil {
f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via). f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).Error("Failed unmarshal handshake message") WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).Error("Failed unmarshal handshake message")
// The handshake state machine is complete, if things break now there is no chance to recover. Tear down and start again // The handshake state machine is complete, if things break now there is no chance to recover. Tear down and start again
@@ -508,7 +513,7 @@ func ixHandshakeStage2(f *Interface, via *ViaSender, hh *HandshakeHostInfo, pack
rc, err := cert.Recombine(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), ci.Curve()) rc, err := cert.Recombine(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), ci.Curve())
if err != nil { if err != nil {
f.l.WithError(err).WithField("from", via). f.l.WithError(err).WithField("udpAddr", addr).
WithField("vpnAddrs", hostinfo.vpnAddrs). WithField("vpnAddrs", hostinfo.vpnAddrs).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
Info("Handshake did not contain a certificate") Info("Handshake did not contain a certificate")
@@ -522,7 +527,7 @@ func ixHandshakeStage2(f *Interface, via *ViaSender, hh *HandshakeHostInfo, pack
fp = "<error generating certificate fingerprint>" fp = "<error generating certificate fingerprint>"
} }
e := f.l.WithError(err).WithField("from", via). e := f.l.WithError(err).WithField("udpAddr", addr).
WithField("vpnAddrs", hostinfo.vpnAddrs). WithField("vpnAddrs", hostinfo.vpnAddrs).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
WithField("certFingerprint", fp). WithField("certFingerprint", fp).
@@ -537,7 +542,7 @@ func ixHandshakeStage2(f *Interface, via *ViaSender, hh *HandshakeHostInfo, pack
} }
if len(remoteCert.Certificate.Networks()) == 0 { if len(remoteCert.Certificate.Networks()) == 0 {
f.l.WithError(err).WithField("from", via). f.l.WithError(err).WithField("udpAddr", addr).
WithField("vpnAddrs", hostinfo.vpnAddrs). WithField("vpnAddrs", hostinfo.vpnAddrs).
WithField("cert", remoteCert). WithField("cert", remoteCert).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
@@ -560,8 +565,8 @@ func ixHandshakeStage2(f *Interface, via *ViaSender, hh *HandshakeHostInfo, pack
ci.eKey = NewNebulaCipherState(eKey) ci.eKey = NewNebulaCipherState(eKey)
// Make sure the current udpAddr being used is set for responding // Make sure the current udpAddr being used is set for responding
if !via.IsRelayed { if addr.IsValid() {
hostinfo.SetRemote(via.UdpAddr) hostinfo.SetRemote(addr)
} else { } else {
hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0]) hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0])
} }
@@ -583,7 +588,7 @@ func ixHandshakeStage2(f *Interface, via *ViaSender, hh *HandshakeHostInfo, pack
// Ensure the right host responded // Ensure the right host responded
if !correctHostResponded { if !correctHostResponded {
f.l.WithField("intendedVpnAddrs", hostinfo.vpnAddrs).WithField("haveVpnNetworks", vpnNetworks). f.l.WithField("intendedVpnAddrs", hostinfo.vpnAddrs).WithField("haveVpnNetworks", vpnNetworks).
WithField("from", via). WithField("udpAddr", addr).
WithField("certName", certName). WithField("certName", certName).
WithField("certVersion", certVersion). WithField("certVersion", certVersion).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
@@ -597,7 +602,7 @@ func ixHandshakeStage2(f *Interface, via *ViaSender, hh *HandshakeHostInfo, pack
f.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], func(newHH *HandshakeHostInfo) { f.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], func(newHH *HandshakeHostInfo) {
// Block the current used address // Block the current used address
newHH.hostinfo.remotes = hostinfo.remotes newHH.hostinfo.remotes = hostinfo.remotes
newHH.hostinfo.remotes.BlockRemote(via) newHH.hostinfo.remotes.BlockRemote(addr)
f.l.WithField("blockedUdpAddrs", newHH.hostinfo.remotes.CopyBlockedRemotes()). f.l.WithField("blockedUdpAddrs", newHH.hostinfo.remotes.CopyBlockedRemotes()).
WithField("vpnNetworks", vpnNetworks). WithField("vpnNetworks", vpnNetworks).
@@ -620,7 +625,7 @@ func ixHandshakeStage2(f *Interface, via *ViaSender, hh *HandshakeHostInfo, pack
ci.window.Update(f.l, 2) ci.window.Update(f.l, 2)
duration := time.Since(hh.startTime).Nanoseconds() duration := time.Since(hh.startTime).Nanoseconds()
msgRxL := f.l.WithField("vpnAddrs", vpnAddrs).WithField("from", via). msgRxL := f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
WithField("certName", certName). WithField("certName", certName).
WithField("certVersion", certVersion). WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint). WithField("fingerprint", fingerprint).

View File

@@ -136,11 +136,11 @@ func (hm *HandshakeManager) Run(ctx context.Context) {
} }
} }
func (hm *HandshakeManager) HandleIncoming(via *ViaSender, packet []byte, h *header.H) { func (hm *HandshakeManager) HandleIncoming(addr netip.AddrPort, via *ViaSender, packet []byte, h *header.H) {
// First remote allow list check before we know the vpnIp // First remote allow list check before we know the vpnIp
if !via.IsRelayed { if addr.IsValid() {
if !hm.lightHouse.GetRemoteAllowList().AllowUnknownVpnAddr(via.UdpAddr.Addr()) { if !hm.lightHouse.GetRemoteAllowList().AllowUnknownVpnAddr(addr.Addr()) {
hm.l.WithField("from", via).Debug("lighthouse.remote_allow_list denied incoming handshake") hm.l.WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
return return
} }
} }
@@ -149,11 +149,11 @@ func (hm *HandshakeManager) HandleIncoming(via *ViaSender, packet []byte, h *hea
case header.HandshakeIXPSK0: case header.HandshakeIXPSK0:
switch h.MessageCounter { switch h.MessageCounter {
case 1: case 1:
ixHandshakeStage1(hm.f, via, packet, h) ixHandshakeStage1(hm.f, addr, via, packet, h)
case 2: case 2:
newHostinfo := hm.queryIndex(h.RemoteIndex) newHostinfo := hm.queryIndex(h.RemoteIndex)
tearDown := ixHandshakeStage2(hm.f, via, newHostinfo, packet, h) tearDown := ixHandshakeStage2(hm.f, addr, via, newHostinfo, packet, h)
if tearDown && newHostinfo != nil { if tearDown && newHostinfo != nil {
hm.DeleteHostInfo(newHostinfo.hostinfo) hm.DeleteHostInfo(newHostinfo.hostinfo)
} }

View File

@@ -1,9 +1,7 @@
package nebula package nebula
import ( import (
"encoding/json"
"errors" "errors"
"fmt"
"net" "net"
"net/netip" "net/netip"
"slices" "slices"
@@ -278,25 +276,9 @@ type HostInfo struct {
} }
type ViaSender struct { type ViaSender struct {
UdpAddr netip.AddrPort
relayHI *HostInfo // relayHI is the host info object of the relay relayHI *HostInfo // relayHI is the host info object of the relay
remoteIdx uint32 // remoteIdx is the index included in the header of the received packet remoteIdx uint32 // remoteIdx is the index included in the header of the received packet
relay *Relay // relay contains the rest of the relay information, including the PeerIP of the host trying to communicate with us. relay *Relay // relay contains the rest of the relay information, including the PeerIP of the host trying to communicate with us.
IsRelayed bool // IsRelayed is true if the packet was sent through a relay
}
func (v *ViaSender) String() string {
if v.IsRelayed {
return fmt.Sprintf("%s (relayed)", v.UdpAddr)
}
return v.UdpAddr.String()
}
func (v *ViaSender) MarshalJSON() ([]byte, error) {
if v.IsRelayed {
return json.Marshal(m{"direct": v.UdpAddr})
}
return json.Marshal(m{"relay": v.UdpAddr})
} }
type cachedPacket struct { type cachedPacket struct {
@@ -712,7 +694,6 @@ func (i *HostInfo) GetCert() *cert.CachedCertificate {
return nil return nil
} }
// TODO: Maybe use ViaSender here?
func (i *HostInfo) SetRemote(remote netip.AddrPort) { func (i *HostInfo) SetRemote(remote netip.AddrPort) {
// We copy here because we likely got this remote from a source that reuses the object // We copy here because we likely got this remote from a source that reuses the object
if i.remote != remote { if i.remote != remote {
@@ -723,14 +704,14 @@ func (i *HostInfo) SetRemote(remote netip.AddrPort) {
// SetRemoteIfPreferred returns true if the remote was changed. The lastRoam // SetRemoteIfPreferred returns true if the remote was changed. The lastRoam
// time on the HostInfo will also be updated. // time on the HostInfo will also be updated.
func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, via *ViaSender) bool { func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, newRemote netip.AddrPort) bool {
if via.IsRelayed { if !newRemote.IsValid() {
// relays have nil udp Addrs
return false return false
} }
currentRemote := i.remote currentRemote := i.remote
if !currentRemote.IsValid() { if !currentRemote.IsValid() {
i.SetRemote(via.UdpAddr) i.SetRemote(newRemote)
return true return true
} }
@@ -743,7 +724,7 @@ func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, via *ViaSender) bool {
return false return false
} }
if l.Contains(via.UdpAddr.Addr()) { if l.Contains(newRemote.Addr()) {
newIsPreferred = true newIsPreferred = true
} }
} }
@@ -753,7 +734,7 @@ func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, via *ViaSender) bool {
i.lastRoam = time.Now() i.lastRoam = time.Now()
i.lastRoamRemote = currentRemote i.lastRoamRemote = currentRemote
i.SetRemote(via.UdpAddr) i.SetRemote(newRemote)
return true return true
} }

View File

@@ -222,13 +222,6 @@ func (f *Interface) activate() {
WithField("boringcrypto", boringEnabled()). WithField("boringcrypto", boringEnabled()).
Info("Nebula interface is active") Info("Nebula interface is active")
if f.routines > 1 {
if !f.inside.SupportsMultiqueue() || !f.outside.SupportsMultipleReaders() {
f.routines = 1
f.l.Warn("routines is not supported on this platform, falling back to a single routine")
}
}
metrics.GetOrRegisterGauge("routines", nil).Update(int64(f.routines)) metrics.GetOrRegisterGauge("routines", nil).Update(int64(f.routines))
// Prepare n tun queues // Prepare n tun queues
@@ -279,7 +272,7 @@ func (f *Interface) listenOut(i int) {
nb := make([]byte, 12, 12) nb := make([]byte, 12, 12)
li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) { li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) {
f.readOutsidePackets(&ViaSender{UdpAddr: fromUdpAddr}, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get(f.l)) f.readOutsidePackets(fromUdpAddr, nil, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get(f.l))
}) })
} }

View File

@@ -19,21 +19,21 @@ const (
minFwPacketLen = 4 minFwPacketLen = 4
) )
func (f *Interface) readOutsidePackets(via *ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache) { func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache) {
err := h.Parse(packet) err := h.Parse(packet)
if err != nil { if err != nil {
// Hole punch packets are 0 or 1 byte big, so lets ignore printing those errors // Hole punch packets are 0 or 1 byte big, so lets ignore printing those errors
if len(packet) > 1 { if len(packet) > 1 {
f.l.WithField("packet", packet).Infof("Error while parsing inbound packet from %s: %s", via, err) f.l.WithField("packet", packet).Infof("Error while parsing inbound packet from %s: %s", ip, err)
} }
return return
} }
//l.Error("in packet ", header, packet[HeaderLen:]) //l.Error("in packet ", header, packet[HeaderLen:])
if !via.IsRelayed { if ip.IsValid() {
if f.myVpnNetworksTable.Contains(via.UdpAddr.Addr()) { if f.myVpnNetworksTable.Contains(ip.Addr()) {
if f.l.Level >= logrus.DebugLevel { if f.l.Level >= logrus.DebugLevel {
f.l.WithField("from", via).Debug("Refusing to process double encrypted packet") f.l.WithField("udpAddr", ip).Debug("Refusing to process double encrypted packet")
} }
return return
} }
@@ -54,7 +54,8 @@ func (f *Interface) readOutsidePackets(via *ViaSender, out []byte, packet []byte
switch h.Type { switch h.Type {
case header.Message: case header.Message:
if !f.handleEncrypted(ci, via, h) { // TODO handleEncrypted sends directly to addr on error. Handle this in the tunneling case.
if !f.handleEncrypted(ci, ip, h) {
return return
} }
@@ -78,7 +79,7 @@ func (f *Interface) readOutsidePackets(via *ViaSender, out []byte, packet []byte
// Successfully validated the thing. Get rid of the Relay header. // Successfully validated the thing. Get rid of the Relay header.
signedPayload = signedPayload[header.Len:] signedPayload = signedPayload[header.Len:]
// Pull the Roaming parts up here, and return in all call paths. // Pull the Roaming parts up here, and return in all call paths.
f.handleHostRoaming(hostinfo, via) f.handleHostRoaming(hostinfo, ip)
// Track usage of both the HostInfo and the Relay for the received & authenticated packet // Track usage of both the HostInfo and the Relay for the received & authenticated packet
f.connectionManager.In(hostinfo) f.connectionManager.In(hostinfo)
f.connectionManager.RelayUsed(h.RemoteIndex) f.connectionManager.RelayUsed(h.RemoteIndex)
@@ -95,14 +96,7 @@ func (f *Interface) readOutsidePackets(via *ViaSender, out []byte, packet []byte
case TerminalType: case TerminalType:
// If I am the target of this relay, process the unwrapped packet // If I am the target of this relay, process the unwrapped packet
// From this recursive point, all these variables are 'burned'. We shouldn't rely on them again. // From this recursive point, all these variables are 'burned'. We shouldn't rely on them again.
rVia := ViaSender{ f.readOutsidePackets(netip.AddrPort{}, &ViaSender{relayHI: hostinfo, remoteIdx: relay.RemoteIndex, relay: relay}, out[:0], signedPayload, h, fwPacket, lhf, nb, q, localCache)
UdpAddr: via.UdpAddr,
relayHI: hostinfo,
remoteIdx: relay.RemoteIndex,
relay: relay,
IsRelayed: true,
}
f.readOutsidePackets(&rVia, out[:0], signedPayload, h, fwPacket, lhf, nb, q, localCache)
return return
case ForwardingType: case ForwardingType:
// Find the target HostInfo relay object // Find the target HostInfo relay object
@@ -132,32 +126,31 @@ func (f *Interface) readOutsidePackets(via *ViaSender, out []byte, packet []byte
case header.LightHouse: case header.LightHouse:
f.messageMetrics.Rx(h.Type, h.Subtype, 1) f.messageMetrics.Rx(h.Type, h.Subtype, 1)
if !f.handleEncrypted(ci, via, h) { if !f.handleEncrypted(ci, ip, h) {
return return
} }
d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb) d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb)
if err != nil { if err != nil {
hostinfo.logger(f.l).WithError(err).WithField("from", via). hostinfo.logger(f.l).WithError(err).WithField("udpAddr", ip).
WithField("packet", packet). WithField("packet", packet).
Error("Failed to decrypt lighthouse packet") Error("Failed to decrypt lighthouse packet")
return return
} }
//TODO: assert via is not relayed lhf.HandleRequest(ip, hostinfo.vpnAddrs, d, f)
lhf.HandleRequest(via.UdpAddr, hostinfo.vpnAddrs, d, f)
// Fallthrough to the bottom to record incoming traffic // Fallthrough to the bottom to record incoming traffic
case header.Test: case header.Test:
f.messageMetrics.Rx(h.Type, h.Subtype, 1) f.messageMetrics.Rx(h.Type, h.Subtype, 1)
if !f.handleEncrypted(ci, via, h) { if !f.handleEncrypted(ci, ip, h) {
return return
} }
d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb) d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb)
if err != nil { if err != nil {
hostinfo.logger(f.l).WithError(err).WithField("from", via). hostinfo.logger(f.l).WithError(err).WithField("udpAddr", ip).
WithField("packet", packet). WithField("packet", packet).
Error("Failed to decrypt test packet") Error("Failed to decrypt test packet")
return return
@@ -166,7 +159,7 @@ func (f *Interface) readOutsidePackets(via *ViaSender, out []byte, packet []byte
if h.Subtype == header.TestRequest { if h.Subtype == header.TestRequest {
// This testRequest might be from TryPromoteBest, so we should roam // This testRequest might be from TryPromoteBest, so we should roam
// to the new IP address before responding // to the new IP address before responding
f.handleHostRoaming(hostinfo, via) f.handleHostRoaming(hostinfo, ip)
f.send(header.Test, header.TestReply, ci, hostinfo, d, nb, out) f.send(header.Test, header.TestReply, ci, hostinfo, d, nb, out)
} }
@@ -177,34 +170,34 @@ func (f *Interface) readOutsidePackets(via *ViaSender, out []byte, packet []byte
case header.Handshake: case header.Handshake:
f.messageMetrics.Rx(h.Type, h.Subtype, 1) f.messageMetrics.Rx(h.Type, h.Subtype, 1)
f.handshakeManager.HandleIncoming(via, packet, h) f.handshakeManager.HandleIncoming(ip, via, packet, h)
return return
case header.RecvError: case header.RecvError:
f.messageMetrics.Rx(h.Type, h.Subtype, 1) f.messageMetrics.Rx(h.Type, h.Subtype, 1)
f.handleRecvError(via.UdpAddr, h) f.handleRecvError(ip, h)
return return
case header.CloseTunnel: case header.CloseTunnel:
f.messageMetrics.Rx(h.Type, h.Subtype, 1) f.messageMetrics.Rx(h.Type, h.Subtype, 1)
if !f.handleEncrypted(ci, via, h) { if !f.handleEncrypted(ci, ip, h) {
return return
} }
hostinfo.logger(f.l).WithField("from", via). hostinfo.logger(f.l).WithField("udpAddr", ip).
Info("Close tunnel received, tearing down.") Info("Close tunnel received, tearing down.")
f.closeTunnel(hostinfo) f.closeTunnel(hostinfo)
return return
case header.Control: case header.Control:
if !f.handleEncrypted(ci, via, h) { if !f.handleEncrypted(ci, ip, h) {
return return
} }
d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb) d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb)
if err != nil { if err != nil {
hostinfo.logger(f.l).WithError(err).WithField("from", via). hostinfo.logger(f.l).WithError(err).WithField("udpAddr", ip).
WithField("packet", packet). WithField("packet", packet).
Error("Failed to decrypt Control packet") Error("Failed to decrypt Control packet")
return return
@@ -214,11 +207,11 @@ func (f *Interface) readOutsidePackets(via *ViaSender, out []byte, packet []byte
default: default:
f.messageMetrics.Rx(h.Type, h.Subtype, 1) f.messageMetrics.Rx(h.Type, h.Subtype, 1)
hostinfo.logger(f.l).Debugf("Unexpected packet received from %s", via) hostinfo.logger(f.l).Debugf("Unexpected packet received from %s", ip)
return return
} }
f.handleHostRoaming(hostinfo, via) f.handleHostRoaming(hostinfo, ip)
f.connectionManager.In(hostinfo) f.connectionManager.In(hostinfo)
} }
@@ -237,36 +230,36 @@ func (f *Interface) sendCloseTunnel(h *HostInfo) {
f.send(header.CloseTunnel, 0, h.ConnectionState, h, []byte{}, make([]byte, 12, 12), make([]byte, mtu)) f.send(header.CloseTunnel, 0, h.ConnectionState, h, []byte{}, make([]byte, 12, 12), make([]byte, mtu))
} }
func (f *Interface) handleHostRoaming(hostinfo *HostInfo, via *ViaSender) { func (f *Interface) handleHostRoaming(hostinfo *HostInfo, udpAddr netip.AddrPort) {
if !via.IsRelayed && hostinfo.remote != via.UdpAddr { if udpAddr.IsValid() && hostinfo.remote != udpAddr {
if !f.lightHouse.GetRemoteAllowList().AllowAll(hostinfo.vpnAddrs, via.UdpAddr.Addr()) { if !f.lightHouse.GetRemoteAllowList().AllowAll(hostinfo.vpnAddrs, udpAddr.Addr()) {
hostinfo.logger(f.l).WithField("newAddr", via.UdpAddr).Debug("lighthouse.remote_allow_list denied roaming") hostinfo.logger(f.l).WithField("newAddr", udpAddr).Debug("lighthouse.remote_allow_list denied roaming")
return return
} }
if !hostinfo.lastRoam.IsZero() && via.UdpAddr == hostinfo.lastRoamRemote && time.Since(hostinfo.lastRoam) < RoamingSuppressSeconds*time.Second { if !hostinfo.lastRoam.IsZero() && udpAddr == hostinfo.lastRoamRemote && time.Since(hostinfo.lastRoam) < RoamingSuppressSeconds*time.Second {
if f.l.Level >= logrus.DebugLevel { if f.l.Level >= logrus.DebugLevel {
hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", via.UdpAddr). hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", udpAddr).
Debugf("Suppressing roam back to previous remote for %d seconds", RoamingSuppressSeconds) Debugf("Suppressing roam back to previous remote for %d seconds", RoamingSuppressSeconds)
} }
return return
} }
hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", via.UdpAddr). hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", udpAddr).
Info("Host roamed to new udp ip/port.") Info("Host roamed to new udp ip/port.")
hostinfo.lastRoam = time.Now() hostinfo.lastRoam = time.Now()
hostinfo.lastRoamRemote = hostinfo.remote hostinfo.lastRoamRemote = hostinfo.remote
hostinfo.SetRemote(via.UdpAddr) hostinfo.SetRemote(udpAddr)
} }
} }
// handleEncrypted returns true if a packet should be processed, false otherwise // handleEncrypted returns true if a packet should be processed, false otherwise
func (f *Interface) handleEncrypted(ci *ConnectionState, via *ViaSender, h *header.H) bool { func (f *Interface) handleEncrypted(ci *ConnectionState, addr netip.AddrPort, h *header.H) bool {
// If connectionstate does not exist, send a recv error, if possible, to encourage a fast reconnect // If connectionstate does not exist, send a recv error, if possible, to encourage a fast reconnect
if ci == nil { if ci == nil {
if !via.IsRelayed { if addr.IsValid() {
f.maybeSendRecvError(via.UdpAddr, h.RemoteIndex) f.maybeSendRecvError(addr, h.RemoteIndex)
} }
return false return false
} }

View File

@@ -13,6 +13,5 @@ type Device interface {
Networks() []netip.Prefix Networks() []netip.Prefix
Name() string Name() string
RoutesFor(netip.Addr) routing.Gateways RoutesFor(netip.Addr) routing.Gateways
SupportsMultiqueue() bool
NewMultiQueueReader() (io.ReadWriteCloser, error) NewMultiQueueReader() (io.ReadWriteCloser, error)
} }

View File

@@ -95,10 +95,6 @@ func (t *tun) Name() string {
return "android" return "android"
} }
func (t *tun) SupportsMultiqueue() bool {
return false
}
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for android") return nil, fmt.Errorf("TODO: multiqueue not implemented for android")
} }

View File

@@ -549,10 +549,6 @@ func (t *tun) Name() string {
return t.Device return t.Device
} }
func (t *tun) SupportsMultiqueue() bool {
return false
}
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for darwin") return nil, fmt.Errorf("TODO: multiqueue not implemented for darwin")
} }

View File

@@ -105,10 +105,6 @@ func (t *disabledTun) Write(b []byte) (int, error) {
return len(b), nil return len(b), nil
} }
func (t *disabledTun) SupportsMultiqueue() bool {
return true
}
func (t *disabledTun) NewMultiQueueReader() (io.ReadWriteCloser, error) { func (t *disabledTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return t, nil return t, nil
} }

View File

@@ -450,10 +450,6 @@ func (t *tun) Name() string {
return t.Device return t.Device
} }
func (t *tun) SupportsMultiqueue() bool {
return false
}
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for freebsd") return nil, fmt.Errorf("TODO: multiqueue not implemented for freebsd")
} }

View File

@@ -151,10 +151,6 @@ func (t *tun) Name() string {
return "iOS" return "iOS"
} }
func (t *tun) SupportsMultiqueue() bool {
return false
}
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for ios") return nil, fmt.Errorf("TODO: multiqueue not implemented for ios")
} }

View File

@@ -216,10 +216,6 @@ func (t *tun) reload(c *config.C, initial bool) error {
return nil return nil
} }
func (t *tun) SupportsMultiqueue() bool {
return true
}
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0) fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
if err != nil { if err != nil {
@@ -586,42 +582,48 @@ func (t *tun) isGatewayInVpnNetworks(gwAddr netip.Addr) bool {
} }
func (t *tun) getGatewaysFromRoute(r *netlink.Route) routing.Gateways { func (t *tun) getGatewaysFromRoute(r *netlink.Route) routing.Gateways {
var gateways routing.Gateways var gateways routing.Gateways
link, err := netlink.LinkByName(t.Device) link, err := netlink.LinkByName(t.Device)
if err != nil { if err != nil {
t.l.WithField("deviceName", t.Device).Error("Ignoring route update: failed to get link by name") t.l.WithField("Devicename", t.Device).Error("Ignoring route update: failed to get link by name")
return gateways return gateways
} }
// If this route is relevant to our interface and there is a gateway then add it // If this route is relevant to our interface and there is a gateway then add it
if r.LinkIndex == link.Attrs().Index { if r.LinkIndex == link.Attrs().Index && len(r.Gw) > 0 {
gwAddr, ok := getGatewayAddr(r.Gw, r.Via) gwAddr, ok := netip.AddrFromSlice(r.Gw)
if ok { if !ok {
if t.isGatewayInVpnNetworks(gwAddr) { t.l.WithField("route", r).Debug("Ignoring route update, invalid gateway address")
gateways = append(gateways, routing.NewGateway(gwAddr, 1))
} else {
// Gateway isn't in our overlay network, ignore
t.l.WithField("route", r).Debug("Ignoring route update, gateway is not in our network")
}
} else { } else {
t.l.WithField("route", r).Debug("Ignoring route update, invalid gateway or via address") gwAddr = gwAddr.Unmap()
if !t.isGatewayInVpnNetworks(gwAddr) {
// Gateway isn't in our overlay network, ignore
t.l.WithField("route", r).Debug("Ignoring route update, not in our network")
} else {
gateways = append(gateways, routing.NewGateway(gwAddr, 1))
}
} }
} }
for _, p := range r.MultiPath { for _, p := range r.MultiPath {
// If this route is relevant to our interface and there is a gateway then add it // If this route is relevant to our interface and there is a gateway then add it
if p.LinkIndex == link.Attrs().Index { if p.LinkIndex == link.Attrs().Index && len(p.Gw) > 0 {
gwAddr, ok := getGatewayAddr(p.Gw, p.Via) gwAddr, ok := netip.AddrFromSlice(p.Gw)
if ok { if !ok {
if t.isGatewayInVpnNetworks(gwAddr) { t.l.WithField("route", r).Debug("Ignoring multipath route update, invalid gateway address")
gateways = append(gateways, routing.NewGateway(gwAddr, p.Hops+1))
} else {
// Gateway isn't in our overlay network, ignore
t.l.WithField("route", r).Debug("Ignoring route update, gateway is not in our network")
}
} else { } else {
t.l.WithField("route", r).Debug("Ignoring route update, invalid gateway or via address") gwAddr = gwAddr.Unmap()
if !t.isGatewayInVpnNetworks(gwAddr) {
// Gateway isn't in our overlay network, ignore
t.l.WithField("route", r).Debug("Ignoring route update, not in our network")
} else {
// p.Hops+1 = weight of the route
gateways = append(gateways, routing.NewGateway(gwAddr, p.Hops+1))
}
} }
} }
} }
@@ -630,27 +632,10 @@ func (t *tun) getGatewaysFromRoute(r *netlink.Route) routing.Gateways {
return gateways return gateways
} }
func getGatewayAddr(gw net.IP, via netlink.Destination) (netip.Addr, bool) {
// Try to use the old RTA_GATEWAY first
gwAddr, ok := netip.AddrFromSlice(gw)
if !ok {
// Fallback to the new RTA_VIA
rVia, ok := via.(*netlink.Via)
if ok {
gwAddr, ok = netip.AddrFromSlice(rVia.Addr)
}
}
if gwAddr.IsValid() {
gwAddr = gwAddr.Unmap()
return gwAddr, true
}
return netip.Addr{}, false
}
func (t *tun) updateRoutes(r netlink.RouteUpdate) { func (t *tun) updateRoutes(r netlink.RouteUpdate) {
gateways := t.getGatewaysFromRoute(&r.Route) gateways := t.getGatewaysFromRoute(&r.Route)
if len(gateways) == 0 { if len(gateways) == 0 {
// No gateways relevant to our network, no routing changes required. // No gateways relevant to our network, no routing changes required.
t.l.WithField("route", r).Debug("Ignoring route update, no gateways") t.l.WithField("route", r).Debug("Ignoring route update, no gateways")

View File

@@ -390,10 +390,6 @@ func (t *tun) Name() string {
return t.Device return t.Device
} }
func (t *tun) SupportsMultiqueue() bool {
return false
}
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for netbsd") return nil, fmt.Errorf("TODO: multiqueue not implemented for netbsd")
} }

View File

@@ -310,10 +310,6 @@ func (t *tun) Name() string {
return t.Device return t.Device
} }
func (t *tun) SupportsMultiqueue() bool {
return false
}
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for openbsd") return nil, fmt.Errorf("TODO: multiqueue not implemented for openbsd")
} }

View File

@@ -132,10 +132,6 @@ func (t *TestTun) Read(b []byte) (int, error) {
return len(p), nil return len(p), nil
} }
func (t *TestTun) SupportsMultiqueue() bool {
return false
}
func (t *TestTun) NewMultiQueueReader() (io.ReadWriteCloser, error) { func (t *TestTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented") return nil, fmt.Errorf("TODO: multiqueue not implemented")
} }

View File

@@ -234,10 +234,6 @@ func (t *winTun) Write(b []byte) (int, error) {
return t.tun.Write(b, 0) return t.tun.Write(b, 0)
} }
func (t *winTun) SupportsMultiqueue() bool {
return false
}
func (t *winTun) NewMultiQueueReader() (io.ReadWriteCloser, error) { func (t *winTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for windows") return nil, fmt.Errorf("TODO: multiqueue not implemented for windows")
} }

View File

@@ -46,10 +46,6 @@ func (d *UserDevice) RoutesFor(ip netip.Addr) routing.Gateways {
return routing.Gateways{routing.NewGateway(ip, 1)} return routing.Gateways{routing.NewGateway(ip, 1)}
} }
func (d *UserDevice) SupportsMultiqueue() bool {
return true
}
func (d *UserDevice) NewMultiQueueReader() (io.ReadWriteCloser, error) { func (d *UserDevice) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return d, nil return d, nil
} }

View File

@@ -338,21 +338,21 @@ func (r *RemoteList) CopyCache() *CacheMap {
} }
// BlockRemote locks and records the address as bad, it will be excluded from the deduplicated address list // BlockRemote locks and records the address as bad, it will be excluded from the deduplicated address list
func (r *RemoteList) BlockRemote(bad *ViaSender) { func (r *RemoteList) BlockRemote(bad netip.AddrPort) {
if bad.IsRelayed { if !bad.IsValid() {
// relays can have nil udp Addrs
return return
} }
r.Lock() r.Lock()
defer r.Unlock() defer r.Unlock()
// Check if we already blocked this addr // Check if we already blocked this addr
if r.unlockedIsBad(bad.UdpAddr) { if r.unlockedIsBad(bad) {
return return
} }
// We copy here because we are taking something else's memory and we can't trust everything // We copy here because we are taking something else's memory and we can't trust everything
r.badRemotes = append(r.badRemotes, bad.UdpAddr) r.badRemotes = append(r.badRemotes, bad)
// Mark the next interaction must recollect/dedupe // Mark the next interaction must recollect/dedupe
r.shouldRebuild = true r.shouldRebuild = true

View File

@@ -34,10 +34,6 @@ func (NoopTun) Write([]byte) (int, error) {
return 0, nil return 0, nil
} }
func (NoopTun) SupportsMultiqueue() bool {
return false
}
func (NoopTun) NewMultiQueueReader() (io.ReadWriteCloser, error) { func (NoopTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return nil, errors.New("unsupported") return nil, errors.New("unsupported")
} }

View File

@@ -19,7 +19,6 @@ type Conn interface {
ListenOut(r EncReader) ListenOut(r EncReader)
WriteTo(b []byte, addr netip.AddrPort) error WriteTo(b []byte, addr netip.AddrPort) error
ReloadConfig(c *config.C) ReloadConfig(c *config.C)
SupportsMultipleReaders() bool
Close() error Close() error
} }
@@ -34,9 +33,6 @@ func (NoopConn) LocalAddr() (netip.AddrPort, error) {
func (NoopConn) ListenOut(_ EncReader) { func (NoopConn) ListenOut(_ EncReader) {
return return
} }
func (NoopConn) SupportsMultipleReaders() bool {
return false
}
func (NoopConn) WriteTo(_ []byte, _ netip.AddrPort) error { func (NoopConn) WriteTo(_ []byte, _ netip.AddrPort) error {
return nil return nil
} }

View File

@@ -98,9 +98,9 @@ func (u *StdConn) WriteTo(b []byte, ap netip.AddrPort) error {
return ErrInvalidIPv6RemoteForSocket return ErrInvalidIPv6RemoteForSocket
} }
var rsa unix.RawSockaddrInet4 var rsa unix.RawSockaddrInet6
rsa.Family = unix.AF_INET rsa.Family = unix.AF_INET6
rsa.Addr = ap.Addr().As4() rsa.Addr = ap.Addr().As16()
binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&rsa.Port))[:], ap.Port()) binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&rsa.Port))[:], ap.Port())
sa = unsafe.Pointer(&rsa) sa = unsafe.Pointer(&rsa)
addrLen = syscall.SizeofSockaddrInet4 addrLen = syscall.SizeofSockaddrInet4
@@ -184,10 +184,6 @@ func (u *StdConn) ListenOut(r EncReader) {
} }
} }
func (u *StdConn) SupportsMultipleReaders() bool {
return false
}
func (u *StdConn) Rebind() error { func (u *StdConn) Rebind() error {
var err error var err error
if u.isV4 { if u.isV4 {

View File

@@ -85,7 +85,3 @@ func (u *GenericConn) ListenOut(r EncReader) {
r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n]) r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n])
} }
} }
func (u *GenericConn) SupportsMultipleReaders() bool {
return false
}

View File

@@ -72,10 +72,6 @@ func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch in
return &StdConn{sysFd: fd, isV4: ip.Is4(), l: l, batch: batch}, err return &StdConn{sysFd: fd, isV4: ip.Is4(), l: l, batch: batch}, err
} }
func (u *StdConn) SupportsMultipleReaders() bool {
return true
}
func (u *StdConn) Rebind() error { func (u *StdConn) Rebind() error {
return nil return nil
} }

View File

@@ -315,10 +315,6 @@ func (u *RIOConn) LocalAddr() (netip.AddrPort, error) {
} }
func (u *RIOConn) SupportsMultipleReaders() bool {
return false
}
func (u *RIOConn) Rebind() error { func (u *RIOConn) Rebind() error {
return nil return nil
} }

View File

@@ -127,10 +127,6 @@ func (u *TesterConn) LocalAddr() (netip.AddrPort, error) {
return u.Addr, nil return u.Addr, nil
} }
func (u *TesterConn) SupportsMultipleReaders() bool {
return false
}
func (u *TesterConn) Rebind() error { func (u *TesterConn) Rebind() error {
return nil return nil
} }