Compare commits

..

9 Commits

Author SHA1 Message Date
Jay Wren
97977982cb whatif? 2025-12-12 15:32:28 -05:00
Jay Wren
69ed6646f1 dont call writeTo4 if writing to ipv6 2025-12-10 16:56:06 -05:00
Jack Doan
3ec527e42c cert.MarshalSigningPublicKeyToPEM should emit the 'ECDSA' variant of the banner (#1552)
Some checks failed
gofmt / Run gofmt (push) Failing after 2s
smoke-extra / Run extra smoke tests (push) Failing after 2s
smoke / Run multi node smoke test (push) Failing after 2s
Build and test / Build all and test on ubuntu-linux (push) Failing after 2s
Build and test / Build and test on linux with boringcrypto (push) Failing after 2s
Build and test / Build and test on linux with pkcs11 (push) Failing after 2s
Build and test / Build and test on macos-latest (push) Has been cancelled
Build and test / Build and test on windows-latest (push) Has been cancelled
* cert.MarshalSigningPublicKeyToPEM should emit the 'ECDSA' variant of the banner

* oof owie ouch my tests
2025-12-10 10:39:36 -06:00
Nate Brown
2d16940232 Slight improvement to hot path benchmark, add a relay hot path benchmark (#1539)
Some checks failed
gofmt / Run gofmt (push) Failing after 15s
smoke-extra / Run extra smoke tests (push) Failing after 2s
smoke / Run multi node smoke test (push) Failing after 2s
Build and test / Build all and test on ubuntu-linux (push) Failing after 2s
Build and test / Build and test on linux with boringcrypto (push) Failing after 2s
Build and test / Build and test on linux with pkcs11 (push) Failing after 2s
Build and test / Build and test on macos-latest (push) Has been cancelled
Build and test / Build and test on windows-latest (push) Has been cancelled
2025-12-09 22:29:26 -06:00
dependabot[bot]
cba294ffa4 Bump actions/checkout from 5 to 6 (#1541)
Bumps [actions/checkout](https://github.com/actions/checkout) from 5 to 6.
- [Release notes](https://github.com/actions/checkout/releases)
- [Changelog](https://github.com/actions/checkout/blob/main/CHANGELOG.md)
- [Commits](https://github.com/actions/checkout/compare/v5...v6)

---
updated-dependencies:
- dependency-name: actions/checkout
  dependency-version: '6'
  dependency-type: direct:production
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-12-09 22:25:48 -06:00
dependabot[bot]
48406f85da Bump the golang-x-dependencies group with 3 updates (#1550)
Bumps the golang-x-dependencies group with 3 updates: [golang.org/x/sync](https://github.com/golang/sync), [golang.org/x/sys](https://github.com/golang/sys) and [golang.org/x/term](https://github.com/golang/term).


Updates `golang.org/x/sync` from 0.18.0 to 0.19.0
- [Commits](https://github.com/golang/sync/compare/v0.18.0...v0.19.0)

Updates `golang.org/x/sys` from 0.38.0 to 0.39.0
- [Commits](https://github.com/golang/sys/compare/v0.38.0...v0.39.0)

Updates `golang.org/x/term` from 0.37.0 to 0.38.0
- [Commits](https://github.com/golang/term/compare/v0.37.0...v0.38.0)

---
updated-dependencies:
- dependency-name: golang.org/x/sync
  dependency-version: 0.19.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
  dependency-group: golang-x-dependencies
- dependency-name: golang.org/x/sys
  dependency-version: 0.39.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
  dependency-group: golang-x-dependencies
- dependency-name: golang.org/x/term
  dependency-version: 0.38.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
  dependency-group: golang-x-dependencies
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-12-09 22:19:53 -06:00
dependabot[bot]
14a1af132e Bump Apple-Actions/import-codesign-certs from 5 to 6 (#1549)
Bumps [Apple-Actions/import-codesign-certs](https://github.com/apple-actions/import-codesign-certs) from 5 to 6.
- [Release notes](https://github.com/apple-actions/import-codesign-certs/releases)
- [Commits](https://github.com/apple-actions/import-codesign-certs/compare/v5...v6)

---
updated-dependencies:
- dependency-name: Apple-Actions/import-codesign-certs
  dependency-version: '6'
  dependency-type: direct:production
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-12-09 22:17:50 -06:00
Nate Brown
59e24b98bd v1.10.0 (#1534)
Some checks failed
gofmt / Run gofmt (push) Failing after 4s
smoke-extra / Run extra smoke tests (push) Failing after 2s
smoke / Run multi node smoke test (push) Failing after 2s
Build and test / Build all and test on ubuntu-linux (push) Failing after 2s
Build and test / Build and test on linux with boringcrypto (push) Failing after 2s
Build and test / Build and test on linux with pkcs11 (push) Failing after 2s
Build and test / Build and test on macos-latest (push) Has been cancelled
Build and test / Build and test on windows-latest (push) Has been cancelled
Update CHANGELOG for Nebula v1.10.0
2025-12-04 14:42:31 -05:00
Nate Brown
56067afca2 Stab at better logging when a relay is being used (#1533)
Some checks failed
gofmt / Run gofmt (push) Failing after 5s
smoke-extra / Run extra smoke tests (push) Failing after 2s
smoke / Run multi node smoke test (push) Failing after 3s
Build and test / Build all and test on ubuntu-linux (push) Failing after 2s
Build and test / Build and test on linux with boringcrypto (push) Failing after 3s
Build and test / Build and test on linux with pkcs11 (push) Failing after 2s
Build and test / Build and test on macos-latest (push) Has been cancelled
Build and test / Build and test on windows-latest (push) Has been cancelled
2025-12-03 17:48:29 -06:00
22 changed files with 401 additions and 338 deletions

View File

@@ -14,7 +14,7 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v5 - uses: actions/checkout@v6
- uses: actions/setup-go@v6 - uses: actions/setup-go@v6
with: with:

View File

@@ -10,7 +10,7 @@ jobs:
name: Build Linux/BSD All name: Build Linux/BSD All
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v5 - uses: actions/checkout@v6
- uses: actions/setup-go@v6 - uses: actions/setup-go@v6
with: with:
@@ -33,7 +33,7 @@ jobs:
name: Build Windows name: Build Windows
runs-on: windows-latest runs-on: windows-latest
steps: steps:
- uses: actions/checkout@v5 - uses: actions/checkout@v6
- uses: actions/setup-go@v6 - uses: actions/setup-go@v6
with: with:
@@ -66,7 +66,7 @@ jobs:
HAS_SIGNING_CREDS: ${{ secrets.AC_USERNAME != '' }} HAS_SIGNING_CREDS: ${{ secrets.AC_USERNAME != '' }}
runs-on: macos-latest runs-on: macos-latest
steps: steps:
- uses: actions/checkout@v5 - uses: actions/checkout@v6
- uses: actions/setup-go@v6 - uses: actions/setup-go@v6
with: with:
@@ -75,7 +75,7 @@ jobs:
- name: Import certificates - name: Import certificates
if: env.HAS_SIGNING_CREDS == 'true' if: env.HAS_SIGNING_CREDS == 'true'
uses: Apple-Actions/import-codesign-certs@v5 uses: Apple-Actions/import-codesign-certs@v6
with: with:
p12-file-base64: ${{ secrets.APPLE_DEVELOPER_CERTIFICATE_P12_BASE64 }} p12-file-base64: ${{ secrets.APPLE_DEVELOPER_CERTIFICATE_P12_BASE64 }}
p12-password: ${{ secrets.APPLE_DEVELOPER_CERTIFICATE_PASSWORD }} p12-password: ${{ secrets.APPLE_DEVELOPER_CERTIFICATE_PASSWORD }}
@@ -124,7 +124,7 @@ jobs:
# be overwritten # be overwritten
- name: Checkout code - name: Checkout code
if: ${{ env.HAS_DOCKER_CREDS == 'true' }} if: ${{ env.HAS_DOCKER_CREDS == 'true' }}
uses: actions/checkout@v5 uses: actions/checkout@v6
- name: Download artifacts - name: Download artifacts
if: ${{ env.HAS_DOCKER_CREDS == 'true' }} if: ${{ env.HAS_DOCKER_CREDS == 'true' }}
@@ -160,7 +160,7 @@ jobs:
needs: [build-linux, build-darwin, build-windows] needs: [build-linux, build-darwin, build-windows]
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v5 - uses: actions/checkout@v6
- name: Download artifacts - name: Download artifacts
uses: actions/download-artifact@v6 uses: actions/download-artifact@v6

View File

@@ -20,7 +20,7 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v5 - uses: actions/checkout@v6
- uses: actions/setup-go@v6 - uses: actions/setup-go@v6
with: with:

View File

@@ -18,7 +18,7 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v5 - uses: actions/checkout@v6
- uses: actions/setup-go@v6 - uses: actions/setup-go@v6
with: with:

View File

@@ -18,7 +18,7 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v5 - uses: actions/checkout@v6
- uses: actions/setup-go@v6 - uses: actions/setup-go@v6
with: with:
@@ -56,7 +56,7 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v5 - uses: actions/checkout@v6
- uses: actions/setup-go@v6 - uses: actions/setup-go@v6
with: with:
@@ -77,7 +77,7 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v5 - uses: actions/checkout@v6
- uses: actions/setup-go@v6 - uses: actions/setup-go@v6
with: with:
@@ -98,7 +98,7 @@ jobs:
os: [windows-latest, macos-latest] os: [windows-latest, macos-latest]
steps: steps:
- uses: actions/checkout@v5 - uses: actions/checkout@v6
- uses: actions/setup-go@v6 - uses: actions/setup-go@v6
with: with:

View File

@@ -7,12 +7,85 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [Unreleased] ## [Unreleased]
## [1.10.0] - 2025-12-04
See the [v1.10.0](https://github.com/slackhq/nebula/milestone/16?closed=1) milestone for a complete list of changes.
### Added
- Support for ipv6 and multiple ipv4/6 addresses in the overlay.
A new v2 ASN.1 based certificate format.
Certificates now have a unified interface for external implementations.
(#1212, #1216, #1345, #1359, #1381, #1419, #1464, #1466, #1451, #1476, #1467, #1481, #1399, #1488, #1492, #1495, #1468, #1521, #1535, #1538)
- Add the ability to mark packets on linux to better target nebula packets in iptables/nftables. (#1331)
- Add ECMP support for `unsafe_routes`. (#1332)
- PKCS11 support for P256 keys when built with `pkcs11` tag (#1153, #1482)
### Changed ### Changed
- `default_local_cidr_any` now defaults to false, meaning that any firewall rule - **NOTE**: `default_local_cidr_any` now defaults to false, meaning that any firewall rule
intended to target an `unsafe_routes` entry must explicitly declare it via the intended to target an `unsafe_routes` entry must explicitly declare it via the
`local_cidr` field. This is almost always the intended behavior. This flag is `local_cidr` field. This is almost always the intended behavior. This flag is
deprecated and will be removed in a future release. deprecated and will be removed in a future release. (#1373)
- Improve logging when a relay is in use on an inbound packet. (#1533)
- Avoid fatal errors if `rountines` is > 1 on systems that don't support more than 1 routine. (#1531)
- Log a warning if a firewall rule contains an `any` that negates a more restrictive filter. (#1513)
- Accept encrypted CA passphrase from an environment variable. (#1421)
- Allow handshaking with any trusted remote. (#1509)
- Log only the count of blocklisted certificate fingerprints instead of the entire list. (#1525)
- Don't fatal when the ssh server is unable to be configured successfully. (#1520)
- Update to build against go v1.25. (#1483)
- Allow projects using `nebula` as a library with userspace networking to configure the `logger` and build version. (#1239)
- Upgrade to `yaml.v3`. (#1148, #1371, #1438, #1478)
### Fixed
- Fix a potential bug with udp ipv4 only on darwin. (#1532)
- Improve lost packet statistics. (#1441, #1537)
- Honor `remote_allow_list` in hole punch response. (#1186)
- Fix a panic when `tun.use_system_route_table` is `true` and a route lacks a destination. (#1437)
- Fix an issue when `tun.use_system_route_table: true` could result in heavy CPU utilization when many thousands of routes
are present. (#1326)
- Fix tests for 32 bit machines. (#1394)
- Fix a possible 32bit integer underflow in config handling. (#1353)
- Fix moving a udp address from one vpn address to another in the `static_host_map`
which could cause rapid re-handshaking with an incorrect remote. (#1259)
- Improve smoke tests in environments where the docker network is not the default. (#1347)
## [1.9.7] - 2025-10-10
### Security
- Fix an issue where Nebula could incorrectly accept and process a packet from an erroneous source IP when the sender's
certificate is configured with unsafe_routes (cert v1/v2) or multiple IPs (cert v2). (#1494)
### Changed
- Disable sending `recv_error` messages when a packet is received outside the allowable counter window. (#1459)
- Improve error messages and remove some unnecessary fatal conditions in the Windows and generic udp listener. (#1453)
## [1.9.6] - 2025-7-15
### Added
- Support dropping inactive tunnels. This is disabled by default in this release but can be enabled with `tunnels.drop_inactive`. See example config for more details. (#1413)
### Fixed
- Fix Darwin freeze due to presence of some Network Extensions (#1426)
- Ensure the same relay tunnel is always used when multiple relay tunnels are present (#1422)
- Fix Windows freeze due to ICMP error handling (#1412)
- Fix relay migration panic (#1403)
## [1.9.5] - 2024-12-05
### Added
- Gracefully ignore v2 certificates. (#1282)
### Fixed
- Fix relays that refuse to re-establish after one of the remote tunnel pairs breaks. (#1277)
## [1.9.4] - 2024-09-09 ## [1.9.4] - 2024-09-09
@@ -671,7 +744,11 @@ created.)
- Initial public release. - Initial public release.
[Unreleased]: https://github.com/slackhq/nebula/compare/v1.9.4...HEAD [Unreleased]: https://github.com/slackhq/nebula/compare/v1.10.0...HEAD
[1.10.0]: https://github.com/slackhq/nebula/releases/tag/v1.10.0
[1.9.7]: https://github.com/slackhq/nebula/releases/tag/v1.9.7
[1.9.6]: https://github.com/slackhq/nebula/releases/tag/v1.9.6
[1.9.5]: https://github.com/slackhq/nebula/releases/tag/v1.9.5
[1.9.4]: https://github.com/slackhq/nebula/releases/tag/v1.9.4 [1.9.4]: https://github.com/slackhq/nebula/releases/tag/v1.9.4
[1.9.3]: https://github.com/slackhq/nebula/releases/tag/v1.9.3 [1.9.3]: https://github.com/slackhq/nebula/releases/tag/v1.9.3
[1.9.2]: https://github.com/slackhq/nebula/releases/tag/v1.9.2 [1.9.2]: https://github.com/slackhq/nebula/releases/tag/v1.9.2

View File

@@ -99,13 +99,19 @@ func TestCertificateV1_PublicKeyPem(t *testing.T) {
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
AAAAAAAAAAAAAAAAAAAAAAA= AAAAAAAAAAAAAAAAAAAAAAA=
-----END NEBULA P256 PUBLIC KEY----- -----END NEBULA P256 PUBLIC KEY-----
`)
pubP256KeyPemCA := []byte(`-----BEGIN NEBULA ECDSA P256 PUBLIC KEY-----
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
AAAAAAAAAAAAAAAAAAAAAAA=
-----END NEBULA ECDSA P256 PUBLIC KEY-----
`) `)
pubP256Key, _, _, err := UnmarshalPublicKeyFromPEM(pubP256KeyPem) pubP256Key, _, _, err := UnmarshalPublicKeyFromPEM(pubP256KeyPem)
require.NoError(t, err) require.NoError(t, err)
nc.details.curve = Curve_P256 nc.details.curve = Curve_P256
nc.details.publicKey = pubP256Key nc.details.publicKey = pubP256Key
assert.Equal(t, Curve_P256, nc.Curve()) assert.Equal(t, Curve_P256, nc.Curve())
assert.Equal(t, string(nc.MarshalPublicKeyPEM()), string(pubP256KeyPem)) assert.Equal(t, string(nc.MarshalPublicKeyPEM()), string(pubP256KeyPemCA))
assert.True(t, nc.IsCA()) assert.True(t, nc.IsCA())
nc.details.isCA = false nc.details.isCA = false

View File

@@ -114,12 +114,19 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
AAAAAAAAAAAAAAAAAAAAAAA= AAAAAAAAAAAAAAAAAAAAAAA=
-----END NEBULA P256 PUBLIC KEY----- -----END NEBULA P256 PUBLIC KEY-----
`) `)
pubP256KeyPemCA := []byte(`-----BEGIN NEBULA ECDSA P256 PUBLIC KEY-----
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
AAAAAAAAAAAAAAAAAAAAAAA=
-----END NEBULA ECDSA P256 PUBLIC KEY-----
`)
pubP256Key, _, _, err := UnmarshalPublicKeyFromPEM(pubP256KeyPem) pubP256Key, _, _, err := UnmarshalPublicKeyFromPEM(pubP256KeyPem)
require.NoError(t, err) require.NoError(t, err)
nc.curve = Curve_P256 nc.curve = Curve_P256
nc.publicKey = pubP256Key nc.publicKey = pubP256Key
assert.Equal(t, Curve_P256, nc.Curve()) assert.Equal(t, Curve_P256, nc.Curve())
assert.Equal(t, string(nc.MarshalPublicKeyPEM()), string(pubP256KeyPem)) assert.Equal(t, string(nc.MarshalPublicKeyPEM()), string(pubP256KeyPemCA))
assert.True(t, nc.IsCA()) assert.True(t, nc.IsCA())
nc.details.isCA = false nc.details.isCA = false

View File

@@ -86,7 +86,7 @@ func MarshalSigningPublicKeyToPEM(curve Curve, b []byte) []byte {
case Curve_CURVE25519: case Curve_CURVE25519:
return pem.EncodeToMemory(&pem.Block{Type: Ed25519PublicKeyBanner, Bytes: b}) return pem.EncodeToMemory(&pem.Block{Type: Ed25519PublicKeyBanner, Bytes: b})
case Curve_P256: case Curve_P256:
return pem.EncodeToMemory(&pem.Block{Type: P256PublicKeyBanner, Bytes: b}) return pem.EncodeToMemory(&pem.Block{Type: ECDSAP256PublicKeyBanner, Bytes: b})
default: default:
return nil return nil
} }

View File

@@ -25,11 +25,12 @@ import (
func BenchmarkHotPath(b *testing.B) { func BenchmarkHotPath(b *testing.B) {
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil) myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil)
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil) theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil)
// Put their info in our lighthouse // Put their info in our lighthouse
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
// Start the servers // Start the servers
myControl.Start() myControl.Start()
@@ -38,6 +39,9 @@ func BenchmarkHotPath(b *testing.B) {
r := router.NewR(b, myControl, theirControl) r := router.NewR(b, myControl, theirControl)
r.CancelFlowLogs() r.CancelFlowLogs()
assertTunnel(b, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
b.ResetTimer()
for n := 0; n < b.N; n++ { for n := 0; n < b.N; n++ {
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
_ = r.RouteForAllUntilTxTun(theirControl) _ = r.RouteForAllUntilTxTun(theirControl)
@@ -47,6 +51,39 @@ func BenchmarkHotPath(b *testing.B) {
theirControl.Stop() theirControl.Stop()
} }
func BenchmarkHotPathRelay(b *testing.B) {
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}})
relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}})
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}})
// Teach my how to get to the relay and that their can be reached via the relay
myControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr)
myControl.InjectRelays(theirVpnIpNet[0].Addr(), []netip.Addr{relayVpnIpNet[0].Addr()})
relayControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
// Build a router so we don't have to reason who gets which packet
r := router.NewR(b, myControl, relayControl, theirControl)
r.CancelFlowLogs()
// Start the servers
myControl.Start()
relayControl.Start()
theirControl.Start()
assertTunnel(b, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r)
b.ResetTimer()
for n := 0; n < b.N; n++ {
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
_ = r.RouteForAllUntilTxTun(theirControl)
}
myControl.Stop()
theirControl.Stop()
relayControl.Stop()
}
func TestGoodHandshake(t *testing.T) { func TestGoodHandshake(t *testing.T) {
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil) myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil)
@@ -1341,13 +1378,6 @@ func TestGoodHandshakeUnsafeDest(t *testing.T) {
"tun": m{ "tun": m{
"unsafe_routes": []m{route}, "unsafe_routes": []m{route},
}, },
"firewall": m{
"unsafe_outbound": []m{{
"port": "any",
"proto": "any",
"host": "any",
}},
},
} }
myControl, myVpnIpNet, myUdpAddr, myConfig := newSimpleServer(cert.Version2, ca, caKey, "me", "10.128.0.1/24", myCfg) myControl, myVpnIpNet, myUdpAddr, myConfig := newSimpleServer(cert.Version2, ca, caKey, "me", "10.128.0.1/24", myCfg)
t.Logf("my config %v", myConfig) t.Logf("my config %v", myConfig)

View File

@@ -85,9 +85,8 @@ func newSimpleServerWithUdpAndUnsafeNetworks(v cert.Version, caCrt cert.Certific
}} }}
var unsafeNetworks []netip.Prefix var unsafeNetworks []netip.Prefix
var firewallUnsafeInbound []m
if sUnsafeNetworks != "" { if sUnsafeNetworks != "" {
firewallUnsafeInbound = []m{{ firewallInbound = []m{{
"proto": "any", "proto": "any",
"port": "any", "port": "any",
"host": "any", "host": "any",
@@ -123,8 +122,7 @@ func newSimpleServerWithUdpAndUnsafeNetworks(v cert.Version, caCrt cert.Certific
"port": "any", "port": "any",
"host": "any", "host": "any",
}}, }},
"inbound": firewallInbound, "inbound": firewallInbound,
"unsafe_inbound": firewallUnsafeInbound,
}, },
//"handshakes": m{ //"handshakes": m{
// "try_interval": "1s", // "try_interval": "1s",
@@ -294,7 +292,7 @@ func deadline(t *testing.T, seconds time.Duration) doneCb {
} }
} }
func assertTunnel(t *testing.T, vpnIpA, vpnIpB netip.Addr, controlA, controlB *nebula.Control, r *router.R) { func assertTunnel(t testing.TB, vpnIpA, vpnIpB netip.Addr, controlA, controlB *nebula.Control, r *router.R) {
// Send a packet from them to me // Send a packet from them to me
controlB.InjectTunUDPPacket(vpnIpA, 80, vpnIpB, 90, []byte("Hi from B")) controlB.InjectTunUDPPacket(vpnIpA, 80, vpnIpB, 90, []byte("Hi from B"))
bPacket := r.RouteForAllUntilTxTun(controlA) bPacket := r.RouteForAllUntilTxTun(controlA)
@@ -306,7 +304,7 @@ func assertTunnel(t *testing.T, vpnIpA, vpnIpB netip.Addr, controlA, controlB *n
assertUdpPacket(t, []byte("Hello from A"), aPacket, vpnIpA, vpnIpB, 90, 80) assertUdpPacket(t, []byte("Hello from A"), aPacket, vpnIpA, vpnIpB, 90, 80)
} }
func assertHostInfoPair(t *testing.T, addrA, addrB netip.AddrPort, vpnNetsA, vpnNetsB []netip.Prefix, controlA, controlB *nebula.Control) { func assertHostInfoPair(t testing.TB, addrA, addrB netip.AddrPort, vpnNetsA, vpnNetsB []netip.Prefix, controlA, controlB *nebula.Control) {
// Get both host infos // Get both host infos
//TODO: CERT-V2 we may want to loop over each vpnAddr and assert all the things //TODO: CERT-V2 we may want to loop over each vpnAddr and assert all the things
hBinA := controlA.GetHostInfoByVpnAddr(vpnNetsB[0].Addr(), false) hBinA := controlA.GetHostInfoByVpnAddr(vpnNetsB[0].Addr(), false)
@@ -327,7 +325,7 @@ func assertHostInfoPair(t *testing.T, addrA, addrB netip.AddrPort, vpnNetsA, vpn
assert.Equal(t, hBinA.RemoteIndex, hAinB.LocalIndex, "Host B remote index does not match host A local index") assert.Equal(t, hBinA.RemoteIndex, hAinB.LocalIndex, "Host B remote index does not match host A local index")
} }
func assertUdpPacket(t *testing.T, expected, b []byte, fromIp, toIp netip.Addr, fromPort, toPort uint16) { func assertUdpPacket(t testing.TB, expected, b []byte, fromIp, toIp netip.Addr, fromPort, toPort uint16) {
if toIp.Is6() { if toIp.Is6() {
assertUdpPacket6(t, expected, b, fromIp, toIp, fromPort, toPort) assertUdpPacket6(t, expected, b, fromIp, toIp, fromPort, toPort)
} else { } else {
@@ -335,7 +333,7 @@ func assertUdpPacket(t *testing.T, expected, b []byte, fromIp, toIp netip.Addr,
} }
} }
func assertUdpPacket6(t *testing.T, expected, b []byte, fromIp, toIp netip.Addr, fromPort, toPort uint16) { func assertUdpPacket6(t testing.TB, expected, b []byte, fromIp, toIp netip.Addr, fromPort, toPort uint16) {
packet := gopacket.NewPacket(b, layers.LayerTypeIPv6, gopacket.Lazy) packet := gopacket.NewPacket(b, layers.LayerTypeIPv6, gopacket.Lazy)
v6 := packet.Layer(layers.LayerTypeIPv6).(*layers.IPv6) v6 := packet.Layer(layers.LayerTypeIPv6).(*layers.IPv6)
assert.NotNil(t, v6, "No ipv6 data found") assert.NotNil(t, v6, "No ipv6 data found")
@@ -354,7 +352,7 @@ func assertUdpPacket6(t *testing.T, expected, b []byte, fromIp, toIp netip.Addr,
assert.Equal(t, expected, data.Payload(), "Data was incorrect") assert.Equal(t, expected, data.Payload(), "Data was incorrect")
} }
func assertUdpPacket4(t *testing.T, expected, b []byte, fromIp, toIp netip.Addr, fromPort, toPort uint16) { func assertUdpPacket4(t testing.TB, expected, b []byte, fromIp, toIp netip.Addr, fromPort, toPort uint16) {
packet := gopacket.NewPacket(b, layers.LayerTypeIPv4, gopacket.Lazy) packet := gopacket.NewPacket(b, layers.LayerTypeIPv4, gopacket.Lazy)
v4 := packet.Layer(layers.LayerTypeIPv4).(*layers.IPv4) v4 := packet.Layer(layers.LayerTypeIPv4).(*layers.IPv4)
assert.NotNil(t, v4, "No ipv4 data found") assert.NotNil(t, v4, "No ipv4 data found")

View File

@@ -23,15 +23,16 @@ import (
) )
type FirewallInterface interface { type FirewallInterface interface {
AddRule(unsafe, incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, cidr, localCidr string, caName string, caSha string) error AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, cidr, localCidr string, caName string, caSha string) error
} }
type conn struct { type conn struct {
Expires time.Time // Time when this conntrack entry will expire Expires time.Time // Time when this conntrack entry will expire
// record why the original connection passed the firewall, so we can re-validate after ruleset changes. // record why the original connection passed the firewall, so we can re-validate
// after ruleset changes. Note, rulesVersion is a uint16 so that these two
// fields pack for free after the uint32 above
incoming bool incoming bool
unsafe bool
rulesVersion uint16 rulesVersion uint16
} }
@@ -39,10 +40,8 @@ type conn struct {
type Firewall struct { type Firewall struct {
Conntrack *FirewallConntrack Conntrack *FirewallConntrack
InRules *FirewallTable InRules *FirewallTable
OutRules *FirewallTable OutRules *FirewallTable
UnsafeInRules *FirewallTable
UnsafeOutRules *FirewallTable
InSendReject bool InSendReject bool
OutSendReject bool OutSendReject bool
@@ -55,7 +54,7 @@ type Firewall struct {
// routableNetworks describes the vpn addresses as well as any unsafe networks issued to us in the certificate. // routableNetworks describes the vpn addresses as well as any unsafe networks issued to us in the certificate.
// The vpn addresses are a full bit match while the unsafe networks only match the prefix // The vpn addresses are a full bit match while the unsafe networks only match the prefix
routableNetworks *bart.Table[NetworkType] routableNetworks *bart.Lite
// assignedNetworks is a list of vpn networks assigned to us in the certificate. // assignedNetworks is a list of vpn networks assigned to us in the certificate.
assignedNetworks []netip.Prefix assignedNetworks []netip.Prefix
@@ -150,16 +149,17 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D
tmax = defaultTimeout tmax = defaultTimeout
} }
routableNetworks := new(bart.Table[NetworkType]) routableNetworks := new(bart.Lite)
var assignedNetworks []netip.Prefix var assignedNetworks []netip.Prefix
for _, network := range c.Networks() { for _, network := range c.Networks() {
routableNetworks.Insert(netip.PrefixFrom(network.Addr(), network.Addr().BitLen()), NetworkTypeVPN) nprefix := netip.PrefixFrom(network.Addr(), network.Addr().BitLen())
routableNetworks.Insert(nprefix)
assignedNetworks = append(assignedNetworks, network) assignedNetworks = append(assignedNetworks, network)
} }
hasUnsafeNetworks := false hasUnsafeNetworks := false
for _, n := range c.UnsafeNetworks() { for _, n := range c.UnsafeNetworks() {
routableNetworks.Insert(n, NetworkTypeUnsafe) routableNetworks.Insert(n)
hasUnsafeNetworks = true hasUnsafeNetworks = true
} }
@@ -170,8 +170,6 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D
}, },
InRules: newFirewallTable(), InRules: newFirewallTable(),
OutRules: newFirewallTable(), OutRules: newFirewallTable(),
UnsafeInRules: newFirewallTable(),
UnsafeOutRules: newFirewallTable(),
TCPTimeout: tcpTimeout, TCPTimeout: tcpTimeout,
UDPTimeout: UDPTimeout, UDPTimeout: UDPTimeout,
DefaultTimeout: defaultTimeout, DefaultTimeout: defaultTimeout,
@@ -214,7 +212,6 @@ func NewFirewallFromConfig(l *logrus.Logger, cs *CertState, c *config.C) (*Firew
fw.defaultLocalCIDRAny = c.GetBool("firewall.default_local_cidr_any", false) fw.defaultLocalCIDRAny = c.GetBool("firewall.default_local_cidr_any", false)
//TODO: do we also need firewall.unsafe_inbound_action and firewall.unsafe_outbound_action?
inboundAction := c.GetString("firewall.inbound_action", "drop") inboundAction := c.GetString("firewall.inbound_action", "drop")
switch inboundAction { switch inboundAction {
case "reject": case "reject":
@@ -237,26 +234,12 @@ func NewFirewallFromConfig(l *logrus.Logger, cs *CertState, c *config.C) (*Firew
fw.OutSendReject = false fw.OutSendReject = false
} }
// outbound rules err := AddFirewallRulesFromConfig(l, false, c, fw)
err := AddFirewallRulesFromConfig(l, false, false, c, fw)
if err != nil { if err != nil {
return nil, err return nil, err
} }
// unsafe outbound rules err = AddFirewallRulesFromConfig(l, true, c, fw)
err = AddFirewallRulesFromConfig(l, true, false, c, fw)
if err != nil {
return nil, err
}
// inbound rules
err = AddFirewallRulesFromConfig(l, false, true, c, fw)
if err != nil {
return nil, err
}
// unsafe inbound rules
err = AddFirewallRulesFromConfig(l, true, true, c, fw)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -265,11 +248,11 @@ func NewFirewallFromConfig(l *logrus.Logger, cs *CertState, c *config.C) (*Firew
} }
// AddRule properly creates the in memory rule structure for a firewall table. // AddRule properly creates the in memory rule structure for a firewall table.
func (f *Firewall) AddRule(unsafe, incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, cidr, localCidr, caName string, caSha string) error { func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, cidr, localCidr, caName string, caSha string) error {
// We need this rule string because we generate a hash. Removing this will break firewall reload. // We need this rule string because we generate a hash. Removing this will break firewall reload.
ruleString := fmt.Sprintf( ruleString := fmt.Sprintf(
"unsafe: %v, incoming: %v, proto: %v, startPort: %v, endPort: %v, groups: %v, host: %v, ip: %v, localIp: %v, caName: %v, caSha: %s", "incoming: %v, proto: %v, startPort: %v, endPort: %v, groups: %v, host: %v, ip: %v, localIp: %v, caName: %v, caSha: %s",
unsafe, incoming, proto, startPort, endPort, groups, host, cidr, localCidr, caName, caSha, incoming, proto, startPort, endPort, groups, host, cidr, localCidr, caName, caSha,
) )
f.rules += ruleString + "\n" f.rules += ruleString + "\n"
@@ -277,12 +260,8 @@ func (f *Firewall) AddRule(unsafe, incoming bool, proto uint8, startPort int32,
if !incoming { if !incoming {
direction = "outgoing" direction = "outgoing"
} }
f.l.WithField("firewallRule", m{"direction": direction, "proto": proto, "startPort": startPort, "endPort": endPort, "groups": groups, "host": host, "cidr": cidr, "localCidr": localCidr, "caName": caName, "caSha": caSha}).
fields := m{"direction": direction, "proto": proto, "startPort": startPort, "endPort": endPort, "groups": groups, "host": host, "cidr": cidr, "localCidr": localCidr, "caName": caName, "caSha": caSha} Info("Firewall rule added")
if unsafe {
fields["unsafe"] = true
}
f.l.WithField("firewallRule", fields).Info("Firewall rule added")
var ( var (
ft *FirewallTable ft *FirewallTable
@@ -290,18 +269,9 @@ func (f *Firewall) AddRule(unsafe, incoming bool, proto uint8, startPort int32,
) )
if incoming { if incoming {
if unsafe { ft = f.InRules
ft = f.UnsafeInRules
} else {
ft = f.InRules
}
} else { } else {
if unsafe { ft = f.OutRules
ft = f.UnsafeOutRules
} else {
ft = f.OutRules
}
} }
switch proto { switch proto {
@@ -338,21 +308,12 @@ func (f *Firewall) GetRuleHashes() string {
return "SHA:" + f.GetRuleHash() + ",FNV:" + strconv.FormatUint(uint64(f.GetRuleHashFNV()), 10) return "SHA:" + f.GetRuleHash() + ",FNV:" + strconv.FormatUint(uint64(f.GetRuleHashFNV()), 10)
} }
func AddFirewallRulesFromConfig(l *logrus.Logger, unsafe, inbound bool, c *config.C, fw FirewallInterface) error { func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw FirewallInterface) error {
var table string var table string
if inbound { if inbound {
if unsafe { table = "firewall.inbound"
table = "firewall.unsafe_inbound"
} else {
table = "firewall.inbound"
}
} else { } else {
if unsafe { table = "firewall.outbound"
table = "firewall.unsafe_outbound"
} else {
table = "firewall.outbound"
}
} }
r := c.Get(table) r := c.Get(table)
@@ -425,7 +386,7 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, unsafe, inbound bool, c *confi
l.Warnf("%s rule #%v; %s", table, i, warning) l.Warnf("%s rule #%v; %s", table, i, warning)
} }
err = fw.AddRule(unsafe, inbound, proto, startPort, endPort, r.Groups, r.Host, r.Cidr, r.LocalCidr, r.CAName, r.CASha) err = fw.AddRule(inbound, proto, startPort, endPort, r.Groups, r.Host, r.Cidr, r.LocalCidr, r.CAName, r.CASha)
if err != nil { if err != nil {
return fmt.Errorf("%s rule #%v; `%s`", table, i, err) return fmt.Errorf("%s rule #%v; `%s`", table, i, err)
} }
@@ -448,9 +409,6 @@ func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *
return nil return nil
} }
var remoteNetworkType NetworkType
var ok bool
// Make sure remote address matches nebula certificate, and determine how to treat it // Make sure remote address matches nebula certificate, and determine how to treat it
if h.networks == nil { if h.networks == nil {
// Simple case: Certificate has one address and no unsafe networks // Simple case: Certificate has one address and no unsafe networks
@@ -458,14 +416,13 @@ func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *
f.metrics(incoming).droppedRemoteAddr.Inc(1) f.metrics(incoming).droppedRemoteAddr.Inc(1)
return ErrInvalidRemoteIP return ErrInvalidRemoteIP
} }
remoteNetworkType = NetworkTypeVPN
} else { } else {
remoteNetworkType, ok = h.networks.Lookup(fp.RemoteAddr) nwType, ok := h.networks.Lookup(fp.RemoteAddr)
if !ok { if !ok {
f.metrics(incoming).droppedRemoteAddr.Inc(1) f.metrics(incoming).droppedRemoteAddr.Inc(1)
return ErrInvalidRemoteIP return ErrInvalidRemoteIP
} }
switch remoteNetworkType { switch nwType {
case NetworkTypeVPN: case NetworkTypeVPN:
break // nothing special break // nothing special
case NetworkTypeVPNPeer: case NetworkTypeVPNPeer:
@@ -480,27 +437,14 @@ func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *
} }
// Make sure we are supposed to be handling this local ip address // Make sure we are supposed to be handling this local ip address
localNetworkType, ok := f.routableNetworks.Lookup(fp.LocalAddr) if !f.routableNetworks.Contains(fp.LocalAddr) {
if !ok {
f.metrics(incoming).droppedLocalAddr.Inc(1) f.metrics(incoming).droppedLocalAddr.Inc(1)
return ErrInvalidLocalIP return ErrInvalidLocalIP
} }
useUnsafe := remoteNetworkType == NetworkTypeUnsafe || localNetworkType == NetworkTypeUnsafe table := f.OutRules
var table *FirewallTable
if incoming { if incoming {
if useUnsafe { table = f.InRules
table = f.UnsafeInRules
} else {
table = f.InRules
}
} else {
if useUnsafe {
table = f.UnsafeOutRules
} else {
table = f.OutRules
}
} }
// We now know which firewall table to check against // We now know which firewall table to check against
@@ -510,13 +454,12 @@ func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *
} }
// We always want to conntrack since it is a faster operation // We always want to conntrack since it is a faster operation
f.addConn(fp, useUnsafe, incoming) f.addConn(fp, incoming)
return nil return nil
} }
func (f *Firewall) metrics(incoming bool) firewallMetrics { func (f *Firewall) metrics(incoming bool) firewallMetrics {
//TODO: need unsafe metrics too
if incoming { if incoming {
return f.incomingMetrics return f.incomingMetrics
} else { } else {
@@ -556,6 +499,7 @@ func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool,
} }
c, ok := conntrack.Conns[fp] c, ok := conntrack.Conns[fp]
if !ok { if !ok {
conntrack.Unlock() conntrack.Unlock()
return false return false
@@ -564,19 +508,9 @@ func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool,
if c.rulesVersion != f.rulesVersion { if c.rulesVersion != f.rulesVersion {
// This conntrack entry was for an older rule set, validate // This conntrack entry was for an older rule set, validate
// it still passes with the current rule set // it still passes with the current rule set
var table *FirewallTable table := f.OutRules
if c.incoming { if c.incoming {
if c.unsafe { table = f.InRules
table = f.UnsafeInRules
} else {
table = f.InRules
}
} else {
if c.unsafe {
table = f.UnsafeOutRules
} else {
table = f.OutRules
}
} }
// We now know which firewall table to check against // We now know which firewall table to check against
@@ -585,7 +519,6 @@ func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool,
h.logger(f.l). h.logger(f.l).
WithField("fwPacket", fp). WithField("fwPacket", fp).
WithField("incoming", c.incoming). WithField("incoming", c.incoming).
WithField("unsafe", c.unsafe).
WithField("rulesVersion", f.rulesVersion). WithField("rulesVersion", f.rulesVersion).
WithField("oldRulesVersion", c.rulesVersion). WithField("oldRulesVersion", c.rulesVersion).
Debugln("dropping old conntrack entry, does not match new ruleset") Debugln("dropping old conntrack entry, does not match new ruleset")
@@ -599,7 +532,6 @@ func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool,
h.logger(f.l). h.logger(f.l).
WithField("fwPacket", fp). WithField("fwPacket", fp).
WithField("incoming", c.incoming). WithField("incoming", c.incoming).
WithField("unsafe", c.unsafe).
WithField("rulesVersion", f.rulesVersion). WithField("rulesVersion", f.rulesVersion).
WithField("oldRulesVersion", c.rulesVersion). WithField("oldRulesVersion", c.rulesVersion).
Debugln("keeping old conntrack entry, does match new ruleset") Debugln("keeping old conntrack entry, does match new ruleset")
@@ -626,7 +558,7 @@ func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool,
return true return true
} }
func (f *Firewall) addConn(fp firewall.Packet, unsafe, incoming bool) { func (f *Firewall) addConn(fp firewall.Packet, incoming bool) {
var timeout time.Duration var timeout time.Duration
c := &conn{} c := &conn{}
@@ -649,7 +581,6 @@ func (f *Firewall) addConn(fp firewall.Packet, unsafe, incoming bool) {
// Record which rulesVersion allowed this connection, so we can retest after // Record which rulesVersion allowed this connection, so we can retest after
// firewall reload // firewall reload
c.incoming = incoming c.incoming = incoming
c.unsafe = unsafe
c.rulesVersion = f.rulesVersion c.rulesVersion = f.rulesVersion
c.Expires = time.Now().Add(timeout) c.Expires = time.Now().Add(timeout)
conntrack.Conns[fp] = c conntrack.Conns[fp] = c
@@ -1006,7 +937,6 @@ func convertRule(l *logrus.Logger, p any, table string, i int) (rule, error) {
r.Code = toString("code", m) r.Code = toString("code", m)
r.Proto = toString("proto", m) r.Proto = toString("proto", m)
r.Host = toString("host", m) r.Host = toString("host", m)
//TODO: create an alias to remote_cidr and deprecate cidr?
r.Cidr = toString("cidr", m) r.Cidr = toString("cidr", m)
r.LocalCidr = toString("local_cidr", m) r.LocalCidr = toString("local_cidr", m)
r.CAName = toString("ca_name", m) r.CAName = toString("ca_name", m)

View File

@@ -73,65 +73,65 @@ func TestFirewall_AddRule(t *testing.T) {
ti6, err := netip.ParsePrefix("fd12::34/128") ti6, err := netip.ParsePrefix("fd12::34/128")
require.NoError(t, err) require.NoError(t, err)
require.NoError(t, fw.AddRule(false, true, firewall.ProtoTCP, 1, 1, []string{}, "", "", "", "", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoTCP, 1, 1, []string{}, "", "", "", "", ""))
// An empty rule is any // An empty rule is any
assert.True(t, fw.InRules.TCP[1].Any.Any.Any) assert.True(t, fw.InRules.TCP[1].Any.Any.Any)
assert.Empty(t, fw.InRules.TCP[1].Any.Groups) assert.Empty(t, fw.InRules.TCP[1].Any.Groups)
assert.Empty(t, fw.InRules.TCP[1].Any.Hosts) assert.Empty(t, fw.InRules.TCP[1].Any.Hosts)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(false, true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", "", "", "", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", "", "", "", ""))
assert.Nil(t, fw.InRules.UDP[1].Any.Any) assert.Nil(t, fw.InRules.UDP[1].Any.Any)
assert.Contains(t, fw.InRules.UDP[1].Any.Groups[0].Groups, "g1") assert.Contains(t, fw.InRules.UDP[1].Any.Groups[0].Groups, "g1")
assert.Empty(t, fw.InRules.UDP[1].Any.Hosts) assert.Empty(t, fw.InRules.UDP[1].Any.Hosts)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(false, true, firewall.ProtoICMP, 1, 1, []string{}, "h1", "", "", "", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoICMP, 1, 1, []string{}, "h1", "", "", "", ""))
assert.Nil(t, fw.InRules.ICMP[1].Any.Any) assert.Nil(t, fw.InRules.ICMP[1].Any.Any)
assert.Empty(t, fw.InRules.ICMP[1].Any.Groups) assert.Empty(t, fw.InRules.ICMP[1].Any.Groups)
assert.Contains(t, fw.InRules.ICMP[1].Any.Hosts, "h1") assert.Contains(t, fw.InRules.ICMP[1].Any.Hosts, "h1")
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(false, false, firewall.ProtoAny, 1, 1, []string{}, "", ti.String(), "", "", "")) require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti.String(), "", "", ""))
assert.Nil(t, fw.OutRules.AnyProto[1].Any.Any) assert.Nil(t, fw.OutRules.AnyProto[1].Any.Any)
_, ok := fw.OutRules.AnyProto[1].Any.CIDR.Get(ti) _, ok := fw.OutRules.AnyProto[1].Any.CIDR.Get(ti)
assert.True(t, ok) assert.True(t, ok)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(false, false, firewall.ProtoAny, 1, 1, []string{}, "", ti6.String(), "", "", "")) require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti6.String(), "", "", ""))
assert.Nil(t, fw.OutRules.AnyProto[1].Any.Any) assert.Nil(t, fw.OutRules.AnyProto[1].Any.Any)
_, ok = fw.OutRules.AnyProto[1].Any.CIDR.Get(ti6) _, ok = fw.OutRules.AnyProto[1].Any.CIDR.Get(ti6)
assert.True(t, ok) assert.True(t, ok)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(false, false, firewall.ProtoAny, 1, 1, []string{}, "", "", ti.String(), "", "")) require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", "", ti.String(), "", ""))
assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any) assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any)
ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti) ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti)
assert.True(t, ok) assert.True(t, ok)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(false, false, firewall.ProtoAny, 1, 1, []string{}, "", "", ti6.String(), "", "")) require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", "", ti6.String(), "", ""))
assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any) assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any)
ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti6) ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti6)
assert.True(t, ok) assert.True(t, ok)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(false, true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", "", "", "ca-name", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", "", "", "ca-name", ""))
assert.Contains(t, fw.InRules.UDP[1].CANames, "ca-name") assert.Contains(t, fw.InRules.UDP[1].CANames, "ca-name")
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(false, true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", "", "", "", "ca-sha")) require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", "", "", "", "ca-sha"))
assert.Contains(t, fw.InRules.UDP[1].CAShas, "ca-sha") assert.Contains(t, fw.InRules.UDP[1].CAShas, "ca-sha")
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(false, false, firewall.ProtoAny, 0, 0, []string{}, "any", "", "", "", "")) require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", "", "", "", ""))
assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any) assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
anyIp, err := netip.ParsePrefix("0.0.0.0/0") anyIp, err := netip.ParsePrefix("0.0.0.0/0")
require.NoError(t, err) require.NoError(t, err)
require.NoError(t, fw.AddRule(false, false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp.String(), "", "", "")) require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp.String(), "", "", ""))
assert.Nil(t, fw.OutRules.AnyProto[0].Any.Any) assert.Nil(t, fw.OutRules.AnyProto[0].Any.Any)
table, ok := fw.OutRules.AnyProto[0].Any.CIDR.Lookup(netip.MustParseAddr("1.1.1.1")) table, ok := fw.OutRules.AnyProto[0].Any.CIDR.Lookup(netip.MustParseAddr("1.1.1.1"))
assert.True(t, table.Any) assert.True(t, table.Any)
@@ -142,7 +142,7 @@ func TestFirewall_AddRule(t *testing.T) {
anyIp6, err := netip.ParsePrefix("::/0") anyIp6, err := netip.ParsePrefix("::/0")
require.NoError(t, err) require.NoError(t, err)
require.NoError(t, fw.AddRule(false, false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp6.String(), "", "", "")) require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp6.String(), "", "", ""))
assert.Nil(t, fw.OutRules.AnyProto[0].Any.Any) assert.Nil(t, fw.OutRules.AnyProto[0].Any.Any)
table, ok = fw.OutRules.AnyProto[0].Any.CIDR.Lookup(netip.MustParseAddr("9::9")) table, ok = fw.OutRules.AnyProto[0].Any.CIDR.Lookup(netip.MustParseAddr("9::9"))
assert.True(t, table.Any) assert.True(t, table.Any)
@@ -150,29 +150,29 @@ func TestFirewall_AddRule(t *testing.T) {
assert.False(t, ok) assert.False(t, ok)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(false, false, firewall.ProtoAny, 0, 0, []string{}, "", "any", "", "", "")) require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", "any", "", "", ""))
assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any) assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(false, false, firewall.ProtoAny, 0, 0, []string{}, "", "", anyIp.String(), "", "")) require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", "", anyIp.String(), "", ""))
assert.False(t, fw.OutRules.AnyProto[0].Any.Any.Any) assert.False(t, fw.OutRules.AnyProto[0].Any.Any.Any)
assert.True(t, fw.OutRules.AnyProto[0].Any.Any.LocalCIDR.Lookup(netip.MustParseAddr("1.1.1.1"))) assert.True(t, fw.OutRules.AnyProto[0].Any.Any.LocalCIDR.Lookup(netip.MustParseAddr("1.1.1.1")))
assert.False(t, fw.OutRules.AnyProto[0].Any.Any.LocalCIDR.Lookup(netip.MustParseAddr("9::9"))) assert.False(t, fw.OutRules.AnyProto[0].Any.Any.LocalCIDR.Lookup(netip.MustParseAddr("9::9")))
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(false, false, firewall.ProtoAny, 0, 0, []string{}, "", "", anyIp6.String(), "", "")) require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", "", anyIp6.String(), "", ""))
assert.False(t, fw.OutRules.AnyProto[0].Any.Any.Any) assert.False(t, fw.OutRules.AnyProto[0].Any.Any.Any)
assert.True(t, fw.OutRules.AnyProto[0].Any.Any.LocalCIDR.Lookup(netip.MustParseAddr("9::9"))) assert.True(t, fw.OutRules.AnyProto[0].Any.Any.LocalCIDR.Lookup(netip.MustParseAddr("9::9")))
assert.False(t, fw.OutRules.AnyProto[0].Any.Any.LocalCIDR.Lookup(netip.MustParseAddr("1.1.1.1"))) assert.False(t, fw.OutRules.AnyProto[0].Any.Any.LocalCIDR.Lookup(netip.MustParseAddr("1.1.1.1")))
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(false, false, firewall.ProtoAny, 0, 0, []string{}, "", "", "any", "", "")) require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", "", "any", "", ""))
assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any) assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
// Test error conditions // Test error conditions
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.Error(t, fw.AddRule(false, true, math.MaxUint8, 0, 0, []string{}, "", "", "", "", "")) require.Error(t, fw.AddRule(true, math.MaxUint8, 0, 0, []string{}, "", "", "", "", ""))
require.Error(t, fw.AddRule(false, true, firewall.ProtoAny, 10, 0, []string{}, "", "", "", "", "")) require.Error(t, fw.AddRule(true, firewall.ProtoAny, 10, 0, []string{}, "", "", "", "", ""))
} }
func TestFirewall_Drop(t *testing.T) { func TestFirewall_Drop(t *testing.T) {
@@ -208,7 +208,7 @@ func TestFirewall_Drop(t *testing.T) {
h.buildNetworks(myVpnNetworksTable, &c) h.buildNetworks(myVpnNetworksTable, &c)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(false, true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", ""))
cp := cert.NewCAPool() cp := cert.NewCAPool()
// Drop outbound // Drop outbound
@@ -227,28 +227,28 @@ func TestFirewall_Drop(t *testing.T) {
// ensure signer doesn't get in the way of group checks // ensure signer doesn't get in the way of group checks
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(false, true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum"))
require.NoError(t, fw.AddRule(false, true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum-bad")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum-bad"))
assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule) assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
// test caSha doesn't drop on match // test caSha doesn't drop on match
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(false, true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum-bad")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum-bad"))
require.NoError(t, fw.AddRule(false, true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum"))
require.NoError(t, fw.Drop(p, true, &h, cp, nil)) require.NoError(t, fw.Drop(p, true, &h, cp, nil))
// ensure ca name doesn't get in the way of group checks // ensure ca name doesn't get in the way of group checks
cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}} cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(false, true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good", ""))
require.NoError(t, fw.AddRule(false, true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good-bad", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good-bad", ""))
assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule) assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
// test caName doesn't drop on match // test caName doesn't drop on match
cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}} cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(false, true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good-bad", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good-bad", ""))
require.NoError(t, fw.AddRule(false, true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good", ""))
require.NoError(t, fw.Drop(p, true, &h, cp, nil)) require.NoError(t, fw.Drop(p, true, &h, cp, nil))
} }
@@ -287,7 +287,7 @@ func TestFirewall_DropV6(t *testing.T) {
h.buildNetworks(myVpnNetworksTable, &c) h.buildNetworks(myVpnNetworksTable, &c)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(false, true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", ""))
cp := cert.NewCAPool() cp := cert.NewCAPool()
// Drop outbound // Drop outbound
@@ -306,28 +306,28 @@ func TestFirewall_DropV6(t *testing.T) {
// ensure signer doesn't get in the way of group checks // ensure signer doesn't get in the way of group checks
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(false, true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum"))
require.NoError(t, fw.AddRule(false, true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum-bad")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum-bad"))
assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule) assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
// test caSha doesn't drop on match // test caSha doesn't drop on match
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(false, true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum-bad")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum-bad"))
require.NoError(t, fw.AddRule(false, true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum"))
require.NoError(t, fw.Drop(p, true, &h, cp, nil)) require.NoError(t, fw.Drop(p, true, &h, cp, nil))
// ensure ca name doesn't get in the way of group checks // ensure ca name doesn't get in the way of group checks
cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}} cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(false, true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good", ""))
require.NoError(t, fw.AddRule(false, true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good-bad", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good-bad", ""))
assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule) assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
// test caName doesn't drop on match // test caName doesn't drop on match
cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}} cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(false, true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good-bad", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good-bad", ""))
require.NoError(t, fw.AddRule(false, true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good", ""))
require.NoError(t, fw.Drop(p, true, &h, cp, nil)) require.NoError(t, fw.Drop(p, true, &h, cp, nil))
} }
@@ -532,7 +532,7 @@ func TestFirewall_Drop2(t *testing.T) {
h1.buildNetworks(myVpnNetworksTable, c1.Certificate) h1.buildNetworks(myVpnNetworksTable, c1.Certificate)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
require.NoError(t, fw.AddRule(false, true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", "", "", "", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", "", "", "", ""))
cp := cert.NewCAPool() cp := cert.NewCAPool()
// h1/c1 lacks the proper groups // h1/c1 lacks the proper groups
@@ -612,8 +612,8 @@ func TestFirewall_Drop3(t *testing.T) {
h3.buildNetworks(myVpnNetworksTable, c3.Certificate) h3.buildNetworks(myVpnNetworksTable, c3.Certificate)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
require.NoError(t, fw.AddRule(false, true, firewall.ProtoAny, 1, 1, []string{}, "host1", "", "", "", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", "", "", "", ""))
require.NoError(t, fw.AddRule(false, true, firewall.ProtoAny, 1, 1, []string{}, "", "", "", "", "signer-sha")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", "", "", "", "signer-sha"))
cp := cert.NewCAPool() cp := cert.NewCAPool()
// c1 should pass because host match // c1 should pass because host match
@@ -627,7 +627,7 @@ func TestFirewall_Drop3(t *testing.T) {
// Test a remote address match // Test a remote address match
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
require.NoError(t, fw.AddRule(false, true, firewall.ProtoAny, 1, 1, []string{}, "", "1.2.3.4/24", "", "", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", "1.2.3.4/24", "", "", ""))
require.NoError(t, fw.Drop(p, true, &h1, cp, nil)) require.NoError(t, fw.Drop(p, true, &h1, cp, nil))
} }
@@ -665,7 +665,7 @@ func TestFirewall_Drop3V6(t *testing.T) {
// Test a remote address match // Test a remote address match
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
cp := cert.NewCAPool() cp := cert.NewCAPool()
require.NoError(t, fw.AddRule(false, true, firewall.ProtoAny, 1, 1, []string{}, "", "fd12::34/120", "", "", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", "fd12::34/120", "", "", ""))
require.NoError(t, fw.Drop(p, true, &h, cp, nil)) require.NoError(t, fw.Drop(p, true, &h, cp, nil))
} }
@@ -704,7 +704,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
h.buildNetworks(myVpnNetworksTable, c.Certificate) h.buildNetworks(myVpnNetworksTable, c.Certificate)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
require.NoError(t, fw.AddRule(false, true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", ""))
cp := cert.NewCAPool() cp := cert.NewCAPool()
// Drop outbound // Drop outbound
@@ -717,7 +717,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
oldFw := fw oldFw := fw
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
require.NoError(t, fw.AddRule(false, true, firewall.ProtoAny, 10, 10, []string{"any"}, "", "", "", "", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 10, 10, []string{"any"}, "", "", "", "", ""))
fw.Conntrack = oldFw.Conntrack fw.Conntrack = oldFw.Conntrack
fw.rulesVersion = oldFw.rulesVersion + 1 fw.rulesVersion = oldFw.rulesVersion + 1
@@ -726,7 +726,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
oldFw = fw oldFw = fw
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
require.NoError(t, fw.AddRule(false, true, firewall.ProtoAny, 11, 11, []string{"any"}, "", "", "", "", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 11, 11, []string{"any"}, "", "", "", "", ""))
fw.Conntrack = oldFw.Conntrack fw.Conntrack = oldFw.Conntrack
fw.rulesVersion = oldFw.rulesVersion + 1 fw.rulesVersion = oldFw.rulesVersion + 1
@@ -765,7 +765,7 @@ func TestFirewall_DropIPSpoofing(t *testing.T) {
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
require.NoError(t, fw.AddRule(false, true, firewall.ProtoAny, 1, 1, []string{}, "", "", "", "", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", "", "", "", ""))
cp := cert.NewCAPool() cp := cert.NewCAPool()
// Packet spoofed by `c1`. Note that the remote addr is not a valid one. // Packet spoofed by `c1`. Note that the remote addr is not a valid one.
@@ -958,28 +958,28 @@ func TestAddFirewallRulesFromConfig(t *testing.T) {
conf := config.NewC(l) conf := config.NewC(l)
mf := &mockFirewall{} mf := &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "tcp", "host": "a"}}} conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "tcp", "host": "a"}}}
require.NoError(t, AddFirewallRulesFromConfig(l, false, false, conf, mf)) require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf))
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoTCP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoTCP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall)
// Test adding udp rule // Test adding udp rule
conf = config.NewC(l) conf = config.NewC(l)
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "udp", "host": "a"}}} conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "udp", "host": "a"}}}
require.NoError(t, AddFirewallRulesFromConfig(l, false, false, conf, mf)) require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf))
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoUDP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoUDP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall)
// Test adding icmp rule // Test adding icmp rule
conf = config.NewC(l) conf = config.NewC(l)
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "icmp", "host": "a"}}} conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "icmp", "host": "a"}}}
require.NoError(t, AddFirewallRulesFromConfig(l, false, false, conf, mf)) require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf))
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall)
// Test adding any rule // Test adding any rule
conf = config.NewC(l) conf = config.NewC(l)
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "host": "a"}}} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "host": "a"}}}
require.NoError(t, AddFirewallRulesFromConfig(l, false, true, conf, mf)) require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall)
// Test adding rule with cidr // Test adding rule with cidr
@@ -987,14 +987,14 @@ func TestAddFirewallRulesFromConfig(t *testing.T) {
conf = config.NewC(l) conf = config.NewC(l)
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": cidr.String()}}} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": cidr.String()}}}
require.NoError(t, AddFirewallRulesFromConfig(l, false, true, conf, mf)) require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr.String(), localIp: ""}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr.String(), localIp: ""}, mf.lastCall)
// Test adding rule with local_cidr // Test adding rule with local_cidr
conf = config.NewC(l) conf = config.NewC(l)
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": cidr.String()}}} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": cidr.String()}}}
require.NoError(t, AddFirewallRulesFromConfig(l, false, true, conf, mf)) require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "", localIp: cidr.String()}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "", localIp: cidr.String()}, mf.lastCall)
// Test adding rule with cidr ipv6 // Test adding rule with cidr ipv6
@@ -1002,75 +1002,75 @@ func TestAddFirewallRulesFromConfig(t *testing.T) {
conf = config.NewC(l) conf = config.NewC(l)
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": cidr6.String()}}} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": cidr6.String()}}}
require.NoError(t, AddFirewallRulesFromConfig(l, false, true, conf, mf)) require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr6.String(), localIp: ""}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr6.String(), localIp: ""}, mf.lastCall)
// Test adding rule with any cidr // Test adding rule with any cidr
conf = config.NewC(l) conf = config.NewC(l)
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": "any"}}} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": "any"}}}
require.NoError(t, AddFirewallRulesFromConfig(l, false, true, conf, mf)) require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "any", localIp: ""}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "any", localIp: ""}, mf.lastCall)
// Test adding rule with junk cidr // Test adding rule with junk cidr
conf = config.NewC(l) conf = config.NewC(l)
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": "junk/junk"}}} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": "junk/junk"}}}
require.EqualError(t, AddFirewallRulesFromConfig(l, false, true, conf, mf), "firewall.inbound rule #0; cidr did not parse; netip.ParsePrefix(\"junk/junk\"): ParseAddr(\"junk\"): unable to parse IP") require.EqualError(t, AddFirewallRulesFromConfig(l, true, conf, mf), "firewall.inbound rule #0; cidr did not parse; netip.ParsePrefix(\"junk/junk\"): ParseAddr(\"junk\"): unable to parse IP")
// Test adding rule with local_cidr ipv6 // Test adding rule with local_cidr ipv6
conf = config.NewC(l) conf = config.NewC(l)
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": cidr6.String()}}} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": cidr6.String()}}}
require.NoError(t, AddFirewallRulesFromConfig(l, false, true, conf, mf)) require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "", localIp: cidr6.String()}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "", localIp: cidr6.String()}, mf.lastCall)
// Test adding rule with any local_cidr // Test adding rule with any local_cidr
conf = config.NewC(l) conf = config.NewC(l)
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": "any"}}} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": "any"}}}
require.NoError(t, AddFirewallRulesFromConfig(l, false, true, conf, mf)) require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, localIp: "any"}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, localIp: "any"}, mf.lastCall)
// Test adding rule with junk local_cidr // Test adding rule with junk local_cidr
conf = config.NewC(l) conf = config.NewC(l)
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": "junk/junk"}}} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": "junk/junk"}}}
require.EqualError(t, AddFirewallRulesFromConfig(l, false, true, conf, mf), "firewall.inbound rule #0; local_cidr did not parse; netip.ParsePrefix(\"junk/junk\"): ParseAddr(\"junk\"): unable to parse IP") require.EqualError(t, AddFirewallRulesFromConfig(l, true, conf, mf), "firewall.inbound rule #0; local_cidr did not parse; netip.ParsePrefix(\"junk/junk\"): ParseAddr(\"junk\"): unable to parse IP")
// Test adding rule with ca_sha // Test adding rule with ca_sha
conf = config.NewC(l) conf = config.NewC(l)
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "ca_sha": "12312313123"}}} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "ca_sha": "12312313123"}}}
require.NoError(t, AddFirewallRulesFromConfig(l, false, true, conf, mf)) require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "", localIp: "", caSha: "12312313123"}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "", localIp: "", caSha: "12312313123"}, mf.lastCall)
// Test adding rule with ca_name // Test adding rule with ca_name
conf = config.NewC(l) conf = config.NewC(l)
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "ca_name": "root01"}}} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "ca_name": "root01"}}}
require.NoError(t, AddFirewallRulesFromConfig(l, false, true, conf, mf)) require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "", localIp: "", caName: "root01"}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "", localIp: "", caName: "root01"}, mf.lastCall)
// Test single group // Test single group
conf = config.NewC(l) conf = config.NewC(l)
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "group": "a"}}} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "group": "a"}}}
require.NoError(t, AddFirewallRulesFromConfig(l, false, true, conf, mf)) require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: "", localIp: ""}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: "", localIp: ""}, mf.lastCall)
// Test single groups // Test single groups
conf = config.NewC(l) conf = config.NewC(l)
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "groups": "a"}}} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "groups": "a"}}}
require.NoError(t, AddFirewallRulesFromConfig(l, false, true, conf, mf)) require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: "", localIp: ""}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: "", localIp: ""}, mf.lastCall)
// Test multiple AND groups // Test multiple AND groups
conf = config.NewC(l) conf = config.NewC(l)
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "groups": []string{"a", "b"}}}} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "groups": []string{"a", "b"}}}}
require.NoError(t, AddFirewallRulesFromConfig(l, false, true, conf, mf)) require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a", "b"}, ip: "", localIp: ""}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a", "b"}, ip: "", localIp: ""}, mf.lastCall)
// Test Add error // Test Add error
@@ -1078,7 +1078,7 @@ func TestAddFirewallRulesFromConfig(t *testing.T) {
mf = &mockFirewall{} mf = &mockFirewall{}
mf.nextCallReturn = errors.New("test error") mf.nextCallReturn = errors.New("test error")
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "host": "a"}}} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "host": "a"}}}
require.EqualError(t, AddFirewallRulesFromConfig(l, false, true, conf, mf), "firewall.inbound rule #0; `test error`") require.EqualError(t, AddFirewallRulesFromConfig(l, true, conf, mf), "firewall.inbound rule #0; `test error`")
} }
func TestFirewall_convertRule(t *testing.T) { func TestFirewall_convertRule(t *testing.T) {
@@ -1251,7 +1251,7 @@ func newSetupFromCert(t *testing.T, l *logrus.Logger, c dummyCert) testsetup {
myVpnNetworksTable.Insert(prefix) myVpnNetworksTable.Insert(prefix)
} }
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(false, true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", ""))
return testsetup{ return testsetup{
c: c, c: c,
@@ -1332,14 +1332,13 @@ func TestFirewall_Drop_EnforceIPMatch(t *testing.T) {
tc.p.LocalAddr = netip.MustParseAddr("192.168.0.3") tc.p.LocalAddr = netip.MustParseAddr("192.168.0.3")
tc.err = ErrNoMatchingRule tc.err = ErrNoMatchingRule
tc.Test(t, unsafeSetup.fw) //should hit firewall and bounce off tc.Test(t, unsafeSetup.fw) //should hit firewall and bounce off
require.NoError(t, unsafeSetup.fw.AddRule(true, true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", unsafePrefix.String(), "", "")) require.NoError(t, unsafeSetup.fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", unsafePrefix.String(), "", ""))
tc.err = nil tc.err = nil
tc.Test(t, unsafeSetup.fw) //should pass tc.Test(t, unsafeSetup.fw) //should pass
}) })
} }
type addRuleCall struct { type addRuleCall struct {
unsafe bool
incoming bool incoming bool
proto uint8 proto uint8
startPort int32 startPort int32
@@ -1357,9 +1356,8 @@ type mockFirewall struct {
nextCallReturn error nextCallReturn error
} }
func (mf *mockFirewall) AddRule(unsafe, incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip, localIp, caName string, caSha string) error { func (mf *mockFirewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip, localIp, caName string, caSha string) error {
mf.lastCall = addRuleCall{ mf.lastCall = addRuleCall{
unsafe: unsafe,
incoming: incoming, incoming: incoming,
proto: proto, proto: proto,
startPort: startPort, startPort: startPort,

6
go.mod
View File

@@ -26,9 +26,9 @@ require (
golang.org/x/crypto v0.45.0 golang.org/x/crypto v0.45.0
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 golang.org/x/exp v0.0.0-20230725093048-515e97ebf090
golang.org/x/net v0.47.0 golang.org/x/net v0.47.0
golang.org/x/sync v0.18.0 golang.org/x/sync v0.19.0
golang.org/x/sys v0.38.0 golang.org/x/sys v0.39.0
golang.org/x/term v0.37.0 golang.org/x/term v0.38.0
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2
golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b
golang.zx2c4.com/wireguard/windows v0.5.3 golang.zx2c4.com/wireguard/windows v0.5.3

12
go.sum
View File

@@ -191,8 +191,8 @@ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJ
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I= golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4=
golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
@@ -209,11 +209,11 @@ golang.org/x/sys v0.0.0-20210603081109-ebe580a85c40/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc= golang.org/x/sys v0.39.0 h1:CvCKL8MeisomCi6qNZ+wbb0DN9E5AATixKsvNtMoMFk=
golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= golang.org/x/sys v0.39.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.37.0 h1:8EGAD0qCmHYZg6J17DvsMy9/wJ7/D/4pV/wfnld5lTU= golang.org/x/term v0.38.0 h1:PQ5pkm/rLO6HnxFR7N2lJHOZX6Kez5Y1gDSJla6jo7Q=
golang.org/x/term v0.37.0/go.mod h1:5pB4lxRNYYVZuTLmy8oR2BH8dflOR+IbTYFD8fi3254= golang.org/x/term v0.38.0/go.mod h1:bSEAKrOT1W+VSu9TSCMtoGEOUcKxOKgl3LE5QEF/xVg=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=

View File

@@ -99,11 +99,11 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
return true return true
} }
func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet []byte, h *header.H) { func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H) {
cs := f.pki.getCertState() cs := f.pki.getCertState()
crt := cs.GetDefaultCertificate() crt := cs.GetDefaultCertificate()
if crt == nil { if crt == nil {
f.l.WithField("udpAddr", addr). f.l.WithField("from", via).
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}). WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
WithField("certVersion", cs.initiatingVersion). WithField("certVersion", cs.initiatingVersion).
Error("Unable to handshake with host because no certificate is available") Error("Unable to handshake with host because no certificate is available")
@@ -112,7 +112,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
ci, err := NewConnectionState(f.l, cs, crt, false, noise.HandshakeIX) ci, err := NewConnectionState(f.l, cs, crt, false, noise.HandshakeIX)
if err != nil { if err != nil {
f.l.WithError(err).WithField("udpAddr", addr). f.l.WithError(err).WithField("from", via).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
Error("Failed to create connection state") Error("Failed to create connection state")
return return
@@ -123,7 +123,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
msg, _, _, err := ci.H.ReadMessage(nil, packet[header.Len:]) msg, _, _, err := ci.H.ReadMessage(nil, packet[header.Len:])
if err != nil { if err != nil {
f.l.WithError(err).WithField("udpAddr", addr). f.l.WithError(err).WithField("from", via).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
Error("Failed to call noise.ReadMessage") Error("Failed to call noise.ReadMessage")
return return
@@ -132,7 +132,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
hs := &NebulaHandshake{} hs := &NebulaHandshake{}
err = hs.Unmarshal(msg) err = hs.Unmarshal(msg)
if err != nil || hs.Details == nil { if err != nil || hs.Details == nil {
f.l.WithError(err).WithField("udpAddr", addr). f.l.WithError(err).WithField("from", via).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
Error("Failed unmarshal handshake message") Error("Failed unmarshal handshake message")
return return
@@ -140,7 +140,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
rc, err := cert.Recombine(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), ci.Curve()) rc, err := cert.Recombine(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), ci.Curve())
if err != nil { if err != nil {
f.l.WithError(err).WithField("udpAddr", addr). f.l.WithError(err).WithField("from", via).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
Info("Handshake did not contain a certificate") Info("Handshake did not contain a certificate")
return return
@@ -153,7 +153,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
fp = "<error generating certificate fingerprint>" fp = "<error generating certificate fingerprint>"
} }
e := f.l.WithError(err).WithField("udpAddr", addr). e := f.l.WithError(err).WithField("from", via).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
WithField("certVpnNetworks", rc.Networks()). WithField("certVpnNetworks", rc.Networks()).
WithField("certFingerprint", fp) WithField("certFingerprint", fp)
@@ -172,7 +172,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
if myCertOtherVersion == nil { if myCertOtherVersion == nil {
if f.l.Level >= logrus.DebugLevel { if f.l.Level >= logrus.DebugLevel {
f.l.WithError(err).WithFields(m{ f.l.WithError(err).WithFields(m{
"udpAddr": addr, "from": via,
"handshake": m{"stage": 1, "style": "ix_psk0"}, "handshake": m{"stage": 1, "style": "ix_psk0"},
"cert": remoteCert, "cert": remoteCert,
}).Debug("Might be unable to handshake with host due to missing certificate version") }).Debug("Might be unable to handshake with host due to missing certificate version")
@@ -184,7 +184,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
} }
if len(remoteCert.Certificate.Networks()) == 0 { if len(remoteCert.Certificate.Networks()) == 0 {
f.l.WithError(err).WithField("udpAddr", addr). f.l.WithError(err).WithField("from", via).
WithField("cert", remoteCert). WithField("cert", remoteCert).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
Info("No networks in certificate") Info("No networks in certificate")
@@ -201,7 +201,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
vpnAddrs := make([]netip.Addr, len(vpnNetworks)) vpnAddrs := make([]netip.Addr, len(vpnNetworks))
for i, network := range vpnNetworks { for i, network := range vpnNetworks {
if f.myVpnAddrsTable.Contains(network.Addr()) { if f.myVpnAddrsTable.Contains(network.Addr()) {
f.l.WithField("vpnNetworks", vpnNetworks).WithField("udpAddr", addr). f.l.WithField("vpnNetworks", vpnNetworks).WithField("from", via).
WithField("certName", certName). WithField("certName", certName).
WithField("certVersion", certVersion). WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint). WithField("fingerprint", fingerprint).
@@ -215,18 +215,18 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
} }
} }
if addr.IsValid() { if !via.IsRelayed {
// addr can be invalid when the tunnel is being relayed.
// We only want to apply the remote allow list for direct tunnels here // We only want to apply the remote allow list for direct tunnels here
if !f.lightHouse.GetRemoteAllowList().AllowAll(vpnAddrs, addr.Addr()) { if !f.lightHouse.GetRemoteAllowList().AllowAll(vpnAddrs, via.UdpAddr.Addr()) {
f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake") f.l.WithField("vpnAddrs", vpnAddrs).WithField("from", via).
Debug("lighthouse.remote_allow_list denied incoming handshake")
return return
} }
} }
myIndex, err := generateIndex(f.l) myIndex, err := generateIndex(f.l)
if err != nil { if err != nil {
f.l.WithError(err).WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr). f.l.WithError(err).WithField("vpnAddrs", vpnAddrs).WithField("from", via).
WithField("certName", certName). WithField("certName", certName).
WithField("certVersion", certVersion). WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint). WithField("fingerprint", fingerprint).
@@ -251,7 +251,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
msgRxL := f.l.WithFields(m{ msgRxL := f.l.WithFields(m{
"vpnAddrs": vpnAddrs, "vpnAddrs": vpnAddrs,
"udpAddr": addr, "from": via,
"certName": certName, "certName": certName,
"certVersion": certVersion, "certVersion": certVersion,
"fingerprint": fingerprint, "fingerprint": fingerprint,
@@ -283,7 +283,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
hsBytes, err := hs.Marshal() hsBytes, err := hs.Marshal()
if err != nil { if err != nil {
f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr). f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via).
WithField("certName", certName). WithField("certName", certName).
WithField("certVersion", certVersion). WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint). WithField("fingerprint", fingerprint).
@@ -295,7 +295,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
nh := header.Encode(make([]byte, header.Len), header.Version, header.Handshake, header.HandshakeIXPSK0, hs.Details.InitiatorIndex, 2) nh := header.Encode(make([]byte, header.Len), header.Version, header.Handshake, header.HandshakeIXPSK0, hs.Details.InitiatorIndex, 2)
msg, dKey, eKey, err := ci.H.WriteMessage(nh, hsBytes) msg, dKey, eKey, err := ci.H.WriteMessage(nh, hsBytes)
if err != nil { if err != nil {
f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr). f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via).
WithField("certName", certName). WithField("certName", certName).
WithField("certVersion", certVersion). WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint). WithField("fingerprint", fingerprint).
@@ -303,7 +303,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage") WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage")
return return
} else if dKey == nil || eKey == nil { } else if dKey == nil || eKey == nil {
f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr). f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via).
WithField("certName", certName). WithField("certName", certName).
WithField("certVersion", certVersion). WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint). WithField("fingerprint", fingerprint).
@@ -329,7 +329,9 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
ci.eKey = NewNebulaCipherState(eKey) ci.eKey = NewNebulaCipherState(eKey)
hostinfo.remotes = f.lightHouse.QueryCache(vpnAddrs) hostinfo.remotes = f.lightHouse.QueryCache(vpnAddrs)
hostinfo.SetRemote(addr) if !via.IsRelayed {
hostinfo.SetRemote(via.UdpAddr)
}
hostinfo.buildNetworks(f.myVpnNetworksTable, remoteCert.Certificate) hostinfo.buildNetworks(f.myVpnNetworksTable, remoteCert.Certificate)
existing, err := f.handshakeManager.CheckAndComplete(hostinfo, 0, f) existing, err := f.handshakeManager.CheckAndComplete(hostinfo, 0, f)
@@ -337,7 +339,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
switch err { switch err {
case ErrAlreadySeen: case ErrAlreadySeen:
// Update remote if preferred // Update remote if preferred
if existing.SetRemoteIfPreferred(f.hostMap, addr) { if existing.SetRemoteIfPreferred(f.hostMap, via) {
// Send a test packet to ensure the other side has also switched to // Send a test packet to ensure the other side has also switched to
// the preferred remote // the preferred remote
f.SendMessageToVpnAddr(header.Test, header.TestRequest, vpnAddrs[0], []byte(""), make([]byte, 12, 12), make([]byte, mtu)) f.SendMessageToVpnAddr(header.Test, header.TestRequest, vpnAddrs[0], []byte(""), make([]byte, 12, 12), make([]byte, mtu))
@@ -345,21 +347,21 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
msg = existing.HandshakePacket[2] msg = existing.HandshakePacket[2]
f.messageMetrics.Tx(header.Handshake, header.MessageSubType(msg[1]), 1) f.messageMetrics.Tx(header.Handshake, header.MessageSubType(msg[1]), 1)
if addr.IsValid() { if !via.IsRelayed {
err := f.outside.WriteTo(msg, addr) err := f.outside.WriteTo(msg, via.UdpAddr)
if err != nil { if err != nil {
f.l.WithField("vpnAddrs", existing.vpnAddrs).WithField("udpAddr", addr). f.l.WithField("vpnAddrs", existing.vpnAddrs).WithField("from", via).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true).
WithError(err).Error("Failed to send handshake message") WithError(err).Error("Failed to send handshake message")
} else { } else {
f.l.WithField("vpnAddrs", existing.vpnAddrs).WithField("udpAddr", addr). f.l.WithField("vpnAddrs", existing.vpnAddrs).WithField("from", via).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true).
Info("Handshake message sent") Info("Handshake message sent")
} }
return return
} else { } else {
if via == nil { if via.relay == nil {
f.l.Error("Handshake send failed: both addr and via are nil.") f.l.Error("Handshake send failed: both addr and via.relay are nil.")
return return
} }
hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0]) hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0])
@@ -371,7 +373,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
} }
case ErrExistingHostInfo: case ErrExistingHostInfo:
// This means there was an existing tunnel and this handshake was older than the one we are currently based on // This means there was an existing tunnel and this handshake was older than the one we are currently based on
f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr). f.l.WithField("vpnAddrs", vpnAddrs).WithField("from", via).
WithField("certName", certName). WithField("certName", certName).
WithField("certVersion", certVersion). WithField("certVersion", certVersion).
WithField("oldHandshakeTime", existing.lastHandshakeTime). WithField("oldHandshakeTime", existing.lastHandshakeTime).
@@ -387,7 +389,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
return return
case ErrLocalIndexCollision: case ErrLocalIndexCollision:
// This means we failed to insert because of collision on localIndexId. Just let the next handshake packet retry // This means we failed to insert because of collision on localIndexId. Just let the next handshake packet retry
f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr). f.l.WithField("vpnAddrs", vpnAddrs).WithField("from", via).
WithField("certName", certName). WithField("certName", certName).
WithField("certVersion", certVersion). WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint). WithField("fingerprint", fingerprint).
@@ -400,7 +402,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
default: default:
// Shouldn't happen, but just in case someone adds a new error type to CheckAndComplete // Shouldn't happen, but just in case someone adds a new error type to CheckAndComplete
// And we forget to update it here // And we forget to update it here
f.l.WithError(err).WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr). f.l.WithError(err).WithField("vpnAddrs", vpnAddrs).WithField("from", via).
WithField("certName", certName). WithField("certName", certName).
WithField("certVersion", certVersion). WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint). WithField("fingerprint", fingerprint).
@@ -414,30 +416,23 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
// Do the send // Do the send
f.messageMetrics.Tx(header.Handshake, header.MessageSubType(msg[1]), 1) f.messageMetrics.Tx(header.Handshake, header.MessageSubType(msg[1]), 1)
if addr.IsValid() { if !via.IsRelayed {
err = f.outside.WriteTo(msg, addr) err = f.outside.WriteTo(msg, via.UdpAddr)
log := f.l.WithField("vpnAddrs", vpnAddrs).WithField("from", via).
WithField("certName", certName).
WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint).
WithField("issuer", issuer).
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"})
if err != nil { if err != nil {
f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr). log.WithError(err).Error("Failed to send handshake")
WithField("certName", certName).
WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint).
WithField("issuer", issuer).
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
WithError(err).Error("Failed to send handshake")
} else { } else {
f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr). log.Info("Handshake message sent")
WithField("certName", certName).
WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint).
WithField("issuer", issuer).
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
Info("Handshake message sent")
} }
} else { } else {
if via == nil { if via.relay == nil {
f.l.Error("Handshake send failed: both addr and via are nil.") f.l.Error("Handshake send failed: both addr and via.relay are nil.")
return return
} }
hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0]) hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0])
@@ -462,7 +457,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
return return
} }
func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *HandshakeHostInfo, packet []byte, h *header.H) bool { func ixHandshakeStage2(f *Interface, via ViaSender, hh *HandshakeHostInfo, packet []byte, h *header.H) bool {
if hh == nil { if hh == nil {
// Nothing here to tear down, got a bogus stage 2 packet // Nothing here to tear down, got a bogus stage 2 packet
return true return true
@@ -472,10 +467,10 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
defer hh.Unlock() defer hh.Unlock()
hostinfo := hh.hostinfo hostinfo := hh.hostinfo
if addr.IsValid() { if !via.IsRelayed {
// The vpnAddr we know about is the one we tried to handshake with, use it to apply the remote allow list. // The vpnAddr we know about is the one we tried to handshake with, use it to apply the remote allow list.
if !f.lightHouse.GetRemoteAllowList().AllowAll(hostinfo.vpnAddrs, addr.Addr()) { if !f.lightHouse.GetRemoteAllowList().AllowAll(hostinfo.vpnAddrs, via.UdpAddr.Addr()) {
f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake") f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via).Debug("lighthouse.remote_allow_list denied incoming handshake")
return false return false
} }
} }
@@ -483,7 +478,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
ci := hostinfo.ConnectionState ci := hostinfo.ConnectionState
msg, eKey, dKey, err := ci.H.ReadMessage(nil, packet[header.Len:]) msg, eKey, dKey, err := ci.H.ReadMessage(nil, packet[header.Len:])
if err != nil { if err != nil {
f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr). f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("header", h). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("header", h).
Error("Failed to call noise.ReadMessage") Error("Failed to call noise.ReadMessage")
@@ -492,7 +487,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
// near future // near future
return false return false
} else if dKey == nil || eKey == nil { } else if dKey == nil || eKey == nil {
f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr). f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
Error("Noise did not arrive at a key") Error("Noise did not arrive at a key")
@@ -504,7 +499,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
hs := &NebulaHandshake{} hs := &NebulaHandshake{}
err = hs.Unmarshal(msg) err = hs.Unmarshal(msg)
if err != nil || hs.Details == nil { if err != nil || hs.Details == nil {
f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr). f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).Error("Failed unmarshal handshake message") WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).Error("Failed unmarshal handshake message")
// The handshake state machine is complete, if things break now there is no chance to recover. Tear down and start again // The handshake state machine is complete, if things break now there is no chance to recover. Tear down and start again
@@ -513,7 +508,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
rc, err := cert.Recombine(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), ci.Curve()) rc, err := cert.Recombine(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), ci.Curve())
if err != nil { if err != nil {
f.l.WithError(err).WithField("udpAddr", addr). f.l.WithError(err).WithField("from", via).
WithField("vpnAddrs", hostinfo.vpnAddrs). WithField("vpnAddrs", hostinfo.vpnAddrs).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
Info("Handshake did not contain a certificate") Info("Handshake did not contain a certificate")
@@ -527,7 +522,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
fp = "<error generating certificate fingerprint>" fp = "<error generating certificate fingerprint>"
} }
e := f.l.WithError(err).WithField("udpAddr", addr). e := f.l.WithError(err).WithField("from", via).
WithField("vpnAddrs", hostinfo.vpnAddrs). WithField("vpnAddrs", hostinfo.vpnAddrs).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
WithField("certFingerprint", fp). WithField("certFingerprint", fp).
@@ -542,7 +537,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
} }
if len(remoteCert.Certificate.Networks()) == 0 { if len(remoteCert.Certificate.Networks()) == 0 {
f.l.WithError(err).WithField("udpAddr", addr). f.l.WithError(err).WithField("from", via).
WithField("vpnAddrs", hostinfo.vpnAddrs). WithField("vpnAddrs", hostinfo.vpnAddrs).
WithField("cert", remoteCert). WithField("cert", remoteCert).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
@@ -565,8 +560,8 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
ci.eKey = NewNebulaCipherState(eKey) ci.eKey = NewNebulaCipherState(eKey)
// Make sure the current udpAddr being used is set for responding // Make sure the current udpAddr being used is set for responding
if addr.IsValid() { if !via.IsRelayed {
hostinfo.SetRemote(addr) hostinfo.SetRemote(via.UdpAddr)
} else { } else {
hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0]) hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0])
} }
@@ -588,7 +583,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
// Ensure the right host responded // Ensure the right host responded
if !correctHostResponded { if !correctHostResponded {
f.l.WithField("intendedVpnAddrs", hostinfo.vpnAddrs).WithField("haveVpnNetworks", vpnNetworks). f.l.WithField("intendedVpnAddrs", hostinfo.vpnAddrs).WithField("haveVpnNetworks", vpnNetworks).
WithField("udpAddr", addr). WithField("from", via).
WithField("certName", certName). WithField("certName", certName).
WithField("certVersion", certVersion). WithField("certVersion", certVersion).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
@@ -602,7 +597,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
f.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], func(newHH *HandshakeHostInfo) { f.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], func(newHH *HandshakeHostInfo) {
// Block the current used address // Block the current used address
newHH.hostinfo.remotes = hostinfo.remotes newHH.hostinfo.remotes = hostinfo.remotes
newHH.hostinfo.remotes.BlockRemote(addr) newHH.hostinfo.remotes.BlockRemote(via)
f.l.WithField("blockedUdpAddrs", newHH.hostinfo.remotes.CopyBlockedRemotes()). f.l.WithField("blockedUdpAddrs", newHH.hostinfo.remotes.CopyBlockedRemotes()).
WithField("vpnNetworks", vpnNetworks). WithField("vpnNetworks", vpnNetworks).
@@ -625,7 +620,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
ci.window.Update(f.l, 2) ci.window.Update(f.l, 2)
duration := time.Since(hh.startTime).Nanoseconds() duration := time.Since(hh.startTime).Nanoseconds()
msgRxL := f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr). msgRxL := f.l.WithField("vpnAddrs", vpnAddrs).WithField("from", via).
WithField("certName", certName). WithField("certName", certName).
WithField("certVersion", certVersion). WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint). WithField("fingerprint", fingerprint).

