mirror of
https://github.com/slackhq/nebula.git
synced 2025-11-22 08:24:25 +01:00
Compare commits
2 Commits
udpaddr-lo
...
changelog-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ee8e4d2017 | ||
|
|
8d656fb890 |
60
CHANGELOG.md
60
CHANGELOG.md
@@ -7,12 +7,64 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
|||||||
|
|
||||||
## [Unreleased]
|
## [Unreleased]
|
||||||
|
|
||||||
|
## [1.10.0] - ????
|
||||||
|
|
||||||
|
### Added
|
||||||
|
|
||||||
|
- PKCS11 support for P256 keys when built with `pkcs11` tag (#1153)
|
||||||
|
- ASN.1 based v2 nebula certificates with support for ipv6 and multiple ip addresses.
|
||||||
|
Certificates now have a unified interface for external implementations. (#1212, #1216, #1345)
|
||||||
|
**TODO: External documentation link!**
|
||||||
|
- Add the ability to mark packets on linux to better target nebula packets in iptables/nftables. (#1331)
|
||||||
|
- Add ECMP support for `unsafe_routes`. (#1332)
|
||||||
|
|
||||||
### Changed
|
### Changed
|
||||||
|
|
||||||
- `default_local_cidr_any` now defaults to false, meaning that any firewall rule
|
- `default_local_cidr_any` now defaults to false, meaning that any firewall rule
|
||||||
intended to target an `unsafe_routes` entry must explicitly declare it via the
|
intended to target an `unsafe_routes` entry must explicitly declare it via the
|
||||||
`local_cidr` field. This is almost always the intended behavior. This flag is
|
`local_cidr` field. This is almost always the intended behavior. This flag is
|
||||||
deprecated and will be removed in a future release.
|
deprecated and will be removed in a future release. (#1373)
|
||||||
|
|
||||||
|
### Fixed
|
||||||
|
|
||||||
|
- Fix moving a udp address from one vpn address to another in the `static_host_map`
|
||||||
|
which could cause rapid re-handshaking with an incorrect remote. (#1259)
|
||||||
|
- Improve smoke tests in environments where the docker network is not the default. (#1347)
|
||||||
|
|
||||||
|
## [1.9.7] - 2025-10-10
|
||||||
|
|
||||||
|
### Security
|
||||||
|
|
||||||
|
- Fix an issue where Nebula could incorrectly accept and process a packet from an erroneous source IP when the sender's
|
||||||
|
certificate is configured with unsafe_routes (cert v1/v2) or multiple IPs (cert v2). (#1494)
|
||||||
|
|
||||||
|
### Changed
|
||||||
|
|
||||||
|
- Disable sending `recv_error` messages when a packet is received outside the allowable counter window. (#1459)
|
||||||
|
- Improve error messages and remove some unnecessary fatal conditions in the Windows and generic udp listener. (#1543)
|
||||||
|
|
||||||
|
## [1.9.6] - 2025-7-15
|
||||||
|
|
||||||
|
### Added
|
||||||
|
|
||||||
|
- Support dropping inactive tunnels. This is disabled by default in this release but can be enabled with `tunnels.drop_inactive`. See example config for more details. (#1413)
|
||||||
|
|
||||||
|
### Fixed
|
||||||
|
|
||||||
|
- Fix Darwin freeze due to presence of some Network Extensions (#1426)
|
||||||
|
- Ensure the same relay tunnel is always used when multiple relay tunnels are present (#1422)
|
||||||
|
- Fix Windows freeze due to ICMP error handling (#1412)
|
||||||
|
- Fix relay migration panic (#1403)
|
||||||
|
|
||||||
|
## [1.9.5] - 2024-12-05
|
||||||
|
|
||||||
|
### Added
|
||||||
|
|
||||||
|
- Gracefully ignore v2 certificates. (#1282)
|
||||||
|
|
||||||
|
### Fixed
|
||||||
|
|
||||||
|
- Fix relays that refuse to re-establish after one of the remote tunnel pairs breaks. (#1277)
|
||||||
|
|
||||||
## [1.9.4] - 2024-09-09
|
## [1.9.4] - 2024-09-09
|
||||||
|
|
||||||
@@ -671,7 +723,11 @@ created.)
|
|||||||
|
|
||||||
- Initial public release.
|
- Initial public release.
|
||||||
|
|
||||||
[Unreleased]: https://github.com/slackhq/nebula/compare/v1.9.4...HEAD
|
[Unreleased]: https://github.com/slackhq/nebula/compare/v1.10.0...HEAD
|
||||||
|
[1.10.0]: https://github.com/slackhq/nebula/releases/tag/v1.10.0
|
||||||
|
[1.9.7]: https://github.com/slackhq/nebula/releases/tag/v1.9.7
|
||||||
|
[1.9.6]: https://github.com/slackhq/nebula/releases/tag/v1.9.6
|
||||||
|
[1.9.5]: https://github.com/slackhq/nebula/releases/tag/v1.9.5
|
||||||
[1.9.4]: https://github.com/slackhq/nebula/releases/tag/v1.9.4
|
[1.9.4]: https://github.com/slackhq/nebula/releases/tag/v1.9.4
|
||||||
[1.9.3]: https://github.com/slackhq/nebula/releases/tag/v1.9.3
|
[1.9.3]: https://github.com/slackhq/nebula/releases/tag/v1.9.3
|
||||||
[1.9.2]: https://github.com/slackhq/nebula/releases/tag/v1.9.2
|
[1.9.2]: https://github.com/slackhq/nebula/releases/tag/v1.9.2
|
||||||
|
|||||||
@@ -43,7 +43,7 @@ type signFlags struct {
|
|||||||
func newSignFlags() *signFlags {
|
func newSignFlags() *signFlags {
|
||||||
sf := signFlags{set: flag.NewFlagSet("sign", flag.ContinueOnError)}
|
sf := signFlags{set: flag.NewFlagSet("sign", flag.ContinueOnError)}
|
||||||
sf.set.Usage = func() {}
|
sf.set.Usage = func() {}
|
||||||
sf.version = sf.set.Uint("version", 0, "Optional: version of the certificate format to use. The default is to match the version of the signing CA")
|
sf.version = sf.set.Uint("version", 0, "Optional: version of the certificate format to use, the default is to create both v1 and v2 certificates.")
|
||||||
sf.caKeyPath = sf.set.String("ca-key", "ca.key", "Optional: path to the signing CA key")
|
sf.caKeyPath = sf.set.String("ca-key", "ca.key", "Optional: path to the signing CA key")
|
||||||
sf.caCertPath = sf.set.String("ca-crt", "ca.crt", "Optional: path to the signing CA cert")
|
sf.caCertPath = sf.set.String("ca-crt", "ca.crt", "Optional: path to the signing CA cert")
|
||||||
sf.name = sf.set.String("name", "", "Required: name of the cert, usually a hostname")
|
sf.name = sf.set.String("name", "", "Required: name of the cert, usually a hostname")
|
||||||
@@ -167,10 +167,6 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
|
|||||||
return fmt.Errorf("ca certificate is expired")
|
return fmt.Errorf("ca certificate is expired")
|
||||||
}
|
}
|
||||||
|
|
||||||
if version == 0 {
|
|
||||||
version = caCert.Version()
|
|
||||||
}
|
|
||||||
|
|
||||||
// if no duration is given, expire one second before the root expires
|
// if no duration is given, expire one second before the root expires
|
||||||
if *sf.duration <= 0 {
|
if *sf.duration <= 0 {
|
||||||
*sf.duration = time.Until(caCert.NotAfter()) - time.Second*1
|
*sf.duration = time.Until(caCert.NotAfter()) - time.Second*1
|
||||||
@@ -283,19 +279,21 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
|
|||||||
notBefore := time.Now()
|
notBefore := time.Now()
|
||||||
notAfter := notBefore.Add(*sf.duration)
|
notAfter := notBefore.Add(*sf.duration)
|
||||||
|
|
||||||
switch version {
|
if version == 0 || version == cert.Version1 {
|
||||||
case cert.Version1:
|
// Make sure we at least have an ip
|
||||||
// Make sure we have only one ipv4 address
|
|
||||||
if len(v4Networks) != 1 {
|
if len(v4Networks) != 1 {
|
||||||
return newHelpErrorf("invalid -networks definition: v1 certificates can only have a single ipv4 address")
|
return newHelpErrorf("invalid -networks definition: v1 certificates can only have a single ipv4 address")
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(v6Networks) > 0 {
|
if version == cert.Version1 {
|
||||||
return newHelpErrorf("invalid -networks definition: v1 certificates can only contain ipv4 addresses")
|
// If we are asked to mint a v1 certificate only then we cant just ignore any v6 addresses
|
||||||
}
|
if len(v6Networks) > 0 {
|
||||||
|
return newHelpErrorf("invalid -networks definition: v1 certificates can only be ipv4")
|
||||||
|
}
|
||||||
|
|
||||||
if len(v6UnsafeNetworks) > 0 {
|
if len(v6UnsafeNetworks) > 0 {
|
||||||
return newHelpErrorf("invalid -unsafe-networks definition: v1 certificates can only contain ipv4 addresses")
|
return newHelpErrorf("invalid -unsafe-networks definition: v1 certificates can only be ipv4")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
t := &cert.TBSCertificate{
|
t := &cert.TBSCertificate{
|
||||||
@@ -325,8 +323,9 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
|
|||||||
}
|
}
|
||||||
|
|
||||||
crts = append(crts, nc)
|
crts = append(crts, nc)
|
||||||
|
}
|
||||||
|
|
||||||
case cert.Version2:
|
if version == 0 || version == cert.Version2 {
|
||||||
t := &cert.TBSCertificate{
|
t := &cert.TBSCertificate{
|
||||||
Version: cert.Version2,
|
Version: cert.Version2,
|
||||||
Name: *sf.name,
|
Name: *sf.name,
|
||||||
@@ -354,9 +353,6 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
|
|||||||
}
|
}
|
||||||
|
|
||||||
crts = append(crts, nc)
|
crts = append(crts, nc)
|
||||||
default:
|
|
||||||
// this should be unreachable
|
|
||||||
return fmt.Errorf("invalid version: %d", version)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if !isP11 && *sf.inPubPath == "" {
|
if !isP11 && *sf.inPubPath == "" {
|
||||||
|
|||||||
@@ -55,7 +55,7 @@ func Test_signHelp(t *testing.T) {
|
|||||||
" -unsafe-networks string\n"+
|
" -unsafe-networks string\n"+
|
||||||
" \tOptional: comma separated list of ip address and network in CIDR notation. Unsafe networks this cert can route for\n"+
|
" \tOptional: comma separated list of ip address and network in CIDR notation. Unsafe networks this cert can route for\n"+
|
||||||
" -version uint\n"+
|
" -version uint\n"+
|
||||||
" \tOptional: version of the certificate format to use. The default is to match the version of the signing CA\n",
|
" \tOptional: version of the certificate format to use, the default is to create both v1 and v2 certificates.\n",
|
||||||
ob.String(),
|
ob.String(),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
@@ -204,7 +204,7 @@ func Test_signCert(t *testing.T) {
|
|||||||
ob.Reset()
|
ob.Reset()
|
||||||
eb.Reset()
|
eb.Reset()
|
||||||
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "100::100/100"}
|
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "100::100/100"}
|
||||||
assertHelpError(t, signCert(args, ob, eb, nopw), "invalid -unsafe-networks definition: v1 certificates can only contain ipv4 addresses")
|
assertHelpError(t, signCert(args, ob, eb, nopw), "invalid -unsafe-networks definition: v1 certificates can only be ipv4")
|
||||||
assert.Empty(t, ob.String())
|
assert.Empty(t, ob.String())
|
||||||
assert.Empty(t, eb.String())
|
assert.Empty(t, eb.String())
|
||||||
|
|
||||||
|
|||||||
@@ -50,6 +50,11 @@ func NewConnectionState(l *logrus.Logger, cs *CertState, crt cert.Certificate, i
|
|||||||
}
|
}
|
||||||
|
|
||||||
static := noise.DHKey{Private: cs.privateKey, Public: crt.PublicKey()}
|
static := noise.DHKey{Private: cs.privateKey, Public: crt.PublicKey()}
|
||||||
|
|
||||||
|
b := NewBits(ReplayWindow)
|
||||||
|
// Clear out bit 0, we never transmit it, and we don't want it showing as packet loss
|
||||||
|
b.Update(l, 0)
|
||||||
|
|
||||||
hs, err := noise.NewHandshakeState(noise.Config{
|
hs, err := noise.NewHandshakeState(noise.Config{
|
||||||
CipherSuite: ncs,
|
CipherSuite: ncs,
|
||||||
Random: rand.Reader,
|
Random: rand.Reader,
|
||||||
@@ -69,7 +74,7 @@ func NewConnectionState(l *logrus.Logger, cs *CertState, crt cert.Certificate, i
|
|||||||
ci := &ConnectionState{
|
ci := &ConnectionState{
|
||||||
H: hs,
|
H: hs,
|
||||||
initiator: initiator,
|
initiator: initiator,
|
||||||
window: NewBits(ReplayWindow),
|
window: b,
|
||||||
myCert: crt,
|
myCert: crt,
|
||||||
}
|
}
|
||||||
// always start the counter from 2, as packet 1 and packet 2 are handshake packets.
|
// always start the counter from 2, as packet 1 and packet 2 are handshake packets.
|
||||||
|
|||||||
@@ -25,12 +25,11 @@ import (
|
|||||||
|
|
||||||
func BenchmarkHotPath(b *testing.B) {
|
func BenchmarkHotPath(b *testing.B) {
|
||||||
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||||
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil)
|
myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil)
|
||||||
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil)
|
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil)
|
||||||
|
|
||||||
// Put their info in our lighthouse
|
// Put their info in our lighthouse
|
||||||
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
|
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
|
||||||
theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
|
|
||||||
|
|
||||||
// Start the servers
|
// Start the servers
|
||||||
myControl.Start()
|
myControl.Start()
|
||||||
@@ -39,9 +38,6 @@ func BenchmarkHotPath(b *testing.B) {
|
|||||||
r := router.NewR(b, myControl, theirControl)
|
r := router.NewR(b, myControl, theirControl)
|
||||||
r.CancelFlowLogs()
|
r.CancelFlowLogs()
|
||||||
|
|
||||||
assertTunnel(b, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
|
|
||||||
b.ResetTimer()
|
|
||||||
|
|
||||||
for n := 0; n < b.N; n++ {
|
for n := 0; n < b.N; n++ {
|
||||||
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
|
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
|
||||||
_ = r.RouteForAllUntilTxTun(theirControl)
|
_ = r.RouteForAllUntilTxTun(theirControl)
|
||||||
@@ -51,39 +47,6 @@ func BenchmarkHotPath(b *testing.B) {
|
|||||||
theirControl.Stop()
|
theirControl.Stop()
|
||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkHotPathRelay(b *testing.B) {
|
|
||||||
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
|
||||||
myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}})
|
|
||||||
relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}})
|
|
||||||
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}})
|
|
||||||
|
|
||||||
// Teach my how to get to the relay and that their can be reached via the relay
|
|
||||||
myControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr)
|
|
||||||
myControl.InjectRelays(theirVpnIpNet[0].Addr(), []netip.Addr{relayVpnIpNet[0].Addr()})
|
|
||||||
relayControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
|
|
||||||
|
|
||||||
// Build a router so we don't have to reason who gets which packet
|
|
||||||
r := router.NewR(b, myControl, relayControl, theirControl)
|
|
||||||
r.CancelFlowLogs()
|
|
||||||
|
|
||||||
// Start the servers
|
|
||||||
myControl.Start()
|
|
||||||
relayControl.Start()
|
|
||||||
theirControl.Start()
|
|
||||||
|
|
||||||
assertTunnel(b, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r)
|
|
||||||
b.ResetTimer()
|
|
||||||
|
|
||||||
for n := 0; n < b.N; n++ {
|
|
||||||
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
|
|
||||||
_ = r.RouteForAllUntilTxTun(theirControl)
|
|
||||||
}
|
|
||||||
|
|
||||||
myControl.Stop()
|
|
||||||
theirControl.Stop()
|
|
||||||
relayControl.Stop()
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGoodHandshake(t *testing.T) {
|
func TestGoodHandshake(t *testing.T) {
|
||||||
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||||
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil)
|
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil)
|
||||||
|
|||||||
@@ -292,7 +292,7 @@ func deadline(t *testing.T, seconds time.Duration) doneCb {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func assertTunnel(t testing.TB, vpnIpA, vpnIpB netip.Addr, controlA, controlB *nebula.Control, r *router.R) {
|
func assertTunnel(t *testing.T, vpnIpA, vpnIpB netip.Addr, controlA, controlB *nebula.Control, r *router.R) {
|
||||||
// Send a packet from them to me
|
// Send a packet from them to me
|
||||||
controlB.InjectTunUDPPacket(vpnIpA, 80, vpnIpB, 90, []byte("Hi from B"))
|
controlB.InjectTunUDPPacket(vpnIpA, 80, vpnIpB, 90, []byte("Hi from B"))
|
||||||
bPacket := r.RouteForAllUntilTxTun(controlA)
|
bPacket := r.RouteForAllUntilTxTun(controlA)
|
||||||
@@ -325,7 +325,7 @@ func assertHostInfoPair(t *testing.T, addrA, addrB netip.AddrPort, vpnNetsA, vpn
|
|||||||
assert.Equal(t, hBinA.RemoteIndex, hAinB.LocalIndex, "Host B remote index does not match host A local index")
|
assert.Equal(t, hBinA.RemoteIndex, hAinB.LocalIndex, "Host B remote index does not match host A local index")
|
||||||
}
|
}
|
||||||
|
|
||||||
func assertUdpPacket(t testing.TB, expected, b []byte, fromIp, toIp netip.Addr, fromPort, toPort uint16) {
|
func assertUdpPacket(t *testing.T, expected, b []byte, fromIp, toIp netip.Addr, fromPort, toPort uint16) {
|
||||||
if toIp.Is6() {
|
if toIp.Is6() {
|
||||||
assertUdpPacket6(t, expected, b, fromIp, toIp, fromPort, toPort)
|
assertUdpPacket6(t, expected, b, fromIp, toIp, fromPort, toPort)
|
||||||
} else {
|
} else {
|
||||||
@@ -333,7 +333,7 @@ func assertUdpPacket(t testing.TB, expected, b []byte, fromIp, toIp netip.Addr,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func assertUdpPacket6(t testing.TB, expected, b []byte, fromIp, toIp netip.Addr, fromPort, toPort uint16) {
|
func assertUdpPacket6(t *testing.T, expected, b []byte, fromIp, toIp netip.Addr, fromPort, toPort uint16) {
|
||||||
packet := gopacket.NewPacket(b, layers.LayerTypeIPv6, gopacket.Lazy)
|
packet := gopacket.NewPacket(b, layers.LayerTypeIPv6, gopacket.Lazy)
|
||||||
v6 := packet.Layer(layers.LayerTypeIPv6).(*layers.IPv6)
|
v6 := packet.Layer(layers.LayerTypeIPv6).(*layers.IPv6)
|
||||||
assert.NotNil(t, v6, "No ipv6 data found")
|
assert.NotNil(t, v6, "No ipv6 data found")
|
||||||
@@ -352,7 +352,7 @@ func assertUdpPacket6(t testing.TB, expected, b []byte, fromIp, toIp netip.Addr,
|
|||||||
assert.Equal(t, expected, data.Payload(), "Data was incorrect")
|
assert.Equal(t, expected, data.Payload(), "Data was incorrect")
|
||||||
}
|
}
|
||||||
|
|
||||||
func assertUdpPacket4(t testing.TB, expected, b []byte, fromIp, toIp netip.Addr, fromPort, toPort uint16) {
|
func assertUdpPacket4(t *testing.T, expected, b []byte, fromIp, toIp netip.Addr, fromPort, toPort uint16) {
|
||||||
packet := gopacket.NewPacket(b, layers.LayerTypeIPv4, gopacket.Lazy)
|
packet := gopacket.NewPacket(b, layers.LayerTypeIPv4, gopacket.Lazy)
|
||||||
v4 := packet.Layer(layers.LayerTypeIPv4).(*layers.IPv4)
|
v4 := packet.Layer(layers.LayerTypeIPv4).(*layers.IPv4)
|
||||||
assert.NotNil(t, v4, "No ipv4 data found")
|
assert.NotNil(t, v4, "No ipv4 data found")
|
||||||
|
|||||||
@@ -383,9 +383,8 @@ firewall:
|
|||||||
# host: `any` or a literal hostname, ie `test-host`
|
# host: `any` or a literal hostname, ie `test-host`
|
||||||
# group: `any` or a literal group name, ie `default-group`
|
# group: `any` or a literal group name, ie `default-group`
|
||||||
# groups: Same as group but accepts a list of values. Multiple values are AND'd together and a certificate would have to contain all groups to pass
|
# groups: Same as group but accepts a list of values. Multiple values are AND'd together and a certificate would have to contain all groups to pass
|
||||||
# cidr: a remote CIDR, `0.0.0.0/0` is any ipv4 and `::/0` is any ipv6. `any` means any ip family and address.
|
# cidr: a remote CIDR, `0.0.0.0/0` is any ipv4 and `::/0` is any ipv6.
|
||||||
# local_cidr: a local CIDR, `0.0.0.0/0` is any ipv4 and `::/0` is any ipv6. `any` means any ip family and address.
|
# local_cidr: a local CIDR, `0.0.0.0/0` is any ipv4 and `::/0` is any ipv6. This can be used to filter destinations when using unsafe_routes.
|
||||||
# This can be used to filter destinations when using unsafe_routes.
|
|
||||||
# By default, this is set to only the VPN (overlay) networks assigned via the certificate networks field unless `default_local_cidr_any` is set to true.
|
# By default, this is set to only the VPN (overlay) networks assigned via the certificate networks field unless `default_local_cidr_any` is set to true.
|
||||||
# If there are unsafe_routes present in this config file, `local_cidr` should be set appropriately for the intended us case.
|
# If there are unsafe_routes present in this config file, `local_cidr` should be set appropriately for the intended us case.
|
||||||
# ca_name: An issuing CA name
|
# ca_name: An issuing CA name
|
||||||
|
|||||||
167
firewall.go
167
firewall.go
@@ -8,7 +8,6 @@ import (
|
|||||||
"hash/fnv"
|
"hash/fnv"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"reflect"
|
"reflect"
|
||||||
"slices"
|
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
@@ -23,7 +22,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type FirewallInterface interface {
|
type FirewallInterface interface {
|
||||||
AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, cidr, localCidr string, caName string, caSha string) error
|
AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, addr, localAddr netip.Prefix, caName string, caSha string) error
|
||||||
}
|
}
|
||||||
|
|
||||||
type conn struct {
|
type conn struct {
|
||||||
@@ -248,11 +247,22 @@ func NewFirewallFromConfig(l *logrus.Logger, cs *CertState, c *config.C) (*Firew
|
|||||||
}
|
}
|
||||||
|
|
||||||
// AddRule properly creates the in memory rule structure for a firewall table.
|
// AddRule properly creates the in memory rule structure for a firewall table.
|
||||||
func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, cidr, localCidr, caName string, caSha string) error {
|
func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip, localIp netip.Prefix, caName string, caSha string) error {
|
||||||
|
// Under gomobile, stringing a nil pointer with fmt causes an abort in debug mode for iOS
|
||||||
|
// https://github.com/golang/go/issues/14131
|
||||||
|
sIp := ""
|
||||||
|
if ip.IsValid() {
|
||||||
|
sIp = ip.String()
|
||||||
|
}
|
||||||
|
lIp := ""
|
||||||
|
if localIp.IsValid() {
|
||||||
|
lIp = localIp.String()
|
||||||
|
}
|
||||||
|
|
||||||
// We need this rule string because we generate a hash. Removing this will break firewall reload.
|
// We need this rule string because we generate a hash. Removing this will break firewall reload.
|
||||||
ruleString := fmt.Sprintf(
|
ruleString := fmt.Sprintf(
|
||||||
"incoming: %v, proto: %v, startPort: %v, endPort: %v, groups: %v, host: %v, ip: %v, localIp: %v, caName: %v, caSha: %s",
|
"incoming: %v, proto: %v, startPort: %v, endPort: %v, groups: %v, host: %v, ip: %v, localIp: %v, caName: %v, caSha: %s",
|
||||||
incoming, proto, startPort, endPort, groups, host, cidr, localCidr, caName, caSha,
|
incoming, proto, startPort, endPort, groups, host, sIp, lIp, caName, caSha,
|
||||||
)
|
)
|
||||||
f.rules += ruleString + "\n"
|
f.rules += ruleString + "\n"
|
||||||
|
|
||||||
@@ -260,7 +270,7 @@ func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort
|
|||||||
if !incoming {
|
if !incoming {
|
||||||
direction = "outgoing"
|
direction = "outgoing"
|
||||||
}
|
}
|
||||||
f.l.WithField("firewallRule", m{"direction": direction, "proto": proto, "startPort": startPort, "endPort": endPort, "groups": groups, "host": host, "cidr": cidr, "localCidr": localCidr, "caName": caName, "caSha": caSha}).
|
f.l.WithField("firewallRule", m{"direction": direction, "proto": proto, "startPort": startPort, "endPort": endPort, "groups": groups, "host": host, "ip": sIp, "localIp": lIp, "caName": caName, "caSha": caSha}).
|
||||||
Info("Firewall rule added")
|
Info("Firewall rule added")
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -287,7 +297,7 @@ func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort
|
|||||||
return fmt.Errorf("unknown protocol %v", proto)
|
return fmt.Errorf("unknown protocol %v", proto)
|
||||||
}
|
}
|
||||||
|
|
||||||
return fp.addRule(f, startPort, endPort, groups, host, cidr, localCidr, caName, caSha)
|
return fp.addRule(f, startPort, endPort, groups, host, ip, localIp, caName, caSha)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetRuleHash returns a hash representation of all inbound and outbound rules
|
// GetRuleHash returns a hash representation of all inbound and outbound rules
|
||||||
@@ -327,6 +337,7 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw
|
|||||||
}
|
}
|
||||||
|
|
||||||
for i, t := range rs {
|
for i, t := range rs {
|
||||||
|
var groups []string
|
||||||
r, err := convertRule(l, t, table, i)
|
r, err := convertRule(l, t, table, i)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("%s rule #%v; %s", table, i, err)
|
return fmt.Errorf("%s rule #%v; %s", table, i, err)
|
||||||
@@ -336,10 +347,23 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw
|
|||||||
return fmt.Errorf("%s rule #%v; only one of port or code should be provided", table, i)
|
return fmt.Errorf("%s rule #%v; only one of port or code should be provided", table, i)
|
||||||
}
|
}
|
||||||
|
|
||||||
if r.Host == "" && len(r.Groups) == 0 && r.Cidr == "" && r.LocalCidr == "" && r.CAName == "" && r.CASha == "" {
|
if r.Host == "" && len(r.Groups) == 0 && r.Group == "" && r.Cidr == "" && r.LocalCidr == "" && r.CAName == "" && r.CASha == "" {
|
||||||
return fmt.Errorf("%s rule #%v; at least one of host, group, cidr, local_cidr, ca_name, or ca_sha must be provided", table, i)
|
return fmt.Errorf("%s rule #%v; at least one of host, group, cidr, local_cidr, ca_name, or ca_sha must be provided", table, i)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if len(r.Groups) > 0 {
|
||||||
|
groups = r.Groups
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.Group != "" {
|
||||||
|
// Check if we have both groups and group provided in the rule config
|
||||||
|
if len(groups) > 0 {
|
||||||
|
return fmt.Errorf("%s rule #%v; only one of group or groups should be defined, both provided", table, i)
|
||||||
|
}
|
||||||
|
|
||||||
|
groups = []string{r.Group}
|
||||||
|
}
|
||||||
|
|
||||||
var sPort, errPort string
|
var sPort, errPort string
|
||||||
if r.Code != "" {
|
if r.Code != "" {
|
||||||
errPort = "code"
|
errPort = "code"
|
||||||
@@ -368,25 +392,23 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw
|
|||||||
return fmt.Errorf("%s rule #%v; proto was not understood; `%s`", table, i, r.Proto)
|
return fmt.Errorf("%s rule #%v; proto was not understood; `%s`", table, i, r.Proto)
|
||||||
}
|
}
|
||||||
|
|
||||||
if r.Cidr != "" && r.Cidr != "any" {
|
var cidr netip.Prefix
|
||||||
_, err = netip.ParsePrefix(r.Cidr)
|
if r.Cidr != "" {
|
||||||
|
cidr, err = netip.ParsePrefix(r.Cidr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("%s rule #%v; cidr did not parse; %s", table, i, err)
|
return fmt.Errorf("%s rule #%v; cidr did not parse; %s", table, i, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if r.LocalCidr != "" && r.LocalCidr != "any" {
|
var localCidr netip.Prefix
|
||||||
_, err = netip.ParsePrefix(r.LocalCidr)
|
if r.LocalCidr != "" {
|
||||||
|
localCidr, err = netip.ParsePrefix(r.LocalCidr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("%s rule #%v; local_cidr did not parse; %s", table, i, err)
|
return fmt.Errorf("%s rule #%v; local_cidr did not parse; %s", table, i, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if warning := r.sanity(); warning != nil {
|
err = fw.AddRule(inbound, proto, startPort, endPort, groups, r.Host, cidr, localCidr, r.CAName, r.CASha)
|
||||||
l.Warnf("%s rule #%v; %s", table, i, warning)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = fw.AddRule(inbound, proto, startPort, endPort, r.Groups, r.Host, r.Cidr, r.LocalCidr, r.CAName, r.CASha)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("%s rule #%v; `%s`", table, i, err)
|
return fmt.Errorf("%s rule #%v; `%s`", table, i, err)
|
||||||
}
|
}
|
||||||
@@ -633,7 +655,7 @@ func (ft *FirewallTable) match(p firewall.Packet, incoming bool, c *cert.CachedC
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (fp firewallPort) addRule(f *Firewall, startPort int32, endPort int32, groups []string, host string, cidr, localCidr, caName string, caSha string) error {
|
func (fp firewallPort) addRule(f *Firewall, startPort int32, endPort int32, groups []string, host string, ip, localIp netip.Prefix, caName string, caSha string) error {
|
||||||
if startPort > endPort {
|
if startPort > endPort {
|
||||||
return fmt.Errorf("start port was lower than end port")
|
return fmt.Errorf("start port was lower than end port")
|
||||||
}
|
}
|
||||||
@@ -646,7 +668,7 @@ func (fp firewallPort) addRule(f *Firewall, startPort int32, endPort int32, grou
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := fp[i].addRule(f, groups, host, cidr, localCidr, caName, caSha); err != nil {
|
if err := fp[i].addRule(f, groups, host, ip, localIp, caName, caSha); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -677,7 +699,7 @@ func (fp firewallPort) match(p firewall.Packet, incoming bool, c *cert.CachedCer
|
|||||||
return fp[firewall.PortAny].match(p, c, caPool)
|
return fp[firewall.PortAny].match(p, c, caPool)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (fc *FirewallCA) addRule(f *Firewall, groups []string, host string, cidr, localCidr, caName, caSha string) error {
|
func (fc *FirewallCA) addRule(f *Firewall, groups []string, host string, ip, localIp netip.Prefix, caName, caSha string) error {
|
||||||
fr := func() *FirewallRule {
|
fr := func() *FirewallRule {
|
||||||
return &FirewallRule{
|
return &FirewallRule{
|
||||||
Hosts: make(map[string]*firewallLocalCIDR),
|
Hosts: make(map[string]*firewallLocalCIDR),
|
||||||
@@ -691,14 +713,14 @@ func (fc *FirewallCA) addRule(f *Firewall, groups []string, host string, cidr, l
|
|||||||
fc.Any = fr()
|
fc.Any = fr()
|
||||||
}
|
}
|
||||||
|
|
||||||
return fc.Any.addRule(f, groups, host, cidr, localCidr)
|
return fc.Any.addRule(f, groups, host, ip, localIp)
|
||||||
}
|
}
|
||||||
|
|
||||||
if caSha != "" {
|
if caSha != "" {
|
||||||
if _, ok := fc.CAShas[caSha]; !ok {
|
if _, ok := fc.CAShas[caSha]; !ok {
|
||||||
fc.CAShas[caSha] = fr()
|
fc.CAShas[caSha] = fr()
|
||||||
}
|
}
|
||||||
err := fc.CAShas[caSha].addRule(f, groups, host, cidr, localCidr)
|
err := fc.CAShas[caSha].addRule(f, groups, host, ip, localIp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -708,7 +730,7 @@ func (fc *FirewallCA) addRule(f *Firewall, groups []string, host string, cidr, l
|
|||||||
if _, ok := fc.CANames[caName]; !ok {
|
if _, ok := fc.CANames[caName]; !ok {
|
||||||
fc.CANames[caName] = fr()
|
fc.CANames[caName] = fr()
|
||||||
}
|
}
|
||||||
err := fc.CANames[caName].addRule(f, groups, host, cidr, localCidr)
|
err := fc.CANames[caName].addRule(f, groups, host, ip, localIp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -740,24 +762,24 @@ func (fc *FirewallCA) match(p firewall.Packet, c *cert.CachedCertificate, caPool
|
|||||||
return fc.CANames[s.Certificate.Name()].match(p, c)
|
return fc.CANames[s.Certificate.Name()].match(p, c)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (fr *FirewallRule) addRule(f *Firewall, groups []string, host, cidr, localCidr string) error {
|
func (fr *FirewallRule) addRule(f *Firewall, groups []string, host string, ip, localCIDR netip.Prefix) error {
|
||||||
flc := func() *firewallLocalCIDR {
|
flc := func() *firewallLocalCIDR {
|
||||||
return &firewallLocalCIDR{
|
return &firewallLocalCIDR{
|
||||||
LocalCIDR: new(bart.Lite),
|
LocalCIDR: new(bart.Lite),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if fr.isAny(groups, host, cidr) {
|
if fr.isAny(groups, host, ip) {
|
||||||
if fr.Any == nil {
|
if fr.Any == nil {
|
||||||
fr.Any = flc()
|
fr.Any = flc()
|
||||||
}
|
}
|
||||||
|
|
||||||
return fr.Any.addRule(f, localCidr)
|
return fr.Any.addRule(f, localCIDR)
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(groups) > 0 {
|
if len(groups) > 0 {
|
||||||
nlc := flc()
|
nlc := flc()
|
||||||
err := nlc.addRule(f, localCidr)
|
err := nlc.addRule(f, localCIDR)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -773,34 +795,30 @@ func (fr *FirewallRule) addRule(f *Firewall, groups []string, host, cidr, localC
|
|||||||
if nlc == nil {
|
if nlc == nil {
|
||||||
nlc = flc()
|
nlc = flc()
|
||||||
}
|
}
|
||||||
err := nlc.addRule(f, localCidr)
|
err := nlc.addRule(f, localCIDR)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
fr.Hosts[host] = nlc
|
fr.Hosts[host] = nlc
|
||||||
}
|
}
|
||||||
|
|
||||||
if cidr != "" {
|
if ip.IsValid() {
|
||||||
c, err := netip.ParsePrefix(cidr)
|
nlc, _ := fr.CIDR.Get(ip)
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
nlc, _ := fr.CIDR.Get(c)
|
|
||||||
if nlc == nil {
|
if nlc == nil {
|
||||||
nlc = flc()
|
nlc = flc()
|
||||||
}
|
}
|
||||||
err = nlc.addRule(f, localCidr)
|
err := nlc.addRule(f, localCIDR)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
fr.CIDR.Insert(c, nlc)
|
fr.CIDR.Insert(ip, nlc)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (fr *FirewallRule) isAny(groups []string, host string, cidr string) bool {
|
func (fr *FirewallRule) isAny(groups []string, host string, ip netip.Prefix) bool {
|
||||||
if len(groups) == 0 && host == "" && cidr == "" {
|
if len(groups) == 0 && host == "" && !ip.IsValid() {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -814,7 +832,7 @@ func (fr *FirewallRule) isAny(groups []string, host string, cidr string) bool {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
if cidr == "any" {
|
if ip.IsValid() && ip.Bits() == 0 {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -866,13 +884,8 @@ func (fr *FirewallRule) match(p firewall.Packet, c *cert.CachedCertificate) bool
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (flc *firewallLocalCIDR) addRule(f *Firewall, localCidr string) error {
|
func (flc *firewallLocalCIDR) addRule(f *Firewall, localIp netip.Prefix) error {
|
||||||
if localCidr == "any" {
|
if !localIp.IsValid() {
|
||||||
flc.Any = true
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if localCidr == "" {
|
|
||||||
if !f.hasUnsafeNetworks || f.defaultLocalCIDRAny {
|
if !f.hasUnsafeNetworks || f.defaultLocalCIDRAny {
|
||||||
flc.Any = true
|
flc.Any = true
|
||||||
return nil
|
return nil
|
||||||
@@ -883,13 +896,12 @@ func (flc *firewallLocalCIDR) addRule(f *Firewall, localCidr string) error {
|
|||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
|
||||||
|
} else if localIp.Bits() == 0 {
|
||||||
|
flc.Any = true
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
c, err := netip.ParsePrefix(localCidr)
|
flc.LocalCIDR.Insert(localIp)
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
flc.LocalCIDR.Insert(c)
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -910,6 +922,7 @@ type rule struct {
|
|||||||
Code string
|
Code string
|
||||||
Proto string
|
Proto string
|
||||||
Host string
|
Host string
|
||||||
|
Group string
|
||||||
Groups []string
|
Groups []string
|
||||||
Cidr string
|
Cidr string
|
||||||
LocalCidr string
|
LocalCidr string
|
||||||
@@ -951,8 +964,7 @@ func convertRule(l *logrus.Logger, p any, table string, i int) (rule, error) {
|
|||||||
l.Warnf("%s rule #%v; group was an array with a single value, converting to simple value", table, i)
|
l.Warnf("%s rule #%v; group was an array with a single value, converting to simple value", table, i)
|
||||||
m["group"] = v[0]
|
m["group"] = v[0]
|
||||||
}
|
}
|
||||||
|
r.Group = toString("group", m)
|
||||||
singleGroup := toString("group", m)
|
|
||||||
|
|
||||||
if rg, ok := m["groups"]; ok {
|
if rg, ok := m["groups"]; ok {
|
||||||
switch reflect.TypeOf(rg).Kind() {
|
switch reflect.TypeOf(rg).Kind() {
|
||||||
@@ -969,60 +981,9 @@ func convertRule(l *logrus.Logger, p any, table string, i int) (rule, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
//flatten group vs groups
|
|
||||||
if singleGroup != "" {
|
|
||||||
// Check if we have both groups and group provided in the rule config
|
|
||||||
if len(r.Groups) > 0 {
|
|
||||||
return r, fmt.Errorf("only one of group or groups should be defined, both provided")
|
|
||||||
}
|
|
||||||
r.Groups = []string{singleGroup}
|
|
||||||
}
|
|
||||||
|
|
||||||
return r, nil
|
return r, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// sanity returns an error if the rule would be evaluated in a way that would short-circuit a configured check on a wildcard value
|
|
||||||
// rules are evaluated as "port AND proto AND (ca_sha OR ca_name) AND (host OR group OR groups OR cidr) AND local_cidr"
|
|
||||||
func (r *rule) sanity() error {
|
|
||||||
//port, proto, local_cidr are AND, no need to check here
|
|
||||||
//ca_sha and ca_name don't have a wildcard value, no need to check here
|
|
||||||
groupsEmpty := len(r.Groups) == 0
|
|
||||||
hostEmpty := r.Host == ""
|
|
||||||
cidrEmpty := r.Cidr == ""
|
|
||||||
|
|
||||||
if (groupsEmpty && hostEmpty && cidrEmpty) == true {
|
|
||||||
return nil //no content!
|
|
||||||
}
|
|
||||||
|
|
||||||
groupsHasAny := slices.Contains(r.Groups, "any")
|
|
||||||
if groupsHasAny && len(r.Groups) > 1 {
|
|
||||||
return fmt.Errorf("groups spec [%s] contains the group '\"any\". This rule will ignore the other groups specified", r.Groups)
|
|
||||||
}
|
|
||||||
|
|
||||||
if r.Host == "any" {
|
|
||||||
if !groupsEmpty {
|
|
||||||
return fmt.Errorf("groups specified as %s, but host=any will match any host, regardless of groups", r.Groups)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !cidrEmpty {
|
|
||||||
return fmt.Errorf("cidr specified as %s, but host=any will match any host, regardless of cidr", r.Cidr)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if groupsHasAny {
|
|
||||||
if !hostEmpty && r.Host != "any" {
|
|
||||||
return fmt.Errorf("groups spec [%s] contains the group '\"any\". This rule will ignore the specified host %s", r.Groups, r.Host)
|
|
||||||
}
|
|
||||||
if !cidrEmpty {
|
|
||||||
return fmt.Errorf("groups spec [%s] contains the group '\"any\". This rule will ignore the specified cidr %s", r.Groups, r.Cidr)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
//todo alert on cidr-any
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func parsePort(s string) (startPort, endPort int32, err error) {
|
func parsePort(s string) (startPort, endPort int32, err error) {
|
||||||
if s == "any" {
|
if s == "any" {
|
||||||
startPort = firewall.PortAny
|
startPort = firewall.PortAny
|
||||||
|
|||||||
242
firewall_test.go
242
firewall_test.go
@@ -73,106 +73,78 @@ func TestFirewall_AddRule(t *testing.T) {
|
|||||||
ti6, err := netip.ParsePrefix("fd12::34/128")
|
ti6, err := netip.ParsePrefix("fd12::34/128")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoTCP, 1, 1, []string{}, "", "", "", "", ""))
|
require.NoError(t, fw.AddRule(true, firewall.ProtoTCP, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
||||||
// An empty rule is any
|
// An empty rule is any
|
||||||
assert.True(t, fw.InRules.TCP[1].Any.Any.Any)
|
assert.True(t, fw.InRules.TCP[1].Any.Any.Any)
|
||||||
assert.Empty(t, fw.InRules.TCP[1].Any.Groups)
|
assert.Empty(t, fw.InRules.TCP[1].Any.Groups)
|
||||||
assert.Empty(t, fw.InRules.TCP[1].Any.Hosts)
|
assert.Empty(t, fw.InRules.TCP[1].Any.Hosts)
|
||||||
|
|
||||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", "", "", "", ""))
|
require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
||||||
assert.Nil(t, fw.InRules.UDP[1].Any.Any)
|
assert.Nil(t, fw.InRules.UDP[1].Any.Any)
|
||||||
assert.Contains(t, fw.InRules.UDP[1].Any.Groups[0].Groups, "g1")
|
assert.Contains(t, fw.InRules.UDP[1].Any.Groups[0].Groups, "g1")
|
||||||
assert.Empty(t, fw.InRules.UDP[1].Any.Hosts)
|
assert.Empty(t, fw.InRules.UDP[1].Any.Hosts)
|
||||||
|
|
||||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoICMP, 1, 1, []string{}, "h1", "", "", "", ""))
|
require.NoError(t, fw.AddRule(true, firewall.ProtoICMP, 1, 1, []string{}, "h1", netip.Prefix{}, netip.Prefix{}, "", ""))
|
||||||
assert.Nil(t, fw.InRules.ICMP[1].Any.Any)
|
assert.Nil(t, fw.InRules.ICMP[1].Any.Any)
|
||||||
assert.Empty(t, fw.InRules.ICMP[1].Any.Groups)
|
assert.Empty(t, fw.InRules.ICMP[1].Any.Groups)
|
||||||
assert.Contains(t, fw.InRules.ICMP[1].Any.Hosts, "h1")
|
assert.Contains(t, fw.InRules.ICMP[1].Any.Hosts, "h1")
|
||||||
|
|
||||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||||
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti.String(), "", "", ""))
|
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti, netip.Prefix{}, "", ""))
|
||||||
assert.Nil(t, fw.OutRules.AnyProto[1].Any.Any)
|
assert.Nil(t, fw.OutRules.AnyProto[1].Any.Any)
|
||||||
_, ok := fw.OutRules.AnyProto[1].Any.CIDR.Get(ti)
|
_, ok := fw.OutRules.AnyProto[1].Any.CIDR.Get(ti)
|
||||||
assert.True(t, ok)
|
assert.True(t, ok)
|
||||||
|
|
||||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||||
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti6.String(), "", "", ""))
|
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti6, netip.Prefix{}, "", ""))
|
||||||
assert.Nil(t, fw.OutRules.AnyProto[1].Any.Any)
|
assert.Nil(t, fw.OutRules.AnyProto[1].Any.Any)
|
||||||
_, ok = fw.OutRules.AnyProto[1].Any.CIDR.Get(ti6)
|
_, ok = fw.OutRules.AnyProto[1].Any.CIDR.Get(ti6)
|
||||||
assert.True(t, ok)
|
assert.True(t, ok)
|
||||||
|
|
||||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||||
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", "", ti.String(), "", ""))
|
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, ti, "", ""))
|
||||||
assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any)
|
assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any)
|
||||||
ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti)
|
ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti)
|
||||||
assert.True(t, ok)
|
assert.True(t, ok)
|
||||||
|
|
||||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||||
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", "", ti6.String(), "", ""))
|
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, ti6, "", ""))
|
||||||
assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any)
|
assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any)
|
||||||
ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti6)
|
ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti6)
|
||||||
assert.True(t, ok)
|
assert.True(t, ok)
|
||||||
|
|
||||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", "", "", "ca-name", ""))
|
require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "ca-name", ""))
|
||||||
assert.Contains(t, fw.InRules.UDP[1].CANames, "ca-name")
|
assert.Contains(t, fw.InRules.UDP[1].CANames, "ca-name")
|
||||||
|
|
||||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", "", "", "", "ca-sha"))
|
require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "", "ca-sha"))
|
||||||
assert.Contains(t, fw.InRules.UDP[1].CAShas, "ca-sha")
|
assert.Contains(t, fw.InRules.UDP[1].CAShas, "ca-sha")
|
||||||
|
|
||||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||||
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", "", "", "", ""))
|
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", netip.Prefix{}, netip.Prefix{}, "", ""))
|
||||||
assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
|
assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
|
||||||
|
|
||||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||||
anyIp, err := netip.ParsePrefix("0.0.0.0/0")
|
anyIp, err := netip.ParsePrefix("0.0.0.0/0")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp.String(), "", "", ""))
|
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp, netip.Prefix{}, "", ""))
|
||||||
assert.Nil(t, fw.OutRules.AnyProto[0].Any.Any)
|
assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
|
||||||
table, ok := fw.OutRules.AnyProto[0].Any.CIDR.Lookup(netip.MustParseAddr("1.1.1.1"))
|
|
||||||
assert.True(t, table.Any)
|
|
||||||
table, ok = fw.OutRules.AnyProto[0].Any.CIDR.Lookup(netip.MustParseAddr("9::9"))
|
|
||||||
assert.False(t, ok)
|
|
||||||
|
|
||||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||||
anyIp6, err := netip.ParsePrefix("::/0")
|
anyIp6, err := netip.ParsePrefix("::/0")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp6.String(), "", "", ""))
|
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp6, netip.Prefix{}, "", ""))
|
||||||
assert.Nil(t, fw.OutRules.AnyProto[0].Any.Any)
|
|
||||||
table, ok = fw.OutRules.AnyProto[0].Any.CIDR.Lookup(netip.MustParseAddr("9::9"))
|
|
||||||
assert.True(t, table.Any)
|
|
||||||
table, ok = fw.OutRules.AnyProto[0].Any.CIDR.Lookup(netip.MustParseAddr("1.1.1.1"))
|
|
||||||
assert.False(t, ok)
|
|
||||||
|
|
||||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
|
||||||
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", "any", "", "", ""))
|
|
||||||
assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
|
|
||||||
|
|
||||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
|
||||||
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", "", anyIp.String(), "", ""))
|
|
||||||
assert.False(t, fw.OutRules.AnyProto[0].Any.Any.Any)
|
|
||||||
assert.True(t, fw.OutRules.AnyProto[0].Any.Any.LocalCIDR.Lookup(netip.MustParseAddr("1.1.1.1")))
|
|
||||||
assert.False(t, fw.OutRules.AnyProto[0].Any.Any.LocalCIDR.Lookup(netip.MustParseAddr("9::9")))
|
|
||||||
|
|
||||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
|
||||||
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", "", anyIp6.String(), "", ""))
|
|
||||||
assert.False(t, fw.OutRules.AnyProto[0].Any.Any.Any)
|
|
||||||
assert.True(t, fw.OutRules.AnyProto[0].Any.Any.LocalCIDR.Lookup(netip.MustParseAddr("9::9")))
|
|
||||||
assert.False(t, fw.OutRules.AnyProto[0].Any.Any.LocalCIDR.Lookup(netip.MustParseAddr("1.1.1.1")))
|
|
||||||
|
|
||||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
|
||||||
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", "", "any", "", ""))
|
|
||||||
assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
|
assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
|
||||||
|
|
||||||
// Test error conditions
|
// Test error conditions
|
||||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||||
require.Error(t, fw.AddRule(true, math.MaxUint8, 0, 0, []string{}, "", "", "", "", ""))
|
require.Error(t, fw.AddRule(true, math.MaxUint8, 0, 0, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
||||||
require.Error(t, fw.AddRule(true, firewall.ProtoAny, 10, 0, []string{}, "", "", "", "", ""))
|
require.Error(t, fw.AddRule(true, firewall.ProtoAny, 10, 0, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestFirewall_Drop(t *testing.T) {
|
func TestFirewall_Drop(t *testing.T) {
|
||||||
@@ -208,7 +180,7 @@ func TestFirewall_Drop(t *testing.T) {
|
|||||||
h.buildNetworks(myVpnNetworksTable, &c)
|
h.buildNetworks(myVpnNetworksTable, &c)
|
||||||
|
|
||||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", ""))
|
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
||||||
cp := cert.NewCAPool()
|
cp := cert.NewCAPool()
|
||||||
|
|
||||||
// Drop outbound
|
// Drop outbound
|
||||||
@@ -227,28 +199,28 @@ func TestFirewall_Drop(t *testing.T) {
|
|||||||
|
|
||||||
// ensure signer doesn't get in the way of group checks
|
// ensure signer doesn't get in the way of group checks
|
||||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum"))
|
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum"))
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum-bad"))
|
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad"))
|
||||||
assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
|
assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
|
||||||
|
|
||||||
// test caSha doesn't drop on match
|
// test caSha doesn't drop on match
|
||||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum-bad"))
|
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad"))
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum"))
|
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum"))
|
||||||
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
|
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
|
||||||
|
|
||||||
// ensure ca name doesn't get in the way of group checks
|
// ensure ca name doesn't get in the way of group checks
|
||||||
cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
|
cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
|
||||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good", ""))
|
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", ""))
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good-bad", ""))
|
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", ""))
|
||||||
assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
|
assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
|
||||||
|
|
||||||
// test caName doesn't drop on match
|
// test caName doesn't drop on match
|
||||||
cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
|
cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
|
||||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good-bad", ""))
|
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", ""))
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good", ""))
|
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", ""))
|
||||||
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
|
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -287,7 +259,7 @@ func TestFirewall_DropV6(t *testing.T) {
|
|||||||
h.buildNetworks(myVpnNetworksTable, &c)
|
h.buildNetworks(myVpnNetworksTable, &c)
|
||||||
|
|
||||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", ""))
|
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
||||||
cp := cert.NewCAPool()
|
cp := cert.NewCAPool()
|
||||||
|
|
||||||
// Drop outbound
|
// Drop outbound
|
||||||
@@ -306,28 +278,28 @@ func TestFirewall_DropV6(t *testing.T) {
|
|||||||
|
|
||||||
// ensure signer doesn't get in the way of group checks
|
// ensure signer doesn't get in the way of group checks
|
||||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum"))
|
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum"))
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum-bad"))
|
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad"))
|
||||||
assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
|
assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
|
||||||
|
|
||||||
// test caSha doesn't drop on match
|
// test caSha doesn't drop on match
|
||||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum-bad"))
|
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad"))
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum"))
|
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum"))
|
||||||
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
|
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
|
||||||
|
|
||||||
// ensure ca name doesn't get in the way of group checks
|
// ensure ca name doesn't get in the way of group checks
|
||||||
cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
|
cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
|
||||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good", ""))
|
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", ""))
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good-bad", ""))
|
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", ""))
|
||||||
assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
|
assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
|
||||||
|
|
||||||
// test caName doesn't drop on match
|
// test caName doesn't drop on match
|
||||||
cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
|
cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
|
||||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good-bad", ""))
|
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", ""))
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good", ""))
|
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", ""))
|
||||||
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
|
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -338,12 +310,12 @@ func BenchmarkFirewallTable_match(b *testing.B) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pfix := netip.MustParsePrefix("172.1.1.1/32")
|
pfix := netip.MustParsePrefix("172.1.1.1/32")
|
||||||
_ = ft.TCP.addRule(f, 10, 10, []string{"good-group"}, "good-host", pfix.String(), "", "", "")
|
_ = ft.TCP.addRule(f, 10, 10, []string{"good-group"}, "good-host", pfix, netip.Prefix{}, "", "")
|
||||||
_ = ft.TCP.addRule(f, 100, 100, []string{"good-group"}, "good-host", "", pfix.String(), "", "")
|
_ = ft.TCP.addRule(f, 100, 100, []string{"good-group"}, "good-host", netip.Prefix{}, pfix, "", "")
|
||||||
|
|
||||||
pfix6 := netip.MustParsePrefix("fd11::11/128")
|
pfix6 := netip.MustParsePrefix("fd11::11/128")
|
||||||
_ = ft.TCP.addRule(f, 10, 10, []string{"good-group"}, "good-host", pfix6.String(), "", "", "")
|
_ = ft.TCP.addRule(f, 10, 10, []string{"good-group"}, "good-host", pfix6, netip.Prefix{}, "", "")
|
||||||
_ = ft.TCP.addRule(f, 100, 100, []string{"good-group"}, "good-host", "", pfix6.String(), "", "")
|
_ = ft.TCP.addRule(f, 100, 100, []string{"good-group"}, "good-host", netip.Prefix{}, pfix6, "", "")
|
||||||
cp := cert.NewCAPool()
|
cp := cert.NewCAPool()
|
||||||
|
|
||||||
b.Run("fail on proto", func(b *testing.B) {
|
b.Run("fail on proto", func(b *testing.B) {
|
||||||
@@ -532,7 +504,7 @@ func TestFirewall_Drop2(t *testing.T) {
|
|||||||
h1.buildNetworks(myVpnNetworksTable, c1.Certificate)
|
h1.buildNetworks(myVpnNetworksTable, c1.Certificate)
|
||||||
|
|
||||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", "", "", "", ""))
|
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
||||||
cp := cert.NewCAPool()
|
cp := cert.NewCAPool()
|
||||||
|
|
||||||
// h1/c1 lacks the proper groups
|
// h1/c1 lacks the proper groups
|
||||||
@@ -612,8 +584,8 @@ func TestFirewall_Drop3(t *testing.T) {
|
|||||||
h3.buildNetworks(myVpnNetworksTable, c3.Certificate)
|
h3.buildNetworks(myVpnNetworksTable, c3.Certificate)
|
||||||
|
|
||||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", "", "", "", ""))
|
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", netip.Prefix{}, netip.Prefix{}, "", ""))
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", "", "", "", "signer-sha"))
|
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-sha"))
|
||||||
cp := cert.NewCAPool()
|
cp := cert.NewCAPool()
|
||||||
|
|
||||||
// c1 should pass because host match
|
// c1 should pass because host match
|
||||||
@@ -627,7 +599,7 @@ func TestFirewall_Drop3(t *testing.T) {
|
|||||||
|
|
||||||
// Test a remote address match
|
// Test a remote address match
|
||||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", "1.2.3.4/24", "", "", ""))
|
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.MustParsePrefix("1.2.3.4/24"), netip.Prefix{}, "", ""))
|
||||||
require.NoError(t, fw.Drop(p, true, &h1, cp, nil))
|
require.NoError(t, fw.Drop(p, true, &h1, cp, nil))
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -665,7 +637,7 @@ func TestFirewall_Drop3V6(t *testing.T) {
|
|||||||
// Test a remote address match
|
// Test a remote address match
|
||||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
||||||
cp := cert.NewCAPool()
|
cp := cert.NewCAPool()
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", "fd12::34/120", "", "", ""))
|
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.MustParsePrefix("fd12::34/120"), netip.Prefix{}, "", ""))
|
||||||
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
|
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -704,7 +676,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
|
|||||||
h.buildNetworks(myVpnNetworksTable, c.Certificate)
|
h.buildNetworks(myVpnNetworksTable, c.Certificate)
|
||||||
|
|
||||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", ""))
|
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
||||||
cp := cert.NewCAPool()
|
cp := cert.NewCAPool()
|
||||||
|
|
||||||
// Drop outbound
|
// Drop outbound
|
||||||
@@ -717,7 +689,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
|
|||||||
|
|
||||||
oldFw := fw
|
oldFw := fw
|
||||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 10, 10, []string{"any"}, "", "", "", "", ""))
|
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 10, 10, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
||||||
fw.Conntrack = oldFw.Conntrack
|
fw.Conntrack = oldFw.Conntrack
|
||||||
fw.rulesVersion = oldFw.rulesVersion + 1
|
fw.rulesVersion = oldFw.rulesVersion + 1
|
||||||
|
|
||||||
@@ -726,7 +698,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
|
|||||||
|
|
||||||
oldFw = fw
|
oldFw = fw
|
||||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 11, 11, []string{"any"}, "", "", "", "", ""))
|
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 11, 11, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
||||||
fw.Conntrack = oldFw.Conntrack
|
fw.Conntrack = oldFw.Conntrack
|
||||||
fw.rulesVersion = oldFw.rulesVersion + 1
|
fw.rulesVersion = oldFw.rulesVersion + 1
|
||||||
|
|
||||||
@@ -765,7 +737,7 @@ func TestFirewall_DropIPSpoofing(t *testing.T) {
|
|||||||
|
|
||||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
||||||
|
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", "", "", "", ""))
|
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
||||||
cp := cert.NewCAPool()
|
cp := cert.NewCAPool()
|
||||||
|
|
||||||
// Packet spoofed by `c1`. Note that the remote addr is not a valid one.
|
// Packet spoofed by `c1`. Note that the remote addr is not a valid one.
|
||||||
@@ -959,28 +931,28 @@ func TestAddFirewallRulesFromConfig(t *testing.T) {
|
|||||||
mf := &mockFirewall{}
|
mf := &mockFirewall{}
|
||||||
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "tcp", "host": "a"}}}
|
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "tcp", "host": "a"}}}
|
||||||
require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf))
|
require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf))
|
||||||
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoTCP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall)
|
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoTCP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
|
||||||
|
|
||||||
// Test adding udp rule
|
// Test adding udp rule
|
||||||
conf = config.NewC(l)
|
conf = config.NewC(l)
|
||||||
mf = &mockFirewall{}
|
mf = &mockFirewall{}
|
||||||
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "udp", "host": "a"}}}
|
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "udp", "host": "a"}}}
|
||||||
require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf))
|
require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf))
|
||||||
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoUDP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall)
|
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoUDP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
|
||||||
|
|
||||||
// Test adding icmp rule
|
// Test adding icmp rule
|
||||||
conf = config.NewC(l)
|
conf = config.NewC(l)
|
||||||
mf = &mockFirewall{}
|
mf = &mockFirewall{}
|
||||||
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "icmp", "host": "a"}}}
|
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "icmp", "host": "a"}}}
|
||||||
require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf))
|
require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf))
|
||||||
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall)
|
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
|
||||||
|
|
||||||
// Test adding any rule
|
// Test adding any rule
|
||||||
conf = config.NewC(l)
|
conf = config.NewC(l)
|
||||||
mf = &mockFirewall{}
|
mf = &mockFirewall{}
|
||||||
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "host": "a"}}}
|
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "host": "a"}}}
|
||||||
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
||||||
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall)
|
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
|
||||||
|
|
||||||
// Test adding rule with cidr
|
// Test adding rule with cidr
|
||||||
cidr := netip.MustParsePrefix("10.0.0.0/8")
|
cidr := netip.MustParsePrefix("10.0.0.0/8")
|
||||||
@@ -988,14 +960,14 @@ func TestAddFirewallRulesFromConfig(t *testing.T) {
|
|||||||
mf = &mockFirewall{}
|
mf = &mockFirewall{}
|
||||||
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": cidr.String()}}}
|
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": cidr.String()}}}
|
||||||
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
||||||
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr.String(), localIp: ""}, mf.lastCall)
|
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr, localIp: netip.Prefix{}}, mf.lastCall)
|
||||||
|
|
||||||
// Test adding rule with local_cidr
|
// Test adding rule with local_cidr
|
||||||
conf = config.NewC(l)
|
conf = config.NewC(l)
|
||||||
mf = &mockFirewall{}
|
mf = &mockFirewall{}
|
||||||
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": cidr.String()}}}
|
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": cidr.String()}}}
|
||||||
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
||||||
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "", localIp: cidr.String()}, mf.lastCall)
|
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: cidr}, mf.lastCall)
|
||||||
|
|
||||||
// Test adding rule with cidr ipv6
|
// Test adding rule with cidr ipv6
|
||||||
cidr6 := netip.MustParsePrefix("fd00::/8")
|
cidr6 := netip.MustParsePrefix("fd00::/8")
|
||||||
@@ -1003,75 +975,49 @@ func TestAddFirewallRulesFromConfig(t *testing.T) {
|
|||||||
mf = &mockFirewall{}
|
mf = &mockFirewall{}
|
||||||
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": cidr6.String()}}}
|
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": cidr6.String()}}}
|
||||||
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
||||||
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr6.String(), localIp: ""}, mf.lastCall)
|
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr6, localIp: netip.Prefix{}}, mf.lastCall)
|
||||||
|
|
||||||
// Test adding rule with any cidr
|
|
||||||
conf = config.NewC(l)
|
|
||||||
mf = &mockFirewall{}
|
|
||||||
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": "any"}}}
|
|
||||||
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
|
||||||
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "any", localIp: ""}, mf.lastCall)
|
|
||||||
|
|
||||||
// Test adding rule with junk cidr
|
|
||||||
conf = config.NewC(l)
|
|
||||||
mf = &mockFirewall{}
|
|
||||||
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": "junk/junk"}}}
|
|
||||||
require.EqualError(t, AddFirewallRulesFromConfig(l, true, conf, mf), "firewall.inbound rule #0; cidr did not parse; netip.ParsePrefix(\"junk/junk\"): ParseAddr(\"junk\"): unable to parse IP")
|
|
||||||
|
|
||||||
// Test adding rule with local_cidr ipv6
|
// Test adding rule with local_cidr ipv6
|
||||||
conf = config.NewC(l)
|
conf = config.NewC(l)
|
||||||
mf = &mockFirewall{}
|
mf = &mockFirewall{}
|
||||||
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": cidr6.String()}}}
|
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": cidr6.String()}}}
|
||||||
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
||||||
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "", localIp: cidr6.String()}, mf.lastCall)
|
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: cidr6}, mf.lastCall)
|
||||||
|
|
||||||
// Test adding rule with any local_cidr
|
|
||||||
conf = config.NewC(l)
|
|
||||||
mf = &mockFirewall{}
|
|
||||||
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": "any"}}}
|
|
||||||
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
|
||||||
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, localIp: "any"}, mf.lastCall)
|
|
||||||
|
|
||||||
// Test adding rule with junk local_cidr
|
|
||||||
conf = config.NewC(l)
|
|
||||||
mf = &mockFirewall{}
|
|
||||||
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": "junk/junk"}}}
|
|
||||||
require.EqualError(t, AddFirewallRulesFromConfig(l, true, conf, mf), "firewall.inbound rule #0; local_cidr did not parse; netip.ParsePrefix(\"junk/junk\"): ParseAddr(\"junk\"): unable to parse IP")
|
|
||||||
|
|
||||||
// Test adding rule with ca_sha
|
// Test adding rule with ca_sha
|
||||||
conf = config.NewC(l)
|
conf = config.NewC(l)
|
||||||
mf = &mockFirewall{}
|
mf = &mockFirewall{}
|
||||||
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "ca_sha": "12312313123"}}}
|
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "ca_sha": "12312313123"}}}
|
||||||
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
||||||
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "", localIp: "", caSha: "12312313123"}, mf.lastCall)
|
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: netip.Prefix{}, caSha: "12312313123"}, mf.lastCall)
|
||||||
|
|
||||||
// Test adding rule with ca_name
|
// Test adding rule with ca_name
|
||||||
conf = config.NewC(l)
|
conf = config.NewC(l)
|
||||||
mf = &mockFirewall{}
|
mf = &mockFirewall{}
|
||||||
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "ca_name": "root01"}}}
|
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "ca_name": "root01"}}}
|
||||||
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
||||||
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "", localIp: "", caName: "root01"}, mf.lastCall)
|
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: netip.Prefix{}, caName: "root01"}, mf.lastCall)
|
||||||
|
|
||||||
// Test single group
|
// Test single group
|
||||||
conf = config.NewC(l)
|
conf = config.NewC(l)
|
||||||
mf = &mockFirewall{}
|
mf = &mockFirewall{}
|
||||||
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "group": "a"}}}
|
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "group": "a"}}}
|
||||||
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
||||||
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: "", localIp: ""}, mf.lastCall)
|
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
|
||||||
|
|
||||||
// Test single groups
|
// Test single groups
|
||||||
conf = config.NewC(l)
|
conf = config.NewC(l)
|
||||||
mf = &mockFirewall{}
|
mf = &mockFirewall{}
|
||||||
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "groups": "a"}}}
|
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "groups": "a"}}}
|
||||||
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
||||||
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: "", localIp: ""}, mf.lastCall)
|
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
|
||||||
|
|
||||||
// Test multiple AND groups
|
// Test multiple AND groups
|
||||||
conf = config.NewC(l)
|
conf = config.NewC(l)
|
||||||
mf = &mockFirewall{}
|
mf = &mockFirewall{}
|
||||||
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "groups": []string{"a", "b"}}}}
|
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "groups": []string{"a", "b"}}}}
|
||||||
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
||||||
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a", "b"}, ip: "", localIp: ""}, mf.lastCall)
|
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a", "b"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
|
||||||
|
|
||||||
// Test Add error
|
// Test Add error
|
||||||
conf = config.NewC(l)
|
conf = config.NewC(l)
|
||||||
@@ -1094,7 +1040,7 @@ func TestFirewall_convertRule(t *testing.T) {
|
|||||||
r, err := convertRule(l, c, "test", 1)
|
r, err := convertRule(l, c, "test", 1)
|
||||||
assert.Contains(t, ob.String(), "test rule #1; group was an array with a single value, converting to simple value")
|
assert.Contains(t, ob.String(), "test rule #1; group was an array with a single value, converting to simple value")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, []string{"group1"}, r.Groups)
|
assert.Equal(t, "group1", r.Group)
|
||||||
|
|
||||||
// Ensure group array of > 1 is errord
|
// Ensure group array of > 1 is errord
|
||||||
ob.Reset()
|
ob.Reset()
|
||||||
@@ -1114,63 +1060,7 @@ func TestFirewall_convertRule(t *testing.T) {
|
|||||||
|
|
||||||
r, err = convertRule(l, c, "test", 1)
|
r, err = convertRule(l, c, "test", 1)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, []string{"group1"}, r.Groups)
|
assert.Equal(t, "group1", r.Group)
|
||||||
}
|
|
||||||
|
|
||||||
func TestFirewall_convertRuleSanity(t *testing.T) {
|
|
||||||
l := test.NewLogger()
|
|
||||||
ob := &bytes.Buffer{}
|
|
||||||
l.SetOutput(ob)
|
|
||||||
|
|
||||||
noWarningPlease := []map[string]any{
|
|
||||||
{"group": "group1"},
|
|
||||||
{"groups": []any{"group2"}},
|
|
||||||
{"host": "bob"},
|
|
||||||
{"cidr": "1.1.1.1/1"},
|
|
||||||
{"groups": []any{"group2"}, "host": "bob"},
|
|
||||||
{"cidr": "1.1.1.1/1", "host": "bob"},
|
|
||||||
{"groups": []any{"group2"}, "cidr": "1.1.1.1/1"},
|
|
||||||
{"groups": []any{"group2"}, "cidr": "1.1.1.1/1", "host": "bob"},
|
|
||||||
}
|
|
||||||
for _, c := range noWarningPlease {
|
|
||||||
r, err := convertRule(l, c, "test", 1)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.NoError(t, r.sanity(), "should not generate a sanity warning, %+v", c)
|
|
||||||
}
|
|
||||||
|
|
||||||
yesWarningPlease := []map[string]any{
|
|
||||||
{"group": "group1"},
|
|
||||||
{"groups": []any{"group2"}},
|
|
||||||
{"cidr": "1.1.1.1/1"},
|
|
||||||
{"groups": []any{"group2"}, "host": "bob"},
|
|
||||||
{"cidr": "1.1.1.1/1", "host": "bob"},
|
|
||||||
{"groups": []any{"group2"}, "cidr": "1.1.1.1/1"},
|
|
||||||
{"groups": []any{"group2"}, "cidr": "1.1.1.1/1", "host": "bob"},
|
|
||||||
}
|
|
||||||
for _, c := range yesWarningPlease {
|
|
||||||
c["host"] = "any"
|
|
||||||
r, err := convertRule(l, c, "test", 1)
|
|
||||||
require.NoError(t, err)
|
|
||||||
err = r.sanity()
|
|
||||||
require.Error(t, err, "I wanted a warning: %+v", c)
|
|
||||||
}
|
|
||||||
//reset the list
|
|
||||||
yesWarningPlease = []map[string]any{
|
|
||||||
{"group": "group1"},
|
|
||||||
{"groups": []any{"group2"}},
|
|
||||||
{"cidr": "1.1.1.1/1"},
|
|
||||||
{"groups": []any{"group2"}, "host": "bob"},
|
|
||||||
{"cidr": "1.1.1.1/1", "host": "bob"},
|
|
||||||
{"groups": []any{"group2"}, "cidr": "1.1.1.1/1"},
|
|
||||||
{"groups": []any{"group2"}, "cidr": "1.1.1.1/1", "host": "bob"},
|
|
||||||
}
|
|
||||||
for _, c := range yesWarningPlease {
|
|
||||||
r, err := convertRule(l, c, "test", 1)
|
|
||||||
require.NoError(t, err)
|
|
||||||
r.Groups = append(r.Groups, "any")
|
|
||||||
err = r.sanity()
|
|
||||||
require.Error(t, err, "I wanted a warning: %+v", c)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type testcase struct {
|
type testcase struct {
|
||||||
@@ -1251,7 +1141,7 @@ func newSetupFromCert(t *testing.T, l *logrus.Logger, c dummyCert) testsetup {
|
|||||||
myVpnNetworksTable.Insert(prefix)
|
myVpnNetworksTable.Insert(prefix)
|
||||||
}
|
}
|
||||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", ""))
|
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
||||||
|
|
||||||
return testsetup{
|
return testsetup{
|
||||||
c: c,
|
c: c,
|
||||||
@@ -1332,7 +1222,7 @@ func TestFirewall_Drop_EnforceIPMatch(t *testing.T) {
|
|||||||
tc.p.LocalAddr = netip.MustParseAddr("192.168.0.3")
|
tc.p.LocalAddr = netip.MustParseAddr("192.168.0.3")
|
||||||
tc.err = ErrNoMatchingRule
|
tc.err = ErrNoMatchingRule
|
||||||
tc.Test(t, unsafeSetup.fw) //should hit firewall and bounce off
|
tc.Test(t, unsafeSetup.fw) //should hit firewall and bounce off
|
||||||
require.NoError(t, unsafeSetup.fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", unsafePrefix.String(), "", ""))
|
require.NoError(t, unsafeSetup.fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, unsafePrefix, "", ""))
|
||||||
tc.err = nil
|
tc.err = nil
|
||||||
tc.Test(t, unsafeSetup.fw) //should pass
|
tc.Test(t, unsafeSetup.fw) //should pass
|
||||||
})
|
})
|
||||||
@@ -1345,8 +1235,8 @@ type addRuleCall struct {
|
|||||||
endPort int32
|
endPort int32
|
||||||
groups []string
|
groups []string
|
||||||
host string
|
host string
|
||||||
ip string
|
ip netip.Prefix
|
||||||
localIp string
|
localIp netip.Prefix
|
||||||
caName string
|
caName string
|
||||||
caSha string
|
caSha string
|
||||||
}
|
}
|
||||||
@@ -1356,7 +1246,7 @@ type mockFirewall struct {
|
|||||||
nextCallReturn error
|
nextCallReturn error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (mf *mockFirewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip, localIp, caName string, caSha string) error {
|
func (mf *mockFirewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip netip.Prefix, localIp netip.Prefix, caName string, caSha string) error {
|
||||||
mf.lastCall = addRuleCall{
|
mf.lastCall = addRuleCall{
|
||||||
incoming: incoming,
|
incoming: incoming,
|
||||||
proto: proto,
|
proto: proto,
|
||||||
|
|||||||
4
go.mod
4
go.mod
@@ -23,9 +23,9 @@ require (
|
|||||||
github.com/stretchr/testify v1.11.1
|
github.com/stretchr/testify v1.11.1
|
||||||
github.com/vishvananda/netlink v1.3.1
|
github.com/vishvananda/netlink v1.3.1
|
||||||
go.yaml.in/yaml/v3 v3.0.4
|
go.yaml.in/yaml/v3 v3.0.4
|
||||||
golang.org/x/crypto v0.45.0
|
golang.org/x/crypto v0.44.0
|
||||||
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090
|
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090
|
||||||
golang.org/x/net v0.47.0
|
golang.org/x/net v0.46.0
|
||||||
golang.org/x/sync v0.18.0
|
golang.org/x/sync v0.18.0
|
||||||
golang.org/x/sys v0.38.0
|
golang.org/x/sys v0.38.0
|
||||||
golang.org/x/term v0.37.0
|
golang.org/x/term v0.37.0
|
||||||
|
|||||||
8
go.sum
8
go.sum
@@ -162,8 +162,8 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk
|
|||||||
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||||
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
||||||
golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
|
golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
|
||||||
golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q=
|
golang.org/x/crypto v0.44.0 h1:A97SsFvM3AIwEEmTBiaxPPTYpDC47w720rdiiUvgoAU=
|
||||||
golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4=
|
golang.org/x/crypto v0.44.0/go.mod h1:013i+Nw79BMiQiMsOPcVCB5ZIJbYkerPrGnOa00tvmc=
|
||||||
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 h1:Di6/M8l0O2lCLc6VVRWhgCiApHV8MnQurBnFSHsQtNY=
|
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 h1:Di6/M8l0O2lCLc6VVRWhgCiApHV8MnQurBnFSHsQtNY=
|
||||||
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc=
|
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc=
|
||||||
golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY=
|
golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY=
|
||||||
@@ -182,8 +182,8 @@ golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLL
|
|||||||
golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
|
golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
|
||||||
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
|
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
|
||||||
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||||
golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY=
|
golang.org/x/net v0.46.0 h1:giFlY12I07fugqwPuWJi68oOnpfqFnJIJzaIIm2JVV4=
|
||||||
golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU=
|
golang.org/x/net v0.46.0/go.mod h1:Q9BGdFy1y4nkUwiLvT5qtyhAnEHgnQ/zd8PfU6nc210=
|
||||||
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
|
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
|
||||||
golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
|
|||||||
125
handshake_ix.go
125
handshake_ix.go
@@ -99,11 +99,11 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func ixHandshakeStage1(f *Interface, via *ViaSender, packet []byte, h *header.H) {
|
func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet []byte, h *header.H) {
|
||||||
cs := f.pki.getCertState()
|
cs := f.pki.getCertState()
|
||||||
crt := cs.GetDefaultCertificate()
|
crt := cs.GetDefaultCertificate()
|
||||||
if crt == nil {
|
if crt == nil {
|
||||||
f.l.WithField("from", via).
|
f.l.WithField("udpAddr", addr).
|
||||||
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
|
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
|
||||||
WithField("certVersion", cs.initiatingVersion).
|
WithField("certVersion", cs.initiatingVersion).
|
||||||
Error("Unable to handshake with host because no certificate is available")
|
Error("Unable to handshake with host because no certificate is available")
|
||||||
@@ -112,7 +112,7 @@ func ixHandshakeStage1(f *Interface, via *ViaSender, packet []byte, h *header.H)
|
|||||||
|
|
||||||
ci, err := NewConnectionState(f.l, cs, crt, false, noise.HandshakeIX)
|
ci, err := NewConnectionState(f.l, cs, crt, false, noise.HandshakeIX)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
f.l.WithError(err).WithField("from", via).
|
f.l.WithError(err).WithField("udpAddr", addr).
|
||||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
||||||
Error("Failed to create connection state")
|
Error("Failed to create connection state")
|
||||||
return
|
return
|
||||||
@@ -123,7 +123,7 @@ func ixHandshakeStage1(f *Interface, via *ViaSender, packet []byte, h *header.H)
|
|||||||
|
|
||||||
msg, _, _, err := ci.H.ReadMessage(nil, packet[header.Len:])
|
msg, _, _, err := ci.H.ReadMessage(nil, packet[header.Len:])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
f.l.WithError(err).WithField("from", via).
|
f.l.WithError(err).WithField("udpAddr", addr).
|
||||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
||||||
Error("Failed to call noise.ReadMessage")
|
Error("Failed to call noise.ReadMessage")
|
||||||
return
|
return
|
||||||
@@ -132,7 +132,7 @@ func ixHandshakeStage1(f *Interface, via *ViaSender, packet []byte, h *header.H)
|
|||||||
hs := &NebulaHandshake{}
|
hs := &NebulaHandshake{}
|
||||||
err = hs.Unmarshal(msg)
|
err = hs.Unmarshal(msg)
|
||||||
if err != nil || hs.Details == nil {
|
if err != nil || hs.Details == nil {
|
||||||
f.l.WithError(err).WithField("from", via).
|
f.l.WithError(err).WithField("udpAddr", addr).
|
||||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
||||||
Error("Failed unmarshal handshake message")
|
Error("Failed unmarshal handshake message")
|
||||||
return
|
return
|
||||||
@@ -140,7 +140,7 @@ func ixHandshakeStage1(f *Interface, via *ViaSender, packet []byte, h *header.H)
|
|||||||
|
|
||||||
rc, err := cert.Recombine(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), ci.Curve())
|
rc, err := cert.Recombine(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), ci.Curve())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
f.l.WithError(err).WithField("from", via).
|
f.l.WithError(err).WithField("udpAddr", addr).
|
||||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
||||||
Info("Handshake did not contain a certificate")
|
Info("Handshake did not contain a certificate")
|
||||||
return
|
return
|
||||||
@@ -153,7 +153,7 @@ func ixHandshakeStage1(f *Interface, via *ViaSender, packet []byte, h *header.H)
|
|||||||
fp = "<error generating certificate fingerprint>"
|
fp = "<error generating certificate fingerprint>"
|
||||||
}
|
}
|
||||||
|
|
||||||
e := f.l.WithError(err).WithField("from", via).
|
e := f.l.WithError(err).WithField("udpAddr", addr).
|
||||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
||||||
WithField("certVpnNetworks", rc.Networks()).
|
WithField("certVpnNetworks", rc.Networks()).
|
||||||
WithField("certFingerprint", fp)
|
WithField("certFingerprint", fp)
|
||||||
@@ -172,7 +172,7 @@ func ixHandshakeStage1(f *Interface, via *ViaSender, packet []byte, h *header.H)
|
|||||||
if myCertOtherVersion == nil {
|
if myCertOtherVersion == nil {
|
||||||
if f.l.Level >= logrus.DebugLevel {
|
if f.l.Level >= logrus.DebugLevel {
|
||||||
f.l.WithError(err).WithFields(m{
|
f.l.WithError(err).WithFields(m{
|
||||||
"from": via,
|
"udpAddr": addr,
|
||||||
"handshake": m{"stage": 1, "style": "ix_psk0"},
|
"handshake": m{"stage": 1, "style": "ix_psk0"},
|
||||||
"cert": remoteCert,
|
"cert": remoteCert,
|
||||||
}).Debug("Might be unable to handshake with host due to missing certificate version")
|
}).Debug("Might be unable to handshake with host due to missing certificate version")
|
||||||
@@ -184,7 +184,7 @@ func ixHandshakeStage1(f *Interface, via *ViaSender, packet []byte, h *header.H)
|
|||||||
}
|
}
|
||||||
|
|
||||||
if len(remoteCert.Certificate.Networks()) == 0 {
|
if len(remoteCert.Certificate.Networks()) == 0 {
|
||||||
f.l.WithError(err).WithField("from", via).
|
f.l.WithError(err).WithField("udpAddr", addr).
|
||||||
WithField("cert", remoteCert).
|
WithField("cert", remoteCert).
|
||||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
||||||
Info("No networks in certificate")
|
Info("No networks in certificate")
|
||||||
@@ -201,7 +201,7 @@ func ixHandshakeStage1(f *Interface, via *ViaSender, packet []byte, h *header.H)
|
|||||||
vpnAddrs := make([]netip.Addr, len(vpnNetworks))
|
vpnAddrs := make([]netip.Addr, len(vpnNetworks))
|
||||||
for i, network := range vpnNetworks {
|
for i, network := range vpnNetworks {
|
||||||
if f.myVpnAddrsTable.Contains(network.Addr()) {
|
if f.myVpnAddrsTable.Contains(network.Addr()) {
|
||||||
f.l.WithField("vpnNetworks", vpnNetworks).WithField("from", via).
|
f.l.WithField("vpnNetworks", vpnNetworks).WithField("udpAddr", addr).
|
||||||
WithField("certName", certName).
|
WithField("certName", certName).
|
||||||
WithField("certVersion", certVersion).
|
WithField("certVersion", certVersion).
|
||||||
WithField("fingerprint", fingerprint).
|
WithField("fingerprint", fingerprint).
|
||||||
@@ -215,18 +215,18 @@ func ixHandshakeStage1(f *Interface, via *ViaSender, packet []byte, h *header.H)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if !via.IsRelayed {
|
if addr.IsValid() {
|
||||||
|
// addr can be invalid when the tunnel is being relayed.
|
||||||
// We only want to apply the remote allow list for direct tunnels here
|
// We only want to apply the remote allow list for direct tunnels here
|
||||||
if !f.lightHouse.GetRemoteAllowList().AllowAll(vpnAddrs, via.UdpAddr.Addr()) {
|
if !f.lightHouse.GetRemoteAllowList().AllowAll(vpnAddrs, addr.Addr()) {
|
||||||
f.l.WithField("vpnAddrs", vpnAddrs).WithField("from", via).
|
f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
|
||||||
Debug("lighthouse.remote_allow_list denied incoming handshake")
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
myIndex, err := generateIndex(f.l)
|
myIndex, err := generateIndex(f.l)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
f.l.WithError(err).WithField("vpnAddrs", vpnAddrs).WithField("from", via).
|
f.l.WithError(err).WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
|
||||||
WithField("certName", certName).
|
WithField("certName", certName).
|
||||||
WithField("certVersion", certVersion).
|
WithField("certVersion", certVersion).
|
||||||
WithField("fingerprint", fingerprint).
|
WithField("fingerprint", fingerprint).
|
||||||
@@ -251,7 +251,7 @@ func ixHandshakeStage1(f *Interface, via *ViaSender, packet []byte, h *header.H)
|
|||||||
|
|
||||||
msgRxL := f.l.WithFields(m{
|
msgRxL := f.l.WithFields(m{
|
||||||
"vpnAddrs": vpnAddrs,
|
"vpnAddrs": vpnAddrs,
|
||||||
"from": via,
|
"udpAddr": addr,
|
||||||
"certName": certName,
|
"certName": certName,
|
||||||
"certVersion": certVersion,
|
"certVersion": certVersion,
|
||||||
"fingerprint": fingerprint,
|
"fingerprint": fingerprint,
|
||||||
@@ -283,7 +283,7 @@ func ixHandshakeStage1(f *Interface, via *ViaSender, packet []byte, h *header.H)
|
|||||||
|
|
||||||
hsBytes, err := hs.Marshal()
|
hsBytes, err := hs.Marshal()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via).
|
f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).
|
||||||
WithField("certName", certName).
|
WithField("certName", certName).
|
||||||
WithField("certVersion", certVersion).
|
WithField("certVersion", certVersion).
|
||||||
WithField("fingerprint", fingerprint).
|
WithField("fingerprint", fingerprint).
|
||||||
@@ -295,7 +295,7 @@ func ixHandshakeStage1(f *Interface, via *ViaSender, packet []byte, h *header.H)
|
|||||||
nh := header.Encode(make([]byte, header.Len), header.Version, header.Handshake, header.HandshakeIXPSK0, hs.Details.InitiatorIndex, 2)
|
nh := header.Encode(make([]byte, header.Len), header.Version, header.Handshake, header.HandshakeIXPSK0, hs.Details.InitiatorIndex, 2)
|
||||||
msg, dKey, eKey, err := ci.H.WriteMessage(nh, hsBytes)
|
msg, dKey, eKey, err := ci.H.WriteMessage(nh, hsBytes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via).
|
f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).
|
||||||
WithField("certName", certName).
|
WithField("certName", certName).
|
||||||
WithField("certVersion", certVersion).
|
WithField("certVersion", certVersion).
|
||||||
WithField("fingerprint", fingerprint).
|
WithField("fingerprint", fingerprint).
|
||||||
@@ -303,7 +303,7 @@ func ixHandshakeStage1(f *Interface, via *ViaSender, packet []byte, h *header.H)
|
|||||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage")
|
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage")
|
||||||
return
|
return
|
||||||
} else if dKey == nil || eKey == nil {
|
} else if dKey == nil || eKey == nil {
|
||||||
f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via).
|
f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).
|
||||||
WithField("certName", certName).
|
WithField("certName", certName).
|
||||||
WithField("certVersion", certVersion).
|
WithField("certVersion", certVersion).
|
||||||
WithField("fingerprint", fingerprint).
|
WithField("fingerprint", fingerprint).
|
||||||
@@ -329,9 +329,7 @@ func ixHandshakeStage1(f *Interface, via *ViaSender, packet []byte, h *header.H)
|
|||||||
ci.eKey = NewNebulaCipherState(eKey)
|
ci.eKey = NewNebulaCipherState(eKey)
|
||||||
|
|
||||||
hostinfo.remotes = f.lightHouse.QueryCache(vpnAddrs)
|
hostinfo.remotes = f.lightHouse.QueryCache(vpnAddrs)
|
||||||
if !via.IsRelayed {
|
hostinfo.SetRemote(addr)
|
||||||
hostinfo.SetRemote(via.UdpAddr)
|
|
||||||
}
|
|
||||||
hostinfo.buildNetworks(f.myVpnNetworksTable, remoteCert.Certificate)
|
hostinfo.buildNetworks(f.myVpnNetworksTable, remoteCert.Certificate)
|
||||||
|
|
||||||
existing, err := f.handshakeManager.CheckAndComplete(hostinfo, 0, f)
|
existing, err := f.handshakeManager.CheckAndComplete(hostinfo, 0, f)
|
||||||
@@ -339,7 +337,7 @@ func ixHandshakeStage1(f *Interface, via *ViaSender, packet []byte, h *header.H)
|
|||||||
switch err {
|
switch err {
|
||||||
case ErrAlreadySeen:
|
case ErrAlreadySeen:
|
||||||
// Update remote if preferred
|
// Update remote if preferred
|
||||||
if existing.SetRemoteIfPreferred(f.hostMap, via) {
|
if existing.SetRemoteIfPreferred(f.hostMap, addr) {
|
||||||
// Send a test packet to ensure the other side has also switched to
|
// Send a test packet to ensure the other side has also switched to
|
||||||
// the preferred remote
|
// the preferred remote
|
||||||
f.SendMessageToVpnAddr(header.Test, header.TestRequest, vpnAddrs[0], []byte(""), make([]byte, 12, 12), make([]byte, mtu))
|
f.SendMessageToVpnAddr(header.Test, header.TestRequest, vpnAddrs[0], []byte(""), make([]byte, 12, 12), make([]byte, mtu))
|
||||||
@@ -347,21 +345,21 @@ func ixHandshakeStage1(f *Interface, via *ViaSender, packet []byte, h *header.H)
|
|||||||
|
|
||||||
msg = existing.HandshakePacket[2]
|
msg = existing.HandshakePacket[2]
|
||||||
f.messageMetrics.Tx(header.Handshake, header.MessageSubType(msg[1]), 1)
|
f.messageMetrics.Tx(header.Handshake, header.MessageSubType(msg[1]), 1)
|
||||||
if !via.IsRelayed {
|
if addr.IsValid() {
|
||||||
err := f.outside.WriteTo(msg, via.UdpAddr)
|
err := f.outside.WriteTo(msg, addr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
f.l.WithField("vpnAddrs", existing.vpnAddrs).WithField("from", via).
|
f.l.WithField("vpnAddrs", existing.vpnAddrs).WithField("udpAddr", addr).
|
||||||
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true).
|
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true).
|
||||||
WithError(err).Error("Failed to send handshake message")
|
WithError(err).Error("Failed to send handshake message")
|
||||||
} else {
|
} else {
|
||||||
f.l.WithField("vpnAddrs", existing.vpnAddrs).WithField("from", via).
|
f.l.WithField("vpnAddrs", existing.vpnAddrs).WithField("udpAddr", addr).
|
||||||
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true).
|
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true).
|
||||||
Info("Handshake message sent")
|
Info("Handshake message sent")
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
} else {
|
} else {
|
||||||
if via.relay == nil {
|
if via == nil {
|
||||||
f.l.Error("Handshake send failed: both addr and via.relay are nil.")
|
f.l.Error("Handshake send failed: both addr and via are nil.")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0])
|
hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0])
|
||||||
@@ -373,7 +371,7 @@ func ixHandshakeStage1(f *Interface, via *ViaSender, packet []byte, h *header.H)
|
|||||||
}
|
}
|
||||||
case ErrExistingHostInfo:
|
case ErrExistingHostInfo:
|
||||||
// This means there was an existing tunnel and this handshake was older than the one we are currently based on
|
// This means there was an existing tunnel and this handshake was older than the one we are currently based on
|
||||||
f.l.WithField("vpnAddrs", vpnAddrs).WithField("from", via).
|
f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
|
||||||
WithField("certName", certName).
|
WithField("certName", certName).
|
||||||
WithField("certVersion", certVersion).
|
WithField("certVersion", certVersion).
|
||||||
WithField("oldHandshakeTime", existing.lastHandshakeTime).
|
WithField("oldHandshakeTime", existing.lastHandshakeTime).
|
||||||
@@ -389,7 +387,7 @@ func ixHandshakeStage1(f *Interface, via *ViaSender, packet []byte, h *header.H)
|
|||||||
return
|
return
|
||||||
case ErrLocalIndexCollision:
|
case ErrLocalIndexCollision:
|
||||||
// This means we failed to insert because of collision on localIndexId. Just let the next handshake packet retry
|
// This means we failed to insert because of collision on localIndexId. Just let the next handshake packet retry
|
||||||
f.l.WithField("vpnAddrs", vpnAddrs).WithField("from", via).
|
f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
|
||||||
WithField("certName", certName).
|
WithField("certName", certName).
|
||||||
WithField("certVersion", certVersion).
|
WithField("certVersion", certVersion).
|
||||||
WithField("fingerprint", fingerprint).
|
WithField("fingerprint", fingerprint).
|
||||||
@@ -402,7 +400,7 @@ func ixHandshakeStage1(f *Interface, via *ViaSender, packet []byte, h *header.H)
|
|||||||
default:
|
default:
|
||||||
// Shouldn't happen, but just in case someone adds a new error type to CheckAndComplete
|
// Shouldn't happen, but just in case someone adds a new error type to CheckAndComplete
|
||||||
// And we forget to update it here
|
// And we forget to update it here
|
||||||
f.l.WithError(err).WithField("vpnAddrs", vpnAddrs).WithField("from", via).
|
f.l.WithError(err).WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
|
||||||
WithField("certName", certName).
|
WithField("certName", certName).
|
||||||
WithField("certVersion", certVersion).
|
WithField("certVersion", certVersion).
|
||||||
WithField("fingerprint", fingerprint).
|
WithField("fingerprint", fingerprint).
|
||||||
@@ -416,23 +414,30 @@ func ixHandshakeStage1(f *Interface, via *ViaSender, packet []byte, h *header.H)
|
|||||||
|
|
||||||
// Do the send
|
// Do the send
|
||||||
f.messageMetrics.Tx(header.Handshake, header.MessageSubType(msg[1]), 1)
|
f.messageMetrics.Tx(header.Handshake, header.MessageSubType(msg[1]), 1)
|
||||||
if !via.IsRelayed {
|
if addr.IsValid() {
|
||||||
err = f.outside.WriteTo(msg, via.UdpAddr)
|
err = f.outside.WriteTo(msg, addr)
|
||||||
log := f.l.WithField("vpnAddrs", vpnAddrs).WithField("from", via).
|
|
||||||
WithField("certName", certName).
|
|
||||||
WithField("certVersion", certVersion).
|
|
||||||
WithField("fingerprint", fingerprint).
|
|
||||||
WithField("issuer", issuer).
|
|
||||||
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
|
||||||
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"})
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithError(err).Error("Failed to send handshake")
|
f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
|
||||||
|
WithField("certName", certName).
|
||||||
|
WithField("certVersion", certVersion).
|
||||||
|
WithField("fingerprint", fingerprint).
|
||||||
|
WithField("issuer", issuer).
|
||||||
|
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
||||||
|
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
|
||||||
|
WithError(err).Error("Failed to send handshake")
|
||||||
} else {
|
} else {
|
||||||
log.Info("Handshake message sent")
|
f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
|
||||||
|
WithField("certName", certName).
|
||||||
|
WithField("certVersion", certVersion).
|
||||||
|
WithField("fingerprint", fingerprint).
|
||||||
|
WithField("issuer", issuer).
|
||||||
|
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
||||||
|
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
|
||||||
|
Info("Handshake message sent")
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if via.relay == nil {
|
if via == nil {
|
||||||
f.l.Error("Handshake send failed: both addr and via.relay are nil.")
|
f.l.Error("Handshake send failed: both addr and via are nil.")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0])
|
hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0])
|
||||||
@@ -457,7 +462,7 @@ func ixHandshakeStage1(f *Interface, via *ViaSender, packet []byte, h *header.H)
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func ixHandshakeStage2(f *Interface, via *ViaSender, hh *HandshakeHostInfo, packet []byte, h *header.H) bool {
|
func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *HandshakeHostInfo, packet []byte, h *header.H) bool {
|
||||||
if hh == nil {
|
if hh == nil {
|
||||||
// Nothing here to tear down, got a bogus stage 2 packet
|
// Nothing here to tear down, got a bogus stage 2 packet
|
||||||
return true
|
return true
|
||||||
@@ -467,10 +472,10 @@ func ixHandshakeStage2(f *Interface, via *ViaSender, hh *HandshakeHostInfo, pack
|
|||||||
defer hh.Unlock()
|
defer hh.Unlock()
|
||||||
|
|
||||||
hostinfo := hh.hostinfo
|
hostinfo := hh.hostinfo
|
||||||
if !via.IsRelayed {
|
if addr.IsValid() {
|
||||||
// The vpnAddr we know about is the one we tried to handshake with, use it to apply the remote allow list.
|
// The vpnAddr we know about is the one we tried to handshake with, use it to apply the remote allow list.
|
||||||
if !f.lightHouse.GetRemoteAllowList().AllowAll(hostinfo.vpnAddrs, via.UdpAddr.Addr()) {
|
if !f.lightHouse.GetRemoteAllowList().AllowAll(hostinfo.vpnAddrs, addr.Addr()) {
|
||||||
f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via).Debug("lighthouse.remote_allow_list denied incoming handshake")
|
f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -478,7 +483,7 @@ func ixHandshakeStage2(f *Interface, via *ViaSender, hh *HandshakeHostInfo, pack
|
|||||||
ci := hostinfo.ConnectionState
|
ci := hostinfo.ConnectionState
|
||||||
msg, eKey, dKey, err := ci.H.ReadMessage(nil, packet[header.Len:])
|
msg, eKey, dKey, err := ci.H.ReadMessage(nil, packet[header.Len:])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via).
|
f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).
|
||||||
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("header", h).
|
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("header", h).
|
||||||
Error("Failed to call noise.ReadMessage")
|
Error("Failed to call noise.ReadMessage")
|
||||||
|
|
||||||
@@ -487,7 +492,7 @@ func ixHandshakeStage2(f *Interface, via *ViaSender, hh *HandshakeHostInfo, pack
|
|||||||
// near future
|
// near future
|
||||||
return false
|
return false
|
||||||
} else if dKey == nil || eKey == nil {
|
} else if dKey == nil || eKey == nil {
|
||||||
f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via).
|
f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).
|
||||||
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
|
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
|
||||||
Error("Noise did not arrive at a key")
|
Error("Noise did not arrive at a key")
|
||||||
|
|
||||||
@@ -499,7 +504,7 @@ func ixHandshakeStage2(f *Interface, via *ViaSender, hh *HandshakeHostInfo, pack
|
|||||||
hs := &NebulaHandshake{}
|
hs := &NebulaHandshake{}
|
||||||
err = hs.Unmarshal(msg)
|
err = hs.Unmarshal(msg)
|
||||||
if err != nil || hs.Details == nil {
|
if err != nil || hs.Details == nil {
|
||||||
f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via).
|
f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).
|
||||||
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).Error("Failed unmarshal handshake message")
|
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).Error("Failed unmarshal handshake message")
|
||||||
|
|
||||||
// The handshake state machine is complete, if things break now there is no chance to recover. Tear down and start again
|
// The handshake state machine is complete, if things break now there is no chance to recover. Tear down and start again
|
||||||
@@ -508,7 +513,7 @@ func ixHandshakeStage2(f *Interface, via *ViaSender, hh *HandshakeHostInfo, pack
|
|||||||
|
|
||||||
rc, err := cert.Recombine(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), ci.Curve())
|
rc, err := cert.Recombine(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), ci.Curve())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
f.l.WithError(err).WithField("from", via).
|
f.l.WithError(err).WithField("udpAddr", addr).
|
||||||
WithField("vpnAddrs", hostinfo.vpnAddrs).
|
WithField("vpnAddrs", hostinfo.vpnAddrs).
|
||||||
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
|
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
|
||||||
Info("Handshake did not contain a certificate")
|
Info("Handshake did not contain a certificate")
|
||||||
@@ -522,7 +527,7 @@ func ixHandshakeStage2(f *Interface, via *ViaSender, hh *HandshakeHostInfo, pack
|
|||||||
fp = "<error generating certificate fingerprint>"
|
fp = "<error generating certificate fingerprint>"
|
||||||
}
|
}
|
||||||
|
|
||||||
e := f.l.WithError(err).WithField("from", via).
|
e := f.l.WithError(err).WithField("udpAddr", addr).
|
||||||
WithField("vpnAddrs", hostinfo.vpnAddrs).
|
WithField("vpnAddrs", hostinfo.vpnAddrs).
|
||||||
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
|
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
|
||||||
WithField("certFingerprint", fp).
|
WithField("certFingerprint", fp).
|
||||||
@@ -537,7 +542,7 @@ func ixHandshakeStage2(f *Interface, via *ViaSender, hh *HandshakeHostInfo, pack
|
|||||||
}
|
}
|
||||||
|
|
||||||
if len(remoteCert.Certificate.Networks()) == 0 {
|
if len(remoteCert.Certificate.Networks()) == 0 {
|
||||||
f.l.WithError(err).WithField("from", via).
|
f.l.WithError(err).WithField("udpAddr", addr).
|
||||||
WithField("vpnAddrs", hostinfo.vpnAddrs).
|
WithField("vpnAddrs", hostinfo.vpnAddrs).
|
||||||
WithField("cert", remoteCert).
|
WithField("cert", remoteCert).
|
||||||
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
|
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
|
||||||
@@ -560,8 +565,8 @@ func ixHandshakeStage2(f *Interface, via *ViaSender, hh *HandshakeHostInfo, pack
|
|||||||
ci.eKey = NewNebulaCipherState(eKey)
|
ci.eKey = NewNebulaCipherState(eKey)
|
||||||
|
|
||||||
// Make sure the current udpAddr being used is set for responding
|
// Make sure the current udpAddr being used is set for responding
|
||||||
if !via.IsRelayed {
|
if addr.IsValid() {
|
||||||
hostinfo.SetRemote(via.UdpAddr)
|
hostinfo.SetRemote(addr)
|
||||||
} else {
|
} else {
|
||||||
hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0])
|
hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0])
|
||||||
}
|
}
|
||||||
@@ -583,7 +588,7 @@ func ixHandshakeStage2(f *Interface, via *ViaSender, hh *HandshakeHostInfo, pack
|
|||||||
// Ensure the right host responded
|
// Ensure the right host responded
|
||||||
if !correctHostResponded {
|
if !correctHostResponded {
|
||||||
f.l.WithField("intendedVpnAddrs", hostinfo.vpnAddrs).WithField("haveVpnNetworks", vpnNetworks).
|
f.l.WithField("intendedVpnAddrs", hostinfo.vpnAddrs).WithField("haveVpnNetworks", vpnNetworks).
|
||||||
WithField("from", via).
|
WithField("udpAddr", addr).
|
||||||
WithField("certName", certName).
|
WithField("certName", certName).
|
||||||
WithField("certVersion", certVersion).
|
WithField("certVersion", certVersion).
|
||||||
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
|
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
|
||||||
@@ -597,7 +602,7 @@ func ixHandshakeStage2(f *Interface, via *ViaSender, hh *HandshakeHostInfo, pack
|
|||||||
f.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], func(newHH *HandshakeHostInfo) {
|
f.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], func(newHH *HandshakeHostInfo) {
|
||||||
// Block the current used address
|
// Block the current used address
|
||||||
newHH.hostinfo.remotes = hostinfo.remotes
|
newHH.hostinfo.remotes = hostinfo.remotes
|
||||||
newHH.hostinfo.remotes.BlockRemote(via)
|
newHH.hostinfo.remotes.BlockRemote(addr)
|
||||||
|
|
||||||
f.l.WithField("blockedUdpAddrs", newHH.hostinfo.remotes.CopyBlockedRemotes()).
|
f.l.WithField("blockedUdpAddrs", newHH.hostinfo.remotes.CopyBlockedRemotes()).
|
||||||
WithField("vpnNetworks", vpnNetworks).
|
WithField("vpnNetworks", vpnNetworks).
|
||||||
@@ -620,7 +625,7 @@ func ixHandshakeStage2(f *Interface, via *ViaSender, hh *HandshakeHostInfo, pack
|
|||||||
ci.window.Update(f.l, 2)
|
ci.window.Update(f.l, 2)
|
||||||
|
|
||||||
duration := time.Since(hh.startTime).Nanoseconds()
|
duration := time.Since(hh.startTime).Nanoseconds()
|
||||||
msgRxL := f.l.WithField("vpnAddrs", vpnAddrs).WithField("from", via).
|
msgRxL := f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
|
||||||
WithField("certName", certName).
|
WithField("certName", certName).
|
||||||
WithField("certVersion", certVersion).
|
WithField("certVersion", certVersion).
|
||||||
WithField("fingerprint", fingerprint).
|
WithField("fingerprint", fingerprint).
|
||||||
|
|||||||
@@ -136,11 +136,11 @@ func (hm *HandshakeManager) Run(ctx context.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (hm *HandshakeManager) HandleIncoming(via *ViaSender, packet []byte, h *header.H) {
|
func (hm *HandshakeManager) HandleIncoming(addr netip.AddrPort, via *ViaSender, packet []byte, h *header.H) {
|
||||||
// First remote allow list check before we know the vpnIp
|
// First remote allow list check before we know the vpnIp
|
||||||
if !via.IsRelayed {
|
if addr.IsValid() {
|
||||||
if !hm.lightHouse.GetRemoteAllowList().AllowUnknownVpnAddr(via.UdpAddr.Addr()) {
|
if !hm.lightHouse.GetRemoteAllowList().AllowUnknownVpnAddr(addr.Addr()) {
|
||||||
hm.l.WithField("from", via).Debug("lighthouse.remote_allow_list denied incoming handshake")
|
hm.l.WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -149,11 +149,11 @@ func (hm *HandshakeManager) HandleIncoming(via *ViaSender, packet []byte, h *hea
|
|||||||
case header.HandshakeIXPSK0:
|
case header.HandshakeIXPSK0:
|
||||||
switch h.MessageCounter {
|
switch h.MessageCounter {
|
||||||
case 1:
|
case 1:
|
||||||
ixHandshakeStage1(hm.f, via, packet, h)
|
ixHandshakeStage1(hm.f, addr, via, packet, h)
|
||||||
|
|
||||||
case 2:
|
case 2:
|
||||||
newHostinfo := hm.queryIndex(h.RemoteIndex)
|
newHostinfo := hm.queryIndex(h.RemoteIndex)
|
||||||
tearDown := ixHandshakeStage2(hm.f, via, newHostinfo, packet, h)
|
tearDown := ixHandshakeStage2(hm.f, addr, via, newHostinfo, packet, h)
|
||||||
if tearDown && newHostinfo != nil {
|
if tearDown && newHostinfo != nil {
|
||||||
hm.DeleteHostInfo(newHostinfo.hostinfo)
|
hm.DeleteHostInfo(newHostinfo.hostinfo)
|
||||||
}
|
}
|
||||||
|
|||||||
31
hostmap.go
31
hostmap.go
@@ -1,9 +1,7 @@
|
|||||||
package nebula
|
package nebula
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"slices"
|
"slices"
|
||||||
@@ -278,25 +276,9 @@ type HostInfo struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type ViaSender struct {
|
type ViaSender struct {
|
||||||
UdpAddr netip.AddrPort
|
|
||||||
relayHI *HostInfo // relayHI is the host info object of the relay
|
relayHI *HostInfo // relayHI is the host info object of the relay
|
||||||
remoteIdx uint32 // remoteIdx is the index included in the header of the received packet
|
remoteIdx uint32 // remoteIdx is the index included in the header of the received packet
|
||||||
relay *Relay // relay contains the rest of the relay information, including the PeerIP of the host trying to communicate with us.
|
relay *Relay // relay contains the rest of the relay information, including the PeerIP of the host trying to communicate with us.
|
||||||
IsRelayed bool // IsRelayed is true if the packet was sent through a relay
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *ViaSender) String() string {
|
|
||||||
if v.IsRelayed {
|
|
||||||
return fmt.Sprintf("%s (relayed)", v.UdpAddr)
|
|
||||||
}
|
|
||||||
return v.UdpAddr.String()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *ViaSender) MarshalJSON() ([]byte, error) {
|
|
||||||
if v.IsRelayed {
|
|
||||||
return json.Marshal(m{"direct": v.UdpAddr})
|
|
||||||
}
|
|
||||||
return json.Marshal(m{"relay": v.UdpAddr})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type cachedPacket struct {
|
type cachedPacket struct {
|
||||||
@@ -712,7 +694,6 @@ func (i *HostInfo) GetCert() *cert.CachedCertificate {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: Maybe use ViaSender here?
|
|
||||||
func (i *HostInfo) SetRemote(remote netip.AddrPort) {
|
func (i *HostInfo) SetRemote(remote netip.AddrPort) {
|
||||||
// We copy here because we likely got this remote from a source that reuses the object
|
// We copy here because we likely got this remote from a source that reuses the object
|
||||||
if i.remote != remote {
|
if i.remote != remote {
|
||||||
@@ -723,14 +704,14 @@ func (i *HostInfo) SetRemote(remote netip.AddrPort) {
|
|||||||
|
|
||||||
// SetRemoteIfPreferred returns true if the remote was changed. The lastRoam
|
// SetRemoteIfPreferred returns true if the remote was changed. The lastRoam
|
||||||
// time on the HostInfo will also be updated.
|
// time on the HostInfo will also be updated.
|
||||||
func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, via *ViaSender) bool {
|
func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, newRemote netip.AddrPort) bool {
|
||||||
if via.IsRelayed {
|
if !newRemote.IsValid() {
|
||||||
|
// relays have nil udp Addrs
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
currentRemote := i.remote
|
currentRemote := i.remote
|
||||||
if !currentRemote.IsValid() {
|
if !currentRemote.IsValid() {
|
||||||
i.SetRemote(via.UdpAddr)
|
i.SetRemote(newRemote)
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -743,7 +724,7 @@ func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, via *ViaSender) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
if l.Contains(via.UdpAddr.Addr()) {
|
if l.Contains(newRemote.Addr()) {
|
||||||
newIsPreferred = true
|
newIsPreferred = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -753,7 +734,7 @@ func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, via *ViaSender) bool {
|
|||||||
i.lastRoam = time.Now()
|
i.lastRoam = time.Now()
|
||||||
i.lastRoamRemote = currentRemote
|
i.lastRoamRemote = currentRemote
|
||||||
|
|
||||||
i.SetRemote(via.UdpAddr)
|
i.SetRemote(newRemote)
|
||||||
|
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -222,13 +222,6 @@ func (f *Interface) activate() {
|
|||||||
WithField("boringcrypto", boringEnabled()).
|
WithField("boringcrypto", boringEnabled()).
|
||||||
Info("Nebula interface is active")
|
Info("Nebula interface is active")
|
||||||
|
|
||||||
if f.routines > 1 {
|
|
||||||
if !f.inside.SupportsMultiqueue() || !f.outside.SupportsMultipleReaders() {
|
|
||||||
f.routines = 1
|
|
||||||
f.l.Warn("routines is not supported on this platform, falling back to a single routine")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
metrics.GetOrRegisterGauge("routines", nil).Update(int64(f.routines))
|
metrics.GetOrRegisterGauge("routines", nil).Update(int64(f.routines))
|
||||||
|
|
||||||
// Prepare n tun queues
|
// Prepare n tun queues
|
||||||
@@ -279,7 +272,7 @@ func (f *Interface) listenOut(i int) {
|
|||||||
nb := make([]byte, 12, 12)
|
nb := make([]byte, 12, 12)
|
||||||
|
|
||||||
li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) {
|
li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) {
|
||||||
f.readOutsidePackets(&ViaSender{UdpAddr: fromUdpAddr}, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get(f.l))
|
f.readOutsidePackets(fromUdpAddr, nil, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get(f.l))
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
75
outside.go
75
outside.go
@@ -19,21 +19,21 @@ const (
|
|||||||
minFwPacketLen = 4
|
minFwPacketLen = 4
|
||||||
)
|
)
|
||||||
|
|
||||||
func (f *Interface) readOutsidePackets(via *ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache) {
|
func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache) {
|
||||||
err := h.Parse(packet)
|
err := h.Parse(packet)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Hole punch packets are 0 or 1 byte big, so lets ignore printing those errors
|
// Hole punch packets are 0 or 1 byte big, so lets ignore printing those errors
|
||||||
if len(packet) > 1 {
|
if len(packet) > 1 {
|
||||||
f.l.WithField("packet", packet).Infof("Error while parsing inbound packet from %s: %s", via, err)
|
f.l.WithField("packet", packet).Infof("Error while parsing inbound packet from %s: %s", ip, err)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
//l.Error("in packet ", header, packet[HeaderLen:])
|
//l.Error("in packet ", header, packet[HeaderLen:])
|
||||||
if !via.IsRelayed {
|
if ip.IsValid() {
|
||||||
if f.myVpnNetworksTable.Contains(via.UdpAddr.Addr()) {
|
if f.myVpnNetworksTable.Contains(ip.Addr()) {
|
||||||
if f.l.Level >= logrus.DebugLevel {
|
if f.l.Level >= logrus.DebugLevel {
|
||||||
f.l.WithField("from", via).Debug("Refusing to process double encrypted packet")
|
f.l.WithField("udpAddr", ip).Debug("Refusing to process double encrypted packet")
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -54,7 +54,8 @@ func (f *Interface) readOutsidePackets(via *ViaSender, out []byte, packet []byte
|
|||||||
|
|
||||||
switch h.Type {
|
switch h.Type {
|
||||||
case header.Message:
|
case header.Message:
|
||||||
if !f.handleEncrypted(ci, via, h) {
|
// TODO handleEncrypted sends directly to addr on error. Handle this in the tunneling case.
|
||||||
|
if !f.handleEncrypted(ci, ip, h) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -78,7 +79,7 @@ func (f *Interface) readOutsidePackets(via *ViaSender, out []byte, packet []byte
|
|||||||
// Successfully validated the thing. Get rid of the Relay header.
|
// Successfully validated the thing. Get rid of the Relay header.
|
||||||
signedPayload = signedPayload[header.Len:]
|
signedPayload = signedPayload[header.Len:]
|
||||||
// Pull the Roaming parts up here, and return in all call paths.
|
// Pull the Roaming parts up here, and return in all call paths.
|
||||||
f.handleHostRoaming(hostinfo, via)
|
f.handleHostRoaming(hostinfo, ip)
|
||||||
// Track usage of both the HostInfo and the Relay for the received & authenticated packet
|
// Track usage of both the HostInfo and the Relay for the received & authenticated packet
|
||||||
f.connectionManager.In(hostinfo)
|
f.connectionManager.In(hostinfo)
|
||||||
f.connectionManager.RelayUsed(h.RemoteIndex)
|
f.connectionManager.RelayUsed(h.RemoteIndex)
|
||||||
@@ -95,14 +96,7 @@ func (f *Interface) readOutsidePackets(via *ViaSender, out []byte, packet []byte
|
|||||||
case TerminalType:
|
case TerminalType:
|
||||||
// If I am the target of this relay, process the unwrapped packet
|
// If I am the target of this relay, process the unwrapped packet
|
||||||
// From this recursive point, all these variables are 'burned'. We shouldn't rely on them again.
|
// From this recursive point, all these variables are 'burned'. We shouldn't rely on them again.
|
||||||
rVia := ViaSender{
|
f.readOutsidePackets(netip.AddrPort{}, &ViaSender{relayHI: hostinfo, remoteIdx: relay.RemoteIndex, relay: relay}, out[:0], signedPayload, h, fwPacket, lhf, nb, q, localCache)
|
||||||
UdpAddr: via.UdpAddr,
|
|
||||||
relayHI: hostinfo,
|
|
||||||
remoteIdx: relay.RemoteIndex,
|
|
||||||
relay: relay,
|
|
||||||
IsRelayed: true,
|
|
||||||
}
|
|
||||||
f.readOutsidePackets(&rVia, out[:0], signedPayload, h, fwPacket, lhf, nb, q, localCache)
|
|
||||||
return
|
return
|
||||||
case ForwardingType:
|
case ForwardingType:
|
||||||
// Find the target HostInfo relay object
|
// Find the target HostInfo relay object
|
||||||
@@ -132,32 +126,31 @@ func (f *Interface) readOutsidePackets(via *ViaSender, out []byte, packet []byte
|
|||||||
|
|
||||||
case header.LightHouse:
|
case header.LightHouse:
|
||||||
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
|
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
|
||||||
if !f.handleEncrypted(ci, via, h) {
|
if !f.handleEncrypted(ci, ip, h) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb)
|
d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
hostinfo.logger(f.l).WithError(err).WithField("from", via).
|
hostinfo.logger(f.l).WithError(err).WithField("udpAddr", ip).
|
||||||
WithField("packet", packet).
|
WithField("packet", packet).
|
||||||
Error("Failed to decrypt lighthouse packet")
|
Error("Failed to decrypt lighthouse packet")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
//TODO: assert via is not relayed
|
lhf.HandleRequest(ip, hostinfo.vpnAddrs, d, f)
|
||||||
lhf.HandleRequest(via.UdpAddr, hostinfo.vpnAddrs, d, f)
|
|
||||||
|
|
||||||
// Fallthrough to the bottom to record incoming traffic
|
// Fallthrough to the bottom to record incoming traffic
|
||||||
|
|
||||||
case header.Test:
|
case header.Test:
|
||||||
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
|
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
|
||||||
if !f.handleEncrypted(ci, via, h) {
|
if !f.handleEncrypted(ci, ip, h) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb)
|
d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
hostinfo.logger(f.l).WithError(err).WithField("from", via).
|
hostinfo.logger(f.l).WithError(err).WithField("udpAddr", ip).
|
||||||
WithField("packet", packet).
|
WithField("packet", packet).
|
||||||
Error("Failed to decrypt test packet")
|
Error("Failed to decrypt test packet")
|
||||||
return
|
return
|
||||||
@@ -166,7 +159,7 @@ func (f *Interface) readOutsidePackets(via *ViaSender, out []byte, packet []byte
|
|||||||
if h.Subtype == header.TestRequest {
|
if h.Subtype == header.TestRequest {
|
||||||
// This testRequest might be from TryPromoteBest, so we should roam
|
// This testRequest might be from TryPromoteBest, so we should roam
|
||||||
// to the new IP address before responding
|
// to the new IP address before responding
|
||||||
f.handleHostRoaming(hostinfo, via)
|
f.handleHostRoaming(hostinfo, ip)
|
||||||
f.send(header.Test, header.TestReply, ci, hostinfo, d, nb, out)
|
f.send(header.Test, header.TestReply, ci, hostinfo, d, nb, out)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -177,34 +170,34 @@ func (f *Interface) readOutsidePackets(via *ViaSender, out []byte, packet []byte
|
|||||||
|
|
||||||
case header.Handshake:
|
case header.Handshake:
|
||||||
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
|
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
|
||||||
f.handshakeManager.HandleIncoming(via, packet, h)
|
f.handshakeManager.HandleIncoming(ip, via, packet, h)
|
||||||
return
|
return
|
||||||
|
|
||||||
case header.RecvError:
|
case header.RecvError:
|
||||||
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
|
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
|
||||||
f.handleRecvError(via.UdpAddr, h)
|
f.handleRecvError(ip, h)
|
||||||
return
|
return
|
||||||
|
|
||||||
case header.CloseTunnel:
|
case header.CloseTunnel:
|
||||||
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
|
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
|
||||||
if !f.handleEncrypted(ci, via, h) {
|
if !f.handleEncrypted(ci, ip, h) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
hostinfo.logger(f.l).WithField("from", via).
|
hostinfo.logger(f.l).WithField("udpAddr", ip).
|
||||||
Info("Close tunnel received, tearing down.")
|
Info("Close tunnel received, tearing down.")
|
||||||
|
|
||||||
f.closeTunnel(hostinfo)
|
f.closeTunnel(hostinfo)
|
||||||
return
|
return
|
||||||
|
|
||||||
case header.Control:
|
case header.Control:
|
||||||
if !f.handleEncrypted(ci, via, h) {
|
if !f.handleEncrypted(ci, ip, h) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb)
|
d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
hostinfo.logger(f.l).WithError(err).WithField("from", via).
|
hostinfo.logger(f.l).WithError(err).WithField("udpAddr", ip).
|
||||||
WithField("packet", packet).
|
WithField("packet", packet).
|
||||||
Error("Failed to decrypt Control packet")
|
Error("Failed to decrypt Control packet")
|
||||||
return
|
return
|
||||||
@@ -214,11 +207,11 @@ func (f *Interface) readOutsidePackets(via *ViaSender, out []byte, packet []byte
|
|||||||
|
|
||||||
default:
|
default:
|
||||||
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
|
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
|
||||||
hostinfo.logger(f.l).Debugf("Unexpected packet received from %s", via)
|
hostinfo.logger(f.l).Debugf("Unexpected packet received from %s", ip)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
f.handleHostRoaming(hostinfo, via)
|
f.handleHostRoaming(hostinfo, ip)
|
||||||
|
|
||||||
f.connectionManager.In(hostinfo)
|
f.connectionManager.In(hostinfo)
|
||||||
}
|
}
|
||||||
@@ -237,36 +230,36 @@ func (f *Interface) sendCloseTunnel(h *HostInfo) {
|
|||||||
f.send(header.CloseTunnel, 0, h.ConnectionState, h, []byte{}, make([]byte, 12, 12), make([]byte, mtu))
|
f.send(header.CloseTunnel, 0, h.ConnectionState, h, []byte{}, make([]byte, 12, 12), make([]byte, mtu))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Interface) handleHostRoaming(hostinfo *HostInfo, via *ViaSender) {
|
func (f *Interface) handleHostRoaming(hostinfo *HostInfo, udpAddr netip.AddrPort) {
|
||||||
if !via.IsRelayed && hostinfo.remote != via.UdpAddr {
|
if udpAddr.IsValid() && hostinfo.remote != udpAddr {
|
||||||
if !f.lightHouse.GetRemoteAllowList().AllowAll(hostinfo.vpnAddrs, via.UdpAddr.Addr()) {
|
if !f.lightHouse.GetRemoteAllowList().AllowAll(hostinfo.vpnAddrs, udpAddr.Addr()) {
|
||||||
hostinfo.logger(f.l).WithField("newAddr", via.UdpAddr).Debug("lighthouse.remote_allow_list denied roaming")
|
hostinfo.logger(f.l).WithField("newAddr", udpAddr).Debug("lighthouse.remote_allow_list denied roaming")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if !hostinfo.lastRoam.IsZero() && via.UdpAddr == hostinfo.lastRoamRemote && time.Since(hostinfo.lastRoam) < RoamingSuppressSeconds*time.Second {
|
if !hostinfo.lastRoam.IsZero() && udpAddr == hostinfo.lastRoamRemote && time.Since(hostinfo.lastRoam) < RoamingSuppressSeconds*time.Second {
|
||||||
if f.l.Level >= logrus.DebugLevel {
|
if f.l.Level >= logrus.DebugLevel {
|
||||||
hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", via.UdpAddr).
|
hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", udpAddr).
|
||||||
Debugf("Suppressing roam back to previous remote for %d seconds", RoamingSuppressSeconds)
|
Debugf("Suppressing roam back to previous remote for %d seconds", RoamingSuppressSeconds)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", via.UdpAddr).
|
hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", udpAddr).
|
||||||
Info("Host roamed to new udp ip/port.")
|
Info("Host roamed to new udp ip/port.")
|
||||||
hostinfo.lastRoam = time.Now()
|
hostinfo.lastRoam = time.Now()
|
||||||
hostinfo.lastRoamRemote = hostinfo.remote
|
hostinfo.lastRoamRemote = hostinfo.remote
|
||||||
hostinfo.SetRemote(via.UdpAddr)
|
hostinfo.SetRemote(udpAddr)
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleEncrypted returns true if a packet should be processed, false otherwise
|
// handleEncrypted returns true if a packet should be processed, false otherwise
|
||||||
func (f *Interface) handleEncrypted(ci *ConnectionState, via *ViaSender, h *header.H) bool {
|
func (f *Interface) handleEncrypted(ci *ConnectionState, addr netip.AddrPort, h *header.H) bool {
|
||||||
// If connectionstate does not exist, send a recv error, if possible, to encourage a fast reconnect
|
// If connectionstate does not exist, send a recv error, if possible, to encourage a fast reconnect
|
||||||
if ci == nil {
|
if ci == nil {
|
||||||
if !via.IsRelayed {
|
if addr.IsValid() {
|
||||||
f.maybeSendRecvError(via.UdpAddr, h.RemoteIndex)
|
f.maybeSendRecvError(addr, h.RemoteIndex)
|
||||||
}
|
}
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -13,6 +13,5 @@ type Device interface {
|
|||||||
Networks() []netip.Prefix
|
Networks() []netip.Prefix
|
||||||
Name() string
|
Name() string
|
||||||
RoutesFor(netip.Addr) routing.Gateways
|
RoutesFor(netip.Addr) routing.Gateways
|
||||||
SupportsMultiqueue() bool
|
|
||||||
NewMultiQueueReader() (io.ReadWriteCloser, error)
|
NewMultiQueueReader() (io.ReadWriteCloser, error)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -95,10 +95,6 @@ func (t *tun) Name() string {
|
|||||||
return "android"
|
return "android"
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) SupportsMultiqueue() bool {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||||
return nil, fmt.Errorf("TODO: multiqueue not implemented for android")
|
return nil, fmt.Errorf("TODO: multiqueue not implemented for android")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -549,10 +549,6 @@ func (t *tun) Name() string {
|
|||||||
return t.Device
|
return t.Device
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) SupportsMultiqueue() bool {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||||
return nil, fmt.Errorf("TODO: multiqueue not implemented for darwin")
|
return nil, fmt.Errorf("TODO: multiqueue not implemented for darwin")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -105,10 +105,6 @@ func (t *disabledTun) Write(b []byte) (int, error) {
|
|||||||
return len(b), nil
|
return len(b), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *disabledTun) SupportsMultiqueue() bool {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *disabledTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
func (t *disabledTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||||
return t, nil
|
return t, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -450,10 +450,6 @@ func (t *tun) Name() string {
|
|||||||
return t.Device
|
return t.Device
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) SupportsMultiqueue() bool {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||||
return nil, fmt.Errorf("TODO: multiqueue not implemented for freebsd")
|
return nil, fmt.Errorf("TODO: multiqueue not implemented for freebsd")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -151,10 +151,6 @@ func (t *tun) Name() string {
|
|||||||
return "iOS"
|
return "iOS"
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) SupportsMultiqueue() bool {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||||
return nil, fmt.Errorf("TODO: multiqueue not implemented for ios")
|
return nil, fmt.Errorf("TODO: multiqueue not implemented for ios")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -216,10 +216,6 @@ func (t *tun) reload(c *config.C, initial bool) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) SupportsMultiqueue() bool {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||||
fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
|
fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -586,42 +582,48 @@ func (t *tun) isGatewayInVpnNetworks(gwAddr netip.Addr) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) getGatewaysFromRoute(r *netlink.Route) routing.Gateways {
|
func (t *tun) getGatewaysFromRoute(r *netlink.Route) routing.Gateways {
|
||||||
|
|
||||||
var gateways routing.Gateways
|
var gateways routing.Gateways
|
||||||
|
|
||||||
link, err := netlink.LinkByName(t.Device)
|
link, err := netlink.LinkByName(t.Device)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.l.WithField("deviceName", t.Device).Error("Ignoring route update: failed to get link by name")
|
t.l.WithField("Devicename", t.Device).Error("Ignoring route update: failed to get link by name")
|
||||||
return gateways
|
return gateways
|
||||||
}
|
}
|
||||||
|
|
||||||
// If this route is relevant to our interface and there is a gateway then add it
|
// If this route is relevant to our interface and there is a gateway then add it
|
||||||
if r.LinkIndex == link.Attrs().Index {
|
if r.LinkIndex == link.Attrs().Index && len(r.Gw) > 0 {
|
||||||
gwAddr, ok := getGatewayAddr(r.Gw, r.Via)
|
gwAddr, ok := netip.AddrFromSlice(r.Gw)
|
||||||
if ok {
|
if !ok {
|
||||||
if t.isGatewayInVpnNetworks(gwAddr) {
|
t.l.WithField("route", r).Debug("Ignoring route update, invalid gateway address")
|
||||||
gateways = append(gateways, routing.NewGateway(gwAddr, 1))
|
|
||||||
} else {
|
|
||||||
// Gateway isn't in our overlay network, ignore
|
|
||||||
t.l.WithField("route", r).Debug("Ignoring route update, gateway is not in our network")
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
t.l.WithField("route", r).Debug("Ignoring route update, invalid gateway or via address")
|
gwAddr = gwAddr.Unmap()
|
||||||
|
|
||||||
|
if !t.isGatewayInVpnNetworks(gwAddr) {
|
||||||
|
// Gateway isn't in our overlay network, ignore
|
||||||
|
t.l.WithField("route", r).Debug("Ignoring route update, not in our network")
|
||||||
|
} else {
|
||||||
|
gateways = append(gateways, routing.NewGateway(gwAddr, 1))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, p := range r.MultiPath {
|
for _, p := range r.MultiPath {
|
||||||
// If this route is relevant to our interface and there is a gateway then add it
|
// If this route is relevant to our interface and there is a gateway then add it
|
||||||
if p.LinkIndex == link.Attrs().Index {
|
if p.LinkIndex == link.Attrs().Index && len(p.Gw) > 0 {
|
||||||
gwAddr, ok := getGatewayAddr(p.Gw, p.Via)
|
gwAddr, ok := netip.AddrFromSlice(p.Gw)
|
||||||
if ok {
|
if !ok {
|
||||||
if t.isGatewayInVpnNetworks(gwAddr) {
|
t.l.WithField("route", r).Debug("Ignoring multipath route update, invalid gateway address")
|
||||||
gateways = append(gateways, routing.NewGateway(gwAddr, p.Hops+1))
|
|
||||||
} else {
|
|
||||||
// Gateway isn't in our overlay network, ignore
|
|
||||||
t.l.WithField("route", r).Debug("Ignoring route update, gateway is not in our network")
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
t.l.WithField("route", r).Debug("Ignoring route update, invalid gateway or via address")
|
gwAddr = gwAddr.Unmap()
|
||||||
|
|
||||||
|
if !t.isGatewayInVpnNetworks(gwAddr) {
|
||||||
|
// Gateway isn't in our overlay network, ignore
|
||||||
|
t.l.WithField("route", r).Debug("Ignoring route update, not in our network")
|
||||||
|
} else {
|
||||||
|
// p.Hops+1 = weight of the route
|
||||||
|
gateways = append(gateways, routing.NewGateway(gwAddr, p.Hops+1))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -630,27 +632,10 @@ func (t *tun) getGatewaysFromRoute(r *netlink.Route) routing.Gateways {
|
|||||||
return gateways
|
return gateways
|
||||||
}
|
}
|
||||||
|
|
||||||
func getGatewayAddr(gw net.IP, via netlink.Destination) (netip.Addr, bool) {
|
|
||||||
// Try to use the old RTA_GATEWAY first
|
|
||||||
gwAddr, ok := netip.AddrFromSlice(gw)
|
|
||||||
if !ok {
|
|
||||||
// Fallback to the new RTA_VIA
|
|
||||||
rVia, ok := via.(*netlink.Via)
|
|
||||||
if ok {
|
|
||||||
gwAddr, ok = netip.AddrFromSlice(rVia.Addr)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if gwAddr.IsValid() {
|
|
||||||
gwAddr = gwAddr.Unmap()
|
|
||||||
return gwAddr, true
|
|
||||||
}
|
|
||||||
|
|
||||||
return netip.Addr{}, false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *tun) updateRoutes(r netlink.RouteUpdate) {
|
func (t *tun) updateRoutes(r netlink.RouteUpdate) {
|
||||||
|
|
||||||
gateways := t.getGatewaysFromRoute(&r.Route)
|
gateways := t.getGatewaysFromRoute(&r.Route)
|
||||||
|
|
||||||
if len(gateways) == 0 {
|
if len(gateways) == 0 {
|
||||||
// No gateways relevant to our network, no routing changes required.
|
// No gateways relevant to our network, no routing changes required.
|
||||||
t.l.WithField("route", r).Debug("Ignoring route update, no gateways")
|
t.l.WithField("route", r).Debug("Ignoring route update, no gateways")
|
||||||
|
|||||||
@@ -390,10 +390,6 @@ func (t *tun) Name() string {
|
|||||||
return t.Device
|
return t.Device
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) SupportsMultiqueue() bool {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||||
return nil, fmt.Errorf("TODO: multiqueue not implemented for netbsd")
|
return nil, fmt.Errorf("TODO: multiqueue not implemented for netbsd")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -310,10 +310,6 @@ func (t *tun) Name() string {
|
|||||||
return t.Device
|
return t.Device
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) SupportsMultiqueue() bool {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||||
return nil, fmt.Errorf("TODO: multiqueue not implemented for openbsd")
|
return nil, fmt.Errorf("TODO: multiqueue not implemented for openbsd")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -132,10 +132,6 @@ func (t *TestTun) Read(b []byte) (int, error) {
|
|||||||
return len(p), nil
|
return len(p), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *TestTun) SupportsMultiqueue() bool {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *TestTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
func (t *TestTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||||
return nil, fmt.Errorf("TODO: multiqueue not implemented")
|
return nil, fmt.Errorf("TODO: multiqueue not implemented")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -234,10 +234,6 @@ func (t *winTun) Write(b []byte) (int, error) {
|
|||||||
return t.tun.Write(b, 0)
|
return t.tun.Write(b, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *winTun) SupportsMultiqueue() bool {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *winTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
func (t *winTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||||
return nil, fmt.Errorf("TODO: multiqueue not implemented for windows")
|
return nil, fmt.Errorf("TODO: multiqueue not implemented for windows")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -46,10 +46,6 @@ func (d *UserDevice) RoutesFor(ip netip.Addr) routing.Gateways {
|
|||||||
return routing.Gateways{routing.NewGateway(ip, 1)}
|
return routing.Gateways{routing.NewGateway(ip, 1)}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *UserDevice) SupportsMultiqueue() bool {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *UserDevice) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
func (d *UserDevice) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||||
return d, nil
|
return d, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -338,21 +338,21 @@ func (r *RemoteList) CopyCache() *CacheMap {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// BlockRemote locks and records the address as bad, it will be excluded from the deduplicated address list
|
// BlockRemote locks and records the address as bad, it will be excluded from the deduplicated address list
|
||||||
func (r *RemoteList) BlockRemote(bad *ViaSender) {
|
func (r *RemoteList) BlockRemote(bad netip.AddrPort) {
|
||||||
if bad.IsRelayed {
|
if !bad.IsValid() {
|
||||||
|
// relays can have nil udp Addrs
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
r.Lock()
|
r.Lock()
|
||||||
defer r.Unlock()
|
defer r.Unlock()
|
||||||
|
|
||||||
// Check if we already blocked this addr
|
// Check if we already blocked this addr
|
||||||
if r.unlockedIsBad(bad.UdpAddr) {
|
if r.unlockedIsBad(bad) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// We copy here because we are taking something else's memory and we can't trust everything
|
// We copy here because we are taking something else's memory and we can't trust everything
|
||||||
r.badRemotes = append(r.badRemotes, bad.UdpAddr)
|
r.badRemotes = append(r.badRemotes, bad)
|
||||||
|
|
||||||
// Mark the next interaction must recollect/dedupe
|
// Mark the next interaction must recollect/dedupe
|
||||||
r.shouldRebuild = true
|
r.shouldRebuild = true
|
||||||
|
|||||||
@@ -34,10 +34,6 @@ func (NoopTun) Write([]byte) (int, error) {
|
|||||||
return 0, nil
|
return 0, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (NoopTun) SupportsMultiqueue() bool {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (NoopTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
func (NoopTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||||
return nil, errors.New("unsupported")
|
return nil, errors.New("unsupported")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -19,7 +19,6 @@ type Conn interface {
|
|||||||
ListenOut(r EncReader)
|
ListenOut(r EncReader)
|
||||||
WriteTo(b []byte, addr netip.AddrPort) error
|
WriteTo(b []byte, addr netip.AddrPort) error
|
||||||
ReloadConfig(c *config.C)
|
ReloadConfig(c *config.C)
|
||||||
SupportsMultipleReaders() bool
|
|
||||||
Close() error
|
Close() error
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -34,9 +33,6 @@ func (NoopConn) LocalAddr() (netip.AddrPort, error) {
|
|||||||
func (NoopConn) ListenOut(_ EncReader) {
|
func (NoopConn) ListenOut(_ EncReader) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
func (NoopConn) SupportsMultipleReaders() bool {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
func (NoopConn) WriteTo(_ []byte, _ netip.AddrPort) error {
|
func (NoopConn) WriteTo(_ []byte, _ netip.AddrPort) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -98,9 +98,9 @@ func (u *StdConn) WriteTo(b []byte, ap netip.AddrPort) error {
|
|||||||
return ErrInvalidIPv6RemoteForSocket
|
return ErrInvalidIPv6RemoteForSocket
|
||||||
}
|
}
|
||||||
|
|
||||||
var rsa unix.RawSockaddrInet4
|
var rsa unix.RawSockaddrInet6
|
||||||
rsa.Family = unix.AF_INET
|
rsa.Family = unix.AF_INET6
|
||||||
rsa.Addr = ap.Addr().As4()
|
rsa.Addr = ap.Addr().As16()
|
||||||
binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&rsa.Port))[:], ap.Port())
|
binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&rsa.Port))[:], ap.Port())
|
||||||
sa = unsafe.Pointer(&rsa)
|
sa = unsafe.Pointer(&rsa)
|
||||||
addrLen = syscall.SizeofSockaddrInet4
|
addrLen = syscall.SizeofSockaddrInet4
|
||||||
@@ -184,10 +184,6 @@ func (u *StdConn) ListenOut(r EncReader) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *StdConn) SupportsMultipleReaders() bool {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (u *StdConn) Rebind() error {
|
func (u *StdConn) Rebind() error {
|
||||||
var err error
|
var err error
|
||||||
if u.isV4 {
|
if u.isV4 {
|
||||||
|
|||||||
@@ -85,7 +85,3 @@ func (u *GenericConn) ListenOut(r EncReader) {
|
|||||||
r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n])
|
r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *GenericConn) SupportsMultipleReaders() bool {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -72,10 +72,6 @@ func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch in
|
|||||||
return &StdConn{sysFd: fd, isV4: ip.Is4(), l: l, batch: batch}, err
|
return &StdConn{sysFd: fd, isV4: ip.Is4(), l: l, batch: batch}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *StdConn) SupportsMultipleReaders() bool {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
func (u *StdConn) Rebind() error {
|
func (u *StdConn) Rebind() error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -315,10 +315,6 @@ func (u *RIOConn) LocalAddr() (netip.AddrPort, error) {
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *RIOConn) SupportsMultipleReaders() bool {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (u *RIOConn) Rebind() error {
|
func (u *RIOConn) Rebind() error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -127,10 +127,6 @@ func (u *TesterConn) LocalAddr() (netip.AddrPort, error) {
|
|||||||
return u.Addr, nil
|
return u.Addr, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *TesterConn) SupportsMultipleReaders() bool {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (u *TesterConn) Rebind() error {
|
func (u *TesterConn) Rebind() error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user