View File

@@ -136,11 +136,11 @@ func (hm *HandshakeManager) Run(ctx context.Context) {
} }
} }
func (hm *HandshakeManager) HandleIncoming(addr netip.AddrPort, via *ViaSender, packet []byte, h *header.H) { func (hm *HandshakeManager) HandleIncoming(via ViaSender, packet []byte, h *header.H) {
// First remote allow list check before we know the vpnIp // First remote allow list check before we know the vpnIp
if addr.IsValid() { if !via.IsRelayed {
if !hm.lightHouse.GetRemoteAllowList().AllowUnknownVpnAddr(addr.Addr()) { if !hm.lightHouse.GetRemoteAllowList().AllowUnknownVpnAddr(via.UdpAddr.Addr()) {
hm.l.WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake") hm.l.WithField("from", via).Debug("lighthouse.remote_allow_list denied incoming handshake")
return return
} }
} }
@@ -149,11 +149,11 @@ func (hm *HandshakeManager) HandleIncoming(addr netip.AddrPort, via *ViaSender,
case header.HandshakeIXPSK0: case header.HandshakeIXPSK0:
switch h.MessageCounter { switch h.MessageCounter {
case 1: case 1:
ixHandshakeStage1(hm.f, addr, via, packet, h) ixHandshakeStage1(hm.f, via, packet, h)
case 2: case 2:
newHostinfo := hm.queryIndex(h.RemoteIndex) newHostinfo := hm.queryIndex(h.RemoteIndex)
tearDown := ixHandshakeStage2(hm.f, addr, via, newHostinfo, packet, h) tearDown := ixHandshakeStage2(hm.f, via, newHostinfo, packet, h)
if tearDown && newHostinfo != nil { if tearDown && newHostinfo != nil {
hm.DeleteHostInfo(newHostinfo.hostinfo) hm.DeleteHostInfo(newHostinfo.hostinfo)
} }

View File

@@ -1,7 +1,9 @@
package nebula package nebula
import ( import (
"encoding/json"
"errors" "errors"
"fmt"
"net" "net"
"net/netip" "net/netip"
"slices" "slices"
@@ -276,9 +278,25 @@ type HostInfo struct {
} }
type ViaSender struct { type ViaSender struct {
UdpAddr netip.AddrPort
relayHI *HostInfo // relayHI is the host info object of the relay relayHI *HostInfo // relayHI is the host info object of the relay
remoteIdx uint32 // remoteIdx is the index included in the header of the received packet remoteIdx uint32 // remoteIdx is the index included in the header of the received packet
relay *Relay // relay contains the rest of the relay information, including the PeerIP of the host trying to communicate with us. relay *Relay // relay contains the rest of the relay information, including the PeerIP of the host trying to communicate with us.
IsRelayed bool // IsRelayed is true if the packet was sent through a relay
}
func (v ViaSender) String() string {
if v.IsRelayed {
return fmt.Sprintf("%s (relayed)", v.UdpAddr)
}
return v.UdpAddr.String()
}
func (v ViaSender) MarshalJSON() ([]byte, error) {
if v.IsRelayed {
return json.Marshal(m{"relay": v.UdpAddr})
}
return json.Marshal(m{"direct": v.UdpAddr})
} }
type cachedPacket struct { type cachedPacket struct {
@@ -694,6 +712,7 @@ func (i *HostInfo) GetCert() *cert.CachedCertificate {
return nil return nil
} }
// TODO: Maybe use ViaSender here?
func (i *HostInfo) SetRemote(remote netip.AddrPort) { func (i *HostInfo) SetRemote(remote netip.AddrPort) {
// We copy here because we likely got this remote from a source that reuses the object // We copy here because we likely got this remote from a source that reuses the object
if i.remote != remote { if i.remote != remote {
@@ -704,14 +723,14 @@ func (i *HostInfo) SetRemote(remote netip.AddrPort) {
// SetRemoteIfPreferred returns true if the remote was changed. The lastRoam // SetRemoteIfPreferred returns true if the remote was changed. The lastRoam
// time on the HostInfo will also be updated. // time on the HostInfo will also be updated.
func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, newRemote netip.AddrPort) bool { func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, via ViaSender) bool {
if !newRemote.IsValid() { if via.IsRelayed {
// relays have nil udp Addrs
return false return false
} }
currentRemote := i.remote currentRemote := i.remote
if !currentRemote.IsValid() { if !currentRemote.IsValid() {
i.SetRemote(newRemote) i.SetRemote(via.UdpAddr)
return true return true
} }
@@ -724,7 +743,7 @@ func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, newRemote netip.AddrPort) b
return false return false
} }
if l.Contains(newRemote.Addr()) { if l.Contains(via.UdpAddr.Addr()) {
newIsPreferred = true newIsPreferred = true
} }
} }
@@ -734,7 +753,7 @@ func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, newRemote netip.AddrPort) b
i.lastRoam = time.Now() i.lastRoam = time.Now()
i.lastRoamRemote = currentRemote i.lastRoamRemote = currentRemote
i.SetRemote(newRemote) i.SetRemote(via.UdpAddr)
return true return true
} }

View File

@@ -279,7 +279,7 @@ func (f *Interface) listenOut(i int) {
nb := make([]byte, 12, 12) nb := make([]byte, 12, 12)
li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) { li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) {
f.readOutsidePackets(fromUdpAddr, nil, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get(f.l)) f.readOutsidePackets(ViaSender{UdpAddr: fromUdpAddr}, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get(f.l))
}) })
} }

View File

@@ -19,21 +19,21 @@ const (
minFwPacketLen = 4 minFwPacketLen = 4
) )
func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache) { func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache) {
err := h.Parse(packet) err := h.Parse(packet)
if err != nil { if err != nil {
// Hole punch packets are 0 or 1 byte big, so lets ignore printing those errors // Hole punch packets are 0 or 1 byte big, so lets ignore printing those errors
if len(packet) > 1 { if len(packet) > 1 {
f.l.WithField("packet", packet).Infof("Error while parsing inbound packet from %s: %s", ip, err) f.l.WithField("packet", packet).Infof("Error while parsing inbound packet from %s: %s", via, err)
} }
return return
} }
//l.Error("in packet ", header, packet[HeaderLen:]) //l.Error("in packet ", header, packet[HeaderLen:])
if ip.IsValid() { if !via.IsRelayed {
if f.myVpnNetworksTable.Contains(ip.Addr()) { if f.myVpnNetworksTable.Contains(via.UdpAddr.Addr()) {
if f.l.Level >= logrus.DebugLevel { if f.l.Level >= logrus.DebugLevel {
f.l.WithField("udpAddr", ip).Debug("Refusing to process double encrypted packet") f.l.WithField("from", via).Debug("Refusing to process double encrypted packet")
} }
return return
} }
@@ -54,8 +54,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
switch h.Type { switch h.Type {
case header.Message: case header.Message:
// TODO handleEncrypted sends directly to addr on error. Handle this in the tunneling case. if !f.handleEncrypted(ci, via, h) {
if !f.handleEncrypted(ci, ip, h) {
return return
} }
@@ -79,7 +78,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
// Successfully validated the thing. Get rid of the Relay header. // Successfully validated the thing. Get rid of the Relay header.
signedPayload = signedPayload[header.Len:] signedPayload = signedPayload[header.Len:]
// Pull the Roaming parts up here, and return in all call paths. // Pull the Roaming parts up here, and return in all call paths.
f.handleHostRoaming(hostinfo, ip) f.handleHostRoaming(hostinfo, via)
// Track usage of both the HostInfo and the Relay for the received & authenticated packet // Track usage of both the HostInfo and the Relay for the received & authenticated packet
f.connectionManager.In(hostinfo) f.connectionManager.In(hostinfo)
f.connectionManager.RelayUsed(h.RemoteIndex) f.connectionManager.RelayUsed(h.RemoteIndex)
@@ -96,7 +95,14 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
case TerminalType: case TerminalType:
// If I am the target of this relay, process the unwrapped packet // If I am the target of this relay, process the unwrapped packet
// From this recursive point, all these variables are 'burned'. We shouldn't rely on them again. // From this recursive point, all these variables are 'burned'. We shouldn't rely on them again.
f.readOutsidePackets(netip.AddrPort{}, &ViaSender{relayHI: hostinfo, remoteIdx: relay.RemoteIndex, relay: relay}, out[:0], signedPayload, h, fwPacket, lhf, nb, q, localCache) via = ViaSender{
UdpAddr: via.UdpAddr,
relayHI: hostinfo,
remoteIdx: relay.RemoteIndex,
relay: relay,
IsRelayed: true,
}
f.readOutsidePackets(via, out[:0], signedPayload, h, fwPacket, lhf, nb, q, localCache)
return return
case ForwardingType: case ForwardingType:
// Find the target HostInfo relay object // Find the target HostInfo relay object
@@ -126,31 +132,32 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
case header.LightHouse: case header.LightHouse:
f.messageMetrics.Rx(h.Type, h.Subtype, 1) f.messageMetrics.Rx(h.Type, h.Subtype, 1)
if !f.handleEncrypted(ci, ip, h) { if !f.handleEncrypted(ci, via, h) {
return return
} }
d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb) d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb)
if err != nil { if err != nil {
hostinfo.logger(f.l).WithError(err).WithField("udpAddr", ip). hostinfo.logger(f.l).WithError(err).WithField("from", via).
WithField("packet", packet). WithField("packet", packet).
Error("Failed to decrypt lighthouse packet") Error("Failed to decrypt lighthouse packet")
return return
} }
lhf.HandleRequest(ip, hostinfo.vpnAddrs, d, f) //TODO: assert via is not relayed
lhf.HandleRequest(via.UdpAddr, hostinfo.vpnAddrs, d, f)
// Fallthrough to the bottom to record incoming traffic // Fallthrough to the bottom to record incoming traffic
case header.Test: case header.Test:
f.messageMetrics.Rx(h.Type, h.Subtype, 1) f.messageMetrics.Rx(h.Type, h.Subtype, 1)
if !f.handleEncrypted(ci, ip, h) { if !f.handleEncrypted(ci, via, h) {
return return
} }
d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb) d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb)
if err != nil { if err != nil {
hostinfo.logger(f.l).WithError(err).WithField("udpAddr", ip). hostinfo.logger(f.l).WithError(err).WithField("from", via).
WithField("packet", packet). WithField("packet", packet).
Error("Failed to decrypt test packet") Error("Failed to decrypt test packet")
return return
@@ -159,7 +166,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
if h.Subtype == header.TestRequest { if h.Subtype == header.TestRequest {
// This testRequest might be from TryPromoteBest, so we should roam // This testRequest might be from TryPromoteBest, so we should roam
// to the new IP address before responding // to the new IP address before responding
f.handleHostRoaming(hostinfo, ip) f.handleHostRoaming(hostinfo, via)
f.send(header.Test, header.TestReply, ci, hostinfo, d, nb, out) f.send(header.Test, header.TestReply, ci, hostinfo, d, nb, out)
} }
@@ -170,34 +177,34 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
case header.Handshake: case header.Handshake:
f.messageMetrics.Rx(h.Type, h.Subtype, 1) f.messageMetrics.Rx(h.Type, h.Subtype, 1)
f.handshakeManager.HandleIncoming(ip, via, packet, h) f.handshakeManager.HandleIncoming(via, packet, h)
return return
case header.RecvError: case header.RecvError:
f.messageMetrics.Rx(h.Type, h.Subtype, 1) f.messageMetrics.Rx(h.Type, h.Subtype, 1)
f.handleRecvError(ip, h) f.handleRecvError(via.UdpAddr, h)
return return
case header.CloseTunnel: case header.CloseTunnel:
f.messageMetrics.Rx(h.Type, h.Subtype, 1) f.messageMetrics.Rx(h.Type, h.Subtype, 1)
if !f.handleEncrypted(ci, ip, h) { if !f.handleEncrypted(ci, via, h) {
return return
} }
hostinfo.logger(f.l).WithField("udpAddr", ip). hostinfo.logger(f.l).WithField("from", via).
Info("Close tunnel received, tearing down.") Info("Close tunnel received, tearing down.")
f.closeTunnel(hostinfo) f.closeTunnel(hostinfo)
return return
case header.Control: case header.Control:
if !f.handleEncrypted(ci, ip, h) { if !f.handleEncrypted(ci, via, h) {
return return
} }
d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb) d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb)
if err != nil { if err != nil {
hostinfo.logger(f.l).WithError(err).WithField("udpAddr", ip). hostinfo.logger(f.l).WithError(err).WithField("from", via).
WithField("packet", packet). WithField("packet", packet).
Error("Failed to decrypt Control packet") Error("Failed to decrypt Control packet")
return return
@@ -207,11 +214,11 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
default: default:
f.messageMetrics.Rx(h.Type, h.Subtype, 1) f.messageMetrics.Rx(h.Type, h.Subtype, 1)
hostinfo.logger(f.l).Debugf("Unexpected packet received from %s", ip) hostinfo.logger(f.l).Debugf("Unexpected packet received from %s", via)
return return
} }
f.handleHostRoaming(hostinfo, ip) f.handleHostRoaming(hostinfo, via)
f.connectionManager.In(hostinfo) f.connectionManager.In(hostinfo)
} }
@@ -230,36 +237,36 @@ func (f *Interface) sendCloseTunnel(h *HostInfo) {
f.send(header.CloseTunnel, 0, h.ConnectionState, h, []byte{}, make([]byte, 12, 12), make([]byte, mtu)) f.send(header.CloseTunnel, 0, h.ConnectionState, h, []byte{}, make([]byte, 12, 12), make([]byte, mtu))
} }
func (f *Interface) handleHostRoaming(hostinfo *HostInfo, udpAddr netip.AddrPort) { func (f *Interface) handleHostRoaming(hostinfo *HostInfo, via ViaSender) {
if udpAddr.IsValid() && hostinfo.remote != udpAddr { if !via.IsRelayed && hostinfo.remote != via.UdpAddr {
if !f.lightHouse.GetRemoteAllowList().AllowAll(hostinfo.vpnAddrs, udpAddr.Addr()) { if !f.lightHouse.GetRemoteAllowList().AllowAll(hostinfo.vpnAddrs, via.UdpAddr.Addr()) {
hostinfo.logger(f.l).WithField("newAddr", udpAddr).Debug("lighthouse.remote_allow_list denied roaming") hostinfo.logger(f.l).WithField("newAddr", via.UdpAddr).Debug("lighthouse.remote_allow_list denied roaming")
return return
} }
if !hostinfo.lastRoam.IsZero() && udpAddr == hostinfo.lastRoamRemote && time.Since(hostinfo.lastRoam) < RoamingSuppressSeconds*time.Second { if !hostinfo.lastRoam.IsZero() && via.UdpAddr == hostinfo.lastRoamRemote && time.Since(hostinfo.lastRoam) < RoamingSuppressSeconds*time.Second {
if f.l.Level >= logrus.DebugLevel { if f.l.Level >= logrus.DebugLevel {
hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", udpAddr). hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", via.UdpAddr).
Debugf("Suppressing roam back to previous remote for %d seconds", RoamingSuppressSeconds) Debugf("Suppressing roam back to previous remote for %d seconds", RoamingSuppressSeconds)
} }
return return
} }
hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", udpAddr). hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", via.UdpAddr).
Info("Host roamed to new udp ip/port.") Info("Host roamed to new udp ip/port.")
hostinfo.lastRoam = time.Now() hostinfo.lastRoam = time.Now()
hostinfo.lastRoamRemote = hostinfo.remote hostinfo.lastRoamRemote = hostinfo.remote
hostinfo.SetRemote(udpAddr) hostinfo.SetRemote(via.UdpAddr)
} }
} }
// handleEncrypted returns true if a packet should be processed, false otherwise // handleEncrypted returns true if a packet should be processed, false otherwise
func (f *Interface) handleEncrypted(ci *ConnectionState, addr netip.AddrPort, h *header.H) bool { func (f *Interface) handleEncrypted(ci *ConnectionState, via ViaSender, h *header.H) bool {
// If connectionstate does not exist, send a recv error, if possible, to encourage a fast reconnect // If connectionstate does not exist, send a recv error, if possible, to encourage a fast reconnect
if ci == nil { if ci == nil {
if addr.IsValid() { if !via.IsRelayed {
f.maybeSendRecvError(addr, h.RemoteIndex) f.maybeSendRecvError(via.UdpAddr, h.RemoteIndex)
} }
return false return false
} }

View File

@@ -338,21 +338,21 @@ func (r *RemoteList) CopyCache() *CacheMap {
} }
// BlockRemote locks and records the address as bad, it will be excluded from the deduplicated address list // BlockRemote locks and records the address as bad, it will be excluded from the deduplicated address list
func (r *RemoteList) BlockRemote(bad netip.AddrPort) { func (r *RemoteList) BlockRemote(bad ViaSender) {
if !bad.IsValid() { if bad.IsRelayed {
// relays can have nil udp Addrs
return return
} }
r.Lock() r.Lock()
defer r.Unlock() defer r.Unlock()
// Check if we already blocked this addr // Check if we already blocked this addr
if r.unlockedIsBad(bad) { if r.unlockedIsBad(bad.UdpAddr) {
return return
} }
// We copy here because we are taking something else's memory and we can't trust everything // We copy here because we are taking something else's memory and we can't trust everything
r.badRemotes = append(r.badRemotes, bad) r.badRemotes = append(r.badRemotes, bad.UdpAddr)
// Mark the next interaction must recollect/dedupe // Mark the next interaction must recollect/dedupe
r.shouldRebuild = true r.shouldRebuild = true

View File

@@ -224,10 +224,6 @@ func (u *StdConn) writeTo6(b []byte, ip netip.AddrPort) error {
} }
func (u *StdConn) writeTo4(b []byte, ip netip.AddrPort) error { func (u *StdConn) writeTo4(b []byte, ip netip.AddrPort) error {
if !ip.Addr().Is4() {
return ErrInvalidIPv6RemoteForSocket
}
var rsa unix.RawSockaddrInet4 var rsa unix.RawSockaddrInet4
rsa.Family = unix.AF_INET rsa.Family = unix.AF_INET
rsa.Addr = ip.Addr().As4() rsa.Addr = ip.Addr().As4()